Исправить обработку MTProto через WebSocket на iOS/iPadOS

This commit is contained in:
WillhamOlgren 2026-03-23 22:18:28 +03:00
parent 7a1e2f3f5b
commit ee3c2bd5ab
1 changed files with 120 additions and 40 deletions

View File

@ -144,7 +144,14 @@ _st_Q = struct.Struct('>Q')
_st_I_net = struct.Struct('!I') _st_I_net = struct.Struct('!I')
_st_Ih = struct.Struct('<Ih') _st_Ih = struct.Struct('<Ih')
_st_I_le = struct.Struct('<I') _st_I_le = struct.Struct('<I')
_VALID_PROTOS = frozenset((0xEFEFEFEF, 0xEEEEEEEE, 0xDDDDDDDD)) _PROTO_ABRIDGED = 0xEFEFEFEF
_PROTO_INTERMEDIATE = 0xEEEEEEEE
_PROTO_PADDED_INTERMEDIATE = 0xDDDDDDDD
_VALID_PROTOS = frozenset((
_PROTO_ABRIDGED,
_PROTO_INTERMEDIATE,
_PROTO_PADDED_INTERMEDIATE,
))
class RawWebSocket: class RawWebSocket:
@ -403,6 +410,22 @@ def _dc_from_init(data: bytes) -> Tuple[Optional[int], bool]:
return None, False return None, False
def _proto_from_init(data: bytes) -> Optional[int]:
"""Extract MTProto transport marker from the obfuscated init packet."""
try:
cipher = Cipher(algorithms.AES(data[8:40]), modes.CTR(data[40:56]))
encryptor = cipher.encryptor()
keystream = encryptor.update(_ZERO_64)
plain = (int.from_bytes(data[56:64], 'big') ^
int.from_bytes(keystream[56:64], 'big')).to_bytes(8, 'big')
proto = _st_I_le.unpack(plain[:4])[0]
if proto in _VALID_PROTOS:
return proto
except Exception as exc:
log.debug("Transport extraction failed: %s", exc)
return None
def _patch_init_dc(data: bytes, dc: int) -> bytes: def _patch_init_dc(data: bytes, dc: int) -> bytes:
""" """
Patch dc_id in the 64-byte MTProto init packet. Patch dc_id in the 64-byte MTProto init packet.
@ -431,54 +454,101 @@ def _patch_init_dc(data: bytes, dc: int) -> bytes:
class _MsgSplitter: class _MsgSplitter:
""" """
Splits client TCP data into individual MTProto abridged-protocol Splits client TCP data into individual MTProto transport packets so
messages so each can be sent as a separate WebSocket frame. each can be sent as a separate WebSocket frame.
The Telegram WS relay processes one MTProto message per WS frame. Some mobile clients coalesce multiple MTProto packets into one TCP
Mobile clients batches multiple messages in a single TCP write (e.g. write, and TCP reads may also cut a packet in half. Keep a rolling
msgs_ack + req_DH_params). If sent as one WS frame, the relay buffer so incomplete packets are not forwarded as standalone frames.
only processes the first message DH handshake never completes.
""" """
def __init__(self, init_data: bytes): __slots__ = ('_dec', '_proto', '_cipher_buf', '_plain_buf', '_disabled')
def __init__(self, init_data: bytes, proto: int):
cipher = Cipher(algorithms.AES(init_data[8:40]), cipher = Cipher(algorithms.AES(init_data[8:40]),
modes.CTR(init_data[40:56])) modes.CTR(init_data[40:56]))
self._dec = cipher.encryptor() self._dec = cipher.encryptor()
self._dec.update(_ZERO_64) # skip init packet self._dec.update(_ZERO_64) # skip init packet
self._proto = proto
self._cipher_buf = bytearray()
self._plain_buf = bytearray()
self._disabled = False
def split(self, chunk: bytes) -> List[bytes]: def split(self, chunk: bytes) -> List[bytes]:
"""Decrypt to find message boundaries, return split ciphertext.""" """Decrypt to find packet boundaries, return complete ciphertext packets."""
plain = self._dec.update(chunk) if not chunk:
boundaries = [] return []
pos = 0 if self._disabled:
plain_len = len(plain)
while pos < plain_len:
first = plain[pos]
if first == 0x7f:
if pos + 4 > plain_len:
break
msg_len = (
_st_I_le.unpack_from(plain, pos + 1)[0] & 0xFFFFFF
) * 4
pos += 4
else:
msg_len = first * 4
pos += 1
if msg_len == 0 or pos + msg_len > plain_len:
break
pos += msg_len
boundaries.append(pos)
if len(boundaries) <= 1:
return [chunk] return [chunk]
self._cipher_buf.extend(chunk)
self._plain_buf.extend(self._dec.update(chunk))
parts = [] parts = []
prev = 0 while self._cipher_buf:
for b in boundaries: packet_len = self._next_packet_len()
parts.append(chunk[prev:b]) if packet_len is None:
prev = b break
if prev < len(chunk): if packet_len <= 0:
parts.append(chunk[prev:]) parts.append(bytes(self._cipher_buf))
self._cipher_buf.clear()
self._plain_buf.clear()
self._disabled = True
break
parts.append(bytes(self._cipher_buf[:packet_len]))
del self._cipher_buf[:packet_len]
del self._plain_buf[:packet_len]
return parts return parts
def flush(self) -> List[bytes]:
if not self._cipher_buf:
return []
tail = bytes(self._cipher_buf)
self._cipher_buf.clear()
self._plain_buf.clear()
return [tail]
def _next_packet_len(self) -> Optional[int]:
if not self._plain_buf:
return None
if self._proto == _PROTO_ABRIDGED:
return self._next_abridged_len()
if self._proto in (_PROTO_INTERMEDIATE, _PROTO_PADDED_INTERMEDIATE):
return self._next_intermediate_len()
return 0
def _next_abridged_len(self) -> Optional[int]:
first = self._plain_buf[0]
if first in (0x7F, 0xFF):
if len(self._plain_buf) < 4:
return None
payload_len = int.from_bytes(self._plain_buf[1:4], 'little') * 4
header_len = 4
else:
payload_len = (first & 0x7F) * 4
header_len = 1
if payload_len <= 0:
return 0
packet_len = header_len + payload_len
if len(self._plain_buf) < packet_len:
return None
return packet_len
def _next_intermediate_len(self) -> Optional[int]:
if len(self._plain_buf) < 4:
return None
payload_len = _st_I_le.unpack_from(self._plain_buf, 0)[0] & 0x7FFFFFFF
if payload_len <= 0:
return 0
packet_len = 4 + payload_len
if len(self._plain_buf) < packet_len:
return None
return packet_len
def _ws_domains(dc: int, is_media) -> List[str]: def _ws_domains(dc: int, is_media) -> List[str]:
dc = _DC_OVERRIDES.get(dc, dc) dc = _DC_OVERRIDES.get(dc, dc)
@ -627,6 +697,10 @@ async def _bridge_ws(reader, writer, ws: RawWebSocket, label,
while True: while True:
chunk = await reader.read(65536) chunk = await reader.read(65536)
if not chunk: if not chunk:
if splitter:
tail = splitter.flush()
if tail:
await ws.send(tail[0])
break break
n = len(chunk) n = len(chunk)
_stats.bytes_up += n _stats.bytes_up += n
@ -634,6 +708,8 @@ async def _bridge_ws(reader, writer, ws: RawWebSocket, label,
up_packets += 1 up_packets += 1
if splitter: if splitter:
parts = splitter.split(chunk) parts = splitter.split(chunk)
if not parts:
continue
if len(parts) > 1: if len(parts) > 1:
await ws.send_batch(parts) await ws.send_batch(parts)
else: else:
@ -894,15 +970,14 @@ async def _handle_client(reader, writer):
return return
# -- Extract DC ID -- # -- Extract DC ID --
proto = _proto_from_init(init)
dc, is_media = _dc_from_init(init) dc, is_media = _dc_from_init(init)
init_patched = False
# Android (may be ios too) with useSecret=0 has random dc_id bytes — patch it # Android (may be ios too) with useSecret=0 has random dc_id bytes — patch it
if dc is None and dst in _IP_TO_DC: if dc is None and dst in _IP_TO_DC:
dc, is_media = _IP_TO_DC.get(dst) dc, is_media = _IP_TO_DC.get(dst)
if dc in _dc_opt: if dc in _dc_opt:
init = _patch_init_dc(init, dc if is_media else -dc) init = _patch_init_dc(init, dc if is_media else -dc)
init_patched = True
if dc is None or dc not in _dc_opt: if dc is None or dc not in _dc_opt:
log.warning("[%s] unknown DC%s for %s:%d -> TCP passthrough", log.warning("[%s] unknown DC%s for %s:%d -> TCP passthrough",
@ -1003,9 +1078,9 @@ async def _handle_client(reader, writer):
_stats.connections_ws += 1 _stats.connections_ws += 1
splitter = None splitter = None
if init_patched: if proto in _VALID_PROTOS:
try: try:
splitter = _MsgSplitter(init) splitter = _MsgSplitter(init, proto)
except Exception: except Exception:
pass pass
@ -1025,6 +1100,11 @@ async def _handle_client(reader, writer):
log.debug("[%s] cancelled", label) log.debug("[%s] cancelled", label)
except ConnectionResetError: except ConnectionResetError:
log.debug("[%s] connection reset", label) log.debug("[%s] connection reset", label)
except OSError as exc:
if getattr(exc, 'winerror', None) == 1236:
log.debug("[%s] connection aborted by local system", label)
else:
log.error("[%s] unexpected os error: %s", label, exc)
except Exception as exc: except Exception as exc:
log.error("[%s] unexpected: %s", label, exc) log.error("[%s] unexpected: %s", label, exc)
finally: finally: