import logging import typing from dataclasses import dataclass from functools import partial from urllib import parse from urllib.parse import urlencode, urlparse import anyio import h11 import httpx import tenacity from fastapi import Response from starlette.background import BackgroundTask from starlette.concurrency import iterate_in_threadpool from starlette.requests import Request from starlette.types import Receive, Send, Scope from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type from tqdm.asyncio import tqdm as tqdm_asyncio from mediaflow_proxy.configs import settings from mediaflow_proxy.const import SUPPORTED_REQUEST_HEADERS from mediaflow_proxy.utils.crypto_utils import EncryptionHandler logger = logging.getLogger(__name__) class DownloadError(Exception): def __init__(self, status_code, message): self.status_code = status_code self.message = message super().__init__(message) def create_httpx_client(follow_redirects: bool = True, **kwargs) -> httpx.AsyncClient: """Creates an HTTPX client with configured proxy routing""" mounts = settings.transport_config.get_mounts() kwargs.setdefault("timeout", settings.transport_config.timeout) client = httpx.AsyncClient(mounts=mounts, follow_redirects=follow_redirects, **kwargs) return client @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10), retry=retry_if_exception_type(DownloadError), ) async def fetch_with_retry(client, method, url, headers, follow_redirects=True, **kwargs): """ Fetches a URL with retry logic. Args: client (httpx.AsyncClient): The HTTP client to use for the request. method (str): The HTTP method to use (e.g., GET, POST). url (str): The URL to fetch. headers (dict): The headers to include in the request. follow_redirects (bool, optional): Whether to follow redirects. Defaults to True. **kwargs: Additional arguments to pass to the request. Returns: httpx.Response: The HTTP response. Raises: DownloadError: If the request fails after retries. """ try: response = await client.request(method, url, headers=headers, follow_redirects=follow_redirects, **kwargs) response.raise_for_status() return response except httpx.TimeoutException: logger.warning(f"Timeout while downloading {url}") raise DownloadError(409, f"Timeout while downloading {url}") except httpx.HTTPStatusError as e: logger.error(f"HTTP error {e.response.status_code} while downloading {url}") if e.response.status_code == 404: logger.error(f"Segment Resource not found: {url}") raise e raise DownloadError(e.response.status_code, f"HTTP error {e.response.status_code} while downloading {url}") except Exception as e: logger.error(f"Error downloading {url}: {e}") raise class Streamer: # PNG signature and IEND marker for fake PNG header detection (StreamWish/FileMoon) _PNG_SIGNATURE = b"\x89PNG\r\n\x1a\n" _PNG_IEND_MARKER = b"\x49\x45\x4E\x44\xAE\x42\x60\x82" def __init__(self, client): """ Initializes the Streamer with an HTTP client. Args: client (httpx.AsyncClient): The HTTP client to use for streaming. """ self.client = client self.response = None self.progress_bar = None self.bytes_transferred = 0 self.start_byte = 0 self.end_byte = 0 self.total_size = 0 @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10), retry=retry_if_exception_type(DownloadError), ) async def create_streaming_response(self, url: str, headers: dict): """ Creates and sends a streaming request. Args: url (str): The URL to stream from. headers (dict): The headers to include in the request. """ try: request = self.client.build_request("GET", url, headers=headers) self.response = await self.client.send(request, stream=True, follow_redirects=True) self.response.raise_for_status() except httpx.TimeoutException: logger.warning("Timeout while creating streaming response") raise DownloadError(409, "Timeout while creating streaming response") except httpx.HTTPStatusError as e: logger.error(f"HTTP error {e.response.status_code} while creating streaming response") if e.response.status_code == 404: logger.error(f"Segment Resource not found: {url}") raise e raise DownloadError( e.response.status_code, f"HTTP error {e.response.status_code} while creating streaming response" ) except httpx.RequestError as e: logger.error(f"Error creating streaming response: {e}") raise DownloadError(502, f"Error creating streaming response: {e}") except Exception as e: logger.error(f"Error creating streaming response: {e}") raise RuntimeError(f"Error creating streaming response: {e}") @staticmethod def _strip_fake_png_wrapper(chunk: bytes) -> bytes: """ Strip fake PNG wrapper from chunk data. Some streaming services (StreamWish, FileMoon) prepend a fake PNG image to video data to evade detection. This method detects and removes it. Args: chunk: The raw chunk data that may contain a fake PNG header. Returns: The chunk with fake PNG wrapper removed, or original chunk if not present. """ if not chunk.startswith(Streamer._PNG_SIGNATURE): return chunk # Find the IEND marker that signals end of PNG data iend_pos = chunk.find(Streamer._PNG_IEND_MARKER) if iend_pos == -1: # IEND not found in this chunk - return as-is to avoid data corruption logger.debug("PNG signature detected but IEND marker not found in chunk") return chunk # Calculate position after IEND marker content_start = iend_pos + len(Streamer._PNG_IEND_MARKER) # Skip any padding bytes (null or 0xFF) between PNG and actual content while content_start < len(chunk) and chunk[content_start] in (0x00, 0xFF): content_start += 1 stripped_bytes = content_start logger.debug(f"Stripped {stripped_bytes} bytes of fake PNG wrapper from stream") return chunk[content_start:] async def stream_content(self) -> typing.AsyncGenerator[bytes, None]: if not self.response: raise RuntimeError("No response available for streaming") is_first_chunk = True try: self.parse_content_range() if settings.enable_streaming_progress: with tqdm_asyncio( total=self.total_size, initial=self.start_byte, unit="B", unit_scale=True, unit_divisor=1024, desc="Streaming", ncols=100, mininterval=1, ) as self.progress_bar: async for chunk in self.response.aiter_bytes(): if is_first_chunk: is_first_chunk = False chunk = self._strip_fake_png_wrapper(chunk) yield chunk self.bytes_transferred += len(chunk) self.progress_bar.update(len(chunk)) else: async for chunk in self.response.aiter_bytes(): if is_first_chunk: is_first_chunk = False chunk = self._strip_fake_png_wrapper(chunk) yield chunk self.bytes_transferred += len(chunk) except httpx.TimeoutException: logger.warning("Timeout while streaming") raise DownloadError(409, "Timeout while streaming") except httpx.RemoteProtocolError as e: # Special handling for connection closed errors if "peer closed connection without sending complete message body" in str(e): logger.warning(f"Remote server closed connection prematurely: {e}") # If we've received some data, just log the warning and return normally if self.bytes_transferred > 0: logger.info( f"Partial content received ({self.bytes_transferred} bytes). Continuing with available data." ) return else: # If we haven't received any data, raise an error raise DownloadError(502, f"Remote server closed connection without sending any data: {e}") else: logger.error(f"Protocol error while streaming: {e}") raise DownloadError(502, f"Protocol error while streaming: {e}") except GeneratorExit: logger.info("Streaming session stopped by the user") except httpx.ReadError as e: # Handle network read errors gracefully - these occur when upstream connection drops logger.warning(f"ReadError while streaming: {e}") if self.bytes_transferred > 0: logger.info(f"Partial content received ({self.bytes_transferred} bytes) before ReadError. Graceful termination.") return else: raise DownloadError(502, f"ReadError while streaming: {e}") except Exception as e: logger.error(f"Error streaming content: {e}") raise @staticmethod def format_bytes(size) -> str: power = 2**10 n = 0 units = {0: "B", 1: "KB", 2: "MB", 3: "GB", 4: "TB"} while size > power: size /= power n += 1 return f"{size:.2f} {units[n]}" def parse_content_range(self): content_range = self.response.headers.get("Content-Range", "") if content_range: range_info = content_range.split()[-1] self.start_byte, self.end_byte, self.total_size = map(int, range_info.replace("/", "-").split("-")) else: self.start_byte = 0 self.total_size = int(self.response.headers.get("Content-Length", 0)) self.end_byte = self.total_size - 1 if self.total_size > 0 else 0 async def get_text(self, url: str, headers: dict): """ Sends a GET request to a URL and returns the response text. Args: url (str): The URL to send the GET request to. headers (dict): The headers to include in the request. Returns: str: The response text. """ try: self.response = await fetch_with_retry(self.client, "GET", url, headers) except tenacity.RetryError as e: raise e.last_attempt.result() return self.response.text async def close(self): """ Closes the HTTP client and response. """ if self.response: await self.response.aclose() if self.progress_bar: self.progress_bar.close() await self.client.aclose() async def download_file_with_retry(url: str, headers: dict): """ Downloads a file with retry logic. Args: url (str): The URL of the file to download. headers (dict): The headers to include in the request. Returns: bytes: The downloaded file content. Raises: DownloadError: If the download fails after retries. """ async with create_httpx_client() as client: try: response = await fetch_with_retry(client, "GET", url, headers) return response.content except DownloadError as e: logger.error(f"Failed to download file: {e}") raise e except tenacity.RetryError as e: raise DownloadError(502, f"Failed to download file: {e.last_attempt.result()}") async def request_with_retry(method: str, url: str, headers: dict, **kwargs) -> httpx.Response: """ Sends an HTTP request with retry logic. Args: method (str): The HTTP method to use (e.g., GET, POST). url (str): The URL to send the request to. headers (dict): The headers to include in the request. **kwargs: Additional arguments to pass to the request. Returns: httpx.Response: The HTTP response. Raises: DownloadError: If the request fails after retries. """ async with create_httpx_client() as client: try: response = await fetch_with_retry(client, method, url, headers, **kwargs) return response except DownloadError as e: logger.error(f"Failed to download file: {e}") raise def encode_mediaflow_proxy_url( mediaflow_proxy_url: str, endpoint: typing.Optional[str] = None, destination_url: typing.Optional[str] = None, query_params: typing.Optional[dict] = None, request_headers: typing.Optional[dict] = None, response_headers: typing.Optional[dict] = None, encryption_handler: EncryptionHandler = None, expiration: int = None, ip: str = None, filename: typing.Optional[str] = None, ) -> str: """ Encodes & Encrypt (Optional) a MediaFlow proxy URL with query parameters and headers. Args: mediaflow_proxy_url (str): The base MediaFlow proxy URL. endpoint (str, optional): The endpoint to append to the base URL. Defaults to None. destination_url (str, optional): The destination URL to include in the query parameters. Defaults to None. query_params (dict, optional): Additional query parameters to include. Defaults to None. request_headers (dict, optional): Headers to include as query parameters. Defaults to None. response_headers (dict, optional): Headers to include as query parameters. Defaults to None. encryption_handler (EncryptionHandler, optional): The encryption handler to use. Defaults to None. expiration (int, optional): The expiration time for the encrypted token. Defaults to None. ip (str, optional): The public IP address to include in the query parameters. Defaults to None. filename (str, optional): Filename to be preserved for media players like Infuse. Defaults to None. Returns: str: The encoded MediaFlow proxy URL. """ # Prepare query parameters query_params = query_params or {} if destination_url is not None: query_params["d"] = destination_url # Add headers if provided if request_headers: query_params.update( {key if key.startswith("h_") else f"h_{key}": value for key, value in request_headers.items()} ) if response_headers: query_params.update( {key if key.startswith("r_") else f"r_{key}": value for key, value in response_headers.items()} ) # Construct the base URL if endpoint is None: base_url = mediaflow_proxy_url else: base_url = parse.urljoin(mediaflow_proxy_url, endpoint) # Ensure base_url doesn't end with a slash for consistent handling if base_url.endswith("/"): base_url = base_url[:-1] # Handle encryption if needed if encryption_handler: encrypted_token = encryption_handler.encrypt_data(query_params, expiration, ip) # Parse the base URL to get its components parsed_url = parse.urlparse(base_url) # Insert the token at the beginning of the path new_path = f"/_token_{encrypted_token}{parsed_url.path}" # Reconstruct the URL with the token at the beginning of the path url_parts = list(parsed_url) url_parts[2] = new_path # Update the path component # Build the URL url = parse.urlunparse(url_parts) # Add filename at the end if provided if filename: url = f"{url}/{parse.quote(filename)}" return url else: # No encryption, use regular query parameters url = base_url if filename: url = f"{url}/{parse.quote(filename)}" if query_params: return f"{url}?{urlencode(query_params)}" return url def encode_stremio_proxy_url( stremio_proxy_url: str, destination_url: str, request_headers: typing.Optional[dict] = None, response_headers: typing.Optional[dict] = None, ) -> str: """ Encodes a Stremio proxy URL with destination URL and headers. Format: http://127.0.0.1:11470/proxy/d=&h=&r=/ Args: stremio_proxy_url (str): The base Stremio proxy URL. destination_url (str): The destination URL to proxy. request_headers (dict, optional): Headers to include as query parameters. Defaults to None. response_headers (dict, optional): Response headers to include as query parameters. Defaults to None. Returns: str: The encoded Stremio proxy URL. """ # Parse the destination URL to separate origin, path, and query parsed_dest = parse.urlparse(destination_url) dest_origin = f"{parsed_dest.scheme}://{parsed_dest.netloc}" dest_path = parsed_dest.path.lstrip("/") dest_query = parsed_dest.query # Prepare query parameters list for proper handling of multiple headers query_parts = [] # Add destination origin (scheme + netloc only) with proper encoding query_parts.append(f"d={parse.quote_plus(dest_origin)}") # Add request headers if request_headers: for key, value in request_headers.items(): header_string = f"{key}:{value}" query_parts.append(f"h={parse.quote_plus(header_string)}") # Add response headers if response_headers: for key, value in response_headers.items(): header_string = f"{key}:{value}" query_parts.append(f"r={parse.quote_plus(header_string)}") # Ensure base_url doesn't end with a slash for consistent handling base_url = stremio_proxy_url.rstrip("/") # Construct the URL path with query string query_string = "&".join(query_parts) # Build the final URL: /proxy/{opts}/{pathname}{search} url_path = f"/proxy/{query_string}" # Append the path from destination URL if dest_path: url_path = f"{url_path}/{dest_path}" # Append the query string from destination URL if dest_query: url_path = f"{url_path}?{dest_query}" return f"{base_url}{url_path}" def get_original_scheme(request: Request) -> str: """ Determines the original scheme (http or https) of the request. Args: request (Request): The incoming HTTP request. Returns: str: The original scheme ('http' or 'https') """ # Check the X-Forwarded-Proto header first forwarded_proto = request.headers.get("X-Forwarded-Proto") if forwarded_proto: return forwarded_proto # Check if the request is secure if request.url.scheme == "https" or request.headers.get("X-Forwarded-Ssl") == "on": return "https" # Check for other common headers that might indicate HTTPS if ( request.headers.get("X-Forwarded-Ssl") == "on" or request.headers.get("X-Forwarded-Protocol") == "https" or request.headers.get("X-Url-Scheme") == "https" ): return "https" # Default to http if no indicators of https are found return "http" @dataclass class ProxyRequestHeaders: request: dict response: dict def get_proxy_headers(request: Request) -> ProxyRequestHeaders: """ Extracts proxy headers from the request query parameters. Args: request (Request): The incoming HTTP request. Returns: ProxyRequest: A named tuple containing the request headers and response headers. """ request_headers = {k: v for k, v in request.headers.items() if k in SUPPORTED_REQUEST_HEADERS} request_headers.update({k[2:].lower(): v for k, v in request.query_params.items() if k.startswith("h_")}) request_headers.setdefault("user-agent", settings.user_agent) # Handle common misspelling of referer if "referrer" in request_headers: if "referer" not in request_headers: request_headers["referer"] = request_headers.pop("referrer") dest = request.query_params.get("d", "") host = urlparse(dest).netloc.lower() if "vidoza" in host or "videzz" in host: # Remove ALL empty headers for h in list(request_headers.keys()): v = request_headers[h] if v is None or v.strip() == "": request_headers.pop(h, None) response_headers = {k[2:].lower(): v for k, v in request.query_params.items() if k.startswith("r_")} return ProxyRequestHeaders(request_headers, response_headers) class EnhancedStreamingResponse(Response): body_iterator: typing.AsyncIterable[typing.Any] def __init__( self, content: typing.Union[typing.AsyncIterable[typing.Any], typing.Iterable[typing.Any]], status_code: int = 200, headers: typing.Optional[typing.Mapping[str, str]] = None, media_type: typing.Optional[str] = None, background: typing.Optional[BackgroundTask] = None, ) -> None: if isinstance(content, typing.AsyncIterable): self.body_iterator = content else: self.body_iterator = iterate_in_threadpool(content) self.status_code = status_code self.media_type = self.media_type if media_type is None else media_type self.background = background self.init_headers(headers) self.actual_content_length = 0 @staticmethod async def listen_for_disconnect(receive: Receive) -> None: try: while True: message = await receive() if message["type"] == "http.disconnect": logger.debug("Client disconnected") break except Exception as e: logger.error(f"Error in listen_for_disconnect: {str(e)}") async def stream_response(self, send: Send) -> None: # Track if response headers have been sent to prevent duplicate headers response_started = False # Track if response finalization (more_body: False) has been sent to prevent ASGI protocol violation finalization_sent = False try: # Initialize headers headers = list(self.raw_headers) # Start the response await send( { "type": "http.response.start", "status": self.status_code, "headers": headers, } ) response_started = True # Track if we've sent any data data_sent = False try: async for chunk in self.body_iterator: if not isinstance(chunk, (bytes, memoryview)): chunk = chunk.encode(self.charset) try: await send({"type": "http.response.body", "body": chunk, "more_body": True}) data_sent = True self.actual_content_length += len(chunk) except (ConnectionResetError, anyio.BrokenResourceError): logger.info("Client disconnected during streaming") return # Successfully streamed all content await send({"type": "http.response.body", "body": b"", "more_body": False}) finalization_sent = True except (httpx.RemoteProtocolError, httpx.ReadError, h11._util.LocalProtocolError) as e: # Handle connection closed / read errors gracefully if data_sent: # We've sent some data to the client, so try to complete the response logger.warning(f"Upstream connection error after partial streaming: {e}") try: await send({"type": "http.response.body", "body": b"", "more_body": False}) finalization_sent = True logger.info( f"Response finalized after partial content ({self.actual_content_length} bytes transferred)" ) except Exception as close_err: logger.warning(f"Could not finalize response after upstream error: {close_err}") else: # No data was sent, re-raise the error logger.error(f"Upstream error before any data was streamed: {e}") raise except Exception as e: logger.exception(f"Error in stream_response: {str(e)}") if not isinstance(e, (ConnectionResetError, anyio.BrokenResourceError)) and not response_started: # Only attempt to send error response if headers haven't been sent yet try: await send( { "type": "http.response.start", "status": 502, "headers": [(b"content-type", b"text/plain")], } ) error_message = f"Streaming error: {str(e)}".encode("utf-8") await send({"type": "http.response.body", "body": error_message, "more_body": False}) finalization_sent = True except Exception: # If we can't send an error response, just log it pass elif response_started and not finalization_sent: # Response already started but not finalized - gracefully close the stream try: await send({"type": "http.response.body", "body": b"", "more_body": False}) finalization_sent = True except Exception: pass async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: async with anyio.create_task_group() as task_group: stream_func = partial(self.stream_response, send) listen_func = partial(self.listen_for_disconnect, receive) async def wrap(func: typing.Callable[[], typing.Awaitable[None]]) -> None: try: await func() except Exception as e: # Note: stream_response and listen_for_disconnect handle their own exceptions # internally. This is a safety net for any unexpected exceptions that might # escape due to future code changes. if not isinstance(e, anyio.get_cancelled_exc_class()): logger.exception(f"Unexpected error in streaming task: {type(e).__name__}: {e}") # Re-raise unexpected errors to surface bugs rather than silently swallowing them raise finally: # Cancel task group when either task completes or fails: # - stream_func finished (success or failure) -> stop listening for disconnect # - listen_func finished (client disconnected) -> stop streaming task_group.cancel_scope.cancel() # Start the streaming response in a separate task task_group.start_soon(wrap, stream_func) # Listen for disconnect events await wrap(listen_func) if self.background is not None: await self.background()