diff --git a/proxy/tg_ws_proxy.py b/proxy/tg_ws_proxy.py index f293f94..2a30cc6 100644 --- a/proxy/tg_ws_proxy.py +++ b/proxy/tg_ws_proxy.py @@ -539,6 +539,7 @@ class _WsPool: def __init__(self): self._idle: Dict[Tuple[int, bool], deque] = {} self._refilling: Set[Tuple[int, bool]] = set() + self._refill_tasks: Dict[Tuple[int, bool], asyncio.Task] = {} async def get(self, dc: int, is_media: bool, target_ip: str, domains: List[str] @@ -571,10 +572,13 @@ class _WsPool: if key in self._refilling: return self._refilling.add(key) - asyncio.create_task(self._refill(key, target_ip, domains)) + task = asyncio.create_task(self._refill(key, target_ip, domains)) + self._refill_tasks[key] = task + task.add_done_callback(lambda _t, refill_key=key: self._refill_tasks.pop(refill_key, None)) async def _refill(self, key, target_ip, domains): dc, is_media = key + tasks: List[asyncio.Task] = [] try: bucket = self._idle.setdefault(key, deque()) needed = proxy_config.pool_size - len(bucket) @@ -588,10 +592,19 @@ class _WsPool: ws = await t if ws: bucket.append((ws, time.monotonic())) + except asyncio.CancelledError: + raise except Exception: pass log.debug("WS pool refilled DC%d%s: %d ready", dc, 'm' if is_media else '', len(bucket)) + except asyncio.CancelledError: + for task in tasks: + if not task.done(): + task.cancel() + if tasks: + await asyncio.gather(*tasks, return_exceptions=True) + raise finally: self._refilling.discard(key) @@ -625,6 +638,29 @@ class _WsPool: self._schedule_refill((dc, is_media), target_ip, domains) log.info("WS pool warmup started for %d DC(s)", len(dc_redirects)) + async def close(self): + refill_tasks = list(self._refill_tasks.values()) + self._refill_tasks.clear() + for task in refill_tasks: + if not task.done(): + task.cancel() + if refill_tasks: + await asyncio.gather(*refill_tasks, return_exceptions=True) + + idle_sockets = [] + for bucket in self._idle.values(): + while bucket: + ws, _created = bucket.popleft() + idle_sockets.append(ws) + self._idle.clear() + self._refilling.clear() + + if idle_sockets: + await asyncio.gather( + *(self._quiet_close(ws) for ws in idle_sockets), + return_exceptions=True, + ) + _ws_pool = _WsPool() @@ -1111,6 +1147,7 @@ async def _run(stop_event: Optional[asyncio.Event] = None): else: await server.serve_forever() finally: + await _ws_pool.close() log_stats_task.cancel() try: await log_stats_task