fix(proxy): reassemble fragmented WebSocket messages in recv()

RFC 6455 allows large BINARY payloads to be split across multiple
frames with FIN=0 and CONTINUATION opcodes. The previous recv() only
returned the first fragment, corrupting the MTProto stream and breaking
some media downloads. Control frames between fragments are still
handled as before.

Made-with: Cursor
This commit is contained in:
Sceef 2026-04-02 14:49:37 +07:00
parent da4b521aba
commit f28e4e3384
1 changed files with 43 additions and 5 deletions

View File

@ -234,8 +234,15 @@ class RawWebSocket:
await self.writer.drain()
async def recv(self) -> Optional[bytes]:
"""
Return one complete WebSocket data message (RFC 6455), reassembling
fragmented BINARY/TEXT frames. Control frames may appear between
fragments and are handled without breaking reassembly.
"""
fragments: Optional[List[bytes]] = None
while not self._closed:
opcode, payload = await self._read_frame()
opcode, payload, fin = await self._read_frame()
if opcode == self.OP_CLOSE:
self._closed = True
@ -260,8 +267,38 @@ class RawWebSocket:
if opcode == self.OP_PONG:
continue
# Continuation — only valid inside a fragmented message
if opcode == 0:
if fragments is None:
self._closed = True
try:
self.writer.close()
await self.writer.wait_closed()
except Exception:
pass
return None
fragments.append(payload)
if fin:
out = b''.join(fragments)
fragments = None
return out
continue
if opcode in (0x1, 0x2):
return payload
if fragments is not None:
self._closed = True
try:
self.writer.close()
await self.writer.wait_closed()
except Exception:
pass
return None
if fin:
return payload
fragments = [payload]
continue
# Reserved / unknown data opcodes — skip frame, keep reading
continue
return None
@ -300,8 +337,9 @@ class RawWebSocket:
return _st_BBH4s.pack(fb, 0x80 | 126, length, mask_key) + masked
return _st_BBQ4s.pack(fb, 0x80 | 127, length, mask_key) + masked
async def _read_frame(self) -> Tuple[int, bytes]:
async def _read_frame(self) -> Tuple[int, bytes, bool]:
hdr = await self.reader.readexactly(2)
fin = bool(hdr[0] & 0x80)
opcode = hdr[0] & 0x0F
length = hdr[1] & 0x7F
if length == 126:
@ -311,9 +349,9 @@ class RawWebSocket:
if hdr[1] & 0x80:
mask_key = await self.reader.readexactly(4)
payload = await self.reader.readexactly(length)
return opcode, _xor_mask(payload, mask_key)
return opcode, _xor_mask(payload, mask_key), fin
payload = await self.reader.readexactly(length)
return opcode, payload
return opcode, payload, fin
def _human_bytes(n: int) -> str: