diff --git a/proxy/crypto_backend.py b/proxy/crypto_backend.py new file mode 100644 index 0000000..d7a99a2 --- /dev/null +++ b/proxy/crypto_backend.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +import os +from typing import Protocol + + +class AesCtrTransform(Protocol): + def update(self, data: bytes) -> bytes: + ... + + def finalize(self) -> bytes: + ... + + +def _create_cryptography_transform(key: bytes, + iv: bytes) -> AesCtrTransform: + from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes + + cipher = Cipher(algorithms.AES(key), modes.CTR(iv)) + return cipher.encryptor() + + +def create_aes_ctr_transform(key: bytes, iv: bytes, + backend: str | None = None) -> AesCtrTransform: + """ + Create a stateful AES-CTR transform. + + The backend name is configurable so Android can supply an alternative + implementation later without touching proxy logic. + """ + selected = backend or os.environ.get( + 'TG_WS_PROXY_CRYPTO_BACKEND', 'cryptography') + + if selected == 'cryptography': + return _create_cryptography_transform(key, iv) + + raise ValueError(f"Unsupported AES-CTR backend: {selected}") diff --git a/proxy/tg_ws_proxy.py b/proxy/tg_ws_proxy.py index 7912fd8..35cc3e7 100644 --- a/proxy/tg_ws_proxy.py +++ b/proxy/tg_ws_proxy.py @@ -11,7 +11,8 @@ import struct import sys import time from typing import Dict, List, Optional, Set, Tuple -from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes + +from proxy.crypto_backend import create_aes_ctr_transform DEFAULT_PORT = 1080 @@ -365,8 +366,7 @@ def _dc_from_init(data: bytes) -> Tuple[Optional[int], bool]: try: key = bytes(data[8:40]) iv = bytes(data[40:56]) - cipher = Cipher(algorithms.AES(key), modes.CTR(iv)) - encryptor = cipher.encryptor() + encryptor = create_aes_ctr_transform(key, iv) keystream = encryptor.update(b'\x00' * 64) + encryptor.finalize() plain = bytes(a ^ b for a, b in zip(data[56:64], keystream[56:64])) proto = struct.unpack(' bytes: try: key_raw = bytes(data[8:40]) iv = bytes(data[40:56]) - cipher = Cipher(algorithms.AES(key_raw), modes.CTR(iv)) - enc = cipher.encryptor() + enc = create_aes_ctr_transform(key_raw, iv) ks = enc.update(b'\x00' * 64) + enc.finalize() patched = bytearray(data[:64]) patched[60] = ks[60] ^ new_dc[0] @@ -424,8 +423,7 @@ class _MsgSplitter: def __init__(self, init_data: bytes): key_raw = bytes(init_data[8:40]) iv = bytes(init_data[40:56]) - cipher = Cipher(algorithms.AES(key_raw), modes.CTR(iv)) - self._dec = cipher.encryptor() + self._dec = create_aes_ctr_transform(key_raw, iv) self._dec.update(b'\x00' * 64) # skip init packet def split(self, chunk: bytes) -> List[bytes]: