Optimizations

This commit is contained in:
Flowseal 2026-03-20 22:57:15 +03:00
parent 6a80ca85e3
commit c1452c23da
1 changed files with 27 additions and 27 deletions

View File

@ -131,13 +131,19 @@ def _xor_mask(data: bytes, mask: bytes) -> bytes:
return (int.from_bytes(data, 'big') ^ int.from_bytes(mask_rep, 'big')).to_bytes(n, 'big')
# Pre-compiled struct formats for WS frame building
# Pre-compiled struct formats
_st_BB = struct.Struct('>BB')
_st_BBH = struct.Struct('>BBH')
_st_BBQ = struct.Struct('>BBQ')
_st_BB4s = struct.Struct('>BB4s')
_st_BBH4s = struct.Struct('>BBH4s')
_st_BBQ4s = struct.Struct('>BBQ4s')
_st_H = struct.Struct('>H')
_st_Q = struct.Struct('>Q')
_st_I_net = struct.Struct('!I')
_st_Ih = struct.Struct('<Ih')
_st_I_le = struct.Struct('<I')
_VALID_PROTOS = frozenset((0xEFEFEFEF, 0xEEEEEEEE, 0xDDDDDDDD))
class RawWebSocket:
@ -334,17 +340,16 @@ class RawWebSocket:
async def _read_frame(self) -> Tuple[int, bytes]:
hdr = await self.reader.readexactly(2)
opcode = hdr[0] & 0x0F
is_masked = bool(hdr[1] & 0x80)
length = hdr[1] & 0x7F
if length == 126:
length = struct.unpack('>H',
await self.reader.readexactly(2))[0]
length = _st_H.unpack(
await self.reader.readexactly(2))[0]
elif length == 127:
length = struct.unpack('>Q',
await self.reader.readexactly(8))[0]
length = _st_Q.unpack(
await self.reader.readexactly(8))[0]
if is_masked:
if hdr[1] & 0x80:
mask_key = await self.reader.readexactly(4)
payload = await self.reader.readexactly(length)
return opcode, _xor_mask(payload, mask_key)
@ -363,7 +368,7 @@ def _human_bytes(n: int) -> str:
def _is_telegram_ip(ip: str) -> bool:
try:
n = struct.unpack('!I', _socket.inet_aton(ip))[0]
n = _st_I_net.unpack(_socket.inet_aton(ip))[0]
return any(lo <= n <= hi for lo, hi in _TG_RANGES)
except OSError:
return False
@ -380,17 +385,14 @@ def _dc_from_init(data: bytes) -> Tuple[Optional[int], bool]:
Returns (dc_id, is_media).
"""
try:
key = bytes(data[8:40])
iv = bytes(data[40:56])
cipher = Cipher(algorithms.AES(key), modes.CTR(iv))
cipher = Cipher(algorithms.AES(data[8:40]), modes.CTR(data[40:56]))
encryptor = cipher.encryptor()
keystream = encryptor.update(_ZERO_64) + encryptor.finalize()
keystream = encryptor.update(_ZERO_64)
plain = (int.from_bytes(data[56:64], 'big') ^ int.from_bytes(keystream[56:64], 'big')).to_bytes(8, 'big')
proto = struct.unpack('<I', plain[0:4])[0]
dc_raw = struct.unpack('<h', plain[4:6])[0]
proto, dc_raw = _st_Ih.unpack(plain[:6])
log.debug("dc_from_init: proto=0x%08X dc_raw=%d plain=%s",
proto, dc_raw, plain.hex())
if proto in (0xEFEFEFEF, 0xEEEEEEEE, 0xDDDDDDDD):
if proto in _VALID_PROTOS:
dc = abs(dc_raw)
if 1 <= dc <= 5 or dc == 203:
return dc, (dc_raw < 0)
@ -411,11 +413,9 @@ def _patch_init_dc(data: bytes, dc: int) -> bytes:
new_dc = struct.pack('<h', dc)
try:
key_raw = bytes(data[8:40])
iv = bytes(data[40:56])
cipher = Cipher(algorithms.AES(key_raw), modes.CTR(iv))
cipher = Cipher(algorithms.AES(data[8:40]), modes.CTR(data[40:56]))
enc = cipher.encryptor()
ks = enc.update(_ZERO_64) + enc.finalize()
ks = enc.update(_ZERO_64)
patched = bytearray(data[:64])
patched[60] = ks[60] ^ new_dc[0]
patched[61] = ks[61] ^ new_dc[1]
@ -439,9 +439,8 @@ class _MsgSplitter:
"""
def __init__(self, init_data: bytes):
key_raw = bytes(init_data[8:40])
iv = bytes(init_data[40:56])
cipher = Cipher(algorithms.AES(key_raw), modes.CTR(iv))
cipher = Cipher(algorithms.AES(init_data[8:40]),
modes.CTR(init_data[40:56]))
self._dec = cipher.encryptor()
self._dec.update(_ZERO_64) # skip init packet
@ -450,19 +449,20 @@ class _MsgSplitter:
plain = self._dec.update(chunk)
boundaries = []
pos = 0
while pos < len(plain):
plain_len = len(plain)
while pos < plain_len:
first = plain[pos]
if first == 0x7f:
if pos + 4 > len(plain):
if pos + 4 > plain_len:
break
msg_len = (
struct.unpack_from('<I', plain, pos + 1)[0] & 0xFFFFFF
_st_I_le.unpack_from(plain, pos + 1)[0] & 0xFFFFFF
) * 4
pos += 4
else:
msg_len = first * 4
pos += 1
if msg_len == 0 or pos + msg_len > len(plain):
if msg_len == 0 or pos + msg_len > plain_len:
break
pos += msg_len
boundaries.append(pos)
@ -832,7 +832,7 @@ async def _handle_client(reader, writer):
writer.close()
return
port = struct.unpack('!H', await reader.readexactly(2))[0]
port = _st_H.unpack(await reader.readexactly(2))[0]
if ':' in dst:
log.error(