mirror of
https://github.com/UrloMythus/UnHided.git
synced 2026-04-11 11:50:51 +00:00
update
This commit is contained in:
469
mediaflow_proxy/remuxer/mkv_demuxer.py
Normal file
469
mediaflow_proxy/remuxer/mkv_demuxer.py
Normal file
@@ -0,0 +1,469 @@
|
||||
"""
|
||||
Streaming MKV demuxer.
|
||||
|
||||
Reads an MKV byte stream via an async iterator and yields individual media
|
||||
frames (MKVFrame) with absolute timestamps. Designed for on-the-fly remuxing
|
||||
without buffering the entire file.
|
||||
|
||||
Architecture:
|
||||
AsyncIterator[bytes] -> StreamBuffer -> EBML parsing -> MKVFrame yields
|
||||
|
||||
The demuxer works in two phases:
|
||||
1. read_header(): Consume bytes until Tracks is fully parsed, returning
|
||||
a list of MKVTrack with codec metadata.
|
||||
2. iter_frames(): Yield MKVFrame objects from Cluster/SimpleBlock data
|
||||
as clusters arrive.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import AsyncIterator
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from mediaflow_proxy.remuxer.ebml_parser import (
|
||||
CLUSTER,
|
||||
CLUSTER_TIMESTAMP,
|
||||
EBML_HEADER,
|
||||
INFO,
|
||||
MKVFrame,
|
||||
MKVTrack,
|
||||
SEGMENT,
|
||||
SIMPLE_BLOCK,
|
||||
BLOCK_GROUP,
|
||||
TRACKS,
|
||||
TIMESTAMP_SCALE,
|
||||
DURATION,
|
||||
UNKNOWN_SIZE,
|
||||
extract_block_frames,
|
||||
parse_tracks,
|
||||
read_element_id,
|
||||
read_element_size,
|
||||
read_float,
|
||||
read_uint,
|
||||
_parse_block_group,
|
||||
iter_elements,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StreamBuffer:
|
||||
"""
|
||||
Accumulating byte buffer for streaming EBML parsing.
|
||||
|
||||
Collects chunks from an async byte source and provides read-ahead
|
||||
capabilities for EBML element parsing. Supports consuming parsed
|
||||
bytes to keep memory usage bounded.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._chunks: list[bytes] = []
|
||||
self._total: int = 0
|
||||
self._consumed: int = 0 # Logical bytes consumed (for offset tracking)
|
||||
|
||||
@property
|
||||
def available(self) -> int:
|
||||
"""Number of buffered bytes available for reading."""
|
||||
return self._total
|
||||
|
||||
@property
|
||||
def consumed(self) -> int:
|
||||
"""Total bytes consumed so far (for absolute offset tracking)."""
|
||||
return self._consumed
|
||||
|
||||
def append(self, data: bytes) -> None:
|
||||
"""Add bytes to the buffer."""
|
||||
if data:
|
||||
self._chunks.append(data)
|
||||
self._total += len(data)
|
||||
|
||||
def peek(self, size: int) -> bytes:
|
||||
"""Read up to size bytes without consuming."""
|
||||
if size <= 0:
|
||||
return b""
|
||||
result = bytearray()
|
||||
remaining = size
|
||||
for chunk in self._chunks:
|
||||
if remaining <= 0:
|
||||
break
|
||||
take = min(len(chunk), remaining)
|
||||
result.extend(chunk[:take])
|
||||
remaining -= take
|
||||
return bytes(result)
|
||||
|
||||
def get_all(self) -> bytes:
|
||||
"""Get all buffered data as a single bytes object (without consuming)."""
|
||||
if len(self._chunks) == 1:
|
||||
return self._chunks[0]
|
||||
data = b"".join(self._chunks)
|
||||
self._chunks = [data]
|
||||
return data
|
||||
|
||||
def consume(self, size: int) -> bytes:
|
||||
"""Remove and return size bytes from the front of the buffer."""
|
||||
if size <= 0:
|
||||
return b""
|
||||
if size > self._total:
|
||||
size = self._total
|
||||
|
||||
result = bytearray()
|
||||
remaining = size
|
||||
while remaining > 0 and self._chunks:
|
||||
chunk = self._chunks[0]
|
||||
if len(chunk) <= remaining:
|
||||
result.extend(chunk)
|
||||
remaining -= len(chunk)
|
||||
self._chunks.pop(0)
|
||||
else:
|
||||
result.extend(chunk[:remaining])
|
||||
self._chunks[0] = chunk[remaining:]
|
||||
remaining = 0
|
||||
|
||||
consumed = len(result)
|
||||
self._total -= consumed
|
||||
self._consumed += consumed
|
||||
return bytes(result)
|
||||
|
||||
def skip(self, size: int) -> int:
|
||||
"""Discard size bytes from the front. Returns actual bytes skipped."""
|
||||
if size <= 0:
|
||||
return 0
|
||||
actual = min(size, self._total)
|
||||
remaining = actual
|
||||
while remaining > 0 and self._chunks:
|
||||
chunk = self._chunks[0]
|
||||
if len(chunk) <= remaining:
|
||||
remaining -= len(chunk)
|
||||
self._chunks.pop(0)
|
||||
else:
|
||||
self._chunks[0] = chunk[remaining:]
|
||||
remaining = 0
|
||||
self._total -= actual
|
||||
self._consumed += actual
|
||||
return actual
|
||||
|
||||
|
||||
@dataclass
|
||||
class MKVHeader:
|
||||
"""Parsed MKV header metadata."""
|
||||
|
||||
tracks: list[MKVTrack] = field(default_factory=list)
|
||||
timestamp_scale_ns: int = 1_000_000 # Default 1ms
|
||||
duration_ms: float = 0.0
|
||||
segment_data_offset: int = 0 # Absolute byte offset of Segment children
|
||||
|
||||
|
||||
class MKVDemuxer:
|
||||
"""
|
||||
Streaming async MKV demuxer.
|
||||
|
||||
Reads an MKV byte stream from an async iterator and provides:
|
||||
- read_header(): Parse EBML header + Segment metadata + Tracks
|
||||
- iter_frames(): Yield MKVFrame objects from Clusters
|
||||
|
||||
Usage:
|
||||
demuxer = MKVDemuxer()
|
||||
header = await demuxer.read_header(source)
|
||||
async for frame in demuxer.iter_frames(source):
|
||||
process(frame)
|
||||
"""
|
||||
|
||||
# Minimum bytes to try parsing an element header (ID + size)
|
||||
_MIN_ELEMENT_HEADER = 12
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._buf = StreamBuffer()
|
||||
self._header: MKVHeader | None = None
|
||||
self._scale_ms: float = 1.0 # timestamp_scale / 1_000_000
|
||||
|
||||
@property
|
||||
def header(self) -> MKVHeader | None:
|
||||
return self._header
|
||||
|
||||
async def read_header(self, source: AsyncIterator[bytes]) -> MKVHeader:
|
||||
"""
|
||||
Read and parse the MKV header (EBML header, Segment, Info, Tracks).
|
||||
|
||||
Consumes bytes from source until Tracks is fully parsed. Any leftover
|
||||
bytes (start of first Cluster) remain in the internal buffer for
|
||||
iter_frames().
|
||||
|
||||
Returns:
|
||||
MKVHeader with track info and timing metadata.
|
||||
"""
|
||||
header = MKVHeader()
|
||||
|
||||
# Phase 1: Accumulate enough data for EBML header + Segment header
|
||||
await self._ensure_bytes(source, 64)
|
||||
|
||||
data = self._buf.get_all()
|
||||
if len(data) < 4:
|
||||
raise ValueError(
|
||||
f"Source ended prematurely: got {len(data)} bytes, need at least an EBML header (source disconnected?)"
|
||||
)
|
||||
pos = 0
|
||||
|
||||
# Parse EBML Header
|
||||
eid, pos = read_element_id(data, pos)
|
||||
if eid != EBML_HEADER:
|
||||
raise ValueError(f"Not an MKV file: expected EBML header, got 0x{eid:X}")
|
||||
size, pos = read_element_size(data, pos)
|
||||
if size == UNKNOWN_SIZE:
|
||||
raise ValueError("EBML header has unknown size")
|
||||
pos += size # Skip EBML header content
|
||||
|
||||
# Parse Segment element header
|
||||
eid, pos = read_element_id(data, pos)
|
||||
if eid != SEGMENT:
|
||||
raise ValueError(f"Expected Segment, got 0x{eid:X}")
|
||||
_seg_size, pos = read_element_size(data, pos)
|
||||
header.segment_data_offset = self._buf.consumed + pos
|
||||
|
||||
# Phase 2: Parse Segment children until we have Tracks
|
||||
# We need to iterate top-level Segment children: SeekHead, Info, Tracks
|
||||
# Stop when we hit the first Cluster (media data).
|
||||
tracks_found = False
|
||||
|
||||
while not tracks_found:
|
||||
# Ensure we have enough for element header
|
||||
await self._ensure_bytes(source, pos + self._MIN_ELEMENT_HEADER)
|
||||
data = self._buf.get_all()
|
||||
|
||||
if pos >= len(data):
|
||||
break
|
||||
|
||||
try:
|
||||
eid, pos2 = read_element_id(data, pos)
|
||||
size, pos3 = read_element_size(data, pos2)
|
||||
except (ValueError, IndexError):
|
||||
await self._ensure_bytes(source, pos + 32)
|
||||
data = self._buf.get_all()
|
||||
try:
|
||||
eid, pos2 = read_element_id(data, pos)
|
||||
size, pos3 = read_element_size(data, pos2)
|
||||
except (ValueError, IndexError):
|
||||
break
|
||||
|
||||
if eid == CLUSTER:
|
||||
# Reached media data; header parsing is done.
|
||||
# Don't consume the Cluster -- leave it for iter_frames.
|
||||
break
|
||||
|
||||
if size == UNKNOWN_SIZE:
|
||||
# Can't handle unknown-size elements in header
|
||||
logger.warning("[mkv_demuxer] Unknown-size element 0x%X in header at pos %d", eid, pos)
|
||||
break
|
||||
|
||||
# Ensure we have the full element
|
||||
elem_end = pos3 + size
|
||||
await self._ensure_bytes(source, elem_end)
|
||||
data = self._buf.get_all()
|
||||
|
||||
if eid == INFO:
|
||||
self._parse_info_element(data, pos3, pos3 + size, header)
|
||||
elif eid == TRACKS:
|
||||
header.tracks = parse_tracks(data, pos3, pos3 + size)
|
||||
tracks_found = True
|
||||
logger.info(
|
||||
"[mkv_demuxer] Parsed %d tracks: %s",
|
||||
len(header.tracks),
|
||||
", ".join(f"#{t.track_number}={t.codec_id}" for t in header.tracks),
|
||||
)
|
||||
|
||||
pos = elem_end
|
||||
|
||||
# Consume everything up to the current position (Cluster boundary)
|
||||
self._buf.consume(pos)
|
||||
|
||||
# Set timing scale
|
||||
self._scale_ms = header.timestamp_scale_ns / 1_000_000.0
|
||||
self._header = header
|
||||
return header
|
||||
|
||||
async def iter_frames(self, source: AsyncIterator[bytes]) -> AsyncIterator[MKVFrame]:
|
||||
"""
|
||||
Yield MKVFrame objects from Cluster/SimpleBlock data.
|
||||
|
||||
Must be called after read_header(). Continues consuming bytes from
|
||||
source, parsing Clusters and yielding individual frames.
|
||||
"""
|
||||
if self._header is None:
|
||||
raise RuntimeError("read_header() must be called before iter_frames()")
|
||||
|
||||
while True:
|
||||
# Try to read the next element header
|
||||
if not await self._ensure_bytes_soft(source, self._MIN_ELEMENT_HEADER):
|
||||
break
|
||||
|
||||
data = self._buf.get_all()
|
||||
pos = 0
|
||||
|
||||
try:
|
||||
eid, pos2 = read_element_id(data, pos)
|
||||
size, pos3 = read_element_size(data, pos2)
|
||||
except (ValueError, IndexError):
|
||||
# Try to get more data
|
||||
if not await self._ensure_bytes_soft(source, len(data) + 4096):
|
||||
break
|
||||
data = self._buf.get_all()
|
||||
try:
|
||||
eid, pos2 = read_element_id(data, pos)
|
||||
size, pos3 = read_element_size(data, pos2)
|
||||
except (ValueError, IndexError):
|
||||
break
|
||||
|
||||
if eid == CLUSTER:
|
||||
if size == UNKNOWN_SIZE:
|
||||
# Unknown-size Cluster: parse children until we hit the next
|
||||
# Cluster or run out of data
|
||||
self._buf.consume(pos3) # consume Cluster header
|
||||
async for frame in self._parse_unknown_size_cluster(source):
|
||||
yield frame
|
||||
else:
|
||||
# Known-size Cluster: ensure we have all data
|
||||
elem_end = pos3 + size
|
||||
await self._ensure_bytes(source, elem_end)
|
||||
data = self._buf.get_all()
|
||||
|
||||
for frame in self._parse_cluster_data(data, pos3, pos3 + size):
|
||||
yield frame
|
||||
|
||||
self._buf.consume(elem_end)
|
||||
else:
|
||||
# Skip non-Cluster top-level elements
|
||||
if size == UNKNOWN_SIZE:
|
||||
break
|
||||
elem_end = pos3 + size
|
||||
if elem_end > len(data):
|
||||
# Need to skip bytes we don't have yet
|
||||
self._buf.consume(len(data))
|
||||
skip_remaining = elem_end - len(data)
|
||||
await self._skip_bytes(source, skip_remaining)
|
||||
else:
|
||||
self._buf.consume(elem_end)
|
||||
|
||||
def _parse_info_element(self, data: bytes, start: int, end: int, header: MKVHeader) -> None:
|
||||
"""Parse Info element children for timestamp scale and duration."""
|
||||
for eid, off, size, _ in iter_elements(data, start, end):
|
||||
if eid == TIMESTAMP_SCALE:
|
||||
header.timestamp_scale_ns = read_uint(data, off, size)
|
||||
elif eid == DURATION:
|
||||
scale = header.timestamp_scale_ns / 1_000_000.0
|
||||
header.duration_ms = read_float(data, off, size) * scale
|
||||
|
||||
def _parse_cluster_data(self, data: bytes, start: int, end: int) -> list[MKVFrame]:
|
||||
"""Parse a known-size Cluster and return its frames."""
|
||||
cluster_timecode = 0
|
||||
frames = []
|
||||
|
||||
for eid, data_off, size, _ in iter_elements(data, start, end):
|
||||
if eid == CLUSTER_TIMESTAMP:
|
||||
cluster_timecode = read_uint(data, data_off, size)
|
||||
elif eid == SIMPLE_BLOCK:
|
||||
for track_num, rel_tc, flags, frame_list in extract_block_frames(data, data_off, size):
|
||||
is_kf = bool(flags & 0x80)
|
||||
abs_ts_ms = (cluster_timecode + rel_tc) * self._scale_ms
|
||||
for frame_data in frame_list:
|
||||
frames.append(
|
||||
MKVFrame(
|
||||
track_number=track_num,
|
||||
timestamp_ms=abs_ts_ms,
|
||||
is_keyframe=is_kf,
|
||||
data=frame_data,
|
||||
)
|
||||
)
|
||||
elif eid == BLOCK_GROUP:
|
||||
_parse_block_group(data, data_off, data_off + size, cluster_timecode, self._scale_ms, frames)
|
||||
|
||||
return frames
|
||||
|
||||
async def _parse_unknown_size_cluster(self, source: AsyncIterator[bytes]) -> AsyncIterator[MKVFrame]:
|
||||
"""Parse an unknown-size Cluster by reading children until next Cluster."""
|
||||
cluster_timecode = 0
|
||||
|
||||
while True:
|
||||
if not await self._ensure_bytes_soft(source, self._MIN_ELEMENT_HEADER):
|
||||
break
|
||||
|
||||
data = self._buf.get_all()
|
||||
pos = 0
|
||||
|
||||
try:
|
||||
eid, pos2 = read_element_id(data, pos)
|
||||
size, pos3 = read_element_size(data, pos2)
|
||||
except (ValueError, IndexError):
|
||||
if not await self._ensure_bytes_soft(source, len(data) + 4096):
|
||||
break
|
||||
data = self._buf.get_all()
|
||||
try:
|
||||
eid, pos2 = read_element_id(data, pos)
|
||||
size, pos3 = read_element_size(data, pos2)
|
||||
except (ValueError, IndexError):
|
||||
break
|
||||
|
||||
# A new Cluster or top-level element signals end of current Cluster
|
||||
if eid == CLUSTER or eid == SEGMENT:
|
||||
break
|
||||
|
||||
if size == UNKNOWN_SIZE:
|
||||
break
|
||||
|
||||
elem_end = pos3 + size
|
||||
await self._ensure_bytes(source, elem_end)
|
||||
data = self._buf.get_all()
|
||||
|
||||
if eid == CLUSTER_TIMESTAMP:
|
||||
cluster_timecode = read_uint(data, pos3, size)
|
||||
elif eid == SIMPLE_BLOCK:
|
||||
for track_num, rel_tc, flags, frame_list in extract_block_frames(data, pos3, size):
|
||||
is_kf = bool(flags & 0x80)
|
||||
abs_ts_ms = (cluster_timecode + rel_tc) * self._scale_ms
|
||||
for frame_data in frame_list:
|
||||
yield MKVFrame(
|
||||
track_number=track_num,
|
||||
timestamp_ms=abs_ts_ms,
|
||||
is_keyframe=is_kf,
|
||||
data=frame_data,
|
||||
)
|
||||
elif eid == BLOCK_GROUP:
|
||||
bg_frames = []
|
||||
_parse_block_group(data, pos3, pos3 + size, cluster_timecode, self._scale_ms, bg_frames)
|
||||
for frame in bg_frames:
|
||||
yield frame
|
||||
|
||||
self._buf.consume(elem_end)
|
||||
|
||||
async def _ensure_bytes(self, source: AsyncIterator[bytes], needed: int) -> None:
|
||||
"""Ensure the buffer has at least 'needed' bytes. Raises StopAsyncIteration if exhausted."""
|
||||
while self._buf.available < needed:
|
||||
try:
|
||||
chunk = await source.__anext__()
|
||||
self._buf.append(chunk)
|
||||
except StopAsyncIteration:
|
||||
return
|
||||
|
||||
async def _ensure_bytes_soft(self, source: AsyncIterator[bytes], needed: int) -> bool:
|
||||
"""Like _ensure_bytes but returns False instead of raising."""
|
||||
while self._buf.available < needed:
|
||||
try:
|
||||
chunk = await source.__anext__()
|
||||
if not chunk:
|
||||
return self._buf.available > 0
|
||||
self._buf.append(chunk)
|
||||
except StopAsyncIteration:
|
||||
return self._buf.available > 0
|
||||
return True
|
||||
|
||||
async def _skip_bytes(self, source: AsyncIterator[bytes], count: int) -> None:
|
||||
"""Skip count bytes from the source without buffering."""
|
||||
remaining = count
|
||||
while remaining > 0:
|
||||
try:
|
||||
chunk = await source.__anext__()
|
||||
if len(chunk) <= remaining:
|
||||
remaining -= len(chunk)
|
||||
else:
|
||||
# Put the excess back
|
||||
self._buf.append(chunk[remaining:])
|
||||
remaining = 0
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
Reference in New Issue
Block a user