diff --git a/mediaflow_proxy/__init__.py b/mediaflow_proxy/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mediaflow_proxy/configs.py b/mediaflow_proxy/configs.py new file mode 100644 index 0000000..7d24dba --- /dev/null +++ b/mediaflow_proxy/configs.py @@ -0,0 +1,68 @@ +from typing import Dict, Optional, Union + +import httpx +from pydantic import BaseModel, Field +from pydantic_settings import BaseSettings + + +class RouteConfig(BaseModel): + """Configuration for a specific route""" + + proxy: bool = True + proxy_url: Optional[str] = None + verify_ssl: bool = True + + +class TransportConfig(BaseSettings): + """Main proxy configuration""" + + proxy_url: Optional[str] = Field( + None, description="Primary proxy URL. Example: socks5://user:pass@proxy:1080 or http://proxy:8080" + ) + all_proxy: bool = Field(False, description="Enable proxy for all routes by default") + transport_routes: Dict[str, RouteConfig] = Field( + default_factory=dict, description="Pattern-based route configuration" + ) + + def get_mounts( + self, async_http: bool = True + ) -> Dict[str, Optional[Union[httpx.HTTPTransport, httpx.AsyncHTTPTransport]]]: + """ + Get a dictionary of httpx mount points to transport instances. + """ + mounts = {} + transport_cls = httpx.AsyncHTTPTransport if async_http else httpx.HTTPTransport + + # Configure specific routes + for pattern, route in self.transport_routes.items(): + mounts[pattern] = transport_cls( + verify=route.verify_ssl, proxy=route.proxy_url or self.proxy_url if route.proxy else None + ) + + # Set default proxy for all routes if enabled + if self.all_proxy: + mounts["all://"] = transport_cls(proxy=self.proxy_url) + + return mounts + + class Config: + env_file = ".env" + extra = "ignore" + + +class Settings(BaseSettings): + api_password: str | None = None # The password for protecting the API endpoints. + log_level: str = "INFO" # The logging level to use. + transport_config: TransportConfig = Field(default_factory=TransportConfig) # Configuration for httpx transport. + enable_streaming_progress: bool = False # Whether to enable streaming progress tracking. + + user_agent: str = ( + "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3" # The user agent to use for HTTP requests. + ) + + class Config: + env_file = ".env" + extra = "ignore" + + +settings = Settings() diff --git a/mediaflow_proxy/const.py b/mediaflow_proxy/const.py new file mode 100644 index 0000000..49d2464 --- /dev/null +++ b/mediaflow_proxy/const.py @@ -0,0 +1,17 @@ +SUPPORTED_RESPONSE_HEADERS = [ + "accept-ranges", + "content-type", + "content-length", + "content-range", + "connection", + "transfer-encoding", + "last-modified", + "etag", + "cache-control", + "expires", +] + +SUPPORTED_REQUEST_HEADERS = [ + "range", + "if-range", +] diff --git a/mediaflow_proxy/drm/__init__.py b/mediaflow_proxy/drm/__init__.py new file mode 100644 index 0000000..b92db72 --- /dev/null +++ b/mediaflow_proxy/drm/__init__.py @@ -0,0 +1,11 @@ +import os +import tempfile + + +async def create_temp_file(suffix: str, content: bytes = None, prefix: str = None) -> tempfile.NamedTemporaryFile: + temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=suffix, prefix=prefix) + temp_file.delete_file = lambda: os.unlink(temp_file.name) + if content: + temp_file.write(content) + temp_file.close() + return temp_file diff --git a/mediaflow_proxy/drm/decrypter.py b/mediaflow_proxy/drm/decrypter.py new file mode 100644 index 0000000..a6514a9 --- /dev/null +++ b/mediaflow_proxy/drm/decrypter.py @@ -0,0 +1,778 @@ +import argparse +import struct +import sys +from typing import Optional, Union + +from Crypto.Cipher import AES +from collections import namedtuple +import array + +CENCSampleAuxiliaryDataFormat = namedtuple("CENCSampleAuxiliaryDataFormat", ["is_encrypted", "iv", "sub_samples"]) + + +class MP4Atom: + """ + Represents an MP4 atom, which is a basic unit of data in an MP4 file. + Each atom contains a header (size and type) and data. + """ + + __slots__ = ("atom_type", "size", "data") + + def __init__(self, atom_type: bytes, size: int, data: Union[memoryview, bytearray]): + """ + Initializes an MP4Atom instance. + + Args: + atom_type (bytes): The type of the atom. + size (int): The size of the atom. + data (Union[memoryview, bytearray]): The data contained in the atom. + """ + self.atom_type = atom_type + self.size = size + self.data = data + + def __repr__(self): + return f"" + + def pack(self): + """ + Packs the atom into binary data. + + Returns: + bytes: Packed binary data with size, type, and data. + """ + return struct.pack(">I", self.size) + self.atom_type + self.data + + +class MP4Parser: + """ + Parses MP4 data to extract atoms and their structure. + """ + + def __init__(self, data: memoryview): + """ + Initializes an MP4Parser instance. + + Args: + data (memoryview): The binary data of the MP4 file. + """ + self.data = data + self.position = 0 + + def read_atom(self) -> Optional[MP4Atom]: + """ + Reads the next atom from the data. + + Returns: + Optional[MP4Atom]: MP4Atom object or None if no more atoms are available. + """ + pos = self.position + if pos + 8 > len(self.data): + return None + + size, atom_type = struct.unpack_from(">I4s", self.data, pos) + pos += 8 + + if size == 1: + if pos + 8 > len(self.data): + return None + size = struct.unpack_from(">Q", self.data, pos)[0] + pos += 8 + + if size < 8 or pos + size - 8 > len(self.data): + return None + + atom_data = self.data[pos : pos + size - 8] + self.position = pos + size - 8 + return MP4Atom(atom_type, size, atom_data) + + def list_atoms(self) -> list[MP4Atom]: + """ + Lists all atoms in the data. + + Returns: + list[MP4Atom]: List of MP4Atom objects. + """ + atoms = [] + original_position = self.position + self.position = 0 + while self.position + 8 <= len(self.data): + atom = self.read_atom() + if not atom: + break + atoms.append(atom) + self.position = original_position + return atoms + + def _read_atom_at(self, pos: int, end: int) -> Optional[MP4Atom]: + if pos + 8 > end: + return None + + size, atom_type = struct.unpack_from(">I4s", self.data, pos) + pos += 8 + + if size == 1: + if pos + 8 > end: + return None + size = struct.unpack_from(">Q", self.data, pos)[0] + pos += 8 + + if size < 8 or pos + size - 8 > end: + return None + + atom_data = self.data[pos : pos + size - 8] + return MP4Atom(atom_type, size, atom_data) + + def print_atoms_structure(self, indent: int = 0): + """ + Prints the structure of all atoms in the data. + + Args: + indent (int): The indentation level for printing. + """ + pos = 0 + end = len(self.data) + while pos + 8 <= end: + atom = self._read_atom_at(pos, end) + if not atom: + break + self.print_single_atom_structure(atom, pos, indent) + pos += atom.size + + def print_single_atom_structure(self, atom: MP4Atom, parent_position: int, indent: int): + """ + Prints the structure of a single atom. + + Args: + atom (MP4Atom): The atom to print. + parent_position (int): The position of the parent atom. + indent (int): The indentation level for printing. + """ + try: + atom_type = atom.atom_type.decode("utf-8") + except UnicodeDecodeError: + atom_type = repr(atom.atom_type) + print(" " * indent + f"Type: {atom_type}, Size: {atom.size}") + + child_pos = 0 + child_end = len(atom.data) + while child_pos + 8 <= child_end: + child_atom = self._read_atom_at(parent_position + 8 + child_pos, parent_position + 8 + child_end) + if not child_atom: + break + self.print_single_atom_structure(child_atom, parent_position, indent + 2) + child_pos += child_atom.size + + +class MP4Decrypter: + """ + Class to handle the decryption of CENC encrypted MP4 segments. + + Attributes: + key_map (dict[bytes, bytes]): Mapping of track IDs to decryption keys. + current_key (Optional[bytes]): Current decryption key. + trun_sample_sizes (array.array): Array of sample sizes from the 'trun' box. + current_sample_info (list): List of sample information from the 'senc' box. + encryption_overhead (int): Total size of encryption-related boxes. + """ + + def __init__(self, key_map: dict[bytes, bytes]): + """ + Initializes the MP4Decrypter with a key map. + + Args: + key_map (dict[bytes, bytes]): Mapping of track IDs to decryption keys. + """ + self.key_map = key_map + self.current_key = None + self.trun_sample_sizes = array.array("I") + self.current_sample_info = [] + self.encryption_overhead = 0 + + def decrypt_segment(self, combined_segment: bytes) -> bytes: + """ + Decrypts a combined MP4 segment. + + Args: + combined_segment (bytes): Combined initialization and media segment. + + Returns: + bytes: Decrypted segment content. + """ + data = memoryview(combined_segment) + parser = MP4Parser(data) + atoms = parser.list_atoms() + + atom_process_order = [b"moov", b"moof", b"sidx", b"mdat"] + + processed_atoms = {} + for atom_type in atom_process_order: + if atom := next((a for a in atoms if a.atom_type == atom_type), None): + processed_atoms[atom_type] = self._process_atom(atom_type, atom) + + result = bytearray() + for atom in atoms: + if atom.atom_type in processed_atoms: + processed_atom = processed_atoms[atom.atom_type] + result.extend(processed_atom.pack()) + else: + result.extend(atom.pack()) + + return bytes(result) + + def _process_atom(self, atom_type: bytes, atom: MP4Atom) -> MP4Atom: + """ + Processes an MP4 atom based on its type. + + Args: + atom_type (bytes): Type of the atom. + atom (MP4Atom): The atom to process. + + Returns: + MP4Atom: Processed atom. + """ + if atom_type == b"moov": + return self._process_moov(atom) + elif atom_type == b"moof": + return self._process_moof(atom) + elif atom_type == b"sidx": + return self._process_sidx(atom) + elif atom_type == b"mdat": + return self._decrypt_mdat(atom) + else: + return atom + + def _process_moov(self, moov: MP4Atom) -> MP4Atom: + """ + Processes the 'moov' (Movie) atom, which contains metadata about the entire presentation. + This includes information about tracks, media data, and other movie-level metadata. + + Args: + moov (MP4Atom): The 'moov' atom to process. + + Returns: + MP4Atom: Processed 'moov' atom with updated track information. + """ + parser = MP4Parser(moov.data) + new_moov_data = bytearray() + + for atom in iter(parser.read_atom, None): + if atom.atom_type == b"trak": + new_trak = self._process_trak(atom) + new_moov_data.extend(new_trak.pack()) + elif atom.atom_type != b"pssh": + # Skip PSSH boxes as they are not needed in the decrypted output + new_moov_data.extend(atom.pack()) + + return MP4Atom(b"moov", len(new_moov_data) + 8, new_moov_data) + + def _process_moof(self, moof: MP4Atom) -> MP4Atom: + """ + Processes the 'moov' (Movie) atom, which contains metadata about the entire presentation. + This includes information about tracks, media data, and other movie-level metadata. + + Args: + moov (MP4Atom): The 'moov' atom to process. + + Returns: + MP4Atom: Processed 'moov' atom with updated track information. + """ + parser = MP4Parser(moof.data) + new_moof_data = bytearray() + + for atom in iter(parser.read_atom, None): + if atom.atom_type == b"traf": + new_traf = self._process_traf(atom) + new_moof_data.extend(new_traf.pack()) + else: + new_moof_data.extend(atom.pack()) + + return MP4Atom(b"moof", len(new_moof_data) + 8, new_moof_data) + + def _process_traf(self, traf: MP4Atom) -> MP4Atom: + """ + Processes the 'traf' (Track Fragment) atom, which contains information about a track fragment. + This includes sample information, sample encryption data, and other track-level metadata. + + Args: + traf (MP4Atom): The 'traf' atom to process. + + Returns: + MP4Atom: Processed 'traf' atom with updated sample information. + """ + parser = MP4Parser(traf.data) + new_traf_data = bytearray() + tfhd = None + sample_count = 0 + sample_info = [] + + atoms = parser.list_atoms() + + # calculate encryption_overhead earlier to avoid dependency on trun + self.encryption_overhead = sum(a.size for a in atoms if a.atom_type in {b"senc", b"saiz", b"saio"}) + + for atom in atoms: + if atom.atom_type == b"tfhd": + tfhd = atom + new_traf_data.extend(atom.pack()) + elif atom.atom_type == b"trun": + sample_count = self._process_trun(atom) + new_trun = self._modify_trun(atom) + new_traf_data.extend(new_trun.pack()) + elif atom.atom_type == b"senc": + # Parse senc but don't include it in the new decrypted traf data and similarly don't include saiz and saio + sample_info = self._parse_senc(atom, sample_count) + elif atom.atom_type not in {b"saiz", b"saio"}: + new_traf_data.extend(atom.pack()) + + if tfhd: + tfhd_track_id = struct.unpack_from(">I", tfhd.data, 4)[0] + self.current_key = self._get_key_for_track(tfhd_track_id) + self.current_sample_info = sample_info + + return MP4Atom(b"traf", len(new_traf_data) + 8, new_traf_data) + + def _decrypt_mdat(self, mdat: MP4Atom) -> MP4Atom: + """ + Decrypts the 'mdat' (Media Data) atom, which contains the actual media data (audio, video, etc.). + The decryption is performed using the current decryption key and sample information. + + Args: + mdat (MP4Atom): The 'mdat' atom to decrypt. + + Returns: + MP4Atom: Decrypted 'mdat' atom with decrypted media data. + """ + if not self.current_key or not self.current_sample_info: + return mdat # Return original mdat if we don't have decryption info + + decrypted_samples = bytearray() + mdat_data = mdat.data + position = 0 + + for i, info in enumerate(self.current_sample_info): + if position >= len(mdat_data): + break # No more data to process + + sample_size = self.trun_sample_sizes[i] if i < len(self.trun_sample_sizes) else len(mdat_data) - position + sample = mdat_data[position : position + sample_size] + position += sample_size + decrypted_sample = self._process_sample(sample, info, self.current_key) + decrypted_samples.extend(decrypted_sample) + + return MP4Atom(b"mdat", len(decrypted_samples) + 8, decrypted_samples) + + def _parse_senc(self, senc: MP4Atom, sample_count: int) -> list[CENCSampleAuxiliaryDataFormat]: + """ + Parses the 'senc' (Sample Encryption) atom, which contains encryption information for samples. + This includes initialization vectors (IVs) and sub-sample encryption data. + + Args: + senc (MP4Atom): The 'senc' atom to parse. + sample_count (int): The number of samples. + + Returns: + list[CENCSampleAuxiliaryDataFormat]: List of sample auxiliary data formats with encryption information. + """ + data = memoryview(senc.data) + version_flags = struct.unpack_from(">I", data, 0)[0] + version, flags = version_flags >> 24, version_flags & 0xFFFFFF + position = 4 + + if version == 0: + sample_count = struct.unpack_from(">I", data, position)[0] + position += 4 + + sample_info = [] + for _ in range(sample_count): + if position + 8 > len(data): + break + + iv = data[position : position + 8].tobytes() + position += 8 + + sub_samples = [] + if flags & 0x000002 and position + 2 <= len(data): # Check if subsample information is present + subsample_count = struct.unpack_from(">H", data, position)[0] + position += 2 + + for _ in range(subsample_count): + if position + 6 <= len(data): + clear_bytes, encrypted_bytes = struct.unpack_from(">HI", data, position) + position += 6 + sub_samples.append((clear_bytes, encrypted_bytes)) + else: + break + + sample_info.append(CENCSampleAuxiliaryDataFormat(True, iv, sub_samples)) + + return sample_info + + def _get_key_for_track(self, track_id: int) -> bytes: + """ + Retrieves the decryption key for a given track ID from the key map. + + Args: + track_id (int): The track ID. + + Returns: + bytes: The decryption key for the specified track ID. + """ + if len(self.key_map) == 1: + return next(iter(self.key_map.values())) + key = self.key_map.get(track_id.pack(4, "big")) + if not key: + raise ValueError(f"No key found for track ID {track_id}") + return key + + @staticmethod + def _process_sample( + sample: memoryview, sample_info: CENCSampleAuxiliaryDataFormat, key: bytes + ) -> Union[memoryview, bytearray, bytes]: + """ + Processes and decrypts a sample using the provided sample information and decryption key. + This includes handling sub-sample encryption if present. + + Args: + sample (memoryview): The sample data. + sample_info (CENCSampleAuxiliaryDataFormat): The sample auxiliary data format with encryption information. + key (bytes): The decryption key. + + Returns: + Union[memoryview, bytearray, bytes]: The decrypted sample. + """ + if not sample_info.is_encrypted: + return sample + + # pad IV to 16 bytes + iv = sample_info.iv + b"\x00" * (16 - len(sample_info.iv)) + cipher = AES.new(key, AES.MODE_CTR, initial_value=iv, nonce=b"") + + if not sample_info.sub_samples: + # If there are no sub_samples, decrypt the entire sample + return cipher.decrypt(sample) + + result = bytearray() + offset = 0 + for clear_bytes, encrypted_bytes in sample_info.sub_samples: + result.extend(sample[offset : offset + clear_bytes]) + offset += clear_bytes + result.extend(cipher.decrypt(sample[offset : offset + encrypted_bytes])) + offset += encrypted_bytes + + # If there's any remaining data, treat it as encrypted + if offset < len(sample): + result.extend(cipher.decrypt(sample[offset:])) + + return result + + def _process_trun(self, trun: MP4Atom) -> int: + """ + Processes the 'trun' (Track Fragment Run) atom, which contains information about the samples in a track fragment. + This includes sample sizes, durations, flags, and composition time offsets. + + Args: + trun (MP4Atom): The 'trun' atom to process. + + Returns: + int: The number of samples in the 'trun' atom. + """ + trun_flags, sample_count = struct.unpack_from(">II", trun.data, 0) + data_offset = 8 + + if trun_flags & 0x000001: + data_offset += 4 + if trun_flags & 0x000004: + data_offset += 4 + + self.trun_sample_sizes = array.array("I") + + for _ in range(sample_count): + if trun_flags & 0x000100: # sample-duration-present flag + data_offset += 4 + if trun_flags & 0x000200: # sample-size-present flag + sample_size = struct.unpack_from(">I", trun.data, data_offset)[0] + self.trun_sample_sizes.append(sample_size) + data_offset += 4 + else: + self.trun_sample_sizes.append(0) # Using 0 instead of None for uniformity in the array + if trun_flags & 0x000400: # sample-flags-present flag + data_offset += 4 + if trun_flags & 0x000800: # sample-composition-time-offsets-present flag + data_offset += 4 + + return sample_count + + def _modify_trun(self, trun: MP4Atom) -> MP4Atom: + """ + Modifies the 'trun' (Track Fragment Run) atom to update the data offset. + This is necessary to account for the encryption overhead. + + Args: + trun (MP4Atom): The 'trun' atom to modify. + + Returns: + MP4Atom: Modified 'trun' atom with updated data offset. + """ + trun_data = bytearray(trun.data) + current_flags = struct.unpack_from(">I", trun_data, 0)[0] & 0xFFFFFF + + # If the data-offset-present flag is set, update the data offset to account for encryption overhead + if current_flags & 0x000001: + current_data_offset = struct.unpack_from(">i", trun_data, 8)[0] + struct.pack_into(">i", trun_data, 8, current_data_offset - self.encryption_overhead) + + return MP4Atom(b"trun", len(trun_data) + 8, trun_data) + + def _process_sidx(self, sidx: MP4Atom) -> MP4Atom: + """ + Processes the 'sidx' (Segment Index) atom, which contains indexing information for media segments. + This includes references to media segments and their durations. + + Args: + sidx (MP4Atom): The 'sidx' atom to process. + + Returns: + MP4Atom: Processed 'sidx' atom with updated segment references. + """ + sidx_data = bytearray(sidx.data) + + current_size = struct.unpack_from(">I", sidx_data, 32)[0] + reference_type = current_size >> 31 + current_referenced_size = current_size & 0x7FFFFFFF + + # Remove encryption overhead from referenced size + new_referenced_size = current_referenced_size - self.encryption_overhead + new_size = (reference_type << 31) | new_referenced_size + struct.pack_into(">I", sidx_data, 32, new_size) + + return MP4Atom(b"sidx", len(sidx_data) + 8, sidx_data) + + def _process_trak(self, trak: MP4Atom) -> MP4Atom: + """ + Processes the 'trak' (Track) atom, which contains information about a single track in the movie. + This includes track header, media information, and other track-level metadata. + + Args: + trak (MP4Atom): The 'trak' atom to process. + + Returns: + MP4Atom: Processed 'trak' atom with updated track information. + """ + parser = MP4Parser(trak.data) + new_trak_data = bytearray() + + for atom in iter(parser.read_atom, None): + if atom.atom_type == b"mdia": + new_mdia = self._process_mdia(atom) + new_trak_data.extend(new_mdia.pack()) + else: + new_trak_data.extend(atom.pack()) + + return MP4Atom(b"trak", len(new_trak_data) + 8, new_trak_data) + + def _process_mdia(self, mdia: MP4Atom) -> MP4Atom: + """ + Processes the 'mdia' (Media) atom, which contains media information for a track. + This includes media header, handler reference, and media information container. + + Args: + mdia (MP4Atom): The 'mdia' atom to process. + + Returns: + MP4Atom: Processed 'mdia' atom with updated media information. + """ + parser = MP4Parser(mdia.data) + new_mdia_data = bytearray() + + for atom in iter(parser.read_atom, None): + if atom.atom_type == b"minf": + new_minf = self._process_minf(atom) + new_mdia_data.extend(new_minf.pack()) + else: + new_mdia_data.extend(atom.pack()) + + return MP4Atom(b"mdia", len(new_mdia_data) + 8, new_mdia_data) + + def _process_minf(self, minf: MP4Atom) -> MP4Atom: + """ + Processes the 'minf' (Media Information) atom, which contains information about the media data in a track. + This includes data information, sample table, and other media-level metadata. + + Args: + minf (MP4Atom): The 'minf' atom to process. + + Returns: + MP4Atom: Processed 'minf' atom with updated media information. + """ + parser = MP4Parser(minf.data) + new_minf_data = bytearray() + + for atom in iter(parser.read_atom, None): + if atom.atom_type == b"stbl": + new_stbl = self._process_stbl(atom) + new_minf_data.extend(new_stbl.pack()) + else: + new_minf_data.extend(atom.pack()) + + return MP4Atom(b"minf", len(new_minf_data) + 8, new_minf_data) + + def _process_stbl(self, stbl: MP4Atom) -> MP4Atom: + """ + Processes the 'stbl' (Sample Table) atom, which contains information about the samples in a track. + This includes sample descriptions, sample sizes, sample times, and other sample-level metadata. + + Args: + stbl (MP4Atom): The 'stbl' atom to process. + + Returns: + MP4Atom: Processed 'stbl' atom with updated sample information. + """ + parser = MP4Parser(stbl.data) + new_stbl_data = bytearray() + + for atom in iter(parser.read_atom, None): + if atom.atom_type == b"stsd": + new_stsd = self._process_stsd(atom) + new_stbl_data.extend(new_stsd.pack()) + else: + new_stbl_data.extend(atom.pack()) + + return MP4Atom(b"stbl", len(new_stbl_data) + 8, new_stbl_data) + + def _process_stsd(self, stsd: MP4Atom) -> MP4Atom: + """ + Processes the 'stsd' (Sample Description) atom, which contains descriptions of the sample entries in a track. + This includes codec information, sample entry details, and other sample description metadata. + + Args: + stsd (MP4Atom): The 'stsd' atom to process. + + Returns: + MP4Atom: Processed 'stsd' atom with updated sample descriptions. + """ + parser = MP4Parser(stsd.data) + entry_count = struct.unpack_from(">I", parser.data, 4)[0] + new_stsd_data = bytearray(stsd.data[:8]) + + parser.position = 8 # Move past version_flags and entry_count + + for _ in range(entry_count): + sample_entry = parser.read_atom() + if not sample_entry: + break + + processed_entry = self._process_sample_entry(sample_entry) + new_stsd_data.extend(processed_entry.pack()) + + return MP4Atom(b"stsd", len(new_stsd_data) + 8, new_stsd_data) + + def _process_sample_entry(self, entry: MP4Atom) -> MP4Atom: + """ + Processes a sample entry atom, which contains information about a specific type of sample. + This includes codec-specific information and other sample entry details. + + Args: + entry (MP4Atom): The sample entry atom to process. + + Returns: + MP4Atom: Processed sample entry atom with updated information. + """ + # Determine the size of fixed fields based on sample entry type + if entry.atom_type in {b"mp4a", b"enca"}: + fixed_size = 28 # 8 bytes for size, type and reserved, 20 bytes for fixed fields in Audio Sample Entry. + elif entry.atom_type in {b"mp4v", b"encv", b"avc1", b"hev1", b"hvc1"}: + fixed_size = 78 # 8 bytes for size, type and reserved, 70 bytes for fixed fields in Video Sample Entry. + else: + fixed_size = 16 # 8 bytes for size, type and reserved, 8 bytes for fixed fields in other Sample Entries. + + new_entry_data = bytearray(entry.data[:fixed_size]) + parser = MP4Parser(entry.data[fixed_size:]) + codec_format = None + + for atom in iter(parser.read_atom, None): + if atom.atom_type in {b"sinf", b"schi", b"tenc", b"schm"}: + if atom.atom_type == b"sinf": + codec_format = self._extract_codec_format(atom) + continue # Skip encryption-related atoms + new_entry_data.extend(atom.pack()) + + # Replace the atom type with the extracted codec format + new_type = codec_format if codec_format else entry.atom_type + return MP4Atom(new_type, len(new_entry_data) + 8, new_entry_data) + + def _extract_codec_format(self, sinf: MP4Atom) -> Optional[bytes]: + """ + Extracts the codec format from the 'sinf' (Protection Scheme Information) atom. + This includes information about the original format of the protected content. + + Args: + sinf (MP4Atom): The 'sinf' atom to extract from. + + Returns: + Optional[bytes]: The codec format or None if not found. + """ + parser = MP4Parser(sinf.data) + for atom in iter(parser.read_atom, None): + if atom.atom_type == b"frma": + return atom.data + return None + + +def decrypt_segment(init_segment: bytes, segment_content: bytes, key_id: str, key: str) -> bytes: + """ + Decrypts a CENC encrypted MP4 segment. + + Args: + init_segment (bytes): Initialization segment data. + segment_content (bytes): Encrypted segment content. + key_id (str): Key ID in hexadecimal format. + key (str): Key in hexadecimal format. + """ + key_map = {bytes.fromhex(key_id): bytes.fromhex(key)} + decrypter = MP4Decrypter(key_map) + decrypted_content = decrypter.decrypt_segment(init_segment + segment_content) + return decrypted_content + + +def cli(): + """ + Command line interface for decrypting a CENC encrypted MP4 segment. + """ + init_segment = b"" + + if args.init and args.segment: + with open(args.init, "rb") as f: + init_segment = f.read() + with open(args.segment, "rb") as f: + segment_content = f.read() + elif args.combined_segment: + with open(args.combined_segment, "rb") as f: + segment_content = f.read() + else: + print("Usage: python mp4decrypt.py --help") + sys.exit(1) + + try: + decrypted_segment = decrypt_segment(init_segment, segment_content, args.key_id, args.key) + print(f"Decrypted content size is {len(decrypted_segment)} bytes") + with open(args.output, "wb") as f: + f.write(decrypted_segment) + print(f"Decrypted segment written to {args.output}") + except Exception as e: + print(f"Error: {e}") + sys.exit(1) + + +if __name__ == "__main__": + arg_parser = argparse.ArgumentParser(description="Decrypts a MP4 init and media segment using CENC encryption.") + arg_parser.add_argument("--init", help="Path to the init segment file", required=False) + arg_parser.add_argument("--segment", help="Path to the media segment file", required=False) + arg_parser.add_argument( + "--combined_segment", help="Path to the combined init and media segment file", required=False + ) + arg_parser.add_argument("--key_id", help="Key ID in hexadecimal format", required=True) + arg_parser.add_argument("--key", help="Key in hexadecimal format", required=True) + arg_parser.add_argument("--output", help="Path to the output file", required=True) + args = arg_parser.parse_args() + cli() diff --git a/mediaflow_proxy/extractors/__init__.py b/mediaflow_proxy/extractors/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mediaflow_proxy/extractors/base.py b/mediaflow_proxy/extractors/base.py new file mode 100644 index 0000000..bf8a15b --- /dev/null +++ b/mediaflow_proxy/extractors/base.py @@ -0,0 +1,50 @@ +from abc import ABC, abstractmethod +from typing import Dict, Optional, Any + +import httpx + +from mediaflow_proxy.configs import settings +from mediaflow_proxy.utils.http_utils import create_httpx_client + + +class ExtractorError(Exception): + """Base exception for all extractors.""" + + pass + + +class BaseExtractor(ABC): + """Base class for all URL extractors.""" + + def __init__(self, request_headers: dict): + self.base_headers = { + "user-agent": settings.user_agent, + } + self.mediaflow_endpoint = "proxy_stream_endpoint" + self.base_headers.update(request_headers) + + async def _make_request( + self, url: str, method: str = "GET", headers: Optional[Dict] = None, **kwargs + ) -> httpx.Response: + """Make HTTP request with error handling.""" + try: + async with create_httpx_client() as client: + request_headers = self.base_headers + request_headers.update(headers or {}) + response = await client.request( + method, + url, + headers=request_headers, + **kwargs, + ) + response.raise_for_status() + return response + except httpx.HTTPError as e: + raise ExtractorError(f"HTTP request failed: {str(e)}") + except Exception as e: + raise ExtractorError(f"Request failed: {str(e)}") + + @abstractmethod + async def extract(self, url: str, **kwargs) -> Dict[str, Any]: + """Extract final URL and required headers.""" + pass diff --git a/mediaflow_proxy/extractors/doodstream.py b/mediaflow_proxy/extractors/doodstream.py new file mode 100644 index 0000000..a8f851d --- /dev/null +++ b/mediaflow_proxy/extractors/doodstream.py @@ -0,0 +1,39 @@ +import re +import time +from typing import Dict + +from mediaflow_proxy.extractors.base import BaseExtractor, ExtractorError + + +class DoodStreamExtractor(BaseExtractor): + """DoodStream URL extractor.""" + + def __init__(self, request_headers: dict): + super().__init__(request_headers) + self.base_url = "https://d000d.com" + + async def extract(self, url: str, **kwargs) -> Dict[str, str]: + """Extract DoodStream URL.""" + response = await self._make_request(url) + + # Extract URL pattern + pattern = r"(\/pass_md5\/.*?)'.*(\?token=.*?expiry=)" + match = re.search(pattern, response.text, re.DOTALL) + if not match: + raise ExtractorError("Failed to extract URL pattern") + + # Build final URL + pass_url = f"{self.base_url}{match[1]}" + referer = f"{self.base_url}/" + headers = {"range": "bytes=0-", "referer": referer} + + response = await self._make_request(pass_url, headers=headers) + timestamp = str(int(time.time())) + final_url = f"{response.text}123456789{match[2]}{timestamp}" + + self.base_headers["referer"] = referer + return { + "destination_url": final_url, + "request_headers": self.base_headers, + "mediaflow_endpoint": self.mediaflow_endpoint, + } diff --git a/mediaflow_proxy/extractors/factory.py b/mediaflow_proxy/extractors/factory.py new file mode 100644 index 0000000..c61d408 --- /dev/null +++ b/mediaflow_proxy/extractors/factory.py @@ -0,0 +1,32 @@ +from typing import Dict, Type + +from mediaflow_proxy.extractors.base import BaseExtractor, ExtractorError +from mediaflow_proxy.extractors.doodstream import DoodStreamExtractor +from mediaflow_proxy.extractors.livetv import LiveTVExtractor +from mediaflow_proxy.extractors.mixdrop import MixdropExtractor +from mediaflow_proxy.extractors.uqload import UqloadExtractor +from mediaflow_proxy.extractors.streamtape import StreamtapeExtractor +from mediaflow_proxy.extractors.supervideo import SupervideoExtractor + + + + +class ExtractorFactory: + """Factory for creating URL extractors.""" + + _extractors: Dict[str, Type[BaseExtractor]] = { + "Doodstream": DoodStreamExtractor, + "Uqload": UqloadExtractor, + "Mixdrop": MixdropExtractor, + "Streamtape": StreamtapeExtractor, + "Supervideo": SupervideoExtractor, + "LiveTV": LiveTVExtractor, + } + + @classmethod + def get_extractor(cls, host: str, request_headers: dict) -> BaseExtractor: + """Get appropriate extractor instance for the given host.""" + extractor_class = cls._extractors.get(host) + if not extractor_class: + raise ExtractorError(f"Unsupported host: {host}") + return extractor_class(request_headers) diff --git a/mediaflow_proxy/extractors/livetv.py b/mediaflow_proxy/extractors/livetv.py new file mode 100644 index 0000000..fbe0c93 --- /dev/null +++ b/mediaflow_proxy/extractors/livetv.py @@ -0,0 +1,251 @@ +import re +from typing import Dict, Tuple, Optional +from urllib.parse import urljoin, urlparse, unquote + +from httpx import Response + +from mediaflow_proxy.extractors.base import BaseExtractor, ExtractorError + + +class LiveTVExtractor(BaseExtractor): + """LiveTV URL extractor for both M3U8 and MPD streams.""" + + def __init__(self, request_headers: dict): + super().__init__(request_headers) + # Default to HLS proxy endpoint, will be updated based on stream type + self.mediaflow_endpoint = "hls_manifest_proxy" + + # Patterns for stream URL extraction + self.fallback_pattern = re.compile( + r"source: [\'\"](.*?)[\'\"]\s*,\s*[\s\S]*?mimeType: [\'\"](application/x-mpegURL|application/vnd\.apple\.mpegURL|application/dash\+xml)[\'\"]", + re.IGNORECASE, + ) + self.any_m3u8_pattern = re.compile( + r'["\']?(https?://.*?\.m3u8(?:\?[^"\']*)?)["\']?', + re.IGNORECASE, + ) + + async def extract(self, url: str, stream_title: str = None, **kwargs) -> Dict[str, str]: + """Extract LiveTV URL and required headers. + + Args: + url: The channel page URL + stream_title: Optional stream title to filter specific stream + + Returns: + Tuple[str, Dict[str, str]]: Stream URL and required headers + """ + try: + # Get the channel page + response = await self._make_request(url) + self.base_headers["referer"] = urljoin(url, "/") + + # Extract player API details + player_api_base, method = await self._extract_player_api_base(response.text) + if not player_api_base: + raise ExtractorError("Failed to extract player API URL") + + # Get player options + options_data = await self._get_player_options(response.text) + if not options_data: + raise ExtractorError("No player options found") + + # Process player options to find matching stream + for option in options_data: + current_title = option.get("title") + if stream_title and current_title != stream_title: + continue + + # Get stream URL based on player option + stream_data = await self._process_player_option( + player_api_base, method, option.get("post"), option.get("nume"), option.get("type") + ) + + if stream_data: + stream_url = stream_data.get("url") + if not stream_url: + continue + + response = { + "destination_url": stream_url, + "request_headers": self.base_headers, + "mediaflow_endpoint": self.mediaflow_endpoint, + } + + # Set endpoint based on stream type + if stream_data.get("type") == "mpd": + if stream_data.get("drm_key_id") and stream_data.get("drm_key"): + response.update( + { + "query_params": { + "key_id": stream_data["drm_key_id"], + "key": stream_data["drm_key"], + }, + "mediaflow_endpoint": "mpd_manifest_proxy", + } + ) + + return response + + raise ExtractorError("No valid stream found") + + except Exception as e: + raise ExtractorError(f"Extraction failed: {str(e)}") + + async def _extract_player_api_base(self, html_content: str) -> Tuple[Optional[str], Optional[str]]: + """Extract player API base URL and method.""" + admin_ajax_pattern = r'"player_api"\s*:\s*"([^"]+)".*?"play_method"\s*:\s*"([^"]+)"' + match = re.search(admin_ajax_pattern, html_content) + if not match: + return None, None + url = match.group(1).replace("\\/", "/") + method = match.group(2) + if method == "wp_json": + return url, method + url = urljoin(url, "/wp-admin/admin-ajax.php") + return url, method + + async def _get_player_options(self, html_content: str) -> list: + """Extract player options from HTML content.""" + pattern = r']*class=["\']dooplay_player_option["\'][^>]*data-type=["\']([^"\']*)["\'][^>]*data-post=["\']([^"\']*)["\'][^>]*data-nume=["\']([^"\']*)["\'][^>]*>.*?([^<]*)' + matches = re.finditer(pattern, html_content, re.DOTALL) + return [ + {"type": match.group(1), "post": match.group(2), "nume": match.group(3), "title": match.group(4).strip()} + for match in matches + ] + + async def _process_player_option(self, api_base: str, method: str, post: str, nume: str, type_: str) -> Dict: + """Process player option to get stream URL.""" + if method == "wp_json": + api_url = f"{api_base}{post}/{type_}/{nume}" + response = await self._make_request(api_url) + else: + form_data = {"action": "doo_player_ajax", "post": post, "nume": nume, "type": type_} + response = await self._make_request(api_base, method="POST", data=form_data) + + # Get iframe URL from API response + try: + data = response.json() + iframe_url = urljoin(api_base, data.get("embed_url", "").replace("\\/", "/")) + + # Get stream URL from iframe + iframe_response = await self._make_request(iframe_url) + stream_data = await self._extract_stream_url(iframe_response, iframe_url) + return stream_data + + except Exception as e: + raise ExtractorError(f"Failed to process player option: {str(e)}") + + async def _extract_stream_url(self, iframe_response: Response, iframe_url: str) -> Dict: + """ + Extract final stream URL from iframe content. + """ + try: + # Parse URL components + parsed_url = urlparse(iframe_url) + query_params = dict(param.split("=") for param in parsed_url.query.split("&") if "=" in param) + + # Check if content is already a direct M3U8 stream + content_types = ["application/x-mpegurl", "application/vnd.apple.mpegurl"] + + if any(ext in iframe_response.headers["content-type"] for ext in content_types): + return {"url": iframe_url, "type": "m3u8"} + + stream_data = {} + + # Check for source parameter in URL + if "source" in query_params: + stream_data = { + "url": urljoin(iframe_url, unquote(query_params["source"])), + "type": "m3u8", + } + + # Check for MPD stream with DRM + elif "zy" in query_params and ".mpd``" in query_params["zy"]: + data = query_params["zy"].split("``") + url = data[0] + key_id, key = data[1].split(":") + stream_data = {"url": url, "type": "mpd", "drm_key_id": key_id, "drm_key": key} + + # Check for tamilultra specific format + elif "tamilultra" in iframe_url: + stream_data = {"url": urljoin(iframe_url, parsed_url.query), "type": "m3u8"} + + # Try pattern matching for stream URLs + else: + channel_id = query_params.get("id", [""]) + stream_url = None + + html_content = iframe_response.text + + if channel_id: + # Try channel ID specific pattern + pattern = rf'{re.escape(channel_id)}["\']:\s*{{\s*["\']?url["\']?\s*:\s*["\']([^"\']+)["\']' + match = re.search(pattern, html_content) + if match: + stream_url = match.group(1) + + # Try fallback patterns if channel ID pattern fails + if not stream_url: + for pattern in [self.fallback_pattern, self.any_m3u8_pattern]: + match = pattern.search(html_content) + if match: + stream_url = match.group(1) + break + + if stream_url: + stream_data = {"url": stream_url, "type": "m3u8"} # Default to m3u8, will be updated + + # Check for MPD stream and extract DRM keys + if stream_url.endswith(".mpd"): + stream_data["type"] = "mpd" + drm_data = await self._extract_drm_keys(html_content, channel_id) + if drm_data: + stream_data.update(drm_data) + + # If no stream data found, raise error + if not stream_data: + raise ExtractorError("No valid stream URL found") + + # Update stream type based on URL if not already set + if stream_data.get("type") == "m3u8": + if stream_data["url"].endswith(".mpd"): + stream_data["type"] = "mpd" + elif not any(ext in stream_data["url"] for ext in [".m3u8", ".m3u"]): + stream_data["type"] = "m3u8" # Default to m3u8 if no extension found + + return stream_data + + except Exception as e: + raise ExtractorError(f"Failed to extract stream URL: {str(e)}") + + async def _extract_drm_keys(self, html_content: str, channel_id: str) -> Dict: + """ + Extract DRM keys for MPD streams. + """ + try: + # Pattern for channel entry + channel_pattern = rf'"{re.escape(channel_id)}":\s*{{[^}}]+}}' + channel_match = re.search(channel_pattern, html_content) + + if channel_match: + channel_data = channel_match.group(0) + + # Try clearkeys pattern first + clearkey_pattern = r'["\']?clearkeys["\']?\s*:\s*{\s*["\'](.+?)["\']:\s*["\'](.+?)["\']' + clearkey_match = re.search(clearkey_pattern, channel_data) + + # Try k1/k2 pattern if clearkeys not found + if not clearkey_match: + k1k2_pattern = r'["\']?k1["\']?\s*:\s*["\'](.+?)["\'],\s*["\']?k2["\']?\s*:\s*["\'](.+?)["\']' + k1k2_match = re.search(k1k2_pattern, channel_data) + + if k1k2_match: + return {"drm_key_id": k1k2_match.group(1), "drm_key": k1k2_match.group(2)} + else: + return {"drm_key_id": clearkey_match.group(1), "drm_key": clearkey_match.group(2)} + + return {} + + except Exception: + return {} diff --git a/mediaflow_proxy/extractors/mixdrop.py b/mediaflow_proxy/extractors/mixdrop.py new file mode 100644 index 0000000..26d91a7 --- /dev/null +++ b/mediaflow_proxy/extractors/mixdrop.py @@ -0,0 +1,36 @@ +import re +import string +from typing import Dict, Any + +from mediaflow_proxy.extractors.base import BaseExtractor, ExtractorError + + +class MixdropExtractor(BaseExtractor): + """Mixdrop URL extractor.""" + + async def extract(self, url: str, **kwargs) -> Dict[str, Any]: + """Extract Mixdrop URL.""" + response = await self._make_request(url, headers={"accept-language": "en-US,en;q=0.5"}) + + # Extract and decode URL + match = re.search(r"}\('(.+)',.+,'(.+)'\.split", response.text) + if not match: + raise ExtractorError("Failed to extract URL components") + + s1, s2 = match.group(1, 2) + schema = s1.split(";")[2][5:-1] + terms = s2.split("|") + + # Build character mapping + charset = string.digits + string.ascii_letters + char_map = {charset[i]: terms[i] or charset[i] for i in range(len(terms))} + + # Construct final URL + final_url = "https:" + "".join(char_map.get(c, c) for c in schema) + + self.base_headers["referer"] = url + return { + "destination_url": final_url, + "request_headers": self.base_headers, + "mediaflow_endpoint": self.mediaflow_endpoint, + } diff --git a/mediaflow_proxy/extractors/streamtape.py b/mediaflow_proxy/extractors/streamtape.py new file mode 100644 index 0000000..6358a7d --- /dev/null +++ b/mediaflow_proxy/extractors/streamtape.py @@ -0,0 +1,32 @@ +import re +from typing import Dict, Any + +from mediaflow_proxy.extractors.base import BaseExtractor, ExtractorError + + +class StreamtapeExtractor(BaseExtractor): + """Streamtape URL extractor.""" + + async def extract(self, url: str, **kwargs) -> Dict[str, Any]: + """Extract Streamtape URL.""" + response = await self._make_request(url) + + # Extract and decode URL + matches = re.findall(r"id=.*?(?=')", response.text) + if not matches: + raise ExtractorError("Failed to extract URL components") + final_url = next( + ( + f"https://streamtape.com/get_video?{matches[i + 1]}" + for i in range(len(matches) - 1) + if matches[i] == matches[i + 1] + ), + None, + ) + + self.base_headers["referer"] = url + return { + "destination_url": final_url, + "request_headers": self.base_headers, + "mediaflow_endpoint": self.mediaflow_endpoint, + } diff --git a/mediaflow_proxy/extractors/supervideo.py b/mediaflow_proxy/extractors/supervideo.py new file mode 100644 index 0000000..ba0fc6b --- /dev/null +++ b/mediaflow_proxy/extractors/supervideo.py @@ -0,0 +1,27 @@ +import re +from typing import Dict, Any + +from mediaflow_proxy.extractors.base import BaseExtractor, ExtractorError + + +class SupervideoExtractor(BaseExtractor): + """Supervideo URL extractor.""" + + async def extract(self, url: str, **kwargs) -> Dict[str, Any]: + """Extract Supervideo URL.""" + response = await self._make_request(url) + # Extract and decode URL + s2 = re.search(r"\}\('(.+)',.+,'(.+)'\.split", response.text).group(2) + terms = s2.split("|") + hfs = next(terms[i] for i in range(terms.index("file"), len(terms)) if "hfs" in terms[i]) + result = terms[terms.index("urlset") + 1 : terms.index("hls")] + + base_url = f"https://{hfs}.serversicuro.cc/hls/" + final_url = base_url + ",".join(reversed(result)) + (".urlset/master.m3u8" if result else "") + + self.base_headers["referer"] = url + return { + "destination_url": final_url, + "request_headers": self.base_headers, + "mediaflow_endpoint": self.mediaflow_endpoint, + } diff --git a/mediaflow_proxy/extractors/uqload.py b/mediaflow_proxy/extractors/uqload.py new file mode 100644 index 0000000..19cdbd2 --- /dev/null +++ b/mediaflow_proxy/extractors/uqload.py @@ -0,0 +1,24 @@ +import re +from typing import Dict +from urllib.parse import urljoin + +from mediaflow_proxy.extractors.base import BaseExtractor, ExtractorError + + +class UqloadExtractor(BaseExtractor): + """Uqload URL extractor.""" + + async def extract(self, url: str, **kwargs) -> Dict[str, str]: + """Extract Uqload URL.""" + response = await self._make_request(url) + + video_url_match = re.search(r'sources: \["(.*?)"]', response.text) + if not video_url_match: + raise ExtractorError("Failed to extract video URL") + + self.base_headers["referer"] = urljoin(url, "/") + return { + "destination_url": video_url_match.group(1), + "request_headers": self.base_headers, + "mediaflow_endpoint": self.mediaflow_endpoint, + } diff --git a/mediaflow_proxy/handlers.py b/mediaflow_proxy/handlers.py new file mode 100644 index 0000000..957eb51 --- /dev/null +++ b/mediaflow_proxy/handlers.py @@ -0,0 +1,358 @@ +import base64 +import logging +from urllib.parse import urlparse + +import httpx +from fastapi import Request, Response, HTTPException +from starlette.background import BackgroundTask + +from .const import SUPPORTED_RESPONSE_HEADERS +from .mpd_processor import process_manifest, process_playlist, process_segment +from .schemas import HLSManifestParams, ProxyStreamParams, MPDManifestParams, MPDPlaylistParams, MPDSegmentParams +from .utils.cache_utils import get_cached_mpd, get_cached_init_segment +from .utils.http_utils import ( + Streamer, + DownloadError, + download_file_with_retry, + request_with_retry, + EnhancedStreamingResponse, + ProxyRequestHeaders, + create_httpx_client, +) +from .utils.m3u8_processor import M3U8Processor +from .utils.mpd_utils import pad_base64 + +logger = logging.getLogger(__name__) + + +async def setup_client_and_streamer() -> tuple[httpx.AsyncClient, Streamer]: + """ + Set up an HTTP client and a streamer. + + Returns: + tuple: An httpx.AsyncClient instance and a Streamer instance. + """ + client = create_httpx_client() + return client, Streamer(client) + + +def handle_exceptions(exception: Exception) -> Response: + """ + Handle exceptions and return appropriate HTTP responses. + + Args: + exception (Exception): The exception that was raised. + + Returns: + Response: An HTTP response corresponding to the exception type. + """ + if isinstance(exception, httpx.HTTPStatusError): + logger.error(f"Upstream service error while handling request: {exception}") + return Response(status_code=exception.response.status_code, content=f"Upstream service error: {exception}") + elif isinstance(exception, DownloadError): + logger.error(f"Error downloading content: {exception}") + return Response(status_code=exception.status_code, content=str(exception)) + else: + logger.exception(f"Internal server error while handling request: {exception}") + return Response(status_code=502, content=f"Internal server error: {exception}") + + +async def handle_hls_stream_proxy( + request: Request, hls_params: HLSManifestParams, proxy_headers: ProxyRequestHeaders +) -> Response: + """ + Handle HLS stream proxy requests. + + This function processes HLS manifest files and streams content based on the request parameters. + + Args: + request (Request): The incoming FastAPI request object. + hls_params (HLSManifestParams): Parameters for the HLS manifest. + proxy_headers (ProxyRequestHeaders): Headers to be used in the proxy request. + + Returns: + Union[Response, EnhancedStreamingResponse]: Either a processed m3u8 playlist or a streaming response. + """ + client, streamer = await setup_client_and_streamer() + + try: + if urlparse(hls_params.destination).path.endswith((".m3u", ".m3u8")): + return await fetch_and_process_m3u8( + streamer, hls_params.destination, proxy_headers, request, hls_params.key_url + ) + + # Create initial streaming response to check content type + await streamer.create_streaming_response(hls_params.destination, proxy_headers.request) + + if "mpegurl" in streamer.response.headers.get("content-type", "").lower(): + return await fetch_and_process_m3u8( + streamer, hls_params.destination, proxy_headers, request, hls_params.key_url + ) + + # Handle range requests + content_range = proxy_headers.request.get("range", "bytes=0-") + if "NaN" in content_range: + # Handle invalid range requests "bytes=NaN-NaN" + raise HTTPException(status_code=416, detail="Invalid Range Header") + proxy_headers.request.update({"range": content_range}) + + # Create new streaming response with updated headers + await streamer.create_streaming_response(hls_params.destination, proxy_headers.request) + response_headers = prepare_response_headers(streamer.response.headers, proxy_headers.response) + + return EnhancedStreamingResponse( + streamer.stream_content(), + status_code=streamer.response.status_code, + headers=response_headers, + background=BackgroundTask(streamer.close), + ) + except Exception as e: + await streamer.close() + return handle_exceptions(e) + + +async def handle_stream_request( + method: str, + video_url: str, + proxy_headers: ProxyRequestHeaders, +) -> Response: + """ + Handle general stream requests. + + This function processes both HEAD and GET requests for video streams. + + Args: + method (str): The HTTP method (e.g., 'GET' or 'HEAD'). + video_url (str): The URL of the video to stream. + proxy_headers (ProxyRequestHeaders): Headers to be used in the proxy request. + + Returns: + Union[Response, EnhancedStreamingResponse]: Either a HEAD response with headers or a streaming response. + """ + client, streamer = await setup_client_and_streamer() + + try: + await streamer.create_streaming_response(video_url, proxy_headers.request) + response_headers = prepare_response_headers(streamer.response.headers, proxy_headers.response) + + if method == "HEAD": + # For HEAD requests, just return the headers without streaming content + await streamer.close() + return Response(headers=response_headers, status_code=streamer.response.status_code) + else: + # For GET requests, return the streaming response + return EnhancedStreamingResponse( + streamer.stream_content(), + headers=response_headers, + status_code=streamer.response.status_code, + background=BackgroundTask(streamer.close), + ) + except Exception as e: + await streamer.close() + return handle_exceptions(e) + + +def prepare_response_headers(original_headers, proxy_response_headers) -> dict: + """ + Prepare response headers for the proxy response. + + This function filters the original headers, ensures proper transfer encoding, + and merges them with the proxy response headers. + + Args: + original_headers (httpx.Headers): The original headers from the upstream response. + proxy_response_headers (dict): Additional headers to be included in the proxy response. + + Returns: + dict: The prepared headers for the proxy response. + """ + response_headers = {k: v for k, v in original_headers.multi_items() if k in SUPPORTED_RESPONSE_HEADERS} + response_headers.update(proxy_response_headers) + return response_headers + + +async def proxy_stream(method: str, stream_params: ProxyStreamParams, proxy_headers: ProxyRequestHeaders): + """ + Proxies the stream request to the given video URL. + + Args: + method (str): The HTTP method (e.g., GET, HEAD). + stream_params (ProxyStreamParams): The parameters for the stream request. + proxy_headers (ProxyRequestHeaders): The headers to include in the request. + + Returns: + Response: The HTTP response with the streamed content. + """ + return await handle_stream_request(method, stream_params.destination, proxy_headers) + + +async def fetch_and_process_m3u8( + streamer: Streamer, url: str, proxy_headers: ProxyRequestHeaders, request: Request, key_url: str = None +): + """ + Fetches and processes the m3u8 playlist, converting it to an HLS playlist. + + Args: + streamer (Streamer): The HTTP client to use for streaming. + url (str): The URL of the m3u8 playlist. + proxy_headers (ProxyRequestHeaders): The headers to include in the request. + request (Request): The incoming HTTP request. + key_url (str, optional): The HLS Key URL to replace the original key URL. Defaults to None. + + Returns: + Response: The HTTP response with the processed m3u8 playlist. + """ + try: + content = await streamer.get_text(url, proxy_headers.request) + processor = M3U8Processor(request, key_url) + processed_content = await processor.process_m3u8(content, str(streamer.response.url)) + response_headers = {"Content-Disposition": "inline", "Accept-Ranges": "none"} + response_headers.update(proxy_headers.response) + return Response( + content=processed_content, + media_type="application/vnd.apple.mpegurl", + headers=response_headers, + ) + except Exception as e: + return handle_exceptions(e) + finally: + await streamer.close() + + +async def handle_drm_key_data(key_id, key, drm_info): + """ + Handles the DRM key data, retrieving the key ID and key from the DRM info if not provided. + + Args: + key_id (str): The DRM key ID. + key (str): The DRM key. + drm_info (dict): The DRM information from the MPD manifest. + + Returns: + tuple: The key ID and key. + """ + if drm_info and not drm_info.get("isDrmProtected"): + return None, None + + if not key_id or not key: + if "keyId" in drm_info and "key" in drm_info: + key_id = drm_info["keyId"] + key = drm_info["key"] + elif "laUrl" in drm_info and "keyId" in drm_info: + raise HTTPException(status_code=400, detail="LA URL is not supported yet") + else: + raise HTTPException( + status_code=400, detail="Unable to determine key_id and key, and they were not provided" + ) + + return key_id, key + + +async def get_manifest( + request: Request, + manifest_params: MPDManifestParams, + proxy_headers: ProxyRequestHeaders, +): + """ + Retrieves and processes the MPD manifest, converting it to an HLS manifest. + + Args: + request (Request): The incoming HTTP request. + manifest_params (MPDManifestParams): The parameters for the manifest request. + proxy_headers (ProxyRequestHeaders): The headers to include in the request. + + Returns: + Response: The HTTP response with the HLS manifest. + """ + try: + mpd_dict = await get_cached_mpd( + manifest_params.destination, + headers=proxy_headers.request, + parse_drm=not manifest_params.key_id and not manifest_params.key, + ) + except DownloadError as e: + raise HTTPException(status_code=e.status_code, detail=f"Failed to download MPD: {e.message}") + drm_info = mpd_dict.get("drmInfo", {}) + + if drm_info and not drm_info.get("isDrmProtected"): + # For non-DRM protected MPD, we still create an HLS manifest + return await process_manifest(request, mpd_dict, proxy_headers, None, None) + + key_id, key = await handle_drm_key_data(manifest_params.key_id, manifest_params.key, drm_info) + + # check if the provided key_id and key are valid + if key_id and len(key_id) != 32: + key_id = base64.urlsafe_b64decode(pad_base64(key_id)).hex() + if key and len(key) != 32: + key = base64.urlsafe_b64decode(pad_base64(key)).hex() + + return await process_manifest(request, mpd_dict, proxy_headers, key_id, key) + + +async def get_playlist( + request: Request, + playlist_params: MPDPlaylistParams, + proxy_headers: ProxyRequestHeaders, +): + """ + Retrieves and processes the MPD manifest, converting it to an HLS playlist for a specific profile. + + Args: + request (Request): The incoming HTTP request. + playlist_params (MPDPlaylistParams): The parameters for the playlist request. + proxy_headers (ProxyRequestHeaders): The headers to include in the request. + + Returns: + Response: The HTTP response with the HLS playlist. + """ + try: + mpd_dict = await get_cached_mpd( + playlist_params.destination, + headers=proxy_headers.request, + parse_drm=not playlist_params.key_id and not playlist_params.key, + parse_segment_profile_id=playlist_params.profile_id, + ) + except DownloadError as e: + raise HTTPException(status_code=e.status_code, detail=f"Failed to download MPD: {e.message}") + return await process_playlist(request, mpd_dict, playlist_params.profile_id, proxy_headers) + + +async def get_segment( + segment_params: MPDSegmentParams, + proxy_headers: ProxyRequestHeaders, +): + """ + Retrieves and processes a media segment, decrypting it if necessary. + + Args: + segment_params (MPDSegmentParams): The parameters for the segment request. + proxy_headers (ProxyRequestHeaders): The headers to include in the request. + + Returns: + Response: The HTTP response with the processed segment. + """ + try: + init_content = await get_cached_init_segment(segment_params.init_url, proxy_headers.request) + segment_content = await download_file_with_retry(segment_params.segment_url, proxy_headers.request) + except Exception as e: + return handle_exceptions(e) + + return await process_segment( + init_content, + segment_content, + segment_params.mime_type, + proxy_headers, + segment_params.key_id, + segment_params.key, + ) + + +async def get_public_ip(): + """ + Retrieves the public IP address of the MediaFlow proxy. + + Returns: + Response: The HTTP response with the public IP address. + """ + ip_address_data = await request_with_retry("GET", "https://api.ipify.org?format=json", {}) + return ip_address_data.json() diff --git a/mediaflow_proxy/main.py b/mediaflow_proxy/main.py new file mode 100644 index 0000000..9b45b81 --- /dev/null +++ b/mediaflow_proxy/main.py @@ -0,0 +1,99 @@ +import logging +from importlib import resources + +from fastapi import FastAPI, Depends, Security, HTTPException +from fastapi.security import APIKeyQuery, APIKeyHeader +from starlette.middleware.cors import CORSMiddleware +from starlette.responses import RedirectResponse +from starlette.staticfiles import StaticFiles + +from mediaflow_proxy.configs import settings +from mediaflow_proxy.routes import proxy_router, extractor_router, speedtest_router +from mediaflow_proxy.schemas import GenerateUrlRequest +from mediaflow_proxy.utils.crypto_utils import EncryptionHandler, EncryptionMiddleware +from mediaflow_proxy.utils.http_utils import encode_mediaflow_proxy_url + +logging.basicConfig(level=settings.log_level, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") +app = FastAPI() +api_password_query = APIKeyQuery(name="api_password", auto_error=False) +api_password_header = APIKeyHeader(name="api_password", auto_error=False) +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) +app.add_middleware(EncryptionMiddleware) + + +async def verify_api_key(api_key: str = Security(api_password_query), api_key_alt: str = Security(api_password_header)): + """ + Verifies the API key for the request. + + Args: + api_key (str): The API key to validate. + api_key_alt (str): The alternative API key to validate. + + Raises: + HTTPException: If the API key is invalid. + """ + if not settings.api_password: + return + + if api_key == settings.api_password or api_key_alt == settings.api_password: + return + + raise HTTPException(status_code=403, detail="Could not validate credentials") + + +@app.get("/health") +async def health_check(): + return {"status": "healthy"} + + +@app.get("/favicon.ico") +async def get_favicon(): + return RedirectResponse(url="/logo.png") + + +@app.get("/speedtest") +async def show_speedtest_page(): + return RedirectResponse(url="/speedtest.html") + + +@app.post("/generate_encrypted_or_encoded_url") +async def generate_encrypted_or_encoded_url(request: GenerateUrlRequest): + if "api_password" not in request.query_params: + request.query_params["api_password"] = request.api_password + + encoded_url = encode_mediaflow_proxy_url( + request.mediaflow_proxy_url, + request.endpoint, + request.destination_url, + request.query_params, + request.request_headers, + request.response_headers, + EncryptionHandler(request.api_password) if request.api_password else None, + request.expiration, + str(request.ip) if request.ip else None, + ) + return {"encoded_url": encoded_url} + + +app.include_router(proxy_router, prefix="/proxy", tags=["proxy"], dependencies=[Depends(verify_api_key)]) +app.include_router(extractor_router, prefix="/extractor", tags=["extractors"], dependencies=[Depends(verify_api_key)]) +app.include_router(speedtest_router, prefix="/speedtest", tags=["speedtest"], dependencies=[Depends(verify_api_key)]) + +static_path = resources.files("mediaflow_proxy").joinpath("static") +app.mount("/", StaticFiles(directory=str(static_path), html=True), name="static") + + +def run(): + import uvicorn + + uvicorn.run(app, host="0.0.0.0", port=8888, log_level="info", workers=3) + + +if __name__ == "__main__": + run() diff --git a/mediaflow_proxy/mpd_processor.py b/mediaflow_proxy/mpd_processor.py new file mode 100644 index 0000000..76fba69 --- /dev/null +++ b/mediaflow_proxy/mpd_processor.py @@ -0,0 +1,214 @@ +import logging +import math +import time + +from fastapi import Request, Response, HTTPException + +from mediaflow_proxy.drm.decrypter import decrypt_segment +from mediaflow_proxy.utils.crypto_utils import encryption_handler +from mediaflow_proxy.utils.http_utils import encode_mediaflow_proxy_url, get_original_scheme, ProxyRequestHeaders + +logger = logging.getLogger(__name__) + + +async def process_manifest( + request: Request, mpd_dict: dict, proxy_headers: ProxyRequestHeaders, key_id: str = None, key: str = None +) -> Response: + """ + Processes the MPD manifest and converts it to an HLS manifest. + + Args: + request (Request): The incoming HTTP request. + mpd_dict (dict): The MPD manifest data. + proxy_headers (ProxyRequestHeaders): The headers to include in the request. + key_id (str, optional): The DRM key ID. Defaults to None. + key (str, optional): The DRM key. Defaults to None. + + Returns: + Response: The HLS manifest as an HTTP response. + """ + hls_content = build_hls(mpd_dict, request, key_id, key) + return Response(content=hls_content, media_type="application/vnd.apple.mpegurl", headers=proxy_headers.response) + + +async def process_playlist( + request: Request, mpd_dict: dict, profile_id: str, proxy_headers: ProxyRequestHeaders +) -> Response: + """ + Processes the MPD manifest and converts it to an HLS playlist for a specific profile. + + Args: + request (Request): The incoming HTTP request. + mpd_dict (dict): The MPD manifest data. + profile_id (str): The profile ID to generate the playlist for. + proxy_headers (ProxyRequestHeaders): The headers to include in the request. + + Returns: + Response: The HLS playlist as an HTTP response. + + Raises: + HTTPException: If the profile is not found in the MPD manifest. + """ + matching_profiles = [p for p in mpd_dict["profiles"] if p["id"] == profile_id] + if not matching_profiles: + raise HTTPException(status_code=404, detail="Profile not found") + + hls_content = build_hls_playlist(mpd_dict, matching_profiles, request) + return Response(content=hls_content, media_type="application/vnd.apple.mpegurl", headers=proxy_headers.response) + + +async def process_segment( + init_content: bytes, + segment_content: bytes, + mimetype: str, + proxy_headers: ProxyRequestHeaders, + key_id: str = None, + key: str = None, +) -> Response: + """ + Processes and decrypts a media segment. + + Args: + init_content (bytes): The initialization segment content. + segment_content (bytes): The media segment content. + mimetype (str): The MIME type of the segment. + proxy_headers (ProxyRequestHeaders): The headers to include in the request. + key_id (str, optional): The DRM key ID. Defaults to None. + key (str, optional): The DRM key. Defaults to None. + + Returns: + Response: The decrypted segment as an HTTP response. + """ + if key_id and key: + # For DRM protected content + now = time.time() + decrypted_content = decrypt_segment(init_content, segment_content, key_id, key) + logger.info(f"Decryption of {mimetype} segment took {time.time() - now:.4f} seconds") + else: + # For non-DRM protected content, we just concatenate init and segment content + decrypted_content = init_content + segment_content + + return Response(content=decrypted_content, media_type=mimetype, headers=proxy_headers.response) + + +def build_hls(mpd_dict: dict, request: Request, key_id: str = None, key: str = None) -> str: + """ + Builds an HLS manifest from the MPD manifest. + + Args: + mpd_dict (dict): The MPD manifest data. + request (Request): The incoming HTTP request. + key_id (str, optional): The DRM key ID. Defaults to None. + key (str, optional): The DRM key. Defaults to None. + + Returns: + str: The HLS manifest as a string. + """ + hls = ["#EXTM3U", "#EXT-X-VERSION:6"] + query_params = dict(request.query_params) + has_encrypted = query_params.pop("has_encrypted", False) + + video_profiles = {} + audio_profiles = {} + + # Get the base URL for the playlist_endpoint endpoint + proxy_url = request.url_for("playlist_endpoint") + proxy_url = str(proxy_url.replace(scheme=get_original_scheme(request))) + + for profile in mpd_dict["profiles"]: + query_params.update({"profile_id": profile["id"], "key_id": key_id or "", "key": key or ""}) + playlist_url = encode_mediaflow_proxy_url( + proxy_url, + query_params=query_params, + encryption_handler=encryption_handler if has_encrypted else None, + ) + + if "video" in profile["mimeType"]: + video_profiles[profile["id"]] = (profile, playlist_url) + elif "audio" in profile["mimeType"]: + audio_profiles[profile["id"]] = (profile, playlist_url) + + # Add audio streams + for i, (profile, playlist_url) in enumerate(audio_profiles.values()): + is_default = "YES" if i == 0 else "NO" # Set the first audio track as default + hls.append( + f'#EXT-X-MEDIA:TYPE=AUDIO,GROUP-ID="audio",NAME="{profile["id"]}",DEFAULT={is_default},AUTOSELECT={is_default},LANGUAGE="{profile.get("lang", "und")}",URI="{playlist_url}"' + ) + + # Add video streams + for profile, playlist_url in video_profiles.values(): + hls.append( + f'#EXT-X-STREAM-INF:BANDWIDTH={profile["bandwidth"]},RESOLUTION={profile["width"]}x{profile["height"]},CODECS="{profile["codecs"]}",FRAME-RATE={profile["frameRate"]},AUDIO="audio"' + ) + hls.append(playlist_url) + + return "\n".join(hls) + + +def build_hls_playlist(mpd_dict: dict, profiles: list[dict], request: Request) -> str: + """ + Builds an HLS playlist from the MPD manifest for specific profiles. + + Args: + mpd_dict (dict): The MPD manifest data. + profiles (list[dict]): The profiles to include in the playlist. + request (Request): The incoming HTTP request. + + Returns: + str: The HLS playlist as a string. + """ + hls = ["#EXTM3U", "#EXT-X-VERSION:6"] + + added_segments = 0 + + proxy_url = request.url_for("segment_endpoint") + proxy_url = str(proxy_url.replace(scheme=get_original_scheme(request))) + + for index, profile in enumerate(profiles): + segments = profile["segments"] + if not segments: + logger.warning(f"No segments found for profile {profile['id']}") + continue + + # Add headers for only the first profile + if index == 0: + sequence = segments[0]["number"] + extinf_values = [f["extinf"] for f in segments if "extinf" in f] + target_duration = math.ceil(max(extinf_values)) if extinf_values else 3 + hls.extend( + [ + f"#EXT-X-TARGETDURATION:{target_duration}", + f"#EXT-X-MEDIA-SEQUENCE:{sequence}", + ] + ) + if mpd_dict["isLive"]: + hls.append("#EXT-X-PLAYLIST-TYPE:EVENT") + else: + hls.append("#EXT-X-PLAYLIST-TYPE:VOD") + + init_url = profile["initUrl"] + + query_params = dict(request.query_params) + query_params.pop("profile_id", None) + query_params.pop("d", None) + has_encrypted = query_params.pop("has_encrypted", False) + + for segment in segments: + hls.append(f'#EXTINF:{segment["extinf"]:.3f},') + query_params.update( + {"init_url": init_url, "segment_url": segment["media"], "mime_type": profile["mimeType"]} + ) + hls.append( + encode_mediaflow_proxy_url( + proxy_url, + query_params=query_params, + encryption_handler=encryption_handler if has_encrypted else None, + ) + ) + added_segments += 1 + + if not mpd_dict["isLive"]: + hls.append("#EXT-X-ENDLIST") + + logger.info(f"Added {added_segments} segments to HLS playlist") + return "\n".join(hls) diff --git a/mediaflow_proxy/routes.py b/mediaflow_proxy/routes.py new file mode 100644 index 0000000..078c243 --- /dev/null +++ b/mediaflow_proxy/routes.py @@ -0,0 +1,164 @@ +from fastapi import Request, Depends, APIRouter +from pydantic import HttpUrl + +from .handlers import handle_hls_stream_proxy, proxy_stream, get_manifest, get_playlist, get_segment, get_public_ip +from .utils.http_utils import get_proxy_headers, ProxyRequestHeaders + +proxy_router = APIRouter() + + +@proxy_router.head("/hls") +@proxy_router.get("/hls") +async def hls_stream_proxy( + request: Request, + d: HttpUrl, + proxy_headers: ProxyRequestHeaders = Depends(get_proxy_headers), + key_url: HttpUrl | None = None, + verify_ssl: bool = False, + use_request_proxy: bool = True, +): + """ + Proxify HLS stream requests, fetching and processing the m3u8 playlist or streaming the content. + + Args: + request (Request): The incoming HTTP request. + d (HttpUrl): The destination URL to fetch the content from. + key_url (HttpUrl, optional): The HLS Key URL to replace the original key URL. Defaults to None. (Useful for bypassing some sneaky protection) + proxy_headers (ProxyRequestHeaders): The headers to include in the request. + verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to False. + use_request_proxy (bool, optional): Whether to use the MediaFlow proxy configuration. Defaults to True. + + Returns: + Response: The HTTP response with the processed m3u8 playlist or streamed content. + """ + destination = str(d) + return await handle_hls_stream_proxy(request, destination, proxy_headers, key_url, verify_ssl, use_request_proxy) + + +@proxy_router.head("/stream") +@proxy_router.get("/stream") +async def proxy_stream_endpoint( + request: Request, + d: HttpUrl, + proxy_headers: ProxyRequestHeaders = Depends(get_proxy_headers), + verify_ssl: bool = False, + use_request_proxy: bool = True, +): + """ + Proxies stream requests to the given video URL. + + Args: + request (Request): The incoming HTTP request. + d (HttpUrl): The URL of the video to stream. + proxy_headers (ProxyRequestHeaders): The headers to include in the request. + verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to False. + use_request_proxy (bool, optional): Whether to use the MediaFlow proxy configuration. Defaults to True. + + Returns: + Response: The HTTP response with the streamed content. + """ + proxy_headers.request.update({"range": proxy_headers.request.get("range", "bytes=0-")}) + return await proxy_stream(request.method, str(d), proxy_headers, verify_ssl, use_request_proxy) + + +@proxy_router.get("/mpd/manifest") +async def manifest_endpoint( + request: Request, + d: HttpUrl, + proxy_headers: ProxyRequestHeaders = Depends(get_proxy_headers), + key_id: str = None, + key: str = None, + verify_ssl: bool = False, + use_request_proxy: bool = True, +): + """ + Retrieves and processes the MPD manifest, converting it to an HLS manifest. + + Args: + request (Request): The incoming HTTP request. + d (HttpUrl): The URL of the MPD manifest. + proxy_headers (ProxyRequestHeaders): The headers to include in the request. + key_id (str, optional): The DRM key ID. Defaults to None. + key (str, optional): The DRM key. Defaults to None. + verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to False. + use_request_proxy (bool, optional): Whether to use the MediaFlow proxy configuration. Defaults to True. + + Returns: + Response: The HTTP response with the HLS manifest. + """ + return await get_manifest(request, str(d), proxy_headers, key_id, key, verify_ssl, use_request_proxy) + + +@proxy_router.get("/mpd/playlist") +async def playlist_endpoint( + request: Request, + d: HttpUrl, + profile_id: str, + proxy_headers: ProxyRequestHeaders = Depends(get_proxy_headers), + key_id: str = None, + key: str = None, + verify_ssl: bool = False, + use_request_proxy: bool = True, +): + """ + Retrieves and processes the MPD manifest, converting it to an HLS playlist for a specific profile. + + Args: + request (Request): The incoming HTTP request. + d (HttpUrl): The URL of the MPD manifest. + profile_id (str): The profile ID to generate the playlist for. + proxy_headers (ProxyRequestHeaders): The headers to include in the request. + key_id (str, optional): The DRM key ID. Defaults to None. + key (str, optional): The DRM key. Defaults to None. + verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to False. + use_request_proxy (bool, optional): Whether to use the MediaFlow proxy configuration. Defaults to True. + + Returns: + Response: The HTTP response with the HLS playlist. + """ + return await get_playlist(request, str(d), profile_id, proxy_headers, key_id, key, verify_ssl, use_request_proxy) + + +@proxy_router.get("/mpd/segment") +async def segment_endpoint( + init_url: HttpUrl, + segment_url: HttpUrl, + mime_type: str, + proxy_headers: ProxyRequestHeaders = Depends(get_proxy_headers), + key_id: str = None, + key: str = None, + verify_ssl: bool = False, + use_request_proxy: bool = True, +): + """ + Retrieves and processes a media segment, decrypting it if necessary. + + Args: + init_url (HttpUrl): The URL of the initialization segment. + segment_url (HttpUrl): The URL of the media segment. + mime_type (str): The MIME type of the segment. + proxy_headers (ProxyRequestHeaders): The headers to include in the request. + key_id (str, optional): The DRM key ID. Defaults to None. + key (str, optional): The DRM key. Defaults to None. + verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to False. + use_request_proxy (bool, optional): Whether to use the MediaFlow proxy configuration. Defaults to True. + + Returns: + Response: The HTTP response with the processed segment. + """ + return await get_segment( + str(init_url), str(segment_url), mime_type, proxy_headers, key_id, key, verify_ssl, use_request_proxy + ) + + +@proxy_router.get("/ip") +async def get_mediaflow_proxy_public_ip( + use_request_proxy: bool = True, +): + """ + Retrieves the public IP address of the MediaFlow proxy server. + + Returns: + Response: The HTTP response with the public IP address in the form of a JSON object. {"ip": "xxx.xxx.xxx.xxx"} + """ + return await get_public_ip(use_request_proxy) diff --git a/mediaflow_proxy/routes/__init__.py b/mediaflow_proxy/routes/__init__.py new file mode 100644 index 0000000..4a8a05c --- /dev/null +++ b/mediaflow_proxy/routes/__init__.py @@ -0,0 +1,5 @@ +from .proxy import proxy_router +from .extractor import extractor_router +from .speedtest import speedtest_router + +__all__ = ["proxy_router", "extractor_router", "speedtest_router"] diff --git a/mediaflow_proxy/routes/extractor.py b/mediaflow_proxy/routes/extractor.py new file mode 100644 index 0000000..09981f6 --- /dev/null +++ b/mediaflow_proxy/routes/extractor.py @@ -0,0 +1,61 @@ +import logging +from typing import Annotated + +from fastapi import APIRouter, Query, HTTPException, Request, Depends +from fastapi.responses import RedirectResponse + +from mediaflow_proxy.extractors.base import ExtractorError +from mediaflow_proxy.extractors.factory import ExtractorFactory +from mediaflow_proxy.schemas import ExtractorURLParams +from mediaflow_proxy.utils.cache_utils import get_cached_extractor_result, set_cache_extractor_result +from mediaflow_proxy.utils.http_utils import ( + encode_mediaflow_proxy_url, + get_original_scheme, + ProxyRequestHeaders, + get_proxy_headers, +) + +extractor_router = APIRouter() +logger = logging.getLogger(__name__) + + +@extractor_router.head("/video") +@extractor_router.get("/video") +async def extract_url( + extractor_params: Annotated[ExtractorURLParams, Query()], + request: Request, + proxy_headers: Annotated[ProxyRequestHeaders, Depends(get_proxy_headers)], +): + """Extract clean links from various video hosting services.""" + try: + cache_key = f"{extractor_params.host}_{extractor_params.model_dump_json()}" + response = await get_cached_extractor_result(cache_key) + if not response: + extractor = ExtractorFactory.get_extractor(extractor_params.host, proxy_headers.request) + response = await extractor.extract(extractor_params.destination, **extractor_params.extra_params) + await set_cache_extractor_result(cache_key, response) + else: + response["request_headers"].update(proxy_headers.request) + + response["mediaflow_proxy_url"] = str( + request.url_for(response.pop("mediaflow_endpoint")).replace(scheme=get_original_scheme(request)) + ) + response["query_params"] = response.get("query_params", {}) + # Add API password to query params + response["query_params"]["api_password"] = request.query_params.get("api_password") + + if extractor_params.redirect_stream: + stream_url = encode_mediaflow_proxy_url( + **response, + response_headers=proxy_headers.response, + ) + return RedirectResponse(url=stream_url, status_code=302) + + return response + + except ExtractorError as e: + logger.error(f"Extraction failed: {str(e)}") + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + logger.exception(f"Extraction failed: {str(e)}") + raise HTTPException(status_code=500, detail=f"Extraction failed: {str(e)}") diff --git a/mediaflow_proxy/routes/proxy.py b/mediaflow_proxy/routes/proxy.py new file mode 100644 index 0000000..fb936fa --- /dev/null +++ b/mediaflow_proxy/routes/proxy.py @@ -0,0 +1,138 @@ +from typing import Annotated + +from fastapi import Request, Depends, APIRouter, Query, HTTPException + +from mediaflow_proxy.handlers import ( + handle_hls_stream_proxy, + proxy_stream, + get_manifest, + get_playlist, + get_segment, + get_public_ip, +) +from mediaflow_proxy.schemas import ( + MPDSegmentParams, + MPDPlaylistParams, + HLSManifestParams, + ProxyStreamParams, + MPDManifestParams, +) +from mediaflow_proxy.utils.http_utils import get_proxy_headers, ProxyRequestHeaders + +proxy_router = APIRouter() + + +@proxy_router.head("/hls/manifest.m3u8") +@proxy_router.get("/hls/manifest.m3u8") +async def hls_manifest_proxy( + request: Request, + hls_params: Annotated[HLSManifestParams, Query()], + proxy_headers: Annotated[ProxyRequestHeaders, Depends(get_proxy_headers)], +): + """ + Proxify HLS stream requests, fetching and processing the m3u8 playlist or streaming the content. + + Args: + request (Request): The incoming HTTP request. + hls_params (HLSPlaylistParams): The parameters for the HLS stream request. + proxy_headers (ProxyRequestHeaders): The headers to include in the request. + + Returns: + Response: The HTTP response with the processed m3u8 playlist or streamed content. + """ + return await handle_hls_stream_proxy(request, hls_params, proxy_headers) + + +@proxy_router.head("/stream") +@proxy_router.get("/stream") +async def proxy_stream_endpoint( + request: Request, + stream_params: Annotated[ProxyStreamParams, Query()], + proxy_headers: Annotated[ProxyRequestHeaders, Depends(get_proxy_headers)], +): + """ + Proxies stream requests to the given video URL. + + Args: + request (Request): The incoming HTTP request. + stream_params (ProxyStreamParams): The parameters for the stream request. + proxy_headers (ProxyRequestHeaders): The headers to include in the request. + + Returns: + Response: The HTTP response with the streamed content. + """ + content_range = proxy_headers.request.get("range", "bytes=0-") + if "nan" in content_range.casefold(): + # Handle invalid range requests "bytes=NaN-NaN" + raise HTTPException(status_code=416, detail="Invalid Range Header") + proxy_headers.request.update({"range": content_range}) + return await proxy_stream(request.method, stream_params, proxy_headers) + + +@proxy_router.get("/mpd/manifest.m3u8") +async def mpd_manifest_proxy( + request: Request, + manifest_params: Annotated[MPDManifestParams, Query()], + proxy_headers: Annotated[ProxyRequestHeaders, Depends(get_proxy_headers)], +): + """ + Retrieves and processes the MPD manifest, converting it to an HLS manifest. + + Args: + request (Request): The incoming HTTP request. + manifest_params (MPDManifestParams): The parameters for the manifest request. + proxy_headers (ProxyRequestHeaders): The headers to include in the request. + + Returns: + Response: The HTTP response with the HLS manifest. + """ + return await get_manifest(request, manifest_params, proxy_headers) + + +@proxy_router.get("/mpd/playlist.m3u8") +async def playlist_endpoint( + request: Request, + playlist_params: Annotated[MPDPlaylistParams, Query()], + proxy_headers: Annotated[ProxyRequestHeaders, Depends(get_proxy_headers)], +): + """ + Retrieves and processes the MPD manifest, converting it to an HLS playlist for a specific profile. + + Args: + request (Request): The incoming HTTP request. + playlist_params (MPDPlaylistParams): The parameters for the playlist request. + proxy_headers (ProxyRequestHeaders): The headers to include in the request. + + Returns: + Response: The HTTP response with the HLS playlist. + """ + return await get_playlist(request, playlist_params, proxy_headers) + + +@proxy_router.get("/mpd/segment.mp4") +async def segment_endpoint( + segment_params: Annotated[MPDSegmentParams, Query()], + proxy_headers: Annotated[ProxyRequestHeaders, Depends(get_proxy_headers)], +): + """ + Retrieves and processes a media segment, decrypting it if necessary. + + Args: + segment_params (MPDSegmentParams): The parameters for the segment request. + proxy_headers (ProxyRequestHeaders): The headers to include in the request. + + Returns: + Response: The HTTP response with the processed segment. + """ + return await get_segment(segment_params, proxy_headers) + + +@proxy_router.get("/ip") +async def get_mediaflow_proxy_public_ip(): + """ + Retrieves the public IP address of the MediaFlow proxy server. + + Returns: + Response: The HTTP response with the public IP address in the form of a JSON object. {"ip": "xxx.xxx.xxx.xxx"} + """ + return await get_public_ip() diff --git a/mediaflow_proxy/routes/speedtest.py b/mediaflow_proxy/routes/speedtest.py new file mode 100644 index 0000000..ec33bfd --- /dev/null +++ b/mediaflow_proxy/routes/speedtest.py @@ -0,0 +1,43 @@ +import uuid + +from fastapi import APIRouter, BackgroundTasks, HTTPException, Request +from fastapi.responses import RedirectResponse + +from mediaflow_proxy.speedtest.service import SpeedTestService, SpeedTestProvider + +speedtest_router = APIRouter() + +# Initialize service +speedtest_service = SpeedTestService() + + +@speedtest_router.get("/", summary="Show speed test interface") +async def show_speedtest_page(): + """Return the speed test HTML interface.""" + return RedirectResponse(url="/speedtest.html") + + +@speedtest_router.post("/start", summary="Start a new speed test", response_model=dict) +async def start_speedtest(background_tasks: BackgroundTasks, provider: SpeedTestProvider, request: Request): + """Start a new speed test for the specified provider.""" + task_id = str(uuid.uuid4()) + api_key = request.headers.get("api_key") + + # Create and initialize the task + await speedtest_service.create_test(task_id, provider, api_key) + + # Schedule the speed test + background_tasks.add_task(speedtest_service.run_speedtest, task_id, provider, api_key) + + return {"task_id": task_id} + + +@speedtest_router.get("/results/{task_id}", summary="Get speed test results") +async def get_speedtest_results(task_id: str): + """Get the results or current status of a speed test.""" + task = await speedtest_service.get_test_results(task_id) + + if not task: + raise HTTPException(status_code=404, detail="Speed test task not found or expired") + + return task.dict() diff --git a/mediaflow_proxy/schemas.py b/mediaflow_proxy/schemas.py new file mode 100644 index 0000000..9a8fa2b --- /dev/null +++ b/mediaflow_proxy/schemas.py @@ -0,0 +1,74 @@ +from typing import Literal, Dict, Any, Optional + +from pydantic import BaseModel, Field, IPvAnyAddress, ConfigDict + + +class GenerateUrlRequest(BaseModel): + mediaflow_proxy_url: str = Field(..., description="The base URL for the mediaflow proxy.") + endpoint: Optional[str] = Field(None, description="The specific endpoint to be appended to the base URL.") + destination_url: Optional[str] = Field( + None, description="The destination URL to which the request will be proxied." + ) + query_params: Optional[dict] = Field( + default_factory=dict, description="Query parameters to be included in the request." + ) + request_headers: Optional[dict] = Field(default_factory=dict, description="Headers to be included in the request.") + response_headers: Optional[dict] = Field( + default_factory=dict, description="Headers to be included in the response." + ) + expiration: Optional[int] = Field( + None, description="Expiration time for the URL in seconds. If not provided, the URL will not expire." + ) + api_password: Optional[str] = Field( + None, description="API password for encryption. If not provided, the URL will only be encoded." + ) + ip: Optional[IPvAnyAddress] = Field(None, description="The IP address to restrict the URL to.") + + +class GenericParams(BaseModel): + model_config = ConfigDict(populate_by_name=True) + + +class HLSManifestParams(GenericParams): + destination: str = Field(..., description="The URL of the HLS manifest.", alias="d") + key_url: Optional[str] = Field( + None, + description="The HLS Key URL to replace the original key URL. Defaults to None. (Useful for bypassing some sneaky protection)", + ) + + +class ProxyStreamParams(GenericParams): + destination: str = Field(..., description="The URL of the stream.", alias="d") + + +class MPDManifestParams(GenericParams): + destination: str = Field(..., description="The URL of the MPD manifest.", alias="d") + key_id: Optional[str] = Field(None, description="The DRM key ID (optional).") + key: Optional[str] = Field(None, description="The DRM key (optional).") + + +class MPDPlaylistParams(GenericParams): + destination: str = Field(..., description="The URL of the MPD manifest.", alias="d") + profile_id: str = Field(..., description="The profile ID to generate the playlist for.") + key_id: Optional[str] = Field(None, description="The DRM key ID (optional).") + key: Optional[str] = Field(None, description="The DRM key (optional).") + + +class MPDSegmentParams(GenericParams): + init_url: str = Field(..., description="The URL of the initialization segment.") + segment_url: str = Field(..., description="The URL of the media segment.") + mime_type: str = Field(..., description="The MIME type of the segment.") + key_id: Optional[str] = Field(None, description="The DRM key ID (optional).") + key: Optional[str] = Field(None, description="The DRM key (optional).") + + +class ExtractorURLParams(GenericParams): + host: Literal["Doodstream", "Mixdrop", "Uqload", "Streamtape", "Supervideo", "LiveTV"] = Field( + ..., description="The host to extract the URL from." + ) + destination: str = Field(..., description="The URL of the stream.", alias="d") + redirect_stream: bool = Field(False, description="Whether to redirect to the stream endpoint automatically.") + extra_params: Dict[str, Any] = Field( + default_factory=dict, + description="Additional parameters required for specific extractors (e.g., stream_title for LiveTV)", + ) diff --git a/mediaflow_proxy/speedtest/__init__.py b/mediaflow_proxy/speedtest/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mediaflow_proxy/speedtest/models.py b/mediaflow_proxy/speedtest/models.py new file mode 100644 index 0000000..c789fe1 --- /dev/null +++ b/mediaflow_proxy/speedtest/models.py @@ -0,0 +1,46 @@ +from datetime import datetime +from enum import Enum +from typing import Dict, Optional + +from pydantic import BaseModel, Field + + +class SpeedTestProvider(str, Enum): + REAL_DEBRID = "real_debrid" + ALL_DEBRID = "all_debrid" + + +class ServerInfo(BaseModel): + url: str + name: str + + +class UserInfo(BaseModel): + ip: Optional[str] = None + isp: Optional[str] = None + country: Optional[str] = None + + +class SpeedTestResult(BaseModel): + speed_mbps: float = Field(..., description="Speed in Mbps") + duration: float = Field(..., description="Test duration in seconds") + data_transferred: int = Field(..., description="Data transferred in bytes") + timestamp: datetime = Field(default_factory=datetime.utcnow) + + +class LocationResult(BaseModel): + result: Optional[SpeedTestResult] = None + error: Optional[str] = None + server_name: str + server_url: str + + +class SpeedTestTask(BaseModel): + task_id: str + provider: SpeedTestProvider + results: Dict[str, LocationResult] = {} + started_at: datetime + completed_at: Optional[datetime] = None + status: str = "running" + user_info: Optional[UserInfo] = None + current_location: Optional[str] = None diff --git a/mediaflow_proxy/speedtest/providers/all_debrid.py b/mediaflow_proxy/speedtest/providers/all_debrid.py new file mode 100644 index 0000000..9bf3f7d --- /dev/null +++ b/mediaflow_proxy/speedtest/providers/all_debrid.py @@ -0,0 +1,50 @@ +import random +from typing import Dict, Tuple, Optional + +from mediaflow_proxy.configs import settings +from mediaflow_proxy.speedtest.models import ServerInfo, UserInfo +from mediaflow_proxy.speedtest.providers.base import BaseSpeedTestProvider, SpeedTestProviderConfig +from mediaflow_proxy.utils.http_utils import request_with_retry + + +class SpeedTestError(Exception): + pass + + +class AllDebridSpeedTest(BaseSpeedTestProvider): + """AllDebrid speed test provider implementation.""" + + def __init__(self, api_key: str): + self.api_key = api_key + self.servers: Dict[str, ServerInfo] = {} + + async def get_test_urls(self) -> Tuple[Dict[str, str], Optional[UserInfo]]: + response = await request_with_retry( + "GET", + "https://alldebrid.com/internalapi/v4/speedtest", + headers={"User-Agent": settings.user_agent}, + params={"agent": "service", "version": "1.0-363869a7", "apikey": self.api_key}, + ) + + if response.status_code != 200: + raise SpeedTestError("Failed to fetch AllDebrid servers") + + data = response.json() + if data["status"] != "success": + raise SpeedTestError("AllDebrid API returned error") + + # Create UserInfo + user_info = UserInfo(ip=data["data"]["ip"], isp=data["data"]["isp"], country=data["data"]["country"]) + + # Store server info + self.servers = {server["name"]: ServerInfo(**server) for server in data["data"]["servers"]} + + # Generate URLs with random number + random_number = f"{random.uniform(1, 2):.24f}".replace(".", "") + urls = {name: f"{server.url}/speedtest/{random_number}" for name, server in self.servers.items()} + + return urls, user_info + + async def get_config(self) -> SpeedTestProviderConfig: + urls, _ = await self.get_test_urls() + return SpeedTestProviderConfig(test_duration=10, test_urls=urls) diff --git a/mediaflow_proxy/speedtest/providers/base.py b/mediaflow_proxy/speedtest/providers/base.py new file mode 100644 index 0000000..275c901 --- /dev/null +++ b/mediaflow_proxy/speedtest/providers/base.py @@ -0,0 +1,24 @@ +from abc import ABC, abstractmethod +from typing import Dict, Tuple, Optional +from pydantic import BaseModel + +from mediaflow_proxy.speedtest.models import UserInfo + + +class SpeedTestProviderConfig(BaseModel): + test_duration: int = 10 # seconds + test_urls: Dict[str, str] + + +class BaseSpeedTestProvider(ABC): + """Base class for speed test providers.""" + + @abstractmethod + async def get_test_urls(self) -> Tuple[Dict[str, str], Optional[UserInfo]]: + """Get list of test URLs for the provider and optional user info.""" + pass + + @abstractmethod + async def get_config(self) -> SpeedTestProviderConfig: + """Get provider-specific configuration.""" + pass diff --git a/mediaflow_proxy/speedtest/providers/real_debrid.py b/mediaflow_proxy/speedtest/providers/real_debrid.py new file mode 100644 index 0000000..0d38511 --- /dev/null +++ b/mediaflow_proxy/speedtest/providers/real_debrid.py @@ -0,0 +1,32 @@ +from typing import Dict, Tuple, Optional +import random + +from mediaflow_proxy.speedtest.models import UserInfo +from mediaflow_proxy.speedtest.providers.base import BaseSpeedTestProvider, SpeedTestProviderConfig + + +class RealDebridSpeedTest(BaseSpeedTestProvider): + """RealDebrid speed test provider implementation.""" + + async def get_test_urls(self) -> Tuple[Dict[str, str], Optional[UserInfo]]: + urls = { + "AMS": "https://45.download.real-debrid.com/speedtest/testDefault.rar/", + "RBX": "https://rbx.download.real-debrid.com/speedtest/test.rar/", + "LON1": "https://lon1.download.real-debrid.com/speedtest/test.rar/", + "HKG1": "https://hkg1.download.real-debrid.com/speedtest/test.rar/", + "SGP1": "https://sgp1.download.real-debrid.com/speedtest/test.rar/", + "SGPO1": "https://sgpo1.download.real-debrid.com/speedtest/test.rar/", + "TYO1": "https://tyo1.download.real-debrid.com/speedtest/test.rar/", + "LAX1": "https://lax1.download.real-debrid.com/speedtest/test.rar/", + "TLV1": "https://tlv1.download.real-debrid.com/speedtest/test.rar/", + "MUM1": "https://mum1.download.real-debrid.com/speedtest/test.rar/", + "JKT1": "https://jkt1.download.real-debrid.com/speedtest/test.rar/", + "Cloudflare": "https://45.download.real-debrid.cloud/speedtest/testCloudflare.rar/", + } + # Add random number to prevent caching + urls = {location: f"{base_url}{random.uniform(0, 1):.16f}" for location, base_url in urls.items()} + return urls, None + + async def get_config(self) -> SpeedTestProviderConfig: + urls, _ = await self.get_test_urls() + return SpeedTestProviderConfig(test_duration=10, test_urls=urls) diff --git a/mediaflow_proxy/speedtest/service.py b/mediaflow_proxy/speedtest/service.py new file mode 100644 index 0000000..b40c639 --- /dev/null +++ b/mediaflow_proxy/speedtest/service.py @@ -0,0 +1,129 @@ +import logging +import time +from datetime import datetime, timezone +from typing import Dict, Optional, Type + +from mediaflow_proxy.utils.cache_utils import get_cached_speedtest, set_cache_speedtest +from mediaflow_proxy.utils.http_utils import Streamer, create_httpx_client +from .models import SpeedTestTask, LocationResult, SpeedTestResult, SpeedTestProvider +from .providers.all_debrid import AllDebridSpeedTest +from .providers.base import BaseSpeedTestProvider +from .providers.real_debrid import RealDebridSpeedTest + +logger = logging.getLogger(__name__) + + +class SpeedTestService: + """Service for managing speed tests across different providers.""" + + def __init__(self): + # Provider mapping + self._providers: Dict[SpeedTestProvider, Type[BaseSpeedTestProvider]] = { + SpeedTestProvider.REAL_DEBRID: RealDebridSpeedTest, + SpeedTestProvider.ALL_DEBRID: AllDebridSpeedTest, + } + + def _get_provider(self, provider: SpeedTestProvider, api_key: Optional[str] = None) -> BaseSpeedTestProvider: + """Get the appropriate provider implementation.""" + provider_class = self._providers.get(provider) + if not provider_class: + raise ValueError(f"Unsupported provider: {provider}") + + if provider == SpeedTestProvider.ALL_DEBRID and not api_key: + raise ValueError("API key required for AllDebrid") + + return provider_class(api_key) if provider == SpeedTestProvider.ALL_DEBRID else provider_class() + + async def create_test( + self, task_id: str, provider: SpeedTestProvider, api_key: Optional[str] = None + ) -> SpeedTestTask: + """Create a new speed test task.""" + provider_impl = self._get_provider(provider, api_key) + + # Get initial URLs and user info + urls, user_info = await provider_impl.get_test_urls() + + task = SpeedTestTask( + task_id=task_id, provider=provider, started_at=datetime.now(tz=timezone.utc), user_info=user_info + ) + + await set_cache_speedtest(task_id, task) + return task + + @staticmethod + async def get_test_results(task_id: str) -> Optional[SpeedTestTask]: + """Get results for a specific task.""" + return await get_cached_speedtest(task_id) + + async def run_speedtest(self, task_id: str, provider: SpeedTestProvider, api_key: Optional[str] = None): + """Run the speed test with real-time updates.""" + try: + task = await get_cached_speedtest(task_id) + if not task: + raise ValueError(f"Task {task_id} not found") + + provider_impl = self._get_provider(provider, api_key) + config = await provider_impl.get_config() + + async with create_httpx_client() as client: + streamer = Streamer(client) + + for location, url in config.test_urls.items(): + try: + task.current_location = location + await set_cache_speedtest(task_id, task) + result = await self._test_location(location, url, streamer, config.test_duration, provider_impl) + task.results[location] = result + await set_cache_speedtest(task_id, task) + except Exception as e: + logger.error(f"Error testing {location}: {str(e)}") + task.results[location] = LocationResult( + error=str(e), server_name=location, server_url=config.test_urls[location] + ) + await set_cache_speedtest(task_id, task) + + # Mark task as completed + task.completed_at = datetime.now(tz=timezone.utc) + task.status = "completed" + task.current_location = None + await set_cache_speedtest(task_id, task) + + except Exception as e: + logger.error(f"Error in speed test task {task_id}: {str(e)}") + if task := await get_cached_speedtest(task_id): + task.status = "failed" + await set_cache_speedtest(task_id, task) + + async def _test_location( + self, location: str, url: str, streamer: Streamer, test_duration: int, provider: BaseSpeedTestProvider + ) -> LocationResult: + """Test speed for a specific location.""" + try: + start_time = time.time() + total_bytes = 0 + + await streamer.create_streaming_response(url, headers={}) + + async for chunk in streamer.stream_content(): + if time.time() - start_time >= test_duration: + break + total_bytes += len(chunk) + + duration = time.time() - start_time + speed_mbps = (total_bytes * 8) / (duration * 1_000_000) + + # Get server info if available (for AllDebrid) + server_info = getattr(provider, "servers", {}).get(location) + server_url = server_info.url if server_info else url + + return LocationResult( + result=SpeedTestResult( + speed_mbps=round(speed_mbps, 2), duration=round(duration, 2), data_transferred=total_bytes + ), + server_name=location, + server_url=server_url, + ) + + except Exception as e: + logger.error(f"Error testing {location}: {str(e)}") + raise # Re-raise to be handled by run_speedtest diff --git a/mediaflow_proxy/static/index.html b/mediaflow_proxy/static/index.html new file mode 100644 index 0000000..ea185a0 --- /dev/null +++ b/mediaflow_proxy/static/index.html @@ -0,0 +1,76 @@ + + + + + + MediaFlow Proxy + + + + +
+ MediaFlow Proxy Logo +

