From 1599b1126c34b2617f5c3faeb953d0cda548a113 Mon Sep 17 00:00:00 2001 From: Dark_Avery Date: Mon, 30 Mar 2026 16:14:42 +0300 Subject: [PATCH] feat(runtime): adapt android_migration shell to upstream mtproto core --- proxy/app_runtime.py | 33 +++++++- proxy/tg_ws_proxy.py | 24 ++++++ tests/test_app_runtime.py | 17 ++++ tests/test_crypto_mtproto.py | 93 +++++++++++++++------- tests/test_socks5_protocol.py | 145 +++++++++------------------------- tests/test_update_check.py | 7 +- 6 files changed, 179 insertions(+), 140 deletions(-) diff --git a/proxy/app_runtime.py b/proxy/app_runtime.py index 9cac781..5fb7ba8 100644 --- a/proxy/app_runtime.py +++ b/proxy/app_runtime.py @@ -4,6 +4,7 @@ import asyncio as _asyncio import json import logging import logging.handlers +import os import sys import threading import time @@ -14,8 +15,9 @@ import proxy.tg_ws_proxy as tg_ws_proxy DEFAULT_CONFIG = { - "port": 1080, + "port": 1443, "host": "127.0.0.1", + "secret": os.urandom(16).hex(), "dc_ip": ["2:149.154.167.220", "4:149.154.167.220"], "log_max_mb": 5, "buf_kb": 256, @@ -48,6 +50,27 @@ class ProxyAppRuntime: self._proxy_thread = None self._async_stop = None + def _build_core_config(self, active_cfg: dict, dc_opt: Dict[int, str]): + port = int(active_cfg.get("port", self.default_config["port"])) + host = str(active_cfg.get("host", self.default_config["host"])) + secret = str(active_cfg.get("secret") or "").strip() + if not secret: + secret = os.urandom(16).hex() + active_cfg["secret"] = secret + + buf_kb = int(active_cfg.get("buf_kb", self.default_config["buf_kb"])) + pool_size = int(active_cfg.get( + "pool_size", self.default_config["pool_size"])) + + return tg_ws_proxy.ProxyConfig( + port=port, + host=host, + secret=secret, + dc_redirects=dc_opt, + buffer_size=max(4, buf_kb) * 1024, + pool_size=max(0, pool_size), + ) + def ensure_dirs(self): self.app_dir.mkdir(parents=True, exist_ok=True) @@ -132,8 +155,7 @@ class ProxyAppRuntime: self._async_stop = (loop, stop_ev) try: - loop.run_until_complete( - self.run_proxy(port, dc_opt, stop_event=stop_ev, host=host)) + loop.run_until_complete(self.run_proxy(stop_event=stop_ev)) except Exception as exc: self.log.error("Proxy thread crashed: %s", exc) if ("10048" in str(exc) or @@ -143,6 +165,8 @@ class ProxyAppRuntime: "Порт уже используется другим приложением.\n\n" "Закройте приложение, использующее этот порт, " "или измените порт в настройках прокси и перезапустите.") + else: + self._emit_error(str(exc) or exc.__class__.__name__) finally: loop.close() self._async_stop = None @@ -168,6 +192,9 @@ class ProxyAppRuntime: self._emit_error("Ошибка конфигурации:\n%s" % exc) return False + tg_ws_proxy.proxy_config = self._build_core_config(active_cfg, dc_opt) + self.save_config(active_cfg) + self.log.info("Starting proxy on %s:%d ...", host, port) tg_ws_proxy._RECV_BUF = max(4, buf_kb) * 1024 tg_ws_proxy._SEND_BUF = tg_ws_proxy._RECV_BUF diff --git a/proxy/tg_ws_proxy.py b/proxy/tg_ws_proxy.py index 21ba025..f293f94 100644 --- a/proxy/tg_ws_proxy.py +++ b/proxy/tg_ws_proxy.py @@ -493,6 +493,7 @@ class Stats: self.bytes_down = 0 self.pool_hits = 0 self.pool_misses = 0 + self.last_transport_route: Optional[str] = None def summary(self) -> str: pool_total = self.pool_hits + self.pool_misses @@ -511,6 +512,27 @@ class Stats: _stats = Stats() +def reset_stats() -> None: + global _stats + _stats = Stats() + + +def get_stats_snapshot() -> Dict[str, object]: + return { + "connections_total": _stats.connections_total, + "connections_active": _stats.connections_active, + "connections_ws": _stats.connections_ws, + "connections_tcp_fallback": _stats.connections_tcp_fallback, + "connections_bad": _stats.connections_bad, + "ws_errors": _stats.ws_errors, + "bytes_up": _stats.bytes_up, + "bytes_down": _stats.bytes_down, + "pool_hits": _stats.pool_hits, + "pool_misses": _stats.pool_misses, + "last_transport_route": _stats.last_transport_route, + } + + class _WsPool: WS_POOL_MAX_AGE = 120.0 @@ -769,6 +791,7 @@ async def _tcp_fallback(reader, writer, dst, port, relay_init, label, return False _stats.connections_tcp_fallback += 1 + _stats.last_transport_route = "tcp_fallback" rw.write(relay_init) await rw.drain() await _bridge_tcp_reencrypt(reader, writer, rr, rw, label, @@ -965,6 +988,7 @@ async def _handle_client(reader, writer, secret: bytes): dc_fail_until.pop(dc_key, None) _stats.connections_ws += 1 + _stats.last_transport_route = "telegram_ws_direct" splitter = None try: diff --git a/tests/test_app_runtime.py b/tests/test_app_runtime.py index b6026f9..bf3dfc7 100644 --- a/tests/test_app_runtime.py +++ b/tests/test_app_runtime.py @@ -116,6 +116,23 @@ class ProxyAppRuntimeTests(unittest.TestCase): self.assertFalse(started) self.assertEqual(errors, ["Ошибка конфигурации:\nbad dc mapping"]) + def test_run_proxy_thread_reports_generic_runtime_error(self): + with tempfile.TemporaryDirectory() as tmpdir: + errors = [] + + async def fake_run_proxy(stop_event=None): + raise RuntimeError("proxy boom") + + runtime = ProxyAppRuntime( + Path(tmpdir), + on_error=errors.append, + run_proxy=fake_run_proxy, + ) + + runtime._run_proxy_thread(1443, {2: "149.154.167.220"}, "127.0.0.1") + + self.assertEqual(errors, ["proxy boom"]) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_crypto_mtproto.py b/tests/test_crypto_mtproto.py index 4c5a18a..2bd0864 100644 --- a/tests/test_crypto_mtproto.py +++ b/tests/test_crypto_mtproto.py @@ -1,38 +1,55 @@ +import hashlib import struct import unittest +from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes + from proxy.crypto_backend import create_aes_ctr_transform -from proxy.tg_ws_proxy import _MsgSplitter, _dc_from_init, _patch_init_dc +from proxy.tg_ws_proxy import ( + PROTO_ABRIDGED_INT, + PROTO_TAG_ABRIDGED, + _MsgSplitter, + _generate_relay_init, + _try_handshake, +) KEY = bytes(range(32)) IV = bytes(range(16)) -PROTO_TAG = 0xEFEFEFEF +SECRET = bytes.fromhex("0123456789abcdef0123456789abcdef") def _xor(left: bytes, right: bytes) -> bytes: return bytes(a ^ b for a, b in zip(left, right)) -def _keystream(size: int) -> bytes: - transform = create_aes_ctr_transform(KEY, IV) - return transform.update(b"\x00" * size) + transform.finalize() +def _keystream(size: int, key: bytes, iv: bytes) -> bytes: + transform = Cipher(algorithms.AES(key), modes.CTR(iv)).encryptor() + return transform.update(b"\x00" * size) -def _build_init_packet(dc_raw: int, proto: int = PROTO_TAG) -> bytes: +def _build_client_handshake( + dc_raw: int, + proto_tag: bytes = PROTO_TAG_ABRIDGED, + secret: bytes = SECRET, +) -> bytes: packet = bytearray(64) packet[8:40] = KEY packet[40:56] = IV - plain_tail = struct.pack(" bytes: - transform = create_aes_ctr_transform(init_packet[8:40], init_packet[40:56]) +def _encrypt_after_init(relay_init: bytes, plaintext: bytes) -> bytes: + transform = Cipher( + algorithms.AES(relay_init[8:40]), + modes.CTR(relay_init[40:56]), + ).encryptor() transform.update(b"\x00" * 64) - return transform.update(plaintext) + transform.finalize() + return transform.update(plaintext) class CryptoBackendTests(unittest.TestCase): @@ -63,42 +80,60 @@ class CryptoBackendTests(unittest.TestCase): create_aes_ctr_transform(KEY, IV, backend="missing") -class MtProtoInitTests(unittest.TestCase): - def test_dc_from_init_reads_non_media_dc(self): - init_packet = _build_init_packet(dc_raw=2) +class MtProtoHandshakeTests(unittest.TestCase): + def test_try_handshake_reads_non_media_dc(self): + handshake = _build_client_handshake(dc_raw=2) - self.assertEqual(_dc_from_init(init_packet), (2, False)) + result = _try_handshake(handshake, SECRET) - def test_dc_from_init_reads_media_dc(self): - init_packet = _build_init_packet(dc_raw=-4) + self.assertEqual(result[:3], (2, False, PROTO_TAG_ABRIDGED)) - self.assertEqual(_dc_from_init(init_packet), (4, True)) + def test_try_handshake_reads_media_dc(self): + handshake = _build_client_handshake(dc_raw=-4) - def test_patch_init_dc_updates_signed_dc_and_preserves_tail(self): - original = _build_init_packet(dc_raw=99) + b"tail" + result = _try_handshake(handshake, SECRET) - patched = _patch_init_dc(original, -3) + self.assertEqual(result[:3], (4, True, PROTO_TAG_ABRIDGED)) - self.assertEqual(_dc_from_init(patched[:64]), (3, True)) - self.assertEqual(patched[64:], b"tail") + def test_try_handshake_rejects_wrong_secret(self): + handshake = _build_client_handshake(dc_raw=2) + + result = _try_handshake( + handshake, + bytes.fromhex("fedcba9876543210fedcba9876543210"), + ) + + self.assertIsNone(result) + + def test_generate_relay_init_encodes_proto_and_signed_dc(self): + relay_init = _generate_relay_init(PROTO_TAG_ABRIDGED, -3) + decryptor = Cipher( + algorithms.AES(relay_init[8:40]), + modes.CTR(relay_init[40:56]), + ).encryptor() + + decrypted = decryptor.update(relay_init) + + self.assertEqual(decrypted[56:60], PROTO_TAG_ABRIDGED) + self.assertEqual(struct.unpack(" bytes: - end = self._offset + n - if end > len(self._payload): - partial = self._payload[self._offset:] - self._offset = len(self._payload) - raise asyncio.IncompleteReadError(partial, n) - chunk = self._payload[self._offset:end] - self._offset = end - return chunk +def _xor(left: bytes, right: bytes) -> bytes: + return bytes(a ^ b for a, b in zip(left, right)) -class _FakeWriter: - def __init__(self): - self.transport = _FakeTransport() - self.writes = [] - self.closed = False - self.close_calls = 0 +def _build_client_handshake(dc_raw: int, proto_tag: bytes) -> bytes: + packet = bytearray(64) + packet[8:40] = KEY + packet[40:56] = IV - def get_extra_info(self, name): - if name == "peername": - return ("127.0.0.1", 50000) - return None + dec_key = hashlib.sha256(KEY + SECRET).digest() + decryptor = Cipher(algorithms.AES(dec_key), modes.CTR(IV)).encryptor() + keystream = decryptor.update(b"\x00" * 64) - def write(self, data: bytes): - self.writes.append(data) - - async def drain(self): - return None - - def close(self): - self.closed = True - self.close_calls += 1 - - async def wait_closed(self): - return None + plain_tail = proto_tag + struct.pack(" bytes: - return bytes([0x05, cmd, 0x00, 0x01]) + socket.inet_aton(ip) + port.to_bytes(2, "big") +class MtProtoProtocolTests(unittest.TestCase): + def test_try_handshake_accepts_abridged_proto(self): + handshake = _build_client_handshake(2, PROTO_TAG_ABRIDGED) + result = _try_handshake(handshake, SECRET) -def _domain_connect_request(domain: str, port: int, cmd: int = 1) -> bytes: - encoded = domain.encode("utf-8") - return ( - bytes([0x05, cmd, 0x00, 0x03, len(encoded)]) - + encoded - + port.to_bytes(2, "big") - ) + self.assertIsNotNone(result) + self.assertEqual(result[:3], (2, False, PROTO_TAG_ABRIDGED)) + def test_try_handshake_accepts_intermediate_proto(self): + handshake = _build_client_handshake(-4, PROTO_TAG_INTERMEDIATE) -def _ipv6_connect_request(ip: str, port: int) -> bytes: - return ( - bytes([0x05, 0x01, 0x00, 0x04]) - + socket.inet_pton(socket.AF_INET6, ip) - + port.to_bytes(2, "big") - ) + result = _try_handshake(handshake, SECRET) + self.assertIsNotNone(result) + self.assertEqual(result[:3], (4, True, PROTO_TAG_INTERMEDIATE)) -class Socks5ProtocolTests(unittest.IsolatedAsyncioTestCase): - async def test_rejects_non_socks5_greeting(self): - reader = _FakeReader(b"\x04\x01") - writer = _FakeWriter() + def test_generate_relay_init_produces_handshake_sized_packet(self): + relay_init = _generate_relay_init(PROTO_TAG_ABRIDGED, -2) - await _handle_client(reader, writer) - - self.assertEqual(writer.writes, []) - self.assertTrue(writer.closed) - - async def test_rejects_unsupported_command(self): - reader = _FakeReader(b"\x05\x01\x00" + _ipv4_connect_request("1.1.1.1", 443, cmd=2)) - writer = _FakeWriter() - - await _handle_client(reader, writer) - - self.assertEqual(writer.writes, [b"\x05\x00", _socks5_reply(0x07)]) - self.assertTrue(writer.closed) - - async def test_rejects_unsupported_address_type(self): - reader = _FakeReader(b"\x05\x01\x00" + b"\x05\x01\x00\x02") - writer = _FakeWriter() - - await _handle_client(reader, writer) - - self.assertEqual(writer.writes, [b"\x05\x00", _socks5_reply(0x08)]) - self.assertTrue(writer.closed) - - async def test_rejects_ipv6_destinations(self): - reader = _FakeReader(b"\x05\x01\x00" + _ipv6_connect_request("2001:db8::1", 443)) - writer = _FakeWriter() - - await _handle_client(reader, writer) - - self.assertEqual(writer.writes, [b"\x05\x00", _socks5_reply(0x05)]) - self.assertTrue(writer.closed) - - async def test_passthrough_connect_failure_returns_error(self): - reader = _FakeReader(b"\x05\x01\x00" + _domain_connect_request("example.com", 443)) - writer = _FakeWriter() - - with patch("proxy.tg_ws_proxy.asyncio.open_connection", side_effect=OSError("boom")): - await _handle_client(reader, writer) - - self.assertEqual(writer.writes, [b"\x05\x00", _socks5_reply(0x05)]) - self.assertTrue(writer.closed) + self.assertEqual(len(relay_init), 64) + self.assertEqual(relay_init[0], relay_init[0] & 0xFF) if __name__ == "__main__": diff --git a/tests/test_update_check.py b/tests/test_update_check.py index bee30b9..fe3f23c 100644 --- a/tests/test_update_check.py +++ b/tests/test_update_check.py @@ -13,8 +13,11 @@ class UpdateCheckTests(unittest.TestCase): update_check._state.update(self._orig_state) def test_apply_release_tag_marks_update_available(self): + version_parts = [int(part) for part in __version__.split(".")] + version_parts[-1] += 1 + next_version = ".".join(str(part) for part in version_parts) update_check._apply_release_tag( - tag="v1.3.1", + tag=f"v{next_version}", html_url="https://example.com/release", current_version=__version__, ) @@ -22,7 +25,7 @@ class UpdateCheckTests(unittest.TestCase): status = update_check.get_status() self.assertTrue(status["has_update"]) self.assertFalse(status["ahead_of_release"]) - self.assertEqual(status["latest"], "1.3.1") + self.assertEqual(status["latest"], next_version) self.assertEqual(status["html_url"], "https://example.com/release") def test_apply_release_tag_marks_ahead_of_release(self):