Add files via upload

This commit is contained in:
Urlo30
2024-12-29 23:18:53 +01:00
committed by GitHub
parent 76987906b5
commit fb60e99822
40 changed files with 5170 additions and 0 deletions

View File

View File

@@ -0,0 +1,68 @@
from typing import Dict, Optional, Union
import httpx
from pydantic import BaseModel, Field
from pydantic_settings import BaseSettings
class RouteConfig(BaseModel):
"""Configuration for a specific route"""
proxy: bool = True
proxy_url: Optional[str] = None
verify_ssl: bool = True
class TransportConfig(BaseSettings):
"""Main proxy configuration"""
proxy_url: Optional[str] = Field(
None, description="Primary proxy URL. Example: socks5://user:pass@proxy:1080 or http://proxy:8080"
)
all_proxy: bool = Field(False, description="Enable proxy for all routes by default")
transport_routes: Dict[str, RouteConfig] = Field(
default_factory=dict, description="Pattern-based route configuration"
)
def get_mounts(
self, async_http: bool = True
) -> Dict[str, Optional[Union[httpx.HTTPTransport, httpx.AsyncHTTPTransport]]]:
"""
Get a dictionary of httpx mount points to transport instances.
"""
mounts = {}
transport_cls = httpx.AsyncHTTPTransport if async_http else httpx.HTTPTransport
# Configure specific routes
for pattern, route in self.transport_routes.items():
mounts[pattern] = transport_cls(
verify=route.verify_ssl, proxy=route.proxy_url or self.proxy_url if route.proxy else None
)
# Set default proxy for all routes if enabled
if self.all_proxy:
mounts["all://"] = transport_cls(proxy=self.proxy_url)
return mounts
class Config:
env_file = ".env"
extra = "ignore"
class Settings(BaseSettings):
api_password: str | None = None # The password for protecting the API endpoints.
log_level: str = "INFO" # The logging level to use.
transport_config: TransportConfig = Field(default_factory=TransportConfig) # Configuration for httpx transport.
enable_streaming_progress: bool = False # Whether to enable streaming progress tracking.
user_agent: str = (
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3" # The user agent to use for HTTP requests.
)
class Config:
env_file = ".env"
extra = "ignore"
settings = Settings()

17
mediaflow_proxy/const.py Normal file
View File

@@ -0,0 +1,17 @@
SUPPORTED_RESPONSE_HEADERS = [
"accept-ranges",
"content-type",
"content-length",
"content-range",
"connection",
"transfer-encoding",
"last-modified",
"etag",
"cache-control",
"expires",
]
SUPPORTED_REQUEST_HEADERS = [
"range",
"if-range",
]

View File

@@ -0,0 +1,11 @@
import os
import tempfile
async def create_temp_file(suffix: str, content: bytes = None, prefix: str = None) -> tempfile.NamedTemporaryFile:
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=suffix, prefix=prefix)
temp_file.delete_file = lambda: os.unlink(temp_file.name)
if content:
temp_file.write(content)
temp_file.close()
return temp_file

View File

@@ -0,0 +1,778 @@
import argparse
import struct
import sys
from typing import Optional, Union
from Crypto.Cipher import AES
from collections import namedtuple
import array
CENCSampleAuxiliaryDataFormat = namedtuple("CENCSampleAuxiliaryDataFormat", ["is_encrypted", "iv", "sub_samples"])
class MP4Atom:
"""
Represents an MP4 atom, which is a basic unit of data in an MP4 file.
Each atom contains a header (size and type) and data.
"""
__slots__ = ("atom_type", "size", "data")
def __init__(self, atom_type: bytes, size: int, data: Union[memoryview, bytearray]):
"""
Initializes an MP4Atom instance.
Args:
atom_type (bytes): The type of the atom.
size (int): The size of the atom.
data (Union[memoryview, bytearray]): The data contained in the atom.
"""
self.atom_type = atom_type
self.size = size
self.data = data
def __repr__(self):
return f"<MP4Atom type={self.atom_type}, size={self.size}>"
def pack(self):
"""
Packs the atom into binary data.
Returns:
bytes: Packed binary data with size, type, and data.
"""
return struct.pack(">I", self.size) + self.atom_type + self.data
class MP4Parser:
"""
Parses MP4 data to extract atoms and their structure.
"""
def __init__(self, data: memoryview):
"""
Initializes an MP4Parser instance.
Args:
data (memoryview): The binary data of the MP4 file.
"""
self.data = data
self.position = 0
def read_atom(self) -> Optional[MP4Atom]:
"""
Reads the next atom from the data.
Returns:
Optional[MP4Atom]: MP4Atom object or None if no more atoms are available.
"""
pos = self.position
if pos + 8 > len(self.data):
return None
size, atom_type = struct.unpack_from(">I4s", self.data, pos)
pos += 8
if size == 1:
if pos + 8 > len(self.data):
return None
size = struct.unpack_from(">Q", self.data, pos)[0]
pos += 8
if size < 8 or pos + size - 8 > len(self.data):
return None
atom_data = self.data[pos : pos + size - 8]
self.position = pos + size - 8
return MP4Atom(atom_type, size, atom_data)
def list_atoms(self) -> list[MP4Atom]:
"""
Lists all atoms in the data.
Returns:
list[MP4Atom]: List of MP4Atom objects.
"""
atoms = []
original_position = self.position
self.position = 0
while self.position + 8 <= len(self.data):
atom = self.read_atom()
if not atom:
break
atoms.append(atom)
self.position = original_position
return atoms
def _read_atom_at(self, pos: int, end: int) -> Optional[MP4Atom]:
if pos + 8 > end:
return None
size, atom_type = struct.unpack_from(">I4s", self.data, pos)
pos += 8
if size == 1:
if pos + 8 > end:
return None
size = struct.unpack_from(">Q", self.data, pos)[0]
pos += 8
if size < 8 or pos + size - 8 > end:
return None
atom_data = self.data[pos : pos + size - 8]
return MP4Atom(atom_type, size, atom_data)
def print_atoms_structure(self, indent: int = 0):
"""
Prints the structure of all atoms in the data.
Args:
indent (int): The indentation level for printing.
"""
pos = 0
end = len(self.data)
while pos + 8 <= end:
atom = self._read_atom_at(pos, end)
if not atom:
break
self.print_single_atom_structure(atom, pos, indent)
pos += atom.size
def print_single_atom_structure(self, atom: MP4Atom, parent_position: int, indent: int):
"""
Prints the structure of a single atom.
Args:
atom (MP4Atom): The atom to print.
parent_position (int): The position of the parent atom.
indent (int): The indentation level for printing.
"""
try:
atom_type = atom.atom_type.decode("utf-8")
except UnicodeDecodeError:
atom_type = repr(atom.atom_type)
print(" " * indent + f"Type: {atom_type}, Size: {atom.size}")
child_pos = 0
child_end = len(atom.data)
while child_pos + 8 <= child_end:
child_atom = self._read_atom_at(parent_position + 8 + child_pos, parent_position + 8 + child_end)
if not child_atom:
break
self.print_single_atom_structure(child_atom, parent_position, indent + 2)
child_pos += child_atom.size
class MP4Decrypter:
"""
Class to handle the decryption of CENC encrypted MP4 segments.
Attributes:
key_map (dict[bytes, bytes]): Mapping of track IDs to decryption keys.
current_key (Optional[bytes]): Current decryption key.
trun_sample_sizes (array.array): Array of sample sizes from the 'trun' box.
current_sample_info (list): List of sample information from the 'senc' box.
encryption_overhead (int): Total size of encryption-related boxes.
"""
def __init__(self, key_map: dict[bytes, bytes]):
"""
Initializes the MP4Decrypter with a key map.
Args:
key_map (dict[bytes, bytes]): Mapping of track IDs to decryption keys.
"""
self.key_map = key_map
self.current_key = None
self.trun_sample_sizes = array.array("I")
self.current_sample_info = []
self.encryption_overhead = 0
def decrypt_segment(self, combined_segment: bytes) -> bytes:
"""
Decrypts a combined MP4 segment.
Args:
combined_segment (bytes): Combined initialization and media segment.
Returns:
bytes: Decrypted segment content.
"""
data = memoryview(combined_segment)
parser = MP4Parser(data)
atoms = parser.list_atoms()
atom_process_order = [b"moov", b"moof", b"sidx", b"mdat"]
processed_atoms = {}
for atom_type in atom_process_order:
if atom := next((a for a in atoms if a.atom_type == atom_type), None):
processed_atoms[atom_type] = self._process_atom(atom_type, atom)
result = bytearray()
for atom in atoms:
if atom.atom_type in processed_atoms:
processed_atom = processed_atoms[atom.atom_type]
result.extend(processed_atom.pack())
else:
result.extend(atom.pack())
return bytes(result)
def _process_atom(self, atom_type: bytes, atom: MP4Atom) -> MP4Atom:
"""
Processes an MP4 atom based on its type.
Args:
atom_type (bytes): Type of the atom.
atom (MP4Atom): The atom to process.
Returns:
MP4Atom: Processed atom.
"""
if atom_type == b"moov":
return self._process_moov(atom)
elif atom_type == b"moof":
return self._process_moof(atom)
elif atom_type == b"sidx":
return self._process_sidx(atom)
elif atom_type == b"mdat":
return self._decrypt_mdat(atom)
else:
return atom
def _process_moov(self, moov: MP4Atom) -> MP4Atom:
"""
Processes the 'moov' (Movie) atom, which contains metadata about the entire presentation.
This includes information about tracks, media data, and other movie-level metadata.
Args:
moov (MP4Atom): The 'moov' atom to process.
Returns:
MP4Atom: Processed 'moov' atom with updated track information.
"""
parser = MP4Parser(moov.data)
new_moov_data = bytearray()
for atom in iter(parser.read_atom, None):
if atom.atom_type == b"trak":
new_trak = self._process_trak(atom)
new_moov_data.extend(new_trak.pack())
elif atom.atom_type != b"pssh":
# Skip PSSH boxes as they are not needed in the decrypted output
new_moov_data.extend(atom.pack())
return MP4Atom(b"moov", len(new_moov_data) + 8, new_moov_data)
def _process_moof(self, moof: MP4Atom) -> MP4Atom:
"""
Processes the 'moov' (Movie) atom, which contains metadata about the entire presentation.
This includes information about tracks, media data, and other movie-level metadata.
Args:
moov (MP4Atom): The 'moov' atom to process.
Returns:
MP4Atom: Processed 'moov' atom with updated track information.
"""
parser = MP4Parser(moof.data)
new_moof_data = bytearray()
for atom in iter(parser.read_atom, None):
if atom.atom_type == b"traf":
new_traf = self._process_traf(atom)
new_moof_data.extend(new_traf.pack())
else:
new_moof_data.extend(atom.pack())
return MP4Atom(b"moof", len(new_moof_data) + 8, new_moof_data)
def _process_traf(self, traf: MP4Atom) -> MP4Atom:
"""
Processes the 'traf' (Track Fragment) atom, which contains information about a track fragment.
This includes sample information, sample encryption data, and other track-level metadata.
Args:
traf (MP4Atom): The 'traf' atom to process.
Returns:
MP4Atom: Processed 'traf' atom with updated sample information.
"""
parser = MP4Parser(traf.data)
new_traf_data = bytearray()
tfhd = None
sample_count = 0
sample_info = []
atoms = parser.list_atoms()
# calculate encryption_overhead earlier to avoid dependency on trun
self.encryption_overhead = sum(a.size for a in atoms if a.atom_type in {b"senc", b"saiz", b"saio"})
for atom in atoms:
if atom.atom_type == b"tfhd":
tfhd = atom
new_traf_data.extend(atom.pack())
elif atom.atom_type == b"trun":
sample_count = self._process_trun(atom)
new_trun = self._modify_trun(atom)
new_traf_data.extend(new_trun.pack())
elif atom.atom_type == b"senc":
# Parse senc but don't include it in the new decrypted traf data and similarly don't include saiz and saio
sample_info = self._parse_senc(atom, sample_count)
elif atom.atom_type not in {b"saiz", b"saio"}:
new_traf_data.extend(atom.pack())
if tfhd:
tfhd_track_id = struct.unpack_from(">I", tfhd.data, 4)[0]
self.current_key = self._get_key_for_track(tfhd_track_id)
self.current_sample_info = sample_info
return MP4Atom(b"traf", len(new_traf_data) + 8, new_traf_data)
def _decrypt_mdat(self, mdat: MP4Atom) -> MP4Atom:
"""
Decrypts the 'mdat' (Media Data) atom, which contains the actual media data (audio, video, etc.).
The decryption is performed using the current decryption key and sample information.
Args:
mdat (MP4Atom): The 'mdat' atom to decrypt.
Returns:
MP4Atom: Decrypted 'mdat' atom with decrypted media data.
"""
if not self.current_key or not self.current_sample_info:
return mdat # Return original mdat if we don't have decryption info
decrypted_samples = bytearray()
mdat_data = mdat.data
position = 0
for i, info in enumerate(self.current_sample_info):
if position >= len(mdat_data):
break # No more data to process
sample_size = self.trun_sample_sizes[i] if i < len(self.trun_sample_sizes) else len(mdat_data) - position
sample = mdat_data[position : position + sample_size]
position += sample_size
decrypted_sample = self._process_sample(sample, info, self.current_key)
decrypted_samples.extend(decrypted_sample)
return MP4Atom(b"mdat", len(decrypted_samples) + 8, decrypted_samples)
def _parse_senc(self, senc: MP4Atom, sample_count: int) -> list[CENCSampleAuxiliaryDataFormat]:
"""
Parses the 'senc' (Sample Encryption) atom, which contains encryption information for samples.
This includes initialization vectors (IVs) and sub-sample encryption data.
Args:
senc (MP4Atom): The 'senc' atom to parse.
sample_count (int): The number of samples.
Returns:
list[CENCSampleAuxiliaryDataFormat]: List of sample auxiliary data formats with encryption information.
"""
data = memoryview(senc.data)
version_flags = struct.unpack_from(">I", data, 0)[0]
version, flags = version_flags >> 24, version_flags & 0xFFFFFF
position = 4
if version == 0:
sample_count = struct.unpack_from(">I", data, position)[0]
position += 4
sample_info = []
for _ in range(sample_count):
if position + 8 > len(data):
break
iv = data[position : position + 8].tobytes()
position += 8
sub_samples = []
if flags & 0x000002 and position + 2 <= len(data): # Check if subsample information is present
subsample_count = struct.unpack_from(">H", data, position)[0]
position += 2
for _ in range(subsample_count):
if position + 6 <= len(data):
clear_bytes, encrypted_bytes = struct.unpack_from(">HI", data, position)
position += 6
sub_samples.append((clear_bytes, encrypted_bytes))
else:
break
sample_info.append(CENCSampleAuxiliaryDataFormat(True, iv, sub_samples))
return sample_info
def _get_key_for_track(self, track_id: int) -> bytes:
"""
Retrieves the decryption key for a given track ID from the key map.
Args:
track_id (int): The track ID.
Returns:
bytes: The decryption key for the specified track ID.
"""
if len(self.key_map) == 1:
return next(iter(self.key_map.values()))
key = self.key_map.get(track_id.pack(4, "big"))
if not key:
raise ValueError(f"No key found for track ID {track_id}")
return key
@staticmethod
def _process_sample(
sample: memoryview, sample_info: CENCSampleAuxiliaryDataFormat, key: bytes
) -> Union[memoryview, bytearray, bytes]:
"""
Processes and decrypts a sample using the provided sample information and decryption key.
This includes handling sub-sample encryption if present.
Args:
sample (memoryview): The sample data.
sample_info (CENCSampleAuxiliaryDataFormat): The sample auxiliary data format with encryption information.
key (bytes): The decryption key.
Returns:
Union[memoryview, bytearray, bytes]: The decrypted sample.
"""
if not sample_info.is_encrypted:
return sample
# pad IV to 16 bytes
iv = sample_info.iv + b"\x00" * (16 - len(sample_info.iv))
cipher = AES.new(key, AES.MODE_CTR, initial_value=iv, nonce=b"")
if not sample_info.sub_samples:
# If there are no sub_samples, decrypt the entire sample
return cipher.decrypt(sample)
result = bytearray()
offset = 0
for clear_bytes, encrypted_bytes in sample_info.sub_samples:
result.extend(sample[offset : offset + clear_bytes])
offset += clear_bytes
result.extend(cipher.decrypt(sample[offset : offset + encrypted_bytes]))
offset += encrypted_bytes
# If there's any remaining data, treat it as encrypted
if offset < len(sample):
result.extend(cipher.decrypt(sample[offset:]))
return result
def _process_trun(self, trun: MP4Atom) -> int:
"""
Processes the 'trun' (Track Fragment Run) atom, which contains information about the samples in a track fragment.
This includes sample sizes, durations, flags, and composition time offsets.
Args:
trun (MP4Atom): The 'trun' atom to process.
Returns:
int: The number of samples in the 'trun' atom.
"""
trun_flags, sample_count = struct.unpack_from(">II", trun.data, 0)
data_offset = 8
if trun_flags & 0x000001:
data_offset += 4
if trun_flags & 0x000004:
data_offset += 4
self.trun_sample_sizes = array.array("I")
for _ in range(sample_count):
if trun_flags & 0x000100: # sample-duration-present flag
data_offset += 4
if trun_flags & 0x000200: # sample-size-present flag
sample_size = struct.unpack_from(">I", trun.data, data_offset)[0]
self.trun_sample_sizes.append(sample_size)
data_offset += 4
else:
self.trun_sample_sizes.append(0) # Using 0 instead of None for uniformity in the array
if trun_flags & 0x000400: # sample-flags-present flag
data_offset += 4
if trun_flags & 0x000800: # sample-composition-time-offsets-present flag
data_offset += 4
return sample_count
def _modify_trun(self, trun: MP4Atom) -> MP4Atom:
"""
Modifies the 'trun' (Track Fragment Run) atom to update the data offset.
This is necessary to account for the encryption overhead.
Args:
trun (MP4Atom): The 'trun' atom to modify.
Returns:
MP4Atom: Modified 'trun' atom with updated data offset.
"""
trun_data = bytearray(trun.data)
current_flags = struct.unpack_from(">I", trun_data, 0)[0] & 0xFFFFFF
# If the data-offset-present flag is set, update the data offset to account for encryption overhead
if current_flags & 0x000001:
current_data_offset = struct.unpack_from(">i", trun_data, 8)[0]
struct.pack_into(">i", trun_data, 8, current_data_offset - self.encryption_overhead)
return MP4Atom(b"trun", len(trun_data) + 8, trun_data)
def _process_sidx(self, sidx: MP4Atom) -> MP4Atom:
"""
Processes the 'sidx' (Segment Index) atom, which contains indexing information for media segments.
This includes references to media segments and their durations.
Args:
sidx (MP4Atom): The 'sidx' atom to process.
Returns:
MP4Atom: Processed 'sidx' atom with updated segment references.
"""
sidx_data = bytearray(sidx.data)
current_size = struct.unpack_from(">I", sidx_data, 32)[0]
reference_type = current_size >> 31
current_referenced_size = current_size & 0x7FFFFFFF
# Remove encryption overhead from referenced size
new_referenced_size = current_referenced_size - self.encryption_overhead
new_size = (reference_type << 31) | new_referenced_size
struct.pack_into(">I", sidx_data, 32, new_size)
return MP4Atom(b"sidx", len(sidx_data) + 8, sidx_data)
def _process_trak(self, trak: MP4Atom) -> MP4Atom:
"""
Processes the 'trak' (Track) atom, which contains information about a single track in the movie.
This includes track header, media information, and other track-level metadata.
Args:
trak (MP4Atom): The 'trak' atom to process.
Returns:
MP4Atom: Processed 'trak' atom with updated track information.
"""
parser = MP4Parser(trak.data)
new_trak_data = bytearray()
for atom in iter(parser.read_atom, None):
if atom.atom_type == b"mdia":
new_mdia = self._process_mdia(atom)
new_trak_data.extend(new_mdia.pack())
else:
new_trak_data.extend(atom.pack())
return MP4Atom(b"trak", len(new_trak_data) + 8, new_trak_data)
def _process_mdia(self, mdia: MP4Atom) -> MP4Atom:
"""
Processes the 'mdia' (Media) atom, which contains media information for a track.
This includes media header, handler reference, and media information container.
Args:
mdia (MP4Atom): The 'mdia' atom to process.
Returns:
MP4Atom: Processed 'mdia' atom with updated media information.
"""
parser = MP4Parser(mdia.data)
new_mdia_data = bytearray()
for atom in iter(parser.read_atom, None):
if atom.atom_type == b"minf":
new_minf = self._process_minf(atom)
new_mdia_data.extend(new_minf.pack())
else:
new_mdia_data.extend(atom.pack())
return MP4Atom(b"mdia", len(new_mdia_data) + 8, new_mdia_data)
def _process_minf(self, minf: MP4Atom) -> MP4Atom:
"""
Processes the 'minf' (Media Information) atom, which contains information about the media data in a track.
This includes data information, sample table, and other media-level metadata.
Args:
minf (MP4Atom): The 'minf' atom to process.
Returns:
MP4Atom: Processed 'minf' atom with updated media information.
"""
parser = MP4Parser(minf.data)
new_minf_data = bytearray()
for atom in iter(parser.read_atom, None):
if atom.atom_type == b"stbl":
new_stbl = self._process_stbl(atom)
new_minf_data.extend(new_stbl.pack())
else:
new_minf_data.extend(atom.pack())
return MP4Atom(b"minf", len(new_minf_data) + 8, new_minf_data)
def _process_stbl(self, stbl: MP4Atom) -> MP4Atom:
"""
Processes the 'stbl' (Sample Table) atom, which contains information about the samples in a track.
This includes sample descriptions, sample sizes, sample times, and other sample-level metadata.
Args:
stbl (MP4Atom): The 'stbl' atom to process.
Returns:
MP4Atom: Processed 'stbl' atom with updated sample information.
"""
parser = MP4Parser(stbl.data)
new_stbl_data = bytearray()
for atom in iter(parser.read_atom, None):
if atom.atom_type == b"stsd":
new_stsd = self._process_stsd(atom)
new_stbl_data.extend(new_stsd.pack())
else:
new_stbl_data.extend(atom.pack())
return MP4Atom(b"stbl", len(new_stbl_data) + 8, new_stbl_data)
def _process_stsd(self, stsd: MP4Atom) -> MP4Atom:
"""
Processes the 'stsd' (Sample Description) atom, which contains descriptions of the sample entries in a track.
This includes codec information, sample entry details, and other sample description metadata.
Args:
stsd (MP4Atom): The 'stsd' atom to process.
Returns:
MP4Atom: Processed 'stsd' atom with updated sample descriptions.
"""
parser = MP4Parser(stsd.data)
entry_count = struct.unpack_from(">I", parser.data, 4)[0]
new_stsd_data = bytearray(stsd.data[:8])
parser.position = 8 # Move past version_flags and entry_count
for _ in range(entry_count):
sample_entry = parser.read_atom()
if not sample_entry:
break
processed_entry = self._process_sample_entry(sample_entry)
new_stsd_data.extend(processed_entry.pack())
return MP4Atom(b"stsd", len(new_stsd_data) + 8, new_stsd_data)
def _process_sample_entry(self, entry: MP4Atom) -> MP4Atom:
"""
Processes a sample entry atom, which contains information about a specific type of sample.
This includes codec-specific information and other sample entry details.
Args:
entry (MP4Atom): The sample entry atom to process.
Returns:
MP4Atom: Processed sample entry atom with updated information.
"""
# Determine the size of fixed fields based on sample entry type
if entry.atom_type in {b"mp4a", b"enca"}:
fixed_size = 28 # 8 bytes for size, type and reserved, 20 bytes for fixed fields in Audio Sample Entry.
elif entry.atom_type in {b"mp4v", b"encv", b"avc1", b"hev1", b"hvc1"}:
fixed_size = 78 # 8 bytes for size, type and reserved, 70 bytes for fixed fields in Video Sample Entry.
else:
fixed_size = 16 # 8 bytes for size, type and reserved, 8 bytes for fixed fields in other Sample Entries.
new_entry_data = bytearray(entry.data[:fixed_size])
parser = MP4Parser(entry.data[fixed_size:])
codec_format = None
for atom in iter(parser.read_atom, None):
if atom.atom_type in {b"sinf", b"schi", b"tenc", b"schm"}:
if atom.atom_type == b"sinf":
codec_format = self._extract_codec_format(atom)
continue # Skip encryption-related atoms
new_entry_data.extend(atom.pack())
# Replace the atom type with the extracted codec format
new_type = codec_format if codec_format else entry.atom_type
return MP4Atom(new_type, len(new_entry_data) + 8, new_entry_data)
def _extract_codec_format(self, sinf: MP4Atom) -> Optional[bytes]:
"""
Extracts the codec format from the 'sinf' (Protection Scheme Information) atom.
This includes information about the original format of the protected content.
Args:
sinf (MP4Atom): The 'sinf' atom to extract from.
Returns:
Optional[bytes]: The codec format or None if not found.
"""
parser = MP4Parser(sinf.data)
for atom in iter(parser.read_atom, None):
if atom.atom_type == b"frma":
return atom.data
return None
def decrypt_segment(init_segment: bytes, segment_content: bytes, key_id: str, key: str) -> bytes:
"""
Decrypts a CENC encrypted MP4 segment.
Args:
init_segment (bytes): Initialization segment data.
segment_content (bytes): Encrypted segment content.
key_id (str): Key ID in hexadecimal format.
key (str): Key in hexadecimal format.
"""
key_map = {bytes.fromhex(key_id): bytes.fromhex(key)}
decrypter = MP4Decrypter(key_map)
decrypted_content = decrypter.decrypt_segment(init_segment + segment_content)
return decrypted_content
def cli():
"""
Command line interface for decrypting a CENC encrypted MP4 segment.
"""
init_segment = b""
if args.init and args.segment:
with open(args.init, "rb") as f:
init_segment = f.read()
with open(args.segment, "rb") as f:
segment_content = f.read()
elif args.combined_segment:
with open(args.combined_segment, "rb") as f:
segment_content = f.read()
else:
print("Usage: python mp4decrypt.py --help")
sys.exit(1)
try:
decrypted_segment = decrypt_segment(init_segment, segment_content, args.key_id, args.key)
print(f"Decrypted content size is {len(decrypted_segment)} bytes")
with open(args.output, "wb") as f:
f.write(decrypted_segment)
print(f"Decrypted segment written to {args.output}")
except Exception as e:
print(f"Error: {e}")
sys.exit(1)
if __name__ == "__main__":
arg_parser = argparse.ArgumentParser(description="Decrypts a MP4 init and media segment using CENC encryption.")
arg_parser.add_argument("--init", help="Path to the init segment file", required=False)
arg_parser.add_argument("--segment", help="Path to the media segment file", required=False)
arg_parser.add_argument(
"--combined_segment", help="Path to the combined init and media segment file", required=False
)
arg_parser.add_argument("--key_id", help="Key ID in hexadecimal format", required=True)
arg_parser.add_argument("--key", help="Key in hexadecimal format", required=True)
arg_parser.add_argument("--output", help="Path to the output file", required=True)
args = arg_parser.parse_args()
cli()

