new version

This commit is contained in:
UrloMythus
2026-01-11 14:29:22 +01:00
parent b8a40b5afc
commit 7785e8c604
45 changed files with 5463 additions and 832 deletions

View File

@@ -0,0 +1,39 @@
# Author: Trevor Perrin
# See the LICENSE file for legal information regarding use of this file.
"""Abstract class for AES."""
class AES(object):
def __init__(self, key, mode, IV, implementation):
if len(key) not in (16, 24, 32):
raise AssertionError()
if mode not in [2, 6]:
raise AssertionError()
if mode == 2:
if len(IV) != 16:
raise AssertionError()
if mode == 6:
if len(IV) > 16:
raise AssertionError()
self.isBlockCipher = True
self.isAEAD = False
self.block_size = 16
self.implementation = implementation
if len(key)==16:
self.name = "aes128"
elif len(key)==24:
self.name = "aes192"
elif len(key)==32:
self.name = "aes256"
else:
raise AssertionError()
#CBC-Mode encryption, returns ciphertext
#WARNING: *MAY* modify the input as well
def encrypt(self, plaintext):
assert(len(plaintext) % 16 == 0)
#CBC-Mode decryption, returns plaintext
#WARNING: *MAY* modify the input as well
def decrypt(self, ciphertext):
assert(len(ciphertext) % 16 == 0)

View File

@@ -0,0 +1,193 @@
# Author: Google
# See the LICENSE file for legal information regarding use of this file.
# GCM derived from Go's implementation in crypto/cipher.
#
# https://golang.org/src/crypto/cipher/gcm.go
# GCM works over elements of the field GF(2^128), each of which is a 128-bit
# polynomial. Throughout this implementation, polynomials are represented as
# Python integers with the low-order terms at the most significant bits. So a
# 128-bit polynomial is an integer from 0 to 2^128-1 with the most significant
# bit representing the x^0 term and the least significant bit representing the
# x^127 term. This bit reversal also applies to polynomials used as indices in a
# look-up table.
from __future__ import division
from . import python_aes
from .constanttime import ct_compare_digest
from .cryptomath import bytesToNumber, numberToByteArray
class AESGCM(object):
"""
AES-GCM implementation. Note: this implementation does not attempt
to be side-channel resistant. It's also rather slow.
"""
def __init__(self, key, implementation, rawAesEncrypt):
self.isBlockCipher = False
self.isAEAD = True
self.nonceLength = 12
self.tagLength = 16
self.implementation = implementation
if len(key) == 16:
self.name = "aes128gcm"
elif len(key) == 32:
self.name = "aes256gcm"
else:
raise AssertionError()
self.key = key
self._rawAesEncrypt = rawAesEncrypt
self._ctr = python_aes.new(self.key, 6, bytearray(b'\x00' * 16))
# The GCM key is AES(0).
h = bytesToNumber(self._rawAesEncrypt(bytearray(16)))
# Pre-compute all 4-bit multiples of h. Note that bits are reversed
# because our polynomial representation places low-order terms at the
# most significant bit. Thus x^0 * h = h is at index 0b1000 = 8 and
# x^1 * h is at index 0b0100 = 4.
self._productTable = [0] * 16
self._productTable[self._reverseBits(1)] = h
for i in range(2, 16, 2):
self._productTable[self._reverseBits(i)] = \
self._gcmShift(self._productTable[self._reverseBits(i//2)])
self._productTable[self._reverseBits(i+1)] = \
self._gcmAdd(self._productTable[self._reverseBits(i)], h)
def _auth(self, ciphertext, ad, tagMask):
y = 0
y = self._update(y, ad)
y = self._update(y, ciphertext)
y ^= (len(ad) << (3 + 64)) | (len(ciphertext) << 3)
y = self._mul(y)
y ^= bytesToNumber(tagMask)
return numberToByteArray(y, 16)
def _update(self, y, data):
for i in range(0, len(data) // 16):
y ^= bytesToNumber(data[16*i:16*i+16])
y = self._mul(y)
extra = len(data) % 16
if extra != 0:
block = bytearray(16)
block[:extra] = data[-extra:]
y ^= bytesToNumber(block)
y = self._mul(y)
return y
def _mul(self, y):
""" Returns y*H, where H is the GCM key. """
ret = 0
# Multiply H by y 4 bits at a time, starting with the highest power
# terms.
for i in range(0, 128, 4):
# Multiply by x^4. The reduction for the top four terms is
# precomputed.
retHigh = ret & 0xf
ret >>= 4
ret ^= (AESGCM._gcmReductionTable[retHigh] << (128-16))
# Add in y' * H where y' are the next four terms of y, shifted down
# to the x^0..x^4. This is one of the pre-computed multiples of
# H. The multiplication by x^4 shifts them back into place.
ret ^= self._productTable[y & 0xf]
y >>= 4
assert y == 0
return ret
def seal(self, nonce, plaintext, data=''):
"""
Encrypts and authenticates plaintext using nonce and data. Returns the
ciphertext, consisting of the encrypted plaintext and tag concatenated.
"""
if len(nonce) != 12:
raise ValueError("Bad nonce length")
# The initial counter value is the nonce, followed by a 32-bit counter
# that starts at 1. It's used to compute the tag mask.
counter = bytearray(16)
counter[:12] = nonce
counter[-1] = 1
tagMask = self._rawAesEncrypt(counter)
# The counter starts at 2 for the actual encryption.
counter[-1] = 2
self._ctr.counter = counter
ciphertext = self._ctr.encrypt(plaintext)
tag = self._auth(ciphertext, data, tagMask)
return ciphertext + tag
def open(self, nonce, ciphertext, data=''):
"""
Decrypts and authenticates ciphertext using nonce and data. If the
tag is valid, the plaintext is returned. If the tag is invalid,
returns None.
"""
if len(nonce) != 12:
raise ValueError("Bad nonce length")
if len(ciphertext) < 16:
return None
tag = ciphertext[-16:]
ciphertext = ciphertext[:-16]
# The initial counter value is the nonce, followed by a 32-bit counter
# that starts at 1. It's used to compute the tag mask.
counter = bytearray(16)
counter[:12] = nonce
counter[-1] = 1
tagMask = self._rawAesEncrypt(counter)
if data and not ct_compare_digest(tag, self._auth(ciphertext, data, tagMask)):
return None
# The counter starts at 2 for the actual decryption.
counter[-1] = 2
self._ctr.counter = counter
return self._ctr.decrypt(ciphertext)
@staticmethod
def _reverseBits(i):
assert i < 16
i = ((i << 2) & 0xc) | ((i >> 2) & 0x3)
i = ((i << 1) & 0xa) | ((i >> 1) & 0x5)
return i
@staticmethod
def _gcmAdd(x, y):
return x ^ y
@staticmethod
def _gcmShift(x):
# Multiplying by x is a right shift, due to bit order.
highTermSet = x & 1
x >>= 1
if highTermSet:
# The x^127 term was shifted up to x^128, so subtract a 1+x+x^2+x^7
# term. This is 0b11100001 or 0xe1 when represented as an 8-bit
# polynomial.
x ^= 0xe1 << (128-8)
return x
@staticmethod
def _inc32(counter):
for i in range(len(counter)-1, len(counter)-5, -1):
counter[i] = (counter[i] + 1) % 256
if counter[i] != 0:
break
return counter
# _gcmReductionTable[i] is i * (1+x+x^2+x^7) for all 4-bit polynomials i. The
# result is stored as a 16-bit polynomial. This is used in the reduction step to
# multiply elements of GF(2^128) by x^4.
_gcmReductionTable = [
0x0000, 0x1c20, 0x3840, 0x2460, 0x7080, 0x6ca0, 0x48c0, 0x54e0,
0xe100, 0xfd20, 0xd940, 0xc560, 0x9180, 0x8da0, 0xa9c0, 0xb5e0,
]

View File

@@ -175,12 +175,27 @@ class HybridCache:
if not isinstance(data, (bytes, bytearray, memoryview)):
raise ValueError("Data must be bytes, bytearray, or memoryview")
expires_at = time.time() + (ttl or self.ttl)
ttl_seconds = self.ttl if ttl is None else ttl
key = self._get_md5_hash(key)
if ttl_seconds <= 0:
# Explicit request to avoid caching - remove any previous entry and return success
self.memory_cache.remove(key)
try:
file_path = self._get_file_path(key)
await aiofiles.os.remove(file_path)
except FileNotFoundError:
pass
except Exception as e:
logger.error(f"Error removing cache file: {e}")
return True
expires_at = time.time() + ttl_seconds
# Create cache entry
entry = CacheEntry(data=data, expires_at=expires_at, access_count=0, last_access=time.time(), size=len(data))
key = self._get_md5_hash(key)
# Update memory cache
self.memory_cache.set(key, entry)
file_path = self._get_file_path(key)
@@ -210,10 +225,11 @@ class HybridCache:
async def delete(self, key: str) -> bool:
"""Delete item from both caches."""
self.memory_cache.remove(key)
hashed_key = self._get_md5_hash(key)
self.memory_cache.remove(hashed_key)
try:
file_path = self._get_file_path(key)
file_path = self._get_file_path(hashed_key)
await aiofiles.os.remove(file_path)
return True
except FileNotFoundError:
@@ -237,7 +253,13 @@ class AsyncMemoryCache:
async def set(self, key: str, data: Union[bytes, bytearray, memoryview], ttl: Optional[int] = None) -> bool:
"""Set value in cache."""
try:
expires_at = time.time() + (ttl or 3600) # Default 1 hour TTL if not specified
ttl_seconds = 3600 if ttl is None else ttl
if ttl_seconds <= 0:
self.memory_cache.remove(key)
return True
expires_at = time.time() + ttl_seconds
entry = CacheEntry(
data=data, expires_at=expires_at, access_count=0, last_access=time.time(), size=len(data)
)
@@ -276,18 +298,35 @@ EXTRACTOR_CACHE = HybridCache(
# Specific cache implementations
async def get_cached_init_segment(init_url: str, headers: dict) -> Optional[bytes]:
"""Get initialization segment from cache or download it."""
# Try cache first
cached_data = await INIT_SEGMENT_CACHE.get(init_url)
if cached_data is not None:
return cached_data
async def get_cached_init_segment(
init_url: str,
headers: dict,
cache_token: str | None = None,
ttl: Optional[int] = None,
) -> Optional[bytes]:
"""Get initialization segment from cache or download it.
cache_token allows differentiating entries that share the same init_url but
rely on different DRM keys or initialization payloads (e.g. key rotation).
ttl overrides the default cache TTL; pass a value <= 0 to skip caching entirely.
"""
use_cache = ttl is None or ttl > 0
cache_key = f"{init_url}|{cache_token}" if cache_token else init_url
if use_cache:
cached_data = await INIT_SEGMENT_CACHE.get(cache_key)
if cached_data is not None:
return cached_data
else:
# Remove any previously cached entry when caching is disabled
await INIT_SEGMENT_CACHE.delete(cache_key)
# Download if not cached
try:
init_content = await download_file_with_retry(init_url, headers)
if init_content:
await INIT_SEGMENT_CACHE.set(init_url, init_content)
if init_content and use_cache:
await INIT_SEGMENT_CACHE.set(cache_key, init_content, ttl=ttl)
return init_content
except Exception as e:
logger.error(f"Error downloading init segment: {e}")

View File

@@ -0,0 +1,465 @@
# Author: Trevor Perrin
# See the LICENSE file for legal information regarding use of this file.
"""Classes for reading/writing binary data (such as TLS records)."""
from __future__ import division
import sys
import struct
from struct import pack
from .compat import bytes_to_int
class DecodeError(SyntaxError):
"""Exception raised in case of decoding errors."""
pass
class BadCertificateError(SyntaxError):
"""Exception raised in case of bad certificate."""
pass
class Writer(object):
"""Serialisation helper for complex byte-based structures."""
def __init__(self):
"""Initialise the serializer with no data."""
self.bytes = bytearray(0)
def addOne(self, val):
"""Add a single-byte wide element to buffer, see add()."""
self.bytes.append(val)
if sys.version_info < (2, 7):
# struct.pack on Python2.6 does not raise exception if the value
# is larger than can fit inside the specified size
def addTwo(self, val):
"""Add a double-byte wide element to buffer, see add()."""
if not 0 <= val <= 0xffff:
raise ValueError("Can't represent value in specified length")
self.bytes += pack('>H', val)
def addThree(self, val):
"""Add a three-byte wide element to buffer, see add()."""
if not 0 <= val <= 0xffffff:
raise ValueError("Can't represent value in specified length")
self.bytes += pack('>BH', val >> 16, val & 0xffff)
def addFour(self, val):
"""Add a four-byte wide element to buffer, see add()."""
if not 0 <= val <= 0xffffffff:
raise ValueError("Can't represent value in specified length")
self.bytes += pack('>I', val)
else:
def addTwo(self, val):
"""Add a double-byte wide element to buffer, see add()."""
try:
self.bytes += pack('>H', val)
except struct.error:
raise ValueError("Can't represent value in specified length")
def addThree(self, val):
"""Add a three-byte wide element to buffer, see add()."""
try:
self.bytes += pack('>BH', val >> 16, val & 0xffff)
except struct.error:
raise ValueError("Can't represent value in specified length")
def addFour(self, val):
"""Add a four-byte wide element to buffer, see add()."""
try:
self.bytes += pack('>I', val)
except struct.error:
raise ValueError("Can't represent value in specified length")
if sys.version_info >= (3, 0):
# the method is called thousands of times, so it's better to extern
# the version info check
def add(self, x, length):
"""
Add a single positive integer value x, encode it in length bytes
Encode positive integer x in big-endian format using length bytes,
add to the internal buffer.
:type x: int
:param x: value to encode
:type length: int
:param length: number of bytes to use for encoding the value
"""
try:
self.bytes += x.to_bytes(length, 'big')
except OverflowError:
raise ValueError("Can't represent value in specified length")
else:
_addMethods = {1: addOne, 2: addTwo, 3: addThree, 4: addFour}
def add(self, x, length):
"""
Add a single positive integer value x, encode it in length bytes
Encode positive iteger x in big-endian format using length bytes,
add to the internal buffer.
:type x: int
:param x: value to encode
:type length: int
:param length: number of bytes to use for encoding the value
"""
try:
self._addMethods[length](self, x)
except KeyError:
self.bytes += bytearray(length)
newIndex = len(self.bytes) - 1
for i in range(newIndex, newIndex - length, -1):
self.bytes[i] = x & 0xFF
x >>= 8
if x != 0:
raise ValueError("Can't represent value in specified "
"length")
def addFixSeq(self, seq, length):
"""
Add a list of items, encode every item in length bytes
Uses the unbounded iterable seq to produce items, each of
which is then encoded to length bytes
:type seq: iterable of int
:param seq: list of positive integers to encode
:type length: int
:param length: number of bytes to which encode every element
"""
for e in seq:
self.add(e, length)
if sys.version_info < (2, 7):
# struct.pack on Python2.6 does not raise exception if the value
# is larger than can fit inside the specified size
def _addVarSeqTwo(self, seq):
"""Helper method for addVarSeq"""
if not all(0 <= i <= 0xffff for i in seq):
raise ValueError("Can't represent value in specified "
"length")
self.bytes += pack('>' + 'H' * len(seq), *seq)
def addVarSeq(self, seq, length, lengthLength):
"""
Add a bounded list of same-sized values
Create a list of specific length with all items being of the same
size
:type seq: list of int
:param seq: list of positive integers to encode
:type length: int
:param length: amount of bytes in which to encode every item
:type lengthLength: int
:param lengthLength: amount of bytes in which to encode the overall
length of the array
"""
self.add(len(seq)*length, lengthLength)
if length == 1:
self.bytes.extend(seq)
elif length == 2:
self._addVarSeqTwo(seq)
else:
for i in seq:
self.add(i, length)
else:
def addVarSeq(self, seq, length, lengthLength):
"""
Add a bounded list of same-sized values
Create a list of specific length with all items being of the same
size
:type seq: list of int
:param seq: list of positive integers to encode
:type length: int
:param length: amount of bytes in which to encode every item
:type lengthLength: int
:param lengthLength: amount of bytes in which to encode the overall
length of the array
"""
seqLen = len(seq)
self.add(seqLen*length, lengthLength)
if length == 1:
self.bytes.extend(seq)
elif length == 2:
try:
self.bytes += pack('>' + 'H' * seqLen, *seq)
except struct.error:
raise ValueError("Can't represent value in specified "
"length")
else:
for i in seq:
self.add(i, length)
def addVarTupleSeq(self, seq, length, lengthLength):
"""
Add a variable length list of same-sized element tuples.
Note that all tuples must have the same size.
Inverse of Parser.getVarTupleList()
:type seq: enumerable
:param seq: list of tuples
:type length: int
:param length: length of single element in tuple
:type lengthLength: int
:param lengthLength: length in bytes of overall length field
"""
if not seq:
self.add(0, lengthLength)
else:
startPos = len(self.bytes)
dataLength = len(seq) * len(seq[0]) * length
self.add(dataLength, lengthLength)
# since at the time of writing, all the calls encode single byte
# elements, and it's very easy to speed up that case, give it
# special case
if length == 1:
for elemTuple in seq:
self.bytes.extend(elemTuple)
else:
for elemTuple in seq:
self.addFixSeq(elemTuple, length)
if startPos + dataLength + lengthLength != len(self.bytes):
raise ValueError("Tuples of different lengths")
def add_var_bytes(self, data, length_length):
"""
Add a variable length array of bytes.
Inverse of Parser.getVarBytes()
:type data: bytes
:param data: bytes to add to the buffer
:param int length_length: size of the field to represent the length
of the data string
"""
length = len(data)
self.add(length, length_length)
self.bytes += data
class Parser(object):
"""
Parser for TLV and LV byte-based encodings.
Parser that can handle arbitrary byte-based encodings usually employed in
Type-Length-Value or Length-Value binary encoding protocols like ASN.1
or TLS
Note: if the raw bytes don't match expected values (like trying to
read a 4-byte integer from a 2-byte buffer), most methods will raise a
DecodeError exception.
TODO: don't use an exception used by language parser to indicate errors
in application code.
:vartype bytes: bytearray
:ivar bytes: data to be interpreted (buffer)
:vartype index: int
:ivar index: current position in the buffer
:vartype lengthCheck: int
:ivar lengthCheck: size of struct being parsed
:vartype indexCheck: int
:ivar indexCheck: position at which the structure begins in buffer
"""
def __init__(self, bytes):
"""
Bind raw bytes with parser.
:type bytes: bytearray
:param bytes: bytes to be parsed/interpreted
"""
self.bytes = bytes
self.index = 0
self.indexCheck = 0
self.lengthCheck = 0
def get(self, length):
"""
Read a single big-endian integer value encoded in 'length' bytes.
:type length: int
:param length: number of bytes in which the value is encoded in
:rtype: int
"""
ret = self.getFixBytes(length)
return bytes_to_int(ret, 'big')
def getFixBytes(self, lengthBytes):
"""
Read a string of bytes encoded in 'lengthBytes' bytes.
:type lengthBytes: int
:param lengthBytes: number of bytes to return
:rtype: bytearray
"""
end = self.index + lengthBytes
if end > len(self.bytes):
raise DecodeError("Read past end of buffer")
ret = self.bytes[self.index : end]
self.index += lengthBytes
return ret
def skip_bytes(self, length):
"""Move the internal pointer ahead length bytes."""
if self.index + length > len(self.bytes):
raise DecodeError("Read past end of buffer")
self.index += length
def getVarBytes(self, lengthLength):
"""
Read a variable length string with a fixed length.
see Writer.add_var_bytes() for an inverse of this method
:type lengthLength: int
:param lengthLength: number of bytes in which the length of the string
is encoded in
:rtype: bytearray
"""
lengthBytes = self.get(lengthLength)
return self.getFixBytes(lengthBytes)
def getFixList(self, length, lengthList):
"""
Read a list of static length with same-sized ints.
:type length: int
:param length: size in bytes of a single element in list
:type lengthList: int
:param lengthList: number of elements in list
:rtype: list of int
"""
l = [0] * lengthList
for x in range(lengthList):
l[x] = self.get(length)
return l
def getVarList(self, length, lengthLength):
"""
Read a variable length list of same-sized integers.
:type length: int
:param length: size in bytes of a single element
:type lengthLength: int
:param lengthLength: size of the encoded length of the list
:rtype: list of int
"""
lengthList = self.get(lengthLength)
if lengthList % length != 0:
raise DecodeError("Encoded length not a multiple of element "
"length")
lengthList = lengthList // length
l = [0] * lengthList
for x in range(lengthList):
l[x] = self.get(length)
return l
def getVarTupleList(self, elemLength, elemNum, lengthLength):
"""
Read a variable length list of same sized tuples.
:type elemLength: int
:param elemLength: length in bytes of single tuple element
:type elemNum: int
:param elemNum: number of elements in tuple
:type lengthLength: int
:param lengthLength: length in bytes of the list length variable
:rtype: list of tuple of int
"""
lengthList = self.get(lengthLength)
if lengthList % (elemLength * elemNum) != 0:
raise DecodeError("Encoded length not a multiple of element "
"length")
tupleCount = lengthList // (elemLength * elemNum)
tupleList = []
for _ in range(tupleCount):
currentTuple = []
for _ in range(elemNum):
currentTuple.append(self.get(elemLength))
tupleList.append(tuple(currentTuple))
return tupleList
def startLengthCheck(self, lengthLength):
"""
Read length of struct and start a length check for parsing.
:type lengthLength: int
:param lengthLength: number of bytes in which the length is encoded
"""
self.lengthCheck = self.get(lengthLength)
self.indexCheck = self.index
def setLengthCheck(self, length):
"""
Set length of struct and start a length check for parsing.
:type length: int
:param length: expected size of parsed struct in bytes
"""
self.lengthCheck = length
self.indexCheck = self.index
def stopLengthCheck(self):
"""
Stop struct parsing, verify that no under- or overflow occurred.
In case the expected length was mismatched with actual length of
processed data, raises an exception.
"""
if (self.index - self.indexCheck) != self.lengthCheck:
raise DecodeError("Under- or over-flow while reading buffer")
def atLengthCheck(self):
"""
Check if there is data in structure left for parsing.
Returns True if the whole structure was parsed, False if there is
some data left.
Will raise an exception if overflow occured (amount of data read was
greater than expected size)
"""
if (self.index - self.indexCheck) < self.lengthCheck:
return False
elif (self.index - self.indexCheck) == self.lengthCheck:
return True
else:
raise DecodeError("Read past end of buffer")
def getRemainingLength(self):
"""Return amount of data remaining in struct being parsed."""
return len(self.bytes) - self.index

View File

@@ -0,0 +1,231 @@
# Author: Trevor Perrin
# See the LICENSE file for legal information regarding use of this file.
"""Miscellaneous functions to mask Python version differences."""
import sys
import re
import platform
import binascii
import traceback
import time
if sys.version_info >= (3,0):
def compat26Str(x): return x
# Python 3.3 requires bytes instead of bytearrays for HMAC
# So, python 2.6 requires strings, python 3 requires 'bytes',
# and python 2.7 and 3.5 can handle bytearrays...
# pylint: disable=invalid-name
# we need to keep compatHMAC and `x` for API compatibility
if sys.version_info < (3, 4):
def compatHMAC(x):
"""Convert bytes-like input to format acceptable for HMAC."""
return bytes(x)
else:
def compatHMAC(x):
"""Convert bytes-like input to format acceptable for HMAC."""
return x
# pylint: enable=invalid-name
def compatAscii2Bytes(val):
"""Convert ASCII string to bytes."""
if isinstance(val, str):
return bytes(val, 'ascii')
return val
def compat_b2a(val):
"""Convert an ASCII bytes string to string."""
return str(val, 'ascii')
def raw_input(s):
return input(s)
# So, the python3 binascii module deals with bytearrays, and python2
# deals with strings... I would rather deal with the "a" part as
# strings, and the "b" part as bytearrays, regardless of python version,
# so...
def a2b_hex(s):
try:
b = bytearray(binascii.a2b_hex(bytearray(s, "ascii")))
except Exception as e:
raise SyntaxError("base16 error: %s" % e)
return b
def a2b_base64(s):
try:
if isinstance(s, str):
s = bytearray(s, "ascii")
b = bytearray(binascii.a2b_base64(s))
except Exception as e:
raise SyntaxError("base64 error: %s" % e)
return b
def b2a_hex(b):
return binascii.b2a_hex(b).decode("ascii")
def b2a_base64(b):
return binascii.b2a_base64(b).decode("ascii")
def readStdinBinary():
return sys.stdin.buffer.read()
def compatLong(num):
return int(num)
int_types = tuple([int])
def formatExceptionTrace(e):
"""Return exception information formatted as string"""
return str(e)
def time_stamp():
"""Returns system time as a float"""
if sys.version_info >= (3, 3):
return time.perf_counter()
return time.clock()
def remove_whitespace(text):
"""Removes all whitespace from passed in string"""
return re.sub(r"\s+", "", text, flags=re.UNICODE)
# pylint: disable=invalid-name
# pylint is stupid here and deson't notice it's a function, not
# constant
bytes_to_int = int.from_bytes
# pylint: enable=invalid-name
def bit_length(val):
"""Return number of bits necessary to represent an integer."""
return val.bit_length()
def int_to_bytes(val, length=None, byteorder="big"):
"""Return number converted to bytes"""
if length is None:
if val:
length = byte_length(val)
else:
length = 1
# for gmpy we need to convert back to native int
if type(val) != int:
val = int(val)
return bytearray(val.to_bytes(length=length, byteorder=byteorder))
else:
# Python 2.6 requires strings instead of bytearrays in a couple places,
# so we define this function so it does the conversion if needed.
# same thing with very old 2.7 versions
# or on Jython
if sys.version_info < (2, 7) or sys.version_info < (2, 7, 4) \
or platform.system() == 'Java':
def compat26Str(x): return str(x)
def remove_whitespace(text):
"""Removes all whitespace from passed in string"""
return re.sub(r"\s+", "", text)
def bit_length(val):
"""Return number of bits necessary to represent an integer."""
if val == 0:
return 0
return len(bin(val))-2
else:
def compat26Str(x): return x
def remove_whitespace(text):
"""Removes all whitespace from passed in string"""
return re.sub(r"\s+", "", text, flags=re.UNICODE)
def bit_length(val):
"""Return number of bits necessary to represent an integer."""
return val.bit_length()
def compatAscii2Bytes(val):
"""Convert ASCII string to bytes."""
return val
def compat_b2a(val):
"""Convert an ASCII bytes string to string."""
return str(val)
# So, python 2.6 requires strings, python 3 requires 'bytes',
# and python 2.7 can handle bytearrays...
def compatHMAC(x): return compat26Str(x)
def a2b_hex(s):
try:
b = bytearray(binascii.a2b_hex(s))
except Exception as e:
raise SyntaxError("base16 error: %s" % e)
return b
def a2b_base64(s):
try:
b = bytearray(binascii.a2b_base64(s))
except Exception as e:
raise SyntaxError("base64 error: %s" % e)
return b
def b2a_hex(b):
return binascii.b2a_hex(compat26Str(b))
def b2a_base64(b):
return binascii.b2a_base64(compat26Str(b))
def compatLong(num):
return long(num)
int_types = (int, long)
# pylint on Python3 goes nuts for the sys dereferences...
#pylint: disable=no-member
def formatExceptionTrace(e):
"""Return exception information formatted as string"""
newStr = "".join(traceback.format_exception(sys.exc_type,
sys.exc_value,
sys.exc_traceback))
return newStr
#pylint: enable=no-member
def time_stamp():
"""Returns system time as a float"""
return time.clock()
def bytes_to_int(val, byteorder):
"""Convert bytes to an int."""
if not val:
return 0
if byteorder == "big":
return int(b2a_hex(val), 16)
if byteorder == "little":
return int(b2a_hex(val[::-1]), 16)
raise ValueError("Only 'big' and 'little' endian supported")
def int_to_bytes(val, length=None, byteorder="big"):
"""Return number converted to bytes"""
if length is None:
if val:
length = byte_length(val)
else:
length = 1
if byteorder == "big":
return bytearray((val >> i) & 0xff
for i in reversed(range(0, length*8, 8)))
if byteorder == "little":
return bytearray((val >> i) & 0xff
for i in range(0, length*8, 8))
raise ValueError("Only 'big' or 'little' endian supported")
def byte_length(val):
"""Return number of bytes necessary to represent an integer."""
length = bit_length(val)
return (length + 7) // 8
ecdsaAllCurves = False
ML_KEM_AVAILABLE = False
ML_DSA_AVAILABLE = False

View File

@@ -0,0 +1,218 @@
# Copyright (c) 2015, Hubert Kario
#
# See the LICENSE file for legal information regarding use of this file.
"""Various constant time functions for processing sensitive data"""
from __future__ import division
from .compat import compatHMAC
import hmac
def ct_lt_u32(val_a, val_b):
"""
Returns 1 if val_a < val_b, 0 otherwise. Constant time.
:type val_a: int
:type val_b: int
:param val_a: an unsigned integer representable as a 32 bit value
:param val_b: an unsigned integer representable as a 32 bit value
:rtype: int
"""
val_a &= 0xffffffff
val_b &= 0xffffffff
return (val_a^((val_a^val_b)|(((val_a-val_b)&0xffffffff)^val_b)))>>31
def ct_gt_u32(val_a, val_b):
"""
Return 1 if val_a > val_b, 0 otherwise. Constant time.
:type val_a: int
:type val_b: int
:param val_a: an unsigned integer representable as a 32 bit value
:param val_b: an unsigned integer representable as a 32 bit value
:rtype: int
"""
return ct_lt_u32(val_b, val_a)
def ct_le_u32(val_a, val_b):
"""
Return 1 if val_a <= val_b, 0 otherwise. Constant time.
:type val_a: int
:type val_b: int
:param val_a: an unsigned integer representable as a 32 bit value
:param val_b: an unsigned integer representable as a 32 bit value
:rtype: int
"""
return 1 ^ ct_gt_u32(val_a, val_b)
def ct_lsb_prop_u8(val):
"""Propagate LSB to all 8 bits of the returned int. Constant time."""
val &= 0x01
val |= val << 1
val |= val << 2
val |= val << 4
return val
def ct_lsb_prop_u16(val):
"""Propagate LSB to all 16 bits of the returned int. Constant time."""
val &= 0x01
val |= val << 1
val |= val << 2
val |= val << 4
val |= val << 8
return val
def ct_isnonzero_u32(val):
"""
Returns 1 if val is != 0, 0 otherwise. Constant time.
:type val: int
:param val: an unsigned integer representable as a 32 bit value
:rtype: int
"""
val &= 0xffffffff
return (val|(-val&0xffffffff)) >> 31
def ct_neq_u32(val_a, val_b):
"""
Return 1 if val_a != val_b, 0 otherwise. Constant time.
:type val_a: int
:type val_b: int
:param val_a: an unsigned integer representable as a 32 bit value
:param val_b: an unsigned integer representable as a 32 bit value
:rtype: int
"""
val_a &= 0xffffffff
val_b &= 0xffffffff
return (((val_a-val_b)&0xffffffff) | ((val_b-val_a)&0xffffffff)) >> 31
def ct_eq_u32(val_a, val_b):
"""
Return 1 if val_a == val_b, 0 otherwise. Constant time.
:type val_a: int
:type val_b: int
:param val_a: an unsigned integer representable as a 32 bit value
:param val_b: an unsigned integer representable as a 32 bit value
:rtype: int
"""
return 1 ^ ct_neq_u32(val_a, val_b)
def ct_check_cbc_mac_and_pad(data, mac, seqnumBytes, contentType, version,
block_size=16):
"""
Check CBC cipher HMAC and padding. Close to constant time.
:type data: bytearray
:param data: data with HMAC value to test and padding
:type mac: hashlib mac
:param mac: empty HMAC, initialised with a key
:type seqnumBytes: bytearray
:param seqnumBytes: TLS sequence number, used as input to HMAC
:type contentType: int
:param contentType: a single byte, used as input to HMAC
:type version: tuple of int
:param version: a tuple of two ints, used as input to HMAC and to guide
checking of padding
:rtype: boolean
:returns: True if MAC and pad is ok, False otherwise
"""
assert version in ((3, 0), (3, 1), (3, 2), (3, 3))
data_len = len(data)
if mac.digest_size + 1 > data_len: # data_len is public
return False
# 0 - OK
result = 0x00
#
# check padding
#
pad_length = data[data_len-1]
pad_start = data_len - pad_length - 1
pad_start = max(0, pad_start)
if version == (3, 0): # version is public
# in SSLv3 we can only check if pad is not longer than the cipher
# block size
# subtract 1 for the pad length byte
mask = ct_lsb_prop_u8(ct_lt_u32(block_size, pad_length))
result |= mask
else:
start_pos = max(0, data_len - 256)
for i in range(start_pos, data_len):
# if pad_start < i: mask = 0xff; else: mask = 0x00
mask = ct_lsb_prop_u8(ct_le_u32(pad_start, i))
# if data[i] != pad_length and "inside_pad": result = False
result |= (data[i] ^ pad_length) & mask
#
# check MAC
#
# real place where mac starts and data ends
mac_start = pad_start - mac.digest_size
mac_start = max(0, mac_start)
# place to start processing
start_pos = max(0, data_len - (256 + mac.digest_size)) // mac.block_size
start_pos *= mac.block_size
# add start data
data_mac = mac.copy()
data_mac.update(compatHMAC(seqnumBytes))
data_mac.update(compatHMAC(bytearray([contentType])))
if version != (3, 0): # version is public
data_mac.update(compatHMAC(bytearray([version[0]])))
data_mac.update(compatHMAC(bytearray([version[1]])))
data_mac.update(compatHMAC(bytearray([mac_start >> 8])))
data_mac.update(compatHMAC(bytearray([mac_start & 0xff])))
data_mac.update(compatHMAC(data[:start_pos]))
# don't check past the array end (already checked to be >= zero)
end_pos = data_len - mac.digest_size
# calculate all possible
for i in range(start_pos, end_pos): # constant for given overall length
cur_mac = data_mac.copy()
cur_mac.update(compatHMAC(data[start_pos:i]))
mac_compare = bytearray(cur_mac.digest())
# compare the hash for real only if it's the place where mac is
# supposed to be
mask = ct_lsb_prop_u8(ct_eq_u32(i, mac_start))
for j in range(0, mac.digest_size): # digest_size is public
result |= (data[i+j] ^ mac_compare[j]) & mask
# return python boolean
return result == 0
if hasattr(hmac, 'compare_digest'):
ct_compare_digest = hmac.compare_digest
else:
def ct_compare_digest(val_a, val_b):
"""Compares if string like objects are equal. Constant time."""
if len(val_a) != len(val_b):
return False
result = 0
for x, y in zip(val_a, val_b):
result |= x ^ y
return result == 0

View File

@@ -0,0 +1,366 @@
# Authors:
# Trevor Perrin
# Martin von Loewis - python 3 port
# Yngve Pettersen (ported by Paul Sokolovsky) - TLS 1.2
#
# See the LICENSE file for legal information regarding use of this file.
"""cryptomath module
This module has basic math/crypto code."""
from __future__ import print_function
import os
import math
import base64
import binascii
from .compat import compat26Str, compatHMAC, compatLong, \
bytes_to_int, int_to_bytes, bit_length, byte_length
from .codec import Writer
from . import tlshashlib as hashlib
from . import tlshmac as hmac
m2cryptoLoaded = False
gmpyLoaded = False
GMPY2_LOADED = False
pycryptoLoaded = False
# **************************************************************************
# PRNG Functions
# **************************************************************************
# Check that os.urandom works
import zlib
assert len(zlib.compress(os.urandom(1000))) > 900
def getRandomBytes(howMany):
b = bytearray(os.urandom(howMany))
assert(len(b) == howMany)
return b
prngName = "os.urandom"
# **************************************************************************
# Simple hash functions
# **************************************************************************
def MD5(b):
"""Return a MD5 digest of data"""
return secureHash(b, 'md5')
def SHA1(b):
"""Return a SHA1 digest of data"""
return secureHash(b, 'sha1')
def secureHash(data, algorithm):
"""Return a digest of `data` using `algorithm`"""
hashInstance = hashlib.new(algorithm)
hashInstance.update(compat26Str(data))
return bytearray(hashInstance.digest())
def secureHMAC(k, b, algorithm):
"""Return a HMAC using `b` and `k` using `algorithm`"""
k = compatHMAC(k)
b = compatHMAC(b)
return bytearray(hmac.new(k, b, getattr(hashlib, algorithm)).digest())
def HMAC_MD5(k, b):
return secureHMAC(k, b, 'md5')
def HMAC_SHA1(k, b):
return secureHMAC(k, b, 'sha1')
def HMAC_SHA256(k, b):
return secureHMAC(k, b, 'sha256')
def HMAC_SHA384(k, b):
return secureHMAC(k, b, 'sha384')
def HKDF_expand(PRK, info, L, algorithm):
N = divceil(L, getattr(hashlib, algorithm)().digest_size)
T = bytearray()
Titer = bytearray()
for x in range(1, N+2):
T += Titer
Titer = secureHMAC(PRK, Titer + info + bytearray([x]), algorithm)
return T[:L]
def HKDF_expand_label(secret, label, hashValue, length, algorithm):
"""
TLS1.3 key derivation function (HKDF-Expand-Label).
:param bytearray secret: the key from which to derive the keying material
:param bytearray label: label used to differentiate the keying materials
:param bytearray hashValue: bytes used to "salt" the produced keying
material
:param int length: number of bytes to produce
:param str algorithm: name of the secure hash algorithm used as the
basis of the HKDF
:rtype: bytearray
"""
hkdfLabel = Writer()
hkdfLabel.addTwo(length)
hkdfLabel.addVarSeq(bytearray(b"tls13 ") + label, 1, 1)
hkdfLabel.addVarSeq(hashValue, 1, 1)
return HKDF_expand(secret, hkdfLabel.bytes, length, algorithm)
def derive_secret(secret, label, handshake_hashes, algorithm):
"""
TLS1.3 key derivation function (Derive-Secret).
:param bytearray secret: secret key used to derive the keying material
:param bytearray label: label used to differentiate they keying materials
:param HandshakeHashes handshake_hashes: hashes of the handshake messages
or `None` if no handshake transcript is to be used for derivation of
keying material
:param str algorithm: name of the secure hash algorithm used as the
basis of the HKDF algorithm - governs how much keying material will
be generated
:rtype: bytearray
"""
if handshake_hashes is None:
hs_hash = secureHash(bytearray(b''), algorithm)
else:
hs_hash = handshake_hashes.digest(algorithm)
return HKDF_expand_label(secret, label, hs_hash,
getattr(hashlib, algorithm)().digest_size,
algorithm)
# **************************************************************************
# Converter Functions
# **************************************************************************
def bytesToNumber(b, endian="big"):
"""
Convert a number stored in bytearray to an integer.
By default assumes big-endian encoding of the number.
"""
return bytes_to_int(b, endian)
def numberToByteArray(n, howManyBytes=None, endian="big"):
"""
Convert an integer into a bytearray, zero-pad to howManyBytes.
The returned bytearray may be smaller than howManyBytes, but will
not be larger. The returned bytearray will contain a big- or little-endian
encoding of the input integer (n). Big endian encoding is used by default.
"""
if howManyBytes is not None:
length = byte_length(n)
if howManyBytes < length:
ret = int_to_bytes(n, length, endian)
if endian == "big":
return ret[length-howManyBytes:length]
return ret[:howManyBytes]
return int_to_bytes(n, howManyBytes, endian)
def mpiToNumber(mpi):
"""Convert a MPI (OpenSSL bignum string) to an integer."""
byte = bytearray(mpi)
if byte[4] & 0x80:
raise ValueError("Input must be a positive integer")
return bytesToNumber(byte[4:])
def numberToMPI(n):
b = numberToByteArray(n)
ext = 0
#If the high-order bit is going to be set,
#add an extra byte of zeros
if (numBits(n) & 0x7)==0:
ext = 1
length = numBytes(n) + ext
b = bytearray(4+ext) + b
b[0] = (length >> 24) & 0xFF
b[1] = (length >> 16) & 0xFF
b[2] = (length >> 8) & 0xFF
b[3] = length & 0xFF
return bytes(b)
# **************************************************************************
# Misc. Utility Functions
# **************************************************************************
# pylint: disable=invalid-name
# pylint recognises them as constants, not function names, also
# we can't change their names without API change
numBits = bit_length
numBytes = byte_length
# pylint: enable=invalid-name
# **************************************************************************
# Big Number Math
# **************************************************************************
def getRandomNumber(low, high):
assert low < high
howManyBits = numBits(high)
howManyBytes = numBytes(high)
lastBits = howManyBits % 8
while 1:
bytes = getRandomBytes(howManyBytes)
if lastBits:
bytes[0] = bytes[0] % (1 << lastBits)
n = bytesToNumber(bytes)
if n >= low and n < high:
return n
def gcd(a,b):
a, b = max(a,b), min(a,b)
while b:
a, b = b, a % b
return a
def lcm(a, b):
return (a * b) // gcd(a, b)
# pylint: disable=invalid-name
# disable pylint check as the (a, b) are part of the API
if GMPY2_LOADED:
def invMod(a, b):
"""Return inverse of a mod b, zero if none."""
if a == 0:
return 0
return powmod(a, -1, b)
else:
# Use Extended Euclidean Algorithm
def invMod(a, b):
"""Return inverse of a mod b, zero if none."""
c, d = a, b
uc, ud = 1, 0
while c != 0:
q = d // c
c, d = d-(q*c), c
uc, ud = ud - (q * uc), uc
if d == 1:
return ud % b
return 0
# pylint: enable=invalid-name
if gmpyLoaded or GMPY2_LOADED:
def powMod(base, power, modulus):
base = mpz(base)
power = mpz(power)
modulus = mpz(modulus)
result = pow(base, power, modulus)
return compatLong(result)
else:
powMod = pow
def divceil(divident, divisor):
"""Integer division with rounding up"""
quot, r = divmod(divident, divisor)
return quot + int(bool(r))
#Pre-calculate a sieve of the ~100 primes < 1000:
def makeSieve(n):
sieve = list(range(n))
for count in range(2, int(math.sqrt(n))+1):
if sieve[count] == 0:
continue
x = sieve[count] * 2
while x < len(sieve):
sieve[x] = 0
x += sieve[count]
sieve = [x for x in sieve[2:] if x]
return sieve
def isPrime(n, iterations=5, display=False, sieve=makeSieve(1000)):
#Trial division with sieve
for x in sieve:
if x >= n: return True
if n % x == 0: return False
#Passed trial division, proceed to Rabin-Miller
#Rabin-Miller implemented per Ferguson & Schneier
#Compute s, t for Rabin-Miller
if display: print("*", end=' ')
s, t = n-1, 0
while s % 2 == 0:
s, t = s//2, t+1
#Repeat Rabin-Miller x times
a = 2 #Use 2 as a base for first iteration speedup, per HAC
for count in range(iterations):
v = powMod(a, s, n)
if v==1:
continue
i = 0
while v != n-1:
if i == t-1:
return False
else:
v, i = powMod(v, 2, n), i+1
a = getRandomNumber(2, n)
return True
def getRandomPrime(bits, display=False):
"""
Generate a random prime number of a given size.
the number will be 'bits' bits long (i.e. generated number will be
larger than `(2^(bits-1) * 3 ) / 2` but smaller than 2^bits.
"""
assert bits >= 10
#The 1.5 ensures the 2 MSBs are set
#Thus, when used for p,q in RSA, n will have its MSB set
#
#Since 30 is lcm(2,3,5), we'll set our test numbers to
#29 % 30 and keep them there
low = ((2 ** (bits-1)) * 3) // 2
high = 2 ** bits - 30
while True:
if display:
print(".", end=' ')
cand_p = getRandomNumber(low, high)
# make odd
if cand_p % 2 == 0:
cand_p += 1
if isPrime(cand_p, display=display):
return cand_p
#Unused at the moment...
def getRandomSafePrime(bits, display=False):
"""Generate a random safe prime.
Will generate a prime `bits` bits long (see getRandomPrime) such that
the (p-1)/2 will also be prime.
"""
assert bits >= 10
#The 1.5 ensures the 2 MSBs are set
#Thus, when used for p,q in RSA, n will have its MSB set
#
#Since 30 is lcm(2,3,5), we'll set our test numbers to
#29 % 30 and keep them there
low = (2 ** (bits-2)) * 3//2
high = (2 ** (bits-1)) - 30
q = getRandomNumber(low, high)
q += 29 - (q % 30)
while 1:
if display: print(".", end=' ')
q += 30
if (q >= high):
q = getRandomNumber(low, high)
q += 29 - (q % 30)
#Ideas from Tom Wu's SRP code
#Do trial division on p and q before Rabin-Miller
if isPrime(q, 0, display=display):
p = (2 * q) + 1
if isPrime(p, display=display):
if isPrime(q, display=display):
return p

View File

@@ -0,0 +1,218 @@
# Copyright (c) 2018 Hubert Kario
#
# See the LICENSE file for legal information regarding use of this file.
"""Methods for deprecating old names for arguments or attributes."""
import warnings
import inspect
from functools import wraps
def deprecated_class_name(old_name,
warn="Class name '{old_name}' is deprecated, "
"please use '{new_name}'"):
"""
Class decorator to deprecate a use of class.
:param str old_name: the deprecated name that will be registered, but
will raise warnings if used.
:param str warn: DeprecationWarning format string for informing the
user what is the current class name, uses 'old_name' for the deprecated
keyword name and the 'new_name' for the current one.
Example: "Old name: {old_nam}, use '{new_name}' instead".
"""
def _wrap(obj):
assert callable(obj)
def _warn():
warnings.warn(warn.format(old_name=old_name,
new_name=obj.__name__),
DeprecationWarning,
stacklevel=3)
def _wrap_with_warn(func, is_inspect):
@wraps(func)
def _func(*args, **kwargs):
if is_inspect:
# XXX: If use another name to call,
# you will not get the warning.
# we do this instead of subclassing or metaclass as
# we want to isinstance(new_name(), old_name) and
# isinstance(old_name(), new_name) to work
frame = inspect.currentframe().f_back
code = inspect.getframeinfo(frame).code_context
if [line for line in code
if '{0}('.format(old_name) in line]:
_warn()
else:
_warn()
return func(*args, **kwargs)
return _func
# Make old name available.
frame = inspect.currentframe().f_back
if old_name in frame.f_globals:
raise NameError("Name '{0}' already in use.".format(old_name))
if inspect.isclass(obj):
obj.__init__ = _wrap_with_warn(obj.__init__, True)
placeholder = obj
else:
placeholder = _wrap_with_warn(obj, False)
frame.f_globals[old_name] = placeholder
return obj
return _wrap
def deprecated_params(names, warn="Param name '{old_name}' is deprecated, "
"please use '{new_name}'"):
"""Decorator to translate obsolete names and warn about their use.
:param dict names: dictionary with pairs of new_name: old_name
that will be used for translating obsolete param names to new names
:param str warn: DeprecationWarning format string for informing the user
what is the current parameter name, uses 'old_name' for the
deprecated keyword name and 'new_name' for the current one.
Example: "Old name: {old_name}, use {new_name} instead".
"""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
for new_name, old_name in names.items():
if old_name in kwargs:
if new_name in kwargs:
raise TypeError("got multiple values for keyword "
"argument '{0}'".format(new_name))
warnings.warn(warn.format(old_name=old_name,
new_name=new_name),
DeprecationWarning,
stacklevel=2)
kwargs[new_name] = kwargs.pop(old_name)
return func(*args, **kwargs)
return wrapper
return decorator
def deprecated_instance_attrs(names,
warn="Attribute '{old_name}' is deprecated, "
"please use '{new_name}'"):
"""Decorator to deprecate class instance attributes.
Translates all names in `names` to use new names and emits warnings
if the translation was necessary. Does apply only to instance variables
and attributes (won't modify behaviour of class variables, static methods,
etc.
:param dict names: dictionary with paris of new_name: old_name that will
be used to translate the calls
:param str warn: DeprecationWarning format string for informing the user
what is the current parameter name, uses 'old_name' for the
deprecated keyword name and 'new_name' for the current one.
Example: "Old name: {old_name}, use {new_name} instead".
"""
# reverse the dict as we're looking for old attributes, not new ones
names = dict((j, i) for i, j in names.items())
def decorator(clazz):
def getx(self, name, __old_getx=getattr(clazz, "__getattr__", None)):
if name in names:
warnings.warn(warn.format(old_name=name,
new_name=names[name]),
DeprecationWarning,
stacklevel=2)
return getattr(self, names[name])
if __old_getx:
if hasattr(__old_getx, "__func__"):
return __old_getx.__func__(self, name)
return __old_getx(self, name)
raise AttributeError("'{0}' object has no attribute '{1}'"
.format(clazz.__name__, name))
getx.__name__ = "__getattr__"
clazz.__getattr__ = getx
def setx(self, name, value, __old_setx=getattr(clazz, "__setattr__")):
if name in names:
warnings.warn(warn.format(old_name=name,
new_name=names[name]),
DeprecationWarning,
stacklevel=2)
setattr(self, names[name], value)
else:
__old_setx(self, name, value)
setx.__name__ = "__setattr__"
clazz.__setattr__ = setx
def delx(self, name, __old_delx=getattr(clazz, "__delattr__")):
if name in names:
warnings.warn(warn.format(old_name=name,
new_name=names[name]),
DeprecationWarning,
stacklevel=2)
delattr(self, names[name])
else:
__old_delx(self, name)
delx.__name__ = "__delattr__"
clazz.__delattr__ = delx
return clazz
return decorator
def deprecated_attrs(names, warn="Attribute '{old_name}' is deprecated, "
"please use '{new_name}'"):
"""Decorator to deprecate all specified attributes in class.
Translates all names in `names` to use new names and emits warnings
if the translation was necessary.
Note: uses metaclass magic so is incompatible with other metaclass uses
:param dict names: dictionary with paris of new_name: old_name that will
be used to translate the calls
:param str warn: DeprecationWarning format string for informing the user
what is the current parameter name, uses 'old_name' for the
deprecated keyword name and 'new_name' for the current one.
Example: "Old name: {old_name}, use {new_name} instead".
"""
# prepare metaclass for handling all the class methods, class variables
# and static methods (as they don't go through instance's __getattr__)
class DeprecatedProps(type):
pass
metaclass = deprecated_instance_attrs(names, warn)(DeprecatedProps)
def wrapper(cls):
cls = deprecated_instance_attrs(names, warn)(cls)
# apply metaclass
orig_vars = cls.__dict__.copy()
slots = orig_vars.get('__slots__')
if slots is not None:
if isinstance(slots, str):
slots = [slots]
for slots_var in slots:
orig_vars.pop(slots_var)
orig_vars.pop('__dict__', None)
orig_vars.pop('__weakref__', None)
return metaclass(cls.__name__, cls.__bases__, orig_vars)
return wrapper
def deprecated_method(message):
"""Decorator for deprecating methods.
:param ste message: The message you want to display.
"""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
warnings.warn("{0} is a deprecated method. {1}".format(func.__name__, message),
DeprecationWarning, stacklevel=2)
return func(*args, **kwargs)
return wrapper
return decorator

View File

@@ -6,6 +6,9 @@ from urllib.parse import urlparse
import httpx
from mediaflow_proxy.utils.http_utils import create_httpx_client
from mediaflow_proxy.configs import settings
from collections import OrderedDict
import time
from urllib.parse import urljoin
logger = logging.getLogger(__name__)
@@ -23,12 +26,21 @@ class HLSPreBuffer:
max_cache_size (int): Maximum number of segments to cache (uses config if None)
prebuffer_segments (int): Number of segments to pre-buffer ahead (uses config if None)
"""
from collections import OrderedDict
import time
from urllib.parse import urljoin
self.max_cache_size = max_cache_size or settings.hls_prebuffer_cache_size
self.prebuffer_segments = prebuffer_segments or settings.hls_prebuffer_segments
self.max_memory_percent = settings.hls_prebuffer_max_memory_percent
self.emergency_threshold = settings.hls_prebuffer_emergency_threshold
self.segment_cache: Dict[str, bytes] = {}
# Cache LRU
self.segment_cache: "OrderedDict[str, bytes]" = OrderedDict()
# Mappa playlist -> lista segmenti
self.segment_urls: Dict[str, List[str]] = {}
# Mappa inversa segmento -> (playlist_url, index)
self.segment_to_playlist: Dict[str, tuple[str, int]] = {}
# Stato per playlist: {headers, last_access, refresh_task, target_duration}
self.playlist_state: Dict[str, dict] = {}
self.client = create_httpx_client()
async def prebuffer_playlist(self, playlist_url: str, headers: Dict[str, str]) -> None:
@@ -41,37 +53,44 @@ class HLSPreBuffer:
"""
try:
logger.debug(f"Starting pre-buffer for playlist: {playlist_url}")
# Download and parse playlist
response = await self.client.get(playlist_url, headers=headers)
response.raise_for_status()
playlist_content = response.text
# Check if this is a master playlist (contains variants)
# Se master playlist: prendi la prima variante (fix relativo)
if "#EXT-X-STREAM-INF" in playlist_content:
logger.debug(f"Master playlist detected, finding first variant")
# Extract variant URLs
variant_urls = self._extract_variant_urls(playlist_content, playlist_url)
if variant_urls:
# Pre-buffer the first variant
first_variant_url = variant_urls[0]
logger.debug(f"Pre-buffering first variant: {first_variant_url}")
await self.prebuffer_playlist(first_variant_url, headers)
else:
logger.warning("No variants found in master playlist")
return
# Extract segment URLs
# Media playlist: estrai segmenti, salva stato e lancia refresh loop
segment_urls = self._extract_segment_urls(playlist_content, playlist_url)
# Store segment URLs for this playlist
self.segment_urls[playlist_url] = segment_urls
# Pre-buffer first few segments
# aggiorna mappa inversa
for idx, u in enumerate(segment_urls):
self.segment_to_playlist[u] = (playlist_url, idx)
# prebuffer iniziale
await self._prebuffer_segments(segment_urls[:self.prebuffer_segments], headers)
logger.info(f"Pre-buffered {min(self.prebuffer_segments, len(segment_urls))} segments for {playlist_url}")
# setup refresh loop se non già attivo
target_duration = self._parse_target_duration(playlist_content) or 6
st = self.playlist_state.get(playlist_url, {})
if not st.get("refresh_task") or st["refresh_task"].done():
task = asyncio.create_task(self._refresh_playlist_loop(playlist_url, headers, target_duration))
self.playlist_state[playlist_url] = {
"headers": headers,
"last_access": asyncio.get_event_loop().time(),
"refresh_task": task,
"target_duration": target_duration,
}
except Exception as e:
logger.warning(f"Failed to pre-buffer playlist {playlist_url}: {e}")
@@ -124,34 +143,24 @@ class HLSPreBuffer:
def _extract_variant_urls(self, playlist_content: str, base_url: str) -> List[str]:
"""
Extract variant URLs from master playlist content.
Args:
playlist_content (str): Content of the master playlist
base_url (str): Base URL for resolving relative URLs
Returns:
List[str]: List of variant URLs
Estrae le varianti dal master playlist. Corretto per gestire URI relativi:
prende la riga non-commento successiva a #EXT-X-STREAM-INF e la risolve rispetto a base_url.
"""
from urllib.parse import urljoin
variant_urls = []
lines = playlist_content.split('\n')
lines = [l.strip() for l in playlist_content.split('\n')]
take_next_uri = False
for line in lines:
line = line.strip()
if line and not line.startswith('#') and ('http://' in line or 'https://' in line):
# Resolve relative URLs
if line.startswith('http'):
variant_urls.append(line)
else:
# Join with base URL for relative paths
parsed_base = urlparse(base_url)
variant_url = f"{parsed_base.scheme}://{parsed_base.netloc}{line}"
variant_urls.append(variant_url)
if line.startswith("#EXT-X-STREAM-INF"):
take_next_uri = True
continue
if take_next_uri:
take_next_uri = False
if line and not line.startswith('#'):
variant_urls.append(urljoin(base_url, line))
logger.debug(f"Extracted {len(variant_urls)} variant URLs from master playlist")
if variant_urls:
logger.debug(f"First variant URL: {variant_urls[0]}")
return variant_urls
async def _prebuffer_segments(self, segment_urls: List[str], headers: Dict[str, str]) -> None:
@@ -166,7 +175,6 @@ class HLSPreBuffer:
for url in segment_urls:
if url not in self.segment_cache:
tasks.append(self._download_segment(url, headers))
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
@@ -196,16 +204,16 @@ class HLSPreBuffer:
def _emergency_cache_cleanup(self) -> None:
"""
Perform emergency cache cleanup when memory usage is high.
Esegue cleanup LRU rimuovendo il 50% più vecchio.
"""
if self._check_memory_threshold():
logger.warning("Emergency cache cleanup triggered due to high memory usage")
# Clear 50% of cache
cache_size = len(self.segment_cache)
keys_to_remove = list(self.segment_cache.keys())[:cache_size // 2]
for key in keys_to_remove:
del self.segment_cache[key]
logger.info(f"Emergency cleanup removed {len(keys_to_remove)} segments from cache")
to_remove = max(1, len(self.segment_cache) // 2)
removed = 0
while removed < to_remove and self.segment_cache:
self.segment_cache.popitem(last=False) # rimuovi LRU
removed += 1
logger.info(f"Emergency cleanup removed {removed} segments from cache")
async def _download_segment(self, segment_url: str, headers: Dict[str, str]) -> None:
"""
@@ -216,29 +224,26 @@ class HLSPreBuffer:
headers (Dict[str, str]): Headers to use for request
"""
try:
# Check memory usage before downloading
memory_percent = self._get_memory_usage_percent()
if memory_percent > self.max_memory_percent:
logger.warning(f"Memory usage {memory_percent}% exceeds limit {self.max_memory_percent}%, skipping download")
return
response = await self.client.get(segment_url, headers=headers)
response.raise_for_status()
# Cache the segment
# Cache LRU
self.segment_cache[segment_url] = response.content
# Check for emergency cleanup
self.segment_cache.move_to_end(segment_url, last=True)
if self._check_memory_threshold():
self._emergency_cache_cleanup()
# Maintain cache size
elif len(self.segment_cache) > self.max_cache_size:
# Remove oldest entries (simple FIFO)
oldest_key = next(iter(self.segment_cache))
del self.segment_cache[oldest_key]
# Evict LRU finché non rientra
while len(self.segment_cache) > self.max_cache_size:
self.segment_cache.popitem(last=False)
logger.debug(f"Cached segment: {segment_url}")
except Exception as e:
logger.warning(f"Failed to download segment {segment_url}: {e}")
@@ -256,38 +261,64 @@ class HLSPreBuffer:
# Check cache first
if segment_url in self.segment_cache:
logger.debug(f"Cache hit for segment: {segment_url}")
return self.segment_cache[segment_url]
# Check memory usage before downloading
# LRU touch
data = self.segment_cache[segment_url]
self.segment_cache.move_to_end(segment_url, last=True)
# aggiorna last_access per la playlist se mappata
pl = self.segment_to_playlist.get(segment_url)
if pl:
st = self.playlist_state.get(pl[0])
if st:
st["last_access"] = asyncio.get_event_loop().time()
return data
memory_percent = self._get_memory_usage_percent()
if memory_percent > self.max_memory_percent:
logger.warning(f"Memory usage {memory_percent}% exceeds limit {self.max_memory_percent}%, skipping download")
return None
# Download if not in cache
try:
response = await self.client.get(segment_url, headers=headers)
response.raise_for_status()
segment_data = response.content
# Cache the segment
# Cache LRU
self.segment_cache[segment_url] = segment_data
# Check for emergency cleanup
self.segment_cache.move_to_end(segment_url, last=True)
if self._check_memory_threshold():
self._emergency_cache_cleanup()
# Maintain cache size
elif len(self.segment_cache) > self.max_cache_size:
oldest_key = next(iter(self.segment_cache))
del self.segment_cache[oldest_key]
while len(self.segment_cache) > self.max_cache_size:
self.segment_cache.popitem(last=False)
# aggiorna last_access per playlist
pl = self.segment_to_playlist.get(segment_url)
if pl:
st = self.playlist_state.get(pl[0])
if st:
st["last_access"] = asyncio.get_event_loop().time()
logger.debug(f"Downloaded and cached segment: {segment_url}")
return segment_data
except Exception as e:
logger.warning(f"Failed to get segment {segment_url}: {e}")
return None
async def prebuffer_from_segment(self, segment_url: str, headers: Dict[str, str]) -> None:
"""
Dato un URL di segmento, prebuffer i successivi in base alla playlist e all'indice mappato.
"""
mapped = self.segment_to_playlist.get(segment_url)
if not mapped:
return
playlist_url, idx = mapped
# aggiorna access time
st = self.playlist_state.get(playlist_url)
if st:
st["last_access"] = asyncio.get_event_loop().time()
await self.prebuffer_next_segments(playlist_url, idx, headers)
async def prebuffer_next_segments(self, playlist_url: str, current_segment_index: int, headers: Dict[str, str]) -> None:
"""
Pre-buffer next segments based on current playback position.
@@ -299,10 +330,8 @@ class HLSPreBuffer:
"""
if playlist_url not in self.segment_urls:
return
segment_urls = self.segment_urls[playlist_url]
next_segments = segment_urls[current_segment_index + 1:current_segment_index + 1 + self.prebuffer_segments]
if next_segments:
await self._prebuffer_segments(next_segments, headers)
@@ -310,6 +339,8 @@ class HLSPreBuffer:
"""Clear the segment cache."""
self.segment_cache.clear()
self.segment_urls.clear()
self.segment_to_playlist.clear()
self.playlist_state.clear()
logger.info("HLS pre-buffer cache cleared")
async def close(self) -> None:
@@ -318,4 +349,142 @@ class HLSPreBuffer:
# Global pre-buffer instance
hls_prebuffer = HLSPreBuffer()
hls_prebuffer = HLSPreBuffer()
class HLSPreBuffer:
def _parse_target_duration(self, playlist_content: str) -> Optional[int]:
"""
Parse EXT-X-TARGETDURATION from a media playlist and return duration in seconds.
Returns None if not present or unparsable.
"""
for line in playlist_content.splitlines():
line = line.strip()
if line.startswith("#EXT-X-TARGETDURATION:"):
try:
value = line.split(":", 1)[1].strip()
return int(float(value))
except Exception:
return None
return None
async def _refresh_playlist_loop(self, playlist_url: str, headers: Dict[str, str], target_duration: int) -> None:
"""
Aggiorna periodicamente la playlist per seguire la sliding window e mantenere la cache coerente.
Interrompe e pulisce dopo inattività prolungata.
"""
sleep_s = max(2, min(15, int(target_duration)))
inactivity_timeout = 600 # 10 minuti
while True:
try:
st = self.playlist_state.get(playlist_url)
now = asyncio.get_event_loop().time()
if not st:
return
if now - st.get("last_access", now) > inactivity_timeout:
# cleanup specifico della playlist
urls = set(self.segment_urls.get(playlist_url, []))
if urls:
# rimuovi dalla cache solo i segmenti di questa playlist
for u in list(self.segment_cache.keys()):
if u in urls:
self.segment_cache.pop(u, None)
# rimuovi mapping
for u in urls:
self.segment_to_playlist.pop(u, None)
self.segment_urls.pop(playlist_url, None)
self.playlist_state.pop(playlist_url, None)
logger.info(f"Stopped HLS prebuffer for inactive playlist: {playlist_url}")
return
# refresh manifest
resp = await self.client.get(playlist_url, headers=headers)
resp.raise_for_status()
content = resp.text
new_target = self._parse_target_duration(content)
if new_target:
sleep_s = max(2, min(15, int(new_target)))
new_urls = self._extract_segment_urls(content, playlist_url)
if new_urls:
self.segment_urls[playlist_url] = new_urls
# rebuild reverse map per gli ultimi N (limita la memoria)
for idx, u in enumerate(new_urls[-(self.max_cache_size * 2):]):
# rimappiando sovrascrivi eventuali entry
real_idx = len(new_urls) - (self.max_cache_size * 2) + idx if len(new_urls) > (self.max_cache_size * 2) else idx
self.segment_to_playlist[u] = (playlist_url, real_idx)
# tenta un prebuffer proattivo: se conosciamo l'ultimo segmento accessibile, anticipa i successivi
# Non conosciamo l'indice di riproduzione corrente qui, quindi non facciamo nulla di aggressivo.
except Exception as e:
logger.debug(f"Playlist refresh error for {playlist_url}: {e}")
await asyncio.sleep(sleep_s)
def _extract_segment_urls(self, playlist_content: str, base_url: str) -> List[str]:
"""
Extract segment URLs from HLS playlist content.
Args:
playlist_content (str): Content of the HLS playlist
base_url (str): Base URL for resolving relative URLs
Returns:
List[str]: List of segment URLs
"""
segment_urls = []
lines = playlist_content.split('\n')
logger.debug(f"Analyzing playlist with {len(lines)} lines")
for line in lines:
line = line.strip()
if line and not line.startswith('#'):
# Check if line contains a URL (http/https) or is a relative path
if 'http://' in line or 'https://' in line:
segment_urls.append(line)
logger.debug(f"Found absolute URL: {line}")
elif line and not line.startswith('#'):
# This might be a relative path to a segment
parsed_base = urlparse(base_url)
# Ensure proper path joining
if line.startswith('/'):
segment_url = f"{parsed_base.scheme}://{parsed_base.netloc}{line}"
else:
# Get the directory path from base_url
base_path = parsed_base.path.rsplit('/', 1)[0] if '/' in parsed_base.path else ''
segment_url = f"{parsed_base.scheme}://{parsed_base.netloc}{base_path}/{line}"
segment_urls.append(segment_url)
logger.debug(f"Found relative path: {line} -> {segment_url}")
logger.debug(f"Extracted {len(segment_urls)} segment URLs from playlist")
if segment_urls:
logger.debug(f"First segment URL: {segment_urls[0]}")
else:
logger.debug("No segment URLs found in playlist")
# Log first few lines for debugging
for i, line in enumerate(lines[:10]):
logger.debug(f"Line {i}: {line}")
return segment_urls
def _extract_variant_urls(self, playlist_content: str, base_url: str) -> List[str]:
"""
Estrae le varianti dal master playlist. Corretto per gestire URI relativi:
prende la riga non-commento successiva a #EXT-X-STREAM-INF e la risolve rispetto a base_url.
"""
from urllib.parse import urljoin
variant_urls = []
lines = [l.strip() for l in playlist_content.split('\n')]
take_next_uri = False
for line in lines:
if line.startswith("#EXT-X-STREAM-INF"):
take_next_uri = True
continue
if take_next_uri:
take_next_uri = False
if line and not line.startswith('#'):
variant_urls.append(urljoin(base_url, line))
logger.debug(f"Extracted {len(variant_urls)} variant URLs from master playlist")
if variant_urls:
logger.debug(f"First variant URL: {variant_urls[0]}")
return variant_urls

View File

@@ -0,0 +1,54 @@
import logging
import re
from typing import List, Dict, Any, Optional, Tuple
from urllib.parse import urljoin
logger = logging.getLogger(__name__)
def parse_hls_playlist(playlist_content: str, base_url: Optional[str] = None) -> List[Dict[str, Any]]:
"""
Parses an HLS master playlist to extract stream information.
Args:
playlist_content (str): The content of the M3U8 master playlist.
base_url (str, optional): The base URL of the playlist for resolving relative stream URLs. Defaults to None.
Returns:
List[Dict[str, Any]]: A list of dictionaries, each representing a stream variant.
"""
streams = []
lines = playlist_content.strip().split('\n')
# Regex to capture attributes from #EXT-X-STREAM-INF
stream_inf_pattern = re.compile(r'#EXT-X-STREAM-INF:(.*)')
for i, line in enumerate(lines):
if line.startswith('#EXT-X-STREAM-INF'):
stream_info = {'raw_stream_inf': line}
match = stream_inf_pattern.match(line)
if not match:
logger.warning(f"Could not parse #EXT-X-STREAM-INF line: {line}")
continue
attributes_str = match.group(1)
# Parse attributes like BANDWIDTH, RESOLUTION, etc.
attributes = re.findall(r'([A-Z-]+)=("([^"]+)"|([^,]+))', attributes_str)
for key, _, quoted_val, unquoted_val in attributes:
value = quoted_val if quoted_val else unquoted_val
if key == 'RESOLUTION':
try:
width, height = map(int, value.split('x'))
stream_info['resolution'] = (width, height)
except ValueError:
stream_info['resolution'] = (0, 0)
else:
stream_info[key.lower().replace('-', '_')] = value
# The next line should be the stream URL
if i + 1 < len(lines) and not lines[i + 1].startswith('#'):
stream_url = lines[i + 1].strip()
stream_info['url'] = urljoin(base_url, stream_url) if base_url else stream_url
streams.append(stream_info)
return streams

View File

@@ -3,7 +3,7 @@ import typing
from dataclasses import dataclass
from functools import partial
from urllib import parse
from urllib.parse import urlencode
from urllib.parse import urlencode, urlparse
import anyio
import h11
@@ -81,6 +81,10 @@ async def fetch_with_retry(client, method, url, headers, follow_redirects=True,
class Streamer:
# PNG signature and IEND marker for fake PNG header detection (StreamWish/FileMoon)
_PNG_SIGNATURE = b"\x89PNG\r\n\x1a\n"
_PNG_IEND_MARKER = b"\x49\x45\x4E\x44\xAE\x42\x60\x82"
def __init__(self, client):
"""
Initializes the Streamer with an HTTP client.
@@ -132,13 +136,48 @@ class Streamer:
logger.error(f"Error creating streaming response: {e}")
raise RuntimeError(f"Error creating streaming response: {e}")
@staticmethod
def _strip_fake_png_wrapper(chunk: bytes) -> bytes:
"""
Strip fake PNG wrapper from chunk data.
Some streaming services (StreamWish, FileMoon) prepend a fake PNG image
to video data to evade detection. This method detects and removes it.
Args:
chunk: The raw chunk data that may contain a fake PNG header.
Returns:
The chunk with fake PNG wrapper removed, or original chunk if not present.
"""
if not chunk.startswith(Streamer._PNG_SIGNATURE):
return chunk
# Find the IEND marker that signals end of PNG data
iend_pos = chunk.find(Streamer._PNG_IEND_MARKER)
if iend_pos == -1:
# IEND not found in this chunk - return as-is to avoid data corruption
logger.debug("PNG signature detected but IEND marker not found in chunk")
return chunk
# Calculate position after IEND marker
content_start = iend_pos + len(Streamer._PNG_IEND_MARKER)
# Skip any padding bytes (null or 0xFF) between PNG and actual content
while content_start < len(chunk) and chunk[content_start] in (0x00, 0xFF):
content_start += 1
stripped_bytes = content_start
logger.debug(f"Stripped {stripped_bytes} bytes of fake PNG wrapper from stream")
return chunk[content_start:]
async def stream_content(self) -> typing.AsyncGenerator[bytes, None]:
"""
Streams the content from the response.
"""
if not self.response:
raise RuntimeError("No response available for streaming")
is_first_chunk = True
try:
self.parse_content_range()
@@ -154,15 +193,19 @@ class Streamer:
mininterval=1,
) as self.progress_bar:
async for chunk in self.response.aiter_bytes():
if is_first_chunk:
is_first_chunk = False
chunk = self._strip_fake_png_wrapper(chunk)
yield chunk
chunk_size = len(chunk)
self.bytes_transferred += chunk_size
self.progress_bar.set_postfix_str(
f"📥 : {self.format_bytes(self.bytes_transferred)}", refresh=False
)
self.progress_bar.update(chunk_size)
self.bytes_transferred += len(chunk)
self.progress_bar.update(len(chunk))
else:
async for chunk in self.response.aiter_bytes():
if is_first_chunk:
is_first_chunk = False
chunk = self._strip_fake_png_wrapper(chunk)
yield chunk
self.bytes_transferred += len(chunk)
@@ -187,10 +230,19 @@ class Streamer:
raise DownloadError(502, f"Protocol error while streaming: {e}")
except GeneratorExit:
logger.info("Streaming session stopped by the user")
except httpx.ReadError as e:
# Handle network read errors gracefully - these occur when upstream connection drops
logger.warning(f"ReadError while streaming: {e}")
if self.bytes_transferred > 0:
logger.info(f"Partial content received ({self.bytes_transferred} bytes) before ReadError. Graceful termination.")
return
else:
raise DownloadError(502, f"ReadError while streaming: {e}")
except Exception as e:
logger.error(f"Error streaming content: {e}")
raise
@staticmethod
def format_bytes(size) -> str:
power = 2**10
@@ -490,6 +542,23 @@ def get_proxy_headers(request: Request) -> ProxyRequestHeaders:
"""
request_headers = {k: v for k, v in request.headers.items() if k in SUPPORTED_REQUEST_HEADERS}
request_headers.update({k[2:].lower(): v for k, v in request.query_params.items() if k.startswith("h_")})
request_headers.setdefault("user-agent", settings.user_agent)
# Handle common misspelling of referer
if "referrer" in request_headers:
if "referer" not in request_headers:
request_headers["referer"] = request_headers.pop("referrer")
dest = request.query_params.get("d", "")
host = urlparse(dest).netloc.lower()
if "vidoza" in host or "videzz" in host:
# Remove ALL empty headers
for h in list(request_headers.keys()):
v = request_headers[h]
if v is None or v.strip() == "":
request_headers.pop(h, None)
response_headers = {k[2:].lower(): v for k, v in request.query_params.items() if k.startswith("r_")}
return ProxyRequestHeaders(request_headers, response_headers)
@@ -527,21 +596,14 @@ class EnhancedStreamingResponse(Response):
logger.error(f"Error in listen_for_disconnect: {str(e)}")
async def stream_response(self, send: Send) -> None:
# Track if response headers have been sent to prevent duplicate headers
response_started = False
# Track if response finalization (more_body: False) has been sent to prevent ASGI protocol violation
finalization_sent = False
try:
# Initialize headers
headers = list(self.raw_headers)
# Set the transfer-encoding to chunked for streamed responses with content-length
# when content-length is present. This ensures we don't hit protocol errors
# if the upstream connection is closed prematurely.
for i, (name, _) in enumerate(headers):
if name.lower() == b"content-length":
# Replace content-length with transfer-encoding: chunked for streaming
headers[i] = (b"transfer-encoding", b"chunked")
headers = [h for h in headers if h[0].lower() != b"content-length"]
logger.debug("Switched from content-length to chunked transfer-encoding for streaming")
break
# Start the response
await send(
{
@@ -550,6 +612,7 @@ class EnhancedStreamingResponse(Response):
"headers": headers,
}
)
response_started = True
# Track if we've sent any data
data_sent = False
@@ -568,27 +631,29 @@ class EnhancedStreamingResponse(Response):
# Successfully streamed all content
await send({"type": "http.response.body", "body": b"", "more_body": False})
except (httpx.RemoteProtocolError, h11._util.LocalProtocolError) as e:
# Handle connection closed errors
finalization_sent = True
except (httpx.RemoteProtocolError, httpx.ReadError, h11._util.LocalProtocolError) as e:
# Handle connection closed / read errors gracefully
if data_sent:
# We've sent some data to the client, so try to complete the response
logger.warning(f"Remote protocol error after partial streaming: {e}")
logger.warning(f"Upstream connection error after partial streaming: {e}")
try:
await send({"type": "http.response.body", "body": b"", "more_body": False})
finalization_sent = True
logger.info(
f"Response finalized after partial content ({self.actual_content_length} bytes transferred)"
)
except Exception as close_err:
logger.warning(f"Could not finalize response after remote error: {close_err}")
logger.warning(f"Could not finalize response after upstream error: {close_err}")
else:
# No data was sent, re-raise the error
logger.error(f"Protocol error before any data was streamed: {e}")
logger.error(f"Upstream error before any data was streamed: {e}")
raise
except Exception as e:
logger.exception(f"Error in stream_response: {str(e)}")
if not isinstance(e, (ConnectionResetError, anyio.BrokenResourceError)):
if not isinstance(e, (ConnectionResetError, anyio.BrokenResourceError)) and not response_started:
# Only attempt to send error response if headers haven't been sent yet
try:
# Try to send an error response if client is still connected
await send(
{
"type": "http.response.start",
@@ -598,36 +663,39 @@ class EnhancedStreamingResponse(Response):
)
error_message = f"Streaming error: {str(e)}".encode("utf-8")
await send({"type": "http.response.body", "body": error_message, "more_body": False})
finalization_sent = True
except Exception:
# If we can't send an error response, just log it
pass
elif response_started and not finalization_sent:
# Response already started but not finalized - gracefully close the stream
try:
await send({"type": "http.response.body", "body": b"", "more_body": False})
finalization_sent = True
except Exception:
pass
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
async with anyio.create_task_group() as task_group:
streaming_completed = False
stream_func = partial(self.stream_response, send)
listen_func = partial(self.listen_for_disconnect, receive)
async def wrap(func: typing.Callable[[], typing.Awaitable[None]]) -> None:
try:
await func()
# If this is the stream_response function and it completes successfully, mark as done
if func == stream_func:
nonlocal streaming_completed
streaming_completed = True
except Exception as e:
if isinstance(e, (httpx.RemoteProtocolError, h11._util.LocalProtocolError)):
# Handle protocol errors more gracefully
logger.warning(f"Protocol error during streaming: {e}")
elif not isinstance(e, anyio.get_cancelled_exc_class()):
logger.exception("Error in streaming task")
# Only re-raise if it's not a protocol error or cancellation
# Note: stream_response and listen_for_disconnect handle their own exceptions
# internally. This is a safety net for any unexpected exceptions that might
# escape due to future code changes.
if not isinstance(e, anyio.get_cancelled_exc_class()):
logger.exception(f"Unexpected error in streaming task: {type(e).__name__}: {e}")
# Re-raise unexpected errors to surface bugs rather than silently swallowing them
raise
finally:
# Only cancel the task group if we're in disconnect listener or
# if streaming_completed is True (meaning we finished normally)
if func == listen_func or streaming_completed:
task_group.cancel_scope.cancel()
# Cancel task group when either task completes or fails:
# - stream_func finished (success or failure) -> stop listening for disconnect
# - listen_func finished (client disconnected) -> stop streaming
task_group.cancel_scope.cancel()
# Start the streaming response in a separate task
task_group.start_soon(wrap, stream_func)

View File

@@ -11,7 +11,7 @@ from mediaflow_proxy.utils.hls_prebuffer import hls_prebuffer
class M3U8Processor:
def __init__(self, request, key_url: str = None, force_playlist_proxy: bool = None):
def __init__(self, request, key_url: str = None, force_playlist_proxy: bool = None, key_only_proxy: bool = False, no_proxy: bool = False):
"""
Initializes the M3U8Processor with the request and URL prefix.
@@ -19,9 +19,13 @@ class M3U8Processor:
request (Request): The incoming HTTP request.
key_url (HttpUrl, optional): The URL of the key server. Defaults to None.
force_playlist_proxy (bool, optional): Force all playlist URLs to be proxied through MediaFlow. Defaults to None.
key_only_proxy (bool, optional): Only proxy the key URL, leaving segment URLs direct. Defaults to False.
no_proxy (bool, optional): If True, returns the manifest without proxying any URLs. Defaults to False.
"""
self.request = request
self.key_url = parse.urlparse(key_url) if key_url else None
self.key_only_proxy = key_only_proxy
self.no_proxy = no_proxy
self.force_playlist_proxy = force_playlist_proxy
self.mediaflow_proxy_url = str(
request.url_for("hls_manifest_proxy").replace(scheme=get_original_scheme(request))
@@ -174,6 +178,15 @@ class M3U8Processor:
Returns:
str: The processed key line.
"""
# If no_proxy is enabled, just resolve relative URLs without proxying
if self.no_proxy:
uri_match = re.search(r'URI="([^"]+)"', line)
if uri_match:
original_uri = uri_match.group(1)
full_url = parse.urljoin(base_url, original_uri)
line = line.replace(f'URI="{original_uri}"', f'URI="{full_url}"')
return line
uri_match = re.search(r'URI="([^"]+)"', line)
if uri_match:
original_uri = uri_match.group(1)
@@ -197,6 +210,14 @@ class M3U8Processor:
"""
full_url = parse.urljoin(base_url, url)
# If no_proxy is enabled, return the direct URL without any proxying
if self.no_proxy:
return full_url
# If key_only_proxy is enabled, return the direct URL for segments
if self.key_only_proxy and not url.endswith((".m3u", ".m3u8")):
return full_url
# Determine routing strategy based on configuration
routing_strategy = settings.m3u8_content_routing

View File

@@ -270,7 +270,7 @@ def parse_representation(
if item:
profile["segments"] = parse_segment_template(parsed_dict, item, profile, source)
else:
profile["segments"] = parse_segment_base(representation, source)
profile["segments"] = parse_segment_base(representation, profile, source)
return profile
@@ -547,7 +547,7 @@ def create_segment_data(segment: Dict, item: dict, profile: dict, source: str, t
return segment_data
def parse_segment_base(representation: dict, source: str) -> List[Dict]:
def parse_segment_base(representation: dict, profile: dict, source: str) -> List[Dict]:
"""
Parses segment base information and extracts segment data. This is used for single-segment representations.
@@ -562,6 +562,12 @@ def parse_segment_base(representation: dict, source: str) -> List[Dict]:
start, end = map(int, segment["@indexRange"].split("-"))
if "Initialization" in segment:
start, _ = map(int, segment["Initialization"]["@range"].split("-"))
# Set initUrl for SegmentBase
if not representation['BaseURL'].startswith("http"):
profile["initUrl"] = f"{source}/{representation['BaseURL']}"
else:
profile["initUrl"] = representation['BaseURL']
return [
{

View File

@@ -0,0 +1,122 @@
# Author: Trevor Perrin
# See the LICENSE file for legal information regarding use of this file.
"""Pure-Python AES implementation."""
import sys
from .aes import AES
from .rijndael import Rijndael
from .cryptomath import bytesToNumber, numberToByteArray
__all__ = ['new', 'Python_AES']
def new(key, mode, IV):
# IV argument name is a part of the interface
# pylint: disable=invalid-name
if mode == 2:
return Python_AES(key, mode, IV)
elif mode == 6:
return Python_AES_CTR(key, mode, IV)
else:
raise NotImplementedError()
class Python_AES(AES):
def __init__(self, key, mode, IV):
# IV argument/field names are a part of the interface
# pylint: disable=invalid-name
key, IV = bytearray(key), bytearray(IV)
super(Python_AES, self).__init__(key, mode, IV, "python")
self.rijndael = Rijndael(key, 16)
self.IV = IV
def encrypt(self, plaintext):
super(Python_AES, self).encrypt(plaintext)
plaintextBytes = bytearray(plaintext)
chainBytes = self.IV[:]
#CBC Mode: For each block...
for x in range(len(plaintextBytes)//16):
#XOR with the chaining block
blockBytes = plaintextBytes[x*16 : (x*16)+16]
for y in range(16):
blockBytes[y] ^= chainBytes[y]
#Encrypt it
encryptedBytes = self.rijndael.encrypt(blockBytes)
#Overwrite the input with the output
for y in range(16):
plaintextBytes[(x*16)+y] = encryptedBytes[y]
#Set the next chaining block
chainBytes = encryptedBytes
self.IV = chainBytes[:]
return plaintextBytes
def decrypt(self, ciphertext):
super(Python_AES, self).decrypt(ciphertext)
ciphertextBytes = ciphertext[:]
chainBytes = self.IV[:]
#CBC Mode: For each block...
for x in range(len(ciphertextBytes)//16):
#Decrypt it
blockBytes = ciphertextBytes[x*16 : (x*16)+16]
decryptedBytes = self.rijndael.decrypt(blockBytes)
#XOR with the chaining block and overwrite the input with output
for y in range(16):
decryptedBytes[y] ^= chainBytes[y]
ciphertextBytes[(x*16)+y] = decryptedBytes[y]
#Set the next chaining block
chainBytes = blockBytes
self.IV = chainBytes[:]
return ciphertextBytes
class Python_AES_CTR(AES):
def __init__(self, key, mode, IV):
super(Python_AES_CTR, self).__init__(key, mode, IV, "python")
self.rijndael = Rijndael(key, 16)
self.IV = IV
self._counter_bytes = 16 - len(self.IV)
self._counter = self.IV + bytearray(b'\x00' * self._counter_bytes)
@property
def counter(self):
return self._counter
@counter.setter
def counter(self, ctr):
self._counter = ctr
def _counter_update(self):
counter_int = bytesToNumber(self._counter) + 1
self._counter = numberToByteArray(counter_int, 16)
if self._counter_bytes > 0 and \
self._counter[-self._counter_bytes:] == \
bytearray(b'\xff' * self._counter_bytes):
raise OverflowError("CTR counter overflowed")
def encrypt(self, plaintext):
mask = bytearray()
while len(mask) < len(plaintext):
mask += self.rijndael.encrypt(self._counter)
self._counter_update()
if sys.version_info < (3, 0):
inp_bytes = bytearray(ord(i) ^ j for i, j in zip(plaintext, mask))
else:
inp_bytes = bytearray(i ^ j for i, j in zip(plaintext, mask))
return inp_bytes
def decrypt(self, ciphertext):
return self.encrypt(ciphertext)

View File

@@ -0,0 +1,12 @@
# mediaflow_proxy/utils/python_aesgcm.py
from .aesgcm import AESGCM
from .rijndael import Rijndael
def new(key: bytes) -> AESGCM:
"""
Mirror ResolveURL's python_aesgcm.new(key) API:
returns an AESGCM instance with pure-Python Rijndael backend.
"""
return AESGCM(key, "python", Rijndael(key, 16).encrypt)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,32 @@
# Author: Hubert Kario (c) 2015
# see LICENCE file for legal information regarding use of this file
"""hashlib that handles FIPS mode."""
# Because we are extending the hashlib module, we need to import all its
# fields to suppport the same uses
# pylint: disable=unused-wildcard-import, wildcard-import
from hashlib import *
# pylint: enable=unused-wildcard-import, wildcard-import
import hashlib
def _fipsFunction(func, *args, **kwargs):
"""Make hash function support FIPS mode."""
try:
return func(*args, **kwargs)
except ValueError:
return func(*args, usedforsecurity=False, **kwargs)
# redefining the function is exactly what we intend to do
# pylint: disable=function-redefined
def md5(*args, **kwargs):
"""MD5 constructor that works in FIPS mode."""
return _fipsFunction(hashlib.md5, *args, **kwargs)
def new(*args, **kwargs):
"""General constructor that works in FIPS mode."""
return _fipsFunction(hashlib.new, *args, **kwargs)
# pylint: enable=function-redefined

View File

@@ -0,0 +1,88 @@
# Author: Hubert Kario (c) 2019
# see LICENCE file for legal information regarding use of this file
"""
HMAC module that works in FIPS mode.
Note that this makes this code FIPS non-compliant!
"""
# Because we are extending the hashlib module, we need to import all its
# fields to suppport the same uses
from . import tlshashlib
from .compat import compatHMAC
try:
from hmac import compare_digest
__all__ = ["new", "compare_digest", "HMAC"]
except ImportError:
__all__ = ["new", "HMAC"]
try:
from hmac import HMAC, new
# if we can calculate HMAC on MD5, then use the built-in HMAC
# implementation
_val = HMAC(b'some key', b'msg', 'md5')
_val.digest()
del _val
except Exception:
# fallback only when MD5 doesn't work
class HMAC(object):
"""Hacked version of HMAC that works in FIPS mode even with MD5."""
def __init__(self, key, msg=None, digestmod=None):
"""
Initialise the HMAC and hash first portion of data.
msg: data to hash
digestmod: name of hash or object that be used as a hash and be cloned
"""
self.key = key
if digestmod is None:
digestmod = 'md5'
if callable(digestmod):
digestmod = digestmod()
if not hasattr(digestmod, 'digest_size'):
digestmod = tlshashlib.new(digestmod)
self.block_size = digestmod.block_size
self.digest_size = digestmod.digest_size
self.digestmod = digestmod
if len(key) > self.block_size:
k_hash = digestmod.copy()
k_hash.update(compatHMAC(key))
key = k_hash.digest()
if len(key) < self.block_size:
key = key + b'\x00' * (self.block_size - len(key))
key = bytearray(key)
ipad = bytearray(b'\x36' * self.block_size)
opad = bytearray(b'\x5c' * self.block_size)
i_key = bytearray(i ^ j for i, j in zip(key, ipad))
self._o_key = bytearray(i ^ j for i, j in zip(key, opad))
self._context = digestmod.copy()
self._context.update(compatHMAC(i_key))
if msg:
self._context.update(compatHMAC(msg))
def update(self, msg):
self._context.update(compatHMAC(msg))
def digest(self):
i_digest = self._context.digest()
o_hash = self.digestmod.copy()
o_hash.update(compatHMAC(self._o_key))
o_hash.update(compatHMAC(i_digest))
return o_hash.digest()
def copy(self):
new = HMAC.__new__(HMAC)
new.key = self.key
new.digestmod = self.digestmod
new.block_size = self.block_size
new.digest_size = self.digest_size
new._o_key = self._o_key
new._context = self._context.copy()
return new
def new(*args, **kwargs):
"""General constructor that works in FIPS mode."""
return HMAC(*args, **kwargs)