feat: enhance quota user lock management and testing

- Adjusted QUOTA_USER_LOCKS_MAX based on test and non-test configurations to improve flexibility.
- Implemented logic to retain existing locks when the maximum quota is reached, ensuring efficient memory usage.
- Added comprehensive tests for quota user lock functionality, including cache reuse, saturation behavior, and race conditions.
- Enhanced StatsIo struct to manage wake scheduling for read and write operations, preventing unnecessary self-wakes.
- Introduced separate replay checker domains for handshake and TLS to ensure isolation and prevent cross-pollution of keys.
- Added security tests for replay checker to validate domain separation and window clamping behavior.
This commit is contained in:
David Osipov
2026-03-18 23:55:08 +04:00
parent 20e205189c
commit c7cf37898b
18 changed files with 1896 additions and 49 deletions

View File

@@ -1508,9 +1508,11 @@ impl Stats {
// ============= Replay Checker =============
pub struct ReplayChecker {
shards: Vec<Mutex<ReplayShard>>,
handshake_shards: Vec<Mutex<ReplayShard>>,
tls_shards: Vec<Mutex<ReplayShard>>,
shard_mask: usize,
window: Duration,
tls_window: Duration,
checks: AtomicU64,
hits: AtomicU64,
additions: AtomicU64,
@@ -1587,19 +1589,24 @@ impl ReplayShard {
impl ReplayChecker {
pub fn new(total_capacity: usize, window: Duration) -> Self {
const MIN_TLS_REPLAY_WINDOW: Duration = Duration::from_secs(120);
let num_shards = 64;
let shard_capacity = (total_capacity / num_shards).max(1);
let cap = NonZeroUsize::new(shard_capacity).unwrap();
let mut shards = Vec::with_capacity(num_shards);
let mut handshake_shards = Vec::with_capacity(num_shards);
let mut tls_shards = Vec::with_capacity(num_shards);
for _ in 0..num_shards {
shards.push(Mutex::new(ReplayShard::new(cap)));
handshake_shards.push(Mutex::new(ReplayShard::new(cap)));
tls_shards.push(Mutex::new(ReplayShard::new(cap)));
}
Self {
shards,
handshake_shards,
tls_shards,
shard_mask: num_shards - 1,
window,
tls_window: window.max(MIN_TLS_REPLAY_WINDOW),
checks: AtomicU64::new(0),
hits: AtomicU64::new(0),
additions: AtomicU64::new(0),
@@ -1613,46 +1620,60 @@ impl ReplayChecker {
(hasher.finish() as usize) & self.shard_mask
}
fn check_and_add_internal(&self, data: &[u8]) -> bool {
fn check_and_add_internal(
&self,
data: &[u8],
shards: &[Mutex<ReplayShard>],
window: Duration,
) -> bool {
self.checks.fetch_add(1, Ordering::Relaxed);
let idx = self.get_shard_idx(data);
let mut shard = self.shards[idx].lock();
let mut shard = shards[idx].lock();
let now = Instant::now();
let found = shard.check(data, now, self.window);
let found = shard.check(data, now, window);
if found {
self.hits.fetch_add(1, Ordering::Relaxed);
} else {
shard.add(data, now, self.window);
shard.add(data, now, window);
self.additions.fetch_add(1, Ordering::Relaxed);
}
found
}
fn add_only(&self, data: &[u8]) {
fn add_only(&self, data: &[u8], shards: &[Mutex<ReplayShard>], window: Duration) {
self.additions.fetch_add(1, Ordering::Relaxed);
let idx = self.get_shard_idx(data);
let mut shard = self.shards[idx].lock();
shard.add(data, Instant::now(), self.window);
let mut shard = shards[idx].lock();
shard.add(data, Instant::now(), window);
}
pub fn check_and_add_handshake(&self, data: &[u8]) -> bool {
self.check_and_add_internal(data)
self.check_and_add_internal(data, &self.handshake_shards, self.window)
}
pub fn check_and_add_tls_digest(&self, data: &[u8]) -> bool {
self.check_and_add_internal(data)
self.check_and_add_internal(data, &self.tls_shards, self.tls_window)
}
// Compatibility helpers (non-atomic split operations) — prefer check_and_add_*.
pub fn check_handshake(&self, data: &[u8]) -> bool { self.check_and_add_handshake(data) }
pub fn add_handshake(&self, data: &[u8]) { self.add_only(data) }
pub fn add_handshake(&self, data: &[u8]) {
self.add_only(data, &self.handshake_shards, self.window)
}
pub fn check_tls_digest(&self, data: &[u8]) -> bool { self.check_and_add_tls_digest(data) }
pub fn add_tls_digest(&self, data: &[u8]) { self.add_only(data) }
pub fn add_tls_digest(&self, data: &[u8]) {
self.add_only(data, &self.tls_shards, self.tls_window)
}
pub fn stats(&self) -> ReplayStats {
let mut total_entries = 0;
let mut total_queue_len = 0;
for shard in &self.shards {
for shard in &self.handshake_shards {
let s = shard.lock();
total_entries += s.cache.len();
total_queue_len += s.queue.len();
}
for shard in &self.tls_shards {
let s = shard.lock();
total_entries += s.cache.len();
total_queue_len += s.queue.len();
@@ -1665,7 +1686,7 @@ impl ReplayChecker {
total_hits: self.hits.load(Ordering::Relaxed),
total_additions: self.additions.load(Ordering::Relaxed),
total_cleanups: self.cleanups.load(Ordering::Relaxed),
num_shards: self.shards.len(),
num_shards: self.handshake_shards.len() + self.tls_shards.len(),
window_secs: self.window.as_secs(),
}
}
@@ -1683,13 +1704,20 @@ impl ReplayChecker {
let now = Instant::now();
let mut cleaned = 0usize;
for shard_mutex in &self.shards {
for shard_mutex in &self.handshake_shards {
let mut shard = shard_mutex.lock();
let before = shard.len();
shard.cleanup(now, self.window);
let after = shard.len();
cleaned += before.saturating_sub(after);
}
for shard_mutex in &self.tls_shards {
let mut shard = shard_mutex.lock();
let before = shard.len();
shard.cleanup(now, self.tls_window);
let after = shard.len();
cleaned += before.saturating_sub(after);
}
self.cleanups.fetch_add(1, Ordering::Relaxed);
@@ -1815,7 +1843,7 @@ mod tests {
fn test_replay_checker_many_keys() {
let checker = ReplayChecker::new(10_000, Duration::from_secs(60));
for i in 0..500u32 {
checker.add_only(&i.to_le_bytes());
checker.add_handshake(&i.to_le_bytes());
}
for i in 0..500u32 {
assert!(checker.check_handshake(&i.to_le_bytes()));
@@ -1827,3 +1855,7 @@ mod tests {
#[cfg(test)]
#[path = "connection_lease_security_tests.rs"]
mod connection_lease_security_tests;
#[cfg(test)]
#[path = "replay_checker_security_tests.rs"]
mod replay_checker_security_tests;

View File

@@ -0,0 +1,80 @@
use super::*;
use std::time::Duration;
#[test]
fn replay_checker_keeps_tls_and_handshake_domains_isolated_for_same_key() {
let checker = ReplayChecker::new(128, Duration::from_millis(20));
let key = b"same-key-domain-separation";
assert!(
!checker.check_and_add_handshake(key),
"first handshake use should be fresh"
);
assert!(
!checker.check_and_add_tls_digest(key),
"same bytes in TLS domain should still be fresh"
);
assert!(
checker.check_and_add_handshake(key),
"second handshake use should be replay-hit"
);
assert!(
checker.check_and_add_tls_digest(key),
"second TLS use should be replay-hit independently"
);
}
#[test]
fn replay_checker_tls_window_is_clamped_beyond_small_handshake_window() {
let checker = ReplayChecker::new(128, Duration::from_millis(20));
let handshake_key = b"short-window-handshake";
let tls_key = b"short-window-tls";
assert!(!checker.check_and_add_handshake(handshake_key));
assert!(!checker.check_and_add_tls_digest(tls_key));
std::thread::sleep(Duration::from_millis(80));
assert!(
!checker.check_and_add_handshake(handshake_key),
"handshake key should expire under short configured window"
);
assert!(
checker.check_and_add_tls_digest(tls_key),
"TLS key should still be replay-hit because TLS window is clamped to a secure minimum"
);
}
#[test]
fn replay_checker_compat_add_paths_do_not_cross_pollute_domains() {
let checker = ReplayChecker::new(128, Duration::from_secs(1));
let key = b"compat-domain-separation";
checker.add_handshake(key);
assert!(
checker.check_and_add_handshake(key),
"handshake add helper must populate handshake domain"
);
assert!(
!checker.check_and_add_tls_digest(key),
"handshake add helper must not pollute TLS domain"
);
checker.add_tls_digest(key);
assert!(
checker.check_and_add_tls_digest(key),
"TLS add helper must populate TLS domain"
);
}
#[test]
fn replay_checker_stats_reflect_dual_shard_domains() {
let checker = ReplayChecker::new(128, Duration::from_secs(1));
let stats = checker.stats();
assert_eq!(
stats.num_shards, 128,
"stats should expose both shard domains (handshake + TLS)"
);
}