mirror of
https://github.com/Flowseal/tg-ws-proxy.git
synced 2026-06-18 20:48:28 +03:00
Pool for cloudflare worker
This commit is contained in:
+8
-110
@@ -11,8 +11,7 @@ import logging
|
||||
import logging.handlers
|
||||
import socket as _socket
|
||||
|
||||
from collections import deque
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
from typing import Dict, Optional, Set, Tuple
|
||||
|
||||
from ._aes import Cipher, algorithms, modes
|
||||
|
||||
@@ -29,6 +28,7 @@ from .bridge import MsgSplitter, CryptoCtx, do_fallback, bridge_ws_reencrypt
|
||||
from .raw_websocket import RawWebSocket, WsHandshakeError, set_sock_opts
|
||||
from .fake_tls import proxy_to_masking_domain, verify_client_hello, build_server_hello, FakeTlsStream, TLS_RECORD_HANDSHAKE
|
||||
from .balancer import balancer
|
||||
from .pool import ws_pool, cf_worker_pool
|
||||
|
||||
|
||||
log = logging.getLogger('tg-mtproto-proxy')
|
||||
@@ -100,112 +100,8 @@ def _generate_relay_init(proto_tag: bytes, dc_idx: int) -> bytes:
|
||||
return bytes(result)
|
||||
|
||||
|
||||
def _ws_domains(dc: int, is_media) -> List[str]:
|
||||
if dc == 203:
|
||||
dc = 2
|
||||
if is_media is None or is_media:
|
||||
return [f'kws{dc}-1.web.telegram.org', f'kws{dc}.web.telegram.org']
|
||||
return [f'kws{dc}.web.telegram.org', f'kws{dc}-1.web.telegram.org']
|
||||
|
||||
|
||||
class _WsPool:
|
||||
WS_POOL_MAX_AGE = 120.0
|
||||
|
||||
def __init__(self):
|
||||
self._idle: Dict[Tuple[int, bool], deque] = {}
|
||||
self._refilling: Set[Tuple[int, bool]] = set()
|
||||
|
||||
async def get(self, dc: int, is_media: bool,
|
||||
target_ip: str, domains: List[str]
|
||||
) -> Optional[RawWebSocket]:
|
||||
key = (dc, is_media)
|
||||
now = time.monotonic()
|
||||
|
||||
bucket = self._idle.get(key)
|
||||
if bucket is None:
|
||||
bucket = deque()
|
||||
self._idle[key] = bucket
|
||||
while bucket:
|
||||
ws, created = bucket.popleft()
|
||||
age = now - created
|
||||
if (age > self.WS_POOL_MAX_AGE or ws._closed
|
||||
or ws.writer.transport.is_closing()):
|
||||
asyncio.create_task(self._quiet_close(ws))
|
||||
continue
|
||||
stats.pool_hits += 1
|
||||
log.debug("WS pool hit DC%d%s (age=%.1fs, left=%d)",
|
||||
dc, 'm' if is_media else '', age, len(bucket))
|
||||
self._schedule_refill(key, target_ip, domains)
|
||||
return ws
|
||||
|
||||
stats.pool_misses += 1
|
||||
self._schedule_refill(key, target_ip, domains)
|
||||
return None
|
||||
|
||||
def _schedule_refill(self, key, target_ip, domains):
|
||||
if key in self._refilling:
|
||||
return
|
||||
self._refilling.add(key)
|
||||
asyncio.create_task(self._refill(key, target_ip, domains))
|
||||
|
||||
async def _refill(self, key, target_ip, domains):
|
||||
dc, is_media = key
|
||||
try:
|
||||
bucket = self._idle.setdefault(key, deque())
|
||||
needed = proxy_config.pool_size - len(bucket)
|
||||
if needed <= 0:
|
||||
return
|
||||
tasks = [asyncio.create_task(
|
||||
self._connect_one(target_ip, domains))
|
||||
for _ in range(needed)]
|
||||
for t in tasks:
|
||||
try:
|
||||
ws = await t
|
||||
if ws:
|
||||
bucket.append((ws, time.monotonic()))
|
||||
except Exception:
|
||||
pass
|
||||
log.debug("WS pool refilled DC%d%s: %d ready",
|
||||
dc, 'm' if is_media else '', len(bucket))
|
||||
finally:
|
||||
self._refilling.discard(key)
|
||||
|
||||
@staticmethod
|
||||
async def _connect_one(target_ip, domains) -> Optional[RawWebSocket]:
|
||||
for domain in domains:
|
||||
try:
|
||||
return await RawWebSocket.connect(
|
||||
target_ip, domain, timeout=8)
|
||||
except WsHandshakeError as exc:
|
||||
if exc.is_redirect:
|
||||
continue
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def _quiet_close(ws):
|
||||
try:
|
||||
await ws.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def warmup(self, dc_redirects: Dict[int, str]):
|
||||
for dc, target_ip in dc_redirects.items():
|
||||
if target_ip is None:
|
||||
continue
|
||||
for is_media in (False, True):
|
||||
domains = _ws_domains(dc, is_media)
|
||||
self._schedule_refill((dc, is_media), target_ip, domains)
|
||||
log.info("WS pool warmup started for %d DC(s)", len(dc_redirects))
|
||||
|
||||
def reset(self):
|
||||
self._idle.clear()
|
||||
self._refilling.clear()
|
||||
|
||||
_ws_pool = _WsPool()
|
||||
|
||||
|
||||
async def _read_client_init(reader, writer, secret, label, masking):
|
||||
if proxy_config.proxy_protocol:
|
||||
@@ -420,13 +316,13 @@ async def _handle_client(reader, writer, secret: bytes):
|
||||
fail_until = dc_fail_until.get(dc_key, 0)
|
||||
ws_timeout = WS_FAIL_TIMEOUT if now < fail_until else 10.0
|
||||
|
||||
domains = _ws_domains(dc, is_media)
|
||||
domains = ws_domains(dc, is_media)
|
||||
target = proxy_config.dc_redirects[dc]
|
||||
ws = None
|
||||
ws_failed_redirect = False
|
||||
all_redirects = True
|
||||
|
||||
ws = await _ws_pool.get(dc, is_media, target, domains)
|
||||
ws = await ws_pool.get(dc, is_media, target, domains)
|
||||
if ws:
|
||||
log.info("[%s] DC%d%s -> pool hit via %s",
|
||||
label, dc, media_tag, target)
|
||||
@@ -536,7 +432,8 @@ async def _run(stop_event: Optional[asyncio.Event] = None):
|
||||
global _server_instance, _server_stop_event
|
||||
_server_stop_event = stop_event
|
||||
|
||||
_ws_pool.reset()
|
||||
ws_pool.reset()
|
||||
cf_worker_pool.reset()
|
||||
ws_blacklist.clear()
|
||||
dc_fail_until.clear()
|
||||
_client_tasks.clear()
|
||||
@@ -611,7 +508,8 @@ async def _run(stop_event: Optional[asyncio.Event] = None):
|
||||
|
||||
log_stats_task = asyncio.create_task(log_stats())
|
||||
|
||||
await _ws_pool.warmup(proxy_config.dc_redirects)
|
||||
await ws_pool.warmup()
|
||||
await cf_worker_pool.warmup()
|
||||
|
||||
try:
|
||||
async with server:
|
||||
|
||||
Reference in New Issue
Block a user