Pool for cloudflare worker

This commit is contained in:
Flowseal
2026-05-17 17:28:38 +03:00
parent 23f0e4d426
commit 1c4b103df2
5 changed files with 264 additions and 136 deletions

View File

@@ -3,7 +3,7 @@ import logging
import struct
from ._aes import Cipher, algorithms, modes
from typing import Dict, List, Optional
from typing import List, Optional
from urllib.parse import urlencode
from .utils import *
@@ -11,20 +11,13 @@ from .stats import stats
from .balancer import balancer
from .config import proxy_config
from .raw_websocket import RawWebSocket
from .pool import cf_worker_pool
log = logging.getLogger('tg-mtproto-proxy')
_st_I_le = struct.Struct('<I')
ZERO_64 = b'\x00' * 64
DC_DEFAULT_IPS: Dict[int, str] = {
1: '149.154.175.50',
2: '149.154.167.51',
3: '149.154.175.100',
4: '149.154.167.91',
5: '149.154.171.5',
203: '91.105.192.100'
}
class CryptoCtx:
@@ -135,7 +128,6 @@ class MsgSplitter:
return packet_len
async def do_fallback(reader, writer, relay_init, label,
dc: int, is_media: bool, media_tag: str,
ctx: CryptoCtx, splitter=None):
@@ -188,23 +180,27 @@ async def _cfproxy_worker_fallback(reader, writer, relay_init, label,
if not worker_domain:
return False
query = urlencode({
'dst': fallback_dst,
'dc': str(dc),
'media': '1' if is_media else '0',
})
path = f'/apiws?{query}'
ws = await cf_worker_pool.get(dc, worker_domain, fallback_dst)
if ws:
log.info("[%s] DC%d%s -> CF worker pool hit for %s",
label, dc, media_tag, fallback_dst)
else:
query = urlencode({
'dst': fallback_dst,
'dc': str(dc),
})
path = f'/apiws?{query}'
log.info("[%s] DC%d%s -> trying CF worker for %s",
label, dc, media_tag, fallback_dst)
log.info("[%s] DC%d%s -> trying CF worker for %s",
label, dc, media_tag, fallback_dst)
try:
ws = await RawWebSocket.connect(worker_domain, worker_domain,
timeout=10.0, path=path)
except Exception as exc:
log.warning("[%s] DC%d%s CF worker failed: %s",
label, dc, media_tag, repr(exc))
return False
try:
ws = await RawWebSocket.connect(worker_domain, worker_domain,
timeout=10.0, path=path)
except Exception as exc:
log.warning("[%s] DC%d%s CF worker failed: %s",
label, dc, media_tag, repr(exc))
return False
stats.connections_cfproxy += 1
await ws.send(relay_init)

211
proxy/pool.py Normal file
View File

@@ -0,0 +1,211 @@
import asyncio
import logging
import time
from collections import deque
from urllib.parse import urlencode
from typing import Dict, List, Optional, Tuple, Set
from .raw_websocket import RawWebSocket, WsHandshakeError
from .stats import stats
from .config import proxy_config
from .utils import ws_domains, DC_DEFAULT_IPS
log = logging.getLogger('tg-mtproto-proxy')
class _WsPool:
WS_POOL_MAX_AGE = 120.0
def __init__(self):
self._idle: Dict[Tuple[int, bool], deque] = {}
self._refilling: Set[Tuple[int, bool]] = set()
async def get(self, dc: int, is_media: bool,
target_ip: str, domains: List[str]
) -> Optional[RawWebSocket]:
key = (dc, is_media)
now = time.monotonic()
bucket = self._idle.get(key)
if bucket is None:
bucket = deque()
self._idle[key] = bucket
while bucket:
ws, created = bucket.popleft()
age = now - created
if (age > self.WS_POOL_MAX_AGE or ws._closed
or ws.writer.transport.is_closing()):
asyncio.create_task(self._quiet_close(ws))
continue
stats.pool_hits += 1
log.debug("WS pool hit DC%d%s (age=%.1fs, left=%d)",
dc, 'm' if is_media else '', age, len(bucket))
self._schedule_refill(key, target_ip, domains)
return ws
stats.pool_misses += 1
self._schedule_refill(key, target_ip, domains)
return None
def _schedule_refill(self, key, target_ip, domains):
if key in self._refilling:
return
self._refilling.add(key)
asyncio.create_task(self._refill(key, target_ip, domains))
async def _refill(self, key, target_ip, domains):
dc, is_media = key
try:
bucket = self._idle.setdefault(key, deque())
needed = proxy_config.pool_size - len(bucket)
if needed <= 0:
return
tasks = [asyncio.create_task(
self._connect_one(target_ip, domains))
for _ in range(needed)]
for t in tasks:
try:
ws = await t
if ws:
bucket.append((ws, time.monotonic()))
except Exception:
pass
log.debug("WS pool refilled DC%d%s: %d ready",
dc, 'm' if is_media else '', len(bucket))
finally:
self._refilling.discard(key)
@staticmethod
async def _connect_one(target_ip, domains) -> Optional[RawWebSocket]:
for domain in domains:
try:
return await RawWebSocket.connect(
target_ip, domain, timeout=8)
except WsHandshakeError as exc:
if exc.is_redirect:
continue
return None
except Exception:
return None
return None
@staticmethod
async def _quiet_close(ws):
try:
await ws.close()
except Exception:
pass
async def warmup(self):
for dc, target_ip in proxy_config.dc_redirects.items():
if target_ip is None:
continue
for is_media in (False, True):
domains = ws_domains(dc, is_media)
self._schedule_refill((dc, is_media), target_ip, domains)
log.info("WS pool warmup started for %d DC(s)", len(proxy_config.dc_redirects))
def reset(self):
self._idle.clear()
self._refilling.clear()
class _CfWorkerPool:
WS_POOL_MAX_AGE = 120.0
def __init__(self):
self._idle: Dict[int, deque] = {}
self._refilling: Set[int] = set()
async def get(self, dc: int, worker_domain: str, fallback_dst: str) -> Optional[RawWebSocket]:
now = time.monotonic()
bucket = self._idle.get(dc)
if bucket is None:
bucket = deque()
self._idle[dc] = bucket
while bucket:
ws, created = bucket.popleft()
age = now - created
if (age > self.WS_POOL_MAX_AGE or ws._closed
or ws.writer.transport.is_closing()):
asyncio.create_task(self._quiet_close(ws))
continue
stats.cf_pool_hits += 1
log.debug("CF worker pool hit DC%d (age=%.1fs, left=%d)",
dc, age, len(bucket))
self._schedule_refill(dc, worker_domain, fallback_dst)
return ws
stats.cf_pool_misses += 1
self._schedule_refill(dc, worker_domain, fallback_dst)
return None
def _schedule_refill(self, dc, worker_domain, fallback_dst):
if dc in self._refilling:
return
self._refilling.add(dc)
asyncio.create_task(self._refill(dc, worker_domain, fallback_dst))
async def _refill(self, dc, worker_domain, fallback_dst):
try:
bucket = self._idle.setdefault(dc, deque())
needed = proxy_config.pool_size - len(bucket)
if needed <= 0:
return
tasks = [asyncio.create_task(
self._connect_one(worker_domain, fallback_dst, dc))
for _ in range(needed)]
for t in tasks:
try:
ws = await t
if ws:
bucket.append((ws, time.monotonic()))
except Exception:
pass
log.debug("CF worker pool refilled DC%d: %d ready",
dc, len(bucket))
finally:
self._refilling.discard(dc)
@staticmethod
async def _connect_one(worker_domain, fallback_dst, dc) -> Optional[RawWebSocket]:
query = urlencode({
'dst': fallback_dst,
'dc': str(dc),
})
path = f'/apiws?{query}'
try:
return await RawWebSocket.connect(
worker_domain, worker_domain, timeout=8, path=path)
except Exception:
return None
@staticmethod
async def _quiet_close(ws):
try:
await ws.close()
except Exception:
pass
async def warmup(self):
cf_fallbacks = {
dc: ip for dc, ip in DC_DEFAULT_IPS.items()
if dc not in proxy_config.dc_redirects
}
if not cf_fallbacks or not proxy_config.cfproxy_worker_domain:
return
for dc, fallback_dst in cf_fallbacks.items():
self._schedule_refill(dc, proxy_config.cfproxy_worker_domain, fallback_dst)
log.info("CF worker pool warmup started for %d DC(s)", len(cf_fallbacks))
def reset(self):
self._idle.clear()
self._refilling.clear()
ws_pool = _WsPool()
cf_worker_pool = _CfWorkerPool()

View File

@@ -14,11 +14,16 @@ class _Stats:
self.bytes_down = 0
self.pool_hits = 0
self.pool_misses = 0
self.cf_pool_hits = 0
self.cf_pool_misses = 0
def summary(self) -> str:
pool_total = self.pool_hits + self.pool_misses
pool_s = (f"{self.pool_hits}/{pool_total}"
if pool_total else "n/a")
cf_pool_total = self.cf_pool_hits + self.cf_pool_misses
cf_pool_s = (f"{self.cf_pool_hits}/{cf_pool_total}"
if cf_pool_total else "n/a")
return (f"total={self.connections_total} "
f"active={self.connections_active} "
f"ws={self.connections_ws} "
@@ -28,6 +33,7 @@ class _Stats:
f"masked={self.connections_masked} "
f"err={self.ws_errors} "
f"pool={pool_s} "
f"cf_pool={cf_pool_s} "
f"up={human_bytes(self.bytes_up)} "
f"down={human_bytes(self.bytes_down)}")

View File

@@ -11,8 +11,7 @@ import logging
import logging.handlers
import socket as _socket
from collections import deque
from typing import Dict, List, Optional, Set, Tuple
from typing import Dict, Optional, Set, Tuple
from ._aes import Cipher, algorithms, modes
@@ -29,6 +28,7 @@ from .bridge import MsgSplitter, CryptoCtx, do_fallback, bridge_ws_reencrypt
from .raw_websocket import RawWebSocket, WsHandshakeError, set_sock_opts
from .fake_tls import proxy_to_masking_domain, verify_client_hello, build_server_hello, FakeTlsStream, TLS_RECORD_HANDSHAKE
from .balancer import balancer
from .pool import ws_pool, cf_worker_pool
log = logging.getLogger('tg-mtproto-proxy')
@@ -100,112 +100,8 @@ def _generate_relay_init(proto_tag: bytes, dc_idx: int) -> bytes:
return bytes(result)
def _ws_domains(dc: int, is_media) -> List[str]:
if dc == 203:
dc = 2
if is_media is None or is_media:
return [f'kws{dc}-1.web.telegram.org', f'kws{dc}.web.telegram.org']
return [f'kws{dc}.web.telegram.org', f'kws{dc}-1.web.telegram.org']
class _WsPool:
WS_POOL_MAX_AGE = 120.0
def __init__(self):
self._idle: Dict[Tuple[int, bool], deque] = {}
self._refilling: Set[Tuple[int, bool]] = set()
async def get(self, dc: int, is_media: bool,
target_ip: str, domains: List[str]
) -> Optional[RawWebSocket]:
key = (dc, is_media)
now = time.monotonic()
bucket = self._idle.get(key)
if bucket is None:
bucket = deque()
self._idle[key] = bucket
while bucket:
ws, created = bucket.popleft()
age = now - created
if (age > self.WS_POOL_MAX_AGE or ws._closed
or ws.writer.transport.is_closing()):
asyncio.create_task(self._quiet_close(ws))
continue
stats.pool_hits += 1
log.debug("WS pool hit DC%d%s (age=%.1fs, left=%d)",
dc, 'm' if is_media else '', age, len(bucket))
self._schedule_refill(key, target_ip, domains)
return ws
stats.pool_misses += 1
self._schedule_refill(key, target_ip, domains)
return None
def _schedule_refill(self, key, target_ip, domains):
if key in self._refilling:
return
self._refilling.add(key)
asyncio.create_task(self._refill(key, target_ip, domains))
async def _refill(self, key, target_ip, domains):
dc, is_media = key
try:
bucket = self._idle.setdefault(key, deque())
needed = proxy_config.pool_size - len(bucket)
if needed <= 0:
return
tasks = [asyncio.create_task(
self._connect_one(target_ip, domains))
for _ in range(needed)]
for t in tasks:
try:
ws = await t
if ws:
bucket.append((ws, time.monotonic()))
except Exception:
pass
log.debug("WS pool refilled DC%d%s: %d ready",
dc, 'm' if is_media else '', len(bucket))
finally:
self._refilling.discard(key)
@staticmethod
async def _connect_one(target_ip, domains) -> Optional[RawWebSocket]:
for domain in domains:
try:
return await RawWebSocket.connect(
target_ip, domain, timeout=8)
except WsHandshakeError as exc:
if exc.is_redirect:
continue
return None
except Exception:
return None
return None
@staticmethod
async def _quiet_close(ws):
try:
await ws.close()
except Exception:
pass
async def warmup(self, dc_redirects: Dict[int, str]):
for dc, target_ip in dc_redirects.items():
if target_ip is None:
continue
for is_media in (False, True):
domains = _ws_domains(dc, is_media)
self._schedule_refill((dc, is_media), target_ip, domains)
log.info("WS pool warmup started for %d DC(s)", len(dc_redirects))
def reset(self):
self._idle.clear()
self._refilling.clear()
_ws_pool = _WsPool()
async def _read_client_init(reader, writer, secret, label, masking):
if proxy_config.proxy_protocol:
@@ -420,13 +316,13 @@ async def _handle_client(reader, writer, secret: bytes):
fail_until = dc_fail_until.get(dc_key, 0)
ws_timeout = WS_FAIL_TIMEOUT if now < fail_until else 10.0
domains = _ws_domains(dc, is_media)
domains = ws_domains(dc, is_media)
target = proxy_config.dc_redirects[dc]
ws = None
ws_failed_redirect = False
all_redirects = True
ws = await _ws_pool.get(dc, is_media, target, domains)
ws = await ws_pool.get(dc, is_media, target, domains)
if ws:
log.info("[%s] DC%d%s -> pool hit via %s",
label, dc, media_tag, target)
@@ -536,7 +432,8 @@ async def _run(stop_event: Optional[asyncio.Event] = None):
global _server_instance, _server_stop_event
_server_stop_event = stop_event
_ws_pool.reset()
ws_pool.reset()
cf_worker_pool.reset()
ws_blacklist.clear()
dc_fail_until.clear()
_client_tasks.clear()
@@ -611,7 +508,8 @@ async def _run(stop_event: Optional[asyncio.Event] = None):
log_stats_task = asyncio.create_task(log_stats())
await _ws_pool.warmup(proxy_config.dc_redirects)
await ws_pool.warmup()
await cf_worker_pool.warmup()
try:
async with server:

View File

@@ -2,7 +2,7 @@ import socket as _socket
import urllib.request
import http.client
from typing import Optional, Dict
from typing import Optional, Dict, List
from urllib.request import Request
@@ -34,6 +34,23 @@ _GITHUB_IPS: Dict[str, str] = {
"raw.githubusercontent.com": "185.199.109.133",
}
DC_DEFAULT_IPS: Dict[int, str] = {
1: '149.154.175.50',
2: '149.154.167.51',
3: '149.154.175.100',
4: '149.154.167.91',
5: '149.154.171.5',
203: '91.105.192.100'
}
def ws_domains(dc: int, is_media) -> List[str]:
if dc == 203:
dc = 2
if is_media is None or is_media:
return [f'kws{dc}-1.web.telegram.org', f'kws{dc}.web.telegram.org']
return [f'kws{dc}.web.telegram.org', f'kws{dc}-1.web.telegram.org']
def human_bytes(n: int) -> str:
for unit in ('B', 'KB', 'MB', 'GB'):