From 96e5b4b639e55858a761699eb26a00364e330787 Mon Sep 17 00:00:00 2001 From: Konukhov Yaroslav Date: Wed, 17 Jun 2026 00:13:55 -0700 Subject: [PATCH] fix: add WebSocket keepalive pings to prevent idle disconnects (#646) (#925) --- proxy/bridge.py | 22 +++++++++++++++++++++- proxy/config.py | 1 + proxy/raw_websocket.py | 7 +++++++ proxy/tg_ws_proxy.py | 5 +++++ 4 files changed, 34 insertions(+), 1 deletion(-) diff --git a/proxy/bridge.py b/proxy/bridge.py index d66c3db..6133476 100644 --- a/proxy/bridge.py +++ b/proxy/bridge.py @@ -266,6 +266,23 @@ async def _tcp_fallback(reader, writer, dst, port, relay_init, label, ctx: Crypt return True +async def _ws_keepalive(ws, interval: float): + """Send periodic WS PING frames to keep the upstream flow warm. + + A non-positive interval disables keepalive. The loop exits on send + failure so a dead upstream is detected promptly instead of lingering + until the next client packet (see issue #646). + """ + if interval <= 0: + return + try: + while True: + await asyncio.sleep(interval) + await ws.send_ping() + except (asyncio.CancelledError, ConnectionError, OSError): + return + + async def bridge_ws_reencrypt(reader, writer, ws: RawWebSocket, label, ctx: CryptoCtx, dc=None, is_media=False, @@ -337,12 +354,15 @@ async def bridge_ws_reencrypt(reader, writer, ws: RawWebSocket, label, tasks = [asyncio.create_task(tcp_to_ws()), asyncio.create_task(ws_to_tcp())] + keepalive = asyncio.ensure_future( + _ws_keepalive(ws, proxy_config.ws_keepalive_interval)) try: await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) finally: + keepalive.cancel() for t in tasks: t.cancel() - for t in tasks: + for t in (*tasks, keepalive): try: await t except BaseException: diff --git a/proxy/config.py b/proxy/config.py index 2d246b5..0d9821a 100644 --- a/proxy/config.py +++ b/proxy/config.py @@ -67,6 +67,7 @@ class ProxyConfig: cfproxy_worker_domains: List[str] = field(default_factory=list) fake_tls_domain: str = '' proxy_protocol: bool = False + ws_keepalive_interval: float = 30.0 proxy_config = ProxyConfig() diff --git a/proxy/raw_websocket.py b/proxy/raw_websocket.py index 30d07e9..9e7a248 100644 --- a/proxy/raw_websocket.py +++ b/proxy/raw_websocket.py @@ -154,6 +154,13 @@ class RawWebSocket: self._build_frame(self.OP_BINARY, part, mask=True)) await self.writer.drain() + async def send_ping(self, payload: bytes = b''): + if self._closed: + raise ConnectionError("WebSocket closed") + frame = self._build_frame(self.OP_PING, payload, mask=True) + self.writer.write(frame) + await self.writer.drain() + async def recv(self) -> Optional[bytes]: while not self._closed: opcode, payload = await self._read_frame() diff --git a/proxy/tg_ws_proxy.py b/proxy/tg_ws_proxy.py index 3869749..5589c8e 100644 --- a/proxy/tg_ws_proxy.py +++ b/proxy/tg_ws_proxy.py @@ -593,6 +593,10 @@ def main(): ap.add_argument('--proxy-protocol', action='store_true', help='Accept PROXY protocol v1 header ' '(for use behind nginx/haproxy with proxy_protocol on)') + ap.add_argument('--ws-keepalive', type=float, default=30.0, metavar='SEC', + help='Seconds between WebSocket keepalive PINGs to the ' + 'upstream (default 30, 0 to disable). Keeps idle ' + 'sessions alive through NAT/firewall timeouts.') args = ap.parse_args() if not args.dc_ip: @@ -629,6 +633,7 @@ def main(): proxy_config.cfproxy_worker_domains = coerce_domain_list(args.cfproxy_worker_domain) proxy_config.fake_tls_domain = args.fake_tls_domain.strip() proxy_config.proxy_protocol = args.proxy_protocol + proxy_config.ws_keepalive_interval = max(0.0, args.ws_keepalive) log_level = logging.DEBUG if args.verbose else logging.INFO log_fmt = logging.Formatter('%(asctime)s %(levelname)-5s %(message)s',