use std::borrow::Borrow; use std::collections::VecDeque; use std::collections::hash_map::DefaultHasher; use std::hash::{Hash, Hasher}; use std::num::NonZeroUsize; use std::sync::Arc; use std::sync::atomic::{AtomicU64, Ordering}; use std::time::{Duration, Instant}; use lru::LruCache; use parking_lot::Mutex; use tracing::debug; const REPLAY_INLINE_KEY_CAP: usize = 48; #[derive(Clone)] enum ReplayKey { Inline { len: u8, bytes: [u8; REPLAY_INLINE_KEY_CAP], }, Heap(Arc<[u8]>), } impl ReplayKey { fn from_slice(key: &[u8]) -> Self { if key.len() <= REPLAY_INLINE_KEY_CAP { let mut bytes = [0u8; REPLAY_INLINE_KEY_CAP]; bytes[..key.len()].copy_from_slice(key); return Self::Inline { len: key.len() as u8, bytes, }; } Self::Heap(Arc::from(key)) } fn as_slice(&self) -> &[u8] { match self { Self::Inline { len, bytes } => &bytes[..*len as usize], Self::Heap(bytes) => bytes.as_ref(), } } } impl Borrow<[u8]> for ReplayKey { fn borrow(&self) -> &[u8] { self.as_slice() } } impl PartialEq for ReplayKey { fn eq(&self, other: &Self) -> bool { self.as_slice() == other.as_slice() } } impl Eq for ReplayKey {} impl Hash for ReplayKey { fn hash(&self, state: &mut H) { self.as_slice().hash(state); } } pub struct ReplayChecker { handshake_shards: Vec>, tls_shards: Vec>, shard_mask: usize, window: Duration, tls_window: Duration, checks: AtomicU64, hits: AtomicU64, additions: AtomicU64, cleanups: AtomicU64, } struct ReplayEntry { seq: u64, } struct ReplayShard { cache: LruCache, queue: VecDeque<(Instant, ReplayKey, u64)>, seq_counter: u64, capacity: usize, } impl ReplayShard { fn new(cap: NonZeroUsize) -> Self { Self { cache: LruCache::new(cap), queue: VecDeque::with_capacity(cap.get()), seq_counter: 0, capacity: cap.get(), } } fn next_seq(&mut self) -> u64 { self.seq_counter += 1; self.seq_counter } fn cleanup(&mut self, now: Instant, window: Duration) { if window.is_zero() { self.cache.clear(); self.queue.clear(); return; } let cutoff = now.checked_sub(window).unwrap_or(now); while let Some((ts, _, _)) = self.queue.front() { if *ts >= cutoff { break; } self.evict_queue_front(); } } fn evict_queue_front(&mut self) { let Some((_, key, queue_seq)) = self.queue.pop_front() else { return; }; if let Some(entry) = self.cache.peek(key.as_slice()) && entry.seq == queue_seq { self.cache.pop(key.as_slice()); } } fn check(&mut self, key: &[u8], now: Instant, window: Duration) -> bool { if window.is_zero() { return false; } self.cleanup(now, window); self.cache.get(key).is_some() } fn add_owned(&mut self, key: ReplayKey, now: Instant, window: Duration) { if window.is_zero() { return; } self.cleanup(now, window); if self.cache.peek(key.as_slice()).is_some() { return; } while self.queue.len() >= self.capacity { self.evict_queue_front(); } let seq = self.next_seq(); self.cache.put(key.clone(), ReplayEntry { seq }); self.queue.push_back((now, key, seq)); } fn len(&self) -> usize { self.cache.len() } } impl ReplayChecker { pub fn new(total_capacity: usize, window: Duration) -> Self { const MIN_TLS_REPLAY_WINDOW: Duration = Duration::from_secs(120); let num_shards = 64; let shard_capacity = (total_capacity / num_shards).max(1); let cap = NonZeroUsize::new(shard_capacity).unwrap(); let mut handshake_shards = Vec::with_capacity(num_shards); let mut tls_shards = Vec::with_capacity(num_shards); for _ in 0..num_shards { handshake_shards.push(Mutex::new(ReplayShard::new(cap))); tls_shards.push(Mutex::new(ReplayShard::new(cap))); } Self { handshake_shards, tls_shards, shard_mask: num_shards - 1, window, tls_window: window.max(MIN_TLS_REPLAY_WINDOW), checks: AtomicU64::new(0), hits: AtomicU64::new(0), additions: AtomicU64::new(0), cleanups: AtomicU64::new(0), } } fn get_shard_idx(&self, key: &[u8]) -> usize { let mut hasher = DefaultHasher::new(); key.hash(&mut hasher); (hasher.finish() as usize) & self.shard_mask } fn check_and_add_internal( &self, data: &[u8], shards: &[Mutex], window: Duration, ) -> bool { self.checks.fetch_add(1, Ordering::Relaxed); let idx = self.get_shard_idx(data); let owned_key = ReplayKey::from_slice(data); let mut shard = shards[idx].lock(); let now = Instant::now(); let found = shard.check(data, now, window); if found { self.hits.fetch_add(1, Ordering::Relaxed); } else { shard.add_owned(owned_key, now, window); self.additions.fetch_add(1, Ordering::Relaxed); } found } fn check_only_internal( &self, data: &[u8], shards: &[Mutex], window: Duration, ) -> bool { self.checks.fetch_add(1, Ordering::Relaxed); let idx = self.get_shard_idx(data); let mut shard = shards[idx].lock(); let found = shard.check(data, Instant::now(), window); if found { self.hits.fetch_add(1, Ordering::Relaxed); } found } fn add_only(&self, data: &[u8], shards: &[Mutex], window: Duration) { self.additions.fetch_add(1, Ordering::Relaxed); let idx = self.get_shard_idx(data); let owned_key = ReplayKey::from_slice(data); let mut shard = shards[idx].lock(); shard.add_owned(owned_key, Instant::now(), window); } pub fn check_and_add_handshake(&self, data: &[u8]) -> bool { self.check_and_add_internal(data, &self.handshake_shards, self.window) } pub fn check_and_add_tls_digest(&self, data: &[u8]) -> bool { self.check_and_add_internal(data, &self.tls_shards, self.tls_window) } pub fn check_handshake(&self, data: &[u8]) -> bool { self.check_and_add_handshake(data) } pub fn add_handshake(&self, data: &[u8]) { self.add_only(data, &self.handshake_shards, self.window) } pub fn check_tls_digest(&self, data: &[u8]) -> bool { self.check_only_internal(data, &self.tls_shards, self.tls_window) } pub fn add_tls_digest(&self, data: &[u8]) { self.add_only(data, &self.tls_shards, self.tls_window) } pub fn stats(&self) -> ReplayStats { let mut total_entries = 0; let mut total_queue_len = 0; for shard in &self.handshake_shards { let s = shard.lock(); total_entries += s.cache.len(); total_queue_len += s.queue.len(); } for shard in &self.tls_shards { let s = shard.lock(); total_entries += s.cache.len(); total_queue_len += s.queue.len(); } ReplayStats { total_entries, total_queue_len, total_checks: self.checks.load(Ordering::Relaxed), total_hits: self.hits.load(Ordering::Relaxed), total_additions: self.additions.load(Ordering::Relaxed), total_cleanups: self.cleanups.load(Ordering::Relaxed), num_shards: self.handshake_shards.len() + self.tls_shards.len(), window_secs: self.window.as_secs(), } } pub async fn run_periodic_cleanup(&self) { let interval = if self.window.as_secs() > 60 { Duration::from_secs(30) } else { Duration::from_secs((self.window.as_secs().max(1) / 2).max(1)) }; loop { tokio::time::sleep(interval).await; let now = Instant::now(); let mut cleaned = 0usize; for shard_mutex in &self.handshake_shards { let mut shard = shard_mutex.lock(); let before = shard.len(); shard.cleanup(now, self.window); let after = shard.len(); cleaned += before.saturating_sub(after); } for shard_mutex in &self.tls_shards { let mut shard = shard_mutex.lock(); let before = shard.len(); shard.cleanup(now, self.tls_window); let after = shard.len(); cleaned += before.saturating_sub(after); } self.cleanups.fetch_add(1, Ordering::Relaxed); if cleaned > 0 { debug!(cleaned = cleaned, "Replay checker: periodic cleanup"); } } } } #[derive(Debug, Clone)] pub struct ReplayStats { pub total_entries: usize, pub total_queue_len: usize, pub total_checks: u64, pub total_hits: u64, pub total_additions: u64, pub total_cleanups: u64, pub num_shards: usize, pub window_secs: u64, } impl ReplayStats { pub fn hit_rate(&self) -> f64 { if self.total_checks == 0 { 0.0 } else { (self.total_hits as f64 / self.total_checks as f64) * 100.0 } } pub fn ghost_ratio(&self) -> f64 { if self.total_entries == 0 { 0.0 } else { self.total_queue_len as f64 / self.total_entries as f64 } } }