From ecd6a19246ba73a8403e5bbac78d85d47c19deec Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Tue, 31 Mar 2026 18:40:04 +0300 Subject: [PATCH] Cleanup Methods for Memory Consistency Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com> --- src/ip_tracker.rs | 87 +++++++++++++++++++++++++++++++++++++++ src/metrics.rs | 48 +++++++++++++++++++++ src/stats/mod.rs | 27 +++++++++++- src/stream/buffer_pool.rs | 69 +++++++++++++++++++++++++++---- 4 files changed, 221 insertions(+), 10 deletions(-) diff --git a/src/ip_tracker.rs b/src/ip_tracker.rs index 76ea424..b4d934f 100644 --- a/src/ip_tracker.rs +++ b/src/ip_tracker.rs @@ -26,6 +26,15 @@ pub struct UserIpTracker { cleanup_drain_lock: Arc>, } +#[derive(Debug, Clone, Copy)] +pub struct UserIpTrackerMemoryStats { + pub active_users: usize, + pub recent_users: usize, + pub active_entries: usize, + pub recent_entries: usize, + pub cleanup_queue_len: usize, +} + impl UserIpTracker { pub fn new() -> Self { Self { @@ -141,6 +150,13 @@ impl UserIpTracker { let mut active_ips = self.active_ips.write().await; let mut recent_ips = self.recent_ips.write().await; + let window = *self.limit_window.read().await; + let now = Instant::now(); + + for user_recent in recent_ips.values_mut() { + Self::prune_recent(user_recent, now, window); + } + let mut users = Vec::::with_capacity(active_ips.len().saturating_add(recent_ips.len())); users.extend(active_ips.keys().cloned()); @@ -166,6 +182,26 @@ impl UserIpTracker { } } + pub async fn memory_stats(&self) -> UserIpTrackerMemoryStats { + let cleanup_queue_len = self + .cleanup_queue + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) + .len(); + let active_ips = self.active_ips.read().await; + let recent_ips = self.recent_ips.read().await; + let active_entries = active_ips.values().map(HashMap::len).sum(); + let recent_entries = recent_ips.values().map(HashMap::len).sum(); + + UserIpTrackerMemoryStats { + active_users: active_ips.len(), + recent_users: recent_ips.len(), + active_entries, + recent_entries, + cleanup_queue_len, + } + } + pub async fn set_limit_policy(&self, mode: UserMaxUniqueIpsMode, window_secs: u64) { { let mut current_mode = self.limit_mode.write().await; @@ -451,6 +487,7 @@ impl Default for UserIpTracker { mod tests { use super::*; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; + use std::sync::atomic::Ordering; fn test_ipv4(oct1: u8, oct2: u8, oct3: u8, oct4: u8) -> IpAddr { IpAddr::V4(Ipv4Addr::new(oct1, oct2, oct3, oct4)) @@ -764,4 +801,54 @@ mod tests { tokio::time::sleep(Duration::from_millis(1100)).await; assert!(tracker.check_and_add("test_user", ip2).await.is_ok()); } + + #[tokio::test] + async fn test_memory_stats_reports_queue_and_entry_counts() { + let tracker = UserIpTracker::new(); + tracker.set_user_limit("test_user", 4).await; + let ip1 = test_ipv4(10, 2, 0, 1); + let ip2 = test_ipv4(10, 2, 0, 2); + + tracker.check_and_add("test_user", ip1).await.unwrap(); + tracker.check_and_add("test_user", ip2).await.unwrap(); + tracker.enqueue_cleanup("test_user".to_string(), ip1); + + let snapshot = tracker.memory_stats().await; + assert_eq!(snapshot.active_users, 1); + assert_eq!(snapshot.recent_users, 1); + assert_eq!(snapshot.active_entries, 2); + assert_eq!(snapshot.recent_entries, 2); + assert_eq!(snapshot.cleanup_queue_len, 1); + } + + #[tokio::test] + async fn test_compact_prunes_stale_recent_entries() { + let tracker = UserIpTracker::new(); + tracker + .set_limit_policy(UserMaxUniqueIpsMode::TimeWindow, 1) + .await; + + let stale_user = "stale-user".to_string(); + let stale_ip = test_ipv4(10, 3, 0, 1); + { + let mut recent_ips = tracker.recent_ips.write().await; + recent_ips + .entry(stale_user.clone()) + .or_insert_with(HashMap::new) + .insert(stale_ip, Instant::now() - Duration::from_secs(5)); + } + + tracker.last_compact_epoch_secs.store(0, Ordering::Relaxed); + tracker + .check_and_add("trigger-user", test_ipv4(10, 3, 0, 2)) + .await + .unwrap(); + + let recent_ips = tracker.recent_ips.read().await; + let stale_exists = recent_ips + .get(&stale_user) + .map(|ips| ips.contains_key(&stale_ip)) + .unwrap_or(false); + assert!(!stale_exists); + } } diff --git a/src/metrics.rs b/src/metrics.rs index 3a88a5b..c05737f 100644 --- a/src/metrics.rs +++ b/src/metrics.rs @@ -2490,6 +2490,48 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp if user_enabled { 0 } else { 1 } ); + let ip_memory = ip_tracker.memory_stats().await; + let _ = writeln!( + out, + "# HELP telemt_ip_tracker_users Number of users tracked by IP limiter state" + ); + let _ = writeln!(out, "# TYPE telemt_ip_tracker_users gauge"); + let _ = writeln!( + out, + "telemt_ip_tracker_users{{scope=\"active\"}} {}", + ip_memory.active_users + ); + let _ = writeln!( + out, + "telemt_ip_tracker_users{{scope=\"recent\"}} {}", + ip_memory.recent_users + ); + let _ = writeln!( + out, + "# HELP telemt_ip_tracker_entries Number of IP entries tracked by limiter state" + ); + let _ = writeln!(out, "# TYPE telemt_ip_tracker_entries gauge"); + let _ = writeln!( + out, + "telemt_ip_tracker_entries{{scope=\"active\"}} {}", + ip_memory.active_entries + ); + let _ = writeln!( + out, + "telemt_ip_tracker_entries{{scope=\"recent\"}} {}", + ip_memory.recent_entries + ); + let _ = writeln!( + out, + "# HELP telemt_ip_tracker_cleanup_queue_len Deferred disconnect cleanup queue length" + ); + let _ = writeln!(out, "# TYPE telemt_ip_tracker_cleanup_queue_len gauge"); + let _ = writeln!( + out, + "telemt_ip_tracker_cleanup_queue_len {}", + ip_memory.cleanup_queue_len + ); + if user_enabled { for entry in stats.iter_user_stats() { let user = entry.key(); @@ -2728,6 +2770,9 @@ mod tests { assert!(output.contains("telemt_user_unique_ips_recent_window{user=\"alice\"} 1")); assert!(output.contains("telemt_user_unique_ips_limit{user=\"alice\"} 4")); assert!(output.contains("telemt_user_unique_ips_utilization{user=\"alice\"} 0.250000")); + assert!(output.contains("telemt_ip_tracker_users{scope=\"active\"} 1")); + assert!(output.contains("telemt_ip_tracker_entries{scope=\"active\"} 1")); + assert!(output.contains("telemt_ip_tracker_cleanup_queue_len 0")); } #[tokio::test] @@ -2799,6 +2844,9 @@ mod tests { assert!(output.contains("# TYPE telemt_user_unique_ips_recent_window gauge")); assert!(output.contains("# TYPE telemt_user_unique_ips_limit gauge")); assert!(output.contains("# TYPE telemt_user_unique_ips_utilization gauge")); + assert!(output.contains("# TYPE telemt_ip_tracker_users gauge")); + assert!(output.contains("# TYPE telemt_ip_tracker_entries gauge")); + assert!(output.contains("# TYPE telemt_ip_tracker_cleanup_queue_len gauge")); } #[tokio::test] diff --git a/src/stats/mod.rs b/src/stats/mod.rs index 2d1f413..18cf360 100644 --- a/src/stats/mod.rs +++ b/src/stats/mod.rs @@ -2171,6 +2171,8 @@ impl ReplayShard { fn cleanup(&mut self, now: Instant, window: Duration) { if window.is_zero() { + self.cache.clear(); + self.queue.clear(); return; } let cutoff = now.checked_sub(window).unwrap_or(now); @@ -2192,13 +2194,22 @@ impl ReplayShard { } fn check(&mut self, key: &[u8], now: Instant, window: Duration) -> bool { + if window.is_zero() { + return false; + } self.cleanup(now, window); // key is &[u8], resolves Q=[u8] via Box<[u8]>: Borrow<[u8]> self.cache.get(key).is_some() } fn add(&mut self, key: &[u8], now: Instant, window: Duration) { + if window.is_zero() { + return; + } self.cleanup(now, window); + if self.cache.peek(key).is_some() { + return; + } let seq = self.next_seq(); let boxed_key: Box<[u8]> = key.into(); @@ -2341,7 +2352,7 @@ impl ReplayChecker { let interval = if self.window.as_secs() > 60 { Duration::from_secs(30) } else { - Duration::from_secs(self.window.as_secs().max(1) / 2) + Duration::from_secs((self.window.as_secs().max(1) / 2).max(1)) }; loop { @@ -2553,6 +2564,20 @@ mod tests { assert!(!checker.check_handshake(b"expire")); } + #[test] + fn test_replay_checker_zero_window_does_not_retain_entries() { + let checker = ReplayChecker::new(100, Duration::ZERO); + + for _ in 0..1_000 { + assert!(!checker.check_handshake(b"no-retain")); + checker.add_handshake(b"no-retain"); + } + + let stats = checker.stats(); + assert_eq!(stats.total_entries, 0); + assert_eq!(stats.total_queue_len, 0); + } + #[test] fn test_replay_checker_stats() { let checker = ReplayChecker::new(100, Duration::from_secs(60)); diff --git a/src/stream/buffer_pool.rs b/src/stream/buffer_pool.rs index 6cdac60..0040fa5 100644 --- a/src/stream/buffer_pool.rs +++ b/src/stream/buffer_pool.rs @@ -35,6 +35,10 @@ pub struct BufferPool { misses: AtomicUsize, /// Number of successful reuses hits: AtomicUsize, + /// Number of non-standard buffers replaced with a fresh default-sized buffer + replaced_nonstandard: AtomicUsize, + /// Number of buffers dropped because the pool queue was full + dropped_pool_full: AtomicUsize, } impl BufferPool { @@ -52,6 +56,8 @@ impl BufferPool { allocated: AtomicUsize::new(0), misses: AtomicUsize::new(0), hits: AtomicUsize::new(0), + replaced_nonstandard: AtomicUsize::new(0), + dropped_pool_full: AtomicUsize::new(0), } } @@ -91,17 +97,36 @@ impl BufferPool { /// Return a buffer to the pool fn return_buffer(&self, mut buffer: BytesMut) { - // Clear the buffer but keep capacity - buffer.clear(); + const MAX_RETAINED_BUFFER_FACTOR: usize = 2; - // Only return if we haven't exceeded max and buffer is right size - if buffer.capacity() >= self.buffer_size { - // Try to push to pool, if full just drop - let _ = self.buffers.push(buffer); + // Clear the buffer but keep capacity. + buffer.clear(); + let max_retained_capacity = self + .buffer_size + .saturating_mul(MAX_RETAINED_BUFFER_FACTOR) + .max(self.buffer_size); + + // Keep only near-default capacities in the pool. Oversized buffers keep + // RSS elevated for hours under churn; replace them with default-sized + // buffers before re-pooling. + if buffer.capacity() < self.buffer_size || buffer.capacity() > max_retained_capacity { + self.replaced_nonstandard.fetch_add(1, Ordering::Relaxed); + buffer = BytesMut::with_capacity(self.buffer_size); } - // If buffer was dropped (pool full), decrement allocated - // Actually we don't decrement here because the buffer might have been - // grown beyond our size - we just let it go + + // Try to return into the queue; if full, drop and update accounting. + if self.buffers.push(buffer).is_err() { + self.dropped_pool_full.fetch_add(1, Ordering::Relaxed); + self.decrement_allocated(); + } + } + + fn decrement_allocated(&self) { + let _ = self + .allocated + .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |current| { + Some(current.saturating_sub(1)) + }); } /// Get pool statistics @@ -113,6 +138,8 @@ impl BufferPool { buffer_size: self.buffer_size, hits: self.hits.load(Ordering::Relaxed), misses: self.misses.load(Ordering::Relaxed), + replaced_nonstandard: self.replaced_nonstandard.load(Ordering::Relaxed), + dropped_pool_full: self.dropped_pool_full.load(Ordering::Relaxed), } } @@ -160,6 +187,10 @@ pub struct PoolStats { pub hits: usize, /// Number of cache misses (new allocation) pub misses: usize, + /// Number of non-standard buffers replaced during return + pub replaced_nonstandard: usize, + /// Number of buffers dropped because the pool queue was full + pub dropped_pool_full: usize, } impl PoolStats { @@ -185,6 +216,7 @@ pub struct PooledBuffer { impl PooledBuffer { /// Take the inner buffer, preventing return to pool pub fn take(mut self) -> BytesMut { + self.pool.decrement_allocated(); self.buffer.take().unwrap() } @@ -364,6 +396,25 @@ mod tests { let stats = pool.stats(); assert_eq!(stats.pooled, 0); + assert_eq!(stats.allocated, 0); + } + + #[test] + fn test_pool_replaces_oversized_buffers() { + let pool = Arc::new(BufferPool::with_config(1024, 10)); + + { + let mut buf = pool.get(); + buf.reserve(8192); + assert!(buf.capacity() > 2048); + } + + let stats = pool.stats(); + assert_eq!(stats.replaced_nonstandard, 1); + assert_eq!(stats.pooled, 1); + + let buf = pool.get(); + assert!(buf.capacity() <= 2048); } #[test]