diff --git a/proxy/tg_ws_proxy.py b/proxy/tg_ws_proxy.py index 92c5c12..3a9c79a 100644 --- a/proxy/tg_ws_proxy.py +++ b/proxy/tg_ws_proxy.py @@ -35,6 +35,8 @@ log = logging.getLogger('tg-mtproto-proxy') DC_FAIL_COOLDOWN = 30.0 WS_FAIL_TIMEOUT = 2.0 +LISTENER_CHECK_INTERVAL = 5.0 +LISTENER_RESTART_DELAY = 0.5 ws_blacklist: Set[str] = set() dc_fail_until: Dict[str, float] = {} @@ -511,43 +513,84 @@ async def _run(stop_event: Optional[asyncio.Event] = None): await ws_pool.warmup() await cf_worker_pool.warmup() + async def _quiet_cancel(t): + if not t.done(): + t.cancel() + try: + await t + except (asyncio.CancelledError, Exception): + pass + try: - async with server: - if stop_event: - serve_task = asyncio.create_task(server.serve_forever()) - stop_task = asyncio.create_task(stop_event.wait()) - done, _ = await asyncio.wait( - (serve_task, stop_task), - return_when=asyncio.FIRST_COMPLETED, - ) - if stop_task in done: - for task in list(_client_tasks): - task.cancel() - if _client_tasks: - await asyncio.gather( - *_client_tasks, return_exceptions=True) - if not serve_task.done(): - serve_task.cancel() - try: - await serve_task - except asyncio.CancelledError: - pass - server.close() - await server.wait_closed() - else: - stop_task.cancel() - try: - await stop_task - except asyncio.CancelledError: - pass - else: - await server.serve_forever() + while True: + serve_task = asyncio.create_task(server.serve_forever()) + stop_task = (asyncio.create_task(stop_event.wait()) + if stop_event else None) + + async def _listener_watchdog(): + while True: + await asyncio.sleep(LISTENER_CHECK_INTERVAL) + socks = server.sockets + if not socks or all(s.fileno() < 0 for s in socks): + return + + watchdog_task = asyncio.create_task(_listener_watchdog()) + waiters = [serve_task, watchdog_task] + if stop_task is not None: + waiters.append(stop_task) + + done, _ = await asyncio.wait( + waiters, return_when=asyncio.FIRST_COMPLETED) + + if stop_task is not None and stop_task in done: + for task in list(_client_tasks): + task.cancel() + if _client_tasks: + await asyncio.gather( + *_client_tasks, return_exceptions=True) + await _quiet_cancel(watchdog_task) + await _quiet_cancel(serve_task) + server.close() + await server.wait_closed() + break + + await _quiet_cancel(watchdog_task) + await _quiet_cancel(serve_task) + log.warning( + "Listening socket died (OS accept error, e.g. WinError 64); " + "restarting server") + server.close() + try: + await server.wait_closed() + except Exception: + pass + await asyncio.sleep(LISTENER_RESTART_DELAY) + try: + server = await asyncio.start_server( + client_cb, proxy_config.host, proxy_config.port) + except OSError as exc: + log.error("Failed to restart server: %s", repr(exc)) + break + _server_instance = server + for sock in server.sockets: + try: + sock.setsockopt( + _socket.IPPROTO_TCP, _socket.TCP_NODELAY, 1) + except (OSError, AttributeError): + pass + log.warning("Server restored, listening on %s:%d", + proxy_config.host, proxy_config.port) finally: log_stats_task.cancel() try: await log_stats_task except asyncio.CancelledError: pass + try: + server.close() + await server.wait_closed() + except Exception: + pass _server_instance = None