fix(runtime): close ws pool tasks before loop shutdown

This commit is contained in:
Dark_Avery 2026-03-30 16:32:34 +03:00
parent 7ad377c12c
commit 76b375bd03
1 changed files with 38 additions and 1 deletions

View File

@ -539,6 +539,7 @@ class _WsPool:
def __init__(self):
self._idle: Dict[Tuple[int, bool], deque] = {}
self._refilling: Set[Tuple[int, bool]] = set()
self._refill_tasks: Dict[Tuple[int, bool], asyncio.Task] = {}
async def get(self, dc: int, is_media: bool,
target_ip: str, domains: List[str]
@ -571,10 +572,13 @@ class _WsPool:
if key in self._refilling:
return
self._refilling.add(key)
asyncio.create_task(self._refill(key, target_ip, domains))
task = asyncio.create_task(self._refill(key, target_ip, domains))
self._refill_tasks[key] = task
task.add_done_callback(lambda _t, refill_key=key: self._refill_tasks.pop(refill_key, None))
async def _refill(self, key, target_ip, domains):
dc, is_media = key
tasks: List[asyncio.Task] = []
try:
bucket = self._idle.setdefault(key, deque())
needed = proxy_config.pool_size - len(bucket)
@ -588,10 +592,19 @@ class _WsPool:
ws = await t
if ws:
bucket.append((ws, time.monotonic()))
except asyncio.CancelledError:
raise
except Exception:
pass
log.debug("WS pool refilled DC%d%s: %d ready",
dc, 'm' if is_media else '', len(bucket))
except asyncio.CancelledError:
for task in tasks:
if not task.done():
task.cancel()
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
raise
finally:
self._refilling.discard(key)
@ -625,6 +638,29 @@ class _WsPool:
self._schedule_refill((dc, is_media), target_ip, domains)
log.info("WS pool warmup started for %d DC(s)", len(dc_redirects))
async def close(self):
refill_tasks = list(self._refill_tasks.values())
self._refill_tasks.clear()
for task in refill_tasks:
if not task.done():
task.cancel()
if refill_tasks:
await asyncio.gather(*refill_tasks, return_exceptions=True)
idle_sockets = []
for bucket in self._idle.values():
while bucket:
ws, _created = bucket.popleft()
idle_sockets.append(ws)
self._idle.clear()
self._refilling.clear()
if idle_sockets:
await asyncio.gather(
*(self._quiet_close(ws) for ws in idle_sockets),
return_exceptions=True,
)
_ws_pool = _WsPool()
@ -1111,6 +1147,7 @@ async def _run(stop_event: Optional[asyncio.Event] = None):
else:
await server.serve_forever()
finally:
await _ws_pool.close()
log_stats_task.cancel()
try:
await log_stats_task