feat(runtime): adapt android_migration shell to upstream mtproto core
This commit is contained in:
parent
9e2c8c16ff
commit
1599b1126c
|
|
@ -4,6 +4,7 @@ import asyncio as _asyncio
|
|||
import json
|
||||
import logging
|
||||
import logging.handlers
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
|
|
@ -14,8 +15,9 @@ import proxy.tg_ws_proxy as tg_ws_proxy
|
|||
|
||||
|
||||
DEFAULT_CONFIG = {
|
||||
"port": 1080,
|
||||
"port": 1443,
|
||||
"host": "127.0.0.1",
|
||||
"secret": os.urandom(16).hex(),
|
||||
"dc_ip": ["2:149.154.167.220", "4:149.154.167.220"],
|
||||
"log_max_mb": 5,
|
||||
"buf_kb": 256,
|
||||
|
|
@ -48,6 +50,27 @@ class ProxyAppRuntime:
|
|||
self._proxy_thread = None
|
||||
self._async_stop = None
|
||||
|
||||
def _build_core_config(self, active_cfg: dict, dc_opt: Dict[int, str]):
|
||||
port = int(active_cfg.get("port", self.default_config["port"]))
|
||||
host = str(active_cfg.get("host", self.default_config["host"]))
|
||||
secret = str(active_cfg.get("secret") or "").strip()
|
||||
if not secret:
|
||||
secret = os.urandom(16).hex()
|
||||
active_cfg["secret"] = secret
|
||||
|
||||
buf_kb = int(active_cfg.get("buf_kb", self.default_config["buf_kb"]))
|
||||
pool_size = int(active_cfg.get(
|
||||
"pool_size", self.default_config["pool_size"]))
|
||||
|
||||
return tg_ws_proxy.ProxyConfig(
|
||||
port=port,
|
||||
host=host,
|
||||
secret=secret,
|
||||
dc_redirects=dc_opt,
|
||||
buffer_size=max(4, buf_kb) * 1024,
|
||||
pool_size=max(0, pool_size),
|
||||
)
|
||||
|
||||
def ensure_dirs(self):
|
||||
self.app_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
|
@ -132,8 +155,7 @@ class ProxyAppRuntime:
|
|||
self._async_stop = (loop, stop_ev)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
self.run_proxy(port, dc_opt, stop_event=stop_ev, host=host))
|
||||
loop.run_until_complete(self.run_proxy(stop_event=stop_ev))
|
||||
except Exception as exc:
|
||||
self.log.error("Proxy thread crashed: %s", exc)
|
||||
if ("10048" in str(exc) or
|
||||
|
|
@ -143,6 +165,8 @@ class ProxyAppRuntime:
|
|||
"Порт уже используется другим приложением.\n\n"
|
||||
"Закройте приложение, использующее этот порт, "
|
||||
"или измените порт в настройках прокси и перезапустите.")
|
||||
else:
|
||||
self._emit_error(str(exc) or exc.__class__.__name__)
|
||||
finally:
|
||||
loop.close()
|
||||
self._async_stop = None
|
||||
|
|
@ -168,6 +192,9 @@ class ProxyAppRuntime:
|
|||
self._emit_error("Ошибка конфигурации:\n%s" % exc)
|
||||
return False
|
||||
|
||||
tg_ws_proxy.proxy_config = self._build_core_config(active_cfg, dc_opt)
|
||||
self.save_config(active_cfg)
|
||||
|
||||
self.log.info("Starting proxy on %s:%d ...", host, port)
|
||||
tg_ws_proxy._RECV_BUF = max(4, buf_kb) * 1024
|
||||
tg_ws_proxy._SEND_BUF = tg_ws_proxy._RECV_BUF
|
||||
|
|
|
|||
|
|
@ -493,6 +493,7 @@ class Stats:
|
|||
self.bytes_down = 0
|
||||
self.pool_hits = 0
|
||||
self.pool_misses = 0
|
||||
self.last_transport_route: Optional[str] = None
|
||||
|
||||
def summary(self) -> str:
|
||||
pool_total = self.pool_hits + self.pool_misses
|
||||
|
|
@ -511,6 +512,27 @@ class Stats:
|
|||
_stats = Stats()
|
||||
|
||||
|
||||
def reset_stats() -> None:
|
||||
global _stats
|
||||
_stats = Stats()
|
||||
|
||||
|
||||
def get_stats_snapshot() -> Dict[str, object]:
|
||||
return {
|
||||
"connections_total": _stats.connections_total,
|
||||
"connections_active": _stats.connections_active,
|
||||
"connections_ws": _stats.connections_ws,
|
||||
"connections_tcp_fallback": _stats.connections_tcp_fallback,
|
||||
"connections_bad": _stats.connections_bad,
|
||||
"ws_errors": _stats.ws_errors,
|
||||
"bytes_up": _stats.bytes_up,
|
||||
"bytes_down": _stats.bytes_down,
|
||||
"pool_hits": _stats.pool_hits,
|
||||
"pool_misses": _stats.pool_misses,
|
||||
"last_transport_route": _stats.last_transport_route,
|
||||
}
|
||||
|
||||
|
||||
class _WsPool:
|
||||
WS_POOL_MAX_AGE = 120.0
|
||||
|
||||
|
|
@ -769,6 +791,7 @@ async def _tcp_fallback(reader, writer, dst, port, relay_init, label,
|
|||
return False
|
||||
|
||||
_stats.connections_tcp_fallback += 1
|
||||
_stats.last_transport_route = "tcp_fallback"
|
||||
rw.write(relay_init)
|
||||
await rw.drain()
|
||||
await _bridge_tcp_reencrypt(reader, writer, rr, rw, label,
|
||||
|
|
@ -965,6 +988,7 @@ async def _handle_client(reader, writer, secret: bytes):
|
|||
|
||||
dc_fail_until.pop(dc_key, None)
|
||||
_stats.connections_ws += 1
|
||||
_stats.last_transport_route = "telegram_ws_direct"
|
||||
|
||||
splitter = None
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -116,6 +116,23 @@ class ProxyAppRuntimeTests(unittest.TestCase):
|
|||
self.assertFalse(started)
|
||||
self.assertEqual(errors, ["Ошибка конфигурации:\nbad dc mapping"])
|
||||
|
||||
def test_run_proxy_thread_reports_generic_runtime_error(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
errors = []
|
||||
|
||||
async def fake_run_proxy(stop_event=None):
|
||||
raise RuntimeError("proxy boom")
|
||||
|
||||
runtime = ProxyAppRuntime(
|
||||
Path(tmpdir),
|
||||
on_error=errors.append,
|
||||
run_proxy=fake_run_proxy,
|
||||
)
|
||||
|
||||
runtime._run_proxy_thread(1443, {2: "149.154.167.220"}, "127.0.0.1")
|
||||
|
||||
self.assertEqual(errors, ["proxy boom"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -1,38 +1,55 @@
|
|||
import hashlib
|
||||
import struct
|
||||
import unittest
|
||||
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
|
||||
from proxy.crypto_backend import create_aes_ctr_transform
|
||||
from proxy.tg_ws_proxy import _MsgSplitter, _dc_from_init, _patch_init_dc
|
||||
from proxy.tg_ws_proxy import (
|
||||
PROTO_ABRIDGED_INT,
|
||||
PROTO_TAG_ABRIDGED,
|
||||
_MsgSplitter,
|
||||
_generate_relay_init,
|
||||
_try_handshake,
|
||||
)
|
||||
|
||||
|
||||
KEY = bytes(range(32))
|
||||
IV = bytes(range(16))
|
||||
PROTO_TAG = 0xEFEFEFEF
|
||||
SECRET = bytes.fromhex("0123456789abcdef0123456789abcdef")
|
||||
|
||||
|
||||
def _xor(left: bytes, right: bytes) -> bytes:
|
||||
return bytes(a ^ b for a, b in zip(left, right))
|
||||
|
||||
|
||||
def _keystream(size: int) -> bytes:
|
||||
transform = create_aes_ctr_transform(KEY, IV)
|
||||
return transform.update(b"\x00" * size) + transform.finalize()
|
||||
def _keystream(size: int, key: bytes, iv: bytes) -> bytes:
|
||||
transform = Cipher(algorithms.AES(key), modes.CTR(iv)).encryptor()
|
||||
return transform.update(b"\x00" * size)
|
||||
|
||||
|
||||
def _build_init_packet(dc_raw: int, proto: int = PROTO_TAG) -> bytes:
|
||||
def _build_client_handshake(
|
||||
dc_raw: int,
|
||||
proto_tag: bytes = PROTO_TAG_ABRIDGED,
|
||||
secret: bytes = SECRET,
|
||||
) -> bytes:
|
||||
packet = bytearray(64)
|
||||
packet[8:40] = KEY
|
||||
packet[40:56] = IV
|
||||
|
||||
plain_tail = struct.pack("<Ih", proto, dc_raw) + b"\x00\x00"
|
||||
packet[56:64] = _xor(plain_tail, _keystream(64)[56:64])
|
||||
dec_key = hashlib.sha256(KEY + secret).digest()
|
||||
plain_tail = proto_tag + struct.pack("<h", dc_raw) + b"\x00\x00"
|
||||
packet[56:64] = _xor(plain_tail, _keystream(64, dec_key, IV)[56:64])
|
||||
return bytes(packet)
|
||||
|
||||
|
||||
def _encrypt_after_init(init_packet: bytes, plaintext: bytes) -> bytes:
|
||||
transform = create_aes_ctr_transform(init_packet[8:40], init_packet[40:56])
|
||||
def _encrypt_after_init(relay_init: bytes, plaintext: bytes) -> bytes:
|
||||
transform = Cipher(
|
||||
algorithms.AES(relay_init[8:40]),
|
||||
modes.CTR(relay_init[40:56]),
|
||||
).encryptor()
|
||||
transform.update(b"\x00" * 64)
|
||||
return transform.update(plaintext) + transform.finalize()
|
||||
return transform.update(plaintext)
|
||||
|
||||
|
||||
class CryptoBackendTests(unittest.TestCase):
|
||||
|
|
@ -63,42 +80,60 @@ class CryptoBackendTests(unittest.TestCase):
|
|||
create_aes_ctr_transform(KEY, IV, backend="missing")
|
||||
|
||||
|
||||
class MtProtoInitTests(unittest.TestCase):
|
||||
def test_dc_from_init_reads_non_media_dc(self):
|
||||
init_packet = _build_init_packet(dc_raw=2)
|
||||
class MtProtoHandshakeTests(unittest.TestCase):
|
||||
def test_try_handshake_reads_non_media_dc(self):
|
||||
handshake = _build_client_handshake(dc_raw=2)
|
||||
|
||||
self.assertEqual(_dc_from_init(init_packet), (2, False))
|
||||
result = _try_handshake(handshake, SECRET)
|
||||
|
||||
def test_dc_from_init_reads_media_dc(self):
|
||||
init_packet = _build_init_packet(dc_raw=-4)
|
||||
self.assertEqual(result[:3], (2, False, PROTO_TAG_ABRIDGED))
|
||||
|
||||
self.assertEqual(_dc_from_init(init_packet), (4, True))
|
||||
def test_try_handshake_reads_media_dc(self):
|
||||
handshake = _build_client_handshake(dc_raw=-4)
|
||||
|
||||
def test_patch_init_dc_updates_signed_dc_and_preserves_tail(self):
|
||||
original = _build_init_packet(dc_raw=99) + b"tail"
|
||||
result = _try_handshake(handshake, SECRET)
|
||||
|
||||
patched = _patch_init_dc(original, -3)
|
||||
self.assertEqual(result[:3], (4, True, PROTO_TAG_ABRIDGED))
|
||||
|
||||
self.assertEqual(_dc_from_init(patched[:64]), (3, True))
|
||||
self.assertEqual(patched[64:], b"tail")
|
||||
def test_try_handshake_rejects_wrong_secret(self):
|
||||
handshake = _build_client_handshake(dc_raw=2)
|
||||
|
||||
result = _try_handshake(
|
||||
handshake,
|
||||
bytes.fromhex("fedcba9876543210fedcba9876543210"),
|
||||
)
|
||||
|
||||
self.assertIsNone(result)
|
||||
|
||||
def test_generate_relay_init_encodes_proto_and_signed_dc(self):
|
||||
relay_init = _generate_relay_init(PROTO_TAG_ABRIDGED, -3)
|
||||
decryptor = Cipher(
|
||||
algorithms.AES(relay_init[8:40]),
|
||||
modes.CTR(relay_init[40:56]),
|
||||
).encryptor()
|
||||
|
||||
decrypted = decryptor.update(relay_init)
|
||||
|
||||
self.assertEqual(decrypted[56:60], PROTO_TAG_ABRIDGED)
|
||||
self.assertEqual(struct.unpack("<h", decrypted[60:62])[0], -3)
|
||||
|
||||
|
||||
class MsgSplitterTests(unittest.TestCase):
|
||||
def test_splitter_splits_multiple_abridged_messages(self):
|
||||
init_packet = _build_init_packet(dc_raw=-2)
|
||||
relay_init = _generate_relay_init(PROTO_TAG_ABRIDGED, -2)
|
||||
plain_chunk = b"\x01abcd\x02EFGH1234"
|
||||
encrypted_chunk = _encrypt_after_init(init_packet, plain_chunk)
|
||||
encrypted_chunk = _encrypt_after_init(relay_init, plain_chunk)
|
||||
|
||||
parts = _MsgSplitter(init_packet).split(encrypted_chunk)
|
||||
parts = _MsgSplitter(relay_init, PROTO_ABRIDGED_INT).split(encrypted_chunk)
|
||||
|
||||
self.assertEqual(parts, [encrypted_chunk[:5], encrypted_chunk[5:14]])
|
||||
|
||||
def test_splitter_leaves_single_message_intact(self):
|
||||
init_packet = _build_init_packet(dc_raw=2)
|
||||
relay_init = _generate_relay_init(PROTO_TAG_ABRIDGED, 2)
|
||||
plain_chunk = b"\x02abcdefgh"
|
||||
encrypted_chunk = _encrypt_after_init(init_packet, plain_chunk)
|
||||
encrypted_chunk = _encrypt_after_init(relay_init, plain_chunk)
|
||||
|
||||
parts = _MsgSplitter(init_packet).split(encrypted_chunk)
|
||||
parts = _MsgSplitter(relay_init, PROTO_ABRIDGED_INT).split(encrypted_chunk)
|
||||
|
||||
self.assertEqual(parts, [encrypted_chunk])
|
||||
|
||||
|
|
|
|||
|
|
@ -1,128 +1,61 @@
|
|||
import asyncio
|
||||
import socket
|
||||
import hashlib
|
||||
import struct
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
|
||||
from proxy.tg_ws_proxy import _handle_client, _socks5_reply
|
||||
from proxy.tg_ws_proxy import (
|
||||
PROTO_TAG_ABRIDGED,
|
||||
PROTO_TAG_INTERMEDIATE,
|
||||
_generate_relay_init,
|
||||
_try_handshake,
|
||||
)
|
||||
|
||||
|
||||
class _FakeTransport:
|
||||
def get_extra_info(self, name):
|
||||
return None
|
||||
|
||||
def get_write_buffer_size(self):
|
||||
return 0
|
||||
KEY = bytes(range(32))
|
||||
IV = bytes(range(16))
|
||||
SECRET = bytes.fromhex("0123456789abcdef0123456789abcdef")
|
||||
|
||||
|
||||
class _FakeReader:
|
||||
def __init__(self, payload: bytes):
|
||||
self._payload = payload
|
||||
self._offset = 0
|
||||
|
||||
async def readexactly(self, n: int) -> bytes:
|
||||
end = self._offset + n
|
||||
if end > len(self._payload):
|
||||
partial = self._payload[self._offset:]
|
||||
self._offset = len(self._payload)
|
||||
raise asyncio.IncompleteReadError(partial, n)
|
||||
chunk = self._payload[self._offset:end]
|
||||
self._offset = end
|
||||
return chunk
|
||||
def _xor(left: bytes, right: bytes) -> bytes:
|
||||
return bytes(a ^ b for a, b in zip(left, right))
|
||||
|
||||
|
||||
class _FakeWriter:
|
||||
def __init__(self):
|
||||
self.transport = _FakeTransport()
|
||||
self.writes = []
|
||||
self.closed = False
|
||||
self.close_calls = 0
|
||||
def _build_client_handshake(dc_raw: int, proto_tag: bytes) -> bytes:
|
||||
packet = bytearray(64)
|
||||
packet[8:40] = KEY
|
||||
packet[40:56] = IV
|
||||
|
||||
def get_extra_info(self, name):
|
||||
if name == "peername":
|
||||
return ("127.0.0.1", 50000)
|
||||
return None
|
||||
dec_key = hashlib.sha256(KEY + SECRET).digest()
|
||||
decryptor = Cipher(algorithms.AES(dec_key), modes.CTR(IV)).encryptor()
|
||||
keystream = decryptor.update(b"\x00" * 64)
|
||||
|
||||
def write(self, data: bytes):
|
||||
self.writes.append(data)
|
||||
|
||||
async def drain(self):
|
||||
return None
|
||||
|
||||
def close(self):
|
||||
self.closed = True
|
||||
self.close_calls += 1
|
||||
|
||||
async def wait_closed(self):
|
||||
return None
|
||||
plain_tail = proto_tag + struct.pack("<h", dc_raw) + b"\x00\x00"
|
||||
packet[56:64] = _xor(plain_tail, keystream[56:64])
|
||||
return bytes(packet)
|
||||
|
||||
|
||||
def _ipv4_connect_request(ip: str, port: int, cmd: int = 1) -> bytes:
|
||||
return bytes([0x05, cmd, 0x00, 0x01]) + socket.inet_aton(ip) + port.to_bytes(2, "big")
|
||||
class MtProtoProtocolTests(unittest.TestCase):
|
||||
def test_try_handshake_accepts_abridged_proto(self):
|
||||
handshake = _build_client_handshake(2, PROTO_TAG_ABRIDGED)
|
||||
|
||||
result = _try_handshake(handshake, SECRET)
|
||||
|
||||
def _domain_connect_request(domain: str, port: int, cmd: int = 1) -> bytes:
|
||||
encoded = domain.encode("utf-8")
|
||||
return (
|
||||
bytes([0x05, cmd, 0x00, 0x03, len(encoded)])
|
||||
+ encoded
|
||||
+ port.to_bytes(2, "big")
|
||||
)
|
||||
self.assertIsNotNone(result)
|
||||
self.assertEqual(result[:3], (2, False, PROTO_TAG_ABRIDGED))
|
||||
|
||||
def test_try_handshake_accepts_intermediate_proto(self):
|
||||
handshake = _build_client_handshake(-4, PROTO_TAG_INTERMEDIATE)
|
||||
|
||||
def _ipv6_connect_request(ip: str, port: int) -> bytes:
|
||||
return (
|
||||
bytes([0x05, 0x01, 0x00, 0x04])
|
||||
+ socket.inet_pton(socket.AF_INET6, ip)
|
||||
+ port.to_bytes(2, "big")
|
||||
)
|
||||
result = _try_handshake(handshake, SECRET)
|
||||
|
||||
self.assertIsNotNone(result)
|
||||
self.assertEqual(result[:3], (4, True, PROTO_TAG_INTERMEDIATE))
|
||||
|
||||
class Socks5ProtocolTests(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_rejects_non_socks5_greeting(self):
|
||||
reader = _FakeReader(b"\x04\x01")
|
||||
writer = _FakeWriter()
|
||||
def test_generate_relay_init_produces_handshake_sized_packet(self):
|
||||
relay_init = _generate_relay_init(PROTO_TAG_ABRIDGED, -2)
|
||||
|
||||
await _handle_client(reader, writer)
|
||||
|
||||
self.assertEqual(writer.writes, [])
|
||||
self.assertTrue(writer.closed)
|
||||
|
||||
async def test_rejects_unsupported_command(self):
|
||||
reader = _FakeReader(b"\x05\x01\x00" + _ipv4_connect_request("1.1.1.1", 443, cmd=2))
|
||||
writer = _FakeWriter()
|
||||
|
||||
await _handle_client(reader, writer)
|
||||
|
||||
self.assertEqual(writer.writes, [b"\x05\x00", _socks5_reply(0x07)])
|
||||
self.assertTrue(writer.closed)
|
||||
|
||||
async def test_rejects_unsupported_address_type(self):
|
||||
reader = _FakeReader(b"\x05\x01\x00" + b"\x05\x01\x00\x02")
|
||||
writer = _FakeWriter()
|
||||
|
||||
await _handle_client(reader, writer)
|
||||
|
||||
self.assertEqual(writer.writes, [b"\x05\x00", _socks5_reply(0x08)])
|
||||
self.assertTrue(writer.closed)
|
||||
|
||||
async def test_rejects_ipv6_destinations(self):
|
||||
reader = _FakeReader(b"\x05\x01\x00" + _ipv6_connect_request("2001:db8::1", 443))
|
||||
writer = _FakeWriter()
|
||||
|
||||
await _handle_client(reader, writer)
|
||||
|
||||
self.assertEqual(writer.writes, [b"\x05\x00", _socks5_reply(0x05)])
|
||||
self.assertTrue(writer.closed)
|
||||
|
||||
async def test_passthrough_connect_failure_returns_error(self):
|
||||
reader = _FakeReader(b"\x05\x01\x00" + _domain_connect_request("example.com", 443))
|
||||
writer = _FakeWriter()
|
||||
|
||||
with patch("proxy.tg_ws_proxy.asyncio.open_connection", side_effect=OSError("boom")):
|
||||
await _handle_client(reader, writer)
|
||||
|
||||
self.assertEqual(writer.writes, [b"\x05\x00", _socks5_reply(0x05)])
|
||||
self.assertTrue(writer.closed)
|
||||
self.assertEqual(len(relay_init), 64)
|
||||
self.assertEqual(relay_init[0], relay_init[0] & 0xFF)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -13,8 +13,11 @@ class UpdateCheckTests(unittest.TestCase):
|
|||
update_check._state.update(self._orig_state)
|
||||
|
||||
def test_apply_release_tag_marks_update_available(self):
|
||||
version_parts = [int(part) for part in __version__.split(".")]
|
||||
version_parts[-1] += 1
|
||||
next_version = ".".join(str(part) for part in version_parts)
|
||||
update_check._apply_release_tag(
|
||||
tag="v1.3.1",
|
||||
tag=f"v{next_version}",
|
||||
html_url="https://example.com/release",
|
||||
current_version=__version__,
|
||||
)
|
||||
|
|
@ -22,7 +25,7 @@ class UpdateCheckTests(unittest.TestCase):
|
|||
status = update_check.get_status()
|
||||
self.assertTrue(status["has_update"])
|
||||
self.assertFalse(status["ahead_of_release"])
|
||||
self.assertEqual(status["latest"], "1.3.1")
|
||||
self.assertEqual(status["latest"], next_version)
|
||||
self.assertEqual(status["html_url"], "https://example.com/release")
|
||||
|
||||
def test_apply_release_tag_marks_ahead_of_release(self):
|
||||
|
|
|
|||
Loading…
Reference in New Issue