diff --git a/src/config/defaults.rs b/src/config/defaults.rs index 6b80ede..d43ace9 100644 --- a/src/config/defaults.rs +++ b/src/config/defaults.rs @@ -182,6 +182,26 @@ pub(crate) fn default_update_every_secs() -> u64 { 30 * 60 } +pub(crate) fn default_me_reinit_every_secs() -> u64 { + 15 * 60 +} + +pub(crate) fn default_me_hardswap_warmup_delay_min_ms() -> u64 { + 1000 +} + +pub(crate) fn default_me_hardswap_warmup_delay_max_ms() -> u64 { + 2000 +} + +pub(crate) fn default_me_hardswap_warmup_extra_passes() -> u8 { + 3 +} + +pub(crate) fn default_me_hardswap_warmup_pass_backoff_base_ms() -> u64 { + 500 +} + pub(crate) fn default_me_config_stable_snapshots() -> u8 { 2 } diff --git a/src/config/load.rs b/src/config/load.rs index be34efa..5698a71 100644 --- a/src/config/load.rs +++ b/src/config/load.rs @@ -147,6 +147,38 @@ impl ProxyConfig { } } + if config.general.me_reinit_every_secs == 0 { + return Err(ProxyError::Config( + "general.me_reinit_every_secs must be > 0".to_string(), + )); + } + + if config.general.me_hardswap_warmup_delay_max_ms == 0 { + return Err(ProxyError::Config( + "general.me_hardswap_warmup_delay_max_ms must be > 0".to_string(), + )); + } + + if config.general.me_hardswap_warmup_delay_min_ms + > config.general.me_hardswap_warmup_delay_max_ms + { + return Err(ProxyError::Config( + "general.me_hardswap_warmup_delay_min_ms must be <= general.me_hardswap_warmup_delay_max_ms".to_string(), + )); + } + + if config.general.me_hardswap_warmup_extra_passes > 10 { + return Err(ProxyError::Config( + "general.me_hardswap_warmup_extra_passes must be within [0, 10]".to_string(), + )); + } + + if config.general.me_hardswap_warmup_pass_backoff_base_ms == 0 { + return Err(ProxyError::Config( + "general.me_hardswap_warmup_pass_backoff_base_ms must be > 0".to_string(), + )); + } + if config.general.me_config_stable_snapshots == 0 { return Err(ProxyError::Config( "general.me_config_stable_snapshots must be > 0".to_string(), @@ -480,6 +512,161 @@ mod tests { let _ = std::fs::remove_file(path); } + #[test] + fn me_reinit_every_default_is_set() { + let toml = r#" + [censorship] + tls_domain = "example.com" + + [access.users] + user = "00000000000000000000000000000000" + "#; + let dir = std::env::temp_dir(); + let path = dir.join("telemt_me_reinit_every_default_test.toml"); + std::fs::write(&path, toml).unwrap(); + let cfg = ProxyConfig::load(&path).unwrap(); + assert_eq!( + cfg.general.me_reinit_every_secs, + default_me_reinit_every_secs() + ); + let _ = std::fs::remove_file(path); + } + + #[test] + fn me_reinit_every_zero_is_rejected() { + let toml = r#" + [general] + me_reinit_every_secs = 0 + + [censorship] + tls_domain = "example.com" + + [access.users] + user = "00000000000000000000000000000000" + "#; + let dir = std::env::temp_dir(); + let path = dir.join("telemt_me_reinit_every_zero_test.toml"); + std::fs::write(&path, toml).unwrap(); + let err = ProxyConfig::load(&path).unwrap_err().to_string(); + assert!(err.contains("general.me_reinit_every_secs must be > 0")); + let _ = std::fs::remove_file(path); + } + + #[test] + fn me_hardswap_warmup_defaults_are_set() { + let toml = r#" + [censorship] + tls_domain = "example.com" + + [access.users] + user = "00000000000000000000000000000000" + "#; + let dir = std::env::temp_dir(); + let path = dir.join("telemt_me_hardswap_warmup_defaults_test.toml"); + std::fs::write(&path, toml).unwrap(); + let cfg = ProxyConfig::load(&path).unwrap(); + assert_eq!( + cfg.general.me_hardswap_warmup_delay_min_ms, + default_me_hardswap_warmup_delay_min_ms() + ); + assert_eq!( + cfg.general.me_hardswap_warmup_delay_max_ms, + default_me_hardswap_warmup_delay_max_ms() + ); + assert_eq!( + cfg.general.me_hardswap_warmup_extra_passes, + default_me_hardswap_warmup_extra_passes() + ); + assert_eq!( + cfg.general.me_hardswap_warmup_pass_backoff_base_ms, + default_me_hardswap_warmup_pass_backoff_base_ms() + ); + let _ = std::fs::remove_file(path); + } + + #[test] + fn me_hardswap_warmup_delay_range_is_validated() { + let toml = r#" + [general] + me_hardswap_warmup_delay_min_ms = 2001 + me_hardswap_warmup_delay_max_ms = 2000 + + [censorship] + tls_domain = "example.com" + + [access.users] + user = "00000000000000000000000000000000" + "#; + let dir = std::env::temp_dir(); + let path = dir.join("telemt_me_hardswap_warmup_delay_range_test.toml"); + std::fs::write(&path, toml).unwrap(); + let err = ProxyConfig::load(&path).unwrap_err().to_string(); + assert!(err.contains( + "general.me_hardswap_warmup_delay_min_ms must be <= general.me_hardswap_warmup_delay_max_ms" + )); + let _ = std::fs::remove_file(path); + } + + #[test] + fn me_hardswap_warmup_delay_max_zero_is_rejected() { + let toml = r#" + [general] + me_hardswap_warmup_delay_max_ms = 0 + + [censorship] + tls_domain = "example.com" + + [access.users] + user = "00000000000000000000000000000000" + "#; + let dir = std::env::temp_dir(); + let path = dir.join("telemt_me_hardswap_warmup_delay_max_zero_test.toml"); + std::fs::write(&path, toml).unwrap(); + let err = ProxyConfig::load(&path).unwrap_err().to_string(); + assert!(err.contains("general.me_hardswap_warmup_delay_max_ms must be > 0")); + let _ = std::fs::remove_file(path); + } + + #[test] + fn me_hardswap_warmup_extra_passes_out_of_range_is_rejected() { + let toml = r#" + [general] + me_hardswap_warmup_extra_passes = 11 + + [censorship] + tls_domain = "example.com" + + [access.users] + user = "00000000000000000000000000000000" + "#; + let dir = std::env::temp_dir(); + let path = dir.join("telemt_me_hardswap_warmup_extra_passes_test.toml"); + std::fs::write(&path, toml).unwrap(); + let err = ProxyConfig::load(&path).unwrap_err().to_string(); + assert!(err.contains("general.me_hardswap_warmup_extra_passes must be within [0, 10]")); + let _ = std::fs::remove_file(path); + } + + #[test] + fn me_hardswap_warmup_pass_backoff_zero_is_rejected() { + let toml = r#" + [general] + me_hardswap_warmup_pass_backoff_base_ms = 0 + + [censorship] + tls_domain = "example.com" + + [access.users] + user = "00000000000000000000000000000000" + "#; + let dir = std::env::temp_dir(); + let path = dir.join("telemt_me_hardswap_warmup_backoff_zero_test.toml"); + std::fs::write(&path, toml).unwrap(); + let err = ProxyConfig::load(&path).unwrap_err().to_string(); + assert!(err.contains("general.me_hardswap_warmup_pass_backoff_base_ms must be > 0")); + let _ = std::fs::remove_file(path); + } + #[test] fn me_config_stable_snapshots_zero_is_rejected() { let toml = r#" diff --git a/src/config/types.rs b/src/config/types.rs index bd9697e..0cda9f4 100644 --- a/src/config/types.rs +++ b/src/config/types.rs @@ -267,6 +267,26 @@ pub struct GeneralConfig { #[serde(default)] pub update_every: Option, + /// Periodic ME pool reinitialization interval in seconds. + #[serde(default = "default_me_reinit_every_secs")] + pub me_reinit_every_secs: u64, + + /// Minimum delay in ms between hardswap warmup connect attempts. + #[serde(default = "default_me_hardswap_warmup_delay_min_ms")] + pub me_hardswap_warmup_delay_min_ms: u64, + + /// Maximum delay in ms between hardswap warmup connect attempts. + #[serde(default = "default_me_hardswap_warmup_delay_max_ms")] + pub me_hardswap_warmup_delay_max_ms: u64, + + /// Additional warmup passes in the same hardswap cycle after the base pass. + #[serde(default = "default_me_hardswap_warmup_extra_passes")] + pub me_hardswap_warmup_extra_passes: u8, + + /// Base backoff in ms between hardswap warmup passes when floor is still incomplete. + #[serde(default = "default_me_hardswap_warmup_pass_backoff_base_ms")] + pub me_hardswap_warmup_pass_backoff_base_ms: u64, + /// Number of identical getProxyConfig snapshots required before applying ME map updates. #[serde(default = "default_me_config_stable_snapshots")] pub me_config_stable_snapshots: u8, @@ -366,6 +386,11 @@ impl Default for GeneralConfig { hardswap: default_hardswap(), fast_mode_min_tls_record: default_fast_mode_min_tls_record(), update_every: Some(default_update_every_secs()), + me_reinit_every_secs: default_me_reinit_every_secs(), + me_hardswap_warmup_delay_min_ms: default_me_hardswap_warmup_delay_min_ms(), + me_hardswap_warmup_delay_max_ms: default_me_hardswap_warmup_delay_max_ms(), + me_hardswap_warmup_extra_passes: default_me_hardswap_warmup_extra_passes(), + me_hardswap_warmup_pass_backoff_base_ms: default_me_hardswap_warmup_pass_backoff_base_ms(), me_config_stable_snapshots: default_me_config_stable_snapshots(), me_config_apply_cooldown_secs: default_me_config_apply_cooldown_secs(), proxy_secret_stable_snapshots: default_proxy_secret_stable_snapshots(), @@ -392,6 +417,11 @@ impl GeneralConfig { .unwrap_or_else(|| self.proxy_secret_auto_reload_secs.min(self.proxy_config_auto_reload_secs)) } + /// Resolve periodic zero-downtime reinit interval for ME writers. + pub fn effective_me_reinit_every_secs(&self) -> u64 { + self.me_reinit_every_secs + } + /// Resolve force-close timeout for stale writers. /// `me_reinit_drain_timeout_secs` remains backward-compatible alias. pub fn effective_me_pool_force_close_secs(&self) -> u64 { diff --git a/src/main.rs b/src/main.rs index 1c7b39c..3bcbf3e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -373,6 +373,10 @@ async fn main() -> std::result::Result<(), Box> { config.general.me_pool_drain_ttl_secs, config.general.effective_me_pool_force_close_secs(), config.general.me_pool_min_fresh_ratio, + config.general.me_hardswap_warmup_delay_min_ms, + config.general.me_hardswap_warmup_delay_max_ms, + config.general.me_hardswap_warmup_extra_passes, + config.general.me_hardswap_warmup_pass_backoff_base_ms, ); let pool_size = config.general.middle_proxy_pool_size.max(1); @@ -391,18 +395,6 @@ async fn main() -> std::result::Result<(), Box> { .await; }); - // Periodic ME connection rotation - let pool_clone_rot = pool.clone(); - let rng_clone_rot = rng.clone(); - tokio::spawn(async move { - crate::transport::middle_proxy::me_rotation_task( - pool_clone_rot, - rng_clone_rot, - std::time::Duration::from_secs(1800), - ) - .await; - }); - Some(pool) } Err(e) => { @@ -712,6 +704,18 @@ async fn main() -> std::result::Result<(), Box> { ) .await; }); + + let pool_clone_rot = pool.clone(); + let rng_clone_rot = rng.clone(); + let config_rx_clone_rot = config_rx.clone(); + tokio::spawn(async move { + crate::transport::middle_proxy::me_rotation_task( + pool_clone_rot, + rng_clone_rot, + config_rx_clone_rot, + ) + .await; + }); } let mut listeners = Vec::new(); diff --git a/src/transport/middle_proxy/config_updater.rs b/src/transport/middle_proxy/config_updater.rs index fc9ed3d..4e8e63f 100644 --- a/src/transport/middle_proxy/config_updater.rs +++ b/src/transport/middle_proxy/config_updater.rs @@ -228,6 +228,10 @@ async fn run_update_cycle( cfg.general.me_pool_drain_ttl_secs, cfg.general.effective_me_pool_force_close_secs(), cfg.general.me_pool_min_fresh_ratio, + cfg.general.me_hardswap_warmup_delay_min_ms, + cfg.general.me_hardswap_warmup_delay_max_ms, + cfg.general.me_hardswap_warmup_extra_passes, + cfg.general.me_hardswap_warmup_pass_backoff_base_ms, ); let required_cfg_snapshots = cfg.general.me_config_stable_snapshots.max(1); @@ -407,6 +411,10 @@ pub async fn me_config_updater( cfg.general.me_pool_drain_ttl_secs, cfg.general.effective_me_pool_force_close_secs(), cfg.general.me_pool_min_fresh_ratio, + cfg.general.me_hardswap_warmup_delay_min_ms, + cfg.general.me_hardswap_warmup_delay_max_ms, + cfg.general.me_hardswap_warmup_extra_passes, + cfg.general.me_hardswap_warmup_pass_backoff_base_ms, ); let new_secs = cfg.general.effective_update_every_secs().max(1); if new_secs == update_every_secs { diff --git a/src/transport/middle_proxy/health.rs b/src/transport/middle_proxy/health.rs index 4bb7e64..dde3354 100644 --- a/src/transport/middle_proxy/health.rs +++ b/src/transport/middle_proxy/health.rs @@ -1,10 +1,9 @@ -use std::collections::{HashMap, HashSet}; +use std::collections::HashMap; use std::net::SocketAddr; use std::sync::Arc; use std::time::{Duration, Instant}; use tracing::{debug, info, warn}; -use rand::seq::SliceRandom; use rand::Rng; use crate::crypto::SecureRandom; @@ -64,31 +63,43 @@ async fn check_family( IpFamily::V4 => pool.proxy_map_v4.read().await.clone(), IpFamily::V6 => pool.proxy_map_v6.read().await.clone(), }; - let writer_addrs: HashSet = pool + + let mut dc_endpoints = HashMap::>::new(); + for (dc, addrs) in map { + let entry = dc_endpoints.entry(dc.abs()).or_default(); + for (ip, port) in addrs { + entry.push(SocketAddr::new(ip, port)); + } + } + for endpoints in dc_endpoints.values_mut() { + endpoints.sort_unstable(); + endpoints.dedup(); + } + + let mut live_addr_counts = HashMap::::new(); + for writer in pool .writers .read() .await .iter() .filter(|w| !w.draining.load(std::sync::atomic::Ordering::Relaxed)) - .map(|w| w.addr) - .collect(); + { + *live_addr_counts.entry(writer.addr).or_insert(0) += 1; + } - let entries: Vec<(i32, Vec)> = map - .iter() - .map(|(dc, addrs)| { - let list = addrs - .iter() - .map(|(ip, port)| SocketAddr::new(*ip, *port)) - .collect::>(); - (*dc, list) - }) - .collect(); - - for (dc, dc_addrs) in entries { - let has_coverage = dc_addrs.iter().any(|a| writer_addrs.contains(a)); - if has_coverage { + for (dc, endpoints) in dc_endpoints { + if endpoints.is_empty() { continue; } + let required = MePool::required_writers_for_dc(endpoints.len()); + let alive = endpoints + .iter() + .map(|addr| *live_addr_counts.get(addr).unwrap_or(&0)) + .sum::(); + if alive >= required { + continue; + } + let missing = required - alive; let key = (dc, family); let now = Instant::now(); @@ -104,32 +115,45 @@ async fn check_family( } *inflight.entry(key).or_insert(0) += 1; - let mut shuffled = dc_addrs.clone(); - shuffled.shuffle(&mut rand::rng()); - let mut success = false; - for addr in shuffled { - let res = tokio::time::timeout(pool.me_one_timeout, pool.connect_one(addr, rng.as_ref())).await; + let mut restored = 0usize; + for _ in 0..missing { + let res = tokio::time::timeout( + pool.me_one_timeout, + pool.connect_endpoints_round_robin(&endpoints, rng.as_ref()), + ) + .await; match res { - Ok(Ok(())) => { - info!(%addr, dc = %dc, ?family, "ME reconnected for DC coverage"); + Ok(true) => { + restored += 1; pool.stats.increment_me_reconnect_success(); - backoff.insert(key, pool.me_reconnect_backoff_base.as_millis() as u64); - let jitter = pool.me_reconnect_backoff_base.as_millis() as u64 / JITTER_FRAC_NUM; - let wait = pool.me_reconnect_backoff_base - + Duration::from_millis(rand::rng().random_range(0..=jitter.max(1))); - next_attempt.insert(key, now + wait); - success = true; - break; } - Ok(Err(e)) => { + Ok(false) => { pool.stats.increment_me_reconnect_attempt(); - debug!(%addr, dc = %dc, error = %e, ?family, "ME reconnect failed") + debug!(dc = %dc, ?family, "ME round-robin reconnect failed") + } + Err(_) => { + pool.stats.increment_me_reconnect_attempt(); + debug!(dc = %dc, ?family, "ME reconnect timed out"); } - Err(_) => debug!(%addr, dc = %dc, ?family, "ME reconnect timed out"), } } - if !success { - pool.stats.increment_me_reconnect_attempt(); + + let now_alive = alive + restored; + if now_alive >= required { + info!( + dc = %dc, + ?family, + alive = now_alive, + required, + endpoint_count = endpoints.len(), + "ME writer floor restored for DC" + ); + backoff.insert(key, pool.me_reconnect_backoff_base.as_millis() as u64); + let jitter = pool.me_reconnect_backoff_base.as_millis() as u64 / JITTER_FRAC_NUM; + let wait = pool.me_reconnect_backoff_base + + Duration::from_millis(rand::rng().random_range(0..=jitter.max(1))); + next_attempt.insert(key, now + wait); + } else { let curr = *backoff.get(&key).unwrap_or(&(pool.me_reconnect_backoff_base.as_millis() as u64)); let next_ms = (curr.saturating_mul(2)).min(pool.me_reconnect_backoff_cap.as_millis() as u64); backoff.insert(key, next_ms); @@ -137,7 +161,15 @@ async fn check_family( let wait = Duration::from_millis(next_ms) + Duration::from_millis(rand::rng().random_range(0..=jitter.max(1))); next_attempt.insert(key, now + wait); - warn!(dc = %dc, backoff_ms = next_ms, ?family, "DC has no ME coverage, scheduled reconnect"); + warn!( + dc = %dc, + ?family, + alive = now_alive, + required, + endpoint_count = endpoints.len(), + backoff_ms = next_ms, + "DC writer floor is below required level, scheduled reconnect" + ); } if let Some(v) = inflight.get_mut(&key) { *v = v.saturating_sub(1); diff --git a/src/transport/middle_proxy/pool.rs b/src/transport/middle_proxy/pool.rs index 06fdc96..aa14e5b 100644 --- a/src/transport/middle_proxy/pool.rs +++ b/src/transport/middle_proxy/pool.rs @@ -75,6 +75,7 @@ pub struct MePool { pub(super) rtt_stats: Arc>>, pub(super) nat_reflection_cache: Arc>, pub(super) writer_available: Arc, + pub(super) refill_inflight: Arc>>, pub(super) conn_count: AtomicUsize, pub(super) stats: Arc, pub(super) generation: AtomicU64, @@ -82,6 +83,10 @@ pub struct MePool { pub(super) me_pool_drain_ttl_secs: AtomicU64, pub(super) me_pool_force_close_secs: AtomicU64, pub(super) me_pool_min_fresh_ratio_permille: AtomicU32, + pub(super) me_hardswap_warmup_delay_min_ms: AtomicU64, + pub(super) me_hardswap_warmup_delay_max_ms: AtomicU64, + pub(super) me_hardswap_warmup_extra_passes: AtomicU32, + pub(super) me_hardswap_warmup_pass_backoff_base_ms: AtomicU64, pool_size: usize, } @@ -139,6 +144,10 @@ impl MePool { me_pool_drain_ttl_secs: u64, me_pool_force_close_secs: u64, me_pool_min_fresh_ratio: f32, + me_hardswap_warmup_delay_min_ms: u64, + me_hardswap_warmup_delay_max_ms: u64, + me_hardswap_warmup_extra_passes: u8, + me_hardswap_warmup_pass_backoff_base_ms: u64, ) -> Arc { Arc::new(Self { registry: Arc::new(ConnRegistry::new()), @@ -180,12 +189,17 @@ impl MePool { rtt_stats: Arc::new(Mutex::new(HashMap::new())), nat_reflection_cache: Arc::new(Mutex::new(NatReflectionCache::default())), writer_available: Arc::new(Notify::new()), + refill_inflight: Arc::new(Mutex::new(HashSet::new())), conn_count: AtomicUsize::new(0), generation: AtomicU64::new(1), hardswap: AtomicBool::new(hardswap), me_pool_drain_ttl_secs: AtomicU64::new(me_pool_drain_ttl_secs), me_pool_force_close_secs: AtomicU64::new(me_pool_force_close_secs), me_pool_min_fresh_ratio_permille: AtomicU32::new(Self::ratio_to_permille(me_pool_min_fresh_ratio)), + me_hardswap_warmup_delay_min_ms: AtomicU64::new(me_hardswap_warmup_delay_min_ms), + me_hardswap_warmup_delay_max_ms: AtomicU64::new(me_hardswap_warmup_delay_max_ms), + me_hardswap_warmup_extra_passes: AtomicU32::new(me_hardswap_warmup_extra_passes as u32), + me_hardswap_warmup_pass_backoff_base_ms: AtomicU64::new(me_hardswap_warmup_pass_backoff_base_ms), }) } @@ -203,6 +217,10 @@ impl MePool { drain_ttl_secs: u64, force_close_secs: u64, min_fresh_ratio: f32, + hardswap_warmup_delay_min_ms: u64, + hardswap_warmup_delay_max_ms: u64, + hardswap_warmup_extra_passes: u8, + hardswap_warmup_pass_backoff_base_ms: u64, ) { self.hardswap.store(hardswap, Ordering::Relaxed); self.me_pool_drain_ttl_secs.store(drain_ttl_secs, Ordering::Relaxed); @@ -210,6 +228,14 @@ impl MePool { .store(force_close_secs, Ordering::Relaxed); self.me_pool_min_fresh_ratio_permille .store(Self::ratio_to_permille(min_fresh_ratio), Ordering::Relaxed); + self.me_hardswap_warmup_delay_min_ms + .store(hardswap_warmup_delay_min_ms, Ordering::Relaxed); + self.me_hardswap_warmup_delay_max_ms + .store(hardswap_warmup_delay_max_ms, Ordering::Relaxed); + self.me_hardswap_warmup_extra_passes + .store(hardswap_warmup_extra_passes as u32, Ordering::Relaxed); + self.me_hardswap_warmup_pass_backoff_base_ms + .store(hardswap_warmup_pass_backoff_base_ms, Ordering::Relaxed); } pub fn reset_stun_state(&self) { @@ -324,36 +350,172 @@ impl MePool { out } + pub(super) fn required_writers_for_dc(endpoint_count: usize) -> usize { + endpoint_count.max(3) + } + + fn hardswap_warmup_connect_delay_ms(&self) -> u64 { + let min_ms = self + .me_hardswap_warmup_delay_min_ms + .load(Ordering::Relaxed); + let max_ms = self + .me_hardswap_warmup_delay_max_ms + .load(Ordering::Relaxed); + let (min_ms, max_ms) = if min_ms <= max_ms { + (min_ms, max_ms) + } else { + (max_ms, min_ms) + }; + if min_ms == max_ms { + return min_ms; + } + rand::rng().random_range(min_ms..=max_ms) + } + + fn hardswap_warmup_backoff_ms(&self, pass_idx: usize) -> u64 { + let base_ms = self + .me_hardswap_warmup_pass_backoff_base_ms + .load(Ordering::Relaxed); + let cap_ms = (self.me_reconnect_backoff_cap.as_millis() as u64).max(base_ms); + let shift = (pass_idx as u32).min(20); + let scaled = base_ms.saturating_mul(1u64 << shift); + let core = scaled.min(cap_ms); + let jitter = (core / 2).max(1); + core.saturating_add(rand::rng().random_range(0..=jitter)) + } + + async fn fresh_writer_count_for_endpoints( + &self, + generation: u64, + endpoints: &HashSet, + ) -> usize { + let ws = self.writers.read().await; + ws.iter() + .filter(|w| !w.draining.load(Ordering::Relaxed)) + .filter(|w| w.generation == generation) + .filter(|w| endpoints.contains(&w.addr)) + .count() + } + + pub(super) async fn connect_endpoints_round_robin( + self: &Arc, + endpoints: &[SocketAddr], + rng: &SecureRandom, + ) -> bool { + if endpoints.is_empty() { + return false; + } + let start = (self.rr.fetch_add(1, Ordering::Relaxed) as usize) % endpoints.len(); + for offset in 0..endpoints.len() { + let idx = (start + offset) % endpoints.len(); + let addr = endpoints[idx]; + match self.connect_one(addr, rng).await { + Ok(()) => return true, + Err(e) => debug!(%addr, error = %e, "ME connect failed during round-robin warmup"), + } + } + false + } + async fn warmup_generation_for_all_dcs( self: &Arc, rng: &SecureRandom, generation: u64, desired_by_dc: &HashMap>, ) { - for endpoints in desired_by_dc.values() { + let extra_passes = self + .me_hardswap_warmup_extra_passes + .load(Ordering::Relaxed) + .min(10) as usize; + let total_passes = 1 + extra_passes; + + for (dc, endpoints) in desired_by_dc { if endpoints.is_empty() { continue; } - let has_fresh = { - let ws = self.writers.read().await; - ws.iter().any(|w| { - !w.draining.load(Ordering::Relaxed) - && w.generation == generation - && endpoints.contains(&w.addr) - }) - }; + let mut endpoint_list: Vec = endpoints.iter().copied().collect(); + endpoint_list.sort_unstable(); + let required = Self::required_writers_for_dc(endpoint_list.len()); + let mut completed = false; + let mut last_fresh_count = self + .fresh_writer_count_for_endpoints(generation, endpoints) + .await; - if has_fresh { - continue; - } - - let mut shuffled: Vec = endpoints.iter().copied().collect(); - shuffled.shuffle(&mut rand::rng()); - for addr in shuffled { - if self.connect_one(addr, rng).await.is_ok() { + for pass_idx in 0..total_passes { + if last_fresh_count >= required { + completed = true; break; } + + let missing = required.saturating_sub(last_fresh_count); + debug!( + dc = *dc, + pass = pass_idx + 1, + total_passes, + fresh_count = last_fresh_count, + required, + missing, + endpoint_count = endpoint_list.len(), + "ME hardswap warmup pass started" + ); + + for attempt_idx in 0..missing { + let delay_ms = self.hardswap_warmup_connect_delay_ms(); + tokio::time::sleep(Duration::from_millis(delay_ms)).await; + + let connected = self.connect_endpoints_round_robin(&endpoint_list, rng).await; + debug!( + dc = *dc, + pass = pass_idx + 1, + total_passes, + attempt = attempt_idx + 1, + delay_ms, + connected, + "ME hardswap warmup connect attempt finished" + ); + } + + last_fresh_count = self + .fresh_writer_count_for_endpoints(generation, endpoints) + .await; + if last_fresh_count >= required { + completed = true; + info!( + dc = *dc, + pass = pass_idx + 1, + total_passes, + fresh_count = last_fresh_count, + required, + "ME hardswap warmup floor reached for DC" + ); + break; + } + + if pass_idx + 1 < total_passes { + let backoff_ms = self.hardswap_warmup_backoff_ms(pass_idx); + debug!( + dc = *dc, + pass = pass_idx + 1, + total_passes, + fresh_count = last_fresh_count, + required, + backoff_ms, + "ME hardswap warmup pass incomplete, delaying next pass" + ); + tokio::time::sleep(Duration::from_millis(backoff_ms)).await; + } + } + + if !completed { + warn!( + dc = *dc, + fresh_count = last_fresh_count, + required, + endpoint_count = endpoint_list.len(), + total_passes, + "ME warmup stopped: unable to reach required writer floor for DC" + ); } } } @@ -364,7 +526,7 @@ impl MePool { ) { let desired_by_dc = self.desired_dc_endpoints().await; if desired_by_dc.is_empty() { - warn!("ME endpoint map is empty after update; skipping stale writer drain"); + warn!("ME endpoint map is empty; skipping stale writer drain"); return; } @@ -403,19 +565,26 @@ impl MePool { } if hardswap { - let fresh_writer_addrs: HashSet = writers - .iter() - .filter(|w| !w.draining.load(Ordering::Relaxed)) - .filter(|w| w.generation == generation) - .map(|w| w.addr) - .collect(); - let (fresh_ratio, fresh_missing_dc) = - Self::coverage_ratio(&desired_by_dc, &fresh_writer_addrs); + let mut fresh_missing_dc = Vec::<(i32, usize, usize)>::new(); + for (dc, endpoints) in &desired_by_dc { + if endpoints.is_empty() { + continue; + } + let required = Self::required_writers_for_dc(endpoints.len()); + let fresh_count = writers + .iter() + .filter(|w| !w.draining.load(Ordering::Relaxed)) + .filter(|w| w.generation == generation) + .filter(|w| endpoints.contains(&w.addr)) + .count(); + if fresh_count < required { + fresh_missing_dc.push((*dc, fresh_count, required)); + } + } if !fresh_missing_dc.is_empty() { warn!( previous_generation, generation, - fresh_ratio = format_args!("{fresh_ratio:.3}"), missing_dc = ?fresh_missing_dc, "ME hardswap pending: fresh generation coverage incomplete" ); @@ -425,7 +594,7 @@ impl MePool { warn!( missing_dc = ?missing_dc, // Keep stale writers alive when fresh coverage is incomplete. - "ME reinit coverage incomplete after map update; keeping stale writers" + "ME reinit coverage incomplete; keeping stale writers" ); return; } @@ -450,7 +619,7 @@ impl MePool { drop(writers); if stale_writer_ids.is_empty() { - debug!("ME map update completed with no stale writers"); + debug!("ME reinit cycle completed with no stale writers"); return; } @@ -464,7 +633,7 @@ impl MePool { coverage_ratio = format_args!("{coverage_ratio:.3}"), min_ratio = format_args!("{min_ratio:.3}"), drain_timeout_secs, - "ME map update covered; draining stale writers" + "ME reinit cycle covered; draining stale writers" ); self.stats.increment_pool_swap_total(); for writer_id in stale_writer_ids { @@ -473,6 +642,134 @@ impl MePool { } } + pub async fn zero_downtime_reinit_periodic( + self: &Arc, + rng: &SecureRandom, + ) { + self.zero_downtime_reinit_after_map_change(rng).await; + } + + async fn endpoints_for_same_dc(&self, addr: SocketAddr) -> Vec { + let mut target_dc = HashSet::::new(); + let mut endpoints = HashSet::::new(); + + if self.decision.ipv4_me { + let map = self.proxy_map_v4.read().await.clone(); + for (dc, addrs) in &map { + if addrs + .iter() + .any(|(ip, port)| SocketAddr::new(*ip, *port) == addr) + { + target_dc.insert(dc.abs()); + } + } + for dc in &target_dc { + for key in [*dc, -*dc] { + if let Some(addrs) = map.get(&key) { + for (ip, port) in addrs { + endpoints.insert(SocketAddr::new(*ip, *port)); + } + } + } + } + } + + if self.decision.ipv6_me { + let map = self.proxy_map_v6.read().await.clone(); + for (dc, addrs) in &map { + if addrs + .iter() + .any(|(ip, port)| SocketAddr::new(*ip, *port) == addr) + { + target_dc.insert(dc.abs()); + } + } + for dc in &target_dc { + for key in [*dc, -*dc] { + if let Some(addrs) = map.get(&key) { + for (ip, port) in addrs { + endpoints.insert(SocketAddr::new(*ip, *port)); + } + } + } + } + } + + let mut sorted: Vec = endpoints.into_iter().collect(); + sorted.sort_unstable(); + sorted + } + + async fn refill_writer_after_loss(self: &Arc, addr: SocketAddr) -> bool { + let fast_retries = self.me_reconnect_fast_retry_count.max(1); + + for attempt in 0..fast_retries { + self.stats.increment_me_reconnect_attempt(); + match self.connect_one(addr, self.rng.as_ref()).await { + Ok(()) => { + self.stats.increment_me_reconnect_success(); + info!( + %addr, + attempt = attempt + 1, + "ME writer restored on the same endpoint" + ); + return true; + } + Err(e) => { + debug!( + %addr, + attempt = attempt + 1, + error = %e, + "ME immediate same-endpoint reconnect failed" + ); + } + } + } + + let dc_endpoints = self.endpoints_for_same_dc(addr).await; + if dc_endpoints.is_empty() { + return false; + } + + for attempt in 0..fast_retries { + self.stats.increment_me_reconnect_attempt(); + if self + .connect_endpoints_round_robin(&dc_endpoints, self.rng.as_ref()) + .await + { + self.stats.increment_me_reconnect_success(); + info!( + %addr, + attempt = attempt + 1, + "ME writer restored via DC fallback endpoint" + ); + return true; + } + } + + false + } + + pub(crate) fn trigger_immediate_refill(self: &Arc, addr: SocketAddr) { + let pool = Arc::clone(self); + tokio::spawn(async move { + { + let mut guard = pool.refill_inflight.lock().await; + if !guard.insert(addr) { + return; + } + } + + let restored = pool.refill_writer_after_loss(addr).await; + if !restored { + warn!(%addr, "ME immediate refill failed"); + } + + let mut guard = pool.refill_inflight.lock().await; + guard.remove(&addr); + }); + } + pub async fn update_proxy_maps( &self, new_v4: HashMap>, @@ -880,16 +1177,21 @@ impl MePool { } } - async fn remove_writer_only(&self, writer_id: u64) -> Vec { + async fn remove_writer_only(self: &Arc, writer_id: u64) -> Vec { let mut close_tx: Option> = None; + let mut removed_addr: Option = None; + let mut trigger_refill = false; { let mut ws = self.writers.write().await; if let Some(pos) = ws.iter().position(|w| w.id == writer_id) { let w = ws.remove(pos); - if w.draining.load(Ordering::Relaxed) { + let was_draining = w.draining.load(Ordering::Relaxed); + if was_draining { self.stats.decrement_pool_drain_active(); } w.cancel.cancel(); + removed_addr = Some(w.addr); + trigger_refill = !was_draining; close_tx = Some(w.tx.clone()); self.conn_count.fetch_sub(1, Ordering::Relaxed); } @@ -897,6 +1199,11 @@ impl MePool { if let Some(tx) = close_tx { let _ = tx.send(WriterCommand::Close).await; } + if trigger_refill + && let Some(addr) = removed_addr + { + self.trigger_immediate_refill(addr); + } self.rtt_stats.lock().await.remove(&writer_id); self.registry.writer_lost(writer_id).await } diff --git a/src/transport/middle_proxy/rotation.rs b/src/transport/middle_proxy/rotation.rs index e141fc4..cf5f70d 100644 --- a/src/transport/middle_proxy/rotation.rs +++ b/src/transport/middle_proxy/rotation.rs @@ -1,50 +1,87 @@ use std::sync::Arc; -use std::sync::atomic::Ordering; use std::time::Duration; +use tokio::sync::watch; use tracing::{info, warn}; +use crate::config::ProxyConfig; use crate::crypto::SecureRandom; use super::MePool; -/// Periodically refresh ME connections to avoid long-lived degradation. -pub async fn me_rotation_task(pool: Arc, rng: Arc, interval: Duration) { - let interval = interval.max(Duration::from_secs(600)); +/// Periodically reinitialize ME generations and swap them after full warmup. +pub async fn me_rotation_task( + pool: Arc, + rng: Arc, + mut config_rx: watch::Receiver>, +) { + let mut interval_secs = config_rx + .borrow() + .general + .effective_me_reinit_every_secs() + .max(1); + let mut interval = Duration::from_secs(interval_secs); + let mut next_tick = tokio::time::Instant::now() + interval; + + info!(interval_secs, "ME periodic reinit task started"); + loop { - tokio::time::sleep(interval).await; + let sleep = tokio::time::sleep_until(next_tick); + tokio::pin!(sleep); - let candidate = { - let ws = pool.writers.read().await; - if ws.is_empty() { - None - } else { - let idx = (pool.rr.load(std::sync::atomic::Ordering::Relaxed) as usize) % ws.len(); - ws.get(idx).cloned() - } - }; - - let Some(w) = candidate else { - continue; - }; - - info!(addr = %w.addr, writer_id = w.id, "Rotating ME connection"); - match pool.connect_one(w.addr, rng.as_ref()).await { - Ok(()) => { - tokio::time::sleep(Duration::from_secs(2)).await; - let ws = pool.writers.read().await; - let new_alive = ws.iter().any(|nw| - nw.id != w.id && nw.addr == w.addr && !nw.degraded.load(Ordering::Relaxed) && !nw.draining.load(Ordering::Relaxed) - ); - drop(ws); - if new_alive { - pool.mark_writer_draining(w.id).await; - } else { - warn!(addr = %w.addr, writer_id = w.id, "New writer died, keeping old"); + tokio::select! { + _ = &mut sleep => { + pool.zero_downtime_reinit_periodic(rng.as_ref()).await; + let refreshed_secs = config_rx + .borrow() + .general + .effective_me_reinit_every_secs() + .max(1); + if refreshed_secs != interval_secs { + info!( + old_me_reinit_every_secs = interval_secs, + new_me_reinit_every_secs = refreshed_secs, + "ME periodic reinit interval changed" + ); + interval_secs = refreshed_secs; + interval = Duration::from_secs(interval_secs); } + next_tick = tokio::time::Instant::now() + interval; } - Err(e) => { - warn!(addr = %w.addr, writer_id = w.id, error = %e, "ME rotation connect failed"); + changed = config_rx.changed() => { + if changed.is_err() { + warn!("ME periodic reinit task stopped: config channel closed"); + break; + } + let new_secs = config_rx + .borrow() + .general + .effective_me_reinit_every_secs() + .max(1); + if new_secs == interval_secs { + continue; + } + + if new_secs < interval_secs { + info!( + old_me_reinit_every_secs = interval_secs, + new_me_reinit_every_secs = new_secs, + "ME periodic reinit interval decreased, running immediate reinit" + ); + interval_secs = new_secs; + interval = Duration::from_secs(interval_secs); + pool.zero_downtime_reinit_periodic(rng.as_ref()).await; + next_tick = tokio::time::Instant::now() + interval; + } else { + info!( + old_me_reinit_every_secs = interval_secs, + new_me_reinit_every_secs = new_secs, + "ME periodic reinit interval increased" + ); + interval_secs = new_secs; + interval = Duration::from_secs(interval_secs); + next_tick = tokio::time::Instant::now() + interval; + } } } }