use std::collections::{HashMap, HashSet}; use std::net::{IpAddr, Ipv6Addr, SocketAddr}; use std::sync::Arc; use std::sync::atomic::{AtomicBool, AtomicI32, AtomicU32, AtomicU64, AtomicUsize, Ordering}; use bytes::BytesMut; use rand::Rng; use rand::seq::SliceRandom; use tokio::sync::{Mutex, RwLock, mpsc, Notify}; use tokio_util::sync::CancellationToken; use tracing::{debug, info, warn}; use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; use crate::crypto::SecureRandom; use crate::error::{ProxyError, Result}; use crate::network::probe::NetworkDecision; use crate::network::IpFamily; use crate::protocol::constants::*; use super::ConnRegistry; use super::registry::BoundConn; use super::codec::{RpcWriter, WriterCommand}; use super::reader::reader_loop; const ME_ACTIVE_PING_SECS: u64 = 25; const ME_ACTIVE_PING_JITTER_SECS: i64 = 5; #[derive(Clone)] pub struct MeWriter { pub id: u64, pub addr: SocketAddr, pub generation: u64, pub tx: mpsc::Sender, pub cancel: CancellationToken, pub degraded: Arc, pub draining: Arc, pub draining_started_at_epoch_secs: Arc, pub allow_drain_fallback: Arc, } #[allow(dead_code)] pub struct MePool { pub(super) registry: Arc, pub(super) writers: Arc>>, pub(super) rr: AtomicU64, pub(super) decision: NetworkDecision, pub(super) rng: Arc, pub(super) proxy_tag: Option>, pub(super) proxy_secret: Arc>>, pub(super) nat_ip_cfg: Option, pub(super) nat_ip_detected: Arc>>, pub(super) nat_probe: bool, pub(super) nat_stun: Option, pub(super) nat_stun_servers: Vec, pub(super) detected_ipv6: Option, pub(super) nat_probe_attempts: std::sync::atomic::AtomicU8, pub(super) nat_probe_disabled: std::sync::atomic::AtomicBool, pub(super) stun_backoff_until: Arc>>, pub(super) me_one_retry: u8, pub(super) me_one_timeout: Duration, pub(super) me_keepalive_enabled: bool, pub(super) me_keepalive_interval: Duration, pub(super) me_keepalive_jitter: Duration, pub(super) me_keepalive_payload_random: bool, pub(super) me_warmup_stagger_enabled: bool, pub(super) me_warmup_step_delay: Duration, pub(super) me_warmup_step_jitter: Duration, pub(super) me_reconnect_max_concurrent_per_dc: u32, pub(super) me_reconnect_backoff_base: Duration, pub(super) me_reconnect_backoff_cap: Duration, pub(super) me_reconnect_fast_retry_count: u32, pub(super) proxy_map_v4: Arc>>>, pub(super) proxy_map_v6: Arc>>>, pub(super) default_dc: AtomicI32, pub(super) next_writer_id: AtomicU64, pub(super) ping_tracker: Arc>>, pub(super) rtt_stats: Arc>>, pub(super) nat_reflection_cache: Arc>, pub(super) writer_available: Arc, pub(super) refill_inflight: Arc>>, pub(super) conn_count: AtomicUsize, pub(super) stats: Arc, pub(super) generation: AtomicU64, pub(super) hardswap: AtomicBool, pub(super) me_pool_drain_ttl_secs: AtomicU64, pub(super) me_pool_force_close_secs: AtomicU64, pub(super) me_pool_min_fresh_ratio_permille: AtomicU32, pub(super) me_hardswap_warmup_delay_min_ms: AtomicU64, pub(super) me_hardswap_warmup_delay_max_ms: AtomicU64, pub(super) me_hardswap_warmup_extra_passes: AtomicU32, pub(super) me_hardswap_warmup_pass_backoff_base_ms: AtomicU64, pool_size: usize, } #[derive(Debug, Default)] pub struct NatReflectionCache { pub v4: Option<(std::time::Instant, std::net::SocketAddr)>, pub v6: Option<(std::time::Instant, std::net::SocketAddr)>, } impl MePool { fn ratio_to_permille(ratio: f32) -> u32 { let clamped = ratio.clamp(0.0, 1.0); (clamped * 1000.0).round() as u32 } fn permille_to_ratio(permille: u32) -> f32 { (permille.min(1000) as f32) / 1000.0 } fn now_epoch_secs() -> u64 { SystemTime::now() .duration_since(UNIX_EPOCH) .unwrap_or_default() .as_secs() } pub fn new( proxy_tag: Option>, proxy_secret: Vec, nat_ip: Option, nat_probe: bool, nat_stun: Option, nat_stun_servers: Vec, detected_ipv6: Option, me_one_retry: u8, me_one_timeout_ms: u64, proxy_map_v4: HashMap>, proxy_map_v6: HashMap>, default_dc: Option, decision: NetworkDecision, rng: Arc, stats: Arc, me_keepalive_enabled: bool, me_keepalive_interval_secs: u64, me_keepalive_jitter_secs: u64, me_keepalive_payload_random: bool, me_warmup_stagger_enabled: bool, me_warmup_step_delay_ms: u64, me_warmup_step_jitter_ms: u64, me_reconnect_max_concurrent_per_dc: u32, me_reconnect_backoff_base_ms: u64, me_reconnect_backoff_cap_ms: u64, me_reconnect_fast_retry_count: u32, hardswap: bool, me_pool_drain_ttl_secs: u64, me_pool_force_close_secs: u64, me_pool_min_fresh_ratio: f32, me_hardswap_warmup_delay_min_ms: u64, me_hardswap_warmup_delay_max_ms: u64, me_hardswap_warmup_extra_passes: u8, me_hardswap_warmup_pass_backoff_base_ms: u64, ) -> Arc { Arc::new(Self { registry: Arc::new(ConnRegistry::new()), writers: Arc::new(RwLock::new(Vec::new())), rr: AtomicU64::new(0), decision, rng, proxy_tag, proxy_secret: Arc::new(RwLock::new(proxy_secret)), nat_ip_cfg: nat_ip, nat_ip_detected: Arc::new(RwLock::new(None)), nat_probe, nat_stun, nat_stun_servers, detected_ipv6, nat_probe_attempts: std::sync::atomic::AtomicU8::new(0), nat_probe_disabled: std::sync::atomic::AtomicBool::new(false), stun_backoff_until: Arc::new(RwLock::new(None)), me_one_retry, me_one_timeout: Duration::from_millis(me_one_timeout_ms), stats, me_keepalive_enabled, me_keepalive_interval: Duration::from_secs(me_keepalive_interval_secs), me_keepalive_jitter: Duration::from_secs(me_keepalive_jitter_secs), me_keepalive_payload_random, me_warmup_stagger_enabled, me_warmup_step_delay: Duration::from_millis(me_warmup_step_delay_ms), me_warmup_step_jitter: Duration::from_millis(me_warmup_step_jitter_ms), me_reconnect_max_concurrent_per_dc, me_reconnect_backoff_base: Duration::from_millis(me_reconnect_backoff_base_ms), me_reconnect_backoff_cap: Duration::from_millis(me_reconnect_backoff_cap_ms), me_reconnect_fast_retry_count, pool_size: 2, proxy_map_v4: Arc::new(RwLock::new(proxy_map_v4)), proxy_map_v6: Arc::new(RwLock::new(proxy_map_v6)), default_dc: AtomicI32::new(default_dc.unwrap_or(0)), next_writer_id: AtomicU64::new(1), ping_tracker: Arc::new(Mutex::new(HashMap::new())), rtt_stats: Arc::new(Mutex::new(HashMap::new())), nat_reflection_cache: Arc::new(Mutex::new(NatReflectionCache::default())), writer_available: Arc::new(Notify::new()), refill_inflight: Arc::new(Mutex::new(HashSet::new())), conn_count: AtomicUsize::new(0), generation: AtomicU64::new(1), hardswap: AtomicBool::new(hardswap), me_pool_drain_ttl_secs: AtomicU64::new(me_pool_drain_ttl_secs), me_pool_force_close_secs: AtomicU64::new(me_pool_force_close_secs), me_pool_min_fresh_ratio_permille: AtomicU32::new(Self::ratio_to_permille(me_pool_min_fresh_ratio)), me_hardswap_warmup_delay_min_ms: AtomicU64::new(me_hardswap_warmup_delay_min_ms), me_hardswap_warmup_delay_max_ms: AtomicU64::new(me_hardswap_warmup_delay_max_ms), me_hardswap_warmup_extra_passes: AtomicU32::new(me_hardswap_warmup_extra_passes as u32), me_hardswap_warmup_pass_backoff_base_ms: AtomicU64::new(me_hardswap_warmup_pass_backoff_base_ms), }) } pub fn has_proxy_tag(&self) -> bool { self.proxy_tag.is_some() } pub fn current_generation(&self) -> u64 { self.generation.load(Ordering::Relaxed) } pub fn update_runtime_reinit_policy( &self, hardswap: bool, drain_ttl_secs: u64, force_close_secs: u64, min_fresh_ratio: f32, hardswap_warmup_delay_min_ms: u64, hardswap_warmup_delay_max_ms: u64, hardswap_warmup_extra_passes: u8, hardswap_warmup_pass_backoff_base_ms: u64, ) { self.hardswap.store(hardswap, Ordering::Relaxed); self.me_pool_drain_ttl_secs.store(drain_ttl_secs, Ordering::Relaxed); self.me_pool_force_close_secs .store(force_close_secs, Ordering::Relaxed); self.me_pool_min_fresh_ratio_permille .store(Self::ratio_to_permille(min_fresh_ratio), Ordering::Relaxed); self.me_hardswap_warmup_delay_min_ms .store(hardswap_warmup_delay_min_ms, Ordering::Relaxed); self.me_hardswap_warmup_delay_max_ms .store(hardswap_warmup_delay_max_ms, Ordering::Relaxed); self.me_hardswap_warmup_extra_passes .store(hardswap_warmup_extra_passes as u32, Ordering::Relaxed); self.me_hardswap_warmup_pass_backoff_base_ms .store(hardswap_warmup_pass_backoff_base_ms, Ordering::Relaxed); } pub fn reset_stun_state(&self) { self.nat_probe_attempts.store(0, Ordering::Relaxed); self.nat_probe_disabled.store(false, Ordering::Relaxed); } pub fn translate_our_addr(&self, addr: SocketAddr) -> SocketAddr { let ip = self.translate_ip_for_nat(addr.ip()); SocketAddr::new(ip, addr.port()) } pub fn registry(&self) -> &Arc { &self.registry } fn writers_arc(&self) -> Arc>> { self.writers.clone() } fn force_close_timeout(&self) -> Option { let secs = self.me_pool_force_close_secs.load(Ordering::Relaxed); if secs == 0 { None } else { Some(Duration::from_secs(secs)) } } fn coverage_ratio( desired_by_dc: &HashMap>, active_writer_addrs: &HashSet, ) -> (f32, Vec) { if desired_by_dc.is_empty() { return (1.0, Vec::new()); } let mut missing_dc = Vec::::new(); let mut covered = 0usize; for (dc, endpoints) in desired_by_dc { if endpoints.is_empty() { continue; } if endpoints.iter().any(|addr| active_writer_addrs.contains(addr)) { covered += 1; } else { missing_dc.push(*dc); } } missing_dc.sort_unstable(); let total = desired_by_dc.len().max(1); let ratio = (covered as f32) / (total as f32); (ratio, missing_dc) } pub async fn reconcile_connections(self: &Arc, rng: &SecureRandom) { let writers = self.writers.read().await; let current: HashSet = writers .iter() .filter(|w| !w.draining.load(Ordering::Relaxed)) .map(|w| w.addr) .collect(); drop(writers); for family in self.family_order() { let map = self.proxy_map_for_family(family).await; for (_dc, addrs) in map.iter() { let dc_addrs: Vec = addrs .iter() .map(|(ip, port)| SocketAddr::new(*ip, *port)) .collect(); if !dc_addrs.iter().any(|a| current.contains(a)) { let mut shuffled = dc_addrs.clone(); shuffled.shuffle(&mut rand::rng()); for addr in shuffled { if self.connect_one(addr, rng).await.is_ok() { break; } } } } if !self.decision.effective_multipath && !current.is_empty() { break; } } } async fn desired_dc_endpoints(&self) -> HashMap> { let mut out: HashMap> = HashMap::new(); if self.decision.ipv4_me { let map_v4 = self.proxy_map_v4.read().await.clone(); for (dc, addrs) in map_v4 { let entry = out.entry(dc.abs()).or_default(); for (ip, port) in addrs { entry.insert(SocketAddr::new(ip, port)); } } } if self.decision.ipv6_me { let map_v6 = self.proxy_map_v6.read().await.clone(); for (dc, addrs) in map_v6 { let entry = out.entry(dc.abs()).or_default(); for (ip, port) in addrs { entry.insert(SocketAddr::new(ip, port)); } } } out } pub(super) fn required_writers_for_dc(endpoint_count: usize) -> usize { endpoint_count.max(3) } fn hardswap_warmup_connect_delay_ms(&self) -> u64 { let min_ms = self .me_hardswap_warmup_delay_min_ms .load(Ordering::Relaxed); let max_ms = self .me_hardswap_warmup_delay_max_ms .load(Ordering::Relaxed); let (min_ms, max_ms) = if min_ms <= max_ms { (min_ms, max_ms) } else { (max_ms, min_ms) }; if min_ms == max_ms { return min_ms; } rand::rng().random_range(min_ms..=max_ms) } fn hardswap_warmup_backoff_ms(&self, pass_idx: usize) -> u64 { let base_ms = self .me_hardswap_warmup_pass_backoff_base_ms .load(Ordering::Relaxed); let cap_ms = (self.me_reconnect_backoff_cap.as_millis() as u64).max(base_ms); let shift = (pass_idx as u32).min(20); let scaled = base_ms.saturating_mul(1u64 << shift); let core = scaled.min(cap_ms); let jitter = (core / 2).max(1); core.saturating_add(rand::rng().random_range(0..=jitter)) } async fn fresh_writer_count_for_endpoints( &self, generation: u64, endpoints: &HashSet, ) -> usize { let ws = self.writers.read().await; ws.iter() .filter(|w| !w.draining.load(Ordering::Relaxed)) .filter(|w| w.generation == generation) .filter(|w| endpoints.contains(&w.addr)) .count() } pub(super) async fn connect_endpoints_round_robin( self: &Arc, endpoints: &[SocketAddr], rng: &SecureRandom, ) -> bool { if endpoints.is_empty() { return false; } let start = (self.rr.fetch_add(1, Ordering::Relaxed) as usize) % endpoints.len(); for offset in 0..endpoints.len() { let idx = (start + offset) % endpoints.len(); let addr = endpoints[idx]; match self.connect_one(addr, rng).await { Ok(()) => return true, Err(e) => debug!(%addr, error = %e, "ME connect failed during round-robin warmup"), } } false } async fn warmup_generation_for_all_dcs( self: &Arc, rng: &SecureRandom, generation: u64, desired_by_dc: &HashMap>, ) { let extra_passes = self .me_hardswap_warmup_extra_passes .load(Ordering::Relaxed) .min(10) as usize; let total_passes = 1 + extra_passes; for (dc, endpoints) in desired_by_dc { if endpoints.is_empty() { continue; } let mut endpoint_list: Vec = endpoints.iter().copied().collect(); endpoint_list.sort_unstable(); let required = Self::required_writers_for_dc(endpoint_list.len()); let mut completed = false; let mut last_fresh_count = self .fresh_writer_count_for_endpoints(generation, endpoints) .await; for pass_idx in 0..total_passes { if last_fresh_count >= required { completed = true; break; } let missing = required.saturating_sub(last_fresh_count); debug!( dc = *dc, pass = pass_idx + 1, total_passes, fresh_count = last_fresh_count, required, missing, endpoint_count = endpoint_list.len(), "ME hardswap warmup pass started" ); for attempt_idx in 0..missing { let delay_ms = self.hardswap_warmup_connect_delay_ms(); tokio::time::sleep(Duration::from_millis(delay_ms)).await; let connected = self.connect_endpoints_round_robin(&endpoint_list, rng).await; debug!( dc = *dc, pass = pass_idx + 1, total_passes, attempt = attempt_idx + 1, delay_ms, connected, "ME hardswap warmup connect attempt finished" ); } last_fresh_count = self .fresh_writer_count_for_endpoints(generation, endpoints) .await; if last_fresh_count >= required { completed = true; info!( dc = *dc, pass = pass_idx + 1, total_passes, fresh_count = last_fresh_count, required, "ME hardswap warmup floor reached for DC" ); break; } if pass_idx + 1 < total_passes { let backoff_ms = self.hardswap_warmup_backoff_ms(pass_idx); debug!( dc = *dc, pass = pass_idx + 1, total_passes, fresh_count = last_fresh_count, required, backoff_ms, "ME hardswap warmup pass incomplete, delaying next pass" ); tokio::time::sleep(Duration::from_millis(backoff_ms)).await; } } if !completed { warn!( dc = *dc, fresh_count = last_fresh_count, required, endpoint_count = endpoint_list.len(), total_passes, "ME warmup stopped: unable to reach required writer floor for DC" ); } } } pub async fn zero_downtime_reinit_after_map_change( self: &Arc, rng: &SecureRandom, ) { let desired_by_dc = self.desired_dc_endpoints().await; if desired_by_dc.is_empty() { warn!("ME endpoint map is empty; skipping stale writer drain"); return; } let previous_generation = self.current_generation(); let generation = self.generation.fetch_add(1, Ordering::Relaxed) + 1; let hardswap = self.hardswap.load(Ordering::Relaxed); if hardswap { self.warmup_generation_for_all_dcs(rng, generation, &desired_by_dc) .await; } else { self.reconcile_connections(rng).await; } let writers = self.writers.read().await; let active_writer_addrs: HashSet = writers .iter() .filter(|w| !w.draining.load(Ordering::Relaxed)) .map(|w| w.addr) .collect(); let min_ratio = Self::permille_to_ratio( self.me_pool_min_fresh_ratio_permille .load(Ordering::Relaxed), ); let (coverage_ratio, missing_dc) = Self::coverage_ratio(&desired_by_dc, &active_writer_addrs); if !hardswap && coverage_ratio < min_ratio { warn!( previous_generation, generation, coverage_ratio = format_args!("{coverage_ratio:.3}"), min_ratio = format_args!("{min_ratio:.3}"), missing_dc = ?missing_dc, "ME reinit coverage below threshold; keeping stale writers" ); return; } if hardswap { let mut fresh_missing_dc = Vec::<(i32, usize, usize)>::new(); for (dc, endpoints) in &desired_by_dc { if endpoints.is_empty() { continue; } let required = Self::required_writers_for_dc(endpoints.len()); let fresh_count = writers .iter() .filter(|w| !w.draining.load(Ordering::Relaxed)) .filter(|w| w.generation == generation) .filter(|w| endpoints.contains(&w.addr)) .count(); if fresh_count < required { fresh_missing_dc.push((*dc, fresh_count, required)); } } if !fresh_missing_dc.is_empty() { warn!( previous_generation, generation, missing_dc = ?fresh_missing_dc, "ME hardswap pending: fresh generation coverage incomplete" ); return; } } else if !missing_dc.is_empty() { warn!( missing_dc = ?missing_dc, // Keep stale writers alive when fresh coverage is incomplete. "ME reinit coverage incomplete; keeping stale writers" ); return; } let desired_addrs: HashSet = desired_by_dc .values() .flat_map(|set| set.iter().copied()) .collect(); let stale_writer_ids: Vec = writers .iter() .filter(|w| !w.draining.load(Ordering::Relaxed)) .filter(|w| { if hardswap { w.generation < generation } else { !desired_addrs.contains(&w.addr) } }) .map(|w| w.id) .collect(); drop(writers); if stale_writer_ids.is_empty() { debug!("ME reinit cycle completed with no stale writers"); return; } let drain_timeout = self.force_close_timeout(); let drain_timeout_secs = drain_timeout.map(|d| d.as_secs()).unwrap_or(0); info!( stale_writers = stale_writer_ids.len(), previous_generation, generation, hardswap, coverage_ratio = format_args!("{coverage_ratio:.3}"), min_ratio = format_args!("{min_ratio:.3}"), drain_timeout_secs, "ME reinit cycle covered; draining stale writers" ); self.stats.increment_pool_swap_total(); for writer_id in stale_writer_ids { self.mark_writer_draining_with_timeout(writer_id, drain_timeout, !hardswap) .await; } } pub async fn zero_downtime_reinit_periodic( self: &Arc, rng: &SecureRandom, ) { self.zero_downtime_reinit_after_map_change(rng).await; } async fn endpoints_for_same_dc(&self, addr: SocketAddr) -> Vec { let mut target_dc = HashSet::::new(); let mut endpoints = HashSet::::new(); if self.decision.ipv4_me { let map = self.proxy_map_v4.read().await.clone(); for (dc, addrs) in &map { if addrs .iter() .any(|(ip, port)| SocketAddr::new(*ip, *port) == addr) { target_dc.insert(dc.abs()); } } for dc in &target_dc { for key in [*dc, -*dc] { if let Some(addrs) = map.get(&key) { for (ip, port) in addrs { endpoints.insert(SocketAddr::new(*ip, *port)); } } } } } if self.decision.ipv6_me { let map = self.proxy_map_v6.read().await.clone(); for (dc, addrs) in &map { if addrs .iter() .any(|(ip, port)| SocketAddr::new(*ip, *port) == addr) { target_dc.insert(dc.abs()); } } for dc in &target_dc { for key in [*dc, -*dc] { if let Some(addrs) = map.get(&key) { for (ip, port) in addrs { endpoints.insert(SocketAddr::new(*ip, *port)); } } } } } let mut sorted: Vec = endpoints.into_iter().collect(); sorted.sort_unstable(); sorted } async fn refill_writer_after_loss(self: &Arc, addr: SocketAddr) -> bool { let fast_retries = self.me_reconnect_fast_retry_count.max(1); for attempt in 0..fast_retries { self.stats.increment_me_reconnect_attempt(); match self.connect_one(addr, self.rng.as_ref()).await { Ok(()) => { self.stats.increment_me_reconnect_success(); info!( %addr, attempt = attempt + 1, "ME writer restored on the same endpoint" ); return true; } Err(e) => { debug!( %addr, attempt = attempt + 1, error = %e, "ME immediate same-endpoint reconnect failed" ); } } } let dc_endpoints = self.endpoints_for_same_dc(addr).await; if dc_endpoints.is_empty() { return false; } for attempt in 0..fast_retries { self.stats.increment_me_reconnect_attempt(); if self .connect_endpoints_round_robin(&dc_endpoints, self.rng.as_ref()) .await { self.stats.increment_me_reconnect_success(); info!( %addr, attempt = attempt + 1, "ME writer restored via DC fallback endpoint" ); return true; } } false } pub(crate) fn trigger_immediate_refill(self: &Arc, addr: SocketAddr) { let pool = Arc::clone(self); tokio::spawn(async move { { let mut guard = pool.refill_inflight.lock().await; if !guard.insert(addr) { return; } } let restored = pool.refill_writer_after_loss(addr).await; if !restored { warn!(%addr, "ME immediate refill failed"); } let mut guard = pool.refill_inflight.lock().await; guard.remove(&addr); }); } pub async fn update_proxy_maps( &self, new_v4: HashMap>, new_v6: Option>>, ) -> bool { let mut changed = false; { let mut guard = self.proxy_map_v4.write().await; if !new_v4.is_empty() && *guard != new_v4 { *guard = new_v4; changed = true; } } if let Some(v6) = new_v6 { let mut guard = self.proxy_map_v6.write().await; if !v6.is_empty() && *guard != v6 { *guard = v6; changed = true; } } // Ensure negative DC entries mirror positives when absent (Telegram convention). { let mut guard = self.proxy_map_v4.write().await; let keys: Vec = guard.keys().cloned().collect(); for k in keys.iter().cloned().filter(|k| *k > 0) { if !guard.contains_key(&-k) && let Some(addrs) = guard.get(&k).cloned() { guard.insert(-k, addrs); } } } { let mut guard = self.proxy_map_v6.write().await; let keys: Vec = guard.keys().cloned().collect(); for k in keys.iter().cloned().filter(|k| *k > 0) { if !guard.contains_key(&-k) && let Some(addrs) = guard.get(&k).cloned() { guard.insert(-k, addrs); } } } changed } pub async fn update_secret(self: &Arc, new_secret: Vec) -> bool { if new_secret.len() < 32 { warn!(len = new_secret.len(), "proxy-secret update ignored (too short)"); return false; } let mut guard = self.proxy_secret.write().await; if *guard != new_secret { *guard = new_secret; drop(guard); self.reconnect_all().await; return true; } false } pub async fn reconnect_all(self: &Arc) { let ws = self.writers.read().await.clone(); for w in ws { if let Ok(()) = self.connect_one(w.addr, self.rng.as_ref()).await { self.mark_writer_draining(w.id).await; tokio::time::sleep(Duration::from_secs(2)).await; } } } pub(super) async fn key_selector(&self) -> u32 { let secret = self.proxy_secret.read().await; if secret.len() >= 4 { u32::from_le_bytes([secret[0], secret[1], secret[2], secret[3]]) } else { 0 } } pub(super) fn family_order(&self) -> Vec { let mut order = Vec::new(); if self.decision.prefer_ipv6() { if self.decision.ipv6_me { order.push(IpFamily::V6); } if self.decision.ipv4_me { order.push(IpFamily::V4); } } else { if self.decision.ipv4_me { order.push(IpFamily::V4); } if self.decision.ipv6_me { order.push(IpFamily::V6); } } order } async fn proxy_map_for_family(&self, family: IpFamily) -> HashMap> { match family { IpFamily::V4 => self.proxy_map_v4.read().await.clone(), IpFamily::V6 => self.proxy_map_v6.read().await.clone(), } } pub async fn init(self: &Arc, pool_size: usize, rng: &Arc) -> Result<()> { let family_order = self.family_order(); let ks = self.key_selector().await; info!( me_servers = self.proxy_map_v4.read().await.len(), pool_size, key_selector = format_args!("0x{ks:08x}"), secret_len = self.proxy_secret.read().await.len(), "Initializing ME pool" ); for family in family_order { let map = self.proxy_map_for_family(family).await; let dc_addrs: Vec<(i32, Vec<(IpAddr, u16)>)> = map .iter() .map(|(dc, addrs)| (*dc, addrs.clone())) .collect(); // Ensure at least one connection per DC; run DCs in parallel. let mut join = tokio::task::JoinSet::new(); let mut dc_failures = 0usize; for (dc, addrs) in dc_addrs.iter().cloned() { if addrs.is_empty() { continue; } let pool = Arc::clone(self); let rng_clone = Arc::clone(rng); join.spawn(async move { pool.connect_primary_for_dc(dc, addrs, rng_clone).await }); } while let Some(res) = join.join_next().await { if let Ok(false) = res { dc_failures += 1; } } if dc_failures > 2 { return Err(ProxyError::Proxy("Too many ME DC init failures, falling back to direct".into())); } // Additional connections up to pool_size total (round-robin across DCs), staggered to de-phase lifecycles. if self.me_warmup_stagger_enabled { for (dc, addrs) in dc_addrs.iter() { for (ip, port) in addrs { if self.connection_count() >= pool_size { break; } let addr = SocketAddr::new(*ip, *port); let jitter = rand::rng().random_range(0..=self.me_warmup_step_jitter.as_millis() as u64); let delay_ms = self.me_warmup_step_delay.as_millis() as u64 + jitter; tokio::time::sleep(Duration::from_millis(delay_ms)).await; if let Err(e) = self.connect_one(addr, rng.as_ref()).await { debug!(%addr, dc = %dc, error = %e, "Extra ME connect failed (staggered)"); } } } } else { for (dc, addrs) in dc_addrs.iter() { for (ip, port) in addrs { if self.connection_count() >= pool_size { break; } let addr = SocketAddr::new(*ip, *port); if let Err(e) = self.connect_one(addr, rng.as_ref()).await { debug!(%addr, dc = %dc, error = %e, "Extra ME connect failed"); } } if self.connection_count() >= pool_size { break; } } } if !self.decision.effective_multipath && self.connection_count() > 0 { break; } } if self.writers.read().await.is_empty() { return Err(ProxyError::Proxy("No ME connections".into())); } Ok(()) } pub(crate) async fn connect_one(self: &Arc, addr: SocketAddr, rng: &SecureRandom) -> Result<()> { let secret_len = self.proxy_secret.read().await.len(); if secret_len < 32 { return Err(ProxyError::Proxy("proxy-secret too short for ME auth".into())); } let (stream, _connect_ms) = self.connect_tcp(addr).await?; let hs = self.handshake_only(stream, addr, rng).await?; let writer_id = self.next_writer_id.fetch_add(1, Ordering::Relaxed); let generation = self.current_generation(); let cancel = CancellationToken::new(); let degraded = Arc::new(AtomicBool::new(false)); let draining = Arc::new(AtomicBool::new(false)); let draining_started_at_epoch_secs = Arc::new(AtomicU64::new(0)); let allow_drain_fallback = Arc::new(AtomicBool::new(false)); let (tx, mut rx) = mpsc::channel::(4096); let mut rpc_writer = RpcWriter { writer: hs.wr, key: hs.write_key, iv: hs.write_iv, seq_no: 0, crc_mode: hs.crc_mode, }; let cancel_wr = cancel.clone(); tokio::spawn(async move { loop { tokio::select! { cmd = rx.recv() => { match cmd { Some(WriterCommand::Data(payload)) => { if rpc_writer.send(&payload).await.is_err() { break; } } Some(WriterCommand::DataAndFlush(payload)) => { if rpc_writer.send_and_flush(&payload).await.is_err() { break; } } Some(WriterCommand::Close) | None => break, } } _ = cancel_wr.cancelled() => break, } } }); let writer = MeWriter { id: writer_id, addr, generation, tx: tx.clone(), cancel: cancel.clone(), degraded: degraded.clone(), draining: draining.clone(), draining_started_at_epoch_secs: draining_started_at_epoch_secs.clone(), allow_drain_fallback: allow_drain_fallback.clone(), }; self.writers.write().await.push(writer.clone()); self.conn_count.fetch_add(1, Ordering::Relaxed); self.writer_available.notify_one(); let reg = self.registry.clone(); let writers_arc = self.writers_arc(); let ping_tracker = self.ping_tracker.clone(); let ping_tracker_reader = ping_tracker.clone(); let rtt_stats = self.rtt_stats.clone(); let stats_reader = self.stats.clone(); let stats_ping = self.stats.clone(); let pool = Arc::downgrade(self); let cancel_ping = cancel.clone(); let tx_ping = tx.clone(); let ping_tracker_ping = ping_tracker.clone(); let cleanup_done = Arc::new(AtomicBool::new(false)); let cleanup_for_reader = cleanup_done.clone(); let cleanup_for_ping = cleanup_done.clone(); let keepalive_enabled = self.me_keepalive_enabled; let keepalive_interval = self.me_keepalive_interval; let keepalive_jitter = self.me_keepalive_jitter; let cancel_reader_token = cancel.clone(); let cancel_ping_token = cancel_ping.clone(); tokio::spawn(async move { let res = reader_loop( hs.rd, hs.read_key, hs.read_iv, hs.crc_mode, reg.clone(), BytesMut::new(), BytesMut::new(), tx.clone(), ping_tracker_reader, rtt_stats.clone(), stats_reader, writer_id, degraded.clone(), cancel_reader_token.clone(), ) .await; if let Some(pool) = pool.upgrade() && cleanup_for_reader .compare_exchange(false, true, Ordering::AcqRel, Ordering::Relaxed) .is_ok() { pool.remove_writer_and_close_clients(writer_id).await; } if let Err(e) = res { warn!(error = %e, "ME reader ended"); } let mut ws = writers_arc.write().await; ws.retain(|w| w.id != writer_id); info!(remaining = ws.len(), "Dead ME writer removed from pool"); }); let pool_ping = Arc::downgrade(self); tokio::spawn(async move { let mut ping_id: i64 = rand::random::(); // Per-writer jittered start to avoid phase sync. let startup_jitter = if keepalive_enabled { let jitter_cap_ms = keepalive_interval.as_millis() / 2; let effective_jitter_ms = keepalive_jitter.as_millis().min(jitter_cap_ms).max(1); Duration::from_millis(rand::rng().random_range(0..=effective_jitter_ms as u64)) } else { let jitter = rand::rng() .random_range(-ME_ACTIVE_PING_JITTER_SECS..=ME_ACTIVE_PING_JITTER_SECS); let wait = (ME_ACTIVE_PING_SECS as i64 + jitter).max(5) as u64; Duration::from_secs(wait) }; tokio::select! { _ = cancel_ping_token.cancelled() => return, _ = tokio::time::sleep(startup_jitter) => {} } loop { let wait = if keepalive_enabled { let jitter_cap_ms = keepalive_interval.as_millis() / 2; let effective_jitter_ms = keepalive_jitter.as_millis().min(jitter_cap_ms).max(1); keepalive_interval + Duration::from_millis( rand::rng().random_range(0..=effective_jitter_ms as u64) ) } else { let jitter = rand::rng() .random_range(-ME_ACTIVE_PING_JITTER_SECS..=ME_ACTIVE_PING_JITTER_SECS); let secs = (ME_ACTIVE_PING_SECS as i64 + jitter).max(5) as u64; Duration::from_secs(secs) }; tokio::select! { _ = cancel_ping_token.cancelled() => { break; } _ = tokio::time::sleep(wait) => {} } let sent_id = ping_id; let mut p = Vec::with_capacity(12); p.extend_from_slice(&RPC_PING_U32.to_le_bytes()); p.extend_from_slice(&sent_id.to_le_bytes()); { let mut tracker = ping_tracker_ping.lock().await; let before = tracker.len(); tracker.retain(|_, (ts, _)| ts.elapsed() < Duration::from_secs(120)); let expired = before.saturating_sub(tracker.len()); if expired > 0 { stats_ping.increment_me_keepalive_timeout_by(expired as u64); } tracker.insert(sent_id, (std::time::Instant::now(), writer_id)); } ping_id = ping_id.wrapping_add(1); stats_ping.increment_me_keepalive_sent(); if tx_ping.send(WriterCommand::DataAndFlush(p)).await.is_err() { stats_ping.increment_me_keepalive_failed(); debug!("ME ping failed, removing dead writer"); cancel_ping.cancel(); if let Some(pool) = pool_ping.upgrade() && cleanup_for_ping .compare_exchange(false, true, Ordering::AcqRel, Ordering::Relaxed) .is_ok() { pool.remove_writer_and_close_clients(writer_id).await; } break; } } }); Ok(()) } async fn connect_primary_for_dc( self: Arc, dc: i32, mut addrs: Vec<(IpAddr, u16)>, rng: Arc, ) -> bool { if addrs.is_empty() { return false; } addrs.shuffle(&mut rand::rng()); for (ip, port) in addrs { let addr = SocketAddr::new(ip, port); match self.connect_one(addr, rng.as_ref()).await { Ok(()) => { info!(%addr, dc = %dc, "ME connected"); return true; } Err(e) => warn!(%addr, dc = %dc, error = %e, "ME connect failed, trying next"), } } warn!(dc = %dc, "All ME servers for DC failed at init"); false } pub(crate) async fn remove_writer_and_close_clients(self: &Arc, writer_id: u64) { let conns = self.remove_writer_only(writer_id).await; for bound in conns { let _ = self.registry.route(bound.conn_id, super::MeResponse::Close).await; let _ = self.registry.unregister(bound.conn_id).await; } } async fn remove_writer_only(self: &Arc, writer_id: u64) -> Vec { let mut close_tx: Option> = None; let mut removed_addr: Option = None; let mut trigger_refill = false; { let mut ws = self.writers.write().await; if let Some(pos) = ws.iter().position(|w| w.id == writer_id) { let w = ws.remove(pos); let was_draining = w.draining.load(Ordering::Relaxed); if was_draining { self.stats.decrement_pool_drain_active(); } w.cancel.cancel(); removed_addr = Some(w.addr); trigger_refill = !was_draining; close_tx = Some(w.tx.clone()); self.conn_count.fetch_sub(1, Ordering::Relaxed); } } if let Some(tx) = close_tx { let _ = tx.send(WriterCommand::Close).await; } if trigger_refill && let Some(addr) = removed_addr { self.trigger_immediate_refill(addr); } self.rtt_stats.lock().await.remove(&writer_id); self.registry.writer_lost(writer_id).await } pub(crate) async fn mark_writer_draining_with_timeout( self: &Arc, writer_id: u64, timeout: Option, allow_drain_fallback: bool, ) { let timeout = timeout.filter(|d| !d.is_zero()); let found = { let mut ws = self.writers.write().await; if let Some(w) = ws.iter_mut().find(|w| w.id == writer_id) { let already_draining = w.draining.swap(true, Ordering::Relaxed); w.allow_drain_fallback .store(allow_drain_fallback, Ordering::Relaxed); w.draining_started_at_epoch_secs .store(Self::now_epoch_secs(), Ordering::Relaxed); if !already_draining { self.stats.increment_pool_drain_active(); } w.draining.store(true, Ordering::Relaxed); true } else { false } }; if !found { return; } let timeout_secs = timeout.map(|d| d.as_secs()).unwrap_or(0); debug!( writer_id, timeout_secs, allow_drain_fallback, "ME writer marked draining" ); let pool = Arc::downgrade(self); tokio::spawn(async move { let deadline = timeout.map(|t| Instant::now() + t); while let Some(p) = pool.upgrade() { if let Some(deadline_at) = deadline && Instant::now() >= deadline_at { warn!(writer_id, "Drain timeout, force-closing"); p.stats.increment_pool_force_close_total(); let _ = p.remove_writer_and_close_clients(writer_id).await; break; } if p.registry.is_writer_empty(writer_id).await { let _ = p.remove_writer_only(writer_id).await; break; } tokio::time::sleep(Duration::from_secs(1)).await; } }); } pub(crate) async fn mark_writer_draining(self: &Arc, writer_id: u64) { self.mark_writer_draining_with_timeout(writer_id, Some(Duration::from_secs(300)), false) .await; } pub(super) fn writer_accepts_new_binding(&self, writer: &MeWriter) -> bool { if !writer.draining.load(Ordering::Relaxed) { return true; } if !writer.allow_drain_fallback.load(Ordering::Relaxed) { return false; } let ttl_secs = self.me_pool_drain_ttl_secs.load(Ordering::Relaxed); if ttl_secs == 0 { return true; } let started = writer.draining_started_at_epoch_secs.load(Ordering::Relaxed); if started == 0 { return false; } Self::now_epoch_secs().saturating_sub(started) <= ttl_secs } } #[allow(dead_code)] fn hex_dump(data: &[u8]) -> String { const MAX: usize = 64; let mut out = String::with_capacity(data.len() * 2 + 3); for (i, b) in data.iter().take(MAX).enumerate() { if i > 0 { out.push(' '); } out.push_str(&format!("{b:02x}")); } if data.len() > MAX { out.push_str(" …"); } out }