mirror of
https://github.com/Flowseal/tg-ws-proxy.git
synced 2026-06-16 03:28:26 +03:00
Merge upstream/main into android_migration
This commit is contained in:
@@ -4,6 +4,7 @@ import argparse
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
from collections import deque
|
||||
import logging.handlers
|
||||
import os
|
||||
import socket as _socket
|
||||
@@ -145,7 +146,14 @@ _st_Q = struct.Struct('>Q')
|
||||
_st_I_net = struct.Struct('!I')
|
||||
_st_Ih = struct.Struct('<Ih')
|
||||
_st_I_le = struct.Struct('<I')
|
||||
_VALID_PROTOS = frozenset((0xEFEFEFEF, 0xEEEEEEEE, 0xDDDDDDDD))
|
||||
_PROTO_ABRIDGED = 0xEFEFEFEF
|
||||
_PROTO_INTERMEDIATE = 0xEEEEEEEE
|
||||
_PROTO_PADDED_INTERMEDIATE = 0xDDDDDDDD
|
||||
_VALID_PROTOS = frozenset((
|
||||
_PROTO_ABRIDGED,
|
||||
_PROTO_INTERMEDIATE,
|
||||
_PROTO_PADDED_INTERMEDIATE,
|
||||
))
|
||||
|
||||
|
||||
class RawWebSocket:
|
||||
@@ -382,11 +390,7 @@ def _is_http_transport(data: bytes) -> bool:
|
||||
data[:5] == b'HEAD ' or data[:8] == b'OPTIONS ')
|
||||
|
||||
|
||||
def _dc_from_init(data: bytes) -> Tuple[Optional[int], bool]:
|
||||
"""
|
||||
Extract DC ID from the 64-byte MTProto obfuscation init packet.
|
||||
Returns (dc_id, is_media).
|
||||
"""
|
||||
def _dc_from_init(data: bytes, *, return_proto: bool = False):
|
||||
try:
|
||||
key = bytes(data[8:40])
|
||||
iv = bytes(data[40:56])
|
||||
@@ -400,11 +404,14 @@ def _dc_from_init(data: bytes) -> Tuple[Optional[int], bool]:
|
||||
if proto in _VALID_PROTOS:
|
||||
dc = abs(dc_raw)
|
||||
if 1 <= dc <= 5 or dc == 203:
|
||||
return dc, (dc_raw < 0)
|
||||
return (
|
||||
(dc, (dc_raw < 0), proto)
|
||||
if return_proto else (dc, (dc_raw < 0))
|
||||
)
|
||||
return (None, False, proto) if return_proto else (None, False)
|
||||
except Exception as exc:
|
||||
log.debug("DC extraction failed: %s", exc)
|
||||
return None, False
|
||||
|
||||
return (None, False, None) if return_proto else (None, False)
|
||||
|
||||
def _patch_init_dc(data: bytes, dc: int) -> bytes:
|
||||
"""
|
||||
@@ -435,54 +442,103 @@ def _patch_init_dc(data: bytes, dc: int) -> bytes:
|
||||
|
||||
class _MsgSplitter:
|
||||
"""
|
||||
Splits client TCP data into individual MTProto abridged-protocol
|
||||
messages so each can be sent as a separate WebSocket frame.
|
||||
Splits client TCP data into individual MTProto transport packets so
|
||||
each can be sent as a separate WebSocket frame.
|
||||
|
||||
The Telegram WS relay processes one MTProto message per WS frame.
|
||||
Mobile clients batches multiple messages in a single TCP write (e.g.
|
||||
msgs_ack + req_DH_params). If sent as one WS frame, the relay
|
||||
only processes the first message — DH handshake never completes.
|
||||
Some mobile clients coalesce multiple MTProto packets into one TCP
|
||||
write, and TCP reads may also cut a packet in half. Keep a rolling
|
||||
buffer so incomplete packets are not forwarded as standalone frames.
|
||||
"""
|
||||
|
||||
def __init__(self, init_data: bytes):
|
||||
__slots__ = ('_dec', '_proto', '_cipher_buf', '_plain_buf', '_disabled')
|
||||
|
||||
def __init__(self, init_data: bytes, proto: Optional[int] = None):
|
||||
if proto is None:
|
||||
_, _, proto = _dc_from_init(init_data, return_proto=True)
|
||||
key_raw = bytes(init_data[8:40])
|
||||
iv = bytes(init_data[40:56])
|
||||
self._dec = create_aes_ctr_transform(key_raw, iv)
|
||||
self._dec.update(b'\x00' * 64) # skip init packet
|
||||
self._proto = proto
|
||||
self._cipher_buf = bytearray()
|
||||
self._plain_buf = bytearray()
|
||||
self._disabled = False
|
||||
|
||||
def split(self, chunk: bytes) -> List[bytes]:
|
||||
"""Decrypt to find message boundaries, return split ciphertext."""
|
||||
plain = self._dec.update(chunk)
|
||||
boundaries = []
|
||||
pos = 0
|
||||
plain_len = len(plain)
|
||||
while pos < plain_len:
|
||||
first = plain[pos]
|
||||
if first == 0x7f:
|
||||
if pos + 4 > plain_len:
|
||||
break
|
||||
msg_len = (
|
||||
_st_I_le.unpack_from(plain, pos + 1)[0] & 0xFFFFFF
|
||||
) * 4
|
||||
pos += 4
|
||||
else:
|
||||
msg_len = first * 4
|
||||
pos += 1
|
||||
if msg_len == 0 or pos + msg_len > plain_len:
|
||||
break
|
||||
pos += msg_len
|
||||
boundaries.append(pos)
|
||||
if len(boundaries) <= 1:
|
||||
"""Decrypt to find packet boundaries, return complete ciphertext packets."""
|
||||
if not chunk:
|
||||
return []
|
||||
if self._disabled:
|
||||
return [chunk]
|
||||
|
||||
self._cipher_buf.extend(chunk)
|
||||
self._plain_buf.extend(self._dec.update(chunk))
|
||||
|
||||
parts = []
|
||||
prev = 0
|
||||
for b in boundaries:
|
||||
parts.append(chunk[prev:b])
|
||||
prev = b
|
||||
if prev < len(chunk):
|
||||
parts.append(chunk[prev:])
|
||||
while self._cipher_buf:
|
||||
packet_len = self._next_packet_len()
|
||||
if packet_len is None:
|
||||
break
|
||||
if packet_len <= 0:
|
||||
parts.append(bytes(self._cipher_buf))
|
||||
self._cipher_buf.clear()
|
||||
self._plain_buf.clear()
|
||||
self._disabled = True
|
||||
break
|
||||
parts.append(bytes(self._cipher_buf[:packet_len]))
|
||||
del self._cipher_buf[:packet_len]
|
||||
del self._plain_buf[:packet_len]
|
||||
return parts
|
||||
|
||||
def flush(self) -> List[bytes]:
|
||||
if not self._cipher_buf:
|
||||
return []
|
||||
tail = bytes(self._cipher_buf)
|
||||
self._cipher_buf.clear()
|
||||
self._plain_buf.clear()
|
||||
return [tail]
|
||||
|
||||
def _next_packet_len(self) -> Optional[int]:
|
||||
if not self._plain_buf:
|
||||
return None
|
||||
if self._proto == _PROTO_ABRIDGED:
|
||||
return self._next_abridged_len()
|
||||
if self._proto in (_PROTO_INTERMEDIATE, _PROTO_PADDED_INTERMEDIATE):
|
||||
return self._next_intermediate_len()
|
||||
return 0
|
||||
|
||||
def _next_abridged_len(self) -> Optional[int]:
|
||||
first = self._plain_buf[0]
|
||||
if first in (0x7F, 0xFF):
|
||||
if len(self._plain_buf) < 4:
|
||||
return None
|
||||
payload_len = int.from_bytes(self._plain_buf[1:4], 'little') * 4
|
||||
header_len = 4
|
||||
else:
|
||||
payload_len = (first & 0x7F) * 4
|
||||
header_len = 1
|
||||
|
||||
if payload_len <= 0:
|
||||
return 0
|
||||
|
||||
packet_len = header_len + payload_len
|
||||
if len(self._plain_buf) < packet_len:
|
||||
return None
|
||||
return packet_len
|
||||
|
||||
def _next_intermediate_len(self) -> Optional[int]:
|
||||
if len(self._plain_buf) < 4:
|
||||
return None
|
||||
|
||||
payload_len = _st_I_le.unpack_from(self._plain_buf, 0)[0] & 0x7FFFFFFF
|
||||
if payload_len <= 0:
|
||||
return 0
|
||||
|
||||
packet_len = 4 + payload_len
|
||||
if len(self._plain_buf) < packet_len:
|
||||
return None
|
||||
return packet_len
|
||||
|
||||
|
||||
def _ws_domains(dc: int, is_media) -> List[str]:
|
||||
dc = _DC_OVERRIDES.get(dc, dc)
|
||||
@@ -505,12 +561,15 @@ class Stats:
|
||||
self.pool_misses = 0
|
||||
|
||||
def summary(self) -> str:
|
||||
pool_total = self.pool_hits + self.pool_misses
|
||||
pool_s = (
|
||||
f"{self.pool_hits}/{pool_total}" if pool_total else "n/a")
|
||||
return (f"total={self.connections_total} ws={self.connections_ws} "
|
||||
f"tcp_fb={self.connections_tcp_fallback} "
|
||||
f"http_skip={self.connections_http_rejected} "
|
||||
f"pass={self.connections_passthrough} "
|
||||
f"err={self.ws_errors} "
|
||||
f"pool={self.pool_hits}/{self.pool_hits+self.pool_misses} "
|
||||
f"pool={pool_s} "
|
||||
f"up={_human_bytes(self.bytes_up)} "
|
||||
f"down={_human_bytes(self.bytes_down)}")
|
||||
|
||||
@@ -535,7 +594,7 @@ def get_stats_snapshot() -> Dict[str, int]:
|
||||
|
||||
class _WsPool:
|
||||
def __init__(self):
|
||||
self._idle: Dict[Tuple[int, bool], list] = {}
|
||||
self._idle: Dict[Tuple[int, bool], deque] = {}
|
||||
self._refilling: Set[Tuple[int, bool]] = set()
|
||||
|
||||
async def get(self, dc: int, is_media: bool,
|
||||
@@ -544,9 +603,12 @@ class _WsPool:
|
||||
key = (dc, is_media)
|
||||
now = time.monotonic()
|
||||
|
||||
bucket = self._idle.get(key, [])
|
||||
bucket = self._idle.get(key)
|
||||
if bucket is None:
|
||||
bucket = deque()
|
||||
self._idle[key] = bucket
|
||||
while bucket:
|
||||
ws, created = bucket.pop(0)
|
||||
ws, created = bucket.popleft()
|
||||
age = now - created
|
||||
if age > _WS_POOL_MAX_AGE or ws._closed:
|
||||
asyncio.create_task(self._quiet_close(ws))
|
||||
@@ -570,7 +632,7 @@ class _WsPool:
|
||||
async def _refill(self, key, target_ip, domains):
|
||||
dc, is_media = key
|
||||
try:
|
||||
bucket = self._idle.setdefault(key, [])
|
||||
bucket = self._idle.setdefault(key, deque())
|
||||
needed = _WS_POOL_SIZE - len(bucket)
|
||||
if needed <= 0:
|
||||
return
|
||||
@@ -646,6 +708,10 @@ async def _bridge_ws(reader, writer, ws: RawWebSocket, label,
|
||||
while True:
|
||||
chunk = await reader.read(65536)
|
||||
if not chunk:
|
||||
if splitter:
|
||||
tail = splitter.flush()
|
||||
if tail:
|
||||
await ws.send(tail[0])
|
||||
break
|
||||
n = len(chunk)
|
||||
_stats.bytes_up += n
|
||||
@@ -653,6 +719,8 @@ async def _bridge_ws(reader, writer, ws: RawWebSocket, label,
|
||||
up_packets += 1
|
||||
if splitter:
|
||||
parts = splitter.split(chunk)
|
||||
if not parts:
|
||||
continue
|
||||
if len(parts) > 1:
|
||||
await ws.send_batch(parts)
|
||||
else:
|
||||
@@ -913,14 +981,14 @@ async def _handle_client(reader, writer):
|
||||
return
|
||||
|
||||
# -- Extract DC ID --
|
||||
dc, is_media = _dc_from_init(init)
|
||||
init_patched = False
|
||||
dc, is_media, proto = _dc_from_init(init)
|
||||
|
||||
init_patched = False
|
||||
# Android (may be ios too) with useSecret=0 has random dc_id bytes — patch it
|
||||
if dc is None and dst in _IP_TO_DC:
|
||||
dc, is_media = _IP_TO_DC.get(dst)
|
||||
if dc in _dc_opt:
|
||||
init = _patch_init_dc(init, dc if is_media else -dc)
|
||||
init = _patch_init_dc(init, -dc if is_media else dc)
|
||||
init_patched = True
|
||||
|
||||
if dc is None or dc not in _dc_opt:
|
||||
@@ -1022,9 +1090,12 @@ async def _handle_client(reader, writer):
|
||||
_stats.connections_ws += 1
|
||||
|
||||
splitter = None
|
||||
if init_patched:
|
||||
|
||||
# Turning splitter on for mobile clients or media-connections, so as the big files don't get fragmented by the TCP socket.
|
||||
if proto is not None and (init_patched or is_media or proto != _PROTO_INTERMEDIATE):
|
||||
try:
|
||||
splitter = _MsgSplitter(init)
|
||||
splitter = _MsgSplitter(init, proto)
|
||||
log.debug("[%s] MsgSplitter activated for proto 0x%08X", label, proto)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -1044,6 +1115,11 @@ async def _handle_client(reader, writer):
|
||||
log.debug("[%s] cancelled", label)
|
||||
except ConnectionResetError:
|
||||
log.debug("[%s] connection reset", label)
|
||||
except OSError as exc:
|
||||
if getattr(exc, 'winerror', None) == 1236:
|
||||
log.debug("[%s] connection aborted by local system", label)
|
||||
else:
|
||||
log.error("[%s] unexpected os error: %s", label, exc)
|
||||
except Exception as exc:
|
||||
log.error("[%s] unexpected: %s", label, exc)
|
||||
finally:
|
||||
@@ -1087,34 +1163,50 @@ async def _run(port: int, dc_opt: Dict[int, Optional[str]],
|
||||
log.info("=" * 60)
|
||||
|
||||
async def log_stats():
|
||||
while True:
|
||||
await asyncio.sleep(60)
|
||||
bl = ', '.join(
|
||||
f'DC{d}{"m" if m else ""}'
|
||||
for d, m in sorted(_ws_blacklist)) or 'none'
|
||||
log.info("stats: %s | ws_bl: %s", _stats.summary(), bl)
|
||||
try:
|
||||
while True:
|
||||
await asyncio.sleep(60)
|
||||
bl = ', '.join(
|
||||
f'DC{d}{"m" if m else ""}'
|
||||
for d, m in sorted(_ws_blacklist)) or 'none'
|
||||
log.info("stats: %s | ws_bl: %s", _stats.summary(), bl)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
|
||||
asyncio.create_task(log_stats())
|
||||
log_stats_task = asyncio.create_task(log_stats())
|
||||
|
||||
await _ws_pool.warmup(dc_opt)
|
||||
|
||||
if stop_event:
|
||||
async def wait_stop():
|
||||
await stop_event.wait()
|
||||
server.close()
|
||||
me = asyncio.current_task()
|
||||
for task in list(asyncio.all_tasks()):
|
||||
if task is not me:
|
||||
task.cancel()
|
||||
try:
|
||||
await server.wait_closed()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
asyncio.create_task(wait_stop())
|
||||
|
||||
async with server:
|
||||
try:
|
||||
async with server:
|
||||
if stop_event:
|
||||
serve_task = asyncio.create_task(server.serve_forever())
|
||||
stop_task = asyncio.create_task(stop_event.wait())
|
||||
done, _pending = await asyncio.wait(
|
||||
(serve_task, stop_task),
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
if stop_task in done:
|
||||
server.close()
|
||||
await server.wait_closed()
|
||||
if not serve_task.done():
|
||||
serve_task.cancel()
|
||||
try:
|
||||
await serve_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
else:
|
||||
stop_task.cancel()
|
||||
try:
|
||||
await stop_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
else:
|
||||
await server.serve_forever()
|
||||
finally:
|
||||
log_stats_task.cancel()
|
||||
try:
|
||||
await server.serve_forever()
|
||||
await log_stats_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
_server_instance = None
|
||||
|
||||
Reference in New Issue
Block a user