Optimization, connections pool
This commit is contained in:
parent
72e5040e6d
commit
1c227b924a
|
|
@ -17,6 +17,12 @@ from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||||
DEFAULT_PORT = 1080
|
DEFAULT_PORT = 1080
|
||||||
log = logging.getLogger('tg-ws-proxy')
|
log = logging.getLogger('tg-ws-proxy')
|
||||||
|
|
||||||
|
_TCP_NODELAY = True
|
||||||
|
_RECV_BUF = 131072
|
||||||
|
_SEND_BUF = 131072
|
||||||
|
_WS_POOL_SIZE = 4
|
||||||
|
_WS_POOL_MAX_AGE = 120.0
|
||||||
|
|
||||||
_TG_RANGES = [
|
_TG_RANGES = [
|
||||||
# 185.76.151.0/24
|
# 185.76.151.0/24
|
||||||
(struct.unpack('!I', _socket.inet_aton('185.76.151.0'))[0],
|
(struct.unpack('!I', _socket.inet_aton('185.76.151.0'))[0],
|
||||||
|
|
@ -43,7 +49,7 @@ _IP_TO_DC: Dict[str, Tuple[int, bool]] = {
|
||||||
'149.154.167.51': (2, False), '149.154.167.220': (2, False),
|
'149.154.167.51': (2, False), '149.154.167.220': (2, False),
|
||||||
'95.161.76.100': (2, False),
|
'95.161.76.100': (2, False),
|
||||||
'149.154.167.151': (2, True), '149.154.167.222': (2, True),
|
'149.154.167.151': (2, True), '149.154.167.222': (2, True),
|
||||||
'149.154.167.223': (2, True),
|
'149.154.167.223': (2, True), '149.154.162.123': (2, True),
|
||||||
# DC3
|
# DC3
|
||||||
'149.154.175.100': (3, False), '149.154.175.101': (3, False),
|
'149.154.175.100': (3, False), '149.154.175.101': (3, False),
|
||||||
'149.154.175.102': (3, True),
|
'149.154.175.102': (3, True),
|
||||||
|
|
@ -79,6 +85,22 @@ _ssl_ctx.check_hostname = False
|
||||||
_ssl_ctx.verify_mode = ssl.CERT_NONE
|
_ssl_ctx.verify_mode = ssl.CERT_NONE
|
||||||
|
|
||||||
|
|
||||||
|
def _set_sock_opts(transport):
|
||||||
|
sock = transport.get_extra_info('socket')
|
||||||
|
if sock is None:
|
||||||
|
return
|
||||||
|
if _TCP_NODELAY:
|
||||||
|
try:
|
||||||
|
sock.setsockopt(_socket.IPPROTO_TCP, _socket.TCP_NODELAY, 1)
|
||||||
|
except (OSError, AttributeError):
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
sock.setsockopt(_socket.SOL_SOCKET, _socket.SO_RCVBUF, _RECV_BUF)
|
||||||
|
sock.setsockopt(_socket.SOL_SOCKET, _socket.SO_SNDBUF, _SEND_BUF)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class WsHandshakeError(Exception):
|
class WsHandshakeError(Exception):
|
||||||
def __init__(self, status_code: int, status_line: str,
|
def __init__(self, status_code: int, status_line: str,
|
||||||
headers: dict = None, location: str = None):
|
headers: dict = None, location: str = None):
|
||||||
|
|
@ -136,6 +158,7 @@ class RawWebSocket:
|
||||||
asyncio.open_connection(ip, 443, ssl=_ssl_ctx,
|
asyncio.open_connection(ip, 443, ssl=_ssl_ctx,
|
||||||
server_hostname=domain),
|
server_hostname=domain),
|
||||||
timeout=min(timeout, 10))
|
timeout=min(timeout, 10))
|
||||||
|
_set_sock_opts(writer.transport)
|
||||||
|
|
||||||
ws_key = base64.b64encode(os.urandom(16)).decode()
|
ws_key = base64.b64encode(os.urandom(16)).decode()
|
||||||
req = (
|
req = (
|
||||||
|
|
@ -463,6 +486,8 @@ class Stats:
|
||||||
self.ws_errors = 0
|
self.ws_errors = 0
|
||||||
self.bytes_up = 0
|
self.bytes_up = 0
|
||||||
self.bytes_down = 0
|
self.bytes_down = 0
|
||||||
|
self.pool_hits = 0
|
||||||
|
self.pool_misses = 0
|
||||||
|
|
||||||
def summary(self) -> str:
|
def summary(self) -> str:
|
||||||
return (f"total={self.connections_total} ws={self.connections_ws} "
|
return (f"total={self.connections_total} ws={self.connections_ws} "
|
||||||
|
|
@ -470,6 +495,7 @@ class Stats:
|
||||||
f"http_skip={self.connections_http_rejected} "
|
f"http_skip={self.connections_http_rejected} "
|
||||||
f"pass={self.connections_passthrough} "
|
f"pass={self.connections_passthrough} "
|
||||||
f"err={self.ws_errors} "
|
f"err={self.ws_errors} "
|
||||||
|
f"pool={self.pool_hits}/{self.pool_hits+self.pool_misses} "
|
||||||
f"up={_human_bytes(self.bytes_up)} "
|
f"up={_human_bytes(self.bytes_up)} "
|
||||||
f"down={_human_bytes(self.bytes_down)}")
|
f"down={_human_bytes(self.bytes_down)}")
|
||||||
|
|
||||||
|
|
@ -477,6 +503,100 @@ class Stats:
|
||||||
_stats = Stats()
|
_stats = Stats()
|
||||||
|
|
||||||
|
|
||||||
|
class _WsPool:
|
||||||
|
def __init__(self):
|
||||||
|
self._idle: Dict[Tuple[int, bool], list] = {}
|
||||||
|
self._refilling: Set[Tuple[int, bool]] = set()
|
||||||
|
|
||||||
|
async def get(self, dc: int, is_media: bool,
|
||||||
|
target_ip: str, domains: List[str]
|
||||||
|
) -> Optional[RawWebSocket]:
|
||||||
|
key = (dc, is_media)
|
||||||
|
now = time.monotonic()
|
||||||
|
|
||||||
|
bucket = self._idle.get(key, [])
|
||||||
|
while bucket:
|
||||||
|
ws, created = bucket.pop(0)
|
||||||
|
age = now - created
|
||||||
|
if age > _WS_POOL_MAX_AGE or ws._closed:
|
||||||
|
asyncio.create_task(self._quiet_close(ws))
|
||||||
|
continue
|
||||||
|
_stats.pool_hits += 1
|
||||||
|
log.debug("WS pool hit for DC%d%s (age=%.1fs, left=%d)",
|
||||||
|
dc, 'm' if is_media else '', age, len(bucket))
|
||||||
|
self._schedule_refill(key, target_ip, domains)
|
||||||
|
return ws
|
||||||
|
|
||||||
|
_stats.pool_misses += 1
|
||||||
|
self._schedule_refill(key, target_ip, domains)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _schedule_refill(self, key, target_ip, domains):
|
||||||
|
if key in self._refilling:
|
||||||
|
return
|
||||||
|
self._refilling.add(key)
|
||||||
|
asyncio.create_task(self._refill(key, target_ip, domains))
|
||||||
|
|
||||||
|
async def _refill(self, key, target_ip, domains):
|
||||||
|
dc, is_media = key
|
||||||
|
try:
|
||||||
|
bucket = self._idle.setdefault(key, [])
|
||||||
|
needed = _WS_POOL_SIZE - len(bucket)
|
||||||
|
if needed <= 0:
|
||||||
|
return
|
||||||
|
tasks = []
|
||||||
|
for _ in range(needed):
|
||||||
|
tasks.append(asyncio.create_task(
|
||||||
|
self._connect_one(target_ip, domains)))
|
||||||
|
for t in tasks:
|
||||||
|
try:
|
||||||
|
ws = await t
|
||||||
|
if ws:
|
||||||
|
bucket.append((ws, time.monotonic()))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
log.debug("WS pool refilled DC%d%s: %d ready",
|
||||||
|
dc, 'm' if is_media else '', len(bucket))
|
||||||
|
finally:
|
||||||
|
self._refilling.discard(key)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _connect_one(target_ip, domains) -> Optional[RawWebSocket]:
|
||||||
|
for domain in domains:
|
||||||
|
try:
|
||||||
|
ws = await RawWebSocket.connect(
|
||||||
|
target_ip, domain, timeout=8)
|
||||||
|
return ws
|
||||||
|
except WsHandshakeError as exc:
|
||||||
|
if exc.is_redirect:
|
||||||
|
continue
|
||||||
|
return None
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _quiet_close(ws):
|
||||||
|
try:
|
||||||
|
await ws.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def warmup(self, dc_opt: Dict[int, Optional[str]]):
|
||||||
|
"""Pre-fill pool for all configured DCs on startup."""
|
||||||
|
for dc, target_ip in dc_opt.items():
|
||||||
|
if target_ip is None:
|
||||||
|
continue
|
||||||
|
for is_media in (False, True):
|
||||||
|
domains = _ws_domains(dc, is_media)
|
||||||
|
key = (dc, is_media)
|
||||||
|
self._schedule_refill(key, target_ip, domains)
|
||||||
|
log.info("WS pool warmup started for %d DC(s)", len(dc_opt))
|
||||||
|
|
||||||
|
|
||||||
|
_ws_pool = _WsPool()
|
||||||
|
|
||||||
|
|
||||||
async def _bridge_ws(reader, writer, ws: RawWebSocket, label,
|
async def _bridge_ws(reader, writer, ws: RawWebSocket, label,
|
||||||
dc=None, dst=None, port=None, is_media=False,
|
dc=None, dst=None, port=None, is_media=False,
|
||||||
splitter: _MsgSplitter = None):
|
splitter: _MsgSplitter = None):
|
||||||
|
|
@ -526,7 +646,7 @@ async def _bridge_ws(reader, writer, ws: RawWebSocket, label,
|
||||||
writer.write(data)
|
writer.write(data)
|
||||||
# drain only when kernel buffer is filling up
|
# drain only when kernel buffer is filling up
|
||||||
buf = writer.transport.get_write_buffer_size()
|
buf = writer.transport.get_write_buffer_size()
|
||||||
if buf > 262144:
|
if buf > _SEND_BUF:
|
||||||
await writer.drain()
|
await writer.drain()
|
||||||
except (asyncio.CancelledError, ConnectionError, OSError):
|
except (asyncio.CancelledError, ConnectionError, OSError):
|
||||||
return
|
return
|
||||||
|
|
@ -658,6 +778,8 @@ async def _handle_client(reader, writer):
|
||||||
peer = writer.get_extra_info('peername')
|
peer = writer.get_extra_info('peername')
|
||||||
label = f"{peer[0]}:{peer[1]}" if peer else "?"
|
label = f"{peer[0]}:{peer[1]}" if peer else "?"
|
||||||
|
|
||||||
|
_set_sock_opts(writer.transport)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# -- SOCKS5 greeting --
|
# -- SOCKS5 greeting --
|
||||||
hdr = await asyncio.wait_for(reader.readexactly(2), timeout=10)
|
hdr = await asyncio.wait_for(reader.readexactly(2), timeout=10)
|
||||||
|
|
@ -798,6 +920,11 @@ async def _handle_client(reader, writer):
|
||||||
ws_failed_redirect = False
|
ws_failed_redirect = False
|
||||||
all_redirects = True
|
all_redirects = True
|
||||||
|
|
||||||
|
ws = await _ws_pool.get(dc, is_media, target, domains)
|
||||||
|
if ws:
|
||||||
|
log.info("[%s] DC%d%s (%s:%d) -> pool hit via %s",
|
||||||
|
label, dc, media_tag, dst, port, target)
|
||||||
|
else:
|
||||||
for domain in domains:
|
for domain in domains:
|
||||||
url = f'wss://{domain}/apiws'
|
url = f'wss://{domain}/apiws'
|
||||||
log.info("[%s] DC%d%s (%s:%d) -> %s via %s",
|
log.info("[%s] DC%d%s (%s:%d) -> %s via %s",
|
||||||
|
|
@ -906,6 +1033,12 @@ async def _run(port: int, dc_opt: Dict[int, Optional[str]],
|
||||||
_handle_client, host, port)
|
_handle_client, host, port)
|
||||||
_server_instance = server
|
_server_instance = server
|
||||||
|
|
||||||
|
for sock in server.sockets:
|
||||||
|
try:
|
||||||
|
sock.setsockopt(_socket.IPPROTO_TCP, _socket.TCP_NODELAY, 1)
|
||||||
|
except (OSError, AttributeError):
|
||||||
|
pass
|
||||||
|
|
||||||
log.info("=" * 60)
|
log.info("=" * 60)
|
||||||
log.info(" Telegram WS Bridge Proxy")
|
log.info(" Telegram WS Bridge Proxy")
|
||||||
log.info(" Listening on %s:%d", host, port)
|
log.info(" Listening on %s:%d", host, port)
|
||||||
|
|
@ -928,6 +1061,8 @@ async def _run(port: int, dc_opt: Dict[int, Optional[str]],
|
||||||
|
|
||||||
asyncio.create_task(log_stats())
|
asyncio.create_task(log_stats())
|
||||||
|
|
||||||
|
await _ws_pool.warmup(dc_opt)
|
||||||
|
|
||||||
if stop_event:
|
if stop_event:
|
||||||
async def wait_stop():
|
async def wait_stop():
|
||||||
await stop_event.wait()
|
await stop_event.wait()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue