From 349e9c9cdad84f0c0e129117d824209a83748cf5 Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Thu, 19 Mar 2026 20:55:50 +0300 Subject: [PATCH] Arc-swap for ME Writer Snapshots Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com> --- src/transport/middle_proxy/health.rs | 49 +++++---- .../middle_proxy/health_adversarial_tests.rs | 6 +- .../middle_proxy/health_integration_tests.rs | 10 +- .../middle_proxy/health_regression_tests.rs | 28 ++--- src/transport/middle_proxy/pool.rs | 15 +-- src/transport/middle_proxy/pool_config.rs | 4 +- src/transport/middle_proxy/pool_init.rs | 2 +- src/transport/middle_proxy/pool_refill.rs | 7 +- src/transport/middle_proxy/pool_reinit.rs | 44 ++++---- src/transport/middle_proxy/pool_status.rs | 22 ++-- src/transport/middle_proxy/pool_writer.rs | 78 ++++++++------ src/transport/middle_proxy/send.rs | 100 ++++++++++-------- 12 files changed, 199 insertions(+), 166 deletions(-) diff --git a/src/transport/middle_proxy/health.rs b/src/transport/middle_proxy/health.rs index ed69526..c0dd08a 100644 --- a/src/transport/middle_proxy/health.rs +++ b/src/transport/middle_proxy/health.rs @@ -189,14 +189,14 @@ pub(super) async fn reap_draining_writers( let drain_threshold = pool .me_pool_drain_threshold .load(std::sync::atomic::Ordering::Relaxed); - let writers = pool.writers.read().await.clone(); + let writers = pool.writers.load_full(); let activity = pool.registry.writer_activity_snapshot().await; let mut draining_writers = Vec::new(); let mut empty_writer_ids = Vec::::new(); let mut timeout_expired_writer_ids = Vec::::new(); let mut force_close_writer_ids = Vec::::new(); - for writer in writers { - if !writer.draining.load(std::sync::atomic::Ordering::Relaxed) { + for writer in writers.iter().cloned() { + if !writer.draining.load(std::sync::atomic::Ordering::Acquire) { continue; } if draining_writer_timeout_expired(pool, &writer, now_epoch_secs, drain_ttl_secs) { @@ -497,12 +497,13 @@ async fn check_family( let mut live_addr_counts = HashMap::<(i32, SocketAddr), usize>::new(); let mut live_writer_ids_by_addr = HashMap::<(i32, SocketAddr), Vec>::new(); - for writer in pool.writers.read().await.iter().filter(|w| { - !w.draining.load(std::sync::atomic::Ordering::Relaxed) + let writers_snapshot = pool.writers.load_full(); + for writer in writers_snapshot.iter().filter(|w| { + !w.draining.load(std::sync::atomic::Ordering::Acquire) }) { if !matches!( super::pool::WriterContour::from_u8( - writer.contour.load(std::sync::atomic::Ordering::Relaxed), + writer.contour.load(std::sync::atomic::Ordering::Acquire), ), super::pool::WriterContour::Active ) { @@ -1566,20 +1567,20 @@ async fn maybe_rotate_single_endpoint_shadow( let now_epoch_secs = MePool::now_epoch_secs(); // Collect zombie IDs under a short read-lock. - let zombie_ids: Vec = { - let ws = pool.writers.read().await; - ws.iter() - .filter(|w| w.draining.load(std::sync::atomic::Ordering::Relaxed)) - .filter(|w| { - let deadline = w - .drain_deadline_epoch_secs - .load(std::sync::atomic::Ordering::Relaxed); - deadline != 0 - && now_epoch_secs.saturating_sub(deadline) > ZOMBIE_THRESHOLD_SECS - }) - .map(|w| w.id) - .collect() - }; + let zombie_ids: Vec = pool + .writers + .load_full() + .iter() + .filter(|w| w.draining.load(std::sync::atomic::Ordering::Acquire)) + .filter(|w| { + let deadline = w + .drain_deadline_epoch_secs + .load(std::sync::atomic::Ordering::Acquire); + deadline != 0 + && now_epoch_secs.saturating_sub(deadline) > ZOMBIE_THRESHOLD_SECS + }) + .map(|w| w.id) + .collect(); if zombie_ids.is_empty() { continue; @@ -1737,7 +1738,9 @@ mod tests { drain_deadline_epoch_secs: Arc::new(AtomicU64::new(0)), allow_drain_fallback: Arc::new(AtomicBool::new(false)), }; - pool.writers.write().await.push(writer); + let mut writers = (*pool.writers.load_full()).clone(); + writers.push(writer); + pool.writers.store(Arc::new(writers)); pool.registry.register_writer(writer_id, tx).await; pool.conn_count.fetch_add(1, Ordering::Relaxed); assert!( @@ -1769,7 +1772,7 @@ mod tests { reap_draining_writers(&pool, &mut warn_next_allowed, &mut soft_evict_next_allowed).await; - let writer_ids: Vec = pool.writers.read().await.iter().map(|writer| writer.id).collect(); + let writer_ids: Vec = pool.writers.load_full().iter().map(|writer| writer.id).collect(); assert_eq!(writer_ids, vec![20, 30]); assert!(pool.registry.get_writer(conn_a).await.is_none()); assert_eq!(pool.registry.get_writer(conn_b).await.unwrap().writer_id, 20); @@ -1788,7 +1791,7 @@ mod tests { reap_draining_writers(&pool, &mut warn_next_allowed, &mut soft_evict_next_allowed).await; - let writer_ids: Vec = pool.writers.read().await.iter().map(|writer| writer.id).collect(); + let writer_ids: Vec = pool.writers.load_full().iter().map(|writer| writer.id).collect(); assert_eq!(writer_ids, vec![10, 20, 30]); assert_eq!(pool.registry.get_writer(conn_a).await.unwrap().writer_id, 10); assert_eq!(pool.registry.get_writer(conn_b).await.unwrap().writer_id, 20); diff --git a/src/transport/middle_proxy/health_adversarial_tests.rs b/src/transport/middle_proxy/health_adversarial_tests.rs index ae517b3..4005b41 100644 --- a/src/transport/middle_proxy/health_adversarial_tests.rs +++ b/src/transport/middle_proxy/health_adversarial_tests.rs @@ -147,7 +147,9 @@ async fn insert_draining_writer( allow_drain_fallback: Arc::new(AtomicBool::new(false)), }; - pool.writers.write().await.push(writer); + let mut writers = (*pool.writers.load_full()).clone(); + writers.push(writer); + pool.writers.store(Arc::new(writers)); pool.registry.register_writer(writer_id, tx).await; pool.conn_count.fetch_add(1, Ordering::Relaxed); @@ -174,7 +176,7 @@ async fn insert_draining_writer( } async fn writer_count(pool: &Arc) -> usize { - pool.writers.read().await.len() + pool.writers.load_full().len() } async fn sorted_writer_ids(pool: &Arc) -> Vec { diff --git a/src/transport/middle_proxy/health_integration_tests.rs b/src/transport/middle_proxy/health_integration_tests.rs index fbbffce..37c2ea1 100644 --- a/src/transport/middle_proxy/health_integration_tests.rs +++ b/src/transport/middle_proxy/health_integration_tests.rs @@ -144,7 +144,9 @@ async fn insert_draining_writer( drain_deadline_epoch_secs: Arc::new(AtomicU64::new(drain_deadline_epoch_secs)), allow_drain_fallback: Arc::new(AtomicBool::new(false)), }; - pool.writers.write().await.push(writer); + let mut writers = (*pool.writers.load_full()).clone(); + writers.push(writer); + pool.writers.store(Arc::new(writers)); pool.registry.register_writer(writer_id, tx).await; pool.conn_count.fetch_add(1, Ordering::Relaxed); for idx in 0..bound_clients { @@ -190,7 +192,7 @@ async fn me_health_monitor_drains_expired_backlog_over_multiple_cycles() { monitor.abort(); let _ = monitor.await; - assert!(pool.writers.read().await.is_empty()); + assert!(pool.writers.load_full().is_empty()); } #[tokio::test] @@ -206,7 +208,7 @@ async fn me_health_monitor_cleans_empty_draining_writers_without_force_close() { monitor.abort(); let _ = monitor.await; - assert!(pool.writers.read().await.is_empty()); + assert!(pool.writers.load_full().is_empty()); } #[tokio::test] @@ -231,5 +233,5 @@ async fn me_health_monitor_converges_retry_like_threshold_backlog_to_empty() { monitor.abort(); let _ = monitor.await; - assert!(pool.writers.read().await.is_empty()); + assert!(pool.writers.load_full().is_empty()); } diff --git a/src/transport/middle_proxy/health_regression_tests.rs b/src/transport/middle_proxy/health_regression_tests.rs index bcdaf2e..b6b0e8f 100644 --- a/src/transport/middle_proxy/health_regression_tests.rs +++ b/src/transport/middle_proxy/health_regression_tests.rs @@ -138,7 +138,9 @@ async fn insert_draining_writer( drain_deadline_epoch_secs: Arc::new(AtomicU64::new(drain_deadline_epoch_secs)), allow_drain_fallback: Arc::new(AtomicBool::new(false)), }; - pool.writers.write().await.push(writer); + let mut writers = (*pool.writers.load_full()).clone(); + writers.push(writer); + pool.writers.store(Arc::new(writers)); pool.registry.register_writer(writer_id, tx).await; pool.conn_count.fetch_add(1, Ordering::Relaxed); for idx in 0..bound_clients { @@ -256,7 +258,9 @@ async fn reap_draining_writers_does_not_block_on_stuck_writer_close_signal() { drain_deadline_epoch_secs: Arc::new(AtomicU64::new(0)), allow_drain_fallback: Arc::new(AtomicBool::new(false)), }; - pool.writers.write().await.push(blocked_writer); + let mut writers = (*pool.writers.load_full()).clone(); + writers.push(blocked_writer); + pool.writers.store(Arc::new(writers)); pool.registry .register_writer(blocked_writer_id, blocked_tx) .await; @@ -357,7 +361,7 @@ async fn reap_draining_writers_limits_closes_per_health_tick() { reap_draining_writers(&pool, &mut warn_next_allowed, &mut soft_evict_next_allowed).await; - assert_eq!(pool.writers.read().await.len(), writer_total - close_budget); + assert_eq!(pool.writers.load_full().len(), writer_total - close_budget); } #[tokio::test] @@ -380,13 +384,13 @@ async fn reap_draining_writers_backlog_drains_across_ticks() { let mut soft_evict_next_allowed = HashMap::new(); for _ in 0..8 { - if pool.writers.read().await.is_empty() { + if pool.writers.load_full().is_empty() { break; } reap_draining_writers(&pool, &mut warn_next_allowed, &mut soft_evict_next_allowed).await; } - assert!(pool.writers.read().await.is_empty()); + assert!(pool.writers.load_full().is_empty()); } #[tokio::test] @@ -411,12 +415,12 @@ async fn reap_draining_writers_threshold_backlog_converges_to_threshold() { for _ in 0..16 { reap_draining_writers(&pool, &mut warn_next_allowed, &mut soft_evict_next_allowed).await; - if pool.writers.read().await.len() <= threshold as usize { + if pool.writers.load_full().len() <= threshold as usize { break; } } - assert_eq!(pool.writers.read().await.len(), threshold as usize); + assert_eq!(pool.writers.load_full().len(), threshold as usize); } #[tokio::test] @@ -521,14 +525,14 @@ async fn reap_draining_writers_warn_state_never_exceeds_live_draining_population .await; } reap_draining_writers(&pool, &mut warn_next_allowed, &mut soft_evict_next_allowed).await; - assert!(warn_next_allowed.len() <= pool.writers.read().await.len()); + assert!(warn_next_allowed.len() <= pool.writers.load_full().len()); let existing_writer_ids = current_writer_ids(&pool).await; for writer_id in existing_writer_ids.into_iter().take(4) { let _ = pool.remove_writer_and_close_clients(writer_id).await; } reap_draining_writers(&pool, &mut warn_next_allowed, &mut soft_evict_next_allowed).await; - assert!(warn_next_allowed.len() <= pool.writers.read().await.len()); + assert!(warn_next_allowed.len() <= pool.writers.load_full().len()); } } @@ -558,13 +562,13 @@ async fn reap_draining_writers_mixed_backlog_converges_without_leaking_warn_stat for _ in 0..16 { reap_draining_writers(&pool, &mut warn_next_allowed, &mut soft_evict_next_allowed).await; - if pool.writers.read().await.len() <= 6 { + if pool.writers.load_full().len() <= 6 { break; } } - assert!(pool.writers.read().await.len() <= 6); - assert!(warn_next_allowed.len() <= pool.writers.read().await.len()); + assert!(pool.writers.load_full().len() <= 6); + assert!(warn_next_allowed.len() <= pool.writers.load_full().len()); } #[tokio::test] diff --git a/src/transport/middle_proxy/pool.rs b/src/transport/middle_proxy/pool.rs index 441d41d..5c71108 100644 --- a/src/transport/middle_proxy/pool.rs +++ b/src/transport/middle_proxy/pool.rs @@ -5,6 +5,7 @@ use std::sync::atomic::{AtomicBool, AtomicI32, AtomicU8, AtomicU32, AtomicU64, A use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; use tokio::sync::{Mutex, Notify, RwLock, mpsc}; +use arc_swap::ArcSwap; use tokio_util::sync::CancellationToken; use crate::config::{ @@ -82,7 +83,7 @@ pub struct SecretSnapshot { #[allow(dead_code)] pub struct MePool { pub(super) registry: Arc, - pub(super) writers: Arc>>, + pub(super) writers: Arc>>>, pub(super) rr: AtomicU64, pub(super) decision: NetworkDecision, pub(super) upstream: Option>, @@ -329,7 +330,7 @@ impl MePool { ); Arc::new(Self { registry, - writers: Arc::new(RwLock::new(Vec::new())), + writers: Arc::new(ArcSwap::from_pointee(Vec::new())), rr: AtomicU64::new(0), decision, upstream, @@ -512,7 +513,7 @@ impl MePool { } pub fn current_generation(&self) -> u64 { - self.active_generation.load(Ordering::Relaxed) + self.active_generation.load(Ordering::Acquire) } pub fn set_runtime_ready(&self, ready: bool) { @@ -728,7 +729,7 @@ impl MePool { MeSocksKdfPolicy::from_u8(self.me_socks_kdf_policy.load(Ordering::Relaxed)) } - pub(super) fn writers_arc(&self) -> Arc>> { + pub(super) fn writers_arc(&self) -> Arc>>> { self.writers.clone() } @@ -776,14 +777,14 @@ impl MePool { } pub(super) async fn non_draining_writer_counts_by_contour(&self) -> (usize, usize, usize) { - let ws = self.writers.read().await; + let ws = self.writers.load_full(); let mut active = 0usize; let mut warm = 0usize; for writer in ws.iter() { - if writer.draining.load(Ordering::Relaxed) { + if writer.draining.load(Ordering::Acquire) { continue; } - match WriterContour::from_u8(writer.contour.load(Ordering::Relaxed)) { + match WriterContour::from_u8(writer.contour.load(Ordering::Acquire)) { WriterContour::Active => active = active.saturating_add(1), WriterContour::Warm => warm = warm.saturating_add(1), WriterContour::Draining => {} diff --git a/src/transport/middle_proxy/pool_config.rs b/src/transport/middle_proxy/pool_config.rs index 66752bf..3e0b4d7 100644 --- a/src/transport/middle_proxy/pool_config.rs +++ b/src/transport/middle_proxy/pool_config.rs @@ -108,8 +108,8 @@ impl MePool { } pub async fn reconnect_all(self: &Arc) { - let ws = self.writers.read().await.clone(); - for w in ws { + let ws = self.writers.load_full(); + for w in ws.iter().cloned() { if let Ok(()) = self .connect_one_for_dc(w.addr, w.writer_dc, self.rng.as_ref()) .await diff --git a/src/transport/middle_proxy/pool_init.rs b/src/transport/middle_proxy/pool_init.rs index 29a70c5..227b69b 100644 --- a/src/transport/middle_proxy/pool_init.rs +++ b/src/transport/middle_proxy/pool_init.rs @@ -132,7 +132,7 @@ impl MePool { } } - if self.writers.read().await.is_empty() { + if self.writers.load_full().is_empty() { return Err(ProxyError::Proxy("No ME connections".into())); } info!( diff --git a/src/transport/middle_proxy/pool_refill.rs b/src/transport/middle_proxy/pool_refill.rs index fc916f4..d01551d 100644 --- a/src/transport/middle_proxy/pool_refill.rs +++ b/src/transport/middle_proxy/pool_refill.rs @@ -119,9 +119,9 @@ impl MePool { } if candidates.len() > 1 { let mut active_by_endpoint = HashMap::::new(); - let ws = self.writers.read().await; + let ws = self.writers.load_full(); for writer in ws.iter() { - if writer.draining.load(Ordering::Relaxed) { + if writer.draining.load(Ordering::Acquire) { continue; } if writer.writer_dc != dc { @@ -129,7 +129,7 @@ impl MePool { } if !matches!( super::pool::WriterContour::from_u8( - writer.contour.load(Ordering::Relaxed), + writer.contour.load(Ordering::Acquire), ), super::pool::WriterContour::Active ) { @@ -139,7 +139,6 @@ impl MePool { *active_by_endpoint.entry(writer.addr).or_insert(0) += 1; } } - drop(ws); candidates.sort_by_key(|addr| (active_by_endpoint.get(addr).copied().unwrap_or(0), *addr)); } let start = (self.rr.fetch_add(1, Ordering::Relaxed) as usize) % candidates.len(); diff --git a/src/transport/middle_proxy/pool_reinit.rs b/src/transport/middle_proxy/pool_reinit.rs index 0d5c6f4..7bdaebb 100644 --- a/src/transport/middle_proxy/pool_reinit.rs +++ b/src/transport/middle_proxy/pool_reinit.rs @@ -36,26 +36,26 @@ impl MePool { } fn clear_pending_hardswap_state(&self) { - self.pending_hardswap_generation.store(0, Ordering::Relaxed); + self.pending_hardswap_generation.store(0, Ordering::Release); self.pending_hardswap_started_at_epoch_secs - .store(0, Ordering::Relaxed); - self.pending_hardswap_map_hash.store(0, Ordering::Relaxed); - self.warm_generation.store(0, Ordering::Relaxed); + .store(0, Ordering::Release); + self.pending_hardswap_map_hash.store(0, Ordering::Release); + self.warm_generation.store(0, Ordering::Release); } async fn promote_warm_generation_to_active(&self, generation: u64) { - self.active_generation.store(generation, Ordering::Relaxed); - self.warm_generation.store(0, Ordering::Relaxed); + self.active_generation.store(generation, Ordering::Release); + self.warm_generation.store(0, Ordering::Release); - let ws = self.writers.read().await; + let ws = self.writers.load_full(); for writer in ws.iter() { - if writer.draining.load(Ordering::Relaxed) { + if writer.draining.load(Ordering::Acquire) { continue; } if writer.generation == generation { writer .contour - .store(WriterContour::Active.as_u8(), Ordering::Relaxed); + .store(WriterContour::Active.as_u8(), Ordering::Release); } } } @@ -177,9 +177,9 @@ impl MePool { dc: i32, endpoints: &HashSet, ) -> usize { - let ws = self.writers.read().await; + let ws = self.writers.load_full(); ws.iter() - .filter(|w| !w.draining.load(Ordering::Relaxed)) + .filter(|w| !w.draining.load(Ordering::Acquire)) .filter(|w| w.generation == generation) .filter(|w| w.writer_dc == dc) .filter(|w| endpoints.contains(&w.addr)) @@ -191,9 +191,9 @@ impl MePool { dc: i32, endpoints: &HashSet, ) -> usize { - let ws = self.writers.read().await; + let ws = self.writers.load_full(); ws.iter() - .filter(|w| !w.draining.load(Ordering::Relaxed)) + .filter(|w| !w.draining.load(Ordering::Acquire)) .filter(|w| w.writer_dc == dc) .filter(|w| endpoints.contains(&w.addr)) .count() @@ -358,12 +358,12 @@ impl MePool { } let next_generation = self.generation.fetch_add(1, Ordering::Relaxed) + 1; self.pending_hardswap_generation - .store(next_generation, Ordering::Relaxed); + .store(next_generation, Ordering::Release); self.pending_hardswap_started_at_epoch_secs - .store(now_epoch_secs, Ordering::Relaxed); + .store(now_epoch_secs, Ordering::Release); self.pending_hardswap_map_hash - .store(desired_map_hash, Ordering::Relaxed); - self.warm_generation.store(next_generation, Ordering::Relaxed); + .store(desired_map_hash, Ordering::Release); + self.warm_generation.store(next_generation, Ordering::Release); next_generation } } else { @@ -372,17 +372,17 @@ impl MePool { }; if hardswap { - self.warm_generation.store(generation, Ordering::Relaxed); + self.warm_generation.store(generation, Ordering::Release); self.warmup_generation_for_all_dcs(rng, generation, &desired_by_dc) .await; } else { self.reconcile_connections(rng).await; } - let writers = self.writers.read().await; + let writers = self.writers.load_full(); let active_writer_addrs: HashSet<(i32, SocketAddr)> = writers .iter() - .filter(|w| !w.draining.load(Ordering::Relaxed)) + .filter(|w| !w.draining.load(Ordering::Acquire)) .map(|w| (w.writer_dc, w.addr)) .collect(); let min_ratio = Self::permille_to_ratio( @@ -405,7 +405,7 @@ impl MePool { if hardswap { let fresh_writer_addrs: HashSet<(i32, SocketAddr)> = writers .iter() - .filter(|w| !w.draining.load(Ordering::Relaxed)) + .filter(|w| !w.draining.load(Ordering::Acquire)) .filter(|w| w.generation == generation) .map(|w| (w.writer_dc, w.addr)) .collect(); @@ -441,7 +441,7 @@ impl MePool { let stale_writer_ids: Vec = writers .iter() - .filter(|w| !w.draining.load(Ordering::Relaxed)) + .filter(|w| !w.draining.load(Ordering::Acquire)) .filter(|w| { if hardswap { w.generation < generation diff --git a/src/transport/middle_proxy/pool_status.rs b/src/transport/middle_proxy/pool_status.rs index 5fe45cb..b7f42fa 100644 --- a/src/transport/middle_proxy/pool_status.rs +++ b/src/transport/middle_proxy/pool_status.rs @@ -166,10 +166,10 @@ impl MePool { return false; } - let writers = self.writers.read().await.clone(); + let writers = self.writers.load_full(); let mut live_writers_by_dc = HashMap::::new(); - for writer in writers { - if writer.draining.load(Ordering::Relaxed) { + for writer in writers.iter() { + if writer.draining.load(Ordering::Acquire) { continue; } if let Ok(dc) = i16::try_from(writer.writer_dc) { @@ -203,10 +203,10 @@ impl MePool { return false; } - let writers = self.writers.read().await.clone(); + let writers = self.writers.load_full(); let mut live_writers_by_dc = HashMap::::new(); - for writer in writers { - if writer.draining.load(Ordering::Relaxed) { + for writer in writers.iter() { + if writer.draining.load(Ordering::Acquire) { continue; } if let Ok(dc) = i16::try_from(writer.writer_dc) { @@ -255,7 +255,7 @@ impl MePool { let idle_since = self.registry.writer_idle_since_snapshot().await; let activity = self.registry.writer_activity_snapshot().await; let rtt = self.rtt_stats.lock().await.clone(); - let writers = self.writers.read().await.clone(); + let writers = self.writers.load_full(); let mut live_writers_by_dc_endpoint = HashMap::<(i16, SocketAddr), usize>::new(); let mut live_writers_by_dc = HashMap::::new(); @@ -263,10 +263,10 @@ impl MePool { let mut dc_rtt_agg = HashMap::::new(); let mut writer_rows = Vec::::with_capacity(writers.len()); - for writer in writers { + for writer in writers.iter() { let endpoint = writer.addr; let dc = i16::try_from(writer.writer_dc).ok(); - let draining = writer.draining.load(Ordering::Relaxed); + let draining = writer.draining.load(Ordering::Acquire); let degraded = writer.degraded.load(Ordering::Relaxed); let matches_active_generation = writer.generation == active_generation; let in_desired_map = dc @@ -296,7 +296,7 @@ impl MePool { && drain_ttl_secs > 0 && drain_started_at_epoch_secs .is_some_and(|started| now_epoch_secs.saturating_sub(started) > drain_ttl_secs); - let state = match WriterContour::from_u8(writer.contour.load(Ordering::Relaxed)) { + let state = match WriterContour::from_u8(writer.contour.load(Ordering::Acquire)) { WriterContour::Warm => "warm", WriterContour::Active => "active", WriterContour::Draining => "draining", @@ -501,7 +501,7 @@ impl MePool { } MeApiRuntimeSnapshot { - active_generation: self.active_generation.load(Ordering::Relaxed), + active_generation: self.active_generation.load(Ordering::Acquire), warm_generation: self.warm_generation.load(Ordering::Relaxed), pending_hardswap_generation: self.pending_hardswap_generation.load(Ordering::Relaxed), pending_hardswap_age_secs, diff --git a/src/transport/middle_proxy/pool_writer.rs b/src/transport/middle_proxy/pool_writer.rs index 7d78b84..61915ac 100644 --- a/src/transport/middle_proxy/pool_writer.rs +++ b/src/transport/middle_proxy/pool_writer.rs @@ -34,10 +34,13 @@ fn is_me_peer_closed_error(error: &ProxyError) -> bool { impl MePool { pub(crate) async fn prune_closed_writers(self: &Arc) { - let closed_writer_ids: Vec = { - let ws = self.writers.read().await; - ws.iter().filter(|w| w.tx.is_closed()).map(|w| w.id).collect() - }; + let closed_writer_ids: Vec = self + .writers + .load_full() + .iter() + .filter(|w| w.tx.is_closed()) + .map(|w| w.id) + .collect(); if closed_writer_ids.is_empty() { return; } @@ -178,7 +181,9 @@ impl MePool { drain_deadline_epoch_secs: drain_deadline_epoch_secs.clone(), allow_drain_fallback: allow_drain_fallback.clone(), }; - self.writers.write().await.push(writer.clone()); + let mut new_writers = (*self.writers.load_full()).clone(); + new_writers.push(writer.clone()); + self.writers.store(Arc::new(new_writers)); self.registry.register_writer(writer_id, tx.clone()).await; self.registry.mark_writer_idle(writer_id).await; self.conn_count.fetch_add(1, Ordering::Relaxed); @@ -254,8 +259,9 @@ impl MePool { warn!(error = %e, "ME reader ended"); } } - let mut ws = writers_arc.write().await; + let mut ws = (*writers_arc.load_full()).clone(); ws.retain(|w| w.id != writer_id); + writers_arc.store(Arc::new(ws.clone())); info!(remaining = ws.len(), "Dead ME writer removed from pool"); }); @@ -503,26 +509,30 @@ impl MePool { let mut removed_dc: Option = None; let mut removed_uptime: Option = None; let mut trigger_refill = false; + if let Some(pos) = self + .writers + .load_full() + .iter() + .position(|w| w.id == writer_id) { - let mut ws = self.writers.write().await; - if let Some(pos) = ws.iter().position(|w| w.id == writer_id) { - let w = ws.remove(pos); - let was_draining = w.draining.load(Ordering::Relaxed); - if was_draining { - self.stats.decrement_pool_drain_active(); - } - self.stats.increment_me_writer_removed_total(); - w.cancel.cancel(); - removed_addr = Some(w.addr); - removed_dc = Some(w.writer_dc); - removed_uptime = Some(w.created_at.elapsed()); - trigger_refill = !was_draining; - if trigger_refill { - self.stats.increment_me_writer_removed_unexpected_total(); - } - close_tx = Some(w.tx.clone()); - self.conn_count.fetch_sub(1, Ordering::Relaxed); + let mut ws = (*self.writers.load_full()).clone(); + let w = ws.remove(pos); + self.writers.store(Arc::new(ws)); + let was_draining = w.draining.load(Ordering::Acquire); + if was_draining { + self.stats.decrement_pool_drain_active(); } + self.stats.increment_me_writer_removed_total(); + w.cancel.cancel(); + removed_addr = Some(w.addr); + removed_dc = Some(w.writer_dc); + removed_uptime = Some(w.created_at.elapsed()); + trigger_refill = !was_draining; + if trigger_refill { + self.stats.increment_me_writer_removed_unexpected_total(); + } + close_tx = Some(w.tx.clone()); + self.conn_count.fetch_sub(1, Ordering::Relaxed); } // State invariant: // - writer is removed from `self.writers` (pool visibility), @@ -576,25 +586,27 @@ impl MePool { ) { let timeout = timeout.filter(|d| !d.is_zero()); let found = { - let mut ws = self.writers.write().await; + let current = self.writers.load_full(); + let mut ws = (*current).clone(); if let Some(w) = ws.iter_mut().find(|w| w.id == writer_id) { - let already_draining = w.draining.swap(true, Ordering::Relaxed); + let already_draining = w.draining.swap(true, Ordering::Acquire); w.allow_drain_fallback - .store(allow_drain_fallback, Ordering::Relaxed); + .store(allow_drain_fallback, Ordering::Release); let now_epoch_secs = Self::now_epoch_secs(); w.draining_started_at_epoch_secs - .store(now_epoch_secs, Ordering::Relaxed); + .store(now_epoch_secs, Ordering::Release); let drain_deadline_epoch_secs = timeout .map(|duration| now_epoch_secs.saturating_add(duration.as_secs())) .unwrap_or(0); w.drain_deadline_epoch_secs - .store(drain_deadline_epoch_secs, Ordering::Relaxed); + .store(drain_deadline_epoch_secs, Ordering::Release); if !already_draining { self.stats.increment_pool_drain_active(); } w.contour - .store(WriterContour::Draining.as_u8(), Ordering::Relaxed); - w.draining.store(true, Ordering::Relaxed); + .store(WriterContour::Draining.as_u8(), Ordering::Release); + w.draining.store(true, Ordering::Release); + self.writers.store(Arc::new(ws)); true } else { false @@ -620,10 +632,10 @@ impl MePool { } pub(super) fn writer_accepts_new_binding(&self, writer: &MeWriter) -> bool { - if !writer.draining.load(Ordering::Relaxed) { + if !writer.draining.load(Ordering::Acquire) { return true; } - if !writer.allow_drain_fallback.load(Ordering::Relaxed) { + if !writer.allow_drain_fallback.load(Ordering::Acquire) { return false; } diff --git a/src/transport/middle_proxy/send.rs b/src/transport/middle_proxy/send.rs index 6791064..b0fc950 100644 --- a/src/transport/middle_proxy/send.rs +++ b/src/transport/middle_proxy/send.rs @@ -158,9 +158,8 @@ impl MePool { } let mut writers_snapshot = { - let ws = self.writers.read().await; + let ws = self.writers.load_full(); if ws.is_empty() { - drop(ws); match no_writer_mode { MeRouteNoWriterMode::AsyncRecoveryFailfast => { let deadline = *no_writer_deadline.get_or_insert_with(|| { @@ -200,19 +199,19 @@ impl MePool { } } } - if !self.writers.read().await.is_empty() { + if !self.writers.load_full().is_empty() { break; } } } - if !self.writers.read().await.is_empty() { + if !self.writers.load_full().is_empty() { continue; } let deadline = *no_writer_deadline .get_or_insert_with(|| Instant::now() + self.me_route_inline_recovery_wait); if !self.wait_for_writer_until(deadline).await { - if !self.writers.read().await.is_empty() { + if !self.writers.load_full().is_empty() { continue; } self.stats.increment_me_no_writer_failfast_total(); @@ -241,15 +240,15 @@ impl MePool { } } } - ws.clone() + ws }; let mut candidate_indices = self - .candidate_indices_for_dc(&writers_snapshot, routed_dc, false) + .candidate_indices_for_dc(writers_snapshot.as_ref(), routed_dc, false) .await; if candidate_indices.is_empty() { candidate_indices = self - .candidate_indices_for_dc(&writers_snapshot, routed_dc, true) + .candidate_indices_for_dc(writers_snapshot.as_ref(), routed_dc, true) .await; } if let Some(skip_writer_id) = skip_writer_id { @@ -304,15 +303,13 @@ impl MePool { } } tokio::time::sleep(Duration::from_millis(100 * emergency_attempts as u64)).await; - let ws2 = self.writers.read().await; - writers_snapshot = ws2.clone(); - drop(ws2); + writers_snapshot = self.writers.load_full(); candidate_indices = self - .candidate_indices_for_dc(&writers_snapshot, routed_dc, false) + .candidate_indices_for_dc(writers_snapshot.as_ref(), routed_dc, false) .await; if candidate_indices.is_empty() { candidate_indices = self - .candidate_indices_for_dc(&writers_snapshot, routed_dc, true) + .candidate_indices_for_dc(writers_snapshot.as_ref(), routed_dc, true) .await; } if candidate_indices.is_empty() { @@ -354,7 +351,7 @@ impl MePool { let ordered_candidate_indices = if pick_mode == MeWriterPickMode::P2c { self.p2c_ordered_candidate_indices( &candidate_indices, - &writers_snapshot, + writers_snapshot.as_ref(), &writer_idle_since, now_epoch_secs, start, @@ -427,17 +424,17 @@ impl MePool { } let effective_our_addr = SocketAddr::new(w.source_ip, our_addr.port()); let (payload, meta) = build_routed_payload(effective_our_addr); + if !self.registry.bind_writer(conn_id, w.id, meta.clone()).await { + debug!( + conn_id, + writer_id = w.id, + "ME writer disappeared before bind commit, retrying" + ); + continue; + } match w.tx.try_send(WriterCommand::Data(payload.clone())) { Ok(()) => { self.stats.increment_me_writer_pick_success_try_total(pick_mode); - if !self.registry.bind_writer(conn_id, w.id, meta).await { - debug!( - conn_id, - writer_id = w.id, - "ME writer disappeared before bind commit, retrying" - ); - continue; - } if w.generation < self.current_generation() { self.stats.increment_pool_stale_pick_total(); debug!( @@ -451,11 +448,19 @@ impl MePool { return Ok(()); } Err(TrySendError::Full(_)) => { + let _ = self + .registry + .evict_bound_conn_if_writer(conn_id, w.id) + .await; if fallback_blocking_idx.is_none() { fallback_blocking_idx = Some(idx); } } Err(TrySendError::Closed(_)) => { + let _ = self + .registry + .evict_bound_conn_if_writer(conn_id, w.id) + .await; self.stats.increment_me_writer_pick_closed_total(pick_mode); warn!(writer_id = w.id, "ME writer channel closed"); self.remove_writer_and_close_clients(w.id).await; @@ -477,6 +482,14 @@ impl MePool { self.stats.increment_me_writer_pick_blocking_fallback_total(); let effective_our_addr = SocketAddr::new(w.source_ip, our_addr.port()); let (payload, meta) = build_routed_payload(effective_our_addr); + if !self.registry.bind_writer(conn_id, w.id, meta.clone()).await { + debug!( + conn_id, + writer_id = w.id, + "ME writer disappeared before fallback bind commit, retrying" + ); + continue; + } match send_writer_command_with_timeout( &w.tx, WriterCommand::Data(payload.clone()), @@ -487,25 +500,25 @@ impl MePool { Ok(()) => { self.stats .increment_me_writer_pick_success_fallback_total(pick_mode); - if !self.registry.bind_writer(conn_id, w.id, meta).await { - debug!( - conn_id, - writer_id = w.id, - "ME writer disappeared before fallback bind commit, retrying" - ); - continue; - } if w.generation < self.current_generation() { self.stats.increment_pool_stale_pick_total(); } return Ok(()); } Err(TimedSendError::Closed(_)) => { + let _ = self + .registry + .evict_bound_conn_if_writer(conn_id, w.id) + .await; self.stats.increment_me_writer_pick_closed_total(pick_mode); warn!(writer_id = w.id, "ME writer channel closed (blocking)"); self.remove_writer_and_close_clients(w.id).await; } Err(TimedSendError::Timeout(_)) => { + let _ = self + .registry + .evict_bound_conn_if_writer(conn_id, w.id) + .await; self.stats.increment_me_writer_pick_full_total(pick_mode); debug!( conn_id, @@ -520,18 +533,18 @@ impl MePool { async fn wait_for_writer_until(&self, deadline: Instant) -> bool { let waiter = self.writer_available.notified(); - if !self.writers.read().await.is_empty() { + if !self.writers.load_full().is_empty() { return true; } let now = Instant::now(); if now >= deadline { - return !self.writers.read().await.is_empty(); + return !self.writers.load_full().is_empty(); } let timeout = deadline.saturating_duration_since(now); if tokio::time::timeout(timeout, waiter).await.is_ok() { return true; } - !self.writers.read().await.is_empty() + !self.writers.load_full().is_empty() } async fn wait_for_candidate_until(&self, routed_dc: i32, deadline: Instant) -> bool { @@ -560,19 +573,16 @@ impl MePool { } async fn has_candidate_for_target_dc(&self, routed_dc: i32) -> bool { - let writers_snapshot = { - let ws = self.writers.read().await; - if ws.is_empty() { - return false; - } - ws.clone() - }; + let writers_snapshot = self.writers.load_full(); + if writers_snapshot.is_empty() { + return false; + } let mut candidate_indices = self - .candidate_indices_for_dc(&writers_snapshot, routed_dc, false) + .candidate_indices_for_dc(writers_snapshot.as_ref(), routed_dc, false) .await; if candidate_indices.is_empty() { candidate_indices = self - .candidate_indices_for_dc(&writers_snapshot, routed_dc, true) + .candidate_indices_for_dc(writers_snapshot.as_ref(), routed_dc, true) .await; } !candidate_indices.is_empty() @@ -736,7 +746,7 @@ impl MePool { return false; } - match WriterContour::from_u8(writer.contour.load(Ordering::Relaxed)) { + match WriterContour::from_u8(writer.contour.load(Ordering::Acquire)) { WriterContour::Active => true, WriterContour::Warm => include_warm, WriterContour::Draining => true, @@ -744,7 +754,7 @@ impl MePool { } fn writer_contour_rank_for_selection(&self, writer: &super::pool::MeWriter) -> usize { - match WriterContour::from_u8(writer.contour.load(Ordering::Relaxed)) { + match WriterContour::from_u8(writer.contour.load(Ordering::Acquire)) { WriterContour::Active => 0, WriterContour::Warm => 1, WriterContour::Draining => 2, @@ -776,7 +786,7 @@ impl MePool { idle_since_by_writer: &HashMap, now_epoch_secs: u64, ) -> u64 { - let contour_penalty = match WriterContour::from_u8(writer.contour.load(Ordering::Relaxed)) { + let contour_penalty = match WriterContour::from_u8(writer.contour.load(Ordering::Acquire)) { WriterContour::Active => 0, WriterContour::Warm => PICK_PENALTY_WARM, WriterContour::Draining => PICK_PENALTY_DRAINING,