diff --git a/src/config/load.rs b/src/config/load.rs index f57de17..d14510f 100644 --- a/src/config/load.rs +++ b/src/config/load.rs @@ -57,7 +57,10 @@ fn normalize_mask_host_to_ascii(host: &str, field: &str) -> Result { if host.starts_with('[') && host.ends_with(']') { let inner = &host[1..host.len() - 1]; let ip = inner.parse::().map_err(|_| { - ProxyError::Config(format!("Invalid {field}: '{}'. IPv6 literal is invalid", host)) + ProxyError::Config(format!( + "Invalid {field}: '{}'. IPv6 literal is invalid", + host + )) })?; return match ip { std::net::IpAddr::V6(v6) => Ok(format!("[{v6}]")), @@ -2063,18 +2066,29 @@ impl ProxyConfig { } let mut exclusive_mask = HashMap::with_capacity(config.censorship.exclusive_mask.len()); + let mut exclusive_mask_targets = + HashMap::with_capacity(config.censorship.exclusive_mask.len()); for (domain, target) in std::mem::take(&mut config.censorship.exclusive_mask) { - let domain = normalize_domain_to_ascii( - &domain, - "censorship.exclusive_mask domain", - )?; - let target = normalize_exclusive_mask_target( - &target, - "censorship.exclusive_mask target", - )?; + let domain = normalize_domain_to_ascii(&domain, "censorship.exclusive_mask domain")?; + let target = + normalize_exclusive_mask_target(&target, "censorship.exclusive_mask target")?; + let Some((host, port)) = parse_exclusive_mask_target(&target) else { + return Err(ProxyError::Config(format!( + "Invalid censorship.exclusive_mask target for '{}': '{}'. Expected host:port with port > 0", + domain, target + ))); + }; + exclusive_mask_targets.insert( + domain.clone(), + ExclusiveMaskTarget { + host: host.to_string(), + port, + }, + ); exclusive_mask.insert(domain, target); } config.censorship.exclusive_mask = exclusive_mask; + config.censorship.exclusive_mask_targets = exclusive_mask_targets; // Migration: prefer_ipv6 -> network.prefer. if config.general.prefer_ipv6 { diff --git a/src/config/types.rs b/src/config/types.rs index f1cc816..b707dff 100644 --- a/src/config/types.rs +++ b/src/config/types.rs @@ -1687,6 +1687,14 @@ impl Default for TlsFetchConfig { } } +#[derive(Debug, Clone)] +pub struct ExclusiveMaskTarget { + /// Target host after IDNA/IP normalization. + pub host: String, + /// TCP port for the selected target. + pub port: u16, +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct AntiCensorshipConfig { #[serde(default = "default_tls_domain")] @@ -1722,6 +1730,10 @@ pub struct AntiCensorshipConfig { #[serde(default)] pub exclusive_mask: HashMap, + /// Parsed runtime cache for per-SNI TCP mask targets. + #[serde(skip)] + pub exclusive_mask_targets: HashMap, + #[serde(default)] pub mask_unix_sock: Option, @@ -1846,6 +1858,7 @@ impl Default for AntiCensorshipConfig { mask_host: None, mask_port: default_mask_port(), exclusive_mask: HashMap::new(), + exclusive_mask_targets: HashMap::new(), mask_unix_sock: None, fake_cert_len: default_fake_cert_len(), tls_emulation: true, diff --git a/src/ip_tracker.rs b/src/ip_tracker.rs index 15a864f..e35ad73 100644 --- a/src/ip_tracker.rs +++ b/src/ip_tracker.rs @@ -3,38 +3,61 @@ #![allow(dead_code)] use std::collections::HashMap; +use std::collections::hash_map::DefaultHasher; +use std::hash::{Hash, Hasher}; use std::net::IpAddr; use std::sync::Arc; use std::sync::Mutex; -use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::atomic::{AtomicU8, AtomicU64, AtomicUsize, Ordering}; use std::time::{Duration, Instant}; -use tokio::sync::{Mutex as AsyncMutex, RwLock, RwLockWriteGuard}; +use dashmap::DashMap; +use tokio::sync::{Mutex as AsyncMutex, RwLock}; use crate::config::UserMaxUniqueIpsMode; const CLEANUP_DRAIN_BATCH_LIMIT: usize = 1024; const MAX_ACTIVE_IP_ENTRIES: u64 = 131_072; const MAX_RECENT_IP_ENTRIES: u64 = 262_144; +const USER_IP_TRACKER_SHARDS: usize = 64; +const USER_IP_TRACKER_SHARD_MASK: usize = USER_IP_TRACKER_SHARDS - 1; + +mod admission; +mod cleanup; +mod snapshot; +#[cfg(test)] +mod tests; + +#[derive(Debug, Default)] +struct UserIpShard { + active_ips: HashMap>, + recent_ips: HashMap>, +} + +#[derive(Debug, Default)] +struct CleanupShard { + queue: Mutex>>, +} /// Tracks active and recent client IPs for per-user admission control. #[derive(Debug, Clone)] pub struct UserIpTracker { - active_ips: Arc>>>, - recent_ips: Arc>>>, + shards: Arc]>>, active_entry_count: Arc, recent_entry_count: Arc, active_cap_rejects: Arc, recent_cap_rejects: Arc, cleanup_deferred_releases: Arc, - max_ips: Arc>>, - default_max_ips: Arc>, - limit_mode: Arc>, - limit_window: Arc>, + max_ips: Arc>, + default_max_ips: Arc, + limit_mode: Arc, + limit_window_secs: Arc, last_compact_epoch_secs: Arc, cleanup_queue_len: Arc, - cleanup_queue: Arc>>, - cleanup_drain_lock: Arc>, + cleanup_shards: Arc>, + cleanup_drain_locks: Arc]>>, + #[cfg(test)] + cleanup_queue_poison_probe: Arc>>, } /// Point-in-time memory counters for user/IP limiter state. @@ -60,26 +83,78 @@ pub struct UserIpTrackerMemoryStats { impl UserIpTracker { pub fn new() -> Self { + let shards = std::iter::repeat_with(|| RwLock::new(UserIpShard::default())) + .take(USER_IP_TRACKER_SHARDS) + .collect::>() + .into_boxed_slice(); + let cleanup_shards = std::iter::repeat_with(CleanupShard::default) + .take(USER_IP_TRACKER_SHARDS) + .collect::>() + .into_boxed_slice(); + let cleanup_drain_locks = std::iter::repeat_with(|| AsyncMutex::new(())) + .take(USER_IP_TRACKER_SHARDS) + .collect::>() + .into_boxed_slice(); Self { - active_ips: Arc::new(RwLock::new(HashMap::new())), - recent_ips: Arc::new(RwLock::new(HashMap::new())), + shards: Arc::new(shards), active_entry_count: Arc::new(AtomicU64::new(0)), recent_entry_count: Arc::new(AtomicU64::new(0)), active_cap_rejects: Arc::new(AtomicU64::new(0)), recent_cap_rejects: Arc::new(AtomicU64::new(0)), cleanup_deferred_releases: Arc::new(AtomicU64::new(0)), - max_ips: Arc::new(RwLock::new(HashMap::new())), - default_max_ips: Arc::new(RwLock::new(0)), - limit_mode: Arc::new(RwLock::new(UserMaxUniqueIpsMode::ActiveWindow)), - limit_window: Arc::new(RwLock::new(Duration::from_secs(30))), + max_ips: Arc::new(DashMap::new()), + default_max_ips: Arc::new(AtomicUsize::new(0)), + limit_mode: Arc::new(AtomicU8::new(Self::mode_to_u8( + UserMaxUniqueIpsMode::ActiveWindow, + ))), + limit_window_secs: Arc::new(AtomicU64::new(30)), last_compact_epoch_secs: Arc::new(AtomicU64::new(0)), cleanup_queue_len: Arc::new(AtomicU64::new(0)), - cleanup_queue: Arc::new(Mutex::new(HashMap::new())), - cleanup_drain_lock: Arc::new(AsyncMutex::new(())), + cleanup_shards: Arc::new(cleanup_shards), + cleanup_drain_locks: Arc::new(cleanup_drain_locks), + #[cfg(test)] + cleanup_queue_poison_probe: Arc::new(Mutex::new(HashMap::new())), } } - fn decrement_counter(counter: &AtomicU64, amount: usize) { + pub(super) fn mode_to_u8(mode: UserMaxUniqueIpsMode) -> u8 { + match mode { + UserMaxUniqueIpsMode::ActiveWindow => 0, + UserMaxUniqueIpsMode::TimeWindow => 1, + UserMaxUniqueIpsMode::Combined => 2, + } + } + + pub(super) fn mode_from_u8(raw: u8) -> UserMaxUniqueIpsMode { + match raw { + 1 => UserMaxUniqueIpsMode::TimeWindow, + 2 => UserMaxUniqueIpsMode::Combined, + _ => UserMaxUniqueIpsMode::ActiveWindow, + } + } + + pub(super) fn shard_idx(username: &str) -> usize { + let mut hasher = DefaultHasher::new(); + username.hash(&mut hasher); + (hasher.finish() as usize) & USER_IP_TRACKER_SHARD_MASK + } + + pub(super) fn limit_window(&self) -> Duration { + Duration::from_secs(self.limit_window_secs.load(Ordering::Relaxed).max(1)) + } + + pub(super) fn user_limit(&self, username: &str) -> Option { + self.max_ips + .get(username) + .map(|limit| *limit) + .filter(|limit| *limit > 0) + .or_else(|| { + let default_limit = self.default_max_ips.load(Ordering::Relaxed); + (default_limit > 0).then_some(default_limit) + }) + } + + pub(super) fn decrement_counter(counter: &AtomicU64, amount: usize) { if amount == 0 { return; } @@ -89,7 +164,7 @@ impl UserIpTracker { }); } - fn apply_active_cleanup( + pub(super) fn apply_active_cleanup( active_ips: &mut HashMap>, user: &str, ip: IpAddr, @@ -117,575 +192,49 @@ impl UserIpTracker { removed_active_entries } - /// Queues a deferred active IP cleanup for a later async drain. - pub fn enqueue_cleanup(&self, user: String, ip: IpAddr) { - match self.cleanup_queue.lock() { - Ok(mut queue) => { - let count = queue.entry((user, ip)).or_insert(0); - if *count == 0 { - self.cleanup_queue_len.fetch_add(1, Ordering::Relaxed); - } - *count = count.saturating_add(1); - self.cleanup_deferred_releases - .fetch_add(1, Ordering::Relaxed); - } - Err(poisoned) => { - let mut queue = poisoned.into_inner(); - let count = queue.entry((user.clone(), ip)).or_insert(0); - if *count == 0 { - self.cleanup_queue_len.fetch_add(1, Ordering::Relaxed); - } - *count = count.saturating_add(1); - self.cleanup_deferred_releases - .fetch_add(1, Ordering::Relaxed); - self.cleanup_queue.clear_poison(); - tracing::warn!( - "UserIpTracker cleanup_queue lock poisoned; recovered and enqueued IP cleanup for {} ({})", - user, - ip - ); - } + pub(super) fn try_increment_counter(counter: &AtomicU64, cap: u64) -> bool { + counter + .fetch_update(Ordering::AcqRel, Ordering::Relaxed, |current| { + (current < cap).then_some(current + 1) + }) + .is_ok() + } + + pub(super) fn pop_one_cleanup( + queue: &mut HashMap>, + ) -> Option<(String, IpAddr, usize)> { + let user = queue.keys().next().cloned()?; + let ip = queue.get(&user)?.keys().next().copied()?; + let count = queue.get_mut(&user)?.remove(&ip)?; + let remove_user = queue + .get(&user) + .map(|user_queue| user_queue.is_empty()) + .unwrap_or(false); + if remove_user { + queue.remove(&user); } + Some((user, ip, count)) } #[cfg(test)] - pub(crate) fn cleanup_queue_len_for_tests(&self) -> usize { - self.cleanup_queue - .lock() - .unwrap_or_else(|poisoned| poisoned.into_inner()) - .len() - } - - #[cfg(test)] - pub(crate) fn cleanup_queue_mutex_for_tests( - &self, - ) -> Arc>> { - Arc::clone(&self.cleanup_queue) - } - - pub(crate) async fn drain_cleanup_queue(&self) { - if self.cleanup_queue_len.load(Ordering::Relaxed) == 0 { - return; - } - let Ok(_drain_guard) = self.cleanup_drain_lock.try_lock() else { - return; - }; - - let to_remove = { - match self.cleanup_queue.lock() { - Ok(mut queue) => { - if queue.is_empty() { - return; - } - let mut drained = - HashMap::with_capacity(queue.len().min(CLEANUP_DRAIN_BATCH_LIMIT)); - for _ in 0..CLEANUP_DRAIN_BATCH_LIMIT { - let Some(key) = queue.keys().next().cloned() else { - break; - }; - if let Some(count) = queue.remove(&key) { - self.cleanup_queue_len.fetch_sub(1, Ordering::Relaxed); - drained.insert(key, count); - } - } - drained - } - Err(poisoned) => { - let mut queue = poisoned.into_inner(); - if queue.is_empty() { - self.cleanup_queue.clear_poison(); - return; - } - let mut drained = - HashMap::with_capacity(queue.len().min(CLEANUP_DRAIN_BATCH_LIMIT)); - for _ in 0..CLEANUP_DRAIN_BATCH_LIMIT { - let Some(key) = queue.keys().next().cloned() else { - break; - }; - if let Some(count) = queue.remove(&key) { - self.cleanup_queue_len.fetch_sub(1, Ordering::Relaxed); - drained.insert(key, count); - } - } - self.cleanup_queue.clear_poison(); - drained - } + pub(super) fn observe_cleanup_poison_for_tests(&self) { + match self.cleanup_queue_poison_probe.lock() { + Ok(_) => {} + Err(_) => { + self.cleanup_queue_poison_probe.clear_poison(); } - }; - if to_remove.is_empty() { - return; } - - let mut active_ips = self.active_ips.write().await; - let mut removed_active_entries = 0usize; - for ((user, ip), pending_count) in to_remove { - removed_active_entries = removed_active_entries.saturating_add( - Self::apply_active_cleanup(&mut active_ips, &user, ip, pending_count), - ); - } - Self::decrement_counter(&self.active_entry_count, removed_active_entries); } - fn now_epoch_secs() -> u64 { + #[cfg(not(test))] + pub(super) fn observe_cleanup_poison_for_tests(&self) {} + + pub(super) fn now_epoch_secs() -> u64 { std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap_or_default() .as_secs() } - - async fn active_and_recent_write( - &self, - ) -> ( - RwLockWriteGuard<'_, HashMap>>, - RwLockWriteGuard<'_, HashMap>>, - ) { - loop { - let active_ips = self.active_ips.write().await; - match self.recent_ips.try_write() { - Ok(recent_ips) => return (active_ips, recent_ips), - Err(_) => { - drop(active_ips); - tokio::task::yield_now().await; - } - } - } - } - - async fn maybe_compact_empty_users(&self) { - const COMPACT_INTERVAL_SECS: u64 = 60; - let now_epoch_secs = Self::now_epoch_secs(); - let last_compact_epoch_secs = self.last_compact_epoch_secs.load(Ordering::Relaxed); - if now_epoch_secs.saturating_sub(last_compact_epoch_secs) < COMPACT_INTERVAL_SECS { - return; - } - if self - .last_compact_epoch_secs - .compare_exchange( - last_compact_epoch_secs, - now_epoch_secs, - Ordering::AcqRel, - Ordering::Relaxed, - ) - .is_err() - { - return; - } - - let window = *self.limit_window.read().await; - let now = Instant::now(); - let (mut active_ips, mut recent_ips) = self.active_and_recent_write().await; - - let mut pruned_recent_entries = 0usize; - for user_recent in recent_ips.values_mut() { - pruned_recent_entries = - pruned_recent_entries.saturating_add(Self::prune_recent(user_recent, now, window)); - } - Self::decrement_counter(&self.recent_entry_count, pruned_recent_entries); - - let mut users = - Vec::::with_capacity(active_ips.len().saturating_add(recent_ips.len())); - users.extend(active_ips.keys().cloned()); - for user in recent_ips.keys() { - if !active_ips.contains_key(user) { - users.push(user.clone()); - } - } - - for user in users { - let active_empty = active_ips - .get(&user) - .map(|ips| ips.is_empty()) - .unwrap_or(true); - let recent_empty = recent_ips - .get(&user) - .map(|ips| ips.is_empty()) - .unwrap_or(true); - if active_empty && recent_empty { - active_ips.remove(&user); - recent_ips.remove(&user); - } - } - } - - pub async fn run_periodic_maintenance(self: Arc) { - let mut interval = tokio::time::interval(Duration::from_secs(1)); - loop { - interval.tick().await; - self.drain_cleanup_queue().await; - self.maybe_compact_empty_users().await; - } - } - - pub async fn memory_stats(&self) -> UserIpTrackerMemoryStats { - let cleanup_queue_len = self.cleanup_queue_len.load(Ordering::Relaxed) as usize; - let active_ips = self.active_ips.read().await; - let recent_ips = self.recent_ips.read().await; - let active_entries = active_ips.values().map(HashMap::len).sum(); - let recent_entries = recent_ips.values().map(HashMap::len).sum(); - - UserIpTrackerMemoryStats { - active_users: active_ips.len(), - recent_users: recent_ips.len(), - active_entries, - recent_entries, - cleanup_queue_len, - active_cap_rejects: self.active_cap_rejects.load(Ordering::Relaxed), - recent_cap_rejects: self.recent_cap_rejects.load(Ordering::Relaxed), - cleanup_deferred_releases: self.cleanup_deferred_releases.load(Ordering::Relaxed), - } - } - - pub async fn set_limit_policy(&self, mode: UserMaxUniqueIpsMode, window_secs: u64) { - { - let mut current_mode = self.limit_mode.write().await; - *current_mode = mode; - } - let mut current_window = self.limit_window.write().await; - *current_window = Duration::from_secs(window_secs.max(1)); - } - - pub async fn set_user_limit(&self, username: &str, max_ips: usize) { - let mut limits = self.max_ips.write().await; - limits.insert(username.to_string(), max_ips); - } - - pub async fn remove_user_limit(&self, username: &str) { - let mut limits = self.max_ips.write().await; - limits.remove(username); - } - - pub async fn load_limits(&self, default_limit: usize, limits: &HashMap) { - let mut default_max_ips = self.default_max_ips.write().await; - *default_max_ips = default_limit; - drop(default_max_ips); - let mut max_ips = self.max_ips.write().await; - max_ips.clone_from(limits); - } - - fn prune_recent( - user_recent: &mut HashMap, - now: Instant, - window: Duration, - ) -> usize { - if user_recent.is_empty() { - return 0; - } - let before = user_recent.len(); - user_recent.retain(|_, seen_at| now.duration_since(*seen_at) <= window); - before.saturating_sub(user_recent.len()) - } - - pub async fn check_and_add(&self, username: &str, ip: IpAddr) -> Result<(), String> { - self.drain_cleanup_queue().await; - self.maybe_compact_empty_users().await; - let default_max_ips = *self.default_max_ips.read().await; - let limit = { - let max_ips = self.max_ips.read().await; - max_ips - .get(username) - .copied() - .filter(|limit| *limit > 0) - .or((default_max_ips > 0).then_some(default_max_ips)) - }; - let mode = *self.limit_mode.read().await; - let window = *self.limit_window.read().await; - let now = Instant::now(); - - let (mut active_ips, mut recent_ips) = self.active_and_recent_write().await; - if !active_ips.contains_key(username) { - active_ips.insert(username.to_string(), HashMap::new()); - } - if !recent_ips.contains_key(username) { - recent_ips.insert(username.to_string(), HashMap::new()); - } - let Some(user_active) = active_ips.get_mut(username) else { - return Err(format!("IP tracker active entry unavailable for user '{username}'")); - }; - let Some(user_recent) = recent_ips.get_mut(username) else { - return Err(format!("IP tracker recent entry unavailable for user '{username}'")); - }; - let pruned_recent_entries = Self::prune_recent(user_recent, now, window); - Self::decrement_counter(&self.recent_entry_count, pruned_recent_entries); - let recent_contains_ip = user_recent.contains_key(&ip); - - if let Some(count) = user_active.get_mut(&ip) { - if !recent_contains_ip - && self.recent_entry_count.load(Ordering::Relaxed) >= MAX_RECENT_IP_ENTRIES - { - self.recent_cap_rejects.fetch_add(1, Ordering::Relaxed); - return Err(format!( - "IP tracker recent entry cap reached: entries={}/{}", - self.recent_entry_count.load(Ordering::Relaxed), - MAX_RECENT_IP_ENTRIES - )); - } - *count = count.saturating_add(1); - if user_recent.insert(ip, now).is_none() { - self.recent_entry_count.fetch_add(1, Ordering::Relaxed); - } - return Ok(()); - } - - let is_new_ip = !recent_contains_ip; - - if let Some(limit) = limit { - let active_limit_reached = user_active.len() >= limit; - let recent_limit_reached = user_recent.len() >= limit && is_new_ip; - let deny = match mode { - UserMaxUniqueIpsMode::ActiveWindow => active_limit_reached, - UserMaxUniqueIpsMode::TimeWindow => recent_limit_reached, - UserMaxUniqueIpsMode::Combined => active_limit_reached || recent_limit_reached, - }; - - if deny { - return Err(format!( - "IP limit reached for user '{}': active={}/{} recent={}/{} mode={:?}", - username, - user_active.len(), - limit, - user_recent.len(), - limit, - mode - )); - } - } - - if self.active_entry_count.load(Ordering::Relaxed) >= MAX_ACTIVE_IP_ENTRIES { - self.active_cap_rejects.fetch_add(1, Ordering::Relaxed); - return Err(format!( - "IP tracker active entry cap reached: entries={}/{}", - self.active_entry_count.load(Ordering::Relaxed), - MAX_ACTIVE_IP_ENTRIES - )); - } - if is_new_ip && self.recent_entry_count.load(Ordering::Relaxed) >= MAX_RECENT_IP_ENTRIES { - self.recent_cap_rejects.fetch_add(1, Ordering::Relaxed); - return Err(format!( - "IP tracker recent entry cap reached: entries={}/{}", - self.recent_entry_count.load(Ordering::Relaxed), - MAX_RECENT_IP_ENTRIES - )); - } - - if user_active.insert(ip, 1).is_none() { - self.active_entry_count.fetch_add(1, Ordering::Relaxed); - } - if user_recent.insert(ip, now).is_none() { - self.recent_entry_count.fetch_add(1, Ordering::Relaxed); - } - Ok(()) - } - - pub async fn remove_ip(&self, username: &str, ip: IpAddr) { - self.maybe_compact_empty_users().await; - let mut active_ips = self.active_ips.write().await; - let mut removed_active_entries = 0usize; - if let Some(user_ips) = active_ips.get_mut(username) { - if let Some(count) = user_ips.get_mut(&ip) { - if *count > 1 { - *count -= 1; - } else { - if user_ips.remove(&ip).is_some() { - removed_active_entries = 1; - } - } - } - if user_ips.is_empty() { - active_ips.remove(username); - } - } - Self::decrement_counter(&self.active_entry_count, removed_active_entries); - } - - pub async fn get_recent_counts_for_users(&self, users: &[String]) -> HashMap { - self.drain_cleanup_queue().await; - self.get_recent_counts_for_users_snapshot(users).await - } - - pub(crate) async fn get_recent_counts_for_users_snapshot( - &self, - users: &[String], - ) -> HashMap { - let window = *self.limit_window.read().await; - let now = Instant::now(); - let recent_ips = self.recent_ips.read().await; - - let mut counts = HashMap::with_capacity(users.len()); - for user in users { - let count = if let Some(user_recent) = recent_ips.get(user) { - user_recent - .values() - .filter(|seen_at| now.duration_since(**seen_at) <= window) - .count() - } else { - 0 - }; - counts.insert(user.clone(), count); - } - counts - } - - pub async fn get_active_ips_for_users(&self, users: &[String]) -> HashMap> { - self.drain_cleanup_queue().await; - let active_ips = self.active_ips.read().await; - let mut out = HashMap::with_capacity(users.len()); - for user in users { - let mut ips = active_ips - .get(user) - .map(|per_ip| per_ip.keys().copied().collect::>()) - .unwrap_or_else(Vec::new); - ips.sort(); - out.insert(user.clone(), ips); - } - out - } - - pub async fn get_recent_ips_for_users(&self, users: &[String]) -> HashMap> { - self.drain_cleanup_queue().await; - let window = *self.limit_window.read().await; - let now = Instant::now(); - let recent_ips = self.recent_ips.read().await; - - let mut out = HashMap::with_capacity(users.len()); - for user in users { - let mut ips = if let Some(user_recent) = recent_ips.get(user) { - user_recent - .iter() - .filter(|(_, seen_at)| now.duration_since(**seen_at) <= window) - .map(|(ip, _)| *ip) - .collect::>() - } else { - Vec::new() - }; - ips.sort(); - out.insert(user.clone(), ips); - } - out - } - - pub async fn get_active_ip_count(&self, username: &str) -> usize { - self.drain_cleanup_queue().await; - let active_ips = self.active_ips.read().await; - active_ips.get(username).map(|ips| ips.len()).unwrap_or(0) - } - - pub async fn get_active_ips(&self, username: &str) -> Vec { - self.drain_cleanup_queue().await; - let active_ips = self.active_ips.read().await; - active_ips - .get(username) - .map(|ips| ips.keys().copied().collect()) - .unwrap_or_else(Vec::new) - } - - pub async fn get_stats(&self) -> Vec<(String, usize, usize)> { - self.drain_cleanup_queue().await; - self.get_stats_snapshot().await - } - - pub(crate) async fn get_stats_snapshot(&self) -> Vec<(String, usize, usize)> { - let active_ips = self.active_ips.read().await; - let active_counts = active_ips - .iter() - .map(|(username, user_ips)| (username.clone(), user_ips.len())) - .collect::>(); - drop(active_ips); - - let max_ips = self.max_ips.read().await; - let default_max_ips = *self.default_max_ips.read().await; - - let mut stats = Vec::with_capacity(active_counts.len()); - for (username, active_count) in active_counts { - let limit = max_ips - .get(&username) - .copied() - .filter(|limit| *limit > 0) - .or((default_max_ips > 0).then_some(default_max_ips)) - .unwrap_or(0); - stats.push((username, active_count, limit)); - } - - stats.sort_by(|a, b| a.0.cmp(&b.0)); - stats - } - - pub async fn clear_user_ips(&self, username: &str) { - let mut active_ips = self.active_ips.write().await; - let removed_active_entries = active_ips - .remove(username) - .map(|ips| ips.len()) - .unwrap_or(0); - drop(active_ips); - Self::decrement_counter(&self.active_entry_count, removed_active_entries); - - let mut recent_ips = self.recent_ips.write().await; - let removed_recent_entries = recent_ips - .remove(username) - .map(|ips| ips.len()) - .unwrap_or(0); - Self::decrement_counter(&self.recent_entry_count, removed_recent_entries); - } - - pub async fn clear_all(&self) { - let mut active_ips = self.active_ips.write().await; - active_ips.clear(); - drop(active_ips); - self.active_entry_count.store(0, Ordering::Relaxed); - - let mut recent_ips = self.recent_ips.write().await; - recent_ips.clear(); - self.recent_entry_count.store(0, Ordering::Relaxed); - } - - pub async fn is_ip_active(&self, username: &str, ip: IpAddr) -> bool { - self.drain_cleanup_queue().await; - let active_ips = self.active_ips.read().await; - active_ips - .get(username) - .map(|ips| ips.contains_key(&ip)) - .unwrap_or(false) - } - - pub async fn get_user_limit(&self, username: &str) -> Option { - let default_max_ips = *self.default_max_ips.read().await; - let max_ips = self.max_ips.read().await; - max_ips - .get(username) - .copied() - .filter(|limit| *limit > 0) - .or((default_max_ips > 0).then_some(default_max_ips)) - } - - pub async fn format_stats(&self) -> String { - let stats = self.get_stats().await; - - if stats.is_empty() { - return String::from("No active users"); - } - - let mut output = String::from("User IP Statistics:\n"); - output.push_str("==================\n"); - - for (username, active_count, limit) in stats { - output.push_str(&format!( - "User: {:<20} Active IPs: {}/{}\n", - username, - active_count, - if limit > 0 { - limit.to_string() - } else { - "unlimited".to_string() - } - )); - - let ips = self.get_active_ips(&username).await; - for ip in ips { - output.push_str(&format!(" - {}\n", ip)); - } - } - - output - } } impl Default for UserIpTracker { @@ -693,388 +242,3 @@ impl Default for UserIpTracker { Self::new() } } - -#[cfg(test)] -mod tests { - use super::*; - use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; - use std::sync::atomic::Ordering; - - fn test_ipv4(oct1: u8, oct2: u8, oct3: u8, oct4: u8) -> IpAddr { - IpAddr::V4(Ipv4Addr::new(oct1, oct2, oct3, oct4)) - } - - fn test_ipv6() -> IpAddr { - IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1)) - } - - #[tokio::test] - async fn test_basic_ip_limit() { - let tracker = UserIpTracker::new(); - tracker.set_user_limit("test_user", 2).await; - - let ip1 = test_ipv4(192, 168, 1, 1); - let ip2 = test_ipv4(192, 168, 1, 2); - let ip3 = test_ipv4(192, 168, 1, 3); - - assert!(tracker.check_and_add("test_user", ip1).await.is_ok()); - assert!(tracker.check_and_add("test_user", ip2).await.is_ok()); - assert!(tracker.check_and_add("test_user", ip3).await.is_err()); - - assert_eq!(tracker.get_active_ip_count("test_user").await, 2); - } - - #[tokio::test] - async fn test_active_window_rejects_new_ip_and_keeps_existing_session() { - let tracker = UserIpTracker::new(); - tracker.set_user_limit("test_user", 1).await; - tracker - .set_limit_policy(UserMaxUniqueIpsMode::ActiveWindow, 30) - .await; - - let ip1 = test_ipv4(10, 10, 10, 1); - let ip2 = test_ipv4(10, 10, 10, 2); - - assert!(tracker.check_and_add("test_user", ip1).await.is_ok()); - assert!(tracker.is_ip_active("test_user", ip1).await); - assert!(tracker.check_and_add("test_user", ip2).await.is_err()); - - // Existing session remains active; only new unique IP is denied. - assert!(tracker.is_ip_active("test_user", ip1).await); - assert_eq!(tracker.get_active_ip_count("test_user").await, 1); - } - - #[tokio::test] - async fn test_reconnection_from_same_ip() { - let tracker = UserIpTracker::new(); - tracker.set_user_limit("test_user", 2).await; - - let ip1 = test_ipv4(192, 168, 1, 1); - - assert!(tracker.check_and_add("test_user", ip1).await.is_ok()); - assert!(tracker.check_and_add("test_user", ip1).await.is_ok()); - assert_eq!(tracker.get_active_ip_count("test_user").await, 1); - } - - #[tokio::test] - async fn test_same_ip_disconnect_keeps_active_while_other_session_alive() { - let tracker = UserIpTracker::new(); - tracker.set_user_limit("test_user", 2).await; - - let ip1 = test_ipv4(192, 168, 1, 1); - - assert!(tracker.check_and_add("test_user", ip1).await.is_ok()); - assert!(tracker.check_and_add("test_user", ip1).await.is_ok()); - assert_eq!(tracker.get_active_ip_count("test_user").await, 1); - - tracker.remove_ip("test_user", ip1).await; - assert_eq!(tracker.get_active_ip_count("test_user").await, 1); - - tracker.remove_ip("test_user", ip1).await; - assert_eq!(tracker.get_active_ip_count("test_user").await, 0); - } - - #[tokio::test] - async fn test_ip_removal() { - let tracker = UserIpTracker::new(); - tracker.set_user_limit("test_user", 2).await; - - let ip1 = test_ipv4(192, 168, 1, 1); - let ip2 = test_ipv4(192, 168, 1, 2); - let ip3 = test_ipv4(192, 168, 1, 3); - - assert!(tracker.check_and_add("test_user", ip1).await.is_ok()); - assert!(tracker.check_and_add("test_user", ip2).await.is_ok()); - assert!(tracker.check_and_add("test_user", ip3).await.is_err()); - - tracker.remove_ip("test_user", ip1).await; - - assert!(tracker.check_and_add("test_user", ip3).await.is_ok()); - assert_eq!(tracker.get_active_ip_count("test_user").await, 2); - } - - #[tokio::test] - async fn test_no_limit() { - let tracker = UserIpTracker::new(); - - let ip1 = test_ipv4(192, 168, 1, 1); - let ip2 = test_ipv4(192, 168, 1, 2); - let ip3 = test_ipv4(192, 168, 1, 3); - - assert!(tracker.check_and_add("test_user", ip1).await.is_ok()); - assert!(tracker.check_and_add("test_user", ip2).await.is_ok()); - assert!(tracker.check_and_add("test_user", ip3).await.is_ok()); - - assert_eq!(tracker.get_active_ip_count("test_user").await, 3); - } - - #[tokio::test] - async fn test_multiple_users() { - let tracker = UserIpTracker::new(); - tracker.set_user_limit("user1", 2).await; - tracker.set_user_limit("user2", 1).await; - - let ip1 = test_ipv4(192, 168, 1, 1); - let ip2 = test_ipv4(192, 168, 1, 2); - - assert!(tracker.check_and_add("user1", ip1).await.is_ok()); - assert!(tracker.check_and_add("user1", ip2).await.is_ok()); - - assert!(tracker.check_and_add("user2", ip1).await.is_ok()); - assert!(tracker.check_and_add("user2", ip2).await.is_err()); - } - - #[tokio::test] - async fn test_ipv6_support() { - let tracker = UserIpTracker::new(); - tracker.set_user_limit("test_user", 2).await; - - let ipv4 = test_ipv4(192, 168, 1, 1); - let ipv6 = test_ipv6(); - - assert!(tracker.check_and_add("test_user", ipv4).await.is_ok()); - assert!(tracker.check_and_add("test_user", ipv6).await.is_ok()); - - assert_eq!(tracker.get_active_ip_count("test_user").await, 2); - } - - #[tokio::test] - async fn test_get_active_ips() { - let tracker = UserIpTracker::new(); - tracker.set_user_limit("test_user", 3).await; - - let ip1 = test_ipv4(192, 168, 1, 1); - let ip2 = test_ipv4(192, 168, 1, 2); - - tracker.check_and_add("test_user", ip1).await.unwrap(); - tracker.check_and_add("test_user", ip2).await.unwrap(); - - let active_ips = tracker.get_active_ips("test_user").await; - assert_eq!(active_ips.len(), 2); - assert!(active_ips.contains(&ip1)); - assert!(active_ips.contains(&ip2)); - } - - #[tokio::test] - async fn test_stats() { - let tracker = UserIpTracker::new(); - tracker.set_user_limit("user1", 3).await; - tracker.set_user_limit("user2", 2).await; - - let ip1 = test_ipv4(192, 168, 1, 1); - let ip2 = test_ipv4(192, 168, 1, 2); - - tracker.check_and_add("user1", ip1).await.unwrap(); - tracker.check_and_add("user2", ip2).await.unwrap(); - - let stats = tracker.get_stats().await; - assert_eq!(stats.len(), 2); - - assert!(stats.iter().any(|(name, _, _)| name == "user1")); - assert!(stats.iter().any(|(name, _, _)| name == "user2")); - } - - #[tokio::test] - async fn test_clear_user_ips() { - let tracker = UserIpTracker::new(); - let ip1 = test_ipv4(192, 168, 1, 1); - - tracker.check_and_add("test_user", ip1).await.unwrap(); - assert_eq!(tracker.get_active_ip_count("test_user").await, 1); - - tracker.clear_user_ips("test_user").await; - assert_eq!(tracker.get_active_ip_count("test_user").await, 0); - } - - #[tokio::test] - async fn test_is_ip_active() { - let tracker = UserIpTracker::new(); - let ip1 = test_ipv4(192, 168, 1, 1); - let ip2 = test_ipv4(192, 168, 1, 2); - - tracker.check_and_add("test_user", ip1).await.unwrap(); - - assert!(tracker.is_ip_active("test_user", ip1).await); - assert!(!tracker.is_ip_active("test_user", ip2).await); - } - - #[tokio::test] - async fn test_load_limits_from_config() { - let tracker = UserIpTracker::new(); - - let mut config_limits = HashMap::new(); - config_limits.insert("user1".to_string(), 5); - config_limits.insert("user2".to_string(), 3); - - tracker.load_limits(0, &config_limits).await; - - assert_eq!(tracker.get_user_limit("user1").await, Some(5)); - assert_eq!(tracker.get_user_limit("user2").await, Some(3)); - assert_eq!(tracker.get_user_limit("user3").await, None); - } - - #[tokio::test] - async fn test_load_limits_replaces_previous_map() { - let tracker = UserIpTracker::new(); - - let mut first = HashMap::new(); - first.insert("user1".to_string(), 2); - first.insert("user2".to_string(), 3); - tracker.load_limits(0, &first).await; - - let mut second = HashMap::new(); - second.insert("user2".to_string(), 5); - tracker.load_limits(0, &second).await; - - assert_eq!(tracker.get_user_limit("user1").await, None); - assert_eq!(tracker.get_user_limit("user2").await, Some(5)); - } - - #[tokio::test] - async fn test_global_each_limit_applies_without_user_override() { - let tracker = UserIpTracker::new(); - tracker.load_limits(2, &HashMap::new()).await; - - let ip1 = test_ipv4(172, 16, 0, 1); - let ip2 = test_ipv4(172, 16, 0, 2); - let ip3 = test_ipv4(172, 16, 0, 3); - - assert!(tracker.check_and_add("test_user", ip1).await.is_ok()); - assert!(tracker.check_and_add("test_user", ip2).await.is_ok()); - assert!(tracker.check_and_add("test_user", ip3).await.is_err()); - assert_eq!(tracker.get_user_limit("test_user").await, Some(2)); - } - - #[tokio::test] - async fn test_user_override_wins_over_global_each_limit() { - let tracker = UserIpTracker::new(); - let mut limits = HashMap::new(); - limits.insert("test_user".to_string(), 1); - tracker.load_limits(3, &limits).await; - - let ip1 = test_ipv4(172, 17, 0, 1); - let ip2 = test_ipv4(172, 17, 0, 2); - - assert!(tracker.check_and_add("test_user", ip1).await.is_ok()); - assert!(tracker.check_and_add("test_user", ip2).await.is_err()); - assert_eq!(tracker.get_user_limit("test_user").await, Some(1)); - } - - #[tokio::test] - async fn test_time_window_mode_blocks_recent_ip_churn() { - let tracker = UserIpTracker::new(); - tracker.set_user_limit("test_user", 1).await; - tracker - .set_limit_policy(UserMaxUniqueIpsMode::TimeWindow, 30) - .await; - - let ip1 = test_ipv4(10, 0, 0, 1); - let ip2 = test_ipv4(10, 0, 0, 2); - - assert!(tracker.check_and_add("test_user", ip1).await.is_ok()); - tracker.remove_ip("test_user", ip1).await; - assert!(tracker.check_and_add("test_user", ip2).await.is_err()); - } - - #[tokio::test] - async fn test_combined_mode_enforces_active_and_recent_limits() { - let tracker = UserIpTracker::new(); - tracker.set_user_limit("test_user", 1).await; - tracker - .set_limit_policy(UserMaxUniqueIpsMode::Combined, 30) - .await; - - let ip1 = test_ipv4(10, 0, 1, 1); - let ip2 = test_ipv4(10, 0, 1, 2); - - assert!(tracker.check_and_add("test_user", ip1).await.is_ok()); - assert!(tracker.check_and_add("test_user", ip2).await.is_err()); - - tracker.remove_ip("test_user", ip1).await; - assert!(tracker.check_and_add("test_user", ip2).await.is_err()); - } - - #[tokio::test] - async fn test_time_window_expires() { - let tracker = UserIpTracker::new(); - tracker.set_user_limit("test_user", 1).await; - tracker - .set_limit_policy(UserMaxUniqueIpsMode::TimeWindow, 1) - .await; - - let ip1 = test_ipv4(10, 1, 0, 1); - let ip2 = test_ipv4(10, 1, 0, 2); - - assert!(tracker.check_and_add("test_user", ip1).await.is_ok()); - tracker.remove_ip("test_user", ip1).await; - assert!(tracker.check_and_add("test_user", ip2).await.is_err()); - - tokio::time::sleep(Duration::from_millis(1100)).await; - assert!(tracker.check_and_add("test_user", ip2).await.is_ok()); - } - - #[tokio::test] - async fn test_memory_stats_reports_queue_and_entry_counts() { - let tracker = UserIpTracker::new(); - tracker.set_user_limit("test_user", 4).await; - let ip1 = test_ipv4(10, 2, 0, 1); - let ip2 = test_ipv4(10, 2, 0, 2); - - tracker.check_and_add("test_user", ip1).await.unwrap(); - tracker.check_and_add("test_user", ip2).await.unwrap(); - tracker.enqueue_cleanup("test_user".to_string(), ip1); - - let snapshot = tracker.memory_stats().await; - assert_eq!(snapshot.active_users, 1); - assert_eq!(snapshot.recent_users, 1); - assert_eq!(snapshot.active_entries, 2); - assert_eq!(snapshot.recent_entries, 2); - assert_eq!(snapshot.cleanup_queue_len, 1); - } - - #[tokio::test] - async fn test_compact_prunes_stale_recent_entries() { - let tracker = UserIpTracker::new(); - tracker - .set_limit_policy(UserMaxUniqueIpsMode::TimeWindow, 1) - .await; - - let stale_user = "stale-user".to_string(); - let stale_ip = test_ipv4(10, 3, 0, 1); - { - let mut recent_ips = tracker.recent_ips.write().await; - recent_ips - .entry(stale_user.clone()) - .or_insert_with(HashMap::new) - .insert(stale_ip, Instant::now() - Duration::from_secs(5)); - } - - tracker.last_compact_epoch_secs.store(0, Ordering::Relaxed); - tracker - .check_and_add("trigger-user", test_ipv4(10, 3, 0, 2)) - .await - .unwrap(); - - let recent_ips = tracker.recent_ips.read().await; - let stale_exists = recent_ips - .get(&stale_user) - .map(|ips| ips.contains_key(&stale_ip)) - .unwrap_or(false); - assert!(!stale_exists); - } - - #[tokio::test] - async fn test_time_window_allows_same_ip_reconnect() { - let tracker = UserIpTracker::new(); - tracker.set_user_limit("test_user", 1).await; - tracker - .set_limit_policy(UserMaxUniqueIpsMode::TimeWindow, 1) - .await; - - let ip1 = test_ipv4(10, 4, 0, 1); - - assert!(tracker.check_and_add("test_user", ip1).await.is_ok()); - tracker.remove_ip("test_user", ip1).await; - assert!(tracker.check_and_add("test_user", ip1).await.is_ok()); - } -} diff --git a/src/ip_tracker/admission.rs b/src/ip_tracker/admission.rs new file mode 100644 index 0000000..314b273 --- /dev/null +++ b/src/ip_tracker/admission.rs @@ -0,0 +1,173 @@ +use super::*; + +impl UserIpTracker { + pub async fn set_limit_policy(&self, mode: UserMaxUniqueIpsMode, window_secs: u64) { + self.limit_mode + .store(Self::mode_to_u8(mode), Ordering::Relaxed); + self.limit_window_secs + .store(window_secs.max(1), Ordering::Relaxed); + } + + pub async fn set_user_limit(&self, username: &str, max_ips: usize) { + self.max_ips.insert(username.to_string(), max_ips); + } + + pub async fn remove_user_limit(&self, username: &str) { + self.max_ips.remove(username); + } + + pub async fn load_limits(&self, default_limit: usize, limits: &HashMap) { + self.default_max_ips.store(default_limit, Ordering::Relaxed); + self.max_ips.clear(); + for (username, limit) in limits { + self.max_ips.insert(username.clone(), *limit); + } + } + + pub(super) fn prune_recent( + user_recent: &mut HashMap, + now: Instant, + window: Duration, + ) -> usize { + if user_recent.is_empty() { + return 0; + } + let before = user_recent.len(); + user_recent.retain(|_, seen_at| now.duration_since(*seen_at) <= window); + before.saturating_sub(user_recent.len()) + } + + pub async fn check_and_add(&self, username: &str, ip: IpAddr) -> Result<(), String> { + self.drain_cleanup_for_user(username).await; + self.maybe_compact_empty_users().await; + let limit = self.user_limit(username); + let mode = Self::mode_from_u8(self.limit_mode.load(Ordering::Relaxed)); + let window = self.limit_window(); + let now = Instant::now(); + + let shard_idx = Self::shard_idx(username); + let mut shard = self.shards[shard_idx].write().await; + let user_active = shard.active_ips.entry(username.to_string()).or_default(); + let active_contains_ip = user_active.contains_key(&ip); + let active_len = user_active.len(); + let user_recent = shard.recent_ips.entry(username.to_string()).or_default(); + let pruned_recent_entries = Self::prune_recent(user_recent, now, window); + Self::decrement_counter(&self.recent_entry_count, pruned_recent_entries); + let recent_contains_ip = user_recent.contains_key(&ip); + let recent_len = user_recent.len(); + + if active_contains_ip { + if !recent_contains_ip + && !Self::try_increment_counter(&self.recent_entry_count, MAX_RECENT_IP_ENTRIES) + { + self.recent_cap_rejects.fetch_add(1, Ordering::Relaxed); + return Err(format!( + "IP tracker recent entry cap reached: entries={}/{}", + self.recent_entry_count.load(Ordering::Relaxed), + MAX_RECENT_IP_ENTRIES + )); + } + let Some(count) = shard + .active_ips + .get_mut(username) + .and_then(|user_active| user_active.get_mut(&ip)) + else { + return Err(format!( + "IP tracker active entry unavailable for user '{username}'" + )); + }; + *count = count.saturating_add(1); + if let Some(user_recent) = shard.recent_ips.get_mut(username) { + user_recent.insert(ip, now); + } + return Ok(()); + } + + let is_new_ip = !recent_contains_ip; + + if let Some(limit) = limit { + let active_limit_reached = active_len >= limit; + let recent_limit_reached = recent_len >= limit && is_new_ip; + let deny = match mode { + UserMaxUniqueIpsMode::ActiveWindow => active_limit_reached, + UserMaxUniqueIpsMode::TimeWindow => recent_limit_reached, + UserMaxUniqueIpsMode::Combined => active_limit_reached || recent_limit_reached, + }; + + if deny { + return Err(format!( + "IP limit reached for user '{}': active={}/{} recent={}/{} mode={:?}", + username, active_len, limit, recent_len, limit, mode + )); + } + } + + if !Self::try_increment_counter(&self.active_entry_count, MAX_ACTIVE_IP_ENTRIES) { + self.active_cap_rejects.fetch_add(1, Ordering::Relaxed); + return Err(format!( + "IP tracker active entry cap reached: entries={}/{}", + self.active_entry_count.load(Ordering::Relaxed), + MAX_ACTIVE_IP_ENTRIES + )); + } + let mut reserved_recent = false; + if is_new_ip { + if !Self::try_increment_counter(&self.recent_entry_count, MAX_RECENT_IP_ENTRIES) { + Self::decrement_counter(&self.active_entry_count, 1); + self.recent_cap_rejects.fetch_add(1, Ordering::Relaxed); + return Err(format!( + "IP tracker recent entry cap reached: entries={}/{}", + self.recent_entry_count.load(Ordering::Relaxed), + MAX_RECENT_IP_ENTRIES + )); + } + reserved_recent = true; + } + + let Some(user_active) = shard.active_ips.get_mut(username) else { + Self::decrement_counter(&self.active_entry_count, 1); + if reserved_recent { + Self::decrement_counter(&self.recent_entry_count, 1); + } + return Err(format!( + "IP tracker active entry unavailable for user '{username}'" + )); + }; + if user_active.insert(ip, 1).is_some() { + Self::decrement_counter(&self.active_entry_count, 1); + } + let Some(user_recent) = shard.recent_ips.get_mut(username) else { + Self::decrement_counter(&self.active_entry_count, 1); + if reserved_recent { + Self::decrement_counter(&self.recent_entry_count, 1); + } + return Err(format!( + "IP tracker recent entry unavailable for user '{username}'" + )); + }; + if user_recent.insert(ip, now).is_some() && reserved_recent { + Self::decrement_counter(&self.recent_entry_count, 1); + } + Ok(()) + } + + pub async fn remove_ip(&self, username: &str, ip: IpAddr) { + self.maybe_compact_empty_users().await; + let shard_idx = Self::shard_idx(username); + let mut shard = self.shards[shard_idx].write().await; + let mut removed_active_entries = 0usize; + if let Some(user_ips) = shard.active_ips.get_mut(username) { + if let Some(count) = user_ips.get_mut(&ip) { + if *count > 1 { + *count -= 1; + } else if user_ips.remove(&ip).is_some() { + removed_active_entries = 1; + } + } + if user_ips.is_empty() { + shard.active_ips.remove(username); + } + } + Self::decrement_counter(&self.active_entry_count, removed_active_entries); + } +} diff --git a/src/ip_tracker/cleanup.rs b/src/ip_tracker/cleanup.rs new file mode 100644 index 0000000..86b4f8b --- /dev/null +++ b/src/ip_tracker/cleanup.rs @@ -0,0 +1,148 @@ +use super::*; + +impl UserIpTracker { + /// Queues a deferred active IP cleanup for a later async drain. + pub fn enqueue_cleanup(&self, user: String, ip: IpAddr) { + self.observe_cleanup_poison_for_tests(); + let shard_idx = Self::shard_idx(&user); + let cleanup_shard = &self.cleanup_shards[shard_idx]; + match cleanup_shard.queue.lock() { + Ok(mut queue) => { + let user_queue = queue.entry(user).or_default(); + let count = user_queue.entry(ip).or_insert(0); + if *count == 0 { + self.cleanup_queue_len.fetch_add(1, Ordering::Relaxed); + } + *count = count.saturating_add(1); + self.cleanup_deferred_releases + .fetch_add(1, Ordering::Relaxed); + } + Err(poisoned) => { + let mut queue = poisoned.into_inner(); + let user_queue = queue.entry(user.clone()).or_default(); + let count = user_queue.entry(ip).or_insert(0); + if *count == 0 { + self.cleanup_queue_len.fetch_add(1, Ordering::Relaxed); + } + *count = count.saturating_add(1); + self.cleanup_deferred_releases + .fetch_add(1, Ordering::Relaxed); + cleanup_shard.queue.clear_poison(); + tracing::warn!( + "UserIpTracker cleanup_queue lock poisoned; recovered and enqueued IP cleanup for {} ({})", + user, + ip + ); + } + } + } + + #[cfg(test)] + pub(crate) fn cleanup_queue_len_for_tests(&self) -> usize { + self.cleanup_queue_len.load(Ordering::Relaxed) as usize + } + + #[cfg(test)] + pub(crate) fn cleanup_queue_mutex_for_tests( + &self, + ) -> Arc>> { + Arc::clone(&self.cleanup_queue_poison_probe) + } + + pub(crate) async fn drain_cleanup_queue(&self) { + if self.cleanup_queue_len.load(Ordering::Relaxed) == 0 { + return; + } + for shard_idx in 0..USER_IP_TRACKER_SHARDS { + self.drain_cleanup_shard(shard_idx).await; + } + } + + pub(super) async fn drain_cleanup_for_user(&self, user: &str) { + if self.cleanup_queue_len.load(Ordering::Relaxed) == 0 { + return; + } + let shard_idx = Self::shard_idx(user); + let cleanup_shard = &self.cleanup_shards[shard_idx]; + let to_remove = match cleanup_shard.queue.lock() { + Ok(mut queue) => queue.remove(user).unwrap_or_default(), + Err(poisoned) => { + let mut queue = poisoned.into_inner(); + let drained = queue.remove(user).unwrap_or_default(); + cleanup_shard.queue.clear_poison(); + drained + } + }; + if to_remove.is_empty() { + return; + } + self.cleanup_queue_len + .fetch_sub(to_remove.len() as u64, Ordering::Relaxed); + let mut shard = self.shards[shard_idx].write().await; + let mut removed_active_entries = 0usize; + for (ip, pending_count) in to_remove { + removed_active_entries = removed_active_entries.saturating_add( + Self::apply_active_cleanup(&mut shard.active_ips, user, ip, pending_count), + ); + } + Self::decrement_counter(&self.active_entry_count, removed_active_entries); + } + + pub(super) async fn drain_cleanup_shard(&self, shard_idx: usize) { + let Ok(_drain_guard) = self.cleanup_drain_locks[shard_idx].try_lock() else { + return; + }; + + let cleanup_shard = &self.cleanup_shards[shard_idx]; + let to_remove = { + match cleanup_shard.queue.lock() { + Ok(mut queue) => { + if queue.is_empty() { + return; + } + let mut drained = + HashMap::with_capacity(queue.len().min(CLEANUP_DRAIN_BATCH_LIMIT)); + for _ in 0..CLEANUP_DRAIN_BATCH_LIMIT { + let Some((user, ip, count)) = Self::pop_one_cleanup(&mut queue) else { + break; + }; + self.cleanup_queue_len.fetch_sub(1, Ordering::Relaxed); + drained.insert((user, ip), count); + } + drained + } + Err(poisoned) => { + let mut queue = poisoned.into_inner(); + if queue.is_empty() { + cleanup_shard.queue.clear_poison(); + return; + } + let mut drained = + HashMap::with_capacity(queue.len().min(CLEANUP_DRAIN_BATCH_LIMIT)); + for _ in 0..CLEANUP_DRAIN_BATCH_LIMIT { + let Some((user, ip, count)) = Self::pop_one_cleanup(&mut queue) else { + break; + }; + self.cleanup_queue_len.fetch_sub(1, Ordering::Relaxed); + drained.insert((user, ip), count); + } + cleanup_shard.queue.clear_poison(); + drained + } + } + }; + drop(_drain_guard); + if to_remove.is_empty() { + return; + } + + let mut shard = self.shards[shard_idx].write().await; + let mut removed_active_entries = 0usize; + for ((user, ip), pending_count) in to_remove { + removed_active_entries = removed_active_entries.saturating_add( + Self::apply_active_cleanup(&mut shard.active_ips, &user, ip, pending_count), + ); + } + Self::decrement_counter(&self.active_entry_count, removed_active_entries); + } +} diff --git a/src/ip_tracker/snapshot.rs b/src/ip_tracker/snapshot.rs new file mode 100644 index 0000000..0d96b4b --- /dev/null +++ b/src/ip_tracker/snapshot.rs @@ -0,0 +1,309 @@ +use super::*; + +impl UserIpTracker { + pub(super) async fn maybe_compact_empty_users(&self) { + const COMPACT_INTERVAL_SECS: u64 = 60; + let now_epoch_secs = Self::now_epoch_secs(); + let last_compact_epoch_secs = self.last_compact_epoch_secs.load(Ordering::Relaxed); + if now_epoch_secs.saturating_sub(last_compact_epoch_secs) < COMPACT_INTERVAL_SECS { + return; + } + if self + .last_compact_epoch_secs + .compare_exchange( + last_compact_epoch_secs, + now_epoch_secs, + Ordering::AcqRel, + Ordering::Relaxed, + ) + .is_err() + { + return; + } + + let window = self.limit_window(); + let now = Instant::now(); + for shard_lock in self.shards.iter() { + let mut shard = shard_lock.write().await; + let mut pruned_recent_entries = 0usize; + for user_recent in shard.recent_ips.values_mut() { + pruned_recent_entries = pruned_recent_entries.saturating_add(Self::prune_recent( + user_recent, + now, + window, + )); + } + Self::decrement_counter(&self.recent_entry_count, pruned_recent_entries); + + let mut users = Vec::::with_capacity( + shard + .active_ips + .len() + .saturating_add(shard.recent_ips.len()), + ); + users.extend(shard.active_ips.keys().cloned()); + for user in shard.recent_ips.keys() { + if !shard.active_ips.contains_key(user) { + users.push(user.clone()); + } + } + + for user in users { + let active_empty = shard + .active_ips + .get(&user) + .map(|ips| ips.is_empty()) + .unwrap_or(true); + let recent_empty = shard + .recent_ips + .get(&user) + .map(|ips| ips.is_empty()) + .unwrap_or(true); + if active_empty && recent_empty { + shard.active_ips.remove(&user); + shard.recent_ips.remove(&user); + } + } + } + } + + pub async fn run_periodic_maintenance(self: Arc) { + let mut interval = tokio::time::interval(Duration::from_secs(1)); + loop { + interval.tick().await; + self.drain_cleanup_queue().await; + self.maybe_compact_empty_users().await; + } + } + + pub async fn memory_stats(&self) -> UserIpTrackerMemoryStats { + let cleanup_queue_len = self.cleanup_queue_len.load(Ordering::Relaxed) as usize; + let mut active_users = 0usize; + let mut recent_users = 0usize; + let mut active_entries = 0usize; + let mut recent_entries = 0usize; + for shard_lock in self.shards.iter() { + let shard = shard_lock.read().await; + active_users = active_users.saturating_add(shard.active_ips.len()); + recent_users = recent_users.saturating_add(shard.recent_ips.len()); + active_entries = + active_entries.saturating_add(shard.active_ips.values().map(HashMap::len).sum()); + recent_entries = + recent_entries.saturating_add(shard.recent_ips.values().map(HashMap::len).sum()); + } + + UserIpTrackerMemoryStats { + active_users, + recent_users, + active_entries, + recent_entries, + cleanup_queue_len, + active_cap_rejects: self.active_cap_rejects.load(Ordering::Relaxed), + recent_cap_rejects: self.recent_cap_rejects.load(Ordering::Relaxed), + cleanup_deferred_releases: self.cleanup_deferred_releases.load(Ordering::Relaxed), + } + } + + pub async fn get_recent_counts_for_users(&self, users: &[String]) -> HashMap { + self.drain_cleanup_queue().await; + self.get_recent_counts_for_users_snapshot(users).await + } + + pub(crate) async fn get_recent_counts_for_users_snapshot( + &self, + users: &[String], + ) -> HashMap { + let window = self.limit_window(); + let now = Instant::now(); + + let mut counts = HashMap::with_capacity(users.len()); + for user in users { + let shard_idx = Self::shard_idx(user); + let shard = self.shards[shard_idx].read().await; + let count = if let Some(user_recent) = shard.recent_ips.get(user) { + user_recent + .values() + .filter(|seen_at| now.duration_since(**seen_at) <= window) + .count() + } else { + 0 + }; + counts.insert(user.clone(), count); + } + counts + } + + pub async fn get_active_ips_for_users(&self, users: &[String]) -> HashMap> { + self.drain_cleanup_queue().await; + let mut out = HashMap::with_capacity(users.len()); + for user in users { + let shard_idx = Self::shard_idx(user); + let shard = self.shards[shard_idx].read().await; + let mut ips = shard + .active_ips + .get(user) + .map(|per_ip| per_ip.keys().copied().collect::>()) + .unwrap_or_else(Vec::new); + ips.sort(); + out.insert(user.clone(), ips); + } + out + } + + pub async fn get_recent_ips_for_users(&self, users: &[String]) -> HashMap> { + self.drain_cleanup_queue().await; + let window = self.limit_window(); + let now = Instant::now(); + + let mut out = HashMap::with_capacity(users.len()); + for user in users { + let shard_idx = Self::shard_idx(user); + let shard = self.shards[shard_idx].read().await; + let mut ips = if let Some(user_recent) = shard.recent_ips.get(user) { + user_recent + .iter() + .filter(|(_, seen_at)| now.duration_since(**seen_at) <= window) + .map(|(ip, _)| *ip) + .collect::>() + } else { + Vec::new() + }; + ips.sort(); + out.insert(user.clone(), ips); + } + out + } + + pub async fn get_active_ip_count(&self, username: &str) -> usize { + self.drain_cleanup_queue().await; + let shard_idx = Self::shard_idx(username); + let shard = self.shards[shard_idx].read().await; + shard + .active_ips + .get(username) + .map(|ips| ips.len()) + .unwrap_or(0) + } + + pub async fn get_active_ips(&self, username: &str) -> Vec { + self.drain_cleanup_queue().await; + let shard_idx = Self::shard_idx(username); + let shard = self.shards[shard_idx].read().await; + shard + .active_ips + .get(username) + .map(|ips| ips.keys().copied().collect()) + .unwrap_or_else(Vec::new) + } + + pub async fn get_stats(&self) -> Vec<(String, usize, usize)> { + self.drain_cleanup_queue().await; + self.get_stats_snapshot().await + } + + pub(crate) async fn get_stats_snapshot(&self) -> Vec<(String, usize, usize)> { + let mut active_counts = Vec::new(); + for shard_lock in self.shards.iter() { + let shard = shard_lock.read().await; + active_counts.extend( + shard + .active_ips + .iter() + .map(|(username, user_ips)| (username.clone(), user_ips.len())), + ); + } + + let mut stats = Vec::with_capacity(active_counts.len()); + for (username, active_count) in active_counts { + let limit = self.user_limit(&username).unwrap_or(0); + stats.push((username, active_count, limit)); + } + + stats.sort_by(|a, b| a.0.cmp(&b.0)); + stats + } + + pub async fn clear_user_ips(&self, username: &str) { + let shard_idx = Self::shard_idx(username); + let mut shard = self.shards[shard_idx].write().await; + let removed_active_entries = shard + .active_ips + .remove(username) + .map(|ips| ips.len()) + .unwrap_or(0); + Self::decrement_counter(&self.active_entry_count, removed_active_entries); + + let removed_recent_entries = shard + .recent_ips + .remove(username) + .map(|ips| ips.len()) + .unwrap_or(0); + Self::decrement_counter(&self.recent_entry_count, removed_recent_entries); + } + + pub async fn clear_all(&self) { + for shard_lock in self.shards.iter() { + let mut shard = shard_lock.write().await; + shard.active_ips.clear(); + shard.recent_ips.clear(); + } + self.active_entry_count.store(0, Ordering::Relaxed); + self.recent_entry_count.store(0, Ordering::Relaxed); + for cleanup_shard in self.cleanup_shards.iter() { + match cleanup_shard.queue.lock() { + Ok(mut queue) => queue.clear(), + Err(poisoned) => { + poisoned.into_inner().clear(); + cleanup_shard.queue.clear_poison(); + } + } + } + self.cleanup_queue_len.store(0, Ordering::Relaxed); + } + + pub async fn is_ip_active(&self, username: &str, ip: IpAddr) -> bool { + self.drain_cleanup_queue().await; + let shard_idx = Self::shard_idx(username); + let shard = self.shards[shard_idx].read().await; + shard + .active_ips + .get(username) + .map(|ips| ips.contains_key(&ip)) + .unwrap_or(false) + } + + pub async fn get_user_limit(&self, username: &str) -> Option { + self.user_limit(username) + } + + pub async fn format_stats(&self) -> String { + let stats = self.get_stats().await; + + if stats.is_empty() { + return String::from("No active users"); + } + + let mut output = String::from("User IP Statistics:\n"); + output.push_str("==================\n"); + + for (username, active_count, limit) in stats { + output.push_str(&format!( + "User: {:<20} Active IPs: {}/{}\n", + username, + active_count, + if limit > 0 { + limit.to_string() + } else { + "unlimited".to_string() + } + )); + + let ips = self.get_active_ips(&username).await; + for ip in ips { + output.push_str(&format!(" - {}\n", ip)); + } + } + + output + } +} diff --git a/src/ip_tracker/tests.rs b/src/ip_tracker/tests.rs new file mode 100644 index 0000000..ba1b1fe --- /dev/null +++ b/src/ip_tracker/tests.rs @@ -0,0 +1,385 @@ +use super::*; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; +use std::sync::atomic::Ordering; + +fn test_ipv4(oct1: u8, oct2: u8, oct3: u8, oct4: u8) -> IpAddr { + IpAddr::V4(Ipv4Addr::new(oct1, oct2, oct3, oct4)) +} + +fn test_ipv6() -> IpAddr { + IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1)) +} + +#[tokio::test] +async fn test_basic_ip_limit() { + let tracker = UserIpTracker::new(); + tracker.set_user_limit("test_user", 2).await; + + let ip1 = test_ipv4(192, 168, 1, 1); + let ip2 = test_ipv4(192, 168, 1, 2); + let ip3 = test_ipv4(192, 168, 1, 3); + + assert!(tracker.check_and_add("test_user", ip1).await.is_ok()); + assert!(tracker.check_and_add("test_user", ip2).await.is_ok()); + assert!(tracker.check_and_add("test_user", ip3).await.is_err()); + + assert_eq!(tracker.get_active_ip_count("test_user").await, 2); +} + +#[tokio::test] +async fn test_active_window_rejects_new_ip_and_keeps_existing_session() { + let tracker = UserIpTracker::new(); + tracker.set_user_limit("test_user", 1).await; + tracker + .set_limit_policy(UserMaxUniqueIpsMode::ActiveWindow, 30) + .await; + + let ip1 = test_ipv4(10, 10, 10, 1); + let ip2 = test_ipv4(10, 10, 10, 2); + + assert!(tracker.check_and_add("test_user", ip1).await.is_ok()); + assert!(tracker.is_ip_active("test_user", ip1).await); + assert!(tracker.check_and_add("test_user", ip2).await.is_err()); + + // Existing session remains active; only new unique IP is denied. + assert!(tracker.is_ip_active("test_user", ip1).await); + assert_eq!(tracker.get_active_ip_count("test_user").await, 1); +} + +#[tokio::test] +async fn test_reconnection_from_same_ip() { + let tracker = UserIpTracker::new(); + tracker.set_user_limit("test_user", 2).await; + + let ip1 = test_ipv4(192, 168, 1, 1); + + assert!(tracker.check_and_add("test_user", ip1).await.is_ok()); + assert!(tracker.check_and_add("test_user", ip1).await.is_ok()); + assert_eq!(tracker.get_active_ip_count("test_user").await, 1); +} + +#[tokio::test] +async fn test_same_ip_disconnect_keeps_active_while_other_session_alive() { + let tracker = UserIpTracker::new(); + tracker.set_user_limit("test_user", 2).await; + + let ip1 = test_ipv4(192, 168, 1, 1); + + assert!(tracker.check_and_add("test_user", ip1).await.is_ok()); + assert!(tracker.check_and_add("test_user", ip1).await.is_ok()); + assert_eq!(tracker.get_active_ip_count("test_user").await, 1); + + tracker.remove_ip("test_user", ip1).await; + assert_eq!(tracker.get_active_ip_count("test_user").await, 1); + + tracker.remove_ip("test_user", ip1).await; + assert_eq!(tracker.get_active_ip_count("test_user").await, 0); +} + +#[tokio::test] +async fn test_ip_removal() { + let tracker = UserIpTracker::new(); + tracker.set_user_limit("test_user", 2).await; + + let ip1 = test_ipv4(192, 168, 1, 1); + let ip2 = test_ipv4(192, 168, 1, 2); + let ip3 = test_ipv4(192, 168, 1, 3); + + assert!(tracker.check_and_add("test_user", ip1).await.is_ok()); + assert!(tracker.check_and_add("test_user", ip2).await.is_ok()); + assert!(tracker.check_and_add("test_user", ip3).await.is_err()); + + tracker.remove_ip("test_user", ip1).await; + + assert!(tracker.check_and_add("test_user", ip3).await.is_ok()); + assert_eq!(tracker.get_active_ip_count("test_user").await, 2); +} + +#[tokio::test] +async fn test_no_limit() { + let tracker = UserIpTracker::new(); + + let ip1 = test_ipv4(192, 168, 1, 1); + let ip2 = test_ipv4(192, 168, 1, 2); + let ip3 = test_ipv4(192, 168, 1, 3); + + assert!(tracker.check_and_add("test_user", ip1).await.is_ok()); + assert!(tracker.check_and_add("test_user", ip2).await.is_ok()); + assert!(tracker.check_and_add("test_user", ip3).await.is_ok()); + + assert_eq!(tracker.get_active_ip_count("test_user").await, 3); +} + +#[tokio::test] +async fn test_multiple_users() { + let tracker = UserIpTracker::new(); + tracker.set_user_limit("user1", 2).await; + tracker.set_user_limit("user2", 1).await; + + let ip1 = test_ipv4(192, 168, 1, 1); + let ip2 = test_ipv4(192, 168, 1, 2); + + assert!(tracker.check_and_add("user1", ip1).await.is_ok()); + assert!(tracker.check_and_add("user1", ip2).await.is_ok()); + + assert!(tracker.check_and_add("user2", ip1).await.is_ok()); + assert!(tracker.check_and_add("user2", ip2).await.is_err()); +} + +#[tokio::test] +async fn test_ipv6_support() { + let tracker = UserIpTracker::new(); + tracker.set_user_limit("test_user", 2).await; + + let ipv4 = test_ipv4(192, 168, 1, 1); + let ipv6 = test_ipv6(); + + assert!(tracker.check_and_add("test_user", ipv4).await.is_ok()); + assert!(tracker.check_and_add("test_user", ipv6).await.is_ok()); + + assert_eq!(tracker.get_active_ip_count("test_user").await, 2); +} + +#[tokio::test] +async fn test_get_active_ips() { + let tracker = UserIpTracker::new(); + tracker.set_user_limit("test_user", 3).await; + + let ip1 = test_ipv4(192, 168, 1, 1); + let ip2 = test_ipv4(192, 168, 1, 2); + + tracker.check_and_add("test_user", ip1).await.unwrap(); + tracker.check_and_add("test_user", ip2).await.unwrap(); + + let active_ips = tracker.get_active_ips("test_user").await; + assert_eq!(active_ips.len(), 2); + assert!(active_ips.contains(&ip1)); + assert!(active_ips.contains(&ip2)); +} + +#[tokio::test] +async fn test_stats() { + let tracker = UserIpTracker::new(); + tracker.set_user_limit("user1", 3).await; + tracker.set_user_limit("user2", 2).await; + + let ip1 = test_ipv4(192, 168, 1, 1); + let ip2 = test_ipv4(192, 168, 1, 2); + + tracker.check_and_add("user1", ip1).await.unwrap(); + tracker.check_and_add("user2", ip2).await.unwrap(); + + let stats = tracker.get_stats().await; + assert_eq!(stats.len(), 2); + + assert!(stats.iter().any(|(name, _, _)| name == "user1")); + assert!(stats.iter().any(|(name, _, _)| name == "user2")); +} + +#[tokio::test] +async fn test_clear_user_ips() { + let tracker = UserIpTracker::new(); + let ip1 = test_ipv4(192, 168, 1, 1); + + tracker.check_and_add("test_user", ip1).await.unwrap(); + assert_eq!(tracker.get_active_ip_count("test_user").await, 1); + + tracker.clear_user_ips("test_user").await; + assert_eq!(tracker.get_active_ip_count("test_user").await, 0); +} + +#[tokio::test] +async fn test_is_ip_active() { + let tracker = UserIpTracker::new(); + let ip1 = test_ipv4(192, 168, 1, 1); + let ip2 = test_ipv4(192, 168, 1, 2); + + tracker.check_and_add("test_user", ip1).await.unwrap(); + + assert!(tracker.is_ip_active("test_user", ip1).await); + assert!(!tracker.is_ip_active("test_user", ip2).await); +} + +#[tokio::test] +async fn test_load_limits_from_config() { + let tracker = UserIpTracker::new(); + + let mut config_limits = HashMap::new(); + config_limits.insert("user1".to_string(), 5); + config_limits.insert("user2".to_string(), 3); + + tracker.load_limits(0, &config_limits).await; + + assert_eq!(tracker.get_user_limit("user1").await, Some(5)); + assert_eq!(tracker.get_user_limit("user2").await, Some(3)); + assert_eq!(tracker.get_user_limit("user3").await, None); +} + +#[tokio::test] +async fn test_load_limits_replaces_previous_map() { + let tracker = UserIpTracker::new(); + + let mut first = HashMap::new(); + first.insert("user1".to_string(), 2); + first.insert("user2".to_string(), 3); + tracker.load_limits(0, &first).await; + + let mut second = HashMap::new(); + second.insert("user2".to_string(), 5); + tracker.load_limits(0, &second).await; + + assert_eq!(tracker.get_user_limit("user1").await, None); + assert_eq!(tracker.get_user_limit("user2").await, Some(5)); +} + +#[tokio::test] +async fn test_global_each_limit_applies_without_user_override() { + let tracker = UserIpTracker::new(); + tracker.load_limits(2, &HashMap::new()).await; + + let ip1 = test_ipv4(172, 16, 0, 1); + let ip2 = test_ipv4(172, 16, 0, 2); + let ip3 = test_ipv4(172, 16, 0, 3); + + assert!(tracker.check_and_add("test_user", ip1).await.is_ok()); + assert!(tracker.check_and_add("test_user", ip2).await.is_ok()); + assert!(tracker.check_and_add("test_user", ip3).await.is_err()); + assert_eq!(tracker.get_user_limit("test_user").await, Some(2)); +} + +#[tokio::test] +async fn test_user_override_wins_over_global_each_limit() { + let tracker = UserIpTracker::new(); + let mut limits = HashMap::new(); + limits.insert("test_user".to_string(), 1); + tracker.load_limits(3, &limits).await; + + let ip1 = test_ipv4(172, 17, 0, 1); + let ip2 = test_ipv4(172, 17, 0, 2); + + assert!(tracker.check_and_add("test_user", ip1).await.is_ok()); + assert!(tracker.check_and_add("test_user", ip2).await.is_err()); + assert_eq!(tracker.get_user_limit("test_user").await, Some(1)); +} + +#[tokio::test] +async fn test_time_window_mode_blocks_recent_ip_churn() { + let tracker = UserIpTracker::new(); + tracker.set_user_limit("test_user", 1).await; + tracker + .set_limit_policy(UserMaxUniqueIpsMode::TimeWindow, 30) + .await; + + let ip1 = test_ipv4(10, 0, 0, 1); + let ip2 = test_ipv4(10, 0, 0, 2); + + assert!(tracker.check_and_add("test_user", ip1).await.is_ok()); + tracker.remove_ip("test_user", ip1).await; + assert!(tracker.check_and_add("test_user", ip2).await.is_err()); +} + +#[tokio::test] +async fn test_combined_mode_enforces_active_and_recent_limits() { + let tracker = UserIpTracker::new(); + tracker.set_user_limit("test_user", 1).await; + tracker + .set_limit_policy(UserMaxUniqueIpsMode::Combined, 30) + .await; + + let ip1 = test_ipv4(10, 0, 1, 1); + let ip2 = test_ipv4(10, 0, 1, 2); + + assert!(tracker.check_and_add("test_user", ip1).await.is_ok()); + assert!(tracker.check_and_add("test_user", ip2).await.is_err()); + + tracker.remove_ip("test_user", ip1).await; + assert!(tracker.check_and_add("test_user", ip2).await.is_err()); +} + +#[tokio::test] +async fn test_time_window_expires() { + let tracker = UserIpTracker::new(); + tracker.set_user_limit("test_user", 1).await; + tracker + .set_limit_policy(UserMaxUniqueIpsMode::TimeWindow, 1) + .await; + + let ip1 = test_ipv4(10, 1, 0, 1); + let ip2 = test_ipv4(10, 1, 0, 2); + + assert!(tracker.check_and_add("test_user", ip1).await.is_ok()); + tracker.remove_ip("test_user", ip1).await; + assert!(tracker.check_and_add("test_user", ip2).await.is_err()); + + tokio::time::sleep(Duration::from_millis(1100)).await; + assert!(tracker.check_and_add("test_user", ip2).await.is_ok()); +} + +#[tokio::test] +async fn test_memory_stats_reports_queue_and_entry_counts() { + let tracker = UserIpTracker::new(); + tracker.set_user_limit("test_user", 4).await; + let ip1 = test_ipv4(10, 2, 0, 1); + let ip2 = test_ipv4(10, 2, 0, 2); + + tracker.check_and_add("test_user", ip1).await.unwrap(); + tracker.check_and_add("test_user", ip2).await.unwrap(); + tracker.enqueue_cleanup("test_user".to_string(), ip1); + + let snapshot = tracker.memory_stats().await; + assert_eq!(snapshot.active_users, 1); + assert_eq!(snapshot.recent_users, 1); + assert_eq!(snapshot.active_entries, 2); + assert_eq!(snapshot.recent_entries, 2); + assert_eq!(snapshot.cleanup_queue_len, 1); +} + +#[tokio::test] +async fn test_compact_prunes_stale_recent_entries() { + let tracker = UserIpTracker::new(); + tracker + .set_limit_policy(UserMaxUniqueIpsMode::TimeWindow, 1) + .await; + + let stale_user = "stale-user".to_string(); + let stale_ip = test_ipv4(10, 3, 0, 1); + { + let shard_idx = UserIpTracker::shard_idx(&stale_user); + let mut shard = tracker.shards[shard_idx].write().await; + shard + .recent_ips + .entry(stale_user.clone()) + .or_insert_with(HashMap::new) + .insert(stale_ip, Instant::now() - Duration::from_secs(5)); + } + + tracker.last_compact_epoch_secs.store(0, Ordering::Relaxed); + tracker + .check_and_add("trigger-user", test_ipv4(10, 3, 0, 2)) + .await + .unwrap(); + + let shard_idx = UserIpTracker::shard_idx(&stale_user); + let shard = tracker.shards[shard_idx].read().await; + let stale_exists = shard + .recent_ips + .get(&stale_user) + .map(|ips| ips.contains_key(&stale_ip)) + .unwrap_or(false); + assert!(!stale_exists); +} + +#[tokio::test] +async fn test_time_window_allows_same_ip_reconnect() { + let tracker = UserIpTracker::new(); + tracker.set_user_limit("test_user", 1).await; + tracker + .set_limit_policy(UserMaxUniqueIpsMode::TimeWindow, 1) + .await; + + let ip1 = test_ipv4(10, 4, 0, 1); + + assert!(tracker.check_and_add("test_user", ip1).await.is_ok()); + tracker.remove_ip("test_user", ip1).await; + assert!(tracker.check_and_add("test_user", ip1).await.is_ok()); +} diff --git a/src/maestro/mod.rs b/src/maestro/mod.rs index 0f3c14e..834704f 100644 --- a/src/maestro/mod.rs +++ b/src/maestro/mod.rs @@ -47,7 +47,9 @@ use crate::stats::{ReplayChecker, Stats}; use crate::stream::BufferPool; use crate::transport::UpstreamManager; use crate::transport::middle_proxy::MePool; -use helpers::{parse_cli, print_maestro_line, resolve_runtime_base_dir, resolve_runtime_config_path}; +use helpers::{ + parse_cli, print_maestro_line, resolve_runtime_base_dir, resolve_runtime_config_path, +}; #[cfg(unix)] use crate::daemon::{DaemonOptions, PidFile, drop_privileges}; @@ -463,8 +465,7 @@ async fn run_telemt_core( let (api_config_tx, api_config_rx) = watch::channel(Arc::new(config.clone())); let (detected_ips_tx, detected_ips_rx) = watch::channel((None::, None::)); - let initial_direct_first = - config.general.use_middle_proxy && config.general.me2dc_fallback; + let initial_direct_first = config.general.use_middle_proxy && config.general.me2dc_fallback; let initial_admission_open = !config.general.use_middle_proxy || initial_direct_first; let (admission_tx, admission_rx) = watch::channel(initial_admission_open); let initial_route_mode = if !config.general.use_middle_proxy || initial_direct_first { @@ -694,7 +695,9 @@ async fn run_telemt_core( if direct_first_startup { startup_tracker.set_transport_mode("direct").await; startup_tracker.set_degraded(true).await; - info!("Transport: Direct DC startup fallback active; Middle-End bootstrap continues in background"); + info!( + "Transport: Direct DC startup fallback active; Middle-End bootstrap continues in background" + ); } else if me_pool.is_some() { startup_tracker.set_transport_mode("middle_proxy").await; startup_tracker.set_degraded(false).await; @@ -840,7 +843,9 @@ async fn run_telemt_core( if let Some(pool) = api_me_pool_ready.read().await.as_ref() { pool.set_runtime_ready(true); } - startup_tracker_ready.set_transport_mode("middle_proxy").await; + startup_tracker_ready + .set_transport_mode("middle_proxy") + .await; startup_tracker_ready.set_degraded(false).await; info!("Transport: Middle-End Proxy restored for new sessions"); } diff --git a/src/metrics.rs b/src/metrics.rs index e32cf06..61a26c5 100644 --- a/src/metrics.rs +++ b/src/metrics.rs @@ -726,6 +726,37 @@ async fn render_metrics( } ); + let _ = writeln!( + out, + "# HELP telemt_route_cutover_parked_current Sessions currently parked in route cutover stagger delay" + ); + let _ = writeln!(out, "# TYPE telemt_route_cutover_parked_current gauge"); + let _ = writeln!( + out, + "telemt_route_cutover_parked_current{{route=\"direct\"}} {}", + stats.get_route_cutover_parked_direct_current() + ); + let _ = writeln!( + out, + "telemt_route_cutover_parked_current{{route=\"middle\"}} {}", + stats.get_route_cutover_parked_middle_current() + ); + let _ = writeln!( + out, + "# HELP telemt_route_cutover_parked_total Sessions parked in route cutover stagger delay" + ); + let _ = writeln!(out, "# TYPE telemt_route_cutover_parked_total counter"); + let _ = writeln!( + out, + "telemt_route_cutover_parked_total{{route=\"direct\"}} {}", + stats.get_route_cutover_parked_direct_total() + ); + let _ = writeln!( + out, + "telemt_route_cutover_parked_total{{route=\"middle\"}} {}", + stats.get_route_cutover_parked_middle_total() + ); + let _ = writeln!( out, "# HELP telemt_quota_refund_bytes_total Reserved quota bytes returned before commit" diff --git a/src/proxy/direct_relay.rs b/src/proxy/direct_relay.rs index efebcd9..2fea54a 100644 --- a/src/proxy/direct_relay.rs +++ b/src/proxy/direct_relay.rs @@ -358,6 +358,7 @@ where delay_ms = delay.as_millis() as u64, "Cutover affected direct session, closing client connection" ); + let _cutover_park_lease = stats.acquire_direct_cutover_park_lease(); tokio::time::sleep(delay).await; break Err(ProxyError::RouteSwitched); } diff --git a/src/proxy/masking.rs b/src/proxy/masking.rs index 85d737f..e0631c8 100644 --- a/src/proxy/masking.rs +++ b/src/proxy/masking.rs @@ -515,12 +515,28 @@ fn exclusive_mask_target_for_sni<'a>( config: &'a ProxyConfig, sni: &str, ) -> Option> { + if let Some(target) = config.censorship.exclusive_mask_targets.get(sni) { + return Some(MaskTcpTarget { + host: target.host.as_str(), + port: target.port, + }); + } if let Some(target) = config.censorship.exclusive_mask.get(sni) { return parse_exclusive_mask_target(target); } if sni.bytes().any(|byte| byte.is_ascii_uppercase()) { let normalized_sni = sni.to_ascii_lowercase(); + if let Some(target) = config + .censorship + .exclusive_mask_targets + .get(&normalized_sni) + { + return Some(MaskTcpTarget { + host: target.host.as_str(), + port: target.port, + }); + } if let Some(target) = config.censorship.exclusive_mask.get(&normalized_sni) { return parse_exclusive_mask_target(target); } diff --git a/src/proxy/middle_relay.rs b/src/proxy/middle_relay.rs index c8d6ce0..3f7c31c 100644 --- a/src/proxy/middle_relay.rs +++ b/src/proxy/middle_relay.rs @@ -35,14 +35,59 @@ use crate::stats::{ use crate::stream::{BufferPool, CryptoReader, CryptoWriter, PooledBuffer}; use crate::transport::middle_proxy::{MePool, MeResponse, proto_flags_for_tag}; -enum C2MeCommand { - Data { - payload: PooledBuffer, - flags: u32, - _permit: OwnedSemaphorePermit, - }, - Close, -} +mod c2me; +mod d2c; +mod desync; +mod idle; +mod quota; +mod session; + +pub(crate) use self::desync::DesyncDedupRotationState; +pub(crate) use self::idle::{RelayIdleCandidateRegistry, note_global_relay_pressure}; +pub(crate) use self::session::handle_via_middle_proxy; + +use self::c2me::{ + C2MeCommand, acquire_c2me_payload_permit, c2me_queued_permit_budget, enqueue_c2me_command_in, + should_yield_c2me_sender, +}; +use self::d2c::{ + MeD2cFlushPolicy, MeWriterResponseOutcome, classify_me_d2c_flush_reason, + flush_client_or_cancel, observe_me_d2c_flush_event, + process_me_writer_response_with_traffic_lease, +}; +use self::desync::{RelayForensicsState, hash_ip_in, report_desync_frame_too_large_in}; +use self::idle::{ + RelayClientIdlePolicy, RelayClientIdleState, clear_relay_idle_candidate_in, + maybe_evict_idle_candidate_on_pressure_in, note_relay_pressure_event_in, + read_client_payload_with_idle_policy_in, relay_pressure_event_seq_in, +}; +use self::quota::{ + MiddleQuotaReserveError, quota_soft_cap, reserve_user_quota_with_yield, + wait_for_traffic_budget, wait_for_traffic_budget_or_cancel, +}; + +#[cfg(test)] +use self::c2me::enqueue_c2me_command; +#[cfg(test)] +use self::d2c::{compute_intermediate_secure_wire_len, process_me_writer_response}; +#[cfg(test)] +pub(crate) use self::desync::{ + clear_desync_dedup_for_testing_in_shared, desync_dedup_get_for_testing, + desync_dedup_insert_for_testing, desync_dedup_keys_for_testing, desync_dedup_len_for_testing, + desync_forensics_len_bytes, hash_ip, report_desync_frame_too_large, + should_emit_full_desync_for_testing, +}; +#[cfg(test)] +use self::idle::RelayIdleCandidateMeta; +#[cfg(test)] +pub(crate) use self::idle::{ + clear_relay_idle_candidate_for_testing, clear_relay_idle_pressure_state_for_testing_in_shared, + mark_relay_idle_candidate_for_testing, maybe_evict_idle_candidate_on_pressure_for_testing, + note_relay_pressure_event_for_testing, oldest_relay_idle_candidate_for_testing, + read_client_payload, read_client_payload_legacy, read_client_payload_with_idle_policy, + relay_idle_mark_seq_for_testing, relay_pressure_event_seq_for_testing, + set_relay_pressure_state_for_testing, +}; const DESYNC_DEDUP_WINDOW: Duration = Duration::from_secs(60); const DESYNC_DEDUP_MAX_ENTRIES: usize = 65_536; @@ -68,1061 +113,6 @@ const QUOTA_RESERVE_BACKOFF_MAX_MS: u64 = 16; const QUOTA_RESERVE_MAX_BACKOFF_ROUNDS: usize = 16; const ME_CHILD_JOIN_TIMEOUT: Duration = Duration::from_secs(2); -enum MiddleQuotaReserveError { - LimitExceeded, - Contended, - Cancelled, - DeadlineExceeded, -} - -#[derive(Default)] -pub(crate) struct DesyncDedupRotationState { - current_started_at: Option, -} - -struct RelayForensicsState { - trace_id: u64, - conn_id: u64, - user: String, - peer: SocketAddr, - peer_hash: u64, - started_at: Instant, - bytes_c2me: u64, - bytes_me2c: Arc, - desync_all_full: bool, -} - -#[derive(Default)] -pub(crate) struct RelayIdleCandidateRegistry { - by_conn_id: HashMap, - ordered: BTreeSet<(u64, u64)>, - pressure_event_seq: u64, - pressure_consumed_seq: u64, -} - -#[derive(Clone, Copy)] -struct RelayIdleCandidateMeta { - mark_order_seq: u64, - mark_pressure_seq: u64, -} - -fn relay_idle_candidate_registry_lock_in( - shared: &ProxySharedState, -) -> std::sync::MutexGuard<'_, RelayIdleCandidateRegistry> { - let registry = &shared.middle_relay.relay_idle_registry; - match registry.lock() { - Ok(guard) => guard, - Err(poisoned) => { - let mut guard = poisoned.into_inner(); - *guard = RelayIdleCandidateRegistry::default(); - registry.clear_poison(); - guard - } - } -} - -fn mark_relay_idle_candidate_in(shared: &ProxySharedState, conn_id: u64) -> bool { - let mut guard = relay_idle_candidate_registry_lock_in(shared); - - if guard.by_conn_id.contains_key(&conn_id) { - return false; - } - - let mark_order_seq = shared - .middle_relay - .relay_idle_mark_seq - .fetch_add(1, Ordering::Relaxed) - .saturating_add(1); - let meta = RelayIdleCandidateMeta { - mark_order_seq, - mark_pressure_seq: guard.pressure_event_seq, - }; - guard.by_conn_id.insert(conn_id, meta); - guard.ordered.insert((meta.mark_order_seq, conn_id)); - true -} - -fn clear_relay_idle_candidate_in(shared: &ProxySharedState, conn_id: u64) { - let mut guard = relay_idle_candidate_registry_lock_in(shared); - - if let Some(meta) = guard.by_conn_id.remove(&conn_id) { - guard.ordered.remove(&(meta.mark_order_seq, conn_id)); - } -} - -fn note_relay_pressure_event_in(shared: &ProxySharedState) { - let mut guard = relay_idle_candidate_registry_lock_in(shared); - guard.pressure_event_seq = guard.pressure_event_seq.wrapping_add(1); -} - -pub(crate) fn note_global_relay_pressure(shared: &ProxySharedState) { - note_relay_pressure_event_in(shared); -} - -fn relay_pressure_event_seq_in(shared: &ProxySharedState) -> u64 { - let guard = relay_idle_candidate_registry_lock_in(shared); - guard.pressure_event_seq -} - -fn maybe_evict_idle_candidate_on_pressure_in( - shared: &ProxySharedState, - conn_id: u64, - seen_pressure_seq: &mut u64, - stats: &Stats, -) -> bool { - let mut guard = relay_idle_candidate_registry_lock_in(shared); - - let latest_pressure_seq = guard.pressure_event_seq; - if latest_pressure_seq == *seen_pressure_seq { - return false; - } - *seen_pressure_seq = latest_pressure_seq; - - if latest_pressure_seq == guard.pressure_consumed_seq { - return false; - } - - if guard.ordered.is_empty() { - guard.pressure_consumed_seq = latest_pressure_seq; - return false; - } - - let oldest = guard - .ordered - .iter() - .next() - .map(|(_, candidate_conn_id)| *candidate_conn_id); - if oldest != Some(conn_id) { - return false; - } - - let Some(candidate_meta) = guard.by_conn_id.get(&conn_id).copied() else { - return false; - }; - - if latest_pressure_seq == candidate_meta.mark_pressure_seq { - return false; - } - - if let Some(meta) = guard.by_conn_id.remove(&conn_id) { - guard.ordered.remove(&(meta.mark_order_seq, conn_id)); - } - guard.pressure_consumed_seq = latest_pressure_seq; - stats.increment_relay_pressure_evict_total(); - true -} - -#[derive(Clone, Copy)] -struct MeD2cFlushPolicy { - max_frames: usize, - max_bytes: usize, - max_delay: Duration, - ack_flush_immediate: bool, - quota_soft_overshoot_bytes: u64, - frame_buf_shrink_threshold_bytes: usize, -} - -#[derive(Clone, Copy)] -struct RelayClientIdlePolicy { - enabled: bool, - soft_idle: Duration, - hard_idle: Duration, - grace_after_downstream_activity: Duration, - legacy_frame_read_timeout: Duration, -} - -impl RelayClientIdlePolicy { - fn from_config(config: &ProxyConfig) -> Self { - let frame_read_timeout = - Duration::from_secs(config.timeouts.relay_client_idle_hard_secs.max(1)); - if !config.timeouts.relay_idle_policy_v2_enabled { - return Self::disabled(frame_read_timeout); - } - - let soft_idle = Duration::from_secs(config.timeouts.relay_client_idle_soft_secs.max(1)); - let hard_idle = Duration::from_secs(config.timeouts.relay_client_idle_hard_secs.max(1)); - let grace_after_downstream_activity = Duration::from_secs( - config - .timeouts - .relay_idle_grace_after_downstream_activity_secs, - ); - - Self { - enabled: true, - soft_idle, - hard_idle, - grace_after_downstream_activity, - legacy_frame_read_timeout: frame_read_timeout, - } - } - - fn disabled(frame_read_timeout: Duration) -> Self { - Self { - enabled: false, - soft_idle: frame_read_timeout, - hard_idle: frame_read_timeout, - grace_after_downstream_activity: Duration::ZERO, - legacy_frame_read_timeout: frame_read_timeout, - } - } - - fn apply_pressure_caps(&mut self, profile: ConntrackPressureProfile) { - let pressure_soft_idle_cap = Duration::from_secs(profile.middle_soft_idle_cap_secs()); - let pressure_hard_idle_cap = Duration::from_secs(profile.middle_hard_idle_cap_secs()); - - self.soft_idle = self.soft_idle.min(pressure_soft_idle_cap); - self.hard_idle = self.hard_idle.min(pressure_hard_idle_cap); - if self.soft_idle > self.hard_idle { - self.soft_idle = self.hard_idle; - } - self.legacy_frame_read_timeout = self.legacy_frame_read_timeout.min(pressure_hard_idle_cap); - if self.grace_after_downstream_activity > self.hard_idle { - self.grace_after_downstream_activity = self.hard_idle; - } - } -} - -#[derive(Clone, Copy)] -struct RelayClientIdleState { - last_client_frame_at: Instant, - soft_idle_marked: bool, - tiny_frame_debt: u32, -} - -impl RelayClientIdleState { - fn new(now: Instant) -> Self { - Self { - last_client_frame_at: now, - soft_idle_marked: false, - tiny_frame_debt: 0, - } - } - - fn on_client_frame(&mut self, now: Instant) { - self.last_client_frame_at = now; - self.soft_idle_marked = false; - } - - fn on_client_tiny_frame(&mut self, now: Instant) { - self.last_client_frame_at = now; - } -} - -impl MeD2cFlushPolicy { - fn from_config(config: &ProxyConfig) -> Self { - Self { - max_frames: config - .general - .me_d2c_flush_batch_max_frames - .max(ME_D2C_FLUSH_BATCH_MAX_FRAMES_MIN), - max_bytes: config - .general - .me_d2c_flush_batch_max_bytes - .max(ME_D2C_FLUSH_BATCH_MAX_BYTES_MIN), - max_delay: Duration::from_micros(config.general.me_d2c_flush_batch_max_delay_us), - ack_flush_immediate: config.general.me_d2c_ack_flush_immediate, - quota_soft_overshoot_bytes: config.general.me_quota_soft_overshoot_bytes, - frame_buf_shrink_threshold_bytes: config - .general - .me_d2c_frame_buf_shrink_threshold_bytes - .max(4096), - } - } -} - -#[cfg(test)] -fn hash_value(value: &T) -> u64 { - let mut hasher = DefaultHasher::new(); - value.hash(&mut hasher); - hasher.finish() -} - -fn hash_value_in(shared: &ProxySharedState, value: &T) -> u64 { - shared.middle_relay.desync_hasher.hash_one(value) -} - -#[cfg(test)] -fn hash_ip(ip: IpAddr) -> u64 { - hash_value(&ip) -} - -fn hash_ip_in(shared: &ProxySharedState, ip: IpAddr) -> u64 { - hash_value_in(shared, &ip) -} - -fn should_emit_full_desync_in( - shared: &ProxySharedState, - key: u64, - all_full: bool, - now: Instant, -) -> bool { - if all_full { - return true; - } - - let dedup_current = &shared.middle_relay.desync_dedup; - let dedup_previous = &shared.middle_relay.desync_dedup_previous; - let rotation_state = &shared.middle_relay.desync_dedup_rotation_state; - - let mut state = match rotation_state.lock() { - Ok(guard) => guard, - Err(poisoned) => { - let mut guard = poisoned.into_inner(); - *guard = DesyncDedupRotationState::default(); - rotation_state.clear_poison(); - guard - } - }; - - let rotate_now = match state.current_started_at { - Some(current_started_at) => match now.checked_duration_since(current_started_at) { - Some(elapsed) => elapsed >= DESYNC_DEDUP_WINDOW, - None => true, - }, - None => true, - }; - if rotate_now { - dedup_previous.clear(); - for entry in dedup_current.iter() { - dedup_previous.insert(*entry.key(), *entry.value()); - } - dedup_current.clear(); - state.current_started_at = Some(now); - } - - if let Some(seen_at) = dedup_current.get(&key).map(|entry| *entry.value()) { - let within_window = match now.checked_duration_since(seen_at) { - Some(elapsed) => elapsed < DESYNC_DEDUP_WINDOW, - None => true, - }; - if within_window { - return false; - } - dedup_current.insert(key, now); - return true; - } - - if let Some(seen_at) = dedup_previous.get(&key).map(|entry| *entry.value()) { - let within_window = match now.checked_duration_since(seen_at) { - Some(elapsed) => elapsed < DESYNC_DEDUP_WINDOW, - None => true, - }; - if within_window { - dedup_current.insert(key, seen_at); - return false; - } - dedup_previous.remove(&key); - } - - if dedup_current.len() >= DESYNC_DEDUP_MAX_ENTRIES { - dedup_previous.clear(); - for entry in dedup_current.iter() { - dedup_previous.insert(*entry.key(), *entry.value()); - } - dedup_current.clear(); - state.current_started_at = Some(now); - dedup_current.insert(key, now); - should_emit_full_desync_full_cache_in(shared, now) - } else { - dedup_current.insert(key, now); - true - } -} - -fn should_emit_full_desync_full_cache_in(shared: &ProxySharedState, now: Instant) -> bool { - let gate = &shared.middle_relay.desync_full_cache_last_emit_at; - let Ok(mut last_emit_at) = gate.lock() else { - return false; - }; - - match *last_emit_at { - None => { - *last_emit_at = Some(now); - true - } - Some(last) => { - let Some(elapsed) = now.checked_duration_since(last) else { - *last_emit_at = Some(now); - return true; - }; - if elapsed >= DESYNC_FULL_CACHE_EMIT_MIN_INTERVAL { - *last_emit_at = Some(now); - true - } else { - false - } - } - } -} - -fn desync_forensics_len_bytes(len: usize) -> ([u8; 4], bool) { - match u32::try_from(len) { - Ok(value) => (value.to_le_bytes(), false), - Err(_) => (u32::MAX.to_le_bytes(), true), - } -} - -fn report_desync_frame_too_large_in( - shared: &ProxySharedState, - state: &RelayForensicsState, - proto_tag: ProtoTag, - frame_counter: u64, - max_frame: usize, - len: usize, - raw_len_bytes: Option<[u8; 4]>, - stats: &Stats, -) -> ProxyError { - let (fallback_len_buf, len_buf_truncated) = desync_forensics_len_bytes(len); - let len_buf = raw_len_bytes.unwrap_or(fallback_len_buf); - let looks_like_tls = raw_len_bytes - .map(|b| b[0] == 0x16 && b[1] == 0x03) - .unwrap_or(false); - let looks_like_http = raw_len_bytes - .map(|b| matches!(b[0], b'G' | b'P' | b'H' | b'C' | b'D')) - .unwrap_or(false); - let now = Instant::now(); - let dedup_key = hash_value_in( - shared, - &( - state.user.as_str(), - state.peer_hash, - proto_tag, - DESYNC_ERROR_CLASS, - ), - ); - let emit_full = should_emit_full_desync_in(shared, dedup_key, state.desync_all_full, now); - let duration_ms = state.started_at.elapsed().as_millis() as u64; - let bytes_me2c = state.bytes_me2c.load(Ordering::Relaxed); - - stats.increment_desync_total(); - stats.increment_relay_protocol_desync_close_total(); - stats.observe_desync_frames_ok(frame_counter); - if emit_full { - stats.increment_desync_full_logged(); - warn!( - trace_id = format_args!("0x{:016x}", state.trace_id), - conn_id = state.conn_id, - user = %state.user, - peer_hash = format_args!("0x{:016x}", state.peer_hash), - proto = ?proto_tag, - mode = "middle_proxy", - is_tls = true, - duration_ms, - bytes_c2me = state.bytes_c2me, - bytes_me2c, - raw_len = len, - raw_len_hex = format_args!("0x{:08x}", len), - raw_len_bytes_truncated = len_buf_truncated, - raw_bytes = format_args!( - "{:02x} {:02x} {:02x} {:02x}", - len_buf[0], len_buf[1], len_buf[2], len_buf[3] - ), - max_frame, - tls_like = looks_like_tls, - http_like = looks_like_http, - frames_ok = frame_counter, - dedup_window_secs = DESYNC_DEDUP_WINDOW.as_secs(), - desync_all_full = state.desync_all_full, - full_reason = if state.desync_all_full { "desync_all_full" } else { "first_in_dedup_window" }, - error_class = DESYNC_ERROR_CLASS, - "Frame too large — crypto desync forensics" - ); - debug!( - trace_id = format_args!("0x{:016x}", state.trace_id), - conn_id = state.conn_id, - user = %state.user, - peer = %state.peer, - "Frame too large forensic peer detail" - ); - } else { - stats.increment_desync_suppressed(); - debug!( - trace_id = format_args!("0x{:016x}", state.trace_id), - conn_id = state.conn_id, - user = %state.user, - peer_hash = format_args!("0x{:016x}", state.peer_hash), - proto = ?proto_tag, - duration_ms, - bytes_c2me = state.bytes_c2me, - bytes_me2c, - raw_len = len, - frames_ok = frame_counter, - dedup_window_secs = DESYNC_DEDUP_WINDOW.as_secs(), - error_class = DESYNC_ERROR_CLASS, - "Frame too large — crypto desync forensic suppressed" - ); - } - - ProxyError::Proxy(format!( - "Frame too large: {len} (max {max_frame}), frames_ok={frame_counter}, conn_id={}, trace_id=0x{:016x}", - state.conn_id, state.trace_id - )) -} - -#[cfg(test)] -fn report_desync_frame_too_large( - state: &RelayForensicsState, - proto_tag: ProtoTag, - frame_counter: u64, - max_frame: usize, - len: usize, - raw_len_bytes: Option<[u8; 4]>, - stats: &Stats, -) -> ProxyError { - let shared = ProxySharedState::new(); - report_desync_frame_too_large_in( - shared.as_ref(), - state, - proto_tag, - frame_counter, - max_frame, - len, - raw_len_bytes, - stats, - ) -} - -fn should_yield_c2me_sender(sent_since_yield: usize, has_backlog: bool) -> bool { - has_backlog && sent_since_yield >= C2ME_SENDER_FAIRNESS_BUDGET -} - -fn c2me_payload_permits(payload_len: usize) -> u32 { - payload_len - .max(1) - .div_ceil(C2ME_QUEUED_BYTE_PERMIT_UNIT) - .min(u32::MAX as usize) as u32 -} - -fn c2me_queued_permit_budget(channel_capacity: usize, frame_limit: usize) -> usize { - channel_capacity - .saturating_mul(C2ME_QUEUED_PERMITS_PER_SLOT) - .max(c2me_payload_permits(frame_limit) as usize) - .max(1) -} - -async fn acquire_c2me_payload_permit( - semaphore: &Arc, - payload_len: usize, - send_timeout: Option, - stats: &Stats, -) -> Result { - let permits = c2me_payload_permits(payload_len); - let acquire = semaphore.clone().acquire_many_owned(permits); - match send_timeout { - Some(send_timeout) => match timeout(send_timeout, acquire).await { - Ok(Ok(permit)) => Ok(permit), - Ok(Err(_)) => Err(ProxyError::Proxy("ME sender byte budget closed".into())), - Err(_) => { - stats.increment_me_c2me_send_timeout_total(); - Err(ProxyError::Proxy("ME sender byte budget timeout".into())) - } - }, - None => acquire - .await - .map_err(|_| ProxyError::Proxy("ME sender byte budget closed".into())), - } -} - -fn quota_soft_cap(limit: u64, overshoot: u64) -> u64 { - limit.saturating_add(overshoot) -} - -async fn reserve_user_quota_with_yield( - user_stats: &UserStats, - bytes: u64, - limit: u64, - stats: &Stats, - cancel: &CancellationToken, - deadline: Option, -) -> std::result::Result { - let mut backoff_ms = QUOTA_RESERVE_BACKOFF_MIN_MS; - let mut backoff_rounds = 0usize; - loop { - for _ in 0..QUOTA_RESERVE_SPIN_RETRIES { - match user_stats.quota_try_reserve(bytes, limit) { - Ok(total) => return Ok(total), - Err(QuotaReserveError::LimitExceeded) => { - return Err(MiddleQuotaReserveError::LimitExceeded); - } - Err(QuotaReserveError::Contended) => { - stats.increment_quota_contention_total(); - std::hint::spin_loop(); - } - } - } - - tokio::task::yield_now().await; - if deadline.is_some_and(|deadline| Instant::now() >= deadline) { - stats.increment_quota_contention_timeout_total(); - return Err(MiddleQuotaReserveError::DeadlineExceeded); - } - tokio::select! { - _ = tokio::time::sleep(Duration::from_millis(backoff_ms)) => {} - _ = cancel.cancelled() => { - stats.increment_quota_acquire_cancelled_total(); - return Err(MiddleQuotaReserveError::Cancelled); - } - } - backoff_rounds = backoff_rounds.saturating_add(1); - if backoff_rounds >= QUOTA_RESERVE_MAX_BACKOFF_ROUNDS { - stats.increment_quota_contention_timeout_total(); - return Err(MiddleQuotaReserveError::Contended); - } - backoff_ms = backoff_ms - .saturating_mul(2) - .min(QUOTA_RESERVE_BACKOFF_MAX_MS); - } -} - -async fn wait_for_traffic_budget( - lease: Option<&Arc>, - direction: RateDirection, - bytes: u64, - deadline: Option, -) -> Result<()> { - if bytes == 0 { - return Ok(()); - } - let Some(lease) = lease else { - return Ok(()); - }; - - let mut remaining = bytes; - while remaining > 0 { - let consume = lease.try_consume(direction, remaining); - if consume.granted > 0 { - remaining = remaining.saturating_sub(consume.granted); - continue; - } - - let wait_started_at = Instant::now(); - if deadline.is_some_and(|deadline| wait_started_at >= deadline) { - return Err(ProxyError::TrafficBudgetWaitDeadlineExceeded); - } - tokio::time::sleep(next_refill_delay()).await; - let wait_ms = wait_started_at - .elapsed() - .as_millis() - .min(u128::from(u64::MAX)) as u64; - lease.observe_wait_ms( - direction, - consume.blocked_user, - consume.blocked_cidr, - wait_ms, - ); - } - - Ok(()) -} - -async fn wait_for_traffic_budget_or_cancel( - lease: Option<&Arc>, - direction: RateDirection, - bytes: u64, - cancel: &CancellationToken, - stats: &Stats, - deadline: Option, -) -> Result<()> { - if bytes == 0 { - return Ok(()); - } - let Some(lease) = lease else { - return Ok(()); - }; - - let mut remaining = bytes; - while remaining > 0 { - let consume = lease.try_consume(direction, remaining); - if consume.granted > 0 { - remaining = remaining.saturating_sub(consume.granted); - continue; - } - - let wait_started_at = Instant::now(); - if deadline.is_some_and(|deadline| wait_started_at >= deadline) { - stats.increment_flow_wait_middle_rate_limit_cancelled_total(); - return Err(ProxyError::TrafficBudgetWaitDeadlineExceeded); - } - tokio::select! { - _ = tokio::time::sleep(next_refill_delay()) => {} - _ = cancel.cancelled() => { - stats.increment_flow_wait_middle_rate_limit_cancelled_total(); - return Err(ProxyError::TrafficBudgetWaitCancelled); - } - } - let wait_ms = wait_started_at - .elapsed() - .as_millis() - .min(u128::from(u64::MAX)) as u64; - lease.observe_wait_ms( - direction, - consume.blocked_user, - consume.blocked_cidr, - wait_ms, - ); - stats.observe_flow_wait_middle_rate_limit_ms(wait_ms); - } - - Ok(()) -} - -fn classify_me_d2c_flush_reason( - flush_immediately: bool, - batch_frames: usize, - max_frames: usize, - batch_bytes: usize, - max_bytes: usize, - max_delay_fired: bool, -) -> MeD2cFlushReason { - if flush_immediately { - return MeD2cFlushReason::AckImmediate; - } - if batch_frames >= max_frames { - return MeD2cFlushReason::BatchFrames; - } - if batch_bytes >= max_bytes { - return MeD2cFlushReason::BatchBytes; - } - if max_delay_fired { - return MeD2cFlushReason::MaxDelay; - } - MeD2cFlushReason::QueueDrain -} - -fn observe_me_d2c_flush_event( - stats: &Stats, - reason: MeD2cFlushReason, - batch_frames: usize, - batch_bytes: usize, - flush_duration_us: Option, -) { - stats.increment_me_d2c_flush_reason(reason); - if batch_frames > 0 || batch_bytes > 0 { - stats.increment_me_d2c_batches_total(); - stats.add_me_d2c_batch_frames_total(batch_frames as u64); - stats.add_me_d2c_batch_bytes_total(batch_bytes as u64); - stats.observe_me_d2c_batch_frames(batch_frames as u64); - stats.observe_me_d2c_batch_bytes(batch_bytes as u64); - } - if let Some(duration_us) = flush_duration_us { - stats.observe_me_d2c_flush_duration_us(duration_us); - } -} - -#[cfg(test)] -pub(crate) fn mark_relay_idle_candidate_for_testing( - shared: &ProxySharedState, - conn_id: u64, -) -> bool { - let registry = &shared.middle_relay.relay_idle_registry; - let mut guard = match registry.lock() { - Ok(guard) => guard, - Err(poisoned) => { - let mut guard = poisoned.into_inner(); - *guard = RelayIdleCandidateRegistry::default(); - registry.clear_poison(); - guard - } - }; - - if guard.by_conn_id.contains_key(&conn_id) { - return false; - } - - let mark_order_seq = shared - .middle_relay - .relay_idle_mark_seq - .fetch_add(1, Ordering::Relaxed); - let mark_pressure_seq = guard.pressure_event_seq; - let meta = RelayIdleCandidateMeta { - mark_order_seq, - mark_pressure_seq, - }; - guard.by_conn_id.insert(conn_id, meta); - guard.ordered.insert((mark_order_seq, conn_id)); - true -} - -#[cfg(test)] -pub(crate) fn oldest_relay_idle_candidate_for_testing(shared: &ProxySharedState) -> Option { - let registry = &shared.middle_relay.relay_idle_registry; - let guard = match registry.lock() { - Ok(guard) => guard, - Err(poisoned) => { - let mut guard = poisoned.into_inner(); - *guard = RelayIdleCandidateRegistry::default(); - registry.clear_poison(); - guard - } - }; - guard.ordered.iter().next().map(|(_, conn_id)| *conn_id) -} - -#[cfg(test)] -pub(crate) fn clear_relay_idle_candidate_for_testing(shared: &ProxySharedState, conn_id: u64) { - let registry = &shared.middle_relay.relay_idle_registry; - let mut guard = match registry.lock() { - Ok(guard) => guard, - Err(poisoned) => { - let mut guard = poisoned.into_inner(); - *guard = RelayIdleCandidateRegistry::default(); - registry.clear_poison(); - guard - } - }; - if let Some(meta) = guard.by_conn_id.remove(&conn_id) { - guard.ordered.remove(&(meta.mark_order_seq, conn_id)); - } -} - -#[cfg(test)] -pub(crate) fn clear_relay_idle_pressure_state_for_testing_in_shared(shared: &ProxySharedState) { - if let Ok(mut guard) = shared.middle_relay.relay_idle_registry.lock() { - *guard = RelayIdleCandidateRegistry::default(); - } - shared - .middle_relay - .relay_idle_mark_seq - .store(0, Ordering::Relaxed); -} - -#[cfg(test)] -pub(crate) fn note_relay_pressure_event_for_testing(shared: &ProxySharedState) { - note_relay_pressure_event_in(shared); -} - -#[cfg(test)] -pub(crate) fn relay_pressure_event_seq_for_testing(shared: &ProxySharedState) -> u64 { - relay_pressure_event_seq_in(shared) -} - -#[cfg(test)] -pub(crate) fn relay_idle_mark_seq_for_testing(shared: &ProxySharedState) -> u64 { - shared - .middle_relay - .relay_idle_mark_seq - .load(Ordering::Relaxed) -} - -#[cfg(test)] -pub(crate) fn maybe_evict_idle_candidate_on_pressure_for_testing( - shared: &ProxySharedState, - conn_id: u64, - seen_pressure_seq: &mut u64, - stats: &Stats, -) -> bool { - maybe_evict_idle_candidate_on_pressure_in(shared, conn_id, seen_pressure_seq, stats) -} - -#[cfg(test)] -pub(crate) fn set_relay_pressure_state_for_testing( - shared: &ProxySharedState, - pressure_event_seq: u64, - pressure_consumed_seq: u64, -) { - let registry = &shared.middle_relay.relay_idle_registry; - let mut guard = match registry.lock() { - Ok(guard) => guard, - Err(poisoned) => { - let mut guard = poisoned.into_inner(); - *guard = RelayIdleCandidateRegistry::default(); - registry.clear_poison(); - guard - } - }; - guard.pressure_event_seq = pressure_event_seq; - guard.pressure_consumed_seq = pressure_consumed_seq; -} - -#[cfg(test)] -pub(crate) fn should_emit_full_desync_for_testing( - shared: &ProxySharedState, - key: u64, - all_full: bool, - now: Instant, -) -> bool { - if all_full { - return true; - } - - let dedup_current = &shared.middle_relay.desync_dedup; - let dedup_previous = &shared.middle_relay.desync_dedup_previous; - - let Ok(mut state) = shared.middle_relay.desync_dedup_rotation_state.lock() else { - return false; - }; - - let rotate_now = match state.current_started_at { - Some(current_started_at) => match now.checked_duration_since(current_started_at) { - Some(elapsed) => elapsed >= DESYNC_DEDUP_WINDOW, - None => true, - }, - None => true, - }; - if rotate_now { - dedup_previous.clear(); - for entry in dedup_current.iter() { - dedup_previous.insert(*entry.key(), *entry.value()); - } - dedup_current.clear(); - state.current_started_at = Some(now); - } - - if let Some(seen_at) = dedup_current.get(&key).map(|entry| *entry.value()) { - let within_window = match now.checked_duration_since(seen_at) { - Some(elapsed) => elapsed < DESYNC_DEDUP_WINDOW, - None => true, - }; - if within_window { - return false; - } - dedup_current.insert(key, now); - return true; - } - - if let Some(seen_at) = dedup_previous.get(&key).map(|entry| *entry.value()) { - let within_window = match now.checked_duration_since(seen_at) { - Some(elapsed) => elapsed < DESYNC_DEDUP_WINDOW, - None => true, - }; - if within_window { - dedup_current.insert(key, seen_at); - return false; - } - dedup_previous.remove(&key); - } - - if dedup_current.len() >= DESYNC_DEDUP_MAX_ENTRIES { - dedup_previous.clear(); - for entry in dedup_current.iter() { - dedup_previous.insert(*entry.key(), *entry.value()); - } - dedup_current.clear(); - state.current_started_at = Some(now); - dedup_current.insert(key, now); - let Ok(mut last_emit_at) = shared.middle_relay.desync_full_cache_last_emit_at.lock() else { - return false; - }; - return match *last_emit_at { - None => { - *last_emit_at = Some(now); - true - } - Some(last) => { - let Some(elapsed) = now.checked_duration_since(last) else { - *last_emit_at = Some(now); - return true; - }; - if elapsed >= DESYNC_FULL_CACHE_EMIT_MIN_INTERVAL { - *last_emit_at = Some(now); - true - } else { - false - } - } - }; - } - - dedup_current.insert(key, now); - true -} - -#[cfg(test)] -pub(crate) fn clear_desync_dedup_for_testing_in_shared(shared: &ProxySharedState) { - shared.middle_relay.desync_dedup.clear(); - shared.middle_relay.desync_dedup_previous.clear(); - if let Ok(mut rotation_state) = shared.middle_relay.desync_dedup_rotation_state.lock() { - *rotation_state = DesyncDedupRotationState::default(); - } - if let Ok(mut last_emit_at) = shared.middle_relay.desync_full_cache_last_emit_at.lock() { - *last_emit_at = None; - } -} - -#[cfg(test)] -pub(crate) fn desync_dedup_len_for_testing(shared: &ProxySharedState) -> usize { - shared.middle_relay.desync_dedup.len() -} - -#[cfg(test)] -pub(crate) fn desync_dedup_insert_for_testing(shared: &ProxySharedState, key: u64, at: Instant) { - shared.middle_relay.desync_dedup.insert(key, at); -} - -#[cfg(test)] -pub(crate) fn desync_dedup_get_for_testing(shared: &ProxySharedState, key: u64) -> Option { - shared - .middle_relay - .desync_dedup - .get(&key) - .map(|entry| *entry.value()) -} - -#[cfg(test)] -pub(crate) fn desync_dedup_keys_for_testing( - shared: &ProxySharedState, -) -> std::collections::HashSet { - shared - .middle_relay - .desync_dedup - .iter() - .map(|entry| *entry.key()) - .collect() -} - -async fn enqueue_c2me_command_in( - shared: &ProxySharedState, - tx: &mpsc::Sender, - cmd: C2MeCommand, - send_timeout: Option, - stats: &Stats, -) -> std::result::Result<(), mpsc::error::SendError> { - match tx.try_send(cmd) { - Ok(()) => Ok(()), - Err(mpsc::error::TrySendError::Closed(cmd)) => Err(mpsc::error::SendError(cmd)), - Err(mpsc::error::TrySendError::Full(cmd)) => { - stats.increment_me_c2me_send_full_total(); - stats.increment_me_c2me_send_high_water_total(); - note_relay_pressure_event_in(shared); - // Cooperative yield reduces burst catch-up when the per-conn queue is near saturation. - if tx.capacity() <= C2ME_SOFT_PRESSURE_MIN_FREE_SLOTS { - tokio::task::yield_now().await; - } - let reserve_result = match send_timeout { - Some(send_timeout) => match timeout(send_timeout, tx.reserve()).await { - Ok(result) => result, - Err(_) => { - stats.increment_me_c2me_send_timeout_total(); - return Err(mpsc::error::SendError(cmd)); - } - }, - None => tx.reserve().await, - }; - match reserve_result { - Ok(permit) => { - permit.send(cmd); - Ok(()) - } - Err(_) => { - stats.increment_me_c2me_send_timeout_total(); - Err(mpsc::error::SendError(cmd)) - } - } - } - } -} - -#[cfg(test)] -async fn enqueue_c2me_command( - tx: &mpsc::Sender, - cmd: C2MeCommand, - send_timeout: Option, - stats: &Stats, -) -> std::result::Result<(), mpsc::error::SendError> { - let shared = ProxySharedState::new(); - enqueue_c2me_command_in(shared.as_ref(), tx, cmd, send_timeout, stats).await -} - #[cfg(test)] async fn run_relay_test_step_timeout(context: &'static str, fut: F) -> T where @@ -1133,1650 +123,6 @@ where .unwrap_or_else(|_| panic!("{context} exceeded {}s", RELAY_TEST_STEP_TIMEOUT.as_secs())) } -pub(crate) async fn handle_via_middle_proxy( - mut crypto_reader: CryptoReader, - crypto_writer: CryptoWriter, - success: HandshakeSuccess, - me_pool: Arc, - stats: Arc, - config: Arc, - buffer_pool: Arc, - local_addr: SocketAddr, - rng: Arc, - mut route_rx: watch::Receiver, - route_snapshot: RouteCutoverState, - session_id: u64, - shared: Arc, -) -> Result<()> -where - R: AsyncRead + Unpin + Send + 'static, - W: AsyncWrite + Unpin + Send + 'static, -{ - let user = success.user.clone(); - let quota_limit = config.access.user_data_quota.get(&user).copied(); - let quota_user_stats = quota_limit.map(|_| stats.get_or_create_user_stats_handle(&user)); - let peer = success.peer; - let traffic_lease = shared.traffic_limiter.acquire_lease(&user, peer.ip()); - let proto_tag = success.proto_tag; - let pool_generation = me_pool.current_generation(); - - debug!( - user = %user, - peer = %peer, - dc = success.dc_idx, - proto = ?proto_tag, - mode = "middle_proxy", - pool_generation, - "Routing via Middle-End" - ); - - let (conn_id, me_rx) = me_pool.registry().register().await; - let trace_id = session_id; - let bytes_me2c = Arc::new(AtomicU64::new(0)); - let mut forensics = RelayForensicsState { - trace_id, - conn_id, - user: user.clone(), - peer, - peer_hash: hash_ip_in(shared.as_ref(), peer.ip()), - started_at: Instant::now(), - bytes_c2me: 0, - bytes_me2c: bytes_me2c.clone(), - desync_all_full: config.general.desync_all_full, - }; - - stats.increment_user_connects(&user); - let _me_connection_lease = stats.acquire_me_connection_lease(); - - if let Some(cutover) = - affected_cutover_state(&route_rx, RelayRouteMode::Middle, route_snapshot.generation) - { - let delay = cutover_stagger_delay(session_id, cutover.generation); - warn!( - conn_id, - target_mode = cutover.mode.as_str(), - cutover_generation = cutover.generation, - delay_ms = delay.as_millis() as u64, - "Cutover affected middle session before relay start, closing client connection" - ); - tokio::time::sleep(delay).await; - let _ = me_pool.send_close(conn_id).await; - me_pool.registry().unregister(conn_id).await; - return Err(ProxyError::RouteSwitched); - } - - // Per-user ad_tag from access.user_ad_tags; fallback to general.ad_tag (hot-reloadable) - let user_tag: Option> = config - .access - .user_ad_tags - .get(&user) - .and_then(|s| hex::decode(s).ok()) - .filter(|v| v.len() == 16); - let global_tag: Option> = config - .general - .ad_tag - .as_ref() - .and_then(|s| hex::decode(s).ok()) - .filter(|v| v.len() == 16); - let effective_tag = user_tag.or(global_tag); - - let proto_flags = proto_flags_for_tag(proto_tag, effective_tag.is_some()); - debug!( - trace_id = format_args!("0x{:016x}", trace_id), - user = %user, - conn_id, - peer_hash = format_args!("0x{:016x}", forensics.peer_hash), - desync_all_full = forensics.desync_all_full, - proto_flags = format_args!("0x{:08x}", proto_flags), - pool_generation, - "ME relay started" - ); - - let translated_local_addr = me_pool.translate_our_addr(local_addr); - - let frame_limit = config.general.max_client_frame; - let mut relay_idle_policy = RelayClientIdlePolicy::from_config(&config); - let mut pressure_caps_applied = false; - if shared.conntrack_pressure_active() { - relay_idle_policy.apply_pressure_caps(config.server.conntrack_control.profile); - pressure_caps_applied = true; - } - let session_started_at = forensics.started_at; - let mut relay_idle_state = RelayClientIdleState::new(session_started_at); - let last_downstream_activity_ms = Arc::new(AtomicU64::new(0)); - - let c2me_channel_capacity = config - .general - .me_c2me_channel_capacity - .max(C2ME_CHANNEL_CAPACITY_FALLBACK); - let c2me_send_timeout = match config.general.me_c2me_send_timeout_ms { - 0 => None, - timeout_ms => Some(Duration::from_millis(timeout_ms)), - }; - let c2me_byte_budget = c2me_queued_permit_budget(c2me_channel_capacity, frame_limit); - let c2me_byte_semaphore = Arc::new(Semaphore::new(c2me_byte_budget)); - let (c2me_tx, mut c2me_rx) = mpsc::channel::(c2me_channel_capacity); - let me_pool_c2me = me_pool.clone(); - let mut c2me_sender = tokio::spawn(async move { - let mut sent_since_yield = 0usize; - while let Some(cmd) = c2me_rx.recv().await { - match cmd { - C2MeCommand::Data { - payload, - flags, - _permit, - } => { - me_pool_c2me - .send_proxy_req( - conn_id, - success.dc_idx, - peer, - translated_local_addr, - payload.as_ref(), - flags, - effective_tag.as_deref(), - ) - .await?; - sent_since_yield = sent_since_yield.saturating_add(1); - if should_yield_c2me_sender(sent_since_yield, !c2me_rx.is_empty()) { - sent_since_yield = 0; - tokio::task::yield_now().await; - } - } - C2MeCommand::Close => { - let _ = me_pool_c2me.send_close(conn_id).await; - return Ok(()); - } - } - } - Ok(()) - }); - - let (stop_tx, mut stop_rx) = oneshot::channel::<()>(); - let flow_cancel = CancellationToken::new(); - let mut me_rx_task = me_rx; - let stats_clone = stats.clone(); - let rng_clone = rng.clone(); - let user_clone = user.clone(); - let quota_user_stats_me_writer = quota_user_stats.clone(); - let traffic_lease_me_writer = traffic_lease.clone(); - let flow_cancel_me_writer = flow_cancel.clone(); - let last_downstream_activity_ms_clone = last_downstream_activity_ms.clone(); - let bytes_me2c_clone = bytes_me2c.clone(); - let d2c_flush_policy = MeD2cFlushPolicy::from_config(&config); - let mut me_writer = tokio::spawn(async move { - let mut writer = crypto_writer; - let mut frame_buf = Vec::with_capacity(16 * 1024); - let shrink_threshold = d2c_flush_policy.frame_buf_shrink_threshold_bytes; - - fn shrink_session_vec(buf: &mut Vec, threshold: usize) { - if buf.capacity() > threshold { - buf.clear(); - buf.shrink_to(threshold); - } else { - buf.clear(); - } - } - - loop { - tokio::select! { - msg = me_rx_task.recv() => { - let Some(first) = msg else { - debug!(conn_id, "ME channel closed"); - shrink_session_vec(&mut frame_buf, shrink_threshold); - return Err(ProxyError::MiddleConnectionLost); - }; - - let mut batch_frames = 0usize; - let mut batch_bytes = 0usize; - let mut flush_immediately; - let mut max_delay_fired = false; - - let first_is_downstream_activity = - matches!(&first, MeResponse::Data { .. } | MeResponse::Ack(_)); - match process_me_writer_response_with_traffic_lease( - first, - &mut writer, - proto_tag, - rng_clone.as_ref(), - &mut frame_buf, - stats_clone.as_ref(), - &user_clone, - quota_user_stats_me_writer.as_deref(), - quota_limit, - d2c_flush_policy.quota_soft_overshoot_bytes, - traffic_lease_me_writer.as_ref(), - &flow_cancel_me_writer, - bytes_me2c_clone.as_ref(), - conn_id, - d2c_flush_policy.ack_flush_immediate, - false, - ).await? { - MeWriterResponseOutcome::Continue { frames, bytes, flush_immediately: immediate } => { - if first_is_downstream_activity { - last_downstream_activity_ms_clone - .store(session_started_at.elapsed().as_millis() as u64, Ordering::Relaxed); - } - batch_frames = batch_frames.saturating_add(frames); - batch_bytes = batch_bytes.saturating_add(bytes); - flush_immediately = immediate; - } - MeWriterResponseOutcome::Close => { - let flush_started_at = if stats_clone.telemetry_policy().me_level.allows_debug() { - Some(Instant::now()) - } else { - None - }; - let _ = flush_client_or_cancel(&mut writer, &flow_cancel_me_writer).await; - let flush_duration_us = flush_started_at.map(|started| { - started - .elapsed() - .as_micros() - .min(u128::from(u64::MAX)) as u64 - }); - observe_me_d2c_flush_event( - stats_clone.as_ref(), - MeD2cFlushReason::Close, - batch_frames, - batch_bytes, - flush_duration_us, - ); - shrink_session_vec(&mut frame_buf, shrink_threshold); - return Ok(()); - } - } - - while !flush_immediately - && batch_frames < d2c_flush_policy.max_frames - && batch_bytes < d2c_flush_policy.max_bytes - { - let Ok(next) = me_rx_task.try_recv() else { - break; - }; - - let next_is_downstream_activity = - matches!(&next, MeResponse::Data { .. } | MeResponse::Ack(_)); - match process_me_writer_response_with_traffic_lease( - next, - &mut writer, - proto_tag, - rng_clone.as_ref(), - &mut frame_buf, - stats_clone.as_ref(), - &user_clone, - quota_user_stats_me_writer.as_deref(), - quota_limit, - d2c_flush_policy.quota_soft_overshoot_bytes, - traffic_lease_me_writer.as_ref(), - &flow_cancel_me_writer, - bytes_me2c_clone.as_ref(), - conn_id, - d2c_flush_policy.ack_flush_immediate, - true, - ).await? { - MeWriterResponseOutcome::Continue { frames, bytes, flush_immediately: immediate } => { - if next_is_downstream_activity { - last_downstream_activity_ms_clone - .store(session_started_at.elapsed().as_millis() as u64, Ordering::Relaxed); - } - batch_frames = batch_frames.saturating_add(frames); - batch_bytes = batch_bytes.saturating_add(bytes); - flush_immediately |= immediate; - } - MeWriterResponseOutcome::Close => { - let flush_started_at = - if stats_clone.telemetry_policy().me_level.allows_debug() { - Some(Instant::now()) - } else { - None - }; - let _ = - flush_client_or_cancel(&mut writer, &flow_cancel_me_writer).await; - let flush_duration_us = flush_started_at.map(|started| { - started - .elapsed() - .as_micros() - .min(u128::from(u64::MAX)) - as u64 - }); - observe_me_d2c_flush_event( - stats_clone.as_ref(), - MeD2cFlushReason::Close, - batch_frames, - batch_bytes, - flush_duration_us, - ); - shrink_session_vec(&mut frame_buf, shrink_threshold); - return Ok(()); - } - } - } - - if !flush_immediately - && !d2c_flush_policy.max_delay.is_zero() - && batch_frames < d2c_flush_policy.max_frames - && batch_bytes < d2c_flush_policy.max_bytes - { - stats_clone.increment_me_d2c_batch_timeout_armed_total(); - match tokio::time::timeout(d2c_flush_policy.max_delay, me_rx_task.recv()).await { - Ok(Some(next)) => { - let next_is_downstream_activity = - matches!(&next, MeResponse::Data { .. } | MeResponse::Ack(_)); - match process_me_writer_response_with_traffic_lease( - next, - &mut writer, - proto_tag, - rng_clone.as_ref(), - &mut frame_buf, - stats_clone.as_ref(), - &user_clone, - quota_user_stats_me_writer.as_deref(), - quota_limit, - d2c_flush_policy.quota_soft_overshoot_bytes, - traffic_lease_me_writer.as_ref(), - &flow_cancel_me_writer, - bytes_me2c_clone.as_ref(), - conn_id, - d2c_flush_policy.ack_flush_immediate, - true, - ).await? { - MeWriterResponseOutcome::Continue { frames, bytes, flush_immediately: immediate } => { - if next_is_downstream_activity { - last_downstream_activity_ms_clone - .store(session_started_at.elapsed().as_millis() as u64, Ordering::Relaxed); - } - batch_frames = batch_frames.saturating_add(frames); - batch_bytes = batch_bytes.saturating_add(bytes); - flush_immediately |= immediate; - } - MeWriterResponseOutcome::Close => { - let flush_started_at = if stats_clone - .telemetry_policy() - .me_level - .allows_debug() - { - Some(Instant::now()) - } else { - None - }; - let _ = flush_client_or_cancel( - &mut writer, - &flow_cancel_me_writer, - ) - .await; - let flush_duration_us = flush_started_at.map(|started| { - started - .elapsed() - .as_micros() - .min(u128::from(u64::MAX)) - as u64 - }); - observe_me_d2c_flush_event( - stats_clone.as_ref(), - MeD2cFlushReason::Close, - batch_frames, - batch_bytes, - flush_duration_us, - ); - shrink_session_vec(&mut frame_buf, shrink_threshold); - return Ok(()); - } - } - - while !flush_immediately - && batch_frames < d2c_flush_policy.max_frames - && batch_bytes < d2c_flush_policy.max_bytes - { - let Ok(extra) = me_rx_task.try_recv() else { - break; - }; - - let extra_is_downstream_activity = - matches!(&extra, MeResponse::Data { .. } | MeResponse::Ack(_)); - match process_me_writer_response_with_traffic_lease( - extra, - &mut writer, - proto_tag, - rng_clone.as_ref(), - &mut frame_buf, - stats_clone.as_ref(), - &user_clone, - quota_user_stats_me_writer.as_deref(), - quota_limit, - d2c_flush_policy.quota_soft_overshoot_bytes, - traffic_lease_me_writer.as_ref(), - &flow_cancel_me_writer, - bytes_me2c_clone.as_ref(), - conn_id, - d2c_flush_policy.ack_flush_immediate, - true, - ).await? { - MeWriterResponseOutcome::Continue { frames, bytes, flush_immediately: immediate } => { - if extra_is_downstream_activity { - last_downstream_activity_ms_clone - .store(session_started_at.elapsed().as_millis() as u64, Ordering::Relaxed); - } - batch_frames = batch_frames.saturating_add(frames); - batch_bytes = batch_bytes.saturating_add(bytes); - flush_immediately |= immediate; - } - MeWriterResponseOutcome::Close => { - let flush_started_at = if stats_clone - .telemetry_policy() - .me_level - .allows_debug() - { - Some(Instant::now()) - } else { - None - }; - let _ = flush_client_or_cancel( - &mut writer, - &flow_cancel_me_writer, - ) - .await; - let flush_duration_us = flush_started_at.map(|started| { - started - .elapsed() - .as_micros() - .min(u128::from(u64::MAX)) - as u64 - }); - observe_me_d2c_flush_event( - stats_clone.as_ref(), - MeD2cFlushReason::Close, - batch_frames, - batch_bytes, - flush_duration_us, - ); - shrink_session_vec(&mut frame_buf, shrink_threshold); - return Ok(()); - } - } - } - } - Ok(None) => { - debug!(conn_id, "ME channel closed"); - shrink_session_vec(&mut frame_buf, shrink_threshold); - return Err(ProxyError::MiddleConnectionLost); - } - Err(_) => { - max_delay_fired = true; - stats_clone.increment_me_d2c_batch_timeout_fired_total(); - } - } - } - - let flush_reason = classify_me_d2c_flush_reason( - flush_immediately, - batch_frames, - d2c_flush_policy.max_frames, - batch_bytes, - d2c_flush_policy.max_bytes, - max_delay_fired, - ); - let flush_started_at = if stats_clone.telemetry_policy().me_level.allows_debug() { - Some(Instant::now()) - } else { - None - }; - flush_client_or_cancel(&mut writer, &flow_cancel_me_writer).await?; - let flush_duration_us = flush_started_at.map(|started| { - started - .elapsed() - .as_micros() - .min(u128::from(u64::MAX)) as u64 - }); - observe_me_d2c_flush_event( - stats_clone.as_ref(), - flush_reason, - batch_frames, - batch_bytes, - flush_duration_us, - ); - let shrink_threshold = d2c_flush_policy.frame_buf_shrink_threshold_bytes; - let shrink_trigger = shrink_threshold - .saturating_mul(ME_D2C_FRAME_BUF_SHRINK_HYSTERESIS_FACTOR); - if frame_buf.capacity() > shrink_trigger { - let cap_before = frame_buf.capacity(); - frame_buf.shrink_to(shrink_threshold); - let cap_after = frame_buf.capacity(); - let bytes_freed = cap_before.saturating_sub(cap_after) as u64; - stats_clone.observe_me_d2c_frame_buf_shrink(bytes_freed); - } - } - _ = &mut stop_rx => { - debug!(conn_id, "ME writer stop signal"); - shrink_session_vec(&mut frame_buf, shrink_threshold); - return Ok(()); - } - } - } - }); - - let mut main_result: Result<()> = Ok(()); - let mut client_closed = false; - let mut frame_counter: u64 = 0; - let mut route_watch_open = true; - let mut seen_pressure_seq = relay_pressure_event_seq_in(shared.as_ref()); - loop { - if shared.conntrack_pressure_active() && !pressure_caps_applied { - relay_idle_policy.apply_pressure_caps(config.server.conntrack_control.profile); - pressure_caps_applied = true; - } - - if relay_idle_policy.enabled - && maybe_evict_idle_candidate_on_pressure_in( - shared.as_ref(), - conn_id, - &mut seen_pressure_seq, - stats.as_ref(), - ) - { - info!( - conn_id, - trace_id = format_args!("0x{:016x}", trace_id), - user = %user, - "Middle-relay pressure eviction for idle-candidate session" - ); - let _ = enqueue_c2me_command_in( - shared.as_ref(), - &c2me_tx, - C2MeCommand::Close, - c2me_send_timeout, - stats.as_ref(), - ) - .await; - main_result = Err(ProxyError::Proxy( - "middle-relay session evicted under pressure (idle-candidate)".to_string(), - )); - break; - } - - if let Some(cutover) = - affected_cutover_state(&route_rx, RelayRouteMode::Middle, route_snapshot.generation) - { - let delay = cutover_stagger_delay(session_id, cutover.generation); - warn!( - conn_id, - target_mode = cutover.mode.as_str(), - cutover_generation = cutover.generation, - delay_ms = delay.as_millis() as u64, - "Cutover affected middle session, closing client connection" - ); - tokio::time::sleep(delay).await; - let _ = enqueue_c2me_command_in( - shared.as_ref(), - &c2me_tx, - C2MeCommand::Close, - c2me_send_timeout, - stats.as_ref(), - ) - .await; - main_result = Err(ProxyError::RouteSwitched); - break; - } - - tokio::select! { - changed = route_rx.changed(), if route_watch_open => { - if changed.is_err() { - route_watch_open = false; - } - } - payload_result = read_client_payload_with_idle_policy_in( - &mut crypto_reader, - proto_tag, - frame_limit, - &buffer_pool, - &forensics, - &mut frame_counter, - &stats, - shared.as_ref(), - &relay_idle_policy, - &mut relay_idle_state, - last_downstream_activity_ms.as_ref(), - session_started_at, - ) => { - match payload_result { - Ok(Some((payload, quickack))) => { - trace!(conn_id, bytes = payload.len(), "C->ME frame"); - wait_for_traffic_budget( - traffic_lease.as_ref(), - RateDirection::Up, - payload.len() as u64, - None, - ) - .await?; - forensics.bytes_c2me = forensics - .bytes_c2me - .saturating_add(payload.len() as u64); - if let (Some(limit), Some(user_stats)) = - (quota_limit, quota_user_stats.as_deref()) - { - match reserve_user_quota_with_yield( - user_stats, - payload.len() as u64, - limit, - stats.as_ref(), - &flow_cancel, - None, - ) - .await - { - Ok(_) => {} - Err(MiddleQuotaReserveError::LimitExceeded) => { - main_result = Err(ProxyError::DataQuotaExceeded { - user: user.clone(), - }); - break; - } - Err(MiddleQuotaReserveError::Contended) => { - main_result = Err(ProxyError::Proxy( - "ME C->ME quota reservation contended".into(), - )); - break; - } - Err(MiddleQuotaReserveError::Cancelled) => { - main_result = Err(ProxyError::Proxy( - "ME C->ME quota reservation cancelled".into(), - )); - break; - } - Err(MiddleQuotaReserveError::DeadlineExceeded) => { - main_result = Err(ProxyError::Proxy( - "ME C->ME quota reservation deadline exceeded".into(), - )); - break; - } - } - stats.add_user_octets_from_handle(user_stats, payload.len() as u64); - } else { - stats.add_user_octets_from(&user, payload.len() as u64); - } - let mut flags = proto_flags; - if quickack { - flags |= RPC_FLAG_QUICKACK; - } - if payload.len() >= 8 && payload[..8].iter().all(|b| *b == 0) { - flags |= RPC_FLAG_NOT_ENCRYPTED; - } - let payload_permit = match acquire_c2me_payload_permit( - &c2me_byte_semaphore, - payload.len(), - c2me_send_timeout, - stats.as_ref(), - ) - .await - { - Ok(permit) => permit, - Err(e) => { - main_result = Err(e); - break; - } - }; - // Keep client read loop lightweight: route heavy ME send path via a dedicated task. - if enqueue_c2me_command_in( - shared.as_ref(), - &c2me_tx, - C2MeCommand::Data { - payload, - flags, - _permit: payload_permit, - }, - c2me_send_timeout, - stats.as_ref(), - ) - .await - .is_err() - { - main_result = Err(ProxyError::Proxy("ME sender channel closed".into())); - break; - } - } - Ok(None) => { - debug!(conn_id, "Client EOF"); - client_closed = true; - let _ = enqueue_c2me_command_in( - shared.as_ref(), - &c2me_tx, - C2MeCommand::Close, - c2me_send_timeout, - stats.as_ref(), - ) - .await; - break; - } - Err(e) => { - main_result = Err(e); - break; - } - } - } - } - } - - drop(c2me_tx); - let c2me_result = match timeout(ME_CHILD_JOIN_TIMEOUT, &mut c2me_sender).await { - Ok(joined) => { - joined.unwrap_or_else(|e| Err(ProxyError::Proxy(format!("ME sender join error: {e}")))) - } - Err(_) => { - stats.increment_me_child_join_timeout_total(); - stats.increment_me_child_abort_total(); - c2me_sender.abort(); - Err(ProxyError::Proxy("ME sender join timeout".into())) - } - }; - - flow_cancel.cancel(); - let _ = stop_tx.send(()); - let mut writer_result = match timeout(ME_CHILD_JOIN_TIMEOUT, &mut me_writer).await { - Ok(joined) => { - joined.unwrap_or_else(|e| Err(ProxyError::Proxy(format!("ME writer join error: {e}")))) - } - Err(_) => { - stats.increment_me_child_join_timeout_total(); - stats.increment_me_child_abort_total(); - me_writer.abort(); - Err(ProxyError::Proxy("ME writer join timeout".into())) - } - }; - - // When client closes, but ME channel stopped as unregistered - it isnt error - if client_closed && matches!(writer_result, Err(ProxyError::MiddleConnectionLost)) { - writer_result = Ok(()); - } - - let result = match (main_result, c2me_result, writer_result) { - (Ok(()), Ok(()), Ok(())) => Ok(()), - (Err(e), _, _) => Err(e), - (_, Err(e), _) => Err(e), - (_, _, Err(e)) => Err(e), - }; - - debug!( - user = %user, - conn_id, - trace_id = format_args!("0x{:016x}", trace_id), - duration_ms = forensics.started_at.elapsed().as_millis() as u64, - bytes_c2me = forensics.bytes_c2me, - bytes_me2c = forensics.bytes_me2c.load(Ordering::Relaxed), - frames_ok = frame_counter, - "ME relay cleanup" - ); - - let close_reason = classify_conntrack_close_reason(&result); - let publish_result = shared.publish_conntrack_close_event(ConntrackCloseEvent { - src: peer, - dst: local_addr, - reason: close_reason, - }); - if !matches!( - publish_result, - ConntrackClosePublishResult::Sent | ConntrackClosePublishResult::Disabled - ) { - stats.increment_conntrack_close_event_drop_total(); - } - - clear_relay_idle_candidate_in(shared.as_ref(), conn_id); - me_pool.registry().unregister(conn_id).await; - buffer_pool.trim_to(buffer_pool.max_buffers().min(64)); - let pool_snapshot = buffer_pool.stats(); - stats.set_buffer_pool_gauges( - pool_snapshot.pooled, - pool_snapshot.allocated, - pool_snapshot.allocated.saturating_sub(pool_snapshot.pooled), - ); - result -} - -fn classify_conntrack_close_reason(result: &Result<()>) -> ConntrackCloseReason { - match result { - Ok(()) => ConntrackCloseReason::NormalEof, - Err(ProxyError::Io(error)) if matches!(error.kind(), std::io::ErrorKind::TimedOut) => { - ConntrackCloseReason::Timeout - } - Err(ProxyError::Io(error)) - if matches!( - error.kind(), - std::io::ErrorKind::ConnectionReset - | std::io::ErrorKind::ConnectionAborted - | std::io::ErrorKind::BrokenPipe - | std::io::ErrorKind::NotConnected - | std::io::ErrorKind::UnexpectedEof - ) => - { - ConntrackCloseReason::Reset - } - Err(ProxyError::Proxy(message)) - if message.contains("pressure") || message.contains("evicted") => - { - ConntrackCloseReason::Pressure - } - Err(_) => ConntrackCloseReason::Other, - } -} - -async fn read_client_payload_with_idle_policy_in( - client_reader: &mut CryptoReader, - proto_tag: ProtoTag, - max_frame: usize, - buffer_pool: &Arc, - forensics: &RelayForensicsState, - frame_counter: &mut u64, - stats: &Stats, - shared: &ProxySharedState, - idle_policy: &RelayClientIdlePolicy, - idle_state: &mut RelayClientIdleState, - last_downstream_activity_ms: &AtomicU64, - session_started_at: Instant, -) -> Result> -where - R: AsyncRead + Unpin + Send + 'static, -{ - const LEGACY_MAX_CONSECUTIVE_ZERO_LEN_FRAMES: u32 = 4; - - async fn read_exact_with_policy( - client_reader: &mut CryptoReader, - buf: &mut [u8], - idle_policy: &RelayClientIdlePolicy, - idle_state: &mut RelayClientIdleState, - last_downstream_activity_ms: &AtomicU64, - session_started_at: Instant, - forensics: &RelayForensicsState, - stats: &Stats, - shared: &ProxySharedState, - read_label: &'static str, - ) -> Result<()> - where - R: AsyncRead + Unpin + Send + 'static, - { - fn hard_deadline( - idle_policy: &RelayClientIdlePolicy, - idle_state: &RelayClientIdleState, - session_started_at: Instant, - last_downstream_activity_ms: u64, - ) -> Instant { - let mut deadline = idle_state.last_client_frame_at + idle_policy.hard_idle; - if idle_policy.grace_after_downstream_activity.is_zero() { - return deadline; - } - - let downstream_at = - session_started_at + Duration::from_millis(last_downstream_activity_ms); - if downstream_at > idle_state.last_client_frame_at { - let grace_deadline = downstream_at + idle_policy.grace_after_downstream_activity; - if grace_deadline > deadline { - deadline = grace_deadline; - } - } - deadline - } - - let mut filled = 0usize; - while filled < buf.len() { - let timeout_window = if idle_policy.enabled { - let now = Instant::now(); - let downstream_ms = last_downstream_activity_ms.load(Ordering::Relaxed); - let hard_deadline = - hard_deadline(idle_policy, idle_state, session_started_at, downstream_ms); - if !idle_state.soft_idle_marked - && now.saturating_duration_since(idle_state.last_client_frame_at) - >= idle_policy.soft_idle - { - idle_state.soft_idle_marked = true; - if mark_relay_idle_candidate_in(shared, forensics.conn_id) { - stats.increment_relay_idle_soft_mark_total(); - } - info!( - trace_id = format_args!("0x{:016x}", forensics.trace_id), - conn_id = forensics.conn_id, - user = %forensics.user, - read_label, - soft_idle_secs = idle_policy.soft_idle.as_secs(), - hard_idle_secs = idle_policy.hard_idle.as_secs(), - grace_secs = idle_policy.grace_after_downstream_activity.as_secs(), - "Middle-relay soft idle mark" - ); - } - - let soft_deadline = idle_state.last_client_frame_at + idle_policy.soft_idle; - let next_deadline = if idle_state.soft_idle_marked { - hard_deadline - } else { - soft_deadline.min(hard_deadline) - }; - let mut remaining = next_deadline.saturating_duration_since(now); - if remaining.is_zero() { - remaining = Duration::from_millis(1); - } - remaining.min(RELAY_IDLE_IO_POLL_MAX) - } else { - idle_policy.legacy_frame_read_timeout - }; - - let read_result = timeout(timeout_window, client_reader.read(&mut buf[filled..])).await; - match read_result { - Ok(Ok(0)) => { - return Err(ProxyError::Io(std::io::Error::from( - std::io::ErrorKind::UnexpectedEof, - ))); - } - Ok(Ok(n)) => { - filled = filled.saturating_add(n); - } - Ok(Err(e)) => return Err(ProxyError::Io(e)), - Err(_) if !idle_policy.enabled => { - return Err(ProxyError::Io(std::io::Error::new( - std::io::ErrorKind::TimedOut, - format!( - "middle-relay client frame read timeout while reading {read_label}" - ), - ))); - } - Err(_) => { - let now = Instant::now(); - let downstream_ms = last_downstream_activity_ms.load(Ordering::Relaxed); - let hard_deadline = - hard_deadline(idle_policy, idle_state, session_started_at, downstream_ms); - if now >= hard_deadline { - clear_relay_idle_candidate_in(shared, forensics.conn_id); - stats.increment_relay_idle_hard_close_total(); - let client_idle_secs = now - .saturating_duration_since(idle_state.last_client_frame_at) - .as_secs(); - let downstream_idle_secs = now - .saturating_duration_since( - session_started_at + Duration::from_millis(downstream_ms), - ) - .as_secs(); - warn!( - trace_id = format_args!("0x{:016x}", forensics.trace_id), - conn_id = forensics.conn_id, - user = %forensics.user, - read_label, - client_idle_secs, - downstream_idle_secs, - soft_idle_secs = idle_policy.soft_idle.as_secs(), - hard_idle_secs = idle_policy.hard_idle.as_secs(), - grace_secs = idle_policy.grace_after_downstream_activity.as_secs(), - "Middle-relay hard idle close" - ); - return Err(ProxyError::Io(std::io::Error::new( - std::io::ErrorKind::TimedOut, - format!( - "middle-relay hard idle timeout while reading {read_label}: client_idle_secs={client_idle_secs}, downstream_idle_secs={downstream_idle_secs}, soft_idle_secs={}, hard_idle_secs={}, grace_secs={}", - idle_policy.soft_idle.as_secs(), - idle_policy.hard_idle.as_secs(), - idle_policy.grace_after_downstream_activity.as_secs(), - ), - ))); - } - } - } - } - - Ok(()) - } - - let mut consecutive_zero_len_frames = 0u32; - loop { - let (len, quickack, raw_len_bytes) = match proto_tag { - ProtoTag::Abridged => { - let mut first = [0u8; 1]; - match read_exact_with_policy( - client_reader, - &mut first, - idle_policy, - idle_state, - last_downstream_activity_ms, - session_started_at, - forensics, - stats, - shared, - "abridged.first_len_byte", - ) - .await - { - Ok(()) => {} - Err(ProxyError::Io(e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => { - return Ok(None); - } - Err(e) => return Err(e), - } - - let quickack = (first[0] & 0x80) != 0; - let len_words = if (first[0] & 0x7f) == 0x7f { - let mut ext = [0u8; 3]; - read_exact_with_policy( - client_reader, - &mut ext, - idle_policy, - idle_state, - last_downstream_activity_ms, - session_started_at, - forensics, - stats, - shared, - "abridged.extended_len", - ) - .await?; - u32::from_le_bytes([ext[0], ext[1], ext[2], 0]) as usize - } else { - (first[0] & 0x7f) as usize - }; - - let len = len_words - .checked_mul(4) - .ok_or_else(|| ProxyError::Proxy("Abridged frame length overflow".into()))?; - (len, quickack, None) - } - ProtoTag::Intermediate | ProtoTag::Secure => { - let mut len_buf = [0u8; 4]; - match read_exact_with_policy( - client_reader, - &mut len_buf, - idle_policy, - idle_state, - last_downstream_activity_ms, - session_started_at, - forensics, - stats, - shared, - "len_prefix", - ) - .await - { - Ok(()) => {} - Err(ProxyError::Io(e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => { - return Ok(None); - } - Err(e) => return Err(e), - } - let quickack = (len_buf[3] & 0x80) != 0; - ( - (u32::from_le_bytes(len_buf) & 0x7fff_ffff) as usize, - quickack, - Some(len_buf), - ) - } - }; - - if len == 0 { - idle_state.on_client_tiny_frame(Instant::now()); - idle_state.tiny_frame_debt = idle_state - .tiny_frame_debt - .saturating_add(TINY_FRAME_DEBT_PER_TINY); - if idle_state.tiny_frame_debt >= TINY_FRAME_DEBT_LIMIT { - stats.increment_relay_protocol_desync_close_total(); - return Err(ProxyError::Proxy(format!( - "Tiny frame overhead limit exceeded: debt={}, conn_id={}", - idle_state.tiny_frame_debt, forensics.conn_id - ))); - } - - if !idle_policy.enabled { - consecutive_zero_len_frames = consecutive_zero_len_frames.saturating_add(1); - if consecutive_zero_len_frames > LEGACY_MAX_CONSECUTIVE_ZERO_LEN_FRAMES { - stats.increment_relay_protocol_desync_close_total(); - return Err(ProxyError::Proxy( - "Excessive zero-length abridged frames".to_string(), - )); - } - } - continue; - } - if len < 4 && proto_tag != ProtoTag::Abridged { - warn!( - trace_id = format_args!("0x{:016x}", forensics.trace_id), - conn_id = forensics.conn_id, - user = %forensics.user, - len, - proto = ?proto_tag, - "Frame too small — corrupt or probe" - ); - stats.increment_relay_protocol_desync_close_total(); - return Err(ProxyError::Proxy(format!("Frame too small: {len}"))); - } - - if len > max_frame { - return Err(report_desync_frame_too_large_in( - shared, - forensics, - proto_tag, - *frame_counter, - max_frame, - len, - raw_len_bytes, - stats, - )); - } - - let secure_payload_len = if proto_tag == ProtoTag::Secure { - match secure_payload_len_from_wire_len(len) { - Some(payload_len) => payload_len, - None => { - stats.increment_secure_padding_invalid(); - stats.increment_relay_protocol_desync_close_total(); - return Err(ProxyError::Proxy(format!( - "Invalid secure frame length: {len}" - ))); - } - } - } else { - len - }; - - let mut payload = buffer_pool.get(); - payload.clear(); - let current_cap = payload.capacity(); - if current_cap < len { - payload.reserve(len - current_cap); - } - payload.resize(len, 0); - read_exact_with_policy( - client_reader, - &mut payload[..len], - idle_policy, - idle_state, - last_downstream_activity_ms, - session_started_at, - forensics, - stats, - shared, - "payload", - ) - .await?; - - // Secure Intermediate: strip validated trailing padding bytes. - if proto_tag == ProtoTag::Secure { - payload.truncate(secure_payload_len); - } - *frame_counter += 1; - idle_state.on_client_frame(Instant::now()); - idle_state.tiny_frame_debt = idle_state.tiny_frame_debt.saturating_sub(1); - clear_relay_idle_candidate_in(shared, forensics.conn_id); - return Ok(Some((payload, quickack))); - } -} - -#[cfg(test)] -async fn read_client_payload_with_idle_policy( - client_reader: &mut CryptoReader, - proto_tag: ProtoTag, - max_frame: usize, - buffer_pool: &Arc, - forensics: &RelayForensicsState, - frame_counter: &mut u64, - stats: &Stats, - idle_policy: &RelayClientIdlePolicy, - idle_state: &mut RelayClientIdleState, - last_downstream_activity_ms: &AtomicU64, - session_started_at: Instant, -) -> Result> -where - R: AsyncRead + Unpin + Send + 'static, -{ - let shared = ProxySharedState::new(); - read_client_payload_with_idle_policy_in( - client_reader, - proto_tag, - max_frame, - buffer_pool, - forensics, - frame_counter, - stats, - shared.as_ref(), - idle_policy, - idle_state, - last_downstream_activity_ms, - session_started_at, - ) - .await -} - -#[cfg(test)] -async fn read_client_payload_legacy( - client_reader: &mut CryptoReader, - proto_tag: ProtoTag, - max_frame: usize, - frame_read_timeout: Duration, - buffer_pool: &Arc, - forensics: &RelayForensicsState, - frame_counter: &mut u64, - stats: &Stats, -) -> Result> -where - R: AsyncRead + Unpin + Send + 'static, -{ - let now = Instant::now(); - let shared = ProxySharedState::new(); - let mut idle_state = RelayClientIdleState::new(now); - let last_downstream_activity_ms = AtomicU64::new(0); - let idle_policy = RelayClientIdlePolicy::disabled(frame_read_timeout); - read_client_payload_with_idle_policy_in( - client_reader, - proto_tag, - max_frame, - buffer_pool, - forensics, - frame_counter, - stats, - shared.as_ref(), - &idle_policy, - &mut idle_state, - &last_downstream_activity_ms, - now, - ) - .await -} - -#[cfg(test)] -async fn read_client_payload( - client_reader: &mut CryptoReader, - proto_tag: ProtoTag, - max_frame: usize, - frame_read_timeout: Duration, - buffer_pool: &Arc, - forensics: &RelayForensicsState, - frame_counter: &mut u64, - stats: &Stats, -) -> Result> -where - R: AsyncRead + Unpin + Send + 'static, -{ - read_client_payload_legacy( - client_reader, - proto_tag, - max_frame, - frame_read_timeout, - buffer_pool, - forensics, - frame_counter, - stats, - ) - .await -} - -enum MeWriterResponseOutcome { - Continue { - frames: usize, - bytes: usize, - flush_immediately: bool, - }, - Close, -} - -#[cfg(test)] -async fn process_me_writer_response( - response: MeResponse, - client_writer: &mut CryptoWriter, - proto_tag: ProtoTag, - rng: &SecureRandom, - frame_buf: &mut Vec, - stats: &Stats, - user: &str, - quota_user_stats: Option<&UserStats>, - quota_limit: Option, - quota_soft_overshoot_bytes: u64, - bytes_me2c: &AtomicU64, - conn_id: u64, - ack_flush_immediate: bool, - batched: bool, -) -> Result -where - W: AsyncWrite + Unpin + Send + 'static, -{ - process_me_writer_response_with_traffic_lease( - response, - client_writer, - proto_tag, - rng, - frame_buf, - stats, - user, - quota_user_stats, - quota_limit, - quota_soft_overshoot_bytes, - None, - &CancellationToken::new(), - bytes_me2c, - conn_id, - ack_flush_immediate, - batched, - ) - .await -} - -async fn process_me_writer_response_with_traffic_lease( - response: MeResponse, - client_writer: &mut CryptoWriter, - proto_tag: ProtoTag, - rng: &SecureRandom, - frame_buf: &mut Vec, - stats: &Stats, - user: &str, - quota_user_stats: Option<&UserStats>, - quota_limit: Option, - quota_soft_overshoot_bytes: u64, - traffic_lease: Option<&Arc>, - cancel: &CancellationToken, - bytes_me2c: &AtomicU64, - conn_id: u64, - ack_flush_immediate: bool, - batched: bool, -) -> Result -where - W: AsyncWrite + Unpin + Send + 'static, -{ - match response { - MeResponse::Data { flags, data, .. } => { - if batched { - trace!(conn_id, bytes = data.len(), flags, "ME->C data (batched)"); - } else { - trace!(conn_id, bytes = data.len(), flags, "ME->C data"); - } - let data_len = data.len() as u64; - if let (Some(limit), Some(user_stats)) = (quota_limit, quota_user_stats) { - let soft_limit = quota_soft_cap(limit, quota_soft_overshoot_bytes); - match reserve_user_quota_with_yield( - user_stats, data_len, soft_limit, stats, cancel, None, - ) - .await - { - Ok(_) => {} - Err(MiddleQuotaReserveError::LimitExceeded) => { - stats.increment_me_d2c_quota_reject_total(MeD2cQuotaRejectStage::PreWrite); - return Err(ProxyError::DataQuotaExceeded { - user: user.to_string(), - }); - } - Err(MiddleQuotaReserveError::Contended) => { - return Err(ProxyError::Proxy( - "ME D->C quota reservation contended".into(), - )); - } - Err(MiddleQuotaReserveError::Cancelled) => { - return Err(ProxyError::Proxy( - "ME D->C quota reservation cancelled".into(), - )); - } - Err(MiddleQuotaReserveError::DeadlineExceeded) => { - return Err(ProxyError::Proxy( - "ME D->C quota reservation deadline exceeded".into(), - )); - } - } - } - wait_for_traffic_budget_or_cancel( - traffic_lease, - RateDirection::Down, - data_len, - cancel, - stats, - None, - ) - .await?; - - let write_mode = match write_client_payload( - client_writer, - proto_tag, - flags, - &data, - rng, - frame_buf, - cancel, - ) - .await - { - Ok(mode) => mode, - Err(err) => { - if quota_limit.is_some() { - stats.add_quota_write_fail_bytes_total(data_len); - stats.increment_quota_write_fail_events_total(); - } - return Err(err); - } - }; - - bytes_me2c.fetch_add(data_len, Ordering::Relaxed); - if let Some(user_stats) = quota_user_stats { - stats.add_user_octets_to_handle(user_stats, data_len); - } else { - stats.add_user_octets_to(user, data_len); - } - stats.increment_me_d2c_data_frames_total(); - stats.add_me_d2c_payload_bytes_total(data_len); - stats.increment_me_d2c_write_mode(write_mode); - - Ok(MeWriterResponseOutcome::Continue { - frames: 1, - bytes: data.len(), - flush_immediately: false, - }) - } - MeResponse::Ack(confirm) => { - if batched { - trace!(conn_id, confirm, "ME->C quickack (batched)"); - } else { - trace!(conn_id, confirm, "ME->C quickack"); - } - wait_for_traffic_budget_or_cancel( - traffic_lease, - RateDirection::Down, - 4, - cancel, - stats, - None, - ) - .await?; - write_client_ack(client_writer, proto_tag, confirm, cancel).await?; - stats.increment_me_d2c_ack_frames_total(); - - Ok(MeWriterResponseOutcome::Continue { - frames: 1, - bytes: 4, - flush_immediately: ack_flush_immediate, - }) - } - MeResponse::Close => { - if batched { - debug!(conn_id, "ME sent close (batched)"); - } else { - debug!(conn_id, "ME sent close"); - } - Ok(MeWriterResponseOutcome::Close) - } - } -} - -fn compute_intermediate_secure_wire_len( - data_len: usize, - padding_len: usize, - quickack: bool, -) -> Result<(u32, usize)> { - let wire_len = data_len - .checked_add(padding_len) - .ok_or_else(|| ProxyError::Proxy("Frame length overflow".into()))?; - if wire_len > 0x7fff_ffffusize { - return Err(ProxyError::Proxy(format!( - "Intermediate/Secure frame too large: {wire_len}" - ))); - } - - let total = 4usize - .checked_add(wire_len) - .ok_or_else(|| ProxyError::Proxy("Frame buffer size overflow".into()))?; - let mut len_val = u32::try_from(wire_len) - .map_err(|_| ProxyError::Proxy("Frame length conversion overflow".into()))?; - if quickack { - len_val |= 0x8000_0000; - } - Ok((len_val, total)) -} - -async fn write_client_payload( - client_writer: &mut CryptoWriter, - proto_tag: ProtoTag, - flags: u32, - data: &[u8], - rng: &SecureRandom, - frame_buf: &mut Vec, - cancel: &CancellationToken, -) -> Result -where - W: AsyncWrite + Unpin + Send + 'static, -{ - let quickack = (flags & RPC_FLAG_QUICKACK) != 0; - - let write_mode = match proto_tag { - ProtoTag::Abridged => { - if !data.len().is_multiple_of(4) { - return Err(ProxyError::Proxy(format!( - "Abridged payload must be 4-byte aligned, got {}", - data.len() - ))); - } - - let len_words = data.len() / 4; - if len_words < 0x7f { - let mut first = len_words as u8; - if quickack { - first |= 0x80; - } - let wire_len = 1usize.saturating_add(data.len()); - if wire_len <= ME_D2C_SINGLE_WRITE_COALESCE_MAX_BYTES { - frame_buf.clear(); - frame_buf.reserve(wire_len); - frame_buf.push(first); - frame_buf.extend_from_slice(data); - write_all_client_or_cancel(client_writer, frame_buf.as_slice(), cancel).await?; - MeD2cWriteMode::Coalesced - } else { - let header = [first]; - write_all_client_or_cancel(client_writer, &header, cancel).await?; - write_all_client_or_cancel(client_writer, data, cancel).await?; - MeD2cWriteMode::Split - } - } else if len_words < (1 << 24) { - let mut first = 0x7fu8; - if quickack { - first |= 0x80; - } - let lw = (len_words as u32).to_le_bytes(); - let wire_len = 4usize.saturating_add(data.len()); - if wire_len <= ME_D2C_SINGLE_WRITE_COALESCE_MAX_BYTES { - frame_buf.clear(); - frame_buf.reserve(wire_len); - frame_buf.extend_from_slice(&[first, lw[0], lw[1], lw[2]]); - frame_buf.extend_from_slice(data); - write_all_client_or_cancel(client_writer, frame_buf.as_slice(), cancel).await?; - MeD2cWriteMode::Coalesced - } else { - let header = [first, lw[0], lw[1], lw[2]]; - write_all_client_or_cancel(client_writer, &header, cancel).await?; - write_all_client_or_cancel(client_writer, data, cancel).await?; - MeD2cWriteMode::Split - } - } else { - return Err(ProxyError::Proxy(format!( - "Abridged frame too large: {}", - data.len() - ))); - } - } - ProtoTag::Intermediate | ProtoTag::Secure => { - let padding_len = if proto_tag == ProtoTag::Secure { - if !is_valid_secure_payload_len(data.len()) { - return Err(ProxyError::Proxy(format!( - "Secure payload must be 4-byte aligned, got {}", - data.len() - ))); - } - secure_padding_len(data.len(), rng) - } else { - 0 - }; - - let (len_val, total) = - compute_intermediate_secure_wire_len(data.len(), padding_len, quickack)?; - if total <= ME_D2C_SINGLE_WRITE_COALESCE_MAX_BYTES { - frame_buf.clear(); - frame_buf.reserve(total); - frame_buf.extend_from_slice(&len_val.to_le_bytes()); - frame_buf.extend_from_slice(data); - if padding_len > 0 { - let start = frame_buf.len(); - frame_buf.resize(start + padding_len, 0); - rng.fill(&mut frame_buf[start..]); - } - write_all_client_or_cancel(client_writer, frame_buf.as_slice(), cancel).await?; - MeD2cWriteMode::Coalesced - } else { - let header = len_val.to_le_bytes(); - write_all_client_or_cancel(client_writer, &header, cancel).await?; - write_all_client_or_cancel(client_writer, data, cancel).await?; - if padding_len > 0 { - frame_buf.clear(); - if frame_buf.capacity() < padding_len { - frame_buf.reserve(padding_len); - } - frame_buf.resize(padding_len, 0); - rng.fill(frame_buf.as_mut_slice()); - write_all_client_or_cancel(client_writer, frame_buf.as_slice(), cancel).await?; - } - MeD2cWriteMode::Split - } - } - }; - - Ok(write_mode) -} - -async fn write_client_ack( - client_writer: &mut CryptoWriter, - proto_tag: ProtoTag, - confirm: u32, - cancel: &CancellationToken, -) -> Result<()> -where - W: AsyncWrite + Unpin + Send + 'static, -{ - let bytes = if proto_tag == ProtoTag::Abridged { - confirm.to_be_bytes() - } else { - confirm.to_le_bytes() - }; - write_all_client_or_cancel(client_writer, &bytes, cancel).await -} - -async fn write_all_client_or_cancel( - client_writer: &mut CryptoWriter, - bytes: &[u8], - cancel: &CancellationToken, -) -> Result<()> -where - W: AsyncWrite + Unpin + Send + 'static, -{ - tokio::select! { - result = client_writer.write_all(bytes) => result.map_err(ProxyError::Io), - _ = cancel.cancelled() => Err(ProxyError::MiddleClientWriterCancelled), - } -} - -async fn flush_client_or_cancel( - client_writer: &mut CryptoWriter, - cancel: &CancellationToken, -) -> Result<()> -where - W: AsyncWrite + Unpin + Send + 'static, -{ - tokio::select! { - result = client_writer.flush() => result.map_err(ProxyError::Io), - _ = cancel.cancelled() => Err(ProxyError::MiddleClientWriterCancelled), - } -} - #[cfg(test)] #[path = "tests/middle_relay_idle_policy_security_tests.rs"] mod idle_policy_security_tests; diff --git a/src/proxy/middle_relay/c2me.rs b/src/proxy/middle_relay/c2me.rs new file mode 100644 index 0000000..2b668e8 --- /dev/null +++ b/src/proxy/middle_relay/c2me.rs @@ -0,0 +1,104 @@ +use super::*; + +pub(in crate::proxy::middle_relay) enum C2MeCommand { + Data { + payload: PooledBuffer, + flags: u32, + _permit: OwnedSemaphorePermit, + }, + Close, +} + +pub(super) fn should_yield_c2me_sender(sent_since_yield: usize, has_backlog: bool) -> bool { + has_backlog && sent_since_yield >= C2ME_SENDER_FAIRNESS_BUDGET +} + +pub(super) fn c2me_payload_permits(payload_len: usize) -> u32 { + payload_len + .max(1) + .div_ceil(C2ME_QUEUED_BYTE_PERMIT_UNIT) + .min(u32::MAX as usize) as u32 +} + +pub(super) fn c2me_queued_permit_budget(channel_capacity: usize, frame_limit: usize) -> usize { + channel_capacity + .saturating_mul(C2ME_QUEUED_PERMITS_PER_SLOT) + .max(c2me_payload_permits(frame_limit) as usize) + .max(1) +} + +pub(super) async fn acquire_c2me_payload_permit( + semaphore: &Arc, + payload_len: usize, + send_timeout: Option, + stats: &Stats, +) -> Result { + let permits = c2me_payload_permits(payload_len); + let acquire = semaphore.clone().acquire_many_owned(permits); + match send_timeout { + Some(send_timeout) => match timeout(send_timeout, acquire).await { + Ok(Ok(permit)) => Ok(permit), + Ok(Err(_)) => Err(ProxyError::Proxy("ME sender byte budget closed".into())), + Err(_) => { + stats.increment_me_c2me_send_timeout_total(); + Err(ProxyError::Proxy("ME sender byte budget timeout".into())) + } + }, + None => acquire + .await + .map_err(|_| ProxyError::Proxy("ME sender byte budget closed".into())), + } +} + +pub(super) async fn enqueue_c2me_command_in( + shared: &ProxySharedState, + tx: &mpsc::Sender, + cmd: C2MeCommand, + send_timeout: Option, + stats: &Stats, +) -> std::result::Result<(), mpsc::error::SendError> { + match tx.try_send(cmd) { + Ok(()) => Ok(()), + Err(mpsc::error::TrySendError::Closed(cmd)) => Err(mpsc::error::SendError(cmd)), + Err(mpsc::error::TrySendError::Full(cmd)) => { + stats.increment_me_c2me_send_full_total(); + stats.increment_me_c2me_send_high_water_total(); + note_relay_pressure_event_in(shared); + // Cooperative yield reduces burst catch-up when the per-conn queue is near saturation. + if tx.capacity() <= C2ME_SOFT_PRESSURE_MIN_FREE_SLOTS { + tokio::task::yield_now().await; + } + let reserve_result = match send_timeout { + Some(send_timeout) => match timeout(send_timeout, tx.reserve()).await { + Ok(result) => result, + Err(_) => { + stats.increment_me_c2me_send_timeout_total(); + return Err(mpsc::error::SendError(cmd)); + } + }, + None => tx.reserve().await, + }; + match reserve_result { + Ok(permit) => { + permit.send(cmd); + Ok(()) + } + Err(_) => { + stats.increment_me_c2me_send_timeout_total(); + Err(mpsc::error::SendError(cmd)) + } + } + } + } +} + +#[cfg(test)] +pub(crate) async fn enqueue_c2me_command( + tx: &mpsc::Sender, + cmd: C2MeCommand, + send_timeout: Option, + stats: &Stats, +) -> std::result::Result<(), mpsc::error::SendError> { + let shared = ProxySharedState::new(); + enqueue_c2me_command_in(shared.as_ref(), tx, cmd, send_timeout, stats).await +} diff --git a/src/proxy/middle_relay/d2c.rs b/src/proxy/middle_relay/d2c.rs new file mode 100644 index 0000000..6adc442 --- /dev/null +++ b/src/proxy/middle_relay/d2c.rs @@ -0,0 +1,456 @@ +use super::*; + +#[derive(Clone, Copy)] +pub(super) struct MeD2cFlushPolicy { + pub(super) max_frames: usize, + pub(super) max_bytes: usize, + pub(super) max_delay: Duration, + pub(super) ack_flush_immediate: bool, + pub(super) quota_soft_overshoot_bytes: u64, + pub(super) frame_buf_shrink_threshold_bytes: usize, +} + +impl MeD2cFlushPolicy { + pub(super) fn from_config(config: &ProxyConfig) -> Self { + Self { + max_frames: config + .general + .me_d2c_flush_batch_max_frames + .max(ME_D2C_FLUSH_BATCH_MAX_FRAMES_MIN), + max_bytes: config + .general + .me_d2c_flush_batch_max_bytes + .max(ME_D2C_FLUSH_BATCH_MAX_BYTES_MIN), + max_delay: Duration::from_micros(config.general.me_d2c_flush_batch_max_delay_us), + ack_flush_immediate: config.general.me_d2c_ack_flush_immediate, + quota_soft_overshoot_bytes: config.general.me_quota_soft_overshoot_bytes, + frame_buf_shrink_threshold_bytes: config + .general + .me_d2c_frame_buf_shrink_threshold_bytes + .max(4096), + } + } +} + +pub(super) fn classify_me_d2c_flush_reason( + flush_immediately: bool, + batch_frames: usize, + max_frames: usize, + batch_bytes: usize, + max_bytes: usize, + max_delay_fired: bool, +) -> MeD2cFlushReason { + if flush_immediately { + return MeD2cFlushReason::AckImmediate; + } + if batch_frames >= max_frames { + return MeD2cFlushReason::BatchFrames; + } + if batch_bytes >= max_bytes { + return MeD2cFlushReason::BatchBytes; + } + if max_delay_fired { + return MeD2cFlushReason::MaxDelay; + } + MeD2cFlushReason::QueueDrain +} + +pub(super) fn observe_me_d2c_flush_event( + stats: &Stats, + reason: MeD2cFlushReason, + batch_frames: usize, + batch_bytes: usize, + flush_duration_us: Option, +) { + stats.increment_me_d2c_flush_reason(reason); + if batch_frames > 0 || batch_bytes > 0 { + stats.increment_me_d2c_batches_total(); + stats.add_me_d2c_batch_frames_total(batch_frames as u64); + stats.add_me_d2c_batch_bytes_total(batch_bytes as u64); + stats.observe_me_d2c_batch_frames(batch_frames as u64); + stats.observe_me_d2c_batch_bytes(batch_bytes as u64); + } + if let Some(duration_us) = flush_duration_us { + stats.observe_me_d2c_flush_duration_us(duration_us); + } +} + +pub(super) enum MeWriterResponseOutcome { + Continue { + frames: usize, + bytes: usize, + flush_immediately: bool, + }, + Close, +} + +#[cfg(test)] +pub(crate) async fn process_me_writer_response( + response: MeResponse, + client_writer: &mut CryptoWriter, + proto_tag: ProtoTag, + rng: &SecureRandom, + frame_buf: &mut Vec, + stats: &Stats, + user: &str, + quota_user_stats: Option<&UserStats>, + quota_limit: Option, + quota_soft_overshoot_bytes: u64, + bytes_me2c: &AtomicU64, + conn_id: u64, + ack_flush_immediate: bool, + batched: bool, +) -> Result +where + W: AsyncWrite + Unpin + Send + 'static, +{ + process_me_writer_response_with_traffic_lease( + response, + client_writer, + proto_tag, + rng, + frame_buf, + stats, + user, + quota_user_stats, + quota_limit, + quota_soft_overshoot_bytes, + None, + &CancellationToken::new(), + bytes_me2c, + conn_id, + ack_flush_immediate, + batched, + ) + .await +} + +pub(crate) async fn process_me_writer_response_with_traffic_lease( + response: MeResponse, + client_writer: &mut CryptoWriter, + proto_tag: ProtoTag, + rng: &SecureRandom, + frame_buf: &mut Vec, + stats: &Stats, + user: &str, + quota_user_stats: Option<&UserStats>, + quota_limit: Option, + quota_soft_overshoot_bytes: u64, + traffic_lease: Option<&Arc>, + cancel: &CancellationToken, + bytes_me2c: &AtomicU64, + conn_id: u64, + ack_flush_immediate: bool, + batched: bool, +) -> Result +where + W: AsyncWrite + Unpin + Send + 'static, +{ + match response { + MeResponse::Data { flags, data, .. } => { + if batched { + trace!(conn_id, bytes = data.len(), flags, "ME->C data (batched)"); + } else { + trace!(conn_id, bytes = data.len(), flags, "ME->C data"); + } + let data_len = data.len() as u64; + if let (Some(limit), Some(user_stats)) = (quota_limit, quota_user_stats) { + let soft_limit = quota_soft_cap(limit, quota_soft_overshoot_bytes); + match reserve_user_quota_with_yield( + user_stats, data_len, soft_limit, stats, cancel, None, + ) + .await + { + Ok(_) => {} + Err(MiddleQuotaReserveError::LimitExceeded) => { + stats.increment_me_d2c_quota_reject_total(MeD2cQuotaRejectStage::PreWrite); + return Err(ProxyError::DataQuotaExceeded { + user: user.to_string(), + }); + } + Err(MiddleQuotaReserveError::Contended) => { + return Err(ProxyError::Proxy( + "ME D->C quota reservation contended".into(), + )); + } + Err(MiddleQuotaReserveError::Cancelled) => { + return Err(ProxyError::Proxy( + "ME D->C quota reservation cancelled".into(), + )); + } + Err(MiddleQuotaReserveError::DeadlineExceeded) => { + return Err(ProxyError::Proxy( + "ME D->C quota reservation deadline exceeded".into(), + )); + } + } + } + wait_for_traffic_budget_or_cancel( + traffic_lease, + RateDirection::Down, + data_len, + cancel, + stats, + None, + ) + .await?; + + let write_mode = match write_client_payload( + client_writer, + proto_tag, + flags, + &data, + rng, + frame_buf, + cancel, + ) + .await + { + Ok(mode) => mode, + Err(err) => { + if quota_limit.is_some() { + stats.add_quota_write_fail_bytes_total(data_len); + stats.increment_quota_write_fail_events_total(); + } + return Err(err); + } + }; + + bytes_me2c.fetch_add(data_len, Ordering::Relaxed); + if let Some(user_stats) = quota_user_stats { + stats.add_user_octets_to_handle(user_stats, data_len); + } else { + stats.add_user_octets_to(user, data_len); + } + stats.increment_me_d2c_data_frames_total(); + stats.add_me_d2c_payload_bytes_total(data_len); + stats.increment_me_d2c_write_mode(write_mode); + + Ok(MeWriterResponseOutcome::Continue { + frames: 1, + bytes: data.len(), + flush_immediately: false, + }) + } + MeResponse::Ack(confirm) => { + if batched { + trace!(conn_id, confirm, "ME->C quickack (batched)"); + } else { + trace!(conn_id, confirm, "ME->C quickack"); + } + wait_for_traffic_budget_or_cancel( + traffic_lease, + RateDirection::Down, + 4, + cancel, + stats, + None, + ) + .await?; + write_client_ack(client_writer, proto_tag, confirm, cancel).await?; + stats.increment_me_d2c_ack_frames_total(); + + Ok(MeWriterResponseOutcome::Continue { + frames: 1, + bytes: 4, + flush_immediately: ack_flush_immediate, + }) + } + MeResponse::Close => { + if batched { + debug!(conn_id, "ME sent close (batched)"); + } else { + debug!(conn_id, "ME sent close"); + } + Ok(MeWriterResponseOutcome::Close) + } + } +} + +/// Computes the intermediate/secure wire length while rejecting lossy casts. +pub(in crate::proxy::middle_relay) fn compute_intermediate_secure_wire_len( + data_len: usize, + padding_len: usize, + quickack: bool, +) -> Result<(u32, usize)> { + let wire_len = data_len + .checked_add(padding_len) + .ok_or_else(|| ProxyError::Proxy("Frame length overflow".into()))?; + if wire_len > 0x7fff_ffffusize { + return Err(ProxyError::Proxy(format!( + "Intermediate/Secure frame too large: {wire_len}" + ))); + } + + let total = 4usize + .checked_add(wire_len) + .ok_or_else(|| ProxyError::Proxy("Frame buffer size overflow".into()))?; + let mut len_val = u32::try_from(wire_len) + .map_err(|_| ProxyError::Proxy("Frame length conversion overflow".into()))?; + if quickack { + len_val |= 0x8000_0000; + } + Ok((len_val, total)) +} + +pub(super) async fn write_client_payload( + client_writer: &mut CryptoWriter, + proto_tag: ProtoTag, + flags: u32, + data: &[u8], + rng: &SecureRandom, + frame_buf: &mut Vec, + cancel: &CancellationToken, +) -> Result +where + W: AsyncWrite + Unpin + Send + 'static, +{ + let quickack = (flags & RPC_FLAG_QUICKACK) != 0; + + let write_mode = match proto_tag { + ProtoTag::Abridged => { + if !data.len().is_multiple_of(4) { + return Err(ProxyError::Proxy(format!( + "Abridged payload must be 4-byte aligned, got {}", + data.len() + ))); + } + + let len_words = data.len() / 4; + if len_words < 0x7f { + let mut first = len_words as u8; + if quickack { + first |= 0x80; + } + let wire_len = 1usize.saturating_add(data.len()); + if wire_len <= ME_D2C_SINGLE_WRITE_COALESCE_MAX_BYTES { + frame_buf.clear(); + frame_buf.reserve(wire_len); + frame_buf.push(first); + frame_buf.extend_from_slice(data); + write_all_client_or_cancel(client_writer, frame_buf.as_slice(), cancel).await?; + MeD2cWriteMode::Coalesced + } else { + let header = [first]; + write_all_client_or_cancel(client_writer, &header, cancel).await?; + write_all_client_or_cancel(client_writer, data, cancel).await?; + MeD2cWriteMode::Split + } + } else if len_words < (1 << 24) { + let mut first = 0x7fu8; + if quickack { + first |= 0x80; + } + let lw = (len_words as u32).to_le_bytes(); + let wire_len = 4usize.saturating_add(data.len()); + if wire_len <= ME_D2C_SINGLE_WRITE_COALESCE_MAX_BYTES { + frame_buf.clear(); + frame_buf.reserve(wire_len); + frame_buf.extend_from_slice(&[first, lw[0], lw[1], lw[2]]); + frame_buf.extend_from_slice(data); + write_all_client_or_cancel(client_writer, frame_buf.as_slice(), cancel).await?; + MeD2cWriteMode::Coalesced + } else { + let header = [first, lw[0], lw[1], lw[2]]; + write_all_client_or_cancel(client_writer, &header, cancel).await?; + write_all_client_or_cancel(client_writer, data, cancel).await?; + MeD2cWriteMode::Split + } + } else { + return Err(ProxyError::Proxy(format!( + "Abridged frame too large: {}", + data.len() + ))); + } + } + ProtoTag::Intermediate | ProtoTag::Secure => { + let padding_len = if proto_tag == ProtoTag::Secure { + if !is_valid_secure_payload_len(data.len()) { + return Err(ProxyError::Proxy(format!( + "Secure payload must be 4-byte aligned, got {}", + data.len() + ))); + } + secure_padding_len(data.len(), rng) + } else { + 0 + }; + + let (len_val, total) = + compute_intermediate_secure_wire_len(data.len(), padding_len, quickack)?; + if total <= ME_D2C_SINGLE_WRITE_COALESCE_MAX_BYTES { + frame_buf.clear(); + frame_buf.reserve(total); + frame_buf.extend_from_slice(&len_val.to_le_bytes()); + frame_buf.extend_from_slice(data); + if padding_len > 0 { + let start = frame_buf.len(); + frame_buf.resize(start + padding_len, 0); + rng.fill(&mut frame_buf[start..]); + } + write_all_client_or_cancel(client_writer, frame_buf.as_slice(), cancel).await?; + MeD2cWriteMode::Coalesced + } else { + let header = len_val.to_le_bytes(); + write_all_client_or_cancel(client_writer, &header, cancel).await?; + write_all_client_or_cancel(client_writer, data, cancel).await?; + if padding_len > 0 { + frame_buf.clear(); + if frame_buf.capacity() < padding_len { + frame_buf.reserve(padding_len); + } + frame_buf.resize(padding_len, 0); + rng.fill(frame_buf.as_mut_slice()); + write_all_client_or_cancel(client_writer, frame_buf.as_slice(), cancel).await?; + } + MeD2cWriteMode::Split + } + } + }; + + Ok(write_mode) +} + +pub(super) async fn write_client_ack( + client_writer: &mut CryptoWriter, + proto_tag: ProtoTag, + confirm: u32, + cancel: &CancellationToken, +) -> Result<()> +where + W: AsyncWrite + Unpin + Send + 'static, +{ + let bytes = if proto_tag == ProtoTag::Abridged { + confirm.to_be_bytes() + } else { + confirm.to_le_bytes() + }; + write_all_client_or_cancel(client_writer, &bytes, cancel).await +} + +pub(super) async fn write_all_client_or_cancel( + client_writer: &mut CryptoWriter, + bytes: &[u8], + cancel: &CancellationToken, +) -> Result<()> +where + W: AsyncWrite + Unpin + Send + 'static, +{ + tokio::select! { + result = client_writer.write_all(bytes) => result.map_err(ProxyError::Io), + _ = cancel.cancelled() => Err(ProxyError::MiddleClientWriterCancelled), + } +} + +pub(super) async fn flush_client_or_cancel( + client_writer: &mut CryptoWriter, + cancel: &CancellationToken, +) -> Result<()> +where + W: AsyncWrite + Unpin + Send + 'static, +{ + tokio::select! { + result = client_writer.flush() => result.map_err(ProxyError::Io), + _ = cancel.cancelled() => Err(ProxyError::MiddleClientWriterCancelled), + } +} diff --git a/src/proxy/middle_relay/desync.rs b/src/proxy/middle_relay/desync.rs new file mode 100644 index 0000000..02ca5b9 --- /dev/null +++ b/src/proxy/middle_relay/desync.rs @@ -0,0 +1,406 @@ +use super::*; + +#[derive(Default)] +pub(crate) struct DesyncDedupRotationState { + current_started_at: Option, +} + +pub(in crate::proxy::middle_relay) struct RelayForensicsState { + pub(in crate::proxy::middle_relay) trace_id: u64, + pub(in crate::proxy::middle_relay) conn_id: u64, + pub(in crate::proxy::middle_relay) user: String, + pub(in crate::proxy::middle_relay) peer: SocketAddr, + pub(in crate::proxy::middle_relay) peer_hash: u64, + pub(in crate::proxy::middle_relay) started_at: Instant, + pub(in crate::proxy::middle_relay) bytes_c2me: u64, + pub(in crate::proxy::middle_relay) bytes_me2c: Arc, + pub(in crate::proxy::middle_relay) desync_all_full: bool, +} + +#[cfg(test)] +pub(crate) fn hash_value(value: &T) -> u64 { + let mut hasher = DefaultHasher::new(); + value.hash(&mut hasher); + hasher.finish() +} + +fn hash_value_in(shared: &ProxySharedState, value: &T) -> u64 { + shared.middle_relay.desync_hasher.hash_one(value) +} + +#[cfg(test)] +pub(crate) fn hash_ip(ip: IpAddr) -> u64 { + hash_value(&ip) +} + +pub(super) fn hash_ip_in(shared: &ProxySharedState, ip: IpAddr) -> u64 { + hash_value_in(shared, &ip) +} + +fn should_emit_full_desync_in( + shared: &ProxySharedState, + key: u64, + all_full: bool, + now: Instant, +) -> bool { + if all_full { + return true; + } + + let dedup_current = &shared.middle_relay.desync_dedup; + let dedup_previous = &shared.middle_relay.desync_dedup_previous; + let rotation_state = &shared.middle_relay.desync_dedup_rotation_state; + + let mut state = match rotation_state.lock() { + Ok(guard) => guard, + Err(poisoned) => { + let mut guard = poisoned.into_inner(); + *guard = DesyncDedupRotationState::default(); + rotation_state.clear_poison(); + guard + } + }; + + let rotate_now = match state.current_started_at { + Some(current_started_at) => match now.checked_duration_since(current_started_at) { + Some(elapsed) => elapsed >= DESYNC_DEDUP_WINDOW, + None => true, + }, + None => true, + }; + if rotate_now { + dedup_previous.clear(); + for entry in dedup_current.iter() { + dedup_previous.insert(*entry.key(), *entry.value()); + } + dedup_current.clear(); + state.current_started_at = Some(now); + } + + if let Some(seen_at) = dedup_current.get(&key).map(|entry| *entry.value()) { + let within_window = match now.checked_duration_since(seen_at) { + Some(elapsed) => elapsed < DESYNC_DEDUP_WINDOW, + None => true, + }; + if within_window { + return false; + } + dedup_current.insert(key, now); + return true; + } + + if let Some(seen_at) = dedup_previous.get(&key).map(|entry| *entry.value()) { + let within_window = match now.checked_duration_since(seen_at) { + Some(elapsed) => elapsed < DESYNC_DEDUP_WINDOW, + None => true, + }; + if within_window { + dedup_current.insert(key, seen_at); + return false; + } + dedup_previous.remove(&key); + } + + if dedup_current.len() >= DESYNC_DEDUP_MAX_ENTRIES { + dedup_previous.clear(); + for entry in dedup_current.iter() { + dedup_previous.insert(*entry.key(), *entry.value()); + } + dedup_current.clear(); + state.current_started_at = Some(now); + dedup_current.insert(key, now); + should_emit_full_desync_full_cache_in(shared, now) + } else { + dedup_current.insert(key, now); + true + } +} + +fn should_emit_full_desync_full_cache_in(shared: &ProxySharedState, now: Instant) -> bool { + let gate = &shared.middle_relay.desync_full_cache_last_emit_at; + let Ok(mut last_emit_at) = gate.lock() else { + return false; + }; + + match *last_emit_at { + None => { + *last_emit_at = Some(now); + true + } + Some(last) => { + let Some(elapsed) = now.checked_duration_since(last) else { + *last_emit_at = Some(now); + return true; + }; + if elapsed >= DESYNC_FULL_CACHE_EMIT_MIN_INTERVAL { + *last_emit_at = Some(now); + true + } else { + false + } + } + } +} + +pub(crate) fn desync_forensics_len_bytes(len: usize) -> ([u8; 4], bool) { + match u32::try_from(len) { + Ok(value) => (value.to_le_bytes(), false), + Err(_) => (u32::MAX.to_le_bytes(), true), + } +} + +pub(super) fn report_desync_frame_too_large_in( + shared: &ProxySharedState, + state: &RelayForensicsState, + proto_tag: ProtoTag, + frame_counter: u64, + max_frame: usize, + len: usize, + raw_len_bytes: Option<[u8; 4]>, + stats: &Stats, +) -> ProxyError { + let (fallback_len_buf, len_buf_truncated) = desync_forensics_len_bytes(len); + let len_buf = raw_len_bytes.unwrap_or(fallback_len_buf); + let looks_like_tls = raw_len_bytes + .map(|b| b[0] == 0x16 && b[1] == 0x03) + .unwrap_or(false); + let looks_like_http = raw_len_bytes + .map(|b| matches!(b[0], b'G' | b'P' | b'H' | b'C' | b'D')) + .unwrap_or(false); + let now = Instant::now(); + let dedup_key = hash_value_in( + shared, + &( + state.user.as_str(), + state.peer_hash, + proto_tag, + DESYNC_ERROR_CLASS, + ), + ); + let emit_full = should_emit_full_desync_in(shared, dedup_key, state.desync_all_full, now); + let duration_ms = state.started_at.elapsed().as_millis() as u64; + let bytes_me2c = state.bytes_me2c.load(Ordering::Relaxed); + + stats.increment_desync_total(); + stats.increment_relay_protocol_desync_close_total(); + stats.observe_desync_frames_ok(frame_counter); + if emit_full { + stats.increment_desync_full_logged(); + warn!( + trace_id = format_args!("0x{:016x}", state.trace_id), + conn_id = state.conn_id, + user = %state.user, + peer_hash = format_args!("0x{:016x}", state.peer_hash), + proto = ?proto_tag, + mode = "middle_proxy", + is_tls = true, + duration_ms, + bytes_c2me = state.bytes_c2me, + bytes_me2c, + raw_len = len, + raw_len_hex = format_args!("0x{:08x}", len), + raw_len_bytes_truncated = len_buf_truncated, + raw_bytes = format_args!( + "{:02x} {:02x} {:02x} {:02x}", + len_buf[0], len_buf[1], len_buf[2], len_buf[3] + ), + max_frame, + tls_like = looks_like_tls, + http_like = looks_like_http, + frames_ok = frame_counter, + dedup_window_secs = DESYNC_DEDUP_WINDOW.as_secs(), + desync_all_full = state.desync_all_full, + full_reason = if state.desync_all_full { "desync_all_full" } else { "first_in_dedup_window" }, + error_class = DESYNC_ERROR_CLASS, + "Frame too large — crypto desync forensics" + ); + debug!( + trace_id = format_args!("0x{:016x}", state.trace_id), + conn_id = state.conn_id, + user = %state.user, + peer = %state.peer, + "Frame too large forensic peer detail" + ); + } else { + stats.increment_desync_suppressed(); + debug!( + trace_id = format_args!("0x{:016x}", state.trace_id), + conn_id = state.conn_id, + user = %state.user, + peer_hash = format_args!("0x{:016x}", state.peer_hash), + proto = ?proto_tag, + duration_ms, + bytes_c2me = state.bytes_c2me, + bytes_me2c, + raw_len = len, + frames_ok = frame_counter, + dedup_window_secs = DESYNC_DEDUP_WINDOW.as_secs(), + error_class = DESYNC_ERROR_CLASS, + "Frame too large — crypto desync forensic suppressed" + ); + } + + ProxyError::Proxy(format!( + "Frame too large: {len} (max {max_frame}), frames_ok={frame_counter}, conn_id={}, trace_id=0x{:016x}", + state.conn_id, state.trace_id + )) +} + +#[cfg(test)] +pub(crate) fn report_desync_frame_too_large( + state: &RelayForensicsState, + proto_tag: ProtoTag, + frame_counter: u64, + max_frame: usize, + len: usize, + raw_len_bytes: Option<[u8; 4]>, + stats: &Stats, +) -> ProxyError { + let shared = ProxySharedState::new(); + report_desync_frame_too_large_in( + shared.as_ref(), + state, + proto_tag, + frame_counter, + max_frame, + len, + raw_len_bytes, + stats, + ) +} + +#[cfg(test)] +pub(crate) fn should_emit_full_desync_for_testing( + shared: &ProxySharedState, + key: u64, + all_full: bool, + now: Instant, +) -> bool { + if all_full { + return true; + } + + let dedup_current = &shared.middle_relay.desync_dedup; + let dedup_previous = &shared.middle_relay.desync_dedup_previous; + + let Ok(mut state) = shared.middle_relay.desync_dedup_rotation_state.lock() else { + return false; + }; + + let rotate_now = match state.current_started_at { + Some(current_started_at) => match now.checked_duration_since(current_started_at) { + Some(elapsed) => elapsed >= DESYNC_DEDUP_WINDOW, + None => true, + }, + None => true, + }; + if rotate_now { + dedup_previous.clear(); + for entry in dedup_current.iter() { + dedup_previous.insert(*entry.key(), *entry.value()); + } + dedup_current.clear(); + state.current_started_at = Some(now); + } + + if let Some(seen_at) = dedup_current.get(&key).map(|entry| *entry.value()) { + let within_window = match now.checked_duration_since(seen_at) { + Some(elapsed) => elapsed < DESYNC_DEDUP_WINDOW, + None => true, + }; + if within_window { + return false; + } + dedup_current.insert(key, now); + return true; + } + + if let Some(seen_at) = dedup_previous.get(&key).map(|entry| *entry.value()) { + let within_window = match now.checked_duration_since(seen_at) { + Some(elapsed) => elapsed < DESYNC_DEDUP_WINDOW, + None => true, + }; + if within_window { + dedup_current.insert(key, seen_at); + return false; + } + dedup_previous.remove(&key); + } + + if dedup_current.len() >= DESYNC_DEDUP_MAX_ENTRIES { + dedup_previous.clear(); + for entry in dedup_current.iter() { + dedup_previous.insert(*entry.key(), *entry.value()); + } + dedup_current.clear(); + state.current_started_at = Some(now); + dedup_current.insert(key, now); + let Ok(mut last_emit_at) = shared.middle_relay.desync_full_cache_last_emit_at.lock() else { + return false; + }; + return match *last_emit_at { + None => { + *last_emit_at = Some(now); + true + } + Some(last) => { + let Some(elapsed) = now.checked_duration_since(last) else { + *last_emit_at = Some(now); + return true; + }; + if elapsed >= DESYNC_FULL_CACHE_EMIT_MIN_INTERVAL { + *last_emit_at = Some(now); + true + } else { + false + } + } + }; + } + + dedup_current.insert(key, now); + true +} + +#[cfg(test)] +pub(crate) fn clear_desync_dedup_for_testing_in_shared(shared: &ProxySharedState) { + shared.middle_relay.desync_dedup.clear(); + shared.middle_relay.desync_dedup_previous.clear(); + if let Ok(mut rotation_state) = shared.middle_relay.desync_dedup_rotation_state.lock() { + *rotation_state = DesyncDedupRotationState::default(); + } + if let Ok(mut last_emit_at) = shared.middle_relay.desync_full_cache_last_emit_at.lock() { + *last_emit_at = None; + } +} + +#[cfg(test)] +pub(crate) fn desync_dedup_len_for_testing(shared: &ProxySharedState) -> usize { + shared.middle_relay.desync_dedup.len() +} + +#[cfg(test)] +pub(crate) fn desync_dedup_insert_for_testing(shared: &ProxySharedState, key: u64, at: Instant) { + shared.middle_relay.desync_dedup.insert(key, at); +} + +#[cfg(test)] +pub(crate) fn desync_dedup_get_for_testing(shared: &ProxySharedState, key: u64) -> Option { + shared + .middle_relay + .desync_dedup + .get(&key) + .map(|entry| *entry.value()) +} + +#[cfg(test)] +pub(crate) fn desync_dedup_keys_for_testing( + shared: &ProxySharedState, +) -> std::collections::HashSet { + shared + .middle_relay + .desync_dedup + .iter() + .map(|entry| *entry.key()) + .collect() +} diff --git a/src/proxy/middle_relay/idle.rs b/src/proxy/middle_relay/idle.rs new file mode 100644 index 0000000..3a33869 --- /dev/null +++ b/src/proxy/middle_relay/idle.rs @@ -0,0 +1,341 @@ +use super::*; + +mod read; + +pub(crate) use self::read::read_client_payload_with_idle_policy_in; +#[cfg(test)] +pub(crate) use self::read::{ + read_client_payload, read_client_payload_legacy, read_client_payload_with_idle_policy, +}; + +#[derive(Default)] +pub(crate) struct RelayIdleCandidateRegistry { + pub(in crate::proxy::middle_relay) by_conn_id: HashMap, + pub(in crate::proxy::middle_relay) ordered: BTreeSet<(u64, u64)>, + pressure_event_seq: u64, + pressure_consumed_seq: u64, +} + +/// Queue metadata used to preserve FIFO ordering for idle relay eviction. +#[derive(Clone, Copy)] +pub(in crate::proxy::middle_relay) struct RelayIdleCandidateMeta { + pub(in crate::proxy::middle_relay) mark_order_seq: u64, + pub(in crate::proxy::middle_relay) mark_pressure_seq: u64, +} + +pub(super) fn relay_idle_candidate_registry_lock_in( + shared: &ProxySharedState, +) -> std::sync::MutexGuard<'_, RelayIdleCandidateRegistry> { + let registry = &shared.middle_relay.relay_idle_registry; + match registry.lock() { + Ok(guard) => guard, + Err(poisoned) => { + let mut guard = poisoned.into_inner(); + *guard = RelayIdleCandidateRegistry::default(); + registry.clear_poison(); + guard + } + } +} + +pub(super) fn mark_relay_idle_candidate_in(shared: &ProxySharedState, conn_id: u64) -> bool { + let mut guard = relay_idle_candidate_registry_lock_in(shared); + + if guard.by_conn_id.contains_key(&conn_id) { + return false; + } + + let mark_order_seq = shared + .middle_relay + .relay_idle_mark_seq + .fetch_add(1, Ordering::Relaxed) + .saturating_add(1); + let meta = RelayIdleCandidateMeta { + mark_order_seq, + mark_pressure_seq: guard.pressure_event_seq, + }; + guard.by_conn_id.insert(conn_id, meta); + guard.ordered.insert((meta.mark_order_seq, conn_id)); + true +} + +pub(super) fn clear_relay_idle_candidate_in(shared: &ProxySharedState, conn_id: u64) { + let mut guard = relay_idle_candidate_registry_lock_in(shared); + + if let Some(meta) = guard.by_conn_id.remove(&conn_id) { + guard.ordered.remove(&(meta.mark_order_seq, conn_id)); + } +} + +pub(super) fn note_relay_pressure_event_in(shared: &ProxySharedState) { + let mut guard = relay_idle_candidate_registry_lock_in(shared); + guard.pressure_event_seq = guard.pressure_event_seq.wrapping_add(1); +} + +pub(crate) fn note_global_relay_pressure(shared: &ProxySharedState) { + note_relay_pressure_event_in(shared); +} + +pub(super) fn relay_pressure_event_seq_in(shared: &ProxySharedState) -> u64 { + let guard = relay_idle_candidate_registry_lock_in(shared); + guard.pressure_event_seq +} + +pub(super) fn maybe_evict_idle_candidate_on_pressure_in( + shared: &ProxySharedState, + conn_id: u64, + seen_pressure_seq: &mut u64, + stats: &Stats, +) -> bool { + let mut guard = relay_idle_candidate_registry_lock_in(shared); + + let latest_pressure_seq = guard.pressure_event_seq; + if latest_pressure_seq == *seen_pressure_seq { + return false; + } + *seen_pressure_seq = latest_pressure_seq; + + if latest_pressure_seq == guard.pressure_consumed_seq { + return false; + } + + if guard.ordered.is_empty() { + guard.pressure_consumed_seq = latest_pressure_seq; + return false; + } + + let oldest = guard + .ordered + .iter() + .next() + .map(|(_, candidate_conn_id)| *candidate_conn_id); + if oldest != Some(conn_id) { + return false; + } + + let Some(candidate_meta) = guard.by_conn_id.get(&conn_id).copied() else { + return false; + }; + + if latest_pressure_seq == candidate_meta.mark_pressure_seq { + return false; + } + + if let Some(meta) = guard.by_conn_id.remove(&conn_id) { + guard.ordered.remove(&(meta.mark_order_seq, conn_id)); + } + guard.pressure_consumed_seq = latest_pressure_seq; + stats.increment_relay_pressure_evict_total(); + true +} + +#[derive(Clone, Copy)] +pub(in crate::proxy::middle_relay) struct RelayClientIdlePolicy { + pub(in crate::proxy::middle_relay) enabled: bool, + pub(in crate::proxy::middle_relay) soft_idle: Duration, + pub(in crate::proxy::middle_relay) hard_idle: Duration, + pub(in crate::proxy::middle_relay) grace_after_downstream_activity: Duration, + pub(in crate::proxy::middle_relay) legacy_frame_read_timeout: Duration, +} + +impl RelayClientIdlePolicy { + pub(super) fn from_config(config: &ProxyConfig) -> Self { + let frame_read_timeout = + Duration::from_secs(config.timeouts.relay_client_idle_hard_secs.max(1)); + if !config.timeouts.relay_idle_policy_v2_enabled { + return Self::disabled(frame_read_timeout); + } + + let soft_idle = Duration::from_secs(config.timeouts.relay_client_idle_soft_secs.max(1)); + let hard_idle = Duration::from_secs(config.timeouts.relay_client_idle_hard_secs.max(1)); + let grace_after_downstream_activity = Duration::from_secs( + config + .timeouts + .relay_idle_grace_after_downstream_activity_secs, + ); + + Self { + enabled: true, + soft_idle, + hard_idle, + grace_after_downstream_activity, + legacy_frame_read_timeout: frame_read_timeout, + } + } + + pub(in crate::proxy::middle_relay) fn disabled(frame_read_timeout: Duration) -> Self { + Self { + enabled: false, + soft_idle: frame_read_timeout, + hard_idle: frame_read_timeout, + grace_after_downstream_activity: Duration::ZERO, + legacy_frame_read_timeout: frame_read_timeout, + } + } + + pub(super) fn apply_pressure_caps(&mut self, profile: ConntrackPressureProfile) { + let pressure_soft_idle_cap = Duration::from_secs(profile.middle_soft_idle_cap_secs()); + let pressure_hard_idle_cap = Duration::from_secs(profile.middle_hard_idle_cap_secs()); + + self.soft_idle = self.soft_idle.min(pressure_soft_idle_cap); + self.hard_idle = self.hard_idle.min(pressure_hard_idle_cap); + if self.soft_idle > self.hard_idle { + self.soft_idle = self.hard_idle; + } + self.legacy_frame_read_timeout = self.legacy_frame_read_timeout.min(pressure_hard_idle_cap); + if self.grace_after_downstream_activity > self.hard_idle { + self.grace_after_downstream_activity = self.hard_idle; + } + } +} + +#[derive(Clone, Copy)] +pub(in crate::proxy::middle_relay) struct RelayClientIdleState { + pub(in crate::proxy::middle_relay) last_client_frame_at: Instant, + pub(in crate::proxy::middle_relay) soft_idle_marked: bool, + pub(in crate::proxy::middle_relay) tiny_frame_debt: u32, +} + +impl RelayClientIdleState { + pub(super) fn new(now: Instant) -> Self { + Self { + last_client_frame_at: now, + soft_idle_marked: false, + tiny_frame_debt: 0, + } + } + + pub(super) fn on_client_frame(&mut self, now: Instant) { + self.last_client_frame_at = now; + self.soft_idle_marked = false; + } + + pub(super) fn on_client_tiny_frame(&mut self, now: Instant) { + self.last_client_frame_at = now; + } +} + +#[cfg(test)] +pub(crate) fn mark_relay_idle_candidate_for_testing( + shared: &ProxySharedState, + conn_id: u64, +) -> bool { + let registry = &shared.middle_relay.relay_idle_registry; + let mut guard = match registry.lock() { + Ok(guard) => guard, + Err(poisoned) => { + let mut guard = poisoned.into_inner(); + *guard = RelayIdleCandidateRegistry::default(); + registry.clear_poison(); + guard + } + }; + + if guard.by_conn_id.contains_key(&conn_id) { + return false; + } + + let mark_order_seq = shared + .middle_relay + .relay_idle_mark_seq + .fetch_add(1, Ordering::Relaxed); + let mark_pressure_seq = guard.pressure_event_seq; + let meta = RelayIdleCandidateMeta { + mark_order_seq, + mark_pressure_seq, + }; + guard.by_conn_id.insert(conn_id, meta); + guard.ordered.insert((mark_order_seq, conn_id)); + true +} + +#[cfg(test)] +pub(crate) fn oldest_relay_idle_candidate_for_testing(shared: &ProxySharedState) -> Option { + let registry = &shared.middle_relay.relay_idle_registry; + let guard = match registry.lock() { + Ok(guard) => guard, + Err(poisoned) => { + let mut guard = poisoned.into_inner(); + *guard = RelayIdleCandidateRegistry::default(); + registry.clear_poison(); + guard + } + }; + guard.ordered.iter().next().map(|(_, conn_id)| *conn_id) +} + +#[cfg(test)] +pub(crate) fn clear_relay_idle_candidate_for_testing(shared: &ProxySharedState, conn_id: u64) { + let registry = &shared.middle_relay.relay_idle_registry; + let mut guard = match registry.lock() { + Ok(guard) => guard, + Err(poisoned) => { + let mut guard = poisoned.into_inner(); + *guard = RelayIdleCandidateRegistry::default(); + registry.clear_poison(); + guard + } + }; + if let Some(meta) = guard.by_conn_id.remove(&conn_id) { + guard.ordered.remove(&(meta.mark_order_seq, conn_id)); + } +} + +#[cfg(test)] +pub(crate) fn clear_relay_idle_pressure_state_for_testing_in_shared(shared: &ProxySharedState) { + if let Ok(mut guard) = shared.middle_relay.relay_idle_registry.lock() { + *guard = RelayIdleCandidateRegistry::default(); + } + shared + .middle_relay + .relay_idle_mark_seq + .store(0, Ordering::Relaxed); +} + +#[cfg(test)] +pub(crate) fn note_relay_pressure_event_for_testing(shared: &ProxySharedState) { + note_relay_pressure_event_in(shared); +} + +#[cfg(test)] +pub(crate) fn relay_pressure_event_seq_for_testing(shared: &ProxySharedState) -> u64 { + relay_pressure_event_seq_in(shared) +} + +#[cfg(test)] +pub(crate) fn relay_idle_mark_seq_for_testing(shared: &ProxySharedState) -> u64 { + shared + .middle_relay + .relay_idle_mark_seq + .load(Ordering::Relaxed) +} + +#[cfg(test)] +pub(crate) fn maybe_evict_idle_candidate_on_pressure_for_testing( + shared: &ProxySharedState, + conn_id: u64, + seen_pressure_seq: &mut u64, + stats: &Stats, +) -> bool { + maybe_evict_idle_candidate_on_pressure_in(shared, conn_id, seen_pressure_seq, stats) +} + +#[cfg(test)] +pub(crate) fn set_relay_pressure_state_for_testing( + shared: &ProxySharedState, + pressure_event_seq: u64, + pressure_consumed_seq: u64, +) { + let registry = &shared.middle_relay.relay_idle_registry; + let mut guard = match registry.lock() { + Ok(guard) => guard, + Err(poisoned) => { + let mut guard = poisoned.into_inner(); + *guard = RelayIdleCandidateRegistry::default(); + registry.clear_poison(); + guard + } + }; + guard.pressure_event_seq = pressure_event_seq; + guard.pressure_consumed_seq = pressure_consumed_seq; +} diff --git a/src/proxy/middle_relay/idle/read.rs b/src/proxy/middle_relay/idle/read.rs new file mode 100644 index 0000000..270f104 --- /dev/null +++ b/src/proxy/middle_relay/idle/read.rs @@ -0,0 +1,442 @@ +use super::*; + +pub(crate) async fn read_client_payload_with_idle_policy_in( + client_reader: &mut CryptoReader, + proto_tag: ProtoTag, + max_frame: usize, + buffer_pool: &Arc, + forensics: &RelayForensicsState, + frame_counter: &mut u64, + stats: &Stats, + shared: &ProxySharedState, + idle_policy: &RelayClientIdlePolicy, + idle_state: &mut RelayClientIdleState, + last_downstream_activity_ms: &AtomicU64, + session_started_at: Instant, +) -> Result> +where + R: AsyncRead + Unpin + Send + 'static, +{ + const LEGACY_MAX_CONSECUTIVE_ZERO_LEN_FRAMES: u32 = 4; + + async fn read_exact_with_policy( + client_reader: &mut CryptoReader, + buf: &mut [u8], + idle_policy: &RelayClientIdlePolicy, + idle_state: &mut RelayClientIdleState, + last_downstream_activity_ms: &AtomicU64, + session_started_at: Instant, + forensics: &RelayForensicsState, + stats: &Stats, + shared: &ProxySharedState, + read_label: &'static str, + ) -> Result<()> + where + R: AsyncRead + Unpin + Send + 'static, + { + fn hard_deadline( + idle_policy: &RelayClientIdlePolicy, + idle_state: &RelayClientIdleState, + session_started_at: Instant, + last_downstream_activity_ms: u64, + ) -> Instant { + let mut deadline = idle_state.last_client_frame_at + idle_policy.hard_idle; + if idle_policy.grace_after_downstream_activity.is_zero() { + return deadline; + } + + let downstream_at = + session_started_at + Duration::from_millis(last_downstream_activity_ms); + if downstream_at > idle_state.last_client_frame_at { + let grace_deadline = downstream_at + idle_policy.grace_after_downstream_activity; + if grace_deadline > deadline { + deadline = grace_deadline; + } + } + deadline + } + + let mut filled = 0usize; + while filled < buf.len() { + let timeout_window = if idle_policy.enabled { + let now = Instant::now(); + let downstream_ms = last_downstream_activity_ms.load(Ordering::Relaxed); + let hard_deadline = + hard_deadline(idle_policy, idle_state, session_started_at, downstream_ms); + if !idle_state.soft_idle_marked + && now.saturating_duration_since(idle_state.last_client_frame_at) + >= idle_policy.soft_idle + { + idle_state.soft_idle_marked = true; + if mark_relay_idle_candidate_in(shared, forensics.conn_id) { + stats.increment_relay_idle_soft_mark_total(); + } + info!( + trace_id = format_args!("0x{:016x}", forensics.trace_id), + conn_id = forensics.conn_id, + user = %forensics.user, + read_label, + soft_idle_secs = idle_policy.soft_idle.as_secs(), + hard_idle_secs = idle_policy.hard_idle.as_secs(), + grace_secs = idle_policy.grace_after_downstream_activity.as_secs(), + "Middle-relay soft idle mark" + ); + } + + let soft_deadline = idle_state.last_client_frame_at + idle_policy.soft_idle; + let next_deadline = if idle_state.soft_idle_marked { + hard_deadline + } else { + soft_deadline.min(hard_deadline) + }; + let mut remaining = next_deadline.saturating_duration_since(now); + if remaining.is_zero() { + remaining = Duration::from_millis(1); + } + remaining.min(RELAY_IDLE_IO_POLL_MAX) + } else { + idle_policy.legacy_frame_read_timeout + }; + + let read_result = timeout(timeout_window, client_reader.read(&mut buf[filled..])).await; + match read_result { + Ok(Ok(0)) => { + return Err(ProxyError::Io(std::io::Error::from( + std::io::ErrorKind::UnexpectedEof, + ))); + } + Ok(Ok(n)) => { + filled = filled.saturating_add(n); + } + Ok(Err(e)) => return Err(ProxyError::Io(e)), + Err(_) if !idle_policy.enabled => { + return Err(ProxyError::Io(std::io::Error::new( + std::io::ErrorKind::TimedOut, + format!( + "middle-relay client frame read timeout while reading {read_label}" + ), + ))); + } + Err(_) => { + let now = Instant::now(); + let downstream_ms = last_downstream_activity_ms.load(Ordering::Relaxed); + let hard_deadline = + hard_deadline(idle_policy, idle_state, session_started_at, downstream_ms); + if now >= hard_deadline { + clear_relay_idle_candidate_in(shared, forensics.conn_id); + stats.increment_relay_idle_hard_close_total(); + let client_idle_secs = now + .saturating_duration_since(idle_state.last_client_frame_at) + .as_secs(); + let downstream_idle_secs = now + .saturating_duration_since( + session_started_at + Duration::from_millis(downstream_ms), + ) + .as_secs(); + warn!( + trace_id = format_args!("0x{:016x}", forensics.trace_id), + conn_id = forensics.conn_id, + user = %forensics.user, + read_label, + client_idle_secs, + downstream_idle_secs, + soft_idle_secs = idle_policy.soft_idle.as_secs(), + hard_idle_secs = idle_policy.hard_idle.as_secs(), + grace_secs = idle_policy.grace_after_downstream_activity.as_secs(), + "Middle-relay hard idle close" + ); + return Err(ProxyError::Io(std::io::Error::new( + std::io::ErrorKind::TimedOut, + format!( + "middle-relay hard idle timeout while reading {read_label}: client_idle_secs={client_idle_secs}, downstream_idle_secs={downstream_idle_secs}, soft_idle_secs={}, hard_idle_secs={}, grace_secs={}", + idle_policy.soft_idle.as_secs(), + idle_policy.hard_idle.as_secs(), + idle_policy.grace_after_downstream_activity.as_secs(), + ), + ))); + } + } + } + } + + Ok(()) + } + + let mut consecutive_zero_len_frames = 0u32; + loop { + let (len, quickack, raw_len_bytes) = match proto_tag { + ProtoTag::Abridged => { + let mut first = [0u8; 1]; + match read_exact_with_policy( + client_reader, + &mut first, + idle_policy, + idle_state, + last_downstream_activity_ms, + session_started_at, + forensics, + stats, + shared, + "abridged.first_len_byte", + ) + .await + { + Ok(()) => {} + Err(ProxyError::Io(e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => { + return Ok(None); + } + Err(e) => return Err(e), + } + + let quickack = (first[0] & 0x80) != 0; + let len_words = if (first[0] & 0x7f) == 0x7f { + let mut ext = [0u8; 3]; + read_exact_with_policy( + client_reader, + &mut ext, + idle_policy, + idle_state, + last_downstream_activity_ms, + session_started_at, + forensics, + stats, + shared, + "abridged.extended_len", + ) + .await?; + u32::from_le_bytes([ext[0], ext[1], ext[2], 0]) as usize + } else { + (first[0] & 0x7f) as usize + }; + + let len = len_words + .checked_mul(4) + .ok_or_else(|| ProxyError::Proxy("Abridged frame length overflow".into()))?; + (len, quickack, None) + } + ProtoTag::Intermediate | ProtoTag::Secure => { + let mut len_buf = [0u8; 4]; + match read_exact_with_policy( + client_reader, + &mut len_buf, + idle_policy, + idle_state, + last_downstream_activity_ms, + session_started_at, + forensics, + stats, + shared, + "len_prefix", + ) + .await + { + Ok(()) => {} + Err(ProxyError::Io(e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => { + return Ok(None); + } + Err(e) => return Err(e), + } + let quickack = (len_buf[3] & 0x80) != 0; + ( + (u32::from_le_bytes(len_buf) & 0x7fff_ffff) as usize, + quickack, + Some(len_buf), + ) + } + }; + + if len == 0 { + idle_state.on_client_tiny_frame(Instant::now()); + idle_state.tiny_frame_debt = idle_state + .tiny_frame_debt + .saturating_add(TINY_FRAME_DEBT_PER_TINY); + if idle_state.tiny_frame_debt >= TINY_FRAME_DEBT_LIMIT { + stats.increment_relay_protocol_desync_close_total(); + return Err(ProxyError::Proxy(format!( + "Tiny frame overhead limit exceeded: debt={}, conn_id={}", + idle_state.tiny_frame_debt, forensics.conn_id + ))); + } + + if !idle_policy.enabled { + consecutive_zero_len_frames = consecutive_zero_len_frames.saturating_add(1); + if consecutive_zero_len_frames > LEGACY_MAX_CONSECUTIVE_ZERO_LEN_FRAMES { + stats.increment_relay_protocol_desync_close_total(); + return Err(ProxyError::Proxy( + "Excessive zero-length abridged frames".to_string(), + )); + } + } + continue; + } + if len < 4 && proto_tag != ProtoTag::Abridged { + warn!( + trace_id = format_args!("0x{:016x}", forensics.trace_id), + conn_id = forensics.conn_id, + user = %forensics.user, + len, + proto = ?proto_tag, + "Frame too small — corrupt or probe" + ); + stats.increment_relay_protocol_desync_close_total(); + return Err(ProxyError::Proxy(format!("Frame too small: {len}"))); + } + + if len > max_frame { + return Err(report_desync_frame_too_large_in( + shared, + forensics, + proto_tag, + *frame_counter, + max_frame, + len, + raw_len_bytes, + stats, + )); + } + + let secure_payload_len = if proto_tag == ProtoTag::Secure { + match secure_payload_len_from_wire_len(len) { + Some(payload_len) => payload_len, + None => { + stats.increment_secure_padding_invalid(); + stats.increment_relay_protocol_desync_close_total(); + return Err(ProxyError::Proxy(format!( + "Invalid secure frame length: {len}" + ))); + } + } + } else { + len + }; + + let mut payload = buffer_pool.get(); + payload.clear(); + let current_cap = payload.capacity(); + if current_cap < len { + payload.reserve(len - current_cap); + } + payload.resize(len, 0); + read_exact_with_policy( + client_reader, + &mut payload[..len], + idle_policy, + idle_state, + last_downstream_activity_ms, + session_started_at, + forensics, + stats, + shared, + "payload", + ) + .await?; + + // Secure Intermediate: strip validated trailing padding bytes. + if proto_tag == ProtoTag::Secure { + payload.truncate(secure_payload_len); + } + *frame_counter += 1; + idle_state.on_client_frame(Instant::now()); + idle_state.tiny_frame_debt = idle_state.tiny_frame_debt.saturating_sub(1); + clear_relay_idle_candidate_in(shared, forensics.conn_id); + return Ok(Some((payload, quickack))); + } +} + +#[cfg(test)] +pub(crate) async fn read_client_payload_with_idle_policy( + client_reader: &mut CryptoReader, + proto_tag: ProtoTag, + max_frame: usize, + buffer_pool: &Arc, + forensics: &RelayForensicsState, + frame_counter: &mut u64, + stats: &Stats, + idle_policy: &RelayClientIdlePolicy, + idle_state: &mut RelayClientIdleState, + last_downstream_activity_ms: &AtomicU64, + session_started_at: Instant, +) -> Result> +where + R: AsyncRead + Unpin + Send + 'static, +{ + let shared = ProxySharedState::new(); + read_client_payload_with_idle_policy_in( + client_reader, + proto_tag, + max_frame, + buffer_pool, + forensics, + frame_counter, + stats, + shared.as_ref(), + idle_policy, + idle_state, + last_downstream_activity_ms, + session_started_at, + ) + .await +} + +#[cfg(test)] +pub(crate) async fn read_client_payload_legacy( + client_reader: &mut CryptoReader, + proto_tag: ProtoTag, + max_frame: usize, + frame_read_timeout: Duration, + buffer_pool: &Arc, + forensics: &RelayForensicsState, + frame_counter: &mut u64, + stats: &Stats, +) -> Result> +where + R: AsyncRead + Unpin + Send + 'static, +{ + let now = Instant::now(); + let shared = ProxySharedState::new(); + let mut idle_state = RelayClientIdleState::new(now); + let last_downstream_activity_ms = AtomicU64::new(0); + let idle_policy = RelayClientIdlePolicy::disabled(frame_read_timeout); + read_client_payload_with_idle_policy_in( + client_reader, + proto_tag, + max_frame, + buffer_pool, + forensics, + frame_counter, + stats, + shared.as_ref(), + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + now, + ) + .await +} + +#[cfg(test)] +pub(crate) async fn read_client_payload( + client_reader: &mut CryptoReader, + proto_tag: ProtoTag, + max_frame: usize, + frame_read_timeout: Duration, + buffer_pool: &Arc, + forensics: &RelayForensicsState, + frame_counter: &mut u64, + stats: &Stats, +) -> Result> +where + R: AsyncRead + Unpin + Send + 'static, +{ + read_client_payload_legacy( + client_reader, + proto_tag, + max_frame, + frame_read_timeout, + buffer_pool, + forensics, + frame_counter, + stats, + ) + .await +} diff --git a/src/proxy/middle_relay/quota.rs b/src/proxy/middle_relay/quota.rs new file mode 100644 index 0000000..618056c --- /dev/null +++ b/src/proxy/middle_relay/quota.rs @@ -0,0 +1,151 @@ +use super::*; + +pub(super) enum MiddleQuotaReserveError { + LimitExceeded, + Contended, + Cancelled, + DeadlineExceeded, +} + +pub(super) fn quota_soft_cap(limit: u64, overshoot: u64) -> u64 { + limit.saturating_add(overshoot) +} + +pub(super) async fn reserve_user_quota_with_yield( + user_stats: &UserStats, + bytes: u64, + limit: u64, + stats: &Stats, + cancel: &CancellationToken, + deadline: Option, +) -> std::result::Result { + let mut backoff_ms = QUOTA_RESERVE_BACKOFF_MIN_MS; + let mut backoff_rounds = 0usize; + loop { + for _ in 0..QUOTA_RESERVE_SPIN_RETRIES { + match user_stats.quota_try_reserve(bytes, limit) { + Ok(total) => return Ok(total), + Err(QuotaReserveError::LimitExceeded) => { + return Err(MiddleQuotaReserveError::LimitExceeded); + } + Err(QuotaReserveError::Contended) => { + stats.increment_quota_contention_total(); + std::hint::spin_loop(); + } + } + } + + tokio::task::yield_now().await; + if deadline.is_some_and(|deadline| Instant::now() >= deadline) { + stats.increment_quota_contention_timeout_total(); + return Err(MiddleQuotaReserveError::DeadlineExceeded); + } + tokio::select! { + _ = tokio::time::sleep(Duration::from_millis(backoff_ms)) => {} + _ = cancel.cancelled() => { + stats.increment_quota_acquire_cancelled_total(); + return Err(MiddleQuotaReserveError::Cancelled); + } + } + backoff_rounds = backoff_rounds.saturating_add(1); + if backoff_rounds >= QUOTA_RESERVE_MAX_BACKOFF_ROUNDS { + stats.increment_quota_contention_timeout_total(); + return Err(MiddleQuotaReserveError::Contended); + } + backoff_ms = backoff_ms + .saturating_mul(2) + .min(QUOTA_RESERVE_BACKOFF_MAX_MS); + } +} + +pub(super) async fn wait_for_traffic_budget( + lease: Option<&Arc>, + direction: RateDirection, + bytes: u64, + deadline: Option, +) -> Result<()> { + if bytes == 0 { + return Ok(()); + } + let Some(lease) = lease else { + return Ok(()); + }; + + let mut remaining = bytes; + while remaining > 0 { + let consume = lease.try_consume(direction, remaining); + if consume.granted > 0 { + remaining = remaining.saturating_sub(consume.granted); + continue; + } + + let wait_started_at = Instant::now(); + if deadline.is_some_and(|deadline| wait_started_at >= deadline) { + return Err(ProxyError::TrafficBudgetWaitDeadlineExceeded); + } + tokio::time::sleep(next_refill_delay()).await; + let wait_ms = wait_started_at + .elapsed() + .as_millis() + .min(u128::from(u64::MAX)) as u64; + lease.observe_wait_ms( + direction, + consume.blocked_user, + consume.blocked_cidr, + wait_ms, + ); + } + + Ok(()) +} + +pub(super) async fn wait_for_traffic_budget_or_cancel( + lease: Option<&Arc>, + direction: RateDirection, + bytes: u64, + cancel: &CancellationToken, + stats: &Stats, + deadline: Option, +) -> Result<()> { + if bytes == 0 { + return Ok(()); + } + let Some(lease) = lease else { + return Ok(()); + }; + + let mut remaining = bytes; + while remaining > 0 { + let consume = lease.try_consume(direction, remaining); + if consume.granted > 0 { + remaining = remaining.saturating_sub(consume.granted); + continue; + } + + let wait_started_at = Instant::now(); + if deadline.is_some_and(|deadline| wait_started_at >= deadline) { + stats.increment_flow_wait_middle_rate_limit_cancelled_total(); + return Err(ProxyError::TrafficBudgetWaitDeadlineExceeded); + } + tokio::select! { + _ = tokio::time::sleep(next_refill_delay()) => {} + _ = cancel.cancelled() => { + stats.increment_flow_wait_middle_rate_limit_cancelled_total(); + return Err(ProxyError::TrafficBudgetWaitCancelled); + } + } + let wait_ms = wait_started_at + .elapsed() + .as_millis() + .min(u128::from(u64::MAX)) as u64; + lease.observe_wait_ms( + direction, + consume.blocked_user, + consume.blocked_cidr, + wait_ms, + ); + stats.observe_flow_wait_middle_rate_limit_ms(wait_ms); + } + + Ok(()) +} diff --git a/src/proxy/middle_relay/session.rs b/src/proxy/middle_relay/session.rs new file mode 100644 index 0000000..81cf297 --- /dev/null +++ b/src/proxy/middle_relay/session.rs @@ -0,0 +1,830 @@ +use super::*; + +pub(crate) async fn handle_via_middle_proxy( + mut crypto_reader: CryptoReader, + crypto_writer: CryptoWriter, + success: HandshakeSuccess, + me_pool: Arc, + stats: Arc, + config: Arc, + buffer_pool: Arc, + local_addr: SocketAddr, + rng: Arc, + mut route_rx: watch::Receiver, + route_snapshot: RouteCutoverState, + session_id: u64, + shared: Arc, +) -> Result<()> +where + R: AsyncRead + Unpin + Send + 'static, + W: AsyncWrite + Unpin + Send + 'static, +{ + let user = success.user.clone(); + let quota_limit = config.access.user_data_quota.get(&user).copied(); + let quota_user_stats = quota_limit.map(|_| stats.get_or_create_user_stats_handle(&user)); + let peer = success.peer; + let traffic_lease = shared.traffic_limiter.acquire_lease(&user, peer.ip()); + let proto_tag = success.proto_tag; + let pool_generation = me_pool.current_generation(); + + debug!( + user = %user, + peer = %peer, + dc = success.dc_idx, + proto = ?proto_tag, + mode = "middle_proxy", + pool_generation, + "Routing via Middle-End" + ); + + let (conn_id, me_rx) = me_pool.registry().register().await; + let trace_id = session_id; + let bytes_me2c = Arc::new(AtomicU64::new(0)); + let mut forensics = RelayForensicsState { + trace_id, + conn_id, + user: user.clone(), + peer, + peer_hash: hash_ip_in(shared.as_ref(), peer.ip()), + started_at: Instant::now(), + bytes_c2me: 0, + bytes_me2c: bytes_me2c.clone(), + desync_all_full: config.general.desync_all_full, + }; + + stats.increment_user_connects(&user); + let _me_connection_lease = stats.acquire_me_connection_lease(); + + if let Some(cutover) = + affected_cutover_state(&route_rx, RelayRouteMode::Middle, route_snapshot.generation) + { + let delay = cutover_stagger_delay(session_id, cutover.generation); + warn!( + conn_id, + target_mode = cutover.mode.as_str(), + cutover_generation = cutover.generation, + delay_ms = delay.as_millis() as u64, + "Cutover affected middle session before relay start, closing client connection" + ); + let _cutover_park_lease = stats.acquire_middle_cutover_park_lease(); + tokio::time::sleep(delay).await; + let _ = me_pool.send_close(conn_id).await; + me_pool.registry().unregister(conn_id).await; + return Err(ProxyError::RouteSwitched); + } + + // Per-user ad_tag from access.user_ad_tags; fallback to general.ad_tag (hot-reloadable) + let user_tag: Option> = config + .access + .user_ad_tags + .get(&user) + .and_then(|s| hex::decode(s).ok()) + .filter(|v| v.len() == 16); + let global_tag: Option> = config + .general + .ad_tag + .as_ref() + .and_then(|s| hex::decode(s).ok()) + .filter(|v| v.len() == 16); + let effective_tag = user_tag.or(global_tag); + + let proto_flags = proto_flags_for_tag(proto_tag, effective_tag.is_some()); + let effective_tag_array = effective_tag + .as_deref() + .and_then(|tag| <[u8; 16]>::try_from(tag).ok()); + debug!( + trace_id = format_args!("0x{:016x}", trace_id), + user = %user, + conn_id, + peer_hash = format_args!("0x{:016x}", forensics.peer_hash), + desync_all_full = forensics.desync_all_full, + proto_flags = format_args!("0x{:08x}", proto_flags), + pool_generation, + "ME relay started" + ); + + let translated_local_addr = me_pool.translate_our_addr(local_addr); + + let frame_limit = config.general.max_client_frame; + let mut relay_idle_policy = RelayClientIdlePolicy::from_config(&config); + let mut pressure_caps_applied = false; + if shared.conntrack_pressure_active() { + relay_idle_policy.apply_pressure_caps(config.server.conntrack_control.profile); + pressure_caps_applied = true; + } + let session_started_at = forensics.started_at; + let mut relay_idle_state = RelayClientIdleState::new(session_started_at); + let last_downstream_activity_ms = Arc::new(AtomicU64::new(0)); + + let c2me_channel_capacity = config + .general + .me_c2me_channel_capacity + .max(C2ME_CHANNEL_CAPACITY_FALLBACK); + let c2me_send_timeout = match config.general.me_c2me_send_timeout_ms { + 0 => None, + timeout_ms => Some(Duration::from_millis(timeout_ms)), + }; + let c2me_byte_budget = c2me_queued_permit_budget(c2me_channel_capacity, frame_limit); + let c2me_byte_semaphore = Arc::new(Semaphore::new(c2me_byte_budget)); + let (c2me_tx, mut c2me_rx) = mpsc::channel::(c2me_channel_capacity); + let me_pool_c2me = me_pool.clone(); + let mut c2me_sender = tokio::spawn(async move { + let mut sent_since_yield = 0usize; + while let Some(cmd) = c2me_rx.recv().await { + match cmd { + C2MeCommand::Data { + payload, + flags, + _permit, + } => { + me_pool_c2me + .send_proxy_req_pooled( + conn_id, + success.dc_idx, + peer, + translated_local_addr, + payload, + flags, + effective_tag_array, + ) + .await?; + sent_since_yield = sent_since_yield.saturating_add(1); + if should_yield_c2me_sender(sent_since_yield, !c2me_rx.is_empty()) { + sent_since_yield = 0; + tokio::task::yield_now().await; + } + } + C2MeCommand::Close => { + let _ = me_pool_c2me.send_close(conn_id).await; + return Ok(()); + } + } + } + Ok(()) + }); + + let (stop_tx, mut stop_rx) = oneshot::channel::<()>(); + let flow_cancel = CancellationToken::new(); + let mut me_rx_task = me_rx; + let stats_clone = stats.clone(); + let rng_clone = rng.clone(); + let user_clone = user.clone(); + let quota_user_stats_me_writer = quota_user_stats.clone(); + let traffic_lease_me_writer = traffic_lease.clone(); + let flow_cancel_me_writer = flow_cancel.clone(); + let last_downstream_activity_ms_clone = last_downstream_activity_ms.clone(); + let bytes_me2c_clone = bytes_me2c.clone(); + let d2c_flush_policy = MeD2cFlushPolicy::from_config(&config); + let mut me_writer = tokio::spawn(async move { + let mut writer = crypto_writer; + let mut frame_buf = Vec::with_capacity(16 * 1024); + let shrink_threshold = d2c_flush_policy.frame_buf_shrink_threshold_bytes; + + fn shrink_session_vec(buf: &mut Vec, threshold: usize) { + if buf.capacity() > threshold { + buf.clear(); + buf.shrink_to(threshold); + } else { + buf.clear(); + } + } + + loop { + tokio::select! { + msg = me_rx_task.recv() => { + let Some(first) = msg else { + debug!(conn_id, "ME channel closed"); + shrink_session_vec(&mut frame_buf, shrink_threshold); + return Err(ProxyError::MiddleConnectionLost); + }; + + let mut batch_frames = 0usize; + let mut batch_bytes = 0usize; + let mut flush_immediately; + let mut max_delay_fired = false; + + let first_is_downstream_activity = + matches!(&first, MeResponse::Data { .. } | MeResponse::Ack(_)); + match process_me_writer_response_with_traffic_lease( + first, + &mut writer, + proto_tag, + rng_clone.as_ref(), + &mut frame_buf, + stats_clone.as_ref(), + &user_clone, + quota_user_stats_me_writer.as_deref(), + quota_limit, + d2c_flush_policy.quota_soft_overshoot_bytes, + traffic_lease_me_writer.as_ref(), + &flow_cancel_me_writer, + bytes_me2c_clone.as_ref(), + conn_id, + d2c_flush_policy.ack_flush_immediate, + false, + ).await? { + MeWriterResponseOutcome::Continue { frames, bytes, flush_immediately: immediate } => { + if first_is_downstream_activity { + last_downstream_activity_ms_clone + .store(session_started_at.elapsed().as_millis() as u64, Ordering::Relaxed); + } + batch_frames = batch_frames.saturating_add(frames); + batch_bytes = batch_bytes.saturating_add(bytes); + flush_immediately = immediate; + } + MeWriterResponseOutcome::Close => { + let flush_started_at = if stats_clone.telemetry_policy().me_level.allows_debug() { + Some(Instant::now()) + } else { + None + }; + let _ = flush_client_or_cancel(&mut writer, &flow_cancel_me_writer).await; + let flush_duration_us = flush_started_at.map(|started| { + started + .elapsed() + .as_micros() + .min(u128::from(u64::MAX)) as u64 + }); + observe_me_d2c_flush_event( + stats_clone.as_ref(), + MeD2cFlushReason::Close, + batch_frames, + batch_bytes, + flush_duration_us, + ); + shrink_session_vec(&mut frame_buf, shrink_threshold); + return Ok(()); + } + } + + while !flush_immediately + && batch_frames < d2c_flush_policy.max_frames + && batch_bytes < d2c_flush_policy.max_bytes + { + let Ok(next) = me_rx_task.try_recv() else { + break; + }; + + let next_is_downstream_activity = + matches!(&next, MeResponse::Data { .. } | MeResponse::Ack(_)); + match process_me_writer_response_with_traffic_lease( + next, + &mut writer, + proto_tag, + rng_clone.as_ref(), + &mut frame_buf, + stats_clone.as_ref(), + &user_clone, + quota_user_stats_me_writer.as_deref(), + quota_limit, + d2c_flush_policy.quota_soft_overshoot_bytes, + traffic_lease_me_writer.as_ref(), + &flow_cancel_me_writer, + bytes_me2c_clone.as_ref(), + conn_id, + d2c_flush_policy.ack_flush_immediate, + true, + ).await? { + MeWriterResponseOutcome::Continue { frames, bytes, flush_immediately: immediate } => { + if next_is_downstream_activity { + last_downstream_activity_ms_clone + .store(session_started_at.elapsed().as_millis() as u64, Ordering::Relaxed); + } + batch_frames = batch_frames.saturating_add(frames); + batch_bytes = batch_bytes.saturating_add(bytes); + flush_immediately |= immediate; + } + MeWriterResponseOutcome::Close => { + let flush_started_at = + if stats_clone.telemetry_policy().me_level.allows_debug() { + Some(Instant::now()) + } else { + None + }; + let _ = + flush_client_or_cancel(&mut writer, &flow_cancel_me_writer).await; + let flush_duration_us = flush_started_at.map(|started| { + started + .elapsed() + .as_micros() + .min(u128::from(u64::MAX)) + as u64 + }); + observe_me_d2c_flush_event( + stats_clone.as_ref(), + MeD2cFlushReason::Close, + batch_frames, + batch_bytes, + flush_duration_us, + ); + shrink_session_vec(&mut frame_buf, shrink_threshold); + return Ok(()); + } + } + } + + if !flush_immediately + && !d2c_flush_policy.max_delay.is_zero() + && batch_frames < d2c_flush_policy.max_frames + && batch_bytes < d2c_flush_policy.max_bytes + { + stats_clone.increment_me_d2c_batch_timeout_armed_total(); + match tokio::time::timeout(d2c_flush_policy.max_delay, me_rx_task.recv()).await { + Ok(Some(next)) => { + let next_is_downstream_activity = + matches!(&next, MeResponse::Data { .. } | MeResponse::Ack(_)); + match process_me_writer_response_with_traffic_lease( + next, + &mut writer, + proto_tag, + rng_clone.as_ref(), + &mut frame_buf, + stats_clone.as_ref(), + &user_clone, + quota_user_stats_me_writer.as_deref(), + quota_limit, + d2c_flush_policy.quota_soft_overshoot_bytes, + traffic_lease_me_writer.as_ref(), + &flow_cancel_me_writer, + bytes_me2c_clone.as_ref(), + conn_id, + d2c_flush_policy.ack_flush_immediate, + true, + ).await? { + MeWriterResponseOutcome::Continue { frames, bytes, flush_immediately: immediate } => { + if next_is_downstream_activity { + last_downstream_activity_ms_clone + .store(session_started_at.elapsed().as_millis() as u64, Ordering::Relaxed); + } + batch_frames = batch_frames.saturating_add(frames); + batch_bytes = batch_bytes.saturating_add(bytes); + flush_immediately |= immediate; + } + MeWriterResponseOutcome::Close => { + let flush_started_at = if stats_clone + .telemetry_policy() + .me_level + .allows_debug() + { + Some(Instant::now()) + } else { + None + }; + let _ = flush_client_or_cancel( + &mut writer, + &flow_cancel_me_writer, + ) + .await; + let flush_duration_us = flush_started_at.map(|started| { + started + .elapsed() + .as_micros() + .min(u128::from(u64::MAX)) + as u64 + }); + observe_me_d2c_flush_event( + stats_clone.as_ref(), + MeD2cFlushReason::Close, + batch_frames, + batch_bytes, + flush_duration_us, + ); + shrink_session_vec(&mut frame_buf, shrink_threshold); + return Ok(()); + } + } + + while !flush_immediately + && batch_frames < d2c_flush_policy.max_frames + && batch_bytes < d2c_flush_policy.max_bytes + { + let Ok(extra) = me_rx_task.try_recv() else { + break; + }; + + let extra_is_downstream_activity = + matches!(&extra, MeResponse::Data { .. } | MeResponse::Ack(_)); + match process_me_writer_response_with_traffic_lease( + extra, + &mut writer, + proto_tag, + rng_clone.as_ref(), + &mut frame_buf, + stats_clone.as_ref(), + &user_clone, + quota_user_stats_me_writer.as_deref(), + quota_limit, + d2c_flush_policy.quota_soft_overshoot_bytes, + traffic_lease_me_writer.as_ref(), + &flow_cancel_me_writer, + bytes_me2c_clone.as_ref(), + conn_id, + d2c_flush_policy.ack_flush_immediate, + true, + ).await? { + MeWriterResponseOutcome::Continue { frames, bytes, flush_immediately: immediate } => { + if extra_is_downstream_activity { + last_downstream_activity_ms_clone + .store(session_started_at.elapsed().as_millis() as u64, Ordering::Relaxed); + } + batch_frames = batch_frames.saturating_add(frames); + batch_bytes = batch_bytes.saturating_add(bytes); + flush_immediately |= immediate; + } + MeWriterResponseOutcome::Close => { + let flush_started_at = if stats_clone + .telemetry_policy() + .me_level + .allows_debug() + { + Some(Instant::now()) + } else { + None + }; + let _ = flush_client_or_cancel( + &mut writer, + &flow_cancel_me_writer, + ) + .await; + let flush_duration_us = flush_started_at.map(|started| { + started + .elapsed() + .as_micros() + .min(u128::from(u64::MAX)) + as u64 + }); + observe_me_d2c_flush_event( + stats_clone.as_ref(), + MeD2cFlushReason::Close, + batch_frames, + batch_bytes, + flush_duration_us, + ); + shrink_session_vec(&mut frame_buf, shrink_threshold); + return Ok(()); + } + } + } + } + Ok(None) => { + debug!(conn_id, "ME channel closed"); + shrink_session_vec(&mut frame_buf, shrink_threshold); + return Err(ProxyError::MiddleConnectionLost); + } + Err(_) => { + max_delay_fired = true; + stats_clone.increment_me_d2c_batch_timeout_fired_total(); + } + } + } + + let flush_reason = classify_me_d2c_flush_reason( + flush_immediately, + batch_frames, + d2c_flush_policy.max_frames, + batch_bytes, + d2c_flush_policy.max_bytes, + max_delay_fired, + ); + let flush_started_at = if stats_clone.telemetry_policy().me_level.allows_debug() { + Some(Instant::now()) + } else { + None + }; + flush_client_or_cancel(&mut writer, &flow_cancel_me_writer).await?; + let flush_duration_us = flush_started_at.map(|started| { + started + .elapsed() + .as_micros() + .min(u128::from(u64::MAX)) as u64 + }); + observe_me_d2c_flush_event( + stats_clone.as_ref(), + flush_reason, + batch_frames, + batch_bytes, + flush_duration_us, + ); + let shrink_threshold = d2c_flush_policy.frame_buf_shrink_threshold_bytes; + let shrink_trigger = shrink_threshold + .saturating_mul(ME_D2C_FRAME_BUF_SHRINK_HYSTERESIS_FACTOR); + if frame_buf.capacity() > shrink_trigger { + let cap_before = frame_buf.capacity(); + frame_buf.shrink_to(shrink_threshold); + let cap_after = frame_buf.capacity(); + let bytes_freed = cap_before.saturating_sub(cap_after) as u64; + stats_clone.observe_me_d2c_frame_buf_shrink(bytes_freed); + } + } + _ = &mut stop_rx => { + debug!(conn_id, "ME writer stop signal"); + shrink_session_vec(&mut frame_buf, shrink_threshold); + return Ok(()); + } + } + } + }); + + let mut main_result: Result<()> = Ok(()); + let mut client_closed = false; + let mut frame_counter: u64 = 0; + let mut route_watch_open = true; + let mut seen_pressure_seq = relay_pressure_event_seq_in(shared.as_ref()); + loop { + if shared.conntrack_pressure_active() && !pressure_caps_applied { + relay_idle_policy.apply_pressure_caps(config.server.conntrack_control.profile); + pressure_caps_applied = true; + } + + if relay_idle_policy.enabled + && maybe_evict_idle_candidate_on_pressure_in( + shared.as_ref(), + conn_id, + &mut seen_pressure_seq, + stats.as_ref(), + ) + { + info!( + conn_id, + trace_id = format_args!("0x{:016x}", trace_id), + user = %user, + "Middle-relay pressure eviction for idle-candidate session" + ); + let _ = enqueue_c2me_command_in( + shared.as_ref(), + &c2me_tx, + C2MeCommand::Close, + c2me_send_timeout, + stats.as_ref(), + ) + .await; + main_result = Err(ProxyError::Proxy( + "middle-relay session evicted under pressure (idle-candidate)".to_string(), + )); + break; + } + + if let Some(cutover) = + affected_cutover_state(&route_rx, RelayRouteMode::Middle, route_snapshot.generation) + { + let delay = cutover_stagger_delay(session_id, cutover.generation); + warn!( + conn_id, + target_mode = cutover.mode.as_str(), + cutover_generation = cutover.generation, + delay_ms = delay.as_millis() as u64, + "Cutover affected middle session, closing client connection" + ); + let _cutover_park_lease = stats.acquire_middle_cutover_park_lease(); + tokio::time::sleep(delay).await; + let _ = enqueue_c2me_command_in( + shared.as_ref(), + &c2me_tx, + C2MeCommand::Close, + c2me_send_timeout, + stats.as_ref(), + ) + .await; + main_result = Err(ProxyError::RouteSwitched); + break; + } + + tokio::select! { + changed = route_rx.changed(), if route_watch_open => { + if changed.is_err() { + route_watch_open = false; + } + } + payload_result = read_client_payload_with_idle_policy_in( + &mut crypto_reader, + proto_tag, + frame_limit, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + shared.as_ref(), + &relay_idle_policy, + &mut relay_idle_state, + last_downstream_activity_ms.as_ref(), + session_started_at, + ) => { + match payload_result { + Ok(Some((payload, quickack))) => { + trace!(conn_id, bytes = payload.len(), "C->ME frame"); + wait_for_traffic_budget( + traffic_lease.as_ref(), + RateDirection::Up, + payload.len() as u64, + None, + ) + .await?; + forensics.bytes_c2me = forensics + .bytes_c2me + .saturating_add(payload.len() as u64); + if let (Some(limit), Some(user_stats)) = + (quota_limit, quota_user_stats.as_deref()) + { + match reserve_user_quota_with_yield( + user_stats, + payload.len() as u64, + limit, + stats.as_ref(), + &flow_cancel, + None, + ) + .await + { + Ok(_) => {} + Err(MiddleQuotaReserveError::LimitExceeded) => { + main_result = Err(ProxyError::DataQuotaExceeded { + user: user.clone(), + }); + break; + } + Err(MiddleQuotaReserveError::Contended) => { + main_result = Err(ProxyError::Proxy( + "ME C->ME quota reservation contended".into(), + )); + break; + } + Err(MiddleQuotaReserveError::Cancelled) => { + main_result = Err(ProxyError::Proxy( + "ME C->ME quota reservation cancelled".into(), + )); + break; + } + Err(MiddleQuotaReserveError::DeadlineExceeded) => { + main_result = Err(ProxyError::Proxy( + "ME C->ME quota reservation deadline exceeded".into(), + )); + break; + } + } + stats.add_user_octets_from_handle(user_stats, payload.len() as u64); + } else { + stats.add_user_octets_from(&user, payload.len() as u64); + } + let mut flags = proto_flags; + if quickack { + flags |= RPC_FLAG_QUICKACK; + } + if payload.len() >= 8 && payload[..8].iter().all(|b| *b == 0) { + flags |= RPC_FLAG_NOT_ENCRYPTED; + } + let payload_permit = match acquire_c2me_payload_permit( + &c2me_byte_semaphore, + payload.len(), + c2me_send_timeout, + stats.as_ref(), + ) + .await + { + Ok(permit) => permit, + Err(e) => { + main_result = Err(e); + break; + } + }; + // Keep client read loop lightweight: route heavy ME send path via a dedicated task. + if enqueue_c2me_command_in( + shared.as_ref(), + &c2me_tx, + C2MeCommand::Data { + payload, + flags, + _permit: payload_permit, + }, + c2me_send_timeout, + stats.as_ref(), + ) + .await + .is_err() + { + main_result = Err(ProxyError::Proxy("ME sender channel closed".into())); + break; + } + } + Ok(None) => { + debug!(conn_id, "Client EOF"); + client_closed = true; + let _ = enqueue_c2me_command_in( + shared.as_ref(), + &c2me_tx, + C2MeCommand::Close, + c2me_send_timeout, + stats.as_ref(), + ) + .await; + break; + } + Err(e) => { + main_result = Err(e); + break; + } + } + } + } + } + + drop(c2me_tx); + let c2me_result = match timeout(ME_CHILD_JOIN_TIMEOUT, &mut c2me_sender).await { + Ok(joined) => { + joined.unwrap_or_else(|e| Err(ProxyError::Proxy(format!("ME sender join error: {e}")))) + } + Err(_) => { + stats.increment_me_child_join_timeout_total(); + stats.increment_me_child_abort_total(); + c2me_sender.abort(); + Err(ProxyError::Proxy("ME sender join timeout".into())) + } + }; + + flow_cancel.cancel(); + let _ = stop_tx.send(()); + let mut writer_result = match timeout(ME_CHILD_JOIN_TIMEOUT, &mut me_writer).await { + Ok(joined) => { + joined.unwrap_or_else(|e| Err(ProxyError::Proxy(format!("ME writer join error: {e}")))) + } + Err(_) => { + stats.increment_me_child_join_timeout_total(); + stats.increment_me_child_abort_total(); + me_writer.abort(); + Err(ProxyError::Proxy("ME writer join timeout".into())) + } + }; + + // When client closes, but ME channel stopped as unregistered - it isnt error + if client_closed && matches!(writer_result, Err(ProxyError::MiddleConnectionLost)) { + writer_result = Ok(()); + } + + let result = match (main_result, c2me_result, writer_result) { + (Ok(()), Ok(()), Ok(())) => Ok(()), + (Err(e), _, _) => Err(e), + (_, Err(e), _) => Err(e), + (_, _, Err(e)) => Err(e), + }; + + debug!( + user = %user, + conn_id, + trace_id = format_args!("0x{:016x}", trace_id), + duration_ms = forensics.started_at.elapsed().as_millis() as u64, + bytes_c2me = forensics.bytes_c2me, + bytes_me2c = forensics.bytes_me2c.load(Ordering::Relaxed), + frames_ok = frame_counter, + "ME relay cleanup" + ); + + let close_reason = classify_conntrack_close_reason(&result); + let publish_result = shared.publish_conntrack_close_event(ConntrackCloseEvent { + src: peer, + dst: local_addr, + reason: close_reason, + }); + if !matches!( + publish_result, + ConntrackClosePublishResult::Sent | ConntrackClosePublishResult::Disabled + ) { + stats.increment_conntrack_close_event_drop_total(); + } + + clear_relay_idle_candidate_in(shared.as_ref(), conn_id); + me_pool.registry().unregister(conn_id).await; + buffer_pool.trim_to(buffer_pool.max_buffers().min(64)); + let pool_snapshot = buffer_pool.stats(); + stats.set_buffer_pool_gauges( + pool_snapshot.pooled, + pool_snapshot.allocated, + pool_snapshot.allocated.saturating_sub(pool_snapshot.pooled), + ); + result +} + +fn classify_conntrack_close_reason(result: &Result<()>) -> ConntrackCloseReason { + match result { + Ok(()) => ConntrackCloseReason::NormalEof, + Err(ProxyError::Io(error)) if matches!(error.kind(), std::io::ErrorKind::TimedOut) => { + ConntrackCloseReason::Timeout + } + Err(ProxyError::Io(error)) + if matches!( + error.kind(), + std::io::ErrorKind::ConnectionReset + | std::io::ErrorKind::ConnectionAborted + | std::io::ErrorKind::BrokenPipe + | std::io::ErrorKind::NotConnected + | std::io::ErrorKind::UnexpectedEof + ) => + { + ConntrackCloseReason::Reset + } + Err(ProxyError::Proxy(message)) + if message.contains("pressure") || message.contains("evicted") => + { + ConntrackCloseReason::Pressure + } + Err(_) => ConntrackCloseReason::Other, + } +} diff --git a/src/proxy/relay.rs b/src/proxy/relay.rs index eb5fceb..5ea9e87 100644 --- a/src/proxy/relay.rs +++ b/src/proxy/relay.rs @@ -52,18 +52,15 @@ //! - `SharedCounters` (atomics) let the watchdog read stats without locking use crate::error::{ProxyError, Result}; -use crate::proxy::traffic_limiter::{RateDirection, TrafficLease, next_refill_delay}; -use crate::stats::{Stats, UserStats}; +use crate::proxy::traffic_limiter::TrafficLease; +use crate::stats::Stats; use crate::stream::BufferPool; -use std::io; -use std::pin::Pin; use std::sync::Arc; -use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; -use std::task::{Context, Poll}; +use std::sync::atomic::{AtomicBool, Ordering}; use std::time::Duration; -use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf, copy_bidirectional_with_sizes}; -use tokio::time::{Instant, Sleep}; -use tracing::{debug, trace, warn}; +use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, copy_bidirectional_with_sizes}; +use tokio::time::Instant; +use tracing::{debug, warn}; // ============= Constants ============= @@ -85,700 +82,11 @@ fn watchdog_delta(current: u64, previous: u64) -> u64 { current.saturating_sub(previous) } -// ============= CombinedStream ============= - -/// Combines separate read and write halves into a single bidirectional stream. -/// -/// `copy_bidirectional` requires `AsyncRead + AsyncWrite` on each side, -/// but the handshake layer produces split reader/writer pairs -/// (e.g. `CryptoReader>` + `CryptoWriter<...>`). -/// -/// This wrapper reunifies them with zero overhead — each trait method -/// delegates directly to the corresponding half. No buffering, no copies. -/// -/// Safety: `poll_read` only touches `reader`, `poll_write` only touches `writer`, -/// so there's no aliasing even though both are called on the same `&mut self`. -struct CombinedStream { - reader: R, - writer: W, -} - -impl CombinedStream { - fn new(reader: R, writer: W) -> Self { - Self { reader, writer } - } -} - -impl AsyncRead for CombinedStream { - #[inline] - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - Pin::new(&mut self.get_mut().reader).poll_read(cx, buf) - } -} - -impl AsyncWrite for CombinedStream { - #[inline] - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - Pin::new(&mut self.get_mut().writer).poll_write(cx, buf) - } - - #[inline] - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.get_mut().writer).poll_flush(cx) - } - - #[inline] - fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.get_mut().writer).poll_shutdown(cx) - } -} - -// ============= SharedCounters ============= - -/// Atomic counters shared between the relay (via StatsIo) and the watchdog task. -/// -/// Using `Relaxed` ordering is sufficient because: -/// - Counters are monotonically increasing (no ABA problem) -/// - Slight staleness in watchdog reads is harmless (±10s check interval anyway) -/// - No ordering dependencies between different counters -struct SharedCounters { - /// Bytes read from client (C→S direction) - c2s_bytes: AtomicU64, - /// Bytes written to client (S→C direction) - s2c_bytes: AtomicU64, - /// Number of poll_read completions (≈ C→S chunks) - c2s_ops: AtomicU64, - /// Number of poll_write completions (≈ S→C chunks) - s2c_ops: AtomicU64, - /// Milliseconds since relay epoch of last I/O activity - last_activity_ms: AtomicU64, -} - -impl SharedCounters { - fn new() -> Self { - Self { - c2s_bytes: AtomicU64::new(0), - s2c_bytes: AtomicU64::new(0), - c2s_ops: AtomicU64::new(0), - s2c_ops: AtomicU64::new(0), - last_activity_ms: AtomicU64::new(0), - } - } - - /// Record activity at this instant. - #[inline] - fn touch(&self, now: Instant, epoch: Instant) { - let ms = now.duration_since(epoch).as_millis() as u64; - self.last_activity_ms.store(ms, Ordering::Relaxed); - } - - /// How long since last recorded activity. - fn idle_duration(&self, now: Instant, epoch: Instant) -> Duration { - let last_ms = self.last_activity_ms.load(Ordering::Relaxed); - let now_ms = now.duration_since(epoch).as_millis() as u64; - Duration::from_millis(now_ms.saturating_sub(last_ms)) - } -} - -// ============= StatsIo ============= - -/// Transparent I/O wrapper that tracks per-user statistics and activity. -/// -/// Wraps the **client** side of the relay. Direction mapping: -/// -/// | poll method | direction | stats updated | -/// |-------------|-----------|--------------------------------------| -/// | `poll_read` | C→S | `octets_from`, `msgs_from`, counters | -/// | `poll_write` | S→C | `octets_to`, `msgs_to`, counters | -/// -/// Both update the shared activity timestamp for the watchdog. -/// -/// Note on message counts: the original code counted one `read()`/`write_all()` -/// as one "message". Here we count `poll_read`/`poll_write` completions instead. -/// Byte counts are identical; op counts may differ slightly due to different -/// internal buffering in `copy_bidirectional`. This is fine for monitoring. -struct StatsIo { - inner: S, - counters: Arc, - stats: Arc, - user: String, - user_stats: Arc, - traffic_lease: Option>, - c2s_rate_debt_bytes: u64, - c2s_wait: RateWaitState, - s2c_wait: RateWaitState, - quota_wait: RateWaitState, - quota_limit: Option, - quota_exceeded: Arc, - quota_bytes_since_check: u64, - epoch: Instant, -} - -#[derive(Default)] -struct RateWaitState { - sleep: Option>>, - started_at: Option, - blocked_user: bool, - blocked_cidr: bool, -} - -impl StatsIo { - #[cfg(test)] - fn new( - inner: S, - counters: Arc, - stats: Arc, - user: String, - quota_limit: Option, - quota_exceeded: Arc, - epoch: Instant, - ) -> Self { - Self::new_with_traffic_lease( - inner, - counters, - stats, - user, - None, - quota_limit, - quota_exceeded, - epoch, - ) - } - - fn new_with_traffic_lease( - inner: S, - counters: Arc, - stats: Arc, - user: String, - traffic_lease: Option>, - quota_limit: Option, - quota_exceeded: Arc, - epoch: Instant, - ) -> Self { - // Mark initial activity so the watchdog doesn't fire before data flows - counters.touch(Instant::now(), epoch); - let user_stats = stats.get_or_create_user_stats_handle(&user); - Self { - inner, - counters, - stats, - user, - user_stats, - traffic_lease, - c2s_rate_debt_bytes: 0, - c2s_wait: RateWaitState::default(), - s2c_wait: RateWaitState::default(), - quota_wait: RateWaitState::default(), - quota_limit, - quota_exceeded, - quota_bytes_since_check: 0, - epoch, - } - } - - fn record_wait( - wait: &mut RateWaitState, - lease: Option<&Arc>, - direction: RateDirection, - ) { - let Some(started_at) = wait.started_at.take() else { - return; - }; - let wait_ms = started_at.elapsed().as_millis().min(u128::from(u64::MAX)) as u64; - if let Some(lease) = lease { - lease.observe_wait_ms(direction, wait.blocked_user, wait.blocked_cidr, wait_ms); - } - wait.blocked_user = false; - wait.blocked_cidr = false; - } - - fn arm_wait(wait: &mut RateWaitState, blocked_user: bool, blocked_cidr: bool) { - if wait.sleep.is_none() { - wait.sleep = Some(Box::pin(tokio::time::sleep(next_refill_delay()))); - wait.started_at = Some(Instant::now()); - } - wait.blocked_user |= blocked_user; - wait.blocked_cidr |= blocked_cidr; - } - - fn poll_wait( - wait: &mut RateWaitState, - cx: &mut Context<'_>, - lease: Option<&Arc>, - direction: RateDirection, - ) -> Poll<()> { - let Some(sleep) = wait.sleep.as_mut() else { - return Poll::Ready(()); - }; - if sleep.as_mut().poll(cx).is_pending() { - return Poll::Pending; - } - wait.sleep = None; - Self::record_wait(wait, lease, direction); - Poll::Ready(()) - } - - fn settle_c2s_rate_debt(&mut self, cx: &mut Context<'_>) -> Poll<()> { - let Some(lease) = self.traffic_lease.as_ref() else { - self.c2s_rate_debt_bytes = 0; - return Poll::Ready(()); - }; - - while self.c2s_rate_debt_bytes > 0 { - let consume = lease.try_consume(RateDirection::Up, self.c2s_rate_debt_bytes); - if consume.granted > 0 { - self.c2s_rate_debt_bytes = self.c2s_rate_debt_bytes.saturating_sub(consume.granted); - continue; - } - Self::arm_wait( - &mut self.c2s_wait, - consume.blocked_user, - consume.blocked_cidr, - ); - if Self::poll_wait(&mut self.c2s_wait, cx, Some(lease), RateDirection::Up).is_pending() - { - return Poll::Pending; - } - } - - if Self::poll_wait(&mut self.c2s_wait, cx, Some(lease), RateDirection::Up).is_pending() { - return Poll::Pending; - } - - Poll::Ready(()) - } - - fn arm_quota_wait(&mut self, cx: &mut Context<'_>) -> Poll<()> { - Self::arm_wait(&mut self.quota_wait, false, false); - Self::poll_wait(&mut self.quota_wait, cx, None, RateDirection::Up) - } -} - -#[derive(Debug)] -struct QuotaIoSentinel; - -impl std::fmt::Display for QuotaIoSentinel { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str("user data quota exceeded") - } -} - -impl std::error::Error for QuotaIoSentinel {} - -fn quota_io_error() -> io::Error { - io::Error::new(io::ErrorKind::PermissionDenied, QuotaIoSentinel) -} - -fn is_quota_io_error(err: &io::Error) -> bool { - err.kind() == io::ErrorKind::PermissionDenied - && err - .get_ref() - .and_then(|source| source.downcast_ref::()) - .is_some() -} - -const QUOTA_NEAR_LIMIT_BYTES: u64 = 64 * 1024; -const QUOTA_LARGE_CHARGE_BYTES: u64 = 16 * 1024; -const QUOTA_ADAPTIVE_INTERVAL_MIN_BYTES: u64 = 4 * 1024; -const QUOTA_ADAPTIVE_INTERVAL_MAX_BYTES: u64 = 64 * 1024; -const QUOTA_RESERVE_SPIN_RETRIES: usize = 64; -const QUOTA_RESERVE_MAX_ROUNDS: usize = 8; - -#[inline] -fn quota_adaptive_interval_bytes(remaining_before: u64) -> u64 { - remaining_before.saturating_div(2).clamp( - QUOTA_ADAPTIVE_INTERVAL_MIN_BYTES, - QUOTA_ADAPTIVE_INTERVAL_MAX_BYTES, - ) -} - -#[inline] -fn should_immediate_quota_check(remaining_before: u64, charge_bytes: u64) -> bool { - remaining_before <= QUOTA_NEAR_LIMIT_BYTES || charge_bytes >= QUOTA_LARGE_CHARGE_BYTES -} - -fn refund_reserved_quota_bytes(user_stats: &UserStats, reserved_bytes: u64) { - if reserved_bytes == 0 { - return; - } - let mut current = user_stats.quota_used.load(Ordering::Relaxed); - loop { - let next = current.saturating_sub(reserved_bytes); - match user_stats.quota_used.compare_exchange_weak( - current, - next, - Ordering::Relaxed, - Ordering::Relaxed, - ) { - Ok(_) => return, - Err(observed) => current = observed, - } - } -} - -impl AsyncRead for StatsIo { - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - let this = self.get_mut(); - if this.quota_exceeded.load(Ordering::Acquire) { - return Poll::Ready(Err(quota_io_error())); - } - if this.settle_c2s_rate_debt(cx).is_pending() { - return Poll::Pending; - } - if buf.remaining() == 0 { - return Pin::new(&mut this.inner).poll_read(cx, buf); - } - - let mut remaining_before = None; - let mut reserved_read_bytes = 0u64; - let mut read_limit = buf.remaining(); - if let Some(limit) = this.quota_limit { - let used_before = this.user_stats.quota_used(); - let remaining = limit.saturating_sub(used_before); - if remaining == 0 { - this.quota_exceeded.store(true, Ordering::Release); - return Poll::Ready(Err(quota_io_error())); - } - remaining_before = Some(remaining); - read_limit = read_limit.min(remaining as usize); - if read_limit == 0 { - this.quota_exceeded.store(true, Ordering::Release); - return Poll::Ready(Err(quota_io_error())); - } - - let desired = read_limit as u64; - let mut reserve_rounds = 0usize; - while reserved_read_bytes == 0 { - for _ in 0..QUOTA_RESERVE_SPIN_RETRIES { - match this.user_stats.quota_try_reserve(desired, limit) { - Ok(_) => { - reserved_read_bytes = desired; - break; - } - Err(crate::stats::QuotaReserveError::LimitExceeded) => { - this.quota_exceeded.store(true, Ordering::Release); - return Poll::Ready(Err(quota_io_error())); - } - Err(crate::stats::QuotaReserveError::Contended) => { - this.stats.increment_quota_contention_total(); - } - } - } - - if reserved_read_bytes == 0 { - reserve_rounds = reserve_rounds.saturating_add(1); - if reserve_rounds >= QUOTA_RESERVE_MAX_ROUNDS { - this.stats.increment_quota_contention_timeout_total(); - if this.arm_quota_wait(cx).is_pending() { - return Poll::Pending; - } - reserve_rounds = 0; - } - } - } - } - - let limited_read = read_limit < buf.remaining(); - let read_result = if limited_read { - let mut limited_buf = ReadBuf::new(buf.initialize_unfilled_to(read_limit)); - match Pin::new(&mut this.inner).poll_read(cx, &mut limited_buf) { - Poll::Ready(Ok(())) => { - let n = limited_buf.filled().len(); - buf.advance(n); - Poll::Ready(Ok(n)) - } - Poll::Ready(Err(err)) => Poll::Ready(Err(err)), - Poll::Pending => Poll::Pending, - } - } else { - let before = buf.filled().len(); - match Pin::new(&mut this.inner).poll_read(cx, buf) { - Poll::Ready(Ok(())) => { - let n = buf.filled().len() - before; - Poll::Ready(Ok(n)) - } - Poll::Ready(Err(err)) => Poll::Ready(Err(err)), - Poll::Pending => Poll::Pending, - } - }; - - match read_result { - Poll::Ready(Ok(n)) => { - if reserved_read_bytes > n as u64 { - let refund_bytes = reserved_read_bytes - n as u64; - refund_reserved_quota_bytes(this.user_stats.as_ref(), refund_bytes); - this.stats.add_quota_refund_bytes_total(refund_bytes); - } - if n > 0 { - let n_to_charge = n as u64; - - if let Some(remaining) = remaining_before { - if should_immediate_quota_check(remaining, n_to_charge) { - this.quota_bytes_since_check = 0; - } else { - this.quota_bytes_since_check = - this.quota_bytes_since_check.saturating_add(n_to_charge); - let interval = quota_adaptive_interval_bytes(remaining); - if this.quota_bytes_since_check >= interval { - this.quota_bytes_since_check = 0; - } - } - } - if let Some(limit) = this.quota_limit - && this.user_stats.quota_used() >= limit - { - this.quota_exceeded.store(true, Ordering::Release); - } - - // C→S: client sent data - this.counters - .c2s_bytes - .fetch_add(n_to_charge, Ordering::Relaxed); - this.counters.c2s_ops.fetch_add(1, Ordering::Relaxed); - this.counters.touch(Instant::now(), this.epoch); - - this.stats - .add_user_traffic_from_handle(this.user_stats.as_ref(), n_to_charge); - if this.traffic_lease.is_some() { - this.c2s_rate_debt_bytes = - this.c2s_rate_debt_bytes.saturating_add(n_to_charge); - let _ = this.settle_c2s_rate_debt(cx); - } - - trace!(user = %this.user, bytes = n, "C->S"); - } - Poll::Ready(Ok(())) - } - Poll::Pending => { - if reserved_read_bytes > 0 { - refund_reserved_quota_bytes(this.user_stats.as_ref(), reserved_read_bytes); - this.stats.add_quota_refund_bytes_total(reserved_read_bytes); - } - Poll::Pending - } - Poll::Ready(Err(err)) => { - if reserved_read_bytes > 0 { - refund_reserved_quota_bytes(this.user_stats.as_ref(), reserved_read_bytes); - this.stats.add_quota_refund_bytes_total(reserved_read_bytes); - } - Poll::Ready(Err(err)) - } - } - } -} - -impl AsyncWrite for StatsIo { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - let this = self.get_mut(); - if this.quota_exceeded.load(Ordering::Acquire) { - return Poll::Ready(Err(quota_io_error())); - } - - let mut shaper_reserved_bytes = 0u64; - let mut write_buf = buf; - if let Some(lease) = this.traffic_lease.as_ref() { - if !buf.is_empty() { - loop { - let consume = lease.try_consume(RateDirection::Down, buf.len() as u64); - if consume.granted > 0 { - shaper_reserved_bytes = consume.granted; - if consume.granted < buf.len() as u64 { - write_buf = &buf[..consume.granted as usize]; - } - let _ = Self::poll_wait( - &mut this.s2c_wait, - cx, - Some(lease), - RateDirection::Down, - ); - break; - } - - Self::arm_wait( - &mut this.s2c_wait, - consume.blocked_user, - consume.blocked_cidr, - ); - if Self::poll_wait(&mut this.s2c_wait, cx, Some(lease), RateDirection::Down) - .is_pending() - { - return Poll::Pending; - } - } - } else { - let _ = Self::poll_wait(&mut this.s2c_wait, cx, Some(lease), RateDirection::Down); - } - } - - let mut remaining_before = None; - let mut reserved_bytes = 0u64; - if let Some(limit) = this.quota_limit { - if !write_buf.is_empty() { - let mut reserve_rounds = 0usize; - while reserved_bytes == 0 { - let used_before = this.user_stats.quota_used(); - let remaining = limit.saturating_sub(used_before); - if remaining == 0 { - if let Some(lease) = this.traffic_lease.as_ref() { - lease.refund(RateDirection::Down, shaper_reserved_bytes); - } - this.quota_exceeded.store(true, Ordering::Release); - return Poll::Ready(Err(quota_io_error())); - } - remaining_before = Some(remaining); - - let desired = remaining.min(write_buf.len() as u64); - let mut saw_contention = false; - for _ in 0..QUOTA_RESERVE_SPIN_RETRIES { - match this.user_stats.quota_try_reserve(desired, limit) { - Ok(_) => { - reserved_bytes = desired; - write_buf = &write_buf[..desired as usize]; - break; - } - Err(crate::stats::QuotaReserveError::LimitExceeded) => { - break; - } - Err(crate::stats::QuotaReserveError::Contended) => { - this.stats.increment_quota_contention_total(); - saw_contention = true; - } - } - } - - if reserved_bytes == 0 { - reserve_rounds = reserve_rounds.saturating_add(1); - if reserve_rounds >= QUOTA_RESERVE_MAX_ROUNDS { - this.stats.increment_quota_contention_timeout_total(); - if let Some(lease) = this.traffic_lease.as_ref() { - lease.refund(RateDirection::Down, shaper_reserved_bytes); - } - let _ = this.arm_quota_wait(cx); - return Poll::Pending; - } else if saw_contention { - std::hint::spin_loop(); - } - } - } - } else { - let used_before = this.user_stats.quota_used(); - let remaining = limit.saturating_sub(used_before); - if remaining == 0 { - if let Some(lease) = this.traffic_lease.as_ref() { - lease.refund(RateDirection::Down, shaper_reserved_bytes); - } - this.quota_exceeded.store(true, Ordering::Release); - return Poll::Ready(Err(quota_io_error())); - } - remaining_before = Some(remaining); - } - } - - match Pin::new(&mut this.inner).poll_write(cx, write_buf) { - Poll::Ready(Ok(n)) => { - if reserved_bytes > n as u64 { - let refund_bytes = reserved_bytes - n as u64; - refund_reserved_quota_bytes(this.user_stats.as_ref(), refund_bytes); - this.stats.add_quota_refund_bytes_total(refund_bytes); - } - if shaper_reserved_bytes > n as u64 - && let Some(lease) = this.traffic_lease.as_ref() - { - lease.refund(RateDirection::Down, shaper_reserved_bytes - n as u64); - } - if n > 0 { - if let Some(lease) = this.traffic_lease.as_ref() { - Self::record_wait(&mut this.s2c_wait, Some(lease), RateDirection::Down); - } - let n_to_charge = n as u64; - - // S→C: data written to client - this.counters - .s2c_bytes - .fetch_add(n_to_charge, Ordering::Relaxed); - this.counters.s2c_ops.fetch_add(1, Ordering::Relaxed); - this.counters.touch(Instant::now(), this.epoch); - - this.stats - .add_user_traffic_to_handle(this.user_stats.as_ref(), n_to_charge); - - if let (Some(limit), Some(remaining)) = (this.quota_limit, remaining_before) { - if should_immediate_quota_check(remaining, n_to_charge) { - this.quota_bytes_since_check = 0; - if this.user_stats.quota_used() >= limit { - this.quota_exceeded.store(true, Ordering::Release); - } - } else { - this.quota_bytes_since_check = - this.quota_bytes_since_check.saturating_add(n_to_charge); - let interval = quota_adaptive_interval_bytes(remaining); - if this.quota_bytes_since_check >= interval { - this.quota_bytes_since_check = 0; - if this.user_stats.quota_used() >= limit { - this.quota_exceeded.store(true, Ordering::Release); - } - } - } - } - - trace!(user = %this.user, bytes = n, "S->C"); - } - Poll::Ready(Ok(n)) - } - Poll::Ready(Err(err)) => { - if reserved_bytes > 0 { - refund_reserved_quota_bytes(this.user_stats.as_ref(), reserved_bytes); - this.stats.add_quota_refund_bytes_total(reserved_bytes); - } - if shaper_reserved_bytes > 0 - && let Some(lease) = this.traffic_lease.as_ref() - { - lease.refund(RateDirection::Down, shaper_reserved_bytes); - } - Poll::Ready(Err(err)) - } - Poll::Pending => { - if reserved_bytes > 0 { - refund_reserved_quota_bytes(this.user_stats.as_ref(), reserved_bytes); - this.stats.add_quota_refund_bytes_total(reserved_bytes); - } - if shaper_reserved_bytes > 0 - && let Some(lease) = this.traffic_lease.as_ref() - { - lease.refund(RateDirection::Down, shaper_reserved_bytes); - } - Poll::Pending - } - } - } - - #[inline] - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.get_mut().inner).poll_flush(cx) - } - - #[inline] - fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.get_mut().inner).poll_shutdown(cx) - } -} +mod io; +use self::io::{CombinedStream, SharedCounters, StatsIo, is_quota_io_error}; +#[cfg(test)] +use self::io::{quota_adaptive_interval_bytes, should_immediate_quota_check}; // ============= Relay ============= /// Relay data bidirectionally between client and server. diff --git a/src/proxy/relay/io.rs b/src/proxy/relay/io.rs new file mode 100644 index 0000000..fb30f3f --- /dev/null +++ b/src/proxy/relay/io.rs @@ -0,0 +1,551 @@ +use crate::proxy::traffic_limiter::{RateDirection, TrafficLease, next_refill_delay}; +use crate::stats::{Stats, UserStats}; +use std::io; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio::time::{Instant, Sleep}; +use tracing::trace; + +mod combined; +mod counters; +mod quota; + +pub(super) use self::combined::CombinedStream; +pub(super) use self::counters::SharedCounters; +pub(super) use self::quota::is_quota_io_error; +use self::quota::{ + QUOTA_RESERVE_MAX_ROUNDS, QUOTA_RESERVE_SPIN_RETRIES, quota_io_error, + refund_reserved_quota_bytes, +}; +pub(super) use self::quota::{quota_adaptive_interval_bytes, should_immediate_quota_check}; + +/// Transparent I/O wrapper that tracks per-user statistics and activity. +/// +/// Wraps the **client** side of the relay. Direction mapping: +/// +/// | poll method | direction | stats updated | +/// |-------------|-----------|--------------------------------------| +/// | `poll_read` | C→S | `octets_from`, `msgs_from`, counters | +/// | `poll_write` | S→C | `octets_to`, `msgs_to`, counters | +/// +/// Both update the shared activity timestamp for the watchdog. +/// +/// Note on message counts: the original code counted one `read()`/`write_all()` +/// as one "message". Here we count `poll_read`/`poll_write` completions instead. +/// Byte counts are identical; op counts may differ slightly due to different +/// internal buffering in `copy_bidirectional`. This is fine for monitoring. +pub(super) struct StatsIo { + inner: S, + counters: Arc, + stats: Arc, + user: String, + user_stats: Arc, + traffic_lease: Option>, + c2s_rate_debt_bytes: u64, + c2s_wait: RateWaitState, + s2c_wait: RateWaitState, + quota_wait: RateWaitState, + quota_limit: Option, + quota_exceeded: Arc, + pub(super) quota_bytes_since_check: u64, + epoch: Instant, +} + +#[derive(Default)] +struct RateWaitState { + sleep: Option>>, + started_at: Option, + blocked_user: bool, + blocked_cidr: bool, +} + +impl StatsIo { + /// Creates a StatsIo wrapper without a traffic lease for relay unit tests. + #[cfg(test)] + pub(super) fn new( + inner: S, + counters: Arc, + stats: Arc, + user: String, + quota_limit: Option, + quota_exceeded: Arc, + epoch: Instant, + ) -> Self { + Self::new_with_traffic_lease( + inner, + counters, + stats, + user, + None, + quota_limit, + quota_exceeded, + epoch, + ) + } + + pub(super) fn new_with_traffic_lease( + inner: S, + counters: Arc, + stats: Arc, + user: String, + traffic_lease: Option>, + quota_limit: Option, + quota_exceeded: Arc, + epoch: Instant, + ) -> Self { + // Mark initial activity so the watchdog doesn't fire before data flows + counters.touch(Instant::now(), epoch); + let user_stats = stats.get_or_create_user_stats_handle(&user); + Self { + inner, + counters, + stats, + user, + user_stats, + traffic_lease, + c2s_rate_debt_bytes: 0, + c2s_wait: RateWaitState::default(), + s2c_wait: RateWaitState::default(), + quota_wait: RateWaitState::default(), + quota_limit, + quota_exceeded, + quota_bytes_since_check: 0, + epoch, + } + } + + fn record_wait( + wait: &mut RateWaitState, + lease: Option<&Arc>, + direction: RateDirection, + ) { + let Some(started_at) = wait.started_at.take() else { + return; + }; + let wait_ms = started_at.elapsed().as_millis().min(u128::from(u64::MAX)) as u64; + if let Some(lease) = lease { + lease.observe_wait_ms(direction, wait.blocked_user, wait.blocked_cidr, wait_ms); + } + wait.blocked_user = false; + wait.blocked_cidr = false; + } + + fn arm_wait(wait: &mut RateWaitState, blocked_user: bool, blocked_cidr: bool) { + if wait.sleep.is_none() { + wait.sleep = Some(Box::pin(tokio::time::sleep(next_refill_delay()))); + wait.started_at = Some(Instant::now()); + } + wait.blocked_user |= blocked_user; + wait.blocked_cidr |= blocked_cidr; + } + + fn poll_wait( + wait: &mut RateWaitState, + cx: &mut Context<'_>, + lease: Option<&Arc>, + direction: RateDirection, + ) -> Poll<()> { + let Some(sleep) = wait.sleep.as_mut() else { + return Poll::Ready(()); + }; + if sleep.as_mut().poll(cx).is_pending() { + return Poll::Pending; + } + wait.sleep = None; + Self::record_wait(wait, lease, direction); + Poll::Ready(()) + } + + fn settle_c2s_rate_debt(&mut self, cx: &mut Context<'_>) -> Poll<()> { + let Some(lease) = self.traffic_lease.as_ref() else { + self.c2s_rate_debt_bytes = 0; + return Poll::Ready(()); + }; + + while self.c2s_rate_debt_bytes > 0 { + let consume = lease.try_consume(RateDirection::Up, self.c2s_rate_debt_bytes); + if consume.granted > 0 { + self.c2s_rate_debt_bytes = self.c2s_rate_debt_bytes.saturating_sub(consume.granted); + continue; + } + Self::arm_wait( + &mut self.c2s_wait, + consume.blocked_user, + consume.blocked_cidr, + ); + if Self::poll_wait(&mut self.c2s_wait, cx, Some(lease), RateDirection::Up).is_pending() + { + return Poll::Pending; + } + } + + if Self::poll_wait(&mut self.c2s_wait, cx, Some(lease), RateDirection::Up).is_pending() { + return Poll::Pending; + } + + Poll::Ready(()) + } + + fn arm_quota_wait(&mut self, cx: &mut Context<'_>) -> Poll<()> { + Self::arm_wait(&mut self.quota_wait, false, false); + Self::poll_wait(&mut self.quota_wait, cx, None, RateDirection::Up) + } +} + +impl AsyncRead for StatsIo { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let this = self.get_mut(); + if this.quota_exceeded.load(Ordering::Acquire) { + return Poll::Ready(Err(quota_io_error())); + } + if this.settle_c2s_rate_debt(cx).is_pending() { + return Poll::Pending; + } + if buf.remaining() == 0 { + return Pin::new(&mut this.inner).poll_read(cx, buf); + } + + let mut remaining_before = None; + let mut reserved_read_bytes = 0u64; + let mut read_limit = buf.remaining(); + if let Some(limit) = this.quota_limit { + let used_before = this.user_stats.quota_used(); + let remaining = limit.saturating_sub(used_before); + if remaining == 0 { + this.quota_exceeded.store(true, Ordering::Release); + return Poll::Ready(Err(quota_io_error())); + } + remaining_before = Some(remaining); + read_limit = read_limit.min(remaining as usize); + if read_limit == 0 { + this.quota_exceeded.store(true, Ordering::Release); + return Poll::Ready(Err(quota_io_error())); + } + + let desired = read_limit as u64; + let mut reserve_rounds = 0usize; + while reserved_read_bytes == 0 { + for _ in 0..QUOTA_RESERVE_SPIN_RETRIES { + match this.user_stats.quota_try_reserve(desired, limit) { + Ok(_) => { + reserved_read_bytes = desired; + break; + } + Err(crate::stats::QuotaReserveError::LimitExceeded) => { + this.quota_exceeded.store(true, Ordering::Release); + return Poll::Ready(Err(quota_io_error())); + } + Err(crate::stats::QuotaReserveError::Contended) => { + this.stats.increment_quota_contention_total(); + } + } + } + + if reserved_read_bytes == 0 { + reserve_rounds = reserve_rounds.saturating_add(1); + if reserve_rounds >= QUOTA_RESERVE_MAX_ROUNDS { + this.stats.increment_quota_contention_timeout_total(); + if this.arm_quota_wait(cx).is_pending() { + return Poll::Pending; + } + reserve_rounds = 0; + } + } + } + } + + let limited_read = read_limit < buf.remaining(); + let read_result = if limited_read { + let mut limited_buf = ReadBuf::new(buf.initialize_unfilled_to(read_limit)); + match Pin::new(&mut this.inner).poll_read(cx, &mut limited_buf) { + Poll::Ready(Ok(())) => { + let n = limited_buf.filled().len(); + buf.advance(n); + Poll::Ready(Ok(n)) + } + Poll::Ready(Err(err)) => Poll::Ready(Err(err)), + Poll::Pending => Poll::Pending, + } + } else { + let before = buf.filled().len(); + match Pin::new(&mut this.inner).poll_read(cx, buf) { + Poll::Ready(Ok(())) => { + let n = buf.filled().len() - before; + Poll::Ready(Ok(n)) + } + Poll::Ready(Err(err)) => Poll::Ready(Err(err)), + Poll::Pending => Poll::Pending, + } + }; + + match read_result { + Poll::Ready(Ok(n)) => { + if reserved_read_bytes > n as u64 { + let refund_bytes = reserved_read_bytes - n as u64; + refund_reserved_quota_bytes(this.user_stats.as_ref(), refund_bytes); + this.stats.add_quota_refund_bytes_total(refund_bytes); + } + if n > 0 { + let n_to_charge = n as u64; + + if let Some(remaining) = remaining_before { + if should_immediate_quota_check(remaining, n_to_charge) { + this.quota_bytes_since_check = 0; + } else { + this.quota_bytes_since_check = + this.quota_bytes_since_check.saturating_add(n_to_charge); + let interval = quota_adaptive_interval_bytes(remaining); + if this.quota_bytes_since_check >= interval { + this.quota_bytes_since_check = 0; + } + } + } + if let Some(limit) = this.quota_limit + && this.user_stats.quota_used() >= limit + { + this.quota_exceeded.store(true, Ordering::Release); + } + + // C→S: client sent data + this.counters + .c2s_bytes + .fetch_add(n_to_charge, Ordering::Relaxed); + this.counters.c2s_ops.fetch_add(1, Ordering::Relaxed); + this.counters.touch(Instant::now(), this.epoch); + + this.stats + .add_user_traffic_from_handle(this.user_stats.as_ref(), n_to_charge); + if this.traffic_lease.is_some() { + this.c2s_rate_debt_bytes = + this.c2s_rate_debt_bytes.saturating_add(n_to_charge); + let _ = this.settle_c2s_rate_debt(cx); + } + + trace!(user = %this.user, bytes = n, "C->S"); + } + Poll::Ready(Ok(())) + } + Poll::Pending => { + if reserved_read_bytes > 0 { + refund_reserved_quota_bytes(this.user_stats.as_ref(), reserved_read_bytes); + this.stats.add_quota_refund_bytes_total(reserved_read_bytes); + } + Poll::Pending + } + Poll::Ready(Err(err)) => { + if reserved_read_bytes > 0 { + refund_reserved_quota_bytes(this.user_stats.as_ref(), reserved_read_bytes); + this.stats.add_quota_refund_bytes_total(reserved_read_bytes); + } + Poll::Ready(Err(err)) + } + } + } +} + +impl AsyncWrite for StatsIo { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let this = self.get_mut(); + if this.quota_exceeded.load(Ordering::Acquire) { + return Poll::Ready(Err(quota_io_error())); + } + + let mut shaper_reserved_bytes = 0u64; + let mut write_buf = buf; + if let Some(lease) = this.traffic_lease.as_ref() { + if !buf.is_empty() { + loop { + let consume = lease.try_consume(RateDirection::Down, buf.len() as u64); + if consume.granted > 0 { + shaper_reserved_bytes = consume.granted; + if consume.granted < buf.len() as u64 { + write_buf = &buf[..consume.granted as usize]; + } + let _ = Self::poll_wait( + &mut this.s2c_wait, + cx, + Some(lease), + RateDirection::Down, + ); + break; + } + + Self::arm_wait( + &mut this.s2c_wait, + consume.blocked_user, + consume.blocked_cidr, + ); + if Self::poll_wait(&mut this.s2c_wait, cx, Some(lease), RateDirection::Down) + .is_pending() + { + return Poll::Pending; + } + } + } else { + let _ = Self::poll_wait(&mut this.s2c_wait, cx, Some(lease), RateDirection::Down); + } + } + + let mut remaining_before = None; + let mut reserved_bytes = 0u64; + if let Some(limit) = this.quota_limit { + if !write_buf.is_empty() { + let mut reserve_rounds = 0usize; + while reserved_bytes == 0 { + let used_before = this.user_stats.quota_used(); + let remaining = limit.saturating_sub(used_before); + if remaining == 0 { + if let Some(lease) = this.traffic_lease.as_ref() { + lease.refund(RateDirection::Down, shaper_reserved_bytes); + } + this.quota_exceeded.store(true, Ordering::Release); + return Poll::Ready(Err(quota_io_error())); + } + remaining_before = Some(remaining); + + let desired = remaining.min(write_buf.len() as u64); + let mut saw_contention = false; + for _ in 0..QUOTA_RESERVE_SPIN_RETRIES { + match this.user_stats.quota_try_reserve(desired, limit) { + Ok(_) => { + reserved_bytes = desired; + write_buf = &write_buf[..desired as usize]; + break; + } + Err(crate::stats::QuotaReserveError::LimitExceeded) => { + break; + } + Err(crate::stats::QuotaReserveError::Contended) => { + this.stats.increment_quota_contention_total(); + saw_contention = true; + } + } + } + + if reserved_bytes == 0 { + reserve_rounds = reserve_rounds.saturating_add(1); + if reserve_rounds >= QUOTA_RESERVE_MAX_ROUNDS { + this.stats.increment_quota_contention_timeout_total(); + if let Some(lease) = this.traffic_lease.as_ref() { + lease.refund(RateDirection::Down, shaper_reserved_bytes); + } + let _ = this.arm_quota_wait(cx); + return Poll::Pending; + } else if saw_contention { + std::hint::spin_loop(); + } + } + } + } else { + let used_before = this.user_stats.quota_used(); + let remaining = limit.saturating_sub(used_before); + if remaining == 0 { + if let Some(lease) = this.traffic_lease.as_ref() { + lease.refund(RateDirection::Down, shaper_reserved_bytes); + } + this.quota_exceeded.store(true, Ordering::Release); + return Poll::Ready(Err(quota_io_error())); + } + remaining_before = Some(remaining); + } + } + + match Pin::new(&mut this.inner).poll_write(cx, write_buf) { + Poll::Ready(Ok(n)) => { + if reserved_bytes > n as u64 { + let refund_bytes = reserved_bytes - n as u64; + refund_reserved_quota_bytes(this.user_stats.as_ref(), refund_bytes); + this.stats.add_quota_refund_bytes_total(refund_bytes); + } + if shaper_reserved_bytes > n as u64 + && let Some(lease) = this.traffic_lease.as_ref() + { + lease.refund(RateDirection::Down, shaper_reserved_bytes - n as u64); + } + if n > 0 { + if let Some(lease) = this.traffic_lease.as_ref() { + Self::record_wait(&mut this.s2c_wait, Some(lease), RateDirection::Down); + } + let n_to_charge = n as u64; + + // S→C: data written to client + this.counters + .s2c_bytes + .fetch_add(n_to_charge, Ordering::Relaxed); + this.counters.s2c_ops.fetch_add(1, Ordering::Relaxed); + this.counters.touch(Instant::now(), this.epoch); + + this.stats + .add_user_traffic_to_handle(this.user_stats.as_ref(), n_to_charge); + + if let (Some(limit), Some(remaining)) = (this.quota_limit, remaining_before) { + if should_immediate_quota_check(remaining, n_to_charge) { + this.quota_bytes_since_check = 0; + if this.user_stats.quota_used() >= limit { + this.quota_exceeded.store(true, Ordering::Release); + } + } else { + this.quota_bytes_since_check = + this.quota_bytes_since_check.saturating_add(n_to_charge); + let interval = quota_adaptive_interval_bytes(remaining); + if this.quota_bytes_since_check >= interval { + this.quota_bytes_since_check = 0; + if this.user_stats.quota_used() >= limit { + this.quota_exceeded.store(true, Ordering::Release); + } + } + } + } + + trace!(user = %this.user, bytes = n, "S->C"); + } + Poll::Ready(Ok(n)) + } + Poll::Ready(Err(err)) => { + if reserved_bytes > 0 { + refund_reserved_quota_bytes(this.user_stats.as_ref(), reserved_bytes); + this.stats.add_quota_refund_bytes_total(reserved_bytes); + } + if shaper_reserved_bytes > 0 + && let Some(lease) = this.traffic_lease.as_ref() + { + lease.refund(RateDirection::Down, shaper_reserved_bytes); + } + Poll::Ready(Err(err)) + } + Poll::Pending => { + if reserved_bytes > 0 { + refund_reserved_quota_bytes(this.user_stats.as_ref(), reserved_bytes); + this.stats.add_quota_refund_bytes_total(reserved_bytes); + } + if shaper_reserved_bytes > 0 + && let Some(lease) = this.traffic_lease.as_ref() + { + lease.refund(RateDirection::Down, shaper_reserved_bytes); + } + Poll::Pending + } + } + } + + #[inline] + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.get_mut().inner).poll_flush(cx) + } + + #[inline] + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.get_mut().inner).poll_shutdown(cx) + } +} diff --git a/src/proxy/relay/io/combined.rs b/src/proxy/relay/io/combined.rs new file mode 100644 index 0000000..6ddb131 --- /dev/null +++ b/src/proxy/relay/io/combined.rs @@ -0,0 +1,61 @@ +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +// ============= CombinedStream ============= + +/// Combines separate read and write halves into a single bidirectional stream. +/// +/// `copy_bidirectional` requires `AsyncRead + AsyncWrite` on each side, +/// but the handshake layer produces split reader/writer pairs +/// (e.g. `CryptoReader>` + `CryptoWriter<...>`). +/// +/// This wrapper reunifies them with zero overhead — each trait method +/// delegates directly to the corresponding half. No buffering, no copies. +/// +/// Safety: `poll_read` only touches `reader`, `poll_write` only touches `writer`, +/// so there's no aliasing even though both are called on the same `&mut self`. +pub(in crate::proxy::relay) struct CombinedStream { + reader: R, + writer: W, +} + +impl CombinedStream { + pub(in crate::proxy::relay) fn new(reader: R, writer: W) -> Self { + Self { reader, writer } + } +} + +impl AsyncRead for CombinedStream { + #[inline] + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + Pin::new(&mut self.get_mut().reader).poll_read(cx, buf) + } +} + +impl AsyncWrite for CombinedStream { + #[inline] + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.get_mut().writer).poll_write(cx, buf) + } + + #[inline] + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.get_mut().writer).poll_flush(cx) + } + + #[inline] + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.get_mut().writer).poll_shutdown(cx) + } +} diff --git a/src/proxy/relay/io/counters.rs b/src/proxy/relay/io/counters.rs new file mode 100644 index 0000000..963b88d --- /dev/null +++ b/src/proxy/relay/io/counters.rs @@ -0,0 +1,51 @@ +use std::sync::atomic::{AtomicU64, Ordering}; +use std::time::Duration; + +use tokio::time::Instant; + +// ============= SharedCounters ============= + +/// Atomic counters shared between the relay (via StatsIo) and the watchdog task. +/// +/// Using `Relaxed` ordering is sufficient because: +/// - Counters are monotonically increasing (no ABA problem) +/// - Slight staleness in watchdog reads is harmless (±10s check interval anyway) +/// - No ordering dependencies between different counters +pub(in crate::proxy::relay) struct SharedCounters { + /// Bytes read from client (C→S direction) + pub(in crate::proxy::relay) c2s_bytes: AtomicU64, + /// Bytes written to client (S→C direction) + pub(in crate::proxy::relay) s2c_bytes: AtomicU64, + /// Number of poll_read completions (≈ C→S chunks) + pub(in crate::proxy::relay) c2s_ops: AtomicU64, + /// Number of poll_write completions (≈ S→C chunks) + pub(in crate::proxy::relay) s2c_ops: AtomicU64, + /// Milliseconds since relay epoch of last I/O activity + last_activity_ms: AtomicU64, +} + +impl SharedCounters { + pub(in crate::proxy::relay) fn new() -> Self { + Self { + c2s_bytes: AtomicU64::new(0), + s2c_bytes: AtomicU64::new(0), + c2s_ops: AtomicU64::new(0), + s2c_ops: AtomicU64::new(0), + last_activity_ms: AtomicU64::new(0), + } + } + + /// Record activity at this instant. + #[inline] + pub(in crate::proxy::relay) fn touch(&self, now: Instant, epoch: Instant) { + let ms = now.duration_since(epoch).as_millis() as u64; + self.last_activity_ms.store(ms, Ordering::Relaxed); + } + + /// How long since last recorded activity. + pub(in crate::proxy::relay) fn idle_duration(&self, now: Instant, epoch: Instant) -> Duration { + let last_ms = self.last_activity_ms.load(Ordering::Relaxed); + let now_ms = now.duration_since(epoch).as_millis() as u64; + Duration::from_millis(now_ms.saturating_sub(last_ms)) + } +} diff --git a/src/proxy/relay/io/quota.rs b/src/proxy/relay/io/quota.rs new file mode 100644 index 0000000..1faf0b5 --- /dev/null +++ b/src/proxy/relay/io/quota.rs @@ -0,0 +1,68 @@ +use crate::stats::UserStats; +use std::io; +use std::sync::atomic::Ordering; + +#[derive(Debug)] +struct QuotaIoSentinel; + +impl std::fmt::Display for QuotaIoSentinel { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("user data quota exceeded") + } +} + +impl std::error::Error for QuotaIoSentinel {} + +pub(super) fn quota_io_error() -> io::Error { + io::Error::new(io::ErrorKind::PermissionDenied, QuotaIoSentinel) +} + +pub(in crate::proxy::relay) fn is_quota_io_error(err: &io::Error) -> bool { + err.kind() == io::ErrorKind::PermissionDenied + && err + .get_ref() + .and_then(|source| source.downcast_ref::()) + .is_some() +} + +const QUOTA_NEAR_LIMIT_BYTES: u64 = 64 * 1024; +const QUOTA_LARGE_CHARGE_BYTES: u64 = 16 * 1024; +const QUOTA_ADAPTIVE_INTERVAL_MIN_BYTES: u64 = 4 * 1024; +const QUOTA_ADAPTIVE_INTERVAL_MAX_BYTES: u64 = 64 * 1024; +pub(super) const QUOTA_RESERVE_SPIN_RETRIES: usize = 64; +pub(super) const QUOTA_RESERVE_MAX_ROUNDS: usize = 8; + +#[inline] +pub(in crate::proxy::relay) fn quota_adaptive_interval_bytes(remaining_before: u64) -> u64 { + remaining_before.saturating_div(2).clamp( + QUOTA_ADAPTIVE_INTERVAL_MIN_BYTES, + QUOTA_ADAPTIVE_INTERVAL_MAX_BYTES, + ) +} + +#[inline] +pub(in crate::proxy::relay) fn should_immediate_quota_check( + remaining_before: u64, + charge_bytes: u64, +) -> bool { + remaining_before <= QUOTA_NEAR_LIMIT_BYTES || charge_bytes >= QUOTA_LARGE_CHARGE_BYTES +} + +pub(super) fn refund_reserved_quota_bytes(user_stats: &UserStats, reserved_bytes: u64) { + if reserved_bytes == 0 { + return; + } + let mut current = user_stats.quota_used.load(Ordering::Relaxed); + loop { + let next = current.saturating_sub(reserved_bytes); + match user_stats.quota_used.compare_exchange_weak( + current, + next, + Ordering::Relaxed, + Ordering::Relaxed, + ) { + Ok(_) => return, + Err(observed) => current = observed, + } + } +} diff --git a/src/proxy/tests/relay_baseline_invariant_tests.rs b/src/proxy/tests/relay_baseline_invariant_tests.rs index 998be2d..7b2fab5 100644 --- a/src/proxy/tests/relay_baseline_invariant_tests.rs +++ b/src/proxy/tests/relay_baseline_invariant_tests.rs @@ -3,7 +3,9 @@ use crate::error::ProxyError; use crate::stats::Stats; use crate::stream::BufferPool; use std::io; +use std::pin::Pin; use std::sync::Arc; +use std::task::{Context, Poll}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf, duplex}; use tokio::time::{Duration, timeout}; diff --git a/src/stats/core_counters.rs b/src/stats/core_counters.rs new file mode 100644 index 0000000..9017270 --- /dev/null +++ b/src/stats/core_counters.rs @@ -0,0 +1,266 @@ +use super::*; + +impl Stats { + pub fn apply_telemetry_policy(&self, policy: TelemetryPolicy) { + self.telemetry_core_enabled + .store(policy.core_enabled, Ordering::Relaxed); + self.telemetry_user_enabled + .store(policy.user_enabled, Ordering::Relaxed); + self.telemetry_me_level + .store(policy.me_level.as_u8(), Ordering::Relaxed); + } + + pub fn telemetry_policy(&self) -> TelemetryPolicy { + TelemetryPolicy { + core_enabled: self.telemetry_core_enabled(), + user_enabled: self.telemetry_user_enabled(), + me_level: self.telemetry_me_level(), + } + } + + pub fn increment_connects_all(&self) { + if self.telemetry_core_enabled() { + self.connects_all.fetch_add(1, Ordering::Relaxed); + } + } + + pub fn increment_connects_bad_with_class(&self, class: &'static str) { + if !self.telemetry_core_enabled() { + return; + } + self.connects_bad.fetch_add(1, Ordering::Relaxed); + let entry = self + .connects_bad_classes + .entry(class) + .or_insert_with(|| AtomicU64::new(0)); + entry.fetch_add(1, Ordering::Relaxed); + } + + pub fn increment_connects_bad(&self) { + self.increment_connects_bad_with_class("other"); + } + + pub fn increment_handshake_failure_class(&self, class: &'static str) { + if !self.telemetry_core_enabled() { + return; + } + let entry = self + .handshake_failure_classes + .entry(class) + .or_insert_with(|| AtomicU64::new(0)); + entry.fetch_add(1, Ordering::Relaxed); + } + pub fn increment_current_connections_direct(&self) { + self.current_connections_direct + .fetch_add(1, Ordering::Relaxed); + } + pub fn decrement_current_connections_direct(&self) { + Self::decrement_atomic_saturating(&self.current_connections_direct); + } + pub fn increment_current_connections_me(&self) { + self.current_connections_me.fetch_add(1, Ordering::Relaxed); + } + pub fn decrement_current_connections_me(&self) { + Self::decrement_atomic_saturating(&self.current_connections_me); + } + + pub fn acquire_direct_connection_lease(self: &Arc) -> RouteConnectionLease { + self.increment_current_connections_direct(); + RouteConnectionLease::new(self.clone(), RouteConnectionGauge::Direct) + } + + pub fn acquire_me_connection_lease(self: &Arc) -> RouteConnectionLease { + self.increment_current_connections_me(); + RouteConnectionLease::new(self.clone(), RouteConnectionGauge::Middle) + } + + pub(super) fn decrement_route_cutover_parked_direct(&self) { + Self::decrement_atomic_saturating(&self.route_cutover_parked_direct_current); + } + + pub(super) fn decrement_route_cutover_parked_middle(&self) { + Self::decrement_atomic_saturating(&self.route_cutover_parked_middle_current); + } + + pub fn acquire_direct_cutover_park_lease(self: &Arc) -> RouteCutoverParkLease { + self.route_cutover_parked_direct_current + .fetch_add(1, Ordering::Relaxed); + self.route_cutover_parked_direct_total + .fetch_add(1, Ordering::Relaxed); + RouteCutoverParkLease::new(self.clone(), RouteCutoverParkGauge::Direct) + } + + pub fn acquire_middle_cutover_park_lease(self: &Arc) -> RouteCutoverParkLease { + self.route_cutover_parked_middle_current + .fetch_add(1, Ordering::Relaxed); + self.route_cutover_parked_middle_total + .fetch_add(1, Ordering::Relaxed); + RouteCutoverParkLease::new(self.clone(), RouteCutoverParkGauge::Middle) + } + pub fn increment_handshake_timeouts(&self) { + if self.telemetry_core_enabled() { + self.handshake_timeouts.fetch_add(1, Ordering::Relaxed); + } + } + + pub fn increment_accept_permit_timeout_total(&self) { + if self.telemetry_core_enabled() { + self.accept_permit_timeout_total + .fetch_add(1, Ordering::Relaxed); + } + } + + pub fn set_conntrack_control_enabled(&self, enabled: bool) { + self.conntrack_control_enabled_gauge + .store(enabled, Ordering::Relaxed); + } + + pub fn set_conntrack_control_available(&self, available: bool) { + self.conntrack_control_available_gauge + .store(available, Ordering::Relaxed); + } + + pub fn set_conntrack_pressure_active(&self, active: bool) { + self.conntrack_pressure_active_gauge + .store(active, Ordering::Relaxed); + } + + pub fn set_conntrack_event_queue_depth(&self, depth: u64) { + self.conntrack_event_queue_depth_gauge + .store(depth, Ordering::Relaxed); + } + + pub fn set_conntrack_rule_apply_ok(&self, ok: bool) { + self.conntrack_rule_apply_ok_gauge + .store(ok, Ordering::Relaxed); + } + + pub fn increment_conntrack_delete_attempt_total(&self) { + if self.telemetry_core_enabled() { + self.conntrack_delete_attempt_total + .fetch_add(1, Ordering::Relaxed); + } + } + + pub fn increment_conntrack_delete_success_total(&self) { + if self.telemetry_core_enabled() { + self.conntrack_delete_success_total + .fetch_add(1, Ordering::Relaxed); + } + } + + pub fn increment_conntrack_delete_not_found_total(&self) { + if self.telemetry_core_enabled() { + self.conntrack_delete_not_found_total + .fetch_add(1, Ordering::Relaxed); + } + } + + pub fn increment_conntrack_delete_error_total(&self) { + if self.telemetry_core_enabled() { + self.conntrack_delete_error_total + .fetch_add(1, Ordering::Relaxed); + } + } + + pub fn increment_conntrack_close_event_drop_total(&self) { + if self.telemetry_core_enabled() { + self.conntrack_close_event_drop_total + .fetch_add(1, Ordering::Relaxed); + } + } + + pub fn increment_upstream_connect_attempt_total(&self) { + if self.telemetry_core_enabled() { + self.upstream_connect_attempt_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_upstream_connect_success_total(&self) { + if self.telemetry_core_enabled() { + self.upstream_connect_success_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_upstream_connect_fail_total(&self) { + if self.telemetry_core_enabled() { + self.upstream_connect_fail_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_upstream_connect_failfast_hard_error_total(&self) { + if self.telemetry_core_enabled() { + self.upstream_connect_failfast_hard_error_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn observe_upstream_connect_attempts_per_request(&self, attempts: u32) { + if !self.telemetry_core_enabled() { + return; + } + match attempts { + 0 => {} + 1 => { + self.upstream_connect_attempts_bucket_1 + .fetch_add(1, Ordering::Relaxed); + } + 2 => { + self.upstream_connect_attempts_bucket_2 + .fetch_add(1, Ordering::Relaxed); + } + 3..=4 => { + self.upstream_connect_attempts_bucket_3_4 + .fetch_add(1, Ordering::Relaxed); + } + _ => { + self.upstream_connect_attempts_bucket_gt_4 + .fetch_add(1, Ordering::Relaxed); + } + } + } + pub fn observe_upstream_connect_duration_ms(&self, duration_ms: u64, success: bool) { + if !self.telemetry_core_enabled() { + return; + } + let bucket = match duration_ms { + 0..=100 => 0u8, + 101..=500 => 1u8, + 501..=1000 => 2u8, + _ => 3u8, + }; + match (success, bucket) { + (true, 0) => { + self.upstream_connect_duration_success_bucket_le_100ms + .fetch_add(1, Ordering::Relaxed); + } + (true, 1) => { + self.upstream_connect_duration_success_bucket_101_500ms + .fetch_add(1, Ordering::Relaxed); + } + (true, 2) => { + self.upstream_connect_duration_success_bucket_501_1000ms + .fetch_add(1, Ordering::Relaxed); + } + (true, _) => { + self.upstream_connect_duration_success_bucket_gt_1000ms + .fetch_add(1, Ordering::Relaxed); + } + (false, 0) => { + self.upstream_connect_duration_fail_bucket_le_100ms + .fetch_add(1, Ordering::Relaxed); + } + (false, 1) => { + self.upstream_connect_duration_fail_bucket_101_500ms + .fetch_add(1, Ordering::Relaxed); + } + (false, 2) => { + self.upstream_connect_duration_fail_bucket_501_1000ms + .fetch_add(1, Ordering::Relaxed); + } + (false, _) => { + self.upstream_connect_duration_fail_bucket_gt_1000ms + .fetch_add(1, Ordering::Relaxed); + } + } + } +} diff --git a/src/stats/core_getters.rs b/src/stats/core_getters.rs new file mode 100644 index 0000000..c1730c1 --- /dev/null +++ b/src/stats/core_getters.rs @@ -0,0 +1,283 @@ +use super::*; + +impl Stats { + pub fn get_connects_all(&self) -> u64 { + self.connects_all.load(Ordering::Relaxed) + } + pub fn get_connects_bad(&self) -> u64 { + self.connects_bad.load(Ordering::Relaxed) + } + + pub fn get_connects_bad_class_counts(&self) -> Vec<(String, u64)> { + let mut out: Vec<(String, u64)> = self + .connects_bad_classes + .iter() + .map(|entry| { + ( + entry.key().to_string(), + entry.value().load(Ordering::Relaxed), + ) + }) + .collect(); + out.sort_by(|a, b| a.0.cmp(&b.0)); + out + } + + pub fn get_handshake_failure_class_counts(&self) -> Vec<(String, u64)> { + let mut out: Vec<(String, u64)> = self + .handshake_failure_classes + .iter() + .map(|entry| { + ( + entry.key().to_string(), + entry.value().load(Ordering::Relaxed), + ) + }) + .collect(); + out.sort_by(|a, b| a.0.cmp(&b.0)); + out + } + + pub fn get_accept_permit_timeout_total(&self) -> u64 { + self.accept_permit_timeout_total.load(Ordering::Relaxed) + } + pub fn get_current_connections_direct(&self) -> u64 { + self.current_connections_direct.load(Ordering::Relaxed) + } + pub fn get_current_connections_me(&self) -> u64 { + self.current_connections_me.load(Ordering::Relaxed) + } + pub fn get_route_cutover_parked_direct_current(&self) -> u64 { + self.route_cutover_parked_direct_current + .load(Ordering::Relaxed) + } + pub fn get_route_cutover_parked_middle_current(&self) -> u64 { + self.route_cutover_parked_middle_current + .load(Ordering::Relaxed) + } + pub fn get_route_cutover_parked_direct_total(&self) -> u64 { + self.route_cutover_parked_direct_total + .load(Ordering::Relaxed) + } + pub fn get_route_cutover_parked_middle_total(&self) -> u64 { + self.route_cutover_parked_middle_total + .load(Ordering::Relaxed) + } + pub fn get_current_connections_total(&self) -> u64 { + self.get_current_connections_direct() + .saturating_add(self.get_current_connections_me()) + } + pub fn get_conntrack_control_enabled(&self) -> bool { + self.conntrack_control_enabled_gauge.load(Ordering::Relaxed) + } + pub fn get_conntrack_control_available(&self) -> bool { + self.conntrack_control_available_gauge + .load(Ordering::Relaxed) + } + pub fn get_conntrack_pressure_active(&self) -> bool { + self.conntrack_pressure_active_gauge.load(Ordering::Relaxed) + } + pub fn get_conntrack_event_queue_depth(&self) -> u64 { + self.conntrack_event_queue_depth_gauge + .load(Ordering::Relaxed) + } + pub fn get_conntrack_rule_apply_ok(&self) -> bool { + self.conntrack_rule_apply_ok_gauge.load(Ordering::Relaxed) + } + pub fn get_conntrack_delete_attempt_total(&self) -> u64 { + self.conntrack_delete_attempt_total.load(Ordering::Relaxed) + } + pub fn get_conntrack_delete_success_total(&self) -> u64 { + self.conntrack_delete_success_total.load(Ordering::Relaxed) + } + pub fn get_conntrack_delete_not_found_total(&self) -> u64 { + self.conntrack_delete_not_found_total + .load(Ordering::Relaxed) + } + pub fn get_conntrack_delete_error_total(&self) -> u64 { + self.conntrack_delete_error_total.load(Ordering::Relaxed) + } + pub fn get_conntrack_close_event_drop_total(&self) -> u64 { + self.conntrack_close_event_drop_total + .load(Ordering::Relaxed) + } + pub fn get_me_keepalive_sent(&self) -> u64 { + self.me_keepalive_sent.load(Ordering::Relaxed) + } + pub fn get_me_keepalive_failed(&self) -> u64 { + self.me_keepalive_failed.load(Ordering::Relaxed) + } + pub fn get_me_keepalive_pong(&self) -> u64 { + self.me_keepalive_pong.load(Ordering::Relaxed) + } + pub fn get_me_keepalive_timeout(&self) -> u64 { + self.me_keepalive_timeout.load(Ordering::Relaxed) + } + pub fn get_me_rpc_proxy_req_signal_sent_total(&self) -> u64 { + self.me_rpc_proxy_req_signal_sent_total + .load(Ordering::Relaxed) + } + pub fn get_me_rpc_proxy_req_signal_failed_total(&self) -> u64 { + self.me_rpc_proxy_req_signal_failed_total + .load(Ordering::Relaxed) + } + pub fn get_me_rpc_proxy_req_signal_skipped_no_meta_total(&self) -> u64 { + self.me_rpc_proxy_req_signal_skipped_no_meta_total + .load(Ordering::Relaxed) + } + pub fn get_me_rpc_proxy_req_signal_response_total(&self) -> u64 { + self.me_rpc_proxy_req_signal_response_total + .load(Ordering::Relaxed) + } + pub fn get_me_rpc_proxy_req_signal_close_sent_total(&self) -> u64 { + self.me_rpc_proxy_req_signal_close_sent_total + .load(Ordering::Relaxed) + } + pub fn get_me_reconnect_attempts(&self) -> u64 { + self.me_reconnect_attempts.load(Ordering::Relaxed) + } + pub fn get_me_reconnect_success(&self) -> u64 { + self.me_reconnect_success.load(Ordering::Relaxed) + } + pub fn get_me_handshake_reject_total(&self) -> u64 { + self.me_handshake_reject_total.load(Ordering::Relaxed) + } + pub fn get_me_reader_eof_total(&self) -> u64 { + self.me_reader_eof_total.load(Ordering::Relaxed) + } + pub fn get_me_idle_close_by_peer_total(&self) -> u64 { + self.me_idle_close_by_peer_total.load(Ordering::Relaxed) + } + pub fn get_relay_idle_soft_mark_total(&self) -> u64 { + self.relay_idle_soft_mark_total.load(Ordering::Relaxed) + } + pub fn get_relay_idle_hard_close_total(&self) -> u64 { + self.relay_idle_hard_close_total.load(Ordering::Relaxed) + } + pub fn get_relay_pressure_evict_total(&self) -> u64 { + self.relay_pressure_evict_total.load(Ordering::Relaxed) + } + pub fn get_relay_protocol_desync_close_total(&self) -> u64 { + self.relay_protocol_desync_close_total + .load(Ordering::Relaxed) + } + pub fn get_me_crc_mismatch(&self) -> u64 { + self.me_crc_mismatch.load(Ordering::Relaxed) + } + pub fn get_me_seq_mismatch(&self) -> u64 { + self.me_seq_mismatch.load(Ordering::Relaxed) + } + pub fn get_me_endpoint_quarantine_total(&self) -> u64 { + self.me_endpoint_quarantine_total.load(Ordering::Relaxed) + } + pub fn get_me_endpoint_quarantine_unexpected_total(&self) -> u64 { + self.me_endpoint_quarantine_unexpected_total + .load(Ordering::Relaxed) + } + pub fn get_me_endpoint_quarantine_draining_suppressed_total(&self) -> u64 { + self.me_endpoint_quarantine_draining_suppressed_total + .load(Ordering::Relaxed) + } + pub fn get_me_kdf_drift_total(&self) -> u64 { + self.me_kdf_drift_total.load(Ordering::Relaxed) + } + pub fn get_me_kdf_port_only_drift_total(&self) -> u64 { + self.me_kdf_port_only_drift_total.load(Ordering::Relaxed) + } + pub fn get_me_hardswap_pending_reuse_total(&self) -> u64 { + self.me_hardswap_pending_reuse_total.load(Ordering::Relaxed) + } + pub fn get_me_hardswap_pending_ttl_expired_total(&self) -> u64 { + self.me_hardswap_pending_ttl_expired_total + .load(Ordering::Relaxed) + } + pub fn get_me_single_endpoint_outage_enter_total(&self) -> u64 { + self.me_single_endpoint_outage_enter_total + .load(Ordering::Relaxed) + } + pub fn get_me_single_endpoint_outage_exit_total(&self) -> u64 { + self.me_single_endpoint_outage_exit_total + .load(Ordering::Relaxed) + } + pub fn get_me_single_endpoint_outage_reconnect_attempt_total(&self) -> u64 { + self.me_single_endpoint_outage_reconnect_attempt_total + .load(Ordering::Relaxed) + } + pub fn get_me_single_endpoint_outage_reconnect_success_total(&self) -> u64 { + self.me_single_endpoint_outage_reconnect_success_total + .load(Ordering::Relaxed) + } + pub fn get_me_single_endpoint_quarantine_bypass_total(&self) -> u64 { + self.me_single_endpoint_quarantine_bypass_total + .load(Ordering::Relaxed) + } + pub fn get_me_single_endpoint_shadow_rotate_total(&self) -> u64 { + self.me_single_endpoint_shadow_rotate_total + .load(Ordering::Relaxed) + } + pub fn get_me_single_endpoint_shadow_rotate_skipped_quarantine_total(&self) -> u64 { + self.me_single_endpoint_shadow_rotate_skipped_quarantine_total + .load(Ordering::Relaxed) + } + pub fn get_me_floor_mode_switch_total(&self) -> u64 { + self.me_floor_mode_switch_total.load(Ordering::Relaxed) + } + pub fn get_me_floor_mode_switch_static_to_adaptive_total(&self) -> u64 { + self.me_floor_mode_switch_static_to_adaptive_total + .load(Ordering::Relaxed) + } + pub fn get_me_floor_mode_switch_adaptive_to_static_total(&self) -> u64 { + self.me_floor_mode_switch_adaptive_to_static_total + .load(Ordering::Relaxed) + } + pub fn get_me_floor_cpu_cores_detected_gauge(&self) -> u64 { + self.me_floor_cpu_cores_detected_gauge + .load(Ordering::Relaxed) + } + pub fn get_me_floor_cpu_cores_effective_gauge(&self) -> u64 { + self.me_floor_cpu_cores_effective_gauge + .load(Ordering::Relaxed) + } + pub fn get_me_floor_global_cap_raw_gauge(&self) -> u64 { + self.me_floor_global_cap_raw_gauge.load(Ordering::Relaxed) + } + pub fn get_me_floor_global_cap_effective_gauge(&self) -> u64 { + self.me_floor_global_cap_effective_gauge + .load(Ordering::Relaxed) + } + pub fn get_me_floor_target_writers_total_gauge(&self) -> u64 { + self.me_floor_target_writers_total_gauge + .load(Ordering::Relaxed) + } + pub fn get_me_floor_active_cap_configured_gauge(&self) -> u64 { + self.me_floor_active_cap_configured_gauge + .load(Ordering::Relaxed) + } + pub fn get_me_floor_active_cap_effective_gauge(&self) -> u64 { + self.me_floor_active_cap_effective_gauge + .load(Ordering::Relaxed) + } + pub fn get_me_floor_warm_cap_configured_gauge(&self) -> u64 { + self.me_floor_warm_cap_configured_gauge + .load(Ordering::Relaxed) + } + pub fn get_me_floor_warm_cap_effective_gauge(&self) -> u64 { + self.me_floor_warm_cap_effective_gauge + .load(Ordering::Relaxed) + } + pub fn get_me_writers_active_current_gauge(&self) -> u64 { + self.me_writers_active_current_gauge.load(Ordering::Relaxed) + } + pub fn get_me_writers_warm_current_gauge(&self) -> u64 { + self.me_writers_warm_current_gauge.load(Ordering::Relaxed) + } + pub fn get_me_floor_cap_block_total(&self) -> u64 { + self.me_floor_cap_block_total.load(Ordering::Relaxed) + } + pub fn get_me_floor_swap_idle_total(&self) -> u64 { + self.me_floor_swap_idle_total.load(Ordering::Relaxed) + } + pub fn get_me_floor_swap_idle_failed_total(&self) -> u64 { + self.me_floor_swap_idle_failed_total.load(Ordering::Relaxed) + } +} diff --git a/src/stats/helpers.rs b/src/stats/helpers.rs new file mode 100644 index 0000000..7a398d4 --- /dev/null +++ b/src/stats/helpers.rs @@ -0,0 +1,208 @@ +use std::sync::Arc; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; + +use crate::config::MeTelemetryLevel; + +use super::*; + +impl Stats { + pub(super) fn telemetry_me_level(&self) -> MeTelemetryLevel { + MeTelemetryLevel::from_u8(self.telemetry_me_level.load(Ordering::Relaxed)) + } + + pub(super) fn telemetry_core_enabled(&self) -> bool { + self.telemetry_core_enabled.load(Ordering::Relaxed) + } + + pub(super) fn telemetry_user_enabled(&self) -> bool { + self.telemetry_user_enabled.load(Ordering::Relaxed) + } + + pub(super) fn telemetry_me_allows_normal(&self) -> bool { + self.telemetry_me_level().allows_normal() + } + + pub(super) fn telemetry_me_allows_debug(&self) -> bool { + self.telemetry_me_level().allows_debug() + } + + pub(super) fn decrement_atomic_saturating(counter: &AtomicU64) { + let mut current = counter.load(Ordering::Relaxed); + loop { + if current == 0 { + break; + } + match counter.compare_exchange_weak( + current, + current - 1, + Ordering::Relaxed, + Ordering::Relaxed, + ) { + Ok(_) => break, + Err(actual) => current = actual, + } + } + } + + pub(super) fn now_epoch_secs() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs() + } + + pub(super) fn refresh_cached_epoch_secs(&self) -> u64 { + let now_epoch_secs = Self::now_epoch_secs(); + self.cached_epoch_secs + .store(now_epoch_secs, Ordering::Relaxed); + now_epoch_secs + } + + pub(super) fn cached_epoch_secs(&self) -> u64 { + let cached = self.cached_epoch_secs.load(Ordering::Relaxed); + if cached != 0 { + return cached; + } + self.refresh_cached_epoch_secs() + } + + pub(super) fn touch_user_stats(&self, stats: &UserStats) { + stats + .last_seen_epoch_secs + .store(self.cached_epoch_secs(), Ordering::Relaxed); + } + + pub(crate) fn get_or_create_user_stats_handle(&self, user: &str) -> Arc { + if let Some(existing) = self.user_stats.get(user) { + let handle = Arc::clone(existing.value()); + self.touch_user_stats(handle.as_ref()); + return handle; + } + + let entry = self.user_stats.entry(user.to_string()).or_default(); + if entry.last_seen_epoch_secs.load(Ordering::Relaxed) == 0 { + self.touch_user_stats(entry.value().as_ref()); + } + Arc::clone(entry.value()) + } + + pub(crate) async fn run_periodic_user_stats_maintenance(self: Arc) { + let mut interval = tokio::time::interval(Duration::from_secs(60)); + loop { + interval.tick().await; + self.maybe_cleanup_user_stats(); + } + } + + #[inline] + pub(crate) fn add_user_octets_from_handle(&self, user_stats: &UserStats, bytes: u64) { + if !self.telemetry_user_enabled() { + return; + } + self.touch_user_stats(user_stats); + user_stats + .octets_from_client + .fetch_add(bytes, Ordering::Relaxed); + } + + #[inline] + pub(crate) fn add_user_octets_to_handle(&self, user_stats: &UserStats, bytes: u64) { + if !self.telemetry_user_enabled() { + return; + } + self.touch_user_stats(user_stats); + user_stats + .octets_to_client + .fetch_add(bytes, Ordering::Relaxed); + } + + #[inline] + pub(crate) fn add_user_traffic_from_handle(&self, user_stats: &UserStats, bytes: u64) { + if !self.telemetry_user_enabled() { + return; + } + self.touch_user_stats(user_stats); + user_stats + .octets_from_client + .fetch_add(bytes, Ordering::Relaxed); + user_stats.msgs_from_client.fetch_add(1, Ordering::Relaxed); + } + + #[inline] + pub(crate) fn add_user_traffic_to_handle(&self, user_stats: &UserStats, bytes: u64) { + if !self.telemetry_user_enabled() { + return; + } + self.touch_user_stats(user_stats); + user_stats + .octets_to_client + .fetch_add(bytes, Ordering::Relaxed); + user_stats.msgs_to_client.fetch_add(1, Ordering::Relaxed); + } + + #[inline] + pub(crate) fn increment_user_msgs_from_handle(&self, user_stats: &UserStats) { + if !self.telemetry_user_enabled() { + return; + } + self.touch_user_stats(user_stats); + user_stats.msgs_from_client.fetch_add(1, Ordering::Relaxed); + } + + #[inline] + pub(crate) fn increment_user_msgs_to_handle(&self, user_stats: &UserStats) { + if !self.telemetry_user_enabled() { + return; + } + self.touch_user_stats(user_stats); + user_stats.msgs_to_client.fetch_add(1, Ordering::Relaxed); + } + + /// Charges already committed bytes in a post-I/O path. + /// + /// This helper is intentionally separate from `quota_try_reserve` to avoid + /// mixing reserve and post-charge on a single I/O event. + #[inline] + pub(crate) fn quota_charge_post_write(&self, user_stats: &UserStats, bytes: u64) -> u64 { + self.touch_user_stats(user_stats); + user_stats + .quota_used + .fetch_add(bytes, Ordering::Relaxed) + .saturating_add(bytes) + } + + pub(super) fn maybe_cleanup_user_stats(&self) { + const USER_STATS_CLEANUP_INTERVAL_SECS: u64 = 60; + const USER_STATS_IDLE_TTL_SECS: u64 = 24 * 60 * 60; + + let now_epoch_secs = self.refresh_cached_epoch_secs(); + let last_cleanup_epoch_secs = self + .user_stats_last_cleanup_epoch_secs + .load(Ordering::Relaxed); + if now_epoch_secs.saturating_sub(last_cleanup_epoch_secs) < USER_STATS_CLEANUP_INTERVAL_SECS + { + return; + } + if self + .user_stats_last_cleanup_epoch_secs + .compare_exchange( + last_cleanup_epoch_secs, + now_epoch_secs, + Ordering::AcqRel, + Ordering::Relaxed, + ) + .is_err() + { + return; + } + + self.user_stats.retain(|_, stats| { + if stats.curr_connects.load(Ordering::Relaxed) > 0 { + return true; + } + let last_seen_epoch_secs = stats.last_seen_epoch_secs.load(Ordering::Relaxed); + now_epoch_secs.saturating_sub(last_seen_epoch_secs) <= USER_STATS_IDLE_TTL_SECS + }); + } +} diff --git a/src/stats/me_counters.rs b/src/stats/me_counters.rs new file mode 100644 index 0000000..a6c9fb4 --- /dev/null +++ b/src/stats/me_counters.rs @@ -0,0 +1,442 @@ +use super::*; + +impl Stats { + pub fn increment_me_keepalive_sent(&self) { + if self.telemetry_me_allows_debug() { + self.me_keepalive_sent.fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_keepalive_failed(&self) { + if self.telemetry_me_allows_normal() { + self.me_keepalive_failed.fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_keepalive_pong(&self) { + if self.telemetry_me_allows_debug() { + self.me_keepalive_pong.fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_keepalive_timeout(&self) { + if self.telemetry_me_allows_normal() { + self.me_keepalive_timeout.fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_keepalive_timeout_by(&self, value: u64) { + if self.telemetry_me_allows_normal() { + self.me_keepalive_timeout + .fetch_add(value, Ordering::Relaxed); + } + } + pub fn increment_me_rpc_proxy_req_signal_sent_total(&self) { + if self.telemetry_me_allows_normal() { + self.me_rpc_proxy_req_signal_sent_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_rpc_proxy_req_signal_failed_total(&self) { + if self.telemetry_me_allows_normal() { + self.me_rpc_proxy_req_signal_failed_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_rpc_proxy_req_signal_skipped_no_meta_total(&self) { + if self.telemetry_me_allows_normal() { + self.me_rpc_proxy_req_signal_skipped_no_meta_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_rpc_proxy_req_signal_response_total(&self) { + if self.telemetry_me_allows_normal() { + self.me_rpc_proxy_req_signal_response_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_rpc_proxy_req_signal_close_sent_total(&self) { + if self.telemetry_me_allows_normal() { + self.me_rpc_proxy_req_signal_close_sent_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_reconnect_attempt(&self) { + if self.telemetry_me_allows_normal() { + self.me_reconnect_attempts.fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_reconnect_success(&self) { + if self.telemetry_me_allows_normal() { + self.me_reconnect_success.fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_handshake_reject_total(&self) { + if self.telemetry_me_allows_normal() { + self.me_handshake_reject_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_handshake_error_code(&self, code: i32) { + if !self.telemetry_me_allows_normal() { + return; + } + let entry = self + .me_handshake_error_codes + .entry(code) + .or_insert_with(|| AtomicU64::new(0)); + entry.fetch_add(1, Ordering::Relaxed); + } + pub fn increment_me_reader_eof_total(&self) { + if self.telemetry_me_allows_normal() { + self.me_reader_eof_total.fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_idle_close_by_peer_total(&self) { + if self.telemetry_me_allows_normal() { + self.me_idle_close_by_peer_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_relay_idle_soft_mark_total(&self) { + if self.telemetry_me_allows_normal() { + self.relay_idle_soft_mark_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_relay_idle_hard_close_total(&self) { + if self.telemetry_me_allows_normal() { + self.relay_idle_hard_close_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_relay_pressure_evict_total(&self) { + if self.telemetry_me_allows_normal() { + self.relay_pressure_evict_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_relay_protocol_desync_close_total(&self) { + if self.telemetry_me_allows_normal() { + self.relay_protocol_desync_close_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_crc_mismatch(&self) { + if self.telemetry_me_allows_normal() { + self.me_crc_mismatch.fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_seq_mismatch(&self) { + if self.telemetry_me_allows_normal() { + self.me_seq_mismatch.fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_route_drop_no_conn(&self) { + if self.telemetry_me_allows_normal() { + self.me_route_drop_no_conn.fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_route_drop_channel_closed(&self) { + if self.telemetry_me_allows_normal() { + self.me_route_drop_channel_closed + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_route_drop_queue_full(&self) { + if self.telemetry_me_allows_normal() { + self.me_route_drop_queue_full + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_route_drop_queue_full_base(&self) { + if self.telemetry_me_allows_normal() { + self.me_route_drop_queue_full_base + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_route_drop_queue_full_high(&self) { + if self.telemetry_me_allows_normal() { + self.me_route_drop_queue_full_high + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn set_me_fair_pressure_state_gauge(&self, value: u64) { + if self.telemetry_me_allows_normal() { + self.me_fair_pressure_state_gauge + .store(value, Ordering::Relaxed); + } + } + pub fn set_me_fair_active_flows_gauge(&self, value: u64) { + if self.telemetry_me_allows_normal() { + self.me_fair_active_flows_gauge + .store(value, Ordering::Relaxed); + } + } + pub fn set_me_fair_queued_bytes_gauge(&self, value: u64) { + if self.telemetry_me_allows_normal() { + self.me_fair_queued_bytes_gauge + .store(value, Ordering::Relaxed); + } + } + pub fn set_me_fair_standing_flows_gauge(&self, value: u64) { + if self.telemetry_me_allows_normal() { + self.me_fair_standing_flows_gauge + .store(value, Ordering::Relaxed); + } + } + pub fn set_me_fair_backpressured_flows_gauge(&self, value: u64) { + if self.telemetry_me_allows_normal() { + self.me_fair_backpressured_flows_gauge + .store(value, Ordering::Relaxed); + } + } + pub fn add_me_fair_scheduler_rounds_total(&self, value: u64) { + if self.telemetry_me_allows_normal() && value > 0 { + self.me_fair_scheduler_rounds_total + .fetch_add(value, Ordering::Relaxed); + } + } + pub fn add_me_fair_deficit_grants_total(&self, value: u64) { + if self.telemetry_me_allows_normal() && value > 0 { + self.me_fair_deficit_grants_total + .fetch_add(value, Ordering::Relaxed); + } + } + pub fn add_me_fair_deficit_skips_total(&self, value: u64) { + if self.telemetry_me_allows_normal() && value > 0 { + self.me_fair_deficit_skips_total + .fetch_add(value, Ordering::Relaxed); + } + } + pub fn add_me_fair_enqueue_rejects_total(&self, value: u64) { + if self.telemetry_me_allows_normal() && value > 0 { + self.me_fair_enqueue_rejects_total + .fetch_add(value, Ordering::Relaxed); + } + } + pub fn add_me_fair_shed_drops_total(&self, value: u64) { + if self.telemetry_me_allows_normal() && value > 0 { + self.me_fair_shed_drops_total + .fetch_add(value, Ordering::Relaxed); + } + } + pub fn add_me_fair_penalties_total(&self, value: u64) { + if self.telemetry_me_allows_normal() && value > 0 { + self.me_fair_penalties_total + .fetch_add(value, Ordering::Relaxed); + } + } + pub fn add_me_fair_downstream_stalls_total(&self, value: u64) { + if self.telemetry_me_allows_normal() && value > 0 { + self.me_fair_downstream_stalls_total + .fetch_add(value, Ordering::Relaxed); + } + } + pub fn increment_me_d2c_batches_total(&self) { + if self.telemetry_me_allows_normal() { + self.me_d2c_batches_total.fetch_add(1, Ordering::Relaxed); + } + } + pub fn add_me_d2c_batch_frames_total(&self, frames: u64) { + if self.telemetry_me_allows_normal() { + self.me_d2c_batch_frames_total + .fetch_add(frames, Ordering::Relaxed); + } + } + pub fn add_me_d2c_batch_bytes_total(&self, bytes: u64) { + if self.telemetry_me_allows_normal() { + self.me_d2c_batch_bytes_total + .fetch_add(bytes, Ordering::Relaxed); + } + } + pub fn increment_me_d2c_flush_reason(&self, reason: MeD2cFlushReason) { + if !self.telemetry_me_allows_normal() { + return; + } + match reason { + MeD2cFlushReason::QueueDrain => { + self.me_d2c_flush_reason_queue_drain_total + .fetch_add(1, Ordering::Relaxed); + } + MeD2cFlushReason::BatchFrames => { + self.me_d2c_flush_reason_batch_frames_total + .fetch_add(1, Ordering::Relaxed); + } + MeD2cFlushReason::BatchBytes => { + self.me_d2c_flush_reason_batch_bytes_total + .fetch_add(1, Ordering::Relaxed); + } + MeD2cFlushReason::MaxDelay => { + self.me_d2c_flush_reason_max_delay_total + .fetch_add(1, Ordering::Relaxed); + } + MeD2cFlushReason::AckImmediate => { + self.me_d2c_flush_reason_ack_immediate_total + .fetch_add(1, Ordering::Relaxed); + } + MeD2cFlushReason::Close => { + self.me_d2c_flush_reason_close_total + .fetch_add(1, Ordering::Relaxed); + } + } + } + pub fn increment_me_d2c_data_frames_total(&self) { + if self.telemetry_me_allows_normal() { + self.me_d2c_data_frames_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_d2c_ack_frames_total(&self) { + if self.telemetry_me_allows_normal() { + self.me_d2c_ack_frames_total.fetch_add(1, Ordering::Relaxed); + } + } + pub fn add_me_d2c_payload_bytes_total(&self, bytes: u64) { + if self.telemetry_me_allows_normal() { + self.me_d2c_payload_bytes_total + .fetch_add(bytes, Ordering::Relaxed); + } + } + pub fn increment_me_d2c_write_mode(&self, mode: MeD2cWriteMode) { + if !self.telemetry_me_allows_normal() { + return; + } + match mode { + MeD2cWriteMode::Coalesced => { + self.me_d2c_write_mode_coalesced_total + .fetch_add(1, Ordering::Relaxed); + } + MeD2cWriteMode::Split => { + self.me_d2c_write_mode_split_total + .fetch_add(1, Ordering::Relaxed); + } + } + } + pub fn increment_me_d2c_quota_reject_total(&self, stage: MeD2cQuotaRejectStage) { + if !self.telemetry_me_allows_normal() { + return; + } + match stage { + MeD2cQuotaRejectStage::PreWrite => { + self.me_d2c_quota_reject_pre_write_total + .fetch_add(1, Ordering::Relaxed); + } + MeD2cQuotaRejectStage::PostWrite => { + self.me_d2c_quota_reject_post_write_total + .fetch_add(1, Ordering::Relaxed); + } + } + } + pub fn observe_me_d2c_frame_buf_shrink(&self, bytes_freed: u64) { + if !self.telemetry_me_allows_normal() { + return; + } + self.me_d2c_frame_buf_shrink_total + .fetch_add(1, Ordering::Relaxed); + self.me_d2c_frame_buf_shrink_bytes_total + .fetch_add(bytes_freed, Ordering::Relaxed); + } + pub fn observe_me_d2c_batch_frames(&self, frames: u64) { + if !self.telemetry_me_allows_debug() { + return; + } + match frames { + 0 => {} + 1 => { + self.me_d2c_batch_frames_bucket_1 + .fetch_add(1, Ordering::Relaxed); + } + 2..=4 => { + self.me_d2c_batch_frames_bucket_2_4 + .fetch_add(1, Ordering::Relaxed); + } + 5..=8 => { + self.me_d2c_batch_frames_bucket_5_8 + .fetch_add(1, Ordering::Relaxed); + } + 9..=16 => { + self.me_d2c_batch_frames_bucket_9_16 + .fetch_add(1, Ordering::Relaxed); + } + 17..=32 => { + self.me_d2c_batch_frames_bucket_17_32 + .fetch_add(1, Ordering::Relaxed); + } + _ => { + self.me_d2c_batch_frames_bucket_gt_32 + .fetch_add(1, Ordering::Relaxed); + } + } + } + pub fn observe_me_d2c_batch_bytes(&self, bytes: u64) { + if !self.telemetry_me_allows_debug() { + return; + } + match bytes { + 0..=1024 => { + self.me_d2c_batch_bytes_bucket_0_1k + .fetch_add(1, Ordering::Relaxed); + } + 1025..=4096 => { + self.me_d2c_batch_bytes_bucket_1k_4k + .fetch_add(1, Ordering::Relaxed); + } + 4097..=16_384 => { + self.me_d2c_batch_bytes_bucket_4k_16k + .fetch_add(1, Ordering::Relaxed); + } + 16_385..=65_536 => { + self.me_d2c_batch_bytes_bucket_16k_64k + .fetch_add(1, Ordering::Relaxed); + } + 65_537..=131_072 => { + self.me_d2c_batch_bytes_bucket_64k_128k + .fetch_add(1, Ordering::Relaxed); + } + _ => { + self.me_d2c_batch_bytes_bucket_gt_128k + .fetch_add(1, Ordering::Relaxed); + } + } + } + pub fn observe_me_d2c_flush_duration_us(&self, duration_us: u64) { + if !self.telemetry_me_allows_debug() { + return; + } + match duration_us { + 0..=50 => { + self.me_d2c_flush_duration_us_bucket_0_50 + .fetch_add(1, Ordering::Relaxed); + } + 51..=200 => { + self.me_d2c_flush_duration_us_bucket_51_200 + .fetch_add(1, Ordering::Relaxed); + } + 201..=1000 => { + self.me_d2c_flush_duration_us_bucket_201_1000 + .fetch_add(1, Ordering::Relaxed); + } + 1001..=5000 => { + self.me_d2c_flush_duration_us_bucket_1001_5000 + .fetch_add(1, Ordering::Relaxed); + } + 5001..=20_000 => { + self.me_d2c_flush_duration_us_bucket_5001_20000 + .fetch_add(1, Ordering::Relaxed); + } + _ => { + self.me_d2c_flush_duration_us_bucket_gt_20000 + .fetch_add(1, Ordering::Relaxed); + } + } + } + pub fn increment_me_d2c_batch_timeout_armed_total(&self) { + if self.telemetry_me_allows_debug() { + self.me_d2c_batch_timeout_armed_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_d2c_batch_timeout_fired_total(&self) { + if self.telemetry_me_allows_debug() { + self.me_d2c_batch_timeout_fired_total + .fetch_add(1, Ordering::Relaxed); + } + } +} diff --git a/src/stats/me_getters.rs b/src/stats/me_getters.rs new file mode 100644 index 0000000..0737089 --- /dev/null +++ b/src/stats/me_getters.rs @@ -0,0 +1,398 @@ +use super::*; + +impl Stats { + pub fn get_me_handshake_error_code_counts(&self) -> Vec<(i32, u64)> { + let mut out: Vec<(i32, u64)> = self + .me_handshake_error_codes + .iter() + .map(|entry| (*entry.key(), entry.value().load(Ordering::Relaxed))) + .collect(); + out.sort_by_key(|(code, _)| *code); + out + } + pub fn get_me_route_drop_no_conn(&self) -> u64 { + self.me_route_drop_no_conn.load(Ordering::Relaxed) + } + pub fn get_me_route_drop_channel_closed(&self) -> u64 { + self.me_route_drop_channel_closed.load(Ordering::Relaxed) + } + pub fn get_me_route_drop_queue_full(&self) -> u64 { + self.me_route_drop_queue_full.load(Ordering::Relaxed) + } + pub fn get_me_route_drop_queue_full_base(&self) -> u64 { + self.me_route_drop_queue_full_base.load(Ordering::Relaxed) + } + pub fn get_me_route_drop_queue_full_high(&self) -> u64 { + self.me_route_drop_queue_full_high.load(Ordering::Relaxed) + } + pub fn get_me_fair_pressure_state_gauge(&self) -> u64 { + self.me_fair_pressure_state_gauge.load(Ordering::Relaxed) + } + pub fn get_me_fair_active_flows_gauge(&self) -> u64 { + self.me_fair_active_flows_gauge.load(Ordering::Relaxed) + } + pub fn get_me_fair_queued_bytes_gauge(&self) -> u64 { + self.me_fair_queued_bytes_gauge.load(Ordering::Relaxed) + } + pub fn get_me_fair_standing_flows_gauge(&self) -> u64 { + self.me_fair_standing_flows_gauge.load(Ordering::Relaxed) + } + pub fn get_me_fair_backpressured_flows_gauge(&self) -> u64 { + self.me_fair_backpressured_flows_gauge + .load(Ordering::Relaxed) + } + pub fn get_me_fair_scheduler_rounds_total(&self) -> u64 { + self.me_fair_scheduler_rounds_total.load(Ordering::Relaxed) + } + pub fn get_me_fair_deficit_grants_total(&self) -> u64 { + self.me_fair_deficit_grants_total.load(Ordering::Relaxed) + } + pub fn get_me_fair_deficit_skips_total(&self) -> u64 { + self.me_fair_deficit_skips_total.load(Ordering::Relaxed) + } + pub fn get_me_fair_enqueue_rejects_total(&self) -> u64 { + self.me_fair_enqueue_rejects_total.load(Ordering::Relaxed) + } + pub fn get_me_fair_shed_drops_total(&self) -> u64 { + self.me_fair_shed_drops_total.load(Ordering::Relaxed) + } + pub fn get_me_fair_penalties_total(&self) -> u64 { + self.me_fair_penalties_total.load(Ordering::Relaxed) + } + pub fn get_me_fair_downstream_stalls_total(&self) -> u64 { + self.me_fair_downstream_stalls_total.load(Ordering::Relaxed) + } + pub fn get_me_d2c_batches_total(&self) -> u64 { + self.me_d2c_batches_total.load(Ordering::Relaxed) + } + pub fn get_me_d2c_batch_frames_total(&self) -> u64 { + self.me_d2c_batch_frames_total.load(Ordering::Relaxed) + } + pub fn get_me_d2c_batch_bytes_total(&self) -> u64 { + self.me_d2c_batch_bytes_total.load(Ordering::Relaxed) + } + pub fn get_me_d2c_flush_reason_queue_drain_total(&self) -> u64 { + self.me_d2c_flush_reason_queue_drain_total + .load(Ordering::Relaxed) + } + pub fn get_me_d2c_flush_reason_batch_frames_total(&self) -> u64 { + self.me_d2c_flush_reason_batch_frames_total + .load(Ordering::Relaxed) + } + pub fn get_me_d2c_flush_reason_batch_bytes_total(&self) -> u64 { + self.me_d2c_flush_reason_batch_bytes_total + .load(Ordering::Relaxed) + } + pub fn get_me_d2c_flush_reason_max_delay_total(&self) -> u64 { + self.me_d2c_flush_reason_max_delay_total + .load(Ordering::Relaxed) + } + pub fn get_me_d2c_flush_reason_ack_immediate_total(&self) -> u64 { + self.me_d2c_flush_reason_ack_immediate_total + .load(Ordering::Relaxed) + } + pub fn get_me_d2c_flush_reason_close_total(&self) -> u64 { + self.me_d2c_flush_reason_close_total.load(Ordering::Relaxed) + } + pub fn get_me_d2c_data_frames_total(&self) -> u64 { + self.me_d2c_data_frames_total.load(Ordering::Relaxed) + } + pub fn get_me_d2c_ack_frames_total(&self) -> u64 { + self.me_d2c_ack_frames_total.load(Ordering::Relaxed) + } + pub fn get_me_d2c_payload_bytes_total(&self) -> u64 { + self.me_d2c_payload_bytes_total.load(Ordering::Relaxed) + } + pub fn get_me_d2c_write_mode_coalesced_total(&self) -> u64 { + self.me_d2c_write_mode_coalesced_total + .load(Ordering::Relaxed) + } + pub fn get_me_d2c_write_mode_split_total(&self) -> u64 { + self.me_d2c_write_mode_split_total.load(Ordering::Relaxed) + } + pub fn get_me_d2c_quota_reject_pre_write_total(&self) -> u64 { + self.me_d2c_quota_reject_pre_write_total + .load(Ordering::Relaxed) + } + pub fn get_me_d2c_quota_reject_post_write_total(&self) -> u64 { + self.me_d2c_quota_reject_post_write_total + .load(Ordering::Relaxed) + } + pub fn get_me_d2c_frame_buf_shrink_total(&self) -> u64 { + self.me_d2c_frame_buf_shrink_total.load(Ordering::Relaxed) + } + pub fn get_me_d2c_frame_buf_shrink_bytes_total(&self) -> u64 { + self.me_d2c_frame_buf_shrink_bytes_total + .load(Ordering::Relaxed) + } + pub fn get_me_d2c_batch_frames_bucket_1(&self) -> u64 { + self.me_d2c_batch_frames_bucket_1.load(Ordering::Relaxed) + } + pub fn get_me_d2c_batch_frames_bucket_2_4(&self) -> u64 { + self.me_d2c_batch_frames_bucket_2_4.load(Ordering::Relaxed) + } + pub fn get_me_d2c_batch_frames_bucket_5_8(&self) -> u64 { + self.me_d2c_batch_frames_bucket_5_8.load(Ordering::Relaxed) + } + pub fn get_me_d2c_batch_frames_bucket_9_16(&self) -> u64 { + self.me_d2c_batch_frames_bucket_9_16.load(Ordering::Relaxed) + } + pub fn get_me_d2c_batch_frames_bucket_17_32(&self) -> u64 { + self.me_d2c_batch_frames_bucket_17_32 + .load(Ordering::Relaxed) + } + pub fn get_me_d2c_batch_frames_bucket_gt_32(&self) -> u64 { + self.me_d2c_batch_frames_bucket_gt_32 + .load(Ordering::Relaxed) + } + pub fn get_me_d2c_batch_bytes_bucket_0_1k(&self) -> u64 { + self.me_d2c_batch_bytes_bucket_0_1k.load(Ordering::Relaxed) + } + pub fn get_me_d2c_batch_bytes_bucket_1k_4k(&self) -> u64 { + self.me_d2c_batch_bytes_bucket_1k_4k.load(Ordering::Relaxed) + } + pub fn get_me_d2c_batch_bytes_bucket_4k_16k(&self) -> u64 { + self.me_d2c_batch_bytes_bucket_4k_16k + .load(Ordering::Relaxed) + } + pub fn get_me_d2c_batch_bytes_bucket_16k_64k(&self) -> u64 { + self.me_d2c_batch_bytes_bucket_16k_64k + .load(Ordering::Relaxed) + } + pub fn get_me_d2c_batch_bytes_bucket_64k_128k(&self) -> u64 { + self.me_d2c_batch_bytes_bucket_64k_128k + .load(Ordering::Relaxed) + } + pub fn get_me_d2c_batch_bytes_bucket_gt_128k(&self) -> u64 { + self.me_d2c_batch_bytes_bucket_gt_128k + .load(Ordering::Relaxed) + } + pub fn get_me_d2c_flush_duration_us_bucket_0_50(&self) -> u64 { + self.me_d2c_flush_duration_us_bucket_0_50 + .load(Ordering::Relaxed) + } + pub fn get_me_d2c_flush_duration_us_bucket_51_200(&self) -> u64 { + self.me_d2c_flush_duration_us_bucket_51_200 + .load(Ordering::Relaxed) + } + pub fn get_me_d2c_flush_duration_us_bucket_201_1000(&self) -> u64 { + self.me_d2c_flush_duration_us_bucket_201_1000 + .load(Ordering::Relaxed) + } + pub fn get_me_d2c_flush_duration_us_bucket_1001_5000(&self) -> u64 { + self.me_d2c_flush_duration_us_bucket_1001_5000 + .load(Ordering::Relaxed) + } + pub fn get_me_d2c_flush_duration_us_bucket_5001_20000(&self) -> u64 { + self.me_d2c_flush_duration_us_bucket_5001_20000 + .load(Ordering::Relaxed) + } + pub fn get_me_d2c_flush_duration_us_bucket_gt_20000(&self) -> u64 { + self.me_d2c_flush_duration_us_bucket_gt_20000 + .load(Ordering::Relaxed) + } + + pub fn get_buffer_pool_pooled_gauge(&self) -> u64 { + self.buffer_pool_pooled_gauge.load(Ordering::Relaxed) + } + + pub fn get_buffer_pool_allocated_gauge(&self) -> u64 { + self.buffer_pool_allocated_gauge.load(Ordering::Relaxed) + } + + pub fn get_buffer_pool_in_use_gauge(&self) -> u64 { + self.buffer_pool_in_use_gauge.load(Ordering::Relaxed) + } + + pub fn get_me_c2me_send_full_total(&self) -> u64 { + self.me_c2me_send_full_total.load(Ordering::Relaxed) + } + + pub fn get_me_c2me_send_high_water_total(&self) -> u64 { + self.me_c2me_send_high_water_total.load(Ordering::Relaxed) + } + + pub fn get_me_c2me_send_timeout_total(&self) -> u64 { + self.me_c2me_send_timeout_total.load(Ordering::Relaxed) + } + pub fn get_me_d2c_batch_timeout_armed_total(&self) -> u64 { + self.me_d2c_batch_timeout_armed_total + .load(Ordering::Relaxed) + } + pub fn get_me_d2c_batch_timeout_fired_total(&self) -> u64 { + self.me_d2c_batch_timeout_fired_total + .load(Ordering::Relaxed) + } + pub fn get_me_writer_pick_sorted_rr_success_try_total(&self) -> u64 { + self.me_writer_pick_sorted_rr_success_try_total + .load(Ordering::Relaxed) + } + pub fn get_me_writer_pick_sorted_rr_success_fallback_total(&self) -> u64 { + self.me_writer_pick_sorted_rr_success_fallback_total + .load(Ordering::Relaxed) + } + pub fn get_me_writer_pick_sorted_rr_full_total(&self) -> u64 { + self.me_writer_pick_sorted_rr_full_total + .load(Ordering::Relaxed) + } + pub fn get_me_writer_pick_sorted_rr_closed_total(&self) -> u64 { + self.me_writer_pick_sorted_rr_closed_total + .load(Ordering::Relaxed) + } + pub fn get_me_writer_pick_sorted_rr_no_candidate_total(&self) -> u64 { + self.me_writer_pick_sorted_rr_no_candidate_total + .load(Ordering::Relaxed) + } + pub fn get_me_writer_pick_p2c_success_try_total(&self) -> u64 { + self.me_writer_pick_p2c_success_try_total + .load(Ordering::Relaxed) + } + pub fn get_me_writer_pick_p2c_success_fallback_total(&self) -> u64 { + self.me_writer_pick_p2c_success_fallback_total + .load(Ordering::Relaxed) + } + pub fn get_me_writer_pick_p2c_full_total(&self) -> u64 { + self.me_writer_pick_p2c_full_total.load(Ordering::Relaxed) + } + pub fn get_me_writer_pick_p2c_closed_total(&self) -> u64 { + self.me_writer_pick_p2c_closed_total.load(Ordering::Relaxed) + } + pub fn get_me_writer_pick_p2c_no_candidate_total(&self) -> u64 { + self.me_writer_pick_p2c_no_candidate_total + .load(Ordering::Relaxed) + } + pub fn get_me_writer_pick_blocking_fallback_total(&self) -> u64 { + self.me_writer_pick_blocking_fallback_total + .load(Ordering::Relaxed) + } + pub fn get_me_writer_pick_mode_switch_total(&self) -> u64 { + self.me_writer_pick_mode_switch_total + .load(Ordering::Relaxed) + } + pub fn get_me_socks_kdf_strict_reject(&self) -> u64 { + self.me_socks_kdf_strict_reject.load(Ordering::Relaxed) + } + pub fn get_me_socks_kdf_compat_fallback(&self) -> u64 { + self.me_socks_kdf_compat_fallback.load(Ordering::Relaxed) + } + pub fn get_secure_padding_invalid(&self) -> u64 { + self.secure_padding_invalid.load(Ordering::Relaxed) + } + pub fn get_desync_total(&self) -> u64 { + self.desync_total.load(Ordering::Relaxed) + } + pub fn get_desync_full_logged(&self) -> u64 { + self.desync_full_logged.load(Ordering::Relaxed) + } + pub fn get_desync_suppressed(&self) -> u64 { + self.desync_suppressed.load(Ordering::Relaxed) + } + pub fn get_desync_frames_bucket_0(&self) -> u64 { + self.desync_frames_bucket_0.load(Ordering::Relaxed) + } + pub fn get_desync_frames_bucket_1_2(&self) -> u64 { + self.desync_frames_bucket_1_2.load(Ordering::Relaxed) + } + pub fn get_desync_frames_bucket_3_10(&self) -> u64 { + self.desync_frames_bucket_3_10.load(Ordering::Relaxed) + } + pub fn get_desync_frames_bucket_gt_10(&self) -> u64 { + self.desync_frames_bucket_gt_10.load(Ordering::Relaxed) + } + pub fn get_pool_swap_total(&self) -> u64 { + self.pool_swap_total.load(Ordering::Relaxed) + } + pub fn get_pool_drain_active(&self) -> u64 { + self.pool_drain_active.load(Ordering::Relaxed) + } + pub fn get_pool_force_close_total(&self) -> u64 { + self.pool_force_close_total.load(Ordering::Relaxed) + } + pub fn get_pool_stale_pick_total(&self) -> u64 { + self.pool_stale_pick_total.load(Ordering::Relaxed) + } + pub fn get_me_writer_removed_total(&self) -> u64 { + self.me_writer_removed_total.load(Ordering::Relaxed) + } + pub fn get_me_writer_removed_unexpected_total(&self) -> u64 { + self.me_writer_removed_unexpected_total + .load(Ordering::Relaxed) + } + pub fn get_me_refill_triggered_total(&self) -> u64 { + self.me_refill_triggered_total.load(Ordering::Relaxed) + } + pub fn get_me_refill_skipped_inflight_total(&self) -> u64 { + self.me_refill_skipped_inflight_total + .load(Ordering::Relaxed) + } + pub fn get_me_refill_failed_total(&self) -> u64 { + self.me_refill_failed_total.load(Ordering::Relaxed) + } + pub fn get_me_writer_restored_same_endpoint_total(&self) -> u64 { + self.me_writer_restored_same_endpoint_total + .load(Ordering::Relaxed) + } + pub fn get_me_writer_restored_fallback_total(&self) -> u64 { + self.me_writer_restored_fallback_total + .load(Ordering::Relaxed) + } + pub fn get_me_no_writer_failfast_total(&self) -> u64 { + self.me_no_writer_failfast_total.load(Ordering::Relaxed) + } + pub fn get_me_hybrid_timeout_total(&self) -> u64 { + self.me_hybrid_timeout_total.load(Ordering::Relaxed) + } + pub fn get_me_async_recovery_trigger_total(&self) -> u64 { + self.me_async_recovery_trigger_total.load(Ordering::Relaxed) + } + pub fn get_me_inline_recovery_total(&self) -> u64 { + self.me_inline_recovery_total.load(Ordering::Relaxed) + } + pub fn get_ip_reservation_rollback_tcp_limit_total(&self) -> u64 { + self.ip_reservation_rollback_tcp_limit_total + .load(Ordering::Relaxed) + } + pub fn get_ip_reservation_rollback_quota_limit_total(&self) -> u64 { + self.ip_reservation_rollback_quota_limit_total + .load(Ordering::Relaxed) + } + pub fn get_quota_refund_bytes_total(&self) -> u64 { + self.quota_refund_bytes_total.load(Ordering::Relaxed) + } + pub fn get_quota_contention_total(&self) -> u64 { + self.quota_contention_total.load(Ordering::Relaxed) + } + pub fn get_quota_contention_timeout_total(&self) -> u64 { + self.quota_contention_timeout_total.load(Ordering::Relaxed) + } + pub fn get_quota_acquire_cancelled_total(&self) -> u64 { + self.quota_acquire_cancelled_total.load(Ordering::Relaxed) + } + pub fn get_quota_write_fail_bytes_total(&self) -> u64 { + self.quota_write_fail_bytes_total.load(Ordering::Relaxed) + } + pub fn get_quota_write_fail_events_total(&self) -> u64 { + self.quota_write_fail_events_total.load(Ordering::Relaxed) + } + pub fn get_me_child_join_timeout_total(&self) -> u64 { + self.me_child_join_timeout_total.load(Ordering::Relaxed) + } + pub fn get_me_child_abort_total(&self) -> u64 { + self.me_child_abort_total.load(Ordering::Relaxed) + } + pub fn get_flow_wait_middle_rate_limit_total(&self) -> u64 { + self.flow_wait_middle_rate_limit_total + .load(Ordering::Relaxed) + } + pub fn get_flow_wait_middle_rate_limit_cancelled_total(&self) -> u64 { + self.flow_wait_middle_rate_limit_cancelled_total + .load(Ordering::Relaxed) + } + pub fn get_flow_wait_middle_rate_limit_ms_total(&self) -> u64 { + self.flow_wait_middle_rate_limit_ms_total + .load(Ordering::Relaxed) + } + pub fn get_session_drop_fallback_total(&self) -> u64 { + self.session_drop_fallback_total.load(Ordering::Relaxed) + } +} diff --git a/src/stats/mod.rs b/src/stats/mod.rs index 50bc4e2..7360fc1 100644 --- a/src/stats/mod.rs +++ b/src/stats/mod.rs @@ -3,22 +3,26 @@ #![allow(dead_code)] pub mod beobachten; +mod core_counters; +mod core_getters; +mod helpers; +mod me_counters; +mod me_getters; +mod replay; pub mod telemetry; +mod users; +mod writer_counters; use dashmap::DashMap; -use lru::LruCache; -use parking_lot::Mutex; -use std::collections::hash_map::DefaultHasher; -use std::collections::{HashMap, VecDeque}; -use std::hash::{Hash, Hasher}; -use std::num::NonZeroUsize; +use std::collections::HashMap; use std::sync::Arc; use std::sync::atomic::{AtomicBool, AtomicU8, AtomicU64, Ordering}; -use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; -use tracing::debug; +use std::time::Instant; +#[allow(unused_imports)] +pub use self::replay::{ReplayChecker, ReplayStats}; use self::telemetry::TelemetryPolicy; -use crate::config::{MeTelemetryLevel, MeWriterPickMode}; +use crate::config::MeWriterPickMode; #[derive(Clone, Copy)] enum RouteConnectionGauge { @@ -26,6 +30,12 @@ enum RouteConnectionGauge { Middle, } +#[derive(Clone, Copy)] +enum RouteCutoverParkGauge { + Direct, + Middle, +} + #[derive(Debug, Clone, Copy)] pub enum MeD2cFlushReason { QueueDrain, @@ -55,6 +65,13 @@ pub struct RouteConnectionLease { active: bool, } +#[must_use = "RouteCutoverParkLease must be kept alive while a route cutover is parked"] +pub struct RouteCutoverParkLease { + stats: Arc, + gauge: RouteCutoverParkGauge, + active: bool, +} + impl RouteConnectionLease { fn new(stats: Arc, gauge: RouteConnectionGauge) -> Self { Self { @@ -82,6 +99,28 @@ impl Drop for RouteConnectionLease { } } +impl RouteCutoverParkLease { + fn new(stats: Arc, gauge: RouteCutoverParkGauge) -> Self { + Self { + stats, + gauge, + active: true, + } + } +} + +impl Drop for RouteCutoverParkLease { + fn drop(&mut self) { + if !self.active { + return; + } + match self.gauge { + RouteCutoverParkGauge::Direct => self.stats.decrement_route_cutover_parked_direct(), + RouteCutoverParkGauge::Middle => self.stats.decrement_route_cutover_parked_middle(), + } + } +} + // ============= Stats ============= #[derive(Default)] @@ -92,6 +131,10 @@ pub struct Stats { handshake_failure_classes: DashMap<&'static str, AtomicU64>, current_connections_direct: AtomicU64, current_connections_me: AtomicU64, + route_cutover_parked_direct_current: AtomicU64, + route_cutover_parked_middle_current: AtomicU64, + route_cutover_parked_direct_total: AtomicU64, + route_cutover_parked_middle_total: AtomicU64, handshake_timeouts: AtomicU64, accept_permit_timeout_total: AtomicU64, conntrack_control_enabled_gauge: AtomicBool, @@ -363,2938 +406,10 @@ impl Stats { *stats.start_time.write() = Some(Instant::now()); stats } - - fn telemetry_me_level(&self) -> MeTelemetryLevel { - MeTelemetryLevel::from_u8(self.telemetry_me_level.load(Ordering::Relaxed)) - } - - fn telemetry_core_enabled(&self) -> bool { - self.telemetry_core_enabled.load(Ordering::Relaxed) - } - - fn telemetry_user_enabled(&self) -> bool { - self.telemetry_user_enabled.load(Ordering::Relaxed) - } - - fn telemetry_me_allows_normal(&self) -> bool { - self.telemetry_me_level().allows_normal() - } - - fn telemetry_me_allows_debug(&self) -> bool { - self.telemetry_me_level().allows_debug() - } - - fn decrement_atomic_saturating(counter: &AtomicU64) { - let mut current = counter.load(Ordering::Relaxed); - loop { - if current == 0 { - break; - } - match counter.compare_exchange_weak( - current, - current - 1, - Ordering::Relaxed, - Ordering::Relaxed, - ) { - Ok(_) => break, - Err(actual) => current = actual, - } - } - } - - fn now_epoch_secs() -> u64 { - SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap_or_default() - .as_secs() - } - - fn refresh_cached_epoch_secs(&self) -> u64 { - let now_epoch_secs = Self::now_epoch_secs(); - self.cached_epoch_secs - .store(now_epoch_secs, Ordering::Relaxed); - now_epoch_secs - } - - fn cached_epoch_secs(&self) -> u64 { - let cached = self.cached_epoch_secs.load(Ordering::Relaxed); - if cached != 0 { - return cached; - } - self.refresh_cached_epoch_secs() - } - - fn touch_user_stats(&self, stats: &UserStats) { - stats - .last_seen_epoch_secs - .store(self.cached_epoch_secs(), Ordering::Relaxed); - } - - pub(crate) fn get_or_create_user_stats_handle(&self, user: &str) -> Arc { - if let Some(existing) = self.user_stats.get(user) { - let handle = Arc::clone(existing.value()); - self.touch_user_stats(handle.as_ref()); - return handle; - } - - let entry = self.user_stats.entry(user.to_string()).or_default(); - if entry.last_seen_epoch_secs.load(Ordering::Relaxed) == 0 { - self.touch_user_stats(entry.value().as_ref()); - } - Arc::clone(entry.value()) - } - - pub(crate) async fn run_periodic_user_stats_maintenance(self: Arc) { - let mut interval = tokio::time::interval(Duration::from_secs(60)); - loop { - interval.tick().await; - self.maybe_cleanup_user_stats(); - } - } - - #[inline] - pub(crate) fn add_user_octets_from_handle(&self, user_stats: &UserStats, bytes: u64) { - if !self.telemetry_user_enabled() { - return; - } - self.touch_user_stats(user_stats); - user_stats - .octets_from_client - .fetch_add(bytes, Ordering::Relaxed); - } - - #[inline] - pub(crate) fn add_user_octets_to_handle(&self, user_stats: &UserStats, bytes: u64) { - if !self.telemetry_user_enabled() { - return; - } - self.touch_user_stats(user_stats); - user_stats - .octets_to_client - .fetch_add(bytes, Ordering::Relaxed); - } - - #[inline] - pub(crate) fn add_user_traffic_from_handle(&self, user_stats: &UserStats, bytes: u64) { - if !self.telemetry_user_enabled() { - return; - } - self.touch_user_stats(user_stats); - user_stats - .octets_from_client - .fetch_add(bytes, Ordering::Relaxed); - user_stats.msgs_from_client.fetch_add(1, Ordering::Relaxed); - } - - #[inline] - pub(crate) fn add_user_traffic_to_handle(&self, user_stats: &UserStats, bytes: u64) { - if !self.telemetry_user_enabled() { - return; - } - self.touch_user_stats(user_stats); - user_stats - .octets_to_client - .fetch_add(bytes, Ordering::Relaxed); - user_stats.msgs_to_client.fetch_add(1, Ordering::Relaxed); - } - - #[inline] - pub(crate) fn increment_user_msgs_from_handle(&self, user_stats: &UserStats) { - if !self.telemetry_user_enabled() { - return; - } - self.touch_user_stats(user_stats); - user_stats.msgs_from_client.fetch_add(1, Ordering::Relaxed); - } - - #[inline] - pub(crate) fn increment_user_msgs_to_handle(&self, user_stats: &UserStats) { - if !self.telemetry_user_enabled() { - return; - } - self.touch_user_stats(user_stats); - user_stats.msgs_to_client.fetch_add(1, Ordering::Relaxed); - } - - /// Charges already committed bytes in a post-I/O path. - /// - /// This helper is intentionally separate from `quota_try_reserve` to avoid - /// mixing reserve and post-charge on a single I/O event. - #[inline] - pub(crate) fn quota_charge_post_write(&self, user_stats: &UserStats, bytes: u64) -> u64 { - self.touch_user_stats(user_stats); - user_stats - .quota_used - .fetch_add(bytes, Ordering::Relaxed) - .saturating_add(bytes) - } - - fn maybe_cleanup_user_stats(&self) { - const USER_STATS_CLEANUP_INTERVAL_SECS: u64 = 60; - const USER_STATS_IDLE_TTL_SECS: u64 = 24 * 60 * 60; - - let now_epoch_secs = self.refresh_cached_epoch_secs(); - let last_cleanup_epoch_secs = self - .user_stats_last_cleanup_epoch_secs - .load(Ordering::Relaxed); - if now_epoch_secs.saturating_sub(last_cleanup_epoch_secs) < USER_STATS_CLEANUP_INTERVAL_SECS - { - return; - } - if self - .user_stats_last_cleanup_epoch_secs - .compare_exchange( - last_cleanup_epoch_secs, - now_epoch_secs, - Ordering::AcqRel, - Ordering::Relaxed, - ) - .is_err() - { - return; - } - - self.user_stats.retain(|_, stats| { - if stats.curr_connects.load(Ordering::Relaxed) > 0 { - return true; - } - let last_seen_epoch_secs = stats.last_seen_epoch_secs.load(Ordering::Relaxed); - now_epoch_secs.saturating_sub(last_seen_epoch_secs) <= USER_STATS_IDLE_TTL_SECS - }); - } - - pub fn apply_telemetry_policy(&self, policy: TelemetryPolicy) { - self.telemetry_core_enabled - .store(policy.core_enabled, Ordering::Relaxed); - self.telemetry_user_enabled - .store(policy.user_enabled, Ordering::Relaxed); - self.telemetry_me_level - .store(policy.me_level.as_u8(), Ordering::Relaxed); - } - - pub fn telemetry_policy(&self) -> TelemetryPolicy { - TelemetryPolicy { - core_enabled: self.telemetry_core_enabled(), - user_enabled: self.telemetry_user_enabled(), - me_level: self.telemetry_me_level(), - } - } - - pub fn increment_connects_all(&self) { - if self.telemetry_core_enabled() { - self.connects_all.fetch_add(1, Ordering::Relaxed); - } - } - - pub fn increment_connects_bad_with_class(&self, class: &'static str) { - if !self.telemetry_core_enabled() { - return; - } - self.connects_bad.fetch_add(1, Ordering::Relaxed); - let entry = self - .connects_bad_classes - .entry(class) - .or_insert_with(|| AtomicU64::new(0)); - entry.fetch_add(1, Ordering::Relaxed); - } - - pub fn increment_connects_bad(&self) { - self.increment_connects_bad_with_class("other"); - } - - pub fn increment_handshake_failure_class(&self, class: &'static str) { - if !self.telemetry_core_enabled() { - return; - } - let entry = self - .handshake_failure_classes - .entry(class) - .or_insert_with(|| AtomicU64::new(0)); - entry.fetch_add(1, Ordering::Relaxed); - } - pub fn increment_current_connections_direct(&self) { - self.current_connections_direct - .fetch_add(1, Ordering::Relaxed); - } - pub fn decrement_current_connections_direct(&self) { - Self::decrement_atomic_saturating(&self.current_connections_direct); - } - pub fn increment_current_connections_me(&self) { - self.current_connections_me.fetch_add(1, Ordering::Relaxed); - } - pub fn decrement_current_connections_me(&self) { - Self::decrement_atomic_saturating(&self.current_connections_me); - } - - pub fn acquire_direct_connection_lease(self: &Arc) -> RouteConnectionLease { - self.increment_current_connections_direct(); - RouteConnectionLease::new(self.clone(), RouteConnectionGauge::Direct) - } - - pub fn acquire_me_connection_lease(self: &Arc) -> RouteConnectionLease { - self.increment_current_connections_me(); - RouteConnectionLease::new(self.clone(), RouteConnectionGauge::Middle) - } - pub fn increment_handshake_timeouts(&self) { - if self.telemetry_core_enabled() { - self.handshake_timeouts.fetch_add(1, Ordering::Relaxed); - } - } - - pub fn increment_accept_permit_timeout_total(&self) { - if self.telemetry_core_enabled() { - self.accept_permit_timeout_total - .fetch_add(1, Ordering::Relaxed); - } - } - - pub fn set_conntrack_control_enabled(&self, enabled: bool) { - self.conntrack_control_enabled_gauge - .store(enabled, Ordering::Relaxed); - } - - pub fn set_conntrack_control_available(&self, available: bool) { - self.conntrack_control_available_gauge - .store(available, Ordering::Relaxed); - } - - pub fn set_conntrack_pressure_active(&self, active: bool) { - self.conntrack_pressure_active_gauge - .store(active, Ordering::Relaxed); - } - - pub fn set_conntrack_event_queue_depth(&self, depth: u64) { - self.conntrack_event_queue_depth_gauge - .store(depth, Ordering::Relaxed); - } - - pub fn set_conntrack_rule_apply_ok(&self, ok: bool) { - self.conntrack_rule_apply_ok_gauge - .store(ok, Ordering::Relaxed); - } - - pub fn increment_conntrack_delete_attempt_total(&self) { - if self.telemetry_core_enabled() { - self.conntrack_delete_attempt_total - .fetch_add(1, Ordering::Relaxed); - } - } - - pub fn increment_conntrack_delete_success_total(&self) { - if self.telemetry_core_enabled() { - self.conntrack_delete_success_total - .fetch_add(1, Ordering::Relaxed); - } - } - - pub fn increment_conntrack_delete_not_found_total(&self) { - if self.telemetry_core_enabled() { - self.conntrack_delete_not_found_total - .fetch_add(1, Ordering::Relaxed); - } - } - - pub fn increment_conntrack_delete_error_total(&self) { - if self.telemetry_core_enabled() { - self.conntrack_delete_error_total - .fetch_add(1, Ordering::Relaxed); - } - } - - pub fn increment_conntrack_close_event_drop_total(&self) { - if self.telemetry_core_enabled() { - self.conntrack_close_event_drop_total - .fetch_add(1, Ordering::Relaxed); - } - } - - pub fn increment_upstream_connect_attempt_total(&self) { - if self.telemetry_core_enabled() { - self.upstream_connect_attempt_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_upstream_connect_success_total(&self) { - if self.telemetry_core_enabled() { - self.upstream_connect_success_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_upstream_connect_fail_total(&self) { - if self.telemetry_core_enabled() { - self.upstream_connect_fail_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_upstream_connect_failfast_hard_error_total(&self) { - if self.telemetry_core_enabled() { - self.upstream_connect_failfast_hard_error_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn observe_upstream_connect_attempts_per_request(&self, attempts: u32) { - if !self.telemetry_core_enabled() { - return; - } - match attempts { - 0 => {} - 1 => { - self.upstream_connect_attempts_bucket_1 - .fetch_add(1, Ordering::Relaxed); - } - 2 => { - self.upstream_connect_attempts_bucket_2 - .fetch_add(1, Ordering::Relaxed); - } - 3..=4 => { - self.upstream_connect_attempts_bucket_3_4 - .fetch_add(1, Ordering::Relaxed); - } - _ => { - self.upstream_connect_attempts_bucket_gt_4 - .fetch_add(1, Ordering::Relaxed); - } - } - } - pub fn observe_upstream_connect_duration_ms(&self, duration_ms: u64, success: bool) { - if !self.telemetry_core_enabled() { - return; - } - let bucket = match duration_ms { - 0..=100 => 0u8, - 101..=500 => 1u8, - 501..=1000 => 2u8, - _ => 3u8, - }; - match (success, bucket) { - (true, 0) => { - self.upstream_connect_duration_success_bucket_le_100ms - .fetch_add(1, Ordering::Relaxed); - } - (true, 1) => { - self.upstream_connect_duration_success_bucket_101_500ms - .fetch_add(1, Ordering::Relaxed); - } - (true, 2) => { - self.upstream_connect_duration_success_bucket_501_1000ms - .fetch_add(1, Ordering::Relaxed); - } - (true, _) => { - self.upstream_connect_duration_success_bucket_gt_1000ms - .fetch_add(1, Ordering::Relaxed); - } - (false, 0) => { - self.upstream_connect_duration_fail_bucket_le_100ms - .fetch_add(1, Ordering::Relaxed); - } - (false, 1) => { - self.upstream_connect_duration_fail_bucket_101_500ms - .fetch_add(1, Ordering::Relaxed); - } - (false, 2) => { - self.upstream_connect_duration_fail_bucket_501_1000ms - .fetch_add(1, Ordering::Relaxed); - } - (false, _) => { - self.upstream_connect_duration_fail_bucket_gt_1000ms - .fetch_add(1, Ordering::Relaxed); - } - } - } - pub fn increment_me_keepalive_sent(&self) { - if self.telemetry_me_allows_debug() { - self.me_keepalive_sent.fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_keepalive_failed(&self) { - if self.telemetry_me_allows_normal() { - self.me_keepalive_failed.fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_keepalive_pong(&self) { - if self.telemetry_me_allows_debug() { - self.me_keepalive_pong.fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_keepalive_timeout(&self) { - if self.telemetry_me_allows_normal() { - self.me_keepalive_timeout.fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_keepalive_timeout_by(&self, value: u64) { - if self.telemetry_me_allows_normal() { - self.me_keepalive_timeout - .fetch_add(value, Ordering::Relaxed); - } - } - pub fn increment_me_rpc_proxy_req_signal_sent_total(&self) { - if self.telemetry_me_allows_normal() { - self.me_rpc_proxy_req_signal_sent_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_rpc_proxy_req_signal_failed_total(&self) { - if self.telemetry_me_allows_normal() { - self.me_rpc_proxy_req_signal_failed_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_rpc_proxy_req_signal_skipped_no_meta_total(&self) { - if self.telemetry_me_allows_normal() { - self.me_rpc_proxy_req_signal_skipped_no_meta_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_rpc_proxy_req_signal_response_total(&self) { - if self.telemetry_me_allows_normal() { - self.me_rpc_proxy_req_signal_response_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_rpc_proxy_req_signal_close_sent_total(&self) { - if self.telemetry_me_allows_normal() { - self.me_rpc_proxy_req_signal_close_sent_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_reconnect_attempt(&self) { - if self.telemetry_me_allows_normal() { - self.me_reconnect_attempts.fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_reconnect_success(&self) { - if self.telemetry_me_allows_normal() { - self.me_reconnect_success.fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_handshake_reject_total(&self) { - if self.telemetry_me_allows_normal() { - self.me_handshake_reject_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_handshake_error_code(&self, code: i32) { - if !self.telemetry_me_allows_normal() { - return; - } - let entry = self - .me_handshake_error_codes - .entry(code) - .or_insert_with(|| AtomicU64::new(0)); - entry.fetch_add(1, Ordering::Relaxed); - } - pub fn increment_me_reader_eof_total(&self) { - if self.telemetry_me_allows_normal() { - self.me_reader_eof_total.fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_idle_close_by_peer_total(&self) { - if self.telemetry_me_allows_normal() { - self.me_idle_close_by_peer_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_relay_idle_soft_mark_total(&self) { - if self.telemetry_me_allows_normal() { - self.relay_idle_soft_mark_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_relay_idle_hard_close_total(&self) { - if self.telemetry_me_allows_normal() { - self.relay_idle_hard_close_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_relay_pressure_evict_total(&self) { - if self.telemetry_me_allows_normal() { - self.relay_pressure_evict_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_relay_protocol_desync_close_total(&self) { - if self.telemetry_me_allows_normal() { - self.relay_protocol_desync_close_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_crc_mismatch(&self) { - if self.telemetry_me_allows_normal() { - self.me_crc_mismatch.fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_seq_mismatch(&self) { - if self.telemetry_me_allows_normal() { - self.me_seq_mismatch.fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_route_drop_no_conn(&self) { - if self.telemetry_me_allows_normal() { - self.me_route_drop_no_conn.fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_route_drop_channel_closed(&self) { - if self.telemetry_me_allows_normal() { - self.me_route_drop_channel_closed - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_route_drop_queue_full(&self) { - if self.telemetry_me_allows_normal() { - self.me_route_drop_queue_full - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_route_drop_queue_full_base(&self) { - if self.telemetry_me_allows_normal() { - self.me_route_drop_queue_full_base - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_route_drop_queue_full_high(&self) { - if self.telemetry_me_allows_normal() { - self.me_route_drop_queue_full_high - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn set_me_fair_pressure_state_gauge(&self, value: u64) { - if self.telemetry_me_allows_normal() { - self.me_fair_pressure_state_gauge - .store(value, Ordering::Relaxed); - } - } - pub fn set_me_fair_active_flows_gauge(&self, value: u64) { - if self.telemetry_me_allows_normal() { - self.me_fair_active_flows_gauge - .store(value, Ordering::Relaxed); - } - } - pub fn set_me_fair_queued_bytes_gauge(&self, value: u64) { - if self.telemetry_me_allows_normal() { - self.me_fair_queued_bytes_gauge - .store(value, Ordering::Relaxed); - } - } - pub fn set_me_fair_standing_flows_gauge(&self, value: u64) { - if self.telemetry_me_allows_normal() { - self.me_fair_standing_flows_gauge - .store(value, Ordering::Relaxed); - } - } - pub fn set_me_fair_backpressured_flows_gauge(&self, value: u64) { - if self.telemetry_me_allows_normal() { - self.me_fair_backpressured_flows_gauge - .store(value, Ordering::Relaxed); - } - } - pub fn add_me_fair_scheduler_rounds_total(&self, value: u64) { - if self.telemetry_me_allows_normal() && value > 0 { - self.me_fair_scheduler_rounds_total - .fetch_add(value, Ordering::Relaxed); - } - } - pub fn add_me_fair_deficit_grants_total(&self, value: u64) { - if self.telemetry_me_allows_normal() && value > 0 { - self.me_fair_deficit_grants_total - .fetch_add(value, Ordering::Relaxed); - } - } - pub fn add_me_fair_deficit_skips_total(&self, value: u64) { - if self.telemetry_me_allows_normal() && value > 0 { - self.me_fair_deficit_skips_total - .fetch_add(value, Ordering::Relaxed); - } - } - pub fn add_me_fair_enqueue_rejects_total(&self, value: u64) { - if self.telemetry_me_allows_normal() && value > 0 { - self.me_fair_enqueue_rejects_total - .fetch_add(value, Ordering::Relaxed); - } - } - pub fn add_me_fair_shed_drops_total(&self, value: u64) { - if self.telemetry_me_allows_normal() && value > 0 { - self.me_fair_shed_drops_total - .fetch_add(value, Ordering::Relaxed); - } - } - pub fn add_me_fair_penalties_total(&self, value: u64) { - if self.telemetry_me_allows_normal() && value > 0 { - self.me_fair_penalties_total - .fetch_add(value, Ordering::Relaxed); - } - } - pub fn add_me_fair_downstream_stalls_total(&self, value: u64) { - if self.telemetry_me_allows_normal() && value > 0 { - self.me_fair_downstream_stalls_total - .fetch_add(value, Ordering::Relaxed); - } - } - pub fn increment_me_d2c_batches_total(&self) { - if self.telemetry_me_allows_normal() { - self.me_d2c_batches_total.fetch_add(1, Ordering::Relaxed); - } - } - pub fn add_me_d2c_batch_frames_total(&self, frames: u64) { - if self.telemetry_me_allows_normal() { - self.me_d2c_batch_frames_total - .fetch_add(frames, Ordering::Relaxed); - } - } - pub fn add_me_d2c_batch_bytes_total(&self, bytes: u64) { - if self.telemetry_me_allows_normal() { - self.me_d2c_batch_bytes_total - .fetch_add(bytes, Ordering::Relaxed); - } - } - pub fn increment_me_d2c_flush_reason(&self, reason: MeD2cFlushReason) { - if !self.telemetry_me_allows_normal() { - return; - } - match reason { - MeD2cFlushReason::QueueDrain => { - self.me_d2c_flush_reason_queue_drain_total - .fetch_add(1, Ordering::Relaxed); - } - MeD2cFlushReason::BatchFrames => { - self.me_d2c_flush_reason_batch_frames_total - .fetch_add(1, Ordering::Relaxed); - } - MeD2cFlushReason::BatchBytes => { - self.me_d2c_flush_reason_batch_bytes_total - .fetch_add(1, Ordering::Relaxed); - } - MeD2cFlushReason::MaxDelay => { - self.me_d2c_flush_reason_max_delay_total - .fetch_add(1, Ordering::Relaxed); - } - MeD2cFlushReason::AckImmediate => { - self.me_d2c_flush_reason_ack_immediate_total - .fetch_add(1, Ordering::Relaxed); - } - MeD2cFlushReason::Close => { - self.me_d2c_flush_reason_close_total - .fetch_add(1, Ordering::Relaxed); - } - } - } - pub fn increment_me_d2c_data_frames_total(&self) { - if self.telemetry_me_allows_normal() { - self.me_d2c_data_frames_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_d2c_ack_frames_total(&self) { - if self.telemetry_me_allows_normal() { - self.me_d2c_ack_frames_total.fetch_add(1, Ordering::Relaxed); - } - } - pub fn add_me_d2c_payload_bytes_total(&self, bytes: u64) { - if self.telemetry_me_allows_normal() { - self.me_d2c_payload_bytes_total - .fetch_add(bytes, Ordering::Relaxed); - } - } - pub fn increment_me_d2c_write_mode(&self, mode: MeD2cWriteMode) { - if !self.telemetry_me_allows_normal() { - return; - } - match mode { - MeD2cWriteMode::Coalesced => { - self.me_d2c_write_mode_coalesced_total - .fetch_add(1, Ordering::Relaxed); - } - MeD2cWriteMode::Split => { - self.me_d2c_write_mode_split_total - .fetch_add(1, Ordering::Relaxed); - } - } - } - pub fn increment_me_d2c_quota_reject_total(&self, stage: MeD2cQuotaRejectStage) { - if !self.telemetry_me_allows_normal() { - return; - } - match stage { - MeD2cQuotaRejectStage::PreWrite => { - self.me_d2c_quota_reject_pre_write_total - .fetch_add(1, Ordering::Relaxed); - } - MeD2cQuotaRejectStage::PostWrite => { - self.me_d2c_quota_reject_post_write_total - .fetch_add(1, Ordering::Relaxed); - } - } - } - pub fn observe_me_d2c_frame_buf_shrink(&self, bytes_freed: u64) { - if !self.telemetry_me_allows_normal() { - return; - } - self.me_d2c_frame_buf_shrink_total - .fetch_add(1, Ordering::Relaxed); - self.me_d2c_frame_buf_shrink_bytes_total - .fetch_add(bytes_freed, Ordering::Relaxed); - } - pub fn observe_me_d2c_batch_frames(&self, frames: u64) { - if !self.telemetry_me_allows_debug() { - return; - } - match frames { - 0 => {} - 1 => { - self.me_d2c_batch_frames_bucket_1 - .fetch_add(1, Ordering::Relaxed); - } - 2..=4 => { - self.me_d2c_batch_frames_bucket_2_4 - .fetch_add(1, Ordering::Relaxed); - } - 5..=8 => { - self.me_d2c_batch_frames_bucket_5_8 - .fetch_add(1, Ordering::Relaxed); - } - 9..=16 => { - self.me_d2c_batch_frames_bucket_9_16 - .fetch_add(1, Ordering::Relaxed); - } - 17..=32 => { - self.me_d2c_batch_frames_bucket_17_32 - .fetch_add(1, Ordering::Relaxed); - } - _ => { - self.me_d2c_batch_frames_bucket_gt_32 - .fetch_add(1, Ordering::Relaxed); - } - } - } - pub fn observe_me_d2c_batch_bytes(&self, bytes: u64) { - if !self.telemetry_me_allows_debug() { - return; - } - match bytes { - 0..=1024 => { - self.me_d2c_batch_bytes_bucket_0_1k - .fetch_add(1, Ordering::Relaxed); - } - 1025..=4096 => { - self.me_d2c_batch_bytes_bucket_1k_4k - .fetch_add(1, Ordering::Relaxed); - } - 4097..=16_384 => { - self.me_d2c_batch_bytes_bucket_4k_16k - .fetch_add(1, Ordering::Relaxed); - } - 16_385..=65_536 => { - self.me_d2c_batch_bytes_bucket_16k_64k - .fetch_add(1, Ordering::Relaxed); - } - 65_537..=131_072 => { - self.me_d2c_batch_bytes_bucket_64k_128k - .fetch_add(1, Ordering::Relaxed); - } - _ => { - self.me_d2c_batch_bytes_bucket_gt_128k - .fetch_add(1, Ordering::Relaxed); - } - } - } - pub fn observe_me_d2c_flush_duration_us(&self, duration_us: u64) { - if !self.telemetry_me_allows_debug() { - return; - } - match duration_us { - 0..=50 => { - self.me_d2c_flush_duration_us_bucket_0_50 - .fetch_add(1, Ordering::Relaxed); - } - 51..=200 => { - self.me_d2c_flush_duration_us_bucket_51_200 - .fetch_add(1, Ordering::Relaxed); - } - 201..=1000 => { - self.me_d2c_flush_duration_us_bucket_201_1000 - .fetch_add(1, Ordering::Relaxed); - } - 1001..=5000 => { - self.me_d2c_flush_duration_us_bucket_1001_5000 - .fetch_add(1, Ordering::Relaxed); - } - 5001..=20_000 => { - self.me_d2c_flush_duration_us_bucket_5001_20000 - .fetch_add(1, Ordering::Relaxed); - } - _ => { - self.me_d2c_flush_duration_us_bucket_gt_20000 - .fetch_add(1, Ordering::Relaxed); - } - } - } - pub fn increment_me_d2c_batch_timeout_armed_total(&self) { - if self.telemetry_me_allows_debug() { - self.me_d2c_batch_timeout_armed_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_d2c_batch_timeout_fired_total(&self) { - if self.telemetry_me_allows_debug() { - self.me_d2c_batch_timeout_fired_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_writer_pick_success_try_total(&self, mode: MeWriterPickMode) { - if !self.telemetry_me_allows_normal() { - return; - } - match mode { - MeWriterPickMode::SortedRr => { - self.me_writer_pick_sorted_rr_success_try_total - .fetch_add(1, Ordering::Relaxed); - } - MeWriterPickMode::P2c => { - self.me_writer_pick_p2c_success_try_total - .fetch_add(1, Ordering::Relaxed); - } - } - } - pub fn increment_me_writer_pick_success_fallback_total(&self, mode: MeWriterPickMode) { - if !self.telemetry_me_allows_normal() { - return; - } - match mode { - MeWriterPickMode::SortedRr => { - self.me_writer_pick_sorted_rr_success_fallback_total - .fetch_add(1, Ordering::Relaxed); - } - MeWriterPickMode::P2c => { - self.me_writer_pick_p2c_success_fallback_total - .fetch_add(1, Ordering::Relaxed); - } - } - } - pub fn increment_me_writer_pick_full_total(&self, mode: MeWriterPickMode) { - if !self.telemetry_me_allows_normal() { - return; - } - match mode { - MeWriterPickMode::SortedRr => { - self.me_writer_pick_sorted_rr_full_total - .fetch_add(1, Ordering::Relaxed); - } - MeWriterPickMode::P2c => { - self.me_writer_pick_p2c_full_total - .fetch_add(1, Ordering::Relaxed); - } - } - } - pub fn increment_me_writer_pick_closed_total(&self, mode: MeWriterPickMode) { - if !self.telemetry_me_allows_normal() { - return; - } - match mode { - MeWriterPickMode::SortedRr => { - self.me_writer_pick_sorted_rr_closed_total - .fetch_add(1, Ordering::Relaxed); - } - MeWriterPickMode::P2c => { - self.me_writer_pick_p2c_closed_total - .fetch_add(1, Ordering::Relaxed); - } - } - } - pub fn increment_me_writer_pick_no_candidate_total(&self, mode: MeWriterPickMode) { - if !self.telemetry_me_allows_normal() { - return; - } - match mode { - MeWriterPickMode::SortedRr => { - self.me_writer_pick_sorted_rr_no_candidate_total - .fetch_add(1, Ordering::Relaxed); - } - MeWriterPickMode::P2c => { - self.me_writer_pick_p2c_no_candidate_total - .fetch_add(1, Ordering::Relaxed); - } - } - } - pub fn increment_me_writer_pick_blocking_fallback_total(&self) { - if self.telemetry_me_allows_normal() { - self.me_writer_pick_blocking_fallback_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_writer_pick_mode_switch_total(&self) { - if self.telemetry_me_allows_normal() { - self.me_writer_pick_mode_switch_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_socks_kdf_strict_reject(&self) { - if self.telemetry_me_allows_normal() { - self.me_socks_kdf_strict_reject - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_socks_kdf_compat_fallback(&self) { - if self.telemetry_me_allows_debug() { - self.me_socks_kdf_compat_fallback - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_secure_padding_invalid(&self) { - if self.telemetry_me_allows_normal() { - self.secure_padding_invalid.fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_desync_total(&self) { - if self.telemetry_me_allows_normal() { - self.desync_total.fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_desync_full_logged(&self) { - if self.telemetry_me_allows_normal() { - self.desync_full_logged.fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_desync_suppressed(&self) { - if self.telemetry_me_allows_normal() { - self.desync_suppressed.fetch_add(1, Ordering::Relaxed); - } - } - pub fn observe_desync_frames_ok(&self, frames_ok: u64) { - if !self.telemetry_me_allows_normal() { - return; - } - match frames_ok { - 0 => { - self.desync_frames_bucket_0.fetch_add(1, Ordering::Relaxed); - } - 1..=2 => { - self.desync_frames_bucket_1_2 - .fetch_add(1, Ordering::Relaxed); - } - 3..=10 => { - self.desync_frames_bucket_3_10 - .fetch_add(1, Ordering::Relaxed); - } - _ => { - self.desync_frames_bucket_gt_10 - .fetch_add(1, Ordering::Relaxed); - } - } - } - pub fn increment_pool_swap_total(&self) { - if self.telemetry_me_allows_normal() { - self.pool_swap_total.fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_pool_drain_active(&self) { - if self.telemetry_me_allows_debug() { - self.pool_drain_active.fetch_add(1, Ordering::Relaxed); - } - } - pub fn decrement_pool_drain_active(&self) { - if !self.telemetry_me_allows_debug() { - return; - } - let mut current = self.pool_drain_active.load(Ordering::Relaxed); - loop { - if current == 0 { - break; - } - match self.pool_drain_active.compare_exchange_weak( - current, - current - 1, - Ordering::Relaxed, - Ordering::Relaxed, - ) { - Ok(_) => break, - Err(actual) => current = actual, - } - } - } - pub fn increment_pool_force_close_total(&self) { - if self.telemetry_me_allows_normal() { - self.pool_force_close_total.fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_pool_stale_pick_total(&self) { - if self.telemetry_me_allows_normal() { - self.pool_stale_pick_total.fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_writer_removed_total(&self) { - if self.telemetry_me_allows_debug() { - self.me_writer_removed_total.fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_writer_removed_unexpected_total(&self) { - if self.telemetry_me_allows_normal() { - self.me_writer_removed_unexpected_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_refill_triggered_total(&self) { - if self.telemetry_me_allows_debug() { - self.me_refill_triggered_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_refill_skipped_inflight_total(&self) { - if self.telemetry_me_allows_debug() { - self.me_refill_skipped_inflight_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_refill_failed_total(&self) { - if self.telemetry_me_allows_normal() { - self.me_refill_failed_total.fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_writer_restored_same_endpoint_total(&self) { - if self.telemetry_me_allows_normal() { - self.me_writer_restored_same_endpoint_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_writer_restored_fallback_total(&self) { - if self.telemetry_me_allows_normal() { - self.me_writer_restored_fallback_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_no_writer_failfast_total(&self) { - if self.telemetry_me_allows_normal() { - self.me_no_writer_failfast_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_hybrid_timeout_total(&self) { - if self.telemetry_me_allows_normal() { - self.me_hybrid_timeout_total.fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_async_recovery_trigger_total(&self) { - if self.telemetry_me_allows_normal() { - self.me_async_recovery_trigger_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_inline_recovery_total(&self) { - if self.telemetry_me_allows_normal() { - self.me_inline_recovery_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_ip_reservation_rollback_tcp_limit_total(&self) { - if self.telemetry_core_enabled() { - self.ip_reservation_rollback_tcp_limit_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_ip_reservation_rollback_quota_limit_total(&self) { - if self.telemetry_core_enabled() { - self.ip_reservation_rollback_quota_limit_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn add_quota_refund_bytes_total(&self, bytes: u64) { - if self.telemetry_core_enabled() { - self.quota_refund_bytes_total - .fetch_add(bytes, Ordering::Relaxed); - } - } - pub fn increment_quota_contention_total(&self) { - if self.telemetry_core_enabled() { - self.quota_contention_total.fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_quota_contention_timeout_total(&self) { - if self.telemetry_core_enabled() { - self.quota_contention_timeout_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_quota_acquire_cancelled_total(&self) { - if self.telemetry_core_enabled() { - self.quota_acquire_cancelled_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn add_quota_write_fail_bytes_total(&self, bytes: u64) { - if self.telemetry_core_enabled() { - self.quota_write_fail_bytes_total - .fetch_add(bytes, Ordering::Relaxed); - } - } - pub fn increment_quota_write_fail_events_total(&self) { - if self.telemetry_core_enabled() { - self.quota_write_fail_events_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_child_join_timeout_total(&self) { - if self.telemetry_core_enabled() { - self.me_child_join_timeout_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_child_abort_total(&self) { - if self.telemetry_core_enabled() { - self.me_child_abort_total.fetch_add(1, Ordering::Relaxed); - } - } - pub fn observe_flow_wait_middle_rate_limit_ms(&self, wait_ms: u64) { - if self.telemetry_core_enabled() { - self.flow_wait_middle_rate_limit_total - .fetch_add(1, Ordering::Relaxed); - self.flow_wait_middle_rate_limit_ms_total - .fetch_add(wait_ms, Ordering::Relaxed); - } - } - pub fn increment_flow_wait_middle_rate_limit_cancelled_total(&self) { - if self.telemetry_core_enabled() { - self.flow_wait_middle_rate_limit_cancelled_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_session_drop_fallback_total(&self) { - if self.telemetry_core_enabled() { - self.session_drop_fallback_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_endpoint_quarantine_total(&self) { - if self.telemetry_me_allows_normal() { - self.me_endpoint_quarantine_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_endpoint_quarantine_unexpected_total(&self) { - if self.telemetry_me_allows_normal() { - self.me_endpoint_quarantine_unexpected_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_endpoint_quarantine_draining_suppressed_total(&self) { - if self.telemetry_me_allows_normal() { - self.me_endpoint_quarantine_draining_suppressed_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_kdf_drift_total(&self) { - if self.telemetry_me_allows_normal() { - self.me_kdf_drift_total.fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_kdf_port_only_drift_total(&self) { - if self.telemetry_me_allows_debug() { - self.me_kdf_port_only_drift_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_hardswap_pending_reuse_total(&self) { - if self.telemetry_me_allows_debug() { - self.me_hardswap_pending_reuse_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_hardswap_pending_ttl_expired_total(&self) { - if self.telemetry_me_allows_normal() { - self.me_hardswap_pending_ttl_expired_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_single_endpoint_outage_enter_total(&self) { - if self.telemetry_me_allows_normal() { - self.me_single_endpoint_outage_enter_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_single_endpoint_outage_exit_total(&self) { - if self.telemetry_me_allows_normal() { - self.me_single_endpoint_outage_exit_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_single_endpoint_outage_reconnect_attempt_total(&self) { - if self.telemetry_me_allows_normal() { - self.me_single_endpoint_outage_reconnect_attempt_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_single_endpoint_outage_reconnect_success_total(&self) { - if self.telemetry_me_allows_normal() { - self.me_single_endpoint_outage_reconnect_success_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_single_endpoint_quarantine_bypass_total(&self) { - if self.telemetry_me_allows_normal() { - self.me_single_endpoint_quarantine_bypass_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_single_endpoint_shadow_rotate_total(&self) { - if self.telemetry_me_allows_normal() { - self.me_single_endpoint_shadow_rotate_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_single_endpoint_shadow_rotate_skipped_quarantine_total(&self) { - if self.telemetry_me_allows_normal() { - self.me_single_endpoint_shadow_rotate_skipped_quarantine_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_floor_mode_switch_total(&self) { - if self.telemetry_me_allows_normal() { - self.me_floor_mode_switch_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_floor_mode_switch_static_to_adaptive_total(&self) { - if self.telemetry_me_allows_normal() { - self.me_floor_mode_switch_static_to_adaptive_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_floor_mode_switch_adaptive_to_static_total(&self) { - if self.telemetry_me_allows_normal() { - self.me_floor_mode_switch_adaptive_to_static_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn set_me_floor_cpu_cores_detected_gauge(&self, value: u64) { - if self.telemetry_me_allows_normal() { - self.me_floor_cpu_cores_detected_gauge - .store(value, Ordering::Relaxed); - } - } - pub fn set_me_floor_cpu_cores_effective_gauge(&self, value: u64) { - if self.telemetry_me_allows_normal() { - self.me_floor_cpu_cores_effective_gauge - .store(value, Ordering::Relaxed); - } - } - pub fn set_me_floor_global_cap_raw_gauge(&self, value: u64) { - if self.telemetry_me_allows_normal() { - self.me_floor_global_cap_raw_gauge - .store(value, Ordering::Relaxed); - } - } - pub fn set_me_floor_global_cap_effective_gauge(&self, value: u64) { - if self.telemetry_me_allows_normal() { - self.me_floor_global_cap_effective_gauge - .store(value, Ordering::Relaxed); - } - } - pub fn set_me_floor_target_writers_total_gauge(&self, value: u64) { - if self.telemetry_me_allows_normal() { - self.me_floor_target_writers_total_gauge - .store(value, Ordering::Relaxed); - } - } - pub fn set_me_floor_active_cap_configured_gauge(&self, value: u64) { - if self.telemetry_me_allows_normal() { - self.me_floor_active_cap_configured_gauge - .store(value, Ordering::Relaxed); - } - } - pub fn set_me_floor_active_cap_effective_gauge(&self, value: u64) { - if self.telemetry_me_allows_normal() { - self.me_floor_active_cap_effective_gauge - .store(value, Ordering::Relaxed); - } - } - pub fn set_me_floor_warm_cap_configured_gauge(&self, value: u64) { - if self.telemetry_me_allows_normal() { - self.me_floor_warm_cap_configured_gauge - .store(value, Ordering::Relaxed); - } - } - pub fn set_me_floor_warm_cap_effective_gauge(&self, value: u64) { - if self.telemetry_me_allows_normal() { - self.me_floor_warm_cap_effective_gauge - .store(value, Ordering::Relaxed); - } - } - pub fn set_me_writers_active_current_gauge(&self, value: u64) { - if self.telemetry_me_allows_normal() { - self.me_writers_active_current_gauge - .store(value, Ordering::Relaxed); - } - } - pub fn set_me_writers_warm_current_gauge(&self, value: u64) { - if self.telemetry_me_allows_normal() { - self.me_writers_warm_current_gauge - .store(value, Ordering::Relaxed); - } - } - - pub fn set_buffer_pool_gauges(&self, pooled: usize, allocated: usize, in_use: usize) { - if self.telemetry_me_allows_normal() { - self.buffer_pool_pooled_gauge - .store(pooled as u64, Ordering::Relaxed); - self.buffer_pool_allocated_gauge - .store(allocated as u64, Ordering::Relaxed); - self.buffer_pool_in_use_gauge - .store(in_use as u64, Ordering::Relaxed); - } - } - - pub fn increment_me_c2me_send_full_total(&self) { - if self.telemetry_me_allows_normal() { - self.me_c2me_send_full_total.fetch_add(1, Ordering::Relaxed); - } - } - - pub fn increment_me_c2me_send_high_water_total(&self) { - if self.telemetry_me_allows_normal() { - self.me_c2me_send_high_water_total - .fetch_add(1, Ordering::Relaxed); - } - } - - pub fn increment_me_c2me_send_timeout_total(&self) { - if self.telemetry_me_allows_normal() { - self.me_c2me_send_timeout_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_floor_cap_block_total(&self) { - if self.telemetry_me_allows_normal() { - self.me_floor_cap_block_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_floor_swap_idle_total(&self) { - if self.telemetry_me_allows_normal() { - self.me_floor_swap_idle_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_floor_swap_idle_failed_total(&self) { - if self.telemetry_me_allows_normal() { - self.me_floor_swap_idle_failed_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn get_connects_all(&self) -> u64 { - self.connects_all.load(Ordering::Relaxed) - } - pub fn get_connects_bad(&self) -> u64 { - self.connects_bad.load(Ordering::Relaxed) - } - - pub fn get_connects_bad_class_counts(&self) -> Vec<(String, u64)> { - let mut out: Vec<(String, u64)> = self - .connects_bad_classes - .iter() - .map(|entry| { - ( - entry.key().to_string(), - entry.value().load(Ordering::Relaxed), - ) - }) - .collect(); - out.sort_by(|a, b| a.0.cmp(&b.0)); - out - } - - pub fn get_handshake_failure_class_counts(&self) -> Vec<(String, u64)> { - let mut out: Vec<(String, u64)> = self - .handshake_failure_classes - .iter() - .map(|entry| { - ( - entry.key().to_string(), - entry.value().load(Ordering::Relaxed), - ) - }) - .collect(); - out.sort_by(|a, b| a.0.cmp(&b.0)); - out - } - - pub fn get_accept_permit_timeout_total(&self) -> u64 { - self.accept_permit_timeout_total.load(Ordering::Relaxed) - } - pub fn get_current_connections_direct(&self) -> u64 { - self.current_connections_direct.load(Ordering::Relaxed) - } - pub fn get_current_connections_me(&self) -> u64 { - self.current_connections_me.load(Ordering::Relaxed) - } - pub fn get_current_connections_total(&self) -> u64 { - self.get_current_connections_direct() - .saturating_add(self.get_current_connections_me()) - } - pub fn get_conntrack_control_enabled(&self) -> bool { - self.conntrack_control_enabled_gauge.load(Ordering::Relaxed) - } - pub fn get_conntrack_control_available(&self) -> bool { - self.conntrack_control_available_gauge - .load(Ordering::Relaxed) - } - pub fn get_conntrack_pressure_active(&self) -> bool { - self.conntrack_pressure_active_gauge.load(Ordering::Relaxed) - } - pub fn get_conntrack_event_queue_depth(&self) -> u64 { - self.conntrack_event_queue_depth_gauge - .load(Ordering::Relaxed) - } - pub fn get_conntrack_rule_apply_ok(&self) -> bool { - self.conntrack_rule_apply_ok_gauge.load(Ordering::Relaxed) - } - pub fn get_conntrack_delete_attempt_total(&self) -> u64 { - self.conntrack_delete_attempt_total.load(Ordering::Relaxed) - } - pub fn get_conntrack_delete_success_total(&self) -> u64 { - self.conntrack_delete_success_total.load(Ordering::Relaxed) - } - pub fn get_conntrack_delete_not_found_total(&self) -> u64 { - self.conntrack_delete_not_found_total - .load(Ordering::Relaxed) - } - pub fn get_conntrack_delete_error_total(&self) -> u64 { - self.conntrack_delete_error_total.load(Ordering::Relaxed) - } - pub fn get_conntrack_close_event_drop_total(&self) -> u64 { - self.conntrack_close_event_drop_total - .load(Ordering::Relaxed) - } - pub fn get_me_keepalive_sent(&self) -> u64 { - self.me_keepalive_sent.load(Ordering::Relaxed) - } - pub fn get_me_keepalive_failed(&self) -> u64 { - self.me_keepalive_failed.load(Ordering::Relaxed) - } - pub fn get_me_keepalive_pong(&self) -> u64 { - self.me_keepalive_pong.load(Ordering::Relaxed) - } - pub fn get_me_keepalive_timeout(&self) -> u64 { - self.me_keepalive_timeout.load(Ordering::Relaxed) - } - pub fn get_me_rpc_proxy_req_signal_sent_total(&self) -> u64 { - self.me_rpc_proxy_req_signal_sent_total - .load(Ordering::Relaxed) - } - pub fn get_me_rpc_proxy_req_signal_failed_total(&self) -> u64 { - self.me_rpc_proxy_req_signal_failed_total - .load(Ordering::Relaxed) - } - pub fn get_me_rpc_proxy_req_signal_skipped_no_meta_total(&self) -> u64 { - self.me_rpc_proxy_req_signal_skipped_no_meta_total - .load(Ordering::Relaxed) - } - pub fn get_me_rpc_proxy_req_signal_response_total(&self) -> u64 { - self.me_rpc_proxy_req_signal_response_total - .load(Ordering::Relaxed) - } - pub fn get_me_rpc_proxy_req_signal_close_sent_total(&self) -> u64 { - self.me_rpc_proxy_req_signal_close_sent_total - .load(Ordering::Relaxed) - } - pub fn get_me_reconnect_attempts(&self) -> u64 { - self.me_reconnect_attempts.load(Ordering::Relaxed) - } - pub fn get_me_reconnect_success(&self) -> u64 { - self.me_reconnect_success.load(Ordering::Relaxed) - } - pub fn get_me_handshake_reject_total(&self) -> u64 { - self.me_handshake_reject_total.load(Ordering::Relaxed) - } - pub fn get_me_reader_eof_total(&self) -> u64 { - self.me_reader_eof_total.load(Ordering::Relaxed) - } - pub fn get_me_idle_close_by_peer_total(&self) -> u64 { - self.me_idle_close_by_peer_total.load(Ordering::Relaxed) - } - pub fn get_relay_idle_soft_mark_total(&self) -> u64 { - self.relay_idle_soft_mark_total.load(Ordering::Relaxed) - } - pub fn get_relay_idle_hard_close_total(&self) -> u64 { - self.relay_idle_hard_close_total.load(Ordering::Relaxed) - } - pub fn get_relay_pressure_evict_total(&self) -> u64 { - self.relay_pressure_evict_total.load(Ordering::Relaxed) - } - pub fn get_relay_protocol_desync_close_total(&self) -> u64 { - self.relay_protocol_desync_close_total - .load(Ordering::Relaxed) - } - pub fn get_me_crc_mismatch(&self) -> u64 { - self.me_crc_mismatch.load(Ordering::Relaxed) - } - pub fn get_me_seq_mismatch(&self) -> u64 { - self.me_seq_mismatch.load(Ordering::Relaxed) - } - pub fn get_me_endpoint_quarantine_total(&self) -> u64 { - self.me_endpoint_quarantine_total.load(Ordering::Relaxed) - } - pub fn get_me_endpoint_quarantine_unexpected_total(&self) -> u64 { - self.me_endpoint_quarantine_unexpected_total - .load(Ordering::Relaxed) - } - pub fn get_me_endpoint_quarantine_draining_suppressed_total(&self) -> u64 { - self.me_endpoint_quarantine_draining_suppressed_total - .load(Ordering::Relaxed) - } - pub fn get_me_kdf_drift_total(&self) -> u64 { - self.me_kdf_drift_total.load(Ordering::Relaxed) - } - pub fn get_me_kdf_port_only_drift_total(&self) -> u64 { - self.me_kdf_port_only_drift_total.load(Ordering::Relaxed) - } - pub fn get_me_hardswap_pending_reuse_total(&self) -> u64 { - self.me_hardswap_pending_reuse_total.load(Ordering::Relaxed) - } - pub fn get_me_hardswap_pending_ttl_expired_total(&self) -> u64 { - self.me_hardswap_pending_ttl_expired_total - .load(Ordering::Relaxed) - } - pub fn get_me_single_endpoint_outage_enter_total(&self) -> u64 { - self.me_single_endpoint_outage_enter_total - .load(Ordering::Relaxed) - } - pub fn get_me_single_endpoint_outage_exit_total(&self) -> u64 { - self.me_single_endpoint_outage_exit_total - .load(Ordering::Relaxed) - } - pub fn get_me_single_endpoint_outage_reconnect_attempt_total(&self) -> u64 { - self.me_single_endpoint_outage_reconnect_attempt_total - .load(Ordering::Relaxed) - } - pub fn get_me_single_endpoint_outage_reconnect_success_total(&self) -> u64 { - self.me_single_endpoint_outage_reconnect_success_total - .load(Ordering::Relaxed) - } - pub fn get_me_single_endpoint_quarantine_bypass_total(&self) -> u64 { - self.me_single_endpoint_quarantine_bypass_total - .load(Ordering::Relaxed) - } - pub fn get_me_single_endpoint_shadow_rotate_total(&self) -> u64 { - self.me_single_endpoint_shadow_rotate_total - .load(Ordering::Relaxed) - } - pub fn get_me_single_endpoint_shadow_rotate_skipped_quarantine_total(&self) -> u64 { - self.me_single_endpoint_shadow_rotate_skipped_quarantine_total - .load(Ordering::Relaxed) - } - pub fn get_me_floor_mode_switch_total(&self) -> u64 { - self.me_floor_mode_switch_total.load(Ordering::Relaxed) - } - pub fn get_me_floor_mode_switch_static_to_adaptive_total(&self) -> u64 { - self.me_floor_mode_switch_static_to_adaptive_total - .load(Ordering::Relaxed) - } - pub fn get_me_floor_mode_switch_adaptive_to_static_total(&self) -> u64 { - self.me_floor_mode_switch_adaptive_to_static_total - .load(Ordering::Relaxed) - } - pub fn get_me_floor_cpu_cores_detected_gauge(&self) -> u64 { - self.me_floor_cpu_cores_detected_gauge - .load(Ordering::Relaxed) - } - pub fn get_me_floor_cpu_cores_effective_gauge(&self) -> u64 { - self.me_floor_cpu_cores_effective_gauge - .load(Ordering::Relaxed) - } - pub fn get_me_floor_global_cap_raw_gauge(&self) -> u64 { - self.me_floor_global_cap_raw_gauge.load(Ordering::Relaxed) - } - pub fn get_me_floor_global_cap_effective_gauge(&self) -> u64 { - self.me_floor_global_cap_effective_gauge - .load(Ordering::Relaxed) - } - pub fn get_me_floor_target_writers_total_gauge(&self) -> u64 { - self.me_floor_target_writers_total_gauge - .load(Ordering::Relaxed) - } - pub fn get_me_floor_active_cap_configured_gauge(&self) -> u64 { - self.me_floor_active_cap_configured_gauge - .load(Ordering::Relaxed) - } - pub fn get_me_floor_active_cap_effective_gauge(&self) -> u64 { - self.me_floor_active_cap_effective_gauge - .load(Ordering::Relaxed) - } - pub fn get_me_floor_warm_cap_configured_gauge(&self) -> u64 { - self.me_floor_warm_cap_configured_gauge - .load(Ordering::Relaxed) - } - pub fn get_me_floor_warm_cap_effective_gauge(&self) -> u64 { - self.me_floor_warm_cap_effective_gauge - .load(Ordering::Relaxed) - } - pub fn get_me_writers_active_current_gauge(&self) -> u64 { - self.me_writers_active_current_gauge.load(Ordering::Relaxed) - } - pub fn get_me_writers_warm_current_gauge(&self) -> u64 { - self.me_writers_warm_current_gauge.load(Ordering::Relaxed) - } - pub fn get_me_floor_cap_block_total(&self) -> u64 { - self.me_floor_cap_block_total.load(Ordering::Relaxed) - } - pub fn get_me_floor_swap_idle_total(&self) -> u64 { - self.me_floor_swap_idle_total.load(Ordering::Relaxed) - } - pub fn get_me_floor_swap_idle_failed_total(&self) -> u64 { - self.me_floor_swap_idle_failed_total.load(Ordering::Relaxed) - } - pub fn get_me_handshake_error_code_counts(&self) -> Vec<(i32, u64)> { - let mut out: Vec<(i32, u64)> = self - .me_handshake_error_codes - .iter() - .map(|entry| (*entry.key(), entry.value().load(Ordering::Relaxed))) - .collect(); - out.sort_by_key(|(code, _)| *code); - out - } - pub fn get_me_route_drop_no_conn(&self) -> u64 { - self.me_route_drop_no_conn.load(Ordering::Relaxed) - } - pub fn get_me_route_drop_channel_closed(&self) -> u64 { - self.me_route_drop_channel_closed.load(Ordering::Relaxed) - } - pub fn get_me_route_drop_queue_full(&self) -> u64 { - self.me_route_drop_queue_full.load(Ordering::Relaxed) - } - pub fn get_me_route_drop_queue_full_base(&self) -> u64 { - self.me_route_drop_queue_full_base.load(Ordering::Relaxed) - } - pub fn get_me_route_drop_queue_full_high(&self) -> u64 { - self.me_route_drop_queue_full_high.load(Ordering::Relaxed) - } - pub fn get_me_fair_pressure_state_gauge(&self) -> u64 { - self.me_fair_pressure_state_gauge.load(Ordering::Relaxed) - } - pub fn get_me_fair_active_flows_gauge(&self) -> u64 { - self.me_fair_active_flows_gauge.load(Ordering::Relaxed) - } - pub fn get_me_fair_queued_bytes_gauge(&self) -> u64 { - self.me_fair_queued_bytes_gauge.load(Ordering::Relaxed) - } - pub fn get_me_fair_standing_flows_gauge(&self) -> u64 { - self.me_fair_standing_flows_gauge.load(Ordering::Relaxed) - } - pub fn get_me_fair_backpressured_flows_gauge(&self) -> u64 { - self.me_fair_backpressured_flows_gauge - .load(Ordering::Relaxed) - } - pub fn get_me_fair_scheduler_rounds_total(&self) -> u64 { - self.me_fair_scheduler_rounds_total.load(Ordering::Relaxed) - } - pub fn get_me_fair_deficit_grants_total(&self) -> u64 { - self.me_fair_deficit_grants_total.load(Ordering::Relaxed) - } - pub fn get_me_fair_deficit_skips_total(&self) -> u64 { - self.me_fair_deficit_skips_total.load(Ordering::Relaxed) - } - pub fn get_me_fair_enqueue_rejects_total(&self) -> u64 { - self.me_fair_enqueue_rejects_total.load(Ordering::Relaxed) - } - pub fn get_me_fair_shed_drops_total(&self) -> u64 { - self.me_fair_shed_drops_total.load(Ordering::Relaxed) - } - pub fn get_me_fair_penalties_total(&self) -> u64 { - self.me_fair_penalties_total.load(Ordering::Relaxed) - } - pub fn get_me_fair_downstream_stalls_total(&self) -> u64 { - self.me_fair_downstream_stalls_total.load(Ordering::Relaxed) - } - pub fn get_me_d2c_batches_total(&self) -> u64 { - self.me_d2c_batches_total.load(Ordering::Relaxed) - } - pub fn get_me_d2c_batch_frames_total(&self) -> u64 { - self.me_d2c_batch_frames_total.load(Ordering::Relaxed) - } - pub fn get_me_d2c_batch_bytes_total(&self) -> u64 { - self.me_d2c_batch_bytes_total.load(Ordering::Relaxed) - } - pub fn get_me_d2c_flush_reason_queue_drain_total(&self) -> u64 { - self.me_d2c_flush_reason_queue_drain_total - .load(Ordering::Relaxed) - } - pub fn get_me_d2c_flush_reason_batch_frames_total(&self) -> u64 { - self.me_d2c_flush_reason_batch_frames_total - .load(Ordering::Relaxed) - } - pub fn get_me_d2c_flush_reason_batch_bytes_total(&self) -> u64 { - self.me_d2c_flush_reason_batch_bytes_total - .load(Ordering::Relaxed) - } - pub fn get_me_d2c_flush_reason_max_delay_total(&self) -> u64 { - self.me_d2c_flush_reason_max_delay_total - .load(Ordering::Relaxed) - } - pub fn get_me_d2c_flush_reason_ack_immediate_total(&self) -> u64 { - self.me_d2c_flush_reason_ack_immediate_total - .load(Ordering::Relaxed) - } - pub fn get_me_d2c_flush_reason_close_total(&self) -> u64 { - self.me_d2c_flush_reason_close_total.load(Ordering::Relaxed) - } - pub fn get_me_d2c_data_frames_total(&self) -> u64 { - self.me_d2c_data_frames_total.load(Ordering::Relaxed) - } - pub fn get_me_d2c_ack_frames_total(&self) -> u64 { - self.me_d2c_ack_frames_total.load(Ordering::Relaxed) - } - pub fn get_me_d2c_payload_bytes_total(&self) -> u64 { - self.me_d2c_payload_bytes_total.load(Ordering::Relaxed) - } - pub fn get_me_d2c_write_mode_coalesced_total(&self) -> u64 { - self.me_d2c_write_mode_coalesced_total - .load(Ordering::Relaxed) - } - pub fn get_me_d2c_write_mode_split_total(&self) -> u64 { - self.me_d2c_write_mode_split_total.load(Ordering::Relaxed) - } - pub fn get_me_d2c_quota_reject_pre_write_total(&self) -> u64 { - self.me_d2c_quota_reject_pre_write_total - .load(Ordering::Relaxed) - } - pub fn get_me_d2c_quota_reject_post_write_total(&self) -> u64 { - self.me_d2c_quota_reject_post_write_total - .load(Ordering::Relaxed) - } - pub fn get_me_d2c_frame_buf_shrink_total(&self) -> u64 { - self.me_d2c_frame_buf_shrink_total.load(Ordering::Relaxed) - } - pub fn get_me_d2c_frame_buf_shrink_bytes_total(&self) -> u64 { - self.me_d2c_frame_buf_shrink_bytes_total - .load(Ordering::Relaxed) - } - pub fn get_me_d2c_batch_frames_bucket_1(&self) -> u64 { - self.me_d2c_batch_frames_bucket_1.load(Ordering::Relaxed) - } - pub fn get_me_d2c_batch_frames_bucket_2_4(&self) -> u64 { - self.me_d2c_batch_frames_bucket_2_4.load(Ordering::Relaxed) - } - pub fn get_me_d2c_batch_frames_bucket_5_8(&self) -> u64 { - self.me_d2c_batch_frames_bucket_5_8.load(Ordering::Relaxed) - } - pub fn get_me_d2c_batch_frames_bucket_9_16(&self) -> u64 { - self.me_d2c_batch_frames_bucket_9_16.load(Ordering::Relaxed) - } - pub fn get_me_d2c_batch_frames_bucket_17_32(&self) -> u64 { - self.me_d2c_batch_frames_bucket_17_32 - .load(Ordering::Relaxed) - } - pub fn get_me_d2c_batch_frames_bucket_gt_32(&self) -> u64 { - self.me_d2c_batch_frames_bucket_gt_32 - .load(Ordering::Relaxed) - } - pub fn get_me_d2c_batch_bytes_bucket_0_1k(&self) -> u64 { - self.me_d2c_batch_bytes_bucket_0_1k.load(Ordering::Relaxed) - } - pub fn get_me_d2c_batch_bytes_bucket_1k_4k(&self) -> u64 { - self.me_d2c_batch_bytes_bucket_1k_4k.load(Ordering::Relaxed) - } - pub fn get_me_d2c_batch_bytes_bucket_4k_16k(&self) -> u64 { - self.me_d2c_batch_bytes_bucket_4k_16k - .load(Ordering::Relaxed) - } - pub fn get_me_d2c_batch_bytes_bucket_16k_64k(&self) -> u64 { - self.me_d2c_batch_bytes_bucket_16k_64k - .load(Ordering::Relaxed) - } - pub fn get_me_d2c_batch_bytes_bucket_64k_128k(&self) -> u64 { - self.me_d2c_batch_bytes_bucket_64k_128k - .load(Ordering::Relaxed) - } - pub fn get_me_d2c_batch_bytes_bucket_gt_128k(&self) -> u64 { - self.me_d2c_batch_bytes_bucket_gt_128k - .load(Ordering::Relaxed) - } - pub fn get_me_d2c_flush_duration_us_bucket_0_50(&self) -> u64 { - self.me_d2c_flush_duration_us_bucket_0_50 - .load(Ordering::Relaxed) - } - pub fn get_me_d2c_flush_duration_us_bucket_51_200(&self) -> u64 { - self.me_d2c_flush_duration_us_bucket_51_200 - .load(Ordering::Relaxed) - } - pub fn get_me_d2c_flush_duration_us_bucket_201_1000(&self) -> u64 { - self.me_d2c_flush_duration_us_bucket_201_1000 - .load(Ordering::Relaxed) - } - pub fn get_me_d2c_flush_duration_us_bucket_1001_5000(&self) -> u64 { - self.me_d2c_flush_duration_us_bucket_1001_5000 - .load(Ordering::Relaxed) - } - pub fn get_me_d2c_flush_duration_us_bucket_5001_20000(&self) -> u64 { - self.me_d2c_flush_duration_us_bucket_5001_20000 - .load(Ordering::Relaxed) - } - pub fn get_me_d2c_flush_duration_us_bucket_gt_20000(&self) -> u64 { - self.me_d2c_flush_duration_us_bucket_gt_20000 - .load(Ordering::Relaxed) - } - - pub fn get_buffer_pool_pooled_gauge(&self) -> u64 { - self.buffer_pool_pooled_gauge.load(Ordering::Relaxed) - } - - pub fn get_buffer_pool_allocated_gauge(&self) -> u64 { - self.buffer_pool_allocated_gauge.load(Ordering::Relaxed) - } - - pub fn get_buffer_pool_in_use_gauge(&self) -> u64 { - self.buffer_pool_in_use_gauge.load(Ordering::Relaxed) - } - - pub fn get_me_c2me_send_full_total(&self) -> u64 { - self.me_c2me_send_full_total.load(Ordering::Relaxed) - } - - pub fn get_me_c2me_send_high_water_total(&self) -> u64 { - self.me_c2me_send_high_water_total.load(Ordering::Relaxed) - } - - pub fn get_me_c2me_send_timeout_total(&self) -> u64 { - self.me_c2me_send_timeout_total.load(Ordering::Relaxed) - } - pub fn get_me_d2c_batch_timeout_armed_total(&self) -> u64 { - self.me_d2c_batch_timeout_armed_total - .load(Ordering::Relaxed) - } - pub fn get_me_d2c_batch_timeout_fired_total(&self) -> u64 { - self.me_d2c_batch_timeout_fired_total - .load(Ordering::Relaxed) - } - pub fn get_me_writer_pick_sorted_rr_success_try_total(&self) -> u64 { - self.me_writer_pick_sorted_rr_success_try_total - .load(Ordering::Relaxed) - } - pub fn get_me_writer_pick_sorted_rr_success_fallback_total(&self) -> u64 { - self.me_writer_pick_sorted_rr_success_fallback_total - .load(Ordering::Relaxed) - } - pub fn get_me_writer_pick_sorted_rr_full_total(&self) -> u64 { - self.me_writer_pick_sorted_rr_full_total - .load(Ordering::Relaxed) - } - pub fn get_me_writer_pick_sorted_rr_closed_total(&self) -> u64 { - self.me_writer_pick_sorted_rr_closed_total - .load(Ordering::Relaxed) - } - pub fn get_me_writer_pick_sorted_rr_no_candidate_total(&self) -> u64 { - self.me_writer_pick_sorted_rr_no_candidate_total - .load(Ordering::Relaxed) - } - pub fn get_me_writer_pick_p2c_success_try_total(&self) -> u64 { - self.me_writer_pick_p2c_success_try_total - .load(Ordering::Relaxed) - } - pub fn get_me_writer_pick_p2c_success_fallback_total(&self) -> u64 { - self.me_writer_pick_p2c_success_fallback_total - .load(Ordering::Relaxed) - } - pub fn get_me_writer_pick_p2c_full_total(&self) -> u64 { - self.me_writer_pick_p2c_full_total.load(Ordering::Relaxed) - } - pub fn get_me_writer_pick_p2c_closed_total(&self) -> u64 { - self.me_writer_pick_p2c_closed_total.load(Ordering::Relaxed) - } - pub fn get_me_writer_pick_p2c_no_candidate_total(&self) -> u64 { - self.me_writer_pick_p2c_no_candidate_total - .load(Ordering::Relaxed) - } - pub fn get_me_writer_pick_blocking_fallback_total(&self) -> u64 { - self.me_writer_pick_blocking_fallback_total - .load(Ordering::Relaxed) - } - pub fn get_me_writer_pick_mode_switch_total(&self) -> u64 { - self.me_writer_pick_mode_switch_total - .load(Ordering::Relaxed) - } - pub fn get_me_socks_kdf_strict_reject(&self) -> u64 { - self.me_socks_kdf_strict_reject.load(Ordering::Relaxed) - } - pub fn get_me_socks_kdf_compat_fallback(&self) -> u64 { - self.me_socks_kdf_compat_fallback.load(Ordering::Relaxed) - } - pub fn get_secure_padding_invalid(&self) -> u64 { - self.secure_padding_invalid.load(Ordering::Relaxed) - } - pub fn get_desync_total(&self) -> u64 { - self.desync_total.load(Ordering::Relaxed) - } - pub fn get_desync_full_logged(&self) -> u64 { - self.desync_full_logged.load(Ordering::Relaxed) - } - pub fn get_desync_suppressed(&self) -> u64 { - self.desync_suppressed.load(Ordering::Relaxed) - } - pub fn get_desync_frames_bucket_0(&self) -> u64 { - self.desync_frames_bucket_0.load(Ordering::Relaxed) - } - pub fn get_desync_frames_bucket_1_2(&self) -> u64 { - self.desync_frames_bucket_1_2.load(Ordering::Relaxed) - } - pub fn get_desync_frames_bucket_3_10(&self) -> u64 { - self.desync_frames_bucket_3_10.load(Ordering::Relaxed) - } - pub fn get_desync_frames_bucket_gt_10(&self) -> u64 { - self.desync_frames_bucket_gt_10.load(Ordering::Relaxed) - } - pub fn get_pool_swap_total(&self) -> u64 { - self.pool_swap_total.load(Ordering::Relaxed) - } - pub fn get_pool_drain_active(&self) -> u64 { - self.pool_drain_active.load(Ordering::Relaxed) - } - pub fn get_pool_force_close_total(&self) -> u64 { - self.pool_force_close_total.load(Ordering::Relaxed) - } - pub fn get_pool_stale_pick_total(&self) -> u64 { - self.pool_stale_pick_total.load(Ordering::Relaxed) - } - pub fn get_me_writer_removed_total(&self) -> u64 { - self.me_writer_removed_total.load(Ordering::Relaxed) - } - pub fn get_me_writer_removed_unexpected_total(&self) -> u64 { - self.me_writer_removed_unexpected_total - .load(Ordering::Relaxed) - } - pub fn get_me_refill_triggered_total(&self) -> u64 { - self.me_refill_triggered_total.load(Ordering::Relaxed) - } - pub fn get_me_refill_skipped_inflight_total(&self) -> u64 { - self.me_refill_skipped_inflight_total - .load(Ordering::Relaxed) - } - pub fn get_me_refill_failed_total(&self) -> u64 { - self.me_refill_failed_total.load(Ordering::Relaxed) - } - pub fn get_me_writer_restored_same_endpoint_total(&self) -> u64 { - self.me_writer_restored_same_endpoint_total - .load(Ordering::Relaxed) - } - pub fn get_me_writer_restored_fallback_total(&self) -> u64 { - self.me_writer_restored_fallback_total - .load(Ordering::Relaxed) - } - pub fn get_me_no_writer_failfast_total(&self) -> u64 { - self.me_no_writer_failfast_total.load(Ordering::Relaxed) - } - pub fn get_me_hybrid_timeout_total(&self) -> u64 { - self.me_hybrid_timeout_total.load(Ordering::Relaxed) - } - pub fn get_me_async_recovery_trigger_total(&self) -> u64 { - self.me_async_recovery_trigger_total.load(Ordering::Relaxed) - } - pub fn get_me_inline_recovery_total(&self) -> u64 { - self.me_inline_recovery_total.load(Ordering::Relaxed) - } - pub fn get_ip_reservation_rollback_tcp_limit_total(&self) -> u64 { - self.ip_reservation_rollback_tcp_limit_total - .load(Ordering::Relaxed) - } - pub fn get_ip_reservation_rollback_quota_limit_total(&self) -> u64 { - self.ip_reservation_rollback_quota_limit_total - .load(Ordering::Relaxed) - } - pub fn get_quota_refund_bytes_total(&self) -> u64 { - self.quota_refund_bytes_total.load(Ordering::Relaxed) - } - pub fn get_quota_contention_total(&self) -> u64 { - self.quota_contention_total.load(Ordering::Relaxed) - } - pub fn get_quota_contention_timeout_total(&self) -> u64 { - self.quota_contention_timeout_total.load(Ordering::Relaxed) - } - pub fn get_quota_acquire_cancelled_total(&self) -> u64 { - self.quota_acquire_cancelled_total.load(Ordering::Relaxed) - } - pub fn get_quota_write_fail_bytes_total(&self) -> u64 { - self.quota_write_fail_bytes_total.load(Ordering::Relaxed) - } - pub fn get_quota_write_fail_events_total(&self) -> u64 { - self.quota_write_fail_events_total.load(Ordering::Relaxed) - } - pub fn get_me_child_join_timeout_total(&self) -> u64 { - self.me_child_join_timeout_total.load(Ordering::Relaxed) - } - pub fn get_me_child_abort_total(&self) -> u64 { - self.me_child_abort_total.load(Ordering::Relaxed) - } - pub fn get_flow_wait_middle_rate_limit_total(&self) -> u64 { - self.flow_wait_middle_rate_limit_total - .load(Ordering::Relaxed) - } - pub fn get_flow_wait_middle_rate_limit_cancelled_total(&self) -> u64 { - self.flow_wait_middle_rate_limit_cancelled_total - .load(Ordering::Relaxed) - } - pub fn get_flow_wait_middle_rate_limit_ms_total(&self) -> u64 { - self.flow_wait_middle_rate_limit_ms_total - .load(Ordering::Relaxed) - } - pub fn get_session_drop_fallback_total(&self) -> u64 { - self.session_drop_fallback_total.load(Ordering::Relaxed) - } - - pub fn increment_user_connects(&self, user: &str) { - if !self.telemetry_user_enabled() { - return; - } - let stats = self.get_or_create_user_stats_handle(user); - self.touch_user_stats(stats.as_ref()); - stats.connects.fetch_add(1, Ordering::Relaxed); - } - - pub fn increment_user_curr_connects(&self, user: &str) { - if !self.telemetry_user_enabled() { - return; - } - let stats = self.get_or_create_user_stats_handle(user); - self.touch_user_stats(stats.as_ref()); - stats.curr_connects.fetch_add(1, Ordering::Relaxed); - } - - pub fn try_acquire_user_curr_connects(&self, user: &str, limit: Option) -> bool { - if !self.telemetry_user_enabled() { - return true; - } - - let stats = self.get_or_create_user_stats_handle(user); - self.touch_user_stats(stats.as_ref()); - - let counter = &stats.curr_connects; - let mut current = counter.load(Ordering::Relaxed); - loop { - if let Some(max) = limit - && current >= max - { - return false; - } - match counter.compare_exchange_weak( - current, - current.saturating_add(1), - Ordering::Relaxed, - Ordering::Relaxed, - ) { - Ok(_) => return true, - Err(actual) => current = actual, - } - } - } - - pub fn decrement_user_curr_connects(&self, user: &str) { - if let Some(stats) = self.user_stats.get(user) { - self.touch_user_stats(stats.value().as_ref()); - let counter = &stats.curr_connects; - let mut current = counter.load(Ordering::Relaxed); - loop { - if current == 0 { - break; - } - match counter.compare_exchange_weak( - current, - current - 1, - Ordering::Relaxed, - Ordering::Relaxed, - ) { - Ok(_) => break, - Err(actual) => current = actual, - } - } - } - } - - pub fn get_user_curr_connects(&self, user: &str) -> u64 { - self.user_stats - .get(user) - .map(|s| s.curr_connects.load(Ordering::Relaxed)) - .unwrap_or(0) - } - - pub fn add_user_octets_from(&self, user: &str, bytes: u64) { - if !self.telemetry_user_enabled() { - return; - } - let stats = self.get_or_create_user_stats_handle(user); - self.add_user_octets_from_handle(stats.as_ref(), bytes); - } - - pub fn add_user_octets_to(&self, user: &str, bytes: u64) { - if !self.telemetry_user_enabled() { - return; - } - let stats = self.get_or_create_user_stats_handle(user); - self.add_user_octets_to_handle(stats.as_ref(), bytes); - } - - pub fn increment_user_msgs_from(&self, user: &str) { - if !self.telemetry_user_enabled() { - return; - } - let stats = self.get_or_create_user_stats_handle(user); - self.increment_user_msgs_from_handle(stats.as_ref()); - } - - pub fn increment_user_msgs_to(&self, user: &str) { - if !self.telemetry_user_enabled() { - return; - } - let stats = self.get_or_create_user_stats_handle(user); - self.increment_user_msgs_to_handle(stats.as_ref()); - } - - pub fn get_user_total_octets(&self, user: &str) -> u64 { - self.user_stats - .get(user) - .map(|s| { - s.octets_from_client.load(Ordering::Relaxed) - + s.octets_to_client.load(Ordering::Relaxed) - }) - .unwrap_or(0) - } - - pub fn get_user_quota_used(&self, user: &str) -> u64 { - self.user_stats - .get(user) - .map(|s| s.quota_used.load(Ordering::Relaxed)) - .unwrap_or(0) - } - - pub fn load_user_quota_state(&self, user: &str, used_bytes: u64, last_reset_epoch_secs: u64) { - let stats = self.get_or_create_user_stats_handle(user); - stats.quota_used.store(used_bytes, Ordering::Relaxed); - stats - .quota_last_reset_epoch_secs - .store(last_reset_epoch_secs, Ordering::Relaxed); - } - - pub fn reset_user_quota(&self, user: &str) -> UserQuotaSnapshot { - let stats = self.get_or_create_user_stats_handle(user); - let last_reset_epoch_secs = Self::now_epoch_secs(); - stats.quota_used.store(0, Ordering::Relaxed); - stats - .quota_last_reset_epoch_secs - .store(last_reset_epoch_secs, Ordering::Relaxed); - UserQuotaSnapshot { - used_bytes: 0, - last_reset_epoch_secs, - } - } - - pub fn user_quota_snapshot(&self) -> HashMap { - let mut out = HashMap::new(); - for entry in self.user_stats.iter() { - let stats = entry.value(); - let used_bytes = stats.quota_used.load(Ordering::Relaxed); - let last_reset_epoch_secs = stats.quota_last_reset_epoch_secs.load(Ordering::Relaxed); - if used_bytes == 0 && last_reset_epoch_secs == 0 { - continue; - } - out.insert( - entry.key().clone(), - UserQuotaSnapshot { - used_bytes, - last_reset_epoch_secs, - }, - ); - } - out - } - - pub fn get_handshake_timeouts(&self) -> u64 { - self.handshake_timeouts.load(Ordering::Relaxed) - } - pub fn get_upstream_connect_attempt_total(&self) -> u64 { - self.upstream_connect_attempt_total.load(Ordering::Relaxed) - } - pub fn get_upstream_connect_success_total(&self) -> u64 { - self.upstream_connect_success_total.load(Ordering::Relaxed) - } - pub fn get_upstream_connect_fail_total(&self) -> u64 { - self.upstream_connect_fail_total.load(Ordering::Relaxed) - } - pub fn get_upstream_connect_failfast_hard_error_total(&self) -> u64 { - self.upstream_connect_failfast_hard_error_total - .load(Ordering::Relaxed) - } - pub fn get_upstream_connect_attempts_bucket_1(&self) -> u64 { - self.upstream_connect_attempts_bucket_1 - .load(Ordering::Relaxed) - } - pub fn get_upstream_connect_attempts_bucket_2(&self) -> u64 { - self.upstream_connect_attempts_bucket_2 - .load(Ordering::Relaxed) - } - pub fn get_upstream_connect_attempts_bucket_3_4(&self) -> u64 { - self.upstream_connect_attempts_bucket_3_4 - .load(Ordering::Relaxed) - } - pub fn get_upstream_connect_attempts_bucket_gt_4(&self) -> u64 { - self.upstream_connect_attempts_bucket_gt_4 - .load(Ordering::Relaxed) - } - pub fn get_upstream_connect_duration_success_bucket_le_100ms(&self) -> u64 { - self.upstream_connect_duration_success_bucket_le_100ms - .load(Ordering::Relaxed) - } - pub fn get_upstream_connect_duration_success_bucket_101_500ms(&self) -> u64 { - self.upstream_connect_duration_success_bucket_101_500ms - .load(Ordering::Relaxed) - } - pub fn get_upstream_connect_duration_success_bucket_501_1000ms(&self) -> u64 { - self.upstream_connect_duration_success_bucket_501_1000ms - .load(Ordering::Relaxed) - } - pub fn get_upstream_connect_duration_success_bucket_gt_1000ms(&self) -> u64 { - self.upstream_connect_duration_success_bucket_gt_1000ms - .load(Ordering::Relaxed) - } - pub fn get_upstream_connect_duration_fail_bucket_le_100ms(&self) -> u64 { - self.upstream_connect_duration_fail_bucket_le_100ms - .load(Ordering::Relaxed) - } - pub fn get_upstream_connect_duration_fail_bucket_101_500ms(&self) -> u64 { - self.upstream_connect_duration_fail_bucket_101_500ms - .load(Ordering::Relaxed) - } - pub fn get_upstream_connect_duration_fail_bucket_501_1000ms(&self) -> u64 { - self.upstream_connect_duration_fail_bucket_501_1000ms - .load(Ordering::Relaxed) - } - pub fn get_upstream_connect_duration_fail_bucket_gt_1000ms(&self) -> u64 { - self.upstream_connect_duration_fail_bucket_gt_1000ms - .load(Ordering::Relaxed) - } - - pub fn iter_user_stats(&self) -> dashmap::iter::Iter<'_, String, Arc> { - self.user_stats.iter() - } - - /// Current number of retained per-user stats entries. - pub fn user_stats_len(&self) -> usize { - self.user_stats.len() - } - - pub fn uptime_secs(&self) -> f64 { - self.start_time - .read() - .map(|t| t.elapsed().as_secs_f64()) - .unwrap_or(0.0) - } -} - -// ============= Replay Checker ============= - -pub struct ReplayChecker { - handshake_shards: Vec>, - tls_shards: Vec>, - shard_mask: usize, - window: Duration, - tls_window: Duration, - checks: AtomicU64, - hits: AtomicU64, - additions: AtomicU64, - cleanups: AtomicU64, -} - -struct ReplayEntry { - seen_at: Instant, - seq: u64, -} - -struct ReplayShard { - cache: LruCache, ReplayEntry>, - queue: VecDeque<(Instant, Arc<[u8]>, u64)>, - seq_counter: u64, - capacity: usize, -} - -impl ReplayShard { - fn new(cap: NonZeroUsize) -> Self { - Self { - cache: LruCache::new(cap), - queue: VecDeque::with_capacity(cap.get()), - seq_counter: 0, - capacity: cap.get(), - } - } - - fn next_seq(&mut self) -> u64 { - self.seq_counter += 1; - self.seq_counter - } - - fn cleanup(&mut self, now: Instant, window: Duration) { - if window.is_zero() { - self.cache.clear(); - self.queue.clear(); - return; - } - let cutoff = now.checked_sub(window).unwrap_or(now); - - while let Some((ts, _, _)) = self.queue.front() { - if *ts >= cutoff { - break; - } - self.evict_queue_front(); - } - } - - fn evict_queue_front(&mut self) { - let Some((_, key, queue_seq)) = self.queue.pop_front() else { - return; - }; - - if let Some(entry) = self.cache.peek(key.as_ref()) - && entry.seq == queue_seq - { - self.cache.pop(key.as_ref()); - } - } - - fn check(&mut self, key: &[u8], now: Instant, window: Duration) -> bool { - if window.is_zero() { - return false; - } - self.cleanup(now, window); - // key is &[u8], resolves Q=[u8] via Box<[u8]>: Borrow<[u8]> - self.cache.get(key).is_some() - } - - fn add(&mut self, key: &[u8], now: Instant, window: Duration) { - if window.is_zero() { - return; - } - self.cleanup(now, window); - if self.cache.peek(key).is_some() { - return; - } - while self.queue.len() >= self.capacity { - self.evict_queue_front(); - } - - let seq = self.next_seq(); - let shared_key: Arc<[u8]> = Arc::from(key); - - self.cache - .put(Arc::clone(&shared_key), ReplayEntry { seen_at: now, seq }); - self.queue.push_back((now, shared_key, seq)); - } - - fn len(&self) -> usize { - self.cache.len() - } -} - -impl ReplayChecker { - pub fn new(total_capacity: usize, window: Duration) -> Self { - const MIN_TLS_REPLAY_WINDOW: Duration = Duration::from_secs(120); - let num_shards = 64; - let shard_capacity = (total_capacity / num_shards).max(1); - let cap = NonZeroUsize::new(shard_capacity).unwrap(); - - let mut handshake_shards = Vec::with_capacity(num_shards); - let mut tls_shards = Vec::with_capacity(num_shards); - for _ in 0..num_shards { - handshake_shards.push(Mutex::new(ReplayShard::new(cap))); - tls_shards.push(Mutex::new(ReplayShard::new(cap))); - } - - Self { - handshake_shards, - tls_shards, - shard_mask: num_shards - 1, - window, - tls_window: window.max(MIN_TLS_REPLAY_WINDOW), - checks: AtomicU64::new(0), - hits: AtomicU64::new(0), - additions: AtomicU64::new(0), - cleanups: AtomicU64::new(0), - } - } - - fn get_shard_idx(&self, key: &[u8]) -> usize { - let mut hasher = DefaultHasher::new(); - key.hash(&mut hasher); - (hasher.finish() as usize) & self.shard_mask - } - - fn check_and_add_internal( - &self, - data: &[u8], - shards: &[Mutex], - window: Duration, - ) -> bool { - self.checks.fetch_add(1, Ordering::Relaxed); - let idx = self.get_shard_idx(data); - let mut shard = shards[idx].lock(); - let now = Instant::now(); - let found = shard.check(data, now, window); - if found { - self.hits.fetch_add(1, Ordering::Relaxed); - } else { - shard.add(data, now, window); - self.additions.fetch_add(1, Ordering::Relaxed); - } - found - } - - fn check_only_internal( - &self, - data: &[u8], - shards: &[Mutex], - window: Duration, - ) -> bool { - self.checks.fetch_add(1, Ordering::Relaxed); - let idx = self.get_shard_idx(data); - let mut shard = shards[idx].lock(); - let found = shard.check(data, Instant::now(), window); - if found { - self.hits.fetch_add(1, Ordering::Relaxed); - } - found - } - - fn add_only(&self, data: &[u8], shards: &[Mutex], window: Duration) { - self.additions.fetch_add(1, Ordering::Relaxed); - let idx = self.get_shard_idx(data); - let mut shard = shards[idx].lock(); - shard.add(data, Instant::now(), window); - } - - pub fn check_and_add_handshake(&self, data: &[u8]) -> bool { - self.check_and_add_internal(data, &self.handshake_shards, self.window) - } - - pub fn check_and_add_tls_digest(&self, data: &[u8]) -> bool { - self.check_and_add_internal(data, &self.tls_shards, self.tls_window) - } - - // Compatibility helpers (non-atomic split operations) — prefer check_and_add_*. - pub fn check_handshake(&self, data: &[u8]) -> bool { - self.check_and_add_handshake(data) - } - pub fn add_handshake(&self, data: &[u8]) { - self.add_only(data, &self.handshake_shards, self.window) - } - pub fn check_tls_digest(&self, data: &[u8]) -> bool { - self.check_only_internal(data, &self.tls_shards, self.tls_window) - } - pub fn add_tls_digest(&self, data: &[u8]) { - self.add_only(data, &self.tls_shards, self.tls_window) - } - - pub fn stats(&self) -> ReplayStats { - let mut total_entries = 0; - let mut total_queue_len = 0; - for shard in &self.handshake_shards { - let s = shard.lock(); - total_entries += s.cache.len(); - total_queue_len += s.queue.len(); - } - for shard in &self.tls_shards { - let s = shard.lock(); - total_entries += s.cache.len(); - total_queue_len += s.queue.len(); - } - - ReplayStats { - total_entries, - total_queue_len, - total_checks: self.checks.load(Ordering::Relaxed), - total_hits: self.hits.load(Ordering::Relaxed), - total_additions: self.additions.load(Ordering::Relaxed), - total_cleanups: self.cleanups.load(Ordering::Relaxed), - num_shards: self.handshake_shards.len() + self.tls_shards.len(), - window_secs: self.window.as_secs(), - } - } - - pub async fn run_periodic_cleanup(&self) { - let interval = if self.window.as_secs() > 60 { - Duration::from_secs(30) - } else { - Duration::from_secs((self.window.as_secs().max(1) / 2).max(1)) - }; - - loop { - tokio::time::sleep(interval).await; - - let now = Instant::now(); - let mut cleaned = 0usize; - - for shard_mutex in &self.handshake_shards { - let mut shard = shard_mutex.lock(); - let before = shard.len(); - shard.cleanup(now, self.window); - let after = shard.len(); - cleaned += before.saturating_sub(after); - } - for shard_mutex in &self.tls_shards { - let mut shard = shard_mutex.lock(); - let before = shard.len(); - shard.cleanup(now, self.tls_window); - let after = shard.len(); - cleaned += before.saturating_sub(after); - } - - self.cleanups.fetch_add(1, Ordering::Relaxed); - - if cleaned > 0 { - debug!(cleaned = cleaned, "Replay checker: periodic cleanup"); - } - } - } -} - -#[derive(Debug, Clone)] -pub struct ReplayStats { - pub total_entries: usize, - pub total_queue_len: usize, - pub total_checks: u64, - pub total_hits: u64, - pub total_additions: u64, - pub total_cleanups: u64, - pub num_shards: usize, - pub window_secs: u64, -} - -impl ReplayStats { - pub fn hit_rate(&self) -> f64 { - if self.total_checks == 0 { - 0.0 - } else { - (self.total_hits as f64 / self.total_checks as f64) * 100.0 - } - } - - pub fn ghost_ratio(&self) -> f64 { - if self.total_entries == 0 { - 0.0 - } else { - self.total_queue_len as f64 / self.total_entries as f64 - } - } } #[cfg(test)] -mod tests { - use super::*; - use crate::config::MeTelemetryLevel; - use std::sync::Arc; - use std::sync::atomic::{AtomicU64, Ordering}; - - #[test] - fn test_stats_shared_counters() { - let stats = Arc::new(Stats::new()); - stats.increment_connects_all(); - stats.increment_connects_all(); - stats.increment_connects_all(); - assert_eq!(stats.get_connects_all(), 3); - } - - #[test] - fn test_telemetry_policy_disables_core_and_user_counters() { - let stats = Stats::new(); - stats.apply_telemetry_policy(TelemetryPolicy { - core_enabled: false, - user_enabled: false, - me_level: MeTelemetryLevel::Normal, - }); - - stats.increment_connects_all(); - stats.increment_user_connects("alice"); - stats.add_user_octets_from("alice", 1024); - assert_eq!(stats.get_connects_all(), 0); - assert_eq!(stats.get_user_curr_connects("alice"), 0); - assert_eq!(stats.get_user_total_octets("alice"), 0); - } - - #[test] - fn test_telemetry_policy_me_silent_blocks_me_counters() { - let stats = Stats::new(); - stats.apply_telemetry_policy(TelemetryPolicy { - core_enabled: true, - user_enabled: true, - me_level: MeTelemetryLevel::Silent, - }); - - stats.increment_me_crc_mismatch(); - stats.increment_me_keepalive_sent(); - stats.increment_me_route_drop_queue_full(); - stats.increment_me_d2c_batches_total(); - stats.add_me_d2c_batch_frames_total(4); - stats.add_me_d2c_batch_bytes_total(4096); - stats.increment_me_d2c_flush_reason(MeD2cFlushReason::BatchBytes); - stats.increment_me_d2c_write_mode(MeD2cWriteMode::Coalesced); - stats.increment_me_d2c_quota_reject_total(MeD2cQuotaRejectStage::PreWrite); - stats.observe_me_d2c_frame_buf_shrink(1024); - stats.observe_me_d2c_batch_frames(4); - stats.observe_me_d2c_batch_bytes(4096); - stats.observe_me_d2c_flush_duration_us(120); - stats.increment_me_d2c_batch_timeout_armed_total(); - stats.increment_me_d2c_batch_timeout_fired_total(); - assert_eq!(stats.get_me_crc_mismatch(), 0); - assert_eq!(stats.get_me_keepalive_sent(), 0); - assert_eq!(stats.get_me_route_drop_queue_full(), 0); - assert_eq!(stats.get_me_d2c_batches_total(), 0); - assert_eq!(stats.get_me_d2c_flush_reason_batch_bytes_total(), 0); - assert_eq!(stats.get_me_d2c_write_mode_coalesced_total(), 0); - assert_eq!(stats.get_me_d2c_quota_reject_pre_write_total(), 0); - assert_eq!(stats.get_me_d2c_frame_buf_shrink_total(), 0); - assert_eq!(stats.get_me_d2c_batch_frames_bucket_2_4(), 0); - assert_eq!(stats.get_me_d2c_batch_bytes_bucket_1k_4k(), 0); - assert_eq!(stats.get_me_d2c_flush_duration_us_bucket_51_200(), 0); - assert_eq!(stats.get_me_d2c_batch_timeout_armed_total(), 0); - assert_eq!(stats.get_me_d2c_batch_timeout_fired_total(), 0); - } - - #[test] - fn test_telemetry_policy_me_normal_blocks_d2c_debug_metrics() { - let stats = Stats::new(); - stats.apply_telemetry_policy(TelemetryPolicy { - core_enabled: true, - user_enabled: true, - me_level: MeTelemetryLevel::Normal, - }); - - stats.increment_me_d2c_batches_total(); - stats.add_me_d2c_batch_frames_total(2); - stats.add_me_d2c_batch_bytes_total(2048); - stats.increment_me_d2c_flush_reason(MeD2cFlushReason::QueueDrain); - stats.observe_me_d2c_batch_frames(2); - stats.observe_me_d2c_batch_bytes(2048); - stats.observe_me_d2c_flush_duration_us(100); - stats.increment_me_d2c_batch_timeout_armed_total(); - stats.increment_me_d2c_batch_timeout_fired_total(); - - assert_eq!(stats.get_me_d2c_batches_total(), 1); - assert_eq!(stats.get_me_d2c_batch_frames_total(), 2); - assert_eq!(stats.get_me_d2c_batch_bytes_total(), 2048); - assert_eq!(stats.get_me_d2c_flush_reason_queue_drain_total(), 1); - assert_eq!(stats.get_me_d2c_batch_frames_bucket_2_4(), 0); - assert_eq!(stats.get_me_d2c_batch_bytes_bucket_1k_4k(), 0); - assert_eq!(stats.get_me_d2c_flush_duration_us_bucket_51_200(), 0); - assert_eq!(stats.get_me_d2c_batch_timeout_armed_total(), 0); - assert_eq!(stats.get_me_d2c_batch_timeout_fired_total(), 0); - } - - #[test] - fn test_telemetry_policy_me_debug_enables_d2c_debug_metrics() { - let stats = Stats::new(); - stats.apply_telemetry_policy(TelemetryPolicy { - core_enabled: true, - user_enabled: true, - me_level: MeTelemetryLevel::Debug, - }); - - stats.observe_me_d2c_batch_frames(7); - stats.observe_me_d2c_batch_bytes(70_000); - stats.observe_me_d2c_flush_duration_us(1400); - stats.increment_me_d2c_batch_timeout_armed_total(); - stats.increment_me_d2c_batch_timeout_fired_total(); - - assert_eq!(stats.get_me_d2c_batch_frames_bucket_5_8(), 1); - assert_eq!(stats.get_me_d2c_batch_bytes_bucket_64k_128k(), 1); - assert_eq!(stats.get_me_d2c_flush_duration_us_bucket_1001_5000(), 1); - assert_eq!(stats.get_me_d2c_batch_timeout_armed_total(), 1); - assert_eq!(stats.get_me_d2c_batch_timeout_fired_total(), 1); - } - - #[test] - fn test_replay_checker_basic() { - let checker = ReplayChecker::new(100, Duration::from_secs(60)); - assert!(!checker.check_handshake(b"test1")); // first time, inserts - assert!(checker.check_handshake(b"test1")); // duplicate - assert!(!checker.check_handshake(b"test2")); // new key inserts - } - - #[test] - fn test_replay_checker_duplicate_add() { - let checker = ReplayChecker::new(100, Duration::from_secs(60)); - checker.add_handshake(b"dup"); - checker.add_handshake(b"dup"); - assert!(checker.check_handshake(b"dup")); - } - - #[test] - fn test_replay_checker_expiration() { - let checker = ReplayChecker::new(100, Duration::from_millis(50)); - assert!(!checker.check_handshake(b"expire")); - assert!(checker.check_handshake(b"expire")); - std::thread::sleep(Duration::from_millis(100)); - assert!(!checker.check_handshake(b"expire")); - } - - #[test] - fn test_replay_checker_zero_window_does_not_retain_entries() { - let checker = ReplayChecker::new(100, Duration::ZERO); - - for _ in 0..1_000 { - assert!(!checker.check_handshake(b"no-retain")); - checker.add_handshake(b"no-retain"); - } - - let stats = checker.stats(); - assert_eq!(stats.total_entries, 0); - assert_eq!(stats.total_queue_len, 0); - } - - #[test] - fn test_replay_checker_stats() { - let checker = ReplayChecker::new(100, Duration::from_secs(60)); - assert!(!checker.check_handshake(b"k1")); - assert!(!checker.check_handshake(b"k2")); - assert!(checker.check_handshake(b"k1")); - assert!(!checker.check_handshake(b"k3")); - let stats = checker.stats(); - assert_eq!(stats.total_additions, 3); - assert_eq!(stats.total_checks, 4); - assert_eq!(stats.total_hits, 1); - } - - #[test] - fn test_replay_checker_many_keys() { - let checker = ReplayChecker::new(10_000, Duration::from_secs(60)); - for i in 0..500u32 { - checker.add_handshake(&i.to_le_bytes()); - } - for i in 0..500u32 { - assert!(checker.check_handshake(&i.to_le_bytes())); - } - assert_eq!(checker.stats().total_entries, 500); - } - - #[test] - fn test_quota_reserve_under_contention_hits_limit_exactly() { - let user_stats = Arc::new(UserStats::default()); - let successes = Arc::new(AtomicU64::new(0)); - let limit = 8_192u64; - let mut workers = Vec::new(); - - for _ in 0..8 { - let user_stats = user_stats.clone(); - let successes = successes.clone(); - workers.push(std::thread::spawn(move || { - loop { - match user_stats.quota_try_reserve(1, limit) { - Ok(_) => { - successes.fetch_add(1, Ordering::Relaxed); - } - Err(QuotaReserveError::Contended) => { - std::hint::spin_loop(); - } - Err(QuotaReserveError::LimitExceeded) => { - break; - } - } - } - })); - } - - for worker in workers { - worker.join().expect("worker thread must finish"); - } - - assert_eq!( - successes.load(Ordering::Relaxed), - limit, - "successful reservations must stop exactly at limit" - ); - assert_eq!(user_stats.quota_used(), limit); - } - - #[test] - fn test_quota_reserve_200x_1k_reaches_100k_without_overshoot() { - let user_stats = Arc::new(UserStats::default()); - let successes = Arc::new(AtomicU64::new(0)); - let failures = Arc::new(AtomicU64::new(0)); - let attempts = 200usize; - let reserve_bytes = 1_024u64; - let limit = 100 * 1_024u64; - let mut workers = Vec::with_capacity(attempts); - - for _ in 0..attempts { - let user_stats = user_stats.clone(); - let successes = successes.clone(); - let failures = failures.clone(); - workers.push(std::thread::spawn(move || { - loop { - match user_stats.quota_try_reserve(reserve_bytes, limit) { - Ok(_) => { - successes.fetch_add(1, Ordering::Relaxed); - return; - } - Err(QuotaReserveError::LimitExceeded) => { - failures.fetch_add(1, Ordering::Relaxed); - return; - } - Err(QuotaReserveError::Contended) => { - std::hint::spin_loop(); - } - } - } - })); - } - - for worker in workers { - worker.join().expect("reservation worker must finish"); - } - - assert_eq!( - successes.load(Ordering::Relaxed), - 100, - "exactly 100 reservations of 1 KiB must fit into a 100 KiB quota" - ); - assert_eq!( - failures.load(Ordering::Relaxed), - 100, - "remaining workers must fail once quota is fully reserved" - ); - assert_eq!(user_stats.quota_used(), limit); - } - - #[test] - fn test_quota_used_is_authoritative_and_independent_from_octets_telemetry() { - let stats = Stats::new(); - let user = "quota-authoritative-user"; - let user_stats = stats.get_or_create_user_stats_handle(user); - - stats.add_user_octets_to_handle(&user_stats, 5); - assert_eq!(stats.get_user_total_octets(user), 5); - assert_eq!(stats.get_user_quota_used(user), 0); - - stats.quota_charge_post_write(&user_stats, 7); - assert_eq!(stats.get_user_total_octets(user), 5); - assert_eq!(stats.get_user_quota_used(user), 7); - } - - #[test] - fn test_cached_handle_survives_map_cleanup_until_last_drop() { - let stats = Stats::new(); - let user = "quota-handle-lifetime-user"; - let user_stats = stats.get_or_create_user_stats_handle(user); - let weak = Arc::downgrade(&user_stats); - - stats.user_stats.remove(user); - assert!( - stats.user_stats.get(user).is_none(), - "map cleanup should remove idle entry" - ); - assert!( - weak.upgrade().is_some(), - "cached handle must keep user stats object alive after map removal" - ); - - stats.quota_charge_post_write(user_stats.as_ref(), 3); - assert_eq!(user_stats.quota_used(), 3); - - drop(user_stats); - assert!( - weak.upgrade().is_none(), - "user stats object must be dropped after the last cached handle is released" - ); - } -} +mod tests; #[cfg(test)] #[path = "tests/connection_lease_security_tests.rs"] diff --git a/src/stats/replay.rs b/src/stats/replay.rs new file mode 100644 index 0000000..c0e2257 --- /dev/null +++ b/src/stats/replay.rs @@ -0,0 +1,356 @@ +use std::borrow::Borrow; +use std::collections::VecDeque; +use std::collections::hash_map::DefaultHasher; +use std::hash::{Hash, Hasher}; +use std::num::NonZeroUsize; +use std::sync::Arc; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::time::{Duration, Instant}; + +use lru::LruCache; +use parking_lot::Mutex; +use tracing::debug; + +const REPLAY_INLINE_KEY_CAP: usize = 48; + +#[derive(Clone)] +enum ReplayKey { + Inline { + len: u8, + bytes: [u8; REPLAY_INLINE_KEY_CAP], + }, + Heap(Arc<[u8]>), +} + +impl ReplayKey { + fn from_slice(key: &[u8]) -> Self { + if key.len() <= REPLAY_INLINE_KEY_CAP { + let mut bytes = [0u8; REPLAY_INLINE_KEY_CAP]; + bytes[..key.len()].copy_from_slice(key); + return Self::Inline { + len: key.len() as u8, + bytes, + }; + } + + Self::Heap(Arc::from(key)) + } + + fn as_slice(&self) -> &[u8] { + match self { + Self::Inline { len, bytes } => &bytes[..*len as usize], + Self::Heap(bytes) => bytes.as_ref(), + } + } +} + +impl Borrow<[u8]> for ReplayKey { + fn borrow(&self) -> &[u8] { + self.as_slice() + } +} + +impl PartialEq for ReplayKey { + fn eq(&self, other: &Self) -> bool { + self.as_slice() == other.as_slice() + } +} + +impl Eq for ReplayKey {} + +impl Hash for ReplayKey { + fn hash(&self, state: &mut H) { + self.as_slice().hash(state); + } +} + +pub struct ReplayChecker { + handshake_shards: Vec>, + tls_shards: Vec>, + shard_mask: usize, + window: Duration, + tls_window: Duration, + checks: AtomicU64, + hits: AtomicU64, + additions: AtomicU64, + cleanups: AtomicU64, +} + +struct ReplayEntry { + seq: u64, +} + +struct ReplayShard { + cache: LruCache, + queue: VecDeque<(Instant, ReplayKey, u64)>, + seq_counter: u64, + capacity: usize, +} + +impl ReplayShard { + fn new(cap: NonZeroUsize) -> Self { + Self { + cache: LruCache::new(cap), + queue: VecDeque::with_capacity(cap.get()), + seq_counter: 0, + capacity: cap.get(), + } + } + + fn next_seq(&mut self) -> u64 { + self.seq_counter += 1; + self.seq_counter + } + + fn cleanup(&mut self, now: Instant, window: Duration) { + if window.is_zero() { + self.cache.clear(); + self.queue.clear(); + return; + } + let cutoff = now.checked_sub(window).unwrap_or(now); + + while let Some((ts, _, _)) = self.queue.front() { + if *ts >= cutoff { + break; + } + self.evict_queue_front(); + } + } + + fn evict_queue_front(&mut self) { + let Some((_, key, queue_seq)) = self.queue.pop_front() else { + return; + }; + + if let Some(entry) = self.cache.peek(key.as_slice()) + && entry.seq == queue_seq + { + self.cache.pop(key.as_slice()); + } + } + + fn check(&mut self, key: &[u8], now: Instant, window: Duration) -> bool { + if window.is_zero() { + return false; + } + self.cleanup(now, window); + self.cache.get(key).is_some() + } + + fn add_owned(&mut self, key: ReplayKey, now: Instant, window: Duration) { + if window.is_zero() { + return; + } + self.cleanup(now, window); + if self.cache.peek(key.as_slice()).is_some() { + return; + } + while self.queue.len() >= self.capacity { + self.evict_queue_front(); + } + + let seq = self.next_seq(); + self.cache.put(key.clone(), ReplayEntry { seq }); + self.queue.push_back((now, key, seq)); + } + + fn len(&self) -> usize { + self.cache.len() + } +} + +impl ReplayChecker { + pub fn new(total_capacity: usize, window: Duration) -> Self { + const MIN_TLS_REPLAY_WINDOW: Duration = Duration::from_secs(120); + let num_shards = 64; + let shard_capacity = (total_capacity / num_shards).max(1); + let cap = NonZeroUsize::new(shard_capacity).unwrap(); + + let mut handshake_shards = Vec::with_capacity(num_shards); + let mut tls_shards = Vec::with_capacity(num_shards); + for _ in 0..num_shards { + handshake_shards.push(Mutex::new(ReplayShard::new(cap))); + tls_shards.push(Mutex::new(ReplayShard::new(cap))); + } + + Self { + handshake_shards, + tls_shards, + shard_mask: num_shards - 1, + window, + tls_window: window.max(MIN_TLS_REPLAY_WINDOW), + checks: AtomicU64::new(0), + hits: AtomicU64::new(0), + additions: AtomicU64::new(0), + cleanups: AtomicU64::new(0), + } + } + + fn get_shard_idx(&self, key: &[u8]) -> usize { + let mut hasher = DefaultHasher::new(); + key.hash(&mut hasher); + (hasher.finish() as usize) & self.shard_mask + } + + fn check_and_add_internal( + &self, + data: &[u8], + shards: &[Mutex], + window: Duration, + ) -> bool { + self.checks.fetch_add(1, Ordering::Relaxed); + let idx = self.get_shard_idx(data); + let owned_key = ReplayKey::from_slice(data); + let mut shard = shards[idx].lock(); + let now = Instant::now(); + let found = shard.check(data, now, window); + if found { + self.hits.fetch_add(1, Ordering::Relaxed); + } else { + shard.add_owned(owned_key, now, window); + self.additions.fetch_add(1, Ordering::Relaxed); + } + found + } + + fn check_only_internal( + &self, + data: &[u8], + shards: &[Mutex], + window: Duration, + ) -> bool { + self.checks.fetch_add(1, Ordering::Relaxed); + let idx = self.get_shard_idx(data); + let mut shard = shards[idx].lock(); + let found = shard.check(data, Instant::now(), window); + if found { + self.hits.fetch_add(1, Ordering::Relaxed); + } + found + } + + fn add_only(&self, data: &[u8], shards: &[Mutex], window: Duration) { + self.additions.fetch_add(1, Ordering::Relaxed); + let idx = self.get_shard_idx(data); + let owned_key = ReplayKey::from_slice(data); + let mut shard = shards[idx].lock(); + shard.add_owned(owned_key, Instant::now(), window); + } + + pub fn check_and_add_handshake(&self, data: &[u8]) -> bool { + self.check_and_add_internal(data, &self.handshake_shards, self.window) + } + + pub fn check_and_add_tls_digest(&self, data: &[u8]) -> bool { + self.check_and_add_internal(data, &self.tls_shards, self.tls_window) + } + + pub fn check_handshake(&self, data: &[u8]) -> bool { + self.check_and_add_handshake(data) + } + + pub fn add_handshake(&self, data: &[u8]) { + self.add_only(data, &self.handshake_shards, self.window) + } + + pub fn check_tls_digest(&self, data: &[u8]) -> bool { + self.check_only_internal(data, &self.tls_shards, self.tls_window) + } + + pub fn add_tls_digest(&self, data: &[u8]) { + self.add_only(data, &self.tls_shards, self.tls_window) + } + + pub fn stats(&self) -> ReplayStats { + let mut total_entries = 0; + let mut total_queue_len = 0; + for shard in &self.handshake_shards { + let s = shard.lock(); + total_entries += s.cache.len(); + total_queue_len += s.queue.len(); + } + for shard in &self.tls_shards { + let s = shard.lock(); + total_entries += s.cache.len(); + total_queue_len += s.queue.len(); + } + + ReplayStats { + total_entries, + total_queue_len, + total_checks: self.checks.load(Ordering::Relaxed), + total_hits: self.hits.load(Ordering::Relaxed), + total_additions: self.additions.load(Ordering::Relaxed), + total_cleanups: self.cleanups.load(Ordering::Relaxed), + num_shards: self.handshake_shards.len() + self.tls_shards.len(), + window_secs: self.window.as_secs(), + } + } + + pub async fn run_periodic_cleanup(&self) { + let interval = if self.window.as_secs() > 60 { + Duration::from_secs(30) + } else { + Duration::from_secs((self.window.as_secs().max(1) / 2).max(1)) + }; + + loop { + tokio::time::sleep(interval).await; + + let now = Instant::now(); + let mut cleaned = 0usize; + + for shard_mutex in &self.handshake_shards { + let mut shard = shard_mutex.lock(); + let before = shard.len(); + shard.cleanup(now, self.window); + let after = shard.len(); + cleaned += before.saturating_sub(after); + } + for shard_mutex in &self.tls_shards { + let mut shard = shard_mutex.lock(); + let before = shard.len(); + shard.cleanup(now, self.tls_window); + let after = shard.len(); + cleaned += before.saturating_sub(after); + } + + self.cleanups.fetch_add(1, Ordering::Relaxed); + + if cleaned > 0 { + debug!(cleaned = cleaned, "Replay checker: periodic cleanup"); + } + } + } +} + +#[derive(Debug, Clone)] +pub struct ReplayStats { + pub total_entries: usize, + pub total_queue_len: usize, + pub total_checks: u64, + pub total_hits: u64, + pub total_additions: u64, + pub total_cleanups: u64, + pub num_shards: usize, + pub window_secs: u64, +} + +impl ReplayStats { + pub fn hit_rate(&self) -> f64 { + if self.total_checks == 0 { + 0.0 + } else { + (self.total_hits as f64 / self.total_checks as f64) * 100.0 + } + } + + pub fn ghost_ratio(&self) -> f64 { + if self.total_entries == 0 { + 0.0 + } else { + self.total_queue_len as f64 / self.total_entries as f64 + } + } +} diff --git a/src/stats/tests.rs b/src/stats/tests.rs new file mode 100644 index 0000000..650b123 --- /dev/null +++ b/src/stats/tests.rs @@ -0,0 +1,317 @@ +use super::*; +use crate::config::MeTelemetryLevel; +use std::sync::Arc; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::time::Duration; + +#[test] +fn test_stats_shared_counters() { + let stats = Arc::new(Stats::new()); + stats.increment_connects_all(); + stats.increment_connects_all(); + stats.increment_connects_all(); + assert_eq!(stats.get_connects_all(), 3); +} + +#[test] +fn test_telemetry_policy_disables_core_and_user_counters() { + let stats = Stats::new(); + stats.apply_telemetry_policy(TelemetryPolicy { + core_enabled: false, + user_enabled: false, + me_level: MeTelemetryLevel::Normal, + }); + + stats.increment_connects_all(); + stats.increment_user_connects("alice"); + stats.add_user_octets_from("alice", 1024); + assert_eq!(stats.get_connects_all(), 0); + assert_eq!(stats.get_user_curr_connects("alice"), 0); + assert_eq!(stats.get_user_total_octets("alice"), 0); +} + +#[test] +fn test_telemetry_policy_me_silent_blocks_me_counters() { + let stats = Stats::new(); + stats.apply_telemetry_policy(TelemetryPolicy { + core_enabled: true, + user_enabled: true, + me_level: MeTelemetryLevel::Silent, + }); + + stats.increment_me_crc_mismatch(); + stats.increment_me_keepalive_sent(); + stats.increment_me_route_drop_queue_full(); + stats.increment_me_d2c_batches_total(); + stats.add_me_d2c_batch_frames_total(4); + stats.add_me_d2c_batch_bytes_total(4096); + stats.increment_me_d2c_flush_reason(MeD2cFlushReason::BatchBytes); + stats.increment_me_d2c_write_mode(MeD2cWriteMode::Coalesced); + stats.increment_me_d2c_quota_reject_total(MeD2cQuotaRejectStage::PreWrite); + stats.observe_me_d2c_frame_buf_shrink(1024); + stats.observe_me_d2c_batch_frames(4); + stats.observe_me_d2c_batch_bytes(4096); + stats.observe_me_d2c_flush_duration_us(120); + stats.increment_me_d2c_batch_timeout_armed_total(); + stats.increment_me_d2c_batch_timeout_fired_total(); + assert_eq!(stats.get_me_crc_mismatch(), 0); + assert_eq!(stats.get_me_keepalive_sent(), 0); + assert_eq!(stats.get_me_route_drop_queue_full(), 0); + assert_eq!(stats.get_me_d2c_batches_total(), 0); + assert_eq!(stats.get_me_d2c_flush_reason_batch_bytes_total(), 0); + assert_eq!(stats.get_me_d2c_write_mode_coalesced_total(), 0); + assert_eq!(stats.get_me_d2c_quota_reject_pre_write_total(), 0); + assert_eq!(stats.get_me_d2c_frame_buf_shrink_total(), 0); + assert_eq!(stats.get_me_d2c_batch_frames_bucket_2_4(), 0); + assert_eq!(stats.get_me_d2c_batch_bytes_bucket_1k_4k(), 0); + assert_eq!(stats.get_me_d2c_flush_duration_us_bucket_51_200(), 0); + assert_eq!(stats.get_me_d2c_batch_timeout_armed_total(), 0); + assert_eq!(stats.get_me_d2c_batch_timeout_fired_total(), 0); +} + +#[test] +fn test_telemetry_policy_me_normal_blocks_d2c_debug_metrics() { + let stats = Stats::new(); + stats.apply_telemetry_policy(TelemetryPolicy { + core_enabled: true, + user_enabled: true, + me_level: MeTelemetryLevel::Normal, + }); + + stats.increment_me_d2c_batches_total(); + stats.add_me_d2c_batch_frames_total(2); + stats.add_me_d2c_batch_bytes_total(2048); + stats.increment_me_d2c_flush_reason(MeD2cFlushReason::QueueDrain); + stats.observe_me_d2c_batch_frames(2); + stats.observe_me_d2c_batch_bytes(2048); + stats.observe_me_d2c_flush_duration_us(100); + stats.increment_me_d2c_batch_timeout_armed_total(); + stats.increment_me_d2c_batch_timeout_fired_total(); + + assert_eq!(stats.get_me_d2c_batches_total(), 1); + assert_eq!(stats.get_me_d2c_batch_frames_total(), 2); + assert_eq!(stats.get_me_d2c_batch_bytes_total(), 2048); + assert_eq!(stats.get_me_d2c_flush_reason_queue_drain_total(), 1); + assert_eq!(stats.get_me_d2c_batch_frames_bucket_2_4(), 0); + assert_eq!(stats.get_me_d2c_batch_bytes_bucket_1k_4k(), 0); + assert_eq!(stats.get_me_d2c_flush_duration_us_bucket_51_200(), 0); + assert_eq!(stats.get_me_d2c_batch_timeout_armed_total(), 0); + assert_eq!(stats.get_me_d2c_batch_timeout_fired_total(), 0); +} + +#[test] +fn test_telemetry_policy_me_debug_enables_d2c_debug_metrics() { + let stats = Stats::new(); + stats.apply_telemetry_policy(TelemetryPolicy { + core_enabled: true, + user_enabled: true, + me_level: MeTelemetryLevel::Debug, + }); + + stats.observe_me_d2c_batch_frames(7); + stats.observe_me_d2c_batch_bytes(70_000); + stats.observe_me_d2c_flush_duration_us(1400); + stats.increment_me_d2c_batch_timeout_armed_total(); + stats.increment_me_d2c_batch_timeout_fired_total(); + + assert_eq!(stats.get_me_d2c_batch_frames_bucket_5_8(), 1); + assert_eq!(stats.get_me_d2c_batch_bytes_bucket_64k_128k(), 1); + assert_eq!(stats.get_me_d2c_flush_duration_us_bucket_1001_5000(), 1); + assert_eq!(stats.get_me_d2c_batch_timeout_armed_total(), 1); + assert_eq!(stats.get_me_d2c_batch_timeout_fired_total(), 1); +} + +#[test] +fn test_replay_checker_basic() { + let checker = ReplayChecker::new(100, Duration::from_secs(60)); + assert!(!checker.check_handshake(b"test1")); // first time, inserts + assert!(checker.check_handshake(b"test1")); // duplicate + assert!(!checker.check_handshake(b"test2")); // new key inserts +} + +#[test] +fn test_replay_checker_duplicate_add() { + let checker = ReplayChecker::new(100, Duration::from_secs(60)); + checker.add_handshake(b"dup"); + checker.add_handshake(b"dup"); + assert!(checker.check_handshake(b"dup")); +} + +#[test] +fn test_replay_checker_expiration() { + let checker = ReplayChecker::new(100, Duration::from_millis(50)); + assert!(!checker.check_handshake(b"expire")); + assert!(checker.check_handshake(b"expire")); + std::thread::sleep(Duration::from_millis(100)); + assert!(!checker.check_handshake(b"expire")); +} + +#[test] +fn test_replay_checker_zero_window_does_not_retain_entries() { + let checker = ReplayChecker::new(100, Duration::ZERO); + + for _ in 0..1_000 { + assert!(!checker.check_handshake(b"no-retain")); + checker.add_handshake(b"no-retain"); + } + + let stats = checker.stats(); + assert_eq!(stats.total_entries, 0); + assert_eq!(stats.total_queue_len, 0); +} + +#[test] +fn test_replay_checker_stats() { + let checker = ReplayChecker::new(100, Duration::from_secs(60)); + assert!(!checker.check_handshake(b"k1")); + assert!(!checker.check_handshake(b"k2")); + assert!(checker.check_handshake(b"k1")); + assert!(!checker.check_handshake(b"k3")); + let stats = checker.stats(); + assert_eq!(stats.total_additions, 3); + assert_eq!(stats.total_checks, 4); + assert_eq!(stats.total_hits, 1); +} + +#[test] +fn test_replay_checker_many_keys() { + let checker = ReplayChecker::new(10_000, Duration::from_secs(60)); + for i in 0..500u32 { + checker.add_handshake(&i.to_le_bytes()); + } + for i in 0..500u32 { + assert!(checker.check_handshake(&i.to_le_bytes())); + } + assert_eq!(checker.stats().total_entries, 500); +} + +#[test] +fn test_quota_reserve_under_contention_hits_limit_exactly() { + let user_stats = Arc::new(UserStats::default()); + let successes = Arc::new(AtomicU64::new(0)); + let limit = 8_192u64; + let mut workers = Vec::new(); + + for _ in 0..8 { + let user_stats = user_stats.clone(); + let successes = successes.clone(); + workers.push(std::thread::spawn(move || { + loop { + match user_stats.quota_try_reserve(1, limit) { + Ok(_) => { + successes.fetch_add(1, Ordering::Relaxed); + } + Err(QuotaReserveError::Contended) => { + std::hint::spin_loop(); + } + Err(QuotaReserveError::LimitExceeded) => { + break; + } + } + } + })); + } + + for worker in workers { + worker.join().expect("worker thread must finish"); + } + + assert_eq!( + successes.load(Ordering::Relaxed), + limit, + "successful reservations must stop exactly at limit" + ); + assert_eq!(user_stats.quota_used(), limit); +} + +#[test] +fn test_quota_reserve_200x_1k_reaches_100k_without_overshoot() { + let user_stats = Arc::new(UserStats::default()); + let successes = Arc::new(AtomicU64::new(0)); + let failures = Arc::new(AtomicU64::new(0)); + let attempts = 200usize; + let reserve_bytes = 1_024u64; + let limit = 100 * 1_024u64; + let mut workers = Vec::with_capacity(attempts); + + for _ in 0..attempts { + let user_stats = user_stats.clone(); + let successes = successes.clone(); + let failures = failures.clone(); + workers.push(std::thread::spawn(move || { + loop { + match user_stats.quota_try_reserve(reserve_bytes, limit) { + Ok(_) => { + successes.fetch_add(1, Ordering::Relaxed); + return; + } + Err(QuotaReserveError::LimitExceeded) => { + failures.fetch_add(1, Ordering::Relaxed); + return; + } + Err(QuotaReserveError::Contended) => { + std::hint::spin_loop(); + } + } + } + })); + } + + for worker in workers { + worker.join().expect("reservation worker must finish"); + } + + assert_eq!( + successes.load(Ordering::Relaxed), + 100, + "exactly 100 reservations of 1 KiB must fit into a 100 KiB quota" + ); + assert_eq!( + failures.load(Ordering::Relaxed), + 100, + "remaining workers must fail once quota is fully reserved" + ); + assert_eq!(user_stats.quota_used(), limit); +} + +#[test] +fn test_quota_used_is_authoritative_and_independent_from_octets_telemetry() { + let stats = Stats::new(); + let user = "quota-authoritative-user"; + let user_stats = stats.get_or_create_user_stats_handle(user); + + stats.add_user_octets_to_handle(&user_stats, 5); + assert_eq!(stats.get_user_total_octets(user), 5); + assert_eq!(stats.get_user_quota_used(user), 0); + + stats.quota_charge_post_write(&user_stats, 7); + assert_eq!(stats.get_user_total_octets(user), 5); + assert_eq!(stats.get_user_quota_used(user), 7); +} + +#[test] +fn test_cached_handle_survives_map_cleanup_until_last_drop() { + let stats = Stats::new(); + let user = "quota-handle-lifetime-user"; + let user_stats = stats.get_or_create_user_stats_handle(user); + let weak = Arc::downgrade(&user_stats); + + stats.user_stats.remove(user); + assert!( + stats.user_stats.get(user).is_none(), + "map cleanup should remove idle entry" + ); + assert!( + weak.upgrade().is_some(), + "cached handle must keep user stats object alive after map removal" + ); + + stats.quota_charge_post_write(user_stats.as_ref(), 3); + assert_eq!(user_stats.quota_used(), 3); + + drop(user_stats); + assert!( + weak.upgrade().is_none(), + "user stats object must be dropped after the last cached handle is released" + ); +} diff --git a/src/stats/users.rs b/src/stats/users.rs new file mode 100644 index 0000000..72849cf --- /dev/null +++ b/src/stats/users.rs @@ -0,0 +1,249 @@ +use super::*; + +impl Stats { + pub fn increment_user_connects(&self, user: &str) { + if !self.telemetry_user_enabled() { + return; + } + let stats = self.get_or_create_user_stats_handle(user); + self.touch_user_stats(stats.as_ref()); + stats.connects.fetch_add(1, Ordering::Relaxed); + } + + pub fn increment_user_curr_connects(&self, user: &str) { + if !self.telemetry_user_enabled() { + return; + } + let stats = self.get_or_create_user_stats_handle(user); + self.touch_user_stats(stats.as_ref()); + stats.curr_connects.fetch_add(1, Ordering::Relaxed); + } + + pub fn try_acquire_user_curr_connects(&self, user: &str, limit: Option) -> bool { + if !self.telemetry_user_enabled() { + return true; + } + + let stats = self.get_or_create_user_stats_handle(user); + self.touch_user_stats(stats.as_ref()); + + let counter = &stats.curr_connects; + let mut current = counter.load(Ordering::Relaxed); + loop { + if let Some(max) = limit + && current >= max + { + return false; + } + match counter.compare_exchange_weak( + current, + current.saturating_add(1), + Ordering::Relaxed, + Ordering::Relaxed, + ) { + Ok(_) => return true, + Err(actual) => current = actual, + } + } + } + + pub fn decrement_user_curr_connects(&self, user: &str) { + if let Some(stats) = self.user_stats.get(user) { + self.touch_user_stats(stats.value().as_ref()); + let counter = &stats.curr_connects; + let mut current = counter.load(Ordering::Relaxed); + loop { + if current == 0 { + break; + } + match counter.compare_exchange_weak( + current, + current - 1, + Ordering::Relaxed, + Ordering::Relaxed, + ) { + Ok(_) => break, + Err(actual) => current = actual, + } + } + } + } + + pub fn get_user_curr_connects(&self, user: &str) -> u64 { + self.user_stats + .get(user) + .map(|s| s.curr_connects.load(Ordering::Relaxed)) + .unwrap_or(0) + } + + pub fn add_user_octets_from(&self, user: &str, bytes: u64) { + if !self.telemetry_user_enabled() { + return; + } + let stats = self.get_or_create_user_stats_handle(user); + self.add_user_octets_from_handle(stats.as_ref(), bytes); + } + + pub fn add_user_octets_to(&self, user: &str, bytes: u64) { + if !self.telemetry_user_enabled() { + return; + } + let stats = self.get_or_create_user_stats_handle(user); + self.add_user_octets_to_handle(stats.as_ref(), bytes); + } + + pub fn increment_user_msgs_from(&self, user: &str) { + if !self.telemetry_user_enabled() { + return; + } + let stats = self.get_or_create_user_stats_handle(user); + self.increment_user_msgs_from_handle(stats.as_ref()); + } + + pub fn increment_user_msgs_to(&self, user: &str) { + if !self.telemetry_user_enabled() { + return; + } + let stats = self.get_or_create_user_stats_handle(user); + self.increment_user_msgs_to_handle(stats.as_ref()); + } + + pub fn get_user_total_octets(&self, user: &str) -> u64 { + self.user_stats + .get(user) + .map(|s| { + s.octets_from_client.load(Ordering::Relaxed) + + s.octets_to_client.load(Ordering::Relaxed) + }) + .unwrap_or(0) + } + + pub fn get_user_quota_used(&self, user: &str) -> u64 { + self.user_stats + .get(user) + .map(|s| s.quota_used.load(Ordering::Relaxed)) + .unwrap_or(0) + } + + pub fn load_user_quota_state(&self, user: &str, used_bytes: u64, last_reset_epoch_secs: u64) { + let stats = self.get_or_create_user_stats_handle(user); + stats.quota_used.store(used_bytes, Ordering::Relaxed); + stats + .quota_last_reset_epoch_secs + .store(last_reset_epoch_secs, Ordering::Relaxed); + } + + pub fn reset_user_quota(&self, user: &str) -> UserQuotaSnapshot { + let stats = self.get_or_create_user_stats_handle(user); + let last_reset_epoch_secs = Self::now_epoch_secs(); + stats.quota_used.store(0, Ordering::Relaxed); + stats + .quota_last_reset_epoch_secs + .store(last_reset_epoch_secs, Ordering::Relaxed); + UserQuotaSnapshot { + used_bytes: 0, + last_reset_epoch_secs, + } + } + + pub fn user_quota_snapshot(&self) -> HashMap { + let mut out = HashMap::new(); + for entry in self.user_stats.iter() { + let stats = entry.value(); + let used_bytes = stats.quota_used.load(Ordering::Relaxed); + let last_reset_epoch_secs = stats.quota_last_reset_epoch_secs.load(Ordering::Relaxed); + if used_bytes == 0 && last_reset_epoch_secs == 0 { + continue; + } + out.insert( + entry.key().clone(), + UserQuotaSnapshot { + used_bytes, + last_reset_epoch_secs, + }, + ); + } + out + } + + pub fn get_handshake_timeouts(&self) -> u64 { + self.handshake_timeouts.load(Ordering::Relaxed) + } + pub fn get_upstream_connect_attempt_total(&self) -> u64 { + self.upstream_connect_attempt_total.load(Ordering::Relaxed) + } + pub fn get_upstream_connect_success_total(&self) -> u64 { + self.upstream_connect_success_total.load(Ordering::Relaxed) + } + pub fn get_upstream_connect_fail_total(&self) -> u64 { + self.upstream_connect_fail_total.load(Ordering::Relaxed) + } + pub fn get_upstream_connect_failfast_hard_error_total(&self) -> u64 { + self.upstream_connect_failfast_hard_error_total + .load(Ordering::Relaxed) + } + pub fn get_upstream_connect_attempts_bucket_1(&self) -> u64 { + self.upstream_connect_attempts_bucket_1 + .load(Ordering::Relaxed) + } + pub fn get_upstream_connect_attempts_bucket_2(&self) -> u64 { + self.upstream_connect_attempts_bucket_2 + .load(Ordering::Relaxed) + } + pub fn get_upstream_connect_attempts_bucket_3_4(&self) -> u64 { + self.upstream_connect_attempts_bucket_3_4 + .load(Ordering::Relaxed) + } + pub fn get_upstream_connect_attempts_bucket_gt_4(&self) -> u64 { + self.upstream_connect_attempts_bucket_gt_4 + .load(Ordering::Relaxed) + } + pub fn get_upstream_connect_duration_success_bucket_le_100ms(&self) -> u64 { + self.upstream_connect_duration_success_bucket_le_100ms + .load(Ordering::Relaxed) + } + pub fn get_upstream_connect_duration_success_bucket_101_500ms(&self) -> u64 { + self.upstream_connect_duration_success_bucket_101_500ms + .load(Ordering::Relaxed) + } + pub fn get_upstream_connect_duration_success_bucket_501_1000ms(&self) -> u64 { + self.upstream_connect_duration_success_bucket_501_1000ms + .load(Ordering::Relaxed) + } + pub fn get_upstream_connect_duration_success_bucket_gt_1000ms(&self) -> u64 { + self.upstream_connect_duration_success_bucket_gt_1000ms + .load(Ordering::Relaxed) + } + pub fn get_upstream_connect_duration_fail_bucket_le_100ms(&self) -> u64 { + self.upstream_connect_duration_fail_bucket_le_100ms + .load(Ordering::Relaxed) + } + pub fn get_upstream_connect_duration_fail_bucket_101_500ms(&self) -> u64 { + self.upstream_connect_duration_fail_bucket_101_500ms + .load(Ordering::Relaxed) + } + pub fn get_upstream_connect_duration_fail_bucket_501_1000ms(&self) -> u64 { + self.upstream_connect_duration_fail_bucket_501_1000ms + .load(Ordering::Relaxed) + } + pub fn get_upstream_connect_duration_fail_bucket_gt_1000ms(&self) -> u64 { + self.upstream_connect_duration_fail_bucket_gt_1000ms + .load(Ordering::Relaxed) + } + + pub fn iter_user_stats(&self) -> dashmap::iter::Iter<'_, String, Arc> { + self.user_stats.iter() + } + + /// Current number of retained per-user stats entries. + pub fn user_stats_len(&self) -> usize { + self.user_stats.len() + } + + pub fn uptime_secs(&self) -> f64 { + self.start_time + .read() + .map(|t| t.elapsed().as_secs_f64()) + .unwrap_or(0.0) + } +} diff --git a/src/stats/writer_counters.rs b/src/stats/writer_counters.rs new file mode 100644 index 0000000..2977202 --- /dev/null +++ b/src/stats/writer_counters.rs @@ -0,0 +1,542 @@ +use super::*; + +impl Stats { + pub fn increment_me_writer_pick_success_try_total(&self, mode: MeWriterPickMode) { + if !self.telemetry_me_allows_normal() { + return; + } + match mode { + MeWriterPickMode::SortedRr => { + self.me_writer_pick_sorted_rr_success_try_total + .fetch_add(1, Ordering::Relaxed); + } + MeWriterPickMode::P2c => { + self.me_writer_pick_p2c_success_try_total + .fetch_add(1, Ordering::Relaxed); + } + } + } + pub fn increment_me_writer_pick_success_fallback_total(&self, mode: MeWriterPickMode) { + if !self.telemetry_me_allows_normal() { + return; + } + match mode { + MeWriterPickMode::SortedRr => { + self.me_writer_pick_sorted_rr_success_fallback_total + .fetch_add(1, Ordering::Relaxed); + } + MeWriterPickMode::P2c => { + self.me_writer_pick_p2c_success_fallback_total + .fetch_add(1, Ordering::Relaxed); + } + } + } + pub fn increment_me_writer_pick_full_total(&self, mode: MeWriterPickMode) { + if !self.telemetry_me_allows_normal() { + return; + } + match mode { + MeWriterPickMode::SortedRr => { + self.me_writer_pick_sorted_rr_full_total + .fetch_add(1, Ordering::Relaxed); + } + MeWriterPickMode::P2c => { + self.me_writer_pick_p2c_full_total + .fetch_add(1, Ordering::Relaxed); + } + } + } + pub fn increment_me_writer_pick_closed_total(&self, mode: MeWriterPickMode) { + if !self.telemetry_me_allows_normal() { + return; + } + match mode { + MeWriterPickMode::SortedRr => { + self.me_writer_pick_sorted_rr_closed_total + .fetch_add(1, Ordering::Relaxed); + } + MeWriterPickMode::P2c => { + self.me_writer_pick_p2c_closed_total + .fetch_add(1, Ordering::Relaxed); + } + } + } + pub fn increment_me_writer_pick_no_candidate_total(&self, mode: MeWriterPickMode) { + if !self.telemetry_me_allows_normal() { + return; + } + match mode { + MeWriterPickMode::SortedRr => { + self.me_writer_pick_sorted_rr_no_candidate_total + .fetch_add(1, Ordering::Relaxed); + } + MeWriterPickMode::P2c => { + self.me_writer_pick_p2c_no_candidate_total + .fetch_add(1, Ordering::Relaxed); + } + } + } + pub fn increment_me_writer_pick_blocking_fallback_total(&self) { + if self.telemetry_me_allows_normal() { + self.me_writer_pick_blocking_fallback_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_writer_pick_mode_switch_total(&self) { + if self.telemetry_me_allows_normal() { + self.me_writer_pick_mode_switch_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_socks_kdf_strict_reject(&self) { + if self.telemetry_me_allows_normal() { + self.me_socks_kdf_strict_reject + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_socks_kdf_compat_fallback(&self) { + if self.telemetry_me_allows_debug() { + self.me_socks_kdf_compat_fallback + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_secure_padding_invalid(&self) { + if self.telemetry_me_allows_normal() { + self.secure_padding_invalid.fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_desync_total(&self) { + if self.telemetry_me_allows_normal() { + self.desync_total.fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_desync_full_logged(&self) { + if self.telemetry_me_allows_normal() { + self.desync_full_logged.fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_desync_suppressed(&self) { + if self.telemetry_me_allows_normal() { + self.desync_suppressed.fetch_add(1, Ordering::Relaxed); + } + } + pub fn observe_desync_frames_ok(&self, frames_ok: u64) { + if !self.telemetry_me_allows_normal() { + return; + } + match frames_ok { + 0 => { + self.desync_frames_bucket_0.fetch_add(1, Ordering::Relaxed); + } + 1..=2 => { + self.desync_frames_bucket_1_2 + .fetch_add(1, Ordering::Relaxed); + } + 3..=10 => { + self.desync_frames_bucket_3_10 + .fetch_add(1, Ordering::Relaxed); + } + _ => { + self.desync_frames_bucket_gt_10 + .fetch_add(1, Ordering::Relaxed); + } + } + } + pub fn increment_pool_swap_total(&self) { + if self.telemetry_me_allows_normal() { + self.pool_swap_total.fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_pool_drain_active(&self) { + if self.telemetry_me_allows_debug() { + self.pool_drain_active.fetch_add(1, Ordering::Relaxed); + } + } + pub fn decrement_pool_drain_active(&self) { + if !self.telemetry_me_allows_debug() { + return; + } + let mut current = self.pool_drain_active.load(Ordering::Relaxed); + loop { + if current == 0 { + break; + } + match self.pool_drain_active.compare_exchange_weak( + current, + current - 1, + Ordering::Relaxed, + Ordering::Relaxed, + ) { + Ok(_) => break, + Err(actual) => current = actual, + } + } + } + pub fn increment_pool_force_close_total(&self) { + if self.telemetry_me_allows_normal() { + self.pool_force_close_total.fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_pool_stale_pick_total(&self) { + if self.telemetry_me_allows_normal() { + self.pool_stale_pick_total.fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_writer_removed_total(&self) { + if self.telemetry_me_allows_debug() { + self.me_writer_removed_total.fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_writer_removed_unexpected_total(&self) { + if self.telemetry_me_allows_normal() { + self.me_writer_removed_unexpected_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_refill_triggered_total(&self) { + if self.telemetry_me_allows_debug() { + self.me_refill_triggered_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_refill_skipped_inflight_total(&self) { + if self.telemetry_me_allows_debug() { + self.me_refill_skipped_inflight_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_refill_failed_total(&self) { + if self.telemetry_me_allows_normal() { + self.me_refill_failed_total.fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_writer_restored_same_endpoint_total(&self) { + if self.telemetry_me_allows_normal() { + self.me_writer_restored_same_endpoint_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_writer_restored_fallback_total(&self) { + if self.telemetry_me_allows_normal() { + self.me_writer_restored_fallback_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_no_writer_failfast_total(&self) { + if self.telemetry_me_allows_normal() { + self.me_no_writer_failfast_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_hybrid_timeout_total(&self) { + if self.telemetry_me_allows_normal() { + self.me_hybrid_timeout_total.fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_async_recovery_trigger_total(&self) { + if self.telemetry_me_allows_normal() { + self.me_async_recovery_trigger_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_inline_recovery_total(&self) { + if self.telemetry_me_allows_normal() { + self.me_inline_recovery_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_ip_reservation_rollback_tcp_limit_total(&self) { + if self.telemetry_core_enabled() { + self.ip_reservation_rollback_tcp_limit_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_ip_reservation_rollback_quota_limit_total(&self) { + if self.telemetry_core_enabled() { + self.ip_reservation_rollback_quota_limit_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn add_quota_refund_bytes_total(&self, bytes: u64) { + if self.telemetry_core_enabled() { + self.quota_refund_bytes_total + .fetch_add(bytes, Ordering::Relaxed); + } + } + pub fn increment_quota_contention_total(&self) { + if self.telemetry_core_enabled() { + self.quota_contention_total.fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_quota_contention_timeout_total(&self) { + if self.telemetry_core_enabled() { + self.quota_contention_timeout_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_quota_acquire_cancelled_total(&self) { + if self.telemetry_core_enabled() { + self.quota_acquire_cancelled_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn add_quota_write_fail_bytes_total(&self, bytes: u64) { + if self.telemetry_core_enabled() { + self.quota_write_fail_bytes_total + .fetch_add(bytes, Ordering::Relaxed); + } + } + pub fn increment_quota_write_fail_events_total(&self) { + if self.telemetry_core_enabled() { + self.quota_write_fail_events_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_child_join_timeout_total(&self) { + if self.telemetry_core_enabled() { + self.me_child_join_timeout_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_child_abort_total(&self) { + if self.telemetry_core_enabled() { + self.me_child_abort_total.fetch_add(1, Ordering::Relaxed); + } + } + pub fn observe_flow_wait_middle_rate_limit_ms(&self, wait_ms: u64) { + if self.telemetry_core_enabled() { + self.flow_wait_middle_rate_limit_total + .fetch_add(1, Ordering::Relaxed); + self.flow_wait_middle_rate_limit_ms_total + .fetch_add(wait_ms, Ordering::Relaxed); + } + } + pub fn increment_flow_wait_middle_rate_limit_cancelled_total(&self) { + if self.telemetry_core_enabled() { + self.flow_wait_middle_rate_limit_cancelled_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_session_drop_fallback_total(&self) { + if self.telemetry_core_enabled() { + self.session_drop_fallback_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_endpoint_quarantine_total(&self) { + if self.telemetry_me_allows_normal() { + self.me_endpoint_quarantine_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_endpoint_quarantine_unexpected_total(&self) { + if self.telemetry_me_allows_normal() { + self.me_endpoint_quarantine_unexpected_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_endpoint_quarantine_draining_suppressed_total(&self) { + if self.telemetry_me_allows_normal() { + self.me_endpoint_quarantine_draining_suppressed_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_kdf_drift_total(&self) { + if self.telemetry_me_allows_normal() { + self.me_kdf_drift_total.fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_kdf_port_only_drift_total(&self) { + if self.telemetry_me_allows_debug() { + self.me_kdf_port_only_drift_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_hardswap_pending_reuse_total(&self) { + if self.telemetry_me_allows_debug() { + self.me_hardswap_pending_reuse_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_hardswap_pending_ttl_expired_total(&self) { + if self.telemetry_me_allows_normal() { + self.me_hardswap_pending_ttl_expired_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_single_endpoint_outage_enter_total(&self) { + if self.telemetry_me_allows_normal() { + self.me_single_endpoint_outage_enter_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_single_endpoint_outage_exit_total(&self) { + if self.telemetry_me_allows_normal() { + self.me_single_endpoint_outage_exit_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_single_endpoint_outage_reconnect_attempt_total(&self) { + if self.telemetry_me_allows_normal() { + self.me_single_endpoint_outage_reconnect_attempt_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_single_endpoint_outage_reconnect_success_total(&self) { + if self.telemetry_me_allows_normal() { + self.me_single_endpoint_outage_reconnect_success_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_single_endpoint_quarantine_bypass_total(&self) { + if self.telemetry_me_allows_normal() { + self.me_single_endpoint_quarantine_bypass_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_single_endpoint_shadow_rotate_total(&self) { + if self.telemetry_me_allows_normal() { + self.me_single_endpoint_shadow_rotate_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_single_endpoint_shadow_rotate_skipped_quarantine_total(&self) { + if self.telemetry_me_allows_normal() { + self.me_single_endpoint_shadow_rotate_skipped_quarantine_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_floor_mode_switch_total(&self) { + if self.telemetry_me_allows_normal() { + self.me_floor_mode_switch_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_floor_mode_switch_static_to_adaptive_total(&self) { + if self.telemetry_me_allows_normal() { + self.me_floor_mode_switch_static_to_adaptive_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_floor_mode_switch_adaptive_to_static_total(&self) { + if self.telemetry_me_allows_normal() { + self.me_floor_mode_switch_adaptive_to_static_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn set_me_floor_cpu_cores_detected_gauge(&self, value: u64) { + if self.telemetry_me_allows_normal() { + self.me_floor_cpu_cores_detected_gauge + .store(value, Ordering::Relaxed); + } + } + pub fn set_me_floor_cpu_cores_effective_gauge(&self, value: u64) { + if self.telemetry_me_allows_normal() { + self.me_floor_cpu_cores_effective_gauge + .store(value, Ordering::Relaxed); + } + } + pub fn set_me_floor_global_cap_raw_gauge(&self, value: u64) { + if self.telemetry_me_allows_normal() { + self.me_floor_global_cap_raw_gauge + .store(value, Ordering::Relaxed); + } + } + pub fn set_me_floor_global_cap_effective_gauge(&self, value: u64) { + if self.telemetry_me_allows_normal() { + self.me_floor_global_cap_effective_gauge + .store(value, Ordering::Relaxed); + } + } + pub fn set_me_floor_target_writers_total_gauge(&self, value: u64) { + if self.telemetry_me_allows_normal() { + self.me_floor_target_writers_total_gauge + .store(value, Ordering::Relaxed); + } + } + pub fn set_me_floor_active_cap_configured_gauge(&self, value: u64) { + if self.telemetry_me_allows_normal() { + self.me_floor_active_cap_configured_gauge + .store(value, Ordering::Relaxed); + } + } + pub fn set_me_floor_active_cap_effective_gauge(&self, value: u64) { + if self.telemetry_me_allows_normal() { + self.me_floor_active_cap_effective_gauge + .store(value, Ordering::Relaxed); + } + } + pub fn set_me_floor_warm_cap_configured_gauge(&self, value: u64) { + if self.telemetry_me_allows_normal() { + self.me_floor_warm_cap_configured_gauge + .store(value, Ordering::Relaxed); + } + } + pub fn set_me_floor_warm_cap_effective_gauge(&self, value: u64) { + if self.telemetry_me_allows_normal() { + self.me_floor_warm_cap_effective_gauge + .store(value, Ordering::Relaxed); + } + } + pub fn set_me_writers_active_current_gauge(&self, value: u64) { + if self.telemetry_me_allows_normal() { + self.me_writers_active_current_gauge + .store(value, Ordering::Relaxed); + } + } + pub fn set_me_writers_warm_current_gauge(&self, value: u64) { + if self.telemetry_me_allows_normal() { + self.me_writers_warm_current_gauge + .store(value, Ordering::Relaxed); + } + } + + pub fn set_buffer_pool_gauges(&self, pooled: usize, allocated: usize, in_use: usize) { + if self.telemetry_me_allows_normal() { + self.buffer_pool_pooled_gauge + .store(pooled as u64, Ordering::Relaxed); + self.buffer_pool_allocated_gauge + .store(allocated as u64, Ordering::Relaxed); + self.buffer_pool_in_use_gauge + .store(in_use as u64, Ordering::Relaxed); + } + } + + pub fn increment_me_c2me_send_full_total(&self) { + if self.telemetry_me_allows_normal() { + self.me_c2me_send_full_total.fetch_add(1, Ordering::Relaxed); + } + } + + pub fn increment_me_c2me_send_high_water_total(&self) { + if self.telemetry_me_allows_normal() { + self.me_c2me_send_high_water_total + .fetch_add(1, Ordering::Relaxed); + } + } + + pub fn increment_me_c2me_send_timeout_total(&self) { + if self.telemetry_me_allows_normal() { + self.me_c2me_send_timeout_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_floor_cap_block_total(&self) { + if self.telemetry_me_allows_normal() { + self.me_floor_cap_block_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_floor_swap_idle_total(&self) { + if self.telemetry_me_allows_normal() { + self.me_floor_swap_idle_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_floor_swap_idle_failed_total(&self) { + if self.telemetry_me_allows_normal() { + self.me_floor_swap_idle_failed_total + .fetch_add(1, Ordering::Relaxed); + } + } +} diff --git a/src/transport/middle_proxy/codec.rs b/src/transport/middle_proxy/codec.rs index 3037c52..f603293 100644 --- a/src/transport/middle_proxy/codec.rs +++ b/src/transport/middle_proxy/codec.rs @@ -4,6 +4,9 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt}; use crate::crypto::{AesCbc, crc32, crc32c}; use crate::error::{ProxyError, Result}; use crate::protocol::constants::*; +use crate::stream::PooledBuffer; + +use super::wire::{append_proxy_req_payload_into, proxy_req_payload_len}; const RPC_WRITER_FRAME_BUF_SHRINK_THRESHOLD: usize = 256 * 1024; const RPC_WRITER_FRAME_BUF_RETAIN: usize = 64 * 1024; @@ -12,10 +15,21 @@ const RPC_WRITER_FRAME_BUF_RETAIN: usize = 64 * 1024; pub(crate) enum WriterCommand { Data(Bytes), DataAndFlush(Bytes), + ProxyReq(ProxyReqCommand), ControlAndFlush([u8; 12]), Close, } +/// Structured proxy request command that lets the writer encode directly into its frame buffer. +pub(crate) struct ProxyReqCommand { + pub(crate) conn_id: u64, + pub(crate) client_addr: std::net::SocketAddr, + pub(crate) our_addr: std::net::SocketAddr, + pub(crate) proto_flags: u32, + pub(crate) proxy_tag: Option<[u8; 16]>, + pub(crate) payload: PooledBuffer, +} + #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub(crate) enum RpcChecksumMode { Crc32, @@ -249,7 +263,37 @@ impl RpcWriter { pub(crate) async fn send(&mut self, payload: &[u8]) -> Result<()> { build_rpc_frame_into(&mut self.frame_buf, self.seq_no, payload, self.crc_mode); self.seq_no = self.seq_no.wrapping_add(1); + self.encrypt_and_write_frame().await + } + pub(crate) async fn send_proxy_req(&mut self, command: &ProxyReqCommand) -> Result<()> { + let payload_len = proxy_req_payload_len( + command.payload.len(), + command.proxy_tag.as_ref().map(|tag| tag.as_slice()), + command.proto_flags, + ); + let total_len = 4 + 4 + payload_len + 4; + self.frame_buf.clear(); + self.frame_buf.reserve(total_len + 15); + self.frame_buf + .extend_from_slice(&(total_len as u32).to_le_bytes()); + self.frame_buf.extend_from_slice(&self.seq_no.to_le_bytes()); + append_proxy_req_payload_into( + &mut self.frame_buf, + command.conn_id, + command.client_addr, + command.our_addr, + command.payload.as_ref(), + command.proxy_tag.as_ref().map(|tag| tag.as_slice()), + command.proto_flags, + ); + let c = rpc_crc(self.crc_mode, &self.frame_buf); + self.frame_buf.extend_from_slice(&c.to_le_bytes()); + self.seq_no = self.seq_no.wrapping_add(1); + self.encrypt_and_write_frame().await + } + + async fn encrypt_and_write_frame(&mut self) -> Result<()> { let pad = (16 - (self.frame_buf.len() % 16)) % 16; let pad_pattern: [u8; 4] = [0x04, 0x00, 0x00, 0x00]; for i in 0..pad { diff --git a/src/transport/middle_proxy/pool_writer.rs b/src/transport/middle_proxy/pool_writer.rs index 52ffa58..35021aa 100644 --- a/src/transport/middle_proxy/pool_writer.rs +++ b/src/transport/middle_proxy/pool_writer.rs @@ -60,6 +60,9 @@ async fn writer_command_loop( Some(WriterCommand::DataAndFlush(payload)) => { rpc_writer.send_and_flush(&payload).await?; } + Some(WriterCommand::ProxyReq(command)) => { + rpc_writer.send_proxy_req(&command).await?; + } Some(WriterCommand::ControlAndFlush(payload)) => { rpc_writer.send_and_flush(&payload).await?; } diff --git a/src/transport/middle_proxy/registry.rs b/src/transport/middle_proxy/registry.rs index 8095da8..0f3925e 100644 --- a/src/transport/middle_proxy/registry.rs +++ b/src/transport/middle_proxy/registry.rs @@ -2,14 +2,13 @@ use std::collections::{HashMap, HashSet}; use std::net::SocketAddr; use std::sync::Arc; use std::sync::atomic::{AtomicU8, AtomicU64, Ordering}; -use std::time::{Duration, SystemTime, UNIX_EPOCH}; +use std::time::{SystemTime, UNIX_EPOCH}; use dashmap::DashMap; -use tokio::sync::mpsc::error::TrySendError; use tokio::sync::{Mutex, Semaphore, mpsc}; +use super::MeResponse; use super::codec::WriterCommand; -use super::{MeResponse, RouteBytePermit}; const ROUTE_BACKPRESSURE_BASE_TIMEOUT_MS: u64 = 25; const ROUTE_BACKPRESSURE_HIGH_TIMEOUT_MS: u64 = 120; @@ -18,6 +17,8 @@ const ROUTE_QUEUED_BYTE_PERMIT_UNIT: usize = 16 * 1024; const ROUTE_QUEUED_PERMITS_PER_SLOT: usize = 4; const ROUTE_QUEUED_MAX_FRAME_PERMITS: usize = 1024; +mod writer; + #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum RouteResult { Routed, @@ -218,760 +219,7 @@ impl ConnRegistry { ); (id, rx) } - - pub async fn register_writer(&self, writer_id: u64, tx: mpsc::Sender) { - let mut binding = self.binding.inner.lock().await; - binding.writers.insert(writer_id, tx.clone()); - binding - .conns_for_writer - .entry(writer_id) - .or_insert_with(HashSet::new); - self.writers.map.insert(writer_id, tx); - } - - /// Unregister connection, returning associated writer_id if any. - pub async fn unregister(&self, id: u64) -> Option { - self.routing.map.remove(&id); - self.routing.byte_budget.remove(&id); - self.hot_binding.map.remove(&id); - let mut binding = self.binding.inner.lock().await; - binding.meta.remove(&id); - if let Some(writer_id) = binding.writer_for_conn.remove(&id) { - let became_empty = if let Some(set) = binding.conns_for_writer.get_mut(&writer_id) { - set.remove(&id); - set.is_empty() - } else { - false - }; - if became_empty { - binding - .writer_idle_since_epoch_secs - .insert(writer_id, Self::now_epoch_secs()); - } - return Some(writer_id); - } - None - } - - async fn attach_route_byte_permit( - &self, - id: u64, - resp: MeResponse, - timeout_ms: Option, - ) -> std::result::Result { - let MeResponse::Data { - flags, - data, - route_permit, - } = resp - else { - return Ok(resp); - }; - - if route_permit.is_some() { - return Ok(MeResponse::Data { - flags, - data, - route_permit, - }); - } - - let Some(semaphore) = self - .routing - .byte_budget - .get(&id) - .map(|entry| entry.value().clone()) - else { - return Err(RouteResult::NoConn); - }; - let permits = Self::route_data_permits(data.len()); - let permit = match timeout_ms { - Some(0) => semaphore - .try_acquire_many_owned(permits) - .map_err(|_| RouteResult::QueueFullHigh)?, - Some(timeout_ms) => { - let acquire = semaphore.acquire_many_owned(permits); - match tokio::time::timeout(Duration::from_millis(timeout_ms.max(1)), acquire).await - { - Ok(Ok(permit)) => permit, - Ok(Err(_)) => return Err(RouteResult::ChannelClosed), - Err(_) => return Err(RouteResult::QueueFullHigh), - } - } - None => semaphore - .acquire_many_owned(permits) - .await - .map_err(|_| RouteResult::ChannelClosed)?, - }; - - Ok(MeResponse::Data { - flags, - data, - route_permit: Some(RouteBytePermit::new(permit)), - }) - } - - #[allow(dead_code)] - pub async fn route(&self, id: u64, resp: MeResponse) -> RouteResult { - let tx = self.routing.map.get(&id).map(|entry| entry.value().clone()); - - let Some(tx) = tx else { - return RouteResult::NoConn; - }; - - let base_timeout_ms = self - .route_backpressure_base_timeout_ms - .load(Ordering::Relaxed) - .max(1); - let resp = match self - .attach_route_byte_permit(id, resp, Some(base_timeout_ms)) - .await - { - Ok(resp) => resp, - Err(result) => return result, - }; - - match tx.try_send(resp) { - Ok(()) => RouteResult::Routed, - Err(TrySendError::Closed(_)) => RouteResult::ChannelClosed, - Err(TrySendError::Full(resp)) => { - // Absorb short bursts without dropping/closing the session immediately. - let high_timeout_ms = self - .route_backpressure_high_timeout_ms - .load(Ordering::Relaxed) - .max(base_timeout_ms); - let high_watermark_pct = self - .route_backpressure_high_watermark_pct - .load(Ordering::Relaxed) - .clamp(1, 100); - let used = self.route_channel_capacity.saturating_sub(tx.capacity()); - let used_pct = if self.route_channel_capacity == 0 { - 100 - } else { - (used.saturating_mul(100) / self.route_channel_capacity) as u8 - }; - let high_profile = used_pct >= high_watermark_pct; - let timeout_ms = if high_profile { - high_timeout_ms - } else { - base_timeout_ms - }; - let timeout_dur = Duration::from_millis(timeout_ms); - - match tokio::time::timeout(timeout_dur, tx.send(resp)).await { - Ok(Ok(())) => RouteResult::Routed, - Ok(Err(_)) => RouteResult::ChannelClosed, - Err(_) => { - if high_profile { - RouteResult::QueueFullHigh - } else { - RouteResult::QueueFullBase - } - } - } - } - } - } - - pub async fn route_nowait(&self, id: u64, resp: MeResponse) -> RouteResult { - let tx = self.routing.map.get(&id).map(|entry| entry.value().clone()); - - let Some(tx) = tx else { - return RouteResult::NoConn; - }; - let resp = match self.attach_route_byte_permit(id, resp, Some(0)).await { - Ok(resp) => resp, - Err(result) => return result, - }; - - match tx.try_send(resp) { - Ok(()) => RouteResult::Routed, - Err(TrySendError::Closed(_)) => RouteResult::ChannelClosed, - Err(TrySendError::Full(_)) => RouteResult::QueueFullBase, - } - } - - pub async fn route_with_timeout( - &self, - id: u64, - resp: MeResponse, - timeout_ms: u64, - ) -> RouteResult { - if timeout_ms == 0 { - return self.route_nowait(id, resp).await; - } - - let tx = self.routing.map.get(&id).map(|entry| entry.value().clone()); - - let Some(tx) = tx else { - return RouteResult::NoConn; - }; - let resp = match self - .attach_route_byte_permit(id, resp, Some(timeout_ms)) - .await - { - Ok(resp) => resp, - Err(result) => return result, - }; - - match tx.try_send(resp) { - Ok(()) => RouteResult::Routed, - Err(TrySendError::Closed(_)) => RouteResult::ChannelClosed, - Err(TrySendError::Full(resp)) => { - let high_watermark_pct = self - .route_backpressure_high_watermark_pct - .load(Ordering::Relaxed) - .clamp(1, 100); - let used = self.route_channel_capacity.saturating_sub(tx.capacity()); - let used_pct = if self.route_channel_capacity == 0 { - 100 - } else { - (used.saturating_mul(100) / self.route_channel_capacity) as u8 - }; - let high_profile = used_pct >= high_watermark_pct; - let timeout_dur = Duration::from_millis(timeout_ms.max(1)); - - match tokio::time::timeout(timeout_dur, tx.send(resp)).await { - Ok(Ok(())) => RouteResult::Routed, - Ok(Err(_)) => RouteResult::ChannelClosed, - Err(_) => { - if high_profile { - RouteResult::QueueFullHigh - } else { - RouteResult::QueueFullBase - } - } - } - } - } - } - - pub async fn bind_writer(&self, conn_id: u64, writer_id: u64, meta: ConnMeta) -> bool { - let mut binding = self.binding.inner.lock().await; - // ROUTING IS THE SOURCE OF TRUTH: - // never keep/attach writer binding for a connection that is already - // absent from the routing table. - if !self.routing.map.contains_key(&conn_id) { - return false; - } - if !binding.writers.contains_key(&writer_id) { - return false; - } - - let previous_writer_id = binding.writer_for_conn.insert(conn_id, writer_id); - if let Some(previous_writer_id) = previous_writer_id - && previous_writer_id != writer_id - { - let became_empty = - if let Some(set) = binding.conns_for_writer.get_mut(&previous_writer_id) { - set.remove(&conn_id); - set.is_empty() - } else { - false - }; - if became_empty { - binding - .writer_idle_since_epoch_secs - .insert(previous_writer_id, Self::now_epoch_secs()); - } - } - - binding.meta.insert(conn_id, meta.clone()); - binding.last_meta_for_writer.insert(writer_id, meta.clone()); - binding.writer_idle_since_epoch_secs.remove(&writer_id); - binding - .conns_for_writer - .entry(writer_id) - .or_insert_with(HashSet::new) - .insert(conn_id); - self.hot_binding - .map - .insert(conn_id, HotConnBinding { writer_id, meta }); - true - } - - pub async fn mark_writer_idle(&self, writer_id: u64) { - let mut binding = self.binding.inner.lock().await; - binding - .conns_for_writer - .entry(writer_id) - .or_insert_with(HashSet::new); - binding - .writer_idle_since_epoch_secs - .entry(writer_id) - .or_insert(Self::now_epoch_secs()); - } - - pub async fn get_last_writer_meta(&self, writer_id: u64) -> Option { - let binding = self.binding.inner.lock().await; - binding.last_meta_for_writer.get(&writer_id).cloned() - } - - pub async fn writer_idle_since_snapshot(&self) -> HashMap { - let binding = self.binding.inner.lock().await; - binding.writer_idle_since_epoch_secs.clone() - } - - pub async fn writer_idle_since_for_writer_ids(&self, writer_ids: &[u64]) -> HashMap { - let binding = self.binding.inner.lock().await; - let mut out = HashMap::::with_capacity(writer_ids.len()); - for writer_id in writer_ids { - if let Some(idle_since) = binding.writer_idle_since_epoch_secs.get(writer_id).copied() { - out.insert(*writer_id, idle_since); - } - } - out - } - - pub(super) async fn writer_activity_snapshot(&self) -> WriterActivitySnapshot { - let binding = self.binding.inner.lock().await; - let mut bound_clients_by_writer = HashMap::::new(); - let mut active_sessions_by_target_dc = HashMap::::new(); - - for (writer_id, conn_ids) in &binding.conns_for_writer { - bound_clients_by_writer.insert(*writer_id, conn_ids.len()); - } - for conn_meta in binding.meta.values() { - if conn_meta.target_dc == 0 { - continue; - } - *active_sessions_by_target_dc - .entry(conn_meta.target_dc) - .or_insert(0) += 1; - } - - WriterActivitySnapshot { - bound_clients_by_writer, - active_sessions_by_target_dc, - } - } - - pub async fn get_writer(&self, conn_id: u64) -> Option { - if !self.routing.map.contains_key(&conn_id) { - return None; - } - - let writer_id = self - .hot_binding - .map - .get(&conn_id) - .map(|entry| entry.writer_id)?; - let writer = self - .writers - .map - .get(&writer_id) - .map(|entry| entry.value().clone())?; - Some(ConnWriter { - writer_id, - tx: writer, - }) - } - - /// Returns the active writer and routing metadata from one hot-binding lookup. - pub async fn get_writer_with_meta(&self, conn_id: u64) -> Option<(ConnWriter, ConnMeta)> { - if !self.routing.map.contains_key(&conn_id) { - return None; - } - - let hot = self.hot_binding.map.get(&conn_id)?; - let writer_id = hot.writer_id; - let meta = hot.meta.clone(); - let writer = self - .writers - .map - .get(&writer_id) - .map(|entry| entry.value().clone())?; - Some(( - ConnWriter { - writer_id, - tx: writer, - }, - meta, - )) - } - - pub async fn active_conn_ids(&self) -> Vec { - let binding = self.binding.inner.lock().await; - binding.writer_for_conn.keys().copied().collect() - } - - pub async fn writer_lost(&self, writer_id: u64) -> Vec { - let mut binding = self.binding.inner.lock().await; - binding.writers.remove(&writer_id); - self.writers.map.remove(&writer_id); - binding.last_meta_for_writer.remove(&writer_id); - binding.writer_idle_since_epoch_secs.remove(&writer_id); - let conns = binding - .conns_for_writer - .remove(&writer_id) - .unwrap_or_default() - .into_iter() - .collect::>(); - - let mut out = Vec::new(); - for conn_id in conns { - if binding.writer_for_conn.get(&conn_id).copied() != Some(writer_id) { - continue; - } - binding.writer_for_conn.remove(&conn_id); - let remove_hot = self - .hot_binding - .map - .get(&conn_id) - .map(|hot| hot.writer_id == writer_id) - .unwrap_or(false); - if remove_hot { - self.hot_binding.map.remove(&conn_id); - } - if let Some(m) = binding.meta.get(&conn_id) { - out.push(BoundConn { - conn_id, - meta: m.clone(), - }); - } - } - out - } - - #[allow(dead_code)] - pub async fn get_meta(&self, conn_id: u64) -> Option { - self.hot_binding - .map - .get(&conn_id) - .map(|entry| entry.meta.clone()) - } - - pub async fn is_writer_empty(&self, writer_id: u64) -> bool { - let binding = self.binding.inner.lock().await; - binding - .conns_for_writer - .get(&writer_id) - .map(|s| s.is_empty()) - .unwrap_or(true) - } - - #[allow(dead_code)] - pub async fn unregister_writer_if_empty(&self, writer_id: u64) -> bool { - let mut binding = self.binding.inner.lock().await; - let Some(conn_ids) = binding.conns_for_writer.get(&writer_id) else { - // Writer is already absent from the registry. - return true; - }; - if !conn_ids.is_empty() { - return false; - } - - binding.writers.remove(&writer_id); - self.writers.map.remove(&writer_id); - binding.last_meta_for_writer.remove(&writer_id); - binding.writer_idle_since_epoch_secs.remove(&writer_id); - binding.conns_for_writer.remove(&writer_id); - true - } - - #[allow(dead_code)] - pub(super) async fn non_empty_writer_ids(&self, writer_ids: &[u64]) -> HashSet { - let binding = self.binding.inner.lock().await; - let mut out = HashSet::::with_capacity(writer_ids.len()); - for writer_id in writer_ids { - if let Some(conns) = binding.conns_for_writer.get(writer_id) - && !conns.is_empty() - { - out.insert(*writer_id); - } - } - out - } } #[cfg(test)] -mod tests { - use std::net::{IpAddr, Ipv4Addr, SocketAddr}; - - use bytes::Bytes; - - use super::{ConnMeta, ConnRegistry, RouteResult}; - use crate::transport::middle_proxy::MeResponse; - - #[tokio::test] - async fn writer_activity_snapshot_tracks_writer_and_dc_load() { - let registry = ConnRegistry::new(); - - let (conn_a, _rx_a) = registry.register().await; - let (conn_b, _rx_b) = registry.register().await; - let (conn_c, _rx_c) = registry.register().await; - let (writer_tx_a, _writer_rx_a) = tokio::sync::mpsc::channel(8); - let (writer_tx_b, _writer_rx_b) = tokio::sync::mpsc::channel(8); - registry.register_writer(10, writer_tx_a.clone()).await; - registry.register_writer(20, writer_tx_b.clone()).await; - - let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443); - assert!( - registry - .bind_writer( - conn_a, - 10, - ConnMeta { - target_dc: 2, - client_addr: addr, - our_addr: addr, - proto_flags: 0, - }, - ) - .await - ); - assert!( - registry - .bind_writer( - conn_b, - 10, - ConnMeta { - target_dc: -2, - client_addr: addr, - our_addr: addr, - proto_flags: 0, - }, - ) - .await - ); - assert!( - registry - .bind_writer( - conn_c, - 20, - ConnMeta { - target_dc: 4, - client_addr: addr, - our_addr: addr, - proto_flags: 0, - }, - ) - .await - ); - - let snapshot = registry.writer_activity_snapshot().await; - assert_eq!(snapshot.bound_clients_by_writer.get(&10), Some(&2)); - assert_eq!(snapshot.bound_clients_by_writer.get(&20), Some(&1)); - assert_eq!(snapshot.active_sessions_by_target_dc.get(&2), Some(&1)); - assert_eq!(snapshot.active_sessions_by_target_dc.get(&-2), Some(&1)); - assert_eq!(snapshot.active_sessions_by_target_dc.get(&4), Some(&1)); - } - - #[tokio::test] - async fn route_data_is_bounded_by_byte_permits_before_channel_capacity() { - let registry = ConnRegistry::with_route_byte_permits_for_tests(4, 1); - let (conn_id, mut rx) = registry.register().await; - let routed = registry - .route_nowait( - conn_id, - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0xAA]), - route_permit: None, - }, - ) - .await; - assert!(matches!(routed, RouteResult::Routed)); - - let blocked = registry - .route_nowait( - conn_id, - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0xBB]), - route_permit: None, - }, - ) - .await; - assert!( - matches!(blocked, RouteResult::QueueFullHigh), - "byte budget must reject data before count capacity is exhausted" - ); - - drop(rx.recv().await); - - let routed_after_drain = registry - .route_nowait( - conn_id, - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0xCC]), - route_permit: None, - }, - ) - .await; - assert!( - matches!(routed_after_drain, RouteResult::Routed), - "receiving queued data must release byte permits" - ); - } - - #[tokio::test] - async fn bind_writer_rebinds_conn_atomically() { - let registry = ConnRegistry::new(); - let (conn_id, _rx) = registry.register().await; - let (writer_tx_a, _writer_rx_a) = tokio::sync::mpsc::channel(8); - let (writer_tx_b, _writer_rx_b) = tokio::sync::mpsc::channel(8); - registry.register_writer(10, writer_tx_a).await; - registry.register_writer(20, writer_tx_b).await; - - let client_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443); - let first_our_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)), 443); - let second_our_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(2, 2, 2, 2)), 443); - - assert!( - registry - .bind_writer( - conn_id, - 10, - ConnMeta { - target_dc: 2, - client_addr, - our_addr: first_our_addr, - proto_flags: 1, - }, - ) - .await - ); - assert!( - registry - .bind_writer( - conn_id, - 20, - ConnMeta { - target_dc: 2, - client_addr, - our_addr: second_our_addr, - proto_flags: 2, - }, - ) - .await - ); - - let writer = registry.get_writer(conn_id).await.expect("writer binding"); - assert_eq!(writer.writer_id, 20); - - let meta = registry.get_meta(conn_id).await.expect("conn meta"); - assert_eq!(meta.our_addr, second_our_addr); - assert_eq!(meta.proto_flags, 2); - - let snapshot = registry.writer_activity_snapshot().await; - assert_eq!(snapshot.bound_clients_by_writer.get(&10), Some(&0)); - assert_eq!(snapshot.bound_clients_by_writer.get(&20), Some(&1)); - assert!( - registry - .writer_idle_since_snapshot() - .await - .contains_key(&10) - ); - } - - #[tokio::test] - async fn writer_lost_does_not_drop_rebound_conn() { - let registry = ConnRegistry::new(); - let (conn_id, _rx) = registry.register().await; - let (writer_tx_a, _writer_rx_a) = tokio::sync::mpsc::channel(8); - let (writer_tx_b, _writer_rx_b) = tokio::sync::mpsc::channel(8); - registry.register_writer(10, writer_tx_a).await; - registry.register_writer(20, writer_tx_b).await; - - let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443); - assert!( - registry - .bind_writer( - conn_id, - 10, - ConnMeta { - target_dc: 2, - client_addr: addr, - our_addr: addr, - proto_flags: 0, - }, - ) - .await - ); - assert!( - registry - .bind_writer( - conn_id, - 20, - ConnMeta { - target_dc: 2, - client_addr: addr, - our_addr: addr, - proto_flags: 1, - }, - ) - .await - ); - - let lost = registry.writer_lost(10).await; - assert!(lost.is_empty()); - assert_eq!( - registry - .get_writer(conn_id) - .await - .expect("writer") - .writer_id, - 20 - ); - - let removed_writer = registry.unregister(conn_id).await; - assert_eq!(removed_writer, Some(20)); - assert!(registry.is_writer_empty(20).await); - } - - #[tokio::test] - async fn bind_writer_rejects_unregistered_writer() { - let registry = ConnRegistry::new(); - let (conn_id, _rx) = registry.register().await; - let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443); - - assert!( - !registry - .bind_writer( - conn_id, - 10, - ConnMeta { - target_dc: 2, - client_addr: addr, - our_addr: addr, - proto_flags: 0, - }, - ) - .await - ); - assert!(registry.get_writer(conn_id).await.is_none()); - } - - #[tokio::test] - async fn non_empty_writer_ids_returns_only_writers_with_bound_clients() { - let registry = ConnRegistry::new(); - let (conn_id, _rx) = registry.register().await; - let (writer_tx_a, _writer_rx_a) = tokio::sync::mpsc::channel(8); - let (writer_tx_b, _writer_rx_b) = tokio::sync::mpsc::channel(8); - registry.register_writer(10, writer_tx_a).await; - registry.register_writer(20, writer_tx_b).await; - - let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443); - assert!( - registry - .bind_writer( - conn_id, - 10, - ConnMeta { - target_dc: 2, - client_addr: addr, - our_addr: addr, - proto_flags: 0, - }, - ) - .await - ); - - let non_empty = registry.non_empty_writer_ids(&[10, 20, 30]).await; - assert!(non_empty.contains(&10)); - assert!(!non_empty.contains(&20)); - assert!(!non_empty.contains(&30)); - } -} +mod tests; diff --git a/src/transport/middle_proxy/registry/tests.rs b/src/transport/middle_proxy/registry/tests.rs new file mode 100644 index 0000000..3567a4b --- /dev/null +++ b/src/transport/middle_proxy/registry/tests.rs @@ -0,0 +1,288 @@ +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; + +use bytes::Bytes; + +use super::{ConnMeta, ConnRegistry, RouteResult}; +use crate::transport::middle_proxy::MeResponse; + +#[tokio::test] +async fn writer_activity_snapshot_tracks_writer_and_dc_load() { + let registry = ConnRegistry::new(); + + let (conn_a, _rx_a) = registry.register().await; + let (conn_b, _rx_b) = registry.register().await; + let (conn_c, _rx_c) = registry.register().await; + let (writer_tx_a, _writer_rx_a) = tokio::sync::mpsc::channel(8); + let (writer_tx_b, _writer_rx_b) = tokio::sync::mpsc::channel(8); + registry.register_writer(10, writer_tx_a.clone()).await; + registry.register_writer(20, writer_tx_b.clone()).await; + + let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443); + assert!( + registry + .bind_writer( + conn_a, + 10, + ConnMeta { + target_dc: 2, + client_addr: addr, + our_addr: addr, + proto_flags: 0, + }, + ) + .await + ); + assert!( + registry + .bind_writer( + conn_b, + 10, + ConnMeta { + target_dc: -2, + client_addr: addr, + our_addr: addr, + proto_flags: 0, + }, + ) + .await + ); + assert!( + registry + .bind_writer( + conn_c, + 20, + ConnMeta { + target_dc: 4, + client_addr: addr, + our_addr: addr, + proto_flags: 0, + }, + ) + .await + ); + + let snapshot = registry.writer_activity_snapshot().await; + assert_eq!(snapshot.bound_clients_by_writer.get(&10), Some(&2)); + assert_eq!(snapshot.bound_clients_by_writer.get(&20), Some(&1)); + assert_eq!(snapshot.active_sessions_by_target_dc.get(&2), Some(&1)); + assert_eq!(snapshot.active_sessions_by_target_dc.get(&-2), Some(&1)); + assert_eq!(snapshot.active_sessions_by_target_dc.get(&4), Some(&1)); +} + +#[tokio::test] +async fn route_data_is_bounded_by_byte_permits_before_channel_capacity() { + let registry = ConnRegistry::with_route_byte_permits_for_tests(4, 1); + let (conn_id, mut rx) = registry.register().await; + let routed = registry + .route_nowait( + conn_id, + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0xAA]), + route_permit: None, + }, + ) + .await; + assert!(matches!(routed, RouteResult::Routed)); + + let blocked = registry + .route_nowait( + conn_id, + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0xBB]), + route_permit: None, + }, + ) + .await; + assert!( + matches!(blocked, RouteResult::QueueFullHigh), + "byte budget must reject data before count capacity is exhausted" + ); + + drop(rx.recv().await); + + let routed_after_drain = registry + .route_nowait( + conn_id, + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0xCC]), + route_permit: None, + }, + ) + .await; + assert!( + matches!(routed_after_drain, RouteResult::Routed), + "receiving queued data must release byte permits" + ); +} + +#[tokio::test] +async fn bind_writer_rebinds_conn_atomically() { + let registry = ConnRegistry::new(); + let (conn_id, _rx) = registry.register().await; + let (writer_tx_a, _writer_rx_a) = tokio::sync::mpsc::channel(8); + let (writer_tx_b, _writer_rx_b) = tokio::sync::mpsc::channel(8); + registry.register_writer(10, writer_tx_a).await; + registry.register_writer(20, writer_tx_b).await; + + let client_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443); + let first_our_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)), 443); + let second_our_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(2, 2, 2, 2)), 443); + + assert!( + registry + .bind_writer( + conn_id, + 10, + ConnMeta { + target_dc: 2, + client_addr, + our_addr: first_our_addr, + proto_flags: 1, + }, + ) + .await + ); + assert!( + registry + .bind_writer( + conn_id, + 20, + ConnMeta { + target_dc: 2, + client_addr, + our_addr: second_our_addr, + proto_flags: 2, + }, + ) + .await + ); + + let writer = registry.get_writer(conn_id).await.expect("writer binding"); + assert_eq!(writer.writer_id, 20); + + let meta = registry.get_meta(conn_id).await.expect("conn meta"); + assert_eq!(meta.our_addr, second_our_addr); + assert_eq!(meta.proto_flags, 2); + + let snapshot = registry.writer_activity_snapshot().await; + assert_eq!(snapshot.bound_clients_by_writer.get(&10), Some(&0)); + assert_eq!(snapshot.bound_clients_by_writer.get(&20), Some(&1)); + assert!( + registry + .writer_idle_since_snapshot() + .await + .contains_key(&10) + ); +} + +#[tokio::test] +async fn writer_lost_does_not_drop_rebound_conn() { + let registry = ConnRegistry::new(); + let (conn_id, _rx) = registry.register().await; + let (writer_tx_a, _writer_rx_a) = tokio::sync::mpsc::channel(8); + let (writer_tx_b, _writer_rx_b) = tokio::sync::mpsc::channel(8); + registry.register_writer(10, writer_tx_a).await; + registry.register_writer(20, writer_tx_b).await; + + let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443); + assert!( + registry + .bind_writer( + conn_id, + 10, + ConnMeta { + target_dc: 2, + client_addr: addr, + our_addr: addr, + proto_flags: 0, + }, + ) + .await + ); + assert!( + registry + .bind_writer( + conn_id, + 20, + ConnMeta { + target_dc: 2, + client_addr: addr, + our_addr: addr, + proto_flags: 1, + }, + ) + .await + ); + + let lost = registry.writer_lost(10).await; + assert!(lost.is_empty()); + assert_eq!( + registry + .get_writer(conn_id) + .await + .expect("writer") + .writer_id, + 20 + ); + + let removed_writer = registry.unregister(conn_id).await; + assert_eq!(removed_writer, Some(20)); + assert!(registry.is_writer_empty(20).await); +} + +#[tokio::test] +async fn bind_writer_rejects_unregistered_writer() { + let registry = ConnRegistry::new(); + let (conn_id, _rx) = registry.register().await; + let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443); + + assert!( + !registry + .bind_writer( + conn_id, + 10, + ConnMeta { + target_dc: 2, + client_addr: addr, + our_addr: addr, + proto_flags: 0, + }, + ) + .await + ); + assert!(registry.get_writer(conn_id).await.is_none()); +} + +#[tokio::test] +async fn non_empty_writer_ids_returns_only_writers_with_bound_clients() { + let registry = ConnRegistry::new(); + let (conn_id, _rx) = registry.register().await; + let (writer_tx_a, _writer_rx_a) = tokio::sync::mpsc::channel(8); + let (writer_tx_b, _writer_rx_b) = tokio::sync::mpsc::channel(8); + registry.register_writer(10, writer_tx_a).await; + registry.register_writer(20, writer_tx_b).await; + + let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443); + assert!( + registry + .bind_writer( + conn_id, + 10, + ConnMeta { + target_dc: 2, + client_addr: addr, + our_addr: addr, + proto_flags: 0, + }, + ) + .await + ); + + let non_empty = registry.non_empty_writer_ids(&[10, 20, 30]).await; + assert!(non_empty.contains(&10)); + assert!(!non_empty.contains(&20)); + assert!(!non_empty.contains(&30)); +} diff --git a/src/transport/middle_proxy/registry/writer.rs b/src/transport/middle_proxy/registry/writer.rs new file mode 100644 index 0000000..c2817f0 --- /dev/null +++ b/src/transport/middle_proxy/registry/writer.rs @@ -0,0 +1,481 @@ +use std::collections::{HashMap, HashSet}; +use std::sync::atomic::Ordering; +use std::time::Duration; + +use tokio::sync::mpsc; +use tokio::sync::mpsc::error::TrySendError; + +use super::super::codec::WriterCommand; +use super::super::{MeResponse, RouteBytePermit}; +use super::{ + BoundConn, ConnMeta, ConnRegistry, ConnWriter, HotConnBinding, RouteResult, + WriterActivitySnapshot, +}; + +impl ConnRegistry { + pub async fn register_writer(&self, writer_id: u64, tx: mpsc::Sender) { + let mut binding = self.binding.inner.lock().await; + binding.writers.insert(writer_id, tx.clone()); + binding + .conns_for_writer + .entry(writer_id) + .or_insert_with(HashSet::new); + self.writers.map.insert(writer_id, tx); + } + + /// Unregister connection, returning associated writer_id if any. + pub async fn unregister(&self, id: u64) -> Option { + self.routing.map.remove(&id); + self.routing.byte_budget.remove(&id); + self.hot_binding.map.remove(&id); + let mut binding = self.binding.inner.lock().await; + binding.meta.remove(&id); + if let Some(writer_id) = binding.writer_for_conn.remove(&id) { + let became_empty = if let Some(set) = binding.conns_for_writer.get_mut(&writer_id) { + set.remove(&id); + set.is_empty() + } else { + false + }; + if became_empty { + binding + .writer_idle_since_epoch_secs + .insert(writer_id, Self::now_epoch_secs()); + } + return Some(writer_id); + } + None + } + + async fn attach_route_byte_permit( + &self, + id: u64, + resp: MeResponse, + timeout_ms: Option, + ) -> std::result::Result { + let MeResponse::Data { + flags, + data, + route_permit, + } = resp + else { + return Ok(resp); + }; + + if route_permit.is_some() { + return Ok(MeResponse::Data { + flags, + data, + route_permit, + }); + } + + let Some(semaphore) = self + .routing + .byte_budget + .get(&id) + .map(|entry| entry.value().clone()) + else { + return Err(RouteResult::NoConn); + }; + let permits = Self::route_data_permits(data.len()); + let permit = match timeout_ms { + Some(0) => semaphore + .try_acquire_many_owned(permits) + .map_err(|_| RouteResult::QueueFullHigh)?, + Some(timeout_ms) => { + let acquire = semaphore.acquire_many_owned(permits); + match tokio::time::timeout(Duration::from_millis(timeout_ms.max(1)), acquire).await + { + Ok(Ok(permit)) => permit, + Ok(Err(_)) => return Err(RouteResult::ChannelClosed), + Err(_) => return Err(RouteResult::QueueFullHigh), + } + } + None => semaphore + .acquire_many_owned(permits) + .await + .map_err(|_| RouteResult::ChannelClosed)?, + }; + + Ok(MeResponse::Data { + flags, + data, + route_permit: Some(RouteBytePermit::new(permit)), + }) + } + + #[allow(dead_code)] + pub async fn route(&self, id: u64, resp: MeResponse) -> RouteResult { + let tx = self.routing.map.get(&id).map(|entry| entry.value().clone()); + + let Some(tx) = tx else { + return RouteResult::NoConn; + }; + + let base_timeout_ms = self + .route_backpressure_base_timeout_ms + .load(Ordering::Relaxed) + .max(1); + let resp = match self + .attach_route_byte_permit(id, resp, Some(base_timeout_ms)) + .await + { + Ok(resp) => resp, + Err(result) => return result, + }; + + match tx.try_send(resp) { + Ok(()) => RouteResult::Routed, + Err(TrySendError::Closed(_)) => RouteResult::ChannelClosed, + Err(TrySendError::Full(resp)) => { + // Absorb short bursts without dropping/closing the session immediately. + let high_timeout_ms = self + .route_backpressure_high_timeout_ms + .load(Ordering::Relaxed) + .max(base_timeout_ms); + let high_watermark_pct = self + .route_backpressure_high_watermark_pct + .load(Ordering::Relaxed) + .clamp(1, 100); + let used = self.route_channel_capacity.saturating_sub(tx.capacity()); + let used_pct = if self.route_channel_capacity == 0 { + 100 + } else { + (used.saturating_mul(100) / self.route_channel_capacity) as u8 + }; + let high_profile = used_pct >= high_watermark_pct; + let timeout_ms = if high_profile { + high_timeout_ms + } else { + base_timeout_ms + }; + let timeout_dur = Duration::from_millis(timeout_ms); + + match tokio::time::timeout(timeout_dur, tx.send(resp)).await { + Ok(Ok(())) => RouteResult::Routed, + Ok(Err(_)) => RouteResult::ChannelClosed, + Err(_) => { + if high_profile { + RouteResult::QueueFullHigh + } else { + RouteResult::QueueFullBase + } + } + } + } + } + } + + pub async fn route_nowait(&self, id: u64, resp: MeResponse) -> RouteResult { + let tx = self.routing.map.get(&id).map(|entry| entry.value().clone()); + + let Some(tx) = tx else { + return RouteResult::NoConn; + }; + let resp = match self.attach_route_byte_permit(id, resp, Some(0)).await { + Ok(resp) => resp, + Err(result) => return result, + }; + + match tx.try_send(resp) { + Ok(()) => RouteResult::Routed, + Err(TrySendError::Closed(_)) => RouteResult::ChannelClosed, + Err(TrySendError::Full(_)) => RouteResult::QueueFullBase, + } + } + + pub async fn route_with_timeout( + &self, + id: u64, + resp: MeResponse, + timeout_ms: u64, + ) -> RouteResult { + if timeout_ms == 0 { + return self.route_nowait(id, resp).await; + } + + let tx = self.routing.map.get(&id).map(|entry| entry.value().clone()); + + let Some(tx) = tx else { + return RouteResult::NoConn; + }; + let resp = match self + .attach_route_byte_permit(id, resp, Some(timeout_ms)) + .await + { + Ok(resp) => resp, + Err(result) => return result, + }; + + match tx.try_send(resp) { + Ok(()) => RouteResult::Routed, + Err(TrySendError::Closed(_)) => RouteResult::ChannelClosed, + Err(TrySendError::Full(resp)) => { + let high_watermark_pct = self + .route_backpressure_high_watermark_pct + .load(Ordering::Relaxed) + .clamp(1, 100); + let used = self.route_channel_capacity.saturating_sub(tx.capacity()); + let used_pct = if self.route_channel_capacity == 0 { + 100 + } else { + (used.saturating_mul(100) / self.route_channel_capacity) as u8 + }; + let high_profile = used_pct >= high_watermark_pct; + let timeout_dur = Duration::from_millis(timeout_ms.max(1)); + + match tokio::time::timeout(timeout_dur, tx.send(resp)).await { + Ok(Ok(())) => RouteResult::Routed, + Ok(Err(_)) => RouteResult::ChannelClosed, + Err(_) => { + if high_profile { + RouteResult::QueueFullHigh + } else { + RouteResult::QueueFullBase + } + } + } + } + } + } + + pub async fn bind_writer(&self, conn_id: u64, writer_id: u64, meta: ConnMeta) -> bool { + let mut binding = self.binding.inner.lock().await; + // ROUTING IS THE SOURCE OF TRUTH: + // never keep/attach writer binding for a connection that is already + // absent from the routing table. + if !self.routing.map.contains_key(&conn_id) { + return false; + } + if !binding.writers.contains_key(&writer_id) { + return false; + } + + let previous_writer_id = binding.writer_for_conn.insert(conn_id, writer_id); + if let Some(previous_writer_id) = previous_writer_id + && previous_writer_id != writer_id + { + let became_empty = + if let Some(set) = binding.conns_for_writer.get_mut(&previous_writer_id) { + set.remove(&conn_id); + set.is_empty() + } else { + false + }; + if became_empty { + binding + .writer_idle_since_epoch_secs + .insert(previous_writer_id, Self::now_epoch_secs()); + } + } + + binding.meta.insert(conn_id, meta.clone()); + binding.last_meta_for_writer.insert(writer_id, meta.clone()); + binding.writer_idle_since_epoch_secs.remove(&writer_id); + binding + .conns_for_writer + .entry(writer_id) + .or_insert_with(HashSet::new) + .insert(conn_id); + self.hot_binding + .map + .insert(conn_id, HotConnBinding { writer_id, meta }); + true + } + + pub async fn mark_writer_idle(&self, writer_id: u64) { + let mut binding = self.binding.inner.lock().await; + binding + .conns_for_writer + .entry(writer_id) + .or_insert_with(HashSet::new); + binding + .writer_idle_since_epoch_secs + .entry(writer_id) + .or_insert(Self::now_epoch_secs()); + } + + pub async fn get_last_writer_meta(&self, writer_id: u64) -> Option { + let binding = self.binding.inner.lock().await; + binding.last_meta_for_writer.get(&writer_id).cloned() + } + + pub async fn writer_idle_since_snapshot(&self) -> HashMap { + let binding = self.binding.inner.lock().await; + binding.writer_idle_since_epoch_secs.clone() + } + + pub async fn writer_idle_since_for_writer_ids(&self, writer_ids: &[u64]) -> HashMap { + let binding = self.binding.inner.lock().await; + let mut out = HashMap::::with_capacity(writer_ids.len()); + for writer_id in writer_ids { + if let Some(idle_since) = binding.writer_idle_since_epoch_secs.get(writer_id).copied() { + out.insert(*writer_id, idle_since); + } + } + out + } + + pub(in crate::transport::middle_proxy) async fn writer_activity_snapshot( + &self, + ) -> WriterActivitySnapshot { + let binding = self.binding.inner.lock().await; + let mut bound_clients_by_writer = HashMap::::new(); + let mut active_sessions_by_target_dc = HashMap::::new(); + + for (writer_id, conn_ids) in &binding.conns_for_writer { + bound_clients_by_writer.insert(*writer_id, conn_ids.len()); + } + for conn_meta in binding.meta.values() { + if conn_meta.target_dc == 0 { + continue; + } + *active_sessions_by_target_dc + .entry(conn_meta.target_dc) + .or_insert(0) += 1; + } + + WriterActivitySnapshot { + bound_clients_by_writer, + active_sessions_by_target_dc, + } + } + + pub async fn get_writer(&self, conn_id: u64) -> Option { + if !self.routing.map.contains_key(&conn_id) { + return None; + } + + let writer_id = self + .hot_binding + .map + .get(&conn_id) + .map(|entry| entry.writer_id)?; + let writer = self + .writers + .map + .get(&writer_id) + .map(|entry| entry.value().clone())?; + Some(ConnWriter { + writer_id, + tx: writer, + }) + } + + /// Returns the active writer and routing metadata from one hot-binding lookup. + pub async fn get_writer_with_meta(&self, conn_id: u64) -> Option<(ConnWriter, ConnMeta)> { + if !self.routing.map.contains_key(&conn_id) { + return None; + } + + let hot = self.hot_binding.map.get(&conn_id)?; + let writer_id = hot.writer_id; + let meta = hot.meta.clone(); + let writer = self + .writers + .map + .get(&writer_id) + .map(|entry| entry.value().clone())?; + Some(( + ConnWriter { + writer_id, + tx: writer, + }, + meta, + )) + } + + pub async fn active_conn_ids(&self) -> Vec { + let binding = self.binding.inner.lock().await; + binding.writer_for_conn.keys().copied().collect() + } + + pub async fn writer_lost(&self, writer_id: u64) -> Vec { + let mut binding = self.binding.inner.lock().await; + binding.writers.remove(&writer_id); + self.writers.map.remove(&writer_id); + binding.last_meta_for_writer.remove(&writer_id); + binding.writer_idle_since_epoch_secs.remove(&writer_id); + let conns = binding + .conns_for_writer + .remove(&writer_id) + .unwrap_or_default() + .into_iter() + .collect::>(); + + let mut out = Vec::new(); + for conn_id in conns { + if binding.writer_for_conn.get(&conn_id).copied() != Some(writer_id) { + continue; + } + binding.writer_for_conn.remove(&conn_id); + let remove_hot = self + .hot_binding + .map + .get(&conn_id) + .map(|hot| hot.writer_id == writer_id) + .unwrap_or(false); + if remove_hot { + self.hot_binding.map.remove(&conn_id); + } + if let Some(m) = binding.meta.get(&conn_id) { + out.push(BoundConn { + conn_id, + meta: m.clone(), + }); + } + } + out + } + + #[allow(dead_code)] + pub async fn get_meta(&self, conn_id: u64) -> Option { + self.hot_binding + .map + .get(&conn_id) + .map(|entry| entry.meta.clone()) + } + + pub async fn is_writer_empty(&self, writer_id: u64) -> bool { + let binding = self.binding.inner.lock().await; + binding + .conns_for_writer + .get(&writer_id) + .map(|s| s.is_empty()) + .unwrap_or(true) + } + + #[allow(dead_code)] + pub async fn unregister_writer_if_empty(&self, writer_id: u64) -> bool { + let mut binding = self.binding.inner.lock().await; + let Some(conn_ids) = binding.conns_for_writer.get(&writer_id) else { + // Writer is already absent from the registry. + return true; + }; + if !conn_ids.is_empty() { + return false; + } + + binding.writers.remove(&writer_id); + self.writers.map.remove(&writer_id); + binding.last_meta_for_writer.remove(&writer_id); + binding.writer_idle_since_epoch_secs.remove(&writer_id); + binding.conns_for_writer.remove(&writer_id); + true + } + + #[allow(dead_code)] + pub(super) async fn non_empty_writer_ids(&self, writer_ids: &[u64]) -> HashSet { + let binding = self.binding.inner.lock().await; + let mut out = HashSet::::with_capacity(writer_ids.len()); + for writer_id in writer_ids { + if let Some(conns) = binding.conns_for_writer.get(writer_id) + && !conns.is_empty() + { + out.insert(*writer_id); + } + } + out + } +} diff --git a/src/transport/middle_proxy/send.rs b/src/transport/middle_proxy/send.rs index 8217588..847d60e 100644 --- a/src/transport/middle_proxy/send.rs +++ b/src/transport/middle_proxy/send.rs @@ -1,7 +1,6 @@ #![allow(clippy::too_many_arguments)] use std::cmp::Reverse; -use std::collections::{HashMap, HashSet}; use std::net::SocketAddr; use std::sync::Arc; use std::sync::atomic::Ordering; @@ -10,16 +9,14 @@ use std::time::{Duration, Instant}; use tokio::sync::mpsc::error::TrySendError; use tracing::{debug, warn}; +use super::MePool; +use super::codec::{ProxyReqCommand, WriterCommand}; +use super::registry::ConnMeta; +use super::wire::build_proxy_req_payload; use crate::config::{MeRouteNoWriterMode, MeWriterPickMode}; use crate::error::{ProxyError, Result}; use crate::network::IpFamily; -use crate::protocol::constants::{RPC_CLOSE_CONN_U32, RPC_CLOSE_EXT_U32}; - -use super::MePool; -use super::codec::{WriterCommand, build_control_payload}; -use super::pool::WriterContour; -use super::registry::ConnMeta; -use super::wire::build_proxy_req_payload; +use crate::stream::PooledBuffer; use rand::seq::SliceRandom; const IDLE_WRITER_PENALTY_MID_SECS: u64 = 45; @@ -33,6 +30,21 @@ const PICK_PENALTY_DRAINING: u64 = 600; const PICK_PENALTY_STALE: u64 = 300; const PICK_PENALTY_DEGRADED: u64 = 250; +mod close; +mod recovery; +mod selection; + +fn proxy_tag_array(tag: Option<&[u8]>) -> Option<[u8; 16]> { + tag.and_then(|tag| <[u8; 16]>::try_from(tag).ok()) +} + +fn proxy_req_payload_from_command(cmd: WriterCommand) -> Option { + match cmd { + WriterCommand::ProxyReq(command) => Some(command.payload), + _ => None, + } +} + impl MePool { /// Send RPC_PROXY_REQ. `tag_override`: per-user ad_tag (from access.user_ad_tags); if None, uses pool default. pub async fn send_proxy_req( @@ -84,14 +96,10 @@ impl MePool { let mut hybrid_wait_current = hybrid_wait_step; loop { - if let Some((current, current_meta)) = - self.registry.get_writer_with_meta(conn_id).await + if let Some((current, current_meta)) = self.registry.get_writer_with_meta(conn_id).await { let (current_payload, _) = build_routed_payload(current_meta.our_addr); - match current - .tx - .try_send(WriterCommand::Data(current_payload)) - { + match current.tx.try_send(WriterCommand::Data(current_payload)) { Ok(()) => { self.note_hybrid_route_success(); return Ok(()); @@ -528,401 +536,93 @@ impl MePool { } } - async fn wait_for_writer_until(&self, deadline: Instant) -> bool { - let mut rx = self.writer_epoch.subscribe(); - if !self.writers.read().await.is_empty() { - return true; - } - let now = Instant::now(); - if now >= deadline { - return !self.writers.read().await.is_empty(); - } - let timeout = deadline.saturating_duration_since(now); - if tokio::time::timeout(timeout, rx.changed()).await.is_ok() { - return !self.writers.read().await.is_empty(); - } - !self.writers.read().await.is_empty() - } - - async fn wait_for_candidate_until(&self, routed_dc: i32, deadline: Instant) -> bool { - let mut rx = self.writer_epoch.subscribe(); - loop { - if self.has_candidate_for_target_dc(routed_dc).await { - return true; - } - - let now = Instant::now(); - if now >= deadline { - return self.has_candidate_for_target_dc(routed_dc).await; - } - - if self.has_candidate_for_target_dc(routed_dc).await { - return true; - } - let remaining = deadline.saturating_duration_since(Instant::now()); - if remaining.is_zero() { - return self.has_candidate_for_target_dc(routed_dc).await; - } - if tokio::time::timeout(remaining, rx.changed()).await.is_err() { - return self.has_candidate_for_target_dc(routed_dc).await; - } - } - } - - async fn has_candidate_for_target_dc(&self, routed_dc: i32) -> bool { - let writers_snapshot = { - let ws = self.writers.read().await; - if ws.is_empty() { - return false; - } - ws.clone() - }; - let mut candidate_indices = self - .candidate_indices_for_dc(&writers_snapshot, routed_dc, false) - .await; - if candidate_indices.is_empty() { - candidate_indices = self - .candidate_indices_for_dc(&writers_snapshot, routed_dc, true) - .await; - } - !candidate_indices.is_empty() - } - - async fn trigger_async_recovery_for_target_dc(self: &Arc, routed_dc: i32) -> bool { - let endpoints = self.endpoint_candidates_for_target_dc(routed_dc).await; - if endpoints.is_empty() { - return false; - } - self.stats.increment_me_async_recovery_trigger_total(); - for addr in endpoints.into_iter().take(8) { - self.trigger_immediate_refill_for_dc(addr, routed_dc); - } - true - } - - async fn trigger_async_recovery_global(self: &Arc) { - self.stats.increment_me_async_recovery_trigger_total(); - let mut seen = HashSet::<(i32, SocketAddr)>::new(); - for family in self.family_order() { - let map_guard = match family { - IpFamily::V4 => self.proxy_map_v4.read().await, - IpFamily::V6 => self.proxy_map_v6.read().await, - }; - for (dc, addrs) in map_guard.iter() { - for (ip, port) in addrs { - let addr = SocketAddr::new(*ip, *port); - if seen.insert((*dc, addr)) { - self.trigger_immediate_refill_for_dc(addr, *dc); - } - if seen.len() >= 8 { - return; - } - } - } - } - } - - async fn endpoint_candidates_for_target_dc(&self, routed_dc: i32) -> Vec { - self.preferred_endpoints_for_dc(routed_dc).await - } - - async fn maybe_trigger_hybrid_recovery( + /// Send RPC_PROXY_REQ while keeping the first bound-writer path allocation-light. + pub async fn send_proxy_req_pooled( self: &Arc, - routed_dc: i32, - hybrid_recovery_round: &mut u32, - hybrid_last_recovery_at: &mut Option, - hybrid_wait_step: Duration, - ) { - if !self.try_consume_hybrid_recovery_trigger_slot(HYBRID_RECOVERY_TRIGGER_MIN_INTERVAL_MS) { - return; - } - if let Some(last) = *hybrid_last_recovery_at - && last.elapsed() < hybrid_wait_step - { - return; - } + conn_id: u64, + target_dc: i16, + client_addr: SocketAddr, + our_addr: SocketAddr, + payload: PooledBuffer, + proto_flags: u32, + tag_override: Option<[u8; 16]>, + ) -> Result<()> { + let tag = tag_override.or_else(|| proxy_tag_array(self.proxy_tag.as_deref())); - let round = *hybrid_recovery_round; - let target_triggered = self.trigger_async_recovery_for_target_dc(routed_dc).await; - if !target_triggered || round.is_multiple_of(HYBRID_GLOBAL_BURST_PERIOD_ROUNDS) { - self.trigger_async_recovery_global().await; - } - *hybrid_recovery_round = round.saturating_add(1); - *hybrid_last_recovery_at = Some(Instant::now()); - } - - fn hybrid_total_wait_budget(&self) -> Duration { - let base = self - .route_runtime - .me_route_hybrid_max_wait - .max(Duration::from_millis(50)); - let now_ms = Self::now_epoch_millis(); - let last_success_ms = self - .route_runtime - .me_route_last_success_epoch_ms - .load(Ordering::Relaxed); - if last_success_ms != 0 - && now_ms.saturating_sub(last_success_ms) <= HYBRID_RECENT_SUCCESS_WINDOW_MS - { - return base.saturating_mul(2); - } - base - } - - fn note_hybrid_route_success(&self) { - self.route_runtime - .me_route_last_success_epoch_ms - .store(Self::now_epoch_millis(), Ordering::Relaxed); - } - - fn on_hybrid_timeout(&self, deadline: Instant, routed_dc: i32) { - self.stats.increment_me_hybrid_timeout_total(); - let now_ms = Self::now_epoch_millis(); - let mut last_warn_ms = self - .route_runtime - .me_route_hybrid_timeout_warn_epoch_ms - .load(Ordering::Relaxed); - while now_ms.saturating_sub(last_warn_ms) >= HYBRID_TIMEOUT_WARN_RATE_LIMIT_MS { - match self - .route_runtime - .me_route_hybrid_timeout_warn_epoch_ms - .compare_exchange_weak(last_warn_ms, now_ms, Ordering::AcqRel, Ordering::Relaxed) - { - Ok(_) => { - warn!( - routed_dc, - budget_ms = self.hybrid_total_wait_budget().as_millis() as u64, - elapsed_ms = deadline.elapsed().as_millis() as u64, - "ME hybrid route timeout reached" - ); - break; + if let Some((current, current_meta)) = self.registry.get_writer_with_meta(conn_id).await { + let command = WriterCommand::ProxyReq(ProxyReqCommand { + conn_id, + client_addr, + our_addr: current_meta.our_addr, + proto_flags, + proxy_tag: tag, + payload, + }); + match current.tx.try_send(command) { + Ok(()) => { + self.note_hybrid_route_success(); + return Ok(()); } - Err(actual) => last_warn_ms = actual, - } - } - } - - fn try_consume_hybrid_recovery_trigger_slot(&self, min_interval_ms: u64) -> bool { - let now_ms = Self::now_epoch_millis(); - let mut last_trigger_ms = self - .route_runtime - .me_async_recovery_last_trigger_epoch_ms - .load(Ordering::Relaxed); - loop { - if now_ms.saturating_sub(last_trigger_ms) < min_interval_ms { - return false; - } - match self - .route_runtime - .me_async_recovery_last_trigger_epoch_ms - .compare_exchange_weak(last_trigger_ms, now_ms, Ordering::AcqRel, Ordering::Relaxed) - { - Ok(_) => return true, - Err(actual) => last_trigger_ms = actual, - } - } - } - - pub async fn send_close(self: &Arc, conn_id: u64) -> Result<()> { - if let Some(w) = self.registry.get_writer(conn_id).await { - let payload = build_control_payload(RPC_CLOSE_EXT_U32, conn_id); - if w.tx - .send(WriterCommand::ControlAndFlush(payload)) - .await - .is_err() - { - debug!("ME close write failed"); - self.remove_writer_and_close_clients(w.writer_id).await; - } - } else { - debug!(conn_id, "ME close skipped (writer missing)"); - } - - self.registry.unregister(conn_id).await; - Ok(()) - } - - pub async fn send_close_conn(self: &Arc, conn_id: u64) -> Result<()> { - if let Some(w) = self.registry.get_writer(conn_id).await { - let payload = build_control_payload(RPC_CLOSE_CONN_U32, conn_id); - match w.tx.try_send(WriterCommand::ControlAndFlush(payload)) { - Ok(()) => {} - Err(TrySendError::Full(cmd)) => { - let _ = tokio::time::timeout(Duration::from_millis(50), w.tx.send(cmd)).await; - } - Err(TrySendError::Closed(_)) => { - debug!(conn_id, "ME close_conn skipped: writer channel closed"); + Err(TrySendError::Full(cmd)) => match current.tx.send(cmd).await { + Ok(()) => { + self.note_hybrid_route_success(); + return Ok(()); + } + Err(send_err) => { + let Some(payload) = proxy_req_payload_from_command(send_err.0) else { + return Err(ProxyError::Proxy( + "ME writer rejected unexpected command type".into(), + )); + }; + warn!(writer_id = current.writer_id, "ME writer channel closed"); + self.remove_writer_and_close_clients(current.writer_id) + .await; + return self + .send_proxy_req( + conn_id, + target_dc, + client_addr, + our_addr, + payload.as_ref(), + proto_flags, + tag.as_ref().map(|tag| tag.as_slice()), + ) + .await; + } + }, + Err(TrySendError::Closed(cmd)) => { + let Some(payload) = proxy_req_payload_from_command(cmd) else { + return Err(ProxyError::Proxy( + "ME writer rejected unexpected command type".into(), + )); + }; + warn!(writer_id = current.writer_id, "ME writer channel closed"); + self.remove_writer_and_close_clients(current.writer_id) + .await; + return self + .send_proxy_req( + conn_id, + target_dc, + client_addr, + our_addr, + payload.as_ref(), + proto_flags, + tag.as_ref().map(|tag| tag.as_slice()), + ) + .await; } } - } else { - debug!(conn_id, "ME close_conn skipped (writer missing)"); } - self.registry.unregister(conn_id).await; - Ok(()) - } - - pub async fn shutdown_send_close_conn_all(self: &Arc) -> usize { - let conn_ids = self.registry.active_conn_ids().await; - let total = conn_ids.len(); - for conn_id in conn_ids { - let _ = self.send_close_conn(conn_id).await; - } - total - } - - pub fn connection_count(&self) -> usize { - self.conn_count.load(Ordering::Relaxed) - } - - pub(super) async fn candidate_indices_for_dc( - &self, - writers: &[super::pool::MeWriter], - routed_dc: i32, - include_warm: bool, - ) -> Vec { - let preferred = self.preferred_endpoints_for_dc(routed_dc).await; - if preferred.is_empty() { - return Vec::new(); - } - - let mut out = Vec::new(); - for (idx, w) in writers.iter().enumerate() { - if !self.writer_eligible_for_selection(w, include_warm) { - continue; - } - if w.writer_dc == routed_dc && preferred.contains(&w.addr) { - out.push(idx); - } - } - out - } - - fn writer_eligible_for_selection( - &self, - writer: &super::pool::MeWriter, - include_warm: bool, - ) -> bool { - if !self.writer_accepts_new_binding(writer) { - return false; - } - - match WriterContour::from_u8(writer.contour.load(Ordering::Relaxed)) { - WriterContour::Active => true, - WriterContour::Warm => include_warm, - WriterContour::Draining => true, - } - } - - fn writer_contour_rank_for_selection(&self, writer: &super::pool::MeWriter) -> usize { - match WriterContour::from_u8(writer.contour.load(Ordering::Relaxed)) { - WriterContour::Active => 0, - WriterContour::Warm => 1, - WriterContour::Draining => 2, - } - } - - fn writer_idle_rank_for_selection( - &self, - writer: &super::pool::MeWriter, - idle_since_by_writer: &HashMap, - now_epoch_secs: u64, - ) -> usize { - let Some(idle_since) = idle_since_by_writer.get(&writer.id).copied() else { - return 0; - }; - let idle_age_secs = now_epoch_secs.saturating_sub(idle_since); - if idle_age_secs >= IDLE_WRITER_PENALTY_HIGH_SECS { - 2 - } else if idle_age_secs >= IDLE_WRITER_PENALTY_MID_SECS { - 1 - } else { - 0 - } - } - - fn writer_pick_score( - &self, - writer: &super::pool::MeWriter, - idle_since_by_writer: &HashMap, - now_epoch_secs: u64, - ) -> u64 { - let contour_penalty = match WriterContour::from_u8(writer.contour.load(Ordering::Relaxed)) { - WriterContour::Active => 0, - WriterContour::Warm => PICK_PENALTY_WARM, - WriterContour::Draining => PICK_PENALTY_DRAINING, - }; - let stale_penalty = if writer.generation < self.current_generation() { - PICK_PENALTY_STALE - } else { - 0 - }; - let degraded_penalty = if writer.degraded.load(Ordering::Relaxed) { - PICK_PENALTY_DEGRADED - } else { - 0 - }; - let idle_penalty = - (self.writer_idle_rank_for_selection(writer, idle_since_by_writer, now_epoch_secs) - as u64) - * 100; - let queue_cap = self.writer_lifecycle.writer_cmd_channel_capacity.max(1) as u64; - let queue_remaining = writer.tx.capacity() as u64; - let queue_used = queue_cap.saturating_sub(queue_remaining.min(queue_cap)); - let queue_util_pct = queue_used.saturating_mul(100) / queue_cap; - let queue_penalty = queue_util_pct.saturating_mul(4); - let rtt_penalty = - ((writer.rtt_ema_ms_x10.load(Ordering::Relaxed) as u64).saturating_add(5) / 10) - .min(400); - - contour_penalty - .saturating_add(stale_penalty) - .saturating_add(degraded_penalty) - .saturating_add(idle_penalty) - .saturating_add(queue_penalty) - .saturating_add(rtt_penalty) - } - - fn p2c_ordered_candidate_indices( - &self, - candidate_indices: &[usize], - writers_snapshot: &[super::pool::MeWriter], - idle_since_by_writer: &HashMap, - now_epoch_secs: u64, - start: usize, - sample_size: usize, - ) -> Vec { - let total = candidate_indices.len(); - if total == 0 { - return Vec::new(); - } - - let mut sampled = Vec::::with_capacity(sample_size.min(total)); - let mut seen = HashSet::::with_capacity(total); - for offset in 0..sample_size.min(total) { - let idx = candidate_indices[(start + offset) % total]; - if seen.insert(idx) { - sampled.push(idx); - } - } - - sampled.sort_by_key(|idx| { - let writer = &writers_snapshot[*idx]; - ( - self.writer_pick_score(writer, idle_since_by_writer, now_epoch_secs), - writer.addr, - writer.id, - ) - }); - - let mut ordered = Vec::::with_capacity(total); - ordered.extend(sampled.iter().copied()); - for offset in 0..total { - let idx = candidate_indices[(start + offset) % total]; - if seen.insert(idx) { - ordered.push(idx); - } - } - ordered + self.send_proxy_req( + conn_id, + target_dc, + client_addr, + our_addr, + payload.as_ref(), + proto_flags, + tag.as_ref().map(|tag| tag.as_slice()), + ) + .await } } diff --git a/src/transport/middle_proxy/send/close.rs b/src/transport/middle_proxy/send/close.rs new file mode 100644 index 0000000..ddd25c9 --- /dev/null +++ b/src/transport/middle_proxy/send/close.rs @@ -0,0 +1,66 @@ +use std::sync::Arc; +use std::sync::atomic::Ordering; +use std::time::Duration; + +use tokio::sync::mpsc::error::TrySendError; +use tracing::debug; + +use crate::error::Result; +use crate::protocol::constants::{RPC_CLOSE_CONN_U32, RPC_CLOSE_EXT_U32}; + +use super::super::MePool; +use super::super::codec::{WriterCommand, build_control_payload}; + +impl MePool { + pub async fn send_close(self: &Arc, conn_id: u64) -> Result<()> { + if let Some(w) = self.registry.get_writer(conn_id).await { + let payload = build_control_payload(RPC_CLOSE_EXT_U32, conn_id); + if w.tx + .send(WriterCommand::ControlAndFlush(payload)) + .await + .is_err() + { + debug!("ME close write failed"); + self.remove_writer_and_close_clients(w.writer_id).await; + } + } else { + debug!(conn_id, "ME close skipped (writer missing)"); + } + + self.registry.unregister(conn_id).await; + Ok(()) + } + + pub async fn send_close_conn(self: &Arc, conn_id: u64) -> Result<()> { + if let Some(w) = self.registry.get_writer(conn_id).await { + let payload = build_control_payload(RPC_CLOSE_CONN_U32, conn_id); + match w.tx.try_send(WriterCommand::ControlAndFlush(payload)) { + Ok(()) => {} + Err(TrySendError::Full(cmd)) => { + let _ = tokio::time::timeout(Duration::from_millis(50), w.tx.send(cmd)).await; + } + Err(TrySendError::Closed(_)) => { + debug!(conn_id, "ME close_conn skipped: writer channel closed"); + } + } + } else { + debug!(conn_id, "ME close_conn skipped (writer missing)"); + } + + self.registry.unregister(conn_id).await; + Ok(()) + } + + pub async fn shutdown_send_close_conn_all(self: &Arc) -> usize { + let conn_ids = self.registry.active_conn_ids().await; + let total = conn_ids.len(); + for conn_id in conn_ids { + let _ = self.send_close_conn(conn_id).await; + } + total + } + + pub fn connection_count(&self) -> usize { + self.conn_count.load(Ordering::Relaxed) + } +} diff --git a/src/transport/middle_proxy/send/recovery.rs b/src/transport/middle_proxy/send/recovery.rs new file mode 100644 index 0000000..85772da --- /dev/null +++ b/src/transport/middle_proxy/send/recovery.rs @@ -0,0 +1,218 @@ +use std::collections::HashSet; +use std::net::SocketAddr; +use std::sync::Arc; +use std::sync::atomic::Ordering; +use std::time::{Duration, Instant}; + +use tracing::warn; + +use crate::network::IpFamily; + +use super::super::MePool; +use super::{ + HYBRID_GLOBAL_BURST_PERIOD_ROUNDS, HYBRID_RECENT_SUCCESS_WINDOW_MS, + HYBRID_RECOVERY_TRIGGER_MIN_INTERVAL_MS, HYBRID_TIMEOUT_WARN_RATE_LIMIT_MS, +}; + +impl MePool { + pub(super) async fn wait_for_writer_until(&self, deadline: Instant) -> bool { + let mut rx = self.writer_epoch.subscribe(); + if !self.writers.read().await.is_empty() { + return true; + } + let now = Instant::now(); + if now >= deadline { + return !self.writers.read().await.is_empty(); + } + let timeout = deadline.saturating_duration_since(now); + if tokio::time::timeout(timeout, rx.changed()).await.is_ok() { + return !self.writers.read().await.is_empty(); + } + !self.writers.read().await.is_empty() + } + + pub(super) async fn wait_for_candidate_until(&self, routed_dc: i32, deadline: Instant) -> bool { + let mut rx = self.writer_epoch.subscribe(); + loop { + if self.has_candidate_for_target_dc(routed_dc).await { + return true; + } + + let now = Instant::now(); + if now >= deadline { + return self.has_candidate_for_target_dc(routed_dc).await; + } + + if self.has_candidate_for_target_dc(routed_dc).await { + return true; + } + let remaining = deadline.saturating_duration_since(Instant::now()); + if remaining.is_zero() { + return self.has_candidate_for_target_dc(routed_dc).await; + } + if tokio::time::timeout(remaining, rx.changed()).await.is_err() { + return self.has_candidate_for_target_dc(routed_dc).await; + } + } + } + + pub(super) async fn has_candidate_for_target_dc(&self, routed_dc: i32) -> bool { + let writers_snapshot = { + let ws = self.writers.read().await; + if ws.is_empty() { + return false; + } + ws.clone() + }; + let mut candidate_indices = self + .candidate_indices_for_dc(&writers_snapshot, routed_dc, false) + .await; + if candidate_indices.is_empty() { + candidate_indices = self + .candidate_indices_for_dc(&writers_snapshot, routed_dc, true) + .await; + } + !candidate_indices.is_empty() + } + + pub(super) async fn trigger_async_recovery_for_target_dc( + self: &Arc, + routed_dc: i32, + ) -> bool { + let endpoints = self.endpoint_candidates_for_target_dc(routed_dc).await; + if endpoints.is_empty() { + return false; + } + self.stats.increment_me_async_recovery_trigger_total(); + for addr in endpoints.into_iter().take(8) { + self.trigger_immediate_refill_for_dc(addr, routed_dc); + } + true + } + + pub(super) async fn trigger_async_recovery_global(self: &Arc) { + self.stats.increment_me_async_recovery_trigger_total(); + let mut seen = HashSet::<(i32, SocketAddr)>::new(); + for family in self.family_order() { + let map_guard = match family { + IpFamily::V4 => self.proxy_map_v4.read().await, + IpFamily::V6 => self.proxy_map_v6.read().await, + }; + for (dc, addrs) in map_guard.iter() { + for (ip, port) in addrs { + let addr = SocketAddr::new(*ip, *port); + if seen.insert((*dc, addr)) { + self.trigger_immediate_refill_for_dc(addr, *dc); + } + if seen.len() >= 8 { + return; + } + } + } + } + } + + pub(super) async fn endpoint_candidates_for_target_dc( + &self, + routed_dc: i32, + ) -> Vec { + self.preferred_endpoints_for_dc(routed_dc).await + } + + pub(super) async fn maybe_trigger_hybrid_recovery( + self: &Arc, + routed_dc: i32, + hybrid_recovery_round: &mut u32, + hybrid_last_recovery_at: &mut Option, + hybrid_wait_step: Duration, + ) { + if !self.try_consume_hybrid_recovery_trigger_slot(HYBRID_RECOVERY_TRIGGER_MIN_INTERVAL_MS) { + return; + } + if let Some(last) = *hybrid_last_recovery_at + && last.elapsed() < hybrid_wait_step + { + return; + } + + let round = *hybrid_recovery_round; + let target_triggered = self.trigger_async_recovery_for_target_dc(routed_dc).await; + if !target_triggered || round.is_multiple_of(HYBRID_GLOBAL_BURST_PERIOD_ROUNDS) { + self.trigger_async_recovery_global().await; + } + *hybrid_recovery_round = round.saturating_add(1); + *hybrid_last_recovery_at = Some(Instant::now()); + } + + pub(super) fn hybrid_total_wait_budget(&self) -> Duration { + let base = self + .route_runtime + .me_route_hybrid_max_wait + .max(Duration::from_millis(50)); + let now_ms = Self::now_epoch_millis(); + let last_success_ms = self + .route_runtime + .me_route_last_success_epoch_ms + .load(Ordering::Relaxed); + if last_success_ms != 0 + && now_ms.saturating_sub(last_success_ms) <= HYBRID_RECENT_SUCCESS_WINDOW_MS + { + return base.saturating_mul(2); + } + base + } + + pub(super) fn note_hybrid_route_success(&self) { + self.route_runtime + .me_route_last_success_epoch_ms + .store(Self::now_epoch_millis(), Ordering::Relaxed); + } + + pub(super) fn on_hybrid_timeout(&self, deadline: Instant, routed_dc: i32) { + self.stats.increment_me_hybrid_timeout_total(); + let now_ms = Self::now_epoch_millis(); + let mut last_warn_ms = self + .route_runtime + .me_route_hybrid_timeout_warn_epoch_ms + .load(Ordering::Relaxed); + while now_ms.saturating_sub(last_warn_ms) >= HYBRID_TIMEOUT_WARN_RATE_LIMIT_MS { + match self + .route_runtime + .me_route_hybrid_timeout_warn_epoch_ms + .compare_exchange_weak(last_warn_ms, now_ms, Ordering::AcqRel, Ordering::Relaxed) + { + Ok(_) => { + warn!( + routed_dc, + budget_ms = self.hybrid_total_wait_budget().as_millis() as u64, + elapsed_ms = deadline.elapsed().as_millis() as u64, + "ME hybrid route timeout reached" + ); + break; + } + Err(actual) => last_warn_ms = actual, + } + } + } + + pub(super) fn try_consume_hybrid_recovery_trigger_slot(&self, min_interval_ms: u64) -> bool { + let now_ms = Self::now_epoch_millis(); + let mut last_trigger_ms = self + .route_runtime + .me_async_recovery_last_trigger_epoch_ms + .load(Ordering::Relaxed); + loop { + if now_ms.saturating_sub(last_trigger_ms) < min_interval_ms { + return false; + } + match self + .route_runtime + .me_async_recovery_last_trigger_epoch_ms + .compare_exchange_weak(last_trigger_ms, now_ms, Ordering::AcqRel, Ordering::Relaxed) + { + Ok(_) => return true, + Err(actual) => last_trigger_ms = actual, + } + } + } +} diff --git a/src/transport/middle_proxy/send/selection.rs b/src/transport/middle_proxy/send/selection.rs new file mode 100644 index 0000000..ac05fa1 --- /dev/null +++ b/src/transport/middle_proxy/send/selection.rs @@ -0,0 +1,165 @@ +use std::collections::{HashMap, HashSet}; +use std::sync::atomic::Ordering; + +use super::super::MePool; +use super::super::pool::WriterContour; +use super::{ + IDLE_WRITER_PENALTY_HIGH_SECS, IDLE_WRITER_PENALTY_MID_SECS, PICK_PENALTY_DEGRADED, + PICK_PENALTY_DRAINING, PICK_PENALTY_STALE, PICK_PENALTY_WARM, +}; + +impl MePool { + pub(super) async fn candidate_indices_for_dc( + &self, + writers: &[super::super::pool::MeWriter], + routed_dc: i32, + include_warm: bool, + ) -> Vec { + let preferred = self.preferred_endpoints_for_dc(routed_dc).await; + if preferred.is_empty() { + return Vec::new(); + } + + let mut out = Vec::new(); + for (idx, w) in writers.iter().enumerate() { + if !self.writer_eligible_for_selection(w, include_warm) { + continue; + } + if w.writer_dc == routed_dc && preferred.contains(&w.addr) { + out.push(idx); + } + } + out + } + + pub(super) fn writer_eligible_for_selection( + &self, + writer: &super::super::pool::MeWriter, + include_warm: bool, + ) -> bool { + if !self.writer_accepts_new_binding(writer) { + return false; + } + + match WriterContour::from_u8(writer.contour.load(Ordering::Relaxed)) { + WriterContour::Active => true, + WriterContour::Warm => include_warm, + WriterContour::Draining => true, + } + } + + pub(super) fn writer_contour_rank_for_selection( + &self, + writer: &super::super::pool::MeWriter, + ) -> usize { + match WriterContour::from_u8(writer.contour.load(Ordering::Relaxed)) { + WriterContour::Active => 0, + WriterContour::Warm => 1, + WriterContour::Draining => 2, + } + } + + pub(super) fn writer_idle_rank_for_selection( + &self, + writer: &super::super::pool::MeWriter, + idle_since_by_writer: &HashMap, + now_epoch_secs: u64, + ) -> usize { + let Some(idle_since) = idle_since_by_writer.get(&writer.id).copied() else { + return 0; + }; + let idle_age_secs = now_epoch_secs.saturating_sub(idle_since); + if idle_age_secs >= IDLE_WRITER_PENALTY_HIGH_SECS { + 2 + } else if idle_age_secs >= IDLE_WRITER_PENALTY_MID_SECS { + 1 + } else { + 0 + } + } + + pub(super) fn writer_pick_score( + &self, + writer: &super::super::pool::MeWriter, + idle_since_by_writer: &HashMap, + now_epoch_secs: u64, + ) -> u64 { + let contour_penalty = match WriterContour::from_u8(writer.contour.load(Ordering::Relaxed)) { + WriterContour::Active => 0, + WriterContour::Warm => PICK_PENALTY_WARM, + WriterContour::Draining => PICK_PENALTY_DRAINING, + }; + let stale_penalty = if writer.generation < self.current_generation() { + PICK_PENALTY_STALE + } else { + 0 + }; + let degraded_penalty = if writer.degraded.load(Ordering::Relaxed) { + PICK_PENALTY_DEGRADED + } else { + 0 + }; + let idle_penalty = + (self.writer_idle_rank_for_selection(writer, idle_since_by_writer, now_epoch_secs) + as u64) + * 100; + let queue_cap = self.writer_lifecycle.writer_cmd_channel_capacity.max(1) as u64; + let queue_remaining = writer.tx.capacity() as u64; + let queue_used = queue_cap.saturating_sub(queue_remaining.min(queue_cap)); + let queue_util_pct = queue_used.saturating_mul(100) / queue_cap; + let queue_penalty = queue_util_pct.saturating_mul(4); + let rtt_penalty = + ((writer.rtt_ema_ms_x10.load(Ordering::Relaxed) as u64).saturating_add(5) / 10) + .min(400); + + contour_penalty + .saturating_add(stale_penalty) + .saturating_add(degraded_penalty) + .saturating_add(idle_penalty) + .saturating_add(queue_penalty) + .saturating_add(rtt_penalty) + } + + pub(super) fn p2c_ordered_candidate_indices( + &self, + candidate_indices: &[usize], + writers_snapshot: &[super::super::pool::MeWriter], + idle_since_by_writer: &HashMap, + now_epoch_secs: u64, + start: usize, + sample_size: usize, + ) -> Vec { + let total = candidate_indices.len(); + if total == 0 { + return Vec::new(); + } + + let mut sampled = Vec::::with_capacity(sample_size.min(total)); + let mut seen = HashSet::::with_capacity(total); + for offset in 0..sample_size.min(total) { + let idx = candidate_indices[(start + offset) % total]; + if seen.insert(idx) { + sampled.push(idx); + } + } + + sampled.sort_by_key(|idx| { + let writer = &writers_snapshot[*idx]; + ( + self.writer_pick_score(writer, idle_since_by_writer, now_epoch_secs), + writer.addr, + writer.id, + ) + }); + + let mut ordered = Vec::::with_capacity(total); + ordered.extend(sampled.iter().copied()); + for offset in 0..total { + let idx = candidate_indices[(start + offset) % total]; + if seen.insert(idx) { + ordered.push(idx); + } + } + ordered + } +} diff --git a/src/transport/middle_proxy/tests/send_adversarial_tests.rs b/src/transport/middle_proxy/tests/send_adversarial_tests.rs index 074fa33..4050fa1 100644 --- a/src/transport/middle_proxy/tests/send_adversarial_tests.rs +++ b/src/transport/middle_proxy/tests/send_adversarial_tests.rs @@ -165,6 +165,7 @@ async fn recv_data_count(rx: &mut mpsc::Receiver, budget: Duratio match tokio::time::timeout(remaining.min(Duration::from_millis(10)), rx.recv()).await { Ok(Some(WriterCommand::Data(_))) => data_count += 1, Ok(Some(WriterCommand::DataAndFlush(_))) => data_count += 1, + Ok(Some(WriterCommand::ProxyReq(_))) => data_count += 1, Ok(Some(WriterCommand::ControlAndFlush(_))) => data_count += 1, Ok(Some(WriterCommand::Close)) => {} Ok(None) => break, diff --git a/src/transport/middle_proxy/wire.rs b/src/transport/middle_proxy/wire.rs index f7830ea..addd0b9 100644 --- a/src/transport/middle_proxy/wire.rs +++ b/src/transport/middle_proxy/wire.rs @@ -42,22 +42,45 @@ fn append_mapped_addr_and_port(buf: &mut Vec, addr: SocketAddr) { buf.extend_from_slice(&(addr.port() as u32).to_le_bytes()); } -pub(crate) fn build_proxy_req_payload( +fn proxy_tag_wire_len(tag: &[u8]) -> usize { + if tag.len() < 254 { + 4 + 1 + tag.len() + ((4 - ((1 + tag.len()) % 4)) % 4) + } else { + 4 + 4 + tag.len() + ((4 - (tag.len() % 4)) % 4) + } +} + +/// Returns the exact unencrypted RPC_PROXY_REQ payload length for pre-sizing frame buffers. +pub(crate) fn proxy_req_payload_len( + data_len: usize, + proxy_tag: Option<&[u8]>, + proto_flags: u32, +) -> usize { + let base_len = 4 + 4 + 8 + 20 + 20; + let extra_len = if proto_flags & RPC_FLAG_HAS_AD_TAG != 0 { + 4 + proxy_tag.map(proxy_tag_wire_len).unwrap_or(0) + } else { + 0 + }; + base_len + extra_len + data_len +} + +/// Appends RPC_PROXY_REQ payload bytes without allocating an intermediate payload buffer. +pub(crate) fn append_proxy_req_payload_into( + b: &mut Vec, conn_id: u64, client_addr: SocketAddr, our_addr: SocketAddr, data: &[u8], proxy_tag: Option<&[u8]>, proto_flags: u32, -) -> Bytes { - let mut b = Vec::with_capacity(128 + data.len()); - +) { b.extend_from_slice(&RPC_PROXY_REQ_U32.to_le_bytes()); b.extend_from_slice(&proto_flags.to_le_bytes()); b.extend_from_slice(&conn_id.to_le_bytes()); - append_mapped_addr_and_port(&mut b, client_addr); - append_mapped_addr_and_port(&mut b, our_addr); + append_mapped_addr_and_port(b, client_addr); + append_mapped_addr_and_port(b, our_addr); if proto_flags & RPC_FLAG_HAS_AD_TAG != 0 { let extra_start = b.len(); @@ -86,6 +109,26 @@ pub(crate) fn build_proxy_req_payload( } b.extend_from_slice(data); +} + +pub(crate) fn build_proxy_req_payload( + conn_id: u64, + client_addr: SocketAddr, + our_addr: SocketAddr, + data: &[u8], + proxy_tag: Option<&[u8]>, + proto_flags: u32, +) -> Bytes { + let mut b = Vec::with_capacity(128 + data.len()); + append_proxy_req_payload_into( + &mut b, + conn_id, + client_addr, + our_addr, + data, + proxy_tag, + proto_flags, + ); Bytes::from(b) }