diff --git a/src/config/defaults.rs b/src/config/defaults.rs index 6b80ede..4f563ba 100644 --- a/src/config/defaults.rs +++ b/src/config/defaults.rs @@ -182,6 +182,10 @@ pub(crate) fn default_update_every_secs() -> u64 { 30 * 60 } +pub(crate) fn default_me_reinit_every_secs() -> u64 { + 15 * 60 +} + pub(crate) fn default_me_config_stable_snapshots() -> u8 { 2 } diff --git a/src/config/load.rs b/src/config/load.rs index be34efa..c18c84f 100644 --- a/src/config/load.rs +++ b/src/config/load.rs @@ -147,6 +147,12 @@ impl ProxyConfig { } } + if config.general.me_reinit_every_secs == 0 { + return Err(ProxyError::Config( + "general.me_reinit_every_secs must be > 0".to_string(), + )); + } + if config.general.me_config_stable_snapshots == 0 { return Err(ProxyError::Config( "general.me_config_stable_snapshots must be > 0".to_string(), @@ -480,6 +486,46 @@ mod tests { let _ = std::fs::remove_file(path); } + #[test] + fn me_reinit_every_default_is_set() { + let toml = r#" + [censorship] + tls_domain = "example.com" + + [access.users] + user = "00000000000000000000000000000000" + "#; + let dir = std::env::temp_dir(); + let path = dir.join("telemt_me_reinit_every_default_test.toml"); + std::fs::write(&path, toml).unwrap(); + let cfg = ProxyConfig::load(&path).unwrap(); + assert_eq!( + cfg.general.me_reinit_every_secs, + default_me_reinit_every_secs() + ); + let _ = std::fs::remove_file(path); + } + + #[test] + fn me_reinit_every_zero_is_rejected() { + let toml = r#" + [general] + me_reinit_every_secs = 0 + + [censorship] + tls_domain = "example.com" + + [access.users] + user = "00000000000000000000000000000000" + "#; + let dir = std::env::temp_dir(); + let path = dir.join("telemt_me_reinit_every_zero_test.toml"); + std::fs::write(&path, toml).unwrap(); + let err = ProxyConfig::load(&path).unwrap_err().to_string(); + assert!(err.contains("general.me_reinit_every_secs must be > 0")); + let _ = std::fs::remove_file(path); + } + #[test] fn me_config_stable_snapshots_zero_is_rejected() { let toml = r#" diff --git a/src/config/types.rs b/src/config/types.rs index bd9697e..03417c5 100644 --- a/src/config/types.rs +++ b/src/config/types.rs @@ -267,6 +267,10 @@ pub struct GeneralConfig { #[serde(default)] pub update_every: Option, + /// Periodic ME pool reinitialization interval in seconds. + #[serde(default = "default_me_reinit_every_secs")] + pub me_reinit_every_secs: u64, + /// Number of identical getProxyConfig snapshots required before applying ME map updates. #[serde(default = "default_me_config_stable_snapshots")] pub me_config_stable_snapshots: u8, @@ -366,6 +370,7 @@ impl Default for GeneralConfig { hardswap: default_hardswap(), fast_mode_min_tls_record: default_fast_mode_min_tls_record(), update_every: Some(default_update_every_secs()), + me_reinit_every_secs: default_me_reinit_every_secs(), me_config_stable_snapshots: default_me_config_stable_snapshots(), me_config_apply_cooldown_secs: default_me_config_apply_cooldown_secs(), proxy_secret_stable_snapshots: default_proxy_secret_stable_snapshots(), @@ -392,6 +397,11 @@ impl GeneralConfig { .unwrap_or_else(|| self.proxy_secret_auto_reload_secs.min(self.proxy_config_auto_reload_secs)) } + /// Resolve periodic zero-downtime reinit interval for ME writers. + pub fn effective_me_reinit_every_secs(&self) -> u64 { + self.me_reinit_every_secs + } + /// Resolve force-close timeout for stale writers. /// `me_reinit_drain_timeout_secs` remains backward-compatible alias. pub fn effective_me_pool_force_close_secs(&self) -> u64 { diff --git a/src/main.rs b/src/main.rs index 1c7b39c..d9a692d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -391,18 +391,6 @@ async fn main() -> std::result::Result<(), Box> { .await; }); - // Periodic ME connection rotation - let pool_clone_rot = pool.clone(); - let rng_clone_rot = rng.clone(); - tokio::spawn(async move { - crate::transport::middle_proxy::me_rotation_task( - pool_clone_rot, - rng_clone_rot, - std::time::Duration::from_secs(1800), - ) - .await; - }); - Some(pool) } Err(e) => { @@ -712,6 +700,18 @@ async fn main() -> std::result::Result<(), Box> { ) .await; }); + + let pool_clone_rot = pool.clone(); + let rng_clone_rot = rng.clone(); + let config_rx_clone_rot = config_rx.clone(); + tokio::spawn(async move { + crate::transport::middle_proxy::me_rotation_task( + pool_clone_rot, + rng_clone_rot, + config_rx_clone_rot, + ) + .await; + }); } let mut listeners = Vec::new(); diff --git a/src/transport/middle_proxy/health.rs b/src/transport/middle_proxy/health.rs index 4bb7e64..dde3354 100644 --- a/src/transport/middle_proxy/health.rs +++ b/src/transport/middle_proxy/health.rs @@ -1,10 +1,9 @@ -use std::collections::{HashMap, HashSet}; +use std::collections::HashMap; use std::net::SocketAddr; use std::sync::Arc; use std::time::{Duration, Instant}; use tracing::{debug, info, warn}; -use rand::seq::SliceRandom; use rand::Rng; use crate::crypto::SecureRandom; @@ -64,31 +63,43 @@ async fn check_family( IpFamily::V4 => pool.proxy_map_v4.read().await.clone(), IpFamily::V6 => pool.proxy_map_v6.read().await.clone(), }; - let writer_addrs: HashSet = pool + + let mut dc_endpoints = HashMap::>::new(); + for (dc, addrs) in map { + let entry = dc_endpoints.entry(dc.abs()).or_default(); + for (ip, port) in addrs { + entry.push(SocketAddr::new(ip, port)); + } + } + for endpoints in dc_endpoints.values_mut() { + endpoints.sort_unstable(); + endpoints.dedup(); + } + + let mut live_addr_counts = HashMap::::new(); + for writer in pool .writers .read() .await .iter() .filter(|w| !w.draining.load(std::sync::atomic::Ordering::Relaxed)) - .map(|w| w.addr) - .collect(); + { + *live_addr_counts.entry(writer.addr).or_insert(0) += 1; + } - let entries: Vec<(i32, Vec)> = map - .iter() - .map(|(dc, addrs)| { - let list = addrs - .iter() - .map(|(ip, port)| SocketAddr::new(*ip, *port)) - .collect::>(); - (*dc, list) - }) - .collect(); - - for (dc, dc_addrs) in entries { - let has_coverage = dc_addrs.iter().any(|a| writer_addrs.contains(a)); - if has_coverage { + for (dc, endpoints) in dc_endpoints { + if endpoints.is_empty() { continue; } + let required = MePool::required_writers_for_dc(endpoints.len()); + let alive = endpoints + .iter() + .map(|addr| *live_addr_counts.get(addr).unwrap_or(&0)) + .sum::(); + if alive >= required { + continue; + } + let missing = required - alive; let key = (dc, family); let now = Instant::now(); @@ -104,32 +115,45 @@ async fn check_family( } *inflight.entry(key).or_insert(0) += 1; - let mut shuffled = dc_addrs.clone(); - shuffled.shuffle(&mut rand::rng()); - let mut success = false; - for addr in shuffled { - let res = tokio::time::timeout(pool.me_one_timeout, pool.connect_one(addr, rng.as_ref())).await; + let mut restored = 0usize; + for _ in 0..missing { + let res = tokio::time::timeout( + pool.me_one_timeout, + pool.connect_endpoints_round_robin(&endpoints, rng.as_ref()), + ) + .await; match res { - Ok(Ok(())) => { - info!(%addr, dc = %dc, ?family, "ME reconnected for DC coverage"); + Ok(true) => { + restored += 1; pool.stats.increment_me_reconnect_success(); - backoff.insert(key, pool.me_reconnect_backoff_base.as_millis() as u64); - let jitter = pool.me_reconnect_backoff_base.as_millis() as u64 / JITTER_FRAC_NUM; - let wait = pool.me_reconnect_backoff_base - + Duration::from_millis(rand::rng().random_range(0..=jitter.max(1))); - next_attempt.insert(key, now + wait); - success = true; - break; } - Ok(Err(e)) => { + Ok(false) => { pool.stats.increment_me_reconnect_attempt(); - debug!(%addr, dc = %dc, error = %e, ?family, "ME reconnect failed") + debug!(dc = %dc, ?family, "ME round-robin reconnect failed") + } + Err(_) => { + pool.stats.increment_me_reconnect_attempt(); + debug!(dc = %dc, ?family, "ME reconnect timed out"); } - Err(_) => debug!(%addr, dc = %dc, ?family, "ME reconnect timed out"), } } - if !success { - pool.stats.increment_me_reconnect_attempt(); + + let now_alive = alive + restored; + if now_alive >= required { + info!( + dc = %dc, + ?family, + alive = now_alive, + required, + endpoint_count = endpoints.len(), + "ME writer floor restored for DC" + ); + backoff.insert(key, pool.me_reconnect_backoff_base.as_millis() as u64); + let jitter = pool.me_reconnect_backoff_base.as_millis() as u64 / JITTER_FRAC_NUM; + let wait = pool.me_reconnect_backoff_base + + Duration::from_millis(rand::rng().random_range(0..=jitter.max(1))); + next_attempt.insert(key, now + wait); + } else { let curr = *backoff.get(&key).unwrap_or(&(pool.me_reconnect_backoff_base.as_millis() as u64)); let next_ms = (curr.saturating_mul(2)).min(pool.me_reconnect_backoff_cap.as_millis() as u64); backoff.insert(key, next_ms); @@ -137,7 +161,15 @@ async fn check_family( let wait = Duration::from_millis(next_ms) + Duration::from_millis(rand::rng().random_range(0..=jitter.max(1))); next_attempt.insert(key, now + wait); - warn!(dc = %dc, backoff_ms = next_ms, ?family, "DC has no ME coverage, scheduled reconnect"); + warn!( + dc = %dc, + ?family, + alive = now_alive, + required, + endpoint_count = endpoints.len(), + backoff_ms = next_ms, + "DC writer floor is below required level, scheduled reconnect" + ); } if let Some(v) = inflight.get_mut(&key) { *v = v.saturating_sub(1); diff --git a/src/transport/middle_proxy/pool.rs b/src/transport/middle_proxy/pool.rs index 06fdc96..223d488 100644 --- a/src/transport/middle_proxy/pool.rs +++ b/src/transport/middle_proxy/pool.rs @@ -75,6 +75,7 @@ pub struct MePool { pub(super) rtt_stats: Arc>>, pub(super) nat_reflection_cache: Arc>, pub(super) writer_available: Arc, + pub(super) refill_inflight: Arc>>, pub(super) conn_count: AtomicUsize, pub(super) stats: Arc, pub(super) generation: AtomicU64, @@ -180,6 +181,7 @@ impl MePool { rtt_stats: Arc::new(Mutex::new(HashMap::new())), nat_reflection_cache: Arc::new(Mutex::new(NatReflectionCache::default())), writer_available: Arc::new(Notify::new()), + refill_inflight: Arc::new(Mutex::new(HashSet::new())), conn_count: AtomicUsize::new(0), generation: AtomicU64::new(1), hardswap: AtomicBool::new(hardswap), @@ -324,34 +326,66 @@ impl MePool { out } + pub(super) fn required_writers_for_dc(endpoint_count: usize) -> usize { + endpoint_count.max(3) + } + + pub(super) async fn connect_endpoints_round_robin( + self: &Arc, + endpoints: &[SocketAddr], + rng: &SecureRandom, + ) -> bool { + if endpoints.is_empty() { + return false; + } + let start = (self.rr.fetch_add(1, Ordering::Relaxed) as usize) % endpoints.len(); + for offset in 0..endpoints.len() { + let idx = (start + offset) % endpoints.len(); + let addr = endpoints[idx]; + match self.connect_one(addr, rng).await { + Ok(()) => return true, + Err(e) => debug!(%addr, error = %e, "ME connect failed during round-robin warmup"), + } + } + false + } + async fn warmup_generation_for_all_dcs( self: &Arc, rng: &SecureRandom, generation: u64, desired_by_dc: &HashMap>, ) { - for endpoints in desired_by_dc.values() { + for (dc, endpoints) in desired_by_dc { if endpoints.is_empty() { continue; } - let has_fresh = { - let ws = self.writers.read().await; - ws.iter().any(|w| { - !w.draining.load(Ordering::Relaxed) - && w.generation == generation - && endpoints.contains(&w.addr) - }) - }; + let mut endpoint_list: Vec = endpoints.iter().copied().collect(); + endpoint_list.sort_unstable(); + let required = Self::required_writers_for_dc(endpoint_list.len()); - if has_fresh { - continue; - } + loop { + let fresh_count = { + let ws = self.writers.read().await; + ws.iter() + .filter(|w| !w.draining.load(Ordering::Relaxed)) + .filter(|w| w.generation == generation) + .filter(|w| endpoints.contains(&w.addr)) + .count() + }; + if fresh_count >= required { + break; + } - let mut shuffled: Vec = endpoints.iter().copied().collect(); - shuffled.shuffle(&mut rand::rng()); - for addr in shuffled { - if self.connect_one(addr, rng).await.is_ok() { + if !self.connect_endpoints_round_robin(&endpoint_list, rng).await { + warn!( + dc = *dc, + fresh_count, + required, + endpoint_count = endpoint_list.len(), + "ME warmup stopped: unable to reach required writer floor for DC" + ); break; } } @@ -364,7 +398,7 @@ impl MePool { ) { let desired_by_dc = self.desired_dc_endpoints().await; if desired_by_dc.is_empty() { - warn!("ME endpoint map is empty after update; skipping stale writer drain"); + warn!("ME endpoint map is empty; skipping stale writer drain"); return; } @@ -403,19 +437,26 @@ impl MePool { } if hardswap { - let fresh_writer_addrs: HashSet = writers - .iter() - .filter(|w| !w.draining.load(Ordering::Relaxed)) - .filter(|w| w.generation == generation) - .map(|w| w.addr) - .collect(); - let (fresh_ratio, fresh_missing_dc) = - Self::coverage_ratio(&desired_by_dc, &fresh_writer_addrs); + let mut fresh_missing_dc = Vec::<(i32, usize, usize)>::new(); + for (dc, endpoints) in &desired_by_dc { + if endpoints.is_empty() { + continue; + } + let required = Self::required_writers_for_dc(endpoints.len()); + let fresh_count = writers + .iter() + .filter(|w| !w.draining.load(Ordering::Relaxed)) + .filter(|w| w.generation == generation) + .filter(|w| endpoints.contains(&w.addr)) + .count(); + if fresh_count < required { + fresh_missing_dc.push((*dc, fresh_count, required)); + } + } if !fresh_missing_dc.is_empty() { warn!( previous_generation, generation, - fresh_ratio = format_args!("{fresh_ratio:.3}"), missing_dc = ?fresh_missing_dc, "ME hardswap pending: fresh generation coverage incomplete" ); @@ -425,7 +466,7 @@ impl MePool { warn!( missing_dc = ?missing_dc, // Keep stale writers alive when fresh coverage is incomplete. - "ME reinit coverage incomplete after map update; keeping stale writers" + "ME reinit coverage incomplete; keeping stale writers" ); return; } @@ -450,7 +491,7 @@ impl MePool { drop(writers); if stale_writer_ids.is_empty() { - debug!("ME map update completed with no stale writers"); + debug!("ME reinit cycle completed with no stale writers"); return; } @@ -464,7 +505,7 @@ impl MePool { coverage_ratio = format_args!("{coverage_ratio:.3}"), min_ratio = format_args!("{min_ratio:.3}"), drain_timeout_secs, - "ME map update covered; draining stale writers" + "ME reinit cycle covered; draining stale writers" ); self.stats.increment_pool_swap_total(); for writer_id in stale_writer_ids { @@ -473,6 +514,134 @@ impl MePool { } } + pub async fn zero_downtime_reinit_periodic( + self: &Arc, + rng: &SecureRandom, + ) { + self.zero_downtime_reinit_after_map_change(rng).await; + } + + async fn endpoints_for_same_dc(&self, addr: SocketAddr) -> Vec { + let mut target_dc = HashSet::::new(); + let mut endpoints = HashSet::::new(); + + if self.decision.ipv4_me { + let map = self.proxy_map_v4.read().await.clone(); + for (dc, addrs) in &map { + if addrs + .iter() + .any(|(ip, port)| SocketAddr::new(*ip, *port) == addr) + { + target_dc.insert(dc.abs()); + } + } + for dc in &target_dc { + for key in [*dc, -*dc] { + if let Some(addrs) = map.get(&key) { + for (ip, port) in addrs { + endpoints.insert(SocketAddr::new(*ip, *port)); + } + } + } + } + } + + if self.decision.ipv6_me { + let map = self.proxy_map_v6.read().await.clone(); + for (dc, addrs) in &map { + if addrs + .iter() + .any(|(ip, port)| SocketAddr::new(*ip, *port) == addr) + { + target_dc.insert(dc.abs()); + } + } + for dc in &target_dc { + for key in [*dc, -*dc] { + if let Some(addrs) = map.get(&key) { + for (ip, port) in addrs { + endpoints.insert(SocketAddr::new(*ip, *port)); + } + } + } + } + } + + let mut sorted: Vec = endpoints.into_iter().collect(); + sorted.sort_unstable(); + sorted + } + + async fn refill_writer_after_loss(self: &Arc, addr: SocketAddr) -> bool { + let fast_retries = self.me_reconnect_fast_retry_count.max(1); + + for attempt in 0..fast_retries { + self.stats.increment_me_reconnect_attempt(); + match self.connect_one(addr, self.rng.as_ref()).await { + Ok(()) => { + self.stats.increment_me_reconnect_success(); + info!( + %addr, + attempt = attempt + 1, + "ME writer restored on the same endpoint" + ); + return true; + } + Err(e) => { + debug!( + %addr, + attempt = attempt + 1, + error = %e, + "ME immediate same-endpoint reconnect failed" + ); + } + } + } + + let dc_endpoints = self.endpoints_for_same_dc(addr).await; + if dc_endpoints.is_empty() { + return false; + } + + for attempt in 0..fast_retries { + self.stats.increment_me_reconnect_attempt(); + if self + .connect_endpoints_round_robin(&dc_endpoints, self.rng.as_ref()) + .await + { + self.stats.increment_me_reconnect_success(); + info!( + %addr, + attempt = attempt + 1, + "ME writer restored via DC fallback endpoint" + ); + return true; + } + } + + false + } + + pub(crate) fn trigger_immediate_refill(self: &Arc, addr: SocketAddr) { + let pool = Arc::clone(self); + tokio::spawn(async move { + { + let mut guard = pool.refill_inflight.lock().await; + if !guard.insert(addr) { + return; + } + } + + let restored = pool.refill_writer_after_loss(addr).await; + if !restored { + warn!(%addr, "ME immediate refill failed"); + } + + let mut guard = pool.refill_inflight.lock().await; + guard.remove(&addr); + }); + } + pub async fn update_proxy_maps( &self, new_v4: HashMap>, @@ -880,16 +1049,21 @@ impl MePool { } } - async fn remove_writer_only(&self, writer_id: u64) -> Vec { + async fn remove_writer_only(self: &Arc, writer_id: u64) -> Vec { let mut close_tx: Option> = None; + let mut removed_addr: Option = None; + let mut trigger_refill = false; { let mut ws = self.writers.write().await; if let Some(pos) = ws.iter().position(|w| w.id == writer_id) { let w = ws.remove(pos); - if w.draining.load(Ordering::Relaxed) { + let was_draining = w.draining.load(Ordering::Relaxed); + if was_draining { self.stats.decrement_pool_drain_active(); } w.cancel.cancel(); + removed_addr = Some(w.addr); + trigger_refill = !was_draining; close_tx = Some(w.tx.clone()); self.conn_count.fetch_sub(1, Ordering::Relaxed); } @@ -897,6 +1071,11 @@ impl MePool { if let Some(tx) = close_tx { let _ = tx.send(WriterCommand::Close).await; } + if trigger_refill + && let Some(addr) = removed_addr + { + self.trigger_immediate_refill(addr); + } self.rtt_stats.lock().await.remove(&writer_id); self.registry.writer_lost(writer_id).await } diff --git a/src/transport/middle_proxy/rotation.rs b/src/transport/middle_proxy/rotation.rs index e141fc4..cf5f70d 100644 --- a/src/transport/middle_proxy/rotation.rs +++ b/src/transport/middle_proxy/rotation.rs @@ -1,50 +1,87 @@ use std::sync::Arc; -use std::sync::atomic::Ordering; use std::time::Duration; +use tokio::sync::watch; use tracing::{info, warn}; +use crate::config::ProxyConfig; use crate::crypto::SecureRandom; use super::MePool; -/// Periodically refresh ME connections to avoid long-lived degradation. -pub async fn me_rotation_task(pool: Arc, rng: Arc, interval: Duration) { - let interval = interval.max(Duration::from_secs(600)); +/// Periodically reinitialize ME generations and swap them after full warmup. +pub async fn me_rotation_task( + pool: Arc, + rng: Arc, + mut config_rx: watch::Receiver>, +) { + let mut interval_secs = config_rx + .borrow() + .general + .effective_me_reinit_every_secs() + .max(1); + let mut interval = Duration::from_secs(interval_secs); + let mut next_tick = tokio::time::Instant::now() + interval; + + info!(interval_secs, "ME periodic reinit task started"); + loop { - tokio::time::sleep(interval).await; + let sleep = tokio::time::sleep_until(next_tick); + tokio::pin!(sleep); - let candidate = { - let ws = pool.writers.read().await; - if ws.is_empty() { - None - } else { - let idx = (pool.rr.load(std::sync::atomic::Ordering::Relaxed) as usize) % ws.len(); - ws.get(idx).cloned() - } - }; - - let Some(w) = candidate else { - continue; - }; - - info!(addr = %w.addr, writer_id = w.id, "Rotating ME connection"); - match pool.connect_one(w.addr, rng.as_ref()).await { - Ok(()) => { - tokio::time::sleep(Duration::from_secs(2)).await; - let ws = pool.writers.read().await; - let new_alive = ws.iter().any(|nw| - nw.id != w.id && nw.addr == w.addr && !nw.degraded.load(Ordering::Relaxed) && !nw.draining.load(Ordering::Relaxed) - ); - drop(ws); - if new_alive { - pool.mark_writer_draining(w.id).await; - } else { - warn!(addr = %w.addr, writer_id = w.id, "New writer died, keeping old"); + tokio::select! { + _ = &mut sleep => { + pool.zero_downtime_reinit_periodic(rng.as_ref()).await; + let refreshed_secs = config_rx + .borrow() + .general + .effective_me_reinit_every_secs() + .max(1); + if refreshed_secs != interval_secs { + info!( + old_me_reinit_every_secs = interval_secs, + new_me_reinit_every_secs = refreshed_secs, + "ME periodic reinit interval changed" + ); + interval_secs = refreshed_secs; + interval = Duration::from_secs(interval_secs); } + next_tick = tokio::time::Instant::now() + interval; } - Err(e) => { - warn!(addr = %w.addr, writer_id = w.id, error = %e, "ME rotation connect failed"); + changed = config_rx.changed() => { + if changed.is_err() { + warn!("ME periodic reinit task stopped: config channel closed"); + break; + } + let new_secs = config_rx + .borrow() + .general + .effective_me_reinit_every_secs() + .max(1); + if new_secs == interval_secs { + continue; + } + + if new_secs < interval_secs { + info!( + old_me_reinit_every_secs = interval_secs, + new_me_reinit_every_secs = new_secs, + "ME periodic reinit interval decreased, running immediate reinit" + ); + interval_secs = new_secs; + interval = Duration::from_secs(interval_secs); + pool.zero_downtime_reinit_periodic(rng.as_ref()).await; + next_tick = tokio::time::Instant::now() + interval; + } else { + info!( + old_me_reinit_every_secs = interval_secs, + new_me_reinit_every_secs = new_secs, + "ME periodic reinit interval increased" + ); + interval_secs = new_secs; + interval = Duration::from_secs(interval_secs); + next_tick = tokio::time::Instant::now() + interval; + } } } }