View File

View File

@@ -0,0 +1,50 @@
from abc import ABC, abstractmethod
from typing import Dict, Optional, Any
import httpx
from mediaflow_proxy.configs import settings
from mediaflow_proxy.utils.http_utils import create_httpx_client
class ExtractorError(Exception):
"""Base exception for all extractors."""
pass
class BaseExtractor(ABC):
"""Base class for all URL extractors."""
def __init__(self, request_headers: dict):
self.base_headers = {
"user-agent": settings.user_agent,
}
self.mediaflow_endpoint = "proxy_stream_endpoint"
self.base_headers.update(request_headers)
async def _make_request(
self, url: str, method: str = "GET", headers: Optional[Dict] = None, **kwargs
) -> httpx.Response:
"""Make HTTP request with error handling."""
try:
async with create_httpx_client() as client:
request_headers = self.base_headers
request_headers.update(headers or {})
response = await client.request(
method,
url,
headers=request_headers,
**kwargs,
)
response.raise_for_status()
return response
except httpx.HTTPError as e:
raise ExtractorError(f"HTTP request failed: {str(e)}")
except Exception as e:
raise ExtractorError(f"Request failed: {str(e)}")
@abstractmethod
async def extract(self, url: str, **kwargs) -> Dict[str, Any]:
"""Extract final URL and required headers."""
pass

View File

@@ -0,0 +1,39 @@
import re
import time
from typing import Dict
from mediaflow_proxy.extractors.base import BaseExtractor, ExtractorError
class DoodStreamExtractor(BaseExtractor):
"""DoodStream URL extractor."""
def __init__(self, request_headers: dict):
super().__init__(request_headers)
self.base_url = "https://d000d.com"
async def extract(self, url: str, **kwargs) -> Dict[str, str]:
"""Extract DoodStream URL."""
response = await self._make_request(url)
# Extract URL pattern
pattern = r"(\/pass_md5\/.*?)'.*(\?token=.*?expiry=)"
match = re.search(pattern, response.text, re.DOTALL)
if not match:
raise ExtractorError("Failed to extract URL pattern")
# Build final URL
pass_url = f"{self.base_url}{match[1]}"
referer = f"{self.base_url}/"
headers = {"range": "bytes=0-", "referer": referer}
response = await self._make_request(pass_url, headers=headers)
timestamp = str(int(time.time()))
final_url = f"{response.text}123456789{match[2]}{timestamp}"
self.base_headers["referer"] = referer
return {
"destination_url": final_url,
"request_headers": self.base_headers,
"mediaflow_endpoint": self.mediaflow_endpoint,
}

View File

@@ -0,0 +1,32 @@
from typing import Dict, Type
from mediaflow_proxy.extractors.base import BaseExtractor, ExtractorError
from mediaflow_proxy.extractors.doodstream import DoodStreamExtractor
from mediaflow_proxy.extractors.livetv import LiveTVExtractor
from mediaflow_proxy.extractors.mixdrop import MixdropExtractor
from mediaflow_proxy.extractors.uqload import UqloadExtractor
from mediaflow_proxy.extractors.streamtape import StreamtapeExtractor
from mediaflow_proxy.extractors.supervideo import SupervideoExtractor
class ExtractorFactory:
"""Factory for creating URL extractors."""
_extractors: Dict[str, Type[BaseExtractor]] = {
"Doodstream": DoodStreamExtractor,
"Uqload": UqloadExtractor,
"Mixdrop": MixdropExtractor,
"Streamtape": StreamtapeExtractor,
"Supervideo": SupervideoExtractor,
"LiveTV": LiveTVExtractor,
}
@classmethod
def get_extractor(cls, host: str, request_headers: dict) -> BaseExtractor:
"""Get appropriate extractor instance for the given host."""
extractor_class = cls._extractors.get(host)
if not extractor_class:
raise ExtractorError(f"Unsupported host: {host}")
return extractor_class(request_headers)

View File

@@ -0,0 +1,251 @@
import re
from typing import Dict, Tuple, Optional
from urllib.parse import urljoin, urlparse, unquote
from httpx import Response
from mediaflow_proxy.extractors.base import BaseExtractor, ExtractorError
class LiveTVExtractor(BaseExtractor):
"""LiveTV URL extractor for both M3U8 and MPD streams."""
def __init__(self, request_headers: dict):
super().__init__(request_headers)
# Default to HLS proxy endpoint, will be updated based on stream type
self.mediaflow_endpoint = "hls_manifest_proxy"
# Patterns for stream URL extraction
self.fallback_pattern = re.compile(
r"source: [\'\"](.*?)[\'\"]\s*,\s*[\s\S]*?mimeType: [\'\"](application/x-mpegURL|application/vnd\.apple\.mpegURL|application/dash\+xml)[\'\"]",
re.IGNORECASE,
)
self.any_m3u8_pattern = re.compile(
r'["\']?(https?://.*?\.m3u8(?:\?[^"\']*)?)["\']?',
re.IGNORECASE,
)
async def extract(self, url: str, stream_title: str = None, **kwargs) -> Dict[str, str]:
"""Extract LiveTV URL and required headers.
Args:
url: The channel page URL
stream_title: Optional stream title to filter specific stream
Returns:
Tuple[str, Dict[str, str]]: Stream URL and required headers
"""
try:
# Get the channel page
response = await self._make_request(url)
self.base_headers["referer"] = urljoin(url, "/")
# Extract player API details
player_api_base, method = await self._extract_player_api_base(response.text)
if not player_api_base:
raise ExtractorError("Failed to extract player API URL")
# Get player options
options_data = await self._get_player_options(response.text)
if not options_data:
raise ExtractorError("No player options found")
# Process player options to find matching stream
for option in options_data:
current_title = option.get("title")
if stream_title and current_title != stream_title:
continue
# Get stream URL based on player option
stream_data = await self._process_player_option(
player_api_base, method, option.get("post"), option.get("nume"), option.get("type")
)
if stream_data:
stream_url = stream_data.get("url")
if not stream_url:
continue
response = {
"destination_url": stream_url,
"request_headers": self.base_headers,
"mediaflow_endpoint": self.mediaflow_endpoint,
}
# Set endpoint based on stream type
if stream_data.get("type") == "mpd":
if stream_data.get("drm_key_id") and stream_data.get("drm_key"):
response.update(
{
"query_params": {
"key_id": stream_data["drm_key_id"],
"key": stream_data["drm_key"],
},
"mediaflow_endpoint": "mpd_manifest_proxy",
}
)
return response
raise ExtractorError("No valid stream found")
except Exception as e:
raise ExtractorError(f"Extraction failed: {str(e)}")
async def _extract_player_api_base(self, html_content: str) -> Tuple[Optional[str], Optional[str]]:
"""Extract player API base URL and method."""
admin_ajax_pattern = r'"player_api"\s*:\s*"([^"]+)".*?"play_method"\s*:\s*"([^"]+)"'
match = re.search(admin_ajax_pattern, html_content)
if not match:
return None, None
url = match.group(1).replace("\\/", "/")
method = match.group(2)
if method == "wp_json":
return url, method
url = urljoin(url, "/wp-admin/admin-ajax.php")
return url, method
async def _get_player_options(self, html_content: str) -> list:
"""Extract player options from HTML content."""
pattern = r'<li[^>]*class=["\']dooplay_player_option["\'][^>]*data-type=["\']([^"\']*)["\'][^>]*data-post=["\']([^"\']*)["\'][^>]*data-nume=["\']([^"\']*)["\'][^>]*>.*?<span class=["\']title["\']>([^<]*)</span>'
matches = re.finditer(pattern, html_content, re.DOTALL)
return [
{"type": match.group(1), "post": match.group(2), "nume": match.group(3), "title": match.group(4).strip()}
for match in matches
]
async def _process_player_option(self, api_base: str, method: str, post: str, nume: str, type_: str) -> Dict:
"""Process player option to get stream URL."""
if method == "wp_json":
api_url = f"{api_base}{post}/{type_}/{nume}"
response = await self._make_request(api_url)
else:
form_data = {"action": "doo_player_ajax", "post": post, "nume": nume, "type": type_}
response = await self._make_request(api_base, method="POST", data=form_data)
# Get iframe URL from API response
try:
data = response.json()
iframe_url = urljoin(api_base, data.get("embed_url", "").replace("\\/", "/"))
# Get stream URL from iframe
iframe_response = await self._make_request(iframe_url)
stream_data = await self._extract_stream_url(iframe_response, iframe_url)
return stream_data
except Exception as e:
raise ExtractorError(f"Failed to process player option: {str(e)}")
async def _extract_stream_url(self, iframe_response: Response, iframe_url: str) -> Dict:
"""
Extract final stream URL from iframe content.
"""
try:
# Parse URL components
parsed_url = urlparse(iframe_url)
query_params = dict(param.split("=") for param in parsed_url.query.split("&") if "=" in param)
# Check if content is already a direct M3U8 stream
content_types = ["application/x-mpegurl", "application/vnd.apple.mpegurl"]
if any(ext in iframe_response.headers["content-type"] for ext in content_types):
return {"url": iframe_url, "type": "m3u8"}
stream_data = {}
# Check for source parameter in URL
if "source" in query_params:
stream_data = {
"url": urljoin(iframe_url, unquote(query_params["source"])),
"type": "m3u8",
}
# Check for MPD stream with DRM
elif "zy" in query_params and ".mpd``" in query_params["zy"]:
data = query_params["zy"].split("``")
url = data[0]
key_id, key = data[1].split(":")
stream_data = {"url": url, "type": "mpd", "drm_key_id": key_id, "drm_key": key}
# Check for tamilultra specific format
elif "tamilultra" in iframe_url:
stream_data = {"url": urljoin(iframe_url, parsed_url.query), "type": "m3u8"}
# Try pattern matching for stream URLs
else:
channel_id = query_params.get("id", [""])
stream_url = None
html_content = iframe_response.text
if channel_id:
# Try channel ID specific pattern
pattern = rf'{re.escape(channel_id)}["\']:\s*{{\s*["\']?url["\']?\s*:\s*["\']([^"\']+)["\']'
match = re.search(pattern, html_content)
if match:
stream_url = match.group(1)
# Try fallback patterns if channel ID pattern fails
if not stream_url:
for pattern in [self.fallback_pattern, self.any_m3u8_pattern]:
match = pattern.search(html_content)
if match:
stream_url = match.group(1)
break
if stream_url:
stream_data = {"url": stream_url, "type": "m3u8"} # Default to m3u8, will be updated
# Check for MPD stream and extract DRM keys
if stream_url.endswith(".mpd"):
stream_data["type"] = "mpd"
drm_data = await self._extract_drm_keys(html_content, channel_id)
if drm_data:
stream_data.update(drm_data)
# If no stream data found, raise error
if not stream_data:
raise ExtractorError("No valid stream URL found")
# Update stream type based on URL if not already set
if stream_data.get("type") == "m3u8":
if stream_data["url"].endswith(".mpd"):
stream_data["type"] = "mpd"
elif not any(ext in stream_data["url"] for ext in [".m3u8", ".m3u"]):
stream_data["type"] = "m3u8" # Default to m3u8 if no extension found
return stream_data
except Exception as e:
raise ExtractorError(f"Failed to extract stream URL: {str(e)}")
async def _extract_drm_keys(self, html_content: str, channel_id: str) -> Dict:
"""
Extract DRM keys for MPD streams.
"""
try:
# Pattern for channel entry
channel_pattern = rf'"{re.escape(channel_id)}":\s*{{[^}}]+}}'
channel_match = re.search(channel_pattern, html_content)
if channel_match:
channel_data = channel_match.group(0)
# Try clearkeys pattern first
clearkey_pattern = r'["\']?clearkeys["\']?\s*:\s*{\s*["\'](.+?)["\']:\s*["\'](.+?)["\']'
clearkey_match = re.search(clearkey_pattern, channel_data)
# Try k1/k2 pattern if clearkeys not found
if not clearkey_match:
k1k2_pattern = r'["\']?k1["\']?\s*:\s*["\'](.+?)["\'],\s*["\']?k2["\']?\s*:\s*["\'](.+?)["\']'
k1k2_match = re.search(k1k2_pattern, channel_data)
if k1k2_match:
return {"drm_key_id": k1k2_match.group(1), "drm_key": k1k2_match.group(2)}
else:
return {"drm_key_id": clearkey_match.group(1), "drm_key": clearkey_match.group(2)}
return {}
except Exception:
return {}

View File

@@ -0,0 +1,36 @@
import re
import string
from typing import Dict, Any
from mediaflow_proxy.extractors.base import BaseExtractor, ExtractorError
class MixdropExtractor(BaseExtractor):
"""Mixdrop URL extractor."""
async def extract(self, url: str, **kwargs) -> Dict[str, Any]:
"""Extract Mixdrop URL."""
response = await self._make_request(url, headers={"accept-language": "en-US,en;q=0.5"})
# Extract and decode URL
match = re.search(r"}\('(.+)',.+,'(.+)'\.split", response.text)
if not match:
raise ExtractorError("Failed to extract URL components")
s1, s2 = match.group(1, 2)
schema = s1.split(";")[2][5:-1]
terms = s2.split("|")
# Build character mapping
charset = string.digits + string.ascii_letters
char_map = {charset[i]: terms[i] or charset[i] for i in range(len(terms))}
# Construct final URL
final_url = "https:" + "".join(char_map.get(c, c) for c in schema)
self.base_headers["referer"] = url
return {
"destination_url": final_url,
"request_headers": self.base_headers,
"mediaflow_endpoint": self.mediaflow_endpoint,
}

View File

@@ -0,0 +1,32 @@
import re
from typing import Dict, Any
from mediaflow_proxy.extractors.base import BaseExtractor, ExtractorError
class StreamtapeExtractor(BaseExtractor):
"""Streamtape URL extractor."""
async def extract(self, url: str, **kwargs) -> Dict[str, Any]:
"""Extract Streamtape URL."""
response = await self._make_request(url)
# Extract and decode URL
matches = re.findall(r"id=.*?(?=')", response.text)
if not matches:
raise ExtractorError("Failed to extract URL components")
final_url = next(
(
f"https://streamtape.com/get_video?{matches[i + 1]}"
for i in range(len(matches) - 1)
if matches[i] == matches[i + 1]
),
None,
)
self.base_headers["referer"] = url
return {
"destination_url": final_url,
"request_headers": self.base_headers,
"mediaflow_endpoint": self.mediaflow_endpoint,
}

View File

@@ -0,0 +1,27 @@
import re
from typing import Dict, Any
from mediaflow_proxy.extractors.base import BaseExtractor, ExtractorError
class SupervideoExtractor(BaseExtractor):
"""Supervideo URL extractor."""
async def extract(self, url: str, **kwargs) -> Dict[str, Any]:
"""Extract Supervideo URL."""
response = await self._make_request(url)
# Extract and decode URL
s2 = re.search(r"\}\('(.+)',.+,'(.+)'\.split", response.text).group(2)
terms = s2.split("|")
hfs = next(terms[i] for i in range(terms.index("file"), len(terms)) if "hfs" in terms[i])
result = terms[terms.index("urlset") + 1 : terms.index("hls")]
base_url = f"https://{hfs}.serversicuro.cc/hls/"
final_url = base_url + ",".join(reversed(result)) + (".urlset/master.m3u8" if result else "")
self.base_headers["referer"] = url
return {
"destination_url": final_url,
"request_headers": self.base_headers,
"mediaflow_endpoint": self.mediaflow_endpoint,
}

View File

@@ -0,0 +1,24 @@
import re
from typing import Dict
from urllib.parse import urljoin
from mediaflow_proxy.extractors.base import BaseExtractor, ExtractorError
class UqloadExtractor(BaseExtractor):
"""Uqload URL extractor."""
async def extract(self, url: str, **kwargs) -> Dict[str, str]:
"""Extract Uqload URL."""
response = await self._make_request(url)
video_url_match = re.search(r'sources: \["(.*?)"]', response.text)
if not video_url_match:
raise ExtractorError("Failed to extract video URL")
self.base_headers["referer"] = urljoin(url, "/")
return {
"destination_url": video_url_match.group(1),
"request_headers": self.base_headers,
"mediaflow_endpoint": self.mediaflow_endpoint,
}

358
mediaflow_proxy/handlers.py Normal file
View File

@@ -0,0 +1,358 @@
import base64
import logging
from urllib.parse import urlparse
import httpx
from fastapi import Request, Response, HTTPException
from starlette.background import BackgroundTask
from .const import SUPPORTED_RESPONSE_HEADERS
from .mpd_processor import process_manifest, process_playlist, process_segment
from .schemas import HLSManifestParams, ProxyStreamParams, MPDManifestParams, MPDPlaylistParams, MPDSegmentParams
from .utils.cache_utils import get_cached_mpd, get_cached_init_segment
from .utils.http_utils import (
Streamer,
DownloadError,
download_file_with_retry,
request_with_retry,
EnhancedStreamingResponse,
ProxyRequestHeaders,
create_httpx_client,
)
from .utils.m3u8_processor import M3U8Processor
from .utils.mpd_utils import pad_base64
logger = logging.getLogger(__name__)
async def setup_client_and_streamer() -> tuple[httpx.AsyncClient, Streamer]:
"""
Set up an HTTP client and a streamer.
Returns:
tuple: An httpx.AsyncClient instance and a Streamer instance.
"""
client = create_httpx_client()
return client, Streamer(client)
def handle_exceptions(exception: Exception) -> Response:
"""
Handle exceptions and return appropriate HTTP responses.
Args:
exception (Exception): The exception that was raised.
Returns:
Response: An HTTP response corresponding to the exception type.
"""
if isinstance(exception, httpx.HTTPStatusError):
logger.error(f"Upstream service error while handling request: {exception}")
return Response(status_code=exception.response.status_code, content=f"Upstream service error: {exception}")
elif isinstance(exception, DownloadError):
logger.error(f"Error downloading content: {exception}")
return Response(status_code=exception.status_code, content=str(exception))
else:
logger.exception(f"Internal server error while handling request: {exception}")
return Response(status_code=502, content=f"Internal server error: {exception}")
async def handle_hls_stream_proxy(
request: Request, hls_params: HLSManifestParams, proxy_headers: ProxyRequestHeaders
) -> Response:
"""
Handle HLS stream proxy requests.
This function processes HLS manifest files and streams content based on the request parameters.
Args:
request (Request): The incoming FastAPI request object.
hls_params (HLSManifestParams): Parameters for the HLS manifest.
proxy_headers (ProxyRequestHeaders): Headers to be used in the proxy request.
Returns:
Union[Response, EnhancedStreamingResponse]: Either a processed m3u8 playlist or a streaming response.
"""
client, streamer = await setup_client_and_streamer()
try:
if urlparse(hls_params.destination).path.endswith((".m3u", ".m3u8")):
return await fetch_and_process_m3u8(
streamer, hls_params.destination, proxy_headers, request, hls_params.key_url
)
# Create initial streaming response to check content type
await streamer.create_streaming_response(hls_params.destination, proxy_headers.request)
if "mpegurl" in streamer.response.headers.get("content-type", "").lower():
return await fetch_and_process_m3u8(
streamer, hls_params.destination, proxy_headers, request, hls_params.key_url
)
# Handle range requests
content_range = proxy_headers.request.get("range", "bytes=0-")
if "NaN" in content_range:
# Handle invalid range requests "bytes=NaN-NaN"
raise HTTPException(status_code=416, detail="Invalid Range Header")
proxy_headers.request.update({"range": content_range})
# Create new streaming response with updated headers
await streamer.create_streaming_response(hls_params.destination, proxy_headers.request)
response_headers = prepare_response_headers(streamer.response.headers, proxy_headers.response)
return EnhancedStreamingResponse(
streamer.stream_content(),
status_code=streamer.response.status_code,
headers=response_headers,
background=BackgroundTask(streamer.close),
)
except Exception as e:
await streamer.close()
return handle_exceptions(e)
async def handle_stream_request(
method: str,
video_url: str,
proxy_headers: ProxyRequestHeaders,
) -> Response:
"""
Handle general stream requests.
This function processes both HEAD and GET requests for video streams.
Args:
method (str): The HTTP method (e.g., 'GET' or 'HEAD').
video_url (str): The URL of the video to stream.
proxy_headers (ProxyRequestHeaders): Headers to be used in the proxy request.
Returns:
Union[Response, EnhancedStreamingResponse]: Either a HEAD response with headers or a streaming response.
"""
client, streamer = await setup_client_and_streamer()
try:
await streamer.create_streaming_response(video_url, proxy_headers.request)
response_headers = prepare_response_headers(streamer.response.headers, proxy_headers.response)
if method == "HEAD":
# For HEAD requests, just return the headers without streaming content
await streamer.close()
return Response(headers=response_headers, status_code=streamer.response.status_code)
else:
# For GET requests, return the streaming response
return EnhancedStreamingResponse(
streamer.stream_content(),
headers=response_headers,
status_code=streamer.response.status_code,
background=BackgroundTask(streamer.close),
)
except Exception as e:
await streamer.close()
return handle_exceptions(e)
def prepare_response_headers(original_headers, proxy_response_headers) -> dict:
"""
Prepare response headers for the proxy response.
This function filters the original headers, ensures proper transfer encoding,
and merges them with the proxy response headers.
Args:
original_headers (httpx.Headers): The original headers from the upstream response.
proxy_response_headers (dict): Additional headers to be included in the proxy response.
Returns:
dict: The prepared headers for the proxy response.
"""
response_headers = {k: v for k, v in original_headers.multi_items() if k in SUPPORTED_RESPONSE_HEADERS}
response_headers.update(proxy_response_headers)
return response_headers
async def proxy_stream(method: str, stream_params: ProxyStreamParams, proxy_headers: ProxyRequestHeaders):
"""
Proxies the stream request to the given video URL.
Args:
method (str): The HTTP method (e.g., GET, HEAD).
stream_params (ProxyStreamParams): The parameters for the stream request.
proxy_headers (ProxyRequestHeaders): The headers to include in the request.
Returns:
Response: The HTTP response with the streamed content.
"""
return await handle_stream_request(method, stream_params.destination, proxy_headers)
async def fetch_and_process_m3u8(
streamer: Streamer, url: str, proxy_headers: ProxyRequestHeaders, request: Request, key_url: str = None
):
"""
Fetches and processes the m3u8 playlist, converting it to an HLS playlist.
Args:
streamer (Streamer): The HTTP client to use for streaming.
url (str): The URL of the m3u8 playlist.
proxy_headers (ProxyRequestHeaders): The headers to include in the request.
request (Request): The incoming HTTP request.
key_url (str, optional): The HLS Key URL to replace the original key URL. Defaults to None.
Returns:
Response: The HTTP response with the processed m3u8 playlist.
"""
try:
content = await streamer.get_text(url, proxy_headers.request)
processor = M3U8Processor(request, key_url)
processed_content = await processor.process_m3u8(content, str(streamer.response.url))
response_headers = {"Content-Disposition": "inline", "Accept-Ranges": "none"}
response_headers.update(proxy_headers.response)
return Response(
content=processed_content,
media_type="application/vnd.apple.mpegurl",
headers=response_headers,
)
except Exception as e:
return handle_exceptions(e)
finally:
await streamer.close()
async def handle_drm_key_data(key_id, key, drm_info):
"""
Handles the DRM key data, retrieving the key ID and key from the DRM info if not provided.
Args:
key_id (str): The DRM key ID.
key (str): The DRM key.
drm_info (dict): The DRM information from the MPD manifest.
Returns:
tuple: The key ID and key.
"""
if drm_info and not drm_info.get("isDrmProtected"):
return None, None
if not key_id or not key:
if "keyId" in drm_info and "key" in drm_info:
key_id = drm_info["keyId"]
key = drm_info["key"]
elif "laUrl" in drm_info and "keyId" in drm_info:
raise HTTPException(status_code=400, detail="LA URL is not supported yet")
else:
raise HTTPException(
status_code=400, detail="Unable to determine key_id and key, and they were not provided"
)
return key_id, key
async def get_manifest(
request: Request,
manifest_params: MPDManifestParams,
proxy_headers: ProxyRequestHeaders,
):
"""
Retrieves and processes the MPD manifest, converting it to an HLS manifest.
Args:
request (Request): The incoming HTTP request.
manifest_params (MPDManifestParams): The parameters for the manifest request.
proxy_headers (ProxyRequestHeaders): The headers to include in the request.
Returns:
Response: The HTTP response with the HLS manifest.
"""
try:
mpd_dict = await get_cached_mpd(
manifest_params.destination,
headers=proxy_headers.request,
parse_drm=not manifest_params.key_id and not manifest_params.key,
)
except DownloadError as e:
raise HTTPException(status_code=e.status_code, detail=f"Failed to download MPD: {e.message}")
drm_info = mpd_dict.get("drmInfo", {})
if drm_info and not drm_info.get("isDrmProtected"):
# For non-DRM protected MPD, we still create an HLS manifest
return await process_manifest(request, mpd_dict, proxy_headers, None, None)
key_id, key = await handle_drm_key_data(manifest_params.key_id, manifest_params.key, drm_info)
# check if the provided key_id and key are valid
if key_id and len(key_id) != 32:
key_id = base64.urlsafe_b64decode(pad_base64(key_id)).hex()
if key and len(key) != 32:
key = base64.urlsafe_b64decode(pad_base64(key)).hex()
return await process_manifest(request, mpd_dict, proxy_headers, key_id, key)
async def get_playlist(
request: Request,
playlist_params: MPDPlaylistParams,
proxy_headers: ProxyRequestHeaders,
):
"""
Retrieves and processes the MPD manifest, converting it to an HLS playlist for a specific profile.
Args:
request (Request): The incoming HTTP request.
playlist_params (MPDPlaylistParams): The parameters for the playlist request.
proxy_headers (ProxyRequestHeaders): The headers to include in the request.
Returns:
Response: The HTTP response with the HLS playlist.
"""
try:
mpd_dict = await get_cached_mpd(
playlist_params.destination,
headers=proxy_headers.request,
parse_drm=not playlist_params.key_id and not playlist_params.key,
parse_segment_profile_id=playlist_params.profile_id,
)
except DownloadError as e:
raise HTTPException(status_code=e.status_code, detail=f"Failed to download MPD: {e.message}")
return await process_playlist(request, mpd_dict, playlist_params.profile_id, proxy_headers)
async def get_segment(
segment_params: MPDSegmentParams,
proxy_headers: ProxyRequestHeaders,
):
"""
Retrieves and processes a media segment, decrypting it if necessary.
Args:
segment_params (MPDSegmentParams): The parameters for the segment request.
proxy_headers (ProxyRequestHeaders): The headers to include in the request.
Returns:
Response: The HTTP response with the processed segment.
"""
try:
init_content = await get_cached_init_segment(segment_params.init_url, proxy_headers.request)
segment_content = await download_file_with_retry(segment_params.segment_url, proxy_headers.request)
except Exception as e:
return handle_exceptions(e)
return await process_segment(
init_content,
segment_content,
segment_params.mime_type,
proxy_headers,
segment_params.key_id,
segment_params.key,
)
async def get_public_ip():
"""
Retrieves the public IP address of the MediaFlow proxy.
Returns:
Response: The HTTP response with the public IP address.
"""
ip_address_data = await request_with_retry("GET", "https://api.ipify.org?format=json", {})
return ip_address_data.json()

99
mediaflow_proxy/main.py Normal file
View File

@@ -0,0 +1,99 @@
import logging
from importlib import resources
from fastapi import FastAPI, Depends, Security, HTTPException
from fastapi.security import APIKeyQuery, APIKeyHeader
from starlette.middleware.cors import CORSMiddleware
from starlette.responses import RedirectResponse
from starlette.staticfiles import StaticFiles
from mediaflow_proxy.configs import settings
from mediaflow_proxy.routes import proxy_router, extractor_router, speedtest_router
from mediaflow_proxy.schemas import GenerateUrlRequest
from mediaflow_proxy.utils.crypto_utils import EncryptionHandler, EncryptionMiddleware
from mediaflow_proxy.utils.http_utils import encode_mediaflow_proxy_url
logging.basicConfig(level=settings.log_level, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
app = FastAPI()
api_password_query = APIKeyQuery(name="api_password", auto_error=False)
api_password_header = APIKeyHeader(name="api_password", auto_error=False)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.add_middleware(EncryptionMiddleware)
async def verify_api_key(api_key: str = Security(api_password_query), api_key_alt: str = Security(api_password_header)):
"""
Verifies the API key for the request.
Args:
api_key (str): The API key to validate.
api_key_alt (str): The alternative API key to validate.
Raises:
HTTPException: If the API key is invalid.
"""
if not settings.api_password:
return
if api_key == settings.api_password or api_key_alt == settings.api_password:
return
raise HTTPException(status_code=403, detail="Could not validate credentials")
@app.get("/health")
async def health_check():
return {"status": "healthy"}
@app.get("/favicon.ico")
async def get_favicon():
return RedirectResponse(url="/logo.png")
@app.get("/speedtest")
async def show_speedtest_page():
return RedirectResponse(url="/speedtest.html")
@app.post("/generate_encrypted_or_encoded_url")
async def generate_encrypted_or_encoded_url(request: GenerateUrlRequest):
if "api_password" not in request.query_params:
request.query_params["api_password"] = request.api_password
encoded_url = encode_mediaflow_proxy_url(
request.mediaflow_proxy_url,
request.endpoint,
request.destination_url,
request.query_params,
request.request_headers,
request.response_headers,
EncryptionHandler(request.api_password) if request.api_password else None,
request.expiration,
str(request.ip) if request.ip else None,
)
return {"encoded_url": encoded_url}
app.include_router(proxy_router, prefix="/proxy", tags=["proxy"], dependencies=[Depends(verify_api_key)])
app.include_router(extractor_router, prefix="/extractor", tags=["extractors"], dependencies=[Depends(verify_api_key)])
app.include_router(speedtest_router, prefix="/speedtest", tags=["speedtest"], dependencies=[Depends(verify_api_key)])
static_path = resources.files("mediaflow_proxy").joinpath("static")
app.mount("/", StaticFiles(directory=str(static_path), html=True), name="static")
def run():
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8888, log_level="info", workers=3)
if __name__ == "__main__":
run()

View File

@@ -0,0 +1,214 @@
import logging
import math
import time
from fastapi import Request, Response, HTTPException
from mediaflow_proxy.drm.decrypter import decrypt_segment
from mediaflow_proxy.utils.crypto_utils import encryption_handler
from mediaflow_proxy.utils.http_utils import encode_mediaflow_proxy_url, get_original_scheme, ProxyRequestHeaders
logger = logging.getLogger(__name__)
async def process_manifest(
request: Request, mpd_dict: dict, proxy_headers: ProxyRequestHeaders, key_id: str = None, key: str = None
) -> Response:
"""
Processes the MPD manifest and converts it to an HLS manifest.
Args:
request (Request): The incoming HTTP request.
mpd_dict (dict): The MPD manifest data.
proxy_headers (ProxyRequestHeaders): The headers to include in the request.
key_id (str, optional): The DRM key ID. Defaults to None.
key (str, optional): The DRM key. Defaults to None.
Returns:
Response: The HLS manifest as an HTTP response.
"""
hls_content = build_hls(mpd_dict, request, key_id, key)
return Response(content=hls_content, media_type="application/vnd.apple.mpegurl", headers=proxy_headers.response)
async def process_playlist(
request: Request, mpd_dict: dict, profile_id: str, proxy_headers: ProxyRequestHeaders
) -> Response:
"""
Processes the MPD manifest and converts it to an HLS playlist for a specific profile.
Args:
request (Request): The incoming HTTP request.
mpd_dict (dict): The MPD manifest data.
profile_id (str): The profile ID to generate the playlist for.
proxy_headers (ProxyRequestHeaders): The headers to include in the request.
Returns:
Response: The HLS playlist as an HTTP response.
Raises:
HTTPException: If the profile is not found in the MPD manifest.
"""
matching_profiles = [p for p in mpd_dict["profiles"] if p["id"] == profile_id]
if not matching_profiles:
raise HTTPException(status_code=404, detail="Profile not found")
hls_content = build_hls_playlist(mpd_dict, matching_profiles, request)
return Response(content=hls_content, media_type="application/vnd.apple.mpegurl", headers=proxy_headers.response)
async def process_segment(
init_content: bytes,
segment_content: bytes,
mimetype: str,
proxy_headers: ProxyRequestHeaders,
key_id: str = None,
key: str = None,
) -> Response:
"""
Processes and decrypts a media segment.
Args:
init_content (bytes): The initialization segment content.
segment_content (bytes): The media segment content.
mimetype (str): The MIME type of the segment.
proxy_headers (ProxyRequestHeaders): The headers to include in the request.
key_id (str, optional): The DRM key ID. Defaults to None.
key (str, optional): The DRM key. Defaults to None.
Returns:
Response: The decrypted segment as an HTTP response.
"""
if key_id and key:
# For DRM protected content
now = time.time()
decrypted_content = decrypt_segment(init_content, segment_content, key_id, key)
logger.info(f"Decryption of {mimetype} segment took {time.time() - now:.4f} seconds")
else:
# For non-DRM protected content, we just concatenate init and segment content
decrypted_content = init_content + segment_content
return Response(content=decrypted_content, media_type=mimetype, headers=proxy_headers.response)
def build_hls(mpd_dict: dict, request: Request, key_id: str = None, key: str = None) -> str:
"""
Builds an HLS manifest from the MPD manifest.
Args:
mpd_dict (dict): The MPD manifest data.
request (Request): The incoming HTTP request.
key_id (str, optional): The DRM key ID. Defaults to None.
key (str, optional): The DRM key. Defaults to None.
Returns:
str: The HLS manifest as a string.
"""
hls = ["#EXTM3U", "#EXT-X-VERSION:6"]
query_params = dict(request.query_params)
has_encrypted = query_params.pop("has_encrypted", False)
video_profiles = {}
audio_profiles = {}
# Get the base URL for the playlist_endpoint endpoint
proxy_url = request.url_for("playlist_endpoint")
proxy_url = str(proxy_url.replace(scheme=get_original_scheme(request)))
for profile in mpd_dict["profiles"]:
query_params.update({"profile_id": profile["id"], "key_id": key_id or "", "key": key or ""})
playlist_url = encode_mediaflow_proxy_url(
proxy_url,
query_params=query_params,
encryption_handler=encryption_handler if has_encrypted else None,
)
if "video" in profile["mimeType"]:
video_profiles[profile["id"]] = (profile, playlist_url)
elif "audio" in profile["mimeType"]:
audio_profiles[profile["id"]] = (profile, playlist_url)
# Add audio streams
for i, (profile, playlist_url) in enumerate(audio_profiles.values()):
is_default = "YES" if i == 0 else "NO" # Set the first audio track as default
hls.append(
f'#EXT-X-MEDIA:TYPE=AUDIO,GROUP-ID="audio",NAME="{profile["id"]}",DEFAULT={is_default},AUTOSELECT={is_default},LANGUAGE="{profile.get("lang", "und")}",URI="{playlist_url}"'
)
# Add video streams
for profile, playlist_url in video_profiles.values():
hls.append(
f'#EXT-X-STREAM-INF:BANDWIDTH={profile["bandwidth"]},RESOLUTION={profile["width"]}x{profile["height"]},CODECS="{profile["codecs"]}",FRAME-RATE={profile["frameRate"]},AUDIO="audio"'
)
hls.append(playlist_url)
return "\n".join(hls)
def build_hls_playlist(mpd_dict: dict, profiles: list[dict], request: Request) -> str:
"""
Builds an HLS playlist from the MPD manifest for specific profiles.
Args:
mpd_dict (dict): The MPD manifest data.
profiles (list[dict]): The profiles to include in the playlist.
request (Request): The incoming HTTP request.
Returns:
str: The HLS playlist as a string.
"""
hls = ["#EXTM3U", "#EXT-X-VERSION:6"]
added_segments = 0
proxy_url = request.url_for("segment_endpoint")
proxy_url = str(proxy_url.replace(scheme=get_original_scheme(request)))
for index, profile in enumerate(profiles):
segments = profile["segments"]
if not segments:
logger.warning(f"No segments found for profile {profile['id']}")
continue
# Add headers for only the first profile
if index == 0:
sequence = segments[0]["number"]
extinf_values = [f["extinf"] for f in segments if "extinf" in f]
target_duration = math.ceil(max(extinf_values)) if extinf_values else 3
hls.extend(
[
f"#EXT-X-TARGETDURATION:{target_duration}",
f"#EXT-X-MEDIA-SEQUENCE:{sequence}",
]
)
if mpd_dict["isLive"]:
hls.append("#EXT-X-PLAYLIST-TYPE:EVENT")
else:
hls.append("#EXT-X-PLAYLIST-TYPE:VOD")
init_url = profile["initUrl"]
query_params = dict(request.query_params)
query_params.pop("profile_id", None)
query_params.pop("d", None)
has_encrypted = query_params.pop("has_encrypted", False)
for segment in segments:
hls.append(f'#EXTINF:{segment["extinf"]:.3f},')
query_params.update(
{"init_url": init_url, "segment_url": segment["media"], "mime_type": profile["mimeType"]}
)
hls.append(
encode_mediaflow_proxy_url(
proxy_url,
query_params=query_params,
encryption_handler=encryption_handler if has_encrypted else None,
)
)
added_segments += 1
if not mpd_dict["isLive"]:
hls.append("#EXT-X-ENDLIST")
logger.info(f"Added {added_segments} segments to HLS playlist")
return "\n".join(hls)

