Исправить обработку MTProto через WebSocket на iOS/iPadOS
This commit is contained in:
parent
7a1e2f3f5b
commit
ee3c2bd5ab
|
|
@ -144,7 +144,14 @@ _st_Q = struct.Struct('>Q')
|
|||
_st_I_net = struct.Struct('!I')
|
||||
_st_Ih = struct.Struct('<Ih')
|
||||
_st_I_le = struct.Struct('<I')
|
||||
_VALID_PROTOS = frozenset((0xEFEFEFEF, 0xEEEEEEEE, 0xDDDDDDDD))
|
||||
_PROTO_ABRIDGED = 0xEFEFEFEF
|
||||
_PROTO_INTERMEDIATE = 0xEEEEEEEE
|
||||
_PROTO_PADDED_INTERMEDIATE = 0xDDDDDDDD
|
||||
_VALID_PROTOS = frozenset((
|
||||
_PROTO_ABRIDGED,
|
||||
_PROTO_INTERMEDIATE,
|
||||
_PROTO_PADDED_INTERMEDIATE,
|
||||
))
|
||||
|
||||
|
||||
class RawWebSocket:
|
||||
|
|
@ -403,6 +410,22 @@ def _dc_from_init(data: bytes) -> Tuple[Optional[int], bool]:
|
|||
return None, False
|
||||
|
||||
|
||||
def _proto_from_init(data: bytes) -> Optional[int]:
|
||||
"""Extract MTProto transport marker from the obfuscated init packet."""
|
||||
try:
|
||||
cipher = Cipher(algorithms.AES(data[8:40]), modes.CTR(data[40:56]))
|
||||
encryptor = cipher.encryptor()
|
||||
keystream = encryptor.update(_ZERO_64)
|
||||
plain = (int.from_bytes(data[56:64], 'big') ^
|
||||
int.from_bytes(keystream[56:64], 'big')).to_bytes(8, 'big')
|
||||
proto = _st_I_le.unpack(plain[:4])[0]
|
||||
if proto in _VALID_PROTOS:
|
||||
return proto
|
||||
except Exception as exc:
|
||||
log.debug("Transport extraction failed: %s", exc)
|
||||
return None
|
||||
|
||||
|
||||
def _patch_init_dc(data: bytes, dc: int) -> bytes:
|
||||
"""
|
||||
Patch dc_id in the 64-byte MTProto init packet.
|
||||
|
|
@ -431,54 +454,101 @@ def _patch_init_dc(data: bytes, dc: int) -> bytes:
|
|||
|
||||
class _MsgSplitter:
|
||||
"""
|
||||
Splits client TCP data into individual MTProto abridged-protocol
|
||||
messages so each can be sent as a separate WebSocket frame.
|
||||
Splits client TCP data into individual MTProto transport packets so
|
||||
each can be sent as a separate WebSocket frame.
|
||||
|
||||
The Telegram WS relay processes one MTProto message per WS frame.
|
||||
Mobile clients batches multiple messages in a single TCP write (e.g.
|
||||
msgs_ack + req_DH_params). If sent as one WS frame, the relay
|
||||
only processes the first message — DH handshake never completes.
|
||||
Some mobile clients coalesce multiple MTProto packets into one TCP
|
||||
write, and TCP reads may also cut a packet in half. Keep a rolling
|
||||
buffer so incomplete packets are not forwarded as standalone frames.
|
||||
"""
|
||||
|
||||
def __init__(self, init_data: bytes):
|
||||
__slots__ = ('_dec', '_proto', '_cipher_buf', '_plain_buf', '_disabled')
|
||||
|
||||
def __init__(self, init_data: bytes, proto: int):
|
||||
cipher = Cipher(algorithms.AES(init_data[8:40]),
|
||||
modes.CTR(init_data[40:56]))
|
||||
self._dec = cipher.encryptor()
|
||||
self._dec.update(_ZERO_64) # skip init packet
|
||||
self._proto = proto
|
||||
self._cipher_buf = bytearray()
|
||||
self._plain_buf = bytearray()
|
||||
self._disabled = False
|
||||
|
||||
def split(self, chunk: bytes) -> List[bytes]:
|
||||
"""Decrypt to find message boundaries, return split ciphertext."""
|
||||
plain = self._dec.update(chunk)
|
||||
boundaries = []
|
||||
pos = 0
|
||||
plain_len = len(plain)
|
||||
while pos < plain_len:
|
||||
first = plain[pos]
|
||||
if first == 0x7f:
|
||||
if pos + 4 > plain_len:
|
||||
break
|
||||
msg_len = (
|
||||
_st_I_le.unpack_from(plain, pos + 1)[0] & 0xFFFFFF
|
||||
) * 4
|
||||
pos += 4
|
||||
else:
|
||||
msg_len = first * 4
|
||||
pos += 1
|
||||
if msg_len == 0 or pos + msg_len > plain_len:
|
||||
break
|
||||
pos += msg_len
|
||||
boundaries.append(pos)
|
||||
if len(boundaries) <= 1:
|
||||
"""Decrypt to find packet boundaries, return complete ciphertext packets."""
|
||||
if not chunk:
|
||||
return []
|
||||
if self._disabled:
|
||||
return [chunk]
|
||||
|
||||
self._cipher_buf.extend(chunk)
|
||||
self._plain_buf.extend(self._dec.update(chunk))
|
||||
|
||||
parts = []
|
||||
prev = 0
|
||||
for b in boundaries:
|
||||
parts.append(chunk[prev:b])
|
||||
prev = b
|
||||
if prev < len(chunk):
|
||||
parts.append(chunk[prev:])
|
||||
while self._cipher_buf:
|
||||
packet_len = self._next_packet_len()
|
||||
if packet_len is None:
|
||||
break
|
||||
if packet_len <= 0:
|
||||
parts.append(bytes(self._cipher_buf))
|
||||
self._cipher_buf.clear()
|
||||
self._plain_buf.clear()
|
||||
self._disabled = True
|
||||
break
|
||||
parts.append(bytes(self._cipher_buf[:packet_len]))
|
||||
del self._cipher_buf[:packet_len]
|
||||
del self._plain_buf[:packet_len]
|
||||
return parts
|
||||
|
||||
def flush(self) -> List[bytes]:
|
||||
if not self._cipher_buf:
|
||||
return []
|
||||
tail = bytes(self._cipher_buf)
|
||||
self._cipher_buf.clear()
|
||||
self._plain_buf.clear()
|
||||
return [tail]
|
||||
|
||||
def _next_packet_len(self) -> Optional[int]:
|
||||
if not self._plain_buf:
|
||||
return None
|
||||
if self._proto == _PROTO_ABRIDGED:
|
||||
return self._next_abridged_len()
|
||||
if self._proto in (_PROTO_INTERMEDIATE, _PROTO_PADDED_INTERMEDIATE):
|
||||
return self._next_intermediate_len()
|
||||
return 0
|
||||
|
||||
def _next_abridged_len(self) -> Optional[int]:
|
||||
first = self._plain_buf[0]
|
||||
if first in (0x7F, 0xFF):
|
||||
if len(self._plain_buf) < 4:
|
||||
return None
|
||||
payload_len = int.from_bytes(self._plain_buf[1:4], 'little') * 4
|
||||
header_len = 4
|
||||
else:
|
||||
payload_len = (first & 0x7F) * 4
|
||||
header_len = 1
|
||||
|
||||
if payload_len <= 0:
|
||||
return 0
|
||||
|
||||
packet_len = header_len + payload_len
|
||||
if len(self._plain_buf) < packet_len:
|
||||
return None
|
||||
return packet_len
|
||||
|
||||
def _next_intermediate_len(self) -> Optional[int]:
|
||||
if len(self._plain_buf) < 4:
|
||||
return None
|
||||
|
||||
payload_len = _st_I_le.unpack_from(self._plain_buf, 0)[0] & 0x7FFFFFFF
|
||||
if payload_len <= 0:
|
||||
return 0
|
||||
|
||||
packet_len = 4 + payload_len
|
||||
if len(self._plain_buf) < packet_len:
|
||||
return None
|
||||
return packet_len
|
||||
|
||||
|
||||
def _ws_domains(dc: int, is_media) -> List[str]:
|
||||
dc = _DC_OVERRIDES.get(dc, dc)
|
||||
|
|
@ -627,6 +697,10 @@ async def _bridge_ws(reader, writer, ws: RawWebSocket, label,
|
|||
while True:
|
||||
chunk = await reader.read(65536)
|
||||
if not chunk:
|
||||
if splitter:
|
||||
tail = splitter.flush()
|
||||
if tail:
|
||||
await ws.send(tail[0])
|
||||
break
|
||||
n = len(chunk)
|
||||
_stats.bytes_up += n
|
||||
|
|
@ -634,6 +708,8 @@ async def _bridge_ws(reader, writer, ws: RawWebSocket, label,
|
|||
up_packets += 1
|
||||
if splitter:
|
||||
parts = splitter.split(chunk)
|
||||
if not parts:
|
||||
continue
|
||||
if len(parts) > 1:
|
||||
await ws.send_batch(parts)
|
||||
else:
|
||||
|
|
@ -894,15 +970,14 @@ async def _handle_client(reader, writer):
|
|||
return
|
||||
|
||||
# -- Extract DC ID --
|
||||
proto = _proto_from_init(init)
|
||||
dc, is_media = _dc_from_init(init)
|
||||
init_patched = False
|
||||
|
||||
# Android (may be ios too) with useSecret=0 has random dc_id bytes — patch it
|
||||
if dc is None and dst in _IP_TO_DC:
|
||||
dc, is_media = _IP_TO_DC.get(dst)
|
||||
if dc in _dc_opt:
|
||||
init = _patch_init_dc(init, dc if is_media else -dc)
|
||||
init_patched = True
|
||||
|
||||
if dc is None or dc not in _dc_opt:
|
||||
log.warning("[%s] unknown DC%s for %s:%d -> TCP passthrough",
|
||||
|
|
@ -1003,9 +1078,9 @@ async def _handle_client(reader, writer):
|
|||
_stats.connections_ws += 1
|
||||
|
||||
splitter = None
|
||||
if init_patched:
|
||||
if proto in _VALID_PROTOS:
|
||||
try:
|
||||
splitter = _MsgSplitter(init)
|
||||
splitter = _MsgSplitter(init, proto)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
|
@ -1025,6 +1100,11 @@ async def _handle_client(reader, writer):
|
|||
log.debug("[%s] cancelled", label)
|
||||
except ConnectionResetError:
|
||||
log.debug("[%s] connection reset", label)
|
||||
except OSError as exc:
|
||||
if getattr(exc, 'winerror', None) == 1236:
|
||||
log.debug("[%s] connection aborted by local system", label)
|
||||
else:
|
||||
log.error("[%s] unexpected os error: %s", label, exc)
|
||||
except Exception as exc:
|
||||
log.error("[%s] unexpected: %s", label, exc)
|
||||
finally:
|
||||
|
|
|
|||
Loading…
Reference in New Issue