mirror of
https://github.com/UrloMythus/UnHided.git
synced 2026-04-09 02:40:47 +00:00
Add files via upload
This commit is contained in:
0
mediaflow_proxy/__init__.py
Normal file
0
mediaflow_proxy/__init__.py
Normal file
68
mediaflow_proxy/configs.py
Normal file
68
mediaflow_proxy/configs.py
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
from typing import Dict, Optional, Union
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
|
|
||||||
|
class RouteConfig(BaseModel):
|
||||||
|
"""Configuration for a specific route"""
|
||||||
|
|
||||||
|
proxy: bool = True
|
||||||
|
proxy_url: Optional[str] = None
|
||||||
|
verify_ssl: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class TransportConfig(BaseSettings):
|
||||||
|
"""Main proxy configuration"""
|
||||||
|
|
||||||
|
proxy_url: Optional[str] = Field(
|
||||||
|
None, description="Primary proxy URL. Example: socks5://user:pass@proxy:1080 or http://proxy:8080"
|
||||||
|
)
|
||||||
|
all_proxy: bool = Field(False, description="Enable proxy for all routes by default")
|
||||||
|
transport_routes: Dict[str, RouteConfig] = Field(
|
||||||
|
default_factory=dict, description="Pattern-based route configuration"
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_mounts(
|
||||||
|
self, async_http: bool = True
|
||||||
|
) -> Dict[str, Optional[Union[httpx.HTTPTransport, httpx.AsyncHTTPTransport]]]:
|
||||||
|
"""
|
||||||
|
Get a dictionary of httpx mount points to transport instances.
|
||||||
|
"""
|
||||||
|
mounts = {}
|
||||||
|
transport_cls = httpx.AsyncHTTPTransport if async_http else httpx.HTTPTransport
|
||||||
|
|
||||||
|
# Configure specific routes
|
||||||
|
for pattern, route in self.transport_routes.items():
|
||||||
|
mounts[pattern] = transport_cls(
|
||||||
|
verify=route.verify_ssl, proxy=route.proxy_url or self.proxy_url if route.proxy else None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set default proxy for all routes if enabled
|
||||||
|
if self.all_proxy:
|
||||||
|
mounts["all://"] = transport_cls(proxy=self.proxy_url)
|
||||||
|
|
||||||
|
return mounts
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
env_file = ".env"
|
||||||
|
extra = "ignore"
|
||||||
|
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
api_password: str | None = None # The password for protecting the API endpoints.
|
||||||
|
log_level: str = "INFO" # The logging level to use.
|
||||||
|
transport_config: TransportConfig = Field(default_factory=TransportConfig) # Configuration for httpx transport.
|
||||||
|
enable_streaming_progress: bool = False # Whether to enable streaming progress tracking.
|
||||||
|
|
||||||
|
user_agent: str = (
|
||||||
|
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3" # The user agent to use for HTTP requests.
|
||||||
|
)
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
env_file = ".env"
|
||||||
|
extra = "ignore"
|
||||||
|
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
17
mediaflow_proxy/const.py
Normal file
17
mediaflow_proxy/const.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
SUPPORTED_RESPONSE_HEADERS = [
|
||||||
|
"accept-ranges",
|
||||||
|
"content-type",
|
||||||
|
"content-length",
|
||||||
|
"content-range",
|
||||||
|
"connection",
|
||||||
|
"transfer-encoding",
|
||||||
|
"last-modified",
|
||||||
|
"etag",
|
||||||
|
"cache-control",
|
||||||
|
"expires",
|
||||||
|
]
|
||||||
|
|
||||||
|
SUPPORTED_REQUEST_HEADERS = [
|
||||||
|
"range",
|
||||||
|
"if-range",
|
||||||
|
]
|
||||||
11
mediaflow_proxy/drm/__init__.py
Normal file
11
mediaflow_proxy/drm/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
|
||||||
|
async def create_temp_file(suffix: str, content: bytes = None, prefix: str = None) -> tempfile.NamedTemporaryFile:
|
||||||
|
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=suffix, prefix=prefix)
|
||||||
|
temp_file.delete_file = lambda: os.unlink(temp_file.name)
|
||||||
|
if content:
|
||||||
|
temp_file.write(content)
|
||||||
|
temp_file.close()
|
||||||
|
return temp_file
|
||||||
778
mediaflow_proxy/drm/decrypter.py
Normal file
778
mediaflow_proxy/drm/decrypter.py
Normal file
@@ -0,0 +1,778 @@
|
|||||||
|
import argparse
|
||||||
|
import struct
|
||||||
|
import sys
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
from Crypto.Cipher import AES
|
||||||
|
from collections import namedtuple
|
||||||
|
import array
|
||||||
|
|
||||||
|
CENCSampleAuxiliaryDataFormat = namedtuple("CENCSampleAuxiliaryDataFormat", ["is_encrypted", "iv", "sub_samples"])
|
||||||
|
|
||||||
|
|
||||||
|
class MP4Atom:
|
||||||
|
"""
|
||||||
|
Represents an MP4 atom, which is a basic unit of data in an MP4 file.
|
||||||
|
Each atom contains a header (size and type) and data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__slots__ = ("atom_type", "size", "data")
|
||||||
|
|
||||||
|
def __init__(self, atom_type: bytes, size: int, data: Union[memoryview, bytearray]):
|
||||||
|
"""
|
||||||
|
Initializes an MP4Atom instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
atom_type (bytes): The type of the atom.
|
||||||
|
size (int): The size of the atom.
|
||||||
|
data (Union[memoryview, bytearray]): The data contained in the atom.
|
||||||
|
"""
|
||||||
|
self.atom_type = atom_type
|
||||||
|
self.size = size
|
||||||
|
self.data = data
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<MP4Atom type={self.atom_type}, size={self.size}>"
|
||||||
|
|
||||||
|
def pack(self):
|
||||||
|
"""
|
||||||
|
Packs the atom into binary data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bytes: Packed binary data with size, type, and data.
|
||||||
|
"""
|
||||||
|
return struct.pack(">I", self.size) + self.atom_type + self.data
|
||||||
|
|
||||||
|
|
||||||
|
class MP4Parser:
|
||||||
|
"""
|
||||||
|
Parses MP4 data to extract atoms and their structure.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, data: memoryview):
|
||||||
|
"""
|
||||||
|
Initializes an MP4Parser instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data (memoryview): The binary data of the MP4 file.
|
||||||
|
"""
|
||||||
|
self.data = data
|
||||||
|
self.position = 0
|
||||||
|
|
||||||
|
def read_atom(self) -> Optional[MP4Atom]:
|
||||||
|
"""
|
||||||
|
Reads the next atom from the data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[MP4Atom]: MP4Atom object or None if no more atoms are available.
|
||||||
|
"""
|
||||||
|
pos = self.position
|
||||||
|
if pos + 8 > len(self.data):
|
||||||
|
return None
|
||||||
|
|
||||||
|
size, atom_type = struct.unpack_from(">I4s", self.data, pos)
|
||||||
|
pos += 8
|
||||||
|
|
||||||
|
if size == 1:
|
||||||
|
if pos + 8 > len(self.data):
|
||||||
|
return None
|
||||||
|
size = struct.unpack_from(">Q", self.data, pos)[0]
|
||||||
|
pos += 8
|
||||||
|
|
||||||
|
if size < 8 or pos + size - 8 > len(self.data):
|
||||||
|
return None
|
||||||
|
|
||||||
|
atom_data = self.data[pos : pos + size - 8]
|
||||||
|
self.position = pos + size - 8
|
||||||
|
return MP4Atom(atom_type, size, atom_data)
|
||||||
|
|
||||||
|
def list_atoms(self) -> list[MP4Atom]:
|
||||||
|
"""
|
||||||
|
Lists all atoms in the data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[MP4Atom]: List of MP4Atom objects.
|
||||||
|
"""
|
||||||
|
atoms = []
|
||||||
|
original_position = self.position
|
||||||
|
self.position = 0
|
||||||
|
while self.position + 8 <= len(self.data):
|
||||||
|
atom = self.read_atom()
|
||||||
|
if not atom:
|
||||||
|
break
|
||||||
|
atoms.append(atom)
|
||||||
|
self.position = original_position
|
||||||
|
return atoms
|
||||||
|
|
||||||
|
def _read_atom_at(self, pos: int, end: int) -> Optional[MP4Atom]:
|
||||||
|
if pos + 8 > end:
|
||||||
|
return None
|
||||||
|
|
||||||
|
size, atom_type = struct.unpack_from(">I4s", self.data, pos)
|
||||||
|
pos += 8
|
||||||
|
|
||||||
|
if size == 1:
|
||||||
|
if pos + 8 > end:
|
||||||
|
return None
|
||||||
|
size = struct.unpack_from(">Q", self.data, pos)[0]
|
||||||
|
pos += 8
|
||||||
|
|
||||||
|
if size < 8 or pos + size - 8 > end:
|
||||||
|
return None
|
||||||
|
|
||||||
|
atom_data = self.data[pos : pos + size - 8]
|
||||||
|
return MP4Atom(atom_type, size, atom_data)
|
||||||
|
|
||||||
|
def print_atoms_structure(self, indent: int = 0):
|
||||||
|
"""
|
||||||
|
Prints the structure of all atoms in the data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
indent (int): The indentation level for printing.
|
||||||
|
"""
|
||||||
|
pos = 0
|
||||||
|
end = len(self.data)
|
||||||
|
while pos + 8 <= end:
|
||||||
|
atom = self._read_atom_at(pos, end)
|
||||||
|
if not atom:
|
||||||
|
break
|
||||||
|
self.print_single_atom_structure(atom, pos, indent)
|
||||||
|
pos += atom.size
|
||||||
|
|
||||||
|
def print_single_atom_structure(self, atom: MP4Atom, parent_position: int, indent: int):
|
||||||
|
"""
|
||||||
|
Prints the structure of a single atom.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
atom (MP4Atom): The atom to print.
|
||||||
|
parent_position (int): The position of the parent atom.
|
||||||
|
indent (int): The indentation level for printing.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
atom_type = atom.atom_type.decode("utf-8")
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
atom_type = repr(atom.atom_type)
|
||||||
|
print(" " * indent + f"Type: {atom_type}, Size: {atom.size}")
|
||||||
|
|
||||||
|
child_pos = 0
|
||||||
|
child_end = len(atom.data)
|
||||||
|
while child_pos + 8 <= child_end:
|
||||||
|
child_atom = self._read_atom_at(parent_position + 8 + child_pos, parent_position + 8 + child_end)
|
||||||
|
if not child_atom:
|
||||||
|
break
|
||||||
|
self.print_single_atom_structure(child_atom, parent_position, indent + 2)
|
||||||
|
child_pos += child_atom.size
|
||||||
|
|
||||||
|
|
||||||
|
class MP4Decrypter:
|
||||||
|
"""
|
||||||
|
Class to handle the decryption of CENC encrypted MP4 segments.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
key_map (dict[bytes, bytes]): Mapping of track IDs to decryption keys.
|
||||||
|
current_key (Optional[bytes]): Current decryption key.
|
||||||
|
trun_sample_sizes (array.array): Array of sample sizes from the 'trun' box.
|
||||||
|
current_sample_info (list): List of sample information from the 'senc' box.
|
||||||
|
encryption_overhead (int): Total size of encryption-related boxes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, key_map: dict[bytes, bytes]):
|
||||||
|
"""
|
||||||
|
Initializes the MP4Decrypter with a key map.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key_map (dict[bytes, bytes]): Mapping of track IDs to decryption keys.
|
||||||
|
"""
|
||||||
|
self.key_map = key_map
|
||||||
|
self.current_key = None
|
||||||
|
self.trun_sample_sizes = array.array("I")
|
||||||
|
self.current_sample_info = []
|
||||||
|
self.encryption_overhead = 0
|
||||||
|
|
||||||
|
def decrypt_segment(self, combined_segment: bytes) -> bytes:
|
||||||
|
"""
|
||||||
|
Decrypts a combined MP4 segment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
combined_segment (bytes): Combined initialization and media segment.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bytes: Decrypted segment content.
|
||||||
|
"""
|
||||||
|
data = memoryview(combined_segment)
|
||||||
|
parser = MP4Parser(data)
|
||||||
|
atoms = parser.list_atoms()
|
||||||
|
|
||||||
|
atom_process_order = [b"moov", b"moof", b"sidx", b"mdat"]
|
||||||
|
|
||||||
|
processed_atoms = {}
|
||||||
|
for atom_type in atom_process_order:
|
||||||
|
if atom := next((a for a in atoms if a.atom_type == atom_type), None):
|
||||||
|
processed_atoms[atom_type] = self._process_atom(atom_type, atom)
|
||||||
|
|
||||||
|
result = bytearray()
|
||||||
|
for atom in atoms:
|
||||||
|
if atom.atom_type in processed_atoms:
|
||||||
|
processed_atom = processed_atoms[atom.atom_type]
|
||||||
|
result.extend(processed_atom.pack())
|
||||||
|
else:
|
||||||
|
result.extend(atom.pack())
|
||||||
|
|
||||||
|
return bytes(result)
|
||||||
|
|
||||||
|
def _process_atom(self, atom_type: bytes, atom: MP4Atom) -> MP4Atom:
|
||||||
|
"""
|
||||||
|
Processes an MP4 atom based on its type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
atom_type (bytes): Type of the atom.
|
||||||
|
atom (MP4Atom): The atom to process.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MP4Atom: Processed atom.
|
||||||
|
"""
|
||||||
|
if atom_type == b"moov":
|
||||||
|
return self._process_moov(atom)
|
||||||
|
elif atom_type == b"moof":
|
||||||
|
return self._process_moof(atom)
|
||||||
|
elif atom_type == b"sidx":
|
||||||
|
return self._process_sidx(atom)
|
||||||
|
elif atom_type == b"mdat":
|
||||||
|
return self._decrypt_mdat(atom)
|
||||||
|
else:
|
||||||
|
return atom
|
||||||
|
|
||||||
|
def _process_moov(self, moov: MP4Atom) -> MP4Atom:
|
||||||
|
"""
|
||||||
|
Processes the 'moov' (Movie) atom, which contains metadata about the entire presentation.
|
||||||
|
This includes information about tracks, media data, and other movie-level metadata.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
moov (MP4Atom): The 'moov' atom to process.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MP4Atom: Processed 'moov' atom with updated track information.
|
||||||
|
"""
|
||||||
|
parser = MP4Parser(moov.data)
|
||||||
|
new_moov_data = bytearray()
|
||||||
|
|
||||||
|
for atom in iter(parser.read_atom, None):
|
||||||
|
if atom.atom_type == b"trak":
|
||||||
|
new_trak = self._process_trak(atom)
|
||||||
|
new_moov_data.extend(new_trak.pack())
|
||||||
|
elif atom.atom_type != b"pssh":
|
||||||
|
# Skip PSSH boxes as they are not needed in the decrypted output
|
||||||
|
new_moov_data.extend(atom.pack())
|
||||||
|
|
||||||
|
return MP4Atom(b"moov", len(new_moov_data) + 8, new_moov_data)
|
||||||
|
|
||||||
|
def _process_moof(self, moof: MP4Atom) -> MP4Atom:
|
||||||
|
"""
|
||||||
|
Processes the 'moov' (Movie) atom, which contains metadata about the entire presentation.
|
||||||
|
This includes information about tracks, media data, and other movie-level metadata.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
moov (MP4Atom): The 'moov' atom to process.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MP4Atom: Processed 'moov' atom with updated track information.
|
||||||
|
"""
|
||||||
|
parser = MP4Parser(moof.data)
|
||||||
|
new_moof_data = bytearray()
|
||||||
|
|
||||||
|
for atom in iter(parser.read_atom, None):
|
||||||
|
if atom.atom_type == b"traf":
|
||||||
|
new_traf = self._process_traf(atom)
|
||||||
|
new_moof_data.extend(new_traf.pack())
|
||||||
|
else:
|
||||||
|
new_moof_data.extend(atom.pack())
|
||||||
|
|
||||||
|
return MP4Atom(b"moof", len(new_moof_data) + 8, new_moof_data)
|
||||||
|
|
||||||
|
def _process_traf(self, traf: MP4Atom) -> MP4Atom:
|
||||||
|
"""
|
||||||
|
Processes the 'traf' (Track Fragment) atom, which contains information about a track fragment.
|
||||||
|
This includes sample information, sample encryption data, and other track-level metadata.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
traf (MP4Atom): The 'traf' atom to process.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MP4Atom: Processed 'traf' atom with updated sample information.
|
||||||
|
"""
|
||||||
|
parser = MP4Parser(traf.data)
|
||||||
|
new_traf_data = bytearray()
|
||||||
|
tfhd = None
|
||||||
|
sample_count = 0
|
||||||
|
sample_info = []
|
||||||
|
|
||||||
|
atoms = parser.list_atoms()
|
||||||
|
|
||||||
|
# calculate encryption_overhead earlier to avoid dependency on trun
|
||||||
|
self.encryption_overhead = sum(a.size for a in atoms if a.atom_type in {b"senc", b"saiz", b"saio"})
|
||||||
|
|
||||||
|
for atom in atoms:
|
||||||
|
if atom.atom_type == b"tfhd":
|
||||||
|
tfhd = atom
|
||||||
|
new_traf_data.extend(atom.pack())
|
||||||
|
elif atom.atom_type == b"trun":
|
||||||
|
sample_count = self._process_trun(atom)
|
||||||
|
new_trun = self._modify_trun(atom)
|
||||||
|
new_traf_data.extend(new_trun.pack())
|
||||||
|
elif atom.atom_type == b"senc":
|
||||||
|
# Parse senc but don't include it in the new decrypted traf data and similarly don't include saiz and saio
|
||||||
|
sample_info = self._parse_senc(atom, sample_count)
|
||||||
|
elif atom.atom_type not in {b"saiz", b"saio"}:
|
||||||
|
new_traf_data.extend(atom.pack())
|
||||||
|
|
||||||
|
if tfhd:
|
||||||
|
tfhd_track_id = struct.unpack_from(">I", tfhd.data, 4)[0]
|
||||||
|
self.current_key = self._get_key_for_track(tfhd_track_id)
|
||||||
|
self.current_sample_info = sample_info
|
||||||
|
|
||||||
|
return MP4Atom(b"traf", len(new_traf_data) + 8, new_traf_data)
|
||||||
|
|
||||||
|
def _decrypt_mdat(self, mdat: MP4Atom) -> MP4Atom:
|
||||||
|
"""
|
||||||
|
Decrypts the 'mdat' (Media Data) atom, which contains the actual media data (audio, video, etc.).
|
||||||
|
The decryption is performed using the current decryption key and sample information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mdat (MP4Atom): The 'mdat' atom to decrypt.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MP4Atom: Decrypted 'mdat' atom with decrypted media data.
|
||||||
|
"""
|
||||||
|
if not self.current_key or not self.current_sample_info:
|
||||||
|
return mdat # Return original mdat if we don't have decryption info
|
||||||
|
|
||||||
|
decrypted_samples = bytearray()
|
||||||
|
mdat_data = mdat.data
|
||||||
|
position = 0
|
||||||
|
|
||||||
|
for i, info in enumerate(self.current_sample_info):
|
||||||
|
if position >= len(mdat_data):
|
||||||
|
break # No more data to process
|
||||||
|
|
||||||
|
sample_size = self.trun_sample_sizes[i] if i < len(self.trun_sample_sizes) else len(mdat_data) - position
|
||||||
|
sample = mdat_data[position : position + sample_size]
|
||||||
|
position += sample_size
|
||||||
|
decrypted_sample = self._process_sample(sample, info, self.current_key)
|
||||||
|
decrypted_samples.extend(decrypted_sample)
|
||||||
|
|
||||||
|
return MP4Atom(b"mdat", len(decrypted_samples) + 8, decrypted_samples)
|
||||||
|
|
||||||
|
def _parse_senc(self, senc: MP4Atom, sample_count: int) -> list[CENCSampleAuxiliaryDataFormat]:
|
||||||
|
"""
|
||||||
|
Parses the 'senc' (Sample Encryption) atom, which contains encryption information for samples.
|
||||||
|
This includes initialization vectors (IVs) and sub-sample encryption data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
senc (MP4Atom): The 'senc' atom to parse.
|
||||||
|
sample_count (int): The number of samples.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[CENCSampleAuxiliaryDataFormat]: List of sample auxiliary data formats with encryption information.
|
||||||
|
"""
|
||||||
|
data = memoryview(senc.data)
|
||||||
|
version_flags = struct.unpack_from(">I", data, 0)[0]
|
||||||
|
version, flags = version_flags >> 24, version_flags & 0xFFFFFF
|
||||||
|
position = 4
|
||||||
|
|
||||||
|
if version == 0:
|
||||||
|
sample_count = struct.unpack_from(">I", data, position)[0]
|
||||||
|
position += 4
|
||||||
|
|
||||||
|
sample_info = []
|
||||||
|
for _ in range(sample_count):
|
||||||
|
if position + 8 > len(data):
|
||||||
|
break
|
||||||
|
|
||||||
|
iv = data[position : position + 8].tobytes()
|
||||||
|
position += 8
|
||||||
|
|
||||||
|
sub_samples = []
|
||||||
|
if flags & 0x000002 and position + 2 <= len(data): # Check if subsample information is present
|
||||||
|
subsample_count = struct.unpack_from(">H", data, position)[0]
|
||||||
|
position += 2
|
||||||
|
|
||||||
|
for _ in range(subsample_count):
|
||||||
|
if position + 6 <= len(data):
|
||||||
|
clear_bytes, encrypted_bytes = struct.unpack_from(">HI", data, position)
|
||||||
|
position += 6
|
||||||
|
sub_samples.append((clear_bytes, encrypted_bytes))
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
sample_info.append(CENCSampleAuxiliaryDataFormat(True, iv, sub_samples))
|
||||||
|
|
||||||
|
return sample_info
|
||||||
|
|
||||||
|
def _get_key_for_track(self, track_id: int) -> bytes:
|
||||||
|
"""
|
||||||
|
Retrieves the decryption key for a given track ID from the key map.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
track_id (int): The track ID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bytes: The decryption key for the specified track ID.
|
||||||
|
"""
|
||||||
|
if len(self.key_map) == 1:
|
||||||
|
return next(iter(self.key_map.values()))
|
||||||
|
key = self.key_map.get(track_id.pack(4, "big"))
|
||||||
|
if not key:
|
||||||
|
raise ValueError(f"No key found for track ID {track_id}")
|
||||||
|
return key
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _process_sample(
|
||||||
|
sample: memoryview, sample_info: CENCSampleAuxiliaryDataFormat, key: bytes
|
||||||
|
) -> Union[memoryview, bytearray, bytes]:
|
||||||
|
"""
|
||||||
|
Processes and decrypts a sample using the provided sample information and decryption key.
|
||||||
|
This includes handling sub-sample encryption if present.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sample (memoryview): The sample data.
|
||||||
|
sample_info (CENCSampleAuxiliaryDataFormat): The sample auxiliary data format with encryption information.
|
||||||
|
key (bytes): The decryption key.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Union[memoryview, bytearray, bytes]: The decrypted sample.
|
||||||
|
"""
|
||||||
|
if not sample_info.is_encrypted:
|
||||||
|
return sample
|
||||||
|
|
||||||
|
# pad IV to 16 bytes
|
||||||
|
iv = sample_info.iv + b"\x00" * (16 - len(sample_info.iv))
|
||||||
|
cipher = AES.new(key, AES.MODE_CTR, initial_value=iv, nonce=b"")
|
||||||
|
|
||||||
|
if not sample_info.sub_samples:
|
||||||
|
# If there are no sub_samples, decrypt the entire sample
|
||||||
|
return cipher.decrypt(sample)
|
||||||
|
|
||||||
|
result = bytearray()
|
||||||
|
offset = 0
|
||||||
|
for clear_bytes, encrypted_bytes in sample_info.sub_samples:
|
||||||
|
result.extend(sample[offset : offset + clear_bytes])
|
||||||
|
offset += clear_bytes
|
||||||
|
result.extend(cipher.decrypt(sample[offset : offset + encrypted_bytes]))
|
||||||
|
offset += encrypted_bytes
|
||||||
|
|
||||||
|
# If there's any remaining data, treat it as encrypted
|
||||||
|
if offset < len(sample):
|
||||||
|
result.extend(cipher.decrypt(sample[offset:]))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _process_trun(self, trun: MP4Atom) -> int:
|
||||||
|
"""
|
||||||
|
Processes the 'trun' (Track Fragment Run) atom, which contains information about the samples in a track fragment.
|
||||||
|
This includes sample sizes, durations, flags, and composition time offsets.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
trun (MP4Atom): The 'trun' atom to process.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: The number of samples in the 'trun' atom.
|
||||||
|
"""
|
||||||
|
trun_flags, sample_count = struct.unpack_from(">II", trun.data, 0)
|
||||||
|
data_offset = 8
|
||||||
|
|
||||||
|
if trun_flags & 0x000001:
|
||||||
|
data_offset += 4
|
||||||
|
if trun_flags & 0x000004:
|
||||||
|
data_offset += 4
|
||||||
|
|
||||||
|
self.trun_sample_sizes = array.array("I")
|
||||||
|
|
||||||
|
for _ in range(sample_count):
|
||||||
|
if trun_flags & 0x000100: # sample-duration-present flag
|
||||||
|
data_offset += 4
|
||||||
|
if trun_flags & 0x000200: # sample-size-present flag
|
||||||
|
sample_size = struct.unpack_from(">I", trun.data, data_offset)[0]
|
||||||
|
self.trun_sample_sizes.append(sample_size)
|
||||||
|
data_offset += 4
|
||||||
|
else:
|
||||||
|
self.trun_sample_sizes.append(0) # Using 0 instead of None for uniformity in the array
|
||||||
|
if trun_flags & 0x000400: # sample-flags-present flag
|
||||||
|
data_offset += 4
|
||||||
|
if trun_flags & 0x000800: # sample-composition-time-offsets-present flag
|
||||||
|
data_offset += 4
|
||||||
|
|
||||||
|
return sample_count
|
||||||
|
|
||||||
|
def _modify_trun(self, trun: MP4Atom) -> MP4Atom:
|
||||||
|
"""
|
||||||
|
Modifies the 'trun' (Track Fragment Run) atom to update the data offset.
|
||||||
|
This is necessary to account for the encryption overhead.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
trun (MP4Atom): The 'trun' atom to modify.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MP4Atom: Modified 'trun' atom with updated data offset.
|
||||||
|
"""
|
||||||
|
trun_data = bytearray(trun.data)
|
||||||
|
current_flags = struct.unpack_from(">I", trun_data, 0)[0] & 0xFFFFFF
|
||||||
|
|
||||||
|
# If the data-offset-present flag is set, update the data offset to account for encryption overhead
|
||||||
|
if current_flags & 0x000001:
|
||||||
|
current_data_offset = struct.unpack_from(">i", trun_data, 8)[0]
|
||||||
|
struct.pack_into(">i", trun_data, 8, current_data_offset - self.encryption_overhead)
|
||||||
|
|
||||||
|
return MP4Atom(b"trun", len(trun_data) + 8, trun_data)
|
||||||
|
|
||||||
|
def _process_sidx(self, sidx: MP4Atom) -> MP4Atom:
|
||||||
|
"""
|
||||||
|
Processes the 'sidx' (Segment Index) atom, which contains indexing information for media segments.
|
||||||
|
This includes references to media segments and their durations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sidx (MP4Atom): The 'sidx' atom to process.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MP4Atom: Processed 'sidx' atom with updated segment references.
|
||||||
|
"""
|
||||||
|
sidx_data = bytearray(sidx.data)
|
||||||
|
|
||||||
|
current_size = struct.unpack_from(">I", sidx_data, 32)[0]
|
||||||
|
reference_type = current_size >> 31
|
||||||
|
current_referenced_size = current_size & 0x7FFFFFFF
|
||||||
|
|
||||||
|
# Remove encryption overhead from referenced size
|
||||||
|
new_referenced_size = current_referenced_size - self.encryption_overhead
|
||||||
|
new_size = (reference_type << 31) | new_referenced_size
|
||||||
|
struct.pack_into(">I", sidx_data, 32, new_size)
|
||||||
|
|
||||||
|
return MP4Atom(b"sidx", len(sidx_data) + 8, sidx_data)
|
||||||
|
|
||||||
|
def _process_trak(self, trak: MP4Atom) -> MP4Atom:
|
||||||
|
"""
|
||||||
|
Processes the 'trak' (Track) atom, which contains information about a single track in the movie.
|
||||||
|
This includes track header, media information, and other track-level metadata.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
trak (MP4Atom): The 'trak' atom to process.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MP4Atom: Processed 'trak' atom with updated track information.
|
||||||
|
"""
|
||||||
|
parser = MP4Parser(trak.data)
|
||||||
|
new_trak_data = bytearray()
|
||||||
|
|
||||||
|
for atom in iter(parser.read_atom, None):
|
||||||
|
if atom.atom_type == b"mdia":
|
||||||
|
new_mdia = self._process_mdia(atom)
|
||||||
|
new_trak_data.extend(new_mdia.pack())
|
||||||
|
else:
|
||||||
|
new_trak_data.extend(atom.pack())
|
||||||
|
|
||||||
|
return MP4Atom(b"trak", len(new_trak_data) + 8, new_trak_data)
|
||||||
|
|
||||||
|
def _process_mdia(self, mdia: MP4Atom) -> MP4Atom:
|
||||||
|
"""
|
||||||
|
Processes the 'mdia' (Media) atom, which contains media information for a track.
|
||||||
|
This includes media header, handler reference, and media information container.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mdia (MP4Atom): The 'mdia' atom to process.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MP4Atom: Processed 'mdia' atom with updated media information.
|
||||||
|
"""
|
||||||
|
parser = MP4Parser(mdia.data)
|
||||||
|
new_mdia_data = bytearray()
|
||||||
|
|
||||||
|
for atom in iter(parser.read_atom, None):
|
||||||
|
if atom.atom_type == b"minf":
|
||||||
|
new_minf = self._process_minf(atom)
|
||||||
|
new_mdia_data.extend(new_minf.pack())
|
||||||
|
else:
|
||||||
|
new_mdia_data.extend(atom.pack())
|
||||||
|
|
||||||
|
return MP4Atom(b"mdia", len(new_mdia_data) + 8, new_mdia_data)
|
||||||
|
|
||||||
|
def _process_minf(self, minf: MP4Atom) -> MP4Atom:
|
||||||
|
"""
|
||||||
|
Processes the 'minf' (Media Information) atom, which contains information about the media data in a track.
|
||||||
|
This includes data information, sample table, and other media-level metadata.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
minf (MP4Atom): The 'minf' atom to process.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MP4Atom: Processed 'minf' atom with updated media information.
|
||||||
|
"""
|
||||||
|
parser = MP4Parser(minf.data)
|
||||||
|
new_minf_data = bytearray()
|
||||||
|
|
||||||
|
for atom in iter(parser.read_atom, None):
|
||||||
|
if atom.atom_type == b"stbl":
|
||||||
|
new_stbl = self._process_stbl(atom)
|
||||||
|
new_minf_data.extend(new_stbl.pack())
|
||||||
|
else:
|
||||||
|
new_minf_data.extend(atom.pack())
|
||||||
|
|
||||||
|
return MP4Atom(b"minf", len(new_minf_data) + 8, new_minf_data)
|
||||||
|
|
||||||
|
def _process_stbl(self, stbl: MP4Atom) -> MP4Atom:
|
||||||
|
"""
|
||||||
|
Processes the 'stbl' (Sample Table) atom, which contains information about the samples in a track.
|
||||||
|
This includes sample descriptions, sample sizes, sample times, and other sample-level metadata.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stbl (MP4Atom): The 'stbl' atom to process.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MP4Atom: Processed 'stbl' atom with updated sample information.
|
||||||
|
"""
|
||||||
|
parser = MP4Parser(stbl.data)
|
||||||
|
new_stbl_data = bytearray()
|
||||||
|
|
||||||
|
for atom in iter(parser.read_atom, None):
|
||||||
|
if atom.atom_type == b"stsd":
|
||||||
|
new_stsd = self._process_stsd(atom)
|
||||||
|
new_stbl_data.extend(new_stsd.pack())
|
||||||
|
else:
|
||||||
|
new_stbl_data.extend(atom.pack())
|
||||||
|
|
||||||
|
return MP4Atom(b"stbl", len(new_stbl_data) + 8, new_stbl_data)
|
||||||
|
|
||||||
|
def _process_stsd(self, stsd: MP4Atom) -> MP4Atom:
|
||||||
|
"""
|
||||||
|
Processes the 'stsd' (Sample Description) atom, which contains descriptions of the sample entries in a track.
|
||||||
|
This includes codec information, sample entry details, and other sample description metadata.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stsd (MP4Atom): The 'stsd' atom to process.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MP4Atom: Processed 'stsd' atom with updated sample descriptions.
|
||||||
|
"""
|
||||||
|
parser = MP4Parser(stsd.data)
|
||||||
|
entry_count = struct.unpack_from(">I", parser.data, 4)[0]
|
||||||
|
new_stsd_data = bytearray(stsd.data[:8])
|
||||||
|
|
||||||
|
parser.position = 8 # Move past version_flags and entry_count
|
||||||
|
|
||||||
|
for _ in range(entry_count):
|
||||||
|
sample_entry = parser.read_atom()
|
||||||
|
if not sample_entry:
|
||||||
|
break
|
||||||
|
|
||||||
|
processed_entry = self._process_sample_entry(sample_entry)
|
||||||
|
new_stsd_data.extend(processed_entry.pack())
|
||||||
|
|
||||||
|
return MP4Atom(b"stsd", len(new_stsd_data) + 8, new_stsd_data)
|
||||||
|
|
||||||
|
def _process_sample_entry(self, entry: MP4Atom) -> MP4Atom:
|
||||||
|
"""
|
||||||
|
Processes a sample entry atom, which contains information about a specific type of sample.
|
||||||
|
This includes codec-specific information and other sample entry details.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
entry (MP4Atom): The sample entry atom to process.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MP4Atom: Processed sample entry atom with updated information.
|
||||||
|
"""
|
||||||
|
# Determine the size of fixed fields based on sample entry type
|
||||||
|
if entry.atom_type in {b"mp4a", b"enca"}:
|
||||||
|
fixed_size = 28 # 8 bytes for size, type and reserved, 20 bytes for fixed fields in Audio Sample Entry.
|
||||||
|
elif entry.atom_type in {b"mp4v", b"encv", b"avc1", b"hev1", b"hvc1"}:
|
||||||
|
fixed_size = 78 # 8 bytes for size, type and reserved, 70 bytes for fixed fields in Video Sample Entry.
|
||||||
|
else:
|
||||||
|
fixed_size = 16 # 8 bytes for size, type and reserved, 8 bytes for fixed fields in other Sample Entries.
|
||||||
|
|
||||||
|
new_entry_data = bytearray(entry.data[:fixed_size])
|
||||||
|
parser = MP4Parser(entry.data[fixed_size:])
|
||||||
|
codec_format = None
|
||||||
|
|
||||||
|
for atom in iter(parser.read_atom, None):
|
||||||
|
if atom.atom_type in {b"sinf", b"schi", b"tenc", b"schm"}:
|
||||||
|
if atom.atom_type == b"sinf":
|
||||||
|
codec_format = self._extract_codec_format(atom)
|
||||||
|
continue # Skip encryption-related atoms
|
||||||
|
new_entry_data.extend(atom.pack())
|
||||||
|
|
||||||
|
# Replace the atom type with the extracted codec format
|
||||||
|
new_type = codec_format if codec_format else entry.atom_type
|
||||||
|
return MP4Atom(new_type, len(new_entry_data) + 8, new_entry_data)
|
||||||
|
|
||||||
|
def _extract_codec_format(self, sinf: MP4Atom) -> Optional[bytes]:
|
||||||
|
"""
|
||||||
|
Extracts the codec format from the 'sinf' (Protection Scheme Information) atom.
|
||||||
|
This includes information about the original format of the protected content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sinf (MP4Atom): The 'sinf' atom to extract from.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[bytes]: The codec format or None if not found.
|
||||||
|
"""
|
||||||
|
parser = MP4Parser(sinf.data)
|
||||||
|
for atom in iter(parser.read_atom, None):
|
||||||
|
if atom.atom_type == b"frma":
|
||||||
|
return atom.data
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def decrypt_segment(init_segment: bytes, segment_content: bytes, key_id: str, key: str) -> bytes:
|
||||||
|
"""
|
||||||
|
Decrypts a CENC encrypted MP4 segment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
init_segment (bytes): Initialization segment data.
|
||||||
|
segment_content (bytes): Encrypted segment content.
|
||||||
|
key_id (str): Key ID in hexadecimal format.
|
||||||
|
key (str): Key in hexadecimal format.
|
||||||
|
"""
|
||||||
|
key_map = {bytes.fromhex(key_id): bytes.fromhex(key)}
|
||||||
|
decrypter = MP4Decrypter(key_map)
|
||||||
|
decrypted_content = decrypter.decrypt_segment(init_segment + segment_content)
|
||||||
|
return decrypted_content
|
||||||
|
|
||||||
|
|
||||||
|
def cli():
|
||||||
|
"""
|
||||||
|
Command line interface for decrypting a CENC encrypted MP4 segment.
|
||||||
|
"""
|
||||||
|
init_segment = b""
|
||||||
|
|
||||||
|
if args.init and args.segment:
|
||||||
|
with open(args.init, "rb") as f:
|
||||||
|
init_segment = f.read()
|
||||||
|
with open(args.segment, "rb") as f:
|
||||||
|
segment_content = f.read()
|
||||||
|
elif args.combined_segment:
|
||||||
|
with open(args.combined_segment, "rb") as f:
|
||||||
|
segment_content = f.read()
|
||||||
|
else:
|
||||||
|
print("Usage: python mp4decrypt.py --help")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
try:
|
||||||
|
decrypted_segment = decrypt_segment(init_segment, segment_content, args.key_id, args.key)
|
||||||
|
print(f"Decrypted content size is {len(decrypted_segment)} bytes")
|
||||||
|
with open(args.output, "wb") as f:
|
||||||
|
f.write(decrypted_segment)
|
||||||
|
print(f"Decrypted segment written to {args.output}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
arg_parser = argparse.ArgumentParser(description="Decrypts a MP4 init and media segment using CENC encryption.")
|
||||||
|
arg_parser.add_argument("--init", help="Path to the init segment file", required=False)
|
||||||
|
arg_parser.add_argument("--segment", help="Path to the media segment file", required=False)
|
||||||
|
arg_parser.add_argument(
|
||||||
|
"--combined_segment", help="Path to the combined init and media segment file", required=False
|
||||||
|
)
|
||||||
|
arg_parser.add_argument("--key_id", help="Key ID in hexadecimal format", required=True)
|
||||||
|
arg_parser.add_argument("--key", help="Key in hexadecimal format", required=True)
|
||||||
|
arg_parser.add_argument("--output", help="Path to the output file", required=True)
|
||||||
|
args = arg_parser.parse_args()
|
||||||
|
cli()
|
||||||
0
mediaflow_proxy/extractors/__init__.py
Normal file
0
mediaflow_proxy/extractors/__init__.py
Normal file
50
mediaflow_proxy/extractors/base.py
Normal file
50
mediaflow_proxy/extractors/base.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Dict, Optional, Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from mediaflow_proxy.configs import settings
|
||||||
|
from mediaflow_proxy.utils.http_utils import create_httpx_client
|
||||||
|
|
||||||
|
|
||||||
|
class ExtractorError(Exception):
|
||||||
|
"""Base exception for all extractors."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BaseExtractor(ABC):
|
||||||
|
"""Base class for all URL extractors."""
|
||||||
|
|
||||||
|
def __init__(self, request_headers: dict):
|
||||||
|
self.base_headers = {
|
||||||
|
"user-agent": settings.user_agent,
|
||||||
|
}
|
||||||
|
self.mediaflow_endpoint = "proxy_stream_endpoint"
|
||||||
|
self.base_headers.update(request_headers)
|
||||||
|
|
||||||
|
async def _make_request(
|
||||||
|
self, url: str, method: str = "GET", headers: Optional[Dict] = None, **kwargs
|
||||||
|
) -> httpx.Response:
|
||||||
|
"""Make HTTP request with error handling."""
|
||||||
|
try:
|
||||||
|
async with create_httpx_client() as client:
|
||||||
|
request_headers = self.base_headers
|
||||||
|
request_headers.update(headers or {})
|
||||||
|
response = await client.request(
|
||||||
|
method,
|
||||||
|
url,
|
||||||
|
headers=request_headers,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response
|
||||||
|
except httpx.HTTPError as e:
|
||||||
|
raise ExtractorError(f"HTTP request failed: {str(e)}")
|
||||||
|
except Exception as e:
|
||||||
|
raise ExtractorError(f"Request failed: {str(e)}")
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def extract(self, url: str, **kwargs) -> Dict[str, Any]:
|
||||||
|
"""Extract final URL and required headers."""
|
||||||
|
pass
|
||||||
39
mediaflow_proxy/extractors/doodstream.py
Normal file
39
mediaflow_proxy/extractors/doodstream.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
import re
|
||||||
|
import time
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from mediaflow_proxy.extractors.base import BaseExtractor, ExtractorError
|
||||||
|
|
||||||
|
|
||||||
|
class DoodStreamExtractor(BaseExtractor):
|
||||||
|
"""DoodStream URL extractor."""
|
||||||
|
|
||||||
|
def __init__(self, request_headers: dict):
|
||||||
|
super().__init__(request_headers)
|
||||||
|
self.base_url = "https://d000d.com"
|
||||||
|
|
||||||
|
async def extract(self, url: str, **kwargs) -> Dict[str, str]:
|
||||||
|
"""Extract DoodStream URL."""
|
||||||
|
response = await self._make_request(url)
|
||||||
|
|
||||||
|
# Extract URL pattern
|
||||||
|
pattern = r"(\/pass_md5\/.*?)'.*(\?token=.*?expiry=)"
|
||||||
|
match = re.search(pattern, response.text, re.DOTALL)
|
||||||
|
if not match:
|
||||||
|
raise ExtractorError("Failed to extract URL pattern")
|
||||||
|
|
||||||
|
# Build final URL
|
||||||
|
pass_url = f"{self.base_url}{match[1]}"
|
||||||
|
referer = f"{self.base_url}/"
|
||||||
|
headers = {"range": "bytes=0-", "referer": referer}
|
||||||
|
|
||||||
|
response = await self._make_request(pass_url, headers=headers)
|
||||||
|
timestamp = str(int(time.time()))
|
||||||
|
final_url = f"{response.text}123456789{match[2]}{timestamp}"
|
||||||
|
|
||||||
|
self.base_headers["referer"] = referer
|
||||||
|
return {
|
||||||
|
"destination_url": final_url,
|
||||||
|
"request_headers": self.base_headers,
|
||||||
|
"mediaflow_endpoint": self.mediaflow_endpoint,
|
||||||
|
}
|
||||||
32
mediaflow_proxy/extractors/factory.py
Normal file
32
mediaflow_proxy/extractors/factory.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
from typing import Dict, Type
|
||||||
|
|
||||||
|
from mediaflow_proxy.extractors.base import BaseExtractor, ExtractorError
|
||||||
|
from mediaflow_proxy.extractors.doodstream import DoodStreamExtractor
|
||||||
|
from mediaflow_proxy.extractors.livetv import LiveTVExtractor
|
||||||
|
from mediaflow_proxy.extractors.mixdrop import MixdropExtractor
|
||||||
|
from mediaflow_proxy.extractors.uqload import UqloadExtractor
|
||||||
|
from mediaflow_proxy.extractors.streamtape import StreamtapeExtractor
|
||||||
|
from mediaflow_proxy.extractors.supervideo import SupervideoExtractor
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ExtractorFactory:
|
||||||
|
"""Factory for creating URL extractors."""
|
||||||
|
|
||||||
|
_extractors: Dict[str, Type[BaseExtractor]] = {
|
||||||
|
"Doodstream": DoodStreamExtractor,
|
||||||
|
"Uqload": UqloadExtractor,
|
||||||
|
"Mixdrop": MixdropExtractor,
|
||||||
|
"Streamtape": StreamtapeExtractor,
|
||||||
|
"Supervideo": SupervideoExtractor,
|
||||||
|
"LiveTV": LiveTVExtractor,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_extractor(cls, host: str, request_headers: dict) -> BaseExtractor:
|
||||||
|
"""Get appropriate extractor instance for the given host."""
|
||||||
|
extractor_class = cls._extractors.get(host)
|
||||||
|
if not extractor_class:
|
||||||
|
raise ExtractorError(f"Unsupported host: {host}")
|
||||||
|
return extractor_class(request_headers)
|
||||||
251
mediaflow_proxy/extractors/livetv.py
Normal file
251
mediaflow_proxy/extractors/livetv.py
Normal file
@@ -0,0 +1,251 @@
|
|||||||
|
import re
|
||||||
|
from typing import Dict, Tuple, Optional
|
||||||
|
from urllib.parse import urljoin, urlparse, unquote
|
||||||
|
|
||||||
|
from httpx import Response
|
||||||
|
|
||||||
|
from mediaflow_proxy.extractors.base import BaseExtractor, ExtractorError
|
||||||
|
|
||||||
|
|
||||||
|
class LiveTVExtractor(BaseExtractor):
|
||||||
|
"""LiveTV URL extractor for both M3U8 and MPD streams."""
|
||||||
|
|
||||||
|
def __init__(self, request_headers: dict):
|
||||||
|
super().__init__(request_headers)
|
||||||
|
# Default to HLS proxy endpoint, will be updated based on stream type
|
||||||
|
self.mediaflow_endpoint = "hls_manifest_proxy"
|
||||||
|
|
||||||
|
# Patterns for stream URL extraction
|
||||||
|
self.fallback_pattern = re.compile(
|
||||||
|
r"source: [\'\"](.*?)[\'\"]\s*,\s*[\s\S]*?mimeType: [\'\"](application/x-mpegURL|application/vnd\.apple\.mpegURL|application/dash\+xml)[\'\"]",
|
||||||
|
re.IGNORECASE,
|
||||||
|
)
|
||||||
|
self.any_m3u8_pattern = re.compile(
|
||||||
|
r'["\']?(https?://.*?\.m3u8(?:\?[^"\']*)?)["\']?',
|
||||||
|
re.IGNORECASE,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def extract(self, url: str, stream_title: str = None, **kwargs) -> Dict[str, str]:
|
||||||
|
"""Extract LiveTV URL and required headers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: The channel page URL
|
||||||
|
stream_title: Optional stream title to filter specific stream
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[str, Dict[str, str]]: Stream URL and required headers
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Get the channel page
|
||||||
|
response = await self._make_request(url)
|
||||||
|
self.base_headers["referer"] = urljoin(url, "/")
|
||||||
|
|
||||||
|
# Extract player API details
|
||||||
|
player_api_base, method = await self._extract_player_api_base(response.text)
|
||||||
|
if not player_api_base:
|
||||||
|
raise ExtractorError("Failed to extract player API URL")
|
||||||
|
|
||||||
|
# Get player options
|
||||||
|
options_data = await self._get_player_options(response.text)
|
||||||
|
if not options_data:
|
||||||
|
raise ExtractorError("No player options found")
|
||||||
|
|
||||||
|
# Process player options to find matching stream
|
||||||
|
for option in options_data:
|
||||||
|
current_title = option.get("title")
|
||||||
|
if stream_title and current_title != stream_title:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Get stream URL based on player option
|
||||||
|
stream_data = await self._process_player_option(
|
||||||
|
player_api_base, method, option.get("post"), option.get("nume"), option.get("type")
|
||||||
|
)
|
||||||
|
|
||||||
|
if stream_data:
|
||||||
|
stream_url = stream_data.get("url")
|
||||||
|
if not stream_url:
|
||||||
|
continue
|
||||||
|
|
||||||
|
response = {
|
||||||
|
"destination_url": stream_url,
|
||||||
|
"request_headers": self.base_headers,
|
||||||
|
"mediaflow_endpoint": self.mediaflow_endpoint,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Set endpoint based on stream type
|
||||||
|
if stream_data.get("type") == "mpd":
|
||||||
|
if stream_data.get("drm_key_id") and stream_data.get("drm_key"):
|
||||||
|
response.update(
|
||||||
|
{
|
||||||
|
"query_params": {
|
||||||
|
"key_id": stream_data["drm_key_id"],
|
||||||
|
"key": stream_data["drm_key"],
|
||||||
|
},
|
||||||
|
"mediaflow_endpoint": "mpd_manifest_proxy",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
raise ExtractorError("No valid stream found")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise ExtractorError(f"Extraction failed: {str(e)}")
|
||||||
|
|
||||||
|
async def _extract_player_api_base(self, html_content: str) -> Tuple[Optional[str], Optional[str]]:
|
||||||
|
"""Extract player API base URL and method."""
|
||||||
|
admin_ajax_pattern = r'"player_api"\s*:\s*"([^"]+)".*?"play_method"\s*:\s*"([^"]+)"'
|
||||||
|
match = re.search(admin_ajax_pattern, html_content)
|
||||||
|
if not match:
|
||||||
|
return None, None
|
||||||
|
url = match.group(1).replace("\\/", "/")
|
||||||
|
method = match.group(2)
|
||||||
|
if method == "wp_json":
|
||||||
|
return url, method
|
||||||
|
url = urljoin(url, "/wp-admin/admin-ajax.php")
|
||||||
|
return url, method
|
||||||
|
|
||||||
|
async def _get_player_options(self, html_content: str) -> list:
|
||||||
|
"""Extract player options from HTML content."""
|
||||||
|
pattern = r'<li[^>]*class=["\']dooplay_player_option["\'][^>]*data-type=["\']([^"\']*)["\'][^>]*data-post=["\']([^"\']*)["\'][^>]*data-nume=["\']([^"\']*)["\'][^>]*>.*?<span class=["\']title["\']>([^<]*)</span>'
|
||||||
|
matches = re.finditer(pattern, html_content, re.DOTALL)
|
||||||
|
return [
|
||||||
|
{"type": match.group(1), "post": match.group(2), "nume": match.group(3), "title": match.group(4).strip()}
|
||||||
|
for match in matches
|
||||||
|
]
|
||||||
|
|
||||||
|
async def _process_player_option(self, api_base: str, method: str, post: str, nume: str, type_: str) -> Dict:
|
||||||
|
"""Process player option to get stream URL."""
|
||||||
|
if method == "wp_json":
|
||||||
|
api_url = f"{api_base}{post}/{type_}/{nume}"
|
||||||
|
response = await self._make_request(api_url)
|
||||||
|
else:
|
||||||
|
form_data = {"action": "doo_player_ajax", "post": post, "nume": nume, "type": type_}
|
||||||
|
response = await self._make_request(api_base, method="POST", data=form_data)
|
||||||
|
|
||||||
|
# Get iframe URL from API response
|
||||||
|
try:
|
||||||
|
data = response.json()
|
||||||
|
iframe_url = urljoin(api_base, data.get("embed_url", "").replace("\\/", "/"))
|
||||||
|
|
||||||
|
# Get stream URL from iframe
|
||||||
|
iframe_response = await self._make_request(iframe_url)
|
||||||
|
stream_data = await self._extract_stream_url(iframe_response, iframe_url)
|
||||||
|
return stream_data
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise ExtractorError(f"Failed to process player option: {str(e)}")
|
||||||
|
|
||||||
|
async def _extract_stream_url(self, iframe_response: Response, iframe_url: str) -> Dict:
|
||||||
|
"""
|
||||||
|
Extract final stream URL from iframe content.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Parse URL components
|
||||||
|
parsed_url = urlparse(iframe_url)
|
||||||
|
query_params = dict(param.split("=") for param in parsed_url.query.split("&") if "=" in param)
|
||||||
|
|
||||||
|
# Check if content is already a direct M3U8 stream
|
||||||
|
content_types = ["application/x-mpegurl", "application/vnd.apple.mpegurl"]
|
||||||
|
|
||||||
|
if any(ext in iframe_response.headers["content-type"] for ext in content_types):
|
||||||
|
return {"url": iframe_url, "type": "m3u8"}
|
||||||
|
|
||||||
|
stream_data = {}
|
||||||
|
|
||||||
|
# Check for source parameter in URL
|
||||||
|
if "source" in query_params:
|
||||||
|
stream_data = {
|
||||||
|
"url": urljoin(iframe_url, unquote(query_params["source"])),
|
||||||
|
"type": "m3u8",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Check for MPD stream with DRM
|
||||||
|
elif "zy" in query_params and ".mpd``" in query_params["zy"]:
|
||||||
|
data = query_params["zy"].split("``")
|
||||||
|
url = data[0]
|
||||||
|
key_id, key = data[1].split(":")
|
||||||
|
stream_data = {"url": url, "type": "mpd", "drm_key_id": key_id, "drm_key": key}
|
||||||
|
|
||||||
|
# Check for tamilultra specific format
|
||||||
|
elif "tamilultra" in iframe_url:
|
||||||
|
stream_data = {"url": urljoin(iframe_url, parsed_url.query), "type": "m3u8"}
|
||||||
|
|
||||||
|
# Try pattern matching for stream URLs
|
||||||
|
else:
|
||||||
|
channel_id = query_params.get("id", [""])
|
||||||
|
stream_url = None
|
||||||
|
|
||||||
|
html_content = iframe_response.text
|
||||||
|
|
||||||
|
if channel_id:
|
||||||
|
# Try channel ID specific pattern
|
||||||
|
pattern = rf'{re.escape(channel_id)}["\']:\s*{{\s*["\']?url["\']?\s*:\s*["\']([^"\']+)["\']'
|
||||||
|
match = re.search(pattern, html_content)
|
||||||
|
if match:
|
||||||
|
stream_url = match.group(1)
|
||||||
|
|
||||||
|
# Try fallback patterns if channel ID pattern fails
|
||||||
|
if not stream_url:
|
||||||
|
for pattern in [self.fallback_pattern, self.any_m3u8_pattern]:
|
||||||
|
match = pattern.search(html_content)
|
||||||
|
if match:
|
||||||
|
stream_url = match.group(1)
|
||||||
|
break
|
||||||
|
|
||||||
|
if stream_url:
|
||||||
|
stream_data = {"url": stream_url, "type": "m3u8"} # Default to m3u8, will be updated
|
||||||
|
|
||||||
|
# Check for MPD stream and extract DRM keys
|
||||||
|
if stream_url.endswith(".mpd"):
|
||||||
|
stream_data["type"] = "mpd"
|
||||||
|
drm_data = await self._extract_drm_keys(html_content, channel_id)
|
||||||
|
if drm_data:
|
||||||
|
stream_data.update(drm_data)
|
||||||
|
|
||||||
|
# If no stream data found, raise error
|
||||||
|
if not stream_data:
|
||||||
|
raise ExtractorError("No valid stream URL found")
|
||||||
|
|
||||||
|
# Update stream type based on URL if not already set
|
||||||
|
if stream_data.get("type") == "m3u8":
|
||||||
|
if stream_data["url"].endswith(".mpd"):
|
||||||
|
stream_data["type"] = "mpd"
|
||||||
|
elif not any(ext in stream_data["url"] for ext in [".m3u8", ".m3u"]):
|
||||||
|
stream_data["type"] = "m3u8" # Default to m3u8 if no extension found
|
||||||
|
|
||||||
|
return stream_data
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise ExtractorError(f"Failed to extract stream URL: {str(e)}")
|
||||||
|
|
||||||
|
async def _extract_drm_keys(self, html_content: str, channel_id: str) -> Dict:
|
||||||
|
"""
|
||||||
|
Extract DRM keys for MPD streams.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Pattern for channel entry
|
||||||
|
channel_pattern = rf'"{re.escape(channel_id)}":\s*{{[^}}]+}}'
|
||||||
|
channel_match = re.search(channel_pattern, html_content)
|
||||||
|
|
||||||
|
if channel_match:
|
||||||
|
channel_data = channel_match.group(0)
|
||||||
|
|
||||||
|
# Try clearkeys pattern first
|
||||||
|
clearkey_pattern = r'["\']?clearkeys["\']?\s*:\s*{\s*["\'](.+?)["\']:\s*["\'](.+?)["\']'
|
||||||
|
clearkey_match = re.search(clearkey_pattern, channel_data)
|
||||||
|
|
||||||
|
# Try k1/k2 pattern if clearkeys not found
|
||||||
|
if not clearkey_match:
|
||||||
|
k1k2_pattern = r'["\']?k1["\']?\s*:\s*["\'](.+?)["\'],\s*["\']?k2["\']?\s*:\s*["\'](.+?)["\']'
|
||||||
|
k1k2_match = re.search(k1k2_pattern, channel_data)
|
||||||
|
|
||||||
|
if k1k2_match:
|
||||||
|
return {"drm_key_id": k1k2_match.group(1), "drm_key": k1k2_match.group(2)}
|
||||||
|
else:
|
||||||
|
return {"drm_key_id": clearkey_match.group(1), "drm_key": clearkey_match.group(2)}
|
||||||
|
|
||||||
|
return {}
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
return {}
|
||||||
36
mediaflow_proxy/extractors/mixdrop.py
Normal file
36
mediaflow_proxy/extractors/mixdrop.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
import re
|
||||||
|
import string
|
||||||
|
from typing import Dict, Any
|
||||||
|
|
||||||
|
from mediaflow_proxy.extractors.base import BaseExtractor, ExtractorError
|
||||||
|
|
||||||
|
|
||||||
|
class MixdropExtractor(BaseExtractor):
|
||||||
|
"""Mixdrop URL extractor."""
|
||||||
|
|
||||||
|
async def extract(self, url: str, **kwargs) -> Dict[str, Any]:
|
||||||
|
"""Extract Mixdrop URL."""
|
||||||
|
response = await self._make_request(url, headers={"accept-language": "en-US,en;q=0.5"})
|
||||||
|
|
||||||
|
# Extract and decode URL
|
||||||
|
match = re.search(r"}\('(.+)',.+,'(.+)'\.split", response.text)
|
||||||
|
if not match:
|
||||||
|
raise ExtractorError("Failed to extract URL components")
|
||||||
|
|
||||||
|
s1, s2 = match.group(1, 2)
|
||||||
|
schema = s1.split(";")[2][5:-1]
|
||||||
|
terms = s2.split("|")
|
||||||
|
|
||||||
|
# Build character mapping
|
||||||
|
charset = string.digits + string.ascii_letters
|
||||||
|
char_map = {charset[i]: terms[i] or charset[i] for i in range(len(terms))}
|
||||||
|
|
||||||
|
# Construct final URL
|
||||||
|
final_url = "https:" + "".join(char_map.get(c, c) for c in schema)
|
||||||
|
|
||||||
|
self.base_headers["referer"] = url
|
||||||
|
return {
|
||||||
|
"destination_url": final_url,
|
||||||
|
"request_headers": self.base_headers,
|
||||||
|
"mediaflow_endpoint": self.mediaflow_endpoint,
|
||||||
|
}
|
||||||
32
mediaflow_proxy/extractors/streamtape.py
Normal file
32
mediaflow_proxy/extractors/streamtape.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
import re
|
||||||
|
from typing import Dict, Any
|
||||||
|
|
||||||
|
from mediaflow_proxy.extractors.base import BaseExtractor, ExtractorError
|
||||||
|
|
||||||
|
|
||||||
|
class StreamtapeExtractor(BaseExtractor):
|
||||||
|
"""Streamtape URL extractor."""
|
||||||
|
|
||||||
|
async def extract(self, url: str, **kwargs) -> Dict[str, Any]:
|
||||||
|
"""Extract Streamtape URL."""
|
||||||
|
response = await self._make_request(url)
|
||||||
|
|
||||||
|
# Extract and decode URL
|
||||||
|
matches = re.findall(r"id=.*?(?=')", response.text)
|
||||||
|
if not matches:
|
||||||
|
raise ExtractorError("Failed to extract URL components")
|
||||||
|
final_url = next(
|
||||||
|
(
|
||||||
|
f"https://streamtape.com/get_video?{matches[i + 1]}"
|
||||||
|
for i in range(len(matches) - 1)
|
||||||
|
if matches[i] == matches[i + 1]
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.base_headers["referer"] = url
|
||||||
|
return {
|
||||||
|
"destination_url": final_url,
|
||||||
|
"request_headers": self.base_headers,
|
||||||
|
"mediaflow_endpoint": self.mediaflow_endpoint,
|
||||||
|
}
|
||||||
27
mediaflow_proxy/extractors/supervideo.py
Normal file
27
mediaflow_proxy/extractors/supervideo.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
import re
|
||||||
|
from typing import Dict, Any
|
||||||
|
|
||||||
|
from mediaflow_proxy.extractors.base import BaseExtractor, ExtractorError
|
||||||
|
|
||||||
|
|
||||||
|
class SupervideoExtractor(BaseExtractor):
|
||||||
|
"""Supervideo URL extractor."""
|
||||||
|
|
||||||
|
async def extract(self, url: str, **kwargs) -> Dict[str, Any]:
|
||||||
|
"""Extract Supervideo URL."""
|
||||||
|
response = await self._make_request(url)
|
||||||
|
# Extract and decode URL
|
||||||
|
s2 = re.search(r"\}\('(.+)',.+,'(.+)'\.split", response.text).group(2)
|
||||||
|
terms = s2.split("|")
|
||||||
|
hfs = next(terms[i] for i in range(terms.index("file"), len(terms)) if "hfs" in terms[i])
|
||||||
|
result = terms[terms.index("urlset") + 1 : terms.index("hls")]
|
||||||
|
|
||||||
|
base_url = f"https://{hfs}.serversicuro.cc/hls/"
|
||||||
|
final_url = base_url + ",".join(reversed(result)) + (".urlset/master.m3u8" if result else "")
|
||||||
|
|
||||||
|
self.base_headers["referer"] = url
|
||||||
|
return {
|
||||||
|
"destination_url": final_url,
|
||||||
|
"request_headers": self.base_headers,
|
||||||
|
"mediaflow_endpoint": self.mediaflow_endpoint,
|
||||||
|
}
|
||||||
24
mediaflow_proxy/extractors/uqload.py
Normal file
24
mediaflow_proxy/extractors/uqload.py
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
import re
|
||||||
|
from typing import Dict
|
||||||
|
from urllib.parse import urljoin
|
||||||
|
|
||||||
|
from mediaflow_proxy.extractors.base import BaseExtractor, ExtractorError
|
||||||
|
|
||||||
|
|
||||||
|
class UqloadExtractor(BaseExtractor):
|
||||||
|
"""Uqload URL extractor."""
|
||||||
|
|
||||||
|
async def extract(self, url: str, **kwargs) -> Dict[str, str]:
|
||||||
|
"""Extract Uqload URL."""
|
||||||
|
response = await self._make_request(url)
|
||||||
|
|
||||||
|
video_url_match = re.search(r'sources: \["(.*?)"]', response.text)
|
||||||
|
if not video_url_match:
|
||||||
|
raise ExtractorError("Failed to extract video URL")
|
||||||
|
|
||||||
|
self.base_headers["referer"] = urljoin(url, "/")
|
||||||
|
return {
|
||||||
|
"destination_url": video_url_match.group(1),
|
||||||
|
"request_headers": self.base_headers,
|
||||||
|
"mediaflow_endpoint": self.mediaflow_endpoint,
|
||||||
|
}
|
||||||
358
mediaflow_proxy/handlers.py
Normal file
358
mediaflow_proxy/handlers.py
Normal file
@@ -0,0 +1,358 @@
|
|||||||
|
import base64
|
||||||
|
import logging
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from fastapi import Request, Response, HTTPException
|
||||||
|
from starlette.background import BackgroundTask
|
||||||
|
|
||||||
|
from .const import SUPPORTED_RESPONSE_HEADERS
|
||||||
|
from .mpd_processor import process_manifest, process_playlist, process_segment
|
||||||
|
from .schemas import HLSManifestParams, ProxyStreamParams, MPDManifestParams, MPDPlaylistParams, MPDSegmentParams
|
||||||
|
from .utils.cache_utils import get_cached_mpd, get_cached_init_segment
|
||||||
|
from .utils.http_utils import (
|
||||||
|
Streamer,
|
||||||
|
DownloadError,
|
||||||
|
download_file_with_retry,
|
||||||
|
request_with_retry,
|
||||||
|
EnhancedStreamingResponse,
|
||||||
|
ProxyRequestHeaders,
|
||||||
|
create_httpx_client,
|
||||||
|
)
|
||||||
|
from .utils.m3u8_processor import M3U8Processor
|
||||||
|
from .utils.mpd_utils import pad_base64
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def setup_client_and_streamer() -> tuple[httpx.AsyncClient, Streamer]:
|
||||||
|
"""
|
||||||
|
Set up an HTTP client and a streamer.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: An httpx.AsyncClient instance and a Streamer instance.
|
||||||
|
"""
|
||||||
|
client = create_httpx_client()
|
||||||
|
return client, Streamer(client)
|
||||||
|
|
||||||
|
|
||||||
|
def handle_exceptions(exception: Exception) -> Response:
|
||||||
|
"""
|
||||||
|
Handle exceptions and return appropriate HTTP responses.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
exception (Exception): The exception that was raised.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response: An HTTP response corresponding to the exception type.
|
||||||
|
"""
|
||||||
|
if isinstance(exception, httpx.HTTPStatusError):
|
||||||
|
logger.error(f"Upstream service error while handling request: {exception}")
|
||||||
|
return Response(status_code=exception.response.status_code, content=f"Upstream service error: {exception}")
|
||||||
|
elif isinstance(exception, DownloadError):
|
||||||
|
logger.error(f"Error downloading content: {exception}")
|
||||||
|
return Response(status_code=exception.status_code, content=str(exception))
|
||||||
|
else:
|
||||||
|
logger.exception(f"Internal server error while handling request: {exception}")
|
||||||
|
return Response(status_code=502, content=f"Internal server error: {exception}")
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_hls_stream_proxy(
|
||||||
|
request: Request, hls_params: HLSManifestParams, proxy_headers: ProxyRequestHeaders
|
||||||
|
) -> Response:
|
||||||
|
"""
|
||||||
|
Handle HLS stream proxy requests.
|
||||||
|
|
||||||
|
This function processes HLS manifest files and streams content based on the request parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request (Request): The incoming FastAPI request object.
|
||||||
|
hls_params (HLSManifestParams): Parameters for the HLS manifest.
|
||||||
|
proxy_headers (ProxyRequestHeaders): Headers to be used in the proxy request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Union[Response, EnhancedStreamingResponse]: Either a processed m3u8 playlist or a streaming response.
|
||||||
|
"""
|
||||||
|
client, streamer = await setup_client_and_streamer()
|
||||||
|
|
||||||
|
try:
|
||||||
|
if urlparse(hls_params.destination).path.endswith((".m3u", ".m3u8")):
|
||||||
|
return await fetch_and_process_m3u8(
|
||||||
|
streamer, hls_params.destination, proxy_headers, request, hls_params.key_url
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create initial streaming response to check content type
|
||||||
|
await streamer.create_streaming_response(hls_params.destination, proxy_headers.request)
|
||||||
|
|
||||||
|
if "mpegurl" in streamer.response.headers.get("content-type", "").lower():
|
||||||
|
return await fetch_and_process_m3u8(
|
||||||
|
streamer, hls_params.destination, proxy_headers, request, hls_params.key_url
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle range requests
|
||||||
|
content_range = proxy_headers.request.get("range", "bytes=0-")
|
||||||
|
if "NaN" in content_range:
|
||||||
|
# Handle invalid range requests "bytes=NaN-NaN"
|
||||||
|
raise HTTPException(status_code=416, detail="Invalid Range Header")
|
||||||
|
proxy_headers.request.update({"range": content_range})
|
||||||
|
|
||||||
|
# Create new streaming response with updated headers
|
||||||
|
await streamer.create_streaming_response(hls_params.destination, proxy_headers.request)
|
||||||
|
response_headers = prepare_response_headers(streamer.response.headers, proxy_headers.response)
|
||||||
|
|
||||||
|
return EnhancedStreamingResponse(
|
||||||
|
streamer.stream_content(),
|
||||||
|
status_code=streamer.response.status_code,
|
||||||
|
headers=response_headers,
|
||||||
|
background=BackgroundTask(streamer.close),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
await streamer.close()
|
||||||
|
return handle_exceptions(e)
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_stream_request(
|
||||||
|
method: str,
|
||||||
|
video_url: str,
|
||||||
|
proxy_headers: ProxyRequestHeaders,
|
||||||
|
) -> Response:
|
||||||
|
"""
|
||||||
|
Handle general stream requests.
|
||||||
|
|
||||||
|
This function processes both HEAD and GET requests for video streams.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
method (str): The HTTP method (e.g., 'GET' or 'HEAD').
|
||||||
|
video_url (str): The URL of the video to stream.
|
||||||
|
proxy_headers (ProxyRequestHeaders): Headers to be used in the proxy request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Union[Response, EnhancedStreamingResponse]: Either a HEAD response with headers or a streaming response.
|
||||||
|
"""
|
||||||
|
client, streamer = await setup_client_and_streamer()
|
||||||
|
|
||||||
|
try:
|
||||||
|
await streamer.create_streaming_response(video_url, proxy_headers.request)
|
||||||
|
response_headers = prepare_response_headers(streamer.response.headers, proxy_headers.response)
|
||||||
|
|
||||||
|
if method == "HEAD":
|
||||||
|
# For HEAD requests, just return the headers without streaming content
|
||||||
|
await streamer.close()
|
||||||
|
return Response(headers=response_headers, status_code=streamer.response.status_code)
|
||||||
|
else:
|
||||||
|
# For GET requests, return the streaming response
|
||||||
|
return EnhancedStreamingResponse(
|
||||||
|
streamer.stream_content(),
|
||||||
|
headers=response_headers,
|
||||||
|
status_code=streamer.response.status_code,
|
||||||
|
background=BackgroundTask(streamer.close),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
await streamer.close()
|
||||||
|
return handle_exceptions(e)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_response_headers(original_headers, proxy_response_headers) -> dict:
|
||||||
|
"""
|
||||||
|
Prepare response headers for the proxy response.
|
||||||
|
|
||||||
|
This function filters the original headers, ensures proper transfer encoding,
|
||||||
|
and merges them with the proxy response headers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
original_headers (httpx.Headers): The original headers from the upstream response.
|
||||||
|
proxy_response_headers (dict): Additional headers to be included in the proxy response.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: The prepared headers for the proxy response.
|
||||||
|
"""
|
||||||
|
response_headers = {k: v for k, v in original_headers.multi_items() if k in SUPPORTED_RESPONSE_HEADERS}
|
||||||
|
response_headers.update(proxy_response_headers)
|
||||||
|
return response_headers
|
||||||
|
|
||||||
|
|
||||||
|
async def proxy_stream(method: str, stream_params: ProxyStreamParams, proxy_headers: ProxyRequestHeaders):
|
||||||
|
"""
|
||||||
|
Proxies the stream request to the given video URL.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
method (str): The HTTP method (e.g., GET, HEAD).
|
||||||
|
stream_params (ProxyStreamParams): The parameters for the stream request.
|
||||||
|
proxy_headers (ProxyRequestHeaders): The headers to include in the request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response: The HTTP response with the streamed content.
|
||||||
|
"""
|
||||||
|
return await handle_stream_request(method, stream_params.destination, proxy_headers)
|
||||||
|
|
||||||
|
|
||||||
|
async def fetch_and_process_m3u8(
|
||||||
|
streamer: Streamer, url: str, proxy_headers: ProxyRequestHeaders, request: Request, key_url: str = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Fetches and processes the m3u8 playlist, converting it to an HLS playlist.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
streamer (Streamer): The HTTP client to use for streaming.
|
||||||
|
url (str): The URL of the m3u8 playlist.
|
||||||
|
proxy_headers (ProxyRequestHeaders): The headers to include in the request.
|
||||||
|
request (Request): The incoming HTTP request.
|
||||||
|
key_url (str, optional): The HLS Key URL to replace the original key URL. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response: The HTTP response with the processed m3u8 playlist.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
content = await streamer.get_text(url, proxy_headers.request)
|
||||||
|
processor = M3U8Processor(request, key_url)
|
||||||
|
processed_content = await processor.process_m3u8(content, str(streamer.response.url))
|
||||||
|
response_headers = {"Content-Disposition": "inline", "Accept-Ranges": "none"}
|
||||||
|
response_headers.update(proxy_headers.response)
|
||||||
|
return Response(
|
||||||
|
content=processed_content,
|
||||||
|
media_type="application/vnd.apple.mpegurl",
|
||||||
|
headers=response_headers,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
return handle_exceptions(e)
|
||||||
|
finally:
|
||||||
|
await streamer.close()
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_drm_key_data(key_id, key, drm_info):
|
||||||
|
"""
|
||||||
|
Handles the DRM key data, retrieving the key ID and key from the DRM info if not provided.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key_id (str): The DRM key ID.
|
||||||
|
key (str): The DRM key.
|
||||||
|
drm_info (dict): The DRM information from the MPD manifest.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: The key ID and key.
|
||||||
|
"""
|
||||||
|
if drm_info and not drm_info.get("isDrmProtected"):
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
if not key_id or not key:
|
||||||
|
if "keyId" in drm_info and "key" in drm_info:
|
||||||
|
key_id = drm_info["keyId"]
|
||||||
|
key = drm_info["key"]
|
||||||
|
elif "laUrl" in drm_info and "keyId" in drm_info:
|
||||||
|
raise HTTPException(status_code=400, detail="LA URL is not supported yet")
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400, detail="Unable to determine key_id and key, and they were not provided"
|
||||||
|
)
|
||||||
|
|
||||||
|
return key_id, key
|
||||||
|
|
||||||
|
|
||||||
|
async def get_manifest(
|
||||||
|
request: Request,
|
||||||
|
manifest_params: MPDManifestParams,
|
||||||
|
proxy_headers: ProxyRequestHeaders,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Retrieves and processes the MPD manifest, converting it to an HLS manifest.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request (Request): The incoming HTTP request.
|
||||||
|
manifest_params (MPDManifestParams): The parameters for the manifest request.
|
||||||
|
proxy_headers (ProxyRequestHeaders): The headers to include in the request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response: The HTTP response with the HLS manifest.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
mpd_dict = await get_cached_mpd(
|
||||||
|
manifest_params.destination,
|
||||||
|
headers=proxy_headers.request,
|
||||||
|
parse_drm=not manifest_params.key_id and not manifest_params.key,
|
||||||
|
)
|
||||||
|
except DownloadError as e:
|
||||||
|
raise HTTPException(status_code=e.status_code, detail=f"Failed to download MPD: {e.message}")
|
||||||
|
drm_info = mpd_dict.get("drmInfo", {})
|
||||||
|
|
||||||
|
if drm_info and not drm_info.get("isDrmProtected"):
|
||||||
|
# For non-DRM protected MPD, we still create an HLS manifest
|
||||||
|
return await process_manifest(request, mpd_dict, proxy_headers, None, None)
|
||||||
|
|
||||||
|
key_id, key = await handle_drm_key_data(manifest_params.key_id, manifest_params.key, drm_info)
|
||||||
|
|
||||||
|
# check if the provided key_id and key are valid
|
||||||
|
if key_id and len(key_id) != 32:
|
||||||
|
key_id = base64.urlsafe_b64decode(pad_base64(key_id)).hex()
|
||||||
|
if key and len(key) != 32:
|
||||||
|
key = base64.urlsafe_b64decode(pad_base64(key)).hex()
|
||||||
|
|
||||||
|
return await process_manifest(request, mpd_dict, proxy_headers, key_id, key)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_playlist(
|
||||||
|
request: Request,
|
||||||
|
playlist_params: MPDPlaylistParams,
|
||||||
|
proxy_headers: ProxyRequestHeaders,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Retrieves and processes the MPD manifest, converting it to an HLS playlist for a specific profile.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request (Request): The incoming HTTP request.
|
||||||
|
playlist_params (MPDPlaylistParams): The parameters for the playlist request.
|
||||||
|
proxy_headers (ProxyRequestHeaders): The headers to include in the request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response: The HTTP response with the HLS playlist.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
mpd_dict = await get_cached_mpd(
|
||||||
|
playlist_params.destination,
|
||||||
|
headers=proxy_headers.request,
|
||||||
|
parse_drm=not playlist_params.key_id and not playlist_params.key,
|
||||||
|
parse_segment_profile_id=playlist_params.profile_id,
|
||||||
|
)
|
||||||
|
except DownloadError as e:
|
||||||
|
raise HTTPException(status_code=e.status_code, detail=f"Failed to download MPD: {e.message}")
|
||||||
|
return await process_playlist(request, mpd_dict, playlist_params.profile_id, proxy_headers)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_segment(
|
||||||
|
segment_params: MPDSegmentParams,
|
||||||
|
proxy_headers: ProxyRequestHeaders,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Retrieves and processes a media segment, decrypting it if necessary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
segment_params (MPDSegmentParams): The parameters for the segment request.
|
||||||
|
proxy_headers (ProxyRequestHeaders): The headers to include in the request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response: The HTTP response with the processed segment.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
init_content = await get_cached_init_segment(segment_params.init_url, proxy_headers.request)
|
||||||
|
segment_content = await download_file_with_retry(segment_params.segment_url, proxy_headers.request)
|
||||||
|
except Exception as e:
|
||||||
|
return handle_exceptions(e)
|
||||||
|
|
||||||
|
return await process_segment(
|
||||||
|
init_content,
|
||||||
|
segment_content,
|
||||||
|
segment_params.mime_type,
|
||||||
|
proxy_headers,
|
||||||
|
segment_params.key_id,
|
||||||
|
segment_params.key,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_public_ip():
|
||||||
|
"""
|
||||||
|
Retrieves the public IP address of the MediaFlow proxy.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response: The HTTP response with the public IP address.
|
||||||
|
"""
|
||||||
|
ip_address_data = await request_with_retry("GET", "https://api.ipify.org?format=json", {})
|
||||||
|
return ip_address_data.json()
|
||||||
99
mediaflow_proxy/main.py
Normal file
99
mediaflow_proxy/main.py
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
import logging
|
||||||
|
from importlib import resources
|
||||||
|
|
||||||
|
from fastapi import FastAPI, Depends, Security, HTTPException
|
||||||
|
from fastapi.security import APIKeyQuery, APIKeyHeader
|
||||||
|
from starlette.middleware.cors import CORSMiddleware
|
||||||
|
from starlette.responses import RedirectResponse
|
||||||
|
from starlette.staticfiles import StaticFiles
|
||||||
|
|
||||||
|
from mediaflow_proxy.configs import settings
|
||||||
|
from mediaflow_proxy.routes import proxy_router, extractor_router, speedtest_router
|
||||||
|
from mediaflow_proxy.schemas import GenerateUrlRequest
|
||||||
|
from mediaflow_proxy.utils.crypto_utils import EncryptionHandler, EncryptionMiddleware
|
||||||
|
from mediaflow_proxy.utils.http_utils import encode_mediaflow_proxy_url
|
||||||
|
|
||||||
|
logging.basicConfig(level=settings.log_level, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||||
|
app = FastAPI()
|
||||||
|
api_password_query = APIKeyQuery(name="api_password", auto_error=False)
|
||||||
|
api_password_header = APIKeyHeader(name="api_password", auto_error=False)
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
app.add_middleware(EncryptionMiddleware)
|
||||||
|
|
||||||
|
|
||||||
|
async def verify_api_key(api_key: str = Security(api_password_query), api_key_alt: str = Security(api_password_header)):
|
||||||
|
"""
|
||||||
|
Verifies the API key for the request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key (str): The API key to validate.
|
||||||
|
api_key_alt (str): The alternative API key to validate.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If the API key is invalid.
|
||||||
|
"""
|
||||||
|
if not settings.api_password:
|
||||||
|
return
|
||||||
|
|
||||||
|
if api_key == settings.api_password or api_key_alt == settings.api_password:
|
||||||
|
return
|
||||||
|
|
||||||
|
raise HTTPException(status_code=403, detail="Could not validate credentials")
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health_check():
|
||||||
|
return {"status": "healthy"}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/favicon.ico")
|
||||||
|
async def get_favicon():
|
||||||
|
return RedirectResponse(url="/logo.png")
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/speedtest")
|
||||||
|
async def show_speedtest_page():
|
||||||
|
return RedirectResponse(url="/speedtest.html")
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/generate_encrypted_or_encoded_url")
|
||||||
|
async def generate_encrypted_or_encoded_url(request: GenerateUrlRequest):
|
||||||
|
if "api_password" not in request.query_params:
|
||||||
|
request.query_params["api_password"] = request.api_password
|
||||||
|
|
||||||
|
encoded_url = encode_mediaflow_proxy_url(
|
||||||
|
request.mediaflow_proxy_url,
|
||||||
|
request.endpoint,
|
||||||
|
request.destination_url,
|
||||||
|
request.query_params,
|
||||||
|
request.request_headers,
|
||||||
|
request.response_headers,
|
||||||
|
EncryptionHandler(request.api_password) if request.api_password else None,
|
||||||
|
request.expiration,
|
||||||
|
str(request.ip) if request.ip else None,
|
||||||
|
)
|
||||||
|
return {"encoded_url": encoded_url}
|
||||||
|
|
||||||
|
|
||||||
|
app.include_router(proxy_router, prefix="/proxy", tags=["proxy"], dependencies=[Depends(verify_api_key)])
|
||||||
|
app.include_router(extractor_router, prefix="/extractor", tags=["extractors"], dependencies=[Depends(verify_api_key)])
|
||||||
|
app.include_router(speedtest_router, prefix="/speedtest", tags=["speedtest"], dependencies=[Depends(verify_api_key)])
|
||||||
|
|
||||||
|
static_path = resources.files("mediaflow_proxy").joinpath("static")
|
||||||
|
app.mount("/", StaticFiles(directory=str(static_path), html=True), name="static")
|
||||||
|
|
||||||
|
|
||||||
|
def run():
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=8888, log_level="info", workers=3)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run()
|
||||||
214
mediaflow_proxy/mpd_processor.py
Normal file
214
mediaflow_proxy/mpd_processor.py
Normal file
@@ -0,0 +1,214 @@
|
|||||||
|
import logging
|
||||||
|
import math
|
||||||
|
import time
|
||||||
|
|
||||||
|
from fastapi import Request, Response, HTTPException
|
||||||
|
|
||||||
|
from mediaflow_proxy.drm.decrypter import decrypt_segment
|
||||||
|
from mediaflow_proxy.utils.crypto_utils import encryption_handler
|
||||||
|
from mediaflow_proxy.utils.http_utils import encode_mediaflow_proxy_url, get_original_scheme, ProxyRequestHeaders
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def process_manifest(
|
||||||
|
request: Request, mpd_dict: dict, proxy_headers: ProxyRequestHeaders, key_id: str = None, key: str = None
|
||||||
|
) -> Response:
|
||||||
|
"""
|
||||||
|
Processes the MPD manifest and converts it to an HLS manifest.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request (Request): The incoming HTTP request.
|
||||||
|
mpd_dict (dict): The MPD manifest data.
|
||||||
|
proxy_headers (ProxyRequestHeaders): The headers to include in the request.
|
||||||
|
key_id (str, optional): The DRM key ID. Defaults to None.
|
||||||
|
key (str, optional): The DRM key. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response: The HLS manifest as an HTTP response.
|
||||||
|
"""
|
||||||
|
hls_content = build_hls(mpd_dict, request, key_id, key)
|
||||||
|
return Response(content=hls_content, media_type="application/vnd.apple.mpegurl", headers=proxy_headers.response)
|
||||||
|
|
||||||
|
|
||||||
|
async def process_playlist(
|
||||||
|
request: Request, mpd_dict: dict, profile_id: str, proxy_headers: ProxyRequestHeaders
|
||||||
|
) -> Response:
|
||||||
|
"""
|
||||||
|
Processes the MPD manifest and converts it to an HLS playlist for a specific profile.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request (Request): The incoming HTTP request.
|
||||||
|
mpd_dict (dict): The MPD manifest data.
|
||||||
|
profile_id (str): The profile ID to generate the playlist for.
|
||||||
|
proxy_headers (ProxyRequestHeaders): The headers to include in the request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response: The HLS playlist as an HTTP response.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If the profile is not found in the MPD manifest.
|
||||||
|
"""
|
||||||
|
matching_profiles = [p for p in mpd_dict["profiles"] if p["id"] == profile_id]
|
||||||
|
if not matching_profiles:
|
||||||
|
raise HTTPException(status_code=404, detail="Profile not found")
|
||||||
|
|
||||||
|
hls_content = build_hls_playlist(mpd_dict, matching_profiles, request)
|
||||||
|
return Response(content=hls_content, media_type="application/vnd.apple.mpegurl", headers=proxy_headers.response)
|
||||||
|
|
||||||
|
|
||||||
|
async def process_segment(
|
||||||
|
init_content: bytes,
|
||||||
|
segment_content: bytes,
|
||||||
|
mimetype: str,
|
||||||
|
proxy_headers: ProxyRequestHeaders,
|
||||||
|
key_id: str = None,
|
||||||
|
key: str = None,
|
||||||
|
) -> Response:
|
||||||
|
"""
|
||||||
|
Processes and decrypts a media segment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
init_content (bytes): The initialization segment content.
|
||||||
|
segment_content (bytes): The media segment content.
|
||||||
|
mimetype (str): The MIME type of the segment.
|
||||||
|
proxy_headers (ProxyRequestHeaders): The headers to include in the request.
|
||||||
|
key_id (str, optional): The DRM key ID. Defaults to None.
|
||||||
|
key (str, optional): The DRM key. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response: The decrypted segment as an HTTP response.
|
||||||
|
"""
|
||||||
|
if key_id and key:
|
||||||
|
# For DRM protected content
|
||||||
|
now = time.time()
|
||||||
|
decrypted_content = decrypt_segment(init_content, segment_content, key_id, key)
|
||||||
|
logger.info(f"Decryption of {mimetype} segment took {time.time() - now:.4f} seconds")
|
||||||
|
else:
|
||||||
|
# For non-DRM protected content, we just concatenate init and segment content
|
||||||
|
decrypted_content = init_content + segment_content
|
||||||
|
|
||||||
|
return Response(content=decrypted_content, media_type=mimetype, headers=proxy_headers.response)
|
||||||
|
|
||||||
|
|
||||||
|
def build_hls(mpd_dict: dict, request: Request, key_id: str = None, key: str = None) -> str:
|
||||||
|
"""
|
||||||
|
Builds an HLS manifest from the MPD manifest.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mpd_dict (dict): The MPD manifest data.
|
||||||
|
request (Request): The incoming HTTP request.
|
||||||
|
key_id (str, optional): The DRM key ID. Defaults to None.
|
||||||
|
key (str, optional): The DRM key. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The HLS manifest as a string.
|
||||||
|
"""
|
||||||
|
hls = ["#EXTM3U", "#EXT-X-VERSION:6"]
|
||||||
|
query_params = dict(request.query_params)
|
||||||
|
has_encrypted = query_params.pop("has_encrypted", False)
|
||||||
|
|
||||||
|
video_profiles = {}
|
||||||
|
audio_profiles = {}
|
||||||
|
|
||||||
|
# Get the base URL for the playlist_endpoint endpoint
|
||||||
|
proxy_url = request.url_for("playlist_endpoint")
|
||||||
|
proxy_url = str(proxy_url.replace(scheme=get_original_scheme(request)))
|
||||||
|
|
||||||
|
for profile in mpd_dict["profiles"]:
|
||||||
|
query_params.update({"profile_id": profile["id"], "key_id": key_id or "", "key": key or ""})
|
||||||
|
playlist_url = encode_mediaflow_proxy_url(
|
||||||
|
proxy_url,
|
||||||
|
query_params=query_params,
|
||||||
|
encryption_handler=encryption_handler if has_encrypted else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if "video" in profile["mimeType"]:
|
||||||
|
video_profiles[profile["id"]] = (profile, playlist_url)
|
||||||
|
elif "audio" in profile["mimeType"]:
|
||||||
|
audio_profiles[profile["id"]] = (profile, playlist_url)
|
||||||
|
|
||||||
|
# Add audio streams
|
||||||
|
for i, (profile, playlist_url) in enumerate(audio_profiles.values()):
|
||||||
|
is_default = "YES" if i == 0 else "NO" # Set the first audio track as default
|
||||||
|
hls.append(
|
||||||
|
f'#EXT-X-MEDIA:TYPE=AUDIO,GROUP-ID="audio",NAME="{profile["id"]}",DEFAULT={is_default},AUTOSELECT={is_default},LANGUAGE="{profile.get("lang", "und")}",URI="{playlist_url}"'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add video streams
|
||||||
|
for profile, playlist_url in video_profiles.values():
|
||||||
|
hls.append(
|
||||||
|
f'#EXT-X-STREAM-INF:BANDWIDTH={profile["bandwidth"]},RESOLUTION={profile["width"]}x{profile["height"]},CODECS="{profile["codecs"]}",FRAME-RATE={profile["frameRate"]},AUDIO="audio"'
|
||||||
|
)
|
||||||
|
hls.append(playlist_url)
|
||||||
|
|
||||||
|
return "\n".join(hls)
|
||||||
|
|
||||||
|
|
||||||
|
def build_hls_playlist(mpd_dict: dict, profiles: list[dict], request: Request) -> str:
|
||||||
|
"""
|
||||||
|
Builds an HLS playlist from the MPD manifest for specific profiles.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mpd_dict (dict): The MPD manifest data.
|
||||||
|
profiles (list[dict]): The profiles to include in the playlist.
|
||||||
|
request (Request): The incoming HTTP request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The HLS playlist as a string.
|
||||||
|
"""
|
||||||
|
hls = ["#EXTM3U", "#EXT-X-VERSION:6"]
|
||||||
|
|
||||||
|
added_segments = 0
|
||||||
|
|
||||||
|
proxy_url = request.url_for("segment_endpoint")
|
||||||
|
proxy_url = str(proxy_url.replace(scheme=get_original_scheme(request)))
|
||||||
|
|
||||||
|
for index, profile in enumerate(profiles):
|
||||||
|
segments = profile["segments"]
|
||||||
|
if not segments:
|
||||||
|
logger.warning(f"No segments found for profile {profile['id']}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Add headers for only the first profile
|
||||||
|
if index == 0:
|
||||||
|
sequence = segments[0]["number"]
|
||||||
|
extinf_values = [f["extinf"] for f in segments if "extinf" in f]
|
||||||
|
target_duration = math.ceil(max(extinf_values)) if extinf_values else 3
|
||||||
|
hls.extend(
|
||||||
|
[
|
||||||
|
f"#EXT-X-TARGETDURATION:{target_duration}",
|
||||||
|
f"#EXT-X-MEDIA-SEQUENCE:{sequence}",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if mpd_dict["isLive"]:
|
||||||
|
hls.append("#EXT-X-PLAYLIST-TYPE:EVENT")
|
||||||
|
else:
|
||||||
|
hls.append("#EXT-X-PLAYLIST-TYPE:VOD")
|
||||||
|
|
||||||
|
init_url = profile["initUrl"]
|
||||||
|
|
||||||
|
query_params = dict(request.query_params)
|
||||||
|
query_params.pop("profile_id", None)
|
||||||
|
query_params.pop("d", None)
|
||||||
|
has_encrypted = query_params.pop("has_encrypted", False)
|
||||||
|
|
||||||
|
for segment in segments:
|
||||||
|
hls.append(f'#EXTINF:{segment["extinf"]:.3f},')
|
||||||
|
query_params.update(
|
||||||
|
{"init_url": init_url, "segment_url": segment["media"], "mime_type": profile["mimeType"]}
|
||||||
|
)
|
||||||
|
hls.append(
|
||||||
|
encode_mediaflow_proxy_url(
|
||||||
|
proxy_url,
|
||||||
|
query_params=query_params,
|
||||||
|
encryption_handler=encryption_handler if has_encrypted else None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
added_segments += 1
|
||||||
|
|
||||||
|
if not mpd_dict["isLive"]:
|
||||||
|
hls.append("#EXT-X-ENDLIST")
|
||||||
|
|
||||||
|
logger.info(f"Added {added_segments} segments to HLS playlist")
|
||||||
|
return "\n".join(hls)
|
||||||
164
mediaflow_proxy/routes.py
Normal file
164
mediaflow_proxy/routes.py
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
from fastapi import Request, Depends, APIRouter
|
||||||
|
from pydantic import HttpUrl
|
||||||
|
|
||||||
|
from .handlers import handle_hls_stream_proxy, proxy_stream, get_manifest, get_playlist, get_segment, get_public_ip
|
||||||
|
from .utils.http_utils import get_proxy_headers, ProxyRequestHeaders
|
||||||
|
|
||||||
|
proxy_router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@proxy_router.head("/hls")
|
||||||
|
@proxy_router.get("/hls")
|
||||||
|
async def hls_stream_proxy(
|
||||||
|
request: Request,
|
||||||
|
d: HttpUrl,
|
||||||
|
proxy_headers: ProxyRequestHeaders = Depends(get_proxy_headers),
|
||||||
|
key_url: HttpUrl | None = None,
|
||||||
|
verify_ssl: bool = False,
|
||||||
|
use_request_proxy: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Proxify HLS stream requests, fetching and processing the m3u8 playlist or streaming the content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request (Request): The incoming HTTP request.
|
||||||
|
d (HttpUrl): The destination URL to fetch the content from.
|
||||||
|
key_url (HttpUrl, optional): The HLS Key URL to replace the original key URL. Defaults to None. (Useful for bypassing some sneaky protection)
|
||||||
|
proxy_headers (ProxyRequestHeaders): The headers to include in the request.
|
||||||
|
verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to False.
|
||||||
|
use_request_proxy (bool, optional): Whether to use the MediaFlow proxy configuration. Defaults to True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response: The HTTP response with the processed m3u8 playlist or streamed content.
|
||||||
|
"""
|
||||||
|
destination = str(d)
|
||||||
|
return await handle_hls_stream_proxy(request, destination, proxy_headers, key_url, verify_ssl, use_request_proxy)
|
||||||
|
|
||||||
|
|
||||||
|
@proxy_router.head("/stream")
|
||||||
|
@proxy_router.get("/stream")
|
||||||
|
async def proxy_stream_endpoint(
|
||||||
|
request: Request,
|
||||||
|
d: HttpUrl,
|
||||||
|
proxy_headers: ProxyRequestHeaders = Depends(get_proxy_headers),
|
||||||
|
verify_ssl: bool = False,
|
||||||
|
use_request_proxy: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Proxies stream requests to the given video URL.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request (Request): The incoming HTTP request.
|
||||||
|
d (HttpUrl): The URL of the video to stream.
|
||||||
|
proxy_headers (ProxyRequestHeaders): The headers to include in the request.
|
||||||
|
verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to False.
|
||||||
|
use_request_proxy (bool, optional): Whether to use the MediaFlow proxy configuration. Defaults to True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response: The HTTP response with the streamed content.
|
||||||
|
"""
|
||||||
|
proxy_headers.request.update({"range": proxy_headers.request.get("range", "bytes=0-")})
|
||||||
|
return await proxy_stream(request.method, str(d), proxy_headers, verify_ssl, use_request_proxy)
|
||||||
|
|
||||||
|
|
||||||
|
@proxy_router.get("/mpd/manifest")
|
||||||
|
async def manifest_endpoint(
|
||||||
|
request: Request,
|
||||||
|
d: HttpUrl,
|
||||||
|
proxy_headers: ProxyRequestHeaders = Depends(get_proxy_headers),
|
||||||
|
key_id: str = None,
|
||||||
|
key: str = None,
|
||||||
|
verify_ssl: bool = False,
|
||||||
|
use_request_proxy: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Retrieves and processes the MPD manifest, converting it to an HLS manifest.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request (Request): The incoming HTTP request.
|
||||||
|
d (HttpUrl): The URL of the MPD manifest.
|
||||||
|
proxy_headers (ProxyRequestHeaders): The headers to include in the request.
|
||||||
|
key_id (str, optional): The DRM key ID. Defaults to None.
|
||||||
|
key (str, optional): The DRM key. Defaults to None.
|
||||||
|
verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to False.
|
||||||
|
use_request_proxy (bool, optional): Whether to use the MediaFlow proxy configuration. Defaults to True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response: The HTTP response with the HLS manifest.
|
||||||
|
"""
|
||||||
|
return await get_manifest(request, str(d), proxy_headers, key_id, key, verify_ssl, use_request_proxy)
|
||||||
|
|
||||||
|
|
||||||
|
@proxy_router.get("/mpd/playlist")
|
||||||
|
async def playlist_endpoint(
|
||||||
|
request: Request,
|
||||||
|
d: HttpUrl,
|
||||||
|
profile_id: str,
|
||||||
|
proxy_headers: ProxyRequestHeaders = Depends(get_proxy_headers),
|
||||||
|
key_id: str = None,
|
||||||
|
key: str = None,
|
||||||
|
verify_ssl: bool = False,
|
||||||
|
use_request_proxy: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Retrieves and processes the MPD manifest, converting it to an HLS playlist for a specific profile.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request (Request): The incoming HTTP request.
|
||||||
|
d (HttpUrl): The URL of the MPD manifest.
|
||||||
|
profile_id (str): The profile ID to generate the playlist for.
|
||||||
|
proxy_headers (ProxyRequestHeaders): The headers to include in the request.
|
||||||
|
key_id (str, optional): The DRM key ID. Defaults to None.
|
||||||
|
key (str, optional): The DRM key. Defaults to None.
|
||||||
|
verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to False.
|
||||||
|
use_request_proxy (bool, optional): Whether to use the MediaFlow proxy configuration. Defaults to True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response: The HTTP response with the HLS playlist.
|
||||||
|
"""
|
||||||
|
return await get_playlist(request, str(d), profile_id, proxy_headers, key_id, key, verify_ssl, use_request_proxy)
|
||||||
|
|
||||||
|
|
||||||
|
@proxy_router.get("/mpd/segment")
|
||||||
|
async def segment_endpoint(
|
||||||
|
init_url: HttpUrl,
|
||||||
|
segment_url: HttpUrl,
|
||||||
|
mime_type: str,
|
||||||
|
proxy_headers: ProxyRequestHeaders = Depends(get_proxy_headers),
|
||||||
|
key_id: str = None,
|
||||||
|
key: str = None,
|
||||||
|
verify_ssl: bool = False,
|
||||||
|
use_request_proxy: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Retrieves and processes a media segment, decrypting it if necessary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
init_url (HttpUrl): The URL of the initialization segment.
|
||||||
|
segment_url (HttpUrl): The URL of the media segment.
|
||||||
|
mime_type (str): The MIME type of the segment.
|
||||||
|
proxy_headers (ProxyRequestHeaders): The headers to include in the request.
|
||||||
|
key_id (str, optional): The DRM key ID. Defaults to None.
|
||||||
|
key (str, optional): The DRM key. Defaults to None.
|
||||||
|
verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to False.
|
||||||
|
use_request_proxy (bool, optional): Whether to use the MediaFlow proxy configuration. Defaults to True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response: The HTTP response with the processed segment.
|
||||||
|
"""
|
||||||
|
return await get_segment(
|
||||||
|
str(init_url), str(segment_url), mime_type, proxy_headers, key_id, key, verify_ssl, use_request_proxy
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@proxy_router.get("/ip")
|
||||||
|
async def get_mediaflow_proxy_public_ip(
|
||||||
|
use_request_proxy: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Retrieves the public IP address of the MediaFlow proxy server.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response: The HTTP response with the public IP address in the form of a JSON object. {"ip": "xxx.xxx.xxx.xxx"}
|
||||||
|
"""
|
||||||
|
return await get_public_ip(use_request_proxy)
|
||||||
5
mediaflow_proxy/routes/__init__.py
Normal file
5
mediaflow_proxy/routes/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
from .proxy import proxy_router
|
||||||
|
from .extractor import extractor_router
|
||||||
|
from .speedtest import speedtest_router
|
||||||
|
|
||||||
|
__all__ = ["proxy_router", "extractor_router", "speedtest_router"]
|
||||||
61
mediaflow_proxy/routes/extractor.py
Normal file
61
mediaflow_proxy/routes/extractor.py
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Query, HTTPException, Request, Depends
|
||||||
|
from fastapi.responses import RedirectResponse
|
||||||
|
|
||||||
|
from mediaflow_proxy.extractors.base import ExtractorError
|
||||||
|
from mediaflow_proxy.extractors.factory import ExtractorFactory
|
||||||
|
from mediaflow_proxy.schemas import ExtractorURLParams
|
||||||
|
from mediaflow_proxy.utils.cache_utils import get_cached_extractor_result, set_cache_extractor_result
|
||||||
|
from mediaflow_proxy.utils.http_utils import (
|
||||||
|
encode_mediaflow_proxy_url,
|
||||||
|
get_original_scheme,
|
||||||
|
ProxyRequestHeaders,
|
||||||
|
get_proxy_headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
extractor_router = APIRouter()
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@extractor_router.head("/video")
|
||||||
|
@extractor_router.get("/video")
|
||||||
|
async def extract_url(
|
||||||
|
extractor_params: Annotated[ExtractorURLParams, Query()],
|
||||||
|
request: Request,
|
||||||
|
proxy_headers: Annotated[ProxyRequestHeaders, Depends(get_proxy_headers)],
|
||||||
|
):
|
||||||
|
"""Extract clean links from various video hosting services."""
|
||||||
|
try:
|
||||||
|
cache_key = f"{extractor_params.host}_{extractor_params.model_dump_json()}"
|
||||||
|
response = await get_cached_extractor_result(cache_key)
|
||||||
|
if not response:
|
||||||
|
extractor = ExtractorFactory.get_extractor(extractor_params.host, proxy_headers.request)
|
||||||
|
response = await extractor.extract(extractor_params.destination, **extractor_params.extra_params)
|
||||||
|
await set_cache_extractor_result(cache_key, response)
|
||||||
|
else:
|
||||||
|
response["request_headers"].update(proxy_headers.request)
|
||||||
|
|
||||||
|
response["mediaflow_proxy_url"] = str(
|
||||||
|
request.url_for(response.pop("mediaflow_endpoint")).replace(scheme=get_original_scheme(request))
|
||||||
|
)
|
||||||
|
response["query_params"] = response.get("query_params", {})
|
||||||
|
# Add API password to query params
|
||||||
|
response["query_params"]["api_password"] = request.query_params.get("api_password")
|
||||||
|
|
||||||
|
if extractor_params.redirect_stream:
|
||||||
|
stream_url = encode_mediaflow_proxy_url(
|
||||||
|
**response,
|
||||||
|
response_headers=proxy_headers.response,
|
||||||
|
)
|
||||||
|
return RedirectResponse(url=stream_url, status_code=302)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
except ExtractorError as e:
|
||||||
|
logger.error(f"Extraction failed: {str(e)}")
|
||||||
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Extraction failed: {str(e)}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"Extraction failed: {str(e)}")
|
||||||
138
mediaflow_proxy/routes/proxy.py
Normal file
138
mediaflow_proxy/routes/proxy.py
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import Request, Depends, APIRouter, Query, HTTPException
|
||||||
|
|
||||||
|
from mediaflow_proxy.handlers import (
|
||||||
|
handle_hls_stream_proxy,
|
||||||
|
proxy_stream,
|
||||||
|
get_manifest,
|
||||||
|
get_playlist,
|
||||||
|
get_segment,
|
||||||
|
get_public_ip,
|
||||||
|
)
|
||||||
|
from mediaflow_proxy.schemas import (
|
||||||
|
MPDSegmentParams,
|
||||||
|
MPDPlaylistParams,
|
||||||
|
HLSManifestParams,
|
||||||
|
ProxyStreamParams,
|
||||||
|
MPDManifestParams,
|
||||||
|
)
|
||||||
|
from mediaflow_proxy.utils.http_utils import get_proxy_headers, ProxyRequestHeaders
|
||||||
|
|
||||||
|
proxy_router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@proxy_router.head("/hls/manifest.m3u8")
|
||||||
|
@proxy_router.get("/hls/manifest.m3u8")
|
||||||
|
async def hls_manifest_proxy(
|
||||||
|
request: Request,
|
||||||
|
hls_params: Annotated[HLSManifestParams, Query()],
|
||||||
|
proxy_headers: Annotated[ProxyRequestHeaders, Depends(get_proxy_headers)],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Proxify HLS stream requests, fetching and processing the m3u8 playlist or streaming the content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request (Request): The incoming HTTP request.
|
||||||
|
hls_params (HLSPlaylistParams): The parameters for the HLS stream request.
|
||||||
|
proxy_headers (ProxyRequestHeaders): The headers to include in the request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response: The HTTP response with the processed m3u8 playlist or streamed content.
|
||||||
|
"""
|
||||||
|
return await handle_hls_stream_proxy(request, hls_params, proxy_headers)
|
||||||
|
|
||||||
|
|
||||||
|
@proxy_router.head("/stream")
|
||||||
|
@proxy_router.get("/stream")
|
||||||
|
async def proxy_stream_endpoint(
|
||||||
|
request: Request,
|
||||||
|
stream_params: Annotated[ProxyStreamParams, Query()],
|
||||||
|
proxy_headers: Annotated[ProxyRequestHeaders, Depends(get_proxy_headers)],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Proxies stream requests to the given video URL.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request (Request): The incoming HTTP request.
|
||||||
|
stream_params (ProxyStreamParams): The parameters for the stream request.
|
||||||
|
proxy_headers (ProxyRequestHeaders): The headers to include in the request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response: The HTTP response with the streamed content.
|
||||||
|
"""
|
||||||
|
content_range = proxy_headers.request.get("range", "bytes=0-")
|
||||||
|
if "nan" in content_range.casefold():
|
||||||
|
# Handle invalid range requests "bytes=NaN-NaN"
|
||||||
|
raise HTTPException(status_code=416, detail="Invalid Range Header")
|
||||||
|
proxy_headers.request.update({"range": content_range})
|
||||||
|
return await proxy_stream(request.method, stream_params, proxy_headers)
|
||||||
|
|
||||||
|
|
||||||
|
@proxy_router.get("/mpd/manifest.m3u8")
|
||||||
|
async def mpd_manifest_proxy(
|
||||||
|
request: Request,
|
||||||
|
manifest_params: Annotated[MPDManifestParams, Query()],
|
||||||
|
proxy_headers: Annotated[ProxyRequestHeaders, Depends(get_proxy_headers)],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Retrieves and processes the MPD manifest, converting it to an HLS manifest.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request (Request): The incoming HTTP request.
|
||||||
|
manifest_params (MPDManifestParams): The parameters for the manifest request.
|
||||||
|
proxy_headers (ProxyRequestHeaders): The headers to include in the request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response: The HTTP response with the HLS manifest.
|
||||||
|
"""
|
||||||
|
return await get_manifest(request, manifest_params, proxy_headers)
|
||||||
|
|
||||||
|
|
||||||
|
@proxy_router.get("/mpd/playlist.m3u8")
|
||||||
|
async def playlist_endpoint(
|
||||||
|
request: Request,
|
||||||
|
playlist_params: Annotated[MPDPlaylistParams, Query()],
|
||||||
|
proxy_headers: Annotated[ProxyRequestHeaders, Depends(get_proxy_headers)],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Retrieves and processes the MPD manifest, converting it to an HLS playlist for a specific profile.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request (Request): The incoming HTTP request.
|
||||||
|
playlist_params (MPDPlaylistParams): The parameters for the playlist request.
|
||||||
|
proxy_headers (ProxyRequestHeaders): The headers to include in the request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response: The HTTP response with the HLS playlist.
|
||||||
|
"""
|
||||||
|
return await get_playlist(request, playlist_params, proxy_headers)
|
||||||
|
|
||||||
|
|
||||||
|
@proxy_router.get("/mpd/segment.mp4")
|
||||||
|
async def segment_endpoint(
|
||||||
|
segment_params: Annotated[MPDSegmentParams, Query()],
|
||||||
|
proxy_headers: Annotated[ProxyRequestHeaders, Depends(get_proxy_headers)],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Retrieves and processes a media segment, decrypting it if necessary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
segment_params (MPDSegmentParams): The parameters for the segment request.
|
||||||
|
proxy_headers (ProxyRequestHeaders): The headers to include in the request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response: The HTTP response with the processed segment.
|
||||||
|
"""
|
||||||
|
return await get_segment(segment_params, proxy_headers)
|
||||||
|
|
||||||
|
|
||||||
|
@proxy_router.get("/ip")
|
||||||
|
async def get_mediaflow_proxy_public_ip():
|
||||||
|
"""
|
||||||
|
Retrieves the public IP address of the MediaFlow proxy server.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response: The HTTP response with the public IP address in the form of a JSON object. {"ip": "xxx.xxx.xxx.xxx"}
|
||||||
|
"""
|
||||||
|
return await get_public_ip()
|
||||||
43
mediaflow_proxy/routes/speedtest.py
Normal file
43
mediaflow_proxy/routes/speedtest.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
import uuid
|
||||||
|
|
||||||
|
from fastapi import APIRouter, BackgroundTasks, HTTPException, Request
|
||||||
|
from fastapi.responses import RedirectResponse
|
||||||
|
|
||||||
|
from mediaflow_proxy.speedtest.service import SpeedTestService, SpeedTestProvider
|
||||||
|
|
||||||
|
speedtest_router = APIRouter()
|
||||||
|
|
||||||
|
# Initialize service
|
||||||
|
speedtest_service = SpeedTestService()
|
||||||
|
|
||||||
|
|
||||||
|
@speedtest_router.get("/", summary="Show speed test interface")
|
||||||
|
async def show_speedtest_page():
|
||||||
|
"""Return the speed test HTML interface."""
|
||||||
|
return RedirectResponse(url="/speedtest.html")
|
||||||
|
|
||||||
|
|
||||||
|
@speedtest_router.post("/start", summary="Start a new speed test", response_model=dict)
|
||||||
|
async def start_speedtest(background_tasks: BackgroundTasks, provider: SpeedTestProvider, request: Request):
|
||||||
|
"""Start a new speed test for the specified provider."""
|
||||||
|
task_id = str(uuid.uuid4())
|
||||||
|
api_key = request.headers.get("api_key")
|
||||||
|
|
||||||
|
# Create and initialize the task
|
||||||
|
await speedtest_service.create_test(task_id, provider, api_key)
|
||||||
|
|
||||||
|
# Schedule the speed test
|
||||||
|
background_tasks.add_task(speedtest_service.run_speedtest, task_id, provider, api_key)
|
||||||
|
|
||||||
|
return {"task_id": task_id}
|
||||||
|
|
||||||
|
|
||||||
|
@speedtest_router.get("/results/{task_id}", summary="Get speed test results")
|
||||||
|
async def get_speedtest_results(task_id: str):
|
||||||
|
"""Get the results or current status of a speed test."""
|
||||||
|
task = await speedtest_service.get_test_results(task_id)
|
||||||
|
|
||||||
|
if not task:
|
||||||
|
raise HTTPException(status_code=404, detail="Speed test task not found or expired")
|
||||||
|
|
||||||
|
return task.dict()
|
||||||
74
mediaflow_proxy/schemas.py
Normal file
74
mediaflow_proxy/schemas.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
from typing import Literal, Dict, Any, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, IPvAnyAddress, ConfigDict
|
||||||
|
|
||||||
|
|
||||||
|
class GenerateUrlRequest(BaseModel):
|
||||||
|
mediaflow_proxy_url: str = Field(..., description="The base URL for the mediaflow proxy.")
|
||||||
|
endpoint: Optional[str] = Field(None, description="The specific endpoint to be appended to the base URL.")
|
||||||
|
destination_url: Optional[str] = Field(
|
||||||
|
None, description="The destination URL to which the request will be proxied."
|
||||||
|
)
|
||||||
|
query_params: Optional[dict] = Field(
|
||||||
|
default_factory=dict, description="Query parameters to be included in the request."
|
||||||
|
)
|
||||||
|
request_headers: Optional[dict] = Field(default_factory=dict, description="Headers to be included in the request.")
|
||||||
|
response_headers: Optional[dict] = Field(
|
||||||
|
default_factory=dict, description="Headers to be included in the response."
|
||||||
|
)
|
||||||
|
expiration: Optional[int] = Field(
|
||||||
|
None, description="Expiration time for the URL in seconds. If not provided, the URL will not expire."
|
||||||
|
)
|
||||||
|
api_password: Optional[str] = Field(
|
||||||
|
None, description="API password for encryption. If not provided, the URL will only be encoded."
|
||||||
|
)
|
||||||
|
ip: Optional[IPvAnyAddress] = Field(None, description="The IP address to restrict the URL to.")
|
||||||
|
|
||||||
|
|
||||||
|
class GenericParams(BaseModel):
|
||||||
|
model_config = ConfigDict(populate_by_name=True)
|
||||||
|
|
||||||
|
|
||||||
|
class HLSManifestParams(GenericParams):
|
||||||
|
destination: str = Field(..., description="The URL of the HLS manifest.", alias="d")
|
||||||
|
key_url: Optional[str] = Field(
|
||||||
|
None,
|
||||||
|
description="The HLS Key URL to replace the original key URL. Defaults to None. (Useful for bypassing some sneaky protection)",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ProxyStreamParams(GenericParams):
|
||||||
|
destination: str = Field(..., description="The URL of the stream.", alias="d")
|
||||||
|
|
||||||
|
|
||||||
|
class MPDManifestParams(GenericParams):
|
||||||
|
destination: str = Field(..., description="The URL of the MPD manifest.", alias="d")
|
||||||
|
key_id: Optional[str] = Field(None, description="The DRM key ID (optional).")
|
||||||
|
key: Optional[str] = Field(None, description="The DRM key (optional).")
|
||||||
|
|
||||||
|
|
||||||
|
class MPDPlaylistParams(GenericParams):
|
||||||
|
destination: str = Field(..., description="The URL of the MPD manifest.", alias="d")
|
||||||
|
profile_id: str = Field(..., description="The profile ID to generate the playlist for.")
|
||||||
|
key_id: Optional[str] = Field(None, description="The DRM key ID (optional).")
|
||||||
|
key: Optional[str] = Field(None, description="The DRM key (optional).")
|
||||||
|
|
||||||
|
|
||||||
|
class MPDSegmentParams(GenericParams):
|
||||||
|
init_url: str = Field(..., description="The URL of the initialization segment.")
|
||||||
|
segment_url: str = Field(..., description="The URL of the media segment.")
|
||||||
|
mime_type: str = Field(..., description="The MIME type of the segment.")
|
||||||
|
key_id: Optional[str] = Field(None, description="The DRM key ID (optional).")
|
||||||
|
key: Optional[str] = Field(None, description="The DRM key (optional).")
|
||||||
|
|
||||||
|
|
||||||
|
class ExtractorURLParams(GenericParams):
|
||||||
|
host: Literal["Doodstream", "Mixdrop", "Uqload", "Streamtape", "Supervideo", "LiveTV"] = Field(
|
||||||
|
..., description="The host to extract the URL from."
|
||||||
|
)
|
||||||
|
destination: str = Field(..., description="The URL of the stream.", alias="d")
|
||||||
|
redirect_stream: bool = Field(False, description="Whether to redirect to the stream endpoint automatically.")
|
||||||
|
extra_params: Dict[str, Any] = Field(
|
||||||
|
default_factory=dict,
|
||||||
|
description="Additional parameters required for specific extractors (e.g., stream_title for LiveTV)",
|
||||||
|
)
|
||||||
0
mediaflow_proxy/speedtest/__init__.py
Normal file
0
mediaflow_proxy/speedtest/__init__.py
Normal file
46
mediaflow_proxy/speedtest/models.py
Normal file
46
mediaflow_proxy/speedtest/models.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class SpeedTestProvider(str, Enum):
|
||||||
|
REAL_DEBRID = "real_debrid"
|
||||||
|
ALL_DEBRID = "all_debrid"
|
||||||
|
|
||||||
|
|
||||||
|
class ServerInfo(BaseModel):
|
||||||
|
url: str
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
class UserInfo(BaseModel):
|
||||||
|
ip: Optional[str] = None
|
||||||
|
isp: Optional[str] = None
|
||||||
|
country: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class SpeedTestResult(BaseModel):
|
||||||
|
speed_mbps: float = Field(..., description="Speed in Mbps")
|
||||||
|
duration: float = Field(..., description="Test duration in seconds")
|
||||||
|
data_transferred: int = Field(..., description="Data transferred in bytes")
|
||||||
|
timestamp: datetime = Field(default_factory=datetime.utcnow)
|
||||||
|
|
||||||
|
|
||||||
|
class LocationResult(BaseModel):
|
||||||
|
result: Optional[SpeedTestResult] = None
|
||||||
|
error: Optional[str] = None
|
||||||
|
server_name: str
|
||||||
|
server_url: str
|
||||||
|
|
||||||
|
|
||||||
|
class SpeedTestTask(BaseModel):
|
||||||
|
task_id: str
|
||||||
|
provider: SpeedTestProvider
|
||||||
|
results: Dict[str, LocationResult] = {}
|
||||||
|
started_at: datetime
|
||||||
|
completed_at: Optional[datetime] = None
|
||||||
|
status: str = "running"
|
||||||
|
user_info: Optional[UserInfo] = None
|
||||||
|
current_location: Optional[str] = None
|
||||||
50
mediaflow_proxy/speedtest/providers/all_debrid.py
Normal file
50
mediaflow_proxy/speedtest/providers/all_debrid.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
import random
|
||||||
|
from typing import Dict, Tuple, Optional
|
||||||
|
|
||||||
|
from mediaflow_proxy.configs import settings
|
||||||
|
from mediaflow_proxy.speedtest.models import ServerInfo, UserInfo
|
||||||
|
from mediaflow_proxy.speedtest.providers.base import BaseSpeedTestProvider, SpeedTestProviderConfig
|
||||||
|
from mediaflow_proxy.utils.http_utils import request_with_retry
|
||||||
|
|
||||||
|
|
||||||
|
class SpeedTestError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class AllDebridSpeedTest(BaseSpeedTestProvider):
|
||||||
|
"""AllDebrid speed test provider implementation."""
|
||||||
|
|
||||||
|
def __init__(self, api_key: str):
|
||||||
|
self.api_key = api_key
|
||||||
|
self.servers: Dict[str, ServerInfo] = {}
|
||||||
|
|
||||||
|
async def get_test_urls(self) -> Tuple[Dict[str, str], Optional[UserInfo]]:
|
||||||
|
response = await request_with_retry(
|
||||||
|
"GET",
|
||||||
|
"https://alldebrid.com/internalapi/v4/speedtest",
|
||||||
|
headers={"User-Agent": settings.user_agent},
|
||||||
|
params={"agent": "service", "version": "1.0-363869a7", "apikey": self.api_key},
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise SpeedTestError("Failed to fetch AllDebrid servers")
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
if data["status"] != "success":
|
||||||
|
raise SpeedTestError("AllDebrid API returned error")
|
||||||
|
|
||||||
|
# Create UserInfo
|
||||||
|
user_info = UserInfo(ip=data["data"]["ip"], isp=data["data"]["isp"], country=data["data"]["country"])
|
||||||
|
|
||||||
|
# Store server info
|
||||||
|
self.servers = {server["name"]: ServerInfo(**server) for server in data["data"]["servers"]}
|
||||||
|
|
||||||
|
# Generate URLs with random number
|
||||||
|
random_number = f"{random.uniform(1, 2):.24f}".replace(".", "")
|
||||||
|
urls = {name: f"{server.url}/speedtest/{random_number}" for name, server in self.servers.items()}
|
||||||
|
|
||||||
|
return urls, user_info
|
||||||
|
|
||||||
|
async def get_config(self) -> SpeedTestProviderConfig:
|
||||||
|
urls, _ = await self.get_test_urls()
|
||||||
|
return SpeedTestProviderConfig(test_duration=10, test_urls=urls)
|
||||||
24
mediaflow_proxy/speedtest/providers/base.py
Normal file
24
mediaflow_proxy/speedtest/providers/base.py
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Dict, Tuple, Optional
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from mediaflow_proxy.speedtest.models import UserInfo
|
||||||
|
|
||||||
|
|
||||||
|
class SpeedTestProviderConfig(BaseModel):
|
||||||
|
test_duration: int = 10 # seconds
|
||||||
|
test_urls: Dict[str, str]
|
||||||
|
|
||||||
|
|
||||||
|
class BaseSpeedTestProvider(ABC):
|
||||||
|
"""Base class for speed test providers."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_test_urls(self) -> Tuple[Dict[str, str], Optional[UserInfo]]:
|
||||||
|
"""Get list of test URLs for the provider and optional user info."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_config(self) -> SpeedTestProviderConfig:
|
||||||
|
"""Get provider-specific configuration."""
|
||||||
|
pass
|
||||||
32
mediaflow_proxy/speedtest/providers/real_debrid.py
Normal file
32
mediaflow_proxy/speedtest/providers/real_debrid.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
from typing import Dict, Tuple, Optional
|
||||||
|
import random
|
||||||
|
|
||||||
|
from mediaflow_proxy.speedtest.models import UserInfo
|
||||||
|
from mediaflow_proxy.speedtest.providers.base import BaseSpeedTestProvider, SpeedTestProviderConfig
|
||||||
|
|
||||||
|
|
||||||
|
class RealDebridSpeedTest(BaseSpeedTestProvider):
|
||||||
|
"""RealDebrid speed test provider implementation."""
|
||||||
|
|
||||||
|
async def get_test_urls(self) -> Tuple[Dict[str, str], Optional[UserInfo]]:
|
||||||
|
urls = {
|
||||||
|
"AMS": "https://45.download.real-debrid.com/speedtest/testDefault.rar/",
|
||||||
|
"RBX": "https://rbx.download.real-debrid.com/speedtest/test.rar/",
|
||||||
|
"LON1": "https://lon1.download.real-debrid.com/speedtest/test.rar/",
|
||||||
|
"HKG1": "https://hkg1.download.real-debrid.com/speedtest/test.rar/",
|
||||||
|
"SGP1": "https://sgp1.download.real-debrid.com/speedtest/test.rar/",
|
||||||
|
"SGPO1": "https://sgpo1.download.real-debrid.com/speedtest/test.rar/",
|
||||||
|
"TYO1": "https://tyo1.download.real-debrid.com/speedtest/test.rar/",
|
||||||
|
"LAX1": "https://lax1.download.real-debrid.com/speedtest/test.rar/",
|
||||||
|
"TLV1": "https://tlv1.download.real-debrid.com/speedtest/test.rar/",
|
||||||
|
"MUM1": "https://mum1.download.real-debrid.com/speedtest/test.rar/",
|
||||||
|
"JKT1": "https://jkt1.download.real-debrid.com/speedtest/test.rar/",
|
||||||
|
"Cloudflare": "https://45.download.real-debrid.cloud/speedtest/testCloudflare.rar/",
|
||||||
|
}
|
||||||
|
# Add random number to prevent caching
|
||||||
|
urls = {location: f"{base_url}{random.uniform(0, 1):.16f}" for location, base_url in urls.items()}
|
||||||
|
return urls, None
|
||||||
|
|
||||||
|
async def get_config(self) -> SpeedTestProviderConfig:
|
||||||
|
urls, _ = await self.get_test_urls()
|
||||||
|
return SpeedTestProviderConfig(test_duration=10, test_urls=urls)
|
||||||
129
mediaflow_proxy/speedtest/service.py
Normal file
129
mediaflow_proxy/speedtest/service.py
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Dict, Optional, Type
|
||||||
|
|
||||||
|
from mediaflow_proxy.utils.cache_utils import get_cached_speedtest, set_cache_speedtest
|
||||||
|
from mediaflow_proxy.utils.http_utils import Streamer, create_httpx_client
|
||||||
|
from .models import SpeedTestTask, LocationResult, SpeedTestResult, SpeedTestProvider
|
||||||
|
from .providers.all_debrid import AllDebridSpeedTest
|
||||||
|
from .providers.base import BaseSpeedTestProvider
|
||||||
|
from .providers.real_debrid import RealDebridSpeedTest
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SpeedTestService:
|
||||||
|
"""Service for managing speed tests across different providers."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# Provider mapping
|
||||||
|
self._providers: Dict[SpeedTestProvider, Type[BaseSpeedTestProvider]] = {
|
||||||
|
SpeedTestProvider.REAL_DEBRID: RealDebridSpeedTest,
|
||||||
|
SpeedTestProvider.ALL_DEBRID: AllDebridSpeedTest,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _get_provider(self, provider: SpeedTestProvider, api_key: Optional[str] = None) -> BaseSpeedTestProvider:
|
||||||
|
"""Get the appropriate provider implementation."""
|
||||||
|
provider_class = self._providers.get(provider)
|
||||||
|
if not provider_class:
|
||||||
|
raise ValueError(f"Unsupported provider: {provider}")
|
||||||
|
|
||||||
|
if provider == SpeedTestProvider.ALL_DEBRID and not api_key:
|
||||||
|
raise ValueError("API key required for AllDebrid")
|
||||||
|
|
||||||
|
return provider_class(api_key) if provider == SpeedTestProvider.ALL_DEBRID else provider_class()
|
||||||
|
|
||||||
|
async def create_test(
|
||||||
|
self, task_id: str, provider: SpeedTestProvider, api_key: Optional[str] = None
|
||||||
|
) -> SpeedTestTask:
|
||||||
|
"""Create a new speed test task."""
|
||||||
|
provider_impl = self._get_provider(provider, api_key)
|
||||||
|
|
||||||
|
# Get initial URLs and user info
|
||||||
|
urls, user_info = await provider_impl.get_test_urls()
|
||||||
|
|
||||||
|
task = SpeedTestTask(
|
||||||
|
task_id=task_id, provider=provider, started_at=datetime.now(tz=timezone.utc), user_info=user_info
|
||||||
|
)
|
||||||
|
|
||||||
|
await set_cache_speedtest(task_id, task)
|
||||||
|
return task
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_test_results(task_id: str) -> Optional[SpeedTestTask]:
|
||||||
|
"""Get results for a specific task."""
|
||||||
|
return await get_cached_speedtest(task_id)
|
||||||
|
|
||||||
|
async def run_speedtest(self, task_id: str, provider: SpeedTestProvider, api_key: Optional[str] = None):
|
||||||
|
"""Run the speed test with real-time updates."""
|
||||||
|
try:
|
||||||
|
task = await get_cached_speedtest(task_id)
|
||||||
|
if not task:
|
||||||
|
raise ValueError(f"Task {task_id} not found")
|
||||||
|
|
||||||
|
provider_impl = self._get_provider(provider, api_key)
|
||||||
|
config = await provider_impl.get_config()
|
||||||
|
|
||||||
|
async with create_httpx_client() as client:
|
||||||
|
streamer = Streamer(client)
|
||||||
|
|
||||||
|
for location, url in config.test_urls.items():
|
||||||
|
try:
|
||||||
|
task.current_location = location
|
||||||
|
await set_cache_speedtest(task_id, task)
|
||||||
|
result = await self._test_location(location, url, streamer, config.test_duration, provider_impl)
|
||||||
|
task.results[location] = result
|
||||||
|
await set_cache_speedtest(task_id, task)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error testing {location}: {str(e)}")
|
||||||
|
task.results[location] = LocationResult(
|
||||||
|
error=str(e), server_name=location, server_url=config.test_urls[location]
|
||||||
|
)
|
||||||
|
await set_cache_speedtest(task_id, task)
|
||||||
|
|
||||||
|
# Mark task as completed
|
||||||
|
task.completed_at = datetime.now(tz=timezone.utc)
|
||||||
|
task.status = "completed"
|
||||||
|
task.current_location = None
|
||||||
|
await set_cache_speedtest(task_id, task)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in speed test task {task_id}: {str(e)}")
|
||||||
|
if task := await get_cached_speedtest(task_id):
|
||||||
|
task.status = "failed"
|
||||||
|
await set_cache_speedtest(task_id, task)
|
||||||
|
|
||||||
|
async def _test_location(
|
||||||
|
self, location: str, url: str, streamer: Streamer, test_duration: int, provider: BaseSpeedTestProvider
|
||||||
|
) -> LocationResult:
|
||||||
|
"""Test speed for a specific location."""
|
||||||
|
try:
|
||||||
|
start_time = time.time()
|
||||||
|
total_bytes = 0
|
||||||
|
|
||||||
|
await streamer.create_streaming_response(url, headers={})
|
||||||
|
|
||||||
|
async for chunk in streamer.stream_content():
|
||||||
|
if time.time() - start_time >= test_duration:
|
||||||
|
break
|
||||||
|
total_bytes += len(chunk)
|
||||||
|
|
||||||
|
duration = time.time() - start_time
|
||||||
|
speed_mbps = (total_bytes * 8) / (duration * 1_000_000)
|
||||||
|
|
||||||
|
# Get server info if available (for AllDebrid)
|
||||||
|
server_info = getattr(provider, "servers", {}).get(location)
|
||||||
|
server_url = server_info.url if server_info else url
|
||||||
|
|
||||||
|
return LocationResult(
|
||||||
|
result=SpeedTestResult(
|
||||||
|
speed_mbps=round(speed_mbps, 2), duration=round(duration, 2), data_transferred=total_bytes
|
||||||
|
),
|
||||||
|
server_name=location,
|
||||||
|
server_url=server_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error testing {location}: {str(e)}")
|
||||||
|
raise # Re-raise to be handled by run_speedtest
|
||||||
76
mediaflow_proxy/static/index.html
Normal file
76
mediaflow_proxy/static/index.html
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
<title>MediaFlow Proxy</title>
|
||||||
|
<link rel="icon" href="/logo.png" type="image/x-icon">
|
||||||
|
<style>
|
||||||
|
body {
|
||||||
|
font-family: Arial, sans-serif;
|
||||||
|
line-height: 1.6;
|
||||||
|
color: #333;
|
||||||
|
max-width: 800px;
|
||||||
|
margin: 0 auto;
|
||||||
|
padding: 20px;
|
||||||
|
background-color: #f9f9f9;
|
||||||
|
}
|
||||||
|
|
||||||
|
header {
|
||||||
|
background-color: #90aacc;
|
||||||
|
color: #fff;
|
||||||
|
padding: 10px 0;
|
||||||
|
text-align: center;
|
||||||
|
}
|
||||||
|
|
||||||
|
header img {
|
||||||
|
width: 200px;
|
||||||
|
height: 200px;
|
||||||
|
vertical-align: middle;
|
||||||
|
border-radius: 15px;
|
||||||
|
}
|
||||||
|
|
||||||
|
header h1 {
|
||||||
|
display: inline;
|
||||||
|
margin-left: 20px;
|
||||||
|
font-size: 36px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.feature {
|
||||||
|
background-color: #f4f4f4;
|
||||||
|
border-left: 4px solid #3498db;
|
||||||
|
padding: 10px;
|
||||||
|
margin-bottom: 10px;
|
||||||
|
}
|
||||||
|
|
||||||
|
a {
|
||||||
|
color: #3498db;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<header>
|
||||||
|
<img src="/logo.png" alt="MediaFlow Proxy Logo">
|
||||||
|
<h1>MediaFlow Proxy</h1>
|
||||||
|
</header>
|
||||||
|
<p>A high-performance proxy server for streaming media, supporting HTTP(S), HLS, and MPEG-DASH with real-time DRM decryption.</p>
|
||||||
|
|
||||||
|
<h2>Key Features</h2>
|
||||||
|
<div class="feature">Convert MPEG-DASH streams (DRM-protected and non-protected) to HLS</div>
|
||||||
|
<div class="feature">Support for Clear Key DRM-protected MPD DASH streams</div>
|
||||||
|
<div class="feature">Handle both live and video-on-demand (VOD) DASH streams</div>
|
||||||
|
<div class="feature">Proxy HTTP/HTTPS links with custom headers</div>
|
||||||
|
<div class="feature">Proxy and modify HLS (M3U8) streams in real-time with custom headers and key URL modifications for bypassing some sneaky restrictions.</div>
|
||||||
|
<div class="feature">Protect against unauthorized access and network bandwidth abuses</div>
|
||||||
|
|
||||||
|
<h2>Getting Started</h2>
|
||||||
|
<p>Visit the <a href="https://github.com/mhdzumair/mediaflow-proxy">GitHub repository</a> for installation instructions and documentation.</p>
|
||||||
|
|
||||||
|
<h2>Premium Hosted Service</h2>
|
||||||
|
<p>For a hassle-free experience, check out <a href="https://store.elfhosted.com/product/mediaflow-proxy">premium hosted service on ElfHosted</a>.</p>
|
||||||
|
|
||||||
|
<h2>API Documentation</h2>
|
||||||
|
<p>Explore the <a href="/docs">Swagger UI</a> for comprehensive details about the API endpoints and their usage.</p>
|
||||||
|
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
BIN
mediaflow_proxy/static/logo.png
Normal file
BIN
mediaflow_proxy/static/logo.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 85 KiB |
697
mediaflow_proxy/static/speedtest.html
Normal file
697
mediaflow_proxy/static/speedtest.html
Normal file
@@ -0,0 +1,697 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en" class="h-full">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
<title>Debrid Speed Test</title>
|
||||||
|
<script src="https://cdn.tailwindcss.com"></script>
|
||||||
|
<script>
|
||||||
|
tailwind.config = {
|
||||||
|
darkMode: 'class',
|
||||||
|
theme: {
|
||||||
|
extend: {
|
||||||
|
animation: {
|
||||||
|
'progress': 'progress 180s linear forwards',
|
||||||
|
},
|
||||||
|
keyframes: {
|
||||||
|
progress: {
|
||||||
|
'0%': {width: '0%'},
|
||||||
|
'100%': {width: '100%'}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
</script>
|
||||||
|
<style>
|
||||||
|
.provider-card {
|
||||||
|
transition: all 0.3s ease;
|
||||||
|
}
|
||||||
|
|
||||||
|
.provider-card:hover {
|
||||||
|
transform: translateY(-5px);
|
||||||
|
}
|
||||||
|
|
||||||
|
@keyframes slideIn {
|
||||||
|
from {
|
||||||
|
transform: translateY(20px);
|
||||||
|
opacity: 0;
|
||||||
|
}
|
||||||
|
to {
|
||||||
|
transform: translateY(0);
|
||||||
|
opacity: 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
.slide-in {
|
||||||
|
animation: slideIn 0.3s ease-out forwards;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body class="bg-gray-100 dark:bg-gray-900 min-h-full">
|
||||||
|
<!-- Theme Toggle -->
|
||||||
|
<div class="fixed top-4 right-4 z-50">
|
||||||
|
<button id="themeToggle" class="p-2 rounded-full bg-gray-200 dark:bg-gray-700 hover:bg-gray-300 dark:hover:bg-gray-600 transition-colors">
|
||||||
|
<svg id="sunIcon" class="w-6 h-6 text-yellow-500 hidden dark:block" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||||
|
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2"
|
||||||
|
d="M12 3v1m0 16v1m9-9h-1M4 12H3m15.364 6.364l-.707-.707M6.343 6.343l-.707-.707m12.728 0l-.707.707M6.343 17.657l-.707.707M16 12a4 4 0 11-8 0 4 4 0 018 0z"/>
|
||||||
|
</svg>
|
||||||
|
<svg id="moonIcon" class="w-6 h-6 text-gray-700 block dark:hidden" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||||
|
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M20.354 15.354A9 9 0 018.646 3.646 9.003 9.003 0 0012 21a9.003 9.003 0 008.354-5.646z"/>
|
||||||
|
</svg>
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<main class="container mx-auto px-4 py-8">
|
||||||
|
<!-- Views Container -->
|
||||||
|
<div id="views-container">
|
||||||
|
<!-- API Password View -->
|
||||||
|
<div id="passwordView" class="space-y-8">
|
||||||
|
<h1 class="text-3xl font-bold text-center text-gray-800 dark:text-white mb-8">
|
||||||
|
Enter API Password
|
||||||
|
</h1>
|
||||||
|
|
||||||
|
<div class="max-w-md mx-auto">
|
||||||
|
<form id="passwordForm" class="bg-white dark:bg-gray-800 rounded-lg shadow-lg p-6 space-y-4">
|
||||||
|
<div class="space-y-2">
|
||||||
|
<label for="apiPassword" class="block text-sm font-medium text-gray-700 dark:text-gray-300">
|
||||||
|
API Password
|
||||||
|
</label>
|
||||||
|
<input
|
||||||
|
type="password"
|
||||||
|
id="apiPassword"
|
||||||
|
class="w-full px-4 py-2 rounded-md border border-gray-300 dark:border-gray-600 bg-white dark:bg-gray-700 text-gray-900 dark:text-white focus:outline-none focus:ring-2 focus:ring-blue-500"
|
||||||
|
required
|
||||||
|
>
|
||||||
|
</div>
|
||||||
|
<div class="flex items-center space-x-2">
|
||||||
|
<input
|
||||||
|
type="checkbox"
|
||||||
|
id="rememberPassword"
|
||||||
|
class="rounded border-gray-300 dark:border-gray-600"
|
||||||
|
>
|
||||||
|
<label for="rememberPassword" class="text-sm text-gray-600 dark:text-gray-400">
|
||||||
|
Remember password
|
||||||
|
</label>
|
||||||
|
</div>
|
||||||
|
<button
|
||||||
|
type="submit"
|
||||||
|
class="w-full px-4 py-2 bg-blue-500 text-white rounded-md hover:bg-blue-600 focus:outline-none focus:ring-2 focus:ring-blue-500 focus:ring-offset-2 transition-colors"
|
||||||
|
>
|
||||||
|
Continue
|
||||||
|
</button>
|
||||||
|
</form>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Provider Selection View -->
|
||||||
|
<div id="selectionView" class="space-y-8 hidden">
|
||||||
|
<h1 class="text-3xl font-bold text-center text-gray-800 dark:text-white mb-8">
|
||||||
|
Select Debrid Service for Speed Test
|
||||||
|
</h1>
|
||||||
|
<div class="grid grid-cols-1 md:grid-cols-2 gap-6 max-w-4xl mx-auto">
|
||||||
|
<!-- Real-Debrid Card -->
|
||||||
|
<button onclick="startTest('real_debrid')" class="provider-card bg-white dark:bg-gray-800 rounded-lg shadow-lg p-6 text-left hover:shadow-xl transition-shadow">
|
||||||
|
<h2 class="text-xl font-semibold text-gray-800 dark:text-white mb-2">Real-Debrid</h2>
|
||||||
|
<p class="text-gray-600 dark:text-gray-300">Test speeds across multiple Real-Debrid servers worldwide</p>
|
||||||
|
</button>
|
||||||
|
|
||||||
|
<!-- AllDebrid Card -->
|
||||||
|
<button onclick="showAllDebridSetup()" class="provider-card bg-white dark:bg-gray-800 rounded-lg shadow-lg p-6 text-left hover:shadow-xl transition-shadow">
|
||||||
|
<h2 class="text-xl font-semibold text-gray-800 dark:text-white mb-2">AllDebrid</h2>
|
||||||
|
<p class="text-gray-600 dark:text-gray-300">Measure download speeds from AllDebrid servers</p>
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- AllDebrid Setup View -->
|
||||||
|
<div id="allDebridSetupView" class="max-w-md mx-auto space-y-6 hidden">
|
||||||
|
<h2 class="text-2xl font-bold text-center text-gray-800 dark:text-white mb-8">
|
||||||
|
AllDebrid Setup
|
||||||
|
</h2>
|
||||||
|
<div class="bg-white dark:bg-gray-800 rounded-lg shadow-lg p-6">
|
||||||
|
<form id="allDebridForm" class="space-y-4">
|
||||||
|
<div class="space-y-2">
|
||||||
|
<label for="adApiKey" class="block text-sm font-medium text-gray-700 dark:text-gray-300">
|
||||||
|
AllDebrid API Key
|
||||||
|
</label>
|
||||||
|
<input
|
||||||
|
type="password"
|
||||||
|
id="adApiKey"
|
||||||
|
class="w-full px-4 py-2 rounded-md border border-gray-300 dark:border-gray-600 bg-white dark:bg-gray-700 text-gray-900 dark:text-white focus:outline-none focus:ring-2 focus:ring-blue-500"
|
||||||
|
required
|
||||||
|
>
|
||||||
|
<p class="text-sm text-gray-500 dark:text-gray-400">
|
||||||
|
You can find your API key in the AllDebrid dashboard
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
<div class="flex items-center space-x-2">
|
||||||
|
<input
|
||||||
|
type="checkbox"
|
||||||
|
id="rememberAdKey"
|
||||||
|
class="rounded border-gray-300 dark:border-gray-600"
|
||||||
|
>
|
||||||
|
<label for="rememberAdKey" class="text-sm text-gray-600 dark:text-gray-400">
|
||||||
|
Remember API key
|
||||||
|
</label>
|
||||||
|
</div>
|
||||||
|
<div class="flex space-x-3">
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
onclick="showView('selectionView')"
|
||||||
|
class="flex-1 px-4 py-2 border border-gray-300 dark:border-gray-600 text-gray-700 dark:text-gray-300 rounded-md hover:bg-gray-50 dark:hover:bg-gray-700 focus:outline-none focus:ring-2 focus:ring-blue-500 transition-colors"
|
||||||
|
>
|
||||||
|
Back
|
||||||
|
</button>
|
||||||
|
<button
|
||||||
|
type="submit"
|
||||||
|
class="flex-1 px-4 py-2 bg-blue-500 text-white rounded-md hover:bg-blue-600 focus:outline-none focus:ring-2 focus:ring-blue-500 focus:ring-offset-2 transition-colors"
|
||||||
|
>
|
||||||
|
Start Test
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</form>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Testing View -->
|
||||||
|
<div id="testingView" class="max-w-4xl mx-auto space-y-6 hidden">
|
||||||
|
<div class="bg-white dark:bg-gray-800 rounded-lg shadow-lg p-6">
|
||||||
|
<!-- User Info Section -->
|
||||||
|
<div id="userInfo" class="mb-6 hidden">
|
||||||
|
<!-- User info will be populated dynamically -->
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Progress Section -->
|
||||||
|
<div class="space-y-4">
|
||||||
|
<div class="text-center text-gray-600 dark:text-gray-300" id="currentLocation">
|
||||||
|
Initializing test...
|
||||||
|
</div>
|
||||||
|
<div class="h-2 bg-gray-200 dark:bg-gray-700 rounded-full overflow-hidden">
|
||||||
|
<div class="h-full bg-blue-500 animate-progress" id="progressBar"></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Results Container -->
|
||||||
|
<div id="resultsContainer" class="mt-8">
|
||||||
|
<!-- Results will be populated dynamically -->
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Results View -->
|
||||||
|
<div id="resultsView" class="max-w-4xl mx-auto space-y-6 hidden">
|
||||||
|
<div class="bg-white dark:bg-gray-800 rounded-lg shadow-lg p-6">
|
||||||
|
<div class="space-y-6">
|
||||||
|
<!-- Summary Section -->
|
||||||
|
<div class="border-b border-gray-200 dark:border-gray-700 pb-4">
|
||||||
|
<h3 class="text-lg font-semibold text-gray-800 dark:text-white mb-4">Test Summary</h3>
|
||||||
|
<div class="grid grid-cols-2 md:grid-cols-3 gap-4">
|
||||||
|
<div class="space-y-1">
|
||||||
|
<div class="text-sm text-gray-500 dark:text-gray-400">Fastest Server</div>
|
||||||
|
<div id="fastestServer" class="font-medium text-gray-900 dark:text-white"></div>
|
||||||
|
</div>
|
||||||
|
<div class="space-y-1">
|
||||||
|
<div class="text-sm text-gray-500 dark:text-gray-400">Top Speed</div>
|
||||||
|
<div id="topSpeed" class="font-medium text-green-500"></div>
|
||||||
|
</div>
|
||||||
|
<div class="space-y-1">
|
||||||
|
<div class="text-sm text-gray-500 dark:text-gray-400">Average Speed</div>
|
||||||
|
<div id="avgSpeed" class="font-medium text-blue-500"></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Detailed Results -->
|
||||||
|
<div id="finalResults" class="space-y-4">
|
||||||
|
<!-- Results will be populated here -->
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="text-center mt-6">
|
||||||
|
<button onclick="resetTest()" class="px-6 py-2 bg-blue-500 text-white rounded-lg hover:bg-blue-600 transition-colors">
|
||||||
|
Test Another Provider
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Error View -->
|
||||||
|
<div id="errorView" class="max-w-4xl mx-auto space-y-6 hidden">
|
||||||
|
<div class="bg-red-50 dark:bg-red-900/50 border-l-4 border-red-500 p-4 rounded">
|
||||||
|
<div class="flex">
|
||||||
|
<div class="flex-shrink-0">
|
||||||
|
<svg class="h-5 w-5 text-red-400" viewBox="0 0 20 20" fill="currentColor">
|
||||||
|
<path fill-rule="evenodd"
|
||||||
|
d="M10 18a8 8 0 100-16 8 8 0 000 16zM8.707 7.293a1 1 0 00-1.414 1.414L8.586 10l-1.293 1.293a1 1 0 101.414 1.414L10 11.414l1.293 1.293a1 1 0 001.414-1.414L11.414 10l1.293-1.293a1 1 0 00-1.414-1.414L10 8.586 8.707 7.293z"
|
||||||
|
clip-rule="evenodd"/>
|
||||||
|
</svg>
|
||||||
|
</div>
|
||||||
|
<div class="ml-3">
|
||||||
|
<p class="text-sm text-red-700 dark:text-red-200" id="errorMessage"></p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="text-center">
|
||||||
|
<button onclick="resetTest()"
|
||||||
|
class="inline-flex items-center px-4 py-2 border border-transparent text-sm font-medium rounded-md shadow-sm text-white bg-blue-600 hover:bg-blue-700 focus:outline-none focus:ring-2 focus:ring-offset-2 focus:ring-blue-500 dark:hover:bg-blue-500 transition-colors duration-200">
|
||||||
|
Try Again
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
</main>
|
||||||
|
|
||||||
|
<script>
|
||||||
|
// Config and State
|
||||||
|
const STATE = {
|
||||||
|
apiPassword: localStorage.getItem('speedtest_api_password'),
|
||||||
|
adApiKey: localStorage.getItem('ad_api_key'),
|
||||||
|
currentTaskId: null,
|
||||||
|
resultsCount: 0,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Theme handling
|
||||||
|
function setTheme(theme) {
|
||||||
|
document.documentElement.classList.toggle('dark', theme === 'dark');
|
||||||
|
localStorage.theme = theme;
|
||||||
|
}
|
||||||
|
|
||||||
|
function initTheme() {
|
||||||
|
const prefersDark = window.matchMedia('(prefers-color-scheme: dark)').matches;
|
||||||
|
setTheme(localStorage.theme || (prefersDark ? 'dark' : 'light'));
|
||||||
|
}
|
||||||
|
|
||||||
|
// View management
|
||||||
|
function showView(viewId) {
|
||||||
|
document.querySelectorAll('#views-container > div').forEach(view => {
|
||||||
|
view.classList.toggle('hidden', view.id !== viewId);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
function createErrorResult(location, data) {
|
||||||
|
return `
|
||||||
|
<div class="py-4">
|
||||||
|
<div class="flex justify-between items-center">
|
||||||
|
<div>
|
||||||
|
<span class="font-medium text-gray-800 dark:text-white">${location}</span>
|
||||||
|
<span class="ml-2 text-sm text-gray-500 dark:text-gray-400">${data.server_name || ''}</span>
|
||||||
|
</div>
|
||||||
|
<span class="text-sm text-red-500 dark:text-red-400">
|
||||||
|
Failed
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
<div class="mt-1 text-sm text-red-400 dark:text-red-300">
|
||||||
|
${data.error || 'Test failed'}
|
||||||
|
</div>
|
||||||
|
<div class="mt-1 text-xs text-gray-400 dark:text-gray-500">
|
||||||
|
Server: ${data.server_url}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
`;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
function formatBytes(bytes) {
|
||||||
|
const units = ['B', 'KB', 'MB', 'GB'];
|
||||||
|
let value = bytes;
|
||||||
|
let unitIndex = 0;
|
||||||
|
|
||||||
|
while (value >= 1024 && unitIndex < units.length - 1) {
|
||||||
|
value /= 1024;
|
||||||
|
unitIndex++;
|
||||||
|
}
|
||||||
|
|
||||||
|
return `${value.toFixed(2)} ${units[unitIndex]}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
function handleAuthError() {
|
||||||
|
localStorage.removeItem('speedtest_api_password');
|
||||||
|
STATE.apiPassword = null;
|
||||||
|
showError('Authentication failed. Please check your API password.');
|
||||||
|
}
|
||||||
|
|
||||||
|
function showError(message) {
|
||||||
|
document.getElementById('errorMessage').textContent = message;
|
||||||
|
showView('errorView');
|
||||||
|
}
|
||||||
|
|
||||||
|
function resetTest() {
|
||||||
|
window.location.reload();
|
||||||
|
}
|
||||||
|
|
||||||
|
function showAllDebridSetup() {
|
||||||
|
showView('allDebridSetupView');
|
||||||
|
}
|
||||||
|
|
||||||
|
async function startTest(provider) {
|
||||||
|
if (provider === 'all_debrid' && !STATE.adApiKey) {
|
||||||
|
showAllDebridSetup();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
showView('testingView');
|
||||||
|
initializeResultsContainer();
|
||||||
|
|
||||||
|
try {
|
||||||
|
const params = new URLSearchParams({provider});
|
||||||
|
const headers = {'api_password': STATE.apiPassword};
|
||||||
|
|
||||||
|
if (provider === 'all_debrid' && STATE.adApiKey) {
|
||||||
|
headers['api_key'] = STATE.adApiKey;
|
||||||
|
}
|
||||||
|
|
||||||
|
const response = await fetch(`/speedtest/start?${params}`, {
|
||||||
|
method: 'POST',
|
||||||
|
headers
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
if (response.status === 403) {
|
||||||
|
handleAuthError();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
throw new Error('Failed to start speed test');
|
||||||
|
}
|
||||||
|
|
||||||
|
const {task_id} = await response.json();
|
||||||
|
STATE.currentTaskId = task_id;
|
||||||
|
await pollResults(task_id);
|
||||||
|
} catch (error) {
|
||||||
|
showError(error.message);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function initializeResultsContainer() {
|
||||||
|
const container = document.getElementById('resultsContainer');
|
||||||
|
container.innerHTML = `
|
||||||
|
<div class="space-y-4">
|
||||||
|
<div id="locationResults" class="divide-y divide-gray-200 dark:divide-gray-700">
|
||||||
|
<!-- Results will be populated here -->
|
||||||
|
</div>
|
||||||
|
<div id="summaryStats" class="hidden pt-4">
|
||||||
|
<!-- Summary stats will be populated here -->
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
`;
|
||||||
|
}
|
||||||
|
|
||||||
|
async function pollResults(taskId) {
|
||||||
|
let retryCount = 0;
|
||||||
|
const maxRetries = 10;
|
||||||
|
try {
|
||||||
|
while (true) {
|
||||||
|
const response = await fetch(`/speedtest/results/${taskId}`, {
|
||||||
|
headers: {'api_password': STATE.apiPassword}
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
if (response.status === 403) {
|
||||||
|
handleAuthError();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (retryCount < maxRetries) {
|
||||||
|
retryCount++
|
||||||
|
await new Promise(resolve => setTimeout(resolve, 2000));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
throw new Error('Failed to fetch results after multiple attempts');
|
||||||
|
}
|
||||||
|
|
||||||
|
const data = await response.json();
|
||||||
|
retryCount = 0; //reset the retry count
|
||||||
|
|
||||||
|
if (data.status === 'failed') {
|
||||||
|
throw new Error('Speed test failed');
|
||||||
|
}
|
||||||
|
|
||||||
|
updateUI(data);
|
||||||
|
|
||||||
|
if (data.status === 'completed') {
|
||||||
|
showFinalResults(data);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
await new Promise(resolve => setTimeout(resolve, 2000));
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
showError(error.message);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function updateUI(data) {
|
||||||
|
if (data.user_info) {
|
||||||
|
updateUserInfo(data.user_info);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (data.current_location) {
|
||||||
|
document.getElementById('currentLocation').textContent =
|
||||||
|
`Testing server ${data.current_location}...`;
|
||||||
|
}
|
||||||
|
|
||||||
|
updateResults(data.results);
|
||||||
|
}
|
||||||
|
|
||||||
|
function updateUserInfo(userInfo) {
|
||||||
|
const userInfoDiv = document.getElementById('userInfo');
|
||||||
|
userInfoDiv.innerHTML = `
|
||||||
|
<div class="grid grid-cols-2 md:grid-cols-4 gap-4 p-4 bg-gray-50 dark:bg-gray-700 rounded-lg">
|
||||||
|
<div class="space-y-1">
|
||||||
|
<div class="text-sm text-gray-500 dark:text-gray-400">IP Address</div>
|
||||||
|
<div class="font-medium text-gray-900 dark:text-white">${userInfo.ip}</div>
|
||||||
|
</div>
|
||||||
|
<div class="space-y-1">
|
||||||
|
<div class="text-sm text-gray-500 dark:text-gray-400">ISP</div>
|
||||||
|
<div class="font-medium text-gray-900 dark:text-white">${userInfo.isp}</div>
|
||||||
|
</div>
|
||||||
|
<div class="space-y-1">
|
||||||
|
<div class="text-sm text-gray-500 dark:text-gray-400">Country</div>
|
||||||
|
<div class="font-medium text-gray-900 dark:text-white">${userInfo.country?.toUpperCase()}</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
`;
|
||||||
|
userInfoDiv.classList.remove('hidden');
|
||||||
|
}
|
||||||
|
|
||||||
|
function updateResults(results) {
|
||||||
|
const container = document.getElementById('resultsContainer');
|
||||||
|
const validResults = Object.entries(results)
|
||||||
|
.filter(([, data]) => data.result !== null && !data.error)
|
||||||
|
.sort(([, a], [, b]) => (b.result.speed_mbps) - (a.result.speed_mbps));
|
||||||
|
|
||||||
|
const failedResults = Object.entries(results)
|
||||||
|
.filter(([, data]) => data.error || data.result === null);
|
||||||
|
|
||||||
|
// Generate HTML for results
|
||||||
|
const resultsHTML = [
|
||||||
|
// Successful results
|
||||||
|
...validResults.map(([location, data]) => createSuccessResult(location, data)),
|
||||||
|
// Failed results
|
||||||
|
...failedResults.map(([location, data]) => createErrorResult(location, data))
|
||||||
|
].join('');
|
||||||
|
|
||||||
|
container.innerHTML = `
|
||||||
|
<div class="space-y-4">
|
||||||
|
<!-- Summary Stats -->
|
||||||
|
${createSummaryStats(validResults)}
|
||||||
|
<!-- Individual Results -->
|
||||||
|
<div class="mt-6 divide-y divide-gray-200 dark:divide-gray-700">
|
||||||
|
${resultsHTML}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
`;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
function createSummaryStats(validResults) {
|
||||||
|
if (validResults.length === 0) return '';
|
||||||
|
|
||||||
|
const speeds = validResults.map(([, data]) => data.result.speed_mbps);
|
||||||
|
const maxSpeed = Math.max(...speeds);
|
||||||
|
const avgSpeed = speeds.reduce((a, b) => a + b, 0) / speeds.length;
|
||||||
|
const fastestServer = validResults[0][0]; // First server after sorting
|
||||||
|
|
||||||
|
return `
|
||||||
|
<div class="grid grid-cols-1 md:grid-cols-3 gap-4 bg-gray-50 dark:bg-gray-800 p-4 rounded-lg">
|
||||||
|
<div class="text-center">
|
||||||
|
<div class="text-sm text-gray-500 dark:text-gray-400">Fastest Server</div>
|
||||||
|
<div class="font-medium text-gray-900 dark:text-white">${fastestServer}</div>
|
||||||
|
</div>
|
||||||
|
<div class="text-center">
|
||||||
|
<div class="text-sm text-gray-500 dark:text-gray-400">Top Speed</div>
|
||||||
|
<div class="font-medium text-green-500">${maxSpeed.toFixed(2)} Mbps</div>
|
||||||
|
</div>
|
||||||
|
<div class="text-center">
|
||||||
|
<div class="text-sm text-gray-500 dark:text-gray-400">Average Speed</div>
|
||||||
|
<div class="font-medium text-blue-500">${avgSpeed.toFixed(2)} Mbps</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
`;
|
||||||
|
}
|
||||||
|
|
||||||
|
function createSuccessResult(location, data) {
|
||||||
|
const speedClass = getSpeedClass(data.result.speed_mbps);
|
||||||
|
return `
|
||||||
|
<div class="py-4">
|
||||||
|
<div class="flex justify-between items-center">
|
||||||
|
<div>
|
||||||
|
<span class="font-medium text-gray-800 dark:text-white">${location}</span>
|
||||||
|
<span class="ml-2 text-sm text-gray-500 dark:text-gray-400">${data.server_name || ''}</span>
|
||||||
|
</div>
|
||||||
|
<span class="text-lg font-semibold ${speedClass}">${data.result.speed_mbps.toFixed(2)} Mbps</span>
|
||||||
|
</div>
|
||||||
|
<div class="mt-1 text-sm text-gray-500 dark:text-gray-400">
|
||||||
|
Duration: ${data.result.duration.toFixed(2)}s •
|
||||||
|
Data: ${formatBytes(data.result.data_transferred)}
|
||||||
|
</div>
|
||||||
|
<div class="mt-1 text-xs text-gray-400 dark:text-gray-500">
|
||||||
|
Server: ${data.server_url}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
`;
|
||||||
|
}
|
||||||
|
|
||||||
|
function getSpeedClass(speed) {
|
||||||
|
if (speed >= 10) return 'text-green-500 dark:text-green-400';
|
||||||
|
if (speed >= 5) return 'text-blue-500 dark:text-blue-400';
|
||||||
|
if (speed >= 2) return 'text-yellow-500 dark:text-yellow-400';
|
||||||
|
return 'text-red-500 dark:text-red-400';
|
||||||
|
}
|
||||||
|
|
||||||
|
function showFinalResults(data) {
|
||||||
|
// Stop the progress animation
|
||||||
|
document.querySelector('#progressBar').style.animation = 'none';
|
||||||
|
|
||||||
|
// Update the final results view
|
||||||
|
const validResults = Object.entries(data.results)
|
||||||
|
.filter(([, data]) => data.result !== null && !data.error)
|
||||||
|
.sort(([, a], [, b]) => (b.result.speed_mbps) - (a.result.speed_mbps));
|
||||||
|
|
||||||
|
const failedResults = Object.entries(data.results)
|
||||||
|
.filter(([, data]) => data.error || data.result === null);
|
||||||
|
|
||||||
|
// Update summary stats
|
||||||
|
if (validResults.length > 0) {
|
||||||
|
const speeds = validResults.map(([, data]) => data.result.speed_mbps);
|
||||||
|
const maxSpeed = Math.max(...speeds);
|
||||||
|
const avgSpeed = speeds.reduce((a, b) => a + b, 0) / speeds.length;
|
||||||
|
const fastestServer = validResults[0][0];
|
||||||
|
|
||||||
|
document.getElementById('fastestServer').textContent = fastestServer;
|
||||||
|
document.getElementById('topSpeed').textContent = `${maxSpeed.toFixed(2)} Mbps`;
|
||||||
|
document.getElementById('avgSpeed').textContent = `${avgSpeed.toFixed(2)} Mbps`;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate detailed results HTML
|
||||||
|
const finalResultsHTML = `
|
||||||
|
${validResults.map(([location, data]) => `
|
||||||
|
<div class="bg-white dark:bg-gray-800 rounded-lg p-4 shadow-sm">
|
||||||
|
<div class="flex justify-between items-center">
|
||||||
|
<div>
|
||||||
|
<h3 class="text-lg font-medium text-gray-900 dark:text-white">${location}</h3>
|
||||||
|
<p class="text-sm text-gray-500 dark:text-gray-400">${data.server_name || ''}</p>
|
||||||
|
</div>
|
||||||
|
<div class="text-right">
|
||||||
|
<p class="text-2xl font-bold ${getSpeedClass(data.result.speed_mbps)}">
|
||||||
|
${data.result.speed_mbps.toFixed(2)} Mbps
|
||||||
|
</p>
|
||||||
|
<p class="text-sm text-gray-500 dark:text-gray-400">
|
||||||
|
${data.result.duration.toFixed(2)}s • ${formatBytes(data.result.data_transferred)}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class="mt-2 text-xs text-gray-400 dark:text-gray-500">
|
||||||
|
${data.server_url}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
`).join('')}
|
||||||
|
|
||||||
|
${failedResults.length > 0 ? `
|
||||||
|
<div class="mt-6">
|
||||||
|
<h3 class="text-lg font-medium text-gray-900 dark:text-white mb-4">Failed Tests</h3>
|
||||||
|
${failedResults.map(([location, data]) => `
|
||||||
|
<div class="bg-red-50 dark:bg-red-900/20 rounded-lg p-4 mb-4">
|
||||||
|
<div class="flex justify-between items-center">
|
||||||
|
<div>
|
||||||
|
<h4 class="font-medium text-red-800 dark:text-red-200">
|
||||||
|
${location} ${data.server_name ? `(${data.server_name})` : ''}
|
||||||
|
</h4>
|
||||||
|
<p class="text-sm text-red-700 dark:text-red-300">
|
||||||
|
${data.error || 'Test failed'}
|
||||||
|
</p>
|
||||||
|
<p class="text-xs text-red-600 dark:text-red-400 mt-1">
|
||||||
|
${data.server_url}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
`).join('')}
|
||||||
|
</div>
|
||||||
|
` : ''}
|
||||||
|
`;
|
||||||
|
|
||||||
|
document.getElementById('finalResults').innerHTML = finalResultsHTML;
|
||||||
|
|
||||||
|
// If we have user info from AllDebrid, copy it to the final view
|
||||||
|
const userInfoDiv = document.getElementById('userInfo');
|
||||||
|
if (!userInfoDiv.classList.contains('hidden') && data.user_info) {
|
||||||
|
const userInfoContent = userInfoDiv.innerHTML;
|
||||||
|
document.getElementById('finalResults').insertAdjacentHTML('afterbegin', `
|
||||||
|
<div class="mb-6">
|
||||||
|
${userInfoContent}
|
||||||
|
</div>
|
||||||
|
`);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Show the final results view
|
||||||
|
showView('resultsView');
|
||||||
|
}
|
||||||
|
|
||||||
|
function initializeView() {
|
||||||
|
initTheme();
|
||||||
|
showView(STATE.apiPassword ? 'selectionView' : 'passwordView');
|
||||||
|
}
|
||||||
|
|
||||||
|
function initializeFormHandlers() {
|
||||||
|
// Password form handler
|
||||||
|
document.getElementById('passwordForm').addEventListener('submit', (e) => {
|
||||||
|
e.preventDefault();
|
||||||
|
const password = document.getElementById('apiPassword').value;
|
||||||
|
const remember = document.getElementById('rememberPassword').checked;
|
||||||
|
|
||||||
|
if (remember) {
|
||||||
|
localStorage.setItem('speedtest_api_password', password);
|
||||||
|
}
|
||||||
|
STATE.apiPassword = password;
|
||||||
|
showView('selectionView');
|
||||||
|
});
|
||||||
|
|
||||||
|
// AllDebrid form handler
|
||||||
|
document.getElementById('allDebridForm').addEventListener('submit', async (e) => {
|
||||||
|
e.preventDefault();
|
||||||
|
const apiKey = document.getElementById('adApiKey').value;
|
||||||
|
const remember = document.getElementById('rememberAdKey').checked;
|
||||||
|
|
||||||
|
if (remember) {
|
||||||
|
localStorage.setItem('ad_api_key', apiKey);
|
||||||
|
}
|
||||||
|
STATE.adApiKey = apiKey;
|
||||||
|
await startTest('all_debrid');
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
document.addEventListener('DOMContentLoaded', () => {
|
||||||
|
initializeView();
|
||||||
|
initializeFormHandlers();
|
||||||
|
});
|
||||||
|
|
||||||
|
// Theme Toggle Event Listener
|
||||||
|
document.getElementById('themeToggle').addEventListener('click', () => {
|
||||||
|
setTheme(document.documentElement.classList.contains('dark') ? 'light' : 'dark');
|
||||||
|
});
|
||||||
|
</script>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
0
mediaflow_proxy/utils/__init__.py
Normal file
0
mediaflow_proxy/utils/__init__.py
Normal file
376
mediaflow_proxy/utils/cache_utils.py
Normal file
376
mediaflow_proxy/utils/cache_utils.py
Normal file
@@ -0,0 +1,376 @@
|
|||||||
|
import asyncio
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from collections import OrderedDict
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Union, Any
|
||||||
|
|
||||||
|
import aiofiles
|
||||||
|
import aiofiles.os
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from mediaflow_proxy.speedtest.models import SpeedTestTask
|
||||||
|
from mediaflow_proxy.utils.http_utils import download_file_with_retry, DownloadError
|
||||||
|
from mediaflow_proxy.utils.mpd_utils import parse_mpd, parse_mpd_dict
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CacheEntry:
|
||||||
|
"""Represents a cache entry with metadata."""
|
||||||
|
|
||||||
|
data: bytes
|
||||||
|
expires_at: float
|
||||||
|
access_count: int = 0
|
||||||
|
last_access: float = 0.0
|
||||||
|
size: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
class LRUMemoryCache:
|
||||||
|
"""Thread-safe LRU memory cache with support."""
|
||||||
|
|
||||||
|
def __init__(self, maxsize: int):
|
||||||
|
self.maxsize = maxsize
|
||||||
|
self._cache: OrderedDict[str, CacheEntry] = OrderedDict()
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
self._current_size = 0
|
||||||
|
|
||||||
|
def get(self, key: str) -> Optional[CacheEntry]:
|
||||||
|
with self._lock:
|
||||||
|
if key in self._cache:
|
||||||
|
entry = self._cache.pop(key) # Remove and re-insert for LRU
|
||||||
|
if time.time() < entry.expires_at:
|
||||||
|
entry.access_count += 1
|
||||||
|
entry.last_access = time.time()
|
||||||
|
self._cache[key] = entry
|
||||||
|
return entry
|
||||||
|
else:
|
||||||
|
# Remove expired entry
|
||||||
|
self._current_size -= entry.size
|
||||||
|
self._cache.pop(key, None)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def set(self, key: str, entry: CacheEntry) -> None:
|
||||||
|
with self._lock:
|
||||||
|
if key in self._cache:
|
||||||
|
old_entry = self._cache[key]
|
||||||
|
self._current_size -= old_entry.size
|
||||||
|
|
||||||
|
# Check if we need to make space
|
||||||
|
while self._current_size + entry.size > self.maxsize and self._cache:
|
||||||
|
_, removed_entry = self._cache.popitem(last=False)
|
||||||
|
self._current_size -= removed_entry.size
|
||||||
|
|
||||||
|
self._cache[key] = entry
|
||||||
|
self._current_size += entry.size
|
||||||
|
|
||||||
|
def remove(self, key: str) -> None:
|
||||||
|
with self._lock:
|
||||||
|
if key in self._cache:
|
||||||
|
entry = self._cache.pop(key)
|
||||||
|
self._current_size -= entry.size
|
||||||
|
|
||||||
|
|
||||||
|
class HybridCache:
|
||||||
|
"""High-performance hybrid cache combining memory and file storage."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cache_dir_name: str,
|
||||||
|
ttl: int,
|
||||||
|
max_memory_size: int = 100 * 1024 * 1024, # 100MB default
|
||||||
|
executor_workers: int = 4,
|
||||||
|
):
|
||||||
|
self.cache_dir = Path(tempfile.gettempdir()) / cache_dir_name
|
||||||
|
self.ttl = ttl
|
||||||
|
self.memory_cache = LRUMemoryCache(maxsize=max_memory_size)
|
||||||
|
self._executor = ThreadPoolExecutor(max_workers=executor_workers)
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
|
# Initialize cache directories
|
||||||
|
self._init_cache_dirs()
|
||||||
|
|
||||||
|
def _init_cache_dirs(self):
|
||||||
|
"""Initialize sharded cache directories."""
|
||||||
|
os.makedirs(self.cache_dir, exist_ok=True)
|
||||||
|
|
||||||
|
def _get_md5_hash(self, key: str) -> str:
|
||||||
|
"""Get the MD5 hash of a cache key."""
|
||||||
|
return hashlib.md5(key.encode()).hexdigest()
|
||||||
|
|
||||||
|
def _get_file_path(self, key: str) -> Path:
|
||||||
|
"""Get the file path for a cache key."""
|
||||||
|
return self.cache_dir / key
|
||||||
|
|
||||||
|
async def get(self, key: str, default: Any = None) -> Optional[bytes]:
|
||||||
|
"""
|
||||||
|
Get value from cache, trying memory first then file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Cache key
|
||||||
|
default: Default value if key not found
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cached value or default if not found
|
||||||
|
"""
|
||||||
|
key = self._get_md5_hash(key)
|
||||||
|
# Try memory cache first
|
||||||
|
entry = self.memory_cache.get(key)
|
||||||
|
if entry is not None:
|
||||||
|
return entry.data
|
||||||
|
|
||||||
|
# Try file cache
|
||||||
|
try:
|
||||||
|
file_path = self._get_file_path(key)
|
||||||
|
async with aiofiles.open(file_path, "rb") as f:
|
||||||
|
metadata_size = await f.read(8)
|
||||||
|
metadata_length = int.from_bytes(metadata_size, "big")
|
||||||
|
metadata_bytes = await f.read(metadata_length)
|
||||||
|
metadata = json.loads(metadata_bytes.decode())
|
||||||
|
|
||||||
|
# Check expiration
|
||||||
|
if metadata["expires_at"] < time.time():
|
||||||
|
await self.delete(key)
|
||||||
|
return default
|
||||||
|
|
||||||
|
# Read data
|
||||||
|
data = await f.read()
|
||||||
|
|
||||||
|
# Update memory cache in background
|
||||||
|
entry = CacheEntry(
|
||||||
|
data=data,
|
||||||
|
expires_at=metadata["expires_at"],
|
||||||
|
access_count=metadata["access_count"] + 1,
|
||||||
|
last_access=time.time(),
|
||||||
|
size=len(data),
|
||||||
|
)
|
||||||
|
self.memory_cache.set(key, entry)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
except FileNotFoundError:
|
||||||
|
return default
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error reading from cache: {e}")
|
||||||
|
return default
|
||||||
|
|
||||||
|
async def set(self, key: str, data: Union[bytes, bytearray, memoryview], ttl: Optional[int] = None) -> bool:
|
||||||
|
"""
|
||||||
|
Set value in both memory and file cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Cache key
|
||||||
|
data: Data to cache
|
||||||
|
ttl: Optional TTL override
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: Success status
|
||||||
|
"""
|
||||||
|
if not isinstance(data, (bytes, bytearray, memoryview)):
|
||||||
|
raise ValueError("Data must be bytes, bytearray, or memoryview")
|
||||||
|
|
||||||
|
expires_at = time.time() + (ttl or self.ttl)
|
||||||
|
|
||||||
|
# Create cache entry
|
||||||
|
entry = CacheEntry(data=data, expires_at=expires_at, access_count=0, last_access=time.time(), size=len(data))
|
||||||
|
|
||||||
|
key = self._get_md5_hash(key)
|
||||||
|
# Update memory cache
|
||||||
|
self.memory_cache.set(key, entry)
|
||||||
|
file_path = self._get_file_path(key)
|
||||||
|
temp_path = file_path.with_suffix(".tmp")
|
||||||
|
|
||||||
|
# Update file cache
|
||||||
|
try:
|
||||||
|
metadata = {"expires_at": expires_at, "access_count": 0, "last_access": time.time()}
|
||||||
|
metadata_bytes = json.dumps(metadata).encode()
|
||||||
|
metadata_size = len(metadata_bytes).to_bytes(8, "big")
|
||||||
|
|
||||||
|
async with aiofiles.open(temp_path, "wb") as f:
|
||||||
|
await f.write(metadata_size)
|
||||||
|
await f.write(metadata_bytes)
|
||||||
|
await f.write(data)
|
||||||
|
|
||||||
|
await aiofiles.os.rename(temp_path, file_path)
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error writing to cache: {e}")
|
||||||
|
try:
|
||||||
|
await aiofiles.os.remove(temp_path)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def delete(self, key: str) -> bool:
|
||||||
|
"""Delete item from both caches."""
|
||||||
|
self.memory_cache.remove(key)
|
||||||
|
|
||||||
|
try:
|
||||||
|
file_path = self._get_file_path(key)
|
||||||
|
await aiofiles.os.remove(file_path)
|
||||||
|
return True
|
||||||
|
except FileNotFoundError:
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error deleting from cache: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncMemoryCache:
|
||||||
|
"""Async wrapper around LRUMemoryCache."""
|
||||||
|
|
||||||
|
def __init__(self, max_memory_size: int):
|
||||||
|
self.memory_cache = LRUMemoryCache(maxsize=max_memory_size)
|
||||||
|
|
||||||
|
async def get(self, key: str, default: Any = None) -> Optional[bytes]:
|
||||||
|
"""Get value from cache."""
|
||||||
|
entry = self.memory_cache.get(key)
|
||||||
|
return entry.data if entry is not None else default
|
||||||
|
|
||||||
|
async def set(self, key: str, data: Union[bytes, bytearray, memoryview], ttl: Optional[int] = None) -> bool:
|
||||||
|
"""Set value in cache."""
|
||||||
|
try:
|
||||||
|
expires_at = time.time() + (ttl or 3600) # Default 1 hour TTL if not specified
|
||||||
|
entry = CacheEntry(
|
||||||
|
data=data, expires_at=expires_at, access_count=0, last_access=time.time(), size=len(data)
|
||||||
|
)
|
||||||
|
self.memory_cache.set(key, entry)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error setting cache value: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def delete(self, key: str) -> bool:
|
||||||
|
"""Delete item from cache."""
|
||||||
|
try:
|
||||||
|
self.memory_cache.remove(key)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error deleting from cache: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
# Create cache instances
|
||||||
|
INIT_SEGMENT_CACHE = HybridCache(
|
||||||
|
cache_dir_name="init_segment_cache",
|
||||||
|
ttl=3600, # 1 hour
|
||||||
|
max_memory_size=500 * 1024 * 1024, # 500MB for init segments
|
||||||
|
)
|
||||||
|
|
||||||
|
MPD_CACHE = AsyncMemoryCache(
|
||||||
|
max_memory_size=100 * 1024 * 1024, # 100MB for MPD files
|
||||||
|
)
|
||||||
|
|
||||||
|
SPEEDTEST_CACHE = HybridCache(
|
||||||
|
cache_dir_name="speedtest_cache",
|
||||||
|
ttl=3600, # 1 hour
|
||||||
|
max_memory_size=50 * 1024 * 1024,
|
||||||
|
)
|
||||||
|
|
||||||
|
EXTRACTOR_CACHE = HybridCache(
|
||||||
|
cache_dir_name="extractor_cache",
|
||||||
|
ttl=5 * 60, # 5 minutes
|
||||||
|
max_memory_size=50 * 1024 * 1024,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Specific cache implementations
|
||||||
|
async def get_cached_init_segment(init_url: str, headers: dict) -> Optional[bytes]:
|
||||||
|
"""Get initialization segment from cache or download it."""
|
||||||
|
# Try cache first
|
||||||
|
cached_data = await INIT_SEGMENT_CACHE.get(init_url)
|
||||||
|
if cached_data is not None:
|
||||||
|
return cached_data
|
||||||
|
|
||||||
|
# Download if not cached
|
||||||
|
try:
|
||||||
|
init_content = await download_file_with_retry(init_url, headers)
|
||||||
|
if init_content:
|
||||||
|
await INIT_SEGMENT_CACHE.set(init_url, init_content)
|
||||||
|
return init_content
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error downloading init segment: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def get_cached_mpd(
|
||||||
|
mpd_url: str,
|
||||||
|
headers: dict,
|
||||||
|
parse_drm: bool,
|
||||||
|
parse_segment_profile_id: Optional[str] = None,
|
||||||
|
) -> dict:
|
||||||
|
"""Get MPD from cache or download and parse it."""
|
||||||
|
# Try cache first
|
||||||
|
cached_data = await MPD_CACHE.get(mpd_url)
|
||||||
|
if cached_data is not None:
|
||||||
|
try:
|
||||||
|
mpd_dict = json.loads(cached_data)
|
||||||
|
return parse_mpd_dict(mpd_dict, mpd_url, parse_drm, parse_segment_profile_id)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
await MPD_CACHE.delete(mpd_url)
|
||||||
|
|
||||||
|
# Download and parse if not cached
|
||||||
|
try:
|
||||||
|
mpd_content = await download_file_with_retry(mpd_url, headers)
|
||||||
|
mpd_dict = parse_mpd(mpd_content)
|
||||||
|
parsed_dict = parse_mpd_dict(mpd_dict, mpd_url, parse_drm, parse_segment_profile_id)
|
||||||
|
|
||||||
|
# Cache the original MPD dict
|
||||||
|
await MPD_CACHE.set(mpd_url, json.dumps(mpd_dict).encode(), ttl=parsed_dict["minimumUpdatePeriod"])
|
||||||
|
return parsed_dict
|
||||||
|
except DownloadError as error:
|
||||||
|
logger.error(f"Error downloading MPD: {error}")
|
||||||
|
raise error
|
||||||
|
except Exception as error:
|
||||||
|
logger.exception(f"Error processing MPD: {error}")
|
||||||
|
raise error
|
||||||
|
|
||||||
|
|
||||||
|
async def get_cached_speedtest(task_id: str) -> Optional[SpeedTestTask]:
|
||||||
|
"""Get speed test results from cache."""
|
||||||
|
cached_data = await SPEEDTEST_CACHE.get(task_id)
|
||||||
|
if cached_data is not None:
|
||||||
|
try:
|
||||||
|
return SpeedTestTask.model_validate_json(cached_data.decode())
|
||||||
|
except ValidationError as e:
|
||||||
|
logger.error(f"Error parsing cached speed test data: {e}")
|
||||||
|
await SPEEDTEST_CACHE.delete(task_id)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def set_cache_speedtest(task_id: str, task: SpeedTestTask) -> bool:
|
||||||
|
"""Cache speed test results."""
|
||||||
|
try:
|
||||||
|
return await SPEEDTEST_CACHE.set(task_id, task.model_dump_json().encode())
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error caching speed test data: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
async def get_cached_extractor_result(key: str) -> Optional[dict]:
|
||||||
|
"""Get extractor result from cache."""
|
||||||
|
cached_data = await EXTRACTOR_CACHE.get(key)
|
||||||
|
if cached_data is not None:
|
||||||
|
try:
|
||||||
|
return json.loads(cached_data)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
await EXTRACTOR_CACHE.delete(key)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def set_cache_extractor_result(key: str, result: dict) -> bool:
|
||||||
|
"""Cache extractor result."""
|
||||||
|
try:
|
||||||
|
return await EXTRACTOR_CACHE.set(key, json.dumps(result).encode())
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error caching extractor result: {e}")
|
||||||
|
return False
|
||||||
110
mediaflow_proxy/utils/crypto_utils.py
Normal file
110
mediaflow_proxy/utils/crypto_utils.py
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
from typing import Optional
|
||||||
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
|
from Crypto.Cipher import AES
|
||||||
|
from Crypto.Random import get_random_bytes
|
||||||
|
from Crypto.Util.Padding import pad, unpad
|
||||||
|
from fastapi import HTTPException, Request
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
from starlette.responses import JSONResponse
|
||||||
|
|
||||||
|
from mediaflow_proxy.configs import settings
|
||||||
|
|
||||||
|
|
||||||
|
class EncryptionHandler:
|
||||||
|
def __init__(self, secret_key: str):
|
||||||
|
self.secret_key = secret_key.encode("utf-8").ljust(32)[:32]
|
||||||
|
|
||||||
|
def encrypt_data(self, data: dict, expiration: int = None, ip: str = None) -> str:
|
||||||
|
if expiration:
|
||||||
|
data["exp"] = int(time.time()) + expiration
|
||||||
|
if ip:
|
||||||
|
data["ip"] = ip
|
||||||
|
json_data = json.dumps(data).encode("utf-8")
|
||||||
|
iv = get_random_bytes(16)
|
||||||
|
cipher = AES.new(self.secret_key, AES.MODE_CBC, iv)
|
||||||
|
encrypted_data = cipher.encrypt(pad(json_data, AES.block_size))
|
||||||
|
return base64.urlsafe_b64encode(iv + encrypted_data).decode("utf-8")
|
||||||
|
|
||||||
|
def decrypt_data(self, token: str, client_ip: str) -> dict:
|
||||||
|
try:
|
||||||
|
encrypted_data = base64.urlsafe_b64decode(token.encode("utf-8"))
|
||||||
|
iv = encrypted_data[:16]
|
||||||
|
cipher = AES.new(self.secret_key, AES.MODE_CBC, iv)
|
||||||
|
decrypted_data = unpad(cipher.decrypt(encrypted_data[16:]), AES.block_size)
|
||||||
|
data = json.loads(decrypted_data)
|
||||||
|
|
||||||
|
if "exp" in data:
|
||||||
|
if data["exp"] < time.time():
|
||||||
|
raise HTTPException(status_code=401, detail="Token has expired")
|
||||||
|
del data["exp"] # Remove expiration from the data
|
||||||
|
|
||||||
|
if "ip" in data:
|
||||||
|
if data["ip"] != client_ip:
|
||||||
|
raise HTTPException(status_code=403, detail="IP address mismatch")
|
||||||
|
del data["ip"] # Remove IP from the data
|
||||||
|
|
||||||
|
return data
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=401, detail="Invalid or expired token")
|
||||||
|
|
||||||
|
|
||||||
|
class EncryptionMiddleware(BaseHTTPMiddleware):
|
||||||
|
def __init__(self, app):
|
||||||
|
super().__init__(app)
|
||||||
|
self.encryption_handler = encryption_handler
|
||||||
|
|
||||||
|
async def dispatch(self, request: Request, call_next):
|
||||||
|
encrypted_token = request.query_params.get("token")
|
||||||
|
if encrypted_token and self.encryption_handler:
|
||||||
|
try:
|
||||||
|
client_ip = self.get_client_ip(request)
|
||||||
|
decrypted_data = self.encryption_handler.decrypt_data(encrypted_token, client_ip)
|
||||||
|
# Modify request query parameters with decrypted data
|
||||||
|
query_params = dict(request.query_params)
|
||||||
|
query_params.pop("token") # Remove the encrypted token from query params
|
||||||
|
query_params.update(decrypted_data) # Add decrypted data to query params
|
||||||
|
query_params["has_encrypted"] = True
|
||||||
|
|
||||||
|
# Create a new request scope with updated query parameters
|
||||||
|
new_query_string = urlencode(query_params)
|
||||||
|
request.scope["query_string"] = new_query_string.encode()
|
||||||
|
request._query_params = query_params
|
||||||
|
except HTTPException as e:
|
||||||
|
return JSONResponse(content={"error": str(e.detail)}, status_code=e.status_code)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await call_next(request)
|
||||||
|
except Exception:
|
||||||
|
exc = traceback.format_exc(chain=False)
|
||||||
|
logging.error("An error occurred while processing the request, error: %s", exc)
|
||||||
|
return JSONResponse(
|
||||||
|
content={"error": "An error occurred while processing the request, check the server for logs"},
|
||||||
|
status_code=500,
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_client_ip(request: Request) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Extract the client's real IP address from the request headers or fallback to the client host.
|
||||||
|
"""
|
||||||
|
x_forwarded_for = request.headers.get("X-Forwarded-For")
|
||||||
|
if x_forwarded_for:
|
||||||
|
# In some cases, this header can contain multiple IPs
|
||||||
|
# separated by commas.
|
||||||
|
# The first one is the original client's IP.
|
||||||
|
return x_forwarded_for.split(",")[0].strip()
|
||||||
|
# Fallback to X-Real-IP if X-Forwarded-For is not available
|
||||||
|
x_real_ip = request.headers.get("X-Real-IP")
|
||||||
|
if x_real_ip:
|
||||||
|
return x_real_ip
|
||||||
|
return request.client.host if request.client else "127.0.0.1"
|
||||||
|
|
||||||
|
|
||||||
|
encryption_handler = EncryptionHandler(settings.api_password) if settings.api_password else None
|
||||||
430
mediaflow_proxy/utils/http_utils.py
Normal file
430
mediaflow_proxy/utils/http_utils.py
Normal file
@@ -0,0 +1,430 @@
|
|||||||
|
import logging
|
||||||
|
import typing
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from functools import partial
|
||||||
|
from urllib import parse
|
||||||
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
|
import anyio
|
||||||
|
import httpx
|
||||||
|
import tenacity
|
||||||
|
from fastapi import Response
|
||||||
|
from starlette.background import BackgroundTask
|
||||||
|
from starlette.concurrency import iterate_in_threadpool
|
||||||
|
from starlette.requests import Request
|
||||||
|
from starlette.types import Receive, Send, Scope
|
||||||
|
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
|
||||||
|
from tqdm.asyncio import tqdm as tqdm_asyncio
|
||||||
|
|
||||||
|
from mediaflow_proxy.configs import settings
|
||||||
|
from mediaflow_proxy.const import SUPPORTED_REQUEST_HEADERS
|
||||||
|
from mediaflow_proxy.utils.crypto_utils import EncryptionHandler
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DownloadError(Exception):
|
||||||
|
def __init__(self, status_code, message):
|
||||||
|
self.status_code = status_code
|
||||||
|
self.message = message
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
def create_httpx_client(follow_redirects: bool = True, timeout: float = 30.0, **kwargs) -> httpx.AsyncClient:
|
||||||
|
"""Creates an HTTPX client with configured proxy routing"""
|
||||||
|
mounts = settings.transport_config.get_mounts()
|
||||||
|
client = httpx.AsyncClient(mounts=mounts, follow_redirects=follow_redirects, timeout=timeout, **kwargs)
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
@retry(
|
||||||
|
stop=stop_after_attempt(3),
|
||||||
|
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||||
|
retry=retry_if_exception_type(DownloadError),
|
||||||
|
)
|
||||||
|
async def fetch_with_retry(client, method, url, headers, follow_redirects=True, **kwargs):
|
||||||
|
"""
|
||||||
|
Fetches a URL with retry logic.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client (httpx.AsyncClient): The HTTP client to use for the request.
|
||||||
|
method (str): The HTTP method to use (e.g., GET, POST).
|
||||||
|
url (str): The URL to fetch.
|
||||||
|
headers (dict): The headers to include in the request.
|
||||||
|
follow_redirects (bool, optional): Whether to follow redirects. Defaults to True.
|
||||||
|
**kwargs: Additional arguments to pass to the request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
httpx.Response: The HTTP response.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
DownloadError: If the request fails after retries.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
response = await client.request(method, url, headers=headers, follow_redirects=follow_redirects, **kwargs)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
logger.warning(f"Timeout while downloading {url}")
|
||||||
|
raise DownloadError(409, f"Timeout while downloading {url}")
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
logger.error(f"HTTP error {e.response.status_code} while downloading {url}")
|
||||||
|
if e.response.status_code == 404:
|
||||||
|
logger.error(f"Segment Resource not found: {url}")
|
||||||
|
raise e
|
||||||
|
raise DownloadError(e.response.status_code, f"HTTP error {e.response.status_code} while downloading {url}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error downloading {url}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
class Streamer:
|
||||||
|
def __init__(self, client):
|
||||||
|
"""
|
||||||
|
Initializes the Streamer with an HTTP client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client (httpx.AsyncClient): The HTTP client to use for streaming.
|
||||||
|
"""
|
||||||
|
self.client = client
|
||||||
|
self.response = None
|
||||||
|
self.progress_bar = None
|
||||||
|
self.bytes_transferred = 0
|
||||||
|
self.start_byte = 0
|
||||||
|
self.end_byte = 0
|
||||||
|
self.total_size = 0
|
||||||
|
|
||||||
|
async def create_streaming_response(self, url: str, headers: dict):
|
||||||
|
"""
|
||||||
|
Creates and sends a streaming request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url (str): The URL to stream from.
|
||||||
|
headers (dict): The headers to include in the request.
|
||||||
|
|
||||||
|
"""
|
||||||
|
request = self.client.build_request("GET", url, headers=headers)
|
||||||
|
self.response = await self.client.send(request, stream=True, follow_redirects=True)
|
||||||
|
self.response.raise_for_status()
|
||||||
|
|
||||||
|
async def stream_content(self) -> typing.AsyncGenerator[bytes, None]:
|
||||||
|
"""
|
||||||
|
Streams the content from the response.
|
||||||
|
"""
|
||||||
|
if not self.response:
|
||||||
|
raise RuntimeError("No response available for streaming")
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.parse_content_range()
|
||||||
|
|
||||||
|
if settings.enable_streaming_progress:
|
||||||
|
with tqdm_asyncio(
|
||||||
|
total=self.total_size,
|
||||||
|
initial=self.start_byte,
|
||||||
|
unit="B",
|
||||||
|
unit_scale=True,
|
||||||
|
unit_divisor=1024,
|
||||||
|
desc="Streaming",
|
||||||
|
ncols=100,
|
||||||
|
mininterval=1,
|
||||||
|
) as self.progress_bar:
|
||||||
|
async for chunk in self.response.aiter_bytes():
|
||||||
|
yield chunk
|
||||||
|
chunk_size = len(chunk)
|
||||||
|
self.bytes_transferred += chunk_size
|
||||||
|
self.progress_bar.set_postfix_str(
|
||||||
|
f"📥 : {self.format_bytes(self.bytes_transferred)}", refresh=False
|
||||||
|
)
|
||||||
|
self.progress_bar.update(chunk_size)
|
||||||
|
else:
|
||||||
|
async for chunk in self.response.aiter_bytes():
|
||||||
|
yield chunk
|
||||||
|
self.bytes_transferred += len(chunk)
|
||||||
|
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
logger.warning("Timeout while streaming")
|
||||||
|
raise DownloadError(409, "Timeout while streaming")
|
||||||
|
except GeneratorExit:
|
||||||
|
logger.info("Streaming session stopped by the user")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error streaming content: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def format_bytes(size) -> str:
|
||||||
|
power = 2**10
|
||||||
|
n = 0
|
||||||
|
units = {0: "B", 1: "KB", 2: "MB", 3: "GB", 4: "TB"}
|
||||||
|
while size > power:
|
||||||
|
size /= power
|
||||||
|
n += 1
|
||||||
|
return f"{size:.2f} {units[n]}"
|
||||||
|
|
||||||
|
def parse_content_range(self):
|
||||||
|
content_range = self.response.headers.get("Content-Range", "")
|
||||||
|
if content_range:
|
||||||
|
range_info = content_range.split()[-1]
|
||||||
|
self.start_byte, self.end_byte, self.total_size = map(int, range_info.replace("/", "-").split("-"))
|
||||||
|
else:
|
||||||
|
self.start_byte = 0
|
||||||
|
self.total_size = int(self.response.headers.get("Content-Length", 0))
|
||||||
|
self.end_byte = self.total_size - 1 if self.total_size > 0 else 0
|
||||||
|
|
||||||
|
async def get_text(self, url: str, headers: dict):
|
||||||
|
"""
|
||||||
|
Sends a GET request to a URL and returns the response text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url (str): The URL to send the GET request to.
|
||||||
|
headers (dict): The headers to include in the request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The response text.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self.response = await fetch_with_retry(self.client, "GET", url, headers)
|
||||||
|
except tenacity.RetryError as e:
|
||||||
|
raise e.last_attempt.result()
|
||||||
|
return self.response.text
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
"""
|
||||||
|
Closes the HTTP client and response.
|
||||||
|
"""
|
||||||
|
if self.response:
|
||||||
|
await self.response.aclose()
|
||||||
|
if self.progress_bar:
|
||||||
|
self.progress_bar.close()
|
||||||
|
await self.client.aclose()
|
||||||
|
|
||||||
|
|
||||||
|
async def download_file_with_retry(url: str, headers: dict):
|
||||||
|
"""
|
||||||
|
Downloads a file with retry logic.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url (str): The URL of the file to download.
|
||||||
|
headers (dict): The headers to include in the request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bytes: The downloaded file content.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
DownloadError: If the download fails after retries.
|
||||||
|
"""
|
||||||
|
async with create_httpx_client() as client:
|
||||||
|
try:
|
||||||
|
response = await fetch_with_retry(client, "GET", url, headers)
|
||||||
|
return response.content
|
||||||
|
except DownloadError as e:
|
||||||
|
logger.error(f"Failed to download file: {e}")
|
||||||
|
raise e
|
||||||
|
except tenacity.RetryError as e:
|
||||||
|
raise DownloadError(502, f"Failed to download file: {e.last_attempt.result()}")
|
||||||
|
|
||||||
|
|
||||||
|
async def request_with_retry(method: str, url: str, headers: dict, **kwargs) -> httpx.Response:
|
||||||
|
"""
|
||||||
|
Sends an HTTP request with retry logic.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
method (str): The HTTP method to use (e.g., GET, POST).
|
||||||
|
url (str): The URL to send the request to.
|
||||||
|
headers (dict): The headers to include in the request.
|
||||||
|
**kwargs: Additional arguments to pass to the request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
httpx.Response: The HTTP response.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
DownloadError: If the request fails after retries.
|
||||||
|
"""
|
||||||
|
async with create_httpx_client() as client:
|
||||||
|
try:
|
||||||
|
response = await fetch_with_retry(client, method, url, headers, **kwargs)
|
||||||
|
return response
|
||||||
|
except DownloadError as e:
|
||||||
|
logger.error(f"Failed to download file: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def encode_mediaflow_proxy_url(
|
||||||
|
mediaflow_proxy_url: str,
|
||||||
|
endpoint: typing.Optional[str] = None,
|
||||||
|
destination_url: typing.Optional[str] = None,
|
||||||
|
query_params: typing.Optional[dict] = None,
|
||||||
|
request_headers: typing.Optional[dict] = None,
|
||||||
|
response_headers: typing.Optional[dict] = None,
|
||||||
|
encryption_handler: EncryptionHandler = None,
|
||||||
|
expiration: int = None,
|
||||||
|
ip: str = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Encodes & Encrypt (Optional) a MediaFlow proxy URL with query parameters and headers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mediaflow_proxy_url (str): The base MediaFlow proxy URL.
|
||||||
|
endpoint (str, optional): The endpoint to append to the base URL. Defaults to None.
|
||||||
|
destination_url (str, optional): The destination URL to include in the query parameters. Defaults to None.
|
||||||
|
query_params (dict, optional): Additional query parameters to include. Defaults to None.
|
||||||
|
request_headers (dict, optional): Headers to include as query parameters. Defaults to None.
|
||||||
|
response_headers (dict, optional): Headers to include as query parameters. Defaults to None.
|
||||||
|
encryption_handler (EncryptionHandler, optional): The encryption handler to use. Defaults to None.
|
||||||
|
expiration (int, optional): The expiration time for the encrypted token. Defaults to None.
|
||||||
|
ip (str, optional): The public IP address to include in the query parameters. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The encoded MediaFlow proxy URL.
|
||||||
|
"""
|
||||||
|
query_params = query_params or {}
|
||||||
|
if destination_url is not None:
|
||||||
|
query_params["d"] = destination_url
|
||||||
|
|
||||||
|
# Add headers if provided
|
||||||
|
if request_headers:
|
||||||
|
query_params.update(
|
||||||
|
{key if key.startswith("h_") else f"h_{key}": value for key, value in request_headers.items()}
|
||||||
|
)
|
||||||
|
if response_headers:
|
||||||
|
query_params.update(
|
||||||
|
{key if key.startswith("r_") else f"r_{key}": value for key, value in response_headers.items()}
|
||||||
|
)
|
||||||
|
|
||||||
|
if encryption_handler:
|
||||||
|
encrypted_token = encryption_handler.encrypt_data(query_params, expiration, ip)
|
||||||
|
encoded_params = urlencode({"token": encrypted_token})
|
||||||
|
else:
|
||||||
|
encoded_params = urlencode(query_params)
|
||||||
|
|
||||||
|
# Construct the full URL
|
||||||
|
if endpoint is None:
|
||||||
|
return f"{mediaflow_proxy_url}?{encoded_params}"
|
||||||
|
|
||||||
|
base_url = parse.urljoin(mediaflow_proxy_url, endpoint)
|
||||||
|
return f"{base_url}?{encoded_params}"
|
||||||
|
|
||||||
|
|
||||||
|
def get_original_scheme(request: Request) -> str:
|
||||||
|
"""
|
||||||
|
Determines the original scheme (http or https) of the request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request (Request): The incoming HTTP request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The original scheme ('http' or 'https')
|
||||||
|
"""
|
||||||
|
# Check the X-Forwarded-Proto header first
|
||||||
|
forwarded_proto = request.headers.get("X-Forwarded-Proto")
|
||||||
|
if forwarded_proto:
|
||||||
|
return forwarded_proto
|
||||||
|
|
||||||
|
# Check if the request is secure
|
||||||
|
if request.url.scheme == "https" or request.headers.get("X-Forwarded-Ssl") == "on":
|
||||||
|
return "https"
|
||||||
|
|
||||||
|
# Check for other common headers that might indicate HTTPS
|
||||||
|
if (
|
||||||
|
request.headers.get("X-Forwarded-Ssl") == "on"
|
||||||
|
or request.headers.get("X-Forwarded-Protocol") == "https"
|
||||||
|
or request.headers.get("X-Url-Scheme") == "https"
|
||||||
|
):
|
||||||
|
return "https"
|
||||||
|
|
||||||
|
# Default to http if no indicators of https are found
|
||||||
|
return "http"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ProxyRequestHeaders:
|
||||||
|
request: dict
|
||||||
|
response: dict
|
||||||
|
|
||||||
|
|
||||||
|
def get_proxy_headers(request: Request) -> ProxyRequestHeaders:
|
||||||
|
"""
|
||||||
|
Extracts proxy headers from the request query parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request (Request): The incoming HTTP request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ProxyRequest: A named tuple containing the request headers and response headers.
|
||||||
|
"""
|
||||||
|
request_headers = {k: v for k, v in request.headers.items() if k in SUPPORTED_REQUEST_HEADERS}
|
||||||
|
request_headers.update({k[2:].lower(): v for k, v in request.query_params.items() if k.startswith("h_")})
|
||||||
|
response_headers = {k[2:].lower(): v for k, v in request.query_params.items() if k.startswith("r_")}
|
||||||
|
return ProxyRequestHeaders(request_headers, response_headers)
|
||||||
|
|
||||||
|
|
||||||
|
class EnhancedStreamingResponse(Response):
|
||||||
|
body_iterator: typing.AsyncIterable[typing.Any]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
content: typing.Union[typing.AsyncIterable[typing.Any], typing.Iterable[typing.Any]],
|
||||||
|
status_code: int = 200,
|
||||||
|
headers: typing.Optional[typing.Mapping[str, str]] = None,
|
||||||
|
media_type: typing.Optional[str] = None,
|
||||||
|
background: typing.Optional[BackgroundTask] = None,
|
||||||
|
) -> None:
|
||||||
|
if isinstance(content, typing.AsyncIterable):
|
||||||
|
self.body_iterator = content
|
||||||
|
else:
|
||||||
|
self.body_iterator = iterate_in_threadpool(content)
|
||||||
|
self.status_code = status_code
|
||||||
|
self.media_type = self.media_type if media_type is None else media_type
|
||||||
|
self.background = background
|
||||||
|
self.init_headers(headers)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def listen_for_disconnect(receive: Receive) -> None:
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
message = await receive()
|
||||||
|
if message["type"] == "http.disconnect":
|
||||||
|
logger.debug("Client disconnected")
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in listen_for_disconnect: {str(e)}")
|
||||||
|
|
||||||
|
async def stream_response(self, send: Send) -> None:
|
||||||
|
try:
|
||||||
|
await send(
|
||||||
|
{
|
||||||
|
"type": "http.response.start",
|
||||||
|
"status": self.status_code,
|
||||||
|
"headers": self.raw_headers,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
async for chunk in self.body_iterator:
|
||||||
|
if not isinstance(chunk, (bytes, memoryview)):
|
||||||
|
chunk = chunk.encode(self.charset)
|
||||||
|
try:
|
||||||
|
await send({"type": "http.response.body", "body": chunk, "more_body": True})
|
||||||
|
except (ConnectionResetError, anyio.BrokenResourceError):
|
||||||
|
logger.info("Client disconnected during streaming")
|
||||||
|
return
|
||||||
|
|
||||||
|
await send({"type": "http.response.body", "body": b"", "more_body": False})
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Error in stream_response: {str(e)}")
|
||||||
|
|
||||||
|
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||||
|
async with anyio.create_task_group() as task_group:
|
||||||
|
|
||||||
|
async def wrap(func: typing.Callable[[], typing.Awaitable[None]]) -> None:
|
||||||
|
try:
|
||||||
|
await func()
|
||||||
|
except Exception as e:
|
||||||
|
if not isinstance(e, anyio.get_cancelled_exc_class()):
|
||||||
|
logger.exception("Error in streaming task")
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
task_group.cancel_scope.cancel()
|
||||||
|
|
||||||
|
task_group.start_soon(wrap, partial(self.stream_response, send))
|
||||||
|
await wrap(partial(self.listen_for_disconnect, receive))
|
||||||
|
|
||||||
|
if self.background is not None:
|
||||||
|
await self.background()
|
||||||
87
mediaflow_proxy/utils/m3u8_processor.py
Normal file
87
mediaflow_proxy/utils/m3u8_processor.py
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
import re
|
||||||
|
from urllib import parse
|
||||||
|
|
||||||
|
from mediaflow_proxy.utils.crypto_utils import encryption_handler
|
||||||
|
from mediaflow_proxy.utils.http_utils import encode_mediaflow_proxy_url, get_original_scheme
|
||||||
|
|
||||||
|
|
||||||
|
class M3U8Processor:
|
||||||
|
def __init__(self, request, key_url: str = None):
|
||||||
|
"""
|
||||||
|
Initializes the M3U8Processor with the request and URL prefix.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request (Request): The incoming HTTP request.
|
||||||
|
key_url (HttpUrl, optional): The URL of the key server. Defaults to None.
|
||||||
|
"""
|
||||||
|
self.request = request
|
||||||
|
self.key_url = parse.urlparse(key_url) if key_url else None
|
||||||
|
self.mediaflow_proxy_url = str(
|
||||||
|
request.url_for("hls_manifest_proxy").replace(scheme=get_original_scheme(request))
|
||||||
|
)
|
||||||
|
|
||||||
|
async def process_m3u8(self, content: str, base_url: str) -> str:
|
||||||
|
"""
|
||||||
|
Processes the m3u8 content, proxying URLs and handling key lines.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content (str): The m3u8 content to process.
|
||||||
|
base_url (str): The base URL to resolve relative URLs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The processed m3u8 content.
|
||||||
|
"""
|
||||||
|
lines = content.splitlines()
|
||||||
|
processed_lines = []
|
||||||
|
for line in lines:
|
||||||
|
if "URI=" in line:
|
||||||
|
processed_lines.append(await self.process_key_line(line, base_url))
|
||||||
|
elif not line.startswith("#") and line.strip():
|
||||||
|
processed_lines.append(await self.proxy_url(line, base_url))
|
||||||
|
else:
|
||||||
|
processed_lines.append(line)
|
||||||
|
return "\n".join(processed_lines)
|
||||||
|
|
||||||
|
async def process_key_line(self, line: str, base_url: str) -> str:
|
||||||
|
"""
|
||||||
|
Processes a key line in the m3u8 content, proxying the URI.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
line (str): The key line to process.
|
||||||
|
base_url (str): The base URL to resolve relative URLs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The processed key line.
|
||||||
|
"""
|
||||||
|
uri_match = re.search(r'URI="([^"]+)"', line)
|
||||||
|
if uri_match:
|
||||||
|
original_uri = uri_match.group(1)
|
||||||
|
uri = parse.urlparse(original_uri)
|
||||||
|
if self.key_url:
|
||||||
|
uri = uri._replace(scheme=self.key_url.scheme, netloc=self.key_url.netloc)
|
||||||
|
new_uri = await self.proxy_url(uri.geturl(), base_url)
|
||||||
|
line = line.replace(f'URI="{original_uri}"', f'URI="{new_uri}"')
|
||||||
|
return line
|
||||||
|
|
||||||
|
async def proxy_url(self, url: str, base_url: str) -> str:
|
||||||
|
"""
|
||||||
|
Proxies a URL, encoding it with the MediaFlow proxy URL.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url (str): The URL to proxy.
|
||||||
|
base_url (str): The base URL to resolve relative URLs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The proxied URL.
|
||||||
|
"""
|
||||||
|
full_url = parse.urljoin(base_url, url)
|
||||||
|
query_params = dict(self.request.query_params)
|
||||||
|
has_encrypted = query_params.pop("has_encrypted", False)
|
||||||
|
|
||||||
|
return encode_mediaflow_proxy_url(
|
||||||
|
self.mediaflow_proxy_url,
|
||||||
|
"",
|
||||||
|
full_url,
|
||||||
|
query_params=dict(self.request.query_params),
|
||||||
|
encryption_handler=encryption_handler if has_encrypted else None,
|
||||||
|
)
|
||||||
555
mediaflow_proxy/utils/mpd_utils.py
Normal file
555
mediaflow_proxy/utils/mpd_utils.py
Normal file
@@ -0,0 +1,555 @@
|
|||||||
|
import logging
|
||||||
|
import math
|
||||||
|
import re
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from typing import List, Dict, Optional, Union
|
||||||
|
from urllib.parse import urljoin
|
||||||
|
|
||||||
|
import xmltodict
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_mpd(mpd_content: Union[str, bytes]) -> dict:
|
||||||
|
"""
|
||||||
|
Parses the MPD content into a dictionary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mpd_content (Union[str, bytes]): The MPD content to parse.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: The parsed MPD content as a dictionary.
|
||||||
|
"""
|
||||||
|
return xmltodict.parse(mpd_content)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_mpd_dict(
|
||||||
|
mpd_dict: dict, mpd_url: str, parse_drm: bool = True, parse_segment_profile_id: Optional[str] = None
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Parses the MPD dictionary and extracts relevant information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mpd_dict (dict): The MPD content as a dictionary.
|
||||||
|
mpd_url (str): The URL of the MPD manifest.
|
||||||
|
parse_drm (bool, optional): Whether to parse DRM information. Defaults to True.
|
||||||
|
parse_segment_profile_id (str, optional): The profile ID to parse segments for. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: The parsed MPD information including profiles and DRM info.
|
||||||
|
|
||||||
|
This function processes the MPD dictionary to extract profiles, DRM information, and other relevant data.
|
||||||
|
It handles both live and static MPD manifests.
|
||||||
|
"""
|
||||||
|
profiles = []
|
||||||
|
parsed_dict = {}
|
||||||
|
source = "/".join(mpd_url.split("/")[:-1])
|
||||||
|
|
||||||
|
is_live = mpd_dict["MPD"].get("@type", "static").lower() == "dynamic"
|
||||||
|
parsed_dict["isLive"] = is_live
|
||||||
|
|
||||||
|
media_presentation_duration = mpd_dict["MPD"].get("@mediaPresentationDuration")
|
||||||
|
|
||||||
|
# Parse additional MPD attributes for live streams
|
||||||
|
if is_live:
|
||||||
|
parsed_dict["minimumUpdatePeriod"] = parse_duration(mpd_dict["MPD"].get("@minimumUpdatePeriod", "PT0S"))
|
||||||
|
parsed_dict["timeShiftBufferDepth"] = parse_duration(mpd_dict["MPD"].get("@timeShiftBufferDepth", "PT2M"))
|
||||||
|
parsed_dict["availabilityStartTime"] = datetime.fromisoformat(
|
||||||
|
mpd_dict["MPD"]["@availabilityStartTime"].replace("Z", "+00:00")
|
||||||
|
)
|
||||||
|
parsed_dict["publishTime"] = datetime.fromisoformat(
|
||||||
|
mpd_dict["MPD"].get("@publishTime", "").replace("Z", "+00:00")
|
||||||
|
)
|
||||||
|
|
||||||
|
periods = mpd_dict["MPD"]["Period"]
|
||||||
|
periods = periods if isinstance(periods, list) else [periods]
|
||||||
|
|
||||||
|
for period in periods:
|
||||||
|
parsed_dict["PeriodStart"] = parse_duration(period.get("@start", "PT0S"))
|
||||||
|
for adaptation in period["AdaptationSet"]:
|
||||||
|
representations = adaptation["Representation"]
|
||||||
|
representations = representations if isinstance(representations, list) else [representations]
|
||||||
|
|
||||||
|
for representation in representations:
|
||||||
|
profile = parse_representation(
|
||||||
|
parsed_dict,
|
||||||
|
representation,
|
||||||
|
adaptation,
|
||||||
|
source,
|
||||||
|
media_presentation_duration,
|
||||||
|
parse_segment_profile_id,
|
||||||
|
)
|
||||||
|
if profile:
|
||||||
|
profiles.append(profile)
|
||||||
|
parsed_dict["profiles"] = profiles
|
||||||
|
|
||||||
|
if parse_drm:
|
||||||
|
drm_info = extract_drm_info(periods, mpd_url)
|
||||||
|
else:
|
||||||
|
drm_info = {}
|
||||||
|
parsed_dict["drmInfo"] = drm_info
|
||||||
|
|
||||||
|
return parsed_dict
|
||||||
|
|
||||||
|
|
||||||
|
def pad_base64(encoded_key_id):
|
||||||
|
"""
|
||||||
|
Pads a base64 encoded key ID to make its length a multiple of 4.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
encoded_key_id (str): The base64 encoded key ID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The padded base64 encoded key ID.
|
||||||
|
"""
|
||||||
|
return encoded_key_id + "=" * (4 - len(encoded_key_id) % 4)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_drm_info(periods: List[Dict], mpd_url: str) -> Dict:
|
||||||
|
"""
|
||||||
|
Extracts DRM information from the MPD periods.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
periods (List[Dict]): The list of periods in the MPD.
|
||||||
|
mpd_url (str): The URL of the MPD manifest.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: The extracted DRM information.
|
||||||
|
|
||||||
|
This function processes the ContentProtection elements in the MPD to extract DRM system information,
|
||||||
|
such as ClearKey, Widevine, and PlayReady.
|
||||||
|
"""
|
||||||
|
drm_info = {"isDrmProtected": False}
|
||||||
|
|
||||||
|
for period in periods:
|
||||||
|
adaptation_sets: Union[list[dict], dict] = period.get("AdaptationSet", [])
|
||||||
|
if not isinstance(adaptation_sets, list):
|
||||||
|
adaptation_sets = [adaptation_sets]
|
||||||
|
|
||||||
|
for adaptation_set in adaptation_sets:
|
||||||
|
# Check ContentProtection in AdaptationSet
|
||||||
|
process_content_protection(adaptation_set.get("ContentProtection", []), drm_info)
|
||||||
|
|
||||||
|
# Check ContentProtection inside each Representation
|
||||||
|
representations: Union[list[dict], dict] = adaptation_set.get("Representation", [])
|
||||||
|
if not isinstance(representations, list):
|
||||||
|
representations = [representations]
|
||||||
|
|
||||||
|
for representation in representations:
|
||||||
|
process_content_protection(representation.get("ContentProtection", []), drm_info)
|
||||||
|
|
||||||
|
# If we have a license acquisition URL, make sure it's absolute
|
||||||
|
if "laUrl" in drm_info and not drm_info["laUrl"].startswith(("http://", "https://")):
|
||||||
|
drm_info["laUrl"] = urljoin(mpd_url, drm_info["laUrl"])
|
||||||
|
|
||||||
|
return drm_info
|
||||||
|
|
||||||
|
|
||||||
|
def process_content_protection(content_protection: Union[list[dict], dict], drm_info: dict):
|
||||||
|
"""
|
||||||
|
Processes the ContentProtection elements to extract DRM information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content_protection (Union[list[dict], dict]): The ContentProtection elements.
|
||||||
|
drm_info (dict): The dictionary to store DRM information.
|
||||||
|
|
||||||
|
This function updates the drm_info dictionary with DRM system information found in the ContentProtection elements.
|
||||||
|
"""
|
||||||
|
if not isinstance(content_protection, list):
|
||||||
|
content_protection = [content_protection]
|
||||||
|
|
||||||
|
for protection in content_protection:
|
||||||
|
drm_info["isDrmProtected"] = True
|
||||||
|
scheme_id_uri = protection.get("@schemeIdUri", "").lower()
|
||||||
|
|
||||||
|
if "clearkey" in scheme_id_uri:
|
||||||
|
drm_info["drmSystem"] = "clearkey"
|
||||||
|
if "clearkey:Laurl" in protection:
|
||||||
|
la_url = protection["clearkey:Laurl"].get("#text")
|
||||||
|
if la_url and "laUrl" not in drm_info:
|
||||||
|
drm_info["laUrl"] = la_url
|
||||||
|
|
||||||
|
elif "widevine" in scheme_id_uri or "edef8ba9-79d6-4ace-a3c8-27dcd51d21ed" in scheme_id_uri:
|
||||||
|
drm_info["drmSystem"] = "widevine"
|
||||||
|
pssh = protection.get("cenc:pssh", {}).get("#text")
|
||||||
|
if pssh:
|
||||||
|
drm_info["pssh"] = pssh
|
||||||
|
|
||||||
|
elif "playready" in scheme_id_uri or "9a04f079-9840-4286-ab92-e65be0885f95" in scheme_id_uri:
|
||||||
|
drm_info["drmSystem"] = "playready"
|
||||||
|
|
||||||
|
if "@cenc:default_KID" in protection:
|
||||||
|
key_id = protection["@cenc:default_KID"].replace("-", "")
|
||||||
|
if "keyId" not in drm_info:
|
||||||
|
drm_info["keyId"] = key_id
|
||||||
|
|
||||||
|
if "ms:laurl" in protection:
|
||||||
|
la_url = protection["ms:laurl"].get("@licenseUrl")
|
||||||
|
if la_url and "laUrl" not in drm_info:
|
||||||
|
drm_info["laUrl"] = la_url
|
||||||
|
|
||||||
|
return drm_info
|
||||||
|
|
||||||
|
|
||||||
|
def parse_representation(
|
||||||
|
parsed_dict: dict,
|
||||||
|
representation: dict,
|
||||||
|
adaptation: dict,
|
||||||
|
source: str,
|
||||||
|
media_presentation_duration: str,
|
||||||
|
parse_segment_profile_id: Optional[str],
|
||||||
|
) -> Optional[dict]:
|
||||||
|
"""
|
||||||
|
Parses a representation and extracts profile information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
parsed_dict (dict): The parsed MPD data.
|
||||||
|
representation (dict): The representation data.
|
||||||
|
adaptation (dict): The adaptation set data.
|
||||||
|
source (str): The source URL.
|
||||||
|
media_presentation_duration (str): The media presentation duration.
|
||||||
|
parse_segment_profile_id (str, optional): The profile ID to parse segments for. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[dict]: The parsed profile information or None if not applicable.
|
||||||
|
"""
|
||||||
|
mime_type = _get_key(adaptation, representation, "@mimeType") or (
|
||||||
|
"video/mp4" if "avc" in representation["@codecs"] else "audio/mp4"
|
||||||
|
)
|
||||||
|
if "video" not in mime_type and "audio" not in mime_type:
|
||||||
|
return None
|
||||||
|
|
||||||
|
profile = {
|
||||||
|
"id": representation.get("@id") or adaptation.get("@id"),
|
||||||
|
"mimeType": mime_type,
|
||||||
|
"lang": representation.get("@lang") or adaptation.get("@lang"),
|
||||||
|
"codecs": representation.get("@codecs") or adaptation.get("@codecs"),
|
||||||
|
"bandwidth": int(representation.get("@bandwidth") or adaptation.get("@bandwidth")),
|
||||||
|
"startWithSAP": (_get_key(adaptation, representation, "@startWithSAP") or "1") == "1",
|
||||||
|
"mediaPresentationDuration": media_presentation_duration,
|
||||||
|
}
|
||||||
|
|
||||||
|
if "audio" in profile["mimeType"]:
|
||||||
|
profile["audioSamplingRate"] = representation.get("@audioSamplingRate") or adaptation.get("@audioSamplingRate")
|
||||||
|
profile["channels"] = representation.get("AudioChannelConfiguration", {}).get("@value", "2")
|
||||||
|
else:
|
||||||
|
profile["width"] = int(representation["@width"])
|
||||||
|
profile["height"] = int(representation["@height"])
|
||||||
|
frame_rate = representation.get("@frameRate") or adaptation.get("@maxFrameRate") or "30000/1001"
|
||||||
|
frame_rate = frame_rate if "/" in frame_rate else f"{frame_rate}/1"
|
||||||
|
profile["frameRate"] = round(int(frame_rate.split("/")[0]) / int(frame_rate.split("/")[1]), 3)
|
||||||
|
profile["sar"] = representation.get("@sar", "1:1")
|
||||||
|
|
||||||
|
if parse_segment_profile_id is None or profile["id"] != parse_segment_profile_id:
|
||||||
|
return profile
|
||||||
|
|
||||||
|
item = adaptation.get("SegmentTemplate") or representation.get("SegmentTemplate")
|
||||||
|
if item:
|
||||||
|
profile["segments"] = parse_segment_template(parsed_dict, item, profile, source)
|
||||||
|
else:
|
||||||
|
profile["segments"] = parse_segment_base(representation, source)
|
||||||
|
|
||||||
|
return profile
|
||||||
|
|
||||||
|
|
||||||
|
def _get_key(adaptation: dict, representation: dict, key: str) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Retrieves a key from the representation or adaptation set.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
adaptation (dict): The adaptation set data.
|
||||||
|
representation (dict): The representation data.
|
||||||
|
key (str): The key to retrieve.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[str]: The value of the key or None if not found.
|
||||||
|
"""
|
||||||
|
return representation.get(key, adaptation.get(key, None))
|
||||||
|
|
||||||
|
|
||||||
|
def parse_segment_template(parsed_dict: dict, item: dict, profile: dict, source: str) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
Parses a segment template and extracts segment information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
parsed_dict (dict): The parsed MPD data.
|
||||||
|
item (dict): The segment template data.
|
||||||
|
profile (dict): The profile information.
|
||||||
|
source (str): The source URL.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Dict]: The list of parsed segments.
|
||||||
|
"""
|
||||||
|
segments = []
|
||||||
|
timescale = int(item.get("@timescale", 1))
|
||||||
|
|
||||||
|
# Initialization
|
||||||
|
if "@initialization" in item:
|
||||||
|
media = item["@initialization"]
|
||||||
|
media = media.replace("$RepresentationID$", profile["id"])
|
||||||
|
media = media.replace("$Bandwidth$", str(profile["bandwidth"]))
|
||||||
|
if not media.startswith("http"):
|
||||||
|
media = f"{source}/{media}"
|
||||||
|
profile["initUrl"] = media
|
||||||
|
|
||||||
|
# Segments
|
||||||
|
if "SegmentTimeline" in item:
|
||||||
|
segments.extend(parse_segment_timeline(parsed_dict, item, profile, source, timescale))
|
||||||
|
elif "@duration" in item:
|
||||||
|
segments.extend(parse_segment_duration(parsed_dict, item, profile, source, timescale))
|
||||||
|
|
||||||
|
return segments
|
||||||
|
|
||||||
|
|
||||||
|
def parse_segment_timeline(parsed_dict: dict, item: dict, profile: dict, source: str, timescale: int) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
Parses a segment timeline and extracts segment information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
parsed_dict (dict): The parsed MPD data.
|
||||||
|
item (dict): The segment timeline data.
|
||||||
|
profile (dict): The profile information.
|
||||||
|
source (str): The source URL.
|
||||||
|
timescale (int): The timescale for the segments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Dict]: The list of parsed segments.
|
||||||
|
"""
|
||||||
|
timelines = item["SegmentTimeline"]["S"]
|
||||||
|
timelines = timelines if isinstance(timelines, list) else [timelines]
|
||||||
|
period_start = parsed_dict["availabilityStartTime"] + timedelta(seconds=parsed_dict.get("PeriodStart", 0))
|
||||||
|
presentation_time_offset = int(item.get("@presentationTimeOffset", 0))
|
||||||
|
start_number = int(item.get("@startNumber", 1))
|
||||||
|
|
||||||
|
segments = [
|
||||||
|
create_segment_data(timeline, item, profile, source, timescale)
|
||||||
|
for timeline in preprocess_timeline(timelines, start_number, period_start, presentation_time_offset, timescale)
|
||||||
|
]
|
||||||
|
return segments
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_timeline(
|
||||||
|
timelines: List[Dict], start_number: int, period_start: datetime, presentation_time_offset: int, timescale: int
|
||||||
|
) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
Preprocesses the segment timeline data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timelines (List[Dict]): The list of timeline segments.
|
||||||
|
start_number (int): The starting segment number.
|
||||||
|
period_start (datetime): The start time of the period.
|
||||||
|
presentation_time_offset (int): The presentation time offset.
|
||||||
|
timescale (int): The timescale for the segments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Dict]: The list of preprocessed timeline segments.
|
||||||
|
"""
|
||||||
|
processed_data = []
|
||||||
|
current_time = 0
|
||||||
|
for timeline in timelines:
|
||||||
|
repeat = int(timeline.get("@r", 0))
|
||||||
|
duration = int(timeline["@d"])
|
||||||
|
start_time = int(timeline.get("@t", current_time))
|
||||||
|
|
||||||
|
for _ in range(repeat + 1):
|
||||||
|
segment_start_time = period_start + timedelta(seconds=(start_time - presentation_time_offset) / timescale)
|
||||||
|
segment_end_time = segment_start_time + timedelta(seconds=duration / timescale)
|
||||||
|
processed_data.append(
|
||||||
|
{
|
||||||
|
"number": start_number,
|
||||||
|
"start_time": segment_start_time,
|
||||||
|
"end_time": segment_end_time,
|
||||||
|
"duration": duration,
|
||||||
|
"time": start_time,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
start_time += duration
|
||||||
|
start_number += 1
|
||||||
|
|
||||||
|
current_time = start_time
|
||||||
|
|
||||||
|
return processed_data
|
||||||
|
|
||||||
|
|
||||||
|
def parse_segment_duration(parsed_dict: dict, item: dict, profile: dict, source: str, timescale: int) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
Parses segment duration and extracts segment information.
|
||||||
|
This is used for static or live MPD manifests.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
parsed_dict (dict): The parsed MPD data.
|
||||||
|
item (dict): The segment duration data.
|
||||||
|
profile (dict): The profile information.
|
||||||
|
source (str): The source URL.
|
||||||
|
timescale (int): The timescale for the segments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Dict]: The list of parsed segments.
|
||||||
|
"""
|
||||||
|
duration = int(item["@duration"])
|
||||||
|
start_number = int(item.get("@startNumber", 1))
|
||||||
|
segment_duration_sec = duration / timescale
|
||||||
|
|
||||||
|
if parsed_dict["isLive"]:
|
||||||
|
segments = generate_live_segments(parsed_dict, segment_duration_sec, start_number)
|
||||||
|
else:
|
||||||
|
segments = generate_vod_segments(profile, duration, timescale, start_number)
|
||||||
|
|
||||||
|
return [create_segment_data(seg, item, profile, source, timescale) for seg in segments]
|
||||||
|
|
||||||
|
|
||||||
|
def generate_live_segments(parsed_dict: dict, segment_duration_sec: float, start_number: int) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
Generates live segments based on the segment duration and start number.
|
||||||
|
This is used for live MPD manifests.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
parsed_dict (dict): The parsed MPD data.
|
||||||
|
segment_duration_sec (float): The segment duration in seconds.
|
||||||
|
start_number (int): The starting segment number.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Dict]: The list of generated live segments.
|
||||||
|
"""
|
||||||
|
time_shift_buffer_depth = timedelta(seconds=parsed_dict.get("timeShiftBufferDepth", 60))
|
||||||
|
segment_count = math.ceil(time_shift_buffer_depth.total_seconds() / segment_duration_sec)
|
||||||
|
current_time = datetime.now(tz=timezone.utc)
|
||||||
|
earliest_segment_number = max(
|
||||||
|
start_number
|
||||||
|
+ math.floor((current_time - parsed_dict["availabilityStartTime"]).total_seconds() / segment_duration_sec)
|
||||||
|
- segment_count,
|
||||||
|
start_number,
|
||||||
|
)
|
||||||
|
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"number": number,
|
||||||
|
"start_time": parsed_dict["availabilityStartTime"]
|
||||||
|
+ timedelta(seconds=(number - start_number) * segment_duration_sec),
|
||||||
|
"duration": segment_duration_sec,
|
||||||
|
}
|
||||||
|
for number in range(earliest_segment_number, earliest_segment_number + segment_count)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def generate_vod_segments(profile: dict, duration: int, timescale: int, start_number: int) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
Generates VOD segments based on the segment duration and start number.
|
||||||
|
This is used for static MPD manifests.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
profile (dict): The profile information.
|
||||||
|
duration (int): The segment duration.
|
||||||
|
timescale (int): The timescale for the segments.
|
||||||
|
start_number (int): The starting segment number.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Dict]: The list of generated VOD segments.
|
||||||
|
"""
|
||||||
|
total_duration = profile.get("mediaPresentationDuration") or 0
|
||||||
|
if isinstance(total_duration, str):
|
||||||
|
total_duration = parse_duration(total_duration)
|
||||||
|
segment_count = math.ceil(total_duration * timescale / duration)
|
||||||
|
|
||||||
|
return [{"number": start_number + i, "duration": duration / timescale} for i in range(segment_count)]
|
||||||
|
|
||||||
|
|
||||||
|
def create_segment_data(segment: Dict, item: dict, profile: dict, source: str, timescale: Optional[int] = None) -> Dict:
|
||||||
|
"""
|
||||||
|
Creates segment data based on the segment information. This includes the segment URL and metadata.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
segment (Dict): The segment information.
|
||||||
|
item (dict): The segment template data.
|
||||||
|
profile (dict): The profile information.
|
||||||
|
source (str): The source URL.
|
||||||
|
timescale (int, optional): The timescale for the segments. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: The created segment data.
|
||||||
|
"""
|
||||||
|
media_template = item["@media"]
|
||||||
|
media = media_template.replace("$RepresentationID$", profile["id"])
|
||||||
|
media = media.replace("$Number%04d$", f"{segment['number']:04d}")
|
||||||
|
media = media.replace("$Number$", str(segment["number"]))
|
||||||
|
media = media.replace("$Bandwidth$", str(profile["bandwidth"]))
|
||||||
|
|
||||||
|
if "time" in segment and timescale is not None:
|
||||||
|
media = media.replace("$Time$", str(int(segment["time"] * timescale)))
|
||||||
|
|
||||||
|
if not media.startswith("http"):
|
||||||
|
media = f"{source}/{media}"
|
||||||
|
|
||||||
|
segment_data = {
|
||||||
|
"type": "segment",
|
||||||
|
"media": media,
|
||||||
|
"number": segment["number"],
|
||||||
|
}
|
||||||
|
|
||||||
|
if "start_time" in segment and "end_time" in segment:
|
||||||
|
segment_data.update(
|
||||||
|
{
|
||||||
|
"start_time": segment["start_time"],
|
||||||
|
"end_time": segment["end_time"],
|
||||||
|
"extinf": (segment["end_time"] - segment["start_time"]).total_seconds(),
|
||||||
|
"program_date_time": segment["start_time"].isoformat() + "Z",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif "start_time" in segment and "duration" in segment:
|
||||||
|
duration = segment["duration"]
|
||||||
|
segment_data.update(
|
||||||
|
{
|
||||||
|
"start_time": segment["start_time"],
|
||||||
|
"end_time": segment["start_time"] + timedelta(seconds=duration),
|
||||||
|
"extinf": duration,
|
||||||
|
"program_date_time": segment["start_time"].isoformat() + "Z",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif "duration" in segment:
|
||||||
|
segment_data["extinf"] = segment["duration"]
|
||||||
|
|
||||||
|
return segment_data
|
||||||
|
|
||||||
|
|
||||||
|
def parse_segment_base(representation: dict, source: str) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
Parses segment base information and extracts segment data. This is used for single-segment representations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
representation (dict): The representation data.
|
||||||
|
source (str): The source URL.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Dict]: The list of parsed segments.
|
||||||
|
"""
|
||||||
|
segment = representation["SegmentBase"]
|
||||||
|
start, end = map(int, segment["@indexRange"].split("-"))
|
||||||
|
if "Initialization" in segment:
|
||||||
|
start, _ = map(int, segment["Initialization"]["@range"].split("-"))
|
||||||
|
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"type": "segment",
|
||||||
|
"range": f"{start}-{end}",
|
||||||
|
"media": f"{source}/{representation['BaseURL']}",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def parse_duration(duration_str: str) -> float:
|
||||||
|
"""
|
||||||
|
Parses a duration ISO 8601 string into seconds.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
duration_str (str): The duration string to parse.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: The parsed duration in seconds.
|
||||||
|
"""
|
||||||
|
pattern = re.compile(r"P(?:(\d+)Y)?(?:(\d+)M)?(?:(\d+)D)?T?(?:(\d+)H)?(?:(\d+)M)?(?:(\d+(?:\.\d+)?)S)?")
|
||||||
|
match = pattern.match(duration_str)
|
||||||
|
if not match:
|
||||||
|
raise ValueError(f"Invalid duration format: {duration_str}")
|
||||||
|
|
||||||
|
years, months, days, hours, minutes, seconds = [float(g) if g else 0 for g in match.groups()]
|
||||||
|
return years * 365 * 24 * 3600 + months * 30 * 24 * 3600 + days * 24 * 3600 + hours * 3600 + minutes * 60 + seconds
|
||||||
19
requirements.txt
Normal file
19
requirements.txt
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
bs4
|
||||||
|
dateparser
|
||||||
|
python-dotenv
|
||||||
|
fastapi
|
||||||
|
uvicorn
|
||||||
|
tzdata
|
||||||
|
lxml
|
||||||
|
curl_cffi
|
||||||
|
fake-headers
|
||||||
|
pydantic_settings
|
||||||
|
httpx
|
||||||
|
Crypto
|
||||||
|
pycryptodome
|
||||||
|
tenacity
|
||||||
|
xmltodict
|
||||||
|
starlette
|
||||||
|
cachetools
|
||||||
|
tqdm
|
||||||
|
aiofiles
|
||||||
18
run.py
Normal file
18
run.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
from fastapi import FastAPI
|
||||||
|
from mediaflow_proxy.main import app as mediaflow_app # Import mediaflow app
|
||||||
|
import httpx
|
||||||
|
import re
|
||||||
|
import string
|
||||||
|
|
||||||
|
# Initialize the main FastAPI application
|
||||||
|
main_app = FastAPI()
|
||||||
|
|
||||||
|
# Manually add only non-static routes from mediaflow_app
|
||||||
|
for route in mediaflow_app.routes:
|
||||||
|
if route.path != "/": # Exclude the static file path
|
||||||
|
main_app.router.routes.append(route)
|
||||||
|
|
||||||
|
# Run the main app
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
uvicorn.run(main_app, host="0.0.0.0", port=8080)
|
||||||
Reference in New Issue
Block a user