164
mediaflow_proxy/routes.py Normal file
View File

@@ -0,0 +1,164 @@
from fastapi import Request, Depends, APIRouter
from pydantic import HttpUrl
from .handlers import handle_hls_stream_proxy, proxy_stream, get_manifest, get_playlist, get_segment, get_public_ip
from .utils.http_utils import get_proxy_headers, ProxyRequestHeaders
proxy_router = APIRouter()
@proxy_router.head("/hls")
@proxy_router.get("/hls")
async def hls_stream_proxy(
request: Request,
d: HttpUrl,
proxy_headers: ProxyRequestHeaders = Depends(get_proxy_headers),
key_url: HttpUrl | None = None,
verify_ssl: bool = False,
use_request_proxy: bool = True,
):
"""
Proxify HLS stream requests, fetching and processing the m3u8 playlist or streaming the content.
Args:
request (Request): The incoming HTTP request.
d (HttpUrl): The destination URL to fetch the content from.
key_url (HttpUrl, optional): The HLS Key URL to replace the original key URL. Defaults to None. (Useful for bypassing some sneaky protection)
proxy_headers (ProxyRequestHeaders): The headers to include in the request.
verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to False.
use_request_proxy (bool, optional): Whether to use the MediaFlow proxy configuration. Defaults to True.
Returns:
Response: The HTTP response with the processed m3u8 playlist or streamed content.
"""
destination = str(d)
return await handle_hls_stream_proxy(request, destination, proxy_headers, key_url, verify_ssl, use_request_proxy)
@proxy_router.head("/stream")
@proxy_router.get("/stream")
async def proxy_stream_endpoint(
request: Request,
d: HttpUrl,
proxy_headers: ProxyRequestHeaders = Depends(get_proxy_headers),
verify_ssl: bool = False,
use_request_proxy: bool = True,
):
"""
Proxies stream requests to the given video URL.
Args:
request (Request): The incoming HTTP request.
d (HttpUrl): The URL of the video to stream.
proxy_headers (ProxyRequestHeaders): The headers to include in the request.
verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to False.
use_request_proxy (bool, optional): Whether to use the MediaFlow proxy configuration. Defaults to True.
Returns:
Response: The HTTP response with the streamed content.
"""
proxy_headers.request.update({"range": proxy_headers.request.get("range", "bytes=0-")})
return await proxy_stream(request.method, str(d), proxy_headers, verify_ssl, use_request_proxy)
@proxy_router.get("/mpd/manifest")
async def manifest_endpoint(
request: Request,
d: HttpUrl,
proxy_headers: ProxyRequestHeaders = Depends(get_proxy_headers),
key_id: str = None,
key: str = None,
verify_ssl: bool = False,
use_request_proxy: bool = True,
):
"""
Retrieves and processes the MPD manifest, converting it to an HLS manifest.
Args:
request (Request): The incoming HTTP request.
d (HttpUrl): The URL of the MPD manifest.
proxy_headers (ProxyRequestHeaders): The headers to include in the request.
key_id (str, optional): The DRM key ID. Defaults to None.
key (str, optional): The DRM key. Defaults to None.
verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to False.
use_request_proxy (bool, optional): Whether to use the MediaFlow proxy configuration. Defaults to True.
Returns:
Response: The HTTP response with the HLS manifest.
"""
return await get_manifest(request, str(d), proxy_headers, key_id, key, verify_ssl, use_request_proxy)
@proxy_router.get("/mpd/playlist")
async def playlist_endpoint(
request: Request,
d: HttpUrl,
profile_id: str,
proxy_headers: ProxyRequestHeaders = Depends(get_proxy_headers),
key_id: str = None,
key: str = None,
verify_ssl: bool = False,
use_request_proxy: bool = True,
):
"""
Retrieves and processes the MPD manifest, converting it to an HLS playlist for a specific profile.
Args:
request (Request): The incoming HTTP request.
d (HttpUrl): The URL of the MPD manifest.
profile_id (str): The profile ID to generate the playlist for.
proxy_headers (ProxyRequestHeaders): The headers to include in the request.
key_id (str, optional): The DRM key ID. Defaults to None.
key (str, optional): The DRM key. Defaults to None.
verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to False.
use_request_proxy (bool, optional): Whether to use the MediaFlow proxy configuration. Defaults to True.
Returns:
Response: The HTTP response with the HLS playlist.
"""
return await get_playlist(request, str(d), profile_id, proxy_headers, key_id, key, verify_ssl, use_request_proxy)
@proxy_router.get("/mpd/segment")
async def segment_endpoint(
init_url: HttpUrl,
segment_url: HttpUrl,
mime_type: str,
proxy_headers: ProxyRequestHeaders = Depends(get_proxy_headers),
key_id: str = None,
key: str = None,
verify_ssl: bool = False,
use_request_proxy: bool = True,
):
"""
Retrieves and processes a media segment, decrypting it if necessary.
Args:
init_url (HttpUrl): The URL of the initialization segment.
segment_url (HttpUrl): The URL of the media segment.
mime_type (str): The MIME type of the segment.
proxy_headers (ProxyRequestHeaders): The headers to include in the request.
key_id (str, optional): The DRM key ID. Defaults to None.
key (str, optional): The DRM key. Defaults to None.
verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to False.
use_request_proxy (bool, optional): Whether to use the MediaFlow proxy configuration. Defaults to True.
Returns:
Response: The HTTP response with the processed segment.
"""
return await get_segment(
str(init_url), str(segment_url), mime_type, proxy_headers, key_id, key, verify_ssl, use_request_proxy
)
@proxy_router.get("/ip")
async def get_mediaflow_proxy_public_ip(
use_request_proxy: bool = True,
):
"""
Retrieves the public IP address of the MediaFlow proxy server.
Returns:
Response: The HTTP response with the public IP address in the form of a JSON object. {"ip": "xxx.xxx.xxx.xxx"}
"""
return await get_public_ip(use_request_proxy)

View File

@@ -0,0 +1,5 @@
from .proxy import proxy_router
from .extractor import extractor_router
from .speedtest import speedtest_router
__all__ = ["proxy_router", "extractor_router", "speedtest_router"]

View File

@@ -0,0 +1,61 @@
import logging
from typing import Annotated
from fastapi import APIRouter, Query, HTTPException, Request, Depends
from fastapi.responses import RedirectResponse
from mediaflow_proxy.extractors.base import ExtractorError
from mediaflow_proxy.extractors.factory import ExtractorFactory
from mediaflow_proxy.schemas import ExtractorURLParams
from mediaflow_proxy.utils.cache_utils import get_cached_extractor_result, set_cache_extractor_result
from mediaflow_proxy.utils.http_utils import (
encode_mediaflow_proxy_url,
get_original_scheme,
ProxyRequestHeaders,
get_proxy_headers,
)
extractor_router = APIRouter()
logger = logging.getLogger(__name__)
@extractor_router.head("/video")
@extractor_router.get("/video")
async def extract_url(
extractor_params: Annotated[ExtractorURLParams, Query()],
request: Request,
proxy_headers: Annotated[ProxyRequestHeaders, Depends(get_proxy_headers)],
):
"""Extract clean links from various video hosting services."""
try:
cache_key = f"{extractor_params.host}_{extractor_params.model_dump_json()}"
response = await get_cached_extractor_result(cache_key)
if not response:
extractor = ExtractorFactory.get_extractor(extractor_params.host, proxy_headers.request)
response = await extractor.extract(extractor_params.destination, **extractor_params.extra_params)
await set_cache_extractor_result(cache_key, response)
else:
response["request_headers"].update(proxy_headers.request)
response["mediaflow_proxy_url"] = str(
request.url_for(response.pop("mediaflow_endpoint")).replace(scheme=get_original_scheme(request))
)
response["query_params"] = response.get("query_params", {})
# Add API password to query params
response["query_params"]["api_password"] = request.query_params.get("api_password")
if extractor_params.redirect_stream:
stream_url = encode_mediaflow_proxy_url(
**response,
response_headers=proxy_headers.response,
)
return RedirectResponse(url=stream_url, status_code=302)
return response
except ExtractorError as e:
logger.error(f"Extraction failed: {str(e)}")
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.exception(f"Extraction failed: {str(e)}")
raise HTTPException(status_code=500, detail=f"Extraction failed: {str(e)}")

View File

@@ -0,0 +1,138 @@
from typing import Annotated
from fastapi import Request, Depends, APIRouter, Query, HTTPException
from mediaflow_proxy.handlers import (
handle_hls_stream_proxy,
proxy_stream,
get_manifest,
get_playlist,
get_segment,
get_public_ip,
)
from mediaflow_proxy.schemas import (
MPDSegmentParams,
MPDPlaylistParams,
HLSManifestParams,
ProxyStreamParams,
MPDManifestParams,
)
from mediaflow_proxy.utils.http_utils import get_proxy_headers, ProxyRequestHeaders
proxy_router = APIRouter()
@proxy_router.head("/hls/manifest.m3u8")
@proxy_router.get("/hls/manifest.m3u8")
async def hls_manifest_proxy(
request: Request,
hls_params: Annotated[HLSManifestParams, Query()],
proxy_headers: Annotated[ProxyRequestHeaders, Depends(get_proxy_headers)],
):
"""
Proxify HLS stream requests, fetching and processing the m3u8 playlist or streaming the content.
Args:
request (Request): The incoming HTTP request.
hls_params (HLSPlaylistParams): The parameters for the HLS stream request.
proxy_headers (ProxyRequestHeaders): The headers to include in the request.
Returns:
Response: The HTTP response with the processed m3u8 playlist or streamed content.
"""
return await handle_hls_stream_proxy(request, hls_params, proxy_headers)
@proxy_router.head("/stream")
@proxy_router.get("/stream")
async def proxy_stream_endpoint(
request: Request,
stream_params: Annotated[ProxyStreamParams, Query()],
proxy_headers: Annotated[ProxyRequestHeaders, Depends(get_proxy_headers)],
):
"""
Proxies stream requests to the given video URL.
Args:
request (Request): The incoming HTTP request.
stream_params (ProxyStreamParams): The parameters for the stream request.
proxy_headers (ProxyRequestHeaders): The headers to include in the request.
Returns:
Response: The HTTP response with the streamed content.
"""
content_range = proxy_headers.request.get("range", "bytes=0-")
if "nan" in content_range.casefold():
# Handle invalid range requests "bytes=NaN-NaN"
raise HTTPException(status_code=416, detail="Invalid Range Header")
proxy_headers.request.update({"range": content_range})
return await proxy_stream(request.method, stream_params, proxy_headers)
@proxy_router.get("/mpd/manifest.m3u8")
async def mpd_manifest_proxy(
request: Request,
manifest_params: Annotated[MPDManifestParams, Query()],
proxy_headers: Annotated[ProxyRequestHeaders, Depends(get_proxy_headers)],
):
"""
Retrieves and processes the MPD manifest, converting it to an HLS manifest.
Args:
request (Request): The incoming HTTP request.
manifest_params (MPDManifestParams): The parameters for the manifest request.
proxy_headers (ProxyRequestHeaders): The headers to include in the request.
Returns:
Response: The HTTP response with the HLS manifest.
"""
return await get_manifest(request, manifest_params, proxy_headers)
@proxy_router.get("/mpd/playlist.m3u8")
async def playlist_endpoint(
request: Request,
playlist_params: Annotated[MPDPlaylistParams, Query()],
proxy_headers: Annotated[ProxyRequestHeaders, Depends(get_proxy_headers)],
):
"""
Retrieves and processes the MPD manifest, converting it to an HLS playlist for a specific profile.
Args:
request (Request): The incoming HTTP request.
playlist_params (MPDPlaylistParams): The parameters for the playlist request.
proxy_headers (ProxyRequestHeaders): The headers to include in the request.
Returns:
Response: The HTTP response with the HLS playlist.
"""
return await get_playlist(request, playlist_params, proxy_headers)
@proxy_router.get("/mpd/segment.mp4")
async def segment_endpoint(
segment_params: Annotated[MPDSegmentParams, Query()],
proxy_headers: Annotated[ProxyRequestHeaders, Depends(get_proxy_headers)],
):
"""
Retrieves and processes a media segment, decrypting it if necessary.
Args:
segment_params (MPDSegmentParams): The parameters for the segment request.
proxy_headers (ProxyRequestHeaders): The headers to include in the request.
Returns:
Response: The HTTP response with the processed segment.
"""
return await get_segment(segment_params, proxy_headers)
@proxy_router.get("/ip")
async def get_mediaflow_proxy_public_ip():
"""
Retrieves the public IP address of the MediaFlow proxy server.
Returns:
Response: The HTTP response with the public IP address in the form of a JSON object. {"ip": "xxx.xxx.xxx.xxx"}
"""
return await get_public_ip()

View File

@@ -0,0 +1,43 @@
import uuid
from fastapi import APIRouter, BackgroundTasks, HTTPException, Request
from fastapi.responses import RedirectResponse
from mediaflow_proxy.speedtest.service import SpeedTestService, SpeedTestProvider
speedtest_router = APIRouter()
# Initialize service
speedtest_service = SpeedTestService()
@speedtest_router.get("/", summary="Show speed test interface")
async def show_speedtest_page():
"""Return the speed test HTML interface."""
return RedirectResponse(url="/speedtest.html")
@speedtest_router.post("/start", summary="Start a new speed test", response_model=dict)
async def start_speedtest(background_tasks: BackgroundTasks, provider: SpeedTestProvider, request: Request):
"""Start a new speed test for the specified provider."""
task_id = str(uuid.uuid4())
api_key = request.headers.get("api_key")
# Create and initialize the task
await speedtest_service.create_test(task_id, provider, api_key)
# Schedule the speed test
background_tasks.add_task(speedtest_service.run_speedtest, task_id, provider, api_key)
return {"task_id": task_id}
@speedtest_router.get("/results/{task_id}", summary="Get speed test results")
async def get_speedtest_results(task_id: str):
"""Get the results or current status of a speed test."""
task = await speedtest_service.get_test_results(task_id)
if not task:
raise HTTPException(status_code=404, detail="Speed test task not found or expired")
return task.dict()

View File

@@ -0,0 +1,74 @@
from typing import Literal, Dict, Any, Optional
from pydantic import BaseModel, Field, IPvAnyAddress, ConfigDict
class GenerateUrlRequest(BaseModel):
mediaflow_proxy_url: str = Field(..., description="The base URL for the mediaflow proxy.")
endpoint: Optional[str] = Field(None, description="The specific endpoint to be appended to the base URL.")
destination_url: Optional[str] = Field(
None, description="The destination URL to which the request will be proxied."
)
query_params: Optional[dict] = Field(
default_factory=dict, description="Query parameters to be included in the request."
)
request_headers: Optional[dict] = Field(default_factory=dict, description="Headers to be included in the request.")
response_headers: Optional[dict] = Field(
default_factory=dict, description="Headers to be included in the response."
)
expiration: Optional[int] = Field(
None, description="Expiration time for the URL in seconds. If not provided, the URL will not expire."
)
api_password: Optional[str] = Field(
None, description="API password for encryption. If not provided, the URL will only be encoded."
)
ip: Optional[IPvAnyAddress] = Field(None, description="The IP address to restrict the URL to.")
class GenericParams(BaseModel):
model_config = ConfigDict(populate_by_name=True)
class HLSManifestParams(GenericParams):
destination: str = Field(..., description="The URL of the HLS manifest.", alias="d")
key_url: Optional[str] = Field(
None,
description="The HLS Key URL to replace the original key URL. Defaults to None. (Useful for bypassing some sneaky protection)",
)
class ProxyStreamParams(GenericParams):
destination: str = Field(..., description="The URL of the stream.", alias="d")
class MPDManifestParams(GenericParams):
destination: str = Field(..., description="The URL of the MPD manifest.", alias="d")
key_id: Optional[str] = Field(None, description="The DRM key ID (optional).")
key: Optional[str] = Field(None, description="The DRM key (optional).")
class MPDPlaylistParams(GenericParams):
destination: str = Field(..., description="The URL of the MPD manifest.", alias="d")
profile_id: str = Field(..., description="The profile ID to generate the playlist for.")
key_id: Optional[str] = Field(None, description="The DRM key ID (optional).")
key: Optional[str] = Field(None, description="The DRM key (optional).")
class MPDSegmentParams(GenericParams):
init_url: str = Field(..., description="The URL of the initialization segment.")
segment_url: str = Field(..., description="The URL of the media segment.")
mime_type: str = Field(..., description="The MIME type of the segment.")
key_id: Optional[str] = Field(None, description="The DRM key ID (optional).")
key: Optional[str] = Field(None, description="The DRM key (optional).")
class ExtractorURLParams(GenericParams):
host: Literal["Doodstream", "Mixdrop", "Uqload", "Streamtape", "Supervideo", "LiveTV"] = Field(
..., description="The host to extract the URL from."
)
destination: str = Field(..., description="The URL of the stream.", alias="d")
redirect_stream: bool = Field(False, description="Whether to redirect to the stream endpoint automatically.")
extra_params: Dict[str, Any] = Field(
default_factory=dict,
description="Additional parameters required for specific extractors (e.g., stream_title for LiveTV)",
)

View File

View File

@@ -0,0 +1,46 @@
from datetime import datetime
from enum import Enum
from typing import Dict, Optional
from pydantic import BaseModel, Field
class SpeedTestProvider(str, Enum):
REAL_DEBRID = "real_debrid"
ALL_DEBRID = "all_debrid"
class ServerInfo(BaseModel):
url: str
name: str
class UserInfo(BaseModel):
ip: Optional[str] = None
isp: Optional[str] = None
country: Optional[str] = None
class SpeedTestResult(BaseModel):
speed_mbps: float = Field(..., description="Speed in Mbps")
duration: float = Field(..., description="Test duration in seconds")
data_transferred: int = Field(..., description="Data transferred in bytes")
timestamp: datetime = Field(default_factory=datetime.utcnow)
class LocationResult(BaseModel):
result: Optional[SpeedTestResult] = None
error: Optional[str] = None
server_name: str
server_url: str
class SpeedTestTask(BaseModel):
task_id: str
provider: SpeedTestProvider
results: Dict[str, LocationResult] = {}
started_at: datetime
completed_at: Optional[datetime] = None
status: str = "running"
user_info: Optional[UserInfo] = None
current_location: Optional[str] = None

View File

@@ -0,0 +1,50 @@
import random
from typing import Dict, Tuple, Optional
from mediaflow_proxy.configs import settings
from mediaflow_proxy.speedtest.models import ServerInfo, UserInfo
from mediaflow_proxy.speedtest.providers.base import BaseSpeedTestProvider, SpeedTestProviderConfig
from mediaflow_proxy.utils.http_utils import request_with_retry
class SpeedTestError(Exception):
pass
class AllDebridSpeedTest(BaseSpeedTestProvider):
"""AllDebrid speed test provider implementation."""
def __init__(self, api_key: str):
self.api_key = api_key
self.servers: Dict[str, ServerInfo] = {}
async def get_test_urls(self) -> Tuple[Dict[str, str], Optional[UserInfo]]:
response = await request_with_retry(
"GET",
"https://alldebrid.com/internalapi/v4/speedtest",
headers={"User-Agent": settings.user_agent},
params={"agent": "service", "version": "1.0-363869a7", "apikey": self.api_key},
)
if response.status_code != 200:
raise SpeedTestError("Failed to fetch AllDebrid servers")
data = response.json()
if data["status"] != "success":
raise SpeedTestError("AllDebrid API returned error")
# Create UserInfo
user_info = UserInfo(ip=data["data"]["ip"], isp=data["data"]["isp"], country=data["data"]["country"])
# Store server info
self.servers = {server["name"]: ServerInfo(**server) for server in data["data"]["servers"]}
# Generate URLs with random number
random_number = f"{random.uniform(1, 2):.24f}".replace(".", "")
urls = {name: f"{server.url}/speedtest/{random_number}" for name, server in self.servers.items()}
return urls, user_info
async def get_config(self) -> SpeedTestProviderConfig:
urls, _ = await self.get_test_urls()
return SpeedTestProviderConfig(test_duration=10, test_urls=urls)

View File

@@ -0,0 +1,24 @@
from abc import ABC, abstractmethod
from typing import Dict, Tuple, Optional
from pydantic import BaseModel
from mediaflow_proxy.speedtest.models import UserInfo
class SpeedTestProviderConfig(BaseModel):
test_duration: int = 10 # seconds
test_urls: Dict[str, str]
class BaseSpeedTestProvider(ABC):
"""Base class for speed test providers."""
@abstractmethod
async def get_test_urls(self) -> Tuple[Dict[str, str], Optional[UserInfo]]:
"""Get list of test URLs for the provider and optional user info."""
pass
@abstractmethod
async def get_config(self) -> SpeedTestProviderConfig:
"""Get provider-specific configuration."""
pass

View File

@@ -0,0 +1,32 @@
from typing import Dict, Tuple, Optional
import random
from mediaflow_proxy.speedtest.models import UserInfo
from mediaflow_proxy.speedtest.providers.base import BaseSpeedTestProvider, SpeedTestProviderConfig
class RealDebridSpeedTest(BaseSpeedTestProvider):
"""RealDebrid speed test provider implementation."""
async def get_test_urls(self) -> Tuple[Dict[str, str], Optional[UserInfo]]:
urls = {
"AMS": "https://45.download.real-debrid.com/speedtest/testDefault.rar/",
"RBX": "https://rbx.download.real-debrid.com/speedtest/test.rar/",
"LON1": "https://lon1.download.real-debrid.com/speedtest/test.rar/",
"HKG1": "https://hkg1.download.real-debrid.com/speedtest/test.rar/",
"SGP1": "https://sgp1.download.real-debrid.com/speedtest/test.rar/",
"SGPO1": "https://sgpo1.download.real-debrid.com/speedtest/test.rar/",
"TYO1": "https://tyo1.download.real-debrid.com/speedtest/test.rar/",
"LAX1": "https://lax1.download.real-debrid.com/speedtest/test.rar/",
"TLV1": "https://tlv1.download.real-debrid.com/speedtest/test.rar/",
"MUM1": "https://mum1.download.real-debrid.com/speedtest/test.rar/",
"JKT1": "https://jkt1.download.real-debrid.com/speedtest/test.rar/",
"Cloudflare": "https://45.download.real-debrid.cloud/speedtest/testCloudflare.rar/",
}
# Add random number to prevent caching
urls = {location: f"{base_url}{random.uniform(0, 1):.16f}" for location, base_url in urls.items()}
return urls, None
async def get_config(self) -> SpeedTestProviderConfig:
urls, _ = await self.get_test_urls()
return SpeedTestProviderConfig(test_duration=10, test_urls=urls)

View File

@@ -0,0 +1,129 @@
import logging
import time
from datetime import datetime, timezone
from typing import Dict, Optional, Type
from mediaflow_proxy.utils.cache_utils import get_cached_speedtest, set_cache_speedtest
from mediaflow_proxy.utils.http_utils import Streamer, create_httpx_client
from .models import SpeedTestTask, LocationResult, SpeedTestResult, SpeedTestProvider
from .providers.all_debrid import AllDebridSpeedTest
from .providers.base import BaseSpeedTestProvider
from .providers.real_debrid import RealDebridSpeedTest
logger = logging.getLogger(__name__)
class SpeedTestService:
"""Service for managing speed tests across different providers."""
def __init__(self):
# Provider mapping
self._providers: Dict[SpeedTestProvider, Type[BaseSpeedTestProvider]] = {
SpeedTestProvider.REAL_DEBRID: RealDebridSpeedTest,
SpeedTestProvider.ALL_DEBRID: AllDebridSpeedTest,
}
def _get_provider(self, provider: SpeedTestProvider, api_key: Optional[str] = None) -> BaseSpeedTestProvider:
"""Get the appropriate provider implementation."""
provider_class = self._providers.get(provider)
if not provider_class:
raise ValueError(f"Unsupported provider: {provider}")
if provider == SpeedTestProvider.ALL_DEBRID and not api_key:
raise ValueError("API key required for AllDebrid")
return provider_class(api_key) if provider == SpeedTestProvider.ALL_DEBRID else provider_class()
async def create_test(
self, task_id: str, provider: SpeedTestProvider, api_key: Optional[str] = None
) -> SpeedTestTask:
"""Create a new speed test task."""
provider_impl = self._get_provider(provider, api_key)
# Get initial URLs and user info
urls, user_info = await provider_impl.get_test_urls()
task = SpeedTestTask(
task_id=task_id, provider=provider, started_at=datetime.now(tz=timezone.utc), user_info=user_info
)
await set_cache_speedtest(task_id, task)
return task
@staticmethod
async def get_test_results(task_id: str) -> Optional[SpeedTestTask]:
"""Get results for a specific task."""
return await get_cached_speedtest(task_id)
async def run_speedtest(self, task_id: str, provider: SpeedTestProvider, api_key: Optional[str] = None):
"""Run the speed test with real-time updates."""
try:
task = await get_cached_speedtest(task_id)
if not task:
raise ValueError(f"Task {task_id} not found")
provider_impl = self._get_provider(provider, api_key)
config = await provider_impl.get_config()
async with create_httpx_client() as client:
streamer = Streamer(client)
for location, url in config.test_urls.items():
try:
task.current_location = location
await set_cache_speedtest(task_id, task)
result = await self._test_location(location, url, streamer, config.test_duration, provider_impl)
task.results[location] = result
await set_cache_speedtest(task_id, task)
except Exception as e:
logger.error(f"Error testing {location}: {str(e)}")
task.results[location] = LocationResult(
error=str(e), server_name=location, server_url=config.test_urls[location]
)
await set_cache_speedtest(task_id, task)
# Mark task as completed
task.completed_at = datetime.now(tz=timezone.utc)
task.status = "completed"
task.current_location = None
await set_cache_speedtest(task_id, task)
except Exception as e:
logger.error(f"Error in speed test task {task_id}: {str(e)}")
if task := await get_cached_speedtest(task_id):
task.status = "failed"
await set_cache_speedtest(task_id, task)
async def _test_location(
self, location: str, url: str, streamer: Streamer, test_duration: int, provider: BaseSpeedTestProvider
) -> LocationResult:
"""Test speed for a specific location."""
try:
start_time = time.time()
total_bytes = 0
await streamer.create_streaming_response(url, headers={})
async for chunk in streamer.stream_content():
if time.time() - start_time >= test_duration:
break
total_bytes += len(chunk)
duration = time.time() - start_time
speed_mbps = (total_bytes * 8) / (duration * 1_000_000)
# Get server info if available (for AllDebrid)
server_info = getattr(provider, "servers", {}).get(location)
server_url = server_info.url if server_info else url
return LocationResult(
result=SpeedTestResult(
speed_mbps=round(speed_mbps, 2), duration=round(duration, 2), data_transferred=total_bytes
),
server_name=location,
server_url=server_url,
)
except Exception as e:
logger.error(f"Error testing {location}: {str(e)}")
raise # Re-raise to be handled by run_speedtest

View File

@@ -0,0 +1,76 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>MediaFlow Proxy</title>
<link rel="icon" href="/logo.png" type="image/x-icon">
<style>
body {
font-family: Arial, sans-serif;
line-height: 1.6;
color: #333;
max-width: 800px;
margin: 0 auto;
padding: 20px;
background-color: #f9f9f9;
}
header {
background-color: #90aacc;
color: #fff;
padding: 10px 0;
text-align: center;
}
header img {
width: 200px;
height: 200px;
vertical-align: middle;
border-radius: 15px;
}
header h1 {
display: inline;
margin-left: 20px;
font-size: 36px;
}
.feature {
background-color: #f4f4f4;
border-left: 4px solid #3498db;
padding: 10px;
margin-bottom: 10px;
}
a {
color: #3498db;
}
</style>
</head>
<body>
<header>
<img src="/logo.png" alt="MediaFlow Proxy Logo">
<h1>MediaFlow Proxy</h1>
</header>
<p>A high-performance proxy server for streaming media, supporting HTTP(S), HLS, and MPEG-DASH with real-time DRM decryption.</p>
<h2>Key Features</h2>
<div class="feature">Convert MPEG-DASH streams (DRM-protected and non-protected) to HLS</div>
<div class="feature">Support for Clear Key DRM-protected MPD DASH streams</div>
<div class="feature">Handle both live and video-on-demand (VOD) DASH streams</div>
<div class="feature">Proxy HTTP/HTTPS links with custom headers</div>
<div class="feature">Proxy and modify HLS (M3U8) streams in real-time with custom headers and key URL modifications for bypassing some sneaky restrictions.</div>
<div class="feature">Protect against unauthorized access and network bandwidth abuses</div>
<h2>Getting Started</h2>
<p>Visit the <a href="https://github.com/mhdzumair/mediaflow-proxy">GitHub repository</a> for installation instructions and documentation.</p>
<h2>Premium Hosted Service</h2>
<p>For a hassle-free experience, check out <a href="https://store.elfhosted.com/product/mediaflow-proxy">premium hosted service on ElfHosted</a>.</p>
<h2>API Documentation</h2>
<p>Explore the <a href="/docs">Swagger UI</a> for comprehensive details about the API endpoints and their usage.</p>
</body>
</html>

Binary file not shown.

After

Width:  |  Height:  |  Size: 85 KiB

View File

@@ -0,0 +1,697 @@
<!DOCTYPE html>
<html lang="en" class="h-full">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Debrid Speed Test</title>
<script src="https://cdn.tailwindcss.com"></script>
<script>
tailwind.config = {
darkMode: 'class',
theme: {
extend: {
animation: {
'progress': 'progress 180s linear forwards',
},
keyframes: {
progress: {
'0%': {width: '0%'},
'100%': {width: '100%'}
}
}
}
}
}
</script>
<style>
.provider-card {
transition: all 0.3s ease;
}
.provider-card:hover {
transform: translateY(-5px);
}
@keyframes slideIn {
from {
transform: translateY(20px);
opacity: 0;
}
to {
transform: translateY(0);
opacity: 1;
}
}
.slide-in {
animation: slideIn 0.3s ease-out forwards;
}
</style>
</head>
<body class="bg-gray-100 dark:bg-gray-900 min-h-full">
<!-- Theme Toggle -->
<div class="fixed top-4 right-4 z-50">
<button id="themeToggle" class="p-2 rounded-full bg-gray-200 dark:bg-gray-700 hover:bg-gray-300 dark:hover:bg-gray-600 transition-colors">
<svg id="sunIcon" class="w-6 h-6 text-yellow-500 hidden dark:block" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2"
d="M12 3v1m0 16v1m9-9h-1M4 12H3m15.364 6.364l-.707-.707M6.343 6.343l-.707-.707m12.728 0l-.707.707M6.343 17.657l-.707.707M16 12a4 4 0 11-8 0 4 4 0 018 0z"/>
</svg>
<svg id="moonIcon" class="w-6 h-6 text-gray-700 block dark:hidden" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M20.354 15.354A9 9 0 018.646 3.646 9.003 9.003 0 0012 21a9.003 9.003 0 008.354-5.646z"/>
</svg>
</button>
</div>
<main class="container mx-auto px-4 py-8">
<!-- Views Container -->
<div id="views-container">
<!-- API Password View -->
<div id="passwordView" class="space-y-8">
<h1 class="text-3xl font-bold text-center text-gray-800 dark:text-white mb-8">
Enter API Password
</h1>
<div class="max-w-md mx-auto">
<form id="passwordForm" class="bg-white dark:bg-gray-800 rounded-lg shadow-lg p-6 space-y-4">
<div class="space-y-2">
<label for="apiPassword" class="block text-sm font-medium text-gray-700 dark:text-gray-300">
API Password
</label>
<input
type="password"
id="apiPassword"
class="w-full px-4 py-2 rounded-md border border-gray-300 dark:border-gray-600 bg-white dark:bg-gray-700 text-gray-900 dark:text-white focus:outline-none focus:ring-2 focus:ring-blue-500"
required
>
</div>
<div class="flex items-center space-x-2">
<input
type="checkbox"
id="rememberPassword"
class="rounded border-gray-300 dark:border-gray-600"
>
<label for="rememberPassword" class="text-sm text-gray-600 dark:text-gray-400">
Remember password
</label>
</div>
<button
type="submit"
class="w-full px-4 py-2 bg-blue-500 text-white rounded-md hover:bg-blue-600 focus:outline-none focus:ring-2 focus:ring-blue-500 focus:ring-offset-2 transition-colors"
>
Continue
</button>
</form>
</div>
</div>
<!-- Provider Selection View -->
<div id="selectionView" class="space-y-8 hidden">
<h1 class="text-3xl font-bold text-center text-gray-800 dark:text-white mb-8">
Select Debrid Service for Speed Test
</h1>
<div class="grid grid-cols-1 md:grid-cols-2 gap-6 max-w-4xl mx-auto">
<!-- Real-Debrid Card -->
<button onclick="startTest('real_debrid')" class="provider-card bg-white dark:bg-gray-800 rounded-lg shadow-lg p-6 text-left hover:shadow-xl transition-shadow">
<h2 class="text-xl font-semibold text-gray-800 dark:text-white mb-2">Real-Debrid</h2>
<p class="text-gray-600 dark:text-gray-300">Test speeds across multiple Real-Debrid servers worldwide</p>
</button>
<!-- AllDebrid Card -->
<button onclick="showAllDebridSetup()" class="provider-card bg-white dark:bg-gray-800 rounded-lg shadow-lg p-6 text-left hover:shadow-xl transition-shadow">
<h2 class="text-xl font-semibold text-gray-800 dark:text-white mb-2">AllDebrid</h2>
<p class="text-gray-600 dark:text-gray-300">Measure download speeds from AllDebrid servers</p>
</button>
</div>
</div>
<!-- AllDebrid Setup View -->
<div id="allDebridSetupView" class="max-w-md mx-auto space-y-6 hidden">
<h2 class="text-2xl font-bold text-center text-gray-800 dark:text-white mb-8">
AllDebrid Setup
</h2>
<div class="bg-white dark:bg-gray-800 rounded-lg shadow-lg p-6">
<form id="allDebridForm" class="space-y-4">
<div class="space-y-2">
<label for="adApiKey" class="block text-sm font-medium text-gray-700 dark:text-gray-300">
AllDebrid API Key
</label>
<input
type="password"
id="adApiKey"
class="w-full px-4 py-2 rounded-md border border-gray-300 dark:border-gray-600 bg-white dark:bg-gray-700 text-gray-900 dark:text-white focus:outline-none focus:ring-2 focus:ring-blue-500"
required
>
<p class="text-sm text-gray-500 dark:text-gray-400">
You can find your API key in the AllDebrid dashboard
</p>
</div>
<div class="flex items-center space-x-2">
<input
type="checkbox"
id="rememberAdKey"
class="rounded border-gray-300 dark:border-gray-600"
>
<label for="rememberAdKey" class="text-sm text-gray-600 dark:text-gray-400">
Remember API key
</label>
</div>
<div class="flex space-x-3">
<button
type="button"
onclick="showView('selectionView')"
class="flex-1 px-4 py-2 border border-gray-300 dark:border-gray-600 text-gray-700 dark:text-gray-300 rounded-md hover:bg-gray-50 dark:hover:bg-gray-700 focus:outline-none focus:ring-2 focus:ring-blue-500 transition-colors"
>
Back
</button>
<button
type="submit"
class="flex-1 px-4 py-2 bg-blue-500 text-white rounded-md hover:bg-blue-600 focus:outline-none focus:ring-2 focus:ring-blue-500 focus:ring-offset-2 transition-colors"
>
Start Test
</button>
</div>
</form>
</div>
</div>
<!-- Testing View -->
<div id="testingView" class="max-w-4xl mx-auto space-y-6 hidden">
<div class="bg-white dark:bg-gray-800 rounded-lg shadow-lg p-6">
<!-- User Info Section -->
<div id="userInfo" class="mb-6 hidden">
<!-- User info will be populated dynamically -->
</div>
<!-- Progress Section -->
<div class="space-y-4">
<div class="text-center text-gray-600 dark:text-gray-300" id="currentLocation">
Initializing test...
</div>
<div class="h-2 bg-gray-200 dark:bg-gray-700 rounded-full overflow-hidden">
<div class="h-full bg-blue-500 animate-progress" id="progressBar"></div>
</div>
</div>
<!-- Results Container -->
<div id="resultsContainer" class="mt-8">
<!-- Results will be populated dynamically -->
</div>
</div>
</div>
<!-- Results View -->
<div id="resultsView" class="max-w-4xl mx-auto space-y-6 hidden">
<div class="bg-white dark:bg-gray-800 rounded-lg shadow-lg p-6">
<div class="space-y-6">
<!-- Summary Section -->
<div class="border-b border-gray-200 dark:border-gray-700 pb-4">
<h3 class="text-lg font-semibold text-gray-800 dark:text-white mb-4">Test Summary</h3>
<div class="grid grid-cols-2 md:grid-cols-3 gap-4">
<div class="space-y-1">
<div class="text-sm text-gray-500 dark:text-gray-400">Fastest Server</div>
<div id="fastestServer" class="font-medium text-gray-900 dark:text-white"></div>
</div>
<div class="space-y-1">
<div class="text-sm text-gray-500 dark:text-gray-400">Top Speed</div>
<div id="topSpeed" class="font-medium text-green-500"></div>
</div>
<div class="space-y-1">
<div class="text-sm text-gray-500 dark:text-gray-400">Average Speed</div>
<div id="avgSpeed" class="font-medium text-blue-500"></div>
</div>
</div>
</div>
<!-- Detailed Results -->
<div id="finalResults" class="space-y-4">
<!-- Results will be populated here -->
</div>
</div>
</div>
<div class="text-center mt-6">
<button onclick="resetTest()" class="px-6 py-2 bg-blue-500 text-white rounded-lg hover:bg-blue-600 transition-colors">
Test Another Provider
</button>
</div>
</div>
<!-- Error View -->
<div id="errorView" class="max-w-4xl mx-auto space-y-6 hidden">
<div class="bg-red-50 dark:bg-red-900/50 border-l-4 border-red-500 p-4 rounded">
<div class="flex">
<div class="flex-shrink-0">
<svg class="h-5 w-5 text-red-400" viewBox="0 0 20 20" fill="currentColor">
<path fill-rule="evenodd"
d="M10 18a8 8 0 100-16 8 8 0 000 16zM8.707 7.293a1 1 0 00-1.414 1.414L8.586 10l-1.293 1.293a1 1 0 101.414 1.414L10 11.414l1.293 1.293a1 1 0 001.414-1.414L11.414 10l1.293-1.293a1 1 0 00-1.414-1.414L10 8.586 8.707 7.293z"
clip-rule="evenodd"/>
</svg>
</div>
<div class="ml-3">
<p class="text-sm text-red-700 dark:text-red-200" id="errorMessage"></p>
</div>
</div>
</div>
<div class="text-center">
<button onclick="resetTest()"
class="inline-flex items-center px-4 py-2 border border-transparent text-sm font-medium rounded-md shadow-sm text-white bg-blue-600 hover:bg-blue-700 focus:outline-none focus:ring-2 focus:ring-offset-2 focus:ring-blue-500 dark:hover:bg-blue-500 transition-colors duration-200">
Try Again
</button>
</div>
</div>
</div>
</main>
<script>
// Config and State
const STATE = {
apiPassword: localStorage.getItem('speedtest_api_password'),
adApiKey: localStorage.getItem('ad_api_key'),
currentTaskId: null,
resultsCount: 0,
};
// Theme handling
function setTheme(theme) {
document.documentElement.classList.toggle('dark', theme === 'dark');
localStorage.theme = theme;
}
function initTheme() {
const prefersDark = window.matchMedia('(prefers-color-scheme: dark)').matches;
setTheme(localStorage.theme || (prefersDark ? 'dark' : 'light'));
}
// View management
function showView(viewId) {
document.querySelectorAll('#views-container > div').forEach(view => {
view.classList.toggle('hidden', view.id !== viewId);
});
}
function createErrorResult(location, data) {
return `
<div class="py-4">
<div class="flex justify-between items-center">
<div>
<span class="font-medium text-gray-800 dark:text-white">${location}</span>
<span class="ml-2 text-sm text-gray-500 dark:text-gray-400">${data.server_name || ''}</span>
</div>
<span class="text-sm text-red-500 dark:text-red-400">
Failed
</span>
</div>
<div class="mt-1 text-sm text-red-400 dark:text-red-300">
${data.error || 'Test failed'}
</div>
<div class="mt-1 text-xs text-gray-400 dark:text-gray-500">
Server: ${data.server_url}
</div>
</div>
`;
}
function formatBytes(bytes) {
const units = ['B', 'KB', 'MB', 'GB'];
let value = bytes;
let unitIndex = 0;
while (value >= 1024 && unitIndex < units.length - 1) {
value /= 1024;
unitIndex++;
}
return `${value.toFixed(2)} ${units[unitIndex]}`;
}
function handleAuthError() {
localStorage.removeItem('speedtest_api_password');
STATE.apiPassword = null;
showError('Authentication failed. Please check your API password.');
}
function showError(message) {
document.getElementById('errorMessage').textContent = message;
showView('errorView');
}
function resetTest() {
window.location.reload();
}
function showAllDebridSetup() {
showView('allDebridSetupView');
}
async function startTest(provider) {
if (provider === 'all_debrid' && !STATE.adApiKey) {
showAllDebridSetup();
return;
}
showView('testingView');
initializeResultsContainer();
try {
const params = new URLSearchParams({provider});
const headers = {'api_password': STATE.apiPassword};
if (provider === 'all_debrid' && STATE.adApiKey) {
headers['api_key'] = STATE.adApiKey;
}
const response = await fetch(`/speedtest/start?${params}`, {
method: 'POST',
headers
});
if (!response.ok) {
if (response.status === 403) {
handleAuthError();
return;
}
throw new Error('Failed to start speed test');
}
const {task_id} = await response.json();
STATE.currentTaskId = task_id;
await pollResults(task_id);
} catch (error) {
showError(error.message);
}
}
function initializeResultsContainer() {
const container = document.getElementById('resultsContainer');
container.innerHTML = `
<div class="space-y-4">
<div id="locationResults" class="divide-y divide-gray-200 dark:divide-gray-700">
<!-- Results will be populated here -->
</div>
<div id="summaryStats" class="hidden pt-4">
<!-- Summary stats will be populated here -->
</div>
</div>
`;
}
async function pollResults(taskId) {
let retryCount = 0;
const maxRetries = 10;
try {
while (true) {
const response = await fetch(`/speedtest/results/${taskId}`, {
headers: {'api_password': STATE.apiPassword}
});
if (!response.ok) {
if (response.status === 403) {
handleAuthError();
return;
}
if (retryCount < maxRetries) {
retryCount++
await new Promise(resolve => setTimeout(resolve, 2000));
continue;
}
throw new Error('Failed to fetch results after multiple attempts');
}
const data = await response.json();
retryCount = 0; //reset the retry count
if (data.status === 'failed') {
throw new Error('Speed test failed');
}
updateUI(data);
if (data.status === 'completed') {
showFinalResults(data);
break;
}
await new Promise(resolve => setTimeout(resolve, 2000));
}
} catch (error) {
showError(error.message);
}
}
function updateUI(data) {
if (data.user_info) {
updateUserInfo(data.user_info);
}
if (data.current_location) {
document.getElementById('currentLocation').textContent =
`Testing server ${data.current_location}...`;
}
updateResults(data.results);
}
function updateUserInfo(userInfo) {
const userInfoDiv = document.getElementById('userInfo');
userInfoDiv.innerHTML = `
<div class="grid grid-cols-2 md:grid-cols-4 gap-4 p-4 bg-gray-50 dark:bg-gray-700 rounded-lg">
<div class="space-y-1">
<div class="text-sm text-gray-500 dark:text-gray-400">IP Address</div>
<div class="font-medium text-gray-900 dark:text-white">${userInfo.ip}</div>
</div>
<div class="space-y-1">
<div class="text-sm text-gray-500 dark:text-gray-400">ISP</div>
<div class="font-medium text-gray-900 dark:text-white">${userInfo.isp}</div>
</div>
<div class="space-y-1">
<div class="text-sm text-gray-500 dark:text-gray-400">Country</div>
<div class="font-medium text-gray-900 dark:text-white">${userInfo.country?.toUpperCase()}</div>
</div>
</div>
`;
userInfoDiv.classList.remove('hidden');
}
function updateResults(results) {
const container = document.getElementById('resultsContainer');
const validResults = Object.entries(results)
.filter(([, data]) => data.result !== null && !data.error)
.sort(([, a], [, b]) => (b.result.speed_mbps) - (a.result.speed_mbps));
const failedResults = Object.entries(results)
.filter(([, data]) => data.error || data.result === null);
// Generate HTML for results
const resultsHTML = [
// Successful results
...validResults.map(([location, data]) => createSuccessResult(location, data)),
// Failed results
...failedResults.map(([location, data]) => createErrorResult(location, data))
].join('');
container.innerHTML = `
<div class="space-y-4">
<!-- Summary Stats -->
${createSummaryStats(validResults)}
<!-- Individual Results -->
<div class="mt-6 divide-y divide-gray-200 dark:divide-gray-700">
${resultsHTML}
</div>
</div>
`;
}
function createSummaryStats(validResults) {
if (validResults.length === 0) return '';
const speeds = validResults.map(([, data]) => data.result.speed_mbps);
const maxSpeed = Math.max(...speeds);
const avgSpeed = speeds.reduce((a, b) => a + b, 0) / speeds.length;
const fastestServer = validResults[0][0]; // First server after sorting
return `
<div class="grid grid-cols-1 md:grid-cols-3 gap-4 bg-gray-50 dark:bg-gray-800 p-4 rounded-lg">
<div class="text-center">
<div class="text-sm text-gray-500 dark:text-gray-400">Fastest Server</div>
<div class="font-medium text-gray-900 dark:text-white">${fastestServer}</div>
</div>
<div class="text-center">
<div class="text-sm text-gray-500 dark:text-gray-400">Top Speed</div>
<div class="font-medium text-green-500">${maxSpeed.toFixed(2)} Mbps</div>
</div>
<div class="text-center">
<div class="text-sm text-gray-500 dark:text-gray-400">Average Speed</div>
<div class="font-medium text-blue-500">${avgSpeed.toFixed(2)} Mbps</div>
</div>
</div>
`;
}
function createSuccessResult(location, data) {
const speedClass = getSpeedClass(data.result.speed_mbps);
return `
<div class="py-4">
<div class="flex justify-between items-center">
<div>
<span class="font-medium text-gray-800 dark:text-white">${location}</span>
<span class="ml-2 text-sm text-gray-500 dark:text-gray-400">${data.server_name || ''}</span>
</div>
<span class="text-lg font-semibold ${speedClass}">${data.result.speed_mbps.toFixed(2)} Mbps</span>
</div>
<div class="mt-1 text-sm text-gray-500 dark:text-gray-400">
Duration: ${data.result.duration.toFixed(2)}s •
Data: ${formatBytes(data.result.data_transferred)}
</div>
<div class="mt-1 text-xs text-gray-400 dark:text-gray-500">
Server: ${data.server_url}
</div>
</div>
`;
}
function getSpeedClass(speed) {
if (speed >= 10) return 'text-green-500 dark:text-green-400';
if (speed >= 5) return 'text-blue-500 dark:text-blue-400';
if (speed >= 2) return 'text-yellow-500 dark:text-yellow-400';
return 'text-red-500 dark:text-red-400';
}
function showFinalResults(data) {
// Stop the progress animation
document.querySelector('#progressBar').style.animation = 'none';
// Update the final results view
const validResults = Object.entries(data.results)
.filter(([, data]) => data.result !== null && !data.error)
.sort(([, a], [, b]) => (b.result.speed_mbps) - (a.result.speed_mbps));
const failedResults = Object.entries(data.results)
.filter(([, data]) => data.error || data.result === null);
// Update summary stats
if (validResults.length > 0) {
const speeds = validResults.map(([, data]) => data.result.speed_mbps);
const maxSpeed = Math.max(...speeds);
const avgSpeed = speeds.reduce((a, b) => a + b, 0) / speeds.length;
const fastestServer = validResults[0][0];
document.getElementById('fastestServer').textContent = fastestServer;
document.getElementById('topSpeed').textContent = `${maxSpeed.toFixed(2)} Mbps`;
document.getElementById('avgSpeed').textContent = `${avgSpeed.toFixed(2)} Mbps`;
}
// Generate detailed results HTML
const finalResultsHTML = `
${validResults.map(([location, data]) => `
<div class="bg-white dark:bg-gray-800 rounded-lg p-4 shadow-sm">
<div class="flex justify-between items-center">
<div>
<h3 class="text-lg font-medium text-gray-900 dark:text-white">${location}</h3>
<p class="text-sm text-gray-500 dark:text-gray-400">${data.server_name || ''}</p>
</div>
<div class="text-right">
<p class="text-2xl font-bold ${getSpeedClass(data.result.speed_mbps)}">
${data.result.speed_mbps.toFixed(2)} Mbps
</p>
<p class="text-sm text-gray-500 dark:text-gray-400">
${data.result.duration.toFixed(2)}s • ${formatBytes(data.result.data_transferred)}
</p>
</div>
</div>
<div class="mt-2 text-xs text-gray-400 dark:text-gray-500">
${data.server_url}
</div>
</div>
`).join('')}
${failedResults.length > 0 ? `
<div class="mt-6">
<h3 class="text-lg font-medium text-gray-900 dark:text-white mb-4">Failed Tests</h3>
${failedResults.map(([location, data]) => `
<div class="bg-red-50 dark:bg-red-900/20 rounded-lg p-4 mb-4">
<div class="flex justify-between items-center">
<div>
<h4 class="font-medium text-red-800 dark:text-red-200">
${location} ${data.server_name ? `(${data.server_name})` : ''}
</h4>
<p class="text-sm text-red-700 dark:text-red-300">
${data.error || 'Test failed'}
</p>
<p class="text-xs text-red-600 dark:text-red-400 mt-1">
${data.server_url}
</p>
</div>
</div>
</div>
`).join('')}
</div>
` : ''}
`;
document.getElementById('finalResults').innerHTML = finalResultsHTML;
// If we have user info from AllDebrid, copy it to the final view
const userInfoDiv = document.getElementById('userInfo');
if (!userInfoDiv.classList.contains('hidden') && data.user_info) {
const userInfoContent = userInfoDiv.innerHTML;
document.getElementById('finalResults').insertAdjacentHTML('afterbegin', `
<div class="mb-6">
${userInfoContent}
</div>
`);
}
// Show the final results view
showView('resultsView');
}
function initializeView() {
initTheme();
showView(STATE.apiPassword ? 'selectionView' : 'passwordView');
}
function initializeFormHandlers() {
// Password form handler
document.getElementById('passwordForm').addEventListener('submit', (e) => {
e.preventDefault();
const password = document.getElementById('apiPassword').value;
const remember = document.getElementById('rememberPassword').checked;
if (remember) {
localStorage.setItem('speedtest_api_password', password);
}
STATE.apiPassword = password;
showView('selectionView');
});
// AllDebrid form handler
document.getElementById('allDebridForm').addEventListener('submit', async (e) => {
e.preventDefault();
const apiKey = document.getElementById('adApiKey').value;
const remember = document.getElementById('rememberAdKey').checked;
if (remember) {
localStorage.setItem('ad_api_key', apiKey);
}
STATE.adApiKey = apiKey;
await startTest('all_debrid');
});
}
document.addEventListener('DOMContentLoaded', () => {
initializeView();
initializeFormHandlers();
});
// Theme Toggle Event Listener
document.getElementById('themeToggle').addEventListener('click', () => {
setTheme(document.documentElement.classList.contains('dark') ? 'light' : 'dark');
});
</script>
</body>
</html>

View File

View File

@@ -0,0 +1,376 @@
import asyncio
import hashlib
import json
import logging
import os
import tempfile
import threading
import time
from collections import OrderedDict
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Union, Any
import aiofiles
import aiofiles.os
from pydantic import ValidationError
from mediaflow_proxy.speedtest.models import SpeedTestTask
from mediaflow_proxy.utils.http_utils import download_file_with_retry, DownloadError
from mediaflow_proxy.utils.mpd_utils import parse_mpd, parse_mpd_dict
logger = logging.getLogger(__name__)
@dataclass
class CacheEntry:
"""Represents a cache entry with metadata."""
data: bytes
expires_at: float
access_count: int = 0
last_access: float = 0.0
size: int = 0
class LRUMemoryCache:
"""Thread-safe LRU memory cache with support."""
def __init__(self, maxsize: int):
self.maxsize = maxsize
self._cache: OrderedDict[str, CacheEntry] = OrderedDict()
self._lock = threading.Lock()
self._current_size = 0
def get(self, key: str) -> Optional[CacheEntry]:
with self._lock:
if key in self._cache:
entry = self._cache.pop(key) # Remove and re-insert for LRU
if time.time() < entry.expires_at:
entry.access_count += 1
entry.last_access = time.time()
self._cache[key] = entry
return entry
else:
# Remove expired entry
self._current_size -= entry.size
self._cache.pop(key, None)
return None
def set(self, key: str, entry: CacheEntry) -> None:
with self._lock:
if key in self._cache:
old_entry = self._cache[key]
self._current_size -= old_entry.size
# Check if we need to make space
while self._current_size + entry.size > self.maxsize and self._cache:
_, removed_entry = self._cache.popitem(last=False)
self._current_size -= removed_entry.size
self._cache[key] = entry
self._current_size += entry.size
def remove(self, key: str) -> None:
with self._lock:
if key in self._cache:
entry = self._cache.pop(key)
self._current_size -= entry.size
class HybridCache:
"""High-performance hybrid cache combining memory and file storage."""
def __init__(
self,
cache_dir_name: str,
ttl: int,
max_memory_size: int = 100 * 1024 * 1024, # 100MB default
executor_workers: int = 4,
):
self.cache_dir = Path(tempfile.gettempdir()) / cache_dir_name
self.ttl = ttl
self.memory_cache = LRUMemoryCache(maxsize=max_memory_size)
self._executor = ThreadPoolExecutor(max_workers=executor_workers)
self._lock = asyncio.Lock()
# Initialize cache directories
self._init_cache_dirs()
def _init_cache_dirs(self):
"""Initialize sharded cache directories."""
os.makedirs(self.cache_dir, exist_ok=True)
def _get_md5_hash(self, key: str) -> str:
"""Get the MD5 hash of a cache key."""
return hashlib.md5(key.encode()).hexdigest()
def _get_file_path(self, key: str) -> Path:
"""Get the file path for a cache key."""
return self.cache_dir / key
async def get(self, key: str, default: Any = None) -> Optional[bytes]:
"""
Get value from cache, trying memory first then file.
Args:
key: Cache key
default: Default value if key not found
Returns:
Cached value or default if not found
"""
key = self._get_md5_hash(key)
# Try memory cache first
entry = self.memory_cache.get(key)
if entry is not None:
return entry.data
# Try file cache
try:
file_path = self._get_file_path(key)
async with aiofiles.open(file_path, "rb") as f:
metadata_size = await f.read(8)
metadata_length = int.from_bytes(metadata_size, "big")
metadata_bytes = await f.read(metadata_length)
metadata = json.loads(metadata_bytes.decode())
# Check expiration
if metadata["expires_at"] < time.time():
await self.delete(key)
return default
# Read data
data = await f.read()
# Update memory cache in background
entry = CacheEntry(
data=data,
expires_at=metadata["expires_at"],
access_count=metadata["access_count"] + 1,
last_access=time.time(),
size=len(data),
)
self.memory_cache.set(key, entry)
return data
except FileNotFoundError:
return default
except Exception as e:
logger.error(f"Error reading from cache: {e}")
return default
async def set(self, key: str, data: Union[bytes, bytearray, memoryview], ttl: Optional[int] = None) -> bool:
"""
Set value in both memory and file cache.
Args:
key: Cache key
data: Data to cache
ttl: Optional TTL override
Returns:
bool: Success status
"""
if not isinstance(data, (bytes, bytearray, memoryview)):
raise ValueError("Data must be bytes, bytearray, or memoryview")
expires_at = time.time() + (ttl or self.ttl)
# Create cache entry
entry = CacheEntry(data=data, expires_at=expires_at, access_count=0, last_access=time.time(), size=len(data))
key = self._get_md5_hash(key)
# Update memory cache
self.memory_cache.set(key, entry)
file_path = self._get_file_path(key)
temp_path = file_path.with_suffix(".tmp")
# Update file cache
try:
metadata = {"expires_at": expires_at, "access_count": 0, "last_access": time.time()}
metadata_bytes = json.dumps(metadata).encode()
metadata_size = len(metadata_bytes).to_bytes(8, "big")
async with aiofiles.open(temp_path, "wb") as f:
await f.write(metadata_size)
await f.write(metadata_bytes)
await f.write(data)
await aiofiles.os.rename(temp_path, file_path)
return True
except Exception as e:
logger.error(f"Error writing to cache: {e}")
try:
await aiofiles.os.remove(temp_path)
except:
pass
return False
async def delete(self, key: str) -> bool:
"""Delete item from both caches."""
self.memory_cache.remove(key)
try:
file_path = self._get_file_path(key)
await aiofiles.os.remove(file_path)
return True
except FileNotFoundError:
return True
except Exception as e:
logger.error(f"Error deleting from cache: {e}")
return False
class AsyncMemoryCache:
"""Async wrapper around LRUMemoryCache."""
def __init__(self, max_memory_size: int):
self.memory_cache = LRUMemoryCache(maxsize=max_memory_size)
async def get(self, key: str, default: Any = None) -> Optional[bytes]:
"""Get value from cache."""
entry = self.memory_cache.get(key)
return entry.data if entry is not None else default
async def set(self, key: str, data: Union[bytes, bytearray, memoryview], ttl: Optional[int] = None) -> bool:
"""Set value in cache."""
try:
expires_at = time.time() + (ttl or 3600) # Default 1 hour TTL if not specified
entry = CacheEntry(
data=data, expires_at=expires_at, access_count=0, last_access=time.time(), size=len(data)
)
self.memory_cache.set(key, entry)
return True
except Exception as e:
logger.error(f"Error setting cache value: {e}")
return False
async def delete(self, key: str) -> bool:
"""Delete item from cache."""
try:
self.memory_cache.remove(key)
return True
except Exception as e:
logger.error(f"Error deleting from cache: {e}")
return False
# Create cache instances
INIT_SEGMENT_CACHE = HybridCache(
cache_dir_name="init_segment_cache",
ttl=3600, # 1 hour
max_memory_size=500 * 1024 * 1024, # 500MB for init segments
)
MPD_CACHE = AsyncMemoryCache(
max_memory_size=100 * 1024 * 1024, # 100MB for MPD files
)
SPEEDTEST_CACHE = HybridCache(
cache_dir_name="speedtest_cache",
ttl=3600, # 1 hour
max_memory_size=50 * 1024 * 1024,
)
EXTRACTOR_CACHE = HybridCache(
cache_dir_name="extractor_cache",
ttl=5 * 60, # 5 minutes
max_memory_size=50 * 1024 * 1024,
)
# Specific cache implementations
async def get_cached_init_segment(init_url: str, headers: dict) -> Optional[bytes]:
"""Get initialization segment from cache or download it."""
# Try cache first
cached_data = await INIT_SEGMENT_CACHE.get(init_url)
if cached_data is not None:
return cached_data
# Download if not cached
try:
init_content = await download_file_with_retry(init_url, headers)
if init_content:
await INIT_SEGMENT_CACHE.set(init_url, init_content)
return init_content
except Exception as e:
logger.error(f"Error downloading init segment: {e}")
return None
async def get_cached_mpd(
mpd_url: str,
headers: dict,
parse_drm: bool,
parse_segment_profile_id: Optional[str] = None,
) -> dict:
"""Get MPD from cache or download and parse it."""
# Try cache first
cached_data = await MPD_CACHE.get(mpd_url)
if cached_data is not None:
try:
mpd_dict = json.loads(cached_data)
return parse_mpd_dict(mpd_dict, mpd_url, parse_drm, parse_segment_profile_id)
except json.JSONDecodeError:
await MPD_CACHE.delete(mpd_url)
# Download and parse if not cached
try:
mpd_content = await download_file_with_retry(mpd_url, headers)
mpd_dict = parse_mpd(mpd_content)
parsed_dict = parse_mpd_dict(mpd_dict, mpd_url, parse_drm, parse_segment_profile_id)
# Cache the original MPD dict
await MPD_CACHE.set(mpd_url, json.dumps(mpd_dict).encode(), ttl=parsed_dict["minimumUpdatePeriod"])
return parsed_dict
except DownloadError as error:
logger.error(f"Error downloading MPD: {error}")
raise error
except Exception as error:
logger.exception(f"Error processing MPD: {error}")
raise error
async def get_cached_speedtest(task_id: str) -> Optional[SpeedTestTask]:
"""Get speed test results from cache."""
cached_data = await SPEEDTEST_CACHE.get(task_id)
if cached_data is not None:
try:
return SpeedTestTask.model_validate_json(cached_data.decode())
except ValidationError as e:
logger.error(f"Error parsing cached speed test data: {e}")
await SPEEDTEST_CACHE.delete(task_id)
return None
async def set_cache_speedtest(task_id: str, task: SpeedTestTask) -> bool:
"""Cache speed test results."""
try:
return await SPEEDTEST_CACHE.set(task_id, task.model_dump_json().encode())
except Exception as e:
logger.error(f"Error caching speed test data: {e}")
return False
async def get_cached_extractor_result(key: str) -> Optional[dict]:
"""Get extractor result from cache."""
cached_data = await EXTRACTOR_CACHE.get(key)
if cached_data is not None:
try:
return json.loads(cached_data)
except json.JSONDecodeError:
await EXTRACTOR_CACHE.delete(key)
return None
async def set_cache_extractor_result(key: str, result: dict) -> bool:
"""Cache extractor result."""
try:
return await EXTRACTOR_CACHE.set(key, json.dumps(result).encode())
except Exception as e:
logger.error(f"Error caching extractor result: {e}")
return False

