diff --git a/proxy/balancer.py b/proxy/balancer.py index 977433a..f658b83 100644 --- a/proxy/balancer.py +++ b/proxy/balancer.py @@ -29,7 +29,8 @@ class _Balancer: def get_domains_for_dc(self, dc_id: int) -> Iterator[str]: current_domain = self._dc_to_domain.get(dc_id) - yield current_domain + if current_domain is not None: + yield current_domain shuffled_domains = self.domains[:] random.shuffle(shuffled_domains) diff --git a/proxy/bridge.py b/proxy/bridge.py index c2ce787..160278e 100644 --- a/proxy/bridge.py +++ b/proxy/bridge.py @@ -127,7 +127,7 @@ class MsgSplitter: async def do_fallback(reader, writer, relay_init, label, - dc, is_media, media_tag, + dc: int, is_media: bool, media_tag: str, ctx: CryptoCtx, splitter=None): fallback_dst = DC_DEFAULT_IPS.get(dc) use_cf = proxy_config.fallback_cfproxy @@ -141,9 +141,9 @@ async def do_fallback(reader, writer, relay_init, label, for method in methods: if method == 'cf': ok = await _cfproxy_fallback( - reader, writer, relay_init, label, + reader, writer, relay_init, label, ctx, dc=dc, is_media=is_media, - ctx=ctx, splitter=splitter) + splitter=splitter) if ok: return True elif method == 'tcp' and fallback_dst: @@ -151,15 +151,16 @@ async def do_fallback(reader, writer, relay_init, label, label, dc, media_tag, fallback_dst) ok = await _tcp_fallback( reader, writer, fallback_dst, 443, - relay_init, label, dc=dc, is_media=is_media, ctx=ctx) + relay_init, label, ctx) if ok: return True return False async def _cfproxy_fallback(reader, writer, relay_init, label, - dc=None, is_media=False, - ctx: CryptoCtx = None, splitter=None): + ctx: CryptoCtx, + dc: int, is_media: bool, + splitter=None): media_tag = ' media' if is_media else '' ws = None chosen_domain = None @@ -185,14 +186,13 @@ async def _cfproxy_fallback(reader, writer, relay_init, label, stats.connections_cfproxy += 1 await ws.send(relay_init) - await bridge_ws_reencrypt(reader, writer, ws, label, + await bridge_ws_reencrypt(reader, writer, ws, label, ctx, dc=dc, is_media=is_media, - ctx=ctx, splitter=splitter) + splitter=splitter) return True -async def _tcp_fallback(reader, writer, dst, port, relay_init, label, - dc=None, is_media=False, ctx: CryptoCtx = None): +async def _tcp_fallback(reader, writer, dst, port, relay_init, label, ctx: CryptoCtx): try: rr, rw = await asyncio.wait_for( asyncio.open_connection(dst, port), timeout=10) @@ -204,15 +204,14 @@ async def _tcp_fallback(reader, writer, dst, port, relay_init, label, stats.connections_tcp_fallback += 1 rw.write(relay_init) await rw.drain() - await _bridge_tcp_reencrypt(reader, writer, rr, rw, label, - dc=dc, is_media=is_media, ctx=ctx) + await _bridge_tcp_reencrypt(reader, writer, rr, rw, label, ctx) return True async def bridge_ws_reencrypt(reader, writer, ws: RawWebSocket, label, + ctx: CryptoCtx, dc=None, is_media=False, - ctx: CryptoCtx = None, - splitter: MsgSplitter = None): + splitter: Optional[MsgSplitter] = None): """ Bidirectional TCP(client) <-> WS(telegram) with re-encryption. client ciphertext → decrypt(clt_key) → encrypt(tg_key) → WS @@ -309,8 +308,7 @@ async def bridge_ws_reencrypt(reader, writer, ws: RawWebSocket, label, async def _bridge_tcp_reencrypt(reader, writer, remote_reader, remote_writer, - label, dc=None, is_media=False, - ctx: CryptoCtx = None): + label, ctx: CryptoCtx): """Bidirectional TCP <-> TCP with re-encryption.""" async def forward(src, dst_w, is_up): diff --git a/proxy/raw_websocket.py b/proxy/raw_websocket.py index ce79213..1f0bf5e 100644 --- a/proxy/raw_websocket.py +++ b/proxy/raw_websocket.py @@ -25,7 +25,7 @@ _ssl_ctx.verify_mode = ssl.CERT_NONE class WsHandshakeError(Exception): def __init__(self, status_code: int, status_line: str, - headers: dict = None, location: str = None): + headers: Optional[dict] = None, location: Optional[str] = None): self.status_code = status_code self.status_line = status_line self.headers = headers or {} diff --git a/proxy/tg_ws_proxy.py b/proxy/tg_ws_proxy.py index 8039546..c8025eb 100644 --- a/proxy/tg_ws_proxy.py +++ b/proxy/tg_ws_proxy.py @@ -36,7 +36,7 @@ log = logging.getLogger('tg-mtproto-proxy') DC_FAIL_COOLDOWN = 30.0 WS_FAIL_TIMEOUT = 2.0 ws_blacklist: Set[str] = set() -dc_fail_until: Dict[Tuple[int, bool], float] = {} +dc_fail_until: Dict[str, float] = {} def _try_handshake(handshake: bytes, secret: bytes) -> Optional[Tuple[int, bool, bytes, bytes]]: @@ -191,7 +191,7 @@ class _WsPool: except Exception: pass - async def warmup(self, dc_redirects: Dict[int, Optional[str]]): + async def warmup(self, dc_redirects: Dict[int, str]): for dc, target_ip in dc_redirects.items(): if target_ip is None: continue @@ -207,6 +207,146 @@ class _WsPool: _ws_pool = _WsPool() +async def _read_client_init(reader, writer, secret, label, masking): + if proxy_config.proxy_protocol: + try: + pp_line = await asyncio.wait_for( + reader.readline(), timeout=10) + except asyncio.IncompleteReadError: + log.debug("[%s] disconnected during PROXY header", label) + return None + pp_text = pp_line.decode('ascii', errors='replace').strip() + if pp_text.startswith('PROXY '): + parts = pp_text.split() + if len(parts) >= 6: + label = f"{parts[2]}:{parts[4]}" + log.debug("[%s] PROXY protocol: %s", label, pp_text) + else: + log.debug("[%s] expected PROXY header, got: %r", label, + pp_text[:60]) + + try: + first_byte = await asyncio.wait_for( + reader.readexactly(1), timeout=10) + except asyncio.IncompleteReadError: + log.debug("[%s] client disconnected before handshake", label) + return None + + if first_byte[0] == TLS_RECORD_HANDSHAKE and masking: + try: + hdr_rest = await asyncio.wait_for( + reader.readexactly(4), timeout=10) + except asyncio.IncompleteReadError: + log.debug("[%s] incomplete TLS record header", label) + return None + + tls_header = first_byte + hdr_rest + record_len = struct.unpack('>H', tls_header[3:5])[0] + + try: + record_body = await asyncio.wait_for( + reader.readexactly(record_len), timeout=10) + except asyncio.IncompleteReadError: + log.debug("[%s] incomplete TLS record body", label) + return None + + client_hello = tls_header + record_body + + tls_result = verify_client_hello(client_hello, secret) + + if tls_result is None: + log.debug("[%s] Fake TLS verify failed (size=%d rec=%d) " + "-> masking", + label, len(client_hello), record_len) + await proxy_to_masking_domain( + reader, writer, client_hello, masking, label) + return None + + client_random, session_id, ts = tls_result + log.debug("[%s] Fake TLS handshake ok (ts=%d)", label, ts) + + server_hello = build_server_hello(secret, client_random, session_id) + writer.write(server_hello) + await writer.drain() + + tls_stream = FakeTlsStream(reader, writer) + + try: + handshake = await asyncio.wait_for( + tls_stream.readexactly(HANDSHAKE_LEN), timeout=10) + except asyncio.IncompleteReadError: + log.debug("[%s] incomplete obfs2 init inside TLS", label) + return None + + return handshake, tls_stream, tls_stream, label + + elif masking: + log.debug("[%s] non-TLS byte 0x%02X -> HTTP redirect", label, + first_byte[0]) + redirect = ( + f"HTTP/1.1 301 Moved Permanently\r\n" + f"Location: https://{masking}/\r\n" + f"Content-Length: 0\r\n" + f"Connection: close\r\n\r\n" + ).encode() + writer.write(redirect) + await writer.drain() + return None + + else: + try: + rest = await asyncio.wait_for( + reader.readexactly(HANDSHAKE_LEN - 1), timeout=10) + except asyncio.IncompleteReadError: + log.debug("[%s] client disconnected before handshake", label) + return None + return first_byte + rest, reader, writer, label + + +def _build_crypto_ctx(client_dec_prekey_iv, secret, relay_init): + # key = SHA256(prekey + secret), iv from handshake + # "dec" = decrypt data from client; "enc" = encrypt data to client + clt_dec_prekey = client_dec_prekey_iv[:PREKEY_LEN] + clt_dec_iv = client_dec_prekey_iv[PREKEY_LEN:] + clt_dec_key = hashlib.sha256(clt_dec_prekey + secret).digest() + + clt_enc_prekey_iv = client_dec_prekey_iv[::-1] + clt_enc_key = hashlib.sha256( + clt_enc_prekey_iv[:PREKEY_LEN] + secret).digest() + clt_enc_iv = clt_enc_prekey_iv[PREKEY_LEN:] + + clt_decryptor = Cipher( + algorithms.AES(clt_dec_key), modes.CTR(clt_dec_iv) + ).encryptor() + clt_encryptor = Cipher( + algorithms.AES(clt_enc_key), modes.CTR(clt_enc_iv) + ).encryptor() + + # fast-forward client decryptor past the 64-byte init + clt_decryptor.update(ZERO_64) + + # relay side: standard obfuscation (no secret hash, raw key) + relay_enc_key = relay_init[SKIP_LEN:SKIP_LEN + PREKEY_LEN] + relay_enc_iv = relay_init[SKIP_LEN + PREKEY_LEN: + SKIP_LEN + PREKEY_LEN + IV_LEN] + + relay_dec_prekey_iv = relay_init[SKIP_LEN: + SKIP_LEN + PREKEY_LEN + IV_LEN][::-1] + relay_dec_key = relay_dec_prekey_iv[:KEY_LEN] + relay_dec_iv = relay_dec_prekey_iv[KEY_LEN:] + + tg_encryptor = Cipher( + algorithms.AES(relay_enc_key), modes.CTR(relay_enc_iv) + ).encryptor() + tg_decryptor = Cipher( + algorithms.AES(relay_dec_key), modes.CTR(relay_dec_iv) + ).encryptor() + + tg_encryptor.update(ZERO_64) + + return CryptoCtx(clt_decryptor, clt_encryptor, tg_encryptor, tg_decryptor) + + async def _handle_client(reader, writer, secret: bytes): stats.connections_total += 1 stats.connections_active += 1 @@ -215,115 +355,25 @@ async def _handle_client(reader, writer, secret: bytes): set_sock_opts(writer.transport, proxy_config.buffer_size) - tls_stream = None - masking = proxy_config.fake_tls_domain - try: - if proxy_config.proxy_protocol: - try: - pp_line = await asyncio.wait_for( - reader.readline(), timeout=10) - except asyncio.IncompleteReadError: - log.debug("[%s] disconnected during PROXY header", label) - return - pp_text = pp_line.decode('ascii', errors='replace').strip() - if pp_text.startswith('PROXY '): - parts = pp_text.split() - if len(parts) >= 6: - label = f"{parts[2]}:{parts[4]}" - log.debug("[%s] PROXY protocol: %s", label, pp_text) - else: - log.debug("[%s] expected PROXY header, got: %r", label, - pp_text[:60]) - - try: - first_byte = await asyncio.wait_for( - reader.readexactly(1), timeout=10) - except asyncio.IncompleteReadError: - log.debug("[%s] client disconnected before handshake", label) + init = await _read_client_init( + reader, writer, secret, label, proxy_config.fake_tls_domain) + if init is None: return - if first_byte[0] == TLS_RECORD_HANDSHAKE and masking: - try: - hdr_rest = await asyncio.wait_for( - reader.readexactly(4), timeout=10) - except asyncio.IncompleteReadError: - log.debug("[%s] incomplete TLS record header", label) - return - - tls_header = first_byte + hdr_rest - record_len = struct.unpack('>H', tls_header[3:5])[0] - - try: - record_body = await asyncio.wait_for( - reader.readexactly(record_len), timeout=10) - except asyncio.IncompleteReadError: - log.debug("[%s] incomplete TLS record body", label) - return - - client_hello = tls_header + record_body - - tls_result = verify_client_hello(client_hello, secret) - - if tls_result is None: - log.debug("[%s] Fake TLS verify failed (size=%d rec=%d) " - "-> masking", - label, len(client_hello), record_len) - await proxy_to_masking_domain( - reader, writer, client_hello, masking, label) - return - - client_random, session_id, ts = tls_result - log.debug("[%s] Fake TLS handshake ok (ts=%d)", label, ts) - - server_hello = build_server_hello(secret, client_random, session_id) - writer.write(server_hello) - await writer.drain() - - tls_stream = FakeTlsStream(reader, writer) - - try: - handshake = await asyncio.wait_for( - tls_stream.readexactly(HANDSHAKE_LEN), timeout=10) - except asyncio.IncompleteReadError: - log.debug("[%s] incomplete obfs2 init inside TLS", label) - return - elif masking: - log.debug("[%s] non-TLS byte 0x%02X -> HTTP redirect", label, - first_byte[0]) - redirect = ( - f"HTTP/1.1 301 Moved Permanently\r\n" - f"Location: https://{masking}/\r\n" - f"Content-Length: 0\r\n" - f"Connection: close\r\n\r\n" - ).encode() - writer.write(redirect) - await writer.drain() - return - else: - try: - rest = await asyncio.wait_for( - reader.readexactly(HANDSHAKE_LEN - 1), timeout=10) - except asyncio.IncompleteReadError: - log.debug("[%s] client disconnected before handshake", label) - return - handshake = first_byte + rest + handshake, clt_reader, clt_writer, label = init result = _try_handshake(handshake, secret) if result is None: stats.connections_bad += 1 log.warning("[%s] bad handshake (wrong secret or proto)", label) try: - drain_src = tls_stream or reader - while await drain_src.read(4096): + while await clt_reader.read(4096): pass except Exception: pass return - clt_reader = tls_stream or reader - clt_writer = tls_stream or writer - dc, is_media, proto_tag, client_dec_prekey_iv = result if proto_tag == PROTO_TAG_ABRIDGED: @@ -339,48 +389,7 @@ async def _handle_client(reader, writer, secret: bytes): label, dc, ' media' if is_media else '', proto_int) relay_init = _generate_relay_init(proto_tag, dc_idx) - - # key = SHA256(prekey + secret), iv from handshake - # "dec" = decrypt data from client; "enc" = encrypt data to client - clt_dec_prekey = client_dec_prekey_iv[:PREKEY_LEN] - clt_dec_iv = client_dec_prekey_iv[PREKEY_LEN:] - clt_dec_key = hashlib.sha256(clt_dec_prekey + secret).digest() - - clt_enc_prekey_iv = client_dec_prekey_iv[::-1] - clt_enc_key = hashlib.sha256( - clt_enc_prekey_iv[:PREKEY_LEN] + secret).digest() - clt_enc_iv = clt_enc_prekey_iv[PREKEY_LEN:] - - clt_decryptor = Cipher( - algorithms.AES(clt_dec_key), modes.CTR(clt_dec_iv) - ).encryptor() - clt_encryptor = Cipher( - algorithms.AES(clt_enc_key), modes.CTR(clt_enc_iv) - ).encryptor() - - # fast-forward client decryptor past the 64-byte init - clt_decryptor.update(ZERO_64) - - # relay side: standard obfuscation (no secret hash, raw key) - relay_enc_key = relay_init[SKIP_LEN:SKIP_LEN + PREKEY_LEN] - relay_enc_iv = relay_init[SKIP_LEN + PREKEY_LEN: - SKIP_LEN + PREKEY_LEN + IV_LEN] - - relay_dec_prekey_iv = relay_init[SKIP_LEN: - SKIP_LEN + PREKEY_LEN + IV_LEN][::-1] - relay_dec_key = relay_dec_prekey_iv[:KEY_LEN] - relay_dec_iv = relay_dec_prekey_iv[KEY_LEN:] - - tg_encryptor = Cipher( - algorithms.AES(relay_enc_key), modes.CTR(relay_enc_iv) - ).encryptor() - tg_decryptor = Cipher( - algorithms.AES(relay_dec_key), modes.CTR(relay_dec_iv) - ).encryptor() - - tg_encryptor.update(ZERO_64) - - ctx = CryptoCtx(clt_decryptor, clt_encryptor, tg_encryptor, tg_decryptor) + ctx = _build_crypto_ctx(client_dec_prekey_iv, secret, relay_init) dc_key = f'{dc}{"m" if is_media else ""}' media_tag = " media" if is_media else "" @@ -490,9 +499,9 @@ async def _handle_client(reader, writer, secret: bytes): await ws.send(relay_init) - await bridge_ws_reencrypt(clt_reader, clt_writer, ws, label, + await bridge_ws_reencrypt(clt_reader, clt_writer, ws, label, ctx, dc=dc, is_media=is_media, - ctx=ctx, splitter=splitter) + splitter=splitter) except asyncio.TimeoutError: log.warning("[%s] timeout during handshake", label) diff --git a/proxy/utils.py b/proxy/utils.py index 10d5f57..3b51aa6 100644 --- a/proxy/utils.py +++ b/proxy/utils.py @@ -31,7 +31,7 @@ def human_bytes(n: int) -> str: for unit in ('B', 'KB', 'MB', 'GB'): if abs(n) < 1024: return f"{n:.1f}{unit}" - n /= 1024 + n /= 1024 # type: ignore return f"{n:.1f}TB"