//! Statistics use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; use std::time::{Instant, Duration}; use dashmap::DashMap; use parking_lot::{RwLock, Mutex}; use lru::LruCache; use std::num::NonZeroUsize; use std::hash::{Hash, Hasher}; use std::collections::hash_map::DefaultHasher; use std::collections::VecDeque; /// Thread-safe statistics #[derive(Default)] pub struct Stats { // Global counters connects_all: AtomicU64, connects_bad: AtomicU64, handshake_timeouts: AtomicU64, // Per-user stats user_stats: DashMap, // Start time start_time: RwLock>, } /// Per-user statistics #[derive(Default)] pub struct UserStats { pub connects: AtomicU64, pub curr_connects: AtomicU64, pub octets_from_client: AtomicU64, pub octets_to_client: AtomicU64, pub msgs_from_client: AtomicU64, pub msgs_to_client: AtomicU64, } impl Stats { pub fn new() -> Self { let stats = Self::default(); *stats.start_time.write() = Some(Instant::now()); stats } // Global stats pub fn increment_connects_all(&self) { self.connects_all.fetch_add(1, Ordering::Relaxed); } pub fn increment_connects_bad(&self) { self.connects_bad.fetch_add(1, Ordering::Relaxed); } pub fn increment_handshake_timeouts(&self) { self.handshake_timeouts.fetch_add(1, Ordering::Relaxed); } pub fn get_connects_all(&self) -> u64 { self.connects_all.load(Ordering::Relaxed) } pub fn get_connects_bad(&self) -> u64 { self.connects_bad.load(Ordering::Relaxed) } // User stats pub fn increment_user_connects(&self, user: &str) { self.user_stats .entry(user.to_string()) .or_default() .connects .fetch_add(1, Ordering::Relaxed); } pub fn increment_user_curr_connects(&self, user: &str) { self.user_stats .entry(user.to_string()) .or_default() .curr_connects .fetch_add(1, Ordering::Relaxed); } pub fn decrement_user_curr_connects(&self, user: &str) { if let Some(stats) = self.user_stats.get(user) { stats.curr_connects.fetch_sub(1, Ordering::Relaxed); } } pub fn get_user_curr_connects(&self, user: &str) -> u64 { self.user_stats .get(user) .map(|s| s.curr_connects.load(Ordering::Relaxed)) .unwrap_or(0) } pub fn add_user_octets_from(&self, user: &str, bytes: u64) { self.user_stats .entry(user.to_string()) .or_default() .octets_from_client .fetch_add(bytes, Ordering::Relaxed); } pub fn add_user_octets_to(&self, user: &str, bytes: u64) { self.user_stats .entry(user.to_string()) .or_default() .octets_to_client .fetch_add(bytes, Ordering::Relaxed); } pub fn increment_user_msgs_from(&self, user: &str) { self.user_stats .entry(user.to_string()) .or_default() .msgs_from_client .fetch_add(1, Ordering::Relaxed); } pub fn increment_user_msgs_to(&self, user: &str) { self.user_stats .entry(user.to_string()) .or_default() .msgs_to_client .fetch_add(1, Ordering::Relaxed); } pub fn get_user_total_octets(&self, user: &str) -> u64 { self.user_stats .get(user) .map(|s| { s.octets_from_client.load(Ordering::Relaxed) + s.octets_to_client.load(Ordering::Relaxed) }) .unwrap_or(0) } pub fn uptime_secs(&self) -> f64 { self.start_time.read() .map(|t| t.elapsed().as_secs_f64()) .unwrap_or(0.0) } } /// Sharded Replay attack checker using LRU cache + sliding window /// Uses multiple independent LRU caches to reduce lock contention pub struct ReplayChecker { shards: Vec>, shard_mask: usize, window: Duration, } struct ReplayEntry { seen_at: Instant, } struct ReplayShard { cache: LruCache, ReplayEntry>, queue: VecDeque<(Instant, Vec)>, } impl ReplayShard { fn new(cap: NonZeroUsize) -> Self { Self { cache: LruCache::new(cap), queue: VecDeque::with_capacity(cap.get()), } } fn cleanup(&mut self, now: Instant, window: Duration) { if window.is_zero() { return; } let cutoff = now - window; while let Some((ts, _)) = self.queue.front() { if *ts >= cutoff { break; } let (ts_old, key_old) = self.queue.pop_front().unwrap(); if let Some(entry) = self.cache.get(&key_old) { if entry.seen_at <= ts_old { self.cache.pop(&key_old); } } } } } impl ReplayChecker { /// Create new replay checker with specified capacity per shard /// Total capacity = capacity * num_shards pub fn new(total_capacity: usize, window: Duration) -> Self { // Use 64 shards for good concurrency let num_shards = 64; let shard_capacity = (total_capacity / num_shards).max(1); let cap = NonZeroUsize::new(shard_capacity).unwrap(); let mut shards = Vec::with_capacity(num_shards); for _ in 0..num_shards { shards.push(Mutex::new(ReplayShard::new(cap))); } Self { shards, shard_mask: num_shards - 1, window, } } fn get_shard(&self, key: &[u8]) -> usize { let mut hasher = DefaultHasher::new(); key.hash(&mut hasher); (hasher.finish() as usize) & self.shard_mask } fn check(&self, data: &[u8]) -> bool { let shard_idx = self.get_shard(data); let mut shard = self.shards[shard_idx].lock(); let now = Instant::now(); shard.cleanup(now, self.window); let key = data.to_vec(); shard.cache.get(&key).is_some() } fn add(&self, data: &[u8]) { let shard_idx = self.get_shard(data); let mut shard = self.shards[shard_idx].lock(); let now = Instant::now(); shard.cleanup(now, self.window); let key = data.to_vec(); shard.cache.put(key.clone(), ReplayEntry { seen_at: now }); shard.queue.push_back((now, key)); } pub fn check_handshake(&self, data: &[u8]) -> bool { self.check(data) } pub fn add_handshake(&self, data: &[u8]) { self.add(data) } pub fn check_tls_digest(&self, data: &[u8]) -> bool { self.check(data) } pub fn add_tls_digest(&self, data: &[u8]) { self.add(data) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_stats_shared_counters() { let stats = Arc::new(Stats::new()); let stats1 = Arc::clone(&stats); let stats2 = Arc::clone(&stats); stats1.increment_connects_all(); stats2.increment_connects_all(); stats1.increment_connects_all(); assert_eq!(stats.get_connects_all(), 3); } #[test] fn test_replay_checker_sharding() { let checker = ReplayChecker::new(100, Duration::from_secs(60)); let data1 = b"test1"; let data2 = b"test2"; checker.add_handshake(data1); assert!(checker.check_handshake(data1)); assert!(!checker.check_handshake(data2)); checker.add_handshake(data2); assert!(checker.check_handshake(data2)); } }