From d414c73c9b294ad9a54de2e19cecd90294042aff Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Sun, 14 Jun 2026 16:15:41 +0300 Subject: [PATCH] Hardened KDF-Tuple + NAT Probing + Paddings Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com> --- src/maestro/me_startup.rs | 2 + src/network/probe.rs | 40 ++- src/network/stun.rs | 276 +++++++++++++++--- src/protocol/constants.rs | 27 +- src/transport/middle_proxy/handshake.rs | 95 +----- src/transport/middle_proxy/health.rs | 2 + src/transport/middle_proxy/pool.rs | 6 + src/transport/middle_proxy/pool_nat.rs | 87 +++--- .../tests/health_adversarial_tests.rs | 2 + .../tests/health_integration_tests.rs | 2 + .../tests/health_regression_tests.rs | 2 + .../tests/pool_refill_security_tests.rs | 2 + .../tests/pool_writer_security_tests.rs | 2 + .../tests/send_adversarial_tests.rs | 2 + 14 files changed, 355 insertions(+), 192 deletions(-) diff --git a/src/maestro/me_startup.rs b/src/maestro/me_startup.rs index 9dde7fa..21e1ccf 100644 --- a/src/maestro/me_startup.rs +++ b/src/maestro/me_startup.rs @@ -208,6 +208,8 @@ pub(crate) async fn initialize_me_pool( me_nat_probe, None, config.network.stun_servers.clone(), + config.network.stun_tcp_fallback, + config.network.http_ip_detect_urls.clone(), config.general.stun_nat_probe_concurrency, probe.detected_ipv6, config.timeouts.me_one_retry, diff --git a/src/network/probe.rs b/src/network/probe.rs index 90484b3..5c3cbb8 100644 --- a/src/network/probe.rs +++ b/src/network/probe.rs @@ -12,7 +12,7 @@ use tracing::{debug, info, warn}; use crate::config::{NetworkConfig, UpstreamConfig, UpstreamType}; use crate::error::Result; use crate::network::stun::{ - DualStunResult, IpFamily, StunProbeResult, stun_probe_family_with_bind, + DualStunResult, IpFamily, StunProbeResult, stun_probe_family_with_bind_and_tcp_fallback, }; use crate::transport::UpstreamManager; @@ -58,6 +58,7 @@ impl NetworkDecision { } const STUN_BATCH_TIMEOUT: Duration = Duration::from_secs(5); +const STUN_BATCH_TCP_FALLBACK_TIMEOUT: Duration = Duration::from_secs(12); pub async fn run_probe( config: &NetworkConfig, @@ -81,8 +82,14 @@ pub async fn run_probe( warn!("STUN probe is enabled but network.stun_servers is empty"); DualStunResult::default() } else { - probe_stun_servers_parallel(&servers, stun_nat_probe_concurrency.max(1), None, None) - .await + probe_stun_servers_parallel( + &servers, + stun_nat_probe_concurrency.max(1), + None, + None, + config.stun_tcp_fallback, + ) + .await } } else if nat_probe { info!("STUN probe is disabled by network.stun_use=false"); @@ -163,6 +170,7 @@ pub async fn run_probe( stun_nat_probe_concurrency.max(1), bind_v4, bind_v6, + config.stun_tcp_fallback, ) .await; if let Some(reflected) = direct_stun_res.v4.map(|r| r.reflected_addr) { @@ -234,7 +242,7 @@ pub async fn run_probe( Ok(probe) } -async fn detect_public_ipv4_http(urls: &[String]) -> Option { +pub(crate) async fn detect_public_ipv4_http(urls: &[String]) -> Option { let client = reqwest::Client::builder() .timeout(Duration::from_secs(3)) .build() @@ -277,6 +285,7 @@ async fn probe_stun_servers_parallel( concurrency: usize, bind_v4: Option, bind_v6: Option, + tcp_fallback: bool, ) -> DualStunResult { let mut join_set = JoinSet::new(); let mut next_idx = 0usize; @@ -288,9 +297,26 @@ async fn probe_stun_servers_parallel( let stun_addr = servers[next_idx].clone(); next_idx += 1; join_set.spawn(async move { - let res = timeout(STUN_BATCH_TIMEOUT, async { - let v4 = stun_probe_family_with_bind(&stun_addr, IpFamily::V4, bind_v4).await?; - let v6 = stun_probe_family_with_bind(&stun_addr, IpFamily::V6, bind_v6).await?; + let batch_timeout = if tcp_fallback { + STUN_BATCH_TCP_FALLBACK_TIMEOUT + } else { + STUN_BATCH_TIMEOUT + }; + let res = timeout(batch_timeout, async { + let v4 = stun_probe_family_with_bind_and_tcp_fallback( + &stun_addr, + IpFamily::V4, + bind_v4, + tcp_fallback, + ) + .await?; + let v6 = stun_probe_family_with_bind_and_tcp_fallback( + &stun_addr, + IpFamily::V6, + bind_v6, + tcp_fallback, + ) + .await?; Ok::(DualStunResult { v4, v6 }) }) .await; diff --git a/src/network/stun.rs b/src/network/stun.rs index d1e088c..c2c1c86 100644 --- a/src/network/stun.rs +++ b/src/network/stun.rs @@ -4,7 +4,8 @@ use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; use std::sync::OnceLock; -use tokio::net::{UdpSocket, lookup_host}; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::{TcpSocket, UdpSocket, lookup_host}; use tokio::time::{Duration, sleep, timeout}; use crate::crypto::SecureRandom; @@ -36,9 +37,16 @@ pub struct DualStunResult { } pub async fn stun_probe_dual(stun_addr: &str) -> Result { + stun_probe_dual_with_tcp_fallback(stun_addr, false).await +} + +pub async fn stun_probe_dual_with_tcp_fallback( + stun_addr: &str, + tcp_fallback: bool, +) -> Result { let (v4, v6) = tokio::join!( - stun_probe_family(stun_addr, IpFamily::V4), - stun_probe_family(stun_addr, IpFamily::V6), + stun_probe_family_with_tcp_fallback(stun_addr, IpFamily::V4, tcp_fallback), + stun_probe_family_with_tcp_fallback(stun_addr, IpFamily::V6, tcp_fallback), ); Ok(DualStunResult { v4: v4?, v6: v6? }) @@ -48,13 +56,44 @@ pub async fn stun_probe_family( stun_addr: &str, family: IpFamily, ) -> Result> { - stun_probe_family_with_bind(stun_addr, family, None).await + stun_probe_family_with_tcp_fallback(stun_addr, family, false).await +} + +pub async fn stun_probe_family_with_tcp_fallback( + stun_addr: &str, + family: IpFamily, + tcp_fallback: bool, +) -> Result> { + stun_probe_family_with_bind_and_tcp_fallback(stun_addr, family, None, tcp_fallback).await } pub async fn stun_probe_family_with_bind( stun_addr: &str, family: IpFamily, bind_ip: Option, +) -> Result> { + stun_probe_family_with_bind_and_tcp_fallback(stun_addr, family, bind_ip, false).await +} + +pub async fn stun_probe_family_with_bind_and_tcp_fallback( + stun_addr: &str, + family: IpFamily, + bind_ip: Option, + tcp_fallback: bool, +) -> Result> { + let udp_attempts = if tcp_fallback { 1 } else { 3 }; + let udp_result = stun_probe_family_udp(stun_addr, family, bind_ip, udp_attempts).await?; + if udp_result.is_some() || !tcp_fallback { + return Ok(udp_result); + } + stun_probe_family_tcp(stun_addr, family, bind_ip).await +} + +async fn stun_probe_family_udp( + stun_addr: &str, + family: IpFamily, + bind_ip: Option, + max_attempts: u8, ) -> Result> { let bind_addr = match (family, bind_ip) { (IpFamily::V4, Some(IpAddr::V4(ip))) => SocketAddr::new(IpAddr::V4(ip), 0), @@ -94,12 +133,7 @@ pub async fn stun_probe_family_with_bind( return Ok(None); } - let mut req = [0u8; 20]; - req[0..2].copy_from_slice(&0x0001u16.to_be_bytes()); // Binding Request - req[2..4].copy_from_slice(&0u16.to_be_bytes()); // length - req[4..8].copy_from_slice(&0x2112A442u32.to_be_bytes()); // magic cookie - stun_rng().fill(&mut req[8..20]); // transaction ID - + let req = build_binding_request(); let mut buf = [0u8; 256]; let mut attempt = 0; let mut backoff = Duration::from_secs(1); @@ -115,7 +149,7 @@ pub async fn stun_probe_family_with_bind( Ok(Err(e)) => return Err(ProxyError::Proxy(format!("STUN recv failed: {e}"))), Err(_) => { attempt += 1; - if attempt >= 3 { + if attempt >= max_attempts { return Ok(None); } sleep(backoff).await; @@ -128,19 +162,139 @@ pub async fn stun_probe_family_with_bind( return Ok(None); } - let magic = 0x2112A442u32.to_be_bytes(); let txid = &req[8..20]; - let mut idx = 20; - while idx + 4 <= n { - let atype = u16::from_be_bytes(buf[idx..idx + 2].try_into().unwrap()); - let alen = u16::from_be_bytes(buf[idx + 2..idx + 4].try_into().unwrap()) as usize; - idx += 4; - if idx + alen > n { - break; - } + if let Some(reflected_addr) = parse_reflected_addr(&buf[..n], txid) { + let local_addr = socket + .local_addr() + .map_err(|e| ProxyError::Proxy(format!("STUN local_addr failed: {e}")))?; + return Ok(Some(StunProbeResult { + local_addr, + reflected_addr, + family, + })); + } + } - match atype { - 0x0020 /* XOR-MAPPED-ADDRESS */ | 0x0001 /* MAPPED-ADDRESS */ => { + Ok(None) +} + +async fn stun_probe_family_tcp( + stun_addr: &str, + family: IpFamily, + bind_ip: Option, +) -> Result> { + let target_addr = match resolve_stun_addr(stun_addr, family).await? { + Some(addr) => addr, + None => return Ok(None), + }; + let socket = match family { + IpFamily::V4 => TcpSocket::new_v4(), + IpFamily::V6 => TcpSocket::new_v6(), + } + .map_err(|e| ProxyError::Proxy(format!("STUN TCP socket failed: {e}")))?; + match (family, bind_ip) { + (IpFamily::V4, Some(IpAddr::V4(ip))) => { + if socket.bind(SocketAddr::new(IpAddr::V4(ip), 0)).is_err() { + return Ok(None); + } + } + (IpFamily::V6, Some(IpAddr::V6(ip))) => { + if socket.bind(SocketAddr::new(IpAddr::V6(ip), 0)).is_err() { + return Ok(None); + } + } + (IpFamily::V4, Some(IpAddr::V6(_))) | (IpFamily::V6, Some(IpAddr::V4(_))) => { + return Ok(None); + } + (_, None) => {} + } + + let connect_res = timeout(Duration::from_secs(3), socket.connect(target_addr)).await; + let mut stream = match connect_res { + Ok(Ok(stream)) => stream, + Ok(Err(e)) + if family == IpFamily::V6 + && matches!( + e.kind(), + std::io::ErrorKind::NetworkUnreachable + | std::io::ErrorKind::HostUnreachable + | std::io::ErrorKind::Unsupported + | std::io::ErrorKind::NetworkDown + ) => + { + return Ok(None); + } + Ok(Err(e)) => return Err(ProxyError::Proxy(format!("STUN TCP connect failed: {e}"))), + Err(_) => return Ok(None), + }; + + let req = build_binding_request(); + timeout(Duration::from_secs(3), stream.write_all(&req)) + .await + .map_err(|_| ProxyError::Proxy("STUN TCP send timeout".to_string()))? + .map_err(|e| ProxyError::Proxy(format!("STUN TCP send failed: {e}")))?; + + let mut header = [0u8; 20]; + timeout(Duration::from_secs(3), stream.read_exact(&mut header)) + .await + .map_err(|_| ProxyError::Proxy("STUN TCP header timeout".to_string()))? + .map_err(|e| ProxyError::Proxy(format!("STUN TCP header read failed: {e}")))?; + let body_len = u16::from_be_bytes([header[2], header[3]]) as usize; + if body_len > 236 { + return Ok(None); + } + let mut buf = [0u8; 256]; + buf[..20].copy_from_slice(&header); + if body_len > 0 { + timeout( + Duration::from_secs(3), + stream.read_exact(&mut buf[20..20 + body_len]), + ) + .await + .map_err(|_| ProxyError::Proxy("STUN TCP body timeout".to_string()))? + .map_err(|e| ProxyError::Proxy(format!("STUN TCP body read failed: {e}")))?; + } + + let txid = &req[8..20]; + let Some(reflected_addr) = parse_reflected_addr(&buf[..20 + body_len], txid) else { + return Ok(None); + }; + let local_addr = stream + .local_addr() + .map_err(|e| ProxyError::Proxy(format!("STUN TCP local_addr failed: {e}")))?; + Ok(Some(StunProbeResult { + local_addr, + reflected_addr, + family, + })) +} + +fn build_binding_request() -> [u8; 20] { + let mut req = [0u8; 20]; + req[0..2].copy_from_slice(&0x0001u16.to_be_bytes()); + req[2..4].copy_from_slice(&0u16.to_be_bytes()); + req[4..8].copy_from_slice(&0x2112A442u32.to_be_bytes()); + stun_rng().fill(&mut req[8..20]); + req +} + +fn parse_reflected_addr(buf: &[u8], txid: &[u8]) -> Option { + if buf.len() < 20 { + return None; + } + + let magic = 0x2112A442u32.to_be_bytes(); + let mut idx = 20; + while idx + 4 <= buf.len() { + let atype = u16::from_be_bytes(buf[idx..idx + 2].try_into().ok()?); + let alen = u16::from_be_bytes(buf[idx + 2..idx + 4].try_into().ok()?) as usize; + idx += 4; + if idx + alen > buf.len() { + break; + } + + match atype { + 0x0020 | 0x0001 => { if alen < 8 { break; } @@ -157,7 +311,6 @@ pub async fn stun_probe_family_with_bind( let raw_ip = &buf[idx + 4..idx + 4 + len_check]; let mut port = u16::from_be_bytes(port_bytes); - let reflected_ip = if atype == 0x0020 { port ^= ((magic[0] as u16) << 8) | magic[1] as u16; match family_byte { @@ -172,7 +325,9 @@ pub async fn stun_probe_family_with_bind( } 0x02 => { let mut ip = [0u8; 16]; - let xor_key = [magic.as_slice(), txid].concat(); + let mut xor_key = [0u8; 16]; + xor_key[..4].copy_from_slice(&magic); + xor_key[4..].copy_from_slice(txid.get(..12)?); for (i, b) in raw_ip.iter().enumerate().take(16) { ip[i] = *b ^ xor_key[i]; } @@ -185,34 +340,24 @@ pub async fn stun_probe_family_with_bind( } } else { match family_byte { - 0x01 => IpAddr::V4(Ipv4Addr::new(raw_ip[0], raw_ip[1], raw_ip[2], raw_ip[3])), - 0x02 => IpAddr::V6(Ipv6Addr::from(<[u8; 16]>::try_from(raw_ip).unwrap())), + 0x01 => { + IpAddr::V4(Ipv4Addr::new(raw_ip[0], raw_ip[1], raw_ip[2], raw_ip[3])) + } + 0x02 => IpAddr::V6(Ipv6Addr::from(<[u8; 16]>::try_from(raw_ip).ok()?)), _ => { idx += (alen + 3) & !3; continue; } } }; - - let reflected_addr = SocketAddr::new(reflected_ip, port); - let local_addr = socket - .local_addr() - .map_err(|e| ProxyError::Proxy(format!("STUN local_addr failed: {e}")))?; - - return Ok(Some(StunProbeResult { - local_addr, - reflected_addr, - family, - })); + return Some(SocketAddr::new(reflected_ip, port)); } _ => {} } - idx += (alen + 3) & !3; - } + idx += (alen + 3) & !3; } - - Ok(None) + None } async fn resolve_stun_addr(stun_addr: &str, family: IpFamily) -> Result> { @@ -245,3 +390,52 @@ async fn resolve_stun_addr(stun_addr: &str, family: IpFamily) -> Result bool { } /// Compute Secure Intermediate payload length from wire length. -/// Secure mode strips up to 3 random tail bytes by truncating to 4-byte boundary. +/// Secure mode cannot distinguish full-word padding from payload, so only the +/// non-aligned tail bytes are stripped. pub fn secure_payload_len_from_wire_len(wire_len: usize) -> Option { if wire_len < 4 { return None; @@ -245,13 +246,13 @@ pub fn secure_payload_len_from_wire_len(wire_len: usize) -> Option { } /// Generate padding length for Secure Intermediate protocol. -/// Data must be 4-byte aligned; padding is 1..=3 so total is never divisible by 4. +/// Telegram Desktop uses a 4-bit random padding length for VersionD packets. pub fn secure_padding_len(data_len: usize, rng: &SecureRandom) -> usize { debug_assert!( is_valid_secure_payload_len(data_len), "Secure payload must be 4-byte aligned, got {data_len}" ); - rng.range(3) + 1 + rng.range(16) } // ============= Timeouts ============= @@ -424,21 +425,15 @@ mod tests { } #[test] - fn secure_padding_never_produces_aligned_total() { + fn secure_padding_matches_tdesktop_range() { let rng = SecureRandom::new(); for data_len in (0..1000).step_by(4) { for _ in 0..100 { let padding = secure_padding_len(data_len, &rng); assert!( - padding <= 3, + padding <= 15, "padding out of range: data_len={data_len}, padding={padding}" ); - assert_ne!( - (data_len + padding) % 4, - 0, - "invariant violated: data_len={data_len}, padding={padding}, total={}", - data_len + padding - ); } } } @@ -454,6 +449,16 @@ mod tests { } } + #[test] + fn secure_wire_len_preserves_full_word_tail() { + let payload_len = 64; + for padding in [4usize, 8, 12] { + let wire_len = payload_len + padding; + let recovered = secure_payload_len_from_wire_len(wire_len); + assert_eq!(recovered, Some(wire_len)); + } + } + #[test] fn secure_wire_len_rejects_too_short_frames() { assert_eq!(secure_payload_len_from_wire_len(0), None); diff --git a/src/transport/middle_proxy/handshake.rs b/src/transport/middle_proxy/handshake.rs index 01206e2..b19dc84 100644 --- a/src/transport/middle_proxy/handshake.rs +++ b/src/transport/middle_proxy/handshake.rs @@ -18,7 +18,7 @@ use tokio::time::timeout; use tracing::{debug, info, warn}; use crate::config::MeSocksKdfPolicy; -use crate::crypto::{SecureRandom, build_middleproxy_prekey, derive_middleproxy_keys, sha256}; +use crate::crypto::{SecureRandom, derive_middleproxy_keys}; use crate::error::{ProxyError, Result}; use crate::network::IpFamily; use crate::network::probe::is_bogon; @@ -292,14 +292,15 @@ impl MePool { BndPortStatus::Error }; record_bnd_status(bnd_addr_status, bnd_port_status, raw_socks_bound_addr); - let reflected = if let Some(bound) = socks_bound_addr { + let socks_bound_kdf_addr = socks_bound_addr.filter(|bound| bound.port() != 0); + let reflected = if let Some(bound) = socks_bound_kdf_addr { Some(bound) } else if is_socks_route { match self.socks_kdf_policy() { MeSocksKdfPolicy::Strict => { self.stats.increment_me_socks_kdf_strict_reject(); return Err(ProxyError::InvalidHandshake( - "SOCKS route returned no valid BND.ADDR for ME KDF (strict policy)" + "SOCKS route returned no valid BND tuple for ME KDF (strict policy)" .to_string(), )); } @@ -323,16 +324,14 @@ impl MePool { let local_addr_nat = self.translate_our_addr_with_reflection(local_addr, reflected); let peer_addr_nat = SocketAddr::new(self.translate_ip_for_nat(peer_addr.ip()), peer_addr.port()); + let client_addr_for_kdf = socks_bound_kdf_addr.unwrap_or(local_addr_nat); if let Some(upstream_info) = upstream_egress { - let client_ip_for_kdf = socks_bound_addr - .map(|value| value.ip()) - .unwrap_or(local_addr_nat.ip()); record_upstream_bnd_status( upstream_info.upstream_id, bnd_addr_status, bnd_port_status, raw_socks_bound_addr, - Some(client_ip_for_kdf), + Some(client_addr_for_kdf.ip()), ); } let (mut rd, mut wr) = tokio::io::split(stream); @@ -409,6 +408,7 @@ impl MePool { info!( %local_addr, %local_addr_nat, + %client_addr_for_kdf, reflected_ip = reflected.map(|r| r.ip()).as_ref().map(ToString::to_string), %peer_addr, %transport_peer_addr, @@ -422,16 +422,14 @@ impl MePool { let ts_bytes = crypto_ts.to_le_bytes(); let server_port_bytes = peer_addr_nat.port().to_le_bytes(); - let socks_bound_port = socks_bound_addr - .map(|bound| bound.port()) - .filter(|port| *port != 0); - let client_port_for_kdf = socks_bound_port.unwrap_or(local_addr_nat.port()); + let socks_bound_port = socks_bound_kdf_addr.map(|bound| bound.port()); + let client_port_for_kdf = client_addr_for_kdf.port(); let client_port_source = KdfClientPortSource::from_socks_bound_port(socks_bound_port); let kdf_fingerprint = Self::kdf_material_fingerprint( - local_addr_nat.ip(), + client_addr_for_kdf.ip(), peer_addr_nat, reflected.map(|value| value.ip()), - socks_bound_addr.map(|value| value.ip()), + socks_bound_kdf_addr.map(|value| value.ip()), client_port_source, ); let previous_kdf_fingerprint = { @@ -473,7 +471,7 @@ impl MePool { let client_port_bytes = client_port_for_kdf.to_le_bytes(); let server_ip = extract_ip_material(peer_addr_nat); - let client_ip = extract_ip_material(local_addr_nat); + let client_ip = extract_ip_material(client_addr_for_kdf); let (srv_ip_opt, clt_ip_opt, clt_v6_opt, srv_v6_opt, hs_our_ip, hs_peer_ip) = match (server_ip, client_ip) { @@ -494,38 +492,6 @@ impl MePool { } }; - let diag_level: u8 = std::env::var("ME_DIAG") - .ok() - .and_then(|v| v.parse().ok()) - .unwrap_or(0); - - let prekey_client = build_middleproxy_prekey( - &srv_nonce, - &my_nonce, - &ts_bytes, - srv_ip_opt.as_ref().map(|x| &x[..]), - &client_port_bytes, - b"CLIENT", - clt_ip_opt.as_ref().map(|x| &x[..]), - &server_port_bytes, - &secret, - clt_v6_opt.as_ref(), - srv_v6_opt.as_ref(), - ); - let prekey_server = build_middleproxy_prekey( - &srv_nonce, - &my_nonce, - &ts_bytes, - srv_ip_opt.as_ref().map(|x| &x[..]), - &client_port_bytes, - b"SERVER", - clt_ip_opt.as_ref().map(|x| &x[..]), - &server_port_bytes, - &secret, - clt_v6_opt.as_ref(), - srv_v6_opt.as_ref(), - ); - let (wk, wi) = derive_middleproxy_keys( &srv_nonce, &my_nonce, @@ -556,47 +522,14 @@ impl MePool { let requested_crc_mode = RpcChecksumMode::Crc32c; let hs_payload = build_handshake_payload( hs_our_ip, - local_addr.port(), + client_port_for_kdf, hs_peer_ip, - peer_addr.port(), + peer_addr_nat.port(), requested_crc_mode.advertised_flags(), ); let hs_frame = build_rpc_frame(-1, &hs_payload, RpcChecksumMode::Crc32); - if diag_level >= 1 { - info!( - write_key = %hex_dump(&wk), - write_iv = %hex_dump(&wi), - read_key = %hex_dump(&rk), - read_iv = %hex_dump(&ri), - srv_ip = %srv_ip_opt.map(|ip| hex_dump(&ip)).unwrap_or_default(), - clt_ip = %clt_ip_opt.map(|ip| hex_dump(&ip)).unwrap_or_default(), - srv_port = %hex_dump(&server_port_bytes), - clt_port = %hex_dump(&client_port_bytes), - crypto_ts = %hex_dump(&ts_bytes), - nonce_srv = %hex_dump(&srv_nonce), - nonce_clt = %hex_dump(&my_nonce), - prekey_sha256_client = %hex_dump(&sha256(&prekey_client)), - prekey_sha256_server = %hex_dump(&sha256(&prekey_server)), - hs_plain = %hex_dump(&hs_frame), - proxy_secret_sha256 = %hex_dump(&sha256(&secret)), - "ME diag: derived keys and handshake plaintext" - ); - } - if diag_level >= 2 { - info!( - prekey_client = %hex_dump(&prekey_client), - prekey_server = %hex_dump(&prekey_server), - "ME diag: full prekey buffers" - ); - } let (encrypted_hs, write_iv) = cbc_encrypt_padded(&wk, &wi, &hs_frame)?; - if diag_level >= 1 { - info!( - hs_cipher = %hex_dump(&encrypted_hs), - "ME diag: handshake ciphertext" - ); - } wr.write_all(&encrypted_hs).await.map_err(ProxyError::Io)?; wr.flush().await.map_err(ProxyError::Io)?; diff --git a/src/transport/middle_proxy/health.rs b/src/transport/middle_proxy/health.rs index 399fd13..0573167 100644 --- a/src/transport/middle_proxy/health.rs +++ b/src/transport/middle_proxy/health.rs @@ -1728,6 +1728,8 @@ mod tests { false, None, Vec::new(), + false, + Vec::new(), 1, None, 12, diff --git a/src/transport/middle_proxy/pool.rs b/src/transport/middle_proxy/pool.rs index 50861eb..077a014 100644 --- a/src/transport/middle_proxy/pool.rs +++ b/src/transport/middle_proxy/pool.rs @@ -336,6 +336,8 @@ pub(super) struct NatRuntimeCore { pub(super) nat_probe: bool, pub(super) nat_stun: Option, pub(super) nat_stun_servers: Vec, + pub(super) stun_tcp_fallback: bool, + pub(super) http_ip_detect_urls: Vec, pub(super) nat_stun_live_servers: Arc>>, pub(super) nat_probe_concurrency: usize, pub(super) detected_ipv6: Option, @@ -484,6 +486,8 @@ impl MePool { nat_probe: bool, nat_stun: Option, nat_stun_servers: Vec, + stun_tcp_fallback: bool, + http_ip_detect_urls: Vec, nat_probe_concurrency: usize, detected_ipv6: Option, me_one_retry: u8, @@ -706,6 +710,8 @@ impl MePool { nat_probe, nat_stun, nat_stun_servers, + stun_tcp_fallback, + http_ip_detect_urls, nat_stun_live_servers: Arc::new(RwLock::new(Vec::new())), nat_probe_concurrency: nat_probe_concurrency.max(1), detected_ipv6, diff --git a/src/transport/middle_proxy/pool_nat.rs b/src/transport/middle_proxy/pool_nat.rs index be2d9df..28b8db8 100644 --- a/src/transport/middle_proxy/pool_nat.rs +++ b/src/transport/middle_proxy/pool_nat.rs @@ -1,19 +1,22 @@ use std::collections::HashMap; -use std::net::{IpAddr, Ipv4Addr}; +use std::net::IpAddr; use std::time::Duration; use tokio::task::JoinSet; use tokio::time::timeout; -use tracing::{debug, info, warn}; +use tracing::{debug, info}; use crate::error::{ProxyError, Result}; -use crate::network::probe::is_bogon; -use crate::network::stun::{IpFamily, stun_probe_dual, stun_probe_family_with_bind}; +use crate::network::probe::{detect_public_ipv4_http, is_bogon}; +use crate::network::stun::{ + IpFamily, stun_probe_dual_with_tcp_fallback, stun_probe_family_with_bind_and_tcp_fallback, +}; use super::MePool; use std::time::Instant; const STUN_BATCH_TIMEOUT: Duration = Duration::from_secs(5); +const STUN_BATCH_TCP_FALLBACK_TIMEOUT: Duration = Duration::from_secs(12); #[allow(dead_code)] pub async fn stun_probe(stun_addr: Option) -> Result { @@ -28,15 +31,14 @@ pub async fn stun_probe(stun_addr: Option) -> Result Option { - fetch_public_ipv4_with_retry() + let urls = crate::config::defaults::default_http_ip_detect_urls(); + detect_public_ipv4_http(&urls) .await - .ok() - .flatten() .map(IpAddr::V4) } @@ -65,15 +67,26 @@ impl MePool { let mut live_servers = Vec::new(); let mut best_by_ip: HashMap = HashMap::new(); let concurrency = self.nat_runtime.nat_probe_concurrency.max(1); + let tcp_fallback = self.nat_runtime.stun_tcp_fallback; while next_idx < servers.len() || !join_set.is_empty() { while next_idx < servers.len() && join_set.len() < concurrency { let stun_addr = servers[next_idx].clone(); next_idx += 1; join_set.spawn(async move { + let batch_timeout = if tcp_fallback { + STUN_BATCH_TCP_FALLBACK_TIMEOUT + } else { + STUN_BATCH_TIMEOUT + }; let res = timeout( - STUN_BATCH_TIMEOUT, - stun_probe_family_with_bind(&stun_addr, family, bind_ip), + batch_timeout, + stun_probe_family_with_bind_and_tcp_fallback( + &stun_addr, + family, + bind_ip, + tcp_fallback, + ), ) .await; (stun_addr, res) @@ -193,6 +206,10 @@ impl MePool { return self.nat_runtime.nat_ip_cfg; } + if !self.nat_runtime.nat_probe { + return None; + } + if !(is_bogon(local_ip) || local_ip.is_loopback() || local_ip.is_unspecified()) { return None; } @@ -201,21 +218,15 @@ impl MePool { return Some(ip); } - match fetch_public_ipv4_with_retry().await { - Ok(Some(ip)) => { - { - let mut guard = self.nat_runtime.nat_ip_detected.write().await; - *guard = Some(IpAddr::V4(ip)); - } - info!(public_ip = %ip, "Auto-detected public IP for NAT translation"); - Some(IpAddr::V4(ip)) - } - Ok(None) => None, - Err(e) => { - warn!(error = %e, "Failed to auto-detect public IP"); - None - } + let Some(ip) = detect_public_ipv4_http(&self.nat_runtime.http_ip_detect_urls).await else { + return None; + }; + { + let mut guard = self.nat_runtime.nat_ip_detected.write().await; + *guard = Some(IpAddr::V4(ip)); } + info!(public_ip = %ip, "Auto-detected public IP for NAT translation"); + Some(IpAddr::V4(ip)) } pub(super) async fn maybe_reflect_public_addr( @@ -365,31 +376,3 @@ impl MePool { None } } - -async fn fetch_public_ipv4_with_retry() -> Result> { - let providers = [ - "https://checkip.amazonaws.com", - "http://v4.ident.me", - "http://ipv4.icanhazip.com", - ]; - for url in providers { - if let Ok(Some(ip)) = fetch_public_ipv4_once(url).await { - return Ok(Some(ip)); - } - } - Ok(None) -} - -async fn fetch_public_ipv4_once(url: &str) -> Result> { - let res = reqwest::get(url) - .await - .map_err(|e| ProxyError::Proxy(format!("public IP detection request failed: {e}")))?; - - let text = res - .text() - .await - .map_err(|e| ProxyError::Proxy(format!("public IP detection read failed: {e}")))?; - - let ip = text.trim().parse().ok(); - Ok(ip) -} diff --git a/src/transport/middle_proxy/tests/health_adversarial_tests.rs b/src/transport/middle_proxy/tests/health_adversarial_tests.rs index ea88c67..fdbd5f9 100644 --- a/src/transport/middle_proxy/tests/health_adversarial_tests.rs +++ b/src/transport/middle_proxy/tests/health_adversarial_tests.rs @@ -38,6 +38,8 @@ async fn make_pool( false, None, Vec::new(), + false, + Vec::new(), 1, None, 12, diff --git a/src/transport/middle_proxy/tests/health_integration_tests.rs b/src/transport/middle_proxy/tests/health_integration_tests.rs index 9b3f93e..1dc6abf 100644 --- a/src/transport/middle_proxy/tests/health_integration_tests.rs +++ b/src/transport/middle_proxy/tests/health_integration_tests.rs @@ -36,6 +36,8 @@ async fn make_pool( false, None, Vec::new(), + false, + Vec::new(), 1, None, 12, diff --git a/src/transport/middle_proxy/tests/health_regression_tests.rs b/src/transport/middle_proxy/tests/health_regression_tests.rs index aa1f9ed..c2f5441 100644 --- a/src/transport/middle_proxy/tests/health_regression_tests.rs +++ b/src/transport/middle_proxy/tests/health_regression_tests.rs @@ -31,6 +31,8 @@ async fn make_pool(me_pool_drain_threshold: u64) -> Arc { false, None, Vec::new(), + false, + Vec::new(), 1, None, 12, diff --git a/src/transport/middle_proxy/tests/pool_refill_security_tests.rs b/src/transport/middle_proxy/tests/pool_refill_security_tests.rs index 6519c05..9125ad0 100644 --- a/src/transport/middle_proxy/tests/pool_refill_security_tests.rs +++ b/src/transport/middle_proxy/tests/pool_refill_security_tests.rs @@ -20,6 +20,8 @@ async fn make_pool() -> Arc { false, None, Vec::new(), + false, + Vec::new(), 1, None, 12, diff --git a/src/transport/middle_proxy/tests/pool_writer_security_tests.rs b/src/transport/middle_proxy/tests/pool_writer_security_tests.rs index 5f9f130..f0514a9 100644 --- a/src/transport/middle_proxy/tests/pool_writer_security_tests.rs +++ b/src/transport/middle_proxy/tests/pool_writer_security_tests.rs @@ -25,6 +25,8 @@ async fn make_pool() -> Arc { false, None, Vec::new(), + false, + Vec::new(), 1, None, 12, diff --git a/src/transport/middle_proxy/tests/send_adversarial_tests.rs b/src/transport/middle_proxy/tests/send_adversarial_tests.rs index 4050fa1..7ed8d76 100644 --- a/src/transport/middle_proxy/tests/send_adversarial_tests.rs +++ b/src/transport/middle_proxy/tests/send_adversarial_tests.rs @@ -31,6 +31,8 @@ async fn make_pool() -> (Arc, Arc) { false, None, Vec::new(), + false, + Vec::new(), 1, None, 12,