From c0a3e43aa8e6e7e417fe64cc0f356634025c3c7d Mon Sep 17 00:00:00 2001 From: David Osipov Date: Sat, 21 Mar 2026 20:54:13 +0400 Subject: [PATCH] Add comprehensive security tests for proxy functionality - Introduced client TLS record wrapping tests to ensure correct handling of empty and oversized payloads. - Added integration tests for middle relay to validate quota saturation behavior under concurrent pressure. - Implemented high-risk security tests covering various payload scenarios, including alignment checks and boundary conditions. - Developed length cast hardening tests to verify proper handling of wire lengths and overflow conditions. - Created quota overflow lock tests to ensure stable behavior under saturation and reclaim scenarios. - Refactored existing middle relay security tests for improved clarity and consistency in lock handling. --- ...ls_length_cast_hardening_security_tests.rs | 37 + src/protocol/tls.rs | 33 +- src/proxy/client.rs | 26 +- src/proxy/middle_relay.rs | 90 ++- ...ls_record_wrap_hardening_security_tests.rs | 37 + ...lay_blackhat_campaign_integration_tests.rs | 112 +++ ...relay_coverage_high_risk_security_tests.rs | 708 ++++++++++++++++++ ...ay_length_cast_hardening_security_tests.rs | 75 ++ ...elay_quota_overflow_lock_security_tests.rs | 131 ++++ .../tests/middle_relay_security_tests.rs | 21 +- 10 files changed, 1238 insertions(+), 32 deletions(-) create mode 100644 src/protocol/tests/tls_length_cast_hardening_security_tests.rs create mode 100644 src/proxy/tests/client_tls_record_wrap_hardening_security_tests.rs create mode 100644 src/proxy/tests/middle_relay_blackhat_campaign_integration_tests.rs create mode 100644 src/proxy/tests/middle_relay_coverage_high_risk_security_tests.rs create mode 100644 src/proxy/tests/middle_relay_length_cast_hardening_security_tests.rs create mode 100644 src/proxy/tests/middle_relay_quota_overflow_lock_security_tests.rs diff --git a/src/protocol/tests/tls_length_cast_hardening_security_tests.rs b/src/protocol/tests/tls_length_cast_hardening_security_tests.rs new file mode 100644 index 0000000..31418e4 --- /dev/null +++ b/src/protocol/tests/tls_length_cast_hardening_security_tests.rs @@ -0,0 +1,37 @@ +use super::*; + +#[test] +fn extension_builder_fails_closed_on_u16_length_overflow() { + let builder = TlsExtensionBuilder { + extensions: vec![0u8; (u16::MAX as usize) + 1], + }; + + let built = builder.build(); + assert!( + built.is_empty(), + "oversized extension blob must fail closed instead of truncating length field" + ); +} + +#[test] +fn server_hello_builder_fails_closed_on_session_id_len_overflow() { + let builder = ServerHelloBuilder { + random: [0u8; 32], + session_id: vec![0xAB; (u8::MAX as usize) + 1], + cipher_suite: cipher_suite::TLS_AES_128_GCM_SHA256, + compression: 0, + extensions: TlsExtensionBuilder::new(), + }; + + let message = builder.build_message(); + let record = builder.build_record(); + + assert!( + message.is_empty(), + "session_id length overflow must fail closed in message builder" + ); + assert!( + record.is_empty(), + "session_id length overflow must fail closed in record builder" + ); +} diff --git a/src/protocol/tls.rs b/src/protocol/tls.rs index b9bca49..613106e 100644 --- a/src/protocol/tls.rs +++ b/src/protocol/tls.rs @@ -183,10 +183,12 @@ impl TlsExtensionBuilder { /// Build final extensions with length prefix fn build(self) -> Vec { + let Ok(len) = u16::try_from(self.extensions.len()) else { + return Vec::new(); + }; let mut result = Vec::with_capacity(2 + self.extensions.len()); // Extensions length (2 bytes) - let len = self.extensions.len() as u16; result.extend_from_slice(&len.to_be_bytes()); // Extensions data @@ -241,8 +243,13 @@ impl ServerHelloBuilder { /// Build ServerHello message (without record header) fn build_message(&self) -> Vec { + let Ok(session_id_len) = u8::try_from(self.session_id.len()) else { + return Vec::new(); + }; let extensions = self.extensions.extensions.clone(); - let extensions_len = extensions.len() as u16; + let Ok(extensions_len) = u16::try_from(extensions.len()) else { + return Vec::new(); + }; // Calculate total length let body_len = 2 + // version @@ -251,6 +258,9 @@ impl ServerHelloBuilder { 2 + // cipher suite 1 + // compression 2 + extensions.len(); // extensions length + data + if body_len > 0x00ff_ffff { + return Vec::new(); + } let mut message = Vec::with_capacity(4 + body_len); @@ -258,7 +268,10 @@ impl ServerHelloBuilder { message.push(0x02); // ServerHello message type // 3-byte length - let len_bytes = (body_len as u32).to_be_bytes(); + let Ok(body_len_u32) = u32::try_from(body_len) else { + return Vec::new(); + }; + let len_bytes = body_len_u32.to_be_bytes(); message.extend_from_slice(&len_bytes[1..4]); // Server version (TLS 1.2 in header, actual version in extension) @@ -268,7 +281,7 @@ impl ServerHelloBuilder { message.extend_from_slice(&self.random); // Session ID - message.push(self.session_id.len() as u8); + message.push(session_id_len); message.extend_from_slice(&self.session_id); // Cipher suite @@ -289,13 +302,19 @@ impl ServerHelloBuilder { /// Build complete ServerHello TLS record fn build_record(&self) -> Vec { let message = self.build_message(); + if message.is_empty() { + return Vec::new(); + } + let Ok(message_len) = u16::try_from(message.len()) else { + return Vec::new(); + }; let mut record = Vec::with_capacity(5 + message.len()); // TLS record header record.push(TLS_RECORD_HANDSHAKE); record.extend_from_slice(&TLS_VERSION); - record.extend_from_slice(&(message.len() as u16).to_be_bytes()); + record.extend_from_slice(&message_len.to_be_bytes()); // Message record.extend_from_slice(&message); @@ -910,3 +929,7 @@ mod adversarial_tests; #[cfg(test)] #[path = "tests/tls_fuzz_security_tests.rs"] mod fuzz_security_tests; + +#[cfg(test)] +#[path = "tests/tls_length_cast_hardening_security_tests.rs"] +mod length_cast_hardening_security_tests; diff --git a/src/proxy/client.rs b/src/proxy/client.rs index d71fc36..4b7f57e 100644 --- a/src/proxy/client.rs +++ b/src/proxy/client.rs @@ -116,11 +116,23 @@ fn beobachten_ttl(config: &ProxyConfig) -> Duration { } fn wrap_tls_application_record(payload: &[u8]) -> Vec { - let mut record = Vec::with_capacity(5 + payload.len()); - record.push(TLS_RECORD_APPLICATION); - record.extend_from_slice(&TLS_VERSION); - record.extend_from_slice(&(payload.len() as u16).to_be_bytes()); - record.extend_from_slice(payload); + let chunks = payload.len().div_ceil(u16::MAX as usize).max(1); + let mut record = Vec::with_capacity(payload.len() + 5 * chunks); + + if payload.is_empty() { + record.push(TLS_RECORD_APPLICATION); + record.extend_from_slice(&TLS_VERSION); + record.extend_from_slice(&0u16.to_be_bytes()); + return record; + } + + for chunk in payload.chunks(u16::MAX as usize) { + record.push(TLS_RECORD_APPLICATION); + record.extend_from_slice(&TLS_VERSION); + record.extend_from_slice(&(chunk.len() as u16).to_be_bytes()); + record.extend_from_slice(chunk); + } + record } @@ -1312,3 +1324,7 @@ mod masking_probe_evasion_blackhat_tests; #[cfg(test)] #[path = "tests/client_beobachten_ttl_bounds_security_tests.rs"] mod beobachten_ttl_bounds_security_tests; + +#[cfg(test)] +#[path = "tests/client_tls_record_wrap_hardening_security_tests.rs"] +mod tls_record_wrap_hardening_security_tests; diff --git a/src/proxy/middle_relay.rs b/src/proxy/middle_relay.rs index d8d94d2..f56a606 100644 --- a/src/proxy/middle_relay.rs +++ b/src/proxy/middle_relay.rs @@ -49,11 +49,16 @@ const ME_D2C_FLUSH_BATCH_MAX_BYTES_MIN: usize = 4096; const QUOTA_USER_LOCKS_MAX: usize = 64; #[cfg(not(test))] const QUOTA_USER_LOCKS_MAX: usize = 4_096; +#[cfg(test)] +const QUOTA_OVERFLOW_LOCK_STRIPES: usize = 16; +#[cfg(not(test))] +const QUOTA_OVERFLOW_LOCK_STRIPES: usize = 256; static DESYNC_DEDUP: OnceLock> = OnceLock::new(); static DESYNC_HASHER: OnceLock = OnceLock::new(); static DESYNC_FULL_CACHE_LAST_EMIT_AT: OnceLock>> = OnceLock::new(); static DESYNC_DEDUP_EVER_SATURATED: OnceLock = OnceLock::new(); static QUOTA_USER_LOCKS: OnceLock>>> = OnceLock::new(); +static QUOTA_USER_OVERFLOW_LOCKS: OnceLock>>> = OnceLock::new(); static RELAY_IDLE_CANDIDATE_REGISTRY: OnceLock> = OnceLock::new(); static RELAY_IDLE_MARK_SEQ: AtomicU64 = AtomicU64::new(0); @@ -413,6 +418,13 @@ fn desync_dedup_test_lock() -> &'static Mutex<()> { TEST_LOCK.get_or_init(|| Mutex::new(())) } +fn desync_forensics_len_bytes(len: usize) -> ([u8; 4], bool) { + match u32::try_from(len) { + Ok(value) => (value.to_le_bytes(), false), + Err(_) => (u32::MAX.to_le_bytes(), true), + } +} + fn report_desync_frame_too_large( state: &RelayForensicsState, proto_tag: ProtoTag, @@ -422,7 +434,8 @@ fn report_desync_frame_too_large( raw_len_bytes: Option<[u8; 4]>, stats: &Stats, ) -> ProxyError { - let len_buf = raw_len_bytes.unwrap_or((len as u32).to_le_bytes()); + let (fallback_len_buf, len_buf_truncated) = desync_forensics_len_bytes(len); + let len_buf = raw_len_bytes.unwrap_or(fallback_len_buf); let looks_like_tls = raw_len_bytes .map(|b| b[0] == 0x16 && b[1] == 0x03) .unwrap_or(false); @@ -458,6 +471,7 @@ fn report_desync_frame_too_large( bytes_me2c, raw_len = len, raw_len_hex = format_args!("0x{:08x}", len), + raw_len_bytes_truncated = len_buf_truncated, raw_bytes = format_args!( "{:02x} {:02x} {:02x} {:02x}", len_buf[0], len_buf[1], len_buf[2], len_buf[3] @@ -524,6 +538,30 @@ fn quota_would_be_exceeded_for_user( }) } +#[cfg(test)] +fn quota_user_lock_test_guard() -> &'static Mutex<()> { + static TEST_LOCK: OnceLock> = OnceLock::new(); + TEST_LOCK.get_or_init(|| Mutex::new(())) +} + +#[cfg(test)] +fn quota_user_lock_test_scope() -> std::sync::MutexGuard<'static, ()> { + quota_user_lock_test_guard() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +fn quota_overflow_user_lock(user: &str) -> Arc> { + let stripes = QUOTA_USER_OVERFLOW_LOCKS.get_or_init(|| { + (0..QUOTA_OVERFLOW_LOCK_STRIPES) + .map(|_| Arc::new(AsyncMutex::new(()))) + .collect() + }); + + let hash = crc32fast::hash(user.as_bytes()) as usize; + Arc::clone(&stripes[hash % stripes.len()]) +} + fn quota_user_lock(user: &str) -> Arc> { let locks = QUOTA_USER_LOCKS.get_or_init(DashMap::new); if let Some(existing) = locks.get(user) { @@ -535,7 +573,7 @@ fn quota_user_lock(user: &str) -> Arc> { } if locks.len() >= QUOTA_USER_LOCKS_MAX { - return Arc::new(AsyncMutex::new(())); + return quota_overflow_user_lock(user); } let created = Arc::new(AsyncMutex::new(())); @@ -1518,6 +1556,31 @@ where } } +fn compute_intermediate_secure_wire_len( + data_len: usize, + padding_len: usize, + quickack: bool, +) -> Result<(u32, usize)> { + let wire_len = data_len + .checked_add(padding_len) + .ok_or_else(|| ProxyError::Proxy("Frame length overflow".into()))?; + if wire_len > 0x7fff_ffffusize { + return Err(ProxyError::Proxy(format!( + "Intermediate/Secure frame too large: {wire_len}" + ))); + } + + let total = 4usize + .checked_add(wire_len) + .ok_or_else(|| ProxyError::Proxy("Frame buffer size overflow".into()))?; + let mut len_val = u32::try_from(wire_len) + .map_err(|_| ProxyError::Proxy("Frame length conversion overflow".into()))?; + if quickack { + len_val |= 0x8000_0000; + } + Ok((len_val, total)) +} + async fn write_client_payload( client_writer: &mut CryptoWriter, proto_tag: ProtoTag, @@ -1587,11 +1650,8 @@ where } else { 0 }; - let mut len_val = (data.len() + padding_len) as u32; - if quickack { - len_val |= 0x8000_0000; - } - let total = 4 + data.len() + padding_len; + let (len_val, total) = + compute_intermediate_secure_wire_len(data.len(), padding_len, quickack)?; frame_buf.clear(); frame_buf.reserve(total); frame_buf.extend_from_slice(&len_val.to_le_bytes()); @@ -1645,3 +1705,19 @@ mod desync_all_full_dedup_security_tests; #[cfg(test)] #[path = "tests/middle_relay_stub_completion_security_tests.rs"] mod stub_completion_security_tests; + +#[cfg(test)] +#[path = "tests/middle_relay_coverage_high_risk_security_tests.rs"] +mod coverage_high_risk_security_tests; + +#[cfg(test)] +#[path = "tests/middle_relay_quota_overflow_lock_security_tests.rs"] +mod quota_overflow_lock_security_tests; + +#[cfg(test)] +#[path = "tests/middle_relay_length_cast_hardening_security_tests.rs"] +mod length_cast_hardening_security_tests; + +#[cfg(test)] +#[path = "tests/middle_relay_blackhat_campaign_integration_tests.rs"] +mod blackhat_campaign_integration_tests; diff --git a/src/proxy/tests/client_tls_record_wrap_hardening_security_tests.rs b/src/proxy/tests/client_tls_record_wrap_hardening_security_tests.rs new file mode 100644 index 0000000..08f52d1 --- /dev/null +++ b/src/proxy/tests/client_tls_record_wrap_hardening_security_tests.rs @@ -0,0 +1,37 @@ +use super::*; + +#[test] +fn wrap_tls_application_record_empty_payload_emits_zero_length_record() { + let record = wrap_tls_application_record(&[]); + assert_eq!(record.len(), 5); + assert_eq!(record[0], TLS_RECORD_APPLICATION); + assert_eq!(&record[1..3], &TLS_VERSION); + assert_eq!(&record[3..5], &0u16.to_be_bytes()); +} + +#[test] +fn wrap_tls_application_record_oversized_payload_is_chunked_without_truncation() { + let total = (u16::MAX as usize) + 37; + let payload = vec![0xA5u8; total]; + let record = wrap_tls_application_record(&payload); + + let mut offset = 0usize; + let mut recovered = Vec::with_capacity(total); + let mut frames = 0usize; + + while offset + 5 <= record.len() { + assert_eq!(record[offset], TLS_RECORD_APPLICATION); + assert_eq!(&record[offset + 1..offset + 3], &TLS_VERSION); + let len = u16::from_be_bytes([record[offset + 3], record[offset + 4]]) as usize; + let body_start = offset + 5; + let body_end = body_start + len; + assert!(body_end <= record.len(), "declared TLS record length must be in-bounds"); + recovered.extend_from_slice(&record[body_start..body_end]); + offset = body_end; + frames += 1; + } + + assert_eq!(offset, record.len(), "record parser must consume exact output size"); + assert_eq!(frames, 2, "oversized payload should split into exactly two records"); + assert_eq!(recovered, payload, "chunked records must preserve full payload"); +} diff --git a/src/proxy/tests/middle_relay_blackhat_campaign_integration_tests.rs b/src/proxy/tests/middle_relay_blackhat_campaign_integration_tests.rs new file mode 100644 index 0000000..2c9f3f6 --- /dev/null +++ b/src/proxy/tests/middle_relay_blackhat_campaign_integration_tests.rs @@ -0,0 +1,112 @@ +use super::*; +use crate::stats::Stats; +use dashmap::DashMap; +use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; +use tokio::sync::Barrier; +use tokio::time::{Duration, timeout}; + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn blackhat_campaign_saturation_quota_race_with_queue_pressure_stays_fail_closed() { + let _guard = super::quota_user_lock_test_scope(); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); + for idx in 0..QUOTA_USER_LOCKS_MAX { + retained.push(quota_user_lock(&format!( + "middle-blackhat-held-{}-{idx}", + std::process::id() + ))); + } + + assert_eq!( + map.len(), + QUOTA_USER_LOCKS_MAX, + "precondition: bounded lock cache must be saturated" + ); + + let (tx, _rx) = mpsc::channel::(1); + tx.send(C2MeCommand::Close) + .await + .expect("queue prefill should succeed"); + + let pressure_seq_before = relay_pressure_event_seq(); + let pressure_errors = Arc::new(AtomicUsize::new(0)); + let mut pressure_workers = Vec::new(); + for _ in 0..16 { + let tx = tx.clone(); + let pressure_errors = Arc::clone(&pressure_errors); + pressure_workers.push(tokio::spawn(async move { + if enqueue_c2me_command(&tx, C2MeCommand::Close).await.is_err() { + pressure_errors.fetch_add(1, Ordering::Relaxed); + } + })); + } + + let stats = Arc::new(Stats::new()); + let user = format!("middle-blackhat-quota-race-{}", std::process::id()); + let gate = Arc::new(Barrier::new(16)); + + let mut quota_workers = Vec::new(); + for _ in 0..16u8 { + let stats = Arc::clone(&stats); + let user = user.clone(); + let gate = Arc::clone(&gate); + quota_workers.push(tokio::spawn(async move { + gate.wait().await; + let user_lock = quota_user_lock(&user); + let _quota_guard = user_lock.lock().await; + + if quota_would_be_exceeded_for_user(&stats, &user, Some(1), 1) { + return false; + } + stats.add_user_octets_to(&user, 1); + true + })); + } + + let mut ok_count = 0usize; + let mut denied_count = 0usize; + for worker in quota_workers { + let result = timeout(Duration::from_secs(2), worker) + .await + .expect("quota worker must finish") + .expect("quota worker must not panic"); + if result { + ok_count += 1; + } else { + denied_count += 1; + } + } + + for worker in pressure_workers { + timeout(Duration::from_secs(2), worker) + .await + .expect("pressure worker must finish") + .expect("pressure worker must not panic"); + } + + assert_eq!( + stats.get_user_total_octets(&user), + 1, + "black-hat campaign must not overshoot same-user quota under saturation" + ); + assert!(ok_count <= 1, "at most one quota contender may succeed"); + assert!( + denied_count >= 15, + "all remaining contenders must be quota-denied" + ); + + let pressure_seq_after = relay_pressure_event_seq(); + assert!( + pressure_seq_after > pressure_seq_before, + "queue pressure leg must trigger pressure accounting" + ); + assert!( + pressure_errors.load(Ordering::Relaxed) >= 1, + "at least one pressure worker should fail from persistent backpressure" + ); + + drop(retained); +} diff --git a/src/proxy/tests/middle_relay_coverage_high_risk_security_tests.rs b/src/proxy/tests/middle_relay_coverage_high_risk_security_tests.rs new file mode 100644 index 0000000..fff26b4 --- /dev/null +++ b/src/proxy/tests/middle_relay_coverage_high_risk_security_tests.rs @@ -0,0 +1,708 @@ +use super::*; +use crate::crypto::AesCtr; +use crate::crypto::SecureRandom; +use crate::stats::Stats; +use crate::stream::{BufferPool, PooledBuffer}; +use std::sync::Arc; +use tokio::io::AsyncReadExt; +use tokio::io::duplex; +use tokio::sync::mpsc; +use tokio::time::{Duration as TokioDuration, timeout}; + +fn make_pooled_payload(data: &[u8]) -> PooledBuffer { + let pool = Arc::new(BufferPool::with_config(data.len().max(1), 4)); + let mut payload = pool.get(); + payload.resize(data.len(), 0); + payload[..data.len()].copy_from_slice(data); + payload +} + +#[tokio::test] +async fn write_client_payload_abridged_short_quickack_sets_flag_and_preserves_payload() { + let (mut read_side, write_side) = duplex(4096); + let key = [0u8; 32]; + let iv = 0u128; + + let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); + let mut decryptor = AesCtr::new(&key, iv); + let rng = SecureRandom::new(); + let mut frame_buf = Vec::new(); + let payload = vec![0xA1, 0xB2, 0xC3, 0xD4, 0x10, 0x20, 0x30, 0x40]; + + write_client_payload( + &mut writer, + ProtoTag::Abridged, + RPC_FLAG_QUICKACK, + &payload, + &rng, + &mut frame_buf, + ) + .await + .expect("abridged quickack payload should serialize"); + writer.flush().await.expect("flush must succeed"); + + let mut encrypted = vec![0u8; 1 + payload.len()]; + read_side + .read_exact(&mut encrypted) + .await + .expect("must read serialized abridged frame"); + let plaintext = decryptor.decrypt(&encrypted); + + assert_eq!(plaintext[0], 0x80 | ((payload.len() / 4) as u8)); + assert_eq!(&plaintext[1..], payload.as_slice()); +} + +#[tokio::test] +async fn write_client_payload_abridged_extended_header_is_encoded_correctly() { + let (mut read_side, write_side) = duplex(16 * 1024); + let key = [0u8; 32]; + let iv = 0u128; + + let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); + let mut decryptor = AesCtr::new(&key, iv); + let rng = SecureRandom::new(); + let mut frame_buf = Vec::new(); + + // Boundary where abridged switches to extended length encoding. + let payload = vec![0x5Au8; 0x7f * 4]; + + write_client_payload( + &mut writer, + ProtoTag::Abridged, + RPC_FLAG_QUICKACK, + &payload, + &rng, + &mut frame_buf, + ) + .await + .expect("extended abridged payload should serialize"); + writer.flush().await.expect("flush must succeed"); + + let mut encrypted = vec![0u8; 4 + payload.len()]; + read_side + .read_exact(&mut encrypted) + .await + .expect("must read serialized extended abridged frame"); + let plaintext = decryptor.decrypt(&encrypted); + + assert_eq!(plaintext[0], 0xff, "0x7f with quickack bit must be set"); + assert_eq!(&plaintext[1..4], &[0x7f, 0x00, 0x00]); + assert_eq!(&plaintext[4..], payload.as_slice()); +} + +#[tokio::test] +async fn write_client_payload_abridged_misaligned_is_rejected_fail_closed() { + let (_read_side, write_side) = duplex(1024); + let key = [0u8; 32]; + let iv = 0u128; + + let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); + let rng = SecureRandom::new(); + let mut frame_buf = Vec::new(); + + let err = write_client_payload( + &mut writer, + ProtoTag::Abridged, + 0, + &[1, 2, 3], + &rng, + &mut frame_buf, + ) + .await + .expect_err("misaligned abridged payload must be rejected"); + + let msg = format!("{err}"); + assert!( + msg.contains("4-byte aligned"), + "error should explain alignment contract, got: {msg}" + ); +} + +#[tokio::test] +async fn write_client_payload_secure_misaligned_is_rejected_fail_closed() { + let (_read_side, write_side) = duplex(1024); + let key = [0u8; 32]; + let iv = 0u128; + + let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); + let rng = SecureRandom::new(); + let mut frame_buf = Vec::new(); + + let err = write_client_payload( + &mut writer, + ProtoTag::Secure, + 0, + &[9, 8, 7, 6, 5], + &rng, + &mut frame_buf, + ) + .await + .expect_err("misaligned secure payload must be rejected"); + + let msg = format!("{err}"); + assert!( + msg.contains("Secure payload must be 4-byte aligned"), + "error should be explicit for fail-closed triage, got: {msg}" + ); +} + +#[tokio::test] +async fn write_client_payload_intermediate_quickack_sets_length_msb() { + let (mut read_side, write_side) = duplex(4096); + let key = [0u8; 32]; + let iv = 0u128; + + let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); + let mut decryptor = AesCtr::new(&key, iv); + let rng = SecureRandom::new(); + let mut frame_buf = Vec::new(); + let payload = b"hello-middle-relay"; + + write_client_payload( + &mut writer, + ProtoTag::Intermediate, + RPC_FLAG_QUICKACK, + payload, + &rng, + &mut frame_buf, + ) + .await + .expect("intermediate quickack payload should serialize"); + writer.flush().await.expect("flush must succeed"); + + let mut encrypted = vec![0u8; 4 + payload.len()]; + read_side + .read_exact(&mut encrypted) + .await + .expect("must read intermediate frame"); + let plaintext = decryptor.decrypt(&encrypted); + + let mut len_bytes = [0u8; 4]; + len_bytes.copy_from_slice(&plaintext[..4]); + let len_with_flags = u32::from_le_bytes(len_bytes); + assert_ne!(len_with_flags & 0x8000_0000, 0, "quickack bit must be set"); + assert_eq!((len_with_flags & 0x7fff_ffff) as usize, payload.len()); + assert_eq!(&plaintext[4..], payload); +} + +#[tokio::test] +async fn write_client_payload_secure_quickack_prefix_and_padding_bounds_hold() { + let (mut read_side, write_side) = duplex(4096); + let key = [0u8; 32]; + let iv = 0u128; + + let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); + let mut decryptor = AesCtr::new(&key, iv); + let rng = SecureRandom::new(); + let mut frame_buf = Vec::new(); + let payload = vec![0x33u8; 100]; // 4-byte aligned as required by secure mode. + + write_client_payload( + &mut writer, + ProtoTag::Secure, + RPC_FLAG_QUICKACK, + &payload, + &rng, + &mut frame_buf, + ) + .await + .expect("secure quickack payload should serialize"); + writer.flush().await.expect("flush must succeed"); + + // Secure mode adds 1..=3 bytes of randomized tail padding. + let mut encrypted_header = [0u8; 4]; + read_side + .read_exact(&mut encrypted_header) + .await + .expect("must read secure header"); + let decrypted_header = decryptor.decrypt(&encrypted_header); + let header: [u8; 4] = decrypted_header + .try_into() + .expect("decrypted secure header must be 4 bytes"); + let wire_len_raw = u32::from_le_bytes(header); + + assert_ne!( + wire_len_raw & 0x8000_0000, + 0, + "secure quickack bit must be set" + ); + + let wire_len = (wire_len_raw & 0x7fff_ffff) as usize; + assert!(wire_len >= payload.len()); + let padding_len = wire_len - payload.len(); + assert!( + (1..=3).contains(&padding_len), + "secure writer must add bounded random tail padding, got {padding_len}" + ); + + let mut encrypted_body = vec![0u8; wire_len]; + read_side + .read_exact(&mut encrypted_body) + .await + .expect("must read secure body"); + let decrypted_body = decryptor.decrypt(&encrypted_body); + assert_eq!(&decrypted_body[..payload.len()], payload.as_slice()); +} + +#[tokio::test] +#[ignore = "heavy: allocates >64MiB to validate abridged too-large fail-closed branch"] +async fn write_client_payload_abridged_too_large_is_rejected_fail_closed() { + let (_read_side, write_side) = duplex(1024); + let key = [0u8; 32]; + let iv = 0u128; + + let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); + let rng = SecureRandom::new(); + let mut frame_buf = Vec::new(); + + // Exactly one 4-byte word above the encodable 24-bit abridged length range. + let payload = vec![0x00u8; (1 << 24) * 4]; + let err = write_client_payload( + &mut writer, + ProtoTag::Abridged, + 0, + &payload, + &rng, + &mut frame_buf, + ) + .await + .expect_err("oversized abridged payload must be rejected"); + + let msg = format!("{err}"); + assert!( + msg.contains("Abridged frame too large"), + "error must clearly indicate oversize fail-close path, got: {msg}" + ); +} + +#[tokio::test] +async fn write_client_ack_intermediate_is_little_endian() { + let (mut read_side, write_side) = duplex(1024); + let key = [0u8; 32]; + let iv = 0u128; + let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); + let mut decryptor = AesCtr::new(&key, iv); + + write_client_ack(&mut writer, ProtoTag::Intermediate, 0x11_22_33_44) + .await + .expect("ack serialization should succeed"); + writer.flush().await.expect("flush must succeed"); + + let mut encrypted = [0u8; 4]; + read_side + .read_exact(&mut encrypted) + .await + .expect("must read ack bytes"); + let plain = decryptor.decrypt(&encrypted); + assert_eq!(plain.as_slice(), &0x11_22_33_44u32.to_le_bytes()); +} + +#[tokio::test] +async fn write_client_ack_abridged_is_big_endian() { + let (mut read_side, write_side) = duplex(1024); + let key = [0u8; 32]; + let iv = 0u128; + let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); + let mut decryptor = AesCtr::new(&key, iv); + + write_client_ack(&mut writer, ProtoTag::Abridged, 0xDE_AD_BE_EF) + .await + .expect("ack serialization should succeed"); + writer.flush().await.expect("flush must succeed"); + + let mut encrypted = [0u8; 4]; + read_side + .read_exact(&mut encrypted) + .await + .expect("must read ack bytes"); + let plain = decryptor.decrypt(&encrypted); + assert_eq!(plain.as_slice(), &0xDE_AD_BE_EFu32.to_be_bytes()); +} + +#[tokio::test] +async fn write_client_payload_abridged_short_boundary_0x7e_is_single_byte_header() { + let (mut read_side, write_side) = duplex(1024 * 1024); + let key = [0u8; 32]; + let iv = 0u128; + let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); + let mut decryptor = AesCtr::new(&key, iv); + let rng = SecureRandom::new(); + let mut frame_buf = Vec::new(); + let payload = vec![0xABu8; 0x7e * 4]; + + write_client_payload( + &mut writer, + ProtoTag::Abridged, + 0, + &payload, + &rng, + &mut frame_buf, + ) + .await + .expect("boundary payload should serialize"); + writer.flush().await.expect("flush must succeed"); + + let mut encrypted = vec![0u8; 1 + payload.len()]; + read_side.read_exact(&mut encrypted).await.unwrap(); + let plain = decryptor.decrypt(&encrypted); + assert_eq!(plain[0], 0x7e); + assert_eq!(&plain[1..], payload.as_slice()); +} + +#[tokio::test] +async fn write_client_payload_abridged_extended_without_quickack_has_clean_prefix() { + let (mut read_side, write_side) = duplex(16 * 1024); + let key = [0u8; 32]; + let iv = 0u128; + let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); + let mut decryptor = AesCtr::new(&key, iv); + let rng = SecureRandom::new(); + let mut frame_buf = Vec::new(); + let payload = vec![0x42u8; 0x80 * 4]; + + write_client_payload( + &mut writer, + ProtoTag::Abridged, + 0, + &payload, + &rng, + &mut frame_buf, + ) + .await + .expect("extended payload should serialize"); + writer.flush().await.expect("flush must succeed"); + + let mut encrypted = vec![0u8; 4 + payload.len()]; + read_side.read_exact(&mut encrypted).await.unwrap(); + let plain = decryptor.decrypt(&encrypted); + assert_eq!(plain[0], 0x7f); + assert_eq!(&plain[1..4], &[0x80, 0x00, 0x00]); + assert_eq!(&plain[4..], payload.as_slice()); +} + +#[tokio::test] +async fn write_client_payload_intermediate_zero_length_emits_header_only() { + let (mut read_side, write_side) = duplex(1024); + let key = [0u8; 32]; + let iv = 0u128; + let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); + let mut decryptor = AesCtr::new(&key, iv); + let rng = SecureRandom::new(); + let mut frame_buf = Vec::new(); + + write_client_payload( + &mut writer, + ProtoTag::Intermediate, + 0, + &[], + &rng, + &mut frame_buf, + ) + .await + .expect("zero-length intermediate payload should serialize"); + writer.flush().await.expect("flush must succeed"); + + let mut encrypted = [0u8; 4]; + read_side.read_exact(&mut encrypted).await.unwrap(); + let plain = decryptor.decrypt(&encrypted); + assert_eq!(plain.as_slice(), &[0, 0, 0, 0]); +} + +#[tokio::test] +async fn write_client_payload_intermediate_ignores_unrelated_flags() { + let (mut read_side, write_side) = duplex(1024); + let key = [0u8; 32]; + let iv = 0u128; + let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); + let mut decryptor = AesCtr::new(&key, iv); + let rng = SecureRandom::new(); + let mut frame_buf = Vec::new(); + let payload = [7u8; 12]; + + write_client_payload( + &mut writer, + ProtoTag::Intermediate, + 0x4000_0000, + &payload, + &rng, + &mut frame_buf, + ) + .await + .expect("payload should serialize"); + writer.flush().await.expect("flush must succeed"); + + let mut encrypted = [0u8; 16]; + read_side.read_exact(&mut encrypted).await.unwrap(); + let plain = decryptor.decrypt(&encrypted); + let len = u32::from_le_bytes(plain[0..4].try_into().unwrap()); + assert_eq!(len, payload.len() as u32, "only quickack bit may affect header"); + assert_eq!(&plain[4..], payload.as_slice()); +} + +#[tokio::test] +async fn write_client_payload_secure_without_quickack_keeps_msb_clear() { + let (mut read_side, write_side) = duplex(4096); + let key = [0u8; 32]; + let iv = 0u128; + let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); + let mut decryptor = AesCtr::new(&key, iv); + let rng = SecureRandom::new(); + let mut frame_buf = Vec::new(); + let payload = [0x1Du8; 64]; + + write_client_payload( + &mut writer, + ProtoTag::Secure, + 0, + &payload, + &rng, + &mut frame_buf, + ) + .await + .expect("payload should serialize"); + writer.flush().await.expect("flush must succeed"); + + let mut encrypted_header = [0u8; 4]; + read_side.read_exact(&mut encrypted_header).await.unwrap(); + let plain_header = decryptor.decrypt(&encrypted_header); + let h: [u8; 4] = plain_header.as_slice().try_into().unwrap(); + let wire_len_raw = u32::from_le_bytes(h); + assert_eq!(wire_len_raw & 0x8000_0000, 0, "quickack bit must stay clear"); +} + +#[tokio::test] +async fn secure_padding_light_fuzz_distribution_has_multiple_outcomes() { + let (mut read_side, write_side) = duplex(256 * 1024); + let key = [0u8; 32]; + let iv = 0u128; + let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); + let mut decryptor = AesCtr::new(&key, iv); + let rng = SecureRandom::new(); + let mut frame_buf = Vec::new(); + let payload = [0x55u8; 100]; + let mut seen = [false; 4]; + + for _ in 0..96 { + write_client_payload( + &mut writer, + ProtoTag::Secure, + 0, + &payload, + &rng, + &mut frame_buf, + ) + .await + .expect("secure payload should serialize"); + writer.flush().await.expect("flush must succeed"); + + let mut encrypted_header = [0u8; 4]; + read_side.read_exact(&mut encrypted_header).await.unwrap(); + let plain_header = decryptor.decrypt(&encrypted_header); + let h: [u8; 4] = plain_header.as_slice().try_into().unwrap(); + let wire_len = (u32::from_le_bytes(h) & 0x7fff_ffff) as usize; + let padding_len = wire_len - payload.len(); + assert!((1..=3).contains(&padding_len)); + seen[padding_len] = true; + + let mut encrypted_body = vec![0u8; wire_len]; + read_side.read_exact(&mut encrypted_body).await.unwrap(); + let _ = decryptor.decrypt(&encrypted_body); + } + + let distinct = (1..=3).filter(|idx| seen[*idx]).count(); + assert!( + distinct >= 2, + "padding generator should not collapse to a single outcome under campaign" + ); +} + +#[tokio::test] +async fn write_client_payload_mixed_proto_sequence_preserves_stream_sync() { + let (mut read_side, write_side) = duplex(128 * 1024); + let key = [0u8; 32]; + let iv = 0u128; + let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); + let mut decryptor = AesCtr::new(&key, iv); + let rng = SecureRandom::new(); + let mut frame_buf = Vec::new(); + + let p1 = vec![1u8; 8]; + let p2 = vec![2u8; 16]; + let p3 = vec![3u8; 20]; + + write_client_payload(&mut writer, ProtoTag::Abridged, 0, &p1, &rng, &mut frame_buf) + .await + .unwrap(); + write_client_payload( + &mut writer, + ProtoTag::Intermediate, + RPC_FLAG_QUICKACK, + &p2, + &rng, + &mut frame_buf, + ) + .await + .unwrap(); + write_client_payload(&mut writer, ProtoTag::Secure, 0, &p3, &rng, &mut frame_buf) + .await + .unwrap(); + writer.flush().await.unwrap(); + + // Frame 1: abridged short. + let mut e1 = vec![0u8; 1 + p1.len()]; + read_side.read_exact(&mut e1).await.unwrap(); + let d1 = decryptor.decrypt(&e1); + assert_eq!(d1[0], (p1.len() / 4) as u8); + assert_eq!(&d1[1..], p1.as_slice()); + + // Frame 2: intermediate with quickack. + let mut e2 = vec![0u8; 4 + p2.len()]; + read_side.read_exact(&mut e2).await.unwrap(); + let d2 = decryptor.decrypt(&e2); + let l2 = u32::from_le_bytes(d2[0..4].try_into().unwrap()); + assert_ne!(l2 & 0x8000_0000, 0); + assert_eq!((l2 & 0x7fff_ffff) as usize, p2.len()); + assert_eq!(&d2[4..], p2.as_slice()); + + // Frame 3: secure with bounded tail. + let mut e3h = [0u8; 4]; + read_side.read_exact(&mut e3h).await.unwrap(); + let d3h = decryptor.decrypt(&e3h); + let l3 = (u32::from_le_bytes(d3h.as_slice().try_into().unwrap()) & 0x7fff_ffff) as usize; + assert!(l3 >= p3.len()); + assert!((1..=3).contains(&(l3 - p3.len()))); + let mut e3b = vec![0u8; l3]; + read_side.read_exact(&mut e3b).await.unwrap(); + let d3b = decryptor.decrypt(&e3b); + assert_eq!(&d3b[..p3.len()], p3.as_slice()); +} + +#[test] +fn should_yield_sender_boundary_matrix_blackhat() { + assert!(!should_yield_c2me_sender(0, false)); + assert!(!should_yield_c2me_sender(0, true)); + assert!(!should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET - 1, true)); + assert!(!should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET, false)); + assert!(should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET, true)); + assert!(should_yield_c2me_sender( + C2ME_SENDER_FAIRNESS_BUDGET.saturating_add(1024), + true + )); +} + +#[test] +fn should_yield_sender_light_fuzz_matches_oracle() { + let mut s: u64 = 0xD00D_BAAD_F00D_CAFE; + for _ in 0..5000 { + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + let sent = (s as usize) & 0x1fff; + let backlog = (s & 1) != 0; + + let expected = backlog && sent >= C2ME_SENDER_FAIRNESS_BUDGET; + assert_eq!(should_yield_c2me_sender(sent, backlog), expected); + } +} + +#[test] +fn quota_would_be_exceeded_exact_remaining_one_byte() { + let stats = Stats::new(); + let user = "quota-edge"; + let quota = 100u64; + stats.add_user_octets_to(user, 99); + + assert!( + !quota_would_be_exceeded_for_user(&stats, user, Some(quota), 1), + "exactly remaining budget should be allowed" + ); + assert!( + quota_would_be_exceeded_for_user(&stats, user, Some(quota), 2), + "one byte beyond remaining budget must be rejected" + ); +} + +#[test] +fn quota_would_be_exceeded_saturating_edge_remains_fail_closed() { + let stats = Stats::new(); + let user = "quota-saturating-edge"; + let quota = u64::MAX - 3; + stats.add_user_octets_to(user, u64::MAX - 4); + + assert!( + quota_would_be_exceeded_for_user(&stats, user, Some(quota), 2), + "saturating arithmetic edge must stay fail-closed" + ); +} + +#[test] +fn quota_exceeded_boundary_is_inclusive() { + let stats = Stats::new(); + let user = "quota-inclusive-boundary"; + stats.add_user_octets_to(user, 50); + + assert!(quota_exceeded_for_user(&stats, user, Some(50))); + assert!(!quota_exceeded_for_user(&stats, user, Some(51))); +} + +#[tokio::test] +async fn enqueue_c2me_close_fast_path_succeeds_without_backpressure() { + let (tx, mut rx) = mpsc::channel::(4); + enqueue_c2me_command(&tx, C2MeCommand::Close) + .await + .expect("close should enqueue on fast path"); + + let recv = timeout(TokioDuration::from_millis(50), rx.recv()) + .await + .expect("must receive close command") + .expect("close command should be present"); + assert!(matches!(recv, C2MeCommand::Close)); +} + +#[tokio::test] +async fn enqueue_c2me_data_full_then_drain_preserves_order() { + let (tx, mut rx) = mpsc::channel::(1); + tx.send(C2MeCommand::Data { + payload: make_pooled_payload(&[1]), + flags: 10, + }) + .await + .unwrap(); + + let tx2 = tx.clone(); + let producer = tokio::spawn(async move { + enqueue_c2me_command( + &tx2, + C2MeCommand::Data { + payload: make_pooled_payload(&[2, 2]), + flags: 20, + }, + ) + .await + }); + + tokio::time::sleep(TokioDuration::from_millis(10)).await; + + let first = rx.recv().await.expect("first item should exist"); + match first { + C2MeCommand::Data { payload, flags } => { + assert_eq!(payload.as_ref(), &[1]); + assert_eq!(flags, 10); + } + C2MeCommand::Close => panic!("unexpected close as first item"), + } + + producer.await.unwrap().expect("producer should complete"); + + let second = timeout(TokioDuration::from_millis(100), rx.recv()) + .await + .unwrap() + .expect("second item should exist"); + match second { + C2MeCommand::Data { payload, flags } => { + assert_eq!(payload.as_ref(), &[2, 2]); + assert_eq!(flags, 20); + } + C2MeCommand::Close => panic!("unexpected close as second item"), + } +} diff --git a/src/proxy/tests/middle_relay_length_cast_hardening_security_tests.rs b/src/proxy/tests/middle_relay_length_cast_hardening_security_tests.rs new file mode 100644 index 0000000..6c6644d --- /dev/null +++ b/src/proxy/tests/middle_relay_length_cast_hardening_security_tests.rs @@ -0,0 +1,75 @@ +use super::*; +use std::sync::Arc; +use std::sync::atomic::AtomicU64; + +#[test] +fn intermediate_secure_wire_len_allows_max_31bit_payload() { + let (len_val, total) = compute_intermediate_secure_wire_len(0x7fff_fffe, 1, true) + .expect("31-bit wire length should be accepted"); + + assert_eq!(len_val, 0xffff_ffff, "quickack must use top bit only"); + assert_eq!(total, 0x8000_0003); +} + +#[test] +fn intermediate_secure_wire_len_rejects_length_above_31bit_limit() { + let err = compute_intermediate_secure_wire_len(0x7fff_ffff, 1, false) + .expect_err("wire length above 31-bit must fail closed"); + assert!( + format!("{err}").contains("frame too large"), + "error should identify oversize frame path" + ); +} + +#[test] +fn intermediate_secure_wire_len_rejects_addition_overflow() { + let err = compute_intermediate_secure_wire_len(usize::MAX, 1, false) + .expect_err("overflowing addition must fail closed"); + assert!( + format!("{err}").contains("overflow"), + "error should clearly report overflow" + ); +} + +#[test] +fn desync_forensics_len_bytes_marks_truncation_for_oversize_values() { + let (small_bytes, small_truncated) = desync_forensics_len_bytes(0x1020_3040); + assert_eq!(small_bytes, 0x1020_3040u32.to_le_bytes()); + assert!(!small_truncated); + + let (huge_bytes, huge_truncated) = desync_forensics_len_bytes(usize::MAX); + assert_eq!(huge_bytes, u32::MAX.to_le_bytes()); + assert!(huge_truncated); +} + +#[test] +fn report_desync_frame_too_large_preserves_full_length_in_error_message() { + let state = RelayForensicsState { + trace_id: 0x1234, + conn_id: 0x5678, + user: "middle-desync-oversize".to_string(), + peer: "198.51.100.55:443".parse().expect("valid test peer"), + peer_hash: 0xAABBCCDD, + started_at: Instant::now(), + bytes_c2me: 7, + bytes_me2c: Arc::new(AtomicU64::new(9)), + desync_all_full: false, + }; + + let huge_len = usize::MAX; + let err = report_desync_frame_too_large( + &state, + ProtoTag::Intermediate, + 3, + 1024, + huge_len, + None, + &Stats::new(), + ); + + let msg = format!("{err}"); + assert!( + msg.contains(&huge_len.to_string()), + "error must preserve full usize length for forensics" + ); +} diff --git a/src/proxy/tests/middle_relay_quota_overflow_lock_security_tests.rs b/src/proxy/tests/middle_relay_quota_overflow_lock_security_tests.rs new file mode 100644 index 0000000..d06e103 --- /dev/null +++ b/src/proxy/tests/middle_relay_quota_overflow_lock_security_tests.rs @@ -0,0 +1,131 @@ +use super::*; +use dashmap::DashMap; +use std::sync::Arc; + +#[test] +fn saturation_uses_stable_overflow_lock_without_cache_growth() { + let _guard = super::quota_user_lock_test_scope(); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let prefix = format!("middle-quota-held-{}", std::process::id()); + let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); + for idx in 0..QUOTA_USER_LOCKS_MAX { + retained.push(quota_user_lock(&format!("{prefix}-{idx}"))); + } + + assert_eq!(map.len(), QUOTA_USER_LOCKS_MAX); + + let user = format!("middle-quota-overflow-{}", std::process::id()); + let first = quota_user_lock(&user); + let second = quota_user_lock(&user); + + assert!( + Arc::ptr_eq(&first, &second), + "overflow user must get deterministic same lock while cache is saturated" + ); + assert_eq!( + map.len(), + QUOTA_USER_LOCKS_MAX, + "overflow path must not grow bounded lock map" + ); + assert!( + map.get(&user).is_none(), + "overflow user should stay outside bounded lock map under saturation" + ); + + drop(retained); +} + +#[test] +fn overflow_striping_keeps_different_users_distributed() { + let _guard = super::quota_user_lock_test_scope(); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let prefix = format!("middle-quota-dist-held-{}", std::process::id()); + let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); + for idx in 0..QUOTA_USER_LOCKS_MAX { + retained.push(quota_user_lock(&format!("{prefix}-{idx}"))); + } + + let a = quota_user_lock("middle-overflow-user-a"); + let b = quota_user_lock("middle-overflow-user-b"); + let c = quota_user_lock("middle-overflow-user-c"); + + let distinct = [ + Arc::as_ptr(&a) as usize, + Arc::as_ptr(&b) as usize, + Arc::as_ptr(&c) as usize, + ] + .iter() + .copied() + .collect::>() + .len(); + + assert!( + distinct >= 2, + "striped overflow lock set should avoid collapsing all users to one lock" + ); + + drop(retained); +} + +#[test] +fn reclaim_path_caches_new_user_after_stale_entries_drop() { + let _guard = super::quota_user_lock_test_scope(); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let prefix = format!("middle-quota-reclaim-held-{}", std::process::id()); + let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); + for idx in 0..QUOTA_USER_LOCKS_MAX { + retained.push(quota_user_lock(&format!("{prefix}-{idx}"))); + } + + drop(retained); + + let user = format!("middle-quota-reclaim-user-{}", std::process::id()); + let got = quota_user_lock(&user); + assert!(map.get(&user).is_some()); + assert!( + Arc::strong_count(&got) >= 2, + "after reclaim, lock should be held both by caller and map" + ); +} + +#[test] +fn overflow_path_same_user_is_stable_across_parallel_threads() { + let _guard = super::quota_user_lock_test_scope(); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); + for idx in 0..QUOTA_USER_LOCKS_MAX { + retained.push(quota_user_lock(&format!( + "middle-quota-thread-held-{}-{idx}", + std::process::id() + ))); + } + + let user = format!("middle-quota-overflow-thread-user-{}", std::process::id()); + let mut workers = Vec::new(); + for _ in 0..32 { + let user = user.clone(); + workers.push(std::thread::spawn(move || quota_user_lock(&user))); + } + + let first = workers + .remove(0) + .join() + .expect("thread must return lock handle"); + for worker in workers { + let got = worker.join().expect("thread must return lock handle"); + assert!( + Arc::ptr_eq(&first, &got), + "same overflow user should resolve to one striped lock even under contention" + ); + } + + drop(retained); +} diff --git a/src/proxy/tests/middle_relay_security_tests.rs b/src/proxy/tests/middle_relay_security_tests.rs index 8b4f7f1..4ec20df 100644 --- a/src/proxy/tests/middle_relay_security_tests.rs +++ b/src/proxy/tests/middle_relay_security_tests.rs @@ -15,7 +15,7 @@ use std::collections::{HashMap, HashSet}; use std::net::SocketAddr; use std::sync::Arc; use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering}; -use std::sync::{Mutex, OnceLock}; +use std::sync::Mutex; use std::thread; use tokio::io::AsyncReadExt; use tokio::io::AsyncWriteExt; @@ -38,11 +38,6 @@ fn make_pooled_payload_from(pool: &Arc, data: &[u8]) -> PooledBuffer payload } -fn quota_user_lock_test_lock() -> &'static Mutex<()> { - static TEST_LOCK: OnceLock> = OnceLock::new(); - TEST_LOCK.get_or_init(|| Mutex::new(())) -} - #[test] fn should_yield_sender_only_on_budget_with_backlog() { assert!(!should_yield_c2me_sender(0, true)); @@ -250,9 +245,7 @@ fn quota_user_lock_cache_reuses_entry_for_same_user() { #[test] fn quota_user_lock_cache_is_bounded_under_unique_churn() { - let _guard = quota_user_lock_test_lock() - .lock() - .expect("quota user lock test lock must be available"); + let _guard = super::quota_user_lock_test_scope(); let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); map.clear(); @@ -270,10 +263,8 @@ fn quota_user_lock_cache_is_bounded_under_unique_churn() { } #[test] -fn quota_user_lock_cache_saturation_returns_ephemeral_lock_without_growth() { - let _guard = quota_user_lock_test_lock() - .lock() - .expect("quota user lock test lock must be available"); +fn quota_user_lock_cache_saturation_returns_stable_overflow_lock_without_growth() { + let _guard = super::quota_user_lock_test_scope(); let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); for attempt in 0..8u32 { @@ -305,8 +296,8 @@ fn quota_user_lock_cache_saturation_returns_ephemeral_lock_without_growth() { "overflow path should not cache new user lock when map is saturated and all entries are retained" ); assert!( - !Arc::ptr_eq(&overflow_a, &overflow_b), - "overflow user lock should be ephemeral under saturation to preserve bounded cache size" + Arc::ptr_eq(&overflow_a, &overflow_b), + "overflow user lock should use deterministic striping under saturation" ); drop(retained);