From 6f17d4d2316889359266e6e5e2e07dea27c1ecb4 Mon Sep 17 00:00:00 2001 From: David Osipov Date: Mon, 23 Mar 2026 12:04:41 +0400 Subject: [PATCH] Add comprehensive security tests for quota management and relay functionality - Introduced `relay_dual_lock_race_harness_security_tests.rs` to validate user liveness during lock hold and release cycles. - Added `relay_quota_extended_attack_surface_security_tests.rs` to cover various quota scenarios including positive, negative, edge cases, and adversarial conditions. - Implemented `relay_quota_lock_eviction_lifecycle_tdd_tests.rs` to ensure proper eviction of stale entries and lifecycle management of quota locks. - Created `relay_quota_lock_eviction_stress_security_tests.rs` to stress test the eviction mechanism under high churn conditions. - Enhanced `relay_quota_lock_pressure_adversarial_tests.rs` to verify reclaiming of unreferenced entries after explicit eviction. - Developed `relay_quota_retry_allocation_latency_security_tests.rs` to benchmark and validate latency and allocation behavior under contention. --- Cargo.lock | 4 +- src/maestro/runtime_tasks.rs | 31 + src/proxy/handshake.rs | 7 +- src/proxy/masking.rs | 112 ++- src/proxy/middle_relay.rs | 125 ++- src/proxy/quota_lock_registry.rs | 37 +- src/proxy/relay.rs | 149 ++- src/proxy/tests/client_security_tests.rs | 2 +- ...auth_probe_eviction_bias_security_tests.rs | 93 ++ ...e_auth_probe_scan_budget_security_tests.rs | 21 +- ...ake_auth_probe_scan_offset_stress_tests.rs | 21 +- .../tests/handshake_more_clever_tests.rs | 2 +- ..._extended_attack_surface_security_tests.rs | 217 +++++ ...erface_cache_concurrency_security_tests.rs | 41 + .../masking_interface_cache_security_tests.rs | 14 +- ...roduction_cap_regression_security_tests.rs | 289 ++++++ ...masking_self_target_loop_security_tests.rs | 54 +- ...g_timing_budget_coupling_security_tests.rs | 55 ++ ...relay_coverage_high_risk_security_tests.rs | 69 ++ ..._lock_release_regression_security_tests.rs | 295 ++++++ ...s_mode_lookup_efficiency_security_tests.rs | 116 +++ ...s_mode_quota_lock_matrix_security_tests.rs | 376 ++++++++ ...s_mode_quota_reservation_security_tests.rs | 254 +++++ .../middle_relay_hol_quota_security_tests.rs | 3 + ..._extended_attack_surface_security_tests.rs | 372 ++++++++ ...lay_quota_reservation_adversarial_tests.rs | 874 ++++++++++++++++++ ...uota_reservation_extreme_security_tests.rs | 399 ++++++++ ...y_frame_debt_concurrency_security_tests.rs | 34 +- ...rame_debt_proto_chunking_security_tests.rs | 59 +- ...le_relay_tiny_frame_debt_security_tests.rs | 282 +++++- ...pipeline_hol_integration_security_tests.rs | 267 ++++++ ...peline_latency_benchmark_security_tests.rs | 213 +++++ ...lay_cross_mode_quota_fairness_tdd_tests.rs | 381 +++++++- ...k_alternating_contention_security_tests.rs | 340 +++++++ ..._lock_backoff_regression_security_tests.rs | 74 ++ ...l_lock_contention_matrix_security_tests.rs | 325 +++++++ ...y_dual_lock_race_harness_security_tests.rs | 128 +++ ..._extended_attack_surface_security_tests.rs | 332 +++++++ ...quota_lock_eviction_lifecycle_tdd_tests.rs | 79 ++ ...ota_lock_eviction_stress_security_tests.rs | 153 +++ ...y_quota_lock_pressure_adversarial_tests.rs | 4 +- ...retry_allocation_latency_security_tests.rs | 249 +++++ 42 files changed, 6774 insertions(+), 178 deletions(-) create mode 100644 src/proxy/tests/handshake_auth_probe_eviction_bias_security_tests.rs create mode 100644 src/proxy/tests/masking_extended_attack_surface_security_tests.rs create mode 100644 src/proxy/tests/masking_interface_cache_concurrency_security_tests.rs create mode 100644 src/proxy/tests/masking_production_cap_regression_security_tests.rs create mode 100644 src/proxy/tests/masking_timing_budget_coupling_security_tests.rs create mode 100644 src/proxy/tests/middle_relay_cross_mode_lock_release_regression_security_tests.rs create mode 100644 src/proxy/tests/middle_relay_cross_mode_lookup_efficiency_security_tests.rs create mode 100644 src/proxy/tests/middle_relay_cross_mode_quota_lock_matrix_security_tests.rs create mode 100644 src/proxy/tests/middle_relay_cross_mode_quota_reservation_security_tests.rs create mode 100644 src/proxy/tests/middle_relay_quota_extended_attack_surface_security_tests.rs create mode 100644 src/proxy/tests/middle_relay_quota_reservation_extreme_security_tests.rs create mode 100644 src/proxy/tests/relay_cross_mode_pipeline_hol_integration_security_tests.rs create mode 100644 src/proxy/tests/relay_cross_mode_pipeline_latency_benchmark_security_tests.rs create mode 100644 src/proxy/tests/relay_dual_lock_alternating_contention_security_tests.rs create mode 100644 src/proxy/tests/relay_dual_lock_backoff_regression_security_tests.rs create mode 100644 src/proxy/tests/relay_dual_lock_contention_matrix_security_tests.rs create mode 100644 src/proxy/tests/relay_dual_lock_race_harness_security_tests.rs create mode 100644 src/proxy/tests/relay_quota_extended_attack_surface_security_tests.rs create mode 100644 src/proxy/tests/relay_quota_lock_eviction_lifecycle_tdd_tests.rs create mode 100644 src/proxy/tests/relay_quota_lock_eviction_stress_security_tests.rs create mode 100644 src/proxy/tests/relay_quota_retry_allocation_latency_security_tests.rs diff --git a/Cargo.lock b/Cargo.lock index c4cde39..92da630 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1454,9 +1454,9 @@ dependencies = [ [[package]] name = "iri-string" -version = "0.7.10" +version = "0.7.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c91338f0783edbd6195decb37bae672fd3b165faffb89bf7b9e6942f8b1a731a" +checksum = "d8e7418f59cc01c88316161279a7f665217ae316b388e58a0d10e29f54f1e5eb" dependencies = [ "memchr", "serde", diff --git a/src/maestro/runtime_tasks.rs b/src/maestro/runtime_tasks.rs index d553eb9..066c853 100644 --- a/src/maestro/runtime_tasks.rs +++ b/src/maestro/runtime_tasks.rs @@ -32,6 +32,14 @@ pub(crate) struct RuntimeWatches { pub(crate) detected_ip_v6: Option, } +const QUOTA_USER_LOCK_EVICT_INTERVAL_SECS: u64 = 60; + +fn spawn_quota_lock_maintenance_task() -> tokio::task::JoinHandle<()> { + crate::proxy::relay::spawn_quota_user_lock_evictor(std::time::Duration::from_secs( + QUOTA_USER_LOCK_EVICT_INTERVAL_SECS, + )) +} + #[allow(clippy::too_many_arguments)] pub(crate) async fn spawn_runtime_tasks( config: &Arc, @@ -69,6 +77,8 @@ pub(crate) async fn spawn_runtime_tasks( rc_clone.run_periodic_cleanup().await; }); + spawn_quota_lock_maintenance_task(); + let detected_ip_v4: Option = probe.detected_ipv4.map(IpAddr::V4); let detected_ip_v6: Option = probe.detected_ipv6.map(IpAddr::V6); debug!( @@ -360,3 +370,24 @@ pub(crate) async fn mark_runtime_ready(startup_tracker: &Arc) { .await; startup_tracker.mark_ready().await; } + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn tdd_runtime_quota_lock_maintenance_path_spawns_single_evictor_task() { + crate::proxy::relay::reset_quota_user_lock_evictor_spawn_count_for_tests(); + + let handle = spawn_quota_lock_maintenance_task(); + tokio::time::sleep(std::time::Duration::from_millis(5)).await; + + assert_eq!( + crate::proxy::relay::quota_user_lock_evictor_spawn_count_for_tests(), + 1, + "runtime maintenance path must spawn exactly one quota lock evictor task per call" + ); + + handle.abort(); + } +} diff --git a/src/proxy/handshake.rs b/src/proxy/handshake.rs index 3444a88..96994c7 100644 --- a/src/proxy/handshake.rs +++ b/src/proxy/handshake.rs @@ -131,8 +131,7 @@ fn auth_probe_scan_start_offset( return 0; } - let window = state_len.min(scan_limit); - auth_probe_eviction_offset(peer_ip, now) % window + auth_probe_eviction_offset(peer_ip, now) % state_len } fn auth_probe_is_throttled(peer_ip: IpAddr, now: Instant) -> bool { @@ -997,6 +996,10 @@ mod auth_probe_scan_budget_security_tests; #[path = "tests/handshake_auth_probe_scan_offset_stress_tests.rs"] mod auth_probe_scan_offset_stress_tests; +#[cfg(test)] +#[path = "tests/handshake_auth_probe_eviction_bias_security_tests.rs"] +mod auth_probe_eviction_bias_security_tests; + #[cfg(test)] #[path = "tests/handshake_advanced_clever_tests.rs"] mod advanced_clever_tests; diff --git a/src/proxy/masking.rs b/src/proxy/masking.rs index 7d970c2..841749c 100644 --- a/src/proxy/masking.rs +++ b/src/proxy/masking.rs @@ -19,6 +19,8 @@ use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::net::TcpStream; #[cfg(unix)] use tokio::net::UnixStream; +#[cfg(unix)] +use tokio::sync::Mutex as AsyncMutex; use tokio::time::{Instant, timeout}; use tracing::debug; @@ -95,10 +97,6 @@ where Ok(Ok(())) => {} Ok(Err(_)) | Err(_) => break, } - - if total >= byte_cap { - break; - } } CopyOutcome { total, @@ -370,6 +368,9 @@ struct LocalInterfaceCache { static LOCAL_INTERFACE_CACHE: OnceLock> = OnceLock::new(); #[cfg(unix)] +static LOCAL_INTERFACE_REFRESH_LOCK: OnceLock> = OnceLock::new(); + +#[cfg(all(unix, test))] fn local_interface_ips() -> Vec { let cache = LOCAL_INTERFACE_CACHE.get_or_init(|| Mutex::new(LocalInterfaceCache::default())); let mut guard = cache.lock().unwrap_or_else(|poison| poison.into_inner()); @@ -386,11 +387,59 @@ fn local_interface_ips() -> Vec { guard.ips.clone() } -#[cfg(not(unix))] +#[cfg(unix)] +async fn local_interface_ips_async() -> Vec { + let cache = LOCAL_INTERFACE_CACHE.get_or_init(|| Mutex::new(LocalInterfaceCache::default())); + + { + let guard = cache.lock().unwrap_or_else(|poison| poison.into_inner()); + let stale = guard + .refreshed_at + .is_none_or(|at| at.elapsed() >= LOCAL_INTERFACE_CACHE_TTL); + if !stale { + return guard.ips.clone(); + } + } + + let refresh_lock = LOCAL_INTERFACE_REFRESH_LOCK.get_or_init(|| AsyncMutex::new(())); + let _refresh_guard = refresh_lock.lock().await; + + { + let guard = cache.lock().unwrap_or_else(|poison| poison.into_inner()); + let stale = guard + .refreshed_at + .is_none_or(|at| at.elapsed() >= LOCAL_INTERFACE_CACHE_TTL); + if !stale { + return guard.ips.clone(); + } + } + + let refreshed = tokio::task::spawn_blocking(collect_local_interface_ips) + .await + .unwrap_or_default(); + + let mut guard = cache.lock().unwrap_or_else(|poison| poison.into_inner()); + let stale = guard + .refreshed_at + .is_none_or(|at| at.elapsed() >= LOCAL_INTERFACE_CACHE_TTL); + if stale { + guard.ips = choose_interface_snapshot(&guard.ips, refreshed); + guard.refreshed_at = Some(StdInstant::now()); + } + + guard.ips.clone() +} + +#[cfg(all(not(unix), test))] fn local_interface_ips() -> Vec { Vec::new() } +#[cfg(not(unix))] +async fn local_interface_ips_async() -> Vec { + Vec::new() +} + #[cfg(test)] static LOCAL_INTERFACE_ENUMERATIONS: AtomicUsize = AtomicUsize::new(0); @@ -457,6 +506,7 @@ fn is_mask_target_local_listener_with_interfaces( false } +#[cfg(test)] fn is_mask_target_local_listener( mask_host: &str, mask_port: u16, @@ -477,6 +527,26 @@ fn is_mask_target_local_listener( ) } +async fn is_mask_target_local_listener_async( + mask_host: &str, + mask_port: u16, + local_addr: SocketAddr, + resolved_override: Option, +) -> bool { + if mask_port != local_addr.port() { + return false; + } + + let interfaces = local_interface_ips_async().await; + is_mask_target_local_listener_with_interfaces( + mask_host, + mask_port, + local_addr, + resolved_override, + &interfaces, + ) +} + fn masking_beobachten_ttl(config: &ProxyConfig) -> Duration { let minutes = config.general.beobachten_minutes; let clamped = minutes.clamp(1, 24 * 60); @@ -608,13 +678,15 @@ pub async fn handle_bad_client( .as_deref() .unwrap_or(&config.censorship.tls_domain); let mask_port = config.censorship.mask_port; - let outcome_started = Instant::now(); // Fail closed when fallback points at our own listener endpoint. // Self-referential masking can create recursive proxy loops under // misconfiguration and leak distinguishable load spikes to adversaries. let resolved_mask_addr = resolve_socket_addr(mask_host, mask_port); - if is_mask_target_local_listener(mask_host, mask_port, local_addr, resolved_mask_addr) { + if is_mask_target_local_listener_async(mask_host, mask_port, local_addr, resolved_mask_addr) + .await + { + let outcome_started = Instant::now(); debug!( client_type = client_type, host = %mask_host, @@ -627,6 +699,8 @@ pub async fn handle_bad_client( return; } + let outcome_started = Instant::now(); + debug!( client_type = client_type, host = %mask_host, @@ -768,7 +842,13 @@ async fn consume_client_data(mut reader: R, byte_cap: usiz let mut total = 0usize; loop { - let n = match timeout(MASK_RELAY_IDLE_TIMEOUT, reader.read(&mut buf)).await { + let remaining_budget = byte_cap.saturating_sub(total); + if remaining_budget == 0 { + break; + } + + let read_len = remaining_budget.min(MASK_BUFFER_SIZE); + let n = match timeout(MASK_RELAY_IDLE_TIMEOUT, reader.read(&mut buf[..read_len])).await { Ok(Ok(n)) => n, Ok(Err(_)) | Err(_) => break, }; @@ -804,6 +884,10 @@ mod masking_shape_above_cap_blur_security_tests; #[path = "tests/masking_timing_normalization_security_tests.rs"] mod masking_timing_normalization_security_tests; +#[cfg(test)] +#[path = "tests/masking_timing_budget_coupling_security_tests.rs"] +mod masking_timing_budget_coupling_security_tests; + #[cfg(test)] #[path = "tests/masking_ab_envelope_blur_integration_security_tests.rs"] mod masking_ab_envelope_blur_integration_security_tests; @@ -884,6 +968,18 @@ mod masking_interface_cache_security_tests; #[path = "tests/masking_interface_cache_defense_in_depth_security_tests.rs"] mod masking_interface_cache_defense_in_depth_security_tests; +#[cfg(test)] +#[path = "tests/masking_interface_cache_concurrency_security_tests.rs"] +mod masking_interface_cache_concurrency_security_tests; + +#[cfg(test)] +#[path = "tests/masking_production_cap_regression_security_tests.rs"] +mod masking_production_cap_regression_security_tests; + +#[cfg(test)] +#[path = "tests/masking_extended_attack_surface_security_tests.rs"] +mod masking_extended_attack_surface_security_tests; + #[cfg(test)] #[path = "tests/masking_padding_timeout_adversarial_tests.rs"] mod masking_padding_timeout_adversarial_tests; diff --git a/src/proxy/middle_relay.rs b/src/proxy/middle_relay.rs index 0d2a748..b6b198c 100644 --- a/src/proxy/middle_relay.rs +++ b/src/proxy/middle_relay.rs @@ -1,5 +1,7 @@ use std::collections::hash_map::RandomState; use std::collections::{BTreeSet, HashMap}; +#[cfg(test)] +use std::future::Future; use std::hash::{BuildHasher, Hash}; use std::net::{IpAddr, SocketAddr}; use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; @@ -45,6 +47,8 @@ const TINY_FRAME_DEBT_LIMIT: u32 = 512; const C2ME_SEND_TIMEOUT: Duration = Duration::from_millis(50); #[cfg(not(test))] const C2ME_SEND_TIMEOUT: Duration = Duration::from_secs(5); +#[cfg(test)] +const RELAY_TEST_STEP_TIMEOUT: Duration = Duration::from_secs(1); const ME_D2C_FLUSH_BATCH_MAX_FRAMES_MIN: usize = 1; const ME_D2C_FLUSH_BATCH_MAX_BYTES_MIN: usize = 4096; const ME_D2C_FRAME_BUF_SHRINK_HYSTERESIS_FACTOR: usize = 2; @@ -561,11 +565,8 @@ fn quota_would_be_exceeded_for_user_soft( bytes: u64, overshoot: u64, ) -> bool { - quota_limit.is_some_and(|quota| { - let cap = quota_soft_cap(quota, overshoot); - let used = stats.get_user_total_octets(user); - used >= cap || bytes > cap.saturating_sub(used) - }) + let capped_limit = quota_limit.map(|quota| quota_soft_cap(quota, overshoot)); + quota_would_be_exceeded_for_user(stats, user, capped_limit, bytes) } fn classify_me_d2c_flush_reason( @@ -683,7 +684,7 @@ fn quota_user_lock(user: &str) -> Arc> { } #[cfg(test)] -pub(crate) fn cross_mode_quota_user_lock_for_tests(user: &str) -> Arc> { +pub(crate) fn cross_mode_quota_user_lock_for_tests(user: &str) -> Arc> { crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(user) } @@ -712,6 +713,16 @@ async fn enqueue_c2me_command( } } +#[cfg(test)] +async fn run_relay_test_step_timeout(context: &'static str, fut: F) -> T +where + F: Future, +{ + timeout(RELAY_TEST_STEP_TIMEOUT, fut) + .await + .unwrap_or_else(|_| panic!("{context} exceeded {}s", RELAY_TEST_STEP_TIMEOUT.as_secs())) +} + pub(crate) async fn handle_via_middle_proxy( mut crypto_reader: CryptoReader, crypto_writer: CryptoWriter, @@ -860,6 +871,7 @@ where let stats_clone = stats.clone(); let rng_clone = rng.clone(); let user_clone = user.clone(); + let cross_mode_quota_lock_me_writer = cross_mode_quota_lock.clone(); let last_downstream_activity_ms_clone = last_downstream_activity_ms.clone(); let bytes_me2c_clone = bytes_me2c.clone(); let d2c_flush_policy = MeD2cFlushPolicy::from_config(&config); @@ -881,7 +893,7 @@ where let first_is_downstream_activity = matches!(&first, MeResponse::Data { .. } | MeResponse::Ack(_)); - match process_me_writer_response( + match process_me_writer_response_with_cross_mode_lock( first, &mut writer, proto_tag, @@ -891,6 +903,7 @@ where &user_clone, quota_limit, d2c_flush_policy.quota_soft_overshoot_bytes, + cross_mode_quota_lock_me_writer.as_ref(), bytes_me2c_clone.as_ref(), conn_id, d2c_flush_policy.ack_flush_immediate, @@ -939,7 +952,7 @@ where let next_is_downstream_activity = matches!(&next, MeResponse::Data { .. } | MeResponse::Ack(_)); - match process_me_writer_response( + match process_me_writer_response_with_cross_mode_lock( next, &mut writer, proto_tag, @@ -949,6 +962,7 @@ where &user_clone, quota_limit, d2c_flush_policy.quota_soft_overshoot_bytes, + cross_mode_quota_lock_me_writer.as_ref(), bytes_me2c_clone.as_ref(), conn_id, d2c_flush_policy.ack_flush_immediate, @@ -1000,7 +1014,7 @@ where Ok(Some(next)) => { let next_is_downstream_activity = matches!(&next, MeResponse::Data { .. } | MeResponse::Ack(_)); - match process_me_writer_response( + match process_me_writer_response_with_cross_mode_lock( next, &mut writer, proto_tag, @@ -1010,6 +1024,7 @@ where &user_clone, quota_limit, d2c_flush_policy.quota_soft_overshoot_bytes, + cross_mode_quota_lock_me_writer.as_ref(), bytes_me2c_clone.as_ref(), conn_id, d2c_flush_policy.ack_flush_immediate, @@ -1063,7 +1078,7 @@ where let extra_is_downstream_activity = matches!(&extra, MeResponse::Data { .. } | MeResponse::Ack(_)); - match process_me_writer_response( + match process_me_writer_response_with_cross_mode_lock( extra, &mut writer, proto_tag, @@ -1073,6 +1088,7 @@ where &user_clone, quota_limit, d2c_flush_policy.quota_soft_overshoot_bytes, + cross_mode_quota_lock_me_writer.as_ref(), bytes_me2c_clone.as_ref(), conn_id, d2c_flush_policy.ack_flush_immediate, @@ -1252,10 +1268,7 @@ where )); break; }; - let _cross_mode_quota_guard = match cross_mode_lock.lock() { - Ok(guard) => guard, - Err(poisoned) => poisoned.into_inner(), - }; + let _cross_mode_quota_guard = cross_mode_lock.lock().await; stats.add_user_octets_from(&user, payload.len() as u64); if quota_exceeded_for_user(stats.as_ref(), &user, Some(limit)) { main_result = Err(ProxyError::DataQuotaExceeded { @@ -1741,6 +1754,7 @@ enum MeWriterResponseOutcome { Close, } +#[cfg(test)] async fn process_me_writer_response( response: MeResponse, client_writer: &mut CryptoWriter, @@ -1756,6 +1770,44 @@ async fn process_me_writer_response( ack_flush_immediate: bool, batched: bool, ) -> Result +where + W: AsyncWrite + Unpin + Send + 'static, +{ + process_me_writer_response_with_cross_mode_lock( + response, + client_writer, + proto_tag, + rng, + frame_buf, + stats, + user, + quota_limit, + quota_soft_overshoot_bytes, + None, + bytes_me2c, + conn_id, + ack_flush_immediate, + batched, + ) + .await +} + +async fn process_me_writer_response_with_cross_mode_lock( + response: MeResponse, + client_writer: &mut CryptoWriter, + proto_tag: ProtoTag, + rng: &SecureRandom, + frame_buf: &mut Vec, + stats: &Stats, + user: &str, + quota_limit: Option, + quota_soft_overshoot_bytes: u64, + cross_mode_quota_lock: Option<&Arc>>, + bytes_me2c: &AtomicU64, + conn_id: u64, + ack_flush_immediate: bool, + batched: bool, +) -> Result where W: AsyncWrite + Unpin + Send + 'static, { @@ -1768,8 +1820,23 @@ where } let data_len = data.len() as u64; if let Some(limit) = quota_limit { + let owned_cross_mode_lock; + let cross_mode_lock = if let Some(lock) = cross_mode_quota_lock { + lock + } else { + owned_cross_mode_lock = + crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(user); + &owned_cross_mode_lock + }; + let cross_mode_quota_guard = cross_mode_lock.lock().await; let soft_limit = quota_soft_cap(limit, quota_soft_overshoot_bytes); - if quota_would_be_exceeded_for_user(stats, user, Some(soft_limit), data_len) { + if quota_would_be_exceeded_for_user_soft( + stats, + user, + Some(limit), + data_len, + quota_soft_overshoot_bytes, + ) { stats.increment_me_d2c_quota_reject_total(MeD2cQuotaRejectStage::PreWrite); return Err(ProxyError::DataQuotaExceeded { user: user.to_string(), @@ -1789,6 +1856,10 @@ where }); } + // Keep cross-mode lock scope explicit and minimal: quota reservation is serialized, + // but socket I/O proceeds without holding same-user cross-mode admission lock. + drop(cross_mode_quota_guard); + let write_mode = match write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf) .await @@ -2084,3 +2155,27 @@ mod middle_relay_tiny_frame_debt_concurrency_security_tests; #[cfg(test)] #[path = "tests/middle_relay_tiny_frame_debt_proto_chunking_security_tests.rs"] mod middle_relay_tiny_frame_debt_proto_chunking_security_tests; + +#[cfg(test)] +#[path = "tests/middle_relay_cross_mode_quota_reservation_security_tests.rs"] +mod middle_relay_cross_mode_quota_reservation_security_tests; + +#[cfg(test)] +#[path = "tests/middle_relay_cross_mode_quota_lock_matrix_security_tests.rs"] +mod middle_relay_cross_mode_quota_lock_matrix_security_tests; + +#[cfg(test)] +#[path = "tests/middle_relay_cross_mode_lookup_efficiency_security_tests.rs"] +mod middle_relay_cross_mode_lookup_efficiency_security_tests; + +#[cfg(test)] +#[path = "tests/middle_relay_cross_mode_lock_release_regression_security_tests.rs"] +mod middle_relay_cross_mode_lock_release_regression_security_tests; + +#[cfg(test)] +#[path = "tests/middle_relay_quota_extended_attack_surface_security_tests.rs"] +mod middle_relay_quota_extended_attack_surface_security_tests; + +#[cfg(test)] +#[path = "tests/middle_relay_quota_reservation_extreme_security_tests.rs"] +mod middle_relay_quota_reservation_extreme_security_tests; diff --git a/src/proxy/quota_lock_registry.rs b/src/proxy/quota_lock_registry.rs index ac64a57..7798b09 100644 --- a/src/proxy/quota_lock_registry.rs +++ b/src/proxy/quota_lock_registry.rs @@ -1,5 +1,9 @@ use dashmap::DashMap; -use std::sync::{Arc, Mutex, OnceLock}; +use std::sync::{Arc, OnceLock}; +use tokio::sync::Mutex; + +#[cfg(test)] +use std::sync::atomic::{AtomicUsize, Ordering}; #[cfg(test)] const CROSS_MODE_QUOTA_USER_LOCKS_MAX: usize = 64; @@ -13,6 +17,11 @@ const CROSS_MODE_QUOTA_OVERFLOW_LOCK_STRIPES: usize = 256; static CROSS_MODE_QUOTA_USER_LOCKS: OnceLock>>> = OnceLock::new(); static CROSS_MODE_QUOTA_USER_OVERFLOW_LOCKS: OnceLock>>> = OnceLock::new(); +#[cfg(test)] +static CROSS_MODE_QUOTA_USER_LOCK_LOOKUPS: AtomicUsize = AtomicUsize::new(0); +#[cfg(test)] +static CROSS_MODE_QUOTA_USER_LOOKUPS_BY_USER: OnceLock> = OnceLock::new(); + fn cross_mode_quota_overflow_user_lock(user: &str) -> Arc> { let stripes = CROSS_MODE_QUOTA_USER_OVERFLOW_LOCKS.get_or_init(|| { (0..CROSS_MODE_QUOTA_OVERFLOW_LOCK_STRIPES) @@ -25,6 +34,14 @@ fn cross_mode_quota_overflow_user_lock(user: &str) -> Arc> { } pub(crate) fn cross_mode_quota_user_lock(user: &str) -> Arc> { + #[cfg(test)] + { + CROSS_MODE_QUOTA_USER_LOCK_LOOKUPS.fetch_add(1, Ordering::Relaxed); + let lookups = CROSS_MODE_QUOTA_USER_LOOKUPS_BY_USER.get_or_init(DashMap::new); + let mut entry = lookups.entry(user.to_string()).or_insert(0); + *entry += 1; + } + let locks = CROSS_MODE_QUOTA_USER_LOCKS.get_or_init(DashMap::new); if let Some(existing) = locks.get(user) { return Arc::clone(existing.value()); @@ -48,6 +65,24 @@ pub(crate) fn cross_mode_quota_user_lock(user: &str) -> Arc> { } } +#[cfg(test)] +pub(crate) fn reset_cross_mode_quota_user_lock_lookup_count_for_tests() { + CROSS_MODE_QUOTA_USER_LOCK_LOOKUPS.store(0, Ordering::Relaxed); + let lookups = CROSS_MODE_QUOTA_USER_LOOKUPS_BY_USER.get_or_init(DashMap::new); + lookups.clear(); +} + +#[cfg(test)] +pub(crate) fn cross_mode_quota_user_lock_lookup_count_for_tests() -> usize { + CROSS_MODE_QUOTA_USER_LOCK_LOOKUPS.load(Ordering::Relaxed) +} + +#[cfg(test)] +pub(crate) fn cross_mode_quota_user_lock_lookup_count_for_user_for_tests(user: &str) -> usize { + let lookups = CROSS_MODE_QUOTA_USER_LOOKUPS_BY_USER.get_or_init(DashMap::new); + lookups.get(user).map(|entry| *entry).unwrap_or(0) +} + #[cfg(test)] #[path = "tests/quota_lock_registry_cross_mode_adversarial_tests.rs"] mod quota_lock_registry_cross_mode_adversarial_tests; diff --git a/src/proxy/relay.rs b/src/proxy/relay.rs index dcacedd..55f1385 100644 --- a/src/proxy/relay.rs +++ b/src/proxy/relay.rs @@ -62,6 +62,7 @@ use std::sync::{Arc, Mutex, OnceLock}; use std::task::{Context, Poll}; use std::time::Duration; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf, copy_bidirectional_with_sizes}; +use tokio::sync::Mutex as AsyncMutex; use tokio::time::{Instant, Sleep}; use tracing::{debug, trace, warn}; @@ -210,7 +211,7 @@ struct StatsIo { stats: Arc, user: String, quota_lock: Option>>, - cross_mode_quota_lock: Option>>, + cross_mode_quota_lock: Option>>, quota_limit: Option, quota_exceeded: Arc, quota_read_wake_scheduled: bool, @@ -289,6 +290,21 @@ const QUOTA_CONTENTION_RETRY_MAX_INTERVAL: Duration = Duration::from_millis(16); #[cfg(not(test))] const QUOTA_CONTENTION_RETRY_MAX_INTERVAL: Duration = Duration::from_millis(64); +#[cfg(test)] +static QUOTA_RETRY_SLEEP_ALLOCS: AtomicU64 = AtomicU64::new(0); +#[cfg(test)] +static QUOTA_USER_LOCK_EVICTOR_SPAWN_COUNT: AtomicU64 = AtomicU64::new(0); + +#[cfg(test)] +pub(crate) fn reset_quota_retry_sleep_allocs_for_tests() { + QUOTA_RETRY_SLEEP_ALLOCS.store(0, Ordering::Relaxed); +} + +#[cfg(test)] +pub(crate) fn quota_retry_sleep_allocs_for_tests() -> u64 { + QUOTA_RETRY_SLEEP_ALLOCS.load(Ordering::Relaxed) +} + #[inline] fn quota_contention_retry_delay(retry_attempt: u8) -> Duration { let shift = u32::from(retry_attempt.min(5)); @@ -317,6 +333,8 @@ fn poll_quota_retry_sleep( ) { if !*wake_scheduled { *wake_scheduled = true; + #[cfg(test)] + QUOTA_RETRY_SLEEP_ALLOCS.fetch_add(1, Ordering::Relaxed); *sleep_slot = Some(Box::pin(tokio::time::sleep(quota_contention_retry_delay( *retry_attempt, )))); @@ -368,16 +386,47 @@ fn quota_overflow_user_lock(user: &str) -> Arc> { Arc::clone(&stripes[hash % stripes.len()]) } +pub(crate) fn quota_user_lock_evict() { + if let Some(locks) = QUOTA_USER_LOCKS.get() { + locks.retain(|_, value| Arc::strong_count(value) > 1); + } +} + +pub(crate) fn spawn_quota_user_lock_evictor(interval: Duration) -> tokio::task::JoinHandle<()> { + let interval = interval.max(Duration::from_millis(1)); + #[cfg(test)] + QUOTA_USER_LOCK_EVICTOR_SPAWN_COUNT.fetch_add(1, Ordering::Relaxed); + tokio::spawn(async move { + loop { + tokio::time::sleep(interval).await; + quota_user_lock_evict(); + } + }) +} + +#[cfg(test)] +pub(crate) fn spawn_quota_user_lock_evictor_for_tests( + interval: Duration, +) -> tokio::task::JoinHandle<()> { + spawn_quota_user_lock_evictor(interval) +} + +#[cfg(test)] +pub(crate) fn reset_quota_user_lock_evictor_spawn_count_for_tests() { + QUOTA_USER_LOCK_EVICTOR_SPAWN_COUNT.store(0, Ordering::Relaxed); +} + +#[cfg(test)] +pub(crate) fn quota_user_lock_evictor_spawn_count_for_tests() -> u64 { + QUOTA_USER_LOCK_EVICTOR_SPAWN_COUNT.load(Ordering::Relaxed) +} + fn quota_user_lock(user: &str) -> Arc> { let locks = QUOTA_USER_LOCKS.get_or_init(DashMap::new); if let Some(existing) = locks.get(user) { return Arc::clone(existing.value()); } - if locks.len() >= QUOTA_USER_LOCKS_MAX { - locks.retain(|_, value| Arc::strong_count(value) > 1); - } - if locks.len() >= QUOTA_USER_LOCKS_MAX { return quota_overflow_user_lock(user); } @@ -393,7 +442,7 @@ fn quota_user_lock(user: &str) -> Arc> { } #[cfg(test)] -pub(crate) fn cross_mode_quota_user_lock_for_tests(user: &str) -> Arc> { +pub(crate) fn cross_mode_quota_user_lock_for_tests(user: &str) -> Arc> { crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(user) } @@ -410,14 +459,7 @@ impl AsyncRead for StatsIo { let _quota_guard = if let Some(lock) = this.quota_lock.as_ref() { match lock.try_lock() { - Ok(guard) => { - reset_quota_retry_scheduler( - &mut this.quota_read_retry_sleep, - &mut this.quota_read_wake_scheduled, - &mut this.quota_read_retry_attempt, - ); - Some(guard) - } + Ok(guard) => Some(guard), Err(_) => { poll_quota_retry_sleep( &mut this.quota_read_retry_sleep, @@ -434,14 +476,7 @@ impl AsyncRead for StatsIo { let _cross_mode_quota_guard = if let Some(lock) = this.cross_mode_quota_lock.as_ref() { match lock.try_lock() { - Ok(guard) => { - reset_quota_retry_scheduler( - &mut this.quota_read_retry_sleep, - &mut this.quota_read_wake_scheduled, - &mut this.quota_read_retry_attempt, - ); - Some(guard) - } + Ok(guard) => Some(guard), Err(_) => { poll_quota_retry_sleep( &mut this.quota_read_retry_sleep, @@ -456,6 +491,12 @@ impl AsyncRead for StatsIo { None }; + reset_quota_retry_scheduler( + &mut this.quota_read_retry_sleep, + &mut this.quota_read_wake_scheduled, + &mut this.quota_read_retry_attempt, + ); + if let Some(limit) = this.quota_limit && this.stats.get_user_total_octets(&this.user) >= limit { @@ -523,14 +564,7 @@ impl AsyncWrite for StatsIo { let _quota_guard = if let Some(lock) = this.quota_lock.as_ref() { match lock.try_lock() { - Ok(guard) => { - reset_quota_retry_scheduler( - &mut this.quota_write_retry_sleep, - &mut this.quota_write_wake_scheduled, - &mut this.quota_write_retry_attempt, - ); - Some(guard) - } + Ok(guard) => Some(guard), Err(_) => { poll_quota_retry_sleep( &mut this.quota_write_retry_sleep, @@ -547,14 +581,7 @@ impl AsyncWrite for StatsIo { let _cross_mode_quota_guard = if let Some(lock) = this.cross_mode_quota_lock.as_ref() { match lock.try_lock() { - Ok(guard) => { - reset_quota_retry_scheduler( - &mut this.quota_write_retry_sleep, - &mut this.quota_write_wake_scheduled, - &mut this.quota_write_retry_attempt, - ); - Some(guard) - } + Ok(guard) => Some(guard), Err(_) => { poll_quota_retry_sleep( &mut this.quota_write_retry_sleep, @@ -569,6 +596,12 @@ impl AsyncWrite for StatsIo { None }; + reset_quota_retry_scheduler( + &mut this.quota_write_retry_sleep, + &mut this.quota_write_wake_scheduled, + &mut this.quota_write_retry_attempt, + ); + let write_buf = if let Some(limit) = this.quota_limit { let used = this.stats.get_user_total_octets(&this.user); if used >= limit { @@ -861,6 +894,10 @@ mod relay_quota_model_adversarial_tests; #[path = "tests/relay_quota_overflow_regression_tests.rs"] mod relay_quota_overflow_regression_tests; +#[cfg(test)] +#[path = "tests/relay_quota_extended_attack_surface_security_tests.rs"] +mod relay_quota_extended_attack_surface_security_tests; + #[cfg(test)] #[path = "tests/relay_watchdog_delta_security_tests.rs"] mod relay_watchdog_delta_security_tests; @@ -889,6 +926,14 @@ mod relay_quota_retry_scheduler_tdd_tests; #[path = "tests/relay_cross_mode_quota_fairness_tdd_tests.rs"] mod relay_cross_mode_quota_fairness_tdd_tests; +#[cfg(test)] +#[path = "tests/relay_cross_mode_pipeline_hol_integration_security_tests.rs"] +mod relay_cross_mode_pipeline_hol_integration_security_tests; + +#[cfg(test)] +#[path = "tests/relay_cross_mode_pipeline_latency_benchmark_security_tests.rs"] +mod relay_cross_mode_pipeline_latency_benchmark_security_tests; + #[cfg(test)] #[path = "tests/relay_quota_retry_backoff_security_tests.rs"] mod relay_quota_retry_backoff_security_tests; @@ -896,3 +941,31 @@ mod relay_quota_retry_backoff_security_tests; #[cfg(test)] #[path = "tests/relay_quota_retry_backoff_benchmark_security_tests.rs"] mod relay_quota_retry_backoff_benchmark_security_tests; + +#[cfg(test)] +#[path = "tests/relay_dual_lock_backoff_regression_security_tests.rs"] +mod relay_dual_lock_backoff_regression_security_tests; + +#[cfg(test)] +#[path = "tests/relay_dual_lock_contention_matrix_security_tests.rs"] +mod relay_dual_lock_contention_matrix_security_tests; + +#[cfg(test)] +#[path = "tests/relay_dual_lock_race_harness_security_tests.rs"] +mod relay_dual_lock_race_harness_security_tests; + +#[cfg(test)] +#[path = "tests/relay_dual_lock_alternating_contention_security_tests.rs"] +mod relay_dual_lock_alternating_contention_security_tests; + +#[cfg(test)] +#[path = "tests/relay_quota_retry_allocation_latency_security_tests.rs"] +mod relay_quota_retry_allocation_latency_security_tests; + +#[cfg(test)] +#[path = "tests/relay_quota_lock_eviction_lifecycle_tdd_tests.rs"] +mod relay_quota_lock_eviction_lifecycle_tdd_tests; + +#[cfg(test)] +#[path = "tests/relay_quota_lock_eviction_stress_security_tests.rs"] +mod relay_quota_lock_eviction_stress_security_tests; diff --git a/src/proxy/tests/client_security_tests.rs b/src/proxy/tests/client_security_tests.rs index 35f517a..2b1fae6 100644 --- a/src/proxy/tests/client_security_tests.rs +++ b/src/proxy/tests/client_security_tests.rs @@ -8,7 +8,7 @@ use crate::proxy::handshake::HandshakeSuccess; use crate::stream::{CryptoReader, CryptoWriter}; use crate::transport::proxy_protocol::ProxyProtocolV1Builder; use rand::rngs::StdRng; -use rand::RngCore; +use rand::Rng; use rand::SeedableRng; use std::net::Ipv4Addr; use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; diff --git a/src/proxy/tests/handshake_auth_probe_eviction_bias_security_tests.rs b/src/proxy/tests/handshake_auth_probe_eviction_bias_security_tests.rs new file mode 100644 index 0000000..6c48cc1 --- /dev/null +++ b/src/proxy/tests/handshake_auth_probe_eviction_bias_security_tests.rs @@ -0,0 +1,93 @@ +use super::*; +use std::collections::HashSet; +use std::net::{IpAddr, Ipv4Addr}; +use std::time::{Duration, Instant}; + +fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> { + auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +#[test] +fn adversarial_large_state_offsets_escape_first_scan_window() { + let _guard = auth_probe_test_guard(); + let base = Instant::now(); + let state_len = 65_536usize; + let scan_limit = 1_024usize; + + let mut saw_offset_outside_first_window = false; + for i in 0..8_192u64 { + let ip = IpAddr::V4(Ipv4Addr::new( + ((i >> 16) & 0xff) as u8, + ((i >> 8) & 0xff) as u8, + (i & 0xff) as u8, + ((i.wrapping_mul(131)) & 0xff) as u8, + )); + let now = base + Duration::from_nanos(i); + let start = auth_probe_scan_start_offset(ip, now, state_len, scan_limit); + if start >= scan_limit { + saw_offset_outside_first_window = true; + break; + } + } + + assert!( + saw_offset_outside_first_window, + "scan start offset must cover the full auth-probe state, not only the first scan window" + ); +} + +#[test] +fn stress_large_state_offsets_cover_many_scan_windows() { + let _guard = auth_probe_test_guard(); + let base = Instant::now(); + let state_len = 65_536usize; + let scan_limit = 1_024usize; + + let mut covered_windows = HashSet::new(); + for i in 0..16_384u64 { + let ip = IpAddr::V4(Ipv4Addr::new( + ((i >> 16) & 0xff) as u8, + ((i >> 8) & 0xff) as u8, + (i & 0xff) as u8, + ((i.wrapping_mul(17)) & 0xff) as u8, + )); + let now = base + Duration::from_micros(i); + let start = auth_probe_scan_start_offset(ip, now, state_len, scan_limit); + covered_windows.insert(start / scan_limit); + } + + assert!( + covered_windows.len() >= 16, + "eviction scan must not collapse to a tiny hot zone; covered windows={} out of {}", + covered_windows.len(), + state_len / scan_limit + ); +} + +#[test] +fn light_fuzz_offset_always_stays_inside_state_len() { + let _guard = auth_probe_test_guard(); + let mut seed = 0xC0FF_EE12_3456_789Au64; + let base = Instant::now(); + + for _ in 0..8_192usize { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let ip = IpAddr::V4(Ipv4Addr::new( + (seed >> 24) as u8, + (seed >> 16) as u8, + (seed >> 8) as u8, + seed as u8, + )); + let state_len = ((seed >> 16) as usize % 200_000).saturating_add(1); + let scan_limit = ((seed >> 40) as usize % 2_048).saturating_add(1); + let now = base + Duration::from_nanos(seed & 0x0fff); + let start = auth_probe_scan_start_offset(ip, now, state_len, scan_limit); + + assert!(start < state_len, "scan offset must stay inside state length"); + } +} \ No newline at end of file diff --git a/src/proxy/tests/handshake_auth_probe_scan_budget_security_tests.rs b/src/proxy/tests/handshake_auth_probe_scan_budget_security_tests.rs index c5e57d7..ece6ff5 100644 --- a/src/proxy/tests/handshake_auth_probe_scan_budget_security_tests.rs +++ b/src/proxy/tests/handshake_auth_probe_scan_budget_security_tests.rs @@ -22,12 +22,13 @@ fn edge_zero_state_len_yields_zero_start_offset() { } #[test] -fn adversarial_large_state_must_bound_start_offset_to_scan_budget() { +fn adversarial_large_state_must_allow_start_offset_outside_scan_budget_window() { let _guard = auth_probe_test_guard(); let base = Instant::now(); let scan_limit = 16usize; let state_len = 65_536usize; + let mut saw_offset_outside_window = false; for i in 0..2048u32 { let ip = IpAddr::V4(Ipv4Addr::new( 203, @@ -38,10 +39,19 @@ fn adversarial_large_state_must_bound_start_offset_to_scan_budget() { let now = base + Duration::from_micros(i as u64); let start = auth_probe_scan_start_offset(ip, now, state_len, scan_limit); assert!( - start < scan_limit, - "start offset must stay within scan window; start={start}, limit={scan_limit}" + start < state_len, + "start offset must stay within state length; start={start}, len={state_len}" ); + if start >= scan_limit { + saw_offset_outside_window = true; + break; + } } + + assert!( + saw_offset_outside_window, + "large-state eviction must sample beyond the first scan window" + ); } #[test] @@ -80,11 +90,10 @@ fn light_fuzz_scan_offset_budget_never_exceeds_effective_window() { let scan_limit = ((seed >> 32) as usize % 512).saturating_add(1); let now = base + Duration::from_nanos(seed & 0xffff); let start = auth_probe_scan_start_offset(ip, now, state_len, scan_limit); - let effective_window = state_len.min(scan_limit); assert!( - start < effective_window, - "scan offset must stay inside effective window" + start < state_len, + "scan offset must stay inside state length" ); } } \ No newline at end of file diff --git a/src/proxy/tests/handshake_auth_probe_scan_offset_stress_tests.rs b/src/proxy/tests/handshake_auth_probe_scan_offset_stress_tests.rs index cdaf498..260a1b9 100644 --- a/src/proxy/tests/handshake_auth_probe_scan_offset_stress_tests.rs +++ b/src/proxy/tests/handshake_auth_probe_scan_offset_stress_tests.rs @@ -22,10 +22,10 @@ fn positive_same_ip_moving_time_yields_diverse_scan_offsets() { uniq.insert(offset); } - assert_eq!( - uniq.len(), - 16, - "offset randomization must cover the entire scan window over 512 samples" + assert!( + uniq.len() >= 256, + "offset randomization collapsed unexpectedly for same-ip moving-time samples (uniq={})", + uniq.len() ); } @@ -45,10 +45,10 @@ fn adversarial_many_ips_same_time_spreads_offsets_without_bias_collapse() { uniq.insert(auth_probe_scan_start_offset(ip, now, 65_536, 16)); } - assert_eq!( - uniq.len(), - 16, - "scan offset distribution collapsed unexpectedly across peer set" + assert!( + uniq.len() >= 512, + "scan offset distribution collapsed unexpectedly across adversarial peer set (uniq={})", + uniq.len() ); } @@ -108,6 +108,9 @@ fn light_fuzz_scan_offset_stays_within_window_for_randomized_inputs() { let now = base + Duration::from_nanos(seed & 0x1fff); let offset = auth_probe_scan_start_offset(ip, now, state_len, scan_limit); - assert!(offset < state_len.min(scan_limit)); + assert!( + offset < state_len, + "scan offset must always remain inside state length" + ); } } \ No newline at end of file diff --git a/src/proxy/tests/handshake_more_clever_tests.rs b/src/proxy/tests/handshake_more_clever_tests.rs index b3da4df..77df442 100644 --- a/src/proxy/tests/handshake_more_clever_tests.rs +++ b/src/proxy/tests/handshake_more_clever_tests.rs @@ -1,7 +1,7 @@ use super::*; use crate::crypto::{sha256, sha256_hmac, AesCtr}; use crate::protocol::constants::{ProtoTag, RESERVED_NONCE_BEGINNINGS, RESERVED_NONCE_FIRST_BYTES}; -use rand::{RngExt, SeedableRng}; +use rand::{Rng, SeedableRng}; use rand::rngs::StdRng; use std::collections::HashSet; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; diff --git a/src/proxy/tests/masking_extended_attack_surface_security_tests.rs b/src/proxy/tests/masking_extended_attack_surface_security_tests.rs new file mode 100644 index 0000000..040f567 --- /dev/null +++ b/src/proxy/tests/masking_extended_attack_surface_security_tests.rs @@ -0,0 +1,217 @@ +use super::*; +use tokio::io::{AsyncWriteExt, duplex}; +use tokio::time::{Duration, Instant, timeout}; + +fn make_self_target_config( + timing_normalization_enabled: bool, + floor_ms: u64, + ceiling_ms: u64, + beobachten_enabled: bool, +) -> ProxyConfig { + let mut config = ProxyConfig::default(); + config.general.beobachten = beobachten_enabled; + config.general.beobachten_minutes = 5; + config.censorship.mask = true; + config.censorship.mask_unix_sock = None; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = 443; + config.censorship.mask_timing_normalization_enabled = timing_normalization_enabled; + config.censorship.mask_timing_normalization_floor_ms = floor_ms; + config.censorship.mask_timing_normalization_ceiling_ms = ceiling_ms; + config +} + +async fn run_self_target_refusal( + config: ProxyConfig, + peer: SocketAddr, + initial: &'static [u8], +) -> Duration { + let beobachten = BeobachtenStore::new(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().expect("valid local addr"); + + let (mut client, server) = duplex(1024); + let started = Instant::now(); + let task = tokio::spawn(async move { + handle_bad_client(server, tokio::io::sink(), initial, peer, local_addr, &config, &beobachten) + .await; + }); + + client + .shutdown() + .await + .expect("client shutdown must succeed"); + + timeout(Duration::from_secs(3), task) + .await + .expect("self-target refusal must complete in bounded time") + .expect("self-target refusal task must not panic"); + + started.elapsed() +} + +#[tokio::test] +async fn positive_self_target_refusal_honors_normalization_floor() { + let config = make_self_target_config(true, 120, 120, false); + let peer: SocketAddr = "203.0.113.41:54041".parse().expect("valid peer"); + + let elapsed = run_self_target_refusal(config, peer, b"GET / HTTP/1.1\r\n\r\n").await; + + assert!( + elapsed >= Duration::from_millis(110) && elapsed < Duration::from_millis(260), + "normalized self-target refusal must stay within expected envelope" + ); +} + +#[tokio::test] +async fn negative_non_normalized_refusal_does_not_sleep_to_large_floor() { + let config = make_self_target_config(false, 240, 240, false); + let peer: SocketAddr = "203.0.113.42:54042".parse().expect("valid peer"); + + let elapsed = run_self_target_refusal(config, peer, b"GET / HTTP/1.1\r\n\r\n").await; + + assert!( + elapsed < Duration::from_millis(180), + "non-normalized path must not inherit normalization floor delays" + ); +} + +#[tokio::test] +async fn edge_ceiling_below_floor_uses_floor_fail_closed() { + let config = make_self_target_config(true, 140, 80, false); + let peer: SocketAddr = "203.0.113.43:54043".parse().expect("valid peer"); + + let elapsed = run_self_target_refusal(config, peer, b"GET / HTTP/1.1\r\n\r\n").await; + + assert!( + elapsed >= Duration::from_millis(130) && elapsed < Duration::from_millis(280), + "ceiling max { + max = elapsed; + } + assert!( + elapsed >= Duration::from_millis(100) && elapsed < Duration::from_millis(320), + "parallel probe latency must stay bounded under normalization" + ); + } + + assert!( + max.saturating_sub(min) <= Duration::from_millis(130), + "normalization should limit path variance across adversarial parallel probes" + ); +} + +#[tokio::test] +async fn integration_beobachten_records_probe_classification_on_refusal() { + let config = make_self_target_config(false, 0, 0, true); + let peer: SocketAddr = "198.51.100.71:55071".parse().expect("valid peer"); + let local_addr: SocketAddr = "127.0.0.1:443".parse().expect("valid local addr"); + let beobachten = BeobachtenStore::new(); + + let (mut client, server) = duplex(1024); + let task = tokio::spawn(async move { + handle_bad_client( + server, + tokio::io::sink(), + b"GET /classified HTTP/1.1\r\nHost: demo\r\n\r\n", + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + beobachten.snapshot_text(Duration::from_secs(60)) + }); + + client + .shutdown() + .await + .expect("client shutdown must succeed"); + + let snapshot = timeout(Duration::from_secs(3), task) + .await + .expect("integration task must complete") + .expect("integration task must not panic"); + + assert!(snapshot.contains("[HTTP]")); + assert!(snapshot.contains("198.51.100.71-1")); +} + +#[tokio::test] +async fn light_fuzz_timing_configuration_matrix_is_bounded() { + let mut seed = 0xA17E_55AA_2026_0323u64; + + for case in 0..48u64 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let enabled = (seed & 1) == 0; + let floor = (seed >> 8) % 180; + let ceiling = (seed >> 24) % 180; + let config = make_self_target_config(enabled, floor, ceiling, false); + let peer: SocketAddr = format!("203.0.113.90:{}", 56000 + (case as u16)) + .parse() + .expect("valid peer"); + + let elapsed = run_self_target_refusal(config, peer, b"HEAD /h HTTP/1.1\r\n\r\n").await; + + assert!( + elapsed < Duration::from_millis(420), + "fuzz case must stay bounded and never hang" + ); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_high_fanout_self_target_refusal_no_deadlock_or_timeout() { + let workers = 64usize; + let mut tasks = Vec::with_capacity(workers); + + for idx in 0..workers { + tasks.push(tokio::spawn(async move { + let config = make_self_target_config(false, 0, 0, false); + let peer: SocketAddr = format!("198.51.100.200:{}", 57000 + idx as u16) + .parse() + .expect("valid peer"); + run_self_target_refusal(config, peer, b"GET /stress HTTP/1.1\r\n\r\n").await + })); + } + + timeout(Duration::from_secs(5), async { + for task in tasks { + let elapsed = task.await.expect("stress task must not panic"); + assert!( + elapsed < Duration::from_millis(260), + "stress refusal must remain bounded without normalization" + ); + } + }) + .await + .expect("high-fanout refusal workload must complete without deadlock"); +} \ No newline at end of file diff --git a/src/proxy/tests/masking_interface_cache_concurrency_security_tests.rs b/src/proxy/tests/masking_interface_cache_concurrency_security_tests.rs new file mode 100644 index 0000000..8d99b8f --- /dev/null +++ b/src/proxy/tests/masking_interface_cache_concurrency_security_tests.rs @@ -0,0 +1,41 @@ +#![cfg(unix)] + +use super::*; +use std::sync::{Mutex, OnceLock}; +use tokio::sync::Barrier; + +fn interface_cache_test_lock() -> &'static Mutex<()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn adversarial_parallel_cold_miss_performs_single_interface_refresh() { + let _guard = interface_cache_test_lock() + .lock() + .unwrap_or_else(|poison| poison.into_inner()); + reset_local_interface_enumerations_for_tests(); + + let local_addr: SocketAddr = "0.0.0.0:443".parse().expect("valid local addr"); + let workers = 32usize; + let barrier = std::sync::Arc::new(Barrier::new(workers)); + let mut tasks = Vec::with_capacity(workers); + + for _ in 0..workers { + let barrier = std::sync::Arc::clone(&barrier); + tasks.push(tokio::spawn(async move { + barrier.wait().await; + is_mask_target_local_listener_async("127.0.0.1", 443, local_addr, None).await + })); + } + + for task in tasks { + let _ = task.await.expect("parallel cache task must not panic"); + } + + assert_eq!( + local_interface_enumerations_for_tests(), + 1, + "parallel cold misses must coalesce into a single interface enumeration" + ); +} \ No newline at end of file diff --git a/src/proxy/tests/masking_interface_cache_security_tests.rs b/src/proxy/tests/masking_interface_cache_security_tests.rs index b14d7c3..6be99d0 100644 --- a/src/proxy/tests/masking_interface_cache_security_tests.rs +++ b/src/proxy/tests/masking_interface_cache_security_tests.rs @@ -8,8 +8,8 @@ fn interface_cache_test_lock() -> &'static Mutex<()> { LOCK.get_or_init(|| Mutex::new(())) } -#[test] -fn tdd_repeated_local_listener_checks_do_not_repeat_interface_enumeration_within_window() { +#[tokio::test] +async fn tdd_repeated_local_listener_checks_do_not_repeat_interface_enumeration_within_window() { let _guard = interface_cache_test_lock() .lock() .unwrap_or_else(|poison| poison.into_inner()); @@ -17,8 +17,8 @@ fn tdd_repeated_local_listener_checks_do_not_repeat_interface_enumeration_within let local_addr: SocketAddr = "0.0.0.0:443".parse().expect("valid local addr"); - let _ = is_mask_target_local_listener("127.0.0.1", 443, local_addr, None); - let _ = is_mask_target_local_listener("127.0.0.1", 443, local_addr, None); + let _ = is_mask_target_local_listener_async("127.0.0.1", 443, local_addr, None).await; + let _ = is_mask_target_local_listener_async("127.0.0.1", 443, local_addr, None).await; assert_eq!( local_interface_enumerations_for_tests(), @@ -27,15 +27,15 @@ fn tdd_repeated_local_listener_checks_do_not_repeat_interface_enumeration_within ); } -#[test] -fn tdd_non_local_port_short_circuit_does_not_enumerate_interfaces() { +#[tokio::test] +async fn tdd_non_local_port_short_circuit_does_not_enumerate_interfaces() { let _guard = interface_cache_test_lock() .lock() .unwrap_or_else(|poison| poison.into_inner()); reset_local_interface_enumerations_for_tests(); let local_addr: SocketAddr = "0.0.0.0:443".parse().expect("valid local addr"); - let is_local = is_mask_target_local_listener("127.0.0.1", 8443, local_addr, None); + let is_local = is_mask_target_local_listener_async("127.0.0.1", 8443, local_addr, None).await; assert!(!is_local, "different port must not be treated as local listener"); assert_eq!( diff --git a/src/proxy/tests/masking_production_cap_regression_security_tests.rs b/src/proxy/tests/masking_production_cap_regression_security_tests.rs new file mode 100644 index 0000000..f2368a1 --- /dev/null +++ b/src/proxy/tests/masking_production_cap_regression_security_tests.rs @@ -0,0 +1,289 @@ +use super::*; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::time::{Duration, Instant, timeout}; + +const PROD_CAP_BYTES: usize = 5 * 1024 * 1024; + +struct FinitePatternReader { + remaining: usize, + chunk: usize, + read_calls: Arc, +} + +impl FinitePatternReader { + fn new(total: usize, chunk: usize, read_calls: Arc) -> Self { + Self { + remaining: total, + chunk, + read_calls, + } + } +} + +impl AsyncRead for FinitePatternReader { + fn poll_read( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + self.read_calls.fetch_add(1, Ordering::Relaxed); + + if self.remaining == 0 { + return Poll::Ready(Ok(())); + } + + let take = self.remaining.min(self.chunk).min(buf.remaining()); + if take == 0 { + return Poll::Ready(Ok(())); + } + + let fill = vec![0x5Au8; take]; + buf.put_slice(&fill); + self.remaining -= take; + Poll::Ready(Ok(())) + } +} + +#[derive(Default)] +struct CountingWriter { + written: usize, +} + +impl AsyncWrite for CountingWriter { + fn poll_write( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.written = self.written.saturating_add(buf.len()); + Poll::Ready(Ok(buf.len())) + } + + fn poll_flush( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } +} + +struct NeverReadyReader; + +impl AsyncRead for NeverReadyReader { + fn poll_read( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + Poll::Pending + } +} + +struct BudgetProbeReader { + remaining: usize, + total_read: Arc, +} + +impl BudgetProbeReader { + fn new(total: usize, total_read: Arc) -> Self { + Self { + remaining: total, + total_read, + } + } +} + +impl AsyncRead for BudgetProbeReader { + fn poll_read( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + if self.remaining == 0 { + return Poll::Ready(Ok(())); + } + + let take = self.remaining.min(buf.remaining()); + if take == 0 { + return Poll::Ready(Ok(())); + } + + let fill = vec![0xA5u8; take]; + buf.put_slice(&fill); + self.remaining -= take; + self.total_read.fetch_add(take, Ordering::Relaxed); + Poll::Ready(Ok(())) + } +} + +#[tokio::test] +async fn positive_copy_with_production_cap_stops_exactly_at_budget() { + let read_calls = Arc::new(AtomicUsize::new(0)); + let mut reader = FinitePatternReader::new(PROD_CAP_BYTES + (256 * 1024), 4096, read_calls); + let mut writer = CountingWriter::default(); + + let outcome = copy_with_idle_timeout(&mut reader, &mut writer, PROD_CAP_BYTES, true).await; + + assert_eq!( + outcome.total, PROD_CAP_BYTES, + "copy path must stop at explicit production cap" + ); + assert_eq!(writer.written, PROD_CAP_BYTES); + assert!( + !outcome.ended_by_eof, + "byte-cap stop must not be misclassified as EOF" + ); +} + +#[tokio::test] +async fn negative_consume_with_zero_cap_performs_no_reads() { + let read_calls = Arc::new(AtomicUsize::new(0)); + let reader = FinitePatternReader::new(1024, 64, Arc::clone(&read_calls)); + + consume_client_data_with_timeout_and_cap(reader, 0).await; + + assert_eq!( + read_calls.load(Ordering::Relaxed), + 0, + "zero cap must return before reading attacker-controlled bytes" + ); +} + +#[tokio::test] +async fn edge_copy_below_cap_reports_eof_without_overread() { + let read_calls = Arc::new(AtomicUsize::new(0)); + let payload = 73 * 1024; + let mut reader = FinitePatternReader::new(payload, 3072, read_calls); + let mut writer = CountingWriter::default(); + + let outcome = copy_with_idle_timeout(&mut reader, &mut writer, PROD_CAP_BYTES, true).await; + + assert_eq!(outcome.total, payload); + assert_eq!(writer.written, payload); + assert!( + outcome.ended_by_eof, + "finite upstream below cap must terminate via EOF path" + ); +} + +#[tokio::test] +async fn adversarial_blackhat_never_ready_reader_is_bounded_by_timeout_guards() { + let started = Instant::now(); + + consume_client_data_with_timeout_and_cap(NeverReadyReader, PROD_CAP_BYTES).await; + + assert!( + started.elapsed() < Duration::from_millis(350), + "never-ready reader must be bounded by idle/relay timeout protections" + ); +} + +#[tokio::test] +async fn integration_consume_path_honors_production_cap_for_large_payload() { + let read_calls = Arc::new(AtomicUsize::new(0)); + let reader = FinitePatternReader::new(PROD_CAP_BYTES + (1024 * 1024), 8192, read_calls); + + let bounded = timeout( + Duration::from_millis(350), + consume_client_data_with_timeout_and_cap(reader, PROD_CAP_BYTES), + ) + .await; + + assert!( + bounded.is_ok(), + "consume path with production cap must finish within bounded time" + ); +} + +#[tokio::test] +async fn adversarial_consume_path_never_reads_beyond_declared_byte_cap() { + let byte_cap = 5usize; + let total_read = Arc::new(AtomicUsize::new(0)); + let reader = BudgetProbeReader::new(256 * 1024, Arc::clone(&total_read)); + + consume_client_data_with_timeout_and_cap(reader, byte_cap).await; + + assert!( + total_read.load(Ordering::Relaxed) <= byte_cap, + "consume path must not read more than configured byte cap" + ); +} + +#[tokio::test] +async fn light_fuzz_cap_and_payload_matrix_preserves_min_budget_invariant() { + let mut seed = 0x1234_5678_9ABC_DEF0u64; + + for _case in 0..96u32 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let cap = ((seed & 0x3ffff) as usize).saturating_add(1); + let payload = ((seed.rotate_left(11) & 0x7ffff) as usize).saturating_add(1); + let chunk = (((seed >> 5) & 0x1fff) as usize).saturating_add(1); + + let read_calls = Arc::new(AtomicUsize::new(0)); + let mut reader = FinitePatternReader::new(payload, chunk, read_calls); + let mut writer = CountingWriter::default(); + + let outcome = copy_with_idle_timeout(&mut reader, &mut writer, cap, true).await; + let expected = payload.min(cap); + + assert_eq!( + outcome.total, expected, + "copy total must match min(payload, cap) under fuzzed inputs" + ); + assert_eq!(writer.written, expected); + if payload <= cap { + assert!(outcome.ended_by_eof); + } else { + assert!(!outcome.ended_by_eof); + } + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_parallel_copy_tasks_with_production_cap_complete_without_leaks() { + let workers = 8usize; + let mut tasks = Vec::with_capacity(workers); + + for idx in 0..workers { + tasks.push(tokio::spawn(async move { + let read_calls = Arc::new(AtomicUsize::new(0)); + let mut reader = FinitePatternReader::new( + PROD_CAP_BYTES + (idx + 1) * 4096, + 4096 + (idx * 257), + read_calls, + ); + let mut writer = CountingWriter::default(); + copy_with_idle_timeout(&mut reader, &mut writer, PROD_CAP_BYTES, true).await + })); + } + + timeout(Duration::from_secs(3), async { + for task in tasks { + let outcome = task.await.expect("stress task must not panic"); + assert_eq!( + outcome.total, PROD_CAP_BYTES, + "stress copy task must stay within production cap" + ); + assert!( + !outcome.ended_by_eof, + "stress task should end due to cap, not EOF" + ); + } + }) + .await + .expect("stress suite must complete in bounded time"); +} diff --git a/src/proxy/tests/masking_self_target_loop_security_tests.rs b/src/proxy/tests/masking_self_target_loop_security_tests.rs index b92ce3d..18cb0d7 100644 --- a/src/proxy/tests/masking_self_target_loop_security_tests.rs +++ b/src/proxy/tests/masking_self_target_loop_security_tests.rs @@ -12,71 +12,77 @@ fn closed_local_port() -> u16 { port } -#[test] -fn self_target_detection_matches_literal_ipv4_listener() { +#[tokio::test] +async fn self_target_detection_matches_literal_ipv4_listener() { let local: SocketAddr = "198.51.100.40:443".parse().unwrap(); - assert!(is_mask_target_local_listener( + assert!(is_mask_target_local_listener_async( "198.51.100.40", 443, local, None, - )); + ) + .await); } -#[test] -fn self_target_detection_matches_bracketed_ipv6_listener() { +#[tokio::test] +async fn self_target_detection_matches_bracketed_ipv6_listener() { let local: SocketAddr = "[2001:db8::44]:8443".parse().unwrap(); - assert!(is_mask_target_local_listener( + assert!(is_mask_target_local_listener_async( "[2001:db8::44]", 8443, local, None, - )); + ) + .await); } -#[test] -fn self_target_detection_keeps_same_ip_different_port_forwardable() { +#[tokio::test] +async fn self_target_detection_keeps_same_ip_different_port_forwardable() { let local: SocketAddr = "203.0.113.44:443".parse().unwrap(); - assert!(!is_mask_target_local_listener( + assert!(!is_mask_target_local_listener_async( "203.0.113.44", 8443, local, None, - )); + ) + .await); } -#[test] -fn self_target_detection_normalizes_ipv4_mapped_ipv6_literal() { +#[tokio::test] +async fn self_target_detection_normalizes_ipv4_mapped_ipv6_literal() { let local: SocketAddr = "127.0.0.1:443".parse().unwrap(); - assert!(is_mask_target_local_listener( + assert!(is_mask_target_local_listener_async( "::ffff:127.0.0.1", 443, local, None, - )); + ) + .await); } -#[test] -fn self_target_detection_unspecified_bind_blocks_loopback_target() { +#[tokio::test] +async fn self_target_detection_unspecified_bind_blocks_loopback_target() { let local: SocketAddr = "0.0.0.0:443".parse().unwrap(); - assert!(is_mask_target_local_listener( + assert!(is_mask_target_local_listener_async( "127.0.0.1", 443, local, None, - )); + ) + .await); } -#[test] -fn self_target_detection_unspecified_bind_keeps_remote_target_forwardable() { +#[tokio::test] +async fn self_target_detection_unspecified_bind_keeps_remote_target_forwardable() { let local: SocketAddr = "0.0.0.0:443".parse().unwrap(); let remote: SocketAddr = "198.51.100.44:443".parse().unwrap(); - assert!(!is_mask_target_local_listener( + assert!(!is_mask_target_local_listener_async( "mask.example", 443, local, Some(remote), - )); + ) + .await); } #[tokio::test] diff --git a/src/proxy/tests/masking_timing_budget_coupling_security_tests.rs b/src/proxy/tests/masking_timing_budget_coupling_security_tests.rs new file mode 100644 index 0000000..1c342ea --- /dev/null +++ b/src/proxy/tests/masking_timing_budget_coupling_security_tests.rs @@ -0,0 +1,55 @@ +#![cfg(unix)] + +use super::*; +use tokio::io::{AsyncWriteExt, duplex}; +use tokio::time::{Duration, Instant, timeout}; + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn adversarial_delayed_interface_lookup_does_not_consume_outcome_floor_budget() { + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = None; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = 443; + config.censorship.mask_timing_normalization_enabled = true; + config.censorship.mask_timing_normalization_floor_ms = 120; + config.censorship.mask_timing_normalization_ceiling_ms = 120; + + let peer: SocketAddr = "203.0.113.151:55151".parse().expect("valid peer"); + let local_addr: SocketAddr = "0.0.0.0:443".parse().expect("valid local addr"); + let beobachten = BeobachtenStore::new(); + + let refresh_lock = LOCAL_INTERFACE_REFRESH_LOCK.get_or_init(|| AsyncMutex::new(())); + let held_refresh_guard = refresh_lock.lock().await; + + let (mut client, server) = duplex(1024); + let started = Instant::now(); + let task = tokio::spawn(async move { + handle_bad_client( + server, + tokio::io::sink(), + b"GET / HTTP/1.1\r\n\r\n", + peer, + local_addr, + &config, + &beobachten, + ) + .await; + }); + + tokio::time::sleep(Duration::from_millis(80)).await; + drop(held_refresh_guard); + client.shutdown().await.expect("client shutdown must succeed"); + + timeout(Duration::from_secs(2), task) + .await + .expect("task must finish in bounded time") + .expect("task must not panic"); + let elapsed = started.elapsed(); + + assert!( + elapsed >= Duration::from_millis(180) && elapsed < Duration::from_millis(350), + "timing normalization floor must start after pre-outcome self-target checks" + ); +} \ No newline at end of file diff --git a/src/proxy/tests/middle_relay_coverage_high_risk_security_tests.rs b/src/proxy/tests/middle_relay_coverage_high_risk_security_tests.rs index fff26b4..44c201f 100644 --- a/src/proxy/tests/middle_relay_coverage_high_risk_security_tests.rs +++ b/src/proxy/tests/middle_relay_coverage_high_risk_security_tests.rs @@ -645,6 +645,75 @@ fn quota_exceeded_boundary_is_inclusive() { assert!(!quota_exceeded_for_user(&stats, user, Some(51))); } +#[test] +fn quota_soft_helper_matches_capped_generic_helper_matrix() { + let stats = Stats::new(); + let user = "quota-soft-parity"; + + for used in [0u64, 1, 7, 63, 127, 255] { + stats.sub_user_octets_to(user, stats.get_user_total_octets(user)); + stats.add_user_octets_to(user, used); + + for quota in [8u64, 64, 128, 256] { + for overshoot in [0u64, 1, 5, 32] { + for bytes in [0u64, 1, 2, 7, 31, 64] { + let soft = quota_would_be_exceeded_for_user_soft( + &stats, + user, + Some(quota), + bytes, + overshoot, + ); + let capped = quota_would_be_exceeded_for_user( + &stats, + user, + Some(quota_soft_cap(quota, overshoot)), + bytes, + ); + assert_eq!( + soft, capped, + "soft helper parity mismatch: used={used} quota={quota} overshoot={overshoot} bytes={bytes}" + ); + } + } + } + } +} + +#[test] +fn quota_soft_helper_none_limit_never_rejects() { + let stats = Stats::new(); + let user = "quota-soft-none"; + stats.add_user_octets_to(user, u64::MAX); + + assert!(!quota_would_be_exceeded_for_user_soft( + &stats, + user, + None, + u64::MAX, + u64::MAX, + )); +} + +#[test] +fn quota_soft_cap_saturates_and_stays_fail_closed() { + let stats = Stats::new(); + let user = "quota-soft-saturating"; + let quota = u64::MAX - 2; + let overshoot = 100; + + assert_eq!(quota_soft_cap(quota, overshoot), u64::MAX); + + stats.add_user_octets_to(user, u64::MAX - 1); + assert!(quota_would_be_exceeded_for_user_soft( + &stats, + user, + Some(quota), + 2, + overshoot, + )); +} + #[tokio::test] async fn enqueue_c2me_close_fast_path_succeeds_without_backpressure() { let (tx, mut rx) = mpsc::channel::(4); diff --git a/src/proxy/tests/middle_relay_cross_mode_lock_release_regression_security_tests.rs b/src/proxy/tests/middle_relay_cross_mode_lock_release_regression_security_tests.rs new file mode 100644 index 0000000..a787aa6 --- /dev/null +++ b/src/proxy/tests/middle_relay_cross_mode_lock_release_regression_security_tests.rs @@ -0,0 +1,295 @@ +use super::*; +use crate::crypto::{AesCtr, SecureRandom}; +use crate::stats::Stats; +use crate::stream::CryptoWriter; +use bytes::Bytes; +use std::pin::Pin; +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; +use std::sync::{Arc, Mutex}; +use std::task::{Context, Poll, Waker}; +use tokio::io::AsyncWrite; +use tokio::sync::Notify; +use tokio::task::JoinSet; +use tokio::time::{Duration, timeout}; + +fn make_crypto_writer(writer: W) -> CryptoWriter +where + W: tokio::io::AsyncWrite + Unpin, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) +} + +#[derive(Default)] +struct BlockingWriteState { + write_entered: AtomicBool, + released: AtomicBool, + write_waker: Mutex>, + write_entered_notify: Notify, +} + +struct BlockingWrite { + state: Arc, +} + +impl BlockingWrite { + fn new(state: Arc) -> Self { + Self { state } + } +} + +impl AsyncWrite for BlockingWrite { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.state.write_entered.store(true, Ordering::Release); + self.state.write_entered_notify.notify_waiters(); + + if self.state.released.load(Ordering::Acquire) { + return Poll::Ready(Ok(buf.len())); + } + + if let Ok(mut slot) = self.state.write_waker.lock() { + *slot = Some(cx.waker().clone()); + } + + Poll::Pending + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} + +async fn wait_until_blocking_write_entered(state: &Arc) { + for _ in 0..8 { + if state.write_entered.load(Ordering::Acquire) { + return; + } + let _ = timeout(Duration::from_millis(25), state.write_entered_notify.notified()).await; + } + + panic!("blocking writer did not enter poll_write in bounded time"); +} + +fn release_blocking_write(state: &Arc) { + state.released.store(true, Ordering::Release); + if let Ok(mut slot) = state.write_waker.lock() + && let Some(waker) = slot.take() + { + waker.wake(); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn adversarial_blocked_write_releases_cross_mode_lock_and_preserves_fail_closed_quota() { + let stats = Arc::new(Stats::new()); + let user = format!("middle-cross-release-regression-{}", std::process::id()); + let cross_mode_lock = Arc::new(cross_mode_quota_user_lock_for_tests(&user)); + let bytes_me2c = Arc::new(AtomicU64::new(0)); + let writer_state = Arc::new(BlockingWriteState::default()); + + let first = { + let stats = Arc::clone(&stats); + let user = user.clone(); + let cross_mode_lock = Arc::clone(&cross_mode_lock); + let bytes_me2c = Arc::clone(&bytes_me2c); + let writer_state = Arc::clone(&writer_state); + tokio::spawn(async move { + let mut writer = make_crypto_writer(BlockingWrite::new(writer_state)); + let mut frame_buf = Vec::new(); + process_me_writer_response_with_cross_mode_lock( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0xAA, 0xBB, 0xCC, 0xDD]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + stats.as_ref(), + &user, + Some(4), + 0, + Some(&cross_mode_lock), + bytes_me2c.as_ref(), + 41_000, + false, + false, + ) + .await + }) + }; + + wait_until_blocking_write_entered(&writer_state).await; + + let guard = timeout(Duration::from_millis(40), cross_mode_lock.lock()) + .await + .expect("cross-mode lock must be released while first write is pending"); + drop(guard); + + let second = { + let stats = Arc::clone(&stats); + let user = user.clone(); + let cross_mode_lock = Arc::clone(&cross_mode_lock); + let bytes_me2c = Arc::clone(&bytes_me2c); + tokio::spawn(async move { + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + timeout( + Duration::from_millis(150), + process_me_writer_response_with_cross_mode_lock( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0xEE]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + stats.as_ref(), + &user, + Some(4), + 0, + Some(&cross_mode_lock), + bytes_me2c.as_ref(), + 41_001, + false, + false, + ), + ) + .await + }) + }; + + let second_result = second + .await + .expect("second task must not panic") + .expect("second write must not block on cross-mode lock"); + assert!( + matches!(second_result, Err(ProxyError::DataQuotaExceeded { .. })), + "second write must fail closed due to first write reservation" + ); + + release_blocking_write(&writer_state); + + let first_result = timeout(Duration::from_millis(300), first) + .await + .expect("first task timed out") + .expect("first task must not panic"); + assert!(first_result.is_ok()); + + assert_eq!(stats.get_user_total_octets(&user), 4); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), 4); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_pending_write_does_not_starve_same_user_waiters_after_quota_boundary() { + let stats = Arc::new(Stats::new()); + let user = format!("middle-cross-release-stress-{}", std::process::id()); + let cross_mode_lock = Arc::new(cross_mode_quota_user_lock_for_tests(&user)); + let bytes_me2c = Arc::new(AtomicU64::new(0)); + let writer_state = Arc::new(BlockingWriteState::default()); + + let first = { + let stats = Arc::clone(&stats); + let user = user.clone(); + let cross_mode_lock = Arc::clone(&cross_mode_lock); + let bytes_me2c = Arc::clone(&bytes_me2c); + let writer_state = Arc::clone(&writer_state); + tokio::spawn(async move { + let mut writer = make_crypto_writer(BlockingWrite::new(writer_state)); + let mut frame_buf = Vec::new(); + process_me_writer_response_with_cross_mode_lock( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0x01, 0x02]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + stats.as_ref(), + &user, + Some(3), + 0, + Some(&cross_mode_lock), + bytes_me2c.as_ref(), + 41_100, + false, + false, + ) + .await + }) + }; + + wait_until_blocking_write_entered(&writer_state).await; + + let mut set = JoinSet::new(); + for idx in 0..48u64 { + let stats = Arc::clone(&stats); + let user = user.clone(); + let cross_mode_lock = Arc::clone(&cross_mode_lock); + let bytes_me2c = Arc::clone(&bytes_me2c); + set.spawn(async move { + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + timeout( + Duration::from_millis(200), + process_me_writer_response_with_cross_mode_lock( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0x10]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + stats.as_ref(), + &user, + Some(3), + 0, + Some(&cross_mode_lock), + bytes_me2c.as_ref(), + 41_200 + idx, + false, + false, + ), + ) + .await + }); + } + + let mut ok = 0usize; + let mut quota_exceeded = 0usize; + while let Some(done) = set.join_next().await { + let timed = done.expect("waiter task must not panic"); + let result = timed.expect("waiter must not block behind pending first write"); + match result { + Ok(_) => ok += 1, + Err(ProxyError::DataQuotaExceeded { .. }) => quota_exceeded += 1, + Err(other) => panic!("unexpected error in waiter: {other:?}"), + } + } + + assert_eq!(ok, 1, "exactly one waiter should consume remaining one-byte quota"); + assert_eq!(quota_exceeded, 47); + + release_blocking_write(&writer_state); + + let first_result = timeout(Duration::from_millis(300), first) + .await + .expect("first task timed out") + .expect("first task must not panic"); + assert!(first_result.is_ok()); + + assert_eq!(stats.get_user_total_octets(&user), 3); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), 3); +} diff --git a/src/proxy/tests/middle_relay_cross_mode_lookup_efficiency_security_tests.rs b/src/proxy/tests/middle_relay_cross_mode_lookup_efficiency_security_tests.rs new file mode 100644 index 0000000..37e1b87 --- /dev/null +++ b/src/proxy/tests/middle_relay_cross_mode_lookup_efficiency_security_tests.rs @@ -0,0 +1,116 @@ +use super::*; +use crate::crypto::{AesCtr, SecureRandom}; +use crate::stats::Stats; +use crate::stream::CryptoWriter; +use bytes::Bytes; +use std::sync::Arc; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::{Mutex, OnceLock}; + +fn make_crypto_writer(writer: W) -> CryptoWriter +where + W: tokio::io::AsyncWrite + Unpin, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) +} + +fn lookup_counter_test_lock() -> &'static Mutex<()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())) +} + +#[tokio::test] +async fn tdd_prefetched_cross_mode_lock_avoids_per_frame_registry_lookup_in_me_to_client_writer() { + let _guard = lookup_counter_test_lock() + .lock() + .unwrap_or_else(|poison| poison.into_inner()); + + let stats = Stats::new(); + let user = format!("middle-cross-mode-lookup-{}", std::process::id()); + let cross_mode_lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); + + crate::proxy::quota_lock_registry::reset_cross_mode_quota_user_lock_lookup_count_for_tests(); + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let bytes_me2c = AtomicU64::new(0); + + for idx in 0..8u64 { + let outcome = process_me_writer_response_with_cross_mode_lock( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0xAB]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(1024), + 0, + Some(&cross_mode_lock), + &bytes_me2c, + 20_000 + idx, + false, + false, + ) + .await; + + assert!(outcome.is_ok()); + } + + assert_eq!( + crate::proxy::quota_lock_registry::cross_mode_quota_user_lock_lookup_count_for_user_for_tests(&user), + 0, + "prefetched lock path must not re-query lock registry per frame" + ); + assert_eq!(stats.get_user_total_octets(&user), 8); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), 8); +} + +#[tokio::test] +async fn control_without_prefetched_lock_still_uses_registry_lookup_path() { + let _guard = lookup_counter_test_lock() + .lock() + .unwrap_or_else(|poison| poison.into_inner()); + + let stats = Stats::new(); + let user = format!("middle-cross-mode-lookup-control-{}", std::process::id()); + + crate::proxy::quota_lock_registry::reset_cross_mode_quota_user_lock_lookup_count_for_tests(); + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let bytes_me2c = AtomicU64::new(0); + + let outcome = process_me_writer_response_with_cross_mode_lock( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0xCD]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(1024), + 0, + None, + &bytes_me2c, + 20_100, + false, + false, + ) + .await; + + assert!(outcome.is_ok()); + assert_eq!( + crate::proxy::quota_lock_registry::cross_mode_quota_user_lock_lookup_count_for_user_for_tests(&user), + 1, + "fallback path without prefetched lock should perform a registry lookup" + ); +} diff --git a/src/proxy/tests/middle_relay_cross_mode_quota_lock_matrix_security_tests.rs b/src/proxy/tests/middle_relay_cross_mode_quota_lock_matrix_security_tests.rs new file mode 100644 index 0000000..bc7c857 --- /dev/null +++ b/src/proxy/tests/middle_relay_cross_mode_quota_lock_matrix_security_tests.rs @@ -0,0 +1,376 @@ +use super::*; +use crate::crypto::{AesCtr, SecureRandom}; +use crate::stats::Stats; +use crate::stream::CryptoWriter; +use bytes::Bytes; +use std::sync::Arc; +use std::sync::atomic::{AtomicU64, Ordering}; +use tokio::time::{Duration, timeout}; + +fn make_crypto_writer(writer: W) -> CryptoWriter +where + W: tokio::io::AsyncWrite + Unpin, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) +} + +#[tokio::test] +async fn positive_quota_limited_me_to_client_write_updates_counters_exactly_once() { + let stats = Stats::new(); + let user = format!("middle-cross-matrix-positive-{}", std::process::id()); + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let bytes_me2c = AtomicU64::new(0); + + let result = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[1, 2, 3, 4]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(128), + 0, + &bytes_me2c, + 10_001, + false, + false, + ) + .await; + + assert!(result.is_ok()); + assert_eq!(stats.get_user_total_octets(&user), 4); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), 4); +} + +#[tokio::test] +async fn negative_held_cross_mode_lock_blocks_quota_limited_me_to_client_path() { + let stats = Stats::new(); + let user = format!("middle-cross-matrix-negative-{}", std::process::id()); + let held = cross_mode_quota_user_lock_for_tests(&user); + let held_guard = held + .try_lock() + .expect("test must hold lock before ME->C call"); + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let bytes_me2c = AtomicU64::new(0); + + let blocked = timeout( + Duration::from_millis(25), + process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0x41]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(256), + 0, + &bytes_me2c, + 10_002, + false, + false, + ), + ) + .await; + + assert!(blocked.is_err()); + drop(held_guard); +} + +#[tokio::test] +async fn edge_quota_none_bypasses_cross_mode_lock_guard_in_me_to_client_path() { + let stats = Stats::new(); + let user = format!("middle-cross-matrix-edge-none-{}", std::process::id()); + let held = cross_mode_quota_user_lock_for_tests(&user); + let held_guard = held + .try_lock() + .expect("test must hold lock while quota is disabled"); + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let bytes_me2c = AtomicU64::new(0); + + let outcome = timeout( + Duration::from_millis(80), + process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0x11, 0x22]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + None, + 0, + &bytes_me2c, + 10_003, + false, + false, + ), + ) + .await + .expect("quota-none path must not wait on cross-mode lock"); + + assert!(outcome.is_ok()); + drop(held_guard); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn adversarial_same_user_parallel_quota_limited_writes_stay_hard_capped() { + let stats = Arc::new(Stats::new()); + let user = format!("middle-cross-matrix-adversarial-{}", std::process::id()); + let limit = 64u64; + let bytes_me2c = Arc::new(AtomicU64::new(0)); + let mut tasks = Vec::new(); + + for idx in 0..256u64 { + let stats = Arc::clone(&stats); + let bytes_me2c = Arc::clone(&bytes_me2c); + let user = user.clone(); + tasks.push(tokio::spawn(async move { + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0xEE]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + stats.as_ref(), + &user, + Some(limit), + 0, + bytes_me2c.as_ref(), + 11_000 + idx, + false, + false, + ) + .await + })); + } + + let mut ok = 0usize; + for task in tasks { + match task.await.expect("task must not panic") { + Ok(_) => ok += 1, + Err(ProxyError::DataQuotaExceeded { .. }) => {} + Err(other) => panic!("unexpected error in adversarial parallel case: {other:?}"), + } + } + + assert_eq!(ok, limit as usize); + assert_eq!(stats.get_user_total_octets(&user), limit); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), limit); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn integration_shared_lock_blocks_direct_relay_and_middle_relay_for_same_user() { + let user = format!("middle-cross-matrix-integration-{}", std::process::id()); + let relay_lock = crate::proxy::relay::cross_mode_quota_user_lock_for_tests(&user); + let middle_lock = cross_mode_quota_user_lock_for_tests(&user); + assert!( + Arc::ptr_eq(&relay_lock, &middle_lock), + "relay and middle-relay must share the same cross-mode lock identity" + ); + + let held_guard = relay_lock + .try_lock() + .expect("test must hold shared cross-mode lock"); + + let stats = Stats::new(); + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let bytes_me2c = AtomicU64::new(0); + + let middle_blocked = timeout( + Duration::from_millis(25), + process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0x92]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(1024), + 0, + &bytes_me2c, + 12_001, + false, + false, + ), + ) + .await; + assert!(middle_blocked.is_err()); + + drop(held_guard); + + let middle_ready = timeout( + Duration::from_millis(250), + process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0x94]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(1024), + 0, + &bytes_me2c, + 12_002, + false, + false, + ), + ) + .await + .expect("middle path must complete after release"); + + assert!(middle_ready.is_ok()); +} + +#[tokio::test] +async fn light_fuzz_mixed_payload_sizes_with_periodic_lock_holds_keeps_accounting_consistent() { + let stats = Stats::new(); + let user = format!("middle-cross-matrix-fuzz-{}", std::process::id()); + let bytes_me2c = AtomicU64::new(0); + let mut seed = 0xC0DE_1234_55AA_9988u64; + + for case in 0..96u32 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let hold = (seed & 0x03) == 0; + let mut held_lock = None; + let maybe_guard = if hold { + held_lock = Some(cross_mode_quota_user_lock_for_tests(&user)); + Some( + held_lock + .as_ref() + .expect("held lock should be present") + .try_lock() + .expect("cross-mode lock should be acquirable in fuzz round"), + ) + } else { + None + }; + + let payload_len = ((seed >> 8) as usize % 8) + 1; + let payload = vec![(seed & 0xff) as u8; payload_len]; + let before = stats.get_user_total_octets(&user); + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + + let timed = timeout( + Duration::from_millis(20), + process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from(payload), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(1024), + 0, + &bytes_me2c, + 13_000 + case as u64, + false, + false, + ), + ) + .await; + + if hold { + assert!(timed.is_err(), "held-lock fuzz round must block within timeout"); + assert_eq!(stats.get_user_total_octets(&user), before); + } else { + let done = timed.expect("unheld fuzz round must complete in time"); + assert!(done.is_ok()); + } + + drop(maybe_guard); + drop(held_lock); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), stats.get_user_total_octets(&user)); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_held_user_lock_does_not_block_other_users_me_to_client_writes() { + let held_user = format!("middle-cross-matrix-stress-held-{}", std::process::id()); + let free_user = format!("middle-cross-matrix-stress-free-{}", std::process::id()); + + let held = cross_mode_quota_user_lock_for_tests(&held_user); + let held_guard = held + .try_lock() + .expect("test must hold lock for blocked user"); + + let mut tasks = Vec::new(); + for idx in 0..64u64 { + let user = free_user.clone(); + tasks.push(tokio::spawn(async move { + let stats = Stats::new(); + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let bytes_me2c = AtomicU64::new(0); + process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0xA0]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(1), + 0, + &bytes_me2c, + 14_000 + idx, + false, + false, + ) + .await + })); + } + + timeout(Duration::from_secs(2), async { + for task in tasks { + let done = task.await.expect("free-user task must not panic"); + assert!(done.is_ok()); + } + }) + .await + .expect("free-user tasks should complete without waiting for held user's lock"); + + drop(held_guard); +} diff --git a/src/proxy/tests/middle_relay_cross_mode_quota_reservation_security_tests.rs b/src/proxy/tests/middle_relay_cross_mode_quota_reservation_security_tests.rs new file mode 100644 index 0000000..51092bd --- /dev/null +++ b/src/proxy/tests/middle_relay_cross_mode_quota_reservation_security_tests.rs @@ -0,0 +1,254 @@ +use super::*; +use crate::crypto::{AesCtr, SecureRandom}; +use crate::stats::Stats; +use crate::stream::CryptoWriter; +use bytes::Bytes; +use std::pin::Pin; +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; +use std::sync::{Arc, Mutex}; +use std::task::{Context, Poll, Waker}; +use tokio::io::AsyncWrite; +use tokio::sync::Notify; +use tokio::time::{Duration, timeout}; + +fn make_crypto_writer(writer: W) -> CryptoWriter +where + W: tokio::io::AsyncWrite + Unpin, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) +} + +#[derive(Default)] +struct BlockingWriteState { + write_entered: AtomicBool, + released: AtomicBool, + write_waker: Mutex>, + write_entered_notify: Notify, +} + +struct BlockingWrite { + state: Arc, +} + +impl BlockingWrite { + fn new(state: Arc) -> Self { + Self { state } + } +} + +impl AsyncWrite for BlockingWrite { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.state.write_entered.store(true, Ordering::Release); + self.state.write_entered_notify.notify_waiters(); + + if self.state.released.load(Ordering::Acquire) { + return Poll::Ready(Ok(buf.len())); + } + + if let Ok(mut slot) = self.state.write_waker.lock() { + *slot = Some(cx.waker().clone()); + } + Poll::Pending + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} + +async fn wait_until_blocking_write_entered(state: &Arc) { + for _ in 0..8 { + if state.write_entered.load(Ordering::Acquire) { + return; + } + let _ = timeout(Duration::from_millis(25), state.write_entered_notify.notified()).await; + } + + panic!("blocking writer did not enter poll_write in bounded time"); +} + +fn release_blocking_write(state: &Arc) { + state.released.store(true, Ordering::Release); + if let Ok(mut slot) = state.write_waker.lock() + && let Some(waker) = slot.take() + { + waker.wake(); + } +} + +#[tokio::test] +async fn adversarial_held_cross_mode_lock_blocks_me_to_client_quota_reservation_path() { + let stats = Stats::new(); + let user = format!("middle-me2c-cross-mode-held-{}", std::process::id()); + let held = cross_mode_quota_user_lock_for_tests(&user); + let held_guard = held + .try_lock() + .expect("test must hold shared cross-mode lock before ME->C write path"); + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let bytes_me2c = AtomicU64::new(0); + + let blocked = timeout( + Duration::from_millis(25), + process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0x41]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(1024), + 0, + &bytes_me2c, + 9901, + false, + false, + ), + ) + .await; + + assert!( + blocked.is_err(), + "ME->C quota reservation path must be serialized by held shared cross-mode lock" + ); + + drop(held_guard); + + let released = timeout( + Duration::from_millis(250), + process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0x42]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(1024), + 0, + &bytes_me2c, + 9902, + false, + false, + ), + ) + .await + .expect("ME->C write must complete after cross-mode lock release"); + + assert!(released.is_ok()); +} + +#[tokio::test] +async fn business_uncontended_cross_mode_lock_allows_me_to_client_quota_reservation() { + let stats = Stats::new(); + let user = format!("middle-me2c-cross-mode-free-{}", std::process::id()); + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let bytes_me2c = AtomicU64::new(0); + + let outcome = timeout( + Duration::from_millis(250), + process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0x55, 0x66]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(1024), + 0, + &bytes_me2c, + 9903, + false, + false, + ), + ) + .await + .expect("uncontended ME->C path should not stall"); + + assert!(outcome.is_ok()); + assert_eq!(stats.get_user_total_octets(&user), 2); + assert_eq!(bytes_me2c.load(std::sync::atomic::Ordering::Relaxed), 2); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn adversarial_cross_mode_lock_is_released_before_me_to_client_write_await() { + let stats = Arc::new(Stats::new()); + let user = format!("middle-me2c-lock-drop-before-write-{}", std::process::id()); + let cross_mode_lock = cross_mode_quota_user_lock_for_tests(&user); + let bytes_me2c = Arc::new(AtomicU64::new(0)); + let writer_state = Arc::new(BlockingWriteState::default()); + + let worker = { + let stats = Arc::clone(&stats); + let user = user.clone(); + let cross_mode_lock = Arc::clone(&cross_mode_lock); + let bytes_me2c = Arc::clone(&bytes_me2c); + let writer_state = Arc::clone(&writer_state); + tokio::spawn(async move { + let mut writer = make_crypto_writer(BlockingWrite::new(writer_state)); + let mut frame_buf = Vec::new(); + let rng = SecureRandom::new(); + process_me_writer_response_with_cross_mode_lock( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0xDE, 0xAD, 0xBE, 0xEF]), + }, + &mut writer, + ProtoTag::Intermediate, + &rng, + &mut frame_buf, + stats.as_ref(), + &user, + Some(1024), + 0, + Some(&cross_mode_lock), + bytes_me2c.as_ref(), + 9910, + false, + false, + ) + .await + }) + }; + + wait_until_blocking_write_entered(&writer_state).await; + + let acquired_guard = timeout(Duration::from_millis(40), cross_mode_lock.lock()) + .await + .expect("cross-mode lock must be free while ME->C write is pending"); + drop(acquired_guard); + + release_blocking_write(&writer_state); + + let result = timeout(Duration::from_millis(300), worker) + .await + .expect("ME->C worker timed out after releasing blocking writer") + .expect("ME->C worker must not panic"); + + assert!(result.is_ok()); + assert_eq!(stats.get_user_total_octets(&user), 4); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), 4); +} diff --git a/src/proxy/tests/middle_relay_hol_quota_security_tests.rs b/src/proxy/tests/middle_relay_hol_quota_security_tests.rs index 3d7929b..3ce0235 100644 --- a/src/proxy/tests/middle_relay_hol_quota_security_tests.rs +++ b/src/proxy/tests/middle_relay_hol_quota_security_tests.rs @@ -128,6 +128,7 @@ async fn adversarial_same_user_slow_writer_must_not_hol_block_peer_connection() &stats, user, quota_limit, + 0, &bytes_me2c, 7001, false, @@ -167,6 +168,7 @@ async fn adversarial_same_user_slow_writer_must_not_hol_block_peer_connection() &stats_fast, user, quota_limit, + 0, &bytes_fast, 7002, false, @@ -208,6 +210,7 @@ async fn negative_write_failure_rolls_back_pre_accounted_quota_and_forensics_byt &stats, user, Some(64), + 0, &bytes_me2c, 7003, false, diff --git a/src/proxy/tests/middle_relay_quota_extended_attack_surface_security_tests.rs b/src/proxy/tests/middle_relay_quota_extended_attack_surface_security_tests.rs new file mode 100644 index 0000000..29384e0 --- /dev/null +++ b/src/proxy/tests/middle_relay_quota_extended_attack_surface_security_tests.rs @@ -0,0 +1,372 @@ +use super::*; +use crate::crypto::{AesCtr, SecureRandom}; +use crate::error::ProxyError; +use crate::stats::Stats; +use crate::stream::CryptoWriter; +use bytes::Bytes; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::{Arc, OnceLock, Mutex}; +use tokio::sync::Mutex as AsyncMutex; +use tokio::task::JoinSet; +use tokio::time::{Duration, timeout}; + +fn make_crypto_writer(writer: W) -> CryptoWriter +where + W: tokio::io::AsyncWrite + Unpin, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) +} + +fn lookup_test_lock() -> &'static Mutex<()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())) +} + +#[tokio::test] +async fn positive_me2c_quota_counts_bytes_exactly_once() { + let _guard = lookup_test_lock().lock().unwrap(); + let stats = Stats::new(); + let user = format!("quota-middle-ext-positive-{}", std::process::id()); + let lock = Arc::new(AsyncMutex::new(())); + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let bytes_me2c = AtomicU64::new(0); + + let result = process_me_writer_response_with_cross_mode_lock( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[1, 2, 3, 4, 5]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(64), + 0, + Some(&lock), + &bytes_me2c, + 70_001, + false, + false, + ) + .await; + + assert!(result.is_ok()); + assert_eq!(stats.get_user_total_octets(&user), 5); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), 5); +} + +#[tokio::test] +async fn negative_held_crossmode_lock_blocks_me2c_write() { + let _guard = lookup_test_lock().lock().unwrap(); + let stats = Stats::new(); + let user = format!("quota-middle-ext-negative-{}", std::process::id()); + + let lock = Arc::new(AsyncMutex::new(())); + let _held = lock.try_lock().expect("lock must be held"); + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let bytes_me2c = AtomicU64::new(0); + + let blocked = timeout( + Duration::from_millis(25), + process_me_writer_response_with_cross_mode_lock( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0xFE]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(16), + 0, + Some(&lock), + &bytes_me2c, + 70_101, + false, + false, + ), + ) + .await; + + assert!(blocked.is_err()); + assert_eq!(stats.get_user_total_octets(&user), 0); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0); +} + +#[tokio::test] +async fn edge_zero_quota_zero_payload_is_fail_closed() { + let _guard = lookup_test_lock().lock().unwrap(); + let stats = Stats::new(); + let user = format!("quota-middle-ext-edge-{}", std::process::id()); + + let lock = Arc::new(AsyncMutex::new(())); + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let bytes_me2c = AtomicU64::new(0); + + let result = process_me_writer_response_with_cross_mode_lock( + MeResponse::Data { + flags: 0, + data: Bytes::new(), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(0), + 0, + Some(&lock), + &bytes_me2c, + 70_201, + false, + false, + ) + .await; + + assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. }))); + assert_eq!(stats.get_user_total_octets(&user), 0); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn adversarial_parallel_me2c_race_falls_back_to_quota_error() { + let _guard = lookup_test_lock().lock().unwrap(); + let stats = Arc::new(Stats::new()); + let user = format!("quota-middle-ext-blackhat-{}", std::process::id()); + let quota = 64u64; + let lock = Arc::new(AsyncMutex::new(())); + let bytes_me2c = Arc::new(AtomicU64::new(0)); + + let mut set = JoinSet::new(); + for i in 0..256u64 { + let stats = Arc::clone(&stats); + let user = user.clone(); + let lock = Arc::clone(&lock); + let bytes_me2c = Arc::clone(&bytes_me2c); + + set.spawn(async move { + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let payload = vec![((i & 0xFF) as u8); (i % 4 + 1) as usize]; + + process_me_writer_response_with_cross_mode_lock( + MeResponse::Data { + flags: 0, + data: Bytes::from(payload), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + stats.as_ref(), + &user, + Some(quota), + 0, + Some(&lock), + bytes_me2c.as_ref(), + 70_301 + i, + false, + false, + ) + .await + }); + } + + let mut succeeded = 0usize; + while let Some(done) = set.join_next().await { + match done.expect("task must not panic") { + Ok(_) => succeeded += 1, + Err(ProxyError::DataQuotaExceeded { .. }) => {} + Err(other) => panic!("unexpected error {other:?}"), + } + } + + assert_eq!(stats.get_user_total_octets(&user), bytes_me2c.load(Ordering::Relaxed)); + assert!(stats.get_user_total_octets(&user) <= quota); + assert!(succeeded <= quota as usize); +} + +#[tokio::test] +async fn integration_shared_prefetched_lock_blocks_then_releases_writer() { + let stats = Stats::new(); + let user = format!("quota-middle-ext-integration-{}", std::process::id()); + let lock = Arc::new(AsyncMutex::new(())); + let held = lock + .try_lock() + .expect("integration test must hold prefetched lock first"); + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let bytes_me2c = AtomicU64::new(0); + + let blocked = timeout( + Duration::from_millis(25), + process_me_writer_response_with_cross_mode_lock( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0xA1]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(8), + 0, + Some(&lock), + &bytes_me2c, + 70_360, + false, + false, + ), + ) + .await; + assert!(blocked.is_err()); + + drop(held); + + let after_release = timeout( + Duration::from_millis(150), + process_me_writer_response_with_cross_mode_lock( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0xA2]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(8), + 0, + Some(&lock), + &bytes_me2c, + 70_361, + false, + false, + ), + ) + .await + .expect("writer should progress once the shared lock is released"); + + assert!(after_release.is_ok()); +} + +#[tokio::test] +async fn light_fuzz_small_payloads_toggle_lock_state_stays_consistent() { + let _guard = lookup_test_lock().lock().unwrap(); + let stats = Stats::new(); + let user = format!("quota-middle-ext-fuzz-{}", std::process::id()); + let mut seed = 0xCAFE_BABE_1234u64; + let bytes_me2c = AtomicU64::new(0); + + for case in 0..48u32 { + seed ^= seed << 5; + seed ^= seed >> 12; + seed ^= seed << 13; + let hold = (seed & 0x1) == 0; + + let lock = Arc::new(AsyncMutex::new(())); + let maybe_guard = if hold { + Some(lock.try_lock().unwrap()) + } else { + None + }; + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + + let result = timeout( + Duration::from_millis(30), + process_me_writer_response_with_cross_mode_lock( + MeResponse::Data { + flags: 0, + data: Bytes::from(vec![(seed & 0xFF) as u8; ((seed as usize % 5) + 1)]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(128), + 0, + Some(&lock), + &bytes_me2c, + 70_401 + case as u64, + false, + false, + ), + ) + .await; + + if hold { + assert!(result.is_err()); + } else { + assert!(result.unwrap().is_ok()); + } + + drop(maybe_guard); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_parallel_free_users_during_held_user_lock_maintains_liveness() { + let _guard = lookup_test_lock().lock().unwrap(); + let held = Arc::new(AsyncMutex::new(())); + let _held_guard = held.try_lock().unwrap(); + + let mut set = JoinSet::new(); + for i in 0..48u64 { + set.spawn(async move { + let stats = Stats::new(); + let user = format!("quota-middle-ext-stress-free-{i}"); + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let bytes_me2c = AtomicU64::new(0); + let free_lock = Arc::new(AsyncMutex::new(())); + + process_me_writer_response_with_cross_mode_lock( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0xEE]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(1), + 0, + Some(&free_lock), + &bytes_me2c, + 70_500 + i, + false, + false, + ) + .await + }); + } + + timeout(Duration::from_secs(2), async { + while let Some(task) = set.join_next().await { + task.unwrap().unwrap(); + } + }) + .await + .unwrap(); +} diff --git a/src/proxy/tests/middle_relay_quota_reservation_adversarial_tests.rs b/src/proxy/tests/middle_relay_quota_reservation_adversarial_tests.rs index 717a375..963b3e0 100644 --- a/src/proxy/tests/middle_relay_quota_reservation_adversarial_tests.rs +++ b/src/proxy/tests/middle_relay_quota_reservation_adversarial_tests.rs @@ -5,6 +5,8 @@ use crate::stream::CryptoWriter; use bytes::Bytes; use std::sync::Arc; use std::sync::atomic::{AtomicU64, Ordering}; +use std::task::{Context, Poll}; +use tokio::io::AsyncWrite; use tokio::task::JoinSet; fn make_crypto_writer(writer: W) -> CryptoWriter @@ -16,6 +18,77 @@ where CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) } +struct FailingWriter; + +impl AsyncWrite for FailingWriter { + fn poll_write( + self: std::pin::Pin<&mut Self>, + _cx: &mut Context<'_>, + _buf: &[u8], + ) -> Poll> { + Poll::Ready(Err(std::io::Error::other("forced writer failure"))) + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown( + self: std::pin::Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } +} + +struct FailAfterBudgetWriter { + remaining: usize, + written: usize, +} + +impl FailAfterBudgetWriter { + fn new(remaining: usize) -> Self { + Self { + remaining, + written: 0, + } + } +} + +impl AsyncWrite for FailAfterBudgetWriter { + fn poll_write( + mut self: std::pin::Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + if self.remaining == 0 { + return Poll::Ready(Err(std::io::Error::other("forced short-write exhaustion"))); + } + + let n = self.remaining.min(buf.len()); + self.remaining -= n; + self.written += n; + Poll::Ready(Ok(n)) + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown( + self: std::pin::Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } +} + #[tokio::test] async fn positive_exact_quota_boundary_allows_last_frame_and_blocks_next() { let stats = Stats::new(); @@ -38,6 +111,7 @@ async fn positive_exact_quota_boundary_allows_last_frame_and_blocks_next() { &stats, user, Some(8), + 0, &bytes_me2c, 7101, false, @@ -62,6 +136,7 @@ async fn positive_exact_quota_boundary_allows_last_frame_and_blocks_next() { &stats, user, Some(8), + 0, &bytes_me2c, 7102, false, @@ -105,6 +180,7 @@ async fn adversarial_parallel_reservation_stress_never_overshoots_quota_or_count stats_ref.as_ref(), &user_owned, Some(quota_limit), + 0, bytes_ref.as_ref(), 7200 + idx, false, @@ -171,6 +247,7 @@ async fn light_fuzz_random_frame_sizes_preserve_quota_and_counter_consistency() &stats, user, Some(quota_limit), + 0, &bytes_me2c, 7300 + conn, false, @@ -189,4 +266,801 @@ async fn light_fuzz_random_frame_sizes_preserve_quota_and_counter_consistency() let total = stats.get_user_total_octets(user); assert!(total <= quota_limit); assert_eq!(bytes_me2c.load(Ordering::Relaxed), total); +} + +#[tokio::test] +async fn positive_soft_overshoot_allows_burst_inside_soft_cap_then_blocks() { + let stats = Stats::new(); + let user = "soft-cap-boundary-user"; + let bytes_me2c = AtomicU64::new(0); + let quota_limit = 10u64; + let overshoot = 3u64; + + stats.add_user_octets_from(user, 10); + + let mut writer_one = make_crypto_writer(tokio::io::sink()); + let mut frame_buf_one = Vec::new(); + let first = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[1, 2, 3]), + }, + &mut writer_one, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf_one, + &stats, + user, + Some(quota_limit), + overshoot, + &bytes_me2c, + 7401, + false, + false, + ) + .await; + assert!(first.is_ok(), "soft-cap buffer should allow reaching limit+overshoot"); + assert_eq!(stats.get_user_total_octets(user), 13); + + let mut writer_two = make_crypto_writer(tokio::io::sink()); + let mut frame_buf_two = Vec::new(); + let second = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[9]), + }, + &mut writer_two, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf_two, + &stats, + user, + Some(quota_limit), + overshoot, + &bytes_me2c, + 7402, + false, + false, + ) + .await; + assert!(matches!(second, Err(ProxyError::DataQuotaExceeded { .. }))); + assert_eq!(stats.get_user_total_octets(user), 13); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), 3); +} + +#[tokio::test] +async fn negative_soft_overshoot_rejects_when_payload_exceeds_remaining_soft_budget() { + let stats = Stats::new(); + let user = "soft-cap-remaining-user"; + let bytes_me2c = AtomicU64::new(0); + let quota_limit = 10u64; + let overshoot = 4u64; + + stats.add_user_octets_from(user, 12); + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let result = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[1, 2, 3]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + user, + Some(quota_limit), + overshoot, + &bytes_me2c, + 7501, + false, + false, + ) + .await; + + assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. }))); + assert_eq!(stats.get_user_total_octets(user), 12); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0); +} + +#[tokio::test] +async fn negative_write_failure_rolls_back_reservation_under_soft_cap_mode() { + let stats = Stats::new(); + let user = "soft-cap-rollback-user"; + let bytes_me2c = AtomicU64::new(0); + let mut writer = make_crypto_writer(FailingWriter); + let mut frame_buf = Vec::new(); + + stats.add_user_octets_from(user, 9); + + let result = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[1, 2, 3]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + user, + Some(10), + 8, + &bytes_me2c, + 7601, + false, + false, + ) + .await; + + assert!(matches!(result, Err(ProxyError::Io(_)))); + assert_eq!(stats.get_user_total_octets(user), 9); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn adversarial_parallel_soft_cap_stress_never_exceeds_soft_limit() { + let stats = Arc::new(Stats::new()); + let user = "soft-cap-stress-user"; + let quota_limit = 40u64; + let overshoot = 5u64; + let soft_limit = quota_limit + overshoot; + let bytes_me2c = Arc::new(AtomicU64::new(0)); + let mut tasks = JoinSet::new(); + + for idx in 0..256u64 { + let user_owned = user.to_string(); + let stats_ref = Arc::clone(&stats); + let bytes_ref = Arc::clone(&bytes_me2c); + tasks.spawn(async move { + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0x42]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + stats_ref.as_ref(), + &user_owned, + Some(quota_limit), + overshoot, + bytes_ref.as_ref(), + 7700 + idx, + false, + false, + ) + .await + }); + } + + while let Some(joined) = tasks.join_next().await { + match joined.expect("soft-cap stress task must not panic") { + Ok(_) | Err(ProxyError::DataQuotaExceeded { .. }) => {} + Err(other) => panic!("unexpected error in soft-cap stress case: {other:?}"), + } + } + + let total = stats.get_user_total_octets(user); + assert!(total <= soft_limit, "soft-cap stress must never overshoot soft limit"); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), total); +} + +#[tokio::test] +async fn light_fuzz_soft_cap_matrix_keeps_counters_and_limits_consistent() { + let stats = Stats::new(); + let user = "soft-cap-fuzz-user"; + let bytes_me2c = AtomicU64::new(0); + let mut seed = 0x9E37_79B9_7F4A_7C15u64; + + for conn in 0..1024u64 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let quota_limit = 32 + (seed & 0x3f); + let overshoot = seed.rotate_left(13) & 0x0f; + let len = ((seed >> 3) & 0x07) + 1; + let payload = vec![0xA5; len as usize]; + let before = stats.get_user_total_octets(user); + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let result = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from(payload), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + user, + Some(quota_limit), + overshoot, + &bytes_me2c, + 7800 + conn, + false, + false, + ) + .await; + + if let Err(ref err) = result { + assert!( + matches!(err, ProxyError::DataQuotaExceeded { .. }), + "soft-cap fuzz produced unexpected error variant: {err:?}" + ); + } + + let after = stats.get_user_total_octets(user); + let soft_limit = quota_limit.saturating_add(overshoot); + match result { + Ok(_) => { + assert_eq!(after, before.saturating_add(len)); + assert!(after <= soft_limit, "accepted write must stay within active soft cap"); + } + Err(_) => { + assert_eq!(after, before, "rejected write must not mutate quota state"); + } + } + assert_eq!( + bytes_me2c.load(Ordering::Relaxed), + after, + "soft-cap fuzz must keep counters synchronized" + ); + } +} + +#[tokio::test] +async fn positive_no_quota_limit_accumulates_data_octets_exactly() { + let stats = Stats::new(); + let user = "no-quota-user"; + let bytes_me2c = AtomicU64::new(0); + let mut expected = 0u64; + + for (idx, len) in [1usize, 2, 3, 5, 8, 13, 21].iter().copied().enumerate() { + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let payload = vec![0x41; len]; + let result = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from(payload), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + user, + None, + 0, + &bytes_me2c, + 7900 + idx as u64, + false, + false, + ) + .await; + + assert!(result.is_ok()); + expected += len as u64; + } + + assert_eq!(stats.get_user_total_octets(user), expected); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), expected); +} + +#[tokio::test] +async fn negative_zero_quota_rejects_non_empty_payload() { + let stats = Stats::new(); + let user = "zero-quota-user"; + let bytes_me2c = AtomicU64::new(0); + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let result = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0xAA]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + user, + Some(0), + 0, + &bytes_me2c, + 8001, + false, + false, + ) + .await; + + assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. }))); + assert_eq!(stats.get_user_total_octets(user), 0); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0); +} + +#[tokio::test] +async fn edge_zero_length_payload_with_zero_quota_is_fail_closed() { + let stats = Stats::new(); + let user = "zero-len-zero-quota-user"; + let bytes_me2c = AtomicU64::new(0); + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let result = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::new(), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + user, + Some(0), + 0, + &bytes_me2c, + 8002, + false, + false, + ) + .await; + + assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. }))); + assert_eq!(stats.get_user_total_octets(user), 0); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0); +} + +#[tokio::test] +async fn positive_ack_response_does_not_touch_quota_counters() { + let stats = Stats::new(); + let user = "ack-accounting-user"; + let bytes_me2c = AtomicU64::new(11); + stats.add_user_octets_to(user, 23); + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let result = process_me_writer_response( + MeResponse::Ack(0x33445566), + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + user, + Some(24), + 0, + &bytes_me2c, + 8003, + true, + true, + ) + .await; + + assert!(result.is_ok()); + assert_eq!(stats.get_user_total_octets(user), 23); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), 11); +} + +#[tokio::test] +async fn edge_close_response_is_accounting_noop() { + let stats = Stats::new(); + let user = "close-accounting-user"; + let bytes_me2c = AtomicU64::new(19); + stats.add_user_octets_to(user, 31); + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let result = process_me_writer_response( + MeResponse::Close, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + user, + Some(40), + 3, + &bytes_me2c, + 8004, + false, + true, + ) + .await; + + assert!(result.is_ok()); + assert_eq!(stats.get_user_total_octets(user), 31); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), 19); +} + +#[tokio::test] +async fn negative_preloaded_above_soft_cap_rejects_even_single_byte() { + let stats = Stats::new(); + let user = "preloaded-over-soft-cap-user"; + let bytes_me2c = AtomicU64::new(0); + let quota_limit = 20u64; + let overshoot = 2u64; + stats.add_user_octets_to(user, quota_limit + overshoot + 1); + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let result = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[1]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + user, + Some(quota_limit), + overshoot, + &bytes_me2c, + 8005, + false, + false, + ) + .await; + + assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. }))); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0); + assert_eq!(stats.get_user_total_octets(user), quota_limit + overshoot + 1); +} + +#[tokio::test] +async fn adversarial_fail_writer_path_never_desynchronizes_quota_accounting() { + let stats = Stats::new(); + let user = "partial-write-rollback-user"; + let bytes_me2c = AtomicU64::new(0); + let mut writer = make_crypto_writer(FailAfterBudgetWriter::new(7)); + let mut frame_buf = Vec::new(); + let payload_len = 16 * 1024u64; + + let result = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from(vec![0x42; 16 * 1024]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + user, + Some(payload_len), + 0, + &bytes_me2c, + 8006, + false, + false, + ) + .await; + + let total_after = stats.get_user_total_octets(user); + let forensic_after = bytes_me2c.load(Ordering::Relaxed); + assert_eq!(forensic_after, total_after); + assert!( + total_after == 0 || total_after == payload_len, + "writer failure path must either roll back fully or commit exactly one payload" + ); + + // Regardless of whether I/O failure surfaced immediately or was deferred, + // accounting must remain fail-closed and prevent silent overshoot. + let mut writer_two = make_crypto_writer(tokio::io::sink()); + let mut frame_buf_two = Vec::new(); + let second = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0x99]), + }, + &mut writer_two, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf_two, + &stats, + user, + Some(payload_len), + 0, + &bytes_me2c, + 8007, + false, + false, + ) + .await; + + if total_after == payload_len { + assert!(matches!(second, Err(ProxyError::DataQuotaExceeded { .. }))); + } else { + assert!(second.is_ok()); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_parallel_oversized_frames_fail_closed_without_counter_leak() { + let stats = Arc::new(Stats::new()); + let user = "parallel-fail-rollback-user"; + let bytes_me2c = Arc::new(AtomicU64::new(0)); + let mut tasks = JoinSet::new(); + + for idx in 0..256u64 { + let user_owned = user.to_string(); + let stats_ref = Arc::clone(&stats); + let bytes_ref = Arc::clone(&bytes_me2c); + tasks.spawn(async move { + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from(vec![0xEE; 12 * 1024]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + stats_ref.as_ref(), + &user_owned, + Some(512), + 0, + bytes_ref.as_ref(), + 8100 + idx, + false, + false, + ) + .await + }); + } + + while let Some(joined) = tasks.join_next().await { + let result = joined.expect("parallel fail writer task must not panic"); + assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. }))); + } + + assert_eq!(stats.get_user_total_octets(user), 0); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0); +} + +#[tokio::test] +async fn integration_mixed_data_ack_close_sequence_preserves_data_only_accounting() { + let stats = Stats::new(); + let user = "mixed-sequence-user"; + let bytes_me2c = AtomicU64::new(0); + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + + let data_one = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[1, 2, 3]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + user, + Some(32), + 0, + &bytes_me2c, + 8201, + false, + false, + ) + .await; + assert!(data_one.is_ok()); + + let ack = process_me_writer_response( + MeResponse::Ack(0x0102_0304), + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + user, + Some(32), + 0, + &bytes_me2c, + 8202, + true, + true, + ) + .await; + assert!(ack.is_ok()); + + let data_two = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[4, 5]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + user, + Some(32), + 0, + &bytes_me2c, + 8203, + false, + true, + ) + .await; + assert!(data_two.is_ok()); + + let close = process_me_writer_response( + MeResponse::Close, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + user, + Some(32), + 0, + &bytes_me2c, + 8204, + false, + true, + ) + .await; + assert!(close.is_ok()); + + assert_eq!(stats.get_user_total_octets(user), 5); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), 5); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_parallel_multi_user_quota_isolation_no_cross_user_leakage() { + let stats = Arc::new(Stats::new()); + let user_a = "quota-isolation-a"; + let user_b = "quota-isolation-b"; + let limit_a = 50u64; + let limit_b = 80u64; + let bytes_a = Arc::new(AtomicU64::new(0)); + let bytes_b = Arc::new(AtomicU64::new(0)); + + let mut tasks = JoinSet::new(); + for idx in 0..200u64 { + let stats_ref = Arc::clone(&stats); + let bytes_ref = Arc::clone(&bytes_a); + tasks.spawn(async move { + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0xA1]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + stats_ref.as_ref(), + user_a, + Some(limit_a), + 0, + bytes_ref.as_ref(), + 8300 + idx, + false, + false, + ) + .await + }); + } + + for idx in 0..220u64 { + let stats_ref = Arc::clone(&stats); + let bytes_ref = Arc::clone(&bytes_b); + tasks.spawn(async move { + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0xB2]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + stats_ref.as_ref(), + user_b, + Some(limit_b), + 0, + bytes_ref.as_ref(), + 8500 + idx, + false, + false, + ) + .await + }); + } + + while let Some(joined) = tasks.join_next().await { + let result = joined.expect("quota isolation task must not panic"); + assert!(result.is_ok() || matches!(result, Err(ProxyError::DataQuotaExceeded { .. }))); + } + + assert_eq!(stats.get_user_total_octets(user_a), limit_a); + assert_eq!(stats.get_user_total_octets(user_b), limit_b); + assert_eq!(bytes_a.load(Ordering::Relaxed), limit_a); + assert_eq!(bytes_b.load(Ordering::Relaxed), limit_b); +} + +#[tokio::test] +async fn light_fuzz_mixed_me_responses_preserve_quota_and_counter_invariants() { + let stats = Stats::new(); + let user = "mixed-fuzz-user"; + let bytes_me2c = AtomicU64::new(0); + let quota_limit = 96u64; + let mut seed = 0xDEAD_BEEF_2026_0323u64; + + for idx in 0..2048u64 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let choice = (seed & 0x03) as u8; + let response = if choice == 0 { + MeResponse::Ack((seed >> 8) as u32) + } else if choice == 1 { + MeResponse::Close + } else { + let len = ((seed >> 16) & 0x07) as usize; + let mut payload = vec![0u8; len]; + payload.fill((seed & 0xff) as u8); + MeResponse::Data { + flags: 0, + data: Bytes::from(payload), + } + }; + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let result = process_me_writer_response( + response, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + user, + Some(quota_limit), + 0, + &bytes_me2c, + 8800 + idx, + (idx & 1) == 0, + (idx & 2) == 0, + ) + .await; + + if let Err(err) = result { + assert!( + matches!(err, ProxyError::DataQuotaExceeded { .. }), + "mixed fuzz produced unexpected error variant: {err:?}" + ); + } + + let total = stats.get_user_total_octets(user); + assert!( + total <= quota_limit, + "mixed fuzz must keep usage at or below quota limit" + ); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), total); + } } \ No newline at end of file diff --git a/src/proxy/tests/middle_relay_quota_reservation_extreme_security_tests.rs b/src/proxy/tests/middle_relay_quota_reservation_extreme_security_tests.rs new file mode 100644 index 0000000..e4d0c6e --- /dev/null +++ b/src/proxy/tests/middle_relay_quota_reservation_extreme_security_tests.rs @@ -0,0 +1,399 @@ +use super::*; +use crate::crypto::{AesCtr, SecureRandom}; +use crate::stats::Stats; +use crate::stream::CryptoWriter; +use bytes::Bytes; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::{Arc, Mutex, OnceLock}; +use tokio::sync::Mutex as AsyncMutex; +use tokio::task::JoinSet; +use tokio::time::{Duration, timeout}; + +fn make_crypto_writer(writer: W) -> CryptoWriter +where + W: tokio::io::AsyncWrite + Unpin, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) +} + +fn lookup_counter_test_lock() -> &'static Mutex<()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())) +} + +#[tokio::test] +async fn positive_prefetched_cross_mode_lock_multi_frame_accounting_is_exact() { + let _guard = lookup_counter_test_lock() + .lock() + .unwrap_or_else(|poison| poison.into_inner()); + + let stats = Stats::new(); + let user = format!("quota-extreme-positive-{}", std::process::id()); + let lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); + crate::proxy::quota_lock_registry::reset_cross_mode_quota_user_lock_lookup_count_for_tests(); + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let bytes_me2c = AtomicU64::new(0); + + for idx in 0..12u64 { + let payload = vec![0x5A; ((idx % 4) + 1) as usize]; + let result = process_me_writer_response_with_cross_mode_lock( + MeResponse::Data { + flags: 0, + data: Bytes::from(payload), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(512), + 0, + Some(&lock), + &bytes_me2c, + 31_000 + idx, + false, + false, + ) + .await; + + assert!(result.is_ok()); + } + + assert_eq!( + crate::proxy::quota_lock_registry::cross_mode_quota_user_lock_lookup_count_for_user_for_tests(&user), + 0, + "prefetched lock path must avoid hot-path registry lookups" + ); + assert_eq!( + stats.get_user_total_octets(&user), + bytes_me2c.load(Ordering::Relaxed), + "forensics and quota accounting must remain synchronized" + ); +} + +#[tokio::test] +async fn negative_held_prefetched_lock_blocks_writer_without_accounting_mutation() { + let _guard = lookup_counter_test_lock() + .lock() + .unwrap_or_else(|poison| poison.into_inner()); + + let stats = Stats::new(); + let user = format!("quota-extreme-negative-{}", std::process::id()); + let lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); + let held_guard = lock + .try_lock() + .expect("test must hold lock before calling ME->C writer"); + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let bytes_me2c = AtomicU64::new(0); + + let blocked = timeout( + Duration::from_millis(25), + process_me_writer_response_with_cross_mode_lock( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[1, 2, 3]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(64), + 0, + Some(&lock), + &bytes_me2c, + 31_100, + false, + false, + ), + ) + .await; + + assert!(blocked.is_err()); + assert_eq!(stats.get_user_total_octets(&user), 0); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0); + + drop(held_guard); +} + +#[tokio::test] +async fn edge_zero_quota_and_zero_payload_is_fail_closed() { + let _guard = lookup_counter_test_lock() + .lock() + .unwrap_or_else(|poison| poison.into_inner()); + + let stats = Stats::new(); + let user = format!("quota-extreme-edge-{}", std::process::id()); + let lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let bytes_me2c = AtomicU64::new(0); + + let result = process_me_writer_response_with_cross_mode_lock( + MeResponse::Data { + flags: 0, + data: Bytes::new(), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(0), + 0, + Some(&lock), + &bytes_me2c, + 31_200, + false, + false, + ) + .await; + + assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. }))); + assert_eq!(stats.get_user_total_octets(&user), 0); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn adversarial_blackhat_parallel_quota_race_never_overshoots_soft_cap() { + let _guard = lookup_counter_test_lock() + .lock() + .unwrap_or_else(|poison| poison.into_inner()); + + let stats = Arc::new(Stats::new()); + let user = format!("quota-extreme-blackhat-{}", std::process::id()); + let quota = 80u64; + let overshoot = 7u64; + let soft_limit = quota + overshoot; + let lock = Arc::new(crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user)); + let bytes_me2c = Arc::new(AtomicU64::new(0)); + + let mut set = JoinSet::new(); + for idx in 0..256u64 { + let stats = Arc::clone(&stats); + let user = user.clone(); + let lock = Arc::clone(&lock); + let bytes_me2c = Arc::clone(&bytes_me2c); + + set.spawn(async move { + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let len = ((idx % 5) + 1) as usize; + let payload = vec![0xAA; len]; + + process_me_writer_response_with_cross_mode_lock( + MeResponse::Data { + flags: 0, + data: Bytes::from(payload), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + stats.as_ref(), + &user, + Some(quota), + overshoot, + Some(&lock), + bytes_me2c.as_ref(), + 31_300 + idx, + false, + false, + ) + .await + }); + } + + while let Some(done) = set.join_next().await { + match done.expect("task must not panic") { + Ok(_) | Err(ProxyError::DataQuotaExceeded { .. }) => {} + Err(other) => panic!("unexpected error variant under black-hat race: {other:?}"), + } + } + + let total = stats.get_user_total_octets(&user); + assert!( + total <= soft_limit, + "parallel adversarial race must stay under soft cap" + ); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), total); +} + +#[tokio::test] +async fn integration_without_prefetched_lock_uses_registry_lookup_path() { + let _guard = lookup_counter_test_lock() + .lock() + .unwrap_or_else(|poison| poison.into_inner()); + + let stats = Stats::new(); + let user = format!("quota-extreme-integration-{}", std::process::id()); + crate::proxy::quota_lock_registry::reset_cross_mode_quota_user_lock_lookup_count_for_tests(); + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let bytes_me2c = AtomicU64::new(0); + + for idx in 0..3u64 { + let result = process_me_writer_response_with_cross_mode_lock( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0x41]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(16), + 0, + None, + &bytes_me2c, + 31_400 + idx, + false, + false, + ) + .await; + + assert!(result.is_ok()); + } + + assert_eq!( + crate::proxy::quota_lock_registry::cross_mode_quota_user_lock_lookup_count_for_user_for_tests(&user), + 3, + "control path should perform one lock-registry lookup per call" + ); +} + +#[tokio::test] +async fn light_fuzz_quota_matrix_preserves_fail_closed_accounting() { + let _guard = lookup_counter_test_lock() + .lock() + .unwrap_or_else(|poison| poison.into_inner()); + + let stats = Stats::new(); + let user = format!("quota-extreme-fuzz-{}", std::process::id()); + let lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); + let bytes_me2c = AtomicU64::new(0); + let mut seed = 0xA11C_55EE_2026_0323u64; + + for idx in 0..512u64 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let quota = 24 + (seed & 0x3f); + let overshoot = (seed >> 13) & 0x0f; + let len = ((seed >> 19) & 0x07) + 1; + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let before = stats.get_user_total_octets(&user); + + let result = process_me_writer_response_with_cross_mode_lock( + MeResponse::Data { + flags: 0, + data: Bytes::from(vec![0x11; len as usize]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(quota), + overshoot, + Some(&lock), + &bytes_me2c, + 31_500 + idx, + false, + false, + ) + .await; + + let after = stats.get_user_total_octets(&user); + if result.is_ok() { + assert!(after >= before); + } else { + assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. }))); + assert_eq!(after, before); + } + assert_eq!(bytes_me2c.load(Ordering::Relaxed), after); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_prefetched_lock_high_fanout_exact_quota_success_count() { + let _guard = lookup_counter_test_lock() + .lock() + .unwrap_or_else(|poison| poison.into_inner()); + + let stats = Arc::new(Stats::new()); + let user = format!("quota-extreme-stress-{}", std::process::id()); + let quota = 96u64; + let lock: Arc> = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); + let bytes_me2c = Arc::new(AtomicU64::new(0)); + + crate::proxy::quota_lock_registry::reset_cross_mode_quota_user_lock_lookup_count_for_tests(); + + let mut set = JoinSet::new(); + for idx in 0..384u64 { + let stats = Arc::clone(&stats); + let user = user.clone(); + let lock = Arc::clone(&lock); + let bytes_me2c = Arc::clone(&bytes_me2c); + + set.spawn(async move { + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + process_me_writer_response_with_cross_mode_lock( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0xFF]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + stats.as_ref(), + &user, + Some(quota), + 0, + Some(&lock), + bytes_me2c.as_ref(), + 31_600 + idx, + false, + false, + ) + .await + }); + } + + let mut success = 0usize; + while let Some(done) = set.join_next().await { + match done.expect("task must not panic") { + Ok(_) => success += 1, + Err(ProxyError::DataQuotaExceeded { .. }) => {} + Err(other) => panic!("unexpected error variant in stress fanout: {other:?}"), + } + } + + assert_eq!(success, quota as usize); + assert_eq!(stats.get_user_total_octets(&user), quota); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), quota); + assert_eq!( + crate::proxy::quota_lock_registry::cross_mode_quota_user_lock_lookup_count_for_user_for_tests(&user), + 0, + "stress prefetched path must not use lock registry lookups" + ); +} diff --git a/src/proxy/tests/middle_relay_tiny_frame_debt_concurrency_security_tests.rs b/src/proxy/tests/middle_relay_tiny_frame_debt_concurrency_security_tests.rs index 1bf3123..34fc454 100644 --- a/src/proxy/tests/middle_relay_tiny_frame_debt_concurrency_security_tests.rs +++ b/src/proxy/tests/middle_relay_tiny_frame_debt_concurrency_security_tests.rs @@ -7,7 +7,7 @@ use std::sync::atomic::AtomicU64; use std::time::Instant; use tokio::io::{AsyncRead, AsyncWriteExt, duplex}; use tokio::task::JoinSet; -use tokio::time::{Duration as TokioDuration, sleep, timeout}; +use tokio::time::{Duration as TokioDuration, sleep}; fn make_crypto_reader(reader: T) -> CryptoReader where @@ -42,10 +42,10 @@ fn make_forensics(conn_id: u64, started_at: Instant) -> RelayForensicsState { fn make_enabled_idle_policy() -> RelayClientIdlePolicy { RelayClientIdlePolicy { enabled: true, - soft_idle: Duration::from_secs(30), - hard_idle: Duration::from_secs(60), + soft_idle: Duration::from_millis(50), + hard_idle: Duration::from_millis(120), grace_after_downstream_activity: Duration::from_secs(0), - legacy_frame_read_timeout: Duration::from_secs(30), + legacy_frame_read_timeout: Duration::from_millis(50), } } @@ -94,8 +94,8 @@ async fn stress_parallel_pure_tiny_floods_all_fail_closed() { writer.write_all(&flood_encrypted).await.unwrap(); drop(writer); - let result = timeout( - TokioDuration::from_secs(1), + let result = run_relay_test_step_timeout( + "tiny flood task", read_once( &mut crypto_reader, ProtoTag::Abridged, @@ -104,8 +104,7 @@ async fn stress_parallel_pure_tiny_floods_all_fail_closed() { &mut idle_state, ), ) - .await - .expect("tiny flood task must complete"); + .await; assert!(matches!(result, Err(ProxyError::Proxy(_)))); assert_eq!(frame_counter, 0); @@ -140,8 +139,8 @@ async fn stress_parallel_benign_tiny_burst_then_real_all_pass() { let encrypted = encrypt_for_reader(&plaintext); writer.write_all(&encrypted).await.unwrap(); - let result = timeout( - TokioDuration::from_secs(1), + let result = run_relay_test_step_timeout( + "benign tiny burst read", read_once( &mut crypto_reader, ProtoTag::Abridged, @@ -151,7 +150,6 @@ async fn stress_parallel_benign_tiny_burst_then_real_all_pass() { ), ) .await - .expect("benign task must complete") .expect("benign payload must parse") .expect("benign payload must return frame"); @@ -196,8 +194,8 @@ async fn adversarial_lockstep_alternating_attack_under_jitter_closes() { let mut closed = false; for _ in 0..220 { - let result = timeout( - TokioDuration::from_secs(1), + let result = run_relay_test_step_timeout( + "alternating jitter read step", read_once( &mut crypto_reader, ProtoTag::Abridged, @@ -206,8 +204,7 @@ async fn adversarial_lockstep_alternating_attack_under_jitter_closes() { &mut idle_state, ), ) - .await - .expect("alternating reader step must complete"); + .await; match result { Ok(Some((_payload, _))) => {} @@ -336,8 +333,8 @@ async fn light_fuzz_parallel_patterns_no_hang_or_panic() { drop(writer); for _ in 0..320 { - let step = timeout( - TokioDuration::from_secs(1), + let step = run_relay_test_step_timeout( + "fuzz case read step", read_once( &mut crypto_reader, ProtoTag::Abridged, @@ -346,8 +343,7 @@ async fn light_fuzz_parallel_patterns_no_hang_or_panic() { &mut idle_state, ), ) - .await - .expect("fuzz case read step must complete"); + .await; match step { Ok(Some(_)) => {} diff --git a/src/proxy/tests/middle_relay_tiny_frame_debt_proto_chunking_security_tests.rs b/src/proxy/tests/middle_relay_tiny_frame_debt_proto_chunking_security_tests.rs index 0ff46a2..853b381 100644 --- a/src/proxy/tests/middle_relay_tiny_frame_debt_proto_chunking_security_tests.rs +++ b/src/proxy/tests/middle_relay_tiny_frame_debt_proto_chunking_security_tests.rs @@ -6,7 +6,7 @@ use std::sync::Arc; use std::sync::atomic::AtomicU64; use std::time::Instant; use tokio::io::{AsyncRead, AsyncWriteExt, duplex}; -use tokio::time::{Duration as TokioDuration, sleep, timeout}; +use tokio::time::{Duration as TokioDuration, sleep}; fn make_crypto_reader(reader: T) -> CryptoReader where @@ -41,10 +41,10 @@ fn make_forensics(conn_id: u64, started_at: Instant) -> RelayForensicsState { fn make_enabled_idle_policy() -> RelayClientIdlePolicy { RelayClientIdlePolicy { enabled: true, - soft_idle: Duration::from_secs(30), - hard_idle: Duration::from_secs(60), + soft_idle: Duration::from_millis(50), + hard_idle: Duration::from_millis(120), grace_after_downstream_activity: Duration::from_secs(0), - legacy_frame_read_timeout: Duration::from_secs(30), + legacy_frame_read_timeout: Duration::from_millis(50), } } @@ -117,6 +117,11 @@ async fn read_once_with_state( .await } +fn is_fail_closed_outcome(result: &Result>) -> bool { + matches!(result, Err(ProxyError::Proxy(_))) + || matches!(result, Err(ProxyError::Io(e)) if e.kind() == std::io::ErrorKind::TimedOut) +} + #[tokio::test] async fn intermediate_chunked_zero_flood_fail_closed() { let (reader, mut writer) = duplex(4096); @@ -134,8 +139,8 @@ async fn intermediate_chunked_zero_flood_fail_closed() { write_chunked_with_jitter(&mut writer, &encrypted, 0x1111_2222).await; drop(writer); - let result = timeout( - TokioDuration::from_secs(2), + let result = run_relay_test_step_timeout( + "intermediate flood read", read_once_with_state( &mut crypto_reader, ProtoTag::Intermediate, @@ -144,10 +149,12 @@ async fn intermediate_chunked_zero_flood_fail_closed() { &mut idle_state, ), ) - .await - .expect("intermediate flood read must complete"); + .await; - assert!(matches!(result, Err(ProxyError::Proxy(_)))); + assert!( + is_fail_closed_outcome(&result), + "zero-length flood must fail closed via debt guard or idle timeout" + ); assert_eq!(frame_counter, 0); } @@ -168,8 +175,8 @@ async fn secure_chunked_zero_flood_fail_closed() { write_chunked_with_jitter(&mut writer, &encrypted, 0x3333_4444).await; drop(writer); - let result = timeout( - TokioDuration::from_secs(2), + let result = run_relay_test_step_timeout( + "secure flood read", read_once_with_state( &mut crypto_reader, ProtoTag::Secure, @@ -178,10 +185,12 @@ async fn secure_chunked_zero_flood_fail_closed() { &mut idle_state, ), ) - .await - .expect("secure flood read must complete"); + .await; - assert!(matches!(result, Err(ProxyError::Proxy(_)))); + assert!( + is_fail_closed_outcome(&result), + "secure zero-length flood must fail closed via debt guard or idle timeout" + ); assert_eq!(frame_counter, 0); } @@ -208,8 +217,8 @@ async fn intermediate_chunked_alternating_attack_closes_before_eof() { let mut closed = false; for _ in 0..240 { - let step = timeout( - TokioDuration::from_secs(1), + let step = run_relay_test_step_timeout( + "intermediate alternating read step", read_once_with_state( &mut crypto_reader, ProtoTag::Intermediate, @@ -218,8 +227,7 @@ async fn intermediate_chunked_alternating_attack_closes_before_eof() { &mut idle_state, ), ) - .await - .expect("intermediate alternating read step must complete"); + .await; match step { Ok(Some(_)) => {} @@ -259,8 +267,8 @@ async fn secure_chunked_alternating_attack_closes_before_eof() { let mut closed = false; for _ in 0..240 { - let step = timeout( - TokioDuration::from_secs(1), + let step = run_relay_test_step_timeout( + "secure alternating read step", read_once_with_state( &mut crypto_reader, ProtoTag::Secure, @@ -269,8 +277,7 @@ async fn secure_chunked_alternating_attack_closes_before_eof() { &mut idle_state, ), ) - .await - .expect("secure alternating read step must complete"); + .await; match step { Ok(Some(_)) => {} @@ -394,8 +401,8 @@ async fn light_fuzz_proto_chunking_outcomes_are_bounded() { drop(writer); for _ in 0..260 { - let step = timeout( - TokioDuration::from_secs(1), + let step = run_relay_test_step_timeout( + "fuzz proto read step", read_once_with_state( &mut crypto_reader, proto, @@ -404,12 +411,12 @@ async fn light_fuzz_proto_chunking_outcomes_are_bounded() { &mut idle_state, ), ) - .await - .expect("fuzz proto read step must complete"); + .await; match step { Ok(Some((_payload, _))) => {} Err(ProxyError::Proxy(_)) => break, + Err(ProxyError::Io(e)) if e.kind() == std::io::ErrorKind::TimedOut => break, Ok(None) => break, Err(other) => panic!("unexpected proto chunking fuzz error: {other}"), } diff --git a/src/proxy/tests/middle_relay_tiny_frame_debt_security_tests.rs b/src/proxy/tests/middle_relay_tiny_frame_debt_security_tests.rs index d0719c8..dee5dd9 100644 --- a/src/proxy/tests/middle_relay_tiny_frame_debt_security_tests.rs +++ b/src/proxy/tests/middle_relay_tiny_frame_debt_security_tests.rs @@ -40,13 +40,44 @@ fn make_forensics(conn_id: u64, started_at: Instant) -> RelayForensicsState { fn make_enabled_idle_policy() -> RelayClientIdlePolicy { RelayClientIdlePolicy { enabled: true, - soft_idle: Duration::from_secs(30), - hard_idle: Duration::from_secs(60), + soft_idle: Duration::from_millis(50), + hard_idle: Duration::from_millis(120), grace_after_downstream_activity: Duration::from_secs(0), - legacy_frame_read_timeout: Duration::from_secs(30), + legacy_frame_read_timeout: Duration::from_millis(50), } } +async fn read_bounded( + crypto_reader: &mut CryptoReader, + proto_tag: ProtoTag, + buffer_pool: &Arc, + forensics: &RelayForensicsState, + frame_counter: &mut u64, + stats: &Stats, + idle_policy: &RelayClientIdlePolicy, + idle_state: &mut RelayClientIdleState, + last_downstream_activity_ms: &AtomicU64, + session_started_at: Instant, +) -> Result> { + run_relay_test_step_timeout( + "tiny-frame debt read step", + read_client_payload_with_idle_policy( + crypto_reader, + proto_tag, + 1024, + buffer_pool, + forensics, + frame_counter, + stats, + idle_policy, + idle_state, + last_downstream_activity_ms, + session_started_at, + ), + ) + .await +} + fn simulate_tiny_debt_pattern(pattern: &[bool], max_steps: usize) -> (Option, u32, usize) { let mut debt = 0u32; let mut reals = 0usize; @@ -246,10 +277,9 @@ async fn idle_policy_enabled_intermediate_zero_length_flood_is_fail_closed() { writer.write_all(&flood_encrypted).await.unwrap(); drop(writer); - let result = read_client_payload_with_idle_policy( + let result = read_bounded( &mut crypto_reader, ProtoTag::Intermediate, - 1024, &buffer_pool, &forensics, &mut frame_counter, @@ -282,10 +312,9 @@ async fn idle_policy_enabled_secure_zero_length_flood_is_fail_closed() { writer.write_all(&flood_encrypted).await.unwrap(); drop(writer); - let result = read_client_payload_with_idle_policy( + let result = read_bounded( &mut crypto_reader, ProtoTag::Secure, - 1024, &buffer_pool, &forensics, &mut frame_counter, @@ -325,10 +354,9 @@ async fn intermediate_alternating_zero_and_real_eventually_closes() { let mut closed = false; for _ in 0..220 { - let result = read_client_payload_with_idle_policy( + let result = read_bounded( &mut crypto_reader, ProtoTag::Intermediate, - 1024, &buffer_pool, &forensics, &mut frame_counter, @@ -377,10 +405,9 @@ async fn small_tiny_burst_followed_by_real_frame_does_not_spuriously_close() { let encrypted = encrypt_for_reader(&plaintext); writer.write_all(&encrypted).await.unwrap(); - let first = read_client_payload_with_idle_policy( + let first = read_bounded( &mut crypto_reader, ProtoTag::Abridged, - 1024, &buffer_pool, &forensics, &mut frame_counter, @@ -420,10 +447,9 @@ async fn idle_policy_enabled_zero_length_flood_is_fail_closed() { .expect("zero-length flood bytes must be writable"); drop(writer); - let result = read_client_payload_with_idle_policy( + let result = read_bounded( &mut crypto_reader, ProtoTag::Abridged, - 1024, &buffer_pool, &forensics, &mut frame_counter, @@ -470,10 +496,9 @@ async fn idle_policy_enabled_alternating_tiny_real_eventually_closes() { let mut saw_proxy_close = false; for _ in 0..300 { - let result = read_client_payload_with_idle_policy( + let result = read_bounded( &mut crypto_reader, ProtoTag::Abridged, - 1024, &buffer_pool, &forensics, &mut frame_counter, @@ -527,10 +552,9 @@ async fn enabled_idle_policy_valid_nonzero_frame_still_passes() { .await .expect("nonzero frame must be writable"); - let result = read_client_payload_with_idle_policy( + let result = read_bounded( &mut crypto_reader, ProtoTag::Abridged, - 1024, &buffer_pool, &forensics, &mut frame_counter, @@ -548,3 +572,227 @@ async fn enabled_idle_policy_valid_nonzero_frame_still_passes() { assert!(!result.1); assert_eq!(frame_counter, 1); } + +#[tokio::test] +async fn abridged_quickack_tiny_flood_is_fail_closed() { + let (reader, mut writer) = duplex(4096); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(21, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_enabled_idle_policy(); + let last_downstream_activity_ms = AtomicU64::new(0); + + let flood_plaintext = vec![0x80u8; 256]; + let flood_encrypted = encrypt_for_reader(&flood_plaintext); + writer.write_all(&flood_encrypted).await.unwrap(); + drop(writer); + + let result = read_bounded( + &mut crypto_reader, + ProtoTag::Abridged, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ) + .await; + + assert!( + matches!(result, Err(ProxyError::Proxy(_))), + "quickack-marked zero-length flood must fail closed" + ); +} + +#[tokio::test] +async fn abridged_extended_zero_len_flood_is_fail_closed() { + let (reader, mut writer) = duplex(4096); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(22, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_enabled_idle_policy(); + let last_downstream_activity_ms = AtomicU64::new(0); + + let mut flood_plaintext = Vec::with_capacity(4 * 256); + for _ in 0..256 { + flood_plaintext.extend_from_slice(&[0x7f, 0x00, 0x00, 0x00]); + } + let flood_encrypted = encrypt_for_reader(&flood_plaintext); + writer.write_all(&flood_encrypted).await.unwrap(); + drop(writer); + + let result = read_bounded( + &mut crypto_reader, + ProtoTag::Abridged, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ) + .await; + + assert!( + matches!(result, Err(ProxyError::Proxy(_))), + "extended zero-length abridged flood must fail closed" + ); +} + +#[tokio::test] +async fn one_to_eight_abridged_wire_pattern_survives_without_false_positive_close() { + let mut plaintext = Vec::with_capacity(9 * 300); + for idx in 0..300usize { + plaintext.push(0x00); + for _ in 0..8 { + let b = idx as u8; + plaintext.push(0x01); + plaintext.extend_from_slice(&[b, b ^ 0x11, b ^ 0x22, b ^ 0x33]); + } + } + + // Keep the test single-task and deterministic: make duplex capacity larger than the + // generated ciphertext so write_all cannot block waiting for a concurrent reader. + let duplex_capacity = plaintext.len().saturating_add(1024); + let (reader, mut writer) = duplex(duplex_capacity); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(23, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_enabled_idle_policy(); + let last_downstream_activity_ms = AtomicU64::new(0); + + let encrypted = encrypt_for_reader(&plaintext); + writer.write_all(&encrypted).await.unwrap(); + drop(writer); + + let mut closed = false; + for _ in 0..3000 { + match read_bounded( + &mut crypto_reader, + ProtoTag::Abridged, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ) + .await + { + Ok(Some(_)) => {} + Ok(None) => break, + Err(ProxyError::Proxy(_)) => { + closed = true; + break; + } + Err(other) => panic!("unexpected error in 1:8 wire test: {other}"), + } + } + + assert!( + !closed, + "wire-level 1:8 tiny-to-real pattern should not trigger debt close" + ); +} + +#[tokio::test] +async fn deterministic_light_fuzz_abridged_wire_behavior_matches_model() { + let mut seed = 0xD1CE_BAAD_2026_0322u64; + + for case_idx in 0..32u64 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let events = 300 + ((seed as usize) & 0xff); + let mut pattern = Vec::with_capacity(events); + let mut local = seed; + for _ in 0..events { + local ^= local << 7; + local ^= local >> 9; + local ^= local << 8; + pattern.push((local & 0x03) == 0); + } + + let mut plaintext = Vec::with_capacity(events * 6); + for (idx, tiny) in pattern.iter().copied().enumerate() { + if tiny { + plaintext.push(0x00); + } else { + let b = (idx as u8) ^ (case_idx as u8); + plaintext.push(0x01); + plaintext.extend_from_slice(&[b, b ^ 0x1F, b ^ 0x7A, b ^ 0xC3]); + } + } + + let (reader, mut writer) = duplex(16 * 1024); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(500 + case_idx, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_enabled_idle_policy(); + let last_downstream_activity_ms = AtomicU64::new(0); + + writer + .write_all(&encrypt_for_reader(&plaintext)) + .await + .unwrap(); + drop(writer); + + let (expected_close, _, _) = simulate_tiny_debt_pattern(&pattern, pattern.len()); + let mut observed_close = false; + + for _ in 0..(events + 8) { + match read_bounded( + &mut crypto_reader, + ProtoTag::Abridged, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ) + .await + { + Ok(Some(_)) => {} + Ok(None) => break, + Err(ProxyError::Proxy(_)) => { + observed_close = true; + break; + } + Err(other) => panic!("unexpected fuzz error: {other}"), + } + } + + assert_eq!( + observed_close, + expected_close.is_some(), + "wire parser behavior must match debt model for case {case_idx}" + ); + } +} diff --git a/src/proxy/tests/relay_cross_mode_pipeline_hol_integration_security_tests.rs b/src/proxy/tests/relay_cross_mode_pipeline_hol_integration_security_tests.rs new file mode 100644 index 0000000..9ea921c --- /dev/null +++ b/src/proxy/tests/relay_cross_mode_pipeline_hol_integration_security_tests.rs @@ -0,0 +1,267 @@ +use super::relay_bidirectional; +use crate::stats::Stats; +use crate::stream::BufferPool; +use std::sync::Arc; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::time::{Duration, timeout}; + +fn quota_test_guard() -> impl Drop { + super::quota_user_lock_test_scope() +} + +#[tokio::test] +async fn negative_same_user_pipeline_stalls_while_middle_lock_is_held() { + let _guard = quota_test_guard(); + + let user = format!("relay-pipeline-stall-{}", std::process::id()); + let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user); + let held_guard = held + .try_lock() + .expect("test must hold shared cross-mode lock"); + + let stats = Arc::new(Stats::new()); + let (mut client_peer, relay_client) = duplex(1024); + let (relay_server, mut server_peer) = duplex(1024); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_user = user.clone(); + let relay_stats = Arc::clone(&stats); + let relay_task = tokio::spawn(async move { + relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 256, + 256, + &relay_user, + relay_stats, + Some(1024), + Arc::new(BufferPool::new()), + ) + .await + }); + + server_peer + .write_all(&[0xA1]) + .await + .expect("server write should enqueue while relay is stalled"); + + let mut one = [0u8; 1]; + let blocked_read = timeout(Duration::from_millis(40), client_peer.read_exact(&mut one)).await; + assert!( + blocked_read.is_err(), + "same-user relay must remain blocked while cross-mode lock is held" + ); + + drop(held_guard); + + timeout(Duration::from_millis(400), client_peer.read_exact(&mut one)) + .await + .expect("blocked relay must resume after cross-mode lock release") + .expect("resumed relay must deliver queued byte"); + assert_eq!(one, [0xA1]); + + drop(client_peer); + drop(server_peer); + + let relay_result = timeout(Duration::from_secs(1), relay_task) + .await + .expect("relay task must complete") + .expect("relay task must not panic"); + assert!(relay_result.is_ok()); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn integration_other_user_pipeline_progresses_while_blocked_user_is_stalled() { + let _guard = quota_test_guard(); + + let blocked_user = format!("relay-pipeline-blocked-{}", std::process::id()); + let free_user = format!("relay-pipeline-free-{}", std::process::id()); + let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&blocked_user); + let held_guard = held + .try_lock() + .expect("test must hold blocked user's shared cross-mode lock"); + + let stats_blocked = Arc::new(Stats::new()); + let stats_free = Arc::new(Stats::new()); + + let (mut blocked_client, blocked_relay_client) = duplex(1024); + let (blocked_relay_server, mut blocked_server) = duplex(1024); + let (blocked_client_reader, blocked_client_writer) = tokio::io::split(blocked_relay_client); + let (blocked_server_reader, blocked_server_writer) = tokio::io::split(blocked_relay_server); + + let (mut free_client, free_relay_client) = duplex(1024); + let (free_relay_server, mut free_server) = duplex(1024); + let (free_client_reader, free_client_writer) = tokio::io::split(free_relay_client); + let (free_server_reader, free_server_writer) = tokio::io::split(free_relay_server); + + let blocked_task = { + let user = blocked_user.clone(); + let stats = Arc::clone(&stats_blocked); + tokio::spawn(async move { + relay_bidirectional( + blocked_client_reader, + blocked_client_writer, + blocked_server_reader, + blocked_server_writer, + 256, + 256, + &user, + stats, + Some(1024), + Arc::new(BufferPool::new()), + ) + .await + }) + }; + + let free_task = { + let user = free_user.clone(); + let stats = Arc::clone(&stats_free); + tokio::spawn(async move { + relay_bidirectional( + free_client_reader, + free_client_writer, + free_server_reader, + free_server_writer, + 256, + 256, + &user, + stats, + Some(1024), + Arc::new(BufferPool::new()), + ) + .await + }) + }; + + blocked_server + .write_all(&[0xB1]) + .await + .expect("blocked user server write should queue"); + free_server + .write_all(&[0xC1]) + .await + .expect("free user server write should queue"); + + let mut blocked_buf = [0u8; 1]; + let mut free_buf = [0u8; 1]; + + let blocked_stalled = timeout( + Duration::from_millis(40), + blocked_client.read_exact(&mut blocked_buf), + ) + .await; + assert!( + blocked_stalled.is_err(), + "blocked user must remain stalled while its lock is held" + ); + + timeout(Duration::from_millis(250), free_client.read_exact(&mut free_buf)) + .await + .expect("free user must make progress while other user is blocked") + .expect("free user read must succeed"); + assert_eq!(free_buf, [0xC1]); + + drop(held_guard); + + timeout(Duration::from_millis(400), blocked_client.read_exact(&mut blocked_buf)) + .await + .expect("blocked user must resume after release") + .expect("blocked user resumed read must succeed"); + assert_eq!(blocked_buf, [0xB1]); + + drop(blocked_client); + drop(blocked_server); + drop(free_client); + drop(free_server); + + assert!( + timeout(Duration::from_secs(1), blocked_task) + .await + .expect("blocked relay task must complete") + .expect("blocked relay task must not panic") + .is_ok() + ); + assert!( + timeout(Duration::from_secs(1), free_task) + .await + .expect("free relay task must complete") + .expect("free relay task must not panic") + .is_ok() + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_jittered_hold_release_cycles_preserve_pipeline_liveness() { + let _guard = quota_test_guard(); + + let mut seed = 0x5EED_C0DE_2026_0323u64; + for round in 0..24u32 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let hold_ms = 2 + (seed % 10); + let user = format!("relay-pipeline-fuzz-{}-{round}", std::process::id()); + let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user); + let held_guard = held + .try_lock() + .expect("test must hold lock during fuzz round"); + + let stats = Arc::new(Stats::new()); + let (mut client_peer, relay_client) = duplex(1024); + let (relay_server, mut server_peer) = duplex(1024); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_user = user.clone(); + let relay_stats = Arc::clone(&stats); + let relay_task = tokio::spawn(async move { + relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 256, + 256, + &relay_user, + relay_stats, + Some(1024), + Arc::new(BufferPool::new()), + ) + .await + }); + + server_peer + .write_all(&[0xD1]) + .await + .expect("server write should queue in fuzz round"); + + let mut one = [0u8; 1]; + let stalled = timeout(Duration::from_millis(30), client_peer.read_exact(&mut one)).await; + assert!(stalled.is_err(), "held phase must stall same-user relay"); + + tokio::time::sleep(Duration::from_millis(hold_ms)).await; + drop(held_guard); + + timeout(Duration::from_millis(400), client_peer.read_exact(&mut one)) + .await + .expect("released phase must resume same-user relay") + .expect("released phase read must succeed"); + assert_eq!(one, [0xD1]); + + drop(client_peer); + drop(server_peer); + + assert!( + timeout(Duration::from_secs(1), relay_task) + .await + .expect("fuzz relay task must complete") + .expect("fuzz relay task must not panic") + .is_ok() + ); + } +} \ No newline at end of file diff --git a/src/proxy/tests/relay_cross_mode_pipeline_latency_benchmark_security_tests.rs b/src/proxy/tests/relay_cross_mode_pipeline_latency_benchmark_security_tests.rs new file mode 100644 index 0000000..c967861 --- /dev/null +++ b/src/proxy/tests/relay_cross_mode_pipeline_latency_benchmark_security_tests.rs @@ -0,0 +1,213 @@ +use super::relay_bidirectional; +use crate::stats::Stats; +use crate::stream::BufferPool; +use std::sync::{Arc, Mutex}; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::sync::{Barrier, watch}; +use tokio::time::{Duration, Instant, timeout}; + +fn quota_test_guard() -> impl Drop { + super::quota_user_lock_test_scope() +} + +fn percentile_index(len: usize, percentile: usize) -> usize { + ((len * percentile) / 100).min(len.saturating_sub(1)) +} + +#[tokio::test] +async fn micro_benchmark_pipeline_release_to_delivery_latency_stays_bounded() { + let _guard = quota_test_guard(); + + let rounds = 64usize; + let user = format!("relay-pipeline-latency-single-{}", std::process::id()); + let mut samples_ms = Vec::with_capacity(rounds); + + for round in 0..rounds { + let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user); + let held_guard = held + .try_lock() + .expect("test must hold shared cross-mode lock before round"); + + let stats = Arc::new(Stats::new()); + let (mut client_peer, relay_client) = duplex(1024); + let (relay_server, mut server_peer) = duplex(1024); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_user = user.clone(); + let relay_stats = Arc::clone(&stats); + let relay_task = tokio::spawn(async move { + relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 256, + 256, + &relay_user, + relay_stats, + Some(2048), + Arc::new(BufferPool::new()), + ) + .await + }); + + server_peer + .write_all(&[(round as u8) ^ 0xA5]) + .await + .expect("server write should queue before release"); + + let release_at = Instant::now(); + drop(held_guard); + + let mut one = [0u8; 1]; + timeout(Duration::from_millis(450), client_peer.read_exact(&mut one)) + .await + .expect("client must receive queued byte after release") + .expect("queued byte read must succeed"); + samples_ms.push(release_at.elapsed().as_millis() as u64); + + drop(client_peer); + drop(server_peer); + + let relay_result = timeout(Duration::from_secs(1), relay_task) + .await + .expect("relay task must complete") + .expect("relay task must not panic"); + assert!(relay_result.is_ok()); + } + + samples_ms.sort_unstable(); + let p50_ms = samples_ms[percentile_index(samples_ms.len(), 50)]; + let p95_ms = samples_ms[percentile_index(samples_ms.len(), 95)]; + + assert!( + p50_ms <= 45, + "single-flow release latency p50 must stay bounded; p50_ms={p50_ms}, samples={samples_ms:?}" + ); + assert!( + p95_ms <= 130, + "single-flow release latency p95 must stay bounded; p95_ms={p95_ms}, samples={samples_ms:?}" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_128_waiter_pipeline_release_latency_p95_stays_bounded() { + let _guard = quota_test_guard(); + + let waiters = 128usize; + let user = format!("relay-pipeline-latency-fanout-{}", std::process::id()); + let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user); + let held_guard = held + .try_lock() + .expect("test must hold shared lock before fanout release benchmark"); + + let ready_barrier = Arc::new(Barrier::new(waiters + 1)); + let release_at = Arc::new(Mutex::new(None::)); + let (release_tx, release_rx) = watch::channel(false); + let mut tasks = Vec::with_capacity(waiters); + + for idx in 0..waiters { + let user = user.clone(); + let barrier = Arc::clone(&ready_barrier); + let release_at = Arc::clone(&release_at); + let mut release_rx = release_rx.clone(); + + tasks.push(tokio::spawn(async move { + let stats = Arc::new(Stats::new()); + let (mut client_peer, relay_client) = duplex(512); + let (relay_server, mut server_peer) = duplex(512); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_user = user; + let relay_stats = Arc::clone(&stats); + let relay_task = tokio::spawn(async move { + relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 256, + 256, + &relay_user, + relay_stats, + Some(2048), + Arc::new(BufferPool::new()), + ) + .await + }); + + server_peer + .write_all(&[(idx as u8) ^ 0x5A]) + .await + .expect("fanout server write should queue before release"); + + barrier.wait().await; + release_rx + .changed() + .await + .expect("release signal should remain available"); + + let started = { + let guard = release_at.lock().unwrap_or_else(|poison| poison.into_inner()); + guard.expect("release timestamp must be populated before signal") + }; + + let mut one = [0u8; 1]; + timeout(Duration::from_millis(900), client_peer.read_exact(&mut one)) + .await + .expect("fanout waiter must receive queued byte after release") + .expect("fanout waiter read must succeed"); + + drop(client_peer); + drop(server_peer); + + let relay_result = timeout(Duration::from_secs(2), relay_task) + .await + .expect("fanout relay task must complete") + .expect("fanout relay task must not panic"); + assert!(relay_result.is_ok()); + + started.elapsed().as_millis() as u64 + })); + } + + ready_barrier.wait().await; + { + let mut guard = release_at.lock().unwrap_or_else(|poison| poison.into_inner()); + *guard = Some(Instant::now()); + } + drop(held_guard); + release_tx + .send(true) + .expect("release broadcast must succeed"); + + let mut samples_ms = Vec::with_capacity(waiters); + timeout(Duration::from_secs(8), async { + for task in tasks { + let elapsed = task.await.expect("fanout waiter must not panic"); + samples_ms.push(elapsed); + } + }) + .await + .expect("fanout benchmark must complete in bounded time"); + + samples_ms.sort_unstable(); + let p50_ms = samples_ms[percentile_index(samples_ms.len(), 50)]; + let p95_ms = samples_ms[percentile_index(samples_ms.len(), 95)]; + let max_ms = *samples_ms.last().unwrap_or(&0); + + assert!( + p50_ms <= 120, + "fanout release latency p50 must stay bounded; p50_ms={p50_ms}, p95_ms={p95_ms}, max_ms={max_ms}" + ); + assert!( + p95_ms <= 260, + "fanout release latency p95 must stay bounded; p50_ms={p50_ms}, p95_ms={p95_ms}, max_ms={max_ms}" + ); + assert!( + max_ms <= 700, + "fanout release latency max must stay bounded; p50_ms={p50_ms}, p95_ms={p95_ms}, max_ms={max_ms}" + ); +} \ No newline at end of file diff --git a/src/proxy/tests/relay_cross_mode_quota_fairness_tdd_tests.rs b/src/proxy/tests/relay_cross_mode_quota_fairness_tdd_tests.rs index 87944ba..adbdb22 100644 --- a/src/proxy/tests/relay_cross_mode_quota_fairness_tdd_tests.rs +++ b/src/proxy/tests/relay_cross_mode_quota_fairness_tdd_tests.rs @@ -3,8 +3,9 @@ use crate::stats::Stats; use std::pin::Pin; use std::sync::Arc; use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; -use std::task::{Context, Waker}; +use std::task::{Context, Poll, Waker}; use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadBuf}; +use tokio::sync::Barrier; use tokio::time::{Duration, timeout}; #[derive(Default)] @@ -26,6 +27,13 @@ fn quota_test_guard() -> impl Drop { super::quota_user_lock_test_scope() } +fn build_context() -> (Arc, Context<'static>) { + let wake_counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&wake_counter)); + let leaked_waker: &'static Waker = Box::leak(Box::new(waker)); + (wake_counter, Context::from_waker(leaked_waker)) +} + #[tokio::test] async fn positive_cross_mode_uncontended_writer_progresses() { let _guard = quota_test_guard(); @@ -223,3 +231,374 @@ async fn light_fuzz_cross_mode_release_timing_preserves_read_write_liveness() { assert!(write_done.is_ok()); } } + +#[tokio::test] +async fn integration_middle_lock_blocks_relay_reader_for_same_user() { + let _guard = quota_test_guard(); + + let user = format!("cross-mode-middle-reader-block-{}", std::process::id()); + let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user); + let _held_guard = held + .try_lock() + .expect("test must hold middle-relay shared lock"); + + let mut io = StatsIo::new( + tokio::io::empty(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + Some(1024), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + let (_wake_counter, mut cx) = build_context(); + let mut one = [0u8; 1]; + let mut buf = ReadBuf::new(&mut one); + let poll = Pin::new(&mut io).poll_read(&mut cx, &mut buf); + assert!(poll.is_pending()); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn integration_middle_lock_release_unblocks_relay_reader() { + let _guard = quota_test_guard(); + + let user = format!("cross-mode-middle-reader-release-{}", std::process::id()); + let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user); + let held_guard = held + .try_lock() + .expect("test must hold middle-relay shared lock"); + + let task = tokio::spawn({ + let user = user.clone(); + async move { + let mut io = StatsIo::new( + tokio::io::empty(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + Some(1024), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + let mut one = [0u8; 1]; + io.read(&mut one).await + } + }); + + tokio::time::sleep(Duration::from_millis(5)).await; + drop(held_guard); + + let done = timeout(Duration::from_millis(300), task) + .await + .expect("reader task must complete after release") + .expect("reader task must not panic"); + assert!(done.is_ok()); +} + +#[tokio::test] +async fn business_different_user_middle_lock_does_not_block_relay_writer() { + let _guard = quota_test_guard(); + + let held_user = format!("cross-mode-middle-held-{}", std::process::id()); + let active_user = format!("cross-mode-middle-active-{}", std::process::id()); + let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&held_user); + let _held_guard = held + .try_lock() + .expect("test must hold middle-relay lock for other user"); + + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + active_user, + Some(1024), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + let (_wake_counter, mut cx) = build_context(); + let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x61]); + assert!(matches!(poll, Poll::Ready(Ok(1)))); +} + +#[tokio::test] +async fn edge_quota_none_bypasses_cross_mode_lock_even_when_held() { + let _guard = quota_test_guard(); + + let user = format!("cross-mode-none-limit-{}", std::process::id()); + let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user); + let _held_guard = held + .try_lock() + .expect("test must hold lock while quota is disabled"); + + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + None, + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + let (_wake_counter, mut cx) = build_context(); + let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x62, 0x63]); + assert!(matches!(poll, Poll::Ready(Ok(2)))); +} + +#[tokio::test] +async fn edge_quota_exceeded_flag_short_circuits_before_lock_path() { + let _guard = quota_test_guard(); + + let user = format!("cross-mode-pre-exceeded-{}", std::process::id()); + let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user); + let _held_guard = held + .try_lock() + .expect("test must hold shared lock before poll"); + + let quota_exceeded = Arc::new(AtomicBool::new(true)); + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + Some(1024), + Arc::clone("a_exceeded), + tokio::time::Instant::now(), + ); + + let (_wake_counter, mut cx) = build_context(); + let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x64]); + assert!(matches!(poll, Poll::Ready(Err(ref e)) if is_quota_io_error(e))); +} + +#[tokio::test] +async fn adversarial_repoll_while_middle_lock_held_keeps_pending_without_usage_leak() { + let _guard = quota_test_guard(); + + let user = format!("cross-mode-repoll-held-{}", std::process::id()); + let stats = Arc::new(Stats::new()); + let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user); + let _held_guard = held + .try_lock() + .expect("test must hold lock for repoll sequence"); + + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::clone(&stats), + user.clone(), + Some(1024), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + let (_wake_counter, mut cx) = build_context(); + for _ in 0..8 { + let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x65]); + assert!(poll.is_pending()); + } + + assert_eq!(stats.get_user_total_octets(&user), 0); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn integration_same_user_mixed_read_write_waiters_resume_after_release() { + let _guard = quota_test_guard(); + + let user = format!("cross-mode-mixed-resume-{}", std::process::id()); + let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user); + let held_guard = held + .try_lock() + .expect("test must hold lock before spawning mixed waiters"); + + let mut tasks = Vec::new(); + for i in 0..12usize { + let user = user.clone(); + tasks.push(tokio::spawn(async move { + if i % 2 == 0 { + let mut io = StatsIo::new( + tokio::io::empty(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + Some(1024), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + let mut b = [0u8; 1]; + io.read(&mut b).await.map(|_| ()) + } else { + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + Some(1024), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + io.write_all(&[0x66]).await + } + })); + } + + tokio::time::sleep(Duration::from_millis(8)).await; + drop(held_guard); + + timeout(Duration::from_secs(1), async { + for task in tasks { + let result = task.await.expect("mixed waiter task must not panic"); + assert!(result.is_ok()); + } + }) + .await + .expect("all mixed waiters must finish after release"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn integration_one_user_blocked_other_user_progresses_under_middle_lock() { + let _guard = quota_test_guard(); + + let blocked_user = format!("cross-mode-blocked-{}", std::process::id()); + let free_user = format!("cross-mode-free-{}", std::process::id()); + let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&blocked_user); + let held_guard = held + .try_lock() + .expect("test must hold blocked user lock"); + + let blocked_task = tokio::spawn({ + let blocked_user = blocked_user.clone(); + async move { + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + blocked_user, + Some(1024), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + io.write_all(&[0x77]).await + } + }); + + let free_task = tokio::spawn({ + let free_user = free_user.clone(); + async move { + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + free_user, + Some(1024), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + io.write_all(&[0x78]).await + } + }); + + let free_done = timeout(Duration::from_millis(250), free_task) + .await + .expect("free user must not be blocked") + .expect("free user task must not panic"); + assert!(free_done.is_ok()); + + drop(held_guard); + let blocked_done = timeout(Duration::from_secs(1), blocked_task) + .await + .expect("blocked user must resume after release") + .expect("blocked user task must not panic"); + assert!(blocked_done.is_ok()); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_middle_lock_release_allows_high_waiter_fanout_completion() { + let _guard = quota_test_guard(); + + let user = format!("cross-mode-fanout-{}", std::process::id()); + let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user); + let held_guard = held + .try_lock() + .expect("test must hold lock before fanout"); + + let waiters = 48usize; + let gate = Arc::new(Barrier::new(waiters + 1)); + let mut tasks = Vec::new(); + for _ in 0..waiters { + let user = user.clone(); + let gate = Arc::clone(&gate); + tasks.push(tokio::spawn(async move { + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + Some(1024), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + gate.wait().await; + io.write_all(&[0x79]).await + })); + } + + gate.wait().await; + tokio::time::sleep(Duration::from_millis(10)).await; + drop(held_guard); + + timeout(Duration::from_secs(2), async { + for task in tasks { + let result = task.await.expect("fanout task must not panic"); + assert!(result.is_ok()); + } + }) + .await + .expect("fanout waiters must complete after release"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn light_fuzz_middle_lock_hold_release_cycles_preserve_same_user_liveness() { + let _guard = quota_test_guard(); + + let mut seed = 0xA11C_EE55_2026_0323u64; + for round in 0..20u32 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let hold_ms = 2 + (seed % 10); + let user = format!("cross-mode-middle-fuzz-{}-{round}", std::process::id()); + let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user); + let held_guard = held + .try_lock() + .expect("test must hold lock in fuzz round"); + + let writer = tokio::spawn({ + let user = user.clone(); + async move { + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + Some(1024), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + io.write_all(&[0x7A]).await + } + }); + + tokio::time::sleep(Duration::from_millis(hold_ms)).await; + drop(held_guard); + + let done = timeout(Duration::from_millis(400), writer) + .await + .expect("writer must complete after lock release") + .expect("writer task must not panic"); + assert!(done.is_ok()); + } +} diff --git a/src/proxy/tests/relay_dual_lock_alternating_contention_security_tests.rs b/src/proxy/tests/relay_dual_lock_alternating_contention_security_tests.rs new file mode 100644 index 0000000..9ac4621 --- /dev/null +++ b/src/proxy/tests/relay_dual_lock_alternating_contention_security_tests.rs @@ -0,0 +1,340 @@ +use super::*; +use crate::stats::Stats; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::task::{Context, Waker}; +use tokio::io::AsyncWriteExt; +use tokio::time::{Duration, Instant, timeout}; + +#[derive(Default)] +struct WakeCounter { + wakes: AtomicUsize, +} + +impl std::task::Wake for WakeCounter { + fn wake(self: Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } + + fn wake_by_ref(self: &Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } +} + +fn quota_test_guard() -> impl Drop { + super::quota_user_lock_test_scope() +} + +#[tokio::test] +async fn positive_uncontended_dual_lock_writer_has_zero_retry_attempt() { + let _guard = quota_test_guard(); + + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + format!("dual-lock-alt-positive-{}", std::process::id()), + Some(2048), + Arc::new(AtomicBool::new(false)), + Instant::now(), + ); + + let write = io.write_all(&[0xAA, 0xBB]).await; + assert!(write.is_ok(), "uncontended write must complete"); + assert_eq!( + io.quota_write_retry_attempt, 0, + "uncontended write must not advance retry backoff" + ); +} + +#[tokio::test] +async fn adversarial_alternating_local_and_cross_mode_contention_preserves_backoff_growth() { + let _guard = quota_test_guard(); + + let user = format!("dual-lock-alt-adversarial-{}", std::process::id()); + let local_lock = quota_user_lock(&user); + let cross_mode_lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); + + let mut local_guard = Some( + local_lock + .try_lock() + .expect("test must hold local quota lock initially"), + ); + let mut cross_guard = None; + + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + Some(2048), + Arc::new(AtomicBool::new(false)), + Instant::now(), + ); + + let wake_counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&wake_counter)); + let mut cx = Context::from_waker(&waker); + + let first = Pin::new(&mut io).poll_write(&mut cx, &[0x11]); + assert!(first.is_pending(), "held local lock must block first poll"); + + let mut observed_wakes = 0usize; + for idx in 0..18usize { + tokio::time::sleep(Duration::from_millis(6)).await; + + if idx % 2 == 0 { + drop(local_guard.take()); + cross_guard = Some( + cross_mode_lock + .try_lock() + .expect("cross-mode lock should be acquirable while local lock released"), + ); + } else { + drop(cross_guard.take()); + local_guard = Some( + local_lock + .try_lock() + .expect("local lock should be acquirable while cross lock released"), + ); + } + + let wakes = wake_counter.wakes.load(Ordering::Relaxed); + if wakes > observed_wakes { + observed_wakes = wakes; + let pending = Pin::new(&mut io).poll_write(&mut cx, &[0x12]); + assert!( + pending.is_pending(), + "alternating contention must keep write pending while one lock is held" + ); + } + } + + assert!( + io.quota_write_retry_attempt >= 2, + "alternating contention must still ramp retry backoff; got {}", + io.quota_write_retry_attempt + ); + assert!( + wake_counter.wakes.load(Ordering::Relaxed) <= 32, + "alternating contention must stay wake-rate-limited" + ); + + drop(local_guard); + drop(cross_guard); + let ready = Pin::new(&mut io).poll_write(&mut cx, &[0x13]); + assert!(ready.is_ready(), "writer must resume after both locks released"); +} + +#[tokio::test] +async fn edge_retry_scheduler_resets_after_alternating_contention_clears() { + let _guard = quota_test_guard(); + + let user = format!("dual-lock-alt-edge-reset-{}", std::process::id()); + let local_lock = quota_user_lock(&user); + let local_guard = local_lock + .try_lock() + .expect("test must hold local lock for edge scenario"); + + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + Some(2048), + Arc::new(AtomicBool::new(false)), + Instant::now(), + ); + + let wake_counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&wake_counter)); + let mut cx = Context::from_waker(&waker); + + let first = Pin::new(&mut io).poll_write(&mut cx, &[0x21]); + assert!(first.is_pending()); + tokio::time::sleep(Duration::from_millis(15)).await; + if wake_counter.wakes.load(Ordering::Relaxed) > 0 { + let next = Pin::new(&mut io).poll_write(&mut cx, &[0x22]); + assert!(next.is_pending()); + } + + drop(local_guard); + + let ready = Pin::new(&mut io).poll_write(&mut cx, &[0x23]); + assert!(ready.is_ready()); + assert_eq!( + io.quota_write_retry_attempt, 0, + "successful dual-lock acquisition must reset retry scheduler" + ); + assert!(!io.quota_write_wake_scheduled); + assert!(io.quota_write_retry_sleep.is_none()); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn integration_cross_mode_waiters_remain_live_under_alternating_contention_then_resume() { + let _guard = quota_test_guard(); + + let user = format!("dual-lock-alt-integration-{}", std::process::id()); + let local_lock = quota_user_lock(&user); + let cross_mode_lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); + + let mut waiters = Vec::new(); + for _ in 0..16usize { + let user = user.clone(); + waiters.push(tokio::spawn(async move { + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + Some(2048), + Arc::new(AtomicBool::new(false)), + Instant::now(), + ); + timeout(Duration::from_secs(2), io.write_all(&[0x31])).await + })); + } + + let mut local_guard = Some( + local_lock + .try_lock() + .expect("integration toggle must acquire local lock first"), + ); + let mut cross_guard = None; + + for idx in 0..24usize { + tokio::time::sleep(Duration::from_millis(4)).await; + if idx % 2 == 0 { + drop(local_guard.take()); + cross_guard = cross_mode_lock.try_lock().ok(); + } else { + drop(cross_guard.take()); + local_guard = local_lock.try_lock().ok(); + } + } + + drop(local_guard); + drop(cross_guard); + + for waiter in waiters { + let done = waiter.await.expect("waiter task must not panic"); + assert!( + done.is_ok(), + "waiter must finish once alternating contention window ends" + ); + assert!(done.expect("waiter timeout must not fire").is_ok()); + } +} + +#[tokio::test] +async fn light_fuzz_alternating_contention_matrix_preserves_lock_gating() { + let _guard = quota_test_guard(); + + let user = format!("dual-lock-alt-fuzz-{}", std::process::id()); + let local_lock = quota_user_lock(&user); + let cross_mode_lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); + let mut seed = 0xD00D_BAAD_F00D_2026u64; + + for _round in 0..64u32 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let hold_mode = (seed % 3) as u8; + let local_guard = if hold_mode == 0 { + Some( + local_lock + .try_lock() + .expect("fuzz local lock should be acquirable"), + ) + } else { + None + }; + let cross_guard = if hold_mode == 1 { + Some( + cross_mode_lock + .try_lock() + .expect("fuzz cross lock should be acquirable"), + ) + } else { + None + }; + + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user.clone(), + Some(1024), + Arc::new(AtomicBool::new(false)), + Instant::now(), + ); + + let write = timeout(Duration::from_millis(35), io.write_all(&[0x51])).await; + if hold_mode == 2 { + assert!(write.is_ok(), "unheld fuzz round must make progress"); + assert!(write.expect("unheld round timeout").is_ok()); + } else { + assert!( + write.is_err(), + "held-lock fuzz round must remain pending inside bounded window" + ); + } + + drop(local_guard); + drop(cross_guard); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_fanout_alternating_contention_recovers_without_hanging() { + let _guard = quota_test_guard(); + + let user = format!("dual-lock-alt-stress-{}", std::process::id()); + let local_lock = quota_user_lock(&user); + let cross_mode_lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); + + let mut waiters = Vec::new(); + for _ in 0..48usize { + let user = user.clone(); + waiters.push(tokio::spawn(async move { + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + Some(4096), + Arc::new(AtomicBool::new(false)), + Instant::now(), + ); + timeout(Duration::from_secs(3), io.write_all(&[0xA0, 0xA1])).await + })); + } + + let mut local_guard = Some( + local_lock + .try_lock() + .expect("stress toggle must acquire local lock first"), + ); + let mut cross_guard = None; + for idx in 0..40usize { + tokio::time::sleep(Duration::from_millis(3)).await; + if idx % 2 == 0 { + drop(local_guard.take()); + cross_guard = cross_mode_lock.try_lock().ok(); + } else { + drop(cross_guard.take()); + local_guard = local_lock.try_lock().ok(); + } + } + + drop(local_guard); + drop(cross_guard); + + for waiter in waiters { + let done = waiter.await.expect("stress waiter task must not panic"); + assert!(done.is_ok(), "stress waiter timed out under alternating contention"); + assert!(done.expect("stress waiter timeout should not fire").is_ok()); + } +} diff --git a/src/proxy/tests/relay_dual_lock_backoff_regression_security_tests.rs b/src/proxy/tests/relay_dual_lock_backoff_regression_security_tests.rs new file mode 100644 index 0000000..ce26941 --- /dev/null +++ b/src/proxy/tests/relay_dual_lock_backoff_regression_security_tests.rs @@ -0,0 +1,74 @@ +use super::*; +use crate::stats::Stats; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::task::{Context, Waker}; +use tokio::time::{Duration, Instant}; + +#[derive(Default)] +struct WakeCounter { + wakes: AtomicUsize, +} + +impl std::task::Wake for WakeCounter { + fn wake(self: Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } + + fn wake_by_ref(self: &Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } +} + +fn quota_test_guard() -> impl Drop { + super::quota_user_lock_test_scope() +} + +#[tokio::test] +async fn adversarial_cross_mode_only_contention_backoff_attempt_must_ramp() { + let _guard = quota_test_guard(); + + let user = format!("dual-lock-backoff-{}", std::process::id()); + let cross_mode_lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); + let held_cross_mode_guard = cross_mode_lock + .try_lock() + .expect("test must hold cross-mode lock before polling"); + + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + Some(2048), + Arc::new(AtomicBool::new(false)), + Instant::now(), + ); + + let wake_counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&wake_counter)); + let mut cx = Context::from_waker(&waker); + + let first = Pin::new(&mut io).poll_write(&mut cx, &[0xAA]); + assert!(first.is_pending(), "held cross-mode lock must block writer"); + + let started = Instant::now(); + let mut last_wakes = 0usize; + while started.elapsed() < Duration::from_millis(120) { + let wakes = wake_counter.wakes.load(Ordering::Relaxed); + if wakes > last_wakes { + last_wakes = wakes; + let next = Pin::new(&mut io).poll_write(&mut cx, &[0xAB]); + assert!(next.is_pending(), "writer must remain blocked while lock is held"); + } + tokio::time::sleep(Duration::from_millis(1)).await; + } + + assert!( + io.quota_write_retry_attempt >= 2, + "retry attempt must ramp under sustained second-lock contention; got {}", + io.quota_write_retry_attempt + ); + + drop(held_cross_mode_guard); +} diff --git a/src/proxy/tests/relay_dual_lock_contention_matrix_security_tests.rs b/src/proxy/tests/relay_dual_lock_contention_matrix_security_tests.rs new file mode 100644 index 0000000..513d92b --- /dev/null +++ b/src/proxy/tests/relay_dual_lock_contention_matrix_security_tests.rs @@ -0,0 +1,325 @@ +use super::*; +use crate::stats::Stats; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::task::{Context, Waker}; +use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadBuf}; +use tokio::time::{Duration, Instant, timeout}; + +#[derive(Default)] +struct WakeCounter { + wakes: AtomicUsize, +} + +impl std::task::Wake for WakeCounter { + fn wake(self: Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } + + fn wake_by_ref(self: &Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } +} + +fn quota_test_guard() -> impl Drop { + super::quota_user_lock_test_scope() +} + +fn build_context() -> (Arc, Context<'static>) { + let wake_counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&wake_counter)); + let leaked_waker: &'static Waker = Box::leak(Box::new(waker)); + (wake_counter, Context::from_waker(leaked_waker)) +} + +#[tokio::test] +async fn positive_uncontended_dual_locks_writer_completes_without_retry_state() { + let _guard = quota_test_guard(); + + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + format!("dual-lock-positive-{}", std::process::id()), + Some(4096), + Arc::new(AtomicBool::new(false)), + Instant::now(), + ); + + let (_wake_counter, mut cx) = build_context(); + let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x01, 0x02, 0x03]); + assert!(poll.is_ready()); + assert_eq!(io.quota_write_retry_attempt, 0); + assert!(!io.quota_write_wake_scheduled); + assert!(io.quota_write_retry_sleep.is_none()); +} + +#[tokio::test] +async fn negative_local_lock_contention_read_retry_attempt_ramps() { + let _guard = quota_test_guard(); + + let user = format!("dual-lock-local-contention-{}", std::process::id()); + let held = quota_user_lock(&user); + let held_guard = held + .try_lock() + .expect("test must hold local quota lock before polling"); + + let mut io = StatsIo::new( + tokio::io::empty(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + Some(2048), + Arc::new(AtomicBool::new(false)), + Instant::now(), + ); + + let (wake_counter, mut cx) = build_context(); + let mut one = [0u8; 1]; + let mut buf = ReadBuf::new(&mut one); + let first = Pin::new(&mut io).poll_read(&mut cx, &mut buf); + assert!(first.is_pending()); + + let started = Instant::now(); + let mut observed = 0usize; + while started.elapsed() < Duration::from_millis(120) { + let wakes = wake_counter.wakes.load(Ordering::Relaxed); + if wakes > observed { + observed = wakes; + let mut step_buf = ReadBuf::new(&mut one); + let next = Pin::new(&mut io).poll_read(&mut cx, &mut step_buf); + assert!(next.is_pending()); + } + tokio::time::sleep(Duration::from_millis(1)).await; + } + + assert!( + io.quota_read_retry_attempt >= 2, + "retry attempt must ramp under sustained local-lock contention; got {}", + io.quota_read_retry_attempt + ); + + drop(held_guard); +} + +#[tokio::test] +async fn edge_cross_mode_contention_release_resets_retry_scheduler_on_success() { + let _guard = quota_test_guard(); + + let user = format!("dual-lock-reset-{}", std::process::id()); + let cross_mode = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); + let held_guard = cross_mode + .try_lock() + .expect("test must hold cross-mode lock before polling"); + + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + Some(2048), + Arc::new(AtomicBool::new(false)), + Instant::now(), + ); + + let (wake_counter, mut cx) = build_context(); + let first = Pin::new(&mut io).poll_write(&mut cx, &[0x10]); + assert!(first.is_pending()); + + tokio::time::sleep(Duration::from_millis(20)).await; + if wake_counter.wakes.load(Ordering::Relaxed) > 0 { + let next = Pin::new(&mut io).poll_write(&mut cx, &[0x11]); + assert!(next.is_pending()); + } + + drop(held_guard); + let ready = Pin::new(&mut io).poll_write(&mut cx, &[0x12]); + assert!(ready.is_ready()); + assert_eq!(io.quota_write_retry_attempt, 0); + assert!(!io.quota_write_wake_scheduled); + assert!(io.quota_write_retry_sleep.is_none()); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn adversarial_cross_mode_hold_blocks_many_waiters_without_usage_leak() { + let _guard = quota_test_guard(); + + let user = format!("dual-lock-adversarial-{}", std::process::id()); + let stats = Arc::new(Stats::new()); + let held = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); + let held_guard = held + .try_lock() + .expect("test must hold cross-mode lock before launching waiters"); + + let mut tasks = Vec::new(); + for _ in 0..24usize { + let stats = Arc::clone(&stats); + let user = user.clone(); + tasks.push(tokio::spawn(async move { + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + stats, + user, + Some(1024), + Arc::new(AtomicBool::new(false)), + Instant::now(), + ); + timeout(Duration::from_millis(40), io.write_all(&[0x33])).await + })); + } + + for task in tasks { + let timed = task.await.expect("waiter task must not panic"); + assert!(timed.is_err(), "held cross-mode lock must keep waiter pending"); + } + + assert_eq!(stats.get_user_total_octets(&user), 0); + drop(held_guard); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn integration_waiters_resume_after_cross_mode_release() { + let _guard = quota_test_guard(); + + let user = format!("dual-lock-integration-{}", std::process::id()); + let held = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); + let held_guard = held + .try_lock() + .expect("test must hold cross-mode lock before starting waiter"); + + let task = tokio::spawn({ + let user = user.clone(); + async move { + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + Some(1024), + Arc::new(AtomicBool::new(false)), + Instant::now(), + ); + io.write_all(&[0x44]).await + } + }); + + tokio::time::sleep(Duration::from_millis(10)).await; + drop(held_guard); + + let done = timeout(Duration::from_secs(1), task) + .await + .expect("waiter task must complete after release") + .expect("waiter task must not panic"); + assert!(done.is_ok()); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn light_fuzz_randomized_lock_holds_preserve_liveness_and_quota_bounds() { + let _guard = quota_test_guard(); + + let user = format!("dual-lock-fuzz-{}", std::process::id()); + let stats = Arc::new(Stats::new()); + let mut seed = 0xA55A_55AA_C3D2_E1F0u64; + + for _round in 0..48u32 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let hold_mode = (seed % 3) as u8; + let mut local_lock = None; + let mut cross_lock = None; + let mut local_guard = None; + let mut cross_guard = None; + + if hold_mode == 0 { + local_lock = Some(quota_user_lock(&user)); + local_guard = Some( + local_lock + .as_ref() + .expect("local lock should be present") + .try_lock() + .expect("local lock should be acquirable in fuzz round"), + ); + } else if hold_mode == 1 { + cross_lock = Some(crate::proxy::quota_lock_registry::cross_mode_quota_user_lock( + &user, + )); + cross_guard = Some( + cross_lock + .as_ref() + .expect("cross lock should be present") + .try_lock() + .expect("cross lock should be acquirable in fuzz round"), + ); + } + + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::clone(&stats), + user.clone(), + Some(4096), + Arc::new(AtomicBool::new(false)), + Instant::now(), + ); + + let write = timeout(Duration::from_millis(25), io.write_all(&[0x7A])).await; + if hold_mode == 2 { + assert!(write.is_ok(), "unheld round must make progress"); + } else { + assert!(write.is_err(), "held-lock round must stay blocked within timeout"); + } + + drop(local_guard); + drop(cross_guard); + drop(local_lock); + drop(cross_lock); + } + + assert!(stats.get_user_total_octets(&user) <= 4096); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_fanout_waiters_complete_after_release_without_panics() { + let _guard = quota_test_guard(); + + let user = format!("dual-lock-stress-{}", std::process::id()); + let held = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); + let held_guard = held + .try_lock() + .expect("test must hold cross-mode lock before stress fanout"); + + let waiters = 64usize; + let mut tasks = Vec::new(); + for _ in 0..waiters { + let user = user.clone(); + tasks.push(tokio::spawn(async move { + let mut io = StatsIo::new( + tokio::io::empty(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + Some(1024), + Arc::new(AtomicBool::new(false)), + Instant::now(), + ); + let mut one = [0u8; 1]; + io.read(&mut one).await + })); + } + + tokio::time::sleep(Duration::from_millis(12)).await; + drop(held_guard); + + timeout(Duration::from_secs(2), async { + for task in tasks { + let result = task.await.expect("stress waiter task must not panic"); + assert!(result.is_ok()); + } + }) + .await + .expect("all stress waiters must complete after release"); +} diff --git a/src/proxy/tests/relay_dual_lock_race_harness_security_tests.rs b/src/proxy/tests/relay_dual_lock_race_harness_security_tests.rs new file mode 100644 index 0000000..ec180e8 --- /dev/null +++ b/src/proxy/tests/relay_dual_lock_race_harness_security_tests.rs @@ -0,0 +1,128 @@ +use super::*; +use crate::stats::Stats; +use std::sync::Arc; +use std::sync::atomic::AtomicBool; +use tokio::io::AsyncWriteExt; +use tokio::time::{Duration, timeout}; + +fn quota_test_guard() -> impl Drop { + super::quota_user_lock_test_scope() +} + +fn make_stats_io(user: String) -> StatsIo { + StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + Some(4096), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn light_fuzz_1024_round_hold_release_cycles_preserve_same_user_liveness() { + let _guard = quota_test_guard(); + + let user = format!("dual-lock-race-fuzz-{}", std::process::id()); + let mut seed = 0xD1CE_BAAD_5EED_1234u64; + + for round in 0..1024u32 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let hold = (seed & 1) == 0; + let hold_ms = (seed % 3) as u64; + + let maybe_lock = if hold { + Some(crate::proxy::quota_lock_registry::cross_mode_quota_user_lock( + &user, + )) + } else { + None + }; + + let maybe_guard = maybe_lock.as_ref().map(|lock| { + lock.try_lock() + .expect("cross-mode lock must be acquirable in fuzz round") + }); + + if hold { + let mut blocked_io = make_stats_io(user.clone()); + let blocked = timeout(Duration::from_millis(5), blocked_io.write_all(&[0xA5])).await; + assert!( + blocked.is_err(), + "held round must block waiter before lock release (round={round})" + ); + + if hold_ms > 0 { + tokio::time::sleep(Duration::from_millis(hold_ms)).await; + } + } else { + let mut free_io = make_stats_io(user.clone()); + let free = timeout(Duration::from_millis(120), free_io.write_all(&[0xA5])).await; + assert!( + free.is_ok(), + "unheld round must complete promptly (round={round})" + ); + assert!(free.expect("unheld round should complete").is_ok()); + } + + drop(maybe_guard); + + let done = timeout(Duration::from_millis(350), async { + let user = user.clone(); + let mut io = make_stats_io(user); + io.write_all(&[0xA6]).await + }) + .await + .expect("post-release write must complete in bounded time"); + assert!(done.is_ok()); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_jittered_three_waiter_rounds_do_not_starve_after_release() { + let _guard = quota_test_guard(); + + let user = format!("dual-lock-race-stress-{}", std::process::id()); + let mut seed = 0xC0FF_EE77_4444_9999u64; + + for round in 0..256u32 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let hold_ms = (seed % 4) as u64; + let lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); + let guard = lock + .try_lock() + .expect("cross-mode lock must be acquirable at round start"); + + let mut waiters = Vec::new(); + for _ in 0..3usize { + let user = user.clone(); + waiters.push(tokio::spawn(async move { + let mut io = make_stats_io(user); + io.write_all(&[0x55]).await + })); + } + + tokio::time::sleep(Duration::from_millis(hold_ms)).await; + drop(guard); + + timeout(Duration::from_secs(1), async { + for waiter in waiters { + let done = waiter.await.expect("waiter task must not panic"); + assert!( + done.is_ok(), + "waiter must complete after release (round={round})" + ); + } + }) + .await + .expect("all waiters must complete in bounded time after release"); + } +} diff --git a/src/proxy/tests/relay_quota_extended_attack_surface_security_tests.rs b/src/proxy/tests/relay_quota_extended_attack_surface_security_tests.rs new file mode 100644 index 0000000..5ee6522 --- /dev/null +++ b/src/proxy/tests/relay_quota_extended_attack_surface_security_tests.rs @@ -0,0 +1,332 @@ +use super::relay_bidirectional; +use crate::error::ProxyError; +use crate::stats::Stats; +use crate::stream::BufferPool; +use rand::rngs::StdRng; +use rand::{RngExt, SeedableRng}; +use std::sync::Arc; +use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; +use tokio::time::{Duration, timeout}; + +async fn read_available(reader: &mut R, budget: Duration) -> usize { + let start = tokio::time::Instant::now(); + let mut total = 0usize; + let mut buf = [0u8; 128]; + + loop { + let elapsed = start.elapsed(); + if elapsed >= budget { + break; + } + let remaining = budget.saturating_sub(elapsed); + match timeout(remaining, reader.read(&mut buf)).await { + Ok(Ok(0)) => break, + Ok(Ok(n)) => total = total.saturating_add(n), + Ok(Err(_)) | Err(_) => break, + } + } + + total +} + +#[tokio::test] +async fn positive_quota_path_forwards_both_directions_within_limit() { + let stats = Arc::new(Stats::new()); + let user = "quota-extended-positive-user"; + + let (mut client_peer, relay_client) = duplex(2048); + let (relay_server, mut server_peer) = duplex(2048); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 256, + 256, + user, + Arc::clone(&stats), + Some(16), + Arc::new(BufferPool::new()), + )); + + client_peer.write_all(&[0xAA, 0xBB, 0xCC, 0xDD]).await.unwrap(); + server_peer.read_exact(&mut [0u8; 4]).await.unwrap(); + + server_peer.write_all(&[0x11, 0x22, 0x33, 0x44]).await.unwrap(); + client_peer.read_exact(&mut [0u8; 4]).await.unwrap(); + + drop(client_peer); + drop(server_peer); + + let relay_result = timeout(Duration::from_secs(2), relay).await.unwrap().unwrap(); + assert!(relay_result.is_ok()); + assert!(stats.get_user_total_octets(user) <= 16); +} + +#[tokio::test] +async fn negative_preloaded_quota_forbids_any_forwarding() { + let stats = Arc::new(Stats::new()); + let user = "quota-extended-negative-user"; + stats.add_user_octets_from(user, 8); + + let (mut client_peer, relay_client) = duplex(1024); + let (relay_server, mut server_peer) = duplex(1024); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 128, + 128, + user, + Arc::clone(&stats), + Some(8), + Arc::new(BufferPool::new()), + )); + + client_peer.write_all(&[0xAA]).await.unwrap(); + server_peer.write_all(&[0xBB]).await.unwrap(); + + assert_eq!(read_available(&mut server_peer, Duration::from_millis(120)).await, 0); + assert_eq!(read_available(&mut client_peer, Duration::from_millis(120)).await, 0); + + let relay_result = timeout(Duration::from_secs(2), relay).await.unwrap().unwrap(); + assert!(matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. }))); + assert!(stats.get_user_total_octets(user) <= 8); +} + +#[tokio::test] +async fn edge_quota_one_ensures_at_most_one_byte_across_directions() { + let stats = Arc::new(Stats::new()); + let user = "quota-extended-edge-user"; + + let (mut client_peer, relay_client) = duplex(1024); + let (relay_server, mut server_peer) = duplex(1024); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 128, + 128, + user, + Arc::clone(&stats), + Some(1), + Arc::new(BufferPool::new()), + )); + + let _ = tokio::join!( + client_peer.write_all(&[0xFE]), + server_peer.write_all(&[0xEF]), + ); + + let mut buf = [0u8; 1]; + let delivered_s2c = timeout(Duration::from_millis(120), client_peer.read(&mut buf)).await.unwrap().unwrap_or(0); + let delivered_c2s = timeout(Duration::from_millis(120), server_peer.read(&mut buf)).await.unwrap().unwrap_or(0); + + assert!(delivered_s2c + delivered_c2s <= 1); + + let relay_result = timeout(Duration::from_secs(2), relay).await.unwrap().unwrap(); + assert!(matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. }))); +} + +#[tokio::test] +async fn adversarial_blackhat_alternating_jitter_does_not_overshoot_quota() { + let stats = Arc::new(Stats::new()); + let user = "quota-extended-blackhat-user"; + let quota = 24u64; + + let (mut client_peer, relay_client) = duplex(4096); + let (relay_server, mut server_peer) = duplex(4096); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 256, + 256, + user, + Arc::clone(&stats), + Some(quota), + Arc::new(BufferPool::new()), + )); + + let mut total_forwarded = 0usize; + + for i in 0..256usize { + if relay.is_finished() { + break; + } + if (i & 1) == 0 { + let _ = client_peer.write_all(&[(i as u8) ^ 0x57]).await; + let mut one = [0u8; 1]; + if let Ok(Ok(n)) = timeout(Duration::from_millis(6), server_peer.read(&mut one)).await { + total_forwarded += n; + } + } else { + let _ = server_peer.write_all(&[(i as u8) ^ 0xA8]).await; + let mut one = [0u8; 1]; + if let Ok(Ok(n)) = timeout(Duration::from_millis(6), client_peer.read(&mut one)).await { + total_forwarded += n; + } + } + + tokio::time::sleep(Duration::from_millis(((i % 3) + 1) as u64)).await; + } + + let relay_result = timeout(Duration::from_secs(3), relay).await.unwrap().unwrap(); + assert!(matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. }))); + assert!(total_forwarded <= quota as usize); + assert!(stats.get_user_total_octets(user) <= quota); +} + +#[tokio::test] +async fn light_fuzz_random_quota_schedule_preserves_quota_invariants() { + let mut rng = StdRng::seed_from_u64(0xBEEF_C0DE); + + for case in 0..32u64 { + let stats = Arc::new(Stats::new()); + let user = format!("quota-extended-fuzz-{case}"); + let quota = rng.random_range(1u64..=35u64); + + let (mut client_peer, relay_client) = duplex(4096); + let (relay_server, mut server_peer) = duplex(4096); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_user = user.clone(); + let relay_stats = Arc::clone(&stats); + let relay = tokio::spawn(async move { + relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 256, + 256, + &relay_user, + Arc::clone(&relay_stats), + Some(quota), + Arc::new(BufferPool::new()), + ) + .await + }); + + let mut total_forwarded = 0usize; + + for _ in 0..96usize { + if relay.is_finished() { + break; + } + + if rng.random::() { + let _ = client_peer.write_all(&[rng.random::()]).await; + let mut one = [0u8; 1]; + if let Ok(Ok(n)) = timeout(Duration::from_millis(4), server_peer.read(&mut one)).await { + total_forwarded += n; + } + } else { + let _ = server_peer.write_all(&[rng.random::()]).await; + let mut one = [0u8; 1]; + if let Ok(Ok(n)) = timeout(Duration::from_millis(4), client_peer.read(&mut one)).await { + total_forwarded += n; + } + } + } + + drop(client_peer); + drop(server_peer); + + let relay_result = timeout(Duration::from_secs(2), relay).await.unwrap().unwrap(); + assert!(relay_result.is_ok() || matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. }))); + assert!(total_forwarded <= quota as usize); + assert!(stats.get_user_total_octets(&user) <= quota); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_parallel_relays_for_one_user_obey_global_quota() { + let stats = Arc::new(Stats::new()); + let user = "quota-extended-stress-user".to_string(); + let quota = 64u64; + + let mut tasks = Vec::new(); + + for worker in 0..4u8 { + let stats = Arc::clone(&stats); + let user = user.clone(); + + tasks.push(tokio::spawn(async move { + let (mut client_peer, relay_client) = duplex(2048); + let (relay_server, mut server_peer) = duplex(2048); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_user = user.clone(); + let relay_stats = Arc::clone(&stats); + let relay = tokio::spawn(async move { + relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 128, + 128, + &relay_user, + Arc::clone(&relay_stats), + Some(quota), + Arc::new(BufferPool::new()), + ) + .await + }); + + let mut total = 0usize; + for step in 0..64u8 { + if relay.is_finished() { + break; + } + if (step as usize + worker as usize) % 2 == 0 { + let _ = client_peer.write_all(&[(step ^ 0x5A)]).await; + let mut one = [0u8; 1]; + if let Ok(Ok(n)) = timeout(Duration::from_millis(6), server_peer.read(&mut one)).await { + total += n; + } + } else { + let _ = server_peer.write_all(&[(step ^ 0xA5)]).await; + let mut one = [0u8; 1]; + if let Ok(Ok(n)) = timeout(Duration::from_millis(6), client_peer.read(&mut one)).await { + total += n; + } + } + tokio::time::sleep(Duration::from_millis(1)).await; + } + + drop(client_peer); + drop(server_peer); + + let relay_result = timeout(Duration::from_secs(2), relay).await.unwrap().unwrap(); + assert!(relay_result.is_ok() || matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. }))); + total + })); + } + + let mut delivered = 0usize; + for task in tasks { + delivered += task.await.unwrap(); + } + + assert!(stats.get_user_total_octets(&user) <= quota); + assert!(delivered <= quota as usize); +} diff --git a/src/proxy/tests/relay_quota_lock_eviction_lifecycle_tdd_tests.rs b/src/proxy/tests/relay_quota_lock_eviction_lifecycle_tdd_tests.rs new file mode 100644 index 0000000..806efb6 --- /dev/null +++ b/src/proxy/tests/relay_quota_lock_eviction_lifecycle_tdd_tests.rs @@ -0,0 +1,79 @@ +use super::*; +use dashmap::DashMap; +use std::sync::Arc; +use tokio::time::{Duration, timeout}; + +#[test] +fn tdd_explicit_quota_lock_evict_reclaims_only_unheld_entries() { + let _guard = quota_user_lock_test_scope(); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let held_user = format!("quota-evict-held-{}", std::process::id()); + let stale_a_user = format!("quota-evict-stale-a-{}", std::process::id()); + let stale_b_user = format!("quota-evict-stale-b-{}", std::process::id()); + + let held = quota_user_lock(&held_user); + let stale_a = quota_user_lock(&stale_a_user); + let stale_b = quota_user_lock(&stale_b_user); + + assert!(map.get(&held_user).is_some()); + assert!(map.get(&stale_a_user).is_some()); + assert!(map.get(&stale_b_user).is_some()); + + drop(stale_a); + drop(stale_b); + + quota_user_lock_evict(); + + assert!( + map.get(&held_user).is_some(), + "held entry must survive eviction" + ); + assert!( + map.get(&stale_a_user).is_none(), + "unheld stale entry must be reclaimed" + ); + assert!( + map.get(&stale_b_user).is_none(), + "unheld stale entry must be reclaimed" + ); + + drop(held); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn tdd_periodic_quota_lock_evictor_reclaims_stale_entries_off_hot_path() { + let _guard = quota_user_lock_test_scope(); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let held_user = format!("quota-evict-loop-held-{}", std::process::id()); + let stale_user = format!("quota-evict-loop-stale-{}", std::process::id()); + + let held = quota_user_lock(&held_user); + let stale = quota_user_lock(&stale_user); + + assert_eq!(map.len(), 2); + drop(stale); + + let evictor = spawn_quota_user_lock_evictor_for_tests(Duration::from_millis(5)); + + timeout(Duration::from_millis(200), async { + loop { + if map.get(&stale_user).is_none() { + break; + } + tokio::time::sleep(Duration::from_millis(5)).await; + } + }) + .await + .expect("periodic quota lock evictor must reclaim stale entry"); + + evictor.abort(); + + assert!(map.get(&held_user).is_some()); + assert!(map.get(&stale_user).is_none()); + + drop(held); +} diff --git a/src/proxy/tests/relay_quota_lock_eviction_stress_security_tests.rs b/src/proxy/tests/relay_quota_lock_eviction_stress_security_tests.rs new file mode 100644 index 0000000..251582a --- /dev/null +++ b/src/proxy/tests/relay_quota_lock_eviction_stress_security_tests.rs @@ -0,0 +1,153 @@ +use super::*; +use dashmap::DashMap; +use std::sync::Arc; +use tokio::task::JoinSet; +use tokio::time::{Duration, timeout}; + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_background_evictor_with_high_churn_keeps_cache_bounded_and_live() { + let _guard = quota_user_lock_test_scope(); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let evictor = spawn_quota_user_lock_evictor_for_tests(Duration::from_millis(5)); + + let mut tasks = JoinSet::new(); + for worker in 0..24u32 { + tasks.spawn(async move { + for round in 0..320u32 { + let user = format!( + "quota-evict-stress-user-{}-{}-{}", + std::process::id(), + worker, + round + ); + let lock = quota_user_lock(&user); + if round % 19 == 0 { + tokio::task::yield_now().await; + } + drop(lock); + } + }); + } + + while let Some(done) = tasks.join_next().await { + done.expect("stress worker must not panic"); + } + + quota_user_lock_evict(); + tokio::time::sleep(Duration::from_millis(20)).await; + + assert!( + map.len() <= QUOTA_USER_LOCKS_MAX, + "quota lock map must remain bounded after churn + eviction" + ); + + let sanity_user = format!("quota-evict-stress-sanity-{}", std::process::id()); + let sanity_lock = quota_user_lock(&sanity_user); + assert!( + map.get(&sanity_user).is_some(), + "sanity user should be cacheable after eviction reclaimed stale entries" + ); + + drop(sanity_lock); + evictor.abort(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn adversarial_held_lock_survives_repeated_eviction_then_reclaims_after_release() { + let _guard = quota_user_lock_test_scope(); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let held_user = format!("quota-evict-held-survive-{}", std::process::id()); + let held = quota_user_lock(&held_user); + + let evictor = spawn_quota_user_lock_evictor_for_tests(Duration::from_millis(3)); + + for idx in 0..512u32 { + let user = format!("quota-evict-held-churn-{}-{}", std::process::id(), idx); + let temp = quota_user_lock(&user); + drop(temp); + if idx % 32 == 0 { + tokio::task::yield_now().await; + } + } + + let reacquired = quota_user_lock(&held_user); + assert!( + Arc::ptr_eq(&held, &reacquired), + "held user lock identity must remain stable across repeated evictions" + ); + assert!( + map.get(&held_user).is_some(), + "held user entry must not be reclaimed while externally referenced" + ); + + drop(reacquired); + drop(held); + + timeout(Duration::from_millis(300), async { + loop { + if map.get(&held_user).is_none() { + break; + } + tokio::time::sleep(Duration::from_millis(5)).await; + } + }) + .await + .expect("released held lock must be reclaimed by periodic evictor"); + + evictor.abort(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_saturation_then_periodic_eviction_recovers_cacheability_without_inline_retain() { + let _guard = quota_user_lock_test_scope(); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); + let prefix = format!("quota-evict-saturated-{}", std::process::id()); + for idx in 0..QUOTA_USER_LOCKS_MAX { + retained.push(quota_user_lock(&format!("{prefix}-{idx}"))); + } + + assert_eq!(map.len(), QUOTA_USER_LOCKS_MAX); + + let overflow_user = format!("quota-evict-overflow-user-{}", std::process::id()); + let overflow_before = quota_user_lock(&overflow_user); + assert!( + map.get(&overflow_user).is_none(), + "saturated map must initially route new user to overflow stripe" + ); + + drop(retained); + + let evictor = spawn_quota_user_lock_evictor_for_tests(Duration::from_millis(4)); + + timeout(Duration::from_millis(400), async { + loop { + if map.len() < QUOTA_USER_LOCKS_MAX { + break; + } + tokio::time::sleep(Duration::from_millis(5)).await; + } + }) + .await + .expect("periodic evictor must reclaim stale saturated entries"); + + let overflow_after = quota_user_lock(&overflow_user); + assert!( + map.get(&overflow_user).is_some(), + "after eviction, overflow user should become cacheable again" + ); + assert!( + Arc::strong_count(&overflow_after) >= 2, + "cacheable lock should be held by map and caller" + ); + + drop(overflow_before); + drop(overflow_after); + evictor.abort(); +} diff --git a/src/proxy/tests/relay_quota_lock_pressure_adversarial_tests.rs b/src/proxy/tests/relay_quota_lock_pressure_adversarial_tests.rs index e29e86e..5687965 100644 --- a/src/proxy/tests/relay_quota_lock_pressure_adversarial_tests.rs +++ b/src/proxy/tests/relay_quota_lock_pressure_adversarial_tests.rs @@ -127,7 +127,7 @@ fn quota_lock_saturation_returns_stable_overflow_lock_without_cache_growth() { } #[test] -fn quota_lock_reclaims_unreferenced_entries_before_ephemeral_fallback() { +fn quota_lock_reclaims_unreferenced_entries_after_explicit_eviction_pass() { let _guard = super::quota_user_lock_test_scope(); let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); map.clear(); @@ -142,6 +142,8 @@ fn quota_lock_reclaims_unreferenced_entries_before_ephemeral_fallback() { drop(retained); + quota_user_lock_evict(); + let overflow_user = format!("quota-reclaim-overflow-{}", std::process::id()); let overflow = quota_user_lock(&overflow_user); diff --git a/src/proxy/tests/relay_quota_retry_allocation_latency_security_tests.rs b/src/proxy/tests/relay_quota_retry_allocation_latency_security_tests.rs new file mode 100644 index 0000000..447a090 --- /dev/null +++ b/src/proxy/tests/relay_quota_retry_allocation_latency_security_tests.rs @@ -0,0 +1,249 @@ +use super::*; +use crate::stats::Stats; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::task::{Context, Waker}; +use tokio::io::AsyncWriteExt; +use tokio::time::{Duration, Instant, timeout}; + +#[derive(Default)] +struct WakeCounter { + wakes: AtomicUsize, +} + +impl std::task::Wake for WakeCounter { + fn wake(self: Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } + + fn wake_by_ref(self: &Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } +} + +fn quota_test_guard() -> impl Drop { + super::quota_user_lock_test_scope() +} + +fn build_context() -> (Arc, Context<'static>) { + let wake_counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&wake_counter)); + let leaked_waker: &'static Waker = Box::leak(Box::new(waker)); + (wake_counter, Context::from_waker(leaked_waker)) +} + +fn sleep_slot_ptr(slot: &Option>>) -> usize { + slot.as_ref() + .map(|sleep| (&**sleep) as *const tokio::time::Sleep as usize) + .unwrap_or(0) +} + +#[tokio::test] +async fn tdd_single_pending_timer_does_not_allocate_on_each_repoll() { + let _guard = quota_test_guard(); + + let user = format!("retry-alloc-single-pending-{}", std::process::id()); + let lock = quota_user_lock(&user); + let held_guard = lock + .try_lock() + .expect("test must hold local lock to force retry scheduling"); + + reset_quota_retry_sleep_allocs_for_tests(); + + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + Some(2048), + Arc::new(AtomicBool::new(false)), + Instant::now(), + ); + + let (_wake_counter, mut cx) = build_context(); + + let first = Pin::new(&mut io).poll_write(&mut cx, &[0xA1]); + assert!(first.is_pending()); + let allocs_after_first = quota_retry_sleep_allocs_for_tests(); + let ptr_after_first = sleep_slot_ptr(&io.quota_write_retry_sleep); + + let second = Pin::new(&mut io).poll_write(&mut cx, &[0xA2]); + assert!(second.is_pending()); + let allocs_after_second = quota_retry_sleep_allocs_for_tests(); + let ptr_after_second = sleep_slot_ptr(&io.quota_write_retry_sleep); + + assert_eq!(allocs_after_first, 1, "first pending poll must allocate one timer"); + assert_eq!( + allocs_after_second, 1, + "repoll while the same timer is pending must not allocate again" + ); + assert_eq!( + ptr_after_first, ptr_after_second, + "repoll while pending should retain the same timer allocation" + ); + + drop(held_guard); +} + +#[tokio::test] +async fn tdd_retry_cycle_allocates_once_per_fired_timer_cycle_not_per_poll() { + let _guard = quota_test_guard(); + + let user = format!("retry-alloc-per-cycle-{}", std::process::id()); + let lock = quota_user_lock(&user); + let held_guard = lock + .try_lock() + .expect("test must hold local lock to keep write path pending"); + + reset_quota_retry_sleep_allocs_for_tests(); + + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + Some(2048), + Arc::new(AtomicBool::new(false)), + Instant::now(), + ); + + let (wake_counter, mut cx) = build_context(); + + let mut polls = 0u64; + let mut observed_wakes = 0usize; + let started = Instant::now(); + while started.elapsed() < Duration::from_millis(70) { + let poll = Pin::new(&mut io).poll_write(&mut cx, &[0xB1]); + polls = polls.saturating_add(1); + assert!(poll.is_pending()); + + let wakes = wake_counter.wakes.load(Ordering::Relaxed); + if wakes > observed_wakes { + observed_wakes = wakes; + } + + tokio::time::sleep(Duration::from_millis(1)).await; + } + + let allocs = quota_retry_sleep_allocs_for_tests(); + assert!(allocs >= 2, "multiple fired cycles should allocate multiple timers"); + assert!( + allocs < polls, + "timer allocations must be bounded by cycles, not by every repoll (allocs={allocs}, polls={polls})" + ); + + drop(held_guard); +} + +#[tokio::test] +async fn adversarial_backoff_latency_envelope_stays_bounded_under_contention() { + let _guard = quota_test_guard(); + + let user = format!("retry-latency-envelope-{}", std::process::id()); + let lock = quota_user_lock(&user); + let held_guard = lock + .try_lock() + .expect("test must hold local lock for sustained contention"); + + reset_quota_retry_sleep_allocs_for_tests(); + + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + Some(2048), + Arc::new(AtomicBool::new(false)), + Instant::now(), + ); + + let (wake_counter, mut cx) = build_context(); + + let first = Pin::new(&mut io).poll_write(&mut cx, &[0xC1]); + assert!(first.is_pending()); + + let started = Instant::now(); + let mut last_wakes = 0usize; + let mut wake_instants = Vec::new(); + + while started.elapsed() < Duration::from_millis(120) { + let wakes = wake_counter.wakes.load(Ordering::Relaxed); + if wakes > last_wakes { + last_wakes = wakes; + wake_instants.push(Instant::now()); + let pending = Pin::new(&mut io).poll_write(&mut cx, &[0xC2]); + assert!(pending.is_pending()); + } + tokio::time::sleep(Duration::from_millis(1)).await; + } + + let mut max_gap = Duration::from_millis(0); + for idx in 1..wake_instants.len() { + let gap = wake_instants[idx].saturating_duration_since(wake_instants[idx - 1]); + if gap > max_gap { + max_gap = gap; + } + } + + assert!( + max_gap <= Duration::from_millis(35), + "retry wake gap must remain bounded in test profile; observed max gap={max_gap:?}" + ); + assert!( + quota_retry_sleep_allocs_for_tests() <= 16, + "allocation cycles must remain bounded during a short contention window" + ); + + drop(held_guard); +} + +#[tokio::test] +async fn micro_benchmark_release_to_completion_latency_stays_bounded() { + let _guard = quota_test_guard(); + + let rounds = 96usize; + let mut samples_ms = Vec::with_capacity(rounds); + + for round in 0..rounds { + let user = format!("retry-release-latency-{}-{round}", std::process::id()); + let lock = quota_user_lock(&user); + let held_guard = lock + .try_lock() + .expect("test must hold local lock before spawning blocked writer"); + + let writer = tokio::spawn(async move { + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + Some(2048), + Arc::new(AtomicBool::new(false)), + Instant::now(), + ); + io.write_all(&[0xD1]).await + }); + + tokio::time::sleep(Duration::from_millis(2)).await; + let release_at = Instant::now(); + drop(held_guard); + + let done = timeout(Duration::from_millis(120), writer) + .await + .expect("blocked writer must complete after release") + .expect("writer task must not panic"); + assert!(done.is_ok()); + + samples_ms.push(release_at.elapsed().as_millis() as u64); + } + + samples_ms.sort_unstable(); + let p95_idx = ((samples_ms.len() * 95) / 100).min(samples_ms.len().saturating_sub(1)); + let p95_ms = samples_ms[p95_idx]; + + assert!( + p95_ms <= 40, + "contention release->completion p95 must stay bounded; p95_ms={p95_ms}, samples={samples_ms:?}" + ); +}