Optimizations

This commit is contained in:
Flowseal 2026-03-19 22:07:47 +03:00
parent 4ae7cb92f7
commit 6a80ca85e3
1 changed files with 48 additions and 32 deletions

View File

@ -85,6 +85,8 @@ _dc_fail_until: Dict[Tuple[int, bool], float] = {}
_DC_FAIL_COOLDOWN = 30.0 # seconds to keep reduced WS timeout after failure _DC_FAIL_COOLDOWN = 30.0 # seconds to keep reduced WS timeout after failure
_WS_FAIL_TIMEOUT = 2.0 # quick-retry timeout after a recent WS failure _WS_FAIL_TIMEOUT = 2.0 # quick-retry timeout after a recent WS failure
_ZERO_64 = b'\x00' * 64
_ssl_ctx = ssl.create_default_context() _ssl_ctx = ssl.create_default_context()
_ssl_ctx.check_hostname = False _ssl_ctx.check_hostname = False
@ -129,6 +131,15 @@ def _xor_mask(data: bytes, mask: bytes) -> bytes:
return (int.from_bytes(data, 'big') ^ int.from_bytes(mask_rep, 'big')).to_bytes(n, 'big') return (int.from_bytes(data, 'big') ^ int.from_bytes(mask_rep, 'big')).to_bytes(n, 'big')
# Pre-compiled struct formats for WS frame building
_st_BB = struct.Struct('>BB')
_st_BBH = struct.Struct('>BBH')
_st_BBQ = struct.Struct('>BBQ')
_st_BB4s = struct.Struct('>BB4s')
_st_BBH4s = struct.Struct('>BBH4s')
_st_BBQ4s = struct.Struct('>BBQ4s')
class RawWebSocket: class RawWebSocket:
""" """
Lightweight WebSocket client over asyncio reader/writer streams. Lightweight WebSocket client over asyncio reader/writer streams.
@ -302,25 +313,23 @@ class RawWebSocket:
@staticmethod @staticmethod
def _build_frame(opcode: int, data: bytes, def _build_frame(opcode: int, data: bytes,
mask: bool = False) -> bytes: mask: bool = False) -> bytes:
header = bytearray()
header.append(0x80 | opcode) # FIN=1 + opcode
length = len(data) length = len(data)
mask_bit = 0x80 if mask else 0x00 fb = 0x80 | opcode
if not mask:
if length < 126:
return _st_BB.pack(fb, length) + data
if length < 65536:
return _st_BBH.pack(fb, 126, length) + data
return _st_BBQ.pack(fb, 127, length) + data
mask_key = os.urandom(4)
masked = _xor_mask(data, mask_key)
if length < 126: if length < 126:
header.append(mask_bit | length) return _st_BB4s.pack(fb, 0x80 | length, mask_key) + masked
elif length < 65536: if length < 65536:
header.append(mask_bit | 126) return _st_BBH4s.pack(fb, 0x80 | 126, length, mask_key) + masked
header.extend(struct.pack('>H', length)) return _st_BBQ4s.pack(fb, 0x80 | 127, length, mask_key) + masked
else:
header.append(mask_bit | 127)
header.extend(struct.pack('>Q', length))
if mask:
mask_key = os.urandom(4)
header.extend(mask_key)
return bytes(header) + _xor_mask(data, mask_key)
return bytes(header) + data
async def _read_frame(self) -> Tuple[int, bytes]: async def _read_frame(self) -> Tuple[int, bytes]:
hdr = await self.reader.readexactly(2) hdr = await self.reader.readexactly(2)
@ -375,8 +384,8 @@ def _dc_from_init(data: bytes) -> Tuple[Optional[int], bool]:
iv = bytes(data[40:56]) iv = bytes(data[40:56])
cipher = Cipher(algorithms.AES(key), modes.CTR(iv)) cipher = Cipher(algorithms.AES(key), modes.CTR(iv))
encryptor = cipher.encryptor() encryptor = cipher.encryptor()
keystream = encryptor.update(b'\x00' * 64) + encryptor.finalize() keystream = encryptor.update(_ZERO_64) + encryptor.finalize()
plain = bytes(a ^ b for a, b in zip(data[56:64], keystream[56:64])) plain = (int.from_bytes(data[56:64], 'big') ^ int.from_bytes(keystream[56:64], 'big')).to_bytes(8, 'big')
proto = struct.unpack('<I', plain[0:4])[0] proto = struct.unpack('<I', plain[0:4])[0]
dc_raw = struct.unpack('<h', plain[4:6])[0] dc_raw = struct.unpack('<h', plain[4:6])[0]
log.debug("dc_from_init: proto=0x%08X dc_raw=%d plain=%s", log.debug("dc_from_init: proto=0x%08X dc_raw=%d plain=%s",
@ -406,7 +415,7 @@ def _patch_init_dc(data: bytes, dc: int) -> bytes:
iv = bytes(data[40:56]) iv = bytes(data[40:56])
cipher = Cipher(algorithms.AES(key_raw), modes.CTR(iv)) cipher = Cipher(algorithms.AES(key_raw), modes.CTR(iv))
enc = cipher.encryptor() enc = cipher.encryptor()
ks = enc.update(b'\x00' * 64) + enc.finalize() ks = enc.update(_ZERO_64) + enc.finalize()
patched = bytearray(data[:64]) patched = bytearray(data[:64])
patched[60] = ks[60] ^ new_dc[0] patched[60] = ks[60] ^ new_dc[0]
patched[61] = ks[61] ^ new_dc[1] patched[61] = ks[61] ^ new_dc[1]
@ -434,7 +443,7 @@ class _MsgSplitter:
iv = bytes(init_data[40:56]) iv = bytes(init_data[40:56])
cipher = Cipher(algorithms.AES(key_raw), modes.CTR(iv)) cipher = Cipher(algorithms.AES(key_raw), modes.CTR(iv))
self._dec = cipher.encryptor() self._dec = cipher.encryptor()
self._dec.update(b'\x00' * 64) # skip init packet self._dec.update(_ZERO_64) # skip init packet
def split(self, chunk: bytes) -> List[bytes]: def split(self, chunk: bytes) -> List[bytes]:
"""Decrypt to find message boundaries, return split ciphertext.""" """Decrypt to find message boundaries, return split ciphertext."""
@ -617,8 +626,9 @@ async def _bridge_ws(reader, writer, ws: RawWebSocket, label,
chunk = await reader.read(65536) chunk = await reader.read(65536)
if not chunk: if not chunk:
break break
_stats.bytes_up += len(chunk) n = len(chunk)
up_bytes += len(chunk) _stats.bytes_up += n
up_bytes += n
up_packets += 1 up_packets += 1
if splitter: if splitter:
parts = splitter.split(chunk) parts = splitter.split(chunk)
@ -640,8 +650,9 @@ async def _bridge_ws(reader, writer, ws: RawWebSocket, label,
data = await ws.recv() data = await ws.recv()
if data is None: if data is None:
break break
_stats.bytes_down += len(data) n = len(data)
down_bytes += len(data) _stats.bytes_down += n
down_bytes += n
down_packets += 1 down_packets += 1
writer.write(data) writer.write(data)
# drain only when kernel buffer is filling up # drain only when kernel buffer is filling up
@ -687,26 +698,27 @@ async def _bridge_tcp(reader, writer, remote_reader, remote_writer,
label, dc=None, dst=None, port=None, label, dc=None, dst=None, port=None,
is_media=False): is_media=False):
"""Bidirectional TCP <-> TCP forwarding (for fallback).""" """Bidirectional TCP <-> TCP forwarding (for fallback)."""
async def forward(src, dst_w, tag): async def forward(src, dst_w, is_up):
try: try:
while True: while True:
data = await src.read(65536) data = await src.read(65536)
if not data: if not data:
break break
if 'up' in tag: n = len(data)
_stats.bytes_up += len(data) if is_up:
_stats.bytes_up += n
else: else:
_stats.bytes_down += len(data) _stats.bytes_down += n
dst_w.write(data) dst_w.write(data)
await dst_w.drain() await dst_w.drain()
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
except Exception as e: except Exception as e:
log.debug("[%s] %s ended: %s", label, tag, e) log.debug("[%s] forward ended: %s", label, e)
tasks = [ tasks = [
asyncio.create_task(forward(reader, remote_writer, 'up')), asyncio.create_task(forward(reader, remote_writer, True)),
asyncio.create_task(forward(remote_reader, writer, 'down')), asyncio.create_task(forward(remote_reader, writer, False)),
] ]
try: try:
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
@ -747,8 +759,12 @@ async def _pipe(r, w):
pass pass
_SOCKS5_REPLIES = {s: bytes([0x05, s, 0x00, 0x01, 0, 0, 0, 0, 0, 0])
for s in (0x00, 0x05, 0x07, 0x08)}
def _socks5_reply(status): def _socks5_reply(status):
return bytes([0x05, status, 0x00, 0x01]) + b'\x00' * 6 return _SOCKS5_REPLIES[status]
async def _tcp_fallback(reader, writer, dst, port, init, label, async def _tcp_fallback(reader, writer, dst, port, init, label,