diff --git a/proxy/tg_ws_proxy.py b/proxy/tg_ws_proxy.py index 21ba025..713c6ab 100644 --- a/proxy/tg_ws_proxy.py +++ b/proxy/tg_ws_proxy.py @@ -234,8 +234,15 @@ class RawWebSocket: await self.writer.drain() async def recv(self) -> Optional[bytes]: + """ + Return one complete WebSocket data message (RFC 6455), reassembling + fragmented BINARY/TEXT frames. Control frames may appear between + fragments and are handled without breaking reassembly. + """ + fragments: Optional[List[bytes]] = None + while not self._closed: - opcode, payload = await self._read_frame() + opcode, payload, fin = await self._read_frame() if opcode == self.OP_CLOSE: self._closed = True @@ -260,8 +267,38 @@ class RawWebSocket: if opcode == self.OP_PONG: continue + # Continuation — only valid inside a fragmented message + if opcode == 0: + if fragments is None: + self._closed = True + try: + self.writer.close() + await self.writer.wait_closed() + except Exception: + pass + return None + fragments.append(payload) + if fin: + out = b''.join(fragments) + fragments = None + return out + continue + if opcode in (0x1, 0x2): - return payload + if fragments is not None: + self._closed = True + try: + self.writer.close() + await self.writer.wait_closed() + except Exception: + pass + return None + if fin: + return payload + fragments = [payload] + continue + + # Reserved / unknown data opcodes — skip frame, keep reading continue return None @@ -300,8 +337,9 @@ class RawWebSocket: return _st_BBH4s.pack(fb, 0x80 | 126, length, mask_key) + masked return _st_BBQ4s.pack(fb, 0x80 | 127, length, mask_key) + masked - async def _read_frame(self) -> Tuple[int, bytes]: + async def _read_frame(self) -> Tuple[int, bytes, bool]: hdr = await self.reader.readexactly(2) + fin = bool(hdr[0] & 0x80) opcode = hdr[0] & 0x0F length = hdr[1] & 0x7F if length == 126: @@ -311,9 +349,9 @@ class RawWebSocket: if hdr[1] & 0x80: mask_key = await self.reader.readexactly(4) payload = await self.reader.readexactly(length) - return opcode, _xor_mask(payload, mask_key) + return opcode, _xor_mask(payload, mask_key), fin payload = await self.reader.readexactly(length) - return opcode, payload + return opcode, payload, fin def _human_bytes(n: int) -> str: