Files
UnHided/mediaflow_proxy/remuxer/mkv_demuxer.py
UrloMythus cfc6bbabc9 update
2026-02-19 20:15:03 +01:00

470 lines
17 KiB
Python

"""
Streaming MKV demuxer.
Reads an MKV byte stream via an async iterator and yields individual media
frames (MKVFrame) with absolute timestamps. Designed for on-the-fly remuxing
without buffering the entire file.
Architecture:
AsyncIterator[bytes] -> StreamBuffer -> EBML parsing -> MKVFrame yields
The demuxer works in two phases:
1. read_header(): Consume bytes until Tracks is fully parsed, returning
a list of MKVTrack with codec metadata.
2. iter_frames(): Yield MKVFrame objects from Cluster/SimpleBlock data
as clusters arrive.
"""
import logging
from collections.abc import AsyncIterator
from dataclasses import dataclass, field
from mediaflow_proxy.remuxer.ebml_parser import (
CLUSTER,
CLUSTER_TIMESTAMP,
EBML_HEADER,
INFO,
MKVFrame,
MKVTrack,
SEGMENT,
SIMPLE_BLOCK,
BLOCK_GROUP,
TRACKS,
TIMESTAMP_SCALE,
DURATION,
UNKNOWN_SIZE,
extract_block_frames,
parse_tracks,
read_element_id,
read_element_size,
read_float,
read_uint,
_parse_block_group,
iter_elements,
)
logger = logging.getLogger(__name__)
class StreamBuffer:
"""
Accumulating byte buffer for streaming EBML parsing.
Collects chunks from an async byte source and provides read-ahead
capabilities for EBML element parsing. Supports consuming parsed
bytes to keep memory usage bounded.
"""
def __init__(self) -> None:
self._chunks: list[bytes] = []
self._total: int = 0
self._consumed: int = 0 # Logical bytes consumed (for offset tracking)
@property
def available(self) -> int:
"""Number of buffered bytes available for reading."""
return self._total
@property
def consumed(self) -> int:
"""Total bytes consumed so far (for absolute offset tracking)."""
return self._consumed
def append(self, data: bytes) -> None:
"""Add bytes to the buffer."""
if data:
self._chunks.append(data)
self._total += len(data)
def peek(self, size: int) -> bytes:
"""Read up to size bytes without consuming."""
if size <= 0:
return b""
result = bytearray()
remaining = size
for chunk in self._chunks:
if remaining <= 0:
break
take = min(len(chunk), remaining)
result.extend(chunk[:take])
remaining -= take
return bytes(result)
def get_all(self) -> bytes:
"""Get all buffered data as a single bytes object (without consuming)."""
if len(self._chunks) == 1:
return self._chunks[0]
data = b"".join(self._chunks)
self._chunks = [data]
return data
def consume(self, size: int) -> bytes:
"""Remove and return size bytes from the front of the buffer."""
if size <= 0:
return b""
if size > self._total:
size = self._total
result = bytearray()
remaining = size
while remaining > 0 and self._chunks:
chunk = self._chunks[0]
if len(chunk) <= remaining:
result.extend(chunk)
remaining -= len(chunk)
self._chunks.pop(0)
else:
result.extend(chunk[:remaining])
self._chunks[0] = chunk[remaining:]
remaining = 0
consumed = len(result)
self._total -= consumed
self._consumed += consumed
return bytes(result)
def skip(self, size: int) -> int:
"""Discard size bytes from the front. Returns actual bytes skipped."""
if size <= 0:
return 0
actual = min(size, self._total)
remaining = actual
while remaining > 0 and self._chunks:
chunk = self._chunks[0]
if len(chunk) <= remaining:
remaining -= len(chunk)
self._chunks.pop(0)
else:
self._chunks[0] = chunk[remaining:]
remaining = 0
self._total -= actual
self._consumed += actual
return actual
@dataclass
class MKVHeader:
"""Parsed MKV header metadata."""
tracks: list[MKVTrack] = field(default_factory=list)
timestamp_scale_ns: int = 1_000_000 # Default 1ms
duration_ms: float = 0.0
segment_data_offset: int = 0 # Absolute byte offset of Segment children
class MKVDemuxer:
"""
Streaming async MKV demuxer.
Reads an MKV byte stream from an async iterator and provides:
- read_header(): Parse EBML header + Segment metadata + Tracks
- iter_frames(): Yield MKVFrame objects from Clusters
Usage:
demuxer = MKVDemuxer()
header = await demuxer.read_header(source)
async for frame in demuxer.iter_frames(source):
process(frame)
"""
# Minimum bytes to try parsing an element header (ID + size)
_MIN_ELEMENT_HEADER = 12
def __init__(self) -> None:
self._buf = StreamBuffer()
self._header: MKVHeader | None = None
self._scale_ms: float = 1.0 # timestamp_scale / 1_000_000
@property
def header(self) -> MKVHeader | None:
return self._header
async def read_header(self, source: AsyncIterator[bytes]) -> MKVHeader:
"""
Read and parse the MKV header (EBML header, Segment, Info, Tracks).
Consumes bytes from source until Tracks is fully parsed. Any leftover
bytes (start of first Cluster) remain in the internal buffer for
iter_frames().
Returns:
MKVHeader with track info and timing metadata.
"""
header = MKVHeader()
# Phase 1: Accumulate enough data for EBML header + Segment header
await self._ensure_bytes(source, 64)
data = self._buf.get_all()
if len(data) < 4:
raise ValueError(
f"Source ended prematurely: got {len(data)} bytes, need at least an EBML header (source disconnected?)"
)
pos = 0
# Parse EBML Header
eid, pos = read_element_id(data, pos)
if eid != EBML_HEADER:
raise ValueError(f"Not an MKV file: expected EBML header, got 0x{eid:X}")
size, pos = read_element_size(data, pos)
if size == UNKNOWN_SIZE:
raise ValueError("EBML header has unknown size")
pos += size # Skip EBML header content
# Parse Segment element header
eid, pos = read_element_id(data, pos)
if eid != SEGMENT:
raise ValueError(f"Expected Segment, got 0x{eid:X}")
_seg_size, pos = read_element_size(data, pos)
header.segment_data_offset = self._buf.consumed + pos
# Phase 2: Parse Segment children until we have Tracks
# We need to iterate top-level Segment children: SeekHead, Info, Tracks
# Stop when we hit the first Cluster (media data).
tracks_found = False
while not tracks_found:
# Ensure we have enough for element header
await self._ensure_bytes(source, pos + self._MIN_ELEMENT_HEADER)
data = self._buf.get_all()
if pos >= len(data):
break
try:
eid, pos2 = read_element_id(data, pos)
size, pos3 = read_element_size(data, pos2)
except (ValueError, IndexError):
await self._ensure_bytes(source, pos + 32)
data = self._buf.get_all()
try:
eid, pos2 = read_element_id(data, pos)
size, pos3 = read_element_size(data, pos2)
except (ValueError, IndexError):
break
if eid == CLUSTER:
# Reached media data; header parsing is done.
# Don't consume the Cluster -- leave it for iter_frames.
break
if size == UNKNOWN_SIZE:
# Can't handle unknown-size elements in header
logger.warning("[mkv_demuxer] Unknown-size element 0x%X in header at pos %d", eid, pos)
break
# Ensure we have the full element
elem_end = pos3 + size
await self._ensure_bytes(source, elem_end)
data = self._buf.get_all()
if eid == INFO:
self._parse_info_element(data, pos3, pos3 + size, header)
elif eid == TRACKS:
header.tracks = parse_tracks(data, pos3, pos3 + size)
tracks_found = True
logger.info(
"[mkv_demuxer] Parsed %d tracks: %s",
len(header.tracks),
", ".join(f"#{t.track_number}={t.codec_id}" for t in header.tracks),
)
pos = elem_end
# Consume everything up to the current position (Cluster boundary)
self._buf.consume(pos)
# Set timing scale
self._scale_ms = header.timestamp_scale_ns / 1_000_000.0
self._header = header
return header
async def iter_frames(self, source: AsyncIterator[bytes]) -> AsyncIterator[MKVFrame]:
"""
Yield MKVFrame objects from Cluster/SimpleBlock data.
Must be called after read_header(). Continues consuming bytes from
source, parsing Clusters and yielding individual frames.
"""
if self._header is None:
raise RuntimeError("read_header() must be called before iter_frames()")
while True:
# Try to read the next element header
if not await self._ensure_bytes_soft(source, self._MIN_ELEMENT_HEADER):
break
data = self._buf.get_all()
pos = 0
try:
eid, pos2 = read_element_id(data, pos)
size, pos3 = read_element_size(data, pos2)
except (ValueError, IndexError):
# Try to get more data
if not await self._ensure_bytes_soft(source, len(data) + 4096):
break
data = self._buf.get_all()
try:
eid, pos2 = read_element_id(data, pos)
size, pos3 = read_element_size(data, pos2)
except (ValueError, IndexError):
break
if eid == CLUSTER:
if size == UNKNOWN_SIZE:
# Unknown-size Cluster: parse children until we hit the next
# Cluster or run out of data
self._buf.consume(pos3) # consume Cluster header
async for frame in self._parse_unknown_size_cluster(source):
yield frame
else:
# Known-size Cluster: ensure we have all data
elem_end = pos3 + size
await self._ensure_bytes(source, elem_end)
data = self._buf.get_all()
for frame in self._parse_cluster_data(data, pos3, pos3 + size):
yield frame
self._buf.consume(elem_end)
else:
# Skip non-Cluster top-level elements
if size == UNKNOWN_SIZE:
break
elem_end = pos3 + size
if elem_end > len(data):
# Need to skip bytes we don't have yet
self._buf.consume(len(data))
skip_remaining = elem_end - len(data)
await self._skip_bytes(source, skip_remaining)
else:
self._buf.consume(elem_end)
def _parse_info_element(self, data: bytes, start: int, end: int, header: MKVHeader) -> None:
"""Parse Info element children for timestamp scale and duration."""
for eid, off, size, _ in iter_elements(data, start, end):
if eid == TIMESTAMP_SCALE:
header.timestamp_scale_ns = read_uint(data, off, size)
elif eid == DURATION:
scale = header.timestamp_scale_ns / 1_000_000.0
header.duration_ms = read_float(data, off, size) * scale
def _parse_cluster_data(self, data: bytes, start: int, end: int) -> list[MKVFrame]:
"""Parse a known-size Cluster and return its frames."""
cluster_timecode = 0
frames = []
for eid, data_off, size, _ in iter_elements(data, start, end):
if eid == CLUSTER_TIMESTAMP:
cluster_timecode = read_uint(data, data_off, size)
elif eid == SIMPLE_BLOCK:
for track_num, rel_tc, flags, frame_list in extract_block_frames(data, data_off, size):
is_kf = bool(flags & 0x80)
abs_ts_ms = (cluster_timecode + rel_tc) * self._scale_ms
for frame_data in frame_list:
frames.append(
MKVFrame(
track_number=track_num,
timestamp_ms=abs_ts_ms,
is_keyframe=is_kf,
data=frame_data,
)
)
elif eid == BLOCK_GROUP:
_parse_block_group(data, data_off, data_off + size, cluster_timecode, self._scale_ms, frames)
return frames
async def _parse_unknown_size_cluster(self, source: AsyncIterator[bytes]) -> AsyncIterator[MKVFrame]:
"""Parse an unknown-size Cluster by reading children until next Cluster."""
cluster_timecode = 0
while True:
if not await self._ensure_bytes_soft(source, self._MIN_ELEMENT_HEADER):
break
data = self._buf.get_all()
pos = 0
try:
eid, pos2 = read_element_id(data, pos)
size, pos3 = read_element_size(data, pos2)
except (ValueError, IndexError):
if not await self._ensure_bytes_soft(source, len(data) + 4096):
break
data = self._buf.get_all()
try:
eid, pos2 = read_element_id(data, pos)
size, pos3 = read_element_size(data, pos2)
except (ValueError, IndexError):
break
# A new Cluster or top-level element signals end of current Cluster
if eid == CLUSTER or eid == SEGMENT:
break
if size == UNKNOWN_SIZE:
break
elem_end = pos3 + size
await self._ensure_bytes(source, elem_end)
data = self._buf.get_all()
if eid == CLUSTER_TIMESTAMP:
cluster_timecode = read_uint(data, pos3, size)
elif eid == SIMPLE_BLOCK:
for track_num, rel_tc, flags, frame_list in extract_block_frames(data, pos3, size):
is_kf = bool(flags & 0x80)
abs_ts_ms = (cluster_timecode + rel_tc) * self._scale_ms
for frame_data in frame_list:
yield MKVFrame(
track_number=track_num,
timestamp_ms=abs_ts_ms,
is_keyframe=is_kf,
data=frame_data,
)
elif eid == BLOCK_GROUP:
bg_frames = []
_parse_block_group(data, pos3, pos3 + size, cluster_timecode, self._scale_ms, bg_frames)
for frame in bg_frames:
yield frame
self._buf.consume(elem_end)
async def _ensure_bytes(self, source: AsyncIterator[bytes], needed: int) -> None:
"""Ensure the buffer has at least 'needed' bytes. Raises StopAsyncIteration if exhausted."""
while self._buf.available < needed:
try:
chunk = await source.__anext__()
self._buf.append(chunk)
except StopAsyncIteration:
return
async def _ensure_bytes_soft(self, source: AsyncIterator[bytes], needed: int) -> bool:
"""Like _ensure_bytes but returns False instead of raising."""
while self._buf.available < needed:
try:
chunk = await source.__anext__()
if not chunk:
return self._buf.available > 0
self._buf.append(chunk)
except StopAsyncIteration:
return self._buf.available > 0
return True
async def _skip_bytes(self, source: AsyncIterator[bytes], count: int) -> None:
"""Skip count bytes from the source without buffering."""
remaining = count
while remaining > 0:
try:
chunk = await source.__anext__()
if len(chunk) <= remaining:
remaining -= len(chunk)
else:
# Put the excess back
self._buf.append(chunk[remaining:])
remaining = 0
except StopAsyncIteration:
break