View File

@@ -0,0 +1,110 @@
import base64
import json
import logging
import time
import traceback
from typing import Optional
from urllib.parse import urlencode
from Crypto.Cipher import AES
from Crypto.Random import get_random_bytes
from Crypto.Util.Padding import pad, unpad
from fastapi import HTTPException, Request
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
from mediaflow_proxy.configs import settings
class EncryptionHandler:
def __init__(self, secret_key: str):
self.secret_key = secret_key.encode("utf-8").ljust(32)[:32]
def encrypt_data(self, data: dict, expiration: int = None, ip: str = None) -> str:
if expiration:
data["exp"] = int(time.time()) + expiration
if ip:
data["ip"] = ip
json_data = json.dumps(data).encode("utf-8")
iv = get_random_bytes(16)
cipher = AES.new(self.secret_key, AES.MODE_CBC, iv)
encrypted_data = cipher.encrypt(pad(json_data, AES.block_size))
return base64.urlsafe_b64encode(iv + encrypted_data).decode("utf-8")
def decrypt_data(self, token: str, client_ip: str) -> dict:
try:
encrypted_data = base64.urlsafe_b64decode(token.encode("utf-8"))
iv = encrypted_data[:16]
cipher = AES.new(self.secret_key, AES.MODE_CBC, iv)
decrypted_data = unpad(cipher.decrypt(encrypted_data[16:]), AES.block_size)
data = json.loads(decrypted_data)
if "exp" in data:
if data["exp"] < time.time():
raise HTTPException(status_code=401, detail="Token has expired")
del data["exp"] # Remove expiration from the data
if "ip" in data:
if data["ip"] != client_ip:
raise HTTPException(status_code=403, detail="IP address mismatch")
del data["ip"] # Remove IP from the data
return data
except Exception as e:
raise HTTPException(status_code=401, detail="Invalid or expired token")
class EncryptionMiddleware(BaseHTTPMiddleware):
def __init__(self, app):
super().__init__(app)
self.encryption_handler = encryption_handler
async def dispatch(self, request: Request, call_next):
encrypted_token = request.query_params.get("token")
if encrypted_token and self.encryption_handler:
try:
client_ip = self.get_client_ip(request)
decrypted_data = self.encryption_handler.decrypt_data(encrypted_token, client_ip)
# Modify request query parameters with decrypted data
query_params = dict(request.query_params)
query_params.pop("token") # Remove the encrypted token from query params
query_params.update(decrypted_data) # Add decrypted data to query params
query_params["has_encrypted"] = True
# Create a new request scope with updated query parameters
new_query_string = urlencode(query_params)
request.scope["query_string"] = new_query_string.encode()
request._query_params = query_params
except HTTPException as e:
return JSONResponse(content={"error": str(e.detail)}, status_code=e.status_code)
try:
response = await call_next(request)
except Exception:
exc = traceback.format_exc(chain=False)
logging.error("An error occurred while processing the request, error: %s", exc)
return JSONResponse(
content={"error": "An error occurred while processing the request, check the server for logs"},
status_code=500,
)
return response
@staticmethod
def get_client_ip(request: Request) -> Optional[str]:
"""
Extract the client's real IP address from the request headers or fallback to the client host.
"""
x_forwarded_for = request.headers.get("X-Forwarded-For")
if x_forwarded_for:
# In some cases, this header can contain multiple IPs
# separated by commas.
# The first one is the original client's IP.
return x_forwarded_for.split(",")[0].strip()
# Fallback to X-Real-IP if X-Forwarded-For is not available
x_real_ip = request.headers.get("X-Real-IP")
if x_real_ip:
return x_real_ip
return request.client.host if request.client else "127.0.0.1"
encryption_handler = EncryptionHandler(settings.api_password) if settings.api_password else None

View File

@@ -0,0 +1,430 @@
import logging
import typing
from dataclasses import dataclass
from functools import partial
from urllib import parse
from urllib.parse import urlencode
import anyio
import httpx
import tenacity
from fastapi import Response
from starlette.background import BackgroundTask
from starlette.concurrency import iterate_in_threadpool
from starlette.requests import Request
from starlette.types import Receive, Send, Scope
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
from tqdm.asyncio import tqdm as tqdm_asyncio
from mediaflow_proxy.configs import settings
from mediaflow_proxy.const import SUPPORTED_REQUEST_HEADERS
from mediaflow_proxy.utils.crypto_utils import EncryptionHandler
logger = logging.getLogger(__name__)
class DownloadError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
super().__init__(message)
def create_httpx_client(follow_redirects: bool = True, timeout: float = 30.0, **kwargs) -> httpx.AsyncClient:
"""Creates an HTTPX client with configured proxy routing"""
mounts = settings.transport_config.get_mounts()
client = httpx.AsyncClient(mounts=mounts, follow_redirects=follow_redirects, timeout=timeout, **kwargs)
return client
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(DownloadError),
)
async def fetch_with_retry(client, method, url, headers, follow_redirects=True, **kwargs):
"""
Fetches a URL with retry logic.
Args:
client (httpx.AsyncClient): The HTTP client to use for the request.
method (str): The HTTP method to use (e.g., GET, POST).
url (str): The URL to fetch.
headers (dict): The headers to include in the request.
follow_redirects (bool, optional): Whether to follow redirects. Defaults to True.
**kwargs: Additional arguments to pass to the request.
Returns:
httpx.Response: The HTTP response.
Raises:
DownloadError: If the request fails after retries.
"""
try:
response = await client.request(method, url, headers=headers, follow_redirects=follow_redirects, **kwargs)
response.raise_for_status()
return response
except httpx.TimeoutException:
logger.warning(f"Timeout while downloading {url}")
raise DownloadError(409, f"Timeout while downloading {url}")
except httpx.HTTPStatusError as e:
logger.error(f"HTTP error {e.response.status_code} while downloading {url}")
if e.response.status_code == 404:
logger.error(f"Segment Resource not found: {url}")
raise e
raise DownloadError(e.response.status_code, f"HTTP error {e.response.status_code} while downloading {url}")
except Exception as e:
logger.error(f"Error downloading {url}: {e}")
raise
class Streamer:
def __init__(self, client):
"""
Initializes the Streamer with an HTTP client.
Args:
client (httpx.AsyncClient): The HTTP client to use for streaming.
"""
self.client = client
self.response = None
self.progress_bar = None
self.bytes_transferred = 0
self.start_byte = 0
self.end_byte = 0
self.total_size = 0
async def create_streaming_response(self, url: str, headers: dict):
"""
Creates and sends a streaming request.
Args:
url (str): The URL to stream from.
headers (dict): The headers to include in the request.
"""
request = self.client.build_request("GET", url, headers=headers)
self.response = await self.client.send(request, stream=True, follow_redirects=True)
self.response.raise_for_status()
async def stream_content(self) -> typing.AsyncGenerator[bytes, None]:
"""
Streams the content from the response.
"""
if not self.response:
raise RuntimeError("No response available for streaming")
try:
self.parse_content_range()
if settings.enable_streaming_progress:
with tqdm_asyncio(
total=self.total_size,
initial=self.start_byte,
unit="B",
unit_scale=True,
unit_divisor=1024,
desc="Streaming",
ncols=100,
mininterval=1,
) as self.progress_bar:
async for chunk in self.response.aiter_bytes():
yield chunk
chunk_size = len(chunk)
self.bytes_transferred += chunk_size
self.progress_bar.set_postfix_str(
f"📥 : {self.format_bytes(self.bytes_transferred)}", refresh=False
)
self.progress_bar.update(chunk_size)
else:
async for chunk in self.response.aiter_bytes():
yield chunk
self.bytes_transferred += len(chunk)
except httpx.TimeoutException:
logger.warning("Timeout while streaming")
raise DownloadError(409, "Timeout while streaming")
except GeneratorExit:
logger.info("Streaming session stopped by the user")
except Exception as e:
logger.error(f"Error streaming content: {e}")
raise
@staticmethod
def format_bytes(size) -> str:
power = 2**10
n = 0
units = {0: "B", 1: "KB", 2: "MB", 3: "GB", 4: "TB"}
while size > power:
size /= power
n += 1
return f"{size:.2f} {units[n]}"
def parse_content_range(self):
content_range = self.response.headers.get("Content-Range", "")
if content_range:
range_info = content_range.split()[-1]
self.start_byte, self.end_byte, self.total_size = map(int, range_info.replace("/", "-").split("-"))
else:
self.start_byte = 0
self.total_size = int(self.response.headers.get("Content-Length", 0))
self.end_byte = self.total_size - 1 if self.total_size > 0 else 0
async def get_text(self, url: str, headers: dict):
"""
Sends a GET request to a URL and returns the response text.
Args:
url (str): The URL to send the GET request to.
headers (dict): The headers to include in the request.
Returns:
str: The response text.
"""
try:
self.response = await fetch_with_retry(self.client, "GET", url, headers)
except tenacity.RetryError as e:
raise e.last_attempt.result()
return self.response.text
async def close(self):
"""
Closes the HTTP client and response.
"""
if self.response:
await self.response.aclose()
if self.progress_bar:
self.progress_bar.close()
await self.client.aclose()
async def download_file_with_retry(url: str, headers: dict):
"""
Downloads a file with retry logic.
Args:
url (str): The URL of the file to download.
headers (dict): The headers to include in the request.
Returns:
bytes: The downloaded file content.
Raises:
DownloadError: If the download fails after retries.
"""
async with create_httpx_client() as client:
try:
response = await fetch_with_retry(client, "GET", url, headers)
return response.content
except DownloadError as e:
logger.error(f"Failed to download file: {e}")
raise e
except tenacity.RetryError as e:
raise DownloadError(502, f"Failed to download file: {e.last_attempt.result()}")
async def request_with_retry(method: str, url: str, headers: dict, **kwargs) -> httpx.Response:
"""
Sends an HTTP request with retry logic.
Args:
method (str): The HTTP method to use (e.g., GET, POST).
url (str): The URL to send the request to.
headers (dict): The headers to include in the request.
**kwargs: Additional arguments to pass to the request.
Returns:
httpx.Response: The HTTP response.
Raises:
DownloadError: If the request fails after retries.
"""
async with create_httpx_client() as client:
try:
response = await fetch_with_retry(client, method, url, headers, **kwargs)
return response
except DownloadError as e:
logger.error(f"Failed to download file: {e}")
raise
def encode_mediaflow_proxy_url(
mediaflow_proxy_url: str,
endpoint: typing.Optional[str] = None,
destination_url: typing.Optional[str] = None,
query_params: typing.Optional[dict] = None,
request_headers: typing.Optional[dict] = None,
response_headers: typing.Optional[dict] = None,
encryption_handler: EncryptionHandler = None,
expiration: int = None,
ip: str = None,
) -> str:
"""
Encodes & Encrypt (Optional) a MediaFlow proxy URL with query parameters and headers.
Args:
mediaflow_proxy_url (str): The base MediaFlow proxy URL.
endpoint (str, optional): The endpoint to append to the base URL. Defaults to None.
destination_url (str, optional): The destination URL to include in the query parameters. Defaults to None.
query_params (dict, optional): Additional query parameters to include. Defaults to None.
request_headers (dict, optional): Headers to include as query parameters. Defaults to None.
response_headers (dict, optional): Headers to include as query parameters. Defaults to None.
encryption_handler (EncryptionHandler, optional): The encryption handler to use. Defaults to None.
expiration (int, optional): The expiration time for the encrypted token. Defaults to None.
ip (str, optional): The public IP address to include in the query parameters. Defaults to None.
Returns:
str: The encoded MediaFlow proxy URL.
"""
query_params = query_params or {}
if destination_url is not None:
query_params["d"] = destination_url
# Add headers if provided
if request_headers:
query_params.update(
{key if key.startswith("h_") else f"h_{key}": value for key, value in request_headers.items()}
)
if response_headers:
query_params.update(
{key if key.startswith("r_") else f"r_{key}": value for key, value in response_headers.items()}
)
if encryption_handler:
encrypted_token = encryption_handler.encrypt_data(query_params, expiration, ip)
encoded_params = urlencode({"token": encrypted_token})
else:
encoded_params = urlencode(query_params)
# Construct the full URL
if endpoint is None:
return f"{mediaflow_proxy_url}?{encoded_params}"
base_url = parse.urljoin(mediaflow_proxy_url, endpoint)
return f"{base_url}?{encoded_params}"
def get_original_scheme(request: Request) -> str:
"""
Determines the original scheme (http or https) of the request.
Args:
request (Request): The incoming HTTP request.
Returns:
str: The original scheme ('http' or 'https')
"""
# Check the X-Forwarded-Proto header first
forwarded_proto = request.headers.get("X-Forwarded-Proto")
if forwarded_proto:
return forwarded_proto
# Check if the request is secure
if request.url.scheme == "https" or request.headers.get("X-Forwarded-Ssl") == "on":
return "https"
# Check for other common headers that might indicate HTTPS
if (
request.headers.get("X-Forwarded-Ssl") == "on"
or request.headers.get("X-Forwarded-Protocol") == "https"
or request.headers.get("X-Url-Scheme") == "https"
):
return "https"
# Default to http if no indicators of https are found
return "http"
@dataclass
class ProxyRequestHeaders:
request: dict
response: dict
def get_proxy_headers(request: Request) -> ProxyRequestHeaders:
"""
Extracts proxy headers from the request query parameters.
Args:
request (Request): The incoming HTTP request.
Returns:
ProxyRequest: A named tuple containing the request headers and response headers.
"""
request_headers = {k: v for k, v in request.headers.items() if k in SUPPORTED_REQUEST_HEADERS}
request_headers.update({k[2:].lower(): v for k, v in request.query_params.items() if k.startswith("h_")})
response_headers = {k[2:].lower(): v for k, v in request.query_params.items() if k.startswith("r_")}
return ProxyRequestHeaders(request_headers, response_headers)
class EnhancedStreamingResponse(Response):
body_iterator: typing.AsyncIterable[typing.Any]
def __init__(
self,
content: typing.Union[typing.AsyncIterable[typing.Any], typing.Iterable[typing.Any]],
status_code: int = 200,
headers: typing.Optional[typing.Mapping[str, str]] = None,
media_type: typing.Optional[str] = None,
background: typing.Optional[BackgroundTask] = None,
) -> None:
if isinstance(content, typing.AsyncIterable):
self.body_iterator = content
else:
self.body_iterator = iterate_in_threadpool(content)
self.status_code = status_code
self.media_type = self.media_type if media_type is None else media_type
self.background = background
self.init_headers(headers)
@staticmethod
async def listen_for_disconnect(receive: Receive) -> None:
try:
while True:
message = await receive()
if message["type"] == "http.disconnect":
logger.debug("Client disconnected")
break
except Exception as e:
logger.error(f"Error in listen_for_disconnect: {str(e)}")
async def stream_response(self, send: Send) -> None:
try:
await send(
{
"type": "http.response.start",
"status": self.status_code,
"headers": self.raw_headers,
}
)
async for chunk in self.body_iterator:
if not isinstance(chunk, (bytes, memoryview)):
chunk = chunk.encode(self.charset)
try:
await send({"type": "http.response.body", "body": chunk, "more_body": True})
except (ConnectionResetError, anyio.BrokenResourceError):
logger.info("Client disconnected during streaming")
return
await send({"type": "http.response.body", "body": b"", "more_body": False})
except Exception as e:
logger.exception(f"Error in stream_response: {str(e)}")
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
async with anyio.create_task_group() as task_group:
async def wrap(func: typing.Callable[[], typing.Awaitable[None]]) -> None:
try:
await func()
except Exception as e:
if not isinstance(e, anyio.get_cancelled_exc_class()):
logger.exception("Error in streaming task")
raise
finally:
task_group.cancel_scope.cancel()
task_group.start_soon(wrap, partial(self.stream_response, send))
await wrap(partial(self.listen_for_disconnect, receive))
if self.background is not None:
await self.background()

View File

