diff --git a/src/config/defaults.rs b/src/config/defaults.rs index 3fb8c3d..4f0a53d 100644 --- a/src/config/defaults.rs +++ b/src/config/defaults.rs @@ -132,6 +132,10 @@ pub(crate) fn default_middle_proxy_nat_stun_servers() -> Vec { ] } +pub(crate) fn default_stun_nat_probe_concurrency() -> usize { + 8 +} + pub(crate) fn default_middle_proxy_warm_standby() -> usize { DEFAULT_MIDDLE_PROXY_WARM_STANDBY } diff --git a/src/config/hot_reload.rs b/src/config/hot_reload.rs index 7f121f6..c949104 100644 --- a/src/config/hot_reload.rs +++ b/src/config/hot_reload.rs @@ -96,6 +96,9 @@ fn warn_non_hot_changes(old: &ProxyConfig, new: &ProxyConfig) { if old.general.use_middle_proxy != new.general.use_middle_proxy { warn!("config reload: use_middle_proxy changed; restart required"); } + if old.general.stun_nat_probe_concurrency != new.general.stun_nat_probe_concurrency { + warn!("config reload: general.stun_nat_probe_concurrency changed; restart required"); + } } /// Resolve the public host for link generation — mirrors the logic in main.rs. diff --git a/src/config/load.rs b/src/config/load.rs index 35099be..31a8b5d 100644 --- a/src/config/load.rs +++ b/src/config/load.rs @@ -166,6 +166,12 @@ impl ProxyConfig { } } + if config.general.stun_nat_probe_concurrency == 0 { + return Err(ProxyError::Config( + "general.stun_nat_probe_concurrency must be > 0".to_string(), + )); + } + if config.general.me_reinit_every_secs == 0 { return Err(ProxyError::Config( "general.me_reinit_every_secs must be > 0".to_string(), @@ -607,6 +613,26 @@ mod tests { let _ = std::fs::remove_file(path); } + #[test] + fn stun_nat_probe_concurrency_zero_is_rejected() { + let toml = r#" + [general] + stun_nat_probe_concurrency = 0 + + [censorship] + tls_domain = "example.com" + + [access.users] + user = "00000000000000000000000000000000" + "#; + let dir = std::env::temp_dir(); + let path = dir.join("telemt_stun_nat_probe_concurrency_zero_test.toml"); + std::fs::write(&path, toml).unwrap(); + let err = ProxyConfig::load(&path).unwrap_err().to_string(); + assert!(err.contains("general.stun_nat_probe_concurrency must be > 0")); + let _ = std::fs::remove_file(path); + } + #[test] fn me_reinit_every_default_is_set() { let toml = r#" diff --git a/src/config/types.rs b/src/config/types.rs index e827088..58a3a3e 100644 --- a/src/config/types.rs +++ b/src/config/types.rs @@ -168,6 +168,10 @@ pub struct GeneralConfig { #[serde(default = "default_middle_proxy_nat_stun_servers")] pub middle_proxy_nat_stun_servers: Vec, + /// Maximum number of concurrent STUN probes during NAT detection. + #[serde(default = "default_stun_nat_probe_concurrency")] + pub stun_nat_probe_concurrency: usize, + /// Desired size of active Middle-Proxy writer pool. #[serde(default = "default_pool_size")] pub middle_proxy_pool_size: usize, @@ -378,6 +382,7 @@ impl Default for GeneralConfig { middle_proxy_nat_probe: default_true(), middle_proxy_nat_stun: default_middle_proxy_nat_stun(), middle_proxy_nat_stun_servers: default_middle_proxy_nat_stun_servers(), + stun_nat_probe_concurrency: default_stun_nat_probe_concurrency(), middle_proxy_pool_size: default_pool_size(), middle_proxy_warm_standby: default_middle_proxy_warm_standby(), me_keepalive_enabled: default_true(), diff --git a/src/main.rs b/src/main.rs index e4f7a79..dd4feb8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -257,7 +257,9 @@ async fn main() -> std::result::Result<(), Box> { let probe = run_probe( &config.network, config.general.middle_proxy_nat_stun.clone(), + config.general.middle_proxy_nat_stun_servers.clone(), config.general.middle_proxy_nat_probe, + config.general.stun_nat_probe_concurrency, ) .await?; let decision = decide_network_capabilities(&config.network, &probe); @@ -360,6 +362,7 @@ async fn main() -> std::result::Result<(), Box> { config.general.middle_proxy_nat_probe, config.general.middle_proxy_nat_stun.clone(), config.general.middle_proxy_nat_stun_servers.clone(), + config.general.stun_nat_probe_concurrency, probe.detected_ipv6, config.timeouts.me_one_retry, config.timeouts.me_one_timeout_ms, diff --git a/src/network/probe.rs b/src/network/probe.rs index c52b340..378faa5 100644 --- a/src/network/probe.rs +++ b/src/network/probe.rs @@ -1,12 +1,16 @@ #![allow(dead_code)] +use std::collections::HashMap; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket}; +use std::time::Duration; -use tracing::{info, warn}; +use tokio::task::JoinSet; +use tokio::time::timeout; +use tracing::{debug, info, warn}; use crate::config::NetworkConfig; use crate::error::Result; -use crate::network::stun::{stun_probe_dual, DualStunResult, IpFamily}; +use crate::network::stun::{stun_probe_dual, DualStunResult, IpFamily, StunProbeResult}; #[derive(Debug, Clone, Default)] pub struct NetworkProbe { @@ -49,7 +53,15 @@ impl NetworkDecision { } } -pub async fn run_probe(config: &NetworkConfig, stun_addr: Option, nat_probe: bool) -> Result { +const STUN_BATCH_TIMEOUT: Duration = Duration::from_secs(5); + +pub async fn run_probe( + config: &NetworkConfig, + stun_addr: Option, + stun_servers: Vec, + nat_probe: bool, + stun_nat_probe_concurrency: usize, +) -> Result { let mut probe = NetworkProbe::default(); probe.detected_ipv4 = detect_local_ip_v4(); @@ -58,21 +70,30 @@ pub async fn run_probe(config: &NetworkConfig, stun_addr: Option, nat_pr probe.ipv4_is_bogon = probe.detected_ipv4.map(is_bogon_v4).unwrap_or(false); probe.ipv6_is_bogon = probe.detected_ipv6.map(is_bogon_v6).unwrap_or(false); - let stun_server = stun_addr.unwrap_or_else(|| "stun.l.google.com:19302".to_string()); let stun_res = if nat_probe { - match stun_probe_dual(&stun_server).await { - Ok(res) => res, - Err(e) => { - warn!(error = %e, "STUN probe failed, continuing without reflection"); - DualStunResult::default() - } - } + let servers = collect_stun_servers(config, stun_addr, stun_servers); + probe_stun_servers_parallel( + &servers, + stun_nat_probe_concurrency.max(1), + ) + .await } else { DualStunResult::default() }; probe.reflected_ipv4 = stun_res.v4.map(|r| r.reflected_addr); probe.reflected_ipv6 = stun_res.v6.map(|r| r.reflected_addr); + // If STUN is blocked but IPv4 is private, try HTTP public-IP fallback. + if nat_probe + && probe.reflected_ipv4.is_none() + && probe.detected_ipv4.map(is_bogon_v4).unwrap_or(false) + { + if let Some(public_ip) = detect_public_ipv4_http(&config.http_ip_detect_urls).await { + probe.reflected_ipv4 = Some(SocketAddr::new(IpAddr::V4(public_ip), 0)); + info!(public_ip = %public_ip, "STUN unavailable, using HTTP public IPv4 fallback"); + } + } + probe.ipv4_nat_detected = match (probe.detected_ipv4, probe.reflected_ipv4) { (Some(det), Some(reflected)) => det != reflected.ip(), _ => false, @@ -94,6 +115,134 @@ pub async fn run_probe(config: &NetworkConfig, stun_addr: Option, nat_pr Ok(probe) } +async fn detect_public_ipv4_http(urls: &[String]) -> Option { + let client = reqwest::Client::builder() + .timeout(Duration::from_secs(3)) + .build() + .ok()?; + + for url in urls { + let response = match client.get(url).send().await { + Ok(response) => response, + Err(_) => continue, + }; + + let body = match response.text().await { + Ok(body) => body, + Err(_) => continue, + }; + + let Ok(ip) = body.trim().parse::() else { + continue; + }; + if !is_bogon_v4(ip) { + return Some(ip); + } + } + + None +} + +fn collect_stun_servers( + config: &NetworkConfig, + stun_addr: Option, + stun_servers: Vec, +) -> Vec { + let mut out = Vec::new(); + if !stun_servers.is_empty() { + for s in stun_servers { + if !s.is_empty() && !out.contains(&s) { + out.push(s); + } + } + } else if let Some(s) = stun_addr + && !s.is_empty() + { + out.push(s); + } + + if out.is_empty() { + for s in &config.stun_servers { + if !s.is_empty() && !out.contains(s) { + out.push(s.clone()); + } + } + } + + if out.is_empty() { + out.push("stun.l.google.com:19302".to_string()); + } + + out +} + +async fn probe_stun_servers_parallel( + servers: &[String], + concurrency: usize, +) -> DualStunResult { + let mut join_set = JoinSet::new(); + let mut next_idx = 0usize; + let mut best_v4_by_ip: HashMap = HashMap::new(); + let mut best_v6_by_ip: HashMap = HashMap::new(); + + while next_idx < servers.len() || !join_set.is_empty() { + while next_idx < servers.len() && join_set.len() < concurrency { + let stun_addr = servers[next_idx].clone(); + next_idx += 1; + join_set.spawn(async move { + let res = timeout(STUN_BATCH_TIMEOUT, stun_probe_dual(&stun_addr)).await; + (stun_addr, res) + }); + } + + let Some(task) = join_set.join_next().await else { + break; + }; + + match task { + Ok((stun_addr, Ok(Ok(result)))) => { + if let Some(v4) = result.v4 { + let entry = best_v4_by_ip.entry(v4.reflected_addr.ip()).or_insert((0, v4)); + entry.0 += 1; + } + if let Some(v6) = result.v6 { + let entry = best_v6_by_ip.entry(v6.reflected_addr.ip()).or_insert((0, v6)); + entry.0 += 1; + } + if result.v4.is_some() || result.v6.is_some() { + debug!(stun = %stun_addr, "STUN server responded within probe timeout"); + } + } + Ok((stun_addr, Ok(Err(e)))) => { + debug!(error = %e, stun = %stun_addr, "STUN probe failed"); + } + Ok((stun_addr, Err(_))) => { + debug!(stun = %stun_addr, "STUN probe timeout"); + } + Err(e) => { + debug!(error = %e, "STUN probe task join failed"); + } + } + } + + let mut out = DualStunResult::default(); + if let Some((_, best)) = best_v4_by_ip + .into_values() + .max_by_key(|(count, _)| *count) + { + info!("STUN-Quorum reached, IP: {}", best.reflected_addr.ip()); + out.v4 = Some(best); + } + if let Some((_, best)) = best_v6_by_ip + .into_values() + .max_by_key(|(count, _)| *count) + { + info!("STUN-Quorum reached, IP: {}", best.reflected_addr.ip()); + out.v6 = Some(best); + } + out +} + pub fn decide_network_capabilities(config: &NetworkConfig, probe: &NetworkProbe) -> NetworkDecision { let ipv4_dc = config.ipv4 && probe.detected_ipv4.is_some(); let ipv6_dc = config.ipv6.unwrap_or(probe.detected_ipv6.is_some()) && probe.detected_ipv6.is_some(); diff --git a/src/transport/middle_proxy/pool.rs b/src/transport/middle_proxy/pool.rs index e5aebe4..c95457b 100644 --- a/src/transport/middle_proxy/pool.rs +++ b/src/transport/middle_proxy/pool.rs @@ -50,6 +50,8 @@ pub struct MePool { pub(super) nat_probe: bool, pub(super) nat_stun: Option, pub(super) nat_stun_servers: Vec, + pub(super) nat_stun_live_servers: Arc>>, + pub(super) nat_probe_concurrency: usize, pub(super) detected_ipv6: Option, pub(super) nat_probe_attempts: std::sync::atomic::AtomicU8, pub(super) nat_probe_disabled: std::sync::atomic::AtomicBool, @@ -120,6 +122,7 @@ impl MePool { nat_probe: bool, nat_stun: Option, nat_stun_servers: Vec, + nat_probe_concurrency: usize, detected_ipv6: Option, me_one_retry: u8, me_one_timeout_ms: u64, @@ -162,6 +165,8 @@ impl MePool { nat_probe, nat_stun, nat_stun_servers, + nat_stun_live_servers: Arc::new(RwLock::new(Vec::new())), + nat_probe_concurrency: nat_probe_concurrency.max(1), detected_ipv6, nat_probe_attempts: std::sync::atomic::AtomicU8::new(0), nat_probe_disabled: std::sync::atomic::AtomicBool::new(false), @@ -241,6 +246,9 @@ impl MePool { pub fn reset_stun_state(&self) { self.nat_probe_attempts.store(0, Ordering::Relaxed); self.nat_probe_disabled.store(false, Ordering::Relaxed); + if let Ok(mut live) = self.nat_stun_live_servers.try_write() { + live.clear(); + } } pub fn translate_our_addr(&self, addr: SocketAddr) -> SocketAddr { @@ -896,10 +904,25 @@ impl MePool { for family in family_order { let map = self.proxy_map_for_family(family).await; - let dc_addrs: Vec<(i32, Vec<(IpAddr, u16)>)> = map - .iter() - .map(|(dc, addrs)| (*dc, addrs.clone())) + let mut grouped_dc_addrs: HashMap> = HashMap::new(); + for (dc, addrs) in map { + if addrs.is_empty() { + continue; + } + grouped_dc_addrs + .entry(dc.abs()) + .or_default() + .extend(addrs); + } + let mut dc_addrs: Vec<(i32, Vec<(IpAddr, u16)>)> = grouped_dc_addrs + .into_iter() + .map(|(dc, mut addrs)| { + addrs.sort_unstable(); + addrs.dedup(); + (dc, addrs) + }) .collect(); + dc_addrs.sort_unstable_by_key(|(dc, _)| *dc); // Ensure at least one connection per DC; run DCs in parallel. let mut join = tokio::task::JoinSet::new(); @@ -923,38 +946,49 @@ impl MePool { return Err(ProxyError::Proxy("Too many ME DC init failures, falling back to direct".into())); } - // Additional connections up to pool_size total (round-robin across DCs), staggered to de-phase lifecycles. - if self.me_warmup_stagger_enabled { - for (dc, addrs) in dc_addrs.iter() { - for (ip, port) in addrs { - if self.connection_count() >= pool_size { - break; + // Warm reserve writers asynchronously so startup does not block after first working pool is ready. + let pool = Arc::clone(self); + let rng_clone = Arc::clone(rng); + let dc_addrs_bg = dc_addrs.clone(); + tokio::spawn(async move { + if pool.me_warmup_stagger_enabled { + for (dc, addrs) in dc_addrs_bg.iter() { + for (ip, port) in addrs { + if pool.connection_count() >= pool_size { + break; + } + let addr = SocketAddr::new(*ip, *port); + let jitter = rand::rng() + .random_range(0..=pool.me_warmup_step_jitter.as_millis() as u64); + let delay_ms = pool.me_warmup_step_delay.as_millis() as u64 + jitter; + tokio::time::sleep(Duration::from_millis(delay_ms)).await; + if let Err(e) = pool.connect_one(addr, rng_clone.as_ref()).await { + debug!(%addr, dc = %dc, error = %e, "Extra ME connect failed (staggered)"); + } + } } - let addr = SocketAddr::new(*ip, *port); - let jitter = rand::rng().random_range(0..=self.me_warmup_step_jitter.as_millis() as u64); - let delay_ms = self.me_warmup_step_delay.as_millis() as u64 + jitter; - tokio::time::sleep(Duration::from_millis(delay_ms)).await; - if let Err(e) = self.connect_one(addr, rng.as_ref()).await { - debug!(%addr, dc = %dc, error = %e, "Extra ME connect failed (staggered)"); + } else { + for (dc, addrs) in dc_addrs_bg.iter() { + for (ip, port) in addrs { + if pool.connection_count() >= pool_size { + break; + } + let addr = SocketAddr::new(*ip, *port); + if let Err(e) = pool.connect_one(addr, rng_clone.as_ref()).await { + debug!(%addr, dc = %dc, error = %e, "Extra ME connect failed"); + } + } + if pool.connection_count() >= pool_size { + break; + } } } - } - } else { - for (dc, addrs) in dc_addrs.iter() { - for (ip, port) in addrs { - if self.connection_count() >= pool_size { - break; - } - let addr = SocketAddr::new(*ip, *port); - if let Err(e) = self.connect_one(addr, rng.as_ref()).await { - debug!(%addr, dc = %dc, error = %e, "Extra ME connect failed"); - } - } - if self.connection_count() >= pool_size { - break; - } - } - } + debug!( + target_pool_size = pool_size, + current_pool_size = pool.connection_count(), + "Background ME reserve warmup finished" + ); + }); if !self.decision.effective_multipath && self.connection_count() > 0 { break; @@ -964,6 +998,10 @@ impl MePool { if self.writers.read().await.is_empty() { return Err(ProxyError::Proxy("No ME connections".into())); } + info!( + active_writers = self.connection_count(), + "ME primary pool ready; reserve warmup continues in background" + ); Ok(()) } diff --git a/src/transport/middle_proxy/pool_nat.rs b/src/transport/middle_proxy/pool_nat.rs index 9936707..37c0d5b 100644 --- a/src/transport/middle_proxy/pool_nat.rs +++ b/src/transport/middle_proxy/pool_nat.rs @@ -1,7 +1,10 @@ +use std::collections::HashMap; use std::net::{IpAddr, Ipv4Addr}; use std::time::Duration; -use tracing::{info, warn}; +use tokio::task::JoinSet; +use tokio::time::timeout; +use tracing::{debug, info, warn}; use crate::error::{ProxyError, Result}; use crate::network::probe::is_bogon; @@ -10,6 +13,8 @@ use crate::network::stun::{stun_probe_dual, IpFamily, StunProbeResult}; use super::MePool; use std::time::Instant; +const STUN_BATCH_TIMEOUT: Duration = Duration::from_secs(5); + #[allow(dead_code)] pub async fn stun_probe(stun_addr: Option) -> Result { let stun_addr = stun_addr.unwrap_or_else(|| "stun.l.google.com:19302".to_string()); @@ -22,6 +27,99 @@ pub async fn detect_public_ip() -> Option { } impl MePool { + fn configured_stun_servers(&self) -> Vec { + if !self.nat_stun_servers.is_empty() { + return self.nat_stun_servers.clone(); + } + if let Some(s) = &self.nat_stun { + return vec![s.clone()]; + } + vec!["stun.l.google.com:19302".to_string()] + } + + async fn probe_stun_batch_for_family( + &self, + servers: &[String], + family: IpFamily, + attempt: u8, + ) -> (Vec, Option) { + let mut join_set = JoinSet::new(); + let mut next_idx = 0usize; + let mut live_servers = Vec::new(); + let mut best_by_ip: HashMap = HashMap::new(); + let concurrency = self.nat_probe_concurrency.max(1); + + while next_idx < servers.len() || !join_set.is_empty() { + while next_idx < servers.len() && join_set.len() < concurrency { + let stun_addr = servers[next_idx].clone(); + next_idx += 1; + join_set.spawn(async move { + let res = timeout(STUN_BATCH_TIMEOUT, stun_probe_dual(&stun_addr)).await; + (stun_addr, res) + }); + } + + let Some(task) = join_set.join_next().await else { + break; + }; + + match task { + Ok((stun_addr, Ok(Ok(res)))) => { + let picked: Option = match family { + IpFamily::V4 => res.v4, + IpFamily::V6 => res.v6, + }; + + if let Some(result) = picked { + live_servers.push(stun_addr.clone()); + let entry = best_by_ip + .entry(result.reflected_addr.ip()) + .or_insert((0, result.reflected_addr)); + entry.0 += 1; + debug!( + local = %result.local_addr, + reflected = %result.reflected_addr, + family = ?family, + stun = %stun_addr, + "NAT probe: reflected address" + ); + } + } + Ok((stun_addr, Ok(Err(e)))) => { + debug!( + error = %e, + stun = %stun_addr, + attempt = attempt + 1, + "NAT probe failed, trying next server" + ); + } + Ok((stun_addr, Err(_))) => { + debug!( + stun = %stun_addr, + attempt = attempt + 1, + "NAT probe timeout, trying next server" + ); + } + Err(e) => { + debug!( + error = %e, + attempt = attempt + 1, + "NAT probe task join failed" + ); + } + } + } + + live_servers.sort_unstable(); + live_servers.dedup(); + let best_reflected = best_by_ip + .into_values() + .max_by_key(|(count, _)| *count) + .map(|(_, addr)| addr); + + (live_servers, best_reflected) + } + pub(super) fn translate_ip_for_nat(&self, ip: IpAddr) -> IpAddr { let nat_ip = self .nat_ip_cfg @@ -128,39 +226,51 @@ impl MePool { } let attempt = self.nat_probe_attempts.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - let servers = if !self.nat_stun_servers.is_empty() { - self.nat_stun_servers.clone() - } else if let Some(s) = &self.nat_stun { - vec![s.clone()] + let configured_servers = self.configured_stun_servers(); + let live_snapshot = self.nat_stun_live_servers.read().await.clone(); + let primary_servers = if live_snapshot.is_empty() { + configured_servers.clone() } else { - vec!["stun.l.google.com:19302".to_string()] + live_snapshot }; - for stun_addr in servers { - match stun_probe_dual(&stun_addr).await { - Ok(res) => { - let picked: Option = match family { - IpFamily::V4 => res.v4, - IpFamily::V6 => res.v6, - }; - if let Some(result) = picked { - info!(local = %result.local_addr, reflected = %result.reflected_addr, family = ?family, stun = %stun_addr, "NAT probe: reflected address"); - self.nat_probe_attempts.store(0, std::sync::atomic::Ordering::Relaxed); - if let Ok(mut cache) = self.nat_reflection_cache.try_lock() { - let slot = match family { - IpFamily::V4 => &mut cache.v4, - IpFamily::V6 => &mut cache.v6, - }; - *slot = Some((Instant::now(), result.reflected_addr)); - } - return Some(result.reflected_addr); - } - } - Err(e) => { - warn!(error = %e, stun = %stun_addr, attempt = attempt + 1, "NAT probe failed, trying next server"); - } - } + let (mut live_servers, mut selected_reflected) = self + .probe_stun_batch_for_family(&primary_servers, family, attempt) + .await; + + if selected_reflected.is_none() && !configured_servers.is_empty() && primary_servers != configured_servers { + let (rediscovered_live, rediscovered_reflected) = self + .probe_stun_batch_for_family(&configured_servers, family, attempt) + .await; + live_servers = rediscovered_live; + selected_reflected = rediscovered_reflected; } + + let live_server_count = live_servers.len(); + if !live_servers.is_empty() { + *self.nat_stun_live_servers.write().await = live_servers; + } else { + self.nat_stun_live_servers.write().await.clear(); + } + + if let Some(reflected_addr) = selected_reflected { + self.nat_probe_attempts.store(0, std::sync::atomic::Ordering::Relaxed); + info!( + family = ?family, + live_servers = live_server_count, + "STUN-Quorum reached, IP: {}", + reflected_addr.ip() + ); + if let Ok(mut cache) = self.nat_reflection_cache.try_lock() { + let slot = match family { + IpFamily::V4 => &mut cache.v4, + IpFamily::V6 => &mut cache.v6, + }; + *slot = Some((Instant::now(), reflected_addr)); + } + return Some(reflected_addr); + } + let backoff = Duration::from_secs(60 * 2u64.pow((attempt as u32).min(6))); *self.stun_backoff_until.write().await = Some(Instant::now() + backoff); None