diff --git a/proxy/tg_ws_proxy.py b/proxy/tg_ws_proxy.py index 0d63920..f25738c 100644 --- a/proxy/tg_ws_proxy.py +++ b/proxy/tg_ws_proxy.py @@ -131,13 +131,19 @@ def _xor_mask(data: bytes, mask: bytes) -> bytes: return (int.from_bytes(data, 'big') ^ int.from_bytes(mask_rep, 'big')).to_bytes(n, 'big') -# Pre-compiled struct formats for WS frame building +# Pre-compiled struct formats _st_BB = struct.Struct('>BB') _st_BBH = struct.Struct('>BBH') _st_BBQ = struct.Struct('>BBQ') _st_BB4s = struct.Struct('>BB4s') _st_BBH4s = struct.Struct('>BBH4s') _st_BBQ4s = struct.Struct('>BBQ4s') +_st_H = struct.Struct('>H') +_st_Q = struct.Struct('>Q') +_st_I_net = struct.Struct('!I') +_st_Ih = struct.Struct(' Tuple[int, bytes]: hdr = await self.reader.readexactly(2) opcode = hdr[0] & 0x0F - is_masked = bool(hdr[1] & 0x80) length = hdr[1] & 0x7F if length == 126: - length = struct.unpack('>H', - await self.reader.readexactly(2))[0] + length = _st_H.unpack( + await self.reader.readexactly(2))[0] elif length == 127: - length = struct.unpack('>Q', - await self.reader.readexactly(8))[0] + length = _st_Q.unpack( + await self.reader.readexactly(8))[0] - if is_masked: + if hdr[1] & 0x80: mask_key = await self.reader.readexactly(4) payload = await self.reader.readexactly(length) return opcode, _xor_mask(payload, mask_key) @@ -363,7 +368,7 @@ def _human_bytes(n: int) -> str: def _is_telegram_ip(ip: str) -> bool: try: - n = struct.unpack('!I', _socket.inet_aton(ip))[0] + n = _st_I_net.unpack(_socket.inet_aton(ip))[0] return any(lo <= n <= hi for lo, hi in _TG_RANGES) except OSError: return False @@ -380,17 +385,14 @@ def _dc_from_init(data: bytes) -> Tuple[Optional[int], bool]: Returns (dc_id, is_media). """ try: - key = bytes(data[8:40]) - iv = bytes(data[40:56]) - cipher = Cipher(algorithms.AES(key), modes.CTR(iv)) + cipher = Cipher(algorithms.AES(data[8:40]), modes.CTR(data[40:56])) encryptor = cipher.encryptor() - keystream = encryptor.update(_ZERO_64) + encryptor.finalize() + keystream = encryptor.update(_ZERO_64) plain = (int.from_bytes(data[56:64], 'big') ^ int.from_bytes(keystream[56:64], 'big')).to_bytes(8, 'big') - proto = struct.unpack(' bytes: new_dc = struct.pack(' len(plain): + if pos + 4 > plain_len: break msg_len = ( - struct.unpack_from(' len(plain): + if msg_len == 0 or pos + msg_len > plain_len: break pos += msg_len boundaries.append(pos) @@ -832,7 +832,7 @@ async def _handle_client(reader, writer): writer.close() return - port = struct.unpack('!H', await reader.readexactly(2))[0] + port = _st_H.unpack(await reader.readexactly(2))[0] if ':' in dst: log.error(