diff --git a/proxy/tg_ws_proxy.py b/proxy/tg_ws_proxy.py index 8bd0c45..26078ee 100644 --- a/proxy/tg_ws_proxy.py +++ b/proxy/tg_ws_proxy.py @@ -22,6 +22,9 @@ _RECV_BUF = 256 * 1024 _SEND_BUF = 256 * 1024 _WS_POOL_SIZE = 4 _WS_POOL_MAX_AGE = 120.0 +_TCP_ONLY_PORTS = {5222} +_WS_ONLY_PORTS = {443} +_DYN_IP_CACHE_MAX = 256 _TG_RANGES = [ # 185.76.151.0/24 @@ -43,13 +46,14 @@ _IP_TO_DC: Dict[str, Tuple[int, bool]] = { # DC1 '149.154.175.50': (1, False), '149.154.175.51': (1, False), '149.154.175.53': (1, False), '149.154.175.54': (1, False), - '149.154.175.52': (1, True), + '149.154.175.52': (1, True), '149.154.175.211': (1, False), # DC2 '149.154.167.41': (2, False), '149.154.167.50': (2, False), '149.154.167.51': (2, False), '149.154.167.220': (2, False), '95.161.76.100': (2, False), '149.154.167.151': (2, True), '149.154.167.222': (2, True), '149.154.167.223': (2, True), '149.154.162.123': (2, True), + '149.154.167.35': (2, False), '149.154.167.255': (2, True), # DC3 '149.154.175.100': (3, False), '149.154.175.101': (3, False), '149.154.175.102': (3, True), @@ -74,6 +78,7 @@ _DC_OVERRIDES: Dict[int, int] = { } _dc_opt: Dict[int, Optional[str]] = {} +_prefer_tcp_for_media = False # DCs where WS is known to fail (302 redirect) # Raw TCP fallback will be used instead @@ -82,7 +87,10 @@ _ws_blacklist: Set[Tuple[int, bool]] = set() # Rate-limit re-attempts per (dc, is_media) _dc_fail_until: Dict[Tuple[int, bool], float] = {} -_DC_FAIL_COOLDOWN = 30.0 # seconds to keep reduced WS timeout after failure +_dc_fail_count: Dict[Tuple[int, bool], int] = {} +_domain_success: Dict[Tuple[int, bool], str] = {} +_DC_FAIL_COOLDOWN = 15.0 # base seconds to keep reduced WS timeout after failure +_DC_FAIL_COOLDOWN_MAX = 120.0 _WS_FAIL_TIMEOUT = 2.0 # quick-retry timeout after a recent WS failure @@ -151,6 +159,16 @@ class RawWebSocket: self.writer = writer self._closed = False + def is_usable(self) -> bool: + if self._closed: + return False + if self.writer.is_closing(): + return False + transport = self.writer.transport + if transport is None or transport.is_closing(): + return False + return True + @staticmethod async def connect(ip: str, domain: str, path: str = '/apiws', timeout: float = 10.0) -> 'RawWebSocket': @@ -472,8 +490,15 @@ class _MsgSplitter: def _ws_domains(dc: int, is_media) -> List[str]: dc = _DC_OVERRIDES.get(dc, dc) if is_media is None or is_media: - return [f'kws{dc}-1.web.telegram.org', f'kws{dc}.web.telegram.org'] - return [f'kws{dc}.web.telegram.org', f'kws{dc}-1.web.telegram.org'] + domains = [f'kws{dc}-1.web.telegram.org', f'kws{dc}.web.telegram.org'] + else: + domains = [f'kws{dc}.web.telegram.org', f'kws{dc}-1.web.telegram.org'] + key = (dc, bool(is_media)) + preferred = _domain_success.get(key) + if preferred in domains: + domains.remove(preferred) + domains.insert(0, preferred) + return domains class Stats: @@ -488,6 +513,13 @@ class Stats: self.bytes_down = 0 self.pool_hits = 0 self.pool_misses = 0 + self.media_ws_success = 0 + self.media_ws_fail = 0 + self.media_tcp_fallback = 0 + self.media_unknown_dc = 0 + self.media_init_patched = 0 + self.port_tcp_only_hits = 0 + self.port_ws_attempts = 0 def summary(self) -> str: return (f"total={self.connections_total} ws={self.connections_ws} " @@ -495,6 +527,11 @@ class Stats: f"http_skip={self.connections_http_rejected} " f"pass={self.connections_passthrough} " f"err={self.ws_errors} " + f"media(ws={self.media_ws_success},fail={self.media_ws_fail}," + f"tcp={self.media_tcp_fallback},unk={self.media_unknown_dc}," + f"patch={self.media_init_patched}) " + f"ports(tcp_only={self.port_tcp_only_hits}," + f"ws_try={self.port_ws_attempts}) " f"pool={self.pool_hits}/{self.pool_hits+self.pool_misses} " f"up={_human_bytes(self.bytes_up)} " f"down={_human_bytes(self.bytes_down)}") @@ -503,6 +540,95 @@ class Stats: _stats = Stats() +def _remember_ip_mapping(ip: str, dc: int, is_media: bool): + if not ip or ':' in ip: + return + current = _IP_TO_DC.get(ip) + if current == (dc, is_media): + return + if current is None and len(_IP_TO_DC) >= (128 + _DYN_IP_CACHE_MAX): + return + _IP_TO_DC[ip] = (dc, is_media) + log.debug("learned IP mapping %s -> DC%d%s", ip, dc, 'm' if is_media else '') + + +def _target_ip_for(dc: int, dst: str) -> Optional[str]: + target = _dc_opt.get(dc) + if target: + return target + if _is_telegram_ip(dst): + return dst + return None + + +def _guess_dc_candidates(dst: str) -> List[Tuple[int, bool, str]]: + candidates: List[Tuple[int, bool, str]] = [] + seen: Set[Tuple[int, bool, str]] = set() + + mapped = _IP_TO_DC.get(dst) + if mapped is not None: + dc, is_media = mapped + target = _target_ip_for(dc, dst) + if target: + item = (dc, bool(is_media), target) + seen.add(item) + candidates.append(item) + + prefixes = [] + parts = dst.split('.') + if len(parts) == 4: + prefixes = ['.'.join(parts[:3]) + '.', '.'.join(parts[:2]) + '.'] + for ip, (dc, is_media) in list(_IP_TO_DC.items()): + target = _target_ip_for(dc, dst) + if not target: + continue + if any(ip.startswith(prefix) for prefix in prefixes): + item = (dc, bool(is_media), target) + if item not in seen: + seen.add(item) + candidates.append(item) + + for dc, target in _dc_opt.items(): + if not target: + continue + for is_media in (False, True): + item = (dc, is_media, target) + if item not in seen: + seen.add(item) + candidates.append(item) + + return candidates + + +def _ws_mode_for(port: int, is_media: bool) -> str: + if port in _TCP_ONLY_PORTS: + return 'tcp' + if port in _WS_ONLY_PORTS: + if is_media and _prefer_tcp_for_media: + return 'tcp' + return 'ws' + if is_media and _prefer_tcp_for_media: + return 'tcp' + return 'ws' + + +def _register_ws_success(dc_key: Tuple[int, bool], domain: Optional[str] = None): + _dc_fail_until.pop(dc_key, None) + _dc_fail_count.pop(dc_key, None) + if domain: + _domain_success[dc_key] = domain + + +def _register_ws_failure(dc_key: Tuple[int, bool], redirect_only: bool): + fails = _dc_fail_count.get(dc_key, 0) + 1 + _dc_fail_count[dc_key] = fails + cooldown = min(_DC_FAIL_COOLDOWN * (2 ** (fails - 1)), _DC_FAIL_COOLDOWN_MAX) + _dc_fail_until[dc_key] = time.monotonic() + cooldown + if redirect_only: + _ws_blacklist.add(dc_key) + return cooldown + + class _WsPool: def __init__(self): self._idle: Dict[Tuple[int, bool], list] = {} @@ -518,7 +644,7 @@ class _WsPool: while bucket: ws, created = bucket.pop(0) age = now - created - if age > _WS_POOL_MAX_AGE or ws._closed: + if age > _WS_POOL_MAX_AGE or not ws.is_usable(): asyncio.create_task(self._quiet_close(ws)) continue _stats.pool_hits += 1 @@ -550,9 +676,11 @@ class _WsPool: self._connect_one(target_ip, domains))) for t in tasks: try: - ws = await t - if ws: + result = await t + if result: + ws, domain = result bucket.append((ws, time.monotonic())) + _domain_success[(dc, is_media)] = domain except Exception: pass log.debug("WS pool refilled DC%d%s: %d ready", @@ -561,12 +689,12 @@ class _WsPool: self._refilling.discard(key) @staticmethod - async def _connect_one(target_ip, domains) -> Optional[RawWebSocket]: + async def _connect_one(target_ip, domains) -> Optional[Tuple[RawWebSocket, str]]: for domain in domains: try: ws = await RawWebSocket.connect( target_ip, domain, timeout=8) - return ws + return ws, domain except WsHandshakeError as exc: if exc.is_redirect: continue @@ -631,6 +759,7 @@ async def _bridge_ws(reader, writer, ws: RawWebSocket, label, except (asyncio.CancelledError, ConnectionError, OSError): return except Exception as e: + ws._closed = True log.debug("[%s] tcp->ws ended: %s", label, e) async def ws_to_tcp(): @@ -651,6 +780,7 @@ async def _bridge_ws(reader, writer, ws: RawWebSocket, label, except (asyncio.CancelledError, ConnectionError, OSError): return except Exception as e: + ws._closed = True log.debug("[%s] ws->tcp ended: %s", label, e) tasks = [asyncio.create_task(tcp_to_ws()), @@ -698,7 +828,9 @@ async def _bridge_tcp(reader, writer, remote_reader, remote_writer, else: _stats.bytes_down += len(data) dst_w.write(data) - await dst_w.drain() + buf = dst_w.transport.get_write_buffer_size() + if buf > _SEND_BUF: + await dst_w.drain() except asyncio.CancelledError: pass except Exception as e: @@ -808,8 +940,13 @@ async def _handle_client(reader, writer): dlen = (await reader.readexactly(1))[0] dst = (await reader.readexactly(dlen)).decode() elif atyp == 4: # IPv6 - raw = await reader.readexactly(16) - dst = _socket.inet_ntop(_socket.AF_INET6, raw) + await reader.readexactly(16) + await reader.readexactly(2) + log.debug("[%s] IPv6 SOCKS request rejected", label) + writer.write(_socks5_reply(0x08)) + await writer.drain() + writer.close() + return else: writer.write(_socks5_reply(0x08)) await writer.drain() @@ -819,12 +956,8 @@ async def _handle_client(reader, writer): port = struct.unpack('!H', await reader.readexactly(2))[0] if ':' in dst: - log.error( - "[%s] IPv6 address detected: %s:%d — " - "IPv6 addresses are not supported; " - "disable IPv6 to continue using the proxy.", - label, dst, port) - writer.write(_socks5_reply(0x05)) + log.debug("[%s] rejected non-IPv4 destination %s:%d", label, dst, port) + writer.write(_socks5_reply(0x08)) await writer.drain() writer.close() return @@ -881,29 +1014,86 @@ async def _handle_client(reader, writer): # -- Extract DC ID -- dc, is_media = _dc_from_init(init) init_patched = False - + # Android (may be ios too) with useSecret=0 has random dc_id bytes — patch it if dc is None and dst in _IP_TO_DC: dc, is_media = _IP_TO_DC.get(dst) - if dc in _dc_opt: - init = _patch_init_dc(init, dc if is_media else -dc) - init_patched = True + signed_dc = -dc if is_media else dc + init = _patch_init_dc(init, signed_dc) + init_patched = True + if is_media: + _stats.media_init_patched += 1 - if dc is None or dc not in _dc_opt: - log.warning("[%s] unknown DC%s for %s:%d -> TCP passthrough", - label, dc, dst, port) - await _tcp_fallback(reader, writer, dst, port, init, label) - return + if dc is None: + guessed = _guess_dc_candidates(dst) + if guessed: + log.info("[%s] unknown DC for %s:%d -> trying guessed WS candidates: %s", + label, dst, port, ', '.join(f'DC{gdc}{'m' if gmedia else ""}@{gtarget}' for gdc, gmedia, gtarget in guessed[:6])) + last_media = False + for gdc, gis_media, gtarget in guessed: + dc = gdc + is_media = gis_media + last_media = gis_media + signed_dc = -gdc if gis_media else gdc + patched_init = _patch_init_dc(init, signed_dc) + _remember_ip_mapping(dst, gdc, gis_media) + target = _target_ip_for(gdc, dst) or gtarget + if target: + init = patched_init + init_patched = True + if gis_media: + _stats.media_init_patched += 1 + break + else: + dc = None + is_media = last_media + if dc is None: + if is_media: + _stats.media_unknown_dc += 1 + log.warning("[%s] unknown DC for %s:%d -> TCP passthrough", + label, dst, port) + await _tcp_fallback(reader, writer, dst, port, init, label) + return - dc_key = (dc, is_media if is_media is not None else True) + _remember_ip_mapping(dst, dc, bool(is_media)) + dc_key = (dc, bool(is_media)) now = time.monotonic() media_tag = (" media" if is_media else (" media?" if is_media is None else "")) + target = _target_ip_for(dc, dst) + mode = _ws_mode_for(port, bool(is_media)) + + if target is None: + if is_media: + _stats.media_unknown_dc += 1 + log.warning("[%s] DC%d%s has no target IP for %s:%d -> TCP passthrough", + label, dc, media_tag, dst, port) + await _tcp_fallback(reader, writer, dst, port, init, label, + dc=dc, is_media=is_media) + return + + if mode == 'tcp': + _stats.port_tcp_only_hits += 1 + if is_media: + _stats.media_tcp_fallback += 1 + log.info("[%s] DC%d%s port %d policy -> TCP %s:%d", + label, dc, media_tag, port, dst, port) + ok = await _tcp_fallback(reader, writer, dst, port, init, + label, dc=dc, is_media=is_media) + if ok: + log.info("[%s] DC%d%s TCP policy session closed", + label, dc, media_tag) + return + + _stats.port_ws_attempts += 1 # -- WS blacklist check -- if dc_key in _ws_blacklist: log.debug("[%s] DC%d%s WS blacklisted -> TCP %s:%d", label, dc, media_tag, dst, port) + if is_media: + _stats.media_ws_fail += 1 + _stats.media_tcp_fallback += 1 ok = await _tcp_fallback(reader, writer, dst, port, init, label, dc=dc, is_media=is_media) if ok: @@ -916,15 +1106,16 @@ async def _handle_client(reader, writer): ws_timeout = _WS_FAIL_TIMEOUT if now < fail_until else 10.0 domains = _ws_domains(dc, is_media) - target = _dc_opt[dc] ws = None ws_failed_redirect = False all_redirects = True + selected_domain = None - ws = await _ws_pool.get(dc, is_media, target, domains) + ws = await _ws_pool.get(dc, bool(is_media), target, domains) if ws: log.info("[%s] DC%d%s (%s:%d) -> pool hit via %s", label, dc, media_tag, dst, port, target) + selected_domain = _domain_success.get(dc_key) else: for domain in domains: url = f'wss://{domain}/apiws' @@ -934,6 +1125,7 @@ async def _handle_client(reader, writer): ws = await RawWebSocket.connect(target, domain, timeout=ws_timeout) all_redirects = False + selected_domain = domain break except WsHandshakeError as exc: _stats.ws_errors += 1 @@ -962,18 +1154,18 @@ async def _handle_client(reader, writer): # -- WS failed -> fallback -- if ws is None: + cooldown = _register_ws_failure(dc_key, ws_failed_redirect and all_redirects) if ws_failed_redirect and all_redirects: - _ws_blacklist.add(dc_key) log.warning( "[%s] DC%d%s blacklisted for WS (all 302)", label, dc, media_tag) - elif ws_failed_redirect: - _dc_fail_until[dc_key] = now + _DC_FAIL_COOLDOWN else: - _dc_fail_until[dc_key] = now + _DC_FAIL_COOLDOWN log.info("[%s] DC%d%s WS cooldown for %ds", - label, dc, media_tag, int(_DC_FAIL_COOLDOWN)) + label, dc, media_tag, int(cooldown)) + if is_media: + _stats.media_ws_fail += 1 + _stats.media_tcp_fallback += 1 log.info("[%s] DC%d%s -> TCP fallback to %s:%d", label, dc, media_tag, dst, port) ok = await _tcp_fallback(reader, writer, dst, port, init, @@ -984,8 +1176,10 @@ async def _handle_client(reader, writer): return # -- WS success -- - _dc_fail_until.pop(dc_key, None) + _register_ws_success(dc_key, selected_domain) _stats.connections_ws += 1 + if is_media: + _stats.media_ws_success += 1 splitter = None if init_patched: @@ -1023,6 +1217,28 @@ _server_instance = None _server_stop_event = None +async def _probe_startup(dc_opt: Dict[int, Optional[str]]): + for dc, target_ip in dc_opt.items(): + if not target_ip: + continue + for is_media in (False, True): + domains = _ws_domains(dc, is_media) + ok = False + used = None + for domain in domains: + try: + ws = await RawWebSocket.connect(target_ip, domain, timeout=4.0) + used = domain + await ws.close() + ok = True + break + except Exception: + continue + log.info("startup probe DC%d%s via %s: %s", + dc, 'm' if is_media else '', + target_ip, used if ok else 'FAIL') + + async def _run(port: int, dc_opt: Dict[int, Optional[str]], stop_event: Optional[asyncio.Event] = None, host: str = '127.0.0.1'): @@ -1063,6 +1279,7 @@ async def _run(port: int, dc_opt: Dict[int, Optional[str]], asyncio.create_task(log_stats()) await _ws_pool.warmup(dc_opt) + asyncio.create_task(_probe_startup(dc_opt)) if stop_event: async def wait_stop(): @@ -1109,6 +1326,19 @@ def run_proxy(port: int, dc_opt: Dict[int, str], asyncio.run(_run(port, dc_opt, stop_event, host)) +def _parse_port_set(value: str) -> Set[int]: + ports: Set[int] = set() + for part in value.split(','): + part = part.strip() + if not part: + continue + p = int(part) + if not (1 <= p <= 65535): + raise ValueError(f"Invalid port {p}") + ports.add(p) + return ports + + def main(): ap = argparse.ArgumentParser( description='Telegram Desktop WebSocket Bridge Proxy') @@ -1120,6 +1350,12 @@ def main(): default=[], help='Target IP for a DC, e.g. --dc-ip 1:149.154.175.205' ' --dc-ip 2:149.154.167.220') + ap.add_argument('--tcp-only-ports', type=str, default='5222', + help='Comma-separated Telegram destination ports that should always use direct TCP (default 5222)') + ap.add_argument('--ws-ports', type=str, default='443', + help='Comma-separated Telegram destination ports that should prefer WebSocket bridge (default 443)') + ap.add_argument('--prefer-tcp-for-media', action='store_true', + help='Route media Telegram sessions over direct TCP when possible') ap.add_argument('-v', '--verbose', action='store_true', help='Debug logging') args = ap.parse_args() @@ -1127,12 +1363,17 @@ def main(): if not args.dc_ip: args.dc_ip = ['2:149.154.167.220', '4:149.154.167.220'] + global _prefer_tcp_for_media, _TCP_ONLY_PORTS, _WS_ONLY_PORTS try: dc_opt = parse_dc_ip_list(args.dc_ip) + _TCP_ONLY_PORTS = _parse_port_set(args.tcp_only_ports) + _WS_ONLY_PORTS = _parse_port_set(args.ws_ports) except ValueError as e: log.error(str(e)) sys.exit(1) + _prefer_tcp_for_media = args.prefer_tcp_for_media + logging.basicConfig( level=logging.DEBUG if args.verbose else logging.INFO, format='%(asctime)s %(levelname)-5s %(message)s',