Merge upstream/main into android_migration

This commit is contained in:
Dark_Avery
2026-03-28 22:21:31 +03:00
16 changed files with 1710 additions and 672 deletions

View File

@@ -4,6 +4,7 @@ import argparse
import asyncio
import base64
import logging
from collections import deque
import logging.handlers
import os
import socket as _socket
@@ -145,7 +146,14 @@ _st_Q = struct.Struct('>Q')
_st_I_net = struct.Struct('!I')
_st_Ih = struct.Struct('<Ih')
_st_I_le = struct.Struct('<I')
_VALID_PROTOS = frozenset((0xEFEFEFEF, 0xEEEEEEEE, 0xDDDDDDDD))
_PROTO_ABRIDGED = 0xEFEFEFEF
_PROTO_INTERMEDIATE = 0xEEEEEEEE
_PROTO_PADDED_INTERMEDIATE = 0xDDDDDDDD
_VALID_PROTOS = frozenset((
_PROTO_ABRIDGED,
_PROTO_INTERMEDIATE,
_PROTO_PADDED_INTERMEDIATE,
))
class RawWebSocket:
@@ -382,11 +390,7 @@ def _is_http_transport(data: bytes) -> bool:
data[:5] == b'HEAD ' or data[:8] == b'OPTIONS ')
def _dc_from_init(data: bytes) -> Tuple[Optional[int], bool]:
"""
Extract DC ID from the 64-byte MTProto obfuscation init packet.
Returns (dc_id, is_media).
"""
def _dc_from_init(data: bytes, *, return_proto: bool = False):
try:
key = bytes(data[8:40])
iv = bytes(data[40:56])
@@ -400,11 +404,14 @@ def _dc_from_init(data: bytes) -> Tuple[Optional[int], bool]:
if proto in _VALID_PROTOS:
dc = abs(dc_raw)
if 1 <= dc <= 5 or dc == 203:
return dc, (dc_raw < 0)
return (
(dc, (dc_raw < 0), proto)
if return_proto else (dc, (dc_raw < 0))
)
return (None, False, proto) if return_proto else (None, False)
except Exception as exc:
log.debug("DC extraction failed: %s", exc)
return None, False
return (None, False, None) if return_proto else (None, False)
def _patch_init_dc(data: bytes, dc: int) -> bytes:
"""
@@ -435,54 +442,103 @@ def _patch_init_dc(data: bytes, dc: int) -> bytes:
class _MsgSplitter:
"""
Splits client TCP data into individual MTProto abridged-protocol
messages so each can be sent as a separate WebSocket frame.
Splits client TCP data into individual MTProto transport packets so
each can be sent as a separate WebSocket frame.
The Telegram WS relay processes one MTProto message per WS frame.
Mobile clients batches multiple messages in a single TCP write (e.g.
msgs_ack + req_DH_params). If sent as one WS frame, the relay
only processes the first message — DH handshake never completes.
Some mobile clients coalesce multiple MTProto packets into one TCP
write, and TCP reads may also cut a packet in half. Keep a rolling
buffer so incomplete packets are not forwarded as standalone frames.
"""
def __init__(self, init_data: bytes):
__slots__ = ('_dec', '_proto', '_cipher_buf', '_plain_buf', '_disabled')
def __init__(self, init_data: bytes, proto: Optional[int] = None):
if proto is None:
_, _, proto = _dc_from_init(init_data, return_proto=True)
key_raw = bytes(init_data[8:40])
iv = bytes(init_data[40:56])
self._dec = create_aes_ctr_transform(key_raw, iv)
self._dec.update(b'\x00' * 64) # skip init packet
self._proto = proto
self._cipher_buf = bytearray()
self._plain_buf = bytearray()
self._disabled = False
def split(self, chunk: bytes) -> List[bytes]:
"""Decrypt to find message boundaries, return split ciphertext."""
plain = self._dec.update(chunk)
boundaries = []
pos = 0
plain_len = len(plain)
while pos < plain_len:
first = plain[pos]
if first == 0x7f:
if pos + 4 > plain_len:
break
msg_len = (
_st_I_le.unpack_from(plain, pos + 1)[0] & 0xFFFFFF
) * 4
pos += 4
else:
msg_len = first * 4
pos += 1
if msg_len == 0 or pos + msg_len > plain_len:
break
pos += msg_len
boundaries.append(pos)
if len(boundaries) <= 1:
"""Decrypt to find packet boundaries, return complete ciphertext packets."""
if not chunk:
return []
if self._disabled:
return [chunk]
self._cipher_buf.extend(chunk)
self._plain_buf.extend(self._dec.update(chunk))
parts = []
prev = 0
for b in boundaries:
parts.append(chunk[prev:b])
prev = b
if prev < len(chunk):
parts.append(chunk[prev:])
while self._cipher_buf:
packet_len = self._next_packet_len()
if packet_len is None:
break
if packet_len <= 0:
parts.append(bytes(self._cipher_buf))
self._cipher_buf.clear()
self._plain_buf.clear()
self._disabled = True
break
parts.append(bytes(self._cipher_buf[:packet_len]))
del self._cipher_buf[:packet_len]
del self._plain_buf[:packet_len]
return parts
def flush(self) -> List[bytes]:
if not self._cipher_buf:
return []
tail = bytes(self._cipher_buf)
self._cipher_buf.clear()
self._plain_buf.clear()
return [tail]
def _next_packet_len(self) -> Optional[int]:
if not self._plain_buf:
return None
if self._proto == _PROTO_ABRIDGED:
return self._next_abridged_len()
if self._proto in (_PROTO_INTERMEDIATE, _PROTO_PADDED_INTERMEDIATE):
return self._next_intermediate_len()
return 0
def _next_abridged_len(self) -> Optional[int]:
first = self._plain_buf[0]
if first in (0x7F, 0xFF):
if len(self._plain_buf) < 4:
return None
payload_len = int.from_bytes(self._plain_buf[1:4], 'little') * 4
header_len = 4
else:
payload_len = (first & 0x7F) * 4
header_len = 1
if payload_len <= 0:
return 0
packet_len = header_len + payload_len
if len(self._plain_buf) < packet_len:
return None
return packet_len
def _next_intermediate_len(self) -> Optional[int]:
if len(self._plain_buf) < 4:
return None
payload_len = _st_I_le.unpack_from(self._plain_buf, 0)[0] & 0x7FFFFFFF
if payload_len <= 0:
return 0
packet_len = 4 + payload_len
if len(self._plain_buf) < packet_len:
return None
return packet_len
def _ws_domains(dc: int, is_media) -> List[str]:
dc = _DC_OVERRIDES.get(dc, dc)
@@ -505,12 +561,15 @@ class Stats:
self.pool_misses = 0
def summary(self) -> str:
pool_total = self.pool_hits + self.pool_misses
pool_s = (
f"{self.pool_hits}/{pool_total}" if pool_total else "n/a")
return (f"total={self.connections_total} ws={self.connections_ws} "
f"tcp_fb={self.connections_tcp_fallback} "
f"http_skip={self.connections_http_rejected} "
f"pass={self.connections_passthrough} "
f"err={self.ws_errors} "
f"pool={self.pool_hits}/{self.pool_hits+self.pool_misses} "
f"pool={pool_s} "
f"up={_human_bytes(self.bytes_up)} "
f"down={_human_bytes(self.bytes_down)}")
@@ -535,7 +594,7 @@ def get_stats_snapshot() -> Dict[str, int]:
class _WsPool:
def __init__(self):
self._idle: Dict[Tuple[int, bool], list] = {}
self._idle: Dict[Tuple[int, bool], deque] = {}
self._refilling: Set[Tuple[int, bool]] = set()
async def get(self, dc: int, is_media: bool,
@@ -544,9 +603,12 @@ class _WsPool:
key = (dc, is_media)
now = time.monotonic()
bucket = self._idle.get(key, [])
bucket = self._idle.get(key)
if bucket is None:
bucket = deque()
self._idle[key] = bucket
while bucket:
ws, created = bucket.pop(0)
ws, created = bucket.popleft()
age = now - created
if age > _WS_POOL_MAX_AGE or ws._closed:
asyncio.create_task(self._quiet_close(ws))
@@ -570,7 +632,7 @@ class _WsPool:
async def _refill(self, key, target_ip, domains):
dc, is_media = key
try:
bucket = self._idle.setdefault(key, [])
bucket = self._idle.setdefault(key, deque())
needed = _WS_POOL_SIZE - len(bucket)
if needed <= 0:
return
@@ -646,6 +708,10 @@ async def _bridge_ws(reader, writer, ws: RawWebSocket, label,
while True:
chunk = await reader.read(65536)
if not chunk:
if splitter:
tail = splitter.flush()
if tail:
await ws.send(tail[0])
break
n = len(chunk)
_stats.bytes_up += n
@@ -653,6 +719,8 @@ async def _bridge_ws(reader, writer, ws: RawWebSocket, label,
up_packets += 1
if splitter:
parts = splitter.split(chunk)
if not parts:
continue
if len(parts) > 1:
await ws.send_batch(parts)
else:
@@ -913,14 +981,14 @@ async def _handle_client(reader, writer):
return
# -- Extract DC ID --
dc, is_media = _dc_from_init(init)
init_patched = False
dc, is_media, proto = _dc_from_init(init)
init_patched = False
# Android (may be ios too) with useSecret=0 has random dc_id bytes — patch it
if dc is None and dst in _IP_TO_DC:
dc, is_media = _IP_TO_DC.get(dst)
if dc in _dc_opt:
init = _patch_init_dc(init, dc if is_media else -dc)
init = _patch_init_dc(init, -dc if is_media else dc)
init_patched = True
if dc is None or dc not in _dc_opt:
@@ -1022,9 +1090,12 @@ async def _handle_client(reader, writer):
_stats.connections_ws += 1
splitter = None
if init_patched:
# Turning splitter on for mobile clients or media-connections, so as the big files don't get fragmented by the TCP socket.
if proto is not None and (init_patched or is_media or proto != _PROTO_INTERMEDIATE):
try:
splitter = _MsgSplitter(init)
splitter = _MsgSplitter(init, proto)
log.debug("[%s] MsgSplitter activated for proto 0x%08X", label, proto)
except Exception:
pass
@@ -1044,6 +1115,11 @@ async def _handle_client(reader, writer):
log.debug("[%s] cancelled", label)
except ConnectionResetError:
log.debug("[%s] connection reset", label)
except OSError as exc:
if getattr(exc, 'winerror', None) == 1236:
log.debug("[%s] connection aborted by local system", label)
else:
log.error("[%s] unexpected os error: %s", label, exc)
except Exception as exc:
log.error("[%s] unexpected: %s", label, exc)
finally:
@@ -1087,34 +1163,50 @@ async def _run(port: int, dc_opt: Dict[int, Optional[str]],
log.info("=" * 60)
async def log_stats():
while True:
await asyncio.sleep(60)
bl = ', '.join(
f'DC{d}{"m" if m else ""}'
for d, m in sorted(_ws_blacklist)) or 'none'
log.info("stats: %s | ws_bl: %s", _stats.summary(), bl)
try:
while True:
await asyncio.sleep(60)
bl = ', '.join(
f'DC{d}{"m" if m else ""}'
for d, m in sorted(_ws_blacklist)) or 'none'
log.info("stats: %s | ws_bl: %s", _stats.summary(), bl)
except asyncio.CancelledError:
raise
asyncio.create_task(log_stats())
log_stats_task = asyncio.create_task(log_stats())
await _ws_pool.warmup(dc_opt)
if stop_event:
async def wait_stop():
await stop_event.wait()
server.close()
me = asyncio.current_task()
for task in list(asyncio.all_tasks()):
if task is not me:
task.cancel()
try:
await server.wait_closed()
except asyncio.CancelledError:
pass
asyncio.create_task(wait_stop())
async with server:
try:
async with server:
if stop_event:
serve_task = asyncio.create_task(server.serve_forever())
stop_task = asyncio.create_task(stop_event.wait())
done, _pending = await asyncio.wait(
(serve_task, stop_task),
return_when=asyncio.FIRST_COMPLETED,
)
if stop_task in done:
server.close()
await server.wait_closed()
if not serve_task.done():
serve_task.cancel()
try:
await serve_task
except asyncio.CancelledError:
pass
else:
stop_task.cancel()
try:
await stop_task
except asyncio.CancelledError:
pass
else:
await server.serve_forever()
finally:
log_stats_task.cancel()
try:
await server.serve_forever()
await log_stats_task
except asyncio.CancelledError:
pass
_server_instance = None