diff --git a/src/transport/middle_proxy/health.rs b/src/transport/middle_proxy/health.rs index e5f4260..a2e107d 100644 --- a/src/transport/middle_proxy/health.rs +++ b/src/transport/middle_proxy/health.rs @@ -115,59 +115,109 @@ async fn reap_draining_writers( pool: &Arc, warn_next_allowed: &mut HashMap, ) { + if pool.draining_active_runtime() == 0 { + return; + } + let now_epoch_secs = MePool::now_epoch_secs(); let now = Instant::now(); let drain_ttl_secs = pool.me_pool_drain_ttl_secs.load(std::sync::atomic::Ordering::Relaxed); let drain_threshold = pool .me_pool_drain_threshold .load(std::sync::atomic::Ordering::Relaxed); - let writers = pool.writers.read().await.clone(); - let mut draining_writers = Vec::new(); - for writer in writers { - if !writer.draining.load(std::sync::atomic::Ordering::Relaxed) { - continue; + let mut draining_writers = { + let writers = pool.writers.read().await; + let mut draining_writers = Vec::::new(); + for writer in writers.iter() { + if !writer.draining.load(std::sync::atomic::Ordering::Relaxed) { + continue; + } + draining_writers.push(DrainingWriterSnapshot { + id: writer.id, + writer_dc: writer.writer_dc, + addr: writer.addr, + generation: writer.generation, + created_at: writer.created_at, + draining_started_at_epoch_secs: writer + .draining_started_at_epoch_secs + .load(std::sync::atomic::Ordering::Relaxed), + drain_deadline_epoch_secs: writer + .drain_deadline_epoch_secs + .load(std::sync::atomic::Ordering::Relaxed), + allow_drain_fallback: writer + .allow_drain_fallback + .load(std::sync::atomic::Ordering::Relaxed), + }); } - let is_empty = pool.registry.is_writer_empty(writer.id).await; - if is_empty { - pool.remove_writer_and_close_clients(writer.id).await; - continue; - } - draining_writers.push(writer); + draining_writers + }; + + if draining_writers.is_empty() { + return; } - if drain_threshold > 0 && draining_writers.len() > drain_threshold as usize { - draining_writers.sort_by(|left, right| { - let left_started = left - .draining_started_at_epoch_secs - .load(std::sync::atomic::Ordering::Relaxed); - let right_started = right - .draining_started_at_epoch_secs - .load(std::sync::atomic::Ordering::Relaxed); - left_started - .cmp(&right_started) - .then_with(|| left.created_at.cmp(&right.created_at)) - .then_with(|| left.id.cmp(&right.id)) - }); - let overflow = draining_writers.len().saturating_sub(drain_threshold as usize); - warn!( - draining_writers = draining_writers.len(), - me_pool_drain_threshold = drain_threshold, - removing_writers = overflow, - "ME draining writer threshold exceeded, force-closing oldest draining writers" - ); - for writer in draining_writers.drain(..overflow) { - pool.stats.increment_pool_force_close_total(); + let draining_ids: Vec = draining_writers.iter().map(|writer| writer.id).collect(); + let non_empty_writer_ids = pool.registry.non_empty_writer_ids(&draining_ids).await; + let mut non_empty_draining_writers = + Vec::::with_capacity(draining_writers.len()); + for writer in draining_writers.drain(..) { + if non_empty_writer_ids.contains(&writer.id) { + non_empty_draining_writers.push(writer); + } else { pool.remove_writer_and_close_clients(writer.id).await; } } + draining_writers = non_empty_draining_writers; + if draining_writers.is_empty() { + return; + } + + let overflow = if drain_threshold > 0 && draining_writers.len() > drain_threshold as usize { + draining_writers.len().saturating_sub(drain_threshold as usize) + } else { + 0 + }; + let has_deadline_expired = draining_writers.iter().any(|writer| { + writer.drain_deadline_epoch_secs != 0 && now_epoch_secs >= writer.drain_deadline_epoch_secs + }); + let can_drop_with_replacement = if overflow > 0 || has_deadline_expired { + pool.has_non_draining_writer_per_desired_dc_group().await + } else { + false + }; + + if overflow > 0 { + if can_drop_with_replacement { + draining_writers.sort_by(|left, right| { + left.draining_started_at_epoch_secs + .cmp(&right.draining_started_at_epoch_secs) + .then_with(|| left.created_at.cmp(&right.created_at)) + .then_with(|| left.id.cmp(&right.id)) + }); + warn!( + draining_writers = draining_writers.len(), + me_pool_drain_threshold = drain_threshold, + removing_writers = overflow, + "ME draining writer threshold exceeded, force-closing oldest draining writers" + ); + for writer in draining_writers.drain(..overflow) { + pool.stats.increment_pool_force_close_total(); + pool.remove_writer_and_close_clients(writer.id).await; + } + } else { + warn!( + draining_writers = draining_writers.len(), + me_pool_drain_threshold = drain_threshold, + overflow, + "ME draining threshold exceeded, but replacement coverage is incomplete; keeping draining writers" + ); + } + } for writer in draining_writers { - let drain_started_at_epoch_secs = writer - .draining_started_at_epoch_secs - .load(std::sync::atomic::Ordering::Relaxed); if drain_ttl_secs > 0 - && drain_started_at_epoch_secs != 0 - && now_epoch_secs.saturating_sub(drain_started_at_epoch_secs) > drain_ttl_secs + && writer.draining_started_at_epoch_secs != 0 + && now_epoch_secs.saturating_sub(writer.draining_started_at_epoch_secs) > drain_ttl_secs && should_emit_writer_warn( warn_next_allowed, writer.id, @@ -182,21 +232,45 @@ async fn reap_draining_writers( generation = writer.generation, drain_ttl_secs, force_close_secs = pool.me_pool_force_close_secs.load(std::sync::atomic::Ordering::Relaxed), - allow_drain_fallback = writer.allow_drain_fallback.load(std::sync::atomic::Ordering::Relaxed), + allow_drain_fallback = writer.allow_drain_fallback, "ME draining writer remains non-empty past drain TTL" ); } - let deadline_epoch_secs = writer - .drain_deadline_epoch_secs - .load(std::sync::atomic::Ordering::Relaxed); - if deadline_epoch_secs != 0 && now_epoch_secs >= deadline_epoch_secs { - warn!(writer_id = writer.id, "Drain timeout, force-closing"); - pool.stats.increment_pool_force_close_total(); - pool.remove_writer_and_close_clients(writer.id).await; + if writer.drain_deadline_epoch_secs != 0 && now_epoch_secs >= writer.drain_deadline_epoch_secs + { + if can_drop_with_replacement { + warn!(writer_id = writer.id, "Drain timeout, force-closing"); + pool.stats.increment_pool_force_close_total(); + pool.remove_writer_and_close_clients(writer.id).await; + } else if should_emit_writer_warn( + warn_next_allowed, + writer.id, + now, + pool.warn_rate_limit_duration(), + ) { + warn!( + writer_id = writer.id, + writer_dc = writer.writer_dc, + endpoint = %writer.addr, + "Drain timeout reached, but replacement coverage is incomplete; keeping draining writer" + ); + } } } } +#[derive(Debug, Clone)] +struct DrainingWriterSnapshot { + id: u64, + writer_dc: i32, + addr: SocketAddr, + generation: u64, + created_at: Instant, + draining_started_at_epoch_secs: u64, + drain_deadline_epoch_secs: u64, + allow_drain_fallback: bool, +} + fn should_emit_writer_warn( next_allowed: &mut HashMap, writer_id: u64, @@ -1330,6 +1404,15 @@ mod tests { me_pool_drain_threshold, ..GeneralConfig::default() }; + let mut proxy_map_v4 = HashMap::new(); + proxy_map_v4.insert( + 2, + vec![(IpAddr::V4(Ipv4Addr::new(203, 0, 113, 10)), 443)], + ); + let decision = NetworkDecision { + ipv4_me: true, + ..NetworkDecision::default() + }; MePool::new( None, vec![1u8; 32], @@ -1341,10 +1424,10 @@ mod tests { None, 12, 1200, - HashMap::new(), + proxy_map_v4, HashMap::new(), None, - NetworkDecision::default(), + decision, None, Arc::new(SecureRandom::new()), Arc::new(Stats::default()), @@ -1438,6 +1521,7 @@ mod tests { pool.writers.write().await.push(writer); pool.registry.register_writer(writer_id, tx).await; pool.conn_count.fetch_add(1, Ordering::Relaxed); + pool.increment_draining_active_runtime(); assert!( pool.registry .bind_writer( @@ -1455,8 +1539,56 @@ mod tests { conn_id } + async fn insert_live_writer(pool: &Arc, writer_id: u64, writer_dc: i32) { + let (tx, _writer_rx) = mpsc::channel::(8); + let writer = MeWriter { + id: writer_id, + addr: SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(203, 0, 113, (writer_id as u8).saturating_add(1))), + 4000 + writer_id as u16, + ), + source_ip: IpAddr::V4(Ipv4Addr::LOCALHOST), + writer_dc, + generation: 2, + contour: Arc::new(AtomicU8::new(WriterContour::Active.as_u8())), + created_at: Instant::now(), + tx: tx.clone(), + cancel: CancellationToken::new(), + degraded: Arc::new(AtomicBool::new(false)), + rtt_ema_ms_x10: Arc::new(AtomicU32::new(0)), + draining: Arc::new(AtomicBool::new(false)), + draining_started_at_epoch_secs: Arc::new(AtomicU64::new(0)), + drain_deadline_epoch_secs: Arc::new(AtomicU64::new(0)), + allow_drain_fallback: Arc::new(AtomicBool::new(false)), + }; + pool.writers.write().await.push(writer); + pool.registry.register_writer(writer_id, tx).await; + pool.conn_count.fetch_add(1, Ordering::Relaxed); + } + #[tokio::test] async fn reap_draining_writers_force_closes_oldest_over_threshold() { + let pool = make_pool(2).await; + insert_live_writer(&pool, 1, 2).await; + assert!(pool.has_non_draining_writer_per_desired_dc_group().await); + let now_epoch_secs = MePool::now_epoch_secs(); + let conn_a = insert_draining_writer(&pool, 10, now_epoch_secs.saturating_sub(30)).await; + let conn_b = insert_draining_writer(&pool, 20, now_epoch_secs.saturating_sub(20)).await; + let conn_c = insert_draining_writer(&pool, 30, now_epoch_secs.saturating_sub(10)).await; + let mut warn_next_allowed = HashMap::new(); + + reap_draining_writers(&pool, &mut warn_next_allowed).await; + + let mut writer_ids: Vec = pool.writers.read().await.iter().map(|writer| writer.id).collect(); + writer_ids.sort_unstable(); + assert_eq!(writer_ids, vec![1, 20, 30]); + assert!(pool.registry.get_writer(conn_a).await.is_none()); + assert_eq!(pool.registry.get_writer(conn_b).await.unwrap().writer_id, 20); + assert_eq!(pool.registry.get_writer(conn_c).await.unwrap().writer_id, 30); + } + + #[tokio::test] + async fn reap_draining_writers_does_not_force_close_overflow_without_replacement() { let pool = make_pool(2).await; let now_epoch_secs = MePool::now_epoch_secs(); let conn_a = insert_draining_writer(&pool, 10, now_epoch_secs.saturating_sub(30)).await; @@ -1466,9 +1598,10 @@ mod tests { reap_draining_writers(&pool, &mut warn_next_allowed).await; - let writer_ids: Vec = pool.writers.read().await.iter().map(|writer| writer.id).collect(); - assert_eq!(writer_ids, vec![20, 30]); - assert!(pool.registry.get_writer(conn_a).await.is_none()); + let mut writer_ids: Vec = pool.writers.read().await.iter().map(|writer| writer.id).collect(); + writer_ids.sort_unstable(); + assert_eq!(writer_ids, vec![10, 20, 30]); + assert_eq!(pool.registry.get_writer(conn_a).await.unwrap().writer_id, 10); assert_eq!(pool.registry.get_writer(conn_b).await.unwrap().writer_id, 20); assert_eq!(pool.registry.get_writer(conn_c).await.unwrap().writer_id, 30); } diff --git a/src/transport/middle_proxy/pool.rs b/src/transport/middle_proxy/pool.rs index 2a65160..56f3fbf 100644 --- a/src/transport/middle_proxy/pool.rs +++ b/src/transport/middle_proxy/pool.rs @@ -160,6 +160,7 @@ pub struct MePool { pub(super) refill_inflight: Arc>>, pub(super) refill_inflight_dc: Arc>>, pub(super) conn_count: AtomicUsize, + pub(super) draining_active_runtime: AtomicU64, pub(super) stats: Arc, pub(super) generation: AtomicU64, pub(super) active_generation: AtomicU64, @@ -438,6 +439,7 @@ impl MePool { refill_inflight: Arc::new(Mutex::new(HashSet::new())), refill_inflight_dc: Arc::new(Mutex::new(HashSet::new())), conn_count: AtomicUsize::new(0), + draining_active_runtime: AtomicU64::new(0), generation: AtomicU64::new(1), active_generation: AtomicU64::new(1), warm_generation: AtomicU64::new(0), @@ -690,6 +692,32 @@ impl MePool { } } + pub(super) fn draining_active_runtime(&self) -> u64 { + self.draining_active_runtime.load(Ordering::Relaxed) + } + + pub(super) fn increment_draining_active_runtime(&self) { + self.draining_active_runtime.fetch_add(1, Ordering::Relaxed); + } + + pub(super) fn decrement_draining_active_runtime(&self) { + let mut current = self.draining_active_runtime.load(Ordering::Relaxed); + loop { + if current == 0 { + break; + } + match self.draining_active_runtime.compare_exchange_weak( + current, + current - 1, + Ordering::Relaxed, + Ordering::Relaxed, + ) { + Ok(_) => break, + Err(actual) => current = actual, + } + } + } + pub(super) async fn key_selector(&self) -> u32 { self.proxy_secret.read().await.key_selector } diff --git a/src/transport/middle_proxy/pool_reinit.rs b/src/transport/middle_proxy/pool_reinit.rs index 3d9d679..3cfc834 100644 --- a/src/transport/middle_proxy/pool_reinit.rs +++ b/src/transport/middle_proxy/pool_reinit.rs @@ -141,6 +141,38 @@ impl MePool { out } + pub(super) async fn has_non_draining_writer_per_desired_dc_group(&self) -> bool { + let desired_by_dc = self.desired_dc_endpoints().await; + let required_dcs: HashSet = desired_by_dc + .iter() + .filter_map(|(dc, endpoints)| { + if endpoints.is_empty() { + None + } else { + Some(*dc) + } + }) + .collect(); + if required_dcs.is_empty() { + return true; + } + + let ws = self.writers.read().await; + let mut covered_dcs = HashSet::::with_capacity(required_dcs.len()); + for writer in ws.iter() { + if writer.draining.load(Ordering::Relaxed) { + continue; + } + if required_dcs.contains(&writer.writer_dc) { + covered_dcs.insert(writer.writer_dc); + if covered_dcs.len() == required_dcs.len() { + return true; + } + } + } + false + } + fn hardswap_warmup_connect_delay_ms(&self) -> u64 { let min_ms = self.me_hardswap_warmup_delay_min_ms.load(Ordering::Relaxed); let max_ms = self.me_hardswap_warmup_delay_max_ms.load(Ordering::Relaxed); @@ -475,12 +507,30 @@ impl MePool { coverage_ratio = format_args!("{coverage_ratio:.3}"), min_ratio = format_args!("{min_ratio:.3}"), drain_timeout_secs, - "ME reinit cycle covered; draining stale writers" + "ME reinit cycle covered; processing stale writers" ); self.stats.increment_pool_swap_total(); + let can_drop_with_replacement = self + .has_non_draining_writer_per_desired_dc_group() + .await; + if can_drop_with_replacement { + info!( + stale_writers = stale_writer_ids.len(), + "ME reinit stale writers: replacement coverage ready, force-closing clients for fast rebind" + ); + } else { + warn!( + stale_writers = stale_writer_ids.len(), + "ME reinit stale writers: replacement coverage incomplete, keeping draining fallback" + ); + } for writer_id in stale_writer_ids { self.mark_writer_draining_with_timeout(writer_id, drain_timeout, !hardswap) .await; + if can_drop_with_replacement { + self.stats.increment_pool_force_close_total(); + self.remove_writer_and_close_clients(writer_id).await; + } } if hardswap { self.clear_pending_hardswap_state(); diff --git a/src/transport/middle_proxy/pool_writer.rs b/src/transport/middle_proxy/pool_writer.rs index 8ce3de3..7490a98 100644 --- a/src/transport/middle_proxy/pool_writer.rs +++ b/src/transport/middle_proxy/pool_writer.rs @@ -514,6 +514,7 @@ impl MePool { let was_draining = w.draining.load(Ordering::Relaxed); if was_draining { self.stats.decrement_pool_drain_active(); + self.decrement_draining_active_runtime(); } self.stats.increment_me_writer_removed_total(); w.cancel.cancel(); @@ -572,6 +573,7 @@ impl MePool { .store(drain_deadline_epoch_secs, Ordering::Relaxed); if !already_draining { self.stats.increment_pool_drain_active(); + self.increment_draining_active_runtime(); } w.contour .store(WriterContour::Draining.as_u8(), Ordering::Relaxed); diff --git a/src/transport/middle_proxy/registry.rs b/src/transport/middle_proxy/registry.rs index cc3028b..cbe1d9a 100644 --- a/src/transport/middle_proxy/registry.rs +++ b/src/transport/middle_proxy/registry.rs @@ -436,6 +436,19 @@ impl ConnRegistry { .map(|s| s.is_empty()) .unwrap_or(true) } + + pub(super) async fn non_empty_writer_ids(&self, writer_ids: &[u64]) -> HashSet { + let inner = self.inner.read().await; + let mut out = HashSet::::with_capacity(writer_ids.len()); + for writer_id in writer_ids { + if let Some(conns) = inner.conns_for_writer.get(writer_id) + && !conns.is_empty() + { + out.insert(*writer_id); + } + } + out + } } #[cfg(test)] @@ -634,4 +647,35 @@ mod tests { ); assert!(registry.get_writer(conn_id).await.is_none()); } + + #[tokio::test] + async fn non_empty_writer_ids_returns_only_writers_with_bound_clients() { + let registry = ConnRegistry::new(); + let (conn_id, _rx) = registry.register().await; + let (writer_tx_a, _writer_rx_a) = tokio::sync::mpsc::channel(8); + let (writer_tx_b, _writer_rx_b) = tokio::sync::mpsc::channel(8); + registry.register_writer(10, writer_tx_a).await; + registry.register_writer(20, writer_tx_b).await; + + let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443); + assert!( + registry + .bind_writer( + conn_id, + 10, + ConnMeta { + target_dc: 2, + client_addr: addr, + our_addr: addr, + proto_flags: 0, + }, + ) + .await + ); + + let non_empty = registry.non_empty_writer_ids(&[10, 20, 30]).await; + assert!(non_empty.contains(&10)); + assert!(!non_empty.contains(&20)); + assert!(!non_empty.contains(&30)); + } }