@@ -0,0 +1,87 @@
import re
from urllib import parse
from mediaflow_proxy.utils.crypto_utils import encryption_handler
from mediaflow_proxy.utils.http_utils import encode_mediaflow_proxy_url, get_original_scheme
class M3U8Processor:
def __init__(self, request, key_url: str = None):
"""
Initializes the M3U8Processor with the request and URL prefix.
Args:
request (Request): The incoming HTTP request.
key_url (HttpUrl, optional): The URL of the key server. Defaults to None.
"""
self.request = request
self.key_url = parse.urlparse(key_url) if key_url else None
self.mediaflow_proxy_url = str(
request.url_for("hls_manifest_proxy").replace(scheme=get_original_scheme(request))
)
async def process_m3u8(self, content: str, base_url: str) -> str:
"""
Processes the m3u8 content, proxying URLs and handling key lines.
Args:
content (str): The m3u8 content to process.
base_url (str): The base URL to resolve relative URLs.
Returns:
str: The processed m3u8 content.
"""
lines = content.splitlines()
processed_lines = []
for line in lines:
if "URI=" in line:
processed_lines.append(await self.process_key_line(line, base_url))
elif not line.startswith("#") and line.strip():
processed_lines.append(await self.proxy_url(line, base_url))
else:
processed_lines.append(line)
return "\n".join(processed_lines)
async def process_key_line(self, line: str, base_url: str) -> str:
"""
Processes a key line in the m3u8 content, proxying the URI.
Args:
line (str): The key line to process.
base_url (str): The base URL to resolve relative URLs.
Returns:
str: The processed key line.
"""
uri_match = re.search(r'URI="([^"]+)"', line)
if uri_match:
original_uri = uri_match.group(1)
uri = parse.urlparse(original_uri)
if self.key_url:
uri = uri._replace(scheme=self.key_url.scheme, netloc=self.key_url.netloc)
new_uri = await self.proxy_url(uri.geturl(), base_url)
line = line.replace(f'URI="{original_uri}"', f'URI="{new_uri}"')
return line
async def proxy_url(self, url: str, base_url: str) -> str:
"""
Proxies a URL, encoding it with the MediaFlow proxy URL.
Args:
url (str): The URL to proxy.
base_url (str): The base URL to resolve relative URLs.
Returns:
str: The proxied URL.
"""
full_url = parse.urljoin(base_url, url)
query_params = dict(self.request.query_params)
has_encrypted = query_params.pop("has_encrypted", False)
return encode_mediaflow_proxy_url(
self.mediaflow_proxy_url,
"",
full_url,
query_params=dict(self.request.query_params),
encryption_handler=encryption_handler if has_encrypted else None,
)

View File

@@ -0,0 +1,555 @@
import logging
import math
import re
from datetime import datetime, timedelta, timezone
from typing import List, Dict, Optional, Union
from urllib.parse import urljoin
import xmltodict
logger = logging.getLogger(__name__)
def parse_mpd(mpd_content: Union[str, bytes]) -> dict:
"""
Parses the MPD content into a dictionary.
Args:
mpd_content (Union[str, bytes]): The MPD content to parse.
Returns:
dict: The parsed MPD content as a dictionary.
"""
return xmltodict.parse(mpd_content)
def parse_mpd_dict(
mpd_dict: dict, mpd_url: str, parse_drm: bool = True, parse_segment_profile_id: Optional[str] = None
) -> dict:
"""
Parses the MPD dictionary and extracts relevant information.
Args:
mpd_dict (dict): The MPD content as a dictionary.
mpd_url (str): The URL of the MPD manifest.
parse_drm (bool, optional): Whether to parse DRM information. Defaults to True.
parse_segment_profile_id (str, optional): The profile ID to parse segments for. Defaults to None.
Returns:
dict: The parsed MPD information including profiles and DRM info.
This function processes the MPD dictionary to extract profiles, DRM information, and other relevant data.
It handles both live and static MPD manifests.
"""
profiles = []
parsed_dict = {}
source = "/".join(mpd_url.split("/")[:-1])
is_live = mpd_dict["MPD"].get("@type", "static").lower() == "dynamic"
parsed_dict["isLive"] = is_live
media_presentation_duration = mpd_dict["MPD"].get("@mediaPresentationDuration")
# Parse additional MPD attributes for live streams
if is_live:
parsed_dict["minimumUpdatePeriod"] = parse_duration(mpd_dict["MPD"].get("@minimumUpdatePeriod", "PT0S"))
parsed_dict["timeShiftBufferDepth"] = parse_duration(mpd_dict["MPD"].get("@timeShiftBufferDepth", "PT2M"))
parsed_dict["availabilityStartTime"] = datetime.fromisoformat(
mpd_dict["MPD"]["@availabilityStartTime"].replace("Z", "+00:00")
)
parsed_dict["publishTime"] = datetime.fromisoformat(
mpd_dict["MPD"].get("@publishTime", "").replace("Z", "+00:00")
)
periods = mpd_dict["MPD"]["Period"]
periods = periods if isinstance(periods, list) else [periods]
for period in periods:
parsed_dict["PeriodStart"] = parse_duration(period.get("@start", "PT0S"))
for adaptation in period["AdaptationSet"]:
representations = adaptation["Representation"]
representations = representations if isinstance(representations, list) else [representations]
for representation in representations:
profile = parse_representation(
parsed_dict,
representation,
adaptation,
source,
media_presentation_duration,
parse_segment_profile_id,
)
if profile:
profiles.append(profile)
parsed_dict["profiles"] = profiles
if parse_drm:
drm_info = extract_drm_info(periods, mpd_url)
else:
drm_info = {}
parsed_dict["drmInfo"] = drm_info
return parsed_dict
def pad_base64(encoded_key_id):
"""
Pads a base64 encoded key ID to make its length a multiple of 4.
Args:
encoded_key_id (str): The base64 encoded key ID.
Returns:
str: The padded base64 encoded key ID.
"""
return encoded_key_id + "=" * (4 - len(encoded_key_id) % 4)
def extract_drm_info(periods: List[Dict], mpd_url: str) -> Dict:
"""
Extracts DRM information from the MPD periods.
Args:
periods (List[Dict]): The list of periods in the MPD.
mpd_url (str): The URL of the MPD manifest.
Returns:
Dict: The extracted DRM information.
This function processes the ContentProtection elements in the MPD to extract DRM system information,
such as ClearKey, Widevine, and PlayReady.
"""
drm_info = {"isDrmProtected": False}
for period in periods:
adaptation_sets: Union[list[dict], dict] = period.get("AdaptationSet", [])
if not isinstance(adaptation_sets, list):
adaptation_sets = [adaptation_sets]
for adaptation_set in adaptation_sets:
# Check ContentProtection in AdaptationSet
process_content_protection(adaptation_set.get("ContentProtection", []), drm_info)
# Check ContentProtection inside each Representation
representations: Union[list[dict], dict] = adaptation_set.get("Representation", [])
if not isinstance(representations, list):
representations = [representations]
for representation in representations:
process_content_protection(representation.get("ContentProtection", []), drm_info)
# If we have a license acquisition URL, make sure it's absolute
if "laUrl" in drm_info and not drm_info["laUrl"].startswith(("http://", "https://")):
drm_info["laUrl"] = urljoin(mpd_url, drm_info["laUrl"])
return drm_info
def process_content_protection(content_protection: Union[list[dict], dict], drm_info: dict):
"""
Processes the ContentProtection elements to extract DRM information.
Args:
content_protection (Union[list[dict], dict]): The ContentProtection elements.
drm_info (dict): The dictionary to store DRM information.
This function updates the drm_info dictionary with DRM system information found in the ContentProtection elements.
"""
if not isinstance(content_protection, list):
content_protection = [content_protection]
for protection in content_protection:
drm_info["isDrmProtected"] = True
scheme_id_uri = protection.get("@schemeIdUri", "").lower()
if "clearkey" in scheme_id_uri:
drm_info["drmSystem"] = "clearkey"
if "clearkey:Laurl" in protection:
la_url = protection["clearkey:Laurl"].get("#text")
if la_url and "laUrl" not in drm_info:
drm_info["laUrl"] = la_url
elif "widevine" in scheme_id_uri or "edef8ba9-79d6-4ace-a3c8-27dcd51d21ed" in scheme_id_uri:
drm_info["drmSystem"] = "widevine"
pssh = protection.get("cenc:pssh", {}).get("#text")
if pssh:
drm_info["pssh"] = pssh
elif "playready" in scheme_id_uri or "9a04f079-9840-4286-ab92-e65be0885f95" in scheme_id_uri:
drm_info["drmSystem"] = "playready"
if "@cenc:default_KID" in protection:
key_id = protection["@cenc:default_KID"].replace("-", "")
if "keyId" not in drm_info:
drm_info["keyId"] = key_id
if "ms:laurl" in protection:
la_url = protection["ms:laurl"].get("@licenseUrl")
if la_url and "laUrl" not in drm_info:
drm_info["laUrl"] = la_url
return drm_info
def parse_representation(
parsed_dict: dict,
representation: dict,
adaptation: dict,
source: str,
media_presentation_duration: str,
parse_segment_profile_id: Optional[str],
) -> Optional[dict]:
"""
Parses a representation and extracts profile information.
Args:
parsed_dict (dict): The parsed MPD data.
representation (dict): The representation data.
adaptation (dict): The adaptation set data.
source (str): The source URL.
media_presentation_duration (str): The media presentation duration.
parse_segment_profile_id (str, optional): The profile ID to parse segments for. Defaults to None.
Returns:
Optional[dict]: The parsed profile information or None if not applicable.
"""
mime_type = _get_key(adaptation, representation, "@mimeType") or (
"video/mp4" if "avc" in representation["@codecs"] else "audio/mp4"
)
if "video" not in mime_type and "audio" not in mime_type:
return None
profile = {
"id": representation.get("@id") or adaptation.get("@id"),
"mimeType": mime_type,
"lang": representation.get("@lang") or adaptation.get("@lang"),
"codecs": representation.get("@codecs") or adaptation.get("@codecs"),
"bandwidth": int(representation.get("@bandwidth") or adaptation.get("@bandwidth")),
"startWithSAP": (_get_key(adaptation, representation, "@startWithSAP") or "1") == "1",
"mediaPresentationDuration": media_presentation_duration,
}
if "audio" in profile["mimeType"]:
profile["audioSamplingRate"] = representation.get("@audioSamplingRate") or adaptation.get("@audioSamplingRate")
profile["channels"] = representation.get("AudioChannelConfiguration", {}).get("@value", "2")
else:
profile["width"] = int(representation["@width"])
profile["height"] = int(representation["@height"])
frame_rate = representation.get("@frameRate") or adaptation.get("@maxFrameRate") or "30000/1001"
frame_rate = frame_rate if "/" in frame_rate else f"{frame_rate}/1"
profile["frameRate"] = round(int(frame_rate.split("/")[0]) / int(frame_rate.split("/")[1]), 3)
profile["sar"] = representation.get("@sar", "1:1")
if parse_segment_profile_id is None or profile["id"] != parse_segment_profile_id:
return profile
item = adaptation.get("SegmentTemplate") or representation.get("SegmentTemplate")
if item:
profile["segments"] = parse_segment_template(parsed_dict, item, profile, source)
else:
profile["segments"] = parse_segment_base(representation, source)
return profile
def _get_key(adaptation: dict, representation: dict, key: str) -> Optional[str]:
"""
Retrieves a key from the representation or adaptation set.
Args:
adaptation (dict): The adaptation set data.
representation (dict): The representation data.
key (str): The key to retrieve.
Returns:
Optional[str]: The value of the key or None if not found.
"""
return representation.get(key, adaptation.get(key, None))
def parse_segment_template(parsed_dict: dict, item: dict, profile: dict, source: str) -> List[Dict]:
"""
Parses a segment template and extracts segment information.
Args:
parsed_dict (dict): The parsed MPD data.
item (dict): The segment template data.
profile (dict): The profile information.
source (str): The source URL.
Returns:
List[Dict]: The list of parsed segments.
"""
segments = []
timescale = int(item.get("@timescale", 1))
# Initialization
if "@initialization" in item:
media = item["@initialization"]
media = media.replace("$RepresentationID$", profile["id"])
media = media.replace("$Bandwidth$", str(profile["bandwidth"]))
if not media.startswith("http"):
media = f"{source}/{media}"
profile["initUrl"] = media
# Segments
if "SegmentTimeline" in item:
segments.extend(parse_segment_timeline(parsed_dict, item, profile, source, timescale))
elif "@duration" in item:
segments.extend(parse_segment_duration(parsed_dict, item, profile, source, timescale))
return segments
def parse_segment_timeline(parsed_dict: dict, item: dict, profile: dict, source: str, timescale: int) -> List[Dict]:
"""
Parses a segment timeline and extracts segment information.
Args:
parsed_dict (dict): The parsed MPD data.
item (dict): The segment timeline data.
profile (dict): The profile information.
source (str): The source URL.
timescale (int): The timescale for the segments.
Returns:
List[Dict]: The list of parsed segments.
"""
timelines = item["SegmentTimeline"]["S"]
timelines = timelines if isinstance(timelines, list) else [timelines]
period_start = parsed_dict["availabilityStartTime"] + timedelta(seconds=parsed_dict.get("PeriodStart", 0))
presentation_time_offset = int(item.get("@presentationTimeOffset", 0))
start_number = int(item.get("@startNumber", 1))
segments = [
create_segment_data(timeline, item, profile, source, timescale)
for timeline in preprocess_timeline(timelines, start_number, period_start, presentation_time_offset, timescale)
]
return segments
def preprocess_timeline(
timelines: List[Dict], start_number: int, period_start: datetime, presentation_time_offset: int, timescale: int
) -> List[Dict]:
"""
Preprocesses the segment timeline data.
Args:
timelines (List[Dict]): The list of timeline segments.
start_number (int): The starting segment number.
period_start (datetime): The start time of the period.
presentation_time_offset (int): The presentation time offset.
timescale (int): The timescale for the segments.
Returns:
List[Dict]: The list of preprocessed timeline segments.
"""
processed_data = []
current_time = 0
for timeline in timelines:
repeat = int(timeline.get("@r", 0))
duration = int(timeline["@d"])
start_time = int(timeline.get("@t", current_time))
for _ in range(repeat + 1):
segment_start_time = period_start + timedelta(seconds=(start_time - presentation_time_offset) / timescale)
segment_end_time = segment_start_time + timedelta(seconds=duration / timescale)
processed_data.append(
{
"number": start_number,
"start_time": segment_start_time,
"end_time": segment_end_time,
"duration": duration,
"time": start_time,
}
)
start_time += duration
start_number += 1
current_time = start_time
return processed_data
def parse_segment_duration(parsed_dict: dict, item: dict, profile: dict, source: str, timescale: int) -> List[Dict]:
"""
Parses segment duration and extracts segment information.
This is used for static or live MPD manifests.
Args:
parsed_dict (dict): The parsed MPD data.
item (dict): The segment duration data.
profile (dict): The profile information.
source (str): The source URL.
timescale (int): The timescale for the segments.
Returns:
List[Dict]: The list of parsed segments.
"""
duration = int(item["@duration"])
start_number = int(item.get("@startNumber", 1))
segment_duration_sec = duration / timescale
if parsed_dict["isLive"]:
segments = generate_live_segments(parsed_dict, segment_duration_sec, start_number)
else:
segments = generate_vod_segments(profile, duration, timescale, start_number)
return [create_segment_data(seg, item, profile, source, timescale) for seg in segments]
def generate_live_segments(parsed_dict: dict, segment_duration_sec: float, start_number: int) -> List[Dict]:
"""
Generates live segments based on the segment duration and start number.
This is used for live MPD manifests.
Args:
parsed_dict (dict): The parsed MPD data.
segment_duration_sec (float): The segment duration in seconds.
start_number (int): The starting segment number.
Returns:
List[Dict]: The list of generated live segments.
"""
time_shift_buffer_depth = timedelta(seconds=parsed_dict.get("timeShiftBufferDepth", 60))
segment_count = math.ceil(time_shift_buffer_depth.total_seconds() / segment_duration_sec)
current_time = datetime.now(tz=timezone.utc)
earliest_segment_number = max(
start_number
+ math.floor((current_time - parsed_dict["availabilityStartTime"]).total_seconds() / segment_duration_sec)
- segment_count,
start_number,
)
return [
{
"number": number,
"start_time": parsed_dict["availabilityStartTime"]
+ timedelta(seconds=(number - start_number) * segment_duration_sec),
"duration": segment_duration_sec,
}
for number in range(earliest_segment_number, earliest_segment_number + segment_count)
]
def generate_vod_segments(profile: dict, duration: int, timescale: int, start_number: int) -> List[Dict]:
"""
Generates VOD segments based on the segment duration and start number.
This is used for static MPD manifests.
Args:
profile (dict): The profile information.
duration (int): The segment duration.
timescale (int): The timescale for the segments.
start_number (int): The starting segment number.
Returns:
List[Dict]: The list of generated VOD segments.
"""
total_duration = profile.get("mediaPresentationDuration") or 0
if isinstance(total_duration, str):
total_duration = parse_duration(total_duration)
segment_count = math.ceil(total_duration * timescale / duration)
return [{"number": start_number + i, "duration": duration / timescale} for i in range(segment_count)]
def create_segment_data(segment: Dict, item: dict, profile: dict, source: str, timescale: Optional[int] = None) -> Dict:
"""
Creates segment data based on the segment information. This includes the segment URL and metadata.
Args:
segment (Dict): The segment information.
item (dict): The segment template data.
profile (dict): The profile information.
source (str): The source URL.
timescale (int, optional): The timescale for the segments. Defaults to None.
Returns:
Dict: The created segment data.
"""
media_template = item["@media"]
media = media_template.replace("$RepresentationID$", profile["id"])
media = media.replace("$Number%04d$", f"{segment['number']:04d}")
media = media.replace("$Number$", str(segment["number"]))
media = media.replace("$Bandwidth$", str(profile["bandwidth"]))
if "time" in segment and timescale is not None:
media = media.replace("$Time$", str(int(segment["time"] * timescale)))
if not media.startswith("http"):
media = f"{source}/{media}"
segment_data = {
"type": "segment",
"media": media,
"number": segment["number"],
}
if "start_time" in segment and "end_time" in segment:
segment_data.update(
{
"start_time": segment["start_time"],
"end_time": segment["end_time"],
"extinf": (segment["end_time"] - segment["start_time"]).total_seconds(),
"program_date_time": segment["start_time"].isoformat() + "Z",
}
)
elif "start_time" in segment and "duration" in segment:
duration = segment["duration"]
segment_data.update(
{
"start_time": segment["start_time"],
"end_time": segment["start_time"] + timedelta(seconds=duration),
"extinf": duration,
"program_date_time": segment["start_time"].isoformat() + "Z",
}
)
elif "duration" in segment:
segment_data["extinf"] = segment["duration"]
return segment_data
def parse_segment_base(representation: dict, source: str) -> List[Dict]:
"""
Parses segment base information and extracts segment data. This is used for single-segment representations.
Args:
representation (dict): The representation data.
source (str): The source URL.
Returns:
List[Dict]: The list of parsed segments.
"""
segment = representation["SegmentBase"]
start, end = map(int, segment["@indexRange"].split("-"))
if "Initialization" in segment:
start, _ = map(int, segment["Initialization"]["@range"].split("-"))
return [
{
"type": "segment",
"range": f"{start}-{end}",
"media": f"{source}/{representation['BaseURL']}",
}
]
def parse_duration(duration_str: str) -> float:
"""
Parses a duration ISO 8601 string into seconds.
Args:
duration_str (str): The duration string to parse.
Returns:
float: The parsed duration in seconds.
"""
pattern = re.compile(r"P(?:(\d+)Y)?(?:(\d+)M)?(?:(\d+)D)?T?(?:(\d+)H)?(?:(\d+)M)?(?:(\d+(?:\.\d+)?)S)?")
match = pattern.match(duration_str)
if not match:
raise ValueError(f"Invalid duration format: {duration_str}")
years, months, days, hours, minutes, seconds = [float(g) if g else 0 for g in match.groups()]
return years * 365 * 24 * 3600 + months * 30 * 24 * 3600 + days * 24 * 3600 + hours * 3600 + minutes * 60 + seconds

19
requirements.txt Normal file
View File

@@ -0,0 +1,19 @@
bs4
dateparser
python-dotenv
fastapi
uvicorn
tzdata
lxml
curl_cffi
fake-headers
pydantic_settings
httpx
Crypto
pycryptodome
tenacity
xmltodict
starlette
cachetools
tqdm
aiofiles

18
run.py Normal file
View File

@@ -0,0 +1,18 @@
from fastapi import FastAPI
from mediaflow_proxy.main import app as mediaflow_app # Import mediaflow app
import httpx
import re
import string
# Initialize the main FastAPI application
main_app = FastAPI()
# Manually add only non-static routes from mediaflow_app
for route in mediaflow_app.routes:
if route.path != "/": # Exclude the static file path
main_app.router.routes.append(route)
# Run the main app
if __name__ == "__main__":
import uvicorn
uvicorn.run(main_app, host="0.0.0.0", port=8080)