From b930ea1ec5cab67ffde228f82c81f0d439c68f78 Mon Sep 17 00:00:00 2001 From: David Osipov Date: Sat, 21 Mar 2026 15:16:20 +0400 Subject: [PATCH] Add regression and security tests for relay quota and TLS stream handling - Introduced regression tests for relay quota wake liveness to ensure proper handling of contention and wake events. - Added adversarial tests to validate the behavior of the quota system under stress and contention scenarios. - Implemented security tests for the TLS stream to verify the preservation of pending plaintext during state transitions. - Enhanced the pool writer tests to ensure proper quarantine behavior and validate the removal of writers from the registry. - Included fuzz testing to assess the robustness of the quota and TLS handling mechanisms against unexpected inputs and states. --- src/proxy/client.rs | 19 +- src/proxy/handshake.rs | 32 +- src/proxy/middle_relay.rs | 4 + src/proxy/relay.rs | 76 ++++- ...nt_beobachten_ttl_bounds_security_tests.rs | 126 ++++++++ ...ent_tls_mtproto_fallback_security_tests.rs | 106 ++++++ ..._auth_probe_hardening_adversarial_tests.rs | 187 +++++++++++ ...dshake_saturation_poison_security_tests.rs | 71 ++++ ...ay_desync_all_full_dedup_security_tests.rs | 179 ++++++++++ ...ay_quota_wake_liveness_regression_tests.rs | 290 +++++++++++++++++ ...lay_quota_waker_storm_adversarial_tests.rs | 306 ++++++++++++++++++ .../relay_watchdog_delta_security_tests.rs | 61 ++++ src/stream/tls_stream.rs | 9 + ...stream_pending_plaintext_security_tests.rs | 143 ++++++++ src/transport/middle_proxy/pool_writer.rs | 7 +- .../tests/pool_writer_security_tests.rs | 208 +++++++++++- 16 files changed, 1790 insertions(+), 34 deletions(-) create mode 100644 src/proxy/tests/client_beobachten_ttl_bounds_security_tests.rs create mode 100644 src/proxy/tests/handshake_auth_probe_hardening_adversarial_tests.rs create mode 100644 src/proxy/tests/handshake_saturation_poison_security_tests.rs create mode 100644 src/proxy/tests/middle_relay_desync_all_full_dedup_security_tests.rs create mode 100644 src/proxy/tests/relay_quota_wake_liveness_regression_tests.rs create mode 100644 src/proxy/tests/relay_quota_waker_storm_adversarial_tests.rs create mode 100644 src/proxy/tests/relay_watchdog_delta_security_tests.rs create mode 100644 src/stream/tls_stream_pending_plaintext_security_tests.rs diff --git a/src/proxy/client.rs b/src/proxy/client.rs index 65b893d..d0aa3a2 100644 --- a/src/proxy/client.rs +++ b/src/proxy/client.rs @@ -87,6 +87,7 @@ use crate::proxy::middle_relay::handle_via_middle_proxy; use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController}; fn beobachten_ttl(config: &ProxyConfig) -> Duration { + const BEOBACHTEN_TTL_MAX_MINUTES: u64 = 24 * 60; let minutes = config.general.beobachten_minutes; if minutes == 0 { static BEOBACHTEN_ZERO_MINUTES_WARNED: OnceLock = OnceLock::new(); @@ -99,7 +100,19 @@ fn beobachten_ttl(config: &ProxyConfig) -> Duration { return Duration::from_secs(60); } - Duration::from_secs(minutes.saturating_mul(60)) + if minutes > BEOBACHTEN_TTL_MAX_MINUTES { + static BEOBACHTEN_OVERSIZED_MINUTES_WARNED: OnceLock = OnceLock::new(); + let warned = BEOBACHTEN_OVERSIZED_MINUTES_WARNED.get_or_init(|| AtomicBool::new(false)); + if !warned.swap(true, Ordering::Relaxed) { + warn!( + configured_minutes = minutes, + max_minutes = BEOBACHTEN_TTL_MAX_MINUTES, + "general.beobachten_minutes is too large; clamping to secure maximum" + ); + } + } + + Duration::from_secs(minutes.min(BEOBACHTEN_TTL_MAX_MINUTES).saturating_mul(60)) } fn wrap_tls_application_record(payload: &[u8]) -> Vec { @@ -1277,3 +1290,7 @@ mod masking_shape_classifier_fuzz_redteam_expected_fail_tests; #[cfg(test)] #[path = "tests/client_masking_probe_evasion_blackhat_tests.rs"] mod masking_probe_evasion_blackhat_tests; + +#[cfg(test)] +#[path = "tests/client_beobachten_ttl_bounds_security_tests.rs"] +mod beobachten_ttl_bounds_security_tests; diff --git a/src/proxy/handshake.rs b/src/proxy/handshake.rs index 8751436..0ac3c0d 100644 --- a/src/proxy/handshake.rs +++ b/src/proxy/handshake.rs @@ -78,6 +78,13 @@ fn auth_probe_saturation_state() -> &'static Mutex std::sync::MutexGuard<'static, Option> { + auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + fn normalize_auth_probe_ip(peer_ip: IpAddr) -> IpAddr { match peer_ip { IpAddr::V4(ip) => IpAddr::V4(ip), @@ -155,11 +162,7 @@ fn auth_probe_should_apply_preauth_throttle(peer_ip: IpAddr, now: Instant) -> bo } fn auth_probe_saturation_is_throttled(now: Instant) -> bool { - let saturation = auth_probe_saturation_state(); - let mut guard = match saturation.lock() { - Ok(guard) => guard, - Err(_) => return false, - }; + let mut guard = auth_probe_saturation_state_lock(); let Some(state) = guard.as_mut() else { return false; @@ -178,11 +181,7 @@ fn auth_probe_saturation_is_throttled(now: Instant) -> bool { } fn auth_probe_note_saturation(now: Instant) { - let saturation = auth_probe_saturation_state(); - let mut guard = match saturation.lock() { - Ok(guard) => guard, - Err(_) => return, - }; + let mut guard = auth_probe_saturation_state_lock(); match guard.as_mut() { Some(state) @@ -356,9 +355,8 @@ fn clear_auth_probe_state_for_testing() { if let Some(state) = AUTH_PROBE_STATE.get() { state.clear(); } - if let Some(saturation) = AUTH_PROBE_SATURATION_STATE.get() - && let Ok(mut guard) = saturation.lock() - { + if AUTH_PROBE_SATURATION_STATE.get().is_some() { + let mut guard = auth_probe_saturation_state_lock(); *guard = None; } } @@ -975,6 +973,14 @@ mod adversarial_tests; #[path = "tests/handshake_fuzz_security_tests.rs"] mod fuzz_security_tests; +#[cfg(test)] +#[path = "tests/handshake_saturation_poison_security_tests.rs"] +mod saturation_poison_security_tests; + +#[cfg(test)] +#[path = "tests/handshake_auth_probe_hardening_adversarial_tests.rs"] +mod auth_probe_hardening_adversarial_tests; + /// Compile-time guard: HandshakeSuccess holds cryptographic key material and /// must never be Copy. A Copy impl would allow silent key duplication, /// undermining the zeroize-on-drop guarantee. diff --git a/src/proxy/middle_relay.rs b/src/proxy/middle_relay.rs index d212a43..2000977 100644 --- a/src/proxy/middle_relay.rs +++ b/src/proxy/middle_relay.rs @@ -1653,3 +1653,7 @@ mod security_tests; #[cfg(test)] #[path = "tests/middle_relay_idle_policy_security_tests.rs"] mod idle_policy_security_tests; + +#[cfg(test)] +#[path = "tests/middle_relay_desync_all_full_dedup_security_tests.rs"] +mod desync_all_full_dedup_security_tests; diff --git a/src/proxy/relay.rs b/src/proxy/relay.rs index c0cf3d4..6b71ace 100644 --- a/src/proxy/relay.rs +++ b/src/proxy/relay.rs @@ -81,6 +81,11 @@ const ACTIVITY_TIMEOUT: Duration = Duration::from_secs(1800); /// without measurable overhead from atomic reads. const WATCHDOG_INTERVAL: Duration = Duration::from_secs(10); +#[inline] +fn watchdog_delta(current: u64, previous: u64) -> u64 { + current.saturating_sub(previous) +} + // ============= CombinedStream ============= /// Combines separate read and write halves into a single bidirectional stream. @@ -210,6 +215,8 @@ struct StatsIo { quota_exceeded: Arc, quota_read_wake_scheduled: bool, quota_write_wake_scheduled: bool, + quota_read_retry_active: Arc, + quota_write_retry_active: Arc, epoch: Instant, } @@ -234,11 +241,20 @@ impl StatsIo { quota_exceeded, quota_read_wake_scheduled: false, quota_write_wake_scheduled: false, + quota_read_retry_active: Arc::new(AtomicBool::new(false)), + quota_write_retry_active: Arc::new(AtomicBool::new(false)), epoch, } } } +impl Drop for StatsIo { + fn drop(&mut self) { + self.quota_read_retry_active.store(false, Ordering::Relaxed); + self.quota_write_retry_active.store(false, Ordering::Relaxed); + } +} + #[derive(Debug)] struct QuotaIoSentinel; @@ -262,6 +278,26 @@ fn is_quota_io_error(err: &io::Error) -> bool { .is_some() } +#[cfg(test)] +const QUOTA_CONTENTION_RETRY_INTERVAL: Duration = Duration::from_millis(1); +#[cfg(not(test))] +const QUOTA_CONTENTION_RETRY_INTERVAL: Duration = Duration::from_millis(2); + +fn spawn_quota_retry_waker(retry_active: Arc, waker: std::task::Waker) { + tokio::task::spawn(async move { + loop { + if !retry_active.load(Ordering::Relaxed) { + break; + } + tokio::time::sleep(QUOTA_CONTENTION_RETRY_INTERVAL).await; + if !retry_active.load(Ordering::Relaxed) { + break; + } + waker.wake_by_ref(); + } + }); +} + static QUOTA_USER_LOCKS: OnceLock>>> = OnceLock::new(); static QUOTA_USER_OVERFLOW_LOCKS: OnceLock>>> = OnceLock::new(); @@ -334,16 +370,17 @@ impl AsyncRead for StatsIo { match lock.try_lock() { Ok(guard) => { this.quota_read_wake_scheduled = false; + this.quota_read_retry_active.store(false, Ordering::Relaxed); Some(guard) } Err(_) => { if !this.quota_read_wake_scheduled { this.quota_read_wake_scheduled = true; - let waker = cx.waker().clone(); - tokio::task::spawn(async move { - tokio::task::yield_now().await; - waker.wake(); - }); + this.quota_read_retry_active.store(true, Ordering::Relaxed); + spawn_quota_retry_waker( + Arc::clone(&this.quota_read_retry_active), + cx.waker().clone(), + ); } return Poll::Pending; } @@ -423,16 +460,17 @@ impl AsyncWrite for StatsIo { match lock.try_lock() { Ok(guard) => { this.quota_write_wake_scheduled = false; + this.quota_write_retry_active.store(false, Ordering::Relaxed); Some(guard) } Err(_) => { if !this.quota_write_wake_scheduled { this.quota_write_wake_scheduled = true; - let waker = cx.waker().clone(); - tokio::task::spawn(async move { - tokio::task::yield_now().await; - waker.wake(); - }); + this.quota_write_retry_active.store(true, Ordering::Relaxed); + spawn_quota_retry_waker( + Arc::clone(&this.quota_write_retry_active), + cx.waker().clone(), + ); } return Poll::Pending; } @@ -591,8 +629,8 @@ where // ── Periodic rate logging ─────────────────────────────── let c2s = wd_counters.c2s_bytes.load(Ordering::Relaxed); let s2c = wd_counters.s2c_bytes.load(Ordering::Relaxed); - let c2s_delta = c2s - prev_c2s; - let s2c_delta = s2c - prev_s2c; + let c2s_delta = watchdog_delta(c2s, prev_c2s); + let s2c_delta = watchdog_delta(s2c, prev_s2c); if c2s_delta > 0 || s2c_delta > 0 { let secs = WATCHDOG_INTERVAL.as_secs_f64(); @@ -729,4 +767,16 @@ mod relay_quota_model_adversarial_tests; #[cfg(test)] #[path = "tests/relay_quota_overflow_regression_tests.rs"] -mod relay_quota_overflow_regression_tests; \ No newline at end of file +mod relay_quota_overflow_regression_tests; + +#[cfg(test)] +#[path = "tests/relay_watchdog_delta_security_tests.rs"] +mod relay_watchdog_delta_security_tests; + +#[cfg(test)] +#[path = "tests/relay_quota_waker_storm_adversarial_tests.rs"] +mod relay_quota_waker_storm_adversarial_tests; + +#[cfg(test)] +#[path = "tests/relay_quota_wake_liveness_regression_tests.rs"] +mod relay_quota_wake_liveness_regression_tests; \ No newline at end of file diff --git a/src/proxy/tests/client_beobachten_ttl_bounds_security_tests.rs b/src/proxy/tests/client_beobachten_ttl_bounds_security_tests.rs new file mode 100644 index 0000000..80f9834 --- /dev/null +++ b/src/proxy/tests/client_beobachten_ttl_bounds_security_tests.rs @@ -0,0 +1,126 @@ +use super::*; + +const BEOBACHTEN_TTL_MAX_MINUTES: u64 = 24 * 60; + +#[test] +fn beobachten_ttl_exact_upper_bound_is_preserved() { + let mut config = ProxyConfig::default(); + config.general.beobachten = true; + config.general.beobachten_minutes = BEOBACHTEN_TTL_MAX_MINUTES; + + let ttl = beobachten_ttl(&config); + assert_eq!( + ttl, + Duration::from_secs(BEOBACHTEN_TTL_MAX_MINUTES * 60), + "upper-bound TTL should remain unchanged" + ); +} + +#[test] +fn beobachten_ttl_above_upper_bound_is_clamped() { + let mut config = ProxyConfig::default(); + config.general.beobachten = true; + config.general.beobachten_minutes = BEOBACHTEN_TTL_MAX_MINUTES + 1; + + let ttl = beobachten_ttl(&config); + assert_eq!( + ttl, + Duration::from_secs(BEOBACHTEN_TTL_MAX_MINUTES * 60), + "TTL above security cap must be clamped" + ); +} + +#[test] +fn beobachten_ttl_u64_max_is_clamped_fail_safe() { + let mut config = ProxyConfig::default(); + config.general.beobachten = true; + config.general.beobachten_minutes = u64::MAX; + + let ttl = beobachten_ttl(&config); + assert_eq!( + ttl, + Duration::from_secs(BEOBACHTEN_TTL_MAX_MINUTES * 60), + "extreme configured TTL must not become multi-century retention" + ); +} + +#[test] +fn positive_one_minute_maps_to_exact_60_seconds() { + let mut config = ProxyConfig::default(); + config.general.beobachten = true; + config.general.beobachten_minutes = 1; + + assert_eq!(beobachten_ttl(&config), Duration::from_secs(60)); +} + +#[test] +fn adversarial_boundary_triplet_behaves_deterministically() { + let mut config = ProxyConfig::default(); + config.general.beobachten = true; + + config.general.beobachten_minutes = BEOBACHTEN_TTL_MAX_MINUTES - 1; + assert_eq!( + beobachten_ttl(&config), + Duration::from_secs((BEOBACHTEN_TTL_MAX_MINUTES - 1) * 60) + ); + + config.general.beobachten_minutes = BEOBACHTEN_TTL_MAX_MINUTES; + assert_eq!( + beobachten_ttl(&config), + Duration::from_secs(BEOBACHTEN_TTL_MAX_MINUTES * 60) + ); + + config.general.beobachten_minutes = BEOBACHTEN_TTL_MAX_MINUTES + 1; + assert_eq!( + beobachten_ttl(&config), + Duration::from_secs(BEOBACHTEN_TTL_MAX_MINUTES * 60) + ); +} + +#[test] +fn light_fuzz_random_minutes_match_fail_safe_model() { + let mut config = ProxyConfig::default(); + config.general.beobachten = true; + + let mut seed = 0xD15E_A5E5_F00D_BAADu64; + for _ in 0..8192 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + config.general.beobachten_minutes = seed; + let ttl = beobachten_ttl(&config); + let expected = if seed == 0 { + Duration::from_secs(60) + } else { + Duration::from_secs(seed.min(BEOBACHTEN_TTL_MAX_MINUTES) * 60) + }; + + assert_eq!(ttl, expected, "ttl mismatch for minutes={seed}"); + assert!(ttl <= Duration::from_secs(BEOBACHTEN_TTL_MAX_MINUTES * 60)); + } +} + +#[test] +fn stress_monotonic_minutes_remain_monotonic_until_cap_then_flat() { + let mut config = ProxyConfig::default(); + config.general.beobachten = true; + + let mut prev = Duration::from_secs(0); + for minutes in 0..=(BEOBACHTEN_TTL_MAX_MINUTES + 4096) { + config.general.beobachten_minutes = minutes; + let ttl = beobachten_ttl(&config); + + assert!(ttl >= prev, "ttl must be non-decreasing as minutes grow"); + assert!(ttl <= Duration::from_secs(BEOBACHTEN_TTL_MAX_MINUTES * 60)); + + if minutes > BEOBACHTEN_TTL_MAX_MINUTES { + assert_eq!( + ttl, + Duration::from_secs(BEOBACHTEN_TTL_MAX_MINUTES * 60), + "ttl must stay clamped once cap is exceeded" + ); + } + prev = ttl; + } +} diff --git a/src/proxy/tests/client_tls_mtproto_fallback_security_tests.rs b/src/proxy/tests/client_tls_mtproto_fallback_security_tests.rs index 94732f5..920c013 100644 --- a/src/proxy/tests/client_tls_mtproto_fallback_security_tests.rs +++ b/src/proxy/tests/client_tls_mtproto_fallback_security_tests.rs @@ -4,7 +4,10 @@ use crate::crypto::sha256_hmac; use crate::protocol::constants::{ HANDSHAKE_LEN, MAX_TLS_CIPHERTEXT_SIZE, + TLS_RECORD_ALERT, TLS_RECORD_APPLICATION, + TLS_RECORD_CHANGE_CIPHER, + TLS_RECORD_HANDSHAKE, TLS_VERSION, }; use crate::protocol::tls; @@ -2753,3 +2756,106 @@ async fn blackhat_coalesced_tail_zero_following_record_after_coalesced_is_not_in .unwrap() .unwrap(); } + +#[tokio::test] +async fn blackhat_coalesced_tail_light_fuzz_mixed_followup_records_stay_byte_exact() { + let mut seed = 0xA11C_E2E5_F00D_BAADu64; + + for case in 0..24u32 { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + let tail_len = (seed as usize % 1536) + 1; + let mut tail = vec![0u8; tail_len]; + for (i, b) in tail.iter_mut().enumerate() { + *b = (seed as u8).wrapping_add(i as u8).wrapping_mul(13); + } + + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + let follow_type = match seed & 0x3 { + 0 => TLS_RECORD_APPLICATION, + 1 => TLS_RECORD_ALERT, + 2 => TLS_RECORD_CHANGE_CIPHER, + _ => TLS_RECORD_HANDSHAKE, + }; + let follow_len = (seed as usize % 96) + (case as usize % 3); + let mut follow_payload = vec![0u8; follow_len]; + for (i, b) in follow_payload.iter_mut().enumerate() { + *b = (case as u8).wrapping_mul(29).wrapping_add(i as u8); + } + + let secret = [0xD1u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 600 + case, 600, 0x33); + let coalesced_record = wrap_invalid_mtproto_with_coalesced_tail(&tail); + let expected_tail = wrap_tls_application_data(&tail); + let follow_record = wrap_tls_record(follow_type, &follow_payload); + let expected_wire = [expected_tail.clone(), follow_record.clone()].concat(); + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got = vec![0u8; expected_wire.len()]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, expected_wire); + }); + + let harness = build_harness("d1d1d1d1d1d1d1d1d1d1d1d1d1d1d1d1", backend_addr.port()); + let (server_side, mut client_side) = duplex(262144); + let peer: SocketAddr = format!("198.51.100.250:{}", 57000 + case as u16) + .parse() + .unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut head = [0u8; 5]; + client_side.read_exact(&mut head).await.unwrap(); + assert_eq!(head[0], 0x16); + read_and_discard_tls_record_body(&mut client_side, head).await; + + let mut local_seed = seed ^ 0x55AA_55AA_1234_5678; + for data in [&coalesced_record, &follow_record] { + let mut pos = 0usize; + while pos < data.len() { + local_seed ^= local_seed << 7; + local_seed ^= local_seed >> 9; + local_seed ^= local_seed << 8; + let step = ((local_seed as usize % 17) + 1).min(data.len() - pos); + let end = pos + step; + client_side.write_all(&data[pos..end]).await.unwrap(); + pos = end; + } + } + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + } +} diff --git a/src/proxy/tests/handshake_auth_probe_hardening_adversarial_tests.rs b/src/proxy/tests/handshake_auth_probe_hardening_adversarial_tests.rs new file mode 100644 index 0000000..d8fac4f --- /dev/null +++ b/src/proxy/tests/handshake_auth_probe_hardening_adversarial_tests.rs @@ -0,0 +1,187 @@ +use super::*; +use std::net::{IpAddr, Ipv4Addr}; +use std::time::{Duration, Instant}; + +fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> { + auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +#[test] +fn positive_preauth_throttle_activates_after_failure_threshold() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let ip = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 20)); + let now = Instant::now(); + + for _ in 0..AUTH_PROBE_BACKOFF_START_FAILS { + auth_probe_record_failure(ip, now); + } + + assert!( + auth_probe_is_throttled(ip, now), + "peer must be throttled once fail streak reaches threshold" + ); +} + +#[test] +fn negative_unrelated_peer_remains_unthrottled() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let attacker = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 12)); + let benign = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 13)); + let now = Instant::now(); + + for _ in 0..AUTH_PROBE_BACKOFF_START_FAILS { + auth_probe_record_failure(attacker, now); + } + + assert!(auth_probe_is_throttled(attacker, now)); + assert!( + !auth_probe_is_throttled(benign, now), + "throttle state must stay scoped to normalized peer key" + ); +} + +#[test] +fn edge_expired_entry_is_pruned_and_no_longer_throttled() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let ip = IpAddr::V4(Ipv4Addr::new(192, 0, 2, 41)); + let base = Instant::now(); + for _ in 0..AUTH_PROBE_BACKOFF_START_FAILS { + auth_probe_record_failure(ip, base); + } + + let expired_at = base + Duration::from_secs(AUTH_PROBE_TRACK_RETENTION_SECS + 1); + assert!( + !auth_probe_is_throttled(ip, expired_at), + "expired entries must not keep throttling peers" + ); + + let state = auth_probe_state_map(); + assert!( + state.get(&normalize_auth_probe_ip(ip)).is_none(), + "expired lookup should prune stale state" + ); +} + +#[test] +fn adversarial_saturation_grace_requires_extra_failures_before_preauth_throttle() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let ip = IpAddr::V4(Ipv4Addr::new(198, 18, 0, 7)); + let now = Instant::now(); + + for _ in 0..AUTH_PROBE_BACKOFF_START_FAILS { + auth_probe_record_failure(ip, now); + } + auth_probe_note_saturation(now); + + assert!( + !auth_probe_should_apply_preauth_throttle(ip, now), + "during global saturation, peer must receive configured grace window" + ); + + for _ in 0..AUTH_PROBE_SATURATION_GRACE_FAILS { + auth_probe_record_failure(ip, now + Duration::from_millis(1)); + } + + assert!( + auth_probe_should_apply_preauth_throttle(ip, now + Duration::from_millis(1)), + "after grace failures are exhausted, preauth throttle must activate" + ); +} + +#[test] +fn integration_over_cap_insertion_keeps_probe_map_bounded() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let now = Instant::now(); + for idx in 0..(AUTH_PROBE_TRACK_MAX_ENTRIES + 1024) { + let ip = IpAddr::V4(Ipv4Addr::new( + 10, + ((idx / 65_536) % 256) as u8, + ((idx / 256) % 256) as u8, + (idx % 256) as u8, + )); + auth_probe_record_failure(ip, now); + } + + let tracked = auth_probe_state_map().len(); + assert!( + tracked <= AUTH_PROBE_TRACK_MAX_ENTRIES, + "probe map must remain hard bounded under insertion storm" + ); +} + +#[test] +fn light_fuzz_randomized_failures_preserve_cap_and_nonzero_streaks() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let mut seed = 0x4D53_5854_6F66_6175u64; + let now = Instant::now(); + + for _ in 0..8192 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let ip = IpAddr::V4(Ipv4Addr::new( + (seed >> 24) as u8, + (seed >> 16) as u8, + (seed >> 8) as u8, + seed as u8, + )); + auth_probe_record_failure(ip, now + Duration::from_millis((seed & 0x3f) as u64)); + } + + let state = auth_probe_state_map(); + assert!(state.len() <= AUTH_PROBE_TRACK_MAX_ENTRIES); + for entry in state.iter() { + assert!(entry.value().fail_streak > 0); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_parallel_failure_flood_keeps_state_hard_capped() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let start = Instant::now(); + let mut tasks = Vec::new(); + + for worker in 0..8u8 { + tasks.push(tokio::spawn(async move { + for i in 0..4096u32 { + let ip = IpAddr::V4(Ipv4Addr::new( + 172, + worker, + ((i >> 8) & 0xff) as u8, + (i & 0xff) as u8, + )); + auth_probe_record_failure(ip, start + Duration::from_millis((i % 4) as u64)); + } + })); + } + + for task in tasks { + task.await.expect("stress worker must not panic"); + } + + let tracked = auth_probe_state_map().len(); + assert!( + tracked <= AUTH_PROBE_TRACK_MAX_ENTRIES, + "parallel failure flood must not exceed cap" + ); + + let probe = IpAddr::V4(Ipv4Addr::new(172, 3, 4, 5)); + let _ = auth_probe_is_throttled(probe, start + Duration::from_millis(2)); +} diff --git a/src/proxy/tests/handshake_saturation_poison_security_tests.rs b/src/proxy/tests/handshake_saturation_poison_security_tests.rs new file mode 100644 index 0000000..4c2ca5d --- /dev/null +++ b/src/proxy/tests/handshake_saturation_poison_security_tests.rs @@ -0,0 +1,71 @@ +use super::*; +use std::time::{Duration, Instant}; + +fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> { + auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +fn poison_saturation_mutex() { + let saturation = auth_probe_saturation_state(); + let poison_thread = std::thread::spawn(move || { + let _guard = saturation + .lock() + .expect("saturation mutex must be lockable for poison setup"); + panic!("intentional poison for saturation mutex resilience test"); + }); + let _ = poison_thread.join(); +} + +#[test] +fn auth_probe_saturation_note_recovers_after_mutex_poison() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + poison_saturation_mutex(); + + let now = Instant::now(); + auth_probe_note_saturation(now); + + assert!( + auth_probe_saturation_is_throttled_at_for_testing(now), + "poisoned saturation mutex must not disable saturation throttling" + ); +} + +#[test] +fn auth_probe_saturation_check_recovers_after_mutex_poison() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + poison_saturation_mutex(); + + { + let mut guard = auth_probe_saturation_state_lock(); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: Instant::now() + Duration::from_millis(10), + last_seen: Instant::now(), + }); + } + + assert!( + auth_probe_saturation_is_throttled_for_testing(), + "throttle check must recover poisoned saturation mutex and stay fail-closed" + ); +} + +#[test] +fn clear_auth_probe_state_clears_saturation_even_if_poisoned() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + poison_saturation_mutex(); + + auth_probe_note_saturation(Instant::now()); + assert!(auth_probe_saturation_is_throttled_for_testing()); + + clear_auth_probe_state_for_testing(); + assert!( + !auth_probe_saturation_is_throttled_for_testing(), + "clear helper must clear saturation state even after poison" + ); +} diff --git a/src/proxy/tests/middle_relay_desync_all_full_dedup_security_tests.rs b/src/proxy/tests/middle_relay_desync_all_full_dedup_security_tests.rs new file mode 100644 index 0000000..574a3f9 --- /dev/null +++ b/src/proxy/tests/middle_relay_desync_all_full_dedup_security_tests.rs @@ -0,0 +1,179 @@ +use super::*; +use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::thread; + +#[test] +fn desync_all_full_bypass_does_not_initialize_or_grow_dedup_cache() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("desync dedup test lock must be available"); + clear_desync_dedup_for_testing(); + + let initial_len = DESYNC_DEDUP.get().map(|dedup| dedup.len()).unwrap_or(0); + let now = Instant::now(); + + for i in 0..20_000u64 { + assert!( + should_emit_full_desync(0xD35E_D000_0000_0000u64 ^ i, true, now), + "desync_all_full path must always emit" + ); + } + + let after_len = DESYNC_DEDUP.get().map(|dedup| dedup.len()).unwrap_or(0); + assert_eq!( + after_len, initial_len, + "desync_all_full bypass must not allocate or accumulate dedup entries" + ); +} + +#[test] +fn desync_all_full_bypass_keeps_existing_dedup_entries_unchanged() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("desync dedup test lock must be available"); + clear_desync_dedup_for_testing(); + + let dedup = DESYNC_DEDUP.get_or_init(DashMap::new); + let seed_time = Instant::now() - Duration::from_secs(7); + dedup.insert(0xAAAABBBBCCCCDDDD, seed_time); + dedup.insert(0x1111222233334444, seed_time); + + let now = Instant::now(); + for i in 0..2048u64 { + assert!( + should_emit_full_desync(0xF011_F000_0000_0000u64 ^ i, true, now), + "desync_all_full must bypass suppression and dedup refresh" + ); + } + + assert_eq!(dedup.len(), 2, "bypass path must not mutate dedup cardinality"); + assert_eq!( + *dedup + .get(&0xAAAABBBBCCCCDDDD) + .expect("seed key must remain"), + seed_time, + "bypass path must not refresh existing dedup timestamps" + ); + assert_eq!( + *dedup + .get(&0x1111222233334444) + .expect("seed key must remain"), + seed_time, + "bypass path must not touch unrelated dedup entries" + ); +} + +#[test] +fn edge_all_full_burst_does_not_poison_later_false_path_tracking() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("desync dedup test lock must be available"); + clear_desync_dedup_for_testing(); + + let now = Instant::now(); + for i in 0..8192u64 { + assert!(should_emit_full_desync(0xABCD_0000_0000_0000 ^ i, true, now)); + } + + let tracked_key = 0xDEAD_BEEF_0000_0001u64; + assert!( + should_emit_full_desync(tracked_key, false, now), + "first false-path event after all_full burst must still be tracked and emitted" + ); + + let dedup = DESYNC_DEDUP + .get() + .expect("false path should initialize dedup"); + assert!(dedup.get(&tracked_key).is_some()); +} + +#[test] +fn adversarial_mixed_sequence_true_steps_never_change_cache_len() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("desync dedup test lock must be available"); + clear_desync_dedup_for_testing(); + + let dedup = DESYNC_DEDUP.get_or_init(DashMap::new); + for i in 0..256u64 { + dedup.insert(0x1000_0000_0000_0000 ^ i, Instant::now()); + } + + let mut seed = 0xC0DE_CAFE_BAAD_F00Du64; + for i in 0..4096u64 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let flag_all_full = (seed & 0x1) == 1; + let key = 0x7000_0000_0000_0000u64 ^ i ^ seed; + let before = dedup.len(); + let _ = should_emit_full_desync(key, flag_all_full, Instant::now()); + let after = dedup.len(); + + if flag_all_full { + assert_eq!(after, before, "all_full step must not mutate dedup length"); + } + } +} + +#[test] +fn light_fuzz_all_full_mode_always_emits_and_stays_bounded() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("desync dedup test lock must be available"); + clear_desync_dedup_for_testing(); + + let mut seed = 0x1234_5678_9ABC_DEF0u64; + let before = DESYNC_DEDUP.get().map(|d| d.len()).unwrap_or(0); + + for _ in 0..20_000 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + let key = seed ^ 0x55AA_55AA_55AA_55AAu64; + assert!(should_emit_full_desync(key, true, Instant::now())); + } + + let after = DESYNC_DEDUP.get().map(|d| d.len()).unwrap_or(0); + assert_eq!(after, before); + assert!(after <= DESYNC_DEDUP_MAX_ENTRIES); +} + +#[test] +fn stress_parallel_all_full_storm_does_not_grow_or_mutate_cache() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("desync dedup test lock must be available"); + clear_desync_dedup_for_testing(); + + let dedup = DESYNC_DEDUP.get_or_init(DashMap::new); + let seed_time = Instant::now() - Duration::from_secs(2); + for i in 0..1024u64 { + dedup.insert(0x8888_0000_0000_0000 ^ i, seed_time); + } + let before_len = dedup.len(); + + let emits = Arc::new(AtomicUsize::new(0)); + let mut workers = Vec::new(); + for worker in 0..16u64 { + let emits = Arc::clone(&emits); + workers.push(thread::spawn(move || { + let now = Instant::now(); + for i in 0..4096u64 { + let key = 0xFACE_0000_0000_0000u64 ^ (worker << 20) ^ i; + if should_emit_full_desync(key, true, now) { + emits.fetch_add(1, Ordering::Relaxed); + } + } + })); + } + + for worker in workers { + worker.join().expect("worker must not panic"); + } + + assert_eq!(emits.load(Ordering::Relaxed), 16 * 4096); + assert_eq!(dedup.len(), before_len, "parallel all_full storm must not mutate cache len"); +} diff --git a/src/proxy/tests/relay_quota_wake_liveness_regression_tests.rs b/src/proxy/tests/relay_quota_wake_liveness_regression_tests.rs new file mode 100644 index 0000000..f68609a --- /dev/null +++ b/src/proxy/tests/relay_quota_wake_liveness_regression_tests.rs @@ -0,0 +1,290 @@ +use super::*; +use crate::stats::Stats; +use dashmap::DashMap; +use std::sync::Arc; +use std::sync::atomic::AtomicBool; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::sync::Barrier; +use tokio::time::{Duration, timeout}; + +fn saturate_lock_cache() -> Vec>> { + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); + for idx in 0..QUOTA_USER_LOCKS_MAX { + retained.push(quota_user_lock(&format!("quota-liveness-saturated-{idx}"))); + } + retained +} + +fn quota_test_guard() -> std::sync::MutexGuard<'static, ()> { + super::quota_user_lock_test_guard() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +#[tokio::test] +async fn positive_writer_progresses_after_contention_release_without_external_wake() { + let _guard = quota_test_guard(); + + let _retained = saturate_lock_cache(); + let user = "quota-liveness-writer-positive"; + let stats = Arc::new(Stats::new()); + + let lock = quota_user_lock(user); + let held_guard = lock + .try_lock() + .expect("test must hold user quota lock before write"); + + let counters = Arc::new(SharedCounters::new()); + let quota_exceeded = Arc::new(AtomicBool::new(false)); + let mut io = StatsIo::new( + tokio::io::sink(), + counters, + Arc::clone(&stats), + user.to_string(), + Some(1024), + quota_exceeded, + tokio::time::Instant::now(), + ); + + let writer = tokio::spawn(async move { io.write_all(&[0x11]).await }); + + // Let the initial deferred wake fire while contention is still active. + tokio::time::sleep(Duration::from_millis(4)).await; + + drop(held_guard); + + let completed = timeout(Duration::from_millis(250), writer) + .await + .expect("writer must be re-polled and complete after lock release") + .expect("writer task must not panic"); + assert!(completed.is_ok(), "writer must complete after lock release"); +} + +#[tokio::test] +async fn edge_reader_progresses_after_contention_release_without_external_wake() { + let _guard = quota_test_guard(); + + let _retained = saturate_lock_cache(); + let user = "quota-liveness-reader-edge"; + let stats = Arc::new(Stats::new()); + + let lock = quota_user_lock(user); + let held_guard = lock + .try_lock() + .expect("test must hold user quota lock before read"); + + let counters = Arc::new(SharedCounters::new()); + let quota_exceeded = Arc::new(AtomicBool::new(false)); + let mut io = StatsIo::new( + tokio::io::empty(), + counters, + Arc::clone(&stats), + user.to_string(), + Some(1024), + quota_exceeded, + tokio::time::Instant::now(), + ); + + let reader = tokio::spawn(async move { + let mut one = [0u8; 1]; + io.read(&mut one).await + }); + + tokio::time::sleep(Duration::from_millis(4)).await; + drop(held_guard); + + let completed = timeout(Duration::from_millis(250), reader) + .await + .expect("reader must be re-polled and complete after lock release") + .expect("reader task must not panic"); + assert!(completed.is_ok(), "reader must complete after lock release"); +} + +#[tokio::test] +async fn adversarial_early_deferred_wake_consumption_does_not_deadlock_writer() { + let _guard = quota_test_guard(); + + let _retained = saturate_lock_cache(); + let user = "quota-liveness-adversarial"; + let stats = Arc::new(Stats::new()); + + let lock = quota_user_lock(user); + let held_guard = lock + .try_lock() + .expect("test must hold user quota lock before adversarial write"); + + let counters = Arc::new(SharedCounters::new()); + let quota_exceeded = Arc::new(AtomicBool::new(false)); + let mut io = StatsIo::new( + tokio::io::sink(), + counters, + Arc::clone(&stats), + user.to_string(), + Some(1024), + quota_exceeded, + tokio::time::Instant::now(), + ); + + let writer = tokio::spawn(async move { io.write_all(&[0x22]).await }); + + // Force multiple scheduler rounds while lock remains held so the first + // deferred wake has already been consumed under contention. + for _ in 0..32 { + tokio::task::yield_now().await; + } + + drop(held_guard); + + let completed = timeout(Duration::from_millis(300), writer) + .await + .expect("writer must not stay parked forever after release") + .expect("writer task must not panic"); + assert!(completed.is_ok()); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn integration_parallel_waiters_resume_after_single_release_event() { + let _guard = quota_test_guard(); + + let _retained = saturate_lock_cache(); + let user = format!("quota-liveness-integration-{}", std::process::id()); + let stats = Arc::new(Stats::new()); + let barrier = Arc::new(Barrier::new(13)); + + let lock = quota_user_lock(&user); + let held_guard = lock + .try_lock() + .expect("test must hold user quota lock before launching waiters"); + + let mut waiters = Vec::new(); + for _ in 0..12 { + let stats = Arc::clone(&stats); + let user = user.clone(); + let barrier = Arc::clone(&barrier); + waiters.push(tokio::spawn(async move { + let counters = Arc::new(SharedCounters::new()); + let quota_exceeded = Arc::new(AtomicBool::new(false)); + let mut io = StatsIo::new( + tokio::io::sink(), + counters, + stats, + user, + Some(4096), + quota_exceeded, + tokio::time::Instant::now(), + ); + barrier.wait().await; + io.write_all(&[0x33]).await + })); + } + + barrier.wait().await; + tokio::time::sleep(Duration::from_millis(4)).await; + drop(held_guard); + + timeout(Duration::from_secs(1), async { + for waiter in waiters { + let outcome = waiter.await.expect("waiter must not panic"); + assert!(outcome.is_ok(), "waiter must resume and complete after release"); + } + }) + .await + .expect("all waiters must complete in bounded time"); +} + +#[tokio::test] +async fn light_fuzz_release_timing_matrix_preserves_liveness() { + let _guard = quota_test_guard(); + + let _retained = saturate_lock_cache(); + let stats = Arc::new(Stats::new()); + + let mut seed = 0xD1CE_F00D_0123_4567u64; + for round in 0..64u32 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let delay_ms = 1 + (seed & 0x7) as u64; + let user = format!("quota-liveness-fuzz-{}-{round}", std::process::id()); + + let lock = quota_user_lock(&user); + let held_guard = lock + .try_lock() + .expect("test must hold user quota lock in fuzz round"); + + let counters = Arc::new(SharedCounters::new()); + let quota_exceeded = Arc::new(AtomicBool::new(false)); + let mut io = StatsIo::new( + tokio::io::sink(), + counters, + Arc::clone(&stats), + user, + Some(2048), + quota_exceeded, + tokio::time::Instant::now(), + ); + + let writer = tokio::spawn(async move { io.write_all(&[0x44]).await }); + + tokio::time::sleep(Duration::from_millis(delay_ms)).await; + drop(held_guard); + + let done = timeout(Duration::from_millis(300), writer) + .await + .expect("fuzz round writer must complete") + .expect("fuzz writer task must not panic"); + assert!(done.is_ok(), "fuzz round writer must not stall after release"); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_repeated_contention_cycles_remain_live() { + let _guard = quota_test_guard(); + + let _retained = saturate_lock_cache(); + let stats = Arc::new(Stats::new()); + + for cycle in 0..40u32 { + let user = format!("quota-liveness-stress-{}-{cycle}", std::process::id()); + let lock = quota_user_lock(&user); + let held_guard = lock + .try_lock() + .expect("test must hold lock before stress cycle"); + + let mut tasks = Vec::new(); + for _ in 0..6 { + let stats = Arc::clone(&stats); + let user = user.clone(); + tasks.push(tokio::spawn(async move { + let counters = Arc::new(SharedCounters::new()); + let quota_exceeded = Arc::new(AtomicBool::new(false)); + let mut io = StatsIo::new( + tokio::io::sink(), + counters, + stats, + user, + Some(2048), + quota_exceeded, + tokio::time::Instant::now(), + ); + io.write_all(&[0x55]).await + })); + } + + tokio::task::yield_now().await; + drop(held_guard); + + timeout(Duration::from_millis(700), async { + for task in tasks { + let outcome = task.await.expect("stress task must not panic"); + assert!(outcome.is_ok(), "stress writer must complete"); + } + }) + .await + .expect("stress cycle must finish in bounded time"); + } +} diff --git a/src/proxy/tests/relay_quota_waker_storm_adversarial_tests.rs b/src/proxy/tests/relay_quota_waker_storm_adversarial_tests.rs new file mode 100644 index 0000000..666d90c --- /dev/null +++ b/src/proxy/tests/relay_quota_waker_storm_adversarial_tests.rs @@ -0,0 +1,306 @@ +use super::*; +use crate::stats::Stats; +use dashmap::DashMap; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::task::{Context, Waker}; +use tokio::io::{ReadBuf, AsyncWriteExt}; +use tokio::time::{Duration, timeout}; + +#[derive(Default)] +struct WakeCounter { + wakes: AtomicUsize, +} + +impl std::task::Wake for WakeCounter { + fn wake(self: Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } + + fn wake_by_ref(self: &Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } +} + +fn quota_test_guard() -> std::sync::MutexGuard<'static, ()> { + super::quota_user_lock_test_guard() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +fn saturate_quota_user_locks() -> Vec>> { + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); + for idx in 0..QUOTA_USER_LOCKS_MAX { + retained.push(quota_user_lock(&format!("quota-waker-saturate-{idx}"))); + } + retained +} + +#[tokio::test] +async fn positive_contended_writer_emits_deferred_wake_for_liveness() { + let _guard = quota_test_guard(); + + let _retained = saturate_quota_user_locks(); + let stats = Arc::new(Stats::new()); + let user = "quota-waker-positive-user"; + + let lock = quota_user_lock(user); + let held_guard = lock + .try_lock() + .expect("test must hold overflow lock before polling writer"); + + let counters = Arc::new(SharedCounters::new()); + let quota_exceeded = Arc::new(AtomicBool::new(false)); + let mut io = StatsIo::new( + tokio::io::sink(), + counters, + Arc::clone(&stats), + user.to_string(), + Some(1024), + quota_exceeded, + tokio::time::Instant::now(), + ); + + let wake_counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&wake_counter)); + let mut cx = Context::from_waker(&waker); + + let pending = Pin::new(&mut io).poll_write(&mut cx, &[0xA1]); + assert!(pending.is_pending()); + + timeout(Duration::from_millis(100), async { + loop { + if wake_counter.wakes.load(Ordering::Relaxed) >= 1 { + break; + } + tokio::task::yield_now().await; + } + }) + .await + .expect("contended writer must receive deferred wake"); + + drop(held_guard); + let ready = Pin::new(&mut io).poll_write(&mut cx, &[0xA2]); + assert!(ready.is_ready(), "writer must progress after contention release"); +} + +#[tokio::test] +async fn adversarial_blackhat_writer_contention_does_not_create_waker_storm() { + let _guard = quota_test_guard(); + + let _retained = saturate_quota_user_locks(); + let stats = Arc::new(Stats::new()); + let user = "quota-waker-blackhat-writer"; + + let lock = quota_user_lock(user); + let held_guard = lock + .try_lock() + .expect("test must hold overflow lock before polling writer"); + + let counters = Arc::new(SharedCounters::new()); + let quota_exceeded = Arc::new(AtomicBool::new(false)); + let mut io = StatsIo::new( + tokio::io::sink(), + counters, + Arc::clone(&stats), + user.to_string(), + Some(1024), + quota_exceeded, + tokio::time::Instant::now(), + ); + + let wake_counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&wake_counter)); + let mut cx = Context::from_waker(&waker); + + for _ in 0..512 { + let poll = Pin::new(&mut io).poll_write(&mut cx, &[0xBE]); + assert!(poll.is_pending(), "writer must stay pending while lock is held"); + tokio::task::yield_now().await; + } + + let wakes = wake_counter.wakes.load(Ordering::Relaxed); + assert!( + wakes <= 128, + "pending writer retries must not trigger wake storm; observed wakes={wakes}" + ); + + drop(held_guard); + let ready = Pin::new(&mut io).poll_write(&mut cx, &[0xEF]); + assert!(ready.is_ready()); +} + +#[tokio::test] +async fn edge_read_path_contention_keeps_wake_budget_bounded() { + let _guard = quota_test_guard(); + + let _retained = saturate_quota_user_locks(); + let stats = Arc::new(Stats::new()); + let user = "quota-waker-read-edge"; + + let lock = quota_user_lock(user); + let held_guard = lock + .try_lock() + .expect("test must hold overflow lock before polling reader"); + + let counters = Arc::new(SharedCounters::new()); + let quota_exceeded = Arc::new(AtomicBool::new(false)); + let mut io = StatsIo::new( + tokio::io::empty(), + counters, + Arc::clone(&stats), + user.to_string(), + Some(1024), + quota_exceeded, + tokio::time::Instant::now(), + ); + + let wake_counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&wake_counter)); + let mut cx = Context::from_waker(&waker); + let mut storage = [0u8; 1]; + + for _ in 0..512 { + let mut buf = ReadBuf::new(&mut storage); + let poll = Pin::new(&mut io).poll_read(&mut cx, &mut buf); + assert!(poll.is_pending()); + tokio::task::yield_now().await; + } + + let wakes = wake_counter.wakes.load(Ordering::Relaxed); + assert!( + wakes <= 128, + "pending reader retries must not trigger wake storm; observed wakes={wakes}" + ); + + drop(held_guard); + let mut buf = ReadBuf::new(&mut storage); + let ready = Pin::new(&mut io).poll_read(&mut cx, &mut buf); + assert!(ready.is_ready()); +} + +#[tokio::test] +async fn light_fuzz_mixed_poll_schedule_under_contention_stays_bounded() { + let _guard = quota_test_guard(); + + let _retained = saturate_quota_user_locks(); + let stats = Arc::new(Stats::new()); + let user = "quota-waker-fuzz-user"; + + let lock = quota_user_lock(user); + let held_guard = lock + .try_lock() + .expect("test must hold overflow lock before fuzz polling"); + + let counters_w = Arc::new(SharedCounters::new()); + let mut writer_io = StatsIo::new( + tokio::io::sink(), + counters_w, + Arc::clone(&stats), + user.to_string(), + Some(1024), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + let counters_r = Arc::new(SharedCounters::new()); + let mut reader_io = StatsIo::new( + tokio::io::empty(), + counters_r, + Arc::clone(&stats), + user.to_string(), + Some(1024), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + let wake_counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&wake_counter)); + let mut cx = Context::from_waker(&waker); + let mut seed = 0xBADC_0FFE_EE11_2211u64; + let mut storage = [0u8; 1]; + + for _ in 0..1024 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + if (seed & 1) == 0 { + let poll = Pin::new(&mut writer_io).poll_write(&mut cx, &[0x44]); + assert!(poll.is_pending()); + } else { + let mut buf = ReadBuf::new(&mut storage); + let poll = Pin::new(&mut reader_io).poll_read(&mut cx, &mut buf); + assert!(poll.is_pending()); + } + tokio::task::yield_now().await; + } + + assert!( + wake_counter.wakes.load(Ordering::Relaxed) <= 192, + "mixed contention fuzz must keep deferred wake count tightly bounded" + ); + + drop(held_guard); + let ready_w = Pin::new(&mut writer_io).poll_write(&mut cx, &[0x55]); + assert!(ready_w.is_ready()); + + let mut buf = ReadBuf::new(&mut storage); + let ready_r = Pin::new(&mut reader_io).poll_read(&mut cx, &mut buf); + assert!(ready_r.is_ready()); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +#[ignore = "red-team detector: reveals possible starvation if deferred wake fires before contention release"] +async fn stress_many_contended_writers_complete_after_release() { + let _guard = quota_test_guard(); + + let _retained = saturate_quota_user_locks(); + let user = "quota-waker-stress-user".to_string(); + let stats = Arc::new(Stats::new()); + + let lock = quota_user_lock(&user); + let held_guard = lock + .try_lock() + .expect("test must hold overflow lock before launching contended tasks"); + + let mut tasks = Vec::new(); + for _ in 0..32 { + let stats = Arc::clone(&stats); + let user = user.clone(); + tasks.push(tokio::spawn(async move { + let counters = Arc::new(SharedCounters::new()); + let quota_exceeded = Arc::new(AtomicBool::new(false)); + let mut io = StatsIo::new( + tokio::io::sink(), + counters, + stats, + user, + Some(2048), + quota_exceeded, + tokio::time::Instant::now(), + ); + + io.write_all(&[0xAA]).await + })); + } + + for _ in 0..8 { + tokio::task::yield_now().await; + } + + drop(held_guard); + + timeout(Duration::from_secs(2), async { + for task in tasks { + let result = task.await.expect("stress task must not panic"); + assert!(result.is_ok(), "task must complete after lock release"); + } + }) + .await + .expect("all contended writer tasks must finish in bounded time after release"); +} diff --git a/src/proxy/tests/relay_watchdog_delta_security_tests.rs b/src/proxy/tests/relay_watchdog_delta_security_tests.rs new file mode 100644 index 0000000..f05ee62 --- /dev/null +++ b/src/proxy/tests/relay_watchdog_delta_security_tests.rs @@ -0,0 +1,61 @@ +use super::watchdog_delta; + +#[test] +fn positive_monotonic_growth_returns_exact_delta() { + assert_eq!(watchdog_delta(42, 40), 2); + assert_eq!(watchdog_delta(4096, 1024), 3072); +} + +#[test] +fn edge_equal_values_return_zero_delta() { + assert_eq!(watchdog_delta(0, 0), 0); + assert_eq!(watchdog_delta(777, 777), 0); +} + +#[test] +fn adversarial_wrap_like_regression_saturates_to_zero() { + // Simulates a wrapped or reset counter observation where current < previous. + assert_eq!(watchdog_delta(0, 1), 0); + assert_eq!(watchdog_delta(12, 4096), 0); +} + +#[test] +fn adversarial_blackhat_large_previous_value_never_underflows() { + let current = 3u64; + let previous = u64::MAX - 1; + assert_eq!(watchdog_delta(current, previous), 0); +} + +#[test] +fn light_fuzz_mixed_pairs_match_saturating_sub_contract() { + // Deterministic xorshift64* generator for reproducible pseudo-fuzzing. + let mut seed = 0xA51C_ED42_D00D_F00Du64; + + for _ in 0..10_000 { + seed ^= seed >> 12; + seed ^= seed << 25; + seed ^= seed >> 27; + let current = seed.wrapping_mul(0x2545_F491_4F6C_DD1D); + + seed ^= seed >> 12; + seed ^= seed << 25; + seed ^= seed >> 27; + let previous = seed.wrapping_mul(0x2545_F491_4F6C_DD1D); + + let expected = current.saturating_sub(previous); + let actual = watchdog_delta(current, previous); + assert_eq!(actual, expected, "delta mismatch for ({current}, {previous})"); + } +} + +#[test] +fn stress_long_running_monotonic_sequence_remains_exact() { + let mut prev = 0u64; + + for step in 1u64..=200_000 { + let curr = prev.saturating_add(step & 0x7); + let delta = watchdog_delta(curr, prev); + assert_eq!(delta, curr - prev); + prev = curr; + } +} diff --git a/src/stream/tls_stream.rs b/src/stream/tls_stream.rs index d405cda..7053a7b 100644 --- a/src/stream/tls_stream.rs +++ b/src/stream/tls_stream.rs @@ -297,6 +297,11 @@ impl FakeTlsReader { pub fn into_inner_with_pending_plaintext(mut self) -> (R, Vec) { let pending = match std::mem::replace(&mut self.state, TlsReaderState::Idle) { TlsReaderState::Yielding { buffer } => buffer.as_slice().to_vec(), + TlsReaderState::ReadingBody { record_type, buffer, .. } + if record_type == TLS_RECORD_APPLICATION => + { + buffer.to_vec() + } _ => Vec::new(), }; (self.upstream, pending) @@ -1293,3 +1298,7 @@ mod tests { assert_eq!(bytes, [0x17, 0x03, 0x03, 0x12, 0x34]); } } + +#[cfg(test)] +#[path = "tls_stream_pending_plaintext_security_tests.rs"] +mod pending_plaintext_security_tests; diff --git a/src/stream/tls_stream_pending_plaintext_security_tests.rs b/src/stream/tls_stream_pending_plaintext_security_tests.rs new file mode 100644 index 0000000..30a11ad --- /dev/null +++ b/src/stream/tls_stream_pending_plaintext_security_tests.rs @@ -0,0 +1,143 @@ +use super::*; +use bytes::{Bytes, BytesMut}; + +#[test] +fn reading_body_pending_application_plaintext_is_preserved_on_into_inner() { + let sample = b"coalesced-tail-after-mtproto"; + let mut reader = FakeTlsReader::new(tokio::io::empty()); + reader.state = TlsReaderState::ReadingBody { + record_type: TLS_RECORD_APPLICATION, + length: sample.len(), + buffer: BytesMut::from(&sample[..]), + }; + + let (_inner, pending) = reader.into_inner_with_pending_plaintext(); + assert_eq!( + pending, + sample, + "partial application-data body must survive into fallback path" + ); +} + +#[test] +fn yielding_pending_plaintext_is_preserved_on_into_inner() { + let sample = b"already-decoded-buffer"; + let mut reader = FakeTlsReader::new(tokio::io::empty()); + reader.state = TlsReaderState::Yielding { + buffer: YieldBuffer::new(Bytes::copy_from_slice(sample)), + }; + + let (_inner, pending) = reader.into_inner_with_pending_plaintext(); + assert_eq!(pending, sample); +} + +#[test] +fn reading_body_non_application_record_does_not_produce_plaintext() { + let sample = b"unexpected-handshake-fragment"; + let mut reader = FakeTlsReader::new(tokio::io::empty()); + reader.state = TlsReaderState::ReadingBody { + record_type: TLS_RECORD_HANDSHAKE, + length: sample.len(), + buffer: BytesMut::from(&sample[..]), + }; + + let (_inner, pending) = reader.into_inner_with_pending_plaintext(); + assert!( + pending.is_empty(), + "non-application partial body must not be surfaced as plaintext" + ); +} + +#[test] +fn partial_header_state_does_not_produce_plaintext() { + let mut header = HeaderBuffer::::new(); + let unfilled = header.unfilled_mut(); + unfilled[0] = TLS_RECORD_APPLICATION; + header.advance(1); + + let mut reader = FakeTlsReader::new(tokio::io::empty()); + reader.state = TlsReaderState::ReadingHeader { header }; + + let (_inner, pending) = reader.into_inner_with_pending_plaintext(); + assert!(pending.is_empty(), "partial header bytes are not plaintext payload"); +} + +#[test] +fn edge_zero_length_application_fragment_remains_empty_without_panics() { + let mut reader = FakeTlsReader::new(tokio::io::empty()); + reader.state = TlsReaderState::ReadingBody { + record_type: TLS_RECORD_APPLICATION, + length: 0, + buffer: BytesMut::new(), + }; + + let (_inner, pending) = reader.into_inner_with_pending_plaintext(); + assert!(pending.is_empty()); +} + +#[test] +fn adversarial_poisoned_state_never_leaks_pending_bytes() { + let mut reader = FakeTlsReader::new(tokio::io::empty()); + reader.state = TlsReaderState::Poisoned { + error: Some(std::io::Error::other("poisoned by adversarial input")), + }; + + let (_inner, pending) = reader.into_inner_with_pending_plaintext(); + assert!(pending.is_empty(), "poisoned state must fail-closed for fallback payload"); +} + +#[test] +fn stress_large_application_fragment_survives_state_extraction() { + let mut payload = vec![0u8; 96 * 1024]; + for (i, b) in payload.iter_mut().enumerate() { + *b = (i as u8).wrapping_mul(17).wrapping_add(3); + } + + let mut reader = FakeTlsReader::new(tokio::io::empty()); + reader.state = TlsReaderState::ReadingBody { + record_type: TLS_RECORD_APPLICATION, + length: payload.len(), + buffer: BytesMut::from(&payload[..]), + }; + + let (_inner, pending) = reader.into_inner_with_pending_plaintext(); + assert_eq!(pending, payload, "large pending application plaintext must be preserved exactly"); +} + +#[test] +fn light_fuzz_state_matrix_preserves_pending_contract() { + let mut seed = 0x9E37_79B9_7F4A_7C15u64; + + for _ in 0..4096 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let len = (seed & 0x1ff) as usize; + let mut payload = vec![0u8; len]; + for (idx, b) in payload.iter_mut().enumerate() { + *b = (seed as u8).wrapping_add(idx as u8); + } + + let record_type = match seed & 0x3 { + 0 => TLS_RECORD_APPLICATION, + 1 => TLS_RECORD_HANDSHAKE, + 2 => TLS_RECORD_ALERT, + _ => TLS_RECORD_CHANGE_CIPHER, + }; + + let mut reader = FakeTlsReader::new(tokio::io::empty()); + reader.state = TlsReaderState::ReadingBody { + record_type, + length: payload.len(), + buffer: BytesMut::from(&payload[..]), + }; + + let (_inner, pending) = reader.into_inner_with_pending_plaintext(); + if record_type == TLS_RECORD_APPLICATION { + assert_eq!(pending, payload); + } else { + assert!(pending.is_empty()); + } + } +} diff --git a/src/transport/middle_proxy/pool_writer.rs b/src/transport/middle_proxy/pool_writer.rs index 054b5ed..2394992 100644 --- a/src/transport/middle_proxy/pool_writer.rs +++ b/src/transport/middle_proxy/pool_writer.rs @@ -591,14 +591,9 @@ impl MePool { if let Some(tx) = close_tx { let _ = tx.send(WriterCommand::Close).await; } - if let Some(addr) = removed_addr - && let Some(uptime) = removed_uptime - { - // Quarantine flapping endpoints regardless of draining state. - self.maybe_quarantine_flapping_endpoint(addr, uptime).await; - } if let Some(addr) = removed_addr { if let Some(uptime) = removed_uptime { + // Quarantine flapping endpoints regardless of draining state. self.maybe_quarantine_flapping_endpoint(addr, uptime).await; } if trigger_refill diff --git a/src/transport/middle_proxy/tests/pool_writer_security_tests.rs b/src/transport/middle_proxy/tests/pool_writer_security_tests.rs index bbc9790..27b9635 100644 --- a/src/transport/middle_proxy/tests/pool_writer_security_tests.rs +++ b/src/transport/middle_proxy/tests/pool_writer_security_tests.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::sync::Arc; use std::sync::atomic::{AtomicBool, AtomicU8, AtomicU32, AtomicU64, Ordering}; @@ -9,6 +9,7 @@ use tokio_util::sync::CancellationToken; use super::codec::WriterCommand; use super::pool::{MePool, MeWriter, WriterContour}; +use super::registry::ConnMeta; use crate::config::{GeneralConfig, MeRouteNoWriterMode, MeSocksKdfPolicy, MeWriterPickMode}; use crate::crypto::SecureRandom; use crate::network::probe::NetworkDecision; @@ -141,6 +142,34 @@ async fn insert_writer( pool.conn_count.fetch_add(1, Ordering::Relaxed); } +async fn current_writer_ids(pool: &Arc) -> HashSet { + pool.writers + .read() + .await + .iter() + .map(|writer| writer.id) + .collect() +} + +async fn bind_conn_to_writer(pool: &Arc, writer_id: u64, port: u16) -> u64 { + let (conn_id, _rx) = pool.registry.register().await; + let bound = pool + .registry + .bind_writer( + conn_id, + writer_id, + ConnMeta { + target_dc: 2, + client_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), port), + our_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443), + proto_flags: 0, + }, + ) + .await; + assert!(bound, "writer binding must succeed"); + conn_id +} + #[tokio::test] async fn remove_draining_writer_still_quarantines_flapping_endpoint() { let pool = make_pool().await; @@ -174,3 +203,180 @@ async fn remove_draining_writer_still_quarantines_flapping_endpoint() { ); assert_eq!(pool.conn_count.load(Ordering::Relaxed), 0); } + +#[tokio::test] +async fn positive_remove_writer_cleans_bound_registry_routes() { + let pool = make_pool().await; + let writer_id = 88; + let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 12, 0, 88)), 443); + insert_writer(&pool, writer_id, 2, addr, false, Instant::now()).await; + + let conn_id = bind_conn_to_writer(&pool, writer_id, 7301).await; + assert!(pool.registry.get_writer(conn_id).await.is_some()); + + pool.remove_writer_and_close_clients(writer_id).await; + + assert!(pool.registry.get_writer(conn_id).await.is_none()); + assert!(!current_writer_ids(&pool).await.contains(&writer_id)); + assert_eq!(pool.conn_count.load(Ordering::Relaxed), 0); +} + +#[tokio::test] +async fn negative_unknown_writer_removal_is_noop() { + let pool = make_pool().await; + let before_quarantine = pool.stats.get_me_endpoint_quarantine_total(); + + pool.remove_writer_and_close_clients(9_999_001).await; + + assert!(pool.writers.read().await.is_empty()); + assert_eq!(pool.conn_count.load(Ordering::Relaxed), 0); + assert_eq!(pool.stats.get_me_endpoint_quarantine_total(), before_quarantine); +} + +#[tokio::test] +async fn edge_draining_only_detach_rejects_active_writer() { + let pool = make_pool().await; + let writer_id = 91; + let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 12, 0, 91)), 443); + insert_writer(&pool, writer_id, 2, addr, false, Instant::now()).await; + + let removed = pool.remove_draining_writer_hard_detach(writer_id).await; + assert!(!removed, "active writer must not be detached by draining-only path"); + assert!(current_writer_ids(&pool).await.contains(&writer_id)); + assert_eq!(pool.conn_count.load(Ordering::Relaxed), 1); + + pool.remove_writer_and_close_clients(writer_id).await; +} + +#[tokio::test] +async fn adversarial_blackhat_single_remove_establishes_single_quarantine_entry() { + let pool = make_pool().await; + let writer_id = 93; + let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 12, 0, 93)), 443); + insert_writer( + &pool, + writer_id, + 2, + addr, + true, + Instant::now() - Duration::from_secs(1), + ) + .await; + + pool.remove_writer_and_close_clients(writer_id).await; + assert!(pool.is_endpoint_quarantined(addr).await); + assert_eq!(pool.endpoint_quarantine.lock().await.len(), 1); +} + +#[tokio::test] +async fn integration_old_uptime_writer_does_not_trigger_flap_quarantine() { + let pool = make_pool().await; + let writer_id = 94; + let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 12, 0, 94)), 443); + insert_writer( + &pool, + writer_id, + 2, + addr, + false, + Instant::now() - Duration::from_secs(30), + ) + .await; + + let before = pool.stats.get_me_endpoint_quarantine_total(); + pool.remove_writer_and_close_clients(writer_id).await; + let after = pool.stats.get_me_endpoint_quarantine_total(); + + assert_eq!(after, before); + assert!(!pool.is_endpoint_quarantined(addr).await); +} + +#[tokio::test] +async fn light_fuzz_insert_remove_schedule_preserves_pool_invariants() { + let pool = make_pool().await; + let mut seed = 0xA11C_E551_D00D_BAADu64; + let mut model = HashSet::::new(); + + for _ in 0..240 { + seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1); + let writer_id = 1 + (seed % 64); + let op_insert = ((seed >> 17) & 1) == 0; + + if op_insert { + if !model.contains(&writer_id) { + let ip_octet = (writer_id % 250) as u8; + let addr = SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(127, 13, 0, ip_octet.max(1))), + 4000 + writer_id as u16, + ); + let draining = ((seed >> 33) & 1) == 1; + let created_at = if draining { + Instant::now() - Duration::from_secs(1) + } else { + Instant::now() - Duration::from_secs(30) + }; + insert_writer(&pool, writer_id, 2, addr, draining, created_at).await; + model.insert(writer_id); + } + } else { + pool.remove_writer_and_close_clients(writer_id).await; + model.remove(&writer_id); + } + + let actual_ids = current_writer_ids(&pool).await; + assert_eq!(actual_ids, model, "writer-id set must match model under fuzz schedule"); + assert_eq!(pool.conn_count.load(Ordering::Relaxed) as usize, model.len()); + } + + for writer_id in model { + pool.remove_writer_and_close_clients(writer_id).await; + } + assert!(pool.writers.read().await.is_empty()); + assert_eq!(pool.conn_count.load(Ordering::Relaxed), 0); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_parallel_duplicate_removals_are_idempotent() { + let pool = make_pool().await; + + for writer_id in 1..=48u64 { + let addr = SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(127, 14, (writer_id / 250) as u8, (writer_id % 250) as u8)), + 5000 + writer_id as u16, + ); + insert_writer( + &pool, + writer_id, + 2, + addr, + true, + Instant::now() - Duration::from_secs(1), + ) + .await; + } + + let mut tasks = Vec::new(); + for worker in 0..8u64 { + let pool = Arc::clone(&pool); + tasks.push(tokio::spawn(async move { + for writer_id in 1..=48u64 { + if ((writer_id + worker) & 1) == 0 { + pool.remove_writer_and_close_clients(writer_id).await; + } else { + pool.remove_writer_and_close_clients(100_000 + writer_id).await; + } + } + })); + } + + for task in tasks { + task.await.expect("stress remover task must not panic"); + } + + for writer_id in 1..=48u64 { + pool.remove_writer_and_close_clients(writer_id).await; + } + + assert!(pool.writers.read().await.is_empty()); + assert_eq!(pool.conn_count.load(Ordering::Relaxed), 0); +}