This commit is contained in:
Flowseal 2026-03-29 15:21:45 +03:00
parent af74009b11
commit c4a044542c
1 changed files with 25 additions and 16 deletions

View File

@ -484,6 +484,7 @@ def _ws_domains(dc: int, is_media) -> List[str]:
class Stats: class Stats:
def __init__(self): def __init__(self):
self.connections_total = 0 self.connections_total = 0
self.connections_active = 0
self.connections_ws = 0 self.connections_ws = 0
self.connections_tcp_fallback = 0 self.connections_tcp_fallback = 0
self.connections_bad = 0 self.connections_bad = 0
@ -497,7 +498,9 @@ class Stats:
pool_total = self.pool_hits + self.pool_misses pool_total = self.pool_hits + self.pool_misses
pool_s = (f"{self.pool_hits}/{pool_total}" pool_s = (f"{self.pool_hits}/{pool_total}"
if pool_total else "n/a") if pool_total else "n/a")
return (f"total={self.connections_total} ws={self.connections_ws} " return (f"total={self.connections_total} "
f"active={self.connections_active} "
f"ws={self.connections_ws} "
f"tcp_fb={self.connections_tcp_fallback} " f"tcp_fb={self.connections_tcp_fallback} "
f"bad={self.connections_bad} " f"bad={self.connections_bad} "
f"err={self.ws_errors} " f"err={self.ws_errors} "
@ -528,7 +531,8 @@ class _WsPool:
while bucket: while bucket:
ws, created = bucket.popleft() ws, created = bucket.popleft()
age = now - created age = now - created
if age > self.WS_POOL_MAX_AGE or ws._closed: if (age > self.WS_POOL_MAX_AGE or ws._closed
or ws.writer.transport.is_closing()):
asyncio.create_task(self._quiet_close(ws)) asyncio.create_task(self._quiet_close(ws))
continue continue
_stats.pool_hits += 1 _stats.pool_hits += 1
@ -618,7 +622,7 @@ async def _bridge_ws_reencrypt(reader, writer, ws: RawWebSocket, label,
down_bytes = 0 down_bytes = 0
up_packets = 0 up_packets = 0
down_packets = 0 down_packets = 0
start_time = asyncio.get_event_loop().time() start_time = asyncio.get_running_loop().time()
async def tcp_to_ws(): async def tcp_to_ws():
nonlocal up_bytes, up_packets nonlocal up_bytes, up_packets
@ -684,7 +688,7 @@ async def _bridge_ws_reencrypt(reader, writer, ws: RawWebSocket, label,
await t await t
except BaseException: except BaseException:
pass pass
elapsed = asyncio.get_event_loop().time() - start_time elapsed = asyncio.get_running_loop().time() - start_time
log.info("[%s] %s WS session closed: " log.info("[%s] %s WS session closed: "
"^%s (%d pkts) v%s (%d pkts) in %.1fs", "^%s (%d pkts) v%s (%d pkts) in %.1fs",
label, dc_tag, label, dc_tag,
@ -782,6 +786,7 @@ def _fallback_ip(dc: int) -> Optional[str]:
async def _handle_client(reader, writer, secret: bytes): async def _handle_client(reader, writer, secret: bytes):
_stats.connections_total += 1 _stats.connections_total += 1
_stats.connections_active += 1
peer = writer.get_extra_info('peername') peer = writer.get_extra_info('peername')
label = f"{peer[0]}:{peer[1]}" if peer else "?" label = f"{peer[0]}:{peer[1]}" if peer else "?"
@ -869,10 +874,12 @@ async def _handle_client(reader, writer, secret: bytes):
if dc not in proxy_config.dc_redirects or dc_key in ws_blacklist: if dc not in proxy_config.dc_redirects or dc_key in ws_blacklist:
fallback_dst = _fallback_ip(dc) fallback_dst = _fallback_ip(dc)
if fallback_dst: if fallback_dst:
log.info("[%s] DC%d not in config -> TCP fallback %s:443" if dc not in proxy_config.dc_redirects:
if dc not in proxy_config.dc_redirects else log.info("[%s] DC%d not in config -> TCP fallback %s:443",
"[%s] DC%d%s WS blacklisted -> TCP fallback %s:443", label, dc, fallback_dst)
label, dc, fallback_dst) else:
log.info("[%s] DC%d%s WS blacklisted -> TCP fallback %s:443",
label, dc, media_tag, fallback_dst)
await _tcp_fallback(reader, writer, fallback_dst, 443, await _tcp_fallback(reader, writer, fallback_dst, 443,
relay_init, label, dc=dc, relay_init, label, dc=dc,
is_media=is_media, is_media=is_media,
@ -991,8 +998,9 @@ async def _handle_client(reader, writer, secret: bytes):
else: else:
log.error("[%s] unexpected OS error: %s", label, exc) log.error("[%s] unexpected OS error: %s", label, exc)
except Exception as exc: except Exception as exc:
log.error("[%s] unexpected: %s", label, exc.with_traceback()) log.error("[%s] unexpected: %s", label, exc, exc_info=True)
finally: finally:
_stats.connections_active -= 1
try: try:
writer.close() writer.close()
except BaseException: except BaseException:
@ -1007,9 +1015,10 @@ async def _run(stop_event: Optional[asyncio.Event] = None):
global _server_instance, _server_stop_event global _server_instance, _server_stop_event
_server_stop_event = stop_event _server_stop_event = stop_event
print(proxy_config.secret) secret_bytes = bytes.fromhex(proxy_config.secret)
def client_cb(r, w): def client_cb(r, w):
asyncio.create_task(_handle_client(r, w, bytes.fromhex(proxy_config.secret))) asyncio.create_task(_handle_client(r, w, secret_bytes))
server = await asyncio.start_server(client_cb, proxy_config.host, proxy_config.port) server = await asyncio.start_server(client_cb, proxy_config.host, proxy_config.port)
_server_instance = server _server_instance = server
@ -1155,19 +1164,19 @@ def main():
log.error("Secret must be exactly 32 hex characters") log.error("Secret must be exactly 32 hex characters")
sys.exit(1) sys.exit(1)
try: try:
secret = bytes.fromhex(secret_hex) bytes.fromhex(secret_hex)
except ValueError: except ValueError:
log.error("Secret must be valid hex") log.error("Secret must be valid hex")
sys.exit(1) sys.exit(1)
else: else:
secret = os.urandom(16).hex() secret_hex = os.urandom(16).hex()
log.info("Generated secret: %s", secret.hex()) log.info("Generated secret: %s", secret_hex)
global proxy_config global proxy_config
proxy_config = ProxyConfig( proxy_config = ProxyConfig(
port=args.port, port=args.port,
host=args.host, host=args.host,
secret=secret, secret=secret_hex,
dc_redirects=dc_redirects, dc_redirects=dc_redirects,
buffer_size=max(4, args.buf_kb) * 1024, buffer_size=max(4, args.buf_kb) * 1024,
pool_size=max(0, args.pool_size) pool_size=max(0, args.pool_size)
@ -1194,7 +1203,7 @@ def main():
root.addHandler(fh) root.addHandler(fh)
try: try:
asyncio.run(_run(args.port, dc_redirects, secret, host=args.host)) asyncio.run(_run())
except KeyboardInterrupt: except KeyboardInterrupt:
log.info("Shutting down. Final stats: %s", _stats.summary()) log.info("Shutting down. Final stats: %s", _stats.summary())