From b1537825972654c4772b8bde938725ca30b391ca Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Sat, 13 Jun 2026 23:22:50 +0300 Subject: [PATCH 01/15] More efficient Relay Mode Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com> --- src/proxy/masking.rs | 75 ++++++++---- .../tests/masking_relay_manual_perf_tests.rs | 111 ++++++++++++++++++ 2 files changed, 163 insertions(+), 23 deletions(-) create mode 100644 src/proxy/tests/masking_relay_manual_perf_tests.rs diff --git a/src/proxy/masking.rs b/src/proxy/masking.rs index 7e73eb8..fa29529 100644 --- a/src/proxy/masking.rs +++ b/src/proxy/masking.rs @@ -5,6 +5,7 @@ use crate::network::dns_overrides::resolve_socket_addr; use crate::protocol::tls; use crate::stats::beobachten::BeobachtenStore; use crate::transport::proxy_protocol::{ProxyProtocolV1Builder, ProxyProtocolV2Builder}; +use crate::transport::socket::configure_tcp_socket; #[cfg(unix)] use nix::ifaddrs::getifaddrs; use rand::rngs::StdRng; @@ -36,6 +37,8 @@ const MASK_RELAY_TIMEOUT: Duration = Duration::from_millis(200); #[cfg(test)] const MASK_RELAY_IDLE_TIMEOUT: Duration = Duration::from_millis(100); const MASK_BUFFER_SIZE: usize = 8192; +const MASK_BUFFER_GROW_AFTER_BYTES: usize = 256 * 1024; +const MASK_BUFFER_MAX_SIZE: usize = 64 * 1024; #[cfg(unix)] #[cfg(not(test))] const LOCAL_INTERFACE_CACHE_TTL: Duration = Duration::from_secs(300); @@ -53,6 +56,27 @@ struct MaskTcpTarget<'a> { port: u16, } +fn mask_copy_read_len(total: usize, byte_cap: usize) -> usize { + // Keep short scanner probes on the small baseline buffer and grow only + // after the session has proven to be sustained masking relay traffic. + let active_buffer_size = if total >= MASK_BUFFER_GROW_AFTER_BYTES { + MASK_BUFFER_MAX_SIZE + } else { + MASK_BUFFER_SIZE + }; + + if byte_cap == 0 { + return active_buffer_size; + } + + let remaining_budget = byte_cap.saturating_sub(total); + if remaining_budget == 0 { + return 0; + } + + remaining_budget.min(active_buffer_size) +} + async fn copy_with_idle_timeout( reader: &mut R, writer: &mut W, @@ -64,21 +88,18 @@ where R: AsyncRead + Unpin, W: AsyncWrite + Unpin, { - let mut buf = Box::new([0u8; MASK_BUFFER_SIZE]); + let mut buf = vec![0u8; MASK_BUFFER_SIZE]; let mut total = 0usize; let mut ended_by_eof = false; - let unlimited = byte_cap == 0; loop { - let read_len = if unlimited { - MASK_BUFFER_SIZE - } else { - let remaining_budget = byte_cap.saturating_sub(total); - if remaining_budget == 0 { - break; - } - remaining_budget.min(MASK_BUFFER_SIZE) - }; + let read_len = mask_copy_read_len(total, byte_cap); + if read_len == 0 { + break; + } + if buf.len() < read_len { + buf.resize(read_len, 0); + } let read_res = timeout(idle_timeout, reader.read(&mut buf[..read_len])).await; let n = match read_res { Ok(Ok(n)) => n, @@ -877,6 +898,12 @@ fn build_mask_proxy_header( } } +fn configure_mask_backend_socket(stream: &TcpStream) { + if let Err(e) = configure_tcp_socket(stream, false, Duration::from_secs(0)) { + debug!(error = %e, "Failed to configure mask backend socket"); + } +} + /// Handle a bad client by forwarding to mask host pub async fn handle_bad_client( reader: R, @@ -1047,6 +1074,7 @@ pub async fn handle_bad_client( let connect_result = timeout(MASK_TIMEOUT, TcpStream::connect(&mask_addr)).await; match connect_result { Ok(Ok(stream)) => { + configure_mask_backend_socket(&stream); let proxy_header = build_mask_proxy_header(config.censorship.mask_proxy_protocol, peer, local_addr); @@ -1190,20 +1218,17 @@ async fn consume_client_data( idle_timeout: Duration, ) { // Keep drain path fail-closed under slow-loris stalls. - let mut buf = Box::new([0u8; MASK_BUFFER_SIZE]); + let mut buf = vec![0u8; MASK_BUFFER_SIZE]; let mut total = 0usize; - let unlimited = byte_cap == 0; loop { - let read_len = if unlimited { - MASK_BUFFER_SIZE - } else { - let remaining_budget = byte_cap.saturating_sub(total); - if remaining_budget == 0 { - break; - } - remaining_budget.min(MASK_BUFFER_SIZE) - }; + let read_len = mask_copy_read_len(total, byte_cap); + if read_len == 0 { + break; + } + if buf.len() < read_len { + buf.resize(read_len, 0); + } let n = match timeout(idle_timeout, reader.read(&mut buf[..read_len])).await { Ok(Ok(n)) => n, Ok(Err(_)) | Err(_) => break, @@ -1214,7 +1239,7 @@ async fn consume_client_data( } total = total.saturating_add(n); - if !unlimited && total >= byte_cap { + if byte_cap != 0 && total >= byte_cap { break; } } @@ -1332,6 +1357,10 @@ mod masking_interface_cache_concurrency_security_tests; #[path = "tests/masking_production_cap_regression_security_tests.rs"] mod masking_production_cap_regression_security_tests; +#[cfg(test)] +#[path = "tests/masking_relay_manual_perf_tests.rs"] +mod masking_relay_manual_perf_tests; + #[cfg(test)] #[path = "tests/masking_extended_attack_surface_security_tests.rs"] mod masking_extended_attack_surface_security_tests; diff --git a/src/proxy/tests/masking_relay_manual_perf_tests.rs b/src/proxy/tests/masking_relay_manual_perf_tests.rs new file mode 100644 index 0000000..f10bd8a --- /dev/null +++ b/src/proxy/tests/masking_relay_manual_perf_tests.rs @@ -0,0 +1,111 @@ +use super::*; +use std::pin::Pin; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio::time::{Duration, Instant}; + +const PERF_TOTAL_BYTES: usize = 64 * 1024 * 1024; + +struct PatternReader { + remaining: usize, + chunk: usize, + read_calls: AtomicUsize, +} + +impl PatternReader { + fn new(total: usize, chunk: usize) -> Self { + Self { + remaining: total, + chunk, + read_calls: AtomicUsize::new(0), + } + } + + fn read_calls(&self) -> usize { + self.read_calls.load(Ordering::Relaxed) + } +} + +impl AsyncRead for PatternReader { + fn poll_read( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + self.read_calls.fetch_add(1, Ordering::Relaxed); + if self.remaining == 0 { + return Poll::Ready(Ok(())); + } + + let take = self.remaining.min(self.chunk).min(buf.remaining()); + if take == 0 { + return Poll::Ready(Ok(())); + } + + static PATTERN: [u8; MASK_BUFFER_MAX_SIZE] = [0xA5; MASK_BUFFER_MAX_SIZE]; + buf.put_slice(&PATTERN[..take]); + self.remaining -= take; + Poll::Ready(Ok(())) + } +} + +#[derive(Default)] +struct CountingWriter { + written: usize, +} + +impl AsyncWrite for CountingWriter { + fn poll_write( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.written = self.written.saturating_add(buf.len()); + Poll::Ready(Ok(buf.len())) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} + +#[tokio::test] +#[ignore = "manual benchmark: throughput-sensitive and host-dependent"] +async fn masking_copy_with_idle_timeout_manual_throughput() { + let mut reader = PatternReader::new(PERF_TOTAL_BYTES, MASK_BUFFER_MAX_SIZE); + let mut writer = CountingWriter::default(); + let started = Instant::now(); + + let outcome = copy_with_idle_timeout( + &mut reader, + &mut writer, + PERF_TOTAL_BYTES, + true, + Duration::from_secs(30), + ) + .await; + + let elapsed = started.elapsed(); + let mb = PERF_TOTAL_BYTES as f64 / (1024.0 * 1024.0); + let mbps = mb / elapsed.as_secs_f64(); + + assert_eq!(outcome.total, PERF_TOTAL_BYTES); + assert_eq!(writer.written, PERF_TOTAL_BYTES); + assert!( + !outcome.ended_by_eof, + "manual throughput run should terminate at byte cap" + ); + + eprintln!( + "masking manual throughput: bytes={} elapsed_ms={} mib_per_sec={:.2} read_calls={}", + PERF_TOTAL_BYTES, + elapsed.as_millis(), + mbps, + reader.read_calls() + ); +} From d414c73c9b294ad9a54de2e19cecd90294042aff Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Sun, 14 Jun 2026 16:15:41 +0300 Subject: [PATCH 02/15] Hardened KDF-Tuple + NAT Probing + Paddings Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com> --- src/maestro/me_startup.rs | 2 + src/network/probe.rs | 40 ++- src/network/stun.rs | 276 +++++++++++++++--- src/protocol/constants.rs | 27 +- src/transport/middle_proxy/handshake.rs | 95 +----- src/transport/middle_proxy/health.rs | 2 + src/transport/middle_proxy/pool.rs | 6 + src/transport/middle_proxy/pool_nat.rs | 87 +++--- .../tests/health_adversarial_tests.rs | 2 + .../tests/health_integration_tests.rs | 2 + .../tests/health_regression_tests.rs | 2 + .../tests/pool_refill_security_tests.rs | 2 + .../tests/pool_writer_security_tests.rs | 2 + .../tests/send_adversarial_tests.rs | 2 + 14 files changed, 355 insertions(+), 192 deletions(-) diff --git a/src/maestro/me_startup.rs b/src/maestro/me_startup.rs index 9dde7fa..21e1ccf 100644 --- a/src/maestro/me_startup.rs +++ b/src/maestro/me_startup.rs @@ -208,6 +208,8 @@ pub(crate) async fn initialize_me_pool( me_nat_probe, None, config.network.stun_servers.clone(), + config.network.stun_tcp_fallback, + config.network.http_ip_detect_urls.clone(), config.general.stun_nat_probe_concurrency, probe.detected_ipv6, config.timeouts.me_one_retry, diff --git a/src/network/probe.rs b/src/network/probe.rs index 90484b3..5c3cbb8 100644 --- a/src/network/probe.rs +++ b/src/network/probe.rs @@ -12,7 +12,7 @@ use tracing::{debug, info, warn}; use crate::config::{NetworkConfig, UpstreamConfig, UpstreamType}; use crate::error::Result; use crate::network::stun::{ - DualStunResult, IpFamily, StunProbeResult, stun_probe_family_with_bind, + DualStunResult, IpFamily, StunProbeResult, stun_probe_family_with_bind_and_tcp_fallback, }; use crate::transport::UpstreamManager; @@ -58,6 +58,7 @@ impl NetworkDecision { } const STUN_BATCH_TIMEOUT: Duration = Duration::from_secs(5); +const STUN_BATCH_TCP_FALLBACK_TIMEOUT: Duration = Duration::from_secs(12); pub async fn run_probe( config: &NetworkConfig, @@ -81,8 +82,14 @@ pub async fn run_probe( warn!("STUN probe is enabled but network.stun_servers is empty"); DualStunResult::default() } else { - probe_stun_servers_parallel(&servers, stun_nat_probe_concurrency.max(1), None, None) - .await + probe_stun_servers_parallel( + &servers, + stun_nat_probe_concurrency.max(1), + None, + None, + config.stun_tcp_fallback, + ) + .await } } else if nat_probe { info!("STUN probe is disabled by network.stun_use=false"); @@ -163,6 +170,7 @@ pub async fn run_probe( stun_nat_probe_concurrency.max(1), bind_v4, bind_v6, + config.stun_tcp_fallback, ) .await; if let Some(reflected) = direct_stun_res.v4.map(|r| r.reflected_addr) { @@ -234,7 +242,7 @@ pub async fn run_probe( Ok(probe) } -async fn detect_public_ipv4_http(urls: &[String]) -> Option { +pub(crate) async fn detect_public_ipv4_http(urls: &[String]) -> Option { let client = reqwest::Client::builder() .timeout(Duration::from_secs(3)) .build() @@ -277,6 +285,7 @@ async fn probe_stun_servers_parallel( concurrency: usize, bind_v4: Option, bind_v6: Option, + tcp_fallback: bool, ) -> DualStunResult { let mut join_set = JoinSet::new(); let mut next_idx = 0usize; @@ -288,9 +297,26 @@ async fn probe_stun_servers_parallel( let stun_addr = servers[next_idx].clone(); next_idx += 1; join_set.spawn(async move { - let res = timeout(STUN_BATCH_TIMEOUT, async { - let v4 = stun_probe_family_with_bind(&stun_addr, IpFamily::V4, bind_v4).await?; - let v6 = stun_probe_family_with_bind(&stun_addr, IpFamily::V6, bind_v6).await?; + let batch_timeout = if tcp_fallback { + STUN_BATCH_TCP_FALLBACK_TIMEOUT + } else { + STUN_BATCH_TIMEOUT + }; + let res = timeout(batch_timeout, async { + let v4 = stun_probe_family_with_bind_and_tcp_fallback( + &stun_addr, + IpFamily::V4, + bind_v4, + tcp_fallback, + ) + .await?; + let v6 = stun_probe_family_with_bind_and_tcp_fallback( + &stun_addr, + IpFamily::V6, + bind_v6, + tcp_fallback, + ) + .await?; Ok::(DualStunResult { v4, v6 }) }) .await; diff --git a/src/network/stun.rs b/src/network/stun.rs index d1e088c..c2c1c86 100644 --- a/src/network/stun.rs +++ b/src/network/stun.rs @@ -4,7 +4,8 @@ use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; use std::sync::OnceLock; -use tokio::net::{UdpSocket, lookup_host}; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::{TcpSocket, UdpSocket, lookup_host}; use tokio::time::{Duration, sleep, timeout}; use crate::crypto::SecureRandom; @@ -36,9 +37,16 @@ pub struct DualStunResult { } pub async fn stun_probe_dual(stun_addr: &str) -> Result { + stun_probe_dual_with_tcp_fallback(stun_addr, false).await +} + +pub async fn stun_probe_dual_with_tcp_fallback( + stun_addr: &str, + tcp_fallback: bool, +) -> Result { let (v4, v6) = tokio::join!( - stun_probe_family(stun_addr, IpFamily::V4), - stun_probe_family(stun_addr, IpFamily::V6), + stun_probe_family_with_tcp_fallback(stun_addr, IpFamily::V4, tcp_fallback), + stun_probe_family_with_tcp_fallback(stun_addr, IpFamily::V6, tcp_fallback), ); Ok(DualStunResult { v4: v4?, v6: v6? }) @@ -48,13 +56,44 @@ pub async fn stun_probe_family( stun_addr: &str, family: IpFamily, ) -> Result> { - stun_probe_family_with_bind(stun_addr, family, None).await + stun_probe_family_with_tcp_fallback(stun_addr, family, false).await +} + +pub async fn stun_probe_family_with_tcp_fallback( + stun_addr: &str, + family: IpFamily, + tcp_fallback: bool, +) -> Result> { + stun_probe_family_with_bind_and_tcp_fallback(stun_addr, family, None, tcp_fallback).await } pub async fn stun_probe_family_with_bind( stun_addr: &str, family: IpFamily, bind_ip: Option, +) -> Result> { + stun_probe_family_with_bind_and_tcp_fallback(stun_addr, family, bind_ip, false).await +} + +pub async fn stun_probe_family_with_bind_and_tcp_fallback( + stun_addr: &str, + family: IpFamily, + bind_ip: Option, + tcp_fallback: bool, +) -> Result> { + let udp_attempts = if tcp_fallback { 1 } else { 3 }; + let udp_result = stun_probe_family_udp(stun_addr, family, bind_ip, udp_attempts).await?; + if udp_result.is_some() || !tcp_fallback { + return Ok(udp_result); + } + stun_probe_family_tcp(stun_addr, family, bind_ip).await +} + +async fn stun_probe_family_udp( + stun_addr: &str, + family: IpFamily, + bind_ip: Option, + max_attempts: u8, ) -> Result> { let bind_addr = match (family, bind_ip) { (IpFamily::V4, Some(IpAddr::V4(ip))) => SocketAddr::new(IpAddr::V4(ip), 0), @@ -94,12 +133,7 @@ pub async fn stun_probe_family_with_bind( return Ok(None); } - let mut req = [0u8; 20]; - req[0..2].copy_from_slice(&0x0001u16.to_be_bytes()); // Binding Request - req[2..4].copy_from_slice(&0u16.to_be_bytes()); // length - req[4..8].copy_from_slice(&0x2112A442u32.to_be_bytes()); // magic cookie - stun_rng().fill(&mut req[8..20]); // transaction ID - + let req = build_binding_request(); let mut buf = [0u8; 256]; let mut attempt = 0; let mut backoff = Duration::from_secs(1); @@ -115,7 +149,7 @@ pub async fn stun_probe_family_with_bind( Ok(Err(e)) => return Err(ProxyError::Proxy(format!("STUN recv failed: {e}"))), Err(_) => { attempt += 1; - if attempt >= 3 { + if attempt >= max_attempts { return Ok(None); } sleep(backoff).await; @@ -128,19 +162,139 @@ pub async fn stun_probe_family_with_bind( return Ok(None); } - let magic = 0x2112A442u32.to_be_bytes(); let txid = &req[8..20]; - let mut idx = 20; - while idx + 4 <= n { - let atype = u16::from_be_bytes(buf[idx..idx + 2].try_into().unwrap()); - let alen = u16::from_be_bytes(buf[idx + 2..idx + 4].try_into().unwrap()) as usize; - idx += 4; - if idx + alen > n { - break; - } + if let Some(reflected_addr) = parse_reflected_addr(&buf[..n], txid) { + let local_addr = socket + .local_addr() + .map_err(|e| ProxyError::Proxy(format!("STUN local_addr failed: {e}")))?; + return Ok(Some(StunProbeResult { + local_addr, + reflected_addr, + family, + })); + } + } - match atype { - 0x0020 /* XOR-MAPPED-ADDRESS */ | 0x0001 /* MAPPED-ADDRESS */ => { + Ok(None) +} + +async fn stun_probe_family_tcp( + stun_addr: &str, + family: IpFamily, + bind_ip: Option, +) -> Result> { + let target_addr = match resolve_stun_addr(stun_addr, family).await? { + Some(addr) => addr, + None => return Ok(None), + }; + let socket = match family { + IpFamily::V4 => TcpSocket::new_v4(), + IpFamily::V6 => TcpSocket::new_v6(), + } + .map_err(|e| ProxyError::Proxy(format!("STUN TCP socket failed: {e}")))?; + match (family, bind_ip) { + (IpFamily::V4, Some(IpAddr::V4(ip))) => { + if socket.bind(SocketAddr::new(IpAddr::V4(ip), 0)).is_err() { + return Ok(None); + } + } + (IpFamily::V6, Some(IpAddr::V6(ip))) => { + if socket.bind(SocketAddr::new(IpAddr::V6(ip), 0)).is_err() { + return Ok(None); + } + } + (IpFamily::V4, Some(IpAddr::V6(_))) | (IpFamily::V6, Some(IpAddr::V4(_))) => { + return Ok(None); + } + (_, None) => {} + } + + let connect_res = timeout(Duration::from_secs(3), socket.connect(target_addr)).await; + let mut stream = match connect_res { + Ok(Ok(stream)) => stream, + Ok(Err(e)) + if family == IpFamily::V6 + && matches!( + e.kind(), + std::io::ErrorKind::NetworkUnreachable + | std::io::ErrorKind::HostUnreachable + | std::io::ErrorKind::Unsupported + | std::io::ErrorKind::NetworkDown + ) => + { + return Ok(None); + } + Ok(Err(e)) => return Err(ProxyError::Proxy(format!("STUN TCP connect failed: {e}"))), + Err(_) => return Ok(None), + }; + + let req = build_binding_request(); + timeout(Duration::from_secs(3), stream.write_all(&req)) + .await + .map_err(|_| ProxyError::Proxy("STUN TCP send timeout".to_string()))? + .map_err(|e| ProxyError::Proxy(format!("STUN TCP send failed: {e}")))?; + + let mut header = [0u8; 20]; + timeout(Duration::from_secs(3), stream.read_exact(&mut header)) + .await + .map_err(|_| ProxyError::Proxy("STUN TCP header timeout".to_string()))? + .map_err(|e| ProxyError::Proxy(format!("STUN TCP header read failed: {e}")))?; + let body_len = u16::from_be_bytes([header[2], header[3]]) as usize; + if body_len > 236 { + return Ok(None); + } + let mut buf = [0u8; 256]; + buf[..20].copy_from_slice(&header); + if body_len > 0 { + timeout( + Duration::from_secs(3), + stream.read_exact(&mut buf[20..20 + body_len]), + ) + .await + .map_err(|_| ProxyError::Proxy("STUN TCP body timeout".to_string()))? + .map_err(|e| ProxyError::Proxy(format!("STUN TCP body read failed: {e}")))?; + } + + let txid = &req[8..20]; + let Some(reflected_addr) = parse_reflected_addr(&buf[..20 + body_len], txid) else { + return Ok(None); + }; + let local_addr = stream + .local_addr() + .map_err(|e| ProxyError::Proxy(format!("STUN TCP local_addr failed: {e}")))?; + Ok(Some(StunProbeResult { + local_addr, + reflected_addr, + family, + })) +} + +fn build_binding_request() -> [u8; 20] { + let mut req = [0u8; 20]; + req[0..2].copy_from_slice(&0x0001u16.to_be_bytes()); + req[2..4].copy_from_slice(&0u16.to_be_bytes()); + req[4..8].copy_from_slice(&0x2112A442u32.to_be_bytes()); + stun_rng().fill(&mut req[8..20]); + req +} + +fn parse_reflected_addr(buf: &[u8], txid: &[u8]) -> Option { + if buf.len() < 20 { + return None; + } + + let magic = 0x2112A442u32.to_be_bytes(); + let mut idx = 20; + while idx + 4 <= buf.len() { + let atype = u16::from_be_bytes(buf[idx..idx + 2].try_into().ok()?); + let alen = u16::from_be_bytes(buf[idx + 2..idx + 4].try_into().ok()?) as usize; + idx += 4; + if idx + alen > buf.len() { + break; + } + + match atype { + 0x0020 | 0x0001 => { if alen < 8 { break; } @@ -157,7 +311,6 @@ pub async fn stun_probe_family_with_bind( let raw_ip = &buf[idx + 4..idx + 4 + len_check]; let mut port = u16::from_be_bytes(port_bytes); - let reflected_ip = if atype == 0x0020 { port ^= ((magic[0] as u16) << 8) | magic[1] as u16; match family_byte { @@ -172,7 +325,9 @@ pub async fn stun_probe_family_with_bind( } 0x02 => { let mut ip = [0u8; 16]; - let xor_key = [magic.as_slice(), txid].concat(); + let mut xor_key = [0u8; 16]; + xor_key[..4].copy_from_slice(&magic); + xor_key[4..].copy_from_slice(txid.get(..12)?); for (i, b) in raw_ip.iter().enumerate().take(16) { ip[i] = *b ^ xor_key[i]; } @@ -185,34 +340,24 @@ pub async fn stun_probe_family_with_bind( } } else { match family_byte { - 0x01 => IpAddr::V4(Ipv4Addr::new(raw_ip[0], raw_ip[1], raw_ip[2], raw_ip[3])), - 0x02 => IpAddr::V6(Ipv6Addr::from(<[u8; 16]>::try_from(raw_ip).unwrap())), + 0x01 => { + IpAddr::V4(Ipv4Addr::new(raw_ip[0], raw_ip[1], raw_ip[2], raw_ip[3])) + } + 0x02 => IpAddr::V6(Ipv6Addr::from(<[u8; 16]>::try_from(raw_ip).ok()?)), _ => { idx += (alen + 3) & !3; continue; } } }; - - let reflected_addr = SocketAddr::new(reflected_ip, port); - let local_addr = socket - .local_addr() - .map_err(|e| ProxyError::Proxy(format!("STUN local_addr failed: {e}")))?; - - return Ok(Some(StunProbeResult { - local_addr, - reflected_addr, - family, - })); + return Some(SocketAddr::new(reflected_ip, port)); } _ => {} } - idx += (alen + 3) & !3; - } + idx += (alen + 3) & !3; } - - Ok(None) + None } async fn resolve_stun_addr(stun_addr: &str, family: IpFamily) -> Result> { @@ -245,3 +390,52 @@ async fn resolve_stun_addr(stun_addr: &str, family: IpFamily) -> Result bool { } /// Compute Secure Intermediate payload length from wire length. -/// Secure mode strips up to 3 random tail bytes by truncating to 4-byte boundary. +/// Secure mode cannot distinguish full-word padding from payload, so only the +/// non-aligned tail bytes are stripped. pub fn secure_payload_len_from_wire_len(wire_len: usize) -> Option { if wire_len < 4 { return None; @@ -245,13 +246,13 @@ pub fn secure_payload_len_from_wire_len(wire_len: usize) -> Option { } /// Generate padding length for Secure Intermediate protocol. -/// Data must be 4-byte aligned; padding is 1..=3 so total is never divisible by 4. +/// Telegram Desktop uses a 4-bit random padding length for VersionD packets. pub fn secure_padding_len(data_len: usize, rng: &SecureRandom) -> usize { debug_assert!( is_valid_secure_payload_len(data_len), "Secure payload must be 4-byte aligned, got {data_len}" ); - rng.range(3) + 1 + rng.range(16) } // ============= Timeouts ============= @@ -424,21 +425,15 @@ mod tests { } #[test] - fn secure_padding_never_produces_aligned_total() { + fn secure_padding_matches_tdesktop_range() { let rng = SecureRandom::new(); for data_len in (0..1000).step_by(4) { for _ in 0..100 { let padding = secure_padding_len(data_len, &rng); assert!( - padding <= 3, + padding <= 15, "padding out of range: data_len={data_len}, padding={padding}" ); - assert_ne!( - (data_len + padding) % 4, - 0, - "invariant violated: data_len={data_len}, padding={padding}, total={}", - data_len + padding - ); } } } @@ -454,6 +449,16 @@ mod tests { } } + #[test] + fn secure_wire_len_preserves_full_word_tail() { + let payload_len = 64; + for padding in [4usize, 8, 12] { + let wire_len = payload_len + padding; + let recovered = secure_payload_len_from_wire_len(wire_len); + assert_eq!(recovered, Some(wire_len)); + } + } + #[test] fn secure_wire_len_rejects_too_short_frames() { assert_eq!(secure_payload_len_from_wire_len(0), None); diff --git a/src/transport/middle_proxy/handshake.rs b/src/transport/middle_proxy/handshake.rs index 01206e2..b19dc84 100644 --- a/src/transport/middle_proxy/handshake.rs +++ b/src/transport/middle_proxy/handshake.rs @@ -18,7 +18,7 @@ use tokio::time::timeout; use tracing::{debug, info, warn}; use crate::config::MeSocksKdfPolicy; -use crate::crypto::{SecureRandom, build_middleproxy_prekey, derive_middleproxy_keys, sha256}; +use crate::crypto::{SecureRandom, derive_middleproxy_keys}; use crate::error::{ProxyError, Result}; use crate::network::IpFamily; use crate::network::probe::is_bogon; @@ -292,14 +292,15 @@ impl MePool { BndPortStatus::Error }; record_bnd_status(bnd_addr_status, bnd_port_status, raw_socks_bound_addr); - let reflected = if let Some(bound) = socks_bound_addr { + let socks_bound_kdf_addr = socks_bound_addr.filter(|bound| bound.port() != 0); + let reflected = if let Some(bound) = socks_bound_kdf_addr { Some(bound) } else if is_socks_route { match self.socks_kdf_policy() { MeSocksKdfPolicy::Strict => { self.stats.increment_me_socks_kdf_strict_reject(); return Err(ProxyError::InvalidHandshake( - "SOCKS route returned no valid BND.ADDR for ME KDF (strict policy)" + "SOCKS route returned no valid BND tuple for ME KDF (strict policy)" .to_string(), )); } @@ -323,16 +324,14 @@ impl MePool { let local_addr_nat = self.translate_our_addr_with_reflection(local_addr, reflected); let peer_addr_nat = SocketAddr::new(self.translate_ip_for_nat(peer_addr.ip()), peer_addr.port()); + let client_addr_for_kdf = socks_bound_kdf_addr.unwrap_or(local_addr_nat); if let Some(upstream_info) = upstream_egress { - let client_ip_for_kdf = socks_bound_addr - .map(|value| value.ip()) - .unwrap_or(local_addr_nat.ip()); record_upstream_bnd_status( upstream_info.upstream_id, bnd_addr_status, bnd_port_status, raw_socks_bound_addr, - Some(client_ip_for_kdf), + Some(client_addr_for_kdf.ip()), ); } let (mut rd, mut wr) = tokio::io::split(stream); @@ -409,6 +408,7 @@ impl MePool { info!( %local_addr, %local_addr_nat, + %client_addr_for_kdf, reflected_ip = reflected.map(|r| r.ip()).as_ref().map(ToString::to_string), %peer_addr, %transport_peer_addr, @@ -422,16 +422,14 @@ impl MePool { let ts_bytes = crypto_ts.to_le_bytes(); let server_port_bytes = peer_addr_nat.port().to_le_bytes(); - let socks_bound_port = socks_bound_addr - .map(|bound| bound.port()) - .filter(|port| *port != 0); - let client_port_for_kdf = socks_bound_port.unwrap_or(local_addr_nat.port()); + let socks_bound_port = socks_bound_kdf_addr.map(|bound| bound.port()); + let client_port_for_kdf = client_addr_for_kdf.port(); let client_port_source = KdfClientPortSource::from_socks_bound_port(socks_bound_port); let kdf_fingerprint = Self::kdf_material_fingerprint( - local_addr_nat.ip(), + client_addr_for_kdf.ip(), peer_addr_nat, reflected.map(|value| value.ip()), - socks_bound_addr.map(|value| value.ip()), + socks_bound_kdf_addr.map(|value| value.ip()), client_port_source, ); let previous_kdf_fingerprint = { @@ -473,7 +471,7 @@ impl MePool { let client_port_bytes = client_port_for_kdf.to_le_bytes(); let server_ip = extract_ip_material(peer_addr_nat); - let client_ip = extract_ip_material(local_addr_nat); + let client_ip = extract_ip_material(client_addr_for_kdf); let (srv_ip_opt, clt_ip_opt, clt_v6_opt, srv_v6_opt, hs_our_ip, hs_peer_ip) = match (server_ip, client_ip) { @@ -494,38 +492,6 @@ impl MePool { } }; - let diag_level: u8 = std::env::var("ME_DIAG") - .ok() - .and_then(|v| v.parse().ok()) - .unwrap_or(0); - - let prekey_client = build_middleproxy_prekey( - &srv_nonce, - &my_nonce, - &ts_bytes, - srv_ip_opt.as_ref().map(|x| &x[..]), - &client_port_bytes, - b"CLIENT", - clt_ip_opt.as_ref().map(|x| &x[..]), - &server_port_bytes, - &secret, - clt_v6_opt.as_ref(), - srv_v6_opt.as_ref(), - ); - let prekey_server = build_middleproxy_prekey( - &srv_nonce, - &my_nonce, - &ts_bytes, - srv_ip_opt.as_ref().map(|x| &x[..]), - &client_port_bytes, - b"SERVER", - clt_ip_opt.as_ref().map(|x| &x[..]), - &server_port_bytes, - &secret, - clt_v6_opt.as_ref(), - srv_v6_opt.as_ref(), - ); - let (wk, wi) = derive_middleproxy_keys( &srv_nonce, &my_nonce, @@ -556,47 +522,14 @@ impl MePool { let requested_crc_mode = RpcChecksumMode::Crc32c; let hs_payload = build_handshake_payload( hs_our_ip, - local_addr.port(), + client_port_for_kdf, hs_peer_ip, - peer_addr.port(), + peer_addr_nat.port(), requested_crc_mode.advertised_flags(), ); let hs_frame = build_rpc_frame(-1, &hs_payload, RpcChecksumMode::Crc32); - if diag_level >= 1 { - info!( - write_key = %hex_dump(&wk), - write_iv = %hex_dump(&wi), - read_key = %hex_dump(&rk), - read_iv = %hex_dump(&ri), - srv_ip = %srv_ip_opt.map(|ip| hex_dump(&ip)).unwrap_or_default(), - clt_ip = %clt_ip_opt.map(|ip| hex_dump(&ip)).unwrap_or_default(), - srv_port = %hex_dump(&server_port_bytes), - clt_port = %hex_dump(&client_port_bytes), - crypto_ts = %hex_dump(&ts_bytes), - nonce_srv = %hex_dump(&srv_nonce), - nonce_clt = %hex_dump(&my_nonce), - prekey_sha256_client = %hex_dump(&sha256(&prekey_client)), - prekey_sha256_server = %hex_dump(&sha256(&prekey_server)), - hs_plain = %hex_dump(&hs_frame), - proxy_secret_sha256 = %hex_dump(&sha256(&secret)), - "ME diag: derived keys and handshake plaintext" - ); - } - if diag_level >= 2 { - info!( - prekey_client = %hex_dump(&prekey_client), - prekey_server = %hex_dump(&prekey_server), - "ME diag: full prekey buffers" - ); - } let (encrypted_hs, write_iv) = cbc_encrypt_padded(&wk, &wi, &hs_frame)?; - if diag_level >= 1 { - info!( - hs_cipher = %hex_dump(&encrypted_hs), - "ME diag: handshake ciphertext" - ); - } wr.write_all(&encrypted_hs).await.map_err(ProxyError::Io)?; wr.flush().await.map_err(ProxyError::Io)?; diff --git a/src/transport/middle_proxy/health.rs b/src/transport/middle_proxy/health.rs index 399fd13..0573167 100644 --- a/src/transport/middle_proxy/health.rs +++ b/src/transport/middle_proxy/health.rs @@ -1728,6 +1728,8 @@ mod tests { false, None, Vec::new(), + false, + Vec::new(), 1, None, 12, diff --git a/src/transport/middle_proxy/pool.rs b/src/transport/middle_proxy/pool.rs index 50861eb..077a014 100644 --- a/src/transport/middle_proxy/pool.rs +++ b/src/transport/middle_proxy/pool.rs @@ -336,6 +336,8 @@ pub(super) struct NatRuntimeCore { pub(super) nat_probe: bool, pub(super) nat_stun: Option, pub(super) nat_stun_servers: Vec, + pub(super) stun_tcp_fallback: bool, + pub(super) http_ip_detect_urls: Vec, pub(super) nat_stun_live_servers: Arc>>, pub(super) nat_probe_concurrency: usize, pub(super) detected_ipv6: Option, @@ -484,6 +486,8 @@ impl MePool { nat_probe: bool, nat_stun: Option, nat_stun_servers: Vec, + stun_tcp_fallback: bool, + http_ip_detect_urls: Vec, nat_probe_concurrency: usize, detected_ipv6: Option, me_one_retry: u8, @@ -706,6 +710,8 @@ impl MePool { nat_probe, nat_stun, nat_stun_servers, + stun_tcp_fallback, + http_ip_detect_urls, nat_stun_live_servers: Arc::new(RwLock::new(Vec::new())), nat_probe_concurrency: nat_probe_concurrency.max(1), detected_ipv6, diff --git a/src/transport/middle_proxy/pool_nat.rs b/src/transport/middle_proxy/pool_nat.rs index be2d9df..28b8db8 100644 --- a/src/transport/middle_proxy/pool_nat.rs +++ b/src/transport/middle_proxy/pool_nat.rs @@ -1,19 +1,22 @@ use std::collections::HashMap; -use std::net::{IpAddr, Ipv4Addr}; +use std::net::IpAddr; use std::time::Duration; use tokio::task::JoinSet; use tokio::time::timeout; -use tracing::{debug, info, warn}; +use tracing::{debug, info}; use crate::error::{ProxyError, Result}; -use crate::network::probe::is_bogon; -use crate::network::stun::{IpFamily, stun_probe_dual, stun_probe_family_with_bind}; +use crate::network::probe::{detect_public_ipv4_http, is_bogon}; +use crate::network::stun::{ + IpFamily, stun_probe_dual_with_tcp_fallback, stun_probe_family_with_bind_and_tcp_fallback, +}; use super::MePool; use std::time::Instant; const STUN_BATCH_TIMEOUT: Duration = Duration::from_secs(5); +const STUN_BATCH_TCP_FALLBACK_TIMEOUT: Duration = Duration::from_secs(12); #[allow(dead_code)] pub async fn stun_probe(stun_addr: Option) -> Result { @@ -28,15 +31,14 @@ pub async fn stun_probe(stun_addr: Option) -> Result Option { - fetch_public_ipv4_with_retry() + let urls = crate::config::defaults::default_http_ip_detect_urls(); + detect_public_ipv4_http(&urls) .await - .ok() - .flatten() .map(IpAddr::V4) } @@ -65,15 +67,26 @@ impl MePool { let mut live_servers = Vec::new(); let mut best_by_ip: HashMap = HashMap::new(); let concurrency = self.nat_runtime.nat_probe_concurrency.max(1); + let tcp_fallback = self.nat_runtime.stun_tcp_fallback; while next_idx < servers.len() || !join_set.is_empty() { while next_idx < servers.len() && join_set.len() < concurrency { let stun_addr = servers[next_idx].clone(); next_idx += 1; join_set.spawn(async move { + let batch_timeout = if tcp_fallback { + STUN_BATCH_TCP_FALLBACK_TIMEOUT + } else { + STUN_BATCH_TIMEOUT + }; let res = timeout( - STUN_BATCH_TIMEOUT, - stun_probe_family_with_bind(&stun_addr, family, bind_ip), + batch_timeout, + stun_probe_family_with_bind_and_tcp_fallback( + &stun_addr, + family, + bind_ip, + tcp_fallback, + ), ) .await; (stun_addr, res) @@ -193,6 +206,10 @@ impl MePool { return self.nat_runtime.nat_ip_cfg; } + if !self.nat_runtime.nat_probe { + return None; + } + if !(is_bogon(local_ip) || local_ip.is_loopback() || local_ip.is_unspecified()) { return None; } @@ -201,21 +218,15 @@ impl MePool { return Some(ip); } - match fetch_public_ipv4_with_retry().await { - Ok(Some(ip)) => { - { - let mut guard = self.nat_runtime.nat_ip_detected.write().await; - *guard = Some(IpAddr::V4(ip)); - } - info!(public_ip = %ip, "Auto-detected public IP for NAT translation"); - Some(IpAddr::V4(ip)) - } - Ok(None) => None, - Err(e) => { - warn!(error = %e, "Failed to auto-detect public IP"); - None - } + let Some(ip) = detect_public_ipv4_http(&self.nat_runtime.http_ip_detect_urls).await else { + return None; + }; + { + let mut guard = self.nat_runtime.nat_ip_detected.write().await; + *guard = Some(IpAddr::V4(ip)); } + info!(public_ip = %ip, "Auto-detected public IP for NAT translation"); + Some(IpAddr::V4(ip)) } pub(super) async fn maybe_reflect_public_addr( @@ -365,31 +376,3 @@ impl MePool { None } } - -async fn fetch_public_ipv4_with_retry() -> Result> { - let providers = [ - "https://checkip.amazonaws.com", - "http://v4.ident.me", - "http://ipv4.icanhazip.com", - ]; - for url in providers { - if let Ok(Some(ip)) = fetch_public_ipv4_once(url).await { - return Ok(Some(ip)); - } - } - Ok(None) -} - -async fn fetch_public_ipv4_once(url: &str) -> Result> { - let res = reqwest::get(url) - .await - .map_err(|e| ProxyError::Proxy(format!("public IP detection request failed: {e}")))?; - - let text = res - .text() - .await - .map_err(|e| ProxyError::Proxy(format!("public IP detection read failed: {e}")))?; - - let ip = text.trim().parse().ok(); - Ok(ip) -} diff --git a/src/transport/middle_proxy/tests/health_adversarial_tests.rs b/src/transport/middle_proxy/tests/health_adversarial_tests.rs index ea88c67..fdbd5f9 100644 --- a/src/transport/middle_proxy/tests/health_adversarial_tests.rs +++ b/src/transport/middle_proxy/tests/health_adversarial_tests.rs @@ -38,6 +38,8 @@ async fn make_pool( false, None, Vec::new(), + false, + Vec::new(), 1, None, 12, diff --git a/src/transport/middle_proxy/tests/health_integration_tests.rs b/src/transport/middle_proxy/tests/health_integration_tests.rs index 9b3f93e..1dc6abf 100644 --- a/src/transport/middle_proxy/tests/health_integration_tests.rs +++ b/src/transport/middle_proxy/tests/health_integration_tests.rs @@ -36,6 +36,8 @@ async fn make_pool( false, None, Vec::new(), + false, + Vec::new(), 1, None, 12, diff --git a/src/transport/middle_proxy/tests/health_regression_tests.rs b/src/transport/middle_proxy/tests/health_regression_tests.rs index aa1f9ed..c2f5441 100644 --- a/src/transport/middle_proxy/tests/health_regression_tests.rs +++ b/src/transport/middle_proxy/tests/health_regression_tests.rs @@ -31,6 +31,8 @@ async fn make_pool(me_pool_drain_threshold: u64) -> Arc { false, None, Vec::new(), + false, + Vec::new(), 1, None, 12, diff --git a/src/transport/middle_proxy/tests/pool_refill_security_tests.rs b/src/transport/middle_proxy/tests/pool_refill_security_tests.rs index 6519c05..9125ad0 100644 --- a/src/transport/middle_proxy/tests/pool_refill_security_tests.rs +++ b/src/transport/middle_proxy/tests/pool_refill_security_tests.rs @@ -20,6 +20,8 @@ async fn make_pool() -> Arc { false, None, Vec::new(), + false, + Vec::new(), 1, None, 12, diff --git a/src/transport/middle_proxy/tests/pool_writer_security_tests.rs b/src/transport/middle_proxy/tests/pool_writer_security_tests.rs index 5f9f130..f0514a9 100644 --- a/src/transport/middle_proxy/tests/pool_writer_security_tests.rs +++ b/src/transport/middle_proxy/tests/pool_writer_security_tests.rs @@ -25,6 +25,8 @@ async fn make_pool() -> Arc { false, None, Vec::new(), + false, + Vec::new(), 1, None, 12, diff --git a/src/transport/middle_proxy/tests/send_adversarial_tests.rs b/src/transport/middle_proxy/tests/send_adversarial_tests.rs index 4050fa1..7ed8d76 100644 --- a/src/transport/middle_proxy/tests/send_adversarial_tests.rs +++ b/src/transport/middle_proxy/tests/send_adversarial_tests.rs @@ -31,6 +31,8 @@ async fn make_pool() -> (Arc, Arc) { false, None, Vec::new(), + false, + Vec::new(), 1, None, 12, From 2e26bfb86ec0a7a290cd9ec8895b90757c654bd9 Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Sun, 14 Jun 2026 16:33:41 +0300 Subject: [PATCH 03/15] Updated secure padding expectations for VersionD Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com> --- src/stream/frame_codec.rs | 11 +++++------ src/stream/frame_stream.rs | 2 +- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/stream/frame_codec.rs b/src/stream/frame_codec.rs index 2542e37..d0d11b5 100644 --- a/src/stream/frame_codec.rs +++ b/src/stream/frame_codec.rs @@ -317,7 +317,7 @@ fn encode_secure(frame: &Frame, dst: &mut BytesMut, rng: &SecureRandom) -> io::R )); } - // Generate padding that keeps total length non-divisible by 4. + // Telegram Desktop VersionD uses a 4-bit random padding length. let padding_len = secure_padding_len(data.len(), rng); let total_len = data.len() + padding_len; @@ -642,7 +642,7 @@ mod tests { } #[test] - fn secure_codec_always_adds_padding_and_jitters_wire_length() { + fn secure_codec_uses_tdesktop_padding_range_and_jitters_wire_length() { let codec = SecureCodec::new(Arc::new(SecureRandom::new())); let payload = Bytes::from_static(&[1, 2, 3, 4, 5, 6, 7, 8]); let mut wire_lens = HashSet::new(); @@ -652,13 +652,12 @@ mod tests { let mut out = BytesMut::new(); codec.encode(&frame, &mut out).unwrap(); - assert!(out.len() > 4 + payload.len()); let wire_len = u32::from_le_bytes([out[0], out[1], out[2], out[3]]) as usize; + assert_eq!(out.len(), 4 + wire_len); assert!( - (payload.len() + 1..=payload.len() + 3).contains(&wire_len), - "Secure wire length must be payload+1..3, got {wire_len}" + (payload.len()..=payload.len() + 15).contains(&wire_len), + "Secure wire length must be payload+0..15, got {wire_len}" ); - assert_ne!(wire_len % 4, 0, "Secure wire length must be non-4-aligned"); wire_lens.insert(wire_len); } diff --git a/src/stream/frame_stream.rs b/src/stream/frame_stream.rs index e9f1d3e..ed84645 100644 --- a/src/stream/frame_stream.rs +++ b/src/stream/frame_stream.rs @@ -311,7 +311,7 @@ impl SecureIntermediateFrameWriter { )); } - // Add padding so total length is never divisible by 4 (MTProto Secure) + // Telegram Desktop VersionD uses a 4-bit random padding length. let padding_len = secure_padding_len(data.len(), &self.rng); let padding = self.rng.bytes(padding_len); From 04b8d8365cc3d6b898a5d3d9a1a93450c9311076 Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Sun, 14 Jun 2026 19:38:54 +0300 Subject: [PATCH 04/15] Account for full-word paddings in roundtrip tests --- src/proxy/middle_relay/idle/read.rs | 3 ++- src/stream/frame_codec.rs | 18 ++++++++++++++++-- src/stream/frame_stream.rs | 12 +++++++++++- 3 files changed, 29 insertions(+), 4 deletions(-) diff --git a/src/proxy/middle_relay/idle/read.rs b/src/proxy/middle_relay/idle/read.rs index 270f104..80ca0cc 100644 --- a/src/proxy/middle_relay/idle/read.rs +++ b/src/proxy/middle_relay/idle/read.rs @@ -331,7 +331,8 @@ where ) .await?; - // Secure Intermediate: strip validated trailing padding bytes. + // Secure Intermediate strips only non-aligned tail padding; full-word + // padding is indistinguishable from payload in VersionD framing. if proto_tag == ProtoTag::Secure { payload.truncate(secure_payload_len); } diff --git a/src/stream/frame_codec.rs b/src/stream/frame_codec.rs index d0d11b5..cbec951 100644 --- a/src/stream/frame_codec.rs +++ b/src/stream/frame_codec.rs @@ -523,6 +523,16 @@ mod tests { use tokio::io::duplex; use tokio_util::codec::{FramedRead, FramedWrite}; + fn assert_secure_decoded_payload(decoded: &[u8], original: &[u8]) { + assert!(decoded.starts_with(original)); + assert!( + (original.len()..=original.len() + 12).contains(&decoded.len()), + "Secure decoded payload may retain up to 12 bytes of full-word padding, got {}", + decoded.len() + ); + assert_eq!(decoded.len() % 4, 0); + } + #[tokio::test] async fn test_framed_abridged() { let (client, server) = duplex(4096); @@ -565,7 +575,7 @@ mod tests { writer.send(frame).await.unwrap(); let received = reader.next().await.unwrap().unwrap(); - assert_eq!(&received.data[..], &original[..]); + assert_secure_decoded_payload(&received.data, &original); } #[tokio::test] @@ -588,7 +598,11 @@ mod tests { writer.send(frame).await.unwrap(); let received = reader.next().await.unwrap().unwrap(); - assert_eq!(received.data.len(), 8); + if proto_tag == ProtoTag::Secure { + assert_secure_decoded_payload(&received.data, &original); + } else { + assert_eq!(received.data.len(), original.len()); + } } } diff --git a/src/stream/frame_stream.rs b/src/stream/frame_stream.rs index ed84645..61742de 100644 --- a/src/stream/frame_stream.rs +++ b/src/stream/frame_stream.rs @@ -559,6 +559,16 @@ mod tests { use std::sync::Arc; use tokio::io::duplex; + fn assert_secure_decoded_payload(decoded: &[u8], original: &[u8]) { + assert!(decoded.starts_with(original)); + assert!( + (original.len()..=original.len() + 12).contains(&decoded.len()), + "Secure decoded payload may retain up to 12 bytes of full-word padding, got {}", + decoded.len() + ); + assert_eq!(decoded.len() % 4, 0); + } + #[tokio::test] async fn test_abridged_roundtrip() { let (client, server) = duplex(1024); @@ -625,7 +635,7 @@ mod tests { writer.flush().await.unwrap(); let (received, _meta) = reader.read_frame().await.unwrap(); - assert_eq!(received.len(), data.len()); + assert_secure_decoded_payload(&received, &data); } #[tokio::test] From d81d7dba6263bb1b156405af2664c6d388bdc3e0 Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Sun, 14 Jun 2026 19:58:39 +0300 Subject: [PATCH 05/15] Rustfmt --- src/network/stun.rs | 10 ++++++++-- src/transport/middle_proxy/pool_nat.rs | 4 +--- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/network/stun.rs b/src/network/stun.rs index c2c1c86..ca4a8cb 100644 --- a/src/network/stun.rs +++ b/src/network/stun.rs @@ -409,7 +409,10 @@ mod tests { response[28..32].copy_from_slice(&[203, 0, 113, 9]); let reflected = parse_reflected_addr(&response, &txid).unwrap(); - assert_eq!(reflected, SocketAddr::new(IpAddr::V4(Ipv4Addr::new(203, 0, 113, 9)), 443)); + assert_eq!( + reflected, + SocketAddr::new(IpAddr::V4(Ipv4Addr::new(203, 0, 113, 9)), 443) + ); } #[test] @@ -436,6 +439,9 @@ mod tests { response[28..32].copy_from_slice(&xip); let reflected = parse_reflected_addr(&response, &txid).unwrap(); - assert_eq!(reflected, SocketAddr::new(IpAddr::V4(Ipv4Addr::new(203, 0, 113, 9)), 443)); + assert_eq!( + reflected, + SocketAddr::new(IpAddr::V4(Ipv4Addr::new(203, 0, 113, 9)), 443) + ); } } diff --git a/src/transport/middle_proxy/pool_nat.rs b/src/transport/middle_proxy/pool_nat.rs index 28b8db8..d9b7973 100644 --- a/src/transport/middle_proxy/pool_nat.rs +++ b/src/transport/middle_proxy/pool_nat.rs @@ -37,9 +37,7 @@ pub async fn stun_probe(stun_addr: Option) -> Result Option { let urls = crate::config::defaults::default_http_ip_detect_urls(); - detect_public_ipv4_http(&urls) - .await - .map(IpAddr::V4) + detect_public_ipv4_http(&urls).await.map(IpAddr::V4) } impl MePool { From 37d0184a0bd99c9f22b45898b057fde408a6929b Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Mon, 15 Jun 2026 08:50:08 +0300 Subject: [PATCH 06/15] Implement shared MTProto framing and ME address role separation Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com> --- src/config/types.rs | 2 +- src/protocol/constants.rs | 10 +- src/protocol/framing.rs | 92 ++++++++++++ src/protocol/mod.rs | 1 + src/proxy/handshake.rs | 135 ++---------------- src/proxy/handshake/tls_auth.rs | 126 ++++++++++++++++ src/proxy/middle_relay/d2c.rs | 19 ++- src/proxy/middle_relay/idle/read.rs | 6 +- src/stream/frame_codec.rs | 41 +++--- src/stream/frame_stream.rs | 83 ++++++++--- src/transport/middle_proxy/handshake.rs | 3 + src/transport/middle_proxy/send.rs | 6 +- .../tests/send_adversarial_tests.rs | 75 ++++++++++ 13 files changed, 415 insertions(+), 184 deletions(-) create mode 100644 src/protocol/framing.rs create mode 100644 src/proxy/handshake/tls_auth.rs diff --git a/src/config/types.rs b/src/config/types.rs index e0f7b04..6b95260 100644 --- a/src/config/types.rs +++ b/src/config/types.rs @@ -429,7 +429,7 @@ pub struct GeneralConfig { pub ad_tag: Option, /// Public IP override for middle-proxy NAT environments. - /// When set, this IP is used in ME key derivation and RPC_PROXY_REQ "our_addr". + /// When set, this IP is used in ME key derivation and local address translation. #[serde(default)] pub middle_proxy_nat_ip: Option, diff --git a/src/protocol/constants.rs b/src/protocol/constants.rs index 6f6c3ae..19246aa 100644 --- a/src/protocol/constants.rs +++ b/src/protocol/constants.rs @@ -5,6 +5,9 @@ use std::net::{IpAddr, Ipv4Addr}; use crate::crypto::SecureRandom; +use crate::protocol::framing::{ + secure_version_d_body_len_from_wire_len, secure_version_d_padding_len, +}; use std::sync::LazyLock; // ============= Telegram Datacenters ============= @@ -239,10 +242,7 @@ pub fn is_valid_secure_payload_len(data_len: usize) -> bool { /// Secure mode cannot distinguish full-word padding from payload, so only the /// non-aligned tail bytes are stripped. pub fn secure_payload_len_from_wire_len(wire_len: usize) -> Option { - if wire_len < 4 { - return None; - } - Some(wire_len - (wire_len % 4)) + secure_version_d_body_len_from_wire_len(wire_len) } /// Generate padding length for Secure Intermediate protocol. @@ -252,7 +252,7 @@ pub fn secure_padding_len(data_len: usize, rng: &SecureRandom) -> usize { is_valid_secure_payload_len(data_len), "Secure payload must be 4-byte aligned, got {data_len}" ); - rng.range(16) + secure_version_d_padding_len(rng) } // ============= Timeouts ============= diff --git a/src/protocol/framing.rs b/src/protocol/framing.rs new file mode 100644 index 0000000..dd63e89 --- /dev/null +++ b/src/protocol/framing.rs @@ -0,0 +1,92 @@ +//! Shared MTProto transport framing helpers. + +use crate::crypto::SecureRandom; + +/// QuickACK marker bit used by Intermediate and Secure Intermediate headers. +pub(crate) const INTERMEDIATE_QUICKACK_FLAG: u32 = 0x8000_0000; + +/// Payload length mask used by Intermediate and Secure Intermediate headers. +pub(crate) const INTERMEDIATE_WIRE_LEN_MASK: u32 = 0x7fff_ffff; + +/// Maximum random tail length used by Telegram Desktop VersionD packets. +pub(crate) const SECURE_VERSION_D_PADDING_MAX: usize = 15; + +/// Parsed Intermediate/Secure Intermediate length header. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub(crate) struct IntermediateHeader { + /// Payload length on the wire, excluding the four-byte header. + pub(crate) wire_len: usize, + /// Whether the QuickACK marker bit was set in the length header. + pub(crate) quickack: bool, +} + +/// Parse an Intermediate/Secure Intermediate length header. +pub(crate) fn parse_intermediate_header(header: [u8; 4]) -> IntermediateHeader { + let raw = u32::from_le_bytes(header); + IntermediateHeader { + wire_len: (raw & INTERMEDIATE_WIRE_LEN_MASK) as usize, + quickack: (raw & INTERMEDIATE_QUICKACK_FLAG) != 0, + } +} + +/// Encode an Intermediate/Secure Intermediate length header. +pub(crate) fn encode_intermediate_header(wire_len: usize, quickack: bool) -> Option { + if wire_len > INTERMEDIATE_WIRE_LEN_MASK as usize { + return None; + } + + let mut raw = u32::try_from(wire_len).ok()?; + if quickack { + raw |= INTERMEDIATE_QUICKACK_FLAG; + } + Some(raw) +} + +/// Recover the VersionD body length visible to MTProto from the encrypted wire length. +pub(crate) fn secure_version_d_body_len_from_wire_len(wire_len: usize) -> Option { + if wire_len < 4 { + return None; + } + + Some(wire_len - (wire_len % 4)) +} + +/// Generate Telegram Desktop-compatible VersionD random tail length. +pub(crate) fn secure_version_d_padding_len(rng: &SecureRandom) -> usize { + rng.range(SECURE_VERSION_D_PADDING_MAX + 1) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn intermediate_header_roundtrip_preserves_quickack_zero_length() { + let encoded = encode_intermediate_header(0, true).unwrap(); + assert_eq!(encoded, INTERMEDIATE_QUICKACK_FLAG); + + let parsed = parse_intermediate_header(encoded.to_le_bytes()); + assert_eq!(parsed.wire_len, 0); + assert!(parsed.quickack); + } + + #[test] + fn intermediate_header_rejects_lengths_above_31_bits() { + assert_eq!( + encode_intermediate_header(INTERMEDIATE_WIRE_LEN_MASK as usize, false), + Some(INTERMEDIATE_WIRE_LEN_MASK) + ); + assert_eq!( + encode_intermediate_header(INTERMEDIATE_WIRE_LEN_MASK as usize + 1, false), + None + ); + } + + #[test] + fn secure_version_d_body_len_strips_only_non_word_tail() { + assert_eq!(secure_version_d_body_len_from_wire_len(3), None); + assert_eq!(secure_version_d_body_len_from_wire_len(8), Some(8)); + assert_eq!(secure_version_d_body_len_from_wire_len(11), Some(8)); + assert_eq!(secure_version_d_body_len_from_wire_len(12), Some(12)); + } +} diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 9ffff7c..63e75b7 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -2,6 +2,7 @@ pub mod constants; pub mod frame; +pub(crate) mod framing; pub mod obfuscation; pub mod tls; pub mod tls_fingerprint; diff --git a/src/proxy/handshake.rs b/src/proxy/handshake.rs index 084fadc..f9f55de 100644 --- a/src/proxy/handshake.rs +++ b/src/proxy/handshake.rs @@ -4,7 +4,6 @@ use dashmap::DashMap; use dashmap::mapref::entry::Entry; -use hmac::{Hmac, Mac}; #[cfg(test)] use std::collections::HashSet; use std::collections::hash_map::DefaultHasher; @@ -33,8 +32,10 @@ use crate::stream::{CryptoReader, CryptoWriter, FakeTlsReader, FakeTlsWriter}; use crate::tls_front::{TlsFrontCache, emulator}; #[cfg(test)] use rand::RngExt; -use sha2::Sha256; -use subtle::ConstantTimeEq; + +mod tls_auth; + +use self::tls_auth::{parse_tls_auth_material, validate_tls_secret_candidate}; const ACCESS_SECRET_BYTES: usize = 16; const UNKNOWN_SNI_WARN_COOLDOWN_SECS: u64 = 5; @@ -58,8 +59,6 @@ const OVERLOAD_CANDIDATE_BUDGET_UNHINTED: usize = 8; const EXPENSIVE_INVALID_SCAN_SATURATION_THRESHOLD: usize = 64; const RECENT_USER_RING_SCAN_LIMIT: usize = 32; -type HmacSha256 = Hmac; - #[cfg(test)] const AUTH_PROBE_BACKOFF_BASE_MS: u64 = 1; #[cfg(not(test))] @@ -104,23 +103,6 @@ fn should_emit_unknown_sni_warn_in(shared: &ProxySharedState, now: Instant) -> b true } -#[derive(Clone, Copy)] -struct ParsedTlsAuthMaterial { - digest: [u8; tls::TLS_DIGEST_LEN], - session_id: [u8; 32], - session_id_len: usize, - now: i64, - ignore_time_skew: bool, - boot_time_cap_secs: u32, -} - -#[derive(Clone, Copy)] -struct TlsCandidateValidation { - digest: [u8; tls::TLS_DIGEST_LEN], - session_id: [u8; 32], - session_id_len: usize, -} - struct MtprotoCandidateValidation { proto_tag: ProtoTag, dc_idx: i16, @@ -251,104 +233,6 @@ fn budget_for_validation(total_users: usize, overload: bool, has_hint: bool) -> total_users.min(cap.max(1)) } -fn parse_tls_auth_material( - handshake: &[u8], - ignore_time_skew: bool, - replay_window_secs: u64, -) -> Option { - if handshake.len() < tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 { - return None; - } - - let digest: [u8; tls::TLS_DIGEST_LEN] = handshake - [tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] - .try_into() - .ok()?; - - let session_id_len_pos = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN; - let session_id_len = usize::from(handshake.get(session_id_len_pos).copied()?); - if session_id_len > 32 { - return None; - } - let session_id_start = session_id_len_pos + 1; - if handshake.len() < session_id_start + session_id_len { - return None; - } - - let mut session_id = [0u8; 32]; - session_id[..session_id_len] - .copy_from_slice(&handshake[session_id_start..session_id_start + session_id_len]); - - let now = if !ignore_time_skew { - let d = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .ok()?; - i64::try_from(d.as_secs()).ok()? - } else { - 0_i64 - }; - - let replay_window_u32 = u32::try_from(replay_window_secs).unwrap_or(u32::MAX); - let boot_time_cap_secs = if ignore_time_skew { - 0 - } else { - tls::BOOT_TIME_MAX_SECS - .min(replay_window_u32) - .min(tls::BOOT_TIME_COMPAT_MAX_SECS) - }; - - Some(ParsedTlsAuthMaterial { - digest, - session_id, - session_id_len, - now, - ignore_time_skew, - boot_time_cap_secs, - }) -} - -fn compute_tls_hmac_zeroed_digest(secret: &[u8], handshake: &[u8]) -> [u8; 32] { - let mut mac = HmacSha256::new_from_slice(secret).expect("HMAC accepts any key length"); - mac.update(&handshake[..tls::TLS_DIGEST_POS]); - mac.update(&[0u8; tls::TLS_DIGEST_LEN]); - mac.update(&handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN..]); - mac.finalize().into_bytes().into() -} - -fn validate_tls_secret_candidate( - parsed: &ParsedTlsAuthMaterial, - handshake: &[u8], - secret: &[u8], -) -> Option { - let computed = compute_tls_hmac_zeroed_digest(secret, handshake); - if !bool::from(parsed.digest[..28].ct_eq(&computed[..28])) { - return None; - } - - let timestamp = u32::from_le_bytes([ - parsed.digest[28] ^ computed[28], - parsed.digest[29] ^ computed[29], - parsed.digest[30] ^ computed[30], - parsed.digest[31] ^ computed[31], - ]); - - if !parsed.ignore_time_skew { - let is_boot_time = parsed.boot_time_cap_secs > 0 && timestamp < parsed.boot_time_cap_secs; - if !is_boot_time { - let time_diff = parsed.now - i64::from(timestamp); - if !(tls::TIME_SKEW_MIN..=tls::TIME_SKEW_MAX).contains(&time_diff) { - return None; - } - } - } - - Some(TlsCandidateValidation { - digest: parsed.digest, - session_id: parsed.session_id, - session_id_len: parsed.session_id_len, - }) -} - fn validate_mtproto_secret_candidate( handshake: &[u8; HANDSHAKE_LEN], dec_prekey: &[u8; PREKEY_LEN], @@ -1857,7 +1741,16 @@ where return HandshakeResult::BadClient { reader, writer }; } - let validation = matched_validation.expect("validation must exist when matched"); + let Some(validation) = matched_validation else { + auth_probe_record_failure_in(shared, peer.ip(), Instant::now()); + maybe_apply_server_hello_delay(config).await; + warn!( + peer = %peer, + user = %matched_user, + "MTProto handshake matched user without validation material" + ); + return HandshakeResult::BadClient { reader, writer }; + }; if config .access diff --git a/src/proxy/handshake/tls_auth.rs b/src/proxy/handshake/tls_auth.rs new file mode 100644 index 0000000..2feb666 --- /dev/null +++ b/src/proxy/handshake/tls_auth.rs @@ -0,0 +1,126 @@ +use hmac::{Hmac, Mac}; +use sha2::Sha256; +use subtle::ConstantTimeEq; + +use crate::protocol::tls; + +type HmacSha256 = Hmac; + +/// Parsed TLS authentication material extracted from a ClientHello candidate. +#[derive(Clone, Copy)] +pub(super) struct ParsedTlsAuthMaterial { + digest: [u8; tls::TLS_DIGEST_LEN], + session_id: [u8; 32], + session_id_len: usize, + now: i64, + ignore_time_skew: bool, + boot_time_cap_secs: u32, +} + +/// Successful TLS secret validation output used by the handshake state machine. +#[derive(Clone, Copy)] +pub(super) struct TlsCandidateValidation { + pub(super) digest: [u8; tls::TLS_DIGEST_LEN], + pub(super) session_id: [u8; 32], + pub(super) session_id_len: usize, +} + +/// Parse TLS auth digest and session-id material from a candidate handshake. +pub(super) fn parse_tls_auth_material( + handshake: &[u8], + ignore_time_skew: bool, + replay_window_secs: u64, +) -> Option { + if handshake.len() < tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 { + return None; + } + + let digest: [u8; tls::TLS_DIGEST_LEN] = handshake + [tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .try_into() + .ok()?; + + let session_id_len_pos = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN; + let session_id_len = usize::from(handshake.get(session_id_len_pos).copied()?); + if session_id_len > 32 { + return None; + } + let session_id_start = session_id_len_pos + 1; + if handshake.len() < session_id_start + session_id_len { + return None; + } + + let mut session_id = [0u8; 32]; + session_id[..session_id_len] + .copy_from_slice(&handshake[session_id_start..session_id_start + session_id_len]); + + let now = if !ignore_time_skew { + let d = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .ok()?; + i64::try_from(d.as_secs()).ok()? + } else { + 0_i64 + }; + + let replay_window_u32 = u32::try_from(replay_window_secs).unwrap_or(u32::MAX); + let boot_time_cap_secs = if ignore_time_skew { + 0 + } else { + tls::BOOT_TIME_MAX_SECS + .min(replay_window_u32) + .min(tls::BOOT_TIME_COMPAT_MAX_SECS) + }; + + Some(ParsedTlsAuthMaterial { + digest, + session_id, + session_id_len, + now, + ignore_time_skew, + boot_time_cap_secs, + }) +} + +fn compute_tls_hmac_zeroed_digest(secret: &[u8], handshake: &[u8]) -> Option<[u8; 32]> { + let mut mac = HmacSha256::new_from_slice(secret).ok()?; + mac.update(&handshake[..tls::TLS_DIGEST_POS]); + mac.update(&[0u8; tls::TLS_DIGEST_LEN]); + mac.update(&handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN..]); + Some(mac.finalize().into_bytes().into()) +} + +/// Validate a candidate secret against parsed TLS authentication material. +pub(super) fn validate_tls_secret_candidate( + parsed: &ParsedTlsAuthMaterial, + handshake: &[u8], + secret: &[u8], +) -> Option { + let computed = compute_tls_hmac_zeroed_digest(secret, handshake)?; + if !bool::from(parsed.digest[..28].ct_eq(&computed[..28])) { + return None; + } + + let timestamp = u32::from_le_bytes([ + parsed.digest[28] ^ computed[28], + parsed.digest[29] ^ computed[29], + parsed.digest[30] ^ computed[30], + parsed.digest[31] ^ computed[31], + ]); + + if !parsed.ignore_time_skew { + let is_boot_time = parsed.boot_time_cap_secs > 0 && timestamp < parsed.boot_time_cap_secs; + if !is_boot_time { + let time_diff = parsed.now - i64::from(timestamp); + if !(tls::TIME_SKEW_MIN..=tls::TIME_SKEW_MAX).contains(&time_diff) { + return None; + } + } + } + + Some(TlsCandidateValidation { + digest: parsed.digest, + session_id: parsed.session_id, + session_id_len: parsed.session_id_len, + }) +} diff --git a/src/proxy/middle_relay/d2c.rs b/src/proxy/middle_relay/d2c.rs index 92fe3c1..d227aa9 100644 --- a/src/proxy/middle_relay/d2c.rs +++ b/src/proxy/middle_relay/d2c.rs @@ -276,20 +276,17 @@ pub(in crate::proxy::middle_relay) fn compute_intermediate_secure_wire_len( let wire_len = data_len .checked_add(padding_len) .ok_or_else(|| ProxyError::Proxy("Frame length overflow".into()))?; - if wire_len > 0x7fff_ffffusize { - return Err(ProxyError::Proxy(format!( - "Intermediate/Secure frame too large: {wire_len}" - ))); - } - + let len_val = + crate::protocol::framing::encode_intermediate_header(wire_len, quickack).ok_or_else( + || { + ProxyError::Proxy(format!( + "Intermediate/Secure frame too large: {wire_len}" + )) + }, + )?; let total = 4usize .checked_add(wire_len) .ok_or_else(|| ProxyError::Proxy("Frame buffer size overflow".into()))?; - let mut len_val = u32::try_from(wire_len) - .map_err(|_| ProxyError::Proxy("Frame length conversion overflow".into()))?; - if quickack { - len_val |= 0x8000_0000; - } Ok((len_val, total)) } diff --git a/src/proxy/middle_relay/idle/read.rs b/src/proxy/middle_relay/idle/read.rs index 80ca0cc..652041c 100644 --- a/src/proxy/middle_relay/idle/read.rs +++ b/src/proxy/middle_relay/idle/read.rs @@ -236,10 +236,10 @@ where } Err(e) => return Err(e), } - let quickack = (len_buf[3] & 0x80) != 0; + let header = crate::protocol::framing::parse_intermediate_header(len_buf); ( - (u32::from_le_bytes(len_buf) & 0x7fff_ffff) as usize, - quickack, + header.wire_len, + header.quickack, Some(len_buf), ) } diff --git a/src/stream/frame_codec.rs b/src/stream/frame_codec.rs index cbec951..ddf4bde 100644 --- a/src/stream/frame_codec.rs +++ b/src/stream/frame_codec.rs @@ -15,6 +15,7 @@ use crate::crypto::SecureRandom; use crate::protocol::constants::{ ProtoTag, is_valid_secure_payload_len, secure_padding_len, secure_payload_len_from_wire_len, }; +use crate::protocol::framing::{encode_intermediate_header, parse_intermediate_header}; // ============= Unified Codec ============= @@ -197,13 +198,9 @@ fn decode_intermediate(src: &mut BytesMut, max_size: usize) -> io::Result