diff --git a/src/proxy/middle_relay.rs b/src/proxy/middle_relay.rs index 3259597..ca32e6f 100644 --- a/src/proxy/middle_relay.rs +++ b/src/proxy/middle_relay.rs @@ -4,7 +4,7 @@ use std::collections::{BTreeSet, HashMap}; use std::future::Future; use std::hash::{BuildHasher, Hash}; use std::net::{IpAddr, SocketAddr}; -use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; +use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::{Arc, Mutex, OnceLock}; use std::time::{Duration, Instant}; @@ -36,7 +36,6 @@ enum C2MeCommand { const DESYNC_DEDUP_WINDOW: Duration = Duration::from_secs(60); const DESYNC_DEDUP_MAX_ENTRIES: usize = 65_536; -const DESYNC_DEDUP_PRUNE_SCAN_LIMIT: usize = 1024; const DESYNC_FULL_CACHE_EMIT_MIN_INTERVAL: Duration = Duration::from_millis(1000); const DESYNC_ERROR_CLASS: &str = "frame_too_large_crypto_desync"; const C2ME_CHANNEL_CAPACITY_FALLBACK: usize = 128; @@ -57,12 +56,18 @@ const ME_D2C_FRAME_BUF_SHRINK_HYSTERESIS_FACTOR: usize = 2; const ME_D2C_SINGLE_WRITE_COALESCE_MAX_BYTES: usize = 128 * 1024; const QUOTA_RESERVE_SPIN_RETRIES: usize = 32; static DESYNC_DEDUP: OnceLock> = OnceLock::new(); +static DESYNC_DEDUP_PREVIOUS: OnceLock> = OnceLock::new(); static DESYNC_HASHER: OnceLock = OnceLock::new(); static DESYNC_FULL_CACHE_LAST_EMIT_AT: OnceLock>> = OnceLock::new(); -static DESYNC_DEDUP_EVER_SATURATED: OnceLock = OnceLock::new(); +static DESYNC_DEDUP_ROTATION_STATE: OnceLock> = OnceLock::new(); static RELAY_IDLE_CANDIDATE_REGISTRY: OnceLock> = OnceLock::new(); static RELAY_IDLE_MARK_SEQ: AtomicU64 = AtomicU64::new(0); +#[derive(Default)] +struct DesyncDedupRotationState { + current_started_at: Option, +} + struct RelayForensicsState { trace_id: u64, conn_id: u64, @@ -312,64 +317,76 @@ fn should_emit_full_desync(key: u64, all_full: bool, now: Instant) -> bool { return true; } - let dedup = DESYNC_DEDUP.get_or_init(DashMap::new); - let saturated_before = dedup.len() >= DESYNC_DEDUP_MAX_ENTRIES; - let ever_saturated = DESYNC_DEDUP_EVER_SATURATED.get_or_init(|| AtomicBool::new(false)); - if saturated_before { - ever_saturated.store(true, Ordering::Relaxed); - } + let dedup_current = DESYNC_DEDUP.get_or_init(DashMap::new); + let dedup_previous = DESYNC_DEDUP_PREVIOUS.get_or_init(DashMap::new); + let rotation_state = DESYNC_DEDUP_ROTATION_STATE + .get_or_init(|| Mutex::new(DesyncDedupRotationState::default())); - if let Some(mut seen_at) = dedup.get_mut(&key) { - if now.duration_since(*seen_at) >= DESYNC_DEDUP_WINDOW { - *seen_at = now; - return true; + let mut state = match rotation_state.lock() { + Ok(guard) => guard, + Err(poisoned) => { + let mut guard = poisoned.into_inner(); + *guard = DesyncDedupRotationState::default(); + rotation_state.clear_poison(); + guard } - return false; - } - - if dedup.len() >= DESYNC_DEDUP_MAX_ENTRIES { - let mut stale_keys = Vec::new(); - let mut oldest_candidate: Option<(u64, Instant)> = None; - for entry in dedup.iter().take(DESYNC_DEDUP_PRUNE_SCAN_LIMIT) { - let key = *entry.key(); - let seen_at = *entry.value(); - - match oldest_candidate { - Some((_, oldest_seen)) if seen_at >= oldest_seen => {} - _ => oldest_candidate = Some((key, seen_at)), - } - - if now.duration_since(seen_at) >= DESYNC_DEDUP_WINDOW { - stale_keys.push(*entry.key()); - } - } - for stale_key in stale_keys { - dedup.remove(&stale_key); - } - if dedup.len() >= DESYNC_DEDUP_MAX_ENTRIES { - let Some((evict_key, _)) = oldest_candidate else { - return false; - }; - dedup.remove(&evict_key); - dedup.insert(key, now); - return should_emit_full_desync_full_cache(now); - } - } - - dedup.insert(key, now); - let saturated_after = dedup.len() >= DESYNC_DEDUP_MAX_ENTRIES; - // Preserve the first sequential insert that reaches capacity as a normal - // emit, while still gating concurrent newcomer churn after the cache has - // ever been observed at saturation. - let was_ever_saturated = if saturated_after { - ever_saturated.swap(true, Ordering::Relaxed) - } else { - ever_saturated.load(Ordering::Relaxed) }; - if saturated_before || (saturated_after && was_ever_saturated) { + let rotate_now = match state.current_started_at { + Some(current_started_at) => match now.checked_duration_since(current_started_at) { + Some(elapsed) => elapsed >= DESYNC_DEDUP_WINDOW, + None => true, + }, + None => true, + }; + if rotate_now { + dedup_previous.clear(); + for entry in dedup_current.iter() { + dedup_previous.insert(*entry.key(), *entry.value()); + } + dedup_current.clear(); + state.current_started_at = Some(now); + } + + if let Some(seen_at) = dedup_current.get(&key).map(|entry| *entry.value()) { + let within_window = match now.checked_duration_since(seen_at) { + Some(elapsed) => elapsed < DESYNC_DEDUP_WINDOW, + None => true, + }; + if within_window { + return false; + } + dedup_current.insert(key, now); + return true; + } + + if let Some(seen_at) = dedup_previous.get(&key).map(|entry| *entry.value()) { + let within_window = match now.checked_duration_since(seen_at) { + Some(elapsed) => elapsed < DESYNC_DEDUP_WINDOW, + None => true, + }; + if within_window { + // Keep the original timestamp when promoting from previous bucket, + // so dedup expiry remains tied to first-seen time. + dedup_current.insert(key, seen_at); + return false; + } + dedup_previous.remove(&key); + } + + if dedup_current.len() >= DESYNC_DEDUP_MAX_ENTRIES { + // Bounded eviction path: rotate buckets instead of scanning/evicting + // arbitrary entries from a saturated single map. + dedup_previous.clear(); + for entry in dedup_current.iter() { + dedup_previous.insert(*entry.key(), *entry.value()); + } + dedup_current.clear(); + state.current_started_at = Some(now); + dedup_current.insert(key, now); should_emit_full_desync_full_cache(now) } else { + dedup_current.insert(key, now); true } } @@ -405,8 +422,20 @@ fn clear_desync_dedup_for_testing() { if let Some(dedup) = DESYNC_DEDUP.get() { dedup.clear(); } - if let Some(ever_saturated) = DESYNC_DEDUP_EVER_SATURATED.get() { - ever_saturated.store(false, Ordering::Relaxed); + if let Some(dedup_previous) = DESYNC_DEDUP_PREVIOUS.get() { + dedup_previous.clear(); + } + if let Some(rotation_state) = DESYNC_DEDUP_ROTATION_STATE.get() { + match rotation_state.lock() { + Ok(mut guard) => { + *guard = DesyncDedupRotationState::default(); + } + Err(poisoned) => { + let mut guard = poisoned.into_inner(); + *guard = DesyncDedupRotationState::default(); + rotation_state.clear_poison(); + } + } } if let Some(last_emit_at) = DESYNC_FULL_CACHE_LAST_EMIT_AT.get() { match last_emit_at.lock() { diff --git a/src/transport/middle_proxy/pool.rs b/src/transport/middle_proxy/pool.rs index 5faa76d..07d4d19 100644 --- a/src/transport/middle_proxy/pool.rs +++ b/src/transport/middle_proxy/pool.rs @@ -2,6 +2,7 @@ use std::collections::{HashMap, HashSet}; use std::net::{IpAddr, Ipv6Addr, SocketAddr}; +use std::ops::{Deref, DerefMut}; use std::sync::Arc; use std::sync::atomic::{ AtomicBool, AtomicI32, AtomicU8, AtomicU32, AtomicU64, AtomicUsize, Ordering, @@ -56,6 +57,87 @@ pub struct MeWriter { pub allow_drain_fallback: Arc, } +pub(super) struct WritersState { + // HARD INVARIANT: + // All writers.store() calls MUST be guarded by writers_write_guard. + writers: ArcSwap>, + writers_write_guard: Mutex<()>, +} + +impl WritersState { + pub(super) fn new() -> Self { + Self { + writers: ArcSwap::from_pointee(Vec::new()), + writers_write_guard: Mutex::new(()), + } + } + + pub(super) fn snapshot(&self) -> Arc> { + self.writers.load_full() + } + + pub(super) async fn read(&self) -> Arc> { + self.snapshot() + } + + pub(super) async fn write(&self) -> WritersWriteGuard<'_> { + let guard = self.writers_write_guard.lock().await; + let writers = (*self.writers.load_full()).clone(); + WritersWriteGuard { + state: self, + _guard: guard, + writers, + } + } + + pub(super) async fn update(&self, f: F) -> R + where + F: FnOnce(&mut Vec) -> R, + { + let mut guard = self.write().await; + f(&mut guard) + } + + fn debug_assert_store_guarded(&self) { + debug_assert!( + self.writers_write_guard.try_lock().is_err(), + "HARD INVARIANT violated: writers.store() without writers_write_guard" + ); + } + + fn store_guarded(&self, writers: Vec) { + self.debug_assert_store_guarded(); + self.writers.store(Arc::new(writers)); + } +} + +pub(super) struct WritersWriteGuard<'a> { + state: &'a WritersState, + _guard: tokio::sync::MutexGuard<'a, ()>, + writers: Vec, +} + +impl Deref for WritersWriteGuard<'_> { + type Target = Vec; + + fn deref(&self) -> &Self::Target { + &self.writers + } +} + +impl DerefMut for WritersWriteGuard<'_> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.writers + } +} + +impl Drop for WritersWriteGuard<'_> { + fn drop(&mut self) { + let writers = std::mem::take(&mut self.writers); + self.state.store_guarded(writers); + } +} + #[derive(Debug, Clone, Copy, PartialEq, Eq)] #[repr(u8)] pub(super) enum WriterContour { @@ -178,7 +260,7 @@ pub struct SecretSnapshot { #[allow(dead_code)] pub struct MePool { pub(super) registry: Arc, - pub(super) writers: Arc>>, + pub(super) writers: Arc, pub(super) rr: AtomicU64, pub(super) decision: NetworkDecision, pub(super) upstream: Option>, @@ -307,7 +389,7 @@ pub struct MePool { pub(super) me_last_drain_gate_updated_at_epoch_secs: AtomicU64, pub(super) runtime_ready: AtomicBool, pool_size: usize, - pub(super) preferred_endpoints_by_dc: Arc>>>, + pub(super) preferred_endpoints_by_dc: ArcSwap>>, } #[derive(Debug, Default)] @@ -443,7 +525,7 @@ impl MePool { let now_epoch_secs = Self::now_epoch_secs(); Arc::new(Self { registry, - writers: Arc::new(RwLock::new(Vec::new())), + writers: Arc::new(WritersState::new()), rr: AtomicU64::new(0), decision, upstream, @@ -649,7 +731,7 @@ impl MePool { me_last_drain_gate_block_reason: AtomicU8::new(MeDrainGateReason::Open as u8), me_last_drain_gate_updated_at_epoch_secs: AtomicU64::new(now_epoch_secs), runtime_ready: AtomicBool::new(false), - preferred_endpoints_by_dc: Arc::new(RwLock::new(preferred_endpoints_by_dc)), + preferred_endpoints_by_dc: ArcSwap::from_pointee(preferred_endpoints_by_dc), }) } @@ -1004,7 +1086,7 @@ impl MePool { MeSocksKdfPolicy::from_u8(self.me_socks_kdf_policy.load(Ordering::Relaxed)) } - pub(super) fn writers_arc(&self) -> Arc>> { + pub(super) fn writers_arc(&self) -> Arc { self.writers.clone() } @@ -1602,7 +1684,7 @@ impl MePool { let rebuilt = Self::build_endpoint_dc_map_from_maps(&map_v4, &map_v6); let preferred = Self::build_preferred_endpoints_by_dc(&self.decision, &map_v4, &map_v6); *self.endpoint_dc_map.write().await = rebuilt; - *self.preferred_endpoints_by_dc.write().await = preferred; + self.preferred_endpoints_by_dc.store(Arc::new(preferred)); let configured_endpoints = self .endpoint_dc_map .read() @@ -1622,7 +1704,7 @@ impl MePool { } pub(super) async fn preferred_endpoints_for_dc(&self, dc: i32) -> Vec { - let guard = self.preferred_endpoints_by_dc.read().await; + let guard = self.preferred_endpoints_by_dc.load(); guard.get(&dc).cloned().unwrap_or_default() } diff --git a/src/transport/middle_proxy/pool_config.rs b/src/transport/middle_proxy/pool_config.rs index ebbadd2..6e29918 100644 --- a/src/transport/middle_proxy/pool_config.rs +++ b/src/transport/middle_proxy/pool_config.rs @@ -112,7 +112,7 @@ impl MePool { pub async fn reconnect_all(self: &Arc) { let ws = self.writers.read().await.clone(); - for w in ws { + for w in ws.iter() { if let Ok(()) = self .connect_one_for_dc(w.addr, w.writer_dc, self.rng.as_ref()) .await diff --git a/src/transport/middle_proxy/pool_status.rs b/src/transport/middle_proxy/pool_status.rs index 1ef59e1..afb8efe 100644 --- a/src/transport/middle_proxy/pool_status.rs +++ b/src/transport/middle_proxy/pool_status.rs @@ -160,7 +160,7 @@ impl MePool { let writers = self.writers.read().await.clone(); let mut live_writers_by_dc = HashMap::::new(); - for writer in writers { + for writer in writers.iter() { if writer.draining.load(Ordering::Relaxed) { continue; } @@ -197,7 +197,7 @@ impl MePool { let writers = self.writers.read().await.clone(); let mut live_writers_by_dc = HashMap::::new(); - for writer in writers { + for writer in writers.iter() { if writer.draining.load(Ordering::Relaxed) { continue; } @@ -255,7 +255,7 @@ impl MePool { let mut dc_rtt_agg = HashMap::::new(); let mut writer_rows = Vec::::with_capacity(writers.len()); - for writer in writers { + for writer in writers.iter() { let endpoint = writer.addr; let dc = i16::try_from(writer.writer_dc).ok(); let draining = writer.draining.load(Ordering::Relaxed); diff --git a/src/transport/middle_proxy/pool_writer.rs b/src/transport/middle_proxy/pool_writer.rs index 908b113..ef5a766 100644 --- a/src/transport/middle_proxy/pool_writer.rs +++ b/src/transport/middle_proxy/pool_writer.rs @@ -195,7 +195,9 @@ impl MePool { drain_deadline_epoch_secs: drain_deadline_epoch_secs.clone(), allow_drain_fallback: allow_drain_fallback.clone(), }; - self.writers.write().await.push(writer.clone()); + self.writers + .update(|writers| writers.push(writer.clone())) + .await; self.registry.register_writer(writer_id, tx.clone()).await; self.registry.mark_writer_idle(writer_id).await; self.conn_count.fetch_add(1, Ordering::Relaxed);