Optimizations
This commit is contained in:
parent
4ae7cb92f7
commit
6a80ca85e3
|
|
@ -85,6 +85,8 @@ _dc_fail_until: Dict[Tuple[int, bool], float] = {}
|
||||||
_DC_FAIL_COOLDOWN = 30.0 # seconds to keep reduced WS timeout after failure
|
_DC_FAIL_COOLDOWN = 30.0 # seconds to keep reduced WS timeout after failure
|
||||||
_WS_FAIL_TIMEOUT = 2.0 # quick-retry timeout after a recent WS failure
|
_WS_FAIL_TIMEOUT = 2.0 # quick-retry timeout after a recent WS failure
|
||||||
|
|
||||||
|
_ZERO_64 = b'\x00' * 64
|
||||||
|
|
||||||
|
|
||||||
_ssl_ctx = ssl.create_default_context()
|
_ssl_ctx = ssl.create_default_context()
|
||||||
_ssl_ctx.check_hostname = False
|
_ssl_ctx.check_hostname = False
|
||||||
|
|
@ -129,6 +131,15 @@ def _xor_mask(data: bytes, mask: bytes) -> bytes:
|
||||||
return (int.from_bytes(data, 'big') ^ int.from_bytes(mask_rep, 'big')).to_bytes(n, 'big')
|
return (int.from_bytes(data, 'big') ^ int.from_bytes(mask_rep, 'big')).to_bytes(n, 'big')
|
||||||
|
|
||||||
|
|
||||||
|
# Pre-compiled struct formats for WS frame building
|
||||||
|
_st_BB = struct.Struct('>BB')
|
||||||
|
_st_BBH = struct.Struct('>BBH')
|
||||||
|
_st_BBQ = struct.Struct('>BBQ')
|
||||||
|
_st_BB4s = struct.Struct('>BB4s')
|
||||||
|
_st_BBH4s = struct.Struct('>BBH4s')
|
||||||
|
_st_BBQ4s = struct.Struct('>BBQ4s')
|
||||||
|
|
||||||
|
|
||||||
class RawWebSocket:
|
class RawWebSocket:
|
||||||
"""
|
"""
|
||||||
Lightweight WebSocket client over asyncio reader/writer streams.
|
Lightweight WebSocket client over asyncio reader/writer streams.
|
||||||
|
|
@ -302,25 +313,23 @@ class RawWebSocket:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _build_frame(opcode: int, data: bytes,
|
def _build_frame(opcode: int, data: bytes,
|
||||||
mask: bool = False) -> bytes:
|
mask: bool = False) -> bytes:
|
||||||
header = bytearray()
|
|
||||||
header.append(0x80 | opcode) # FIN=1 + opcode
|
|
||||||
length = len(data)
|
length = len(data)
|
||||||
mask_bit = 0x80 if mask else 0x00
|
fb = 0x80 | opcode
|
||||||
|
|
||||||
|
if not mask:
|
||||||
|
if length < 126:
|
||||||
|
return _st_BB.pack(fb, length) + data
|
||||||
|
if length < 65536:
|
||||||
|
return _st_BBH.pack(fb, 126, length) + data
|
||||||
|
return _st_BBQ.pack(fb, 127, length) + data
|
||||||
|
|
||||||
|
mask_key = os.urandom(4)
|
||||||
|
masked = _xor_mask(data, mask_key)
|
||||||
if length < 126:
|
if length < 126:
|
||||||
header.append(mask_bit | length)
|
return _st_BB4s.pack(fb, 0x80 | length, mask_key) + masked
|
||||||
elif length < 65536:
|
if length < 65536:
|
||||||
header.append(mask_bit | 126)
|
return _st_BBH4s.pack(fb, 0x80 | 126, length, mask_key) + masked
|
||||||
header.extend(struct.pack('>H', length))
|
return _st_BBQ4s.pack(fb, 0x80 | 127, length, mask_key) + masked
|
||||||
else:
|
|
||||||
header.append(mask_bit | 127)
|
|
||||||
header.extend(struct.pack('>Q', length))
|
|
||||||
|
|
||||||
if mask:
|
|
||||||
mask_key = os.urandom(4)
|
|
||||||
header.extend(mask_key)
|
|
||||||
return bytes(header) + _xor_mask(data, mask_key)
|
|
||||||
return bytes(header) + data
|
|
||||||
|
|
||||||
async def _read_frame(self) -> Tuple[int, bytes]:
|
async def _read_frame(self) -> Tuple[int, bytes]:
|
||||||
hdr = await self.reader.readexactly(2)
|
hdr = await self.reader.readexactly(2)
|
||||||
|
|
@ -375,8 +384,8 @@ def _dc_from_init(data: bytes) -> Tuple[Optional[int], bool]:
|
||||||
iv = bytes(data[40:56])
|
iv = bytes(data[40:56])
|
||||||
cipher = Cipher(algorithms.AES(key), modes.CTR(iv))
|
cipher = Cipher(algorithms.AES(key), modes.CTR(iv))
|
||||||
encryptor = cipher.encryptor()
|
encryptor = cipher.encryptor()
|
||||||
keystream = encryptor.update(b'\x00' * 64) + encryptor.finalize()
|
keystream = encryptor.update(_ZERO_64) + encryptor.finalize()
|
||||||
plain = bytes(a ^ b for a, b in zip(data[56:64], keystream[56:64]))
|
plain = (int.from_bytes(data[56:64], 'big') ^ int.from_bytes(keystream[56:64], 'big')).to_bytes(8, 'big')
|
||||||
proto = struct.unpack('<I', plain[0:4])[0]
|
proto = struct.unpack('<I', plain[0:4])[0]
|
||||||
dc_raw = struct.unpack('<h', plain[4:6])[0]
|
dc_raw = struct.unpack('<h', plain[4:6])[0]
|
||||||
log.debug("dc_from_init: proto=0x%08X dc_raw=%d plain=%s",
|
log.debug("dc_from_init: proto=0x%08X dc_raw=%d plain=%s",
|
||||||
|
|
@ -406,7 +415,7 @@ def _patch_init_dc(data: bytes, dc: int) -> bytes:
|
||||||
iv = bytes(data[40:56])
|
iv = bytes(data[40:56])
|
||||||
cipher = Cipher(algorithms.AES(key_raw), modes.CTR(iv))
|
cipher = Cipher(algorithms.AES(key_raw), modes.CTR(iv))
|
||||||
enc = cipher.encryptor()
|
enc = cipher.encryptor()
|
||||||
ks = enc.update(b'\x00' * 64) + enc.finalize()
|
ks = enc.update(_ZERO_64) + enc.finalize()
|
||||||
patched = bytearray(data[:64])
|
patched = bytearray(data[:64])
|
||||||
patched[60] = ks[60] ^ new_dc[0]
|
patched[60] = ks[60] ^ new_dc[0]
|
||||||
patched[61] = ks[61] ^ new_dc[1]
|
patched[61] = ks[61] ^ new_dc[1]
|
||||||
|
|
@ -434,7 +443,7 @@ class _MsgSplitter:
|
||||||
iv = bytes(init_data[40:56])
|
iv = bytes(init_data[40:56])
|
||||||
cipher = Cipher(algorithms.AES(key_raw), modes.CTR(iv))
|
cipher = Cipher(algorithms.AES(key_raw), modes.CTR(iv))
|
||||||
self._dec = cipher.encryptor()
|
self._dec = cipher.encryptor()
|
||||||
self._dec.update(b'\x00' * 64) # skip init packet
|
self._dec.update(_ZERO_64) # skip init packet
|
||||||
|
|
||||||
def split(self, chunk: bytes) -> List[bytes]:
|
def split(self, chunk: bytes) -> List[bytes]:
|
||||||
"""Decrypt to find message boundaries, return split ciphertext."""
|
"""Decrypt to find message boundaries, return split ciphertext."""
|
||||||
|
|
@ -617,8 +626,9 @@ async def _bridge_ws(reader, writer, ws: RawWebSocket, label,
|
||||||
chunk = await reader.read(65536)
|
chunk = await reader.read(65536)
|
||||||
if not chunk:
|
if not chunk:
|
||||||
break
|
break
|
||||||
_stats.bytes_up += len(chunk)
|
n = len(chunk)
|
||||||
up_bytes += len(chunk)
|
_stats.bytes_up += n
|
||||||
|
up_bytes += n
|
||||||
up_packets += 1
|
up_packets += 1
|
||||||
if splitter:
|
if splitter:
|
||||||
parts = splitter.split(chunk)
|
parts = splitter.split(chunk)
|
||||||
|
|
@ -640,8 +650,9 @@ async def _bridge_ws(reader, writer, ws: RawWebSocket, label,
|
||||||
data = await ws.recv()
|
data = await ws.recv()
|
||||||
if data is None:
|
if data is None:
|
||||||
break
|
break
|
||||||
_stats.bytes_down += len(data)
|
n = len(data)
|
||||||
down_bytes += len(data)
|
_stats.bytes_down += n
|
||||||
|
down_bytes += n
|
||||||
down_packets += 1
|
down_packets += 1
|
||||||
writer.write(data)
|
writer.write(data)
|
||||||
# drain only when kernel buffer is filling up
|
# drain only when kernel buffer is filling up
|
||||||
|
|
@ -687,26 +698,27 @@ async def _bridge_tcp(reader, writer, remote_reader, remote_writer,
|
||||||
label, dc=None, dst=None, port=None,
|
label, dc=None, dst=None, port=None,
|
||||||
is_media=False):
|
is_media=False):
|
||||||
"""Bidirectional TCP <-> TCP forwarding (for fallback)."""
|
"""Bidirectional TCP <-> TCP forwarding (for fallback)."""
|
||||||
async def forward(src, dst_w, tag):
|
async def forward(src, dst_w, is_up):
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
data = await src.read(65536)
|
data = await src.read(65536)
|
||||||
if not data:
|
if not data:
|
||||||
break
|
break
|
||||||
if 'up' in tag:
|
n = len(data)
|
||||||
_stats.bytes_up += len(data)
|
if is_up:
|
||||||
|
_stats.bytes_up += n
|
||||||
else:
|
else:
|
||||||
_stats.bytes_down += len(data)
|
_stats.bytes_down += n
|
||||||
dst_w.write(data)
|
dst_w.write(data)
|
||||||
await dst_w.drain()
|
await dst_w.drain()
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
pass
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.debug("[%s] %s ended: %s", label, tag, e)
|
log.debug("[%s] forward ended: %s", label, e)
|
||||||
|
|
||||||
tasks = [
|
tasks = [
|
||||||
asyncio.create_task(forward(reader, remote_writer, 'up')),
|
asyncio.create_task(forward(reader, remote_writer, True)),
|
||||||
asyncio.create_task(forward(remote_reader, writer, 'down')),
|
asyncio.create_task(forward(remote_reader, writer, False)),
|
||||||
]
|
]
|
||||||
try:
|
try:
|
||||||
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
||||||
|
|
@ -747,8 +759,12 @@ async def _pipe(r, w):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
_SOCKS5_REPLIES = {s: bytes([0x05, s, 0x00, 0x01, 0, 0, 0, 0, 0, 0])
|
||||||
|
for s in (0x00, 0x05, 0x07, 0x08)}
|
||||||
|
|
||||||
|
|
||||||
def _socks5_reply(status):
|
def _socks5_reply(status):
|
||||||
return bytes([0x05, status, 0x00, 0x01]) + b'\x00' * 6
|
return _SOCKS5_REPLIES[status]
|
||||||
|
|
||||||
|
|
||||||
async def _tcp_fallback(reader, writer, dst, port, init, label,
|
async def _tcp_fallback(reader, writer, dst, port, init, label,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue