diff --git a/src/config/load.rs b/src/config/load.rs index 9abab30..dcca2a0 100644 --- a/src/config/load.rs +++ b/src/config/load.rs @@ -548,6 +548,12 @@ impl ProxyConfig { config.general.middle_proxy_nat_probe = true; warn!("Auto-enabled middle_proxy_nat_probe for middle proxy mode"); } + if config.general.use_middle_proxy && !config.general.me_secret_atomic_snapshot { + config.general.me_secret_atomic_snapshot = true; + warn!( + "Auto-enabled me_secret_atomic_snapshot for middle proxy mode to keep KDF key_selector/secret coherent" + ); + } validate_network_cfg(&mut config.network)?; crate::network::dns_overrides::validate_entries(&config.network.dns_overrides)?; diff --git a/src/main.rs b/src/main.rs index 9f81edf..064df16 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1339,7 +1339,7 @@ async fn main() -> std::result::Result<(), Box> { let (admission_tx, admission_rx) = watch::channel(true); if config.general.use_middle_proxy { if let Some(pool) = me_pool.as_ref() { - let initial_open = pool.admission_ready_full_floor().await; + let initial_open = pool.admission_ready_conditional_cast().await; admission_tx.send_replace(initial_open); if initial_open { info!("Conditional-admission gate: open (ME pool ready)"); @@ -1354,7 +1354,7 @@ async fn main() -> std::result::Result<(), Box> { let mut open_streak = if initial_open { 1u32 } else { 0u32 }; let mut close_streak = if initial_open { 0u32 } else { 1u32 }; loop { - let ready = pool_for_gate.admission_ready_full_floor().await; + let ready = pool_for_gate.admission_ready_conditional_cast().await; if ready { open_streak = open_streak.saturating_add(1); close_streak = 0; @@ -1374,7 +1374,7 @@ async fn main() -> std::result::Result<(), Box> { admission_tx_gate.send_replace(false); warn!( close_streak, - "Conditional-admission gate closed (ME pool below required floor)" + "Conditional-admission gate closed (ME pool has uncovered DC groups)" ); } } diff --git a/src/transport/middle_proxy/handshake.rs b/src/transport/middle_proxy/handshake.rs index 1016c6b..77634a6 100644 --- a/src/transport/middle_proxy/handshake.rs +++ b/src/transport/middle_proxy/handshake.rs @@ -387,9 +387,11 @@ impl MePool { socks_bound_addr.map(|value| value.ip()), client_port_source, ); - let mut kdf_fingerprint_guard = self.kdf_material_fingerprint.lock().await; - if let Some((prev_fingerprint, prev_client_port)) = + let previous_kdf_fingerprint = { + let kdf_fingerprint_guard = self.kdf_material_fingerprint.read().await; kdf_fingerprint_guard.get(&peer_addr_nat).copied() + }; + if let Some((prev_fingerprint, prev_client_port)) = previous_kdf_fingerprint { if prev_fingerprint != kdf_fingerprint { self.stats.increment_me_kdf_drift_total(); @@ -416,6 +418,9 @@ impl MePool { ); } } + // Keep fingerprint updates eventually consistent for diagnostics while avoiding + // serializing all concurrent handshakes on a single async mutex. + let mut kdf_fingerprint_guard = self.kdf_material_fingerprint.write().await; kdf_fingerprint_guard.insert(peer_addr_nat, (kdf_fingerprint, client_port_for_kdf)); drop(kdf_fingerprint_guard); diff --git a/src/transport/middle_proxy/pool.rs b/src/transport/middle_proxy/pool.rs index 8cc078e..1dab2f4 100644 --- a/src/transport/middle_proxy/pool.rs +++ b/src/transport/middle_proxy/pool.rs @@ -132,7 +132,7 @@ pub struct MePool { pub(super) pending_hardswap_map_hash: AtomicU64, pub(super) hardswap: AtomicBool, pub(super) endpoint_quarantine: Arc>>, - pub(super) kdf_material_fingerprint: Arc>>, + pub(super) kdf_material_fingerprint: Arc>>, pub(super) me_pool_drain_ttl_secs: AtomicU64, pub(super) me_pool_force_close_secs: AtomicU64, pub(super) me_pool_min_fresh_ratio_permille: AtomicU32, @@ -335,7 +335,7 @@ impl MePool { pending_hardswap_map_hash: AtomicU64::new(0), hardswap: AtomicBool::new(hardswap), endpoint_quarantine: Arc::new(Mutex::new(HashMap::new())), - kdf_material_fingerprint: Arc::new(Mutex::new(HashMap::new())), + kdf_material_fingerprint: Arc::new(RwLock::new(HashMap::new())), me_pool_drain_ttl_secs: AtomicU64::new(me_pool_drain_ttl_secs), me_pool_force_close_secs: AtomicU64::new(me_pool_force_close_secs), me_pool_min_fresh_ratio_permille: AtomicU32::new(Self::ratio_to_permille( diff --git a/src/transport/middle_proxy/pool_init.rs b/src/transport/middle_proxy/pool_init.rs index fef1553..fbb5c64 100644 --- a/src/transport/middle_proxy/pool_init.rs +++ b/src/transport/middle_proxy/pool_init.rs @@ -14,10 +14,12 @@ use super::pool::MePool; impl MePool { pub async fn init(self: &Arc, pool_size: usize, rng: &Arc) -> Result<()> { let family_order = self.family_order(); + let connect_concurrency = self.me_reconnect_max_concurrent_per_dc.max(1) as usize; let ks = self.key_selector().await; info!( me_servers = self.proxy_map_v4.read().await.len(), pool_size, + connect_concurrency, key_selector = format_args!("0x{ks:08x}"), secret_len = self.proxy_secret.read().await.secret.len(), "Initializing ME pool" @@ -41,23 +43,39 @@ impl MePool { }) .collect(); dc_addrs.sort_unstable_by_key(|(dc, _)| *dc); + dc_addrs.sort_by_key(|(_, addrs)| (addrs.len() != 1, addrs.len())); - // Ensure at least one live writer per DC group; run missing DCs in parallel. + // Stage 1: build base coverage for conditional-cast. + // Single-endpoint DCs are prefilled first; multi-endpoint DCs require one live writer. let mut join = tokio::task::JoinSet::new(); for (dc, addrs) in dc_addrs.iter().cloned() { if addrs.is_empty() { continue; } + let target_writers = if addrs.len() == 1 { + self.required_writers_for_dc_with_floor_mode(addrs.len(), false) + } else { + 1usize + }; let endpoints: HashSet = addrs .iter() .map(|(ip, port)| SocketAddr::new(*ip, *port)) .collect(); - if self.active_writer_count_for_endpoints(&endpoints).await > 0 { + if self.active_writer_count_for_endpoints(&endpoints).await >= target_writers { continue; } let pool = Arc::clone(self); let rng_clone = Arc::clone(rng); - join.spawn(async move { pool.connect_primary_for_dc(dc, addrs, rng_clone).await }); + join.spawn(async move { + pool.connect_primary_for_dc( + dc, + addrs, + target_writers, + rng_clone, + connect_concurrency, + ) + .await + }); } while join.join_next().await.is_some() {} @@ -77,47 +95,35 @@ impl MePool { ))); } - // Warm reserve writers asynchronously so startup does not block after first working pool is ready. + // Stage 2: continue saturating multi-endpoint DC groups in background. let pool = Arc::clone(self); let rng_clone = Arc::clone(rng); let dc_addrs_bg = dc_addrs.clone(); tokio::spawn(async move { - if pool.me_warmup_stagger_enabled { - for (dc, addrs) in &dc_addrs_bg { - for (ip, port) in addrs { - if pool.connection_count() >= pool_size { - break; - } - let addr = SocketAddr::new(*ip, *port); - let jitter = rand::rng() - .random_range(0..=pool.me_warmup_step_jitter.as_millis() as u64); - let delay_ms = pool.me_warmup_step_delay.as_millis() as u64 + jitter; - tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await; - if let Err(e) = pool.connect_one(addr, rng_clone.as_ref()).await { - debug!(%addr, dc = %dc, error = %e, "Extra ME connect failed (staggered)"); - } - } - } - } else { - for (dc, addrs) in &dc_addrs_bg { - for (ip, port) in addrs { - if pool.connection_count() >= pool_size { - break; - } - let addr = SocketAddr::new(*ip, *port); - if let Err(e) = pool.connect_one(addr, rng_clone.as_ref()).await { - debug!(%addr, dc = %dc, error = %e, "Extra ME connect failed"); - } - } - if pool.connection_count() >= pool_size { - break; - } + let mut join_bg = tokio::task::JoinSet::new(); + for (dc, addrs) in dc_addrs_bg { + if addrs.len() <= 1 { + continue; } + let target_writers = pool.required_writers_for_dc_with_floor_mode(addrs.len(), false); + let pool_clone = Arc::clone(&pool); + let rng_clone_local = Arc::clone(&rng_clone); + join_bg.spawn(async move { + pool_clone + .connect_primary_for_dc( + dc, + addrs, + target_writers, + rng_clone_local, + connect_concurrency, + ) + .await + }); } + while join_bg.join_next().await.is_some() {} debug!( - target_pool_size = pool_size, current_pool_size = pool.connection_count(), - "Background ME reserve warmup finished" + "Background ME saturation warmup finished" ); }); @@ -140,62 +146,85 @@ impl MePool { self: Arc, dc: i32, mut addrs: Vec<(IpAddr, u16)>, + target_writers: usize, rng: Arc, + connect_concurrency: usize, ) -> bool { if addrs.is_empty() { return false; } + let target_writers = target_writers.max(1); addrs.shuffle(&mut rand::rng()); - if addrs.len() > 1 { - let concurrency = 2usize; + let endpoints: Vec = addrs + .iter() + .map(|(ip, port)| SocketAddr::new(*ip, *port)) + .collect(); + let endpoint_set: HashSet = endpoints.iter().copied().collect(); + + loop { + let alive = self.active_writer_count_for_endpoints(&endpoint_set).await; + if alive >= target_writers { + info!( + dc = %dc, + alive, + target_writers, + "ME connected" + ); + return true; + } + + let missing = target_writers.saturating_sub(alive).max(1); + let concurrency = connect_concurrency.max(1).min(missing); let mut join = tokio::task::JoinSet::new(); - let mut next_idx = 0usize; + for _ in 0..concurrency { + let pool = Arc::clone(&self); + let rng_clone = Arc::clone(&rng); + let endpoints_clone = endpoints.clone(); + join.spawn(async move { + pool.connect_endpoints_round_robin(&endpoints_clone, rng_clone.as_ref()) + .await + }); + } - while next_idx < addrs.len() || !join.is_empty() { - while next_idx < addrs.len() && join.len() < concurrency { - let (ip, port) = addrs[next_idx]; - next_idx += 1; - let addr = SocketAddr::new(ip, port); - let pool = Arc::clone(&self); - let rng_clone = Arc::clone(&rng); - join.spawn(async move { - (addr, pool.connect_one(addr, rng_clone.as_ref()).await) - }); - } - - let Some(res) = join.join_next().await else { - break; - }; + let mut progress = false; + while let Some(res) = join.join_next().await { match res { - Ok((addr, Ok(()))) => { - info!(%addr, dc = %dc, "ME connected"); - join.abort_all(); - while join.join_next().await.is_some() {} - return true; - } - Ok((addr, Err(e))) => { - warn!(%addr, dc = %dc, error = %e, "ME connect failed, trying next"); + Ok(true) => { + progress = true; } + Ok(false) => {} Err(e) => { warn!(dc = %dc, error = %e, "ME connect task failed"); } } } - warn!(dc = %dc, "All ME servers for DC failed at init"); - return false; - } - for (ip, port) in addrs { - let addr = SocketAddr::new(ip, port); - match self.connect_one(addr, rng.as_ref()).await { - Ok(()) => { - info!(%addr, dc = %dc, "ME connected"); - return true; - } - Err(e) => warn!(%addr, dc = %dc, error = %e, "ME connect failed, trying next"), + let alive_after = self.active_writer_count_for_endpoints(&endpoint_set).await; + if alive_after >= target_writers { + info!( + dc = %dc, + alive = alive_after, + target_writers, + "ME connected" + ); + return true; + } + if !progress { + warn!( + dc = %dc, + alive = alive_after, + target_writers, + "All ME servers for DC failed at init" + ); + return false; + } + + if self.me_warmup_stagger_enabled { + let jitter = rand::rng() + .random_range(0..=self.me_warmup_step_jitter.as_millis() as u64); + let delay_ms = self.me_warmup_step_delay.as_millis() as u64 + jitter; + tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await; } } - warn!(dc = %dc, "All ME servers for DC failed at init"); - false } } diff --git a/src/transport/middle_proxy/pool_status.rs b/src/transport/middle_proxy/pool_status.rs index 9dd3d07..17a418c 100644 --- a/src/transport/middle_proxy/pool_status.rs +++ b/src/transport/middle_proxy/pool_status.rs @@ -100,6 +100,68 @@ pub(crate) struct MeApiRuntimeSnapshot { } impl MePool { + pub(crate) async fn admission_ready_conditional_cast(&self) -> bool { + let mut endpoints_by_dc = BTreeMap::>::new(); + if self.decision.ipv4_me { + let map = self.proxy_map_v4.read().await.clone(); + for (dc, addrs) in map { + let abs_dc = dc.abs(); + if abs_dc == 0 { + continue; + } + let Ok(dc_idx) = i16::try_from(abs_dc) else { + continue; + }; + let entry = endpoints_by_dc.entry(dc_idx).or_default(); + for (ip, port) in addrs { + entry.insert(SocketAddr::new(ip, port)); + } + } + } + if self.decision.ipv6_me { + let map = self.proxy_map_v6.read().await.clone(); + for (dc, addrs) in map { + let abs_dc = dc.abs(); + if abs_dc == 0 { + continue; + } + let Ok(dc_idx) = i16::try_from(abs_dc) else { + continue; + }; + let entry = endpoints_by_dc.entry(dc_idx).or_default(); + for (ip, port) in addrs { + entry.insert(SocketAddr::new(ip, port)); + } + } + } + + if endpoints_by_dc.is_empty() { + return false; + } + + let writers = self.writers.read().await.clone(); + let mut live_writers_by_endpoint = HashMap::::new(); + for writer in writers { + if writer.draining.load(Ordering::Relaxed) { + continue; + } + *live_writers_by_endpoint.entry(writer.addr).or_insert(0) += 1; + } + + for endpoints in endpoints_by_dc.values() { + let alive: usize = endpoints + .iter() + .map(|endpoint| live_writers_by_endpoint.get(endpoint).copied().unwrap_or(0)) + .sum(); + if alive == 0 { + return false; + } + } + + true + } + + #[allow(dead_code)] pub(crate) async fn admission_ready_full_floor(&self) -> bool { let mut endpoints_by_dc = BTreeMap::>::new(); if self.decision.ipv4_me {