From 1c227b924a7ebef0c3fe7976cb5868236bfd9f86 Mon Sep 17 00:00:00 2001 From: Flowseal Date: Sun, 15 Mar 2026 04:34:05 +0300 Subject: [PATCH] Optimization, connections pool --- proxy/tg_ws_proxy.py | 203 +++++++++++++++++++++++++++++++++++-------- 1 file changed, 169 insertions(+), 34 deletions(-) diff --git a/proxy/tg_ws_proxy.py b/proxy/tg_ws_proxy.py index 6a425b2..52f33a2 100644 --- a/proxy/tg_ws_proxy.py +++ b/proxy/tg_ws_proxy.py @@ -17,6 +17,12 @@ from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes DEFAULT_PORT = 1080 log = logging.getLogger('tg-ws-proxy') +_TCP_NODELAY = True +_RECV_BUF = 131072 +_SEND_BUF = 131072 +_WS_POOL_SIZE = 4 +_WS_POOL_MAX_AGE = 120.0 + _TG_RANGES = [ # 185.76.151.0/24 (struct.unpack('!I', _socket.inet_aton('185.76.151.0'))[0], @@ -43,7 +49,7 @@ _IP_TO_DC: Dict[str, Tuple[int, bool]] = { '149.154.167.51': (2, False), '149.154.167.220': (2, False), '95.161.76.100': (2, False), '149.154.167.151': (2, True), '149.154.167.222': (2, True), - '149.154.167.223': (2, True), + '149.154.167.223': (2, True), '149.154.162.123': (2, True), # DC3 '149.154.175.100': (3, False), '149.154.175.101': (3, False), '149.154.175.102': (3, True), @@ -79,6 +85,22 @@ _ssl_ctx.check_hostname = False _ssl_ctx.verify_mode = ssl.CERT_NONE +def _set_sock_opts(transport): + sock = transport.get_extra_info('socket') + if sock is None: + return + if _TCP_NODELAY: + try: + sock.setsockopt(_socket.IPPROTO_TCP, _socket.TCP_NODELAY, 1) + except (OSError, AttributeError): + pass + try: + sock.setsockopt(_socket.SOL_SOCKET, _socket.SO_RCVBUF, _RECV_BUF) + sock.setsockopt(_socket.SOL_SOCKET, _socket.SO_SNDBUF, _SEND_BUF) + except OSError: + pass + + class WsHandshakeError(Exception): def __init__(self, status_code: int, status_line: str, headers: dict = None, location: str = None): @@ -136,6 +158,7 @@ class RawWebSocket: asyncio.open_connection(ip, 443, ssl=_ssl_ctx, server_hostname=domain), timeout=min(timeout, 10)) + _set_sock_opts(writer.transport) ws_key = base64.b64encode(os.urandom(16)).decode() req = ( @@ -463,6 +486,8 @@ class Stats: self.ws_errors = 0 self.bytes_up = 0 self.bytes_down = 0 + self.pool_hits = 0 + self.pool_misses = 0 def summary(self) -> str: return (f"total={self.connections_total} ws={self.connections_ws} " @@ -470,6 +495,7 @@ class Stats: f"http_skip={self.connections_http_rejected} " f"pass={self.connections_passthrough} " f"err={self.ws_errors} " + f"pool={self.pool_hits}/{self.pool_hits+self.pool_misses} " f"up={_human_bytes(self.bytes_up)} " f"down={_human_bytes(self.bytes_down)}") @@ -477,6 +503,100 @@ class Stats: _stats = Stats() +class _WsPool: + def __init__(self): + self._idle: Dict[Tuple[int, bool], list] = {} + self._refilling: Set[Tuple[int, bool]] = set() + + async def get(self, dc: int, is_media: bool, + target_ip: str, domains: List[str] + ) -> Optional[RawWebSocket]: + key = (dc, is_media) + now = time.monotonic() + + bucket = self._idle.get(key, []) + while bucket: + ws, created = bucket.pop(0) + age = now - created + if age > _WS_POOL_MAX_AGE or ws._closed: + asyncio.create_task(self._quiet_close(ws)) + continue + _stats.pool_hits += 1 + log.debug("WS pool hit for DC%d%s (age=%.1fs, left=%d)", + dc, 'm' if is_media else '', age, len(bucket)) + self._schedule_refill(key, target_ip, domains) + return ws + + _stats.pool_misses += 1 + self._schedule_refill(key, target_ip, domains) + return None + + def _schedule_refill(self, key, target_ip, domains): + if key in self._refilling: + return + self._refilling.add(key) + asyncio.create_task(self._refill(key, target_ip, domains)) + + async def _refill(self, key, target_ip, domains): + dc, is_media = key + try: + bucket = self._idle.setdefault(key, []) + needed = _WS_POOL_SIZE - len(bucket) + if needed <= 0: + return + tasks = [] + for _ in range(needed): + tasks.append(asyncio.create_task( + self._connect_one(target_ip, domains))) + for t in tasks: + try: + ws = await t + if ws: + bucket.append((ws, time.monotonic())) + except Exception: + pass + log.debug("WS pool refilled DC%d%s: %d ready", + dc, 'm' if is_media else '', len(bucket)) + finally: + self._refilling.discard(key) + + @staticmethod + async def _connect_one(target_ip, domains) -> Optional[RawWebSocket]: + for domain in domains: + try: + ws = await RawWebSocket.connect( + target_ip, domain, timeout=8) + return ws + except WsHandshakeError as exc: + if exc.is_redirect: + continue + return None + except Exception: + return None + return None + + @staticmethod + async def _quiet_close(ws): + try: + await ws.close() + except Exception: + pass + + async def warmup(self, dc_opt: Dict[int, Optional[str]]): + """Pre-fill pool for all configured DCs on startup.""" + for dc, target_ip in dc_opt.items(): + if target_ip is None: + continue + for is_media in (False, True): + domains = _ws_domains(dc, is_media) + key = (dc, is_media) + self._schedule_refill(key, target_ip, domains) + log.info("WS pool warmup started for %d DC(s)", len(dc_opt)) + + +_ws_pool = _WsPool() + + async def _bridge_ws(reader, writer, ws: RawWebSocket, label, dc=None, dst=None, port=None, is_media=False, splitter: _MsgSplitter = None): @@ -526,7 +646,7 @@ async def _bridge_ws(reader, writer, ws: RawWebSocket, label, writer.write(data) # drain only when kernel buffer is filling up buf = writer.transport.get_write_buffer_size() - if buf > 262144: + if buf > _SEND_BUF: await writer.drain() except (asyncio.CancelledError, ConnectionError, OSError): return @@ -658,6 +778,8 @@ async def _handle_client(reader, writer): peer = writer.get_extra_info('peername') label = f"{peer[0]}:{peer[1]}" if peer else "?" + _set_sock_opts(writer.transport) + try: # -- SOCKS5 greeting -- hdr = await asyncio.wait_for(reader.readexactly(2), timeout=10) @@ -798,39 +920,44 @@ async def _handle_client(reader, writer): ws_failed_redirect = False all_redirects = True - for domain in domains: - url = f'wss://{domain}/apiws' - log.info("[%s] DC%d%s (%s:%d) -> %s via %s", - label, dc, media_tag, dst, port, url, target) - try: - ws = await RawWebSocket.connect(target, domain, - timeout=10) - all_redirects = False - break - except WsHandshakeError as exc: - _stats.ws_errors += 1 - if exc.is_redirect: - ws_failed_redirect = True - log.warning("[%s] DC%d%s got %d from %s -> %s", - label, dc, media_tag, - exc.status_code, domain, - exc.location or '?') - continue - else: + ws = await _ws_pool.get(dc, is_media, target, domains) + if ws: + log.info("[%s] DC%d%s (%s:%d) -> pool hit via %s", + label, dc, media_tag, dst, port, target) + else: + for domain in domains: + url = f'wss://{domain}/apiws' + log.info("[%s] DC%d%s (%s:%d) -> %s via %s", + label, dc, media_tag, dst, port, url, target) + try: + ws = await RawWebSocket.connect(target, domain, + timeout=10) all_redirects = False - log.warning("[%s] DC%d%s WS handshake: %s", - label, dc, media_tag, exc.status_line) - except Exception as exc: - _stats.ws_errors += 1 - all_redirects = False - err_str = str(exc) - if ('CERTIFICATE_VERIFY_FAILED' in err_str or - 'Hostname mismatch' in err_str): - log.warning("[%s] DC%d%s SSL error: %s", - label, dc, media_tag, exc) - else: - log.warning("[%s] DC%d%s WS connect failed: %s", - label, dc, media_tag, exc) + break + except WsHandshakeError as exc: + _stats.ws_errors += 1 + if exc.is_redirect: + ws_failed_redirect = True + log.warning("[%s] DC%d%s got %d from %s -> %s", + label, dc, media_tag, + exc.status_code, domain, + exc.location or '?') + continue + else: + all_redirects = False + log.warning("[%s] DC%d%s WS handshake: %s", + label, dc, media_tag, exc.status_line) + except Exception as exc: + _stats.ws_errors += 1 + all_redirects = False + err_str = str(exc) + if ('CERTIFICATE_VERIFY_FAILED' in err_str or + 'Hostname mismatch' in err_str): + log.warning("[%s] DC%d%s SSL error: %s", + label, dc, media_tag, exc) + else: + log.warning("[%s] DC%d%s WS connect failed: %s", + label, dc, media_tag, exc) # -- WS failed -> fallback -- if ws is None: @@ -906,6 +1033,12 @@ async def _run(port: int, dc_opt: Dict[int, Optional[str]], _handle_client, host, port) _server_instance = server + for sock in server.sockets: + try: + sock.setsockopt(_socket.IPPROTO_TCP, _socket.TCP_NODELAY, 1) + except (OSError, AttributeError): + pass + log.info("=" * 60) log.info(" Telegram WS Bridge Proxy") log.info(" Listening on %s:%d", host, port) @@ -928,6 +1061,8 @@ async def _run(port: int, dc_opt: Dict[int, Optional[str]], asyncio.create_task(log_stats()) + await _ws_pool.warmup(dc_opt) + if stop_event: async def wait_stop(): await stop_event.wait()