Исправить обработку MTProto через WebSocket на iOS/iPadOS

This commit is contained in:
WillhamOlgren 2026-03-23 22:18:28 +03:00
parent 7a1e2f3f5b
commit ee3c2bd5ab
1 changed files with 120 additions and 40 deletions

View File

@ -144,7 +144,14 @@ _st_Q = struct.Struct('>Q')
_st_I_net = struct.Struct('!I')
_st_Ih = struct.Struct('<Ih')
_st_I_le = struct.Struct('<I')
_VALID_PROTOS = frozenset((0xEFEFEFEF, 0xEEEEEEEE, 0xDDDDDDDD))
_PROTO_ABRIDGED = 0xEFEFEFEF
_PROTO_INTERMEDIATE = 0xEEEEEEEE
_PROTO_PADDED_INTERMEDIATE = 0xDDDDDDDD
_VALID_PROTOS = frozenset((
_PROTO_ABRIDGED,
_PROTO_INTERMEDIATE,
_PROTO_PADDED_INTERMEDIATE,
))
class RawWebSocket:
@ -403,6 +410,22 @@ def _dc_from_init(data: bytes) -> Tuple[Optional[int], bool]:
return None, False
def _proto_from_init(data: bytes) -> Optional[int]:
"""Extract MTProto transport marker from the obfuscated init packet."""
try:
cipher = Cipher(algorithms.AES(data[8:40]), modes.CTR(data[40:56]))
encryptor = cipher.encryptor()
keystream = encryptor.update(_ZERO_64)
plain = (int.from_bytes(data[56:64], 'big') ^
int.from_bytes(keystream[56:64], 'big')).to_bytes(8, 'big')
proto = _st_I_le.unpack(plain[:4])[0]
if proto in _VALID_PROTOS:
return proto
except Exception as exc:
log.debug("Transport extraction failed: %s", exc)
return None
def _patch_init_dc(data: bytes, dc: int) -> bytes:
"""
Patch dc_id in the 64-byte MTProto init packet.
@ -431,54 +454,101 @@ def _patch_init_dc(data: bytes, dc: int) -> bytes:
class _MsgSplitter:
"""
Splits client TCP data into individual MTProto abridged-protocol
messages so each can be sent as a separate WebSocket frame.
Splits client TCP data into individual MTProto transport packets so
each can be sent as a separate WebSocket frame.
The Telegram WS relay processes one MTProto message per WS frame.
Mobile clients batches multiple messages in a single TCP write (e.g.
msgs_ack + req_DH_params). If sent as one WS frame, the relay
only processes the first message DH handshake never completes.
Some mobile clients coalesce multiple MTProto packets into one TCP
write, and TCP reads may also cut a packet in half. Keep a rolling
buffer so incomplete packets are not forwarded as standalone frames.
"""
def __init__(self, init_data: bytes):
__slots__ = ('_dec', '_proto', '_cipher_buf', '_plain_buf', '_disabled')
def __init__(self, init_data: bytes, proto: int):
cipher = Cipher(algorithms.AES(init_data[8:40]),
modes.CTR(init_data[40:56]))
self._dec = cipher.encryptor()
self._dec.update(_ZERO_64) # skip init packet
self._proto = proto
self._cipher_buf = bytearray()
self._plain_buf = bytearray()
self._disabled = False
def split(self, chunk: bytes) -> List[bytes]:
"""Decrypt to find message boundaries, return split ciphertext."""
plain = self._dec.update(chunk)
boundaries = []
pos = 0
plain_len = len(plain)
while pos < plain_len:
first = plain[pos]
if first == 0x7f:
if pos + 4 > plain_len:
break
msg_len = (
_st_I_le.unpack_from(plain, pos + 1)[0] & 0xFFFFFF
) * 4
pos += 4
else:
msg_len = first * 4
pos += 1
if msg_len == 0 or pos + msg_len > plain_len:
break
pos += msg_len
boundaries.append(pos)
if len(boundaries) <= 1:
"""Decrypt to find packet boundaries, return complete ciphertext packets."""
if not chunk:
return []
if self._disabled:
return [chunk]
self._cipher_buf.extend(chunk)
self._plain_buf.extend(self._dec.update(chunk))
parts = []
prev = 0
for b in boundaries:
parts.append(chunk[prev:b])
prev = b
if prev < len(chunk):
parts.append(chunk[prev:])
while self._cipher_buf:
packet_len = self._next_packet_len()
if packet_len is None:
break
if packet_len <= 0:
parts.append(bytes(self._cipher_buf))
self._cipher_buf.clear()
self._plain_buf.clear()
self._disabled = True
break
parts.append(bytes(self._cipher_buf[:packet_len]))
del self._cipher_buf[:packet_len]
del self._plain_buf[:packet_len]
return parts
def flush(self) -> List[bytes]:
if not self._cipher_buf:
return []
tail = bytes(self._cipher_buf)
self._cipher_buf.clear()
self._plain_buf.clear()
return [tail]
def _next_packet_len(self) -> Optional[int]:
if not self._plain_buf:
return None
if self._proto == _PROTO_ABRIDGED:
return self._next_abridged_len()
if self._proto in (_PROTO_INTERMEDIATE, _PROTO_PADDED_INTERMEDIATE):
return self._next_intermediate_len()
return 0
def _next_abridged_len(self) -> Optional[int]:
first = self._plain_buf[0]
if first in (0x7F, 0xFF):
if len(self._plain_buf) < 4:
return None
payload_len = int.from_bytes(self._plain_buf[1:4], 'little') * 4
header_len = 4
else:
payload_len = (first & 0x7F) * 4
header_len = 1
if payload_len <= 0:
return 0
packet_len = header_len + payload_len
if len(self._plain_buf) < packet_len:
return None
return packet_len
def _next_intermediate_len(self) -> Optional[int]:
if len(self._plain_buf) < 4:
return None
payload_len = _st_I_le.unpack_from(self._plain_buf, 0)[0] & 0x7FFFFFFF
if payload_len <= 0:
return 0
packet_len = 4 + payload_len
if len(self._plain_buf) < packet_len:
return None
return packet_len
def _ws_domains(dc: int, is_media) -> List[str]:
dc = _DC_OVERRIDES.get(dc, dc)
@ -627,6 +697,10 @@ async def _bridge_ws(reader, writer, ws: RawWebSocket, label,
while True:
chunk = await reader.read(65536)
if not chunk:
if splitter:
tail = splitter.flush()
if tail:
await ws.send(tail[0])
break
n = len(chunk)
_stats.bytes_up += n
@ -634,6 +708,8 @@ async def _bridge_ws(reader, writer, ws: RawWebSocket, label,
up_packets += 1
if splitter:
parts = splitter.split(chunk)
if not parts:
continue
if len(parts) > 1:
await ws.send_batch(parts)
else:
@ -894,15 +970,14 @@ async def _handle_client(reader, writer):
return
# -- Extract DC ID --
proto = _proto_from_init(init)
dc, is_media = _dc_from_init(init)
init_patched = False
# Android (may be ios too) with useSecret=0 has random dc_id bytes — patch it
if dc is None and dst in _IP_TO_DC:
dc, is_media = _IP_TO_DC.get(dst)
if dc in _dc_opt:
init = _patch_init_dc(init, dc if is_media else -dc)
init_patched = True
if dc is None or dc not in _dc_opt:
log.warning("[%s] unknown DC%s for %s:%d -> TCP passthrough",
@ -1003,9 +1078,9 @@ async def _handle_client(reader, writer):
_stats.connections_ws += 1
splitter = None
if init_patched:
if proto in _VALID_PROTOS:
try:
splitter = _MsgSplitter(init)
splitter = _MsgSplitter(init, proto)
except Exception:
pass
@ -1025,6 +1100,11 @@ async def _handle_client(reader, writer):
log.debug("[%s] cancelled", label)
except ConnectionResetError:
log.debug("[%s] connection reset", label)
except OSError as exc:
if getattr(exc, 'winerror', None) == 1236:
log.debug("[%s] connection aborted by local system", label)
else:
log.error("[%s] unexpected os error: %s", label, exc)
except Exception as exc:
log.error("[%s] unexpected: %s", label, exc)
finally: