feat(runtime): adapt android_migration shell to upstream mtproto core

This commit is contained in:
Dark_Avery 2026-03-30 16:14:42 +03:00
parent 9e2c8c16ff
commit 1599b1126c
6 changed files with 179 additions and 140 deletions

View File

@ -4,6 +4,7 @@ import asyncio as _asyncio
import json
import logging
import logging.handlers
import os
import sys
import threading
import time
@ -14,8 +15,9 @@ import proxy.tg_ws_proxy as tg_ws_proxy
DEFAULT_CONFIG = {
"port": 1080,
"port": 1443,
"host": "127.0.0.1",
"secret": os.urandom(16).hex(),
"dc_ip": ["2:149.154.167.220", "4:149.154.167.220"],
"log_max_mb": 5,
"buf_kb": 256,
@ -48,6 +50,27 @@ class ProxyAppRuntime:
self._proxy_thread = None
self._async_stop = None
def _build_core_config(self, active_cfg: dict, dc_opt: Dict[int, str]):
port = int(active_cfg.get("port", self.default_config["port"]))
host = str(active_cfg.get("host", self.default_config["host"]))
secret = str(active_cfg.get("secret") or "").strip()
if not secret:
secret = os.urandom(16).hex()
active_cfg["secret"] = secret
buf_kb = int(active_cfg.get("buf_kb", self.default_config["buf_kb"]))
pool_size = int(active_cfg.get(
"pool_size", self.default_config["pool_size"]))
return tg_ws_proxy.ProxyConfig(
port=port,
host=host,
secret=secret,
dc_redirects=dc_opt,
buffer_size=max(4, buf_kb) * 1024,
pool_size=max(0, pool_size),
)
def ensure_dirs(self):
self.app_dir.mkdir(parents=True, exist_ok=True)
@ -132,8 +155,7 @@ class ProxyAppRuntime:
self._async_stop = (loop, stop_ev)
try:
loop.run_until_complete(
self.run_proxy(port, dc_opt, stop_event=stop_ev, host=host))
loop.run_until_complete(self.run_proxy(stop_event=stop_ev))
except Exception as exc:
self.log.error("Proxy thread crashed: %s", exc)
if ("10048" in str(exc) or
@ -143,6 +165,8 @@ class ProxyAppRuntime:
"Порт уже используется другим приложением.\n\n"
"Закройте приложение, использующее этот порт, "
"или измените порт в настройках прокси и перезапустите.")
else:
self._emit_error(str(exc) or exc.__class__.__name__)
finally:
loop.close()
self._async_stop = None
@ -168,6 +192,9 @@ class ProxyAppRuntime:
self._emit_error("Ошибка конфигурации:\n%s" % exc)
return False
tg_ws_proxy.proxy_config = self._build_core_config(active_cfg, dc_opt)
self.save_config(active_cfg)
self.log.info("Starting proxy on %s:%d ...", host, port)
tg_ws_proxy._RECV_BUF = max(4, buf_kb) * 1024
tg_ws_proxy._SEND_BUF = tg_ws_proxy._RECV_BUF

View File

@ -493,6 +493,7 @@ class Stats:
self.bytes_down = 0
self.pool_hits = 0
self.pool_misses = 0
self.last_transport_route: Optional[str] = None
def summary(self) -> str:
pool_total = self.pool_hits + self.pool_misses
@ -511,6 +512,27 @@ class Stats:
_stats = Stats()
def reset_stats() -> None:
global _stats
_stats = Stats()
def get_stats_snapshot() -> Dict[str, object]:
return {
"connections_total": _stats.connections_total,
"connections_active": _stats.connections_active,
"connections_ws": _stats.connections_ws,
"connections_tcp_fallback": _stats.connections_tcp_fallback,
"connections_bad": _stats.connections_bad,
"ws_errors": _stats.ws_errors,
"bytes_up": _stats.bytes_up,
"bytes_down": _stats.bytes_down,
"pool_hits": _stats.pool_hits,
"pool_misses": _stats.pool_misses,
"last_transport_route": _stats.last_transport_route,
}
class _WsPool:
WS_POOL_MAX_AGE = 120.0
@ -769,6 +791,7 @@ async def _tcp_fallback(reader, writer, dst, port, relay_init, label,
return False
_stats.connections_tcp_fallback += 1
_stats.last_transport_route = "tcp_fallback"
rw.write(relay_init)
await rw.drain()
await _bridge_tcp_reencrypt(reader, writer, rr, rw, label,
@ -965,6 +988,7 @@ async def _handle_client(reader, writer, secret: bytes):
dc_fail_until.pop(dc_key, None)
_stats.connections_ws += 1
_stats.last_transport_route = "telegram_ws_direct"
splitter = None
try:

View File

@ -116,6 +116,23 @@ class ProxyAppRuntimeTests(unittest.TestCase):
self.assertFalse(started)
self.assertEqual(errors, ["Ошибка конфигурации:\nbad dc mapping"])
def test_run_proxy_thread_reports_generic_runtime_error(self):
with tempfile.TemporaryDirectory() as tmpdir:
errors = []
async def fake_run_proxy(stop_event=None):
raise RuntimeError("proxy boom")
runtime = ProxyAppRuntime(
Path(tmpdir),
on_error=errors.append,
run_proxy=fake_run_proxy,
)
runtime._run_proxy_thread(1443, {2: "149.154.167.220"}, "127.0.0.1")
self.assertEqual(errors, ["proxy boom"])
if __name__ == "__main__":
unittest.main()

View File

@ -1,38 +1,55 @@
import hashlib
import struct
import unittest
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from proxy.crypto_backend import create_aes_ctr_transform
from proxy.tg_ws_proxy import _MsgSplitter, _dc_from_init, _patch_init_dc
from proxy.tg_ws_proxy import (
PROTO_ABRIDGED_INT,
PROTO_TAG_ABRIDGED,
_MsgSplitter,
_generate_relay_init,
_try_handshake,
)
KEY = bytes(range(32))
IV = bytes(range(16))
PROTO_TAG = 0xEFEFEFEF
SECRET = bytes.fromhex("0123456789abcdef0123456789abcdef")
def _xor(left: bytes, right: bytes) -> bytes:
return bytes(a ^ b for a, b in zip(left, right))
def _keystream(size: int) -> bytes:
transform = create_aes_ctr_transform(KEY, IV)
return transform.update(b"\x00" * size) + transform.finalize()
def _keystream(size: int, key: bytes, iv: bytes) -> bytes:
transform = Cipher(algorithms.AES(key), modes.CTR(iv)).encryptor()
return transform.update(b"\x00" * size)
def _build_init_packet(dc_raw: int, proto: int = PROTO_TAG) -> bytes:
def _build_client_handshake(
dc_raw: int,
proto_tag: bytes = PROTO_TAG_ABRIDGED,
secret: bytes = SECRET,
) -> bytes:
packet = bytearray(64)
packet[8:40] = KEY
packet[40:56] = IV
plain_tail = struct.pack("<Ih", proto, dc_raw) + b"\x00\x00"
packet[56:64] = _xor(plain_tail, _keystream(64)[56:64])
dec_key = hashlib.sha256(KEY + secret).digest()
plain_tail = proto_tag + struct.pack("<h", dc_raw) + b"\x00\x00"
packet[56:64] = _xor(plain_tail, _keystream(64, dec_key, IV)[56:64])
return bytes(packet)
def _encrypt_after_init(init_packet: bytes, plaintext: bytes) -> bytes:
transform = create_aes_ctr_transform(init_packet[8:40], init_packet[40:56])
def _encrypt_after_init(relay_init: bytes, plaintext: bytes) -> bytes:
transform = Cipher(
algorithms.AES(relay_init[8:40]),
modes.CTR(relay_init[40:56]),
).encryptor()
transform.update(b"\x00" * 64)
return transform.update(plaintext) + transform.finalize()
return transform.update(plaintext)
class CryptoBackendTests(unittest.TestCase):
@ -63,42 +80,60 @@ class CryptoBackendTests(unittest.TestCase):
create_aes_ctr_transform(KEY, IV, backend="missing")
class MtProtoInitTests(unittest.TestCase):
def test_dc_from_init_reads_non_media_dc(self):
init_packet = _build_init_packet(dc_raw=2)
class MtProtoHandshakeTests(unittest.TestCase):
def test_try_handshake_reads_non_media_dc(self):
handshake = _build_client_handshake(dc_raw=2)
self.assertEqual(_dc_from_init(init_packet), (2, False))
result = _try_handshake(handshake, SECRET)
def test_dc_from_init_reads_media_dc(self):
init_packet = _build_init_packet(dc_raw=-4)
self.assertEqual(result[:3], (2, False, PROTO_TAG_ABRIDGED))
self.assertEqual(_dc_from_init(init_packet), (4, True))
def test_try_handshake_reads_media_dc(self):
handshake = _build_client_handshake(dc_raw=-4)
def test_patch_init_dc_updates_signed_dc_and_preserves_tail(self):
original = _build_init_packet(dc_raw=99) + b"tail"
result = _try_handshake(handshake, SECRET)
patched = _patch_init_dc(original, -3)
self.assertEqual(result[:3], (4, True, PROTO_TAG_ABRIDGED))
self.assertEqual(_dc_from_init(patched[:64]), (3, True))
self.assertEqual(patched[64:], b"tail")
def test_try_handshake_rejects_wrong_secret(self):
handshake = _build_client_handshake(dc_raw=2)
result = _try_handshake(
handshake,
bytes.fromhex("fedcba9876543210fedcba9876543210"),
)
self.assertIsNone(result)
def test_generate_relay_init_encodes_proto_and_signed_dc(self):
relay_init = _generate_relay_init(PROTO_TAG_ABRIDGED, -3)
decryptor = Cipher(
algorithms.AES(relay_init[8:40]),
modes.CTR(relay_init[40:56]),
).encryptor()
decrypted = decryptor.update(relay_init)
self.assertEqual(decrypted[56:60], PROTO_TAG_ABRIDGED)
self.assertEqual(struct.unpack("<h", decrypted[60:62])[0], -3)
class MsgSplitterTests(unittest.TestCase):
def test_splitter_splits_multiple_abridged_messages(self):
init_packet = _build_init_packet(dc_raw=-2)
relay_init = _generate_relay_init(PROTO_TAG_ABRIDGED, -2)
plain_chunk = b"\x01abcd\x02EFGH1234"
encrypted_chunk = _encrypt_after_init(init_packet, plain_chunk)
encrypted_chunk = _encrypt_after_init(relay_init, plain_chunk)
parts = _MsgSplitter(init_packet).split(encrypted_chunk)
parts = _MsgSplitter(relay_init, PROTO_ABRIDGED_INT).split(encrypted_chunk)
self.assertEqual(parts, [encrypted_chunk[:5], encrypted_chunk[5:14]])
def test_splitter_leaves_single_message_intact(self):
init_packet = _build_init_packet(dc_raw=2)
relay_init = _generate_relay_init(PROTO_TAG_ABRIDGED, 2)
plain_chunk = b"\x02abcdefgh"
encrypted_chunk = _encrypt_after_init(init_packet, plain_chunk)
encrypted_chunk = _encrypt_after_init(relay_init, plain_chunk)
parts = _MsgSplitter(init_packet).split(encrypted_chunk)
parts = _MsgSplitter(relay_init, PROTO_ABRIDGED_INT).split(encrypted_chunk)
self.assertEqual(parts, [encrypted_chunk])

View File

@ -1,128 +1,61 @@
import asyncio
import socket
import hashlib
import struct
import unittest
from unittest.mock import patch
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from proxy.tg_ws_proxy import _handle_client, _socks5_reply
from proxy.tg_ws_proxy import (
PROTO_TAG_ABRIDGED,
PROTO_TAG_INTERMEDIATE,
_generate_relay_init,
_try_handshake,
)
class _FakeTransport:
def get_extra_info(self, name):
return None
def get_write_buffer_size(self):
return 0
KEY = bytes(range(32))
IV = bytes(range(16))
SECRET = bytes.fromhex("0123456789abcdef0123456789abcdef")
class _FakeReader:
def __init__(self, payload: bytes):
self._payload = payload
self._offset = 0
async def readexactly(self, n: int) -> bytes:
end = self._offset + n
if end > len(self._payload):
partial = self._payload[self._offset:]
self._offset = len(self._payload)
raise asyncio.IncompleteReadError(partial, n)
chunk = self._payload[self._offset:end]
self._offset = end
return chunk
def _xor(left: bytes, right: bytes) -> bytes:
return bytes(a ^ b for a, b in zip(left, right))
class _FakeWriter:
def __init__(self):
self.transport = _FakeTransport()
self.writes = []
self.closed = False
self.close_calls = 0
def _build_client_handshake(dc_raw: int, proto_tag: bytes) -> bytes:
packet = bytearray(64)
packet[8:40] = KEY
packet[40:56] = IV
def get_extra_info(self, name):
if name == "peername":
return ("127.0.0.1", 50000)
return None
dec_key = hashlib.sha256(KEY + SECRET).digest()
decryptor = Cipher(algorithms.AES(dec_key), modes.CTR(IV)).encryptor()
keystream = decryptor.update(b"\x00" * 64)
def write(self, data: bytes):
self.writes.append(data)
async def drain(self):
return None
def close(self):
self.closed = True
self.close_calls += 1
async def wait_closed(self):
return None
plain_tail = proto_tag + struct.pack("<h", dc_raw) + b"\x00\x00"
packet[56:64] = _xor(plain_tail, keystream[56:64])
return bytes(packet)
def _ipv4_connect_request(ip: str, port: int, cmd: int = 1) -> bytes:
return bytes([0x05, cmd, 0x00, 0x01]) + socket.inet_aton(ip) + port.to_bytes(2, "big")
class MtProtoProtocolTests(unittest.TestCase):
def test_try_handshake_accepts_abridged_proto(self):
handshake = _build_client_handshake(2, PROTO_TAG_ABRIDGED)
result = _try_handshake(handshake, SECRET)
def _domain_connect_request(domain: str, port: int, cmd: int = 1) -> bytes:
encoded = domain.encode("utf-8")
return (
bytes([0x05, cmd, 0x00, 0x03, len(encoded)])
+ encoded
+ port.to_bytes(2, "big")
)
self.assertIsNotNone(result)
self.assertEqual(result[:3], (2, False, PROTO_TAG_ABRIDGED))
def test_try_handshake_accepts_intermediate_proto(self):
handshake = _build_client_handshake(-4, PROTO_TAG_INTERMEDIATE)
def _ipv6_connect_request(ip: str, port: int) -> bytes:
return (
bytes([0x05, 0x01, 0x00, 0x04])
+ socket.inet_pton(socket.AF_INET6, ip)
+ port.to_bytes(2, "big")
)
result = _try_handshake(handshake, SECRET)
self.assertIsNotNone(result)
self.assertEqual(result[:3], (4, True, PROTO_TAG_INTERMEDIATE))
class Socks5ProtocolTests(unittest.IsolatedAsyncioTestCase):
async def test_rejects_non_socks5_greeting(self):
reader = _FakeReader(b"\x04\x01")
writer = _FakeWriter()
def test_generate_relay_init_produces_handshake_sized_packet(self):
relay_init = _generate_relay_init(PROTO_TAG_ABRIDGED, -2)
await _handle_client(reader, writer)
self.assertEqual(writer.writes, [])
self.assertTrue(writer.closed)
async def test_rejects_unsupported_command(self):
reader = _FakeReader(b"\x05\x01\x00" + _ipv4_connect_request("1.1.1.1", 443, cmd=2))
writer = _FakeWriter()
await _handle_client(reader, writer)
self.assertEqual(writer.writes, [b"\x05\x00", _socks5_reply(0x07)])
self.assertTrue(writer.closed)
async def test_rejects_unsupported_address_type(self):
reader = _FakeReader(b"\x05\x01\x00" + b"\x05\x01\x00\x02")
writer = _FakeWriter()
await _handle_client(reader, writer)
self.assertEqual(writer.writes, [b"\x05\x00", _socks5_reply(0x08)])
self.assertTrue(writer.closed)
async def test_rejects_ipv6_destinations(self):
reader = _FakeReader(b"\x05\x01\x00" + _ipv6_connect_request("2001:db8::1", 443))
writer = _FakeWriter()
await _handle_client(reader, writer)
self.assertEqual(writer.writes, [b"\x05\x00", _socks5_reply(0x05)])
self.assertTrue(writer.closed)
async def test_passthrough_connect_failure_returns_error(self):
reader = _FakeReader(b"\x05\x01\x00" + _domain_connect_request("example.com", 443))
writer = _FakeWriter()
with patch("proxy.tg_ws_proxy.asyncio.open_connection", side_effect=OSError("boom")):
await _handle_client(reader, writer)
self.assertEqual(writer.writes, [b"\x05\x00", _socks5_reply(0x05)])
self.assertTrue(writer.closed)
self.assertEqual(len(relay_init), 64)
self.assertEqual(relay_init[0], relay_init[0] & 0xFF)
if __name__ == "__main__":

View File

@ -13,8 +13,11 @@ class UpdateCheckTests(unittest.TestCase):
update_check._state.update(self._orig_state)
def test_apply_release_tag_marks_update_available(self):
version_parts = [int(part) for part in __version__.split(".")]
version_parts[-1] += 1
next_version = ".".join(str(part) for part in version_parts)
update_check._apply_release_tag(
tag="v1.3.1",
tag=f"v{next_version}",
html_url="https://example.com/release",
current_version=__version__,
)
@ -22,7 +25,7 @@ class UpdateCheckTests(unittest.TestCase):
status = update_check.get_status()
self.assertTrue(status["has_update"])
self.assertFalse(status["ahead_of_release"])
self.assertEqual(status["latest"], "1.3.1")
self.assertEqual(status["latest"], next_version)
self.assertEqual(status["html_url"], "https://example.com/release")
def test_apply_release_tag_marks_ahead_of_release(self):