MediaFlow Proxy

+
+

A high-performance proxy server for streaming media, supporting HTTP(S), HLS, and MPEG-DASH with real-time DRM decryption.

+ +

Key Features

+
Convert MPEG-DASH streams (DRM-protected and non-protected) to HLS
+
Support for Clear Key DRM-protected MPD DASH streams
+
Handle both live and video-on-demand (VOD) DASH streams
+
Proxy HTTP/HTTPS links with custom headers
+
Proxy and modify HLS (M3U8) streams in real-time with custom headers and key URL modifications for bypassing some sneaky restrictions.
+
Protect against unauthorized access and network bandwidth abuses
+ +

Getting Started

+

Visit the GitHub repository for installation instructions and documentation.

+ +

Premium Hosted Service

+

For a hassle-free experience, check out premium hosted service on ElfHosted.

+ +

API Documentation

+

Explore the Swagger UI for comprehensive details about the API endpoints and their usage.

+ + + \ No newline at end of file diff --git a/mediaflow_proxy/static/logo.png b/mediaflow_proxy/static/logo.png new file mode 100644 index 0000000..d71e14f Binary files /dev/null and b/mediaflow_proxy/static/logo.png differ diff --git a/mediaflow_proxy/static/speedtest.html b/mediaflow_proxy/static/speedtest.html new file mode 100644 index 0000000..386fa8e --- /dev/null +++ b/mediaflow_proxy/static/speedtest.html @@ -0,0 +1,697 @@ + + + + + + Debrid Speed Test + + + + + + +
+ +
+ +
+ +
+ +
+

