From 7782336264368cbcab383e03a286cc689bbf445b Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Thu, 26 Feb 2026 17:56:22 +0300 Subject: [PATCH 1/7] ME Probe parallelized --- src/config/defaults.rs | 4 + src/config/hot_reload.rs | 3 + src/config/load.rs | 26 ++++ src/config/types.rs | 5 + src/main.rs | 3 + src/network/probe.rs | 171 +++++++++++++++++++++++-- src/transport/middle_proxy/pool.rs | 102 ++++++++++----- src/transport/middle_proxy/pool_nat.rs | 170 +++++++++++++++++++----- 8 files changed, 411 insertions(+), 73 deletions(-) diff --git a/src/config/defaults.rs b/src/config/defaults.rs index 3fb8c3d..4f0a53d 100644 --- a/src/config/defaults.rs +++ b/src/config/defaults.rs @@ -132,6 +132,10 @@ pub(crate) fn default_middle_proxy_nat_stun_servers() -> Vec { ] } +pub(crate) fn default_stun_nat_probe_concurrency() -> usize { + 8 +} + pub(crate) fn default_middle_proxy_warm_standby() -> usize { DEFAULT_MIDDLE_PROXY_WARM_STANDBY } diff --git a/src/config/hot_reload.rs b/src/config/hot_reload.rs index 7f121f6..c949104 100644 --- a/src/config/hot_reload.rs +++ b/src/config/hot_reload.rs @@ -96,6 +96,9 @@ fn warn_non_hot_changes(old: &ProxyConfig, new: &ProxyConfig) { if old.general.use_middle_proxy != new.general.use_middle_proxy { warn!("config reload: use_middle_proxy changed; restart required"); } + if old.general.stun_nat_probe_concurrency != new.general.stun_nat_probe_concurrency { + warn!("config reload: general.stun_nat_probe_concurrency changed; restart required"); + } } /// Resolve the public host for link generation — mirrors the logic in main.rs. diff --git a/src/config/load.rs b/src/config/load.rs index 35099be..31a8b5d 100644 --- a/src/config/load.rs +++ b/src/config/load.rs @@ -166,6 +166,12 @@ impl ProxyConfig { } } + if config.general.stun_nat_probe_concurrency == 0 { + return Err(ProxyError::Config( + "general.stun_nat_probe_concurrency must be > 0".to_string(), + )); + } + if config.general.me_reinit_every_secs == 0 { return Err(ProxyError::Config( "general.me_reinit_every_secs must be > 0".to_string(), @@ -607,6 +613,26 @@ mod tests { let _ = std::fs::remove_file(path); } + #[test] + fn stun_nat_probe_concurrency_zero_is_rejected() { + let toml = r#" + [general] + stun_nat_probe_concurrency = 0 + + [censorship] + tls_domain = "example.com" + + [access.users] + user = "00000000000000000000000000000000" + "#; + let dir = std::env::temp_dir(); + let path = dir.join("telemt_stun_nat_probe_concurrency_zero_test.toml"); + std::fs::write(&path, toml).unwrap(); + let err = ProxyConfig::load(&path).unwrap_err().to_string(); + assert!(err.contains("general.stun_nat_probe_concurrency must be > 0")); + let _ = std::fs::remove_file(path); + } + #[test] fn me_reinit_every_default_is_set() { let toml = r#" diff --git a/src/config/types.rs b/src/config/types.rs index e827088..58a3a3e 100644 --- a/src/config/types.rs +++ b/src/config/types.rs @@ -168,6 +168,10 @@ pub struct GeneralConfig { #[serde(default = "default_middle_proxy_nat_stun_servers")] pub middle_proxy_nat_stun_servers: Vec, + /// Maximum number of concurrent STUN probes during NAT detection. + #[serde(default = "default_stun_nat_probe_concurrency")] + pub stun_nat_probe_concurrency: usize, + /// Desired size of active Middle-Proxy writer pool. #[serde(default = "default_pool_size")] pub middle_proxy_pool_size: usize, @@ -378,6 +382,7 @@ impl Default for GeneralConfig { middle_proxy_nat_probe: default_true(), middle_proxy_nat_stun: default_middle_proxy_nat_stun(), middle_proxy_nat_stun_servers: default_middle_proxy_nat_stun_servers(), + stun_nat_probe_concurrency: default_stun_nat_probe_concurrency(), middle_proxy_pool_size: default_pool_size(), middle_proxy_warm_standby: default_middle_proxy_warm_standby(), me_keepalive_enabled: default_true(), diff --git a/src/main.rs b/src/main.rs index e4f7a79..dd4feb8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -257,7 +257,9 @@ async fn main() -> std::result::Result<(), Box> { let probe = run_probe( &config.network, config.general.middle_proxy_nat_stun.clone(), + config.general.middle_proxy_nat_stun_servers.clone(), config.general.middle_proxy_nat_probe, + config.general.stun_nat_probe_concurrency, ) .await?; let decision = decide_network_capabilities(&config.network, &probe); @@ -360,6 +362,7 @@ async fn main() -> std::result::Result<(), Box> { config.general.middle_proxy_nat_probe, config.general.middle_proxy_nat_stun.clone(), config.general.middle_proxy_nat_stun_servers.clone(), + config.general.stun_nat_probe_concurrency, probe.detected_ipv6, config.timeouts.me_one_retry, config.timeouts.me_one_timeout_ms, diff --git a/src/network/probe.rs b/src/network/probe.rs index c52b340..378faa5 100644 --- a/src/network/probe.rs +++ b/src/network/probe.rs @@ -1,12 +1,16 @@ #![allow(dead_code)] +use std::collections::HashMap; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket}; +use std::time::Duration; -use tracing::{info, warn}; +use tokio::task::JoinSet; +use tokio::time::timeout; +use tracing::{debug, info, warn}; use crate::config::NetworkConfig; use crate::error::Result; -use crate::network::stun::{stun_probe_dual, DualStunResult, IpFamily}; +use crate::network::stun::{stun_probe_dual, DualStunResult, IpFamily, StunProbeResult}; #[derive(Debug, Clone, Default)] pub struct NetworkProbe { @@ -49,7 +53,15 @@ impl NetworkDecision { } } -pub async fn run_probe(config: &NetworkConfig, stun_addr: Option, nat_probe: bool) -> Result { +const STUN_BATCH_TIMEOUT: Duration = Duration::from_secs(5); + +pub async fn run_probe( + config: &NetworkConfig, + stun_addr: Option, + stun_servers: Vec, + nat_probe: bool, + stun_nat_probe_concurrency: usize, +) -> Result { let mut probe = NetworkProbe::default(); probe.detected_ipv4 = detect_local_ip_v4(); @@ -58,21 +70,30 @@ pub async fn run_probe(config: &NetworkConfig, stun_addr: Option, nat_pr probe.ipv4_is_bogon = probe.detected_ipv4.map(is_bogon_v4).unwrap_or(false); probe.ipv6_is_bogon = probe.detected_ipv6.map(is_bogon_v6).unwrap_or(false); - let stun_server = stun_addr.unwrap_or_else(|| "stun.l.google.com:19302".to_string()); let stun_res = if nat_probe { - match stun_probe_dual(&stun_server).await { - Ok(res) => res, - Err(e) => { - warn!(error = %e, "STUN probe failed, continuing without reflection"); - DualStunResult::default() - } - } + let servers = collect_stun_servers(config, stun_addr, stun_servers); + probe_stun_servers_parallel( + &servers, + stun_nat_probe_concurrency.max(1), + ) + .await } else { DualStunResult::default() }; probe.reflected_ipv4 = stun_res.v4.map(|r| r.reflected_addr); probe.reflected_ipv6 = stun_res.v6.map(|r| r.reflected_addr); + // If STUN is blocked but IPv4 is private, try HTTP public-IP fallback. + if nat_probe + && probe.reflected_ipv4.is_none() + && probe.detected_ipv4.map(is_bogon_v4).unwrap_or(false) + { + if let Some(public_ip) = detect_public_ipv4_http(&config.http_ip_detect_urls).await { + probe.reflected_ipv4 = Some(SocketAddr::new(IpAddr::V4(public_ip), 0)); + info!(public_ip = %public_ip, "STUN unavailable, using HTTP public IPv4 fallback"); + } + } + probe.ipv4_nat_detected = match (probe.detected_ipv4, probe.reflected_ipv4) { (Some(det), Some(reflected)) => det != reflected.ip(), _ => false, @@ -94,6 +115,134 @@ pub async fn run_probe(config: &NetworkConfig, stun_addr: Option, nat_pr Ok(probe) } +async fn detect_public_ipv4_http(urls: &[String]) -> Option { + let client = reqwest::Client::builder() + .timeout(Duration::from_secs(3)) + .build() + .ok()?; + + for url in urls { + let response = match client.get(url).send().await { + Ok(response) => response, + Err(_) => continue, + }; + + let body = match response.text().await { + Ok(body) => body, + Err(_) => continue, + }; + + let Ok(ip) = body.trim().parse::() else { + continue; + }; + if !is_bogon_v4(ip) { + return Some(ip); + } + } + + None +} + +fn collect_stun_servers( + config: &NetworkConfig, + stun_addr: Option, + stun_servers: Vec, +) -> Vec { + let mut out = Vec::new(); + if !stun_servers.is_empty() { + for s in stun_servers { + if !s.is_empty() && !out.contains(&s) { + out.push(s); + } + } + } else if let Some(s) = stun_addr + && !s.is_empty() + { + out.push(s); + } + + if out.is_empty() { + for s in &config.stun_servers { + if !s.is_empty() && !out.contains(s) { + out.push(s.clone()); + } + } + } + + if out.is_empty() { + out.push("stun.l.google.com:19302".to_string()); + } + + out +} + +async fn probe_stun_servers_parallel( + servers: &[String], + concurrency: usize, +) -> DualStunResult { + let mut join_set = JoinSet::new(); + let mut next_idx = 0usize; + let mut best_v4_by_ip: HashMap = HashMap::new(); + let mut best_v6_by_ip: HashMap = HashMap::new(); + + while next_idx < servers.len() || !join_set.is_empty() { + while next_idx < servers.len() && join_set.len() < concurrency { + let stun_addr = servers[next_idx].clone(); + next_idx += 1; + join_set.spawn(async move { + let res = timeout(STUN_BATCH_TIMEOUT, stun_probe_dual(&stun_addr)).await; + (stun_addr, res) + }); + } + + let Some(task) = join_set.join_next().await else { + break; + }; + + match task { + Ok((stun_addr, Ok(Ok(result)))) => { + if let Some(v4) = result.v4 { + let entry = best_v4_by_ip.entry(v4.reflected_addr.ip()).or_insert((0, v4)); + entry.0 += 1; + } + if let Some(v6) = result.v6 { + let entry = best_v6_by_ip.entry(v6.reflected_addr.ip()).or_insert((0, v6)); + entry.0 += 1; + } + if result.v4.is_some() || result.v6.is_some() { + debug!(stun = %stun_addr, "STUN server responded within probe timeout"); + } + } + Ok((stun_addr, Ok(Err(e)))) => { + debug!(error = %e, stun = %stun_addr, "STUN probe failed"); + } + Ok((stun_addr, Err(_))) => { + debug!(stun = %stun_addr, "STUN probe timeout"); + } + Err(e) => { + debug!(error = %e, "STUN probe task join failed"); + } + } + } + + let mut out = DualStunResult::default(); + if let Some((_, best)) = best_v4_by_ip + .into_values() + .max_by_key(|(count, _)| *count) + { + info!("STUN-Quorum reached, IP: {}", best.reflected_addr.ip()); + out.v4 = Some(best); + } + if let Some((_, best)) = best_v6_by_ip + .into_values() + .max_by_key(|(count, _)| *count) + { + info!("STUN-Quorum reached, IP: {}", best.reflected_addr.ip()); + out.v6 = Some(best); + } + out +} + pub fn decide_network_capabilities(config: &NetworkConfig, probe: &NetworkProbe) -> NetworkDecision { let ipv4_dc = config.ipv4 && probe.detected_ipv4.is_some(); let ipv6_dc = config.ipv6.unwrap_or(probe.detected_ipv6.is_some()) && probe.detected_ipv6.is_some(); diff --git a/src/transport/middle_proxy/pool.rs b/src/transport/middle_proxy/pool.rs index e5aebe4..c95457b 100644 --- a/src/transport/middle_proxy/pool.rs +++ b/src/transport/middle_proxy/pool.rs @@ -50,6 +50,8 @@ pub struct MePool { pub(super) nat_probe: bool, pub(super) nat_stun: Option, pub(super) nat_stun_servers: Vec, + pub(super) nat_stun_live_servers: Arc>>, + pub(super) nat_probe_concurrency: usize, pub(super) detected_ipv6: Option, pub(super) nat_probe_attempts: std::sync::atomic::AtomicU8, pub(super) nat_probe_disabled: std::sync::atomic::AtomicBool, @@ -120,6 +122,7 @@ impl MePool { nat_probe: bool, nat_stun: Option, nat_stun_servers: Vec, + nat_probe_concurrency: usize, detected_ipv6: Option, me_one_retry: u8, me_one_timeout_ms: u64, @@ -162,6 +165,8 @@ impl MePool { nat_probe, nat_stun, nat_stun_servers, + nat_stun_live_servers: Arc::new(RwLock::new(Vec::new())), + nat_probe_concurrency: nat_probe_concurrency.max(1), detected_ipv6, nat_probe_attempts: std::sync::atomic::AtomicU8::new(0), nat_probe_disabled: std::sync::atomic::AtomicBool::new(false), @@ -241,6 +246,9 @@ impl MePool { pub fn reset_stun_state(&self) { self.nat_probe_attempts.store(0, Ordering::Relaxed); self.nat_probe_disabled.store(false, Ordering::Relaxed); + if let Ok(mut live) = self.nat_stun_live_servers.try_write() { + live.clear(); + } } pub fn translate_our_addr(&self, addr: SocketAddr) -> SocketAddr { @@ -896,10 +904,25 @@ impl MePool { for family in family_order { let map = self.proxy_map_for_family(family).await; - let dc_addrs: Vec<(i32, Vec<(IpAddr, u16)>)> = map - .iter() - .map(|(dc, addrs)| (*dc, addrs.clone())) + let mut grouped_dc_addrs: HashMap> = HashMap::new(); + for (dc, addrs) in map { + if addrs.is_empty() { + continue; + } + grouped_dc_addrs + .entry(dc.abs()) + .or_default() + .extend(addrs); + } + let mut dc_addrs: Vec<(i32, Vec<(IpAddr, u16)>)> = grouped_dc_addrs + .into_iter() + .map(|(dc, mut addrs)| { + addrs.sort_unstable(); + addrs.dedup(); + (dc, addrs) + }) .collect(); + dc_addrs.sort_unstable_by_key(|(dc, _)| *dc); // Ensure at least one connection per DC; run DCs in parallel. let mut join = tokio::task::JoinSet::new(); @@ -923,38 +946,49 @@ impl MePool { return Err(ProxyError::Proxy("Too many ME DC init failures, falling back to direct".into())); } - // Additional connections up to pool_size total (round-robin across DCs), staggered to de-phase lifecycles. - if self.me_warmup_stagger_enabled { - for (dc, addrs) in dc_addrs.iter() { - for (ip, port) in addrs { - if self.connection_count() >= pool_size { - break; + // Warm reserve writers asynchronously so startup does not block after first working pool is ready. + let pool = Arc::clone(self); + let rng_clone = Arc::clone(rng); + let dc_addrs_bg = dc_addrs.clone(); + tokio::spawn(async move { + if pool.me_warmup_stagger_enabled { + for (dc, addrs) in dc_addrs_bg.iter() { + for (ip, port) in addrs { + if pool.connection_count() >= pool_size { + break; + } + let addr = SocketAddr::new(*ip, *port); + let jitter = rand::rng() + .random_range(0..=pool.me_warmup_step_jitter.as_millis() as u64); + let delay_ms = pool.me_warmup_step_delay.as_millis() as u64 + jitter; + tokio::time::sleep(Duration::from_millis(delay_ms)).await; + if let Err(e) = pool.connect_one(addr, rng_clone.as_ref()).await { + debug!(%addr, dc = %dc, error = %e, "Extra ME connect failed (staggered)"); + } + } } - let addr = SocketAddr::new(*ip, *port); - let jitter = rand::rng().random_range(0..=self.me_warmup_step_jitter.as_millis() as u64); - let delay_ms = self.me_warmup_step_delay.as_millis() as u64 + jitter; - tokio::time::sleep(Duration::from_millis(delay_ms)).await; - if let Err(e) = self.connect_one(addr, rng.as_ref()).await { - debug!(%addr, dc = %dc, error = %e, "Extra ME connect failed (staggered)"); + } else { + for (dc, addrs) in dc_addrs_bg.iter() { + for (ip, port) in addrs { + if pool.connection_count() >= pool_size { + break; + } + let addr = SocketAddr::new(*ip, *port); + if let Err(e) = pool.connect_one(addr, rng_clone.as_ref()).await { + debug!(%addr, dc = %dc, error = %e, "Extra ME connect failed"); + } + } + if pool.connection_count() >= pool_size { + break; + } } } - } - } else { - for (dc, addrs) in dc_addrs.iter() { - for (ip, port) in addrs { - if self.connection_count() >= pool_size { - break; - } - let addr = SocketAddr::new(*ip, *port); - if let Err(e) = self.connect_one(addr, rng.as_ref()).await { - debug!(%addr, dc = %dc, error = %e, "Extra ME connect failed"); - } - } - if self.connection_count() >= pool_size { - break; - } - } - } + debug!( + target_pool_size = pool_size, + current_pool_size = pool.connection_count(), + "Background ME reserve warmup finished" + ); + }); if !self.decision.effective_multipath && self.connection_count() > 0 { break; @@ -964,6 +998,10 @@ impl MePool { if self.writers.read().await.is_empty() { return Err(ProxyError::Proxy("No ME connections".into())); } + info!( + active_writers = self.connection_count(), + "ME primary pool ready; reserve warmup continues in background" + ); Ok(()) } diff --git a/src/transport/middle_proxy/pool_nat.rs b/src/transport/middle_proxy/pool_nat.rs index 9936707..37c0d5b 100644 --- a/src/transport/middle_proxy/pool_nat.rs +++ b/src/transport/middle_proxy/pool_nat.rs @@ -1,7 +1,10 @@ +use std::collections::HashMap; use std::net::{IpAddr, Ipv4Addr}; use std::time::Duration; -use tracing::{info, warn}; +use tokio::task::JoinSet; +use tokio::time::timeout; +use tracing::{debug, info, warn}; use crate::error::{ProxyError, Result}; use crate::network::probe::is_bogon; @@ -10,6 +13,8 @@ use crate::network::stun::{stun_probe_dual, IpFamily, StunProbeResult}; use super::MePool; use std::time::Instant; +const STUN_BATCH_TIMEOUT: Duration = Duration::from_secs(5); + #[allow(dead_code)] pub async fn stun_probe(stun_addr: Option) -> Result { let stun_addr = stun_addr.unwrap_or_else(|| "stun.l.google.com:19302".to_string()); @@ -22,6 +27,99 @@ pub async fn detect_public_ip() -> Option { } impl MePool { + fn configured_stun_servers(&self) -> Vec { + if !self.nat_stun_servers.is_empty() { + return self.nat_stun_servers.clone(); + } + if let Some(s) = &self.nat_stun { + return vec![s.clone()]; + } + vec!["stun.l.google.com:19302".to_string()] + } + + async fn probe_stun_batch_for_family( + &self, + servers: &[String], + family: IpFamily, + attempt: u8, + ) -> (Vec, Option) { + let mut join_set = JoinSet::new(); + let mut next_idx = 0usize; + let mut live_servers = Vec::new(); + let mut best_by_ip: HashMap = HashMap::new(); + let concurrency = self.nat_probe_concurrency.max(1); + + while next_idx < servers.len() || !join_set.is_empty() { + while next_idx < servers.len() && join_set.len() < concurrency { + let stun_addr = servers[next_idx].clone(); + next_idx += 1; + join_set.spawn(async move { + let res = timeout(STUN_BATCH_TIMEOUT, stun_probe_dual(&stun_addr)).await; + (stun_addr, res) + }); + } + + let Some(task) = join_set.join_next().await else { + break; + }; + + match task { + Ok((stun_addr, Ok(Ok(res)))) => { + let picked: Option = match family { + IpFamily::V4 => res.v4, + IpFamily::V6 => res.v6, + }; + + if let Some(result) = picked { + live_servers.push(stun_addr.clone()); + let entry = best_by_ip + .entry(result.reflected_addr.ip()) + .or_insert((0, result.reflected_addr)); + entry.0 += 1; + debug!( + local = %result.local_addr, + reflected = %result.reflected_addr, + family = ?family, + stun = %stun_addr, + "NAT probe: reflected address" + ); + } + } + Ok((stun_addr, Ok(Err(e)))) => { + debug!( + error = %e, + stun = %stun_addr, + attempt = attempt + 1, + "NAT probe failed, trying next server" + ); + } + Ok((stun_addr, Err(_))) => { + debug!( + stun = %stun_addr, + attempt = attempt + 1, + "NAT probe timeout, trying next server" + ); + } + Err(e) => { + debug!( + error = %e, + attempt = attempt + 1, + "NAT probe task join failed" + ); + } + } + } + + live_servers.sort_unstable(); + live_servers.dedup(); + let best_reflected = best_by_ip + .into_values() + .max_by_key(|(count, _)| *count) + .map(|(_, addr)| addr); + + (live_servers, best_reflected) + } + pub(super) fn translate_ip_for_nat(&self, ip: IpAddr) -> IpAddr { let nat_ip = self .nat_ip_cfg @@ -128,39 +226,51 @@ impl MePool { } let attempt = self.nat_probe_attempts.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - let servers = if !self.nat_stun_servers.is_empty() { - self.nat_stun_servers.clone() - } else if let Some(s) = &self.nat_stun { - vec![s.clone()] + let configured_servers = self.configured_stun_servers(); + let live_snapshot = self.nat_stun_live_servers.read().await.clone(); + let primary_servers = if live_snapshot.is_empty() { + configured_servers.clone() } else { - vec!["stun.l.google.com:19302".to_string()] + live_snapshot }; - for stun_addr in servers { - match stun_probe_dual(&stun_addr).await { - Ok(res) => { - let picked: Option = match family { - IpFamily::V4 => res.v4, - IpFamily::V6 => res.v6, - }; - if let Some(result) = picked { - info!(local = %result.local_addr, reflected = %result.reflected_addr, family = ?family, stun = %stun_addr, "NAT probe: reflected address"); - self.nat_probe_attempts.store(0, std::sync::atomic::Ordering::Relaxed); - if let Ok(mut cache) = self.nat_reflection_cache.try_lock() { - let slot = match family { - IpFamily::V4 => &mut cache.v4, - IpFamily::V6 => &mut cache.v6, - }; - *slot = Some((Instant::now(), result.reflected_addr)); - } - return Some(result.reflected_addr); - } - } - Err(e) => { - warn!(error = %e, stun = %stun_addr, attempt = attempt + 1, "NAT probe failed, trying next server"); - } - } + let (mut live_servers, mut selected_reflected) = self + .probe_stun_batch_for_family(&primary_servers, family, attempt) + .await; + + if selected_reflected.is_none() && !configured_servers.is_empty() && primary_servers != configured_servers { + let (rediscovered_live, rediscovered_reflected) = self + .probe_stun_batch_for_family(&configured_servers, family, attempt) + .await; + live_servers = rediscovered_live; + selected_reflected = rediscovered_reflected; } + + let live_server_count = live_servers.len(); + if !live_servers.is_empty() { + *self.nat_stun_live_servers.write().await = live_servers; + } else { + self.nat_stun_live_servers.write().await.clear(); + } + + if let Some(reflected_addr) = selected_reflected { + self.nat_probe_attempts.store(0, std::sync::atomic::Ordering::Relaxed); + info!( + family = ?family, + live_servers = live_server_count, + "STUN-Quorum reached, IP: {}", + reflected_addr.ip() + ); + if let Ok(mut cache) = self.nat_reflection_cache.try_lock() { + let slot = match family { + IpFamily::V4 => &mut cache.v4, + IpFamily::V6 => &mut cache.v6, + }; + *slot = Some((Instant::now(), reflected_addr)); + } + return Some(reflected_addr); + } + let backoff = Duration::from_secs(60 * 2u64.pow((attempt as u32).min(6))); *self.stun_backoff_until.write().await = Some(Instant::now() + backoff); None From 9d2ff25bf577cd328dca4686226b39d2a620aeba Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Thu, 26 Feb 2026 18:18:24 +0300 Subject: [PATCH 2/7] Unified STUN + ME Primary parallelized - Unified STUN server source-of-truth - parallelize per-DC primary ME init for multi-endpoint DCs --- src/config/defaults.rs | 18 ++-------- src/config/load.rs | 34 ++++++++++++++++++ src/config/types.rs | 6 ++-- src/main.rs | 6 ++-- src/network/probe.rs | 50 ++++++++------------------ src/transport/middle_proxy/pool.rs | 29 +++++++++++++++ src/transport/middle_proxy/pool_nat.rs | 16 +++++++-- 7 files changed, 99 insertions(+), 60 deletions(-) diff --git a/src/config/defaults.rs b/src/config/defaults.rs index 4f0a53d..d82f8ed 100644 --- a/src/config/defaults.rs +++ b/src/config/defaults.rs @@ -111,25 +111,11 @@ pub(crate) fn default_proxy_secret_path() -> Option { } pub(crate) fn default_middle_proxy_nat_stun() -> Option { - Some("stun.l.google.com:19302".to_string()) + None } pub(crate) fn default_middle_proxy_nat_stun_servers() -> Vec { - vec![ - "stun.l.google.com:5349".to_string(), - "stun1.l.google.com:3478".to_string(), - "stun.gmx.net:3478".to_string(), - "stun.l.google.com:19302".to_string(), - "stun.1und1.de:3478".to_string(), - "stun1.l.google.com:19302".to_string(), - "stun2.l.google.com:19302".to_string(), - "stun3.l.google.com:19302".to_string(), - "stun4.l.google.com:19302".to_string(), - "stun.services.mozilla.com:3478".to_string(), - "stun.stunprotocol.org:3478".to_string(), - "stun.nextcloud.com:3478".to_string(), - "stun.voip.eutelia.it:3478".to_string(), - ] + Vec::new() } pub(crate) fn default_stun_nat_probe_concurrency() -> usize { diff --git a/src/config/load.rs b/src/config/load.rs index 31a8b5d..0c1e629 100644 --- a/src/config/load.rs +++ b/src/config/load.rs @@ -65,6 +65,16 @@ fn validate_network_cfg(net: &mut NetworkConfig) -> Result<()> { Ok(()) } +fn push_unique_nonempty(target: &mut Vec, value: String) { + let trimmed = value.trim(); + if trimmed.is_empty() { + return; + } + if !target.iter().any(|existing| existing == trimmed) { + target.push(trimmed.to_string()); + } +} + // ============= Main Config ============= #[derive(Debug, Clone, Serialize, Deserialize, Default)] @@ -138,6 +148,30 @@ impl ProxyConfig { config.general.update_every = None; } + let legacy_nat_stun = config.general.middle_proxy_nat_stun.take(); + let legacy_nat_stun_servers = std::mem::take(&mut config.general.middle_proxy_nat_stun_servers); + let legacy_nat_stun_used = legacy_nat_stun.is_some() || !legacy_nat_stun_servers.is_empty(); + + let mut unified_stun_servers = Vec::new(); + for stun in std::mem::take(&mut config.network.stun_servers) { + push_unique_nonempty(&mut unified_stun_servers, stun); + } + if let Some(stun) = legacy_nat_stun { + push_unique_nonempty(&mut unified_stun_servers, stun); + } + for stun in legacy_nat_stun_servers { + push_unique_nonempty(&mut unified_stun_servers, stun); + } + + if unified_stun_servers.is_empty() { + unified_stun_servers = default_stun_servers(); + } + config.network.stun_servers = unified_stun_servers; + + if legacy_nat_stun_used { + warn!("general.middle_proxy_nat_stun and general.middle_proxy_nat_stun_servers are deprecated; use network.stun_servers"); + } + if let Some(update_every) = config.general.update_every { if update_every == 0 { return Err(ProxyError::Config( diff --git a/src/config/types.rs b/src/config/types.rs index 58a3a3e..68086be 100644 --- a/src/config/types.rs +++ b/src/config/types.rs @@ -160,11 +160,13 @@ pub struct GeneralConfig { #[serde(default = "default_true")] pub middle_proxy_nat_probe: bool, - /// Optional STUN server address (host:port) for NAT probing. + /// Deprecated legacy single STUN server for NAT probing. + /// Use `network.stun_servers` instead. #[serde(default = "default_middle_proxy_nat_stun")] pub middle_proxy_nat_stun: Option, - /// Optional list of STUN servers for NAT probing fallback. + /// Deprecated legacy STUN list for NAT probing fallback. + /// Use `network.stun_servers` instead. #[serde(default = "default_middle_proxy_nat_stun_servers")] pub middle_proxy_nat_stun_servers: Vec, diff --git a/src/main.rs b/src/main.rs index dd4feb8..db46a6d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -256,8 +256,6 @@ async fn main() -> std::result::Result<(), Box> { let probe = run_probe( &config.network, - config.general.middle_proxy_nat_stun.clone(), - config.general.middle_proxy_nat_stun_servers.clone(), config.general.middle_proxy_nat_probe, config.general.stun_nat_probe_concurrency, ) @@ -360,8 +358,8 @@ async fn main() -> std::result::Result<(), Box> { proxy_secret, config.general.middle_proxy_nat_ip, config.general.middle_proxy_nat_probe, - config.general.middle_proxy_nat_stun.clone(), - config.general.middle_proxy_nat_stun_servers.clone(), + None, + config.network.stun_servers.clone(), config.general.stun_nat_probe_concurrency, probe.detected_ipv6, config.timeouts.me_one_retry, diff --git a/src/network/probe.rs b/src/network/probe.rs index 378faa5..6e84682 100644 --- a/src/network/probe.rs +++ b/src/network/probe.rs @@ -57,8 +57,6 @@ const STUN_BATCH_TIMEOUT: Duration = Duration::from_secs(5); pub async fn run_probe( config: &NetworkConfig, - stun_addr: Option, - stun_servers: Vec, nat_probe: bool, stun_nat_probe_concurrency: usize, ) -> Result { @@ -71,12 +69,17 @@ pub async fn run_probe( probe.ipv6_is_bogon = probe.detected_ipv6.map(is_bogon_v6).unwrap_or(false); let stun_res = if nat_probe { - let servers = collect_stun_servers(config, stun_addr, stun_servers); - probe_stun_servers_parallel( - &servers, - stun_nat_probe_concurrency.max(1), - ) - .await + let servers = collect_stun_servers(config); + if servers.is_empty() { + warn!("STUN probe is enabled but network.stun_servers is empty"); + DualStunResult::default() + } else { + probe_stun_servers_parallel( + &servers, + stun_nat_probe_concurrency.max(1), + ) + .await + } } else { DualStunResult::default() }; @@ -143,36 +146,13 @@ async fn detect_public_ipv4_http(urls: &[String]) -> Option { None } -fn collect_stun_servers( - config: &NetworkConfig, - stun_addr: Option, - stun_servers: Vec, -) -> Vec { +fn collect_stun_servers(config: &NetworkConfig) -> Vec { let mut out = Vec::new(); - if !stun_servers.is_empty() { - for s in stun_servers { - if !s.is_empty() && !out.contains(&s) { - out.push(s); - } - } - } else if let Some(s) = stun_addr - && !s.is_empty() - { - out.push(s); - } - - if out.is_empty() { - for s in &config.stun_servers { - if !s.is_empty() && !out.contains(s) { - out.push(s.clone()); - } + for s in &config.stun_servers { + if !s.is_empty() && !out.contains(s) { + out.push(s.clone()); } } - - if out.is_empty() { - out.push("stun.l.google.com:19302".to_string()); - } - out } diff --git a/src/transport/middle_proxy/pool.rs b/src/transport/middle_proxy/pool.rs index c95457b..21c2b87 100644 --- a/src/transport/middle_proxy/pool.rs +++ b/src/transport/middle_proxy/pool.rs @@ -1199,6 +1199,35 @@ impl MePool { return false; } addrs.shuffle(&mut rand::rng()); + if addrs.len() > 1 { + let mut join = tokio::task::JoinSet::new(); + for (ip, port) in addrs { + let addr = SocketAddr::new(ip, port); + let pool = Arc::clone(&self); + let rng_clone = Arc::clone(&rng); + join.spawn(async move { (addr, pool.connect_one(addr, rng_clone.as_ref()).await) }); + } + + while let Some(res) = join.join_next().await { + match res { + Ok((addr, Ok(()))) => { + info!(%addr, dc = %dc, "ME connected"); + join.abort_all(); + while join.join_next().await.is_some() {} + return true; + } + Ok((addr, Err(e))) => { + warn!(%addr, dc = %dc, error = %e, "ME connect failed, trying next"); + } + Err(e) => { + warn!(dc = %dc, error = %e, "ME connect task failed"); + } + } + } + warn!(dc = %dc, "All ME servers for DC failed at init"); + return false; + } + for (ip, port) in addrs { let addr = SocketAddr::new(ip, port); match self.connect_one(addr, rng.as_ref()).await { diff --git a/src/transport/middle_proxy/pool_nat.rs b/src/transport/middle_proxy/pool_nat.rs index 37c0d5b..7141236 100644 --- a/src/transport/middle_proxy/pool_nat.rs +++ b/src/transport/middle_proxy/pool_nat.rs @@ -17,7 +17,15 @@ const STUN_BATCH_TIMEOUT: Duration = Duration::from_secs(5); #[allow(dead_code)] pub async fn stun_probe(stun_addr: Option) -> Result { - let stun_addr = stun_addr.unwrap_or_else(|| "stun.l.google.com:19302".to_string()); + let stun_addr = stun_addr.unwrap_or_else(|| { + crate::config::defaults::default_stun_servers() + .into_iter() + .next() + .unwrap_or_default() + }); + if stun_addr.is_empty() { + return Err(ProxyError::Proxy("STUN server is not configured".to_string())); + } stun_probe_dual(&stun_addr).await } @@ -31,10 +39,12 @@ impl MePool { if !self.nat_stun_servers.is_empty() { return self.nat_stun_servers.clone(); } - if let Some(s) = &self.nat_stun { + if let Some(s) = &self.nat_stun + && !s.trim().is_empty() + { return vec![s.clone()]; } - vec!["stun.l.google.com:19302".to_string()] + Vec::new() } async fn probe_stun_batch_for_family( From 1f255d0aa493460d7a1d68041bf40fa2ae3a7316 Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Thu, 26 Feb 2026 18:41:11 +0300 Subject: [PATCH 3/7] ME Probe + STUN Legacy --- src/config/load.rs | 44 +++++++++++++++++++----------- src/transport/middle_proxy/pool.rs | 22 ++++++++++----- 2 files changed, 43 insertions(+), 23 deletions(-) diff --git a/src/config/load.rs b/src/config/load.rs index 0c1e629..4e0e104 100644 --- a/src/config/load.rs +++ b/src/config/load.rs @@ -131,6 +131,9 @@ impl ProxyConfig { let general_table = parsed_toml .get("general") .and_then(|value| value.as_table()); + let network_table = parsed_toml + .get("network") + .and_then(|value| value.as_table()); let update_every_is_explicit = general_table .map(|table| table.contains_key("update_every")) .unwrap_or(false); @@ -140,6 +143,9 @@ impl ProxyConfig { let legacy_config_is_explicit = general_table .map(|table| table.contains_key("proxy_config_auto_reload_secs")) .unwrap_or(false); + let stun_servers_is_explicit = network_table + .map(|table| table.contains_key("stun_servers")) + .unwrap_or(false); let mut config: ProxyConfig = parsed_toml.try_into().map_err(|e| ProxyError::Config(e.to_string()))?; @@ -151,25 +157,31 @@ impl ProxyConfig { let legacy_nat_stun = config.general.middle_proxy_nat_stun.take(); let legacy_nat_stun_servers = std::mem::take(&mut config.general.middle_proxy_nat_stun_servers); let legacy_nat_stun_used = legacy_nat_stun.is_some() || !legacy_nat_stun_servers.is_empty(); + if stun_servers_is_explicit { + let mut explicit_stun_servers = Vec::new(); + for stun in std::mem::take(&mut config.network.stun_servers) { + push_unique_nonempty(&mut explicit_stun_servers, stun); + } + config.network.stun_servers = explicit_stun_servers; - let mut unified_stun_servers = Vec::new(); - for stun in std::mem::take(&mut config.network.stun_servers) { - push_unique_nonempty(&mut unified_stun_servers, stun); - } - if let Some(stun) = legacy_nat_stun { - push_unique_nonempty(&mut unified_stun_servers, stun); - } - for stun in legacy_nat_stun_servers { - push_unique_nonempty(&mut unified_stun_servers, stun); - } + if legacy_nat_stun_used { + warn!("general.middle_proxy_nat_stun and general.middle_proxy_nat_stun_servers are ignored because network.stun_servers is explicitly set"); + } + } else { + // Keep the default STUN pool unless network.stun_servers is explicitly overridden. + let mut unified_stun_servers = default_stun_servers(); + if let Some(stun) = legacy_nat_stun { + push_unique_nonempty(&mut unified_stun_servers, stun); + } + for stun in legacy_nat_stun_servers { + push_unique_nonempty(&mut unified_stun_servers, stun); + } - if unified_stun_servers.is_empty() { - unified_stun_servers = default_stun_servers(); - } - config.network.stun_servers = unified_stun_servers; + config.network.stun_servers = unified_stun_servers; - if legacy_nat_stun_used { - warn!("general.middle_proxy_nat_stun and general.middle_proxy_nat_stun_servers are deprecated; use network.stun_servers"); + if legacy_nat_stun_used { + warn!("general.middle_proxy_nat_stun and general.middle_proxy_nat_stun_servers are deprecated; use network.stun_servers"); + } } if let Some(update_every) = config.general.update_every { diff --git a/src/transport/middle_proxy/pool.rs b/src/transport/middle_proxy/pool.rs index 21c2b87..a90899d 100644 --- a/src/transport/middle_proxy/pool.rs +++ b/src/transport/middle_proxy/pool.rs @@ -1200,15 +1200,23 @@ impl MePool { } addrs.shuffle(&mut rand::rng()); if addrs.len() > 1 { + let concurrency = 2usize; let mut join = tokio::task::JoinSet::new(); - for (ip, port) in addrs { - let addr = SocketAddr::new(ip, port); - let pool = Arc::clone(&self); - let rng_clone = Arc::clone(&rng); - join.spawn(async move { (addr, pool.connect_one(addr, rng_clone.as_ref()).await) }); - } + let mut next_idx = 0usize; - while let Some(res) = join.join_next().await { + while next_idx < addrs.len() || !join.is_empty() { + while next_idx < addrs.len() && join.len() < concurrency { + let (ip, port) = addrs[next_idx]; + next_idx += 1; + let addr = SocketAddr::new(ip, port); + let pool = Arc::clone(&self); + let rng_clone = Arc::clone(&rng); + join.spawn(async move { (addr, pool.connect_one(addr, rng_clone.as_ref()).await) }); + } + + let Some(res) = join.join_next().await else { + break; + }; match res { Ok((addr, Ok(()))) => { info!(%addr, dc = %dc, "ME connected"); From 4eebb4feb2342e27d0b52df1fc92065aa22aa19a Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Thu, 26 Feb 2026 19:01:24 +0300 Subject: [PATCH 4/7] ME Pool Refactoring --- src/main.rs | 43 +- src/transport/middle_proxy/health.rs | 1 + src/transport/middle_proxy/mod.rs | 9 +- src/transport/middle_proxy/pool.rs | 1128 +-------------------- src/transport/middle_proxy/pool_config.rs | 81 ++ src/transport/middle_proxy/pool_init.rs | 201 ++++ src/transport/middle_proxy/pool_refill.rs | 159 +++ src/transport/middle_proxy/pool_reinit.rs | 383 +++++++ src/transport/middle_proxy/pool_writer.rs | 348 +++++++ 9 files changed, 1226 insertions(+), 1127 deletions(-) create mode 100644 src/transport/middle_proxy/pool_config.rs create mode 100644 src/transport/middle_proxy/pool_init.rs create mode 100644 src/transport/middle_proxy/pool_refill.rs create mode 100644 src/transport/middle_proxy/pool_reinit.rs create mode 100644 src/transport/middle_proxy/pool_writer.rs diff --git a/src/main.rs b/src/main.rs index db46a6d..da88fe3 100644 --- a/src/main.rs +++ b/src/main.rs @@ -392,26 +392,33 @@ async fn main() -> std::result::Result<(), Box> { ); let pool_size = config.general.middle_proxy_pool_size.max(1); - match pool.init(pool_size, &rng).await { - Ok(()) => { - info!("Middle-End pool initialized successfully"); + loop { + match pool.init(pool_size, &rng).await { + Ok(()) => { + info!("Middle-End pool initialized successfully"); - // Phase 4: Start health monitor - let pool_clone = pool.clone(); - let rng_clone = rng.clone(); - let min_conns = pool_size; - tokio::spawn(async move { - crate::transport::middle_proxy::me_health_monitor( - pool_clone, rng_clone, min_conns, - ) - .await; - }); + // Phase 4: Start health monitor + let pool_clone = pool.clone(); + let rng_clone = rng.clone(); + let min_conns = pool_size; + tokio::spawn(async move { + crate::transport::middle_proxy::me_health_monitor( + pool_clone, rng_clone, min_conns, + ) + .await; + }); - Some(pool) - } - Err(e) => { - error!(error = %e, "Failed to initialize ME pool. Falling back to direct mode."); - None + break Some(pool); + } + Err(e) => { + warn!( + error = %e, + retry_in_secs = 2, + "ME pool is not ready yet; retrying startup initialization" + ); + pool.reset_stun_state(); + tokio::time::sleep(Duration::from_secs(2)).await; + } } } } diff --git a/src/transport/middle_proxy/health.rs b/src/transport/middle_proxy/health.rs index dde3354..06cca03 100644 --- a/src/transport/middle_proxy/health.rs +++ b/src/transport/middle_proxy/health.rs @@ -22,6 +22,7 @@ pub async fn me_health_monitor(pool: Arc, rng: Arc, _min_c let mut inflight: HashMap<(i32, IpFamily), usize> = HashMap::new(); loop { tokio::time::sleep(Duration::from_secs(HEALTH_INTERVAL_SECS)).await; + pool.prune_closed_writers().await; check_family( IpFamily::V4, &pool, diff --git a/src/transport/middle_proxy/mod.rs b/src/transport/middle_proxy/mod.rs index f9f8c85..3a4ff16 100644 --- a/src/transport/middle_proxy/mod.rs +++ b/src/transport/middle_proxy/mod.rs @@ -1,17 +1,22 @@ //! Middle Proxy RPC transport. mod codec; +mod config_updater; mod handshake; mod health; mod pool; +mod pool_config; +mod pool_init; mod pool_nat; +mod pool_refill; +mod pool_reinit; +mod pool_writer; mod ping; mod reader; mod registry; +mod rotation; mod send; mod secret; -mod rotation; -mod config_updater; mod wire; use bytes::Bytes; diff --git a/src/transport/middle_proxy/pool.rs b/src/transport/middle_proxy/pool.rs index a90899d..1e43628 100644 --- a/src/transport/middle_proxy/pool.rs +++ b/src/transport/middle_proxy/pool.rs @@ -2,26 +2,17 @@ use std::collections::{HashMap, HashSet}; use std::net::{IpAddr, Ipv6Addr, SocketAddr}; use std::sync::Arc; use std::sync::atomic::{AtomicBool, AtomicI32, AtomicU32, AtomicU64, AtomicUsize, Ordering}; -use bytes::BytesMut; -use rand::Rng; -use rand::seq::SliceRandom; -use tokio::sync::{Mutex, RwLock, mpsc, Notify}; -use tokio_util::sync::CancellationToken; -use tracing::{debug, info, warn}; use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; +use tokio::sync::{Mutex, Notify, RwLock, mpsc}; +use tokio_util::sync::CancellationToken; + use crate::crypto::SecureRandom; -use crate::error::{ProxyError, Result}; -use crate::network::probe::NetworkDecision; use crate::network::IpFamily; -use crate::protocol::constants::*; +use crate::network::probe::NetworkDecision; use super::ConnRegistry; -use super::registry::BoundConn; -use super::codec::{RpcWriter, WriterCommand}; -use super::reader::reader_loop; -const ME_ACTIVE_PING_SECS: u64 = 25; -const ME_ACTIVE_PING_JITTER_SECS: i64 = 5; +use super::codec::WriterCommand; #[derive(Clone)] pub struct MeWriter { @@ -104,11 +95,11 @@ impl MePool { (clamped * 1000.0).round() as u32 } - fn permille_to_ratio(permille: u32) -> f32 { + pub(super) fn permille_to_ratio(permille: u32) -> f32 { (permille.min(1000) as f32) / 1000.0 } - fn now_epoch_secs() -> u64 { + pub(super) fn now_epoch_secs() -> u64 { SystemTime::now() .duration_since(UNIX_EPOCH) .unwrap_or_default() @@ -200,11 +191,15 @@ impl MePool { hardswap: AtomicBool::new(hardswap), me_pool_drain_ttl_secs: AtomicU64::new(me_pool_drain_ttl_secs), me_pool_force_close_secs: AtomicU64::new(me_pool_force_close_secs), - me_pool_min_fresh_ratio_permille: AtomicU32::new(Self::ratio_to_permille(me_pool_min_fresh_ratio)), + me_pool_min_fresh_ratio_permille: AtomicU32::new(Self::ratio_to_permille( + me_pool_min_fresh_ratio, + )), me_hardswap_warmup_delay_min_ms: AtomicU64::new(me_hardswap_warmup_delay_min_ms), me_hardswap_warmup_delay_max_ms: AtomicU64::new(me_hardswap_warmup_delay_max_ms), me_hardswap_warmup_extra_passes: AtomicU32::new(me_hardswap_warmup_extra_passes as u32), - me_hardswap_warmup_pass_backoff_base_ms: AtomicU64::new(me_hardswap_warmup_pass_backoff_base_ms), + me_hardswap_warmup_pass_backoff_base_ms: AtomicU64::new( + me_hardswap_warmup_pass_backoff_base_ms, + ), }) } @@ -228,7 +223,8 @@ impl MePool { hardswap_warmup_pass_backoff_base_ms: u64, ) { self.hardswap.store(hardswap, Ordering::Relaxed); - self.me_pool_drain_ttl_secs.store(drain_ttl_secs, Ordering::Relaxed); + self.me_pool_drain_ttl_secs + .store(drain_ttl_secs, Ordering::Relaxed); self.me_pool_force_close_secs .store(force_close_secs, Ordering::Relaxed); self.me_pool_min_fresh_ratio_permille @@ -260,11 +256,11 @@ impl MePool { &self.registry } - fn writers_arc(&self) -> Arc>> { + pub(super) fn writers_arc(&self) -> Arc>> { self.writers.clone() } - fn force_close_timeout(&self) -> Option { + pub(super) fn force_close_timeout(&self) -> Option { let secs = self.me_pool_force_close_secs.load(Ordering::Relaxed); if secs == 0 { None @@ -273,588 +269,6 @@ impl MePool { } } - fn coverage_ratio( - desired_by_dc: &HashMap>, - active_writer_addrs: &HashSet, - ) -> (f32, Vec) { - if desired_by_dc.is_empty() { - return (1.0, Vec::new()); - } - - let mut missing_dc = Vec::::new(); - let mut covered = 0usize; - for (dc, endpoints) in desired_by_dc { - if endpoints.is_empty() { - continue; - } - if endpoints.iter().any(|addr| active_writer_addrs.contains(addr)) { - covered += 1; - } else { - missing_dc.push(*dc); - } - } - - missing_dc.sort_unstable(); - let total = desired_by_dc.len().max(1); - let ratio = (covered as f32) / (total as f32); - (ratio, missing_dc) - } - - pub async fn reconcile_connections(self: &Arc, rng: &SecureRandom) { - let writers = self.writers.read().await; - let current: HashSet = writers - .iter() - .filter(|w| !w.draining.load(Ordering::Relaxed)) - .map(|w| w.addr) - .collect(); - drop(writers); - - for family in self.family_order() { - let map = self.proxy_map_for_family(family).await; - for (_dc, addrs) in map.iter() { - let dc_addrs: Vec = addrs - .iter() - .map(|(ip, port)| SocketAddr::new(*ip, *port)) - .collect(); - if !dc_addrs.iter().any(|a| current.contains(a)) { - let mut shuffled = dc_addrs.clone(); - shuffled.shuffle(&mut rand::rng()); - for addr in shuffled { - if self.connect_one(addr, rng).await.is_ok() { - break; - } - } - } - } - if !self.decision.effective_multipath && !current.is_empty() { - break; - } - } - } - - async fn desired_dc_endpoints(&self) -> HashMap> { - let mut out: HashMap> = HashMap::new(); - - if self.decision.ipv4_me { - let map_v4 = self.proxy_map_v4.read().await.clone(); - for (dc, addrs) in map_v4 { - let entry = out.entry(dc.abs()).or_default(); - for (ip, port) in addrs { - entry.insert(SocketAddr::new(ip, port)); - } - } - } - - if self.decision.ipv6_me { - let map_v6 = self.proxy_map_v6.read().await.clone(); - for (dc, addrs) in map_v6 { - let entry = out.entry(dc.abs()).or_default(); - for (ip, port) in addrs { - entry.insert(SocketAddr::new(ip, port)); - } - } - } - - out - } - - pub(super) fn required_writers_for_dc(endpoint_count: usize) -> usize { - endpoint_count.max(3) - } - - fn hardswap_warmup_connect_delay_ms(&self) -> u64 { - let min_ms = self - .me_hardswap_warmup_delay_min_ms - .load(Ordering::Relaxed); - let max_ms = self - .me_hardswap_warmup_delay_max_ms - .load(Ordering::Relaxed); - let (min_ms, max_ms) = if min_ms <= max_ms { - (min_ms, max_ms) - } else { - (max_ms, min_ms) - }; - if min_ms == max_ms { - return min_ms; - } - rand::rng().random_range(min_ms..=max_ms) - } - - fn hardswap_warmup_backoff_ms(&self, pass_idx: usize) -> u64 { - let base_ms = self - .me_hardswap_warmup_pass_backoff_base_ms - .load(Ordering::Relaxed); - let cap_ms = (self.me_reconnect_backoff_cap.as_millis() as u64).max(base_ms); - let shift = (pass_idx as u32).min(20); - let scaled = base_ms.saturating_mul(1u64 << shift); - let core = scaled.min(cap_ms); - let jitter = (core / 2).max(1); - core.saturating_add(rand::rng().random_range(0..=jitter)) - } - - async fn fresh_writer_count_for_endpoints( - &self, - generation: u64, - endpoints: &HashSet, - ) -> usize { - let ws = self.writers.read().await; - ws.iter() - .filter(|w| !w.draining.load(Ordering::Relaxed)) - .filter(|w| w.generation == generation) - .filter(|w| endpoints.contains(&w.addr)) - .count() - } - - pub(super) async fn connect_endpoints_round_robin( - self: &Arc, - endpoints: &[SocketAddr], - rng: &SecureRandom, - ) -> bool { - if endpoints.is_empty() { - return false; - } - let start = (self.rr.fetch_add(1, Ordering::Relaxed) as usize) % endpoints.len(); - for offset in 0..endpoints.len() { - let idx = (start + offset) % endpoints.len(); - let addr = endpoints[idx]; - match self.connect_one(addr, rng).await { - Ok(()) => return true, - Err(e) => debug!(%addr, error = %e, "ME connect failed during round-robin warmup"), - } - } - false - } - - async fn warmup_generation_for_all_dcs( - self: &Arc, - rng: &SecureRandom, - generation: u64, - desired_by_dc: &HashMap>, - ) { - let extra_passes = self - .me_hardswap_warmup_extra_passes - .load(Ordering::Relaxed) - .min(10) as usize; - let total_passes = 1 + extra_passes; - - for (dc, endpoints) in desired_by_dc { - if endpoints.is_empty() { - continue; - } - - let mut endpoint_list: Vec = endpoints.iter().copied().collect(); - endpoint_list.sort_unstable(); - let required = Self::required_writers_for_dc(endpoint_list.len()); - let mut completed = false; - let mut last_fresh_count = self - .fresh_writer_count_for_endpoints(generation, endpoints) - .await; - - for pass_idx in 0..total_passes { - if last_fresh_count >= required { - completed = true; - break; - } - - let missing = required.saturating_sub(last_fresh_count); - debug!( - dc = *dc, - pass = pass_idx + 1, - total_passes, - fresh_count = last_fresh_count, - required, - missing, - endpoint_count = endpoint_list.len(), - "ME hardswap warmup pass started" - ); - - for attempt_idx in 0..missing { - let delay_ms = self.hardswap_warmup_connect_delay_ms(); - tokio::time::sleep(Duration::from_millis(delay_ms)).await; - - let connected = self.connect_endpoints_round_robin(&endpoint_list, rng).await; - debug!( - dc = *dc, - pass = pass_idx + 1, - total_passes, - attempt = attempt_idx + 1, - delay_ms, - connected, - "ME hardswap warmup connect attempt finished" - ); - } - - last_fresh_count = self - .fresh_writer_count_for_endpoints(generation, endpoints) - .await; - if last_fresh_count >= required { - completed = true; - info!( - dc = *dc, - pass = pass_idx + 1, - total_passes, - fresh_count = last_fresh_count, - required, - "ME hardswap warmup floor reached for DC" - ); - break; - } - - if pass_idx + 1 < total_passes { - let backoff_ms = self.hardswap_warmup_backoff_ms(pass_idx); - debug!( - dc = *dc, - pass = pass_idx + 1, - total_passes, - fresh_count = last_fresh_count, - required, - backoff_ms, - "ME hardswap warmup pass incomplete, delaying next pass" - ); - tokio::time::sleep(Duration::from_millis(backoff_ms)).await; - } - } - - if !completed { - warn!( - dc = *dc, - fresh_count = last_fresh_count, - required, - endpoint_count = endpoint_list.len(), - total_passes, - "ME warmup stopped: unable to reach required writer floor for DC" - ); - } - } - } - - pub async fn zero_downtime_reinit_after_map_change( - self: &Arc, - rng: &SecureRandom, - ) { - let desired_by_dc = self.desired_dc_endpoints().await; - if desired_by_dc.is_empty() { - warn!("ME endpoint map is empty; skipping stale writer drain"); - return; - } - - let previous_generation = self.current_generation(); - let generation = self.generation.fetch_add(1, Ordering::Relaxed) + 1; - let hardswap = self.hardswap.load(Ordering::Relaxed); - - if hardswap { - self.warmup_generation_for_all_dcs(rng, generation, &desired_by_dc) - .await; - } else { - self.reconcile_connections(rng).await; - } - - let writers = self.writers.read().await; - let active_writer_addrs: HashSet = writers - .iter() - .filter(|w| !w.draining.load(Ordering::Relaxed)) - .map(|w| w.addr) - .collect(); - let min_ratio = Self::permille_to_ratio( - self.me_pool_min_fresh_ratio_permille - .load(Ordering::Relaxed), - ); - let (coverage_ratio, missing_dc) = Self::coverage_ratio(&desired_by_dc, &active_writer_addrs); - if !hardswap && coverage_ratio < min_ratio { - warn!( - previous_generation, - generation, - coverage_ratio = format_args!("{coverage_ratio:.3}"), - min_ratio = format_args!("{min_ratio:.3}"), - missing_dc = ?missing_dc, - "ME reinit coverage below threshold; keeping stale writers" - ); - return; - } - - if hardswap { - let mut fresh_missing_dc = Vec::<(i32, usize, usize)>::new(); - for (dc, endpoints) in &desired_by_dc { - if endpoints.is_empty() { - continue; - } - let required = Self::required_writers_for_dc(endpoints.len()); - let fresh_count = writers - .iter() - .filter(|w| !w.draining.load(Ordering::Relaxed)) - .filter(|w| w.generation == generation) - .filter(|w| endpoints.contains(&w.addr)) - .count(); - if fresh_count < required { - fresh_missing_dc.push((*dc, fresh_count, required)); - } - } - if !fresh_missing_dc.is_empty() { - warn!( - previous_generation, - generation, - missing_dc = ?fresh_missing_dc, - "ME hardswap pending: fresh generation coverage incomplete" - ); - return; - } - } else if !missing_dc.is_empty() { - warn!( - missing_dc = ?missing_dc, - // Keep stale writers alive when fresh coverage is incomplete. - "ME reinit coverage incomplete; keeping stale writers" - ); - return; - } - - let desired_addrs: HashSet = desired_by_dc - .values() - .flat_map(|set| set.iter().copied()) - .collect(); - - let stale_writer_ids: Vec = writers - .iter() - .filter(|w| !w.draining.load(Ordering::Relaxed)) - .filter(|w| { - if hardswap { - w.generation < generation - } else { - !desired_addrs.contains(&w.addr) - } - }) - .map(|w| w.id) - .collect(); - drop(writers); - - if stale_writer_ids.is_empty() { - debug!("ME reinit cycle completed with no stale writers"); - return; - } - - let drain_timeout = self.force_close_timeout(); - let drain_timeout_secs = drain_timeout.map(|d| d.as_secs()).unwrap_or(0); - info!( - stale_writers = stale_writer_ids.len(), - previous_generation, - generation, - hardswap, - coverage_ratio = format_args!("{coverage_ratio:.3}"), - min_ratio = format_args!("{min_ratio:.3}"), - drain_timeout_secs, - "ME reinit cycle covered; draining stale writers" - ); - self.stats.increment_pool_swap_total(); - for writer_id in stale_writer_ids { - self.mark_writer_draining_with_timeout(writer_id, drain_timeout, !hardswap) - .await; - } - } - - pub async fn zero_downtime_reinit_periodic( - self: &Arc, - rng: &SecureRandom, - ) { - self.zero_downtime_reinit_after_map_change(rng).await; - } - - async fn endpoints_for_same_dc(&self, addr: SocketAddr) -> Vec { - let mut target_dc = HashSet::::new(); - let mut endpoints = HashSet::::new(); - - if self.decision.ipv4_me { - let map = self.proxy_map_v4.read().await.clone(); - for (dc, addrs) in &map { - if addrs - .iter() - .any(|(ip, port)| SocketAddr::new(*ip, *port) == addr) - { - target_dc.insert(dc.abs()); - } - } - for dc in &target_dc { - for key in [*dc, -*dc] { - if let Some(addrs) = map.get(&key) { - for (ip, port) in addrs { - endpoints.insert(SocketAddr::new(*ip, *port)); - } - } - } - } - } - - if self.decision.ipv6_me { - let map = self.proxy_map_v6.read().await.clone(); - for (dc, addrs) in &map { - if addrs - .iter() - .any(|(ip, port)| SocketAddr::new(*ip, *port) == addr) - { - target_dc.insert(dc.abs()); - } - } - for dc in &target_dc { - for key in [*dc, -*dc] { - if let Some(addrs) = map.get(&key) { - for (ip, port) in addrs { - endpoints.insert(SocketAddr::new(*ip, *port)); - } - } - } - } - } - - let mut sorted: Vec = endpoints.into_iter().collect(); - sorted.sort_unstable(); - sorted - } - - async fn refill_writer_after_loss(self: &Arc, addr: SocketAddr) -> bool { - let fast_retries = self.me_reconnect_fast_retry_count.max(1); - - for attempt in 0..fast_retries { - self.stats.increment_me_reconnect_attempt(); - match self.connect_one(addr, self.rng.as_ref()).await { - Ok(()) => { - self.stats.increment_me_reconnect_success(); - self.stats.increment_me_writer_restored_same_endpoint_total(); - info!( - %addr, - attempt = attempt + 1, - "ME writer restored on the same endpoint" - ); - return true; - } - Err(e) => { - debug!( - %addr, - attempt = attempt + 1, - error = %e, - "ME immediate same-endpoint reconnect failed" - ); - } - } - } - - let dc_endpoints = self.endpoints_for_same_dc(addr).await; - if dc_endpoints.is_empty() { - self.stats.increment_me_refill_failed_total(); - return false; - } - - for attempt in 0..fast_retries { - self.stats.increment_me_reconnect_attempt(); - if self - .connect_endpoints_round_robin(&dc_endpoints, self.rng.as_ref()) - .await - { - self.stats.increment_me_reconnect_success(); - self.stats.increment_me_writer_restored_fallback_total(); - info!( - %addr, - attempt = attempt + 1, - "ME writer restored via DC fallback endpoint" - ); - return true; - } - } - - self.stats.increment_me_refill_failed_total(); - false - } - - pub(crate) fn trigger_immediate_refill(self: &Arc, addr: SocketAddr) { - let pool = Arc::clone(self); - tokio::spawn(async move { - { - let mut guard = pool.refill_inflight.lock().await; - if !guard.insert(addr) { - pool.stats.increment_me_refill_skipped_inflight_total(); - return; - } - } - pool.stats.increment_me_refill_triggered_total(); - - let restored = pool.refill_writer_after_loss(addr).await; - if !restored { - warn!(%addr, "ME immediate refill failed"); - } - - let mut guard = pool.refill_inflight.lock().await; - guard.remove(&addr); - }); - } - - pub async fn update_proxy_maps( - &self, - new_v4: HashMap>, - new_v6: Option>>, - ) -> bool { - let mut changed = false; - { - let mut guard = self.proxy_map_v4.write().await; - if !new_v4.is_empty() && *guard != new_v4 { - *guard = new_v4; - changed = true; - } - } - if let Some(v6) = new_v6 { - let mut guard = self.proxy_map_v6.write().await; - if !v6.is_empty() && *guard != v6 { - *guard = v6; - changed = true; - } - } - // Ensure negative DC entries mirror positives when absent (Telegram convention). - { - let mut guard = self.proxy_map_v4.write().await; - let keys: Vec = guard.keys().cloned().collect(); - for k in keys.iter().cloned().filter(|k| *k > 0) { - if !guard.contains_key(&-k) - && let Some(addrs) = guard.get(&k).cloned() - { - guard.insert(-k, addrs); - } - } - } - { - let mut guard = self.proxy_map_v6.write().await; - let keys: Vec = guard.keys().cloned().collect(); - for k in keys.iter().cloned().filter(|k| *k > 0) { - if !guard.contains_key(&-k) - && let Some(addrs) = guard.get(&k).cloned() - { - guard.insert(-k, addrs); - } - } - } - changed - } - - pub async fn update_secret(self: &Arc, new_secret: Vec) -> bool { - if new_secret.len() < 32 { - warn!(len = new_secret.len(), "proxy-secret update ignored (too short)"); - return false; - } - let mut guard = self.proxy_secret.write().await; - if *guard != new_secret { - *guard = new_secret; - drop(guard); - self.reconnect_all().await; - return true; - } - false - } - - pub async fn reconnect_all(self: &Arc) { - let ws = self.writers.read().await.clone(); - for w in ws { - if let Ok(()) = self.connect_one(w.addr, self.rng.as_ref()).await { - self.mark_writer_draining(w.id).await; - tokio::time::sleep(Duration::from_secs(2)).await; - } - } - } - pub(super) async fn key_selector(&self) -> u32 { let secret = self.proxy_secret.read().await; if secret.len() >= 4 { @@ -884,513 +298,13 @@ impl MePool { order } - async fn proxy_map_for_family(&self, family: IpFamily) -> HashMap> { + pub(super) async fn proxy_map_for_family( + &self, + family: IpFamily, + ) -> HashMap> { match family { IpFamily::V4 => self.proxy_map_v4.read().await.clone(), IpFamily::V6 => self.proxy_map_v6.read().await.clone(), } } - - pub async fn init(self: &Arc, pool_size: usize, rng: &Arc) -> Result<()> { - let family_order = self.family_order(); - let ks = self.key_selector().await; - info!( - me_servers = self.proxy_map_v4.read().await.len(), - pool_size, - key_selector = format_args!("0x{ks:08x}"), - secret_len = self.proxy_secret.read().await.len(), - "Initializing ME pool" - ); - - for family in family_order { - let map = self.proxy_map_for_family(family).await; - let mut grouped_dc_addrs: HashMap> = HashMap::new(); - for (dc, addrs) in map { - if addrs.is_empty() { - continue; - } - grouped_dc_addrs - .entry(dc.abs()) - .or_default() - .extend(addrs); - } - let mut dc_addrs: Vec<(i32, Vec<(IpAddr, u16)>)> = grouped_dc_addrs - .into_iter() - .map(|(dc, mut addrs)| { - addrs.sort_unstable(); - addrs.dedup(); - (dc, addrs) - }) - .collect(); - dc_addrs.sort_unstable_by_key(|(dc, _)| *dc); - - // Ensure at least one connection per DC; run DCs in parallel. - let mut join = tokio::task::JoinSet::new(); - let mut dc_failures = 0usize; - for (dc, addrs) in dc_addrs.iter().cloned() { - if addrs.is_empty() { - continue; - } - let pool = Arc::clone(self); - let rng_clone = Arc::clone(rng); - join.spawn(async move { - pool.connect_primary_for_dc(dc, addrs, rng_clone).await - }); - } - while let Some(res) = join.join_next().await { - if let Ok(false) = res { - dc_failures += 1; - } - } - if dc_failures > 2 { - return Err(ProxyError::Proxy("Too many ME DC init failures, falling back to direct".into())); - } - - // Warm reserve writers asynchronously so startup does not block after first working pool is ready. - let pool = Arc::clone(self); - let rng_clone = Arc::clone(rng); - let dc_addrs_bg = dc_addrs.clone(); - tokio::spawn(async move { - if pool.me_warmup_stagger_enabled { - for (dc, addrs) in dc_addrs_bg.iter() { - for (ip, port) in addrs { - if pool.connection_count() >= pool_size { - break; - } - let addr = SocketAddr::new(*ip, *port); - let jitter = rand::rng() - .random_range(0..=pool.me_warmup_step_jitter.as_millis() as u64); - let delay_ms = pool.me_warmup_step_delay.as_millis() as u64 + jitter; - tokio::time::sleep(Duration::from_millis(delay_ms)).await; - if let Err(e) = pool.connect_one(addr, rng_clone.as_ref()).await { - debug!(%addr, dc = %dc, error = %e, "Extra ME connect failed (staggered)"); - } - } - } - } else { - for (dc, addrs) in dc_addrs_bg.iter() { - for (ip, port) in addrs { - if pool.connection_count() >= pool_size { - break; - } - let addr = SocketAddr::new(*ip, *port); - if let Err(e) = pool.connect_one(addr, rng_clone.as_ref()).await { - debug!(%addr, dc = %dc, error = %e, "Extra ME connect failed"); - } - } - if pool.connection_count() >= pool_size { - break; - } - } - } - debug!( - target_pool_size = pool_size, - current_pool_size = pool.connection_count(), - "Background ME reserve warmup finished" - ); - }); - - if !self.decision.effective_multipath && self.connection_count() > 0 { - break; - } - } - - if self.writers.read().await.is_empty() { - return Err(ProxyError::Proxy("No ME connections".into())); - } - info!( - active_writers = self.connection_count(), - "ME primary pool ready; reserve warmup continues in background" - ); - Ok(()) - } - - pub(crate) async fn connect_one(self: &Arc, addr: SocketAddr, rng: &SecureRandom) -> Result<()> { - let secret_len = self.proxy_secret.read().await.len(); - if secret_len < 32 { - return Err(ProxyError::Proxy("proxy-secret too short for ME auth".into())); - } - - let (stream, _connect_ms) = self.connect_tcp(addr).await?; - let hs = self.handshake_only(stream, addr, rng).await?; - - let writer_id = self.next_writer_id.fetch_add(1, Ordering::Relaxed); - let generation = self.current_generation(); - let cancel = CancellationToken::new(); - let degraded = Arc::new(AtomicBool::new(false)); - let draining = Arc::new(AtomicBool::new(false)); - let draining_started_at_epoch_secs = Arc::new(AtomicU64::new(0)); - let allow_drain_fallback = Arc::new(AtomicBool::new(false)); - let (tx, mut rx) = mpsc::channel::(4096); - let mut rpc_writer = RpcWriter { - writer: hs.wr, - key: hs.write_key, - iv: hs.write_iv, - seq_no: 0, - crc_mode: hs.crc_mode, - }; - let cancel_wr = cancel.clone(); - tokio::spawn(async move { - loop { - tokio::select! { - cmd = rx.recv() => { - match cmd { - Some(WriterCommand::Data(payload)) => { - if rpc_writer.send(&payload).await.is_err() { break; } - } - Some(WriterCommand::DataAndFlush(payload)) => { - if rpc_writer.send_and_flush(&payload).await.is_err() { break; } - } - Some(WriterCommand::Close) | None => break, - } - } - _ = cancel_wr.cancelled() => break, - } - } - }); - let writer = MeWriter { - id: writer_id, - addr, - generation, - tx: tx.clone(), - cancel: cancel.clone(), - degraded: degraded.clone(), - draining: draining.clone(), - draining_started_at_epoch_secs: draining_started_at_epoch_secs.clone(), - allow_drain_fallback: allow_drain_fallback.clone(), - }; - self.writers.write().await.push(writer.clone()); - self.conn_count.fetch_add(1, Ordering::Relaxed); - self.writer_available.notify_one(); - - let reg = self.registry.clone(); - let writers_arc = self.writers_arc(); - let ping_tracker = self.ping_tracker.clone(); - let ping_tracker_reader = ping_tracker.clone(); - let rtt_stats = self.rtt_stats.clone(); - let stats_reader = self.stats.clone(); - let stats_ping = self.stats.clone(); - let pool = Arc::downgrade(self); - let cancel_ping = cancel.clone(); - let tx_ping = tx.clone(); - let ping_tracker_ping = ping_tracker.clone(); - let cleanup_done = Arc::new(AtomicBool::new(false)); - let cleanup_for_reader = cleanup_done.clone(); - let cleanup_for_ping = cleanup_done.clone(); - let keepalive_enabled = self.me_keepalive_enabled; - let keepalive_interval = self.me_keepalive_interval; - let keepalive_jitter = self.me_keepalive_jitter; - let cancel_reader_token = cancel.clone(); - let cancel_ping_token = cancel_ping.clone(); - - tokio::spawn(async move { - let res = reader_loop( - hs.rd, - hs.read_key, - hs.read_iv, - hs.crc_mode, - reg.clone(), - BytesMut::new(), - BytesMut::new(), - tx.clone(), - ping_tracker_reader, - rtt_stats.clone(), - stats_reader, - writer_id, - degraded.clone(), - cancel_reader_token.clone(), - ) - .await; - if let Some(pool) = pool.upgrade() - && cleanup_for_reader - .compare_exchange(false, true, Ordering::AcqRel, Ordering::Relaxed) - .is_ok() - { - pool.remove_writer_and_close_clients(writer_id).await; - } - if let Err(e) = res { - warn!(error = %e, "ME reader ended"); - } - let mut ws = writers_arc.write().await; - ws.retain(|w| w.id != writer_id); - info!(remaining = ws.len(), "Dead ME writer removed from pool"); - }); - - let pool_ping = Arc::downgrade(self); - tokio::spawn(async move { - let mut ping_id: i64 = rand::random::(); - // Per-writer jittered start to avoid phase sync. - let startup_jitter = if keepalive_enabled { - let jitter_cap_ms = keepalive_interval.as_millis() / 2; - let effective_jitter_ms = keepalive_jitter.as_millis().min(jitter_cap_ms).max(1); - Duration::from_millis(rand::rng().random_range(0..=effective_jitter_ms as u64)) - } else { - let jitter = rand::rng() - .random_range(-ME_ACTIVE_PING_JITTER_SECS..=ME_ACTIVE_PING_JITTER_SECS); - let wait = (ME_ACTIVE_PING_SECS as i64 + jitter).max(5) as u64; - Duration::from_secs(wait) - }; - tokio::select! { - _ = cancel_ping_token.cancelled() => return, - _ = tokio::time::sleep(startup_jitter) => {} - } - loop { - let wait = if keepalive_enabled { - let jitter_cap_ms = keepalive_interval.as_millis() / 2; - let effective_jitter_ms = keepalive_jitter.as_millis().min(jitter_cap_ms).max(1); - keepalive_interval - + Duration::from_millis( - rand::rng().random_range(0..=effective_jitter_ms as u64) - ) - } else { - let jitter = rand::rng() - .random_range(-ME_ACTIVE_PING_JITTER_SECS..=ME_ACTIVE_PING_JITTER_SECS); - let secs = (ME_ACTIVE_PING_SECS as i64 + jitter).max(5) as u64; - Duration::from_secs(secs) - }; - tokio::select! { - _ = cancel_ping_token.cancelled() => { - break; - } - _ = tokio::time::sleep(wait) => {} - } - let sent_id = ping_id; - let mut p = Vec::with_capacity(12); - p.extend_from_slice(&RPC_PING_U32.to_le_bytes()); - p.extend_from_slice(&sent_id.to_le_bytes()); - { - let mut tracker = ping_tracker_ping.lock().await; - let before = tracker.len(); - tracker.retain(|_, (ts, _)| ts.elapsed() < Duration::from_secs(120)); - let expired = before.saturating_sub(tracker.len()); - if expired > 0 { - stats_ping.increment_me_keepalive_timeout_by(expired as u64); - } - tracker.insert(sent_id, (std::time::Instant::now(), writer_id)); - } - ping_id = ping_id.wrapping_add(1); - stats_ping.increment_me_keepalive_sent(); - if tx_ping.send(WriterCommand::DataAndFlush(p)).await.is_err() { - stats_ping.increment_me_keepalive_failed(); - debug!("ME ping failed, removing dead writer"); - cancel_ping.cancel(); - if let Some(pool) = pool_ping.upgrade() - && cleanup_for_ping - .compare_exchange(false, true, Ordering::AcqRel, Ordering::Relaxed) - .is_ok() - { - pool.remove_writer_and_close_clients(writer_id).await; - } - break; - } - } - }); - - Ok(()) - } - - async fn connect_primary_for_dc( - self: Arc, - dc: i32, - mut addrs: Vec<(IpAddr, u16)>, - rng: Arc, - ) -> bool { - if addrs.is_empty() { - return false; - } - addrs.shuffle(&mut rand::rng()); - if addrs.len() > 1 { - let concurrency = 2usize; - let mut join = tokio::task::JoinSet::new(); - let mut next_idx = 0usize; - - while next_idx < addrs.len() || !join.is_empty() { - while next_idx < addrs.len() && join.len() < concurrency { - let (ip, port) = addrs[next_idx]; - next_idx += 1; - let addr = SocketAddr::new(ip, port); - let pool = Arc::clone(&self); - let rng_clone = Arc::clone(&rng); - join.spawn(async move { (addr, pool.connect_one(addr, rng_clone.as_ref()).await) }); - } - - let Some(res) = join.join_next().await else { - break; - }; - match res { - Ok((addr, Ok(()))) => { - info!(%addr, dc = %dc, "ME connected"); - join.abort_all(); - while join.join_next().await.is_some() {} - return true; - } - Ok((addr, Err(e))) => { - warn!(%addr, dc = %dc, error = %e, "ME connect failed, trying next"); - } - Err(e) => { - warn!(dc = %dc, error = %e, "ME connect task failed"); - } - } - } - warn!(dc = %dc, "All ME servers for DC failed at init"); - return false; - } - - for (ip, port) in addrs { - let addr = SocketAddr::new(ip, port); - match self.connect_one(addr, rng.as_ref()).await { - Ok(()) => { - info!(%addr, dc = %dc, "ME connected"); - return true; - } - Err(e) => warn!(%addr, dc = %dc, error = %e, "ME connect failed, trying next"), - } - } - warn!(dc = %dc, "All ME servers for DC failed at init"); - false - } - - pub(crate) async fn remove_writer_and_close_clients(self: &Arc, writer_id: u64) { - let conns = self.remove_writer_only(writer_id).await; - for bound in conns { - let _ = self.registry.route(bound.conn_id, super::MeResponse::Close).await; - let _ = self.registry.unregister(bound.conn_id).await; - } - } - - async fn remove_writer_only(self: &Arc, writer_id: u64) -> Vec { - let mut close_tx: Option> = None; - let mut removed_addr: Option = None; - let mut trigger_refill = false; - { - let mut ws = self.writers.write().await; - if let Some(pos) = ws.iter().position(|w| w.id == writer_id) { - let w = ws.remove(pos); - let was_draining = w.draining.load(Ordering::Relaxed); - if was_draining { - self.stats.decrement_pool_drain_active(); - } - self.stats.increment_me_writer_removed_total(); - w.cancel.cancel(); - removed_addr = Some(w.addr); - trigger_refill = !was_draining; - if trigger_refill { - self.stats.increment_me_writer_removed_unexpected_total(); - } - close_tx = Some(w.tx.clone()); - self.conn_count.fetch_sub(1, Ordering::Relaxed); - } - } - if let Some(tx) = close_tx { - let _ = tx.send(WriterCommand::Close).await; - } - if trigger_refill - && let Some(addr) = removed_addr - { - self.trigger_immediate_refill(addr); - } - self.rtt_stats.lock().await.remove(&writer_id); - self.registry.writer_lost(writer_id).await - } - - pub(crate) async fn mark_writer_draining_with_timeout( - self: &Arc, - writer_id: u64, - timeout: Option, - allow_drain_fallback: bool, - ) { - let timeout = timeout.filter(|d| !d.is_zero()); - let found = { - let mut ws = self.writers.write().await; - if let Some(w) = ws.iter_mut().find(|w| w.id == writer_id) { - let already_draining = w.draining.swap(true, Ordering::Relaxed); - w.allow_drain_fallback - .store(allow_drain_fallback, Ordering::Relaxed); - w.draining_started_at_epoch_secs - .store(Self::now_epoch_secs(), Ordering::Relaxed); - if !already_draining { - self.stats.increment_pool_drain_active(); - } - w.draining.store(true, Ordering::Relaxed); - true - } else { - false - } - }; - - if !found { - return; - } - - let timeout_secs = timeout.map(|d| d.as_secs()).unwrap_or(0); - debug!( - writer_id, - timeout_secs, - allow_drain_fallback, - "ME writer marked draining" - ); - - let pool = Arc::downgrade(self); - tokio::spawn(async move { - let deadline = timeout.map(|t| Instant::now() + t); - while let Some(p) = pool.upgrade() { - if let Some(deadline_at) = deadline - && Instant::now() >= deadline_at - { - warn!(writer_id, "Drain timeout, force-closing"); - p.stats.increment_pool_force_close_total(); - let _ = p.remove_writer_and_close_clients(writer_id).await; - break; - } - if p.registry.is_writer_empty(writer_id).await { - let _ = p.remove_writer_only(writer_id).await; - break; - } - tokio::time::sleep(Duration::from_secs(1)).await; - } - }); - } - - pub(crate) async fn mark_writer_draining(self: &Arc, writer_id: u64) { - self.mark_writer_draining_with_timeout(writer_id, Some(Duration::from_secs(300)), false) - .await; - } - - pub(super) fn writer_accepts_new_binding(&self, writer: &MeWriter) -> bool { - if !writer.draining.load(Ordering::Relaxed) { - return true; - } - if !writer.allow_drain_fallback.load(Ordering::Relaxed) { - return false; - } - - let ttl_secs = self.me_pool_drain_ttl_secs.load(Ordering::Relaxed); - if ttl_secs == 0 { - return true; - } - - let started = writer.draining_started_at_epoch_secs.load(Ordering::Relaxed); - if started == 0 { - return false; - } - - Self::now_epoch_secs().saturating_sub(started) <= ttl_secs - } - -} - -#[allow(dead_code)] -fn hex_dump(data: &[u8]) -> String { - const MAX: usize = 64; - let mut out = String::with_capacity(data.len() * 2 + 3); - for (i, b) in data.iter().take(MAX).enumerate() { - if i > 0 { - out.push(' '); - } - out.push_str(&format!("{b:02x}")); - } - if data.len() > MAX { - out.push_str(" …"); - } - out } diff --git a/src/transport/middle_proxy/pool_config.rs b/src/transport/middle_proxy/pool_config.rs new file mode 100644 index 0000000..fe2aad8 --- /dev/null +++ b/src/transport/middle_proxy/pool_config.rs @@ -0,0 +1,81 @@ +use std::collections::HashMap; +use std::net::IpAddr; +use std::sync::Arc; +use std::time::Duration; + +use tracing::warn; + +use super::pool::MePool; + +impl MePool { + pub async fn update_proxy_maps( + &self, + new_v4: HashMap>, + new_v6: Option>>, + ) -> bool { + let mut changed = false; + { + let mut guard = self.proxy_map_v4.write().await; + if !new_v4.is_empty() && *guard != new_v4 { + *guard = new_v4; + changed = true; + } + } + if let Some(v6) = new_v6 { + let mut guard = self.proxy_map_v6.write().await; + if !v6.is_empty() && *guard != v6 { + *guard = v6; + changed = true; + } + } + // Ensure negative DC entries mirror positives when absent (Telegram convention). + { + let mut guard = self.proxy_map_v4.write().await; + let keys: Vec = guard.keys().cloned().collect(); + for k in keys.iter().cloned().filter(|k| *k > 0) { + if !guard.contains_key(&-k) + && let Some(addrs) = guard.get(&k).cloned() + { + guard.insert(-k, addrs); + } + } + } + { + let mut guard = self.proxy_map_v6.write().await; + let keys: Vec = guard.keys().cloned().collect(); + for k in keys.iter().cloned().filter(|k| *k > 0) { + if !guard.contains_key(&-k) + && let Some(addrs) = guard.get(&k).cloned() + { + guard.insert(-k, addrs); + } + } + } + changed + } + + pub async fn update_secret(self: &Arc, new_secret: Vec) -> bool { + if new_secret.len() < 32 { + warn!(len = new_secret.len(), "proxy-secret update ignored (too short)"); + return false; + } + let mut guard = self.proxy_secret.write().await; + if *guard != new_secret { + *guard = new_secret; + drop(guard); + self.reconnect_all().await; + return true; + } + false + } + + pub async fn reconnect_all(self: &Arc) { + let ws = self.writers.read().await.clone(); + for w in ws { + if let Ok(()) = self.connect_one(w.addr, self.rng.as_ref()).await { + self.mark_writer_draining(w.id).await; + tokio::time::sleep(Duration::from_secs(2)).await; + } + } + } +} diff --git a/src/transport/middle_proxy/pool_init.rs b/src/transport/middle_proxy/pool_init.rs new file mode 100644 index 0000000..623be7f --- /dev/null +++ b/src/transport/middle_proxy/pool_init.rs @@ -0,0 +1,201 @@ +use std::collections::{HashMap, HashSet}; +use std::net::{IpAddr, SocketAddr}; +use std::sync::Arc; + +use rand::Rng; +use rand::seq::SliceRandom; +use tracing::{debug, info, warn}; + +use crate::crypto::SecureRandom; +use crate::error::{ProxyError, Result}; + +use super::pool::MePool; + +impl MePool { + pub async fn init(self: &Arc, pool_size: usize, rng: &Arc) -> Result<()> { + let family_order = self.family_order(); + let ks = self.key_selector().await; + info!( + me_servers = self.proxy_map_v4.read().await.len(), + pool_size, + key_selector = format_args!("0x{ks:08x}"), + secret_len = self.proxy_secret.read().await.len(), + "Initializing ME pool" + ); + + for family in family_order { + let map = self.proxy_map_for_family(family).await; + let mut grouped_dc_addrs: HashMap> = HashMap::new(); + for (dc, addrs) in map { + if addrs.is_empty() { + continue; + } + grouped_dc_addrs.entry(dc.abs()).or_default().extend(addrs); + } + let mut dc_addrs: Vec<(i32, Vec<(IpAddr, u16)>)> = grouped_dc_addrs + .into_iter() + .map(|(dc, mut addrs)| { + addrs.sort_unstable(); + addrs.dedup(); + (dc, addrs) + }) + .collect(); + dc_addrs.sort_unstable_by_key(|(dc, _)| *dc); + + // Ensure at least one live writer per DC group; run missing DCs in parallel. + let mut join = tokio::task::JoinSet::new(); + for (dc, addrs) in dc_addrs.iter().cloned() { + if addrs.is_empty() { + continue; + } + let endpoints: HashSet = addrs + .iter() + .map(|(ip, port)| SocketAddr::new(*ip, *port)) + .collect(); + if self.active_writer_count_for_endpoints(&endpoints).await > 0 { + continue; + } + let pool = Arc::clone(self); + let rng_clone = Arc::clone(rng); + join.spawn(async move { pool.connect_primary_for_dc(dc, addrs, rng_clone).await }); + } + while join.join_next().await.is_some() {} + + let mut missing_dcs = Vec::new(); + for (dc, addrs) in &dc_addrs { + let endpoints: HashSet = addrs + .iter() + .map(|(ip, port)| SocketAddr::new(*ip, *port)) + .collect(); + if self.active_writer_count_for_endpoints(&endpoints).await == 0 { + missing_dcs.push(*dc); + } + } + if !missing_dcs.is_empty() { + return Err(ProxyError::Proxy(format!( + "ME init incomplete: no live writers for DC groups {missing_dcs:?}" + ))); + } + + // Warm reserve writers asynchronously so startup does not block after first working pool is ready. + let pool = Arc::clone(self); + let rng_clone = Arc::clone(rng); + let dc_addrs_bg = dc_addrs.clone(); + tokio::spawn(async move { + if pool.me_warmup_stagger_enabled { + for (dc, addrs) in &dc_addrs_bg { + for (ip, port) in addrs { + if pool.connection_count() >= pool_size { + break; + } + let addr = SocketAddr::new(*ip, *port); + let jitter = rand::rng() + .random_range(0..=pool.me_warmup_step_jitter.as_millis() as u64); + let delay_ms = pool.me_warmup_step_delay.as_millis() as u64 + jitter; + tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await; + if let Err(e) = pool.connect_one(addr, rng_clone.as_ref()).await { + debug!(%addr, dc = %dc, error = %e, "Extra ME connect failed (staggered)"); + } + } + } + } else { + for (dc, addrs) in &dc_addrs_bg { + for (ip, port) in addrs { + if pool.connection_count() >= pool_size { + break; + } + let addr = SocketAddr::new(*ip, *port); + if let Err(e) = pool.connect_one(addr, rng_clone.as_ref()).await { + debug!(%addr, dc = %dc, error = %e, "Extra ME connect failed"); + } + } + if pool.connection_count() >= pool_size { + break; + } + } + } + debug!( + target_pool_size = pool_size, + current_pool_size = pool.connection_count(), + "Background ME reserve warmup finished" + ); + }); + + if !self.decision.effective_multipath && self.connection_count() > 0 { + break; + } + } + + if self.writers.read().await.is_empty() { + return Err(ProxyError::Proxy("No ME connections".into())); + } + info!( + active_writers = self.connection_count(), + "ME primary pool ready; reserve warmup continues in background" + ); + Ok(()) + } + + async fn connect_primary_for_dc( + self: Arc, + dc: i32, + mut addrs: Vec<(IpAddr, u16)>, + rng: Arc, + ) -> bool { + if addrs.is_empty() { + return false; + } + addrs.shuffle(&mut rand::rng()); + if addrs.len() > 1 { + let concurrency = 2usize; + let mut join = tokio::task::JoinSet::new(); + let mut next_idx = 0usize; + + while next_idx < addrs.len() || !join.is_empty() { + while next_idx < addrs.len() && join.len() < concurrency { + let (ip, port) = addrs[next_idx]; + next_idx += 1; + let addr = SocketAddr::new(ip, port); + let pool = Arc::clone(&self); + let rng_clone = Arc::clone(&rng); + join.spawn(async move { + (addr, pool.connect_one(addr, rng_clone.as_ref()).await) + }); + } + + let Some(res) = join.join_next().await else { + break; + }; + match res { + Ok((addr, Ok(()))) => { + info!(%addr, dc = %dc, "ME connected"); + join.abort_all(); + while join.join_next().await.is_some() {} + return true; + } + Ok((addr, Err(e))) => { + warn!(%addr, dc = %dc, error = %e, "ME connect failed, trying next"); + } + Err(e) => { + warn!(dc = %dc, error = %e, "ME connect task failed"); + } + } + } + warn!(dc = %dc, "All ME servers for DC failed at init"); + return false; + } + + for (ip, port) in addrs { + let addr = SocketAddr::new(ip, port); + match self.connect_one(addr, rng.as_ref()).await { + Ok(()) => { + info!(%addr, dc = %dc, "ME connected"); + return true; + } + Err(e) => warn!(%addr, dc = %dc, error = %e, "ME connect failed, trying next"), + } + } + warn!(dc = %dc, "All ME servers for DC failed at init"); + false + } +} diff --git a/src/transport/middle_proxy/pool_refill.rs b/src/transport/middle_proxy/pool_refill.rs new file mode 100644 index 0000000..6dea6c9 --- /dev/null +++ b/src/transport/middle_proxy/pool_refill.rs @@ -0,0 +1,159 @@ +use std::collections::HashSet; +use std::net::SocketAddr; +use std::sync::Arc; +use std::sync::atomic::Ordering; + +use tracing::{debug, info, warn}; + +use crate::crypto::SecureRandom; + +use super::pool::MePool; + +impl MePool { + pub(super) async fn connect_endpoints_round_robin( + self: &Arc, + endpoints: &[SocketAddr], + rng: &SecureRandom, + ) -> bool { + if endpoints.is_empty() { + return false; + } + let start = (self.rr.fetch_add(1, Ordering::Relaxed) as usize) % endpoints.len(); + for offset in 0..endpoints.len() { + let idx = (start + offset) % endpoints.len(); + let addr = endpoints[idx]; + match self.connect_one(addr, rng).await { + Ok(()) => return true, + Err(e) => debug!(%addr, error = %e, "ME connect failed during round-robin warmup"), + } + } + false + } + + async fn endpoints_for_same_dc(&self, addr: SocketAddr) -> Vec { + let mut target_dc = HashSet::::new(); + let mut endpoints = HashSet::::new(); + + if self.decision.ipv4_me { + let map = self.proxy_map_v4.read().await.clone(); + for (dc, addrs) in &map { + if addrs + .iter() + .any(|(ip, port)| SocketAddr::new(*ip, *port) == addr) + { + target_dc.insert(dc.abs()); + } + } + for dc in &target_dc { + for key in [*dc, -*dc] { + if let Some(addrs) = map.get(&key) { + for (ip, port) in addrs { + endpoints.insert(SocketAddr::new(*ip, *port)); + } + } + } + } + } + + if self.decision.ipv6_me { + let map = self.proxy_map_v6.read().await.clone(); + for (dc, addrs) in &map { + if addrs + .iter() + .any(|(ip, port)| SocketAddr::new(*ip, *port) == addr) + { + target_dc.insert(dc.abs()); + } + } + for dc in &target_dc { + for key in [*dc, -*dc] { + if let Some(addrs) = map.get(&key) { + for (ip, port) in addrs { + endpoints.insert(SocketAddr::new(*ip, *port)); + } + } + } + } + } + + let mut sorted: Vec = endpoints.into_iter().collect(); + sorted.sort_unstable(); + sorted + } + + async fn refill_writer_after_loss(self: &Arc, addr: SocketAddr) -> bool { + let fast_retries = self.me_reconnect_fast_retry_count.max(1); + + for attempt in 0..fast_retries { + self.stats.increment_me_reconnect_attempt(); + match self.connect_one(addr, self.rng.as_ref()).await { + Ok(()) => { + self.stats.increment_me_reconnect_success(); + self.stats.increment_me_writer_restored_same_endpoint_total(); + info!( + %addr, + attempt = attempt + 1, + "ME writer restored on the same endpoint" + ); + return true; + } + Err(e) => { + debug!( + %addr, + attempt = attempt + 1, + error = %e, + "ME immediate same-endpoint reconnect failed" + ); + } + } + } + + let dc_endpoints = self.endpoints_for_same_dc(addr).await; + if dc_endpoints.is_empty() { + self.stats.increment_me_refill_failed_total(); + return false; + } + + for attempt in 0..fast_retries { + self.stats.increment_me_reconnect_attempt(); + if self + .connect_endpoints_round_robin(&dc_endpoints, self.rng.as_ref()) + .await + { + self.stats.increment_me_reconnect_success(); + self.stats.increment_me_writer_restored_fallback_total(); + info!( + %addr, + attempt = attempt + 1, + "ME writer restored via DC fallback endpoint" + ); + return true; + } + } + + self.stats.increment_me_refill_failed_total(); + false + } + + pub(crate) fn trigger_immediate_refill(self: &Arc, addr: SocketAddr) { + let pool = Arc::clone(self); + tokio::spawn(async move { + { + let mut guard = pool.refill_inflight.lock().await; + if !guard.insert(addr) { + pool.stats.increment_me_refill_skipped_inflight_total(); + return; + } + } + pool.stats.increment_me_refill_triggered_total(); + + let restored = pool.refill_writer_after_loss(addr).await; + if !restored { + warn!(%addr, "ME immediate refill failed"); + } + + let mut guard = pool.refill_inflight.lock().await; + guard.remove(&addr); + }); + } +} diff --git a/src/transport/middle_proxy/pool_reinit.rs b/src/transport/middle_proxy/pool_reinit.rs new file mode 100644 index 0000000..261ac02 --- /dev/null +++ b/src/transport/middle_proxy/pool_reinit.rs @@ -0,0 +1,383 @@ +use std::collections::{HashMap, HashSet}; +use std::net::SocketAddr; +use std::sync::Arc; +use std::sync::atomic::Ordering; +use std::time::Duration; + +use rand::Rng; +use rand::seq::SliceRandom; +use tracing::{debug, info, warn}; + +use crate::crypto::SecureRandom; + +use super::pool::MePool; + +impl MePool { + fn coverage_ratio( + desired_by_dc: &HashMap>, + active_writer_addrs: &HashSet, + ) -> (f32, Vec) { + if desired_by_dc.is_empty() { + return (1.0, Vec::new()); + } + + let mut missing_dc = Vec::::new(); + let mut covered = 0usize; + for (dc, endpoints) in desired_by_dc { + if endpoints.is_empty() { + continue; + } + if endpoints + .iter() + .any(|addr| active_writer_addrs.contains(addr)) + { + covered += 1; + } else { + missing_dc.push(*dc); + } + } + + missing_dc.sort_unstable(); + let total = desired_by_dc.len().max(1); + let ratio = (covered as f32) / (total as f32); + (ratio, missing_dc) + } + + pub async fn reconcile_connections(self: &Arc, rng: &SecureRandom) { + let writers = self.writers.read().await; + let current: HashSet = writers + .iter() + .filter(|w| !w.draining.load(Ordering::Relaxed)) + .map(|w| w.addr) + .collect(); + drop(writers); + + for family in self.family_order() { + let map = self.proxy_map_for_family(family).await; + for (_dc, addrs) in &map { + let dc_addrs: Vec = addrs + .iter() + .map(|(ip, port)| SocketAddr::new(*ip, *port)) + .collect(); + if !dc_addrs.iter().any(|a| current.contains(a)) { + let mut shuffled = dc_addrs.clone(); + shuffled.shuffle(&mut rand::rng()); + for addr in shuffled { + if self.connect_one(addr, rng).await.is_ok() { + break; + } + } + } + } + if !self.decision.effective_multipath && !current.is_empty() { + break; + } + } + } + + async fn desired_dc_endpoints(&self) -> HashMap> { + let mut out: HashMap> = HashMap::new(); + + if self.decision.ipv4_me { + let map_v4 = self.proxy_map_v4.read().await.clone(); + for (dc, addrs) in map_v4 { + let entry = out.entry(dc.abs()).or_default(); + for (ip, port) in addrs { + entry.insert(SocketAddr::new(ip, port)); + } + } + } + + if self.decision.ipv6_me { + let map_v6 = self.proxy_map_v6.read().await.clone(); + for (dc, addrs) in map_v6 { + let entry = out.entry(dc.abs()).or_default(); + for (ip, port) in addrs { + entry.insert(SocketAddr::new(ip, port)); + } + } + } + + out + } + + pub(super) fn required_writers_for_dc(endpoint_count: usize) -> usize { + endpoint_count.max(3) + } + + fn hardswap_warmup_connect_delay_ms(&self) -> u64 { + let min_ms = self.me_hardswap_warmup_delay_min_ms.load(Ordering::Relaxed); + let max_ms = self.me_hardswap_warmup_delay_max_ms.load(Ordering::Relaxed); + let (min_ms, max_ms) = if min_ms <= max_ms { + (min_ms, max_ms) + } else { + (max_ms, min_ms) + }; + if min_ms == max_ms { + return min_ms; + } + rand::rng().random_range(min_ms..=max_ms) + } + + fn hardswap_warmup_backoff_ms(&self, pass_idx: usize) -> u64 { + let base_ms = self + .me_hardswap_warmup_pass_backoff_base_ms + .load(Ordering::Relaxed); + let cap_ms = (self.me_reconnect_backoff_cap.as_millis() as u64).max(base_ms); + let shift = (pass_idx as u32).min(20); + let scaled = base_ms.saturating_mul(1u64 << shift); + let core = scaled.min(cap_ms); + let jitter = (core / 2).max(1); + core.saturating_add(rand::rng().random_range(0..=jitter)) + } + + async fn fresh_writer_count_for_endpoints( + &self, + generation: u64, + endpoints: &HashSet, + ) -> usize { + let ws = self.writers.read().await; + ws.iter() + .filter(|w| !w.draining.load(Ordering::Relaxed)) + .filter(|w| w.generation == generation) + .filter(|w| endpoints.contains(&w.addr)) + .count() + } + + pub(super) async fn active_writer_count_for_endpoints( + &self, + endpoints: &HashSet, + ) -> usize { + let ws = self.writers.read().await; + ws.iter() + .filter(|w| !w.draining.load(Ordering::Relaxed)) + .filter(|w| endpoints.contains(&w.addr)) + .count() + } + + async fn warmup_generation_for_all_dcs( + self: &Arc, + rng: &SecureRandom, + generation: u64, + desired_by_dc: &HashMap>, + ) { + let extra_passes = self + .me_hardswap_warmup_extra_passes + .load(Ordering::Relaxed) + .min(10) as usize; + let total_passes = 1 + extra_passes; + + for (dc, endpoints) in desired_by_dc { + if endpoints.is_empty() { + continue; + } + + let mut endpoint_list: Vec = endpoints.iter().copied().collect(); + endpoint_list.sort_unstable(); + let required = Self::required_writers_for_dc(endpoint_list.len()); + let mut completed = false; + let mut last_fresh_count = self + .fresh_writer_count_for_endpoints(generation, endpoints) + .await; + + for pass_idx in 0..total_passes { + if last_fresh_count >= required { + completed = true; + break; + } + + let missing = required.saturating_sub(last_fresh_count); + debug!( + dc = *dc, + pass = pass_idx + 1, + total_passes, + fresh_count = last_fresh_count, + required, + missing, + endpoint_count = endpoint_list.len(), + "ME hardswap warmup pass started" + ); + + for attempt_idx in 0..missing { + let delay_ms = self.hardswap_warmup_connect_delay_ms(); + tokio::time::sleep(Duration::from_millis(delay_ms)).await; + + let connected = self.connect_endpoints_round_robin(&endpoint_list, rng).await; + debug!( + dc = *dc, + pass = pass_idx + 1, + total_passes, + attempt = attempt_idx + 1, + delay_ms, + connected, + "ME hardswap warmup connect attempt finished" + ); + } + + last_fresh_count = self + .fresh_writer_count_for_endpoints(generation, endpoints) + .await; + if last_fresh_count >= required { + completed = true; + info!( + dc = *dc, + pass = pass_idx + 1, + total_passes, + fresh_count = last_fresh_count, + required, + "ME hardswap warmup floor reached for DC" + ); + break; + } + + if pass_idx + 1 < total_passes { + let backoff_ms = self.hardswap_warmup_backoff_ms(pass_idx); + debug!( + dc = *dc, + pass = pass_idx + 1, + total_passes, + fresh_count = last_fresh_count, + required, + backoff_ms, + "ME hardswap warmup pass incomplete, delaying next pass" + ); + tokio::time::sleep(Duration::from_millis(backoff_ms)).await; + } + } + + if !completed { + warn!( + dc = *dc, + fresh_count = last_fresh_count, + required, + endpoint_count = endpoint_list.len(), + total_passes, + "ME warmup stopped: unable to reach required writer floor for DC" + ); + } + } + } + + pub async fn zero_downtime_reinit_after_map_change(self: &Arc, rng: &SecureRandom) { + let desired_by_dc = self.desired_dc_endpoints().await; + if desired_by_dc.is_empty() { + warn!("ME endpoint map is empty; skipping stale writer drain"); + return; + } + + let previous_generation = self.current_generation(); + let generation = self.generation.fetch_add(1, Ordering::Relaxed) + 1; + let hardswap = self.hardswap.load(Ordering::Relaxed); + + if hardswap { + self.warmup_generation_for_all_dcs(rng, generation, &desired_by_dc) + .await; + } else { + self.reconcile_connections(rng).await; + } + + let writers = self.writers.read().await; + let active_writer_addrs: HashSet = writers + .iter() + .filter(|w| !w.draining.load(Ordering::Relaxed)) + .map(|w| w.addr) + .collect(); + let min_ratio = Self::permille_to_ratio( + self.me_pool_min_fresh_ratio_permille + .load(Ordering::Relaxed), + ); + let (coverage_ratio, missing_dc) = Self::coverage_ratio(&desired_by_dc, &active_writer_addrs); + if !hardswap && coverage_ratio < min_ratio { + warn!( + previous_generation, + generation, + coverage_ratio = format_args!("{coverage_ratio:.3}"), + min_ratio = format_args!("{min_ratio:.3}"), + missing_dc = ?missing_dc, + "ME reinit coverage below threshold; keeping stale writers" + ); + return; + } + + if hardswap { + let mut fresh_missing_dc = Vec::<(i32, usize, usize)>::new(); + for (dc, endpoints) in &desired_by_dc { + if endpoints.is_empty() { + continue; + } + let required = Self::required_writers_for_dc(endpoints.len()); + let fresh_count = writers + .iter() + .filter(|w| !w.draining.load(Ordering::Relaxed)) + .filter(|w| w.generation == generation) + .filter(|w| endpoints.contains(&w.addr)) + .count(); + if fresh_count < required { + fresh_missing_dc.push((*dc, fresh_count, required)); + } + } + if !fresh_missing_dc.is_empty() { + warn!( + previous_generation, + generation, + missing_dc = ?fresh_missing_dc, + "ME hardswap pending: fresh generation coverage incomplete" + ); + return; + } + } else if !missing_dc.is_empty() { + warn!( + missing_dc = ?missing_dc, + // Keep stale writers alive when fresh coverage is incomplete. + "ME reinit coverage incomplete; keeping stale writers" + ); + return; + } + + let desired_addrs: HashSet = desired_by_dc + .values() + .flat_map(|set| set.iter().copied()) + .collect(); + + let stale_writer_ids: Vec = writers + .iter() + .filter(|w| !w.draining.load(Ordering::Relaxed)) + .filter(|w| { + if hardswap { + w.generation < generation + } else { + !desired_addrs.contains(&w.addr) + } + }) + .map(|w| w.id) + .collect(); + drop(writers); + + if stale_writer_ids.is_empty() { + debug!("ME reinit cycle completed with no stale writers"); + return; + } + + let drain_timeout = self.force_close_timeout(); + let drain_timeout_secs = drain_timeout.map(|d| d.as_secs()).unwrap_or(0); + info!( + stale_writers = stale_writer_ids.len(), + previous_generation, + generation, + hardswap, + coverage_ratio = format_args!("{coverage_ratio:.3}"), + min_ratio = format_args!("{min_ratio:.3}"), + drain_timeout_secs, + "ME reinit cycle covered; draining stale writers" + ); + self.stats.increment_pool_swap_total(); + for writer_id in stale_writer_ids { + self.mark_writer_draining_with_timeout(writer_id, drain_timeout, !hardswap) + .await; + } + } + + pub async fn zero_downtime_reinit_periodic(self: &Arc, rng: &SecureRandom) { + self.zero_downtime_reinit_after_map_change(rng).await; + } +} diff --git a/src/transport/middle_proxy/pool_writer.rs b/src/transport/middle_proxy/pool_writer.rs new file mode 100644 index 0000000..a7d2960 --- /dev/null +++ b/src/transport/middle_proxy/pool_writer.rs @@ -0,0 +1,348 @@ +use std::net::SocketAddr; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; +use std::time::{Duration, Instant}; + +use bytes::BytesMut; +use rand::Rng; +use tokio::sync::mpsc; +use tokio_util::sync::CancellationToken; +use tracing::{debug, info, warn}; + +use crate::crypto::SecureRandom; +use crate::error::{ProxyError, Result}; +use crate::protocol::constants::RPC_PING_U32; + +use super::codec::{RpcWriter, WriterCommand}; +use super::pool::{MePool, MeWriter}; +use super::reader::reader_loop; +use super::registry::BoundConn; + +const ME_ACTIVE_PING_SECS: u64 = 25; +const ME_ACTIVE_PING_JITTER_SECS: i64 = 5; + +impl MePool { + pub(crate) async fn prune_closed_writers(self: &Arc) { + let closed_writer_ids: Vec = { + let ws = self.writers.read().await; + ws.iter().filter(|w| w.tx.is_closed()).map(|w| w.id).collect() + }; + if closed_writer_ids.is_empty() { + return; + } + + for writer_id in closed_writer_ids { + if self.registry.is_writer_empty(writer_id).await { + let _ = self.remove_writer_only(writer_id).await; + } else { + let _ = self.remove_writer_and_close_clients(writer_id).await; + } + } + } + + pub(crate) async fn connect_one(self: &Arc, addr: SocketAddr, rng: &SecureRandom) -> Result<()> { + let secret_len = self.proxy_secret.read().await.len(); + if secret_len < 32 { + return Err(ProxyError::Proxy("proxy-secret too short for ME auth".into())); + } + + let (stream, _connect_ms) = self.connect_tcp(addr).await?; + let hs = self.handshake_only(stream, addr, rng).await?; + + let writer_id = self.next_writer_id.fetch_add(1, Ordering::Relaxed); + let generation = self.current_generation(); + let cancel = CancellationToken::new(); + let degraded = Arc::new(AtomicBool::new(false)); + let draining = Arc::new(AtomicBool::new(false)); + let draining_started_at_epoch_secs = Arc::new(AtomicU64::new(0)); + let allow_drain_fallback = Arc::new(AtomicBool::new(false)); + let (tx, mut rx) = mpsc::channel::(4096); + let mut rpc_writer = RpcWriter { + writer: hs.wr, + key: hs.write_key, + iv: hs.write_iv, + seq_no: 0, + crc_mode: hs.crc_mode, + }; + let cancel_wr = cancel.clone(); + tokio::spawn(async move { + loop { + tokio::select! { + cmd = rx.recv() => { + match cmd { + Some(WriterCommand::Data(payload)) => { + if rpc_writer.send(&payload).await.is_err() { break; } + } + Some(WriterCommand::DataAndFlush(payload)) => { + if rpc_writer.send_and_flush(&payload).await.is_err() { break; } + } + Some(WriterCommand::Close) | None => break, + } + } + _ = cancel_wr.cancelled() => break, + } + } + }); + let writer = MeWriter { + id: writer_id, + addr, + generation, + tx: tx.clone(), + cancel: cancel.clone(), + degraded: degraded.clone(), + draining: draining.clone(), + draining_started_at_epoch_secs: draining_started_at_epoch_secs.clone(), + allow_drain_fallback: allow_drain_fallback.clone(), + }; + self.writers.write().await.push(writer.clone()); + self.conn_count.fetch_add(1, Ordering::Relaxed); + self.writer_available.notify_one(); + + let reg = self.registry.clone(); + let writers_arc = self.writers_arc(); + let ping_tracker = self.ping_tracker.clone(); + let ping_tracker_reader = ping_tracker.clone(); + let rtt_stats = self.rtt_stats.clone(); + let stats_reader = self.stats.clone(); + let stats_ping = self.stats.clone(); + let pool = Arc::downgrade(self); + let cancel_ping = cancel.clone(); + let tx_ping = tx.clone(); + let ping_tracker_ping = ping_tracker.clone(); + let cleanup_done = Arc::new(AtomicBool::new(false)); + let cleanup_for_reader = cleanup_done.clone(); + let cleanup_for_ping = cleanup_done.clone(); + let keepalive_enabled = self.me_keepalive_enabled; + let keepalive_interval = self.me_keepalive_interval; + let keepalive_jitter = self.me_keepalive_jitter; + let cancel_reader_token = cancel.clone(); + let cancel_ping_token = cancel_ping.clone(); + + tokio::spawn(async move { + let res = reader_loop( + hs.rd, + hs.read_key, + hs.read_iv, + hs.crc_mode, + reg.clone(), + BytesMut::new(), + BytesMut::new(), + tx.clone(), + ping_tracker_reader, + rtt_stats.clone(), + stats_reader, + writer_id, + degraded.clone(), + cancel_reader_token.clone(), + ) + .await; + if let Some(pool) = pool.upgrade() + && cleanup_for_reader + .compare_exchange(false, true, Ordering::AcqRel, Ordering::Relaxed) + .is_ok() + { + pool.remove_writer_and_close_clients(writer_id).await; + } + if let Err(e) = res { + warn!(error = %e, "ME reader ended"); + } + let mut ws = writers_arc.write().await; + ws.retain(|w| w.id != writer_id); + info!(remaining = ws.len(), "Dead ME writer removed from pool"); + }); + + let pool_ping = Arc::downgrade(self); + tokio::spawn(async move { + let mut ping_id: i64 = rand::random::(); + // Per-writer jittered start to avoid phase sync. + let startup_jitter = if keepalive_enabled { + let jitter_cap_ms = keepalive_interval.as_millis() / 2; + let effective_jitter_ms = keepalive_jitter.as_millis().min(jitter_cap_ms).max(1); + Duration::from_millis(rand::rng().random_range(0..=effective_jitter_ms as u64)) + } else { + let jitter = rand::rng().random_range(-ME_ACTIVE_PING_JITTER_SECS..=ME_ACTIVE_PING_JITTER_SECS); + let wait = (ME_ACTIVE_PING_SECS as i64 + jitter).max(5) as u64; + Duration::from_secs(wait) + }; + tokio::select! { + _ = cancel_ping_token.cancelled() => return, + _ = tokio::time::sleep(startup_jitter) => {} + } + loop { + let wait = if keepalive_enabled { + let jitter_cap_ms = keepalive_interval.as_millis() / 2; + let effective_jitter_ms = keepalive_jitter.as_millis().min(jitter_cap_ms).max(1); + keepalive_interval + Duration::from_millis(rand::rng().random_range(0..=effective_jitter_ms as u64)) + } else { + let jitter = rand::rng().random_range(-ME_ACTIVE_PING_JITTER_SECS..=ME_ACTIVE_PING_JITTER_SECS); + let secs = (ME_ACTIVE_PING_SECS as i64 + jitter).max(5) as u64; + Duration::from_secs(secs) + }; + tokio::select! { + _ = cancel_ping_token.cancelled() => { + break; + } + _ = tokio::time::sleep(wait) => {} + } + let sent_id = ping_id; + let mut p = Vec::with_capacity(12); + p.extend_from_slice(&RPC_PING_U32.to_le_bytes()); + p.extend_from_slice(&sent_id.to_le_bytes()); + { + let mut tracker = ping_tracker_ping.lock().await; + let before = tracker.len(); + tracker.retain(|_, (ts, _)| ts.elapsed() < Duration::from_secs(120)); + let expired = before.saturating_sub(tracker.len()); + if expired > 0 { + stats_ping.increment_me_keepalive_timeout_by(expired as u64); + } + tracker.insert(sent_id, (std::time::Instant::now(), writer_id)); + } + ping_id = ping_id.wrapping_add(1); + stats_ping.increment_me_keepalive_sent(); + if tx_ping.send(WriterCommand::DataAndFlush(p)).await.is_err() { + stats_ping.increment_me_keepalive_failed(); + debug!("ME ping failed, removing dead writer"); + cancel_ping.cancel(); + if let Some(pool) = pool_ping.upgrade() + && cleanup_for_ping + .compare_exchange(false, true, Ordering::AcqRel, Ordering::Relaxed) + .is_ok() + { + pool.remove_writer_and_close_clients(writer_id).await; + } + break; + } + } + }); + + Ok(()) + } + + pub(crate) async fn remove_writer_and_close_clients(self: &Arc, writer_id: u64) { + let conns = self.remove_writer_only(writer_id).await; + for bound in conns { + let _ = self.registry.route(bound.conn_id, super::MeResponse::Close).await; + let _ = self.registry.unregister(bound.conn_id).await; + } + } + + async fn remove_writer_only(self: &Arc, writer_id: u64) -> Vec { + let mut close_tx: Option> = None; + let mut removed_addr: Option = None; + let mut trigger_refill = false; + { + let mut ws = self.writers.write().await; + if let Some(pos) = ws.iter().position(|w| w.id == writer_id) { + let w = ws.remove(pos); + let was_draining = w.draining.load(Ordering::Relaxed); + if was_draining { + self.stats.decrement_pool_drain_active(); + } + self.stats.increment_me_writer_removed_total(); + w.cancel.cancel(); + removed_addr = Some(w.addr); + trigger_refill = !was_draining; + if trigger_refill { + self.stats.increment_me_writer_removed_unexpected_total(); + } + close_tx = Some(w.tx.clone()); + self.conn_count.fetch_sub(1, Ordering::Relaxed); + } + } + if let Some(tx) = close_tx { + let _ = tx.send(WriterCommand::Close).await; + } + if trigger_refill + && let Some(addr) = removed_addr + { + self.trigger_immediate_refill(addr); + } + self.rtt_stats.lock().await.remove(&writer_id); + self.registry.writer_lost(writer_id).await + } + + pub(crate) async fn mark_writer_draining_with_timeout( + self: &Arc, + writer_id: u64, + timeout: Option, + allow_drain_fallback: bool, + ) { + let timeout = timeout.filter(|d| !d.is_zero()); + let found = { + let mut ws = self.writers.write().await; + if let Some(w) = ws.iter_mut().find(|w| w.id == writer_id) { + let already_draining = w.draining.swap(true, Ordering::Relaxed); + w.allow_drain_fallback + .store(allow_drain_fallback, Ordering::Relaxed); + w.draining_started_at_epoch_secs + .store(Self::now_epoch_secs(), Ordering::Relaxed); + if !already_draining { + self.stats.increment_pool_drain_active(); + } + w.draining.store(true, Ordering::Relaxed); + true + } else { + false + } + }; + + if !found { + return; + } + + let timeout_secs = timeout.map(|d| d.as_secs()).unwrap_or(0); + debug!( + writer_id, + timeout_secs, + allow_drain_fallback, + "ME writer marked draining" + ); + + let pool = Arc::downgrade(self); + tokio::spawn(async move { + let deadline = timeout.map(|t| Instant::now() + t); + while let Some(p) = pool.upgrade() { + if let Some(deadline_at) = deadline + && Instant::now() >= deadline_at + { + warn!(writer_id, "Drain timeout, force-closing"); + p.stats.increment_pool_force_close_total(); + let _ = p.remove_writer_and_close_clients(writer_id).await; + break; + } + if p.registry.is_writer_empty(writer_id).await { + let _ = p.remove_writer_only(writer_id).await; + break; + } + tokio::time::sleep(Duration::from_secs(1)).await; + } + }); + } + + pub(crate) async fn mark_writer_draining(self: &Arc, writer_id: u64) { + self.mark_writer_draining_with_timeout(writer_id, Some(Duration::from_secs(300)), false) + .await; + } + + pub(super) fn writer_accepts_new_binding(&self, writer: &MeWriter) -> bool { + if !writer.draining.load(Ordering::Relaxed) { + return true; + } + if !writer.allow_drain_fallback.load(Ordering::Relaxed) { + return false; + } + + let ttl_secs = self.me_pool_drain_ttl_secs.load(Ordering::Relaxed); + if ttl_secs == 0 { + return true; + } + + let started = writer.draining_started_at_epoch_secs.load(Ordering::Relaxed); + if started == 0 { + return false; + } + + Self::now_epoch_secs().saturating_sub(started) <= ttl_secs + } +} From 04e6135935446b9ba081ef9deb38e2806e16d602 Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Thu, 26 Feb 2026 19:35:34 +0300 Subject: [PATCH 5/7] TLS-F Fetching Optimization --- src/main.rs | 199 +++++++++++++++++++++++++++++++++------------------- 1 file changed, 127 insertions(+), 72 deletions(-) diff --git a/src/main.rs b/src/main.rs index da88fe3..95f7e5a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -254,6 +254,133 @@ async fn main() -> std::result::Result<(), Box> { warn!("Using default tls_domain. Consider setting a custom domain."); } + let upstream_manager = Arc::new(UpstreamManager::new(config.upstreams.clone())); + + let mut tls_domains = Vec::with_capacity(1 + config.censorship.tls_domains.len()); + tls_domains.push(config.censorship.tls_domain.clone()); + for d in &config.censorship.tls_domains { + if !tls_domains.contains(d) { + tls_domains.push(d.clone()); + } + } + + // Start TLS front fetching in background immediately, in parallel with STUN probing. + let tls_cache: Option> = if config.censorship.tls_emulation { + let cache = Arc::new(TlsFrontCache::new( + &tls_domains, + config.censorship.fake_cert_len, + &config.censorship.tls_front_dir, + )); + cache.load_from_disk().await; + + let port = config.censorship.mask_port; + let proxy_protocol = config.censorship.mask_proxy_protocol; + let mask_host = config + .censorship + .mask_host + .clone() + .unwrap_or_else(|| config.censorship.tls_domain.clone()); + let fetch_timeout = Duration::from_secs(5); + + let cache_initial = cache.clone(); + let domains_initial = tls_domains.clone(); + let host_initial = mask_host.clone(); + let upstream_initial = upstream_manager.clone(); + tokio::spawn(async move { + let mut join = tokio::task::JoinSet::new(); + for domain in domains_initial { + let cache_domain = cache_initial.clone(); + let host_domain = host_initial.clone(); + let upstream_domain = upstream_initial.clone(); + join.spawn(async move { + match crate::tls_front::fetcher::fetch_real_tls( + &host_domain, + port, + &domain, + fetch_timeout, + Some(upstream_domain), + proxy_protocol, + ) + .await + { + Ok(res) => cache_domain.update_from_fetch(&domain, res).await, + Err(e) => { + warn!(domain = %domain, error = %e, "TLS emulation initial fetch failed") + } + } + }); + } + while let Some(res) = join.join_next().await { + if let Err(e) = res { + warn!(error = %e, "TLS emulation initial fetch task join failed"); + } + } + }); + + let cache_timeout = cache.clone(); + let domains_timeout = tls_domains.clone(); + let fake_cert_len = config.censorship.fake_cert_len; + tokio::spawn(async move { + tokio::time::sleep(fetch_timeout).await; + for domain in domains_timeout { + let cached = cache_timeout.get(&domain).await; + if cached.domain == "default" { + warn!( + domain = %domain, + timeout_secs = fetch_timeout.as_secs(), + fake_cert_len, + "TLS-front fetch not ready within timeout; using cache/default fake cert fallback" + ); + } + } + }); + + // Periodic refresh with jitter. + let cache_refresh = cache.clone(); + let domains_refresh = tls_domains.clone(); + let host_refresh = mask_host.clone(); + let upstream_refresh = upstream_manager.clone(); + tokio::spawn(async move { + loop { + let base_secs = rand::rng().random_range(4 * 3600..=6 * 3600); + let jitter_secs = rand::rng().random_range(0..=7200); + tokio::time::sleep(Duration::from_secs(base_secs + jitter_secs)).await; + + let mut join = tokio::task::JoinSet::new(); + for domain in domains_refresh.clone() { + let cache_domain = cache_refresh.clone(); + let host_domain = host_refresh.clone(); + let upstream_domain = upstream_refresh.clone(); + join.spawn(async move { + match crate::tls_front::fetcher::fetch_real_tls( + &host_domain, + port, + &domain, + fetch_timeout, + Some(upstream_domain), + proxy_protocol, + ) + .await + { + Ok(res) => cache_domain.update_from_fetch(&domain, res).await, + Err(e) => warn!(domain = %domain, error = %e, "TLS emulation refresh failed"), + } + }); + } + + while let Some(res) = join.join_next().await { + if let Err(e) = res { + warn!(error = %e, "TLS emulation refresh task join failed"); + } + } + } + }); + + Some(cache) + } else { + None + }; + let probe = run_probe( &config.network, config.general.middle_proxy_nat_probe, @@ -450,80 +577,8 @@ async fn main() -> std::result::Result<(), Box> { Duration::from_secs(config.access.replay_window_secs), )); - let upstream_manager = Arc::new(UpstreamManager::new(config.upstreams.clone())); let buffer_pool = Arc::new(BufferPool::with_config(16 * 1024, 4096)); - // TLS front cache (optional emulation) - let mut tls_domains = Vec::with_capacity(1 + config.censorship.tls_domains.len()); - tls_domains.push(config.censorship.tls_domain.clone()); - for d in &config.censorship.tls_domains { - if !tls_domains.contains(d) { - tls_domains.push(d.clone()); - } - } - - let tls_cache: Option> = if config.censorship.tls_emulation { - let cache = Arc::new(TlsFrontCache::new( - &tls_domains, - config.censorship.fake_cert_len, - &config.censorship.tls_front_dir, - )); - - cache.load_from_disk().await; - - let port = config.censorship.mask_port; - let mask_host = config.censorship.mask_host.clone() - .unwrap_or_else(|| config.censorship.tls_domain.clone()); - // Initial synchronous fetch to warm cache before serving clients. - for domain in tls_domains.clone() { - match crate::tls_front::fetcher::fetch_real_tls( - &mask_host, - port, - &domain, - Duration::from_secs(5), - Some(upstream_manager.clone()), - config.censorship.mask_proxy_protocol, - ) - .await - { - Ok(res) => cache.update_from_fetch(&domain, res).await, - Err(e) => warn!(domain = %domain, error = %e, "TLS emulation fetch failed"), - } - } - - // Periodic refresh with jitter. - let cache_clone = cache.clone(); - let domains = tls_domains.clone(); - let upstream_for_task = upstream_manager.clone(); - let proxy_protocol = config.censorship.mask_proxy_protocol; - tokio::spawn(async move { - loop { - let base_secs = rand::rng().random_range(4 * 3600..=6 * 3600); - let jitter_secs = rand::rng().random_range(0..=7200); - tokio::time::sleep(Duration::from_secs(base_secs + jitter_secs)).await; - for domain in &domains { - match crate::tls_front::fetcher::fetch_real_tls( - &mask_host, - port, - domain, - Duration::from_secs(5), - Some(upstream_for_task.clone()), - proxy_protocol, - ) - .await - { - Ok(res) => cache_clone.update_from_fetch(domain, res).await, - Err(e) => warn!(domain = %domain, error = %e, "TLS emulation refresh failed"), - } - } - } - }); - - Some(cache) - } else { - None - }; - // Middle-End ping before DC connectivity if let Some(ref pool) = me_pool { let me_results = run_me_ping(pool, &rng).await; From 144f81c4730203ab550c958245e3fcb0ae2f9d03 Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Thu, 26 Feb 2026 19:37:17 +0300 Subject: [PATCH 6/7] ME Dead Writer w/o dead-lock on timeout --- src/transport/middle_proxy/pool_writer.rs | 24 ++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/src/transport/middle_proxy/pool_writer.rs b/src/transport/middle_proxy/pool_writer.rs index a7d2960..942ddaf 100644 --- a/src/transport/middle_proxy/pool_writer.rs +++ b/src/transport/middle_proxy/pool_writer.rs @@ -20,6 +20,7 @@ use super::registry::BoundConn; const ME_ACTIVE_PING_SECS: u64 = 25; const ME_ACTIVE_PING_JITTER_SECS: i64 = 5; +const ME_IDLE_KEEPALIVE_MAX_SECS: u64 = 5; impl MePool { pub(crate) async fn prune_closed_writers(self: &Arc) { @@ -154,9 +155,18 @@ impl MePool { let pool_ping = Arc::downgrade(self); tokio::spawn(async move { let mut ping_id: i64 = rand::random::(); + let idle_interval_cap = Duration::from_secs(ME_IDLE_KEEPALIVE_MAX_SECS); // Per-writer jittered start to avoid phase sync. let startup_jitter = if keepalive_enabled { - let jitter_cap_ms = keepalive_interval.as_millis() / 2; + let mut interval = keepalive_interval; + if let Some(pool) = pool_ping.upgrade() { + if pool.registry.is_writer_empty(writer_id).await { + interval = interval.min(idle_interval_cap); + } + } else { + return; + } + let jitter_cap_ms = interval.as_millis() / 2; let effective_jitter_ms = keepalive_jitter.as_millis().min(jitter_cap_ms).max(1); Duration::from_millis(rand::rng().random_range(0..=effective_jitter_ms as u64)) } else { @@ -170,9 +180,17 @@ impl MePool { } loop { let wait = if keepalive_enabled { - let jitter_cap_ms = keepalive_interval.as_millis() / 2; + let mut interval = keepalive_interval; + if let Some(pool) = pool_ping.upgrade() { + if pool.registry.is_writer_empty(writer_id).await { + interval = interval.min(idle_interval_cap); + } + } else { + break; + } + let jitter_cap_ms = interval.as_millis() / 2; let effective_jitter_ms = keepalive_jitter.as_millis().min(jitter_cap_ms).max(1); - keepalive_interval + Duration::from_millis(rand::rng().random_range(0..=effective_jitter_ms as u64)) + interval + Duration::from_millis(rand::rng().random_range(0..=effective_jitter_ms as u64)) } else { let jitter = rand::rng().random_range(-ME_ACTIVE_PING_JITTER_SECS..=ME_ACTIVE_PING_JITTER_SECS); let secs = (ME_ACTIVE_PING_SECS as i64 + jitter).max(5) as u64; From 60231224ac50eb7390090b805e9a179d99b2bd81 Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Thu, 26 Feb 2026 19:41:37 +0300 Subject: [PATCH 7/7] Update Cargo.toml --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 994e11f..1d135f3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "telemt" -version = "3.1.0" +version = "3.1.2" edition = "2024" [dependencies]