mirror of https://github.com/telemt/telemt.git
Enhance TLS Emulator with ALPN Support and Add Adversarial Tests
- Modified `build_emulated_server_hello` to accept ALPN (Application-Layer Protocol Negotiation) as an optional parameter, allowing for the embedding of ALPN markers in the application data payload. - Implemented logic to handle oversized ALPN values and ensure they do not interfere with the application data payload. - Added new security tests in `emulator_security_tests.rs` to validate the behavior of the ALPN embedding, including scenarios for oversized ALPN and preference for certificate payloads over ALPN markers. - Introduced `send_adversarial_tests.rs` to cover edge cases and potential issues in the middle proxy's send functionality, ensuring robustness against various failure modes. - Updated `middle_proxy` module to include new test modules and ensure proper handling of writer commands during data transmission.
This commit is contained in:
parent
97d4a1c5c8
commit
20e205189c
|
|
@ -7,8 +7,9 @@ use std::net::IpAddr;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::sync::atomic::{AtomicU64, Ordering};
|
use std::sync::atomic::{AtomicU64, Ordering};
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
|
use std::sync::Mutex;
|
||||||
|
|
||||||
use tokio::sync::RwLock;
|
use tokio::sync::{Mutex as AsyncMutex, RwLock};
|
||||||
|
|
||||||
use crate::config::UserMaxUniqueIpsMode;
|
use crate::config::UserMaxUniqueIpsMode;
|
||||||
|
|
||||||
|
|
@ -21,6 +22,8 @@ pub struct UserIpTracker {
|
||||||
limit_mode: Arc<RwLock<UserMaxUniqueIpsMode>>,
|
limit_mode: Arc<RwLock<UserMaxUniqueIpsMode>>,
|
||||||
limit_window: Arc<RwLock<Duration>>,
|
limit_window: Arc<RwLock<Duration>>,
|
||||||
last_compact_epoch_secs: Arc<AtomicU64>,
|
last_compact_epoch_secs: Arc<AtomicU64>,
|
||||||
|
pub(crate) cleanup_queue: Arc<Mutex<Vec<(String, IpAddr)>>>,
|
||||||
|
cleanup_drain_lock: Arc<AsyncMutex<()>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl UserIpTracker {
|
impl UserIpTracker {
|
||||||
|
|
@ -33,6 +36,67 @@ impl UserIpTracker {
|
||||||
limit_mode: Arc::new(RwLock::new(UserMaxUniqueIpsMode::ActiveWindow)),
|
limit_mode: Arc::new(RwLock::new(UserMaxUniqueIpsMode::ActiveWindow)),
|
||||||
limit_window: Arc::new(RwLock::new(Duration::from_secs(30))),
|
limit_window: Arc::new(RwLock::new(Duration::from_secs(30))),
|
||||||
last_compact_epoch_secs: Arc::new(AtomicU64::new(0)),
|
last_compact_epoch_secs: Arc::new(AtomicU64::new(0)),
|
||||||
|
cleanup_queue: Arc::new(Mutex::new(Vec::new())),
|
||||||
|
cleanup_drain_lock: Arc::new(AsyncMutex::new(())),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
pub fn enqueue_cleanup(&self, user: String, ip: IpAddr) {
|
||||||
|
match self.cleanup_queue.lock() {
|
||||||
|
Ok(mut queue) => queue.push((user, ip)),
|
||||||
|
Err(poisoned) => {
|
||||||
|
let mut queue = poisoned.into_inner();
|
||||||
|
queue.push((user.clone(), ip));
|
||||||
|
self.cleanup_queue.clear_poison();
|
||||||
|
tracing::warn!(
|
||||||
|
"UserIpTracker cleanup_queue lock poisoned; recovered and enqueued IP cleanup for {} ({})",
|
||||||
|
user,
|
||||||
|
ip
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) async fn drain_cleanup_queue(&self) {
|
||||||
|
// Serialize queue draining and active-IP mutation so check-and-add cannot
|
||||||
|
// observe stale active entries that are already queued for removal.
|
||||||
|
let _drain_guard = self.cleanup_drain_lock.lock().await;
|
||||||
|
let to_remove = {
|
||||||
|
match self.cleanup_queue.lock() {
|
||||||
|
Ok(mut queue) => {
|
||||||
|
if queue.is_empty() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
std::mem::take(&mut *queue)
|
||||||
|
}
|
||||||
|
Err(poisoned) => {
|
||||||
|
let mut queue = poisoned.into_inner();
|
||||||
|
if queue.is_empty() {
|
||||||
|
self.cleanup_queue.clear_poison();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
let drained = std::mem::take(&mut *queue);
|
||||||
|
self.cleanup_queue.clear_poison();
|
||||||
|
drained
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut active_ips = self.active_ips.write().await;
|
||||||
|
for (user, ip) in to_remove {
|
||||||
|
if let Some(user_ips) = active_ips.get_mut(&user) {
|
||||||
|
if let Some(count) = user_ips.get_mut(&ip) {
|
||||||
|
if *count > 1 {
|
||||||
|
*count -= 1;
|
||||||
|
} else {
|
||||||
|
user_ips.remove(&ip);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if user_ips.is_empty() {
|
||||||
|
active_ips.remove(&user);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -118,6 +182,7 @@ impl UserIpTracker {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn check_and_add(&self, username: &str, ip: IpAddr) -> Result<(), String> {
|
pub async fn check_and_add(&self, username: &str, ip: IpAddr) -> Result<(), String> {
|
||||||
|
self.drain_cleanup_queue().await;
|
||||||
self.maybe_compact_empty_users().await;
|
self.maybe_compact_empty_users().await;
|
||||||
let default_max_ips = *self.default_max_ips.read().await;
|
let default_max_ips = *self.default_max_ips.read().await;
|
||||||
let limit = {
|
let limit = {
|
||||||
|
|
@ -194,6 +259,7 @@ impl UserIpTracker {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn get_recent_counts_for_users(&self, users: &[String]) -> HashMap<String, usize> {
|
pub async fn get_recent_counts_for_users(&self, users: &[String]) -> HashMap<String, usize> {
|
||||||
|
self.drain_cleanup_queue().await;
|
||||||
let window = *self.limit_window.read().await;
|
let window = *self.limit_window.read().await;
|
||||||
let now = Instant::now();
|
let now = Instant::now();
|
||||||
let recent_ips = self.recent_ips.read().await;
|
let recent_ips = self.recent_ips.read().await;
|
||||||
|
|
@ -214,6 +280,7 @@ impl UserIpTracker {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn get_active_ips_for_users(&self, users: &[String]) -> HashMap<String, Vec<IpAddr>> {
|
pub async fn get_active_ips_for_users(&self, users: &[String]) -> HashMap<String, Vec<IpAddr>> {
|
||||||
|
self.drain_cleanup_queue().await;
|
||||||
let active_ips = self.active_ips.read().await;
|
let active_ips = self.active_ips.read().await;
|
||||||
let mut out = HashMap::with_capacity(users.len());
|
let mut out = HashMap::with_capacity(users.len());
|
||||||
for user in users {
|
for user in users {
|
||||||
|
|
@ -228,6 +295,7 @@ impl UserIpTracker {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn get_recent_ips_for_users(&self, users: &[String]) -> HashMap<String, Vec<IpAddr>> {
|
pub async fn get_recent_ips_for_users(&self, users: &[String]) -> HashMap<String, Vec<IpAddr>> {
|
||||||
|
self.drain_cleanup_queue().await;
|
||||||
let window = *self.limit_window.read().await;
|
let window = *self.limit_window.read().await;
|
||||||
let now = Instant::now();
|
let now = Instant::now();
|
||||||
let recent_ips = self.recent_ips.read().await;
|
let recent_ips = self.recent_ips.read().await;
|
||||||
|
|
@ -250,11 +318,13 @@ impl UserIpTracker {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn get_active_ip_count(&self, username: &str) -> usize {
|
pub async fn get_active_ip_count(&self, username: &str) -> usize {
|
||||||
|
self.drain_cleanup_queue().await;
|
||||||
let active_ips = self.active_ips.read().await;
|
let active_ips = self.active_ips.read().await;
|
||||||
active_ips.get(username).map(|ips| ips.len()).unwrap_or(0)
|
active_ips.get(username).map(|ips| ips.len()).unwrap_or(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn get_active_ips(&self, username: &str) -> Vec<IpAddr> {
|
pub async fn get_active_ips(&self, username: &str) -> Vec<IpAddr> {
|
||||||
|
self.drain_cleanup_queue().await;
|
||||||
let active_ips = self.active_ips.read().await;
|
let active_ips = self.active_ips.read().await;
|
||||||
active_ips
|
active_ips
|
||||||
.get(username)
|
.get(username)
|
||||||
|
|
@ -263,6 +333,7 @@ impl UserIpTracker {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn get_stats(&self) -> Vec<(String, usize, usize)> {
|
pub async fn get_stats(&self) -> Vec<(String, usize, usize)> {
|
||||||
|
self.drain_cleanup_queue().await;
|
||||||
let active_ips = self.active_ips.read().await;
|
let active_ips = self.active_ips.read().await;
|
||||||
let max_ips = self.max_ips.read().await;
|
let max_ips = self.max_ips.read().await;
|
||||||
let default_max_ips = *self.default_max_ips.read().await;
|
let default_max_ips = *self.default_max_ips.read().await;
|
||||||
|
|
@ -301,6 +372,7 @@ impl UserIpTracker {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn is_ip_active(&self, username: &str, ip: IpAddr) -> bool {
|
pub async fn is_ip_active(&self, username: &str, ip: IpAddr) -> bool {
|
||||||
|
self.drain_cleanup_queue().await;
|
||||||
let active_ips = self.active_ips.read().await;
|
let active_ips = self.active_ips.read().await;
|
||||||
active_ips
|
active_ips
|
||||||
.get(username)
|
.get(username)
|
||||||
|
|
|
||||||
|
|
@ -448,3 +448,172 @@ async fn concurrent_reconnect_and_disconnect_preserves_non_negative_counts() {
|
||||||
|
|
||||||
assert!(tracker.get_active_ip_count("cc").await <= 8);
|
assert!(tracker.get_active_ip_count("cc").await <= 8);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn enqueue_cleanup_recovers_from_poisoned_mutex() {
|
||||||
|
let tracker = UserIpTracker::new();
|
||||||
|
let ip = ip_from_idx(99);
|
||||||
|
|
||||||
|
// Poison the lock by panicking while holding it
|
||||||
|
let result = std::panic::catch_unwind(|| {
|
||||||
|
let _guard = tracker.cleanup_queue.lock().unwrap();
|
||||||
|
panic!("Intentional poison panic");
|
||||||
|
});
|
||||||
|
assert!(result.is_err(), "Expected panic to poison mutex");
|
||||||
|
|
||||||
|
// Attempt to enqueue anyway; should hit the poison catch arm and still insert
|
||||||
|
tracker.enqueue_cleanup("poison-user".to_string(), ip);
|
||||||
|
|
||||||
|
tracker.drain_cleanup_queue().await;
|
||||||
|
|
||||||
|
assert_eq!(tracker.get_active_ip_count("poison-user").await, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||||
|
async fn mass_reconnect_sync_cleanup_prevents_temporary_reservation_bloat() {
|
||||||
|
// Tests that synchronous M-01 drop mechanism protects against starvation
|
||||||
|
let tracker = Arc::new(UserIpTracker::new());
|
||||||
|
tracker.set_user_limit("mass", 5).await;
|
||||||
|
|
||||||
|
let ip = ip_from_idx(42);
|
||||||
|
let mut join_handles = Vec::new();
|
||||||
|
|
||||||
|
// 10,000 rapid concurrent requests hitting the same IP limit
|
||||||
|
for _ in 0..10_000 {
|
||||||
|
let tracker_clone = tracker.clone();
|
||||||
|
join_handles.push(tokio::spawn(async move {
|
||||||
|
if tracker_clone.check_and_add("mass", ip).await.is_ok() {
|
||||||
|
// Instantly enqueue cleanup, simulating synchronous reservation drop
|
||||||
|
tracker_clone.enqueue_cleanup("mass".to_string(), ip);
|
||||||
|
// The next caller will drain it before acquiring again
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
for handle in join_handles {
|
||||||
|
let _ = handle.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Force flush
|
||||||
|
tracker.drain_cleanup_queue().await;
|
||||||
|
assert_eq!(tracker.get_active_ip_count("mass").await, 0, "No leaked footprints");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn adversarial_drain_cleanup_queue_race_does_not_cause_false_rejections() {
|
||||||
|
// Regression guard: concurrent cleanup draining must not produce false
|
||||||
|
// limit denials for a new IP when the previous IP is already queued.
|
||||||
|
let tracker = Arc::new(UserIpTracker::new());
|
||||||
|
tracker.set_user_limit("racer", 1).await;
|
||||||
|
let ip1 = ip_from_idx(1);
|
||||||
|
let ip2 = ip_from_idx(2);
|
||||||
|
|
||||||
|
// Initial state: add ip1
|
||||||
|
tracker.check_and_add("racer", ip1).await.unwrap();
|
||||||
|
|
||||||
|
// User disconnects from ip1, queuing it
|
||||||
|
tracker.enqueue_cleanup("racer".to_string(), ip1);
|
||||||
|
|
||||||
|
let mut saw_false_rejection = false;
|
||||||
|
for _ in 0..100 {
|
||||||
|
// Queue cleanup then race explicit drain and check-and-add on the alternative IP.
|
||||||
|
tracker.enqueue_cleanup("racer".to_string(), ip1);
|
||||||
|
let tracker_a = tracker.clone();
|
||||||
|
let tracker_b = tracker.clone();
|
||||||
|
|
||||||
|
let drain_handle = tokio::spawn(async move {
|
||||||
|
tracker_a.drain_cleanup_queue().await;
|
||||||
|
});
|
||||||
|
let handle = tokio::spawn(async move {
|
||||||
|
tracker_b.check_and_add("racer", ip2).await
|
||||||
|
});
|
||||||
|
|
||||||
|
drain_handle.await.unwrap();
|
||||||
|
let res = handle.await.unwrap();
|
||||||
|
if res.is_err() {
|
||||||
|
saw_false_rejection = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Restore baseline for next iteration.
|
||||||
|
tracker.remove_ip("racer", ip2).await;
|
||||||
|
tracker.check_and_add("racer", ip1).await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
!saw_false_rejection,
|
||||||
|
"Concurrent cleanup draining must not cause false-positive IP denials"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn poisoned_cleanup_queue_still_releases_slot_for_next_ip() {
|
||||||
|
let tracker = UserIpTracker::new();
|
||||||
|
tracker.set_user_limit("poison-slot", 1).await;
|
||||||
|
let ip1 = ip_from_idx(7001);
|
||||||
|
let ip2 = ip_from_idx(7002);
|
||||||
|
|
||||||
|
tracker.check_and_add("poison-slot", ip1).await.unwrap();
|
||||||
|
|
||||||
|
// Poison the queue lock as an adversarial condition.
|
||||||
|
let _ = std::panic::catch_unwind(|| {
|
||||||
|
let _guard = tracker.cleanup_queue.lock().unwrap();
|
||||||
|
panic!("intentional queue poison");
|
||||||
|
});
|
||||||
|
|
||||||
|
// Disconnect path must still queue cleanup so the next IP can be admitted.
|
||||||
|
tracker.enqueue_cleanup("poison-slot".to_string(), ip1);
|
||||||
|
let admitted = tracker.check_and_add("poison-slot", ip2).await;
|
||||||
|
assert!(
|
||||||
|
admitted.is_ok(),
|
||||||
|
"cleanup queue poison must not permanently block slot release for the next IP"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn duplicate_cleanup_entries_do_not_break_future_admission() {
|
||||||
|
let tracker = UserIpTracker::new();
|
||||||
|
tracker.set_user_limit("dup-cleanup", 1).await;
|
||||||
|
let ip1 = ip_from_idx(7101);
|
||||||
|
let ip2 = ip_from_idx(7102);
|
||||||
|
|
||||||
|
tracker.check_and_add("dup-cleanup", ip1).await.unwrap();
|
||||||
|
tracker.enqueue_cleanup("dup-cleanup".to_string(), ip1);
|
||||||
|
tracker.enqueue_cleanup("dup-cleanup".to_string(), ip1);
|
||||||
|
tracker.enqueue_cleanup("dup-cleanup".to_string(), ip1);
|
||||||
|
|
||||||
|
tracker.drain_cleanup_queue().await;
|
||||||
|
|
||||||
|
assert_eq!(tracker.get_active_ip_count("dup-cleanup").await, 0);
|
||||||
|
assert!(
|
||||||
|
tracker.check_and_add("dup-cleanup", ip2).await.is_ok(),
|
||||||
|
"extra queued cleanup entries must not leave user stuck in denied state"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn stress_repeated_queue_poison_recovery_preserves_admission_progress() {
|
||||||
|
let tracker = UserIpTracker::new();
|
||||||
|
tracker.set_user_limit("poison-stress", 1).await;
|
||||||
|
let ip_primary = ip_from_idx(7201);
|
||||||
|
let ip_alt = ip_from_idx(7202);
|
||||||
|
|
||||||
|
tracker.check_and_add("poison-stress", ip_primary).await.unwrap();
|
||||||
|
|
||||||
|
for _ in 0..64 {
|
||||||
|
let _ = std::panic::catch_unwind(|| {
|
||||||
|
let _guard = tracker.cleanup_queue.lock().unwrap();
|
||||||
|
panic!("intentional queue poison in stress loop");
|
||||||
|
});
|
||||||
|
|
||||||
|
tracker.enqueue_cleanup("poison-stress".to_string(), ip_primary);
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
tracker.check_and_add("poison-stress", ip_alt).await.is_ok(),
|
||||||
|
"poison recovery must preserve admission progress under repeated queue poisoning"
|
||||||
|
);
|
||||||
|
|
||||||
|
tracker.remove_ip("poison-stress", ip_alt).await;
|
||||||
|
tracker.check_and_add("poison-stress", ip_primary).await.unwrap();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -34,6 +34,9 @@ pub const TIME_SKEW_MIN: i64 = -2 * 60; // 2 minutes before
|
||||||
pub const TIME_SKEW_MAX: i64 = 2 * 60; // 2 minutes after
|
pub const TIME_SKEW_MAX: i64 = 2 * 60; // 2 minutes after
|
||||||
/// Maximum accepted boot-time timestamp (seconds) before skew checks are enforced.
|
/// Maximum accepted boot-time timestamp (seconds) before skew checks are enforced.
|
||||||
pub const BOOT_TIME_MAX_SECS: u32 = 7 * 24 * 60 * 60;
|
pub const BOOT_TIME_MAX_SECS: u32 = 7 * 24 * 60 * 60;
|
||||||
|
/// Hard cap for boot-time compatibility bypass to avoid oversized acceptance
|
||||||
|
/// windows when replay TTL is configured very large.
|
||||||
|
pub const BOOT_TIME_COMPAT_MAX_SECS: u32 = 2 * 60;
|
||||||
|
|
||||||
// ============= Private Constants =============
|
// ============= Private Constants =============
|
||||||
|
|
||||||
|
|
@ -66,6 +69,7 @@ pub struct TlsValidation {
|
||||||
/// Client digest for response generation
|
/// Client digest for response generation
|
||||||
pub digest: [u8; TLS_DIGEST_LEN],
|
pub digest: [u8; TLS_DIGEST_LEN],
|
||||||
/// Timestamp extracted from digest
|
/// Timestamp extracted from digest
|
||||||
|
|
||||||
pub timestamp: u32,
|
pub timestamp: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -121,6 +125,7 @@ impl TlsExtensionBuilder {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Build final extensions with length prefix
|
/// Build final extensions with length prefix
|
||||||
|
|
||||||
fn build(self) -> Vec<u8> {
|
fn build(self) -> Vec<u8> {
|
||||||
let mut result = Vec::with_capacity(2 + self.extensions.len());
|
let mut result = Vec::with_capacity(2 + self.extensions.len());
|
||||||
|
|
||||||
|
|
@ -135,7 +140,7 @@ impl TlsExtensionBuilder {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get current extensions without length prefix (for calculation)
|
/// Get current extensions without length prefix (for calculation)
|
||||||
#[allow(dead_code)]
|
|
||||||
fn as_bytes(&self) -> &[u8] {
|
fn as_bytes(&self) -> &[u8] {
|
||||||
&self.extensions
|
&self.extensions
|
||||||
}
|
}
|
||||||
|
|
@ -251,6 +256,7 @@ impl ServerHelloBuilder {
|
||||||
/// Returns validation result if a matching user is found.
|
/// Returns validation result if a matching user is found.
|
||||||
/// The result **must** be used — ignoring it silently bypasses authentication.
|
/// The result **must** be used — ignoring it silently bypasses authentication.
|
||||||
#[must_use]
|
#[must_use]
|
||||||
|
|
||||||
pub fn validate_tls_handshake(
|
pub fn validate_tls_handshake(
|
||||||
handshake: &[u8],
|
handshake: &[u8],
|
||||||
secrets: &[(String, Vec<u8>)],
|
secrets: &[(String, Vec<u8>)],
|
||||||
|
|
@ -266,9 +272,9 @@ pub fn validate_tls_handshake(
|
||||||
|
|
||||||
/// Validate TLS ClientHello and cap the boot-time bypass by replay-cache TTL.
|
/// Validate TLS ClientHello and cap the boot-time bypass by replay-cache TTL.
|
||||||
///
|
///
|
||||||
/// A boot-time timestamp is only accepted when it falls below both
|
/// A boot-time timestamp is only accepted when it falls below all three
|
||||||
/// `BOOT_TIME_MAX_SECS` and the configured replay window, preventing timestamp
|
/// bounds: `BOOT_TIME_MAX_SECS`, configured replay window, and
|
||||||
/// reuse outside replay cache coverage.
|
/// `BOOT_TIME_COMPAT_MAX_SECS`, preventing oversized compatibility windows.
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn validate_tls_handshake_with_replay_window(
|
pub fn validate_tls_handshake_with_replay_window(
|
||||||
handshake: &[u8],
|
handshake: &[u8],
|
||||||
|
|
@ -292,7 +298,9 @@ pub fn validate_tls_handshake_with_replay_window(
|
||||||
let boot_time_cap_secs = if ignore_time_skew {
|
let boot_time_cap_secs = if ignore_time_skew {
|
||||||
0
|
0
|
||||||
} else {
|
} else {
|
||||||
BOOT_TIME_MAX_SECS.min(replay_window_u32)
|
BOOT_TIME_MAX_SECS
|
||||||
|
.min(replay_window_u32)
|
||||||
|
.min(BOOT_TIME_COMPAT_MAX_SECS)
|
||||||
};
|
};
|
||||||
|
|
||||||
validate_tls_handshake_at_time_with_boot_cap(
|
validate_tls_handshake_at_time_with_boot_cap(
|
||||||
|
|
@ -312,6 +320,7 @@ fn system_time_to_unix_secs(now: SystemTime) -> Option<i64> {
|
||||||
i64::try_from(d.as_secs()).ok()
|
i64::try_from(d.as_secs()).ok()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
fn validate_tls_handshake_at_time(
|
fn validate_tls_handshake_at_time(
|
||||||
handshake: &[u8],
|
handshake: &[u8],
|
||||||
secrets: &[(String, Vec<u8>)],
|
secrets: &[(String, Vec<u8>)],
|
||||||
|
|
@ -437,7 +446,7 @@ pub fn build_server_hello(
|
||||||
session_id: &[u8],
|
session_id: &[u8],
|
||||||
fake_cert_len: usize,
|
fake_cert_len: usize,
|
||||||
rng: &SecureRandom,
|
rng: &SecureRandom,
|
||||||
_alpn: Option<Vec<u8>>,
|
alpn: Option<Vec<u8>>,
|
||||||
new_session_tickets: u8,
|
new_session_tickets: u8,
|
||||||
) -> Vec<u8> {
|
) -> Vec<u8> {
|
||||||
const MIN_APP_DATA: usize = 64;
|
const MIN_APP_DATA: usize = 64;
|
||||||
|
|
@ -459,8 +468,27 @@ pub fn build_server_hello(
|
||||||
0x01, // CCS byte
|
0x01, // CCS byte
|
||||||
];
|
];
|
||||||
|
|
||||||
// Build fake certificate (Application Data record)
|
// Build first encrypted flight mimic as opaque ApplicationData bytes.
|
||||||
let fake_cert = rng.bytes(fake_cert_len);
|
// Embed a compact EncryptedExtensions-like ALPN block when selected.
|
||||||
|
let mut fake_cert = Vec::with_capacity(fake_cert_len);
|
||||||
|
if let Some(proto) = alpn.as_ref().filter(|p| !p.is_empty() && p.len() <= u8::MAX as usize) {
|
||||||
|
let proto_list_len = 1usize + proto.len();
|
||||||
|
let ext_data_len = 2usize + proto_list_len;
|
||||||
|
let marker_len = 4usize + ext_data_len;
|
||||||
|
if marker_len <= fake_cert_len {
|
||||||
|
fake_cert.extend_from_slice(&0x0010u16.to_be_bytes());
|
||||||
|
fake_cert.extend_from_slice(&(ext_data_len as u16).to_be_bytes());
|
||||||
|
fake_cert.extend_from_slice(&(proto_list_len as u16).to_be_bytes());
|
||||||
|
fake_cert.push(proto.len() as u8);
|
||||||
|
fake_cert.extend_from_slice(proto);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if fake_cert.len() < fake_cert_len {
|
||||||
|
fake_cert.extend_from_slice(&rng.bytes(fake_cert_len - fake_cert.len()));
|
||||||
|
} else if fake_cert.len() > fake_cert_len {
|
||||||
|
fake_cert.truncate(fake_cert_len);
|
||||||
|
}
|
||||||
|
|
||||||
let mut app_data_record = Vec::with_capacity(5 + fake_cert_len);
|
let mut app_data_record = Vec::with_capacity(5 + fake_cert_len);
|
||||||
app_data_record.push(TLS_RECORD_APPLICATION);
|
app_data_record.push(TLS_RECORD_APPLICATION);
|
||||||
app_data_record.extend_from_slice(&TLS_VERSION);
|
app_data_record.extend_from_slice(&TLS_VERSION);
|
||||||
|
|
@ -472,8 +500,9 @@ pub fn build_server_hello(
|
||||||
// Build optional NewSessionTicket records (TLS 1.3 handshake messages are encrypted;
|
// Build optional NewSessionTicket records (TLS 1.3 handshake messages are encrypted;
|
||||||
// here we mimic with opaque ApplicationData records of plausible size).
|
// here we mimic with opaque ApplicationData records of plausible size).
|
||||||
let mut tickets = Vec::new();
|
let mut tickets = Vec::new();
|
||||||
if new_session_tickets > 0 {
|
let ticket_count = new_session_tickets.min(4);
|
||||||
for _ in 0..new_session_tickets {
|
if ticket_count > 0 {
|
||||||
|
for _ in 0..ticket_count {
|
||||||
let ticket_len: usize = rng.range(48) + 48; // 48-95 bytes
|
let ticket_len: usize = rng.range(48) + 48; // 48-95 bytes
|
||||||
let mut record = Vec::with_capacity(5 + ticket_len);
|
let mut record = Vec::with_capacity(5 + ticket_len);
|
||||||
record.push(TLS_RECORD_APPLICATION);
|
record.push(TLS_RECORD_APPLICATION);
|
||||||
|
|
@ -678,6 +707,7 @@ pub fn is_tls_handshake(first_bytes: &[u8]) -> bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Parse TLS record header, returns (record_type, length)
|
/// Parse TLS record header, returns (record_type, length)
|
||||||
|
|
||||||
pub fn parse_tls_record_header(header: &[u8; 5]) -> Option<(u8, u16)> {
|
pub fn parse_tls_record_header(header: &[u8; 5]) -> Option<(u8, u16)> {
|
||||||
let record_type = header[0];
|
let record_type = header[0];
|
||||||
let version = [header[1], header[2]];
|
let version = [header[1], header[2]];
|
||||||
|
|
|
||||||
|
|
@ -300,8 +300,8 @@ fn boot_time_timestamp_accepted_without_ignore_flag() {
|
||||||
// Timestamps below the boot-time threshold are treated as client uptime,
|
// Timestamps below the boot-time threshold are treated as client uptime,
|
||||||
// not real wall-clock time. The proxy allows them regardless of skew.
|
// not real wall-clock time. The proxy allows them regardless of skew.
|
||||||
let secret = b"boot_time_test";
|
let secret = b"boot_time_test";
|
||||||
// Keep this safely below BOOT_TIME_MAX_SECS to assert bypass behavior.
|
// Keep this safely below compatibility cap to assert bypass behavior.
|
||||||
let boot_ts: u32 = BOOT_TIME_MAX_SECS / 2;
|
let boot_ts: u32 = BOOT_TIME_COMPAT_MAX_SECS.saturating_sub(1);
|
||||||
let handshake = make_valid_tls_handshake(secret, boot_ts);
|
let handshake = make_valid_tls_handshake(secret, boot_ts);
|
||||||
let secrets = vec![("u".to_string(), secret.to_vec())];
|
let secrets = vec![("u".to_string(), secret.to_vec())];
|
||||||
assert!(
|
assert!(
|
||||||
|
|
@ -663,13 +663,14 @@ fn zero_length_session_id_accepted() {
|
||||||
// Boot-time threshold — exact boundary precision
|
// Boot-time threshold — exact boundary precision
|
||||||
// ------------------------------------------------------------------
|
// ------------------------------------------------------------------
|
||||||
|
|
||||||
/// timestamp = BOOT_TIME_MAX_SECS - 1 is the last value inside the boot-time window.
|
/// timestamp = BOOT_TIME_COMPAT_MAX_SECS - 1 is the last value inside
|
||||||
|
/// the runtime boot-time compatibility window.
|
||||||
/// is_boot_time = true → skew check is skipped entirely → accepted even
|
/// is_boot_time = true → skew check is skipped entirely → accepted even
|
||||||
/// when `now` is far from the timestamp.
|
/// when `now` is far from the timestamp.
|
||||||
#[test]
|
#[test]
|
||||||
fn timestamp_one_below_boot_threshold_bypasses_skew_check() {
|
fn timestamp_one_below_boot_threshold_bypasses_skew_check() {
|
||||||
let secret = b"boot_last_value_test";
|
let secret = b"boot_last_value_test";
|
||||||
let ts: u32 = BOOT_TIME_MAX_SECS - 1;
|
let ts: u32 = BOOT_TIME_COMPAT_MAX_SECS - 1;
|
||||||
let h = make_valid_tls_handshake(secret, ts);
|
let h = make_valid_tls_handshake(secret, ts);
|
||||||
let secrets = vec![("u".to_string(), secret.to_vec())];
|
let secrets = vec![("u".to_string(), secret.to_vec())];
|
||||||
|
|
||||||
|
|
@ -677,32 +678,48 @@ fn timestamp_one_below_boot_threshold_bypasses_skew_check() {
|
||||||
// Boot-time bypass must prevent the skew check from running.
|
// Boot-time bypass must prevent the skew check from running.
|
||||||
assert!(
|
assert!(
|
||||||
validate_tls_handshake_at_time(&h, &secrets, false, 0).is_some(),
|
validate_tls_handshake_at_time(&h, &secrets, false, 0).is_some(),
|
||||||
"ts=BOOT_TIME_MAX_SECS-1 must bypass skew check regardless of now"
|
"ts=BOOT_TIME_COMPAT_MAX_SECS-1 must bypass skew check regardless of now"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// timestamp = BOOT_TIME_MAX_SECS is the first value outside the boot-time window.
|
/// timestamp = BOOT_TIME_COMPAT_MAX_SECS is the first value outside the
|
||||||
|
/// runtime boot-time compatibility window.
|
||||||
/// is_boot_time = false → skew check IS applied. Two sub-cases confirm this:
|
/// is_boot_time = false → skew check IS applied. Two sub-cases confirm this:
|
||||||
/// once with now chosen so the skew passes (accepted) and once where it fails.
|
/// once with now chosen so the skew passes (accepted) and once where it fails.
|
||||||
#[test]
|
#[test]
|
||||||
fn timestamp_at_boot_threshold_triggers_skew_check() {
|
fn timestamp_at_boot_threshold_triggers_skew_check() {
|
||||||
let secret = b"boot_exact_value_test";
|
let secret = b"boot_exact_value_test";
|
||||||
let ts: u32 = BOOT_TIME_MAX_SECS;
|
let ts: u32 = BOOT_TIME_COMPAT_MAX_SECS;
|
||||||
let h = make_valid_tls_handshake(secret, ts);
|
let h = make_valid_tls_handshake(secret, ts);
|
||||||
let secrets = vec![("u".to_string(), secret.to_vec())];
|
let secrets = vec![("u".to_string(), secret.to_vec())];
|
||||||
|
|
||||||
// now = ts + 50 → time_diff = 50, within [-1200, 600] → accepted.
|
// now = ts + 50 → time_diff = 50, within [-1200, 600] → accepted.
|
||||||
let now_valid: i64 = ts as i64 + 50;
|
let now_valid: i64 = ts as i64 + 50;
|
||||||
assert!(
|
assert!(
|
||||||
validate_tls_handshake_at_time(&h, &secrets, false, now_valid).is_some(),
|
validate_tls_handshake_at_time_with_boot_cap(
|
||||||
"ts=BOOT_TIME_MAX_SECS within skew window must be accepted via skew check"
|
&h,
|
||||||
|
&secrets,
|
||||||
|
false,
|
||||||
|
now_valid,
|
||||||
|
BOOT_TIME_COMPAT_MAX_SECS,
|
||||||
|
)
|
||||||
|
.is_some(),
|
||||||
|
"ts=BOOT_TIME_COMPAT_MAX_SECS within skew window must be accepted via skew check"
|
||||||
);
|
);
|
||||||
|
|
||||||
// now = 0 → time_diff = -86_400_000, outside window → rejected.
|
// now = -1 → time_diff = -121 at the 120-second threshold, outside window
|
||||||
// If the boot-time bypass were wrongly applied here this would pass.
|
// for TIME_SKEW_MIN=-120. If boot-time bypass were wrongly applied this
|
||||||
|
// would pass.
|
||||||
assert!(
|
assert!(
|
||||||
validate_tls_handshake_at_time(&h, &secrets, false, 0).is_none(),
|
validate_tls_handshake_at_time_with_boot_cap(
|
||||||
"ts=BOOT_TIME_MAX_SECS far from now must be rejected — no boot-time bypass"
|
&h,
|
||||||
|
&secrets,
|
||||||
|
false,
|
||||||
|
-1,
|
||||||
|
BOOT_TIME_COMPAT_MAX_SECS,
|
||||||
|
)
|
||||||
|
.is_none(),
|
||||||
|
"ts=BOOT_TIME_COMPAT_MAX_SECS far from now must be rejected — no boot-time bypass"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -723,7 +740,7 @@ fn replay_window_cap_disables_boot_bypass_for_old_timestamps() {
|
||||||
#[test]
|
#[test]
|
||||||
fn replay_window_cap_still_allows_small_boot_timestamp() {
|
fn replay_window_cap_still_allows_small_boot_timestamp() {
|
||||||
let secret = b"boot_cap_enabled_test";
|
let secret = b"boot_cap_enabled_test";
|
||||||
let ts: u32 = 120;
|
let ts: u32 = BOOT_TIME_COMPAT_MAX_SECS.saturating_sub(1);
|
||||||
let h = make_valid_tls_handshake(secret, ts);
|
let h = make_valid_tls_handshake(secret, ts);
|
||||||
let secrets = vec![("u".to_string(), secret.to_vec())];
|
let secrets = vec![("u".to_string(), secret.to_vec())];
|
||||||
|
|
||||||
|
|
@ -734,6 +751,20 @@ fn replay_window_cap_still_allows_small_boot_timestamp() {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn large_replay_window_is_hard_capped_for_boot_compatibility() {
|
||||||
|
let secret = b"boot_cap_hard_limit_test";
|
||||||
|
let ts: u32 = BOOT_TIME_COMPAT_MAX_SECS + 1;
|
||||||
|
let h = make_valid_tls_handshake(secret, ts);
|
||||||
|
let secrets = vec![("u".to_string(), secret.to_vec())];
|
||||||
|
|
||||||
|
let result = validate_tls_handshake_with_replay_window(&h, &secrets, false, u64::MAX);
|
||||||
|
assert!(
|
||||||
|
result.is_none(),
|
||||||
|
"very large replay window must not expand boot-time bypass beyond hard compatibility cap"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn ignore_time_skew_explicitly_decouples_from_boot_time_cap() {
|
fn ignore_time_skew_explicitly_decouples_from_boot_time_cap() {
|
||||||
let secret = b"ignore_skew_boot_cap_decouple_test";
|
let secret = b"ignore_skew_boot_cap_decouple_test";
|
||||||
|
|
@ -743,7 +774,7 @@ fn ignore_time_skew_explicitly_decouples_from_boot_time_cap() {
|
||||||
|
|
||||||
let cap_zero = validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, true, 0, 0);
|
let cap_zero = validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, true, 0, 0);
|
||||||
let cap_nonzero =
|
let cap_nonzero =
|
||||||
validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, true, 0, BOOT_TIME_MAX_SECS);
|
validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, true, 0, BOOT_TIME_COMPAT_MAX_SECS);
|
||||||
|
|
||||||
assert!(cap_zero.is_some(), "ignore_time_skew=true must accept valid HMAC");
|
assert!(cap_zero.is_some(), "ignore_time_skew=true must accept valid HMAC");
|
||||||
assert!(
|
assert!(
|
||||||
|
|
@ -1889,6 +1920,228 @@ fn server_hello_new_session_ticket_count_matches_configuration() {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn server_hello_new_session_ticket_count_is_safely_capped() {
|
||||||
|
let secret = b"ticket_count_cap_test";
|
||||||
|
let client_digest = [0x44u8; TLS_DIGEST_LEN];
|
||||||
|
let session_id = vec![0x54; 32];
|
||||||
|
let rng = crate::crypto::SecureRandom::new();
|
||||||
|
|
||||||
|
let response = build_server_hello(secret, &client_digest, &session_id, 1024, &rng, None, u8::MAX);
|
||||||
|
|
||||||
|
let mut pos = 0usize;
|
||||||
|
let mut app_records = 0usize;
|
||||||
|
while pos + 5 <= response.len() {
|
||||||
|
let rtype = response[pos];
|
||||||
|
let rlen = u16::from_be_bytes([response[pos + 3], response[pos + 4]]) as usize;
|
||||||
|
let next = pos + 5 + rlen;
|
||||||
|
assert!(next <= response.len(), "TLS record must stay inside response bounds");
|
||||||
|
if rtype == TLS_RECORD_APPLICATION {
|
||||||
|
app_records += 1;
|
||||||
|
}
|
||||||
|
pos = next;
|
||||||
|
}
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
app_records,
|
||||||
|
5,
|
||||||
|
"response must cap ticket-like tail records to four plus one main application record"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn server_hello_application_data_contains_alpn_marker_when_selected() {
|
||||||
|
let secret = b"alpn_marker_test";
|
||||||
|
let client_digest = [0x55u8; TLS_DIGEST_LEN];
|
||||||
|
let session_id = vec![0xAB; 32];
|
||||||
|
let rng = crate::crypto::SecureRandom::new();
|
||||||
|
|
||||||
|
let response = build_server_hello(
|
||||||
|
secret,
|
||||||
|
&client_digest,
|
||||||
|
&session_id,
|
||||||
|
512,
|
||||||
|
&rng,
|
||||||
|
Some(b"h2".to_vec()),
|
||||||
|
0,
|
||||||
|
);
|
||||||
|
|
||||||
|
let sh_len = u16::from_be_bytes([response[3], response[4]]) as usize;
|
||||||
|
let ccs_pos = 5 + sh_len;
|
||||||
|
let ccs_len = u16::from_be_bytes([response[ccs_pos + 3], response[ccs_pos + 4]]) as usize;
|
||||||
|
let app_pos = ccs_pos + 5 + ccs_len;
|
||||||
|
let app_len = u16::from_be_bytes([response[app_pos + 3], response[app_pos + 4]]) as usize;
|
||||||
|
let app_payload = &response[app_pos + 5..app_pos + 5 + app_len];
|
||||||
|
|
||||||
|
let expected = [0x00u8, 0x10, 0x00, 0x05, 0x00, 0x03, 0x02, b'h', b'2'];
|
||||||
|
assert!(
|
||||||
|
app_payload.windows(expected.len()).any(|window| window == expected),
|
||||||
|
"first application payload must carry ALPN marker for selected protocol"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn server_hello_ignores_oversized_alpn_and_still_caps_ticket_tail() {
|
||||||
|
let secret = b"alpn_oversize_ignore_test";
|
||||||
|
let client_digest = [0x56u8; TLS_DIGEST_LEN];
|
||||||
|
let session_id = vec![0xCD; 32];
|
||||||
|
let rng = crate::crypto::SecureRandom::new();
|
||||||
|
let oversized_alpn = vec![b'x'; u8::MAX as usize + 1];
|
||||||
|
|
||||||
|
let response = build_server_hello(
|
||||||
|
secret,
|
||||||
|
&client_digest,
|
||||||
|
&session_id,
|
||||||
|
512,
|
||||||
|
&rng,
|
||||||
|
Some(oversized_alpn),
|
||||||
|
u8::MAX,
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut pos = 0usize;
|
||||||
|
let mut app_records = 0usize;
|
||||||
|
let mut first_app_payload: Option<&[u8]> = None;
|
||||||
|
while pos + 5 <= response.len() {
|
||||||
|
let rtype = response[pos];
|
||||||
|
let rlen = u16::from_be_bytes([response[pos + 3], response[pos + 4]]) as usize;
|
||||||
|
let next = pos + 5 + rlen;
|
||||||
|
assert!(next <= response.len(), "TLS record must stay inside response bounds");
|
||||||
|
if rtype == TLS_RECORD_APPLICATION {
|
||||||
|
app_records += 1;
|
||||||
|
if first_app_payload.is_none() {
|
||||||
|
first_app_payload = Some(&response[pos + 5..next]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pos = next;
|
||||||
|
}
|
||||||
|
let marker = [0x00u8, 0x10, 0x00, 0x06, 0x00, 0x04, 0x03, b'x', b'x', b'x', b'x'];
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
app_records, 5,
|
||||||
|
"oversized ALPN must not change the four-ticket cap on tail records"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
!first_app_payload
|
||||||
|
.expect("response must contain an application record")
|
||||||
|
.windows(marker.len())
|
||||||
|
.any(|window| window == marker),
|
||||||
|
"oversized ALPN must be ignored rather than embedded into the first application payload"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn server_hello_ignores_oversized_alpn_when_marker_would_not_fit() {
|
||||||
|
let secret = b"alpn_too_large_to_fit_test";
|
||||||
|
let client_digest = [0x57u8; TLS_DIGEST_LEN];
|
||||||
|
let session_id = vec![0xEF; 32];
|
||||||
|
let rng = crate::crypto::SecureRandom::new();
|
||||||
|
let oversized_alpn = vec![0xAB; u8::MAX as usize];
|
||||||
|
|
||||||
|
let response = build_server_hello(
|
||||||
|
secret,
|
||||||
|
&client_digest,
|
||||||
|
&session_id,
|
||||||
|
64,
|
||||||
|
&rng,
|
||||||
|
Some(oversized_alpn),
|
||||||
|
0,
|
||||||
|
);
|
||||||
|
|
||||||
|
let sh_len = u16::from_be_bytes([response[3], response[4]]) as usize;
|
||||||
|
let ccs_pos = 5 + sh_len;
|
||||||
|
let ccs_len = u16::from_be_bytes([response[ccs_pos + 3], response[ccs_pos + 4]]) as usize;
|
||||||
|
let app_pos = ccs_pos + 5 + ccs_len;
|
||||||
|
let app_len = u16::from_be_bytes([response[app_pos + 3], response[app_pos + 4]]) as usize;
|
||||||
|
let app_payload = &response[app_pos + 5..app_pos + 5 + app_len];
|
||||||
|
|
||||||
|
let mut marker_prefix = Vec::new();
|
||||||
|
marker_prefix.extend_from_slice(&0x0010u16.to_be_bytes());
|
||||||
|
marker_prefix.extend_from_slice(&0x0102u16.to_be_bytes());
|
||||||
|
marker_prefix.extend_from_slice(&0x0100u16.to_be_bytes());
|
||||||
|
marker_prefix.push(0xff);
|
||||||
|
marker_prefix.extend_from_slice(&[0xab; 8]);
|
||||||
|
assert!(
|
||||||
|
!app_payload.starts_with(&marker_prefix),
|
||||||
|
"oversized ALPN must not be partially embedded into the ServerHello application record"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn server_hello_embeds_full_alpn_marker_when_it_exactly_fits_fake_cert_len() {
|
||||||
|
let secret = b"alpn_exact_fit_test";
|
||||||
|
let client_digest = [0x58u8; TLS_DIGEST_LEN];
|
||||||
|
let session_id = vec![0xA5; 32];
|
||||||
|
let rng = crate::crypto::SecureRandom::new();
|
||||||
|
let proto = vec![b'z'; 57];
|
||||||
|
|
||||||
|
// marker_len = 4 + (2 + (1 + proto_len)) = 7 + proto_len = 64
|
||||||
|
let response = build_server_hello(
|
||||||
|
secret,
|
||||||
|
&client_digest,
|
||||||
|
&session_id,
|
||||||
|
64,
|
||||||
|
&rng,
|
||||||
|
Some(proto.clone()),
|
||||||
|
0,
|
||||||
|
);
|
||||||
|
|
||||||
|
let sh_len = u16::from_be_bytes([response[3], response[4]]) as usize;
|
||||||
|
let ccs_pos = 5 + sh_len;
|
||||||
|
let ccs_len = u16::from_be_bytes([response[ccs_pos + 3], response[ccs_pos + 4]]) as usize;
|
||||||
|
let app_pos = ccs_pos + 5 + ccs_len;
|
||||||
|
let app_len = u16::from_be_bytes([response[app_pos + 3], response[app_pos + 4]]) as usize;
|
||||||
|
let app_payload = &response[app_pos + 5..app_pos + 5 + app_len];
|
||||||
|
|
||||||
|
let mut expected_marker = Vec::new();
|
||||||
|
expected_marker.extend_from_slice(&0x0010u16.to_be_bytes());
|
||||||
|
expected_marker.extend_from_slice(&0x003Cu16.to_be_bytes());
|
||||||
|
expected_marker.extend_from_slice(&0x003Au16.to_be_bytes());
|
||||||
|
expected_marker.push(57u8);
|
||||||
|
expected_marker.extend_from_slice(&proto);
|
||||||
|
|
||||||
|
assert_eq!(app_payload.len(), expected_marker.len());
|
||||||
|
assert_eq!(app_payload, expected_marker.as_slice());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn server_hello_does_not_embed_partial_alpn_marker_when_one_byte_short() {
|
||||||
|
let secret = b"alpn_one_byte_short_test";
|
||||||
|
let client_digest = [0x59u8; TLS_DIGEST_LEN];
|
||||||
|
let session_id = vec![0xA6; 32];
|
||||||
|
let rng = crate::crypto::SecureRandom::new();
|
||||||
|
let proto = vec![0xAB; 58];
|
||||||
|
|
||||||
|
// marker_len = 65, fake_cert_len = 64 => marker must be fully skipped.
|
||||||
|
let response = build_server_hello(
|
||||||
|
secret,
|
||||||
|
&client_digest,
|
||||||
|
&session_id,
|
||||||
|
64,
|
||||||
|
&rng,
|
||||||
|
Some(proto),
|
||||||
|
0,
|
||||||
|
);
|
||||||
|
|
||||||
|
let sh_len = u16::from_be_bytes([response[3], response[4]]) as usize;
|
||||||
|
let ccs_pos = 5 + sh_len;
|
||||||
|
let ccs_len = u16::from_be_bytes([response[ccs_pos + 3], response[ccs_pos + 4]]) as usize;
|
||||||
|
let app_pos = ccs_pos + 5 + ccs_len;
|
||||||
|
let app_len = u16::from_be_bytes([response[app_pos + 3], response[app_pos + 4]]) as usize;
|
||||||
|
let app_payload = &response[app_pos + 5..app_pos + 5 + app_len];
|
||||||
|
|
||||||
|
let mut marker_prefix = Vec::new();
|
||||||
|
marker_prefix.extend_from_slice(&0x0010u16.to_be_bytes());
|
||||||
|
marker_prefix.extend_from_slice(&0x003Du16.to_be_bytes());
|
||||||
|
marker_prefix.extend_from_slice(&0x003Bu16.to_be_bytes());
|
||||||
|
marker_prefix.push(58u8);
|
||||||
|
marker_prefix.extend_from_slice(&[0xAB; 8]);
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
!app_payload.starts_with(&marker_prefix),
|
||||||
|
"one-byte-short ALPN marker must be skipped entirely, not partially embedded"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn exhaustive_tls_minor_version_classification_matches_policy() {
|
fn exhaustive_tls_minor_version_classification_matches_policy() {
|
||||||
for minor in 0u8..=u8::MAX {
|
for minor in 0u8..=u8::MAX {
|
||||||
|
|
|
||||||
|
|
@ -31,19 +31,16 @@ struct UserConnectionReservation {
|
||||||
user: String,
|
user: String,
|
||||||
ip: IpAddr,
|
ip: IpAddr,
|
||||||
active: bool,
|
active: bool,
|
||||||
runtime_handle: Option<tokio::runtime::Handle>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl UserConnectionReservation {
|
impl UserConnectionReservation {
|
||||||
fn new(stats: Arc<Stats>, ip_tracker: Arc<UserIpTracker>, user: String, ip: IpAddr) -> Self {
|
fn new(stats: Arc<Stats>, ip_tracker: Arc<UserIpTracker>, user: String, ip: IpAddr) -> Self {
|
||||||
let runtime_handle = tokio::runtime::Handle::try_current().ok();
|
|
||||||
Self {
|
Self {
|
||||||
stats,
|
stats,
|
||||||
ip_tracker,
|
ip_tracker,
|
||||||
user,
|
user,
|
||||||
ip,
|
ip,
|
||||||
active: true,
|
active: true,
|
||||||
runtime_handle,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -64,29 +61,7 @@ impl Drop for UserConnectionReservation {
|
||||||
}
|
}
|
||||||
self.active = false;
|
self.active = false;
|
||||||
self.stats.decrement_user_curr_connects(&self.user);
|
self.stats.decrement_user_curr_connects(&self.user);
|
||||||
|
self.ip_tracker.enqueue_cleanup(self.user.clone(), self.ip);
|
||||||
if let Some(handle) = &self.runtime_handle {
|
|
||||||
let ip_tracker = self.ip_tracker.clone();
|
|
||||||
let user = self.user.clone();
|
|
||||||
let ip = self.ip;
|
|
||||||
let handle = handle.clone();
|
|
||||||
handle.spawn(async move {
|
|
||||||
ip_tracker.remove_ip(&user, ip).await;
|
|
||||||
});
|
|
||||||
} else if let Ok(handle) = tokio::runtime::Handle::try_current() {
|
|
||||||
let ip_tracker = self.ip_tracker.clone();
|
|
||||||
let user = self.user.clone();
|
|
||||||
let ip = self.ip;
|
|
||||||
handle.spawn(async move {
|
|
||||||
ip_tracker.remove_ip(&user, ip).await;
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
warn!(
|
|
||||||
user = %self.user,
|
|
||||||
ip = %self.ip,
|
|
||||||
"UserConnectionReservation dropped without Tokio runtime; IP reservation cleanup skipped"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -42,6 +42,35 @@ where
|
||||||
CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024)
|
CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn user_connection_reservation_drop_enqueues_cleanup_synchronously() {
|
||||||
|
let ip_tracker = Arc::new(crate::ip_tracker::UserIpTracker::new());
|
||||||
|
let stats = Arc::new(crate::stats::Stats::new());
|
||||||
|
let user = "sync-drop-user".to_string();
|
||||||
|
let ip: std::net::IpAddr = "192.168.1.1".parse().unwrap();
|
||||||
|
|
||||||
|
ip_tracker.set_user_limit(&user, 1).await;
|
||||||
|
ip_tracker.check_and_add(&user, ip).await.unwrap();
|
||||||
|
stats.increment_user_curr_connects(&user);
|
||||||
|
|
||||||
|
assert_eq!(ip_tracker.get_active_ip_count(&user).await, 1);
|
||||||
|
assert_eq!(stats.get_user_curr_connects(&user), 1);
|
||||||
|
|
||||||
|
let reservation = UserConnectionReservation::new(stats.clone(), ip_tracker.clone(), user.clone(), ip);
|
||||||
|
|
||||||
|
// Drop the reservation synchronously without any tokio::spawn/await yielding!
|
||||||
|
drop(reservation);
|
||||||
|
|
||||||
|
// The IP is now inside the cleanup_queue, check that the queue has length 1
|
||||||
|
let queue_len = ip_tracker.cleanup_queue.lock().unwrap().len();
|
||||||
|
assert_eq!(queue_len, 1, "Reservation drop must push directly to synchronized IP queue");
|
||||||
|
|
||||||
|
assert_eq!(stats.get_user_curr_connects(&user), 0, "Stats must decrement immediately");
|
||||||
|
|
||||||
|
ip_tracker.drain_cleanup_queue().await;
|
||||||
|
assert_eq!(ip_tracker.get_active_ip_count(&user).await, 0);
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn relay_task_abort_releases_user_gate_and_ip_reservation() {
|
async fn relay_task_abort_releases_user_gate_and_ip_reservation() {
|
||||||
let tg_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
let tg_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
|
|
||||||
|
|
@ -132,7 +132,11 @@ fn open_unknown_dc_log_append(path: &Path) -> std::io::Result<std::fs::File> {
|
||||||
}
|
}
|
||||||
#[cfg(not(unix))]
|
#[cfg(not(unix))]
|
||||||
{
|
{
|
||||||
OpenOptions::new().create(true).append(true).open(path)
|
let _ = path;
|
||||||
|
Err(std::io::Error::new(
|
||||||
|
std::io::ErrorKind::PermissionDenied,
|
||||||
|
"unknown_dc_file_log_enabled requires unix O_NOFOLLOW support",
|
||||||
|
))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -204,6 +208,7 @@ where
|
||||||
config.general.direct_relay_copy_buf_s2c_bytes,
|
config.general.direct_relay_copy_buf_s2c_bytes,
|
||||||
user,
|
user,
|
||||||
Arc::clone(&stats),
|
Arc::clone(&stats),
|
||||||
|
config.access.user_data_quota.get(user).copied(),
|
||||||
buffer_pool,
|
buffer_pool,
|
||||||
);
|
);
|
||||||
tokio::pin!(relay_result);
|
tokio::pin!(relay_result);
|
||||||
|
|
|
||||||
|
|
@ -241,7 +241,26 @@ fn auth_probe_record_failure_with_state(
|
||||||
rounds += 1;
|
rounds += 1;
|
||||||
if rounds > 8 {
|
if rounds > 8 {
|
||||||
auth_probe_note_saturation(now);
|
auth_probe_note_saturation(now);
|
||||||
return;
|
let mut eviction_candidate: Option<(IpAddr, u32, Instant)> = None;
|
||||||
|
for entry in state.iter().take(AUTH_PROBE_PRUNE_SCAN_LIMIT) {
|
||||||
|
let key = *entry.key();
|
||||||
|
let fail_streak = entry.value().fail_streak;
|
||||||
|
let last_seen = entry.value().last_seen;
|
||||||
|
match eviction_candidate {
|
||||||
|
Some((_, current_fail, current_seen))
|
||||||
|
if fail_streak > current_fail
|
||||||
|
|| (fail_streak == current_fail && last_seen >= current_seen) =>
|
||||||
|
{
|
||||||
|
}
|
||||||
|
_ => eviction_candidate = Some((key, fail_streak, last_seen)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let Some((evict_key, _, _)) = eviction_candidate else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
state.remove(&evict_key);
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut stale_keys = Vec::new();
|
let mut stale_keys = Vec::new();
|
||||||
|
|
@ -518,6 +537,7 @@ pub struct HandshakeSuccess {
|
||||||
/// Client address
|
/// Client address
|
||||||
pub peer: SocketAddr,
|
pub peer: SocketAddr,
|
||||||
/// Whether TLS was used
|
/// Whether TLS was used
|
||||||
|
|
||||||
pub is_tls: bool,
|
pub is_tls: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -716,7 +736,11 @@ where
|
||||||
R: AsyncRead + Unpin + Send,
|
R: AsyncRead + Unpin + Send,
|
||||||
W: AsyncWrite + Unpin + Send,
|
W: AsyncWrite + Unpin + Send,
|
||||||
{
|
{
|
||||||
trace!(peer = %peer, handshake = ?hex::encode(handshake), "MTProto handshake bytes");
|
trace!(
|
||||||
|
peer = %peer,
|
||||||
|
handshake_head = %hex::encode(&handshake[..8]),
|
||||||
|
"MTProto handshake prefix"
|
||||||
|
);
|
||||||
|
|
||||||
let throttle_now = Instant::now();
|
let throttle_now = Instant::now();
|
||||||
if auth_probe_should_apply_preauth_throttle(peer.ip(), throttle_now) {
|
if auth_probe_should_apply_preauth_throttle(peer.ip(), throttle_now) {
|
||||||
|
|
@ -916,6 +940,7 @@ pub fn encrypt_tg_nonce_with_ciphers(nonce: &[u8; HANDSHAKE_LEN]) -> (Vec<u8>, A
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Encrypt nonce for sending to Telegram (legacy function for compatibility)
|
/// Encrypt nonce for sending to Telegram (legacy function for compatibility)
|
||||||
|
|
||||||
pub fn encrypt_tg_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec<u8> {
|
pub fn encrypt_tg_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec<u8> {
|
||||||
let (encrypted, _, _) = encrypt_tg_nonce_with_ciphers(nonce);
|
let (encrypted, _, _) = encrypt_tg_nonce_with_ciphers(nonce);
|
||||||
encrypted
|
encrypted
|
||||||
|
|
|
||||||
|
|
@ -1584,6 +1584,47 @@ fn stress_auth_probe_full_map_churn_keeps_bound_and_tracks_newcomers() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn auth_probe_over_cap_churn_still_tracks_newcomer_after_round_limit() {
|
||||||
|
let _guard = auth_probe_test_lock()
|
||||||
|
.lock()
|
||||||
|
.unwrap_or_else(|poisoned| poisoned.into_inner());
|
||||||
|
clear_auth_probe_state_for_testing();
|
||||||
|
|
||||||
|
let state = DashMap::new();
|
||||||
|
let now = Instant::now();
|
||||||
|
let initial = AUTH_PROBE_TRACK_MAX_ENTRIES + 32;
|
||||||
|
|
||||||
|
for idx in 0..initial {
|
||||||
|
let ip = IpAddr::V4(Ipv4Addr::new(
|
||||||
|
10,
|
||||||
|
6,
|
||||||
|
((idx >> 8) & 0xff) as u8,
|
||||||
|
(idx & 0xff) as u8,
|
||||||
|
));
|
||||||
|
state.insert(
|
||||||
|
ip,
|
||||||
|
AuthProbeState {
|
||||||
|
fail_streak: 1,
|
||||||
|
blocked_until: now,
|
||||||
|
last_seen: now + Duration::from_millis((idx % 1024) as u64),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let newcomer = IpAddr::V4(Ipv4Addr::new(203, 0, 114, 77));
|
||||||
|
auth_probe_record_failure_with_state(&state, newcomer, now + Duration::from_secs(1));
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
state.get(&newcomer).is_some(),
|
||||||
|
"new probe source must still be tracked even when map starts above hard cap"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
state.len() < initial + 1,
|
||||||
|
"round-limited eviction path must still reclaim capacity under over-cap churn"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn auth_probe_capacity_prefers_evicting_low_fail_streak_entries_first() {
|
fn auth_probe_capacity_prefers_evicting_low_fail_streak_entries_first() {
|
||||||
let _guard = auth_probe_test_lock()
|
let _guard = auth_probe_test_lock()
|
||||||
|
|
|
||||||
|
|
@ -2,15 +2,13 @@ use std::collections::hash_map::RandomState;
|
||||||
use std::hash::BuildHasher;
|
use std::hash::BuildHasher;
|
||||||
use std::hash::{Hash, Hasher};
|
use std::hash::{Hash, Hasher};
|
||||||
use std::net::{IpAddr, SocketAddr};
|
use std::net::{IpAddr, SocketAddr};
|
||||||
use std::sync::atomic::{AtomicU64, Ordering};
|
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
|
||||||
use std::sync::{Arc, OnceLock};
|
use std::sync::{Arc, Mutex, OnceLock};
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
#[cfg(test)]
|
|
||||||
use std::sync::Mutex;
|
|
||||||
|
|
||||||
use dashmap::DashMap;
|
use dashmap::DashMap;
|
||||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||||
use tokio::sync::{mpsc, oneshot, watch};
|
use tokio::sync::{mpsc, oneshot, watch, Mutex as AsyncMutex};
|
||||||
use tokio::time::timeout;
|
use tokio::time::timeout;
|
||||||
use tracing::{debug, trace, warn};
|
use tracing::{debug, trace, warn};
|
||||||
|
|
||||||
|
|
@ -35,14 +33,22 @@ enum C2MeCommand {
|
||||||
const DESYNC_DEDUP_WINDOW: Duration = Duration::from_secs(60);
|
const DESYNC_DEDUP_WINDOW: Duration = Duration::from_secs(60);
|
||||||
const DESYNC_DEDUP_MAX_ENTRIES: usize = 65_536;
|
const DESYNC_DEDUP_MAX_ENTRIES: usize = 65_536;
|
||||||
const DESYNC_DEDUP_PRUNE_SCAN_LIMIT: usize = 1024;
|
const DESYNC_DEDUP_PRUNE_SCAN_LIMIT: usize = 1024;
|
||||||
|
const DESYNC_FULL_CACHE_EMIT_MIN_INTERVAL: Duration = Duration::from_millis(1000);
|
||||||
const DESYNC_ERROR_CLASS: &str = "frame_too_large_crypto_desync";
|
const DESYNC_ERROR_CLASS: &str = "frame_too_large_crypto_desync";
|
||||||
const C2ME_CHANNEL_CAPACITY_FALLBACK: usize = 128;
|
const C2ME_CHANNEL_CAPACITY_FALLBACK: usize = 128;
|
||||||
const C2ME_SOFT_PRESSURE_MIN_FREE_SLOTS: usize = 64;
|
const C2ME_SOFT_PRESSURE_MIN_FREE_SLOTS: usize = 64;
|
||||||
const C2ME_SENDER_FAIRNESS_BUDGET: usize = 32;
|
const C2ME_SENDER_FAIRNESS_BUDGET: usize = 32;
|
||||||
|
#[cfg(test)]
|
||||||
|
const C2ME_SEND_TIMEOUT: Duration = Duration::from_millis(50);
|
||||||
|
#[cfg(not(test))]
|
||||||
|
const C2ME_SEND_TIMEOUT: Duration = Duration::from_secs(5);
|
||||||
const ME_D2C_FLUSH_BATCH_MAX_FRAMES_MIN: usize = 1;
|
const ME_D2C_FLUSH_BATCH_MAX_FRAMES_MIN: usize = 1;
|
||||||
const ME_D2C_FLUSH_BATCH_MAX_BYTES_MIN: usize = 4096;
|
const ME_D2C_FLUSH_BATCH_MAX_BYTES_MIN: usize = 4096;
|
||||||
static DESYNC_DEDUP: OnceLock<DashMap<u64, Instant>> = OnceLock::new();
|
static DESYNC_DEDUP: OnceLock<DashMap<u64, Instant>> = OnceLock::new();
|
||||||
static DESYNC_HASHER: OnceLock<RandomState> = OnceLock::new();
|
static DESYNC_HASHER: OnceLock<RandomState> = OnceLock::new();
|
||||||
|
static DESYNC_FULL_CACHE_LAST_EMIT_AT: OnceLock<Mutex<Option<Instant>>> = OnceLock::new();
|
||||||
|
static DESYNC_DEDUP_EVER_SATURATED: OnceLock<AtomicBool> = OnceLock::new();
|
||||||
|
static QUOTA_USER_LOCKS: OnceLock<DashMap<String, Arc<AsyncMutex<()>>>> = OnceLock::new();
|
||||||
|
|
||||||
struct RelayForensicsState {
|
struct RelayForensicsState {
|
||||||
trace_id: u64,
|
trace_id: u64,
|
||||||
|
|
@ -98,6 +104,11 @@ fn should_emit_full_desync(key: u64, all_full: bool, now: Instant) -> bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
let dedup = DESYNC_DEDUP.get_or_init(DashMap::new);
|
let dedup = DESYNC_DEDUP.get_or_init(DashMap::new);
|
||||||
|
let saturated_before = dedup.len() >= DESYNC_DEDUP_MAX_ENTRIES;
|
||||||
|
let ever_saturated = DESYNC_DEDUP_EVER_SATURATED.get_or_init(|| AtomicBool::new(false));
|
||||||
|
if saturated_before {
|
||||||
|
ever_saturated.store(true, Ordering::Relaxed);
|
||||||
|
}
|
||||||
|
|
||||||
if let Some(mut seen_at) = dedup.get_mut(&key) {
|
if let Some(mut seen_at) = dedup.get_mut(&key) {
|
||||||
if now.duration_since(*seen_at) >= DESYNC_DEDUP_WINDOW {
|
if now.duration_since(*seen_at) >= DESYNC_DEDUP_WINDOW {
|
||||||
|
|
@ -132,12 +143,52 @@ fn should_emit_full_desync(key: u64, all_full: bool, now: Instant) -> bool {
|
||||||
};
|
};
|
||||||
dedup.remove(&evict_key);
|
dedup.remove(&evict_key);
|
||||||
dedup.insert(key, now);
|
dedup.insert(key, now);
|
||||||
return false;
|
return should_emit_full_desync_full_cache(now);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
dedup.insert(key, now);
|
dedup.insert(key, now);
|
||||||
true
|
let saturated_after = dedup.len() >= DESYNC_DEDUP_MAX_ENTRIES;
|
||||||
|
// Preserve the first sequential insert that reaches capacity as a normal
|
||||||
|
// emit, while still gating concurrent newcomer churn after the cache has
|
||||||
|
// ever been observed at saturation.
|
||||||
|
let was_ever_saturated = if saturated_after {
|
||||||
|
ever_saturated.swap(true, Ordering::Relaxed)
|
||||||
|
} else {
|
||||||
|
ever_saturated.load(Ordering::Relaxed)
|
||||||
|
};
|
||||||
|
|
||||||
|
if saturated_before || (saturated_after && was_ever_saturated) {
|
||||||
|
should_emit_full_desync_full_cache(now)
|
||||||
|
} else {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn should_emit_full_desync_full_cache(now: Instant) -> bool {
|
||||||
|
let gate = DESYNC_FULL_CACHE_LAST_EMIT_AT.get_or_init(|| Mutex::new(None));
|
||||||
|
let Ok(mut last_emit_at) = gate.lock() else {
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
|
||||||
|
match *last_emit_at {
|
||||||
|
None => {
|
||||||
|
*last_emit_at = Some(now);
|
||||||
|
true
|
||||||
|
}
|
||||||
|
Some(last) => {
|
||||||
|
let Some(elapsed) = now.checked_duration_since(last) else {
|
||||||
|
*last_emit_at = Some(now);
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
if elapsed >= DESYNC_FULL_CACHE_EMIT_MIN_INTERVAL {
|
||||||
|
*last_emit_at = Some(now);
|
||||||
|
true
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|
@ -145,6 +196,21 @@ fn clear_desync_dedup_for_testing() {
|
||||||
if let Some(dedup) = DESYNC_DEDUP.get() {
|
if let Some(dedup) = DESYNC_DEDUP.get() {
|
||||||
dedup.clear();
|
dedup.clear();
|
||||||
}
|
}
|
||||||
|
if let Some(ever_saturated) = DESYNC_DEDUP_EVER_SATURATED.get() {
|
||||||
|
ever_saturated.store(false, Ordering::Relaxed);
|
||||||
|
}
|
||||||
|
if let Some(last_emit_at) = DESYNC_FULL_CACHE_LAST_EMIT_AT.get() {
|
||||||
|
match last_emit_at.lock() {
|
||||||
|
Ok(mut guard) => {
|
||||||
|
*guard = None;
|
||||||
|
}
|
||||||
|
Err(poisoned) => {
|
||||||
|
let mut guard = poisoned.into_inner();
|
||||||
|
*guard = None;
|
||||||
|
last_emit_at.clear_poison();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|
@ -248,6 +314,38 @@ fn should_yield_c2me_sender(sent_since_yield: usize, has_backlog: bool) -> bool
|
||||||
has_backlog && sent_since_yield >= C2ME_SENDER_FAIRNESS_BUDGET
|
has_backlog && sent_since_yield >= C2ME_SENDER_FAIRNESS_BUDGET
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn quota_exceeded_for_user(stats: &Stats, user: &str, quota_limit: Option<u64>) -> bool {
|
||||||
|
quota_limit.is_some_and(|quota| stats.get_user_total_octets(user) >= quota)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn quota_would_be_exceeded_for_user(
|
||||||
|
stats: &Stats,
|
||||||
|
user: &str,
|
||||||
|
quota_limit: Option<u64>,
|
||||||
|
bytes: u64,
|
||||||
|
) -> bool {
|
||||||
|
quota_limit.is_some_and(|quota| {
|
||||||
|
let used = stats.get_user_total_octets(user);
|
||||||
|
used >= quota || bytes > quota.saturating_sub(used)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn quota_user_lock(user: &str) -> Arc<AsyncMutex<()>> {
|
||||||
|
let locks = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||||
|
if let Some(existing) = locks.get(user) {
|
||||||
|
return Arc::clone(existing.value());
|
||||||
|
}
|
||||||
|
|
||||||
|
let created = Arc::new(AsyncMutex::new(()));
|
||||||
|
match locks.entry(user.to_string()) {
|
||||||
|
dashmap::mapref::entry::Entry::Occupied(entry) => Arc::clone(entry.get()),
|
||||||
|
dashmap::mapref::entry::Entry::Vacant(entry) => {
|
||||||
|
entry.insert(Arc::clone(&created));
|
||||||
|
created
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async fn enqueue_c2me_command(
|
async fn enqueue_c2me_command(
|
||||||
tx: &mpsc::Sender<C2MeCommand>,
|
tx: &mpsc::Sender<C2MeCommand>,
|
||||||
cmd: C2MeCommand,
|
cmd: C2MeCommand,
|
||||||
|
|
@ -260,7 +358,14 @@ async fn enqueue_c2me_command(
|
||||||
if tx.capacity() <= C2ME_SOFT_PRESSURE_MIN_FREE_SLOTS {
|
if tx.capacity() <= C2ME_SOFT_PRESSURE_MIN_FREE_SLOTS {
|
||||||
tokio::task::yield_now().await;
|
tokio::task::yield_now().await;
|
||||||
}
|
}
|
||||||
tx.send(cmd).await
|
match timeout(C2ME_SEND_TIMEOUT, tx.reserve()).await {
|
||||||
|
Ok(Ok(permit)) => {
|
||||||
|
permit.send(cmd);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
Ok(Err(_)) => Err(mpsc::error::SendError(cmd)),
|
||||||
|
Err(_) => Err(mpsc::error::SendError(cmd)),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -284,6 +389,7 @@ where
|
||||||
W: AsyncWrite + Unpin + Send + 'static,
|
W: AsyncWrite + Unpin + Send + 'static,
|
||||||
{
|
{
|
||||||
let user = success.user.clone();
|
let user = success.user.clone();
|
||||||
|
let quota_limit = config.access.user_data_quota.get(&user).copied();
|
||||||
let peer = success.peer;
|
let peer = success.peer;
|
||||||
let proto_tag = success.proto_tag;
|
let proto_tag = success.proto_tag;
|
||||||
let pool_generation = me_pool.current_generation();
|
let pool_generation = me_pool.current_generation();
|
||||||
|
|
@ -432,6 +538,7 @@ where
|
||||||
&mut frame_buf,
|
&mut frame_buf,
|
||||||
stats_clone.as_ref(),
|
stats_clone.as_ref(),
|
||||||
&user_clone,
|
&user_clone,
|
||||||
|
quota_limit,
|
||||||
bytes_me2c_clone.as_ref(),
|
bytes_me2c_clone.as_ref(),
|
||||||
conn_id,
|
conn_id,
|
||||||
d2c_flush_policy.ack_flush_immediate,
|
d2c_flush_policy.ack_flush_immediate,
|
||||||
|
|
@ -464,6 +571,7 @@ where
|
||||||
&mut frame_buf,
|
&mut frame_buf,
|
||||||
stats_clone.as_ref(),
|
stats_clone.as_ref(),
|
||||||
&user_clone,
|
&user_clone,
|
||||||
|
quota_limit,
|
||||||
bytes_me2c_clone.as_ref(),
|
bytes_me2c_clone.as_ref(),
|
||||||
conn_id,
|
conn_id,
|
||||||
d2c_flush_policy.ack_flush_immediate,
|
d2c_flush_policy.ack_flush_immediate,
|
||||||
|
|
@ -496,6 +604,7 @@ where
|
||||||
&mut frame_buf,
|
&mut frame_buf,
|
||||||
stats_clone.as_ref(),
|
stats_clone.as_ref(),
|
||||||
&user_clone,
|
&user_clone,
|
||||||
|
quota_limit,
|
||||||
bytes_me2c_clone.as_ref(),
|
bytes_me2c_clone.as_ref(),
|
||||||
conn_id,
|
conn_id,
|
||||||
d2c_flush_policy.ack_flush_immediate,
|
d2c_flush_policy.ack_flush_immediate,
|
||||||
|
|
@ -528,6 +637,7 @@ where
|
||||||
&mut frame_buf,
|
&mut frame_buf,
|
||||||
stats_clone.as_ref(),
|
stats_clone.as_ref(),
|
||||||
&user_clone,
|
&user_clone,
|
||||||
|
quota_limit,
|
||||||
bytes_me2c_clone.as_ref(),
|
bytes_me2c_clone.as_ref(),
|
||||||
conn_id,
|
conn_id,
|
||||||
d2c_flush_policy.ack_flush_immediate,
|
d2c_flush_policy.ack_flush_immediate,
|
||||||
|
|
@ -609,7 +719,19 @@ where
|
||||||
forensics.bytes_c2me = forensics
|
forensics.bytes_c2me = forensics
|
||||||
.bytes_c2me
|
.bytes_c2me
|
||||||
.saturating_add(payload.len() as u64);
|
.saturating_add(payload.len() as u64);
|
||||||
stats.add_user_octets_from(&user, payload.len() as u64);
|
if let Some(limit) = quota_limit {
|
||||||
|
let quota_lock = quota_user_lock(&user);
|
||||||
|
let _quota_guard = quota_lock.lock().await;
|
||||||
|
stats.add_user_octets_from(&user, payload.len() as u64);
|
||||||
|
if quota_exceeded_for_user(stats.as_ref(), &user, Some(limit)) {
|
||||||
|
main_result = Err(ProxyError::DataQuotaExceeded {
|
||||||
|
user: user.clone(),
|
||||||
|
});
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
stats.add_user_octets_from(&user, payload.len() as u64);
|
||||||
|
}
|
||||||
let mut flags = proto_flags;
|
let mut flags = proto_flags;
|
||||||
if quickack {
|
if quickack {
|
||||||
flags |= RPC_FLAG_QUICKACK;
|
flags |= RPC_FLAG_QUICKACK;
|
||||||
|
|
@ -833,6 +955,7 @@ async fn process_me_writer_response<W>(
|
||||||
frame_buf: &mut Vec<u8>,
|
frame_buf: &mut Vec<u8>,
|
||||||
stats: &Stats,
|
stats: &Stats,
|
||||||
user: &str,
|
user: &str,
|
||||||
|
quota_limit: Option<u64>,
|
||||||
bytes_me2c: &AtomicU64,
|
bytes_me2c: &AtomicU64,
|
||||||
conn_id: u64,
|
conn_id: u64,
|
||||||
ack_flush_immediate: bool,
|
ack_flush_immediate: bool,
|
||||||
|
|
@ -848,17 +971,47 @@ where
|
||||||
} else {
|
} else {
|
||||||
trace!(conn_id, bytes = data.len(), flags, "ME->C data");
|
trace!(conn_id, bytes = data.len(), flags, "ME->C data");
|
||||||
}
|
}
|
||||||
bytes_me2c.fetch_add(data.len() as u64, Ordering::Relaxed);
|
let data_len = data.len() as u64;
|
||||||
stats.add_user_octets_to(user, data.len() as u64);
|
if let Some(limit) = quota_limit {
|
||||||
write_client_payload(
|
let quota_lock = quota_user_lock(user);
|
||||||
client_writer,
|
let _quota_guard = quota_lock.lock().await;
|
||||||
proto_tag,
|
if quota_would_be_exceeded_for_user(stats, user, Some(limit), data_len) {
|
||||||
flags,
|
return Err(ProxyError::DataQuotaExceeded {
|
||||||
&data,
|
user: user.to_string(),
|
||||||
rng,
|
});
|
||||||
frame_buf,
|
}
|
||||||
)
|
write_client_payload(
|
||||||
.await?;
|
client_writer,
|
||||||
|
proto_tag,
|
||||||
|
flags,
|
||||||
|
&data,
|
||||||
|
rng,
|
||||||
|
frame_buf,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
bytes_me2c.fetch_add(data.len() as u64, Ordering::Relaxed);
|
||||||
|
stats.add_user_octets_to(user, data.len() as u64);
|
||||||
|
|
||||||
|
if quota_exceeded_for_user(stats, user, Some(limit)) {
|
||||||
|
return Err(ProxyError::DataQuotaExceeded {
|
||||||
|
user: user.to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
write_client_payload(
|
||||||
|
client_writer,
|
||||||
|
proto_tag,
|
||||||
|
flags,
|
||||||
|
&data,
|
||||||
|
rng,
|
||||||
|
frame_buf,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
bytes_me2c.fetch_add(data.len() as u64, Ordering::Relaxed);
|
||||||
|
stats.add_user_octets_to(user, data.len() as u64);
|
||||||
|
}
|
||||||
|
|
||||||
Ok(MeWriterResponseOutcome::Continue {
|
Ok(MeWriterResponseOutcome::Continue {
|
||||||
frames: 1,
|
frames: 1,
|
||||||
|
|
|
||||||
|
|
@ -13,8 +13,9 @@ use rand::{Rng, SeedableRng};
|
||||||
use std::collections::{HashMap, HashSet};
|
use std::collections::{HashMap, HashSet};
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::sync::atomic::AtomicU64;
|
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
|
||||||
use tokio::io::AsyncWriteExt;
|
use std::thread;
|
||||||
|
use tokio::io::AsyncReadExt;
|
||||||
use tokio::io::duplex;
|
use tokio::io::duplex;
|
||||||
use tokio::time::{Duration as TokioDuration, timeout};
|
use tokio::time::{Duration as TokioDuration, timeout};
|
||||||
|
|
||||||
|
|
@ -176,6 +177,36 @@ async fn enqueue_c2me_command_full_then_closed_recycles_waiting_payload() {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn enqueue_c2me_command_full_queue_times_out_without_receiver_progress() {
|
||||||
|
let (tx, _rx) = mpsc::channel::<C2MeCommand>(1);
|
||||||
|
tx.send(C2MeCommand::Data {
|
||||||
|
payload: make_pooled_payload(&[1]),
|
||||||
|
flags: 0,
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let started = Instant::now();
|
||||||
|
let result = enqueue_c2me_command(
|
||||||
|
&tx,
|
||||||
|
C2MeCommand::Data {
|
||||||
|
payload: make_pooled_payload(&[2, 2]),
|
||||||
|
flags: 1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
result.is_err(),
|
||||||
|
"enqueue must fail when queue stays full beyond bounded timeout"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
started.elapsed() < TokioDuration::from_millis(400),
|
||||||
|
"full-queue timeout must resolve promptly"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn desync_dedup_cache_is_bounded() {
|
fn desync_dedup_cache_is_bounded() {
|
||||||
let _guard = desync_dedup_test_lock()
|
let _guard = desync_dedup_test_lock()
|
||||||
|
|
@ -192,12 +223,12 @@ fn desync_dedup_cache_is_bounded() {
|
||||||
}
|
}
|
||||||
|
|
||||||
assert!(
|
assert!(
|
||||||
!should_emit_full_desync(u64::MAX, false, now),
|
should_emit_full_desync(u64::MAX, false, now),
|
||||||
"new key above cap must remain suppressed to avoid log amplification"
|
"new key above cap must emit once after bounded eviction for forensic visibility"
|
||||||
);
|
);
|
||||||
|
|
||||||
assert!(
|
assert!(
|
||||||
!should_emit_full_desync(7, false, now),
|
!should_emit_full_desync(u64::MAX, false, now),
|
||||||
"already tracked key inside dedup window must stay suppressed"
|
"already tracked key inside dedup window must stay suppressed"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
@ -215,10 +246,18 @@ fn desync_dedup_full_cache_churn_stays_suppressed() {
|
||||||
}
|
}
|
||||||
|
|
||||||
for offset in 0..2048u64 {
|
for offset in 0..2048u64 {
|
||||||
assert!(
|
let emitted = should_emit_full_desync(u64::MAX - offset, false, now);
|
||||||
!should_emit_full_desync(u64::MAX - offset, false, now),
|
if offset == 0 {
|
||||||
"fresh full-cache churn must remain suppressed under pressure"
|
assert!(
|
||||||
);
|
emitted,
|
||||||
|
"first full-cache newcomer should emit for forensic visibility"
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
assert!(
|
||||||
|
!emitted,
|
||||||
|
"full-cache newcomer churn inside emit interval must stay suppressed"
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -296,18 +335,20 @@ fn stress_desync_dedup_churn_keeps_cache_hard_bounded() {
|
||||||
let now = Instant::now();
|
let now = Instant::now();
|
||||||
let total = DESYNC_DEDUP_MAX_ENTRIES + 8192;
|
let total = DESYNC_DEDUP_MAX_ENTRIES + 8192;
|
||||||
|
|
||||||
|
let mut emitted_count = 0usize;
|
||||||
for key in 0..total as u64 {
|
for key in 0..total as u64 {
|
||||||
let emitted = should_emit_full_desync(key, false, now);
|
let emitted = should_emit_full_desync(key, false, now);
|
||||||
if key < DESYNC_DEDUP_MAX_ENTRIES as u64 {
|
if emitted {
|
||||||
assert!(emitted, "keys below cap must be admitted initially");
|
emitted_count += 1;
|
||||||
} else {
|
|
||||||
assert!(
|
|
||||||
!emitted,
|
|
||||||
"new keys above cap must stay suppressed under sustained churn"
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
emitted_count,
|
||||||
|
DESYNC_DEDUP_MAX_ENTRIES + 1,
|
||||||
|
"after capacity is reached, same-tick newcomer churn must be rate-limited"
|
||||||
|
);
|
||||||
|
|
||||||
let len = DESYNC_DEDUP
|
let len = DESYNC_DEDUP
|
||||||
.get()
|
.get()
|
||||||
.expect("dedup cache must be initialized by stress run")
|
.expect("dedup cache must be initialized by stress run")
|
||||||
|
|
@ -318,6 +359,282 @@ fn stress_desync_dedup_churn_keeps_cache_hard_bounded() {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn full_cache_newcomer_emission_is_rate_limited_but_periodic() {
|
||||||
|
let _guard = desync_dedup_test_lock()
|
||||||
|
.lock()
|
||||||
|
.expect("desync dedup test lock must be available");
|
||||||
|
clear_desync_dedup_for_testing();
|
||||||
|
|
||||||
|
let dedup = DESYNC_DEDUP.get_or_init(DashMap::new);
|
||||||
|
let base_now = Instant::now();
|
||||||
|
|
||||||
|
for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 {
|
||||||
|
dedup.insert(key, base_now - TokioDuration::from_millis(10));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Same-tick newcomer storm: only the first should emit full forensic record.
|
||||||
|
let mut burst_emits = 0usize;
|
||||||
|
for i in 0..1024u64 {
|
||||||
|
if should_emit_full_desync(10_000_000 + i, false, base_now) {
|
||||||
|
burst_emits += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert_eq!(
|
||||||
|
burst_emits, 1,
|
||||||
|
"full-cache newcomer burst must be bounded to a single full emit per interval"
|
||||||
|
);
|
||||||
|
|
||||||
|
// After each interval elapses, one newcomer may emit again.
|
||||||
|
for step in 1..=6u64 {
|
||||||
|
let t = base_now + DESYNC_FULL_CACHE_EMIT_MIN_INTERVAL * step as u32;
|
||||||
|
assert!(
|
||||||
|
should_emit_full_desync(20_000_000 + step, false, t),
|
||||||
|
"full-cache newcomer should re-emit once interval has elapsed"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
!should_emit_full_desync(30_000_000 + step, false, t),
|
||||||
|
"additional newcomers in the same interval tick must remain suppressed"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn full_cache_mode_override_emits_every_event() {
|
||||||
|
let _guard = desync_dedup_test_lock()
|
||||||
|
.lock()
|
||||||
|
.expect("desync dedup test lock must be available");
|
||||||
|
clear_desync_dedup_for_testing();
|
||||||
|
|
||||||
|
let now = Instant::now();
|
||||||
|
for i in 0..10_000u64 {
|
||||||
|
assert!(
|
||||||
|
should_emit_full_desync(100_000_000 + i, true, now),
|
||||||
|
"desync_all_full override must bypass dedup and rate-limit suppression"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn report_desync_stats_follow_rate_limited_full_cache_policy() {
|
||||||
|
let _guard = desync_dedup_test_lock()
|
||||||
|
.lock()
|
||||||
|
.expect("desync dedup test lock must be available");
|
||||||
|
clear_desync_dedup_for_testing();
|
||||||
|
|
||||||
|
let dedup = DESYNC_DEDUP.get_or_init(DashMap::new);
|
||||||
|
let base_now = Instant::now();
|
||||||
|
for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 {
|
||||||
|
dedup.insert(key, base_now - TokioDuration::from_millis(10));
|
||||||
|
}
|
||||||
|
|
||||||
|
let stats = Stats::new();
|
||||||
|
let mut state = make_forensics_state();
|
||||||
|
state.started_at = base_now;
|
||||||
|
|
||||||
|
for i in 0..128u64 {
|
||||||
|
state.peer_hash = 0xABC0_0000_0000_0000u64 ^ i;
|
||||||
|
let _ = report_desync_frame_too_large(
|
||||||
|
&state,
|
||||||
|
ProtoTag::Secure,
|
||||||
|
3,
|
||||||
|
1024,
|
||||||
|
4096,
|
||||||
|
Some([0x16, 0x03, 0x03, 0x00]),
|
||||||
|
&stats,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
stats.get_desync_total(),
|
||||||
|
128,
|
||||||
|
"every detected desync must increment total counter"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
stats.get_desync_full_logged(),
|
||||||
|
1,
|
||||||
|
"same-interval full-cache newcomer storm must allow only one full forensic emit"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
stats.get_desync_suppressed(),
|
||||||
|
127,
|
||||||
|
"remaining same-interval full-cache newcomer events must be suppressed"
|
||||||
|
);
|
||||||
|
|
||||||
|
// After one full interval in real wall clock, a newcomer should emit again.
|
||||||
|
thread::sleep(DESYNC_FULL_CACHE_EMIT_MIN_INTERVAL + TokioDuration::from_millis(20));
|
||||||
|
state.peer_hash = 0xDEAD_BEEF_DEAD_BEEFu64;
|
||||||
|
let _ = report_desync_frame_too_large(
|
||||||
|
&state,
|
||||||
|
ProtoTag::Secure,
|
||||||
|
4,
|
||||||
|
1024,
|
||||||
|
4097,
|
||||||
|
Some([0x16, 0x03, 0x03, 0x01]),
|
||||||
|
&stats,
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
stats.get_desync_full_logged(),
|
||||||
|
2,
|
||||||
|
"full forensic emission must recover after rate-limit interval"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn concurrent_full_cache_newcomer_storm_is_single_emit_per_interval() {
|
||||||
|
let _guard = desync_dedup_test_lock()
|
||||||
|
.lock()
|
||||||
|
.expect("desync dedup test lock must be available");
|
||||||
|
clear_desync_dedup_for_testing();
|
||||||
|
|
||||||
|
let dedup = DESYNC_DEDUP.get_or_init(DashMap::new);
|
||||||
|
let base_now = Instant::now();
|
||||||
|
for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 {
|
||||||
|
dedup.insert(key, base_now - TokioDuration::from_millis(10));
|
||||||
|
}
|
||||||
|
|
||||||
|
let emits = Arc::new(AtomicUsize::new(0));
|
||||||
|
let mut workers = Vec::new();
|
||||||
|
for worker_id in 0..32u64 {
|
||||||
|
let emits = Arc::clone(&emits);
|
||||||
|
workers.push(thread::spawn(move || {
|
||||||
|
for i in 0..512u64 {
|
||||||
|
let key = 0x7000_0000_0000_0000u64 ^ (worker_id << 20) ^ i;
|
||||||
|
if should_emit_full_desync(key, false, base_now) {
|
||||||
|
emits.fetch_add(1, Ordering::Relaxed);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
for worker in workers {
|
||||||
|
worker.join().expect("worker thread must not panic");
|
||||||
|
}
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
emits.load(Ordering::Relaxed),
|
||||||
|
1,
|
||||||
|
"concurrent same-interval full-cache storm must allow only one full forensic emit"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn light_fuzz_full_cache_rate_limit_oracle_matches_model() {
|
||||||
|
let _guard = desync_dedup_test_lock()
|
||||||
|
.lock()
|
||||||
|
.expect("desync dedup test lock must be available");
|
||||||
|
clear_desync_dedup_for_testing();
|
||||||
|
|
||||||
|
let dedup = DESYNC_DEDUP.get_or_init(DashMap::new);
|
||||||
|
let base_now = Instant::now();
|
||||||
|
for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 {
|
||||||
|
dedup.insert(key, base_now - TokioDuration::from_millis(10));
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut rng = StdRng::seed_from_u64(0xD15EA5E5_F00DBAAD);
|
||||||
|
let mut model_last_emit: Option<Instant> = None;
|
||||||
|
|
||||||
|
for i in 0..4096u64 {
|
||||||
|
let jitter_ms: u64 = rng.random_range(0..=3000);
|
||||||
|
let t = base_now + TokioDuration::from_millis(jitter_ms);
|
||||||
|
let key = 0x55AA_0000_0000_0000u64 ^ i ^ rng.random::<u64>();
|
||||||
|
let actual = should_emit_full_desync(key, false, t);
|
||||||
|
|
||||||
|
let expected = match model_last_emit {
|
||||||
|
None => {
|
||||||
|
model_last_emit = Some(t);
|
||||||
|
true
|
||||||
|
}
|
||||||
|
Some(last) => {
|
||||||
|
match t.checked_duration_since(last) {
|
||||||
|
Some(elapsed) if elapsed >= DESYNC_FULL_CACHE_EMIT_MIN_INTERVAL => {
|
||||||
|
model_last_emit = Some(t);
|
||||||
|
true
|
||||||
|
}
|
||||||
|
Some(_) => false,
|
||||||
|
None => {
|
||||||
|
// Match production fail-open behavior for non-monotonic synthetic input.
|
||||||
|
model_last_emit = Some(t);
|
||||||
|
true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
actual, expected,
|
||||||
|
"full-cache rate-limit gate diverged from reference model under light fuzz"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn full_cache_gate_lock_poison_is_fail_closed_without_panic() {
|
||||||
|
let _guard = desync_dedup_test_lock()
|
||||||
|
.lock()
|
||||||
|
.expect("desync dedup test lock must be available");
|
||||||
|
clear_desync_dedup_for_testing();
|
||||||
|
|
||||||
|
let dedup = DESYNC_DEDUP.get_or_init(DashMap::new);
|
||||||
|
let base_now = Instant::now();
|
||||||
|
for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 {
|
||||||
|
dedup.insert(key, base_now - TokioDuration::from_millis(10));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Poison the full-cache gate lock intentionally.
|
||||||
|
let gate = DESYNC_FULL_CACHE_LAST_EMIT_AT.get_or_init(|| Mutex::new(None));
|
||||||
|
let _ = std::panic::catch_unwind(|| {
|
||||||
|
let _lock = gate.lock().expect("gate lock must be lockable before poison");
|
||||||
|
panic!("intentional gate poison for fail-closed regression");
|
||||||
|
});
|
||||||
|
|
||||||
|
let emitted = should_emit_full_desync(0xFACE_0000_0000_0001, false, base_now);
|
||||||
|
assert!(
|
||||||
|
!emitted,
|
||||||
|
"poisoned full-cache gate must fail-closed (suppress) instead of panic or fail-open"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
dedup.len() <= DESYNC_DEDUP_MAX_ENTRIES,
|
||||||
|
"dedup cache must remain bounded even when gate lock is poisoned"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn full_cache_non_monotonic_time_emits_and_resets_gate_safely() {
|
||||||
|
let _guard = desync_dedup_test_lock()
|
||||||
|
.lock()
|
||||||
|
.expect("desync dedup test lock must be available");
|
||||||
|
clear_desync_dedup_for_testing();
|
||||||
|
|
||||||
|
let dedup = DESYNC_DEDUP.get_or_init(DashMap::new);
|
||||||
|
let base_now = Instant::now();
|
||||||
|
for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 {
|
||||||
|
dedup.insert(key, base_now - TokioDuration::from_millis(10));
|
||||||
|
}
|
||||||
|
|
||||||
|
// First event seeds the gate.
|
||||||
|
assert!(should_emit_full_desync(
|
||||||
|
0xABCD_0000_0000_0001,
|
||||||
|
false,
|
||||||
|
base_now + TokioDuration::from_millis(900)
|
||||||
|
));
|
||||||
|
|
||||||
|
// Synthetic earlier timestamp must not panic; it should fail-open and reset gate.
|
||||||
|
assert!(should_emit_full_desync(
|
||||||
|
0xABCD_0000_0000_0002,
|
||||||
|
false,
|
||||||
|
base_now + TokioDuration::from_millis(100)
|
||||||
|
));
|
||||||
|
|
||||||
|
// Same instant again remains suppressed after reset.
|
||||||
|
assert!(!should_emit_full_desync(
|
||||||
|
0xABCD_0000_0000_0003,
|
||||||
|
false,
|
||||||
|
base_now + TokioDuration::from_millis(100)
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn desync_dedup_full_cache_inserts_new_key_with_bounded_single_key_churn() {
|
fn desync_dedup_full_cache_inserts_new_key_with_bounded_single_key_churn() {
|
||||||
let _guard = desync_dedup_test_lock()
|
let _guard = desync_dedup_test_lock()
|
||||||
|
|
@ -338,8 +655,8 @@ fn desync_dedup_full_cache_inserts_new_key_with_bounded_single_key_churn() {
|
||||||
let newcomer_key = u64::MAX;
|
let newcomer_key = u64::MAX;
|
||||||
let emitted = should_emit_full_desync(newcomer_key, false, base_now);
|
let emitted = should_emit_full_desync(newcomer_key, false, base_now);
|
||||||
assert!(
|
assert!(
|
||||||
!emitted,
|
emitted,
|
||||||
"new entry under full fresh cache must stay suppressed"
|
"new entry under full fresh cache must emit after bounded eviction"
|
||||||
);
|
);
|
||||||
assert!(
|
assert!(
|
||||||
dedup.get(&newcomer_key).is_some(),
|
dedup.get(&newcomer_key).is_some(),
|
||||||
|
|
@ -406,6 +723,24 @@ fn light_fuzz_desync_dedup_temporal_gate_behavior_is_stable() {
|
||||||
panic!("expected at least one post-window sample to re-emit forensic record");
|
panic!("expected at least one post-window sample to re-emit forensic record");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[ignore = "Tracking for M-04: Verify should_emit_full_desync returns true on first occurrence and false on duplicate within window"]
|
||||||
|
fn should_emit_full_desync_filters_duplicates() {
|
||||||
|
unimplemented!("Stub for M-04");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[ignore = "Tracking for M-04: Verify desync dedup eviction behaves correctly under map-full condition"]
|
||||||
|
fn desync_dedup_eviction_under_map_full_condition() {
|
||||||
|
unimplemented!("Stub for M-04");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
#[ignore = "Tracking for M-05: Verify C2ME channel full path yields then sends under backpressure"]
|
||||||
|
async fn c2me_channel_full_path_yields_then_sends() {
|
||||||
|
unimplemented!("Stub for M-05");
|
||||||
|
}
|
||||||
|
|
||||||
fn make_forensics_state() -> RelayForensicsState {
|
fn make_forensics_state() -> RelayForensicsState {
|
||||||
RelayForensicsState {
|
RelayForensicsState {
|
||||||
trace_id: 1,
|
trace_id: 1,
|
||||||
|
|
@ -974,6 +1309,7 @@ async fn process_me_writer_response_ack_obeys_flush_policy() {
|
||||||
&mut frame_buf,
|
&mut frame_buf,
|
||||||
&stats,
|
&stats,
|
||||||
"user",
|
"user",
|
||||||
|
None,
|
||||||
&bytes_me2c,
|
&bytes_me2c,
|
||||||
77,
|
77,
|
||||||
true,
|
true,
|
||||||
|
|
@ -999,6 +1335,7 @@ async fn process_me_writer_response_ack_obeys_flush_policy() {
|
||||||
&mut frame_buf,
|
&mut frame_buf,
|
||||||
&stats,
|
&stats,
|
||||||
"user",
|
"user",
|
||||||
|
None,
|
||||||
&bytes_me2c,
|
&bytes_me2c,
|
||||||
77,
|
77,
|
||||||
false,
|
false,
|
||||||
|
|
@ -1038,6 +1375,7 @@ async fn process_me_writer_response_data_updates_byte_accounting() {
|
||||||
&mut frame_buf,
|
&mut frame_buf,
|
||||||
&stats,
|
&stats,
|
||||||
"user",
|
"user",
|
||||||
|
None,
|
||||||
&bytes_me2c,
|
&bytes_me2c,
|
||||||
88,
|
88,
|
||||||
false,
|
false,
|
||||||
|
|
@ -1061,6 +1399,162 @@ async fn process_me_writer_response_data_updates_byte_accounting() {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn process_me_writer_response_data_enforces_live_user_quota() {
|
||||||
|
let (writer_side, mut reader_side) = duplex(1024);
|
||||||
|
let mut writer = make_crypto_writer(writer_side);
|
||||||
|
let rng = SecureRandom::new();
|
||||||
|
let mut frame_buf = Vec::new();
|
||||||
|
let stats = Stats::new();
|
||||||
|
let bytes_me2c = AtomicU64::new(0);
|
||||||
|
|
||||||
|
stats.add_user_octets_from("quota-user", 10);
|
||||||
|
|
||||||
|
let result = process_me_writer_response(
|
||||||
|
MeResponse::Data {
|
||||||
|
flags: 0,
|
||||||
|
data: Bytes::from(vec![1u8, 2, 3, 4]),
|
||||||
|
},
|
||||||
|
&mut writer,
|
||||||
|
ProtoTag::Intermediate,
|
||||||
|
&rng,
|
||||||
|
&mut frame_buf,
|
||||||
|
&stats,
|
||||||
|
"quota-user",
|
||||||
|
Some(12),
|
||||||
|
&bytes_me2c,
|
||||||
|
89,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
matches!(result, Err(ProxyError::DataQuotaExceeded { user }) if user == "quota-user"),
|
||||||
|
"ME->client runtime path must terminate when live user quota is crossed"
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut raw = [0u8; 1];
|
||||||
|
assert!(
|
||||||
|
timeout(TokioDuration::from_millis(100), reader_side.read(&mut raw))
|
||||||
|
.await
|
||||||
|
.is_err(),
|
||||||
|
"quota exhaustion must not write any ciphertext to the client stream"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||||
|
async fn process_me_writer_response_concurrent_same_user_quota_does_not_overshoot_limit() {
|
||||||
|
let stats = Stats::new();
|
||||||
|
let bytes_me2c = AtomicU64::new(0);
|
||||||
|
let user = "quota-race-user";
|
||||||
|
|
||||||
|
let (writer_side_a, _reader_side_a) = duplex(1024);
|
||||||
|
let (writer_side_b, _reader_side_b) = duplex(1024);
|
||||||
|
let mut writer_a = make_crypto_writer(writer_side_a);
|
||||||
|
let mut writer_b = make_crypto_writer(writer_side_b);
|
||||||
|
let mut frame_buf_a = Vec::new();
|
||||||
|
let mut frame_buf_b = Vec::new();
|
||||||
|
let rng_a = SecureRandom::new();
|
||||||
|
let rng_b = SecureRandom::new();
|
||||||
|
|
||||||
|
let fut_a = process_me_writer_response(
|
||||||
|
MeResponse::Data {
|
||||||
|
flags: 0,
|
||||||
|
data: Bytes::from_static(&[0x11]),
|
||||||
|
},
|
||||||
|
&mut writer_a,
|
||||||
|
ProtoTag::Intermediate,
|
||||||
|
&rng_a,
|
||||||
|
&mut frame_buf_a,
|
||||||
|
&stats,
|
||||||
|
user,
|
||||||
|
Some(1),
|
||||||
|
&bytes_me2c,
|
||||||
|
91,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
);
|
||||||
|
let fut_b = process_me_writer_response(
|
||||||
|
MeResponse::Data {
|
||||||
|
flags: 0,
|
||||||
|
data: Bytes::from_static(&[0x22]),
|
||||||
|
},
|
||||||
|
&mut writer_b,
|
||||||
|
ProtoTag::Intermediate,
|
||||||
|
&rng_b,
|
||||||
|
&mut frame_buf_b,
|
||||||
|
&stats,
|
||||||
|
user,
|
||||||
|
Some(1),
|
||||||
|
&bytes_me2c,
|
||||||
|
92,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
);
|
||||||
|
|
||||||
|
let (result_a, result_b) = tokio::join!(fut_a, fut_b);
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
matches!(result_a, Err(ProxyError::DataQuotaExceeded { ref user }) if user == "quota-race-user")
|
||||||
|
|| matches!(result_a, Ok(_)),
|
||||||
|
"concurrent quota test must complete without panicking"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
matches!(result_b, Err(ProxyError::DataQuotaExceeded { ref user }) if user == "quota-race-user")
|
||||||
|
|| matches!(result_b, Ok(_)),
|
||||||
|
"concurrent quota test must complete without panicking"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
stats.get_user_total_octets(user) <= 1,
|
||||||
|
"same-user concurrent middle-relay responses must not overshoot the configured quota"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn process_me_writer_response_data_does_not_forward_partial_payload_when_remaining_quota_is_smaller_than_message() {
|
||||||
|
let (writer_side, mut reader_side) = duplex(1024);
|
||||||
|
let mut writer = make_crypto_writer(writer_side);
|
||||||
|
let rng = SecureRandom::new();
|
||||||
|
let mut frame_buf = Vec::new();
|
||||||
|
let stats = Stats::new();
|
||||||
|
let bytes_me2c = AtomicU64::new(0);
|
||||||
|
|
||||||
|
stats.add_user_octets_to("partial-quota-user", 3);
|
||||||
|
|
||||||
|
let result = process_me_writer_response(
|
||||||
|
MeResponse::Data {
|
||||||
|
flags: 0,
|
||||||
|
data: Bytes::from(vec![1u8, 2, 3, 4]),
|
||||||
|
},
|
||||||
|
&mut writer,
|
||||||
|
ProtoTag::Intermediate,
|
||||||
|
&rng,
|
||||||
|
&mut frame_buf,
|
||||||
|
&stats,
|
||||||
|
"partial-quota-user",
|
||||||
|
Some(4),
|
||||||
|
&bytes_me2c,
|
||||||
|
90,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
matches!(result, Err(ProxyError::DataQuotaExceeded { user }) if user == "partial-quota-user"),
|
||||||
|
"ME->client runtime path must reject oversized payloads before writing"
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut raw = [0u8; 1];
|
||||||
|
assert!(
|
||||||
|
timeout(TokioDuration::from_millis(100), reader_side.read(&mut raw))
|
||||||
|
.await
|
||||||
|
.is_err(),
|
||||||
|
"oversized payloads must not leak any partial ciphertext to the client stream"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn middle_relay_abort_midflight_releases_route_gauge() {
|
async fn middle_relay_abort_midflight_releases_route_gauge() {
|
||||||
let stats = Arc::new(Stats::new());
|
let stats = Arc::new(Stats::new());
|
||||||
|
|
|
||||||
|
|
@ -53,16 +53,17 @@
|
||||||
|
|
||||||
use std::io;
|
use std::io;
|
||||||
use std::pin::Pin;
|
use std::pin::Pin;
|
||||||
use std::sync::Arc;
|
use std::sync::{Arc, Mutex, OnceLock};
|
||||||
use std::sync::atomic::{AtomicU64, Ordering};
|
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
|
||||||
use std::task::{Context, Poll};
|
use std::task::{Context, Poll};
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
use dashmap::DashMap;
|
||||||
use tokio::io::{
|
use tokio::io::{
|
||||||
AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf, copy_bidirectional_with_sizes,
|
AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf, copy_bidirectional_with_sizes,
|
||||||
};
|
};
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tracing::{debug, trace, warn};
|
use tracing::{debug, trace, warn};
|
||||||
use crate::error::Result;
|
use crate::error::{ProxyError, Result};
|
||||||
use crate::stats::Stats;
|
use crate::stats::Stats;
|
||||||
use crate::stream::BufferPool;
|
use crate::stream::BufferPool;
|
||||||
|
|
||||||
|
|
@ -205,6 +206,8 @@ struct StatsIo<S> {
|
||||||
counters: Arc<SharedCounters>,
|
counters: Arc<SharedCounters>,
|
||||||
stats: Arc<Stats>,
|
stats: Arc<Stats>,
|
||||||
user: String,
|
user: String,
|
||||||
|
quota_limit: Option<u64>,
|
||||||
|
quota_exceeded: Arc<AtomicBool>,
|
||||||
epoch: Instant,
|
epoch: Instant,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -214,11 +217,62 @@ impl<S> StatsIo<S> {
|
||||||
counters: Arc<SharedCounters>,
|
counters: Arc<SharedCounters>,
|
||||||
stats: Arc<Stats>,
|
stats: Arc<Stats>,
|
||||||
user: String,
|
user: String,
|
||||||
|
quota_limit: Option<u64>,
|
||||||
|
quota_exceeded: Arc<AtomicBool>,
|
||||||
epoch: Instant,
|
epoch: Instant,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
// Mark initial activity so the watchdog doesn't fire before data flows
|
// Mark initial activity so the watchdog doesn't fire before data flows
|
||||||
counters.touch(Instant::now(), epoch);
|
counters.touch(Instant::now(), epoch);
|
||||||
Self { inner, counters, stats, user, epoch }
|
Self {
|
||||||
|
inner,
|
||||||
|
counters,
|
||||||
|
stats,
|
||||||
|
user,
|
||||||
|
quota_limit,
|
||||||
|
quota_exceeded,
|
||||||
|
epoch,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct QuotaIoSentinel;
|
||||||
|
|
||||||
|
impl std::fmt::Display for QuotaIoSentinel {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
f.write_str("user data quota exceeded")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::error::Error for QuotaIoSentinel {}
|
||||||
|
|
||||||
|
fn quota_io_error() -> io::Error {
|
||||||
|
io::Error::new(io::ErrorKind::PermissionDenied, QuotaIoSentinel)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_quota_io_error(err: &io::Error) -> bool {
|
||||||
|
err.kind() == io::ErrorKind::PermissionDenied
|
||||||
|
&& err
|
||||||
|
.get_ref()
|
||||||
|
.and_then(|source| source.downcast_ref::<QuotaIoSentinel>())
|
||||||
|
.is_some()
|
||||||
|
}
|
||||||
|
|
||||||
|
static QUOTA_USER_LOCKS: OnceLock<DashMap<String, Arc<Mutex<()>>>> = OnceLock::new();
|
||||||
|
|
||||||
|
fn quota_user_lock(user: &str) -> Arc<Mutex<()>> {
|
||||||
|
let locks = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||||
|
if let Some(existing) = locks.get(user) {
|
||||||
|
return Arc::clone(existing.value());
|
||||||
|
}
|
||||||
|
|
||||||
|
let created = Arc::new(Mutex::new(()));
|
||||||
|
match locks.entry(user.to_string()) {
|
||||||
|
dashmap::mapref::entry::Entry::Occupied(entry) => Arc::clone(entry.get()),
|
||||||
|
dashmap::mapref::entry::Entry::Vacant(entry) => {
|
||||||
|
entry.insert(Arc::clone(&created));
|
||||||
|
created
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -229,6 +283,32 @@ impl<S: AsyncRead + Unpin> AsyncRead for StatsIo<S> {
|
||||||
buf: &mut ReadBuf<'_>,
|
buf: &mut ReadBuf<'_>,
|
||||||
) -> Poll<io::Result<()>> {
|
) -> Poll<io::Result<()>> {
|
||||||
let this = self.get_mut();
|
let this = self.get_mut();
|
||||||
|
if this.quota_exceeded.load(Ordering::Relaxed) {
|
||||||
|
return Poll::Ready(Err(quota_io_error()));
|
||||||
|
}
|
||||||
|
|
||||||
|
let quota_lock = this
|
||||||
|
.quota_limit
|
||||||
|
.is_some()
|
||||||
|
.then(|| quota_user_lock(&this.user));
|
||||||
|
let _quota_guard = if let Some(lock) = quota_lock.as_ref() {
|
||||||
|
match lock.try_lock() {
|
||||||
|
Ok(guard) => Some(guard),
|
||||||
|
Err(_) => {
|
||||||
|
cx.waker().wake_by_ref();
|
||||||
|
return Poll::Pending;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Some(limit) = this.quota_limit
|
||||||
|
&& this.stats.get_user_total_octets(&this.user) >= limit
|
||||||
|
{
|
||||||
|
this.quota_exceeded.store(true, Ordering::Relaxed);
|
||||||
|
return Poll::Ready(Err(quota_io_error()));
|
||||||
|
}
|
||||||
let before = buf.filled().len();
|
let before = buf.filled().len();
|
||||||
|
|
||||||
match Pin::new(&mut this.inner).poll_read(cx, buf) {
|
match Pin::new(&mut this.inner).poll_read(cx, buf) {
|
||||||
|
|
@ -243,6 +323,13 @@ impl<S: AsyncRead + Unpin> AsyncRead for StatsIo<S> {
|
||||||
this.stats.add_user_octets_from(&this.user, n as u64);
|
this.stats.add_user_octets_from(&this.user, n as u64);
|
||||||
this.stats.increment_user_msgs_from(&this.user);
|
this.stats.increment_user_msgs_from(&this.user);
|
||||||
|
|
||||||
|
if let Some(limit) = this.quota_limit
|
||||||
|
&& this.stats.get_user_total_octets(&this.user) >= limit
|
||||||
|
{
|
||||||
|
this.quota_exceeded.store(true, Ordering::Relaxed);
|
||||||
|
return Poll::Ready(Err(quota_io_error()));
|
||||||
|
}
|
||||||
|
|
||||||
trace!(user = %this.user, bytes = n, "C->S");
|
trace!(user = %this.user, bytes = n, "C->S");
|
||||||
}
|
}
|
||||||
Poll::Ready(Ok(()))
|
Poll::Ready(Ok(()))
|
||||||
|
|
@ -259,8 +346,46 @@ impl<S: AsyncWrite + Unpin> AsyncWrite for StatsIo<S> {
|
||||||
buf: &[u8],
|
buf: &[u8],
|
||||||
) -> Poll<io::Result<usize>> {
|
) -> Poll<io::Result<usize>> {
|
||||||
let this = self.get_mut();
|
let this = self.get_mut();
|
||||||
|
if this.quota_exceeded.load(Ordering::Relaxed) {
|
||||||
|
return Poll::Ready(Err(quota_io_error()));
|
||||||
|
}
|
||||||
|
|
||||||
match Pin::new(&mut this.inner).poll_write(cx, buf) {
|
let quota_lock = this
|
||||||
|
.quota_limit
|
||||||
|
.is_some()
|
||||||
|
.then(|| quota_user_lock(&this.user));
|
||||||
|
let _quota_guard = if let Some(lock) = quota_lock.as_ref() {
|
||||||
|
match lock.try_lock() {
|
||||||
|
Ok(guard) => Some(guard),
|
||||||
|
Err(_) => {
|
||||||
|
cx.waker().wake_by_ref();
|
||||||
|
return Poll::Pending;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
let write_buf = if let Some(limit) = this.quota_limit {
|
||||||
|
let used = this.stats.get_user_total_octets(&this.user);
|
||||||
|
if used >= limit {
|
||||||
|
this.quota_exceeded.store(true, Ordering::Relaxed);
|
||||||
|
return Poll::Ready(Err(quota_io_error()));
|
||||||
|
}
|
||||||
|
|
||||||
|
let remaining = (limit - used) as usize;
|
||||||
|
if buf.len() > remaining {
|
||||||
|
// Fail closed: do not emit partial S->C payload when remaining
|
||||||
|
// quota cannot accommodate the pending write request.
|
||||||
|
this.quota_exceeded.store(true, Ordering::Relaxed);
|
||||||
|
return Poll::Ready(Err(quota_io_error()));
|
||||||
|
}
|
||||||
|
buf
|
||||||
|
} else {
|
||||||
|
buf
|
||||||
|
};
|
||||||
|
|
||||||
|
match Pin::new(&mut this.inner).poll_write(cx, write_buf) {
|
||||||
Poll::Ready(Ok(n)) => {
|
Poll::Ready(Ok(n)) => {
|
||||||
if n > 0 {
|
if n > 0 {
|
||||||
// S→C: data written to client
|
// S→C: data written to client
|
||||||
|
|
@ -271,6 +396,13 @@ impl<S: AsyncWrite + Unpin> AsyncWrite for StatsIo<S> {
|
||||||
this.stats.add_user_octets_to(&this.user, n as u64);
|
this.stats.add_user_octets_to(&this.user, n as u64);
|
||||||
this.stats.increment_user_msgs_to(&this.user);
|
this.stats.increment_user_msgs_to(&this.user);
|
||||||
|
|
||||||
|
if let Some(limit) = this.quota_limit
|
||||||
|
&& this.stats.get_user_total_octets(&this.user) >= limit
|
||||||
|
{
|
||||||
|
this.quota_exceeded.store(true, Ordering::Relaxed);
|
||||||
|
return Poll::Ready(Err(quota_io_error()));
|
||||||
|
}
|
||||||
|
|
||||||
trace!(user = %this.user, bytes = n, "S->C");
|
trace!(user = %this.user, bytes = n, "S->C");
|
||||||
}
|
}
|
||||||
Poll::Ready(Ok(n))
|
Poll::Ready(Ok(n))
|
||||||
|
|
@ -307,7 +439,8 @@ impl<S: AsyncWrite + Unpin> AsyncWrite for StatsIo<S> {
|
||||||
/// - Per-user stats: bytes and ops counted per direction
|
/// - Per-user stats: bytes and ops counted per direction
|
||||||
/// - Periodic rate logging: every 10 seconds when active
|
/// - Periodic rate logging: every 10 seconds when active
|
||||||
/// - Clean shutdown: both write sides are shut down on exit
|
/// - Clean shutdown: both write sides are shut down on exit
|
||||||
/// - Error propagation: I/O errors are returned as `ProxyError::Io`
|
/// - Error propagation: quota exits return `ProxyError::DataQuotaExceeded`,
|
||||||
|
/// other I/O failures are returned as `ProxyError::Io`
|
||||||
pub async fn relay_bidirectional<CR, CW, SR, SW>(
|
pub async fn relay_bidirectional<CR, CW, SR, SW>(
|
||||||
client_reader: CR,
|
client_reader: CR,
|
||||||
client_writer: CW,
|
client_writer: CW,
|
||||||
|
|
@ -317,6 +450,7 @@ pub async fn relay_bidirectional<CR, CW, SR, SW>(
|
||||||
s2c_buf_size: usize,
|
s2c_buf_size: usize,
|
||||||
user: &str,
|
user: &str,
|
||||||
stats: Arc<Stats>,
|
stats: Arc<Stats>,
|
||||||
|
quota_limit: Option<u64>,
|
||||||
_buffer_pool: Arc<BufferPool>,
|
_buffer_pool: Arc<BufferPool>,
|
||||||
) -> Result<()>
|
) -> Result<()>
|
||||||
where
|
where
|
||||||
|
|
@ -327,6 +461,7 @@ where
|
||||||
{
|
{
|
||||||
let epoch = Instant::now();
|
let epoch = Instant::now();
|
||||||
let counters = Arc::new(SharedCounters::new());
|
let counters = Arc::new(SharedCounters::new());
|
||||||
|
let quota_exceeded = Arc::new(AtomicBool::new(false));
|
||||||
let user_owned = user.to_string();
|
let user_owned = user.to_string();
|
||||||
|
|
||||||
// ── Combine split halves into bidirectional streams ──────────────
|
// ── Combine split halves into bidirectional streams ──────────────
|
||||||
|
|
@ -339,12 +474,15 @@ where
|
||||||
Arc::clone(&counters),
|
Arc::clone(&counters),
|
||||||
Arc::clone(&stats),
|
Arc::clone(&stats),
|
||||||
user_owned.clone(),
|
user_owned.clone(),
|
||||||
|
quota_limit,
|
||||||
|
Arc::clone("a_exceeded),
|
||||||
epoch,
|
epoch,
|
||||||
);
|
);
|
||||||
|
|
||||||
// ── Watchdog: activity timeout + periodic rate logging ──────────
|
// ── Watchdog: activity timeout + periodic rate logging ──────────
|
||||||
let wd_counters = Arc::clone(&counters);
|
let wd_counters = Arc::clone(&counters);
|
||||||
let wd_user = user_owned.clone();
|
let wd_user = user_owned.clone();
|
||||||
|
let wd_quota_exceeded = Arc::clone("a_exceeded);
|
||||||
|
|
||||||
let watchdog = async {
|
let watchdog = async {
|
||||||
let mut prev_c2s: u64 = 0;
|
let mut prev_c2s: u64 = 0;
|
||||||
|
|
@ -356,6 +494,11 @@ where
|
||||||
let now = Instant::now();
|
let now = Instant::now();
|
||||||
let idle = wd_counters.idle_duration(now, epoch);
|
let idle = wd_counters.idle_duration(now, epoch);
|
||||||
|
|
||||||
|
if wd_quota_exceeded.load(Ordering::Relaxed) {
|
||||||
|
warn!(user = %wd_user, "User data quota reached, closing relay");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// ── Activity timeout ────────────────────────────────────
|
// ── Activity timeout ────────────────────────────────────
|
||||||
if idle >= ACTIVITY_TIMEOUT {
|
if idle >= ACTIVITY_TIMEOUT {
|
||||||
let c2s = wd_counters.c2s_bytes.load(Ordering::Relaxed);
|
let c2s = wd_counters.c2s_bytes.load(Ordering::Relaxed);
|
||||||
|
|
@ -439,6 +582,22 @@ where
|
||||||
);
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
Some(Err(e)) if is_quota_io_error(&e) => {
|
||||||
|
let c2s = counters.c2s_bytes.load(Ordering::Relaxed);
|
||||||
|
let s2c = counters.s2c_bytes.load(Ordering::Relaxed);
|
||||||
|
warn!(
|
||||||
|
user = %user_owned,
|
||||||
|
c2s_bytes = c2s,
|
||||||
|
s2c_bytes = s2c,
|
||||||
|
c2s_msgs = c2s_ops,
|
||||||
|
s2c_msgs = s2c_ops,
|
||||||
|
duration_secs = duration.as_secs(),
|
||||||
|
"Data quota reached, closing relay"
|
||||||
|
);
|
||||||
|
Err(ProxyError::DataQuotaExceeded {
|
||||||
|
user: user_owned.clone(),
|
||||||
|
})
|
||||||
|
}
|
||||||
Some(Err(e)) => {
|
Some(Err(e)) => {
|
||||||
// I/O error in one of the directions
|
// I/O error in one of the directions
|
||||||
let c2s = counters.c2s_bytes.load(Ordering::Relaxed);
|
let c2s = counters.c2s_bytes.load(Ordering::Relaxed);
|
||||||
|
|
@ -472,3 +631,7 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
#[path = "relay_security_tests.rs"]
|
||||||
|
mod security_tests;
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,972 @@
|
||||||
|
use super::relay_bidirectional;
|
||||||
|
use crate::error::ProxyError;
|
||||||
|
use crate::stats::Stats;
|
||||||
|
use crate::stream::BufferPool;
|
||||||
|
use std::future::poll_fn;
|
||||||
|
use std::io;
|
||||||
|
use std::pin::Pin;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||||
|
use std::sync::Mutex;
|
||||||
|
use std::task::{Context, Poll};
|
||||||
|
use std::task::Waker;
|
||||||
|
use tokio::io::{AsyncRead, ReadBuf};
|
||||||
|
use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt, duplex};
|
||||||
|
use tokio::time::{Duration, timeout};
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn relay_bidirectional_enforces_live_user_quota() {
|
||||||
|
let stats = Arc::new(Stats::new());
|
||||||
|
let user = "quota-user";
|
||||||
|
stats.add_user_octets_from(user, 6);
|
||||||
|
|
||||||
|
let (mut client_peer, relay_client) = duplex(4096);
|
||||||
|
let (relay_server, mut server_peer) = duplex(4096);
|
||||||
|
|
||||||
|
let (client_reader, client_writer) = tokio::io::split(relay_client);
|
||||||
|
let (server_reader, server_writer) = tokio::io::split(relay_server);
|
||||||
|
|
||||||
|
let relay_task = tokio::spawn(relay_bidirectional(
|
||||||
|
client_reader,
|
||||||
|
client_writer,
|
||||||
|
server_reader,
|
||||||
|
server_writer,
|
||||||
|
1024,
|
||||||
|
1024,
|
||||||
|
user,
|
||||||
|
Arc::clone(&stats),
|
||||||
|
Some(8),
|
||||||
|
Arc::new(BufferPool::new()),
|
||||||
|
));
|
||||||
|
|
||||||
|
client_peer
|
||||||
|
.write_all(&[0x10, 0x20, 0x30, 0x40])
|
||||||
|
.await
|
||||||
|
.expect("client write must succeed");
|
||||||
|
|
||||||
|
let mut forwarded = [0u8; 4];
|
||||||
|
let _ = timeout(
|
||||||
|
Duration::from_millis(200),
|
||||||
|
server_peer.read_exact(&mut forwarded),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let relay_result = timeout(Duration::from_secs(2), relay_task)
|
||||||
|
.await
|
||||||
|
.expect("relay task must finish under quota cutoff")
|
||||||
|
.expect("relay task must not panic");
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == "quota-user"),
|
||||||
|
"relay must surface a typed quota error once live quota is exceeded"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn relay_bidirectional_does_not_forward_server_bytes_after_quota_is_exhausted() {
|
||||||
|
let stats = Arc::new(Stats::new());
|
||||||
|
let quota_user = "quota-exhausted-user";
|
||||||
|
stats.add_user_octets_from(quota_user, 1);
|
||||||
|
|
||||||
|
let (mut client_peer, relay_client) = duplex(4096);
|
||||||
|
let (relay_server, mut server_peer) = duplex(4096);
|
||||||
|
|
||||||
|
let (client_reader, client_writer) = tokio::io::split(relay_client);
|
||||||
|
let (server_reader, server_writer) = tokio::io::split(relay_server);
|
||||||
|
|
||||||
|
let relay_task = tokio::spawn(relay_bidirectional(
|
||||||
|
client_reader,
|
||||||
|
client_writer,
|
||||||
|
server_reader,
|
||||||
|
server_writer,
|
||||||
|
1024,
|
||||||
|
1024,
|
||||||
|
quota_user,
|
||||||
|
Arc::clone(&stats),
|
||||||
|
Some(1),
|
||||||
|
Arc::new(BufferPool::new()),
|
||||||
|
));
|
||||||
|
|
||||||
|
server_peer
|
||||||
|
.write_all(&[0xde, 0xad, 0xbe, 0xef])
|
||||||
|
.await
|
||||||
|
.expect("server write must succeed");
|
||||||
|
|
||||||
|
let mut observed = [0u8; 4];
|
||||||
|
let forwarded = timeout(
|
||||||
|
Duration::from_millis(200),
|
||||||
|
client_peer.read_exact(&mut observed),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let relay_result = timeout(Duration::from_secs(2), relay_task)
|
||||||
|
.await
|
||||||
|
.expect("relay task must finish under quota cutoff")
|
||||||
|
.expect("relay task must not panic");
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
!matches!(forwarded, Ok(Ok(n)) if n == observed.len()),
|
||||||
|
"no full server payload should be forwarded once quota is already exhausted"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user),
|
||||||
|
"relay must still terminate with a typed quota error"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn relay_bidirectional_does_not_leak_partial_server_payload_when_remaining_quota_is_smaller_than_write() {
|
||||||
|
let stats = Arc::new(Stats::new());
|
||||||
|
let quota_user = "partial-leak-user";
|
||||||
|
stats.add_user_octets_from(quota_user, 3);
|
||||||
|
|
||||||
|
let (mut client_peer, relay_client) = duplex(4096);
|
||||||
|
let (relay_server, mut server_peer) = duplex(4096);
|
||||||
|
|
||||||
|
let (client_reader, client_writer) = tokio::io::split(relay_client);
|
||||||
|
let (server_reader, server_writer) = tokio::io::split(relay_server);
|
||||||
|
|
||||||
|
let relay_task = tokio::spawn(relay_bidirectional(
|
||||||
|
client_reader,
|
||||||
|
client_writer,
|
||||||
|
server_reader,
|
||||||
|
server_writer,
|
||||||
|
1024,
|
||||||
|
1024,
|
||||||
|
quota_user,
|
||||||
|
Arc::clone(&stats),
|
||||||
|
Some(4),
|
||||||
|
Arc::new(BufferPool::new()),
|
||||||
|
));
|
||||||
|
|
||||||
|
server_peer
|
||||||
|
.write_all(&[0x11, 0x22, 0x33, 0x44])
|
||||||
|
.await
|
||||||
|
.expect("server write must succeed");
|
||||||
|
|
||||||
|
let mut observed = [0u8; 8];
|
||||||
|
let forwarded = timeout(Duration::from_millis(200), client_peer.read(&mut observed)).await;
|
||||||
|
|
||||||
|
let relay_result = timeout(Duration::from_secs(2), relay_task)
|
||||||
|
.await
|
||||||
|
.expect("relay task must finish under quota cutoff")
|
||||||
|
.expect("relay task must not panic");
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
!matches!(forwarded, Ok(Ok(n)) if n > 0),
|
||||||
|
"quota exhaustion must not leak any partial server payload when remaining quota is smaller than the write"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user),
|
||||||
|
"relay must still terminate with a typed quota error"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn relay_bidirectional_zero_quota_remains_fail_closed_for_server_payloads_under_stress() {
|
||||||
|
let stats = Arc::new(Stats::new());
|
||||||
|
let quota_user = "zero-quota-user";
|
||||||
|
|
||||||
|
for payload_len in [1usize, 16, 512, 4096] {
|
||||||
|
let (mut client_peer, relay_client) = duplex(4096);
|
||||||
|
let (relay_server, mut server_peer) = duplex(4096);
|
||||||
|
|
||||||
|
let (client_reader, client_writer) = tokio::io::split(relay_client);
|
||||||
|
let (server_reader, server_writer) = tokio::io::split(relay_server);
|
||||||
|
|
||||||
|
let relay_task = tokio::spawn(relay_bidirectional(
|
||||||
|
client_reader,
|
||||||
|
client_writer,
|
||||||
|
server_reader,
|
||||||
|
server_writer,
|
||||||
|
1024,
|
||||||
|
1024,
|
||||||
|
quota_user,
|
||||||
|
Arc::clone(&stats),
|
||||||
|
Some(0),
|
||||||
|
Arc::new(BufferPool::new()),
|
||||||
|
));
|
||||||
|
|
||||||
|
let payload = vec![0x7f; payload_len];
|
||||||
|
let _ = server_peer.write_all(&payload).await;
|
||||||
|
|
||||||
|
let mut observed = vec![0u8; payload_len];
|
||||||
|
let forwarded = timeout(Duration::from_millis(200), client_peer.read(&mut observed)).await;
|
||||||
|
|
||||||
|
let relay_result = timeout(Duration::from_secs(2), relay_task)
|
||||||
|
.await
|
||||||
|
.expect("relay task must finish under zero-quota cutoff")
|
||||||
|
.expect("relay task must not panic");
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
!matches!(forwarded, Ok(Ok(n)) if n > 0),
|
||||||
|
"zero quota must not forward any server bytes for payload_len={payload_len}"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user),
|
||||||
|
"zero quota must terminate with the typed quota error for payload_len={payload_len}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn relay_bidirectional_allows_exact_server_payload_at_quota_boundary() {
|
||||||
|
let stats = Arc::new(Stats::new());
|
||||||
|
let quota_user = "exact-boundary-user";
|
||||||
|
|
||||||
|
let (mut client_peer, relay_client) = duplex(4096);
|
||||||
|
let (relay_server, mut server_peer) = duplex(4096);
|
||||||
|
|
||||||
|
let (client_reader, client_writer) = tokio::io::split(relay_client);
|
||||||
|
let (server_reader, server_writer) = tokio::io::split(relay_server);
|
||||||
|
|
||||||
|
let relay_task = tokio::spawn(relay_bidirectional(
|
||||||
|
client_reader,
|
||||||
|
client_writer,
|
||||||
|
server_reader,
|
||||||
|
server_writer,
|
||||||
|
1024,
|
||||||
|
1024,
|
||||||
|
quota_user,
|
||||||
|
Arc::clone(&stats),
|
||||||
|
Some(4),
|
||||||
|
Arc::new(BufferPool::new()),
|
||||||
|
));
|
||||||
|
|
||||||
|
server_peer
|
||||||
|
.write_all(&[0x91, 0x92, 0x93, 0x94])
|
||||||
|
.await
|
||||||
|
.expect("server write must succeed at exact quota boundary");
|
||||||
|
|
||||||
|
let mut observed = [0u8; 4];
|
||||||
|
client_peer
|
||||||
|
.read_exact(&mut observed)
|
||||||
|
.await
|
||||||
|
.expect("client must receive the full payload at the exact quota boundary");
|
||||||
|
assert_eq!(observed, [0x91, 0x92, 0x93, 0x94]);
|
||||||
|
|
||||||
|
let relay_result = timeout(Duration::from_secs(2), relay_task)
|
||||||
|
.await
|
||||||
|
.expect("relay task must finish after exact boundary delivery")
|
||||||
|
.expect("relay task must not panic");
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user),
|
||||||
|
"relay must close with a typed quota error after reaching the exact boundary"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn relay_bidirectional_does_not_forward_client_bytes_after_quota_is_exhausted() {
|
||||||
|
let stats = Arc::new(Stats::new());
|
||||||
|
let quota_user = "client-exhausted-user";
|
||||||
|
stats.add_user_octets_from(quota_user, 1);
|
||||||
|
|
||||||
|
let (mut client_peer, relay_client) = duplex(4096);
|
||||||
|
let (relay_server, mut server_peer) = duplex(4096);
|
||||||
|
|
||||||
|
let (client_reader, client_writer) = tokio::io::split(relay_client);
|
||||||
|
let (server_reader, server_writer) = tokio::io::split(relay_server);
|
||||||
|
|
||||||
|
let relay_task = tokio::spawn(relay_bidirectional(
|
||||||
|
client_reader,
|
||||||
|
client_writer,
|
||||||
|
server_reader,
|
||||||
|
server_writer,
|
||||||
|
1024,
|
||||||
|
1024,
|
||||||
|
quota_user,
|
||||||
|
Arc::clone(&stats),
|
||||||
|
Some(1),
|
||||||
|
Arc::new(BufferPool::new()),
|
||||||
|
));
|
||||||
|
|
||||||
|
client_peer
|
||||||
|
.write_all(&[0x51, 0x52, 0x53, 0x54])
|
||||||
|
.await
|
||||||
|
.expect("client write must succeed even when quota is already exhausted");
|
||||||
|
|
||||||
|
let mut observed = [0u8; 4];
|
||||||
|
let forwarded = timeout(
|
||||||
|
Duration::from_millis(200),
|
||||||
|
server_peer.read_exact(&mut observed),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let relay_result = timeout(Duration::from_secs(2), relay_task)
|
||||||
|
.await
|
||||||
|
.expect("relay task must finish under quota cutoff")
|
||||||
|
.expect("relay task must not panic");
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
!matches!(forwarded, Ok(Ok(n)) if n == observed.len()),
|
||||||
|
"client payload must not be fully forwarded once quota is already exhausted"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user),
|
||||||
|
"relay must still terminate with a typed quota error"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn relay_bidirectional_server_bytes_remain_blocked_even_under_multiple_payload_sizes() {
|
||||||
|
let stats = Arc::new(Stats::new());
|
||||||
|
let quota_user = "quota-fuzz-user";
|
||||||
|
stats.add_user_octets_from(quota_user, 2);
|
||||||
|
|
||||||
|
for payload_len in [1usize, 32, 1024, 8192] {
|
||||||
|
let (mut client_peer, relay_client) = duplex(4096);
|
||||||
|
let (relay_server, mut server_peer) = duplex(4096);
|
||||||
|
|
||||||
|
let (client_reader, client_writer) = tokio::io::split(relay_client);
|
||||||
|
let (server_reader, server_writer) = tokio::io::split(relay_server);
|
||||||
|
|
||||||
|
let relay_task = tokio::spawn(relay_bidirectional(
|
||||||
|
client_reader,
|
||||||
|
client_writer,
|
||||||
|
server_reader,
|
||||||
|
server_writer,
|
||||||
|
1024,
|
||||||
|
1024,
|
||||||
|
quota_user,
|
||||||
|
Arc::clone(&stats),
|
||||||
|
Some(2),
|
||||||
|
Arc::new(BufferPool::new()),
|
||||||
|
));
|
||||||
|
|
||||||
|
let payload = vec![0xaa; payload_len];
|
||||||
|
let _ = server_peer.write_all(&payload).await;
|
||||||
|
|
||||||
|
let mut observed = vec![0u8; payload_len];
|
||||||
|
let forwarded = timeout(
|
||||||
|
Duration::from_millis(200),
|
||||||
|
client_peer.read_exact(&mut observed),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let relay_result = timeout(Duration::from_secs(2), relay_task)
|
||||||
|
.await
|
||||||
|
.expect("relay task must finish under quota cutoff")
|
||||||
|
.expect("relay task must not panic");
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
!matches!(forwarded, Ok(Ok(n)) if n == payload_len),
|
||||||
|
"quota exhaustion must block full server-to-client forwarding for payload_len={payload_len}"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user),
|
||||||
|
"relay must keep returning the typed quota error for payload_len={payload_len}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn relay_bidirectional_terminates_on_activity_timeout() {
|
||||||
|
tokio::time::pause();
|
||||||
|
let stats = Arc::new(Stats::new());
|
||||||
|
let user = "timeout-user";
|
||||||
|
|
||||||
|
let (client_peer, relay_client) = duplex(4096);
|
||||||
|
let (relay_server, server_peer) = duplex(4096);
|
||||||
|
|
||||||
|
let (client_reader, client_writer) = tokio::io::split(relay_client);
|
||||||
|
let (server_reader, server_writer) = tokio::io::split(relay_server);
|
||||||
|
|
||||||
|
let relay_task = tokio::spawn(relay_bidirectional(
|
||||||
|
client_reader,
|
||||||
|
client_writer,
|
||||||
|
server_reader,
|
||||||
|
server_writer,
|
||||||
|
1024,
|
||||||
|
1024,
|
||||||
|
user,
|
||||||
|
Arc::clone(&stats),
|
||||||
|
None, // No quota
|
||||||
|
Arc::new(BufferPool::new()),
|
||||||
|
));
|
||||||
|
|
||||||
|
// Wait past the activity timeout threshold (1800 seconds) + buffer
|
||||||
|
tokio::time::sleep(Duration::from_secs(1805)).await;
|
||||||
|
|
||||||
|
// Resume time to process timeouts
|
||||||
|
tokio::time::resume();
|
||||||
|
|
||||||
|
let relay_result = timeout(Duration::from_secs(1), relay_task)
|
||||||
|
.await
|
||||||
|
.expect("relay task must finish inside bounded timeout due to inactivity cutoff")
|
||||||
|
.expect("relay task must not panic");
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
relay_result.is_ok(),
|
||||||
|
"relay should complete successfully on scheduled inactivity timeout"
|
||||||
|
);
|
||||||
|
|
||||||
|
// Verify client/server sockets are closed
|
||||||
|
drop(client_peer);
|
||||||
|
drop(server_peer);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn relay_bidirectional_watchdog_resists_premature_execution() {
|
||||||
|
tokio::time::pause();
|
||||||
|
let stats = Arc::new(Stats::new());
|
||||||
|
let user = "activity-user";
|
||||||
|
|
||||||
|
let (mut client_peer, relay_client) = duplex(4096);
|
||||||
|
let (relay_server, server_peer) = duplex(4096);
|
||||||
|
|
||||||
|
let (client_reader, client_writer) = tokio::io::split(relay_client);
|
||||||
|
let (server_reader, server_writer) = tokio::io::split(relay_server);
|
||||||
|
|
||||||
|
let mut relay_task = tokio::spawn(relay_bidirectional(
|
||||||
|
client_reader,
|
||||||
|
client_writer,
|
||||||
|
server_reader,
|
||||||
|
server_writer,
|
||||||
|
1024,
|
||||||
|
1024,
|
||||||
|
user,
|
||||||
|
Arc::clone(&stats),
|
||||||
|
None,
|
||||||
|
Arc::new(BufferPool::new()),
|
||||||
|
));
|
||||||
|
|
||||||
|
// Advance by half the timeout
|
||||||
|
tokio::time::sleep(Duration::from_secs(900)).await;
|
||||||
|
|
||||||
|
// Provide activity
|
||||||
|
client_peer
|
||||||
|
.write_all(&[0xaa, 0xbb])
|
||||||
|
.await
|
||||||
|
.expect("client write must succeed");
|
||||||
|
client_peer.flush().await.unwrap();
|
||||||
|
|
||||||
|
// Advance by another half (total time since start is 1800, but since last activity is 900)
|
||||||
|
tokio::time::sleep(Duration::from_secs(900)).await;
|
||||||
|
|
||||||
|
tokio::time::resume();
|
||||||
|
|
||||||
|
// Re-evaluating the task, it should NOT have timed out and still be pending
|
||||||
|
let relay_result = timeout(Duration::from_millis(100), &mut relay_task).await;
|
||||||
|
assert!(
|
||||||
|
relay_result.is_err(),
|
||||||
|
"Relay must not exit prematurely as long as activity was received before timeout"
|
||||||
|
);
|
||||||
|
|
||||||
|
// Explicitly drop sockets to cleanly shut down relay loop
|
||||||
|
drop(client_peer);
|
||||||
|
drop(server_peer);
|
||||||
|
|
||||||
|
let completion = timeout(Duration::from_secs(1), relay_task).await
|
||||||
|
.expect("relay task must complete securely after client disconnection")
|
||||||
|
.expect("relay task must not panic");
|
||||||
|
assert!(completion.is_ok(), "relay exits clean");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn relay_bidirectional_half_closure_terminates_cleanly() {
|
||||||
|
let stats = Arc::new(Stats::new());
|
||||||
|
let (client_peer, relay_client) = duplex(4096);
|
||||||
|
let (relay_server, server_peer) = duplex(4096);
|
||||||
|
let (client_reader, client_writer) = tokio::io::split(relay_client);
|
||||||
|
let (server_reader, server_writer) = tokio::io::split(relay_server);
|
||||||
|
|
||||||
|
let relay_task = tokio::spawn(relay_bidirectional(
|
||||||
|
client_reader, client_writer, server_reader, server_writer, 1024, 1024, "half-close", stats, None, Arc::new(BufferPool::new()),
|
||||||
|
));
|
||||||
|
|
||||||
|
// Half closure: drop the client completely but leave the server active.
|
||||||
|
drop(client_peer);
|
||||||
|
|
||||||
|
// Check that we don't immediately crash. Bidirectional relay stays open for the server -> client flush.
|
||||||
|
// Eventually dropping the server cleanly closes the task.
|
||||||
|
drop(server_peer);
|
||||||
|
timeout(Duration::from_secs(1), relay_task).await.unwrap().unwrap().unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn relay_bidirectional_zero_length_noise_fuzzing() {
|
||||||
|
let stats = Arc::new(Stats::new());
|
||||||
|
let (mut client_peer, relay_client) = duplex(4096);
|
||||||
|
let (relay_server, mut server_peer) = duplex(4096);
|
||||||
|
let (client_reader, client_writer) = tokio::io::split(relay_client);
|
||||||
|
let (server_reader, server_writer) = tokio::io::split(relay_server);
|
||||||
|
|
||||||
|
let relay_task = tokio::spawn(relay_bidirectional(
|
||||||
|
client_reader, client_writer, server_reader, server_writer, 1024, 1024, "fuzz", stats, None, Arc::new(BufferPool::new()),
|
||||||
|
));
|
||||||
|
|
||||||
|
// Flood with zero-length payloads (edge cases in stream framing logic sometimes loop)
|
||||||
|
for _ in 0..100 {
|
||||||
|
client_peer.write_all(&[]).await.unwrap();
|
||||||
|
}
|
||||||
|
client_peer.write_all(&[1, 2, 3]).await.unwrap();
|
||||||
|
client_peer.flush().await.unwrap();
|
||||||
|
|
||||||
|
let mut buf = [0u8; 3];
|
||||||
|
server_peer.read_exact(&mut buf).await.unwrap();
|
||||||
|
assert_eq!(&buf, &[1, 2, 3]);
|
||||||
|
|
||||||
|
drop(client_peer);
|
||||||
|
drop(server_peer);
|
||||||
|
timeout(Duration::from_secs(1), relay_task).await.unwrap().unwrap().unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn relay_bidirectional_asymmetric_backpressure() {
|
||||||
|
let stats = Arc::new(Stats::new());
|
||||||
|
// Give the client stream an extremely narrow throughput limit explicitly
|
||||||
|
let (client_peer, relay_client) = duplex(1024);
|
||||||
|
let (relay_server, mut server_peer) = duplex(4096);
|
||||||
|
let (client_reader, client_writer) = tokio::io::split(relay_client);
|
||||||
|
let (server_reader, server_writer) = tokio::io::split(relay_server);
|
||||||
|
|
||||||
|
let relay_task = tokio::spawn(relay_bidirectional(
|
||||||
|
client_reader, client_writer, server_reader, server_writer, 1024, 1024, "slowloris", stats, None, Arc::new(BufferPool::new()),
|
||||||
|
));
|
||||||
|
|
||||||
|
let payload = vec![0xba; 65536]; // 64k payload
|
||||||
|
|
||||||
|
// Server attempts to shove 64KB into a relay whose client pipe only holds 1KB!
|
||||||
|
let write_res = tokio::time::timeout(Duration::from_millis(50), server_peer.write_all(&payload)).await;
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
write_res.is_err(),
|
||||||
|
"Relay backpressure MUST halt the server writer from unbounded buffering when client stream is full!"
|
||||||
|
);
|
||||||
|
|
||||||
|
drop(client_peer);
|
||||||
|
drop(server_peer);
|
||||||
|
|
||||||
|
let completion = timeout(Duration::from_secs(1), relay_task).await.unwrap().unwrap();
|
||||||
|
assert!(
|
||||||
|
completion.is_ok() || completion.is_err(),
|
||||||
|
"Task must unwind reliably (either Ok or BrokenPipe Err) when dropped despite active backpressure locks"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
use rand::{Rng, SeedableRng, rngs::StdRng};
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn relay_bidirectional_light_fuzzing_temporal_jitter() {
|
||||||
|
tokio::time::pause();
|
||||||
|
let stats = Arc::new(Stats::new());
|
||||||
|
let (mut client_peer, relay_client) = duplex(4096);
|
||||||
|
let (relay_server, server_peer) = duplex(4096);
|
||||||
|
let (client_reader, client_writer) = tokio::io::split(relay_client);
|
||||||
|
let (server_reader, server_writer) = tokio::io::split(relay_server);
|
||||||
|
|
||||||
|
let mut relay_task = tokio::spawn(relay_bidirectional(
|
||||||
|
client_reader, client_writer, server_reader, server_writer, 1024, 1024, "fuzz-user", stats, None, Arc::new(BufferPool::new()),
|
||||||
|
));
|
||||||
|
|
||||||
|
let mut rng = StdRng::seed_from_u64(0xDEADBEEF);
|
||||||
|
|
||||||
|
for _ in 0..10 {
|
||||||
|
// Vary timing significantly up to 1600 seconds (limit is 1800s)
|
||||||
|
let jitter = rng.random_range(100..1600);
|
||||||
|
tokio::time::sleep(Duration::from_secs(jitter)).await;
|
||||||
|
|
||||||
|
client_peer.write_all(&[0x11]).await.unwrap();
|
||||||
|
client_peer.flush().await.unwrap();
|
||||||
|
|
||||||
|
// Ensure task has not died
|
||||||
|
let res = timeout(Duration::from_millis(10), &mut relay_task).await;
|
||||||
|
assert!(res.is_err(), "Relay must remain open indefinitely under light temporal fuzzing with active jitter pulses");
|
||||||
|
}
|
||||||
|
|
||||||
|
drop(client_peer);
|
||||||
|
drop(server_peer);
|
||||||
|
timeout(Duration::from_secs(1), relay_task).await.unwrap().unwrap().unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
struct FaultyReader {
|
||||||
|
error_once: Option<io::Error>,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct TwoPartyGate {
|
||||||
|
arrivals: AtomicUsize,
|
||||||
|
total_bytes: AtomicUsize,
|
||||||
|
wakers: Mutex<Vec<Waker>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TwoPartyGate {
|
||||||
|
fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
arrivals: AtomicUsize::new(0),
|
||||||
|
total_bytes: AtomicUsize::new(0),
|
||||||
|
wakers: Mutex::new(Vec::new()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn arrive_or_park(&self, cx: &mut Context<'_>) -> bool {
|
||||||
|
if self.arrivals.load(Ordering::Relaxed) >= 2 {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
let prev = self.arrivals.fetch_add(1, Ordering::AcqRel);
|
||||||
|
if prev + 1 >= 2 {
|
||||||
|
let mut wakers = self.wakers.lock().unwrap_or_else(|p| p.into_inner());
|
||||||
|
for waker in wakers.drain(..) {
|
||||||
|
waker.wake();
|
||||||
|
}
|
||||||
|
true
|
||||||
|
} else {
|
||||||
|
let mut wakers = self.wakers.lock().unwrap_or_else(|p| p.into_inner());
|
||||||
|
wakers.push(cx.waker().clone());
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn total_bytes(&self) -> usize {
|
||||||
|
self.total_bytes.load(Ordering::Relaxed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct GateWriter {
|
||||||
|
gate: Arc<TwoPartyGate>,
|
||||||
|
entered: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl GateWriter {
|
||||||
|
fn new(gate: Arc<TwoPartyGate>) -> Self {
|
||||||
|
Self {
|
||||||
|
gate,
|
||||||
|
entered: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AsyncWrite for GateWriter {
|
||||||
|
fn poll_write(
|
||||||
|
mut self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
buf: &[u8],
|
||||||
|
) -> Poll<io::Result<usize>> {
|
||||||
|
if !self.entered {
|
||||||
|
self.entered = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if !self.gate.arrive_or_park(cx) {
|
||||||
|
return Poll::Pending;
|
||||||
|
}
|
||||||
|
|
||||||
|
self.gate
|
||||||
|
.total_bytes
|
||||||
|
.fetch_add(buf.len(), Ordering::Relaxed);
|
||||||
|
Poll::Ready(Ok(buf.len()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct GateReader {
|
||||||
|
gate: Arc<TwoPartyGate>,
|
||||||
|
entered: bool,
|
||||||
|
emitted: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl GateReader {
|
||||||
|
fn new(gate: Arc<TwoPartyGate>) -> Self {
|
||||||
|
Self {
|
||||||
|
gate,
|
||||||
|
entered: false,
|
||||||
|
emitted: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AsyncRead for GateReader {
|
||||||
|
fn poll_read(
|
||||||
|
mut self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
buf: &mut ReadBuf<'_>,
|
||||||
|
) -> Poll<io::Result<()>> {
|
||||||
|
if self.emitted {
|
||||||
|
return Poll::Ready(Ok(()));
|
||||||
|
}
|
||||||
|
|
||||||
|
if !self.entered {
|
||||||
|
self.entered = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if !self.gate.arrive_or_park(cx) {
|
||||||
|
return Poll::Pending;
|
||||||
|
}
|
||||||
|
|
||||||
|
buf.put_slice(&[0x42]);
|
||||||
|
self.gate.total_bytes.fetch_add(1, Ordering::Relaxed);
|
||||||
|
self.emitted = true;
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||||
|
async fn adversarial_concurrent_quota_write_race_does_not_overshoot_limit() {
|
||||||
|
let stats = Arc::new(Stats::new());
|
||||||
|
let gate = Arc::new(TwoPartyGate::new());
|
||||||
|
let user = "concurrent-quota-write".to_string();
|
||||||
|
|
||||||
|
let writer_a = super::StatsIo::new(
|
||||||
|
GateWriter::new(Arc::clone(&gate)),
|
||||||
|
Arc::new(super::SharedCounters::new()),
|
||||||
|
Arc::clone(&stats),
|
||||||
|
user.clone(),
|
||||||
|
Some(1),
|
||||||
|
Arc::new(std::sync::atomic::AtomicBool::new(false)),
|
||||||
|
tokio::time::Instant::now(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let writer_b = super::StatsIo::new(
|
||||||
|
GateWriter::new(Arc::clone(&gate)),
|
||||||
|
Arc::new(super::SharedCounters::new()),
|
||||||
|
Arc::clone(&stats),
|
||||||
|
user.clone(),
|
||||||
|
Some(1),
|
||||||
|
Arc::new(std::sync::atomic::AtomicBool::new(false)),
|
||||||
|
tokio::time::Instant::now(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let task_a = tokio::spawn(async move {
|
||||||
|
let mut w = writer_a;
|
||||||
|
AsyncWriteExt::write_all(&mut w, &[0x01]).await
|
||||||
|
});
|
||||||
|
let task_b = tokio::spawn(async move {
|
||||||
|
let mut w = writer_b;
|
||||||
|
AsyncWriteExt::write_all(&mut w, &[0x02]).await
|
||||||
|
});
|
||||||
|
|
||||||
|
let (res_a, res_b) = tokio::join!(task_a, task_b);
|
||||||
|
let _ = res_a.expect("task a must join");
|
||||||
|
let _ = res_b.expect("task b must join");
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
gate.total_bytes() <= 1,
|
||||||
|
"concurrent same-user writes must not forward more than one byte under quota=1"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
stats.get_user_total_octets(&user) <= 1,
|
||||||
|
"concurrent same-user writes must not account over limit"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||||
|
async fn adversarial_concurrent_quota_read_race_does_not_overshoot_limit() {
|
||||||
|
let stats = Arc::new(Stats::new());
|
||||||
|
let gate = Arc::new(TwoPartyGate::new());
|
||||||
|
let user = "concurrent-quota-read".to_string();
|
||||||
|
|
||||||
|
let reader_a = super::StatsIo::new(
|
||||||
|
GateReader::new(Arc::clone(&gate)),
|
||||||
|
Arc::new(super::SharedCounters::new()),
|
||||||
|
Arc::clone(&stats),
|
||||||
|
user.clone(),
|
||||||
|
Some(1),
|
||||||
|
Arc::new(std::sync::atomic::AtomicBool::new(false)),
|
||||||
|
tokio::time::Instant::now(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let reader_b = super::StatsIo::new(
|
||||||
|
GateReader::new(Arc::clone(&gate)),
|
||||||
|
Arc::new(super::SharedCounters::new()),
|
||||||
|
Arc::clone(&stats),
|
||||||
|
user.clone(),
|
||||||
|
Some(1),
|
||||||
|
Arc::new(std::sync::atomic::AtomicBool::new(false)),
|
||||||
|
tokio::time::Instant::now(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let task_a = tokio::spawn(async move {
|
||||||
|
let mut r = reader_a;
|
||||||
|
let mut one = [0u8; 1];
|
||||||
|
AsyncReadExt::read_exact(&mut r, &mut one).await
|
||||||
|
});
|
||||||
|
let task_b = tokio::spawn(async move {
|
||||||
|
let mut r = reader_b;
|
||||||
|
let mut one = [0u8; 1];
|
||||||
|
AsyncReadExt::read_exact(&mut r, &mut one).await
|
||||||
|
});
|
||||||
|
|
||||||
|
let (res_a, res_b) = tokio::join!(task_a, task_b);
|
||||||
|
let _ = res_a.expect("task a must join");
|
||||||
|
let _ = res_b.expect("task b must join");
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
gate.total_bytes() <= 1,
|
||||||
|
"concurrent same-user reads must not consume more than one byte under quota=1"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
stats.get_user_total_octets(&user) <= 1,
|
||||||
|
"concurrent same-user reads must not account over limit"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn stress_same_user_quota_parallel_relays_never_exceed_limit() {
|
||||||
|
let stats = Arc::new(Stats::new());
|
||||||
|
let user = "parallel-quota-user";
|
||||||
|
|
||||||
|
for _ in 0..128 {
|
||||||
|
let (mut client_peer_a, relay_client_a) = duplex(256);
|
||||||
|
let (relay_server_a, mut server_peer_a) = duplex(256);
|
||||||
|
let (mut client_peer_b, relay_client_b) = duplex(256);
|
||||||
|
let (relay_server_b, mut server_peer_b) = duplex(256);
|
||||||
|
|
||||||
|
let (client_reader_a, client_writer_a) = tokio::io::split(relay_client_a);
|
||||||
|
let (server_reader_a, server_writer_a) = tokio::io::split(relay_server_a);
|
||||||
|
let (client_reader_b, client_writer_b) = tokio::io::split(relay_client_b);
|
||||||
|
let (server_reader_b, server_writer_b) = tokio::io::split(relay_server_b);
|
||||||
|
|
||||||
|
let relay_a = tokio::spawn(relay_bidirectional(
|
||||||
|
client_reader_a,
|
||||||
|
client_writer_a,
|
||||||
|
server_reader_a,
|
||||||
|
server_writer_a,
|
||||||
|
64,
|
||||||
|
64,
|
||||||
|
user,
|
||||||
|
Arc::clone(&stats),
|
||||||
|
Some(1),
|
||||||
|
Arc::new(BufferPool::new()),
|
||||||
|
));
|
||||||
|
|
||||||
|
let relay_b = tokio::spawn(relay_bidirectional(
|
||||||
|
client_reader_b,
|
||||||
|
client_writer_b,
|
||||||
|
server_reader_b,
|
||||||
|
server_writer_b,
|
||||||
|
64,
|
||||||
|
64,
|
||||||
|
user,
|
||||||
|
Arc::clone(&stats),
|
||||||
|
Some(1),
|
||||||
|
Arc::new(BufferPool::new()),
|
||||||
|
));
|
||||||
|
|
||||||
|
let _ = tokio::join!(
|
||||||
|
client_peer_a.write_all(&[0x01]),
|
||||||
|
server_peer_a.write_all(&[0x02]),
|
||||||
|
client_peer_b.write_all(&[0x03]),
|
||||||
|
server_peer_b.write_all(&[0x04]),
|
||||||
|
);
|
||||||
|
|
||||||
|
let _ = timeout(Duration::from_millis(50), poll_fn(|cx| {
|
||||||
|
let mut one = [0u8; 1];
|
||||||
|
let _ = Pin::new(&mut client_peer_a).poll_read(cx, &mut ReadBuf::new(&mut one));
|
||||||
|
Poll::Ready(())
|
||||||
|
}))
|
||||||
|
.await;
|
||||||
|
|
||||||
|
drop(client_peer_a);
|
||||||
|
drop(server_peer_a);
|
||||||
|
drop(client_peer_b);
|
||||||
|
drop(server_peer_b);
|
||||||
|
|
||||||
|
let _ = timeout(Duration::from_secs(1), relay_a).await;
|
||||||
|
let _ = timeout(Duration::from_secs(1), relay_b).await;
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
stats.get_user_total_octets(user) <= 1,
|
||||||
|
"parallel relays must not exceed configured quota"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FaultyReader {
|
||||||
|
fn permission_denied_with_message(message: impl Into<String>) -> Self {
|
||||||
|
Self {
|
||||||
|
error_once: Some(io::Error::new(io::ErrorKind::PermissionDenied, message.into())),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AsyncRead for FaultyReader {
|
||||||
|
fn poll_read(
|
||||||
|
mut self: Pin<&mut Self>,
|
||||||
|
_cx: &mut Context<'_>,
|
||||||
|
_buf: &mut ReadBuf<'_>,
|
||||||
|
) -> Poll<io::Result<()>> {
|
||||||
|
if let Some(err) = self.error_once.take() {
|
||||||
|
return Poll::Ready(Err(err));
|
||||||
|
}
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn relay_bidirectional_does_not_misclassify_transport_permission_denied_as_quota() {
|
||||||
|
let stats = Arc::new(Stats::new());
|
||||||
|
let (client_peer, relay_client) = duplex(4096);
|
||||||
|
let (client_reader, client_writer) = tokio::io::split(relay_client);
|
||||||
|
|
||||||
|
let relay_result = relay_bidirectional(
|
||||||
|
client_reader,
|
||||||
|
client_writer,
|
||||||
|
FaultyReader::permission_denied_with_message("user data quota exceeded"),
|
||||||
|
tokio::io::sink(),
|
||||||
|
1024,
|
||||||
|
1024,
|
||||||
|
"non-quota-permission-denied",
|
||||||
|
Arc::clone(&stats),
|
||||||
|
None,
|
||||||
|
Arc::new(BufferPool::new()),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
drop(client_peer);
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
matches!(relay_result, Err(ProxyError::Io(ref err)) if err.kind() == io::ErrorKind::PermissionDenied),
|
||||||
|
"non-quota transport PermissionDenied errors must remain IO errors"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn relay_bidirectional_light_fuzz_permission_denied_messages_remain_io_errors() {
|
||||||
|
let mut rng = StdRng::seed_from_u64(0xA11CE0B5);
|
||||||
|
|
||||||
|
for i in 0..128u64 {
|
||||||
|
let stats = Arc::new(Stats::new());
|
||||||
|
let (client_peer, relay_client) = duplex(1024);
|
||||||
|
let (client_reader, client_writer) = tokio::io::split(relay_client);
|
||||||
|
|
||||||
|
let random_len = rng.random_range(1..=48);
|
||||||
|
let mut msg = String::with_capacity(random_len);
|
||||||
|
for _ in 0..random_len {
|
||||||
|
let ch = (b'a' + (rng.random::<u8>() % 26)) as char;
|
||||||
|
msg.push(ch);
|
||||||
|
}
|
||||||
|
// Include the legacy quota string in a subset of fuzz cases to validate
|
||||||
|
// collision resistance against message-based classification.
|
||||||
|
if i % 7 == 0 {
|
||||||
|
msg = "user data quota exceeded".to_string();
|
||||||
|
}
|
||||||
|
|
||||||
|
let relay_result = relay_bidirectional(
|
||||||
|
client_reader,
|
||||||
|
client_writer,
|
||||||
|
FaultyReader::permission_denied_with_message(msg),
|
||||||
|
tokio::io::sink(),
|
||||||
|
1024,
|
||||||
|
1024,
|
||||||
|
"fuzz-perm-denied",
|
||||||
|
Arc::clone(&stats),
|
||||||
|
None,
|
||||||
|
Arc::new(BufferPool::new()),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
drop(client_peer);
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
matches!(relay_result, Err(ProxyError::Io(ref err)) if err.kind() == io::ErrorKind::PermissionDenied),
|
||||||
|
"transport PermissionDenied case must stay typed as IO regardless of message content"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -103,7 +103,7 @@ pub fn build_emulated_server_hello(
|
||||||
cached: &CachedTlsData,
|
cached: &CachedTlsData,
|
||||||
use_full_cert_payload: bool,
|
use_full_cert_payload: bool,
|
||||||
rng: &SecureRandom,
|
rng: &SecureRandom,
|
||||||
_alpn: Option<Vec<u8>>,
|
alpn: Option<Vec<u8>>,
|
||||||
new_session_tickets: u8,
|
new_session_tickets: u8,
|
||||||
) -> Vec<u8> {
|
) -> Vec<u8> {
|
||||||
// --- ServerHello ---
|
// --- ServerHello ---
|
||||||
|
|
@ -198,8 +198,22 @@ pub fn build_emulated_server_hello(
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut app_data = Vec::new();
|
let mut app_data = Vec::new();
|
||||||
|
let alpn_marker = alpn
|
||||||
|
.as_ref()
|
||||||
|
.filter(|p| !p.is_empty() && p.len() <= u8::MAX as usize)
|
||||||
|
.map(|proto| {
|
||||||
|
let proto_list_len = 1usize + proto.len();
|
||||||
|
let ext_data_len = 2usize + proto_list_len;
|
||||||
|
let mut marker = Vec::with_capacity(4 + ext_data_len);
|
||||||
|
marker.extend_from_slice(&0x0010u16.to_be_bytes());
|
||||||
|
marker.extend_from_slice(&(ext_data_len as u16).to_be_bytes());
|
||||||
|
marker.extend_from_slice(&(proto_list_len as u16).to_be_bytes());
|
||||||
|
marker.push(proto.len() as u8);
|
||||||
|
marker.extend_from_slice(proto);
|
||||||
|
marker
|
||||||
|
});
|
||||||
let mut payload_offset = 0usize;
|
let mut payload_offset = 0usize;
|
||||||
for size in sizes {
|
for (idx, size) in sizes.into_iter().enumerate() {
|
||||||
let mut rec = Vec::with_capacity(5 + size);
|
let mut rec = Vec::with_capacity(5 + size);
|
||||||
rec.push(TLS_RECORD_APPLICATION);
|
rec.push(TLS_RECORD_APPLICATION);
|
||||||
rec.extend_from_slice(&TLS_VERSION);
|
rec.extend_from_slice(&TLS_VERSION);
|
||||||
|
|
@ -224,7 +238,20 @@ pub fn build_emulated_server_hello(
|
||||||
}
|
}
|
||||||
} else if size > 17 {
|
} else if size > 17 {
|
||||||
let body_len = size - 17;
|
let body_len = size - 17;
|
||||||
rec.extend_from_slice(&rng.bytes(body_len));
|
let mut body = Vec::with_capacity(body_len);
|
||||||
|
if idx == 0 && let Some(marker) = &alpn_marker {
|
||||||
|
if marker.len() <= body_len {
|
||||||
|
body.extend_from_slice(marker);
|
||||||
|
if body_len > marker.len() {
|
||||||
|
body.extend_from_slice(&rng.bytes(body_len - marker.len()));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
body.extend_from_slice(&rng.bytes(body_len));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
body.extend_from_slice(&rng.bytes(body_len));
|
||||||
|
}
|
||||||
|
rec.extend_from_slice(&body);
|
||||||
rec.push(0x16); // inner content type marker (handshake)
|
rec.push(0x16); // inner content type marker (handshake)
|
||||||
rec.extend_from_slice(&rng.bytes(16)); // AEAD-like tag
|
rec.extend_from_slice(&rng.bytes(16)); // AEAD-like tag
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -236,8 +263,9 @@ pub fn build_emulated_server_hello(
|
||||||
// --- Combine ---
|
// --- Combine ---
|
||||||
// Optional NewSessionTicket mimic records (opaque ApplicationData for fingerprint).
|
// Optional NewSessionTicket mimic records (opaque ApplicationData for fingerprint).
|
||||||
let mut tickets = Vec::new();
|
let mut tickets = Vec::new();
|
||||||
if new_session_tickets > 0 {
|
let ticket_count = new_session_tickets.min(4);
|
||||||
for _ in 0..new_session_tickets {
|
if ticket_count > 0 {
|
||||||
|
for _ in 0..ticket_count {
|
||||||
let ticket_len: usize = rng.range(48) + 48;
|
let ticket_len: usize = rng.range(48) + 48;
|
||||||
let mut rec = Vec::with_capacity(5 + ticket_len);
|
let mut rec = Vec::with_capacity(5 + ticket_len);
|
||||||
rec.push(TLS_RECORD_APPLICATION);
|
rec.push(TLS_RECORD_APPLICATION);
|
||||||
|
|
@ -264,6 +292,10 @@ pub fn build_emulated_server_hello(
|
||||||
response
|
response
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
#[path = "emulator_security_tests.rs"]
|
||||||
|
mod security_tests;
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use std::time::SystemTime;
|
use std::time::SystemTime;
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,136 @@
|
||||||
|
use std::time::SystemTime;
|
||||||
|
|
||||||
|
use crate::crypto::SecureRandom;
|
||||||
|
use crate::protocol::constants::{TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER, TLS_RECORD_HANDSHAKE};
|
||||||
|
use crate::tls_front::emulator::build_emulated_server_hello;
|
||||||
|
use crate::tls_front::types::{
|
||||||
|
CachedTlsData, ParsedServerHello, TlsBehaviorProfile, TlsCertPayload, TlsProfileSource,
|
||||||
|
};
|
||||||
|
|
||||||
|
fn make_cached(cert_payload: Option<crate::tls_front::types::TlsCertPayload>) -> CachedTlsData {
|
||||||
|
CachedTlsData {
|
||||||
|
server_hello_template: ParsedServerHello {
|
||||||
|
version: [0x03, 0x03],
|
||||||
|
random: [0u8; 32],
|
||||||
|
session_id: Vec::new(),
|
||||||
|
cipher_suite: [0x13, 0x01],
|
||||||
|
compression: 0,
|
||||||
|
extensions: Vec::new(),
|
||||||
|
},
|
||||||
|
cert_info: None,
|
||||||
|
cert_payload,
|
||||||
|
app_data_records_sizes: vec![64],
|
||||||
|
total_app_data_len: 64,
|
||||||
|
behavior_profile: TlsBehaviorProfile {
|
||||||
|
change_cipher_spec_count: 1,
|
||||||
|
app_data_record_sizes: vec![64],
|
||||||
|
ticket_record_sizes: Vec::new(),
|
||||||
|
source: TlsProfileSource::Default,
|
||||||
|
},
|
||||||
|
fetched_at: SystemTime::now(),
|
||||||
|
domain: "example.com".to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn first_app_data_payload(response: &[u8]) -> &[u8] {
|
||||||
|
let hello_len = u16::from_be_bytes([response[3], response[4]]) as usize;
|
||||||
|
let ccs_start = 5 + hello_len;
|
||||||
|
let ccs_len = u16::from_be_bytes([response[ccs_start + 3], response[ccs_start + 4]]) as usize;
|
||||||
|
let app_start = ccs_start + 5 + ccs_len;
|
||||||
|
let app_len = u16::from_be_bytes([response[app_start + 3], response[app_start + 4]]) as usize;
|
||||||
|
&response[app_start + 5..app_start + 5 + app_len]
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn emulated_server_hello_ignores_oversized_alpn_when_marker_would_not_fit() {
|
||||||
|
let cached = make_cached(None);
|
||||||
|
let rng = SecureRandom::new();
|
||||||
|
let oversized_alpn = vec![0xAB; u8::MAX as usize + 1];
|
||||||
|
|
||||||
|
let response = build_emulated_server_hello(
|
||||||
|
b"secret",
|
||||||
|
&[0x11; 32],
|
||||||
|
&[0x22; 16],
|
||||||
|
&cached,
|
||||||
|
true,
|
||||||
|
&rng,
|
||||||
|
Some(oversized_alpn),
|
||||||
|
0,
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_eq!(response[0], TLS_RECORD_HANDSHAKE);
|
||||||
|
let hello_len = u16::from_be_bytes([response[3], response[4]]) as usize;
|
||||||
|
let ccs_start = 5 + hello_len;
|
||||||
|
assert_eq!(response[ccs_start], TLS_RECORD_CHANGE_CIPHER);
|
||||||
|
let app_start = ccs_start + 6;
|
||||||
|
assert_eq!(response[app_start], TLS_RECORD_APPLICATION);
|
||||||
|
|
||||||
|
let payload = first_app_data_payload(&response);
|
||||||
|
let mut marker_prefix = Vec::new();
|
||||||
|
marker_prefix.extend_from_slice(&0x0010u16.to_be_bytes());
|
||||||
|
marker_prefix.extend_from_slice(&0x0102u16.to_be_bytes());
|
||||||
|
marker_prefix.extend_from_slice(&0x0100u16.to_be_bytes());
|
||||||
|
marker_prefix.push(0xff);
|
||||||
|
marker_prefix.extend_from_slice(&[0xab; 8]);
|
||||||
|
assert!(
|
||||||
|
!payload.starts_with(&marker_prefix),
|
||||||
|
"oversized ALPN must not be partially embedded into the emulated first application record"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn emulated_server_hello_embeds_full_alpn_marker_when_body_can_fit() {
|
||||||
|
let cached = make_cached(None);
|
||||||
|
let rng = SecureRandom::new();
|
||||||
|
|
||||||
|
let response = build_emulated_server_hello(
|
||||||
|
b"secret",
|
||||||
|
&[0x31; 32],
|
||||||
|
&[0x41; 16],
|
||||||
|
&cached,
|
||||||
|
true,
|
||||||
|
&rng,
|
||||||
|
Some(b"h2".to_vec()),
|
||||||
|
0,
|
||||||
|
);
|
||||||
|
|
||||||
|
let payload = first_app_data_payload(&response);
|
||||||
|
let expected = [0x00u8, 0x10, 0x00, 0x05, 0x00, 0x03, 0x02, b'h', b'2'];
|
||||||
|
assert!(
|
||||||
|
payload.starts_with(&expected),
|
||||||
|
"when body has enough capacity, emulated first application record must include full ALPN marker"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn emulated_server_hello_prefers_cert_payload_over_alpn_marker() {
|
||||||
|
let cert_msg = vec![0x0b, 0x00, 0x00, 0x05, 0x00, 0xaa, 0xbb, 0xcc, 0xdd];
|
||||||
|
let cached = make_cached(Some(TlsCertPayload {
|
||||||
|
cert_chain_der: vec![vec![0x30, 0x01, 0x00]],
|
||||||
|
certificate_message: cert_msg.clone(),
|
||||||
|
}));
|
||||||
|
let rng = SecureRandom::new();
|
||||||
|
|
||||||
|
let response = build_emulated_server_hello(
|
||||||
|
b"secret",
|
||||||
|
&[0x32; 32],
|
||||||
|
&[0x42; 16],
|
||||||
|
&cached,
|
||||||
|
true,
|
||||||
|
&rng,
|
||||||
|
Some(b"h2".to_vec()),
|
||||||
|
0,
|
||||||
|
);
|
||||||
|
|
||||||
|
let payload = first_app_data_payload(&response);
|
||||||
|
let alpn_marker = [0x00u8, 0x10, 0x00, 0x05, 0x00, 0x03, 0x02, b'h', b'2'];
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
payload.starts_with(&cert_msg),
|
||||||
|
"when certificate payload is available, first record must start with cert payload bytes"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
!payload.starts_with(&alpn_marker),
|
||||||
|
"ALPN marker must not displace selected certificate payload"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
@ -27,6 +27,8 @@ mod health_regression_tests;
|
||||||
mod health_integration_tests;
|
mod health_integration_tests;
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod health_adversarial_tests;
|
mod health_adversarial_tests;
|
||||||
|
#[cfg(test)]
|
||||||
|
mod send_adversarial_tests;
|
||||||
|
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -692,6 +692,7 @@ impl MePool {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
pub(super) fn draining_active_runtime(&self) -> u64 {
|
pub(super) fn draining_active_runtime(&self) -> u64 {
|
||||||
self.draining_active_runtime.load(Ordering::Relaxed)
|
self.draining_active_runtime.load(Ordering::Relaxed)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -454,6 +454,7 @@ impl ConnRegistry {
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
pub(super) async fn non_empty_writer_ids(&self, writer_ids: &[u64]) -> HashSet<u64> {
|
pub(super) async fn non_empty_writer_ids(&self, writer_ids: &[u64]) -> HashSet<u64> {
|
||||||
let inner = self.inner.read().await;
|
let inner = self.inner.read().await;
|
||||||
let mut out = HashSet::<u64>::with_capacity(writer_ids.len());
|
let mut out = HashSet::<u64>::with_capacity(writer_ids.len());
|
||||||
|
|
|
||||||
|
|
@ -372,17 +372,20 @@ impl MePool {
|
||||||
}
|
}
|
||||||
let effective_our_addr = SocketAddr::new(w.source_ip, our_addr.port());
|
let effective_our_addr = SocketAddr::new(w.source_ip, our_addr.port());
|
||||||
let (payload, meta) = build_routed_payload(effective_our_addr);
|
let (payload, meta) = build_routed_payload(effective_our_addr);
|
||||||
match w.tx.try_send(WriterCommand::Data(payload.clone())) {
|
match w.tx.clone().try_reserve_owned() {
|
||||||
Ok(()) => {
|
Ok(permit) => {
|
||||||
self.stats.increment_me_writer_pick_success_try_total(pick_mode);
|
|
||||||
if !self.registry.bind_writer(conn_id, w.id, meta).await {
|
if !self.registry.bind_writer(conn_id, w.id, meta).await {
|
||||||
debug!(
|
debug!(
|
||||||
conn_id,
|
conn_id,
|
||||||
writer_id = w.id,
|
writer_id = w.id,
|
||||||
"ME writer disappeared before bind commit, retrying"
|
"ME writer disappeared before bind commit, pruning stale writer"
|
||||||
);
|
);
|
||||||
|
drop(permit);
|
||||||
|
self.remove_writer_and_close_clients(w.id).await;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
permit.send(WriterCommand::Data(payload.clone()));
|
||||||
|
self.stats.increment_me_writer_pick_success_try_total(pick_mode);
|
||||||
if w.generation < self.current_generation() {
|
if w.generation < self.current_generation() {
|
||||||
self.stats.increment_pool_stale_pick_total();
|
self.stats.increment_pool_stale_pick_total();
|
||||||
debug!(
|
debug!(
|
||||||
|
|
@ -422,18 +425,21 @@ impl MePool {
|
||||||
self.stats.increment_me_writer_pick_blocking_fallback_total();
|
self.stats.increment_me_writer_pick_blocking_fallback_total();
|
||||||
let effective_our_addr = SocketAddr::new(w.source_ip, our_addr.port());
|
let effective_our_addr = SocketAddr::new(w.source_ip, our_addr.port());
|
||||||
let (payload, meta) = build_routed_payload(effective_our_addr);
|
let (payload, meta) = build_routed_payload(effective_our_addr);
|
||||||
match w.tx.send(WriterCommand::Data(payload.clone())).await {
|
match w.tx.clone().reserve_owned().await {
|
||||||
Ok(()) => {
|
Ok(permit) => {
|
||||||
self.stats
|
|
||||||
.increment_me_writer_pick_success_fallback_total(pick_mode);
|
|
||||||
if !self.registry.bind_writer(conn_id, w.id, meta).await {
|
if !self.registry.bind_writer(conn_id, w.id, meta).await {
|
||||||
debug!(
|
debug!(
|
||||||
conn_id,
|
conn_id,
|
||||||
writer_id = w.id,
|
writer_id = w.id,
|
||||||
"ME writer disappeared before fallback bind commit, retrying"
|
"ME writer disappeared before fallback bind commit, pruning stale writer"
|
||||||
);
|
);
|
||||||
|
drop(permit);
|
||||||
|
self.remove_writer_and_close_clients(w.id).await;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
permit.send(WriterCommand::Data(payload.clone()));
|
||||||
|
self.stats
|
||||||
|
.increment_me_writer_pick_success_fallback_total(pick_mode);
|
||||||
if w.generation < self.current_generation() {
|
if w.generation < self.current_generation() {
|
||||||
self.stats.increment_pool_stale_pick_total();
|
self.stats.increment_pool_stale_pick_total();
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,263 @@
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::sync::atomic::{AtomicBool, AtomicU8, AtomicU32, AtomicU64, Ordering};
|
||||||
|
use std::time::{Duration, Instant};
|
||||||
|
|
||||||
|
use tokio::sync::mpsc;
|
||||||
|
use tokio_util::sync::CancellationToken;
|
||||||
|
|
||||||
|
use super::codec::WriterCommand;
|
||||||
|
use super::pool::{MePool, MeWriter, WriterContour};
|
||||||
|
use crate::config::{GeneralConfig, MeRouteNoWriterMode, MeSocksKdfPolicy, MeWriterPickMode};
|
||||||
|
use crate::crypto::SecureRandom;
|
||||||
|
use crate::network::probe::NetworkDecision;
|
||||||
|
use crate::stats::Stats;
|
||||||
|
|
||||||
|
async fn make_pool() -> (Arc<MePool>, Arc<SecureRandom>) {
|
||||||
|
let general = GeneralConfig {
|
||||||
|
me_route_no_writer_mode: MeRouteNoWriterMode::AsyncRecoveryFailfast,
|
||||||
|
me_route_no_writer_wait_ms: 50,
|
||||||
|
me_writer_pick_mode: MeWriterPickMode::SortedRr,
|
||||||
|
me_deterministic_writer_sort: true,
|
||||||
|
..GeneralConfig::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
let rng = Arc::new(SecureRandom::new());
|
||||||
|
let pool = MePool::new(
|
||||||
|
None,
|
||||||
|
vec![1u8; 32],
|
||||||
|
None,
|
||||||
|
false,
|
||||||
|
None,
|
||||||
|
Vec::new(),
|
||||||
|
1,
|
||||||
|
None,
|
||||||
|
12,
|
||||||
|
1200,
|
||||||
|
HashMap::new(),
|
||||||
|
HashMap::new(),
|
||||||
|
None,
|
||||||
|
NetworkDecision::default(),
|
||||||
|
None,
|
||||||
|
rng.clone(),
|
||||||
|
Arc::new(Stats::default()),
|
||||||
|
general.me_keepalive_enabled,
|
||||||
|
general.me_keepalive_interval_secs,
|
||||||
|
general.me_keepalive_jitter_secs,
|
||||||
|
general.me_keepalive_payload_random,
|
||||||
|
general.rpc_proxy_req_every,
|
||||||
|
general.me_warmup_stagger_enabled,
|
||||||
|
general.me_warmup_step_delay_ms,
|
||||||
|
general.me_warmup_step_jitter_ms,
|
||||||
|
general.me_reconnect_max_concurrent_per_dc,
|
||||||
|
general.me_reconnect_backoff_base_ms,
|
||||||
|
general.me_reconnect_backoff_cap_ms,
|
||||||
|
general.me_reconnect_fast_retry_count,
|
||||||
|
general.me_single_endpoint_shadow_writers,
|
||||||
|
general.me_single_endpoint_outage_mode_enabled,
|
||||||
|
general.me_single_endpoint_outage_disable_quarantine,
|
||||||
|
general.me_single_endpoint_outage_backoff_min_ms,
|
||||||
|
general.me_single_endpoint_outage_backoff_max_ms,
|
||||||
|
general.me_single_endpoint_shadow_rotate_every_secs,
|
||||||
|
general.me_floor_mode,
|
||||||
|
general.me_adaptive_floor_idle_secs,
|
||||||
|
general.me_adaptive_floor_min_writers_single_endpoint,
|
||||||
|
general.me_adaptive_floor_min_writers_multi_endpoint,
|
||||||
|
general.me_adaptive_floor_recover_grace_secs,
|
||||||
|
general.me_adaptive_floor_writers_per_core_total,
|
||||||
|
general.me_adaptive_floor_cpu_cores_override,
|
||||||
|
general.me_adaptive_floor_max_extra_writers_single_per_core,
|
||||||
|
general.me_adaptive_floor_max_extra_writers_multi_per_core,
|
||||||
|
general.me_adaptive_floor_max_active_writers_per_core,
|
||||||
|
general.me_adaptive_floor_max_warm_writers_per_core,
|
||||||
|
general.me_adaptive_floor_max_active_writers_global,
|
||||||
|
general.me_adaptive_floor_max_warm_writers_global,
|
||||||
|
general.hardswap,
|
||||||
|
general.me_pool_drain_ttl_secs,
|
||||||
|
general.me_pool_drain_threshold,
|
||||||
|
general.effective_me_pool_force_close_secs(),
|
||||||
|
general.me_pool_min_fresh_ratio,
|
||||||
|
general.me_hardswap_warmup_delay_min_ms,
|
||||||
|
general.me_hardswap_warmup_delay_max_ms,
|
||||||
|
general.me_hardswap_warmup_extra_passes,
|
||||||
|
general.me_hardswap_warmup_pass_backoff_base_ms,
|
||||||
|
general.me_bind_stale_mode,
|
||||||
|
general.me_bind_stale_ttl_secs,
|
||||||
|
general.me_secret_atomic_snapshot,
|
||||||
|
general.me_deterministic_writer_sort,
|
||||||
|
general.me_writer_pick_mode,
|
||||||
|
general.me_writer_pick_sample_size,
|
||||||
|
MeSocksKdfPolicy::default(),
|
||||||
|
general.me_writer_cmd_channel_capacity,
|
||||||
|
general.me_route_channel_capacity,
|
||||||
|
general.me_route_backpressure_base_timeout_ms,
|
||||||
|
general.me_route_backpressure_high_timeout_ms,
|
||||||
|
general.me_route_backpressure_high_watermark_pct,
|
||||||
|
general.me_reader_route_data_wait_ms,
|
||||||
|
general.me_health_interval_ms_unhealthy,
|
||||||
|
general.me_health_interval_ms_healthy,
|
||||||
|
general.me_warn_rate_limit_ms,
|
||||||
|
general.me_route_no_writer_mode,
|
||||||
|
general.me_route_no_writer_wait_ms,
|
||||||
|
general.me_route_inline_recovery_attempts,
|
||||||
|
general.me_route_inline_recovery_wait_ms,
|
||||||
|
);
|
||||||
|
|
||||||
|
(pool, rng)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn insert_writer(
|
||||||
|
pool: &Arc<MePool>,
|
||||||
|
writer_id: u64,
|
||||||
|
writer_dc: i32,
|
||||||
|
addr: SocketAddr,
|
||||||
|
register_in_registry: bool,
|
||||||
|
) -> mpsc::Receiver<WriterCommand> {
|
||||||
|
let (tx, rx) = mpsc::channel::<WriterCommand>(8);
|
||||||
|
let writer = MeWriter {
|
||||||
|
id: writer_id,
|
||||||
|
addr,
|
||||||
|
source_ip: addr.ip(),
|
||||||
|
writer_dc,
|
||||||
|
generation: pool.current_generation(),
|
||||||
|
contour: Arc::new(AtomicU8::new(WriterContour::Active.as_u8())),
|
||||||
|
created_at: Instant::now(),
|
||||||
|
tx: tx.clone(),
|
||||||
|
cancel: CancellationToken::new(),
|
||||||
|
degraded: Arc::new(AtomicBool::new(false)),
|
||||||
|
rtt_ema_ms_x10: Arc::new(AtomicU32::new(0)),
|
||||||
|
draining: Arc::new(AtomicBool::new(false)),
|
||||||
|
draining_started_at_epoch_secs: Arc::new(AtomicU64::new(0)),
|
||||||
|
drain_deadline_epoch_secs: Arc::new(AtomicU64::new(0)),
|
||||||
|
allow_drain_fallback: Arc::new(AtomicBool::new(false)),
|
||||||
|
};
|
||||||
|
|
||||||
|
pool.writers.write().await.push(writer);
|
||||||
|
{
|
||||||
|
let mut map = pool.proxy_map_v4.write().await;
|
||||||
|
map.entry(writer_dc)
|
||||||
|
.or_insert_with(Vec::new)
|
||||||
|
.push((addr.ip(), addr.port()));
|
||||||
|
}
|
||||||
|
pool.rebuild_endpoint_dc_map().await;
|
||||||
|
if register_in_registry {
|
||||||
|
pool.registry.register_writer(writer_id, tx).await;
|
||||||
|
}
|
||||||
|
rx
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn recv_data_count(rx: &mut mpsc::Receiver<WriterCommand>, budget: Duration) -> usize {
|
||||||
|
let start = Instant::now();
|
||||||
|
let mut data_count = 0usize;
|
||||||
|
while Instant::now().duration_since(start) < budget {
|
||||||
|
let remaining = budget.saturating_sub(Instant::now().duration_since(start));
|
||||||
|
match tokio::time::timeout(remaining.min(Duration::from_millis(10)), rx.recv()).await {
|
||||||
|
Ok(Some(WriterCommand::Data(_))) => data_count += 1,
|
||||||
|
Ok(Some(WriterCommand::DataAndFlush(_))) => data_count += 1,
|
||||||
|
Ok(Some(WriterCommand::Close)) => {}
|
||||||
|
Ok(None) => break,
|
||||||
|
Err(_) => break,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
data_count
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn send_proxy_req_does_not_replay_when_first_bind_commit_fails() {
|
||||||
|
let (pool, _rng) = make_pool().await;
|
||||||
|
pool.rr.store(0, Ordering::Relaxed);
|
||||||
|
|
||||||
|
let (conn_id, _rx) = pool.registry.register().await;
|
||||||
|
let mut stale_rx = insert_writer(
|
||||||
|
&pool,
|
||||||
|
10,
|
||||||
|
2,
|
||||||
|
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 10)), 443),
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
let mut live_rx = insert_writer(
|
||||||
|
&pool,
|
||||||
|
11,
|
||||||
|
2,
|
||||||
|
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 11)), 443),
|
||||||
|
true,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let result = pool
|
||||||
|
.send_proxy_req(
|
||||||
|
conn_id,
|
||||||
|
2,
|
||||||
|
SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 30000),
|
||||||
|
SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443),
|
||||||
|
b"hello",
|
||||||
|
0,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert!(result.is_ok());
|
||||||
|
assert_eq!(recv_data_count(&mut stale_rx, Duration::from_millis(50)).await, 0);
|
||||||
|
assert_eq!(recv_data_count(&mut live_rx, Duration::from_millis(50)).await, 1);
|
||||||
|
|
||||||
|
let bound = pool.registry.get_writer(conn_id).await;
|
||||||
|
assert!(bound.is_some());
|
||||||
|
assert_eq!(bound.expect("writer should be bound").writer_id, 11);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn send_proxy_req_prunes_iterative_stale_bind_failures_without_data_replay() {
|
||||||
|
let (pool, _rng) = make_pool().await;
|
||||||
|
pool.rr.store(0, Ordering::Relaxed);
|
||||||
|
|
||||||
|
let (conn_id, _rx) = pool.registry.register().await;
|
||||||
|
|
||||||
|
let mut stale_rx_1 = insert_writer(
|
||||||
|
&pool,
|
||||||
|
21,
|
||||||
|
2,
|
||||||
|
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 1, 21)), 443),
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
let mut stale_rx_2 = insert_writer(
|
||||||
|
&pool,
|
||||||
|
22,
|
||||||
|
2,
|
||||||
|
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 1, 22)), 443),
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
let mut live_rx = insert_writer(
|
||||||
|
&pool,
|
||||||
|
23,
|
||||||
|
2,
|
||||||
|
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 1, 23)), 443),
|
||||||
|
true,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let result = pool
|
||||||
|
.send_proxy_req(
|
||||||
|
conn_id,
|
||||||
|
2,
|
||||||
|
SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 30001),
|
||||||
|
SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443),
|
||||||
|
b"storm",
|
||||||
|
0,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert!(result.is_ok());
|
||||||
|
assert_eq!(recv_data_count(&mut stale_rx_1, Duration::from_millis(50)).await, 0);
|
||||||
|
assert_eq!(recv_data_count(&mut stale_rx_2, Duration::from_millis(50)).await, 0);
|
||||||
|
assert_eq!(recv_data_count(&mut live_rx, Duration::from_millis(50)).await, 1);
|
||||||
|
|
||||||
|
let writers = pool.writers.read().await;
|
||||||
|
let writer_ids = writers.iter().map(|w| w.id).collect::<Vec<_>>();
|
||||||
|
drop(writers);
|
||||||
|
assert_eq!(writer_ids, vec![23]);
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue