mirror of
https://github.com/Flowseal/tg-ws-proxy.git
synced 2026-05-25 08:51:43 +03:00
Merge upstream/main into android_migration
This commit is contained in:
@@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
import asyncio as _asyncio
|
||||
import json
|
||||
import logging
|
||||
import logging.handlers
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
@@ -17,6 +18,9 @@ DEFAULT_CONFIG = {
|
||||
"host": "127.0.0.1",
|
||||
"dc_ip": ["2:149.154.167.220", "4:149.154.167.220"],
|
||||
"verbose": False,
|
||||
"log_max_mb": 5,
|
||||
"buf_kb": 256,
|
||||
"pool_size": 4,
|
||||
}
|
||||
|
||||
|
||||
@@ -76,7 +80,7 @@ class ProxyAppRuntime:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def setup_logging(self, verbose: bool = False):
|
||||
def setup_logging(self, verbose: bool = False, log_max_mb: float = 5):
|
||||
self.ensure_dirs()
|
||||
root = logging.getLogger()
|
||||
root.setLevel(logging.DEBUG if verbose else logging.INFO)
|
||||
@@ -89,7 +93,12 @@ class ProxyAppRuntime:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
fh = logging.FileHandler(str(self.log_file), encoding="utf-8")
|
||||
fh = logging.handlers.RotatingFileHandler(
|
||||
str(self.log_file),
|
||||
maxBytes=max(32 * 1024, log_max_mb * 1024 * 1024),
|
||||
backupCount=0,
|
||||
encoding="utf-8",
|
||||
)
|
||||
fh.setLevel(logging.DEBUG)
|
||||
fh.setFormatter(logging.Formatter(
|
||||
"%(asctime)s %(levelname)-5s %(name)s %(message)s",
|
||||
@@ -148,6 +157,9 @@ class ProxyAppRuntime:
|
||||
port = active_cfg.get("port", self.default_config["port"])
|
||||
host = active_cfg.get("host", self.default_config["host"])
|
||||
dc_ip_list = active_cfg.get("dc_ip", self.default_config["dc_ip"])
|
||||
buf_kb = active_cfg.get("buf_kb", self.default_config["buf_kb"])
|
||||
pool_size = active_cfg.get(
|
||||
"pool_size", self.default_config["pool_size"])
|
||||
|
||||
try:
|
||||
dc_opt = self.parse_dc_ip_list(dc_ip_list)
|
||||
@@ -157,6 +169,9 @@ class ProxyAppRuntime:
|
||||
return False
|
||||
|
||||
self.log.info("Starting proxy on %s:%d ...", host, port)
|
||||
tg_ws_proxy._RECV_BUF = max(4, buf_kb) * 1024
|
||||
tg_ws_proxy._SEND_BUF = tg_ws_proxy._RECV_BUF
|
||||
tg_ws_proxy._WS_POOL_SIZE = max(0, pool_size)
|
||||
self._proxy_thread = self.thread_factory(
|
||||
target=self._run_proxy_thread,
|
||||
args=(port, dc_opt, host),
|
||||
|
||||
@@ -4,6 +4,7 @@ import argparse
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
import logging.handlers
|
||||
import os
|
||||
import socket as _socket
|
||||
import ssl
|
||||
@@ -86,6 +87,8 @@ _dc_fail_until: Dict[Tuple[int, bool], float] = {}
|
||||
_DC_FAIL_COOLDOWN = 30.0 # seconds to keep reduced WS timeout after failure
|
||||
_WS_FAIL_TIMEOUT = 2.0 # quick-retry timeout after a recent WS failure
|
||||
|
||||
_ZERO_64 = b'\x00' * 64
|
||||
|
||||
|
||||
_ssl_ctx = ssl.create_default_context()
|
||||
_ssl_ctx.check_hostname = False
|
||||
@@ -130,6 +133,21 @@ def _xor_mask(data: bytes, mask: bytes) -> bytes:
|
||||
return (int.from_bytes(data, 'big') ^ int.from_bytes(mask_rep, 'big')).to_bytes(n, 'big')
|
||||
|
||||
|
||||
# Pre-compiled struct formats
|
||||
_st_BB = struct.Struct('>BB')
|
||||
_st_BBH = struct.Struct('>BBH')
|
||||
_st_BBQ = struct.Struct('>BBQ')
|
||||
_st_BB4s = struct.Struct('>BB4s')
|
||||
_st_BBH4s = struct.Struct('>BBH4s')
|
||||
_st_BBQ4s = struct.Struct('>BBQ4s')
|
||||
_st_H = struct.Struct('>H')
|
||||
_st_Q = struct.Struct('>Q')
|
||||
_st_I_net = struct.Struct('!I')
|
||||
_st_Ih = struct.Struct('<Ih')
|
||||
_st_I_le = struct.Struct('<I')
|
||||
_VALID_PROTOS = frozenset((0xEFEFEFEF, 0xEEEEEEEE, 0xDDDDDDDD))
|
||||
|
||||
|
||||
class RawWebSocket:
|
||||
"""
|
||||
Lightweight WebSocket client over asyncio reader/writer streams.
|
||||
@@ -138,6 +156,7 @@ class RawWebSocket:
|
||||
proxy), performs the HTTP Upgrade handshake, and provides send/recv
|
||||
for binary frames with proper masking, ping/pong, and close handling.
|
||||
"""
|
||||
__slots__ = ('reader', 'writer', '_closed')
|
||||
|
||||
OP_CONTINUATION = 0x0
|
||||
OP_TEXT = 0x1
|
||||
@@ -303,40 +322,37 @@ class RawWebSocket:
|
||||
@staticmethod
|
||||
def _build_frame(opcode: int, data: bytes,
|
||||
mask: bool = False) -> bytes:
|
||||
header = bytearray()
|
||||
header.append(0x80 | opcode) # FIN=1 + opcode
|
||||
length = len(data)
|
||||
mask_bit = 0x80 if mask else 0x00
|
||||
fb = 0x80 | opcode
|
||||
|
||||
if not mask:
|
||||
if length < 126:
|
||||
return _st_BB.pack(fb, length) + data
|
||||
if length < 65536:
|
||||
return _st_BBH.pack(fb, 126, length) + data
|
||||
return _st_BBQ.pack(fb, 127, length) + data
|
||||
|
||||
mask_key = os.urandom(4)
|
||||
masked = _xor_mask(data, mask_key)
|
||||
if length < 126:
|
||||
header.append(mask_bit | length)
|
||||
elif length < 65536:
|
||||
header.append(mask_bit | 126)
|
||||
header.extend(struct.pack('>H', length))
|
||||
else:
|
||||
header.append(mask_bit | 127)
|
||||
header.extend(struct.pack('>Q', length))
|
||||
|
||||
if mask:
|
||||
mask_key = os.urandom(4)
|
||||
header.extend(mask_key)
|
||||
return bytes(header) + _xor_mask(data, mask_key)
|
||||
return bytes(header) + data
|
||||
return _st_BB4s.pack(fb, 0x80 | length, mask_key) + masked
|
||||
if length < 65536:
|
||||
return _st_BBH4s.pack(fb, 0x80 | 126, length, mask_key) + masked
|
||||
return _st_BBQ4s.pack(fb, 0x80 | 127, length, mask_key) + masked
|
||||
|
||||
async def _read_frame(self) -> Tuple[int, bytes]:
|
||||
hdr = await self.reader.readexactly(2)
|
||||
opcode = hdr[0] & 0x0F
|
||||
is_masked = bool(hdr[1] & 0x80)
|
||||
length = hdr[1] & 0x7F
|
||||
|
||||
if length == 126:
|
||||
length = struct.unpack('>H',
|
||||
await self.reader.readexactly(2))[0]
|
||||
length = _st_H.unpack(
|
||||
await self.reader.readexactly(2))[0]
|
||||
elif length == 127:
|
||||
length = struct.unpack('>Q',
|
||||
await self.reader.readexactly(8))[0]
|
||||
length = _st_Q.unpack(
|
||||
await self.reader.readexactly(8))[0]
|
||||
|
||||
if is_masked:
|
||||
if hdr[1] & 0x80:
|
||||
mask_key = await self.reader.readexactly(4)
|
||||
payload = await self.reader.readexactly(length)
|
||||
return opcode, _xor_mask(payload, mask_key)
|
||||
@@ -355,7 +371,7 @@ def _human_bytes(n: int) -> str:
|
||||
|
||||
def _is_telegram_ip(ip: str) -> bool:
|
||||
try:
|
||||
n = struct.unpack('!I', _socket.inet_aton(ip))[0]
|
||||
n = _st_I_net.unpack(_socket.inet_aton(ip))[0]
|
||||
return any(lo <= n <= hi for lo, hi in _TG_RANGES)
|
||||
except OSError:
|
||||
return False
|
||||
@@ -381,7 +397,7 @@ def _dc_from_init(data: bytes) -> Tuple[Optional[int], bool]:
|
||||
dc_raw = struct.unpack('<h', plain[4:6])[0]
|
||||
log.debug("dc_from_init: proto=0x%08X dc_raw=%d plain=%s",
|
||||
proto, dc_raw, plain.hex())
|
||||
if proto in (0xEFEFEFEF, 0xEEEEEEEE, 0xDDDDDDDD):
|
||||
if proto in _VALID_PROTOS:
|
||||
dc = abs(dc_raw)
|
||||
if 1 <= dc <= 5 or dc == 203:
|
||||
return dc, (dc_raw < 0)
|
||||
@@ -439,19 +455,20 @@ class _MsgSplitter:
|
||||
plain = self._dec.update(chunk)
|
||||
boundaries = []
|
||||
pos = 0
|
||||
while pos < len(plain):
|
||||
plain_len = len(plain)
|
||||
while pos < plain_len:
|
||||
first = plain[pos]
|
||||
if first == 0x7f:
|
||||
if pos + 4 > len(plain):
|
||||
if pos + 4 > plain_len:
|
||||
break
|
||||
msg_len = (
|
||||
struct.unpack_from('<I', plain, pos + 1)[0] & 0xFFFFFF
|
||||
_st_I_le.unpack_from(plain, pos + 1)[0] & 0xFFFFFF
|
||||
) * 4
|
||||
pos += 4
|
||||
else:
|
||||
msg_len = first * 4
|
||||
pos += 1
|
||||
if msg_len == 0 or pos + msg_len > len(plain):
|
||||
if msg_len == 0 or pos + msg_len > plain_len:
|
||||
break
|
||||
pos += msg_len
|
||||
boundaries.append(pos)
|
||||
@@ -630,8 +647,9 @@ async def _bridge_ws(reader, writer, ws: RawWebSocket, label,
|
||||
chunk = await reader.read(65536)
|
||||
if not chunk:
|
||||
break
|
||||
_stats.bytes_up += len(chunk)
|
||||
up_bytes += len(chunk)
|
||||
n = len(chunk)
|
||||
_stats.bytes_up += n
|
||||
up_bytes += n
|
||||
up_packets += 1
|
||||
if splitter:
|
||||
parts = splitter.split(chunk)
|
||||
@@ -653,14 +671,12 @@ async def _bridge_ws(reader, writer, ws: RawWebSocket, label,
|
||||
data = await ws.recv()
|
||||
if data is None:
|
||||
break
|
||||
_stats.bytes_down += len(data)
|
||||
down_bytes += len(data)
|
||||
n = len(data)
|
||||
_stats.bytes_down += n
|
||||
down_bytes += n
|
||||
down_packets += 1
|
||||
writer.write(data)
|
||||
# drain only when kernel buffer is filling up
|
||||
buf = writer.transport.get_write_buffer_size()
|
||||
if buf > _SEND_BUF:
|
||||
await writer.drain()
|
||||
await writer.drain()
|
||||
except (asyncio.CancelledError, ConnectionError, OSError):
|
||||
return
|
||||
except Exception as e:
|
||||
@@ -700,26 +716,27 @@ async def _bridge_tcp(reader, writer, remote_reader, remote_writer,
|
||||
label, dc=None, dst=None, port=None,
|
||||
is_media=False):
|
||||
"""Bidirectional TCP <-> TCP forwarding (for fallback)."""
|
||||
async def forward(src, dst_w, tag):
|
||||
async def forward(src, dst_w, is_up):
|
||||
try:
|
||||
while True:
|
||||
data = await src.read(65536)
|
||||
if not data:
|
||||
break
|
||||
if 'up' in tag:
|
||||
_stats.bytes_up += len(data)
|
||||
n = len(data)
|
||||
if is_up:
|
||||
_stats.bytes_up += n
|
||||
else:
|
||||
_stats.bytes_down += len(data)
|
||||
_stats.bytes_down += n
|
||||
dst_w.write(data)
|
||||
await dst_w.drain()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
log.debug("[%s] %s ended: %s", label, tag, e)
|
||||
log.debug("[%s] forward ended: %s", label, e)
|
||||
|
||||
tasks = [
|
||||
asyncio.create_task(forward(reader, remote_writer, 'up')),
|
||||
asyncio.create_task(forward(remote_reader, writer, 'down')),
|
||||
asyncio.create_task(forward(reader, remote_writer, True)),
|
||||
asyncio.create_task(forward(remote_reader, writer, False)),
|
||||
]
|
||||
try:
|
||||
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
||||
@@ -760,8 +777,12 @@ async def _pipe(r, w):
|
||||
pass
|
||||
|
||||
|
||||
_SOCKS5_REPLIES = {s: bytes([0x05, s, 0x00, 0x01, 0, 0, 0, 0, 0, 0])
|
||||
for s in (0x00, 0x05, 0x07, 0x08)}
|
||||
|
||||
|
||||
def _socks5_reply(status):
|
||||
return bytes([0x05, status, 0x00, 0x01]) + b'\x00' * 6
|
||||
return _SOCKS5_REPLIES[status]
|
||||
|
||||
|
||||
async def _tcp_fallback(reader, writer, dst, port, init, label,
|
||||
@@ -829,7 +850,7 @@ async def _handle_client(reader, writer):
|
||||
writer.close()
|
||||
return
|
||||
|
||||
port = struct.unpack('!H', await reader.readexactly(2))[0]
|
||||
port = _st_H.unpack(await reader.readexactly(2))[0]
|
||||
|
||||
if ':' in dst:
|
||||
log.error(
|
||||
@@ -1135,6 +1156,16 @@ def main():
|
||||
' --dc-ip 2:149.154.167.220')
|
||||
ap.add_argument('-v', '--verbose', action='store_true',
|
||||
help='Debug logging')
|
||||
ap.add_argument('--log-file', type=str, default=None, metavar='PATH',
|
||||
help='Log to file with rotation (default: stderr only)')
|
||||
ap.add_argument('--log-max-mb', type=float, default=5, metavar='MB',
|
||||
help='Max log file size in MB before rotation (default 5)')
|
||||
ap.add_argument('--log-backups', type=int, default=0, metavar='N',
|
||||
help='Number of rotated log files to keep (default 0)')
|
||||
ap.add_argument('--buf-kb', type=int, default=256, metavar='KB',
|
||||
help='Socket send/recv buffer size in KB (default 256)')
|
||||
ap.add_argument('--pool-size', type=int, default=4, metavar='N',
|
||||
help='WS connection pool size per DC (default 4, min 0)')
|
||||
args = ap.parse_args()
|
||||
|
||||
if not args.dc_ip:
|
||||
@@ -1146,11 +1177,30 @@ def main():
|
||||
log.error(str(e))
|
||||
sys.exit(1)
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG if args.verbose else logging.INFO,
|
||||
format='%(asctime)s %(levelname)-5s %(message)s',
|
||||
datefmt='%H:%M:%S',
|
||||
)
|
||||
log_level = logging.DEBUG if args.verbose else logging.INFO
|
||||
log_fmt = logging.Formatter('%(asctime)s %(levelname)-5s %(message)s',
|
||||
datefmt='%H:%M:%S')
|
||||
root = logging.getLogger()
|
||||
root.setLevel(log_level)
|
||||
|
||||
console = logging.StreamHandler()
|
||||
console.setFormatter(log_fmt)
|
||||
root.addHandler(console)
|
||||
|
||||
if args.log_file:
|
||||
fh = logging.handlers.RotatingFileHandler(
|
||||
args.log_file,
|
||||
maxBytes=max(32 * 1024, args.log_max_mb * 1024 * 1024),
|
||||
backupCount=max(0, args.log_backups),
|
||||
encoding='utf-8',
|
||||
)
|
||||
fh.setFormatter(log_fmt)
|
||||
root.addHandler(fh)
|
||||
|
||||
global _RECV_BUF, _SEND_BUF, _WS_POOL_SIZE
|
||||
_RECV_BUF = max(4, args.buf_kb) * 1024
|
||||
_SEND_BUF = _RECV_BUF
|
||||
_WS_POOL_SIZE = max(0, args.pool_size)
|
||||
|
||||
try:
|
||||
asyncio.run(_run(args.port, dc_opt, host=args.host))
|
||||
|
||||
Reference in New Issue
Block a user