Files
telemt/src/stats/replay.rs
T

357 lines
9.8 KiB
Rust

use std::borrow::Borrow;
use std::collections::VecDeque;
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::num::NonZeroUsize;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
use lru::LruCache;
use parking_lot::Mutex;
use tracing::debug;
const REPLAY_INLINE_KEY_CAP: usize = 48;
#[derive(Clone)]
enum ReplayKey {
Inline {
len: u8,
bytes: [u8; REPLAY_INLINE_KEY_CAP],
},
Heap(Arc<[u8]>),
}
impl ReplayKey {
fn from_slice(key: &[u8]) -> Self {
if key.len() <= REPLAY_INLINE_KEY_CAP {
let mut bytes = [0u8; REPLAY_INLINE_KEY_CAP];
bytes[..key.len()].copy_from_slice(key);
return Self::Inline {
len: key.len() as u8,
bytes,
};
}
Self::Heap(Arc::from(key))
}
fn as_slice(&self) -> &[u8] {
match self {
Self::Inline { len, bytes } => &bytes[..*len as usize],
Self::Heap(bytes) => bytes.as_ref(),
}
}
}
impl Borrow<[u8]> for ReplayKey {
fn borrow(&self) -> &[u8] {
self.as_slice()
}
}
impl PartialEq for ReplayKey {
fn eq(&self, other: &Self) -> bool {
self.as_slice() == other.as_slice()
}
}
impl Eq for ReplayKey {}
impl Hash for ReplayKey {
fn hash<H: Hasher>(&self, state: &mut H) {
self.as_slice().hash(state);
}
}
pub struct ReplayChecker {
handshake_shards: Vec<Mutex<ReplayShard>>,
tls_shards: Vec<Mutex<ReplayShard>>,
shard_mask: usize,
window: Duration,
tls_window: Duration,
checks: AtomicU64,
hits: AtomicU64,
additions: AtomicU64,
cleanups: AtomicU64,
}
struct ReplayEntry {
seq: u64,
}
struct ReplayShard {
cache: LruCache<ReplayKey, ReplayEntry>,
queue: VecDeque<(Instant, ReplayKey, u64)>,
seq_counter: u64,
capacity: usize,
}
impl ReplayShard {
fn new(cap: NonZeroUsize) -> Self {
Self {
cache: LruCache::new(cap),
queue: VecDeque::with_capacity(cap.get()),
seq_counter: 0,
capacity: cap.get(),
}
}
fn next_seq(&mut self) -> u64 {
self.seq_counter += 1;
self.seq_counter
}
fn cleanup(&mut self, now: Instant, window: Duration) {
if window.is_zero() {
self.cache.clear();
self.queue.clear();
return;
}
let cutoff = now.checked_sub(window).unwrap_or(now);
while let Some((ts, _, _)) = self.queue.front() {
if *ts >= cutoff {
break;
}
self.evict_queue_front();
}
}
fn evict_queue_front(&mut self) {
let Some((_, key, queue_seq)) = self.queue.pop_front() else {
return;
};
if let Some(entry) = self.cache.peek(key.as_slice())
&& entry.seq == queue_seq
{
self.cache.pop(key.as_slice());
}
}
fn check(&mut self, key: &[u8], now: Instant, window: Duration) -> bool {
if window.is_zero() {
return false;
}
self.cleanup(now, window);
self.cache.get(key).is_some()
}
fn add_owned(&mut self, key: ReplayKey, now: Instant, window: Duration) {
if window.is_zero() {
return;
}
self.cleanup(now, window);
if self.cache.peek(key.as_slice()).is_some() {
return;
}
while self.queue.len() >= self.capacity {
self.evict_queue_front();
}
let seq = self.next_seq();
self.cache.put(key.clone(), ReplayEntry { seq });
self.queue.push_back((now, key, seq));
}
fn len(&self) -> usize {
self.cache.len()
}
}
impl ReplayChecker {
pub fn new(total_capacity: usize, window: Duration) -> Self {
const MIN_TLS_REPLAY_WINDOW: Duration = Duration::from_secs(120);
let num_shards = 64;
let shard_capacity = (total_capacity / num_shards).max(1);
let cap = NonZeroUsize::new(shard_capacity).unwrap();
let mut handshake_shards = Vec::with_capacity(num_shards);
let mut tls_shards = Vec::with_capacity(num_shards);
for _ in 0..num_shards {
handshake_shards.push(Mutex::new(ReplayShard::new(cap)));
tls_shards.push(Mutex::new(ReplayShard::new(cap)));
}
Self {
handshake_shards,
tls_shards,
shard_mask: num_shards - 1,
window,
tls_window: window.max(MIN_TLS_REPLAY_WINDOW),
checks: AtomicU64::new(0),
hits: AtomicU64::new(0),
additions: AtomicU64::new(0),
cleanups: AtomicU64::new(0),
}
}
fn get_shard_idx(&self, key: &[u8]) -> usize {
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
(hasher.finish() as usize) & self.shard_mask
}
fn check_and_add_internal(
&self,
data: &[u8],
shards: &[Mutex<ReplayShard>],
window: Duration,
) -> bool {
self.checks.fetch_add(1, Ordering::Relaxed);
let idx = self.get_shard_idx(data);
let owned_key = ReplayKey::from_slice(data);
let mut shard = shards[idx].lock();
let now = Instant::now();
let found = shard.check(data, now, window);
if found {
self.hits.fetch_add(1, Ordering::Relaxed);
} else {
shard.add_owned(owned_key, now, window);
self.additions.fetch_add(1, Ordering::Relaxed);
}
found
}
fn check_only_internal(
&self,
data: &[u8],
shards: &[Mutex<ReplayShard>],
window: Duration,
) -> bool {
self.checks.fetch_add(1, Ordering::Relaxed);
let idx = self.get_shard_idx(data);
let mut shard = shards[idx].lock();
let found = shard.check(data, Instant::now(), window);
if found {
self.hits.fetch_add(1, Ordering::Relaxed);
}
found
}
fn add_only(&self, data: &[u8], shards: &[Mutex<ReplayShard>], window: Duration) {
self.additions.fetch_add(1, Ordering::Relaxed);
let idx = self.get_shard_idx(data);
let owned_key = ReplayKey::from_slice(data);
let mut shard = shards[idx].lock();
shard.add_owned(owned_key, Instant::now(), window);
}
pub fn check_and_add_handshake(&self, data: &[u8]) -> bool {
self.check_and_add_internal(data, &self.handshake_shards, self.window)
}
pub fn check_and_add_tls_digest(&self, data: &[u8]) -> bool {
self.check_and_add_internal(data, &self.tls_shards, self.tls_window)
}
pub fn check_handshake(&self, data: &[u8]) -> bool {
self.check_and_add_handshake(data)
}
pub fn add_handshake(&self, data: &[u8]) {
self.add_only(data, &self.handshake_shards, self.window)
}
pub fn check_tls_digest(&self, data: &[u8]) -> bool {
self.check_only_internal(data, &self.tls_shards, self.tls_window)
}
pub fn add_tls_digest(&self, data: &[u8]) {
self.add_only(data, &self.tls_shards, self.tls_window)
}
pub fn stats(&self) -> ReplayStats {
let mut total_entries = 0;
let mut total_queue_len = 0;
for shard in &self.handshake_shards {
let s = shard.lock();
total_entries += s.cache.len();
total_queue_len += s.queue.len();
}
for shard in &self.tls_shards {
let s = shard.lock();
total_entries += s.cache.len();
total_queue_len += s.queue.len();
}
ReplayStats {
total_entries,
total_queue_len,
total_checks: self.checks.load(Ordering::Relaxed),
total_hits: self.hits.load(Ordering::Relaxed),
total_additions: self.additions.load(Ordering::Relaxed),
total_cleanups: self.cleanups.load(Ordering::Relaxed),
num_shards: self.handshake_shards.len() + self.tls_shards.len(),
window_secs: self.window.as_secs(),
}
}
pub async fn run_periodic_cleanup(&self) {
let interval = if self.window.as_secs() > 60 {
Duration::from_secs(30)
} else {
Duration::from_secs((self.window.as_secs().max(1) / 2).max(1))
};
loop {
tokio::time::sleep(interval).await;
let now = Instant::now();
let mut cleaned = 0usize;
for shard_mutex in &self.handshake_shards {
let mut shard = shard_mutex.lock();
let before = shard.len();
shard.cleanup(now, self.window);
let after = shard.len();
cleaned += before.saturating_sub(after);
}
for shard_mutex in &self.tls_shards {
let mut shard = shard_mutex.lock();
let before = shard.len();
shard.cleanup(now, self.tls_window);
let after = shard.len();
cleaned += before.saturating_sub(after);
}
self.cleanups.fetch_add(1, Ordering::Relaxed);
if cleaned > 0 {
debug!(cleaned = cleaned, "Replay checker: periodic cleanup");
}
}
}
}
#[derive(Debug, Clone)]
pub struct ReplayStats {
pub total_entries: usize,
pub total_queue_len: usize,
pub total_checks: u64,
pub total_hits: u64,
pub total_additions: u64,
pub total_cleanups: u64,
pub num_shards: usize,
pub window_secs: u64,
}
impl ReplayStats {
pub fn hit_rate(&self) -> f64 {
if self.total_checks == 0 {
0.0
} else {
(self.total_hits as f64 / self.total_checks as f64) * 100.0
}
}
pub fn ghost_ratio(&self) -> f64 {
if self.total_entries == 0 {
0.0
} else {
self.total_queue_len as f64 / self.total_entries as f64
}
}
}