+ Enter API Password +

+ +
+
+
+ + +
+
+ + +
+ +
+
+
+ + + + + + + + + + + + + + + + +
+
+ + + + \ No newline at end of file diff --git a/mediaflow_proxy/utils/__init__.py b/mediaflow_proxy/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mediaflow_proxy/utils/cache_utils.py b/mediaflow_proxy/utils/cache_utils.py new file mode 100644 index 0000000..d300a2b --- /dev/null +++ b/mediaflow_proxy/utils/cache_utils.py @@ -0,0 +1,376 @@ +import asyncio +import hashlib +import json +import logging +import os +import tempfile +import threading +import time +from collections import OrderedDict +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass +from pathlib import Path +from typing import Optional, Union, Any + +import aiofiles +import aiofiles.os +from pydantic import ValidationError + +from mediaflow_proxy.speedtest.models import SpeedTestTask +from mediaflow_proxy.utils.http_utils import download_file_with_retry, DownloadError +from mediaflow_proxy.utils.mpd_utils import parse_mpd, parse_mpd_dict + +logger = logging.getLogger(__name__) + + +@dataclass +class CacheEntry: + """Represents a cache entry with metadata.""" + + data: bytes + expires_at: float + access_count: int = 0 + last_access: float = 0.0 + size: int = 0 + + +class LRUMemoryCache: + """Thread-safe LRU memory cache with support.""" + + def __init__(self, maxsize: int): + self.maxsize = maxsize + self._cache: OrderedDict[str, CacheEntry] = OrderedDict() + self._lock = threading.Lock() + self._current_size = 0 + + def get(self, key: str) -> Optional[CacheEntry]: + with self._lock: + if key in self._cache: + entry = self._cache.pop(key) # Remove and re-insert for LRU + if time.time() < entry.expires_at: + entry.access_count += 1 + entry.last_access = time.time() + self._cache[key] = entry + return entry + else: + # Remove expired entry + self._current_size -= entry.size + self._cache.pop(key, None) + return None + + def set(self, key: str, entry: CacheEntry) -> None: + with self._lock: + if key in self._cache: + old_entry = self._cache[key] + self._current_size -= old_entry.size + + # Check if we need to make space + while self._current_size + entry.size > self.maxsize and self._cache: + _, removed_entry = self._cache.popitem(last=False) + self._current_size -= removed_entry.size + + self._cache[key] = entry + self._current_size += entry.size + + def remove(self, key: str) -> None: + with self._lock: + if key in self._cache: + entry = self._cache.pop(key) + self._current_size -= entry.size + + +class HybridCache: + """High-performance hybrid cache combining memory and file storage.""" + + def __init__( + self, + cache_dir_name: str, + ttl: int, + max_memory_size: int = 100 * 1024 * 1024, # 100MB default + executor_workers: int = 4, + ): + self.cache_dir = Path(tempfile.gettempdir()) / cache_dir_name + self.ttl = ttl + self.memory_cache = LRUMemoryCache(maxsize=max_memory_size) + self._executor = ThreadPoolExecutor(max_workers=executor_workers) + self._lock = asyncio.Lock() + + # Initialize cache directories + self._init_cache_dirs() + + def _init_cache_dirs(self): + """Initialize sharded cache directories.""" + os.makedirs(self.cache_dir, exist_ok=True) + + def _get_md5_hash(self, key: str) -> str: + """Get the MD5 hash of a cache key.""" + return hashlib.md5(key.encode()).hexdigest() + + def _get_file_path(self, key: str) -> Path: + """Get the file path for a cache key.""" + return self.cache_dir / key + + async def get(self, key: str, default: Any = None) -> Optional[bytes]: + """ + Get value from cache, trying memory first then file. + + Args: + key: Cache key + default: Default value if key not found + + Returns: + Cached value or default if not found + """ + key = self._get_md5_hash(key) + # Try memory cache first + entry = self.memory_cache.get(key) + if entry is not None: + return entry.data + + # Try file cache + try: + file_path = self._get_file_path(key) + async with aiofiles.open(file_path, "rb") as f: + metadata_size = await f.read(8) + metadata_length = int.from_bytes(metadata_size, "big") + metadata_bytes = await f.read(metadata_length) + metadata = json.loads(metadata_bytes.decode()) + + # Check expiration + if metadata["expires_at"] < time.time(): + await self.delete(key) + return default + + # Read data + data = await f.read() + + # Update memory cache in background + entry = CacheEntry( + data=data, + expires_at=metadata["expires_at"], + access_count=metadata["access_count"] + 1, + last_access=time.time(), + size=len(data), + ) + self.memory_cache.set(key, entry) + + return data + + except FileNotFoundError: + return default + except Exception as e: + logger.error(f"Error reading from cache: {e}") + return default + + async def set(self, key: str, data: Union[bytes, bytearray, memoryview], ttl: Optional[int] = None) -> bool: + """ + Set value in both memory and file cache. + + Args: + key: Cache key + data: Data to cache + ttl: Optional TTL override + + Returns: + bool: Success status + """ + if not isinstance(data, (bytes, bytearray, memoryview)): + raise ValueError("Data must be bytes, bytearray, or memoryview") + + expires_at = time.time() + (ttl or self.ttl) + + # Create cache entry + entry = CacheEntry(data=data, expires_at=expires_at, access_count=0, last_access=time.time(), size=len(data)) + + key = self._get_md5_hash(key) + # Update memory cache + self.memory_cache.set(key, entry) + file_path = self._get_file_path(key) + temp_path = file_path.with_suffix(".tmp") + + # Update file cache + try: + metadata = {"expires_at": expires_at, "access_count": 0, "last_access": time.time()} + metadata_bytes = json.dumps(metadata).encode() + metadata_size = len(metadata_bytes).to_bytes(8, "big") + + async with aiofiles.open(temp_path, "wb") as f: + await f.write(metadata_size) + await f.write(metadata_bytes) + await f.write(data) + + await aiofiles.os.rename(temp_path, file_path) + return True + + except Exception as e: + logger.error(f"Error writing to cache: {e}") + try: + await aiofiles.os.remove(temp_path) + except: + pass + return False + + async def delete(self, key: str) -> bool: + """Delete item from both caches.""" + self.memory_cache.remove(key) + + try: + file_path = self._get_file_path(key) + await aiofiles.os.remove(file_path) + return True + except FileNotFoundError: + return True + except Exception as e: + logger.error(f"Error deleting from cache: {e}") + return False + + +class AsyncMemoryCache: + """Async wrapper around LRUMemoryCache.""" + + def __init__(self, max_memory_size: int): + self.memory_cache = LRUMemoryCache(maxsize=max_memory_size) + + async def get(self, key: str, default: Any = None) -> Optional[bytes]: + """Get value from cache.""" + entry = self.memory_cache.get(key) + return entry.data if entry is not None else default + + async def set(self, key: str, data: Union[bytes, bytearray, memoryview], ttl: Optional[int] = None) -> bool: + """Set value in cache.""" + try: + expires_at = time.time() + (ttl or 3600) # Default 1 hour TTL if not specified + entry = CacheEntry( + data=data, expires_at=expires_at, access_count=0, last_access=time.time(), size=len(data) + ) + self.memory_cache.set(key, entry) + return True + except Exception as e: + logger.error(f"Error setting cache value: {e}") + return False + + async def delete(self, key: str) -> bool: + """Delete item from cache.""" + try: + self.memory_cache.remove(key) + return True + except Exception as e: + logger.error(f"Error deleting from cache: {e}") + return False + + +# Create cache instances +INIT_SEGMENT_CACHE = HybridCache( + cache_dir_name="init_segment_cache", + ttl=3600, # 1 hour + max_memory_size=500 * 1024 * 1024, # 500MB for init segments +) + +MPD_CACHE = AsyncMemoryCache( + max_memory_size=100 * 1024 * 1024, # 100MB for MPD files +) + +SPEEDTEST_CACHE = HybridCache( + cache_dir_name="speedtest_cache", + ttl=3600, # 1 hour + max_memory_size=50 * 1024 * 1024, +) + +EXTRACTOR_CACHE = HybridCache( + cache_dir_name="extractor_cache", + ttl=5 * 60, # 5 minutes + max_memory_size=50 * 1024 * 1024, +) + + +# Specific cache implementations +async def get_cached_init_segment(init_url: str, headers: dict) -> Optional[bytes]: + """Get initialization segment from cache or download it.""" + # Try cache first + cached_data = await INIT_SEGMENT_CACHE.get(init_url) + if cached_data is not None: + return cached_data + + # Download if not cached + try: + init_content = await download_file_with_retry(init_url, headers) + if init_content: + await INIT_SEGMENT_CACHE.set(init_url, init_content) + return init_content + except Exception as e: + logger.error(f"Error downloading init segment: {e}") + return None + + +async def get_cached_mpd( + mpd_url: str, + headers: dict, + parse_drm: bool, + parse_segment_profile_id: Optional[str] = None, +) -> dict: + """Get MPD from cache or download and parse it.""" + # Try cache first + cached_data = await MPD_CACHE.get(mpd_url) + if cached_data is not None: + try: + mpd_dict = json.loads(cached_data) + return parse_mpd_dict(mpd_dict, mpd_url, parse_drm, parse_segment_profile_id) + except json.JSONDecodeError: + await MPD_CACHE.delete(mpd_url) + + # Download and parse if not cached + try: + mpd_content = await download_file_with_retry(mpd_url, headers) + mpd_dict = parse_mpd(mpd_content) + parsed_dict = parse_mpd_dict(mpd_dict, mpd_url, parse_drm, parse_segment_profile_id) + + # Cache the original MPD dict + await MPD_CACHE.set(mpd_url, json.dumps(mpd_dict).encode(), ttl=parsed_dict["minimumUpdatePeriod"]) + return parsed_dict + except DownloadError as error: + logger.error(f"Error downloading MPD: {error}") + raise error + except Exception as error: + logger.exception(f"Error processing MPD: {error}") + raise error + + +async def get_cached_speedtest(task_id: str) -> Optional[SpeedTestTask]: + """Get speed test results from cache.""" + cached_data = await SPEEDTEST_CACHE.get(task_id) + if cached_data is not None: + try: + return SpeedTestTask.model_validate_json(cached_data.decode()) + except ValidationError as e: + logger.error(f"Error parsing cached speed test data: {e}") + await SPEEDTEST_CACHE.delete(task_id) + return None + + +async def set_cache_speedtest(task_id: str, task: SpeedTestTask) -> bool: + """Cache speed test results.""" + try: + return await SPEEDTEST_CACHE.set(task_id, task.model_dump_json().encode()) + except Exception as e: + logger.error(f"Error caching speed test data: {e}") + return False + + +async def get_cached_extractor_result(key: str) -> Optional[dict]: + """Get extractor result from cache.""" + cached_data = await EXTRACTOR_CACHE.get(key) + if cached_data is not None: + try: + return json.loads(cached_data) + except json.JSONDecodeError: + await EXTRACTOR_CACHE.delete(key) + return None + + +async def set_cache_extractor_result(key: str, result: dict) -> bool: + """Cache extractor result.""" + try: + return await EXTRACTOR_CACHE.set(key, json.dumps(result).encode()) + except Exception as e: + logger.error(f"Error caching extractor result: {e}") + return False diff --git a/mediaflow_proxy/utils/crypto_utils.py b/mediaflow_proxy/utils/crypto_utils.py new file mode 100644 index 0000000..056df3b --- /dev/null +++ b/mediaflow_proxy/utils/crypto_utils.py @@ -0,0 +1,110 @@ +import base64 +import json +import logging +import time +import traceback +from typing import Optional +from urllib.parse import urlencode + +from Crypto.Cipher import AES +from Crypto.Random import get_random_bytes +from Crypto.Util.Padding import pad, unpad +from fastapi import HTTPException, Request +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.responses import JSONResponse + +from mediaflow_proxy.configs import settings + + +class EncryptionHandler: + def __init__(self, secret_key: str): + self.secret_key = secret_key.encode("utf-8").ljust(32)[:32] + + def encrypt_data(self, data: dict, expiration: int = None, ip: str = None) -> str: + if expiration: + data["exp"] = int(time.time()) + expiration + if ip: + data["ip"] = ip + json_data = json.dumps(data).encode("utf-8") + iv = get_random_bytes(16) + cipher = AES.new(self.secret_key, AES.MODE_CBC, iv) + encrypted_data = cipher.encrypt(pad(json_data, AES.block_size)) + return base64.urlsafe_b64encode(iv + encrypted_data).decode("utf-8") + + def decrypt_data(self, token: str, client_ip: str) -> dict: + try: + encrypted_data = base64.urlsafe_b64decode(token.encode("utf-8")) + iv = encrypted_data[:16] + cipher = AES.new(self.secret_key, AES.MODE_CBC, iv) + decrypted_data = unpad(cipher.decrypt(encrypted_data[16:]), AES.block_size) + data = json.loads(decrypted_data) + + if "exp" in data: + if data["exp"] < time.time(): + raise HTTPException(status_code=401, detail="Token has expired") + del data["exp"] # Remove expiration from the data + + if "ip" in data: + if data["ip"] != client_ip: + raise HTTPException(status_code=403, detail="IP address mismatch") + del data["ip"] # Remove IP from the data + + return data + except Exception as e: + raise HTTPException(status_code=401, detail="Invalid or expired token") + + +class EncryptionMiddleware(BaseHTTPMiddleware): + def __init__(self, app): + super().__init__(app) + self.encryption_handler = encryption_handler + + async def dispatch(self, request: Request, call_next): + encrypted_token = request.query_params.get("token") + if encrypted_token and self.encryption_handler: + try: + client_ip = self.get_client_ip(request) + decrypted_data = self.encryption_handler.decrypt_data(encrypted_token, client_ip) + # Modify request query parameters with decrypted data + query_params = dict(request.query_params) + query_params.pop("token") # Remove the encrypted token from query params + query_params.update(decrypted_data) # Add decrypted data to query params + query_params["has_encrypted"] = True + + # Create a new request scope with updated query parameters + new_query_string = urlencode(query_params) + request.scope["query_string"] = new_query_string.encode() + request._query_params = query_params + except HTTPException as e: + return JSONResponse(content={"error": str(e.detail)}, status_code=e.status_code) + + try: + response = await call_next(request) + except Exception: + exc = traceback.format_exc(chain=False) + logging.error("An error occurred while processing the request, error: %s", exc) + return JSONResponse( + content={"error": "An error occurred while processing the request, check the server for logs"}, + status_code=500, + ) + return response + + @staticmethod + def get_client_ip(request: Request) -> Optional[str]: + """ + Extract the client's real IP address from the request headers or fallback to the client host. + """ + x_forwarded_for = request.headers.get("X-Forwarded-For") + if x_forwarded_for: + # In some cases, this header can contain multiple IPs + # separated by commas. + # The first one is the original client's IP. + return x_forwarded_for.split(",")[0].strip() + # Fallback to X-Real-IP if X-Forwarded-For is not available + x_real_ip = request.headers.get("X-Real-IP") + if x_real_ip: + return x_real_ip + return request.client.host if request.client else "127.0.0.1" + + +encryption_handler = EncryptionHandler(settings.api_password) if settings.api_password else None diff --git a/mediaflow_proxy/utils/http_utils.py b/mediaflow_proxy/utils/http_utils.py new file mode 100644 index 0000000..1a594e2 --- /dev/null +++ b/mediaflow_proxy/utils/http_utils.py @@ -0,0 +1,430 @@ +import logging +import typing +from dataclasses import dataclass +from functools import partial +from urllib import parse +from urllib.parse import urlencode + +import anyio +import httpx +import tenacity +from fastapi import Response +from starlette.background import BackgroundTask +from starlette.concurrency import iterate_in_threadpool +from starlette.requests import Request +from starlette.types import Receive, Send, Scope +from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type +from tqdm.asyncio import tqdm as tqdm_asyncio + +from mediaflow_proxy.configs import settings +from mediaflow_proxy.const import SUPPORTED_REQUEST_HEADERS +from mediaflow_proxy.utils.crypto_utils import EncryptionHandler + +logger = logging.getLogger(__name__) + + +class DownloadError(Exception): + def __init__(self, status_code, message): + self.status_code = status_code + self.message = message + super().__init__(message) + + +def create_httpx_client(follow_redirects: bool = True, timeout: float = 30.0, **kwargs) -> httpx.AsyncClient: + """Creates an HTTPX client with configured proxy routing""" + mounts = settings.transport_config.get_mounts() + client = httpx.AsyncClient(mounts=mounts, follow_redirects=follow_redirects, timeout=timeout, **kwargs) + return client + + +@retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type(DownloadError), +) +async def fetch_with_retry(client, method, url, headers, follow_redirects=True, **kwargs): + """ + Fetches a URL with retry logic. + + Args: + client (httpx.AsyncClient): The HTTP client to use for the request. + method (str): The HTTP method to use (e.g., GET, POST). + url (str): The URL to fetch. + headers (dict): The headers to include in the request. + follow_redirects (bool, optional): Whether to follow redirects. Defaults to True. + **kwargs: Additional arguments to pass to the request. + + Returns: + httpx.Response: The HTTP response. + + Raises: + DownloadError: If the request fails after retries. + """ + try: + response = await client.request(method, url, headers=headers, follow_redirects=follow_redirects, **kwargs) + response.raise_for_status() + return response + except httpx.TimeoutException: + logger.warning(f"Timeout while downloading {url}") + raise DownloadError(409, f"Timeout while downloading {url}") + except httpx.HTTPStatusError as e: + logger.error(f"HTTP error {e.response.status_code} while downloading {url}") + if e.response.status_code == 404: + logger.error(f"Segment Resource not found: {url}") + raise e + raise DownloadError(e.response.status_code, f"HTTP error {e.response.status_code} while downloading {url}") + except Exception as e: + logger.error(f"Error downloading {url}: {e}") + raise + + +class Streamer: + def __init__(self, client): + """ + Initializes the Streamer with an HTTP client. + + Args: + client (httpx.AsyncClient): The HTTP client to use for streaming. + """ + self.client = client + self.response = None + self.progress_bar = None + self.bytes_transferred = 0 + self.start_byte = 0 + self.end_byte = 0 + self.total_size = 0 + + async def create_streaming_response(self, url: str, headers: dict): + """ + Creates and sends a streaming request. + + Args: + url (str): The URL to stream from. + headers (dict): The headers to include in the request. + + """ + request = self.client.build_request("GET", url, headers=headers) + self.response = await self.client.send(request, stream=True, follow_redirects=True) + self.response.raise_for_status() + + async def stream_content(self) -> typing.AsyncGenerator[bytes, None]: + """ + Streams the content from the response. + """ + if not self.response: + raise RuntimeError("No response available for streaming") + + try: + self.parse_content_range() + + if settings.enable_streaming_progress: + with tqdm_asyncio( + total=self.total_size, + initial=self.start_byte, + unit="B", + unit_scale=True, + unit_divisor=1024, + desc="Streaming", + ncols=100, + mininterval=1, + ) as self.progress_bar: + async for chunk in self.response.aiter_bytes(): + yield chunk + chunk_size = len(chunk) + self.bytes_transferred += chunk_size + self.progress_bar.set_postfix_str( + f"📥 : {self.format_bytes(self.bytes_transferred)}", refresh=False + ) + self.progress_bar.update(chunk_size) + else: + async for chunk in self.response.aiter_bytes(): + yield chunk + self.bytes_transferred += len(chunk) + + except httpx.TimeoutException: + logger.warning("Timeout while streaming") + raise DownloadError(409, "Timeout while streaming") + except GeneratorExit: + logger.info("Streaming session stopped by the user") + except Exception as e: + logger.error(f"Error streaming content: {e}") + raise + + @staticmethod + def format_bytes(size) -> str: + power = 2**10 + n = 0 + units = {0: "B", 1: "KB", 2: "MB", 3: "GB", 4: "TB"} + while size > power: + size /= power + n += 1 + return f"{size:.2f} {units[n]}" + + def parse_content_range(self): + content_range = self.response.headers.get("Content-Range", "") + if content_range: + range_info = content_range.split()[-1] + self.start_byte, self.end_byte, self.total_size = map(int, range_info.replace("/", "-").split("-")) + else: + self.start_byte = 0 + self.total_size = int(self.response.headers.get("Content-Length", 0)) + self.end_byte = self.total_size - 1 if self.total_size > 0 else 0 + + async def get_text(self, url: str, headers: dict): + """ + Sends a GET request to a URL and returns the response text. + + Args: + url (str): The URL to send the GET request to. + headers (dict): The headers to include in the request. + + Returns: + str: The response text. + """ + try: + self.response = await fetch_with_retry(self.client, "GET", url, headers) + except tenacity.RetryError as e: + raise e.last_attempt.result() + return self.response.text + + async def close(self): + """ + Closes the HTTP client and response. + """ + if self.response: + await self.response.aclose() + if self.progress_bar: + self.progress_bar.close() + await self.client.aclose() + + +async def download_file_with_retry(url: str, headers: dict): + """ + Downloads a file with retry logic. + + Args: + url (str): The URL of the file to download. + headers (dict): The headers to include in the request. + + Returns: + bytes: The downloaded file content. + + Raises: + DownloadError: If the download fails after retries. + """ + async with create_httpx_client() as client: + try: + response = await fetch_with_retry(client, "GET", url, headers) + return response.content + except DownloadError as e: + logger.error(f"Failed to download file: {e}") + raise e + except tenacity.RetryError as e: + raise DownloadError(502, f"Failed to download file: {e.last_attempt.result()}") + + +async def request_with_retry(method: str, url: str, headers: dict, **kwargs) -> httpx.Response: + """ + Sends an HTTP request with retry logic. + + Args: + method (str): The HTTP method to use (e.g., GET, POST). + url (str): The URL to send the request to. + headers (dict): The headers to include in the request. + **kwargs: Additional arguments to pass to the request. + + Returns: + httpx.Response: The HTTP response. + + Raises: + DownloadError: If the request fails after retries. + """ + async with create_httpx_client() as client: + try: + response = await fetch_with_retry(client, method, url, headers, **kwargs) + return response + except DownloadError as e: + logger.error(f"Failed to download file: {e}") + raise + + +def encode_mediaflow_proxy_url( + mediaflow_proxy_url: str, + endpoint: typing.Optional[str] = None, + destination_url: typing.Optional[str] = None, + query_params: typing.Optional[dict] = None, + request_headers: typing.Optional[dict] = None, + response_headers: typing.Optional[dict] = None, + encryption_handler: EncryptionHandler = None, + expiration: int = None, + ip: str = None, +) -> str: + """ + Encodes & Encrypt (Optional) a MediaFlow proxy URL with query parameters and headers. + + Args: + mediaflow_proxy_url (str): The base MediaFlow proxy URL. + endpoint (str, optional): The endpoint to append to the base URL. Defaults to None. + destination_url (str, optional): The destination URL to include in the query parameters. Defaults to None. + query_params (dict, optional): Additional query parameters to include. Defaults to None. + request_headers (dict, optional): Headers to include as query parameters. Defaults to None. + response_headers (dict, optional): Headers to include as query parameters. Defaults to None. + encryption_handler (EncryptionHandler, optional): The encryption handler to use. Defaults to None. + expiration (int, optional): The expiration time for the encrypted token. Defaults to None. + ip (str, optional): The public IP address to include in the query parameters. Defaults to None. + + Returns: + str: The encoded MediaFlow proxy URL. + """ + query_params = query_params or {} + if destination_url is not None: + query_params["d"] = destination_url + + # Add headers if provided + if request_headers: + query_params.update( + {key if key.startswith("h_") else f"h_{key}": value for key, value in request_headers.items()} + ) + if response_headers: + query_params.update( + {key if key.startswith("r_") else f"r_{key}": value for key, value in response_headers.items()} + ) + + if encryption_handler: + encrypted_token = encryption_handler.encrypt_data(query_params, expiration, ip) + encoded_params = urlencode({"token": encrypted_token}) + else: + encoded_params = urlencode(query_params) + + # Construct the full URL + if endpoint is None: + return f"{mediaflow_proxy_url}?{encoded_params}" + + base_url = parse.urljoin(mediaflow_proxy_url, endpoint) + return f"{base_url}?{encoded_params}" + + +def get_original_scheme(request: Request) -> str: + """ + Determines the original scheme (http or https) of the request. + + Args: + request (Request): The incoming HTTP request. + + Returns: + str: The original scheme ('http' or 'https') + """ + # Check the X-Forwarded-Proto header first + forwarded_proto = request.headers.get("X-Forwarded-Proto") + if forwarded_proto: + return forwarded_proto + + # Check if the request is secure + if request.url.scheme == "https" or request.headers.get("X-Forwarded-Ssl") == "on": + return "https" + + # Check for other common headers that might indicate HTTPS + if ( + request.headers.get("X-Forwarded-Ssl") == "on" + or request.headers.get("X-Forwarded-Protocol") == "https" + or request.headers.get("X-Url-Scheme") == "https" + ): + return "https" + + # Default to http if no indicators of https are found + return "http" + + +@dataclass +class ProxyRequestHeaders: + request: dict + response: dict + + +def get_proxy_headers(request: Request) -> ProxyRequestHeaders: + """ + Extracts proxy headers from the request query parameters. + + Args: + request (Request): The incoming HTTP request. + + Returns: + ProxyRequest: A named tuple containing the request headers and response headers. + """ + request_headers = {k: v for k, v in request.headers.items() if k in SUPPORTED_REQUEST_HEADERS} + request_headers.update({k[2:].lower(): v for k, v in request.query_params.items() if k.startswith("h_")}) + response_headers = {k[2:].lower(): v for k, v in request.query_params.items() if k.startswith("r_")} + return ProxyRequestHeaders(request_headers, response_headers) + + +class EnhancedStreamingResponse(Response): + body_iterator: typing.AsyncIterable[typing.Any] + + def __init__( + self, + content: typing.Union[typing.AsyncIterable[typing.Any], typing.Iterable[typing.Any]], + status_code: int = 200, + headers: typing.Optional[typing.Mapping[str, str]] = None, + media_type: typing.Optional[str] = None, + background: typing.Optional[BackgroundTask] = None, + ) -> None: + if isinstance(content, typing.AsyncIterable): + self.body_iterator = content + else: + self.body_iterator = iterate_in_threadpool(content) + self.status_code = status_code + self.media_type = self.media_type if media_type is None else media_type + self.background = background + self.init_headers(headers) + + @staticmethod + async def listen_for_disconnect(receive: Receive) -> None: + try: + while True: + message = await receive() + if message["type"] == "http.disconnect": + logger.debug("Client disconnected") + break + except Exception as e: + logger.error(f"Error in listen_for_disconnect: {str(e)}") + + async def stream_response(self, send: Send) -> None: + try: + await send( + { + "type": "http.response.start", + "status": self.status_code, + "headers": self.raw_headers, + } + ) + async for chunk in self.body_iterator: + if not isinstance(chunk, (bytes, memoryview)): + chunk = chunk.encode(self.charset) + try: + await send({"type": "http.response.body", "body": chunk, "more_body": True}) + except (ConnectionResetError, anyio.BrokenResourceError): + logger.info("Client disconnected during streaming") + return + + await send({"type": "http.response.body", "body": b"", "more_body": False}) + except Exception as e: + logger.exception(f"Error in stream_response: {str(e)}") + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + async with anyio.create_task_group() as task_group: + + async def wrap(func: typing.Callable[[], typing.Awaitable[None]]) -> None: + try: + await func() + except Exception as e: + if not isinstance(e, anyio.get_cancelled_exc_class()): + logger.exception("Error in streaming task") + raise + finally: + task_group.cancel_scope.cancel() + + task_group.start_soon(wrap, partial(self.stream_response, send)) + await wrap(partial(self.listen_for_disconnect, receive)) + + if self.background is not None: + await self.background() diff --git a/mediaflow_proxy/utils/m3u8_processor.py b/mediaflow_proxy/utils/m3u8_processor.py new file mode 100644 index 0000000..83f8958 --- /dev/null +++ b/mediaflow_proxy/utils/m3u8_processor.py @@ -0,0 +1,87 @@ +import re +from urllib import parse + +from mediaflow_proxy.utils.crypto_utils import encryption_handler +from mediaflow_proxy.utils.http_utils import encode_mediaflow_proxy_url, get_original_scheme + + +class M3U8Processor: + def __init__(self, request, key_url: str = None): + """ + Initializes the M3U8Processor with the request and URL prefix. + + Args: + request (Request): The incoming HTTP request. + key_url (HttpUrl, optional): The URL of the key server. Defaults to None. + """ + self.request = request + self.key_url = parse.urlparse(key_url) if key_url else None + self.mediaflow_proxy_url = str( + request.url_for("hls_manifest_proxy").replace(scheme=get_original_scheme(request)) + ) + + async def process_m3u8(self, content: str, base_url: str) -> str: + """ + Processes the m3u8 content, proxying URLs and handling key lines. + + Args: + content (str): The m3u8 content to process. + base_url (str): The base URL to resolve relative URLs. + + Returns: + str: The processed m3u8 content. + """ + lines = content.splitlines() + processed_lines = [] + for line in lines: + if "URI=" in line: + processed_lines.append(await self.process_key_line(line, base_url)) + elif not line.startswith("#") and line.strip(): + processed_lines.append(await self.proxy_url(line, base_url)) + else: + processed_lines.append(line) + return "\n".join(processed_lines) + + async def process_key_line(self, line: str, base_url: str) -> str: + """ + Processes a key line in the m3u8 content, proxying the URI. + + Args: + line (str): The key line to process. + base_url (str): The base URL to resolve relative URLs. + + Returns: + str: The processed key line. + """ + uri_match = re.search(r'URI="([^"]+)"', line) + if uri_match: + original_uri = uri_match.group(1) + uri = parse.urlparse(original_uri) + if self.key_url: + uri = uri._replace(scheme=self.key_url.scheme, netloc=self.key_url.netloc) + new_uri = await self.proxy_url(uri.geturl(), base_url) + line = line.replace(f'URI="{original_uri}"', f'URI="{new_uri}"') + return line + + async def proxy_url(self, url: str, base_url: str) -> str: + """ + Proxies a URL, encoding it with the MediaFlow proxy URL. + + Args: + url (str): The URL to proxy. + base_url (str): The base URL to resolve relative URLs. + + Returns: + str: The proxied URL. + """ + full_url = parse.urljoin(base_url, url) + query_params = dict(self.request.query_params) + has_encrypted = query_params.pop("has_encrypted", False) + + return encode_mediaflow_proxy_url( + self.mediaflow_proxy_url, + "", + full_url, + query_params=dict(self.request.query_params), + encryption_handler=encryption_handler if has_encrypted else None, + ) diff --git a/mediaflow_proxy/utils/mpd_utils.py b/mediaflow_proxy/utils/mpd_utils.py new file mode 100644 index 0000000..5603694 --- /dev/null +++ b/mediaflow_proxy/utils/mpd_utils.py @@ -0,0 +1,555 @@ +import logging +import math +import re +from datetime import datetime, timedelta, timezone +from typing import List, Dict, Optional, Union +from urllib.parse import urljoin + +import xmltodict + +logger = logging.getLogger(__name__) + + +def parse_mpd(mpd_content: Union[str, bytes]) -> dict: + """ + Parses the MPD content into a dictionary. + + Args: + mpd_content (Union[str, bytes]): The MPD content to parse. + + Returns: + dict: The parsed MPD content as a dictionary. + """ + return xmltodict.parse(mpd_content) + + +def parse_mpd_dict( + mpd_dict: dict, mpd_url: str, parse_drm: bool = True, parse_segment_profile_id: Optional[str] = None +) -> dict: + """ + Parses the MPD dictionary and extracts relevant information. + + Args: + mpd_dict (dict): The MPD content as a dictionary. + mpd_url (str): The URL of the MPD manifest. + parse_drm (bool, optional): Whether to parse DRM information. Defaults to True. + parse_segment_profile_id (str, optional): The profile ID to parse segments for. Defaults to None. + + Returns: + dict: The parsed MPD information including profiles and DRM info. + + This function processes the MPD dictionary to extract profiles, DRM information, and other relevant data. + It handles both live and static MPD manifests. + """ + profiles = [] + parsed_dict = {} + source = "/".join(mpd_url.split("/")[:-1]) + + is_live = mpd_dict["MPD"].get("@type", "static").lower() == "dynamic" + parsed_dict["isLive"] = is_live + + media_presentation_duration = mpd_dict["MPD"].get("@mediaPresentationDuration") + + # Parse additional MPD attributes for live streams + if is_live: + parsed_dict["minimumUpdatePeriod"] = parse_duration(mpd_dict["MPD"].get("@minimumUpdatePeriod", "PT0S")) + parsed_dict["timeShiftBufferDepth"] = parse_duration(mpd_dict["MPD"].get("@timeShiftBufferDepth", "PT2M")) + parsed_dict["availabilityStartTime"] = datetime.fromisoformat( + mpd_dict["MPD"]["@availabilityStartTime"].replace("Z", "+00:00") + ) + parsed_dict["publishTime"] = datetime.fromisoformat( + mpd_dict["MPD"].get("@publishTime", "").replace("Z", "+00:00") + ) + + periods = mpd_dict["MPD"]["Period"] + periods = periods if isinstance(periods, list) else [periods] + + for period in periods: + parsed_dict["PeriodStart"] = parse_duration(period.get("@start", "PT0S")) + for adaptation in period["AdaptationSet"]: + representations = adaptation["Representation"] + representations = representations if isinstance(representations, list) else [representations] + + for representation in representations: + profile = parse_representation( + parsed_dict, + representation, + adaptation, + source, + media_presentation_duration, + parse_segment_profile_id, + ) + if profile: + profiles.append(profile) + parsed_dict["profiles"] = profiles + + if parse_drm: + drm_info = extract_drm_info(periods, mpd_url) + else: + drm_info = {} + parsed_dict["drmInfo"] = drm_info + + return parsed_dict + + +def pad_base64(encoded_key_id): + """ + Pads a base64 encoded key ID to make its length a multiple of 4. + + Args: + encoded_key_id (str): The base64 encoded key ID. + + Returns: + str: The padded base64 encoded key ID. + """ + return encoded_key_id + "=" * (4 - len(encoded_key_id) % 4) + + +def extract_drm_info(periods: List[Dict], mpd_url: str) -> Dict: + """ + Extracts DRM information from the MPD periods. + + Args: + periods (List[Dict]): The list of periods in the MPD. + mpd_url (str): The URL of the MPD manifest. + + Returns: + Dict: The extracted DRM information. + + This function processes the ContentProtection elements in the MPD to extract DRM system information, + such as ClearKey, Widevine, and PlayReady. + """ + drm_info = {"isDrmProtected": False} + + for period in periods: + adaptation_sets: Union[list[dict], dict] = period.get("AdaptationSet", []) + if not isinstance(adaptation_sets, list): + adaptation_sets = [adaptation_sets] + + for adaptation_set in adaptation_sets: + # Check ContentProtection in AdaptationSet + process_content_protection(adaptation_set.get("ContentProtection", []), drm_info) + + # Check ContentProtection inside each Representation + representations: Union[list[dict], dict] = adaptation_set.get("Representation", []) + if not isinstance(representations, list): + representations = [representations] + + for representation in representations: + process_content_protection(representation.get("ContentProtection", []), drm_info) + + # If we have a license acquisition URL, make sure it's absolute + if "laUrl" in drm_info and not drm_info["laUrl"].startswith(("http://", "https://")): + drm_info["laUrl"] = urljoin(mpd_url, drm_info["laUrl"]) + + return drm_info + + +def process_content_protection(content_protection: Union[list[dict], dict], drm_info: dict): + """ + Processes the ContentProtection elements to extract DRM information. + + Args: + content_protection (Union[list[dict], dict]): The ContentProtection elements. + drm_info (dict): The dictionary to store DRM information. + + This function updates the drm_info dictionary with DRM system information found in the ContentProtection elements. + """ + if not isinstance(content_protection, list): + content_protection = [content_protection] + + for protection in content_protection: + drm_info["isDrmProtected"] = True + scheme_id_uri = protection.get("@schemeIdUri", "").lower() + + if "clearkey" in scheme_id_uri: + drm_info["drmSystem"] = "clearkey" + if "clearkey:Laurl" in protection: + la_url = protection["clearkey:Laurl"].get("#text") + if la_url and "laUrl" not in drm_info: + drm_info["laUrl"] = la_url + + elif "widevine" in scheme_id_uri or "edef8ba9-79d6-4ace-a3c8-27dcd51d21ed" in scheme_id_uri: + drm_info["drmSystem"] = "widevine" + pssh = protection.get("cenc:pssh", {}).get("#text") + if pssh: + drm_info["pssh"] = pssh + + elif "playready" in scheme_id_uri or "9a04f079-9840-4286-ab92-e65be0885f95" in scheme_id_uri: + drm_info["drmSystem"] = "playready" + + if "@cenc:default_KID" in protection: + key_id = protection["@cenc:default_KID"].replace("-", "") + if "keyId" not in drm_info: + drm_info["keyId"] = key_id + + if "ms:laurl" in protection: + la_url = protection["ms:laurl"].get("@licenseUrl") + if la_url and "laUrl" not in drm_info: + drm_info["laUrl"] = la_url + + return drm_info + + +def parse_representation( + parsed_dict: dict, + representation: dict, + adaptation: dict, + source: str, + media_presentation_duration: str, + parse_segment_profile_id: Optional[str], +) -> Optional[dict]: + """ + Parses a representation and extracts profile information. + + Args: + parsed_dict (dict): The parsed MPD data. + representation (dict): The representation data. + adaptation (dict): The adaptation set data. + source (str): The source URL. + media_presentation_duration (str): The media presentation duration. + parse_segment_profile_id (str, optional): The profile ID to parse segments for. Defaults to None. + + Returns: + Optional[dict]: The parsed profile information or None if not applicable. + """ + mime_type = _get_key(adaptation, representation, "@mimeType") or ( + "video/mp4" if "avc" in representation["@codecs"] else "audio/mp4" + ) + if "video" not in mime_type and "audio" not in mime_type: + return None + + profile = { + "id": representation.get("@id") or adaptation.get("@id"), + "mimeType": mime_type, + "lang": representation.get("@lang") or adaptation.get("@lang"), + "codecs": representation.get("@codecs") or adaptation.get("@codecs"), + "bandwidth": int(representation.get("@bandwidth") or adaptation.get("@bandwidth")), + "startWithSAP": (_get_key(adaptation, representation, "@startWithSAP") or "1") == "1", + "mediaPresentationDuration": media_presentation_duration, + } + + if "audio" in profile["mimeType"]: + profile["audioSamplingRate"] = representation.get("@audioSamplingRate") or adaptation.get("@audioSamplingRate") + profile["channels"] = representation.get("AudioChannelConfiguration", {}).get("@value", "2") + else: + profile["width"] = int(representation["@width"]) + profile["height"] = int(representation["@height"]) + frame_rate = representation.get("@frameRate") or adaptation.get("@maxFrameRate") or "30000/1001" + frame_rate = frame_rate if "/" in frame_rate else f"{frame_rate}/1" + profile["frameRate"] = round(int(frame_rate.split("/")[0]) / int(frame_rate.split("/")[1]), 3) + profile["sar"] = representation.get("@sar", "1:1") + + if parse_segment_profile_id is None or profile["id"] != parse_segment_profile_id: + return profile + + item = adaptation.get("SegmentTemplate") or representation.get("SegmentTemplate") + if item: + profile["segments"] = parse_segment_template(parsed_dict, item, profile, source) + else: + profile["segments"] = parse_segment_base(representation, source) + + return profile + + +def _get_key(adaptation: dict, representation: dict, key: str) -> Optional[str]: + """ + Retrieves a key from the representation or adaptation set. + + Args: + adaptation (dict): The adaptation set data. + representation (dict): The representation data. + key (str): The key to retrieve. + + Returns: + Optional[str]: The value of the key or None if not found. + """ + return representation.get(key, adaptation.get(key, None)) + + +def parse_segment_template(parsed_dict: dict, item: dict, profile: dict, source: str) -> List[Dict]: + """ + Parses a segment template and extracts segment information. + + Args: + parsed_dict (dict): The parsed MPD data. + item (dict): The segment template data. + profile (dict): The profile information. + source (str): The source URL. + + Returns: + List[Dict]: The list of parsed segments. + """ + segments = [] + timescale = int(item.get("@timescale", 1)) + + # Initialization + if "@initialization" in item: + media = item["@initialization"] + media = media.replace("$RepresentationID$", profile["id"]) + media = media.replace("$Bandwidth$", str(profile["bandwidth"])) + if not media.startswith("http"): + media = f"{source}/{media}" + profile["initUrl"] = media + + # Segments + if "SegmentTimeline" in item: + segments.extend(parse_segment_timeline(parsed_dict, item, profile, source, timescale)) + elif "@duration" in item: + segments.extend(parse_segment_duration(parsed_dict, item, profile, source, timescale)) + + return segments + + +def parse_segment_timeline(parsed_dict: dict, item: dict, profile: dict, source: str, timescale: int) -> List[Dict]: + """ + Parses a segment timeline and extracts segment information. + + Args: + parsed_dict (dict): The parsed MPD data. + item (dict): The segment timeline data. + profile (dict): The profile information. + source (str): The source URL. + timescale (int): The timescale for the segments. + + Returns: + List[Dict]: The list of parsed segments. + """ + timelines = item["SegmentTimeline"]["S"] + timelines = timelines if isinstance(timelines, list) else [timelines] + period_start = parsed_dict["availabilityStartTime"] + timedelta(seconds=parsed_dict.get("PeriodStart", 0)) + presentation_time_offset = int(item.get("@presentationTimeOffset", 0)) + start_number = int(item.get("@startNumber", 1)) + + segments = [ + create_segment_data(timeline, item, profile, source, timescale) + for timeline in preprocess_timeline(timelines, start_number, period_start, presentation_time_offset, timescale) + ] + return segments + + +def preprocess_timeline( + timelines: List[Dict], start_number: int, period_start: datetime, presentation_time_offset: int, timescale: int +) -> List[Dict]: + """ + Preprocesses the segment timeline data. + + Args: + timelines (List[Dict]): The list of timeline segments. + start_number (int): The starting segment number. + period_start (datetime): The start time of the period. + presentation_time_offset (int): The presentation time offset. + timescale (int): The timescale for the segments. + + Returns: + List[Dict]: The list of preprocessed timeline segments. + """ + processed_data = [] + current_time = 0 + for timeline in timelines: + repeat = int(timeline.get("@r", 0)) + duration = int(timeline["@d"]) + start_time = int(timeline.get("@t", current_time)) + + for _ in range(repeat + 1): + segment_start_time = period_start + timedelta(seconds=(start_time - presentation_time_offset) / timescale) + segment_end_time = segment_start_time + timedelta(seconds=duration / timescale) + processed_data.append( + { + "number": start_number, + "start_time": segment_start_time, + "end_time": segment_end_time, + "duration": duration, + "time": start_time, + } + ) + start_time += duration + start_number += 1 + + current_time = start_time + + return processed_data + + +def parse_segment_duration(parsed_dict: dict, item: dict, profile: dict, source: str, timescale: int) -> List[Dict]: + """ + Parses segment duration and extracts segment information. + This is used for static or live MPD manifests. + + Args: + parsed_dict (dict): The parsed MPD data. + item (dict): The segment duration data. + profile (dict): The profile information. + source (str): The source URL. + timescale (int): The timescale for the segments. + + Returns: + List[Dict]: The list of parsed segments. + """ + duration = int(item["@duration"]) + start_number = int(item.get("@startNumber", 1)) + segment_duration_sec = duration / timescale + + if parsed_dict["isLive"]: + segments = generate_live_segments(parsed_dict, segment_duration_sec, start_number) + else: + segments = generate_vod_segments(profile, duration, timescale, start_number) + + return [create_segment_data(seg, item, profile, source, timescale) for seg in segments] + + +def generate_live_segments(parsed_dict: dict, segment_duration_sec: float, start_number: int) -> List[Dict]: + """ + Generates live segments based on the segment duration and start number. + This is used for live MPD manifests. + + Args: + parsed_dict (dict): The parsed MPD data. + segment_duration_sec (float): The segment duration in seconds. + start_number (int): The starting segment number. + + Returns: + List[Dict]: The list of generated live segments. + """ + time_shift_buffer_depth = timedelta(seconds=parsed_dict.get("timeShiftBufferDepth", 60)) + segment_count = math.ceil(time_shift_buffer_depth.total_seconds() / segment_duration_sec) + current_time = datetime.now(tz=timezone.utc) + earliest_segment_number = max( + start_number + + math.floor((current_time - parsed_dict["availabilityStartTime"]).total_seconds() / segment_duration_sec) + - segment_count, + start_number, + ) + + return [ + { + "number": number, + "start_time": parsed_dict["availabilityStartTime"] + + timedelta(seconds=(number - start_number) * segment_duration_sec), + "duration": segment_duration_sec, + } + for number in range(earliest_segment_number, earliest_segment_number + segment_count) + ] + + +def generate_vod_segments(profile: dict, duration: int, timescale: int, start_number: int) -> List[Dict]: + """ + Generates VOD segments based on the segment duration and start number. + This is used for static MPD manifests. + + Args: + profile (dict): The profile information. + duration (int): The segment duration. + timescale (int): The timescale for the segments. + start_number (int): The starting segment number. + + Returns: + List[Dict]: The list of generated VOD segments. + """ + total_duration = profile.get("mediaPresentationDuration") or 0 + if isinstance(total_duration, str): + total_duration = parse_duration(total_duration) + segment_count = math.ceil(total_duration * timescale / duration) + + return [{"number": start_number + i, "duration": duration / timescale} for i in range(segment_count)] + + +def create_segment_data(segment: Dict, item: dict, profile: dict, source: str, timescale: Optional[int] = None) -> Dict: + """ + Creates segment data based on the segment information. This includes the segment URL and metadata. + + Args: + segment (Dict): The segment information. + item (dict): The segment template data. + profile (dict): The profile information. + source (str): The source URL. + timescale (int, optional): The timescale for the segments. Defaults to None. + + Returns: + Dict: The created segment data. + """ + media_template = item["@media"] + media = media_template.replace("$RepresentationID$", profile["id"]) + media = media.replace("$Number%04d$", f"{segment['number']:04d}") + media = media.replace("$Number$", str(segment["number"])) + media = media.replace("$Bandwidth$", str(profile["bandwidth"])) + + if "time" in segment and timescale is not None: + media = media.replace("$Time$", str(int(segment["time"] * timescale))) + + if not media.startswith("http"): + media = f"{source}/{media}" + + segment_data = { + "type": "segment", + "media": media, + "number": segment["number"], + } + + if "start_time" in segment and "end_time" in segment: + segment_data.update( + { + "start_time": segment["start_time"], + "end_time": segment["end_time"], + "extinf": (segment["end_time"] - segment["start_time"]).total_seconds(), + "program_date_time": segment["start_time"].isoformat() + "Z", + } + ) + elif "start_time" in segment and "duration" in segment: + duration = segment["duration"] + segment_data.update( + { + "start_time": segment["start_time"], + "end_time": segment["start_time"] + timedelta(seconds=duration), + "extinf": duration, + "program_date_time": segment["start_time"].isoformat() + "Z", + } + ) + elif "duration" in segment: + segment_data["extinf"] = segment["duration"] + + return segment_data + + +def parse_segment_base(representation: dict, source: str) -> List[Dict]: + """ + Parses segment base information and extracts segment data. This is used for single-segment representations. + + Args: + representation (dict): The representation data. + source (str): The source URL. + + Returns: + List[Dict]: The list of parsed segments. + """ + segment = representation["SegmentBase"] + start, end = map(int, segment["@indexRange"].split("-")) + if "Initialization" in segment: + start, _ = map(int, segment["Initialization"]["@range"].split("-")) + + return [ + { + "type": "segment", + "range": f"{start}-{end}", + "media": f"{source}/{representation['BaseURL']}", + } + ] + + +def parse_duration(duration_str: str) -> float: + """ + Parses a duration ISO 8601 string into seconds. + + Args: + duration_str (str): The duration string to parse. + + Returns: + float: The parsed duration in seconds. + """ + pattern = re.compile(r"P(?:(\d+)Y)?(?:(\d+)M)?(?:(\d+)D)?T?(?:(\d+)H)?(?:(\d+)M)?(?:(\d+(?:\.\d+)?)S)?") + match = pattern.match(duration_str) + if not match: + raise ValueError(f"Invalid duration format: {duration_str}") + + years, months, days, hours, minutes, seconds = [float(g) if g else 0 for g in match.groups()] + return years * 365 * 24 * 3600 + months * 30 * 24 * 3600 + days * 24 * 3600 + hours * 3600 + minutes * 60 + seconds diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..2376853 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,19 @@ +bs4 +dateparser +python-dotenv +fastapi +uvicorn +tzdata +lxml +curl_cffi +fake-headers +pydantic_settings +httpx +Crypto +pycryptodome +tenacity +xmltodict +starlette +cachetools +tqdm +aiofiles diff --git a/run.py b/run.py new file mode 100644 index 0000000..e7970cb --- /dev/null +++ b/run.py @@ -0,0 +1,18 @@ +from fastapi import FastAPI +from mediaflow_proxy.main import app as mediaflow_app # Import mediaflow app +import httpx +import re +import string + +# Initialize the main FastAPI application +main_app = FastAPI() + +# Manually add only non-static routes from mediaflow_app +for route in mediaflow_app.routes: + if route.path != "/": # Exclude the static file path + main_app.router.routes.append(route) + +# Run the main app +if __name__ == "__main__": + import uvicorn + uvicorn.run(main_app, host="0.0.0.0", port=8080)