Optimizations
This commit is contained in:
parent
4ae7cb92f7
commit
6a80ca85e3
|
|
@ -85,6 +85,8 @@ _dc_fail_until: Dict[Tuple[int, bool], float] = {}
|
|||
_DC_FAIL_COOLDOWN = 30.0 # seconds to keep reduced WS timeout after failure
|
||||
_WS_FAIL_TIMEOUT = 2.0 # quick-retry timeout after a recent WS failure
|
||||
|
||||
_ZERO_64 = b'\x00' * 64
|
||||
|
||||
|
||||
_ssl_ctx = ssl.create_default_context()
|
||||
_ssl_ctx.check_hostname = False
|
||||
|
|
@ -129,6 +131,15 @@ def _xor_mask(data: bytes, mask: bytes) -> bytes:
|
|||
return (int.from_bytes(data, 'big') ^ int.from_bytes(mask_rep, 'big')).to_bytes(n, 'big')
|
||||
|
||||
|
||||
# Pre-compiled struct formats for WS frame building
|
||||
_st_BB = struct.Struct('>BB')
|
||||
_st_BBH = struct.Struct('>BBH')
|
||||
_st_BBQ = struct.Struct('>BBQ')
|
||||
_st_BB4s = struct.Struct('>BB4s')
|
||||
_st_BBH4s = struct.Struct('>BBH4s')
|
||||
_st_BBQ4s = struct.Struct('>BBQ4s')
|
||||
|
||||
|
||||
class RawWebSocket:
|
||||
"""
|
||||
Lightweight WebSocket client over asyncio reader/writer streams.
|
||||
|
|
@ -302,25 +313,23 @@ class RawWebSocket:
|
|||
@staticmethod
|
||||
def _build_frame(opcode: int, data: bytes,
|
||||
mask: bool = False) -> bytes:
|
||||
header = bytearray()
|
||||
header.append(0x80 | opcode) # FIN=1 + opcode
|
||||
length = len(data)
|
||||
mask_bit = 0x80 if mask else 0x00
|
||||
fb = 0x80 | opcode
|
||||
|
||||
if not mask:
|
||||
if length < 126:
|
||||
header.append(mask_bit | length)
|
||||
elif length < 65536:
|
||||
header.append(mask_bit | 126)
|
||||
header.extend(struct.pack('>H', length))
|
||||
else:
|
||||
header.append(mask_bit | 127)
|
||||
header.extend(struct.pack('>Q', length))
|
||||
return _st_BB.pack(fb, length) + data
|
||||
if length < 65536:
|
||||
return _st_BBH.pack(fb, 126, length) + data
|
||||
return _st_BBQ.pack(fb, 127, length) + data
|
||||
|
||||
if mask:
|
||||
mask_key = os.urandom(4)
|
||||
header.extend(mask_key)
|
||||
return bytes(header) + _xor_mask(data, mask_key)
|
||||
return bytes(header) + data
|
||||
masked = _xor_mask(data, mask_key)
|
||||
if length < 126:
|
||||
return _st_BB4s.pack(fb, 0x80 | length, mask_key) + masked
|
||||
if length < 65536:
|
||||
return _st_BBH4s.pack(fb, 0x80 | 126, length, mask_key) + masked
|
||||
return _st_BBQ4s.pack(fb, 0x80 | 127, length, mask_key) + masked
|
||||
|
||||
async def _read_frame(self) -> Tuple[int, bytes]:
|
||||
hdr = await self.reader.readexactly(2)
|
||||
|
|
@ -375,8 +384,8 @@ def _dc_from_init(data: bytes) -> Tuple[Optional[int], bool]:
|
|||
iv = bytes(data[40:56])
|
||||
cipher = Cipher(algorithms.AES(key), modes.CTR(iv))
|
||||
encryptor = cipher.encryptor()
|
||||
keystream = encryptor.update(b'\x00' * 64) + encryptor.finalize()
|
||||
plain = bytes(a ^ b for a, b in zip(data[56:64], keystream[56:64]))
|
||||
keystream = encryptor.update(_ZERO_64) + encryptor.finalize()
|
||||
plain = (int.from_bytes(data[56:64], 'big') ^ int.from_bytes(keystream[56:64], 'big')).to_bytes(8, 'big')
|
||||
proto = struct.unpack('<I', plain[0:4])[0]
|
||||
dc_raw = struct.unpack('<h', plain[4:6])[0]
|
||||
log.debug("dc_from_init: proto=0x%08X dc_raw=%d plain=%s",
|
||||
|
|
@ -406,7 +415,7 @@ def _patch_init_dc(data: bytes, dc: int) -> bytes:
|
|||
iv = bytes(data[40:56])
|
||||
cipher = Cipher(algorithms.AES(key_raw), modes.CTR(iv))
|
||||
enc = cipher.encryptor()
|
||||
ks = enc.update(b'\x00' * 64) + enc.finalize()
|
||||
ks = enc.update(_ZERO_64) + enc.finalize()
|
||||
patched = bytearray(data[:64])
|
||||
patched[60] = ks[60] ^ new_dc[0]
|
||||
patched[61] = ks[61] ^ new_dc[1]
|
||||
|
|
@ -434,7 +443,7 @@ class _MsgSplitter:
|
|||
iv = bytes(init_data[40:56])
|
||||
cipher = Cipher(algorithms.AES(key_raw), modes.CTR(iv))
|
||||
self._dec = cipher.encryptor()
|
||||
self._dec.update(b'\x00' * 64) # skip init packet
|
||||
self._dec.update(_ZERO_64) # skip init packet
|
||||
|
||||
def split(self, chunk: bytes) -> List[bytes]:
|
||||
"""Decrypt to find message boundaries, return split ciphertext."""
|
||||
|
|
@ -617,8 +626,9 @@ async def _bridge_ws(reader, writer, ws: RawWebSocket, label,
|
|||
chunk = await reader.read(65536)
|
||||
if not chunk:
|
||||
break
|
||||
_stats.bytes_up += len(chunk)
|
||||
up_bytes += len(chunk)
|
||||
n = len(chunk)
|
||||
_stats.bytes_up += n
|
||||
up_bytes += n
|
||||
up_packets += 1
|
||||
if splitter:
|
||||
parts = splitter.split(chunk)
|
||||
|
|
@ -640,8 +650,9 @@ async def _bridge_ws(reader, writer, ws: RawWebSocket, label,
|
|||
data = await ws.recv()
|
||||
if data is None:
|
||||
break
|
||||
_stats.bytes_down += len(data)
|
||||
down_bytes += len(data)
|
||||
n = len(data)
|
||||
_stats.bytes_down += n
|
||||
down_bytes += n
|
||||
down_packets += 1
|
||||
writer.write(data)
|
||||
# drain only when kernel buffer is filling up
|
||||
|
|
@ -687,26 +698,27 @@ async def _bridge_tcp(reader, writer, remote_reader, remote_writer,
|
|||
label, dc=None, dst=None, port=None,
|
||||
is_media=False):
|
||||
"""Bidirectional TCP <-> TCP forwarding (for fallback)."""
|
||||
async def forward(src, dst_w, tag):
|
||||
async def forward(src, dst_w, is_up):
|
||||
try:
|
||||
while True:
|
||||
data = await src.read(65536)
|
||||
if not data:
|
||||
break
|
||||
if 'up' in tag:
|
||||
_stats.bytes_up += len(data)
|
||||
n = len(data)
|
||||
if is_up:
|
||||
_stats.bytes_up += n
|
||||
else:
|
||||
_stats.bytes_down += len(data)
|
||||
_stats.bytes_down += n
|
||||
dst_w.write(data)
|
||||
await dst_w.drain()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
log.debug("[%s] %s ended: %s", label, tag, e)
|
||||
log.debug("[%s] forward ended: %s", label, e)
|
||||
|
||||
tasks = [
|
||||
asyncio.create_task(forward(reader, remote_writer, 'up')),
|
||||
asyncio.create_task(forward(remote_reader, writer, 'down')),
|
||||
asyncio.create_task(forward(reader, remote_writer, True)),
|
||||
asyncio.create_task(forward(remote_reader, writer, False)),
|
||||
]
|
||||
try:
|
||||
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
||||
|
|
@ -747,8 +759,12 @@ async def _pipe(r, w):
|
|||
pass
|
||||
|
||||
|
||||
_SOCKS5_REPLIES = {s: bytes([0x05, s, 0x00, 0x01, 0, 0, 0, 0, 0, 0])
|
||||
for s in (0x00, 0x05, 0x07, 0x08)}
|
||||
|
||||
|
||||
def _socks5_reply(status):
|
||||
return bytes([0x05, status, 0x00, 0x01]) + b'\x00' * 6
|
||||
return _SOCKS5_REPLIES[status]
|
||||
|
||||
|
||||
async def _tcp_fallback(reader, writer, dst, port, init, label,
|
||||
|
|
|
|||
Loading…
Reference in New Issue