mirror of
https://github.com/UrloMythus/UnHided.git
synced 2026-04-11 03:40:54 +00:00
new version
This commit is contained in:
39
mediaflow_proxy/utils/aes.py
Normal file
39
mediaflow_proxy/utils/aes.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# Author: Trevor Perrin
|
||||
# See the LICENSE file for legal information regarding use of this file.
|
||||
|
||||
"""Abstract class for AES."""
|
||||
|
||||
class AES(object):
|
||||
def __init__(self, key, mode, IV, implementation):
|
||||
if len(key) not in (16, 24, 32):
|
||||
raise AssertionError()
|
||||
if mode not in [2, 6]:
|
||||
raise AssertionError()
|
||||
if mode == 2:
|
||||
if len(IV) != 16:
|
||||
raise AssertionError()
|
||||
if mode == 6:
|
||||
if len(IV) > 16:
|
||||
raise AssertionError()
|
||||
self.isBlockCipher = True
|
||||
self.isAEAD = False
|
||||
self.block_size = 16
|
||||
self.implementation = implementation
|
||||
if len(key)==16:
|
||||
self.name = "aes128"
|
||||
elif len(key)==24:
|
||||
self.name = "aes192"
|
||||
elif len(key)==32:
|
||||
self.name = "aes256"
|
||||
else:
|
||||
raise AssertionError()
|
||||
|
||||
#CBC-Mode encryption, returns ciphertext
|
||||
#WARNING: *MAY* modify the input as well
|
||||
def encrypt(self, plaintext):
|
||||
assert(len(plaintext) % 16 == 0)
|
||||
|
||||
#CBC-Mode decryption, returns plaintext
|
||||
#WARNING: *MAY* modify the input as well
|
||||
def decrypt(self, ciphertext):
|
||||
assert(len(ciphertext) % 16 == 0)
|
||||
193
mediaflow_proxy/utils/aesgcm.py
Normal file
193
mediaflow_proxy/utils/aesgcm.py
Normal file
@@ -0,0 +1,193 @@
|
||||
# Author: Google
|
||||
# See the LICENSE file for legal information regarding use of this file.
|
||||
|
||||
# GCM derived from Go's implementation in crypto/cipher.
|
||||
#
|
||||
# https://golang.org/src/crypto/cipher/gcm.go
|
||||
|
||||
# GCM works over elements of the field GF(2^128), each of which is a 128-bit
|
||||
# polynomial. Throughout this implementation, polynomials are represented as
|
||||
# Python integers with the low-order terms at the most significant bits. So a
|
||||
# 128-bit polynomial is an integer from 0 to 2^128-1 with the most significant
|
||||
# bit representing the x^0 term and the least significant bit representing the
|
||||
# x^127 term. This bit reversal also applies to polynomials used as indices in a
|
||||
# look-up table.
|
||||
|
||||
from __future__ import division
|
||||
from . import python_aes
|
||||
from .constanttime import ct_compare_digest
|
||||
from .cryptomath import bytesToNumber, numberToByteArray
|
||||
|
||||
class AESGCM(object):
|
||||
"""
|
||||
AES-GCM implementation. Note: this implementation does not attempt
|
||||
to be side-channel resistant. It's also rather slow.
|
||||
"""
|
||||
|
||||
def __init__(self, key, implementation, rawAesEncrypt):
|
||||
self.isBlockCipher = False
|
||||
self.isAEAD = True
|
||||
self.nonceLength = 12
|
||||
self.tagLength = 16
|
||||
self.implementation = implementation
|
||||
if len(key) == 16:
|
||||
self.name = "aes128gcm"
|
||||
elif len(key) == 32:
|
||||
self.name = "aes256gcm"
|
||||
else:
|
||||
raise AssertionError()
|
||||
self.key = key
|
||||
|
||||
self._rawAesEncrypt = rawAesEncrypt
|
||||
self._ctr = python_aes.new(self.key, 6, bytearray(b'\x00' * 16))
|
||||
|
||||
# The GCM key is AES(0).
|
||||
h = bytesToNumber(self._rawAesEncrypt(bytearray(16)))
|
||||
|
||||
# Pre-compute all 4-bit multiples of h. Note that bits are reversed
|
||||
# because our polynomial representation places low-order terms at the
|
||||
# most significant bit. Thus x^0 * h = h is at index 0b1000 = 8 and
|
||||
# x^1 * h is at index 0b0100 = 4.
|
||||
self._productTable = [0] * 16
|
||||
self._productTable[self._reverseBits(1)] = h
|
||||
for i in range(2, 16, 2):
|
||||
self._productTable[self._reverseBits(i)] = \
|
||||
self._gcmShift(self._productTable[self._reverseBits(i//2)])
|
||||
self._productTable[self._reverseBits(i+1)] = \
|
||||
self._gcmAdd(self._productTable[self._reverseBits(i)], h)
|
||||
|
||||
|
||||
def _auth(self, ciphertext, ad, tagMask):
|
||||
y = 0
|
||||
y = self._update(y, ad)
|
||||
y = self._update(y, ciphertext)
|
||||
y ^= (len(ad) << (3 + 64)) | (len(ciphertext) << 3)
|
||||
y = self._mul(y)
|
||||
y ^= bytesToNumber(tagMask)
|
||||
return numberToByteArray(y, 16)
|
||||
|
||||
def _update(self, y, data):
|
||||
for i in range(0, len(data) // 16):
|
||||
y ^= bytesToNumber(data[16*i:16*i+16])
|
||||
y = self._mul(y)
|
||||
extra = len(data) % 16
|
||||
if extra != 0:
|
||||
block = bytearray(16)
|
||||
block[:extra] = data[-extra:]
|
||||
y ^= bytesToNumber(block)
|
||||
y = self._mul(y)
|
||||
return y
|
||||
|
||||
def _mul(self, y):
|
||||
""" Returns y*H, where H is the GCM key. """
|
||||
ret = 0
|
||||
# Multiply H by y 4 bits at a time, starting with the highest power
|
||||
# terms.
|
||||
for i in range(0, 128, 4):
|
||||
# Multiply by x^4. The reduction for the top four terms is
|
||||
# precomputed.
|
||||
retHigh = ret & 0xf
|
||||
ret >>= 4
|
||||
ret ^= (AESGCM._gcmReductionTable[retHigh] << (128-16))
|
||||
|
||||
# Add in y' * H where y' are the next four terms of y, shifted down
|
||||
# to the x^0..x^4. This is one of the pre-computed multiples of
|
||||
# H. The multiplication by x^4 shifts them back into place.
|
||||
ret ^= self._productTable[y & 0xf]
|
||||
y >>= 4
|
||||
assert y == 0
|
||||
return ret
|
||||
|
||||
def seal(self, nonce, plaintext, data=''):
|
||||
"""
|
||||
Encrypts and authenticates plaintext using nonce and data. Returns the
|
||||
ciphertext, consisting of the encrypted plaintext and tag concatenated.
|
||||
"""
|
||||
|
||||
if len(nonce) != 12:
|
||||
raise ValueError("Bad nonce length")
|
||||
|
||||
# The initial counter value is the nonce, followed by a 32-bit counter
|
||||
# that starts at 1. It's used to compute the tag mask.
|
||||
counter = bytearray(16)
|
||||
counter[:12] = nonce
|
||||
counter[-1] = 1
|
||||
tagMask = self._rawAesEncrypt(counter)
|
||||
|
||||
# The counter starts at 2 for the actual encryption.
|
||||
counter[-1] = 2
|
||||
self._ctr.counter = counter
|
||||
ciphertext = self._ctr.encrypt(plaintext)
|
||||
|
||||
tag = self._auth(ciphertext, data, tagMask)
|
||||
|
||||
return ciphertext + tag
|
||||
|
||||
def open(self, nonce, ciphertext, data=''):
|
||||
"""
|
||||
Decrypts and authenticates ciphertext using nonce and data. If the
|
||||
tag is valid, the plaintext is returned. If the tag is invalid,
|
||||
returns None.
|
||||
"""
|
||||
|
||||
if len(nonce) != 12:
|
||||
raise ValueError("Bad nonce length")
|
||||
if len(ciphertext) < 16:
|
||||
return None
|
||||
|
||||
tag = ciphertext[-16:]
|
||||
ciphertext = ciphertext[:-16]
|
||||
|
||||
# The initial counter value is the nonce, followed by a 32-bit counter
|
||||
# that starts at 1. It's used to compute the tag mask.
|
||||
counter = bytearray(16)
|
||||
counter[:12] = nonce
|
||||
counter[-1] = 1
|
||||
tagMask = self._rawAesEncrypt(counter)
|
||||
|
||||
if data and not ct_compare_digest(tag, self._auth(ciphertext, data, tagMask)):
|
||||
return None
|
||||
|
||||
# The counter starts at 2 for the actual decryption.
|
||||
counter[-1] = 2
|
||||
self._ctr.counter = counter
|
||||
return self._ctr.decrypt(ciphertext)
|
||||
|
||||
@staticmethod
|
||||
def _reverseBits(i):
|
||||
assert i < 16
|
||||
i = ((i << 2) & 0xc) | ((i >> 2) & 0x3)
|
||||
i = ((i << 1) & 0xa) | ((i >> 1) & 0x5)
|
||||
return i
|
||||
|
||||
@staticmethod
|
||||
def _gcmAdd(x, y):
|
||||
return x ^ y
|
||||
|
||||
@staticmethod
|
||||
def _gcmShift(x):
|
||||
# Multiplying by x is a right shift, due to bit order.
|
||||
highTermSet = x & 1
|
||||
x >>= 1
|
||||
if highTermSet:
|
||||
# The x^127 term was shifted up to x^128, so subtract a 1+x+x^2+x^7
|
||||
# term. This is 0b11100001 or 0xe1 when represented as an 8-bit
|
||||
# polynomial.
|
||||
x ^= 0xe1 << (128-8)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def _inc32(counter):
|
||||
for i in range(len(counter)-1, len(counter)-5, -1):
|
||||
counter[i] = (counter[i] + 1) % 256
|
||||
if counter[i] != 0:
|
||||
break
|
||||
return counter
|
||||
|
||||
# _gcmReductionTable[i] is i * (1+x+x^2+x^7) for all 4-bit polynomials i. The
|
||||
# result is stored as a 16-bit polynomial. This is used in the reduction step to
|
||||
# multiply elements of GF(2^128) by x^4.
|
||||
_gcmReductionTable = [
|
||||
0x0000, 0x1c20, 0x3840, 0x2460, 0x7080, 0x6ca0, 0x48c0, 0x54e0,
|
||||
0xe100, 0xfd20, 0xd940, 0xc560, 0x9180, 0x8da0, 0xa9c0, 0xb5e0,
|
||||
]
|
||||
@@ -175,12 +175,27 @@ class HybridCache:
|
||||
if not isinstance(data, (bytes, bytearray, memoryview)):
|
||||
raise ValueError("Data must be bytes, bytearray, or memoryview")
|
||||
|
||||
expires_at = time.time() + (ttl or self.ttl)
|
||||
ttl_seconds = self.ttl if ttl is None else ttl
|
||||
|
||||
key = self._get_md5_hash(key)
|
||||
|
||||
if ttl_seconds <= 0:
|
||||
# Explicit request to avoid caching - remove any previous entry and return success
|
||||
self.memory_cache.remove(key)
|
||||
try:
|
||||
file_path = self._get_file_path(key)
|
||||
await aiofiles.os.remove(file_path)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"Error removing cache file: {e}")
|
||||
return True
|
||||
|
||||
expires_at = time.time() + ttl_seconds
|
||||
|
||||
# Create cache entry
|
||||
entry = CacheEntry(data=data, expires_at=expires_at, access_count=0, last_access=time.time(), size=len(data))
|
||||
|
||||
key = self._get_md5_hash(key)
|
||||
# Update memory cache
|
||||
self.memory_cache.set(key, entry)
|
||||
file_path = self._get_file_path(key)
|
||||
@@ -210,10 +225,11 @@ class HybridCache:
|
||||
|
||||
async def delete(self, key: str) -> bool:
|
||||
"""Delete item from both caches."""
|
||||
self.memory_cache.remove(key)
|
||||
hashed_key = self._get_md5_hash(key)
|
||||
self.memory_cache.remove(hashed_key)
|
||||
|
||||
try:
|
||||
file_path = self._get_file_path(key)
|
||||
file_path = self._get_file_path(hashed_key)
|
||||
await aiofiles.os.remove(file_path)
|
||||
return True
|
||||
except FileNotFoundError:
|
||||
@@ -237,7 +253,13 @@ class AsyncMemoryCache:
|
||||
async def set(self, key: str, data: Union[bytes, bytearray, memoryview], ttl: Optional[int] = None) -> bool:
|
||||
"""Set value in cache."""
|
||||
try:
|
||||
expires_at = time.time() + (ttl or 3600) # Default 1 hour TTL if not specified
|
||||
ttl_seconds = 3600 if ttl is None else ttl
|
||||
|
||||
if ttl_seconds <= 0:
|
||||
self.memory_cache.remove(key)
|
||||
return True
|
||||
|
||||
expires_at = time.time() + ttl_seconds
|
||||
entry = CacheEntry(
|
||||
data=data, expires_at=expires_at, access_count=0, last_access=time.time(), size=len(data)
|
||||
)
|
||||
@@ -276,18 +298,35 @@ EXTRACTOR_CACHE = HybridCache(
|
||||
|
||||
|
||||
# Specific cache implementations
|
||||
async def get_cached_init_segment(init_url: str, headers: dict) -> Optional[bytes]:
|
||||
"""Get initialization segment from cache or download it."""
|
||||
# Try cache first
|
||||
cached_data = await INIT_SEGMENT_CACHE.get(init_url)
|
||||
if cached_data is not None:
|
||||
return cached_data
|
||||
async def get_cached_init_segment(
|
||||
init_url: str,
|
||||
headers: dict,
|
||||
cache_token: str | None = None,
|
||||
ttl: Optional[int] = None,
|
||||
) -> Optional[bytes]:
|
||||
"""Get initialization segment from cache or download it.
|
||||
|
||||
cache_token allows differentiating entries that share the same init_url but
|
||||
rely on different DRM keys or initialization payloads (e.g. key rotation).
|
||||
|
||||
ttl overrides the default cache TTL; pass a value <= 0 to skip caching entirely.
|
||||
"""
|
||||
|
||||
use_cache = ttl is None or ttl > 0
|
||||
cache_key = f"{init_url}|{cache_token}" if cache_token else init_url
|
||||
|
||||
if use_cache:
|
||||
cached_data = await INIT_SEGMENT_CACHE.get(cache_key)
|
||||
if cached_data is not None:
|
||||
return cached_data
|
||||
else:
|
||||
# Remove any previously cached entry when caching is disabled
|
||||
await INIT_SEGMENT_CACHE.delete(cache_key)
|
||||
|
||||
# Download if not cached
|
||||
try:
|
||||
init_content = await download_file_with_retry(init_url, headers)
|
||||
if init_content:
|
||||
await INIT_SEGMENT_CACHE.set(init_url, init_content)
|
||||
if init_content and use_cache:
|
||||
await INIT_SEGMENT_CACHE.set(cache_key, init_content, ttl=ttl)
|
||||
return init_content
|
||||
except Exception as e:
|
||||
logger.error(f"Error downloading init segment: {e}")
|
||||
|
||||
465
mediaflow_proxy/utils/codec.py
Normal file
465
mediaflow_proxy/utils/codec.py
Normal file
@@ -0,0 +1,465 @@
|
||||
# Author: Trevor Perrin
|
||||
# See the LICENSE file for legal information regarding use of this file.
|
||||
|
||||
"""Classes for reading/writing binary data (such as TLS records)."""
|
||||
|
||||
from __future__ import division
|
||||
|
||||
import sys
|
||||
import struct
|
||||
from struct import pack
|
||||
from .compat import bytes_to_int
|
||||
|
||||
|
||||
class DecodeError(SyntaxError):
|
||||
"""Exception raised in case of decoding errors."""
|
||||
pass
|
||||
|
||||
|
||||
class BadCertificateError(SyntaxError):
|
||||
"""Exception raised in case of bad certificate."""
|
||||
pass
|
||||
|
||||
|
||||
class Writer(object):
|
||||
"""Serialisation helper for complex byte-based structures."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialise the serializer with no data."""
|
||||
self.bytes = bytearray(0)
|
||||
|
||||
def addOne(self, val):
|
||||
"""Add a single-byte wide element to buffer, see add()."""
|
||||
self.bytes.append(val)
|
||||
|
||||
if sys.version_info < (2, 7):
|
||||
# struct.pack on Python2.6 does not raise exception if the value
|
||||
# is larger than can fit inside the specified size
|
||||
def addTwo(self, val):
|
||||
"""Add a double-byte wide element to buffer, see add()."""
|
||||
if not 0 <= val <= 0xffff:
|
||||
raise ValueError("Can't represent value in specified length")
|
||||
self.bytes += pack('>H', val)
|
||||
|
||||
def addThree(self, val):
|
||||
"""Add a three-byte wide element to buffer, see add()."""
|
||||
if not 0 <= val <= 0xffffff:
|
||||
raise ValueError("Can't represent value in specified length")
|
||||
self.bytes += pack('>BH', val >> 16, val & 0xffff)
|
||||
|
||||
def addFour(self, val):
|
||||
"""Add a four-byte wide element to buffer, see add()."""
|
||||
if not 0 <= val <= 0xffffffff:
|
||||
raise ValueError("Can't represent value in specified length")
|
||||
self.bytes += pack('>I', val)
|
||||
else:
|
||||
def addTwo(self, val):
|
||||
"""Add a double-byte wide element to buffer, see add()."""
|
||||
try:
|
||||
self.bytes += pack('>H', val)
|
||||
except struct.error:
|
||||
raise ValueError("Can't represent value in specified length")
|
||||
|
||||
def addThree(self, val):
|
||||
"""Add a three-byte wide element to buffer, see add()."""
|
||||
try:
|
||||
self.bytes += pack('>BH', val >> 16, val & 0xffff)
|
||||
except struct.error:
|
||||
raise ValueError("Can't represent value in specified length")
|
||||
|
||||
def addFour(self, val):
|
||||
"""Add a four-byte wide element to buffer, see add()."""
|
||||
try:
|
||||
self.bytes += pack('>I', val)
|
||||
except struct.error:
|
||||
raise ValueError("Can't represent value in specified length")
|
||||
|
||||
if sys.version_info >= (3, 0):
|
||||
# the method is called thousands of times, so it's better to extern
|
||||
# the version info check
|
||||
def add(self, x, length):
|
||||
"""
|
||||
Add a single positive integer value x, encode it in length bytes
|
||||
|
||||
Encode positive integer x in big-endian format using length bytes,
|
||||
add to the internal buffer.
|
||||
|
||||
:type x: int
|
||||
:param x: value to encode
|
||||
|
||||
:type length: int
|
||||
:param length: number of bytes to use for encoding the value
|
||||
"""
|
||||
try:
|
||||
self.bytes += x.to_bytes(length, 'big')
|
||||
except OverflowError:
|
||||
raise ValueError("Can't represent value in specified length")
|
||||
else:
|
||||
_addMethods = {1: addOne, 2: addTwo, 3: addThree, 4: addFour}
|
||||
|
||||
def add(self, x, length):
|
||||
"""
|
||||
Add a single positive integer value x, encode it in length bytes
|
||||
|
||||
Encode positive iteger x in big-endian format using length bytes,
|
||||
add to the internal buffer.
|
||||
|
||||
:type x: int
|
||||
:param x: value to encode
|
||||
|
||||
:type length: int
|
||||
:param length: number of bytes to use for encoding the value
|
||||
"""
|
||||
try:
|
||||
self._addMethods[length](self, x)
|
||||
except KeyError:
|
||||
self.bytes += bytearray(length)
|
||||
newIndex = len(self.bytes) - 1
|
||||
for i in range(newIndex, newIndex - length, -1):
|
||||
self.bytes[i] = x & 0xFF
|
||||
x >>= 8
|
||||
if x != 0:
|
||||
raise ValueError("Can't represent value in specified "
|
||||
"length")
|
||||
|
||||
def addFixSeq(self, seq, length):
|
||||
"""
|
||||
Add a list of items, encode every item in length bytes
|
||||
|
||||
Uses the unbounded iterable seq to produce items, each of
|
||||
which is then encoded to length bytes
|
||||
|
||||
:type seq: iterable of int
|
||||
:param seq: list of positive integers to encode
|
||||
|
||||
:type length: int
|
||||
:param length: number of bytes to which encode every element
|
||||
"""
|
||||
for e in seq:
|
||||
self.add(e, length)
|
||||
|
||||
if sys.version_info < (2, 7):
|
||||
# struct.pack on Python2.6 does not raise exception if the value
|
||||
# is larger than can fit inside the specified size
|
||||
def _addVarSeqTwo(self, seq):
|
||||
"""Helper method for addVarSeq"""
|
||||
if not all(0 <= i <= 0xffff for i in seq):
|
||||
raise ValueError("Can't represent value in specified "
|
||||
"length")
|
||||
self.bytes += pack('>' + 'H' * len(seq), *seq)
|
||||
|
||||
def addVarSeq(self, seq, length, lengthLength):
|
||||
"""
|
||||
Add a bounded list of same-sized values
|
||||
|
||||
Create a list of specific length with all items being of the same
|
||||
size
|
||||
|
||||
:type seq: list of int
|
||||
:param seq: list of positive integers to encode
|
||||
|
||||
:type length: int
|
||||
:param length: amount of bytes in which to encode every item
|
||||
|
||||
:type lengthLength: int
|
||||
:param lengthLength: amount of bytes in which to encode the overall
|
||||
length of the array
|
||||
"""
|
||||
self.add(len(seq)*length, lengthLength)
|
||||
if length == 1:
|
||||
self.bytes.extend(seq)
|
||||
elif length == 2:
|
||||
self._addVarSeqTwo(seq)
|
||||
else:
|
||||
for i in seq:
|
||||
self.add(i, length)
|
||||
else:
|
||||
def addVarSeq(self, seq, length, lengthLength):
|
||||
"""
|
||||
Add a bounded list of same-sized values
|
||||
|
||||
Create a list of specific length with all items being of the same
|
||||
size
|
||||
|
||||
:type seq: list of int
|
||||
:param seq: list of positive integers to encode
|
||||
|
||||
:type length: int
|
||||
:param length: amount of bytes in which to encode every item
|
||||
|
||||
:type lengthLength: int
|
||||
:param lengthLength: amount of bytes in which to encode the overall
|
||||
length of the array
|
||||
"""
|
||||
seqLen = len(seq)
|
||||
self.add(seqLen*length, lengthLength)
|
||||
if length == 1:
|
||||
self.bytes.extend(seq)
|
||||
elif length == 2:
|
||||
try:
|
||||
self.bytes += pack('>' + 'H' * seqLen, *seq)
|
||||
except struct.error:
|
||||
raise ValueError("Can't represent value in specified "
|
||||
"length")
|
||||
else:
|
||||
for i in seq:
|
||||
self.add(i, length)
|
||||
|
||||
def addVarTupleSeq(self, seq, length, lengthLength):
|
||||
"""
|
||||
Add a variable length list of same-sized element tuples.
|
||||
|
||||
Note that all tuples must have the same size.
|
||||
|
||||
Inverse of Parser.getVarTupleList()
|
||||
|
||||
:type seq: enumerable
|
||||
:param seq: list of tuples
|
||||
|
||||
:type length: int
|
||||
:param length: length of single element in tuple
|
||||
|
||||
:type lengthLength: int
|
||||
:param lengthLength: length in bytes of overall length field
|
||||
"""
|
||||
if not seq:
|
||||
self.add(0, lengthLength)
|
||||
else:
|
||||
startPos = len(self.bytes)
|
||||
dataLength = len(seq) * len(seq[0]) * length
|
||||
self.add(dataLength, lengthLength)
|
||||
# since at the time of writing, all the calls encode single byte
|
||||
# elements, and it's very easy to speed up that case, give it
|
||||
# special case
|
||||
if length == 1:
|
||||
for elemTuple in seq:
|
||||
self.bytes.extend(elemTuple)
|
||||
else:
|
||||
for elemTuple in seq:
|
||||
self.addFixSeq(elemTuple, length)
|
||||
if startPos + dataLength + lengthLength != len(self.bytes):
|
||||
raise ValueError("Tuples of different lengths")
|
||||
|
||||
def add_var_bytes(self, data, length_length):
|
||||
"""
|
||||
Add a variable length array of bytes.
|
||||
|
||||
Inverse of Parser.getVarBytes()
|
||||
|
||||
:type data: bytes
|
||||
:param data: bytes to add to the buffer
|
||||
|
||||
:param int length_length: size of the field to represent the length
|
||||
of the data string
|
||||
"""
|
||||
length = len(data)
|
||||
self.add(length, length_length)
|
||||
self.bytes += data
|
||||
|
||||
|
||||
class Parser(object):
|
||||
"""
|
||||
Parser for TLV and LV byte-based encodings.
|
||||
|
||||
Parser that can handle arbitrary byte-based encodings usually employed in
|
||||
Type-Length-Value or Length-Value binary encoding protocols like ASN.1
|
||||
or TLS
|
||||
|
||||
Note: if the raw bytes don't match expected values (like trying to
|
||||
read a 4-byte integer from a 2-byte buffer), most methods will raise a
|
||||
DecodeError exception.
|
||||
|
||||
TODO: don't use an exception used by language parser to indicate errors
|
||||
in application code.
|
||||
|
||||
:vartype bytes: bytearray
|
||||
:ivar bytes: data to be interpreted (buffer)
|
||||
|
||||
:vartype index: int
|
||||
:ivar index: current position in the buffer
|
||||
|
||||
:vartype lengthCheck: int
|
||||
:ivar lengthCheck: size of struct being parsed
|
||||
|
||||
:vartype indexCheck: int
|
||||
:ivar indexCheck: position at which the structure begins in buffer
|
||||
"""
|
||||
|
||||
def __init__(self, bytes):
|
||||
"""
|
||||
Bind raw bytes with parser.
|
||||
|
||||
:type bytes: bytearray
|
||||
:param bytes: bytes to be parsed/interpreted
|
||||
"""
|
||||
self.bytes = bytes
|
||||
self.index = 0
|
||||
self.indexCheck = 0
|
||||
self.lengthCheck = 0
|
||||
|
||||
def get(self, length):
|
||||
"""
|
||||
Read a single big-endian integer value encoded in 'length' bytes.
|
||||
|
||||
:type length: int
|
||||
:param length: number of bytes in which the value is encoded in
|
||||
|
||||
:rtype: int
|
||||
"""
|
||||
ret = self.getFixBytes(length)
|
||||
return bytes_to_int(ret, 'big')
|
||||
|
||||
def getFixBytes(self, lengthBytes):
|
||||
"""
|
||||
Read a string of bytes encoded in 'lengthBytes' bytes.
|
||||
|
||||
:type lengthBytes: int
|
||||
:param lengthBytes: number of bytes to return
|
||||
|
||||
:rtype: bytearray
|
||||
"""
|
||||
end = self.index + lengthBytes
|
||||
if end > len(self.bytes):
|
||||
raise DecodeError("Read past end of buffer")
|
||||
ret = self.bytes[self.index : end]
|
||||
self.index += lengthBytes
|
||||
return ret
|
||||
|
||||
def skip_bytes(self, length):
|
||||
"""Move the internal pointer ahead length bytes."""
|
||||
if self.index + length > len(self.bytes):
|
||||
raise DecodeError("Read past end of buffer")
|
||||
self.index += length
|
||||
|
||||
def getVarBytes(self, lengthLength):
|
||||
"""
|
||||
Read a variable length string with a fixed length.
|
||||
|
||||
see Writer.add_var_bytes() for an inverse of this method
|
||||
|
||||
:type lengthLength: int
|
||||
:param lengthLength: number of bytes in which the length of the string
|
||||
is encoded in
|
||||
|
||||
:rtype: bytearray
|
||||
"""
|
||||
lengthBytes = self.get(lengthLength)
|
||||
return self.getFixBytes(lengthBytes)
|
||||
|
||||
def getFixList(self, length, lengthList):
|
||||
"""
|
||||
Read a list of static length with same-sized ints.
|
||||
|
||||
:type length: int
|
||||
:param length: size in bytes of a single element in list
|
||||
|
||||
:type lengthList: int
|
||||
:param lengthList: number of elements in list
|
||||
|
||||
:rtype: list of int
|
||||
"""
|
||||
l = [0] * lengthList
|
||||
for x in range(lengthList):
|
||||
l[x] = self.get(length)
|
||||
return l
|
||||
|
||||
def getVarList(self, length, lengthLength):
|
||||
"""
|
||||
Read a variable length list of same-sized integers.
|
||||
|
||||
:type length: int
|
||||
:param length: size in bytes of a single element
|
||||
|
||||
:type lengthLength: int
|
||||
:param lengthLength: size of the encoded length of the list
|
||||
|
||||
:rtype: list of int
|
||||
"""
|
||||
lengthList = self.get(lengthLength)
|
||||
if lengthList % length != 0:
|
||||
raise DecodeError("Encoded length not a multiple of element "
|
||||
"length")
|
||||
lengthList = lengthList // length
|
||||
l = [0] * lengthList
|
||||
for x in range(lengthList):
|
||||
l[x] = self.get(length)
|
||||
return l
|
||||
|
||||
def getVarTupleList(self, elemLength, elemNum, lengthLength):
|
||||
"""
|
||||
Read a variable length list of same sized tuples.
|
||||
|
||||
:type elemLength: int
|
||||
:param elemLength: length in bytes of single tuple element
|
||||
|
||||
:type elemNum: int
|
||||
:param elemNum: number of elements in tuple
|
||||
|
||||
:type lengthLength: int
|
||||
:param lengthLength: length in bytes of the list length variable
|
||||
|
||||
:rtype: list of tuple of int
|
||||
"""
|
||||
lengthList = self.get(lengthLength)
|
||||
if lengthList % (elemLength * elemNum) != 0:
|
||||
raise DecodeError("Encoded length not a multiple of element "
|
||||
"length")
|
||||
tupleCount = lengthList // (elemLength * elemNum)
|
||||
tupleList = []
|
||||
for _ in range(tupleCount):
|
||||
currentTuple = []
|
||||
for _ in range(elemNum):
|
||||
currentTuple.append(self.get(elemLength))
|
||||
tupleList.append(tuple(currentTuple))
|
||||
return tupleList
|
||||
|
||||
def startLengthCheck(self, lengthLength):
|
||||
"""
|
||||
Read length of struct and start a length check for parsing.
|
||||
|
||||
:type lengthLength: int
|
||||
:param lengthLength: number of bytes in which the length is encoded
|
||||
"""
|
||||
self.lengthCheck = self.get(lengthLength)
|
||||
self.indexCheck = self.index
|
||||
|
||||
def setLengthCheck(self, length):
|
||||
"""
|
||||
Set length of struct and start a length check for parsing.
|
||||
|
||||
:type length: int
|
||||
:param length: expected size of parsed struct in bytes
|
||||
"""
|
||||
self.lengthCheck = length
|
||||
self.indexCheck = self.index
|
||||
|
||||
def stopLengthCheck(self):
|
||||
"""
|
||||
Stop struct parsing, verify that no under- or overflow occurred.
|
||||
|
||||
In case the expected length was mismatched with actual length of
|
||||
processed data, raises an exception.
|
||||
"""
|
||||
if (self.index - self.indexCheck) != self.lengthCheck:
|
||||
raise DecodeError("Under- or over-flow while reading buffer")
|
||||
|
||||
def atLengthCheck(self):
|
||||
"""
|
||||
Check if there is data in structure left for parsing.
|
||||
|
||||
Returns True if the whole structure was parsed, False if there is
|
||||
some data left.
|
||||
|
||||
Will raise an exception if overflow occured (amount of data read was
|
||||
greater than expected size)
|
||||
"""
|
||||
if (self.index - self.indexCheck) < self.lengthCheck:
|
||||
return False
|
||||
elif (self.index - self.indexCheck) == self.lengthCheck:
|
||||
return True
|
||||
else:
|
||||
raise DecodeError("Read past end of buffer")
|
||||
|
||||
def getRemainingLength(self):
|
||||
"""Return amount of data remaining in struct being parsed."""
|
||||
return len(self.bytes) - self.index
|
||||
231
mediaflow_proxy/utils/compat.py
Normal file
231
mediaflow_proxy/utils/compat.py
Normal file
@@ -0,0 +1,231 @@
|
||||
# Author: Trevor Perrin
|
||||
# See the LICENSE file for legal information regarding use of this file.
|
||||
|
||||
"""Miscellaneous functions to mask Python version differences."""
|
||||
|
||||
import sys
|
||||
import re
|
||||
import platform
|
||||
import binascii
|
||||
import traceback
|
||||
import time
|
||||
|
||||
|
||||
if sys.version_info >= (3,0):
|
||||
|
||||
def compat26Str(x): return x
|
||||
|
||||
# Python 3.3 requires bytes instead of bytearrays for HMAC
|
||||
# So, python 2.6 requires strings, python 3 requires 'bytes',
|
||||
# and python 2.7 and 3.5 can handle bytearrays...
|
||||
# pylint: disable=invalid-name
|
||||
# we need to keep compatHMAC and `x` for API compatibility
|
||||
if sys.version_info < (3, 4):
|
||||
def compatHMAC(x):
|
||||
"""Convert bytes-like input to format acceptable for HMAC."""
|
||||
return bytes(x)
|
||||
else:
|
||||
def compatHMAC(x):
|
||||
"""Convert bytes-like input to format acceptable for HMAC."""
|
||||
return x
|
||||
# pylint: enable=invalid-name
|
||||
|
||||
def compatAscii2Bytes(val):
|
||||
"""Convert ASCII string to bytes."""
|
||||
if isinstance(val, str):
|
||||
return bytes(val, 'ascii')
|
||||
return val
|
||||
|
||||
def compat_b2a(val):
|
||||
"""Convert an ASCII bytes string to string."""
|
||||
return str(val, 'ascii')
|
||||
|
||||
def raw_input(s):
|
||||
return input(s)
|
||||
|
||||
# So, the python3 binascii module deals with bytearrays, and python2
|
||||
# deals with strings... I would rather deal with the "a" part as
|
||||
# strings, and the "b" part as bytearrays, regardless of python version,
|
||||
# so...
|
||||
def a2b_hex(s):
|
||||
try:
|
||||
b = bytearray(binascii.a2b_hex(bytearray(s, "ascii")))
|
||||
except Exception as e:
|
||||
raise SyntaxError("base16 error: %s" % e)
|
||||
return b
|
||||
|
||||
def a2b_base64(s):
|
||||
try:
|
||||
if isinstance(s, str):
|
||||
s = bytearray(s, "ascii")
|
||||
b = bytearray(binascii.a2b_base64(s))
|
||||
except Exception as e:
|
||||
raise SyntaxError("base64 error: %s" % e)
|
||||
return b
|
||||
|
||||
def b2a_hex(b):
|
||||
return binascii.b2a_hex(b).decode("ascii")
|
||||
|
||||
def b2a_base64(b):
|
||||
return binascii.b2a_base64(b).decode("ascii")
|
||||
|
||||
def readStdinBinary():
|
||||
return sys.stdin.buffer.read()
|
||||
|
||||
def compatLong(num):
|
||||
return int(num)
|
||||
|
||||
int_types = tuple([int])
|
||||
|
||||
def formatExceptionTrace(e):
|
||||
"""Return exception information formatted as string"""
|
||||
return str(e)
|
||||
|
||||
def time_stamp():
|
||||
"""Returns system time as a float"""
|
||||
if sys.version_info >= (3, 3):
|
||||
return time.perf_counter()
|
||||
return time.clock()
|
||||
|
||||
def remove_whitespace(text):
|
||||
"""Removes all whitespace from passed in string"""
|
||||
return re.sub(r"\s+", "", text, flags=re.UNICODE)
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
# pylint is stupid here and deson't notice it's a function, not
|
||||
# constant
|
||||
bytes_to_int = int.from_bytes
|
||||
# pylint: enable=invalid-name
|
||||
|
||||
def bit_length(val):
|
||||
"""Return number of bits necessary to represent an integer."""
|
||||
return val.bit_length()
|
||||
|
||||
def int_to_bytes(val, length=None, byteorder="big"):
|
||||
"""Return number converted to bytes"""
|
||||
if length is None:
|
||||
if val:
|
||||
length = byte_length(val)
|
||||
else:
|
||||
length = 1
|
||||
# for gmpy we need to convert back to native int
|
||||
if type(val) != int:
|
||||
val = int(val)
|
||||
return bytearray(val.to_bytes(length=length, byteorder=byteorder))
|
||||
|
||||
else:
|
||||
# Python 2.6 requires strings instead of bytearrays in a couple places,
|
||||
# so we define this function so it does the conversion if needed.
|
||||
# same thing with very old 2.7 versions
|
||||
# or on Jython
|
||||
if sys.version_info < (2, 7) or sys.version_info < (2, 7, 4) \
|
||||
or platform.system() == 'Java':
|
||||
def compat26Str(x): return str(x)
|
||||
|
||||
def remove_whitespace(text):
|
||||
"""Removes all whitespace from passed in string"""
|
||||
return re.sub(r"\s+", "", text)
|
||||
|
||||
def bit_length(val):
|
||||
"""Return number of bits necessary to represent an integer."""
|
||||
if val == 0:
|
||||
return 0
|
||||
return len(bin(val))-2
|
||||
else:
|
||||
def compat26Str(x): return x
|
||||
|
||||
def remove_whitespace(text):
|
||||
"""Removes all whitespace from passed in string"""
|
||||
return re.sub(r"\s+", "", text, flags=re.UNICODE)
|
||||
|
||||
def bit_length(val):
|
||||
"""Return number of bits necessary to represent an integer."""
|
||||
return val.bit_length()
|
||||
|
||||
def compatAscii2Bytes(val):
|
||||
"""Convert ASCII string to bytes."""
|
||||
return val
|
||||
|
||||
def compat_b2a(val):
|
||||
"""Convert an ASCII bytes string to string."""
|
||||
return str(val)
|
||||
|
||||
# So, python 2.6 requires strings, python 3 requires 'bytes',
|
||||
# and python 2.7 can handle bytearrays...
|
||||
def compatHMAC(x): return compat26Str(x)
|
||||
|
||||
def a2b_hex(s):
|
||||
try:
|
||||
b = bytearray(binascii.a2b_hex(s))
|
||||
except Exception as e:
|
||||
raise SyntaxError("base16 error: %s" % e)
|
||||
return b
|
||||
|
||||
def a2b_base64(s):
|
||||
try:
|
||||
b = bytearray(binascii.a2b_base64(s))
|
||||
except Exception as e:
|
||||
raise SyntaxError("base64 error: %s" % e)
|
||||
return b
|
||||
|
||||
def b2a_hex(b):
|
||||
return binascii.b2a_hex(compat26Str(b))
|
||||
|
||||
def b2a_base64(b):
|
||||
return binascii.b2a_base64(compat26Str(b))
|
||||
|
||||
def compatLong(num):
|
||||
return long(num)
|
||||
|
||||
int_types = (int, long)
|
||||
|
||||
# pylint on Python3 goes nuts for the sys dereferences...
|
||||
|
||||
#pylint: disable=no-member
|
||||
def formatExceptionTrace(e):
|
||||
"""Return exception information formatted as string"""
|
||||
newStr = "".join(traceback.format_exception(sys.exc_type,
|
||||
sys.exc_value,
|
||||
sys.exc_traceback))
|
||||
return newStr
|
||||
#pylint: enable=no-member
|
||||
|
||||
def time_stamp():
|
||||
"""Returns system time as a float"""
|
||||
return time.clock()
|
||||
|
||||
def bytes_to_int(val, byteorder):
|
||||
"""Convert bytes to an int."""
|
||||
if not val:
|
||||
return 0
|
||||
if byteorder == "big":
|
||||
return int(b2a_hex(val), 16)
|
||||
if byteorder == "little":
|
||||
return int(b2a_hex(val[::-1]), 16)
|
||||
raise ValueError("Only 'big' and 'little' endian supported")
|
||||
|
||||
def int_to_bytes(val, length=None, byteorder="big"):
|
||||
"""Return number converted to bytes"""
|
||||
if length is None:
|
||||
if val:
|
||||
length = byte_length(val)
|
||||
else:
|
||||
length = 1
|
||||
if byteorder == "big":
|
||||
return bytearray((val >> i) & 0xff
|
||||
for i in reversed(range(0, length*8, 8)))
|
||||
if byteorder == "little":
|
||||
return bytearray((val >> i) & 0xff
|
||||
for i in range(0, length*8, 8))
|
||||
raise ValueError("Only 'big' or 'little' endian supported")
|
||||
|
||||
|
||||
def byte_length(val):
|
||||
"""Return number of bytes necessary to represent an integer."""
|
||||
length = bit_length(val)
|
||||
return (length + 7) // 8
|
||||
|
||||
|
||||
ecdsaAllCurves = False
|
||||
ML_KEM_AVAILABLE = False
|
||||
ML_DSA_AVAILABLE = False
|
||||
218
mediaflow_proxy/utils/constanttime.py
Normal file
218
mediaflow_proxy/utils/constanttime.py
Normal file
@@ -0,0 +1,218 @@
|
||||
# Copyright (c) 2015, Hubert Kario
|
||||
#
|
||||
# See the LICENSE file for legal information regarding use of this file.
|
||||
"""Various constant time functions for processing sensitive data"""
|
||||
|
||||
from __future__ import division
|
||||
|
||||
from .compat import compatHMAC
|
||||
import hmac
|
||||
|
||||
def ct_lt_u32(val_a, val_b):
|
||||
"""
|
||||
Returns 1 if val_a < val_b, 0 otherwise. Constant time.
|
||||
|
||||
:type val_a: int
|
||||
:type val_b: int
|
||||
:param val_a: an unsigned integer representable as a 32 bit value
|
||||
:param val_b: an unsigned integer representable as a 32 bit value
|
||||
:rtype: int
|
||||
"""
|
||||
val_a &= 0xffffffff
|
||||
val_b &= 0xffffffff
|
||||
|
||||
return (val_a^((val_a^val_b)|(((val_a-val_b)&0xffffffff)^val_b)))>>31
|
||||
|
||||
|
||||
def ct_gt_u32(val_a, val_b):
|
||||
"""
|
||||
Return 1 if val_a > val_b, 0 otherwise. Constant time.
|
||||
|
||||
:type val_a: int
|
||||
:type val_b: int
|
||||
:param val_a: an unsigned integer representable as a 32 bit value
|
||||
:param val_b: an unsigned integer representable as a 32 bit value
|
||||
:rtype: int
|
||||
"""
|
||||
return ct_lt_u32(val_b, val_a)
|
||||
|
||||
|
||||
def ct_le_u32(val_a, val_b):
|
||||
"""
|
||||
Return 1 if val_a <= val_b, 0 otherwise. Constant time.
|
||||
|
||||
:type val_a: int
|
||||
:type val_b: int
|
||||
:param val_a: an unsigned integer representable as a 32 bit value
|
||||
:param val_b: an unsigned integer representable as a 32 bit value
|
||||
:rtype: int
|
||||
"""
|
||||
return 1 ^ ct_gt_u32(val_a, val_b)
|
||||
|
||||
|
||||
def ct_lsb_prop_u8(val):
|
||||
"""Propagate LSB to all 8 bits of the returned int. Constant time."""
|
||||
val &= 0x01
|
||||
val |= val << 1
|
||||
val |= val << 2
|
||||
val |= val << 4
|
||||
return val
|
||||
|
||||
|
||||
def ct_lsb_prop_u16(val):
|
||||
"""Propagate LSB to all 16 bits of the returned int. Constant time."""
|
||||
val &= 0x01
|
||||
val |= val << 1
|
||||
val |= val << 2
|
||||
val |= val << 4
|
||||
val |= val << 8
|
||||
return val
|
||||
|
||||
|
||||
def ct_isnonzero_u32(val):
|
||||
"""
|
||||
Returns 1 if val is != 0, 0 otherwise. Constant time.
|
||||
|
||||
:type val: int
|
||||
:param val: an unsigned integer representable as a 32 bit value
|
||||
:rtype: int
|
||||
"""
|
||||
val &= 0xffffffff
|
||||
return (val|(-val&0xffffffff)) >> 31
|
||||
|
||||
|
||||
def ct_neq_u32(val_a, val_b):
|
||||
"""
|
||||
Return 1 if val_a != val_b, 0 otherwise. Constant time.
|
||||
|
||||
:type val_a: int
|
||||
:type val_b: int
|
||||
:param val_a: an unsigned integer representable as a 32 bit value
|
||||
:param val_b: an unsigned integer representable as a 32 bit value
|
||||
:rtype: int
|
||||
"""
|
||||
val_a &= 0xffffffff
|
||||
val_b &= 0xffffffff
|
||||
|
||||
return (((val_a-val_b)&0xffffffff) | ((val_b-val_a)&0xffffffff)) >> 31
|
||||
|
||||
def ct_eq_u32(val_a, val_b):
|
||||
"""
|
||||
Return 1 if val_a == val_b, 0 otherwise. Constant time.
|
||||
|
||||
:type val_a: int
|
||||
:type val_b: int
|
||||
:param val_a: an unsigned integer representable as a 32 bit value
|
||||
:param val_b: an unsigned integer representable as a 32 bit value
|
||||
:rtype: int
|
||||
"""
|
||||
return 1 ^ ct_neq_u32(val_a, val_b)
|
||||
|
||||
def ct_check_cbc_mac_and_pad(data, mac, seqnumBytes, contentType, version,
|
||||
block_size=16):
|
||||
"""
|
||||
Check CBC cipher HMAC and padding. Close to constant time.
|
||||
|
||||
:type data: bytearray
|
||||
:param data: data with HMAC value to test and padding
|
||||
|
||||
:type mac: hashlib mac
|
||||
:param mac: empty HMAC, initialised with a key
|
||||
|
||||
:type seqnumBytes: bytearray
|
||||
:param seqnumBytes: TLS sequence number, used as input to HMAC
|
||||
|
||||
:type contentType: int
|
||||
:param contentType: a single byte, used as input to HMAC
|
||||
|
||||
:type version: tuple of int
|
||||
:param version: a tuple of two ints, used as input to HMAC and to guide
|
||||
checking of padding
|
||||
|
||||
:rtype: boolean
|
||||
:returns: True if MAC and pad is ok, False otherwise
|
||||
"""
|
||||
assert version in ((3, 0), (3, 1), (3, 2), (3, 3))
|
||||
|
||||
data_len = len(data)
|
||||
if mac.digest_size + 1 > data_len: # data_len is public
|
||||
return False
|
||||
|
||||
# 0 - OK
|
||||
result = 0x00
|
||||
|
||||
#
|
||||
# check padding
|
||||
#
|
||||
pad_length = data[data_len-1]
|
||||
pad_start = data_len - pad_length - 1
|
||||
pad_start = max(0, pad_start)
|
||||
|
||||
if version == (3, 0): # version is public
|
||||
# in SSLv3 we can only check if pad is not longer than the cipher
|
||||
# block size
|
||||
|
||||
# subtract 1 for the pad length byte
|
||||
mask = ct_lsb_prop_u8(ct_lt_u32(block_size, pad_length))
|
||||
result |= mask
|
||||
else:
|
||||
start_pos = max(0, data_len - 256)
|
||||
for i in range(start_pos, data_len):
|
||||
# if pad_start < i: mask = 0xff; else: mask = 0x00
|
||||
mask = ct_lsb_prop_u8(ct_le_u32(pad_start, i))
|
||||
# if data[i] != pad_length and "inside_pad": result = False
|
||||
result |= (data[i] ^ pad_length) & mask
|
||||
|
||||
#
|
||||
# check MAC
|
||||
#
|
||||
|
||||
# real place where mac starts and data ends
|
||||
mac_start = pad_start - mac.digest_size
|
||||
mac_start = max(0, mac_start)
|
||||
|
||||
# place to start processing
|
||||
start_pos = max(0, data_len - (256 + mac.digest_size)) // mac.block_size
|
||||
start_pos *= mac.block_size
|
||||
|
||||
# add start data
|
||||
data_mac = mac.copy()
|
||||
data_mac.update(compatHMAC(seqnumBytes))
|
||||
data_mac.update(compatHMAC(bytearray([contentType])))
|
||||
if version != (3, 0): # version is public
|
||||
data_mac.update(compatHMAC(bytearray([version[0]])))
|
||||
data_mac.update(compatHMAC(bytearray([version[1]])))
|
||||
data_mac.update(compatHMAC(bytearray([mac_start >> 8])))
|
||||
data_mac.update(compatHMAC(bytearray([mac_start & 0xff])))
|
||||
data_mac.update(compatHMAC(data[:start_pos]))
|
||||
|
||||
# don't check past the array end (already checked to be >= zero)
|
||||
end_pos = data_len - mac.digest_size
|
||||
|
||||
# calculate all possible
|
||||
for i in range(start_pos, end_pos): # constant for given overall length
|
||||
cur_mac = data_mac.copy()
|
||||
cur_mac.update(compatHMAC(data[start_pos:i]))
|
||||
mac_compare = bytearray(cur_mac.digest())
|
||||
# compare the hash for real only if it's the place where mac is
|
||||
# supposed to be
|
||||
mask = ct_lsb_prop_u8(ct_eq_u32(i, mac_start))
|
||||
for j in range(0, mac.digest_size): # digest_size is public
|
||||
result |= (data[i+j] ^ mac_compare[j]) & mask
|
||||
|
||||
# return python boolean
|
||||
return result == 0
|
||||
|
||||
if hasattr(hmac, 'compare_digest'):
|
||||
ct_compare_digest = hmac.compare_digest
|
||||
else:
|
||||
def ct_compare_digest(val_a, val_b):
|
||||
"""Compares if string like objects are equal. Constant time."""
|
||||
if len(val_a) != len(val_b):
|
||||
return False
|
||||
|
||||
result = 0
|
||||
for x, y in zip(val_a, val_b):
|
||||
result |= x ^ y
|
||||
|
||||
return result == 0
|
||||
366
mediaflow_proxy/utils/cryptomath.py
Normal file
366
mediaflow_proxy/utils/cryptomath.py
Normal file
@@ -0,0 +1,366 @@
|
||||
# Authors:
|
||||
# Trevor Perrin
|
||||
# Martin von Loewis - python 3 port
|
||||
# Yngve Pettersen (ported by Paul Sokolovsky) - TLS 1.2
|
||||
#
|
||||
# See the LICENSE file for legal information regarding use of this file.
|
||||
|
||||
"""cryptomath module
|
||||
|
||||
This module has basic math/crypto code."""
|
||||
from __future__ import print_function
|
||||
import os
|
||||
import math
|
||||
import base64
|
||||
import binascii
|
||||
|
||||
from .compat import compat26Str, compatHMAC, compatLong, \
|
||||
bytes_to_int, int_to_bytes, bit_length, byte_length
|
||||
from .codec import Writer
|
||||
|
||||
from . import tlshashlib as hashlib
|
||||
from . import tlshmac as hmac
|
||||
|
||||
|
||||
m2cryptoLoaded = False
|
||||
gmpyLoaded = False
|
||||
GMPY2_LOADED = False
|
||||
pycryptoLoaded = False
|
||||
|
||||
|
||||
# **************************************************************************
|
||||
# PRNG Functions
|
||||
# **************************************************************************
|
||||
|
||||
# Check that os.urandom works
|
||||
import zlib
|
||||
assert len(zlib.compress(os.urandom(1000))) > 900
|
||||
|
||||
def getRandomBytes(howMany):
|
||||
b = bytearray(os.urandom(howMany))
|
||||
assert(len(b) == howMany)
|
||||
return b
|
||||
|
||||
prngName = "os.urandom"
|
||||
|
||||
# **************************************************************************
|
||||
# Simple hash functions
|
||||
# **************************************************************************
|
||||
|
||||
def MD5(b):
|
||||
"""Return a MD5 digest of data"""
|
||||
return secureHash(b, 'md5')
|
||||
|
||||
def SHA1(b):
|
||||
"""Return a SHA1 digest of data"""
|
||||
return secureHash(b, 'sha1')
|
||||
|
||||
def secureHash(data, algorithm):
|
||||
"""Return a digest of `data` using `algorithm`"""
|
||||
hashInstance = hashlib.new(algorithm)
|
||||
hashInstance.update(compat26Str(data))
|
||||
return bytearray(hashInstance.digest())
|
||||
|
||||
def secureHMAC(k, b, algorithm):
|
||||
"""Return a HMAC using `b` and `k` using `algorithm`"""
|
||||
k = compatHMAC(k)
|
||||
b = compatHMAC(b)
|
||||
return bytearray(hmac.new(k, b, getattr(hashlib, algorithm)).digest())
|
||||
|
||||
def HMAC_MD5(k, b):
|
||||
return secureHMAC(k, b, 'md5')
|
||||
|
||||
def HMAC_SHA1(k, b):
|
||||
return secureHMAC(k, b, 'sha1')
|
||||
|
||||
def HMAC_SHA256(k, b):
|
||||
return secureHMAC(k, b, 'sha256')
|
||||
|
||||
def HMAC_SHA384(k, b):
|
||||
return secureHMAC(k, b, 'sha384')
|
||||
|
||||
def HKDF_expand(PRK, info, L, algorithm):
|
||||
N = divceil(L, getattr(hashlib, algorithm)().digest_size)
|
||||
T = bytearray()
|
||||
Titer = bytearray()
|
||||
for x in range(1, N+2):
|
||||
T += Titer
|
||||
Titer = secureHMAC(PRK, Titer + info + bytearray([x]), algorithm)
|
||||
return T[:L]
|
||||
|
||||
def HKDF_expand_label(secret, label, hashValue, length, algorithm):
|
||||
"""
|
||||
TLS1.3 key derivation function (HKDF-Expand-Label).
|
||||
|
||||
:param bytearray secret: the key from which to derive the keying material
|
||||
:param bytearray label: label used to differentiate the keying materials
|
||||
:param bytearray hashValue: bytes used to "salt" the produced keying
|
||||
material
|
||||
:param int length: number of bytes to produce
|
||||
:param str algorithm: name of the secure hash algorithm used as the
|
||||
basis of the HKDF
|
||||
:rtype: bytearray
|
||||
"""
|
||||
hkdfLabel = Writer()
|
||||
hkdfLabel.addTwo(length)
|
||||
hkdfLabel.addVarSeq(bytearray(b"tls13 ") + label, 1, 1)
|
||||
hkdfLabel.addVarSeq(hashValue, 1, 1)
|
||||
|
||||
return HKDF_expand(secret, hkdfLabel.bytes, length, algorithm)
|
||||
|
||||
def derive_secret(secret, label, handshake_hashes, algorithm):
|
||||
"""
|
||||
TLS1.3 key derivation function (Derive-Secret).
|
||||
|
||||
:param bytearray secret: secret key used to derive the keying material
|
||||
:param bytearray label: label used to differentiate they keying materials
|
||||
:param HandshakeHashes handshake_hashes: hashes of the handshake messages
|
||||
or `None` if no handshake transcript is to be used for derivation of
|
||||
keying material
|
||||
:param str algorithm: name of the secure hash algorithm used as the
|
||||
basis of the HKDF algorithm - governs how much keying material will
|
||||
be generated
|
||||
:rtype: bytearray
|
||||
"""
|
||||
if handshake_hashes is None:
|
||||
hs_hash = secureHash(bytearray(b''), algorithm)
|
||||
else:
|
||||
hs_hash = handshake_hashes.digest(algorithm)
|
||||
return HKDF_expand_label(secret, label, hs_hash,
|
||||
getattr(hashlib, algorithm)().digest_size,
|
||||
algorithm)
|
||||
|
||||
# **************************************************************************
|
||||
# Converter Functions
|
||||
# **************************************************************************
|
||||
|
||||
def bytesToNumber(b, endian="big"):
|
||||
"""
|
||||
Convert a number stored in bytearray to an integer.
|
||||
|
||||
By default assumes big-endian encoding of the number.
|
||||
"""
|
||||
return bytes_to_int(b, endian)
|
||||
|
||||
|
||||
def numberToByteArray(n, howManyBytes=None, endian="big"):
|
||||
"""
|
||||
Convert an integer into a bytearray, zero-pad to howManyBytes.
|
||||
|
||||
The returned bytearray may be smaller than howManyBytes, but will
|
||||
not be larger. The returned bytearray will contain a big- or little-endian
|
||||
encoding of the input integer (n). Big endian encoding is used by default.
|
||||
"""
|
||||
if howManyBytes is not None:
|
||||
length = byte_length(n)
|
||||
if howManyBytes < length:
|
||||
ret = int_to_bytes(n, length, endian)
|
||||
if endian == "big":
|
||||
return ret[length-howManyBytes:length]
|
||||
return ret[:howManyBytes]
|
||||
return int_to_bytes(n, howManyBytes, endian)
|
||||
|
||||
|
||||
def mpiToNumber(mpi):
|
||||
"""Convert a MPI (OpenSSL bignum string) to an integer."""
|
||||
byte = bytearray(mpi)
|
||||
if byte[4] & 0x80:
|
||||
raise ValueError("Input must be a positive integer")
|
||||
return bytesToNumber(byte[4:])
|
||||
|
||||
|
||||
def numberToMPI(n):
|
||||
b = numberToByteArray(n)
|
||||
ext = 0
|
||||
#If the high-order bit is going to be set,
|
||||
#add an extra byte of zeros
|
||||
if (numBits(n) & 0x7)==0:
|
||||
ext = 1
|
||||
length = numBytes(n) + ext
|
||||
b = bytearray(4+ext) + b
|
||||
b[0] = (length >> 24) & 0xFF
|
||||
b[1] = (length >> 16) & 0xFF
|
||||
b[2] = (length >> 8) & 0xFF
|
||||
b[3] = length & 0xFF
|
||||
return bytes(b)
|
||||
|
||||
|
||||
# **************************************************************************
|
||||
# Misc. Utility Functions
|
||||
# **************************************************************************
|
||||
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
# pylint recognises them as constants, not function names, also
|
||||
# we can't change their names without API change
|
||||
numBits = bit_length
|
||||
|
||||
|
||||
numBytes = byte_length
|
||||
# pylint: enable=invalid-name
|
||||
|
||||
|
||||
# **************************************************************************
|
||||
# Big Number Math
|
||||
# **************************************************************************
|
||||
|
||||
def getRandomNumber(low, high):
|
||||
assert low < high
|
||||
howManyBits = numBits(high)
|
||||
howManyBytes = numBytes(high)
|
||||
lastBits = howManyBits % 8
|
||||
while 1:
|
||||
bytes = getRandomBytes(howManyBytes)
|
||||
if lastBits:
|
||||
bytes[0] = bytes[0] % (1 << lastBits)
|
||||
n = bytesToNumber(bytes)
|
||||
if n >= low and n < high:
|
||||
return n
|
||||
|
||||
def gcd(a,b):
|
||||
a, b = max(a,b), min(a,b)
|
||||
while b:
|
||||
a, b = b, a % b
|
||||
return a
|
||||
|
||||
def lcm(a, b):
|
||||
return (a * b) // gcd(a, b)
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
# disable pylint check as the (a, b) are part of the API
|
||||
if GMPY2_LOADED:
|
||||
def invMod(a, b):
|
||||
"""Return inverse of a mod b, zero if none."""
|
||||
if a == 0:
|
||||
return 0
|
||||
return powmod(a, -1, b)
|
||||
else:
|
||||
# Use Extended Euclidean Algorithm
|
||||
def invMod(a, b):
|
||||
"""Return inverse of a mod b, zero if none."""
|
||||
c, d = a, b
|
||||
uc, ud = 1, 0
|
||||
while c != 0:
|
||||
q = d // c
|
||||
c, d = d-(q*c), c
|
||||
uc, ud = ud - (q * uc), uc
|
||||
if d == 1:
|
||||
return ud % b
|
||||
return 0
|
||||
# pylint: enable=invalid-name
|
||||
|
||||
|
||||
if gmpyLoaded or GMPY2_LOADED:
|
||||
def powMod(base, power, modulus):
|
||||
base = mpz(base)
|
||||
power = mpz(power)
|
||||
modulus = mpz(modulus)
|
||||
result = pow(base, power, modulus)
|
||||
return compatLong(result)
|
||||
else:
|
||||
powMod = pow
|
||||
|
||||
|
||||
def divceil(divident, divisor):
|
||||
"""Integer division with rounding up"""
|
||||
quot, r = divmod(divident, divisor)
|
||||
return quot + int(bool(r))
|
||||
|
||||
|
||||
#Pre-calculate a sieve of the ~100 primes < 1000:
|
||||
def makeSieve(n):
|
||||
sieve = list(range(n))
|
||||
for count in range(2, int(math.sqrt(n))+1):
|
||||
if sieve[count] == 0:
|
||||
continue
|
||||
x = sieve[count] * 2
|
||||
while x < len(sieve):
|
||||
sieve[x] = 0
|
||||
x += sieve[count]
|
||||
sieve = [x for x in sieve[2:] if x]
|
||||
return sieve
|
||||
|
||||
def isPrime(n, iterations=5, display=False, sieve=makeSieve(1000)):
|
||||
#Trial division with sieve
|
||||
for x in sieve:
|
||||
if x >= n: return True
|
||||
if n % x == 0: return False
|
||||
#Passed trial division, proceed to Rabin-Miller
|
||||
#Rabin-Miller implemented per Ferguson & Schneier
|
||||
#Compute s, t for Rabin-Miller
|
||||
if display: print("*", end=' ')
|
||||
s, t = n-1, 0
|
||||
while s % 2 == 0:
|
||||
s, t = s//2, t+1
|
||||
#Repeat Rabin-Miller x times
|
||||
a = 2 #Use 2 as a base for first iteration speedup, per HAC
|
||||
for count in range(iterations):
|
||||
v = powMod(a, s, n)
|
||||
if v==1:
|
||||
continue
|
||||
i = 0
|
||||
while v != n-1:
|
||||
if i == t-1:
|
||||
return False
|
||||
else:
|
||||
v, i = powMod(v, 2, n), i+1
|
||||
a = getRandomNumber(2, n)
|
||||
return True
|
||||
|
||||
|
||||
def getRandomPrime(bits, display=False):
|
||||
"""
|
||||
Generate a random prime number of a given size.
|
||||
|
||||
the number will be 'bits' bits long (i.e. generated number will be
|
||||
larger than `(2^(bits-1) * 3 ) / 2` but smaller than 2^bits.
|
||||
"""
|
||||
assert bits >= 10
|
||||
#The 1.5 ensures the 2 MSBs are set
|
||||
#Thus, when used for p,q in RSA, n will have its MSB set
|
||||
#
|
||||
#Since 30 is lcm(2,3,5), we'll set our test numbers to
|
||||
#29 % 30 and keep them there
|
||||
low = ((2 ** (bits-1)) * 3) // 2
|
||||
high = 2 ** bits - 30
|
||||
while True:
|
||||
if display:
|
||||
print(".", end=' ')
|
||||
cand_p = getRandomNumber(low, high)
|
||||
# make odd
|
||||
if cand_p % 2 == 0:
|
||||
cand_p += 1
|
||||
if isPrime(cand_p, display=display):
|
||||
return cand_p
|
||||
|
||||
|
||||
#Unused at the moment...
|
||||
def getRandomSafePrime(bits, display=False):
|
||||
"""Generate a random safe prime.
|
||||
|
||||
Will generate a prime `bits` bits long (see getRandomPrime) such that
|
||||
the (p-1)/2 will also be prime.
|
||||
"""
|
||||
assert bits >= 10
|
||||
#The 1.5 ensures the 2 MSBs are set
|
||||
#Thus, when used for p,q in RSA, n will have its MSB set
|
||||
#
|
||||
#Since 30 is lcm(2,3,5), we'll set our test numbers to
|
||||
#29 % 30 and keep them there
|
||||
low = (2 ** (bits-2)) * 3//2
|
||||
high = (2 ** (bits-1)) - 30
|
||||
q = getRandomNumber(low, high)
|
||||
q += 29 - (q % 30)
|
||||
while 1:
|
||||
if display: print(".", end=' ')
|
||||
q += 30
|
||||
if (q >= high):
|
||||
q = getRandomNumber(low, high)
|
||||
q += 29 - (q % 30)
|
||||
#Ideas from Tom Wu's SRP code
|
||||
#Do trial division on p and q before Rabin-Miller
|
||||
if isPrime(q, 0, display=display):
|
||||
p = (2 * q) + 1
|
||||
if isPrime(p, display=display):
|
||||
if isPrime(q, display=display):
|
||||
return p
|
||||
218
mediaflow_proxy/utils/deprecations.py
Normal file
218
mediaflow_proxy/utils/deprecations.py
Normal file
@@ -0,0 +1,218 @@
|
||||
# Copyright (c) 2018 Hubert Kario
|
||||
#
|
||||
# See the LICENSE file for legal information regarding use of this file.
|
||||
"""Methods for deprecating old names for arguments or attributes."""
|
||||
import warnings
|
||||
import inspect
|
||||
from functools import wraps
|
||||
|
||||
|
||||
def deprecated_class_name(old_name,
|
||||
warn="Class name '{old_name}' is deprecated, "
|
||||
"please use '{new_name}'"):
|
||||
"""
|
||||
Class decorator to deprecate a use of class.
|
||||
|
||||
:param str old_name: the deprecated name that will be registered, but
|
||||
will raise warnings if used.
|
||||
|
||||
:param str warn: DeprecationWarning format string for informing the
|
||||
user what is the current class name, uses 'old_name' for the deprecated
|
||||
keyword name and the 'new_name' for the current one.
|
||||
Example: "Old name: {old_nam}, use '{new_name}' instead".
|
||||
"""
|
||||
def _wrap(obj):
|
||||
assert callable(obj)
|
||||
|
||||
def _warn():
|
||||
warnings.warn(warn.format(old_name=old_name,
|
||||
new_name=obj.__name__),
|
||||
DeprecationWarning,
|
||||
stacklevel=3)
|
||||
|
||||
def _wrap_with_warn(func, is_inspect):
|
||||
@wraps(func)
|
||||
def _func(*args, **kwargs):
|
||||
if is_inspect:
|
||||
# XXX: If use another name to call,
|
||||
# you will not get the warning.
|
||||
# we do this instead of subclassing or metaclass as
|
||||
# we want to isinstance(new_name(), old_name) and
|
||||
# isinstance(old_name(), new_name) to work
|
||||
frame = inspect.currentframe().f_back
|
||||
code = inspect.getframeinfo(frame).code_context
|
||||
if [line for line in code
|
||||
if '{0}('.format(old_name) in line]:
|
||||
_warn()
|
||||
else:
|
||||
_warn()
|
||||
return func(*args, **kwargs)
|
||||
return _func
|
||||
|
||||
# Make old name available.
|
||||
frame = inspect.currentframe().f_back
|
||||
if old_name in frame.f_globals:
|
||||
raise NameError("Name '{0}' already in use.".format(old_name))
|
||||
|
||||
if inspect.isclass(obj):
|
||||
obj.__init__ = _wrap_with_warn(obj.__init__, True)
|
||||
placeholder = obj
|
||||
else:
|
||||
placeholder = _wrap_with_warn(obj, False)
|
||||
|
||||
frame.f_globals[old_name] = placeholder
|
||||
|
||||
return obj
|
||||
return _wrap
|
||||
|
||||
|
||||
def deprecated_params(names, warn="Param name '{old_name}' is deprecated, "
|
||||
"please use '{new_name}'"):
|
||||
"""Decorator to translate obsolete names and warn about their use.
|
||||
|
||||
:param dict names: dictionary with pairs of new_name: old_name
|
||||
that will be used for translating obsolete param names to new names
|
||||
|
||||
:param str warn: DeprecationWarning format string for informing the user
|
||||
what is the current parameter name, uses 'old_name' for the
|
||||
deprecated keyword name and 'new_name' for the current one.
|
||||
Example: "Old name: {old_name}, use {new_name} instead".
|
||||
"""
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
for new_name, old_name in names.items():
|
||||
if old_name in kwargs:
|
||||
if new_name in kwargs:
|
||||
raise TypeError("got multiple values for keyword "
|
||||
"argument '{0}'".format(new_name))
|
||||
warnings.warn(warn.format(old_name=old_name,
|
||||
new_name=new_name),
|
||||
DeprecationWarning,
|
||||
stacklevel=2)
|
||||
kwargs[new_name] = kwargs.pop(old_name)
|
||||
return func(*args, **kwargs)
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
def deprecated_instance_attrs(names,
|
||||
warn="Attribute '{old_name}' is deprecated, "
|
||||
"please use '{new_name}'"):
|
||||
"""Decorator to deprecate class instance attributes.
|
||||
|
||||
Translates all names in `names` to use new names and emits warnings
|
||||
if the translation was necessary. Does apply only to instance variables
|
||||
and attributes (won't modify behaviour of class variables, static methods,
|
||||
etc.
|
||||
|
||||
:param dict names: dictionary with paris of new_name: old_name that will
|
||||
be used to translate the calls
|
||||
:param str warn: DeprecationWarning format string for informing the user
|
||||
what is the current parameter name, uses 'old_name' for the
|
||||
deprecated keyword name and 'new_name' for the current one.
|
||||
Example: "Old name: {old_name}, use {new_name} instead".
|
||||
"""
|
||||
# reverse the dict as we're looking for old attributes, not new ones
|
||||
names = dict((j, i) for i, j in names.items())
|
||||
|
||||
def decorator(clazz):
|
||||
def getx(self, name, __old_getx=getattr(clazz, "__getattr__", None)):
|
||||
if name in names:
|
||||
warnings.warn(warn.format(old_name=name,
|
||||
new_name=names[name]),
|
||||
DeprecationWarning,
|
||||
stacklevel=2)
|
||||
return getattr(self, names[name])
|
||||
if __old_getx:
|
||||
if hasattr(__old_getx, "__func__"):
|
||||
return __old_getx.__func__(self, name)
|
||||
return __old_getx(self, name)
|
||||
raise AttributeError("'{0}' object has no attribute '{1}'"
|
||||
.format(clazz.__name__, name))
|
||||
|
||||
getx.__name__ = "__getattr__"
|
||||
clazz.__getattr__ = getx
|
||||
|
||||
def setx(self, name, value, __old_setx=getattr(clazz, "__setattr__")):
|
||||
if name in names:
|
||||
warnings.warn(warn.format(old_name=name,
|
||||
new_name=names[name]),
|
||||
DeprecationWarning,
|
||||
stacklevel=2)
|
||||
setattr(self, names[name], value)
|
||||
else:
|
||||
__old_setx(self, name, value)
|
||||
|
||||
setx.__name__ = "__setattr__"
|
||||
clazz.__setattr__ = setx
|
||||
|
||||
def delx(self, name, __old_delx=getattr(clazz, "__delattr__")):
|
||||
if name in names:
|
||||
warnings.warn(warn.format(old_name=name,
|
||||
new_name=names[name]),
|
||||
DeprecationWarning,
|
||||
stacklevel=2)
|
||||
delattr(self, names[name])
|
||||
else:
|
||||
__old_delx(self, name)
|
||||
|
||||
delx.__name__ = "__delattr__"
|
||||
clazz.__delattr__ = delx
|
||||
|
||||
return clazz
|
||||
return decorator
|
||||
|
||||
|
||||
def deprecated_attrs(names, warn="Attribute '{old_name}' is deprecated, "
|
||||
"please use '{new_name}'"):
|
||||
"""Decorator to deprecate all specified attributes in class.
|
||||
|
||||
Translates all names in `names` to use new names and emits warnings
|
||||
if the translation was necessary.
|
||||
|
||||
Note: uses metaclass magic so is incompatible with other metaclass uses
|
||||
|
||||
:param dict names: dictionary with paris of new_name: old_name that will
|
||||
be used to translate the calls
|
||||
:param str warn: DeprecationWarning format string for informing the user
|
||||
what is the current parameter name, uses 'old_name' for the
|
||||
deprecated keyword name and 'new_name' for the current one.
|
||||
Example: "Old name: {old_name}, use {new_name} instead".
|
||||
"""
|
||||
# prepare metaclass for handling all the class methods, class variables
|
||||
# and static methods (as they don't go through instance's __getattr__)
|
||||
class DeprecatedProps(type):
|
||||
pass
|
||||
|
||||
metaclass = deprecated_instance_attrs(names, warn)(DeprecatedProps)
|
||||
|
||||
def wrapper(cls):
|
||||
cls = deprecated_instance_attrs(names, warn)(cls)
|
||||
|
||||
# apply metaclass
|
||||
orig_vars = cls.__dict__.copy()
|
||||
slots = orig_vars.get('__slots__')
|
||||
if slots is not None:
|
||||
if isinstance(slots, str):
|
||||
slots = [slots]
|
||||
for slots_var in slots:
|
||||
orig_vars.pop(slots_var)
|
||||
orig_vars.pop('__dict__', None)
|
||||
orig_vars.pop('__weakref__', None)
|
||||
return metaclass(cls.__name__, cls.__bases__, orig_vars)
|
||||
return wrapper
|
||||
|
||||
def deprecated_method(message):
|
||||
"""Decorator for deprecating methods.
|
||||
|
||||
:param ste message: The message you want to display.
|
||||
"""
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
warnings.warn("{0} is a deprecated method. {1}".format(func.__name__, message),
|
||||
DeprecationWarning, stacklevel=2)
|
||||
return func(*args, **kwargs)
|
||||
return wrapper
|
||||
return decorator
|
||||
@@ -6,6 +6,9 @@ from urllib.parse import urlparse
|
||||
import httpx
|
||||
from mediaflow_proxy.utils.http_utils import create_httpx_client
|
||||
from mediaflow_proxy.configs import settings
|
||||
from collections import OrderedDict
|
||||
import time
|
||||
from urllib.parse import urljoin
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -23,12 +26,21 @@ class HLSPreBuffer:
|
||||
max_cache_size (int): Maximum number of segments to cache (uses config if None)
|
||||
prebuffer_segments (int): Number of segments to pre-buffer ahead (uses config if None)
|
||||
"""
|
||||
from collections import OrderedDict
|
||||
import time
|
||||
from urllib.parse import urljoin
|
||||
self.max_cache_size = max_cache_size or settings.hls_prebuffer_cache_size
|
||||
self.prebuffer_segments = prebuffer_segments or settings.hls_prebuffer_segments
|
||||
self.max_memory_percent = settings.hls_prebuffer_max_memory_percent
|
||||
self.emergency_threshold = settings.hls_prebuffer_emergency_threshold
|
||||
self.segment_cache: Dict[str, bytes] = {}
|
||||
# Cache LRU
|
||||
self.segment_cache: "OrderedDict[str, bytes]" = OrderedDict()
|
||||
# Mappa playlist -> lista segmenti
|
||||
self.segment_urls: Dict[str, List[str]] = {}
|
||||
# Mappa inversa segmento -> (playlist_url, index)
|
||||
self.segment_to_playlist: Dict[str, tuple[str, int]] = {}
|
||||
# Stato per playlist: {headers, last_access, refresh_task, target_duration}
|
||||
self.playlist_state: Dict[str, dict] = {}
|
||||
self.client = create_httpx_client()
|
||||
|
||||
async def prebuffer_playlist(self, playlist_url: str, headers: Dict[str, str]) -> None:
|
||||
@@ -41,37 +53,44 @@ class HLSPreBuffer:
|
||||
"""
|
||||
try:
|
||||
logger.debug(f"Starting pre-buffer for playlist: {playlist_url}")
|
||||
|
||||
# Download and parse playlist
|
||||
response = await self.client.get(playlist_url, headers=headers)
|
||||
response.raise_for_status()
|
||||
playlist_content = response.text
|
||||
|
||||
# Check if this is a master playlist (contains variants)
|
||||
|
||||
# Se master playlist: prendi la prima variante (fix relativo)
|
||||
if "#EXT-X-STREAM-INF" in playlist_content:
|
||||
logger.debug(f"Master playlist detected, finding first variant")
|
||||
# Extract variant URLs
|
||||
variant_urls = self._extract_variant_urls(playlist_content, playlist_url)
|
||||
if variant_urls:
|
||||
# Pre-buffer the first variant
|
||||
first_variant_url = variant_urls[0]
|
||||
logger.debug(f"Pre-buffering first variant: {first_variant_url}")
|
||||
await self.prebuffer_playlist(first_variant_url, headers)
|
||||
else:
|
||||
logger.warning("No variants found in master playlist")
|
||||
return
|
||||
|
||||
# Extract segment URLs
|
||||
|
||||
# Media playlist: estrai segmenti, salva stato e lancia refresh loop
|
||||
segment_urls = self._extract_segment_urls(playlist_content, playlist_url)
|
||||
|
||||
# Store segment URLs for this playlist
|
||||
self.segment_urls[playlist_url] = segment_urls
|
||||
|
||||
# Pre-buffer first few segments
|
||||
# aggiorna mappa inversa
|
||||
for idx, u in enumerate(segment_urls):
|
||||
self.segment_to_playlist[u] = (playlist_url, idx)
|
||||
|
||||
# prebuffer iniziale
|
||||
await self._prebuffer_segments(segment_urls[:self.prebuffer_segments], headers)
|
||||
|
||||
logger.info(f"Pre-buffered {min(self.prebuffer_segments, len(segment_urls))} segments for {playlist_url}")
|
||||
|
||||
|
||||
# setup refresh loop se non già attivo
|
||||
target_duration = self._parse_target_duration(playlist_content) or 6
|
||||
st = self.playlist_state.get(playlist_url, {})
|
||||
if not st.get("refresh_task") or st["refresh_task"].done():
|
||||
task = asyncio.create_task(self._refresh_playlist_loop(playlist_url, headers, target_duration))
|
||||
self.playlist_state[playlist_url] = {
|
||||
"headers": headers,
|
||||
"last_access": asyncio.get_event_loop().time(),
|
||||
"refresh_task": task,
|
||||
"target_duration": target_duration,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to pre-buffer playlist {playlist_url}: {e}")
|
||||
|
||||
@@ -124,34 +143,24 @@ class HLSPreBuffer:
|
||||
|
||||
def _extract_variant_urls(self, playlist_content: str, base_url: str) -> List[str]:
|
||||
"""
|
||||
Extract variant URLs from master playlist content.
|
||||
|
||||
Args:
|
||||
playlist_content (str): Content of the master playlist
|
||||
base_url (str): Base URL for resolving relative URLs
|
||||
|
||||
Returns:
|
||||
List[str]: List of variant URLs
|
||||
Estrae le varianti dal master playlist. Corretto per gestire URI relativi:
|
||||
prende la riga non-commento successiva a #EXT-X-STREAM-INF e la risolve rispetto a base_url.
|
||||
"""
|
||||
from urllib.parse import urljoin
|
||||
variant_urls = []
|
||||
lines = playlist_content.split('\n')
|
||||
|
||||
lines = [l.strip() for l in playlist_content.split('\n')]
|
||||
take_next_uri = False
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if line and not line.startswith('#') and ('http://' in line or 'https://' in line):
|
||||
# Resolve relative URLs
|
||||
if line.startswith('http'):
|
||||
variant_urls.append(line)
|
||||
else:
|
||||
# Join with base URL for relative paths
|
||||
parsed_base = urlparse(base_url)
|
||||
variant_url = f"{parsed_base.scheme}://{parsed_base.netloc}{line}"
|
||||
variant_urls.append(variant_url)
|
||||
|
||||
if line.startswith("#EXT-X-STREAM-INF"):
|
||||
take_next_uri = True
|
||||
continue
|
||||
if take_next_uri:
|
||||
take_next_uri = False
|
||||
if line and not line.startswith('#'):
|
||||
variant_urls.append(urljoin(base_url, line))
|
||||
logger.debug(f"Extracted {len(variant_urls)} variant URLs from master playlist")
|
||||
if variant_urls:
|
||||
logger.debug(f"First variant URL: {variant_urls[0]}")
|
||||
|
||||
return variant_urls
|
||||
|
||||
async def _prebuffer_segments(self, segment_urls: List[str], headers: Dict[str, str]) -> None:
|
||||
@@ -166,7 +175,6 @@ class HLSPreBuffer:
|
||||
for url in segment_urls:
|
||||
if url not in self.segment_cache:
|
||||
tasks.append(self._download_segment(url, headers))
|
||||
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
@@ -196,16 +204,16 @@ class HLSPreBuffer:
|
||||
|
||||
def _emergency_cache_cleanup(self) -> None:
|
||||
"""
|
||||
Perform emergency cache cleanup when memory usage is high.
|
||||
Esegue cleanup LRU rimuovendo il 50% più vecchio.
|
||||
"""
|
||||
if self._check_memory_threshold():
|
||||
logger.warning("Emergency cache cleanup triggered due to high memory usage")
|
||||
# Clear 50% of cache
|
||||
cache_size = len(self.segment_cache)
|
||||
keys_to_remove = list(self.segment_cache.keys())[:cache_size // 2]
|
||||
for key in keys_to_remove:
|
||||
del self.segment_cache[key]
|
||||
logger.info(f"Emergency cleanup removed {len(keys_to_remove)} segments from cache")
|
||||
to_remove = max(1, len(self.segment_cache) // 2)
|
||||
removed = 0
|
||||
while removed < to_remove and self.segment_cache:
|
||||
self.segment_cache.popitem(last=False) # rimuovi LRU
|
||||
removed += 1
|
||||
logger.info(f"Emergency cleanup removed {removed} segments from cache")
|
||||
|
||||
async def _download_segment(self, segment_url: str, headers: Dict[str, str]) -> None:
|
||||
"""
|
||||
@@ -216,29 +224,26 @@ class HLSPreBuffer:
|
||||
headers (Dict[str, str]): Headers to use for request
|
||||
"""
|
||||
try:
|
||||
# Check memory usage before downloading
|
||||
memory_percent = self._get_memory_usage_percent()
|
||||
if memory_percent > self.max_memory_percent:
|
||||
logger.warning(f"Memory usage {memory_percent}% exceeds limit {self.max_memory_percent}%, skipping download")
|
||||
return
|
||||
|
||||
|
||||
response = await self.client.get(segment_url, headers=headers)
|
||||
response.raise_for_status()
|
||||
|
||||
# Cache the segment
|
||||
|
||||
# Cache LRU
|
||||
self.segment_cache[segment_url] = response.content
|
||||
|
||||
# Check for emergency cleanup
|
||||
self.segment_cache.move_to_end(segment_url, last=True)
|
||||
|
||||
if self._check_memory_threshold():
|
||||
self._emergency_cache_cleanup()
|
||||
# Maintain cache size
|
||||
elif len(self.segment_cache) > self.max_cache_size:
|
||||
# Remove oldest entries (simple FIFO)
|
||||
oldest_key = next(iter(self.segment_cache))
|
||||
del self.segment_cache[oldest_key]
|
||||
|
||||
# Evict LRU finché non rientra
|
||||
while len(self.segment_cache) > self.max_cache_size:
|
||||
self.segment_cache.popitem(last=False)
|
||||
|
||||
logger.debug(f"Cached segment: {segment_url}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to download segment {segment_url}: {e}")
|
||||
|
||||
@@ -256,38 +261,64 @@ class HLSPreBuffer:
|
||||
# Check cache first
|
||||
if segment_url in self.segment_cache:
|
||||
logger.debug(f"Cache hit for segment: {segment_url}")
|
||||
return self.segment_cache[segment_url]
|
||||
|
||||
# Check memory usage before downloading
|
||||
# LRU touch
|
||||
data = self.segment_cache[segment_url]
|
||||
self.segment_cache.move_to_end(segment_url, last=True)
|
||||
# aggiorna last_access per la playlist se mappata
|
||||
pl = self.segment_to_playlist.get(segment_url)
|
||||
if pl:
|
||||
st = self.playlist_state.get(pl[0])
|
||||
if st:
|
||||
st["last_access"] = asyncio.get_event_loop().time()
|
||||
return data
|
||||
|
||||
memory_percent = self._get_memory_usage_percent()
|
||||
if memory_percent > self.max_memory_percent:
|
||||
logger.warning(f"Memory usage {memory_percent}% exceeds limit {self.max_memory_percent}%, skipping download")
|
||||
return None
|
||||
|
||||
# Download if not in cache
|
||||
|
||||
try:
|
||||
response = await self.client.get(segment_url, headers=headers)
|
||||
response.raise_for_status()
|
||||
segment_data = response.content
|
||||
|
||||
# Cache the segment
|
||||
|
||||
# Cache LRU
|
||||
self.segment_cache[segment_url] = segment_data
|
||||
|
||||
# Check for emergency cleanup
|
||||
self.segment_cache.move_to_end(segment_url, last=True)
|
||||
|
||||
if self._check_memory_threshold():
|
||||
self._emergency_cache_cleanup()
|
||||
# Maintain cache size
|
||||
elif len(self.segment_cache) > self.max_cache_size:
|
||||
oldest_key = next(iter(self.segment_cache))
|
||||
del self.segment_cache[oldest_key]
|
||||
|
||||
while len(self.segment_cache) > self.max_cache_size:
|
||||
self.segment_cache.popitem(last=False)
|
||||
|
||||
# aggiorna last_access per playlist
|
||||
pl = self.segment_to_playlist.get(segment_url)
|
||||
if pl:
|
||||
st = self.playlist_state.get(pl[0])
|
||||
if st:
|
||||
st["last_access"] = asyncio.get_event_loop().time()
|
||||
|
||||
logger.debug(f"Downloaded and cached segment: {segment_url}")
|
||||
return segment_data
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get segment {segment_url}: {e}")
|
||||
return None
|
||||
|
||||
async def prebuffer_from_segment(self, segment_url: str, headers: Dict[str, str]) -> None:
|
||||
"""
|
||||
Dato un URL di segmento, prebuffer i successivi in base alla playlist e all'indice mappato.
|
||||
"""
|
||||
mapped = self.segment_to_playlist.get(segment_url)
|
||||
if not mapped:
|
||||
return
|
||||
playlist_url, idx = mapped
|
||||
# aggiorna access time
|
||||
st = self.playlist_state.get(playlist_url)
|
||||
if st:
|
||||
st["last_access"] = asyncio.get_event_loop().time()
|
||||
await self.prebuffer_next_segments(playlist_url, idx, headers)
|
||||
|
||||
async def prebuffer_next_segments(self, playlist_url: str, current_segment_index: int, headers: Dict[str, str]) -> None:
|
||||
"""
|
||||
Pre-buffer next segments based on current playback position.
|
||||
@@ -299,10 +330,8 @@ class HLSPreBuffer:
|
||||
"""
|
||||
if playlist_url not in self.segment_urls:
|
||||
return
|
||||
|
||||
segment_urls = self.segment_urls[playlist_url]
|
||||
next_segments = segment_urls[current_segment_index + 1:current_segment_index + 1 + self.prebuffer_segments]
|
||||
|
||||
if next_segments:
|
||||
await self._prebuffer_segments(next_segments, headers)
|
||||
|
||||
@@ -310,6 +339,8 @@ class HLSPreBuffer:
|
||||
"""Clear the segment cache."""
|
||||
self.segment_cache.clear()
|
||||
self.segment_urls.clear()
|
||||
self.segment_to_playlist.clear()
|
||||
self.playlist_state.clear()
|
||||
logger.info("HLS pre-buffer cache cleared")
|
||||
|
||||
async def close(self) -> None:
|
||||
@@ -318,4 +349,142 @@ class HLSPreBuffer:
|
||||
|
||||
|
||||
# Global pre-buffer instance
|
||||
hls_prebuffer = HLSPreBuffer()
|
||||
hls_prebuffer = HLSPreBuffer()
|
||||
|
||||
|
||||
class HLSPreBuffer:
|
||||
def _parse_target_duration(self, playlist_content: str) -> Optional[int]:
|
||||
"""
|
||||
Parse EXT-X-TARGETDURATION from a media playlist and return duration in seconds.
|
||||
Returns None if not present or unparsable.
|
||||
"""
|
||||
for line in playlist_content.splitlines():
|
||||
line = line.strip()
|
||||
if line.startswith("#EXT-X-TARGETDURATION:"):
|
||||
try:
|
||||
value = line.split(":", 1)[1].strip()
|
||||
return int(float(value))
|
||||
except Exception:
|
||||
return None
|
||||
return None
|
||||
|
||||
async def _refresh_playlist_loop(self, playlist_url: str, headers: Dict[str, str], target_duration: int) -> None:
|
||||
"""
|
||||
Aggiorna periodicamente la playlist per seguire la sliding window e mantenere la cache coerente.
|
||||
Interrompe e pulisce dopo inattività prolungata.
|
||||
"""
|
||||
sleep_s = max(2, min(15, int(target_duration)))
|
||||
inactivity_timeout = 600 # 10 minuti
|
||||
while True:
|
||||
try:
|
||||
st = self.playlist_state.get(playlist_url)
|
||||
now = asyncio.get_event_loop().time()
|
||||
if not st:
|
||||
return
|
||||
if now - st.get("last_access", now) > inactivity_timeout:
|
||||
# cleanup specifico della playlist
|
||||
urls = set(self.segment_urls.get(playlist_url, []))
|
||||
if urls:
|
||||
# rimuovi dalla cache solo i segmenti di questa playlist
|
||||
for u in list(self.segment_cache.keys()):
|
||||
if u in urls:
|
||||
self.segment_cache.pop(u, None)
|
||||
# rimuovi mapping
|
||||
for u in urls:
|
||||
self.segment_to_playlist.pop(u, None)
|
||||
self.segment_urls.pop(playlist_url, None)
|
||||
self.playlist_state.pop(playlist_url, None)
|
||||
logger.info(f"Stopped HLS prebuffer for inactive playlist: {playlist_url}")
|
||||
return
|
||||
|
||||
# refresh manifest
|
||||
resp = await self.client.get(playlist_url, headers=headers)
|
||||
resp.raise_for_status()
|
||||
content = resp.text
|
||||
new_target = self._parse_target_duration(content)
|
||||
if new_target:
|
||||
sleep_s = max(2, min(15, int(new_target)))
|
||||
|
||||
new_urls = self._extract_segment_urls(content, playlist_url)
|
||||
if new_urls:
|
||||
self.segment_urls[playlist_url] = new_urls
|
||||
# rebuild reverse map per gli ultimi N (limita la memoria)
|
||||
for idx, u in enumerate(new_urls[-(self.max_cache_size * 2):]):
|
||||
# rimappiando sovrascrivi eventuali entry
|
||||
real_idx = len(new_urls) - (self.max_cache_size * 2) + idx if len(new_urls) > (self.max_cache_size * 2) else idx
|
||||
self.segment_to_playlist[u] = (playlist_url, real_idx)
|
||||
|
||||
# tenta un prebuffer proattivo: se conosciamo l'ultimo segmento accessibile, anticipa i successivi
|
||||
# Non conosciamo l'indice di riproduzione corrente qui, quindi non facciamo nulla di aggressivo.
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Playlist refresh error for {playlist_url}: {e}")
|
||||
await asyncio.sleep(sleep_s)
|
||||
def _extract_segment_urls(self, playlist_content: str, base_url: str) -> List[str]:
|
||||
"""
|
||||
Extract segment URLs from HLS playlist content.
|
||||
|
||||
Args:
|
||||
playlist_content (str): Content of the HLS playlist
|
||||
base_url (str): Base URL for resolving relative URLs
|
||||
|
||||
Returns:
|
||||
List[str]: List of segment URLs
|
||||
"""
|
||||
segment_urls = []
|
||||
lines = playlist_content.split('\n')
|
||||
|
||||
logger.debug(f"Analyzing playlist with {len(lines)} lines")
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if line and not line.startswith('#'):
|
||||
# Check if line contains a URL (http/https) or is a relative path
|
||||
if 'http://' in line or 'https://' in line:
|
||||
segment_urls.append(line)
|
||||
logger.debug(f"Found absolute URL: {line}")
|
||||
elif line and not line.startswith('#'):
|
||||
# This might be a relative path to a segment
|
||||
parsed_base = urlparse(base_url)
|
||||
# Ensure proper path joining
|
||||
if line.startswith('/'):
|
||||
segment_url = f"{parsed_base.scheme}://{parsed_base.netloc}{line}"
|
||||
else:
|
||||
# Get the directory path from base_url
|
||||
base_path = parsed_base.path.rsplit('/', 1)[0] if '/' in parsed_base.path else ''
|
||||
segment_url = f"{parsed_base.scheme}://{parsed_base.netloc}{base_path}/{line}"
|
||||
segment_urls.append(segment_url)
|
||||
logger.debug(f"Found relative path: {line} -> {segment_url}")
|
||||
|
||||
logger.debug(f"Extracted {len(segment_urls)} segment URLs from playlist")
|
||||
if segment_urls:
|
||||
logger.debug(f"First segment URL: {segment_urls[0]}")
|
||||
else:
|
||||
logger.debug("No segment URLs found in playlist")
|
||||
# Log first few lines for debugging
|
||||
for i, line in enumerate(lines[:10]):
|
||||
logger.debug(f"Line {i}: {line}")
|
||||
|
||||
return segment_urls
|
||||
|
||||
def _extract_variant_urls(self, playlist_content: str, base_url: str) -> List[str]:
|
||||
"""
|
||||
Estrae le varianti dal master playlist. Corretto per gestire URI relativi:
|
||||
prende la riga non-commento successiva a #EXT-X-STREAM-INF e la risolve rispetto a base_url.
|
||||
"""
|
||||
from urllib.parse import urljoin
|
||||
variant_urls = []
|
||||
lines = [l.strip() for l in playlist_content.split('\n')]
|
||||
take_next_uri = False
|
||||
for line in lines:
|
||||
if line.startswith("#EXT-X-STREAM-INF"):
|
||||
take_next_uri = True
|
||||
continue
|
||||
if take_next_uri:
|
||||
take_next_uri = False
|
||||
if line and not line.startswith('#'):
|
||||
variant_urls.append(urljoin(base_url, line))
|
||||
logger.debug(f"Extracted {len(variant_urls)} variant URLs from master playlist")
|
||||
if variant_urls:
|
||||
logger.debug(f"First variant URL: {variant_urls[0]}")
|
||||
return variant_urls
|
||||
54
mediaflow_proxy/utils/hls_utils.py
Normal file
54
mediaflow_proxy/utils/hls_utils.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import logging
|
||||
import re
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from urllib.parse import urljoin
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_hls_playlist(playlist_content: str, base_url: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Parses an HLS master playlist to extract stream information.
|
||||
|
||||
Args:
|
||||
playlist_content (str): The content of the M3U8 master playlist.
|
||||
base_url (str, optional): The base URL of the playlist for resolving relative stream URLs. Defaults to None.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: A list of dictionaries, each representing a stream variant.
|
||||
"""
|
||||
streams = []
|
||||
lines = playlist_content.strip().split('\n')
|
||||
|
||||
# Regex to capture attributes from #EXT-X-STREAM-INF
|
||||
stream_inf_pattern = re.compile(r'#EXT-X-STREAM-INF:(.*)')
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
if line.startswith('#EXT-X-STREAM-INF'):
|
||||
stream_info = {'raw_stream_inf': line}
|
||||
match = stream_inf_pattern.match(line)
|
||||
if not match:
|
||||
logger.warning(f"Could not parse #EXT-X-STREAM-INF line: {line}")
|
||||
continue
|
||||
attributes_str = match.group(1)
|
||||
|
||||
# Parse attributes like BANDWIDTH, RESOLUTION, etc.
|
||||
attributes = re.findall(r'([A-Z-]+)=("([^"]+)"|([^,]+))', attributes_str)
|
||||
for key, _, quoted_val, unquoted_val in attributes:
|
||||
value = quoted_val if quoted_val else unquoted_val
|
||||
if key == 'RESOLUTION':
|
||||
try:
|
||||
width, height = map(int, value.split('x'))
|
||||
stream_info['resolution'] = (width, height)
|
||||
except ValueError:
|
||||
stream_info['resolution'] = (0, 0)
|
||||
else:
|
||||
stream_info[key.lower().replace('-', '_')] = value
|
||||
|
||||
# The next line should be the stream URL
|
||||
if i + 1 < len(lines) and not lines[i + 1].startswith('#'):
|
||||
stream_url = lines[i + 1].strip()
|
||||
stream_info['url'] = urljoin(base_url, stream_url) if base_url else stream_url
|
||||
streams.append(stream_info)
|
||||
|
||||
return streams
|
||||
@@ -3,7 +3,7 @@ import typing
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from urllib import parse
|
||||
from urllib.parse import urlencode
|
||||
from urllib.parse import urlencode, urlparse
|
||||
|
||||
import anyio
|
||||
import h11
|
||||
@@ -81,6 +81,10 @@ async def fetch_with_retry(client, method, url, headers, follow_redirects=True,
|
||||
|
||||
|
||||
class Streamer:
|
||||
# PNG signature and IEND marker for fake PNG header detection (StreamWish/FileMoon)
|
||||
_PNG_SIGNATURE = b"\x89PNG\r\n\x1a\n"
|
||||
_PNG_IEND_MARKER = b"\x49\x45\x4E\x44\xAE\x42\x60\x82"
|
||||
|
||||
def __init__(self, client):
|
||||
"""
|
||||
Initializes the Streamer with an HTTP client.
|
||||
@@ -132,13 +136,48 @@ class Streamer:
|
||||
logger.error(f"Error creating streaming response: {e}")
|
||||
raise RuntimeError(f"Error creating streaming response: {e}")
|
||||
|
||||
@staticmethod
|
||||
def _strip_fake_png_wrapper(chunk: bytes) -> bytes:
|
||||
"""
|
||||
Strip fake PNG wrapper from chunk data.
|
||||
|
||||
Some streaming services (StreamWish, FileMoon) prepend a fake PNG image
|
||||
to video data to evade detection. This method detects and removes it.
|
||||
|
||||
Args:
|
||||
chunk: The raw chunk data that may contain a fake PNG header.
|
||||
|
||||
Returns:
|
||||
The chunk with fake PNG wrapper removed, or original chunk if not present.
|
||||
"""
|
||||
if not chunk.startswith(Streamer._PNG_SIGNATURE):
|
||||
return chunk
|
||||
|
||||
# Find the IEND marker that signals end of PNG data
|
||||
iend_pos = chunk.find(Streamer._PNG_IEND_MARKER)
|
||||
if iend_pos == -1:
|
||||
# IEND not found in this chunk - return as-is to avoid data corruption
|
||||
logger.debug("PNG signature detected but IEND marker not found in chunk")
|
||||
return chunk
|
||||
|
||||
# Calculate position after IEND marker
|
||||
content_start = iend_pos + len(Streamer._PNG_IEND_MARKER)
|
||||
|
||||
# Skip any padding bytes (null or 0xFF) between PNG and actual content
|
||||
while content_start < len(chunk) and chunk[content_start] in (0x00, 0xFF):
|
||||
content_start += 1
|
||||
|
||||
stripped_bytes = content_start
|
||||
logger.debug(f"Stripped {stripped_bytes} bytes of fake PNG wrapper from stream")
|
||||
|
||||
return chunk[content_start:]
|
||||
|
||||
async def stream_content(self) -> typing.AsyncGenerator[bytes, None]:
|
||||
"""
|
||||
Streams the content from the response.
|
||||
"""
|
||||
if not self.response:
|
||||
raise RuntimeError("No response available for streaming")
|
||||
|
||||
is_first_chunk = True
|
||||
|
||||
try:
|
||||
self.parse_content_range()
|
||||
|
||||
@@ -154,15 +193,19 @@ class Streamer:
|
||||
mininterval=1,
|
||||
) as self.progress_bar:
|
||||
async for chunk in self.response.aiter_bytes():
|
||||
if is_first_chunk:
|
||||
is_first_chunk = False
|
||||
chunk = self._strip_fake_png_wrapper(chunk)
|
||||
|
||||
yield chunk
|
||||
chunk_size = len(chunk)
|
||||
self.bytes_transferred += chunk_size
|
||||
self.progress_bar.set_postfix_str(
|
||||
f"📥 : {self.format_bytes(self.bytes_transferred)}", refresh=False
|
||||
)
|
||||
self.progress_bar.update(chunk_size)
|
||||
self.bytes_transferred += len(chunk)
|
||||
self.progress_bar.update(len(chunk))
|
||||
else:
|
||||
async for chunk in self.response.aiter_bytes():
|
||||
if is_first_chunk:
|
||||
is_first_chunk = False
|
||||
chunk = self._strip_fake_png_wrapper(chunk)
|
||||
|
||||
yield chunk
|
||||
self.bytes_transferred += len(chunk)
|
||||
|
||||
@@ -187,10 +230,19 @@ class Streamer:
|
||||
raise DownloadError(502, f"Protocol error while streaming: {e}")
|
||||
except GeneratorExit:
|
||||
logger.info("Streaming session stopped by the user")
|
||||
except httpx.ReadError as e:
|
||||
# Handle network read errors gracefully - these occur when upstream connection drops
|
||||
logger.warning(f"ReadError while streaming: {e}")
|
||||
if self.bytes_transferred > 0:
|
||||
logger.info(f"Partial content received ({self.bytes_transferred} bytes) before ReadError. Graceful termination.")
|
||||
return
|
||||
else:
|
||||
raise DownloadError(502, f"ReadError while streaming: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error streaming content: {e}")
|
||||
raise
|
||||
|
||||
|
||||
@staticmethod
|
||||
def format_bytes(size) -> str:
|
||||
power = 2**10
|
||||
@@ -490,6 +542,23 @@ def get_proxy_headers(request: Request) -> ProxyRequestHeaders:
|
||||
"""
|
||||
request_headers = {k: v for k, v in request.headers.items() if k in SUPPORTED_REQUEST_HEADERS}
|
||||
request_headers.update({k[2:].lower(): v for k, v in request.query_params.items() if k.startswith("h_")})
|
||||
request_headers.setdefault("user-agent", settings.user_agent)
|
||||
|
||||
# Handle common misspelling of referer
|
||||
if "referrer" in request_headers:
|
||||
if "referer" not in request_headers:
|
||||
request_headers["referer"] = request_headers.pop("referrer")
|
||||
|
||||
dest = request.query_params.get("d", "")
|
||||
host = urlparse(dest).netloc.lower()
|
||||
|
||||
if "vidoza" in host or "videzz" in host:
|
||||
# Remove ALL empty headers
|
||||
for h in list(request_headers.keys()):
|
||||
v = request_headers[h]
|
||||
if v is None or v.strip() == "":
|
||||
request_headers.pop(h, None)
|
||||
|
||||
response_headers = {k[2:].lower(): v for k, v in request.query_params.items() if k.startswith("r_")}
|
||||
return ProxyRequestHeaders(request_headers, response_headers)
|
||||
|
||||
@@ -527,21 +596,14 @@ class EnhancedStreamingResponse(Response):
|
||||
logger.error(f"Error in listen_for_disconnect: {str(e)}")
|
||||
|
||||
async def stream_response(self, send: Send) -> None:
|
||||
# Track if response headers have been sent to prevent duplicate headers
|
||||
response_started = False
|
||||
# Track if response finalization (more_body: False) has been sent to prevent ASGI protocol violation
|
||||
finalization_sent = False
|
||||
try:
|
||||
# Initialize headers
|
||||
headers = list(self.raw_headers)
|
||||
|
||||
# Set the transfer-encoding to chunked for streamed responses with content-length
|
||||
# when content-length is present. This ensures we don't hit protocol errors
|
||||
# if the upstream connection is closed prematurely.
|
||||
for i, (name, _) in enumerate(headers):
|
||||
if name.lower() == b"content-length":
|
||||
# Replace content-length with transfer-encoding: chunked for streaming
|
||||
headers[i] = (b"transfer-encoding", b"chunked")
|
||||
headers = [h for h in headers if h[0].lower() != b"content-length"]
|
||||
logger.debug("Switched from content-length to chunked transfer-encoding for streaming")
|
||||
break
|
||||
|
||||
# Start the response
|
||||
await send(
|
||||
{
|
||||
@@ -550,6 +612,7 @@ class EnhancedStreamingResponse(Response):
|
||||
"headers": headers,
|
||||
}
|
||||
)
|
||||
response_started = True
|
||||
|
||||
# Track if we've sent any data
|
||||
data_sent = False
|
||||
@@ -568,27 +631,29 @@ class EnhancedStreamingResponse(Response):
|
||||
|
||||
# Successfully streamed all content
|
||||
await send({"type": "http.response.body", "body": b"", "more_body": False})
|
||||
except (httpx.RemoteProtocolError, h11._util.LocalProtocolError) as e:
|
||||
# Handle connection closed errors
|
||||
finalization_sent = True
|
||||
except (httpx.RemoteProtocolError, httpx.ReadError, h11._util.LocalProtocolError) as e:
|
||||
# Handle connection closed / read errors gracefully
|
||||
if data_sent:
|
||||
# We've sent some data to the client, so try to complete the response
|
||||
logger.warning(f"Remote protocol error after partial streaming: {e}")
|
||||
logger.warning(f"Upstream connection error after partial streaming: {e}")
|
||||
try:
|
||||
await send({"type": "http.response.body", "body": b"", "more_body": False})
|
||||
finalization_sent = True
|
||||
logger.info(
|
||||
f"Response finalized after partial content ({self.actual_content_length} bytes transferred)"
|
||||
)
|
||||
except Exception as close_err:
|
||||
logger.warning(f"Could not finalize response after remote error: {close_err}")
|
||||
logger.warning(f"Could not finalize response after upstream error: {close_err}")
|
||||
else:
|
||||
# No data was sent, re-raise the error
|
||||
logger.error(f"Protocol error before any data was streamed: {e}")
|
||||
logger.error(f"Upstream error before any data was streamed: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"Error in stream_response: {str(e)}")
|
||||
if not isinstance(e, (ConnectionResetError, anyio.BrokenResourceError)):
|
||||
if not isinstance(e, (ConnectionResetError, anyio.BrokenResourceError)) and not response_started:
|
||||
# Only attempt to send error response if headers haven't been sent yet
|
||||
try:
|
||||
# Try to send an error response if client is still connected
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
@@ -598,36 +663,39 @@ class EnhancedStreamingResponse(Response):
|
||||
)
|
||||
error_message = f"Streaming error: {str(e)}".encode("utf-8")
|
||||
await send({"type": "http.response.body", "body": error_message, "more_body": False})
|
||||
finalization_sent = True
|
||||
except Exception:
|
||||
# If we can't send an error response, just log it
|
||||
pass
|
||||
elif response_started and not finalization_sent:
|
||||
# Response already started but not finalized - gracefully close the stream
|
||||
try:
|
||||
await send({"type": "http.response.body", "body": b"", "more_body": False})
|
||||
finalization_sent = True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
async with anyio.create_task_group() as task_group:
|
||||
streaming_completed = False
|
||||
stream_func = partial(self.stream_response, send)
|
||||
listen_func = partial(self.listen_for_disconnect, receive)
|
||||
|
||||
async def wrap(func: typing.Callable[[], typing.Awaitable[None]]) -> None:
|
||||
try:
|
||||
await func()
|
||||
# If this is the stream_response function and it completes successfully, mark as done
|
||||
if func == stream_func:
|
||||
nonlocal streaming_completed
|
||||
streaming_completed = True
|
||||
except Exception as e:
|
||||
if isinstance(e, (httpx.RemoteProtocolError, h11._util.LocalProtocolError)):
|
||||
# Handle protocol errors more gracefully
|
||||
logger.warning(f"Protocol error during streaming: {e}")
|
||||
elif not isinstance(e, anyio.get_cancelled_exc_class()):
|
||||
logger.exception("Error in streaming task")
|
||||
# Only re-raise if it's not a protocol error or cancellation
|
||||
# Note: stream_response and listen_for_disconnect handle their own exceptions
|
||||
# internally. This is a safety net for any unexpected exceptions that might
|
||||
# escape due to future code changes.
|
||||
if not isinstance(e, anyio.get_cancelled_exc_class()):
|
||||
logger.exception(f"Unexpected error in streaming task: {type(e).__name__}: {e}")
|
||||
# Re-raise unexpected errors to surface bugs rather than silently swallowing them
|
||||
raise
|
||||
finally:
|
||||
# Only cancel the task group if we're in disconnect listener or
|
||||
# if streaming_completed is True (meaning we finished normally)
|
||||
if func == listen_func or streaming_completed:
|
||||
task_group.cancel_scope.cancel()
|
||||
# Cancel task group when either task completes or fails:
|
||||
# - stream_func finished (success or failure) -> stop listening for disconnect
|
||||
# - listen_func finished (client disconnected) -> stop streaming
|
||||
task_group.cancel_scope.cancel()
|
||||
|
||||
# Start the streaming response in a separate task
|
||||
task_group.start_soon(wrap, stream_func)
|
||||
|
||||
@@ -11,7 +11,7 @@ from mediaflow_proxy.utils.hls_prebuffer import hls_prebuffer
|
||||
|
||||
|
||||
class M3U8Processor:
|
||||
def __init__(self, request, key_url: str = None, force_playlist_proxy: bool = None):
|
||||
def __init__(self, request, key_url: str = None, force_playlist_proxy: bool = None, key_only_proxy: bool = False, no_proxy: bool = False):
|
||||
"""
|
||||
Initializes the M3U8Processor with the request and URL prefix.
|
||||
|
||||
@@ -19,9 +19,13 @@ class M3U8Processor:
|
||||
request (Request): The incoming HTTP request.
|
||||
key_url (HttpUrl, optional): The URL of the key server. Defaults to None.
|
||||
force_playlist_proxy (bool, optional): Force all playlist URLs to be proxied through MediaFlow. Defaults to None.
|
||||
key_only_proxy (bool, optional): Only proxy the key URL, leaving segment URLs direct. Defaults to False.
|
||||
no_proxy (bool, optional): If True, returns the manifest without proxying any URLs. Defaults to False.
|
||||
"""
|
||||
self.request = request
|
||||
self.key_url = parse.urlparse(key_url) if key_url else None
|
||||
self.key_only_proxy = key_only_proxy
|
||||
self.no_proxy = no_proxy
|
||||
self.force_playlist_proxy = force_playlist_proxy
|
||||
self.mediaflow_proxy_url = str(
|
||||
request.url_for("hls_manifest_proxy").replace(scheme=get_original_scheme(request))
|
||||
@@ -174,6 +178,15 @@ class M3U8Processor:
|
||||
Returns:
|
||||
str: The processed key line.
|
||||
"""
|
||||
# If no_proxy is enabled, just resolve relative URLs without proxying
|
||||
if self.no_proxy:
|
||||
uri_match = re.search(r'URI="([^"]+)"', line)
|
||||
if uri_match:
|
||||
original_uri = uri_match.group(1)
|
||||
full_url = parse.urljoin(base_url, original_uri)
|
||||
line = line.replace(f'URI="{original_uri}"', f'URI="{full_url}"')
|
||||
return line
|
||||
|
||||
uri_match = re.search(r'URI="([^"]+)"', line)
|
||||
if uri_match:
|
||||
original_uri = uri_match.group(1)
|
||||
@@ -197,6 +210,14 @@ class M3U8Processor:
|
||||
"""
|
||||
full_url = parse.urljoin(base_url, url)
|
||||
|
||||
# If no_proxy is enabled, return the direct URL without any proxying
|
||||
if self.no_proxy:
|
||||
return full_url
|
||||
|
||||
# If key_only_proxy is enabled, return the direct URL for segments
|
||||
if self.key_only_proxy and not url.endswith((".m3u", ".m3u8")):
|
||||
return full_url
|
||||
|
||||
# Determine routing strategy based on configuration
|
||||
routing_strategy = settings.m3u8_content_routing
|
||||
|
||||
|
||||
@@ -270,7 +270,7 @@ def parse_representation(
|
||||
if item:
|
||||
profile["segments"] = parse_segment_template(parsed_dict, item, profile, source)
|
||||
else:
|
||||
profile["segments"] = parse_segment_base(representation, source)
|
||||
profile["segments"] = parse_segment_base(representation, profile, source)
|
||||
|
||||
return profile
|
||||
|
||||
@@ -547,7 +547,7 @@ def create_segment_data(segment: Dict, item: dict, profile: dict, source: str, t
|
||||
return segment_data
|
||||
|
||||
|
||||
def parse_segment_base(representation: dict, source: str) -> List[Dict]:
|
||||
def parse_segment_base(representation: dict, profile: dict, source: str) -> List[Dict]:
|
||||
"""
|
||||
Parses segment base information and extracts segment data. This is used for single-segment representations.
|
||||
|
||||
@@ -562,6 +562,12 @@ def parse_segment_base(representation: dict, source: str) -> List[Dict]:
|
||||
start, end = map(int, segment["@indexRange"].split("-"))
|
||||
if "Initialization" in segment:
|
||||
start, _ = map(int, segment["Initialization"]["@range"].split("-"))
|
||||
|
||||
# Set initUrl for SegmentBase
|
||||
if not representation['BaseURL'].startswith("http"):
|
||||
profile["initUrl"] = f"{source}/{representation['BaseURL']}"
|
||||
else:
|
||||
profile["initUrl"] = representation['BaseURL']
|
||||
|
||||
return [
|
||||
{
|
||||
|
||||
122
mediaflow_proxy/utils/python_aes.py
Normal file
122
mediaflow_proxy/utils/python_aes.py
Normal file
@@ -0,0 +1,122 @@
|
||||
# Author: Trevor Perrin
|
||||
# See the LICENSE file for legal information regarding use of this file.
|
||||
|
||||
"""Pure-Python AES implementation."""
|
||||
|
||||
import sys
|
||||
from .aes import AES
|
||||
from .rijndael import Rijndael
|
||||
from .cryptomath import bytesToNumber, numberToByteArray
|
||||
|
||||
__all__ = ['new', 'Python_AES']
|
||||
|
||||
|
||||
def new(key, mode, IV):
|
||||
# IV argument name is a part of the interface
|
||||
# pylint: disable=invalid-name
|
||||
if mode == 2:
|
||||
return Python_AES(key, mode, IV)
|
||||
elif mode == 6:
|
||||
return Python_AES_CTR(key, mode, IV)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class Python_AES(AES):
|
||||
def __init__(self, key, mode, IV):
|
||||
# IV argument/field names are a part of the interface
|
||||
# pylint: disable=invalid-name
|
||||
key, IV = bytearray(key), bytearray(IV)
|
||||
super(Python_AES, self).__init__(key, mode, IV, "python")
|
||||
self.rijndael = Rijndael(key, 16)
|
||||
self.IV = IV
|
||||
|
||||
def encrypt(self, plaintext):
|
||||
super(Python_AES, self).encrypt(plaintext)
|
||||
|
||||
plaintextBytes = bytearray(plaintext)
|
||||
chainBytes = self.IV[:]
|
||||
|
||||
#CBC Mode: For each block...
|
||||
for x in range(len(plaintextBytes)//16):
|
||||
|
||||
#XOR with the chaining block
|
||||
blockBytes = plaintextBytes[x*16 : (x*16)+16]
|
||||
for y in range(16):
|
||||
blockBytes[y] ^= chainBytes[y]
|
||||
|
||||
#Encrypt it
|
||||
encryptedBytes = self.rijndael.encrypt(blockBytes)
|
||||
|
||||
#Overwrite the input with the output
|
||||
for y in range(16):
|
||||
plaintextBytes[(x*16)+y] = encryptedBytes[y]
|
||||
|
||||
#Set the next chaining block
|
||||
chainBytes = encryptedBytes
|
||||
|
||||
self.IV = chainBytes[:]
|
||||
return plaintextBytes
|
||||
|
||||
def decrypt(self, ciphertext):
|
||||
super(Python_AES, self).decrypt(ciphertext)
|
||||
|
||||
ciphertextBytes = ciphertext[:]
|
||||
chainBytes = self.IV[:]
|
||||
|
||||
#CBC Mode: For each block...
|
||||
for x in range(len(ciphertextBytes)//16):
|
||||
|
||||
#Decrypt it
|
||||
blockBytes = ciphertextBytes[x*16 : (x*16)+16]
|
||||
decryptedBytes = self.rijndael.decrypt(blockBytes)
|
||||
|
||||
#XOR with the chaining block and overwrite the input with output
|
||||
for y in range(16):
|
||||
decryptedBytes[y] ^= chainBytes[y]
|
||||
ciphertextBytes[(x*16)+y] = decryptedBytes[y]
|
||||
|
||||
#Set the next chaining block
|
||||
chainBytes = blockBytes
|
||||
|
||||
self.IV = chainBytes[:]
|
||||
return ciphertextBytes
|
||||
|
||||
|
||||
class Python_AES_CTR(AES):
|
||||
def __init__(self, key, mode, IV):
|
||||
super(Python_AES_CTR, self).__init__(key, mode, IV, "python")
|
||||
self.rijndael = Rijndael(key, 16)
|
||||
self.IV = IV
|
||||
self._counter_bytes = 16 - len(self.IV)
|
||||
self._counter = self.IV + bytearray(b'\x00' * self._counter_bytes)
|
||||
|
||||
@property
|
||||
def counter(self):
|
||||
return self._counter
|
||||
|
||||
@counter.setter
|
||||
def counter(self, ctr):
|
||||
self._counter = ctr
|
||||
|
||||
def _counter_update(self):
|
||||
counter_int = bytesToNumber(self._counter) + 1
|
||||
self._counter = numberToByteArray(counter_int, 16)
|
||||
if self._counter_bytes > 0 and \
|
||||
self._counter[-self._counter_bytes:] == \
|
||||
bytearray(b'\xff' * self._counter_bytes):
|
||||
raise OverflowError("CTR counter overflowed")
|
||||
|
||||
def encrypt(self, plaintext):
|
||||
mask = bytearray()
|
||||
while len(mask) < len(plaintext):
|
||||
mask += self.rijndael.encrypt(self._counter)
|
||||
self._counter_update()
|
||||
if sys.version_info < (3, 0):
|
||||
inp_bytes = bytearray(ord(i) ^ j for i, j in zip(plaintext, mask))
|
||||
else:
|
||||
inp_bytes = bytearray(i ^ j for i, j in zip(plaintext, mask))
|
||||
return inp_bytes
|
||||
|
||||
def decrypt(self, ciphertext):
|
||||
return self.encrypt(ciphertext)
|
||||
12
mediaflow_proxy/utils/python_aesgcm.py
Normal file
12
mediaflow_proxy/utils/python_aesgcm.py
Normal file
@@ -0,0 +1,12 @@
|
||||
# mediaflow_proxy/utils/python_aesgcm.py
|
||||
|
||||
from .aesgcm import AESGCM
|
||||
from .rijndael import Rijndael
|
||||
|
||||
|
||||
def new(key: bytes) -> AESGCM:
|
||||
"""
|
||||
Mirror ResolveURL's python_aesgcm.new(key) API:
|
||||
returns an AESGCM instance with pure-Python Rijndael backend.
|
||||
"""
|
||||
return AESGCM(key, "python", Rijndael(key, 16).encrypt)
|
||||
1118
mediaflow_proxy/utils/rijndael.py
Normal file
1118
mediaflow_proxy/utils/rijndael.py
Normal file
File diff suppressed because it is too large
Load Diff
32
mediaflow_proxy/utils/tlshashlib.py
Normal file
32
mediaflow_proxy/utils/tlshashlib.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# Author: Hubert Kario (c) 2015
|
||||
# see LICENCE file for legal information regarding use of this file
|
||||
|
||||
"""hashlib that handles FIPS mode."""
|
||||
|
||||
# Because we are extending the hashlib module, we need to import all its
|
||||
# fields to suppport the same uses
|
||||
# pylint: disable=unused-wildcard-import, wildcard-import
|
||||
from hashlib import *
|
||||
# pylint: enable=unused-wildcard-import, wildcard-import
|
||||
import hashlib
|
||||
|
||||
|
||||
def _fipsFunction(func, *args, **kwargs):
|
||||
"""Make hash function support FIPS mode."""
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except ValueError:
|
||||
return func(*args, usedforsecurity=False, **kwargs)
|
||||
|
||||
|
||||
# redefining the function is exactly what we intend to do
|
||||
# pylint: disable=function-redefined
|
||||
def md5(*args, **kwargs):
|
||||
"""MD5 constructor that works in FIPS mode."""
|
||||
return _fipsFunction(hashlib.md5, *args, **kwargs)
|
||||
|
||||
|
||||
def new(*args, **kwargs):
|
||||
"""General constructor that works in FIPS mode."""
|
||||
return _fipsFunction(hashlib.new, *args, **kwargs)
|
||||
# pylint: enable=function-redefined
|
||||
88
mediaflow_proxy/utils/tlshmac.py
Normal file
88
mediaflow_proxy/utils/tlshmac.py
Normal file
@@ -0,0 +1,88 @@
|
||||
# Author: Hubert Kario (c) 2019
|
||||
# see LICENCE file for legal information regarding use of this file
|
||||
|
||||
"""
|
||||
HMAC module that works in FIPS mode.
|
||||
|
||||
Note that this makes this code FIPS non-compliant!
|
||||
"""
|
||||
|
||||
# Because we are extending the hashlib module, we need to import all its
|
||||
# fields to suppport the same uses
|
||||
from . import tlshashlib
|
||||
from .compat import compatHMAC
|
||||
try:
|
||||
from hmac import compare_digest
|
||||
__all__ = ["new", "compare_digest", "HMAC"]
|
||||
except ImportError:
|
||||
__all__ = ["new", "HMAC"]
|
||||
|
||||
try:
|
||||
from hmac import HMAC, new
|
||||
# if we can calculate HMAC on MD5, then use the built-in HMAC
|
||||
# implementation
|
||||
_val = HMAC(b'some key', b'msg', 'md5')
|
||||
_val.digest()
|
||||
del _val
|
||||
except Exception:
|
||||
# fallback only when MD5 doesn't work
|
||||
class HMAC(object):
|
||||
"""Hacked version of HMAC that works in FIPS mode even with MD5."""
|
||||
|
||||
def __init__(self, key, msg=None, digestmod=None):
|
||||
"""
|
||||
Initialise the HMAC and hash first portion of data.
|
||||
|
||||
msg: data to hash
|
||||
digestmod: name of hash or object that be used as a hash and be cloned
|
||||
"""
|
||||
self.key = key
|
||||
if digestmod is None:
|
||||
digestmod = 'md5'
|
||||
if callable(digestmod):
|
||||
digestmod = digestmod()
|
||||
if not hasattr(digestmod, 'digest_size'):
|
||||
digestmod = tlshashlib.new(digestmod)
|
||||
self.block_size = digestmod.block_size
|
||||
self.digest_size = digestmod.digest_size
|
||||
self.digestmod = digestmod
|
||||
if len(key) > self.block_size:
|
||||
k_hash = digestmod.copy()
|
||||
k_hash.update(compatHMAC(key))
|
||||
key = k_hash.digest()
|
||||
if len(key) < self.block_size:
|
||||
key = key + b'\x00' * (self.block_size - len(key))
|
||||
key = bytearray(key)
|
||||
ipad = bytearray(b'\x36' * self.block_size)
|
||||
opad = bytearray(b'\x5c' * self.block_size)
|
||||
i_key = bytearray(i ^ j for i, j in zip(key, ipad))
|
||||
self._o_key = bytearray(i ^ j for i, j in zip(key, opad))
|
||||
self._context = digestmod.copy()
|
||||
self._context.update(compatHMAC(i_key))
|
||||
if msg:
|
||||
self._context.update(compatHMAC(msg))
|
||||
|
||||
def update(self, msg):
|
||||
self._context.update(compatHMAC(msg))
|
||||
|
||||
def digest(self):
|
||||
i_digest = self._context.digest()
|
||||
o_hash = self.digestmod.copy()
|
||||
o_hash.update(compatHMAC(self._o_key))
|
||||
o_hash.update(compatHMAC(i_digest))
|
||||
return o_hash.digest()
|
||||
|
||||
def copy(self):
|
||||
new = HMAC.__new__(HMAC)
|
||||
new.key = self.key
|
||||
new.digestmod = self.digestmod
|
||||
new.block_size = self.block_size
|
||||
new.digest_size = self.digest_size
|
||||
new._o_key = self._o_key
|
||||
new._context = self._context.copy()
|
||||
return new
|
||||
|
||||
|
||||
def new(*args, **kwargs):
|
||||
"""General constructor that works in FIPS mode."""
|
||||
return HMAC(*args, **kwargs)
|
||||
Reference in New Issue
Block a user