Optimizations
This commit is contained in:
parent
6a80ca85e3
commit
c1452c23da
|
|
@ -131,13 +131,19 @@ def _xor_mask(data: bytes, mask: bytes) -> bytes:
|
|||
return (int.from_bytes(data, 'big') ^ int.from_bytes(mask_rep, 'big')).to_bytes(n, 'big')
|
||||
|
||||
|
||||
# Pre-compiled struct formats for WS frame building
|
||||
# Pre-compiled struct formats
|
||||
_st_BB = struct.Struct('>BB')
|
||||
_st_BBH = struct.Struct('>BBH')
|
||||
_st_BBQ = struct.Struct('>BBQ')
|
||||
_st_BB4s = struct.Struct('>BB4s')
|
||||
_st_BBH4s = struct.Struct('>BBH4s')
|
||||
_st_BBQ4s = struct.Struct('>BBQ4s')
|
||||
_st_H = struct.Struct('>H')
|
||||
_st_Q = struct.Struct('>Q')
|
||||
_st_I_net = struct.Struct('!I')
|
||||
_st_Ih = struct.Struct('<Ih')
|
||||
_st_I_le = struct.Struct('<I')
|
||||
_VALID_PROTOS = frozenset((0xEFEFEFEF, 0xEEEEEEEE, 0xDDDDDDDD))
|
||||
|
||||
|
||||
class RawWebSocket:
|
||||
|
|
@ -334,17 +340,16 @@ class RawWebSocket:
|
|||
async def _read_frame(self) -> Tuple[int, bytes]:
|
||||
hdr = await self.reader.readexactly(2)
|
||||
opcode = hdr[0] & 0x0F
|
||||
is_masked = bool(hdr[1] & 0x80)
|
||||
length = hdr[1] & 0x7F
|
||||
|
||||
if length == 126:
|
||||
length = struct.unpack('>H',
|
||||
await self.reader.readexactly(2))[0]
|
||||
length = _st_H.unpack(
|
||||
await self.reader.readexactly(2))[0]
|
||||
elif length == 127:
|
||||
length = struct.unpack('>Q',
|
||||
await self.reader.readexactly(8))[0]
|
||||
length = _st_Q.unpack(
|
||||
await self.reader.readexactly(8))[0]
|
||||
|
||||
if is_masked:
|
||||
if hdr[1] & 0x80:
|
||||
mask_key = await self.reader.readexactly(4)
|
||||
payload = await self.reader.readexactly(length)
|
||||
return opcode, _xor_mask(payload, mask_key)
|
||||
|
|
@ -363,7 +368,7 @@ def _human_bytes(n: int) -> str:
|
|||
|
||||
def _is_telegram_ip(ip: str) -> bool:
|
||||
try:
|
||||
n = struct.unpack('!I', _socket.inet_aton(ip))[0]
|
||||
n = _st_I_net.unpack(_socket.inet_aton(ip))[0]
|
||||
return any(lo <= n <= hi for lo, hi in _TG_RANGES)
|
||||
except OSError:
|
||||
return False
|
||||
|
|
@ -380,17 +385,14 @@ def _dc_from_init(data: bytes) -> Tuple[Optional[int], bool]:
|
|||
Returns (dc_id, is_media).
|
||||
"""
|
||||
try:
|
||||
key = bytes(data[8:40])
|
||||
iv = bytes(data[40:56])
|
||||
cipher = Cipher(algorithms.AES(key), modes.CTR(iv))
|
||||
cipher = Cipher(algorithms.AES(data[8:40]), modes.CTR(data[40:56]))
|
||||
encryptor = cipher.encryptor()
|
||||
keystream = encryptor.update(_ZERO_64) + encryptor.finalize()
|
||||
keystream = encryptor.update(_ZERO_64)
|
||||
plain = (int.from_bytes(data[56:64], 'big') ^ int.from_bytes(keystream[56:64], 'big')).to_bytes(8, 'big')
|
||||
proto = struct.unpack('<I', plain[0:4])[0]
|
||||
dc_raw = struct.unpack('<h', plain[4:6])[0]
|
||||
proto, dc_raw = _st_Ih.unpack(plain[:6])
|
||||
log.debug("dc_from_init: proto=0x%08X dc_raw=%d plain=%s",
|
||||
proto, dc_raw, plain.hex())
|
||||
if proto in (0xEFEFEFEF, 0xEEEEEEEE, 0xDDDDDDDD):
|
||||
if proto in _VALID_PROTOS:
|
||||
dc = abs(dc_raw)
|
||||
if 1 <= dc <= 5 or dc == 203:
|
||||
return dc, (dc_raw < 0)
|
||||
|
|
@ -411,11 +413,9 @@ def _patch_init_dc(data: bytes, dc: int) -> bytes:
|
|||
|
||||
new_dc = struct.pack('<h', dc)
|
||||
try:
|
||||
key_raw = bytes(data[8:40])
|
||||
iv = bytes(data[40:56])
|
||||
cipher = Cipher(algorithms.AES(key_raw), modes.CTR(iv))
|
||||
cipher = Cipher(algorithms.AES(data[8:40]), modes.CTR(data[40:56]))
|
||||
enc = cipher.encryptor()
|
||||
ks = enc.update(_ZERO_64) + enc.finalize()
|
||||
ks = enc.update(_ZERO_64)
|
||||
patched = bytearray(data[:64])
|
||||
patched[60] = ks[60] ^ new_dc[0]
|
||||
patched[61] = ks[61] ^ new_dc[1]
|
||||
|
|
@ -439,9 +439,8 @@ class _MsgSplitter:
|
|||
"""
|
||||
|
||||
def __init__(self, init_data: bytes):
|
||||
key_raw = bytes(init_data[8:40])
|
||||
iv = bytes(init_data[40:56])
|
||||
cipher = Cipher(algorithms.AES(key_raw), modes.CTR(iv))
|
||||
cipher = Cipher(algorithms.AES(init_data[8:40]),
|
||||
modes.CTR(init_data[40:56]))
|
||||
self._dec = cipher.encryptor()
|
||||
self._dec.update(_ZERO_64) # skip init packet
|
||||
|
||||
|
|
@ -450,19 +449,20 @@ class _MsgSplitter:
|
|||
plain = self._dec.update(chunk)
|
||||
boundaries = []
|
||||
pos = 0
|
||||
while pos < len(plain):
|
||||
plain_len = len(plain)
|
||||
while pos < plain_len:
|
||||
first = plain[pos]
|
||||
if first == 0x7f:
|
||||
if pos + 4 > len(plain):
|
||||
if pos + 4 > plain_len:
|
||||
break
|
||||
msg_len = (
|
||||
struct.unpack_from('<I', plain, pos + 1)[0] & 0xFFFFFF
|
||||
_st_I_le.unpack_from(plain, pos + 1)[0] & 0xFFFFFF
|
||||
) * 4
|
||||
pos += 4
|
||||
else:
|
||||
msg_len = first * 4
|
||||
pos += 1
|
||||
if msg_len == 0 or pos + msg_len > len(plain):
|
||||
if msg_len == 0 or pos + msg_len > plain_len:
|
||||
break
|
||||
pos += msg_len
|
||||
boundaries.append(pos)
|
||||
|
|
@ -832,7 +832,7 @@ async def _handle_client(reader, writer):
|
|||
writer.close()
|
||||
return
|
||||
|
||||
port = struct.unpack('!H', await reader.readexactly(2))[0]
|
||||
port = _st_H.unpack(await reader.readexactly(2))[0]
|
||||
|
||||
if ':' in dst:
|
||||
log.error(
|
||||
|
|
|
|||
Loading…
Reference in New Issue