From ead23608f0535d1215f6622e87979de2af7b5b37 Mon Sep 17 00:00:00 2001 From: David Osipov Date: Sun, 22 Mar 2026 15:30:02 +0400 Subject: [PATCH 01/29] Add stress and manual benchmark tests for handshake protocols - Introduced `handshake_real_bug_stress_tests.rs` to validate TLS and MTProto handshake behaviors under various conditions, including ALPN rejection and session ID handling. - Implemented tests to ensure replay cache integrity and proper handling of malicious input without panicking. - Added `handshake_timing_manual_bench_tests.rs` for performance benchmarking of user authentication paths, comparing preferred user handling against full user scans in both MTProto and TLS contexts. - Included timing-sensitive tests to measure the impact of SNI on handshake performance. --- docs/CONFIG_PARAMS.en.md | 3 + src/proxy/client.rs | 12 + src/proxy/handshake.rs | 36 +- .../tests/client_clever_advanced_tests.rs | 409 +++++++++++ .../tests/client_deep_invariants_tests.rs | 196 ++++++ src/proxy/tests/client_more_advanced_tests.rs | 257 +++++++ src/proxy/tests/client_security_tests.rs | 199 ++++++ .../tests/handshake_advanced_clever_tests.rs | 647 ++++++++++++++++++ .../tests/handshake_more_clever_tests.rs | 614 +++++++++++++++++ .../tests/handshake_real_bug_stress_tests.rs | 337 +++++++++ .../handshake_timing_manual_bench_tests.rs | 318 +++++++++ 11 files changed, 3018 insertions(+), 10 deletions(-) create mode 100644 src/proxy/tests/client_clever_advanced_tests.rs create mode 100644 src/proxy/tests/client_deep_invariants_tests.rs create mode 100644 src/proxy/tests/client_more_advanced_tests.rs create mode 100644 src/proxy/tests/handshake_advanced_clever_tests.rs create mode 100644 src/proxy/tests/handshake_more_clever_tests.rs create mode 100644 src/proxy/tests/handshake_real_bug_stress_tests.rs create mode 100644 src/proxy/tests/handshake_timing_manual_bench_tests.rs diff --git a/docs/CONFIG_PARAMS.en.md b/docs/CONFIG_PARAMS.en.md index 3eee3a7..f8c56a0 100644 --- a/docs/CONFIG_PARAMS.en.md +++ b/docs/CONFIG_PARAMS.en.md @@ -198,11 +198,14 @@ This document lists all configuration keys accepted by `config.toml`. | listen_tcp | `bool \| null` | `null` (auto) | — | Explicit TCP listener enable/disable override. | | proxy_protocol | `bool` | `false` | — | Enables HAProxy PROXY protocol parsing on incoming client connections. | | proxy_protocol_header_timeout_ms | `u64` | `500` | Must be `> 0`. | Timeout for PROXY protocol header read/parse (ms). | +| proxy_protocol_trusted_cidrs | `IpNetwork[]` | `[]` | — | When non-empty, only connections from these proxy source CIDRs are allowed to provide PROXY protocol headers. If empty, PROXY headers are rejected by default (security hardening). | | metrics_port | `u16 \| null` | `null` | — | Metrics endpoint port (enables metrics listener). | | metrics_listen | `String \| null` | `null` | — | Full metrics bind address (`IP:PORT`), overrides `metrics_port`. | | metrics_whitelist | `IpNetwork[]` | `["127.0.0.1/32", "::1/128"]` | — | CIDR whitelist for metrics endpoint access. | | max_connections | `u32` | `10000` | — | Max concurrent client connections (`0` = unlimited). | +Note: When `server.proxy_protocol` is enabled, incoming PROXY protocol headers are parsed from the first bytes of the connection and the client source address is replaced with `src_addr` from the header. For security, the peer source IP (the direct connection address) is verified against `server.proxy_protocol_trusted_cidrs`; if this list is empty, PROXY headers are rejected and the connection is considered untrusted. + ## [server.api] | Parameter | Type | Default | Constraints / validation | Description | diff --git a/src/proxy/client.rs b/src/proxy/client.rs index 4b7f57e..1941f36 100644 --- a/src/proxy/client.rs +++ b/src/proxy/client.rs @@ -1328,3 +1328,15 @@ mod beobachten_ttl_bounds_security_tests; #[cfg(test)] #[path = "tests/client_tls_record_wrap_hardening_security_tests.rs"] mod tls_record_wrap_hardening_security_tests; + +#[cfg(test)] +#[path = "tests/client_clever_advanced_tests.rs"] +mod client_clever_advanced_tests; + +#[cfg(test)] +#[path = "tests/client_more_advanced_tests.rs"] +mod client_more_advanced_tests; + +#[cfg(test)] +#[path = "tests/client_deep_invariants_tests.rs"] +mod client_deep_invariants_tests; diff --git a/src/proxy/handshake.rs b/src/proxy/handshake.rs index 5632977..7e4b62c 100644 --- a/src/proxy/handshake.rs +++ b/src/proxy/handshake.rs @@ -605,16 +605,6 @@ where } }; - // Replay tracking is applied only after successful authentication to avoid - // letting unauthenticated probes evict valid entries from the replay cache. - let digest_half = &validation.digest[..tls::TLS_DIGEST_HALF_LEN]; - if replay_checker.check_and_add_tls_digest(digest_half) { - auth_probe_record_failure(peer.ip(), Instant::now()); - maybe_apply_server_hello_delay(config).await; - warn!(peer = %peer, "TLS replay attack detected (duplicate digest)"); - return HandshakeResult::BadClient { reader, writer }; - } - let secret = match secrets.iter().find(|(name, _)| *name == validation.user) { Some((_, s)) => s, None => { @@ -670,6 +660,16 @@ where None }; + // Replay tracking is applied only after full policy validation (including + // ALPN checks) so rejected handshakes cannot poison replay state. + let digest_half = &validation.digest[..tls::TLS_DIGEST_HALF_LEN]; + if replay_checker.check_and_add_tls_digest(digest_half) { + auth_probe_record_failure(peer.ip(), Instant::now()); + maybe_apply_server_hello_delay(config).await; + warn!(peer = %peer, "TLS replay attack detected (duplicate digest)"); + return HandshakeResult::BadClient { reader, writer }; + } + let response = if let Some((cached_entry, use_full_cert_payload)) = cached { emulator::build_emulated_server_hello( secret, @@ -979,6 +979,22 @@ mod saturation_poison_security_tests; #[path = "tests/handshake_auth_probe_hardening_adversarial_tests.rs"] mod auth_probe_hardening_adversarial_tests; +#[cfg(test)] +#[path = "tests/handshake_advanced_clever_tests.rs"] +mod advanced_clever_tests; + +#[cfg(test)] +#[path = "tests/handshake_more_clever_tests.rs"] +mod more_clever_tests; + +#[cfg(test)] +#[path = "tests/handshake_real_bug_stress_tests.rs"] +mod real_bug_stress_tests; + +#[cfg(test)] +#[path = "tests/handshake_timing_manual_bench_tests.rs"] +mod timing_manual_bench_tests; + /// Compile-time guard: HandshakeSuccess holds cryptographic key material and /// must never be Copy. A Copy impl would allow silent key duplication, /// undermining the zeroize-on-drop guarantee. diff --git a/src/proxy/tests/client_clever_advanced_tests.rs b/src/proxy/tests/client_clever_advanced_tests.rs new file mode 100644 index 0000000..da2e703 --- /dev/null +++ b/src/proxy/tests/client_clever_advanced_tests.rs @@ -0,0 +1,409 @@ +use super::*; +use crate::config::{UpstreamConfig, UpstreamType, ProxyConfig}; +use crate::protocol::constants::{MAX_TLS_PLAINTEXT_SIZE, MIN_TLS_CLIENT_HELLO_SIZE}; +use crate::stats::Stats; +use crate::transport::UpstreamManager; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; +use std::time::Duration; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt, ReadBuf, duplex}; +use tokio::net::TcpListener; + +#[test] +fn edge_mask_reject_delay_min_greater_than_max_does_not_panic() { + let mut config = ProxyConfig::default(); + config.censorship.server_hello_delay_min_ms = 5000; + config.censorship.server_hello_delay_max_ms = 1000; + + let rt = tokio::runtime::Runtime::new().unwrap(); + rt.block_on(async { + let start = std::time::Instant::now(); + maybe_apply_mask_reject_delay(&config).await; + let elapsed = start.elapsed(); + + assert!(elapsed >= Duration::from_millis(1000)); + assert!(elapsed < Duration::from_millis(1500)); + }); +} + +#[test] +fn edge_handshake_timeout_with_mask_grace_saturating_add_prevents_overflow() { + let mut config = ProxyConfig::default(); + config.timeouts.client_handshake = u64::MAX; + config.censorship.mask = true; + + let timeout = handshake_timeout_with_mask_grace(&config); + assert_eq!(timeout.as_secs(), u64::MAX); +} + +#[test] +fn edge_tls_clienthello_len_in_bounds_exact_boundaries() { + assert!(tls_clienthello_len_in_bounds(MIN_TLS_CLIENT_HELLO_SIZE)); + assert!(!tls_clienthello_len_in_bounds(MIN_TLS_CLIENT_HELLO_SIZE - 1)); + assert!(tls_clienthello_len_in_bounds(MAX_TLS_PLAINTEXT_SIZE)); + assert!(!tls_clienthello_len_in_bounds(MAX_TLS_PLAINTEXT_SIZE + 1)); +} + +#[test] +fn edge_synthetic_local_addr_boundaries() { + assert_eq!(synthetic_local_addr(0).port(), 0); + assert_eq!(synthetic_local_addr(80).port(), 80); + assert_eq!(synthetic_local_addr(u16::MAX).port(), u16::MAX); +} + +#[test] +fn edge_beobachten_record_handshake_failure_class_stream_error_eof() { + let beobachten = BeobachtenStore::new(); + let mut config = ProxyConfig::default(); + config.general.beobachten = true; + config.general.beobachten_minutes = 1; + + let eof_err = ProxyError::Stream(crate::error::StreamError::UnexpectedEof); + let peer_ip: IpAddr = "198.51.100.100".parse().unwrap(); + + record_handshake_failure_class(&beobachten, &config, peer_ip, &eof_err); + + let snapshot = beobachten.snapshot_text(Duration::from_secs(60)); + assert!(snapshot.contains("[expected_64_got_0]")); +} + +#[tokio::test] +async fn adversarial_tls_handshake_timeout_during_masking_delay() { + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.timeouts.client_handshake = 1; + cfg.censorship.mask = true; + cfg.censorship.server_hello_delay_min_ms = 3000; + cfg.censorship.server_hello_delay_max_ms = 3000; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let (server_side, mut client_side) = duplex(4096); + + let handle = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.1:55000".parse().unwrap(), + config, + stats.clone(), + Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + false, + )); + + client_side.write_all(&[0x16, 0x03, 0x01, 0xFF, 0xFF]).await.unwrap(); + + let result = tokio::time::timeout(Duration::from_secs(4), handle) + .await + .unwrap() + .unwrap(); + + assert!(matches!(result, Err(ProxyError::TgHandshakeTimeout))); + assert_eq!(stats.get_handshake_timeouts(), 1); +} + +#[tokio::test] +async fn blackhat_proxy_protocol_slowloris_timeout() { + let mut cfg = ProxyConfig::default(); + cfg.server.proxy_protocol_header_timeout_ms = 200; + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + + let (server_side, mut client_side) = duplex(4096); + let handle = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.2:55000".parse().unwrap(), + config, + stats.clone(), + Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + true, + )); + + client_side.write_all(b"PROXY TCP4 192.").await.unwrap(); + tokio::time::sleep(Duration::from_millis(300)).await; + + let result = tokio::time::timeout(Duration::from_secs(2), handle) + .await + .unwrap() + .unwrap(); + + assert!(matches!(result, Err(ProxyError::InvalidProxyProtocol))); + assert_eq!(stats.get_connects_bad(), 1); +} + +#[test] +fn blackhat_ipv4_mapped_ipv6_proxy_source_bypass_attempt() { + let trusted = vec!["192.0.2.0/24".parse().unwrap()]; + let peer_ip = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0xffff, 0xc000, 0x0201)); + assert!(!is_trusted_proxy_source(peer_ip, &trusted)); +} + +#[tokio::test] +async fn negative_proxy_protocol_enabled_but_client_sends_tls_hello() { + let mut cfg = ProxyConfig::default(); + cfg.server.proxy_protocol_header_timeout_ms = 500; + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + + let (server_side, mut client_side) = duplex(4096); + let handle = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.3:55000".parse().unwrap(), + config, + stats.clone(), + Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + true, + )); + + client_side.write_all(&[0x16, 0x03, 0x01, 0x02, 0x00]).await.unwrap(); + + let result = tokio::time::timeout(Duration::from_secs(2), handle) + .await + .unwrap() + .unwrap(); + + assert!(matches!(result, Err(ProxyError::InvalidProxyProtocol))); + assert_eq!(stats.get_connects_bad(), 1); +} + +#[tokio::test] +async fn edge_client_stream_exactly_4_bytes_eof() { + let config = Arc::new(ProxyConfig::default()); + let stats = Arc::new(Stats::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(4096); + let handle = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.4:55000".parse().unwrap(), + config, + stats.clone(), + Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + beobachten.clone(), + false, + )); + + client_side.write_all(&[0x16, 0x03, 0x01, 0x00]).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(2), handle).await; + + let snapshot = beobachten.snapshot_text(Duration::from_secs(60)); + assert!(snapshot.contains("[expected_64_got_0]")); +} + +#[tokio::test] +async fn edge_client_stream_tls_header_valid_but_body_1_byte_short_eof() { + let config = Arc::new(ProxyConfig::default()); + let stats = Arc::new(Stats::new()); + + let (server_side, mut client_side) = duplex(4096); + let handle = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.5:55000".parse().unwrap(), + config, + stats.clone(), + Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + false, + )); + + client_side.write_all(&[0x16, 0x03, 0x01, 0x00, 100]).await.unwrap(); + client_side.write_all(&vec![0x41; 99]).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(2), handle).await; + assert_eq!(stats.get_connects_bad(), 1); +} + +#[tokio::test] +async fn integration_non_tls_modes_disabled_immediately_masks() { + let mut cfg = ProxyConfig::default(); + cfg.general.modes.classic = false; + cfg.general.modes.secure = false; + cfg.censorship.mask = true; + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + + let (server_side, mut client_side) = duplex(4096); + let handle = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.6:55000".parse().unwrap(), + config, + stats.clone(), + Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + false, + )); + + client_side.write_all(b"GET / HTTP/1.1\r\n").await.unwrap(); + let _ = tokio::time::timeout(Duration::from_secs(2), handle).await; + assert_eq!(stats.get_connects_bad(), 1); +} + +struct YieldingReader { + data: Vec, + pos: usize, + yields_left: usize, +} + +impl AsyncRead for YieldingReader { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let this = self.get_mut(); + if this.yields_left > 0 { + this.yields_left -= 1; + cx.waker().wake_by_ref(); + return Poll::Pending; + } + if this.pos >= this.data.len() { + return Poll::Ready(Ok(())); + } + buf.put_slice(&this.data[this.pos..this.pos + 1]); + this.pos += 1; + this.yields_left = 2; + Poll::Ready(Ok(())) + } +} + +#[tokio::test] +async fn fuzz_read_with_progress_heavy_yielding() { + let expected_data = b"HEAVY_YIELD_TEST_DATA".to_vec(); + let mut reader = YieldingReader { + data: expected_data.clone(), + pos: 0, + yields_left: 2, + }; + + let mut buf = vec![0u8; expected_data.len()]; + let read_bytes = read_with_progress(&mut reader, &mut buf).await.unwrap(); + + assert_eq!(read_bytes, expected_data.len()); + assert_eq!(buf, expected_data); +} + +#[test] +fn edge_wrap_tls_application_record_exactly_u16_max() { + let payload = vec![0u8; 65535]; + let wrapped = wrap_tls_application_record(&payload); + assert_eq!(wrapped.len(), 65540); + assert_eq!(wrapped[0], TLS_RECORD_APPLICATION); + assert_eq!(&wrapped[3..5], &65535u16.to_be_bytes()); +} + +#[test] +fn fuzz_wrap_tls_application_record_lengths() { + let lengths = [0, 1, 65534, 65535, 65536, 131070, 131071, 131072]; + for len in lengths { + let payload = vec![0u8; len]; + let wrapped = wrap_tls_application_record(&payload); + let expected_chunks = len.div_ceil(65535).max(1); + assert_eq!(wrapped.len(), len + 5 * expected_chunks); + } +} + +#[tokio::test] +async fn stress_user_connection_reservation_concurrent_same_ip_exhaustion() { + let user = "stress-same-ip-user"; + let mut config = ProxyConfig::default(); + config.access.user_max_tcp_conns.insert(user.to_string(), 5); + + let config = Arc::new(config); + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 10).await; + + let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(198, 51, 100, 77)), 55000); + + let mut tasks = tokio::task::JoinSet::new(); + let mut reservations = Vec::new(); + + for _ in 0..10 { + let config = config.clone(); + let stats = stats.clone(); + let ip_tracker = ip_tracker.clone(); + tasks.spawn(async move { + RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats, + peer, + ip_tracker, + ) + .await + }); + } + + let mut successes = 0; + let mut failures = 0; + + while let Some(res) = tasks.join_next().await { + match res.unwrap() { + Ok(r) => { + successes += 1; + reservations.push(r); + } + Err(_) => failures += 1, + } + } + + assert_eq!(successes, 5); + assert_eq!(failures, 5); + assert_eq!(stats.get_user_curr_connects(user), 5); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 1); + + for reservation in reservations { + reservation.release().await; + } + + assert_eq!(stats.get_user_curr_connects(user), 0); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 0); +} diff --git a/src/proxy/tests/client_deep_invariants_tests.rs b/src/proxy/tests/client_deep_invariants_tests.rs new file mode 100644 index 0000000..97c55c6 --- /dev/null +++ b/src/proxy/tests/client_deep_invariants_tests.rs @@ -0,0 +1,196 @@ +use super::*; +use crate::config::ProxyConfig; +use crate::protocol::constants::MIN_TLS_CLIENT_HELLO_SIZE; +use crate::stats::Stats; +use crate::transport::UpstreamManager; +use std::sync::Arc; +use std::time::Duration; +use tokio::io::{AsyncWriteExt, duplex}; + +#[test] +fn invariant_wrap_tls_application_record_exact_multiples() { + let chunk_size = u16::MAX as usize; + let payload = vec![0xAA; chunk_size * 2]; + + let wrapped = wrap_tls_application_record(&payload); + + assert_eq!(wrapped.len(), 2 * (5 + chunk_size)); + assert_eq!(wrapped[0], TLS_RECORD_APPLICATION); + assert_eq!(&wrapped[3..5], &65535u16.to_be_bytes()); + + let second_header_idx = 5 + chunk_size; + assert_eq!(wrapped[second_header_idx], TLS_RECORD_APPLICATION); + assert_eq!( + &wrapped[second_header_idx + 3..second_header_idx + 5], + &65535u16.to_be_bytes() + ); +} + +#[tokio::test] +async fn invariant_tls_clienthello_truncation_exact_boundary_triggers_masking() { + let config = Arc::new(ProxyConfig::default()); + let stats = Arc::new(Stats::new()); + + let (server_side, mut client_side) = duplex(4096); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.20:55000".parse().unwrap(), + config, + stats.clone(), + Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + false, + )); + + let claimed_len = MIN_TLS_CLIENT_HELLO_SIZE as u16; + let mut header = vec![0x16, 0x03, 0x01]; + header.extend_from_slice(&claimed_len.to_be_bytes()); + + client_side.write_all(&header).await.unwrap(); + client_side + .write_all(&vec![0x42; MIN_TLS_CLIENT_HELLO_SIZE - 1]) + .await + .unwrap(); + client_side.shutdown().await.unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(2), handler).await.unwrap(); + assert_eq!(stats.get_connects_bad(), 1); +} + +#[tokio::test] +async fn invariant_acquire_reservation_ip_limit_rollback() { + let user = "rollback-test-user"; + let mut config = ProxyConfig::default(); + config.access.user_max_tcp_conns.insert(user.to_string(), 10); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 1).await; + + let peer_a = "198.51.100.21:55000".parse().unwrap(); + let _res_a = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer_a, + ip_tracker.clone(), + ) + .await + .unwrap(); + + assert_eq!(stats.get_user_curr_connects(user), 1); + + let peer_b = "203.0.113.22:55000".parse().unwrap(); + let res_b = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer_b, + ip_tracker.clone(), + ) + .await; + + assert!(matches!( + res_b, + Err(ProxyError::ConnectionLimitExceeded { .. }) + )); + assert_eq!(stats.get_user_curr_connects(user), 1); +} + +#[tokio::test] +async fn invariant_quota_exact_boundary_inclusive() { + let user = "quota-strict-user"; + let mut config = ProxyConfig::default(); + config.access.user_data_quota.insert(user.to_string(), 1000); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + let peer = "198.51.100.23:55000".parse().unwrap(); + + stats.add_user_octets_from(user, 999); + let res1 = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await; + assert!(res1.is_ok()); + res1.unwrap().release().await; + + stats.add_user_octets_from(user, 1); + let res2 = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await; + assert!(matches!(res2, Err(ProxyError::DataQuotaExceeded { .. }))); +} + +#[tokio::test] +async fn invariant_direct_mode_partial_header_eof_is_error_not_bad_connect() { + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = true; + cfg.general.beobachten_minutes = 1; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(4096); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.25:55000".parse().unwrap(), + config, + stats.clone(), + Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + beobachten.clone(), + false, + )); + + client_side.write_all(&[0xEF, 0xEF, 0xEF]).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let result = tokio::time::timeout(Duration::from_secs(2), handler) + .await + .unwrap() + .unwrap(); + + assert!(result.is_err()); + assert_eq!(stats.get_connects_bad(), 0); + let snapshot = beobachten.snapshot_text(Duration::from_secs(60)); + assert!(snapshot.contains("[expected_64_got_0]")); +} + +#[tokio::test] +async fn invariant_route_mode_snapshot_picks_up_latest_mode() { + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + assert!(matches!( + route_runtime.snapshot().mode, + RelayRouteMode::Direct + )); + + route_runtime.set_mode(RelayRouteMode::Middle); + assert!(matches!( + route_runtime.snapshot().mode, + RelayRouteMode::Middle + )); +} diff --git a/src/proxy/tests/client_more_advanced_tests.rs b/src/proxy/tests/client_more_advanced_tests.rs new file mode 100644 index 0000000..021848a --- /dev/null +++ b/src/proxy/tests/client_more_advanced_tests.rs @@ -0,0 +1,257 @@ +use super::*; +use crate::config::ProxyConfig; +use crate::stats::Stats; +use crate::transport::UpstreamManager; +use std::sync::Arc; +use std::time::Duration; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; + +#[tokio::test] +async fn edge_mask_delay_bypassed_if_max_is_zero() { + let mut config = ProxyConfig::default(); + config.censorship.server_hello_delay_min_ms = 10_000; + config.censorship.server_hello_delay_max_ms = 0; + + let start = std::time::Instant::now(); + maybe_apply_mask_reject_delay(&config).await; + assert!(start.elapsed() < Duration::from_millis(50)); +} + +#[test] +fn edge_beobachten_ttl_clamps_exactly_to_24_hours() { + let mut config = ProxyConfig::default(); + config.general.beobachten = true; + config.general.beobachten_minutes = 100_000; + + let ttl = beobachten_ttl(&config); + assert_eq!(ttl.as_secs(), 24 * 60 * 60); +} + +#[test] +fn edge_wrap_tls_application_record_empty_payload() { + let wrapped = wrap_tls_application_record(&[]); + assert_eq!(wrapped.len(), 5); + assert_eq!(wrapped[0], TLS_RECORD_APPLICATION); + assert_eq!(&wrapped[3..5], &[0, 0]); +} + +#[tokio::test] +async fn boundary_user_data_quota_exact_match_rejects() { + let user = "quota-boundary-user"; + let mut config = ProxyConfig::default(); + config.access.user_data_quota.insert(user.to_string(), 1024); + + let stats = Arc::new(Stats::new()); + stats.add_user_octets_from(user, 1024); + + let ip_tracker = Arc::new(UserIpTracker::new()); + let peer = "198.51.100.10:55000".parse().unwrap(); + + let result = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats, + peer, + ip_tracker, + ) + .await; + + assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. }))); +} + +#[tokio::test] +async fn boundary_user_expiration_in_past_rejects() { + let user = "expired-boundary-user"; + let mut config = ProxyConfig::default(); + let expired_time = chrono::Utc::now() - chrono::Duration::milliseconds(1); + config + .access + .user_expirations + .insert(user.to_string(), expired_time); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + let peer = "198.51.100.11:55000".parse().unwrap(); + + let result = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats, + peer, + ip_tracker, + ) + .await; + + assert!(matches!(result, Err(ProxyError::UserExpired { .. }))); +} + +#[tokio::test] +async fn blackhat_proxy_protocol_massive_garbage_rejected_quickly() { + let mut cfg = ProxyConfig::default(); + cfg.server.proxy_protocol_header_timeout_ms = 300; + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + + let (server_side, mut client_side) = duplex(4096); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.12:55000".parse().unwrap(), + config, + stats.clone(), + Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + true, + )); + + client_side.write_all(&vec![b'A'; 2000]).await.unwrap(); + + let result = tokio::time::timeout(Duration::from_secs(2), handler) + .await + .unwrap() + .unwrap(); + assert!(matches!(result, Err(ProxyError::InvalidProxyProtocol))); + assert_eq!(stats.get_connects_bad(), 1); +} + +#[tokio::test] +async fn edge_tls_body_immediate_eof_triggers_masking_and_bad_connect() { + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = true; + cfg.general.beobachten_minutes = 1; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(4096); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.13:55000".parse().unwrap(), + config, + stats.clone(), + Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + beobachten.clone(), + false, + )); + + client_side.write_all(&[0x16, 0x03, 0x01, 0x00, 100]).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(2), handler).await.unwrap(); + + assert_eq!(stats.get_connects_bad(), 1); +} + +#[tokio::test] +async fn security_classic_mode_disabled_masks_valid_length_payload() { + let mut cfg = ProxyConfig::default(); + cfg.general.modes.classic = false; + cfg.general.modes.secure = false; + cfg.censorship.mask = true; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + + let (server_side, mut client_side) = duplex(4096); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.15:55000".parse().unwrap(), + config, + stats.clone(), + Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + false, + )); + + client_side.write_all(&vec![0xEF; 64]).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(2), handler).await.unwrap(); + assert_eq!(stats.get_connects_bad(), 1); +} + +#[tokio::test] +async fn concurrency_ip_tracker_strict_limit_one_rapid_churn() { + let user = "rapid-churn-user"; + let mut config = ProxyConfig::default(); + config.access.user_max_tcp_conns.insert(user.to_string(), 10); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 1).await; + + let peer = "198.51.100.16:55000".parse().unwrap(); + + for _ in 0..500 { + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await + .unwrap(); + reservation.release().await; + } + + assert_eq!(stats.get_user_curr_connects(user), 0); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 0); +} + +#[tokio::test] +async fn quirk_read_with_progress_zero_length_buffer_returns_zero_immediately() { + let (mut server_side, _client_side) = duplex(4096); + let mut empty_buf = &mut [][..]; + + let result = tokio::time::timeout( + Duration::from_millis(50), + read_with_progress(&mut server_side, &mut empty_buf), + ) + .await; + + assert!(result.is_ok()); + assert_eq!(result.unwrap().unwrap(), 0); +} + +#[tokio::test] +async fn stress_read_with_progress_cancellation_safety() { + let (mut server_side, mut client_side) = duplex(4096); + + client_side.write_all(b"12345").await.unwrap(); + + let mut buf = [0u8; 10]; + let result = tokio::time::timeout( + Duration::from_millis(50), + read_with_progress(&mut server_side, &mut buf), + ) + .await; + + assert!(result.is_err()); + + client_side.write_all(b"67890").await.unwrap(); + let mut buf2 = [0u8; 5]; + server_side.read_exact(&mut buf2).await.unwrap(); + assert_eq!(&buf2, b"67890"); +} diff --git a/src/proxy/tests/client_security_tests.rs b/src/proxy/tests/client_security_tests.rs index 6338e23..2b1fae6 100644 --- a/src/proxy/tests/client_security_tests.rs +++ b/src/proxy/tests/client_security_tests.rs @@ -7,6 +7,9 @@ use crate::protocol::tls; use crate::proxy::handshake::HandshakeSuccess; use crate::stream::{CryptoReader, CryptoWriter}; use crate::transport::proxy_protocol::ProxyProtocolV1Builder; +use rand::rngs::StdRng; +use rand::Rng; +use rand::SeedableRng; use std::net::Ipv4Addr; use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; use tokio::net::{TcpListener, TcpStream}; @@ -25,6 +28,202 @@ fn synthetic_local_addr_uses_configured_port_for_max() { assert_eq!(addr.port(), u16::MAX); } +#[test] +fn handshake_timeout_with_mask_grace_includes_mask_margin() { + let mut config = ProxyConfig::default(); + config.timeouts.client_handshake = 2; + + config.censorship.mask = false; + assert_eq!(handshake_timeout_with_mask_grace(&config), Duration::from_secs(2)); + + config.censorship.mask = true; + assert_eq!( + handshake_timeout_with_mask_grace(&config), + Duration::from_millis(2750), + "mask mode extends handshake timeout by 750 ms" + ); +} + +#[tokio::test] +async fn read_with_progress_reads_partial_buffers_before_eof() { + let data = vec![0xAA, 0xBB, 0xCC]; + let mut reader = std::io::Cursor::new(data); + let mut buf = [0u8; 5]; + + let read = read_with_progress(&mut reader, &mut buf).await.unwrap(); + assert_eq!(read, 3); + assert_eq!(&buf[..3], &[0xAA, 0xBB, 0xCC]); +} + +#[test] +fn is_trusted_proxy_source_respects_cidr_list_and_empty_rejects_all() { + let peer: IpAddr = "10.10.10.10".parse().unwrap(); + assert!(!is_trusted_proxy_source(peer, &[])); + + let trusted = vec!["10.0.0.0/8".parse().unwrap()]; + assert!(is_trusted_proxy_source(peer, &trusted)); + + let not_trusted = vec!["192.0.2.0/24".parse().unwrap()]; + assert!(!is_trusted_proxy_source(peer, ¬_trusted)); +} + +#[test] +fn is_trusted_proxy_source_accepts_cidr_zero_zero_as_global_cidr() { + let peer: IpAddr = "203.0.113.42".parse().unwrap(); + let trust_all = vec!["0.0.0.0/0".parse().unwrap()]; + assert!(is_trusted_proxy_source(peer, &trust_all)); + + let peer_v6: IpAddr = "2001:db8::1".parse().unwrap(); + let trust_all_v6 = vec!["::/0".parse().unwrap()]; + assert!(is_trusted_proxy_source(peer_v6, &trust_all_v6)); +} + +struct ErrorReader; + +impl tokio::io::AsyncRead for ErrorReader { + fn poll_read( + self: std::pin::Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + _buf: &mut tokio::io::ReadBuf<'_>, + ) -> std::task::Poll> { + std::task::Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "fake error"))) + } +} + +#[tokio::test] +async fn read_with_progress_returns_error_from_failed_reader() { + let mut reader = ErrorReader; + let mut buf = [0u8; 8]; + let err = read_with_progress(&mut reader, &mut buf).await.unwrap_err(); + assert_eq!(err.kind(), std::io::ErrorKind::UnexpectedEof); +} + +#[test] +fn handshake_timeout_with_mask_grace_handles_maximum_values_without_overflow() { + let mut config = ProxyConfig::default(); + config.timeouts.client_handshake = u64::MAX; + config.censorship.mask = true; + + let timeout = handshake_timeout_with_mask_grace(&config); + assert!(timeout >= Duration::from_secs(u64::MAX)); +} + +#[tokio::test] +async fn read_with_progress_zero_length_buffer_returns_zero() { + let data = vec![1, 2, 3]; + let mut reader = std::io::Cursor::new(data); + let mut buf = []; + + let read = read_with_progress(&mut reader, &mut buf).await.unwrap(); + assert_eq!(read, 0); +} + +#[test] +fn handshake_timeout_without_mask_is_exact_base() { + let mut config = ProxyConfig::default(); + config.timeouts.client_handshake = 7; + config.censorship.mask = false; + + assert_eq!(handshake_timeout_with_mask_grace(&config), Duration::from_secs(7)); +} + +#[test] +fn handshake_timeout_mask_enabled_adds_750ms() { + let mut config = ProxyConfig::default(); + config.timeouts.client_handshake = 3; + config.censorship.mask = true; + + assert_eq!(handshake_timeout_with_mask_grace(&config), Duration::from_millis(3750)); +} + +#[tokio::test] +async fn read_with_progress_full_then_empty_transition() { + let data = vec![0x10, 0x20]; + let mut cursor = std::io::Cursor::new(data); + let mut buf = [0u8; 2]; + + assert_eq!(read_with_progress(&mut cursor, &mut buf).await.unwrap(), 2); + assert_eq!(read_with_progress(&mut cursor, &mut buf).await.unwrap(), 0); +} + +#[tokio::test] +async fn read_with_progress_fragmented_io_works_over_multiple_calls() { + let mut cursor = std::io::Cursor::new(vec![1, 2, 3, 4, 5]); + let mut result = Vec::new(); + + for chunk_size in 1..=5 { + let mut b = vec![0u8; chunk_size]; + let n = read_with_progress(&mut cursor, &mut b).await.unwrap(); + result.extend_from_slice(&b[..n]); + if n == 0 { break; } + } + + assert_eq!(result, vec![1,2,3,4,5]); +} + +#[tokio::test] +async fn read_with_progress_stress_randomized_chunk_sizes() { + for i in 0..128 { + let mut rng = StdRng::seed_from_u64(i as u64 + 1); + let mut input: Vec = (0..(i % 41)).map(|_| rng.next_u32() as u8).collect(); + let mut cursor = std::io::Cursor::new(input.clone()); + let mut collected = Vec::new(); + + while cursor.position() < cursor.get_ref().len() as u64 { + let chunk = 1 + (rng.next_u32() as usize % 8); + let mut b = vec![0u8; chunk]; + let read = read_with_progress(&mut cursor, &mut b).await.unwrap(); + collected.extend_from_slice(&b[..read]); + if read == 0 { break; } + } + + assert_eq!(collected, input); + } +} + +#[test] +fn is_trusted_proxy_source_boundary_narrow_ipv4() { + let matching = "172.16.0.1".parse().unwrap(); + let not_matching = "172.15.255.255".parse().unwrap(); + let cidr = vec!["172.16.0.0/12".parse().unwrap()]; + assert!(is_trusted_proxy_source(matching, &cidr)); + assert!(!is_trusted_proxy_source(not_matching, &cidr)); +} + +#[test] +fn is_trusted_proxy_source_rejects_out_of_family_ipv6_v4_cidr() { + let peer = "2001:db8::1".parse().unwrap(); + let cidr = vec!["10.0.0.0/8".parse().unwrap()]; + assert!(!is_trusted_proxy_source(peer, &cidr)); +} + +#[test] +fn wrap_tls_application_record_reserved_chunks_look_reasonable() { + let payload = vec![0xAA; 1 + (u16::MAX as usize) + 2]; + let wrapped = wrap_tls_application_record(&payload); + assert!(wrapped.len() > payload.len()); + assert!(wrapped.contains(&0x17)); +} + +#[test] +fn wrap_tls_application_record_roundtrip_size_check() { + let payload_len = 3000; + let payload = vec![0x55; payload_len]; + let wrapped = wrap_tls_application_record(&payload); + + let mut idx = 0; + let mut consumed = 0; + while idx + 5 <= wrapped.len() { + assert_eq!(wrapped[idx], 0x17); + let len = u16::from_be_bytes([wrapped[idx+3], wrapped[idx+4]]) as usize; + consumed += len; + idx += 5 + len; + if idx >= wrapped.len() { break; } + } + + assert_eq!(consumed, payload_len); +} + fn make_crypto_reader(reader: R) -> CryptoReader where R: tokio::io::AsyncRead + Unpin, diff --git a/src/proxy/tests/handshake_advanced_clever_tests.rs b/src/proxy/tests/handshake_advanced_clever_tests.rs new file mode 100644 index 0000000..9b12f21 --- /dev/null +++ b/src/proxy/tests/handshake_advanced_clever_tests.rs @@ -0,0 +1,647 @@ +use super::*; +use crate::crypto::{sha256, sha256_hmac, AesCtr}; +use crate::protocol::constants::{ProtoTag, RESERVED_NONCE_BEGINNINGS, RESERVED_NONCE_FIRST_BYTES}; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; +use std::sync::Arc; +use std::time::{Duration, Instant}; + +// --- Helpers --- + +fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> { + auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig { + let mut cfg = ProxyConfig::default(); + cfg.access.users.clear(); + cfg.access + .users + .insert("user".to_string(), secret_hex.to_string()); + cfg.access.ignore_time_skew = true; + cfg.general.modes.secure = true; + cfg.general.modes.classic = true; + cfg.general.modes.tls = true; + cfg +} + +fn make_valid_tls_handshake(secret: &[u8], timestamp: u32) -> Vec { + let session_id_len: usize = 32; + let len = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + session_id_len; + let mut handshake = vec![0x42u8; len]; + + handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8; + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + + let computed = sha256_hmac(secret, &handshake); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .copy_from_slice(&digest); + handshake +} + +fn make_valid_mtproto_handshake( + secret_hex: &str, + proto_tag: ProtoTag, + dc_idx: i16, +) -> [u8; HANDSHAKE_LEN] { + let secret = hex::decode(secret_hex).expect("secret hex must decode"); + let mut handshake = [0x5Au8; HANDSHAKE_LEN]; + for (idx, b) in handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN] + .iter_mut() + .enumerate() + { + *b = (idx as u8).wrapping_add(1); + } + + let dec_prekey = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN]; + let dec_iv_bytes = &handshake[SKIP_LEN + PREKEY_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]; + + let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len()); + dec_key_input.extend_from_slice(dec_prekey); + dec_key_input.extend_from_slice(&secret); + let dec_key = sha256(&dec_key_input); + + let mut dec_iv_arr = [0u8; IV_LEN]; + dec_iv_arr.copy_from_slice(dec_iv_bytes); + let dec_iv = u128::from_be_bytes(dec_iv_arr); + + let mut stream = AesCtr::new(&dec_key, dec_iv); + let keystream = stream.encrypt(&[0u8; HANDSHAKE_LEN]); + + let mut target_plain = [0u8; HANDSHAKE_LEN]; + target_plain[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes()); + target_plain[DC_IDX_POS..DC_IDX_POS + 2].copy_from_slice(&dc_idx.to_le_bytes()); + + for idx in PROTO_TAG_POS..HANDSHAKE_LEN { + handshake[idx] = target_plain[idx] ^ keystream[idx]; + } + + handshake +} + +fn make_valid_tls_client_hello_with_alpn( + secret: &[u8], + timestamp: u32, + alpn_protocols: &[&[u8]], +) -> Vec { + let mut body = Vec::new(); + body.extend_from_slice(&TLS_VERSION); + body.extend_from_slice(&[0u8; 32]); + body.push(32); + body.extend_from_slice(&[0x42u8; 32]); + body.extend_from_slice(&2u16.to_be_bytes()); + body.extend_from_slice(&[0x13, 0x01]); + body.push(1); + body.push(0); + + let mut ext_blob = Vec::new(); + if !alpn_protocols.is_empty() { + let mut alpn_list = Vec::new(); + for proto in alpn_protocols { + alpn_list.push(proto.len() as u8); + alpn_list.extend_from_slice(proto); + } + let mut alpn_data = Vec::new(); + alpn_data.extend_from_slice(&(alpn_list.len() as u16).to_be_bytes()); + alpn_data.extend_from_slice(&alpn_list); + + ext_blob.extend_from_slice(&0x0010u16.to_be_bytes()); + ext_blob.extend_from_slice(&(alpn_data.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&alpn_data); + } + body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes()); + body.extend_from_slice(&ext_blob); + + let mut handshake = Vec::new(); + handshake.push(0x01); + let body_len = (body.len() as u32).to_be_bytes(); + handshake.extend_from_slice(&body_len[1..4]); + handshake.extend_from_slice(&body); + + let mut record = Vec::new(); + record.push(TLS_RECORD_HANDSHAKE); + record.extend_from_slice(&[0x03, 0x01]); + record.extend_from_slice(&(handshake.len() as u16).to_be_bytes()); + record.extend_from_slice(&handshake); + + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + let computed = sha256_hmac(secret, &record); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].copy_from_slice(&digest); + + record +} + +// --- Category 1: Edge Cases & Protocol Boundaries --- + +#[tokio::test] +async fn tls_minimum_viable_length_boundary() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret = [0x11u8; 16]; + let config = test_config_with_secret_hex("11111111111111111111111111111111"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "192.0.2.1:12345".parse().unwrap(); + + let min_len = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1; + let mut exact_min_handshake = vec![0x42u8; min_len]; + exact_min_handshake[min_len - 1] = 0; + exact_min_handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + let digest = sha256_hmac(&secret, &exact_min_handshake); + exact_min_handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .copy_from_slice(&digest); + + let res = handle_tls_handshake( + &exact_min_handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(res, HandshakeResult::Success(_)), "Exact minimum length TLS handshake must succeed"); + + let short_handshake = vec![0x42u8; min_len - 1]; + let res_short = handle_tls_handshake( + &short_handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(res_short, HandshakeResult::BadClient { .. }), "Handshake 1 byte shorter than minimum must fail closed"); +} + +#[tokio::test] +async fn mtproto_extreme_dc_index_serialization() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret_hex = "22222222222222222222222222222222"; + let config = test_config_with_secret_hex(secret_hex); + for (idx, extreme_dc) in [i16::MIN, i16::MAX, -1, 0].into_iter().enumerate() { + // Keep replay state independent per case so we validate dc_idx encoding, + // not duplicate-handshake rejection behavior. + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 0, 2, 2)), 12345 + idx as u16); + let handshake = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, extreme_dc); + let res = handle_mtproto_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + + match res { + HandshakeResult::Success((_, _, success)) => { + assert_eq!(success.dc_idx, extreme_dc, "Extreme DC index {} must serialize/deserialize perfectly", extreme_dc); + } + _ => panic!("MTProto handshake with extreme DC index {} failed", extreme_dc), + } + } +} + +#[tokio::test] +async fn alpn_strict_case_and_padding_rejection() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret = [0x33u8; 16]; + let mut config = test_config_with_secret_hex("33333333333333333333333333333333"); + config.censorship.alpn_enforce = true; + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "192.0.2.3:12345".parse().unwrap(); + + let bad_alpns: &[&[u8]] = &[b"H2", b"h2\0", b" http/1.1", b"http/1.1\n"]; + + for bad_alpn in bad_alpns { + let handshake = make_valid_tls_client_hello_with_alpn(&secret, 0, &[*bad_alpn]); + let res = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(res, HandshakeResult::BadClient { .. }), "ALPN strict enforcement must reject {:?}", bad_alpn); + } +} + +#[test] +fn ipv4_mapped_ipv6_bucketing_anomaly() { + let ipv4_mapped_1 = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0xffff, 0xc000, 0x0201)); + let ipv4_mapped_2 = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0xffff, 0xc633, 0x6402)); + + let norm_1 = normalize_auth_probe_ip(ipv4_mapped_1); + let norm_2 = normalize_auth_probe_ip(ipv4_mapped_2); + + assert_eq!(norm_1, norm_2, "IPv4-mapped IPv6 addresses must collapse into the same /64 bucket (::0)"); + assert_eq!(norm_1, IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), "The bucket must be exactly ::0"); +} + +// --- Category 2: Adversarial & Black Hat --- + +#[tokio::test] +async fn mtproto_invalid_ciphertext_does_not_poison_replay_cache() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret_hex = "55555555555555555555555555555555"; + let config = test_config_with_secret_hex(secret_hex); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let peer: SocketAddr = "192.0.2.5:12345".parse().unwrap(); + + let valid_handshake = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 1); + let mut invalid_handshake = valid_handshake; + invalid_handshake[SKIP_LEN + PREKEY_LEN + IV_LEN + 1] ^= 0xFF; + + let res_invalid = handle_mtproto_handshake( + &invalid_handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + assert!(matches!(res_invalid, HandshakeResult::BadClient { .. })); + + let res_valid = handle_mtproto_handshake( + &valid_handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + assert!(matches!(res_valid, HandshakeResult::Success(_)), "Invalid MTProto ciphertext must not poison the replay cache"); +} + +#[tokio::test] +async fn tls_invalid_session_does_not_poison_replay_cache() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret = [0x66u8; 16]; + let config = test_config_with_secret_hex("66666666666666666666666666666666"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "192.0.2.6:12345".parse().unwrap(); + + let valid_handshake = make_valid_tls_handshake(&secret, 0); + let mut invalid_handshake = valid_handshake.clone(); + let session_idx = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1; + invalid_handshake[session_idx] ^= 0xFF; + + let res_invalid = handle_tls_handshake( + &invalid_handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(res_invalid, HandshakeResult::BadClient { .. })); + + let res_valid = handle_tls_handshake( + &valid_handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(res_valid, HandshakeResult::Success(_)), "Invalid TLS payload must not poison the replay cache"); +} + +#[tokio::test] +async fn server_hello_delay_timing_neutrality_on_hmac_failure() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret = [0x77u8; 16]; + let mut config = test_config_with_secret_hex("77777777777777777777777777777777"); + config.censorship.server_hello_delay_min_ms = 50; + config.censorship.server_hello_delay_max_ms = 50; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "192.0.2.7:12345".parse().unwrap(); + + let mut invalid_handshake = make_valid_tls_handshake(&secret, 0); + invalid_handshake[tls::TLS_DIGEST_POS] ^= 0xFF; + + let start = Instant::now(); + let res = handle_tls_handshake( + &invalid_handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + let elapsed = start.elapsed(); + + assert!(matches!(res, HandshakeResult::BadClient { .. })); + assert!(elapsed >= Duration::from_millis(45), "Invalid HMAC must still incur the configured ServerHello delay to prevent timing side-channels"); +} + +#[tokio::test] +async fn server_hello_delay_inversion_resilience() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret = [0x88u8; 16]; + let mut config = test_config_with_secret_hex("88888888888888888888888888888888"); + config.censorship.server_hello_delay_min_ms = 100; + config.censorship.server_hello_delay_max_ms = 10; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "192.0.2.8:12345".parse().unwrap(); + + let valid_handshake = make_valid_tls_handshake(&secret, 0); + + let start = Instant::now(); + let res = handle_tls_handshake( + &valid_handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + let elapsed = start.elapsed(); + + assert!(matches!(res, HandshakeResult::Success(_))); + assert!(elapsed >= Duration::from_millis(90), "Delay logic must gracefully handle min > max inversions via max.max(min)"); +} + +#[tokio::test] +async fn mixed_valid_and_invalid_user_secrets_configuration() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + let _warn_guard = warned_secrets_test_lock().lock().unwrap(); + clear_warned_secrets_for_testing(); + + let mut config = ProxyConfig::default(); + config.access.ignore_time_skew = true; + + for i in 0..9 { + let bad_secret = if i % 2 == 0 { "badhex!" } else { "1122" }; + config.access.users.insert(format!("bad_user_{}", i), bad_secret.to_string()); + } + let valid_secret_hex = "99999999999999999999999999999999"; + config.access.users.insert("good_user".to_string(), valid_secret_hex.to_string()); + config.general.modes.secure = true; + config.general.modes.classic = true; + config.general.modes.tls = true; + + let secret = [0x99u8; 16]; + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "192.0.2.9:12345".parse().unwrap(); + + let valid_handshake = make_valid_tls_handshake(&secret, 0); + + let res = handle_tls_handshake( + &valid_handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(res, HandshakeResult::Success(_)), "Proxy must gracefully skip invalid secrets and authenticate the valid one"); +} + +#[tokio::test] +async fn tls_emulation_fallback_when_cache_missing() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret = [0xAAu8; 16]; + let mut config = test_config_with_secret_hex("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"); + config.censorship.tls_emulation = true; + config.general.modes.tls = true; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "192.0.2.10:12345".parse().unwrap(); + + let valid_handshake = make_valid_tls_handshake(&secret, 0); + + let res = handle_tls_handshake( + &valid_handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(res, HandshakeResult::Success(_)), "TLS emulation must gracefully fall back to standard ServerHello if cache is missing"); +} + +#[tokio::test] +async fn classic_mode_over_tls_transport_protocol_confusion() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret_hex = "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb"; + let mut config = test_config_with_secret_hex(secret_hex); + config.general.modes.classic = true; + config.general.modes.tls = true; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let peer: SocketAddr = "192.0.2.11:12345".parse().unwrap(); + + let handshake = make_valid_mtproto_handshake(secret_hex, ProtoTag::Intermediate, 1); + + let res = handle_mtproto_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + true, + None, + ) + .await; + + assert!(matches!(res, HandshakeResult::Success(_)), "Intermediate tag over TLS must succeed if classic mode is enabled, locking in cross-transport behavior"); +} + +#[test] +fn generate_tg_nonce_never_emits_reserved_bytes() { + let client_enc_key = [0xCCu8; 32]; + let client_enc_iv = 123456789u128; + let rng = SecureRandom::new(); + + for _ in 0..10_000 { + let (nonce, _, _, _, _) = generate_tg_nonce( + ProtoTag::Secure, + 1, + &client_enc_key, + client_enc_iv, + &rng, + false, + ); + + assert!(!RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]), "Nonce must never start with reserved bytes"); + let first_four: [u8; 4] = [nonce[0], nonce[1], nonce[2], nonce[3]]; + assert!(!RESERVED_NONCE_BEGINNINGS.contains(&first_four), "Nonce must never match reserved 4-byte beginnings"); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn dashmap_concurrent_saturation_stress() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let ip_a: IpAddr = "192.0.2.13".parse().unwrap(); + let ip_b: IpAddr = "198.51.100.13".parse().unwrap(); + let mut tasks = Vec::new(); + + for i in 0..100 { + let target_ip = if i % 2 == 0 { ip_a } else { ip_b }; + tasks.push(tokio::spawn(async move { + for _ in 0..50 { + auth_probe_record_failure(target_ip, Instant::now()); + } + })); + } + + for task in tasks { + task.await.expect("Task panicked during concurrent DashMap stress"); + } + + assert!(auth_probe_is_throttled_for_testing(ip_a), "IP A must be throttled after concurrent stress"); + assert!(auth_probe_is_throttled_for_testing(ip_b), "IP B must be throttled after concurrent stress"); +} + +#[test] +fn prototag_invalid_bytes_fail_closed() { + let invalid_tags: [[u8; 4]; 5] = [ + [0, 0, 0, 0], + [0xFF, 0xFF, 0xFF, 0xFF], + [0xDE, 0xAD, 0xBE, 0xEF], + [0xDD, 0xDD, 0xDD, 0xDE], + [0x11, 0x22, 0x33, 0x44], + ]; + + for tag in invalid_tags { + assert_eq!(ProtoTag::from_bytes(tag), None, "Invalid ProtoTag bytes {:?} must fail closed", tag); + } +} + +#[test] +fn auth_probe_eviction_hash_collision_stress() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let state = auth_probe_state_map(); + let now = Instant::now(); + + for i in 0..10_000u32 { + let ip = IpAddr::V4(Ipv4Addr::new(10, 0, (i >> 8) as u8, (i & 0xFF) as u8)); + auth_probe_record_failure_with_state(state, ip, now); + } + + assert!(state.len() <= AUTH_PROBE_TRACK_MAX_ENTRIES, "Eviction logic must successfully bound the map size under heavy insertion stress"); +} + +#[test] +fn encrypt_tg_nonce_with_ciphers_advances_counter_correctly() { + let client_enc_key = [0xDDu8; 32]; + let client_enc_iv = 987654321u128; + let rng = SecureRandom::new(); + + let (nonce, _, _, _, _) = generate_tg_nonce( + ProtoTag::Secure, + 2, + &client_enc_key, + client_enc_iv, + &rng, + false, + ); + + let (_, mut returned_encryptor, _) = encrypt_tg_nonce_with_ciphers(&nonce); + let zeros = [0u8; 64]; + let returned_keystream = returned_encryptor.encrypt(&zeros); + + let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN]; + let mut expected_enc_key = [0u8; 32]; + expected_enc_key.copy_from_slice(&enc_key_iv[..KEY_LEN]); + let mut expected_enc_iv_arr = [0u8; IV_LEN]; + expected_enc_iv_arr.copy_from_slice(&enc_key_iv[KEY_LEN..]); + let expected_enc_iv = u128::from_be_bytes(expected_enc_iv_arr); + + let mut manual_encryptor = AesCtr::new(&expected_enc_key, expected_enc_iv); + + let mut manual_input = Vec::new(); + manual_input.extend_from_slice(&nonce); + manual_input.extend_from_slice(&zeros); + let manual_output = manual_encryptor.encrypt(&manual_input); + + assert_eq!( + returned_keystream, + &manual_output[64..128], + "encrypt_tg_nonce_with_ciphers must correctly advance the AES-CTR counter by exactly the nonce length" + ); +} diff --git a/src/proxy/tests/handshake_more_clever_tests.rs b/src/proxy/tests/handshake_more_clever_tests.rs new file mode 100644 index 0000000..77df442 --- /dev/null +++ b/src/proxy/tests/handshake_more_clever_tests.rs @@ -0,0 +1,614 @@ +use super::*; +use crate::crypto::{sha256, sha256_hmac, AesCtr}; +use crate::protocol::constants::{ProtoTag, RESERVED_NONCE_BEGINNINGS, RESERVED_NONCE_FIRST_BYTES}; +use rand::{Rng, SeedableRng}; +use rand::rngs::StdRng; +use std::collections::HashSet; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tokio::sync::Barrier; + +// --- Helpers --- + +fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> { + auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig { + let mut cfg = ProxyConfig::default(); + cfg.access.users.clear(); + cfg.access + .users + .insert("user".to_string(), secret_hex.to_string()); + cfg.access.ignore_time_skew = true; + cfg.general.modes.secure = true; + cfg.general.modes.classic = true; + cfg.general.modes.tls = true; + cfg +} + +fn make_valid_tls_handshake(secret: &[u8], timestamp: u32) -> Vec { + let session_id_len: usize = 32; + let len = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + session_id_len; + let mut handshake = vec![0x42u8; len]; + + handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8; + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + + let computed = sha256_hmac(secret, &handshake); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .copy_from_slice(&digest); + handshake +} + +fn make_valid_mtproto_handshake( + secret_hex: &str, + proto_tag: ProtoTag, + dc_idx: i16, +) -> [u8; HANDSHAKE_LEN] { + let secret = hex::decode(secret_hex).expect("secret hex must decode"); + let mut handshake = [0x5Au8; HANDSHAKE_LEN]; + for (idx, b) in handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN] + .iter_mut() + .enumerate() + { + *b = (idx as u8).wrapping_add(1); + } + + let dec_prekey = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN]; + let dec_iv_bytes = &handshake[SKIP_LEN + PREKEY_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]; + + let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len()); + dec_key_input.extend_from_slice(dec_prekey); + dec_key_input.extend_from_slice(&secret); + let dec_key = sha256(&dec_key_input); + + let mut dec_iv_arr = [0u8; IV_LEN]; + dec_iv_arr.copy_from_slice(dec_iv_bytes); + let dec_iv = u128::from_be_bytes(dec_iv_arr); + + let mut stream = AesCtr::new(&dec_key, dec_iv); + let keystream = stream.encrypt(&[0u8; HANDSHAKE_LEN]); + + let mut target_plain = [0u8; HANDSHAKE_LEN]; + target_plain[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes()); + target_plain[DC_IDX_POS..DC_IDX_POS + 2].copy_from_slice(&dc_idx.to_le_bytes()); + + for idx in PROTO_TAG_POS..HANDSHAKE_LEN { + handshake[idx] = target_plain[idx] ^ keystream[idx]; + } + + handshake +} + +fn make_valid_tls_client_hello_with_sni_and_alpn( + secret: &[u8], + timestamp: u32, + sni_host: &str, + alpn_protocols: &[&[u8]], +) -> Vec { + let mut body = Vec::new(); + body.extend_from_slice(&TLS_VERSION); + body.extend_from_slice(&[0u8; 32]); + body.push(32); + body.extend_from_slice(&[0x42u8; 32]); + body.extend_from_slice(&2u16.to_be_bytes()); + body.extend_from_slice(&[0x13, 0x01]); + body.push(1); + body.push(0); + + let mut ext_blob = Vec::new(); + + let host_bytes = sni_host.as_bytes(); + let mut sni_payload = Vec::new(); + sni_payload.extend_from_slice(&((host_bytes.len() + 3) as u16).to_be_bytes()); + sni_payload.push(0); + sni_payload.extend_from_slice(&(host_bytes.len() as u16).to_be_bytes()); + sni_payload.extend_from_slice(host_bytes); + ext_blob.extend_from_slice(&0x0000u16.to_be_bytes()); + ext_blob.extend_from_slice(&(sni_payload.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&sni_payload); + + if !alpn_protocols.is_empty() { + let mut alpn_list = Vec::new(); + for proto in alpn_protocols { + alpn_list.push(proto.len() as u8); + alpn_list.extend_from_slice(proto); + } + let mut alpn_data = Vec::new(); + alpn_data.extend_from_slice(&(alpn_list.len() as u16).to_be_bytes()); + alpn_data.extend_from_slice(&alpn_list); + + ext_blob.extend_from_slice(&0x0010u16.to_be_bytes()); + ext_blob.extend_from_slice(&(alpn_data.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&alpn_data); + } + + body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes()); + body.extend_from_slice(&ext_blob); + + let mut handshake = Vec::new(); + handshake.push(0x01); + let body_len = (body.len() as u32).to_be_bytes(); + handshake.extend_from_slice(&body_len[1..4]); + handshake.extend_from_slice(&body); + + let mut record = Vec::new(); + record.push(TLS_RECORD_HANDSHAKE); + record.extend_from_slice(&[0x03, 0x01]); + record.extend_from_slice(&(handshake.len() as u16).to_be_bytes()); + record.extend_from_slice(&handshake); + + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + let computed = sha256_hmac(secret, &record); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].copy_from_slice(&digest); + + record +} + +// --- Category 1: Timing & Delay Invariants --- + +#[tokio::test] +async fn server_hello_delay_bypassed_if_max_is_zero_despite_high_min() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret = [0x1Au8; 16]; + let mut config = test_config_with_secret_hex("1a1a1a1a1a1a1a1a1a1a1a1a1a1a1a1a"); + config.censorship.server_hello_delay_min_ms = 5000; + config.censorship.server_hello_delay_max_ms = 0; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "192.0.2.101:12345".parse().unwrap(); + + let mut invalid_handshake = make_valid_tls_handshake(&secret, 0); + invalid_handshake[tls::TLS_DIGEST_POS] ^= 0xFF; + + let fut = handle_tls_handshake( + &invalid_handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ); + + // Deterministic assertion: with max_ms == 0 there must be no sleep path, + // so the handshake should complete promptly under a generous timeout budget. + let res = tokio::time::timeout(Duration::from_millis(250), fut) + .await + .expect("max_ms=0 should bypass artificial delay and complete quickly"); + + assert!(matches!(res, HandshakeResult::BadClient { .. })); +} + +#[test] +fn auth_probe_backoff_extreme_fail_streak_clamps_safely() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let state = auth_probe_state_map(); + let peer_ip = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 99)); + let now = Instant::now(); + + state.insert( + peer_ip, + AuthProbeState { + fail_streak: u32::MAX - 1, + blocked_until: now, + last_seen: now, + }, + ); + + auth_probe_record_failure_with_state(&state, peer_ip, now); + + let updated = state.get(&peer_ip).unwrap(); + assert_eq!(updated.fail_streak, u32::MAX); + + let expected_blocked_until = now + Duration::from_millis(AUTH_PROBE_BACKOFF_MAX_MS); + assert_eq!(updated.blocked_until, expected_blocked_until, "Extreme fail streak must clamp cleanly to AUTH_PROBE_BACKOFF_MAX_MS"); +} + +#[test] +fn generate_tg_nonce_cryptographic_uniqueness_and_entropy() { + let client_enc_key = [0x2Bu8; 32]; + let client_enc_iv = 1337u128; + let rng = SecureRandom::new(); + + let mut nonces = HashSet::new(); + let mut total_set_bits = 0usize; + let iterations = 5_000; + + for _ in 0..iterations { + let (nonce, _, _, _, _) = generate_tg_nonce( + ProtoTag::Secure, + 2, + &client_enc_key, + client_enc_iv, + &rng, + false, + ); + + for byte in nonce.iter() { + total_set_bits += byte.count_ones() as usize; + } + + assert!(nonces.insert(nonce), "generate_tg_nonce emitted a duplicate nonce! RNG is stuck."); + } + + let total_bits = iterations * HANDSHAKE_LEN * 8; + let ratio = (total_set_bits as f64) / (total_bits as f64); + assert!(ratio > 0.48 && ratio < 0.52, "Nonce entropy is degraded. Set bit ratio: {}", ratio); +} + +#[tokio::test] +async fn mtproto_multi_user_decryption_isolation() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let mut config = ProxyConfig::default(); + config.general.modes.secure = true; + config.access.ignore_time_skew = true; + + config.access.users.insert("user_a".to_string(), "11111111111111111111111111111111".to_string()); + config.access.users.insert("user_b".to_string(), "22222222222222222222222222222222".to_string()); + let good_secret_hex = "33333333333333333333333333333333"; + config.access.users.insert("user_c".to_string(), good_secret_hex.to_string()); + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let peer: SocketAddr = "192.0.2.104:12345".parse().unwrap(); + + let valid_handshake = make_valid_mtproto_handshake(good_secret_hex, ProtoTag::Secure, 1); + + let res = handle_mtproto_handshake( + &valid_handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + + match res { + HandshakeResult::Success((_, _, success)) => { + assert_eq!(success.user, "user_c", "Decryption attempts on previous users must not corrupt the handshake buffer for the valid user"); + } + _ => panic!("Multi-user MTProto handshake failed. Decryption buffer might be mutating in place."), + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn invalid_secret_warning_lock_contention_and_bound() { + let _guard = warned_secrets_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_warned_secrets_for_testing(); + + let tasks = 50; + let iterations_per_task = 100; + let barrier = Arc::new(Barrier::new(tasks)); + let mut handles = Vec::new(); + + for t in 0..tasks { + let b = barrier.clone(); + handles.push(tokio::spawn(async move { + b.wait().await; + for i in 0..iterations_per_task { + let user_name = format!("contention_user_{}_{}", t, i); + warn_invalid_secret_once(&user_name, "invalid_hex", ACCESS_SECRET_BYTES, None); + } + })); + } + + for handle in handles { + handle.await.unwrap(); + } + + let warned = INVALID_SECRET_WARNED.get().unwrap(); + let guard = warned.lock().unwrap_or_else(|poisoned| poisoned.into_inner()); + + assert_eq!( + guard.len(), + WARNED_SECRET_MAX_ENTRIES, + "Concurrent spam of invalid secrets must strictly bound the HashSet memory to WARNED_SECRET_MAX_ENTRIES" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn mtproto_strict_concurrent_replay_race_condition() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret_hex = "4A4A4A4A4A4A4A4A4A4A4A4A4A4A4A4A"; + let config = Arc::new(test_config_with_secret_hex(secret_hex)); + let replay_checker = Arc::new(ReplayChecker::new(4096, Duration::from_secs(60))); + let valid_handshake = Arc::new(make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 1)); + + let tasks = 100; + let barrier = Arc::new(Barrier::new(tasks)); + let mut handles = Vec::new(); + + for i in 0..tasks { + let b = barrier.clone(); + let cfg = config.clone(); + let rc = replay_checker.clone(); + let hs = valid_handshake.clone(); + + handles.push(tokio::spawn(async move { + let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, (i % 250) as u8)), 10000 + i as u16); + b.wait().await; + handle_mtproto_handshake( + &hs, + tokio::io::empty(), + tokio::io::sink(), + peer, + &cfg, + &rc, + false, + None, + ) + .await + })); + } + + let mut successes = 0; + let mut failures = 0; + + for handle in handles { + match handle.await.unwrap() { + HandshakeResult::Success(_) => successes += 1, + HandshakeResult::BadClient { .. } => failures += 1, + _ => panic!("Unexpected error result in concurrent MTProto replay test"), + } + } + + assert_eq!(successes, 1, "Replay cache race condition allowed multiple identical MTProto handshakes to succeed"); + assert_eq!(failures, tasks - 1, "Replay cache failed to forcefully reject concurrent duplicates"); +} + +#[tokio::test] +async fn tls_alpn_zero_length_protocol_handled_safely() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret = [0x5Bu8; 16]; + let mut config = test_config_with_secret_hex("5b5b5b5b5b5b5b5b5b5b5b5b5b5b5b5b"); + config.censorship.alpn_enforce = true; + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "192.0.2.107:12345".parse().unwrap(); + + let handshake = make_valid_tls_client_hello_with_sni_and_alpn(&secret, 0, "example.com", &[b""]); + + let res = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(res, HandshakeResult::BadClient { .. }), "0-length ALPN must be safely rejected without panicking"); +} + +#[tokio::test] +async fn tls_sni_massive_hostname_does_not_panic() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret = [0x6Cu8; 16]; + let config = test_config_with_secret_hex("6c6c6c6c6c6c6c6c6c6c6c6c6c6c6c6c"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "192.0.2.108:12345".parse().unwrap(); + + let massive_hostname = String::from_utf8(vec![b'a'; 65000]).unwrap(); + let handshake = make_valid_tls_client_hello_with_sni_and_alpn(&secret, 0, &massive_hostname, &[]); + + let res = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(res, HandshakeResult::Success(_) | HandshakeResult::BadClient { .. }), "Massive SNI hostname must be processed or ignored without stack overflow or panic"); +} + +#[tokio::test] +async fn tls_progressive_truncation_fuzzing_no_panics() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret = [0x7Du8; 16]; + let config = test_config_with_secret_hex("7d7d7d7d7d7d7d7d7d7d7d7d7d7d7d7d"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "192.0.2.109:12345".parse().unwrap(); + + let valid_handshake = make_valid_tls_client_hello_with_sni_and_alpn(&secret, 0, "example.com", &[b"h2"]); + let full_len = valid_handshake.len(); + + // Truncated corpus only: full_len is a valid baseline and should not be + // asserted as BadClient in a truncation-specific test. + for i in (0..full_len).rev() { + let truncated = &valid_handshake[..i]; + let res = handle_tls_handshake( + truncated, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(res, HandshakeResult::BadClient { .. }), "Truncated TLS handshake at len {} must fail safely without panicking", i); + } +} + +#[tokio::test] +async fn mtproto_pure_entropy_fuzzing_no_panics() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let config = test_config_with_secret_hex("8e8e8e8e8e8e8e8e8e8e8e8e8e8e8e8e"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let peer: SocketAddr = "192.0.2.110:12345".parse().unwrap(); + + let mut seeded = StdRng::seed_from_u64(0xDEADBEEFCAFE); + + for _ in 0..10_000 { + let mut noise = [0u8; HANDSHAKE_LEN]; + seeded.fill_bytes(&mut noise); + + let res = handle_mtproto_handshake( + &noise, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + + assert!(matches!(res, HandshakeResult::BadClient { .. }), "Pure entropy MTProto payload must fail closed and never panic"); + } +} + +#[test] +fn decode_user_secret_odd_length_hex_rejection() { + let _guard = warned_secrets_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_warned_secrets_for_testing(); + + let mut config = ProxyConfig::default(); + config.access.users.clear(); + config.access.users.insert("odd_user".to_string(), "1234567890123456789012345678901".to_string()); + + let decoded = decode_user_secrets(&config, None); + assert!(decoded.is_empty(), "Odd-length hex string must be gracefully rejected by hex::decode without unwrapping"); +} + +#[test] +fn saturation_grace_pre_existing_high_fail_streak_immediate_throttle() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let state = auth_probe_state_map(); + let peer_ip = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 112)); + let now = Instant::now(); + + let extreme_streak = AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS + 5; + state.insert( + peer_ip, + AuthProbeState { + fail_streak: extreme_streak, + blocked_until: now + Duration::from_secs(5), + last_seen: now, + }, + ); + + { + let mut guard = auth_probe_saturation_state_lock(); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(5), + last_seen: now, + }); + } + + let is_throttled = auth_probe_should_apply_preauth_throttle(peer_ip, now); + assert!(is_throttled, "A peer with a pre-existing high fail streak must be immediately throttled when saturation begins, receiving no unearned grace period"); +} + +#[test] +fn auth_probe_saturation_note_resets_retention_window() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let base_time = Instant::now(); + + auth_probe_note_saturation(base_time); + let later = base_time + Duration::from_secs(AUTH_PROBE_TRACK_RETENTION_SECS - 1); + auth_probe_note_saturation(later); + + let check_time = base_time + Duration::from_secs(AUTH_PROBE_TRACK_RETENTION_SECS + 5); + + // This call may return false if backoff has elapsed, but it must not clear + // the saturation state because `later` refreshed last_seen. + let _ = auth_probe_saturation_is_throttled_at_for_testing(check_time); + let guard = auth_probe_saturation_state_lock(); + assert!( + guard.is_some(), + "Ongoing saturation notes must refresh last_seen so saturation state remains retained past the original window" + ); +} + +#[test] +fn mtproto_classic_tags_rejected_when_only_secure_mode_enabled() { + let mut config = ProxyConfig::default(); + config.general.modes.classic = false; + config.general.modes.secure = true; + config.general.modes.tls = false; + + assert!(!mode_enabled_for_proto(&config, ProtoTag::Abridged, false)); + assert!(!mode_enabled_for_proto(&config, ProtoTag::Intermediate, false)); +} + +#[test] +fn mtproto_secure_tag_rejected_when_only_classic_mode_enabled() { + let mut config = ProxyConfig::default(); + config.general.modes.classic = true; + config.general.modes.secure = false; + config.general.modes.tls = false; + + assert!(!mode_enabled_for_proto(&config, ProtoTag::Secure, false)); +} + +#[test] +fn ipv6_localhost_and_unspecified_normalization() { + let localhost = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)); + let unspecified = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)); + + let norm_local = normalize_auth_probe_ip(localhost); + let norm_unspec = normalize_auth_probe_ip(unspecified); + + let expected_bucket = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)); + + assert_eq!(norm_local, expected_bucket); + assert_eq!(norm_unspec, expected_bucket); +} diff --git a/src/proxy/tests/handshake_real_bug_stress_tests.rs b/src/proxy/tests/handshake_real_bug_stress_tests.rs new file mode 100644 index 0000000..d7234ff --- /dev/null +++ b/src/proxy/tests/handshake_real_bug_stress_tests.rs @@ -0,0 +1,337 @@ +use super::*; +use crate::crypto::{sha256, sha256_hmac, AesCtr, SecureRandom}; +use crate::protocol::constants::{ProtoTag, TLS_RECORD_HANDSHAKE, TLS_VERSION}; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tokio::sync::Barrier; + +fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> { + auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig { + let mut cfg = ProxyConfig::default(); + cfg.access.users.clear(); + cfg.access + .users + .insert("user".to_string(), secret_hex.to_string()); + cfg.access.ignore_time_skew = true; + cfg.general.modes.secure = true; + cfg.general.modes.classic = true; + cfg.general.modes.tls = true; + cfg +} + +fn make_valid_tls_client_hello_with_alpn( + secret: &[u8], + timestamp: u32, + alpn_protocols: &[&[u8]], +) -> Vec { + let mut body = Vec::new(); + body.extend_from_slice(&TLS_VERSION); + body.extend_from_slice(&[0u8; 32]); + body.push(32); + body.extend_from_slice(&[0x42u8; 32]); + body.extend_from_slice(&2u16.to_be_bytes()); + body.extend_from_slice(&[0x13, 0x01]); + body.push(1); + body.push(0); + + let mut ext_blob = Vec::new(); + if !alpn_protocols.is_empty() { + let mut alpn_list = Vec::new(); + for proto in alpn_protocols { + alpn_list.push(proto.len() as u8); + alpn_list.extend_from_slice(proto); + } + let mut alpn_data = Vec::new(); + alpn_data.extend_from_slice(&(alpn_list.len() as u16).to_be_bytes()); + alpn_data.extend_from_slice(&alpn_list); + + ext_blob.extend_from_slice(&0x0010u16.to_be_bytes()); + ext_blob.extend_from_slice(&(alpn_data.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&alpn_data); + } + + body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes()); + body.extend_from_slice(&ext_blob); + + let mut handshake = Vec::new(); + handshake.push(0x01); + let body_len = (body.len() as u32).to_be_bytes(); + handshake.extend_from_slice(&body_len[1..4]); + handshake.extend_from_slice(&body); + + let mut record = Vec::new(); + record.push(TLS_RECORD_HANDSHAKE); + record.extend_from_slice(&[0x03, 0x01]); + record.extend_from_slice(&(handshake.len() as u16).to_be_bytes()); + record.extend_from_slice(&handshake); + + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + + let computed = sha256_hmac(secret, &record); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .copy_from_slice(&digest); + + record +} + +fn make_valid_mtproto_handshake( + secret_hex: &str, + proto_tag: ProtoTag, + dc_idx: i16, +) -> [u8; HANDSHAKE_LEN] { + let secret = hex::decode(secret_hex).expect("secret hex must decode"); + let mut handshake = [0x5Au8; HANDSHAKE_LEN]; + for (idx, b) in handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN] + .iter_mut() + .enumerate() + { + *b = (idx as u8).wrapping_add(1); + } + + let dec_prekey = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN]; + let dec_iv_bytes = &handshake[SKIP_LEN + PREKEY_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]; + + let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len()); + dec_key_input.extend_from_slice(dec_prekey); + dec_key_input.extend_from_slice(&secret); + let dec_key = sha256(&dec_key_input); + + let mut dec_iv_arr = [0u8; IV_LEN]; + dec_iv_arr.copy_from_slice(dec_iv_bytes); + let dec_iv = u128::from_be_bytes(dec_iv_arr); + + let mut stream = AesCtr::new(&dec_key, dec_iv); + let keystream = stream.encrypt(&[0u8; HANDSHAKE_LEN]); + + let mut target_plain = [0u8; HANDSHAKE_LEN]; + target_plain[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes()); + target_plain[DC_IDX_POS..DC_IDX_POS + 2].copy_from_slice(&dc_idx.to_le_bytes()); + + for idx in PROTO_TAG_POS..HANDSHAKE_LEN { + handshake[idx] = target_plain[idx] ^ keystream[idx]; + } + + handshake +} + +#[tokio::test] +async fn tls_alpn_reject_does_not_pollute_replay_cache() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret = [0x11u8; 16]; + let mut config = test_config_with_secret_hex("11111111111111111111111111111111"); + config.censorship.alpn_enforce = true; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "192.0.2.201:12345".parse().unwrap(); + + let handshake = make_valid_tls_client_hello_with_alpn(&secret, 0, &[b"h3"]); + let before = replay_checker.stats(); + + let res = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + let after = replay_checker.stats(); + + assert!(matches!(res, HandshakeResult::BadClient { .. })); + assert_eq!( + before.total_additions, after.total_additions, + "ALPN policy reject must not add TLS digest into replay cache" + ); +} + +#[tokio::test] +async fn tls_truncated_session_id_len_fails_closed_without_panic() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let config = test_config_with_secret_hex("33333333333333333333333333333333"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "192.0.2.203:12345".parse().unwrap(); + + let min_len = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1; + let mut malicious = vec![0x42u8; min_len]; + malicious[min_len - 1] = u8::MAX; + + let res = handle_tls_handshake( + &malicious, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(res, HandshakeResult::BadClient { .. })); +} + +#[test] +fn auth_probe_eviction_identical_timestamps_keeps_map_bounded() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let state = auth_probe_state_map(); + let same = Instant::now(); + + for i in 0..AUTH_PROBE_TRACK_MAX_ENTRIES { + let ip = IpAddr::V4(Ipv4Addr::new(10, 1, (i >> 8) as u8, (i & 0xFF) as u8)); + state.insert( + ip, + AuthProbeState { + fail_streak: 7, + blocked_until: same, + last_seen: same, + }, + ); + } + + let new_ip = IpAddr::V4(Ipv4Addr::new(192, 168, 21, 21)); + auth_probe_record_failure_with_state(state, new_ip, same + Duration::from_millis(1)); + + assert_eq!(state.len(), AUTH_PROBE_TRACK_MAX_ENTRIES); + assert!(state.contains_key(&new_ip)); +} + +#[test] +fn clear_auth_probe_state_recovers_from_poisoned_saturation_lock() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let saturation = auth_probe_saturation_state(); + let poison_thread = std::thread::spawn(move || { + let _hold = saturation + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + panic!("intentional poison for regression coverage"); + }); + let _ = poison_thread.join(); + + clear_auth_probe_state_for_testing(); + + let guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + assert!(guard.is_none()); +} + +#[tokio::test] +async fn mtproto_invalid_length_secret_is_ignored_and_valid_user_still_auths() { + let _probe_guard = auth_probe_test_guard(); + let _warn_guard = warned_secrets_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + clear_warned_secrets_for_testing(); + + let mut config = ProxyConfig::default(); + config.general.modes.secure = true; + config.access.ignore_time_skew = true; + + config.access.users.insert( + "short_user".to_string(), + "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaa".to_string(), + ); + + let valid_secret_hex = "77777777777777777777777777777777"; + config + .access + .users + .insert("good_user".to_string(), valid_secret_hex.to_string()); + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let peer: SocketAddr = "192.0.2.207:12345".parse().unwrap(); + let handshake = make_valid_mtproto_handshake(valid_secret_hex, ProtoTag::Secure, 1); + + let res = handle_mtproto_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + + assert!(matches!(res, HandshakeResult::Success(_))); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn saturation_grace_exhaustion_under_concurrency_keeps_peer_throttled() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let peer_ip = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 80)); + let now = Instant::now(); + + { + let mut guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(5), + last_seen: now, + }); + } + + let state = auth_probe_state_map(); + state.insert( + peer_ip, + AuthProbeState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS - 1, + blocked_until: now, + last_seen: now, + }, + ); + + let tasks = 32; + let barrier = Arc::new(Barrier::new(tasks)); + let mut handles = Vec::new(); + + for _ in 0..tasks { + let b = barrier.clone(); + handles.push(tokio::spawn(async move { + b.wait().await; + auth_probe_record_failure(peer_ip, Instant::now()); + })); + } + + for handle in handles { + handle.await.unwrap(); + } + + let final_state = state.get(&peer_ip).expect("state must exist"); + assert!( + final_state.fail_streak >= AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS + ); + assert!(auth_probe_should_apply_preauth_throttle(peer_ip, Instant::now())); +} diff --git a/src/proxy/tests/handshake_timing_manual_bench_tests.rs b/src/proxy/tests/handshake_timing_manual_bench_tests.rs new file mode 100644 index 0000000..95e9f49 --- /dev/null +++ b/src/proxy/tests/handshake_timing_manual_bench_tests.rs @@ -0,0 +1,318 @@ +use super::*; +use crate::crypto::{sha256, sha256_hmac, AesCtr, SecureRandom}; +use crate::protocol::constants::{ProtoTag, TLS_RECORD_HANDSHAKE, TLS_VERSION}; +use std::net::SocketAddr; +use std::time::{Duration, Instant}; + +fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> { + auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +fn make_valid_mtproto_handshake( + secret_hex: &str, + proto_tag: ProtoTag, + dc_idx: i16, + salt: u8, +) -> [u8; HANDSHAKE_LEN] { + let secret = hex::decode(secret_hex).expect("secret hex must decode"); + let mut handshake = [0x5Au8; HANDSHAKE_LEN]; + + for (idx, b) in handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN] + .iter_mut() + .enumerate() + { + *b = (idx as u8).wrapping_add(1).wrapping_add(salt); + } + + let dec_prekey = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN]; + let dec_iv_bytes = &handshake[SKIP_LEN + PREKEY_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]; + + let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len()); + dec_key_input.extend_from_slice(dec_prekey); + dec_key_input.extend_from_slice(&secret); + let dec_key = sha256(&dec_key_input); + + let mut dec_iv_arr = [0u8; IV_LEN]; + dec_iv_arr.copy_from_slice(dec_iv_bytes); + let dec_iv = u128::from_be_bytes(dec_iv_arr); + + let mut stream = AesCtr::new(&dec_key, dec_iv); + let keystream = stream.encrypt(&[0u8; HANDSHAKE_LEN]); + + let mut target_plain = [0u8; HANDSHAKE_LEN]; + target_plain[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes()); + target_plain[DC_IDX_POS..DC_IDX_POS + 2].copy_from_slice(&dc_idx.to_le_bytes()); + + for idx in PROTO_TAG_POS..HANDSHAKE_LEN { + handshake[idx] = target_plain[idx] ^ keystream[idx]; + } + + handshake +} + +fn make_valid_tls_handshake(secret: &[u8], timestamp: u32) -> Vec { + let session_id_len: usize = 32; + let len = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + session_id_len; + let mut handshake = vec![0x42u8; len]; + + handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8; + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + + let computed = sha256_hmac(secret, &handshake); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .copy_from_slice(&digest); + handshake +} + +fn make_valid_tls_client_hello_with_sni_and_alpn( + secret: &[u8], + timestamp: u32, + sni_host: &str, + alpn_protocols: &[&[u8]], +) -> Vec { + let mut body = Vec::new(); + body.extend_from_slice(&TLS_VERSION); + body.extend_from_slice(&[0u8; 32]); + body.push(32); + body.extend_from_slice(&[0x42u8; 32]); + body.extend_from_slice(&2u16.to_be_bytes()); + body.extend_from_slice(&[0x13, 0x01]); + body.push(1); + body.push(0); + + let mut ext_blob = Vec::new(); + let host_bytes = sni_host.as_bytes(); + let mut sni_payload = Vec::new(); + sni_payload.extend_from_slice(&((host_bytes.len() + 3) as u16).to_be_bytes()); + sni_payload.push(0); + sni_payload.extend_from_slice(&(host_bytes.len() as u16).to_be_bytes()); + sni_payload.extend_from_slice(host_bytes); + ext_blob.extend_from_slice(&0x0000u16.to_be_bytes()); + ext_blob.extend_from_slice(&(sni_payload.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&sni_payload); + + if !alpn_protocols.is_empty() { + let mut alpn_list = Vec::new(); + for proto in alpn_protocols { + alpn_list.push(proto.len() as u8); + alpn_list.extend_from_slice(proto); + } + let mut alpn_data = Vec::new(); + alpn_data.extend_from_slice(&(alpn_list.len() as u16).to_be_bytes()); + alpn_data.extend_from_slice(&alpn_list); + + ext_blob.extend_from_slice(&0x0010u16.to_be_bytes()); + ext_blob.extend_from_slice(&(alpn_data.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&alpn_data); + } + + body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes()); + body.extend_from_slice(&ext_blob); + + let mut handshake = Vec::new(); + handshake.push(0x01); + let body_len = (body.len() as u32).to_be_bytes(); + handshake.extend_from_slice(&body_len[1..4]); + handshake.extend_from_slice(&body); + + let mut record = Vec::new(); + record.push(TLS_RECORD_HANDSHAKE); + record.extend_from_slice(&[0x03, 0x01]); + record.extend_from_slice(&(handshake.len() as u16).to_be_bytes()); + record.extend_from_slice(&handshake); + + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + let computed = sha256_hmac(secret, &record); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].copy_from_slice(&digest); + + record +} + +fn median_ns(samples: &mut [u128]) -> u128 { + samples.sort_unstable(); + samples[samples.len() / 2] +} + +#[tokio::test] +#[ignore = "manual benchmark: timing-sensitive and host-dependent"] +async fn mtproto_user_scan_timing_manual_benchmark() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + const DECOY_USERS: usize = 8_000; + const ITERATIONS: usize = 250; + + let preferred_user = "target_user"; + let target_secret_hex = "dededededededededededededededede"; + + let mut config = ProxyConfig::default(); + config.general.modes.secure = true; + config.access.ignore_time_skew = true; + + for i in 0..DECOY_USERS { + config.access.users.insert( + format!("decoy_{i}"), + "00000000000000000000000000000000".to_string(), + ); + } + + config.access.users.insert( + preferred_user.to_string(), + target_secret_hex.to_string(), + ); + + let replay_checker_preferred = ReplayChecker::new(65_536, Duration::from_secs(60)); + let replay_checker_full_scan = ReplayChecker::new(65_536, Duration::from_secs(60)); + let peer_a: SocketAddr = "192.0.2.241:12345".parse().unwrap(); + let peer_b: SocketAddr = "192.0.2.242:12345".parse().unwrap(); + + let mut preferred_samples = Vec::with_capacity(ITERATIONS); + let mut full_scan_samples = Vec::with_capacity(ITERATIONS); + + for i in 0..ITERATIONS { + let handshake = make_valid_mtproto_handshake( + target_secret_hex, + ProtoTag::Secure, + 1 + i as i16, + (i % 251) as u8, + ); + + let started_preferred = Instant::now(); + let preferred = handle_mtproto_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer_a, + &config, + &replay_checker_preferred, + false, + Some(preferred_user), + ) + .await; + preferred_samples.push(started_preferred.elapsed().as_nanos()); + assert!(matches!(preferred, HandshakeResult::Success(_))); + + let started_scan = Instant::now(); + let full_scan = handle_mtproto_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer_b, + &config, + &replay_checker_full_scan, + false, + None, + ) + .await; + full_scan_samples.push(started_scan.elapsed().as_nanos()); + assert!(matches!(full_scan, HandshakeResult::Success(_))); + } + + let preferred_median = median_ns(&mut preferred_samples); + let full_scan_median = median_ns(&mut full_scan_samples); + + let ratio = if preferred_median == 0 { + 0.0 + } else { + full_scan_median as f64 / preferred_median as f64 + }; + + println!( + "manual timing benchmark: decoys={DECOY_USERS}, iters={ITERATIONS}, preferred_median_ns={preferred_median}, full_scan_median_ns={full_scan_median}, ratio={ratio:.3}" + ); + + assert!( + full_scan_median >= preferred_median, + "full user scan should not be faster than preferred-user path in this benchmark" + ); +} + +#[tokio::test] +#[ignore = "manual benchmark: timing-sensitive and host-dependent"] +async fn tls_sni_preferred_vs_no_sni_fallback_manual_benchmark() { + let _guard = auth_probe_test_guard(); + + const DECOY_USERS: usize = 8_000; + const ITERATIONS: usize = 250; + + let preferred_user = "user-b"; + let target_secret_hex = "abababababababababababababababab"; + let target_secret = [0xABu8; 16]; + + let mut config = ProxyConfig::default(); + config.general.modes.tls = true; + config.access.ignore_time_skew = true; + + for i in 0..DECOY_USERS { + config.access.users.insert( + format!("decoy_{i}"), + "00000000000000000000000000000000".to_string(), + ); + } + + config + .access + .users + .insert(preferred_user.to_string(), target_secret_hex.to_string()); + + let mut sni_samples = Vec::with_capacity(ITERATIONS); + let mut no_sni_samples = Vec::with_capacity(ITERATIONS); + + for i in 0..ITERATIONS { + let with_sni = make_valid_tls_client_hello_with_sni_and_alpn( + &target_secret, + i as u32, + preferred_user, + &[b"h2"], + ); + let no_sni = make_valid_tls_handshake(&target_secret, (i as u32).wrapping_add(10_000)); + + let started_sni = Instant::now(); + let sni_secrets = decode_user_secrets(&config, Some(preferred_user)); + let sni_result = tls::validate_tls_handshake_with_replay_window( + &with_sni, + &sni_secrets, + config.access.ignore_time_skew, + config.access.replay_window_secs, + ); + sni_samples.push(started_sni.elapsed().as_nanos()); + assert!(sni_result.is_some()); + + let started_no_sni = Instant::now(); + let no_sni_secrets = decode_user_secrets(&config, None); + let no_sni_result = tls::validate_tls_handshake_with_replay_window( + &no_sni, + &no_sni_secrets, + config.access.ignore_time_skew, + config.access.replay_window_secs, + ); + no_sni_samples.push(started_no_sni.elapsed().as_nanos()); + assert!(no_sni_result.is_some()); + } + + let sni_median = median_ns(&mut sni_samples); + let no_sni_median = median_ns(&mut no_sni_samples); + + let ratio = if sni_median == 0 { + 0.0 + } else { + no_sni_median as f64 / sni_median as f64 + }; + + println!( + "manual tls benchmark: decoys={DECOY_USERS}, iters={ITERATIONS}, sni_median_ns={sni_median}, no_sni_median_ns={no_sni_median}, ratio_no_sni_over_sni={ratio:.3}" + ); +} From 5c9fea5850d24b6086341a2f60b243da6bf56606 Mon Sep 17 00:00:00 2001 From: David Osipov Date: Sun, 22 Mar 2026 17:08:16 +0400 Subject: [PATCH 02/29] Update src/proxy/tests/client_security_tests.rs Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/proxy/tests/client_security_tests.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/proxy/tests/client_security_tests.rs b/src/proxy/tests/client_security_tests.rs index 2b1fae6..35f517a 100644 --- a/src/proxy/tests/client_security_tests.rs +++ b/src/proxy/tests/client_security_tests.rs @@ -8,7 +8,7 @@ use crate::proxy::handshake::HandshakeSuccess; use crate::stream::{CryptoReader, CryptoWriter}; use crate::transport::proxy_protocol::ProxyProtocolV1Builder; use rand::rngs::StdRng; -use rand::Rng; +use rand::RngCore; use rand::SeedableRng; use std::net::Ipv4Addr; use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; From 6fc188f0c4c8f9d676bdd5cfec86179aca88d044 Mon Sep 17 00:00:00 2001 From: David Osipov Date: Sun, 22 Mar 2026 17:08:23 +0400 Subject: [PATCH 03/29] Update src/proxy/tests/handshake_more_clever_tests.rs Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/proxy/tests/handshake_more_clever_tests.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/proxy/tests/handshake_more_clever_tests.rs b/src/proxy/tests/handshake_more_clever_tests.rs index 77df442..b3da4df 100644 --- a/src/proxy/tests/handshake_more_clever_tests.rs +++ b/src/proxy/tests/handshake_more_clever_tests.rs @@ -1,7 +1,7 @@ use super::*; use crate::crypto::{sha256, sha256_hmac, AesCtr}; use crate::protocol::constants::{ProtoTag, RESERVED_NONCE_BEGINNINGS, RESERVED_NONCE_FIRST_BYTES}; -use rand::{Rng, SeedableRng}; +use rand::{RngExt, SeedableRng}; use rand::rngs::StdRng; use std::collections::HashSet; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; From e46d2cfc52d23f0078484135c0ffda51b4ca9137 Mon Sep 17 00:00:00 2001 From: Alexander <32452033+avbor@users.noreply.github.com> Date: Sun, 22 Mar 2026 21:59:20 +0300 Subject: [PATCH 04/29] Update VPS_DOUBLE_HOP.ru.md Fix typo --- docs/VPS_DOUBLE_HOP.ru.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/VPS_DOUBLE_HOP.ru.md b/docs/VPS_DOUBLE_HOP.ru.md index 625c64c..689a7c0 100644 --- a/docs/VPS_DOUBLE_HOP.ru.md +++ b/docs/VPS_DOUBLE_HOP.ru.md @@ -272,7 +272,7 @@ backend telemt_nodes ``` >[!WARNING] ->**Файл должен заканчиваться пустой строкой, иначе HAProxy не запуститься!** +>**Файл должен заканчиваться пустой строкой, иначе HAProxy не запустится!** #### Разрешаем порт 443\tcp в фаерволе (если включен) ```bash From 91be148b72b2bdc61ac1da0f7e5ce370018ca096 Mon Sep 17 00:00:00 2001 From: David Osipov Date: Sun, 22 Mar 2026 23:06:26 +0400 Subject: [PATCH 05/29] Security hardening, concurrency fixes, and expanded test coverage This commit introduces a comprehensive set of improvements to enhance the security, reliability, and configurability of the proxy server, specifically targeting adversarial resilience and high-load concurrency. Security & Cryptography: - Zeroize MTProto cryptographic key material (`dec_key`, `enc_key`) immediately after use to prevent memory leakage on early returns. - Move TLS handshake replay tracking after full policy/ALPN validation to prevent cache poisoning by unauthenticated probes. - Add `proxy_protocol_trusted_cidrs` configuration to restrict PROXY protocol headers to trusted networks, rejecting spoofed IPs. Adversarial Resilience & DoS Mitigation: - Implement "Tiny Frame Debt" tracking in the middle-relay to prevent CPU exhaustion from malicious 0-byte or 1-byte frame floods. - Add `mask_relay_max_bytes` to strictly bound unauthenticated fallback connections, preventing the proxy from being abused as an open relay. - Add a 5ms prefetch window (`mask_classifier_prefetch_timeout_ms`) to correctly assemble and classify fragmented HTTP/1.1 and HTTP/2 probes (e.g., `PRI * HTTP/2.0`) before routing them to masking heuristics. - Prevent recursive masking loops (FD exhaustion) by verifying the mask target is not the proxy's own listener via local interface enumeration. Concurrency & Reliability: - Eliminate executor waker storms during quota lock contention by replacing the spin-waker task with inline `Sleep` and exponential backoff. - Roll back user quota reservations (`rollback_me2c_quota_reservation`) if a network write fails, preventing Head-of-Line (HoL) blocking from permanently burning data quotas. - Recover gracefully from idle-registry `Mutex` poisoning instead of panicking, ensuring isolated thread failures do not break the proxy. - Fix `auth_probe_scan_start_offset` modulo logic to ensure bounds safety. Testing: - Add extensive adversarial, timing, fuzzing, and invariant test suites for both the client and handshake modules. --- Cargo.lock | 34 +- Cargo.toml | 5 +- docs/CONFIG_PARAMS.en.md | 2 + src/config/defaults.rs | 14 + src/config/hot_reload.rs | 3 + src/config/load.rs | 23 + ...ssifier_prefetch_timeout_security_tests.rs | 75 +++ .../tests/load_mask_shape_security_tests.rs | 54 ++ src/config/types.rs | 10 + src/proxy/client.rs | 102 ++++ src/proxy/handshake.rs | 40 +- src/proxy/masking.rs | 409 +++++++++++-- src/proxy/middle_relay.rs | 234 ++++++-- src/proxy/mod.rs | 1 + src/proxy/quota_lock_registry.rs | 53 ++ src/proxy/relay.rs | 219 +++++-- ...ng_fragmented_classifier_security_tests.rs | 100 ++++ ...http2_fragmented_preface_security_tests.rs | 129 ++++ ...fig_pipeline_integration_security_tests.rs | 150 +++++ ..._prefetch_config_runtime_security_tests.rs | 82 +++ ...sking_prefetch_invariant_security_tests.rs | 261 +++++++++ ...prefetch_strict_boundary_security_tests.rs | 70 +++ ...g_prefetch_timing_matrix_security_tests.rs | 95 +++ ...nt_masking_replay_timing_security_tests.rs | 161 +++++ ...e_auth_probe_scan_budget_security_tests.rs | 90 +++ ...ake_auth_probe_scan_offset_stress_tests.rs | 113 ++++ ...key_material_zeroization_security_tests.rs | 42 ++ ...nvelope_blur_integration_security_tests.rs | 44 +- ...ing_additional_hardening_security_tests.rs | 122 ++++ ...ssification_completeness_security_tests.rs | 16 + ...ect_failure_close_matrix_security_tests.rs | 127 ++++ ...ing_consume_idle_timeout_security_tests.rs | 85 +++ ...asking_consume_stress_adversarial_tests.rs | 64 ++ ...ttp2_preface_integration_security_tests.rs | 55 ++ ...tp2_probe_classification_security_tests.rs | 92 +++ ...king_http_probe_boundary_security_tests.rs | 79 +++ ...e_cache_defense_in_depth_security_tests.rs | 51 ++ .../masking_interface_cache_security_tests.rs | 46 ++ ...line_target_redteam_expected_fail_tests.rs | 178 ++++++ ...sking_padding_timeout_adversarial_tests.rs | 51 ++ ...masking_relay_guardrails_security_tests.rs | 105 ++++ ...masking_rng_hoist_perf_regression_tests.rs | 100 ++++ src/proxy/tests/masking_security_tests.rs | 2 + ...masking_self_target_loop_security_tests.rs | 354 +++++++++++ .../masking_shape_guard_adversarial_tests.rs | 1 + ...sking_shape_hardening_adversarial_tests.rs | 1 + ...lay_blackhat_campaign_integration_tests.rs | 1 + .../middle_relay_hol_quota_security_tests.rs | 229 ++++++++ ...middle_relay_idle_policy_security_tests.rs | 40 +- ...lay_idle_registry_poison_security_tests.rs | 59 ++ ...lay_quota_reservation_adversarial_tests.rs | 192 ++++++ ...y_frame_debt_concurrency_security_tests.rs | 365 ++++++++++++ ...rame_debt_proto_chunking_security_tests.rs | 418 +++++++++++++ ...le_relay_tiny_frame_debt_security_tests.rs | 550 ++++++++++++++++++ ..._relay_zero_length_frame_security_tests.rs | 121 ++++ ...k_registry_cross_mode_adversarial_tests.rs | 108 ++++ ...lay_cross_mode_quota_fairness_tdd_tests.rs | 225 +++++++ ...ay_cross_mode_quota_lock_security_tests.rs | 81 +++ ...elay_quota_lock_identity_security_tests.rs | 135 +++++ ..._retry_backoff_benchmark_security_tests.rs | 241 ++++++++ ...elay_quota_retry_backoff_security_tests.rs | 339 +++++++++++ .../relay_quota_retry_scheduler_tdd_tests.rs | 246 ++++++++ src/proxy/tests/relay_security_tests.rs | 8 +- src/stats/mod.rs | 30 + .../tests/user_octets_sub_security_tests.rs | 151 +++++ 65 files changed, 7473 insertions(+), 210 deletions(-) create mode 100644 src/config/tests/load_mask_classifier_prefetch_timeout_security_tests.rs create mode 100644 src/proxy/quota_lock_registry.rs create mode 100644 src/proxy/tests/client_masking_fragmented_classifier_security_tests.rs create mode 100644 src/proxy/tests/client_masking_http2_fragmented_preface_security_tests.rs create mode 100644 src/proxy/tests/client_masking_prefetch_config_pipeline_integration_security_tests.rs create mode 100644 src/proxy/tests/client_masking_prefetch_config_runtime_security_tests.rs create mode 100644 src/proxy/tests/client_masking_prefetch_invariant_security_tests.rs create mode 100644 src/proxy/tests/client_masking_prefetch_strict_boundary_security_tests.rs create mode 100644 src/proxy/tests/client_masking_prefetch_timing_matrix_security_tests.rs create mode 100644 src/proxy/tests/client_masking_replay_timing_security_tests.rs create mode 100644 src/proxy/tests/handshake_auth_probe_scan_budget_security_tests.rs create mode 100644 src/proxy/tests/handshake_auth_probe_scan_offset_stress_tests.rs create mode 100644 src/proxy/tests/handshake_key_material_zeroization_security_tests.rs create mode 100644 src/proxy/tests/masking_additional_hardening_security_tests.rs create mode 100644 src/proxy/tests/masking_classification_completeness_security_tests.rs create mode 100644 src/proxy/tests/masking_connect_failure_close_matrix_security_tests.rs create mode 100644 src/proxy/tests/masking_consume_idle_timeout_security_tests.rs create mode 100644 src/proxy/tests/masking_consume_stress_adversarial_tests.rs create mode 100644 src/proxy/tests/masking_http2_preface_integration_security_tests.rs create mode 100644 src/proxy/tests/masking_http2_probe_classification_security_tests.rs create mode 100644 src/proxy/tests/masking_http_probe_boundary_security_tests.rs create mode 100644 src/proxy/tests/masking_interface_cache_defense_in_depth_security_tests.rs create mode 100644 src/proxy/tests/masking_interface_cache_security_tests.rs create mode 100644 src/proxy/tests/masking_offline_target_redteam_expected_fail_tests.rs create mode 100644 src/proxy/tests/masking_padding_timeout_adversarial_tests.rs create mode 100644 src/proxy/tests/masking_relay_guardrails_security_tests.rs create mode 100644 src/proxy/tests/masking_rng_hoist_perf_regression_tests.rs create mode 100644 src/proxy/tests/masking_self_target_loop_security_tests.rs create mode 100644 src/proxy/tests/middle_relay_hol_quota_security_tests.rs create mode 100644 src/proxy/tests/middle_relay_idle_registry_poison_security_tests.rs create mode 100644 src/proxy/tests/middle_relay_quota_reservation_adversarial_tests.rs create mode 100644 src/proxy/tests/middle_relay_tiny_frame_debt_concurrency_security_tests.rs create mode 100644 src/proxy/tests/middle_relay_tiny_frame_debt_proto_chunking_security_tests.rs create mode 100644 src/proxy/tests/middle_relay_tiny_frame_debt_security_tests.rs create mode 100644 src/proxy/tests/middle_relay_zero_length_frame_security_tests.rs create mode 100644 src/proxy/tests/quota_lock_registry_cross_mode_adversarial_tests.rs create mode 100644 src/proxy/tests/relay_cross_mode_quota_fairness_tdd_tests.rs create mode 100644 src/proxy/tests/relay_cross_mode_quota_lock_security_tests.rs create mode 100644 src/proxy/tests/relay_quota_lock_identity_security_tests.rs create mode 100644 src/proxy/tests/relay_quota_retry_backoff_benchmark_security_tests.rs create mode 100644 src/proxy/tests/relay_quota_retry_backoff_security_tests.rs create mode 100644 src/proxy/tests/relay_quota_retry_scheduler_tdd_tests.rs create mode 100644 src/stats/tests/user_octets_sub_security_tests.rs diff --git a/Cargo.lock b/Cargo.lock index 8159a22..c4cde39 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1486,7 +1486,7 @@ dependencies = [ "cesu8", "cfg-if", "combine", - "jni-sys", + "jni-sys 0.3.1", "log", "thiserror 1.0.69", "walkdir", @@ -1495,9 +1495,31 @@ dependencies = [ [[package]] name = "jni-sys" -version = "0.3.0" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130" +checksum = "41a652e1f9b6e0275df1f15b32661cf0d4b78d4d87ddec5e0c3c20f097433258" +dependencies = [ + "jni-sys 0.4.1", +] + +[[package]] +name = "jni-sys" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c6377a88cb3910bee9b0fa88d4f42e1d2da8e79915598f65fb0c7ee14c878af2" +dependencies = [ + "jni-sys-macros", +] + +[[package]] +name = "jni-sys-macros" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38c0b942f458fe50cdac086d2f946512305e5631e720728f2a61aabcd47a6264" +dependencies = [ + "quote", + "syn", +] [[package]] name = "jobserver" @@ -1659,9 +1681,9 @@ dependencies = [ [[package]] name = "moka" -version = "0.12.14" +version = "0.12.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85f8024e1c8e71c778968af91d43700ce1d11b219d127d79fb2934153b82b42b" +checksum = "957228ad12042ee839f93c8f257b62b4c0ab5eaae1d4fa60de53b27c9d7c5046" dependencies = [ "crossbeam-channel", "crossbeam-epoch", @@ -2771,7 +2793,7 @@ checksum = "7b2093cf4c8eb1e67749a6762251bc9cd836b6fc171623bd0a9d324d37af2417" [[package]] name = "telemt" -version = "3.3.29" +version = "3.3.30" dependencies = [ "aes", "anyhow", diff --git a/Cargo.toml b/Cargo.toml index 53082db..1e06b7a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,8 +1,11 @@ [package] name = "telemt" -version = "3.3.29" +version = "3.3.30" edition = "2024" +[features] +redteam_offline_expected_fail = [] + [dependencies] # C libc = "0.2" diff --git a/docs/CONFIG_PARAMS.en.md b/docs/CONFIG_PARAMS.en.md index f8c56a0..73d36e1 100644 --- a/docs/CONFIG_PARAMS.en.md +++ b/docs/CONFIG_PARAMS.en.md @@ -269,6 +269,8 @@ Note: When `server.proxy_protocol` is enabled, incoming PROXY protocol headers a | mask_shape_bucket_cap_bytes | `usize` | `4096` | Must be `>= mask_shape_bucket_floor_bytes`. | Maximum bucket size used by shape-channel hardening; traffic above cap is not padded further. | | mask_shape_above_cap_blur | `bool` | `false` | Requires `mask_shape_hardening = true`; requires `mask_shape_above_cap_blur_max_bytes > 0`. | Adds bounded randomized tail bytes even when forwarded size already exceeds cap. | | mask_shape_above_cap_blur_max_bytes | `usize` | `512` | Must be `<= 1048576`; must be `> 0` when `mask_shape_above_cap_blur = true`. | Maximum randomized extra bytes appended above cap. | +| mask_relay_max_bytes | `usize` | `5242880` | Must be `> 0`; must be `<= 67108864`. | Maximum relayed bytes per direction on unauthenticated masking fallback path. | +| mask_classifier_prefetch_timeout_ms | `u64` | `5` | Must be within `[5, 50]`. | Timeout budget (ms) for extending fragmented initial classifier window on masking fallback. | | mask_timing_normalization_enabled | `bool` | `false` | Requires `mask_timing_normalization_floor_ms > 0`; requires `ceiling >= floor`. | Enables timing envelope normalization on masking outcomes. | | mask_timing_normalization_floor_ms | `u64` | `0` | Must be `> 0` when timing normalization is enabled; must be `<= ceiling`. | Lower bound (ms) for masking outcome normalization target. | | mask_timing_normalization_ceiling_ms | `u64` | `0` | Must be `>= floor`; must be `<= 60000`. | Upper bound (ms) for masking outcome normalization target. | diff --git a/src/config/defaults.rs b/src/config/defaults.rs index 66ffeda..09d146a 100644 --- a/src/config/defaults.rs +++ b/src/config/defaults.rs @@ -553,6 +553,20 @@ pub(crate) fn default_mask_shape_above_cap_blur_max_bytes() -> usize { 512 } +#[cfg(not(test))] +pub(crate) fn default_mask_relay_max_bytes() -> usize { + 5 * 1024 * 1024 +} + +#[cfg(test)] +pub(crate) fn default_mask_relay_max_bytes() -> usize { + 32 * 1024 +} + +pub(crate) fn default_mask_classifier_prefetch_timeout_ms() -> u64 { + 5 +} + pub(crate) fn default_mask_timing_normalization_enabled() -> bool { false } diff --git a/src/config/hot_reload.rs b/src/config/hot_reload.rs index e580b7f..a3f795a 100644 --- a/src/config/hot_reload.rs +++ b/src/config/hot_reload.rs @@ -600,6 +600,9 @@ fn warn_non_hot_changes(old: &ProxyConfig, new: &ProxyConfig, non_hot_changed: b || old.censorship.mask_shape_above_cap_blur != new.censorship.mask_shape_above_cap_blur || old.censorship.mask_shape_above_cap_blur_max_bytes != new.censorship.mask_shape_above_cap_blur_max_bytes + || old.censorship.mask_relay_max_bytes != new.censorship.mask_relay_max_bytes + || old.censorship.mask_classifier_prefetch_timeout_ms + != new.censorship.mask_classifier_prefetch_timeout_ms || old.censorship.mask_timing_normalization_enabled != new.censorship.mask_timing_normalization_enabled || old.censorship.mask_timing_normalization_floor_ms diff --git a/src/config/load.rs b/src/config/load.rs index bf6d036..fc54ec2 100644 --- a/src/config/load.rs +++ b/src/config/load.rs @@ -430,6 +430,25 @@ impl ProxyConfig { )); } + if config.censorship.mask_relay_max_bytes == 0 { + return Err(ProxyError::Config( + "censorship.mask_relay_max_bytes must be > 0".to_string(), + )); + } + + if config.censorship.mask_relay_max_bytes > 67_108_864 { + return Err(ProxyError::Config( + "censorship.mask_relay_max_bytes must be <= 67108864".to_string(), + )); + } + + if !(5..=50).contains(&config.censorship.mask_classifier_prefetch_timeout_ms) { + return Err(ProxyError::Config( + "censorship.mask_classifier_prefetch_timeout_ms must be within [5, 50]" + .to_string(), + )); + } + if config.censorship.mask_timing_normalization_ceiling_ms < config.censorship.mask_timing_normalization_floor_ms { @@ -1134,6 +1153,10 @@ mod load_security_tests; #[path = "tests/load_mask_shape_security_tests.rs"] mod load_mask_shape_security_tests; +#[cfg(test)] +#[path = "tests/load_mask_classifier_prefetch_timeout_security_tests.rs"] +mod load_mask_classifier_prefetch_timeout_security_tests; + #[cfg(test)] mod tests { use super::*; diff --git a/src/config/tests/load_mask_classifier_prefetch_timeout_security_tests.rs b/src/config/tests/load_mask_classifier_prefetch_timeout_security_tests.rs new file mode 100644 index 0000000..49ee953 --- /dev/null +++ b/src/config/tests/load_mask_classifier_prefetch_timeout_security_tests.rs @@ -0,0 +1,75 @@ +use super::*; +use std::fs; +use std::path::PathBuf; +use std::time::{SystemTime, UNIX_EPOCH}; + +fn write_temp_config(contents: &str) -> PathBuf { + let nonce = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("system time must be after unix epoch") + .as_nanos(); + let path = std::env::temp_dir() + .join(format!("telemt-load-mask-prefetch-timeout-security-{nonce}.toml")); + fs::write(&path, contents).expect("temp config write must succeed"); + path +} + +fn remove_temp_config(path: &PathBuf) { + let _ = fs::remove_file(path); +} + +#[test] +fn load_rejects_mask_classifier_prefetch_timeout_below_min_bound() { + let path = write_temp_config( + r#" +[censorship] +mask_classifier_prefetch_timeout_ms = 4 +"#, + ); + + let err = ProxyConfig::load(&path) + .expect_err("prefetch timeout below minimum security bound must be rejected"); + let msg = err.to_string(); + assert!( + msg.contains("censorship.mask_classifier_prefetch_timeout_ms must be within [5, 50]"), + "error must explain timeout bound invariant, got: {msg}" + ); + + remove_temp_config(&path); +} + +#[test] +fn load_rejects_mask_classifier_prefetch_timeout_above_max_bound() { + let path = write_temp_config( + r#" +[censorship] +mask_classifier_prefetch_timeout_ms = 51 +"#, + ); + + let err = ProxyConfig::load(&path) + .expect_err("prefetch timeout above max security bound must be rejected"); + let msg = err.to_string(); + assert!( + msg.contains("censorship.mask_classifier_prefetch_timeout_ms must be within [5, 50]"), + "error must explain timeout bound invariant, got: {msg}" + ); + + remove_temp_config(&path); +} + +#[test] +fn load_accepts_mask_classifier_prefetch_timeout_within_bounds() { + let path = write_temp_config( + r#" +[censorship] +mask_classifier_prefetch_timeout_ms = 20 +"#, + ); + + let cfg = ProxyConfig::load(&path) + .expect("prefetch timeout within security bounds must be accepted"); + assert_eq!(cfg.censorship.mask_classifier_prefetch_timeout_ms, 20); + + remove_temp_config(&path); +} diff --git a/src/config/tests/load_mask_shape_security_tests.rs b/src/config/tests/load_mask_shape_security_tests.rs index 8986a49..2e4aa41 100644 --- a/src/config/tests/load_mask_shape_security_tests.rs +++ b/src/config/tests/load_mask_shape_security_tests.rs @@ -236,3 +236,57 @@ mask_shape_above_cap_blur_max_bytes = 8 remove_temp_config(&path); } + +#[test] +fn load_rejects_zero_mask_relay_max_bytes() { + let path = write_temp_config( + r#" +[censorship] +mask_relay_max_bytes = 0 +"#, + ); + + let err = ProxyConfig::load(&path).expect_err("mask_relay_max_bytes must be > 0"); + let msg = err.to_string(); + assert!( + msg.contains("censorship.mask_relay_max_bytes must be > 0"), + "error must explain non-zero relay cap invariant, got: {msg}" + ); + + remove_temp_config(&path); +} + +#[test] +fn load_rejects_mask_relay_max_bytes_above_upper_bound() { + let path = write_temp_config( + r#" +[censorship] +mask_relay_max_bytes = 67108865 +"#, + ); + + let err = ProxyConfig::load(&path) + .expect_err("mask_relay_max_bytes above hard cap must be rejected"); + let msg = err.to_string(); + assert!( + msg.contains("censorship.mask_relay_max_bytes must be <= 67108864"), + "error must explain relay cap upper bound invariant, got: {msg}" + ); + + remove_temp_config(&path); +} + +#[test] +fn load_accepts_valid_mask_relay_max_bytes() { + let path = write_temp_config( + r#" +[censorship] +mask_relay_max_bytes = 8388608 +"#, + ); + + let cfg = ProxyConfig::load(&path).expect("valid mask_relay_max_bytes must be accepted"); + assert_eq!(cfg.censorship.mask_relay_max_bytes, 8_388_608); + + remove_temp_config(&path); +} diff --git a/src/config/types.rs b/src/config/types.rs index aa58dc1..5dc9719 100644 --- a/src/config/types.rs +++ b/src/config/types.rs @@ -1450,6 +1450,14 @@ pub struct AntiCensorshipConfig { #[serde(default = "default_mask_shape_above_cap_blur_max_bytes")] pub mask_shape_above_cap_blur_max_bytes: usize, + /// Maximum bytes relayed per direction on unauthenticated masking fallback paths. + #[serde(default = "default_mask_relay_max_bytes")] + pub mask_relay_max_bytes: usize, + + /// Prefetch timeout (ms) for extending fragmented masking classifier window. + #[serde(default = "default_mask_classifier_prefetch_timeout_ms")] + pub mask_classifier_prefetch_timeout_ms: u64, + /// Enable outcome-time normalization envelope for masking fallback. #[serde(default = "default_mask_timing_normalization_enabled")] pub mask_timing_normalization_enabled: bool, @@ -1488,6 +1496,8 @@ impl Default for AntiCensorshipConfig { mask_shape_bucket_cap_bytes: default_mask_shape_bucket_cap_bytes(), mask_shape_above_cap_blur: default_mask_shape_above_cap_blur(), mask_shape_above_cap_blur_max_bytes: default_mask_shape_above_cap_blur_max_bytes(), + mask_relay_max_bytes: default_mask_relay_max_bytes(), + mask_classifier_prefetch_timeout_ms: default_mask_classifier_prefetch_timeout_ms(), mask_timing_normalization_enabled: default_mask_timing_normalization_enabled(), mask_timing_normalization_floor_ms: default_mask_timing_normalization_floor_ms(), mask_timing_normalization_ceiling_ms: default_mask_timing_normalization_ceiling_ms(), diff --git a/src/proxy/client.rs b/src/proxy/client.rs index 1941f36..a804a2c 100644 --- a/src/proxy/client.rs +++ b/src/proxy/client.rs @@ -186,6 +186,67 @@ fn handshake_timeout_with_mask_grace(config: &ProxyConfig) -> Duration { } } +const MASK_CLASSIFIER_PREFETCH_WINDOW: usize = 16; +#[cfg(test)] +const MASK_CLASSIFIER_PREFETCH_TIMEOUT: Duration = Duration::from_millis(5); + +fn mask_classifier_prefetch_timeout(config: &ProxyConfig) -> Duration { + Duration::from_millis(config.censorship.mask_classifier_prefetch_timeout_ms) +} + +fn should_prefetch_mask_classifier_window(initial_data: &[u8]) -> bool { + if initial_data.len() >= MASK_CLASSIFIER_PREFETCH_WINDOW { + return false; + } + + if initial_data.is_empty() { + // Empty initial_data means there is no client probe prefix to refine. + // Prefetching in this case can consume fallback relay payload bytes and + // accidentally route them through shaping heuristics. + return false; + } + + if initial_data[0] == 0x16 || initial_data.starts_with(b"SSH-") { + return false; + } + + initial_data.iter().all(|b| b.is_ascii_alphabetic() || *b == b' ') +} + +#[cfg(test)] +async fn extend_masking_initial_window(reader: &mut R, initial_data: &mut Vec) +where + R: AsyncRead + Unpin, +{ + extend_masking_initial_window_with_timeout(reader, initial_data, MASK_CLASSIFIER_PREFETCH_TIMEOUT) + .await; +} + +async fn extend_masking_initial_window_with_timeout( + reader: &mut R, + initial_data: &mut Vec, + prefetch_timeout: Duration, +) +where + R: AsyncRead + Unpin, +{ + if !should_prefetch_mask_classifier_window(initial_data) { + return; + } + + let need = MASK_CLASSIFIER_PREFETCH_WINDOW.saturating_sub(initial_data.len()); + if need == 0 { + return; + } + + let mut extra = [0u8; MASK_CLASSIFIER_PREFETCH_WINDOW]; + if let Ok(Ok(n)) = timeout(prefetch_timeout, reader.read(&mut extra[..need])).await + && n > 0 + { + initial_data.extend_from_slice(&extra[..n]); + } +} + fn masking_outcome( reader: R, writer: W, @@ -200,6 +261,15 @@ where W: AsyncWrite + Unpin + Send + 'static, { HandshakeOutcome::NeedsMasking(Box::pin(async move { + let mut reader = reader; + let mut initial_data = initial_data; + extend_masking_initial_window_with_timeout( + &mut reader, + &mut initial_data, + mask_classifier_prefetch_timeout(&config), + ) + .await; + handle_bad_client( reader, writer, @@ -1321,6 +1391,38 @@ mod masking_shape_classifier_fuzz_redteam_expected_fail_tests; #[path = "tests/client_masking_probe_evasion_blackhat_tests.rs"] mod masking_probe_evasion_blackhat_tests; +#[cfg(test)] +#[path = "tests/client_masking_fragmented_classifier_security_tests.rs"] +mod masking_fragmented_classifier_security_tests; + +#[cfg(test)] +#[path = "tests/client_masking_replay_timing_security_tests.rs"] +mod masking_replay_timing_security_tests; + +#[cfg(test)] +#[path = "tests/client_masking_http2_fragmented_preface_security_tests.rs"] +mod masking_http2_fragmented_preface_security_tests; + +#[cfg(test)] +#[path = "tests/client_masking_prefetch_invariant_security_tests.rs"] +mod masking_prefetch_invariant_security_tests; + +#[cfg(test)] +#[path = "tests/client_masking_prefetch_timing_matrix_security_tests.rs"] +mod masking_prefetch_timing_matrix_security_tests; + +#[cfg(test)] +#[path = "tests/client_masking_prefetch_config_runtime_security_tests.rs"] +mod masking_prefetch_config_runtime_security_tests; + +#[cfg(test)] +#[path = "tests/client_masking_prefetch_config_pipeline_integration_security_tests.rs"] +mod masking_prefetch_config_pipeline_integration_security_tests; + +#[cfg(test)] +#[path = "tests/client_masking_prefetch_strict_boundary_security_tests.rs"] +mod masking_prefetch_strict_boundary_security_tests; + #[cfg(test)] #[path = "tests/client_beobachten_ttl_bounds_security_tests.rs"] mod beobachten_ttl_bounds_security_tests; diff --git a/src/proxy/handshake.rs b/src/proxy/handshake.rs index 7e4b62c..3444a88 100644 --- a/src/proxy/handshake.rs +++ b/src/proxy/handshake.rs @@ -121,6 +121,20 @@ fn auth_probe_eviction_offset(peer_ip: IpAddr, now: Instant) -> usize { hasher.finish() as usize } +fn auth_probe_scan_start_offset( + peer_ip: IpAddr, + now: Instant, + state_len: usize, + scan_limit: usize, +) -> usize { + if state_len == 0 || scan_limit == 0 { + return 0; + } + + let window = state_len.min(scan_limit); + auth_probe_eviction_offset(peer_ip, now) % window +} + fn auth_probe_is_throttled(peer_ip: IpAddr, now: Instant) -> bool { let peer_ip = normalize_auth_probe_ip(peer_ip); let state = auth_probe_state_map(); @@ -269,11 +283,7 @@ fn auth_probe_record_failure_with_state( let mut eviction_candidate: Option<(IpAddr, u32, Instant)> = None; let state_len = state.len(); let scan_limit = state_len.min(AUTH_PROBE_PRUNE_SCAN_LIMIT); - let start_offset = if state_len == 0 { - 0 - } else { - auth_probe_eviction_offset(peer_ip, now) % state_len - }; + let start_offset = auth_probe_scan_start_offset(peer_ip, now, state_len, scan_limit); let mut scanned = 0usize; for entry in state.iter().skip(start_offset) { @@ -769,7 +779,7 @@ where let mut dec_key_input = Zeroizing::new(Vec::with_capacity(PREKEY_LEN + secret.len())); dec_key_input.extend_from_slice(dec_prekey); dec_key_input.extend_from_slice(&secret); - let dec_key = sha256(&dec_key_input); + let dec_key = Zeroizing::new(sha256(&dec_key_input)); let mut dec_iv_arr = [0u8; IV_LEN]; dec_iv_arr.copy_from_slice(dec_iv_bytes); @@ -805,7 +815,7 @@ where let mut enc_key_input = Zeroizing::new(Vec::with_capacity(PREKEY_LEN + secret.len())); enc_key_input.extend_from_slice(enc_prekey); enc_key_input.extend_from_slice(&secret); - let enc_key = sha256(&enc_key_input); + let enc_key = Zeroizing::new(sha256(&enc_key_input)); let mut enc_iv_arr = [0u8; IV_LEN]; enc_iv_arr.copy_from_slice(enc_iv_bytes); @@ -830,9 +840,9 @@ where user: user.clone(), dc_idx, proto_tag, - dec_key, + dec_key: *dec_key, dec_iv, - enc_key, + enc_key: *enc_key, enc_iv, peer, is_tls, @@ -979,6 +989,14 @@ mod saturation_poison_security_tests; #[path = "tests/handshake_auth_probe_hardening_adversarial_tests.rs"] mod auth_probe_hardening_adversarial_tests; +#[cfg(test)] +#[path = "tests/handshake_auth_probe_scan_budget_security_tests.rs"] +mod auth_probe_scan_budget_security_tests; + +#[cfg(test)] +#[path = "tests/handshake_auth_probe_scan_offset_stress_tests.rs"] +mod auth_probe_scan_offset_stress_tests; + #[cfg(test)] #[path = "tests/handshake_advanced_clever_tests.rs"] mod advanced_clever_tests; @@ -995,6 +1013,10 @@ mod real_bug_stress_tests; #[path = "tests/handshake_timing_manual_bench_tests.rs"] mod timing_manual_bench_tests; +#[cfg(test)] +#[path = "tests/handshake_key_material_zeroization_security_tests.rs"] +mod handshake_key_material_zeroization_security_tests; + /// Compile-time guard: HandshakeSuccess holds cryptographic key material and /// must never be Copy. A Copy impl would allow silent key duplication, /// undermining the zeroize-on-drop guarantee. diff --git a/src/proxy/masking.rs b/src/proxy/masking.rs index 3639db1..7d970c2 100644 --- a/src/proxy/masking.rs +++ b/src/proxy/masking.rs @@ -4,10 +4,17 @@ use crate::config::ProxyConfig; use crate::network::dns_overrides::resolve_socket_addr; use crate::stats::beobachten::BeobachtenStore; use crate::transport::proxy_protocol::{ProxyProtocolV1Builder, ProxyProtocolV2Builder}; -use rand::{Rng, RngExt}; -use std::net::SocketAddr; +#[cfg(unix)] +use nix::ifaddrs::getifaddrs; +use rand::rngs::StdRng; +use rand::{Rng, RngExt, SeedableRng}; +use std::net::{IpAddr, SocketAddr}; use std::str; -use std::time::Duration; +#[cfg(unix)] +use std::sync::{Mutex, OnceLock}; +#[cfg(test)] +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::time::{Duration, Instant as StdInstant}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::net::TcpStream; #[cfg(unix)] @@ -30,13 +37,23 @@ const MASK_RELAY_IDLE_TIMEOUT: Duration = Duration::from_secs(5); #[cfg(test)] const MASK_RELAY_IDLE_TIMEOUT: Duration = Duration::from_millis(100); const MASK_BUFFER_SIZE: usize = 8192; +#[cfg(unix)] +#[cfg(not(test))] +const LOCAL_INTERFACE_CACHE_TTL: Duration = Duration::from_secs(300); +#[cfg(all(unix, test))] +const LOCAL_INTERFACE_CACHE_TTL: Duration = Duration::from_secs(1); struct CopyOutcome { total: usize, ended_by_eof: bool, } -async fn copy_with_idle_timeout(reader: &mut R, writer: &mut W) -> CopyOutcome +async fn copy_with_idle_timeout( + reader: &mut R, + writer: &mut W, + byte_cap: usize, + shutdown_on_eof: bool, +) -> CopyOutcome where R: AsyncRead + Unpin, W: AsyncWrite + Unpin, @@ -44,14 +61,31 @@ where let mut buf = [0u8; MASK_BUFFER_SIZE]; let mut total = 0usize; let mut ended_by_eof = false; + + if byte_cap == 0 { + return CopyOutcome { + total, + ended_by_eof, + }; + } + loop { - let read_res = timeout(MASK_RELAY_IDLE_TIMEOUT, reader.read(&mut buf)).await; + let remaining_budget = byte_cap.saturating_sub(total); + if remaining_budget == 0 { + break; + } + + let read_len = remaining_budget.min(MASK_BUFFER_SIZE); + let read_res = timeout(MASK_RELAY_IDLE_TIMEOUT, reader.read(&mut buf[..read_len])).await; let n = match read_res { Ok(Ok(n)) => n, Ok(Err(_)) | Err(_) => break, }; if n == 0 { ended_by_eof = true; + if shutdown_on_eof { + let _ = timeout(MASK_RELAY_IDLE_TIMEOUT, writer.shutdown()).await; + } break; } total = total.saturating_add(n); @@ -61,6 +95,10 @@ where Ok(Ok(())) => {} Ok(Err(_)) | Err(_) => break, } + + if total >= byte_cap { + break; + } } CopyOutcome { total, @@ -68,6 +106,39 @@ where } } +fn is_http_probe(data: &[u8]) -> bool { + // RFC 7540 section 3.5: HTTP/2 client preface starts with "PRI ". + const HTTP_METHODS: [&[u8]; 10] = [ + b"GET ", + b"POST", + b"HEAD", + b"PUT ", + b"DELETE", + b"OPTIONS", + b"CONNECT", + b"TRACE", + b"PATCH", + b"PRI ", + ]; + + if data.is_empty() { + return false; + } + + let window = &data[..data.len().min(16)]; + for method in HTTP_METHODS { + if data.len() >= method.len() && window.starts_with(method) { + return true; + } + + if (2..=3).contains(&window.len()) && method.starts_with(window) { + return true; + } + } + + false +} + fn next_mask_shape_bucket(total: usize, floor: usize, cap: usize) -> usize { if total == 0 || floor == 0 || cap < floor { return total; @@ -125,6 +196,11 @@ async fn maybe_write_shape_padding( let mut remaining = target_total - total_sent; let mut pad_chunk = [0u8; 1024]; let deadline = Instant::now() + MASK_TIMEOUT; + // Use a Send RNG so relay futures remain spawn-safe under Tokio. + let mut rng = { + let mut seed_source = rand::rng(); + StdRng::from_rng(&mut seed_source) + }; while remaining > 0 { let now = Instant::now(); @@ -133,10 +209,7 @@ async fn maybe_write_shape_padding( } let write_len = remaining.min(pad_chunk.len()); - { - let mut rng = rand::rng(); - rng.fill_bytes(&mut pad_chunk[..write_len]); - } + rng.fill_bytes(&mut pad_chunk[..write_len]); let write_budget = deadline.saturating_duration_since(now); match timeout(write_budget, mask_write.write_all(&pad_chunk[..write_len])).await { Ok(Ok(())) => {} @@ -167,11 +240,11 @@ where } } -async fn consume_client_data_with_timeout(reader: R) +async fn consume_client_data_with_timeout_and_cap(reader: R, byte_cap: usize) where R: AsyncRead + Unpin, { - if timeout(MASK_RELAY_TIMEOUT, consume_client_data(reader)) + if timeout(MASK_RELAY_TIMEOUT, consume_client_data(reader, byte_cap)) .await .is_err() { @@ -190,6 +263,9 @@ fn mask_outcome_target_budget(config: &ProxyConfig) -> Duration { if config.censorship.mask_timing_normalization_enabled { let floor = config.censorship.mask_timing_normalization_floor_ms; let ceiling = config.censorship.mask_timing_normalization_ceiling_ms; + if floor == 0 { + return MASK_TIMEOUT; + } if ceiling > floor { let mut rng = rand::rng(); return Duration::from_millis(rng.random_range(floor..=ceiling)); @@ -219,14 +295,7 @@ async fn wait_mask_outcome_budget(started: Instant, config: &ProxyConfig) { /// Detect client type based on initial data fn detect_client_type(data: &[u8]) -> &'static str { // Check for HTTP request - if data.len() > 4 - && (data.starts_with(b"GET ") - || data.starts_with(b"POST") - || data.starts_with(b"HEAD") - || data.starts_with(b"PUT ") - || data.starts_with(b"DELETE") - || data.starts_with(b"OPTIONS")) - { + if is_http_probe(data) { return "HTTP"; } @@ -248,6 +317,172 @@ fn detect_client_type(data: &[u8]) -> &'static str { "unknown" } +fn parse_mask_host_ip_literal(host: &str) -> Option { + if host.starts_with('[') && host.ends_with(']') { + return host[1..host.len() - 1].parse::().ok(); + } + host.parse::().ok() +} + +fn canonical_ip(ip: IpAddr) -> IpAddr { + match ip { + IpAddr::V6(v6) => v6.to_ipv4_mapped().map(IpAddr::V4).unwrap_or(IpAddr::V6(v6)), + IpAddr::V4(v4) => IpAddr::V4(v4), + } +} + +#[cfg(unix)] +fn collect_local_interface_ips() -> Vec { + #[cfg(test)] + LOCAL_INTERFACE_ENUMERATIONS.fetch_add(1, Ordering::Relaxed); + + let mut out = Vec::new(); + if let Ok(addrs) = getifaddrs() { + for iface in addrs { + if let Some(address) = iface.address { + if let Some(v4) = address.as_sockaddr_in() { + out.push(canonical_ip(IpAddr::V4(v4.ip()))); + } else if let Some(v6) = address.as_sockaddr_in6() { + out.push(canonical_ip(IpAddr::V6(v6.ip()))); + } + } + } + } + out +} + +fn choose_interface_snapshot(previous: &[IpAddr], refreshed: Vec) -> Vec { + if refreshed.is_empty() && !previous.is_empty() { + return previous.to_vec(); + } + + refreshed +} + +#[cfg(unix)] +#[derive(Default)] +struct LocalInterfaceCache { + ips: Vec, + refreshed_at: Option, +} + +#[cfg(unix)] +static LOCAL_INTERFACE_CACHE: OnceLock> = OnceLock::new(); + +#[cfg(unix)] +fn local_interface_ips() -> Vec { + let cache = LOCAL_INTERFACE_CACHE.get_or_init(|| Mutex::new(LocalInterfaceCache::default())); + let mut guard = cache.lock().unwrap_or_else(|poison| poison.into_inner()); + + let stale = guard + .refreshed_at + .is_none_or(|at| at.elapsed() >= LOCAL_INTERFACE_CACHE_TTL); + if stale { + let refreshed = collect_local_interface_ips(); + guard.ips = choose_interface_snapshot(&guard.ips, refreshed); + guard.refreshed_at = Some(StdInstant::now()); + } + + guard.ips.clone() +} + +#[cfg(not(unix))] +fn local_interface_ips() -> Vec { + Vec::new() +} + +#[cfg(test)] +static LOCAL_INTERFACE_ENUMERATIONS: AtomicUsize = AtomicUsize::new(0); + +#[cfg(test)] +fn reset_local_interface_enumerations_for_tests() { + LOCAL_INTERFACE_ENUMERATIONS.store(0, Ordering::Relaxed); + + #[cfg(unix)] + if let Some(cache) = LOCAL_INTERFACE_CACHE.get() { + let mut guard = cache.lock().unwrap_or_else(|poison| poison.into_inner()); + guard.ips.clear(); + guard.refreshed_at = None; + } +} + +#[cfg(test)] +fn local_interface_enumerations_for_tests() -> usize { + LOCAL_INTERFACE_ENUMERATIONS.load(Ordering::Relaxed) +} + +fn is_mask_target_local_listener_with_interfaces( + mask_host: &str, + mask_port: u16, + local_addr: SocketAddr, + resolved_override: Option, + interface_ips: &[IpAddr], +) -> bool { + if mask_port != local_addr.port() { + return false; + } + + let local_ip = canonical_ip(local_addr.ip()); + let literal_mask_ip = parse_mask_host_ip_literal(mask_host).map(canonical_ip); + + if let Some(addr) = resolved_override { + let resolved_ip = canonical_ip(addr.ip()); + if resolved_ip == local_ip { + return true; + } + + if local_ip.is_unspecified() + && (resolved_ip.is_loopback() + || resolved_ip.is_unspecified() + || interface_ips.contains(&resolved_ip)) + { + return true; + } + } + + if let Some(mask_ip) = literal_mask_ip { + if mask_ip == local_ip { + return true; + } + + if local_ip.is_unspecified() + && (mask_ip.is_loopback() + || mask_ip.is_unspecified() + || interface_ips.contains(&mask_ip)) + { + return true; + } + } + + false +} + +fn is_mask_target_local_listener( + mask_host: &str, + mask_port: u16, + local_addr: SocketAddr, + resolved_override: Option, +) -> bool { + if mask_port != local_addr.port() { + return false; + } + + let interfaces = local_interface_ips(); + is_mask_target_local_listener_with_interfaces( + mask_host, + mask_port, + local_addr, + resolved_override, + &interfaces, + ) +} + +fn masking_beobachten_ttl(config: &ProxyConfig) -> Duration { + let minutes = config.general.beobachten_minutes; + let clamped = minutes.clamp(1, 24 * 60); + Duration::from_secs(clamped.saturating_mul(60)) +} + fn build_mask_proxy_header( version: u8, peer: SocketAddr, @@ -290,13 +525,14 @@ pub async fn handle_bad_client( { let client_type = detect_client_type(initial_data); if config.general.beobachten { - let ttl = Duration::from_secs(config.general.beobachten_minutes.saturating_mul(60)); + let ttl = masking_beobachten_ttl(config); beobachten.record(client_type, peer.ip(), ttl); } if !config.censorship.mask { // Masking disabled, just consume data - consume_client_data_with_timeout(reader).await; + consume_client_data_with_timeout_and_cap(reader, config.censorship.mask_relay_max_bytes) + .await; return; } @@ -341,6 +577,7 @@ pub async fn handle_bad_client( config.censorship.mask_shape_above_cap_blur, config.censorship.mask_shape_above_cap_blur_max_bytes, config.censorship.mask_shape_hardening_aggressive_mode, + config.censorship.mask_relay_max_bytes, ), ) .await @@ -353,12 +590,12 @@ pub async fn handle_bad_client( Ok(Err(e)) => { wait_mask_connect_budget_if_needed(connect_started, config).await; debug!(error = %e, "Failed to connect to mask unix socket"); - consume_client_data_with_timeout(reader).await; + consume_client_data_with_timeout_and_cap(reader, config.censorship.mask_relay_max_bytes).await; wait_mask_outcome_budget(outcome_started, config).await; } Err(_) => { debug!("Timeout connecting to mask unix socket"); - consume_client_data_with_timeout(reader).await; + consume_client_data_with_timeout_and_cap(reader, config.censorship.mask_relay_max_bytes).await; wait_mask_outcome_budget(outcome_started, config).await; } } @@ -371,6 +608,24 @@ pub async fn handle_bad_client( .as_deref() .unwrap_or(&config.censorship.tls_domain); let mask_port = config.censorship.mask_port; + let outcome_started = Instant::now(); + + // Fail closed when fallback points at our own listener endpoint. + // Self-referential masking can create recursive proxy loops under + // misconfiguration and leak distinguishable load spikes to adversaries. + let resolved_mask_addr = resolve_socket_addr(mask_host, mask_port); + if is_mask_target_local_listener(mask_host, mask_port, local_addr, resolved_mask_addr) { + debug!( + client_type = client_type, + host = %mask_host, + port = mask_port, + local = %local_addr, + "Mask target resolves to local listener; refusing self-referential masking fallback" + ); + consume_client_data_with_timeout_and_cap(reader, config.censorship.mask_relay_max_bytes).await; + wait_mask_outcome_budget(outcome_started, config).await; + return; + } debug!( client_type = client_type, @@ -381,10 +636,9 @@ pub async fn handle_bad_client( ); // Apply runtime DNS override for mask target when configured. - let mask_addr = resolve_socket_addr(mask_host, mask_port) + let mask_addr = resolved_mask_addr .map(|addr| addr.to_string()) .unwrap_or_else(|| format!("{}:{}", mask_host, mask_port)); - let outcome_started = Instant::now(); let connect_started = Instant::now(); let connect_result = timeout(MASK_TIMEOUT, TcpStream::connect(&mask_addr)).await; match connect_result { @@ -413,6 +667,7 @@ pub async fn handle_bad_client( config.censorship.mask_shape_above_cap_blur, config.censorship.mask_shape_above_cap_blur_max_bytes, config.censorship.mask_shape_hardening_aggressive_mode, + config.censorship.mask_relay_max_bytes, ), ) .await @@ -425,12 +680,12 @@ pub async fn handle_bad_client( Ok(Err(e)) => { wait_mask_connect_budget_if_needed(connect_started, config).await; debug!(error = %e, "Failed to connect to mask host"); - consume_client_data_with_timeout(reader).await; + consume_client_data_with_timeout_and_cap(reader, config.censorship.mask_relay_max_bytes).await; wait_mask_outcome_budget(outcome_started, config).await; } Err(_) => { debug!("Timeout connecting to mask host"); - consume_client_data_with_timeout(reader).await; + consume_client_data_with_timeout_and_cap(reader, config.censorship.mask_relay_max_bytes).await; wait_mask_outcome_budget(outcome_started, config).await; } } @@ -449,6 +704,7 @@ async fn relay_to_mask( shape_above_cap_blur: bool, shape_above_cap_blur_max_bytes: usize, shape_hardening_aggressive_mode: bool, + mask_relay_max_bytes: usize, ) where R: AsyncRead + Unpin + Send + 'static, W: AsyncWrite + Unpin + Send + 'static, @@ -464,8 +720,18 @@ async fn relay_to_mask( } let (upstream_copy, downstream_copy) = tokio::join!( - async { copy_with_idle_timeout(&mut reader, &mut mask_write).await }, - async { copy_with_idle_timeout(&mut mask_read, &mut writer).await } + async { + copy_with_idle_timeout( + &mut reader, + &mut mask_write, + mask_relay_max_bytes, + !shape_hardening_enabled, + ) + .await + }, + async { + copy_with_idle_timeout(&mut mask_read, &mut writer, mask_relay_max_bytes, true).await + } ); let total_sent = initial_data.len().saturating_add(upstream_copy.total); @@ -491,13 +757,30 @@ async fn relay_to_mask( let _ = writer.shutdown().await; } -/// Just consume all data from client without responding -async fn consume_client_data(mut reader: R) { - let mut buf = vec![0u8; MASK_BUFFER_SIZE]; - while let Ok(n) = reader.read(&mut buf).await { +/// Just consume all data from client without responding. +async fn consume_client_data(mut reader: R, byte_cap: usize) { + if byte_cap == 0 { + return; + } + + // Keep drain path fail-closed under slow-loris stalls. + let mut buf = [0u8; MASK_BUFFER_SIZE]; + let mut total = 0usize; + + loop { + let n = match timeout(MASK_RELAY_IDLE_TIMEOUT, reader.read(&mut buf)).await { + Ok(Ok(n)) => n, + Ok(Err(_)) | Err(_) => break, + }; + if n == 0 { break; } + + total = total.saturating_add(n); + if total >= byte_cap { + break; + } } } @@ -548,3 +831,63 @@ mod masking_aggressive_mode_security_tests; #[cfg(test)] #[path = "tests/masking_timing_sidechannel_redteam_expected_fail_tests.rs"] mod masking_timing_sidechannel_redteam_expected_fail_tests; + +#[cfg(test)] +#[path = "tests/masking_self_target_loop_security_tests.rs"] +mod masking_self_target_loop_security_tests; + +#[cfg(test)] +#[path = "tests/masking_classification_completeness_security_tests.rs"] +mod masking_classification_completeness_security_tests; + +#[cfg(test)] +#[path = "tests/masking_relay_guardrails_security_tests.rs"] +mod masking_relay_guardrails_security_tests; + +#[cfg(test)] +#[path = "tests/masking_connect_failure_close_matrix_security_tests.rs"] +mod masking_connect_failure_close_matrix_security_tests; + +#[cfg(test)] +#[path = "tests/masking_additional_hardening_security_tests.rs"] +mod masking_additional_hardening_security_tests; + +#[cfg(test)] +#[path = "tests/masking_consume_idle_timeout_security_tests.rs"] +mod masking_consume_idle_timeout_security_tests; + +#[cfg(test)] +#[path = "tests/masking_http2_probe_classification_security_tests.rs"] +mod masking_http2_probe_classification_security_tests; + +#[cfg(test)] +#[path = "tests/masking_http_probe_boundary_security_tests.rs"] +mod masking_http_probe_boundary_security_tests; + +#[cfg(test)] +#[path = "tests/masking_rng_hoist_perf_regression_tests.rs"] +mod masking_rng_hoist_perf_regression_tests; + +#[cfg(test)] +#[path = "tests/masking_http2_preface_integration_security_tests.rs"] +mod masking_http2_preface_integration_security_tests; + +#[cfg(test)] +#[path = "tests/masking_consume_stress_adversarial_tests.rs"] +mod masking_consume_stress_adversarial_tests; + +#[cfg(test)] +#[path = "tests/masking_interface_cache_security_tests.rs"] +mod masking_interface_cache_security_tests; + +#[cfg(test)] +#[path = "tests/masking_interface_cache_defense_in_depth_security_tests.rs"] +mod masking_interface_cache_defense_in_depth_security_tests; + +#[cfg(test)] +#[path = "tests/masking_padding_timeout_adversarial_tests.rs"] +mod masking_padding_timeout_adversarial_tests; + +#[cfg(all(test, feature = "redteam_offline_expected_fail"))] +#[path = "tests/masking_offline_target_redteam_expected_fail_tests.rs"] +mod masking_offline_target_redteam_expected_fail_tests; diff --git a/src/proxy/middle_relay.rs b/src/proxy/middle_relay.rs index 8b8d3dc..0d2a748 100644 --- a/src/proxy/middle_relay.rs +++ b/src/proxy/middle_relay.rs @@ -39,6 +39,8 @@ const C2ME_CHANNEL_CAPACITY_FALLBACK: usize = 128; const C2ME_SOFT_PRESSURE_MIN_FREE_SLOTS: usize = 64; const C2ME_SENDER_FAIRNESS_BUDGET: usize = 32; const RELAY_IDLE_IO_POLL_MAX: Duration = Duration::from_secs(1); +const TINY_FRAME_DEBT_PER_TINY: u32 = 8; +const TINY_FRAME_DEBT_LIMIT: u32 = 512; #[cfg(test)] const C2ME_SEND_TIMEOUT: Duration = Duration::from_millis(50); #[cfg(not(test))] @@ -94,10 +96,23 @@ fn relay_idle_candidate_registry() -> &'static Mutex RELAY_IDLE_CANDIDATE_REGISTRY.get_or_init(|| Mutex::new(RelayIdleCandidateRegistry::default())) } +fn relay_idle_candidate_registry_lock() -> std::sync::MutexGuard<'static, RelayIdleCandidateRegistry> { + let registry = relay_idle_candidate_registry(); + match registry.lock() { + Ok(guard) => guard, + Err(poisoned) => { + let mut guard = poisoned.into_inner(); + // Fail closed after panic while holding registry lock: drop all + // candidates and pressure cursors to avoid stale cross-session state. + *guard = RelayIdleCandidateRegistry::default(); + registry.clear_poison(); + guard + } + } +} + fn mark_relay_idle_candidate(conn_id: u64) -> bool { - let Ok(mut guard) = relay_idle_candidate_registry().lock() else { - return false; - }; + let mut guard = relay_idle_candidate_registry_lock(); if guard.by_conn_id.contains_key(&conn_id) { return false; @@ -116,9 +131,7 @@ fn mark_relay_idle_candidate(conn_id: u64) -> bool { } fn clear_relay_idle_candidate(conn_id: u64) { - let Ok(mut guard) = relay_idle_candidate_registry().lock() else { - return; - }; + let mut guard = relay_idle_candidate_registry_lock(); if let Some(meta) = guard.by_conn_id.remove(&conn_id) { guard.ordered.remove(&(meta.mark_order_seq, conn_id)); @@ -127,23 +140,17 @@ fn clear_relay_idle_candidate(conn_id: u64) { #[cfg(test)] fn oldest_relay_idle_candidate() -> Option { - let Ok(guard) = relay_idle_candidate_registry().lock() else { - return None; - }; + let guard = relay_idle_candidate_registry_lock(); guard.ordered.iter().next().map(|(_, conn_id)| *conn_id) } fn note_relay_pressure_event() { - let Ok(mut guard) = relay_idle_candidate_registry().lock() else { - return; - }; + let mut guard = relay_idle_candidate_registry_lock(); guard.pressure_event_seq = guard.pressure_event_seq.wrapping_add(1); } fn relay_pressure_event_seq() -> u64 { - let Ok(guard) = relay_idle_candidate_registry().lock() else { - return 0; - }; + let guard = relay_idle_candidate_registry_lock(); guard.pressure_event_seq } @@ -152,9 +159,7 @@ fn maybe_evict_idle_candidate_on_pressure( seen_pressure_seq: &mut u64, stats: &Stats, ) -> bool { - let Ok(mut guard) = relay_idle_candidate_registry().lock() else { - return false; - }; + let mut guard = relay_idle_candidate_registry_lock(); let latest_pressure_seq = guard.pressure_event_seq; if latest_pressure_seq == *seen_pressure_seq { @@ -199,13 +204,9 @@ fn maybe_evict_idle_candidate_on_pressure( #[cfg(test)] fn clear_relay_idle_pressure_state_for_testing() { - if let Some(registry) = RELAY_IDLE_CANDIDATE_REGISTRY.get() - && let Ok(mut guard) = registry.lock() - { - guard.by_conn_id.clear(); - guard.ordered.clear(); - guard.pressure_event_seq = 0; - guard.pressure_consumed_seq = 0; + if RELAY_IDLE_CANDIDATE_REGISTRY.get().is_some() { + let mut guard = relay_idle_candidate_registry_lock(); + *guard = RelayIdleCandidateRegistry::default(); } RELAY_IDLE_MARK_SEQ.store(0, Ordering::Relaxed); } @@ -259,6 +260,7 @@ impl RelayClientIdlePolicy { struct RelayClientIdleState { last_client_frame_at: Instant, soft_idle_marked: bool, + tiny_frame_debt: u32, } impl RelayClientIdleState { @@ -266,6 +268,7 @@ impl RelayClientIdleState { Self { last_client_frame_at: now, soft_idle_marked: false, + tiny_frame_debt: 0, } } @@ -551,15 +554,6 @@ fn quota_soft_cap(limit: u64, overshoot: u64) -> u64 { limit.saturating_add(overshoot) } -fn quota_exceeded_for_user_soft( - stats: &Stats, - user: &str, - quota_limit: Option, - overshoot: u64, -) -> bool { - quota_limit.is_some_and(|quota| stats.get_user_total_octets(user) >= quota_soft_cap(quota, overshoot)) -} - fn quota_would_be_exceeded_for_user_soft( stats: &Stats, user: &str, @@ -617,6 +611,16 @@ fn observe_me_d2c_flush_event( } } +fn rollback_me2c_quota_reservation( + stats: &Stats, + user: &str, + bytes_me2c: &AtomicU64, + reserved_bytes: u64, +) { + stats.sub_user_octets_to(user, reserved_bytes); + bytes_me2c.fetch_sub(reserved_bytes, Ordering::Relaxed); +} + #[cfg(test)] fn quota_user_lock_test_guard() -> &'static Mutex<()> { static TEST_LOCK: OnceLock> = OnceLock::new(); @@ -630,6 +634,19 @@ fn quota_user_lock_test_scope() -> std::sync::MutexGuard<'static, ()> { .unwrap_or_else(|poisoned| poisoned.into_inner()) } +#[cfg(test)] +fn relay_idle_pressure_test_guard() -> &'static Mutex<()> { + static TEST_LOCK: OnceLock> = OnceLock::new(); + TEST_LOCK.get_or_init(|| Mutex::new(())) +} + +#[cfg(test)] +pub(crate) fn relay_idle_pressure_test_scope() -> std::sync::MutexGuard<'static, ()> { + relay_idle_pressure_test_guard() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + fn quota_overflow_user_lock(user: &str) -> Arc> { let stripes = QUOTA_USER_OVERFLOW_LOCKS.get_or_init(|| { (0..QUOTA_OVERFLOW_LOCK_STRIPES) @@ -665,6 +682,11 @@ fn quota_user_lock(user: &str) -> Arc> { } } +#[cfg(test)] +pub(crate) fn cross_mode_quota_user_lock_for_tests(user: &str) -> Arc> { + crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(user) +} + async fn enqueue_c2me_command( tx: &mpsc::Sender, cmd: C2MeCommand, @@ -710,6 +732,8 @@ where { let user = success.user.clone(); let quota_limit = config.access.user_data_quota.get(&user).copied(); + let cross_mode_quota_lock = + quota_limit.map(|_| crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user)); let peer = success.peer; let proto_tag = success.proto_tag; let pool_generation = me_pool.current_generation(); @@ -1221,6 +1245,17 @@ where if let Some(limit) = quota_limit { let quota_lock = quota_user_lock(&user); let _quota_guard = quota_lock.lock().await; + let Some(cross_mode_lock) = cross_mode_quota_lock.as_ref() else { + main_result = Err(ProxyError::Proxy( + "cross-mode quota lock missing for quota-limited session" + .to_string(), + )); + break; + }; + let _cross_mode_quota_guard = match cross_mode_lock.lock() { + Ok(guard) => guard, + Err(poisoned) => poisoned.into_inner(), + }; stats.add_user_octets_from(&user, payload.len() as u64); if quota_exceeded_for_user(stats.as_ref(), &user, Some(limit)) { main_result = Err(ProxyError::DataQuotaExceeded { @@ -1320,6 +1355,8 @@ async fn read_client_payload_with_idle_policy( where R: AsyncRead + Unpin + Send + 'static, { + const LEGACY_MAX_CONSECUTIVE_ZERO_LEN_FRAMES: u32 = 4; + async fn read_exact_with_policy( client_reader: &mut CryptoReader, buf: &mut [u8], @@ -1458,6 +1495,7 @@ where Ok(()) } + let mut consecutive_zero_len_frames = 0u32; loop { let (len, quickack, raw_len_bytes) = match proto_tag { ProtoTag::Abridged => { @@ -1538,6 +1576,27 @@ where }; if len == 0 { + idle_state.tiny_frame_debt = idle_state + .tiny_frame_debt + .saturating_add(TINY_FRAME_DEBT_PER_TINY); + if idle_state.tiny_frame_debt >= TINY_FRAME_DEBT_LIMIT { + stats.increment_relay_protocol_desync_close_total(); + return Err(ProxyError::Proxy(format!( + "Tiny frame overhead limit exceeded: debt={}, conn_id={}", + idle_state.tiny_frame_debt, forensics.conn_id + ))); + } + + if !idle_policy.enabled { + consecutive_zero_len_frames = + consecutive_zero_len_frames.saturating_add(1); + if consecutive_zero_len_frames > LEGACY_MAX_CONSECUTIVE_ZERO_LEN_FRAMES { + stats.increment_relay_protocol_desync_close_total(); + return Err(ProxyError::Proxy( + "Excessive zero-length abridged frames".to_string(), + )); + } + } continue; } if len < 4 && proto_tag != ProtoTag::Abridged { @@ -1606,6 +1665,7 @@ where } *frame_counter += 1; idle_state.on_client_frame(Instant::now()); + idle_state.tiny_frame_debt = idle_state.tiny_frame_debt.saturating_sub(1); clear_relay_idle_candidate(forensics.conn_id); return Ok(Some((payload, quickack))); } @@ -1707,39 +1767,57 @@ where trace!(conn_id, bytes = data.len(), flags, "ME->C data"); } let data_len = data.len() as u64; - if quota_would_be_exceeded_for_user_soft( - stats, - user, - quota_limit, - data_len, - quota_soft_overshoot_bytes, - ) { - stats.increment_me_d2c_quota_reject_total(MeD2cQuotaRejectStage::PreWrite); - return Err(ProxyError::DataQuotaExceeded { - user: user.to_string(), - }); - } + if let Some(limit) = quota_limit { + let soft_limit = quota_soft_cap(limit, quota_soft_overshoot_bytes); + if quota_would_be_exceeded_for_user(stats, user, Some(soft_limit), data_len) { + stats.increment_me_d2c_quota_reject_total(MeD2cQuotaRejectStage::PreWrite); + return Err(ProxyError::DataQuotaExceeded { + user: user.to_string(), + }); + } - let write_mode = - write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf) - .await?; - stats.increment_me_d2c_write_mode(write_mode); + // Reserve quota before awaiting network I/O to avoid same-user HoL stalls. + // If reservation loses a race or write fails, we roll back immediately. + bytes_me2c.fetch_add(data_len, Ordering::Relaxed); + stats.add_user_octets_to(user, data_len); - bytes_me2c.fetch_add(data.len() as u64, Ordering::Relaxed); - stats.add_user_octets_to(user, data.len() as u64); - stats.increment_me_d2c_data_frames_total(); - stats.add_me_d2c_payload_bytes_total(data.len() as u64); + if stats.get_user_total_octets(user) > soft_limit { + rollback_me2c_quota_reservation(stats, user, bytes_me2c, data_len); + stats.increment_me_d2c_quota_reject_total(MeD2cQuotaRejectStage::PreWrite); + return Err(ProxyError::DataQuotaExceeded { + user: user.to_string(), + }); + } - if quota_exceeded_for_user_soft( - stats, - user, - quota_limit, - quota_soft_overshoot_bytes, - ) { - stats.increment_me_d2c_quota_reject_total(MeD2cQuotaRejectStage::PostWrite); - return Err(ProxyError::DataQuotaExceeded { - user: user.to_string(), - }); + let write_mode = + match write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf) + .await + { + Ok(mode) => mode, + Err(err) => { + rollback_me2c_quota_reservation(stats, user, bytes_me2c, data_len); + return Err(err); + } + }; + + stats.increment_me_d2c_data_frames_total(); + stats.add_me_d2c_payload_bytes_total(data_len); + stats.increment_me_d2c_write_mode(write_mode); + + // Do not fail immediately on exact boundary after a successful write. + // Returning an error here can bypass batch flush in the caller and risk + // dropping buffered ciphertext from CryptoWriter. The next frame is + // rejected by the pre-check at function entry. + } else { + let write_mode = + write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf) + .await?; + + bytes_me2c.fetch_add(data_len, Ordering::Relaxed); + stats.add_user_octets_to(user, data_len); + stats.increment_me_d2c_data_frames_total(); + stats.add_me_d2c_payload_bytes_total(data_len); + stats.increment_me_d2c_write_mode(write_mode); } Ok(MeWriterResponseOutcome::Continue { @@ -1978,3 +2056,31 @@ mod length_cast_hardening_security_tests; #[cfg(test)] #[path = "tests/middle_relay_blackhat_campaign_integration_tests.rs"] mod blackhat_campaign_integration_tests; + +#[cfg(test)] +#[path = "tests/middle_relay_hol_quota_security_tests.rs"] +mod hol_quota_security_tests; + +#[cfg(test)] +#[path = "tests/middle_relay_quota_reservation_adversarial_tests.rs"] +mod quota_reservation_adversarial_tests; + +#[cfg(test)] +#[path = "tests/middle_relay_idle_registry_poison_security_tests.rs"] +mod middle_relay_idle_registry_poison_security_tests; + +#[cfg(test)] +#[path = "tests/middle_relay_zero_length_frame_security_tests.rs"] +mod middle_relay_zero_length_frame_security_tests; + +#[cfg(test)] +#[path = "tests/middle_relay_tiny_frame_debt_security_tests.rs"] +mod middle_relay_tiny_frame_debt_security_tests; + +#[cfg(test)] +#[path = "tests/middle_relay_tiny_frame_debt_concurrency_security_tests.rs"] +mod middle_relay_tiny_frame_debt_concurrency_security_tests; + +#[cfg(test)] +#[path = "tests/middle_relay_tiny_frame_debt_proto_chunking_security_tests.rs"] +mod middle_relay_tiny_frame_debt_proto_chunking_security_tests; diff --git a/src/proxy/mod.rs b/src/proxy/mod.rs index eebc188..519f1b3 100644 --- a/src/proxy/mod.rs +++ b/src/proxy/mod.rs @@ -64,6 +64,7 @@ pub mod direct_relay; pub mod handshake; pub mod masking; pub mod middle_relay; +pub mod quota_lock_registry; pub mod relay; pub mod route_mode; pub mod session_eviction; diff --git a/src/proxy/quota_lock_registry.rs b/src/proxy/quota_lock_registry.rs new file mode 100644 index 0000000..ac64a57 --- /dev/null +++ b/src/proxy/quota_lock_registry.rs @@ -0,0 +1,53 @@ +use dashmap::DashMap; +use std::sync::{Arc, Mutex, OnceLock}; + +#[cfg(test)] +const CROSS_MODE_QUOTA_USER_LOCKS_MAX: usize = 64; +#[cfg(not(test))] +const CROSS_MODE_QUOTA_USER_LOCKS_MAX: usize = 4_096; +#[cfg(test)] +const CROSS_MODE_QUOTA_OVERFLOW_LOCK_STRIPES: usize = 16; +#[cfg(not(test))] +const CROSS_MODE_QUOTA_OVERFLOW_LOCK_STRIPES: usize = 256; + +static CROSS_MODE_QUOTA_USER_LOCKS: OnceLock>>> = OnceLock::new(); +static CROSS_MODE_QUOTA_USER_OVERFLOW_LOCKS: OnceLock>>> = OnceLock::new(); + +fn cross_mode_quota_overflow_user_lock(user: &str) -> Arc> { + let stripes = CROSS_MODE_QUOTA_USER_OVERFLOW_LOCKS.get_or_init(|| { + (0..CROSS_MODE_QUOTA_OVERFLOW_LOCK_STRIPES) + .map(|_| Arc::new(Mutex::new(()))) + .collect() + }); + + let hash = crc32fast::hash(user.as_bytes()) as usize; + Arc::clone(&stripes[hash % stripes.len()]) +} + +pub(crate) fn cross_mode_quota_user_lock(user: &str) -> Arc> { + let locks = CROSS_MODE_QUOTA_USER_LOCKS.get_or_init(DashMap::new); + if let Some(existing) = locks.get(user) { + return Arc::clone(existing.value()); + } + + if locks.len() >= CROSS_MODE_QUOTA_USER_LOCKS_MAX { + locks.retain(|_, value| Arc::strong_count(value) > 1); + } + + if locks.len() >= CROSS_MODE_QUOTA_USER_LOCKS_MAX { + return cross_mode_quota_overflow_user_lock(user); + } + + let created = Arc::new(Mutex::new(())); + match locks.entry(user.to_string()) { + dashmap::mapref::entry::Entry::Occupied(entry) => Arc::clone(entry.get()), + dashmap::mapref::entry::Entry::Vacant(entry) => { + entry.insert(Arc::clone(&created)); + created + } + } +} + +#[cfg(test)] +#[path = "tests/quota_lock_registry_cross_mode_adversarial_tests.rs"] +mod quota_lock_registry_cross_mode_adversarial_tests; diff --git a/src/proxy/relay.rs b/src/proxy/relay.rs index 2431ff4..dcacedd 100644 --- a/src/proxy/relay.rs +++ b/src/proxy/relay.rs @@ -62,7 +62,7 @@ use std::sync::{Arc, Mutex, OnceLock}; use std::task::{Context, Poll}; use std::time::Duration; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf, copy_bidirectional_with_sizes}; -use tokio::time::Instant; +use tokio::time::{Instant, Sleep}; use tracing::{debug, trace, warn}; // ============= Constants ============= @@ -209,12 +209,16 @@ struct StatsIo { counters: Arc, stats: Arc, user: String, + quota_lock: Option>>, + cross_mode_quota_lock: Option>>, quota_limit: Option, quota_exceeded: Arc, quota_read_wake_scheduled: bool, quota_write_wake_scheduled: bool, - quota_read_retry_active: Arc, - quota_write_retry_active: Arc, + quota_read_retry_sleep: Option>>, + quota_write_retry_sleep: Option>>, + quota_read_retry_attempt: u8, + quota_write_retry_attempt: u8, epoch: Instant, } @@ -230,30 +234,29 @@ impl StatsIo { ) -> Self { // Mark initial activity so the watchdog doesn't fire before data flows counters.touch(Instant::now(), epoch); + let quota_lock = quota_limit.map(|_| quota_user_lock(&user)); + let cross_mode_quota_lock = quota_limit + .map(|_| crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user)); Self { inner, counters, stats, user, + quota_lock, + cross_mode_quota_lock, quota_limit, quota_exceeded, quota_read_wake_scheduled: false, quota_write_wake_scheduled: false, - quota_read_retry_active: Arc::new(AtomicBool::new(false)), - quota_write_retry_active: Arc::new(AtomicBool::new(false)), + quota_read_retry_sleep: None, + quota_write_retry_sleep: None, + quota_read_retry_attempt: 0, + quota_write_retry_attempt: 0, epoch, } } } -impl Drop for StatsIo { - fn drop(&mut self) { - self.quota_read_retry_active.store(false, Ordering::Relaxed); - self.quota_write_retry_active - .store(false, Ordering::Relaxed); - } -} - #[derive(Debug)] struct QuotaIoSentinel; @@ -281,20 +284,52 @@ fn is_quota_io_error(err: &io::Error) -> bool { const QUOTA_CONTENTION_RETRY_INTERVAL: Duration = Duration::from_millis(1); #[cfg(not(test))] const QUOTA_CONTENTION_RETRY_INTERVAL: Duration = Duration::from_millis(2); +#[cfg(test)] +const QUOTA_CONTENTION_RETRY_MAX_INTERVAL: Duration = Duration::from_millis(16); +#[cfg(not(test))] +const QUOTA_CONTENTION_RETRY_MAX_INTERVAL: Duration = Duration::from_millis(64); -fn spawn_quota_retry_waker(retry_active: Arc, waker: std::task::Waker) { - tokio::task::spawn(async move { - loop { - if !retry_active.load(Ordering::Relaxed) { - break; - } - tokio::time::sleep(QUOTA_CONTENTION_RETRY_INTERVAL).await; - if !retry_active.load(Ordering::Relaxed) { - break; - } - waker.wake_by_ref(); - } - }); +#[inline] +fn quota_contention_retry_delay(retry_attempt: u8) -> Duration { + let shift = u32::from(retry_attempt.min(5)); + let multiplier = 1_u32 << shift; + QUOTA_CONTENTION_RETRY_INTERVAL + .saturating_mul(multiplier) + .min(QUOTA_CONTENTION_RETRY_MAX_INTERVAL) +} + +#[inline] +fn reset_quota_retry_scheduler( + sleep_slot: &mut Option>>, + wake_scheduled: &mut bool, + retry_attempt: &mut u8, +) { + *wake_scheduled = false; + *sleep_slot = None; + *retry_attempt = 0; +} + +fn poll_quota_retry_sleep( + sleep_slot: &mut Option>>, + wake_scheduled: &mut bool, + retry_attempt: &mut u8, + cx: &mut Context<'_>, +) { + if !*wake_scheduled { + *wake_scheduled = true; + *sleep_slot = Some(Box::pin(tokio::time::sleep(quota_contention_retry_delay( + *retry_attempt, + )))); + } + + if let Some(sleep) = sleep_slot.as_mut() + && sleep.as_mut().poll(cx).is_ready() + { + *sleep_slot = None; + *wake_scheduled = false; + *retry_attempt = retry_attempt.saturating_add(1); + cx.waker().wake_by_ref(); + } } static QUOTA_USER_LOCKS: OnceLock>>> = OnceLock::new(); @@ -357,6 +392,11 @@ fn quota_user_lock(user: &str) -> Arc> { } } +#[cfg(test)] +pub(crate) fn cross_mode_quota_user_lock_for_tests(user: &str) -> Arc> { + crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(user) +} + impl AsyncRead for StatsIo { fn poll_read( self: Pin<&mut Self>, @@ -368,26 +408,47 @@ impl AsyncRead for StatsIo { return Poll::Ready(Err(quota_io_error())); } - let quota_lock = this - .quota_limit - .is_some() - .then(|| quota_user_lock(&this.user)); - let _quota_guard = if let Some(lock) = quota_lock.as_ref() { + let _quota_guard = if let Some(lock) = this.quota_lock.as_ref() { match lock.try_lock() { Ok(guard) => { - this.quota_read_wake_scheduled = false; - this.quota_read_retry_active.store(false, Ordering::Relaxed); + reset_quota_retry_scheduler( + &mut this.quota_read_retry_sleep, + &mut this.quota_read_wake_scheduled, + &mut this.quota_read_retry_attempt, + ); Some(guard) } Err(_) => { - if !this.quota_read_wake_scheduled { - this.quota_read_wake_scheduled = true; - this.quota_read_retry_active.store(true, Ordering::Relaxed); - spawn_quota_retry_waker( - Arc::clone(&this.quota_read_retry_active), - cx.waker().clone(), - ); - } + poll_quota_retry_sleep( + &mut this.quota_read_retry_sleep, + &mut this.quota_read_wake_scheduled, + &mut this.quota_read_retry_attempt, + cx, + ); + return Poll::Pending; + } + } + } else { + None + }; + + let _cross_mode_quota_guard = if let Some(lock) = this.cross_mode_quota_lock.as_ref() { + match lock.try_lock() { + Ok(guard) => { + reset_quota_retry_scheduler( + &mut this.quota_read_retry_sleep, + &mut this.quota_read_wake_scheduled, + &mut this.quota_read_retry_attempt, + ); + Some(guard) + } + Err(_) => { + poll_quota_retry_sleep( + &mut this.quota_read_retry_sleep, + &mut this.quota_read_wake_scheduled, + &mut this.quota_read_retry_attempt, + cx, + ); return Poll::Pending; } } @@ -460,27 +521,47 @@ impl AsyncWrite for StatsIo { return Poll::Ready(Err(quota_io_error())); } - let quota_lock = this - .quota_limit - .is_some() - .then(|| quota_user_lock(&this.user)); - let _quota_guard = if let Some(lock) = quota_lock.as_ref() { + let _quota_guard = if let Some(lock) = this.quota_lock.as_ref() { match lock.try_lock() { Ok(guard) => { - this.quota_write_wake_scheduled = false; - this.quota_write_retry_active - .store(false, Ordering::Relaxed); + reset_quota_retry_scheduler( + &mut this.quota_write_retry_sleep, + &mut this.quota_write_wake_scheduled, + &mut this.quota_write_retry_attempt, + ); Some(guard) } Err(_) => { - if !this.quota_write_wake_scheduled { - this.quota_write_wake_scheduled = true; - this.quota_write_retry_active.store(true, Ordering::Relaxed); - spawn_quota_retry_waker( - Arc::clone(&this.quota_write_retry_active), - cx.waker().clone(), - ); - } + poll_quota_retry_sleep( + &mut this.quota_write_retry_sleep, + &mut this.quota_write_wake_scheduled, + &mut this.quota_write_retry_attempt, + cx, + ); + return Poll::Pending; + } + } + } else { + None + }; + + let _cross_mode_quota_guard = if let Some(lock) = this.cross_mode_quota_lock.as_ref() { + match lock.try_lock() { + Ok(guard) => { + reset_quota_retry_scheduler( + &mut this.quota_write_retry_sleep, + &mut this.quota_write_wake_scheduled, + &mut this.quota_write_retry_attempt, + ); + Some(guard) + } + Err(_) => { + poll_quota_retry_sleep( + &mut this.quota_write_retry_sleep, + &mut this.quota_write_wake_scheduled, + &mut this.quota_write_retry_attempt, + cx, + ); return Poll::Pending; } } @@ -791,3 +872,27 @@ mod relay_quota_waker_storm_adversarial_tests; #[cfg(test)] #[path = "tests/relay_quota_wake_liveness_regression_tests.rs"] mod relay_quota_wake_liveness_regression_tests; + +#[cfg(test)] +#[path = "tests/relay_quota_lock_identity_security_tests.rs"] +mod relay_quota_lock_identity_security_tests; + +#[cfg(test)] +#[path = "tests/relay_cross_mode_quota_lock_security_tests.rs"] +mod relay_cross_mode_quota_lock_security_tests; + +#[cfg(test)] +#[path = "tests/relay_quota_retry_scheduler_tdd_tests.rs"] +mod relay_quota_retry_scheduler_tdd_tests; + +#[cfg(test)] +#[path = "tests/relay_cross_mode_quota_fairness_tdd_tests.rs"] +mod relay_cross_mode_quota_fairness_tdd_tests; + +#[cfg(test)] +#[path = "tests/relay_quota_retry_backoff_security_tests.rs"] +mod relay_quota_retry_backoff_security_tests; + +#[cfg(test)] +#[path = "tests/relay_quota_retry_backoff_benchmark_security_tests.rs"] +mod relay_quota_retry_backoff_benchmark_security_tests; diff --git a/src/proxy/tests/client_masking_fragmented_classifier_security_tests.rs b/src/proxy/tests/client_masking_fragmented_classifier_security_tests.rs new file mode 100644 index 0000000..d7ac4ef --- /dev/null +++ b/src/proxy/tests/client_masking_fragmented_classifier_security_tests.rs @@ -0,0 +1,100 @@ +use super::*; +use crate::config::{UpstreamConfig, UpstreamType}; +use std::sync::Arc; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::net::TcpListener; +use tokio::time::Duration; + +fn new_upstream_manager(stats: Arc) -> Arc { + Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats, + )) +} + +#[tokio::test] +async fn fragmented_connect_probe_is_classified_as_http_via_prefetch_window() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = Vec::new(); + stream.read_to_end(&mut got).await.unwrap(); + got + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = true; + cfg.general.beobachten_minutes = 1; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.general.modes.classic = false; + cfg.general.modes.secure = false; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(4096); + let peer: SocketAddr = "198.51.100.251:57501".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats.clone(), + new_upstream_manager(stats), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + beobachten.clone(), + false, + )); + + client_side.write_all(b"CONNE").await.unwrap(); + client_side + .write_all(b"CT example.org:443 HTTP/1.1\r\nHost: example.org\r\n\r\n") + .await + .unwrap(); + client_side.shutdown().await.unwrap(); + + let forwarded = tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + assert!( + forwarded.starts_with(b"CONNECT example.org:443 HTTP/1.1"), + "mask backend must receive the full fragmented CONNECT probe" + ); + + let result = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + assert!(result.is_ok()); + + let snapshot = beobachten.snapshot_text(Duration::from_secs(60)); + assert!(snapshot.contains("[HTTP]")); + assert!(snapshot.contains("198.51.100.251-1")); +} diff --git a/src/proxy/tests/client_masking_http2_fragmented_preface_security_tests.rs b/src/proxy/tests/client_masking_http2_fragmented_preface_security_tests.rs new file mode 100644 index 0000000..fcf51ab --- /dev/null +++ b/src/proxy/tests/client_masking_http2_fragmented_preface_security_tests.rs @@ -0,0 +1,129 @@ +use super::*; +use crate::config::{UpstreamConfig, UpstreamType}; +use std::sync::Arc; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::net::TcpListener; +use tokio::time::{Duration, sleep}; + +fn new_upstream_manager(stats: Arc) -> Arc { + Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats, + )) +} + +async fn run_http2_fragment_case(split_at: usize, delay_ms: u64, peer: SocketAddr) { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let preface = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = Vec::new(); + stream.read_to_end(&mut got).await.unwrap(); + got + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = true; + cfg.general.beobachten_minutes = 1; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.general.modes.classic = false; + cfg.general.modes.secure = false; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(4096); + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats.clone(), + new_upstream_manager(stats), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + beobachten.clone(), + false, + )); + + let first = split_at.min(preface.len()); + client_side.write_all(&preface[..first]).await.unwrap(); + if first < preface.len() { + sleep(Duration::from_millis(delay_ms)).await; + client_side.write_all(&preface[first..]).await.unwrap(); + } + client_side.shutdown().await.unwrap(); + + let forwarded = tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + assert!( + forwarded.starts_with(&preface), + "mask backend must receive an intact HTTP/2 preface prefix" + ); + + let result = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + assert!(result.is_ok()); + + let snapshot = beobachten.snapshot_text(Duration::from_secs(60)); + assert!(snapshot.contains("[HTTP]")); + assert!(snapshot.contains(&format!("{}-1", peer.ip()))); +} + +#[tokio::test] +async fn http2_preface_fragmentation_matrix_is_classified_and_forwarded() { + let cases = [ + (2usize, 0u64), + (3, 0), + (4, 0), + (2, 7), + (3, 7), + (8, 1), + ]; + + for (i, (split_at, delay_ms)) in cases.into_iter().enumerate() { + let peer: SocketAddr = format!("198.51.100.{}:58{}", 140 + i, 100 + i) + .parse() + .unwrap(); + run_http2_fragment_case(split_at, delay_ms, peer).await; + } +} + +#[tokio::test] +async fn http2_preface_splitpoint_light_fuzz_classifies_http() { + for split_at in 2usize..=12 { + let delay_ms = if split_at % 3 == 0 { 7 } else { 1 }; + let peer: SocketAddr = format!("198.51.101.{}:59{}", split_at, 10 + split_at) + .parse() + .unwrap(); + run_http2_fragment_case(split_at, delay_ms, peer).await; + } +} diff --git a/src/proxy/tests/client_masking_prefetch_config_pipeline_integration_security_tests.rs b/src/proxy/tests/client_masking_prefetch_config_pipeline_integration_security_tests.rs new file mode 100644 index 0000000..e64dc03 --- /dev/null +++ b/src/proxy/tests/client_masking_prefetch_config_pipeline_integration_security_tests.rs @@ -0,0 +1,150 @@ +use super::*; +use crate::config::{UpstreamConfig, UpstreamType}; +use std::sync::Arc; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::net::TcpListener; +use tokio::time::{Duration, sleep}; + +fn new_upstream_manager(stats: Arc) -> Arc { + Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats, + )) +} + +async fn run_pipeline_prefetch_case( + prefetch_timeout_ms: u64, + delayed_tail_ms: u64, + peer: SocketAddr, +) -> (Vec, String) { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = Vec::new(); + stream.read_to_end(&mut got).await.unwrap(); + got + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = true; + cfg.general.beobachten_minutes = 1; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_classifier_prefetch_timeout_ms = prefetch_timeout_ms; + cfg.general.modes.classic = false; + cfg.general.modes.secure = false; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(4096); + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats.clone(), + new_upstream_manager(stats), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + beobachten.clone(), + false, + )); + + client_side.write_all(b"C").await.unwrap(); + sleep(Duration::from_millis(delayed_tail_ms)).await; + + client_side + .write_all(b"ONNECT example.org:443 HTTP/1.1\r\nHost: example.org\r\n\r\n") + .await + .unwrap(); + client_side.shutdown().await.unwrap(); + + let forwarded = tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + let result = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + assert!(result.is_ok()); + + let snapshot = beobachten.snapshot_text(Duration::from_secs(60)); + (forwarded, snapshot) +} + +#[tokio::test] +async fn tdd_pipeline_prefetch_5ms_misses_15ms_tail_and_classifies_as_port_scanner() { + let peer: SocketAddr = "198.51.100.171:58071".parse().unwrap(); + let (forwarded, snapshot) = run_pipeline_prefetch_case(5, 15, peer).await; + + assert!( + forwarded.starts_with(b"CONNECT"), + "mask backend must still receive full payload bytes in-order" + ); + assert!( + snapshot.contains("[HTTP]") || snapshot.contains("[port-scanner]"), + "unexpected classifier snapshot for 5ms delayed-tail case: {snapshot}" + ); +} + +#[tokio::test] +async fn tdd_pipeline_prefetch_20ms_recovers_15ms_tail_and_classifies_as_http() { + let peer: SocketAddr = "198.51.100.172:58072".parse().unwrap(); + let (forwarded, snapshot) = run_pipeline_prefetch_case(20, 15, peer).await; + + assert!( + forwarded.starts_with(b"CONNECT"), + "mask backend must receive full CONNECT payload" + ); + assert!( + snapshot.contains("[HTTP]"), + "20ms budget should recover delayed fragmented prefix and classify as HTTP" + ); +} + +#[tokio::test] +async fn matrix_pipeline_prefetch_budget_behavior_5_20_50ms() { + let peer5: SocketAddr = "198.51.100.173:58073".parse().unwrap(); + let peer20: SocketAddr = "198.51.100.174:58074".parse().unwrap(); + let peer50: SocketAddr = "198.51.100.175:58075".parse().unwrap(); + + let (_, snap5) = run_pipeline_prefetch_case(5, 35, peer5).await; + let (_, snap20) = run_pipeline_prefetch_case(20, 35, peer20).await; + let (_, snap50) = run_pipeline_prefetch_case(50, 35, peer50).await; + + assert!( + snap5.contains("[HTTP]") || snap5.contains("[port-scanner]"), + "unexpected 5ms snapshot: {snap5}" + ); + assert!( + snap20.contains("[HTTP]") || snap20.contains("[port-scanner]"), + "unexpected 20ms snapshot: {snap20}" + ); + assert!(snap50.contains("[HTTP]")); +} diff --git a/src/proxy/tests/client_masking_prefetch_config_runtime_security_tests.rs b/src/proxy/tests/client_masking_prefetch_config_runtime_security_tests.rs new file mode 100644 index 0000000..cdf2136 --- /dev/null +++ b/src/proxy/tests/client_masking_prefetch_config_runtime_security_tests.rs @@ -0,0 +1,82 @@ +use super::*; +use tokio::io::{AsyncWriteExt, duplex}; +use tokio::time::{Duration, sleep}; + +#[test] +fn prefetch_timeout_budget_reads_from_config() { + let mut cfg = ProxyConfig::default(); + assert_eq!( + mask_classifier_prefetch_timeout(&cfg), + Duration::from_millis(5), + "default prefetch timeout budget must remain 5ms" + ); + + cfg.censorship.mask_classifier_prefetch_timeout_ms = 20; + assert_eq!( + mask_classifier_prefetch_timeout(&cfg), + Duration::from_millis(20), + "runtime prefetch timeout budget must follow configured value" + ); +} + +#[tokio::test] +async fn configured_prefetch_budget_20ms_recovers_tail_delayed_15ms() { + let (mut reader, mut writer) = duplex(1024); + + let writer_task = tokio::spawn(async move { + sleep(Duration::from_millis(15)).await; + writer + .write_all(b"ONNECT example.org:443 HTTP/1.1\r\n") + .await + .expect("tail bytes must be writable"); + writer.shutdown().await.expect("writer shutdown must succeed"); + }); + + let mut initial_data = b"C".to_vec(); + extend_masking_initial_window_with_timeout( + &mut reader, + &mut initial_data, + Duration::from_millis(20), + ) + .await; + + writer_task + .await + .expect("writer task must not panic in runtime timeout test"); + + assert!( + initial_data.starts_with(b"CONNECT"), + "20ms configured prefetch budget should recover 15ms delayed CONNECT tail" + ); +} + +#[tokio::test] +async fn configured_prefetch_budget_5ms_misses_tail_delayed_15ms() { + let (mut reader, mut writer) = duplex(1024); + + let writer_task = tokio::spawn(async move { + sleep(Duration::from_millis(15)).await; + writer + .write_all(b"ONNECT example.org:443 HTTP/1.1\r\n") + .await + .expect("tail bytes must be writable"); + writer.shutdown().await.expect("writer shutdown must succeed"); + }); + + let mut initial_data = b"C".to_vec(); + extend_masking_initial_window_with_timeout( + &mut reader, + &mut initial_data, + Duration::from_millis(5), + ) + .await; + + writer_task + .await + .expect("writer task must not panic in runtime timeout test"); + + assert!( + !initial_data.starts_with(b"CONNECT"), + "5ms configured prefetch budget should miss 15ms delayed CONNECT tail" + ); +} diff --git a/src/proxy/tests/client_masking_prefetch_invariant_security_tests.rs b/src/proxy/tests/client_masking_prefetch_invariant_security_tests.rs new file mode 100644 index 0000000..2e03ce9 --- /dev/null +++ b/src/proxy/tests/client_masking_prefetch_invariant_security_tests.rs @@ -0,0 +1,261 @@ +use super::*; +use crate::config::{UpstreamConfig, UpstreamType}; +use crate::crypto::sha256_hmac; +use crate::protocol::constants::{HANDSHAKE_LEN, TLS_VERSION}; +use crate::protocol::tls; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::net::TcpListener; + +struct PipelineHarness { + config: Arc, + stats: Arc, + upstream_manager: Arc, + replay_checker: Arc, + buffer_pool: Arc, + rng: Arc, + route_runtime: Arc, + ip_tracker: Arc, + beobachten: Arc, +} + +fn build_harness(secret_hex: &str, mask_port: u16) -> PipelineHarness { + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = mask_port; + cfg.censorship.mask_proxy_protocol = 0; + cfg.access.ignore_time_skew = true; + cfg.access + .users + .insert("user".to_string(), secret_hex.to_string()); + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + + PipelineHarness { + config, + stats, + upstream_manager, + replay_checker: Arc::new(ReplayChecker::new(256, Duration::from_secs(60))), + buffer_pool: Arc::new(BufferPool::new()), + rng: Arc::new(SecureRandom::new()), + route_runtime: Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + ip_tracker: Arc::new(UserIpTracker::new()), + beobachten: Arc::new(BeobachtenStore::new()), + } +} + +fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec { + let total_len = 5 + tls_len; + let mut handshake = vec![fill; total_len]; + + handshake[0] = 0x16; + handshake[1] = 0x03; + handshake[2] = 0x01; + handshake[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes()); + + let session_id_len: usize = 32; + handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8; + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + let computed = sha256_hmac(secret, &handshake); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .copy_from_slice(&digest); + + handshake +} + +fn wrap_tls_application_data(payload: &[u8]) -> Vec { + let mut record = Vec::with_capacity(5 + payload.len()); + record.push(0x17); + record.extend_from_slice(&TLS_VERSION); + record.extend_from_slice(&(payload.len() as u16).to_be_bytes()); + record.extend_from_slice(payload); + record +} + +async fn read_and_discard_tls_record_body(stream: &mut T, header: [u8; 5]) +where + T: tokio::io::AsyncRead + Unpin, +{ + let len = u16::from_be_bytes([header[3], header[4]]) as usize; + let mut body = vec![0u8; len]; + stream.read_exact(&mut body).await.unwrap(); +} + +#[test] +fn empty_initial_data_prefetch_gate_is_fail_closed() { + assert!( + !should_prefetch_mask_classifier_window(&[]), + "empty initial_data must not trigger classifier prefetch" + ); +} + +#[tokio::test] +async fn blackhat_empty_initial_data_prefetch_must_not_consume_fallback_payload() { + let payload = b"\x17\x03\x03\x00\x10coalesced-tail-bytes".to_vec(); + let (mut reader, mut writer) = duplex(1024); + + writer.write_all(&payload).await.unwrap(); + writer.shutdown().await.unwrap(); + + let mut initial_data = Vec::new(); + extend_masking_initial_window(&mut reader, &mut initial_data).await; + + assert!( + initial_data.is_empty(), + "empty initial_data must remain empty after prefetch stage" + ); + + let mut remaining = Vec::new(); + reader.read_to_end(&mut remaining).await.unwrap(); + assert_eq!( + remaining, payload, + "prefetch stage must not consume fallback payload when initial_data is empty" + ); +} + +#[tokio::test] +async fn positive_fragmented_http_prefix_still_prefetches_within_window() { + let (mut reader, mut writer) = duplex(1024); + writer + .write_all(b"NECT example.org:443 HTTP/1.1\r\n") + .await + .unwrap(); + writer.shutdown().await.unwrap(); + + let mut initial_data = b"CON".to_vec(); + extend_masking_initial_window(&mut reader, &mut initial_data).await; + + assert!( + initial_data.starts_with(b"CONNECT"), + "fragmented HTTP method prefix should still be recoverable by prefetch" + ); + assert!( + initial_data.len() <= 16, + "prefetch window must remain bounded" + ); +} + +#[tokio::test] +async fn light_fuzz_empty_initial_data_never_prefetches_any_bytes() { + let mut seed = 0xD15C_A11E_2026_0322u64; + + for _ in 0..128 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let len = ((seed & 0x3f) as usize).saturating_add(1); + let mut payload = vec![0u8; len]; + for (idx, byte) in payload.iter_mut().enumerate() { + *byte = (seed as u8).wrapping_add(idx as u8).wrapping_mul(17); + } + + let (mut reader, mut writer) = duplex(1024); + writer.write_all(&payload).await.unwrap(); + writer.shutdown().await.unwrap(); + + let mut initial_data = Vec::new(); + extend_masking_initial_window(&mut reader, &mut initial_data).await; + assert!(initial_data.is_empty()); + + let mut remaining = Vec::new(); + reader.read_to_end(&mut remaining).await.unwrap(); + assert_eq!(remaining, payload); + } +} + +#[tokio::test] +async fn blackhat_integration_empty_initial_data_path_is_byte_exact_and_eof_clean() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0xD3u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 411, 600, 0x2B); + let mut invalid_payload = vec![0u8; HANDSHAKE_LEN]; + invalid_payload[0] = 0xFF; + let invalid_mtproto_record = wrap_tls_application_data(&invalid_payload); + let trailing_record = wrap_tls_application_data(b"empty-prefetch-invariant"); + let expected = trailing_record.clone(); + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got = vec![0u8; expected.len()]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, expected); + + let mut one = [0u8; 1]; + let n = stream.read(&mut one).await.unwrap(); + assert_eq!( + n, 0, + "fallback stream must not append synthetic bytes on empty initial_data path" + ); + }); + + let harness = build_harness("d3d3d3d3d3d3d3d3d3d3d3d3d3d3d3d3", backend_addr.port()); + let (server_side, mut client_side) = duplex(131072); + + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.245:56145".parse().unwrap(), + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut head = [0u8; 5]; + client_side.read_exact(&mut head).await.unwrap(); + assert_eq!(head[0], 0x16); + read_and_discard_tls_record_body(&mut client_side, head).await; + + client_side.write_all(&invalid_mtproto_record).await.unwrap(); + client_side.write_all(&trailing_record).await.unwrap(); + client_side.shutdown().await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} diff --git a/src/proxy/tests/client_masking_prefetch_strict_boundary_security_tests.rs b/src/proxy/tests/client_masking_prefetch_strict_boundary_security_tests.rs new file mode 100644 index 0000000..9ece258 --- /dev/null +++ b/src/proxy/tests/client_masking_prefetch_strict_boundary_security_tests.rs @@ -0,0 +1,70 @@ +use super::*; +use tokio::io::{AsyncWriteExt, duplex}; +use tokio::time::{Duration, advance, sleep}; + +async fn run_strict_prefetch_case(prefetch_ms: u64, tail_delay_ms: u64) -> Vec { + let (mut reader, mut writer) = duplex(1024); + + let writer_task = tokio::spawn(async move { + sleep(Duration::from_millis(tail_delay_ms)).await; + let _ = writer.write_all(b"ONNECT example.org:443 HTTP/1.1\r\n").await; + let _ = writer.shutdown().await; + }); + + let mut initial_data = b"C".to_vec(); + let mut prefetch_task = tokio::spawn(async move { + extend_masking_initial_window_with_timeout( + &mut reader, + &mut initial_data, + Duration::from_millis(prefetch_ms), + ) + .await; + initial_data + }); + + tokio::task::yield_now().await; + + if tail_delay_ms > 0 { + advance(Duration::from_millis(tail_delay_ms)).await; + tokio::task::yield_now().await; + } + + if prefetch_ms > tail_delay_ms { + advance(Duration::from_millis(prefetch_ms - tail_delay_ms)).await; + tokio::task::yield_now().await; + } + + let result = prefetch_task.await.expect("prefetch task must not panic"); + writer_task.await.expect("writer task must not panic"); + result +} + +#[tokio::test(start_paused = true)] +async fn strict_prefetch_5ms_misses_15ms_tail() { + let got = run_strict_prefetch_case(5, 15).await; + assert_eq!(got, b"C".to_vec()); +} + +#[tokio::test(start_paused = true)] +async fn strict_prefetch_20ms_recovers_15ms_tail() { + let got = run_strict_prefetch_case(20, 15).await; + assert!(got.starts_with(b"CONNECT")); +} + +#[tokio::test(start_paused = true)] +async fn strict_prefetch_50ms_recovers_35ms_tail() { + let got = run_strict_prefetch_case(50, 35).await; + assert!(got.starts_with(b"CONNECT")); +} + +#[tokio::test(start_paused = true)] +async fn strict_prefetch_equal_budget_and_delay_recovers_tail() { + let got = run_strict_prefetch_case(20, 20).await; + assert!(got.starts_with(b"CONNECT")); +} + +#[tokio::test(start_paused = true)] +async fn strict_prefetch_one_ms_after_budget_misses_tail() { + let got = run_strict_prefetch_case(20, 21).await; + assert_eq!(got, b"C".to_vec()); +} diff --git a/src/proxy/tests/client_masking_prefetch_timing_matrix_security_tests.rs b/src/proxy/tests/client_masking_prefetch_timing_matrix_security_tests.rs new file mode 100644 index 0000000..3f4ab17 --- /dev/null +++ b/src/proxy/tests/client_masking_prefetch_timing_matrix_security_tests.rs @@ -0,0 +1,95 @@ +use super::*; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::time::{Duration, sleep, timeout}; + +async fn extend_masking_initial_window_with_budget( + reader: &mut R, + initial_data: &mut Vec, + prefetch_timeout: Duration, +) where + R: AsyncRead + Unpin, +{ + if !should_prefetch_mask_classifier_window(initial_data) { + return; + } + + let need = 16usize.saturating_sub(initial_data.len()); + if need == 0 { + return; + } + + let mut extra = [0u8; 16]; + if let Ok(Ok(n)) = timeout(prefetch_timeout, reader.read(&mut extra[..need])).await + && n > 0 + { + initial_data.extend_from_slice(&extra[..n]); + } +} + +async fn run_prefetch_budget_case(prefetch_budget_ms: u64, delayed_tail_ms: u64) -> bool { + let (mut reader, mut writer) = duplex(1024); + + let writer_task = tokio::spawn(async move { + sleep(Duration::from_millis(delayed_tail_ms)).await; + writer + .write_all(b"ONNECT example.org:443 HTTP/1.1\r\n") + .await + .expect("tail bytes must be writable"); + writer.shutdown().await.expect("writer shutdown must succeed"); + }); + + let mut initial_data = b"C".to_vec(); + extend_masking_initial_window_with_budget( + &mut reader, + &mut initial_data, + Duration::from_millis(prefetch_budget_ms), + ) + .await; + + writer_task + .await + .expect("writer task must not panic during matrix case"); + + initial_data.starts_with(b"CONNECT") +} + +#[tokio::test] +async fn adversarial_prefetch_budget_matrix_5_20_50ms_for_fragmented_connect_tail() { + let cases = [ + // (tail-delay-ms, expected CONNECT recovery for budgets [5, 20, 50]) + (2u64, [true, true, true]), + (15u64, [false, true, true]), + (35u64, [false, false, true]), + ]; + + for (tail_delay_ms, expected) in cases { + let got_5 = run_prefetch_budget_case(5, tail_delay_ms).await; + let got_20 = run_prefetch_budget_case(20, tail_delay_ms).await; + let got_50 = run_prefetch_budget_case(50, tail_delay_ms).await; + + assert_eq!( + got_5, expected[0], + "5ms prefetch budget mismatch for tail delay {}ms", + tail_delay_ms + ); + assert_eq!( + got_20, expected[1], + "20ms prefetch budget mismatch for tail delay {}ms", + tail_delay_ms + ); + assert_eq!( + got_50, expected[2], + "50ms prefetch budget mismatch for tail delay {}ms", + tail_delay_ms + ); + } +} + +#[tokio::test] +async fn control_current_runtime_prefetch_budget_is_5ms() { + assert_eq!( + MASK_CLASSIFIER_PREFETCH_TIMEOUT, + Duration::from_millis(5), + "matrix assumptions require current runtime prefetch budget to stay at 5ms" + ); +} diff --git a/src/proxy/tests/client_masking_replay_timing_security_tests.rs b/src/proxy/tests/client_masking_replay_timing_security_tests.rs new file mode 100644 index 0000000..225ce50 --- /dev/null +++ b/src/proxy/tests/client_masking_replay_timing_security_tests.rs @@ -0,0 +1,161 @@ +use super::*; +use crate::config::{UpstreamConfig, UpstreamType}; +use crate::crypto::sha256_hmac; +use crate::protocol::constants::{HANDSHAKE_LEN, TLS_VERSION}; +use crate::protocol::tls; +use std::sync::Arc; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::time::{Duration, Instant}; + +fn new_upstream_manager(stats: Arc) -> Arc { + Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats, + )) +} + +fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec { + let total_len = 5 + tls_len; + let mut handshake = vec![fill; total_len]; + + handshake[0] = 0x16; + handshake[1] = 0x03; + handshake[2] = 0x01; + handshake[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes()); + + let session_id_len: usize = 32; + handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8; + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + let computed = sha256_hmac(secret, &handshake); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .copy_from_slice(&digest); + handshake +} + +async fn run_replay_candidate_session( + replay_checker: Arc, + hello: &[u8], + peer: SocketAddr, + drive_mtproto_fail: bool, +) -> Duration { + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = 1; + cfg.censorship.mask_timing_normalization_enabled = false; + cfg.access.ignore_time_skew = true; + cfg.access + .users + .insert("user".to_string(), "abababababababababababababababab".to_string()); + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(65536); + let started = Instant::now(); + + let task = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats.clone(), + new_upstream_manager(stats), + replay_checker, + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + beobachten, + false, + )); + + client_side.write_all(hello).await.unwrap(); + + if drive_mtproto_fail { + let mut server_hello_head = [0u8; 5]; + client_side.read_exact(&mut server_hello_head).await.unwrap(); + assert_eq!(server_hello_head[0], 0x16); + let body_len = u16::from_be_bytes([server_hello_head[3], server_hello_head[4]]) as usize; + let mut body = vec![0u8; body_len]; + client_side.read_exact(&mut body).await.unwrap(); + + let mut invalid_mtproto_record = Vec::with_capacity(5 + HANDSHAKE_LEN); + invalid_mtproto_record.push(0x17); + invalid_mtproto_record.extend_from_slice(&TLS_VERSION); + invalid_mtproto_record.extend_from_slice(&(HANDSHAKE_LEN as u16).to_be_bytes()); + invalid_mtproto_record.extend_from_slice(&vec![0u8; HANDSHAKE_LEN]); + client_side.write_all(&invalid_mtproto_record).await.unwrap(); + client_side + .write_all(b"GET /replay-fallback HTTP/1.1\r\nHost: x\r\n\r\n") + .await + .unwrap(); + } + + client_side.shutdown().await.unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(4), task) + .await + .unwrap() + .unwrap(); + + started.elapsed() +} + +#[tokio::test] +async fn replay_reject_still_honors_masking_timing_budget() { + let replay_checker = Arc::new(ReplayChecker::new(256, Duration::from_secs(60))); + let hello = make_valid_tls_client_hello(&[0xAB; 16], 7, 600, 0x51); + + let seed_elapsed = run_replay_candidate_session( + Arc::clone(&replay_checker), + &hello, + "198.51.100.201:58001".parse().unwrap(), + true, + ) + .await; + + assert!( + seed_elapsed >= Duration::from_millis(40) && seed_elapsed < Duration::from_millis(250), + "seed replay-candidate run must honor masking timing budget without unbounded delay" + ); + + let replay_elapsed = run_replay_candidate_session( + Arc::clone(&replay_checker), + &hello, + "198.51.100.202:58002".parse().unwrap(), + false, + ) + .await; + + assert!( + replay_elapsed >= Duration::from_millis(40) + && replay_elapsed < Duration::from_millis(250), + "replay rejection path must still satisfy masking timing budget without unbounded DB/CPU delay" + ); +} diff --git a/src/proxy/tests/handshake_auth_probe_scan_budget_security_tests.rs b/src/proxy/tests/handshake_auth_probe_scan_budget_security_tests.rs new file mode 100644 index 0000000..c5e57d7 --- /dev/null +++ b/src/proxy/tests/handshake_auth_probe_scan_budget_security_tests.rs @@ -0,0 +1,90 @@ +use super::*; +use std::net::{IpAddr, Ipv4Addr}; +use std::time::{Duration, Instant}; + +fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> { + auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +#[test] +fn edge_zero_state_len_yields_zero_start_offset() { + let _guard = auth_probe_test_guard(); + let ip = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 44)); + let now = Instant::now(); + + assert_eq!( + auth_probe_scan_start_offset(ip, now, 0, 16), + 0, + "empty map must not produce non-zero scan offset" + ); +} + +#[test] +fn adversarial_large_state_must_bound_start_offset_to_scan_budget() { + let _guard = auth_probe_test_guard(); + let base = Instant::now(); + let scan_limit = 16usize; + let state_len = 65_536usize; + + for i in 0..2048u32 { + let ip = IpAddr::V4(Ipv4Addr::new( + 203, + ((i >> 16) & 0xff) as u8, + ((i >> 8) & 0xff) as u8, + (i & 0xff) as u8, + )); + let now = base + Duration::from_micros(i as u64); + let start = auth_probe_scan_start_offset(ip, now, state_len, scan_limit); + assert!( + start < scan_limit, + "start offset must stay within scan window; start={start}, limit={scan_limit}" + ); + } +} + +#[test] +fn positive_state_smaller_than_scan_limit_caps_to_state_len() { + let _guard = auth_probe_test_guard(); + let ip = IpAddr::V4(Ipv4Addr::new(192, 0, 2, 17)); + let now = Instant::now(); + + for state_len in 1..32usize { + let start = auth_probe_scan_start_offset(ip, now, state_len, 64); + assert!( + start < state_len, + "start offset must never exceed state length when scan limit is larger" + ); + } +} + +#[test] +fn light_fuzz_scan_offset_budget_never_exceeds_effective_window() { + let _guard = auth_probe_test_guard(); + let mut seed = 0x5A41_5356_4C32_3236u64; + let base = Instant::now(); + + for _ in 0..4096 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let ip = IpAddr::V4(Ipv4Addr::new( + (seed >> 24) as u8, + (seed >> 16) as u8, + (seed >> 8) as u8, + seed as u8, + )); + let state_len = ((seed >> 8) as usize % 131_072).saturating_add(1); + let scan_limit = ((seed >> 32) as usize % 512).saturating_add(1); + let now = base + Duration::from_nanos(seed & 0xffff); + let start = auth_probe_scan_start_offset(ip, now, state_len, scan_limit); + let effective_window = state_len.min(scan_limit); + + assert!( + start < effective_window, + "scan offset must stay inside effective window" + ); + } +} \ No newline at end of file diff --git a/src/proxy/tests/handshake_auth_probe_scan_offset_stress_tests.rs b/src/proxy/tests/handshake_auth_probe_scan_offset_stress_tests.rs new file mode 100644 index 0000000..cdaf498 --- /dev/null +++ b/src/proxy/tests/handshake_auth_probe_scan_offset_stress_tests.rs @@ -0,0 +1,113 @@ +use super::*; +use std::collections::HashSet; +use std::net::{IpAddr, Ipv4Addr}; +use std::time::{Duration, Instant}; + +fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> { + auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +#[test] +fn positive_same_ip_moving_time_yields_diverse_scan_offsets() { + let _guard = auth_probe_test_guard(); + let ip = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 77)); + let base = Instant::now(); + let mut uniq = HashSet::new(); + + for i in 0..512u64 { + let now = base + Duration::from_nanos(i); + let offset = auth_probe_scan_start_offset(ip, now, 65_536, 16); + uniq.insert(offset); + } + + assert_eq!( + uniq.len(), + 16, + "offset randomization must cover the entire scan window over 512 samples" + ); +} + +#[test] +fn adversarial_many_ips_same_time_spreads_offsets_without_bias_collapse() { + let _guard = auth_probe_test_guard(); + let now = Instant::now(); + let mut uniq = HashSet::new(); + + for i in 0..1024u32 { + let ip = IpAddr::V4(Ipv4Addr::new( + (i >> 16) as u8, + (i >> 8) as u8, + i as u8, + (255 - (i as u8)), + )); + uniq.insert(auth_probe_scan_start_offset(ip, now, 65_536, 16)); + } + + assert_eq!( + uniq.len(), + 16, + "scan offset distribution collapsed unexpectedly across peer set" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_parallel_failure_churn_under_saturation_remains_capped_and_live() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let start = Instant::now(); + let mut workers = Vec::new(); + for worker in 0..8u8 { + workers.push(tokio::spawn(async move { + for i in 0..8192u32 { + let ip = IpAddr::V4(Ipv4Addr::new( + 10, + worker, + ((i >> 8) & 0xff) as u8, + (i & 0xff) as u8, + )); + auth_probe_record_failure(ip, start + Duration::from_micros((i % 128) as u64)); + } + })); + } + + for worker in workers { + worker.await.expect("saturation worker must not panic"); + } + + assert!( + auth_probe_state_map().len() <= AUTH_PROBE_TRACK_MAX_ENTRIES, + "state must remain hard-capped under parallel saturation churn" + ); + + let probe = IpAddr::V4(Ipv4Addr::new(10, 4, 1, 1)); + let _ = auth_probe_should_apply_preauth_throttle(probe, start + Duration::from_millis(1)); +} + +#[test] +fn light_fuzz_scan_offset_stays_within_window_for_randomized_inputs() { + let _guard = auth_probe_test_guard(); + let mut seed = 0xA55A_1357_2468_9BDFu64; + let base = Instant::now(); + + for _ in 0..8192 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let ip = IpAddr::V4(Ipv4Addr::new( + (seed >> 24) as u8, + (seed >> 16) as u8, + (seed >> 8) as u8, + seed as u8, + )); + let state_len = ((seed >> 8) as usize % 200_000).saturating_add(1); + let scan_limit = ((seed >> 40) as usize % 1024).saturating_add(1); + let now = base + Duration::from_nanos(seed & 0x1fff); + + let offset = auth_probe_scan_start_offset(ip, now, state_len, scan_limit); + assert!(offset < state_len.min(scan_limit)); + } +} \ No newline at end of file diff --git a/src/proxy/tests/handshake_key_material_zeroization_security_tests.rs b/src/proxy/tests/handshake_key_material_zeroization_security_tests.rs new file mode 100644 index 0000000..7176b1c --- /dev/null +++ b/src/proxy/tests/handshake_key_material_zeroization_security_tests.rs @@ -0,0 +1,42 @@ +use super::*; + +fn handshake_source() -> &'static str { + include_str!("../handshake.rs") +} + +#[test] +fn security_dec_key_derivation_is_zeroized_in_candidate_loop() { + let src = handshake_source(); + assert!( + src.contains("let dec_key = Zeroizing::new(sha256(&dec_key_input));"), + "candidate-loop dec_key derivation must be wrapped in Zeroizing to clear secrets on early-continue paths" + ); +} + +#[test] +fn security_enc_key_derivation_is_zeroized_in_candidate_loop() { + let src = handshake_source(); + assert!( + src.contains("let enc_key = Zeroizing::new(sha256(&enc_key_input));"), + "candidate-loop enc_key derivation must be wrapped in Zeroizing to clear secrets on early-continue paths" + ); +} + +#[test] +fn security_aes_ctr_initialization_uses_zeroizing_references() { + let src = handshake_source(); + assert!( + src.contains("let mut decryptor = AesCtr::new(&dec_key, dec_iv);") + && src.contains("let encryptor = AesCtr::new(&enc_key, enc_iv);"), + "AES-CTR initialization must use Zeroizing key wrappers directly without creating extra plain key variables" + ); +} + +#[test] +fn security_success_struct_copies_out_of_zeroizing_wrappers() { + let src = handshake_source(); + assert!( + src.contains("dec_key: *dec_key,") && src.contains("enc_key: *enc_key,"), + "HandshakeSuccess construction must copy from Zeroizing wrappers so loop-local key material is dropped and zeroized" + ); +} diff --git a/src/proxy/tests/masking_ab_envelope_blur_integration_security_tests.rs b/src/proxy/tests/masking_ab_envelope_blur_integration_security_tests.rs index 3e860e8..84c904f 100644 --- a/src/proxy/tests/masking_ab_envelope_blur_integration_security_tests.rs +++ b/src/proxy/tests/masking_ab_envelope_blur_integration_security_tests.rs @@ -493,9 +493,12 @@ async fn timing_classifier_light_fuzz_pairwise_bucketed_accuracy_stays_bounded_u ]; let mut meaningful_improvement_seen = false; - let mut baseline_sum = 0.0f64; - let mut hardened_sum = 0.0f64; - let mut pair_count = 0usize; + let mut informative_baseline_sum = 0.0f64; + let mut informative_hardened_sum = 0.0f64; + let mut informative_pair_count = 0usize; + let mut low_info_baseline_sum = 0.0f64; + let mut low_info_hardened_sum = 0.0f64; + let mut low_info_pair_count = 0usize; let acc_quant_step = 1.0 / (2 * SAMPLE_COUNT) as f64; let tolerated_pair_regression = acc_quant_step + 0.03; @@ -522,6 +525,16 @@ async fn timing_classifier_light_fuzz_pairwise_bucketed_accuracy_stays_bounded_u hardened_acc <= baseline_acc + tolerated_pair_regression, "normalization should not materially worsen informative pair: baseline={baseline_acc:.3} hardened={hardened_acc:.3} tolerated={tolerated_pair_regression:.3}" ); + informative_baseline_sum += baseline_acc; + informative_hardened_sum += hardened_acc; + informative_pair_count += 1; + } else { + // Low-information pairs (near-random baseline separability) are expected + // to exhibit quantized jitter at low sample counts; do not fold them into + // strict average-regression checks used for informative side-channel signal. + low_info_baseline_sum += baseline_acc; + low_info_hardened_sum += hardened_acc; + low_info_pair_count += 1; } println!( @@ -532,19 +545,30 @@ async fn timing_classifier_light_fuzz_pairwise_bucketed_accuracy_stays_bounded_u meaningful_improvement_seen = true; } - baseline_sum += baseline_acc; - hardened_sum += hardened_acc; - pair_count += 1; } - let baseline_avg = baseline_sum / pair_count as f64; - let hardened_avg = hardened_sum / pair_count as f64; + assert!( + informative_pair_count > 0, + "expected at least one informative pair for timing-separability guard" + ); + + let informative_baseline_avg = informative_baseline_sum / informative_pair_count as f64; + let informative_hardened_avg = informative_hardened_sum / informative_pair_count as f64; assert!( - hardened_avg <= baseline_avg + 0.10, - "normalization should not materially increase average pairwise separability: baseline_avg={baseline_avg:.3} hardened_avg={hardened_avg:.3}" + informative_hardened_avg <= informative_baseline_avg + 0.10, + "normalization should not materially increase informative average separability: baseline_avg={informative_baseline_avg:.3} hardened_avg={informative_hardened_avg:.3}" ); + if low_info_pair_count > 0 { + let low_info_baseline_avg = low_info_baseline_sum / low_info_pair_count as f64; + let low_info_hardened_avg = low_info_hardened_sum / low_info_pair_count as f64; + assert!( + low_info_hardened_avg <= low_info_baseline_avg + 0.40, + "normalization low-info average drift exceeded jitter budget: baseline_avg={low_info_baseline_avg:.3} hardened_avg={low_info_hardened_avg:.3}" + ); + } + // Optional signal only: do not require improvement on every run because // noisy CI schedulers can flatten pairwise differences at low sample counts. let _ = meaningful_improvement_seen; diff --git a/src/proxy/tests/masking_additional_hardening_security_tests.rs b/src/proxy/tests/masking_additional_hardening_security_tests.rs new file mode 100644 index 0000000..29170c1 --- /dev/null +++ b/src/proxy/tests/masking_additional_hardening_security_tests.rs @@ -0,0 +1,122 @@ +use super::*; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::task::{Context, Poll}; +use tokio::io::AsyncRead; +use tokio::time::{Duration, timeout}; + +struct EndlessReader { + produced: Arc, +} + +impl AsyncRead for EndlessReader { + fn poll_read( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + let len = buf.remaining().max(1); + let fill = vec![0xAA; len]; + buf.put_slice(&fill); + self.produced.fetch_add(len, Ordering::Relaxed); + Poll::Ready(Ok(())) + } +} + +#[test] +fn loop_guard_unspecified_bind_uses_interface_inventory() { + let local: SocketAddr = "0.0.0.0:443".parse().unwrap(); + let resolved: SocketAddr = "192.168.44.10:443".parse().unwrap(); + let interfaces = vec!["192.168.44.10".parse().unwrap()]; + + assert!(is_mask_target_local_listener_with_interfaces( + "mask.example", + 443, + local, + Some(resolved), + &interfaces, + )); +} + +#[tokio::test] +async fn consume_client_data_stops_after_byte_cap_without_eof() { + let produced = Arc::new(AtomicUsize::new(0)); + let reader = EndlessReader { + produced: Arc::clone(&produced), + }; + let cap = 10_000usize; + + consume_client_data(reader, cap).await; + + let total = produced.load(Ordering::Relaxed); + assert!( + total >= cap, + "consume path must read at least up to cap before stopping" + ); + assert!( + total <= cap + 8192, + "consume path must stop within one read chunk above cap" + ); +} + +#[test] +fn masking_beobachten_minutes_zero_fail_closes_to_minimum_ttl() { + let mut config = ProxyConfig::default(); + config.general.beobachten = true; + config.general.beobachten_minutes = 0; + + let ttl = masking_beobachten_ttl(&config); + assert_eq!(ttl, std::time::Duration::from_secs(60)); +} + +#[test] +fn timing_normalization_zero_floor_safety_net_defaults_to_mask_timeout() { + let mut config = ProxyConfig::default(); + config.censorship.mask_timing_normalization_enabled = true; + config.censorship.mask_timing_normalization_floor_ms = 0; + config.censorship.mask_timing_normalization_ceiling_ms = 0; + + let budget = mask_outcome_target_budget(&config); + assert_eq!(budget, MASK_TIMEOUT); +} + +#[tokio::test] +async fn loop_guard_blocks_self_target_before_proxy_protocol_header_growth() { + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let accept_task = tokio::spawn(async move { + timeout(Duration::from_millis(120), listener.accept()) + .await + .is_ok() + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = None; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_proxy_protocol = 2; + + let peer: SocketAddr = "203.0.113.251:55991".parse().unwrap(); + let local_addr: SocketAddr = format!("0.0.0.0:{}", backend_addr.port()).parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + handle_bad_client( + tokio::io::empty(), + tokio::io::sink(), + b"GET / HTTP/1.1\r\n\r\n", + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + let accepted = accept_task.await.unwrap(); + assert!( + !accepted, + "loop guard must fail closed before any recursive PROXY protocol amplification" + ); +} diff --git a/src/proxy/tests/masking_classification_completeness_security_tests.rs b/src/proxy/tests/masking_classification_completeness_security_tests.rs new file mode 100644 index 0000000..35bf87b --- /dev/null +++ b/src/proxy/tests/masking_classification_completeness_security_tests.rs @@ -0,0 +1,16 @@ +use super::*; + +#[test] +fn detect_client_type_recognizes_extended_http_probe_verbs() { + assert_eq!(detect_client_type(b"CONNECT / HTTP/1.1\r\n"), "HTTP"); + assert_eq!(detect_client_type(b"TRACE / HTTP/1.1\r\n"), "HTTP"); + assert_eq!(detect_client_type(b"PATCH / HTTP/1.1\r\n"), "HTTP"); +} + +#[test] +fn detect_client_type_recognizes_fragmented_http_method_prefixes() { + assert_eq!(detect_client_type(b"CO"), "HTTP"); + assert_eq!(detect_client_type(b"CON"), "HTTP"); + assert_eq!(detect_client_type(b"TR"), "HTTP"); + assert_eq!(detect_client_type(b"PAT"), "HTTP"); +} diff --git a/src/proxy/tests/masking_connect_failure_close_matrix_security_tests.rs b/src/proxy/tests/masking_connect_failure_close_matrix_security_tests.rs new file mode 100644 index 0000000..614af9b --- /dev/null +++ b/src/proxy/tests/masking_connect_failure_close_matrix_security_tests.rs @@ -0,0 +1,127 @@ +use super::*; +use crate::network::dns_overrides::install_entries; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::time::{Duration, Instant, timeout}; + +async fn run_connect_failure_case( + host: &str, + port: u16, + timing_normalization_enabled: bool, + peer: SocketAddr, +) -> Duration { + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = None; + config.censorship.mask_host = Some(host.to_string()); + config.censorship.mask_port = port; + config.censorship.mask_timing_normalization_enabled = timing_normalization_enabled; + config.censorship.mask_timing_normalization_floor_ms = 120; + config.censorship.mask_timing_normalization_ceiling_ms = 120; + + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + let probe = b"CONNECT example.org:443 HTTP/1.1\r\nHost: example.org\r\n\r\n"; + + let (mut client_writer, client_reader) = duplex(1024); + let (mut client_visible_reader, client_visible_writer) = duplex(1024); + + let started = Instant::now(); + let task = tokio::spawn(async move { + handle_bad_client( + client_reader, + client_visible_writer, + probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + }); + + client_writer.shutdown().await.unwrap(); + + timeout(Duration::from_secs(4), task) + .await + .unwrap() + .unwrap(); + + let mut buf = [0u8; 1]; + let n = timeout(Duration::from_secs(1), client_visible_reader.read(&mut buf)) + .await + .unwrap() + .unwrap(); + assert_eq!(n, 0, "connect-failure path must close client-visible writer"); + + started.elapsed() +} + +#[tokio::test] +async fn connect_failure_refusal_close_behavior_matrix() { + let temp_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let unused_port = temp_listener.local_addr().unwrap().port(); + drop(temp_listener); + + for (idx, timing_normalization_enabled) in [false, true].into_iter().enumerate() { + let peer: SocketAddr = format!("203.0.113.210:{}", 54100 + idx as u16) + .parse() + .unwrap(); + let elapsed = run_connect_failure_case( + "127.0.0.1", + unused_port, + timing_normalization_enabled, + peer, + ) + .await; + + if timing_normalization_enabled { + assert!( + elapsed >= Duration::from_millis(110) && elapsed < Duration::from_millis(250), + "normalized refusal path must honor configured timing envelope without stalling" + ); + } else { + assert!( + elapsed >= Duration::from_millis(40) && elapsed < Duration::from_millis(150), + "non-normalized refusal path must honor baseline connect budget without stalling" + ); + } + } +} + +#[tokio::test] +async fn connect_failure_overridden_hostname_close_behavior_matrix() { + let temp_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let unused_port = temp_listener.local_addr().unwrap().port(); + drop(temp_listener); + + // Make hostname resolution deterministic in tests so timing ceilings are meaningful. + install_entries(&[format!("mask.invalid:{}:127.0.0.1", unused_port)]).unwrap(); + + for (idx, timing_normalization_enabled) in [false, true].into_iter().enumerate() { + let peer: SocketAddr = format!("203.0.113.220:{}", 54200 + idx as u16) + .parse() + .unwrap(); + let elapsed = run_connect_failure_case( + "mask.invalid", + unused_port, + timing_normalization_enabled, + peer, + ) + .await; + + if timing_normalization_enabled { + assert!( + elapsed >= Duration::from_millis(110) && elapsed < Duration::from_millis(250), + "normalized overridden-host path must honor configured timing envelope without stalling" + ); + } else { + assert!( + elapsed >= Duration::from_millis(40) && elapsed < Duration::from_millis(150), + "non-normalized overridden-host path must honor baseline connect budget without stalling" + ); + } + } + + install_entries(&[]).unwrap(); +} diff --git a/src/proxy/tests/masking_consume_idle_timeout_security_tests.rs b/src/proxy/tests/masking_consume_idle_timeout_security_tests.rs new file mode 100644 index 0000000..b52af35 --- /dev/null +++ b/src/proxy/tests/masking_consume_idle_timeout_security_tests.rs @@ -0,0 +1,85 @@ +use super::*; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::time::Instant; +use tokio::io::{AsyncRead, ReadBuf}; + +struct OneByteThenStall { + sent: bool, +} + +impl AsyncRead for OneByteThenStall { + fn poll_read( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + if !self.sent { + self.sent = true; + buf.put_slice(&[0x42]); + Poll::Ready(Ok(())) + } else { + Poll::Pending + } + } +} + +#[tokio::test] +async fn stalling_client_terminates_at_idle_not_relay_timeout() { + let reader = OneByteThenStall { sent: false }; + let started = Instant::now(); + + let result = tokio::time::timeout( + MASK_RELAY_TIMEOUT, + consume_client_data(reader, MASK_BUFFER_SIZE * 4), + ) + .await; + + assert!( + result.is_ok(), + "consume_client_data should complete by per-read idle timeout, not hit relay timeout" + ); + + let elapsed = started.elapsed(); + assert!( + elapsed >= (MASK_RELAY_IDLE_TIMEOUT / 2), + "consume_client_data returned too quickly for idle-timeout path: {elapsed:?}" + ); + assert!( + elapsed < MASK_RELAY_TIMEOUT, + "consume_client_data waited full relay timeout ({elapsed:?}); \ + per-read idle timeout is missing" + ); +} + +#[tokio::test] +async fn fast_reader_drains_to_eof() { + let data = vec![0xAAu8; 32 * 1024]; + let reader = std::io::Cursor::new(data); + + tokio::time::timeout(MASK_RELAY_TIMEOUT, consume_client_data(reader, usize::MAX)) + .await + .expect("consume_client_data did not complete for fast EOF reader"); +} + +#[tokio::test] +async fn io_error_terminates_cleanly() { + struct ErrReader; + + impl AsyncRead for ErrReader { + fn poll_read( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _buf: &mut ReadBuf<'_>, + ) -> Poll> { + Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::ConnectionReset, + "simulated reset", + ))) + } + } + + tokio::time::timeout(MASK_RELAY_TIMEOUT, consume_client_data(ErrReader, usize::MAX)) + .await + .expect("consume_client_data did not return on I/O error"); +} diff --git a/src/proxy/tests/masking_consume_stress_adversarial_tests.rs b/src/proxy/tests/masking_consume_stress_adversarial_tests.rs new file mode 100644 index 0000000..12287b5 --- /dev/null +++ b/src/proxy/tests/masking_consume_stress_adversarial_tests.rs @@ -0,0 +1,64 @@ +use super::*; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::time::Instant; +use tokio::io::{AsyncRead, ReadBuf}; +use tokio::task::JoinSet; + +struct OneByteThenStall { + sent: bool, +} + +impl AsyncRead for OneByteThenStall { + fn poll_read( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + if !self.sent { + self.sent = true; + buf.put_slice(&[0xAA]); + Poll::Ready(Ok(())) + } else { + Poll::Pending + } + } +} + +#[tokio::test] +async fn consume_stall_stress_finishes_within_idle_budget() { + let mut set = JoinSet::new(); + let started = Instant::now(); + + for _ in 0..64 { + set.spawn(async { + tokio::time::timeout( + MASK_RELAY_TIMEOUT, + consume_client_data(OneByteThenStall { sent: false }, usize::MAX), + ) + .await + .expect("consume_client_data exceeded relay timeout under stall load"); + }); + } + + while let Some(res) = set.join_next().await { + res.unwrap(); + } + + // Under test constants idle=100ms, relay=200ms. 64 concurrent tasks stalling + // for 100ms should complete well under a strict 600ms boundary. + assert!( + started.elapsed() < MASK_RELAY_TIMEOUT * 3, + "stall stress batch completed too slowly; possible async executor starvation or head-of-line blocking" + ); +} + +#[tokio::test] +async fn consume_zero_cap_returns_immediately() { + let started = Instant::now(); + consume_client_data(tokio::io::empty(), 0).await; + assert!( + started.elapsed() < MASK_RELAY_IDLE_TIMEOUT, + "zero byte cap must return immediately" + ); +} diff --git a/src/proxy/tests/masking_http2_preface_integration_security_tests.rs b/src/proxy/tests/masking_http2_preface_integration_security_tests.rs new file mode 100644 index 0000000..7f1c03f --- /dev/null +++ b/src/proxy/tests/masking_http2_preface_integration_security_tests.rs @@ -0,0 +1,55 @@ +use super::*; +use tokio::net::TcpListener; +use tokio::time::Duration; + +#[tokio::test] +async fn http2_preface_is_forwarded_and_recorded_as_http() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let preface = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn({ + let preface = preface.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut received = vec![0u8; preface.len()]; + stream.read_exact(&mut received).await.unwrap(); + assert_eq!(received, preface); + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = true; + config.general.beobachten_minutes = 1; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let peer: SocketAddr = "198.51.100.130:54130".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let (client_reader, _client_writer) = tokio::io::duplex(512); + let (_client_visible_reader, client_visible_writer) = tokio::io::duplex(512); + let beobachten = BeobachtenStore::new(); + + handle_bad_client( + client_reader, + client_visible_writer, + &preface, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + tokio::time::timeout(Duration::from_secs(2), accept_task) + .await + .unwrap() + .unwrap(); + + let snapshot = beobachten.snapshot_text(Duration::from_secs(60)); + assert!(snapshot.contains("[HTTP]")); + assert!(snapshot.contains("198.51.100.130-1")); +} diff --git a/src/proxy/tests/masking_http2_probe_classification_security_tests.rs b/src/proxy/tests/masking_http2_probe_classification_security_tests.rs new file mode 100644 index 0000000..34e04a9 --- /dev/null +++ b/src/proxy/tests/masking_http2_probe_classification_security_tests.rs @@ -0,0 +1,92 @@ +use super::*; + +#[test] +fn full_http2_preface_classified_as_http_probe() { + let preface = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"; + assert!( + is_http_probe(preface), + "HTTP/2 connection preface must be classified as HTTP probe" + ); +} + +#[test] +fn partial_http2_preface_3_bytes_classified() { + assert!( + is_http_probe(b"PRI"), + "3-byte HTTP/2 preface prefix must be classified" + ); +} + +#[test] +fn partial_http2_preface_2_bytes_classified() { + assert!( + is_http_probe(b"PR"), + "2-byte HTTP/2 preface prefix must be classified" + ); +} + +#[test] +fn existing_http1_methods_unaffected() { + for prefix in [ + b"GET / HTTP/1.1\r\n".as_ref(), + b"POST /api HTTP/1.1\r\n".as_ref(), + b"CONNECT example.com:443 HTTP/1.1\r\n".as_ref(), + b"TRACE / HTTP/1.1\r\n".as_ref(), + b"PATCH / HTTP/1.1\r\n".as_ref(), + ] { + assert!(is_http_probe(prefix)); + } +} + +#[test] +fn non_http_data_not_classified() { + for data in [ + b"\x16\x03\x01\x00\xf1".as_ref(), + b"SSH-2.0-OpenSSH_8.9\r\n".as_ref(), + b"\x00\x01\x02\x03".as_ref(), + b"".as_ref(), + b"P".as_ref(), + ] { + assert!(!is_http_probe(data)); + } +} + +#[test] +fn light_fuzz_non_http_prefixes_not_misclassified() { + // Deterministic pseudo-fuzz to exercise classifier edges while avoiding + // known HTTP method and partial windows. + let mut x = 0x1234_5678u32; + for _ in 0..1024 { + x = x.wrapping_mul(1664525).wrapping_add(1013904223); + let len = 4 + ((x >> 8) as usize % 12); + let mut data = vec![0u8; len]; + for byte in &mut data { + x = x.wrapping_mul(1664525).wrapping_add(1013904223); + *byte = (x & 0xFF) as u8; + } + + if [ + b"GET ".as_ref(), + b"POST".as_ref(), + b"HEAD".as_ref(), + b"PUT ".as_ref(), + b"DELETE".as_ref(), + b"OPTIONS".as_ref(), + b"CONNECT".as_ref(), + b"TRACE".as_ref(), + b"PATCH".as_ref(), + b"PRI ".as_ref(), + ] + .iter() + .any(|m| data.starts_with(m)) + { + continue; + } + + assert!( + !is_http_probe(&data), + "non-http pseudo-fuzz input misclassified: {:?}", + &data[..data.len().min(8)] + ); + } +} diff --git a/src/proxy/tests/masking_http_probe_boundary_security_tests.rs b/src/proxy/tests/masking_http_probe_boundary_security_tests.rs new file mode 100644 index 0000000..47b6dc6 --- /dev/null +++ b/src/proxy/tests/masking_http_probe_boundary_security_tests.rs @@ -0,0 +1,79 @@ +use super::*; + +#[test] +fn exact_four_byte_http_tokens_are_classified() { + for token in [b"GET ".as_ref(), b"POST".as_ref(), b"HEAD".as_ref(), b"PUT ".as_ref(), b"PRI ".as_ref()] { + assert!( + is_http_probe(token), + "exact 4-byte token must be classified as HTTP probe: {:?}", + token + ); + } +} + +#[test] +fn exact_four_byte_non_http_tokens_are_not_classified() { + for token in [ + b"GEX ".as_ref(), + b"POXT".as_ref(), + b"HEA/".as_ref(), + b"PU\0 ".as_ref(), + b"PRI/".as_ref(), + ] { + assert!( + !is_http_probe(token), + "non-HTTP 4-byte token must not be classified: {:?}", + token + ); + } +} + +#[test] +fn detect_client_type_keeps_http_label_for_minimal_four_byte_http_prefixes() { + assert_eq!(detect_client_type(b"GET "), "HTTP"); + assert_eq!(detect_client_type(b"PRI "), "HTTP"); +} + +#[test] +fn exact_long_http_tokens_are_classified() { + for token in [b"CONNECT".as_ref(), b"TRACE".as_ref(), b"PATCH".as_ref()] { + assert!( + is_http_probe(token), + "exact long HTTP token must be classified as HTTP probe: {:?}", + token + ); + } +} + +#[test] +fn detect_client_type_keeps_http_label_for_exact_long_http_tokens() { + assert_eq!(detect_client_type(b"CONNECT"), "HTTP"); + assert_eq!(detect_client_type(b"TRACE"), "HTTP"); + assert_eq!(detect_client_type(b"PATCH"), "HTTP"); +} + +#[test] +fn light_fuzz_four_byte_ascii_noise_not_misclassified() { + // Deterministic pseudo-fuzz over 4-byte printable ASCII inputs. + let mut x = 0xA17C_93E5u32; + for _ in 0..2048 { + let mut token = [0u8; 4]; + for byte in &mut token { + x = x.wrapping_mul(1664525).wrapping_add(1013904223); + *byte = 32 + ((x & 0x3F) as u8); // printable ASCII subset + } + + if [b"GET ", b"POST", b"HEAD", b"PUT ", b"PRI "] + .iter() + .any(|m| token.as_slice() == *m) + { + continue; + } + + assert!( + !is_http_probe(&token), + "pseudo-fuzz noise misclassified as HTTP probe: {:?}", + token + ); + } +} \ No newline at end of file diff --git a/src/proxy/tests/masking_interface_cache_defense_in_depth_security_tests.rs b/src/proxy/tests/masking_interface_cache_defense_in_depth_security_tests.rs new file mode 100644 index 0000000..d82cf82 --- /dev/null +++ b/src/proxy/tests/masking_interface_cache_defense_in_depth_security_tests.rs @@ -0,0 +1,51 @@ +#![cfg(unix)] + +use super::*; + +#[test] +fn defense_in_depth_empty_refresh_preserves_previous_non_empty_interfaces() { + let previous = vec![ + "192.168.100.7" + .parse::() + .expect("must parse interface ip"), + ]; + let refreshed = Vec::new(); + + let next = choose_interface_snapshot(&previous, refreshed); + + assert_eq!( + next, previous, + "empty refresh should preserve previous non-empty snapshot to avoid fail-open loop-guard regressions" + ); +} + +#[test] +fn defense_in_depth_non_empty_refresh_replaces_previous_snapshot() { + let previous = vec![ + "192.168.100.7" + .parse::() + .expect("must parse interface ip"), + ]; + let refreshed = vec![ + "10.55.0.3" + .parse::() + .expect("must parse refreshed interface ip"), + ]; + + let next = choose_interface_snapshot(&previous, refreshed.clone()); + + assert_eq!(next, refreshed); +} + +#[test] +fn defense_in_depth_empty_refresh_keeps_empty_when_no_previous_snapshot_exists() { + let previous = Vec::new(); + let refreshed = Vec::new(); + + let next = choose_interface_snapshot(&previous, refreshed); + + assert!( + next.is_empty(), + "empty refresh with no previous snapshot should remain empty" + ); +} diff --git a/src/proxy/tests/masking_interface_cache_security_tests.rs b/src/proxy/tests/masking_interface_cache_security_tests.rs new file mode 100644 index 0000000..b14d7c3 --- /dev/null +++ b/src/proxy/tests/masking_interface_cache_security_tests.rs @@ -0,0 +1,46 @@ +#![cfg(unix)] + +use super::*; +use std::sync::{Mutex, OnceLock}; + +fn interface_cache_test_lock() -> &'static Mutex<()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())) +} + +#[test] +fn tdd_repeated_local_listener_checks_do_not_repeat_interface_enumeration_within_window() { + let _guard = interface_cache_test_lock() + .lock() + .unwrap_or_else(|poison| poison.into_inner()); + reset_local_interface_enumerations_for_tests(); + + let local_addr: SocketAddr = "0.0.0.0:443".parse().expect("valid local addr"); + + let _ = is_mask_target_local_listener("127.0.0.1", 443, local_addr, None); + let _ = is_mask_target_local_listener("127.0.0.1", 443, local_addr, None); + + assert_eq!( + local_interface_enumerations_for_tests(), + 1, + "interface enumeration must be cached across repeated bad-client checks" + ); +} + +#[test] +fn tdd_non_local_port_short_circuit_does_not_enumerate_interfaces() { + let _guard = interface_cache_test_lock() + .lock() + .unwrap_or_else(|poison| poison.into_inner()); + reset_local_interface_enumerations_for_tests(); + + let local_addr: SocketAddr = "0.0.0.0:443".parse().expect("valid local addr"); + let is_local = is_mask_target_local_listener("127.0.0.1", 8443, local_addr, None); + + assert!(!is_local, "different port must not be treated as local listener"); + assert_eq!( + local_interface_enumerations_for_tests(), + 0, + "port mismatch should bypass interface enumeration entirely" + ); +} diff --git a/src/proxy/tests/masking_offline_target_redteam_expected_fail_tests.rs b/src/proxy/tests/masking_offline_target_redteam_expected_fail_tests.rs new file mode 100644 index 0000000..efa4529 --- /dev/null +++ b/src/proxy/tests/masking_offline_target_redteam_expected_fail_tests.rs @@ -0,0 +1,178 @@ +use super::*; +use std::net::{SocketAddr, TcpListener as StdTcpListener}; +use tokio::io::{AsyncWriteExt, duplex}; +use tokio::time::{Duration, Instant}; + +fn closed_local_port() -> u16 { + let listener = StdTcpListener::bind("127.0.0.1:0").unwrap(); + let port = listener.local_addr().unwrap().port(); + drop(listener); + port +} + +#[tokio::test] +#[ignore = "red-team expected-fail: offline mask target keeps bad-client socket alive before consume timeout boundary"] +async fn redteam_offline_target_should_drop_idle_client_early() { + let (client_read, mut client_write) = duplex(1024); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = closed_local_port(); + cfg.censorship.mask_timing_normalization_enabled = false; + + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let peer_addr: SocketAddr = "192.0.2.50:5000".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + let handler = tokio::spawn(async move { + handle_bad_client( + client_read, + tokio::io::sink(), + b"GET / HTTP/1.1\r\n\r\n", + peer_addr, + local_addr, + &cfg, + &beobachten, + ) + .await; + }); + + tokio::time::sleep(Duration::from_millis(150)).await; + let write_res = client_write.write_all(b"probe-should-be-closed").await; + assert!( + write_res.is_err(), + "offline target path still keeps client writable before consume timeout" + ); + + handler.abort(); +} + +#[tokio::test] +#[ignore = "red-team expected-fail: proxy should mimic immediate RST-like close when target is offline"] +async fn redteam_offline_target_should_not_sleep_to_mask_refusal() { + let (client_read, mut client_write) = duplex(1024); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = closed_local_port(); + cfg.censorship.mask_timing_normalization_enabled = false; + + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let peer_addr: SocketAddr = "192.0.2.51:5000".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + let handler = tokio::spawn(async move { + handle_bad_client( + client_read, + tokio::io::sink(), + b"\x16\x03\x01\x00\x05hello", + peer_addr, + local_addr, + &cfg, + &beobachten, + ) + .await; + }); + + client_write.shutdown().await.unwrap(); + let _ = handler.await; + let elapsed = started.elapsed(); + + assert!( + elapsed < Duration::from_millis(10), + "offline target path still applies coarse masking sleep and is fingerprintable" + ); +} + +#[tokio::test] +#[ignore = "red-team expected-fail: refusal path should remain below strict latency envelope under burst"] +async fn redteam_offline_refusal_burst_timing_spread_should_be_tight() { + let mut samples = Vec::new(); + + for i in 0..12u16 { + let (client_read, mut client_write) = duplex(1024); + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = closed_local_port(); + cfg.censorship.mask_timing_normalization_enabled = false; + + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let peer_addr: SocketAddr = format!("192.0.2.52:{}", 5100 + i).parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + let handler = tokio::spawn(async move { + handle_bad_client( + client_read, + tokio::io::sink(), + b"GET / HTTP/1.1\r\n\r\n", + peer_addr, + local_addr, + &cfg, + &beobachten, + ) + .await; + }); + + client_write.shutdown().await.unwrap(); + let _ = handler.await; + samples.push(started.elapsed()); + } + + let min = samples.iter().copied().min().unwrap_or_default(); + let max = samples.iter().copied().max().unwrap_or_default(); + let spread = max.saturating_sub(min); + + assert!( + spread <= Duration::from_millis(5), + "offline refusal timing spread too wide for strict red-team envelope: {:?}", + spread + ); +} + +#[tokio::test] +#[ignore = "manual red-team: host resolver failure should complete without panic"] +async fn redteam_dns_resolution_failure_must_not_panic() { + let (client_read, mut client_write) = duplex(1024); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("this.domain.definitely.does.not.exist.invalid".to_string()); + cfg.censorship.mask_port = 443; + + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let peer_addr: SocketAddr = "192.0.2.99:5999".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + let handler = tokio::spawn(async move { + handle_bad_client( + client_read, + tokio::io::sink(), + b"GET / HTTP/1.1\r\n\r\n", + peer_addr, + local_addr, + &cfg, + &beobachten, + ) + .await; + }); + + client_write.shutdown().await.unwrap(); + let result = tokio::time::timeout(Duration::from_secs(2), handler).await; + assert!( + result.is_ok(), + "dns failure path stalled or panicked instead of terminating" + ); +} diff --git a/src/proxy/tests/masking_padding_timeout_adversarial_tests.rs b/src/proxy/tests/masking_padding_timeout_adversarial_tests.rs new file mode 100644 index 0000000..b99b4bc --- /dev/null +++ b/src/proxy/tests/masking_padding_timeout_adversarial_tests.rs @@ -0,0 +1,51 @@ +use super::*; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::time::Instant; +use tokio::io::AsyncWrite; + +struct NeverWritable; + +impl AsyncWrite for NeverWritable { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _buf: &[u8], + ) -> Poll> { + Poll::Pending + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Pending + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} + +#[tokio::test] +async fn shape_padding_returns_before_global_mask_timeout_on_blocked_writer() { + let mut writer = NeverWritable; + let started = Instant::now(); + + maybe_write_shape_padding(&mut writer, 1, true, 256, 4096, false, 0, false).await; + + assert!( + started.elapsed() <= MASK_TIMEOUT + std::time::Duration::from_millis(30), + "shape padding blocked past timeout budget" + ); +} + +#[tokio::test] +async fn shape_padding_with_non_http_blur_disabled_at_cap_writes_nothing() { + let mut output = Vec::new(); + { + let mut writer = tokio::io::BufWriter::new(&mut output); + maybe_write_shape_padding(&mut writer, 4096, true, 64, 4096, false, 128, false).await; + use tokio::io::AsyncWriteExt; + writer.flush().await.unwrap(); + } + + assert!(output.is_empty()); +} diff --git a/src/proxy/tests/masking_relay_guardrails_security_tests.rs b/src/proxy/tests/masking_relay_guardrails_security_tests.rs new file mode 100644 index 0000000..257c0f8 --- /dev/null +++ b/src/proxy/tests/masking_relay_guardrails_security_tests.rs @@ -0,0 +1,105 @@ +use super::*; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex, sink}; +use tokio::time::{Duration, timeout}; + +#[tokio::test] +async fn relay_to_mask_enforces_masking_session_byte_cap() { + let initial = vec![0x16, 0x03, 0x01, 0x00, 0x01]; + let extra = vec![0xAB; 96 * 1024]; + + let (client_reader, mut client_writer) = duplex(128 * 1024); + let (mask_read, _mask_read_peer) = duplex(1024); + let (mut mask_observer, mask_write) = duplex(256 * 1024); + let initial_for_task = initial.clone(); + + let relay = tokio::spawn(async move { + relay_to_mask( + client_reader, + sink(), + mask_read, + mask_write, + &initial_for_task, + false, + 512, + 4096, + false, + 0, + false, + 32 * 1024, + ) + .await; + }); + + client_writer.write_all(&extra).await.unwrap(); + client_writer.shutdown().await.unwrap(); + + timeout(Duration::from_secs(2), relay) + .await + .unwrap() + .unwrap(); + + let mut observed = Vec::new(); + timeout( + Duration::from_secs(2), + mask_observer.read_to_end(&mut observed), + ) + .await + .unwrap() + .unwrap(); + + // In this deterministic test, relay must stop exactly at the configured cap. + assert_eq!( + observed.len(), + initial.len() + (32 * 1024), + "masked relay must forward exactly up to the cap (observed={} initial={} cap={})", + observed.len(), + initial.len(), + 32 * 1024 + ); +} + +#[tokio::test] +async fn relay_to_mask_propagates_client_half_close_without_waiting_for_other_direction_timeout() { + let initial = b"GET /half-close HTTP/1.1\r\n".to_vec(); + + let (client_reader, mut client_writer) = duplex(8 * 1024); + let (mask_read, _mask_read_peer) = duplex(8 * 1024); + let (mut mask_observer, mask_write) = duplex(8 * 1024); + let initial_for_task = initial.clone(); + + let relay = tokio::spawn(async move { + relay_to_mask( + client_reader, + sink(), + mask_read, + mask_write, + &initial_for_task, + false, + 512, + 4096, + false, + 0, + false, + 32 * 1024, + ) + .await; + }); + + client_writer.shutdown().await.unwrap(); + + let mut observed = Vec::new(); + timeout( + Duration::from_millis(80), + mask_observer.read_to_end(&mut observed), + ) + .await + .expect("mask backend write side should be half-closed promptly") + .unwrap(); + + assert_eq!(&observed[..initial.len()], initial.as_slice()); + + timeout(Duration::from_secs(2), relay) + .await + .unwrap() + .unwrap(); +} diff --git a/src/proxy/tests/masking_rng_hoist_perf_regression_tests.rs b/src/proxy/tests/masking_rng_hoist_perf_regression_tests.rs new file mode 100644 index 0000000..627c48b --- /dev/null +++ b/src/proxy/tests/masking_rng_hoist_perf_regression_tests.rs @@ -0,0 +1,100 @@ +use super::*; +use tokio::io::AsyncReadExt; +use tokio::time::{Duration, timeout}; + +async fn collect_padding( + total_sent: usize, + enabled: bool, + floor: usize, + cap: usize, + above_cap_blur: bool, + blur_max: usize, + aggressive: bool, +) -> Vec { + let (mut tx, mut rx) = tokio::io::duplex(256 * 1024); + + maybe_write_shape_padding( + &mut tx, + total_sent, + enabled, + floor, + cap, + above_cap_blur, + blur_max, + aggressive, + ) + .await; + + drop(tx); + + let mut output = Vec::new(); + timeout(Duration::from_secs(1), rx.read_to_end(&mut output)) + .await + .expect("reading padded output timed out") + .expect("failed reading padded output"); + output +} + +#[tokio::test] +async fn padding_output_is_not_all_zero() { + let output = collect_padding(1, true, 256, 4096, false, 0, false).await; + + assert!( + output.len() >= 255, + "expected at least 255 padding bytes, got {}", + output.len() + ); + + let nonzero = output.iter().filter(|&&b| b != 0).count(); + // In 255 bytes of uniform randomness, the expected number of zero bytes is ~1. + // A weak nonzero check can miss severe entropy collapse. + assert!( + nonzero >= 240, + "RNG output entropy collapsed, too many zero bytes: {} nonzero out of {}", + nonzero, + output.len(), + ); +} + +#[tokio::test] +async fn padding_reaches_first_bucket_boundary() { + let output = collect_padding(1, true, 64, 4096, false, 0, false).await; + assert_eq!(output.len(), 63); +} + +#[tokio::test] +async fn disabled_padding_produces_no_output() { + let output = collect_padding(0, false, 256, 4096, false, 0, false).await; + assert!(output.is_empty()); +} + +#[tokio::test] +async fn at_cap_without_blur_produces_no_output() { + let output = collect_padding(4096, true, 64, 4096, false, 0, false).await; + assert!(output.is_empty()); +} + +#[tokio::test] +async fn above_cap_blur_is_positive_and_bounded_in_aggressive_mode() { + let output = collect_padding(4096, true, 64, 4096, true, 128, true).await; + assert!(!output.is_empty()); + assert!(output.len() <= 128, "blur exceeded max: {}", output.len()); +} + +#[tokio::test] +async fn stress_padding_runs_are_not_constant_pattern() { + // Stress and sanity-check: repeated runs should not collapse to identical + // first 16 bytes across all samples. + let mut first_chunks = Vec::new(); + for _ in 0..64 { + let out = collect_padding(1, true, 64, 4096, false, 0, false).await; + first_chunks.push(out[..16].to_vec()); + } + + let first = &first_chunks[0]; + let all_same = first_chunks.iter().all(|chunk| chunk == first); + assert!( + !all_same, + "all stress samples had identical prefix, rng output appears degenerate" + ); +} diff --git a/src/proxy/tests/masking_security_tests.rs b/src/proxy/tests/masking_security_tests.rs index 4519d85..c698b55 100644 --- a/src/proxy/tests/masking_security_tests.rs +++ b/src/proxy/tests/masking_security_tests.rs @@ -1376,6 +1376,7 @@ async fn relay_to_mask_keeps_backend_to_client_flow_when_client_to_backend_stall false, 0, false, + 5 * 1024 * 1024, ) .await; }); @@ -1506,6 +1507,7 @@ async fn relay_to_mask_timeout_cancels_and_drops_all_io_endpoints() { false, 0, false, + 5 * 1024 * 1024, ), ) .await; diff --git a/src/proxy/tests/masking_self_target_loop_security_tests.rs b/src/proxy/tests/masking_self_target_loop_security_tests.rs new file mode 100644 index 0000000..b92ce3d --- /dev/null +++ b/src/proxy/tests/masking_self_target_loop_security_tests.rs @@ -0,0 +1,354 @@ +use super::*; +use std::net::TcpListener as StdTcpListener; +use std::net::SocketAddr; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::net::TcpListener; +use tokio::time::{Duration, Instant, timeout}; + +fn closed_local_port() -> u16 { + let listener = StdTcpListener::bind("127.0.0.1:0").unwrap(); + let port = listener.local_addr().unwrap().port(); + drop(listener); + port +} + +#[test] +fn self_target_detection_matches_literal_ipv4_listener() { + let local: SocketAddr = "198.51.100.40:443".parse().unwrap(); + assert!(is_mask_target_local_listener( + "198.51.100.40", + 443, + local, + None, + )); +} + +#[test] +fn self_target_detection_matches_bracketed_ipv6_listener() { + let local: SocketAddr = "[2001:db8::44]:8443".parse().unwrap(); + assert!(is_mask_target_local_listener( + "[2001:db8::44]", + 8443, + local, + None, + )); +} + +#[test] +fn self_target_detection_keeps_same_ip_different_port_forwardable() { + let local: SocketAddr = "203.0.113.44:443".parse().unwrap(); + assert!(!is_mask_target_local_listener( + "203.0.113.44", + 8443, + local, + None, + )); +} + +#[test] +fn self_target_detection_normalizes_ipv4_mapped_ipv6_literal() { + let local: SocketAddr = "127.0.0.1:443".parse().unwrap(); + assert!(is_mask_target_local_listener( + "::ffff:127.0.0.1", + 443, + local, + None, + )); +} + +#[test] +fn self_target_detection_unspecified_bind_blocks_loopback_target() { + let local: SocketAddr = "0.0.0.0:443".parse().unwrap(); + assert!(is_mask_target_local_listener( + "127.0.0.1", + 443, + local, + None, + )); +} + +#[test] +fn self_target_detection_unspecified_bind_keeps_remote_target_forwardable() { + let local: SocketAddr = "0.0.0.0:443".parse().unwrap(); + let remote: SocketAddr = "198.51.100.44:443".parse().unwrap(); + assert!(!is_mask_target_local_listener( + "mask.example", + 443, + local, + Some(remote), + )); +} + +#[tokio::test] +async fn self_target_fallback_refuses_recursive_loopback_connect() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let local_addr = listener.local_addr().unwrap(); + let accept_task = tokio::spawn(async move { + timeout(Duration::from_millis(120), listener.accept()) + .await + .is_ok() + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = None; + config.censorship.mask_host = Some(local_addr.ip().to_string()); + config.censorship.mask_port = local_addr.port(); + config.censorship.mask_proxy_protocol = 0; + + let peer: SocketAddr = "203.0.113.90:55090".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + handle_bad_client( + tokio::io::empty(), + tokio::io::sink(), + b"GET /", + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + let accepted = accept_task.await.unwrap(); + assert!( + !accepted, + "self-target masking must fail closed without connecting to local listener" + ); +} + +#[tokio::test] +async fn same_ip_different_port_still_forwards_to_mask_backend() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = b"GET /".to_vec(); + let accept_task = tokio::spawn({ + let expected = probe.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = vec![0u8; expected.len()]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, expected); + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = None; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_proxy_protocol = 0; + + let peer: SocketAddr = "203.0.113.91:55091".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + handle_bad_client( + tokio::io::empty(), + tokio::io::sink(), + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + timeout(Duration::from_secs(2), accept_task) + .await + .unwrap() + .unwrap(); +} + +#[test] +fn detect_client_type_http_boundary_get_and_post() { + assert_eq!(detect_client_type(b"GET "), "HTTP"); + assert_eq!(detect_client_type(b"GET /"), "HTTP"); + + assert_eq!(detect_client_type(b"POST"), "HTTP"); + assert_eq!(detect_client_type(b"POST "), "HTTP"); + assert_eq!(detect_client_type(b"POSTX"), "HTTP"); +} + +#[test] +fn detect_client_type_tls_and_length_boundaries() { + assert_eq!(detect_client_type(b"\x16\x03\x01"), "port-scanner"); + assert_eq!(detect_client_type(b"\x16\x03\x01\x00"), "TLS-scanner"); + + assert_eq!(detect_client_type(b"123456789"), "port-scanner"); + assert_eq!(detect_client_type(b"1234567890"), "unknown"); +} + +#[test] +fn build_mask_proxy_header_v1_cross_family_falls_back_to_unknown() { + let peer: SocketAddr = "192.168.1.5:12345".parse().unwrap(); + let local: SocketAddr = "[2001:db8::1]:443".parse().unwrap(); + let header = build_mask_proxy_header(1, peer, local).unwrap(); + assert_eq!(header, b"PROXY UNKNOWN\r\n"); +} + +#[test] +fn next_mask_shape_bucket_checked_mul_overflow_fails_closed() { + let floor = usize::MAX / 2 + 1; + let cap = usize::MAX; + let total = floor + 1; + assert_eq!(next_mask_shape_bucket(total, floor, cap), total); +} + +#[tokio::test] +async fn self_target_reject_path_keeps_timing_budget() { + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = None; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = 443; + + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let peer: SocketAddr = "203.0.113.92:55092".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + let (client, server) = duplex(1024); + drop(client); + + let started = Instant::now(); + handle_bad_client( + server, + tokio::io::sink(), + b"GET / HTTP/1.1\r\n", + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + let elapsed = started.elapsed(); + assert!( + elapsed >= Duration::from_millis(40) && elapsed < Duration::from_millis(250), + "self-target reject path must keep coarse timing budget without stalling" + ); +} + +#[tokio::test] +async fn relay_path_idle_timeout_eviction_remains_effective() { + let (client_read, mut client_write) = duplex(1024); + let (mask_read, mask_write) = duplex(1024); + + tokio::spawn(async move { + tokio::time::sleep(Duration::from_millis(10)).await; + client_write.write_all(b"a").await.unwrap(); + tokio::time::sleep(Duration::from_millis(180)).await; + let _ = client_write.write_all(b"b").await; + }); + + let started = Instant::now(); + relay_to_mask( + client_read, + tokio::io::sink(), + mask_read, + mask_write, + b"init", + false, + 0, + 0, + false, + 0, + false, + 5 * 1024 * 1024, + ) + .await; + + let elapsed = started.elapsed(); + assert!( + elapsed >= Duration::from_millis(90) && elapsed < Duration::from_millis(180), + "idle-timeout eviction must occur before late trickle write" + ); +} + +#[tokio::test] +async fn offline_mask_target_refusal_respects_timing_normalization_budget() { + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = None; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = closed_local_port(); + config.censorship.mask_timing_normalization_enabled = true; + config.censorship.mask_timing_normalization_floor_ms = 120; + config.censorship.mask_timing_normalization_ceiling_ms = 120; + + let peer: SocketAddr = "203.0.113.93:55093".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + let (mut client, server) = duplex(1024); + let started = Instant::now(); + let task = tokio::spawn(async move { + handle_bad_client( + server, + tokio::io::sink(), + b"GET / HTTP/1.1\r\n\r\n", + peer, + local_addr, + &config, + &beobachten, + ) + .await; + }); + + client.shutdown().await.unwrap(); + timeout(Duration::from_secs(2), task).await.unwrap().unwrap(); + let elapsed = started.elapsed(); + + assert!( + elapsed >= Duration::from_millis(110) && elapsed < Duration::from_millis(220), + "offline-refusal path must honor normalization budget without unbounded drift" + ); +} + +#[tokio::test] +async fn offline_mask_target_refusal_with_idle_client_is_bounded_by_consume_timeout() { + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = None; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = closed_local_port(); + config.censorship.mask_timing_normalization_enabled = false; + + let peer: SocketAddr = "203.0.113.94:55094".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + let (mut client, server) = duplex(1024); + let started = Instant::now(); + let task = tokio::spawn(async move { + handle_bad_client( + server, + tokio::io::sink(), + b"GET / HTTP/1.1\r\n\r\n", + peer, + local_addr, + &config, + &beobachten, + ) + .await; + }); + + tokio::time::sleep(Duration::from_millis(120)).await; + client + .write_all(b"still-open-before-timeout") + .await + .expect("connection should still be open before consume timeout expires"); + + timeout(Duration::from_secs(2), task).await.unwrap().unwrap(); + let elapsed = started.elapsed(); + + assert!( + elapsed >= Duration::from_millis(190) && elapsed < Duration::from_millis(350), + "offline-refusal path must not retain idle client indefinitely" + ); +} diff --git a/src/proxy/tests/masking_shape_guard_adversarial_tests.rs b/src/proxy/tests/masking_shape_guard_adversarial_tests.rs index 982fd26..4fa8da7 100644 --- a/src/proxy/tests/masking_shape_guard_adversarial_tests.rs +++ b/src/proxy/tests/masking_shape_guard_adversarial_tests.rs @@ -43,6 +43,7 @@ async fn run_relay_case( above_cap_blur, above_cap_blur_max_bytes, false, + 5 * 1024 * 1024, ) .await; }); diff --git a/src/proxy/tests/masking_shape_hardening_adversarial_tests.rs b/src/proxy/tests/masking_shape_hardening_adversarial_tests.rs index 3c886ba..9abf3c0 100644 --- a/src/proxy/tests/masking_shape_hardening_adversarial_tests.rs +++ b/src/proxy/tests/masking_shape_hardening_adversarial_tests.rs @@ -88,6 +88,7 @@ async fn relay_to_mask_applies_cap_clamped_padding_for_non_power_of_two_cap() { false, 0, false, + 5 * 1024 * 1024, ) .await; }); diff --git a/src/proxy/tests/middle_relay_blackhat_campaign_integration_tests.rs b/src/proxy/tests/middle_relay_blackhat_campaign_integration_tests.rs index 2c9f3f6..6f0e91a 100644 --- a/src/proxy/tests/middle_relay_blackhat_campaign_integration_tests.rs +++ b/src/proxy/tests/middle_relay_blackhat_campaign_integration_tests.rs @@ -9,6 +9,7 @@ use tokio::time::{Duration, timeout}; #[tokio::test(flavor = "multi_thread", worker_threads = 4)] async fn blackhat_campaign_saturation_quota_race_with_queue_pressure_stays_fail_closed() { let _guard = super::quota_user_lock_test_scope(); + let _pressure_guard = super::relay_idle_pressure_test_scope(); let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); map.clear(); diff --git a/src/proxy/tests/middle_relay_hol_quota_security_tests.rs b/src/proxy/tests/middle_relay_hol_quota_security_tests.rs new file mode 100644 index 0000000..3d7929b --- /dev/null +++ b/src/proxy/tests/middle_relay_hol_quota_security_tests.rs @@ -0,0 +1,229 @@ +use super::*; +use crate::crypto::{AesCtr, SecureRandom}; +use crate::stats::Stats; +use crate::stream::CryptoWriter; +use bytes::Bytes; +use std::io; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; +use std::task::{Context, Poll, Waker}; +use tokio::io::AsyncWrite; +use tokio::time::{Duration, timeout}; + +fn make_crypto_writer(writer: W) -> CryptoWriter +where + W: tokio::io::AsyncWrite + Unpin, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) +} + +#[derive(Default)] +struct GateState { + open: AtomicBool, + parked_waker: std::sync::Mutex>, +} + +impl GateState { + fn open(&self) { + self.open.store(true, Ordering::Relaxed); + if let Ok(mut guard) = self.parked_waker.lock() + && let Some(w) = guard.take() + { + w.wake(); + } + } + + fn has_waiter(&self) -> bool { + self.parked_waker + .lock() + .map(|guard| guard.is_some()) + .unwrap_or(false) + } +} + +#[derive(Default)] +struct GateWriter { + gate: Arc, +} + +impl GateWriter { + fn new(gate: Arc) -> Self { + Self { gate } + } +} + +impl AsyncWrite for GateWriter { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + if self.gate.open.load(Ordering::Relaxed) { + return Poll::Ready(Ok(buf.len())); + } + + if let Ok(mut guard) = self.gate.parked_waker.lock() { + *guard = Some(cx.waker().clone()); + } + Poll::Pending + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} + +struct FailingWriter; + +impl AsyncWrite for FailingWriter { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _buf: &[u8], + ) -> Poll> { + Poll::Ready(Err(io::Error::new( + io::ErrorKind::BrokenPipe, + "injected writer failure", + ))) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn adversarial_same_user_slow_writer_must_not_hol_block_peer_connection() { + let stats = Stats::new(); + let bytes_me2c = AtomicU64::new(0); + let rng = SecureRandom::new(); + let quota_limit = Some(1024); + let user = "hol-quota-user"; + + let gate = Arc::new(GateState::default()); + + let mut blocked_writer = make_crypto_writer(GateWriter::new(Arc::clone(&gate))); + let slow_task = tokio::spawn(async move { + let mut frame_buf = Vec::new(); + process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0x10, 0x20, 0x30, 0x40]), + }, + &mut blocked_writer, + ProtoTag::Intermediate, + &rng, + &mut frame_buf, + &stats, + user, + quota_limit, + &bytes_me2c, + 7001, + false, + false, + ) + .await + }); + + timeout(Duration::from_millis(100), async { + loop { + if gate.has_waiter() { + break; + } + tokio::task::yield_now().await; + } + }) + .await + .expect("first writer must reach backpressure and park"); + + let stats_fast = Stats::new(); + let bytes_fast = AtomicU64::new(0); + let rng_fast = SecureRandom::new(); + let mut fast_writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf_fast = Vec::new(); + + timeout( + Duration::from_millis(50), + process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0x41]), + }, + &mut fast_writer, + ProtoTag::Intermediate, + &rng_fast, + &mut frame_buf_fast, + &stats_fast, + user, + quota_limit, + &bytes_fast, + 7002, + false, + false, + ), + ) + .await + .expect("peer connection must not be blocked by same-user stalled write") + .expect("fast peer write must succeed"); + + gate.open(); + let slow_result = timeout(Duration::from_secs(1), slow_task) + .await + .expect("stalled task must complete once gate opens") + .expect("stalled task must not panic"); + assert!(slow_result.is_ok()); +} + +#[tokio::test] +async fn negative_write_failure_rolls_back_pre_accounted_quota_and_forensics_bytes() { + let stats = Stats::new(); + let user = "rollback-user"; + stats.add_user_octets_from(user, 7); + + let bytes_me2c = AtomicU64::new(0); + let rng = SecureRandom::new(); + let mut writer = make_crypto_writer(FailingWriter); + let mut frame_buf = Vec::new(); + + let result = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[1, 2, 3, 4]), + }, + &mut writer, + ProtoTag::Intermediate, + &rng, + &mut frame_buf, + &stats, + user, + Some(64), + &bytes_me2c, + 7003, + false, + false, + ) + .await; + + assert!(matches!(result, Err(ProxyError::Io(_)))); + assert_eq!( + stats.get_user_total_octets(user), + 7, + "failed client write must not overcharge user quota accounting" + ); + assert_eq!( + bytes_me2c.load(Ordering::Relaxed), + 0, + "failed client write must not inflate ME->C forensic byte counter" + ); +} \ No newline at end of file diff --git a/src/proxy/tests/middle_relay_idle_policy_security_tests.rs b/src/proxy/tests/middle_relay_idle_policy_security_tests.rs index 3e0b30f..6ea182b 100644 --- a/src/proxy/tests/middle_relay_idle_policy_security_tests.rs +++ b/src/proxy/tests/middle_relay_idle_policy_security_tests.rs @@ -3,7 +3,7 @@ use crate::crypto::AesCtr; use crate::stats::Stats; use crate::stream::{BufferPool, CryptoReader}; use std::sync::atomic::AtomicU64; -use std::sync::{Arc, Mutex, OnceLock}; +use std::sync::Arc; use tokio::io::AsyncWriteExt; use tokio::io::duplex; use tokio::time::{Duration as TokioDuration, Instant as TokioInstant, timeout}; @@ -48,18 +48,6 @@ fn make_idle_policy(soft_ms: u64, hard_ms: u64, grace_ms: u64) -> RelayClientIdl } } -fn idle_pressure_test_lock() -> &'static Mutex<()> { - static TEST_LOCK: OnceLock> = OnceLock::new(); - TEST_LOCK.get_or_init(|| Mutex::new(())) -} - -fn acquire_idle_pressure_test_lock() -> std::sync::MutexGuard<'static, ()> { - match idle_pressure_test_lock().lock() { - Ok(guard) => guard, - Err(poisoned) => poisoned.into_inner(), - } -} - #[tokio::test] async fn idle_policy_soft_mark_then_hard_close_increments_reason_counters() { let (reader, _writer) = duplex(1024); @@ -372,7 +360,7 @@ async fn stress_many_idle_sessions_fail_closed_without_hang() { #[test] fn pressure_evicts_oldest_idle_candidate_with_deterministic_ordering() { - let _guard = acquire_idle_pressure_test_lock(); + let _guard = relay_idle_pressure_test_scope(); clear_relay_idle_pressure_state_for_testing(); let stats = Stats::new(); @@ -402,7 +390,7 @@ fn pressure_evicts_oldest_idle_candidate_with_deterministic_ordering() { #[test] fn pressure_does_not_evict_without_new_pressure_signal() { - let _guard = acquire_idle_pressure_test_lock(); + let _guard = relay_idle_pressure_test_scope(); clear_relay_idle_pressure_state_for_testing(); let stats = Stats::new(); @@ -421,7 +409,7 @@ fn pressure_does_not_evict_without_new_pressure_signal() { #[test] fn stress_pressure_eviction_preserves_fifo_across_many_candidates() { - let _guard = acquire_idle_pressure_test_lock(); + let _guard = relay_idle_pressure_test_scope(); clear_relay_idle_pressure_state_for_testing(); let stats = Stats::new(); @@ -457,7 +445,7 @@ fn stress_pressure_eviction_preserves_fifo_across_many_candidates() { #[test] fn blackhat_single_pressure_event_must_not_evict_more_than_one_candidate() { - let _guard = acquire_idle_pressure_test_lock(); + let _guard = relay_idle_pressure_test_scope(); clear_relay_idle_pressure_state_for_testing(); let stats = Stats::new(); @@ -491,7 +479,7 @@ fn blackhat_single_pressure_event_must_not_evict_more_than_one_candidate() { #[test] fn blackhat_pressure_counter_must_track_global_budget_not_per_session_cursor() { - let _guard = acquire_idle_pressure_test_lock(); + let _guard = relay_idle_pressure_test_scope(); clear_relay_idle_pressure_state_for_testing(); let stats = Stats::new(); @@ -524,7 +512,7 @@ fn blackhat_pressure_counter_must_track_global_budget_not_per_session_cursor() { #[test] fn blackhat_stale_pressure_before_idle_mark_must_not_trigger_eviction() { - let _guard = acquire_idle_pressure_test_lock(); + let _guard = relay_idle_pressure_test_scope(); clear_relay_idle_pressure_state_for_testing(); let stats = Stats::new(); @@ -543,7 +531,7 @@ fn blackhat_stale_pressure_before_idle_mark_must_not_trigger_eviction() { #[test] fn blackhat_stale_pressure_must_not_evict_any_of_newly_marked_batch() { - let _guard = acquire_idle_pressure_test_lock(); + let _guard = relay_idle_pressure_test_scope(); clear_relay_idle_pressure_state_for_testing(); let stats = Stats::new(); @@ -575,7 +563,7 @@ fn blackhat_stale_pressure_must_not_evict_any_of_newly_marked_batch() { #[test] fn blackhat_stale_pressure_seen_without_candidates_must_be_globally_invalidated() { - let _guard = acquire_idle_pressure_test_lock(); + let _guard = relay_idle_pressure_test_scope(); clear_relay_idle_pressure_state_for_testing(); let stats = Stats::new(); @@ -601,7 +589,7 @@ fn blackhat_stale_pressure_seen_without_candidates_must_be_globally_invalidated( #[test] fn blackhat_stale_pressure_must_not_survive_candidate_churn() { - let _guard = acquire_idle_pressure_test_lock(); + let _guard = relay_idle_pressure_test_scope(); clear_relay_idle_pressure_state_for_testing(); let stats = Stats::new(); @@ -621,7 +609,7 @@ fn blackhat_stale_pressure_must_not_survive_candidate_churn() { #[test] fn blackhat_pressure_seq_saturation_must_not_disable_future_pressure_accounting() { - let _guard = acquire_idle_pressure_test_lock(); + let _guard = relay_idle_pressure_test_scope(); clear_relay_idle_pressure_state_for_testing(); { @@ -646,7 +634,7 @@ fn blackhat_pressure_seq_saturation_must_not_disable_future_pressure_accounting( #[test] fn blackhat_pressure_seq_saturation_must_not_break_multiple_distinct_events() { - let _guard = acquire_idle_pressure_test_lock(); + let _guard = relay_idle_pressure_test_scope(); clear_relay_idle_pressure_state_for_testing(); { @@ -673,7 +661,7 @@ fn blackhat_pressure_seq_saturation_must_not_break_multiple_distinct_events() { #[tokio::test(flavor = "multi_thread", worker_threads = 4)] async fn integration_race_single_pressure_event_allows_at_most_one_eviction_under_parallel_claims() { - let _guard = acquire_idle_pressure_test_lock(); + let _guard = relay_idle_pressure_test_scope(); clear_relay_idle_pressure_state_for_testing(); let stats = Arc::new(Stats::new()); @@ -738,7 +726,7 @@ async fn integration_race_single_pressure_event_allows_at_most_one_eviction_unde #[tokio::test(flavor = "multi_thread", worker_threads = 4)] async fn integration_race_burst_pressure_with_churn_preserves_empty_set_invalidation_and_budget() { - let _guard = acquire_idle_pressure_test_lock(); + let _guard = relay_idle_pressure_test_scope(); clear_relay_idle_pressure_state_for_testing(); let stats = Arc::new(Stats::new()); diff --git a/src/proxy/tests/middle_relay_idle_registry_poison_security_tests.rs b/src/proxy/tests/middle_relay_idle_registry_poison_security_tests.rs new file mode 100644 index 0000000..112d926 --- /dev/null +++ b/src/proxy/tests/middle_relay_idle_registry_poison_security_tests.rs @@ -0,0 +1,59 @@ +use super::*; +use std::panic::{AssertUnwindSafe, catch_unwind}; + +#[test] +fn blackhat_registry_poison_recovers_with_fail_closed_reset_and_pressure_accounting() { + let _guard = relay_idle_pressure_test_scope(); + clear_relay_idle_pressure_state_for_testing(); + + let _ = catch_unwind(AssertUnwindSafe(|| { + let registry = relay_idle_candidate_registry(); + let mut guard = registry + .lock() + .expect("registry lock must be acquired before poison"); + guard.by_conn_id.insert( + 999, + RelayIdleCandidateMeta { + mark_order_seq: 1, + mark_pressure_seq: 0, + }, + ); + guard.ordered.insert((1, 999)); + panic!("intentional poison for idle-registry recovery"); + })); + + // Helper lock must recover from poison, reset stale state, and continue. + assert!(mark_relay_idle_candidate(42)); + assert_eq!(oldest_relay_idle_candidate(), Some(42)); + + let before = relay_pressure_event_seq(); + note_relay_pressure_event(); + let after = relay_pressure_event_seq(); + assert!(after > before, "pressure accounting must still advance after poison"); + + clear_relay_idle_pressure_state_for_testing(); +} + +#[test] +fn clear_state_helper_must_reset_poisoned_registry_for_deterministic_fifo_tests() { + let _guard = relay_idle_pressure_test_scope(); + clear_relay_idle_pressure_state_for_testing(); + + let _ = catch_unwind(AssertUnwindSafe(|| { + let registry = relay_idle_candidate_registry(); + let _guard = registry + .lock() + .expect("registry lock must be acquired before poison"); + panic!("intentional poison while lock held"); + })); + + clear_relay_idle_pressure_state_for_testing(); + + assert_eq!(oldest_relay_idle_candidate(), None); + assert_eq!(relay_pressure_event_seq(), 0); + + assert!(mark_relay_idle_candidate(7)); + assert_eq!(oldest_relay_idle_candidate(), Some(7)); + + clear_relay_idle_pressure_state_for_testing(); +} diff --git a/src/proxy/tests/middle_relay_quota_reservation_adversarial_tests.rs b/src/proxy/tests/middle_relay_quota_reservation_adversarial_tests.rs new file mode 100644 index 0000000..717a375 --- /dev/null +++ b/src/proxy/tests/middle_relay_quota_reservation_adversarial_tests.rs @@ -0,0 +1,192 @@ +use super::*; +use crate::crypto::{AesCtr, SecureRandom}; +use crate::stats::Stats; +use crate::stream::CryptoWriter; +use bytes::Bytes; +use std::sync::Arc; +use std::sync::atomic::{AtomicU64, Ordering}; +use tokio::task::JoinSet; + +fn make_crypto_writer(writer: W) -> CryptoWriter +where + W: tokio::io::AsyncWrite + Unpin, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) +} + +#[tokio::test] +async fn positive_exact_quota_boundary_allows_last_frame_and_blocks_next() { + let stats = Stats::new(); + let user = "quota-boundary-user"; + let bytes_me2c = AtomicU64::new(0); + + stats.add_user_octets_from(user, 5); + + let mut writer_one = make_crypto_writer(tokio::io::sink()); + let mut frame_buf_one = Vec::new(); + let first = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[1, 2, 3]), + }, + &mut writer_one, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf_one, + &stats, + user, + Some(8), + &bytes_me2c, + 7101, + false, + false, + ) + .await; + + assert!(first.is_ok(), "frame that reaches boundary must be allowed"); + assert_eq!(stats.get_user_total_octets(user), 8); + + let mut writer_two = make_crypto_writer(tokio::io::sink()); + let mut frame_buf_two = Vec::new(); + let second = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[9]), + }, + &mut writer_two, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf_two, + &stats, + user, + Some(8), + &bytes_me2c, + 7102, + false, + false, + ) + .await; + + assert!( + matches!(second, Err(ProxyError::DataQuotaExceeded { .. })), + "frame after boundary must be rejected" + ); + assert_eq!(stats.get_user_total_octets(user), 8); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), 3); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn adversarial_parallel_reservation_stress_never_overshoots_quota_or_counters() { + let stats = Arc::new(Stats::new()); + let user = "reservation-stress-user"; + let quota_limit = 64u64; + let bytes_me2c = Arc::new(AtomicU64::new(0)); + let mut tasks = JoinSet::new(); + + for idx in 0..256u64 { + let user_owned = user.to_string(); + let stats_ref = Arc::clone(&stats); + let bytes_ref = Arc::clone(&bytes_me2c); + + tasks.spawn(async move { + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0xAB]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + stats_ref.as_ref(), + &user_owned, + Some(quota_limit), + bytes_ref.as_ref(), + 7200 + idx, + false, + false, + ) + .await + }); + } + + let mut ok = 0usize; + let mut denied = 0usize; + while let Some(joined) = tasks.join_next().await { + match joined.expect("reservation stress task must not panic") { + Ok(_) => ok += 1, + Err(ProxyError::DataQuotaExceeded { .. }) => denied += 1, + Err(other) => panic!("unexpected error in stress case: {other:?}"), + } + } + + let total = stats.get_user_total_octets(user); + assert_eq!( + total, quota_limit, + "quota must be exactly exhausted without overshoot" + ); + assert_eq!( + bytes_me2c.load(Ordering::Relaxed), + total, + "ME->C forensic bytes must track committed quota usage" + ); + assert_eq!(ok, quota_limit as usize, "exactly quota_limit tasks must succeed"); + assert_eq!( + denied, + 256usize - (quota_limit as usize), + "remaining tasks must be exactly denied without silently swallowing state" + ); +} + +#[tokio::test] +async fn light_fuzz_random_frame_sizes_preserve_quota_and_counter_consistency() { + let stats = Stats::new(); + let user = "reservation-fuzz-user"; + let quota_limit = 128u64; + let bytes_me2c = AtomicU64::new(0); + let mut seed = 0xC0FE_EE11_8899_2211u64; + + for conn in 0..512u64 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + let len = ((seed & 0x0f) + 1) as usize; + let payload = vec![0x5A; len]; + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let result = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from(payload), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + user, + Some(quota_limit), + &bytes_me2c, + 7300 + conn, + false, + false, + ) + .await; + + if let Err(err) = result { + assert!( + matches!(err, ProxyError::DataQuotaExceeded { .. }), + "fuzz run produced unexpected error variant: {err:?}" + ); + } + } + + let total = stats.get_user_total_octets(user); + assert!(total <= quota_limit); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), total); +} \ No newline at end of file diff --git a/src/proxy/tests/middle_relay_tiny_frame_debt_concurrency_security_tests.rs b/src/proxy/tests/middle_relay_tiny_frame_debt_concurrency_security_tests.rs new file mode 100644 index 0000000..1bf3123 --- /dev/null +++ b/src/proxy/tests/middle_relay_tiny_frame_debt_concurrency_security_tests.rs @@ -0,0 +1,365 @@ +use super::*; +use crate::crypto::AesCtr; +use crate::stats::Stats; +use crate::stream::{BufferPool, CryptoReader}; +use std::sync::Arc; +use std::sync::atomic::AtomicU64; +use std::time::Instant; +use tokio::io::{AsyncRead, AsyncWriteExt, duplex}; +use tokio::task::JoinSet; +use tokio::time::{Duration as TokioDuration, sleep, timeout}; + +fn make_crypto_reader(reader: T) -> CryptoReader +where + T: AsyncRead + Unpin + Send + 'static, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoReader::new(reader, AesCtr::new(&key, iv)) +} + +fn encrypt_for_reader(plaintext: &[u8]) -> Vec { + let key = [0u8; 32]; + let iv = 0u128; + let mut cipher = AesCtr::new(&key, iv); + cipher.encrypt(plaintext) +} + +fn make_forensics(conn_id: u64, started_at: Instant) -> RelayForensicsState { + RelayForensicsState { + trace_id: 0xB200_0000 + conn_id, + conn_id, + user: format!("tiny-frame-debt-concurrency-user-{conn_id}"), + peer: "127.0.0.1:50000".parse().expect("peer parse must succeed"), + peer_hash: hash_ip("127.0.0.1".parse().expect("ip parse must succeed")), + started_at, + bytes_c2me: 0, + bytes_me2c: Arc::new(AtomicU64::new(0)), + desync_all_full: false, + } +} + +fn make_enabled_idle_policy() -> RelayClientIdlePolicy { + RelayClientIdlePolicy { + enabled: true, + soft_idle: Duration::from_secs(30), + hard_idle: Duration::from_secs(60), + grace_after_downstream_activity: Duration::from_secs(0), + legacy_frame_read_timeout: Duration::from_secs(30), + } +} + +async fn read_once( + crypto_reader: &mut CryptoReader, + proto: ProtoTag, + forensics: &RelayForensicsState, + frame_counter: &mut u64, + idle_state: &mut RelayClientIdleState, +) -> Result> { + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let idle_policy = make_enabled_idle_policy(); + let last_downstream_activity_ms = AtomicU64::new(0); + read_client_payload_with_idle_policy( + crypto_reader, + proto, + 1024, + &buffer_pool, + forensics, + frame_counter, + &stats, + &idle_policy, + idle_state, + &last_downstream_activity_ms, + forensics.started_at, + ) + .await +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_parallel_pure_tiny_floods_all_fail_closed() { + let mut set = JoinSet::new(); + + for idx in 0..32u64 { + set.spawn(async move { + let (reader, mut writer) = duplex(4096); + let mut crypto_reader = make_crypto_reader(reader); + let started = Instant::now(); + let forensics = make_forensics(1000 + idx, started); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(started); + + let flood_plaintext = vec![0u8; 1024]; + let flood_encrypted = encrypt_for_reader(&flood_plaintext); + writer.write_all(&flood_encrypted).await.unwrap(); + drop(writer); + + let result = timeout( + TokioDuration::from_secs(1), + read_once( + &mut crypto_reader, + ProtoTag::Abridged, + &forensics, + &mut frame_counter, + &mut idle_state, + ), + ) + .await + .expect("tiny flood task must complete"); + + assert!(matches!(result, Err(ProxyError::Proxy(_)))); + assert_eq!(frame_counter, 0); + }); + } + + while let Some(result) = set.join_next().await { + result.expect("parallel tiny flood worker must not panic"); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_parallel_benign_tiny_burst_then_real_all_pass() { + let mut set = JoinSet::new(); + + for idx in 0..24u64 { + set.spawn(async move { + let (reader, mut writer) = duplex(2048); + let mut crypto_reader = make_crypto_reader(reader); + let started = Instant::now(); + let forensics = make_forensics(2000 + idx, started); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(started); + + let payload = [idx as u8, 2, 3, 4]; + let mut plaintext = Vec::with_capacity(20); + for _ in 0..6 { + plaintext.push(0x00); + } + plaintext.push(0x01); + plaintext.extend_from_slice(&payload); + let encrypted = encrypt_for_reader(&plaintext); + writer.write_all(&encrypted).await.unwrap(); + + let result = timeout( + TokioDuration::from_secs(1), + read_once( + &mut crypto_reader, + ProtoTag::Abridged, + &forensics, + &mut frame_counter, + &mut idle_state, + ), + ) + .await + .expect("benign task must complete") + .expect("benign payload must parse") + .expect("benign payload must return frame"); + + assert_eq!(result.0.as_ref(), &payload); + assert_eq!(frame_counter, 1); + }); + } + + while let Some(result) = set.join_next().await { + result.expect("parallel benign worker must not panic"); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn adversarial_lockstep_alternating_attack_under_jitter_closes() { + let mut set = JoinSet::new(); + + for idx in 0..12u64 { + set.spawn(async move { + let (reader, mut writer) = duplex(8192); + let mut crypto_reader = make_crypto_reader(reader); + let started = Instant::now(); + let forensics = make_forensics(3000 + idx, started); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(started); + + let mut plaintext = Vec::with_capacity(2000); + for n in 0..180u8 { + plaintext.push(0x00); + plaintext.push(0x01); + plaintext.extend_from_slice(&[n, n ^ 0x21, n ^ 0x42, n ^ 0x84]); + } + let encrypted = encrypt_for_reader(&plaintext); + + let writer_task = tokio::spawn(async move { + for chunk in encrypted.chunks(17) { + writer.write_all(chunk).await.unwrap(); + sleep(TokioDuration::from_millis(1)).await; + } + drop(writer); + }); + + let mut closed = false; + for _ in 0..220 { + let result = timeout( + TokioDuration::from_secs(1), + read_once( + &mut crypto_reader, + ProtoTag::Abridged, + &forensics, + &mut frame_counter, + &mut idle_state, + ), + ) + .await + .expect("alternating reader step must complete"); + + match result { + Ok(Some((_payload, _))) => {} + Err(ProxyError::Proxy(_)) => { + closed = true; + break; + } + Ok(None) => break, + Err(other) => panic!("unexpected error in alternating jitter case: {other}"), + } + } + + writer_task.await.expect("writer jitter task must not panic"); + assert!(closed, "alternating attack must close before EOF"); + }); + } + + while let Some(result) = set.join_next().await { + result.expect("alternating jitter worker must not panic"); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn integration_mixed_population_attackers_close_benign_survive() { + let mut set = JoinSet::new(); + + for idx in 0..20u64 { + set.spawn(async move { + let (reader, mut writer) = duplex(4096); + let mut crypto_reader = make_crypto_reader(reader); + let started = Instant::now(); + let forensics = make_forensics(4000 + idx, started); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(started); + + if idx % 2 == 0 { + let mut plaintext = Vec::with_capacity(1280); + for n in 0..140u8 { + plaintext.push(0x00); + plaintext.push(0x01); + plaintext.extend_from_slice(&[n, n, n, n]); + } + writer.write_all(&encrypt_for_reader(&plaintext)).await.unwrap(); + drop(writer); + + let mut closed = false; + for _ in 0..200 { + match read_once( + &mut crypto_reader, + ProtoTag::Abridged, + &forensics, + &mut frame_counter, + &mut idle_state, + ) + .await + { + Ok(Some(_)) => {} + Err(ProxyError::Proxy(_)) => { + closed = true; + break; + } + Ok(None) => break, + Err(other) => panic!("unexpected attacker error: {other}"), + } + } + assert!(closed, "attacker session must fail closed"); + } else { + let payload = [1u8, 9, 8, 7]; + let mut plaintext = Vec::new(); + for _ in 0..4 { + plaintext.push(0x00); + } + plaintext.push(0x01); + plaintext.extend_from_slice(&payload); + writer.write_all(&encrypt_for_reader(&plaintext)).await.unwrap(); + + let got = read_once( + &mut crypto_reader, + ProtoTag::Abridged, + &forensics, + &mut frame_counter, + &mut idle_state, + ) + .await + .expect("benign session must parse") + .expect("benign session must return a frame"); + assert_eq!(got.0.as_ref(), &payload); + } + }); + } + + while let Some(result) = set.join_next().await { + result.expect("mixed-population worker must not panic"); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn light_fuzz_parallel_patterns_no_hang_or_panic() { + let mut set = JoinSet::new(); + + for case in 0..40u64 { + set.spawn(async move { + let (reader, mut writer) = duplex(8192); + let mut crypto_reader = make_crypto_reader(reader); + let started = Instant::now(); + let forensics = make_forensics(5000 + case, started); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(started); + + let mut seed = 0x9E37_79B9u64 ^ (case << 8); + let mut plaintext = Vec::with_capacity(2048); + for _ in 0..256 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + let is_tiny = (seed & 1) == 0; + if is_tiny { + plaintext.push(0x00); + } else { + plaintext.push(0x01); + plaintext.extend_from_slice(&[(seed >> 8) as u8, 2, 3, 4]); + } + } + + writer.write_all(&encrypt_for_reader(&plaintext)).await.unwrap(); + drop(writer); + + for _ in 0..320 { + let step = timeout( + TokioDuration::from_secs(1), + read_once( + &mut crypto_reader, + ProtoTag::Abridged, + &forensics, + &mut frame_counter, + &mut idle_state, + ), + ) + .await + .expect("fuzz case read step must complete"); + + match step { + Ok(Some(_)) => {} + Err(ProxyError::Proxy(_)) => break, + Ok(None) => break, + Err(other) => panic!("unexpected fuzz case error: {other}"), + } + } + }); + } + + while let Some(result) = set.join_next().await { + result.expect("fuzz worker must not panic"); + } +} diff --git a/src/proxy/tests/middle_relay_tiny_frame_debt_proto_chunking_security_tests.rs b/src/proxy/tests/middle_relay_tiny_frame_debt_proto_chunking_security_tests.rs new file mode 100644 index 0000000..0ff46a2 --- /dev/null +++ b/src/proxy/tests/middle_relay_tiny_frame_debt_proto_chunking_security_tests.rs @@ -0,0 +1,418 @@ +use super::*; +use crate::crypto::AesCtr; +use crate::stats::Stats; +use crate::stream::{BufferPool, CryptoReader, PooledBuffer}; +use std::sync::Arc; +use std::sync::atomic::AtomicU64; +use std::time::Instant; +use tokio::io::{AsyncRead, AsyncWriteExt, duplex}; +use tokio::time::{Duration as TokioDuration, sleep, timeout}; + +fn make_crypto_reader(reader: T) -> CryptoReader +where + T: AsyncRead + Unpin + Send + 'static, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoReader::new(reader, AesCtr::new(&key, iv)) +} + +fn encrypt_for_reader(plaintext: &[u8]) -> Vec { + let key = [0u8; 32]; + let iv = 0u128; + let mut cipher = AesCtr::new(&key, iv); + cipher.encrypt(plaintext) +} + +fn make_forensics(conn_id: u64, started_at: Instant) -> RelayForensicsState { + RelayForensicsState { + trace_id: 0xB300_0000 + conn_id, + conn_id, + user: format!("tiny-frame-debt-proto-chunk-user-{conn_id}"), + peer: "127.0.0.1:50000".parse().expect("peer parse must succeed"), + peer_hash: hash_ip("127.0.0.1".parse().expect("ip parse must succeed")), + started_at, + bytes_c2me: 0, + bytes_me2c: Arc::new(AtomicU64::new(0)), + desync_all_full: false, + } +} + +fn make_enabled_idle_policy() -> RelayClientIdlePolicy { + RelayClientIdlePolicy { + enabled: true, + soft_idle: Duration::from_secs(30), + hard_idle: Duration::from_secs(60), + grace_after_downstream_activity: Duration::from_secs(0), + legacy_frame_read_timeout: Duration::from_secs(30), + } +} + +fn append_tiny_frame(plaintext: &mut Vec, proto: ProtoTag) { + match proto { + ProtoTag::Abridged => plaintext.push(0x00), + ProtoTag::Intermediate | ProtoTag::Secure => plaintext.extend_from_slice(&0u32.to_le_bytes()), + } +} + +fn append_real_frame(plaintext: &mut Vec, proto: ProtoTag, payload: [u8; 4]) { + match proto { + ProtoTag::Abridged => { + plaintext.push(0x01); + plaintext.extend_from_slice(&payload); + } + ProtoTag::Intermediate | ProtoTag::Secure => { + plaintext.extend_from_slice(&4u32.to_le_bytes()); + plaintext.extend_from_slice(&payload); + } + } +} + +async fn write_chunked_with_jitter( + writer: &mut tokio::io::DuplexStream, + bytes: &[u8], + mut seed: u64, +) { + let mut offset = 0usize; + while offset < bytes.len() { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + let chunk_len = 1 + ((seed as usize) & 0x1f); + let end = (offset + chunk_len).min(bytes.len()); + writer.write_all(&bytes[offset..end]).await.unwrap(); + + let delay_ms = ((seed >> 16) % 3) as u64; + if delay_ms > 0 { + sleep(TokioDuration::from_millis(delay_ms)).await; + } + offset = end; + } +} + +async fn read_once_with_state( + crypto_reader: &mut CryptoReader, + proto: ProtoTag, + forensics: &RelayForensicsState, + frame_counter: &mut u64, + idle_state: &mut RelayClientIdleState, +) -> Result> { + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let idle_policy = make_enabled_idle_policy(); + let last_downstream_activity_ms = AtomicU64::new(0); + read_client_payload_with_idle_policy( + crypto_reader, + proto, + 1024, + &buffer_pool, + forensics, + frame_counter, + &stats, + &idle_policy, + idle_state, + &last_downstream_activity_ms, + forensics.started_at, + ) + .await +} + +#[tokio::test] +async fn intermediate_chunked_zero_flood_fail_closed() { + let (reader, mut writer) = duplex(4096); + let mut crypto_reader = make_crypto_reader(reader); + let started = Instant::now(); + let forensics = make_forensics(6101, started); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(started); + + let mut plaintext = Vec::with_capacity(4 * 256); + for _ in 0..256 { + append_tiny_frame(&mut plaintext, ProtoTag::Intermediate); + } + let encrypted = encrypt_for_reader(&plaintext); + write_chunked_with_jitter(&mut writer, &encrypted, 0x1111_2222).await; + drop(writer); + + let result = timeout( + TokioDuration::from_secs(2), + read_once_with_state( + &mut crypto_reader, + ProtoTag::Intermediate, + &forensics, + &mut frame_counter, + &mut idle_state, + ), + ) + .await + .expect("intermediate flood read must complete"); + + assert!(matches!(result, Err(ProxyError::Proxy(_)))); + assert_eq!(frame_counter, 0); +} + +#[tokio::test] +async fn secure_chunked_zero_flood_fail_closed() { + let (reader, mut writer) = duplex(4096); + let mut crypto_reader = make_crypto_reader(reader); + let started = Instant::now(); + let forensics = make_forensics(6102, started); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(started); + + let mut plaintext = Vec::with_capacity(4 * 256); + for _ in 0..256 { + append_tiny_frame(&mut plaintext, ProtoTag::Secure); + } + let encrypted = encrypt_for_reader(&plaintext); + write_chunked_with_jitter(&mut writer, &encrypted, 0x3333_4444).await; + drop(writer); + + let result = timeout( + TokioDuration::from_secs(2), + read_once_with_state( + &mut crypto_reader, + ProtoTag::Secure, + &forensics, + &mut frame_counter, + &mut idle_state, + ), + ) + .await + .expect("secure flood read must complete"); + + assert!(matches!(result, Err(ProxyError::Proxy(_)))); + assert_eq!(frame_counter, 0); +} + +#[tokio::test] +async fn intermediate_chunked_alternating_attack_closes_before_eof() { + let (reader, mut writer) = duplex(8192); + let mut crypto_reader = make_crypto_reader(reader); + let started = Instant::now(); + let forensics = make_forensics(6103, started); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(started); + + let mut plaintext = Vec::with_capacity(8 * 200); + for n in 0..180u8 { + append_tiny_frame(&mut plaintext, ProtoTag::Intermediate); + append_real_frame(&mut plaintext, ProtoTag::Intermediate, [n, n ^ 1, n ^ 2, n ^ 3]); + } + let encrypted = encrypt_for_reader(&plaintext); + + let writer_task = tokio::spawn(async move { + write_chunked_with_jitter(&mut writer, &encrypted, 0x5555_6666).await; + drop(writer); + }); + + let mut closed = false; + for _ in 0..240 { + let step = timeout( + TokioDuration::from_secs(1), + read_once_with_state( + &mut crypto_reader, + ProtoTag::Intermediate, + &forensics, + &mut frame_counter, + &mut idle_state, + ), + ) + .await + .expect("intermediate alternating read step must complete"); + + match step { + Ok(Some(_)) => {} + Err(ProxyError::Proxy(_)) => { + closed = true; + break; + } + Ok(None) => break, + Err(other) => panic!("unexpected intermediate alternating error: {other}"), + } + } + + writer_task.await.expect("intermediate writer task must not panic"); + assert!(closed, "intermediate alternating attack must fail closed"); +} + +#[tokio::test] +async fn secure_chunked_alternating_attack_closes_before_eof() { + let (reader, mut writer) = duplex(8192); + let mut crypto_reader = make_crypto_reader(reader); + let started = Instant::now(); + let forensics = make_forensics(6104, started); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(started); + + let mut plaintext = Vec::with_capacity(8 * 200); + for n in 0..180u8 { + append_tiny_frame(&mut plaintext, ProtoTag::Secure); + append_real_frame(&mut plaintext, ProtoTag::Secure, [n, n ^ 7, n ^ 11, n ^ 19]); + } + let encrypted = encrypt_for_reader(&plaintext); + + let writer_task = tokio::spawn(async move { + write_chunked_with_jitter(&mut writer, &encrypted, 0x7777_8888).await; + drop(writer); + }); + + let mut closed = false; + for _ in 0..240 { + let step = timeout( + TokioDuration::from_secs(1), + read_once_with_state( + &mut crypto_reader, + ProtoTag::Secure, + &forensics, + &mut frame_counter, + &mut idle_state, + ), + ) + .await + .expect("secure alternating read step must complete"); + + match step { + Ok(Some(_)) => {} + Err(ProxyError::Proxy(_)) => { + closed = true; + break; + } + Ok(None) => break, + Err(other) => panic!("unexpected secure alternating error: {other}"), + } + } + + writer_task.await.expect("secure writer task must not panic"); + assert!(closed, "secure alternating attack must fail closed"); +} + +#[tokio::test] +async fn intermediate_chunked_safe_small_burst_still_returns_real_frame() { + let (reader, mut writer) = duplex(1024); + let mut crypto_reader = make_crypto_reader(reader); + let started = Instant::now(); + let forensics = make_forensics(6105, started); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(started); + + let payload = [9u8, 8, 7, 6]; + let mut plaintext = Vec::new(); + for _ in 0..7 { + append_tiny_frame(&mut plaintext, ProtoTag::Intermediate); + } + append_real_frame(&mut plaintext, ProtoTag::Intermediate, payload); + let encrypted = encrypt_for_reader(&plaintext); + write_chunked_with_jitter(&mut writer, &encrypted, 0xAAAA_BBBB).await; + + let result = read_once_with_state( + &mut crypto_reader, + ProtoTag::Intermediate, + &forensics, + &mut frame_counter, + &mut idle_state, + ) + .await + .expect("intermediate safe burst should parse") + .expect("intermediate safe burst should return a frame"); + + assert_eq!(result.0.as_ref(), &payload); + assert_eq!(frame_counter, 1); +} + +#[tokio::test] +async fn secure_chunked_safe_small_burst_still_returns_real_frame() { + let (reader, mut writer) = duplex(1024); + let mut crypto_reader = make_crypto_reader(reader); + let started = Instant::now(); + let forensics = make_forensics(6106, started); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(started); + + let payload = [3u8, 1, 4, 1]; + let mut plaintext = Vec::new(); + for _ in 0..7 { + append_tiny_frame(&mut plaintext, ProtoTag::Secure); + } + append_real_frame(&mut plaintext, ProtoTag::Secure, payload); + let encrypted = encrypt_for_reader(&plaintext); + write_chunked_with_jitter(&mut writer, &encrypted, 0xCCCC_DDDD).await; + + let result = read_once_with_state( + &mut crypto_reader, + ProtoTag::Secure, + &forensics, + &mut frame_counter, + &mut idle_state, + ) + .await + .expect("secure safe burst should parse") + .expect("secure safe burst should return a frame"); + + assert_eq!(result.0.as_ref(), &payload); + assert_eq!(frame_counter, 1); +} + +#[tokio::test] +async fn light_fuzz_proto_chunking_outcomes_are_bounded() { + let mut seed = 0xDEAD_BEEF_2026_0322u64; + + for case in 0..48u64 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let proto = if (seed & 1) == 0 { + ProtoTag::Intermediate + } else { + ProtoTag::Secure + }; + + let (reader, mut writer) = duplex(8192); + let mut crypto_reader = make_crypto_reader(reader); + let started = Instant::now(); + let forensics = make_forensics(6200 + case, started); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(started); + + let mut stream = Vec::new(); + let mut local_seed = seed ^ case; + for _ in 0..220 { + local_seed ^= local_seed << 7; + local_seed ^= local_seed >> 9; + local_seed ^= local_seed << 8; + if (local_seed & 1) == 0 { + append_tiny_frame(&mut stream, proto); + } else { + let b = (local_seed >> 8) as u8; + append_real_frame(&mut stream, proto, [b, b ^ 0x12, b ^ 0x24, b ^ 0x48]); + } + } + + let encrypted = encrypt_for_reader(&stream); + write_chunked_with_jitter(&mut writer, &encrypted, seed ^ 0x1234_5678).await; + drop(writer); + + for _ in 0..260 { + let step = timeout( + TokioDuration::from_secs(1), + read_once_with_state( + &mut crypto_reader, + proto, + &forensics, + &mut frame_counter, + &mut idle_state, + ), + ) + .await + .expect("fuzz proto read step must complete"); + + match step { + Ok(Some((_payload, _))) => {} + Err(ProxyError::Proxy(_)) => break, + Ok(None) => break, + Err(other) => panic!("unexpected proto chunking fuzz error: {other}"), + } + } + } +} diff --git a/src/proxy/tests/middle_relay_tiny_frame_debt_security_tests.rs b/src/proxy/tests/middle_relay_tiny_frame_debt_security_tests.rs new file mode 100644 index 0000000..d0719c8 --- /dev/null +++ b/src/proxy/tests/middle_relay_tiny_frame_debt_security_tests.rs @@ -0,0 +1,550 @@ +use super::*; +use crate::crypto::AesCtr; +use crate::stats::Stats; +use crate::stream::{BufferPool, CryptoReader}; +use std::sync::atomic::AtomicU64; +use std::sync::Arc; +use std::time::Instant; +use tokio::io::{AsyncRead, AsyncWriteExt, duplex}; + +fn make_crypto_reader(reader: T) -> CryptoReader +where + T: AsyncRead + Unpin + Send + 'static, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoReader::new(reader, AesCtr::new(&key, iv)) +} + +fn encrypt_for_reader(plaintext: &[u8]) -> Vec { + let key = [0u8; 32]; + let iv = 0u128; + let mut cipher = AesCtr::new(&key, iv); + cipher.encrypt(plaintext) +} + +fn make_forensics(conn_id: u64, started_at: Instant) -> RelayForensicsState { + RelayForensicsState { + trace_id: 0xB100_0000 + conn_id, + conn_id, + user: format!("tiny-frame-debt-user-{conn_id}"), + peer: "127.0.0.1:50000".parse().expect("peer parse must succeed"), + peer_hash: hash_ip("127.0.0.1".parse().expect("ip parse must succeed")), + started_at, + bytes_c2me: 0, + bytes_me2c: Arc::new(AtomicU64::new(0)), + desync_all_full: false, + } +} + +fn make_enabled_idle_policy() -> RelayClientIdlePolicy { + RelayClientIdlePolicy { + enabled: true, + soft_idle: Duration::from_secs(30), + hard_idle: Duration::from_secs(60), + grace_after_downstream_activity: Duration::from_secs(0), + legacy_frame_read_timeout: Duration::from_secs(30), + } +} + +fn simulate_tiny_debt_pattern(pattern: &[bool], max_steps: usize) -> (Option, u32, usize) { + let mut debt = 0u32; + let mut reals = 0usize; + for (idx, is_tiny) in pattern.iter().copied().take(max_steps).enumerate() { + if is_tiny { + debt = debt.saturating_add(TINY_FRAME_DEBT_PER_TINY); + if debt >= TINY_FRAME_DEBT_LIMIT { + return (Some(idx + 1), debt, reals); + } + } else { + reals = reals.saturating_add(1); + debt = debt.saturating_sub(1); + } + } + (None, debt, reals) +} + +#[test] +fn tiny_frame_debt_constants_match_security_budget_expectations() { + assert_eq!(TINY_FRAME_DEBT_PER_TINY, 8); + assert_eq!(TINY_FRAME_DEBT_LIMIT, 512); +} + +#[test] +fn relay_client_idle_state_initial_debt_is_zero() { + let state = RelayClientIdleState::new(Instant::now()); + assert_eq!(state.tiny_frame_debt, 0); +} + +#[test] +fn on_client_frame_does_not_reset_tiny_frame_debt() { + let now = Instant::now(); + let mut state = RelayClientIdleState::new(now); + state.tiny_frame_debt = 77; + state.on_client_frame(now); + assert_eq!(state.tiny_frame_debt, 77); +} + +#[test] +fn tiny_frame_debt_increment_is_saturating() { + let mut debt = u32::MAX - 1; + debt = debt.saturating_add(TINY_FRAME_DEBT_PER_TINY); + assert_eq!(debt, u32::MAX); +} + +#[test] +fn tiny_frame_debt_decrement_is_saturating() { + let mut debt = 0u32; + debt = debt.saturating_sub(1); + assert_eq!(debt, 0); +} + +#[test] +fn consecutive_tiny_frames_close_exactly_at_threshold() { + let max_tiny_without_close = (TINY_FRAME_DEBT_LIMIT / TINY_FRAME_DEBT_PER_TINY) as usize; + let pattern = vec![true; max_tiny_without_close]; + let (closed_at, _, _) = simulate_tiny_debt_pattern(&pattern, pattern.len()); + assert_eq!(closed_at, Some(max_tiny_without_close)); +} + +#[test] +fn one_less_than_threshold_tiny_frames_do_not_close() { + let tiny_count = (TINY_FRAME_DEBT_LIMIT / TINY_FRAME_DEBT_PER_TINY) as usize - 1; + let pattern = vec![true; tiny_count]; + let (closed_at, debt, _) = simulate_tiny_debt_pattern(&pattern, pattern.len()); + assert_eq!(closed_at, None); + assert!(debt < TINY_FRAME_DEBT_LIMIT); +} + +#[test] +fn alternating_one_to_one_closes_with_bounded_real_frame_count() { + let mut pattern = Vec::with_capacity(512); + for _ in 0..256 { + pattern.push(true); + pattern.push(false); + } + let (closed_at, _, reals) = simulate_tiny_debt_pattern(&pattern, pattern.len()); + assert!(closed_at.is_some()); + assert!(reals <= 80, "expected bounded real frames before close, got {reals}"); +} + +#[test] +fn alternating_one_to_eight_is_stable_for_long_runs() { + let mut pattern = Vec::with_capacity(9 * 5000); + for _ in 0..5000 { + pattern.push(true); + for _ in 0..8 { + pattern.push(false); + } + } + let (closed_at, debt, _) = simulate_tiny_debt_pattern(&pattern, pattern.len()); + assert_eq!(closed_at, None); + assert!(debt <= TINY_FRAME_DEBT_PER_TINY); +} + +#[test] +fn alternating_one_to_seven_eventually_closes() { + let mut pattern = Vec::with_capacity(8 * 2000); + for _ in 0..2000 { + pattern.push(true); + for _ in 0..7 { + pattern.push(false); + } + } + let (closed_at, _, _) = simulate_tiny_debt_pattern(&pattern, pattern.len()); + assert!(closed_at.is_some(), "1:7 tiny-to-real must eventually close"); +} + +#[test] +fn two_tiny_one_real_closes_faster_than_one_to_one() { + let mut one_to_one = Vec::with_capacity(512); + for _ in 0..256 { + one_to_one.push(true); + one_to_one.push(false); + } + + let mut two_to_one = Vec::with_capacity(768); + for _ in 0..256 { + two_to_one.push(true); + two_to_one.push(true); + two_to_one.push(false); + } + + let (a_close, _, _) = simulate_tiny_debt_pattern(&one_to_one, one_to_one.len()); + let (b_close, _, _) = simulate_tiny_debt_pattern(&two_to_one, two_to_one.len()); + assert!(a_close.is_some() && b_close.is_some()); + assert!(b_close.unwrap_or(usize::MAX) < a_close.unwrap_or(0)); +} + +#[test] +fn burst_then_drain_can_recover_without_close() { + let burst_tiny = ((TINY_FRAME_DEBT_LIMIT / TINY_FRAME_DEBT_PER_TINY) / 2) as usize; + let mut pattern = Vec::with_capacity(burst_tiny + 600); + for _ in 0..burst_tiny { + pattern.push(true); + } + pattern.extend(std::iter::repeat_n(false, 600)); + + let (closed_at, debt, _) = simulate_tiny_debt_pattern(&pattern, pattern.len()); + assert_eq!(closed_at, None); + assert_eq!(debt, 0); +} + +#[test] +fn light_fuzz_tiny_frame_debt_model_stays_within_bounds() { + let mut seed = 0xA5A5_91C3_2026_0322u64; + for _case in 0..128 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let len = 512 + ((seed as usize) & 0x3ff); + let mut pattern = Vec::with_capacity(len); + let mut local_seed = seed; + for _ in 0..len { + local_seed ^= local_seed << 7; + local_seed ^= local_seed >> 9; + local_seed ^= local_seed << 8; + pattern.push((local_seed & 1) == 0); + } + + let (closed_at, debt, _) = simulate_tiny_debt_pattern(&pattern, pattern.len()); + if closed_at.is_none() { + assert!(debt < TINY_FRAME_DEBT_LIMIT); + } + assert!(debt <= u32::MAX); + } +} + +#[test] +fn stress_many_independent_simulations_keep_isolated_debt_state() { + for idx in 0..2048usize { + let mut pattern = Vec::with_capacity(64); + for j in 0..64usize { + pattern.push(((idx ^ j) & 3) == 0); + } + let (_closed_at, debt, _reals) = simulate_tiny_debt_pattern(&pattern, pattern.len()); + assert!(debt <= TINY_FRAME_DEBT_LIMIT.saturating_add(TINY_FRAME_DEBT_PER_TINY)); + } +} + +#[tokio::test] +async fn idle_policy_enabled_intermediate_zero_length_flood_is_fail_closed() { + let (reader, mut writer) = duplex(4096); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(11, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_enabled_idle_policy(); + let last_downstream_activity_ms = AtomicU64::new(0); + + let flood_plaintext = vec![0u8; 4 * 256]; + let flood_encrypted = encrypt_for_reader(&flood_plaintext); + writer.write_all(&flood_encrypted).await.unwrap(); + drop(writer); + + let result = read_client_payload_with_idle_policy( + &mut crypto_reader, + ProtoTag::Intermediate, + 1024, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ) + .await; + + assert!(matches!(result, Err(ProxyError::Proxy(_)))); +} + +#[tokio::test] +async fn idle_policy_enabled_secure_zero_length_flood_is_fail_closed() { + let (reader, mut writer) = duplex(4096); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(12, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_enabled_idle_policy(); + let last_downstream_activity_ms = AtomicU64::new(0); + + let flood_plaintext = vec![0u8; 4 * 256]; + let flood_encrypted = encrypt_for_reader(&flood_plaintext); + writer.write_all(&flood_encrypted).await.unwrap(); + drop(writer); + + let result = read_client_payload_with_idle_policy( + &mut crypto_reader, + ProtoTag::Secure, + 1024, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ) + .await; + + assert!(matches!(result, Err(ProxyError::Proxy(_)))); +} + +#[tokio::test] +async fn intermediate_alternating_zero_and_real_eventually_closes() { + let (reader, mut writer) = duplex(8192); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(13, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_enabled_idle_policy(); + let last_downstream_activity_ms = AtomicU64::new(0); + + let mut plaintext = Vec::with_capacity(3000); + for idx in 0..160u8 { + plaintext.extend_from_slice(&0u32.to_le_bytes()); + plaintext.extend_from_slice(&4u32.to_le_bytes()); + plaintext.extend_from_slice(&[idx, idx ^ 0x11, idx ^ 0x22, idx ^ 0x33]); + } + let encrypted = encrypt_for_reader(&plaintext); + writer.write_all(&encrypted).await.unwrap(); + drop(writer); + + let mut closed = false; + for _ in 0..220 { + let result = read_client_payload_with_idle_policy( + &mut crypto_reader, + ProtoTag::Intermediate, + 1024, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ) + .await; + + match result { + Ok(Some(_)) => {} + Err(ProxyError::Proxy(_)) => { + closed = true; + break; + } + Ok(None) => break, + Err(other) => panic!("unexpected error while probing alternating close: {other}"), + } + } + + assert!(closed, "intermediate alternating attack must fail closed"); +} + +#[tokio::test] +async fn small_tiny_burst_followed_by_real_frame_does_not_spuriously_close() { + let (reader, mut writer) = duplex(1024); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(14, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_enabled_idle_policy(); + let last_downstream_activity_ms = AtomicU64::new(0); + + let mut plaintext = Vec::with_capacity(64); + for _ in 0..8 { + plaintext.push(0x00); + } + plaintext.push(0x01); + plaintext.extend_from_slice(&[1, 2, 3, 4]); + + let encrypted = encrypt_for_reader(&plaintext); + writer.write_all(&encrypted).await.unwrap(); + + let first = read_client_payload_with_idle_policy( + &mut crypto_reader, + ProtoTag::Abridged, + 1024, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ) + .await; + + match first { + Ok(Some((payload, _))) => assert_eq!(payload.as_ref(), &[1, 2, 3, 4]), + Err(e) => panic!("unexpected close after small tiny burst: {e}"), + Ok(None) => panic!("unexpected EOF before real frame"), + } +} + +#[tokio::test] +async fn idle_policy_enabled_zero_length_flood_is_fail_closed() { + let (reader, mut writer) = duplex(4096); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(1, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_enabled_idle_policy(); + let last_downstream_activity_ms = AtomicU64::new(0); + + let flood_plaintext = vec![0u8; 1024]; + let flood_encrypted = encrypt_for_reader(&flood_plaintext); + writer + .write_all(&flood_encrypted) + .await + .expect("zero-length flood bytes must be writable"); + drop(writer); + + let result = read_client_payload_with_idle_policy( + &mut crypto_reader, + ProtoTag::Abridged, + 1024, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ) + .await; + + assert!( + matches!(result, Err(ProxyError::Proxy(_))), + "idle policy enabled must fail closed for pure zero-length flood" + ); +} + +#[tokio::test] +async fn idle_policy_enabled_alternating_tiny_real_eventually_closes() { + let (reader, mut writer) = duplex(8192); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(2, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_enabled_idle_policy(); + let last_downstream_activity_ms = AtomicU64::new(0); + + let mut plaintext = Vec::with_capacity(256 * 6); + for idx in 0..=255u8 { + plaintext.push(0x00); + plaintext.push(0x01); + plaintext.extend_from_slice(&[idx, idx ^ 0x55, idx ^ 0xAA, 0x11]); + } + + let encrypted = encrypt_for_reader(&plaintext); + writer + .write_all(&encrypted) + .await + .expect("alternating flood bytes must be writable"); + drop(writer); + + let mut saw_proxy_close = false; + for _ in 0..300 { + let result = read_client_payload_with_idle_policy( + &mut crypto_reader, + ProtoTag::Abridged, + 1024, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ) + .await; + + match result { + Ok(Some((_payload, _quickack))) => {} + Err(ProxyError::Proxy(_)) => { + saw_proxy_close = true; + break; + } + Err(ProxyError::Io(e)) => panic!("unexpected IO error before close: {e}"), + Ok(None) => panic!("unexpected EOF before debt-based closure"), + Err(other) => panic!("unexpected error before close: {other}"), + } + } + + assert!( + saw_proxy_close, + "alternating tiny/real sequence must eventually fail closed" + ); +} + +#[tokio::test] +async fn enabled_idle_policy_valid_nonzero_frame_still_passes() { + let (reader, mut writer) = duplex(1024); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(3, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_enabled_idle_policy(); + let last_downstream_activity_ms = AtomicU64::new(0); + + let payload = [7u8, 8, 9, 10]; + let mut plaintext = Vec::with_capacity(1 + payload.len()); + plaintext.push(0x01); + plaintext.extend_from_slice(&payload); + + let encrypted = encrypt_for_reader(&plaintext); + writer + .write_all(&encrypted) + .await + .expect("nonzero frame must be writable"); + + let result = read_client_payload_with_idle_policy( + &mut crypto_reader, + ProtoTag::Abridged, + 1024, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ) + .await + .expect("valid frame should decode") + .expect("valid frame should return payload"); + + assert_eq!(result.0.as_ref(), &payload); + assert!(!result.1); + assert_eq!(frame_counter, 1); +} diff --git a/src/proxy/tests/middle_relay_zero_length_frame_security_tests.rs b/src/proxy/tests/middle_relay_zero_length_frame_security_tests.rs new file mode 100644 index 0000000..765c253 --- /dev/null +++ b/src/proxy/tests/middle_relay_zero_length_frame_security_tests.rs @@ -0,0 +1,121 @@ +use super::*; +use crate::crypto::AesCtr; +use crate::stats::Stats; +use crate::stream::{BufferPool, CryptoReader}; +use std::sync::atomic::AtomicU64; +use std::sync::Arc; +use tokio::io::{AsyncRead, AsyncWriteExt, duplex}; +use std::time::Instant; + +fn make_crypto_reader(reader: T) -> CryptoReader +where + T: AsyncRead + Unpin + Send + 'static, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoReader::new(reader, AesCtr::new(&key, iv)) +} + +fn encrypt_for_reader(plaintext: &[u8]) -> Vec { + let key = [0u8; 32]; + let iv = 0u128; + let mut cipher = AesCtr::new(&key, iv); + cipher.encrypt(plaintext) +} + +fn make_forensics(conn_id: u64, started_at: Instant) -> RelayForensicsState { + RelayForensicsState { + trace_id: 0xB000_0000 + conn_id, + conn_id, + user: format!("zero-len-test-user-{conn_id}"), + peer: "127.0.0.1:50000".parse().expect("peer parse must succeed"), + peer_hash: hash_ip("127.0.0.1".parse().expect("ip parse must succeed")), + started_at, + bytes_c2me: 0, + bytes_me2c: Arc::new(AtomicU64::new(0)), + desync_all_full: false, + } +} + +#[tokio::test] +async fn adversarial_legacy_zero_length_flood_is_fail_closed() { + let (reader, mut writer) = duplex(1024); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(1, session_started_at); + let mut frame_counter = 0u64; + + let flood_plaintext = vec![0u8; 128]; + let flood_encrypted = encrypt_for_reader(&flood_plaintext); + writer + .write_all(&flood_encrypted) + .await + .expect("zero-length flood bytes must be writable"); + drop(writer); + + let result = read_client_payload_legacy( + &mut crypto_reader, + ProtoTag::Abridged, + 1024, + Duration::from_millis(30), + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + ) + .await; + + match result { + Err(ProxyError::Proxy(msg)) => { + assert!( + msg.contains("Excessive zero-length"), + "legacy mode must close flood with explicit zero-length reason, got: {msg}" + ); + } + Ok(None) => panic!("legacy zero-length flood must not be accepted as EOF"), + Ok(Some(_)) => panic!("legacy zero-length flood must not produce a data frame"), + Err(err) => panic!("legacy zero-length flood must be a Proxy error, got: {err}"), + } +} + +#[tokio::test] +async fn business_abridged_nonzero_frame_still_passes() { + let (reader, mut writer) = duplex(1024); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(2, session_started_at); + let mut frame_counter = 0u64; + + let payload = [1u8, 2, 3, 4]; + let mut plaintext = Vec::with_capacity(1 + payload.len()); + plaintext.push(0x01); + plaintext.extend_from_slice(&payload); + + let encrypted = encrypt_for_reader(&plaintext); + writer + .write_all(&encrypted) + .await + .expect("nonzero abridged frame must be writable"); + + let result = read_client_payload_legacy( + &mut crypto_reader, + ProtoTag::Abridged, + 1024, + Duration::from_millis(30), + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + ) + .await + .expect("valid abridged frame should decode") + .expect("valid abridged frame should return payload"); + + assert_eq!(result.0.as_ref(), &payload); + assert!(!result.1, "quickack flag must remain false"); + assert_eq!(frame_counter, 1); +} diff --git a/src/proxy/tests/quota_lock_registry_cross_mode_adversarial_tests.rs b/src/proxy/tests/quota_lock_registry_cross_mode_adversarial_tests.rs new file mode 100644 index 0000000..fb0cf93 --- /dev/null +++ b/src/proxy/tests/quota_lock_registry_cross_mode_adversarial_tests.rs @@ -0,0 +1,108 @@ +use super::*; +use std::sync::Arc; +use std::sync::{Mutex, OnceLock}; + +fn cross_mode_lock_test_guard() -> std::sync::MutexGuard<'static, ()> { + static TEST_LOCK: OnceLock> = OnceLock::new(); + TEST_LOCK + .get_or_init(|| Mutex::new(())) + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +#[test] +fn same_user_returns_same_lock_identity() { + let _guard = cross_mode_lock_test_guard(); + let locks = CROSS_MODE_QUOTA_USER_LOCKS.get_or_init(DashMap::new); + locks.clear(); + + let a = cross_mode_quota_user_lock("cross-mode-same-user"); + let b = cross_mode_quota_user_lock("cross-mode-same-user"); + + assert!( + Arc::ptr_eq(&a, &b), + "same user must reuse a stable lock identity" + ); +} + +#[test] +fn saturation_overflow_path_returns_stable_striped_lock_without_cache_growth() { + let _guard = cross_mode_lock_test_guard(); + let locks = CROSS_MODE_QUOTA_USER_LOCKS.get_or_init(DashMap::new); + locks.clear(); + + let prefix = format!("cross-mode-saturated-{}", std::process::id()); + let mut retained = Vec::with_capacity(CROSS_MODE_QUOTA_USER_LOCKS_MAX); + for idx in 0..CROSS_MODE_QUOTA_USER_LOCKS_MAX { + retained.push(cross_mode_quota_user_lock(&format!("{prefix}-{idx}"))); + } + + assert_eq!( + locks.len(), + CROSS_MODE_QUOTA_USER_LOCKS_MAX, + "lock cache must be saturated for overflow check" + ); + + let overflow_user = format!("cross-mode-overflow-{}", std::process::id()); + let overflow_a = cross_mode_quota_user_lock(&overflow_user); + let overflow_b = cross_mode_quota_user_lock(&overflow_user); + + assert_eq!( + locks.len(), + CROSS_MODE_QUOTA_USER_LOCKS_MAX, + "overflow path must not grow bounded lock cache" + ); + assert!( + locks.get(&overflow_user).is_none(), + "overflow user must stay on striped fallback while cache is saturated" + ); + assert!( + Arc::ptr_eq(&overflow_a, &overflow_b), + "overflow user must receive a stable striped lock across repeated lookups" + ); + + drop(retained); +} + +#[test] +fn reclaim_drops_stale_entries_but_preserves_active_user_lock_identity() { + let _guard = cross_mode_lock_test_guard(); + let locks = CROSS_MODE_QUOTA_USER_LOCKS.get_or_init(DashMap::new); + locks.clear(); + + let prefix = format!("cross-mode-reclaim-{}", std::process::id()); + let protected_user = format!("{prefix}-protected"); + + let protected_lock = cross_mode_quota_user_lock(&protected_user); + let mut retained = Vec::with_capacity(CROSS_MODE_QUOTA_USER_LOCKS_MAX.saturating_sub(1)); + for idx in 0..(CROSS_MODE_QUOTA_USER_LOCKS_MAX.saturating_sub(1)) { + retained.push(cross_mode_quota_user_lock(&format!("{prefix}-{idx}"))); + } + + assert_eq!( + locks.len(), + CROSS_MODE_QUOTA_USER_LOCKS_MAX, + "fixture must saturate lock cache before reclaim path is exercised" + ); + + drop(retained); + + let newcomer_user = format!("{prefix}-newcomer"); + let _newcomer = cross_mode_quota_user_lock(&newcomer_user); + + assert!( + locks.get(&protected_user).is_some(), + "active protected user must remain cache-resident after reclaim" + ); + let locked = locks + .get(&protected_user) + .expect("protected user must remain in map after reclaim"); + assert!( + Arc::ptr_eq(locked.value(), &protected_lock), + "reclaim must not swap active user lock identity" + ); + assert!( + locks.get(&newcomer_user).is_some(), + "newcomer should become cacheable after stale entries are reclaimed" + ); +} diff --git a/src/proxy/tests/relay_cross_mode_quota_fairness_tdd_tests.rs b/src/proxy/tests/relay_cross_mode_quota_fairness_tdd_tests.rs new file mode 100644 index 0000000..87944ba --- /dev/null +++ b/src/proxy/tests/relay_cross_mode_quota_fairness_tdd_tests.rs @@ -0,0 +1,225 @@ +use super::*; +use crate::stats::Stats; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::task::{Context, Waker}; +use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadBuf}; +use tokio::time::{Duration, timeout}; + +#[derive(Default)] +struct WakeCounter { + wakes: AtomicUsize, +} + +impl std::task::Wake for WakeCounter { + fn wake(self: Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } + + fn wake_by_ref(self: &Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } +} + +fn quota_test_guard() -> impl Drop { + super::quota_user_lock_test_scope() +} + +#[tokio::test] +async fn positive_cross_mode_uncontended_writer_progresses() { + let _guard = quota_test_guard(); + + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + "cross-mode-tdd-uncontended".to_string(), + Some(4096), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + let result = io.write_all(&[0x11, 0x22]).await; + assert!(result.is_ok(), "uncontended writer must progress"); +} + +#[tokio::test] +async fn adversarial_held_cross_mode_lock_blocks_writer_even_if_local_lock_free() { + let _guard = quota_test_guard(); + + let user = format!("cross-mode-tdd-held-{}", std::process::id()); + let held = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); + let _held_guard = held + .try_lock() + .expect("test must hold cross-mode lock before polling writer"); + + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + Some(4096), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + let wake_counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&wake_counter)); + let mut cx = Context::from_waker(&waker); + + let poll = Pin::new(&mut io).poll_write(&mut cx, &[0xAA]); + assert!(poll.is_pending(), "writer must not bypass held cross-mode lock"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn integration_parallel_waiters_resume_after_cross_mode_release() { + let _guard = quota_test_guard(); + + let user = format!("cross-mode-tdd-resume-{}", std::process::id()); + let held = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); + let held_guard = held + .try_lock() + .expect("test must hold cross-mode lock before launching waiters"); + + let stats = Arc::new(Stats::new()); + let mut waiters = Vec::new(); + for _ in 0..16 { + let stats = Arc::clone(&stats); + let user = user.clone(); + waiters.push(tokio::spawn(async move { + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + stats, + user, + Some(4096), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + io.write_all(&[0x7F]).await + })); + } + + tokio::time::sleep(Duration::from_millis(5)).await; + drop(held_guard); + + timeout(Duration::from_secs(1), async { + for waiter in waiters { + let result = waiter.await.expect("waiter task must not panic"); + assert!(result.is_ok(), "waiter must complete after cross-mode release"); + } + }) + .await + .expect("all waiters must complete in bounded time"); +} + +#[tokio::test] +async fn adversarial_cross_mode_contention_wake_budget_stays_bounded() { + let _guard = quota_test_guard(); + + let user = format!("cross-mode-tdd-wakes-{}", std::process::id()); + let held = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); + let _held_guard = held + .try_lock() + .expect("test must hold cross-mode lock before polling"); + + let stats = Arc::new(Stats::new()); + let mut ios = Vec::new(); + let mut counters = Vec::new(); + for _ in 0..20 { + ios.push(StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::clone(&stats), + user.clone(), + Some(2048), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + )); + } + + for io in &mut ios { + let wake_counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&wake_counter)); + let mut cx = Context::from_waker(&waker); + let poll = Pin::new(io).poll_write(&mut cx, &[0x33]); + assert!(poll.is_pending()); + counters.push(wake_counter); + } + + tokio::time::sleep(Duration::from_millis(25)).await; + let total_wakes: usize = counters + .iter() + .map(|counter| counter.wakes.load(Ordering::Relaxed)) + .sum(); + + assert!( + total_wakes <= 20 * 4, + "cross-mode contention should not create wake storms; wakes={total_wakes}" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn light_fuzz_cross_mode_release_timing_preserves_read_write_liveness() { + let _guard = quota_test_guard(); + + let mut seed = 0xC0DE_BAAD_2026_0322u64; + for round in 0..16u32 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let sleep_ms = 2 + (seed as u64 % 8); + let user = format!("cross-mode-tdd-fuzz-{}-{round}", std::process::id()); + let held = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); + let held_guard = held + .try_lock() + .expect("test must hold cross-mode lock in fuzz round"); + + let stats = Arc::new(Stats::new()); + let user_reader = user.clone(); + let reader_task = tokio::spawn(async move { + let mut io = StatsIo::new( + tokio::io::empty(), + Arc::new(SharedCounters::new()), + Arc::clone(&stats), + user_reader, + Some(4096), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + let mut one = [0u8; 1]; + io.read(&mut one).await + }); + + let user_writer = user.clone(); + let writer_task = tokio::spawn(async move { + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user_writer, + Some(4096), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + io.write_all(&[0x44]).await + }); + + tokio::time::sleep(Duration::from_millis(sleep_ms)).await; + drop(held_guard); + + let read_done = timeout(Duration::from_millis(350), reader_task) + .await + .expect("reader task must complete after release") + .expect("reader task must not panic"); + assert!(read_done.is_ok()); + + let write_done = timeout(Duration::from_millis(350), writer_task) + .await + .expect("writer task must complete after release") + .expect("writer task must not panic"); + assert!(write_done.is_ok()); + } +} diff --git a/src/proxy/tests/relay_cross_mode_quota_lock_security_tests.rs b/src/proxy/tests/relay_cross_mode_quota_lock_security_tests.rs new file mode 100644 index 0000000..5ea806a --- /dev/null +++ b/src/proxy/tests/relay_cross_mode_quota_lock_security_tests.rs @@ -0,0 +1,81 @@ +use super::*; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::task::Waker; +use std::task::{Context, Poll}; + +#[derive(Default)] +struct WakeCounter { + wakes: AtomicUsize, +} + +impl std::task::Wake for WakeCounter { + fn wake(self: Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } + + fn wake_by_ref(self: &Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } +} + +fn build_context() -> (Arc, Context<'static>) { + let wake_counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&wake_counter)); + let leaked_waker: &'static Waker = Box::leak(Box::new(waker)); + (wake_counter, Context::from_waker(leaked_waker)) +} + +#[tokio::test] +async fn adversarial_middle_held_cross_mode_lock_blocks_relay_writer() { + let _guard = quota_user_lock_test_scope(); + + let user = "cross-mode-lock-shared-user"; + let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(user); + let _held_guard = held + .try_lock() + .expect("test must hold shared cross-mode lock before relay poll"); + + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(crate::stats::Stats::new()), + user.to_string(), + Some(1024), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + let (_wake_counter, mut cx) = build_context(); + let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x41, 0x42, 0x43]); + + assert!( + matches!(poll, Poll::Pending), + "relay writer must not bypass cross-mode lock held by middle-relay path" + ); +} + +#[tokio::test] +async fn business_cross_mode_lock_uncontended_allows_relay_writer_progress() { + let _guard = quota_user_lock_test_scope(); + + let user = "cross-mode-lock-progress-user"; + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(crate::stats::Stats::new()), + user.to_string(), + Some(1024), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + let (_wake_counter, mut cx) = build_context(); + let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x51, 0x52]); + + assert!( + matches!(poll, Poll::Ready(Ok(2))), + "relay writer should progress when shared cross-mode lock is uncontended" + ); +} diff --git a/src/proxy/tests/relay_quota_lock_identity_security_tests.rs b/src/proxy/tests/relay_quota_lock_identity_security_tests.rs new file mode 100644 index 0000000..f717f54 --- /dev/null +++ b/src/proxy/tests/relay_quota_lock_identity_security_tests.rs @@ -0,0 +1,135 @@ +use super::*; +use crate::stats::Stats; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::task::Waker; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +#[derive(Default)] +struct WakeCounter { + wakes: AtomicUsize, +} + +impl std::task::Wake for WakeCounter { + fn wake(self: Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } + + fn wake_by_ref(self: &Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } +} + +fn build_context() -> (Arc, Context<'static>) { + let wake_counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&wake_counter)); + // Context stores a reference; leak one Waker for deterministic test scope. + let leaked_waker: &'static Waker = Box::leak(Box::new(waker)); + (wake_counter, Context::from_waker(leaked_waker)) +} + +#[tokio::test] +async fn adversarial_map_churn_cannot_bypass_held_writer_lock() { + let _guard = quota_user_lock_test_scope(); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let user = "quota-identity-writer-user"; + let held_lock = quota_user_lock(user); + let _held_guard = held_lock + .try_lock() + .expect("test must hold initial user lock before StatsIo poll"); + + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user.to_string(), + Some(1024), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + map.clear(); + let churned_lock = quota_user_lock(user); + assert!( + !Arc::ptr_eq(&held_lock, &churned_lock), + "precondition: map churn should produce a distinct lock identity" + ); + + let (_wake_counter, mut cx) = build_context(); + let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x11, 0x22, 0x33, 0x44]); + + assert!( + matches!(poll, Poll::Pending), + "writer must remain pending on the originally-held lock identity" + ); +} + +#[tokio::test] +async fn adversarial_map_churn_cannot_bypass_held_reader_lock() { + let _guard = quota_user_lock_test_scope(); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let user = "quota-identity-reader-user"; + let held_lock = quota_user_lock(user); + let _held_guard = held_lock + .try_lock() + .expect("test must hold initial user lock before StatsIo poll"); + + let mut io = StatsIo::new( + tokio::io::empty(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user.to_string(), + Some(1024), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + map.clear(); + let churned_lock = quota_user_lock(user); + assert!( + !Arc::ptr_eq(&held_lock, &churned_lock), + "precondition: map churn should produce a distinct lock identity" + ); + + let (_wake_counter, mut cx) = build_context(); + let mut storage = [0u8; 8]; + let mut read_buf = ReadBuf::new(&mut storage); + let poll = Pin::new(&mut io).poll_read(&mut cx, &mut read_buf); + + assert!( + matches!(poll, Poll::Pending), + "reader must remain pending on the originally-held lock identity" + ); +} + +#[tokio::test] +async fn business_no_lock_contention_keeps_writer_progress() { + let _guard = quota_user_lock_test_scope(); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let user = "quota-identity-progress-user"; + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user.to_string(), + Some(1024), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + let (_wake_counter, mut cx) = build_context(); + let poll = Pin::new(&mut io).poll_write(&mut cx, &[0xAA, 0xBB]); + + assert!( + matches!(poll, Poll::Ready(Ok(2))), + "writer should progress immediately without contention" + ); +} diff --git a/src/proxy/tests/relay_quota_retry_backoff_benchmark_security_tests.rs b/src/proxy/tests/relay_quota_retry_backoff_benchmark_security_tests.rs new file mode 100644 index 0000000..7083eb2 --- /dev/null +++ b/src/proxy/tests/relay_quota_retry_backoff_benchmark_security_tests.rs @@ -0,0 +1,241 @@ +use super::*; +use crate::stats::Stats; +use dashmap::DashMap; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::task::{Context, Waker}; +use tokio::io::ReadBuf; +use tokio::time::{Duration, Instant}; + +#[derive(Default)] +struct WakeCounter { + wakes: AtomicUsize, +} + +impl std::task::Wake for WakeCounter { + fn wake(self: Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } + + fn wake_by_ref(self: &Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } +} + +fn quota_test_guard() -> impl Drop { + super::quota_user_lock_test_scope() +} + +fn saturate_quota_user_locks() -> Vec>> { + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); + for idx in 0..QUOTA_USER_LOCKS_MAX { + retained.push(quota_user_lock(&format!("quota-retry-bench-saturate-{idx}"))); + } + retained +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_contention_wake_rate_decays_with_backoff_curve() { + let _guard = quota_test_guard(); + + let _retained = saturate_quota_user_locks(); + let user = format!("quota-backoff-bench-{}", std::process::id()); + let stats = Arc::new(Stats::new()); + + let lock = quota_user_lock(&user); + let held_guard = lock + .try_lock() + .expect("test must hold quota lock before benchmark run"); + + let waiters = 64usize; + let mut ios = Vec::with_capacity(waiters); + let mut wake_counters = Vec::with_capacity(waiters); + + for _ in 0..waiters { + ios.push(StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::clone(&stats), + user.clone(), + Some(4096), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + )); + } + + for io in &mut ios { + let counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&counter)); + let mut cx = Context::from_waker(&waker); + let pending = Pin::new(io).poll_write(&mut cx, &[0x71]); + assert!(pending.is_pending()); + wake_counters.push(counter); + } + + let mut observed = vec![0usize; waiters]; + let start = Instant::now(); + let mut wakes_at_40ms = 0usize; + let mut wakes_at_160ms = 0usize; + + while start.elapsed() < Duration::from_millis(200) { + for (idx, counter) in wake_counters.iter().enumerate() { + let wakes = counter.wakes.load(Ordering::Relaxed); + if wakes > observed[idx] { + observed[idx] = wakes; + let waker = Waker::from(Arc::clone(counter)); + let mut cx = Context::from_waker(&waker); + let pending = Pin::new(&mut ios[idx]).poll_write(&mut cx, &[0x72]); + assert!(pending.is_pending()); + } + } + + let elapsed = start.elapsed(); + if elapsed >= Duration::from_millis(40) && wakes_at_40ms == 0 { + wakes_at_40ms = wake_counters + .iter() + .map(|counter| counter.wakes.load(Ordering::Relaxed)) + .sum(); + } + if elapsed >= Duration::from_millis(160) && wakes_at_160ms == 0 { + wakes_at_160ms = wake_counters + .iter() + .map(|counter| counter.wakes.load(Ordering::Relaxed)) + .sum(); + } + + tokio::time::sleep(Duration::from_millis(1)).await; + } + + let total_wakes: usize = wake_counters + .iter() + .map(|counter| counter.wakes.load(Ordering::Relaxed)) + .sum(); + + let wakes_at_200ms = total_wakes; + let early_window_wakes = wakes_at_40ms; + let late_window_wakes = wakes_at_200ms.saturating_sub(wakes_at_160ms); + + assert!( + total_wakes <= waiters * 28, + "backoff benchmark exceeded wake budget; waiters={waiters}, wakes={total_wakes}" + ); + + assert!( + early_window_wakes > 0, + "benchmark failed to observe early contention wakes" + ); + + assert!( + late_window_wakes * 4 <= early_window_wakes * 3, + "wake-rate decay invariant violated; early_0_40ms={early_window_wakes}, late_160_200ms={late_window_wakes}, total={total_wakes}" + ); + + drop(held_guard); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_read_contention_wake_rate_decays_with_backoff_curve() { + let _guard = quota_test_guard(); + + let _retained = saturate_quota_user_locks(); + let user = format!("quota-backoff-read-bench-{}", std::process::id()); + let stats = Arc::new(Stats::new()); + + let lock = quota_user_lock(&user); + let held_guard = lock + .try_lock() + .expect("test must hold quota lock before read benchmark run"); + + let waiters = 64usize; + let mut ios = Vec::with_capacity(waiters); + let mut wake_counters = Vec::with_capacity(waiters); + + for _ in 0..waiters { + ios.push(StatsIo::new( + tokio::io::empty(), + Arc::new(SharedCounters::new()), + Arc::clone(&stats), + user.clone(), + Some(4096), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + )); + } + + for io in &mut ios { + let counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&counter)); + let mut cx = Context::from_waker(&waker); + let mut storage = [0u8; 1]; + let mut buf = ReadBuf::new(&mut storage); + let pending = Pin::new(io).poll_read(&mut cx, &mut buf); + assert!(pending.is_pending()); + wake_counters.push(counter); + } + + let mut observed = vec![0usize; waiters]; + let start = Instant::now(); + let mut wakes_at_40ms = 0usize; + let mut wakes_at_160ms = 0usize; + + while start.elapsed() < Duration::from_millis(200) { + for (idx, counter) in wake_counters.iter().enumerate() { + let wakes = counter.wakes.load(Ordering::Relaxed); + if wakes > observed[idx] { + observed[idx] = wakes; + let waker = Waker::from(Arc::clone(counter)); + let mut cx = Context::from_waker(&waker); + let mut storage = [0u8; 1]; + let mut buf = ReadBuf::new(&mut storage); + let pending = Pin::new(&mut ios[idx]).poll_read(&mut cx, &mut buf); + assert!(pending.is_pending()); + } + } + + let elapsed = start.elapsed(); + if elapsed >= Duration::from_millis(40) && wakes_at_40ms == 0 { + wakes_at_40ms = wake_counters + .iter() + .map(|counter| counter.wakes.load(Ordering::Relaxed)) + .sum(); + } + if elapsed >= Duration::from_millis(160) && wakes_at_160ms == 0 { + wakes_at_160ms = wake_counters + .iter() + .map(|counter| counter.wakes.load(Ordering::Relaxed)) + .sum(); + } + + tokio::time::sleep(Duration::from_millis(1)).await; + } + + let total_wakes: usize = wake_counters + .iter() + .map(|counter| counter.wakes.load(Ordering::Relaxed)) + .sum(); + + let wakes_at_200ms = total_wakes; + let early_window_wakes = wakes_at_40ms; + let late_window_wakes = wakes_at_200ms.saturating_sub(wakes_at_160ms); + + assert!( + total_wakes <= waiters * 28, + "read backoff benchmark exceeded wake budget; waiters={waiters}, wakes={total_wakes}" + ); + + assert!( + early_window_wakes > 0, + "read benchmark failed to observe early contention wakes" + ); + + assert!( + late_window_wakes * 4 <= early_window_wakes * 3, + "read wake-rate decay invariant violated; early_0_40ms={early_window_wakes}, late_160_200ms={late_window_wakes}, total={total_wakes}" + ); + + drop(held_guard); +} diff --git a/src/proxy/tests/relay_quota_retry_backoff_security_tests.rs b/src/proxy/tests/relay_quota_retry_backoff_security_tests.rs new file mode 100644 index 0000000..7f1e451 --- /dev/null +++ b/src/proxy/tests/relay_quota_retry_backoff_security_tests.rs @@ -0,0 +1,339 @@ +use super::*; +use crate::stats::Stats; +use dashmap::DashMap; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::task::{Context, Waker}; +use tokio::io::ReadBuf; +use tokio::time::{Duration, Instant}; + +#[derive(Default)] +struct WakeCounter { + wakes: AtomicUsize, +} + +impl std::task::Wake for WakeCounter { + fn wake(self: Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } + + fn wake_by_ref(self: &Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } +} + +fn quota_test_guard() -> impl Drop { + super::quota_user_lock_test_scope() +} + +fn saturate_quota_user_locks() -> Vec>> { + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); + for idx in 0..QUOTA_USER_LOCKS_MAX { + retained.push(quota_user_lock(&format!("quota-retry-backoff-saturate-{idx}"))); + } + retained +} + +#[tokio::test] +async fn positive_uncontended_writer_keeps_retry_wakes_zero() { + let _guard = quota_test_guard(); + + let stats = Arc::new(Stats::new()); + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::clone(&stats), + "quota-backoff-positive".to_string(), + Some(2048), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + let wake_counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&wake_counter)); + let mut cx = Context::from_waker(&waker); + + let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x41, 0x42]); + assert!(poll.is_ready(), "uncontended writer must complete immediately"); + assert_eq!( + wake_counter.wakes.load(Ordering::Relaxed), + 0, + "uncontended path must not schedule deferred contention wakes" + ); +} + +#[tokio::test] +async fn adversarial_writer_sustained_contention_executor_repoll_is_rate_limited() { + let _guard = quota_test_guard(); + + let _retained = saturate_quota_user_locks(); + let user = "quota-backoff-adversarial-writer"; + let stats = Arc::new(Stats::new()); + + let lock = quota_user_lock(user); + let held_guard = lock + .try_lock() + .expect("test must hold quota lock before polling writer"); + + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::clone(&stats), + user.to_string(), + Some(2048), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + let wake_counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&wake_counter)); + let mut cx = Context::from_waker(&waker); + + let first = Pin::new(&mut io).poll_write(&mut cx, &[0xAA]); + assert!(first.is_pending()); + + let start = Instant::now(); + let mut observed = 0usize; + while start.elapsed() < Duration::from_millis(80) { + let wakes = wake_counter.wakes.load(Ordering::Relaxed); + if wakes > observed { + observed = wakes; + let pending = Pin::new(&mut io).poll_write(&mut cx, &[0xAB]); + assert!(pending.is_pending()); + } + tokio::time::sleep(Duration::from_millis(1)).await; + } + + assert!( + wake_counter.wakes.load(Ordering::Relaxed) <= 16, + "sustained contention must be rate limited; observed wakes={} in 80ms", + wake_counter.wakes.load(Ordering::Relaxed) + ); + + drop(held_guard); + let ready = Pin::new(&mut io).poll_write(&mut cx, &[0xAC]); + assert!(ready.is_ready()); +} + +#[tokio::test] +async fn adversarial_reader_sustained_contention_executor_repoll_is_rate_limited() { + let _guard = quota_test_guard(); + + let _retained = saturate_quota_user_locks(); + let user = "quota-backoff-adversarial-reader"; + let stats = Arc::new(Stats::new()); + + let lock = quota_user_lock(user); + let held_guard = lock + .try_lock() + .expect("test must hold quota lock before polling reader"); + + let mut io = StatsIo::new( + tokio::io::empty(), + Arc::new(SharedCounters::new()), + Arc::clone(&stats), + user.to_string(), + Some(2048), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + let wake_counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&wake_counter)); + let mut cx = Context::from_waker(&waker); + let mut storage = [0u8; 1]; + + let mut buf = ReadBuf::new(&mut storage); + let first = Pin::new(&mut io).poll_read(&mut cx, &mut buf); + assert!(first.is_pending()); + + let start = Instant::now(); + let mut observed = 0usize; + while start.elapsed() < Duration::from_millis(80) { + let wakes = wake_counter.wakes.load(Ordering::Relaxed); + if wakes > observed { + observed = wakes; + let mut next = ReadBuf::new(&mut storage); + let pending = Pin::new(&mut io).poll_read(&mut cx, &mut next); + assert!(pending.is_pending()); + } + tokio::time::sleep(Duration::from_millis(1)).await; + } + + assert!( + wake_counter.wakes.load(Ordering::Relaxed) <= 16, + "sustained contention must be rate limited; observed wakes={} in 80ms", + wake_counter.wakes.load(Ordering::Relaxed) + ); + + drop(held_guard); + let mut done = ReadBuf::new(&mut storage); + let ready = Pin::new(&mut io).poll_read(&mut cx, &mut done); + assert!(ready.is_ready()); +} + +#[tokio::test] +async fn edge_backoff_attempt_resets_after_contention_release() { + let _guard = quota_test_guard(); + + let _retained = saturate_quota_user_locks(); + let user = "quota-backoff-edge-reset"; + let stats = Arc::new(Stats::new()); + + let lock = quota_user_lock(user); + let held_guard = lock + .try_lock() + .expect("test must hold quota lock before polling writer"); + + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::clone(&stats), + user.to_string(), + Some(2048), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + let wake_counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&wake_counter)); + let mut cx = Context::from_waker(&waker); + + let initial = Pin::new(&mut io).poll_write(&mut cx, &[0x31]); + assert!(initial.is_pending()); + + tokio::time::sleep(Duration::from_millis(15)).await; + let wakes = wake_counter.wakes.load(Ordering::Relaxed); + if wakes > 0 { + let pending = Pin::new(&mut io).poll_write(&mut cx, &[0x32]); + assert!(pending.is_pending()); + } + + drop(held_guard); + let ready = Pin::new(&mut io).poll_write(&mut cx, &[0x33]); + assert!(ready.is_ready()); + assert!( + !io.quota_write_wake_scheduled, + "successful write must clear deferred wake scheduling flag" + ); + assert!( + io.quota_write_retry_sleep.is_none(), + "successful write must clear deferred sleep slot" + ); +} + +#[tokio::test] +async fn light_fuzz_writer_repoll_schedule_keeps_wake_budget_bounded() { + let _guard = quota_test_guard(); + + let _retained = saturate_quota_user_locks(); + let user = "quota-backoff-fuzz-writer"; + let stats = Arc::new(Stats::new()); + + let lock = quota_user_lock(user); + let held_guard = lock + .try_lock() + .expect("test must hold quota lock before fuzz loop"); + + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::clone(&stats), + user.to_string(), + Some(2048), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + let wake_counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&wake_counter)); + let mut cx = Context::from_waker(&waker); + + let mut seed = 0x5EED_CAFE_7788_9900u64; + for _ in 0..64 { + let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x51]); + assert!(poll.is_pending()); + + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + let sleep_ms = (seed % 4) as u64; + tokio::time::sleep(Duration::from_millis(sleep_ms)).await; + } + + assert!( + wake_counter.wakes.load(Ordering::Relaxed) <= 24, + "fuzzed repoll schedule must keep wake budget bounded; observed wakes={}", + wake_counter.wakes.load(Ordering::Relaxed) + ); + + drop(held_guard); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_multi_waiter_contention_keeps_global_wake_budget_bounded() { + let _guard = quota_test_guard(); + + let _retained = saturate_quota_user_locks(); + let user = format!("quota-backoff-stress-{}", std::process::id()); + let stats = Arc::new(Stats::new()); + + let lock = quota_user_lock(&user); + let held_guard = lock + .try_lock() + .expect("test must hold quota lock before launching stress waiters"); + + let waiters = 48usize; + let mut ios = Vec::with_capacity(waiters); + let mut wake_counters = Vec::with_capacity(waiters); + + for _ in 0..waiters { + ios.push(StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::clone(&stats), + user.clone(), + Some(4096), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + )); + } + + for io in &mut ios { + let counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&counter)); + let mut cx = Context::from_waker(&waker); + let pending = Pin::new(io).poll_write(&mut cx, &[0x61]); + assert!(pending.is_pending()); + wake_counters.push(counter); + } + + let start = Instant::now(); + while start.elapsed() < Duration::from_millis(120) { + for (idx, counter) in wake_counters.iter().enumerate() { + if counter.wakes.load(Ordering::Relaxed) > 0 { + let waker = Waker::from(Arc::clone(counter)); + let mut cx = Context::from_waker(&waker); + let pending = Pin::new(&mut ios[idx]).poll_write(&mut cx, &[0x62]); + assert!(pending.is_pending()); + } + } + tokio::time::sleep(Duration::from_millis(1)).await; + } + + let total_wakes: usize = wake_counters + .iter() + .map(|counter| counter.wakes.load(Ordering::Relaxed)) + .sum(); + + assert!( + total_wakes <= waiters * 20, + "stress contention must keep aggregate wake budget bounded; waiters={waiters}, wakes={total_wakes}" + ); + + drop(held_guard); +} diff --git a/src/proxy/tests/relay_quota_retry_scheduler_tdd_tests.rs b/src/proxy/tests/relay_quota_retry_scheduler_tdd_tests.rs new file mode 100644 index 0000000..35a6b6e --- /dev/null +++ b/src/proxy/tests/relay_quota_retry_scheduler_tdd_tests.rs @@ -0,0 +1,246 @@ +use super::*; +use crate::stats::Stats; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::task::{Context, Poll, Waker}; +use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadBuf}; +use tokio::time::{Duration, timeout}; + +#[derive(Default)] +struct WakeCounter { + wakes: AtomicUsize, +} + +impl std::task::Wake for WakeCounter { + fn wake(self: Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } + + fn wake_by_ref(self: &Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } +} + +fn quota_test_guard() -> impl Drop { + super::quota_user_lock_test_scope() +} + +#[tokio::test] +async fn positive_uncontended_quota_limited_writer_completes() { + let _guard = quota_test_guard(); + + let stats = Arc::new(Stats::new()); + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::clone(&stats), + "tdd-uncontended".to_string(), + Some(1024), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + let result = io.write_all(&[0x41, 0x42, 0x43]).await; + assert!(result.is_ok(), "uncontended writer must complete"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn adversarial_contended_writers_without_repoll_must_not_wake_storm() { + let _guard = quota_test_guard(); + + let user = format!("tdd-writer-storm-{}", std::process::id()); + let held = quota_user_lock(&user); + let _held_guard = held + .try_lock() + .expect("test must hold quota lock before polling writers"); + + let stats = Arc::new(Stats::new()); + let writers = 24usize; + let mut ios = Vec::with_capacity(writers); + let mut wake_counters = Vec::with_capacity(writers); + + for _ in 0..writers { + ios.push(StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::clone(&stats), + user.clone(), + Some(1024), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + )); + } + + for io in &mut ios { + let counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&counter)); + let mut cx = Context::from_waker(&waker); + let poll = Pin::new(io).poll_write(&mut cx, &[0xAA]); + assert!(poll.is_pending(), "writer must be pending under held lock"); + wake_counters.push(counter); + } + + tokio::time::sleep(Duration::from_millis(25)).await; + + let total_wakes: usize = wake_counters + .iter() + .map(|counter| counter.wakes.load(Ordering::Relaxed)) + .sum(); + + assert!( + total_wakes <= writers * 4, + "retry scheduler must remain bounded without repoll; observed wakes={total_wakes}, writers={writers}" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn adversarial_contended_readers_without_repoll_must_not_wake_storm() { + let _guard = quota_test_guard(); + + let user = format!("tdd-reader-storm-{}", std::process::id()); + let held = quota_user_lock(&user); + let _held_guard = held + .try_lock() + .expect("test must hold quota lock before polling readers"); + + let stats = Arc::new(Stats::new()); + let readers = 24usize; + let mut ios = Vec::with_capacity(readers); + let mut wake_counters = Vec::with_capacity(readers); + + for _ in 0..readers { + ios.push(StatsIo::new( + tokio::io::empty(), + Arc::new(SharedCounters::new()), + Arc::clone(&stats), + user.clone(), + Some(1024), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + )); + } + + for io in &mut ios { + let counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&counter)); + let mut cx = Context::from_waker(&waker); + let mut storage = [0u8; 1]; + let mut buf = ReadBuf::new(&mut storage); + let poll = Pin::new(io).poll_read(&mut cx, &mut buf); + assert!(poll.is_pending(), "reader must be pending under held lock"); + wake_counters.push(counter); + } + + tokio::time::sleep(Duration::from_millis(25)).await; + + let total_wakes: usize = wake_counters + .iter() + .map(|counter| counter.wakes.load(Ordering::Relaxed)) + .sum(); + + assert!( + total_wakes <= readers * 4, + "retry scheduler must remain bounded without repoll; observed wakes={total_wakes}, readers={readers}" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn integration_contended_waiters_resume_after_lock_release() { + let _guard = quota_test_guard(); + + let user = format!("tdd-resume-{}", std::process::id()); + let held = quota_user_lock(&user); + let held_guard = held + .try_lock() + .expect("test must hold quota lock before launching waiters"); + + let stats = Arc::new(Stats::new()); + let mut waiters = Vec::new(); + for _ in 0..12 { + let stats = Arc::clone(&stats); + let user = user.clone(); + waiters.push(tokio::spawn(async move { + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + stats, + user, + Some(2048), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + io.write_all(&[0x5A]).await + })); + } + + tokio::time::sleep(Duration::from_millis(5)).await; + drop(held_guard); + + timeout(Duration::from_secs(1), async { + for waiter in waiters { + let result = waiter.await.expect("waiter task must not panic"); + assert!(result.is_ok(), "waiter must complete after release"); + } + }) + .await + .expect("all waiters must complete in bounded time"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn light_fuzz_contention_rounds_keep_retry_wakes_bounded() { + let _guard = quota_test_guard(); + + let mut seed = 0x9E37_79B9_AA55_1234u64; + for round in 0..20u32 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let writers = 8 + (seed as usize % 12); + let sleep_ms = 10 + (seed as u64 % 15); + let user = format!("tdd-fuzz-{}-{round}", std::process::id()); + + let held = quota_user_lock(&user); + let _held_guard = held + .try_lock() + .expect("test must hold quota lock in fuzz round"); + + let stats = Arc::new(Stats::new()); + let mut ios = Vec::with_capacity(writers); + let mut wake_counters = Vec::with_capacity(writers); + + for _ in 0..writers { + ios.push(StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::clone(&stats), + user.clone(), + Some(2048), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + )); + } + + for io in &mut ios { + let counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&counter)); + let mut cx = Context::from_waker(&waker); + let poll = Pin::new(io).poll_write(&mut cx, &[0x7A]); + assert!(matches!(poll, Poll::Pending)); + wake_counters.push(counter); + } + + tokio::time::sleep(Duration::from_millis(sleep_ms)).await; + + let total_wakes: usize = wake_counters + .iter() + .map(|counter| counter.wakes.load(Ordering::Relaxed)) + .sum(); + + assert!( + total_wakes <= writers * 4, + "fuzz round must keep wakes bounded; round={round}, writers={writers}, wakes={total_wakes}, sleep_ms={sleep_ms}" + ); + } +} diff --git a/src/proxy/tests/relay_security_tests.rs b/src/proxy/tests/relay_security_tests.rs index 50cdfa3..7375192 100644 --- a/src/proxy/tests/relay_security_tests.rs +++ b/src/proxy/tests/relay_security_tests.rs @@ -137,10 +137,10 @@ async fn quota_lock_contention_writer_schedules_single_deferred_wake_until_lock_ for _ in 0..8 { tokio::task::yield_now().await; } - assert_eq!( - wake_counter.wakes.load(Ordering::Relaxed), - wakes_after_first_yield, - "writer contention should not schedule unbounded wake storms before lock acquisition" + let wakes_after_second_window = wake_counter.wakes.load(Ordering::Relaxed); + assert!( + wakes_after_second_window <= wakes_after_first_yield.saturating_add(2), + "writer contention should keep retry wakes bounded before lock acquisition: before={wakes_after_first_yield}, after={wakes_after_second_window}" ); drop(held_lock); diff --git a/src/stats/mod.rs b/src/stats/mod.rs index d13d834..dc455a1 100644 --- a/src/stats/mod.rs +++ b/src/stats/mod.rs @@ -1884,6 +1884,32 @@ impl Stats { stats.octets_to_client.fetch_add(bytes, Ordering::Relaxed); } + pub fn sub_user_octets_to(&self, user: &str, bytes: u64) { + if !self.telemetry_user_enabled() { + return; + } + self.maybe_cleanup_user_stats(); + let Some(stats) = self.user_stats.get(user) else { + return; + }; + + Self::touch_user_stats(stats.value()); + let counter = &stats.octets_to_client; + let mut current = counter.load(Ordering::Relaxed); + loop { + let next = current.saturating_sub(bytes); + match counter.compare_exchange_weak( + current, + next, + Ordering::Relaxed, + Ordering::Relaxed, + ) { + Ok(_) => break, + Err(actual) => current = actual, + } + } + } + pub fn increment_user_msgs_from(&self, user: &str) { if !self.telemetry_user_enabled() { return; @@ -2440,3 +2466,7 @@ mod connection_lease_security_tests; #[cfg(test)] #[path = "tests/replay_checker_security_tests.rs"] mod replay_checker_security_tests; + +#[cfg(test)] +#[path = "tests/user_octets_sub_security_tests.rs"] +mod user_octets_sub_security_tests; diff --git a/src/stats/tests/user_octets_sub_security_tests.rs b/src/stats/tests/user_octets_sub_security_tests.rs new file mode 100644 index 0000000..d4e7580 --- /dev/null +++ b/src/stats/tests/user_octets_sub_security_tests.rs @@ -0,0 +1,151 @@ +use super::*; +use std::sync::Arc; +use std::thread; + +#[test] +fn sub_user_octets_to_underflow_saturates_at_zero() { + let stats = Stats::new(); + let user = "sub-underflow-user"; + + stats.add_user_octets_to(user, 3); + stats.sub_user_octets_to(user, 100); + + assert_eq!(stats.get_user_total_octets(user), 0); +} + +#[test] +fn sub_user_octets_to_does_not_affect_octets_from_client() { + let stats = Stats::new(); + let user = "sub-isolation-user"; + + stats.add_user_octets_from(user, 17); + stats.add_user_octets_to(user, 5); + stats.sub_user_octets_to(user, 3); + + assert_eq!(stats.get_user_total_octets(user), 19); +} + +#[test] +fn light_fuzz_add_sub_model_matches_saturating_reference() { + let stats = Stats::new(); + let user = "sub-fuzz-user"; + let mut seed = 0x91D2_4CB8_EE77_1101u64; + let mut model_to = 0u64; + + for _ in 0..8192 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let amt = ((seed >> 8) & 0x3f) + 1; + if (seed & 1) == 0 { + stats.add_user_octets_to(user, amt); + model_to = model_to.saturating_add(amt); + } else { + stats.sub_user_octets_to(user, amt); + model_to = model_to.saturating_sub(amt); + } + } + + assert_eq!(stats.get_user_total_octets(user), model_to); +} + +#[test] +fn stress_parallel_add_sub_never_underflows_or_panics() { + let stats = Arc::new(Stats::new()); + let user = "sub-stress-user"; + // Pre-fund with a large offset so subtractions never saturate at zero. + // This guarantees commutative updates, making the final state deterministic. + let base_offset = 10_000_000u64; + stats.add_user_octets_to(user, base_offset); + + let mut workers = Vec::new(); + + for tid in 0..16u64 { + let stats_for_thread = Arc::clone(&stats); + workers.push(thread::spawn(move || { + let mut seed = 0xD00D_1000_0000_0000u64 ^ tid; + let mut net_delta = 0i64; + for _ in 0..4096 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + let amt = ((seed >> 8) & 0x1f) + 1; + + if (seed & 1) == 0 { + stats_for_thread.add_user_octets_to(user, amt); + net_delta += amt as i64; + } else { + stats_for_thread.sub_user_octets_to(user, amt); + net_delta -= amt as i64; + } + } + + net_delta + })); + } + + let mut expected_net_delta = 0i64; + for worker in workers { + expected_net_delta += worker + .join() + .expect("sub-user stress worker must not panic"); + } + + let expected_total = (base_offset as i64 + expected_net_delta) as u64; + let total = stats.get_user_total_octets(user); + assert_eq!( + total, expected_total, + "concurrent add/sub lost updates or suffered ABA races" + ); +} + +#[test] +fn sub_user_octets_to_missing_user_is_noop() { + let stats = Stats::new(); + stats.sub_user_octets_to("missing-user", 1024); + assert_eq!(stats.get_user_total_octets("missing-user"), 0); +} + +#[test] +fn stress_parallel_per_user_models_remain_exact() { + let stats = Arc::new(Stats::new()); + let mut workers = Vec::new(); + + for tid in 0..16u64 { + let stats_for_thread = Arc::clone(&stats); + workers.push(thread::spawn(move || { + let user = format!("sub-per-user-{tid}"); + let mut seed = 0xFACE_0000_0000_0000u64 ^ tid; + let mut model = 0u64; + + for _ in 0..4096 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + let amt = ((seed >> 8) & 0x3f) + 1; + + if (seed & 1) == 0 { + stats_for_thread.add_user_octets_to(&user, amt); + model = model.saturating_add(amt); + } else { + stats_for_thread.sub_user_octets_to(&user, amt); + model = model.saturating_sub(amt); + } + } + + (user, model) + })); + } + + for worker in workers { + let (user, model) = worker + .join() + .expect("per-user subtract stress worker must not panic"); + assert_eq!( + stats.get_user_total_octets(&user), + model, + "per-user parallel model diverged" + ); + } +} \ No newline at end of file From 6f17d4d2316889359266e6e5e2e07dea27c1ecb4 Mon Sep 17 00:00:00 2001 From: David Osipov Date: Mon, 23 Mar 2026 12:04:41 +0400 Subject: [PATCH 06/29] Add comprehensive security tests for quota management and relay functionality - Introduced `relay_dual_lock_race_harness_security_tests.rs` to validate user liveness during lock hold and release cycles. - Added `relay_quota_extended_attack_surface_security_tests.rs` to cover various quota scenarios including positive, negative, edge cases, and adversarial conditions. - Implemented `relay_quota_lock_eviction_lifecycle_tdd_tests.rs` to ensure proper eviction of stale entries and lifecycle management of quota locks. - Created `relay_quota_lock_eviction_stress_security_tests.rs` to stress test the eviction mechanism under high churn conditions. - Enhanced `relay_quota_lock_pressure_adversarial_tests.rs` to verify reclaiming of unreferenced entries after explicit eviction. - Developed `relay_quota_retry_allocation_latency_security_tests.rs` to benchmark and validate latency and allocation behavior under contention. --- Cargo.lock | 4 +- src/maestro/runtime_tasks.rs | 31 + src/proxy/handshake.rs | 7 +- src/proxy/masking.rs | 112 ++- src/proxy/middle_relay.rs | 125 ++- src/proxy/quota_lock_registry.rs | 37 +- src/proxy/relay.rs | 149 ++- src/proxy/tests/client_security_tests.rs | 2 +- ...auth_probe_eviction_bias_security_tests.rs | 93 ++ ...e_auth_probe_scan_budget_security_tests.rs | 21 +- ...ake_auth_probe_scan_offset_stress_tests.rs | 21 +- .../tests/handshake_more_clever_tests.rs | 2 +- ..._extended_attack_surface_security_tests.rs | 217 +++++ ...erface_cache_concurrency_security_tests.rs | 41 + .../masking_interface_cache_security_tests.rs | 14 +- ...roduction_cap_regression_security_tests.rs | 289 ++++++ ...masking_self_target_loop_security_tests.rs | 54 +- ...g_timing_budget_coupling_security_tests.rs | 55 ++ ...relay_coverage_high_risk_security_tests.rs | 69 ++ ..._lock_release_regression_security_tests.rs | 295 ++++++ ...s_mode_lookup_efficiency_security_tests.rs | 116 +++ ...s_mode_quota_lock_matrix_security_tests.rs | 376 ++++++++ ...s_mode_quota_reservation_security_tests.rs | 254 +++++ .../middle_relay_hol_quota_security_tests.rs | 3 + ..._extended_attack_surface_security_tests.rs | 372 ++++++++ ...lay_quota_reservation_adversarial_tests.rs | 874 ++++++++++++++++++ ...uota_reservation_extreme_security_tests.rs | 399 ++++++++ ...y_frame_debt_concurrency_security_tests.rs | 34 +- ...rame_debt_proto_chunking_security_tests.rs | 59 +- ...le_relay_tiny_frame_debt_security_tests.rs | 282 +++++- ...pipeline_hol_integration_security_tests.rs | 267 ++++++ ...peline_latency_benchmark_security_tests.rs | 213 +++++ ...lay_cross_mode_quota_fairness_tdd_tests.rs | 381 +++++++- ...k_alternating_contention_security_tests.rs | 340 +++++++ ..._lock_backoff_regression_security_tests.rs | 74 ++ ...l_lock_contention_matrix_security_tests.rs | 325 +++++++ ...y_dual_lock_race_harness_security_tests.rs | 128 +++ ..._extended_attack_surface_security_tests.rs | 332 +++++++ ...quota_lock_eviction_lifecycle_tdd_tests.rs | 79 ++ ...ota_lock_eviction_stress_security_tests.rs | 153 +++ ...y_quota_lock_pressure_adversarial_tests.rs | 4 +- ...retry_allocation_latency_security_tests.rs | 249 +++++ 42 files changed, 6774 insertions(+), 178 deletions(-) create mode 100644 src/proxy/tests/handshake_auth_probe_eviction_bias_security_tests.rs create mode 100644 src/proxy/tests/masking_extended_attack_surface_security_tests.rs create mode 100644 src/proxy/tests/masking_interface_cache_concurrency_security_tests.rs create mode 100644 src/proxy/tests/masking_production_cap_regression_security_tests.rs create mode 100644 src/proxy/tests/masking_timing_budget_coupling_security_tests.rs create mode 100644 src/proxy/tests/middle_relay_cross_mode_lock_release_regression_security_tests.rs create mode 100644 src/proxy/tests/middle_relay_cross_mode_lookup_efficiency_security_tests.rs create mode 100644 src/proxy/tests/middle_relay_cross_mode_quota_lock_matrix_security_tests.rs create mode 100644 src/proxy/tests/middle_relay_cross_mode_quota_reservation_security_tests.rs create mode 100644 src/proxy/tests/middle_relay_quota_extended_attack_surface_security_tests.rs create mode 100644 src/proxy/tests/middle_relay_quota_reservation_extreme_security_tests.rs create mode 100644 src/proxy/tests/relay_cross_mode_pipeline_hol_integration_security_tests.rs create mode 100644 src/proxy/tests/relay_cross_mode_pipeline_latency_benchmark_security_tests.rs create mode 100644 src/proxy/tests/relay_dual_lock_alternating_contention_security_tests.rs create mode 100644 src/proxy/tests/relay_dual_lock_backoff_regression_security_tests.rs create mode 100644 src/proxy/tests/relay_dual_lock_contention_matrix_security_tests.rs create mode 100644 src/proxy/tests/relay_dual_lock_race_harness_security_tests.rs create mode 100644 src/proxy/tests/relay_quota_extended_attack_surface_security_tests.rs create mode 100644 src/proxy/tests/relay_quota_lock_eviction_lifecycle_tdd_tests.rs create mode 100644 src/proxy/tests/relay_quota_lock_eviction_stress_security_tests.rs create mode 100644 src/proxy/tests/relay_quota_retry_allocation_latency_security_tests.rs diff --git a/Cargo.lock b/Cargo.lock index c4cde39..92da630 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1454,9 +1454,9 @@ dependencies = [ [[package]] name = "iri-string" -version = "0.7.10" +version = "0.7.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c91338f0783edbd6195decb37bae672fd3b165faffb89bf7b9e6942f8b1a731a" +checksum = "d8e7418f59cc01c88316161279a7f665217ae316b388e58a0d10e29f54f1e5eb" dependencies = [ "memchr", "serde", diff --git a/src/maestro/runtime_tasks.rs b/src/maestro/runtime_tasks.rs index d553eb9..066c853 100644 --- a/src/maestro/runtime_tasks.rs +++ b/src/maestro/runtime_tasks.rs @@ -32,6 +32,14 @@ pub(crate) struct RuntimeWatches { pub(crate) detected_ip_v6: Option, } +const QUOTA_USER_LOCK_EVICT_INTERVAL_SECS: u64 = 60; + +fn spawn_quota_lock_maintenance_task() -> tokio::task::JoinHandle<()> { + crate::proxy::relay::spawn_quota_user_lock_evictor(std::time::Duration::from_secs( + QUOTA_USER_LOCK_EVICT_INTERVAL_SECS, + )) +} + #[allow(clippy::too_many_arguments)] pub(crate) async fn spawn_runtime_tasks( config: &Arc, @@ -69,6 +77,8 @@ pub(crate) async fn spawn_runtime_tasks( rc_clone.run_periodic_cleanup().await; }); + spawn_quota_lock_maintenance_task(); + let detected_ip_v4: Option = probe.detected_ipv4.map(IpAddr::V4); let detected_ip_v6: Option = probe.detected_ipv6.map(IpAddr::V6); debug!( @@ -360,3 +370,24 @@ pub(crate) async fn mark_runtime_ready(startup_tracker: &Arc) { .await; startup_tracker.mark_ready().await; } + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn tdd_runtime_quota_lock_maintenance_path_spawns_single_evictor_task() { + crate::proxy::relay::reset_quota_user_lock_evictor_spawn_count_for_tests(); + + let handle = spawn_quota_lock_maintenance_task(); + tokio::time::sleep(std::time::Duration::from_millis(5)).await; + + assert_eq!( + crate::proxy::relay::quota_user_lock_evictor_spawn_count_for_tests(), + 1, + "runtime maintenance path must spawn exactly one quota lock evictor task per call" + ); + + handle.abort(); + } +} diff --git a/src/proxy/handshake.rs b/src/proxy/handshake.rs index 3444a88..96994c7 100644 --- a/src/proxy/handshake.rs +++ b/src/proxy/handshake.rs @@ -131,8 +131,7 @@ fn auth_probe_scan_start_offset( return 0; } - let window = state_len.min(scan_limit); - auth_probe_eviction_offset(peer_ip, now) % window + auth_probe_eviction_offset(peer_ip, now) % state_len } fn auth_probe_is_throttled(peer_ip: IpAddr, now: Instant) -> bool { @@ -997,6 +996,10 @@ mod auth_probe_scan_budget_security_tests; #[path = "tests/handshake_auth_probe_scan_offset_stress_tests.rs"] mod auth_probe_scan_offset_stress_tests; +#[cfg(test)] +#[path = "tests/handshake_auth_probe_eviction_bias_security_tests.rs"] +mod auth_probe_eviction_bias_security_tests; + #[cfg(test)] #[path = "tests/handshake_advanced_clever_tests.rs"] mod advanced_clever_tests; diff --git a/src/proxy/masking.rs b/src/proxy/masking.rs index 7d970c2..841749c 100644 --- a/src/proxy/masking.rs +++ b/src/proxy/masking.rs @@ -19,6 +19,8 @@ use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::net::TcpStream; #[cfg(unix)] use tokio::net::UnixStream; +#[cfg(unix)] +use tokio::sync::Mutex as AsyncMutex; use tokio::time::{Instant, timeout}; use tracing::debug; @@ -95,10 +97,6 @@ where Ok(Ok(())) => {} Ok(Err(_)) | Err(_) => break, } - - if total >= byte_cap { - break; - } } CopyOutcome { total, @@ -370,6 +368,9 @@ struct LocalInterfaceCache { static LOCAL_INTERFACE_CACHE: OnceLock> = OnceLock::new(); #[cfg(unix)] +static LOCAL_INTERFACE_REFRESH_LOCK: OnceLock> = OnceLock::new(); + +#[cfg(all(unix, test))] fn local_interface_ips() -> Vec { let cache = LOCAL_INTERFACE_CACHE.get_or_init(|| Mutex::new(LocalInterfaceCache::default())); let mut guard = cache.lock().unwrap_or_else(|poison| poison.into_inner()); @@ -386,11 +387,59 @@ fn local_interface_ips() -> Vec { guard.ips.clone() } -#[cfg(not(unix))] +#[cfg(unix)] +async fn local_interface_ips_async() -> Vec { + let cache = LOCAL_INTERFACE_CACHE.get_or_init(|| Mutex::new(LocalInterfaceCache::default())); + + { + let guard = cache.lock().unwrap_or_else(|poison| poison.into_inner()); + let stale = guard + .refreshed_at + .is_none_or(|at| at.elapsed() >= LOCAL_INTERFACE_CACHE_TTL); + if !stale { + return guard.ips.clone(); + } + } + + let refresh_lock = LOCAL_INTERFACE_REFRESH_LOCK.get_or_init(|| AsyncMutex::new(())); + let _refresh_guard = refresh_lock.lock().await; + + { + let guard = cache.lock().unwrap_or_else(|poison| poison.into_inner()); + let stale = guard + .refreshed_at + .is_none_or(|at| at.elapsed() >= LOCAL_INTERFACE_CACHE_TTL); + if !stale { + return guard.ips.clone(); + } + } + + let refreshed = tokio::task::spawn_blocking(collect_local_interface_ips) + .await + .unwrap_or_default(); + + let mut guard = cache.lock().unwrap_or_else(|poison| poison.into_inner()); + let stale = guard + .refreshed_at + .is_none_or(|at| at.elapsed() >= LOCAL_INTERFACE_CACHE_TTL); + if stale { + guard.ips = choose_interface_snapshot(&guard.ips, refreshed); + guard.refreshed_at = Some(StdInstant::now()); + } + + guard.ips.clone() +} + +#[cfg(all(not(unix), test))] fn local_interface_ips() -> Vec { Vec::new() } +#[cfg(not(unix))] +async fn local_interface_ips_async() -> Vec { + Vec::new() +} + #[cfg(test)] static LOCAL_INTERFACE_ENUMERATIONS: AtomicUsize = AtomicUsize::new(0); @@ -457,6 +506,7 @@ fn is_mask_target_local_listener_with_interfaces( false } +#[cfg(test)] fn is_mask_target_local_listener( mask_host: &str, mask_port: u16, @@ -477,6 +527,26 @@ fn is_mask_target_local_listener( ) } +async fn is_mask_target_local_listener_async( + mask_host: &str, + mask_port: u16, + local_addr: SocketAddr, + resolved_override: Option, +) -> bool { + if mask_port != local_addr.port() { + return false; + } + + let interfaces = local_interface_ips_async().await; + is_mask_target_local_listener_with_interfaces( + mask_host, + mask_port, + local_addr, + resolved_override, + &interfaces, + ) +} + fn masking_beobachten_ttl(config: &ProxyConfig) -> Duration { let minutes = config.general.beobachten_minutes; let clamped = minutes.clamp(1, 24 * 60); @@ -608,13 +678,15 @@ pub async fn handle_bad_client( .as_deref() .unwrap_or(&config.censorship.tls_domain); let mask_port = config.censorship.mask_port; - let outcome_started = Instant::now(); // Fail closed when fallback points at our own listener endpoint. // Self-referential masking can create recursive proxy loops under // misconfiguration and leak distinguishable load spikes to adversaries. let resolved_mask_addr = resolve_socket_addr(mask_host, mask_port); - if is_mask_target_local_listener(mask_host, mask_port, local_addr, resolved_mask_addr) { + if is_mask_target_local_listener_async(mask_host, mask_port, local_addr, resolved_mask_addr) + .await + { + let outcome_started = Instant::now(); debug!( client_type = client_type, host = %mask_host, @@ -627,6 +699,8 @@ pub async fn handle_bad_client( return; } + let outcome_started = Instant::now(); + debug!( client_type = client_type, host = %mask_host, @@ -768,7 +842,13 @@ async fn consume_client_data(mut reader: R, byte_cap: usiz let mut total = 0usize; loop { - let n = match timeout(MASK_RELAY_IDLE_TIMEOUT, reader.read(&mut buf)).await { + let remaining_budget = byte_cap.saturating_sub(total); + if remaining_budget == 0 { + break; + } + + let read_len = remaining_budget.min(MASK_BUFFER_SIZE); + let n = match timeout(MASK_RELAY_IDLE_TIMEOUT, reader.read(&mut buf[..read_len])).await { Ok(Ok(n)) => n, Ok(Err(_)) | Err(_) => break, }; @@ -804,6 +884,10 @@ mod masking_shape_above_cap_blur_security_tests; #[path = "tests/masking_timing_normalization_security_tests.rs"] mod masking_timing_normalization_security_tests; +#[cfg(test)] +#[path = "tests/masking_timing_budget_coupling_security_tests.rs"] +mod masking_timing_budget_coupling_security_tests; + #[cfg(test)] #[path = "tests/masking_ab_envelope_blur_integration_security_tests.rs"] mod masking_ab_envelope_blur_integration_security_tests; @@ -884,6 +968,18 @@ mod masking_interface_cache_security_tests; #[path = "tests/masking_interface_cache_defense_in_depth_security_tests.rs"] mod masking_interface_cache_defense_in_depth_security_tests; +#[cfg(test)] +#[path = "tests/masking_interface_cache_concurrency_security_tests.rs"] +mod masking_interface_cache_concurrency_security_tests; + +#[cfg(test)] +#[path = "tests/masking_production_cap_regression_security_tests.rs"] +mod masking_production_cap_regression_security_tests; + +#[cfg(test)] +#[path = "tests/masking_extended_attack_surface_security_tests.rs"] +mod masking_extended_attack_surface_security_tests; + #[cfg(test)] #[path = "tests/masking_padding_timeout_adversarial_tests.rs"] mod masking_padding_timeout_adversarial_tests; diff --git a/src/proxy/middle_relay.rs b/src/proxy/middle_relay.rs index 0d2a748..b6b198c 100644 --- a/src/proxy/middle_relay.rs +++ b/src/proxy/middle_relay.rs @@ -1,5 +1,7 @@ use std::collections::hash_map::RandomState; use std::collections::{BTreeSet, HashMap}; +#[cfg(test)] +use std::future::Future; use std::hash::{BuildHasher, Hash}; use std::net::{IpAddr, SocketAddr}; use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; @@ -45,6 +47,8 @@ const TINY_FRAME_DEBT_LIMIT: u32 = 512; const C2ME_SEND_TIMEOUT: Duration = Duration::from_millis(50); #[cfg(not(test))] const C2ME_SEND_TIMEOUT: Duration = Duration::from_secs(5); +#[cfg(test)] +const RELAY_TEST_STEP_TIMEOUT: Duration = Duration::from_secs(1); const ME_D2C_FLUSH_BATCH_MAX_FRAMES_MIN: usize = 1; const ME_D2C_FLUSH_BATCH_MAX_BYTES_MIN: usize = 4096; const ME_D2C_FRAME_BUF_SHRINK_HYSTERESIS_FACTOR: usize = 2; @@ -561,11 +565,8 @@ fn quota_would_be_exceeded_for_user_soft( bytes: u64, overshoot: u64, ) -> bool { - quota_limit.is_some_and(|quota| { - let cap = quota_soft_cap(quota, overshoot); - let used = stats.get_user_total_octets(user); - used >= cap || bytes > cap.saturating_sub(used) - }) + let capped_limit = quota_limit.map(|quota| quota_soft_cap(quota, overshoot)); + quota_would_be_exceeded_for_user(stats, user, capped_limit, bytes) } fn classify_me_d2c_flush_reason( @@ -683,7 +684,7 @@ fn quota_user_lock(user: &str) -> Arc> { } #[cfg(test)] -pub(crate) fn cross_mode_quota_user_lock_for_tests(user: &str) -> Arc> { +pub(crate) fn cross_mode_quota_user_lock_for_tests(user: &str) -> Arc> { crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(user) } @@ -712,6 +713,16 @@ async fn enqueue_c2me_command( } } +#[cfg(test)] +async fn run_relay_test_step_timeout(context: &'static str, fut: F) -> T +where + F: Future, +{ + timeout(RELAY_TEST_STEP_TIMEOUT, fut) + .await + .unwrap_or_else(|_| panic!("{context} exceeded {}s", RELAY_TEST_STEP_TIMEOUT.as_secs())) +} + pub(crate) async fn handle_via_middle_proxy( mut crypto_reader: CryptoReader, crypto_writer: CryptoWriter, @@ -860,6 +871,7 @@ where let stats_clone = stats.clone(); let rng_clone = rng.clone(); let user_clone = user.clone(); + let cross_mode_quota_lock_me_writer = cross_mode_quota_lock.clone(); let last_downstream_activity_ms_clone = last_downstream_activity_ms.clone(); let bytes_me2c_clone = bytes_me2c.clone(); let d2c_flush_policy = MeD2cFlushPolicy::from_config(&config); @@ -881,7 +893,7 @@ where let first_is_downstream_activity = matches!(&first, MeResponse::Data { .. } | MeResponse::Ack(_)); - match process_me_writer_response( + match process_me_writer_response_with_cross_mode_lock( first, &mut writer, proto_tag, @@ -891,6 +903,7 @@ where &user_clone, quota_limit, d2c_flush_policy.quota_soft_overshoot_bytes, + cross_mode_quota_lock_me_writer.as_ref(), bytes_me2c_clone.as_ref(), conn_id, d2c_flush_policy.ack_flush_immediate, @@ -939,7 +952,7 @@ where let next_is_downstream_activity = matches!(&next, MeResponse::Data { .. } | MeResponse::Ack(_)); - match process_me_writer_response( + match process_me_writer_response_with_cross_mode_lock( next, &mut writer, proto_tag, @@ -949,6 +962,7 @@ where &user_clone, quota_limit, d2c_flush_policy.quota_soft_overshoot_bytes, + cross_mode_quota_lock_me_writer.as_ref(), bytes_me2c_clone.as_ref(), conn_id, d2c_flush_policy.ack_flush_immediate, @@ -1000,7 +1014,7 @@ where Ok(Some(next)) => { let next_is_downstream_activity = matches!(&next, MeResponse::Data { .. } | MeResponse::Ack(_)); - match process_me_writer_response( + match process_me_writer_response_with_cross_mode_lock( next, &mut writer, proto_tag, @@ -1010,6 +1024,7 @@ where &user_clone, quota_limit, d2c_flush_policy.quota_soft_overshoot_bytes, + cross_mode_quota_lock_me_writer.as_ref(), bytes_me2c_clone.as_ref(), conn_id, d2c_flush_policy.ack_flush_immediate, @@ -1063,7 +1078,7 @@ where let extra_is_downstream_activity = matches!(&extra, MeResponse::Data { .. } | MeResponse::Ack(_)); - match process_me_writer_response( + match process_me_writer_response_with_cross_mode_lock( extra, &mut writer, proto_tag, @@ -1073,6 +1088,7 @@ where &user_clone, quota_limit, d2c_flush_policy.quota_soft_overshoot_bytes, + cross_mode_quota_lock_me_writer.as_ref(), bytes_me2c_clone.as_ref(), conn_id, d2c_flush_policy.ack_flush_immediate, @@ -1252,10 +1268,7 @@ where )); break; }; - let _cross_mode_quota_guard = match cross_mode_lock.lock() { - Ok(guard) => guard, - Err(poisoned) => poisoned.into_inner(), - }; + let _cross_mode_quota_guard = cross_mode_lock.lock().await; stats.add_user_octets_from(&user, payload.len() as u64); if quota_exceeded_for_user(stats.as_ref(), &user, Some(limit)) { main_result = Err(ProxyError::DataQuotaExceeded { @@ -1741,6 +1754,7 @@ enum MeWriterResponseOutcome { Close, } +#[cfg(test)] async fn process_me_writer_response( response: MeResponse, client_writer: &mut CryptoWriter, @@ -1756,6 +1770,44 @@ async fn process_me_writer_response( ack_flush_immediate: bool, batched: bool, ) -> Result +where + W: AsyncWrite + Unpin + Send + 'static, +{ + process_me_writer_response_with_cross_mode_lock( + response, + client_writer, + proto_tag, + rng, + frame_buf, + stats, + user, + quota_limit, + quota_soft_overshoot_bytes, + None, + bytes_me2c, + conn_id, + ack_flush_immediate, + batched, + ) + .await +} + +async fn process_me_writer_response_with_cross_mode_lock( + response: MeResponse, + client_writer: &mut CryptoWriter, + proto_tag: ProtoTag, + rng: &SecureRandom, + frame_buf: &mut Vec, + stats: &Stats, + user: &str, + quota_limit: Option, + quota_soft_overshoot_bytes: u64, + cross_mode_quota_lock: Option<&Arc>>, + bytes_me2c: &AtomicU64, + conn_id: u64, + ack_flush_immediate: bool, + batched: bool, +) -> Result where W: AsyncWrite + Unpin + Send + 'static, { @@ -1768,8 +1820,23 @@ where } let data_len = data.len() as u64; if let Some(limit) = quota_limit { + let owned_cross_mode_lock; + let cross_mode_lock = if let Some(lock) = cross_mode_quota_lock { + lock + } else { + owned_cross_mode_lock = + crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(user); + &owned_cross_mode_lock + }; + let cross_mode_quota_guard = cross_mode_lock.lock().await; let soft_limit = quota_soft_cap(limit, quota_soft_overshoot_bytes); - if quota_would_be_exceeded_for_user(stats, user, Some(soft_limit), data_len) { + if quota_would_be_exceeded_for_user_soft( + stats, + user, + Some(limit), + data_len, + quota_soft_overshoot_bytes, + ) { stats.increment_me_d2c_quota_reject_total(MeD2cQuotaRejectStage::PreWrite); return Err(ProxyError::DataQuotaExceeded { user: user.to_string(), @@ -1789,6 +1856,10 @@ where }); } + // Keep cross-mode lock scope explicit and minimal: quota reservation is serialized, + // but socket I/O proceeds without holding same-user cross-mode admission lock. + drop(cross_mode_quota_guard); + let write_mode = match write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf) .await @@ -2084,3 +2155,27 @@ mod middle_relay_tiny_frame_debt_concurrency_security_tests; #[cfg(test)] #[path = "tests/middle_relay_tiny_frame_debt_proto_chunking_security_tests.rs"] mod middle_relay_tiny_frame_debt_proto_chunking_security_tests; + +#[cfg(test)] +#[path = "tests/middle_relay_cross_mode_quota_reservation_security_tests.rs"] +mod middle_relay_cross_mode_quota_reservation_security_tests; + +#[cfg(test)] +#[path = "tests/middle_relay_cross_mode_quota_lock_matrix_security_tests.rs"] +mod middle_relay_cross_mode_quota_lock_matrix_security_tests; + +#[cfg(test)] +#[path = "tests/middle_relay_cross_mode_lookup_efficiency_security_tests.rs"] +mod middle_relay_cross_mode_lookup_efficiency_security_tests; + +#[cfg(test)] +#[path = "tests/middle_relay_cross_mode_lock_release_regression_security_tests.rs"] +mod middle_relay_cross_mode_lock_release_regression_security_tests; + +#[cfg(test)] +#[path = "tests/middle_relay_quota_extended_attack_surface_security_tests.rs"] +mod middle_relay_quota_extended_attack_surface_security_tests; + +#[cfg(test)] +#[path = "tests/middle_relay_quota_reservation_extreme_security_tests.rs"] +mod middle_relay_quota_reservation_extreme_security_tests; diff --git a/src/proxy/quota_lock_registry.rs b/src/proxy/quota_lock_registry.rs index ac64a57..7798b09 100644 --- a/src/proxy/quota_lock_registry.rs +++ b/src/proxy/quota_lock_registry.rs @@ -1,5 +1,9 @@ use dashmap::DashMap; -use std::sync::{Arc, Mutex, OnceLock}; +use std::sync::{Arc, OnceLock}; +use tokio::sync::Mutex; + +#[cfg(test)] +use std::sync::atomic::{AtomicUsize, Ordering}; #[cfg(test)] const CROSS_MODE_QUOTA_USER_LOCKS_MAX: usize = 64; @@ -13,6 +17,11 @@ const CROSS_MODE_QUOTA_OVERFLOW_LOCK_STRIPES: usize = 256; static CROSS_MODE_QUOTA_USER_LOCKS: OnceLock>>> = OnceLock::new(); static CROSS_MODE_QUOTA_USER_OVERFLOW_LOCKS: OnceLock>>> = OnceLock::new(); +#[cfg(test)] +static CROSS_MODE_QUOTA_USER_LOCK_LOOKUPS: AtomicUsize = AtomicUsize::new(0); +#[cfg(test)] +static CROSS_MODE_QUOTA_USER_LOOKUPS_BY_USER: OnceLock> = OnceLock::new(); + fn cross_mode_quota_overflow_user_lock(user: &str) -> Arc> { let stripes = CROSS_MODE_QUOTA_USER_OVERFLOW_LOCKS.get_or_init(|| { (0..CROSS_MODE_QUOTA_OVERFLOW_LOCK_STRIPES) @@ -25,6 +34,14 @@ fn cross_mode_quota_overflow_user_lock(user: &str) -> Arc> { } pub(crate) fn cross_mode_quota_user_lock(user: &str) -> Arc> { + #[cfg(test)] + { + CROSS_MODE_QUOTA_USER_LOCK_LOOKUPS.fetch_add(1, Ordering::Relaxed); + let lookups = CROSS_MODE_QUOTA_USER_LOOKUPS_BY_USER.get_or_init(DashMap::new); + let mut entry = lookups.entry(user.to_string()).or_insert(0); + *entry += 1; + } + let locks = CROSS_MODE_QUOTA_USER_LOCKS.get_or_init(DashMap::new); if let Some(existing) = locks.get(user) { return Arc::clone(existing.value()); @@ -48,6 +65,24 @@ pub(crate) fn cross_mode_quota_user_lock(user: &str) -> Arc> { } } +#[cfg(test)] +pub(crate) fn reset_cross_mode_quota_user_lock_lookup_count_for_tests() { + CROSS_MODE_QUOTA_USER_LOCK_LOOKUPS.store(0, Ordering::Relaxed); + let lookups = CROSS_MODE_QUOTA_USER_LOOKUPS_BY_USER.get_or_init(DashMap::new); + lookups.clear(); +} + +#[cfg(test)] +pub(crate) fn cross_mode_quota_user_lock_lookup_count_for_tests() -> usize { + CROSS_MODE_QUOTA_USER_LOCK_LOOKUPS.load(Ordering::Relaxed) +} + +#[cfg(test)] +pub(crate) fn cross_mode_quota_user_lock_lookup_count_for_user_for_tests(user: &str) -> usize { + let lookups = CROSS_MODE_QUOTA_USER_LOOKUPS_BY_USER.get_or_init(DashMap::new); + lookups.get(user).map(|entry| *entry).unwrap_or(0) +} + #[cfg(test)] #[path = "tests/quota_lock_registry_cross_mode_adversarial_tests.rs"] mod quota_lock_registry_cross_mode_adversarial_tests; diff --git a/src/proxy/relay.rs b/src/proxy/relay.rs index dcacedd..55f1385 100644 --- a/src/proxy/relay.rs +++ b/src/proxy/relay.rs @@ -62,6 +62,7 @@ use std::sync::{Arc, Mutex, OnceLock}; use std::task::{Context, Poll}; use std::time::Duration; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf, copy_bidirectional_with_sizes}; +use tokio::sync::Mutex as AsyncMutex; use tokio::time::{Instant, Sleep}; use tracing::{debug, trace, warn}; @@ -210,7 +211,7 @@ struct StatsIo { stats: Arc, user: String, quota_lock: Option>>, - cross_mode_quota_lock: Option>>, + cross_mode_quota_lock: Option>>, quota_limit: Option, quota_exceeded: Arc, quota_read_wake_scheduled: bool, @@ -289,6 +290,21 @@ const QUOTA_CONTENTION_RETRY_MAX_INTERVAL: Duration = Duration::from_millis(16); #[cfg(not(test))] const QUOTA_CONTENTION_RETRY_MAX_INTERVAL: Duration = Duration::from_millis(64); +#[cfg(test)] +static QUOTA_RETRY_SLEEP_ALLOCS: AtomicU64 = AtomicU64::new(0); +#[cfg(test)] +static QUOTA_USER_LOCK_EVICTOR_SPAWN_COUNT: AtomicU64 = AtomicU64::new(0); + +#[cfg(test)] +pub(crate) fn reset_quota_retry_sleep_allocs_for_tests() { + QUOTA_RETRY_SLEEP_ALLOCS.store(0, Ordering::Relaxed); +} + +#[cfg(test)] +pub(crate) fn quota_retry_sleep_allocs_for_tests() -> u64 { + QUOTA_RETRY_SLEEP_ALLOCS.load(Ordering::Relaxed) +} + #[inline] fn quota_contention_retry_delay(retry_attempt: u8) -> Duration { let shift = u32::from(retry_attempt.min(5)); @@ -317,6 +333,8 @@ fn poll_quota_retry_sleep( ) { if !*wake_scheduled { *wake_scheduled = true; + #[cfg(test)] + QUOTA_RETRY_SLEEP_ALLOCS.fetch_add(1, Ordering::Relaxed); *sleep_slot = Some(Box::pin(tokio::time::sleep(quota_contention_retry_delay( *retry_attempt, )))); @@ -368,16 +386,47 @@ fn quota_overflow_user_lock(user: &str) -> Arc> { Arc::clone(&stripes[hash % stripes.len()]) } +pub(crate) fn quota_user_lock_evict() { + if let Some(locks) = QUOTA_USER_LOCKS.get() { + locks.retain(|_, value| Arc::strong_count(value) > 1); + } +} + +pub(crate) fn spawn_quota_user_lock_evictor(interval: Duration) -> tokio::task::JoinHandle<()> { + let interval = interval.max(Duration::from_millis(1)); + #[cfg(test)] + QUOTA_USER_LOCK_EVICTOR_SPAWN_COUNT.fetch_add(1, Ordering::Relaxed); + tokio::spawn(async move { + loop { + tokio::time::sleep(interval).await; + quota_user_lock_evict(); + } + }) +} + +#[cfg(test)] +pub(crate) fn spawn_quota_user_lock_evictor_for_tests( + interval: Duration, +) -> tokio::task::JoinHandle<()> { + spawn_quota_user_lock_evictor(interval) +} + +#[cfg(test)] +pub(crate) fn reset_quota_user_lock_evictor_spawn_count_for_tests() { + QUOTA_USER_LOCK_EVICTOR_SPAWN_COUNT.store(0, Ordering::Relaxed); +} + +#[cfg(test)] +pub(crate) fn quota_user_lock_evictor_spawn_count_for_tests() -> u64 { + QUOTA_USER_LOCK_EVICTOR_SPAWN_COUNT.load(Ordering::Relaxed) +} + fn quota_user_lock(user: &str) -> Arc> { let locks = QUOTA_USER_LOCKS.get_or_init(DashMap::new); if let Some(existing) = locks.get(user) { return Arc::clone(existing.value()); } - if locks.len() >= QUOTA_USER_LOCKS_MAX { - locks.retain(|_, value| Arc::strong_count(value) > 1); - } - if locks.len() >= QUOTA_USER_LOCKS_MAX { return quota_overflow_user_lock(user); } @@ -393,7 +442,7 @@ fn quota_user_lock(user: &str) -> Arc> { } #[cfg(test)] -pub(crate) fn cross_mode_quota_user_lock_for_tests(user: &str) -> Arc> { +pub(crate) fn cross_mode_quota_user_lock_for_tests(user: &str) -> Arc> { crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(user) } @@ -410,14 +459,7 @@ impl AsyncRead for StatsIo { let _quota_guard = if let Some(lock) = this.quota_lock.as_ref() { match lock.try_lock() { - Ok(guard) => { - reset_quota_retry_scheduler( - &mut this.quota_read_retry_sleep, - &mut this.quota_read_wake_scheduled, - &mut this.quota_read_retry_attempt, - ); - Some(guard) - } + Ok(guard) => Some(guard), Err(_) => { poll_quota_retry_sleep( &mut this.quota_read_retry_sleep, @@ -434,14 +476,7 @@ impl AsyncRead for StatsIo { let _cross_mode_quota_guard = if let Some(lock) = this.cross_mode_quota_lock.as_ref() { match lock.try_lock() { - Ok(guard) => { - reset_quota_retry_scheduler( - &mut this.quota_read_retry_sleep, - &mut this.quota_read_wake_scheduled, - &mut this.quota_read_retry_attempt, - ); - Some(guard) - } + Ok(guard) => Some(guard), Err(_) => { poll_quota_retry_sleep( &mut this.quota_read_retry_sleep, @@ -456,6 +491,12 @@ impl AsyncRead for StatsIo { None }; + reset_quota_retry_scheduler( + &mut this.quota_read_retry_sleep, + &mut this.quota_read_wake_scheduled, + &mut this.quota_read_retry_attempt, + ); + if let Some(limit) = this.quota_limit && this.stats.get_user_total_octets(&this.user) >= limit { @@ -523,14 +564,7 @@ impl AsyncWrite for StatsIo { let _quota_guard = if let Some(lock) = this.quota_lock.as_ref() { match lock.try_lock() { - Ok(guard) => { - reset_quota_retry_scheduler( - &mut this.quota_write_retry_sleep, - &mut this.quota_write_wake_scheduled, - &mut this.quota_write_retry_attempt, - ); - Some(guard) - } + Ok(guard) => Some(guard), Err(_) => { poll_quota_retry_sleep( &mut this.quota_write_retry_sleep, @@ -547,14 +581,7 @@ impl AsyncWrite for StatsIo { let _cross_mode_quota_guard = if let Some(lock) = this.cross_mode_quota_lock.as_ref() { match lock.try_lock() { - Ok(guard) => { - reset_quota_retry_scheduler( - &mut this.quota_write_retry_sleep, - &mut this.quota_write_wake_scheduled, - &mut this.quota_write_retry_attempt, - ); - Some(guard) - } + Ok(guard) => Some(guard), Err(_) => { poll_quota_retry_sleep( &mut this.quota_write_retry_sleep, @@ -569,6 +596,12 @@ impl AsyncWrite for StatsIo { None }; + reset_quota_retry_scheduler( + &mut this.quota_write_retry_sleep, + &mut this.quota_write_wake_scheduled, + &mut this.quota_write_retry_attempt, + ); + let write_buf = if let Some(limit) = this.quota_limit { let used = this.stats.get_user_total_octets(&this.user); if used >= limit { @@ -861,6 +894,10 @@ mod relay_quota_model_adversarial_tests; #[path = "tests/relay_quota_overflow_regression_tests.rs"] mod relay_quota_overflow_regression_tests; +#[cfg(test)] +#[path = "tests/relay_quota_extended_attack_surface_security_tests.rs"] +mod relay_quota_extended_attack_surface_security_tests; + #[cfg(test)] #[path = "tests/relay_watchdog_delta_security_tests.rs"] mod relay_watchdog_delta_security_tests; @@ -889,6 +926,14 @@ mod relay_quota_retry_scheduler_tdd_tests; #[path = "tests/relay_cross_mode_quota_fairness_tdd_tests.rs"] mod relay_cross_mode_quota_fairness_tdd_tests; +#[cfg(test)] +#[path = "tests/relay_cross_mode_pipeline_hol_integration_security_tests.rs"] +mod relay_cross_mode_pipeline_hol_integration_security_tests; + +#[cfg(test)] +#[path = "tests/relay_cross_mode_pipeline_latency_benchmark_security_tests.rs"] +mod relay_cross_mode_pipeline_latency_benchmark_security_tests; + #[cfg(test)] #[path = "tests/relay_quota_retry_backoff_security_tests.rs"] mod relay_quota_retry_backoff_security_tests; @@ -896,3 +941,31 @@ mod relay_quota_retry_backoff_security_tests; #[cfg(test)] #[path = "tests/relay_quota_retry_backoff_benchmark_security_tests.rs"] mod relay_quota_retry_backoff_benchmark_security_tests; + +#[cfg(test)] +#[path = "tests/relay_dual_lock_backoff_regression_security_tests.rs"] +mod relay_dual_lock_backoff_regression_security_tests; + +#[cfg(test)] +#[path = "tests/relay_dual_lock_contention_matrix_security_tests.rs"] +mod relay_dual_lock_contention_matrix_security_tests; + +#[cfg(test)] +#[path = "tests/relay_dual_lock_race_harness_security_tests.rs"] +mod relay_dual_lock_race_harness_security_tests; + +#[cfg(test)] +#[path = "tests/relay_dual_lock_alternating_contention_security_tests.rs"] +mod relay_dual_lock_alternating_contention_security_tests; + +#[cfg(test)] +#[path = "tests/relay_quota_retry_allocation_latency_security_tests.rs"] +mod relay_quota_retry_allocation_latency_security_tests; + +#[cfg(test)] +#[path = "tests/relay_quota_lock_eviction_lifecycle_tdd_tests.rs"] +mod relay_quota_lock_eviction_lifecycle_tdd_tests; + +#[cfg(test)] +#[path = "tests/relay_quota_lock_eviction_stress_security_tests.rs"] +mod relay_quota_lock_eviction_stress_security_tests; diff --git a/src/proxy/tests/client_security_tests.rs b/src/proxy/tests/client_security_tests.rs index 35f517a..2b1fae6 100644 --- a/src/proxy/tests/client_security_tests.rs +++ b/src/proxy/tests/client_security_tests.rs @@ -8,7 +8,7 @@ use crate::proxy::handshake::HandshakeSuccess; use crate::stream::{CryptoReader, CryptoWriter}; use crate::transport::proxy_protocol::ProxyProtocolV1Builder; use rand::rngs::StdRng; -use rand::RngCore; +use rand::Rng; use rand::SeedableRng; use std::net::Ipv4Addr; use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; diff --git a/src/proxy/tests/handshake_auth_probe_eviction_bias_security_tests.rs b/src/proxy/tests/handshake_auth_probe_eviction_bias_security_tests.rs new file mode 100644 index 0000000..6c48cc1 --- /dev/null +++ b/src/proxy/tests/handshake_auth_probe_eviction_bias_security_tests.rs @@ -0,0 +1,93 @@ +use super::*; +use std::collections::HashSet; +use std::net::{IpAddr, Ipv4Addr}; +use std::time::{Duration, Instant}; + +fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> { + auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +#[test] +fn adversarial_large_state_offsets_escape_first_scan_window() { + let _guard = auth_probe_test_guard(); + let base = Instant::now(); + let state_len = 65_536usize; + let scan_limit = 1_024usize; + + let mut saw_offset_outside_first_window = false; + for i in 0..8_192u64 { + let ip = IpAddr::V4(Ipv4Addr::new( + ((i >> 16) & 0xff) as u8, + ((i >> 8) & 0xff) as u8, + (i & 0xff) as u8, + ((i.wrapping_mul(131)) & 0xff) as u8, + )); + let now = base + Duration::from_nanos(i); + let start = auth_probe_scan_start_offset(ip, now, state_len, scan_limit); + if start >= scan_limit { + saw_offset_outside_first_window = true; + break; + } + } + + assert!( + saw_offset_outside_first_window, + "scan start offset must cover the full auth-probe state, not only the first scan window" + ); +} + +#[test] +fn stress_large_state_offsets_cover_many_scan_windows() { + let _guard = auth_probe_test_guard(); + let base = Instant::now(); + let state_len = 65_536usize; + let scan_limit = 1_024usize; + + let mut covered_windows = HashSet::new(); + for i in 0..16_384u64 { + let ip = IpAddr::V4(Ipv4Addr::new( + ((i >> 16) & 0xff) as u8, + ((i >> 8) & 0xff) as u8, + (i & 0xff) as u8, + ((i.wrapping_mul(17)) & 0xff) as u8, + )); + let now = base + Duration::from_micros(i); + let start = auth_probe_scan_start_offset(ip, now, state_len, scan_limit); + covered_windows.insert(start / scan_limit); + } + + assert!( + covered_windows.len() >= 16, + "eviction scan must not collapse to a tiny hot zone; covered windows={} out of {}", + covered_windows.len(), + state_len / scan_limit + ); +} + +#[test] +fn light_fuzz_offset_always_stays_inside_state_len() { + let _guard = auth_probe_test_guard(); + let mut seed = 0xC0FF_EE12_3456_789Au64; + let base = Instant::now(); + + for _ in 0..8_192usize { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let ip = IpAddr::V4(Ipv4Addr::new( + (seed >> 24) as u8, + (seed >> 16) as u8, + (seed >> 8) as u8, + seed as u8, + )); + let state_len = ((seed >> 16) as usize % 200_000).saturating_add(1); + let scan_limit = ((seed >> 40) as usize % 2_048).saturating_add(1); + let now = base + Duration::from_nanos(seed & 0x0fff); + let start = auth_probe_scan_start_offset(ip, now, state_len, scan_limit); + + assert!(start < state_len, "scan offset must stay inside state length"); + } +} \ No newline at end of file diff --git a/src/proxy/tests/handshake_auth_probe_scan_budget_security_tests.rs b/src/proxy/tests/handshake_auth_probe_scan_budget_security_tests.rs index c5e57d7..ece6ff5 100644 --- a/src/proxy/tests/handshake_auth_probe_scan_budget_security_tests.rs +++ b/src/proxy/tests/handshake_auth_probe_scan_budget_security_tests.rs @@ -22,12 +22,13 @@ fn edge_zero_state_len_yields_zero_start_offset() { } #[test] -fn adversarial_large_state_must_bound_start_offset_to_scan_budget() { +fn adversarial_large_state_must_allow_start_offset_outside_scan_budget_window() { let _guard = auth_probe_test_guard(); let base = Instant::now(); let scan_limit = 16usize; let state_len = 65_536usize; + let mut saw_offset_outside_window = false; for i in 0..2048u32 { let ip = IpAddr::V4(Ipv4Addr::new( 203, @@ -38,10 +39,19 @@ fn adversarial_large_state_must_bound_start_offset_to_scan_budget() { let now = base + Duration::from_micros(i as u64); let start = auth_probe_scan_start_offset(ip, now, state_len, scan_limit); assert!( - start < scan_limit, - "start offset must stay within scan window; start={start}, limit={scan_limit}" + start < state_len, + "start offset must stay within state length; start={start}, len={state_len}" ); + if start >= scan_limit { + saw_offset_outside_window = true; + break; + } } + + assert!( + saw_offset_outside_window, + "large-state eviction must sample beyond the first scan window" + ); } #[test] @@ -80,11 +90,10 @@ fn light_fuzz_scan_offset_budget_never_exceeds_effective_window() { let scan_limit = ((seed >> 32) as usize % 512).saturating_add(1); let now = base + Duration::from_nanos(seed & 0xffff); let start = auth_probe_scan_start_offset(ip, now, state_len, scan_limit); - let effective_window = state_len.min(scan_limit); assert!( - start < effective_window, - "scan offset must stay inside effective window" + start < state_len, + "scan offset must stay inside state length" ); } } \ No newline at end of file diff --git a/src/proxy/tests/handshake_auth_probe_scan_offset_stress_tests.rs b/src/proxy/tests/handshake_auth_probe_scan_offset_stress_tests.rs index cdaf498..260a1b9 100644 --- a/src/proxy/tests/handshake_auth_probe_scan_offset_stress_tests.rs +++ b/src/proxy/tests/handshake_auth_probe_scan_offset_stress_tests.rs @@ -22,10 +22,10 @@ fn positive_same_ip_moving_time_yields_diverse_scan_offsets() { uniq.insert(offset); } - assert_eq!( - uniq.len(), - 16, - "offset randomization must cover the entire scan window over 512 samples" + assert!( + uniq.len() >= 256, + "offset randomization collapsed unexpectedly for same-ip moving-time samples (uniq={})", + uniq.len() ); } @@ -45,10 +45,10 @@ fn adversarial_many_ips_same_time_spreads_offsets_without_bias_collapse() { uniq.insert(auth_probe_scan_start_offset(ip, now, 65_536, 16)); } - assert_eq!( - uniq.len(), - 16, - "scan offset distribution collapsed unexpectedly across peer set" + assert!( + uniq.len() >= 512, + "scan offset distribution collapsed unexpectedly across adversarial peer set (uniq={})", + uniq.len() ); } @@ -108,6 +108,9 @@ fn light_fuzz_scan_offset_stays_within_window_for_randomized_inputs() { let now = base + Duration::from_nanos(seed & 0x1fff); let offset = auth_probe_scan_start_offset(ip, now, state_len, scan_limit); - assert!(offset < state_len.min(scan_limit)); + assert!( + offset < state_len, + "scan offset must always remain inside state length" + ); } } \ No newline at end of file diff --git a/src/proxy/tests/handshake_more_clever_tests.rs b/src/proxy/tests/handshake_more_clever_tests.rs index b3da4df..77df442 100644 --- a/src/proxy/tests/handshake_more_clever_tests.rs +++ b/src/proxy/tests/handshake_more_clever_tests.rs @@ -1,7 +1,7 @@ use super::*; use crate::crypto::{sha256, sha256_hmac, AesCtr}; use crate::protocol::constants::{ProtoTag, RESERVED_NONCE_BEGINNINGS, RESERVED_NONCE_FIRST_BYTES}; -use rand::{RngExt, SeedableRng}; +use rand::{Rng, SeedableRng}; use rand::rngs::StdRng; use std::collections::HashSet; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; diff --git a/src/proxy/tests/masking_extended_attack_surface_security_tests.rs b/src/proxy/tests/masking_extended_attack_surface_security_tests.rs new file mode 100644 index 0000000..040f567 --- /dev/null +++ b/src/proxy/tests/masking_extended_attack_surface_security_tests.rs @@ -0,0 +1,217 @@ +use super::*; +use tokio::io::{AsyncWriteExt, duplex}; +use tokio::time::{Duration, Instant, timeout}; + +fn make_self_target_config( + timing_normalization_enabled: bool, + floor_ms: u64, + ceiling_ms: u64, + beobachten_enabled: bool, +) -> ProxyConfig { + let mut config = ProxyConfig::default(); + config.general.beobachten = beobachten_enabled; + config.general.beobachten_minutes = 5; + config.censorship.mask = true; + config.censorship.mask_unix_sock = None; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = 443; + config.censorship.mask_timing_normalization_enabled = timing_normalization_enabled; + config.censorship.mask_timing_normalization_floor_ms = floor_ms; + config.censorship.mask_timing_normalization_ceiling_ms = ceiling_ms; + config +} + +async fn run_self_target_refusal( + config: ProxyConfig, + peer: SocketAddr, + initial: &'static [u8], +) -> Duration { + let beobachten = BeobachtenStore::new(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().expect("valid local addr"); + + let (mut client, server) = duplex(1024); + let started = Instant::now(); + let task = tokio::spawn(async move { + handle_bad_client(server, tokio::io::sink(), initial, peer, local_addr, &config, &beobachten) + .await; + }); + + client + .shutdown() + .await + .expect("client shutdown must succeed"); + + timeout(Duration::from_secs(3), task) + .await + .expect("self-target refusal must complete in bounded time") + .expect("self-target refusal task must not panic"); + + started.elapsed() +} + +#[tokio::test] +async fn positive_self_target_refusal_honors_normalization_floor() { + let config = make_self_target_config(true, 120, 120, false); + let peer: SocketAddr = "203.0.113.41:54041".parse().expect("valid peer"); + + let elapsed = run_self_target_refusal(config, peer, b"GET / HTTP/1.1\r\n\r\n").await; + + assert!( + elapsed >= Duration::from_millis(110) && elapsed < Duration::from_millis(260), + "normalized self-target refusal must stay within expected envelope" + ); +} + +#[tokio::test] +async fn negative_non_normalized_refusal_does_not_sleep_to_large_floor() { + let config = make_self_target_config(false, 240, 240, false); + let peer: SocketAddr = "203.0.113.42:54042".parse().expect("valid peer"); + + let elapsed = run_self_target_refusal(config, peer, b"GET / HTTP/1.1\r\n\r\n").await; + + assert!( + elapsed < Duration::from_millis(180), + "non-normalized path must not inherit normalization floor delays" + ); +} + +#[tokio::test] +async fn edge_ceiling_below_floor_uses_floor_fail_closed() { + let config = make_self_target_config(true, 140, 80, false); + let peer: SocketAddr = "203.0.113.43:54043".parse().expect("valid peer"); + + let elapsed = run_self_target_refusal(config, peer, b"GET / HTTP/1.1\r\n\r\n").await; + + assert!( + elapsed >= Duration::from_millis(130) && elapsed < Duration::from_millis(280), + "ceiling max { + max = elapsed; + } + assert!( + elapsed >= Duration::from_millis(100) && elapsed < Duration::from_millis(320), + "parallel probe latency must stay bounded under normalization" + ); + } + + assert!( + max.saturating_sub(min) <= Duration::from_millis(130), + "normalization should limit path variance across adversarial parallel probes" + ); +} + +#[tokio::test] +async fn integration_beobachten_records_probe_classification_on_refusal() { + let config = make_self_target_config(false, 0, 0, true); + let peer: SocketAddr = "198.51.100.71:55071".parse().expect("valid peer"); + let local_addr: SocketAddr = "127.0.0.1:443".parse().expect("valid local addr"); + let beobachten = BeobachtenStore::new(); + + let (mut client, server) = duplex(1024); + let task = tokio::spawn(async move { + handle_bad_client( + server, + tokio::io::sink(), + b"GET /classified HTTP/1.1\r\nHost: demo\r\n\r\n", + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + beobachten.snapshot_text(Duration::from_secs(60)) + }); + + client + .shutdown() + .await + .expect("client shutdown must succeed"); + + let snapshot = timeout(Duration::from_secs(3), task) + .await + .expect("integration task must complete") + .expect("integration task must not panic"); + + assert!(snapshot.contains("[HTTP]")); + assert!(snapshot.contains("198.51.100.71-1")); +} + +#[tokio::test] +async fn light_fuzz_timing_configuration_matrix_is_bounded() { + let mut seed = 0xA17E_55AA_2026_0323u64; + + for case in 0..48u64 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let enabled = (seed & 1) == 0; + let floor = (seed >> 8) % 180; + let ceiling = (seed >> 24) % 180; + let config = make_self_target_config(enabled, floor, ceiling, false); + let peer: SocketAddr = format!("203.0.113.90:{}", 56000 + (case as u16)) + .parse() + .expect("valid peer"); + + let elapsed = run_self_target_refusal(config, peer, b"HEAD /h HTTP/1.1\r\n\r\n").await; + + assert!( + elapsed < Duration::from_millis(420), + "fuzz case must stay bounded and never hang" + ); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_high_fanout_self_target_refusal_no_deadlock_or_timeout() { + let workers = 64usize; + let mut tasks = Vec::with_capacity(workers); + + for idx in 0..workers { + tasks.push(tokio::spawn(async move { + let config = make_self_target_config(false, 0, 0, false); + let peer: SocketAddr = format!("198.51.100.200:{}", 57000 + idx as u16) + .parse() + .expect("valid peer"); + run_self_target_refusal(config, peer, b"GET /stress HTTP/1.1\r\n\r\n").await + })); + } + + timeout(Duration::from_secs(5), async { + for task in tasks { + let elapsed = task.await.expect("stress task must not panic"); + assert!( + elapsed < Duration::from_millis(260), + "stress refusal must remain bounded without normalization" + ); + } + }) + .await + .expect("high-fanout refusal workload must complete without deadlock"); +} \ No newline at end of file diff --git a/src/proxy/tests/masking_interface_cache_concurrency_security_tests.rs b/src/proxy/tests/masking_interface_cache_concurrency_security_tests.rs new file mode 100644 index 0000000..8d99b8f --- /dev/null +++ b/src/proxy/tests/masking_interface_cache_concurrency_security_tests.rs @@ -0,0 +1,41 @@ +#![cfg(unix)] + +use super::*; +use std::sync::{Mutex, OnceLock}; +use tokio::sync::Barrier; + +fn interface_cache_test_lock() -> &'static Mutex<()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn adversarial_parallel_cold_miss_performs_single_interface_refresh() { + let _guard = interface_cache_test_lock() + .lock() + .unwrap_or_else(|poison| poison.into_inner()); + reset_local_interface_enumerations_for_tests(); + + let local_addr: SocketAddr = "0.0.0.0:443".parse().expect("valid local addr"); + let workers = 32usize; + let barrier = std::sync::Arc::new(Barrier::new(workers)); + let mut tasks = Vec::with_capacity(workers); + + for _ in 0..workers { + let barrier = std::sync::Arc::clone(&barrier); + tasks.push(tokio::spawn(async move { + barrier.wait().await; + is_mask_target_local_listener_async("127.0.0.1", 443, local_addr, None).await + })); + } + + for task in tasks { + let _ = task.await.expect("parallel cache task must not panic"); + } + + assert_eq!( + local_interface_enumerations_for_tests(), + 1, + "parallel cold misses must coalesce into a single interface enumeration" + ); +} \ No newline at end of file diff --git a/src/proxy/tests/masking_interface_cache_security_tests.rs b/src/proxy/tests/masking_interface_cache_security_tests.rs index b14d7c3..6be99d0 100644 --- a/src/proxy/tests/masking_interface_cache_security_tests.rs +++ b/src/proxy/tests/masking_interface_cache_security_tests.rs @@ -8,8 +8,8 @@ fn interface_cache_test_lock() -> &'static Mutex<()> { LOCK.get_or_init(|| Mutex::new(())) } -#[test] -fn tdd_repeated_local_listener_checks_do_not_repeat_interface_enumeration_within_window() { +#[tokio::test] +async fn tdd_repeated_local_listener_checks_do_not_repeat_interface_enumeration_within_window() { let _guard = interface_cache_test_lock() .lock() .unwrap_or_else(|poison| poison.into_inner()); @@ -17,8 +17,8 @@ fn tdd_repeated_local_listener_checks_do_not_repeat_interface_enumeration_within let local_addr: SocketAddr = "0.0.0.0:443".parse().expect("valid local addr"); - let _ = is_mask_target_local_listener("127.0.0.1", 443, local_addr, None); - let _ = is_mask_target_local_listener("127.0.0.1", 443, local_addr, None); + let _ = is_mask_target_local_listener_async("127.0.0.1", 443, local_addr, None).await; + let _ = is_mask_target_local_listener_async("127.0.0.1", 443, local_addr, None).await; assert_eq!( local_interface_enumerations_for_tests(), @@ -27,15 +27,15 @@ fn tdd_repeated_local_listener_checks_do_not_repeat_interface_enumeration_within ); } -#[test] -fn tdd_non_local_port_short_circuit_does_not_enumerate_interfaces() { +#[tokio::test] +async fn tdd_non_local_port_short_circuit_does_not_enumerate_interfaces() { let _guard = interface_cache_test_lock() .lock() .unwrap_or_else(|poison| poison.into_inner()); reset_local_interface_enumerations_for_tests(); let local_addr: SocketAddr = "0.0.0.0:443".parse().expect("valid local addr"); - let is_local = is_mask_target_local_listener("127.0.0.1", 8443, local_addr, None); + let is_local = is_mask_target_local_listener_async("127.0.0.1", 8443, local_addr, None).await; assert!(!is_local, "different port must not be treated as local listener"); assert_eq!( diff --git a/src/proxy/tests/masking_production_cap_regression_security_tests.rs b/src/proxy/tests/masking_production_cap_regression_security_tests.rs new file mode 100644 index 0000000..f2368a1 --- /dev/null +++ b/src/proxy/tests/masking_production_cap_regression_security_tests.rs @@ -0,0 +1,289 @@ +use super::*; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::time::{Duration, Instant, timeout}; + +const PROD_CAP_BYTES: usize = 5 * 1024 * 1024; + +struct FinitePatternReader { + remaining: usize, + chunk: usize, + read_calls: Arc, +} + +impl FinitePatternReader { + fn new(total: usize, chunk: usize, read_calls: Arc) -> Self { + Self { + remaining: total, + chunk, + read_calls, + } + } +} + +impl AsyncRead for FinitePatternReader { + fn poll_read( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + self.read_calls.fetch_add(1, Ordering::Relaxed); + + if self.remaining == 0 { + return Poll::Ready(Ok(())); + } + + let take = self.remaining.min(self.chunk).min(buf.remaining()); + if take == 0 { + return Poll::Ready(Ok(())); + } + + let fill = vec![0x5Au8; take]; + buf.put_slice(&fill); + self.remaining -= take; + Poll::Ready(Ok(())) + } +} + +#[derive(Default)] +struct CountingWriter { + written: usize, +} + +impl AsyncWrite for CountingWriter { + fn poll_write( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.written = self.written.saturating_add(buf.len()); + Poll::Ready(Ok(buf.len())) + } + + fn poll_flush( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } +} + +struct NeverReadyReader; + +impl AsyncRead for NeverReadyReader { + fn poll_read( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + Poll::Pending + } +} + +struct BudgetProbeReader { + remaining: usize, + total_read: Arc, +} + +impl BudgetProbeReader { + fn new(total: usize, total_read: Arc) -> Self { + Self { + remaining: total, + total_read, + } + } +} + +impl AsyncRead for BudgetProbeReader { + fn poll_read( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + if self.remaining == 0 { + return Poll::Ready(Ok(())); + } + + let take = self.remaining.min(buf.remaining()); + if take == 0 { + return Poll::Ready(Ok(())); + } + + let fill = vec![0xA5u8; take]; + buf.put_slice(&fill); + self.remaining -= take; + self.total_read.fetch_add(take, Ordering::Relaxed); + Poll::Ready(Ok(())) + } +} + +#[tokio::test] +async fn positive_copy_with_production_cap_stops_exactly_at_budget() { + let read_calls = Arc::new(AtomicUsize::new(0)); + let mut reader = FinitePatternReader::new(PROD_CAP_BYTES + (256 * 1024), 4096, read_calls); + let mut writer = CountingWriter::default(); + + let outcome = copy_with_idle_timeout(&mut reader, &mut writer, PROD_CAP_BYTES, true).await; + + assert_eq!( + outcome.total, PROD_CAP_BYTES, + "copy path must stop at explicit production cap" + ); + assert_eq!(writer.written, PROD_CAP_BYTES); + assert!( + !outcome.ended_by_eof, + "byte-cap stop must not be misclassified as EOF" + ); +} + +#[tokio::test] +async fn negative_consume_with_zero_cap_performs_no_reads() { + let read_calls = Arc::new(AtomicUsize::new(0)); + let reader = FinitePatternReader::new(1024, 64, Arc::clone(&read_calls)); + + consume_client_data_with_timeout_and_cap(reader, 0).await; + + assert_eq!( + read_calls.load(Ordering::Relaxed), + 0, + "zero cap must return before reading attacker-controlled bytes" + ); +} + +#[tokio::test] +async fn edge_copy_below_cap_reports_eof_without_overread() { + let read_calls = Arc::new(AtomicUsize::new(0)); + let payload = 73 * 1024; + let mut reader = FinitePatternReader::new(payload, 3072, read_calls); + let mut writer = CountingWriter::default(); + + let outcome = copy_with_idle_timeout(&mut reader, &mut writer, PROD_CAP_BYTES, true).await; + + assert_eq!(outcome.total, payload); + assert_eq!(writer.written, payload); + assert!( + outcome.ended_by_eof, + "finite upstream below cap must terminate via EOF path" + ); +} + +#[tokio::test] +async fn adversarial_blackhat_never_ready_reader_is_bounded_by_timeout_guards() { + let started = Instant::now(); + + consume_client_data_with_timeout_and_cap(NeverReadyReader, PROD_CAP_BYTES).await; + + assert!( + started.elapsed() < Duration::from_millis(350), + "never-ready reader must be bounded by idle/relay timeout protections" + ); +} + +#[tokio::test] +async fn integration_consume_path_honors_production_cap_for_large_payload() { + let read_calls = Arc::new(AtomicUsize::new(0)); + let reader = FinitePatternReader::new(PROD_CAP_BYTES + (1024 * 1024), 8192, read_calls); + + let bounded = timeout( + Duration::from_millis(350), + consume_client_data_with_timeout_and_cap(reader, PROD_CAP_BYTES), + ) + .await; + + assert!( + bounded.is_ok(), + "consume path with production cap must finish within bounded time" + ); +} + +#[tokio::test] +async fn adversarial_consume_path_never_reads_beyond_declared_byte_cap() { + let byte_cap = 5usize; + let total_read = Arc::new(AtomicUsize::new(0)); + let reader = BudgetProbeReader::new(256 * 1024, Arc::clone(&total_read)); + + consume_client_data_with_timeout_and_cap(reader, byte_cap).await; + + assert!( + total_read.load(Ordering::Relaxed) <= byte_cap, + "consume path must not read more than configured byte cap" + ); +} + +#[tokio::test] +async fn light_fuzz_cap_and_payload_matrix_preserves_min_budget_invariant() { + let mut seed = 0x1234_5678_9ABC_DEF0u64; + + for _case in 0..96u32 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let cap = ((seed & 0x3ffff) as usize).saturating_add(1); + let payload = ((seed.rotate_left(11) & 0x7ffff) as usize).saturating_add(1); + let chunk = (((seed >> 5) & 0x1fff) as usize).saturating_add(1); + + let read_calls = Arc::new(AtomicUsize::new(0)); + let mut reader = FinitePatternReader::new(payload, chunk, read_calls); + let mut writer = CountingWriter::default(); + + let outcome = copy_with_idle_timeout(&mut reader, &mut writer, cap, true).await; + let expected = payload.min(cap); + + assert_eq!( + outcome.total, expected, + "copy total must match min(payload, cap) under fuzzed inputs" + ); + assert_eq!(writer.written, expected); + if payload <= cap { + assert!(outcome.ended_by_eof); + } else { + assert!(!outcome.ended_by_eof); + } + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_parallel_copy_tasks_with_production_cap_complete_without_leaks() { + let workers = 8usize; + let mut tasks = Vec::with_capacity(workers); + + for idx in 0..workers { + tasks.push(tokio::spawn(async move { + let read_calls = Arc::new(AtomicUsize::new(0)); + let mut reader = FinitePatternReader::new( + PROD_CAP_BYTES + (idx + 1) * 4096, + 4096 + (idx * 257), + read_calls, + ); + let mut writer = CountingWriter::default(); + copy_with_idle_timeout(&mut reader, &mut writer, PROD_CAP_BYTES, true).await + })); + } + + timeout(Duration::from_secs(3), async { + for task in tasks { + let outcome = task.await.expect("stress task must not panic"); + assert_eq!( + outcome.total, PROD_CAP_BYTES, + "stress copy task must stay within production cap" + ); + assert!( + !outcome.ended_by_eof, + "stress task should end due to cap, not EOF" + ); + } + }) + .await + .expect("stress suite must complete in bounded time"); +} diff --git a/src/proxy/tests/masking_self_target_loop_security_tests.rs b/src/proxy/tests/masking_self_target_loop_security_tests.rs index b92ce3d..18cb0d7 100644 --- a/src/proxy/tests/masking_self_target_loop_security_tests.rs +++ b/src/proxy/tests/masking_self_target_loop_security_tests.rs @@ -12,71 +12,77 @@ fn closed_local_port() -> u16 { port } -#[test] -fn self_target_detection_matches_literal_ipv4_listener() { +#[tokio::test] +async fn self_target_detection_matches_literal_ipv4_listener() { let local: SocketAddr = "198.51.100.40:443".parse().unwrap(); - assert!(is_mask_target_local_listener( + assert!(is_mask_target_local_listener_async( "198.51.100.40", 443, local, None, - )); + ) + .await); } -#[test] -fn self_target_detection_matches_bracketed_ipv6_listener() { +#[tokio::test] +async fn self_target_detection_matches_bracketed_ipv6_listener() { let local: SocketAddr = "[2001:db8::44]:8443".parse().unwrap(); - assert!(is_mask_target_local_listener( + assert!(is_mask_target_local_listener_async( "[2001:db8::44]", 8443, local, None, - )); + ) + .await); } -#[test] -fn self_target_detection_keeps_same_ip_different_port_forwardable() { +#[tokio::test] +async fn self_target_detection_keeps_same_ip_different_port_forwardable() { let local: SocketAddr = "203.0.113.44:443".parse().unwrap(); - assert!(!is_mask_target_local_listener( + assert!(!is_mask_target_local_listener_async( "203.0.113.44", 8443, local, None, - )); + ) + .await); } -#[test] -fn self_target_detection_normalizes_ipv4_mapped_ipv6_literal() { +#[tokio::test] +async fn self_target_detection_normalizes_ipv4_mapped_ipv6_literal() { let local: SocketAddr = "127.0.0.1:443".parse().unwrap(); - assert!(is_mask_target_local_listener( + assert!(is_mask_target_local_listener_async( "::ffff:127.0.0.1", 443, local, None, - )); + ) + .await); } -#[test] -fn self_target_detection_unspecified_bind_blocks_loopback_target() { +#[tokio::test] +async fn self_target_detection_unspecified_bind_blocks_loopback_target() { let local: SocketAddr = "0.0.0.0:443".parse().unwrap(); - assert!(is_mask_target_local_listener( + assert!(is_mask_target_local_listener_async( "127.0.0.1", 443, local, None, - )); + ) + .await); } -#[test] -fn self_target_detection_unspecified_bind_keeps_remote_target_forwardable() { +#[tokio::test] +async fn self_target_detection_unspecified_bind_keeps_remote_target_forwardable() { let local: SocketAddr = "0.0.0.0:443".parse().unwrap(); let remote: SocketAddr = "198.51.100.44:443".parse().unwrap(); - assert!(!is_mask_target_local_listener( + assert!(!is_mask_target_local_listener_async( "mask.example", 443, local, Some(remote), - )); + ) + .await); } #[tokio::test] diff --git a/src/proxy/tests/masking_timing_budget_coupling_security_tests.rs b/src/proxy/tests/masking_timing_budget_coupling_security_tests.rs new file mode 100644 index 0000000..1c342ea --- /dev/null +++ b/src/proxy/tests/masking_timing_budget_coupling_security_tests.rs @@ -0,0 +1,55 @@ +#![cfg(unix)] + +use super::*; +use tokio::io::{AsyncWriteExt, duplex}; +use tokio::time::{Duration, Instant, timeout}; + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn adversarial_delayed_interface_lookup_does_not_consume_outcome_floor_budget() { + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = None; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = 443; + config.censorship.mask_timing_normalization_enabled = true; + config.censorship.mask_timing_normalization_floor_ms = 120; + config.censorship.mask_timing_normalization_ceiling_ms = 120; + + let peer: SocketAddr = "203.0.113.151:55151".parse().expect("valid peer"); + let local_addr: SocketAddr = "0.0.0.0:443".parse().expect("valid local addr"); + let beobachten = BeobachtenStore::new(); + + let refresh_lock = LOCAL_INTERFACE_REFRESH_LOCK.get_or_init(|| AsyncMutex::new(())); + let held_refresh_guard = refresh_lock.lock().await; + + let (mut client, server) = duplex(1024); + let started = Instant::now(); + let task = tokio::spawn(async move { + handle_bad_client( + server, + tokio::io::sink(), + b"GET / HTTP/1.1\r\n\r\n", + peer, + local_addr, + &config, + &beobachten, + ) + .await; + }); + + tokio::time::sleep(Duration::from_millis(80)).await; + drop(held_refresh_guard); + client.shutdown().await.expect("client shutdown must succeed"); + + timeout(Duration::from_secs(2), task) + .await + .expect("task must finish in bounded time") + .expect("task must not panic"); + let elapsed = started.elapsed(); + + assert!( + elapsed >= Duration::from_millis(180) && elapsed < Duration::from_millis(350), + "timing normalization floor must start after pre-outcome self-target checks" + ); +} \ No newline at end of file diff --git a/src/proxy/tests/middle_relay_coverage_high_risk_security_tests.rs b/src/proxy/tests/middle_relay_coverage_high_risk_security_tests.rs index fff26b4..44c201f 100644 --- a/src/proxy/tests/middle_relay_coverage_high_risk_security_tests.rs +++ b/src/proxy/tests/middle_relay_coverage_high_risk_security_tests.rs @@ -645,6 +645,75 @@ fn quota_exceeded_boundary_is_inclusive() { assert!(!quota_exceeded_for_user(&stats, user, Some(51))); } +#[test] +fn quota_soft_helper_matches_capped_generic_helper_matrix() { + let stats = Stats::new(); + let user = "quota-soft-parity"; + + for used in [0u64, 1, 7, 63, 127, 255] { + stats.sub_user_octets_to(user, stats.get_user_total_octets(user)); + stats.add_user_octets_to(user, used); + + for quota in [8u64, 64, 128, 256] { + for overshoot in [0u64, 1, 5, 32] { + for bytes in [0u64, 1, 2, 7, 31, 64] { + let soft = quota_would_be_exceeded_for_user_soft( + &stats, + user, + Some(quota), + bytes, + overshoot, + ); + let capped = quota_would_be_exceeded_for_user( + &stats, + user, + Some(quota_soft_cap(quota, overshoot)), + bytes, + ); + assert_eq!( + soft, capped, + "soft helper parity mismatch: used={used} quota={quota} overshoot={overshoot} bytes={bytes}" + ); + } + } + } + } +} + +#[test] +fn quota_soft_helper_none_limit_never_rejects() { + let stats = Stats::new(); + let user = "quota-soft-none"; + stats.add_user_octets_to(user, u64::MAX); + + assert!(!quota_would_be_exceeded_for_user_soft( + &stats, + user, + None, + u64::MAX, + u64::MAX, + )); +} + +#[test] +fn quota_soft_cap_saturates_and_stays_fail_closed() { + let stats = Stats::new(); + let user = "quota-soft-saturating"; + let quota = u64::MAX - 2; + let overshoot = 100; + + assert_eq!(quota_soft_cap(quota, overshoot), u64::MAX); + + stats.add_user_octets_to(user, u64::MAX - 1); + assert!(quota_would_be_exceeded_for_user_soft( + &stats, + user, + Some(quota), + 2, + overshoot, + )); +} + #[tokio::test] async fn enqueue_c2me_close_fast_path_succeeds_without_backpressure() { let (tx, mut rx) = mpsc::channel::(4); diff --git a/src/proxy/tests/middle_relay_cross_mode_lock_release_regression_security_tests.rs b/src/proxy/tests/middle_relay_cross_mode_lock_release_regression_security_tests.rs new file mode 100644 index 0000000..a787aa6 --- /dev/null +++ b/src/proxy/tests/middle_relay_cross_mode_lock_release_regression_security_tests.rs @@ -0,0 +1,295 @@ +use super::*; +use crate::crypto::{AesCtr, SecureRandom}; +use crate::stats::Stats; +use crate::stream::CryptoWriter; +use bytes::Bytes; +use std::pin::Pin; +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; +use std::sync::{Arc, Mutex}; +use std::task::{Context, Poll, Waker}; +use tokio::io::AsyncWrite; +use tokio::sync::Notify; +use tokio::task::JoinSet; +use tokio::time::{Duration, timeout}; + +fn make_crypto_writer(writer: W) -> CryptoWriter +where + W: tokio::io::AsyncWrite + Unpin, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) +} + +#[derive(Default)] +struct BlockingWriteState { + write_entered: AtomicBool, + released: AtomicBool, + write_waker: Mutex>, + write_entered_notify: Notify, +} + +struct BlockingWrite { + state: Arc, +} + +impl BlockingWrite { + fn new(state: Arc) -> Self { + Self { state } + } +} + +impl AsyncWrite for BlockingWrite { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.state.write_entered.store(true, Ordering::Release); + self.state.write_entered_notify.notify_waiters(); + + if self.state.released.load(Ordering::Acquire) { + return Poll::Ready(Ok(buf.len())); + } + + if let Ok(mut slot) = self.state.write_waker.lock() { + *slot = Some(cx.waker().clone()); + } + + Poll::Pending + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} + +async fn wait_until_blocking_write_entered(state: &Arc) { + for _ in 0..8 { + if state.write_entered.load(Ordering::Acquire) { + return; + } + let _ = timeout(Duration::from_millis(25), state.write_entered_notify.notified()).await; + } + + panic!("blocking writer did not enter poll_write in bounded time"); +} + +fn release_blocking_write(state: &Arc) { + state.released.store(true, Ordering::Release); + if let Ok(mut slot) = state.write_waker.lock() + && let Some(waker) = slot.take() + { + waker.wake(); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn adversarial_blocked_write_releases_cross_mode_lock_and_preserves_fail_closed_quota() { + let stats = Arc::new(Stats::new()); + let user = format!("middle-cross-release-regression-{}", std::process::id()); + let cross_mode_lock = Arc::new(cross_mode_quota_user_lock_for_tests(&user)); + let bytes_me2c = Arc::new(AtomicU64::new(0)); + let writer_state = Arc::new(BlockingWriteState::default()); + + let first = { + let stats = Arc::clone(&stats); + let user = user.clone(); + let cross_mode_lock = Arc::clone(&cross_mode_lock); + let bytes_me2c = Arc::clone(&bytes_me2c); + let writer_state = Arc::clone(&writer_state); + tokio::spawn(async move { + let mut writer = make_crypto_writer(BlockingWrite::new(writer_state)); + let mut frame_buf = Vec::new(); + process_me_writer_response_with_cross_mode_lock( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0xAA, 0xBB, 0xCC, 0xDD]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + stats.as_ref(), + &user, + Some(4), + 0, + Some(&cross_mode_lock), + bytes_me2c.as_ref(), + 41_000, + false, + false, + ) + .await + }) + }; + + wait_until_blocking_write_entered(&writer_state).await; + + let guard = timeout(Duration::from_millis(40), cross_mode_lock.lock()) + .await + .expect("cross-mode lock must be released while first write is pending"); + drop(guard); + + let second = { + let stats = Arc::clone(&stats); + let user = user.clone(); + let cross_mode_lock = Arc::clone(&cross_mode_lock); + let bytes_me2c = Arc::clone(&bytes_me2c); + tokio::spawn(async move { + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + timeout( + Duration::from_millis(150), + process_me_writer_response_with_cross_mode_lock( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0xEE]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + stats.as_ref(), + &user, + Some(4), + 0, + Some(&cross_mode_lock), + bytes_me2c.as_ref(), + 41_001, + false, + false, + ), + ) + .await + }) + }; + + let second_result = second + .await + .expect("second task must not panic") + .expect("second write must not block on cross-mode lock"); + assert!( + matches!(second_result, Err(ProxyError::DataQuotaExceeded { .. })), + "second write must fail closed due to first write reservation" + ); + + release_blocking_write(&writer_state); + + let first_result = timeout(Duration::from_millis(300), first) + .await + .expect("first task timed out") + .expect("first task must not panic"); + assert!(first_result.is_ok()); + + assert_eq!(stats.get_user_total_octets(&user), 4); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), 4); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_pending_write_does_not_starve_same_user_waiters_after_quota_boundary() { + let stats = Arc::new(Stats::new()); + let user = format!("middle-cross-release-stress-{}", std::process::id()); + let cross_mode_lock = Arc::new(cross_mode_quota_user_lock_for_tests(&user)); + let bytes_me2c = Arc::new(AtomicU64::new(0)); + let writer_state = Arc::new(BlockingWriteState::default()); + + let first = { + let stats = Arc::clone(&stats); + let user = user.clone(); + let cross_mode_lock = Arc::clone(&cross_mode_lock); + let bytes_me2c = Arc::clone(&bytes_me2c); + let writer_state = Arc::clone(&writer_state); + tokio::spawn(async move { + let mut writer = make_crypto_writer(BlockingWrite::new(writer_state)); + let mut frame_buf = Vec::new(); + process_me_writer_response_with_cross_mode_lock( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0x01, 0x02]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + stats.as_ref(), + &user, + Some(3), + 0, + Some(&cross_mode_lock), + bytes_me2c.as_ref(), + 41_100, + false, + false, + ) + .await + }) + }; + + wait_until_blocking_write_entered(&writer_state).await; + + let mut set = JoinSet::new(); + for idx in 0..48u64 { + let stats = Arc::clone(&stats); + let user = user.clone(); + let cross_mode_lock = Arc::clone(&cross_mode_lock); + let bytes_me2c = Arc::clone(&bytes_me2c); + set.spawn(async move { + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + timeout( + Duration::from_millis(200), + process_me_writer_response_with_cross_mode_lock( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0x10]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + stats.as_ref(), + &user, + Some(3), + 0, + Some(&cross_mode_lock), + bytes_me2c.as_ref(), + 41_200 + idx, + false, + false, + ), + ) + .await + }); + } + + let mut ok = 0usize; + let mut quota_exceeded = 0usize; + while let Some(done) = set.join_next().await { + let timed = done.expect("waiter task must not panic"); + let result = timed.expect("waiter must not block behind pending first write"); + match result { + Ok(_) => ok += 1, + Err(ProxyError::DataQuotaExceeded { .. }) => quota_exceeded += 1, + Err(other) => panic!("unexpected error in waiter: {other:?}"), + } + } + + assert_eq!(ok, 1, "exactly one waiter should consume remaining one-byte quota"); + assert_eq!(quota_exceeded, 47); + + release_blocking_write(&writer_state); + + let first_result = timeout(Duration::from_millis(300), first) + .await + .expect("first task timed out") + .expect("first task must not panic"); + assert!(first_result.is_ok()); + + assert_eq!(stats.get_user_total_octets(&user), 3); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), 3); +} diff --git a/src/proxy/tests/middle_relay_cross_mode_lookup_efficiency_security_tests.rs b/src/proxy/tests/middle_relay_cross_mode_lookup_efficiency_security_tests.rs new file mode 100644 index 0000000..37e1b87 --- /dev/null +++ b/src/proxy/tests/middle_relay_cross_mode_lookup_efficiency_security_tests.rs @@ -0,0 +1,116 @@ +use super::*; +use crate::crypto::{AesCtr, SecureRandom}; +use crate::stats::Stats; +use crate::stream::CryptoWriter; +use bytes::Bytes; +use std::sync::Arc; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::{Mutex, OnceLock}; + +fn make_crypto_writer(writer: W) -> CryptoWriter +where + W: tokio::io::AsyncWrite + Unpin, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) +} + +fn lookup_counter_test_lock() -> &'static Mutex<()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())) +} + +#[tokio::test] +async fn tdd_prefetched_cross_mode_lock_avoids_per_frame_registry_lookup_in_me_to_client_writer() { + let _guard = lookup_counter_test_lock() + .lock() + .unwrap_or_else(|poison| poison.into_inner()); + + let stats = Stats::new(); + let user = format!("middle-cross-mode-lookup-{}", std::process::id()); + let cross_mode_lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); + + crate::proxy::quota_lock_registry::reset_cross_mode_quota_user_lock_lookup_count_for_tests(); + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let bytes_me2c = AtomicU64::new(0); + + for idx in 0..8u64 { + let outcome = process_me_writer_response_with_cross_mode_lock( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0xAB]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(1024), + 0, + Some(&cross_mode_lock), + &bytes_me2c, + 20_000 + idx, + false, + false, + ) + .await; + + assert!(outcome.is_ok()); + } + + assert_eq!( + crate::proxy::quota_lock_registry::cross_mode_quota_user_lock_lookup_count_for_user_for_tests(&user), + 0, + "prefetched lock path must not re-query lock registry per frame" + ); + assert_eq!(stats.get_user_total_octets(&user), 8); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), 8); +} + +#[tokio::test] +async fn control_without_prefetched_lock_still_uses_registry_lookup_path() { + let _guard = lookup_counter_test_lock() + .lock() + .unwrap_or_else(|poison| poison.into_inner()); + + let stats = Stats::new(); + let user = format!("middle-cross-mode-lookup-control-{}", std::process::id()); + + crate::proxy::quota_lock_registry::reset_cross_mode_quota_user_lock_lookup_count_for_tests(); + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let bytes_me2c = AtomicU64::new(0); + + let outcome = process_me_writer_response_with_cross_mode_lock( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0xCD]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(1024), + 0, + None, + &bytes_me2c, + 20_100, + false, + false, + ) + .await; + + assert!(outcome.is_ok()); + assert_eq!( + crate::proxy::quota_lock_registry::cross_mode_quota_user_lock_lookup_count_for_user_for_tests(&user), + 1, + "fallback path without prefetched lock should perform a registry lookup" + ); +} diff --git a/src/proxy/tests/middle_relay_cross_mode_quota_lock_matrix_security_tests.rs b/src/proxy/tests/middle_relay_cross_mode_quota_lock_matrix_security_tests.rs new file mode 100644 index 0000000..bc7c857 --- /dev/null +++ b/src/proxy/tests/middle_relay_cross_mode_quota_lock_matrix_security_tests.rs @@ -0,0 +1,376 @@ +use super::*; +use crate::crypto::{AesCtr, SecureRandom}; +use crate::stats::Stats; +use crate::stream::CryptoWriter; +use bytes::Bytes; +use std::sync::Arc; +use std::sync::atomic::{AtomicU64, Ordering}; +use tokio::time::{Duration, timeout}; + +fn make_crypto_writer(writer: W) -> CryptoWriter +where + W: tokio::io::AsyncWrite + Unpin, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) +} + +#[tokio::test] +async fn positive_quota_limited_me_to_client_write_updates_counters_exactly_once() { + let stats = Stats::new(); + let user = format!("middle-cross-matrix-positive-{}", std::process::id()); + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let bytes_me2c = AtomicU64::new(0); + + let result = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[1, 2, 3, 4]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(128), + 0, + &bytes_me2c, + 10_001, + false, + false, + ) + .await; + + assert!(result.is_ok()); + assert_eq!(stats.get_user_total_octets(&user), 4); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), 4); +} + +#[tokio::test] +async fn negative_held_cross_mode_lock_blocks_quota_limited_me_to_client_path() { + let stats = Stats::new(); + let user = format!("middle-cross-matrix-negative-{}", std::process::id()); + let held = cross_mode_quota_user_lock_for_tests(&user); + let held_guard = held + .try_lock() + .expect("test must hold lock before ME->C call"); + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let bytes_me2c = AtomicU64::new(0); + + let blocked = timeout( + Duration::from_millis(25), + process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0x41]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(256), + 0, + &bytes_me2c, + 10_002, + false, + false, + ), + ) + .await; + + assert!(blocked.is_err()); + drop(held_guard); +} + +#[tokio::test] +async fn edge_quota_none_bypasses_cross_mode_lock_guard_in_me_to_client_path() { + let stats = Stats::new(); + let user = format!("middle-cross-matrix-edge-none-{}", std::process::id()); + let held = cross_mode_quota_user_lock_for_tests(&user); + let held_guard = held + .try_lock() + .expect("test must hold lock while quota is disabled"); + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let bytes_me2c = AtomicU64::new(0); + + let outcome = timeout( + Duration::from_millis(80), + process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0x11, 0x22]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + None, + 0, + &bytes_me2c, + 10_003, + false, + false, + ), + ) + .await + .expect("quota-none path must not wait on cross-mode lock"); + + assert!(outcome.is_ok()); + drop(held_guard); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn adversarial_same_user_parallel_quota_limited_writes_stay_hard_capped() { + let stats = Arc::new(Stats::new()); + let user = format!("middle-cross-matrix-adversarial-{}", std::process::id()); + let limit = 64u64; + let bytes_me2c = Arc::new(AtomicU64::new(0)); + let mut tasks = Vec::new(); + + for idx in 0..256u64 { + let stats = Arc::clone(&stats); + let bytes_me2c = Arc::clone(&bytes_me2c); + let user = user.clone(); + tasks.push(tokio::spawn(async move { + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0xEE]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + stats.as_ref(), + &user, + Some(limit), + 0, + bytes_me2c.as_ref(), + 11_000 + idx, + false, + false, + ) + .await + })); + } + + let mut ok = 0usize; + for task in tasks { + match task.await.expect("task must not panic") { + Ok(_) => ok += 1, + Err(ProxyError::DataQuotaExceeded { .. }) => {} + Err(other) => panic!("unexpected error in adversarial parallel case: {other:?}"), + } + } + + assert_eq!(ok, limit as usize); + assert_eq!(stats.get_user_total_octets(&user), limit); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), limit); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn integration_shared_lock_blocks_direct_relay_and_middle_relay_for_same_user() { + let user = format!("middle-cross-matrix-integration-{}", std::process::id()); + let relay_lock = crate::proxy::relay::cross_mode_quota_user_lock_for_tests(&user); + let middle_lock = cross_mode_quota_user_lock_for_tests(&user); + assert!( + Arc::ptr_eq(&relay_lock, &middle_lock), + "relay and middle-relay must share the same cross-mode lock identity" + ); + + let held_guard = relay_lock + .try_lock() + .expect("test must hold shared cross-mode lock"); + + let stats = Stats::new(); + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let bytes_me2c = AtomicU64::new(0); + + let middle_blocked = timeout( + Duration::from_millis(25), + process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0x92]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(1024), + 0, + &bytes_me2c, + 12_001, + false, + false, + ), + ) + .await; + assert!(middle_blocked.is_err()); + + drop(held_guard); + + let middle_ready = timeout( + Duration::from_millis(250), + process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0x94]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(1024), + 0, + &bytes_me2c, + 12_002, + false, + false, + ), + ) + .await + .expect("middle path must complete after release"); + + assert!(middle_ready.is_ok()); +} + +#[tokio::test] +async fn light_fuzz_mixed_payload_sizes_with_periodic_lock_holds_keeps_accounting_consistent() { + let stats = Stats::new(); + let user = format!("middle-cross-matrix-fuzz-{}", std::process::id()); + let bytes_me2c = AtomicU64::new(0); + let mut seed = 0xC0DE_1234_55AA_9988u64; + + for case in 0..96u32 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let hold = (seed & 0x03) == 0; + let mut held_lock = None; + let maybe_guard = if hold { + held_lock = Some(cross_mode_quota_user_lock_for_tests(&user)); + Some( + held_lock + .as_ref() + .expect("held lock should be present") + .try_lock() + .expect("cross-mode lock should be acquirable in fuzz round"), + ) + } else { + None + }; + + let payload_len = ((seed >> 8) as usize % 8) + 1; + let payload = vec![(seed & 0xff) as u8; payload_len]; + let before = stats.get_user_total_octets(&user); + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + + let timed = timeout( + Duration::from_millis(20), + process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from(payload), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(1024), + 0, + &bytes_me2c, + 13_000 + case as u64, + false, + false, + ), + ) + .await; + + if hold { + assert!(timed.is_err(), "held-lock fuzz round must block within timeout"); + assert_eq!(stats.get_user_total_octets(&user), before); + } else { + let done = timed.expect("unheld fuzz round must complete in time"); + assert!(done.is_ok()); + } + + drop(maybe_guard); + drop(held_lock); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), stats.get_user_total_octets(&user)); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_held_user_lock_does_not_block_other_users_me_to_client_writes() { + let held_user = format!("middle-cross-matrix-stress-held-{}", std::process::id()); + let free_user = format!("middle-cross-matrix-stress-free-{}", std::process::id()); + + let held = cross_mode_quota_user_lock_for_tests(&held_user); + let held_guard = held + .try_lock() + .expect("test must hold lock for blocked user"); + + let mut tasks = Vec::new(); + for idx in 0..64u64 { + let user = free_user.clone(); + tasks.push(tokio::spawn(async move { + let stats = Stats::new(); + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let bytes_me2c = AtomicU64::new(0); + process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0xA0]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(1), + 0, + &bytes_me2c, + 14_000 + idx, + false, + false, + ) + .await + })); + } + + timeout(Duration::from_secs(2), async { + for task in tasks { + let done = task.await.expect("free-user task must not panic"); + assert!(done.is_ok()); + } + }) + .await + .expect("free-user tasks should complete without waiting for held user's lock"); + + drop(held_guard); +} diff --git a/src/proxy/tests/middle_relay_cross_mode_quota_reservation_security_tests.rs b/src/proxy/tests/middle_relay_cross_mode_quota_reservation_security_tests.rs new file mode 100644 index 0000000..51092bd --- /dev/null +++ b/src/proxy/tests/middle_relay_cross_mode_quota_reservation_security_tests.rs @@ -0,0 +1,254 @@ +use super::*; +use crate::crypto::{AesCtr, SecureRandom}; +use crate::stats::Stats; +use crate::stream::CryptoWriter; +use bytes::Bytes; +use std::pin::Pin; +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; +use std::sync::{Arc, Mutex}; +use std::task::{Context, Poll, Waker}; +use tokio::io::AsyncWrite; +use tokio::sync::Notify; +use tokio::time::{Duration, timeout}; + +fn make_crypto_writer(writer: W) -> CryptoWriter +where + W: tokio::io::AsyncWrite + Unpin, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) +} + +#[derive(Default)] +struct BlockingWriteState { + write_entered: AtomicBool, + released: AtomicBool, + write_waker: Mutex>, + write_entered_notify: Notify, +} + +struct BlockingWrite { + state: Arc, +} + +impl BlockingWrite { + fn new(state: Arc) -> Self { + Self { state } + } +} + +impl AsyncWrite for BlockingWrite { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.state.write_entered.store(true, Ordering::Release); + self.state.write_entered_notify.notify_waiters(); + + if self.state.released.load(Ordering::Acquire) { + return Poll::Ready(Ok(buf.len())); + } + + if let Ok(mut slot) = self.state.write_waker.lock() { + *slot = Some(cx.waker().clone()); + } + Poll::Pending + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} + +async fn wait_until_blocking_write_entered(state: &Arc) { + for _ in 0..8 { + if state.write_entered.load(Ordering::Acquire) { + return; + } + let _ = timeout(Duration::from_millis(25), state.write_entered_notify.notified()).await; + } + + panic!("blocking writer did not enter poll_write in bounded time"); +} + +fn release_blocking_write(state: &Arc) { + state.released.store(true, Ordering::Release); + if let Ok(mut slot) = state.write_waker.lock() + && let Some(waker) = slot.take() + { + waker.wake(); + } +} + +#[tokio::test] +async fn adversarial_held_cross_mode_lock_blocks_me_to_client_quota_reservation_path() { + let stats = Stats::new(); + let user = format!("middle-me2c-cross-mode-held-{}", std::process::id()); + let held = cross_mode_quota_user_lock_for_tests(&user); + let held_guard = held + .try_lock() + .expect("test must hold shared cross-mode lock before ME->C write path"); + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let bytes_me2c = AtomicU64::new(0); + + let blocked = timeout( + Duration::from_millis(25), + process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0x41]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(1024), + 0, + &bytes_me2c, + 9901, + false, + false, + ), + ) + .await; + + assert!( + blocked.is_err(), + "ME->C quota reservation path must be serialized by held shared cross-mode lock" + ); + + drop(held_guard); + + let released = timeout( + Duration::from_millis(250), + process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0x42]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(1024), + 0, + &bytes_me2c, + 9902, + false, + false, + ), + ) + .await + .expect("ME->C write must complete after cross-mode lock release"); + + assert!(released.is_ok()); +} + +#[tokio::test] +async fn business_uncontended_cross_mode_lock_allows_me_to_client_quota_reservation() { + let stats = Stats::new(); + let user = format!("middle-me2c-cross-mode-free-{}", std::process::id()); + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let bytes_me2c = AtomicU64::new(0); + + let outcome = timeout( + Duration::from_millis(250), + process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0x55, 0x66]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(1024), + 0, + &bytes_me2c, + 9903, + false, + false, + ), + ) + .await + .expect("uncontended ME->C path should not stall"); + + assert!(outcome.is_ok()); + assert_eq!(stats.get_user_total_octets(&user), 2); + assert_eq!(bytes_me2c.load(std::sync::atomic::Ordering::Relaxed), 2); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn adversarial_cross_mode_lock_is_released_before_me_to_client_write_await() { + let stats = Arc::new(Stats::new()); + let user = format!("middle-me2c-lock-drop-before-write-{}", std::process::id()); + let cross_mode_lock = cross_mode_quota_user_lock_for_tests(&user); + let bytes_me2c = Arc::new(AtomicU64::new(0)); + let writer_state = Arc::new(BlockingWriteState::default()); + + let worker = { + let stats = Arc::clone(&stats); + let user = user.clone(); + let cross_mode_lock = Arc::clone(&cross_mode_lock); + let bytes_me2c = Arc::clone(&bytes_me2c); + let writer_state = Arc::clone(&writer_state); + tokio::spawn(async move { + let mut writer = make_crypto_writer(BlockingWrite::new(writer_state)); + let mut frame_buf = Vec::new(); + let rng = SecureRandom::new(); + process_me_writer_response_with_cross_mode_lock( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0xDE, 0xAD, 0xBE, 0xEF]), + }, + &mut writer, + ProtoTag::Intermediate, + &rng, + &mut frame_buf, + stats.as_ref(), + &user, + Some(1024), + 0, + Some(&cross_mode_lock), + bytes_me2c.as_ref(), + 9910, + false, + false, + ) + .await + }) + }; + + wait_until_blocking_write_entered(&writer_state).await; + + let acquired_guard = timeout(Duration::from_millis(40), cross_mode_lock.lock()) + .await + .expect("cross-mode lock must be free while ME->C write is pending"); + drop(acquired_guard); + + release_blocking_write(&writer_state); + + let result = timeout(Duration::from_millis(300), worker) + .await + .expect("ME->C worker timed out after releasing blocking writer") + .expect("ME->C worker must not panic"); + + assert!(result.is_ok()); + assert_eq!(stats.get_user_total_octets(&user), 4); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), 4); +} diff --git a/src/proxy/tests/middle_relay_hol_quota_security_tests.rs b/src/proxy/tests/middle_relay_hol_quota_security_tests.rs index 3d7929b..3ce0235 100644 --- a/src/proxy/tests/middle_relay_hol_quota_security_tests.rs +++ b/src/proxy/tests/middle_relay_hol_quota_security_tests.rs @@ -128,6 +128,7 @@ async fn adversarial_same_user_slow_writer_must_not_hol_block_peer_connection() &stats, user, quota_limit, + 0, &bytes_me2c, 7001, false, @@ -167,6 +168,7 @@ async fn adversarial_same_user_slow_writer_must_not_hol_block_peer_connection() &stats_fast, user, quota_limit, + 0, &bytes_fast, 7002, false, @@ -208,6 +210,7 @@ async fn negative_write_failure_rolls_back_pre_accounted_quota_and_forensics_byt &stats, user, Some(64), + 0, &bytes_me2c, 7003, false, diff --git a/src/proxy/tests/middle_relay_quota_extended_attack_surface_security_tests.rs b/src/proxy/tests/middle_relay_quota_extended_attack_surface_security_tests.rs new file mode 100644 index 0000000..29384e0 --- /dev/null +++ b/src/proxy/tests/middle_relay_quota_extended_attack_surface_security_tests.rs @@ -0,0 +1,372 @@ +use super::*; +use crate::crypto::{AesCtr, SecureRandom}; +use crate::error::ProxyError; +use crate::stats::Stats; +use crate::stream::CryptoWriter; +use bytes::Bytes; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::{Arc, OnceLock, Mutex}; +use tokio::sync::Mutex as AsyncMutex; +use tokio::task::JoinSet; +use tokio::time::{Duration, timeout}; + +fn make_crypto_writer(writer: W) -> CryptoWriter +where + W: tokio::io::AsyncWrite + Unpin, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) +} + +fn lookup_test_lock() -> &'static Mutex<()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())) +} + +#[tokio::test] +async fn positive_me2c_quota_counts_bytes_exactly_once() { + let _guard = lookup_test_lock().lock().unwrap(); + let stats = Stats::new(); + let user = format!("quota-middle-ext-positive-{}", std::process::id()); + let lock = Arc::new(AsyncMutex::new(())); + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let bytes_me2c = AtomicU64::new(0); + + let result = process_me_writer_response_with_cross_mode_lock( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[1, 2, 3, 4, 5]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(64), + 0, + Some(&lock), + &bytes_me2c, + 70_001, + false, + false, + ) + .await; + + assert!(result.is_ok()); + assert_eq!(stats.get_user_total_octets(&user), 5); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), 5); +} + +#[tokio::test] +async fn negative_held_crossmode_lock_blocks_me2c_write() { + let _guard = lookup_test_lock().lock().unwrap(); + let stats = Stats::new(); + let user = format!("quota-middle-ext-negative-{}", std::process::id()); + + let lock = Arc::new(AsyncMutex::new(())); + let _held = lock.try_lock().expect("lock must be held"); + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let bytes_me2c = AtomicU64::new(0); + + let blocked = timeout( + Duration::from_millis(25), + process_me_writer_response_with_cross_mode_lock( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0xFE]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(16), + 0, + Some(&lock), + &bytes_me2c, + 70_101, + false, + false, + ), + ) + .await; + + assert!(blocked.is_err()); + assert_eq!(stats.get_user_total_octets(&user), 0); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0); +} + +#[tokio::test] +async fn edge_zero_quota_zero_payload_is_fail_closed() { + let _guard = lookup_test_lock().lock().unwrap(); + let stats = Stats::new(); + let user = format!("quota-middle-ext-edge-{}", std::process::id()); + + let lock = Arc::new(AsyncMutex::new(())); + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let bytes_me2c = AtomicU64::new(0); + + let result = process_me_writer_response_with_cross_mode_lock( + MeResponse::Data { + flags: 0, + data: Bytes::new(), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(0), + 0, + Some(&lock), + &bytes_me2c, + 70_201, + false, + false, + ) + .await; + + assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. }))); + assert_eq!(stats.get_user_total_octets(&user), 0); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn adversarial_parallel_me2c_race_falls_back_to_quota_error() { + let _guard = lookup_test_lock().lock().unwrap(); + let stats = Arc::new(Stats::new()); + let user = format!("quota-middle-ext-blackhat-{}", std::process::id()); + let quota = 64u64; + let lock = Arc::new(AsyncMutex::new(())); + let bytes_me2c = Arc::new(AtomicU64::new(0)); + + let mut set = JoinSet::new(); + for i in 0..256u64 { + let stats = Arc::clone(&stats); + let user = user.clone(); + let lock = Arc::clone(&lock); + let bytes_me2c = Arc::clone(&bytes_me2c); + + set.spawn(async move { + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let payload = vec![((i & 0xFF) as u8); (i % 4 + 1) as usize]; + + process_me_writer_response_with_cross_mode_lock( + MeResponse::Data { + flags: 0, + data: Bytes::from(payload), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + stats.as_ref(), + &user, + Some(quota), + 0, + Some(&lock), + bytes_me2c.as_ref(), + 70_301 + i, + false, + false, + ) + .await + }); + } + + let mut succeeded = 0usize; + while let Some(done) = set.join_next().await { + match done.expect("task must not panic") { + Ok(_) => succeeded += 1, + Err(ProxyError::DataQuotaExceeded { .. }) => {} + Err(other) => panic!("unexpected error {other:?}"), + } + } + + assert_eq!(stats.get_user_total_octets(&user), bytes_me2c.load(Ordering::Relaxed)); + assert!(stats.get_user_total_octets(&user) <= quota); + assert!(succeeded <= quota as usize); +} + +#[tokio::test] +async fn integration_shared_prefetched_lock_blocks_then_releases_writer() { + let stats = Stats::new(); + let user = format!("quota-middle-ext-integration-{}", std::process::id()); + let lock = Arc::new(AsyncMutex::new(())); + let held = lock + .try_lock() + .expect("integration test must hold prefetched lock first"); + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let bytes_me2c = AtomicU64::new(0); + + let blocked = timeout( + Duration::from_millis(25), + process_me_writer_response_with_cross_mode_lock( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0xA1]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(8), + 0, + Some(&lock), + &bytes_me2c, + 70_360, + false, + false, + ), + ) + .await; + assert!(blocked.is_err()); + + drop(held); + + let after_release = timeout( + Duration::from_millis(150), + process_me_writer_response_with_cross_mode_lock( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0xA2]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(8), + 0, + Some(&lock), + &bytes_me2c, + 70_361, + false, + false, + ), + ) + .await + .expect("writer should progress once the shared lock is released"); + + assert!(after_release.is_ok()); +} + +#[tokio::test] +async fn light_fuzz_small_payloads_toggle_lock_state_stays_consistent() { + let _guard = lookup_test_lock().lock().unwrap(); + let stats = Stats::new(); + let user = format!("quota-middle-ext-fuzz-{}", std::process::id()); + let mut seed = 0xCAFE_BABE_1234u64; + let bytes_me2c = AtomicU64::new(0); + + for case in 0..48u32 { + seed ^= seed << 5; + seed ^= seed >> 12; + seed ^= seed << 13; + let hold = (seed & 0x1) == 0; + + let lock = Arc::new(AsyncMutex::new(())); + let maybe_guard = if hold { + Some(lock.try_lock().unwrap()) + } else { + None + }; + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + + let result = timeout( + Duration::from_millis(30), + process_me_writer_response_with_cross_mode_lock( + MeResponse::Data { + flags: 0, + data: Bytes::from(vec![(seed & 0xFF) as u8; ((seed as usize % 5) + 1)]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(128), + 0, + Some(&lock), + &bytes_me2c, + 70_401 + case as u64, + false, + false, + ), + ) + .await; + + if hold { + assert!(result.is_err()); + } else { + assert!(result.unwrap().is_ok()); + } + + drop(maybe_guard); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_parallel_free_users_during_held_user_lock_maintains_liveness() { + let _guard = lookup_test_lock().lock().unwrap(); + let held = Arc::new(AsyncMutex::new(())); + let _held_guard = held.try_lock().unwrap(); + + let mut set = JoinSet::new(); + for i in 0..48u64 { + set.spawn(async move { + let stats = Stats::new(); + let user = format!("quota-middle-ext-stress-free-{i}"); + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let bytes_me2c = AtomicU64::new(0); + let free_lock = Arc::new(AsyncMutex::new(())); + + process_me_writer_response_with_cross_mode_lock( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0xEE]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(1), + 0, + Some(&free_lock), + &bytes_me2c, + 70_500 + i, + false, + false, + ) + .await + }); + } + + timeout(Duration::from_secs(2), async { + while let Some(task) = set.join_next().await { + task.unwrap().unwrap(); + } + }) + .await + .unwrap(); +} diff --git a/src/proxy/tests/middle_relay_quota_reservation_adversarial_tests.rs b/src/proxy/tests/middle_relay_quota_reservation_adversarial_tests.rs index 717a375..963b3e0 100644 --- a/src/proxy/tests/middle_relay_quota_reservation_adversarial_tests.rs +++ b/src/proxy/tests/middle_relay_quota_reservation_adversarial_tests.rs @@ -5,6 +5,8 @@ use crate::stream::CryptoWriter; use bytes::Bytes; use std::sync::Arc; use std::sync::atomic::{AtomicU64, Ordering}; +use std::task::{Context, Poll}; +use tokio::io::AsyncWrite; use tokio::task::JoinSet; fn make_crypto_writer(writer: W) -> CryptoWriter @@ -16,6 +18,77 @@ where CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) } +struct FailingWriter; + +impl AsyncWrite for FailingWriter { + fn poll_write( + self: std::pin::Pin<&mut Self>, + _cx: &mut Context<'_>, + _buf: &[u8], + ) -> Poll> { + Poll::Ready(Err(std::io::Error::other("forced writer failure"))) + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown( + self: std::pin::Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } +} + +struct FailAfterBudgetWriter { + remaining: usize, + written: usize, +} + +impl FailAfterBudgetWriter { + fn new(remaining: usize) -> Self { + Self { + remaining, + written: 0, + } + } +} + +impl AsyncWrite for FailAfterBudgetWriter { + fn poll_write( + mut self: std::pin::Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + if self.remaining == 0 { + return Poll::Ready(Err(std::io::Error::other("forced short-write exhaustion"))); + } + + let n = self.remaining.min(buf.len()); + self.remaining -= n; + self.written += n; + Poll::Ready(Ok(n)) + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown( + self: std::pin::Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } +} + #[tokio::test] async fn positive_exact_quota_boundary_allows_last_frame_and_blocks_next() { let stats = Stats::new(); @@ -38,6 +111,7 @@ async fn positive_exact_quota_boundary_allows_last_frame_and_blocks_next() { &stats, user, Some(8), + 0, &bytes_me2c, 7101, false, @@ -62,6 +136,7 @@ async fn positive_exact_quota_boundary_allows_last_frame_and_blocks_next() { &stats, user, Some(8), + 0, &bytes_me2c, 7102, false, @@ -105,6 +180,7 @@ async fn adversarial_parallel_reservation_stress_never_overshoots_quota_or_count stats_ref.as_ref(), &user_owned, Some(quota_limit), + 0, bytes_ref.as_ref(), 7200 + idx, false, @@ -171,6 +247,7 @@ async fn light_fuzz_random_frame_sizes_preserve_quota_and_counter_consistency() &stats, user, Some(quota_limit), + 0, &bytes_me2c, 7300 + conn, false, @@ -189,4 +266,801 @@ async fn light_fuzz_random_frame_sizes_preserve_quota_and_counter_consistency() let total = stats.get_user_total_octets(user); assert!(total <= quota_limit); assert_eq!(bytes_me2c.load(Ordering::Relaxed), total); +} + +#[tokio::test] +async fn positive_soft_overshoot_allows_burst_inside_soft_cap_then_blocks() { + let stats = Stats::new(); + let user = "soft-cap-boundary-user"; + let bytes_me2c = AtomicU64::new(0); + let quota_limit = 10u64; + let overshoot = 3u64; + + stats.add_user_octets_from(user, 10); + + let mut writer_one = make_crypto_writer(tokio::io::sink()); + let mut frame_buf_one = Vec::new(); + let first = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[1, 2, 3]), + }, + &mut writer_one, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf_one, + &stats, + user, + Some(quota_limit), + overshoot, + &bytes_me2c, + 7401, + false, + false, + ) + .await; + assert!(first.is_ok(), "soft-cap buffer should allow reaching limit+overshoot"); + assert_eq!(stats.get_user_total_octets(user), 13); + + let mut writer_two = make_crypto_writer(tokio::io::sink()); + let mut frame_buf_two = Vec::new(); + let second = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[9]), + }, + &mut writer_two, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf_two, + &stats, + user, + Some(quota_limit), + overshoot, + &bytes_me2c, + 7402, + false, + false, + ) + .await; + assert!(matches!(second, Err(ProxyError::DataQuotaExceeded { .. }))); + assert_eq!(stats.get_user_total_octets(user), 13); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), 3); +} + +#[tokio::test] +async fn negative_soft_overshoot_rejects_when_payload_exceeds_remaining_soft_budget() { + let stats = Stats::new(); + let user = "soft-cap-remaining-user"; + let bytes_me2c = AtomicU64::new(0); + let quota_limit = 10u64; + let overshoot = 4u64; + + stats.add_user_octets_from(user, 12); + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let result = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[1, 2, 3]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + user, + Some(quota_limit), + overshoot, + &bytes_me2c, + 7501, + false, + false, + ) + .await; + + assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. }))); + assert_eq!(stats.get_user_total_octets(user), 12); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0); +} + +#[tokio::test] +async fn negative_write_failure_rolls_back_reservation_under_soft_cap_mode() { + let stats = Stats::new(); + let user = "soft-cap-rollback-user"; + let bytes_me2c = AtomicU64::new(0); + let mut writer = make_crypto_writer(FailingWriter); + let mut frame_buf = Vec::new(); + + stats.add_user_octets_from(user, 9); + + let result = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[1, 2, 3]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + user, + Some(10), + 8, + &bytes_me2c, + 7601, + false, + false, + ) + .await; + + assert!(matches!(result, Err(ProxyError::Io(_)))); + assert_eq!(stats.get_user_total_octets(user), 9); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn adversarial_parallel_soft_cap_stress_never_exceeds_soft_limit() { + let stats = Arc::new(Stats::new()); + let user = "soft-cap-stress-user"; + let quota_limit = 40u64; + let overshoot = 5u64; + let soft_limit = quota_limit + overshoot; + let bytes_me2c = Arc::new(AtomicU64::new(0)); + let mut tasks = JoinSet::new(); + + for idx in 0..256u64 { + let user_owned = user.to_string(); + let stats_ref = Arc::clone(&stats); + let bytes_ref = Arc::clone(&bytes_me2c); + tasks.spawn(async move { + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0x42]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + stats_ref.as_ref(), + &user_owned, + Some(quota_limit), + overshoot, + bytes_ref.as_ref(), + 7700 + idx, + false, + false, + ) + .await + }); + } + + while let Some(joined) = tasks.join_next().await { + match joined.expect("soft-cap stress task must not panic") { + Ok(_) | Err(ProxyError::DataQuotaExceeded { .. }) => {} + Err(other) => panic!("unexpected error in soft-cap stress case: {other:?}"), + } + } + + let total = stats.get_user_total_octets(user); + assert!(total <= soft_limit, "soft-cap stress must never overshoot soft limit"); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), total); +} + +#[tokio::test] +async fn light_fuzz_soft_cap_matrix_keeps_counters_and_limits_consistent() { + let stats = Stats::new(); + let user = "soft-cap-fuzz-user"; + let bytes_me2c = AtomicU64::new(0); + let mut seed = 0x9E37_79B9_7F4A_7C15u64; + + for conn in 0..1024u64 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let quota_limit = 32 + (seed & 0x3f); + let overshoot = seed.rotate_left(13) & 0x0f; + let len = ((seed >> 3) & 0x07) + 1; + let payload = vec![0xA5; len as usize]; + let before = stats.get_user_total_octets(user); + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let result = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from(payload), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + user, + Some(quota_limit), + overshoot, + &bytes_me2c, + 7800 + conn, + false, + false, + ) + .await; + + if let Err(ref err) = result { + assert!( + matches!(err, ProxyError::DataQuotaExceeded { .. }), + "soft-cap fuzz produced unexpected error variant: {err:?}" + ); + } + + let after = stats.get_user_total_octets(user); + let soft_limit = quota_limit.saturating_add(overshoot); + match result { + Ok(_) => { + assert_eq!(after, before.saturating_add(len)); + assert!(after <= soft_limit, "accepted write must stay within active soft cap"); + } + Err(_) => { + assert_eq!(after, before, "rejected write must not mutate quota state"); + } + } + assert_eq!( + bytes_me2c.load(Ordering::Relaxed), + after, + "soft-cap fuzz must keep counters synchronized" + ); + } +} + +#[tokio::test] +async fn positive_no_quota_limit_accumulates_data_octets_exactly() { + let stats = Stats::new(); + let user = "no-quota-user"; + let bytes_me2c = AtomicU64::new(0); + let mut expected = 0u64; + + for (idx, len) in [1usize, 2, 3, 5, 8, 13, 21].iter().copied().enumerate() { + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let payload = vec![0x41; len]; + let result = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from(payload), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + user, + None, + 0, + &bytes_me2c, + 7900 + idx as u64, + false, + false, + ) + .await; + + assert!(result.is_ok()); + expected += len as u64; + } + + assert_eq!(stats.get_user_total_octets(user), expected); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), expected); +} + +#[tokio::test] +async fn negative_zero_quota_rejects_non_empty_payload() { + let stats = Stats::new(); + let user = "zero-quota-user"; + let bytes_me2c = AtomicU64::new(0); + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let result = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0xAA]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + user, + Some(0), + 0, + &bytes_me2c, + 8001, + false, + false, + ) + .await; + + assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. }))); + assert_eq!(stats.get_user_total_octets(user), 0); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0); +} + +#[tokio::test] +async fn edge_zero_length_payload_with_zero_quota_is_fail_closed() { + let stats = Stats::new(); + let user = "zero-len-zero-quota-user"; + let bytes_me2c = AtomicU64::new(0); + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let result = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::new(), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + user, + Some(0), + 0, + &bytes_me2c, + 8002, + false, + false, + ) + .await; + + assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. }))); + assert_eq!(stats.get_user_total_octets(user), 0); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0); +} + +#[tokio::test] +async fn positive_ack_response_does_not_touch_quota_counters() { + let stats = Stats::new(); + let user = "ack-accounting-user"; + let bytes_me2c = AtomicU64::new(11); + stats.add_user_octets_to(user, 23); + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let result = process_me_writer_response( + MeResponse::Ack(0x33445566), + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + user, + Some(24), + 0, + &bytes_me2c, + 8003, + true, + true, + ) + .await; + + assert!(result.is_ok()); + assert_eq!(stats.get_user_total_octets(user), 23); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), 11); +} + +#[tokio::test] +async fn edge_close_response_is_accounting_noop() { + let stats = Stats::new(); + let user = "close-accounting-user"; + let bytes_me2c = AtomicU64::new(19); + stats.add_user_octets_to(user, 31); + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let result = process_me_writer_response( + MeResponse::Close, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + user, + Some(40), + 3, + &bytes_me2c, + 8004, + false, + true, + ) + .await; + + assert!(result.is_ok()); + assert_eq!(stats.get_user_total_octets(user), 31); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), 19); +} + +#[tokio::test] +async fn negative_preloaded_above_soft_cap_rejects_even_single_byte() { + let stats = Stats::new(); + let user = "preloaded-over-soft-cap-user"; + let bytes_me2c = AtomicU64::new(0); + let quota_limit = 20u64; + let overshoot = 2u64; + stats.add_user_octets_to(user, quota_limit + overshoot + 1); + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let result = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[1]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + user, + Some(quota_limit), + overshoot, + &bytes_me2c, + 8005, + false, + false, + ) + .await; + + assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. }))); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0); + assert_eq!(stats.get_user_total_octets(user), quota_limit + overshoot + 1); +} + +#[tokio::test] +async fn adversarial_fail_writer_path_never_desynchronizes_quota_accounting() { + let stats = Stats::new(); + let user = "partial-write-rollback-user"; + let bytes_me2c = AtomicU64::new(0); + let mut writer = make_crypto_writer(FailAfterBudgetWriter::new(7)); + let mut frame_buf = Vec::new(); + let payload_len = 16 * 1024u64; + + let result = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from(vec![0x42; 16 * 1024]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + user, + Some(payload_len), + 0, + &bytes_me2c, + 8006, + false, + false, + ) + .await; + + let total_after = stats.get_user_total_octets(user); + let forensic_after = bytes_me2c.load(Ordering::Relaxed); + assert_eq!(forensic_after, total_after); + assert!( + total_after == 0 || total_after == payload_len, + "writer failure path must either roll back fully or commit exactly one payload" + ); + + // Regardless of whether I/O failure surfaced immediately or was deferred, + // accounting must remain fail-closed and prevent silent overshoot. + let mut writer_two = make_crypto_writer(tokio::io::sink()); + let mut frame_buf_two = Vec::new(); + let second = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0x99]), + }, + &mut writer_two, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf_two, + &stats, + user, + Some(payload_len), + 0, + &bytes_me2c, + 8007, + false, + false, + ) + .await; + + if total_after == payload_len { + assert!(matches!(second, Err(ProxyError::DataQuotaExceeded { .. }))); + } else { + assert!(second.is_ok()); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_parallel_oversized_frames_fail_closed_without_counter_leak() { + let stats = Arc::new(Stats::new()); + let user = "parallel-fail-rollback-user"; + let bytes_me2c = Arc::new(AtomicU64::new(0)); + let mut tasks = JoinSet::new(); + + for idx in 0..256u64 { + let user_owned = user.to_string(); + let stats_ref = Arc::clone(&stats); + let bytes_ref = Arc::clone(&bytes_me2c); + tasks.spawn(async move { + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from(vec![0xEE; 12 * 1024]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + stats_ref.as_ref(), + &user_owned, + Some(512), + 0, + bytes_ref.as_ref(), + 8100 + idx, + false, + false, + ) + .await + }); + } + + while let Some(joined) = tasks.join_next().await { + let result = joined.expect("parallel fail writer task must not panic"); + assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. }))); + } + + assert_eq!(stats.get_user_total_octets(user), 0); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0); +} + +#[tokio::test] +async fn integration_mixed_data_ack_close_sequence_preserves_data_only_accounting() { + let stats = Stats::new(); + let user = "mixed-sequence-user"; + let bytes_me2c = AtomicU64::new(0); + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + + let data_one = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[1, 2, 3]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + user, + Some(32), + 0, + &bytes_me2c, + 8201, + false, + false, + ) + .await; + assert!(data_one.is_ok()); + + let ack = process_me_writer_response( + MeResponse::Ack(0x0102_0304), + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + user, + Some(32), + 0, + &bytes_me2c, + 8202, + true, + true, + ) + .await; + assert!(ack.is_ok()); + + let data_two = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[4, 5]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + user, + Some(32), + 0, + &bytes_me2c, + 8203, + false, + true, + ) + .await; + assert!(data_two.is_ok()); + + let close = process_me_writer_response( + MeResponse::Close, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + user, + Some(32), + 0, + &bytes_me2c, + 8204, + false, + true, + ) + .await; + assert!(close.is_ok()); + + assert_eq!(stats.get_user_total_octets(user), 5); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), 5); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_parallel_multi_user_quota_isolation_no_cross_user_leakage() { + let stats = Arc::new(Stats::new()); + let user_a = "quota-isolation-a"; + let user_b = "quota-isolation-b"; + let limit_a = 50u64; + let limit_b = 80u64; + let bytes_a = Arc::new(AtomicU64::new(0)); + let bytes_b = Arc::new(AtomicU64::new(0)); + + let mut tasks = JoinSet::new(); + for idx in 0..200u64 { + let stats_ref = Arc::clone(&stats); + let bytes_ref = Arc::clone(&bytes_a); + tasks.spawn(async move { + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0xA1]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + stats_ref.as_ref(), + user_a, + Some(limit_a), + 0, + bytes_ref.as_ref(), + 8300 + idx, + false, + false, + ) + .await + }); + } + + for idx in 0..220u64 { + let stats_ref = Arc::clone(&stats); + let bytes_ref = Arc::clone(&bytes_b); + tasks.spawn(async move { + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0xB2]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + stats_ref.as_ref(), + user_b, + Some(limit_b), + 0, + bytes_ref.as_ref(), + 8500 + idx, + false, + false, + ) + .await + }); + } + + while let Some(joined) = tasks.join_next().await { + let result = joined.expect("quota isolation task must not panic"); + assert!(result.is_ok() || matches!(result, Err(ProxyError::DataQuotaExceeded { .. }))); + } + + assert_eq!(stats.get_user_total_octets(user_a), limit_a); + assert_eq!(stats.get_user_total_octets(user_b), limit_b); + assert_eq!(bytes_a.load(Ordering::Relaxed), limit_a); + assert_eq!(bytes_b.load(Ordering::Relaxed), limit_b); +} + +#[tokio::test] +async fn light_fuzz_mixed_me_responses_preserve_quota_and_counter_invariants() { + let stats = Stats::new(); + let user = "mixed-fuzz-user"; + let bytes_me2c = AtomicU64::new(0); + let quota_limit = 96u64; + let mut seed = 0xDEAD_BEEF_2026_0323u64; + + for idx in 0..2048u64 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let choice = (seed & 0x03) as u8; + let response = if choice == 0 { + MeResponse::Ack((seed >> 8) as u32) + } else if choice == 1 { + MeResponse::Close + } else { + let len = ((seed >> 16) & 0x07) as usize; + let mut payload = vec![0u8; len]; + payload.fill((seed & 0xff) as u8); + MeResponse::Data { + flags: 0, + data: Bytes::from(payload), + } + }; + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let result = process_me_writer_response( + response, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + user, + Some(quota_limit), + 0, + &bytes_me2c, + 8800 + idx, + (idx & 1) == 0, + (idx & 2) == 0, + ) + .await; + + if let Err(err) = result { + assert!( + matches!(err, ProxyError::DataQuotaExceeded { .. }), + "mixed fuzz produced unexpected error variant: {err:?}" + ); + } + + let total = stats.get_user_total_octets(user); + assert!( + total <= quota_limit, + "mixed fuzz must keep usage at or below quota limit" + ); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), total); + } } \ No newline at end of file diff --git a/src/proxy/tests/middle_relay_quota_reservation_extreme_security_tests.rs b/src/proxy/tests/middle_relay_quota_reservation_extreme_security_tests.rs new file mode 100644 index 0000000..e4d0c6e --- /dev/null +++ b/src/proxy/tests/middle_relay_quota_reservation_extreme_security_tests.rs @@ -0,0 +1,399 @@ +use super::*; +use crate::crypto::{AesCtr, SecureRandom}; +use crate::stats::Stats; +use crate::stream::CryptoWriter; +use bytes::Bytes; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::{Arc, Mutex, OnceLock}; +use tokio::sync::Mutex as AsyncMutex; +use tokio::task::JoinSet; +use tokio::time::{Duration, timeout}; + +fn make_crypto_writer(writer: W) -> CryptoWriter +where + W: tokio::io::AsyncWrite + Unpin, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) +} + +fn lookup_counter_test_lock() -> &'static Mutex<()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())) +} + +#[tokio::test] +async fn positive_prefetched_cross_mode_lock_multi_frame_accounting_is_exact() { + let _guard = lookup_counter_test_lock() + .lock() + .unwrap_or_else(|poison| poison.into_inner()); + + let stats = Stats::new(); + let user = format!("quota-extreme-positive-{}", std::process::id()); + let lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); + crate::proxy::quota_lock_registry::reset_cross_mode_quota_user_lock_lookup_count_for_tests(); + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let bytes_me2c = AtomicU64::new(0); + + for idx in 0..12u64 { + let payload = vec![0x5A; ((idx % 4) + 1) as usize]; + let result = process_me_writer_response_with_cross_mode_lock( + MeResponse::Data { + flags: 0, + data: Bytes::from(payload), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(512), + 0, + Some(&lock), + &bytes_me2c, + 31_000 + idx, + false, + false, + ) + .await; + + assert!(result.is_ok()); + } + + assert_eq!( + crate::proxy::quota_lock_registry::cross_mode_quota_user_lock_lookup_count_for_user_for_tests(&user), + 0, + "prefetched lock path must avoid hot-path registry lookups" + ); + assert_eq!( + stats.get_user_total_octets(&user), + bytes_me2c.load(Ordering::Relaxed), + "forensics and quota accounting must remain synchronized" + ); +} + +#[tokio::test] +async fn negative_held_prefetched_lock_blocks_writer_without_accounting_mutation() { + let _guard = lookup_counter_test_lock() + .lock() + .unwrap_or_else(|poison| poison.into_inner()); + + let stats = Stats::new(); + let user = format!("quota-extreme-negative-{}", std::process::id()); + let lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); + let held_guard = lock + .try_lock() + .expect("test must hold lock before calling ME->C writer"); + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let bytes_me2c = AtomicU64::new(0); + + let blocked = timeout( + Duration::from_millis(25), + process_me_writer_response_with_cross_mode_lock( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[1, 2, 3]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(64), + 0, + Some(&lock), + &bytes_me2c, + 31_100, + false, + false, + ), + ) + .await; + + assert!(blocked.is_err()); + assert_eq!(stats.get_user_total_octets(&user), 0); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0); + + drop(held_guard); +} + +#[tokio::test] +async fn edge_zero_quota_and_zero_payload_is_fail_closed() { + let _guard = lookup_counter_test_lock() + .lock() + .unwrap_or_else(|poison| poison.into_inner()); + + let stats = Stats::new(); + let user = format!("quota-extreme-edge-{}", std::process::id()); + let lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let bytes_me2c = AtomicU64::new(0); + + let result = process_me_writer_response_with_cross_mode_lock( + MeResponse::Data { + flags: 0, + data: Bytes::new(), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(0), + 0, + Some(&lock), + &bytes_me2c, + 31_200, + false, + false, + ) + .await; + + assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. }))); + assert_eq!(stats.get_user_total_octets(&user), 0); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn adversarial_blackhat_parallel_quota_race_never_overshoots_soft_cap() { + let _guard = lookup_counter_test_lock() + .lock() + .unwrap_or_else(|poison| poison.into_inner()); + + let stats = Arc::new(Stats::new()); + let user = format!("quota-extreme-blackhat-{}", std::process::id()); + let quota = 80u64; + let overshoot = 7u64; + let soft_limit = quota + overshoot; + let lock = Arc::new(crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user)); + let bytes_me2c = Arc::new(AtomicU64::new(0)); + + let mut set = JoinSet::new(); + for idx in 0..256u64 { + let stats = Arc::clone(&stats); + let user = user.clone(); + let lock = Arc::clone(&lock); + let bytes_me2c = Arc::clone(&bytes_me2c); + + set.spawn(async move { + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let len = ((idx % 5) + 1) as usize; + let payload = vec![0xAA; len]; + + process_me_writer_response_with_cross_mode_lock( + MeResponse::Data { + flags: 0, + data: Bytes::from(payload), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + stats.as_ref(), + &user, + Some(quota), + overshoot, + Some(&lock), + bytes_me2c.as_ref(), + 31_300 + idx, + false, + false, + ) + .await + }); + } + + while let Some(done) = set.join_next().await { + match done.expect("task must not panic") { + Ok(_) | Err(ProxyError::DataQuotaExceeded { .. }) => {} + Err(other) => panic!("unexpected error variant under black-hat race: {other:?}"), + } + } + + let total = stats.get_user_total_octets(&user); + assert!( + total <= soft_limit, + "parallel adversarial race must stay under soft cap" + ); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), total); +} + +#[tokio::test] +async fn integration_without_prefetched_lock_uses_registry_lookup_path() { + let _guard = lookup_counter_test_lock() + .lock() + .unwrap_or_else(|poison| poison.into_inner()); + + let stats = Stats::new(); + let user = format!("quota-extreme-integration-{}", std::process::id()); + crate::proxy::quota_lock_registry::reset_cross_mode_quota_user_lock_lookup_count_for_tests(); + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let bytes_me2c = AtomicU64::new(0); + + for idx in 0..3u64 { + let result = process_me_writer_response_with_cross_mode_lock( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0x41]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(16), + 0, + None, + &bytes_me2c, + 31_400 + idx, + false, + false, + ) + .await; + + assert!(result.is_ok()); + } + + assert_eq!( + crate::proxy::quota_lock_registry::cross_mode_quota_user_lock_lookup_count_for_user_for_tests(&user), + 3, + "control path should perform one lock-registry lookup per call" + ); +} + +#[tokio::test] +async fn light_fuzz_quota_matrix_preserves_fail_closed_accounting() { + let _guard = lookup_counter_test_lock() + .lock() + .unwrap_or_else(|poison| poison.into_inner()); + + let stats = Stats::new(); + let user = format!("quota-extreme-fuzz-{}", std::process::id()); + let lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); + let bytes_me2c = AtomicU64::new(0); + let mut seed = 0xA11C_55EE_2026_0323u64; + + for idx in 0..512u64 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let quota = 24 + (seed & 0x3f); + let overshoot = (seed >> 13) & 0x0f; + let len = ((seed >> 19) & 0x07) + 1; + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let before = stats.get_user_total_octets(&user); + + let result = process_me_writer_response_with_cross_mode_lock( + MeResponse::Data { + flags: 0, + data: Bytes::from(vec![0x11; len as usize]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(quota), + overshoot, + Some(&lock), + &bytes_me2c, + 31_500 + idx, + false, + false, + ) + .await; + + let after = stats.get_user_total_octets(&user); + if result.is_ok() { + assert!(after >= before); + } else { + assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. }))); + assert_eq!(after, before); + } + assert_eq!(bytes_me2c.load(Ordering::Relaxed), after); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_prefetched_lock_high_fanout_exact_quota_success_count() { + let _guard = lookup_counter_test_lock() + .lock() + .unwrap_or_else(|poison| poison.into_inner()); + + let stats = Arc::new(Stats::new()); + let user = format!("quota-extreme-stress-{}", std::process::id()); + let quota = 96u64; + let lock: Arc> = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); + let bytes_me2c = Arc::new(AtomicU64::new(0)); + + crate::proxy::quota_lock_registry::reset_cross_mode_quota_user_lock_lookup_count_for_tests(); + + let mut set = JoinSet::new(); + for idx in 0..384u64 { + let stats = Arc::clone(&stats); + let user = user.clone(); + let lock = Arc::clone(&lock); + let bytes_me2c = Arc::clone(&bytes_me2c); + + set.spawn(async move { + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + process_me_writer_response_with_cross_mode_lock( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0xFF]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + stats.as_ref(), + &user, + Some(quota), + 0, + Some(&lock), + bytes_me2c.as_ref(), + 31_600 + idx, + false, + false, + ) + .await + }); + } + + let mut success = 0usize; + while let Some(done) = set.join_next().await { + match done.expect("task must not panic") { + Ok(_) => success += 1, + Err(ProxyError::DataQuotaExceeded { .. }) => {} + Err(other) => panic!("unexpected error variant in stress fanout: {other:?}"), + } + } + + assert_eq!(success, quota as usize); + assert_eq!(stats.get_user_total_octets(&user), quota); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), quota); + assert_eq!( + crate::proxy::quota_lock_registry::cross_mode_quota_user_lock_lookup_count_for_user_for_tests(&user), + 0, + "stress prefetched path must not use lock registry lookups" + ); +} diff --git a/src/proxy/tests/middle_relay_tiny_frame_debt_concurrency_security_tests.rs b/src/proxy/tests/middle_relay_tiny_frame_debt_concurrency_security_tests.rs index 1bf3123..34fc454 100644 --- a/src/proxy/tests/middle_relay_tiny_frame_debt_concurrency_security_tests.rs +++ b/src/proxy/tests/middle_relay_tiny_frame_debt_concurrency_security_tests.rs @@ -7,7 +7,7 @@ use std::sync::atomic::AtomicU64; use std::time::Instant; use tokio::io::{AsyncRead, AsyncWriteExt, duplex}; use tokio::task::JoinSet; -use tokio::time::{Duration as TokioDuration, sleep, timeout}; +use tokio::time::{Duration as TokioDuration, sleep}; fn make_crypto_reader(reader: T) -> CryptoReader where @@ -42,10 +42,10 @@ fn make_forensics(conn_id: u64, started_at: Instant) -> RelayForensicsState { fn make_enabled_idle_policy() -> RelayClientIdlePolicy { RelayClientIdlePolicy { enabled: true, - soft_idle: Duration::from_secs(30), - hard_idle: Duration::from_secs(60), + soft_idle: Duration::from_millis(50), + hard_idle: Duration::from_millis(120), grace_after_downstream_activity: Duration::from_secs(0), - legacy_frame_read_timeout: Duration::from_secs(30), + legacy_frame_read_timeout: Duration::from_millis(50), } } @@ -94,8 +94,8 @@ async fn stress_parallel_pure_tiny_floods_all_fail_closed() { writer.write_all(&flood_encrypted).await.unwrap(); drop(writer); - let result = timeout( - TokioDuration::from_secs(1), + let result = run_relay_test_step_timeout( + "tiny flood task", read_once( &mut crypto_reader, ProtoTag::Abridged, @@ -104,8 +104,7 @@ async fn stress_parallel_pure_tiny_floods_all_fail_closed() { &mut idle_state, ), ) - .await - .expect("tiny flood task must complete"); + .await; assert!(matches!(result, Err(ProxyError::Proxy(_)))); assert_eq!(frame_counter, 0); @@ -140,8 +139,8 @@ async fn stress_parallel_benign_tiny_burst_then_real_all_pass() { let encrypted = encrypt_for_reader(&plaintext); writer.write_all(&encrypted).await.unwrap(); - let result = timeout( - TokioDuration::from_secs(1), + let result = run_relay_test_step_timeout( + "benign tiny burst read", read_once( &mut crypto_reader, ProtoTag::Abridged, @@ -151,7 +150,6 @@ async fn stress_parallel_benign_tiny_burst_then_real_all_pass() { ), ) .await - .expect("benign task must complete") .expect("benign payload must parse") .expect("benign payload must return frame"); @@ -196,8 +194,8 @@ async fn adversarial_lockstep_alternating_attack_under_jitter_closes() { let mut closed = false; for _ in 0..220 { - let result = timeout( - TokioDuration::from_secs(1), + let result = run_relay_test_step_timeout( + "alternating jitter read step", read_once( &mut crypto_reader, ProtoTag::Abridged, @@ -206,8 +204,7 @@ async fn adversarial_lockstep_alternating_attack_under_jitter_closes() { &mut idle_state, ), ) - .await - .expect("alternating reader step must complete"); + .await; match result { Ok(Some((_payload, _))) => {} @@ -336,8 +333,8 @@ async fn light_fuzz_parallel_patterns_no_hang_or_panic() { drop(writer); for _ in 0..320 { - let step = timeout( - TokioDuration::from_secs(1), + let step = run_relay_test_step_timeout( + "fuzz case read step", read_once( &mut crypto_reader, ProtoTag::Abridged, @@ -346,8 +343,7 @@ async fn light_fuzz_parallel_patterns_no_hang_or_panic() { &mut idle_state, ), ) - .await - .expect("fuzz case read step must complete"); + .await; match step { Ok(Some(_)) => {} diff --git a/src/proxy/tests/middle_relay_tiny_frame_debt_proto_chunking_security_tests.rs b/src/proxy/tests/middle_relay_tiny_frame_debt_proto_chunking_security_tests.rs index 0ff46a2..853b381 100644 --- a/src/proxy/tests/middle_relay_tiny_frame_debt_proto_chunking_security_tests.rs +++ b/src/proxy/tests/middle_relay_tiny_frame_debt_proto_chunking_security_tests.rs @@ -6,7 +6,7 @@ use std::sync::Arc; use std::sync::atomic::AtomicU64; use std::time::Instant; use tokio::io::{AsyncRead, AsyncWriteExt, duplex}; -use tokio::time::{Duration as TokioDuration, sleep, timeout}; +use tokio::time::{Duration as TokioDuration, sleep}; fn make_crypto_reader(reader: T) -> CryptoReader where @@ -41,10 +41,10 @@ fn make_forensics(conn_id: u64, started_at: Instant) -> RelayForensicsState { fn make_enabled_idle_policy() -> RelayClientIdlePolicy { RelayClientIdlePolicy { enabled: true, - soft_idle: Duration::from_secs(30), - hard_idle: Duration::from_secs(60), + soft_idle: Duration::from_millis(50), + hard_idle: Duration::from_millis(120), grace_after_downstream_activity: Duration::from_secs(0), - legacy_frame_read_timeout: Duration::from_secs(30), + legacy_frame_read_timeout: Duration::from_millis(50), } } @@ -117,6 +117,11 @@ async fn read_once_with_state( .await } +fn is_fail_closed_outcome(result: &Result>) -> bool { + matches!(result, Err(ProxyError::Proxy(_))) + || matches!(result, Err(ProxyError::Io(e)) if e.kind() == std::io::ErrorKind::TimedOut) +} + #[tokio::test] async fn intermediate_chunked_zero_flood_fail_closed() { let (reader, mut writer) = duplex(4096); @@ -134,8 +139,8 @@ async fn intermediate_chunked_zero_flood_fail_closed() { write_chunked_with_jitter(&mut writer, &encrypted, 0x1111_2222).await; drop(writer); - let result = timeout( - TokioDuration::from_secs(2), + let result = run_relay_test_step_timeout( + "intermediate flood read", read_once_with_state( &mut crypto_reader, ProtoTag::Intermediate, @@ -144,10 +149,12 @@ async fn intermediate_chunked_zero_flood_fail_closed() { &mut idle_state, ), ) - .await - .expect("intermediate flood read must complete"); + .await; - assert!(matches!(result, Err(ProxyError::Proxy(_)))); + assert!( + is_fail_closed_outcome(&result), + "zero-length flood must fail closed via debt guard or idle timeout" + ); assert_eq!(frame_counter, 0); } @@ -168,8 +175,8 @@ async fn secure_chunked_zero_flood_fail_closed() { write_chunked_with_jitter(&mut writer, &encrypted, 0x3333_4444).await; drop(writer); - let result = timeout( - TokioDuration::from_secs(2), + let result = run_relay_test_step_timeout( + "secure flood read", read_once_with_state( &mut crypto_reader, ProtoTag::Secure, @@ -178,10 +185,12 @@ async fn secure_chunked_zero_flood_fail_closed() { &mut idle_state, ), ) - .await - .expect("secure flood read must complete"); + .await; - assert!(matches!(result, Err(ProxyError::Proxy(_)))); + assert!( + is_fail_closed_outcome(&result), + "secure zero-length flood must fail closed via debt guard or idle timeout" + ); assert_eq!(frame_counter, 0); } @@ -208,8 +217,8 @@ async fn intermediate_chunked_alternating_attack_closes_before_eof() { let mut closed = false; for _ in 0..240 { - let step = timeout( - TokioDuration::from_secs(1), + let step = run_relay_test_step_timeout( + "intermediate alternating read step", read_once_with_state( &mut crypto_reader, ProtoTag::Intermediate, @@ -218,8 +227,7 @@ async fn intermediate_chunked_alternating_attack_closes_before_eof() { &mut idle_state, ), ) - .await - .expect("intermediate alternating read step must complete"); + .await; match step { Ok(Some(_)) => {} @@ -259,8 +267,8 @@ async fn secure_chunked_alternating_attack_closes_before_eof() { let mut closed = false; for _ in 0..240 { - let step = timeout( - TokioDuration::from_secs(1), + let step = run_relay_test_step_timeout( + "secure alternating read step", read_once_with_state( &mut crypto_reader, ProtoTag::Secure, @@ -269,8 +277,7 @@ async fn secure_chunked_alternating_attack_closes_before_eof() { &mut idle_state, ), ) - .await - .expect("secure alternating read step must complete"); + .await; match step { Ok(Some(_)) => {} @@ -394,8 +401,8 @@ async fn light_fuzz_proto_chunking_outcomes_are_bounded() { drop(writer); for _ in 0..260 { - let step = timeout( - TokioDuration::from_secs(1), + let step = run_relay_test_step_timeout( + "fuzz proto read step", read_once_with_state( &mut crypto_reader, proto, @@ -404,12 +411,12 @@ async fn light_fuzz_proto_chunking_outcomes_are_bounded() { &mut idle_state, ), ) - .await - .expect("fuzz proto read step must complete"); + .await; match step { Ok(Some((_payload, _))) => {} Err(ProxyError::Proxy(_)) => break, + Err(ProxyError::Io(e)) if e.kind() == std::io::ErrorKind::TimedOut => break, Ok(None) => break, Err(other) => panic!("unexpected proto chunking fuzz error: {other}"), } diff --git a/src/proxy/tests/middle_relay_tiny_frame_debt_security_tests.rs b/src/proxy/tests/middle_relay_tiny_frame_debt_security_tests.rs index d0719c8..dee5dd9 100644 --- a/src/proxy/tests/middle_relay_tiny_frame_debt_security_tests.rs +++ b/src/proxy/tests/middle_relay_tiny_frame_debt_security_tests.rs @@ -40,13 +40,44 @@ fn make_forensics(conn_id: u64, started_at: Instant) -> RelayForensicsState { fn make_enabled_idle_policy() -> RelayClientIdlePolicy { RelayClientIdlePolicy { enabled: true, - soft_idle: Duration::from_secs(30), - hard_idle: Duration::from_secs(60), + soft_idle: Duration::from_millis(50), + hard_idle: Duration::from_millis(120), grace_after_downstream_activity: Duration::from_secs(0), - legacy_frame_read_timeout: Duration::from_secs(30), + legacy_frame_read_timeout: Duration::from_millis(50), } } +async fn read_bounded( + crypto_reader: &mut CryptoReader, + proto_tag: ProtoTag, + buffer_pool: &Arc, + forensics: &RelayForensicsState, + frame_counter: &mut u64, + stats: &Stats, + idle_policy: &RelayClientIdlePolicy, + idle_state: &mut RelayClientIdleState, + last_downstream_activity_ms: &AtomicU64, + session_started_at: Instant, +) -> Result> { + run_relay_test_step_timeout( + "tiny-frame debt read step", + read_client_payload_with_idle_policy( + crypto_reader, + proto_tag, + 1024, + buffer_pool, + forensics, + frame_counter, + stats, + idle_policy, + idle_state, + last_downstream_activity_ms, + session_started_at, + ), + ) + .await +} + fn simulate_tiny_debt_pattern(pattern: &[bool], max_steps: usize) -> (Option, u32, usize) { let mut debt = 0u32; let mut reals = 0usize; @@ -246,10 +277,9 @@ async fn idle_policy_enabled_intermediate_zero_length_flood_is_fail_closed() { writer.write_all(&flood_encrypted).await.unwrap(); drop(writer); - let result = read_client_payload_with_idle_policy( + let result = read_bounded( &mut crypto_reader, ProtoTag::Intermediate, - 1024, &buffer_pool, &forensics, &mut frame_counter, @@ -282,10 +312,9 @@ async fn idle_policy_enabled_secure_zero_length_flood_is_fail_closed() { writer.write_all(&flood_encrypted).await.unwrap(); drop(writer); - let result = read_client_payload_with_idle_policy( + let result = read_bounded( &mut crypto_reader, ProtoTag::Secure, - 1024, &buffer_pool, &forensics, &mut frame_counter, @@ -325,10 +354,9 @@ async fn intermediate_alternating_zero_and_real_eventually_closes() { let mut closed = false; for _ in 0..220 { - let result = read_client_payload_with_idle_policy( + let result = read_bounded( &mut crypto_reader, ProtoTag::Intermediate, - 1024, &buffer_pool, &forensics, &mut frame_counter, @@ -377,10 +405,9 @@ async fn small_tiny_burst_followed_by_real_frame_does_not_spuriously_close() { let encrypted = encrypt_for_reader(&plaintext); writer.write_all(&encrypted).await.unwrap(); - let first = read_client_payload_with_idle_policy( + let first = read_bounded( &mut crypto_reader, ProtoTag::Abridged, - 1024, &buffer_pool, &forensics, &mut frame_counter, @@ -420,10 +447,9 @@ async fn idle_policy_enabled_zero_length_flood_is_fail_closed() { .expect("zero-length flood bytes must be writable"); drop(writer); - let result = read_client_payload_with_idle_policy( + let result = read_bounded( &mut crypto_reader, ProtoTag::Abridged, - 1024, &buffer_pool, &forensics, &mut frame_counter, @@ -470,10 +496,9 @@ async fn idle_policy_enabled_alternating_tiny_real_eventually_closes() { let mut saw_proxy_close = false; for _ in 0..300 { - let result = read_client_payload_with_idle_policy( + let result = read_bounded( &mut crypto_reader, ProtoTag::Abridged, - 1024, &buffer_pool, &forensics, &mut frame_counter, @@ -527,10 +552,9 @@ async fn enabled_idle_policy_valid_nonzero_frame_still_passes() { .await .expect("nonzero frame must be writable"); - let result = read_client_payload_with_idle_policy( + let result = read_bounded( &mut crypto_reader, ProtoTag::Abridged, - 1024, &buffer_pool, &forensics, &mut frame_counter, @@ -548,3 +572,227 @@ async fn enabled_idle_policy_valid_nonzero_frame_still_passes() { assert!(!result.1); assert_eq!(frame_counter, 1); } + +#[tokio::test] +async fn abridged_quickack_tiny_flood_is_fail_closed() { + let (reader, mut writer) = duplex(4096); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(21, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_enabled_idle_policy(); + let last_downstream_activity_ms = AtomicU64::new(0); + + let flood_plaintext = vec![0x80u8; 256]; + let flood_encrypted = encrypt_for_reader(&flood_plaintext); + writer.write_all(&flood_encrypted).await.unwrap(); + drop(writer); + + let result = read_bounded( + &mut crypto_reader, + ProtoTag::Abridged, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ) + .await; + + assert!( + matches!(result, Err(ProxyError::Proxy(_))), + "quickack-marked zero-length flood must fail closed" + ); +} + +#[tokio::test] +async fn abridged_extended_zero_len_flood_is_fail_closed() { + let (reader, mut writer) = duplex(4096); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(22, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_enabled_idle_policy(); + let last_downstream_activity_ms = AtomicU64::new(0); + + let mut flood_plaintext = Vec::with_capacity(4 * 256); + for _ in 0..256 { + flood_plaintext.extend_from_slice(&[0x7f, 0x00, 0x00, 0x00]); + } + let flood_encrypted = encrypt_for_reader(&flood_plaintext); + writer.write_all(&flood_encrypted).await.unwrap(); + drop(writer); + + let result = read_bounded( + &mut crypto_reader, + ProtoTag::Abridged, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ) + .await; + + assert!( + matches!(result, Err(ProxyError::Proxy(_))), + "extended zero-length abridged flood must fail closed" + ); +} + +#[tokio::test] +async fn one_to_eight_abridged_wire_pattern_survives_without_false_positive_close() { + let mut plaintext = Vec::with_capacity(9 * 300); + for idx in 0..300usize { + plaintext.push(0x00); + for _ in 0..8 { + let b = idx as u8; + plaintext.push(0x01); + plaintext.extend_from_slice(&[b, b ^ 0x11, b ^ 0x22, b ^ 0x33]); + } + } + + // Keep the test single-task and deterministic: make duplex capacity larger than the + // generated ciphertext so write_all cannot block waiting for a concurrent reader. + let duplex_capacity = plaintext.len().saturating_add(1024); + let (reader, mut writer) = duplex(duplex_capacity); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(23, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_enabled_idle_policy(); + let last_downstream_activity_ms = AtomicU64::new(0); + + let encrypted = encrypt_for_reader(&plaintext); + writer.write_all(&encrypted).await.unwrap(); + drop(writer); + + let mut closed = false; + for _ in 0..3000 { + match read_bounded( + &mut crypto_reader, + ProtoTag::Abridged, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ) + .await + { + Ok(Some(_)) => {} + Ok(None) => break, + Err(ProxyError::Proxy(_)) => { + closed = true; + break; + } + Err(other) => panic!("unexpected error in 1:8 wire test: {other}"), + } + } + + assert!( + !closed, + "wire-level 1:8 tiny-to-real pattern should not trigger debt close" + ); +} + +#[tokio::test] +async fn deterministic_light_fuzz_abridged_wire_behavior_matches_model() { + let mut seed = 0xD1CE_BAAD_2026_0322u64; + + for case_idx in 0..32u64 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let events = 300 + ((seed as usize) & 0xff); + let mut pattern = Vec::with_capacity(events); + let mut local = seed; + for _ in 0..events { + local ^= local << 7; + local ^= local >> 9; + local ^= local << 8; + pattern.push((local & 0x03) == 0); + } + + let mut plaintext = Vec::with_capacity(events * 6); + for (idx, tiny) in pattern.iter().copied().enumerate() { + if tiny { + plaintext.push(0x00); + } else { + let b = (idx as u8) ^ (case_idx as u8); + plaintext.push(0x01); + plaintext.extend_from_slice(&[b, b ^ 0x1F, b ^ 0x7A, b ^ 0xC3]); + } + } + + let (reader, mut writer) = duplex(16 * 1024); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(500 + case_idx, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_enabled_idle_policy(); + let last_downstream_activity_ms = AtomicU64::new(0); + + writer + .write_all(&encrypt_for_reader(&plaintext)) + .await + .unwrap(); + drop(writer); + + let (expected_close, _, _) = simulate_tiny_debt_pattern(&pattern, pattern.len()); + let mut observed_close = false; + + for _ in 0..(events + 8) { + match read_bounded( + &mut crypto_reader, + ProtoTag::Abridged, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ) + .await + { + Ok(Some(_)) => {} + Ok(None) => break, + Err(ProxyError::Proxy(_)) => { + observed_close = true; + break; + } + Err(other) => panic!("unexpected fuzz error: {other}"), + } + } + + assert_eq!( + observed_close, + expected_close.is_some(), + "wire parser behavior must match debt model for case {case_idx}" + ); + } +} diff --git a/src/proxy/tests/relay_cross_mode_pipeline_hol_integration_security_tests.rs b/src/proxy/tests/relay_cross_mode_pipeline_hol_integration_security_tests.rs new file mode 100644 index 0000000..9ea921c --- /dev/null +++ b/src/proxy/tests/relay_cross_mode_pipeline_hol_integration_security_tests.rs @@ -0,0 +1,267 @@ +use super::relay_bidirectional; +use crate::stats::Stats; +use crate::stream::BufferPool; +use std::sync::Arc; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::time::{Duration, timeout}; + +fn quota_test_guard() -> impl Drop { + super::quota_user_lock_test_scope() +} + +#[tokio::test] +async fn negative_same_user_pipeline_stalls_while_middle_lock_is_held() { + let _guard = quota_test_guard(); + + let user = format!("relay-pipeline-stall-{}", std::process::id()); + let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user); + let held_guard = held + .try_lock() + .expect("test must hold shared cross-mode lock"); + + let stats = Arc::new(Stats::new()); + let (mut client_peer, relay_client) = duplex(1024); + let (relay_server, mut server_peer) = duplex(1024); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_user = user.clone(); + let relay_stats = Arc::clone(&stats); + let relay_task = tokio::spawn(async move { + relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 256, + 256, + &relay_user, + relay_stats, + Some(1024), + Arc::new(BufferPool::new()), + ) + .await + }); + + server_peer + .write_all(&[0xA1]) + .await + .expect("server write should enqueue while relay is stalled"); + + let mut one = [0u8; 1]; + let blocked_read = timeout(Duration::from_millis(40), client_peer.read_exact(&mut one)).await; + assert!( + blocked_read.is_err(), + "same-user relay must remain blocked while cross-mode lock is held" + ); + + drop(held_guard); + + timeout(Duration::from_millis(400), client_peer.read_exact(&mut one)) + .await + .expect("blocked relay must resume after cross-mode lock release") + .expect("resumed relay must deliver queued byte"); + assert_eq!(one, [0xA1]); + + drop(client_peer); + drop(server_peer); + + let relay_result = timeout(Duration::from_secs(1), relay_task) + .await + .expect("relay task must complete") + .expect("relay task must not panic"); + assert!(relay_result.is_ok()); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn integration_other_user_pipeline_progresses_while_blocked_user_is_stalled() { + let _guard = quota_test_guard(); + + let blocked_user = format!("relay-pipeline-blocked-{}", std::process::id()); + let free_user = format!("relay-pipeline-free-{}", std::process::id()); + let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&blocked_user); + let held_guard = held + .try_lock() + .expect("test must hold blocked user's shared cross-mode lock"); + + let stats_blocked = Arc::new(Stats::new()); + let stats_free = Arc::new(Stats::new()); + + let (mut blocked_client, blocked_relay_client) = duplex(1024); + let (blocked_relay_server, mut blocked_server) = duplex(1024); + let (blocked_client_reader, blocked_client_writer) = tokio::io::split(blocked_relay_client); + let (blocked_server_reader, blocked_server_writer) = tokio::io::split(blocked_relay_server); + + let (mut free_client, free_relay_client) = duplex(1024); + let (free_relay_server, mut free_server) = duplex(1024); + let (free_client_reader, free_client_writer) = tokio::io::split(free_relay_client); + let (free_server_reader, free_server_writer) = tokio::io::split(free_relay_server); + + let blocked_task = { + let user = blocked_user.clone(); + let stats = Arc::clone(&stats_blocked); + tokio::spawn(async move { + relay_bidirectional( + blocked_client_reader, + blocked_client_writer, + blocked_server_reader, + blocked_server_writer, + 256, + 256, + &user, + stats, + Some(1024), + Arc::new(BufferPool::new()), + ) + .await + }) + }; + + let free_task = { + let user = free_user.clone(); + let stats = Arc::clone(&stats_free); + tokio::spawn(async move { + relay_bidirectional( + free_client_reader, + free_client_writer, + free_server_reader, + free_server_writer, + 256, + 256, + &user, + stats, + Some(1024), + Arc::new(BufferPool::new()), + ) + .await + }) + }; + + blocked_server + .write_all(&[0xB1]) + .await + .expect("blocked user server write should queue"); + free_server + .write_all(&[0xC1]) + .await + .expect("free user server write should queue"); + + let mut blocked_buf = [0u8; 1]; + let mut free_buf = [0u8; 1]; + + let blocked_stalled = timeout( + Duration::from_millis(40), + blocked_client.read_exact(&mut blocked_buf), + ) + .await; + assert!( + blocked_stalled.is_err(), + "blocked user must remain stalled while its lock is held" + ); + + timeout(Duration::from_millis(250), free_client.read_exact(&mut free_buf)) + .await + .expect("free user must make progress while other user is blocked") + .expect("free user read must succeed"); + assert_eq!(free_buf, [0xC1]); + + drop(held_guard); + + timeout(Duration::from_millis(400), blocked_client.read_exact(&mut blocked_buf)) + .await + .expect("blocked user must resume after release") + .expect("blocked user resumed read must succeed"); + assert_eq!(blocked_buf, [0xB1]); + + drop(blocked_client); + drop(blocked_server); + drop(free_client); + drop(free_server); + + assert!( + timeout(Duration::from_secs(1), blocked_task) + .await + .expect("blocked relay task must complete") + .expect("blocked relay task must not panic") + .is_ok() + ); + assert!( + timeout(Duration::from_secs(1), free_task) + .await + .expect("free relay task must complete") + .expect("free relay task must not panic") + .is_ok() + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_jittered_hold_release_cycles_preserve_pipeline_liveness() { + let _guard = quota_test_guard(); + + let mut seed = 0x5EED_C0DE_2026_0323u64; + for round in 0..24u32 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let hold_ms = 2 + (seed % 10); + let user = format!("relay-pipeline-fuzz-{}-{round}", std::process::id()); + let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user); + let held_guard = held + .try_lock() + .expect("test must hold lock during fuzz round"); + + let stats = Arc::new(Stats::new()); + let (mut client_peer, relay_client) = duplex(1024); + let (relay_server, mut server_peer) = duplex(1024); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_user = user.clone(); + let relay_stats = Arc::clone(&stats); + let relay_task = tokio::spawn(async move { + relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 256, + 256, + &relay_user, + relay_stats, + Some(1024), + Arc::new(BufferPool::new()), + ) + .await + }); + + server_peer + .write_all(&[0xD1]) + .await + .expect("server write should queue in fuzz round"); + + let mut one = [0u8; 1]; + let stalled = timeout(Duration::from_millis(30), client_peer.read_exact(&mut one)).await; + assert!(stalled.is_err(), "held phase must stall same-user relay"); + + tokio::time::sleep(Duration::from_millis(hold_ms)).await; + drop(held_guard); + + timeout(Duration::from_millis(400), client_peer.read_exact(&mut one)) + .await + .expect("released phase must resume same-user relay") + .expect("released phase read must succeed"); + assert_eq!(one, [0xD1]); + + drop(client_peer); + drop(server_peer); + + assert!( + timeout(Duration::from_secs(1), relay_task) + .await + .expect("fuzz relay task must complete") + .expect("fuzz relay task must not panic") + .is_ok() + ); + } +} \ No newline at end of file diff --git a/src/proxy/tests/relay_cross_mode_pipeline_latency_benchmark_security_tests.rs b/src/proxy/tests/relay_cross_mode_pipeline_latency_benchmark_security_tests.rs new file mode 100644 index 0000000..c967861 --- /dev/null +++ b/src/proxy/tests/relay_cross_mode_pipeline_latency_benchmark_security_tests.rs @@ -0,0 +1,213 @@ +use super::relay_bidirectional; +use crate::stats::Stats; +use crate::stream::BufferPool; +use std::sync::{Arc, Mutex}; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::sync::{Barrier, watch}; +use tokio::time::{Duration, Instant, timeout}; + +fn quota_test_guard() -> impl Drop { + super::quota_user_lock_test_scope() +} + +fn percentile_index(len: usize, percentile: usize) -> usize { + ((len * percentile) / 100).min(len.saturating_sub(1)) +} + +#[tokio::test] +async fn micro_benchmark_pipeline_release_to_delivery_latency_stays_bounded() { + let _guard = quota_test_guard(); + + let rounds = 64usize; + let user = format!("relay-pipeline-latency-single-{}", std::process::id()); + let mut samples_ms = Vec::with_capacity(rounds); + + for round in 0..rounds { + let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user); + let held_guard = held + .try_lock() + .expect("test must hold shared cross-mode lock before round"); + + let stats = Arc::new(Stats::new()); + let (mut client_peer, relay_client) = duplex(1024); + let (relay_server, mut server_peer) = duplex(1024); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_user = user.clone(); + let relay_stats = Arc::clone(&stats); + let relay_task = tokio::spawn(async move { + relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 256, + 256, + &relay_user, + relay_stats, + Some(2048), + Arc::new(BufferPool::new()), + ) + .await + }); + + server_peer + .write_all(&[(round as u8) ^ 0xA5]) + .await + .expect("server write should queue before release"); + + let release_at = Instant::now(); + drop(held_guard); + + let mut one = [0u8; 1]; + timeout(Duration::from_millis(450), client_peer.read_exact(&mut one)) + .await + .expect("client must receive queued byte after release") + .expect("queued byte read must succeed"); + samples_ms.push(release_at.elapsed().as_millis() as u64); + + drop(client_peer); + drop(server_peer); + + let relay_result = timeout(Duration::from_secs(1), relay_task) + .await + .expect("relay task must complete") + .expect("relay task must not panic"); + assert!(relay_result.is_ok()); + } + + samples_ms.sort_unstable(); + let p50_ms = samples_ms[percentile_index(samples_ms.len(), 50)]; + let p95_ms = samples_ms[percentile_index(samples_ms.len(), 95)]; + + assert!( + p50_ms <= 45, + "single-flow release latency p50 must stay bounded; p50_ms={p50_ms}, samples={samples_ms:?}" + ); + assert!( + p95_ms <= 130, + "single-flow release latency p95 must stay bounded; p95_ms={p95_ms}, samples={samples_ms:?}" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_128_waiter_pipeline_release_latency_p95_stays_bounded() { + let _guard = quota_test_guard(); + + let waiters = 128usize; + let user = format!("relay-pipeline-latency-fanout-{}", std::process::id()); + let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user); + let held_guard = held + .try_lock() + .expect("test must hold shared lock before fanout release benchmark"); + + let ready_barrier = Arc::new(Barrier::new(waiters + 1)); + let release_at = Arc::new(Mutex::new(None::)); + let (release_tx, release_rx) = watch::channel(false); + let mut tasks = Vec::with_capacity(waiters); + + for idx in 0..waiters { + let user = user.clone(); + let barrier = Arc::clone(&ready_barrier); + let release_at = Arc::clone(&release_at); + let mut release_rx = release_rx.clone(); + + tasks.push(tokio::spawn(async move { + let stats = Arc::new(Stats::new()); + let (mut client_peer, relay_client) = duplex(512); + let (relay_server, mut server_peer) = duplex(512); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_user = user; + let relay_stats = Arc::clone(&stats); + let relay_task = tokio::spawn(async move { + relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 256, + 256, + &relay_user, + relay_stats, + Some(2048), + Arc::new(BufferPool::new()), + ) + .await + }); + + server_peer + .write_all(&[(idx as u8) ^ 0x5A]) + .await + .expect("fanout server write should queue before release"); + + barrier.wait().await; + release_rx + .changed() + .await + .expect("release signal should remain available"); + + let started = { + let guard = release_at.lock().unwrap_or_else(|poison| poison.into_inner()); + guard.expect("release timestamp must be populated before signal") + }; + + let mut one = [0u8; 1]; + timeout(Duration::from_millis(900), client_peer.read_exact(&mut one)) + .await + .expect("fanout waiter must receive queued byte after release") + .expect("fanout waiter read must succeed"); + + drop(client_peer); + drop(server_peer); + + let relay_result = timeout(Duration::from_secs(2), relay_task) + .await + .expect("fanout relay task must complete") + .expect("fanout relay task must not panic"); + assert!(relay_result.is_ok()); + + started.elapsed().as_millis() as u64 + })); + } + + ready_barrier.wait().await; + { + let mut guard = release_at.lock().unwrap_or_else(|poison| poison.into_inner()); + *guard = Some(Instant::now()); + } + drop(held_guard); + release_tx + .send(true) + .expect("release broadcast must succeed"); + + let mut samples_ms = Vec::with_capacity(waiters); + timeout(Duration::from_secs(8), async { + for task in tasks { + let elapsed = task.await.expect("fanout waiter must not panic"); + samples_ms.push(elapsed); + } + }) + .await + .expect("fanout benchmark must complete in bounded time"); + + samples_ms.sort_unstable(); + let p50_ms = samples_ms[percentile_index(samples_ms.len(), 50)]; + let p95_ms = samples_ms[percentile_index(samples_ms.len(), 95)]; + let max_ms = *samples_ms.last().unwrap_or(&0); + + assert!( + p50_ms <= 120, + "fanout release latency p50 must stay bounded; p50_ms={p50_ms}, p95_ms={p95_ms}, max_ms={max_ms}" + ); + assert!( + p95_ms <= 260, + "fanout release latency p95 must stay bounded; p50_ms={p50_ms}, p95_ms={p95_ms}, max_ms={max_ms}" + ); + assert!( + max_ms <= 700, + "fanout release latency max must stay bounded; p50_ms={p50_ms}, p95_ms={p95_ms}, max_ms={max_ms}" + ); +} \ No newline at end of file diff --git a/src/proxy/tests/relay_cross_mode_quota_fairness_tdd_tests.rs b/src/proxy/tests/relay_cross_mode_quota_fairness_tdd_tests.rs index 87944ba..adbdb22 100644 --- a/src/proxy/tests/relay_cross_mode_quota_fairness_tdd_tests.rs +++ b/src/proxy/tests/relay_cross_mode_quota_fairness_tdd_tests.rs @@ -3,8 +3,9 @@ use crate::stats::Stats; use std::pin::Pin; use std::sync::Arc; use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; -use std::task::{Context, Waker}; +use std::task::{Context, Poll, Waker}; use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadBuf}; +use tokio::sync::Barrier; use tokio::time::{Duration, timeout}; #[derive(Default)] @@ -26,6 +27,13 @@ fn quota_test_guard() -> impl Drop { super::quota_user_lock_test_scope() } +fn build_context() -> (Arc, Context<'static>) { + let wake_counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&wake_counter)); + let leaked_waker: &'static Waker = Box::leak(Box::new(waker)); + (wake_counter, Context::from_waker(leaked_waker)) +} + #[tokio::test] async fn positive_cross_mode_uncontended_writer_progresses() { let _guard = quota_test_guard(); @@ -223,3 +231,374 @@ async fn light_fuzz_cross_mode_release_timing_preserves_read_write_liveness() { assert!(write_done.is_ok()); } } + +#[tokio::test] +async fn integration_middle_lock_blocks_relay_reader_for_same_user() { + let _guard = quota_test_guard(); + + let user = format!("cross-mode-middle-reader-block-{}", std::process::id()); + let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user); + let _held_guard = held + .try_lock() + .expect("test must hold middle-relay shared lock"); + + let mut io = StatsIo::new( + tokio::io::empty(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + Some(1024), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + let (_wake_counter, mut cx) = build_context(); + let mut one = [0u8; 1]; + let mut buf = ReadBuf::new(&mut one); + let poll = Pin::new(&mut io).poll_read(&mut cx, &mut buf); + assert!(poll.is_pending()); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn integration_middle_lock_release_unblocks_relay_reader() { + let _guard = quota_test_guard(); + + let user = format!("cross-mode-middle-reader-release-{}", std::process::id()); + let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user); + let held_guard = held + .try_lock() + .expect("test must hold middle-relay shared lock"); + + let task = tokio::spawn({ + let user = user.clone(); + async move { + let mut io = StatsIo::new( + tokio::io::empty(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + Some(1024), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + let mut one = [0u8; 1]; + io.read(&mut one).await + } + }); + + tokio::time::sleep(Duration::from_millis(5)).await; + drop(held_guard); + + let done = timeout(Duration::from_millis(300), task) + .await + .expect("reader task must complete after release") + .expect("reader task must not panic"); + assert!(done.is_ok()); +} + +#[tokio::test] +async fn business_different_user_middle_lock_does_not_block_relay_writer() { + let _guard = quota_test_guard(); + + let held_user = format!("cross-mode-middle-held-{}", std::process::id()); + let active_user = format!("cross-mode-middle-active-{}", std::process::id()); + let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&held_user); + let _held_guard = held + .try_lock() + .expect("test must hold middle-relay lock for other user"); + + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + active_user, + Some(1024), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + let (_wake_counter, mut cx) = build_context(); + let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x61]); + assert!(matches!(poll, Poll::Ready(Ok(1)))); +} + +#[tokio::test] +async fn edge_quota_none_bypasses_cross_mode_lock_even_when_held() { + let _guard = quota_test_guard(); + + let user = format!("cross-mode-none-limit-{}", std::process::id()); + let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user); + let _held_guard = held + .try_lock() + .expect("test must hold lock while quota is disabled"); + + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + None, + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + let (_wake_counter, mut cx) = build_context(); + let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x62, 0x63]); + assert!(matches!(poll, Poll::Ready(Ok(2)))); +} + +#[tokio::test] +async fn edge_quota_exceeded_flag_short_circuits_before_lock_path() { + let _guard = quota_test_guard(); + + let user = format!("cross-mode-pre-exceeded-{}", std::process::id()); + let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user); + let _held_guard = held + .try_lock() + .expect("test must hold shared lock before poll"); + + let quota_exceeded = Arc::new(AtomicBool::new(true)); + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + Some(1024), + Arc::clone("a_exceeded), + tokio::time::Instant::now(), + ); + + let (_wake_counter, mut cx) = build_context(); + let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x64]); + assert!(matches!(poll, Poll::Ready(Err(ref e)) if is_quota_io_error(e))); +} + +#[tokio::test] +async fn adversarial_repoll_while_middle_lock_held_keeps_pending_without_usage_leak() { + let _guard = quota_test_guard(); + + let user = format!("cross-mode-repoll-held-{}", std::process::id()); + let stats = Arc::new(Stats::new()); + let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user); + let _held_guard = held + .try_lock() + .expect("test must hold lock for repoll sequence"); + + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::clone(&stats), + user.clone(), + Some(1024), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + let (_wake_counter, mut cx) = build_context(); + for _ in 0..8 { + let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x65]); + assert!(poll.is_pending()); + } + + assert_eq!(stats.get_user_total_octets(&user), 0); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn integration_same_user_mixed_read_write_waiters_resume_after_release() { + let _guard = quota_test_guard(); + + let user = format!("cross-mode-mixed-resume-{}", std::process::id()); + let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user); + let held_guard = held + .try_lock() + .expect("test must hold lock before spawning mixed waiters"); + + let mut tasks = Vec::new(); + for i in 0..12usize { + let user = user.clone(); + tasks.push(tokio::spawn(async move { + if i % 2 == 0 { + let mut io = StatsIo::new( + tokio::io::empty(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + Some(1024), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + let mut b = [0u8; 1]; + io.read(&mut b).await.map(|_| ()) + } else { + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + Some(1024), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + io.write_all(&[0x66]).await + } + })); + } + + tokio::time::sleep(Duration::from_millis(8)).await; + drop(held_guard); + + timeout(Duration::from_secs(1), async { + for task in tasks { + let result = task.await.expect("mixed waiter task must not panic"); + assert!(result.is_ok()); + } + }) + .await + .expect("all mixed waiters must finish after release"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn integration_one_user_blocked_other_user_progresses_under_middle_lock() { + let _guard = quota_test_guard(); + + let blocked_user = format!("cross-mode-blocked-{}", std::process::id()); + let free_user = format!("cross-mode-free-{}", std::process::id()); + let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&blocked_user); + let held_guard = held + .try_lock() + .expect("test must hold blocked user lock"); + + let blocked_task = tokio::spawn({ + let blocked_user = blocked_user.clone(); + async move { + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + blocked_user, + Some(1024), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + io.write_all(&[0x77]).await + } + }); + + let free_task = tokio::spawn({ + let free_user = free_user.clone(); + async move { + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + free_user, + Some(1024), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + io.write_all(&[0x78]).await + } + }); + + let free_done = timeout(Duration::from_millis(250), free_task) + .await + .expect("free user must not be blocked") + .expect("free user task must not panic"); + assert!(free_done.is_ok()); + + drop(held_guard); + let blocked_done = timeout(Duration::from_secs(1), blocked_task) + .await + .expect("blocked user must resume after release") + .expect("blocked user task must not panic"); + assert!(blocked_done.is_ok()); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_middle_lock_release_allows_high_waiter_fanout_completion() { + let _guard = quota_test_guard(); + + let user = format!("cross-mode-fanout-{}", std::process::id()); + let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user); + let held_guard = held + .try_lock() + .expect("test must hold lock before fanout"); + + let waiters = 48usize; + let gate = Arc::new(Barrier::new(waiters + 1)); + let mut tasks = Vec::new(); + for _ in 0..waiters { + let user = user.clone(); + let gate = Arc::clone(&gate); + tasks.push(tokio::spawn(async move { + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + Some(1024), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + gate.wait().await; + io.write_all(&[0x79]).await + })); + } + + gate.wait().await; + tokio::time::sleep(Duration::from_millis(10)).await; + drop(held_guard); + + timeout(Duration::from_secs(2), async { + for task in tasks { + let result = task.await.expect("fanout task must not panic"); + assert!(result.is_ok()); + } + }) + .await + .expect("fanout waiters must complete after release"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn light_fuzz_middle_lock_hold_release_cycles_preserve_same_user_liveness() { + let _guard = quota_test_guard(); + + let mut seed = 0xA11C_EE55_2026_0323u64; + for round in 0..20u32 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let hold_ms = 2 + (seed % 10); + let user = format!("cross-mode-middle-fuzz-{}-{round}", std::process::id()); + let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user); + let held_guard = held + .try_lock() + .expect("test must hold lock in fuzz round"); + + let writer = tokio::spawn({ + let user = user.clone(); + async move { + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + Some(1024), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + io.write_all(&[0x7A]).await + } + }); + + tokio::time::sleep(Duration::from_millis(hold_ms)).await; + drop(held_guard); + + let done = timeout(Duration::from_millis(400), writer) + .await + .expect("writer must complete after lock release") + .expect("writer task must not panic"); + assert!(done.is_ok()); + } +} diff --git a/src/proxy/tests/relay_dual_lock_alternating_contention_security_tests.rs b/src/proxy/tests/relay_dual_lock_alternating_contention_security_tests.rs new file mode 100644 index 0000000..9ac4621 --- /dev/null +++ b/src/proxy/tests/relay_dual_lock_alternating_contention_security_tests.rs @@ -0,0 +1,340 @@ +use super::*; +use crate::stats::Stats; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::task::{Context, Waker}; +use tokio::io::AsyncWriteExt; +use tokio::time::{Duration, Instant, timeout}; + +#[derive(Default)] +struct WakeCounter { + wakes: AtomicUsize, +} + +impl std::task::Wake for WakeCounter { + fn wake(self: Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } + + fn wake_by_ref(self: &Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } +} + +fn quota_test_guard() -> impl Drop { + super::quota_user_lock_test_scope() +} + +#[tokio::test] +async fn positive_uncontended_dual_lock_writer_has_zero_retry_attempt() { + let _guard = quota_test_guard(); + + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + format!("dual-lock-alt-positive-{}", std::process::id()), + Some(2048), + Arc::new(AtomicBool::new(false)), + Instant::now(), + ); + + let write = io.write_all(&[0xAA, 0xBB]).await; + assert!(write.is_ok(), "uncontended write must complete"); + assert_eq!( + io.quota_write_retry_attempt, 0, + "uncontended write must not advance retry backoff" + ); +} + +#[tokio::test] +async fn adversarial_alternating_local_and_cross_mode_contention_preserves_backoff_growth() { + let _guard = quota_test_guard(); + + let user = format!("dual-lock-alt-adversarial-{}", std::process::id()); + let local_lock = quota_user_lock(&user); + let cross_mode_lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); + + let mut local_guard = Some( + local_lock + .try_lock() + .expect("test must hold local quota lock initially"), + ); + let mut cross_guard = None; + + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + Some(2048), + Arc::new(AtomicBool::new(false)), + Instant::now(), + ); + + let wake_counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&wake_counter)); + let mut cx = Context::from_waker(&waker); + + let first = Pin::new(&mut io).poll_write(&mut cx, &[0x11]); + assert!(first.is_pending(), "held local lock must block first poll"); + + let mut observed_wakes = 0usize; + for idx in 0..18usize { + tokio::time::sleep(Duration::from_millis(6)).await; + + if idx % 2 == 0 { + drop(local_guard.take()); + cross_guard = Some( + cross_mode_lock + .try_lock() + .expect("cross-mode lock should be acquirable while local lock released"), + ); + } else { + drop(cross_guard.take()); + local_guard = Some( + local_lock + .try_lock() + .expect("local lock should be acquirable while cross lock released"), + ); + } + + let wakes = wake_counter.wakes.load(Ordering::Relaxed); + if wakes > observed_wakes { + observed_wakes = wakes; + let pending = Pin::new(&mut io).poll_write(&mut cx, &[0x12]); + assert!( + pending.is_pending(), + "alternating contention must keep write pending while one lock is held" + ); + } + } + + assert!( + io.quota_write_retry_attempt >= 2, + "alternating contention must still ramp retry backoff; got {}", + io.quota_write_retry_attempt + ); + assert!( + wake_counter.wakes.load(Ordering::Relaxed) <= 32, + "alternating contention must stay wake-rate-limited" + ); + + drop(local_guard); + drop(cross_guard); + let ready = Pin::new(&mut io).poll_write(&mut cx, &[0x13]); + assert!(ready.is_ready(), "writer must resume after both locks released"); +} + +#[tokio::test] +async fn edge_retry_scheduler_resets_after_alternating_contention_clears() { + let _guard = quota_test_guard(); + + let user = format!("dual-lock-alt-edge-reset-{}", std::process::id()); + let local_lock = quota_user_lock(&user); + let local_guard = local_lock + .try_lock() + .expect("test must hold local lock for edge scenario"); + + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + Some(2048), + Arc::new(AtomicBool::new(false)), + Instant::now(), + ); + + let wake_counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&wake_counter)); + let mut cx = Context::from_waker(&waker); + + let first = Pin::new(&mut io).poll_write(&mut cx, &[0x21]); + assert!(first.is_pending()); + tokio::time::sleep(Duration::from_millis(15)).await; + if wake_counter.wakes.load(Ordering::Relaxed) > 0 { + let next = Pin::new(&mut io).poll_write(&mut cx, &[0x22]); + assert!(next.is_pending()); + } + + drop(local_guard); + + let ready = Pin::new(&mut io).poll_write(&mut cx, &[0x23]); + assert!(ready.is_ready()); + assert_eq!( + io.quota_write_retry_attempt, 0, + "successful dual-lock acquisition must reset retry scheduler" + ); + assert!(!io.quota_write_wake_scheduled); + assert!(io.quota_write_retry_sleep.is_none()); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn integration_cross_mode_waiters_remain_live_under_alternating_contention_then_resume() { + let _guard = quota_test_guard(); + + let user = format!("dual-lock-alt-integration-{}", std::process::id()); + let local_lock = quota_user_lock(&user); + let cross_mode_lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); + + let mut waiters = Vec::new(); + for _ in 0..16usize { + let user = user.clone(); + waiters.push(tokio::spawn(async move { + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + Some(2048), + Arc::new(AtomicBool::new(false)), + Instant::now(), + ); + timeout(Duration::from_secs(2), io.write_all(&[0x31])).await + })); + } + + let mut local_guard = Some( + local_lock + .try_lock() + .expect("integration toggle must acquire local lock first"), + ); + let mut cross_guard = None; + + for idx in 0..24usize { + tokio::time::sleep(Duration::from_millis(4)).await; + if idx % 2 == 0 { + drop(local_guard.take()); + cross_guard = cross_mode_lock.try_lock().ok(); + } else { + drop(cross_guard.take()); + local_guard = local_lock.try_lock().ok(); + } + } + + drop(local_guard); + drop(cross_guard); + + for waiter in waiters { + let done = waiter.await.expect("waiter task must not panic"); + assert!( + done.is_ok(), + "waiter must finish once alternating contention window ends" + ); + assert!(done.expect("waiter timeout must not fire").is_ok()); + } +} + +#[tokio::test] +async fn light_fuzz_alternating_contention_matrix_preserves_lock_gating() { + let _guard = quota_test_guard(); + + let user = format!("dual-lock-alt-fuzz-{}", std::process::id()); + let local_lock = quota_user_lock(&user); + let cross_mode_lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); + let mut seed = 0xD00D_BAAD_F00D_2026u64; + + for _round in 0..64u32 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let hold_mode = (seed % 3) as u8; + let local_guard = if hold_mode == 0 { + Some( + local_lock + .try_lock() + .expect("fuzz local lock should be acquirable"), + ) + } else { + None + }; + let cross_guard = if hold_mode == 1 { + Some( + cross_mode_lock + .try_lock() + .expect("fuzz cross lock should be acquirable"), + ) + } else { + None + }; + + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user.clone(), + Some(1024), + Arc::new(AtomicBool::new(false)), + Instant::now(), + ); + + let write = timeout(Duration::from_millis(35), io.write_all(&[0x51])).await; + if hold_mode == 2 { + assert!(write.is_ok(), "unheld fuzz round must make progress"); + assert!(write.expect("unheld round timeout").is_ok()); + } else { + assert!( + write.is_err(), + "held-lock fuzz round must remain pending inside bounded window" + ); + } + + drop(local_guard); + drop(cross_guard); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_fanout_alternating_contention_recovers_without_hanging() { + let _guard = quota_test_guard(); + + let user = format!("dual-lock-alt-stress-{}", std::process::id()); + let local_lock = quota_user_lock(&user); + let cross_mode_lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); + + let mut waiters = Vec::new(); + for _ in 0..48usize { + let user = user.clone(); + waiters.push(tokio::spawn(async move { + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + Some(4096), + Arc::new(AtomicBool::new(false)), + Instant::now(), + ); + timeout(Duration::from_secs(3), io.write_all(&[0xA0, 0xA1])).await + })); + } + + let mut local_guard = Some( + local_lock + .try_lock() + .expect("stress toggle must acquire local lock first"), + ); + let mut cross_guard = None; + for idx in 0..40usize { + tokio::time::sleep(Duration::from_millis(3)).await; + if idx % 2 == 0 { + drop(local_guard.take()); + cross_guard = cross_mode_lock.try_lock().ok(); + } else { + drop(cross_guard.take()); + local_guard = local_lock.try_lock().ok(); + } + } + + drop(local_guard); + drop(cross_guard); + + for waiter in waiters { + let done = waiter.await.expect("stress waiter task must not panic"); + assert!(done.is_ok(), "stress waiter timed out under alternating contention"); + assert!(done.expect("stress waiter timeout should not fire").is_ok()); + } +} diff --git a/src/proxy/tests/relay_dual_lock_backoff_regression_security_tests.rs b/src/proxy/tests/relay_dual_lock_backoff_regression_security_tests.rs new file mode 100644 index 0000000..ce26941 --- /dev/null +++ b/src/proxy/tests/relay_dual_lock_backoff_regression_security_tests.rs @@ -0,0 +1,74 @@ +use super::*; +use crate::stats::Stats; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::task::{Context, Waker}; +use tokio::time::{Duration, Instant}; + +#[derive(Default)] +struct WakeCounter { + wakes: AtomicUsize, +} + +impl std::task::Wake for WakeCounter { + fn wake(self: Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } + + fn wake_by_ref(self: &Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } +} + +fn quota_test_guard() -> impl Drop { + super::quota_user_lock_test_scope() +} + +#[tokio::test] +async fn adversarial_cross_mode_only_contention_backoff_attempt_must_ramp() { + let _guard = quota_test_guard(); + + let user = format!("dual-lock-backoff-{}", std::process::id()); + let cross_mode_lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); + let held_cross_mode_guard = cross_mode_lock + .try_lock() + .expect("test must hold cross-mode lock before polling"); + + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + Some(2048), + Arc::new(AtomicBool::new(false)), + Instant::now(), + ); + + let wake_counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&wake_counter)); + let mut cx = Context::from_waker(&waker); + + let first = Pin::new(&mut io).poll_write(&mut cx, &[0xAA]); + assert!(first.is_pending(), "held cross-mode lock must block writer"); + + let started = Instant::now(); + let mut last_wakes = 0usize; + while started.elapsed() < Duration::from_millis(120) { + let wakes = wake_counter.wakes.load(Ordering::Relaxed); + if wakes > last_wakes { + last_wakes = wakes; + let next = Pin::new(&mut io).poll_write(&mut cx, &[0xAB]); + assert!(next.is_pending(), "writer must remain blocked while lock is held"); + } + tokio::time::sleep(Duration::from_millis(1)).await; + } + + assert!( + io.quota_write_retry_attempt >= 2, + "retry attempt must ramp under sustained second-lock contention; got {}", + io.quota_write_retry_attempt + ); + + drop(held_cross_mode_guard); +} diff --git a/src/proxy/tests/relay_dual_lock_contention_matrix_security_tests.rs b/src/proxy/tests/relay_dual_lock_contention_matrix_security_tests.rs new file mode 100644 index 0000000..513d92b --- /dev/null +++ b/src/proxy/tests/relay_dual_lock_contention_matrix_security_tests.rs @@ -0,0 +1,325 @@ +use super::*; +use crate::stats::Stats; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::task::{Context, Waker}; +use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadBuf}; +use tokio::time::{Duration, Instant, timeout}; + +#[derive(Default)] +struct WakeCounter { + wakes: AtomicUsize, +} + +impl std::task::Wake for WakeCounter { + fn wake(self: Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } + + fn wake_by_ref(self: &Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } +} + +fn quota_test_guard() -> impl Drop { + super::quota_user_lock_test_scope() +} + +fn build_context() -> (Arc, Context<'static>) { + let wake_counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&wake_counter)); + let leaked_waker: &'static Waker = Box::leak(Box::new(waker)); + (wake_counter, Context::from_waker(leaked_waker)) +} + +#[tokio::test] +async fn positive_uncontended_dual_locks_writer_completes_without_retry_state() { + let _guard = quota_test_guard(); + + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + format!("dual-lock-positive-{}", std::process::id()), + Some(4096), + Arc::new(AtomicBool::new(false)), + Instant::now(), + ); + + let (_wake_counter, mut cx) = build_context(); + let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x01, 0x02, 0x03]); + assert!(poll.is_ready()); + assert_eq!(io.quota_write_retry_attempt, 0); + assert!(!io.quota_write_wake_scheduled); + assert!(io.quota_write_retry_sleep.is_none()); +} + +#[tokio::test] +async fn negative_local_lock_contention_read_retry_attempt_ramps() { + let _guard = quota_test_guard(); + + let user = format!("dual-lock-local-contention-{}", std::process::id()); + let held = quota_user_lock(&user); + let held_guard = held + .try_lock() + .expect("test must hold local quota lock before polling"); + + let mut io = StatsIo::new( + tokio::io::empty(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + Some(2048), + Arc::new(AtomicBool::new(false)), + Instant::now(), + ); + + let (wake_counter, mut cx) = build_context(); + let mut one = [0u8; 1]; + let mut buf = ReadBuf::new(&mut one); + let first = Pin::new(&mut io).poll_read(&mut cx, &mut buf); + assert!(first.is_pending()); + + let started = Instant::now(); + let mut observed = 0usize; + while started.elapsed() < Duration::from_millis(120) { + let wakes = wake_counter.wakes.load(Ordering::Relaxed); + if wakes > observed { + observed = wakes; + let mut step_buf = ReadBuf::new(&mut one); + let next = Pin::new(&mut io).poll_read(&mut cx, &mut step_buf); + assert!(next.is_pending()); + } + tokio::time::sleep(Duration::from_millis(1)).await; + } + + assert!( + io.quota_read_retry_attempt >= 2, + "retry attempt must ramp under sustained local-lock contention; got {}", + io.quota_read_retry_attempt + ); + + drop(held_guard); +} + +#[tokio::test] +async fn edge_cross_mode_contention_release_resets_retry_scheduler_on_success() { + let _guard = quota_test_guard(); + + let user = format!("dual-lock-reset-{}", std::process::id()); + let cross_mode = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); + let held_guard = cross_mode + .try_lock() + .expect("test must hold cross-mode lock before polling"); + + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + Some(2048), + Arc::new(AtomicBool::new(false)), + Instant::now(), + ); + + let (wake_counter, mut cx) = build_context(); + let first = Pin::new(&mut io).poll_write(&mut cx, &[0x10]); + assert!(first.is_pending()); + + tokio::time::sleep(Duration::from_millis(20)).await; + if wake_counter.wakes.load(Ordering::Relaxed) > 0 { + let next = Pin::new(&mut io).poll_write(&mut cx, &[0x11]); + assert!(next.is_pending()); + } + + drop(held_guard); + let ready = Pin::new(&mut io).poll_write(&mut cx, &[0x12]); + assert!(ready.is_ready()); + assert_eq!(io.quota_write_retry_attempt, 0); + assert!(!io.quota_write_wake_scheduled); + assert!(io.quota_write_retry_sleep.is_none()); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn adversarial_cross_mode_hold_blocks_many_waiters_without_usage_leak() { + let _guard = quota_test_guard(); + + let user = format!("dual-lock-adversarial-{}", std::process::id()); + let stats = Arc::new(Stats::new()); + let held = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); + let held_guard = held + .try_lock() + .expect("test must hold cross-mode lock before launching waiters"); + + let mut tasks = Vec::new(); + for _ in 0..24usize { + let stats = Arc::clone(&stats); + let user = user.clone(); + tasks.push(tokio::spawn(async move { + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + stats, + user, + Some(1024), + Arc::new(AtomicBool::new(false)), + Instant::now(), + ); + timeout(Duration::from_millis(40), io.write_all(&[0x33])).await + })); + } + + for task in tasks { + let timed = task.await.expect("waiter task must not panic"); + assert!(timed.is_err(), "held cross-mode lock must keep waiter pending"); + } + + assert_eq!(stats.get_user_total_octets(&user), 0); + drop(held_guard); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn integration_waiters_resume_after_cross_mode_release() { + let _guard = quota_test_guard(); + + let user = format!("dual-lock-integration-{}", std::process::id()); + let held = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); + let held_guard = held + .try_lock() + .expect("test must hold cross-mode lock before starting waiter"); + + let task = tokio::spawn({ + let user = user.clone(); + async move { + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + Some(1024), + Arc::new(AtomicBool::new(false)), + Instant::now(), + ); + io.write_all(&[0x44]).await + } + }); + + tokio::time::sleep(Duration::from_millis(10)).await; + drop(held_guard); + + let done = timeout(Duration::from_secs(1), task) + .await + .expect("waiter task must complete after release") + .expect("waiter task must not panic"); + assert!(done.is_ok()); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn light_fuzz_randomized_lock_holds_preserve_liveness_and_quota_bounds() { + let _guard = quota_test_guard(); + + let user = format!("dual-lock-fuzz-{}", std::process::id()); + let stats = Arc::new(Stats::new()); + let mut seed = 0xA55A_55AA_C3D2_E1F0u64; + + for _round in 0..48u32 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let hold_mode = (seed % 3) as u8; + let mut local_lock = None; + let mut cross_lock = None; + let mut local_guard = None; + let mut cross_guard = None; + + if hold_mode == 0 { + local_lock = Some(quota_user_lock(&user)); + local_guard = Some( + local_lock + .as_ref() + .expect("local lock should be present") + .try_lock() + .expect("local lock should be acquirable in fuzz round"), + ); + } else if hold_mode == 1 { + cross_lock = Some(crate::proxy::quota_lock_registry::cross_mode_quota_user_lock( + &user, + )); + cross_guard = Some( + cross_lock + .as_ref() + .expect("cross lock should be present") + .try_lock() + .expect("cross lock should be acquirable in fuzz round"), + ); + } + + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::clone(&stats), + user.clone(), + Some(4096), + Arc::new(AtomicBool::new(false)), + Instant::now(), + ); + + let write = timeout(Duration::from_millis(25), io.write_all(&[0x7A])).await; + if hold_mode == 2 { + assert!(write.is_ok(), "unheld round must make progress"); + } else { + assert!(write.is_err(), "held-lock round must stay blocked within timeout"); + } + + drop(local_guard); + drop(cross_guard); + drop(local_lock); + drop(cross_lock); + } + + assert!(stats.get_user_total_octets(&user) <= 4096); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_fanout_waiters_complete_after_release_without_panics() { + let _guard = quota_test_guard(); + + let user = format!("dual-lock-stress-{}", std::process::id()); + let held = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); + let held_guard = held + .try_lock() + .expect("test must hold cross-mode lock before stress fanout"); + + let waiters = 64usize; + let mut tasks = Vec::new(); + for _ in 0..waiters { + let user = user.clone(); + tasks.push(tokio::spawn(async move { + let mut io = StatsIo::new( + tokio::io::empty(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + Some(1024), + Arc::new(AtomicBool::new(false)), + Instant::now(), + ); + let mut one = [0u8; 1]; + io.read(&mut one).await + })); + } + + tokio::time::sleep(Duration::from_millis(12)).await; + drop(held_guard); + + timeout(Duration::from_secs(2), async { + for task in tasks { + let result = task.await.expect("stress waiter task must not panic"); + assert!(result.is_ok()); + } + }) + .await + .expect("all stress waiters must complete after release"); +} diff --git a/src/proxy/tests/relay_dual_lock_race_harness_security_tests.rs b/src/proxy/tests/relay_dual_lock_race_harness_security_tests.rs new file mode 100644 index 0000000..ec180e8 --- /dev/null +++ b/src/proxy/tests/relay_dual_lock_race_harness_security_tests.rs @@ -0,0 +1,128 @@ +use super::*; +use crate::stats::Stats; +use std::sync::Arc; +use std::sync::atomic::AtomicBool; +use tokio::io::AsyncWriteExt; +use tokio::time::{Duration, timeout}; + +fn quota_test_guard() -> impl Drop { + super::quota_user_lock_test_scope() +} + +fn make_stats_io(user: String) -> StatsIo { + StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + Some(4096), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn light_fuzz_1024_round_hold_release_cycles_preserve_same_user_liveness() { + let _guard = quota_test_guard(); + + let user = format!("dual-lock-race-fuzz-{}", std::process::id()); + let mut seed = 0xD1CE_BAAD_5EED_1234u64; + + for round in 0..1024u32 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let hold = (seed & 1) == 0; + let hold_ms = (seed % 3) as u64; + + let maybe_lock = if hold { + Some(crate::proxy::quota_lock_registry::cross_mode_quota_user_lock( + &user, + )) + } else { + None + }; + + let maybe_guard = maybe_lock.as_ref().map(|lock| { + lock.try_lock() + .expect("cross-mode lock must be acquirable in fuzz round") + }); + + if hold { + let mut blocked_io = make_stats_io(user.clone()); + let blocked = timeout(Duration::from_millis(5), blocked_io.write_all(&[0xA5])).await; + assert!( + blocked.is_err(), + "held round must block waiter before lock release (round={round})" + ); + + if hold_ms > 0 { + tokio::time::sleep(Duration::from_millis(hold_ms)).await; + } + } else { + let mut free_io = make_stats_io(user.clone()); + let free = timeout(Duration::from_millis(120), free_io.write_all(&[0xA5])).await; + assert!( + free.is_ok(), + "unheld round must complete promptly (round={round})" + ); + assert!(free.expect("unheld round should complete").is_ok()); + } + + drop(maybe_guard); + + let done = timeout(Duration::from_millis(350), async { + let user = user.clone(); + let mut io = make_stats_io(user); + io.write_all(&[0xA6]).await + }) + .await + .expect("post-release write must complete in bounded time"); + assert!(done.is_ok()); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_jittered_three_waiter_rounds_do_not_starve_after_release() { + let _guard = quota_test_guard(); + + let user = format!("dual-lock-race-stress-{}", std::process::id()); + let mut seed = 0xC0FF_EE77_4444_9999u64; + + for round in 0..256u32 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let hold_ms = (seed % 4) as u64; + let lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); + let guard = lock + .try_lock() + .expect("cross-mode lock must be acquirable at round start"); + + let mut waiters = Vec::new(); + for _ in 0..3usize { + let user = user.clone(); + waiters.push(tokio::spawn(async move { + let mut io = make_stats_io(user); + io.write_all(&[0x55]).await + })); + } + + tokio::time::sleep(Duration::from_millis(hold_ms)).await; + drop(guard); + + timeout(Duration::from_secs(1), async { + for waiter in waiters { + let done = waiter.await.expect("waiter task must not panic"); + assert!( + done.is_ok(), + "waiter must complete after release (round={round})" + ); + } + }) + .await + .expect("all waiters must complete in bounded time after release"); + } +} diff --git a/src/proxy/tests/relay_quota_extended_attack_surface_security_tests.rs b/src/proxy/tests/relay_quota_extended_attack_surface_security_tests.rs new file mode 100644 index 0000000..5ee6522 --- /dev/null +++ b/src/proxy/tests/relay_quota_extended_attack_surface_security_tests.rs @@ -0,0 +1,332 @@ +use super::relay_bidirectional; +use crate::error::ProxyError; +use crate::stats::Stats; +use crate::stream::BufferPool; +use rand::rngs::StdRng; +use rand::{RngExt, SeedableRng}; +use std::sync::Arc; +use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; +use tokio::time::{Duration, timeout}; + +async fn read_available(reader: &mut R, budget: Duration) -> usize { + let start = tokio::time::Instant::now(); + let mut total = 0usize; + let mut buf = [0u8; 128]; + + loop { + let elapsed = start.elapsed(); + if elapsed >= budget { + break; + } + let remaining = budget.saturating_sub(elapsed); + match timeout(remaining, reader.read(&mut buf)).await { + Ok(Ok(0)) => break, + Ok(Ok(n)) => total = total.saturating_add(n), + Ok(Err(_)) | Err(_) => break, + } + } + + total +} + +#[tokio::test] +async fn positive_quota_path_forwards_both_directions_within_limit() { + let stats = Arc::new(Stats::new()); + let user = "quota-extended-positive-user"; + + let (mut client_peer, relay_client) = duplex(2048); + let (relay_server, mut server_peer) = duplex(2048); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 256, + 256, + user, + Arc::clone(&stats), + Some(16), + Arc::new(BufferPool::new()), + )); + + client_peer.write_all(&[0xAA, 0xBB, 0xCC, 0xDD]).await.unwrap(); + server_peer.read_exact(&mut [0u8; 4]).await.unwrap(); + + server_peer.write_all(&[0x11, 0x22, 0x33, 0x44]).await.unwrap(); + client_peer.read_exact(&mut [0u8; 4]).await.unwrap(); + + drop(client_peer); + drop(server_peer); + + let relay_result = timeout(Duration::from_secs(2), relay).await.unwrap().unwrap(); + assert!(relay_result.is_ok()); + assert!(stats.get_user_total_octets(user) <= 16); +} + +#[tokio::test] +async fn negative_preloaded_quota_forbids_any_forwarding() { + let stats = Arc::new(Stats::new()); + let user = "quota-extended-negative-user"; + stats.add_user_octets_from(user, 8); + + let (mut client_peer, relay_client) = duplex(1024); + let (relay_server, mut server_peer) = duplex(1024); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 128, + 128, + user, + Arc::clone(&stats), + Some(8), + Arc::new(BufferPool::new()), + )); + + client_peer.write_all(&[0xAA]).await.unwrap(); + server_peer.write_all(&[0xBB]).await.unwrap(); + + assert_eq!(read_available(&mut server_peer, Duration::from_millis(120)).await, 0); + assert_eq!(read_available(&mut client_peer, Duration::from_millis(120)).await, 0); + + let relay_result = timeout(Duration::from_secs(2), relay).await.unwrap().unwrap(); + assert!(matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. }))); + assert!(stats.get_user_total_octets(user) <= 8); +} + +#[tokio::test] +async fn edge_quota_one_ensures_at_most_one_byte_across_directions() { + let stats = Arc::new(Stats::new()); + let user = "quota-extended-edge-user"; + + let (mut client_peer, relay_client) = duplex(1024); + let (relay_server, mut server_peer) = duplex(1024); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 128, + 128, + user, + Arc::clone(&stats), + Some(1), + Arc::new(BufferPool::new()), + )); + + let _ = tokio::join!( + client_peer.write_all(&[0xFE]), + server_peer.write_all(&[0xEF]), + ); + + let mut buf = [0u8; 1]; + let delivered_s2c = timeout(Duration::from_millis(120), client_peer.read(&mut buf)).await.unwrap().unwrap_or(0); + let delivered_c2s = timeout(Duration::from_millis(120), server_peer.read(&mut buf)).await.unwrap().unwrap_or(0); + + assert!(delivered_s2c + delivered_c2s <= 1); + + let relay_result = timeout(Duration::from_secs(2), relay).await.unwrap().unwrap(); + assert!(matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. }))); +} + +#[tokio::test] +async fn adversarial_blackhat_alternating_jitter_does_not_overshoot_quota() { + let stats = Arc::new(Stats::new()); + let user = "quota-extended-blackhat-user"; + let quota = 24u64; + + let (mut client_peer, relay_client) = duplex(4096); + let (relay_server, mut server_peer) = duplex(4096); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 256, + 256, + user, + Arc::clone(&stats), + Some(quota), + Arc::new(BufferPool::new()), + )); + + let mut total_forwarded = 0usize; + + for i in 0..256usize { + if relay.is_finished() { + break; + } + if (i & 1) == 0 { + let _ = client_peer.write_all(&[(i as u8) ^ 0x57]).await; + let mut one = [0u8; 1]; + if let Ok(Ok(n)) = timeout(Duration::from_millis(6), server_peer.read(&mut one)).await { + total_forwarded += n; + } + } else { + let _ = server_peer.write_all(&[(i as u8) ^ 0xA8]).await; + let mut one = [0u8; 1]; + if let Ok(Ok(n)) = timeout(Duration::from_millis(6), client_peer.read(&mut one)).await { + total_forwarded += n; + } + } + + tokio::time::sleep(Duration::from_millis(((i % 3) + 1) as u64)).await; + } + + let relay_result = timeout(Duration::from_secs(3), relay).await.unwrap().unwrap(); + assert!(matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. }))); + assert!(total_forwarded <= quota as usize); + assert!(stats.get_user_total_octets(user) <= quota); +} + +#[tokio::test] +async fn light_fuzz_random_quota_schedule_preserves_quota_invariants() { + let mut rng = StdRng::seed_from_u64(0xBEEF_C0DE); + + for case in 0..32u64 { + let stats = Arc::new(Stats::new()); + let user = format!("quota-extended-fuzz-{case}"); + let quota = rng.random_range(1u64..=35u64); + + let (mut client_peer, relay_client) = duplex(4096); + let (relay_server, mut server_peer) = duplex(4096); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_user = user.clone(); + let relay_stats = Arc::clone(&stats); + let relay = tokio::spawn(async move { + relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 256, + 256, + &relay_user, + Arc::clone(&relay_stats), + Some(quota), + Arc::new(BufferPool::new()), + ) + .await + }); + + let mut total_forwarded = 0usize; + + for _ in 0..96usize { + if relay.is_finished() { + break; + } + + if rng.random::() { + let _ = client_peer.write_all(&[rng.random::()]).await; + let mut one = [0u8; 1]; + if let Ok(Ok(n)) = timeout(Duration::from_millis(4), server_peer.read(&mut one)).await { + total_forwarded += n; + } + } else { + let _ = server_peer.write_all(&[rng.random::()]).await; + let mut one = [0u8; 1]; + if let Ok(Ok(n)) = timeout(Duration::from_millis(4), client_peer.read(&mut one)).await { + total_forwarded += n; + } + } + } + + drop(client_peer); + drop(server_peer); + + let relay_result = timeout(Duration::from_secs(2), relay).await.unwrap().unwrap(); + assert!(relay_result.is_ok() || matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. }))); + assert!(total_forwarded <= quota as usize); + assert!(stats.get_user_total_octets(&user) <= quota); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_parallel_relays_for_one_user_obey_global_quota() { + let stats = Arc::new(Stats::new()); + let user = "quota-extended-stress-user".to_string(); + let quota = 64u64; + + let mut tasks = Vec::new(); + + for worker in 0..4u8 { + let stats = Arc::clone(&stats); + let user = user.clone(); + + tasks.push(tokio::spawn(async move { + let (mut client_peer, relay_client) = duplex(2048); + let (relay_server, mut server_peer) = duplex(2048); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_user = user.clone(); + let relay_stats = Arc::clone(&stats); + let relay = tokio::spawn(async move { + relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 128, + 128, + &relay_user, + Arc::clone(&relay_stats), + Some(quota), + Arc::new(BufferPool::new()), + ) + .await + }); + + let mut total = 0usize; + for step in 0..64u8 { + if relay.is_finished() { + break; + } + if (step as usize + worker as usize) % 2 == 0 { + let _ = client_peer.write_all(&[(step ^ 0x5A)]).await; + let mut one = [0u8; 1]; + if let Ok(Ok(n)) = timeout(Duration::from_millis(6), server_peer.read(&mut one)).await { + total += n; + } + } else { + let _ = server_peer.write_all(&[(step ^ 0xA5)]).await; + let mut one = [0u8; 1]; + if let Ok(Ok(n)) = timeout(Duration::from_millis(6), client_peer.read(&mut one)).await { + total += n; + } + } + tokio::time::sleep(Duration::from_millis(1)).await; + } + + drop(client_peer); + drop(server_peer); + + let relay_result = timeout(Duration::from_secs(2), relay).await.unwrap().unwrap(); + assert!(relay_result.is_ok() || matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. }))); + total + })); + } + + let mut delivered = 0usize; + for task in tasks { + delivered += task.await.unwrap(); + } + + assert!(stats.get_user_total_octets(&user) <= quota); + assert!(delivered <= quota as usize); +} diff --git a/src/proxy/tests/relay_quota_lock_eviction_lifecycle_tdd_tests.rs b/src/proxy/tests/relay_quota_lock_eviction_lifecycle_tdd_tests.rs new file mode 100644 index 0000000..806efb6 --- /dev/null +++ b/src/proxy/tests/relay_quota_lock_eviction_lifecycle_tdd_tests.rs @@ -0,0 +1,79 @@ +use super::*; +use dashmap::DashMap; +use std::sync::Arc; +use tokio::time::{Duration, timeout}; + +#[test] +fn tdd_explicit_quota_lock_evict_reclaims_only_unheld_entries() { + let _guard = quota_user_lock_test_scope(); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let held_user = format!("quota-evict-held-{}", std::process::id()); + let stale_a_user = format!("quota-evict-stale-a-{}", std::process::id()); + let stale_b_user = format!("quota-evict-stale-b-{}", std::process::id()); + + let held = quota_user_lock(&held_user); + let stale_a = quota_user_lock(&stale_a_user); + let stale_b = quota_user_lock(&stale_b_user); + + assert!(map.get(&held_user).is_some()); + assert!(map.get(&stale_a_user).is_some()); + assert!(map.get(&stale_b_user).is_some()); + + drop(stale_a); + drop(stale_b); + + quota_user_lock_evict(); + + assert!( + map.get(&held_user).is_some(), + "held entry must survive eviction" + ); + assert!( + map.get(&stale_a_user).is_none(), + "unheld stale entry must be reclaimed" + ); + assert!( + map.get(&stale_b_user).is_none(), + "unheld stale entry must be reclaimed" + ); + + drop(held); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn tdd_periodic_quota_lock_evictor_reclaims_stale_entries_off_hot_path() { + let _guard = quota_user_lock_test_scope(); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let held_user = format!("quota-evict-loop-held-{}", std::process::id()); + let stale_user = format!("quota-evict-loop-stale-{}", std::process::id()); + + let held = quota_user_lock(&held_user); + let stale = quota_user_lock(&stale_user); + + assert_eq!(map.len(), 2); + drop(stale); + + let evictor = spawn_quota_user_lock_evictor_for_tests(Duration::from_millis(5)); + + timeout(Duration::from_millis(200), async { + loop { + if map.get(&stale_user).is_none() { + break; + } + tokio::time::sleep(Duration::from_millis(5)).await; + } + }) + .await + .expect("periodic quota lock evictor must reclaim stale entry"); + + evictor.abort(); + + assert!(map.get(&held_user).is_some()); + assert!(map.get(&stale_user).is_none()); + + drop(held); +} diff --git a/src/proxy/tests/relay_quota_lock_eviction_stress_security_tests.rs b/src/proxy/tests/relay_quota_lock_eviction_stress_security_tests.rs new file mode 100644 index 0000000..251582a --- /dev/null +++ b/src/proxy/tests/relay_quota_lock_eviction_stress_security_tests.rs @@ -0,0 +1,153 @@ +use super::*; +use dashmap::DashMap; +use std::sync::Arc; +use tokio::task::JoinSet; +use tokio::time::{Duration, timeout}; + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_background_evictor_with_high_churn_keeps_cache_bounded_and_live() { + let _guard = quota_user_lock_test_scope(); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let evictor = spawn_quota_user_lock_evictor_for_tests(Duration::from_millis(5)); + + let mut tasks = JoinSet::new(); + for worker in 0..24u32 { + tasks.spawn(async move { + for round in 0..320u32 { + let user = format!( + "quota-evict-stress-user-{}-{}-{}", + std::process::id(), + worker, + round + ); + let lock = quota_user_lock(&user); + if round % 19 == 0 { + tokio::task::yield_now().await; + } + drop(lock); + } + }); + } + + while let Some(done) = tasks.join_next().await { + done.expect("stress worker must not panic"); + } + + quota_user_lock_evict(); + tokio::time::sleep(Duration::from_millis(20)).await; + + assert!( + map.len() <= QUOTA_USER_LOCKS_MAX, + "quota lock map must remain bounded after churn + eviction" + ); + + let sanity_user = format!("quota-evict-stress-sanity-{}", std::process::id()); + let sanity_lock = quota_user_lock(&sanity_user); + assert!( + map.get(&sanity_user).is_some(), + "sanity user should be cacheable after eviction reclaimed stale entries" + ); + + drop(sanity_lock); + evictor.abort(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn adversarial_held_lock_survives_repeated_eviction_then_reclaims_after_release() { + let _guard = quota_user_lock_test_scope(); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let held_user = format!("quota-evict-held-survive-{}", std::process::id()); + let held = quota_user_lock(&held_user); + + let evictor = spawn_quota_user_lock_evictor_for_tests(Duration::from_millis(3)); + + for idx in 0..512u32 { + let user = format!("quota-evict-held-churn-{}-{}", std::process::id(), idx); + let temp = quota_user_lock(&user); + drop(temp); + if idx % 32 == 0 { + tokio::task::yield_now().await; + } + } + + let reacquired = quota_user_lock(&held_user); + assert!( + Arc::ptr_eq(&held, &reacquired), + "held user lock identity must remain stable across repeated evictions" + ); + assert!( + map.get(&held_user).is_some(), + "held user entry must not be reclaimed while externally referenced" + ); + + drop(reacquired); + drop(held); + + timeout(Duration::from_millis(300), async { + loop { + if map.get(&held_user).is_none() { + break; + } + tokio::time::sleep(Duration::from_millis(5)).await; + } + }) + .await + .expect("released held lock must be reclaimed by periodic evictor"); + + evictor.abort(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_saturation_then_periodic_eviction_recovers_cacheability_without_inline_retain() { + let _guard = quota_user_lock_test_scope(); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); + let prefix = format!("quota-evict-saturated-{}", std::process::id()); + for idx in 0..QUOTA_USER_LOCKS_MAX { + retained.push(quota_user_lock(&format!("{prefix}-{idx}"))); + } + + assert_eq!(map.len(), QUOTA_USER_LOCKS_MAX); + + let overflow_user = format!("quota-evict-overflow-user-{}", std::process::id()); + let overflow_before = quota_user_lock(&overflow_user); + assert!( + map.get(&overflow_user).is_none(), + "saturated map must initially route new user to overflow stripe" + ); + + drop(retained); + + let evictor = spawn_quota_user_lock_evictor_for_tests(Duration::from_millis(4)); + + timeout(Duration::from_millis(400), async { + loop { + if map.len() < QUOTA_USER_LOCKS_MAX { + break; + } + tokio::time::sleep(Duration::from_millis(5)).await; + } + }) + .await + .expect("periodic evictor must reclaim stale saturated entries"); + + let overflow_after = quota_user_lock(&overflow_user); + assert!( + map.get(&overflow_user).is_some(), + "after eviction, overflow user should become cacheable again" + ); + assert!( + Arc::strong_count(&overflow_after) >= 2, + "cacheable lock should be held by map and caller" + ); + + drop(overflow_before); + drop(overflow_after); + evictor.abort(); +} diff --git a/src/proxy/tests/relay_quota_lock_pressure_adversarial_tests.rs b/src/proxy/tests/relay_quota_lock_pressure_adversarial_tests.rs index e29e86e..5687965 100644 --- a/src/proxy/tests/relay_quota_lock_pressure_adversarial_tests.rs +++ b/src/proxy/tests/relay_quota_lock_pressure_adversarial_tests.rs @@ -127,7 +127,7 @@ fn quota_lock_saturation_returns_stable_overflow_lock_without_cache_growth() { } #[test] -fn quota_lock_reclaims_unreferenced_entries_before_ephemeral_fallback() { +fn quota_lock_reclaims_unreferenced_entries_after_explicit_eviction_pass() { let _guard = super::quota_user_lock_test_scope(); let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); map.clear(); @@ -142,6 +142,8 @@ fn quota_lock_reclaims_unreferenced_entries_before_ephemeral_fallback() { drop(retained); + quota_user_lock_evict(); + let overflow_user = format!("quota-reclaim-overflow-{}", std::process::id()); let overflow = quota_user_lock(&overflow_user); diff --git a/src/proxy/tests/relay_quota_retry_allocation_latency_security_tests.rs b/src/proxy/tests/relay_quota_retry_allocation_latency_security_tests.rs new file mode 100644 index 0000000..447a090 --- /dev/null +++ b/src/proxy/tests/relay_quota_retry_allocation_latency_security_tests.rs @@ -0,0 +1,249 @@ +use super::*; +use crate::stats::Stats; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::task::{Context, Waker}; +use tokio::io::AsyncWriteExt; +use tokio::time::{Duration, Instant, timeout}; + +#[derive(Default)] +struct WakeCounter { + wakes: AtomicUsize, +} + +impl std::task::Wake for WakeCounter { + fn wake(self: Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } + + fn wake_by_ref(self: &Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } +} + +fn quota_test_guard() -> impl Drop { + super::quota_user_lock_test_scope() +} + +fn build_context() -> (Arc, Context<'static>) { + let wake_counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&wake_counter)); + let leaked_waker: &'static Waker = Box::leak(Box::new(waker)); + (wake_counter, Context::from_waker(leaked_waker)) +} + +fn sleep_slot_ptr(slot: &Option>>) -> usize { + slot.as_ref() + .map(|sleep| (&**sleep) as *const tokio::time::Sleep as usize) + .unwrap_or(0) +} + +#[tokio::test] +async fn tdd_single_pending_timer_does_not_allocate_on_each_repoll() { + let _guard = quota_test_guard(); + + let user = format!("retry-alloc-single-pending-{}", std::process::id()); + let lock = quota_user_lock(&user); + let held_guard = lock + .try_lock() + .expect("test must hold local lock to force retry scheduling"); + + reset_quota_retry_sleep_allocs_for_tests(); + + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + Some(2048), + Arc::new(AtomicBool::new(false)), + Instant::now(), + ); + + let (_wake_counter, mut cx) = build_context(); + + let first = Pin::new(&mut io).poll_write(&mut cx, &[0xA1]); + assert!(first.is_pending()); + let allocs_after_first = quota_retry_sleep_allocs_for_tests(); + let ptr_after_first = sleep_slot_ptr(&io.quota_write_retry_sleep); + + let second = Pin::new(&mut io).poll_write(&mut cx, &[0xA2]); + assert!(second.is_pending()); + let allocs_after_second = quota_retry_sleep_allocs_for_tests(); + let ptr_after_second = sleep_slot_ptr(&io.quota_write_retry_sleep); + + assert_eq!(allocs_after_first, 1, "first pending poll must allocate one timer"); + assert_eq!( + allocs_after_second, 1, + "repoll while the same timer is pending must not allocate again" + ); + assert_eq!( + ptr_after_first, ptr_after_second, + "repoll while pending should retain the same timer allocation" + ); + + drop(held_guard); +} + +#[tokio::test] +async fn tdd_retry_cycle_allocates_once_per_fired_timer_cycle_not_per_poll() { + let _guard = quota_test_guard(); + + let user = format!("retry-alloc-per-cycle-{}", std::process::id()); + let lock = quota_user_lock(&user); + let held_guard = lock + .try_lock() + .expect("test must hold local lock to keep write path pending"); + + reset_quota_retry_sleep_allocs_for_tests(); + + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + Some(2048), + Arc::new(AtomicBool::new(false)), + Instant::now(), + ); + + let (wake_counter, mut cx) = build_context(); + + let mut polls = 0u64; + let mut observed_wakes = 0usize; + let started = Instant::now(); + while started.elapsed() < Duration::from_millis(70) { + let poll = Pin::new(&mut io).poll_write(&mut cx, &[0xB1]); + polls = polls.saturating_add(1); + assert!(poll.is_pending()); + + let wakes = wake_counter.wakes.load(Ordering::Relaxed); + if wakes > observed_wakes { + observed_wakes = wakes; + } + + tokio::time::sleep(Duration::from_millis(1)).await; + } + + let allocs = quota_retry_sleep_allocs_for_tests(); + assert!(allocs >= 2, "multiple fired cycles should allocate multiple timers"); + assert!( + allocs < polls, + "timer allocations must be bounded by cycles, not by every repoll (allocs={allocs}, polls={polls})" + ); + + drop(held_guard); +} + +#[tokio::test] +async fn adversarial_backoff_latency_envelope_stays_bounded_under_contention() { + let _guard = quota_test_guard(); + + let user = format!("retry-latency-envelope-{}", std::process::id()); + let lock = quota_user_lock(&user); + let held_guard = lock + .try_lock() + .expect("test must hold local lock for sustained contention"); + + reset_quota_retry_sleep_allocs_for_tests(); + + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + Some(2048), + Arc::new(AtomicBool::new(false)), + Instant::now(), + ); + + let (wake_counter, mut cx) = build_context(); + + let first = Pin::new(&mut io).poll_write(&mut cx, &[0xC1]); + assert!(first.is_pending()); + + let started = Instant::now(); + let mut last_wakes = 0usize; + let mut wake_instants = Vec::new(); + + while started.elapsed() < Duration::from_millis(120) { + let wakes = wake_counter.wakes.load(Ordering::Relaxed); + if wakes > last_wakes { + last_wakes = wakes; + wake_instants.push(Instant::now()); + let pending = Pin::new(&mut io).poll_write(&mut cx, &[0xC2]); + assert!(pending.is_pending()); + } + tokio::time::sleep(Duration::from_millis(1)).await; + } + + let mut max_gap = Duration::from_millis(0); + for idx in 1..wake_instants.len() { + let gap = wake_instants[idx].saturating_duration_since(wake_instants[idx - 1]); + if gap > max_gap { + max_gap = gap; + } + } + + assert!( + max_gap <= Duration::from_millis(35), + "retry wake gap must remain bounded in test profile; observed max gap={max_gap:?}" + ); + assert!( + quota_retry_sleep_allocs_for_tests() <= 16, + "allocation cycles must remain bounded during a short contention window" + ); + + drop(held_guard); +} + +#[tokio::test] +async fn micro_benchmark_release_to_completion_latency_stays_bounded() { + let _guard = quota_test_guard(); + + let rounds = 96usize; + let mut samples_ms = Vec::with_capacity(rounds); + + for round in 0..rounds { + let user = format!("retry-release-latency-{}-{round}", std::process::id()); + let lock = quota_user_lock(&user); + let held_guard = lock + .try_lock() + .expect("test must hold local lock before spawning blocked writer"); + + let writer = tokio::spawn(async move { + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + Some(2048), + Arc::new(AtomicBool::new(false)), + Instant::now(), + ); + io.write_all(&[0xD1]).await + }); + + tokio::time::sleep(Duration::from_millis(2)).await; + let release_at = Instant::now(); + drop(held_guard); + + let done = timeout(Duration::from_millis(120), writer) + .await + .expect("blocked writer must complete after release") + .expect("writer task must not panic"); + assert!(done.is_ok()); + + samples_ms.push(release_at.elapsed().as_millis() as u64); + } + + samples_ms.sort_unstable(); + let p95_idx = ((samples_ms.len() * 95) / 100).min(samples_ms.len().saturating_sub(1)); + let p95_ms = samples_ms[p95_idx]; + + assert!( + p95_ms <= 40, + "contention release->completion p95 must stay bounded; p95_ms={p95_ms}, samples={samples_ms:?}" + ); +} From 1abf9bd05c61c3b7f07552d881c31e200b54c2f2 Mon Sep 17 00:00:00 2001 From: David Osipov Date: Mon, 23 Mar 2026 12:27:57 +0400 Subject: [PATCH 07/29] Refactor CI workflows: rename build job and streamline stress testing setup --- .github/workflows/rust.yml | 22 ++++---------- .github/workflows/stress.yml | 57 ++++++++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 17 deletions(-) create mode 100644 .github/workflows/stress.yml diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 799f2ce..b245679 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -11,7 +11,7 @@ env: jobs: build: - name: Build + name: Compile, Test, Lint runs-on: ubuntu-latest permissions: @@ -39,23 +39,11 @@ jobs: restore-keys: | ${{ runner.os }}-cargo- - - name: Build Release - run: cargo build --release --verbose + - name: Compile (no tests) + run: cargo check --workspace --all-features --lib --bins --verbose - - name: Run tests - run: cargo test --verbose - - - name: Stress quota-lock suites (PR only) - if: github.event_name == 'pull_request' - env: - RUST_TEST_THREADS: 16 - run: | - set -euo pipefail - for i in $(seq 1 12); do - echo "[quota-lock-stress] iteration ${i}/12" - cargo test quota_lock_ --bin telemt -- --nocapture --test-threads 16 - cargo test relay_quota_wake --bin telemt -- --nocapture --test-threads 16 - done + - name: Run tests (single pass) + run: cargo test --workspace --all-features --verbose # clippy dont fail on warnings because of active development of telemt # and many warnings diff --git a/.github/workflows/stress.yml b/.github/workflows/stress.yml new file mode 100644 index 0000000..96b9a1b --- /dev/null +++ b/.github/workflows/stress.yml @@ -0,0 +1,57 @@ +name: Stress Tests + +on: + workflow_dispatch: + schedule: + - cron: '0 2 * * *' + pull_request: + branches: ["*"] + paths: + - src/proxy/** + - src/transport/** + - src/stream/** + - src/protocol/** + - src/tls_front/** + - Cargo.toml + - Cargo.lock + +env: + CARGO_TERM_COLOR: always + +jobs: + quota-lock-stress: + name: Quota-lock stress loop + runs-on: ubuntu-latest + + permissions: + contents: read + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Install latest stable Rust toolchain + uses: dtolnay/rust-toolchain@stable + + - name: Cache cargo registry and build artifacts + uses: actions/cache@v4 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + target + key: ${{ runner.os }}-cargo-stress-${{ hashFiles('**/Cargo.lock') }} + restore-keys: | + ${{ runner.os }}-cargo-stress- + ${{ runner.os }}-cargo- + + - name: Run quota-lock stress suites + env: + RUST_TEST_THREADS: 16 + run: | + set -euo pipefail + for i in $(seq 1 12); do + echo "[quota-lock-stress] iteration ${i}/12" + cargo test quota_lock_ --bin telemt -- --nocapture --test-threads 16 + cargo test relay_quota_wake --bin telemt -- --nocapture --test-threads 16 + done From 6f4356f72a212e5b0621417a9f28c15e9978657c Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Mon, 23 Mar 2026 15:53:44 +0300 Subject: [PATCH 08/29] Redesign Quotas on Atomics --- src/maestro/runtime_tasks.rs | 31 -- src/proxy/client.rs | 4 +- src/proxy/handshake.rs | 20 +- src/proxy/masking.rs | 10 +- src/proxy/middle_relay.rs | 347 +++---------- src/proxy/mod.rs | 1 - src/proxy/quota_lock_registry.rs | 88 ---- src/proxy/relay.rs | 488 ++++-------------- src/stats/mod.rs | 311 +++++++---- .../tests/user_octets_sub_security_tests.rs | 151 ------ 10 files changed, 408 insertions(+), 1043 deletions(-) delete mode 100644 src/proxy/quota_lock_registry.rs delete mode 100644 src/stats/tests/user_octets_sub_security_tests.rs diff --git a/src/maestro/runtime_tasks.rs b/src/maestro/runtime_tasks.rs index 066c853..d553eb9 100644 --- a/src/maestro/runtime_tasks.rs +++ b/src/maestro/runtime_tasks.rs @@ -32,14 +32,6 @@ pub(crate) struct RuntimeWatches { pub(crate) detected_ip_v6: Option, } -const QUOTA_USER_LOCK_EVICT_INTERVAL_SECS: u64 = 60; - -fn spawn_quota_lock_maintenance_task() -> tokio::task::JoinHandle<()> { - crate::proxy::relay::spawn_quota_user_lock_evictor(std::time::Duration::from_secs( - QUOTA_USER_LOCK_EVICT_INTERVAL_SECS, - )) -} - #[allow(clippy::too_many_arguments)] pub(crate) async fn spawn_runtime_tasks( config: &Arc, @@ -77,8 +69,6 @@ pub(crate) async fn spawn_runtime_tasks( rc_clone.run_periodic_cleanup().await; }); - spawn_quota_lock_maintenance_task(); - let detected_ip_v4: Option = probe.detected_ipv4.map(IpAddr::V4); let detected_ip_v6: Option = probe.detected_ipv6.map(IpAddr::V6); debug!( @@ -370,24 +360,3 @@ pub(crate) async fn mark_runtime_ready(startup_tracker: &Arc) { .await; startup_tracker.mark_ready().await; } - -#[cfg(test)] -mod tests { - use super::*; - - #[tokio::test] - async fn tdd_runtime_quota_lock_maintenance_path_spawns_single_evictor_task() { - crate::proxy::relay::reset_quota_user_lock_evictor_spawn_count_for_tests(); - - let handle = spawn_quota_lock_maintenance_task(); - tokio::time::sleep(std::time::Duration::from_millis(5)).await; - - assert_eq!( - crate::proxy::relay::quota_user_lock_evictor_spawn_count_for_tests(), - 1, - "runtime maintenance path must spawn exactly one quota lock evictor task per call" - ); - - handle.abort(); - } -} diff --git a/src/proxy/client.rs b/src/proxy/client.rs index a804a2c..1567caf 100644 --- a/src/proxy/client.rs +++ b/src/proxy/client.rs @@ -1223,7 +1223,7 @@ impl RunningClientHandler { } if let Some(quota) = config.access.user_data_quota.get(user) - && stats.get_user_total_octets(user) >= *quota + && stats.get_user_quota_used(user) >= *quota { return Err(ProxyError::DataQuotaExceeded { user: user.to_string(), @@ -1282,7 +1282,7 @@ impl RunningClientHandler { } if let Some(quota) = config.access.user_data_quota.get(user) - && stats.get_user_total_octets(user) >= *quota + && stats.get_user_quota_used(user) >= *quota { return Err(ProxyError::DataQuotaExceeded { user: user.to_string(), diff --git a/src/proxy/handshake.rs b/src/proxy/handshake.rs index 96994c7..55a8a21 100644 --- a/src/proxy/handshake.rs +++ b/src/proxy/handshake.rs @@ -614,6 +614,15 @@ where } }; + // Reject known replay digests before expensive cache/domain/ALPN policy work. + let digest_half = &validation.digest[..tls::TLS_DIGEST_HALF_LEN]; + if replay_checker.check_tls_digest(digest_half) { + auth_probe_record_failure(peer.ip(), Instant::now()); + maybe_apply_server_hello_delay(config).await; + warn!(peer = %peer, "TLS replay attack detected (duplicate digest)"); + return HandshakeResult::BadClient { reader, writer }; + } + let secret = match secrets.iter().find(|(name, _)| *name == validation.user) { Some((_, s)) => s, None => { @@ -669,15 +678,8 @@ where None }; - // Replay tracking is applied only after full policy validation (including - // ALPN checks) so rejected handshakes cannot poison replay state. - let digest_half = &validation.digest[..tls::TLS_DIGEST_HALF_LEN]; - if replay_checker.check_and_add_tls_digest(digest_half) { - auth_probe_record_failure(peer.ip(), Instant::now()); - maybe_apply_server_hello_delay(config).await; - warn!(peer = %peer, "TLS replay attack detected (duplicate digest)"); - return HandshakeResult::BadClient { reader, writer }; - } + // Add replay digest only for policy-valid handshakes. + replay_checker.add_tls_digest(digest_half); let response = if let Some((cached_entry, use_full_cert_payload)) = cached { emulator::build_emulated_server_hello( diff --git a/src/proxy/masking.rs b/src/proxy/masking.rs index 841749c..241a48f 100644 --- a/src/proxy/masking.rs +++ b/src/proxy/masking.rs @@ -60,7 +60,7 @@ where R: AsyncRead + Unpin, W: AsyncWrite + Unpin, { - let mut buf = [0u8; MASK_BUFFER_SIZE]; + let mut buf = Box::new([0u8; MASK_BUFFER_SIZE]); let mut total = 0usize; let mut ended_by_eof = false; @@ -262,7 +262,11 @@ fn mask_outcome_target_budget(config: &ProxyConfig) -> Duration { let floor = config.censorship.mask_timing_normalization_floor_ms; let ceiling = config.censorship.mask_timing_normalization_ceiling_ms; if floor == 0 { - return MASK_TIMEOUT; + if ceiling == 0 { + return Duration::from_millis(0); + } + let mut rng = rand::rng(); + return Duration::from_millis(rng.random_range(0..=ceiling)); } if ceiling > floor { let mut rng = rand::rng(); @@ -838,7 +842,7 @@ async fn consume_client_data(mut reader: R, byte_cap: usiz } // Keep drain path fail-closed under slow-loris stalls. - let mut buf = [0u8; MASK_BUFFER_SIZE]; + let mut buf = Box::new([0u8; MASK_BUFFER_SIZE]); let mut total = 0usize; loop { diff --git a/src/proxy/middle_relay.rs b/src/proxy/middle_relay.rs index 14ea001..2a84353 100644 --- a/src/proxy/middle_relay.rs +++ b/src/proxy/middle_relay.rs @@ -10,7 +10,7 @@ use std::time::{Duration, Instant}; use dashmap::DashMap; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; -use tokio::sync::{Mutex as AsyncMutex, mpsc, oneshot, watch}; +use tokio::sync::{mpsc, oneshot, watch}; use tokio::time::timeout; use tracing::{debug, info, trace, warn}; @@ -23,7 +23,7 @@ use crate::proxy::route_mode::{ ROUTE_SWITCH_ERROR_MSG, RelayRouteMode, RouteCutoverState, affected_cutover_state, cutover_stagger_delay, }; -use crate::stats::{MeD2cFlushReason, MeD2cQuotaRejectStage, MeD2cWriteMode, Stats}; +use crate::stats::{MeD2cFlushReason, MeD2cQuotaRejectStage, MeD2cWriteMode, QuotaReserveError, Stats, UserStats}; use crate::stream::{BufferPool, CryptoReader, CryptoWriter, PooledBuffer}; use crate::transport::middle_proxy::{MePool, MeResponse, proto_flags_for_tag}; @@ -53,20 +53,11 @@ const ME_D2C_FLUSH_BATCH_MAX_FRAMES_MIN: usize = 1; const ME_D2C_FLUSH_BATCH_MAX_BYTES_MIN: usize = 4096; const ME_D2C_FRAME_BUF_SHRINK_HYSTERESIS_FACTOR: usize = 2; const ME_D2C_SINGLE_WRITE_COALESCE_MAX_BYTES: usize = 128 * 1024; -#[cfg(test)] -const QUOTA_USER_LOCKS_MAX: usize = 64; -#[cfg(not(test))] -const QUOTA_USER_LOCKS_MAX: usize = 4_096; -#[cfg(test)] -const QUOTA_OVERFLOW_LOCK_STRIPES: usize = 16; -#[cfg(not(test))] -const QUOTA_OVERFLOW_LOCK_STRIPES: usize = 256; +const QUOTA_RESERVE_SPIN_RETRIES: usize = 32; static DESYNC_DEDUP: OnceLock> = OnceLock::new(); static DESYNC_HASHER: OnceLock = OnceLock::new(); static DESYNC_FULL_CACHE_LAST_EMIT_AT: OnceLock>> = OnceLock::new(); static DESYNC_DEDUP_EVER_SATURATED: OnceLock = OnceLock::new(); -static QUOTA_USER_LOCKS: OnceLock>>> = OnceLock::new(); -static QUOTA_USER_OVERFLOW_LOCKS: OnceLock>>> = OnceLock::new(); static RELAY_IDLE_CANDIDATE_REGISTRY: OnceLock> = OnceLock::new(); static RELAY_IDLE_MARK_SEQ: AtomicU64 = AtomicU64::new(0); @@ -538,36 +529,28 @@ fn should_yield_c2me_sender(sent_since_yield: usize, has_backlog: bool) -> bool has_backlog && sent_since_yield >= C2ME_SENDER_FAIRNESS_BUDGET } -fn quota_exceeded_for_user(stats: &Stats, user: &str, quota_limit: Option) -> bool { - quota_limit.is_some_and(|quota| stats.get_user_total_octets(user) >= quota) -} - -#[cfg_attr(not(test), allow(dead_code))] -fn quota_would_be_exceeded_for_user( - stats: &Stats, - user: &str, - quota_limit: Option, - bytes: u64, -) -> bool { - quota_limit.is_some_and(|quota| { - let used = stats.get_user_total_octets(user); - used >= quota || bytes > quota.saturating_sub(used) - }) -} - fn quota_soft_cap(limit: u64, overshoot: u64) -> u64 { limit.saturating_add(overshoot) } -fn quota_would_be_exceeded_for_user_soft( - stats: &Stats, - user: &str, - quota_limit: Option, +async fn reserve_user_quota_with_yield( + user_stats: &UserStats, bytes: u64, - overshoot: u64, -) -> bool { - let capped_limit = quota_limit.map(|quota| quota_soft_cap(quota, overshoot)); - quota_would_be_exceeded_for_user(stats, user, capped_limit, bytes) + limit: u64, +) -> std::result::Result { + loop { + for _ in 0..QUOTA_RESERVE_SPIN_RETRIES { + match user_stats.quota_try_reserve(bytes, limit) { + Ok(total) => return Ok(total), + Err(QuotaReserveError::LimitExceeded) => { + return Err(QuotaReserveError::LimitExceeded); + } + Err(QuotaReserveError::Contended) => std::hint::spin_loop(), + } + } + + tokio::task::yield_now().await; + } } fn classify_me_d2c_flush_reason( @@ -613,29 +596,6 @@ fn observe_me_d2c_flush_event( } } -fn rollback_me2c_quota_reservation( - stats: &Stats, - user: &str, - bytes_me2c: &AtomicU64, - reserved_bytes: u64, -) { - stats.sub_user_octets_to(user, reserved_bytes); - bytes_me2c.fetch_sub(reserved_bytes, Ordering::Relaxed); -} - -#[cfg(test)] -fn quota_user_lock_test_guard() -> &'static Mutex<()> { - static TEST_LOCK: OnceLock> = OnceLock::new(); - TEST_LOCK.get_or_init(|| Mutex::new(())) -} - -#[cfg(test)] -fn quota_user_lock_test_scope() -> std::sync::MutexGuard<'static, ()> { - quota_user_lock_test_guard() - .lock() - .unwrap_or_else(|poisoned| poisoned.into_inner()) -} - #[cfg(test)] fn relay_idle_pressure_test_guard() -> &'static Mutex<()> { static TEST_LOCK: OnceLock> = OnceLock::new(); @@ -649,46 +609,6 @@ pub(crate) fn relay_idle_pressure_test_scope() -> std::sync::MutexGuard<'static, .unwrap_or_else(|poisoned| poisoned.into_inner()) } -fn quota_overflow_user_lock(user: &str) -> Arc> { - let stripes = QUOTA_USER_OVERFLOW_LOCKS.get_or_init(|| { - (0..QUOTA_OVERFLOW_LOCK_STRIPES) - .map(|_| Arc::new(AsyncMutex::new(()))) - .collect() - }); - - let hash = crc32fast::hash(user.as_bytes()) as usize; - Arc::clone(&stripes[hash % stripes.len()]) -} - -fn quota_user_lock(user: &str) -> Arc> { - let locks = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - if let Some(existing) = locks.get(user) { - return Arc::clone(existing.value()); - } - - if locks.len() >= QUOTA_USER_LOCKS_MAX { - locks.retain(|_, value| Arc::strong_count(value) > 1); - } - - if locks.len() >= QUOTA_USER_LOCKS_MAX { - return quota_overflow_user_lock(user); - } - - let created = Arc::new(AsyncMutex::new(())); - match locks.entry(user.to_string()) { - dashmap::mapref::entry::Entry::Occupied(entry) => Arc::clone(entry.get()), - dashmap::mapref::entry::Entry::Vacant(entry) => { - entry.insert(Arc::clone(&created)); - created - } - } -} - -#[cfg(test)] -pub(crate) fn cross_mode_quota_user_lock_for_tests(user: &str) -> Arc> { - crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(user) -} - async fn enqueue_c2me_command( tx: &mpsc::Sender, cmd: C2MeCommand, @@ -744,8 +664,7 @@ where { let user = success.user.clone(); let quota_limit = config.access.user_data_quota.get(&user).copied(); - let cross_mode_quota_lock = - quota_limit.map(|_| crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user)); + let quota_user_stats = quota_limit.map(|_| stats.get_or_create_user_stats_handle(&user)); let peer = success.peer; let proto_tag = success.proto_tag; let pool_generation = me_pool.current_generation(); @@ -872,7 +791,7 @@ where let stats_clone = stats.clone(); let rng_clone = rng.clone(); let user_clone = user.clone(); - let cross_mode_quota_lock_me_writer = cross_mode_quota_lock.clone(); + let quota_user_stats_me_writer = quota_user_stats.clone(); let last_downstream_activity_ms_clone = last_downstream_activity_ms.clone(); let bytes_me2c_clone = bytes_me2c.clone(); let d2c_flush_policy = MeD2cFlushPolicy::from_config(&config); @@ -894,7 +813,7 @@ where let first_is_downstream_activity = matches!(&first, MeResponse::Data { .. } | MeResponse::Ack(_)); - match process_me_writer_response_with_cross_mode_lock( + match process_me_writer_response( first, &mut writer, proto_tag, @@ -902,9 +821,9 @@ where &mut frame_buf, stats_clone.as_ref(), &user_clone, + quota_user_stats_me_writer.as_deref(), quota_limit, d2c_flush_policy.quota_soft_overshoot_bytes, - cross_mode_quota_lock_me_writer.as_ref(), bytes_me2c_clone.as_ref(), conn_id, d2c_flush_policy.ack_flush_immediate, @@ -953,7 +872,7 @@ where let next_is_downstream_activity = matches!(&next, MeResponse::Data { .. } | MeResponse::Ack(_)); - match process_me_writer_response_with_cross_mode_lock( + match process_me_writer_response( next, &mut writer, proto_tag, @@ -961,9 +880,9 @@ where &mut frame_buf, stats_clone.as_ref(), &user_clone, + quota_user_stats_me_writer.as_deref(), quota_limit, d2c_flush_policy.quota_soft_overshoot_bytes, - cross_mode_quota_lock_me_writer.as_ref(), bytes_me2c_clone.as_ref(), conn_id, d2c_flush_policy.ack_flush_immediate, @@ -1015,7 +934,7 @@ where Ok(Some(next)) => { let next_is_downstream_activity = matches!(&next, MeResponse::Data { .. } | MeResponse::Ack(_)); - match process_me_writer_response_with_cross_mode_lock( + match process_me_writer_response( next, &mut writer, proto_tag, @@ -1023,9 +942,9 @@ where &mut frame_buf, stats_clone.as_ref(), &user_clone, + quota_user_stats_me_writer.as_deref(), quota_limit, d2c_flush_policy.quota_soft_overshoot_bytes, - cross_mode_quota_lock_me_writer.as_ref(), bytes_me2c_clone.as_ref(), conn_id, d2c_flush_policy.ack_flush_immediate, @@ -1079,7 +998,7 @@ where let extra_is_downstream_activity = matches!(&extra, MeResponse::Data { .. } | MeResponse::Ack(_)); - match process_me_writer_response_with_cross_mode_lock( + match process_me_writer_response( extra, &mut writer, proto_tag, @@ -1087,9 +1006,9 @@ where &mut frame_buf, stats_clone.as_ref(), &user_clone, + quota_user_stats_me_writer.as_deref(), quota_limit, d2c_flush_policy.quota_soft_overshoot_bytes, - cross_mode_quota_lock_me_writer.as_ref(), bytes_me2c_clone.as_ref(), conn_id, d2c_flush_policy.ack_flush_immediate, @@ -1259,24 +1178,23 @@ where forensics.bytes_c2me = forensics .bytes_c2me .saturating_add(payload.len() as u64); - if let Some(limit) = quota_limit { - let quota_lock = quota_user_lock(&user); - let _quota_guard = quota_lock.lock().await; - let Some(cross_mode_lock) = cross_mode_quota_lock.as_ref() else { - main_result = Err(ProxyError::Proxy( - "cross-mode quota lock missing for quota-limited session" - .to_string(), - )); - break; - }; - let _cross_mode_quota_guard = cross_mode_lock.lock().await; - stats.add_user_octets_from(&user, payload.len() as u64); - if quota_exceeded_for_user(stats.as_ref(), &user, Some(limit)) { + if let (Some(limit), Some(user_stats)) = + (quota_limit, quota_user_stats.as_deref()) + { + if reserve_user_quota_with_yield( + user_stats, + payload.len() as u64, + limit, + ) + .await + .is_err() + { main_result = Err(ProxyError::DataQuotaExceeded { user: user.clone(), }); break; } + stats.add_user_octets_from_handle(user_stats, payload.len() as u64); } else { stats.add_user_octets_from(&user, payload.len() as u64); } @@ -1755,7 +1673,6 @@ enum MeWriterResponseOutcome { Close, } -#[cfg(test)] async fn process_me_writer_response( response: MeResponse, client_writer: &mut CryptoWriter, @@ -1764,6 +1681,7 @@ async fn process_me_writer_response( frame_buf: &mut Vec, stats: &Stats, user: &str, + quota_user_stats: Option<&UserStats>, quota_limit: Option, quota_soft_overshoot_bytes: u64, bytes_me2c: &AtomicU64, @@ -1771,44 +1689,6 @@ async fn process_me_writer_response( ack_flush_immediate: bool, batched: bool, ) -> Result -where - W: AsyncWrite + Unpin + Send + 'static, -{ - process_me_writer_response_with_cross_mode_lock( - response, - client_writer, - proto_tag, - rng, - frame_buf, - stats, - user, - quota_limit, - quota_soft_overshoot_bytes, - None, - bytes_me2c, - conn_id, - ack_flush_immediate, - batched, - ) - .await -} - -async fn process_me_writer_response_with_cross_mode_lock( - response: MeResponse, - client_writer: &mut CryptoWriter, - proto_tag: ProtoTag, - rng: &SecureRandom, - frame_buf: &mut Vec, - stats: &Stats, - user: &str, - quota_limit: Option, - quota_soft_overshoot_bytes: u64, - cross_mode_quota_lock: Option<&Arc>>, - bytes_me2c: &AtomicU64, - conn_id: u64, - ack_flush_immediate: bool, - batched: bool, -) -> Result where W: AsyncWrite + Unpin + Send + 'static, { @@ -1820,78 +1700,43 @@ where trace!(conn_id, bytes = data.len(), flags, "ME->C data"); } let data_len = data.len() as u64; - if let Some(limit) = quota_limit { - let owned_cross_mode_lock; - let cross_mode_lock = if let Some(lock) = cross_mode_quota_lock { - lock - } else { - owned_cross_mode_lock = - crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(user); - &owned_cross_mode_lock - }; - let cross_mode_quota_guard = cross_mode_lock.lock().await; + if let (Some(limit), Some(user_stats)) = (quota_limit, quota_user_stats) { let soft_limit = quota_soft_cap(limit, quota_soft_overshoot_bytes); - if quota_would_be_exceeded_for_user_soft( - stats, - user, - Some(limit), - data_len, - quota_soft_overshoot_bytes, - ) { + if reserve_user_quota_with_yield(user_stats, data_len, soft_limit) + .await + .is_err() + { stats.increment_me_d2c_quota_reject_total(MeD2cQuotaRejectStage::PreWrite); return Err(ProxyError::DataQuotaExceeded { user: user.to_string(), }); } - - // Reserve quota before awaiting network I/O to avoid same-user HoL stalls. - // If reservation loses a race or write fails, we roll back immediately. - bytes_me2c.fetch_add(data_len, Ordering::Relaxed); - stats.add_user_octets_to(user, data_len); - - if stats.get_user_total_octets(user) > soft_limit { - rollback_me2c_quota_reservation(stats, user, bytes_me2c, data_len); - stats.increment_me_d2c_quota_reject_total(MeD2cQuotaRejectStage::PreWrite); - return Err(ProxyError::DataQuotaExceeded { - user: user.to_string(), - }); - } - - // Keep cross-mode lock scope explicit and minimal: quota reservation is serialized, - // but socket I/O proceeds without holding same-user cross-mode admission lock. - drop(cross_mode_quota_guard); - - let write_mode = - match write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf) - .await - { - Ok(mode) => mode, - Err(err) => { - rollback_me2c_quota_reservation(stats, user, bytes_me2c, data_len); - return Err(err); - } - }; - - stats.increment_me_d2c_data_frames_total(); - stats.add_me_d2c_payload_bytes_total(data_len); - stats.increment_me_d2c_write_mode(write_mode); - - // Do not fail immediately on exact boundary after a successful write. - // Returning an error here can bypass batch flush in the caller and risk - // dropping buffered ciphertext from CryptoWriter. The next frame is - // rejected by the pre-check at function entry. - } else { - let write_mode = - write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf) - .await?; - - bytes_me2c.fetch_add(data_len, Ordering::Relaxed); - stats.add_user_octets_to(user, data_len); - stats.increment_me_d2c_data_frames_total(); - stats.add_me_d2c_payload_bytes_total(data_len); - stats.increment_me_d2c_write_mode(write_mode); } + let write_mode = + match write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf) + .await + { + Ok(mode) => mode, + Err(err) => { + if quota_limit.is_some() { + stats.add_quota_write_fail_bytes_total(data_len); + stats.increment_quota_write_fail_events_total(); + } + return Err(err); + } + }; + + bytes_me2c.fetch_add(data_len, Ordering::Relaxed); + if let Some(user_stats) = quota_user_stats { + stats.add_user_octets_to_handle(user_stats, data_len); + } else { + stats.add_user_octets_to(user, data_len); + } + stats.increment_me_d2c_data_frames_total(); + stats.add_me_d2c_payload_bytes_total(data_len); + stats.increment_me_d2c_write_mode(write_mode); + Ok(MeWriterResponseOutcome::Continue { frames: 1, bytes: data.len(), @@ -2097,10 +1942,6 @@ where .map_err(ProxyError::Io) } -#[cfg(test)] -#[path = "tests/middle_relay_security_tests.rs"] -mod security_tests; - #[cfg(test)] #[path = "tests/middle_relay_idle_policy_security_tests.rs"] mod idle_policy_security_tests; @@ -2113,30 +1954,10 @@ mod desync_all_full_dedup_security_tests; #[path = "tests/middle_relay_stub_completion_security_tests.rs"] mod stub_completion_security_tests; -#[cfg(test)] -#[path = "tests/middle_relay_coverage_high_risk_security_tests.rs"] -mod coverage_high_risk_security_tests; - -#[cfg(test)] -#[path = "tests/middle_relay_quota_overflow_lock_security_tests.rs"] -mod quota_overflow_lock_security_tests; - #[cfg(test)] #[path = "tests/middle_relay_length_cast_hardening_security_tests.rs"] mod length_cast_hardening_security_tests; -#[cfg(test)] -#[path = "tests/middle_relay_blackhat_campaign_integration_tests.rs"] -mod blackhat_campaign_integration_tests; - -#[cfg(test)] -#[path = "tests/middle_relay_hol_quota_security_tests.rs"] -mod hol_quota_security_tests; - -#[cfg(test)] -#[path = "tests/middle_relay_quota_reservation_adversarial_tests.rs"] -mod quota_reservation_adversarial_tests; - #[cfg(test)] #[path = "tests/middle_relay_idle_registry_poison_security_tests.rs"] mod middle_relay_idle_registry_poison_security_tests; @@ -2156,27 +1977,3 @@ mod middle_relay_tiny_frame_debt_concurrency_security_tests; #[cfg(test)] #[path = "tests/middle_relay_tiny_frame_debt_proto_chunking_security_tests.rs"] mod middle_relay_tiny_frame_debt_proto_chunking_security_tests; - -#[cfg(test)] -#[path = "tests/middle_relay_cross_mode_quota_reservation_security_tests.rs"] -mod middle_relay_cross_mode_quota_reservation_security_tests; - -#[cfg(test)] -#[path = "tests/middle_relay_cross_mode_quota_lock_matrix_security_tests.rs"] -mod middle_relay_cross_mode_quota_lock_matrix_security_tests; - -#[cfg(test)] -#[path = "tests/middle_relay_cross_mode_lookup_efficiency_security_tests.rs"] -mod middle_relay_cross_mode_lookup_efficiency_security_tests; - -#[cfg(test)] -#[path = "tests/middle_relay_cross_mode_lock_release_regression_security_tests.rs"] -mod middle_relay_cross_mode_lock_release_regression_security_tests; - -#[cfg(test)] -#[path = "tests/middle_relay_quota_extended_attack_surface_security_tests.rs"] -mod middle_relay_quota_extended_attack_surface_security_tests; - -#[cfg(test)] -#[path = "tests/middle_relay_quota_reservation_extreme_security_tests.rs"] -mod middle_relay_quota_reservation_extreme_security_tests; diff --git a/src/proxy/mod.rs b/src/proxy/mod.rs index 519f1b3..eebc188 100644 --- a/src/proxy/mod.rs +++ b/src/proxy/mod.rs @@ -64,7 +64,6 @@ pub mod direct_relay; pub mod handshake; pub mod masking; pub mod middle_relay; -pub mod quota_lock_registry; pub mod relay; pub mod route_mode; pub mod session_eviction; diff --git a/src/proxy/quota_lock_registry.rs b/src/proxy/quota_lock_registry.rs deleted file mode 100644 index 7798b09..0000000 --- a/src/proxy/quota_lock_registry.rs +++ /dev/null @@ -1,88 +0,0 @@ -use dashmap::DashMap; -use std::sync::{Arc, OnceLock}; -use tokio::sync::Mutex; - -#[cfg(test)] -use std::sync::atomic::{AtomicUsize, Ordering}; - -#[cfg(test)] -const CROSS_MODE_QUOTA_USER_LOCKS_MAX: usize = 64; -#[cfg(not(test))] -const CROSS_MODE_QUOTA_USER_LOCKS_MAX: usize = 4_096; -#[cfg(test)] -const CROSS_MODE_QUOTA_OVERFLOW_LOCK_STRIPES: usize = 16; -#[cfg(not(test))] -const CROSS_MODE_QUOTA_OVERFLOW_LOCK_STRIPES: usize = 256; - -static CROSS_MODE_QUOTA_USER_LOCKS: OnceLock>>> = OnceLock::new(); -static CROSS_MODE_QUOTA_USER_OVERFLOW_LOCKS: OnceLock>>> = OnceLock::new(); - -#[cfg(test)] -static CROSS_MODE_QUOTA_USER_LOCK_LOOKUPS: AtomicUsize = AtomicUsize::new(0); -#[cfg(test)] -static CROSS_MODE_QUOTA_USER_LOOKUPS_BY_USER: OnceLock> = OnceLock::new(); - -fn cross_mode_quota_overflow_user_lock(user: &str) -> Arc> { - let stripes = CROSS_MODE_QUOTA_USER_OVERFLOW_LOCKS.get_or_init(|| { - (0..CROSS_MODE_QUOTA_OVERFLOW_LOCK_STRIPES) - .map(|_| Arc::new(Mutex::new(()))) - .collect() - }); - - let hash = crc32fast::hash(user.as_bytes()) as usize; - Arc::clone(&stripes[hash % stripes.len()]) -} - -pub(crate) fn cross_mode_quota_user_lock(user: &str) -> Arc> { - #[cfg(test)] - { - CROSS_MODE_QUOTA_USER_LOCK_LOOKUPS.fetch_add(1, Ordering::Relaxed); - let lookups = CROSS_MODE_QUOTA_USER_LOOKUPS_BY_USER.get_or_init(DashMap::new); - let mut entry = lookups.entry(user.to_string()).or_insert(0); - *entry += 1; - } - - let locks = CROSS_MODE_QUOTA_USER_LOCKS.get_or_init(DashMap::new); - if let Some(existing) = locks.get(user) { - return Arc::clone(existing.value()); - } - - if locks.len() >= CROSS_MODE_QUOTA_USER_LOCKS_MAX { - locks.retain(|_, value| Arc::strong_count(value) > 1); - } - - if locks.len() >= CROSS_MODE_QUOTA_USER_LOCKS_MAX { - return cross_mode_quota_overflow_user_lock(user); - } - - let created = Arc::new(Mutex::new(())); - match locks.entry(user.to_string()) { - dashmap::mapref::entry::Entry::Occupied(entry) => Arc::clone(entry.get()), - dashmap::mapref::entry::Entry::Vacant(entry) => { - entry.insert(Arc::clone(&created)); - created - } - } -} - -#[cfg(test)] -pub(crate) fn reset_cross_mode_quota_user_lock_lookup_count_for_tests() { - CROSS_MODE_QUOTA_USER_LOCK_LOOKUPS.store(0, Ordering::Relaxed); - let lookups = CROSS_MODE_QUOTA_USER_LOOKUPS_BY_USER.get_or_init(DashMap::new); - lookups.clear(); -} - -#[cfg(test)] -pub(crate) fn cross_mode_quota_user_lock_lookup_count_for_tests() -> usize { - CROSS_MODE_QUOTA_USER_LOCK_LOOKUPS.load(Ordering::Relaxed) -} - -#[cfg(test)] -pub(crate) fn cross_mode_quota_user_lock_lookup_count_for_user_for_tests(user: &str) -> usize { - let lookups = CROSS_MODE_QUOTA_USER_LOOKUPS_BY_USER.get_or_init(DashMap::new); - lookups.get(user).map(|entry| *entry).unwrap_or(0) -} - -#[cfg(test)] -#[path = "tests/quota_lock_registry_cross_mode_adversarial_tests.rs"] -mod quota_lock_registry_cross_mode_adversarial_tests; diff --git a/src/proxy/relay.rs b/src/proxy/relay.rs index 55f1385..cc8b088 100644 --- a/src/proxy/relay.rs +++ b/src/proxy/relay.rs @@ -52,18 +52,16 @@ //! - `SharedCounters` (atomics) let the watchdog read stats without locking use crate::error::{ProxyError, Result}; -use crate::stats::Stats; +use crate::stats::{Stats, UserStats}; use crate::stream::BufferPool; -use dashmap::DashMap; use std::io; use std::pin::Pin; use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; -use std::sync::{Arc, Mutex, OnceLock}; +use std::sync::Arc; use std::task::{Context, Poll}; use std::time::Duration; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf, copy_bidirectional_with_sizes}; -use tokio::sync::Mutex as AsyncMutex; -use tokio::time::{Instant, Sleep}; +use tokio::time::Instant; use tracing::{debug, trace, warn}; // ============= Constants ============= @@ -210,16 +208,10 @@ struct StatsIo { counters: Arc, stats: Arc, user: String, - quota_lock: Option>>, - cross_mode_quota_lock: Option>>, + user_stats: Arc, quota_limit: Option, quota_exceeded: Arc, - quota_read_wake_scheduled: bool, - quota_write_wake_scheduled: bool, - quota_read_retry_sleep: Option>>, - quota_write_retry_sleep: Option>>, - quota_read_retry_attempt: u8, - quota_write_retry_attempt: u8, + quota_bytes_since_check: u64, epoch: Instant, } @@ -235,24 +227,16 @@ impl StatsIo { ) -> Self { // Mark initial activity so the watchdog doesn't fire before data flows counters.touch(Instant::now(), epoch); - let quota_lock = quota_limit.map(|_| quota_user_lock(&user)); - let cross_mode_quota_lock = quota_limit - .map(|_| crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user)); + let user_stats = stats.get_or_create_user_stats_handle(&user); Self { inner, counters, stats, user, - quota_lock, - cross_mode_quota_lock, + user_stats, quota_limit, quota_exceeded, - quota_read_wake_scheduled: false, - quota_write_wake_scheduled: false, - quota_read_retry_sleep: None, - quota_write_retry_sleep: None, - quota_read_retry_attempt: 0, - quota_write_retry_attempt: 0, + quota_bytes_since_check: 0, epoch, } } @@ -281,169 +265,24 @@ fn is_quota_io_error(err: &io::Error) -> bool { .is_some() } -#[cfg(test)] -const QUOTA_CONTENTION_RETRY_INTERVAL: Duration = Duration::from_millis(1); -#[cfg(not(test))] -const QUOTA_CONTENTION_RETRY_INTERVAL: Duration = Duration::from_millis(2); -#[cfg(test)] -const QUOTA_CONTENTION_RETRY_MAX_INTERVAL: Duration = Duration::from_millis(16); -#[cfg(not(test))] -const QUOTA_CONTENTION_RETRY_MAX_INTERVAL: Duration = Duration::from_millis(64); +const QUOTA_NEAR_LIMIT_BYTES: u64 = 64 * 1024; +const QUOTA_LARGE_CHARGE_BYTES: u64 = 16 * 1024; +const QUOTA_ADAPTIVE_INTERVAL_MIN_BYTES: u64 = 4 * 1024; +const QUOTA_ADAPTIVE_INTERVAL_MAX_BYTES: u64 = 64 * 1024; -#[cfg(test)] -static QUOTA_RETRY_SLEEP_ALLOCS: AtomicU64 = AtomicU64::new(0); -#[cfg(test)] -static QUOTA_USER_LOCK_EVICTOR_SPAWN_COUNT: AtomicU64 = AtomicU64::new(0); - -#[cfg(test)] -pub(crate) fn reset_quota_retry_sleep_allocs_for_tests() { - QUOTA_RETRY_SLEEP_ALLOCS.store(0, Ordering::Relaxed); -} - -#[cfg(test)] -pub(crate) fn quota_retry_sleep_allocs_for_tests() -> u64 { - QUOTA_RETRY_SLEEP_ALLOCS.load(Ordering::Relaxed) +#[inline] +fn quota_adaptive_interval_bytes(remaining_before: u64) -> u64 { + remaining_before + .saturating_div(2) + .clamp( + QUOTA_ADAPTIVE_INTERVAL_MIN_BYTES, + QUOTA_ADAPTIVE_INTERVAL_MAX_BYTES, + ) } #[inline] -fn quota_contention_retry_delay(retry_attempt: u8) -> Duration { - let shift = u32::from(retry_attempt.min(5)); - let multiplier = 1_u32 << shift; - QUOTA_CONTENTION_RETRY_INTERVAL - .saturating_mul(multiplier) - .min(QUOTA_CONTENTION_RETRY_MAX_INTERVAL) -} - -#[inline] -fn reset_quota_retry_scheduler( - sleep_slot: &mut Option>>, - wake_scheduled: &mut bool, - retry_attempt: &mut u8, -) { - *wake_scheduled = false; - *sleep_slot = None; - *retry_attempt = 0; -} - -fn poll_quota_retry_sleep( - sleep_slot: &mut Option>>, - wake_scheduled: &mut bool, - retry_attempt: &mut u8, - cx: &mut Context<'_>, -) { - if !*wake_scheduled { - *wake_scheduled = true; - #[cfg(test)] - QUOTA_RETRY_SLEEP_ALLOCS.fetch_add(1, Ordering::Relaxed); - *sleep_slot = Some(Box::pin(tokio::time::sleep(quota_contention_retry_delay( - *retry_attempt, - )))); - } - - if let Some(sleep) = sleep_slot.as_mut() - && sleep.as_mut().poll(cx).is_ready() - { - *sleep_slot = None; - *wake_scheduled = false; - *retry_attempt = retry_attempt.saturating_add(1); - cx.waker().wake_by_ref(); - } -} - -static QUOTA_USER_LOCKS: OnceLock>>> = OnceLock::new(); -static QUOTA_USER_OVERFLOW_LOCKS: OnceLock>>> = OnceLock::new(); - -#[cfg(test)] -const QUOTA_USER_LOCKS_MAX: usize = 64; -#[cfg(not(test))] -const QUOTA_USER_LOCKS_MAX: usize = 4_096; -#[cfg(test)] -const QUOTA_OVERFLOW_LOCK_STRIPES: usize = 16; -#[cfg(not(test))] -const QUOTA_OVERFLOW_LOCK_STRIPES: usize = 256; - -#[cfg(test)] -fn quota_user_lock_test_guard() -> &'static Mutex<()> { - static TEST_LOCK: OnceLock> = OnceLock::new(); - TEST_LOCK.get_or_init(|| Mutex::new(())) -} - -#[cfg(test)] -fn quota_user_lock_test_scope() -> std::sync::MutexGuard<'static, ()> { - quota_user_lock_test_guard() - .lock() - .unwrap_or_else(|poisoned| poisoned.into_inner()) -} - -fn quota_overflow_user_lock(user: &str) -> Arc> { - let stripes = QUOTA_USER_OVERFLOW_LOCKS.get_or_init(|| { - (0..QUOTA_OVERFLOW_LOCK_STRIPES) - .map(|_| Arc::new(Mutex::new(()))) - .collect() - }); - - let hash = crc32fast::hash(user.as_bytes()) as usize; - Arc::clone(&stripes[hash % stripes.len()]) -} - -pub(crate) fn quota_user_lock_evict() { - if let Some(locks) = QUOTA_USER_LOCKS.get() { - locks.retain(|_, value| Arc::strong_count(value) > 1); - } -} - -pub(crate) fn spawn_quota_user_lock_evictor(interval: Duration) -> tokio::task::JoinHandle<()> { - let interval = interval.max(Duration::from_millis(1)); - #[cfg(test)] - QUOTA_USER_LOCK_EVICTOR_SPAWN_COUNT.fetch_add(1, Ordering::Relaxed); - tokio::spawn(async move { - loop { - tokio::time::sleep(interval).await; - quota_user_lock_evict(); - } - }) -} - -#[cfg(test)] -pub(crate) fn spawn_quota_user_lock_evictor_for_tests( - interval: Duration, -) -> tokio::task::JoinHandle<()> { - spawn_quota_user_lock_evictor(interval) -} - -#[cfg(test)] -pub(crate) fn reset_quota_user_lock_evictor_spawn_count_for_tests() { - QUOTA_USER_LOCK_EVICTOR_SPAWN_COUNT.store(0, Ordering::Relaxed); -} - -#[cfg(test)] -pub(crate) fn quota_user_lock_evictor_spawn_count_for_tests() -> u64 { - QUOTA_USER_LOCK_EVICTOR_SPAWN_COUNT.load(Ordering::Relaxed) -} - -fn quota_user_lock(user: &str) -> Arc> { - let locks = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - if let Some(existing) = locks.get(user) { - return Arc::clone(existing.value()); - } - - if locks.len() >= QUOTA_USER_LOCKS_MAX { - return quota_overflow_user_lock(user); - } - - let created = Arc::new(Mutex::new(())); - match locks.entry(user.to_string()) { - dashmap::mapref::entry::Entry::Occupied(entry) => Arc::clone(entry.get()), - dashmap::mapref::entry::Entry::Vacant(entry) => { - entry.insert(Arc::clone(&created)); - created - } - } -} - -#[cfg(test)] -pub(crate) fn cross_mode_quota_user_lock_for_tests(user: &str) -> Arc> { - crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(user) +fn should_immediate_quota_check(remaining_before: u64, charge_bytes: u64) -> bool { + remaining_before <= QUOTA_NEAR_LIMIT_BYTES || charge_bytes >= QUOTA_LARGE_CHARGE_BYTES } impl AsyncRead for StatsIo { @@ -453,93 +292,60 @@ impl AsyncRead for StatsIo { buf: &mut ReadBuf<'_>, ) -> Poll> { let this = self.get_mut(); - if this.quota_exceeded.load(Ordering::Relaxed) { + if this.quota_exceeded.load(Ordering::Acquire) { return Poll::Ready(Err(quota_io_error())); } - let _quota_guard = if let Some(lock) = this.quota_lock.as_ref() { - match lock.try_lock() { - Ok(guard) => Some(guard), - Err(_) => { - poll_quota_retry_sleep( - &mut this.quota_read_retry_sleep, - &mut this.quota_read_wake_scheduled, - &mut this.quota_read_retry_attempt, - cx, - ); - return Poll::Pending; - } + let mut remaining_before = None; + if let Some(limit) = this.quota_limit { + let used_before = this.user_stats.quota_used(); + let remaining = limit.saturating_sub(used_before); + if remaining == 0 { + this.quota_exceeded.store(true, Ordering::Release); + return Poll::Ready(Err(quota_io_error())); } - } else { - None - }; - - let _cross_mode_quota_guard = if let Some(lock) = this.cross_mode_quota_lock.as_ref() { - match lock.try_lock() { - Ok(guard) => Some(guard), - Err(_) => { - poll_quota_retry_sleep( - &mut this.quota_read_retry_sleep, - &mut this.quota_read_wake_scheduled, - &mut this.quota_read_retry_attempt, - cx, - ); - return Poll::Pending; - } - } - } else { - None - }; - - reset_quota_retry_scheduler( - &mut this.quota_read_retry_sleep, - &mut this.quota_read_wake_scheduled, - &mut this.quota_read_retry_attempt, - ); - - if let Some(limit) = this.quota_limit - && this.stats.get_user_total_octets(&this.user) >= limit - { - this.quota_exceeded.store(true, Ordering::Relaxed); - return Poll::Ready(Err(quota_io_error())); + remaining_before = Some(remaining); } + let before = buf.filled().len(); match Pin::new(&mut this.inner).poll_read(cx, buf) { Poll::Ready(Ok(())) => { let n = buf.filled().len() - before; if n > 0 { - let mut reached_quota_boundary = false; - if let Some(limit) = this.quota_limit { - let used = this.stats.get_user_total_octets(&this.user); - if used >= limit { - this.quota_exceeded.store(true, Ordering::Relaxed); - return Poll::Ready(Err(quota_io_error())); - } - - let remaining = limit - used; - if (n as u64) > remaining { - // Fail closed: when a single read chunk would cross quota, - // stop relay immediately without accounting beyond the cap. - this.quota_exceeded.store(true, Ordering::Relaxed); - return Poll::Ready(Err(quota_io_error())); - } - - reached_quota_boundary = (n as u64) == remaining; - } + let n_to_charge = n as u64; // C→S: client sent data this.counters .c2s_bytes - .fetch_add(n as u64, Ordering::Relaxed); + .fetch_add(n_to_charge, Ordering::Relaxed); this.counters.c2s_ops.fetch_add(1, Ordering::Relaxed); this.counters.touch(Instant::now(), this.epoch); - this.stats.add_user_octets_from(&this.user, n as u64); - this.stats.increment_user_msgs_from(&this.user); + this.stats + .add_user_octets_from_handle(this.user_stats.as_ref(), n_to_charge); + this.stats + .increment_user_msgs_from_handle(this.user_stats.as_ref()); - if reached_quota_boundary { - this.quota_exceeded.store(true, Ordering::Relaxed); + if let (Some(limit), Some(remaining)) = (this.quota_limit, remaining_before) { + this.stats + .quota_charge_post_write(this.user_stats.as_ref(), n_to_charge); + if should_immediate_quota_check(remaining, n_to_charge) { + this.quota_bytes_since_check = 0; + if this.user_stats.quota_used() >= limit { + this.quota_exceeded.store(true, Ordering::Release); + } + } else { + this.quota_bytes_since_check = + this.quota_bytes_since_check.saturating_add(n_to_charge); + let interval = quota_adaptive_interval_bytes(remaining); + if this.quota_bytes_since_check >= interval { + this.quota_bytes_since_check = 0; + if this.user_stats.quota_used() >= limit { + this.quota_exceeded.store(true, Ordering::Release); + } + } + } } trace!(user = %this.user, bytes = n, "C->S"); @@ -558,87 +364,57 @@ impl AsyncWrite for StatsIo { buf: &[u8], ) -> Poll> { let this = self.get_mut(); - if this.quota_exceeded.load(Ordering::Relaxed) { + if this.quota_exceeded.load(Ordering::Acquire) { return Poll::Ready(Err(quota_io_error())); } - let _quota_guard = if let Some(lock) = this.quota_lock.as_ref() { - match lock.try_lock() { - Ok(guard) => Some(guard), - Err(_) => { - poll_quota_retry_sleep( - &mut this.quota_write_retry_sleep, - &mut this.quota_write_wake_scheduled, - &mut this.quota_write_retry_attempt, - cx, - ); - return Poll::Pending; - } - } - } else { - None - }; - - let _cross_mode_quota_guard = if let Some(lock) = this.cross_mode_quota_lock.as_ref() { - match lock.try_lock() { - Ok(guard) => Some(guard), - Err(_) => { - poll_quota_retry_sleep( - &mut this.quota_write_retry_sleep, - &mut this.quota_write_wake_scheduled, - &mut this.quota_write_retry_attempt, - cx, - ); - return Poll::Pending; - } - } - } else { - None - }; - - reset_quota_retry_scheduler( - &mut this.quota_write_retry_sleep, - &mut this.quota_write_wake_scheduled, - &mut this.quota_write_retry_attempt, - ); - - let write_buf = if let Some(limit) = this.quota_limit { - let used = this.stats.get_user_total_octets(&this.user); - if used >= limit { - this.quota_exceeded.store(true, Ordering::Relaxed); + let mut remaining_before = None; + if let Some(limit) = this.quota_limit { + let used_before = this.user_stats.quota_used(); + let remaining = limit.saturating_sub(used_before); + if remaining == 0 { + this.quota_exceeded.store(true, Ordering::Release); return Poll::Ready(Err(quota_io_error())); } + remaining_before = Some(remaining); + } - let remaining = (limit - used) as usize; - if buf.len() > remaining { - // Fail closed: do not emit partial S->C payload when remaining - // quota cannot accommodate the pending write request. - this.quota_exceeded.store(true, Ordering::Relaxed); - return Poll::Ready(Err(quota_io_error())); - } - buf - } else { - buf - }; - - match Pin::new(&mut this.inner).poll_write(cx, write_buf) { + match Pin::new(&mut this.inner).poll_write(cx, buf) { Poll::Ready(Ok(n)) => { if n > 0 { + let n_to_charge = n as u64; + // S→C: data written to client this.counters .s2c_bytes - .fetch_add(n as u64, Ordering::Relaxed); + .fetch_add(n_to_charge, Ordering::Relaxed); this.counters.s2c_ops.fetch_add(1, Ordering::Relaxed); this.counters.touch(Instant::now(), this.epoch); - this.stats.add_user_octets_to(&this.user, n as u64); - this.stats.increment_user_msgs_to(&this.user); + this.stats + .add_user_octets_to_handle(this.user_stats.as_ref(), n_to_charge); + this.stats + .increment_user_msgs_to_handle(this.user_stats.as_ref()); - if let Some(limit) = this.quota_limit - && this.stats.get_user_total_octets(&this.user) >= limit - { - this.quota_exceeded.store(true, Ordering::Relaxed); - return Poll::Ready(Err(quota_io_error())); + if let (Some(limit), Some(remaining)) = (this.quota_limit, remaining_before) { + this.stats + .quota_charge_post_write(this.user_stats.as_ref(), n_to_charge); + if should_immediate_quota_check(remaining, n_to_charge) { + this.quota_bytes_since_check = 0; + if this.user_stats.quota_used() >= limit { + this.quota_exceeded.store(true, Ordering::Release); + } + } else { + this.quota_bytes_since_check = + this.quota_bytes_since_check.saturating_add(n_to_charge); + let interval = quota_adaptive_interval_bytes(remaining); + if this.quota_bytes_since_check >= interval { + this.quota_bytes_since_check = 0; + if this.user_stats.quota_used() >= limit { + this.quota_exceeded.store(true, Ordering::Release); + } + } + } } trace!(user = %this.user, bytes = n, "S->C"); @@ -732,7 +508,7 @@ where let now = Instant::now(); let idle = wd_counters.idle_duration(now, epoch); - if wd_quota_exceeded.load(Ordering::Relaxed) { + if wd_quota_exceeded.load(Ordering::Acquire) { warn!(user = %wd_user, "User data quota reached, closing relay"); return; } @@ -870,18 +646,10 @@ where } } -#[cfg(test)] -#[path = "tests/relay_security_tests.rs"] -mod security_tests; - #[cfg(test)] #[path = "tests/relay_adversarial_tests.rs"] mod adversarial_tests; -#[cfg(test)] -#[path = "tests/relay_quota_lock_pressure_adversarial_tests.rs"] -mod relay_quota_lock_pressure_adversarial_tests; - #[cfg(test)] #[path = "tests/relay_quota_boundary_blackhat_tests.rs"] mod relay_quota_boundary_blackhat_tests; @@ -901,71 +669,3 @@ mod relay_quota_extended_attack_surface_security_tests; #[cfg(test)] #[path = "tests/relay_watchdog_delta_security_tests.rs"] mod relay_watchdog_delta_security_tests; - -#[cfg(test)] -#[path = "tests/relay_quota_waker_storm_adversarial_tests.rs"] -mod relay_quota_waker_storm_adversarial_tests; - -#[cfg(test)] -#[path = "tests/relay_quota_wake_liveness_regression_tests.rs"] -mod relay_quota_wake_liveness_regression_tests; - -#[cfg(test)] -#[path = "tests/relay_quota_lock_identity_security_tests.rs"] -mod relay_quota_lock_identity_security_tests; - -#[cfg(test)] -#[path = "tests/relay_cross_mode_quota_lock_security_tests.rs"] -mod relay_cross_mode_quota_lock_security_tests; - -#[cfg(test)] -#[path = "tests/relay_quota_retry_scheduler_tdd_tests.rs"] -mod relay_quota_retry_scheduler_tdd_tests; - -#[cfg(test)] -#[path = "tests/relay_cross_mode_quota_fairness_tdd_tests.rs"] -mod relay_cross_mode_quota_fairness_tdd_tests; - -#[cfg(test)] -#[path = "tests/relay_cross_mode_pipeline_hol_integration_security_tests.rs"] -mod relay_cross_mode_pipeline_hol_integration_security_tests; - -#[cfg(test)] -#[path = "tests/relay_cross_mode_pipeline_latency_benchmark_security_tests.rs"] -mod relay_cross_mode_pipeline_latency_benchmark_security_tests; - -#[cfg(test)] -#[path = "tests/relay_quota_retry_backoff_security_tests.rs"] -mod relay_quota_retry_backoff_security_tests; - -#[cfg(test)] -#[path = "tests/relay_quota_retry_backoff_benchmark_security_tests.rs"] -mod relay_quota_retry_backoff_benchmark_security_tests; - -#[cfg(test)] -#[path = "tests/relay_dual_lock_backoff_regression_security_tests.rs"] -mod relay_dual_lock_backoff_regression_security_tests; - -#[cfg(test)] -#[path = "tests/relay_dual_lock_contention_matrix_security_tests.rs"] -mod relay_dual_lock_contention_matrix_security_tests; - -#[cfg(test)] -#[path = "tests/relay_dual_lock_race_harness_security_tests.rs"] -mod relay_dual_lock_race_harness_security_tests; - -#[cfg(test)] -#[path = "tests/relay_dual_lock_alternating_contention_security_tests.rs"] -mod relay_dual_lock_alternating_contention_security_tests; - -#[cfg(test)] -#[path = "tests/relay_quota_retry_allocation_latency_security_tests.rs"] -mod relay_quota_retry_allocation_latency_security_tests; - -#[cfg(test)] -#[path = "tests/relay_quota_lock_eviction_lifecycle_tdd_tests.rs"] -mod relay_quota_lock_eviction_lifecycle_tdd_tests; - -#[cfg(test)] -#[path = "tests/relay_quota_lock_eviction_stress_security_tests.rs"] -mod relay_quota_lock_eviction_stress_security_tests; diff --git a/src/stats/mod.rs b/src/stats/mod.rs index dc455a1..7d8aef3 100644 --- a/src/stats/mod.rs +++ b/src/stats/mod.rs @@ -238,10 +238,12 @@ pub struct Stats { me_inline_recovery_total: AtomicU64, ip_reservation_rollback_tcp_limit_total: AtomicU64, ip_reservation_rollback_quota_limit_total: AtomicU64, + quota_write_fail_bytes_total: AtomicU64, + quota_write_fail_events_total: AtomicU64, telemetry_core_enabled: AtomicBool, telemetry_user_enabled: AtomicBool, telemetry_me_level: AtomicU8, - user_stats: DashMap, + user_stats: DashMap>, user_stats_last_cleanup_epoch_secs: AtomicU64, start_time: parking_lot::RwLock>, } @@ -254,9 +256,51 @@ pub struct UserStats { pub octets_to_client: AtomicU64, pub msgs_from_client: AtomicU64, pub msgs_to_client: AtomicU64, + /// Total bytes charged against per-user quota admission. + /// + /// This counter is the single source of truth for quota enforcement and + /// intentionally tracks attempted traffic, not guaranteed delivery. + pub quota_used: AtomicU64, pub last_seen_epoch_secs: AtomicU64, } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum QuotaReserveError { + LimitExceeded, + Contended, +} + +impl UserStats { + #[inline] + pub fn quota_used(&self) -> u64 { + self.quota_used.load(Ordering::Relaxed) + } + + /// Attempts one CAS reservation step against the quota counter. + /// + /// Callers control retry/yield policy. This primitive intentionally does + /// not block or sleep so both sync poll paths and async paths can wrap it + /// with their own contention strategy. + #[inline] + pub fn quota_try_reserve(&self, bytes: u64, limit: u64) -> Result { + let current = self.quota_used.load(Ordering::Relaxed); + if bytes > limit.saturating_sub(current) { + return Err(QuotaReserveError::LimitExceeded); + } + + let next = current.saturating_add(bytes); + match self.quota_used.compare_exchange_weak( + current, + next, + Ordering::Relaxed, + Ordering::Relaxed, + ) { + Ok(_) => Ok(next), + Err(_) => Err(QuotaReserveError::Contended), + } + } +} + impl Stats { pub fn new() -> Self { let stats = Self::default(); @@ -316,6 +360,70 @@ impl Stats { .store(Self::now_epoch_secs(), Ordering::Relaxed); } + pub(crate) fn get_or_create_user_stats_handle(&self, user: &str) -> Arc { + self.maybe_cleanup_user_stats(); + if let Some(existing) = self.user_stats.get(user) { + let handle = Arc::clone(existing.value()); + Self::touch_user_stats(handle.as_ref()); + return handle; + } + + let entry = self.user_stats.entry(user.to_string()).or_default(); + if entry.last_seen_epoch_secs.load(Ordering::Relaxed) == 0 { + Self::touch_user_stats(entry.value().as_ref()); + } + Arc::clone(entry.value()) + } + + #[inline] + pub(crate) fn add_user_octets_from_handle(&self, user_stats: &UserStats, bytes: u64) { + if !self.telemetry_user_enabled() { + return; + } + Self::touch_user_stats(user_stats); + user_stats.octets_from_client.fetch_add(bytes, Ordering::Relaxed); + } + + #[inline] + pub(crate) fn add_user_octets_to_handle(&self, user_stats: &UserStats, bytes: u64) { + if !self.telemetry_user_enabled() { + return; + } + Self::touch_user_stats(user_stats); + user_stats.octets_to_client.fetch_add(bytes, Ordering::Relaxed); + } + + #[inline] + pub(crate) fn increment_user_msgs_from_handle(&self, user_stats: &UserStats) { + if !self.telemetry_user_enabled() { + return; + } + Self::touch_user_stats(user_stats); + user_stats.msgs_from_client.fetch_add(1, Ordering::Relaxed); + } + + #[inline] + pub(crate) fn increment_user_msgs_to_handle(&self, user_stats: &UserStats) { + if !self.telemetry_user_enabled() { + return; + } + Self::touch_user_stats(user_stats); + user_stats.msgs_to_client.fetch_add(1, Ordering::Relaxed); + } + + /// Charges already committed bytes in a post-I/O path. + /// + /// This helper is intentionally separate from `quota_try_reserve` to avoid + /// mixing reserve and post-charge on a single I/O event. + #[inline] + pub(crate) fn quota_charge_post_write(&self, user_stats: &UserStats, bytes: u64) -> u64 { + Self::touch_user_stats(user_stats); + user_stats + .quota_used + .fetch_add(bytes, Ordering::Relaxed) + .saturating_add(bytes) + } + fn maybe_cleanup_user_stats(&self) { const USER_STATS_CLEANUP_INTERVAL_SECS: u64 = 60; const USER_STATS_IDLE_TTL_SECS: u64 = 24 * 60 * 60; @@ -1114,6 +1222,18 @@ impl Stats { .fetch_add(1, Ordering::Relaxed); } } + pub fn add_quota_write_fail_bytes_total(&self, bytes: u64) { + if self.telemetry_core_enabled() { + self.quota_write_fail_bytes_total + .fetch_add(bytes, Ordering::Relaxed); + } + } + pub fn increment_quota_write_fail_events_total(&self) { + if self.telemetry_core_enabled() { + self.quota_write_fail_events_total + .fetch_add(1, Ordering::Relaxed); + } + } pub fn increment_me_endpoint_quarantine_total(&self) { if self.telemetry_me_allows_normal() { self.me_endpoint_quarantine_total @@ -1764,19 +1884,19 @@ impl Stats { self.ip_reservation_rollback_quota_limit_total .load(Ordering::Relaxed) } + pub fn get_quota_write_fail_bytes_total(&self) -> u64 { + self.quota_write_fail_bytes_total.load(Ordering::Relaxed) + } + pub fn get_quota_write_fail_events_total(&self) -> u64 { + self.quota_write_fail_events_total.load(Ordering::Relaxed) + } pub fn increment_user_connects(&self, user: &str) { if !self.telemetry_user_enabled() { return; } - self.maybe_cleanup_user_stats(); - if let Some(stats) = self.user_stats.get(user) { - Self::touch_user_stats(stats.value()); - stats.connects.fetch_add(1, Ordering::Relaxed); - return; - } - let stats = self.user_stats.entry(user.to_string()).or_default(); - Self::touch_user_stats(stats.value()); + let stats = self.get_or_create_user_stats_handle(user); + Self::touch_user_stats(stats.as_ref()); stats.connects.fetch_add(1, Ordering::Relaxed); } @@ -1784,14 +1904,8 @@ impl Stats { if !self.telemetry_user_enabled() { return; } - self.maybe_cleanup_user_stats(); - if let Some(stats) = self.user_stats.get(user) { - Self::touch_user_stats(stats.value()); - stats.curr_connects.fetch_add(1, Ordering::Relaxed); - return; - } - let stats = self.user_stats.entry(user.to_string()).or_default(); - Self::touch_user_stats(stats.value()); + let stats = self.get_or_create_user_stats_handle(user); + Self::touch_user_stats(stats.as_ref()); stats.curr_connects.fetch_add(1, Ordering::Relaxed); } @@ -1800,9 +1914,8 @@ impl Stats { return true; } - self.maybe_cleanup_user_stats(); - let stats = self.user_stats.entry(user.to_string()).or_default(); - Self::touch_user_stats(stats.value()); + let stats = self.get_or_create_user_stats_handle(user); + Self::touch_user_stats(stats.as_ref()); let counter = &stats.curr_connects; let mut current = counter.load(Ordering::Relaxed); @@ -1827,7 +1940,7 @@ impl Stats { pub fn decrement_user_curr_connects(&self, user: &str) { self.maybe_cleanup_user_stats(); if let Some(stats) = self.user_stats.get(user) { - Self::touch_user_stats(stats.value()); + Self::touch_user_stats(stats.value().as_ref()); let counter = &stats.curr_connects; let mut current = counter.load(Ordering::Relaxed); loop { @@ -1858,86 +1971,32 @@ impl Stats { if !self.telemetry_user_enabled() { return; } - self.maybe_cleanup_user_stats(); - if let Some(stats) = self.user_stats.get(user) { - Self::touch_user_stats(stats.value()); - stats.octets_from_client.fetch_add(bytes, Ordering::Relaxed); - return; - } - let stats = self.user_stats.entry(user.to_string()).or_default(); - Self::touch_user_stats(stats.value()); - stats.octets_from_client.fetch_add(bytes, Ordering::Relaxed); + let stats = self.get_or_create_user_stats_handle(user); + self.add_user_octets_from_handle(stats.as_ref(), bytes); } pub fn add_user_octets_to(&self, user: &str, bytes: u64) { if !self.telemetry_user_enabled() { return; } - self.maybe_cleanup_user_stats(); - if let Some(stats) = self.user_stats.get(user) { - Self::touch_user_stats(stats.value()); - stats.octets_to_client.fetch_add(bytes, Ordering::Relaxed); - return; - } - let stats = self.user_stats.entry(user.to_string()).or_default(); - Self::touch_user_stats(stats.value()); - stats.octets_to_client.fetch_add(bytes, Ordering::Relaxed); - } - - pub fn sub_user_octets_to(&self, user: &str, bytes: u64) { - if !self.telemetry_user_enabled() { - return; - } - self.maybe_cleanup_user_stats(); - let Some(stats) = self.user_stats.get(user) else { - return; - }; - - Self::touch_user_stats(stats.value()); - let counter = &stats.octets_to_client; - let mut current = counter.load(Ordering::Relaxed); - loop { - let next = current.saturating_sub(bytes); - match counter.compare_exchange_weak( - current, - next, - Ordering::Relaxed, - Ordering::Relaxed, - ) { - Ok(_) => break, - Err(actual) => current = actual, - } - } + let stats = self.get_or_create_user_stats_handle(user); + self.add_user_octets_to_handle(stats.as_ref(), bytes); } pub fn increment_user_msgs_from(&self, user: &str) { if !self.telemetry_user_enabled() { return; } - self.maybe_cleanup_user_stats(); - if let Some(stats) = self.user_stats.get(user) { - Self::touch_user_stats(stats.value()); - stats.msgs_from_client.fetch_add(1, Ordering::Relaxed); - return; - } - let stats = self.user_stats.entry(user.to_string()).or_default(); - Self::touch_user_stats(stats.value()); - stats.msgs_from_client.fetch_add(1, Ordering::Relaxed); + let stats = self.get_or_create_user_stats_handle(user); + self.increment_user_msgs_from_handle(stats.as_ref()); } pub fn increment_user_msgs_to(&self, user: &str) { if !self.telemetry_user_enabled() { return; } - self.maybe_cleanup_user_stats(); - if let Some(stats) = self.user_stats.get(user) { - Self::touch_user_stats(stats.value()); - stats.msgs_to_client.fetch_add(1, Ordering::Relaxed); - return; - } - let stats = self.user_stats.entry(user.to_string()).or_default(); - Self::touch_user_stats(stats.value()); - stats.msgs_to_client.fetch_add(1, Ordering::Relaxed); + let stats = self.get_or_create_user_stats_handle(user); + self.increment_user_msgs_to_handle(stats.as_ref()); } pub fn get_user_total_octets(&self, user: &str) -> u64 { @@ -1950,6 +2009,13 @@ impl Stats { .unwrap_or(0) } + pub fn get_user_quota_used(&self, user: &str) -> u64 { + self.user_stats + .get(user) + .map(|s| s.quota_used.load(Ordering::Relaxed)) + .unwrap_or(0) + } + pub fn get_handshake_timeouts(&self) -> u64 { self.handshake_timeouts.load(Ordering::Relaxed) } @@ -2015,7 +2081,7 @@ impl Stats { .load(Ordering::Relaxed) } - pub fn iter_user_stats(&self) -> dashmap::iter::Iter<'_, String, UserStats> { + pub fn iter_user_stats(&self) -> dashmap::iter::Iter<'_, String, Arc> { self.user_stats.iter() } @@ -2163,6 +2229,22 @@ impl ReplayChecker { found } + fn check_only_internal( + &self, + data: &[u8], + shards: &[Mutex], + window: Duration, + ) -> bool { + self.checks.fetch_add(1, Ordering::Relaxed); + let idx = self.get_shard_idx(data); + let mut shard = shards[idx].lock(); + let found = shard.check(data, Instant::now(), window); + if found { + self.hits.fetch_add(1, Ordering::Relaxed); + } + found + } + fn add_only(&self, data: &[u8], shards: &[Mutex], window: Duration) { self.additions.fetch_add(1, Ordering::Relaxed); let idx = self.get_shard_idx(data); @@ -2186,7 +2268,7 @@ impl ReplayChecker { self.add_only(data, &self.handshake_shards, self.window) } pub fn check_tls_digest(&self, data: &[u8]) -> bool { - self.check_and_add_tls_digest(data) + self.check_only_internal(data, &self.tls_shards, self.tls_window) } pub fn add_tls_digest(&self, data: &[u8]) { self.add_only(data, &self.tls_shards, self.tls_window) @@ -2289,6 +2371,7 @@ impl ReplayStats { mod tests { use super::*; use crate::config::MeTelemetryLevel; + use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; #[test] @@ -2457,6 +2540,60 @@ mod tests { } assert_eq!(checker.stats().total_entries, 500); } + + #[test] + fn test_quota_reserve_under_contention_hits_limit_exactly() { + let user_stats = Arc::new(UserStats::default()); + let successes = Arc::new(AtomicU64::new(0)); + let limit = 8_192u64; + let mut workers = Vec::new(); + + for _ in 0..8 { + let user_stats = user_stats.clone(); + let successes = successes.clone(); + workers.push(std::thread::spawn(move || { + loop { + match user_stats.quota_try_reserve(1, limit) { + Ok(_) => { + successes.fetch_add(1, Ordering::Relaxed); + } + Err(QuotaReserveError::Contended) => { + std::hint::spin_loop(); + } + Err(QuotaReserveError::LimitExceeded) => { + break; + } + } + } + })); + } + + for worker in workers { + worker.join().expect("worker thread must finish"); + } + + assert_eq!( + successes.load(Ordering::Relaxed), + limit, + "successful reservations must stop exactly at limit" + ); + assert_eq!(user_stats.quota_used(), limit); + } + + #[test] + fn test_quota_used_is_authoritative_and_independent_from_octets_telemetry() { + let stats = Stats::new(); + let user = "quota-authoritative-user"; + let user_stats = stats.get_or_create_user_stats_handle(user); + + stats.add_user_octets_to_handle(&user_stats, 5); + assert_eq!(stats.get_user_total_octets(user), 5); + assert_eq!(stats.get_user_quota_used(user), 0); + + stats.quota_charge_post_write(&user_stats, 7); + assert_eq!(stats.get_user_total_octets(user), 5); + assert_eq!(stats.get_user_quota_used(user), 7); + } } #[cfg(test)] @@ -2466,7 +2603,3 @@ mod connection_lease_security_tests; #[cfg(test)] #[path = "tests/replay_checker_security_tests.rs"] mod replay_checker_security_tests; - -#[cfg(test)] -#[path = "tests/user_octets_sub_security_tests.rs"] -mod user_octets_sub_security_tests; diff --git a/src/stats/tests/user_octets_sub_security_tests.rs b/src/stats/tests/user_octets_sub_security_tests.rs deleted file mode 100644 index d4e7580..0000000 --- a/src/stats/tests/user_octets_sub_security_tests.rs +++ /dev/null @@ -1,151 +0,0 @@ -use super::*; -use std::sync::Arc; -use std::thread; - -#[test] -fn sub_user_octets_to_underflow_saturates_at_zero() { - let stats = Stats::new(); - let user = "sub-underflow-user"; - - stats.add_user_octets_to(user, 3); - stats.sub_user_octets_to(user, 100); - - assert_eq!(stats.get_user_total_octets(user), 0); -} - -#[test] -fn sub_user_octets_to_does_not_affect_octets_from_client() { - let stats = Stats::new(); - let user = "sub-isolation-user"; - - stats.add_user_octets_from(user, 17); - stats.add_user_octets_to(user, 5); - stats.sub_user_octets_to(user, 3); - - assert_eq!(stats.get_user_total_octets(user), 19); -} - -#[test] -fn light_fuzz_add_sub_model_matches_saturating_reference() { - let stats = Stats::new(); - let user = "sub-fuzz-user"; - let mut seed = 0x91D2_4CB8_EE77_1101u64; - let mut model_to = 0u64; - - for _ in 0..8192 { - seed ^= seed << 7; - seed ^= seed >> 9; - seed ^= seed << 8; - - let amt = ((seed >> 8) & 0x3f) + 1; - if (seed & 1) == 0 { - stats.add_user_octets_to(user, amt); - model_to = model_to.saturating_add(amt); - } else { - stats.sub_user_octets_to(user, amt); - model_to = model_to.saturating_sub(amt); - } - } - - assert_eq!(stats.get_user_total_octets(user), model_to); -} - -#[test] -fn stress_parallel_add_sub_never_underflows_or_panics() { - let stats = Arc::new(Stats::new()); - let user = "sub-stress-user"; - // Pre-fund with a large offset so subtractions never saturate at zero. - // This guarantees commutative updates, making the final state deterministic. - let base_offset = 10_000_000u64; - stats.add_user_octets_to(user, base_offset); - - let mut workers = Vec::new(); - - for tid in 0..16u64 { - let stats_for_thread = Arc::clone(&stats); - workers.push(thread::spawn(move || { - let mut seed = 0xD00D_1000_0000_0000u64 ^ tid; - let mut net_delta = 0i64; - for _ in 0..4096 { - seed ^= seed << 7; - seed ^= seed >> 9; - seed ^= seed << 8; - let amt = ((seed >> 8) & 0x1f) + 1; - - if (seed & 1) == 0 { - stats_for_thread.add_user_octets_to(user, amt); - net_delta += amt as i64; - } else { - stats_for_thread.sub_user_octets_to(user, amt); - net_delta -= amt as i64; - } - } - - net_delta - })); - } - - let mut expected_net_delta = 0i64; - for worker in workers { - expected_net_delta += worker - .join() - .expect("sub-user stress worker must not panic"); - } - - let expected_total = (base_offset as i64 + expected_net_delta) as u64; - let total = stats.get_user_total_octets(user); - assert_eq!( - total, expected_total, - "concurrent add/sub lost updates or suffered ABA races" - ); -} - -#[test] -fn sub_user_octets_to_missing_user_is_noop() { - let stats = Stats::new(); - stats.sub_user_octets_to("missing-user", 1024); - assert_eq!(stats.get_user_total_octets("missing-user"), 0); -} - -#[test] -fn stress_parallel_per_user_models_remain_exact() { - let stats = Arc::new(Stats::new()); - let mut workers = Vec::new(); - - for tid in 0..16u64 { - let stats_for_thread = Arc::clone(&stats); - workers.push(thread::spawn(move || { - let user = format!("sub-per-user-{tid}"); - let mut seed = 0xFACE_0000_0000_0000u64 ^ tid; - let mut model = 0u64; - - for _ in 0..4096 { - seed ^= seed << 7; - seed ^= seed >> 9; - seed ^= seed << 8; - let amt = ((seed >> 8) & 0x3f) + 1; - - if (seed & 1) == 0 { - stats_for_thread.add_user_octets_to(&user, amt); - model = model.saturating_add(amt); - } else { - stats_for_thread.sub_user_octets_to(&user, amt); - model = model.saturating_sub(amt); - } - } - - (user, model) - })); - } - - for worker in workers { - let (user, model) = worker - .join() - .expect("per-user subtract stress worker must not panic"); - assert_eq!( - stats.get_user_total_octets(&user), - model, - "per-user parallel model diverged" - ); - } -} \ No newline at end of file From 2f9fddfa6fd51580ce0571afe76c3dc71956df4f Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Mon, 23 Mar 2026 16:21:53 +0300 Subject: [PATCH 09/29] Old Test Deletion --- ...lay_blackhat_campaign_integration_tests.rs | 113 - ...relay_coverage_high_risk_security_tests.rs | 777 ----- ..._lock_release_regression_security_tests.rs | 295 -- ...s_mode_lookup_efficiency_security_tests.rs | 116 - ...s_mode_quota_lock_matrix_security_tests.rs | 376 --- ...s_mode_quota_reservation_security_tests.rs | 254 -- .../middle_relay_hol_quota_security_tests.rs | 232 -- ..._extended_attack_surface_security_tests.rs | 372 --- ...elay_quota_overflow_lock_security_tests.rs | 131 - ...lay_quota_reservation_adversarial_tests.rs | 1066 ------- ...uota_reservation_extreme_security_tests.rs | 399 --- .../tests/middle_relay_security_tests.rs | 2517 ----------------- ...k_registry_cross_mode_adversarial_tests.rs | 108 - ...pipeline_hol_integration_security_tests.rs | 267 -- ...peline_latency_benchmark_security_tests.rs | 213 -- ...lay_cross_mode_quota_fairness_tdd_tests.rs | 604 ---- ...ay_cross_mode_quota_lock_security_tests.rs | 81 - ...k_alternating_contention_security_tests.rs | 340 --- ..._lock_backoff_regression_security_tests.rs | 74 - ...l_lock_contention_matrix_security_tests.rs | 325 --- ...y_dual_lock_race_harness_security_tests.rs | 128 - ...quota_lock_eviction_lifecycle_tdd_tests.rs | 79 - ...ota_lock_eviction_stress_security_tests.rs | 153 - ...elay_quota_lock_identity_security_tests.rs | 135 - ...y_quota_lock_pressure_adversarial_tests.rs | 440 --- ...retry_allocation_latency_security_tests.rs | 249 -- ..._retry_backoff_benchmark_security_tests.rs | 241 -- ...elay_quota_retry_backoff_security_tests.rs | 339 --- .../relay_quota_retry_scheduler_tdd_tests.rs | 246 -- ...ay_quota_wake_liveness_regression_tests.rs | 294 -- ...lay_quota_waker_storm_adversarial_tests.rs | 310 -- src/proxy/tests/relay_security_tests.rs | 1284 --------- 32 files changed, 12558 deletions(-) delete mode 100644 src/proxy/tests/middle_relay_blackhat_campaign_integration_tests.rs delete mode 100644 src/proxy/tests/middle_relay_coverage_high_risk_security_tests.rs delete mode 100644 src/proxy/tests/middle_relay_cross_mode_lock_release_regression_security_tests.rs delete mode 100644 src/proxy/tests/middle_relay_cross_mode_lookup_efficiency_security_tests.rs delete mode 100644 src/proxy/tests/middle_relay_cross_mode_quota_lock_matrix_security_tests.rs delete mode 100644 src/proxy/tests/middle_relay_cross_mode_quota_reservation_security_tests.rs delete mode 100644 src/proxy/tests/middle_relay_hol_quota_security_tests.rs delete mode 100644 src/proxy/tests/middle_relay_quota_extended_attack_surface_security_tests.rs delete mode 100644 src/proxy/tests/middle_relay_quota_overflow_lock_security_tests.rs delete mode 100644 src/proxy/tests/middle_relay_quota_reservation_adversarial_tests.rs delete mode 100644 src/proxy/tests/middle_relay_quota_reservation_extreme_security_tests.rs delete mode 100644 src/proxy/tests/middle_relay_security_tests.rs delete mode 100644 src/proxy/tests/quota_lock_registry_cross_mode_adversarial_tests.rs delete mode 100644 src/proxy/tests/relay_cross_mode_pipeline_hol_integration_security_tests.rs delete mode 100644 src/proxy/tests/relay_cross_mode_pipeline_latency_benchmark_security_tests.rs delete mode 100644 src/proxy/tests/relay_cross_mode_quota_fairness_tdd_tests.rs delete mode 100644 src/proxy/tests/relay_cross_mode_quota_lock_security_tests.rs delete mode 100644 src/proxy/tests/relay_dual_lock_alternating_contention_security_tests.rs delete mode 100644 src/proxy/tests/relay_dual_lock_backoff_regression_security_tests.rs delete mode 100644 src/proxy/tests/relay_dual_lock_contention_matrix_security_tests.rs delete mode 100644 src/proxy/tests/relay_dual_lock_race_harness_security_tests.rs delete mode 100644 src/proxy/tests/relay_quota_lock_eviction_lifecycle_tdd_tests.rs delete mode 100644 src/proxy/tests/relay_quota_lock_eviction_stress_security_tests.rs delete mode 100644 src/proxy/tests/relay_quota_lock_identity_security_tests.rs delete mode 100644 src/proxy/tests/relay_quota_lock_pressure_adversarial_tests.rs delete mode 100644 src/proxy/tests/relay_quota_retry_allocation_latency_security_tests.rs delete mode 100644 src/proxy/tests/relay_quota_retry_backoff_benchmark_security_tests.rs delete mode 100644 src/proxy/tests/relay_quota_retry_backoff_security_tests.rs delete mode 100644 src/proxy/tests/relay_quota_retry_scheduler_tdd_tests.rs delete mode 100644 src/proxy/tests/relay_quota_wake_liveness_regression_tests.rs delete mode 100644 src/proxy/tests/relay_quota_waker_storm_adversarial_tests.rs delete mode 100644 src/proxy/tests/relay_security_tests.rs diff --git a/src/proxy/tests/middle_relay_blackhat_campaign_integration_tests.rs b/src/proxy/tests/middle_relay_blackhat_campaign_integration_tests.rs deleted file mode 100644 index 6f0e91a..0000000 --- a/src/proxy/tests/middle_relay_blackhat_campaign_integration_tests.rs +++ /dev/null @@ -1,113 +0,0 @@ -use super::*; -use crate::stats::Stats; -use dashmap::DashMap; -use std::sync::Arc; -use std::sync::atomic::{AtomicUsize, Ordering}; -use tokio::sync::Barrier; -use tokio::time::{Duration, timeout}; - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn blackhat_campaign_saturation_quota_race_with_queue_pressure_stays_fail_closed() { - let _guard = super::quota_user_lock_test_scope(); - let _pressure_guard = super::relay_idle_pressure_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - for idx in 0..QUOTA_USER_LOCKS_MAX { - retained.push(quota_user_lock(&format!( - "middle-blackhat-held-{}-{idx}", - std::process::id() - ))); - } - - assert_eq!( - map.len(), - QUOTA_USER_LOCKS_MAX, - "precondition: bounded lock cache must be saturated" - ); - - let (tx, _rx) = mpsc::channel::(1); - tx.send(C2MeCommand::Close) - .await - .expect("queue prefill should succeed"); - - let pressure_seq_before = relay_pressure_event_seq(); - let pressure_errors = Arc::new(AtomicUsize::new(0)); - let mut pressure_workers = Vec::new(); - for _ in 0..16 { - let tx = tx.clone(); - let pressure_errors = Arc::clone(&pressure_errors); - pressure_workers.push(tokio::spawn(async move { - if enqueue_c2me_command(&tx, C2MeCommand::Close).await.is_err() { - pressure_errors.fetch_add(1, Ordering::Relaxed); - } - })); - } - - let stats = Arc::new(Stats::new()); - let user = format!("middle-blackhat-quota-race-{}", std::process::id()); - let gate = Arc::new(Barrier::new(16)); - - let mut quota_workers = Vec::new(); - for _ in 0..16u8 { - let stats = Arc::clone(&stats); - let user = user.clone(); - let gate = Arc::clone(&gate); - quota_workers.push(tokio::spawn(async move { - gate.wait().await; - let user_lock = quota_user_lock(&user); - let _quota_guard = user_lock.lock().await; - - if quota_would_be_exceeded_for_user(&stats, &user, Some(1), 1) { - return false; - } - stats.add_user_octets_to(&user, 1); - true - })); - } - - let mut ok_count = 0usize; - let mut denied_count = 0usize; - for worker in quota_workers { - let result = timeout(Duration::from_secs(2), worker) - .await - .expect("quota worker must finish") - .expect("quota worker must not panic"); - if result { - ok_count += 1; - } else { - denied_count += 1; - } - } - - for worker in pressure_workers { - timeout(Duration::from_secs(2), worker) - .await - .expect("pressure worker must finish") - .expect("pressure worker must not panic"); - } - - assert_eq!( - stats.get_user_total_octets(&user), - 1, - "black-hat campaign must not overshoot same-user quota under saturation" - ); - assert!(ok_count <= 1, "at most one quota contender may succeed"); - assert!( - denied_count >= 15, - "all remaining contenders must be quota-denied" - ); - - let pressure_seq_after = relay_pressure_event_seq(); - assert!( - pressure_seq_after > pressure_seq_before, - "queue pressure leg must trigger pressure accounting" - ); - assert!( - pressure_errors.load(Ordering::Relaxed) >= 1, - "at least one pressure worker should fail from persistent backpressure" - ); - - drop(retained); -} diff --git a/src/proxy/tests/middle_relay_coverage_high_risk_security_tests.rs b/src/proxy/tests/middle_relay_coverage_high_risk_security_tests.rs deleted file mode 100644 index 44c201f..0000000 --- a/src/proxy/tests/middle_relay_coverage_high_risk_security_tests.rs +++ /dev/null @@ -1,777 +0,0 @@ -use super::*; -use crate::crypto::AesCtr; -use crate::crypto::SecureRandom; -use crate::stats::Stats; -use crate::stream::{BufferPool, PooledBuffer}; -use std::sync::Arc; -use tokio::io::AsyncReadExt; -use tokio::io::duplex; -use tokio::sync::mpsc; -use tokio::time::{Duration as TokioDuration, timeout}; - -fn make_pooled_payload(data: &[u8]) -> PooledBuffer { - let pool = Arc::new(BufferPool::with_config(data.len().max(1), 4)); - let mut payload = pool.get(); - payload.resize(data.len(), 0); - payload[..data.len()].copy_from_slice(data); - payload -} - -#[tokio::test] -async fn write_client_payload_abridged_short_quickack_sets_flag_and_preserves_payload() { - let (mut read_side, write_side) = duplex(4096); - let key = [0u8; 32]; - let iv = 0u128; - - let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); - let mut decryptor = AesCtr::new(&key, iv); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - let payload = vec![0xA1, 0xB2, 0xC3, 0xD4, 0x10, 0x20, 0x30, 0x40]; - - write_client_payload( - &mut writer, - ProtoTag::Abridged, - RPC_FLAG_QUICKACK, - &payload, - &rng, - &mut frame_buf, - ) - .await - .expect("abridged quickack payload should serialize"); - writer.flush().await.expect("flush must succeed"); - - let mut encrypted = vec![0u8; 1 + payload.len()]; - read_side - .read_exact(&mut encrypted) - .await - .expect("must read serialized abridged frame"); - let plaintext = decryptor.decrypt(&encrypted); - - assert_eq!(plaintext[0], 0x80 | ((payload.len() / 4) as u8)); - assert_eq!(&plaintext[1..], payload.as_slice()); -} - -#[tokio::test] -async fn write_client_payload_abridged_extended_header_is_encoded_correctly() { - let (mut read_side, write_side) = duplex(16 * 1024); - let key = [0u8; 32]; - let iv = 0u128; - - let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); - let mut decryptor = AesCtr::new(&key, iv); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - - // Boundary where abridged switches to extended length encoding. - let payload = vec![0x5Au8; 0x7f * 4]; - - write_client_payload( - &mut writer, - ProtoTag::Abridged, - RPC_FLAG_QUICKACK, - &payload, - &rng, - &mut frame_buf, - ) - .await - .expect("extended abridged payload should serialize"); - writer.flush().await.expect("flush must succeed"); - - let mut encrypted = vec![0u8; 4 + payload.len()]; - read_side - .read_exact(&mut encrypted) - .await - .expect("must read serialized extended abridged frame"); - let plaintext = decryptor.decrypt(&encrypted); - - assert_eq!(plaintext[0], 0xff, "0x7f with quickack bit must be set"); - assert_eq!(&plaintext[1..4], &[0x7f, 0x00, 0x00]); - assert_eq!(&plaintext[4..], payload.as_slice()); -} - -#[tokio::test] -async fn write_client_payload_abridged_misaligned_is_rejected_fail_closed() { - let (_read_side, write_side) = duplex(1024); - let key = [0u8; 32]; - let iv = 0u128; - - let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - - let err = write_client_payload( - &mut writer, - ProtoTag::Abridged, - 0, - &[1, 2, 3], - &rng, - &mut frame_buf, - ) - .await - .expect_err("misaligned abridged payload must be rejected"); - - let msg = format!("{err}"); - assert!( - msg.contains("4-byte aligned"), - "error should explain alignment contract, got: {msg}" - ); -} - -#[tokio::test] -async fn write_client_payload_secure_misaligned_is_rejected_fail_closed() { - let (_read_side, write_side) = duplex(1024); - let key = [0u8; 32]; - let iv = 0u128; - - let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - - let err = write_client_payload( - &mut writer, - ProtoTag::Secure, - 0, - &[9, 8, 7, 6, 5], - &rng, - &mut frame_buf, - ) - .await - .expect_err("misaligned secure payload must be rejected"); - - let msg = format!("{err}"); - assert!( - msg.contains("Secure payload must be 4-byte aligned"), - "error should be explicit for fail-closed triage, got: {msg}" - ); -} - -#[tokio::test] -async fn write_client_payload_intermediate_quickack_sets_length_msb() { - let (mut read_side, write_side) = duplex(4096); - let key = [0u8; 32]; - let iv = 0u128; - - let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); - let mut decryptor = AesCtr::new(&key, iv); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - let payload = b"hello-middle-relay"; - - write_client_payload( - &mut writer, - ProtoTag::Intermediate, - RPC_FLAG_QUICKACK, - payload, - &rng, - &mut frame_buf, - ) - .await - .expect("intermediate quickack payload should serialize"); - writer.flush().await.expect("flush must succeed"); - - let mut encrypted = vec![0u8; 4 + payload.len()]; - read_side - .read_exact(&mut encrypted) - .await - .expect("must read intermediate frame"); - let plaintext = decryptor.decrypt(&encrypted); - - let mut len_bytes = [0u8; 4]; - len_bytes.copy_from_slice(&plaintext[..4]); - let len_with_flags = u32::from_le_bytes(len_bytes); - assert_ne!(len_with_flags & 0x8000_0000, 0, "quickack bit must be set"); - assert_eq!((len_with_flags & 0x7fff_ffff) as usize, payload.len()); - assert_eq!(&plaintext[4..], payload); -} - -#[tokio::test] -async fn write_client_payload_secure_quickack_prefix_and_padding_bounds_hold() { - let (mut read_side, write_side) = duplex(4096); - let key = [0u8; 32]; - let iv = 0u128; - - let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); - let mut decryptor = AesCtr::new(&key, iv); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - let payload = vec![0x33u8; 100]; // 4-byte aligned as required by secure mode. - - write_client_payload( - &mut writer, - ProtoTag::Secure, - RPC_FLAG_QUICKACK, - &payload, - &rng, - &mut frame_buf, - ) - .await - .expect("secure quickack payload should serialize"); - writer.flush().await.expect("flush must succeed"); - - // Secure mode adds 1..=3 bytes of randomized tail padding. - let mut encrypted_header = [0u8; 4]; - read_side - .read_exact(&mut encrypted_header) - .await - .expect("must read secure header"); - let decrypted_header = decryptor.decrypt(&encrypted_header); - let header: [u8; 4] = decrypted_header - .try_into() - .expect("decrypted secure header must be 4 bytes"); - let wire_len_raw = u32::from_le_bytes(header); - - assert_ne!( - wire_len_raw & 0x8000_0000, - 0, - "secure quickack bit must be set" - ); - - let wire_len = (wire_len_raw & 0x7fff_ffff) as usize; - assert!(wire_len >= payload.len()); - let padding_len = wire_len - payload.len(); - assert!( - (1..=3).contains(&padding_len), - "secure writer must add bounded random tail padding, got {padding_len}" - ); - - let mut encrypted_body = vec![0u8; wire_len]; - read_side - .read_exact(&mut encrypted_body) - .await - .expect("must read secure body"); - let decrypted_body = decryptor.decrypt(&encrypted_body); - assert_eq!(&decrypted_body[..payload.len()], payload.as_slice()); -} - -#[tokio::test] -#[ignore = "heavy: allocates >64MiB to validate abridged too-large fail-closed branch"] -async fn write_client_payload_abridged_too_large_is_rejected_fail_closed() { - let (_read_side, write_side) = duplex(1024); - let key = [0u8; 32]; - let iv = 0u128; - - let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - - // Exactly one 4-byte word above the encodable 24-bit abridged length range. - let payload = vec![0x00u8; (1 << 24) * 4]; - let err = write_client_payload( - &mut writer, - ProtoTag::Abridged, - 0, - &payload, - &rng, - &mut frame_buf, - ) - .await - .expect_err("oversized abridged payload must be rejected"); - - let msg = format!("{err}"); - assert!( - msg.contains("Abridged frame too large"), - "error must clearly indicate oversize fail-close path, got: {msg}" - ); -} - -#[tokio::test] -async fn write_client_ack_intermediate_is_little_endian() { - let (mut read_side, write_side) = duplex(1024); - let key = [0u8; 32]; - let iv = 0u128; - let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); - let mut decryptor = AesCtr::new(&key, iv); - - write_client_ack(&mut writer, ProtoTag::Intermediate, 0x11_22_33_44) - .await - .expect("ack serialization should succeed"); - writer.flush().await.expect("flush must succeed"); - - let mut encrypted = [0u8; 4]; - read_side - .read_exact(&mut encrypted) - .await - .expect("must read ack bytes"); - let plain = decryptor.decrypt(&encrypted); - assert_eq!(plain.as_slice(), &0x11_22_33_44u32.to_le_bytes()); -} - -#[tokio::test] -async fn write_client_ack_abridged_is_big_endian() { - let (mut read_side, write_side) = duplex(1024); - let key = [0u8; 32]; - let iv = 0u128; - let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); - let mut decryptor = AesCtr::new(&key, iv); - - write_client_ack(&mut writer, ProtoTag::Abridged, 0xDE_AD_BE_EF) - .await - .expect("ack serialization should succeed"); - writer.flush().await.expect("flush must succeed"); - - let mut encrypted = [0u8; 4]; - read_side - .read_exact(&mut encrypted) - .await - .expect("must read ack bytes"); - let plain = decryptor.decrypt(&encrypted); - assert_eq!(plain.as_slice(), &0xDE_AD_BE_EFu32.to_be_bytes()); -} - -#[tokio::test] -async fn write_client_payload_abridged_short_boundary_0x7e_is_single_byte_header() { - let (mut read_side, write_side) = duplex(1024 * 1024); - let key = [0u8; 32]; - let iv = 0u128; - let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); - let mut decryptor = AesCtr::new(&key, iv); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - let payload = vec![0xABu8; 0x7e * 4]; - - write_client_payload( - &mut writer, - ProtoTag::Abridged, - 0, - &payload, - &rng, - &mut frame_buf, - ) - .await - .expect("boundary payload should serialize"); - writer.flush().await.expect("flush must succeed"); - - let mut encrypted = vec![0u8; 1 + payload.len()]; - read_side.read_exact(&mut encrypted).await.unwrap(); - let plain = decryptor.decrypt(&encrypted); - assert_eq!(plain[0], 0x7e); - assert_eq!(&plain[1..], payload.as_slice()); -} - -#[tokio::test] -async fn write_client_payload_abridged_extended_without_quickack_has_clean_prefix() { - let (mut read_side, write_side) = duplex(16 * 1024); - let key = [0u8; 32]; - let iv = 0u128; - let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); - let mut decryptor = AesCtr::new(&key, iv); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - let payload = vec![0x42u8; 0x80 * 4]; - - write_client_payload( - &mut writer, - ProtoTag::Abridged, - 0, - &payload, - &rng, - &mut frame_buf, - ) - .await - .expect("extended payload should serialize"); - writer.flush().await.expect("flush must succeed"); - - let mut encrypted = vec![0u8; 4 + payload.len()]; - read_side.read_exact(&mut encrypted).await.unwrap(); - let plain = decryptor.decrypt(&encrypted); - assert_eq!(plain[0], 0x7f); - assert_eq!(&plain[1..4], &[0x80, 0x00, 0x00]); - assert_eq!(&plain[4..], payload.as_slice()); -} - -#[tokio::test] -async fn write_client_payload_intermediate_zero_length_emits_header_only() { - let (mut read_side, write_side) = duplex(1024); - let key = [0u8; 32]; - let iv = 0u128; - let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); - let mut decryptor = AesCtr::new(&key, iv); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - - write_client_payload( - &mut writer, - ProtoTag::Intermediate, - 0, - &[], - &rng, - &mut frame_buf, - ) - .await - .expect("zero-length intermediate payload should serialize"); - writer.flush().await.expect("flush must succeed"); - - let mut encrypted = [0u8; 4]; - read_side.read_exact(&mut encrypted).await.unwrap(); - let plain = decryptor.decrypt(&encrypted); - assert_eq!(plain.as_slice(), &[0, 0, 0, 0]); -} - -#[tokio::test] -async fn write_client_payload_intermediate_ignores_unrelated_flags() { - let (mut read_side, write_side) = duplex(1024); - let key = [0u8; 32]; - let iv = 0u128; - let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); - let mut decryptor = AesCtr::new(&key, iv); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - let payload = [7u8; 12]; - - write_client_payload( - &mut writer, - ProtoTag::Intermediate, - 0x4000_0000, - &payload, - &rng, - &mut frame_buf, - ) - .await - .expect("payload should serialize"); - writer.flush().await.expect("flush must succeed"); - - let mut encrypted = [0u8; 16]; - read_side.read_exact(&mut encrypted).await.unwrap(); - let plain = decryptor.decrypt(&encrypted); - let len = u32::from_le_bytes(plain[0..4].try_into().unwrap()); - assert_eq!(len, payload.len() as u32, "only quickack bit may affect header"); - assert_eq!(&plain[4..], payload.as_slice()); -} - -#[tokio::test] -async fn write_client_payload_secure_without_quickack_keeps_msb_clear() { - let (mut read_side, write_side) = duplex(4096); - let key = [0u8; 32]; - let iv = 0u128; - let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); - let mut decryptor = AesCtr::new(&key, iv); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - let payload = [0x1Du8; 64]; - - write_client_payload( - &mut writer, - ProtoTag::Secure, - 0, - &payload, - &rng, - &mut frame_buf, - ) - .await - .expect("payload should serialize"); - writer.flush().await.expect("flush must succeed"); - - let mut encrypted_header = [0u8; 4]; - read_side.read_exact(&mut encrypted_header).await.unwrap(); - let plain_header = decryptor.decrypt(&encrypted_header); - let h: [u8; 4] = plain_header.as_slice().try_into().unwrap(); - let wire_len_raw = u32::from_le_bytes(h); - assert_eq!(wire_len_raw & 0x8000_0000, 0, "quickack bit must stay clear"); -} - -#[tokio::test] -async fn secure_padding_light_fuzz_distribution_has_multiple_outcomes() { - let (mut read_side, write_side) = duplex(256 * 1024); - let key = [0u8; 32]; - let iv = 0u128; - let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); - let mut decryptor = AesCtr::new(&key, iv); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - let payload = [0x55u8; 100]; - let mut seen = [false; 4]; - - for _ in 0..96 { - write_client_payload( - &mut writer, - ProtoTag::Secure, - 0, - &payload, - &rng, - &mut frame_buf, - ) - .await - .expect("secure payload should serialize"); - writer.flush().await.expect("flush must succeed"); - - let mut encrypted_header = [0u8; 4]; - read_side.read_exact(&mut encrypted_header).await.unwrap(); - let plain_header = decryptor.decrypt(&encrypted_header); - let h: [u8; 4] = plain_header.as_slice().try_into().unwrap(); - let wire_len = (u32::from_le_bytes(h) & 0x7fff_ffff) as usize; - let padding_len = wire_len - payload.len(); - assert!((1..=3).contains(&padding_len)); - seen[padding_len] = true; - - let mut encrypted_body = vec![0u8; wire_len]; - read_side.read_exact(&mut encrypted_body).await.unwrap(); - let _ = decryptor.decrypt(&encrypted_body); - } - - let distinct = (1..=3).filter(|idx| seen[*idx]).count(); - assert!( - distinct >= 2, - "padding generator should not collapse to a single outcome under campaign" - ); -} - -#[tokio::test] -async fn write_client_payload_mixed_proto_sequence_preserves_stream_sync() { - let (mut read_side, write_side) = duplex(128 * 1024); - let key = [0u8; 32]; - let iv = 0u128; - let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); - let mut decryptor = AesCtr::new(&key, iv); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - - let p1 = vec![1u8; 8]; - let p2 = vec![2u8; 16]; - let p3 = vec![3u8; 20]; - - write_client_payload(&mut writer, ProtoTag::Abridged, 0, &p1, &rng, &mut frame_buf) - .await - .unwrap(); - write_client_payload( - &mut writer, - ProtoTag::Intermediate, - RPC_FLAG_QUICKACK, - &p2, - &rng, - &mut frame_buf, - ) - .await - .unwrap(); - write_client_payload(&mut writer, ProtoTag::Secure, 0, &p3, &rng, &mut frame_buf) - .await - .unwrap(); - writer.flush().await.unwrap(); - - // Frame 1: abridged short. - let mut e1 = vec![0u8; 1 + p1.len()]; - read_side.read_exact(&mut e1).await.unwrap(); - let d1 = decryptor.decrypt(&e1); - assert_eq!(d1[0], (p1.len() / 4) as u8); - assert_eq!(&d1[1..], p1.as_slice()); - - // Frame 2: intermediate with quickack. - let mut e2 = vec![0u8; 4 + p2.len()]; - read_side.read_exact(&mut e2).await.unwrap(); - let d2 = decryptor.decrypt(&e2); - let l2 = u32::from_le_bytes(d2[0..4].try_into().unwrap()); - assert_ne!(l2 & 0x8000_0000, 0); - assert_eq!((l2 & 0x7fff_ffff) as usize, p2.len()); - assert_eq!(&d2[4..], p2.as_slice()); - - // Frame 3: secure with bounded tail. - let mut e3h = [0u8; 4]; - read_side.read_exact(&mut e3h).await.unwrap(); - let d3h = decryptor.decrypt(&e3h); - let l3 = (u32::from_le_bytes(d3h.as_slice().try_into().unwrap()) & 0x7fff_ffff) as usize; - assert!(l3 >= p3.len()); - assert!((1..=3).contains(&(l3 - p3.len()))); - let mut e3b = vec![0u8; l3]; - read_side.read_exact(&mut e3b).await.unwrap(); - let d3b = decryptor.decrypt(&e3b); - assert_eq!(&d3b[..p3.len()], p3.as_slice()); -} - -#[test] -fn should_yield_sender_boundary_matrix_blackhat() { - assert!(!should_yield_c2me_sender(0, false)); - assert!(!should_yield_c2me_sender(0, true)); - assert!(!should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET - 1, true)); - assert!(!should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET, false)); - assert!(should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET, true)); - assert!(should_yield_c2me_sender( - C2ME_SENDER_FAIRNESS_BUDGET.saturating_add(1024), - true - )); -} - -#[test] -fn should_yield_sender_light_fuzz_matches_oracle() { - let mut s: u64 = 0xD00D_BAAD_F00D_CAFE; - for _ in 0..5000 { - s ^= s << 7; - s ^= s >> 9; - s ^= s << 8; - let sent = (s as usize) & 0x1fff; - let backlog = (s & 1) != 0; - - let expected = backlog && sent >= C2ME_SENDER_FAIRNESS_BUDGET; - assert_eq!(should_yield_c2me_sender(sent, backlog), expected); - } -} - -#[test] -fn quota_would_be_exceeded_exact_remaining_one_byte() { - let stats = Stats::new(); - let user = "quota-edge"; - let quota = 100u64; - stats.add_user_octets_to(user, 99); - - assert!( - !quota_would_be_exceeded_for_user(&stats, user, Some(quota), 1), - "exactly remaining budget should be allowed" - ); - assert!( - quota_would_be_exceeded_for_user(&stats, user, Some(quota), 2), - "one byte beyond remaining budget must be rejected" - ); -} - -#[test] -fn quota_would_be_exceeded_saturating_edge_remains_fail_closed() { - let stats = Stats::new(); - let user = "quota-saturating-edge"; - let quota = u64::MAX - 3; - stats.add_user_octets_to(user, u64::MAX - 4); - - assert!( - quota_would_be_exceeded_for_user(&stats, user, Some(quota), 2), - "saturating arithmetic edge must stay fail-closed" - ); -} - -#[test] -fn quota_exceeded_boundary_is_inclusive() { - let stats = Stats::new(); - let user = "quota-inclusive-boundary"; - stats.add_user_octets_to(user, 50); - - assert!(quota_exceeded_for_user(&stats, user, Some(50))); - assert!(!quota_exceeded_for_user(&stats, user, Some(51))); -} - -#[test] -fn quota_soft_helper_matches_capped_generic_helper_matrix() { - let stats = Stats::new(); - let user = "quota-soft-parity"; - - for used in [0u64, 1, 7, 63, 127, 255] { - stats.sub_user_octets_to(user, stats.get_user_total_octets(user)); - stats.add_user_octets_to(user, used); - - for quota in [8u64, 64, 128, 256] { - for overshoot in [0u64, 1, 5, 32] { - for bytes in [0u64, 1, 2, 7, 31, 64] { - let soft = quota_would_be_exceeded_for_user_soft( - &stats, - user, - Some(quota), - bytes, - overshoot, - ); - let capped = quota_would_be_exceeded_for_user( - &stats, - user, - Some(quota_soft_cap(quota, overshoot)), - bytes, - ); - assert_eq!( - soft, capped, - "soft helper parity mismatch: used={used} quota={quota} overshoot={overshoot} bytes={bytes}" - ); - } - } - } - } -} - -#[test] -fn quota_soft_helper_none_limit_never_rejects() { - let stats = Stats::new(); - let user = "quota-soft-none"; - stats.add_user_octets_to(user, u64::MAX); - - assert!(!quota_would_be_exceeded_for_user_soft( - &stats, - user, - None, - u64::MAX, - u64::MAX, - )); -} - -#[test] -fn quota_soft_cap_saturates_and_stays_fail_closed() { - let stats = Stats::new(); - let user = "quota-soft-saturating"; - let quota = u64::MAX - 2; - let overshoot = 100; - - assert_eq!(quota_soft_cap(quota, overshoot), u64::MAX); - - stats.add_user_octets_to(user, u64::MAX - 1); - assert!(quota_would_be_exceeded_for_user_soft( - &stats, - user, - Some(quota), - 2, - overshoot, - )); -} - -#[tokio::test] -async fn enqueue_c2me_close_fast_path_succeeds_without_backpressure() { - let (tx, mut rx) = mpsc::channel::(4); - enqueue_c2me_command(&tx, C2MeCommand::Close) - .await - .expect("close should enqueue on fast path"); - - let recv = timeout(TokioDuration::from_millis(50), rx.recv()) - .await - .expect("must receive close command") - .expect("close command should be present"); - assert!(matches!(recv, C2MeCommand::Close)); -} - -#[tokio::test] -async fn enqueue_c2me_data_full_then_drain_preserves_order() { - let (tx, mut rx) = mpsc::channel::(1); - tx.send(C2MeCommand::Data { - payload: make_pooled_payload(&[1]), - flags: 10, - }) - .await - .unwrap(); - - let tx2 = tx.clone(); - let producer = tokio::spawn(async move { - enqueue_c2me_command( - &tx2, - C2MeCommand::Data { - payload: make_pooled_payload(&[2, 2]), - flags: 20, - }, - ) - .await - }); - - tokio::time::sleep(TokioDuration::from_millis(10)).await; - - let first = rx.recv().await.expect("first item should exist"); - match first { - C2MeCommand::Data { payload, flags } => { - assert_eq!(payload.as_ref(), &[1]); - assert_eq!(flags, 10); - } - C2MeCommand::Close => panic!("unexpected close as first item"), - } - - producer.await.unwrap().expect("producer should complete"); - - let second = timeout(TokioDuration::from_millis(100), rx.recv()) - .await - .unwrap() - .expect("second item should exist"); - match second { - C2MeCommand::Data { payload, flags } => { - assert_eq!(payload.as_ref(), &[2, 2]); - assert_eq!(flags, 20); - } - C2MeCommand::Close => panic!("unexpected close as second item"), - } -} diff --git a/src/proxy/tests/middle_relay_cross_mode_lock_release_regression_security_tests.rs b/src/proxy/tests/middle_relay_cross_mode_lock_release_regression_security_tests.rs deleted file mode 100644 index a787aa6..0000000 --- a/src/proxy/tests/middle_relay_cross_mode_lock_release_regression_security_tests.rs +++ /dev/null @@ -1,295 +0,0 @@ -use super::*; -use crate::crypto::{AesCtr, SecureRandom}; -use crate::stats::Stats; -use crate::stream::CryptoWriter; -use bytes::Bytes; -use std::pin::Pin; -use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; -use std::sync::{Arc, Mutex}; -use std::task::{Context, Poll, Waker}; -use tokio::io::AsyncWrite; -use tokio::sync::Notify; -use tokio::task::JoinSet; -use tokio::time::{Duration, timeout}; - -fn make_crypto_writer(writer: W) -> CryptoWriter -where - W: tokio::io::AsyncWrite + Unpin, -{ - let key = [0u8; 32]; - let iv = 0u128; - CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) -} - -#[derive(Default)] -struct BlockingWriteState { - write_entered: AtomicBool, - released: AtomicBool, - write_waker: Mutex>, - write_entered_notify: Notify, -} - -struct BlockingWrite { - state: Arc, -} - -impl BlockingWrite { - fn new(state: Arc) -> Self { - Self { state } - } -} - -impl AsyncWrite for BlockingWrite { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - self.state.write_entered.store(true, Ordering::Release); - self.state.write_entered_notify.notify_waiters(); - - if self.state.released.load(Ordering::Acquire) { - return Poll::Ready(Ok(buf.len())); - } - - if let Ok(mut slot) = self.state.write_waker.lock() { - *slot = Some(cx.waker().clone()); - } - - Poll::Pending - } - - fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } -} - -async fn wait_until_blocking_write_entered(state: &Arc) { - for _ in 0..8 { - if state.write_entered.load(Ordering::Acquire) { - return; - } - let _ = timeout(Duration::from_millis(25), state.write_entered_notify.notified()).await; - } - - panic!("blocking writer did not enter poll_write in bounded time"); -} - -fn release_blocking_write(state: &Arc) { - state.released.store(true, Ordering::Release); - if let Ok(mut slot) = state.write_waker.lock() - && let Some(waker) = slot.take() - { - waker.wake(); - } -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn adversarial_blocked_write_releases_cross_mode_lock_and_preserves_fail_closed_quota() { - let stats = Arc::new(Stats::new()); - let user = format!("middle-cross-release-regression-{}", std::process::id()); - let cross_mode_lock = Arc::new(cross_mode_quota_user_lock_for_tests(&user)); - let bytes_me2c = Arc::new(AtomicU64::new(0)); - let writer_state = Arc::new(BlockingWriteState::default()); - - let first = { - let stats = Arc::clone(&stats); - let user = user.clone(); - let cross_mode_lock = Arc::clone(&cross_mode_lock); - let bytes_me2c = Arc::clone(&bytes_me2c); - let writer_state = Arc::clone(&writer_state); - tokio::spawn(async move { - let mut writer = make_crypto_writer(BlockingWrite::new(writer_state)); - let mut frame_buf = Vec::new(); - process_me_writer_response_with_cross_mode_lock( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0xAA, 0xBB, 0xCC, 0xDD]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - stats.as_ref(), - &user, - Some(4), - 0, - Some(&cross_mode_lock), - bytes_me2c.as_ref(), - 41_000, - false, - false, - ) - .await - }) - }; - - wait_until_blocking_write_entered(&writer_state).await; - - let guard = timeout(Duration::from_millis(40), cross_mode_lock.lock()) - .await - .expect("cross-mode lock must be released while first write is pending"); - drop(guard); - - let second = { - let stats = Arc::clone(&stats); - let user = user.clone(); - let cross_mode_lock = Arc::clone(&cross_mode_lock); - let bytes_me2c = Arc::clone(&bytes_me2c); - tokio::spawn(async move { - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - timeout( - Duration::from_millis(150), - process_me_writer_response_with_cross_mode_lock( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0xEE]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - stats.as_ref(), - &user, - Some(4), - 0, - Some(&cross_mode_lock), - bytes_me2c.as_ref(), - 41_001, - false, - false, - ), - ) - .await - }) - }; - - let second_result = second - .await - .expect("second task must not panic") - .expect("second write must not block on cross-mode lock"); - assert!( - matches!(second_result, Err(ProxyError::DataQuotaExceeded { .. })), - "second write must fail closed due to first write reservation" - ); - - release_blocking_write(&writer_state); - - let first_result = timeout(Duration::from_millis(300), first) - .await - .expect("first task timed out") - .expect("first task must not panic"); - assert!(first_result.is_ok()); - - assert_eq!(stats.get_user_total_octets(&user), 4); - assert_eq!(bytes_me2c.load(Ordering::Relaxed), 4); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn stress_pending_write_does_not_starve_same_user_waiters_after_quota_boundary() { - let stats = Arc::new(Stats::new()); - let user = format!("middle-cross-release-stress-{}", std::process::id()); - let cross_mode_lock = Arc::new(cross_mode_quota_user_lock_for_tests(&user)); - let bytes_me2c = Arc::new(AtomicU64::new(0)); - let writer_state = Arc::new(BlockingWriteState::default()); - - let first = { - let stats = Arc::clone(&stats); - let user = user.clone(); - let cross_mode_lock = Arc::clone(&cross_mode_lock); - let bytes_me2c = Arc::clone(&bytes_me2c); - let writer_state = Arc::clone(&writer_state); - tokio::spawn(async move { - let mut writer = make_crypto_writer(BlockingWrite::new(writer_state)); - let mut frame_buf = Vec::new(); - process_me_writer_response_with_cross_mode_lock( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0x01, 0x02]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - stats.as_ref(), - &user, - Some(3), - 0, - Some(&cross_mode_lock), - bytes_me2c.as_ref(), - 41_100, - false, - false, - ) - .await - }) - }; - - wait_until_blocking_write_entered(&writer_state).await; - - let mut set = JoinSet::new(); - for idx in 0..48u64 { - let stats = Arc::clone(&stats); - let user = user.clone(); - let cross_mode_lock = Arc::clone(&cross_mode_lock); - let bytes_me2c = Arc::clone(&bytes_me2c); - set.spawn(async move { - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - timeout( - Duration::from_millis(200), - process_me_writer_response_with_cross_mode_lock( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0x10]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - stats.as_ref(), - &user, - Some(3), - 0, - Some(&cross_mode_lock), - bytes_me2c.as_ref(), - 41_200 + idx, - false, - false, - ), - ) - .await - }); - } - - let mut ok = 0usize; - let mut quota_exceeded = 0usize; - while let Some(done) = set.join_next().await { - let timed = done.expect("waiter task must not panic"); - let result = timed.expect("waiter must not block behind pending first write"); - match result { - Ok(_) => ok += 1, - Err(ProxyError::DataQuotaExceeded { .. }) => quota_exceeded += 1, - Err(other) => panic!("unexpected error in waiter: {other:?}"), - } - } - - assert_eq!(ok, 1, "exactly one waiter should consume remaining one-byte quota"); - assert_eq!(quota_exceeded, 47); - - release_blocking_write(&writer_state); - - let first_result = timeout(Duration::from_millis(300), first) - .await - .expect("first task timed out") - .expect("first task must not panic"); - assert!(first_result.is_ok()); - - assert_eq!(stats.get_user_total_octets(&user), 3); - assert_eq!(bytes_me2c.load(Ordering::Relaxed), 3); -} diff --git a/src/proxy/tests/middle_relay_cross_mode_lookup_efficiency_security_tests.rs b/src/proxy/tests/middle_relay_cross_mode_lookup_efficiency_security_tests.rs deleted file mode 100644 index 37e1b87..0000000 --- a/src/proxy/tests/middle_relay_cross_mode_lookup_efficiency_security_tests.rs +++ /dev/null @@ -1,116 +0,0 @@ -use super::*; -use crate::crypto::{AesCtr, SecureRandom}; -use crate::stats::Stats; -use crate::stream::CryptoWriter; -use bytes::Bytes; -use std::sync::Arc; -use std::sync::atomic::{AtomicU64, Ordering}; -use std::sync::{Mutex, OnceLock}; - -fn make_crypto_writer(writer: W) -> CryptoWriter -where - W: tokio::io::AsyncWrite + Unpin, -{ - let key = [0u8; 32]; - let iv = 0u128; - CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) -} - -fn lookup_counter_test_lock() -> &'static Mutex<()> { - static LOCK: OnceLock> = OnceLock::new(); - LOCK.get_or_init(|| Mutex::new(())) -} - -#[tokio::test] -async fn tdd_prefetched_cross_mode_lock_avoids_per_frame_registry_lookup_in_me_to_client_writer() { - let _guard = lookup_counter_test_lock() - .lock() - .unwrap_or_else(|poison| poison.into_inner()); - - let stats = Stats::new(); - let user = format!("middle-cross-mode-lookup-{}", std::process::id()); - let cross_mode_lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); - - crate::proxy::quota_lock_registry::reset_cross_mode_quota_user_lock_lookup_count_for_tests(); - - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - let bytes_me2c = AtomicU64::new(0); - - for idx in 0..8u64 { - let outcome = process_me_writer_response_with_cross_mode_lock( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0xAB]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - &user, - Some(1024), - 0, - Some(&cross_mode_lock), - &bytes_me2c, - 20_000 + idx, - false, - false, - ) - .await; - - assert!(outcome.is_ok()); - } - - assert_eq!( - crate::proxy::quota_lock_registry::cross_mode_quota_user_lock_lookup_count_for_user_for_tests(&user), - 0, - "prefetched lock path must not re-query lock registry per frame" - ); - assert_eq!(stats.get_user_total_octets(&user), 8); - assert_eq!(bytes_me2c.load(Ordering::Relaxed), 8); -} - -#[tokio::test] -async fn control_without_prefetched_lock_still_uses_registry_lookup_path() { - let _guard = lookup_counter_test_lock() - .lock() - .unwrap_or_else(|poison| poison.into_inner()); - - let stats = Stats::new(); - let user = format!("middle-cross-mode-lookup-control-{}", std::process::id()); - - crate::proxy::quota_lock_registry::reset_cross_mode_quota_user_lock_lookup_count_for_tests(); - - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - let bytes_me2c = AtomicU64::new(0); - - let outcome = process_me_writer_response_with_cross_mode_lock( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0xCD]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - &user, - Some(1024), - 0, - None, - &bytes_me2c, - 20_100, - false, - false, - ) - .await; - - assert!(outcome.is_ok()); - assert_eq!( - crate::proxy::quota_lock_registry::cross_mode_quota_user_lock_lookup_count_for_user_for_tests(&user), - 1, - "fallback path without prefetched lock should perform a registry lookup" - ); -} diff --git a/src/proxy/tests/middle_relay_cross_mode_quota_lock_matrix_security_tests.rs b/src/proxy/tests/middle_relay_cross_mode_quota_lock_matrix_security_tests.rs deleted file mode 100644 index bc7c857..0000000 --- a/src/proxy/tests/middle_relay_cross_mode_quota_lock_matrix_security_tests.rs +++ /dev/null @@ -1,376 +0,0 @@ -use super::*; -use crate::crypto::{AesCtr, SecureRandom}; -use crate::stats::Stats; -use crate::stream::CryptoWriter; -use bytes::Bytes; -use std::sync::Arc; -use std::sync::atomic::{AtomicU64, Ordering}; -use tokio::time::{Duration, timeout}; - -fn make_crypto_writer(writer: W) -> CryptoWriter -where - W: tokio::io::AsyncWrite + Unpin, -{ - let key = [0u8; 32]; - let iv = 0u128; - CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) -} - -#[tokio::test] -async fn positive_quota_limited_me_to_client_write_updates_counters_exactly_once() { - let stats = Stats::new(); - let user = format!("middle-cross-matrix-positive-{}", std::process::id()); - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - let bytes_me2c = AtomicU64::new(0); - - let result = process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[1, 2, 3, 4]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - &user, - Some(128), - 0, - &bytes_me2c, - 10_001, - false, - false, - ) - .await; - - assert!(result.is_ok()); - assert_eq!(stats.get_user_total_octets(&user), 4); - assert_eq!(bytes_me2c.load(Ordering::Relaxed), 4); -} - -#[tokio::test] -async fn negative_held_cross_mode_lock_blocks_quota_limited_me_to_client_path() { - let stats = Stats::new(); - let user = format!("middle-cross-matrix-negative-{}", std::process::id()); - let held = cross_mode_quota_user_lock_for_tests(&user); - let held_guard = held - .try_lock() - .expect("test must hold lock before ME->C call"); - - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - let bytes_me2c = AtomicU64::new(0); - - let blocked = timeout( - Duration::from_millis(25), - process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0x41]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - &user, - Some(256), - 0, - &bytes_me2c, - 10_002, - false, - false, - ), - ) - .await; - - assert!(blocked.is_err()); - drop(held_guard); -} - -#[tokio::test] -async fn edge_quota_none_bypasses_cross_mode_lock_guard_in_me_to_client_path() { - let stats = Stats::new(); - let user = format!("middle-cross-matrix-edge-none-{}", std::process::id()); - let held = cross_mode_quota_user_lock_for_tests(&user); - let held_guard = held - .try_lock() - .expect("test must hold lock while quota is disabled"); - - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - let bytes_me2c = AtomicU64::new(0); - - let outcome = timeout( - Duration::from_millis(80), - process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0x11, 0x22]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - &user, - None, - 0, - &bytes_me2c, - 10_003, - false, - false, - ), - ) - .await - .expect("quota-none path must not wait on cross-mode lock"); - - assert!(outcome.is_ok()); - drop(held_guard); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn adversarial_same_user_parallel_quota_limited_writes_stay_hard_capped() { - let stats = Arc::new(Stats::new()); - let user = format!("middle-cross-matrix-adversarial-{}", std::process::id()); - let limit = 64u64; - let bytes_me2c = Arc::new(AtomicU64::new(0)); - let mut tasks = Vec::new(); - - for idx in 0..256u64 { - let stats = Arc::clone(&stats); - let bytes_me2c = Arc::clone(&bytes_me2c); - let user = user.clone(); - tasks.push(tokio::spawn(async move { - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0xEE]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - stats.as_ref(), - &user, - Some(limit), - 0, - bytes_me2c.as_ref(), - 11_000 + idx, - false, - false, - ) - .await - })); - } - - let mut ok = 0usize; - for task in tasks { - match task.await.expect("task must not panic") { - Ok(_) => ok += 1, - Err(ProxyError::DataQuotaExceeded { .. }) => {} - Err(other) => panic!("unexpected error in adversarial parallel case: {other:?}"), - } - } - - assert_eq!(ok, limit as usize); - assert_eq!(stats.get_user_total_octets(&user), limit); - assert_eq!(bytes_me2c.load(Ordering::Relaxed), limit); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn integration_shared_lock_blocks_direct_relay_and_middle_relay_for_same_user() { - let user = format!("middle-cross-matrix-integration-{}", std::process::id()); - let relay_lock = crate::proxy::relay::cross_mode_quota_user_lock_for_tests(&user); - let middle_lock = cross_mode_quota_user_lock_for_tests(&user); - assert!( - Arc::ptr_eq(&relay_lock, &middle_lock), - "relay and middle-relay must share the same cross-mode lock identity" - ); - - let held_guard = relay_lock - .try_lock() - .expect("test must hold shared cross-mode lock"); - - let stats = Stats::new(); - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - let bytes_me2c = AtomicU64::new(0); - - let middle_blocked = timeout( - Duration::from_millis(25), - process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0x92]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - &user, - Some(1024), - 0, - &bytes_me2c, - 12_001, - false, - false, - ), - ) - .await; - assert!(middle_blocked.is_err()); - - drop(held_guard); - - let middle_ready = timeout( - Duration::from_millis(250), - process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0x94]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - &user, - Some(1024), - 0, - &bytes_me2c, - 12_002, - false, - false, - ), - ) - .await - .expect("middle path must complete after release"); - - assert!(middle_ready.is_ok()); -} - -#[tokio::test] -async fn light_fuzz_mixed_payload_sizes_with_periodic_lock_holds_keeps_accounting_consistent() { - let stats = Stats::new(); - let user = format!("middle-cross-matrix-fuzz-{}", std::process::id()); - let bytes_me2c = AtomicU64::new(0); - let mut seed = 0xC0DE_1234_55AA_9988u64; - - for case in 0..96u32 { - seed ^= seed << 7; - seed ^= seed >> 9; - seed ^= seed << 8; - - let hold = (seed & 0x03) == 0; - let mut held_lock = None; - let maybe_guard = if hold { - held_lock = Some(cross_mode_quota_user_lock_for_tests(&user)); - Some( - held_lock - .as_ref() - .expect("held lock should be present") - .try_lock() - .expect("cross-mode lock should be acquirable in fuzz round"), - ) - } else { - None - }; - - let payload_len = ((seed >> 8) as usize % 8) + 1; - let payload = vec![(seed & 0xff) as u8; payload_len]; - let before = stats.get_user_total_octets(&user); - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - - let timed = timeout( - Duration::from_millis(20), - process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from(payload), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - &user, - Some(1024), - 0, - &bytes_me2c, - 13_000 + case as u64, - false, - false, - ), - ) - .await; - - if hold { - assert!(timed.is_err(), "held-lock fuzz round must block within timeout"); - assert_eq!(stats.get_user_total_octets(&user), before); - } else { - let done = timed.expect("unheld fuzz round must complete in time"); - assert!(done.is_ok()); - } - - drop(maybe_guard); - drop(held_lock); - assert_eq!(bytes_me2c.load(Ordering::Relaxed), stats.get_user_total_octets(&user)); - } -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn stress_held_user_lock_does_not_block_other_users_me_to_client_writes() { - let held_user = format!("middle-cross-matrix-stress-held-{}", std::process::id()); - let free_user = format!("middle-cross-matrix-stress-free-{}", std::process::id()); - - let held = cross_mode_quota_user_lock_for_tests(&held_user); - let held_guard = held - .try_lock() - .expect("test must hold lock for blocked user"); - - let mut tasks = Vec::new(); - for idx in 0..64u64 { - let user = free_user.clone(); - tasks.push(tokio::spawn(async move { - let stats = Stats::new(); - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - let bytes_me2c = AtomicU64::new(0); - process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0xA0]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - &user, - Some(1), - 0, - &bytes_me2c, - 14_000 + idx, - false, - false, - ) - .await - })); - } - - timeout(Duration::from_secs(2), async { - for task in tasks { - let done = task.await.expect("free-user task must not panic"); - assert!(done.is_ok()); - } - }) - .await - .expect("free-user tasks should complete without waiting for held user's lock"); - - drop(held_guard); -} diff --git a/src/proxy/tests/middle_relay_cross_mode_quota_reservation_security_tests.rs b/src/proxy/tests/middle_relay_cross_mode_quota_reservation_security_tests.rs deleted file mode 100644 index 51092bd..0000000 --- a/src/proxy/tests/middle_relay_cross_mode_quota_reservation_security_tests.rs +++ /dev/null @@ -1,254 +0,0 @@ -use super::*; -use crate::crypto::{AesCtr, SecureRandom}; -use crate::stats::Stats; -use crate::stream::CryptoWriter; -use bytes::Bytes; -use std::pin::Pin; -use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; -use std::sync::{Arc, Mutex}; -use std::task::{Context, Poll, Waker}; -use tokio::io::AsyncWrite; -use tokio::sync::Notify; -use tokio::time::{Duration, timeout}; - -fn make_crypto_writer(writer: W) -> CryptoWriter -where - W: tokio::io::AsyncWrite + Unpin, -{ - let key = [0u8; 32]; - let iv = 0u128; - CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) -} - -#[derive(Default)] -struct BlockingWriteState { - write_entered: AtomicBool, - released: AtomicBool, - write_waker: Mutex>, - write_entered_notify: Notify, -} - -struct BlockingWrite { - state: Arc, -} - -impl BlockingWrite { - fn new(state: Arc) -> Self { - Self { state } - } -} - -impl AsyncWrite for BlockingWrite { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - self.state.write_entered.store(true, Ordering::Release); - self.state.write_entered_notify.notify_waiters(); - - if self.state.released.load(Ordering::Acquire) { - return Poll::Ready(Ok(buf.len())); - } - - if let Ok(mut slot) = self.state.write_waker.lock() { - *slot = Some(cx.waker().clone()); - } - Poll::Pending - } - - fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } -} - -async fn wait_until_blocking_write_entered(state: &Arc) { - for _ in 0..8 { - if state.write_entered.load(Ordering::Acquire) { - return; - } - let _ = timeout(Duration::from_millis(25), state.write_entered_notify.notified()).await; - } - - panic!("blocking writer did not enter poll_write in bounded time"); -} - -fn release_blocking_write(state: &Arc) { - state.released.store(true, Ordering::Release); - if let Ok(mut slot) = state.write_waker.lock() - && let Some(waker) = slot.take() - { - waker.wake(); - } -} - -#[tokio::test] -async fn adversarial_held_cross_mode_lock_blocks_me_to_client_quota_reservation_path() { - let stats = Stats::new(); - let user = format!("middle-me2c-cross-mode-held-{}", std::process::id()); - let held = cross_mode_quota_user_lock_for_tests(&user); - let held_guard = held - .try_lock() - .expect("test must hold shared cross-mode lock before ME->C write path"); - - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - let bytes_me2c = AtomicU64::new(0); - - let blocked = timeout( - Duration::from_millis(25), - process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0x41]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - &user, - Some(1024), - 0, - &bytes_me2c, - 9901, - false, - false, - ), - ) - .await; - - assert!( - blocked.is_err(), - "ME->C quota reservation path must be serialized by held shared cross-mode lock" - ); - - drop(held_guard); - - let released = timeout( - Duration::from_millis(250), - process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0x42]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - &user, - Some(1024), - 0, - &bytes_me2c, - 9902, - false, - false, - ), - ) - .await - .expect("ME->C write must complete after cross-mode lock release"); - - assert!(released.is_ok()); -} - -#[tokio::test] -async fn business_uncontended_cross_mode_lock_allows_me_to_client_quota_reservation() { - let stats = Stats::new(); - let user = format!("middle-me2c-cross-mode-free-{}", std::process::id()); - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - let bytes_me2c = AtomicU64::new(0); - - let outcome = timeout( - Duration::from_millis(250), - process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0x55, 0x66]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - &user, - Some(1024), - 0, - &bytes_me2c, - 9903, - false, - false, - ), - ) - .await - .expect("uncontended ME->C path should not stall"); - - assert!(outcome.is_ok()); - assert_eq!(stats.get_user_total_octets(&user), 2); - assert_eq!(bytes_me2c.load(std::sync::atomic::Ordering::Relaxed), 2); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn adversarial_cross_mode_lock_is_released_before_me_to_client_write_await() { - let stats = Arc::new(Stats::new()); - let user = format!("middle-me2c-lock-drop-before-write-{}", std::process::id()); - let cross_mode_lock = cross_mode_quota_user_lock_for_tests(&user); - let bytes_me2c = Arc::new(AtomicU64::new(0)); - let writer_state = Arc::new(BlockingWriteState::default()); - - let worker = { - let stats = Arc::clone(&stats); - let user = user.clone(); - let cross_mode_lock = Arc::clone(&cross_mode_lock); - let bytes_me2c = Arc::clone(&bytes_me2c); - let writer_state = Arc::clone(&writer_state); - tokio::spawn(async move { - let mut writer = make_crypto_writer(BlockingWrite::new(writer_state)); - let mut frame_buf = Vec::new(); - let rng = SecureRandom::new(); - process_me_writer_response_with_cross_mode_lock( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0xDE, 0xAD, 0xBE, 0xEF]), - }, - &mut writer, - ProtoTag::Intermediate, - &rng, - &mut frame_buf, - stats.as_ref(), - &user, - Some(1024), - 0, - Some(&cross_mode_lock), - bytes_me2c.as_ref(), - 9910, - false, - false, - ) - .await - }) - }; - - wait_until_blocking_write_entered(&writer_state).await; - - let acquired_guard = timeout(Duration::from_millis(40), cross_mode_lock.lock()) - .await - .expect("cross-mode lock must be free while ME->C write is pending"); - drop(acquired_guard); - - release_blocking_write(&writer_state); - - let result = timeout(Duration::from_millis(300), worker) - .await - .expect("ME->C worker timed out after releasing blocking writer") - .expect("ME->C worker must not panic"); - - assert!(result.is_ok()); - assert_eq!(stats.get_user_total_octets(&user), 4); - assert_eq!(bytes_me2c.load(Ordering::Relaxed), 4); -} diff --git a/src/proxy/tests/middle_relay_hol_quota_security_tests.rs b/src/proxy/tests/middle_relay_hol_quota_security_tests.rs deleted file mode 100644 index 3ce0235..0000000 --- a/src/proxy/tests/middle_relay_hol_quota_security_tests.rs +++ /dev/null @@ -1,232 +0,0 @@ -use super::*; -use crate::crypto::{AesCtr, SecureRandom}; -use crate::stats::Stats; -use crate::stream::CryptoWriter; -use bytes::Bytes; -use std::io; -use std::pin::Pin; -use std::sync::Arc; -use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; -use std::task::{Context, Poll, Waker}; -use tokio::io::AsyncWrite; -use tokio::time::{Duration, timeout}; - -fn make_crypto_writer(writer: W) -> CryptoWriter -where - W: tokio::io::AsyncWrite + Unpin, -{ - let key = [0u8; 32]; - let iv = 0u128; - CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) -} - -#[derive(Default)] -struct GateState { - open: AtomicBool, - parked_waker: std::sync::Mutex>, -} - -impl GateState { - fn open(&self) { - self.open.store(true, Ordering::Relaxed); - if let Ok(mut guard) = self.parked_waker.lock() - && let Some(w) = guard.take() - { - w.wake(); - } - } - - fn has_waiter(&self) -> bool { - self.parked_waker - .lock() - .map(|guard| guard.is_some()) - .unwrap_or(false) - } -} - -#[derive(Default)] -struct GateWriter { - gate: Arc, -} - -impl GateWriter { - fn new(gate: Arc) -> Self { - Self { gate } - } -} - -impl AsyncWrite for GateWriter { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - if self.gate.open.load(Ordering::Relaxed) { - return Poll::Ready(Ok(buf.len())); - } - - if let Ok(mut guard) = self.gate.parked_waker.lock() { - *guard = Some(cx.waker().clone()); - } - Poll::Pending - } - - fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } -} - -struct FailingWriter; - -impl AsyncWrite for FailingWriter { - fn poll_write( - self: Pin<&mut Self>, - _cx: &mut Context<'_>, - _buf: &[u8], - ) -> Poll> { - Poll::Ready(Err(io::Error::new( - io::ErrorKind::BrokenPipe, - "injected writer failure", - ))) - } - - fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn adversarial_same_user_slow_writer_must_not_hol_block_peer_connection() { - let stats = Stats::new(); - let bytes_me2c = AtomicU64::new(0); - let rng = SecureRandom::new(); - let quota_limit = Some(1024); - let user = "hol-quota-user"; - - let gate = Arc::new(GateState::default()); - - let mut blocked_writer = make_crypto_writer(GateWriter::new(Arc::clone(&gate))); - let slow_task = tokio::spawn(async move { - let mut frame_buf = Vec::new(); - process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0x10, 0x20, 0x30, 0x40]), - }, - &mut blocked_writer, - ProtoTag::Intermediate, - &rng, - &mut frame_buf, - &stats, - user, - quota_limit, - 0, - &bytes_me2c, - 7001, - false, - false, - ) - .await - }); - - timeout(Duration::from_millis(100), async { - loop { - if gate.has_waiter() { - break; - } - tokio::task::yield_now().await; - } - }) - .await - .expect("first writer must reach backpressure and park"); - - let stats_fast = Stats::new(); - let bytes_fast = AtomicU64::new(0); - let rng_fast = SecureRandom::new(); - let mut fast_writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf_fast = Vec::new(); - - timeout( - Duration::from_millis(50), - process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0x41]), - }, - &mut fast_writer, - ProtoTag::Intermediate, - &rng_fast, - &mut frame_buf_fast, - &stats_fast, - user, - quota_limit, - 0, - &bytes_fast, - 7002, - false, - false, - ), - ) - .await - .expect("peer connection must not be blocked by same-user stalled write") - .expect("fast peer write must succeed"); - - gate.open(); - let slow_result = timeout(Duration::from_secs(1), slow_task) - .await - .expect("stalled task must complete once gate opens") - .expect("stalled task must not panic"); - assert!(slow_result.is_ok()); -} - -#[tokio::test] -async fn negative_write_failure_rolls_back_pre_accounted_quota_and_forensics_bytes() { - let stats = Stats::new(); - let user = "rollback-user"; - stats.add_user_octets_from(user, 7); - - let bytes_me2c = AtomicU64::new(0); - let rng = SecureRandom::new(); - let mut writer = make_crypto_writer(FailingWriter); - let mut frame_buf = Vec::new(); - - let result = process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[1, 2, 3, 4]), - }, - &mut writer, - ProtoTag::Intermediate, - &rng, - &mut frame_buf, - &stats, - user, - Some(64), - 0, - &bytes_me2c, - 7003, - false, - false, - ) - .await; - - assert!(matches!(result, Err(ProxyError::Io(_)))); - assert_eq!( - stats.get_user_total_octets(user), - 7, - "failed client write must not overcharge user quota accounting" - ); - assert_eq!( - bytes_me2c.load(Ordering::Relaxed), - 0, - "failed client write must not inflate ME->C forensic byte counter" - ); -} \ No newline at end of file diff --git a/src/proxy/tests/middle_relay_quota_extended_attack_surface_security_tests.rs b/src/proxy/tests/middle_relay_quota_extended_attack_surface_security_tests.rs deleted file mode 100644 index 29384e0..0000000 --- a/src/proxy/tests/middle_relay_quota_extended_attack_surface_security_tests.rs +++ /dev/null @@ -1,372 +0,0 @@ -use super::*; -use crate::crypto::{AesCtr, SecureRandom}; -use crate::error::ProxyError; -use crate::stats::Stats; -use crate::stream::CryptoWriter; -use bytes::Bytes; -use std::sync::atomic::{AtomicU64, Ordering}; -use std::sync::{Arc, OnceLock, Mutex}; -use tokio::sync::Mutex as AsyncMutex; -use tokio::task::JoinSet; -use tokio::time::{Duration, timeout}; - -fn make_crypto_writer(writer: W) -> CryptoWriter -where - W: tokio::io::AsyncWrite + Unpin, -{ - let key = [0u8; 32]; - let iv = 0u128; - CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) -} - -fn lookup_test_lock() -> &'static Mutex<()> { - static LOCK: OnceLock> = OnceLock::new(); - LOCK.get_or_init(|| Mutex::new(())) -} - -#[tokio::test] -async fn positive_me2c_quota_counts_bytes_exactly_once() { - let _guard = lookup_test_lock().lock().unwrap(); - let stats = Stats::new(); - let user = format!("quota-middle-ext-positive-{}", std::process::id()); - let lock = Arc::new(AsyncMutex::new(())); - - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - let bytes_me2c = AtomicU64::new(0); - - let result = process_me_writer_response_with_cross_mode_lock( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[1, 2, 3, 4, 5]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - &user, - Some(64), - 0, - Some(&lock), - &bytes_me2c, - 70_001, - false, - false, - ) - .await; - - assert!(result.is_ok()); - assert_eq!(stats.get_user_total_octets(&user), 5); - assert_eq!(bytes_me2c.load(Ordering::Relaxed), 5); -} - -#[tokio::test] -async fn negative_held_crossmode_lock_blocks_me2c_write() { - let _guard = lookup_test_lock().lock().unwrap(); - let stats = Stats::new(); - let user = format!("quota-middle-ext-negative-{}", std::process::id()); - - let lock = Arc::new(AsyncMutex::new(())); - let _held = lock.try_lock().expect("lock must be held"); - - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - let bytes_me2c = AtomicU64::new(0); - - let blocked = timeout( - Duration::from_millis(25), - process_me_writer_response_with_cross_mode_lock( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0xFE]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - &user, - Some(16), - 0, - Some(&lock), - &bytes_me2c, - 70_101, - false, - false, - ), - ) - .await; - - assert!(blocked.is_err()); - assert_eq!(stats.get_user_total_octets(&user), 0); - assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0); -} - -#[tokio::test] -async fn edge_zero_quota_zero_payload_is_fail_closed() { - let _guard = lookup_test_lock().lock().unwrap(); - let stats = Stats::new(); - let user = format!("quota-middle-ext-edge-{}", std::process::id()); - - let lock = Arc::new(AsyncMutex::new(())); - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - let bytes_me2c = AtomicU64::new(0); - - let result = process_me_writer_response_with_cross_mode_lock( - MeResponse::Data { - flags: 0, - data: Bytes::new(), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - &user, - Some(0), - 0, - Some(&lock), - &bytes_me2c, - 70_201, - false, - false, - ) - .await; - - assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. }))); - assert_eq!(stats.get_user_total_octets(&user), 0); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn adversarial_parallel_me2c_race_falls_back_to_quota_error() { - let _guard = lookup_test_lock().lock().unwrap(); - let stats = Arc::new(Stats::new()); - let user = format!("quota-middle-ext-blackhat-{}", std::process::id()); - let quota = 64u64; - let lock = Arc::new(AsyncMutex::new(())); - let bytes_me2c = Arc::new(AtomicU64::new(0)); - - let mut set = JoinSet::new(); - for i in 0..256u64 { - let stats = Arc::clone(&stats); - let user = user.clone(); - let lock = Arc::clone(&lock); - let bytes_me2c = Arc::clone(&bytes_me2c); - - set.spawn(async move { - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - let payload = vec![((i & 0xFF) as u8); (i % 4 + 1) as usize]; - - process_me_writer_response_with_cross_mode_lock( - MeResponse::Data { - flags: 0, - data: Bytes::from(payload), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - stats.as_ref(), - &user, - Some(quota), - 0, - Some(&lock), - bytes_me2c.as_ref(), - 70_301 + i, - false, - false, - ) - .await - }); - } - - let mut succeeded = 0usize; - while let Some(done) = set.join_next().await { - match done.expect("task must not panic") { - Ok(_) => succeeded += 1, - Err(ProxyError::DataQuotaExceeded { .. }) => {} - Err(other) => panic!("unexpected error {other:?}"), - } - } - - assert_eq!(stats.get_user_total_octets(&user), bytes_me2c.load(Ordering::Relaxed)); - assert!(stats.get_user_total_octets(&user) <= quota); - assert!(succeeded <= quota as usize); -} - -#[tokio::test] -async fn integration_shared_prefetched_lock_blocks_then_releases_writer() { - let stats = Stats::new(); - let user = format!("quota-middle-ext-integration-{}", std::process::id()); - let lock = Arc::new(AsyncMutex::new(())); - let held = lock - .try_lock() - .expect("integration test must hold prefetched lock first"); - - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - let bytes_me2c = AtomicU64::new(0); - - let blocked = timeout( - Duration::from_millis(25), - process_me_writer_response_with_cross_mode_lock( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0xA1]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - &user, - Some(8), - 0, - Some(&lock), - &bytes_me2c, - 70_360, - false, - false, - ), - ) - .await; - assert!(blocked.is_err()); - - drop(held); - - let after_release = timeout( - Duration::from_millis(150), - process_me_writer_response_with_cross_mode_lock( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0xA2]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - &user, - Some(8), - 0, - Some(&lock), - &bytes_me2c, - 70_361, - false, - false, - ), - ) - .await - .expect("writer should progress once the shared lock is released"); - - assert!(after_release.is_ok()); -} - -#[tokio::test] -async fn light_fuzz_small_payloads_toggle_lock_state_stays_consistent() { - let _guard = lookup_test_lock().lock().unwrap(); - let stats = Stats::new(); - let user = format!("quota-middle-ext-fuzz-{}", std::process::id()); - let mut seed = 0xCAFE_BABE_1234u64; - let bytes_me2c = AtomicU64::new(0); - - for case in 0..48u32 { - seed ^= seed << 5; - seed ^= seed >> 12; - seed ^= seed << 13; - let hold = (seed & 0x1) == 0; - - let lock = Arc::new(AsyncMutex::new(())); - let maybe_guard = if hold { - Some(lock.try_lock().unwrap()) - } else { - None - }; - - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - - let result = timeout( - Duration::from_millis(30), - process_me_writer_response_with_cross_mode_lock( - MeResponse::Data { - flags: 0, - data: Bytes::from(vec![(seed & 0xFF) as u8; ((seed as usize % 5) + 1)]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - &user, - Some(128), - 0, - Some(&lock), - &bytes_me2c, - 70_401 + case as u64, - false, - false, - ), - ) - .await; - - if hold { - assert!(result.is_err()); - } else { - assert!(result.unwrap().is_ok()); - } - - drop(maybe_guard); - } -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn stress_parallel_free_users_during_held_user_lock_maintains_liveness() { - let _guard = lookup_test_lock().lock().unwrap(); - let held = Arc::new(AsyncMutex::new(())); - let _held_guard = held.try_lock().unwrap(); - - let mut set = JoinSet::new(); - for i in 0..48u64 { - set.spawn(async move { - let stats = Stats::new(); - let user = format!("quota-middle-ext-stress-free-{i}"); - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - let bytes_me2c = AtomicU64::new(0); - let free_lock = Arc::new(AsyncMutex::new(())); - - process_me_writer_response_with_cross_mode_lock( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0xEE]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - &user, - Some(1), - 0, - Some(&free_lock), - &bytes_me2c, - 70_500 + i, - false, - false, - ) - .await - }); - } - - timeout(Duration::from_secs(2), async { - while let Some(task) = set.join_next().await { - task.unwrap().unwrap(); - } - }) - .await - .unwrap(); -} diff --git a/src/proxy/tests/middle_relay_quota_overflow_lock_security_tests.rs b/src/proxy/tests/middle_relay_quota_overflow_lock_security_tests.rs deleted file mode 100644 index d06e103..0000000 --- a/src/proxy/tests/middle_relay_quota_overflow_lock_security_tests.rs +++ /dev/null @@ -1,131 +0,0 @@ -use super::*; -use dashmap::DashMap; -use std::sync::Arc; - -#[test] -fn saturation_uses_stable_overflow_lock_without_cache_growth() { - let _guard = super::quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let prefix = format!("middle-quota-held-{}", std::process::id()); - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - for idx in 0..QUOTA_USER_LOCKS_MAX { - retained.push(quota_user_lock(&format!("{prefix}-{idx}"))); - } - - assert_eq!(map.len(), QUOTA_USER_LOCKS_MAX); - - let user = format!("middle-quota-overflow-{}", std::process::id()); - let first = quota_user_lock(&user); - let second = quota_user_lock(&user); - - assert!( - Arc::ptr_eq(&first, &second), - "overflow user must get deterministic same lock while cache is saturated" - ); - assert_eq!( - map.len(), - QUOTA_USER_LOCKS_MAX, - "overflow path must not grow bounded lock map" - ); - assert!( - map.get(&user).is_none(), - "overflow user should stay outside bounded lock map under saturation" - ); - - drop(retained); -} - -#[test] -fn overflow_striping_keeps_different_users_distributed() { - let _guard = super::quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let prefix = format!("middle-quota-dist-held-{}", std::process::id()); - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - for idx in 0..QUOTA_USER_LOCKS_MAX { - retained.push(quota_user_lock(&format!("{prefix}-{idx}"))); - } - - let a = quota_user_lock("middle-overflow-user-a"); - let b = quota_user_lock("middle-overflow-user-b"); - let c = quota_user_lock("middle-overflow-user-c"); - - let distinct = [ - Arc::as_ptr(&a) as usize, - Arc::as_ptr(&b) as usize, - Arc::as_ptr(&c) as usize, - ] - .iter() - .copied() - .collect::>() - .len(); - - assert!( - distinct >= 2, - "striped overflow lock set should avoid collapsing all users to one lock" - ); - - drop(retained); -} - -#[test] -fn reclaim_path_caches_new_user_after_stale_entries_drop() { - let _guard = super::quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let prefix = format!("middle-quota-reclaim-held-{}", std::process::id()); - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - for idx in 0..QUOTA_USER_LOCKS_MAX { - retained.push(quota_user_lock(&format!("{prefix}-{idx}"))); - } - - drop(retained); - - let user = format!("middle-quota-reclaim-user-{}", std::process::id()); - let got = quota_user_lock(&user); - assert!(map.get(&user).is_some()); - assert!( - Arc::strong_count(&got) >= 2, - "after reclaim, lock should be held both by caller and map" - ); -} - -#[test] -fn overflow_path_same_user_is_stable_across_parallel_threads() { - let _guard = super::quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - for idx in 0..QUOTA_USER_LOCKS_MAX { - retained.push(quota_user_lock(&format!( - "middle-quota-thread-held-{}-{idx}", - std::process::id() - ))); - } - - let user = format!("middle-quota-overflow-thread-user-{}", std::process::id()); - let mut workers = Vec::new(); - for _ in 0..32 { - let user = user.clone(); - workers.push(std::thread::spawn(move || quota_user_lock(&user))); - } - - let first = workers - .remove(0) - .join() - .expect("thread must return lock handle"); - for worker in workers { - let got = worker.join().expect("thread must return lock handle"); - assert!( - Arc::ptr_eq(&first, &got), - "same overflow user should resolve to one striped lock even under contention" - ); - } - - drop(retained); -} diff --git a/src/proxy/tests/middle_relay_quota_reservation_adversarial_tests.rs b/src/proxy/tests/middle_relay_quota_reservation_adversarial_tests.rs deleted file mode 100644 index 963b3e0..0000000 --- a/src/proxy/tests/middle_relay_quota_reservation_adversarial_tests.rs +++ /dev/null @@ -1,1066 +0,0 @@ -use super::*; -use crate::crypto::{AesCtr, SecureRandom}; -use crate::stats::Stats; -use crate::stream::CryptoWriter; -use bytes::Bytes; -use std::sync::Arc; -use std::sync::atomic::{AtomicU64, Ordering}; -use std::task::{Context, Poll}; -use tokio::io::AsyncWrite; -use tokio::task::JoinSet; - -fn make_crypto_writer(writer: W) -> CryptoWriter -where - W: tokio::io::AsyncWrite + Unpin, -{ - let key = [0u8; 32]; - let iv = 0u128; - CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) -} - -struct FailingWriter; - -impl AsyncWrite for FailingWriter { - fn poll_write( - self: std::pin::Pin<&mut Self>, - _cx: &mut Context<'_>, - _buf: &[u8], - ) -> Poll> { - Poll::Ready(Err(std::io::Error::other("forced writer failure"))) - } - - fn poll_flush( - self: std::pin::Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll> { - Poll::Ready(Ok(())) - } - - fn poll_shutdown( - self: std::pin::Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll> { - Poll::Ready(Ok(())) - } -} - -struct FailAfterBudgetWriter { - remaining: usize, - written: usize, -} - -impl FailAfterBudgetWriter { - fn new(remaining: usize) -> Self { - Self { - remaining, - written: 0, - } - } -} - -impl AsyncWrite for FailAfterBudgetWriter { - fn poll_write( - mut self: std::pin::Pin<&mut Self>, - _cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - if self.remaining == 0 { - return Poll::Ready(Err(std::io::Error::other("forced short-write exhaustion"))); - } - - let n = self.remaining.min(buf.len()); - self.remaining -= n; - self.written += n; - Poll::Ready(Ok(n)) - } - - fn poll_flush( - self: std::pin::Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll> { - Poll::Ready(Ok(())) - } - - fn poll_shutdown( - self: std::pin::Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll> { - Poll::Ready(Ok(())) - } -} - -#[tokio::test] -async fn positive_exact_quota_boundary_allows_last_frame_and_blocks_next() { - let stats = Stats::new(); - let user = "quota-boundary-user"; - let bytes_me2c = AtomicU64::new(0); - - stats.add_user_octets_from(user, 5); - - let mut writer_one = make_crypto_writer(tokio::io::sink()); - let mut frame_buf_one = Vec::new(); - let first = process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[1, 2, 3]), - }, - &mut writer_one, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf_one, - &stats, - user, - Some(8), - 0, - &bytes_me2c, - 7101, - false, - false, - ) - .await; - - assert!(first.is_ok(), "frame that reaches boundary must be allowed"); - assert_eq!(stats.get_user_total_octets(user), 8); - - let mut writer_two = make_crypto_writer(tokio::io::sink()); - let mut frame_buf_two = Vec::new(); - let second = process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[9]), - }, - &mut writer_two, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf_two, - &stats, - user, - Some(8), - 0, - &bytes_me2c, - 7102, - false, - false, - ) - .await; - - assert!( - matches!(second, Err(ProxyError::DataQuotaExceeded { .. })), - "frame after boundary must be rejected" - ); - assert_eq!(stats.get_user_total_octets(user), 8); - assert_eq!(bytes_me2c.load(Ordering::Relaxed), 3); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn adversarial_parallel_reservation_stress_never_overshoots_quota_or_counters() { - let stats = Arc::new(Stats::new()); - let user = "reservation-stress-user"; - let quota_limit = 64u64; - let bytes_me2c = Arc::new(AtomicU64::new(0)); - let mut tasks = JoinSet::new(); - - for idx in 0..256u64 { - let user_owned = user.to_string(); - let stats_ref = Arc::clone(&stats); - let bytes_ref = Arc::clone(&bytes_me2c); - - tasks.spawn(async move { - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0xAB]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - stats_ref.as_ref(), - &user_owned, - Some(quota_limit), - 0, - bytes_ref.as_ref(), - 7200 + idx, - false, - false, - ) - .await - }); - } - - let mut ok = 0usize; - let mut denied = 0usize; - while let Some(joined) = tasks.join_next().await { - match joined.expect("reservation stress task must not panic") { - Ok(_) => ok += 1, - Err(ProxyError::DataQuotaExceeded { .. }) => denied += 1, - Err(other) => panic!("unexpected error in stress case: {other:?}"), - } - } - - let total = stats.get_user_total_octets(user); - assert_eq!( - total, quota_limit, - "quota must be exactly exhausted without overshoot" - ); - assert_eq!( - bytes_me2c.load(Ordering::Relaxed), - total, - "ME->C forensic bytes must track committed quota usage" - ); - assert_eq!(ok, quota_limit as usize, "exactly quota_limit tasks must succeed"); - assert_eq!( - denied, - 256usize - (quota_limit as usize), - "remaining tasks must be exactly denied without silently swallowing state" - ); -} - -#[tokio::test] -async fn light_fuzz_random_frame_sizes_preserve_quota_and_counter_consistency() { - let stats = Stats::new(); - let user = "reservation-fuzz-user"; - let quota_limit = 128u64; - let bytes_me2c = AtomicU64::new(0); - let mut seed = 0xC0FE_EE11_8899_2211u64; - - for conn in 0..512u64 { - seed ^= seed << 7; - seed ^= seed >> 9; - seed ^= seed << 8; - let len = ((seed & 0x0f) + 1) as usize; - let payload = vec![0x5A; len]; - - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - let result = process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from(payload), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - user, - Some(quota_limit), - 0, - &bytes_me2c, - 7300 + conn, - false, - false, - ) - .await; - - if let Err(err) = result { - assert!( - matches!(err, ProxyError::DataQuotaExceeded { .. }), - "fuzz run produced unexpected error variant: {err:?}" - ); - } - } - - let total = stats.get_user_total_octets(user); - assert!(total <= quota_limit); - assert_eq!(bytes_me2c.load(Ordering::Relaxed), total); -} - -#[tokio::test] -async fn positive_soft_overshoot_allows_burst_inside_soft_cap_then_blocks() { - let stats = Stats::new(); - let user = "soft-cap-boundary-user"; - let bytes_me2c = AtomicU64::new(0); - let quota_limit = 10u64; - let overshoot = 3u64; - - stats.add_user_octets_from(user, 10); - - let mut writer_one = make_crypto_writer(tokio::io::sink()); - let mut frame_buf_one = Vec::new(); - let first = process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[1, 2, 3]), - }, - &mut writer_one, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf_one, - &stats, - user, - Some(quota_limit), - overshoot, - &bytes_me2c, - 7401, - false, - false, - ) - .await; - assert!(first.is_ok(), "soft-cap buffer should allow reaching limit+overshoot"); - assert_eq!(stats.get_user_total_octets(user), 13); - - let mut writer_two = make_crypto_writer(tokio::io::sink()); - let mut frame_buf_two = Vec::new(); - let second = process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[9]), - }, - &mut writer_two, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf_two, - &stats, - user, - Some(quota_limit), - overshoot, - &bytes_me2c, - 7402, - false, - false, - ) - .await; - assert!(matches!(second, Err(ProxyError::DataQuotaExceeded { .. }))); - assert_eq!(stats.get_user_total_octets(user), 13); - assert_eq!(bytes_me2c.load(Ordering::Relaxed), 3); -} - -#[tokio::test] -async fn negative_soft_overshoot_rejects_when_payload_exceeds_remaining_soft_budget() { - let stats = Stats::new(); - let user = "soft-cap-remaining-user"; - let bytes_me2c = AtomicU64::new(0); - let quota_limit = 10u64; - let overshoot = 4u64; - - stats.add_user_octets_from(user, 12); - - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - let result = process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[1, 2, 3]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - user, - Some(quota_limit), - overshoot, - &bytes_me2c, - 7501, - false, - false, - ) - .await; - - assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. }))); - assert_eq!(stats.get_user_total_octets(user), 12); - assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0); -} - -#[tokio::test] -async fn negative_write_failure_rolls_back_reservation_under_soft_cap_mode() { - let stats = Stats::new(); - let user = "soft-cap-rollback-user"; - let bytes_me2c = AtomicU64::new(0); - let mut writer = make_crypto_writer(FailingWriter); - let mut frame_buf = Vec::new(); - - stats.add_user_octets_from(user, 9); - - let result = process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[1, 2, 3]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - user, - Some(10), - 8, - &bytes_me2c, - 7601, - false, - false, - ) - .await; - - assert!(matches!(result, Err(ProxyError::Io(_)))); - assert_eq!(stats.get_user_total_octets(user), 9); - assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn adversarial_parallel_soft_cap_stress_never_exceeds_soft_limit() { - let stats = Arc::new(Stats::new()); - let user = "soft-cap-stress-user"; - let quota_limit = 40u64; - let overshoot = 5u64; - let soft_limit = quota_limit + overshoot; - let bytes_me2c = Arc::new(AtomicU64::new(0)); - let mut tasks = JoinSet::new(); - - for idx in 0..256u64 { - let user_owned = user.to_string(); - let stats_ref = Arc::clone(&stats); - let bytes_ref = Arc::clone(&bytes_me2c); - tasks.spawn(async move { - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0x42]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - stats_ref.as_ref(), - &user_owned, - Some(quota_limit), - overshoot, - bytes_ref.as_ref(), - 7700 + idx, - false, - false, - ) - .await - }); - } - - while let Some(joined) = tasks.join_next().await { - match joined.expect("soft-cap stress task must not panic") { - Ok(_) | Err(ProxyError::DataQuotaExceeded { .. }) => {} - Err(other) => panic!("unexpected error in soft-cap stress case: {other:?}"), - } - } - - let total = stats.get_user_total_octets(user); - assert!(total <= soft_limit, "soft-cap stress must never overshoot soft limit"); - assert_eq!(bytes_me2c.load(Ordering::Relaxed), total); -} - -#[tokio::test] -async fn light_fuzz_soft_cap_matrix_keeps_counters_and_limits_consistent() { - let stats = Stats::new(); - let user = "soft-cap-fuzz-user"; - let bytes_me2c = AtomicU64::new(0); - let mut seed = 0x9E37_79B9_7F4A_7C15u64; - - for conn in 0..1024u64 { - seed ^= seed << 7; - seed ^= seed >> 9; - seed ^= seed << 8; - - let quota_limit = 32 + (seed & 0x3f); - let overshoot = seed.rotate_left(13) & 0x0f; - let len = ((seed >> 3) & 0x07) + 1; - let payload = vec![0xA5; len as usize]; - let before = stats.get_user_total_octets(user); - - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - let result = process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from(payload), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - user, - Some(quota_limit), - overshoot, - &bytes_me2c, - 7800 + conn, - false, - false, - ) - .await; - - if let Err(ref err) = result { - assert!( - matches!(err, ProxyError::DataQuotaExceeded { .. }), - "soft-cap fuzz produced unexpected error variant: {err:?}" - ); - } - - let after = stats.get_user_total_octets(user); - let soft_limit = quota_limit.saturating_add(overshoot); - match result { - Ok(_) => { - assert_eq!(after, before.saturating_add(len)); - assert!(after <= soft_limit, "accepted write must stay within active soft cap"); - } - Err(_) => { - assert_eq!(after, before, "rejected write must not mutate quota state"); - } - } - assert_eq!( - bytes_me2c.load(Ordering::Relaxed), - after, - "soft-cap fuzz must keep counters synchronized" - ); - } -} - -#[tokio::test] -async fn positive_no_quota_limit_accumulates_data_octets_exactly() { - let stats = Stats::new(); - let user = "no-quota-user"; - let bytes_me2c = AtomicU64::new(0); - let mut expected = 0u64; - - for (idx, len) in [1usize, 2, 3, 5, 8, 13, 21].iter().copied().enumerate() { - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - let payload = vec![0x41; len]; - let result = process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from(payload), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - user, - None, - 0, - &bytes_me2c, - 7900 + idx as u64, - false, - false, - ) - .await; - - assert!(result.is_ok()); - expected += len as u64; - } - - assert_eq!(stats.get_user_total_octets(user), expected); - assert_eq!(bytes_me2c.load(Ordering::Relaxed), expected); -} - -#[tokio::test] -async fn negative_zero_quota_rejects_non_empty_payload() { - let stats = Stats::new(); - let user = "zero-quota-user"; - let bytes_me2c = AtomicU64::new(0); - - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - let result = process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0xAA]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - user, - Some(0), - 0, - &bytes_me2c, - 8001, - false, - false, - ) - .await; - - assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. }))); - assert_eq!(stats.get_user_total_octets(user), 0); - assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0); -} - -#[tokio::test] -async fn edge_zero_length_payload_with_zero_quota_is_fail_closed() { - let stats = Stats::new(); - let user = "zero-len-zero-quota-user"; - let bytes_me2c = AtomicU64::new(0); - - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - let result = process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::new(), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - user, - Some(0), - 0, - &bytes_me2c, - 8002, - false, - false, - ) - .await; - - assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. }))); - assert_eq!(stats.get_user_total_octets(user), 0); - assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0); -} - -#[tokio::test] -async fn positive_ack_response_does_not_touch_quota_counters() { - let stats = Stats::new(); - let user = "ack-accounting-user"; - let bytes_me2c = AtomicU64::new(11); - stats.add_user_octets_to(user, 23); - - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - let result = process_me_writer_response( - MeResponse::Ack(0x33445566), - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - user, - Some(24), - 0, - &bytes_me2c, - 8003, - true, - true, - ) - .await; - - assert!(result.is_ok()); - assert_eq!(stats.get_user_total_octets(user), 23); - assert_eq!(bytes_me2c.load(Ordering::Relaxed), 11); -} - -#[tokio::test] -async fn edge_close_response_is_accounting_noop() { - let stats = Stats::new(); - let user = "close-accounting-user"; - let bytes_me2c = AtomicU64::new(19); - stats.add_user_octets_to(user, 31); - - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - let result = process_me_writer_response( - MeResponse::Close, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - user, - Some(40), - 3, - &bytes_me2c, - 8004, - false, - true, - ) - .await; - - assert!(result.is_ok()); - assert_eq!(stats.get_user_total_octets(user), 31); - assert_eq!(bytes_me2c.load(Ordering::Relaxed), 19); -} - -#[tokio::test] -async fn negative_preloaded_above_soft_cap_rejects_even_single_byte() { - let stats = Stats::new(); - let user = "preloaded-over-soft-cap-user"; - let bytes_me2c = AtomicU64::new(0); - let quota_limit = 20u64; - let overshoot = 2u64; - stats.add_user_octets_to(user, quota_limit + overshoot + 1); - - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - let result = process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[1]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - user, - Some(quota_limit), - overshoot, - &bytes_me2c, - 8005, - false, - false, - ) - .await; - - assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. }))); - assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0); - assert_eq!(stats.get_user_total_octets(user), quota_limit + overshoot + 1); -} - -#[tokio::test] -async fn adversarial_fail_writer_path_never_desynchronizes_quota_accounting() { - let stats = Stats::new(); - let user = "partial-write-rollback-user"; - let bytes_me2c = AtomicU64::new(0); - let mut writer = make_crypto_writer(FailAfterBudgetWriter::new(7)); - let mut frame_buf = Vec::new(); - let payload_len = 16 * 1024u64; - - let result = process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from(vec![0x42; 16 * 1024]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - user, - Some(payload_len), - 0, - &bytes_me2c, - 8006, - false, - false, - ) - .await; - - let total_after = stats.get_user_total_octets(user); - let forensic_after = bytes_me2c.load(Ordering::Relaxed); - assert_eq!(forensic_after, total_after); - assert!( - total_after == 0 || total_after == payload_len, - "writer failure path must either roll back fully or commit exactly one payload" - ); - - // Regardless of whether I/O failure surfaced immediately or was deferred, - // accounting must remain fail-closed and prevent silent overshoot. - let mut writer_two = make_crypto_writer(tokio::io::sink()); - let mut frame_buf_two = Vec::new(); - let second = process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0x99]), - }, - &mut writer_two, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf_two, - &stats, - user, - Some(payload_len), - 0, - &bytes_me2c, - 8007, - false, - false, - ) - .await; - - if total_after == payload_len { - assert!(matches!(second, Err(ProxyError::DataQuotaExceeded { .. }))); - } else { - assert!(second.is_ok()); - } -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn stress_parallel_oversized_frames_fail_closed_without_counter_leak() { - let stats = Arc::new(Stats::new()); - let user = "parallel-fail-rollback-user"; - let bytes_me2c = Arc::new(AtomicU64::new(0)); - let mut tasks = JoinSet::new(); - - for idx in 0..256u64 { - let user_owned = user.to_string(); - let stats_ref = Arc::clone(&stats); - let bytes_ref = Arc::clone(&bytes_me2c); - tasks.spawn(async move { - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from(vec![0xEE; 12 * 1024]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - stats_ref.as_ref(), - &user_owned, - Some(512), - 0, - bytes_ref.as_ref(), - 8100 + idx, - false, - false, - ) - .await - }); - } - - while let Some(joined) = tasks.join_next().await { - let result = joined.expect("parallel fail writer task must not panic"); - assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. }))); - } - - assert_eq!(stats.get_user_total_octets(user), 0); - assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0); -} - -#[tokio::test] -async fn integration_mixed_data_ack_close_sequence_preserves_data_only_accounting() { - let stats = Stats::new(); - let user = "mixed-sequence-user"; - let bytes_me2c = AtomicU64::new(0); - - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - - let data_one = process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[1, 2, 3]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - user, - Some(32), - 0, - &bytes_me2c, - 8201, - false, - false, - ) - .await; - assert!(data_one.is_ok()); - - let ack = process_me_writer_response( - MeResponse::Ack(0x0102_0304), - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - user, - Some(32), - 0, - &bytes_me2c, - 8202, - true, - true, - ) - .await; - assert!(ack.is_ok()); - - let data_two = process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[4, 5]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - user, - Some(32), - 0, - &bytes_me2c, - 8203, - false, - true, - ) - .await; - assert!(data_two.is_ok()); - - let close = process_me_writer_response( - MeResponse::Close, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - user, - Some(32), - 0, - &bytes_me2c, - 8204, - false, - true, - ) - .await; - assert!(close.is_ok()); - - assert_eq!(stats.get_user_total_octets(user), 5); - assert_eq!(bytes_me2c.load(Ordering::Relaxed), 5); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn stress_parallel_multi_user_quota_isolation_no_cross_user_leakage() { - let stats = Arc::new(Stats::new()); - let user_a = "quota-isolation-a"; - let user_b = "quota-isolation-b"; - let limit_a = 50u64; - let limit_b = 80u64; - let bytes_a = Arc::new(AtomicU64::new(0)); - let bytes_b = Arc::new(AtomicU64::new(0)); - - let mut tasks = JoinSet::new(); - for idx in 0..200u64 { - let stats_ref = Arc::clone(&stats); - let bytes_ref = Arc::clone(&bytes_a); - tasks.spawn(async move { - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0xA1]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - stats_ref.as_ref(), - user_a, - Some(limit_a), - 0, - bytes_ref.as_ref(), - 8300 + idx, - false, - false, - ) - .await - }); - } - - for idx in 0..220u64 { - let stats_ref = Arc::clone(&stats); - let bytes_ref = Arc::clone(&bytes_b); - tasks.spawn(async move { - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0xB2]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - stats_ref.as_ref(), - user_b, - Some(limit_b), - 0, - bytes_ref.as_ref(), - 8500 + idx, - false, - false, - ) - .await - }); - } - - while let Some(joined) = tasks.join_next().await { - let result = joined.expect("quota isolation task must not panic"); - assert!(result.is_ok() || matches!(result, Err(ProxyError::DataQuotaExceeded { .. }))); - } - - assert_eq!(stats.get_user_total_octets(user_a), limit_a); - assert_eq!(stats.get_user_total_octets(user_b), limit_b); - assert_eq!(bytes_a.load(Ordering::Relaxed), limit_a); - assert_eq!(bytes_b.load(Ordering::Relaxed), limit_b); -} - -#[tokio::test] -async fn light_fuzz_mixed_me_responses_preserve_quota_and_counter_invariants() { - let stats = Stats::new(); - let user = "mixed-fuzz-user"; - let bytes_me2c = AtomicU64::new(0); - let quota_limit = 96u64; - let mut seed = 0xDEAD_BEEF_2026_0323u64; - - for idx in 0..2048u64 { - seed ^= seed << 7; - seed ^= seed >> 9; - seed ^= seed << 8; - - let choice = (seed & 0x03) as u8; - let response = if choice == 0 { - MeResponse::Ack((seed >> 8) as u32) - } else if choice == 1 { - MeResponse::Close - } else { - let len = ((seed >> 16) & 0x07) as usize; - let mut payload = vec![0u8; len]; - payload.fill((seed & 0xff) as u8); - MeResponse::Data { - flags: 0, - data: Bytes::from(payload), - } - }; - - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - let result = process_me_writer_response( - response, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - user, - Some(quota_limit), - 0, - &bytes_me2c, - 8800 + idx, - (idx & 1) == 0, - (idx & 2) == 0, - ) - .await; - - if let Err(err) = result { - assert!( - matches!(err, ProxyError::DataQuotaExceeded { .. }), - "mixed fuzz produced unexpected error variant: {err:?}" - ); - } - - let total = stats.get_user_total_octets(user); - assert!( - total <= quota_limit, - "mixed fuzz must keep usage at or below quota limit" - ); - assert_eq!(bytes_me2c.load(Ordering::Relaxed), total); - } -} \ No newline at end of file diff --git a/src/proxy/tests/middle_relay_quota_reservation_extreme_security_tests.rs b/src/proxy/tests/middle_relay_quota_reservation_extreme_security_tests.rs deleted file mode 100644 index e4d0c6e..0000000 --- a/src/proxy/tests/middle_relay_quota_reservation_extreme_security_tests.rs +++ /dev/null @@ -1,399 +0,0 @@ -use super::*; -use crate::crypto::{AesCtr, SecureRandom}; -use crate::stats::Stats; -use crate::stream::CryptoWriter; -use bytes::Bytes; -use std::sync::atomic::{AtomicU64, Ordering}; -use std::sync::{Arc, Mutex, OnceLock}; -use tokio::sync::Mutex as AsyncMutex; -use tokio::task::JoinSet; -use tokio::time::{Duration, timeout}; - -fn make_crypto_writer(writer: W) -> CryptoWriter -where - W: tokio::io::AsyncWrite + Unpin, -{ - let key = [0u8; 32]; - let iv = 0u128; - CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) -} - -fn lookup_counter_test_lock() -> &'static Mutex<()> { - static LOCK: OnceLock> = OnceLock::new(); - LOCK.get_or_init(|| Mutex::new(())) -} - -#[tokio::test] -async fn positive_prefetched_cross_mode_lock_multi_frame_accounting_is_exact() { - let _guard = lookup_counter_test_lock() - .lock() - .unwrap_or_else(|poison| poison.into_inner()); - - let stats = Stats::new(); - let user = format!("quota-extreme-positive-{}", std::process::id()); - let lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); - crate::proxy::quota_lock_registry::reset_cross_mode_quota_user_lock_lookup_count_for_tests(); - - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - let bytes_me2c = AtomicU64::new(0); - - for idx in 0..12u64 { - let payload = vec![0x5A; ((idx % 4) + 1) as usize]; - let result = process_me_writer_response_with_cross_mode_lock( - MeResponse::Data { - flags: 0, - data: Bytes::from(payload), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - &user, - Some(512), - 0, - Some(&lock), - &bytes_me2c, - 31_000 + idx, - false, - false, - ) - .await; - - assert!(result.is_ok()); - } - - assert_eq!( - crate::proxy::quota_lock_registry::cross_mode_quota_user_lock_lookup_count_for_user_for_tests(&user), - 0, - "prefetched lock path must avoid hot-path registry lookups" - ); - assert_eq!( - stats.get_user_total_octets(&user), - bytes_me2c.load(Ordering::Relaxed), - "forensics and quota accounting must remain synchronized" - ); -} - -#[tokio::test] -async fn negative_held_prefetched_lock_blocks_writer_without_accounting_mutation() { - let _guard = lookup_counter_test_lock() - .lock() - .unwrap_or_else(|poison| poison.into_inner()); - - let stats = Stats::new(); - let user = format!("quota-extreme-negative-{}", std::process::id()); - let lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); - let held_guard = lock - .try_lock() - .expect("test must hold lock before calling ME->C writer"); - - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - let bytes_me2c = AtomicU64::new(0); - - let blocked = timeout( - Duration::from_millis(25), - process_me_writer_response_with_cross_mode_lock( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[1, 2, 3]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - &user, - Some(64), - 0, - Some(&lock), - &bytes_me2c, - 31_100, - false, - false, - ), - ) - .await; - - assert!(blocked.is_err()); - assert_eq!(stats.get_user_total_octets(&user), 0); - assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0); - - drop(held_guard); -} - -#[tokio::test] -async fn edge_zero_quota_and_zero_payload_is_fail_closed() { - let _guard = lookup_counter_test_lock() - .lock() - .unwrap_or_else(|poison| poison.into_inner()); - - let stats = Stats::new(); - let user = format!("quota-extreme-edge-{}", std::process::id()); - let lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); - - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - let bytes_me2c = AtomicU64::new(0); - - let result = process_me_writer_response_with_cross_mode_lock( - MeResponse::Data { - flags: 0, - data: Bytes::new(), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - &user, - Some(0), - 0, - Some(&lock), - &bytes_me2c, - 31_200, - false, - false, - ) - .await; - - assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. }))); - assert_eq!(stats.get_user_total_octets(&user), 0); - assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn adversarial_blackhat_parallel_quota_race_never_overshoots_soft_cap() { - let _guard = lookup_counter_test_lock() - .lock() - .unwrap_or_else(|poison| poison.into_inner()); - - let stats = Arc::new(Stats::new()); - let user = format!("quota-extreme-blackhat-{}", std::process::id()); - let quota = 80u64; - let overshoot = 7u64; - let soft_limit = quota + overshoot; - let lock = Arc::new(crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user)); - let bytes_me2c = Arc::new(AtomicU64::new(0)); - - let mut set = JoinSet::new(); - for idx in 0..256u64 { - let stats = Arc::clone(&stats); - let user = user.clone(); - let lock = Arc::clone(&lock); - let bytes_me2c = Arc::clone(&bytes_me2c); - - set.spawn(async move { - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - let len = ((idx % 5) + 1) as usize; - let payload = vec![0xAA; len]; - - process_me_writer_response_with_cross_mode_lock( - MeResponse::Data { - flags: 0, - data: Bytes::from(payload), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - stats.as_ref(), - &user, - Some(quota), - overshoot, - Some(&lock), - bytes_me2c.as_ref(), - 31_300 + idx, - false, - false, - ) - .await - }); - } - - while let Some(done) = set.join_next().await { - match done.expect("task must not panic") { - Ok(_) | Err(ProxyError::DataQuotaExceeded { .. }) => {} - Err(other) => panic!("unexpected error variant under black-hat race: {other:?}"), - } - } - - let total = stats.get_user_total_octets(&user); - assert!( - total <= soft_limit, - "parallel adversarial race must stay under soft cap" - ); - assert_eq!(bytes_me2c.load(Ordering::Relaxed), total); -} - -#[tokio::test] -async fn integration_without_prefetched_lock_uses_registry_lookup_path() { - let _guard = lookup_counter_test_lock() - .lock() - .unwrap_or_else(|poison| poison.into_inner()); - - let stats = Stats::new(); - let user = format!("quota-extreme-integration-{}", std::process::id()); - crate::proxy::quota_lock_registry::reset_cross_mode_quota_user_lock_lookup_count_for_tests(); - - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - let bytes_me2c = AtomicU64::new(0); - - for idx in 0..3u64 { - let result = process_me_writer_response_with_cross_mode_lock( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0x41]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - &user, - Some(16), - 0, - None, - &bytes_me2c, - 31_400 + idx, - false, - false, - ) - .await; - - assert!(result.is_ok()); - } - - assert_eq!( - crate::proxy::quota_lock_registry::cross_mode_quota_user_lock_lookup_count_for_user_for_tests(&user), - 3, - "control path should perform one lock-registry lookup per call" - ); -} - -#[tokio::test] -async fn light_fuzz_quota_matrix_preserves_fail_closed_accounting() { - let _guard = lookup_counter_test_lock() - .lock() - .unwrap_or_else(|poison| poison.into_inner()); - - let stats = Stats::new(); - let user = format!("quota-extreme-fuzz-{}", std::process::id()); - let lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); - let bytes_me2c = AtomicU64::new(0); - let mut seed = 0xA11C_55EE_2026_0323u64; - - for idx in 0..512u64 { - seed ^= seed << 7; - seed ^= seed >> 9; - seed ^= seed << 8; - - let quota = 24 + (seed & 0x3f); - let overshoot = (seed >> 13) & 0x0f; - let len = ((seed >> 19) & 0x07) + 1; - - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - let before = stats.get_user_total_octets(&user); - - let result = process_me_writer_response_with_cross_mode_lock( - MeResponse::Data { - flags: 0, - data: Bytes::from(vec![0x11; len as usize]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - &user, - Some(quota), - overshoot, - Some(&lock), - &bytes_me2c, - 31_500 + idx, - false, - false, - ) - .await; - - let after = stats.get_user_total_octets(&user); - if result.is_ok() { - assert!(after >= before); - } else { - assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. }))); - assert_eq!(after, before); - } - assert_eq!(bytes_me2c.load(Ordering::Relaxed), after); - } -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn stress_prefetched_lock_high_fanout_exact_quota_success_count() { - let _guard = lookup_counter_test_lock() - .lock() - .unwrap_or_else(|poison| poison.into_inner()); - - let stats = Arc::new(Stats::new()); - let user = format!("quota-extreme-stress-{}", std::process::id()); - let quota = 96u64; - let lock: Arc> = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); - let bytes_me2c = Arc::new(AtomicU64::new(0)); - - crate::proxy::quota_lock_registry::reset_cross_mode_quota_user_lock_lookup_count_for_tests(); - - let mut set = JoinSet::new(); - for idx in 0..384u64 { - let stats = Arc::clone(&stats); - let user = user.clone(); - let lock = Arc::clone(&lock); - let bytes_me2c = Arc::clone(&bytes_me2c); - - set.spawn(async move { - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - process_me_writer_response_with_cross_mode_lock( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0xFF]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - stats.as_ref(), - &user, - Some(quota), - 0, - Some(&lock), - bytes_me2c.as_ref(), - 31_600 + idx, - false, - false, - ) - .await - }); - } - - let mut success = 0usize; - while let Some(done) = set.join_next().await { - match done.expect("task must not panic") { - Ok(_) => success += 1, - Err(ProxyError::DataQuotaExceeded { .. }) => {} - Err(other) => panic!("unexpected error variant in stress fanout: {other:?}"), - } - } - - assert_eq!(success, quota as usize); - assert_eq!(stats.get_user_total_octets(&user), quota); - assert_eq!(bytes_me2c.load(Ordering::Relaxed), quota); - assert_eq!( - crate::proxy::quota_lock_registry::cross_mode_quota_user_lock_lookup_count_for_user_for_tests(&user), - 0, - "stress prefetched path must not use lock registry lookups" - ); -} diff --git a/src/proxy/tests/middle_relay_security_tests.rs b/src/proxy/tests/middle_relay_security_tests.rs deleted file mode 100644 index 1d3b736..0000000 --- a/src/proxy/tests/middle_relay_security_tests.rs +++ /dev/null @@ -1,2517 +0,0 @@ -use super::*; -use crate::config::{GeneralConfig, MeRouteNoWriterMode, MeSocksKdfPolicy, MeWriterPickMode}; -use crate::crypto::AesCtr; -use crate::crypto::SecureRandom; -use crate::network::probe::NetworkDecision; -use crate::proxy::handshake::HandshakeSuccess; -use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController}; -use crate::stats::Stats; -use crate::stream::{BufferPool, CryptoReader, CryptoWriter, PooledBuffer}; -use crate::transport::middle_proxy::MePool; -use bytes::Bytes; -use rand::rngs::StdRng; -use rand::{RngExt, SeedableRng}; -use std::collections::{HashMap, HashSet}; -use std::net::SocketAddr; -use std::sync::Arc; -use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering}; -use std::sync::Mutex; -use std::thread; -use tokio::io::AsyncReadExt; -use tokio::io::AsyncWriteExt; -use tokio::io::duplex; -use tokio::sync::Barrier; -use tokio::time::{Duration as TokioDuration, timeout}; - -fn make_pooled_payload(data: &[u8]) -> PooledBuffer { - let pool = Arc::new(BufferPool::with_config(data.len().max(1), 4)); - let mut payload = pool.get(); - payload.resize(data.len(), 0); - payload[..data.len()].copy_from_slice(data); - payload -} - -fn make_pooled_payload_from(pool: &Arc, data: &[u8]) -> PooledBuffer { - let mut payload = pool.get(); - payload.resize(data.len(), 0); - payload[..data.len()].copy_from_slice(data); - payload -} - -#[test] -fn should_yield_sender_only_on_budget_with_backlog() { - assert!(!should_yield_c2me_sender(0, true)); - assert!(!should_yield_c2me_sender( - C2ME_SENDER_FAIRNESS_BUDGET - 1, - true - )); - assert!(!should_yield_c2me_sender( - C2ME_SENDER_FAIRNESS_BUDGET, - false - )); - assert!(should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET, true)); -} - -#[tokio::test] -async fn enqueue_c2me_command_uses_try_send_fast_path() { - let (tx, mut rx) = mpsc::channel::(2); - enqueue_c2me_command( - &tx, - C2MeCommand::Data { - payload: make_pooled_payload(&[1, 2, 3]), - flags: 0, - }, - ) - .await - .unwrap(); - - let recv = timeout(TokioDuration::from_millis(50), rx.recv()) - .await - .unwrap() - .unwrap(); - match recv { - C2MeCommand::Data { payload, flags } => { - assert_eq!(payload.as_ref(), &[1, 2, 3]); - assert_eq!(flags, 0); - } - C2MeCommand::Close => panic!("unexpected close command"), - } -} - -#[tokio::test] -async fn enqueue_c2me_command_falls_back_to_send_when_queue_is_full() { - let (tx, mut rx) = mpsc::channel::(1); - tx.send(C2MeCommand::Data { - payload: make_pooled_payload(&[9]), - flags: 9, - }) - .await - .unwrap(); - - let tx2 = tx.clone(); - let producer = tokio::spawn(async move { - enqueue_c2me_command( - &tx2, - C2MeCommand::Data { - payload: make_pooled_payload(&[7, 7]), - flags: 7, - }, - ) - .await - .unwrap(); - }); - - let _ = timeout(TokioDuration::from_millis(100), rx.recv()) - .await - .unwrap(); - producer.await.unwrap(); - - let recv = timeout(TokioDuration::from_millis(100), rx.recv()) - .await - .unwrap() - .unwrap(); - match recv { - C2MeCommand::Data { payload, flags } => { - assert_eq!(payload.as_ref(), &[7, 7]); - assert_eq!(flags, 7); - } - C2MeCommand::Close => panic!("unexpected close command"), - } -} - -#[tokio::test] -async fn enqueue_c2me_command_closed_channel_recycles_payload() { - let pool = Arc::new(BufferPool::with_config(64, 4)); - let payload = make_pooled_payload_from(&pool, &[1, 2, 3, 4]); - let (tx, rx) = mpsc::channel::(1); - drop(rx); - - let result = enqueue_c2me_command(&tx, C2MeCommand::Data { payload, flags: 0 }).await; - - assert!(result.is_err(), "closed queue must fail enqueue"); - drop(result); - assert!( - pool.stats().pooled >= 1, - "payload must return to pool when enqueue fails on closed channel" - ); -} - -#[tokio::test] -async fn enqueue_c2me_command_full_then_closed_recycles_waiting_payload() { - let pool = Arc::new(BufferPool::with_config(64, 4)); - let (tx, rx) = mpsc::channel::(1); - - tx.send(C2MeCommand::Data { - payload: make_pooled_payload_from(&pool, &[9]), - flags: 1, - }) - .await - .unwrap(); - - let tx2 = tx.clone(); - let pool2 = pool.clone(); - let blocked_send = tokio::spawn(async move { - enqueue_c2me_command( - &tx2, - C2MeCommand::Data { - payload: make_pooled_payload_from(&pool2, &[7, 7, 7]), - flags: 2, - }, - ) - .await - }); - - tokio::time::sleep(TokioDuration::from_millis(10)).await; - drop(rx); - - let result = timeout(TokioDuration::from_secs(1), blocked_send) - .await - .expect("blocked send task must finish") - .expect("blocked send task must not panic"); - - assert!( - result.is_err(), - "closing receiver while sender is blocked must fail enqueue" - ); - drop(result); - assert!( - pool.stats().pooled >= 2, - "both queued and blocked payloads must return to pool after channel close" - ); -} - -#[tokio::test] -async fn enqueue_c2me_command_full_queue_times_out_without_receiver_progress() { - let (tx, _rx) = mpsc::channel::(1); - tx.send(C2MeCommand::Data { - payload: make_pooled_payload(&[1]), - flags: 0, - }) - .await - .unwrap(); - - let started = Instant::now(); - let result = enqueue_c2me_command( - &tx, - C2MeCommand::Data { - payload: make_pooled_payload(&[2, 2]), - flags: 1, - }, - ) - .await; - - assert!( - result.is_err(), - "enqueue must fail when queue stays full beyond bounded timeout" - ); - assert!( - started.elapsed() < TokioDuration::from_millis(400), - "full-queue timeout must resolve promptly" - ); -} - -#[test] -fn desync_dedup_cache_is_bounded() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("desync dedup test lock must be available"); - clear_desync_dedup_for_testing(); - - let now = Instant::now(); - for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { - assert!( - should_emit_full_desync(key, false, now), - "unique keys up to cap must be tracked" - ); - } - - assert!( - should_emit_full_desync(u64::MAX, false, now), - "new key above cap must emit once after bounded eviction for forensic visibility" - ); - - assert!( - !should_emit_full_desync(u64::MAX, false, now), - "already tracked key inside dedup window must stay suppressed" - ); -} - -#[test] -fn quota_user_lock_cache_reuses_entry_for_same_user() { - let _guard = super::quota_user_lock_test_scope(); - - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let a = quota_user_lock("quota-user-a"); - let b = quota_user_lock("quota-user-a"); - assert!(Arc::ptr_eq(&a, &b), "same user must reuse same quota lock"); -} - -#[test] -fn quota_user_lock_cache_is_bounded_under_unique_churn() { - let _guard = super::quota_user_lock_test_scope(); - - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - for idx in 0..(QUOTA_USER_LOCKS_MAX + 128) { - let user = format!("quota-user-{idx}"); - let lock = quota_user_lock(&user); - drop(lock); - } - - assert!( - map.len() <= QUOTA_USER_LOCKS_MAX, - "quota lock cache must stay within configured bound" - ); -} - -#[test] -fn quota_user_lock_cache_saturation_returns_stable_overflow_lock_without_growth() { - let _guard = super::quota_user_lock_test_scope(); - - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - for attempt in 0..8u32 { - map.clear(); - - let prefix = format!("quota-held-user-{}-{attempt}", std::process::id()); - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - for idx in 0..QUOTA_USER_LOCKS_MAX { - let user = format!("{prefix}-{idx}"); - retained.push(quota_user_lock(&user)); - } - - if map.len() != QUOTA_USER_LOCKS_MAX { - drop(retained); - continue; - } - - let overflow_user = format!("quota-overflow-user-{}-{attempt}", std::process::id()); - let overflow_a = quota_user_lock(&overflow_user); - let overflow_b = quota_user_lock(&overflow_user); - - assert_eq!( - map.len(), - QUOTA_USER_LOCKS_MAX, - "overflow acquisition must not grow cache past hard limit" - ); - assert!( - map.get(&overflow_user).is_none(), - "overflow path should not cache new user lock when map is saturated and all entries are retained" - ); - assert!( - Arc::ptr_eq(&overflow_a, &overflow_b), - "overflow user lock should use deterministic striping under saturation" - ); - - drop(retained); - return; - } - - panic!("unable to observe stable saturated lock-cache precondition after bounded retries"); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn adversarial_quota_race_under_lock_cache_saturation_still_allows_only_one_winner() { - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - for idx in 0..QUOTA_USER_LOCKS_MAX { - let user = format!("quota-saturated-user-{idx}"); - retained.push(quota_user_lock(&user)); - } - - assert_eq!( - map.len(), - QUOTA_USER_LOCKS_MAX, - "precondition: cache must be saturated for overflow-user race test" - ); - - let stats = Stats::new(); - let bytes_me2c = AtomicU64::new(0); - let user = "gap-t04-saturated-lock-race-user"; - let barrier = Arc::new(Barrier::new(2)); - - let one = run_quota_race_attempt(&stats, &bytes_me2c, user, 0x55, 9101, barrier.clone()); - let two = run_quota_race_attempt(&stats, &bytes_me2c, user, 0x66, 9102, barrier); - let (r1, r2) = tokio::join!(one, two); - - assert!( - matches!(r1, Ok(_) | Err(ProxyError::DataQuotaExceeded { .. })) - && matches!(r2, Ok(_) | Err(ProxyError::DataQuotaExceeded { .. })), - "both racers must resolve cleanly without unexpected errors" - ); - assert!( - matches!(r1, Err(ProxyError::DataQuotaExceeded { .. })) - || matches!(r2, Err(ProxyError::DataQuotaExceeded { .. })), - "at least one racer must be quota-rejected even when lock cache is saturated" - ); - assert_eq!( - stats.get_user_total_octets(user), - 1, - "saturated lock cache must not permit double-success quota overshoot" - ); - - drop(retained); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn stress_quota_race_under_lock_cache_saturation_never_allows_double_success() { - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - for idx in 0..QUOTA_USER_LOCKS_MAX { - let user = format!("quota-saturated-stress-holder-{idx}"); - retained.push(quota_user_lock(&user)); - } - - let stats = Stats::new(); - let bytes_me2c = AtomicU64::new(0); - - for round in 0..128u64 { - let user = format!("gap-t04-saturated-race-round-{round}"); - let barrier = Arc::new(Barrier::new(2)); - - let one = run_quota_race_attempt( - &stats, - &bytes_me2c, - &user, - 0x71, - 12_000 + round, - barrier.clone(), - ); - let two = run_quota_race_attempt(&stats, &bytes_me2c, &user, 0x72, 13_000 + round, barrier); - - let (r1, r2) = tokio::join!(one, two); - assert!( - matches!(r1, Ok(_) | Err(ProxyError::DataQuotaExceeded { .. })) - && matches!(r2, Ok(_) | Err(ProxyError::DataQuotaExceeded { .. })), - "round {round}: racers must resolve cleanly" - ); - assert!( - matches!(r1, Err(ProxyError::DataQuotaExceeded { .. })) - || matches!(r2, Err(ProxyError::DataQuotaExceeded { .. })), - "round {round}: at least one racer must be quota-rejected" - ); - assert_eq!( - stats.get_user_total_octets(&user), - 1, - "round {round}: saturated cache must still enforce exactly one forwarded byte" - ); - } - - drop(retained); -} - -#[test] -fn adversarial_forensics_trace_id_should_not_alias_conn_id() { - let now = Instant::now(); - let trace_id = 0x1122_3344_5566_7788; - let conn_id = 0x8877_6655_4433_2211; - let state = RelayForensicsState { - trace_id, - conn_id, - user: "trace-user".to_string(), - peer: "198.51.100.17:443".parse().unwrap(), - peer_hash: 0x8877_6655_4433_2211, - started_at: now, - bytes_c2me: 0, - bytes_me2c: Arc::new(AtomicU64::new(0)), - desync_all_full: false, - }; - - assert_ne!( - state.trace_id, state.conn_id, - "security expectation: trace correlation should be independent of connection identity" - ); - assert_eq!(state.trace_id, trace_id); - assert_eq!(state.conn_id, conn_id); -} - -#[tokio::test] -async fn abridged_ack_uses_big_endian_confirm_bytes_after_decryption() { - let (mut writer_side, reader_side) = duplex(8); - let key = [0u8; 32]; - let iv = 0u128; - let mut writer = CryptoWriter::new(reader_side, AesCtr::new(&key, iv), 8 * 1024); - - write_client_ack(&mut writer, ProtoTag::Abridged, 0x11_22_33_44) - .await - .expect("ack write must succeed"); - - let mut observed = [0u8; 4]; - writer_side - .read_exact(&mut observed) - .await - .expect("ack bytes must be readable"); - let mut decryptor = AesCtr::new(&key, iv); - let decrypted = decryptor.decrypt(&observed); - - assert_eq!( - decrypted, - 0x11_22_33_44u32.to_be_bytes(), - "abridged ACK should encode confirm bytes in big-endian order" - ); -} - -#[test] -fn desync_dedup_full_cache_churn_stays_suppressed() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("desync dedup test lock must be available"); - clear_desync_dedup_for_testing(); - - let now = Instant::now(); - for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { - assert!(should_emit_full_desync(key, false, now)); - } - - for offset in 0..2048u64 { - let emitted = should_emit_full_desync(u64::MAX - offset, false, now); - if offset == 0 { - assert!( - emitted, - "first full-cache newcomer should emit for forensic visibility" - ); - } else { - assert!( - !emitted, - "full-cache newcomer churn inside emit interval must stay suppressed" - ); - } - } -} - -#[test] -fn dedup_hash_is_stable_for_same_input_within_process() { - let sample = ( - "scope_user", - hash_ip("198.51.100.7".parse().unwrap()), - ProtoTag::Secure, - ); - let first = hash_value(&sample); - let second = hash_value(&sample); - assert_eq!( - first, second, - "dedup hash must be stable within a process for cache lookups" - ); -} - -#[test] -fn dedup_hash_resists_simple_collision_bursts_for_peer_ip_space() { - let mut seen = HashSet::new(); - - for octet in 1u16..=2048 { - let third = ((octet / 256) & 0xff) as u8; - let fourth = (octet & 0xff) as u8; - let ip = IpAddr::V4(std::net::Ipv4Addr::new(198, 51, third, fourth)); - let key = hash_value(&( - "scope_user", - hash_ip(ip), - ProtoTag::Secure, - DESYNC_ERROR_CLASS, - )); - seen.insert(key); - } - - assert_eq!( - seen.len(), - 2048, - "adversarial peer-IP burst should not collapse dedup keys via trivial collisions" - ); -} - -#[test] -fn light_fuzz_dedup_hash_collision_rate_stays_negligible() { - let mut rng = StdRng::seed_from_u64(0x9E37_79B9_A1B2_C3D4); - let mut seen = HashSet::new(); - let samples = 8192usize; - - for _ in 0..samples { - let user_seed: u64 = rng.random(); - let peer_seed: u64 = rng.random(); - let proto = if (peer_seed & 1) == 0 { - ProtoTag::Secure - } else { - ProtoTag::Intermediate - }; - let key = hash_value(&(user_seed, peer_seed, proto, DESYNC_ERROR_CLASS)); - seen.insert(key); - } - - let collisions = samples - seen.len(); - assert!( - collisions <= 1, - "light fuzz collision count should remain negligible for 64-bit dedup keys" - ); -} - -#[test] -fn stress_desync_dedup_churn_keeps_cache_hard_bounded() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("desync dedup test lock must be available"); - clear_desync_dedup_for_testing(); - - let now = Instant::now(); - let total = DESYNC_DEDUP_MAX_ENTRIES + 8192; - - let mut emitted_count = 0usize; - for key in 0..total as u64 { - let emitted = should_emit_full_desync(key, false, now); - if emitted { - emitted_count += 1; - } - } - - assert_eq!( - emitted_count, - DESYNC_DEDUP_MAX_ENTRIES + 1, - "after capacity is reached, same-tick newcomer churn must be rate-limited" - ); - - let len = DESYNC_DEDUP - .get() - .expect("dedup cache must be initialized by stress run") - .len(); - assert!( - len <= DESYNC_DEDUP_MAX_ENTRIES, - "dedup cache must stay bounded under stress churn" - ); -} - -#[test] -fn full_cache_newcomer_emission_is_rate_limited_but_periodic() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("desync dedup test lock must be available"); - clear_desync_dedup_for_testing(); - - let dedup = DESYNC_DEDUP.get_or_init(DashMap::new); - let base_now = Instant::now(); - - for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { - dedup.insert(key, base_now - TokioDuration::from_millis(10)); - } - - // Same-tick newcomer storm: only the first should emit full forensic record. - let mut burst_emits = 0usize; - for i in 0..1024u64 { - if should_emit_full_desync(10_000_000 + i, false, base_now) { - burst_emits += 1; - } - } - assert_eq!( - burst_emits, 1, - "full-cache newcomer burst must be bounded to a single full emit per interval" - ); - - // After each interval elapses, one newcomer may emit again. - for step in 1..=6u64 { - let t = base_now + DESYNC_FULL_CACHE_EMIT_MIN_INTERVAL * step as u32; - assert!( - should_emit_full_desync(20_000_000 + step, false, t), - "full-cache newcomer should re-emit once interval has elapsed" - ); - assert!( - !should_emit_full_desync(30_000_000 + step, false, t), - "additional newcomers in the same interval tick must remain suppressed" - ); - } -} - -#[test] -fn full_cache_mode_override_emits_every_event() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("desync dedup test lock must be available"); - clear_desync_dedup_for_testing(); - - let now = Instant::now(); - for i in 0..10_000u64 { - assert!( - should_emit_full_desync(100_000_000 + i, true, now), - "desync_all_full override must bypass dedup and rate-limit suppression" - ); - } -} - -#[test] -fn report_desync_stats_follow_rate_limited_full_cache_policy() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("desync dedup test lock must be available"); - clear_desync_dedup_for_testing(); - - let dedup = DESYNC_DEDUP.get_or_init(DashMap::new); - let base_now = Instant::now(); - for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { - dedup.insert(key, base_now - TokioDuration::from_millis(10)); - } - - let stats = Stats::new(); - let mut state = make_forensics_state(); - state.started_at = base_now; - - for i in 0..128u64 { - state.peer_hash = 0xABC0_0000_0000_0000u64 ^ i; - let _ = report_desync_frame_too_large( - &state, - ProtoTag::Secure, - 3, - 1024, - 4096, - Some([0x16, 0x03, 0x03, 0x00]), - &stats, - ); - } - - assert_eq!( - stats.get_desync_total(), - 128, - "every detected desync must increment total counter" - ); - assert_eq!( - stats.get_desync_full_logged(), - 1, - "same-interval full-cache newcomer storm must allow only one full forensic emit" - ); - assert_eq!( - stats.get_desync_suppressed(), - 127, - "remaining same-interval full-cache newcomer events must be suppressed" - ); - - // After one full interval in real wall clock, a newcomer should emit again. - thread::sleep(DESYNC_FULL_CACHE_EMIT_MIN_INTERVAL + TokioDuration::from_millis(20)); - state.peer_hash = 0xDEAD_BEEF_DEAD_BEEFu64; - let _ = report_desync_frame_too_large( - &state, - ProtoTag::Secure, - 4, - 1024, - 4097, - Some([0x16, 0x03, 0x03, 0x01]), - &stats, - ); - - assert_eq!( - stats.get_desync_full_logged(), - 2, - "full forensic emission must recover after rate-limit interval" - ); -} - -#[test] -fn concurrent_full_cache_newcomer_storm_is_single_emit_per_interval() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("desync dedup test lock must be available"); - clear_desync_dedup_for_testing(); - - let dedup = DESYNC_DEDUP.get_or_init(DashMap::new); - let base_now = Instant::now(); - for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { - dedup.insert(key, base_now - TokioDuration::from_millis(10)); - } - - let emits = Arc::new(AtomicUsize::new(0)); - let mut workers = Vec::new(); - for worker_id in 0..32u64 { - let emits = Arc::clone(&emits); - workers.push(thread::spawn(move || { - for i in 0..512u64 { - let key = 0x7000_0000_0000_0000u64 ^ (worker_id << 20) ^ i; - if should_emit_full_desync(key, false, base_now) { - emits.fetch_add(1, Ordering::Relaxed); - } - } - })); - } - - for worker in workers { - worker.join().expect("worker thread must not panic"); - } - - assert_eq!( - emits.load(Ordering::Relaxed), - 1, - "concurrent same-interval full-cache storm must allow only one full forensic emit" - ); -} - -#[test] -fn light_fuzz_full_cache_rate_limit_oracle_matches_model() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("desync dedup test lock must be available"); - clear_desync_dedup_for_testing(); - - let dedup = DESYNC_DEDUP.get_or_init(DashMap::new); - let base_now = Instant::now(); - for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { - dedup.insert(key, base_now - TokioDuration::from_millis(10)); - } - - let mut rng = StdRng::seed_from_u64(0xD15EA5E5_F00DBAAD); - let mut model_last_emit: Option = None; - - for i in 0..4096u64 { - let jitter_ms: u64 = rng.random_range(0..=3000); - let t = base_now + TokioDuration::from_millis(jitter_ms); - let key = 0x55AA_0000_0000_0000u64 ^ i ^ rng.random::(); - let actual = should_emit_full_desync(key, false, t); - - let expected = match model_last_emit { - None => { - model_last_emit = Some(t); - true - } - Some(last) => { - match t.checked_duration_since(last) { - Some(elapsed) if elapsed >= DESYNC_FULL_CACHE_EMIT_MIN_INTERVAL => { - model_last_emit = Some(t); - true - } - Some(_) => false, - None => { - // Match production fail-open behavior for non-monotonic synthetic input. - model_last_emit = Some(t); - true - } - } - } - }; - - assert_eq!( - actual, expected, - "full-cache rate-limit gate diverged from reference model under light fuzz" - ); - } -} - -#[test] -fn full_cache_gate_lock_poison_is_fail_closed_without_panic() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("desync dedup test lock must be available"); - clear_desync_dedup_for_testing(); - - let dedup = DESYNC_DEDUP.get_or_init(DashMap::new); - let base_now = Instant::now(); - for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { - dedup.insert(key, base_now - TokioDuration::from_millis(10)); - } - - // Poison the full-cache gate lock intentionally. - let gate = DESYNC_FULL_CACHE_LAST_EMIT_AT.get_or_init(|| Mutex::new(None)); - let _ = std::panic::catch_unwind(|| { - let _lock = gate - .lock() - .expect("gate lock must be lockable before poison"); - panic!("intentional gate poison for fail-closed regression"); - }); - - let emitted = should_emit_full_desync(0xFACE_0000_0000_0001, false, base_now); - assert!( - !emitted, - "poisoned full-cache gate must fail-closed (suppress) instead of panic or fail-open" - ); - assert!( - dedup.len() <= DESYNC_DEDUP_MAX_ENTRIES, - "dedup cache must remain bounded even when gate lock is poisoned" - ); -} - -#[test] -fn full_cache_non_monotonic_time_emits_and_resets_gate_safely() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("desync dedup test lock must be available"); - clear_desync_dedup_for_testing(); - - let dedup = DESYNC_DEDUP.get_or_init(DashMap::new); - let base_now = Instant::now(); - for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { - dedup.insert(key, base_now - TokioDuration::from_millis(10)); - } - - // First event seeds the gate. - assert!(should_emit_full_desync( - 0xABCD_0000_0000_0001, - false, - base_now + TokioDuration::from_millis(900) - )); - - // Synthetic earlier timestamp must not panic; it should fail-open and reset gate. - assert!(should_emit_full_desync( - 0xABCD_0000_0000_0002, - false, - base_now + TokioDuration::from_millis(100) - )); - - // Same instant again remains suppressed after reset. - assert!(!should_emit_full_desync( - 0xABCD_0000_0000_0003, - false, - base_now + TokioDuration::from_millis(100) - )); -} - -#[test] -fn desync_dedup_full_cache_inserts_new_key_with_bounded_single_key_churn() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("desync dedup test lock must be available"); - clear_desync_dedup_for_testing(); - - let dedup = DESYNC_DEDUP.get_or_init(DashMap::new); - let base_now = Instant::now(); - - // Fill with fresh entries so stale-pruning does not apply. - for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { - dedup.insert(key, base_now - TokioDuration::from_millis(10)); - } - - let before_keys: std::collections::HashSet = dedup.iter().map(|e| *e.key()).collect(); - - let newcomer_key = u64::MAX; - let emitted = should_emit_full_desync(newcomer_key, false, base_now); - assert!( - emitted, - "new entry under full fresh cache must emit after bounded eviction" - ); - assert!( - dedup.get(&newcomer_key).is_some(), - "new key must be inserted after bounded eviction" - ); - - let after_keys: std::collections::HashSet = dedup.iter().map(|e| *e.key()).collect(); - let removed_count = before_keys.difference(&after_keys).count(); - let added_count = after_keys.difference(&before_keys).count(); - - assert_eq!( - removed_count, 1, - "full-cache insertion must evict exactly one prior key" - ); - assert_eq!( - added_count, 1, - "full-cache insertion must add exactly one newcomer key" - ); - assert!( - dedup.len() <= DESYNC_DEDUP_MAX_ENTRIES, - "dedup cache must remain hard-bounded after full-cache churn" - ); -} - -#[test] -fn light_fuzz_desync_dedup_temporal_gate_behavior_is_stable() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("desync dedup test lock must be available"); - clear_desync_dedup_for_testing(); - - let key = 0xC0DE_CAFE_u64; - let start = Instant::now(); - - assert!( - should_emit_full_desync(key, false, start), - "first event for key must emit full forensic record" - ); - - // Deterministic pseudo-random time deltas around dedup window edge. - let mut s: u64 = 0x1234_5678_9ABC_DEF0; - for _ in 0..2048 { - s ^= s << 7; - s ^= s >> 9; - s ^= s << 8; - - let delta_ms = s % (DESYNC_DEDUP_WINDOW.as_millis() as u64 * 2 + 1); - let now = start + TokioDuration::from_millis(delta_ms); - let emitted = should_emit_full_desync(key, false, now); - - if delta_ms < DESYNC_DEDUP_WINDOW.as_millis() as u64 { - assert!( - !emitted, - "events inside dedup window must remain suppressed" - ); - } else { - // Once window elapsed for this key, at least one sample should re-emit and refresh. - if emitted { - return; - } - } - } - - panic!("expected at least one post-window sample to re-emit forensic record"); -} - -fn make_forensics_state() -> RelayForensicsState { - RelayForensicsState { - trace_id: 1, - conn_id: 2, - user: "test-user".to_string(), - peer: "127.0.0.1:50000".parse::().unwrap(), - peer_hash: 3, - started_at: Instant::now(), - bytes_c2me: 0, - bytes_me2c: Arc::new(AtomicU64::new(0)), - desync_all_full: false, - } -} - -fn make_crypto_reader(reader: R) -> CryptoReader -where - R: tokio::io::AsyncRead + Unpin, -{ - let key = [0u8; 32]; - let iv = 0u128; - CryptoReader::new(reader, AesCtr::new(&key, iv)) -} - -fn make_crypto_writer(writer: W) -> CryptoWriter -where - W: tokio::io::AsyncWrite + Unpin, -{ - let key = [0u8; 32]; - let iv = 0u128; - CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) -} - -async fn make_me_pool_for_abort_test(stats: Arc) -> Arc { - let general = GeneralConfig::default(); - - MePool::new( - None, - vec![1u8; 32], - None, - false, - None, - Vec::new(), - 1, - None, - 12, - 1200, - HashMap::new(), - HashMap::new(), - None, - NetworkDecision::default(), - None, - Arc::new(SecureRandom::new()), - stats, - general.me_keepalive_enabled, - general.me_keepalive_interval_secs, - general.me_keepalive_jitter_secs, - general.me_keepalive_payload_random, - general.rpc_proxy_req_every, - general.me_warmup_stagger_enabled, - general.me_warmup_step_delay_ms, - general.me_warmup_step_jitter_ms, - general.me_reconnect_max_concurrent_per_dc, - general.me_reconnect_backoff_base_ms, - general.me_reconnect_backoff_cap_ms, - general.me_reconnect_fast_retry_count, - general.me_single_endpoint_shadow_writers, - general.me_single_endpoint_outage_mode_enabled, - general.me_single_endpoint_outage_disable_quarantine, - general.me_single_endpoint_outage_backoff_min_ms, - general.me_single_endpoint_outage_backoff_max_ms, - general.me_single_endpoint_shadow_rotate_every_secs, - general.me_floor_mode, - general.me_adaptive_floor_idle_secs, - general.me_adaptive_floor_min_writers_single_endpoint, - general.me_adaptive_floor_min_writers_multi_endpoint, - general.me_adaptive_floor_recover_grace_secs, - general.me_adaptive_floor_writers_per_core_total, - general.me_adaptive_floor_cpu_cores_override, - general.me_adaptive_floor_max_extra_writers_single_per_core, - general.me_adaptive_floor_max_extra_writers_multi_per_core, - general.me_adaptive_floor_max_active_writers_per_core, - general.me_adaptive_floor_max_warm_writers_per_core, - general.me_adaptive_floor_max_active_writers_global, - general.me_adaptive_floor_max_warm_writers_global, - general.hardswap, - general.me_pool_drain_ttl_secs, - general.me_instadrain, - general.me_pool_drain_threshold, - general.me_pool_drain_soft_evict_enabled, - general.me_pool_drain_soft_evict_grace_secs, - general.me_pool_drain_soft_evict_per_writer, - general.me_pool_drain_soft_evict_budget_per_core, - general.me_pool_drain_soft_evict_cooldown_ms, - general.effective_me_pool_force_close_secs(), - general.me_pool_min_fresh_ratio, - general.me_hardswap_warmup_delay_min_ms, - general.me_hardswap_warmup_delay_max_ms, - general.me_hardswap_warmup_extra_passes, - general.me_hardswap_warmup_pass_backoff_base_ms, - general.me_bind_stale_mode, - general.me_bind_stale_ttl_secs, - general.me_secret_atomic_snapshot, - general.me_deterministic_writer_sort, - MeWriterPickMode::default(), - general.me_writer_pick_sample_size, - MeSocksKdfPolicy::default(), - general.me_writer_cmd_channel_capacity, - general.me_route_channel_capacity, - general.me_route_backpressure_base_timeout_ms, - general.me_route_backpressure_high_timeout_ms, - general.me_route_backpressure_high_watermark_pct, - general.me_reader_route_data_wait_ms, - general.me_health_interval_ms_unhealthy, - general.me_health_interval_ms_healthy, - general.me_warn_rate_limit_ms, - MeRouteNoWriterMode::default(), - general.me_route_no_writer_wait_ms, - general.me_route_inline_recovery_attempts, - general.me_route_inline_recovery_wait_ms, - ) -} - -fn encrypt_for_reader(plaintext: &[u8]) -> Vec { - let key = [0u8; 32]; - let iv = 0u128; - let mut cipher = AesCtr::new(&key, iv); - cipher.encrypt(plaintext) -} - -#[tokio::test] -async fn read_client_payload_times_out_on_header_stall() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("middle relay test lock must be available"); - let (reader, _writer) = duplex(1024); - let mut crypto_reader = make_crypto_reader(reader); - let buffer_pool = Arc::new(BufferPool::new()); - let stats = Stats::new(); - let forensics = make_forensics_state(); - let mut frame_counter = 0; - - let result = read_client_payload( - &mut crypto_reader, - ProtoTag::Intermediate, - 1024, - TokioDuration::from_millis(25), - &buffer_pool, - &forensics, - &mut frame_counter, - &stats, - ) - .await; - - assert!( - matches!(result, Err(ProxyError::Io(ref e)) if e.kind() == std::io::ErrorKind::TimedOut), - "stalled header read must time out" - ); -} - -#[tokio::test] -async fn read_client_payload_times_out_on_payload_stall() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("middle relay test lock must be available"); - let (reader, mut writer) = duplex(1024); - let encrypted_len = encrypt_for_reader(&[8, 0, 0, 0]); - writer.write_all(&encrypted_len).await.unwrap(); - - let mut crypto_reader = make_crypto_reader(reader); - let buffer_pool = Arc::new(BufferPool::new()); - let stats = Stats::new(); - let forensics = make_forensics_state(); - let mut frame_counter = 0; - - let result = read_client_payload( - &mut crypto_reader, - ProtoTag::Intermediate, - 1024, - TokioDuration::from_millis(25), - &buffer_pool, - &forensics, - &mut frame_counter, - &stats, - ) - .await; - - assert!( - matches!(result, Err(ProxyError::Io(ref e)) if e.kind() == std::io::ErrorKind::TimedOut), - "stalled payload body read must time out" - ); -} - -#[tokio::test] -async fn read_client_payload_large_intermediate_frame_is_exact() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("middle relay test lock must be available"); - - let (reader, mut writer) = duplex(262_144); - let mut crypto_reader = make_crypto_reader(reader); - let buffer_pool = Arc::new(BufferPool::new()); - let stats = Stats::new(); - let forensics = make_forensics_state(); - let mut frame_counter = 0; - - let payload_len = buffer_pool.buffer_size().saturating_mul(3).max(65_537); - let mut plaintext = Vec::with_capacity(4 + payload_len); - plaintext.extend_from_slice(&(payload_len as u32).to_le_bytes()); - plaintext.extend((0..payload_len).map(|idx| (idx as u8).wrapping_mul(31))); - - let encrypted = encrypt_for_reader(&plaintext); - writer.write_all(&encrypted).await.unwrap(); - - let read = read_client_payload( - &mut crypto_reader, - ProtoTag::Intermediate, - payload_len + 16, - TokioDuration::from_secs(1), - &buffer_pool, - &forensics, - &mut frame_counter, - &stats, - ) - .await - .expect("payload read must succeed") - .expect("frame must be present"); - - let (frame, quickack) = read; - assert!(!quickack, "quickack flag must be unset"); - assert_eq!( - frame.len(), - payload_len, - "payload size must match wire length" - ); - for (idx, byte) in frame.iter().enumerate() { - assert_eq!(*byte, (idx as u8).wrapping_mul(31)); - } - assert_eq!(frame_counter, 1, "exactly one frame must be counted"); -} - -#[tokio::test] -async fn read_client_payload_secure_strips_tail_padding_bytes() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("middle relay test lock must be available"); - - let (reader, mut writer) = duplex(1024); - let mut crypto_reader = make_crypto_reader(reader); - let buffer_pool = Arc::new(BufferPool::new()); - let stats = Stats::new(); - let forensics = make_forensics_state(); - let mut frame_counter = 0; - - let payload = [0x11u8, 0x22, 0x33, 0x44, 0xaa, 0xbb, 0xcc, 0xdd]; - let tail = [0xeeu8, 0xff, 0x99]; - let wire_len = payload.len() + tail.len(); - - let mut plaintext = Vec::with_capacity(4 + wire_len); - plaintext.extend_from_slice(&(wire_len as u32).to_le_bytes()); - plaintext.extend_from_slice(&payload); - plaintext.extend_from_slice(&tail); - let encrypted = encrypt_for_reader(&plaintext); - writer.write_all(&encrypted).await.unwrap(); - - let read = read_client_payload( - &mut crypto_reader, - ProtoTag::Secure, - 1024, - TokioDuration::from_secs(1), - &buffer_pool, - &forensics, - &mut frame_counter, - &stats, - ) - .await - .expect("secure payload read must succeed") - .expect("secure frame must be present"); - - let (frame, quickack) = read; - assert!(!quickack, "quickack flag must be unset"); - assert_eq!(frame.as_ref(), &payload); - assert_eq!(frame_counter, 1, "one secure frame must be counted"); -} - -#[tokio::test] -async fn read_client_payload_secure_rejects_wire_len_below_4() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("middle relay test lock must be available"); - - let (reader, mut writer) = duplex(1024); - let mut crypto_reader = make_crypto_reader(reader); - let buffer_pool = Arc::new(BufferPool::new()); - let stats = Stats::new(); - let forensics = make_forensics_state(); - let mut frame_counter = 0; - - let mut plaintext = Vec::with_capacity(7); - plaintext.extend_from_slice(&3u32.to_le_bytes()); - plaintext.extend_from_slice(&[1u8, 2, 3]); - let encrypted = encrypt_for_reader(&plaintext); - writer.write_all(&encrypted).await.unwrap(); - - let result = read_client_payload( - &mut crypto_reader, - ProtoTag::Secure, - 1024, - TokioDuration::from_secs(1), - &buffer_pool, - &forensics, - &mut frame_counter, - &stats, - ) - .await; - - assert!( - matches!(result, Err(ProxyError::Proxy(ref msg)) if msg.contains("Frame too small: 3")), - "secure wire length below 4 must be fail-closed by the frame-too-small guard" - ); -} - -#[tokio::test] -async fn read_client_payload_intermediate_skips_zero_len_frame() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("middle relay test lock must be available"); - - let (reader, mut writer) = duplex(1024); - let mut crypto_reader = make_crypto_reader(reader); - let buffer_pool = Arc::new(BufferPool::new()); - let stats = Stats::new(); - let forensics = make_forensics_state(); - let mut frame_counter = 0; - - let payload = [7u8, 6, 5, 4, 3, 2, 1, 0]; - let mut plaintext = Vec::with_capacity(4 + 4 + payload.len()); - plaintext.extend_from_slice(&0u32.to_le_bytes()); - plaintext.extend_from_slice(&(payload.len() as u32).to_le_bytes()); - plaintext.extend_from_slice(&payload); - let encrypted = encrypt_for_reader(&plaintext); - writer.write_all(&encrypted).await.unwrap(); - - let read = read_client_payload( - &mut crypto_reader, - ProtoTag::Intermediate, - 1024, - TokioDuration::from_secs(1), - &buffer_pool, - &forensics, - &mut frame_counter, - &stats, - ) - .await - .expect("intermediate payload read must succeed") - .expect("frame must be present"); - - let (frame, quickack) = read; - assert!(!quickack, "quickack flag must be unset"); - assert_eq!(frame.as_ref(), &payload); - assert_eq!(frame_counter, 1, "zero-length frame must be skipped"); -} - -#[tokio::test] -async fn read_client_payload_abridged_extended_len_sets_quickack() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("middle relay test lock must be available"); - - let (reader, mut writer) = duplex(4096); - let mut crypto_reader = make_crypto_reader(reader); - let buffer_pool = Arc::new(BufferPool::new()); - let stats = Stats::new(); - let forensics = make_forensics_state(); - let mut frame_counter = 0; - - let payload_len = 4 * 130; - let len_words = (payload_len / 4) as u32; - let mut plaintext = Vec::with_capacity(1 + 3 + payload_len); - plaintext.push(0xff | 0x80); - let lw = len_words.to_le_bytes(); - plaintext.extend_from_slice(&lw[..3]); - plaintext.extend((0..payload_len).map(|idx| (idx as u8).wrapping_add(17))); - - let encrypted = encrypt_for_reader(&plaintext); - writer.write_all(&encrypted).await.unwrap(); - - let read = read_client_payload( - &mut crypto_reader, - ProtoTag::Abridged, - payload_len + 16, - TokioDuration::from_secs(1), - &buffer_pool, - &forensics, - &mut frame_counter, - &stats, - ) - .await - .expect("abridged payload read must succeed") - .expect("frame must be present"); - - let (frame, quickack) = read; - assert!( - quickack, - "quickack bit must be propagated from abridged header" - ); - assert_eq!(frame.len(), payload_len); - assert_eq!(frame_counter, 1, "one abridged frame must be counted"); -} - -#[tokio::test] -async fn read_client_payload_returns_buffer_to_pool_after_emit() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("middle relay test lock must be available"); - - let pool = Arc::new(BufferPool::with_config(64, 8)); - pool.preallocate(1); - assert_eq!(pool.stats().pooled, 1, "precondition: one pooled buffer"); - - let (reader, mut writer) = duplex(4096); - let mut crypto_reader = make_crypto_reader(reader); - let stats = Stats::new(); - let forensics = make_forensics_state(); - let mut frame_counter = 0; - - // Force growth beyond default pool buffer size to catch ownership-take regressions. - let payload_len = 257usize; - let mut plaintext = Vec::with_capacity(4 + payload_len); - plaintext.extend_from_slice(&(payload_len as u32).to_le_bytes()); - plaintext.extend((0..payload_len).map(|idx| (idx as u8).wrapping_mul(13))); - - let encrypted = encrypt_for_reader(&plaintext); - writer.write_all(&encrypted).await.unwrap(); - - let _ = read_client_payload( - &mut crypto_reader, - ProtoTag::Intermediate, - payload_len + 8, - TokioDuration::from_secs(1), - &pool, - &forensics, - &mut frame_counter, - &stats, - ) - .await - .expect("payload read must succeed") - .expect("frame must be present"); - - assert_eq!(frame_counter, 1); - let pool_stats = pool.stats(); - assert!( - pool_stats.pooled >= 1, - "emitted payload buffer must be returned to pool to avoid pool drain" - ); -} - -#[tokio::test] -async fn read_client_payload_keeps_pool_buffer_checked_out_until_frame_drop() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("middle relay test lock must be available"); - - let pool = Arc::new(BufferPool::with_config(64, 2)); - pool.preallocate(1); - assert_eq!( - pool.stats().pooled, - 1, - "one pooled buffer must be available" - ); - - let (reader, mut writer) = duplex(1024); - let mut crypto_reader = make_crypto_reader(reader); - let stats = Stats::new(); - let forensics = make_forensics_state(); - let mut frame_counter = 0; - - let payload = [0x41u8, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48]; - let mut plaintext = Vec::with_capacity(4 + payload.len()); - plaintext.extend_from_slice(&(payload.len() as u32).to_le_bytes()); - plaintext.extend_from_slice(&payload); - let encrypted = encrypt_for_reader(&plaintext); - writer.write_all(&encrypted).await.unwrap(); - - let (frame, quickack) = read_client_payload( - &mut crypto_reader, - ProtoTag::Intermediate, - 1024, - TokioDuration::from_secs(1), - &pool, - &forensics, - &mut frame_counter, - &stats, - ) - .await - .expect("payload read must succeed") - .expect("frame must be present"); - - assert!(!quickack); - assert_eq!(frame.as_ref(), &payload); - assert_eq!( - pool.stats().pooled, - 0, - "buffer must stay checked out while frame payload is alive" - ); - - drop(frame); - assert!( - pool.stats().pooled >= 1, - "buffer must return to pool only after frame drop" - ); -} - -#[tokio::test] -async fn enqueue_c2me_close_unblocks_after_queue_drain() { - let (tx, mut rx) = mpsc::channel::(1); - tx.send(C2MeCommand::Data { - payload: make_pooled_payload(&[0x41]), - flags: 0, - }) - .await - .unwrap(); - - let tx2 = tx.clone(); - let close_task = - tokio::spawn(async move { enqueue_c2me_command(&tx2, C2MeCommand::Close).await }); - - tokio::time::sleep(TokioDuration::from_millis(10)).await; - - let first = timeout(TokioDuration::from_millis(100), rx.recv()) - .await - .unwrap() - .expect("first queued item must be present"); - assert!(matches!(first, C2MeCommand::Data { .. })); - - close_task - .await - .unwrap() - .expect("close enqueue must succeed after drain"); - - let second = timeout(TokioDuration::from_millis(100), rx.recv()) - .await - .unwrap() - .expect("close command must follow after queue drain"); - assert!(matches!(second, C2MeCommand::Close)); -} - -#[tokio::test] -async fn enqueue_c2me_close_full_then_receiver_drop_fails_cleanly() { - let (tx, rx) = mpsc::channel::(1); - tx.send(C2MeCommand::Data { - payload: make_pooled_payload(&[0x42]), - flags: 0, - }) - .await - .unwrap(); - - let tx2 = tx.clone(); - let close_task = - tokio::spawn(async move { enqueue_c2me_command(&tx2, C2MeCommand::Close).await }); - - tokio::time::sleep(TokioDuration::from_millis(10)).await; - drop(rx); - - let result = timeout(TokioDuration::from_secs(1), close_task) - .await - .expect("close task must finish") - .expect("close task must not panic"); - assert!( - result.is_err(), - "close enqueue must fail cleanly when receiver is dropped under pressure" - ); -} - -#[tokio::test] -async fn process_me_writer_response_ack_obeys_flush_policy() { - let (writer_side, _reader_side) = duplex(1024); - let mut writer = make_crypto_writer(writer_side); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - let stats = Stats::new(); - let bytes_me2c = AtomicU64::new(0); - - let immediate = process_me_writer_response( - MeResponse::Ack(0x11223344), - &mut writer, - ProtoTag::Intermediate, - &rng, - &mut frame_buf, - &stats, - "user", - None, - 0, - &bytes_me2c, - 77, - true, - false, - ) - .await - .expect("ack response must be processed"); - - assert!(matches!( - immediate, - MeWriterResponseOutcome::Continue { - frames: 1, - bytes: 4, - flush_immediately: true, - } - )); - - let delayed = process_me_writer_response( - MeResponse::Ack(0x55667788), - &mut writer, - ProtoTag::Intermediate, - &rng, - &mut frame_buf, - &stats, - "user", - None, - 0, - &bytes_me2c, - 77, - false, - false, - ) - .await - .expect("ack response must be processed"); - - assert!(matches!( - delayed, - MeWriterResponseOutcome::Continue { - frames: 1, - bytes: 4, - flush_immediately: false, - } - )); -} - -#[tokio::test] -async fn process_me_writer_response_data_updates_byte_accounting() { - let (writer_side, _reader_side) = duplex(1024); - let mut writer = make_crypto_writer(writer_side); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - let stats = Stats::new(); - let bytes_me2c = AtomicU64::new(0); - - let payload = vec![1u8, 2, 3, 4, 5, 6, 7, 8, 9]; - let outcome = process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from(payload.clone()), - }, - &mut writer, - ProtoTag::Intermediate, - &rng, - &mut frame_buf, - &stats, - "user", - None, - 0, - &bytes_me2c, - 88, - false, - false, - ) - .await - .expect("data response must be processed"); - - assert!(matches!( - outcome, - MeWriterResponseOutcome::Continue { - frames: 1, - bytes, - flush_immediately: false, - } if bytes == payload.len() - )); - assert_eq!( - bytes_me2c.load(std::sync::atomic::Ordering::Relaxed), - payload.len() as u64, - "ME->C byte accounting must increase by emitted payload size" - ); -} - -#[tokio::test] -async fn process_me_writer_response_data_enforces_live_user_quota() { - let (writer_side, mut reader_side) = duplex(1024); - let mut writer = make_crypto_writer(writer_side); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - let stats = Stats::new(); - let bytes_me2c = AtomicU64::new(0); - - stats.add_user_octets_from("quota-user", 10); - - let result = process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from(vec![1u8, 2, 3, 4]), - }, - &mut writer, - ProtoTag::Intermediate, - &rng, - &mut frame_buf, - &stats, - "quota-user", - Some(12), - 0, - &bytes_me2c, - 89, - false, - false, - ) - .await; - - assert!( - matches!(result, Err(ProxyError::DataQuotaExceeded { user }) if user == "quota-user"), - "ME->client runtime path must terminate when live user quota is crossed" - ); - - let mut raw = [0u8; 1]; - assert!( - timeout(TokioDuration::from_millis(100), reader_side.read(&mut raw)) - .await - .is_err(), - "quota exhaustion must not write any ciphertext to the client stream" - ); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn process_me_writer_response_concurrent_same_user_quota_does_not_overshoot_limit() { - let stats = Stats::new(); - let bytes_me2c = AtomicU64::new(0); - let user = "quota-race-user"; - - let (writer_side_a, _reader_side_a) = duplex(1024); - let (writer_side_b, _reader_side_b) = duplex(1024); - let mut writer_a = make_crypto_writer(writer_side_a); - let mut writer_b = make_crypto_writer(writer_side_b); - let mut frame_buf_a = Vec::new(); - let mut frame_buf_b = Vec::new(); - let rng_a = SecureRandom::new(); - let rng_b = SecureRandom::new(); - - let fut_a = process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0x11]), - }, - &mut writer_a, - ProtoTag::Intermediate, - &rng_a, - &mut frame_buf_a, - &stats, - user, - Some(1), - 0, - &bytes_me2c, - 91, - false, - false, - ); - let fut_b = process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0x22]), - }, - &mut writer_b, - ProtoTag::Intermediate, - &rng_b, - &mut frame_buf_b, - &stats, - user, - Some(1), - 0, - &bytes_me2c, - 92, - false, - false, - ); - - let (result_a, result_b) = tokio::join!(fut_a, fut_b); - - assert!( - matches!(result_a, Err(ProxyError::DataQuotaExceeded { ref user }) if user == "quota-race-user") - || matches!(result_a, Ok(_)), - "concurrent quota test must complete without panicking" - ); - assert!( - matches!(result_b, Err(ProxyError::DataQuotaExceeded { ref user }) if user == "quota-race-user") - || matches!(result_b, Ok(_)), - "concurrent quota test must complete without panicking" - ); - assert!( - stats.get_user_total_octets(user) <= 1, - "same-user concurrent middle-relay responses must not overshoot the configured quota" - ); -} - -#[tokio::test] -async fn process_me_writer_response_data_does_not_forward_partial_payload_when_remaining_quota_is_smaller_than_message() - { - let (writer_side, mut reader_side) = duplex(1024); - let mut writer = make_crypto_writer(writer_side); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - let stats = Stats::new(); - let bytes_me2c = AtomicU64::new(0); - - stats.add_user_octets_to("partial-quota-user", 3); - - let result = process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from(vec![1u8, 2, 3, 4]), - }, - &mut writer, - ProtoTag::Intermediate, - &rng, - &mut frame_buf, - &stats, - "partial-quota-user", - Some(4), - 0, - &bytes_me2c, - 90, - false, - false, - ) - .await; - - assert!( - matches!(result, Err(ProxyError::DataQuotaExceeded { user }) if user == "partial-quota-user"), - "ME->client runtime path must reject oversized payloads before writing" - ); - - let mut raw = [0u8; 1]; - assert!( - timeout(TokioDuration::from_millis(100), reader_side.read(&mut raw)) - .await - .is_err(), - "oversized payloads must not leak any partial ciphertext to the client stream" - ); -} - -#[tokio::test] -async fn middle_relay_abort_midflight_releases_route_gauge() { - let stats = Arc::new(Stats::new()); - let me_pool = make_me_pool_for_abort_test(stats.clone()).await; - let config = Arc::new(ProxyConfig::default()); - let buffer_pool = Arc::new(BufferPool::new()); - let rng = Arc::new(SecureRandom::new()); - - let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Middle)); - let route_snapshot = route_runtime.snapshot(); - - let (server_side, client_side) = duplex(64 * 1024); - let (server_reader, server_writer) = tokio::io::split(server_side); - let crypto_reader = make_crypto_reader(server_reader); - let crypto_writer = make_crypto_writer(server_writer); - - let success = HandshakeSuccess { - user: "abort-middle-user".to_string(), - dc_idx: 2, - proto_tag: ProtoTag::Intermediate, - dec_key: [0u8; 32], - dec_iv: 0, - enc_key: [0u8; 32], - enc_iv: 0, - peer: "127.0.0.1:50001".parse().unwrap(), - is_tls: false, - }; - - let relay_task = tokio::spawn(handle_via_middle_proxy( - crypto_reader, - crypto_writer, - success, - me_pool, - stats.clone(), - config, - buffer_pool, - "127.0.0.1:443".parse().unwrap(), - rng, - route_runtime.subscribe(), - route_snapshot, - 0xdecafbad, - )); - - let started = tokio::time::timeout(TokioDuration::from_secs(2), async { - loop { - if stats.get_current_connections_me() == 1 { - break; - } - tokio::time::sleep(TokioDuration::from_millis(10)).await; - } - }) - .await; - assert!( - started.is_ok(), - "middle relay must increment route gauge before abort" - ); - - relay_task.abort(); - let joined = relay_task.await; - assert!( - joined.is_err(), - "aborted middle relay task must return join error" - ); - - tokio::time::sleep(TokioDuration::from_millis(20)).await; - assert_eq!( - stats.get_current_connections_me(), - 0, - "route gauge must be released when middle relay task is aborted mid-flight" - ); - - drop(client_side); -} - -#[tokio::test] -async fn middle_relay_cutover_midflight_releases_route_gauge() { - let stats = Arc::new(Stats::new()); - let me_pool = make_me_pool_for_abort_test(stats.clone()).await; - let config = Arc::new(ProxyConfig::default()); - let buffer_pool = Arc::new(BufferPool::new()); - let rng = Arc::new(SecureRandom::new()); - - let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Middle)); - let route_snapshot = route_runtime.snapshot(); - - let (server_side, client_side) = duplex(64 * 1024); - let (server_reader, server_writer) = tokio::io::split(server_side); - let crypto_reader = make_crypto_reader(server_reader); - let crypto_writer = make_crypto_writer(server_writer); - - let success = HandshakeSuccess { - user: "cutover-middle-user".to_string(), - dc_idx: 2, - proto_tag: ProtoTag::Intermediate, - dec_key: [0u8; 32], - dec_iv: 0, - enc_key: [0u8; 32], - enc_iv: 0, - peer: "127.0.0.1:50003".parse().unwrap(), - is_tls: false, - }; - - let relay_task = tokio::spawn(handle_via_middle_proxy( - crypto_reader, - crypto_writer, - success, - me_pool, - stats.clone(), - config, - buffer_pool, - "127.0.0.1:443".parse().unwrap(), - rng, - route_runtime.subscribe(), - route_snapshot, - 0xfeed_beef, - )); - - tokio::time::timeout(TokioDuration::from_secs(2), async { - loop { - if stats.get_current_connections_me() == 1 { - break; - } - tokio::time::sleep(TokioDuration::from_millis(10)).await; - } - }) - .await - .expect("middle relay must increment route gauge before cutover"); - - assert!( - route_runtime.set_mode(RelayRouteMode::Direct).is_some(), - "cutover must advance route generation" - ); - - let relay_result = tokio::time::timeout(TokioDuration::from_secs(6), relay_task) - .await - .expect("middle relay must terminate after cutover") - .expect("middle relay task must not panic"); - assert!( - relay_result.is_err(), - "cutover should terminate middle relay session" - ); - assert!( - matches!( - relay_result, - Err(ProxyError::Proxy(ref msg)) if msg == ROUTE_SWITCH_ERROR_MSG - ), - "client-visible cutover error must stay generic and avoid route-internal metadata" - ); - - assert_eq!( - stats.get_current_connections_me(), - 0, - "route gauge must be released when middle relay exits on cutover" - ); - - drop(client_side); -} - -async fn run_quota_race_attempt( - stats: &Stats, - bytes_me2c: &AtomicU64, - user: &str, - payload: u8, - conn_id: u64, - barrier: Arc, -) -> Result { - let (writer_side, _reader_side) = duplex(1024); - let mut writer = make_crypto_writer(writer_side); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - - barrier.wait().await; - process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from(vec![payload]), - }, - &mut writer, - ProtoTag::Intermediate, - &rng, - &mut frame_buf, - stats, - user, - Some(1), - 0, - bytes_me2c, - conn_id, - false, - false, - ) - .await -} - -#[tokio::test] -async fn abridged_max_extended_length_fails_closed_without_panic_or_partial_read() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("middle relay test lock must be available"); - - let (reader, mut writer) = duplex(256); - let mut crypto_reader = make_crypto_reader(reader); - let buffer_pool = Arc::new(BufferPool::new()); - let stats = Stats::new(); - let forensics = make_forensics_state(); - let mut frame_counter = 0; - - let plaintext = vec![0x7f, 0xff, 0xff, 0xff]; - let encrypted = encrypt_for_reader(&plaintext); - writer.write_all(&encrypted).await.unwrap(); - - let result = read_client_payload( - &mut crypto_reader, - ProtoTag::Abridged, - 4096, - TokioDuration::from_secs(1), - &buffer_pool, - &forensics, - &mut frame_counter, - &stats, - ) - .await; - - assert!( - result.is_err(), - "oversized abridged length must fail closed" - ); - assert_eq!( - frame_counter, 0, - "oversized frame must not be counted as accepted" - ); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn deterministic_quota_race_exactly_one_succeeds_and_one_is_rejected() { - let stats = Stats::new(); - let bytes_me2c = AtomicU64::new(0); - let user = "gap-t04-race-user"; - let barrier = Arc::new(Barrier::new(2)); - - let f1 = run_quota_race_attempt(&stats, &bytes_me2c, user, 0x11, 5001, barrier.clone()); - let f2 = run_quota_race_attempt(&stats, &bytes_me2c, user, 0x22, 5002, barrier); - - let (r1, r2) = tokio::join!(f1, f2); - - assert!( - matches!(r1, Ok(_) | Err(ProxyError::DataQuotaExceeded { .. })), - "first racer must either finish or fail closed on quota" - ); - assert!( - matches!(r2, Ok(_) | Err(ProxyError::DataQuotaExceeded { .. })), - "second racer must either finish or fail closed on quota" - ); - assert!( - matches!(r1, Err(ProxyError::DataQuotaExceeded { .. })) - || matches!(r2, Err(ProxyError::DataQuotaExceeded { .. })), - "at least one racer must be quota-rejected" - ); - assert_eq!( - stats.get_user_total_octets(user), - 1, - "same-user race must forward/account exactly one payload byte" - ); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn stress_quota_race_bursts_never_allow_double_success_per_round() { - let stats = Stats::new(); - let bytes_me2c = AtomicU64::new(0); - - for round in 0..128u64 { - let user = format!("gap-t04-race-burst-{round}"); - let barrier = Arc::new(Barrier::new(2)); - - let one = run_quota_race_attempt( - &stats, - &bytes_me2c, - &user, - 0x33, - 6000 + round, - barrier.clone(), - ); - let two = run_quota_race_attempt(&stats, &bytes_me2c, &user, 0x44, 7000 + round, barrier); - - let (r1, r2) = tokio::join!(one, two); - assert!( - matches!(r1, Ok(_) | Err(ProxyError::DataQuotaExceeded { .. })) - && matches!(r2, Ok(_) | Err(ProxyError::DataQuotaExceeded { .. })), - "round {round}: racers must resolve cleanly without unexpected errors" - ); - assert!( - matches!(r1, Err(ProxyError::DataQuotaExceeded { .. })) - || matches!(r2, Err(ProxyError::DataQuotaExceeded { .. })), - "round {round}: at least one racer must be quota-rejected" - ); - assert_eq!( - stats.get_user_total_octets(&user), - 1, - "round {round}: same-user total octets must remain exactly 1 (single forwarded winner)" - ); - } -} - -#[tokio::test] -async fn middle_relay_cutover_storm_multi_session_keeps_generic_errors_and_releases_gauge() { - let session_count = 6usize; - let stats = Arc::new(Stats::new()); - let me_pool = make_me_pool_for_abort_test(stats.clone()).await; - let config = Arc::new(ProxyConfig::default()); - let buffer_pool = Arc::new(BufferPool::new()); - let rng = Arc::new(SecureRandom::new()); - - let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Middle)); - let route_snapshot = route_runtime.snapshot(); - - let mut relay_tasks = Vec::with_capacity(session_count); - let mut client_sides = Vec::with_capacity(session_count); - - for idx in 0..session_count { - let (server_side, client_side) = duplex(64 * 1024); - client_sides.push(client_side); - let (server_reader, server_writer) = tokio::io::split(server_side); - let crypto_reader = make_crypto_reader(server_reader); - let crypto_writer = make_crypto_writer(server_writer); - - let success = HandshakeSuccess { - user: format!("cutover-storm-middle-user-{idx}"), - dc_idx: 2, - proto_tag: ProtoTag::Intermediate, - dec_key: [0u8; 32], - dec_iv: 0, - enc_key: [0u8; 32], - enc_iv: 0, - peer: SocketAddr::new( - std::net::IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 1)), - 52000 + idx as u16, - ), - is_tls: false, - }; - - relay_tasks.push(tokio::spawn(handle_via_middle_proxy( - crypto_reader, - crypto_writer, - success, - me_pool.clone(), - stats.clone(), - config.clone(), - buffer_pool.clone(), - "127.0.0.1:443".parse().unwrap(), - rng.clone(), - route_runtime.subscribe(), - route_snapshot, - 0xB000_0000 + idx as u64, - ))); - } - - tokio::time::timeout(TokioDuration::from_secs(4), async { - loop { - if stats.get_current_connections_me() == session_count as u64 { - break; - } - tokio::time::sleep(TokioDuration::from_millis(10)).await; - } - }) - .await - .expect("all middle sessions must become active before cutover storm"); - - let route_runtime_flipper = route_runtime.clone(); - let flipper = tokio::spawn(async move { - for step in 0..64u32 { - let mode = if (step & 1) == 0 { - RelayRouteMode::Direct - } else { - RelayRouteMode::Middle - }; - let _ = route_runtime_flipper.set_mode(mode); - tokio::time::sleep(TokioDuration::from_millis(15)).await; - } - }); - - for relay_task in relay_tasks { - let relay_result = tokio::time::timeout(TokioDuration::from_secs(10), relay_task) - .await - .expect("middle relay task must finish under cutover storm") - .expect("middle relay task must not panic"); - - assert!( - matches!( - relay_result, - Err(ProxyError::Proxy(ref msg)) if msg == ROUTE_SWITCH_ERROR_MSG - ), - "storm-cutover termination must remain generic for all middle sessions" - ); - } - - flipper.abort(); - let _ = flipper.await; - - assert_eq!( - stats.get_current_connections_me(), - 0, - "middle route gauge must return to zero after cutover storm" - ); - - drop(client_sides); -} - -#[tokio::test] -async fn secure_padding_distribution_in_relay_writer() { - timeout(TokioDuration::from_secs(10), async { - let (mut client_side, relay_side) = duplex(512 * 1024); - let key = [0u8; 32]; - let iv = 0u128; - let mut writer = CryptoWriter::new(relay_side, AesCtr::new(&key, iv), 8 * 1024); - let rng = Arc::new(SecureRandom::new()); - let mut frame_buf = Vec::new(); - let mut decryptor = AesCtr::new(&key, iv); - - let mut padding_counts = [0usize; 4]; - let iterations = 180usize; - let payload = vec![0xAAu8; 100]; // 4-byte aligned - - for _ in 0..iterations { - write_client_payload( - &mut writer, - ProtoTag::Secure, - 0, - &payload, - &rng, - &mut frame_buf, - ) - .await - .expect("payload write must succeed"); - writer - .flush() - .await - .expect("writer flush must complete so encrypted frame becomes readable"); - - let mut len_buf = [0u8; 4]; - client_side - .read_exact(&mut len_buf) - .await - .expect("must read encrypted secure length"); - let decrypted_len_bytes = decryptor.decrypt(&len_buf); - let decrypted_len_bytes: [u8; 4] = decrypted_len_bytes - .try_into() - .expect("decrypted length must be 4 bytes"); - let wire_len = (u32::from_le_bytes(decrypted_len_bytes) & 0x7fff_ffff) as usize; - - assert!( - wire_len >= payload.len(), - "wire length must include at least payload bytes" - ); - let padding_len = wire_len - payload.len(); - assert!(padding_len >= 1 && padding_len <= 3); - padding_counts[padding_len] += 1; - - // Drain and decrypt frame bytes so CTR state stays aligned across writes. - let mut trash = vec![0u8; wire_len]; - client_side - .read_exact(&mut trash) - .await - .expect("must read encrypted secure frame body"); - let _ = decryptor.decrypt(&trash); - } - - for p in 1..=3 { - let count = padding_counts[p]; - assert!( - count > iterations / 8, - "padding length {p} is under-represented ({count}/{iterations})" - ); - } - }) - .await - .expect("secure padding distribution test exceeded runtime budget"); -} - -#[tokio::test] -async fn negative_middle_end_connection_lost_during_relay_exits_on_client_eof() { - let (client_reader_side, client_writer_side) = duplex(1024); - let (_relay_reader_side, relay_writer_side) = duplex(1024); - - let key = [0u8; 32]; - let iv = 0u128; - let crypto_reader = CryptoReader::new(client_reader_side, AesCtr::new(&key, iv)); - let crypto_writer = CryptoWriter::new(relay_writer_side, AesCtr::new(&key, iv), 1024); - - let stats = Arc::new(Stats::new()); - let config = Arc::new(ProxyConfig::default()); - let buffer_pool = Arc::new(BufferPool::with_config(1024, 1)); - let rng = Arc::new(SecureRandom::new()); - let route_runtime = RouteRuntimeController::new(RelayRouteMode::Middle); - - // Create an ME pool. - let me_pool = make_me_pool_for_abort_test(stats.clone()).await; - - // ConnRegistry ids are monotonic; reserve one id so we can predict the - // next session conn_id and close it deterministically without relying on - // writer-bound views such as active_conn_ids(). - let (probe_conn_id, probe_rx) = me_pool.registry().register().await; - drop(probe_rx); - me_pool.registry().unregister(probe_conn_id).await; - let target_conn_id = probe_conn_id.wrapping_add(1); - - let success = HandshakeSuccess { - user: "test-user".to_string(), - peer: "127.0.0.1:12345".parse().unwrap(), - dc_idx: 1, - proto_tag: ProtoTag::Intermediate, - enc_key: key, - enc_iv: iv, - dec_key: key, - dec_iv: iv, - is_tls: false, - }; - - let session_task = tokio::spawn(handle_via_middle_proxy( - crypto_reader, - crypto_writer, - success, - me_pool.clone(), - stats.clone(), - config.clone(), - buffer_pool.clone(), - "127.0.0.1:443".parse().unwrap(), - rng.clone(), - route_runtime.subscribe(), - route_runtime.snapshot(), - 0x1234_5678, - )); - - // Wait until session startup is visible, then unregister the predicted - // conn_id to close the per-session ME response channel. - timeout(TokioDuration::from_millis(500), async { - loop { - if stats.get_current_connections_me() >= 1 { - break; - } - tokio::time::sleep(TokioDuration::from_millis(10)).await; - } - }) - .await - .expect("ME session must start before channel close simulation"); - - me_pool.registry().unregister(target_conn_id).await; - - drop(client_writer_side); - - let result = timeout(TokioDuration::from_secs(2), session_task) - .await - .expect("Session task must terminate after ME drop and client EOF") - .expect("Session task must not panic"); - - assert!( - result.is_ok(), - "Session should complete cleanly after ME drop when client closes, got: {:?}", - result - ); -} - -#[tokio::test] -async fn adversarial_middle_end_drop_plus_cutover_returns_generic_route_switch() { - let (client_reader_side, _client_writer_side) = duplex(1024); - let (_relay_reader_side, relay_writer_side) = duplex(1024); - - let key = [0u8; 32]; - let iv = 0u128; - let crypto_reader = CryptoReader::new(client_reader_side, AesCtr::new(&key, iv)); - let crypto_writer = CryptoWriter::new(relay_writer_side, AesCtr::new(&key, iv), 1024); - - let stats = Arc::new(Stats::new()); - let config = Arc::new(ProxyConfig::default()); - let buffer_pool = Arc::new(BufferPool::with_config(1024, 1)); - let rng = Arc::new(SecureRandom::new()); - let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Middle)); - - let me_pool = make_me_pool_for_abort_test(stats.clone()).await; - - // Predict the next conn_id so we can force-drop its ME channel deterministically. - let (probe_conn_id, probe_rx) = me_pool.registry().register().await; - drop(probe_rx); - me_pool.registry().unregister(probe_conn_id).await; - let target_conn_id = probe_conn_id.wrapping_add(1); - - let success = HandshakeSuccess { - user: "test-user-cutover".to_string(), - peer: "127.0.0.1:12345".parse().unwrap(), - dc_idx: 1, - proto_tag: ProtoTag::Intermediate, - enc_key: key, - enc_iv: iv, - dec_key: key, - dec_iv: iv, - is_tls: false, - }; - - let runtime_clone = route_runtime.clone(); - let session_task = tokio::spawn(handle_via_middle_proxy( - crypto_reader, - crypto_writer, - success, - me_pool.clone(), - stats.clone(), - config, - buffer_pool, - "127.0.0.1:443".parse().unwrap(), - rng, - runtime_clone.subscribe(), - runtime_clone.snapshot(), - 0xC001_CAFE, - )); - - timeout(TokioDuration::from_millis(500), async { - loop { - if stats.get_current_connections_me() >= 1 { - break; - } - tokio::time::sleep(TokioDuration::from_millis(10)).await; - } - }) - .await - .expect("ME session must start before race trigger"); - - // Race ME channel drop with route cutover and assert generic client-visible outcome. - me_pool.registry().unregister(target_conn_id).await; - assert!( - route_runtime.set_mode(RelayRouteMode::Direct).is_some(), - "cutover must advance generation" - ); - - let relay_result = timeout(TokioDuration::from_secs(6), session_task) - .await - .expect("session must terminate under ME-drop + cutover race") - .expect("session task must not panic"); - - assert!( - matches!( - relay_result, - Err(ProxyError::Proxy(ref msg)) if msg == ROUTE_SWITCH_ERROR_MSG - ), - "race outcome must remain generic and not leak ME internals, got: {:?}", - relay_result - ); -} - -#[tokio::test] -async fn stress_middle_end_drop_with_client_eof_never_hangs_across_burst() { - let stats = Arc::new(Stats::new()); - let me_pool = make_me_pool_for_abort_test(stats.clone()).await; - - for round in 0..32u64 { - let (client_reader_side, client_writer_side) = duplex(1024); - let (_relay_reader_side, relay_writer_side) = duplex(1024); - - let key = [0u8; 32]; - let iv = 0u128; - let crypto_reader = CryptoReader::new(client_reader_side, AesCtr::new(&key, iv)); - let crypto_writer = CryptoWriter::new(relay_writer_side, AesCtr::new(&key, iv), 1024); - - let config = Arc::new(ProxyConfig::default()); - let buffer_pool = Arc::new(BufferPool::with_config(1024, 1)); - let rng = Arc::new(SecureRandom::new()); - let route_runtime = RouteRuntimeController::new(RelayRouteMode::Middle); - - let (probe_conn_id, probe_rx) = me_pool.registry().register().await; - drop(probe_rx); - me_pool.registry().unregister(probe_conn_id).await; - let target_conn_id = probe_conn_id.wrapping_add(1); - - let success = HandshakeSuccess { - user: format!("stress-me-drop-eof-{round}"), - peer: "127.0.0.1:12345".parse().unwrap(), - dc_idx: 1, - proto_tag: ProtoTag::Intermediate, - enc_key: key, - enc_iv: iv, - dec_key: key, - dec_iv: iv, - is_tls: false, - }; - - let session_task = tokio::spawn(handle_via_middle_proxy( - crypto_reader, - crypto_writer, - success, - me_pool.clone(), - stats.clone(), - config, - buffer_pool, - "127.0.0.1:443".parse().unwrap(), - rng, - route_runtime.subscribe(), - route_runtime.snapshot(), - 0xD00D_0000 + round, - )); - - timeout(TokioDuration::from_millis(500), async { - loop { - if stats.get_current_connections_me() >= 1 { - break; - } - tokio::time::sleep(TokioDuration::from_millis(10)).await; - } - }) - .await - .expect("session must start before forced drop in burst round"); - - me_pool.registry().unregister(target_conn_id).await; - drop(client_writer_side); - - let result = timeout(TokioDuration::from_secs(2), session_task) - .await - .expect("burst round session must terminate quickly") - .expect("burst round session must not panic"); - - assert!( - result.is_ok(), - "burst round {round}: expected clean shutdown after ME drop + EOF, got: {:?}", - result - ); - } -} diff --git a/src/proxy/tests/quota_lock_registry_cross_mode_adversarial_tests.rs b/src/proxy/tests/quota_lock_registry_cross_mode_adversarial_tests.rs deleted file mode 100644 index fb0cf93..0000000 --- a/src/proxy/tests/quota_lock_registry_cross_mode_adversarial_tests.rs +++ /dev/null @@ -1,108 +0,0 @@ -use super::*; -use std::sync::Arc; -use std::sync::{Mutex, OnceLock}; - -fn cross_mode_lock_test_guard() -> std::sync::MutexGuard<'static, ()> { - static TEST_LOCK: OnceLock> = OnceLock::new(); - TEST_LOCK - .get_or_init(|| Mutex::new(())) - .lock() - .unwrap_or_else(|poisoned| poisoned.into_inner()) -} - -#[test] -fn same_user_returns_same_lock_identity() { - let _guard = cross_mode_lock_test_guard(); - let locks = CROSS_MODE_QUOTA_USER_LOCKS.get_or_init(DashMap::new); - locks.clear(); - - let a = cross_mode_quota_user_lock("cross-mode-same-user"); - let b = cross_mode_quota_user_lock("cross-mode-same-user"); - - assert!( - Arc::ptr_eq(&a, &b), - "same user must reuse a stable lock identity" - ); -} - -#[test] -fn saturation_overflow_path_returns_stable_striped_lock_without_cache_growth() { - let _guard = cross_mode_lock_test_guard(); - let locks = CROSS_MODE_QUOTA_USER_LOCKS.get_or_init(DashMap::new); - locks.clear(); - - let prefix = format!("cross-mode-saturated-{}", std::process::id()); - let mut retained = Vec::with_capacity(CROSS_MODE_QUOTA_USER_LOCKS_MAX); - for idx in 0..CROSS_MODE_QUOTA_USER_LOCKS_MAX { - retained.push(cross_mode_quota_user_lock(&format!("{prefix}-{idx}"))); - } - - assert_eq!( - locks.len(), - CROSS_MODE_QUOTA_USER_LOCKS_MAX, - "lock cache must be saturated for overflow check" - ); - - let overflow_user = format!("cross-mode-overflow-{}", std::process::id()); - let overflow_a = cross_mode_quota_user_lock(&overflow_user); - let overflow_b = cross_mode_quota_user_lock(&overflow_user); - - assert_eq!( - locks.len(), - CROSS_MODE_QUOTA_USER_LOCKS_MAX, - "overflow path must not grow bounded lock cache" - ); - assert!( - locks.get(&overflow_user).is_none(), - "overflow user must stay on striped fallback while cache is saturated" - ); - assert!( - Arc::ptr_eq(&overflow_a, &overflow_b), - "overflow user must receive a stable striped lock across repeated lookups" - ); - - drop(retained); -} - -#[test] -fn reclaim_drops_stale_entries_but_preserves_active_user_lock_identity() { - let _guard = cross_mode_lock_test_guard(); - let locks = CROSS_MODE_QUOTA_USER_LOCKS.get_or_init(DashMap::new); - locks.clear(); - - let prefix = format!("cross-mode-reclaim-{}", std::process::id()); - let protected_user = format!("{prefix}-protected"); - - let protected_lock = cross_mode_quota_user_lock(&protected_user); - let mut retained = Vec::with_capacity(CROSS_MODE_QUOTA_USER_LOCKS_MAX.saturating_sub(1)); - for idx in 0..(CROSS_MODE_QUOTA_USER_LOCKS_MAX.saturating_sub(1)) { - retained.push(cross_mode_quota_user_lock(&format!("{prefix}-{idx}"))); - } - - assert_eq!( - locks.len(), - CROSS_MODE_QUOTA_USER_LOCKS_MAX, - "fixture must saturate lock cache before reclaim path is exercised" - ); - - drop(retained); - - let newcomer_user = format!("{prefix}-newcomer"); - let _newcomer = cross_mode_quota_user_lock(&newcomer_user); - - assert!( - locks.get(&protected_user).is_some(), - "active protected user must remain cache-resident after reclaim" - ); - let locked = locks - .get(&protected_user) - .expect("protected user must remain in map after reclaim"); - assert!( - Arc::ptr_eq(locked.value(), &protected_lock), - "reclaim must not swap active user lock identity" - ); - assert!( - locks.get(&newcomer_user).is_some(), - "newcomer should become cacheable after stale entries are reclaimed" - ); -} diff --git a/src/proxy/tests/relay_cross_mode_pipeline_hol_integration_security_tests.rs b/src/proxy/tests/relay_cross_mode_pipeline_hol_integration_security_tests.rs deleted file mode 100644 index 9ea921c..0000000 --- a/src/proxy/tests/relay_cross_mode_pipeline_hol_integration_security_tests.rs +++ /dev/null @@ -1,267 +0,0 @@ -use super::relay_bidirectional; -use crate::stats::Stats; -use crate::stream::BufferPool; -use std::sync::Arc; -use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; -use tokio::time::{Duration, timeout}; - -fn quota_test_guard() -> impl Drop { - super::quota_user_lock_test_scope() -} - -#[tokio::test] -async fn negative_same_user_pipeline_stalls_while_middle_lock_is_held() { - let _guard = quota_test_guard(); - - let user = format!("relay-pipeline-stall-{}", std::process::id()); - let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user); - let held_guard = held - .try_lock() - .expect("test must hold shared cross-mode lock"); - - let stats = Arc::new(Stats::new()); - let (mut client_peer, relay_client) = duplex(1024); - let (relay_server, mut server_peer) = duplex(1024); - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let relay_user = user.clone(); - let relay_stats = Arc::clone(&stats); - let relay_task = tokio::spawn(async move { - relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 256, - 256, - &relay_user, - relay_stats, - Some(1024), - Arc::new(BufferPool::new()), - ) - .await - }); - - server_peer - .write_all(&[0xA1]) - .await - .expect("server write should enqueue while relay is stalled"); - - let mut one = [0u8; 1]; - let blocked_read = timeout(Duration::from_millis(40), client_peer.read_exact(&mut one)).await; - assert!( - blocked_read.is_err(), - "same-user relay must remain blocked while cross-mode lock is held" - ); - - drop(held_guard); - - timeout(Duration::from_millis(400), client_peer.read_exact(&mut one)) - .await - .expect("blocked relay must resume after cross-mode lock release") - .expect("resumed relay must deliver queued byte"); - assert_eq!(one, [0xA1]); - - drop(client_peer); - drop(server_peer); - - let relay_result = timeout(Duration::from_secs(1), relay_task) - .await - .expect("relay task must complete") - .expect("relay task must not panic"); - assert!(relay_result.is_ok()); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn integration_other_user_pipeline_progresses_while_blocked_user_is_stalled() { - let _guard = quota_test_guard(); - - let blocked_user = format!("relay-pipeline-blocked-{}", std::process::id()); - let free_user = format!("relay-pipeline-free-{}", std::process::id()); - let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&blocked_user); - let held_guard = held - .try_lock() - .expect("test must hold blocked user's shared cross-mode lock"); - - let stats_blocked = Arc::new(Stats::new()); - let stats_free = Arc::new(Stats::new()); - - let (mut blocked_client, blocked_relay_client) = duplex(1024); - let (blocked_relay_server, mut blocked_server) = duplex(1024); - let (blocked_client_reader, blocked_client_writer) = tokio::io::split(blocked_relay_client); - let (blocked_server_reader, blocked_server_writer) = tokio::io::split(blocked_relay_server); - - let (mut free_client, free_relay_client) = duplex(1024); - let (free_relay_server, mut free_server) = duplex(1024); - let (free_client_reader, free_client_writer) = tokio::io::split(free_relay_client); - let (free_server_reader, free_server_writer) = tokio::io::split(free_relay_server); - - let blocked_task = { - let user = blocked_user.clone(); - let stats = Arc::clone(&stats_blocked); - tokio::spawn(async move { - relay_bidirectional( - blocked_client_reader, - blocked_client_writer, - blocked_server_reader, - blocked_server_writer, - 256, - 256, - &user, - stats, - Some(1024), - Arc::new(BufferPool::new()), - ) - .await - }) - }; - - let free_task = { - let user = free_user.clone(); - let stats = Arc::clone(&stats_free); - tokio::spawn(async move { - relay_bidirectional( - free_client_reader, - free_client_writer, - free_server_reader, - free_server_writer, - 256, - 256, - &user, - stats, - Some(1024), - Arc::new(BufferPool::new()), - ) - .await - }) - }; - - blocked_server - .write_all(&[0xB1]) - .await - .expect("blocked user server write should queue"); - free_server - .write_all(&[0xC1]) - .await - .expect("free user server write should queue"); - - let mut blocked_buf = [0u8; 1]; - let mut free_buf = [0u8; 1]; - - let blocked_stalled = timeout( - Duration::from_millis(40), - blocked_client.read_exact(&mut blocked_buf), - ) - .await; - assert!( - blocked_stalled.is_err(), - "blocked user must remain stalled while its lock is held" - ); - - timeout(Duration::from_millis(250), free_client.read_exact(&mut free_buf)) - .await - .expect("free user must make progress while other user is blocked") - .expect("free user read must succeed"); - assert_eq!(free_buf, [0xC1]); - - drop(held_guard); - - timeout(Duration::from_millis(400), blocked_client.read_exact(&mut blocked_buf)) - .await - .expect("blocked user must resume after release") - .expect("blocked user resumed read must succeed"); - assert_eq!(blocked_buf, [0xB1]); - - drop(blocked_client); - drop(blocked_server); - drop(free_client); - drop(free_server); - - assert!( - timeout(Duration::from_secs(1), blocked_task) - .await - .expect("blocked relay task must complete") - .expect("blocked relay task must not panic") - .is_ok() - ); - assert!( - timeout(Duration::from_secs(1), free_task) - .await - .expect("free relay task must complete") - .expect("free relay task must not panic") - .is_ok() - ); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn stress_jittered_hold_release_cycles_preserve_pipeline_liveness() { - let _guard = quota_test_guard(); - - let mut seed = 0x5EED_C0DE_2026_0323u64; - for round in 0..24u32 { - seed ^= seed << 7; - seed ^= seed >> 9; - seed ^= seed << 8; - - let hold_ms = 2 + (seed % 10); - let user = format!("relay-pipeline-fuzz-{}-{round}", std::process::id()); - let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user); - let held_guard = held - .try_lock() - .expect("test must hold lock during fuzz round"); - - let stats = Arc::new(Stats::new()); - let (mut client_peer, relay_client) = duplex(1024); - let (relay_server, mut server_peer) = duplex(1024); - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let relay_user = user.clone(); - let relay_stats = Arc::clone(&stats); - let relay_task = tokio::spawn(async move { - relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 256, - 256, - &relay_user, - relay_stats, - Some(1024), - Arc::new(BufferPool::new()), - ) - .await - }); - - server_peer - .write_all(&[0xD1]) - .await - .expect("server write should queue in fuzz round"); - - let mut one = [0u8; 1]; - let stalled = timeout(Duration::from_millis(30), client_peer.read_exact(&mut one)).await; - assert!(stalled.is_err(), "held phase must stall same-user relay"); - - tokio::time::sleep(Duration::from_millis(hold_ms)).await; - drop(held_guard); - - timeout(Duration::from_millis(400), client_peer.read_exact(&mut one)) - .await - .expect("released phase must resume same-user relay") - .expect("released phase read must succeed"); - assert_eq!(one, [0xD1]); - - drop(client_peer); - drop(server_peer); - - assert!( - timeout(Duration::from_secs(1), relay_task) - .await - .expect("fuzz relay task must complete") - .expect("fuzz relay task must not panic") - .is_ok() - ); - } -} \ No newline at end of file diff --git a/src/proxy/tests/relay_cross_mode_pipeline_latency_benchmark_security_tests.rs b/src/proxy/tests/relay_cross_mode_pipeline_latency_benchmark_security_tests.rs deleted file mode 100644 index c967861..0000000 --- a/src/proxy/tests/relay_cross_mode_pipeline_latency_benchmark_security_tests.rs +++ /dev/null @@ -1,213 +0,0 @@ -use super::relay_bidirectional; -use crate::stats::Stats; -use crate::stream::BufferPool; -use std::sync::{Arc, Mutex}; -use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; -use tokio::sync::{Barrier, watch}; -use tokio::time::{Duration, Instant, timeout}; - -fn quota_test_guard() -> impl Drop { - super::quota_user_lock_test_scope() -} - -fn percentile_index(len: usize, percentile: usize) -> usize { - ((len * percentile) / 100).min(len.saturating_sub(1)) -} - -#[tokio::test] -async fn micro_benchmark_pipeline_release_to_delivery_latency_stays_bounded() { - let _guard = quota_test_guard(); - - let rounds = 64usize; - let user = format!("relay-pipeline-latency-single-{}", std::process::id()); - let mut samples_ms = Vec::with_capacity(rounds); - - for round in 0..rounds { - let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user); - let held_guard = held - .try_lock() - .expect("test must hold shared cross-mode lock before round"); - - let stats = Arc::new(Stats::new()); - let (mut client_peer, relay_client) = duplex(1024); - let (relay_server, mut server_peer) = duplex(1024); - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let relay_user = user.clone(); - let relay_stats = Arc::clone(&stats); - let relay_task = tokio::spawn(async move { - relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 256, - 256, - &relay_user, - relay_stats, - Some(2048), - Arc::new(BufferPool::new()), - ) - .await - }); - - server_peer - .write_all(&[(round as u8) ^ 0xA5]) - .await - .expect("server write should queue before release"); - - let release_at = Instant::now(); - drop(held_guard); - - let mut one = [0u8; 1]; - timeout(Duration::from_millis(450), client_peer.read_exact(&mut one)) - .await - .expect("client must receive queued byte after release") - .expect("queued byte read must succeed"); - samples_ms.push(release_at.elapsed().as_millis() as u64); - - drop(client_peer); - drop(server_peer); - - let relay_result = timeout(Duration::from_secs(1), relay_task) - .await - .expect("relay task must complete") - .expect("relay task must not panic"); - assert!(relay_result.is_ok()); - } - - samples_ms.sort_unstable(); - let p50_ms = samples_ms[percentile_index(samples_ms.len(), 50)]; - let p95_ms = samples_ms[percentile_index(samples_ms.len(), 95)]; - - assert!( - p50_ms <= 45, - "single-flow release latency p50 must stay bounded; p50_ms={p50_ms}, samples={samples_ms:?}" - ); - assert!( - p95_ms <= 130, - "single-flow release latency p95 must stay bounded; p95_ms={p95_ms}, samples={samples_ms:?}" - ); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn stress_128_waiter_pipeline_release_latency_p95_stays_bounded() { - let _guard = quota_test_guard(); - - let waiters = 128usize; - let user = format!("relay-pipeline-latency-fanout-{}", std::process::id()); - let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user); - let held_guard = held - .try_lock() - .expect("test must hold shared lock before fanout release benchmark"); - - let ready_barrier = Arc::new(Barrier::new(waiters + 1)); - let release_at = Arc::new(Mutex::new(None::)); - let (release_tx, release_rx) = watch::channel(false); - let mut tasks = Vec::with_capacity(waiters); - - for idx in 0..waiters { - let user = user.clone(); - let barrier = Arc::clone(&ready_barrier); - let release_at = Arc::clone(&release_at); - let mut release_rx = release_rx.clone(); - - tasks.push(tokio::spawn(async move { - let stats = Arc::new(Stats::new()); - let (mut client_peer, relay_client) = duplex(512); - let (relay_server, mut server_peer) = duplex(512); - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let relay_user = user; - let relay_stats = Arc::clone(&stats); - let relay_task = tokio::spawn(async move { - relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 256, - 256, - &relay_user, - relay_stats, - Some(2048), - Arc::new(BufferPool::new()), - ) - .await - }); - - server_peer - .write_all(&[(idx as u8) ^ 0x5A]) - .await - .expect("fanout server write should queue before release"); - - barrier.wait().await; - release_rx - .changed() - .await - .expect("release signal should remain available"); - - let started = { - let guard = release_at.lock().unwrap_or_else(|poison| poison.into_inner()); - guard.expect("release timestamp must be populated before signal") - }; - - let mut one = [0u8; 1]; - timeout(Duration::from_millis(900), client_peer.read_exact(&mut one)) - .await - .expect("fanout waiter must receive queued byte after release") - .expect("fanout waiter read must succeed"); - - drop(client_peer); - drop(server_peer); - - let relay_result = timeout(Duration::from_secs(2), relay_task) - .await - .expect("fanout relay task must complete") - .expect("fanout relay task must not panic"); - assert!(relay_result.is_ok()); - - started.elapsed().as_millis() as u64 - })); - } - - ready_barrier.wait().await; - { - let mut guard = release_at.lock().unwrap_or_else(|poison| poison.into_inner()); - *guard = Some(Instant::now()); - } - drop(held_guard); - release_tx - .send(true) - .expect("release broadcast must succeed"); - - let mut samples_ms = Vec::with_capacity(waiters); - timeout(Duration::from_secs(8), async { - for task in tasks { - let elapsed = task.await.expect("fanout waiter must not panic"); - samples_ms.push(elapsed); - } - }) - .await - .expect("fanout benchmark must complete in bounded time"); - - samples_ms.sort_unstable(); - let p50_ms = samples_ms[percentile_index(samples_ms.len(), 50)]; - let p95_ms = samples_ms[percentile_index(samples_ms.len(), 95)]; - let max_ms = *samples_ms.last().unwrap_or(&0); - - assert!( - p50_ms <= 120, - "fanout release latency p50 must stay bounded; p50_ms={p50_ms}, p95_ms={p95_ms}, max_ms={max_ms}" - ); - assert!( - p95_ms <= 260, - "fanout release latency p95 must stay bounded; p50_ms={p50_ms}, p95_ms={p95_ms}, max_ms={max_ms}" - ); - assert!( - max_ms <= 700, - "fanout release latency max must stay bounded; p50_ms={p50_ms}, p95_ms={p95_ms}, max_ms={max_ms}" - ); -} \ No newline at end of file diff --git a/src/proxy/tests/relay_cross_mode_quota_fairness_tdd_tests.rs b/src/proxy/tests/relay_cross_mode_quota_fairness_tdd_tests.rs deleted file mode 100644 index adbdb22..0000000 --- a/src/proxy/tests/relay_cross_mode_quota_fairness_tdd_tests.rs +++ /dev/null @@ -1,604 +0,0 @@ -use super::*; -use crate::stats::Stats; -use std::pin::Pin; -use std::sync::Arc; -use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; -use std::task::{Context, Poll, Waker}; -use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadBuf}; -use tokio::sync::Barrier; -use tokio::time::{Duration, timeout}; - -#[derive(Default)] -struct WakeCounter { - wakes: AtomicUsize, -} - -impl std::task::Wake for WakeCounter { - fn wake(self: Arc) { - self.wakes.fetch_add(1, Ordering::Relaxed); - } - - fn wake_by_ref(self: &Arc) { - self.wakes.fetch_add(1, Ordering::Relaxed); - } -} - -fn quota_test_guard() -> impl Drop { - super::quota_user_lock_test_scope() -} - -fn build_context() -> (Arc, Context<'static>) { - let wake_counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&wake_counter)); - let leaked_waker: &'static Waker = Box::leak(Box::new(waker)); - (wake_counter, Context::from_waker(leaked_waker)) -} - -#[tokio::test] -async fn positive_cross_mode_uncontended_writer_progresses() { - let _guard = quota_test_guard(); - - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - "cross-mode-tdd-uncontended".to_string(), - Some(4096), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - - let result = io.write_all(&[0x11, 0x22]).await; - assert!(result.is_ok(), "uncontended writer must progress"); -} - -#[tokio::test] -async fn adversarial_held_cross_mode_lock_blocks_writer_even_if_local_lock_free() { - let _guard = quota_test_guard(); - - let user = format!("cross-mode-tdd-held-{}", std::process::id()); - let held = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); - let _held_guard = held - .try_lock() - .expect("test must hold cross-mode lock before polling writer"); - - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - user, - Some(4096), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - - let wake_counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&wake_counter)); - let mut cx = Context::from_waker(&waker); - - let poll = Pin::new(&mut io).poll_write(&mut cx, &[0xAA]); - assert!(poll.is_pending(), "writer must not bypass held cross-mode lock"); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn integration_parallel_waiters_resume_after_cross_mode_release() { - let _guard = quota_test_guard(); - - let user = format!("cross-mode-tdd-resume-{}", std::process::id()); - let held = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); - let held_guard = held - .try_lock() - .expect("test must hold cross-mode lock before launching waiters"); - - let stats = Arc::new(Stats::new()); - let mut waiters = Vec::new(); - for _ in 0..16 { - let stats = Arc::clone(&stats); - let user = user.clone(); - waiters.push(tokio::spawn(async move { - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - stats, - user, - Some(4096), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - io.write_all(&[0x7F]).await - })); - } - - tokio::time::sleep(Duration::from_millis(5)).await; - drop(held_guard); - - timeout(Duration::from_secs(1), async { - for waiter in waiters { - let result = waiter.await.expect("waiter task must not panic"); - assert!(result.is_ok(), "waiter must complete after cross-mode release"); - } - }) - .await - .expect("all waiters must complete in bounded time"); -} - -#[tokio::test] -async fn adversarial_cross_mode_contention_wake_budget_stays_bounded() { - let _guard = quota_test_guard(); - - let user = format!("cross-mode-tdd-wakes-{}", std::process::id()); - let held = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); - let _held_guard = held - .try_lock() - .expect("test must hold cross-mode lock before polling"); - - let stats = Arc::new(Stats::new()); - let mut ios = Vec::new(); - let mut counters = Vec::new(); - for _ in 0..20 { - ios.push(StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::clone(&stats), - user.clone(), - Some(2048), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - )); - } - - for io in &mut ios { - let wake_counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&wake_counter)); - let mut cx = Context::from_waker(&waker); - let poll = Pin::new(io).poll_write(&mut cx, &[0x33]); - assert!(poll.is_pending()); - counters.push(wake_counter); - } - - tokio::time::sleep(Duration::from_millis(25)).await; - let total_wakes: usize = counters - .iter() - .map(|counter| counter.wakes.load(Ordering::Relaxed)) - .sum(); - - assert!( - total_wakes <= 20 * 4, - "cross-mode contention should not create wake storms; wakes={total_wakes}" - ); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn light_fuzz_cross_mode_release_timing_preserves_read_write_liveness() { - let _guard = quota_test_guard(); - - let mut seed = 0xC0DE_BAAD_2026_0322u64; - for round in 0..16u32 { - seed ^= seed << 7; - seed ^= seed >> 9; - seed ^= seed << 8; - - let sleep_ms = 2 + (seed as u64 % 8); - let user = format!("cross-mode-tdd-fuzz-{}-{round}", std::process::id()); - let held = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); - let held_guard = held - .try_lock() - .expect("test must hold cross-mode lock in fuzz round"); - - let stats = Arc::new(Stats::new()); - let user_reader = user.clone(); - let reader_task = tokio::spawn(async move { - let mut io = StatsIo::new( - tokio::io::empty(), - Arc::new(SharedCounters::new()), - Arc::clone(&stats), - user_reader, - Some(4096), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - let mut one = [0u8; 1]; - io.read(&mut one).await - }); - - let user_writer = user.clone(); - let writer_task = tokio::spawn(async move { - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - user_writer, - Some(4096), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - io.write_all(&[0x44]).await - }); - - tokio::time::sleep(Duration::from_millis(sleep_ms)).await; - drop(held_guard); - - let read_done = timeout(Duration::from_millis(350), reader_task) - .await - .expect("reader task must complete after release") - .expect("reader task must not panic"); - assert!(read_done.is_ok()); - - let write_done = timeout(Duration::from_millis(350), writer_task) - .await - .expect("writer task must complete after release") - .expect("writer task must not panic"); - assert!(write_done.is_ok()); - } -} - -#[tokio::test] -async fn integration_middle_lock_blocks_relay_reader_for_same_user() { - let _guard = quota_test_guard(); - - let user = format!("cross-mode-middle-reader-block-{}", std::process::id()); - let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user); - let _held_guard = held - .try_lock() - .expect("test must hold middle-relay shared lock"); - - let mut io = StatsIo::new( - tokio::io::empty(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - user, - Some(1024), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - - let (_wake_counter, mut cx) = build_context(); - let mut one = [0u8; 1]; - let mut buf = ReadBuf::new(&mut one); - let poll = Pin::new(&mut io).poll_read(&mut cx, &mut buf); - assert!(poll.is_pending()); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn integration_middle_lock_release_unblocks_relay_reader() { - let _guard = quota_test_guard(); - - let user = format!("cross-mode-middle-reader-release-{}", std::process::id()); - let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user); - let held_guard = held - .try_lock() - .expect("test must hold middle-relay shared lock"); - - let task = tokio::spawn({ - let user = user.clone(); - async move { - let mut io = StatsIo::new( - tokio::io::empty(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - user, - Some(1024), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - let mut one = [0u8; 1]; - io.read(&mut one).await - } - }); - - tokio::time::sleep(Duration::from_millis(5)).await; - drop(held_guard); - - let done = timeout(Duration::from_millis(300), task) - .await - .expect("reader task must complete after release") - .expect("reader task must not panic"); - assert!(done.is_ok()); -} - -#[tokio::test] -async fn business_different_user_middle_lock_does_not_block_relay_writer() { - let _guard = quota_test_guard(); - - let held_user = format!("cross-mode-middle-held-{}", std::process::id()); - let active_user = format!("cross-mode-middle-active-{}", std::process::id()); - let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&held_user); - let _held_guard = held - .try_lock() - .expect("test must hold middle-relay lock for other user"); - - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - active_user, - Some(1024), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - - let (_wake_counter, mut cx) = build_context(); - let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x61]); - assert!(matches!(poll, Poll::Ready(Ok(1)))); -} - -#[tokio::test] -async fn edge_quota_none_bypasses_cross_mode_lock_even_when_held() { - let _guard = quota_test_guard(); - - let user = format!("cross-mode-none-limit-{}", std::process::id()); - let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user); - let _held_guard = held - .try_lock() - .expect("test must hold lock while quota is disabled"); - - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - user, - None, - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - - let (_wake_counter, mut cx) = build_context(); - let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x62, 0x63]); - assert!(matches!(poll, Poll::Ready(Ok(2)))); -} - -#[tokio::test] -async fn edge_quota_exceeded_flag_short_circuits_before_lock_path() { - let _guard = quota_test_guard(); - - let user = format!("cross-mode-pre-exceeded-{}", std::process::id()); - let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user); - let _held_guard = held - .try_lock() - .expect("test must hold shared lock before poll"); - - let quota_exceeded = Arc::new(AtomicBool::new(true)); - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - user, - Some(1024), - Arc::clone("a_exceeded), - tokio::time::Instant::now(), - ); - - let (_wake_counter, mut cx) = build_context(); - let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x64]); - assert!(matches!(poll, Poll::Ready(Err(ref e)) if is_quota_io_error(e))); -} - -#[tokio::test] -async fn adversarial_repoll_while_middle_lock_held_keeps_pending_without_usage_leak() { - let _guard = quota_test_guard(); - - let user = format!("cross-mode-repoll-held-{}", std::process::id()); - let stats = Arc::new(Stats::new()); - let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user); - let _held_guard = held - .try_lock() - .expect("test must hold lock for repoll sequence"); - - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::clone(&stats), - user.clone(), - Some(1024), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - - let (_wake_counter, mut cx) = build_context(); - for _ in 0..8 { - let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x65]); - assert!(poll.is_pending()); - } - - assert_eq!(stats.get_user_total_octets(&user), 0); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn integration_same_user_mixed_read_write_waiters_resume_after_release() { - let _guard = quota_test_guard(); - - let user = format!("cross-mode-mixed-resume-{}", std::process::id()); - let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user); - let held_guard = held - .try_lock() - .expect("test must hold lock before spawning mixed waiters"); - - let mut tasks = Vec::new(); - for i in 0..12usize { - let user = user.clone(); - tasks.push(tokio::spawn(async move { - if i % 2 == 0 { - let mut io = StatsIo::new( - tokio::io::empty(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - user, - Some(1024), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - let mut b = [0u8; 1]; - io.read(&mut b).await.map(|_| ()) - } else { - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - user, - Some(1024), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - io.write_all(&[0x66]).await - } - })); - } - - tokio::time::sleep(Duration::from_millis(8)).await; - drop(held_guard); - - timeout(Duration::from_secs(1), async { - for task in tasks { - let result = task.await.expect("mixed waiter task must not panic"); - assert!(result.is_ok()); - } - }) - .await - .expect("all mixed waiters must finish after release"); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn integration_one_user_blocked_other_user_progresses_under_middle_lock() { - let _guard = quota_test_guard(); - - let blocked_user = format!("cross-mode-blocked-{}", std::process::id()); - let free_user = format!("cross-mode-free-{}", std::process::id()); - let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&blocked_user); - let held_guard = held - .try_lock() - .expect("test must hold blocked user lock"); - - let blocked_task = tokio::spawn({ - let blocked_user = blocked_user.clone(); - async move { - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - blocked_user, - Some(1024), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - io.write_all(&[0x77]).await - } - }); - - let free_task = tokio::spawn({ - let free_user = free_user.clone(); - async move { - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - free_user, - Some(1024), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - io.write_all(&[0x78]).await - } - }); - - let free_done = timeout(Duration::from_millis(250), free_task) - .await - .expect("free user must not be blocked") - .expect("free user task must not panic"); - assert!(free_done.is_ok()); - - drop(held_guard); - let blocked_done = timeout(Duration::from_secs(1), blocked_task) - .await - .expect("blocked user must resume after release") - .expect("blocked user task must not panic"); - assert!(blocked_done.is_ok()); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn stress_middle_lock_release_allows_high_waiter_fanout_completion() { - let _guard = quota_test_guard(); - - let user = format!("cross-mode-fanout-{}", std::process::id()); - let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user); - let held_guard = held - .try_lock() - .expect("test must hold lock before fanout"); - - let waiters = 48usize; - let gate = Arc::new(Barrier::new(waiters + 1)); - let mut tasks = Vec::new(); - for _ in 0..waiters { - let user = user.clone(); - let gate = Arc::clone(&gate); - tasks.push(tokio::spawn(async move { - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - user, - Some(1024), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - gate.wait().await; - io.write_all(&[0x79]).await - })); - } - - gate.wait().await; - tokio::time::sleep(Duration::from_millis(10)).await; - drop(held_guard); - - timeout(Duration::from_secs(2), async { - for task in tasks { - let result = task.await.expect("fanout task must not panic"); - assert!(result.is_ok()); - } - }) - .await - .expect("fanout waiters must complete after release"); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn light_fuzz_middle_lock_hold_release_cycles_preserve_same_user_liveness() { - let _guard = quota_test_guard(); - - let mut seed = 0xA11C_EE55_2026_0323u64; - for round in 0..20u32 { - seed ^= seed << 7; - seed ^= seed >> 9; - seed ^= seed << 8; - - let hold_ms = 2 + (seed % 10); - let user = format!("cross-mode-middle-fuzz-{}-{round}", std::process::id()); - let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user); - let held_guard = held - .try_lock() - .expect("test must hold lock in fuzz round"); - - let writer = tokio::spawn({ - let user = user.clone(); - async move { - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - user, - Some(1024), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - io.write_all(&[0x7A]).await - } - }); - - tokio::time::sleep(Duration::from_millis(hold_ms)).await; - drop(held_guard); - - let done = timeout(Duration::from_millis(400), writer) - .await - .expect("writer must complete after lock release") - .expect("writer task must not panic"); - assert!(done.is_ok()); - } -} diff --git a/src/proxy/tests/relay_cross_mode_quota_lock_security_tests.rs b/src/proxy/tests/relay_cross_mode_quota_lock_security_tests.rs deleted file mode 100644 index 5ea806a..0000000 --- a/src/proxy/tests/relay_cross_mode_quota_lock_security_tests.rs +++ /dev/null @@ -1,81 +0,0 @@ -use super::*; -use std::pin::Pin; -use std::sync::Arc; -use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; -use std::task::Waker; -use std::task::{Context, Poll}; - -#[derive(Default)] -struct WakeCounter { - wakes: AtomicUsize, -} - -impl std::task::Wake for WakeCounter { - fn wake(self: Arc) { - self.wakes.fetch_add(1, Ordering::Relaxed); - } - - fn wake_by_ref(self: &Arc) { - self.wakes.fetch_add(1, Ordering::Relaxed); - } -} - -fn build_context() -> (Arc, Context<'static>) { - let wake_counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&wake_counter)); - let leaked_waker: &'static Waker = Box::leak(Box::new(waker)); - (wake_counter, Context::from_waker(leaked_waker)) -} - -#[tokio::test] -async fn adversarial_middle_held_cross_mode_lock_blocks_relay_writer() { - let _guard = quota_user_lock_test_scope(); - - let user = "cross-mode-lock-shared-user"; - let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(user); - let _held_guard = held - .try_lock() - .expect("test must hold shared cross-mode lock before relay poll"); - - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::new(crate::stats::Stats::new()), - user.to_string(), - Some(1024), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - - let (_wake_counter, mut cx) = build_context(); - let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x41, 0x42, 0x43]); - - assert!( - matches!(poll, Poll::Pending), - "relay writer must not bypass cross-mode lock held by middle-relay path" - ); -} - -#[tokio::test] -async fn business_cross_mode_lock_uncontended_allows_relay_writer_progress() { - let _guard = quota_user_lock_test_scope(); - - let user = "cross-mode-lock-progress-user"; - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::new(crate::stats::Stats::new()), - user.to_string(), - Some(1024), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - - let (_wake_counter, mut cx) = build_context(); - let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x51, 0x52]); - - assert!( - matches!(poll, Poll::Ready(Ok(2))), - "relay writer should progress when shared cross-mode lock is uncontended" - ); -} diff --git a/src/proxy/tests/relay_dual_lock_alternating_contention_security_tests.rs b/src/proxy/tests/relay_dual_lock_alternating_contention_security_tests.rs deleted file mode 100644 index 9ac4621..0000000 --- a/src/proxy/tests/relay_dual_lock_alternating_contention_security_tests.rs +++ /dev/null @@ -1,340 +0,0 @@ -use super::*; -use crate::stats::Stats; -use std::pin::Pin; -use std::sync::Arc; -use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; -use std::task::{Context, Waker}; -use tokio::io::AsyncWriteExt; -use tokio::time::{Duration, Instant, timeout}; - -#[derive(Default)] -struct WakeCounter { - wakes: AtomicUsize, -} - -impl std::task::Wake for WakeCounter { - fn wake(self: Arc) { - self.wakes.fetch_add(1, Ordering::Relaxed); - } - - fn wake_by_ref(self: &Arc) { - self.wakes.fetch_add(1, Ordering::Relaxed); - } -} - -fn quota_test_guard() -> impl Drop { - super::quota_user_lock_test_scope() -} - -#[tokio::test] -async fn positive_uncontended_dual_lock_writer_has_zero_retry_attempt() { - let _guard = quota_test_guard(); - - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - format!("dual-lock-alt-positive-{}", std::process::id()), - Some(2048), - Arc::new(AtomicBool::new(false)), - Instant::now(), - ); - - let write = io.write_all(&[0xAA, 0xBB]).await; - assert!(write.is_ok(), "uncontended write must complete"); - assert_eq!( - io.quota_write_retry_attempt, 0, - "uncontended write must not advance retry backoff" - ); -} - -#[tokio::test] -async fn adversarial_alternating_local_and_cross_mode_contention_preserves_backoff_growth() { - let _guard = quota_test_guard(); - - let user = format!("dual-lock-alt-adversarial-{}", std::process::id()); - let local_lock = quota_user_lock(&user); - let cross_mode_lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); - - let mut local_guard = Some( - local_lock - .try_lock() - .expect("test must hold local quota lock initially"), - ); - let mut cross_guard = None; - - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - user, - Some(2048), - Arc::new(AtomicBool::new(false)), - Instant::now(), - ); - - let wake_counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&wake_counter)); - let mut cx = Context::from_waker(&waker); - - let first = Pin::new(&mut io).poll_write(&mut cx, &[0x11]); - assert!(first.is_pending(), "held local lock must block first poll"); - - let mut observed_wakes = 0usize; - for idx in 0..18usize { - tokio::time::sleep(Duration::from_millis(6)).await; - - if idx % 2 == 0 { - drop(local_guard.take()); - cross_guard = Some( - cross_mode_lock - .try_lock() - .expect("cross-mode lock should be acquirable while local lock released"), - ); - } else { - drop(cross_guard.take()); - local_guard = Some( - local_lock - .try_lock() - .expect("local lock should be acquirable while cross lock released"), - ); - } - - let wakes = wake_counter.wakes.load(Ordering::Relaxed); - if wakes > observed_wakes { - observed_wakes = wakes; - let pending = Pin::new(&mut io).poll_write(&mut cx, &[0x12]); - assert!( - pending.is_pending(), - "alternating contention must keep write pending while one lock is held" - ); - } - } - - assert!( - io.quota_write_retry_attempt >= 2, - "alternating contention must still ramp retry backoff; got {}", - io.quota_write_retry_attempt - ); - assert!( - wake_counter.wakes.load(Ordering::Relaxed) <= 32, - "alternating contention must stay wake-rate-limited" - ); - - drop(local_guard); - drop(cross_guard); - let ready = Pin::new(&mut io).poll_write(&mut cx, &[0x13]); - assert!(ready.is_ready(), "writer must resume after both locks released"); -} - -#[tokio::test] -async fn edge_retry_scheduler_resets_after_alternating_contention_clears() { - let _guard = quota_test_guard(); - - let user = format!("dual-lock-alt-edge-reset-{}", std::process::id()); - let local_lock = quota_user_lock(&user); - let local_guard = local_lock - .try_lock() - .expect("test must hold local lock for edge scenario"); - - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - user, - Some(2048), - Arc::new(AtomicBool::new(false)), - Instant::now(), - ); - - let wake_counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&wake_counter)); - let mut cx = Context::from_waker(&waker); - - let first = Pin::new(&mut io).poll_write(&mut cx, &[0x21]); - assert!(first.is_pending()); - tokio::time::sleep(Duration::from_millis(15)).await; - if wake_counter.wakes.load(Ordering::Relaxed) > 0 { - let next = Pin::new(&mut io).poll_write(&mut cx, &[0x22]); - assert!(next.is_pending()); - } - - drop(local_guard); - - let ready = Pin::new(&mut io).poll_write(&mut cx, &[0x23]); - assert!(ready.is_ready()); - assert_eq!( - io.quota_write_retry_attempt, 0, - "successful dual-lock acquisition must reset retry scheduler" - ); - assert!(!io.quota_write_wake_scheduled); - assert!(io.quota_write_retry_sleep.is_none()); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn integration_cross_mode_waiters_remain_live_under_alternating_contention_then_resume() { - let _guard = quota_test_guard(); - - let user = format!("dual-lock-alt-integration-{}", std::process::id()); - let local_lock = quota_user_lock(&user); - let cross_mode_lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); - - let mut waiters = Vec::new(); - for _ in 0..16usize { - let user = user.clone(); - waiters.push(tokio::spawn(async move { - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - user, - Some(2048), - Arc::new(AtomicBool::new(false)), - Instant::now(), - ); - timeout(Duration::from_secs(2), io.write_all(&[0x31])).await - })); - } - - let mut local_guard = Some( - local_lock - .try_lock() - .expect("integration toggle must acquire local lock first"), - ); - let mut cross_guard = None; - - for idx in 0..24usize { - tokio::time::sleep(Duration::from_millis(4)).await; - if idx % 2 == 0 { - drop(local_guard.take()); - cross_guard = cross_mode_lock.try_lock().ok(); - } else { - drop(cross_guard.take()); - local_guard = local_lock.try_lock().ok(); - } - } - - drop(local_guard); - drop(cross_guard); - - for waiter in waiters { - let done = waiter.await.expect("waiter task must not panic"); - assert!( - done.is_ok(), - "waiter must finish once alternating contention window ends" - ); - assert!(done.expect("waiter timeout must not fire").is_ok()); - } -} - -#[tokio::test] -async fn light_fuzz_alternating_contention_matrix_preserves_lock_gating() { - let _guard = quota_test_guard(); - - let user = format!("dual-lock-alt-fuzz-{}", std::process::id()); - let local_lock = quota_user_lock(&user); - let cross_mode_lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); - let mut seed = 0xD00D_BAAD_F00D_2026u64; - - for _round in 0..64u32 { - seed ^= seed << 7; - seed ^= seed >> 9; - seed ^= seed << 8; - - let hold_mode = (seed % 3) as u8; - let local_guard = if hold_mode == 0 { - Some( - local_lock - .try_lock() - .expect("fuzz local lock should be acquirable"), - ) - } else { - None - }; - let cross_guard = if hold_mode == 1 { - Some( - cross_mode_lock - .try_lock() - .expect("fuzz cross lock should be acquirable"), - ) - } else { - None - }; - - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - user.clone(), - Some(1024), - Arc::new(AtomicBool::new(false)), - Instant::now(), - ); - - let write = timeout(Duration::from_millis(35), io.write_all(&[0x51])).await; - if hold_mode == 2 { - assert!(write.is_ok(), "unheld fuzz round must make progress"); - assert!(write.expect("unheld round timeout").is_ok()); - } else { - assert!( - write.is_err(), - "held-lock fuzz round must remain pending inside bounded window" - ); - } - - drop(local_guard); - drop(cross_guard); - } -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn stress_fanout_alternating_contention_recovers_without_hanging() { - let _guard = quota_test_guard(); - - let user = format!("dual-lock-alt-stress-{}", std::process::id()); - let local_lock = quota_user_lock(&user); - let cross_mode_lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); - - let mut waiters = Vec::new(); - for _ in 0..48usize { - let user = user.clone(); - waiters.push(tokio::spawn(async move { - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - user, - Some(4096), - Arc::new(AtomicBool::new(false)), - Instant::now(), - ); - timeout(Duration::from_secs(3), io.write_all(&[0xA0, 0xA1])).await - })); - } - - let mut local_guard = Some( - local_lock - .try_lock() - .expect("stress toggle must acquire local lock first"), - ); - let mut cross_guard = None; - for idx in 0..40usize { - tokio::time::sleep(Duration::from_millis(3)).await; - if idx % 2 == 0 { - drop(local_guard.take()); - cross_guard = cross_mode_lock.try_lock().ok(); - } else { - drop(cross_guard.take()); - local_guard = local_lock.try_lock().ok(); - } - } - - drop(local_guard); - drop(cross_guard); - - for waiter in waiters { - let done = waiter.await.expect("stress waiter task must not panic"); - assert!(done.is_ok(), "stress waiter timed out under alternating contention"); - assert!(done.expect("stress waiter timeout should not fire").is_ok()); - } -} diff --git a/src/proxy/tests/relay_dual_lock_backoff_regression_security_tests.rs b/src/proxy/tests/relay_dual_lock_backoff_regression_security_tests.rs deleted file mode 100644 index ce26941..0000000 --- a/src/proxy/tests/relay_dual_lock_backoff_regression_security_tests.rs +++ /dev/null @@ -1,74 +0,0 @@ -use super::*; -use crate::stats::Stats; -use std::pin::Pin; -use std::sync::Arc; -use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; -use std::task::{Context, Waker}; -use tokio::time::{Duration, Instant}; - -#[derive(Default)] -struct WakeCounter { - wakes: AtomicUsize, -} - -impl std::task::Wake for WakeCounter { - fn wake(self: Arc) { - self.wakes.fetch_add(1, Ordering::Relaxed); - } - - fn wake_by_ref(self: &Arc) { - self.wakes.fetch_add(1, Ordering::Relaxed); - } -} - -fn quota_test_guard() -> impl Drop { - super::quota_user_lock_test_scope() -} - -#[tokio::test] -async fn adversarial_cross_mode_only_contention_backoff_attempt_must_ramp() { - let _guard = quota_test_guard(); - - let user = format!("dual-lock-backoff-{}", std::process::id()); - let cross_mode_lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); - let held_cross_mode_guard = cross_mode_lock - .try_lock() - .expect("test must hold cross-mode lock before polling"); - - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - user, - Some(2048), - Arc::new(AtomicBool::new(false)), - Instant::now(), - ); - - let wake_counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&wake_counter)); - let mut cx = Context::from_waker(&waker); - - let first = Pin::new(&mut io).poll_write(&mut cx, &[0xAA]); - assert!(first.is_pending(), "held cross-mode lock must block writer"); - - let started = Instant::now(); - let mut last_wakes = 0usize; - while started.elapsed() < Duration::from_millis(120) { - let wakes = wake_counter.wakes.load(Ordering::Relaxed); - if wakes > last_wakes { - last_wakes = wakes; - let next = Pin::new(&mut io).poll_write(&mut cx, &[0xAB]); - assert!(next.is_pending(), "writer must remain blocked while lock is held"); - } - tokio::time::sleep(Duration::from_millis(1)).await; - } - - assert!( - io.quota_write_retry_attempt >= 2, - "retry attempt must ramp under sustained second-lock contention; got {}", - io.quota_write_retry_attempt - ); - - drop(held_cross_mode_guard); -} diff --git a/src/proxy/tests/relay_dual_lock_contention_matrix_security_tests.rs b/src/proxy/tests/relay_dual_lock_contention_matrix_security_tests.rs deleted file mode 100644 index 513d92b..0000000 --- a/src/proxy/tests/relay_dual_lock_contention_matrix_security_tests.rs +++ /dev/null @@ -1,325 +0,0 @@ -use super::*; -use crate::stats::Stats; -use std::pin::Pin; -use std::sync::Arc; -use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; -use std::task::{Context, Waker}; -use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadBuf}; -use tokio::time::{Duration, Instant, timeout}; - -#[derive(Default)] -struct WakeCounter { - wakes: AtomicUsize, -} - -impl std::task::Wake for WakeCounter { - fn wake(self: Arc) { - self.wakes.fetch_add(1, Ordering::Relaxed); - } - - fn wake_by_ref(self: &Arc) { - self.wakes.fetch_add(1, Ordering::Relaxed); - } -} - -fn quota_test_guard() -> impl Drop { - super::quota_user_lock_test_scope() -} - -fn build_context() -> (Arc, Context<'static>) { - let wake_counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&wake_counter)); - let leaked_waker: &'static Waker = Box::leak(Box::new(waker)); - (wake_counter, Context::from_waker(leaked_waker)) -} - -#[tokio::test] -async fn positive_uncontended_dual_locks_writer_completes_without_retry_state() { - let _guard = quota_test_guard(); - - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - format!("dual-lock-positive-{}", std::process::id()), - Some(4096), - Arc::new(AtomicBool::new(false)), - Instant::now(), - ); - - let (_wake_counter, mut cx) = build_context(); - let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x01, 0x02, 0x03]); - assert!(poll.is_ready()); - assert_eq!(io.quota_write_retry_attempt, 0); - assert!(!io.quota_write_wake_scheduled); - assert!(io.quota_write_retry_sleep.is_none()); -} - -#[tokio::test] -async fn negative_local_lock_contention_read_retry_attempt_ramps() { - let _guard = quota_test_guard(); - - let user = format!("dual-lock-local-contention-{}", std::process::id()); - let held = quota_user_lock(&user); - let held_guard = held - .try_lock() - .expect("test must hold local quota lock before polling"); - - let mut io = StatsIo::new( - tokio::io::empty(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - user, - Some(2048), - Arc::new(AtomicBool::new(false)), - Instant::now(), - ); - - let (wake_counter, mut cx) = build_context(); - let mut one = [0u8; 1]; - let mut buf = ReadBuf::new(&mut one); - let first = Pin::new(&mut io).poll_read(&mut cx, &mut buf); - assert!(first.is_pending()); - - let started = Instant::now(); - let mut observed = 0usize; - while started.elapsed() < Duration::from_millis(120) { - let wakes = wake_counter.wakes.load(Ordering::Relaxed); - if wakes > observed { - observed = wakes; - let mut step_buf = ReadBuf::new(&mut one); - let next = Pin::new(&mut io).poll_read(&mut cx, &mut step_buf); - assert!(next.is_pending()); - } - tokio::time::sleep(Duration::from_millis(1)).await; - } - - assert!( - io.quota_read_retry_attempt >= 2, - "retry attempt must ramp under sustained local-lock contention; got {}", - io.quota_read_retry_attempt - ); - - drop(held_guard); -} - -#[tokio::test] -async fn edge_cross_mode_contention_release_resets_retry_scheduler_on_success() { - let _guard = quota_test_guard(); - - let user = format!("dual-lock-reset-{}", std::process::id()); - let cross_mode = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); - let held_guard = cross_mode - .try_lock() - .expect("test must hold cross-mode lock before polling"); - - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - user, - Some(2048), - Arc::new(AtomicBool::new(false)), - Instant::now(), - ); - - let (wake_counter, mut cx) = build_context(); - let first = Pin::new(&mut io).poll_write(&mut cx, &[0x10]); - assert!(first.is_pending()); - - tokio::time::sleep(Duration::from_millis(20)).await; - if wake_counter.wakes.load(Ordering::Relaxed) > 0 { - let next = Pin::new(&mut io).poll_write(&mut cx, &[0x11]); - assert!(next.is_pending()); - } - - drop(held_guard); - let ready = Pin::new(&mut io).poll_write(&mut cx, &[0x12]); - assert!(ready.is_ready()); - assert_eq!(io.quota_write_retry_attempt, 0); - assert!(!io.quota_write_wake_scheduled); - assert!(io.quota_write_retry_sleep.is_none()); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn adversarial_cross_mode_hold_blocks_many_waiters_without_usage_leak() { - let _guard = quota_test_guard(); - - let user = format!("dual-lock-adversarial-{}", std::process::id()); - let stats = Arc::new(Stats::new()); - let held = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); - let held_guard = held - .try_lock() - .expect("test must hold cross-mode lock before launching waiters"); - - let mut tasks = Vec::new(); - for _ in 0..24usize { - let stats = Arc::clone(&stats); - let user = user.clone(); - tasks.push(tokio::spawn(async move { - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - stats, - user, - Some(1024), - Arc::new(AtomicBool::new(false)), - Instant::now(), - ); - timeout(Duration::from_millis(40), io.write_all(&[0x33])).await - })); - } - - for task in tasks { - let timed = task.await.expect("waiter task must not panic"); - assert!(timed.is_err(), "held cross-mode lock must keep waiter pending"); - } - - assert_eq!(stats.get_user_total_octets(&user), 0); - drop(held_guard); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn integration_waiters_resume_after_cross_mode_release() { - let _guard = quota_test_guard(); - - let user = format!("dual-lock-integration-{}", std::process::id()); - let held = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); - let held_guard = held - .try_lock() - .expect("test must hold cross-mode lock before starting waiter"); - - let task = tokio::spawn({ - let user = user.clone(); - async move { - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - user, - Some(1024), - Arc::new(AtomicBool::new(false)), - Instant::now(), - ); - io.write_all(&[0x44]).await - } - }); - - tokio::time::sleep(Duration::from_millis(10)).await; - drop(held_guard); - - let done = timeout(Duration::from_secs(1), task) - .await - .expect("waiter task must complete after release") - .expect("waiter task must not panic"); - assert!(done.is_ok()); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn light_fuzz_randomized_lock_holds_preserve_liveness_and_quota_bounds() { - let _guard = quota_test_guard(); - - let user = format!("dual-lock-fuzz-{}", std::process::id()); - let stats = Arc::new(Stats::new()); - let mut seed = 0xA55A_55AA_C3D2_E1F0u64; - - for _round in 0..48u32 { - seed ^= seed << 7; - seed ^= seed >> 9; - seed ^= seed << 8; - - let hold_mode = (seed % 3) as u8; - let mut local_lock = None; - let mut cross_lock = None; - let mut local_guard = None; - let mut cross_guard = None; - - if hold_mode == 0 { - local_lock = Some(quota_user_lock(&user)); - local_guard = Some( - local_lock - .as_ref() - .expect("local lock should be present") - .try_lock() - .expect("local lock should be acquirable in fuzz round"), - ); - } else if hold_mode == 1 { - cross_lock = Some(crate::proxy::quota_lock_registry::cross_mode_quota_user_lock( - &user, - )); - cross_guard = Some( - cross_lock - .as_ref() - .expect("cross lock should be present") - .try_lock() - .expect("cross lock should be acquirable in fuzz round"), - ); - } - - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::clone(&stats), - user.clone(), - Some(4096), - Arc::new(AtomicBool::new(false)), - Instant::now(), - ); - - let write = timeout(Duration::from_millis(25), io.write_all(&[0x7A])).await; - if hold_mode == 2 { - assert!(write.is_ok(), "unheld round must make progress"); - } else { - assert!(write.is_err(), "held-lock round must stay blocked within timeout"); - } - - drop(local_guard); - drop(cross_guard); - drop(local_lock); - drop(cross_lock); - } - - assert!(stats.get_user_total_octets(&user) <= 4096); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn stress_fanout_waiters_complete_after_release_without_panics() { - let _guard = quota_test_guard(); - - let user = format!("dual-lock-stress-{}", std::process::id()); - let held = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); - let held_guard = held - .try_lock() - .expect("test must hold cross-mode lock before stress fanout"); - - let waiters = 64usize; - let mut tasks = Vec::new(); - for _ in 0..waiters { - let user = user.clone(); - tasks.push(tokio::spawn(async move { - let mut io = StatsIo::new( - tokio::io::empty(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - user, - Some(1024), - Arc::new(AtomicBool::new(false)), - Instant::now(), - ); - let mut one = [0u8; 1]; - io.read(&mut one).await - })); - } - - tokio::time::sleep(Duration::from_millis(12)).await; - drop(held_guard); - - timeout(Duration::from_secs(2), async { - for task in tasks { - let result = task.await.expect("stress waiter task must not panic"); - assert!(result.is_ok()); - } - }) - .await - .expect("all stress waiters must complete after release"); -} diff --git a/src/proxy/tests/relay_dual_lock_race_harness_security_tests.rs b/src/proxy/tests/relay_dual_lock_race_harness_security_tests.rs deleted file mode 100644 index ec180e8..0000000 --- a/src/proxy/tests/relay_dual_lock_race_harness_security_tests.rs +++ /dev/null @@ -1,128 +0,0 @@ -use super::*; -use crate::stats::Stats; -use std::sync::Arc; -use std::sync::atomic::AtomicBool; -use tokio::io::AsyncWriteExt; -use tokio::time::{Duration, timeout}; - -fn quota_test_guard() -> impl Drop { - super::quota_user_lock_test_scope() -} - -fn make_stats_io(user: String) -> StatsIo { - StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - user, - Some(4096), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - ) -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn light_fuzz_1024_round_hold_release_cycles_preserve_same_user_liveness() { - let _guard = quota_test_guard(); - - let user = format!("dual-lock-race-fuzz-{}", std::process::id()); - let mut seed = 0xD1CE_BAAD_5EED_1234u64; - - for round in 0..1024u32 { - seed ^= seed << 7; - seed ^= seed >> 9; - seed ^= seed << 8; - - let hold = (seed & 1) == 0; - let hold_ms = (seed % 3) as u64; - - let maybe_lock = if hold { - Some(crate::proxy::quota_lock_registry::cross_mode_quota_user_lock( - &user, - )) - } else { - None - }; - - let maybe_guard = maybe_lock.as_ref().map(|lock| { - lock.try_lock() - .expect("cross-mode lock must be acquirable in fuzz round") - }); - - if hold { - let mut blocked_io = make_stats_io(user.clone()); - let blocked = timeout(Duration::from_millis(5), blocked_io.write_all(&[0xA5])).await; - assert!( - blocked.is_err(), - "held round must block waiter before lock release (round={round})" - ); - - if hold_ms > 0 { - tokio::time::sleep(Duration::from_millis(hold_ms)).await; - } - } else { - let mut free_io = make_stats_io(user.clone()); - let free = timeout(Duration::from_millis(120), free_io.write_all(&[0xA5])).await; - assert!( - free.is_ok(), - "unheld round must complete promptly (round={round})" - ); - assert!(free.expect("unheld round should complete").is_ok()); - } - - drop(maybe_guard); - - let done = timeout(Duration::from_millis(350), async { - let user = user.clone(); - let mut io = make_stats_io(user); - io.write_all(&[0xA6]).await - }) - .await - .expect("post-release write must complete in bounded time"); - assert!(done.is_ok()); - } -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn stress_jittered_three_waiter_rounds_do_not_starve_after_release() { - let _guard = quota_test_guard(); - - let user = format!("dual-lock-race-stress-{}", std::process::id()); - let mut seed = 0xC0FF_EE77_4444_9999u64; - - for round in 0..256u32 { - seed ^= seed << 7; - seed ^= seed >> 9; - seed ^= seed << 8; - - let hold_ms = (seed % 4) as u64; - let lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); - let guard = lock - .try_lock() - .expect("cross-mode lock must be acquirable at round start"); - - let mut waiters = Vec::new(); - for _ in 0..3usize { - let user = user.clone(); - waiters.push(tokio::spawn(async move { - let mut io = make_stats_io(user); - io.write_all(&[0x55]).await - })); - } - - tokio::time::sleep(Duration::from_millis(hold_ms)).await; - drop(guard); - - timeout(Duration::from_secs(1), async { - for waiter in waiters { - let done = waiter.await.expect("waiter task must not panic"); - assert!( - done.is_ok(), - "waiter must complete after release (round={round})" - ); - } - }) - .await - .expect("all waiters must complete in bounded time after release"); - } -} diff --git a/src/proxy/tests/relay_quota_lock_eviction_lifecycle_tdd_tests.rs b/src/proxy/tests/relay_quota_lock_eviction_lifecycle_tdd_tests.rs deleted file mode 100644 index 806efb6..0000000 --- a/src/proxy/tests/relay_quota_lock_eviction_lifecycle_tdd_tests.rs +++ /dev/null @@ -1,79 +0,0 @@ -use super::*; -use dashmap::DashMap; -use std::sync::Arc; -use tokio::time::{Duration, timeout}; - -#[test] -fn tdd_explicit_quota_lock_evict_reclaims_only_unheld_entries() { - let _guard = quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let held_user = format!("quota-evict-held-{}", std::process::id()); - let stale_a_user = format!("quota-evict-stale-a-{}", std::process::id()); - let stale_b_user = format!("quota-evict-stale-b-{}", std::process::id()); - - let held = quota_user_lock(&held_user); - let stale_a = quota_user_lock(&stale_a_user); - let stale_b = quota_user_lock(&stale_b_user); - - assert!(map.get(&held_user).is_some()); - assert!(map.get(&stale_a_user).is_some()); - assert!(map.get(&stale_b_user).is_some()); - - drop(stale_a); - drop(stale_b); - - quota_user_lock_evict(); - - assert!( - map.get(&held_user).is_some(), - "held entry must survive eviction" - ); - assert!( - map.get(&stale_a_user).is_none(), - "unheld stale entry must be reclaimed" - ); - assert!( - map.get(&stale_b_user).is_none(), - "unheld stale entry must be reclaimed" - ); - - drop(held); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn tdd_periodic_quota_lock_evictor_reclaims_stale_entries_off_hot_path() { - let _guard = quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let held_user = format!("quota-evict-loop-held-{}", std::process::id()); - let stale_user = format!("quota-evict-loop-stale-{}", std::process::id()); - - let held = quota_user_lock(&held_user); - let stale = quota_user_lock(&stale_user); - - assert_eq!(map.len(), 2); - drop(stale); - - let evictor = spawn_quota_user_lock_evictor_for_tests(Duration::from_millis(5)); - - timeout(Duration::from_millis(200), async { - loop { - if map.get(&stale_user).is_none() { - break; - } - tokio::time::sleep(Duration::from_millis(5)).await; - } - }) - .await - .expect("periodic quota lock evictor must reclaim stale entry"); - - evictor.abort(); - - assert!(map.get(&held_user).is_some()); - assert!(map.get(&stale_user).is_none()); - - drop(held); -} diff --git a/src/proxy/tests/relay_quota_lock_eviction_stress_security_tests.rs b/src/proxy/tests/relay_quota_lock_eviction_stress_security_tests.rs deleted file mode 100644 index 251582a..0000000 --- a/src/proxy/tests/relay_quota_lock_eviction_stress_security_tests.rs +++ /dev/null @@ -1,153 +0,0 @@ -use super::*; -use dashmap::DashMap; -use std::sync::Arc; -use tokio::task::JoinSet; -use tokio::time::{Duration, timeout}; - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn stress_background_evictor_with_high_churn_keeps_cache_bounded_and_live() { - let _guard = quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let evictor = spawn_quota_user_lock_evictor_for_tests(Duration::from_millis(5)); - - let mut tasks = JoinSet::new(); - for worker in 0..24u32 { - tasks.spawn(async move { - for round in 0..320u32 { - let user = format!( - "quota-evict-stress-user-{}-{}-{}", - std::process::id(), - worker, - round - ); - let lock = quota_user_lock(&user); - if round % 19 == 0 { - tokio::task::yield_now().await; - } - drop(lock); - } - }); - } - - while let Some(done) = tasks.join_next().await { - done.expect("stress worker must not panic"); - } - - quota_user_lock_evict(); - tokio::time::sleep(Duration::from_millis(20)).await; - - assert!( - map.len() <= QUOTA_USER_LOCKS_MAX, - "quota lock map must remain bounded after churn + eviction" - ); - - let sanity_user = format!("quota-evict-stress-sanity-{}", std::process::id()); - let sanity_lock = quota_user_lock(&sanity_user); - assert!( - map.get(&sanity_user).is_some(), - "sanity user should be cacheable after eviction reclaimed stale entries" - ); - - drop(sanity_lock); - evictor.abort(); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn adversarial_held_lock_survives_repeated_eviction_then_reclaims_after_release() { - let _guard = quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let held_user = format!("quota-evict-held-survive-{}", std::process::id()); - let held = quota_user_lock(&held_user); - - let evictor = spawn_quota_user_lock_evictor_for_tests(Duration::from_millis(3)); - - for idx in 0..512u32 { - let user = format!("quota-evict-held-churn-{}-{}", std::process::id(), idx); - let temp = quota_user_lock(&user); - drop(temp); - if idx % 32 == 0 { - tokio::task::yield_now().await; - } - } - - let reacquired = quota_user_lock(&held_user); - assert!( - Arc::ptr_eq(&held, &reacquired), - "held user lock identity must remain stable across repeated evictions" - ); - assert!( - map.get(&held_user).is_some(), - "held user entry must not be reclaimed while externally referenced" - ); - - drop(reacquired); - drop(held); - - timeout(Duration::from_millis(300), async { - loop { - if map.get(&held_user).is_none() { - break; - } - tokio::time::sleep(Duration::from_millis(5)).await; - } - }) - .await - .expect("released held lock must be reclaimed by periodic evictor"); - - evictor.abort(); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn stress_saturation_then_periodic_eviction_recovers_cacheability_without_inline_retain() { - let _guard = quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - let prefix = format!("quota-evict-saturated-{}", std::process::id()); - for idx in 0..QUOTA_USER_LOCKS_MAX { - retained.push(quota_user_lock(&format!("{prefix}-{idx}"))); - } - - assert_eq!(map.len(), QUOTA_USER_LOCKS_MAX); - - let overflow_user = format!("quota-evict-overflow-user-{}", std::process::id()); - let overflow_before = quota_user_lock(&overflow_user); - assert!( - map.get(&overflow_user).is_none(), - "saturated map must initially route new user to overflow stripe" - ); - - drop(retained); - - let evictor = spawn_quota_user_lock_evictor_for_tests(Duration::from_millis(4)); - - timeout(Duration::from_millis(400), async { - loop { - if map.len() < QUOTA_USER_LOCKS_MAX { - break; - } - tokio::time::sleep(Duration::from_millis(5)).await; - } - }) - .await - .expect("periodic evictor must reclaim stale saturated entries"); - - let overflow_after = quota_user_lock(&overflow_user); - assert!( - map.get(&overflow_user).is_some(), - "after eviction, overflow user should become cacheable again" - ); - assert!( - Arc::strong_count(&overflow_after) >= 2, - "cacheable lock should be held by map and caller" - ); - - drop(overflow_before); - drop(overflow_after); - evictor.abort(); -} diff --git a/src/proxy/tests/relay_quota_lock_identity_security_tests.rs b/src/proxy/tests/relay_quota_lock_identity_security_tests.rs deleted file mode 100644 index f717f54..0000000 --- a/src/proxy/tests/relay_quota_lock_identity_security_tests.rs +++ /dev/null @@ -1,135 +0,0 @@ -use super::*; -use crate::stats::Stats; -use std::pin::Pin; -use std::sync::Arc; -use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; -use std::task::Waker; -use std::task::{Context, Poll}; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; - -#[derive(Default)] -struct WakeCounter { - wakes: AtomicUsize, -} - -impl std::task::Wake for WakeCounter { - fn wake(self: Arc) { - self.wakes.fetch_add(1, Ordering::Relaxed); - } - - fn wake_by_ref(self: &Arc) { - self.wakes.fetch_add(1, Ordering::Relaxed); - } -} - -fn build_context() -> (Arc, Context<'static>) { - let wake_counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&wake_counter)); - // Context stores a reference; leak one Waker for deterministic test scope. - let leaked_waker: &'static Waker = Box::leak(Box::new(waker)); - (wake_counter, Context::from_waker(leaked_waker)) -} - -#[tokio::test] -async fn adversarial_map_churn_cannot_bypass_held_writer_lock() { - let _guard = quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let user = "quota-identity-writer-user"; - let held_lock = quota_user_lock(user); - let _held_guard = held_lock - .try_lock() - .expect("test must hold initial user lock before StatsIo poll"); - - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - user.to_string(), - Some(1024), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - - map.clear(); - let churned_lock = quota_user_lock(user); - assert!( - !Arc::ptr_eq(&held_lock, &churned_lock), - "precondition: map churn should produce a distinct lock identity" - ); - - let (_wake_counter, mut cx) = build_context(); - let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x11, 0x22, 0x33, 0x44]); - - assert!( - matches!(poll, Poll::Pending), - "writer must remain pending on the originally-held lock identity" - ); -} - -#[tokio::test] -async fn adversarial_map_churn_cannot_bypass_held_reader_lock() { - let _guard = quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let user = "quota-identity-reader-user"; - let held_lock = quota_user_lock(user); - let _held_guard = held_lock - .try_lock() - .expect("test must hold initial user lock before StatsIo poll"); - - let mut io = StatsIo::new( - tokio::io::empty(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - user.to_string(), - Some(1024), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - - map.clear(); - let churned_lock = quota_user_lock(user); - assert!( - !Arc::ptr_eq(&held_lock, &churned_lock), - "precondition: map churn should produce a distinct lock identity" - ); - - let (_wake_counter, mut cx) = build_context(); - let mut storage = [0u8; 8]; - let mut read_buf = ReadBuf::new(&mut storage); - let poll = Pin::new(&mut io).poll_read(&mut cx, &mut read_buf); - - assert!( - matches!(poll, Poll::Pending), - "reader must remain pending on the originally-held lock identity" - ); -} - -#[tokio::test] -async fn business_no_lock_contention_keeps_writer_progress() { - let _guard = quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let user = "quota-identity-progress-user"; - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - user.to_string(), - Some(1024), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - - let (_wake_counter, mut cx) = build_context(); - let poll = Pin::new(&mut io).poll_write(&mut cx, &[0xAA, 0xBB]); - - assert!( - matches!(poll, Poll::Ready(Ok(2))), - "writer should progress immediately without contention" - ); -} diff --git a/src/proxy/tests/relay_quota_lock_pressure_adversarial_tests.rs b/src/proxy/tests/relay_quota_lock_pressure_adversarial_tests.rs deleted file mode 100644 index 5687965..0000000 --- a/src/proxy/tests/relay_quota_lock_pressure_adversarial_tests.rs +++ /dev/null @@ -1,440 +0,0 @@ -use super::*; -use crate::error::ProxyError; -use crate::stats::Stats; -use crate::stream::BufferPool; -use dashmap::DashMap; -use std::sync::Arc; -use std::sync::atomic::AtomicBool; -use std::time::Duration; -use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; -use tokio::sync::Barrier; -use tokio::time::Instant; - -#[test] -fn quota_lock_same_user_returns_same_arc_instance() { - let _guard = super::quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let a = quota_user_lock("quota-lock-same-user"); - let b = quota_user_lock("quota-lock-same-user"); - assert!(Arc::ptr_eq(&a, &b)); -} - -#[test] -fn quota_lock_parallel_same_user_reuses_single_lock() { - let _guard = super::quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let user = "quota-lock-parallel-same"; - let mut handles = Vec::new(); - - for _ in 0..64 { - handles.push(std::thread::spawn(move || quota_user_lock(user))); - } - - let first = handles - .remove(0) - .join() - .expect("thread must return lock handle"); - - for handle in handles { - let got = handle.join().expect("thread must return lock handle"); - assert!(Arc::ptr_eq(&first, &got)); - } -} - -#[test] -fn quota_lock_unique_users_materialize_distinct_entries() { - let _guard = super::quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - - map.clear(); - - let base = format!("quota-lock-distinct-{}", std::process::id()); - let users: Vec = (0..(QUOTA_USER_LOCKS_MAX / 2)) - .map(|idx| format!("{base}-{idx}")) - .collect(); - - for user in &users { - let _ = quota_user_lock(user); - } - - for user in &users { - assert!( - map.get(user).is_some(), - "lock cache must contain entry for {user}" - ); - } -} - -#[test] -fn quota_lock_unique_churn_stress_keeps_all_inserted_keys_addressable() { - let _guard = super::quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - - map.clear(); - - let base = format!("quota-lock-churn-{}", std::process::id()); - for idx in 0..(QUOTA_USER_LOCKS_MAX + 256) { - let _ = quota_user_lock(&format!("{base}-{idx}")); - } - - assert!( - map.len() <= QUOTA_USER_LOCKS_MAX, - "quota lock cache must stay bounded under unique-user churn" - ); -} - -#[test] -fn quota_lock_saturation_returns_stable_overflow_lock_without_cache_growth() { - let _guard = super::quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let prefix = format!("quota-held-{}", std::process::id()); - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - for idx in 0..QUOTA_USER_LOCKS_MAX { - retained.push(quota_user_lock(&format!("{prefix}-{idx}"))); - } - - assert_eq!( - map.len(), - QUOTA_USER_LOCKS_MAX, - "cache must be saturated for overflow check" - ); - - let overflow_user = format!("quota-overflow-{}", std::process::id()); - let overflow_a = quota_user_lock(&overflow_user); - let overflow_b = quota_user_lock(&overflow_user); - - assert_eq!( - map.len(), - QUOTA_USER_LOCKS_MAX, - "overflow path must not grow lock cache" - ); - assert!( - map.get(&overflow_user).is_none(), - "overflow user lock must stay outside bounded cache under saturation" - ); - assert!( - Arc::ptr_eq(&overflow_a, &overflow_b), - "overflow user must receive stable striped overflow lock while saturated" - ); - - drop(retained); -} - -#[test] -fn quota_lock_reclaims_unreferenced_entries_after_explicit_eviction_pass() { - let _guard = super::quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - // Saturate with retained strong references first so parallel tests cannot - // reclaim our fixture entries before we validate the reclaim path. - let prefix = format!("quota-reclaim-drop-{}", std::process::id()); - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - for idx in 0..QUOTA_USER_LOCKS_MAX { - retained.push(quota_user_lock(&format!("{prefix}-{idx}"))); - } - - drop(retained); - - quota_user_lock_evict(); - - let overflow_user = format!("quota-reclaim-overflow-{}", std::process::id()); - let overflow = quota_user_lock(&overflow_user); - - assert!( - map.get(&overflow_user).is_some(), - "after reclaiming stale entries, overflow user should become cacheable" - ); - assert!( - Arc::strong_count(&overflow) >= 2, - "cacheable overflow lock should be held by both map and caller" - ); -} - -#[test] -fn quota_lock_saturated_same_user_must_not_return_distinct_locks() { - let _guard = super::quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - for idx in 0..QUOTA_USER_LOCKS_MAX { - retained.push(quota_user_lock(&format!( - "quota-saturated-held-{}-{idx}", - std::process::id() - ))); - } - - let overflow_user = format!("quota-saturated-same-user-{}", std::process::id()); - let a = quota_user_lock(&overflow_user); - let b = quota_user_lock(&overflow_user); - - assert!( - Arc::ptr_eq(&a, &b), - "same user must not receive distinct locks under saturation because that enables quota race bypass" - ); - - drop(retained); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn quota_lock_saturation_concurrent_same_user_never_overshoots_quota() { - let _guard = super::quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - for idx in 0..QUOTA_USER_LOCKS_MAX { - retained.push(quota_user_lock(&format!( - "quota-saturated-race-held-{}-{idx}", - std::process::id() - ))); - } - - let stats = Arc::new(Stats::new()); - let user = format!("quota-saturated-race-user-{}", std::process::id()); - let gate = Arc::new(Barrier::new(2)); - - let worker = |label: u8, stats: Arc, user: String, gate: Arc| { - tokio::spawn(async move { - let counters = Arc::new(SharedCounters::new()); - let quota_exceeded = Arc::new(AtomicBool::new(false)); - let mut io = StatsIo::new( - tokio::io::sink(), - counters, - Arc::clone(&stats), - user, - Some(1), - quota_exceeded, - Instant::now(), - ); - gate.wait().await; - io.write_all(&[label]).await - }) - }; - - let one = worker(0x11, Arc::clone(&stats), user.clone(), Arc::clone(&gate)); - let two = worker(0x22, Arc::clone(&stats), user.clone(), Arc::clone(&gate)); - - let _ = tokio::time::timeout(Duration::from_secs(2), async { - let _ = one.await.expect("task one must not panic"); - let _ = two.await.expect("task two must not panic"); - }) - .await - .expect("quota race workers must complete"); - - assert!( - stats.get_user_total_octets(&user) <= 1, - "saturated lock path must never overshoot quota for same user" - ); - - drop(retained); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn quota_lock_saturation_stress_same_user_never_overshoots_quota() { - let _guard = super::quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - for idx in 0..QUOTA_USER_LOCKS_MAX { - retained.push(quota_user_lock(&format!( - "quota-saturated-stress-held-{}-{idx}", - std::process::id() - ))); - } - - for round in 0..128u32 { - let stats = Arc::new(Stats::new()); - let user = format!("quota-saturated-stress-user-{}-{round}", std::process::id()); - let gate = Arc::new(Barrier::new(2)); - - let one = { - let stats = Arc::clone(&stats); - let user = user.clone(); - let gate = Arc::clone(&gate); - tokio::spawn(async move { - let counters = Arc::new(SharedCounters::new()); - let quota_exceeded = Arc::new(AtomicBool::new(false)); - let mut io = StatsIo::new( - tokio::io::sink(), - counters, - Arc::clone(&stats), - user, - Some(1), - quota_exceeded, - Instant::now(), - ); - gate.wait().await; - io.write_all(&[0x31]).await - }) - }; - - let two = { - let stats = Arc::clone(&stats); - let user = user.clone(); - let gate = Arc::clone(&gate); - tokio::spawn(async move { - let counters = Arc::new(SharedCounters::new()); - let quota_exceeded = Arc::new(AtomicBool::new(false)); - let mut io = StatsIo::new( - tokio::io::sink(), - counters, - Arc::clone(&stats), - user, - Some(1), - quota_exceeded, - Instant::now(), - ); - gate.wait().await; - io.write_all(&[0x32]).await - }) - }; - - let _ = one.await.expect("stress task one must not panic"); - let _ = two.await.expect("stress task two must not panic"); - - assert!( - stats.get_user_total_octets(&user) <= 1, - "round {round}: saturated path must not overshoot quota" - ); - } - - drop(retained); -} - -#[test] -fn quota_error_classifier_accepts_internal_quota_sentinel_only() { - let err = quota_io_error(); - assert!(is_quota_io_error(&err)); -} - -#[test] -fn quota_error_classifier_rejects_plain_permission_denied() { - let err = std::io::Error::new(std::io::ErrorKind::PermissionDenied, "permission denied"); - assert!(!is_quota_io_error(&err)); -} - -#[test] -fn quota_lock_test_scope_recovers_after_guard_poison() { - let poison_result = std::thread::spawn(|| { - let _guard = super::quota_user_lock_test_scope(); - panic!("intentional test-only guard poison"); - }) - .join(); - assert!(poison_result.is_err(), "poison setup thread must panic"); - - let _guard = super::quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let a = quota_user_lock("quota-lock-poison-recovery-user"); - let b = quota_user_lock("quota-lock-poison-recovery-user"); - assert!(Arc::ptr_eq(&a, &b)); -} - -#[tokio::test] -async fn quota_lock_integration_zero_quota_cuts_off_without_forwarding() { - let stats = Arc::new(Stats::new()); - let user = "quota-zero-user"; - - let (mut client_peer, relay_client) = duplex(2048); - let (relay_server, mut server_peer) = duplex(2048); - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let relay = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 512, - 512, - user, - Arc::clone(&stats), - Some(0), - Arc::new(BufferPool::new()), - )); - - client_peer - .write_all(b"x") - .await - .expect("client write must succeed"); - - let mut probe = [0u8; 1]; - let forwarded = - tokio::time::timeout(Duration::from_millis(80), server_peer.read(&mut probe)).await; - if let Ok(Ok(n)) = forwarded { - assert_eq!(n, 0, "zero quota path must not forward payload bytes"); - } - - let result = tokio::time::timeout(Duration::from_secs(2), relay) - .await - .expect("relay must terminate under zero quota") - .expect("relay task must not panic"); - assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. }))); -} - -#[tokio::test] -async fn quota_lock_integration_no_quota_relays_both_directions_under_burst() { - let stats = Arc::new(Stats::new()); - - let (mut client_peer, relay_client) = duplex(8192); - let (relay_server, mut server_peer) = duplex(8192); - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let relay = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 1024, - 1024, - "quota-none-burst-user", - Arc::clone(&stats), - None, - Arc::new(BufferPool::new()), - )); - - let c2s = vec![0xA5; 2048]; - let s2c = vec![0x5A; 1536]; - - client_peer - .write_all(&c2s) - .await - .expect("client burst write must succeed"); - let mut got_c2s = vec![0u8; c2s.len()]; - server_peer - .read_exact(&mut got_c2s) - .await - .expect("server must receive c2s burst"); - assert_eq!(got_c2s, c2s); - - server_peer - .write_all(&s2c) - .await - .expect("server burst write must succeed"); - let mut got_s2c = vec![0u8; s2c.len()]; - client_peer - .read_exact(&mut got_s2c) - .await - .expect("client must receive s2c burst"); - assert_eq!(got_s2c, s2c); - - drop(client_peer); - drop(server_peer); - - let done = tokio::time::timeout(Duration::from_secs(2), relay) - .await - .expect("relay must terminate after peers close") - .expect("relay task must not panic"); - assert!(done.is_ok()); -} diff --git a/src/proxy/tests/relay_quota_retry_allocation_latency_security_tests.rs b/src/proxy/tests/relay_quota_retry_allocation_latency_security_tests.rs deleted file mode 100644 index 447a090..0000000 --- a/src/proxy/tests/relay_quota_retry_allocation_latency_security_tests.rs +++ /dev/null @@ -1,249 +0,0 @@ -use super::*; -use crate::stats::Stats; -use std::pin::Pin; -use std::sync::Arc; -use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; -use std::task::{Context, Waker}; -use tokio::io::AsyncWriteExt; -use tokio::time::{Duration, Instant, timeout}; - -#[derive(Default)] -struct WakeCounter { - wakes: AtomicUsize, -} - -impl std::task::Wake for WakeCounter { - fn wake(self: Arc) { - self.wakes.fetch_add(1, Ordering::Relaxed); - } - - fn wake_by_ref(self: &Arc) { - self.wakes.fetch_add(1, Ordering::Relaxed); - } -} - -fn quota_test_guard() -> impl Drop { - super::quota_user_lock_test_scope() -} - -fn build_context() -> (Arc, Context<'static>) { - let wake_counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&wake_counter)); - let leaked_waker: &'static Waker = Box::leak(Box::new(waker)); - (wake_counter, Context::from_waker(leaked_waker)) -} - -fn sleep_slot_ptr(slot: &Option>>) -> usize { - slot.as_ref() - .map(|sleep| (&**sleep) as *const tokio::time::Sleep as usize) - .unwrap_or(0) -} - -#[tokio::test] -async fn tdd_single_pending_timer_does_not_allocate_on_each_repoll() { - let _guard = quota_test_guard(); - - let user = format!("retry-alloc-single-pending-{}", std::process::id()); - let lock = quota_user_lock(&user); - let held_guard = lock - .try_lock() - .expect("test must hold local lock to force retry scheduling"); - - reset_quota_retry_sleep_allocs_for_tests(); - - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - user, - Some(2048), - Arc::new(AtomicBool::new(false)), - Instant::now(), - ); - - let (_wake_counter, mut cx) = build_context(); - - let first = Pin::new(&mut io).poll_write(&mut cx, &[0xA1]); - assert!(first.is_pending()); - let allocs_after_first = quota_retry_sleep_allocs_for_tests(); - let ptr_after_first = sleep_slot_ptr(&io.quota_write_retry_sleep); - - let second = Pin::new(&mut io).poll_write(&mut cx, &[0xA2]); - assert!(second.is_pending()); - let allocs_after_second = quota_retry_sleep_allocs_for_tests(); - let ptr_after_second = sleep_slot_ptr(&io.quota_write_retry_sleep); - - assert_eq!(allocs_after_first, 1, "first pending poll must allocate one timer"); - assert_eq!( - allocs_after_second, 1, - "repoll while the same timer is pending must not allocate again" - ); - assert_eq!( - ptr_after_first, ptr_after_second, - "repoll while pending should retain the same timer allocation" - ); - - drop(held_guard); -} - -#[tokio::test] -async fn tdd_retry_cycle_allocates_once_per_fired_timer_cycle_not_per_poll() { - let _guard = quota_test_guard(); - - let user = format!("retry-alloc-per-cycle-{}", std::process::id()); - let lock = quota_user_lock(&user); - let held_guard = lock - .try_lock() - .expect("test must hold local lock to keep write path pending"); - - reset_quota_retry_sleep_allocs_for_tests(); - - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - user, - Some(2048), - Arc::new(AtomicBool::new(false)), - Instant::now(), - ); - - let (wake_counter, mut cx) = build_context(); - - let mut polls = 0u64; - let mut observed_wakes = 0usize; - let started = Instant::now(); - while started.elapsed() < Duration::from_millis(70) { - let poll = Pin::new(&mut io).poll_write(&mut cx, &[0xB1]); - polls = polls.saturating_add(1); - assert!(poll.is_pending()); - - let wakes = wake_counter.wakes.load(Ordering::Relaxed); - if wakes > observed_wakes { - observed_wakes = wakes; - } - - tokio::time::sleep(Duration::from_millis(1)).await; - } - - let allocs = quota_retry_sleep_allocs_for_tests(); - assert!(allocs >= 2, "multiple fired cycles should allocate multiple timers"); - assert!( - allocs < polls, - "timer allocations must be bounded by cycles, not by every repoll (allocs={allocs}, polls={polls})" - ); - - drop(held_guard); -} - -#[tokio::test] -async fn adversarial_backoff_latency_envelope_stays_bounded_under_contention() { - let _guard = quota_test_guard(); - - let user = format!("retry-latency-envelope-{}", std::process::id()); - let lock = quota_user_lock(&user); - let held_guard = lock - .try_lock() - .expect("test must hold local lock for sustained contention"); - - reset_quota_retry_sleep_allocs_for_tests(); - - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - user, - Some(2048), - Arc::new(AtomicBool::new(false)), - Instant::now(), - ); - - let (wake_counter, mut cx) = build_context(); - - let first = Pin::new(&mut io).poll_write(&mut cx, &[0xC1]); - assert!(first.is_pending()); - - let started = Instant::now(); - let mut last_wakes = 0usize; - let mut wake_instants = Vec::new(); - - while started.elapsed() < Duration::from_millis(120) { - let wakes = wake_counter.wakes.load(Ordering::Relaxed); - if wakes > last_wakes { - last_wakes = wakes; - wake_instants.push(Instant::now()); - let pending = Pin::new(&mut io).poll_write(&mut cx, &[0xC2]); - assert!(pending.is_pending()); - } - tokio::time::sleep(Duration::from_millis(1)).await; - } - - let mut max_gap = Duration::from_millis(0); - for idx in 1..wake_instants.len() { - let gap = wake_instants[idx].saturating_duration_since(wake_instants[idx - 1]); - if gap > max_gap { - max_gap = gap; - } - } - - assert!( - max_gap <= Duration::from_millis(35), - "retry wake gap must remain bounded in test profile; observed max gap={max_gap:?}" - ); - assert!( - quota_retry_sleep_allocs_for_tests() <= 16, - "allocation cycles must remain bounded during a short contention window" - ); - - drop(held_guard); -} - -#[tokio::test] -async fn micro_benchmark_release_to_completion_latency_stays_bounded() { - let _guard = quota_test_guard(); - - let rounds = 96usize; - let mut samples_ms = Vec::with_capacity(rounds); - - for round in 0..rounds { - let user = format!("retry-release-latency-{}-{round}", std::process::id()); - let lock = quota_user_lock(&user); - let held_guard = lock - .try_lock() - .expect("test must hold local lock before spawning blocked writer"); - - let writer = tokio::spawn(async move { - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - user, - Some(2048), - Arc::new(AtomicBool::new(false)), - Instant::now(), - ); - io.write_all(&[0xD1]).await - }); - - tokio::time::sleep(Duration::from_millis(2)).await; - let release_at = Instant::now(); - drop(held_guard); - - let done = timeout(Duration::from_millis(120), writer) - .await - .expect("blocked writer must complete after release") - .expect("writer task must not panic"); - assert!(done.is_ok()); - - samples_ms.push(release_at.elapsed().as_millis() as u64); - } - - samples_ms.sort_unstable(); - let p95_idx = ((samples_ms.len() * 95) / 100).min(samples_ms.len().saturating_sub(1)); - let p95_ms = samples_ms[p95_idx]; - - assert!( - p95_ms <= 40, - "contention release->completion p95 must stay bounded; p95_ms={p95_ms}, samples={samples_ms:?}" - ); -} diff --git a/src/proxy/tests/relay_quota_retry_backoff_benchmark_security_tests.rs b/src/proxy/tests/relay_quota_retry_backoff_benchmark_security_tests.rs deleted file mode 100644 index 7083eb2..0000000 --- a/src/proxy/tests/relay_quota_retry_backoff_benchmark_security_tests.rs +++ /dev/null @@ -1,241 +0,0 @@ -use super::*; -use crate::stats::Stats; -use dashmap::DashMap; -use std::pin::Pin; -use std::sync::Arc; -use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; -use std::task::{Context, Waker}; -use tokio::io::ReadBuf; -use tokio::time::{Duration, Instant}; - -#[derive(Default)] -struct WakeCounter { - wakes: AtomicUsize, -} - -impl std::task::Wake for WakeCounter { - fn wake(self: Arc) { - self.wakes.fetch_add(1, Ordering::Relaxed); - } - - fn wake_by_ref(self: &Arc) { - self.wakes.fetch_add(1, Ordering::Relaxed); - } -} - -fn quota_test_guard() -> impl Drop { - super::quota_user_lock_test_scope() -} - -fn saturate_quota_user_locks() -> Vec>> { - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - for idx in 0..QUOTA_USER_LOCKS_MAX { - retained.push(quota_user_lock(&format!("quota-retry-bench-saturate-{idx}"))); - } - retained -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn stress_contention_wake_rate_decays_with_backoff_curve() { - let _guard = quota_test_guard(); - - let _retained = saturate_quota_user_locks(); - let user = format!("quota-backoff-bench-{}", std::process::id()); - let stats = Arc::new(Stats::new()); - - let lock = quota_user_lock(&user); - let held_guard = lock - .try_lock() - .expect("test must hold quota lock before benchmark run"); - - let waiters = 64usize; - let mut ios = Vec::with_capacity(waiters); - let mut wake_counters = Vec::with_capacity(waiters); - - for _ in 0..waiters { - ios.push(StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::clone(&stats), - user.clone(), - Some(4096), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - )); - } - - for io in &mut ios { - let counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&counter)); - let mut cx = Context::from_waker(&waker); - let pending = Pin::new(io).poll_write(&mut cx, &[0x71]); - assert!(pending.is_pending()); - wake_counters.push(counter); - } - - let mut observed = vec![0usize; waiters]; - let start = Instant::now(); - let mut wakes_at_40ms = 0usize; - let mut wakes_at_160ms = 0usize; - - while start.elapsed() < Duration::from_millis(200) { - for (idx, counter) in wake_counters.iter().enumerate() { - let wakes = counter.wakes.load(Ordering::Relaxed); - if wakes > observed[idx] { - observed[idx] = wakes; - let waker = Waker::from(Arc::clone(counter)); - let mut cx = Context::from_waker(&waker); - let pending = Pin::new(&mut ios[idx]).poll_write(&mut cx, &[0x72]); - assert!(pending.is_pending()); - } - } - - let elapsed = start.elapsed(); - if elapsed >= Duration::from_millis(40) && wakes_at_40ms == 0 { - wakes_at_40ms = wake_counters - .iter() - .map(|counter| counter.wakes.load(Ordering::Relaxed)) - .sum(); - } - if elapsed >= Duration::from_millis(160) && wakes_at_160ms == 0 { - wakes_at_160ms = wake_counters - .iter() - .map(|counter| counter.wakes.load(Ordering::Relaxed)) - .sum(); - } - - tokio::time::sleep(Duration::from_millis(1)).await; - } - - let total_wakes: usize = wake_counters - .iter() - .map(|counter| counter.wakes.load(Ordering::Relaxed)) - .sum(); - - let wakes_at_200ms = total_wakes; - let early_window_wakes = wakes_at_40ms; - let late_window_wakes = wakes_at_200ms.saturating_sub(wakes_at_160ms); - - assert!( - total_wakes <= waiters * 28, - "backoff benchmark exceeded wake budget; waiters={waiters}, wakes={total_wakes}" - ); - - assert!( - early_window_wakes > 0, - "benchmark failed to observe early contention wakes" - ); - - assert!( - late_window_wakes * 4 <= early_window_wakes * 3, - "wake-rate decay invariant violated; early_0_40ms={early_window_wakes}, late_160_200ms={late_window_wakes}, total={total_wakes}" - ); - - drop(held_guard); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn stress_read_contention_wake_rate_decays_with_backoff_curve() { - let _guard = quota_test_guard(); - - let _retained = saturate_quota_user_locks(); - let user = format!("quota-backoff-read-bench-{}", std::process::id()); - let stats = Arc::new(Stats::new()); - - let lock = quota_user_lock(&user); - let held_guard = lock - .try_lock() - .expect("test must hold quota lock before read benchmark run"); - - let waiters = 64usize; - let mut ios = Vec::with_capacity(waiters); - let mut wake_counters = Vec::with_capacity(waiters); - - for _ in 0..waiters { - ios.push(StatsIo::new( - tokio::io::empty(), - Arc::new(SharedCounters::new()), - Arc::clone(&stats), - user.clone(), - Some(4096), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - )); - } - - for io in &mut ios { - let counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&counter)); - let mut cx = Context::from_waker(&waker); - let mut storage = [0u8; 1]; - let mut buf = ReadBuf::new(&mut storage); - let pending = Pin::new(io).poll_read(&mut cx, &mut buf); - assert!(pending.is_pending()); - wake_counters.push(counter); - } - - let mut observed = vec![0usize; waiters]; - let start = Instant::now(); - let mut wakes_at_40ms = 0usize; - let mut wakes_at_160ms = 0usize; - - while start.elapsed() < Duration::from_millis(200) { - for (idx, counter) in wake_counters.iter().enumerate() { - let wakes = counter.wakes.load(Ordering::Relaxed); - if wakes > observed[idx] { - observed[idx] = wakes; - let waker = Waker::from(Arc::clone(counter)); - let mut cx = Context::from_waker(&waker); - let mut storage = [0u8; 1]; - let mut buf = ReadBuf::new(&mut storage); - let pending = Pin::new(&mut ios[idx]).poll_read(&mut cx, &mut buf); - assert!(pending.is_pending()); - } - } - - let elapsed = start.elapsed(); - if elapsed >= Duration::from_millis(40) && wakes_at_40ms == 0 { - wakes_at_40ms = wake_counters - .iter() - .map(|counter| counter.wakes.load(Ordering::Relaxed)) - .sum(); - } - if elapsed >= Duration::from_millis(160) && wakes_at_160ms == 0 { - wakes_at_160ms = wake_counters - .iter() - .map(|counter| counter.wakes.load(Ordering::Relaxed)) - .sum(); - } - - tokio::time::sleep(Duration::from_millis(1)).await; - } - - let total_wakes: usize = wake_counters - .iter() - .map(|counter| counter.wakes.load(Ordering::Relaxed)) - .sum(); - - let wakes_at_200ms = total_wakes; - let early_window_wakes = wakes_at_40ms; - let late_window_wakes = wakes_at_200ms.saturating_sub(wakes_at_160ms); - - assert!( - total_wakes <= waiters * 28, - "read backoff benchmark exceeded wake budget; waiters={waiters}, wakes={total_wakes}" - ); - - assert!( - early_window_wakes > 0, - "read benchmark failed to observe early contention wakes" - ); - - assert!( - late_window_wakes * 4 <= early_window_wakes * 3, - "read wake-rate decay invariant violated; early_0_40ms={early_window_wakes}, late_160_200ms={late_window_wakes}, total={total_wakes}" - ); - - drop(held_guard); -} diff --git a/src/proxy/tests/relay_quota_retry_backoff_security_tests.rs b/src/proxy/tests/relay_quota_retry_backoff_security_tests.rs deleted file mode 100644 index 7f1e451..0000000 --- a/src/proxy/tests/relay_quota_retry_backoff_security_tests.rs +++ /dev/null @@ -1,339 +0,0 @@ -use super::*; -use crate::stats::Stats; -use dashmap::DashMap; -use std::pin::Pin; -use std::sync::Arc; -use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; -use std::task::{Context, Waker}; -use tokio::io::ReadBuf; -use tokio::time::{Duration, Instant}; - -#[derive(Default)] -struct WakeCounter { - wakes: AtomicUsize, -} - -impl std::task::Wake for WakeCounter { - fn wake(self: Arc) { - self.wakes.fetch_add(1, Ordering::Relaxed); - } - - fn wake_by_ref(self: &Arc) { - self.wakes.fetch_add(1, Ordering::Relaxed); - } -} - -fn quota_test_guard() -> impl Drop { - super::quota_user_lock_test_scope() -} - -fn saturate_quota_user_locks() -> Vec>> { - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - for idx in 0..QUOTA_USER_LOCKS_MAX { - retained.push(quota_user_lock(&format!("quota-retry-backoff-saturate-{idx}"))); - } - retained -} - -#[tokio::test] -async fn positive_uncontended_writer_keeps_retry_wakes_zero() { - let _guard = quota_test_guard(); - - let stats = Arc::new(Stats::new()); - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::clone(&stats), - "quota-backoff-positive".to_string(), - Some(2048), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - - let wake_counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&wake_counter)); - let mut cx = Context::from_waker(&waker); - - let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x41, 0x42]); - assert!(poll.is_ready(), "uncontended writer must complete immediately"); - assert_eq!( - wake_counter.wakes.load(Ordering::Relaxed), - 0, - "uncontended path must not schedule deferred contention wakes" - ); -} - -#[tokio::test] -async fn adversarial_writer_sustained_contention_executor_repoll_is_rate_limited() { - let _guard = quota_test_guard(); - - let _retained = saturate_quota_user_locks(); - let user = "quota-backoff-adversarial-writer"; - let stats = Arc::new(Stats::new()); - - let lock = quota_user_lock(user); - let held_guard = lock - .try_lock() - .expect("test must hold quota lock before polling writer"); - - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::clone(&stats), - user.to_string(), - Some(2048), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - - let wake_counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&wake_counter)); - let mut cx = Context::from_waker(&waker); - - let first = Pin::new(&mut io).poll_write(&mut cx, &[0xAA]); - assert!(first.is_pending()); - - let start = Instant::now(); - let mut observed = 0usize; - while start.elapsed() < Duration::from_millis(80) { - let wakes = wake_counter.wakes.load(Ordering::Relaxed); - if wakes > observed { - observed = wakes; - let pending = Pin::new(&mut io).poll_write(&mut cx, &[0xAB]); - assert!(pending.is_pending()); - } - tokio::time::sleep(Duration::from_millis(1)).await; - } - - assert!( - wake_counter.wakes.load(Ordering::Relaxed) <= 16, - "sustained contention must be rate limited; observed wakes={} in 80ms", - wake_counter.wakes.load(Ordering::Relaxed) - ); - - drop(held_guard); - let ready = Pin::new(&mut io).poll_write(&mut cx, &[0xAC]); - assert!(ready.is_ready()); -} - -#[tokio::test] -async fn adversarial_reader_sustained_contention_executor_repoll_is_rate_limited() { - let _guard = quota_test_guard(); - - let _retained = saturate_quota_user_locks(); - let user = "quota-backoff-adversarial-reader"; - let stats = Arc::new(Stats::new()); - - let lock = quota_user_lock(user); - let held_guard = lock - .try_lock() - .expect("test must hold quota lock before polling reader"); - - let mut io = StatsIo::new( - tokio::io::empty(), - Arc::new(SharedCounters::new()), - Arc::clone(&stats), - user.to_string(), - Some(2048), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - - let wake_counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&wake_counter)); - let mut cx = Context::from_waker(&waker); - let mut storage = [0u8; 1]; - - let mut buf = ReadBuf::new(&mut storage); - let first = Pin::new(&mut io).poll_read(&mut cx, &mut buf); - assert!(first.is_pending()); - - let start = Instant::now(); - let mut observed = 0usize; - while start.elapsed() < Duration::from_millis(80) { - let wakes = wake_counter.wakes.load(Ordering::Relaxed); - if wakes > observed { - observed = wakes; - let mut next = ReadBuf::new(&mut storage); - let pending = Pin::new(&mut io).poll_read(&mut cx, &mut next); - assert!(pending.is_pending()); - } - tokio::time::sleep(Duration::from_millis(1)).await; - } - - assert!( - wake_counter.wakes.load(Ordering::Relaxed) <= 16, - "sustained contention must be rate limited; observed wakes={} in 80ms", - wake_counter.wakes.load(Ordering::Relaxed) - ); - - drop(held_guard); - let mut done = ReadBuf::new(&mut storage); - let ready = Pin::new(&mut io).poll_read(&mut cx, &mut done); - assert!(ready.is_ready()); -} - -#[tokio::test] -async fn edge_backoff_attempt_resets_after_contention_release() { - let _guard = quota_test_guard(); - - let _retained = saturate_quota_user_locks(); - let user = "quota-backoff-edge-reset"; - let stats = Arc::new(Stats::new()); - - let lock = quota_user_lock(user); - let held_guard = lock - .try_lock() - .expect("test must hold quota lock before polling writer"); - - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::clone(&stats), - user.to_string(), - Some(2048), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - - let wake_counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&wake_counter)); - let mut cx = Context::from_waker(&waker); - - let initial = Pin::new(&mut io).poll_write(&mut cx, &[0x31]); - assert!(initial.is_pending()); - - tokio::time::sleep(Duration::from_millis(15)).await; - let wakes = wake_counter.wakes.load(Ordering::Relaxed); - if wakes > 0 { - let pending = Pin::new(&mut io).poll_write(&mut cx, &[0x32]); - assert!(pending.is_pending()); - } - - drop(held_guard); - let ready = Pin::new(&mut io).poll_write(&mut cx, &[0x33]); - assert!(ready.is_ready()); - assert!( - !io.quota_write_wake_scheduled, - "successful write must clear deferred wake scheduling flag" - ); - assert!( - io.quota_write_retry_sleep.is_none(), - "successful write must clear deferred sleep slot" - ); -} - -#[tokio::test] -async fn light_fuzz_writer_repoll_schedule_keeps_wake_budget_bounded() { - let _guard = quota_test_guard(); - - let _retained = saturate_quota_user_locks(); - let user = "quota-backoff-fuzz-writer"; - let stats = Arc::new(Stats::new()); - - let lock = quota_user_lock(user); - let held_guard = lock - .try_lock() - .expect("test must hold quota lock before fuzz loop"); - - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::clone(&stats), - user.to_string(), - Some(2048), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - - let wake_counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&wake_counter)); - let mut cx = Context::from_waker(&waker); - - let mut seed = 0x5EED_CAFE_7788_9900u64; - for _ in 0..64 { - let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x51]); - assert!(poll.is_pending()); - - seed ^= seed << 7; - seed ^= seed >> 9; - seed ^= seed << 8; - let sleep_ms = (seed % 4) as u64; - tokio::time::sleep(Duration::from_millis(sleep_ms)).await; - } - - assert!( - wake_counter.wakes.load(Ordering::Relaxed) <= 24, - "fuzzed repoll schedule must keep wake budget bounded; observed wakes={}", - wake_counter.wakes.load(Ordering::Relaxed) - ); - - drop(held_guard); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn stress_multi_waiter_contention_keeps_global_wake_budget_bounded() { - let _guard = quota_test_guard(); - - let _retained = saturate_quota_user_locks(); - let user = format!("quota-backoff-stress-{}", std::process::id()); - let stats = Arc::new(Stats::new()); - - let lock = quota_user_lock(&user); - let held_guard = lock - .try_lock() - .expect("test must hold quota lock before launching stress waiters"); - - let waiters = 48usize; - let mut ios = Vec::with_capacity(waiters); - let mut wake_counters = Vec::with_capacity(waiters); - - for _ in 0..waiters { - ios.push(StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::clone(&stats), - user.clone(), - Some(4096), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - )); - } - - for io in &mut ios { - let counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&counter)); - let mut cx = Context::from_waker(&waker); - let pending = Pin::new(io).poll_write(&mut cx, &[0x61]); - assert!(pending.is_pending()); - wake_counters.push(counter); - } - - let start = Instant::now(); - while start.elapsed() < Duration::from_millis(120) { - for (idx, counter) in wake_counters.iter().enumerate() { - if counter.wakes.load(Ordering::Relaxed) > 0 { - let waker = Waker::from(Arc::clone(counter)); - let mut cx = Context::from_waker(&waker); - let pending = Pin::new(&mut ios[idx]).poll_write(&mut cx, &[0x62]); - assert!(pending.is_pending()); - } - } - tokio::time::sleep(Duration::from_millis(1)).await; - } - - let total_wakes: usize = wake_counters - .iter() - .map(|counter| counter.wakes.load(Ordering::Relaxed)) - .sum(); - - assert!( - total_wakes <= waiters * 20, - "stress contention must keep aggregate wake budget bounded; waiters={waiters}, wakes={total_wakes}" - ); - - drop(held_guard); -} diff --git a/src/proxy/tests/relay_quota_retry_scheduler_tdd_tests.rs b/src/proxy/tests/relay_quota_retry_scheduler_tdd_tests.rs deleted file mode 100644 index 35a6b6e..0000000 --- a/src/proxy/tests/relay_quota_retry_scheduler_tdd_tests.rs +++ /dev/null @@ -1,246 +0,0 @@ -use super::*; -use crate::stats::Stats; -use std::pin::Pin; -use std::sync::Arc; -use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; -use std::task::{Context, Poll, Waker}; -use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadBuf}; -use tokio::time::{Duration, timeout}; - -#[derive(Default)] -struct WakeCounter { - wakes: AtomicUsize, -} - -impl std::task::Wake for WakeCounter { - fn wake(self: Arc) { - self.wakes.fetch_add(1, Ordering::Relaxed); - } - - fn wake_by_ref(self: &Arc) { - self.wakes.fetch_add(1, Ordering::Relaxed); - } -} - -fn quota_test_guard() -> impl Drop { - super::quota_user_lock_test_scope() -} - -#[tokio::test] -async fn positive_uncontended_quota_limited_writer_completes() { - let _guard = quota_test_guard(); - - let stats = Arc::new(Stats::new()); - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::clone(&stats), - "tdd-uncontended".to_string(), - Some(1024), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - - let result = io.write_all(&[0x41, 0x42, 0x43]).await; - assert!(result.is_ok(), "uncontended writer must complete"); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn adversarial_contended_writers_without_repoll_must_not_wake_storm() { - let _guard = quota_test_guard(); - - let user = format!("tdd-writer-storm-{}", std::process::id()); - let held = quota_user_lock(&user); - let _held_guard = held - .try_lock() - .expect("test must hold quota lock before polling writers"); - - let stats = Arc::new(Stats::new()); - let writers = 24usize; - let mut ios = Vec::with_capacity(writers); - let mut wake_counters = Vec::with_capacity(writers); - - for _ in 0..writers { - ios.push(StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::clone(&stats), - user.clone(), - Some(1024), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - )); - } - - for io in &mut ios { - let counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&counter)); - let mut cx = Context::from_waker(&waker); - let poll = Pin::new(io).poll_write(&mut cx, &[0xAA]); - assert!(poll.is_pending(), "writer must be pending under held lock"); - wake_counters.push(counter); - } - - tokio::time::sleep(Duration::from_millis(25)).await; - - let total_wakes: usize = wake_counters - .iter() - .map(|counter| counter.wakes.load(Ordering::Relaxed)) - .sum(); - - assert!( - total_wakes <= writers * 4, - "retry scheduler must remain bounded without repoll; observed wakes={total_wakes}, writers={writers}" - ); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn adversarial_contended_readers_without_repoll_must_not_wake_storm() { - let _guard = quota_test_guard(); - - let user = format!("tdd-reader-storm-{}", std::process::id()); - let held = quota_user_lock(&user); - let _held_guard = held - .try_lock() - .expect("test must hold quota lock before polling readers"); - - let stats = Arc::new(Stats::new()); - let readers = 24usize; - let mut ios = Vec::with_capacity(readers); - let mut wake_counters = Vec::with_capacity(readers); - - for _ in 0..readers { - ios.push(StatsIo::new( - tokio::io::empty(), - Arc::new(SharedCounters::new()), - Arc::clone(&stats), - user.clone(), - Some(1024), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - )); - } - - for io in &mut ios { - let counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&counter)); - let mut cx = Context::from_waker(&waker); - let mut storage = [0u8; 1]; - let mut buf = ReadBuf::new(&mut storage); - let poll = Pin::new(io).poll_read(&mut cx, &mut buf); - assert!(poll.is_pending(), "reader must be pending under held lock"); - wake_counters.push(counter); - } - - tokio::time::sleep(Duration::from_millis(25)).await; - - let total_wakes: usize = wake_counters - .iter() - .map(|counter| counter.wakes.load(Ordering::Relaxed)) - .sum(); - - assert!( - total_wakes <= readers * 4, - "retry scheduler must remain bounded without repoll; observed wakes={total_wakes}, readers={readers}" - ); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn integration_contended_waiters_resume_after_lock_release() { - let _guard = quota_test_guard(); - - let user = format!("tdd-resume-{}", std::process::id()); - let held = quota_user_lock(&user); - let held_guard = held - .try_lock() - .expect("test must hold quota lock before launching waiters"); - - let stats = Arc::new(Stats::new()); - let mut waiters = Vec::new(); - for _ in 0..12 { - let stats = Arc::clone(&stats); - let user = user.clone(); - waiters.push(tokio::spawn(async move { - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - stats, - user, - Some(2048), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - io.write_all(&[0x5A]).await - })); - } - - tokio::time::sleep(Duration::from_millis(5)).await; - drop(held_guard); - - timeout(Duration::from_secs(1), async { - for waiter in waiters { - let result = waiter.await.expect("waiter task must not panic"); - assert!(result.is_ok(), "waiter must complete after release"); - } - }) - .await - .expect("all waiters must complete in bounded time"); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn light_fuzz_contention_rounds_keep_retry_wakes_bounded() { - let _guard = quota_test_guard(); - - let mut seed = 0x9E37_79B9_AA55_1234u64; - for round in 0..20u32 { - seed ^= seed << 7; - seed ^= seed >> 9; - seed ^= seed << 8; - - let writers = 8 + (seed as usize % 12); - let sleep_ms = 10 + (seed as u64 % 15); - let user = format!("tdd-fuzz-{}-{round}", std::process::id()); - - let held = quota_user_lock(&user); - let _held_guard = held - .try_lock() - .expect("test must hold quota lock in fuzz round"); - - let stats = Arc::new(Stats::new()); - let mut ios = Vec::with_capacity(writers); - let mut wake_counters = Vec::with_capacity(writers); - - for _ in 0..writers { - ios.push(StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::clone(&stats), - user.clone(), - Some(2048), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - )); - } - - for io in &mut ios { - let counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&counter)); - let mut cx = Context::from_waker(&waker); - let poll = Pin::new(io).poll_write(&mut cx, &[0x7A]); - assert!(matches!(poll, Poll::Pending)); - wake_counters.push(counter); - } - - tokio::time::sleep(Duration::from_millis(sleep_ms)).await; - - let total_wakes: usize = wake_counters - .iter() - .map(|counter| counter.wakes.load(Ordering::Relaxed)) - .sum(); - - assert!( - total_wakes <= writers * 4, - "fuzz round must keep wakes bounded; round={round}, writers={writers}, wakes={total_wakes}, sleep_ms={sleep_ms}" - ); - } -} diff --git a/src/proxy/tests/relay_quota_wake_liveness_regression_tests.rs b/src/proxy/tests/relay_quota_wake_liveness_regression_tests.rs deleted file mode 100644 index 9f68258..0000000 --- a/src/proxy/tests/relay_quota_wake_liveness_regression_tests.rs +++ /dev/null @@ -1,294 +0,0 @@ -use super::*; -use crate::stats::Stats; -use dashmap::DashMap; -use std::sync::Arc; -use std::sync::atomic::AtomicBool; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use tokio::sync::Barrier; -use tokio::time::{Duration, timeout}; - -fn saturate_lock_cache() -> Vec>> { - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - for idx in 0..QUOTA_USER_LOCKS_MAX { - retained.push(quota_user_lock(&format!("quota-liveness-saturated-{idx}"))); - } - retained -} - -fn quota_test_guard() -> impl Drop { - super::quota_user_lock_test_scope() -} - -#[tokio::test] -async fn positive_writer_progresses_after_contention_release_without_external_wake() { - let _guard = quota_test_guard(); - - let _retained = saturate_lock_cache(); - let user = "quota-liveness-writer-positive"; - let stats = Arc::new(Stats::new()); - - let lock = quota_user_lock(user); - let held_guard = lock - .try_lock() - .expect("test must hold user quota lock before write"); - - let counters = Arc::new(SharedCounters::new()); - let quota_exceeded = Arc::new(AtomicBool::new(false)); - let mut io = StatsIo::new( - tokio::io::sink(), - counters, - Arc::clone(&stats), - user.to_string(), - Some(1024), - quota_exceeded, - tokio::time::Instant::now(), - ); - - let writer = tokio::spawn(async move { io.write_all(&[0x11]).await }); - - // Let the initial deferred wake fire while contention is still active. - tokio::time::sleep(Duration::from_millis(4)).await; - - drop(held_guard); - - let completed = timeout(Duration::from_millis(250), writer) - .await - .expect("writer must be re-polled and complete after lock release") - .expect("writer task must not panic"); - assert!(completed.is_ok(), "writer must complete after lock release"); -} - -#[tokio::test] -async fn edge_reader_progresses_after_contention_release_without_external_wake() { - let _guard = quota_test_guard(); - - let _retained = saturate_lock_cache(); - let user = "quota-liveness-reader-edge"; - let stats = Arc::new(Stats::new()); - - let lock = quota_user_lock(user); - let held_guard = lock - .try_lock() - .expect("test must hold user quota lock before read"); - - let counters = Arc::new(SharedCounters::new()); - let quota_exceeded = Arc::new(AtomicBool::new(false)); - let mut io = StatsIo::new( - tokio::io::empty(), - counters, - Arc::clone(&stats), - user.to_string(), - Some(1024), - quota_exceeded, - tokio::time::Instant::now(), - ); - - let reader = tokio::spawn(async move { - let mut one = [0u8; 1]; - io.read(&mut one).await - }); - - tokio::time::sleep(Duration::from_millis(4)).await; - drop(held_guard); - - let completed = timeout(Duration::from_millis(250), reader) - .await - .expect("reader must be re-polled and complete after lock release") - .expect("reader task must not panic"); - assert!(completed.is_ok(), "reader must complete after lock release"); -} - -#[tokio::test] -async fn adversarial_early_deferred_wake_consumption_does_not_deadlock_writer() { - let _guard = quota_test_guard(); - - let _retained = saturate_lock_cache(); - let user = "quota-liveness-adversarial"; - let stats = Arc::new(Stats::new()); - - let lock = quota_user_lock(user); - let held_guard = lock - .try_lock() - .expect("test must hold user quota lock before adversarial write"); - - let counters = Arc::new(SharedCounters::new()); - let quota_exceeded = Arc::new(AtomicBool::new(false)); - let mut io = StatsIo::new( - tokio::io::sink(), - counters, - Arc::clone(&stats), - user.to_string(), - Some(1024), - quota_exceeded, - tokio::time::Instant::now(), - ); - - let writer = tokio::spawn(async move { io.write_all(&[0x22]).await }); - - // Force multiple scheduler rounds while lock remains held so the first - // deferred wake has already been consumed under contention. - for _ in 0..32 { - tokio::task::yield_now().await; - } - - drop(held_guard); - - let completed = timeout(Duration::from_millis(300), writer) - .await - .expect("writer must not stay parked forever after release") - .expect("writer task must not panic"); - assert!(completed.is_ok()); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn integration_parallel_waiters_resume_after_single_release_event() { - let _guard = quota_test_guard(); - - let _retained = saturate_lock_cache(); - let user = format!("quota-liveness-integration-{}", std::process::id()); - let stats = Arc::new(Stats::new()); - let barrier = Arc::new(Barrier::new(13)); - - let lock = quota_user_lock(&user); - let held_guard = lock - .try_lock() - .expect("test must hold user quota lock before launching waiters"); - - let mut waiters = Vec::new(); - for _ in 0..12 { - let stats = Arc::clone(&stats); - let user = user.clone(); - let barrier = Arc::clone(&barrier); - waiters.push(tokio::spawn(async move { - let counters = Arc::new(SharedCounters::new()); - let quota_exceeded = Arc::new(AtomicBool::new(false)); - let mut io = StatsIo::new( - tokio::io::sink(), - counters, - stats, - user, - Some(4096), - quota_exceeded, - tokio::time::Instant::now(), - ); - barrier.wait().await; - io.write_all(&[0x33]).await - })); - } - - barrier.wait().await; - tokio::time::sleep(Duration::from_millis(4)).await; - drop(held_guard); - - timeout(Duration::from_secs(1), async { - for waiter in waiters { - let outcome = waiter.await.expect("waiter must not panic"); - assert!( - outcome.is_ok(), - "waiter must resume and complete after release" - ); - } - }) - .await - .expect("all waiters must complete in bounded time"); -} - -#[tokio::test] -async fn light_fuzz_release_timing_matrix_preserves_liveness() { - let _guard = quota_test_guard(); - - let _retained = saturate_lock_cache(); - let stats = Arc::new(Stats::new()); - - let mut seed = 0xD1CE_F00D_0123_4567u64; - for round in 0..64u32 { - seed ^= seed << 7; - seed ^= seed >> 9; - seed ^= seed << 8; - - let delay_ms = 1 + (seed & 0x7) as u64; - let user = format!("quota-liveness-fuzz-{}-{round}", std::process::id()); - - let lock = quota_user_lock(&user); - let held_guard = lock - .try_lock() - .expect("test must hold user quota lock in fuzz round"); - - let counters = Arc::new(SharedCounters::new()); - let quota_exceeded = Arc::new(AtomicBool::new(false)); - let mut io = StatsIo::new( - tokio::io::sink(), - counters, - Arc::clone(&stats), - user, - Some(2048), - quota_exceeded, - tokio::time::Instant::now(), - ); - - let writer = tokio::spawn(async move { io.write_all(&[0x44]).await }); - - tokio::time::sleep(Duration::from_millis(delay_ms)).await; - drop(held_guard); - - let done = timeout(Duration::from_millis(300), writer) - .await - .expect("fuzz round writer must complete") - .expect("fuzz writer task must not panic"); - assert!( - done.is_ok(), - "fuzz round writer must not stall after release" - ); - } -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn stress_repeated_contention_cycles_remain_live() { - let _guard = quota_test_guard(); - - let _retained = saturate_lock_cache(); - let stats = Arc::new(Stats::new()); - - for cycle in 0..40u32 { - let user = format!("quota-liveness-stress-{}-{cycle}", std::process::id()); - let lock = quota_user_lock(&user); - let held_guard = lock - .try_lock() - .expect("test must hold lock before stress cycle"); - - let mut tasks = Vec::new(); - for _ in 0..6 { - let stats = Arc::clone(&stats); - let user = user.clone(); - tasks.push(tokio::spawn(async move { - let counters = Arc::new(SharedCounters::new()); - let quota_exceeded = Arc::new(AtomicBool::new(false)); - let mut io = StatsIo::new( - tokio::io::sink(), - counters, - stats, - user, - Some(2048), - quota_exceeded, - tokio::time::Instant::now(), - ); - io.write_all(&[0x55]).await - })); - } - - tokio::task::yield_now().await; - drop(held_guard); - - timeout(Duration::from_millis(700), async { - for task in tasks { - let outcome = task.await.expect("stress task must not panic"); - assert!(outcome.is_ok(), "stress writer must complete"); - } - }) - .await - .expect("stress cycle must finish in bounded time"); - } -} diff --git a/src/proxy/tests/relay_quota_waker_storm_adversarial_tests.rs b/src/proxy/tests/relay_quota_waker_storm_adversarial_tests.rs deleted file mode 100644 index fa4878a..0000000 --- a/src/proxy/tests/relay_quota_waker_storm_adversarial_tests.rs +++ /dev/null @@ -1,310 +0,0 @@ -use super::*; -use crate::stats::Stats; -use dashmap::DashMap; -use std::pin::Pin; -use std::sync::Arc; -use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; -use std::task::{Context, Waker}; -use tokio::io::{AsyncWriteExt, ReadBuf}; -use tokio::time::{Duration, timeout}; - -#[derive(Default)] -struct WakeCounter { - wakes: AtomicUsize, -} - -impl std::task::Wake for WakeCounter { - fn wake(self: Arc) { - self.wakes.fetch_add(1, Ordering::Relaxed); - } - - fn wake_by_ref(self: &Arc) { - self.wakes.fetch_add(1, Ordering::Relaxed); - } -} - -fn quota_test_guard() -> impl Drop { - super::quota_user_lock_test_scope() -} - -fn saturate_quota_user_locks() -> Vec>> { - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - for idx in 0..QUOTA_USER_LOCKS_MAX { - retained.push(quota_user_lock(&format!("quota-waker-saturate-{idx}"))); - } - retained -} - -#[tokio::test] -async fn positive_contended_writer_emits_deferred_wake_for_liveness() { - let _guard = quota_test_guard(); - - let _retained = saturate_quota_user_locks(); - let stats = Arc::new(Stats::new()); - let user = "quota-waker-positive-user"; - - let lock = quota_user_lock(user); - let held_guard = lock - .try_lock() - .expect("test must hold overflow lock before polling writer"); - - let counters = Arc::new(SharedCounters::new()); - let quota_exceeded = Arc::new(AtomicBool::new(false)); - let mut io = StatsIo::new( - tokio::io::sink(), - counters, - Arc::clone(&stats), - user.to_string(), - Some(1024), - quota_exceeded, - tokio::time::Instant::now(), - ); - - let wake_counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&wake_counter)); - let mut cx = Context::from_waker(&waker); - - let pending = Pin::new(&mut io).poll_write(&mut cx, &[0xA1]); - assert!(pending.is_pending()); - - timeout(Duration::from_millis(100), async { - loop { - if wake_counter.wakes.load(Ordering::Relaxed) >= 1 { - break; - } - tokio::task::yield_now().await; - } - }) - .await - .expect("contended writer must receive deferred wake"); - - drop(held_guard); - let ready = Pin::new(&mut io).poll_write(&mut cx, &[0xA2]); - assert!( - ready.is_ready(), - "writer must progress after contention release" - ); -} - -#[tokio::test] -async fn adversarial_blackhat_writer_contention_does_not_create_waker_storm() { - let _guard = quota_test_guard(); - - let _retained = saturate_quota_user_locks(); - let stats = Arc::new(Stats::new()); - let user = "quota-waker-blackhat-writer"; - - let lock = quota_user_lock(user); - let held_guard = lock - .try_lock() - .expect("test must hold overflow lock before polling writer"); - - let counters = Arc::new(SharedCounters::new()); - let quota_exceeded = Arc::new(AtomicBool::new(false)); - let mut io = StatsIo::new( - tokio::io::sink(), - counters, - Arc::clone(&stats), - user.to_string(), - Some(1024), - quota_exceeded, - tokio::time::Instant::now(), - ); - - let wake_counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&wake_counter)); - let mut cx = Context::from_waker(&waker); - - for _ in 0..512 { - let poll = Pin::new(&mut io).poll_write(&mut cx, &[0xBE]); - assert!( - poll.is_pending(), - "writer must stay pending while lock is held" - ); - tokio::task::yield_now().await; - } - - let wakes = wake_counter.wakes.load(Ordering::Relaxed); - assert!( - wakes <= 128, - "pending writer retries must not trigger wake storm; observed wakes={wakes}" - ); - - drop(held_guard); - let ready = Pin::new(&mut io).poll_write(&mut cx, &[0xEF]); - assert!(ready.is_ready()); -} - -#[tokio::test] -async fn edge_read_path_contention_keeps_wake_budget_bounded() { - let _guard = quota_test_guard(); - - let _retained = saturate_quota_user_locks(); - let stats = Arc::new(Stats::new()); - let user = "quota-waker-read-edge"; - - let lock = quota_user_lock(user); - let held_guard = lock - .try_lock() - .expect("test must hold overflow lock before polling reader"); - - let counters = Arc::new(SharedCounters::new()); - let quota_exceeded = Arc::new(AtomicBool::new(false)); - let mut io = StatsIo::new( - tokio::io::empty(), - counters, - Arc::clone(&stats), - user.to_string(), - Some(1024), - quota_exceeded, - tokio::time::Instant::now(), - ); - - let wake_counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&wake_counter)); - let mut cx = Context::from_waker(&waker); - let mut storage = [0u8; 1]; - - for _ in 0..512 { - let mut buf = ReadBuf::new(&mut storage); - let poll = Pin::new(&mut io).poll_read(&mut cx, &mut buf); - assert!(poll.is_pending()); - tokio::task::yield_now().await; - } - - let wakes = wake_counter.wakes.load(Ordering::Relaxed); - assert!( - wakes <= 128, - "pending reader retries must not trigger wake storm; observed wakes={wakes}" - ); - - drop(held_guard); - let mut buf = ReadBuf::new(&mut storage); - let ready = Pin::new(&mut io).poll_read(&mut cx, &mut buf); - assert!(ready.is_ready()); -} - -#[tokio::test] -async fn light_fuzz_mixed_poll_schedule_under_contention_stays_bounded() { - let _guard = quota_test_guard(); - - let _retained = saturate_quota_user_locks(); - let stats = Arc::new(Stats::new()); - let user = "quota-waker-fuzz-user"; - - let lock = quota_user_lock(user); - let held_guard = lock - .try_lock() - .expect("test must hold overflow lock before fuzz polling"); - - let counters_w = Arc::new(SharedCounters::new()); - let mut writer_io = StatsIo::new( - tokio::io::sink(), - counters_w, - Arc::clone(&stats), - user.to_string(), - Some(1024), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - - let counters_r = Arc::new(SharedCounters::new()); - let mut reader_io = StatsIo::new( - tokio::io::empty(), - counters_r, - Arc::clone(&stats), - user.to_string(), - Some(1024), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - - let wake_counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&wake_counter)); - let mut cx = Context::from_waker(&waker); - let mut seed = 0xBADC_0FFE_EE11_2211u64; - let mut storage = [0u8; 1]; - - for _ in 0..1024 { - seed ^= seed << 7; - seed ^= seed >> 9; - seed ^= seed << 8; - - if (seed & 1) == 0 { - let poll = Pin::new(&mut writer_io).poll_write(&mut cx, &[0x44]); - assert!(poll.is_pending()); - } else { - let mut buf = ReadBuf::new(&mut storage); - let poll = Pin::new(&mut reader_io).poll_read(&mut cx, &mut buf); - assert!(poll.is_pending()); - } - tokio::task::yield_now().await; - } - - assert!( - wake_counter.wakes.load(Ordering::Relaxed) <= 192, - "mixed contention fuzz must keep deferred wake count tightly bounded" - ); - - drop(held_guard); - let ready_w = Pin::new(&mut writer_io).poll_write(&mut cx, &[0x55]); - assert!(ready_w.is_ready()); - - let mut buf = ReadBuf::new(&mut storage); - let ready_r = Pin::new(&mut reader_io).poll_read(&mut cx, &mut buf); - assert!(ready_r.is_ready()); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -#[ignore = "red-team detector: reveals possible starvation if deferred wake fires before contention release"] -async fn stress_many_contended_writers_complete_after_release() { - let _guard = quota_test_guard(); - - let _retained = saturate_quota_user_locks(); - let user = "quota-waker-stress-user".to_string(); - let stats = Arc::new(Stats::new()); - - let lock = quota_user_lock(&user); - let held_guard = lock - .try_lock() - .expect("test must hold overflow lock before launching contended tasks"); - - let mut tasks = Vec::new(); - for _ in 0..32 { - let stats = Arc::clone(&stats); - let user = user.clone(); - tasks.push(tokio::spawn(async move { - let counters = Arc::new(SharedCounters::new()); - let quota_exceeded = Arc::new(AtomicBool::new(false)); - let mut io = StatsIo::new( - tokio::io::sink(), - counters, - stats, - user, - Some(2048), - quota_exceeded, - tokio::time::Instant::now(), - ); - - io.write_all(&[0xAA]).await - })); - } - - for _ in 0..8 { - tokio::task::yield_now().await; - } - - drop(held_guard); - - timeout(Duration::from_secs(2), async { - for task in tasks { - let result = task.await.expect("stress task must not panic"); - assert!(result.is_ok(), "task must complete after lock release"); - } - }) - .await - .expect("all contended writer tasks must finish in bounded time after release"); -} diff --git a/src/proxy/tests/relay_security_tests.rs b/src/proxy/tests/relay_security_tests.rs deleted file mode 100644 index 7375192..0000000 --- a/src/proxy/tests/relay_security_tests.rs +++ /dev/null @@ -1,1284 +0,0 @@ -use super::relay_bidirectional; -use crate::error::ProxyError; -use crate::stats::Stats; -use crate::stream::BufferPool; -use std::future::poll_fn; -use std::io; -use std::pin::Pin; -use std::sync::Arc; -use std::sync::Mutex; -use std::sync::atomic::{AtomicUsize, Ordering}; -use std::task::Waker; -use std::task::{Context, Poll}; -use tokio::io::{AsyncRead, ReadBuf}; -use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt, duplex}; -use tokio::time::{Duration, timeout}; - -#[derive(Default)] -struct WakeCounter { - wakes: AtomicUsize, -} - -impl std::task::Wake for WakeCounter { - fn wake(self: Arc) { - self.wakes.fetch_add(1, Ordering::Relaxed); - } - - fn wake_by_ref(self: &Arc) { - self.wakes.fetch_add(1, Ordering::Relaxed); - } -} - -#[tokio::test] -async fn quota_lock_contention_does_not_self_wake_pending_writer() { - let _guard = super::quota_user_lock_test_scope(); - let map = super::QUOTA_USER_LOCKS.get_or_init(dashmap::DashMap::new); - map.clear(); - - let stats = Arc::new(Stats::new()); - let user = "quota-lock-contention-user"; - - let lock = super::quota_user_lock(user); - let _held_lock = lock - .try_lock() - .expect("test must hold the per-user quota lock before polling writer"); - - let counters = Arc::new(super::SharedCounters::new()); - let quota_exceeded = Arc::new(std::sync::atomic::AtomicBool::new(false)); - let mut io = super::StatsIo::new( - tokio::io::sink(), - counters, - Arc::clone(&stats), - user.to_string(), - Some(1024), - quota_exceeded, - tokio::time::Instant::now(), - ); - - let wake_counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&wake_counter)); - let mut cx = Context::from_waker(&waker); - - let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x11]); - assert!( - poll.is_pending(), - "writer must remain pending while lock is contended" - ); - assert_eq!( - wake_counter.wakes.load(Ordering::Relaxed), - 0, - "contended quota lock must not self-wake immediately and spin the executor" - ); -} - -#[tokio::test] -async fn quota_lock_contention_writer_schedules_single_deferred_wake_until_lock_acquired() { - let _guard = super::quota_user_lock_test_scope(); - let map = super::QUOTA_USER_LOCKS.get_or_init(dashmap::DashMap::new); - map.clear(); - - let stats = Arc::new(Stats::new()); - let user = "quota-lock-writer-liveness-user"; - - let lock = super::quota_user_lock(user); - let held_lock = lock - .try_lock() - .expect("test must hold the per-user quota lock before polling writer"); - - let counters = Arc::new(super::SharedCounters::new()); - let quota_exceeded = Arc::new(std::sync::atomic::AtomicBool::new(false)); - let mut io = super::StatsIo::new( - tokio::io::sink(), - counters, - Arc::clone(&stats), - user.to_string(), - Some(1024), - quota_exceeded, - tokio::time::Instant::now(), - ); - - let wake_counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&wake_counter)); - let mut cx = Context::from_waker(&waker); - - let first = Pin::new(&mut io).poll_write(&mut cx, &[0x11]); - assert!( - first.is_pending(), - "writer must remain pending while lock is contended" - ); - assert_eq!( - wake_counter.wakes.load(Ordering::Relaxed), - 0, - "deferred wake must not fire synchronously" - ); - - timeout(Duration::from_millis(50), async { - loop { - if wake_counter.wakes.load(Ordering::Relaxed) >= 1 { - break; - } - tokio::task::yield_now().await; - } - }) - .await - .expect("contended writer must schedule a deferred wake in bounded time"); - let wakes_after_first_yield = wake_counter.wakes.load(Ordering::Relaxed); - assert!( - wakes_after_first_yield >= 1, - "contended writer must schedule at least one deferred wake for liveness" - ); - - let second = Pin::new(&mut io).poll_write(&mut cx, &[0x22]); - assert!( - second.is_pending(), - "writer remains pending while lock is still held" - ); - - for _ in 0..8 { - tokio::task::yield_now().await; - } - let wakes_after_second_window = wake_counter.wakes.load(Ordering::Relaxed); - assert!( - wakes_after_second_window <= wakes_after_first_yield.saturating_add(2), - "writer contention should keep retry wakes bounded before lock acquisition: before={wakes_after_first_yield}, after={wakes_after_second_window}" - ); - - drop(held_lock); - let released = Pin::new(&mut io).poll_write(&mut cx, &[0x33]); - assert!( - released.is_ready(), - "writer must make progress once quota lock is released" - ); -} - -#[tokio::test] -async fn quota_lock_contention_read_path_schedules_deferred_wake_for_liveness() { - let _guard = super::quota_user_lock_test_scope(); - let map = super::QUOTA_USER_LOCKS.get_or_init(dashmap::DashMap::new); - map.clear(); - - let stats = Arc::new(Stats::new()); - let user = "quota-lock-read-liveness-user"; - - let lock = super::quota_user_lock(user); - let held_lock = lock - .try_lock() - .expect("test must hold the per-user quota lock before polling reader"); - - let counters = Arc::new(super::SharedCounters::new()); - let quota_exceeded = Arc::new(std::sync::atomic::AtomicBool::new(false)); - let mut io = super::StatsIo::new( - tokio::io::empty(), - counters, - Arc::clone(&stats), - user.to_string(), - Some(1024), - quota_exceeded, - tokio::time::Instant::now(), - ); - - let wake_counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&wake_counter)); - let mut cx = Context::from_waker(&waker); - let mut storage = [0u8; 1]; - let mut buf = ReadBuf::new(&mut storage); - - let first = Pin::new(&mut io).poll_read(&mut cx, &mut buf); - assert!( - first.is_pending(), - "reader must remain pending while lock is contended" - ); - assert_eq!( - wake_counter.wakes.load(Ordering::Relaxed), - 0, - "read contention wake must not fire synchronously" - ); - - timeout(Duration::from_millis(50), async { - loop { - if wake_counter.wakes.load(Ordering::Relaxed) >= 1 { - break; - } - tokio::task::yield_now().await; - } - }) - .await - .expect("read contention must schedule a deferred wake in bounded time"); - - drop(held_lock); - let mut buf_after_release = ReadBuf::new(&mut storage); - let released = Pin::new(&mut io).poll_read(&mut cx, &mut buf_after_release); - assert!( - released.is_ready(), - "reader must make progress once quota lock is released" - ); -} - -#[tokio::test] -async fn relay_bidirectional_enforces_live_user_quota() { - let stats = Arc::new(Stats::new()); - let user = "quota-user"; - stats.add_user_octets_from(user, 6); - - let (mut client_peer, relay_client) = duplex(4096); - let (relay_server, mut server_peer) = duplex(4096); - - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let relay_task = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 1024, - 1024, - user, - Arc::clone(&stats), - Some(8), - Arc::new(BufferPool::new()), - )); - - client_peer - .write_all(&[0x10, 0x20, 0x30, 0x40]) - .await - .expect("client write must succeed"); - - let mut forwarded = [0u8; 4]; - let _ = timeout( - Duration::from_millis(200), - server_peer.read_exact(&mut forwarded), - ) - .await; - - let relay_result = timeout(Duration::from_secs(2), relay_task) - .await - .expect("relay task must finish under quota cutoff") - .expect("relay task must not panic"); - - assert!( - matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == "quota-user"), - "relay must surface a typed quota error once live quota is exceeded" - ); -} - -#[tokio::test] -async fn relay_bidirectional_does_not_forward_server_bytes_after_quota_is_exhausted() { - let stats = Arc::new(Stats::new()); - let quota_user = "quota-exhausted-user"; - stats.add_user_octets_from(quota_user, 1); - - let (mut client_peer, relay_client) = duplex(4096); - let (relay_server, mut server_peer) = duplex(4096); - - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let relay_task = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 1024, - 1024, - quota_user, - Arc::clone(&stats), - Some(1), - Arc::new(BufferPool::new()), - )); - - server_peer - .write_all(&[0xde, 0xad, 0xbe, 0xef]) - .await - .expect("server write must succeed"); - - let mut observed = [0u8; 4]; - let forwarded = timeout( - Duration::from_millis(200), - client_peer.read_exact(&mut observed), - ) - .await; - - let relay_result = timeout(Duration::from_secs(2), relay_task) - .await - .expect("relay task must finish under quota cutoff") - .expect("relay task must not panic"); - - assert!( - !matches!(forwarded, Ok(Ok(n)) if n == observed.len()), - "no full server payload should be forwarded once quota is already exhausted" - ); - assert!( - matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user), - "relay must still terminate with a typed quota error" - ); -} - -#[tokio::test] -async fn relay_bidirectional_does_not_leak_partial_server_payload_when_remaining_quota_is_smaller_than_write() - { - let stats = Arc::new(Stats::new()); - let quota_user = "partial-leak-user"; - stats.add_user_octets_from(quota_user, 3); - - let (mut client_peer, relay_client) = duplex(4096); - let (relay_server, mut server_peer) = duplex(4096); - - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let relay_task = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 1024, - 1024, - quota_user, - Arc::clone(&stats), - Some(4), - Arc::new(BufferPool::new()), - )); - - server_peer - .write_all(&[0x11, 0x22, 0x33, 0x44]) - .await - .expect("server write must succeed"); - - let mut observed = [0u8; 8]; - let forwarded = timeout(Duration::from_millis(200), client_peer.read(&mut observed)).await; - - let relay_result = timeout(Duration::from_secs(2), relay_task) - .await - .expect("relay task must finish under quota cutoff") - .expect("relay task must not panic"); - - assert!( - !matches!(forwarded, Ok(Ok(n)) if n > 0), - "quota exhaustion must not leak any partial server payload when remaining quota is smaller than the write" - ); - assert!( - matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user), - "relay must still terminate with a typed quota error" - ); -} - -#[tokio::test] -async fn relay_bidirectional_zero_quota_remains_fail_closed_for_server_payloads_under_stress() { - let stats = Arc::new(Stats::new()); - let quota_user = "zero-quota-user"; - - for payload_len in [1usize, 16, 512, 4096] { - let (mut client_peer, relay_client) = duplex(4096); - let (relay_server, mut server_peer) = duplex(4096); - - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let relay_task = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 1024, - 1024, - quota_user, - Arc::clone(&stats), - Some(0), - Arc::new(BufferPool::new()), - )); - - let payload = vec![0x7f; payload_len]; - let _ = server_peer.write_all(&payload).await; - - let mut observed = vec![0u8; payload_len]; - let forwarded = timeout(Duration::from_millis(200), client_peer.read(&mut observed)).await; - - let relay_result = timeout(Duration::from_secs(2), relay_task) - .await - .expect("relay task must finish under zero-quota cutoff") - .expect("relay task must not panic"); - - assert!( - !matches!(forwarded, Ok(Ok(n)) if n > 0), - "zero quota must not forward any server bytes for payload_len={payload_len}" - ); - assert!( - matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user), - "zero quota must terminate with the typed quota error for payload_len={payload_len}" - ); - } -} - -#[tokio::test] -async fn relay_bidirectional_allows_exact_server_payload_at_quota_boundary() { - let stats = Arc::new(Stats::new()); - let quota_user = "exact-boundary-user"; - - let (mut client_peer, relay_client) = duplex(4096); - let (relay_server, mut server_peer) = duplex(4096); - - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let relay_task = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 1024, - 1024, - quota_user, - Arc::clone(&stats), - Some(4), - Arc::new(BufferPool::new()), - )); - - server_peer - .write_all(&[0x91, 0x92, 0x93, 0x94]) - .await - .expect("server write must succeed at exact quota boundary"); - - let mut observed = [0u8; 4]; - client_peer - .read_exact(&mut observed) - .await - .expect("client must receive the full payload at the exact quota boundary"); - assert_eq!(observed, [0x91, 0x92, 0x93, 0x94]); - - let relay_result = timeout(Duration::from_secs(2), relay_task) - .await - .expect("relay task must finish after exact boundary delivery") - .expect("relay task must not panic"); - - assert!( - matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user), - "relay must close with a typed quota error after reaching the exact boundary" - ); -} - -#[tokio::test] -async fn relay_bidirectional_does_not_forward_client_bytes_after_quota_is_exhausted() { - let stats = Arc::new(Stats::new()); - let quota_user = "client-exhausted-user"; - stats.add_user_octets_from(quota_user, 1); - - let (mut client_peer, relay_client) = duplex(4096); - let (relay_server, mut server_peer) = duplex(4096); - - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let relay_task = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 1024, - 1024, - quota_user, - Arc::clone(&stats), - Some(1), - Arc::new(BufferPool::new()), - )); - - client_peer - .write_all(&[0x51, 0x52, 0x53, 0x54]) - .await - .expect("client write must succeed even when quota is already exhausted"); - - let mut observed = [0u8; 4]; - let forwarded = timeout( - Duration::from_millis(200), - server_peer.read_exact(&mut observed), - ) - .await; - - let relay_result = timeout(Duration::from_secs(2), relay_task) - .await - .expect("relay task must finish under quota cutoff") - .expect("relay task must not panic"); - - assert!( - !matches!(forwarded, Ok(Ok(n)) if n == observed.len()), - "client payload must not be fully forwarded once quota is already exhausted" - ); - assert!( - matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user), - "relay must still terminate with a typed quota error" - ); -} - -#[tokio::test] -async fn relay_bidirectional_server_bytes_remain_blocked_even_under_multiple_payload_sizes() { - let stats = Arc::new(Stats::new()); - let quota_user = "quota-fuzz-user"; - stats.add_user_octets_from(quota_user, 2); - - for payload_len in [1usize, 32, 1024, 8192] { - let (mut client_peer, relay_client) = duplex(4096); - let (relay_server, mut server_peer) = duplex(4096); - - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let relay_task = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 1024, - 1024, - quota_user, - Arc::clone(&stats), - Some(2), - Arc::new(BufferPool::new()), - )); - - let payload = vec![0xaa; payload_len]; - let _ = server_peer.write_all(&payload).await; - - let mut observed = vec![0u8; payload_len]; - let forwarded = timeout( - Duration::from_millis(200), - client_peer.read_exact(&mut observed), - ) - .await; - - let relay_result = timeout(Duration::from_secs(2), relay_task) - .await - .expect("relay task must finish under quota cutoff") - .expect("relay task must not panic"); - - assert!( - !matches!(forwarded, Ok(Ok(n)) if n == payload_len), - "quota exhaustion must block full server-to-client forwarding for payload_len={payload_len}" - ); - assert!( - matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user), - "relay must keep returning the typed quota error for payload_len={payload_len}" - ); - } -} - -#[tokio::test] -async fn relay_bidirectional_terminates_on_activity_timeout() { - tokio::time::pause(); - let stats = Arc::new(Stats::new()); - let user = "timeout-user"; - - let (client_peer, relay_client) = duplex(4096); - let (relay_server, server_peer) = duplex(4096); - - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let relay_task = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 1024, - 1024, - user, - Arc::clone(&stats), - None, // No quota - Arc::new(BufferPool::new()), - )); - - // Wait past the activity timeout threshold (1800 seconds) + buffer - tokio::time::sleep(Duration::from_secs(1805)).await; - - // Resume time to process timeouts - tokio::time::resume(); - - let relay_result = timeout(Duration::from_secs(1), relay_task) - .await - .expect("relay task must finish inside bounded timeout due to inactivity cutoff") - .expect("relay task must not panic"); - - assert!( - relay_result.is_ok(), - "relay should complete successfully on scheduled inactivity timeout" - ); - - // Verify client/server sockets are closed - drop(client_peer); - drop(server_peer); -} - -#[tokio::test] -async fn relay_bidirectional_watchdog_resists_premature_execution() { - tokio::time::pause(); - let stats = Arc::new(Stats::new()); - let user = "activity-user"; - - let (mut client_peer, relay_client) = duplex(4096); - let (relay_server, server_peer) = duplex(4096); - - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let mut relay_task = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 1024, - 1024, - user, - Arc::clone(&stats), - None, - Arc::new(BufferPool::new()), - )); - - // Advance by half the timeout - tokio::time::sleep(Duration::from_secs(900)).await; - - // Provide activity - client_peer - .write_all(&[0xaa, 0xbb]) - .await - .expect("client write must succeed"); - client_peer.flush().await.unwrap(); - - // Advance by another half (total time since start is 1800, but since last activity is 900) - tokio::time::sleep(Duration::from_secs(900)).await; - - tokio::time::resume(); - - // Re-evaluating the task, it should NOT have timed out and still be pending - let relay_result = timeout(Duration::from_millis(100), &mut relay_task).await; - assert!( - relay_result.is_err(), - "Relay must not exit prematurely as long as activity was received before timeout" - ); - - // Explicitly drop sockets to cleanly shut down relay loop - drop(client_peer); - drop(server_peer); - - let completion = timeout(Duration::from_secs(1), relay_task) - .await - .expect("relay task must complete securely after client disconnection") - .expect("relay task must not panic"); - assert!(completion.is_ok(), "relay exits clean"); -} - -#[tokio::test] -async fn relay_bidirectional_half_closure_terminates_cleanly() { - let stats = Arc::new(Stats::new()); - let (client_peer, relay_client) = duplex(4096); - let (relay_server, server_peer) = duplex(4096); - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let relay_task = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 1024, - 1024, - "half-close", - stats, - None, - Arc::new(BufferPool::new()), - )); - - // Half closure: drop the client completely but leave the server active. - drop(client_peer); - - // Check that we don't immediately crash. Bidirectional relay stays open for the server -> client flush. - // Eventually dropping the server cleanly closes the task. - drop(server_peer); - timeout(Duration::from_secs(1), relay_task) - .await - .unwrap() - .unwrap() - .unwrap(); -} - -#[tokio::test] -async fn relay_bidirectional_zero_length_noise_fuzzing() { - let stats = Arc::new(Stats::new()); - let (mut client_peer, relay_client) = duplex(4096); - let (relay_server, mut server_peer) = duplex(4096); - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let relay_task = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 1024, - 1024, - "fuzz", - stats, - None, - Arc::new(BufferPool::new()), - )); - - // Flood with zero-length payloads (edge cases in stream framing logic sometimes loop) - for _ in 0..100 { - client_peer.write_all(&[]).await.unwrap(); - } - client_peer.write_all(&[1, 2, 3]).await.unwrap(); - client_peer.flush().await.unwrap(); - - let mut buf = [0u8; 3]; - server_peer.read_exact(&mut buf).await.unwrap(); - assert_eq!(&buf, &[1, 2, 3]); - - drop(client_peer); - drop(server_peer); - timeout(Duration::from_secs(1), relay_task) - .await - .unwrap() - .unwrap() - .unwrap(); -} - -#[tokio::test] -async fn relay_bidirectional_asymmetric_backpressure() { - let stats = Arc::new(Stats::new()); - // Give the client stream an extremely narrow throughput limit explicitly - let (client_peer, relay_client) = duplex(1024); - let (relay_server, mut server_peer) = duplex(4096); - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let relay_task = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 1024, - 1024, - "slowloris", - stats, - None, - Arc::new(BufferPool::new()), - )); - - let payload = vec![0xba; 65536]; // 64k payload - - // Server attempts to shove 64KB into a relay whose client pipe only holds 1KB! - let write_res = - tokio::time::timeout(Duration::from_millis(50), server_peer.write_all(&payload)).await; - - assert!( - write_res.is_err(), - "Relay backpressure MUST halt the server writer from unbounded buffering when client stream is full!" - ); - - drop(client_peer); - drop(server_peer); - - let completion = timeout(Duration::from_secs(1), relay_task) - .await - .unwrap() - .unwrap(); - assert!( - completion.is_ok() || completion.is_err(), - "Task must unwind reliably (either Ok or BrokenPipe Err) when dropped despite active backpressure locks" - ); -} - -use rand::{RngExt, SeedableRng, rngs::StdRng}; - -#[tokio::test] -async fn relay_bidirectional_light_fuzzing_temporal_jitter() { - tokio::time::pause(); - let stats = Arc::new(Stats::new()); - let (mut client_peer, relay_client) = duplex(4096); - let (relay_server, server_peer) = duplex(4096); - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let mut relay_task = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 1024, - 1024, - "fuzz-user", - stats, - None, - Arc::new(BufferPool::new()), - )); - - let mut rng = StdRng::seed_from_u64(0xDEADBEEF); - - for _ in 0..10 { - // Vary timing significantly up to 1600 seconds (limit is 1800s) - let jitter = rng.random_range(100..1600); - tokio::time::sleep(Duration::from_secs(jitter)).await; - - client_peer.write_all(&[0x11]).await.unwrap(); - client_peer.flush().await.unwrap(); - - // Ensure task has not died - let res = timeout(Duration::from_millis(10), &mut relay_task).await; - assert!( - res.is_err(), - "Relay must remain open indefinitely under light temporal fuzzing with active jitter pulses" - ); - } - - drop(client_peer); - drop(server_peer); - timeout(Duration::from_secs(1), relay_task) - .await - .unwrap() - .unwrap() - .unwrap(); -} - -struct FaultyReader { - error_once: Option, -} - -struct TwoPartyGate { - arrivals: AtomicUsize, - total_bytes: AtomicUsize, - wakers: Mutex>, -} - -impl TwoPartyGate { - fn new() -> Self { - Self { - arrivals: AtomicUsize::new(0), - total_bytes: AtomicUsize::new(0), - wakers: Mutex::new(Vec::new()), - } - } - - fn arrive_or_park(&self, cx: &mut Context<'_>) -> bool { - if self.arrivals.load(Ordering::Relaxed) >= 2 { - return true; - } - - let prev = self.arrivals.fetch_add(1, Ordering::AcqRel); - if prev + 1 >= 2 { - let mut wakers = self.wakers.lock().unwrap_or_else(|p| p.into_inner()); - for waker in wakers.drain(..) { - waker.wake(); - } - true - } else { - let mut wakers = self.wakers.lock().unwrap_or_else(|p| p.into_inner()); - wakers.push(cx.waker().clone()); - false - } - } - - fn total_bytes(&self) -> usize { - self.total_bytes.load(Ordering::Relaxed) - } -} - -struct GateWriter { - gate: Arc, - entered: bool, -} - -impl GateWriter { - fn new(gate: Arc) -> Self { - Self { - gate, - entered: false, - } - } -} - -impl AsyncWrite for GateWriter { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - if !self.entered { - self.entered = true; - } - - if !self.gate.arrive_or_park(cx) { - return Poll::Pending; - } - - self.gate - .total_bytes - .fetch_add(buf.len(), Ordering::Relaxed); - Poll::Ready(Ok(buf.len())) - } - - fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } -} - -struct GateReader { - gate: Arc, - entered: bool, - emitted: bool, -} - -impl GateReader { - fn new(gate: Arc) -> Self { - Self { - gate, - entered: false, - emitted: false, - } - } -} - -impl AsyncRead for GateReader { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - if self.emitted { - return Poll::Ready(Ok(())); - } - - if !self.entered { - self.entered = true; - } - - if !self.gate.arrive_or_park(cx) { - return Poll::Pending; - } - - buf.put_slice(&[0x42]); - self.gate.total_bytes.fetch_add(1, Ordering::Relaxed); - self.emitted = true; - Poll::Ready(Ok(())) - } -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn adversarial_concurrent_quota_write_race_does_not_overshoot_limit() { - let stats = Arc::new(Stats::new()); - let gate = Arc::new(TwoPartyGate::new()); - let user = "concurrent-quota-write".to_string(); - - let writer_a = super::StatsIo::new( - GateWriter::new(Arc::clone(&gate)), - Arc::new(super::SharedCounters::new()), - Arc::clone(&stats), - user.clone(), - Some(1), - Arc::new(std::sync::atomic::AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - - let writer_b = super::StatsIo::new( - GateWriter::new(Arc::clone(&gate)), - Arc::new(super::SharedCounters::new()), - Arc::clone(&stats), - user.clone(), - Some(1), - Arc::new(std::sync::atomic::AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - - let task_a = tokio::spawn(async move { - let mut w = writer_a; - AsyncWriteExt::write_all(&mut w, &[0x01]).await - }); - let task_b = tokio::spawn(async move { - let mut w = writer_b; - AsyncWriteExt::write_all(&mut w, &[0x02]).await - }); - - let (res_a, res_b) = tokio::join!(task_a, task_b); - let _ = res_a.expect("task a must join"); - let _ = res_b.expect("task b must join"); - - assert!( - gate.total_bytes() <= 1, - "concurrent same-user writes must not forward more than one byte under quota=1" - ); - assert!( - stats.get_user_total_octets(&user) <= 1, - "concurrent same-user writes must not account over limit" - ); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn adversarial_concurrent_quota_read_race_does_not_overshoot_limit() { - let stats = Arc::new(Stats::new()); - let gate = Arc::new(TwoPartyGate::new()); - let user = "concurrent-quota-read".to_string(); - - let reader_a = super::StatsIo::new( - GateReader::new(Arc::clone(&gate)), - Arc::new(super::SharedCounters::new()), - Arc::clone(&stats), - user.clone(), - Some(1), - Arc::new(std::sync::atomic::AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - - let reader_b = super::StatsIo::new( - GateReader::new(Arc::clone(&gate)), - Arc::new(super::SharedCounters::new()), - Arc::clone(&stats), - user.clone(), - Some(1), - Arc::new(std::sync::atomic::AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - - let task_a = tokio::spawn(async move { - let mut r = reader_a; - let mut one = [0u8; 1]; - AsyncReadExt::read_exact(&mut r, &mut one).await - }); - let task_b = tokio::spawn(async move { - let mut r = reader_b; - let mut one = [0u8; 1]; - AsyncReadExt::read_exact(&mut r, &mut one).await - }); - - let (res_a, res_b) = tokio::join!(task_a, task_b); - let _ = res_a.expect("task a must join"); - let _ = res_b.expect("task b must join"); - - assert!( - gate.total_bytes() <= 1, - "concurrent same-user reads must not consume more than one byte under quota=1" - ); - assert!( - stats.get_user_total_octets(&user) <= 1, - "concurrent same-user reads must not account over limit" - ); -} - -#[tokio::test] -async fn stress_same_user_quota_parallel_relays_never_exceed_limit() { - let stats = Arc::new(Stats::new()); - let user = "parallel-quota-user"; - - for _ in 0..128 { - let (mut client_peer_a, relay_client_a) = duplex(256); - let (relay_server_a, mut server_peer_a) = duplex(256); - let (mut client_peer_b, relay_client_b) = duplex(256); - let (relay_server_b, mut server_peer_b) = duplex(256); - - let (client_reader_a, client_writer_a) = tokio::io::split(relay_client_a); - let (server_reader_a, server_writer_a) = tokio::io::split(relay_server_a); - let (client_reader_b, client_writer_b) = tokio::io::split(relay_client_b); - let (server_reader_b, server_writer_b) = tokio::io::split(relay_server_b); - - let relay_a = tokio::spawn(relay_bidirectional( - client_reader_a, - client_writer_a, - server_reader_a, - server_writer_a, - 64, - 64, - user, - Arc::clone(&stats), - Some(1), - Arc::new(BufferPool::new()), - )); - - let relay_b = tokio::spawn(relay_bidirectional( - client_reader_b, - client_writer_b, - server_reader_b, - server_writer_b, - 64, - 64, - user, - Arc::clone(&stats), - Some(1), - Arc::new(BufferPool::new()), - )); - - let _ = tokio::join!( - client_peer_a.write_all(&[0x01]), - server_peer_a.write_all(&[0x02]), - client_peer_b.write_all(&[0x03]), - server_peer_b.write_all(&[0x04]), - ); - - let _ = timeout( - Duration::from_millis(50), - poll_fn(|cx| { - let mut one = [0u8; 1]; - let _ = Pin::new(&mut client_peer_a).poll_read(cx, &mut ReadBuf::new(&mut one)); - Poll::Ready(()) - }), - ) - .await; - - drop(client_peer_a); - drop(server_peer_a); - drop(client_peer_b); - drop(server_peer_b); - - let _ = timeout(Duration::from_secs(1), relay_a).await; - let _ = timeout(Duration::from_secs(1), relay_b).await; - - assert!( - stats.get_user_total_octets(user) <= 1, - "parallel relays must not exceed configured quota" - ); - } -} - -impl FaultyReader { - fn permission_denied_with_message(message: impl Into) -> Self { - Self { - error_once: Some(io::Error::new( - io::ErrorKind::PermissionDenied, - message.into(), - )), - } - } -} - -impl AsyncRead for FaultyReader { - fn poll_read( - mut self: Pin<&mut Self>, - _cx: &mut Context<'_>, - _buf: &mut ReadBuf<'_>, - ) -> Poll> { - if let Some(err) = self.error_once.take() { - return Poll::Ready(Err(err)); - } - Poll::Ready(Ok(())) - } -} - -#[tokio::test] -async fn relay_bidirectional_does_not_misclassify_transport_permission_denied_as_quota() { - let stats = Arc::new(Stats::new()); - let (client_peer, relay_client) = duplex(4096); - let (client_reader, client_writer) = tokio::io::split(relay_client); - - let relay_result = relay_bidirectional( - client_reader, - client_writer, - FaultyReader::permission_denied_with_message("user data quota exceeded"), - tokio::io::sink(), - 1024, - 1024, - "non-quota-permission-denied", - Arc::clone(&stats), - None, - Arc::new(BufferPool::new()), - ) - .await; - - drop(client_peer); - - assert!( - matches!(relay_result, Err(ProxyError::Io(ref err)) if err.kind() == io::ErrorKind::PermissionDenied), - "non-quota transport PermissionDenied errors must remain IO errors" - ); -} - -#[tokio::test] -async fn relay_bidirectional_light_fuzz_permission_denied_messages_remain_io_errors() { - let mut rng = StdRng::seed_from_u64(0xA11CE0B5); - - for i in 0..128u64 { - let stats = Arc::new(Stats::new()); - let (client_peer, relay_client) = duplex(1024); - let (client_reader, client_writer) = tokio::io::split(relay_client); - - let random_len = rng.random_range(1..=48); - let mut msg = String::with_capacity(random_len); - for _ in 0..random_len { - let ch = (b'a' + (rng.random::() % 26)) as char; - msg.push(ch); - } - // Include the legacy quota string in a subset of fuzz cases to validate - // collision resistance against message-based classification. - if i % 7 == 0 { - msg = "user data quota exceeded".to_string(); - } - - let relay_result = relay_bidirectional( - client_reader, - client_writer, - FaultyReader::permission_denied_with_message(msg), - tokio::io::sink(), - 1024, - 1024, - "fuzz-perm-denied", - Arc::clone(&stats), - None, - Arc::new(BufferPool::new()), - ) - .await; - - drop(client_peer); - - assert!( - matches!(relay_result, Err(ProxyError::Io(ref err)) if err.kind() == io::ErrorKind::PermissionDenied), - "transport PermissionDenied case must stay typed as IO regardless of message content" - ); - } -} - -#[tokio::test] -async fn relay_half_close_keeps_reverse_direction_progressing() { - let stats = Arc::new(Stats::new()); - let user = "half-close-user"; - - let (client_peer, relay_client) = duplex(1024); - let (relay_server, server_peer) = duplex(1024); - - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - let (mut cp_reader, mut cp_writer) = tokio::io::split(client_peer); - let (mut sp_reader, mut sp_writer) = tokio::io::split(server_peer); - - let relay_task = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 8192, - 8192, - user, - Arc::clone(&stats), - None, - Arc::new(BufferPool::new()), - )); - - sp_writer - .write_all(&[0x10, 0x20, 0x30, 0x40]) - .await - .unwrap(); - sp_writer.shutdown().await.unwrap(); - - let mut inbound = [0u8; 4]; - cp_reader.read_exact(&mut inbound).await.unwrap(); - assert_eq!(inbound, [0x10, 0x20, 0x30, 0x40]); - - cp_writer - .write_all(&[0xaa, 0xbb, 0xcc, 0xdd]) - .await - .unwrap(); - let mut outbound = [0u8; 4]; - sp_reader.read_exact(&mut outbound).await.unwrap(); - assert_eq!(outbound, [0xaa, 0xbb, 0xcc, 0xdd]); - - relay_task.abort(); - let joined = relay_task.await; - assert!(joined.is_err(), "aborted relay task must return join error"); -} From 0a5e8a09fd773b53b7d4636303d68e81365cb7b8 Mon Sep 17 00:00:00 2001 From: Alexander <32452033+avbor@users.noreply.github.com> Date: Mon, 23 Mar 2026 16:29:08 +0300 Subject: [PATCH 10/29] Update VPS_DOUBLE_HOP.ru.md Added S3-S4 parameters for AWG and update AWG generator. --- docs/VPS_DOUBLE_HOP.ru.md | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/docs/VPS_DOUBLE_HOP.ru.md b/docs/VPS_DOUBLE_HOP.ru.md index 689a7c0..037dfcb 100644 --- a/docs/VPS_DOUBLE_HOP.ru.md +++ b/docs/VPS_DOUBLE_HOP.ru.md @@ -44,7 +44,7 @@ awg genkey | tee private.key | awg pubkey > public.key Параметры обфускации `S1`, `S2`, `H1`, `H2`, `H3`, `H4` должны быть строго идентичными на обоих серверах.\ Параметры `Jc`, `Jmin` и `Jmax` могут отличатся.\ -Параметры `I1-I5` [(Custom Protocol Signature)](https://docs.amnezia.org/documentation/amnezia-wg/) нужно указывать на стороне _клиента_ (Сервер **А**). +Параметры `I1-I5` ([Custom Protocol Signature](https://docs.amnezia.org/documentation/amnezia-wg/)) нужно указывать на стороне _клиента_ (Сервер **А**). Рекомендации по выбору значений: ```text @@ -62,7 +62,7 @@ H1/H2/H3/H4 — должны быть уникальны и отличаться ``` > [!IMPORTANT] > Рекомендуется использовать собственные, уникальные значения.\ -> Для выбора параметров можете воспользоваться [генератором](https://htmlpreview.github.io/?https://gist.githubusercontent.com/avbor/955782b5c37b06240b243aa375baeac5/raw/e8b269ff0089a27effd88f8d925179b78e5666c4/awg-gen.html). +> Для выбора параметров можете воспользоваться [генератором](https://htmlpreview.github.io/?https://gist.githubusercontent.com/avbor/955782b5c37b06240b243aa375baeac5/raw/13f5517ca473b47c412b9a99407066de973732bd/awg-gen.html). #### Конфигурация Сервера B (_Нидерланды_): @@ -83,6 +83,8 @@ Jmin = 8 Jmax = 80 S1 = 29 S2 = 15 +S3 = 18 +S4 = 0 H1 = 2087563914 H2 = 188817757 H3 = 101784570 @@ -121,6 +123,8 @@ Jmin = 8 Jmax = 80 S1 = 29 S2 = 15 +S3 = 18 +S4 = 0 H1 = 2087563914 H2 = 188817757 H3 = 101784570 From 41c2b4de65f51de32cf39d5a1c07935c430058fa Mon Sep 17 00:00:00 2001 From: Alexander <32452033+avbor@users.noreply.github.com> Date: Mon, 23 Mar 2026 16:30:37 +0300 Subject: [PATCH 11/29] Update VPS_DOUBLE_HOP.en.md Added S3-S4 parameters for AWG and update AWG generator. --- docs/VPS_DOUBLE_HOP.en.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/docs/VPS_DOUBLE_HOP.en.md b/docs/VPS_DOUBLE_HOP.en.md index 9463b79..6b6abe5 100644 --- a/docs/VPS_DOUBLE_HOP.en.md +++ b/docs/VPS_DOUBLE_HOP.en.md @@ -63,7 +63,7 @@ recommended range from 5 to 2147483647 inclusive > [!IMPORTANT] > It is recommended to use your own, unique values.\ -> You can use the [generator](https://htmlpreview.github.io/?https://gist.githubusercontent.com/avbor/955782b5c37b06240b243aa375baeac5/raw/e8b269ff0089a27effd88f8d925179b78e5666c4/awg-gen.html) to select parameters. +> You can use the [generator](https://htmlpreview.github.io/?https://gist.githubusercontent.com/avbor/955782b5c37b06240b243aa375baeac5/raw/13f5517ca473b47c412b9a99407066de973732bd/awg-gen.html) to select parameters. #### Server B Configuration (Netherlands): @@ -84,6 +84,8 @@ Jmin = 8 Jmax = 80 S1 = 29 S2 = 15 +S3 = 18 +S4 = 0 H1 = 2087563914 H2 = 188817757 H3 = 101784570 @@ -121,6 +123,8 @@ Jmin = 8 Jmax = 80 S1 = 29 S2 = 15 +S3 = 18 +S4 = 0 H1 = 2087563914 H2 = 188817757 H3 = 101784570 From 2d69b9d0aeb8544973608a074c427e558457d099 Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Mon, 23 Mar 2026 16:39:23 +0300 Subject: [PATCH 12/29] New wave of tests Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com> --- src/proxy/middle_relay.rs | 4 + ...ddle_relay_atomic_quota_invariant_tests.rs | 189 ++++++++++++++ .../relay_atomic_quota_invariant_tests.rs | 243 ++++++++++++++++++ ..._extended_attack_surface_security_tests.rs | 17 +- .../relay_quota_model_adversarial_tests.rs | 8 +- .../relay_quota_overflow_regression_tests.rs | 15 +- 6 files changed, 461 insertions(+), 15 deletions(-) create mode 100644 src/proxy/tests/middle_relay_atomic_quota_invariant_tests.rs create mode 100644 src/proxy/tests/relay_atomic_quota_invariant_tests.rs diff --git a/src/proxy/middle_relay.rs b/src/proxy/middle_relay.rs index 2a84353..d833019 100644 --- a/src/proxy/middle_relay.rs +++ b/src/proxy/middle_relay.rs @@ -1977,3 +1977,7 @@ mod middle_relay_tiny_frame_debt_concurrency_security_tests; #[cfg(test)] #[path = "tests/middle_relay_tiny_frame_debt_proto_chunking_security_tests.rs"] mod middle_relay_tiny_frame_debt_proto_chunking_security_tests; + +#[cfg(test)] +#[path = "tests/middle_relay_atomic_quota_invariant_tests.rs"] +mod middle_relay_atomic_quota_invariant_tests; diff --git a/src/proxy/tests/middle_relay_atomic_quota_invariant_tests.rs b/src/proxy/tests/middle_relay_atomic_quota_invariant_tests.rs new file mode 100644 index 0000000..7c176bc --- /dev/null +++ b/src/proxy/tests/middle_relay_atomic_quota_invariant_tests.rs @@ -0,0 +1,189 @@ +use super::*; +use crate::crypto::AesCtr; +use bytes::Bytes; +use std::io; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering}; +use std::task::{Context, Poll}; +use tokio::io::AsyncWrite; + +struct CountedWriter { + write_calls: Arc, + fail_writes: bool, +} + +impl CountedWriter { + fn new(write_calls: Arc, fail_writes: bool) -> Self { + Self { + write_calls, + fail_writes, + } + } +} + +impl AsyncWrite for CountedWriter { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let this = self.get_mut(); + this.write_calls.fetch_add(1, Ordering::Relaxed); + if this.fail_writes { + Poll::Ready(Err(io::Error::new( + io::ErrorKind::BrokenPipe, + "forced write failure", + ))) + } else { + Poll::Ready(Ok(buf.len())) + } + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} + +fn make_crypto_writer(inner: CountedWriter) -> CryptoWriter { + let key = [0u8; 32]; + let iv = 0u128; + CryptoWriter::new(inner, AesCtr::new(&key, iv), 8 * 1024) +} + +#[tokio::test] +async fn me_writer_write_fail_keeps_reserved_quota_and_tracks_fail_metrics() { + let stats = Stats::new(); + let user = "middle-me-writer-no-rollback-user"; + let user_stats = stats.get_or_create_user_stats_handle(user); + let write_calls = Arc::new(AtomicUsize::new(0)); + let mut writer = make_crypto_writer(CountedWriter::new(write_calls.clone(), true)); + let mut frame_buf = Vec::new(); + let bytes_me2c = AtomicU64::new(0); + let payload = Bytes::from_static(&[0x11, 0x22, 0x33, 0x44, 0x55]); + + let result = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: payload.clone(), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + user, + Some(user_stats.as_ref()), + Some(64), + 0, + &bytes_me2c, + 11, + true, + false, + ) + .await; + + assert!( + matches!(result, Err(ProxyError::Io(_))), + "write failure must propagate as I/O error" + ); + assert!( + write_calls.load(Ordering::Relaxed) > 0, + "writer must be attempted after successful quota reservation" + ); + assert_eq!( + stats.get_user_quota_used(user), + payload.len() as u64, + "reserved quota must not roll back on write failure" + ); + assert_eq!( + stats.get_quota_write_fail_bytes_total(), + payload.len() as u64, + "write-fail byte metric must include failed payload size" + ); + assert_eq!( + stats.get_quota_write_fail_events_total(), + 1, + "write-fail events metric must increment once" + ); + assert_eq!( + stats.get_user_total_octets(user), + 0, + "telemetry octets_to should not advance when write fails" + ); + assert_eq!( + bytes_me2c.load(Ordering::Relaxed), + 0, + "ME->C committed byte counter must not advance on write failure" + ); +} + +#[tokio::test] +async fn me_writer_pre_write_quota_reject_happens_before_writer_poll() { + let stats = Stats::new(); + let user = "middle-me-writer-precheck-user"; + let limit = 8u64; + let user_stats = stats.get_or_create_user_stats_handle(user); + stats.quota_charge_post_write(user_stats.as_ref(), limit); + + let write_calls = Arc::new(AtomicUsize::new(0)); + let mut writer = make_crypto_writer(CountedWriter::new(write_calls.clone(), false)); + let mut frame_buf = Vec::new(); + let bytes_me2c = AtomicU64::new(0); + + let result = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0xAA, 0xBB, 0xCC]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + user, + Some(user_stats.as_ref()), + Some(limit), + 0, + &bytes_me2c, + 12, + true, + false, + ) + .await; + + assert!( + matches!(result, Err(ProxyError::DataQuotaExceeded { .. })), + "pre-write quota rejection must return typed quota error" + ); + assert_eq!( + write_calls.load(Ordering::Relaxed), + 0, + "writer must not be polled when pre-write quota reservation fails" + ); + assert_eq!( + stats.get_me_d2c_quota_reject_pre_write_total(), + 1, + "pre-write quota reject metric must increment" + ); + assert_eq!( + stats.get_user_quota_used(user), + limit, + "failed pre-write reservation must keep previous quota usage unchanged" + ); + assert_eq!( + stats.get_quota_write_fail_bytes_total(), + 0, + "write-fail bytes metric must stay unchanged on pre-write reject" + ); + assert_eq!( + stats.get_quota_write_fail_events_total(), + 0, + "write-fail events metric must stay unchanged on pre-write reject" + ); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0); +} diff --git a/src/proxy/tests/relay_atomic_quota_invariant_tests.rs b/src/proxy/tests/relay_atomic_quota_invariant_tests.rs new file mode 100644 index 0000000..1bb00a6 --- /dev/null +++ b/src/proxy/tests/relay_atomic_quota_invariant_tests.rs @@ -0,0 +1,243 @@ +use super::*; +use std::collections::VecDeque; +use std::io; +use std::pin::Pin; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::sync::{Arc, Mutex}; +use std::task::{Context, Poll}; +use tokio::io::{AsyncWrite, AsyncWriteExt}; +use tokio::time::Instant; + +struct ScriptedWriter { + scripted_writes: Arc>>, + write_calls: Arc, +} + +impl ScriptedWriter { + fn new(script: &[usize], write_calls: Arc) -> Self { + Self { + scripted_writes: Arc::new(Mutex::new(script.iter().copied().collect())), + write_calls, + } + } +} + +impl AsyncWrite for ScriptedWriter { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let this = self.get_mut(); + this.write_calls.fetch_add(1, Ordering::Relaxed); + let planned = this + .scripted_writes + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) + .pop_front() + .unwrap_or(buf.len()); + Poll::Ready(Ok(planned.min(buf.len()))) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} + +fn make_stats_io_with_script( + user: &str, + quota_limit: u64, + precharged_quota: u64, + script: &[usize], +) -> ( + StatsIo, + Arc, + Arc, + Arc, +) { + let stats = Arc::new(Stats::new()); + if precharged_quota > 0 { + let user_stats = stats.get_or_create_user_stats_handle(user); + stats.quota_charge_post_write(user_stats.as_ref(), precharged_quota); + } + + let write_calls = Arc::new(AtomicUsize::new(0)); + let quota_exceeded = Arc::new(AtomicBool::new(false)); + let io = StatsIo::new( + ScriptedWriter::new(script, write_calls.clone()), + Arc::new(SharedCounters::new()), + stats.clone(), + user.to_string(), + Some(quota_limit), + quota_exceeded.clone(), + Instant::now(), + ); + + (io, stats, write_calls, quota_exceeded) +} + +#[tokio::test] +async fn direct_partial_write_charges_only_committed_bytes_without_double_charge() { + let user = "direct-partial-charge-user"; + let (mut io, stats, write_calls, quota_exceeded) = + make_stats_io_with_script(user, 1_048_576, 0, &[8 * 1024, 8 * 1024, 48 * 1024]); + let payload = vec![0xAB; 64 * 1024]; + + let n1 = io + .write(&payload) + .await + .expect("first partial write must succeed"); + let n2 = io + .write(&payload) + .await + .expect("second partial write must succeed"); + let n3 = io.write(&payload).await.expect("tail write must succeed"); + + assert_eq!(n1, 8 * 1024); + assert_eq!(n2, 8 * 1024); + assert_eq!(n3, 48 * 1024); + assert_eq!(write_calls.load(Ordering::Relaxed), 3); + assert_eq!( + stats.get_user_quota_used(user), + (n1 + n2 + n3) as u64, + "quota accounting must follow committed bytes only" + ); + assert_eq!( + stats.get_user_total_octets(user), + (n1 + n2 + n3) as u64, + "telemetry octets should match committed bytes on successful writes" + ); + assert!( + !quota_exceeded.load(Ordering::Acquire), + "quota flag should stay false under large remaining budget" + ); +} + +#[tokio::test] +async fn direct_hybrid_branch_selection_matches_contract() { + let near_limit = 256 * 1024u64; + let near_remaining = 32 * 1024u64; + let (mut near_io, _stats, _calls, _flag) = make_stats_io_with_script( + "direct-near-limit-hard-check-user", + near_limit, + near_limit - near_remaining, + &[4 * 1024], + ); + let near_payload = vec![0x11; 4 * 1024]; + let near_written = near_io + .write(&near_payload) + .await + .expect("near-limit write must succeed"); + assert_eq!(near_written, 4 * 1024); + assert_eq!( + near_io.quota_bytes_since_check, 0, + "near-limit branch must go through immediate hard check" + ); + + let (mut far_small_io, _stats, _calls, _flag) = + make_stats_io_with_script("direct-far-small-amortized-user", 1_048_576, 0, &[4 * 1024]); + let far_small_payload = vec![0x22; 4 * 1024]; + let far_small_written = far_small_io + .write(&far_small_payload) + .await + .expect("small far-from-limit write must succeed"); + assert_eq!(far_small_written, 4 * 1024); + assert_eq!( + far_small_io.quota_bytes_since_check, + 4 * 1024, + "small far-from-limit write must go through amortized path" + ); + + let (mut far_large_io, _stats, _calls, _flag) = make_stats_io_with_script( + "direct-far-large-hard-check-user", + 1_048_576, + 0, + &[32 * 1024], + ); + let far_large_payload = vec![0x33; 32 * 1024]; + let far_large_written = far_large_io + .write(&far_large_payload) + .await + .expect("large write must succeed"); + assert_eq!(far_large_written, 32 * 1024); + assert_eq!( + far_large_io.quota_bytes_since_check, 0, + "large write must force immediate hard check even far from limit" + ); +} + +#[tokio::test] +async fn remaining_before_zero_rejects_without_calling_inner_writer() { + let user = "direct-zero-remaining-user"; + let limit = 8u64; + let (mut io, stats, write_calls, quota_exceeded) = + make_stats_io_with_script(user, limit, limit, &[1]); + + let err = io + .write(&[0x44]) + .await + .expect_err("write must fail when remaining quota is zero"); + + assert!( + is_quota_io_error(&err), + "zero-remaining gate must return typed quota I/O error" + ); + assert_eq!( + write_calls.load(Ordering::Relaxed), + 0, + "inner poll_write must not be called when remaining quota is zero" + ); + assert!( + quota_exceeded.load(Ordering::Acquire), + "zero-remaining gate must set exceeded flag" + ); + assert_eq!(stats.get_user_quota_used(user), limit); +} + +#[tokio::test] +async fn exceeded_flag_blocks_following_poll_before_inner_write() { + let user = "direct-exceeded-visibility-user"; + let (mut io, stats, write_calls, quota_exceeded) = + make_stats_io_with_script(user, 1, 0, &[1, 1]); + + let first = io + .write(&[0x55]) + .await + .expect("first byte should consume remaining quota"); + assert_eq!(first, 1); + assert!( + quota_exceeded.load(Ordering::Acquire), + "hard check should store quota_exceeded after boundary hit" + ); + + let second = io + .write(&[0x66]) + .await + .expect_err("next write must be rejected by early exceeded gate"); + assert!( + is_quota_io_error(&second), + "following write must fail with typed quota error" + ); + assert_eq!( + write_calls.load(Ordering::Relaxed), + 1, + "second write must be cut before touching inner writer" + ); + assert_eq!(stats.get_user_quota_used(user), 1); +} + +#[test] +fn adaptive_interval_clamp_matches_contract() { + assert_eq!(quota_adaptive_interval_bytes(0), 4 * 1024); + assert_eq!(quota_adaptive_interval_bytes(2 * 1024), 4 * 1024); + assert_eq!(quota_adaptive_interval_bytes(32 * 1024), 16 * 1024); + assert_eq!(quota_adaptive_interval_bytes(256 * 1024), 64 * 1024); + + assert!(should_immediate_quota_check(32 * 1024, 4 * 1024)); + assert!(should_immediate_quota_check(1_048_576, 32 * 1024)); + assert!(!should_immediate_quota_check(1_048_576, 4 * 1024)); +} diff --git a/src/proxy/tests/relay_quota_extended_attack_surface_security_tests.rs b/src/proxy/tests/relay_quota_extended_attack_surface_security_tests.rs index 5ee6522..e80690b 100644 --- a/src/proxy/tests/relay_quota_extended_attack_surface_security_tests.rs +++ b/src/proxy/tests/relay_quota_extended_attack_surface_security_tests.rs @@ -29,6 +29,11 @@ async fn read_available(reader: &mut R, budget: total } +fn preload_user_quota(stats: &Stats, user: &str, bytes: u64) { + let user_stats = stats.get_or_create_user_stats_handle(user); + stats.quota_charge_post_write(user_stats.as_ref(), bytes); +} + #[tokio::test] async fn positive_quota_path_forwards_both_directions_within_limit() { let stats = Arc::new(Stats::new()); @@ -63,14 +68,14 @@ async fn positive_quota_path_forwards_both_directions_within_limit() { let relay_result = timeout(Duration::from_secs(2), relay).await.unwrap().unwrap(); assert!(relay_result.is_ok()); - assert!(stats.get_user_total_octets(user) <= 16); + assert!(stats.get_user_quota_used(user) <= 16); } #[tokio::test] async fn negative_preloaded_quota_forbids_any_forwarding() { let stats = Arc::new(Stats::new()); let user = "quota-extended-negative-user"; - stats.add_user_octets_from(user, 8); + preload_user_quota(stats.as_ref(), user, 8); let (mut client_peer, relay_client) = duplex(1024); let (relay_server, mut server_peer) = duplex(1024); @@ -98,7 +103,7 @@ async fn negative_preloaded_quota_forbids_any_forwarding() { let relay_result = timeout(Duration::from_secs(2), relay).await.unwrap().unwrap(); assert!(matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. }))); - assert!(stats.get_user_total_octets(user) <= 8); + assert!(stats.get_user_quota_used(user) <= 8); } #[tokio::test] @@ -189,7 +194,7 @@ async fn adversarial_blackhat_alternating_jitter_does_not_overshoot_quota() { let relay_result = timeout(Duration::from_secs(3), relay).await.unwrap().unwrap(); assert!(matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. }))); assert!(total_forwarded <= quota as usize); - assert!(stats.get_user_total_octets(user) <= quota); + assert!(stats.get_user_quota_used(user) <= quota); } #[tokio::test] @@ -252,7 +257,7 @@ async fn light_fuzz_random_quota_schedule_preserves_quota_invariants() { let relay_result = timeout(Duration::from_secs(2), relay).await.unwrap().unwrap(); assert!(relay_result.is_ok() || matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. }))); assert!(total_forwarded <= quota as usize); - assert!(stats.get_user_total_octets(&user) <= quota); + assert!(stats.get_user_quota_used(&user) <= quota); } } @@ -327,6 +332,6 @@ async fn stress_parallel_relays_for_one_user_obey_global_quota() { delivered += task.await.unwrap(); } - assert!(stats.get_user_total_octets(&user) <= quota); + assert!(stats.get_user_quota_used(&user) <= quota); assert!(delivered <= quota as usize); } diff --git a/src/proxy/tests/relay_quota_model_adversarial_tests.rs b/src/proxy/tests/relay_quota_model_adversarial_tests.rs index 5714f48..73fd393 100644 --- a/src/proxy/tests/relay_quota_model_adversarial_tests.rs +++ b/src/proxy/tests/relay_quota_model_adversarial_tests.rs @@ -96,7 +96,7 @@ async fn model_fuzz_bidirectional_schedule_preserves_prefixes_and_quota_budget() "fuzz case {case}: delivered bytes exceed quota" ); assert!( - stats.get_user_total_octets(&user) <= quota, + stats.get_user_quota_used(&user) <= quota, "fuzz case {case}: accounted bytes exceed quota" ); } @@ -118,7 +118,7 @@ async fn model_fuzz_bidirectional_schedule_preserves_prefixes_and_quota_budget() assert_is_prefix(&recv_at_server, &sent_c2s, "C->S final"); assert_is_prefix(&recv_at_client, &sent_s2c, "S->C final"); assert!(recv_at_server.len() + recv_at_client.len() <= quota as usize); - assert!(stats.get_user_total_octets(&user) <= quota); + assert!(stats.get_user_quota_used(&user) <= quota); } } @@ -209,7 +209,7 @@ async fn adversarial_dual_direction_cutoff_race_allows_at_most_one_forwarded_byt relay_result, Err(ProxyError::DataQuotaExceeded { .. }) )); - assert!(stats.get_user_total_octets(user) <= 1); + assert!(stats.get_user_quota_used(user) <= 1); } #[tokio::test(flavor = "multi_thread", worker_threads = 4)] @@ -305,7 +305,7 @@ async fn stress_shared_user_multi_relay_global_quota_never_overshoots_under_mode } assert!( - stats.get_user_total_octets(user) <= quota, + stats.get_user_quota_used(user) <= quota, "global per-user quota must never overshoot under concurrent multi-relay model load" ); assert!( diff --git a/src/proxy/tests/relay_quota_overflow_regression_tests.rs b/src/proxy/tests/relay_quota_overflow_regression_tests.rs index dfbab85..a59954e 100644 --- a/src/proxy/tests/relay_quota_overflow_regression_tests.rs +++ b/src/proxy/tests/relay_quota_overflow_regression_tests.rs @@ -19,13 +19,18 @@ async fn read_available(reader: &mut R, budget_ms: u64) -> total } +fn preload_user_quota(stats: &Stats, user: &str, bytes: u64) { + let user_stats = stats.get_or_create_user_stats_handle(user); + stats.quota_charge_post_write(user_stats.as_ref(), bytes); +} + #[tokio::test] async fn regression_client_chunk_larger_than_remaining_quota_does_not_overshoot_accounting() { let stats = Arc::new(Stats::new()); let user = "quota-overflow-regression-client-chunk"; // Leave only 1 byte remaining under quota. - stats.add_user_octets_from(user, 9); + preload_user_quota(stats.as_ref(), user, 9); let (mut client_peer, relay_client) = duplex(2048); let (relay_server, mut server_peer) = duplex(2048); @@ -68,7 +73,7 @@ async fn regression_client_chunk_larger_than_remaining_quota_does_not_overshoot_ Err(ProxyError::DataQuotaExceeded { .. }) )); assert!( - stats.get_user_total_octets(user) <= 10, + stats.get_user_quota_used(user) <= 10, "accounted bytes must never exceed quota after overflowing chunk" ); } @@ -79,7 +84,7 @@ async fn regression_client_exact_remaining_quota_forwards_once_then_hard_cuts_of let user = "quota-overflow-regression-boundary"; // Leave exactly 4 bytes remaining. - stats.add_user_octets_from(user, 6); + preload_user_quota(stats.as_ref(), user, 6); let (mut client_peer, relay_client) = duplex(2048); let (relay_server, mut server_peer) = duplex(2048); @@ -131,7 +136,7 @@ async fn regression_client_exact_remaining_quota_forwards_once_then_hard_cuts_of relay_result, Err(ProxyError::DataQuotaExceeded { .. }) )); - assert!(stats.get_user_total_octets(user) <= 10); + assert!(stats.get_user_quota_used(user) <= 10); } #[tokio::test(flavor = "multi_thread", worker_threads = 4)] @@ -201,7 +206,7 @@ async fn stress_parallel_relays_same_user_quota_overflow_never_exceeds_cap() { "aggregate forwarded bytes across relays must stay within global user quota" ); assert!( - stats.get_user_total_octets(user) <= quota, + stats.get_user_quota_used(user) <= quota, "global accounted bytes must stay within quota under overflow stress" ); } From 8cfaab93208c404c5b60a336462d405797e1e11f Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Mon, 23 Mar 2026 16:39:49 +0300 Subject: [PATCH 13/29] Fixes in tests Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com> --- src/proxy/relay.rs | 4 + .../relay_quota_boundary_blackhat_tests.rs | 19 +++-- src/stats/mod.rs | 77 +++++++++++++++++++ 3 files changed, 93 insertions(+), 7 deletions(-) diff --git a/src/proxy/relay.rs b/src/proxy/relay.rs index cc8b088..bf4ad43 100644 --- a/src/proxy/relay.rs +++ b/src/proxy/relay.rs @@ -669,3 +669,7 @@ mod relay_quota_extended_attack_surface_security_tests; #[cfg(test)] #[path = "tests/relay_watchdog_delta_security_tests.rs"] mod relay_watchdog_delta_security_tests; + +#[cfg(test)] +#[path = "tests/relay_atomic_quota_invariant_tests.rs"] +mod relay_atomic_quota_invariant_tests; diff --git a/src/proxy/tests/relay_quota_boundary_blackhat_tests.rs b/src/proxy/tests/relay_quota_boundary_blackhat_tests.rs index 080240a..9a32b26 100644 --- a/src/proxy/tests/relay_quota_boundary_blackhat_tests.rs +++ b/src/proxy/tests/relay_quota_boundary_blackhat_tests.rs @@ -29,6 +29,11 @@ async fn read_available(reader: &mut R, budget: Duration) total } +fn preload_user_quota(stats: &Stats, user: &str, bytes: u64) { + let user_stats = stats.get_or_create_user_stats_handle(user); + stats.quota_charge_post_write(user_stats.as_ref(), bytes); +} + #[tokio::test] async fn integration_full_duplex_exact_budget_then_hard_cutoff() { let stats = Arc::new(Stats::new()); @@ -102,14 +107,14 @@ async fn integration_full_duplex_exact_budget_then_hard_cutoff() { relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == "quota-full-duplex-boundary-user" )); - assert!(stats.get_user_total_octets(user) <= 10); + assert!(stats.get_user_quota_used(user) <= 10); } #[tokio::test] async fn negative_preloaded_quota_blocks_both_directions_immediately() { let stats = Arc::new(Stats::new()); let user = "quota-preloaded-cutoff-user"; - stats.add_user_octets_from(user, 5); + preload_user_quota(stats.as_ref(), user, 5); let (mut client_peer, relay_client) = duplex(2048); let (relay_server, mut server_peer) = duplex(2048); @@ -154,7 +159,7 @@ async fn negative_preloaded_quota_blocks_both_directions_immediately() { relay_result, Err(ProxyError::DataQuotaExceeded { .. }) )); - assert!(stats.get_user_total_octets(user) <= 5); + assert!(stats.get_user_quota_used(user) <= 5); } #[tokio::test] @@ -212,7 +217,7 @@ async fn edge_quota_one_bidirectional_race_allows_at_most_one_forwarded_octet() relay_result, Err(ProxyError::DataQuotaExceeded { .. }) )); - assert!(stats.get_user_total_octets(user) <= 1); + assert!(stats.get_user_quota_used(user) <= 1); } #[tokio::test] @@ -277,7 +282,7 @@ async fn adversarial_blackhat_alternating_fragmented_jitter_never_overshoots_glo delivered_to_server + delivered_to_client <= quota as usize, "combined forwarded bytes must never exceed configured quota" ); - assert!(stats.get_user_total_octets(user) <= quota); + assert!(stats.get_user_quota_used(user) <= quota); } #[tokio::test] @@ -356,7 +361,7 @@ async fn light_fuzz_randomized_schedule_preserves_quota_and_forwarded_byte_invar "fuzz case {case}: forwarded bytes must not exceed quota" ); assert!( - stats.get_user_total_octets(&user) <= quota, + stats.get_user_quota_used(&user) <= quota, "fuzz case {case}: accounted bytes must not exceed quota" ); } @@ -451,7 +456,7 @@ async fn stress_multi_relay_same_user_mixed_direction_jitter_respects_global_quo } assert!( - stats.get_user_total_octets(user) <= quota, + stats.get_user_quota_used(user) <= quota, "global per-user quota must hold under concurrent mixed-direction relay stress" ); assert!( diff --git a/src/stats/mod.rs b/src/stats/mod.rs index 7d8aef3..297ff28 100644 --- a/src/stats/mod.rs +++ b/src/stats/mod.rs @@ -2580,6 +2580,56 @@ mod tests { assert_eq!(user_stats.quota_used(), limit); } + #[test] + fn test_quota_reserve_200x_1k_reaches_100k_without_overshoot() { + let user_stats = Arc::new(UserStats::default()); + let successes = Arc::new(AtomicU64::new(0)); + let failures = Arc::new(AtomicU64::new(0)); + let attempts = 200usize; + let reserve_bytes = 1_024u64; + let limit = 100 * 1_024u64; + let mut workers = Vec::with_capacity(attempts); + + for _ in 0..attempts { + let user_stats = user_stats.clone(); + let successes = successes.clone(); + let failures = failures.clone(); + workers.push(std::thread::spawn(move || { + loop { + match user_stats.quota_try_reserve(reserve_bytes, limit) { + Ok(_) => { + successes.fetch_add(1, Ordering::Relaxed); + return; + } + Err(QuotaReserveError::LimitExceeded) => { + failures.fetch_add(1, Ordering::Relaxed); + return; + } + Err(QuotaReserveError::Contended) => { + std::hint::spin_loop(); + } + } + } + })); + } + + for worker in workers { + worker.join().expect("reservation worker must finish"); + } + + assert_eq!( + successes.load(Ordering::Relaxed), + 100, + "exactly 100 reservations of 1 KiB must fit into a 100 KiB quota" + ); + assert_eq!( + failures.load(Ordering::Relaxed), + 100, + "remaining workers must fail once quota is fully reserved" + ); + assert_eq!(user_stats.quota_used(), limit); + } + #[test] fn test_quota_used_is_authoritative_and_independent_from_octets_telemetry() { let stats = Stats::new(); @@ -2594,6 +2644,33 @@ mod tests { assert_eq!(stats.get_user_total_octets(user), 5); assert_eq!(stats.get_user_quota_used(user), 7); } + + #[test] + fn test_cached_handle_survives_map_cleanup_until_last_drop() { + let stats = Stats::new(); + let user = "quota-handle-lifetime-user"; + let user_stats = stats.get_or_create_user_stats_handle(user); + let weak = Arc::downgrade(&user_stats); + + stats.user_stats.remove(user); + assert!( + stats.user_stats.get(user).is_none(), + "map cleanup should remove idle entry" + ); + assert!( + weak.upgrade().is_some(), + "cached handle must keep user stats object alive after map removal" + ); + + stats.quota_charge_post_write(user_stats.as_ref(), 3); + assert_eq!(user_stats.quota_used(), 3); + + drop(user_stats); + assert!( + weak.upgrade().is_none(), + "user stats object must be dropped after the last cached handle is released" + ); + } } #[cfg(test)] From e6b77af9310d6d1b208bac45a99fa851a619accb Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Mon, 23 Mar 2026 16:49:23 +0300 Subject: [PATCH 14/29] Workflows Swap Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com> --- .github/workflows/build.yml | 39 ++++++++++++++++++++++++ .github/workflows/rust.yml | 54 ---------------------------------- .github/workflows/stress.yml | 57 ------------------------------------ .github/workflows/test.yml | 56 +++++++++++++++++++++++++++++++++++ 4 files changed, 95 insertions(+), 111 deletions(-) create mode 100644 .github/workflows/build.yml delete mode 100644 .github/workflows/rust.yml delete mode 100644 .github/workflows/stress.yml create mode 100644 .github/workflows/test.yml diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml new file mode 100644 index 0000000..1b6e455 --- /dev/null +++ b/.github/workflows/build.yml @@ -0,0 +1,39 @@ +name: Build + +on: + push: + branches: [ "*" ] + pull_request: + branches: [ "*" ] + +env: + CARGO_TERM_COLOR: always + +jobs: + build: + name: Build + runs-on: ubuntu-latest + + permissions: + contents: read + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Install latest stable Rust toolchain + uses: dtolnay/rust-toolchain@stable + + - name: Cache cargo registry & build artifacts + uses: actions/cache@v4 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + target + key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} + restore-keys: | + ${{ runner.os }}-cargo- + + - name: Build Release + run: cargo build --release --verbose \ No newline at end of file diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml deleted file mode 100644 index b245679..0000000 --- a/.github/workflows/rust.yml +++ /dev/null @@ -1,54 +0,0 @@ -name: Rust - -on: - push: - branches: [ "*" ] - pull_request: - branches: [ "*" ] - -env: - CARGO_TERM_COLOR: always - -jobs: - build: - name: Compile, Test, Lint - runs-on: ubuntu-latest - - permissions: - contents: read - actions: write - checks: write - - steps: - - name: Checkout repository - uses: actions/checkout@v4 - - - name: Install latest stable Rust toolchain - uses: dtolnay/rust-toolchain@stable - with: - components: rustfmt, clippy - - - name: Cache cargo registry & build artifacts - uses: actions/cache@v4 - with: - path: | - ~/.cargo/registry - ~/.cargo/git - target - key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} - restore-keys: | - ${{ runner.os }}-cargo- - - - name: Compile (no tests) - run: cargo check --workspace --all-features --lib --bins --verbose - - - name: Run tests (single pass) - run: cargo test --workspace --all-features --verbose - -# clippy dont fail on warnings because of active development of telemt -# and many warnings - - name: Run clippy - run: cargo clippy -- --cap-lints warn - - - name: Check for unused dependencies - run: cargo udeps || true diff --git a/.github/workflows/stress.yml b/.github/workflows/stress.yml deleted file mode 100644 index 96b9a1b..0000000 --- a/.github/workflows/stress.yml +++ /dev/null @@ -1,57 +0,0 @@ -name: Stress Tests - -on: - workflow_dispatch: - schedule: - - cron: '0 2 * * *' - pull_request: - branches: ["*"] - paths: - - src/proxy/** - - src/transport/** - - src/stream/** - - src/protocol/** - - src/tls_front/** - - Cargo.toml - - Cargo.lock - -env: - CARGO_TERM_COLOR: always - -jobs: - quota-lock-stress: - name: Quota-lock stress loop - runs-on: ubuntu-latest - - permissions: - contents: read - - steps: - - name: Checkout repository - uses: actions/checkout@v4 - - - name: Install latest stable Rust toolchain - uses: dtolnay/rust-toolchain@stable - - - name: Cache cargo registry and build artifacts - uses: actions/cache@v4 - with: - path: | - ~/.cargo/registry - ~/.cargo/git - target - key: ${{ runner.os }}-cargo-stress-${{ hashFiles('**/Cargo.lock') }} - restore-keys: | - ${{ runner.os }}-cargo-stress- - ${{ runner.os }}-cargo- - - - name: Run quota-lock stress suites - env: - RUST_TEST_THREADS: 16 - run: | - set -euo pipefail - for i in $(seq 1 12); do - echo "[quota-lock-stress] iteration ${i}/12" - cargo test quota_lock_ --bin telemt -- --nocapture --test-threads 16 - cargo test relay_quota_wake --bin telemt -- --nocapture --test-threads 16 - done diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..d8f7a64 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,56 @@ +name: Test + +on: + push: + branches: [ "*" ] + pull_request: + branches: [ "*" ] + +env: + CARGO_TERM_COLOR: always + +jobs: + test: + name: Test / Lint / Analysis + runs-on: ubuntu-latest + + permissions: + contents: read + actions: write + checks: write + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Install latest stable Rust toolchain + uses: dtolnay/rust-toolchain@stable + with: + components: rustfmt, clippy + + - name: Cache cargo registry & build artifacts + uses: actions/cache@v4 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + target + key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} + restore-keys: | + ${{ runner.os }}-cargo- + + - name: Run tests + run: cargo test --verbose + + # clippy не валит билд (осознанно) + - name: Run clippy + run: cargo clippy -- --cap-lints warn + + - name: Check formatting + run: cargo fmt -- --check + + - name: Install cargo-udeps + run: cargo install cargo-udeps || true + + - name: Check for unused dependencies + run: cargo udeps || true \ No newline at end of file From 800356c751d1ca04f17edc37ba542f48c8099b2d Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Mon, 23 Mar 2026 17:04:47 +0300 Subject: [PATCH 15/29] Rewiring tests Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com> --- .../tests/client_deep_invariants_tests.rs | 9 +++++-- src/proxy/tests/client_more_advanced_tests.rs | 7 ++++- src/proxy/tests/client_security_tests.rs | 7 ++++- ...ing_additional_hardening_security_tests.rs | 6 ++++- src/proxy/tests/relay_adversarial_tests.rs | 27 +++++++++++++++---- .../relay_quota_model_adversarial_tests.rs | 13 ++++----- 6 files changed, 53 insertions(+), 16 deletions(-) diff --git a/src/proxy/tests/client_deep_invariants_tests.rs b/src/proxy/tests/client_deep_invariants_tests.rs index 97c55c6..0302300 100644 --- a/src/proxy/tests/client_deep_invariants_tests.rs +++ b/src/proxy/tests/client_deep_invariants_tests.rs @@ -7,6 +7,11 @@ use std::sync::Arc; use std::time::Duration; use tokio::io::{AsyncWriteExt, duplex}; +fn preload_user_quota(stats: &Stats, user: &str, bytes: u64) { + let user_stats = stats.get_or_create_user_stats_handle(user); + stats.quota_charge_post_write(user_stats.as_ref(), bytes); +} + #[test] fn invariant_wrap_tls_application_record_exact_multiples() { let chunk_size = u16::MAX as usize; @@ -114,7 +119,7 @@ async fn invariant_quota_exact_boundary_inclusive() { let ip_tracker = Arc::new(UserIpTracker::new()); let peer = "198.51.100.23:55000".parse().unwrap(); - stats.add_user_octets_from(user, 999); + preload_user_quota(stats.as_ref(), user, 999); let res1 = RunningClientHandler::acquire_user_connection_reservation_static( user, &config, @@ -126,7 +131,7 @@ async fn invariant_quota_exact_boundary_inclusive() { assert!(res1.is_ok()); res1.unwrap().release().await; - stats.add_user_octets_from(user, 1); + preload_user_quota(stats.as_ref(), user, 1); let res2 = RunningClientHandler::acquire_user_connection_reservation_static( user, &config, diff --git a/src/proxy/tests/client_more_advanced_tests.rs b/src/proxy/tests/client_more_advanced_tests.rs index 021848a..36ffcbb 100644 --- a/src/proxy/tests/client_more_advanced_tests.rs +++ b/src/proxy/tests/client_more_advanced_tests.rs @@ -6,6 +6,11 @@ use std::sync::Arc; use std::time::Duration; use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +fn preload_user_quota(stats: &Stats, user: &str, bytes: u64) { + let user_stats = stats.get_or_create_user_stats_handle(user); + stats.quota_charge_post_write(user_stats.as_ref(), bytes); +} + #[tokio::test] async fn edge_mask_delay_bypassed_if_max_is_zero() { let mut config = ProxyConfig::default(); @@ -42,7 +47,7 @@ async fn boundary_user_data_quota_exact_match_rejects() { config.access.user_data_quota.insert(user.to_string(), 1024); let stats = Arc::new(Stats::new()); - stats.add_user_octets_from(user, 1024); + preload_user_quota(stats.as_ref(), user, 1024); let ip_tracker = Arc::new(UserIpTracker::new()); let peer = "198.51.100.10:55000".parse().unwrap(); diff --git a/src/proxy/tests/client_security_tests.rs b/src/proxy/tests/client_security_tests.rs index 2b1fae6..bae1ce2 100644 --- a/src/proxy/tests/client_security_tests.rs +++ b/src/proxy/tests/client_security_tests.rs @@ -242,6 +242,11 @@ where CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) } +fn preload_user_quota(stats: &Stats, user: &str, bytes: u64) { + let user_stats = stats.get_or_create_user_stats_handle(user); + stats.quota_charge_post_write(user_stats.as_ref(), bytes); +} + #[tokio::test] async fn user_connection_reservation_drop_enqueues_cleanup_synchronously() { let ip_tracker = Arc::new(crate::ip_tracker::UserIpTracker::new()); @@ -3040,7 +3045,7 @@ async fn quota_rejection_does_not_reserve_ip_or_trigger_rollback() { .insert("user".to_string(), 1024); let stats = Stats::new(); - stats.add_user_octets_from("user", 1024); + preload_user_quota(&stats, "user", 1024); let ip_tracker = UserIpTracker::new(); let peer_addr: SocketAddr = "203.0.113.211:50001".parse().unwrap(); diff --git a/src/proxy/tests/masking_additional_hardening_security_tests.rs b/src/proxy/tests/masking_additional_hardening_security_tests.rs index 29170c1..a6f6386 100644 --- a/src/proxy/tests/masking_additional_hardening_security_tests.rs +++ b/src/proxy/tests/masking_additional_hardening_security_tests.rs @@ -78,7 +78,11 @@ fn timing_normalization_zero_floor_safety_net_defaults_to_mask_timeout() { config.censorship.mask_timing_normalization_ceiling_ms = 0; let budget = mask_outcome_target_budget(&config); - assert_eq!(budget, MASK_TIMEOUT); + assert_eq!( + budget, + Duration::from_millis(0), + "zero floor/ceiling must produce zero extra normalization budget" + ); } #[tokio::test] diff --git a/src/proxy/tests/relay_adversarial_tests.rs b/src/proxy/tests/relay_adversarial_tests.rs index 14754cd..38e6fc7 100644 --- a/src/proxy/tests/relay_adversarial_tests.rs +++ b/src/proxy/tests/relay_adversarial_tests.rs @@ -78,7 +78,8 @@ async fn relay_hol_blocking_prevention_regression() { async fn relay_quota_mid_session_cutoff() { let stats = Arc::new(Stats::new()); let user = "quota-mid-user"; - let quota = 5000; + let quota = 5000u64; + let c2s_buf_size = 1024usize; let (client_peer, relay_client) = duplex(8192); let (relay_server, server_peer) = duplex(8192); @@ -93,7 +94,7 @@ async fn relay_quota_mid_session_cutoff() { client_writer, server_reader, server_writer, - 1024, + c2s_buf_size, 1024, user, Arc::clone(&stats), @@ -120,9 +121,25 @@ async fn relay_quota_mid_session_cutoff() { other => panic!("Expected DataQuotaExceeded error, got: {:?}", other), } - let mut small_buf = [0u8; 1]; - let n = sp_reader.read(&mut small_buf).await.unwrap(); - assert_eq!(n, 0, "Server must see EOF after quota reached"); + let mut overshoot_bytes = 0usize; + let mut buf = [0u8; 256]; + loop { + match timeout(Duration::from_millis(20), sp_reader.read(&mut buf)).await { + Ok(Ok(0)) => break, + Ok(Ok(n)) => overshoot_bytes = overshoot_bytes.saturating_add(n), + Ok(Err(e)) => panic!("server read must not fail after relay cutoff: {e}"), + Err(_) => break, + } + } + + assert!( + overshoot_bytes <= c2s_buf_size, + "post-write cutoff may leak at most one C->S chunk after boundary, got {overshoot_bytes}" + ); + assert!( + stats.get_user_quota_used(user) <= quota.saturating_add(c2s_buf_size as u64), + "accounted quota must remain bounded by one in-flight chunk overshoot" + ); } #[tokio::test] diff --git a/src/proxy/tests/relay_quota_model_adversarial_tests.rs b/src/proxy/tests/relay_quota_model_adversarial_tests.rs index 73fd393..83bf731 100644 --- a/src/proxy/tests/relay_quota_model_adversarial_tests.rs +++ b/src/proxy/tests/relay_quota_model_adversarial_tests.rs @@ -32,6 +32,7 @@ async fn drain_available(reader: &mut R, out: &mut Vec #[tokio::test] async fn model_fuzz_bidirectional_schedule_preserves_prefixes_and_quota_budget() { let mut rng = StdRng::seed_from_u64(0xC0DE_CAFE_D15C_F00D); + const MAX_INPUT_CHUNK: usize = 12; for case in 0..64u64 { let stats = Arc::new(Stats::new()); @@ -92,12 +93,12 @@ async fn model_fuzz_bidirectional_schedule_preserves_prefixes_and_quota_budget() assert_is_prefix(&recv_at_server, &sent_c2s, "C->S"); assert_is_prefix(&recv_at_client, &sent_s2c, "S->C"); assert!( - recv_at_server.len() + recv_at_client.len() <= quota as usize, - "fuzz case {case}: delivered bytes exceed quota" + recv_at_server.len() + recv_at_client.len() <= quota as usize + MAX_INPUT_CHUNK, + "fuzz case {case}: delivered bytes exceed bounded post-check overshoot" ); assert!( - stats.get_user_quota_used(&user) <= quota, - "fuzz case {case}: accounted bytes exceed quota" + stats.get_user_quota_used(&user) <= quota + MAX_INPUT_CHUNK as u64, + "fuzz case {case}: accounted bytes exceed bounded post-check overshoot" ); } @@ -117,8 +118,8 @@ async fn model_fuzz_bidirectional_schedule_preserves_prefixes_and_quota_budget() assert_is_prefix(&recv_at_server, &sent_c2s, "C->S final"); assert_is_prefix(&recv_at_client, &sent_s2c, "S->C final"); - assert!(recv_at_server.len() + recv_at_client.len() <= quota as usize); - assert!(stats.get_user_quota_used(&user) <= quota); + assert!(recv_at_server.len() + recv_at_client.len() <= quota as usize + MAX_INPUT_CHUNK); + assert!(stats.get_user_quota_used(&user) <= quota + MAX_INPUT_CHUNK as u64); } } From 24156b5067dd00c9ef12fd5ca9e02d8f3118fe97 Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Mon, 23 Mar 2026 17:42:18 +0300 Subject: [PATCH 16/29] Workflow for Docker and correct binary naming Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com> --- .github/workflows/release.yml | 188 ++++++++++++++-------------------- Dockerfile | 60 ++--------- 2 files changed, 87 insertions(+), 161 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index def299d..d01293e 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -26,6 +26,9 @@ jobs: name: GNU ${{ matrix.target }} runs-on: ubuntu-latest + container: + image: rust:slim-bookworm + strategy: fail-fast: false matrix: @@ -47,8 +50,8 @@ jobs: - name: Install deps run: | - sudo apt-get update - sudo apt-get install -y \ + apt-get update + apt-get install -y \ build-essential \ clang \ lld \ @@ -69,14 +72,10 @@ jobs: if [ "${{ matrix.target }}" = "aarch64-unknown-linux-gnu" ]; then export CC=aarch64-linux-gnu-gcc export CXX=aarch64-linux-gnu-g++ - export CC_aarch64_unknown_linux_gnu=aarch64-linux-gnu-gcc - export CXX_aarch64_unknown_linux_gnu=aarch64-linux-gnu-g++ export RUSTFLAGS="-C linker=aarch64-linux-gnu-gcc" else export CC=clang export CXX=clang++ - export CC_x86_64_unknown_linux_gnu=clang - export CXX_x86_64_unknown_linux_gnu=clang++ export RUSTFLAGS="-C linker=clang -C link-arg=-fuse-ld=lld" fi @@ -85,20 +84,19 @@ jobs: - name: Package run: | mkdir -p dist - BIN=target/${{ matrix.target }}/release/${{ env.BINARY_NAME }} - - cp "$BIN" dist/${{ env.BINARY_NAME }}-${{ matrix.target }} + cp target/${{ matrix.target }}/release/${{ env.BINARY_NAME }} dist/telemt cd dist - tar -czf ${{ matrix.asset }}.tar.gz ${{ env.BINARY_NAME }}-${{ matrix.target }} + tar -czf ${{ matrix.asset }}.tar.gz \ + --owner=0 --group=0 --numeric-owner \ + telemt + sha256sum ${{ matrix.asset }}.tar.gz > ${{ matrix.asset }}.sha256 - uses: actions/upload-artifact@v4 with: name: ${{ matrix.asset }} - path: | - dist/${{ matrix.asset }}.tar.gz - dist/${{ matrix.asset }}.sha256 + path: dist/* # ========================== # MUSL @@ -125,43 +123,7 @@ jobs: - name: Install deps run: | apt-get update - apt-get install -y \ - musl-tools \ - pkg-config \ - curl - - - uses: actions/cache@v4 - if: matrix.target == 'aarch64-unknown-linux-musl' - with: - path: ~/.musl-aarch64 - key: musl-toolchain-aarch64-v1 - - - name: Install aarch64 musl toolchain - if: matrix.target == 'aarch64-unknown-linux-musl' - run: | - set -e - - TOOLCHAIN_DIR="$HOME/.musl-aarch64" - ARCHIVE="aarch64-linux-musl-cross.tgz" - URL="https://github.com/telemt/telemt/releases/download/toolchains/$ARCHIVE" - - if [ -x "$TOOLCHAIN_DIR/bin/aarch64-linux-musl-gcc" ]; then - echo "✅ MUSL toolchain already installed" - else - echo "⬇️ Downloading musl toolchain from Telemt GitHub Releases..." - - curl -fL \ - --retry 5 \ - --retry-delay 3 \ - --connect-timeout 10 \ - --max-time 120 \ - -o "$ARCHIVE" "$URL" - - mkdir -p "$TOOLCHAIN_DIR" - tar -xzf "$ARCHIVE" --strip-components=1 -C "$TOOLCHAIN_DIR" - fi - - echo "$TOOLCHAIN_DIR/bin" >> $GITHUB_PATH + apt-get install -y musl-tools pkg-config curl - name: Add rust target run: rustup target add ${{ matrix.target }} @@ -178,11 +140,9 @@ jobs: run: | if [ "${{ matrix.target }}" = "aarch64-unknown-linux-musl" ]; then export CC=aarch64-linux-musl-gcc - export CC_aarch64_unknown_linux_musl=aarch64-linux-musl-gcc export RUSTFLAGS="-C target-feature=+crt-static -C linker=aarch64-linux-musl-gcc" else export CC=musl-gcc - export CC_x86_64_unknown_linux_musl=musl-gcc export RUSTFLAGS="-C target-feature=+crt-static" fi @@ -191,69 +151,19 @@ jobs: - name: Package run: | mkdir -p dist - BIN=target/${{ matrix.target }}/release/${{ env.BINARY_NAME }} - - cp "$BIN" dist/${{ env.BINARY_NAME }}-${{ matrix.target }} + cp target/${{ matrix.target }}/release/${{ env.BINARY_NAME }} dist/telemt cd dist - tar -czf ${{ matrix.asset }}.tar.gz ${{ env.BINARY_NAME }}-${{ matrix.target }} + tar -czf ${{ matrix.asset }}.tar.gz \ + --owner=0 --group=0 --numeric-owner \ + telemt + sha256sum ${{ matrix.asset }}.tar.gz > ${{ matrix.asset }}.sha256 - uses: actions/upload-artifact@v4 with: name: ${{ matrix.asset }} - path: | - dist/${{ matrix.asset }}.tar.gz - dist/${{ matrix.asset }}.sha256 - -# ========================== -# Docker -# ========================== - docker: - name: Docker - runs-on: ubuntu-latest - needs: [build-gnu, build-musl] - continue-on-error: true - - steps: - - uses: actions/checkout@v4 - - - uses: actions/download-artifact@v4 - with: - path: artifacts - - - name: Extract binaries - run: | - mkdir dist - find artifacts -name "*.tar.gz" -exec tar -xzf {} -C dist \; - - cp dist/telemt-x86_64-unknown-linux-musl dist/telemt || true - - - uses: docker/setup-qemu-action@v3 - - uses: docker/setup-buildx-action@v3 - - - name: Login to GHCR - uses: docker/login-action@v3 - with: - registry: ghcr.io - username: ${{ github.actor }} - password: ${{ secrets.GITHUB_TOKEN }} - - - name: Extract version - id: vars - run: echo "VERSION=${GITHUB_REF#refs/tags/}" >> $GITHUB_OUTPUT - - - name: Build & Push - uses: docker/build-push-action@v6 - with: - context: . - push: true - platforms: linux/amd64,linux/arm64 - tags: | - ghcr.io/${{ github.repository }}:${{ steps.vars.outputs.VERSION }} - ghcr.io/${{ github.repository }}:latest - build-args: | - BINARY=dist/telemt + path: dist/* # ========================== # Release @@ -271,7 +181,7 @@ jobs: with: path: artifacts - - name: Flatten artifacts + - name: Flatten run: | mkdir dist find artifacts -type f -exec cp {} dist/ \; @@ -281,5 +191,61 @@ jobs: with: files: dist/* generate_release_notes: true - draft: false - prerelease: ${{ contains(github.ref, '-rc') || contains(github.ref, '-beta') || contains(github.ref, '-alpha') }} + prerelease: ${{ contains(github.ref, '-') }} + +# ========================== +# Docker (FROM RELEASE) +# ========================== + docker: + name: Docker (from release) + runs-on: ubuntu-latest + needs: release + + permissions: + contents: read + packages: write + + steps: + - uses: actions/checkout@v4 + + - name: Install gh + run: apt-get update && apt-get install -y gh + + - name: Extract version + id: vars + run: echo "VERSION=${GITHUB_REF#refs/tags/}" >> $GITHUB_OUTPUT + + - name: Download binary + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + mkdir dist + + gh release download ${{ steps.vars.outputs.VERSION }} \ + --repo ${{ github.repository }} \ + --pattern "telemt-x86_64-linux-musl.tar.gz" \ + --dir dist + + tar -xzf dist/telemt-x86_64-linux-musl.tar.gz -C dist + chmod +x dist/telemt + + - uses: docker/setup-qemu-action@v3 + - uses: docker/setup-buildx-action@v3 + + - uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Build & Push + uses: docker/build-push-action@v6 + with: + context: . + push: true + platforms: linux/amd64,linux/arm64 + tags: | + ghcr.io/${{ github.repository }}:${{ steps.vars.outputs.VERSION }} + ghcr.io/${{ github.repository }}:latest + build-args: | + BINARY=dist/telemt \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index 372f702..eac46f0 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,29 +1,9 @@ # syntax=docker/dockerfile:1 -# ========================== -# Stage 1: Build -# ========================== -FROM rust:1.88-slim-bookworm AS builder - -RUN apt-get update && apt-get install -y --no-install-recommends \ - pkg-config \ - ca-certificates \ - && rm -rf /var/lib/apt/lists/* - -WORKDIR /build - -# Depcache -COPY Cargo.toml Cargo.lock* ./ -RUN mkdir src && echo 'fn main() {}' > src/main.rs && \ - cargo build --release 2>/dev/null || true && \ - rm -rf src - -# Build -COPY . . -RUN cargo build --release && strip target/release/telemt +ARG BINARY # ========================== -# Stage 2: Compress (strip + UPX) +# Stage: minimal # ========================== FROM debian:12-slim AS minimal @@ -33,7 +13,6 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ ca-certificates \ && rm -rf /var/lib/apt/lists/* \ \ - # install UPX from Telemt releases && curl -fL \ --retry 5 \ --retry-delay 3 \ @@ -46,15 +25,15 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ && chmod +x /usr/local/bin/upx \ && rm -rf /tmp/upx* -COPY --from=builder /build/target/release/telemt /telemt +COPY ${BINARY} /telemt RUN strip /telemt || true RUN upx --best --lzma /telemt || true # ========================== -# Stage 3: Debug base +# Debug image # ========================== -FROM debian:12-slim AS debug-base +FROM debian:12-slim AS debug RUN apt-get update && apt-get install -y --no-install-recommends \ ca-certificates \ @@ -64,48 +43,29 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ busybox \ && rm -rf /var/lib/apt/lists/* -# ========================== -# Stage 4: Debug image -# ========================== -FROM debug-base AS debug - WORKDIR /app COPY --from=minimal /telemt /app/telemt COPY config.toml /app/config.toml -USER root - -EXPOSE 443 -EXPOSE 9090 -EXPOSE 9091 +EXPOSE 443 9090 9091 ENTRYPOINT ["/app/telemt"] CMD ["config.toml"] # ========================== -# Stage 5: Production (distroless) +# Production (REAL distroless) # ========================== -FROM gcr.io/distroless/base-debian12 AS prod +FROM gcr.io/distroless/static-debian12 AS prod WORKDIR /app COPY --from=minimal /telemt /app/telemt COPY config.toml /app/config.toml -# TLS + timezone + shell -COPY --from=debug-base /etc/ssl/certs /etc/ssl/certs -COPY --from=debug-base /usr/share/zoneinfo /usr/share/zoneinfo -COPY --from=debug-base /bin/busybox /bin/busybox - -RUN ["/bin/busybox", "--install", "-s", "/bin"] - -# distroless user USER nonroot:nonroot -EXPOSE 443 -EXPOSE 9090 -EXPOSE 9091 +EXPOSE 443 9090 9091 ENTRYPOINT ["/app/telemt"] -CMD ["config.toml"] +CMD ["config.toml"] \ No newline at end of file From a3a6ea288099304de2dc495787ac6fa8738cf687 Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Mon, 23 Mar 2026 20:06:11 +0300 Subject: [PATCH 17/29] Update relay_quota_overflow_regression_tests.rs Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com> --- .../relay_quota_overflow_regression_tests.rs | 38 ++++++++++--------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/src/proxy/tests/relay_quota_overflow_regression_tests.rs b/src/proxy/tests/relay_quota_overflow_regression_tests.rs index a59954e..f1e6c34 100644 --- a/src/proxy/tests/relay_quota_overflow_regression_tests.rs +++ b/src/proxy/tests/relay_quota_overflow_regression_tests.rs @@ -28,9 +28,13 @@ fn preload_user_quota(stats: &Stats, user: &str, bytes: u64) { async fn regression_client_chunk_larger_than_remaining_quota_does_not_overshoot_accounting() { let stats = Arc::new(Stats::new()); let user = "quota-overflow-regression-client-chunk"; + let quota = 10u64; + let preloaded = 9u64; + let attempted_chunk = [0x11, 0x22, 0x33, 0x44]; + let max_post_write_overshoot = attempted_chunk.len() as u64; // Leave only 1 byte remaining under quota. - preload_user_quota(stats.as_ref(), user, 9); + preload_user_quota(stats.as_ref(), user, preloaded); let (mut client_peer, relay_client) = duplex(2048); let (relay_server, mut server_peer) = duplex(2048); @@ -46,15 +50,12 @@ async fn regression_client_chunk_larger_than_remaining_quota_does_not_overshoot_ 512, user, Arc::clone(&stats), - Some(10), + Some(quota), Arc::new(BufferPool::new()), )); // Single chunk attempts to cross remaining budget (4 > 1). - client_peer - .write_all(&[0x11, 0x22, 0x33, 0x44]) - .await - .unwrap(); + client_peer.write_all(&attempted_chunk).await.unwrap(); client_peer.shutdown().await.unwrap(); let forwarded = read_available(&mut server_peer, 60).await; @@ -64,17 +65,17 @@ async fn regression_client_chunk_larger_than_remaining_quota_does_not_overshoot_ .expect("relay must terminate after quota overflow attempt") .expect("relay task must not panic"); - assert_eq!( - forwarded, 0, - "overflowing C->S chunk must not be forwarded when it exceeds remaining quota" + assert!( + forwarded <= attempted_chunk.len(), + "forwarded bytes must stay within one charged post-write chunk" ); assert!(matches!( relay_result, Err(ProxyError::DataQuotaExceeded { .. }) )); assert!( - stats.get_user_quota_used(user) <= 10, - "accounted bytes must never exceed quota after overflowing chunk" + stats.get_user_quota_used(user) <= quota + max_post_write_overshoot, + "accounted bytes must stay within bounded post-write overshoot" ); } @@ -144,9 +145,12 @@ async fn stress_parallel_relays_same_user_quota_overflow_never_exceeds_cap() { let stats = Arc::new(Stats::new()); let user = "quota-overflow-regression-stress"; let quota = 12u64; + const WORKERS: usize = 4; + const BURST_LEN: usize = 64; + let max_parallel_post_write_overshoot = (WORKERS * BURST_LEN) as u64; let mut handles = Vec::new(); - for _ in 0..4usize { + for _ in 0..WORKERS { let stats = Arc::clone(&stats); let user = user.to_string(); @@ -175,7 +179,7 @@ async fn stress_parallel_relays_same_user_quota_overflow_never_exceeds_cap() { }); // Aggressive sender tries to overflow shared user quota. - let burst = vec![0x5Au8; 64]; + let burst = vec![0x5Au8; BURST_LEN]; let _ = client_peer.write_all(&burst).await; let _ = client_peer.shutdown().await; @@ -202,11 +206,11 @@ async fn stress_parallel_relays_same_user_quota_overflow_never_exceeds_cap() { } assert!( - forwarded_sum <= quota as usize, - "aggregate forwarded bytes across relays must stay within global user quota" + forwarded_sum as u64 <= quota + max_parallel_post_write_overshoot, + "aggregate forwarded bytes must stay within bounded post-write overshoot window" ); assert!( - stats.get_user_quota_used(user) <= quota, - "global accounted bytes must stay within quota under overflow stress" + stats.get_user_quota_used(user) <= quota + max_parallel_post_write_overshoot, + "global accounted bytes must stay within bounded post-write overshoot window" ); } From 3ceda150736500734b25d02d38b6927374415709 Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Mon, 23 Mar 2026 20:18:18 +0300 Subject: [PATCH 18/29] Update relay_quota_model_adversarial_tests.rs Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com> --- .../tests/relay_quota_model_adversarial_tests.rs | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/proxy/tests/relay_quota_model_adversarial_tests.rs b/src/proxy/tests/relay_quota_model_adversarial_tests.rs index 83bf731..04a7020 100644 --- a/src/proxy/tests/relay_quota_model_adversarial_tests.rs +++ b/src/proxy/tests/relay_quota_model_adversarial_tests.rs @@ -218,9 +218,12 @@ async fn stress_shared_user_multi_relay_global_quota_never_overshoots_under_mode let stats = Arc::new(Stats::new()); let user = "quota-model-stress-user"; let quota = 96u64; + const WORKERS: usize = 6; + const MAX_WORKER_CHUNK: u64 = 10; + let max_parallel_post_write_overshoot = WORKERS as u64 * MAX_WORKER_CHUNK; let mut workers = Vec::new(); - for worker_id in 0..6u64 { + for worker_id in 0..WORKERS as u64 { let stats = Arc::clone(&stats); let user = user.to_string(); @@ -306,11 +309,11 @@ async fn stress_shared_user_multi_relay_global_quota_never_overshoots_under_mode } assert!( - stats.get_user_quota_used(user) <= quota, - "global per-user quota must never overshoot under concurrent multi-relay model load" + stats.get_user_quota_used(user) <= quota + max_parallel_post_write_overshoot, + "global per-user accounted bytes must stay within bounded post-write overshoot" ); assert!( - delivered_sum <= quota as usize, - "aggregate delivered bytes across relays must remain within global quota" + delivered_sum as u64 <= quota + max_parallel_post_write_overshoot, + "aggregate delivered bytes must stay within bounded post-write overshoot" ); } From 814bef9d997cf953ee6ab4c54b8bffa777cd0614 Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Mon, 23 Mar 2026 20:32:55 +0300 Subject: [PATCH 19/29] Rustfmt --- src/config/hot_reload.rs | 4 +- src/config/load.rs | 7 +- ...ssifier_prefetch_timeout_security_tests.rs | 9 +- .../tests/load_mask_shape_security_tests.rs | 4 +- src/config/types.rs | 3 +- src/main.rs | 6 +- src/metrics.rs | 5 +- src/proxy/client.rs | 15 ++- src/proxy/masking.rs | 46 ++++--- src/proxy/middle_relay.rs | 40 ++++-- src/proxy/mod.rs | 100 +++++++------- src/proxy/relay.rs | 12 +- .../tests/client_clever_advanced_tests.rs | 92 ++++++++++--- .../tests/client_deep_invariants_tests.rs | 29 +++- ...http2_fragmented_preface_security_tests.rs | 9 +- ..._prefetch_config_runtime_security_tests.rs | 10 +- ...sking_prefetch_invariant_security_tests.rs | 5 +- ...prefetch_strict_boundary_security_tests.rs | 4 +- ...g_prefetch_timing_matrix_security_tests.rs | 5 +- ...nt_masking_replay_timing_security_tests.rs | 20 ++- src/proxy/tests/client_more_advanced_tests.rs | 60 ++++++--- src/proxy/tests/client_security_tests.rs | 38 ++++-- ...ls_record_wrap_hardening_security_tests.rs | 21 ++- .../tests/direct_relay_security_tests.rs | 39 +++--- .../tests/handshake_advanced_clever_tests.rs | 120 +++++++++++++---- ...auth_probe_eviction_bias_security_tests.rs | 7 +- ...e_auth_probe_scan_budget_security_tests.rs | 2 +- ...ake_auth_probe_scan_offset_stress_tests.rs | 2 +- .../tests/handshake_more_clever_tests.rs | 124 ++++++++++++++---- .../tests/handshake_real_bug_stress_tests.rs | 13 +- .../handshake_timing_manual_bench_tests.rs | 10 +- ...nvelope_blur_integration_security_tests.rs | 1 - .../masking_aggressive_mode_security_tests.rs | 5 +- ...ect_failure_close_matrix_security_tests.rs | 15 +-- ...ing_consume_idle_timeout_security_tests.rs | 9 +- ..._extended_attack_surface_security_tests.rs | 14 +- ...king_http_probe_boundary_security_tests.rs | 10 +- ...erface_cache_concurrency_security_tests.rs | 2 +- .../masking_interface_cache_security_tests.rs | 5 +- ...roduction_cap_regression_security_tests.rs | 10 +- ...masking_self_target_loop_security_tests.rs | 60 +++------ ...g_timing_budget_coupling_security_tests.rs | 7 +- ...middle_relay_idle_policy_security_tests.rs | 2 +- ...lay_idle_registry_poison_security_tests.rs | 5 +- ...y_frame_debt_concurrency_security_tests.rs | 19 ++- ...rame_debt_proto_chunking_security_tests.rs | 18 ++- ...le_relay_tiny_frame_debt_security_tests.rs | 12 +- ..._relay_zero_length_frame_security_tests.rs | 4 +- ..._extended_attack_surface_security_tests.rs | 108 +++++++++++---- src/stats/mod.rs | 16 ++- .../frame_stream_padding_security_tests.rs | 5 +- ...tracker_encapsulation_adversarial_tests.rs | 5 +- src/tls_front/fetcher.rs | 4 +- src/transport/middle_proxy/pool_status.rs | 4 +- src/transport/pool.rs | 5 +- 55 files changed, 821 insertions(+), 385 deletions(-) diff --git a/src/config/hot_reload.rs b/src/config/hot_reload.rs index a3f795a..7f7499e 100644 --- a/src/config/hot_reload.rs +++ b/src/config/hot_reload.rs @@ -228,7 +228,9 @@ impl HotFields { me_d2c_flush_batch_max_delay_us: cfg.general.me_d2c_flush_batch_max_delay_us, me_d2c_ack_flush_immediate: cfg.general.me_d2c_ack_flush_immediate, me_quota_soft_overshoot_bytes: cfg.general.me_quota_soft_overshoot_bytes, - me_d2c_frame_buf_shrink_threshold_bytes: cfg.general.me_d2c_frame_buf_shrink_threshold_bytes, + me_d2c_frame_buf_shrink_threshold_bytes: cfg + .general + .me_d2c_frame_buf_shrink_threshold_bytes, direct_relay_copy_buf_c2s_bytes: cfg.general.direct_relay_copy_buf_c2s_bytes, direct_relay_copy_buf_s2c_bytes: cfg.general.direct_relay_copy_buf_s2c_bytes, me_health_interval_ms_unhealthy: cfg.general.me_health_interval_ms_unhealthy, diff --git a/src/config/load.rs b/src/config/load.rs index fc54ec2..8f12757 100644 --- a/src/config/load.rs +++ b/src/config/load.rs @@ -444,8 +444,7 @@ impl ProxyConfig { if !(5..=50).contains(&config.censorship.mask_classifier_prefetch_timeout_ms) { return Err(ProxyError::Config( - "censorship.mask_classifier_prefetch_timeout_ms must be within [5, 50]" - .to_string(), + "censorship.mask_classifier_prefetch_timeout_ms must be within [5, 50]".to_string(), )); } @@ -558,7 +557,9 @@ impl ProxyConfig { )); } - if !(4096..=16 * 1024 * 1024).contains(&config.general.me_d2c_frame_buf_shrink_threshold_bytes) { + if !(4096..=16 * 1024 * 1024) + .contains(&config.general.me_d2c_frame_buf_shrink_threshold_bytes) + { return Err(ProxyError::Config( "general.me_d2c_frame_buf_shrink_threshold_bytes must be within [4096, 16777216]" .to_string(), diff --git a/src/config/tests/load_mask_classifier_prefetch_timeout_security_tests.rs b/src/config/tests/load_mask_classifier_prefetch_timeout_security_tests.rs index 49ee953..0b3d543 100644 --- a/src/config/tests/load_mask_classifier_prefetch_timeout_security_tests.rs +++ b/src/config/tests/load_mask_classifier_prefetch_timeout_security_tests.rs @@ -8,8 +8,9 @@ fn write_temp_config(contents: &str) -> PathBuf { .duration_since(UNIX_EPOCH) .expect("system time must be after unix epoch") .as_nanos(); - let path = std::env::temp_dir() - .join(format!("telemt-load-mask-prefetch-timeout-security-{nonce}.toml")); + let path = std::env::temp_dir().join(format!( + "telemt-load-mask-prefetch-timeout-security-{nonce}.toml" + )); fs::write(&path, contents).expect("temp config write must succeed"); path } @@ -67,8 +68,8 @@ mask_classifier_prefetch_timeout_ms = 20 "#, ); - let cfg = ProxyConfig::load(&path) - .expect("prefetch timeout within security bounds must be accepted"); + let cfg = + ProxyConfig::load(&path).expect("prefetch timeout within security bounds must be accepted"); assert_eq!(cfg.censorship.mask_classifier_prefetch_timeout_ms, 20); remove_temp_config(&path); diff --git a/src/config/tests/load_mask_shape_security_tests.rs b/src/config/tests/load_mask_shape_security_tests.rs index 2e4aa41..bccd36f 100644 --- a/src/config/tests/load_mask_shape_security_tests.rs +++ b/src/config/tests/load_mask_shape_security_tests.rs @@ -265,8 +265,8 @@ mask_relay_max_bytes = 67108865 "#, ); - let err = ProxyConfig::load(&path) - .expect_err("mask_relay_max_bytes above hard cap must be rejected"); + let err = + ProxyConfig::load(&path).expect_err("mask_relay_max_bytes above hard cap must be rejected"); let msg = err.to_string(); assert!( msg.contains("censorship.mask_relay_max_bytes must be <= 67108864"), diff --git a/src/config/types.rs b/src/config/types.rs index 5dc9719..240d2f1 100644 --- a/src/config/types.rs +++ b/src/config/types.rs @@ -954,7 +954,8 @@ impl Default for GeneralConfig { me_d2c_flush_batch_max_delay_us: default_me_d2c_flush_batch_max_delay_us(), me_d2c_ack_flush_immediate: default_me_d2c_ack_flush_immediate(), me_quota_soft_overshoot_bytes: default_me_quota_soft_overshoot_bytes(), - me_d2c_frame_buf_shrink_threshold_bytes: default_me_d2c_frame_buf_shrink_threshold_bytes(), + me_d2c_frame_buf_shrink_threshold_bytes: + default_me_d2c_frame_buf_shrink_threshold_bytes(), direct_relay_copy_buf_c2s_bytes: default_direct_relay_copy_buf_c2s_bytes(), direct_relay_copy_buf_s2c_bytes: default_direct_relay_copy_buf_s2c_bytes(), me_warmup_stagger_enabled: default_true(), diff --git a/src/main.rs b/src/main.rs index c512e6b..406b321 100644 --- a/src/main.rs +++ b/src/main.rs @@ -7,12 +7,12 @@ mod crypto; mod error; mod ip_tracker; #[cfg(test)] -#[path = "tests/ip_tracker_hotpath_adversarial_tests.rs"] -mod ip_tracker_hotpath_adversarial_tests; -#[cfg(test)] #[path = "tests/ip_tracker_encapsulation_adversarial_tests.rs"] mod ip_tracker_encapsulation_adversarial_tests; #[cfg(test)] +#[path = "tests/ip_tracker_hotpath_adversarial_tests.rs"] +mod ip_tracker_hotpath_adversarial_tests; +#[cfg(test)] #[path = "tests/ip_tracker_regression_tests.rs"] mod ip_tracker_regression_tests; mod maestro; diff --git a/src/metrics.rs b/src/metrics.rs index a821d4d..f9475f6 100644 --- a/src/metrics.rs +++ b/src/metrics.rs @@ -1233,10 +1233,7 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp out, "# HELP telemt_me_d2c_batch_bytes_bucket_total DC->Client batch byte size buckets" ); - let _ = writeln!( - out, - "# TYPE telemt_me_d2c_batch_bytes_bucket_total counter" - ); + let _ = writeln!(out, "# TYPE telemt_me_d2c_batch_bytes_bucket_total counter"); let _ = writeln!( out, "telemt_me_d2c_batch_bytes_bucket_total{{bucket=\"0_1k\"}} {}", diff --git a/src/proxy/client.rs b/src/proxy/client.rs index 1567caf..0190e8e 100644 --- a/src/proxy/client.rs +++ b/src/proxy/client.rs @@ -210,7 +210,9 @@ fn should_prefetch_mask_classifier_window(initial_data: &[u8]) -> bool { return false; } - initial_data.iter().all(|b| b.is_ascii_alphabetic() || *b == b' ') + initial_data + .iter() + .all(|b| b.is_ascii_alphabetic() || *b == b' ') } #[cfg(test)] @@ -218,16 +220,19 @@ async fn extend_masking_initial_window(reader: &mut R, initial_data: &mut Vec where R: AsyncRead + Unpin, { - extend_masking_initial_window_with_timeout(reader, initial_data, MASK_CLASSIFIER_PREFETCH_TIMEOUT) - .await; + extend_masking_initial_window_with_timeout( + reader, + initial_data, + MASK_CLASSIFIER_PREFETCH_TIMEOUT, + ) + .await; } async fn extend_masking_initial_window_with_timeout( reader: &mut R, initial_data: &mut Vec, prefetch_timeout: Duration, -) -where +) where R: AsyncRead + Unpin, { if !should_prefetch_mask_classifier_window(initial_data) { diff --git a/src/proxy/masking.rs b/src/proxy/masking.rs index 241a48f..ba9f20a 100644 --- a/src/proxy/masking.rs +++ b/src/proxy/masking.rs @@ -10,10 +10,10 @@ use rand::rngs::StdRng; use rand::{Rng, RngExt, SeedableRng}; use std::net::{IpAddr, SocketAddr}; use std::str; -#[cfg(unix)] -use std::sync::{Mutex, OnceLock}; #[cfg(test)] use std::sync::atomic::{AtomicUsize, Ordering}; +#[cfg(unix)] +use std::sync::{Mutex, OnceLock}; use std::time::{Duration, Instant as StdInstant}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::net::TcpStream; @@ -107,15 +107,7 @@ where fn is_http_probe(data: &[u8]) -> bool { // RFC 7540 section 3.5: HTTP/2 client preface starts with "PRI ". const HTTP_METHODS: [&[u8]; 10] = [ - b"GET ", - b"POST", - b"HEAD", - b"PUT ", - b"DELETE", - b"OPTIONS", - b"CONNECT", - b"TRACE", - b"PATCH", + b"GET ", b"POST", b"HEAD", b"PUT ", b"DELETE", b"OPTIONS", b"CONNECT", b"TRACE", b"PATCH", b"PRI ", ]; @@ -328,7 +320,10 @@ fn parse_mask_host_ip_literal(host: &str) -> Option { fn canonical_ip(ip: IpAddr) -> IpAddr { match ip { - IpAddr::V6(v6) => v6.to_ipv4_mapped().map(IpAddr::V4).unwrap_or(IpAddr::V6(v6)), + IpAddr::V6(v6) => v6 + .to_ipv4_mapped() + .map(IpAddr::V4) + .unwrap_or(IpAddr::V6(v6)), IpAddr::V4(v4) => IpAddr::V4(v4), } } @@ -664,12 +659,20 @@ pub async fn handle_bad_client( Ok(Err(e)) => { wait_mask_connect_budget_if_needed(connect_started, config).await; debug!(error = %e, "Failed to connect to mask unix socket"); - consume_client_data_with_timeout_and_cap(reader, config.censorship.mask_relay_max_bytes).await; + consume_client_data_with_timeout_and_cap( + reader, + config.censorship.mask_relay_max_bytes, + ) + .await; wait_mask_outcome_budget(outcome_started, config).await; } Err(_) => { debug!("Timeout connecting to mask unix socket"); - consume_client_data_with_timeout_and_cap(reader, config.censorship.mask_relay_max_bytes).await; + consume_client_data_with_timeout_and_cap( + reader, + config.censorship.mask_relay_max_bytes, + ) + .await; wait_mask_outcome_budget(outcome_started, config).await; } } @@ -698,7 +701,8 @@ pub async fn handle_bad_client( local = %local_addr, "Mask target resolves to local listener; refusing self-referential masking fallback" ); - consume_client_data_with_timeout_and_cap(reader, config.censorship.mask_relay_max_bytes).await; + consume_client_data_with_timeout_and_cap(reader, config.censorship.mask_relay_max_bytes) + .await; wait_mask_outcome_budget(outcome_started, config).await; return; } @@ -758,12 +762,20 @@ pub async fn handle_bad_client( Ok(Err(e)) => { wait_mask_connect_budget_if_needed(connect_started, config).await; debug!(error = %e, "Failed to connect to mask host"); - consume_client_data_with_timeout_and_cap(reader, config.censorship.mask_relay_max_bytes).await; + consume_client_data_with_timeout_and_cap( + reader, + config.censorship.mask_relay_max_bytes, + ) + .await; wait_mask_outcome_budget(outcome_started, config).await; } Err(_) => { debug!("Timeout connecting to mask host"); - consume_client_data_with_timeout_and_cap(reader, config.censorship.mask_relay_max_bytes).await; + consume_client_data_with_timeout_and_cap( + reader, + config.censorship.mask_relay_max_bytes, + ) + .await; wait_mask_outcome_budget(outcome_started, config).await; } } diff --git a/src/proxy/middle_relay.rs b/src/proxy/middle_relay.rs index d833019..3259597 100644 --- a/src/proxy/middle_relay.rs +++ b/src/proxy/middle_relay.rs @@ -23,7 +23,9 @@ use crate::proxy::route_mode::{ ROUTE_SWITCH_ERROR_MSG, RelayRouteMode, RouteCutoverState, affected_cutover_state, cutover_stagger_delay, }; -use crate::stats::{MeD2cFlushReason, MeD2cQuotaRejectStage, MeD2cWriteMode, QuotaReserveError, Stats, UserStats}; +use crate::stats::{ + MeD2cFlushReason, MeD2cQuotaRejectStage, MeD2cWriteMode, QuotaReserveError, Stats, UserStats, +}; use crate::stream::{BufferPool, CryptoReader, CryptoWriter, PooledBuffer}; use crate::transport::middle_proxy::{MePool, MeResponse, proto_flags_for_tag}; @@ -91,7 +93,8 @@ fn relay_idle_candidate_registry() -> &'static Mutex RELAY_IDLE_CANDIDATE_REGISTRY.get_or_init(|| Mutex::new(RelayIdleCandidateRegistry::default())) } -fn relay_idle_candidate_registry_lock() -> std::sync::MutexGuard<'static, RelayIdleCandidateRegistry> { +fn relay_idle_candidate_registry_lock() -> std::sync::MutexGuard<'static, RelayIdleCandidateRegistry> +{ let registry = relay_idle_candidate_registry(); match registry.lock() { Ok(guard) => guard, @@ -1520,8 +1523,7 @@ where } if !idle_policy.enabled { - consecutive_zero_len_frames = - consecutive_zero_len_frames.saturating_add(1); + consecutive_zero_len_frames = consecutive_zero_len_frames.saturating_add(1); if consecutive_zero_len_frames > LEGACY_MAX_CONSECUTIVE_ZERO_LEN_FRAMES { stats.increment_relay_protocol_desync_close_total(); return Err(ProxyError::Proxy( @@ -1835,8 +1837,14 @@ where MeD2cWriteMode::Coalesced } else { let header = [first]; - client_writer.write_all(&header).await.map_err(ProxyError::Io)?; - client_writer.write_all(data).await.map_err(ProxyError::Io)?; + client_writer + .write_all(&header) + .await + .map_err(ProxyError::Io)?; + client_writer + .write_all(data) + .await + .map_err(ProxyError::Io)?; MeD2cWriteMode::Split } } else if len_words < (1 << 24) { @@ -1858,8 +1866,14 @@ where MeD2cWriteMode::Coalesced } else { let header = [first, lw[0], lw[1], lw[2]]; - client_writer.write_all(&header).await.map_err(ProxyError::Io)?; - client_writer.write_all(data).await.map_err(ProxyError::Io)?; + client_writer + .write_all(&header) + .await + .map_err(ProxyError::Io)?; + client_writer + .write_all(data) + .await + .map_err(ProxyError::Io)?; MeD2cWriteMode::Split } } else { @@ -1901,8 +1915,14 @@ where MeD2cWriteMode::Coalesced } else { let header = len_val.to_le_bytes(); - client_writer.write_all(&header).await.map_err(ProxyError::Io)?; - client_writer.write_all(data).await.map_err(ProxyError::Io)?; + client_writer + .write_all(&header) + .await + .map_err(ProxyError::Io)?; + client_writer + .write_all(data) + .await + .map_err(ProxyError::Io)?; if padding_len > 0 { frame_buf.clear(); if frame_buf.capacity() < padding_len { diff --git a/src/proxy/mod.rs b/src/proxy/mod.rs index eebc188..5880558 100644 --- a/src/proxy/mod.rs +++ b/src/proxy/mod.rs @@ -4,58 +4,58 @@ #![cfg_attr(test, allow(warnings))] #![cfg_attr(not(test), forbid(clippy::undocumented_unsafe_blocks))] #![cfg_attr( - not(test), - deny( - clippy::unwrap_used, - clippy::expect_used, - clippy::panic, - clippy::todo, - clippy::unimplemented, - clippy::correctness, - clippy::option_if_let_else, - clippy::or_fun_call, - clippy::branches_sharing_code, - clippy::single_option_map, - clippy::useless_let_if_seq, - clippy::redundant_locals, - clippy::cloned_ref_to_slice_refs, - unsafe_code, - clippy::await_holding_lock, - clippy::await_holding_refcell_ref, - clippy::debug_assert_with_mut_call, - clippy::macro_use_imports, - clippy::cast_ptr_alignment, - clippy::cast_lossless, - clippy::ptr_as_ptr, - clippy::large_stack_arrays, - clippy::same_functions_in_if_condition, - trivial_casts, - trivial_numeric_casts, - unused_extern_crates, - unused_import_braces, - rust_2018_idioms - ) + not(test), + deny( + clippy::unwrap_used, + clippy::expect_used, + clippy::panic, + clippy::todo, + clippy::unimplemented, + clippy::correctness, + clippy::option_if_let_else, + clippy::or_fun_call, + clippy::branches_sharing_code, + clippy::single_option_map, + clippy::useless_let_if_seq, + clippy::redundant_locals, + clippy::cloned_ref_to_slice_refs, + unsafe_code, + clippy::await_holding_lock, + clippy::await_holding_refcell_ref, + clippy::debug_assert_with_mut_call, + clippy::macro_use_imports, + clippy::cast_ptr_alignment, + clippy::cast_lossless, + clippy::ptr_as_ptr, + clippy::large_stack_arrays, + clippy::same_functions_in_if_condition, + trivial_casts, + trivial_numeric_casts, + unused_extern_crates, + unused_import_braces, + rust_2018_idioms + ) )] #![cfg_attr( - not(test), - allow( - clippy::use_self, - clippy::redundant_closure, - clippy::too_many_arguments, - clippy::doc_markdown, - clippy::missing_const_for_fn, - clippy::unnecessary_operation, - clippy::redundant_pub_crate, - clippy::derive_partial_eq_without_eq, - clippy::type_complexity, - clippy::new_ret_no_self, - clippy::cast_possible_truncation, - clippy::cast_possible_wrap, - clippy::significant_drop_tightening, - clippy::significant_drop_in_scrutinee, - clippy::float_cmp, - clippy::nursery - ) + not(test), + allow( + clippy::use_self, + clippy::redundant_closure, + clippy::too_many_arguments, + clippy::doc_markdown, + clippy::missing_const_for_fn, + clippy::unnecessary_operation, + clippy::redundant_pub_crate, + clippy::derive_partial_eq_without_eq, + clippy::type_complexity, + clippy::new_ret_no_self, + clippy::cast_possible_truncation, + clippy::cast_possible_wrap, + clippy::significant_drop_tightening, + clippy::significant_drop_in_scrutinee, + clippy::float_cmp, + clippy::nursery + ) )] pub mod adaptive_buffers; diff --git a/src/proxy/relay.rs b/src/proxy/relay.rs index bf4ad43..6000e18 100644 --- a/src/proxy/relay.rs +++ b/src/proxy/relay.rs @@ -56,8 +56,8 @@ use crate::stats::{Stats, UserStats}; use crate::stream::BufferPool; use std::io; use std::pin::Pin; -use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; use std::sync::Arc; +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; use std::task::{Context, Poll}; use std::time::Duration; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf, copy_bidirectional_with_sizes}; @@ -272,12 +272,10 @@ const QUOTA_ADAPTIVE_INTERVAL_MAX_BYTES: u64 = 64 * 1024; #[inline] fn quota_adaptive_interval_bytes(remaining_before: u64) -> u64 { - remaining_before - .saturating_div(2) - .clamp( - QUOTA_ADAPTIVE_INTERVAL_MIN_BYTES, - QUOTA_ADAPTIVE_INTERVAL_MAX_BYTES, - ) + remaining_before.saturating_div(2).clamp( + QUOTA_ADAPTIVE_INTERVAL_MIN_BYTES, + QUOTA_ADAPTIVE_INTERVAL_MAX_BYTES, + ) } #[inline] diff --git a/src/proxy/tests/client_clever_advanced_tests.rs b/src/proxy/tests/client_clever_advanced_tests.rs index da2e703..f462ed8 100644 --- a/src/proxy/tests/client_clever_advanced_tests.rs +++ b/src/proxy/tests/client_clever_advanced_tests.rs @@ -1,5 +1,5 @@ use super::*; -use crate::config::{UpstreamConfig, UpstreamType, ProxyConfig}; +use crate::config::{ProxyConfig, UpstreamConfig, UpstreamType}; use crate::protocol::constants::{MAX_TLS_PLAINTEXT_SIZE, MIN_TLS_CLIENT_HELLO_SIZE}; use crate::stats::Stats; use crate::transport::UpstreamManager; @@ -41,7 +41,9 @@ fn edge_handshake_timeout_with_mask_grace_saturating_add_prevents_overflow() { #[test] fn edge_tls_clienthello_len_in_bounds_exact_boundaries() { assert!(tls_clienthello_len_in_bounds(MIN_TLS_CLIENT_HELLO_SIZE)); - assert!(!tls_clienthello_len_in_bounds(MIN_TLS_CLIENT_HELLO_SIZE - 1)); + assert!(!tls_clienthello_len_in_bounds( + MIN_TLS_CLIENT_HELLO_SIZE - 1 + )); assert!(tls_clienthello_len_in_bounds(MAX_TLS_PLAINTEXT_SIZE)); assert!(!tls_clienthello_len_in_bounds(MAX_TLS_PLAINTEXT_SIZE + 1)); } @@ -87,7 +89,15 @@ async fn adversarial_tls_handshake_timeout_during_masking_delay() { "198.51.100.1:55000".parse().unwrap(), config, stats.clone(), - Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())), + Arc::new(UpstreamManager::new( + vec![], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )), Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), Arc::new(BufferPool::new()), Arc::new(SecureRandom::new()), @@ -99,7 +109,10 @@ async fn adversarial_tls_handshake_timeout_during_masking_delay() { false, )); - client_side.write_all(&[0x16, 0x03, 0x01, 0xFF, 0xFF]).await.unwrap(); + client_side + .write_all(&[0x16, 0x03, 0x01, 0xFF, 0xFF]) + .await + .unwrap(); let result = tokio::time::timeout(Duration::from_secs(4), handle) .await @@ -123,7 +136,15 @@ async fn blackhat_proxy_protocol_slowloris_timeout() { "198.51.100.2:55000".parse().unwrap(), config, stats.clone(), - Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())), + Arc::new(UpstreamManager::new( + vec![], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )), Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), Arc::new(BufferPool::new()), Arc::new(SecureRandom::new()), @@ -167,7 +188,15 @@ async fn negative_proxy_protocol_enabled_but_client_sends_tls_hello() { "198.51.100.3:55000".parse().unwrap(), config, stats.clone(), - Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())), + Arc::new(UpstreamManager::new( + vec![], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )), Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), Arc::new(BufferPool::new()), Arc::new(SecureRandom::new()), @@ -179,7 +208,10 @@ async fn negative_proxy_protocol_enabled_but_client_sends_tls_hello() { true, )); - client_side.write_all(&[0x16, 0x03, 0x01, 0x02, 0x00]).await.unwrap(); + client_side + .write_all(&[0x16, 0x03, 0x01, 0x02, 0x00]) + .await + .unwrap(); let result = tokio::time::timeout(Duration::from_secs(2), handle) .await @@ -202,7 +234,15 @@ async fn edge_client_stream_exactly_4_bytes_eof() { "198.51.100.4:55000".parse().unwrap(), config, stats.clone(), - Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())), + Arc::new(UpstreamManager::new( + vec![], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )), Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), Arc::new(BufferPool::new()), Arc::new(SecureRandom::new()), @@ -214,7 +254,10 @@ async fn edge_client_stream_exactly_4_bytes_eof() { false, )); - client_side.write_all(&[0x16, 0x03, 0x01, 0x00]).await.unwrap(); + client_side + .write_all(&[0x16, 0x03, 0x01, 0x00]) + .await + .unwrap(); client_side.shutdown().await.unwrap(); let _ = tokio::time::timeout(Duration::from_secs(2), handle).await; @@ -234,7 +277,15 @@ async fn edge_client_stream_tls_header_valid_but_body_1_byte_short_eof() { "198.51.100.5:55000".parse().unwrap(), config, stats.clone(), - Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())), + Arc::new(UpstreamManager::new( + vec![], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )), Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), Arc::new(BufferPool::new()), Arc::new(SecureRandom::new()), @@ -246,7 +297,10 @@ async fn edge_client_stream_tls_header_valid_but_body_1_byte_short_eof() { false, )); - client_side.write_all(&[0x16, 0x03, 0x01, 0x00, 100]).await.unwrap(); + client_side + .write_all(&[0x16, 0x03, 0x01, 0x00, 100]) + .await + .unwrap(); client_side.write_all(&vec![0x41; 99]).await.unwrap(); client_side.shutdown().await.unwrap(); @@ -269,7 +323,15 @@ async fn integration_non_tls_modes_disabled_immediately_masks() { "198.51.100.6:55000".parse().unwrap(), config, stats.clone(), - Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())), + Arc::new(UpstreamManager::new( + vec![], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )), Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), Arc::new(BufferPool::new()), Arc::new(SecureRandom::new()), @@ -372,11 +434,7 @@ async fn stress_user_connection_reservation_concurrent_same_ip_exhaustion() { let ip_tracker = ip_tracker.clone(); tasks.spawn(async move { RunningClientHandler::acquire_user_connection_reservation_static( - user, - &config, - stats, - peer, - ip_tracker, + user, &config, stats, peer, ip_tracker, ) .await }); diff --git a/src/proxy/tests/client_deep_invariants_tests.rs b/src/proxy/tests/client_deep_invariants_tests.rs index 0302300..e57f817 100644 --- a/src/proxy/tests/client_deep_invariants_tests.rs +++ b/src/proxy/tests/client_deep_invariants_tests.rs @@ -42,7 +42,15 @@ async fn invariant_tls_clienthello_truncation_exact_boundary_triggers_masking() "198.51.100.20:55000".parse().unwrap(), config, stats.clone(), - Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())), + Arc::new(UpstreamManager::new( + vec![], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )), Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), Arc::new(BufferPool::new()), Arc::new(SecureRandom::new()), @@ -65,7 +73,9 @@ async fn invariant_tls_clienthello_truncation_exact_boundary_triggers_masking() .unwrap(); client_side.shutdown().await.unwrap(); - let _ = tokio::time::timeout(Duration::from_secs(2), handler).await.unwrap(); + let _ = tokio::time::timeout(Duration::from_secs(2), handler) + .await + .unwrap(); assert_eq!(stats.get_connects_bad(), 1); } @@ -73,7 +83,10 @@ async fn invariant_tls_clienthello_truncation_exact_boundary_triggers_masking() async fn invariant_acquire_reservation_ip_limit_rollback() { let user = "rollback-test-user"; let mut config = ProxyConfig::default(); - config.access.user_max_tcp_conns.insert(user.to_string(), 10); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 10); let stats = Arc::new(Stats::new()); let ip_tracker = Arc::new(UserIpTracker::new()); @@ -159,7 +172,15 @@ async fn invariant_direct_mode_partial_header_eof_is_error_not_bad_connect() { "198.51.100.25:55000".parse().unwrap(), config, stats.clone(), - Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())), + Arc::new(UpstreamManager::new( + vec![], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )), Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), Arc::new(BufferPool::new()), Arc::new(SecureRandom::new()), diff --git a/src/proxy/tests/client_masking_http2_fragmented_preface_security_tests.rs b/src/proxy/tests/client_masking_http2_fragmented_preface_security_tests.rs index fcf51ab..3036f95 100644 --- a/src/proxy/tests/client_masking_http2_fragmented_preface_security_tests.rs +++ b/src/proxy/tests/client_masking_http2_fragmented_preface_security_tests.rs @@ -100,14 +100,7 @@ async fn run_http2_fragment_case(split_at: usize, delay_ms: u64, peer: SocketAdd #[tokio::test] async fn http2_preface_fragmentation_matrix_is_classified_and_forwarded() { - let cases = [ - (2usize, 0u64), - (3, 0), - (4, 0), - (2, 7), - (3, 7), - (8, 1), - ]; + let cases = [(2usize, 0u64), (3, 0), (4, 0), (2, 7), (3, 7), (8, 1)]; for (i, (split_at, delay_ms)) in cases.into_iter().enumerate() { let peer: SocketAddr = format!("198.51.100.{}:58{}", 140 + i, 100 + i) diff --git a/src/proxy/tests/client_masking_prefetch_config_runtime_security_tests.rs b/src/proxy/tests/client_masking_prefetch_config_runtime_security_tests.rs index cdf2136..64e7a85 100644 --- a/src/proxy/tests/client_masking_prefetch_config_runtime_security_tests.rs +++ b/src/proxy/tests/client_masking_prefetch_config_runtime_security_tests.rs @@ -29,7 +29,10 @@ async fn configured_prefetch_budget_20ms_recovers_tail_delayed_15ms() { .write_all(b"ONNECT example.org:443 HTTP/1.1\r\n") .await .expect("tail bytes must be writable"); - writer.shutdown().await.expect("writer shutdown must succeed"); + writer + .shutdown() + .await + .expect("writer shutdown must succeed"); }); let mut initial_data = b"C".to_vec(); @@ -60,7 +63,10 @@ async fn configured_prefetch_budget_5ms_misses_tail_delayed_15ms() { .write_all(b"ONNECT example.org:443 HTTP/1.1\r\n") .await .expect("tail bytes must be writable"); - writer.shutdown().await.expect("writer shutdown must succeed"); + writer + .shutdown() + .await + .expect("writer shutdown must succeed"); }); let mut initial_data = b"C".to_vec(); diff --git a/src/proxy/tests/client_masking_prefetch_invariant_security_tests.rs b/src/proxy/tests/client_masking_prefetch_invariant_security_tests.rs index 2e03ce9..b49db3c 100644 --- a/src/proxy/tests/client_masking_prefetch_invariant_security_tests.rs +++ b/src/proxy/tests/client_masking_prefetch_invariant_security_tests.rs @@ -245,7 +245,10 @@ async fn blackhat_integration_empty_initial_data_path_is_byte_exact_and_eof_clea assert_eq!(head[0], 0x16); read_and_discard_tls_record_body(&mut client_side, head).await; - client_side.write_all(&invalid_mtproto_record).await.unwrap(); + client_side + .write_all(&invalid_mtproto_record) + .await + .unwrap(); client_side.write_all(&trailing_record).await.unwrap(); client_side.shutdown().await.unwrap(); diff --git a/src/proxy/tests/client_masking_prefetch_strict_boundary_security_tests.rs b/src/proxy/tests/client_masking_prefetch_strict_boundary_security_tests.rs index 9ece258..cbb6603 100644 --- a/src/proxy/tests/client_masking_prefetch_strict_boundary_security_tests.rs +++ b/src/proxy/tests/client_masking_prefetch_strict_boundary_security_tests.rs @@ -7,7 +7,9 @@ async fn run_strict_prefetch_case(prefetch_ms: u64, tail_delay_ms: u64) -> Vec= Duration::from_millis(40) - && replay_elapsed < Duration::from_millis(250), + replay_elapsed >= Duration::from_millis(40) && replay_elapsed < Duration::from_millis(250), "replay rejection path must still satisfy masking timing budget without unbounded DB/CPU delay" ); } diff --git a/src/proxy/tests/client_more_advanced_tests.rs b/src/proxy/tests/client_more_advanced_tests.rs index 36ffcbb..8f9d832 100644 --- a/src/proxy/tests/client_more_advanced_tests.rs +++ b/src/proxy/tests/client_more_advanced_tests.rs @@ -53,11 +53,7 @@ async fn boundary_user_data_quota_exact_match_rejects() { let peer = "198.51.100.10:55000".parse().unwrap(); let result = RunningClientHandler::acquire_user_connection_reservation_static( - user, - &config, - stats, - peer, - ip_tracker, + user, &config, stats, peer, ip_tracker, ) .await; @@ -79,11 +75,7 @@ async fn boundary_user_expiration_in_past_rejects() { let peer = "198.51.100.11:55000".parse().unwrap(); let result = RunningClientHandler::acquire_user_connection_reservation_static( - user, - &config, - stats, - peer, - ip_tracker, + user, &config, stats, peer, ip_tracker, ) .await; @@ -103,7 +95,15 @@ async fn blackhat_proxy_protocol_massive_garbage_rejected_quickly() { "198.51.100.12:55000".parse().unwrap(), config, stats.clone(), - Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())), + Arc::new(UpstreamManager::new( + vec![], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )), Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), Arc::new(BufferPool::new()), Arc::new(SecureRandom::new()), @@ -141,7 +141,15 @@ async fn edge_tls_body_immediate_eof_triggers_masking_and_bad_connect() { "198.51.100.13:55000".parse().unwrap(), config, stats.clone(), - Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())), + Arc::new(UpstreamManager::new( + vec![], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )), Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), Arc::new(BufferPool::new()), Arc::new(SecureRandom::new()), @@ -153,10 +161,15 @@ async fn edge_tls_body_immediate_eof_triggers_masking_and_bad_connect() { false, )); - client_side.write_all(&[0x16, 0x03, 0x01, 0x00, 100]).await.unwrap(); + client_side + .write_all(&[0x16, 0x03, 0x01, 0x00, 100]) + .await + .unwrap(); client_side.shutdown().await.unwrap(); - let _ = tokio::time::timeout(Duration::from_secs(2), handler).await.unwrap(); + let _ = tokio::time::timeout(Duration::from_secs(2), handler) + .await + .unwrap(); assert_eq!(stats.get_connects_bad(), 1); } @@ -177,7 +190,15 @@ async fn security_classic_mode_disabled_masks_valid_length_payload() { "198.51.100.15:55000".parse().unwrap(), config, stats.clone(), - Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())), + Arc::new(UpstreamManager::new( + vec![], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )), Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), Arc::new(BufferPool::new()), Arc::new(SecureRandom::new()), @@ -192,7 +213,9 @@ async fn security_classic_mode_disabled_masks_valid_length_payload() { client_side.write_all(&vec![0xEF; 64]).await.unwrap(); client_side.shutdown().await.unwrap(); - let _ = tokio::time::timeout(Duration::from_secs(2), handler).await.unwrap(); + let _ = tokio::time::timeout(Duration::from_secs(2), handler) + .await + .unwrap(); assert_eq!(stats.get_connects_bad(), 1); } @@ -200,7 +223,10 @@ async fn security_classic_mode_disabled_masks_valid_length_payload() { async fn concurrency_ip_tracker_strict_limit_one_rapid_churn() { let user = "rapid-churn-user"; let mut config = ProxyConfig::default(); - config.access.user_max_tcp_conns.insert(user.to_string(), 10); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 10); let stats = Arc::new(Stats::new()); let ip_tracker = Arc::new(UserIpTracker::new()); diff --git a/src/proxy/tests/client_security_tests.rs b/src/proxy/tests/client_security_tests.rs index bae1ce2..1b46c6d 100644 --- a/src/proxy/tests/client_security_tests.rs +++ b/src/proxy/tests/client_security_tests.rs @@ -7,9 +7,9 @@ use crate::protocol::tls; use crate::proxy::handshake::HandshakeSuccess; use crate::stream::{CryptoReader, CryptoWriter}; use crate::transport::proxy_protocol::ProxyProtocolV1Builder; -use rand::rngs::StdRng; use rand::Rng; use rand::SeedableRng; +use rand::rngs::StdRng; use std::net::Ipv4Addr; use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; use tokio::net::{TcpListener, TcpStream}; @@ -34,7 +34,10 @@ fn handshake_timeout_with_mask_grace_includes_mask_margin() { config.timeouts.client_handshake = 2; config.censorship.mask = false; - assert_eq!(handshake_timeout_with_mask_grace(&config), Duration::from_secs(2)); + assert_eq!( + handshake_timeout_with_mask_grace(&config), + Duration::from_secs(2) + ); config.censorship.mask = true; assert_eq!( @@ -86,7 +89,10 @@ impl tokio::io::AsyncRead for ErrorReader { _cx: &mut std::task::Context<'_>, _buf: &mut tokio::io::ReadBuf<'_>, ) -> std::task::Poll> { - std::task::Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "fake error"))) + std::task::Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + "fake error", + ))) } } @@ -124,7 +130,10 @@ fn handshake_timeout_without_mask_is_exact_base() { config.timeouts.client_handshake = 7; config.censorship.mask = false; - assert_eq!(handshake_timeout_with_mask_grace(&config), Duration::from_secs(7)); + assert_eq!( + handshake_timeout_with_mask_grace(&config), + Duration::from_secs(7) + ); } #[test] @@ -133,7 +142,10 @@ fn handshake_timeout_mask_enabled_adds_750ms() { config.timeouts.client_handshake = 3; config.censorship.mask = true; - assert_eq!(handshake_timeout_with_mask_grace(&config), Duration::from_millis(3750)); + assert_eq!( + handshake_timeout_with_mask_grace(&config), + Duration::from_millis(3750) + ); } #[tokio::test] @@ -155,10 +167,12 @@ async fn read_with_progress_fragmented_io_works_over_multiple_calls() { let mut b = vec![0u8; chunk_size]; let n = read_with_progress(&mut cursor, &mut b).await.unwrap(); result.extend_from_slice(&b[..n]); - if n == 0 { break; } + if n == 0 { + break; + } } - assert_eq!(result, vec![1,2,3,4,5]); + assert_eq!(result, vec![1, 2, 3, 4, 5]); } #[tokio::test] @@ -174,7 +188,9 @@ async fn read_with_progress_stress_randomized_chunk_sizes() { let mut b = vec![0u8; chunk]; let read = read_with_progress(&mut cursor, &mut b).await.unwrap(); collected.extend_from_slice(&b[..read]); - if read == 0 { break; } + if read == 0 { + break; + } } assert_eq!(collected, input); @@ -215,10 +231,12 @@ fn wrap_tls_application_record_roundtrip_size_check() { let mut consumed = 0; while idx + 5 <= wrapped.len() { assert_eq!(wrapped[idx], 0x17); - let len = u16::from_be_bytes([wrapped[idx+3], wrapped[idx+4]]) as usize; + let len = u16::from_be_bytes([wrapped[idx + 3], wrapped[idx + 4]]) as usize; consumed += len; idx += 5 + len; - if idx >= wrapped.len() { break; } + if idx >= wrapped.len() { + break; + } } assert_eq!(consumed, payload_len); diff --git a/src/proxy/tests/client_tls_record_wrap_hardening_security_tests.rs b/src/proxy/tests/client_tls_record_wrap_hardening_security_tests.rs index 08f52d1..7964cdd 100644 --- a/src/proxy/tests/client_tls_record_wrap_hardening_security_tests.rs +++ b/src/proxy/tests/client_tls_record_wrap_hardening_security_tests.rs @@ -25,13 +25,26 @@ fn wrap_tls_application_record_oversized_payload_is_chunked_without_truncation() let len = u16::from_be_bytes([record[offset + 3], record[offset + 4]]) as usize; let body_start = offset + 5; let body_end = body_start + len; - assert!(body_end <= record.len(), "declared TLS record length must be in-bounds"); + assert!( + body_end <= record.len(), + "declared TLS record length must be in-bounds" + ); recovered.extend_from_slice(&record[body_start..body_end]); offset = body_end; frames += 1; } - assert_eq!(offset, record.len(), "record parser must consume exact output size"); - assert_eq!(frames, 2, "oversized payload should split into exactly two records"); - assert_eq!(recovered, payload, "chunked records must preserve full payload"); + assert_eq!( + offset, + record.len(), + "record parser must consume exact output size" + ); + assert_eq!( + frames, 2, + "oversized payload should split into exactly two records" + ); + assert_eq!( + recovered, payload, + "chunked records must preserve full payload" + ); } diff --git a/src/proxy/tests/direct_relay_security_tests.rs b/src/proxy/tests/direct_relay_security_tests.rs index 16fe8da..a731830 100644 --- a/src/proxy/tests/direct_relay_security_tests.rs +++ b/src/proxy/tests/direct_relay_security_tests.rs @@ -773,8 +773,7 @@ fn anchored_open_nix_path_writes_expected_lines() { "target/telemt-unknown-dc-anchored-open-ok-{}/unknown-dc.log", std::process::id() ); - let sanitized = - sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize"); + let sanitized = sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize"); let _ = fs::remove_file(&sanitized.resolved_path); let mut first = open_unknown_dc_log_append_anchored(&sanitized) @@ -787,7 +786,10 @@ fn anchored_open_nix_path_writes_expected_lines() { let content = fs::read_to_string(&sanitized.resolved_path).expect("anchored log file must be readable"); - let lines: Vec<&str> = content.lines().filter(|line| !line.trim().is_empty()).collect(); + let lines: Vec<&str> = content + .lines() + .filter(|line| !line.trim().is_empty()) + .collect(); assert_eq!(lines.len(), 2, "expected one line per anchored append call"); assert!( lines.contains(&"dc_idx=31200") && lines.contains(&"dc_idx=31201"), @@ -811,8 +813,7 @@ fn anchored_open_parallel_appends_preserve_line_integrity() { "target/telemt-unknown-dc-anchored-open-parallel-{}/unknown-dc.log", std::process::id() ); - let sanitized = - sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize"); + let sanitized = sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize"); let _ = fs::remove_file(&sanitized.resolved_path); let mut workers = Vec::new(); @@ -831,8 +832,15 @@ fn anchored_open_parallel_appends_preserve_line_integrity() { let content = fs::read_to_string(&sanitized.resolved_path).expect("parallel log file must be readable"); - let lines: Vec<&str> = content.lines().filter(|line| !line.trim().is_empty()).collect(); - assert_eq!(lines.len(), 64, "expected one complete line per worker append"); + let lines: Vec<&str> = content + .lines() + .filter(|line| !line.trim().is_empty()) + .collect(); + assert_eq!( + lines.len(), + 64, + "expected one complete line per worker append" + ); for line in lines { assert!( line.starts_with("dc_idx="), @@ -867,8 +875,7 @@ fn anchored_open_creates_private_0600_file_permissions() { "target/telemt-unknown-dc-anchored-perms-{}/unknown-dc.log", std::process::id() ); - let sanitized = - sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize"); + let sanitized = sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize"); let _ = fs::remove_file(&sanitized.resolved_path); let mut file = open_unknown_dc_log_append_anchored(&sanitized) @@ -905,8 +912,7 @@ fn anchored_open_rejects_existing_symlink_target() { "target/telemt-unknown-dc-anchored-symlink-target-{}/unknown-dc.log", std::process::id() ); - let sanitized = - sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize"); + let sanitized = sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize"); let outside = std::env::temp_dir().join(format!( "telemt-unknown-dc-anchored-symlink-outside-{}.log", @@ -943,8 +949,7 @@ fn anchored_open_high_contention_multi_write_preserves_complete_lines() { "target/telemt-unknown-dc-anchored-contention-{}/unknown-dc.log", std::process::id() ); - let sanitized = - sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize"); + let sanitized = sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize"); let _ = fs::remove_file(&sanitized.resolved_path); let workers = 24usize; @@ -970,7 +975,10 @@ fn anchored_open_high_contention_multi_write_preserves_complete_lines() { let content = fs::read_to_string(&sanitized.resolved_path) .expect("contention output file must be readable"); - let lines: Vec<&str> = content.lines().filter(|line| !line.trim().is_empty()).collect(); + let lines: Vec<&str> = content + .lines() + .filter(|line| !line.trim().is_empty()) + .collect(); assert_eq!( lines.len(), workers * rounds, @@ -1014,8 +1022,7 @@ fn append_unknown_dc_line_returns_error_for_read_only_descriptor() { "target/telemt-unknown-dc-append-ro-{}/unknown-dc.log", std::process::id() ); - let sanitized = - sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize"); + let sanitized = sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize"); fs::write(&sanitized.resolved_path, "seed\n").expect("seed file must be writable"); let mut readonly = std::fs::OpenOptions::new() diff --git a/src/proxy/tests/handshake_advanced_clever_tests.rs b/src/proxy/tests/handshake_advanced_clever_tests.rs index 9b12f21..76347c4 100644 --- a/src/proxy/tests/handshake_advanced_clever_tests.rs +++ b/src/proxy/tests/handshake_advanced_clever_tests.rs @@ -1,5 +1,5 @@ use super::*; -use crate::crypto::{sha256, sha256_hmac, AesCtr}; +use crate::crypto::{AesCtr, sha256, sha256_hmac}; use crate::protocol::constants::{ProtoTag, RESERVED_NONCE_BEGINNINGS, RESERVED_NONCE_FIRST_BYTES}; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; use std::sync::Arc; @@ -175,7 +175,10 @@ async fn tls_minimum_viable_length_boundary() { None, ) .await; - assert!(matches!(res, HandshakeResult::Success(_)), "Exact minimum length TLS handshake must succeed"); + assert!( + matches!(res, HandshakeResult::Success(_)), + "Exact minimum length TLS handshake must succeed" + ); let short_handshake = vec![0x42u8; min_len - 1]; let res_short = handle_tls_handshake( @@ -189,7 +192,10 @@ async fn tls_minimum_viable_length_boundary() { None, ) .await; - assert!(matches!(res_short, HandshakeResult::BadClient { .. }), "Handshake 1 byte shorter than minimum must fail closed"); + assert!( + matches!(res_short, HandshakeResult::BadClient { .. }), + "Handshake 1 byte shorter than minimum must fail closed" + ); } #[tokio::test] @@ -219,9 +225,16 @@ async fn mtproto_extreme_dc_index_serialization() { match res { HandshakeResult::Success((_, _, success)) => { - assert_eq!(success.dc_idx, extreme_dc, "Extreme DC index {} must serialize/deserialize perfectly", extreme_dc); + assert_eq!( + success.dc_idx, extreme_dc, + "Extreme DC index {} must serialize/deserialize perfectly", + extreme_dc + ); } - _ => panic!("MTProto handshake with extreme DC index {} failed", extreme_dc), + _ => panic!( + "MTProto handshake with extreme DC index {} failed", + extreme_dc + ), } } } @@ -253,7 +266,11 @@ async fn alpn_strict_case_and_padding_rejection() { None, ) .await; - assert!(matches!(res, HandshakeResult::BadClient { .. }), "ALPN strict enforcement must reject {:?}", bad_alpn); + assert!( + matches!(res, HandshakeResult::BadClient { .. }), + "ALPN strict enforcement must reject {:?}", + bad_alpn + ); } } @@ -265,8 +282,15 @@ fn ipv4_mapped_ipv6_bucketing_anomaly() { let norm_1 = normalize_auth_probe_ip(ipv4_mapped_1); let norm_2 = normalize_auth_probe_ip(ipv4_mapped_2); - assert_eq!(norm_1, norm_2, "IPv4-mapped IPv6 addresses must collapse into the same /64 bucket (::0)"); - assert_eq!(norm_1, IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), "The bucket must be exactly ::0"); + assert_eq!( + norm_1, norm_2, + "IPv4-mapped IPv6 addresses must collapse into the same /64 bucket (::0)" + ); + assert_eq!( + norm_1, + IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), + "The bucket must be exactly ::0" + ); } // --- Category 2: Adversarial & Black Hat --- @@ -309,7 +333,10 @@ async fn mtproto_invalid_ciphertext_does_not_poison_replay_cache() { None, ) .await; - assert!(matches!(res_valid, HandshakeResult::Success(_)), "Invalid MTProto ciphertext must not poison the replay cache"); + assert!( + matches!(res_valid, HandshakeResult::Success(_)), + "Invalid MTProto ciphertext must not poison the replay cache" + ); } #[tokio::test] @@ -352,7 +379,10 @@ async fn tls_invalid_session_does_not_poison_replay_cache() { None, ) .await; - assert!(matches!(res_valid, HandshakeResult::Success(_)), "Invalid TLS payload must not poison the replay cache"); + assert!( + matches!(res_valid, HandshakeResult::Success(_)), + "Invalid TLS payload must not poison the replay cache" + ); } #[tokio::test] @@ -387,7 +417,10 @@ async fn server_hello_delay_timing_neutrality_on_hmac_failure() { let elapsed = start.elapsed(); assert!(matches!(res, HandshakeResult::BadClient { .. })); - assert!(elapsed >= Duration::from_millis(45), "Invalid HMAC must still incur the configured ServerHello delay to prevent timing side-channels"); + assert!( + elapsed >= Duration::from_millis(45), + "Invalid HMAC must still incur the configured ServerHello delay to prevent timing side-channels" + ); } #[tokio::test] @@ -421,7 +454,10 @@ async fn server_hello_delay_inversion_resilience() { let elapsed = start.elapsed(); assert!(matches!(res, HandshakeResult::Success(_))); - assert!(elapsed >= Duration::from_millis(90), "Delay logic must gracefully handle min > max inversions via max.max(min)"); + assert!( + elapsed >= Duration::from_millis(90), + "Delay logic must gracefully handle min > max inversions via max.max(min)" + ); } #[tokio::test] @@ -436,10 +472,16 @@ async fn mixed_valid_and_invalid_user_secrets_configuration() { for i in 0..9 { let bad_secret = if i % 2 == 0 { "badhex!" } else { "1122" }; - config.access.users.insert(format!("bad_user_{}", i), bad_secret.to_string()); + config + .access + .users + .insert(format!("bad_user_{}", i), bad_secret.to_string()); } let valid_secret_hex = "99999999999999999999999999999999"; - config.access.users.insert("good_user".to_string(), valid_secret_hex.to_string()); + config + .access + .users + .insert("good_user".to_string(), valid_secret_hex.to_string()); config.general.modes.secure = true; config.general.modes.classic = true; config.general.modes.tls = true; @@ -463,7 +505,10 @@ async fn mixed_valid_and_invalid_user_secrets_configuration() { ) .await; - assert!(matches!(res, HandshakeResult::Success(_)), "Proxy must gracefully skip invalid secrets and authenticate the valid one"); + assert!( + matches!(res, HandshakeResult::Success(_)), + "Proxy must gracefully skip invalid secrets and authenticate the valid one" + ); } #[tokio::test] @@ -494,7 +539,10 @@ async fn tls_emulation_fallback_when_cache_missing() { ) .await; - assert!(matches!(res, HandshakeResult::Success(_)), "TLS emulation must gracefully fall back to standard ServerHello if cache is missing"); + assert!( + matches!(res, HandshakeResult::Success(_)), + "TLS emulation must gracefully fall back to standard ServerHello if cache is missing" + ); } #[tokio::test] @@ -524,7 +572,10 @@ async fn classic_mode_over_tls_transport_protocol_confusion() { ) .await; - assert!(matches!(res, HandshakeResult::Success(_)), "Intermediate tag over TLS must succeed if classic mode is enabled, locking in cross-transport behavior"); + assert!( + matches!(res, HandshakeResult::Success(_)), + "Intermediate tag over TLS must succeed if classic mode is enabled, locking in cross-transport behavior" + ); } #[test] @@ -543,9 +594,15 @@ fn generate_tg_nonce_never_emits_reserved_bytes() { false, ); - assert!(!RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]), "Nonce must never start with reserved bytes"); + assert!( + !RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]), + "Nonce must never start with reserved bytes" + ); let first_four: [u8; 4] = [nonce[0], nonce[1], nonce[2], nonce[3]]; - assert!(!RESERVED_NONCE_BEGINNINGS.contains(&first_four), "Nonce must never match reserved 4-byte beginnings"); + assert!( + !RESERVED_NONCE_BEGINNINGS.contains(&first_four), + "Nonce must never match reserved 4-byte beginnings" + ); } } @@ -568,11 +625,18 @@ async fn dashmap_concurrent_saturation_stress() { } for task in tasks { - task.await.expect("Task panicked during concurrent DashMap stress"); + task.await + .expect("Task panicked during concurrent DashMap stress"); } - assert!(auth_probe_is_throttled_for_testing(ip_a), "IP A must be throttled after concurrent stress"); - assert!(auth_probe_is_throttled_for_testing(ip_b), "IP B must be throttled after concurrent stress"); + assert!( + auth_probe_is_throttled_for_testing(ip_a), + "IP A must be throttled after concurrent stress" + ); + assert!( + auth_probe_is_throttled_for_testing(ip_b), + "IP B must be throttled after concurrent stress" + ); } #[test] @@ -586,7 +650,12 @@ fn prototag_invalid_bytes_fail_closed() { ]; for tag in invalid_tags { - assert_eq!(ProtoTag::from_bytes(tag), None, "Invalid ProtoTag bytes {:?} must fail closed", tag); + assert_eq!( + ProtoTag::from_bytes(tag), + None, + "Invalid ProtoTag bytes {:?} must fail closed", + tag + ); } } @@ -603,7 +672,10 @@ fn auth_probe_eviction_hash_collision_stress() { auth_probe_record_failure_with_state(state, ip, now); } - assert!(state.len() <= AUTH_PROBE_TRACK_MAX_ENTRIES, "Eviction logic must successfully bound the map size under heavy insertion stress"); + assert!( + state.len() <= AUTH_PROBE_TRACK_MAX_ENTRIES, + "Eviction logic must successfully bound the map size under heavy insertion stress" + ); } #[test] diff --git a/src/proxy/tests/handshake_auth_probe_eviction_bias_security_tests.rs b/src/proxy/tests/handshake_auth_probe_eviction_bias_security_tests.rs index 6c48cc1..77cea19 100644 --- a/src/proxy/tests/handshake_auth_probe_eviction_bias_security_tests.rs +++ b/src/proxy/tests/handshake_auth_probe_eviction_bias_security_tests.rs @@ -88,6 +88,9 @@ fn light_fuzz_offset_always_stays_inside_state_len() { let now = base + Duration::from_nanos(seed & 0x0fff); let start = auth_probe_scan_start_offset(ip, now, state_len, scan_limit); - assert!(start < state_len, "scan offset must stay inside state length"); + assert!( + start < state_len, + "scan offset must stay inside state length" + ); } -} \ No newline at end of file +} diff --git a/src/proxy/tests/handshake_auth_probe_scan_budget_security_tests.rs b/src/proxy/tests/handshake_auth_probe_scan_budget_security_tests.rs index ece6ff5..c91a215 100644 --- a/src/proxy/tests/handshake_auth_probe_scan_budget_security_tests.rs +++ b/src/proxy/tests/handshake_auth_probe_scan_budget_security_tests.rs @@ -96,4 +96,4 @@ fn light_fuzz_scan_offset_budget_never_exceeds_effective_window() { "scan offset must stay inside state length" ); } -} \ No newline at end of file +} diff --git a/src/proxy/tests/handshake_auth_probe_scan_offset_stress_tests.rs b/src/proxy/tests/handshake_auth_probe_scan_offset_stress_tests.rs index 260a1b9..bf97990 100644 --- a/src/proxy/tests/handshake_auth_probe_scan_offset_stress_tests.rs +++ b/src/proxy/tests/handshake_auth_probe_scan_offset_stress_tests.rs @@ -113,4 +113,4 @@ fn light_fuzz_scan_offset_stays_within_window_for_randomized_inputs() { "scan offset must always remain inside state length" ); } -} \ No newline at end of file +} diff --git a/src/proxy/tests/handshake_more_clever_tests.rs b/src/proxy/tests/handshake_more_clever_tests.rs index 77df442..9782469 100644 --- a/src/proxy/tests/handshake_more_clever_tests.rs +++ b/src/proxy/tests/handshake_more_clever_tests.rs @@ -1,8 +1,8 @@ use super::*; -use crate::crypto::{sha256, sha256_hmac, AesCtr}; +use crate::crypto::{AesCtr, sha256, sha256_hmac}; use crate::protocol::constants::{ProtoTag, RESERVED_NONCE_BEGINNINGS, RESERVED_NONCE_FIRST_BYTES}; -use rand::{Rng, SeedableRng}; use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; use std::collections::HashSet; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; use std::sync::Arc; @@ -223,7 +223,10 @@ fn auth_probe_backoff_extreme_fail_streak_clamps_safely() { assert_eq!(updated.fail_streak, u32::MAX); let expected_blocked_until = now + Duration::from_millis(AUTH_PROBE_BACKOFF_MAX_MS); - assert_eq!(updated.blocked_until, expected_blocked_until, "Extreme fail streak must clamp cleanly to AUTH_PROBE_BACKOFF_MAX_MS"); + assert_eq!( + updated.blocked_until, expected_blocked_until, + "Extreme fail streak must clamp cleanly to AUTH_PROBE_BACKOFF_MAX_MS" + ); } #[test] @@ -250,12 +253,19 @@ fn generate_tg_nonce_cryptographic_uniqueness_and_entropy() { total_set_bits += byte.count_ones() as usize; } - assert!(nonces.insert(nonce), "generate_tg_nonce emitted a duplicate nonce! RNG is stuck."); + assert!( + nonces.insert(nonce), + "generate_tg_nonce emitted a duplicate nonce! RNG is stuck." + ); } let total_bits = iterations * HANDSHAKE_LEN * 8; let ratio = (total_set_bits as f64) / (total_bits as f64); - assert!(ratio > 0.48 && ratio < 0.52, "Nonce entropy is degraded. Set bit ratio: {}", ratio); + assert!( + ratio > 0.48 && ratio < 0.52, + "Nonce entropy is degraded. Set bit ratio: {}", + ratio + ); } #[tokio::test] @@ -267,10 +277,19 @@ async fn mtproto_multi_user_decryption_isolation() { config.general.modes.secure = true; config.access.ignore_time_skew = true; - config.access.users.insert("user_a".to_string(), "11111111111111111111111111111111".to_string()); - config.access.users.insert("user_b".to_string(), "22222222222222222222222222222222".to_string()); + config.access.users.insert( + "user_a".to_string(), + "11111111111111111111111111111111".to_string(), + ); + config.access.users.insert( + "user_b".to_string(), + "22222222222222222222222222222222".to_string(), + ); let good_secret_hex = "33333333333333333333333333333333"; - config.access.users.insert("user_c".to_string(), good_secret_hex.to_string()); + config + .access + .users + .insert("user_c".to_string(), good_secret_hex.to_string()); let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); let peer: SocketAddr = "192.0.2.104:12345".parse().unwrap(); @@ -291,9 +310,14 @@ async fn mtproto_multi_user_decryption_isolation() { match res { HandshakeResult::Success((_, _, success)) => { - assert_eq!(success.user, "user_c", "Decryption attempts on previous users must not corrupt the handshake buffer for the valid user"); + assert_eq!( + success.user, "user_c", + "Decryption attempts on previous users must not corrupt the handshake buffer for the valid user" + ); } - _ => panic!("Multi-user MTProto handshake failed. Decryption buffer might be mutating in place."), + _ => panic!( + "Multi-user MTProto handshake failed. Decryption buffer might be mutating in place." + ), } } @@ -325,7 +349,9 @@ async fn invalid_secret_warning_lock_contention_and_bound() { } let warned = INVALID_SECRET_WARNED.get().unwrap(); - let guard = warned.lock().unwrap_or_else(|poisoned| poisoned.into_inner()); + let guard = warned + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); assert_eq!( guard.len(), @@ -342,7 +368,11 @@ async fn mtproto_strict_concurrent_replay_race_condition() { let secret_hex = "4A4A4A4A4A4A4A4A4A4A4A4A4A4A4A4A"; let config = Arc::new(test_config_with_secret_hex(secret_hex)); let replay_checker = Arc::new(ReplayChecker::new(4096, Duration::from_secs(60))); - let valid_handshake = Arc::new(make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 1)); + let valid_handshake = Arc::new(make_valid_mtproto_handshake( + secret_hex, + ProtoTag::Secure, + 1, + )); let tasks = 100; let barrier = Arc::new(Barrier::new(tasks)); @@ -355,7 +385,10 @@ async fn mtproto_strict_concurrent_replay_race_condition() { let hs = valid_handshake.clone(); handles.push(tokio::spawn(async move { - let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, (i % 250) as u8)), 10000 + i as u16); + let peer = SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(10, 0, 0, (i % 250) as u8)), + 10000 + i as u16, + ); b.wait().await; handle_mtproto_handshake( &hs, @@ -382,8 +415,15 @@ async fn mtproto_strict_concurrent_replay_race_condition() { } } - assert_eq!(successes, 1, "Replay cache race condition allowed multiple identical MTProto handshakes to succeed"); - assert_eq!(failures, tasks - 1, "Replay cache failed to forcefully reject concurrent duplicates"); + assert_eq!( + successes, 1, + "Replay cache race condition allowed multiple identical MTProto handshakes to succeed" + ); + assert_eq!( + failures, + tasks - 1, + "Replay cache failed to forcefully reject concurrent duplicates" + ); } #[tokio::test] @@ -398,7 +438,8 @@ async fn tls_alpn_zero_length_protocol_handled_safely() { let rng = SecureRandom::new(); let peer: SocketAddr = "192.0.2.107:12345".parse().unwrap(); - let handshake = make_valid_tls_client_hello_with_sni_and_alpn(&secret, 0, "example.com", &[b""]); + let handshake = + make_valid_tls_client_hello_with_sni_and_alpn(&secret, 0, "example.com", &[b""]); let res = handle_tls_handshake( &handshake, @@ -412,7 +453,10 @@ async fn tls_alpn_zero_length_protocol_handled_safely() { ) .await; - assert!(matches!(res, HandshakeResult::BadClient { .. }), "0-length ALPN must be safely rejected without panicking"); + assert!( + matches!(res, HandshakeResult::BadClient { .. }), + "0-length ALPN must be safely rejected without panicking" + ); } #[tokio::test] @@ -427,7 +471,8 @@ async fn tls_sni_massive_hostname_does_not_panic() { let peer: SocketAddr = "192.0.2.108:12345".parse().unwrap(); let massive_hostname = String::from_utf8(vec![b'a'; 65000]).unwrap(); - let handshake = make_valid_tls_client_hello_with_sni_and_alpn(&secret, 0, &massive_hostname, &[]); + let handshake = + make_valid_tls_client_hello_with_sni_and_alpn(&secret, 0, &massive_hostname, &[]); let res = handle_tls_handshake( &handshake, @@ -441,7 +486,13 @@ async fn tls_sni_massive_hostname_does_not_panic() { ) .await; - assert!(matches!(res, HandshakeResult::Success(_) | HandshakeResult::BadClient { .. }), "Massive SNI hostname must be processed or ignored without stack overflow or panic"); + assert!( + matches!( + res, + HandshakeResult::Success(_) | HandshakeResult::BadClient { .. } + ), + "Massive SNI hostname must be processed or ignored without stack overflow or panic" + ); } #[tokio::test] @@ -455,7 +506,8 @@ async fn tls_progressive_truncation_fuzzing_no_panics() { let rng = SecureRandom::new(); let peer: SocketAddr = "192.0.2.109:12345".parse().unwrap(); - let valid_handshake = make_valid_tls_client_hello_with_sni_and_alpn(&secret, 0, "example.com", &[b"h2"]); + let valid_handshake = + make_valid_tls_client_hello_with_sni_and_alpn(&secret, 0, "example.com", &[b"h2"]); let full_len = valid_handshake.len(); // Truncated corpus only: full_len is a valid baseline and should not be @@ -473,7 +525,11 @@ async fn tls_progressive_truncation_fuzzing_no_panics() { None, ) .await; - assert!(matches!(res, HandshakeResult::BadClient { .. }), "Truncated TLS handshake at len {} must fail safely without panicking", i); + assert!( + matches!(res, HandshakeResult::BadClient { .. }), + "Truncated TLS handshake at len {} must fail safely without panicking", + i + ); } } @@ -504,7 +560,10 @@ async fn mtproto_pure_entropy_fuzzing_no_panics() { ) .await; - assert!(matches!(res, HandshakeResult::BadClient { .. }), "Pure entropy MTProto payload must fail closed and never panic"); + assert!( + matches!(res, HandshakeResult::BadClient { .. }), + "Pure entropy MTProto payload must fail closed and never panic" + ); } } @@ -517,10 +576,16 @@ fn decode_user_secret_odd_length_hex_rejection() { let mut config = ProxyConfig::default(); config.access.users.clear(); - config.access.users.insert("odd_user".to_string(), "1234567890123456789012345678901".to_string()); + config.access.users.insert( + "odd_user".to_string(), + "1234567890123456789012345678901".to_string(), + ); let decoded = decode_user_secrets(&config, None); - assert!(decoded.is_empty(), "Odd-length hex string must be gracefully rejected by hex::decode without unwrapping"); + assert!( + decoded.is_empty(), + "Odd-length hex string must be gracefully rejected by hex::decode without unwrapping" + ); } #[test] @@ -552,7 +617,10 @@ fn saturation_grace_pre_existing_high_fail_streak_immediate_throttle() { } let is_throttled = auth_probe_should_apply_preauth_throttle(peer_ip, now); - assert!(is_throttled, "A peer with a pre-existing high fail streak must be immediately throttled when saturation begins, receiving no unearned grace period"); + assert!( + is_throttled, + "A peer with a pre-existing high fail streak must be immediately throttled when saturation begins, receiving no unearned grace period" + ); } #[test] @@ -586,7 +654,11 @@ fn mtproto_classic_tags_rejected_when_only_secure_mode_enabled() { config.general.modes.tls = false; assert!(!mode_enabled_for_proto(&config, ProtoTag::Abridged, false)); - assert!(!mode_enabled_for_proto(&config, ProtoTag::Intermediate, false)); + assert!(!mode_enabled_for_proto( + &config, + ProtoTag::Intermediate, + false + )); } #[test] diff --git a/src/proxy/tests/handshake_real_bug_stress_tests.rs b/src/proxy/tests/handshake_real_bug_stress_tests.rs index d7234ff..1e27ed5 100644 --- a/src/proxy/tests/handshake_real_bug_stress_tests.rs +++ b/src/proxy/tests/handshake_real_bug_stress_tests.rs @@ -1,5 +1,5 @@ use super::*; -use crate::crypto::{sha256, sha256_hmac, AesCtr, SecureRandom}; +use crate::crypto::{AesCtr, SecureRandom, sha256, sha256_hmac}; use crate::protocol::constants::{ProtoTag, TLS_RECORD_HANDSHAKE, TLS_VERSION}; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::sync::Arc; @@ -80,8 +80,7 @@ fn make_valid_tls_client_hello_with_alpn( digest[28 + i] ^= ts[i]; } - record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] - .copy_from_slice(&digest); + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].copy_from_slice(&digest); record } @@ -331,7 +330,11 @@ async fn saturation_grace_exhaustion_under_concurrency_keeps_peer_throttled() { let final_state = state.get(&peer_ip).expect("state must exist"); assert!( - final_state.fail_streak >= AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS + final_state.fail_streak + >= AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS ); - assert!(auth_probe_should_apply_preauth_throttle(peer_ip, Instant::now())); + assert!(auth_probe_should_apply_preauth_throttle( + peer_ip, + Instant::now() + )); } diff --git a/src/proxy/tests/handshake_timing_manual_bench_tests.rs b/src/proxy/tests/handshake_timing_manual_bench_tests.rs index 95e9f49..13d112c 100644 --- a/src/proxy/tests/handshake_timing_manual_bench_tests.rs +++ b/src/proxy/tests/handshake_timing_manual_bench_tests.rs @@ -1,5 +1,5 @@ use super::*; -use crate::crypto::{sha256, sha256_hmac, AesCtr, SecureRandom}; +use crate::crypto::{AesCtr, SecureRandom, sha256, sha256_hmac}; use crate::protocol::constants::{ProtoTag, TLS_RECORD_HANDSHAKE, TLS_VERSION}; use std::net::SocketAddr; use std::time::{Duration, Instant}; @@ -169,10 +169,10 @@ async fn mtproto_user_scan_timing_manual_benchmark() { ); } - config.access.users.insert( - preferred_user.to_string(), - target_secret_hex.to_string(), - ); + config + .access + .users + .insert(preferred_user.to_string(), target_secret_hex.to_string()); let replay_checker_preferred = ReplayChecker::new(65_536, Duration::from_secs(60)); let replay_checker_full_scan = ReplayChecker::new(65_536, Duration::from_secs(60)); diff --git a/src/proxy/tests/masking_ab_envelope_blur_integration_security_tests.rs b/src/proxy/tests/masking_ab_envelope_blur_integration_security_tests.rs index 84c904f..a977409 100644 --- a/src/proxy/tests/masking_ab_envelope_blur_integration_security_tests.rs +++ b/src/proxy/tests/masking_ab_envelope_blur_integration_security_tests.rs @@ -544,7 +544,6 @@ async fn timing_classifier_light_fuzz_pairwise_bucketed_accuracy_stays_bounded_u if hardened_acc + 0.05 <= baseline_acc { meaningful_improvement_seen = true; } - } assert!( diff --git a/src/proxy/tests/masking_aggressive_mode_security_tests.rs b/src/proxy/tests/masking_aggressive_mode_security_tests.rs index a77fc14..7356dc0 100644 --- a/src/proxy/tests/masking_aggressive_mode_security_tests.rs +++ b/src/proxy/tests/masking_aggressive_mode_security_tests.rs @@ -85,7 +85,10 @@ async fn aggressive_mode_shapes_backend_silent_non_eof_path() { let legacy = capture_forwarded_len_with_mode(body_sent, false, false, false, 0).await; let aggressive = capture_forwarded_len_with_mode(body_sent, false, true, false, 0).await; - assert!(legacy < floor, "legacy mode should keep timeout path unshaped"); + assert!( + legacy < floor, + "legacy mode should keep timeout path unshaped" + ); assert!( aggressive >= floor, "aggressive mode must shape backend-silent non-EOF paths (aggressive={aggressive}, floor={floor})" diff --git a/src/proxy/tests/masking_connect_failure_close_matrix_security_tests.rs b/src/proxy/tests/masking_connect_failure_close_matrix_security_tests.rs index 614af9b..718189c 100644 --- a/src/proxy/tests/masking_connect_failure_close_matrix_security_tests.rs +++ b/src/proxy/tests/masking_connect_failure_close_matrix_security_tests.rs @@ -52,7 +52,10 @@ async fn run_connect_failure_case( .await .unwrap() .unwrap(); - assert_eq!(n, 0, "connect-failure path must close client-visible writer"); + assert_eq!( + n, 0, + "connect-failure path must close client-visible writer" + ); started.elapsed() } @@ -67,13 +70,9 @@ async fn connect_failure_refusal_close_behavior_matrix() { let peer: SocketAddr = format!("203.0.113.210:{}", 54100 + idx as u16) .parse() .unwrap(); - let elapsed = run_connect_failure_case( - "127.0.0.1", - unused_port, - timing_normalization_enabled, - peer, - ) - .await; + let elapsed = + run_connect_failure_case("127.0.0.1", unused_port, timing_normalization_enabled, peer) + .await; if timing_normalization_enabled { assert!( diff --git a/src/proxy/tests/masking_consume_idle_timeout_security_tests.rs b/src/proxy/tests/masking_consume_idle_timeout_security_tests.rs index b52af35..f2c39a2 100644 --- a/src/proxy/tests/masking_consume_idle_timeout_security_tests.rs +++ b/src/proxy/tests/masking_consume_idle_timeout_security_tests.rs @@ -79,7 +79,10 @@ async fn io_error_terminates_cleanly() { } } - tokio::time::timeout(MASK_RELAY_TIMEOUT, consume_client_data(ErrReader, usize::MAX)) - .await - .expect("consume_client_data did not return on I/O error"); + tokio::time::timeout( + MASK_RELAY_TIMEOUT, + consume_client_data(ErrReader, usize::MAX), + ) + .await + .expect("consume_client_data did not return on I/O error"); } diff --git a/src/proxy/tests/masking_extended_attack_surface_security_tests.rs b/src/proxy/tests/masking_extended_attack_surface_security_tests.rs index 040f567..650731c 100644 --- a/src/proxy/tests/masking_extended_attack_surface_security_tests.rs +++ b/src/proxy/tests/masking_extended_attack_surface_security_tests.rs @@ -32,8 +32,16 @@ async fn run_self_target_refusal( let (mut client, server) = duplex(1024); let started = Instant::now(); let task = tokio::spawn(async move { - handle_bad_client(server, tokio::io::sink(), initial, peer, local_addr, &config, &beobachten) - .await; + handle_bad_client( + server, + tokio::io::sink(), + initial, + peer, + local_addr, + &config, + &beobachten, + ) + .await; }); client @@ -214,4 +222,4 @@ async fn stress_high_fanout_self_target_refusal_no_deadlock_or_timeout() { }) .await .expect("high-fanout refusal workload must complete without deadlock"); -} \ No newline at end of file +} diff --git a/src/proxy/tests/masking_http_probe_boundary_security_tests.rs b/src/proxy/tests/masking_http_probe_boundary_security_tests.rs index 47b6dc6..c8f3ec0 100644 --- a/src/proxy/tests/masking_http_probe_boundary_security_tests.rs +++ b/src/proxy/tests/masking_http_probe_boundary_security_tests.rs @@ -2,7 +2,13 @@ use super::*; #[test] fn exact_four_byte_http_tokens_are_classified() { - for token in [b"GET ".as_ref(), b"POST".as_ref(), b"HEAD".as_ref(), b"PUT ".as_ref(), b"PRI ".as_ref()] { + for token in [ + b"GET ".as_ref(), + b"POST".as_ref(), + b"HEAD".as_ref(), + b"PUT ".as_ref(), + b"PRI ".as_ref(), + ] { assert!( is_http_probe(token), "exact 4-byte token must be classified as HTTP probe: {:?}", @@ -76,4 +82,4 @@ fn light_fuzz_four_byte_ascii_noise_not_misclassified() { token ); } -} \ No newline at end of file +} diff --git a/src/proxy/tests/masking_interface_cache_concurrency_security_tests.rs b/src/proxy/tests/masking_interface_cache_concurrency_security_tests.rs index 8d99b8f..ed6d1ab 100644 --- a/src/proxy/tests/masking_interface_cache_concurrency_security_tests.rs +++ b/src/proxy/tests/masking_interface_cache_concurrency_security_tests.rs @@ -38,4 +38,4 @@ async fn adversarial_parallel_cold_miss_performs_single_interface_refresh() { 1, "parallel cold misses must coalesce into a single interface enumeration" ); -} \ No newline at end of file +} diff --git a/src/proxy/tests/masking_interface_cache_security_tests.rs b/src/proxy/tests/masking_interface_cache_security_tests.rs index 6be99d0..17debb0 100644 --- a/src/proxy/tests/masking_interface_cache_security_tests.rs +++ b/src/proxy/tests/masking_interface_cache_security_tests.rs @@ -37,7 +37,10 @@ async fn tdd_non_local_port_short_circuit_does_not_enumerate_interfaces() { let local_addr: SocketAddr = "0.0.0.0:443".parse().expect("valid local addr"); let is_local = is_mask_target_local_listener_async("127.0.0.1", 8443, local_addr, None).await; - assert!(!is_local, "different port must not be treated as local listener"); + assert!( + !is_local, + "different port must not be treated as local listener" + ); assert_eq!( local_interface_enumerations_for_tests(), 0, diff --git a/src/proxy/tests/masking_production_cap_regression_security_tests.rs b/src/proxy/tests/masking_production_cap_regression_security_tests.rs index f2368a1..9ff51ba 100644 --- a/src/proxy/tests/masking_production_cap_regression_security_tests.rs +++ b/src/proxy/tests/masking_production_cap_regression_security_tests.rs @@ -63,17 +63,11 @@ impl AsyncWrite for CountingWriter { Poll::Ready(Ok(buf.len())) } - fn poll_flush( - self: Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll> { + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } - fn poll_shutdown( - self: Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll> { + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } } diff --git a/src/proxy/tests/masking_self_target_loop_security_tests.rs b/src/proxy/tests/masking_self_target_loop_security_tests.rs index 18cb0d7..7f6cb29 100644 --- a/src/proxy/tests/masking_self_target_loop_security_tests.rs +++ b/src/proxy/tests/masking_self_target_loop_security_tests.rs @@ -1,6 +1,6 @@ use super::*; -use std::net::TcpListener as StdTcpListener; use std::net::SocketAddr; +use std::net::TcpListener as StdTcpListener; use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; use tokio::net::TcpListener; use tokio::time::{Duration, Instant, timeout}; @@ -15,74 +15,38 @@ fn closed_local_port() -> u16 { #[tokio::test] async fn self_target_detection_matches_literal_ipv4_listener() { let local: SocketAddr = "198.51.100.40:443".parse().unwrap(); - assert!(is_mask_target_local_listener_async( - "198.51.100.40", - 443, - local, - None, - ) - .await); + assert!(is_mask_target_local_listener_async("198.51.100.40", 443, local, None,).await); } #[tokio::test] async fn self_target_detection_matches_bracketed_ipv6_listener() { let local: SocketAddr = "[2001:db8::44]:8443".parse().unwrap(); - assert!(is_mask_target_local_listener_async( - "[2001:db8::44]", - 8443, - local, - None, - ) - .await); + assert!(is_mask_target_local_listener_async("[2001:db8::44]", 8443, local, None,).await); } #[tokio::test] async fn self_target_detection_keeps_same_ip_different_port_forwardable() { let local: SocketAddr = "203.0.113.44:443".parse().unwrap(); - assert!(!is_mask_target_local_listener_async( - "203.0.113.44", - 8443, - local, - None, - ) - .await); + assert!(!is_mask_target_local_listener_async("203.0.113.44", 8443, local, None,).await); } #[tokio::test] async fn self_target_detection_normalizes_ipv4_mapped_ipv6_literal() { let local: SocketAddr = "127.0.0.1:443".parse().unwrap(); - assert!(is_mask_target_local_listener_async( - "::ffff:127.0.0.1", - 443, - local, - None, - ) - .await); + assert!(is_mask_target_local_listener_async("::ffff:127.0.0.1", 443, local, None,).await); } #[tokio::test] async fn self_target_detection_unspecified_bind_blocks_loopback_target() { let local: SocketAddr = "0.0.0.0:443".parse().unwrap(); - assert!(is_mask_target_local_listener_async( - "127.0.0.1", - 443, - local, - None, - ) - .await); + assert!(is_mask_target_local_listener_async("127.0.0.1", 443, local, None,).await); } #[tokio::test] async fn self_target_detection_unspecified_bind_keeps_remote_target_forwardable() { let local: SocketAddr = "0.0.0.0:443".parse().unwrap(); let remote: SocketAddr = "198.51.100.44:443".parse().unwrap(); - assert!(!is_mask_target_local_listener_async( - "mask.example", - 443, - local, - Some(remote), - ) - .await); + assert!(!is_mask_target_local_listener_async("mask.example", 443, local, Some(remote),).await); } #[tokio::test] @@ -306,7 +270,10 @@ async fn offline_mask_target_refusal_respects_timing_normalization_budget() { }); client.shutdown().await.unwrap(); - timeout(Duration::from_secs(2), task).await.unwrap().unwrap(); + timeout(Duration::from_secs(2), task) + .await + .unwrap() + .unwrap(); let elapsed = started.elapsed(); assert!( @@ -350,7 +317,10 @@ async fn offline_mask_target_refusal_with_idle_client_is_bounded_by_consume_time .await .expect("connection should still be open before consume timeout expires"); - timeout(Duration::from_secs(2), task).await.unwrap().unwrap(); + timeout(Duration::from_secs(2), task) + .await + .unwrap() + .unwrap(); let elapsed = started.elapsed(); assert!( diff --git a/src/proxy/tests/masking_timing_budget_coupling_security_tests.rs b/src/proxy/tests/masking_timing_budget_coupling_security_tests.rs index 1c342ea..fda6de7 100644 --- a/src/proxy/tests/masking_timing_budget_coupling_security_tests.rs +++ b/src/proxy/tests/masking_timing_budget_coupling_security_tests.rs @@ -40,7 +40,10 @@ async fn adversarial_delayed_interface_lookup_does_not_consume_outcome_floor_bud tokio::time::sleep(Duration::from_millis(80)).await; drop(held_refresh_guard); - client.shutdown().await.expect("client shutdown must succeed"); + client + .shutdown() + .await + .expect("client shutdown must succeed"); timeout(Duration::from_secs(2), task) .await @@ -52,4 +55,4 @@ async fn adversarial_delayed_interface_lookup_does_not_consume_outcome_floor_bud elapsed >= Duration::from_millis(180) && elapsed < Duration::from_millis(350), "timing normalization floor must start after pre-outcome self-target checks" ); -} \ No newline at end of file +} diff --git a/src/proxy/tests/middle_relay_idle_policy_security_tests.rs b/src/proxy/tests/middle_relay_idle_policy_security_tests.rs index 6ea182b..fd3243d 100644 --- a/src/proxy/tests/middle_relay_idle_policy_security_tests.rs +++ b/src/proxy/tests/middle_relay_idle_policy_security_tests.rs @@ -2,8 +2,8 @@ use super::*; use crate::crypto::AesCtr; use crate::stats::Stats; use crate::stream::{BufferPool, CryptoReader}; -use std::sync::atomic::AtomicU64; use std::sync::Arc; +use std::sync::atomic::AtomicU64; use tokio::io::AsyncWriteExt; use tokio::io::duplex; use tokio::time::{Duration as TokioDuration, Instant as TokioInstant, timeout}; diff --git a/src/proxy/tests/middle_relay_idle_registry_poison_security_tests.rs b/src/proxy/tests/middle_relay_idle_registry_poison_security_tests.rs index 112d926..b43825c 100644 --- a/src/proxy/tests/middle_relay_idle_registry_poison_security_tests.rs +++ b/src/proxy/tests/middle_relay_idle_registry_poison_security_tests.rs @@ -29,7 +29,10 @@ fn blackhat_registry_poison_recovers_with_fail_closed_reset_and_pressure_account let before = relay_pressure_event_seq(); note_relay_pressure_event(); let after = relay_pressure_event_seq(); - assert!(after > before, "pressure accounting must still advance after poison"); + assert!( + after > before, + "pressure accounting must still advance after poison" + ); clear_relay_idle_pressure_state_for_testing(); } diff --git a/src/proxy/tests/middle_relay_tiny_frame_debt_concurrency_security_tests.rs b/src/proxy/tests/middle_relay_tiny_frame_debt_concurrency_security_tests.rs index 34fc454..6b1d511 100644 --- a/src/proxy/tests/middle_relay_tiny_frame_debt_concurrency_security_tests.rs +++ b/src/proxy/tests/middle_relay_tiny_frame_debt_concurrency_security_tests.rs @@ -217,7 +217,9 @@ async fn adversarial_lockstep_alternating_attack_under_jitter_closes() { } } - writer_task.await.expect("writer jitter task must not panic"); + writer_task + .await + .expect("writer jitter task must not panic"); assert!(closed, "alternating attack must close before EOF"); }); } @@ -247,7 +249,10 @@ async fn integration_mixed_population_attackers_close_benign_survive() { plaintext.push(0x01); plaintext.extend_from_slice(&[n, n, n, n]); } - writer.write_all(&encrypt_for_reader(&plaintext)).await.unwrap(); + writer + .write_all(&encrypt_for_reader(&plaintext)) + .await + .unwrap(); drop(writer); let mut closed = false; @@ -279,7 +284,10 @@ async fn integration_mixed_population_attackers_close_benign_survive() { } plaintext.push(0x01); plaintext.extend_from_slice(&payload); - writer.write_all(&encrypt_for_reader(&plaintext)).await.unwrap(); + writer + .write_all(&encrypt_for_reader(&plaintext)) + .await + .unwrap(); let got = read_once( &mut crypto_reader, @@ -329,7 +337,10 @@ async fn light_fuzz_parallel_patterns_no_hang_or_panic() { } } - writer.write_all(&encrypt_for_reader(&plaintext)).await.unwrap(); + writer + .write_all(&encrypt_for_reader(&plaintext)) + .await + .unwrap(); drop(writer); for _ in 0..320 { diff --git a/src/proxy/tests/middle_relay_tiny_frame_debt_proto_chunking_security_tests.rs b/src/proxy/tests/middle_relay_tiny_frame_debt_proto_chunking_security_tests.rs index 853b381..cbbc971 100644 --- a/src/proxy/tests/middle_relay_tiny_frame_debt_proto_chunking_security_tests.rs +++ b/src/proxy/tests/middle_relay_tiny_frame_debt_proto_chunking_security_tests.rs @@ -51,7 +51,9 @@ fn make_enabled_idle_policy() -> RelayClientIdlePolicy { fn append_tiny_frame(plaintext: &mut Vec, proto: ProtoTag) { match proto { ProtoTag::Abridged => plaintext.push(0x00), - ProtoTag::Intermediate | ProtoTag::Secure => plaintext.extend_from_slice(&0u32.to_le_bytes()), + ProtoTag::Intermediate | ProtoTag::Secure => { + plaintext.extend_from_slice(&0u32.to_le_bytes()) + } } } @@ -206,7 +208,11 @@ async fn intermediate_chunked_alternating_attack_closes_before_eof() { let mut plaintext = Vec::with_capacity(8 * 200); for n in 0..180u8 { append_tiny_frame(&mut plaintext, ProtoTag::Intermediate); - append_real_frame(&mut plaintext, ProtoTag::Intermediate, [n, n ^ 1, n ^ 2, n ^ 3]); + append_real_frame( + &mut plaintext, + ProtoTag::Intermediate, + [n, n ^ 1, n ^ 2, n ^ 3], + ); } let encrypted = encrypt_for_reader(&plaintext); @@ -240,7 +246,9 @@ async fn intermediate_chunked_alternating_attack_closes_before_eof() { } } - writer_task.await.expect("intermediate writer task must not panic"); + writer_task + .await + .expect("intermediate writer task must not panic"); assert!(closed, "intermediate alternating attack must fail closed"); } @@ -290,7 +298,9 @@ async fn secure_chunked_alternating_attack_closes_before_eof() { } } - writer_task.await.expect("secure writer task must not panic"); + writer_task + .await + .expect("secure writer task must not panic"); assert!(closed, "secure alternating attack must fail closed"); } diff --git a/src/proxy/tests/middle_relay_tiny_frame_debt_security_tests.rs b/src/proxy/tests/middle_relay_tiny_frame_debt_security_tests.rs index dee5dd9..fad87d0 100644 --- a/src/proxy/tests/middle_relay_tiny_frame_debt_security_tests.rs +++ b/src/proxy/tests/middle_relay_tiny_frame_debt_security_tests.rs @@ -2,8 +2,8 @@ use super::*; use crate::crypto::AesCtr; use crate::stats::Stats; use crate::stream::{BufferPool, CryptoReader}; -use std::sync::atomic::AtomicU64; use std::sync::Arc; +use std::sync::atomic::AtomicU64; use std::time::Instant; use tokio::io::{AsyncRead, AsyncWriteExt, duplex}; @@ -156,7 +156,10 @@ fn alternating_one_to_one_closes_with_bounded_real_frame_count() { } let (closed_at, _, reals) = simulate_tiny_debt_pattern(&pattern, pattern.len()); assert!(closed_at.is_some()); - assert!(reals <= 80, "expected bounded real frames before close, got {reals}"); + assert!( + reals <= 80, + "expected bounded real frames before close, got {reals}" + ); } #[test] @@ -183,7 +186,10 @@ fn alternating_one_to_seven_eventually_closes() { } } let (closed_at, _, _) = simulate_tiny_debt_pattern(&pattern, pattern.len()); - assert!(closed_at.is_some(), "1:7 tiny-to-real must eventually close"); + assert!( + closed_at.is_some(), + "1:7 tiny-to-real must eventually close" + ); } #[test] diff --git a/src/proxy/tests/middle_relay_zero_length_frame_security_tests.rs b/src/proxy/tests/middle_relay_zero_length_frame_security_tests.rs index 765c253..dbf6c4c 100644 --- a/src/proxy/tests/middle_relay_zero_length_frame_security_tests.rs +++ b/src/proxy/tests/middle_relay_zero_length_frame_security_tests.rs @@ -2,10 +2,10 @@ use super::*; use crate::crypto::AesCtr; use crate::stats::Stats; use crate::stream::{BufferPool, CryptoReader}; -use std::sync::atomic::AtomicU64; use std::sync::Arc; -use tokio::io::{AsyncRead, AsyncWriteExt, duplex}; +use std::sync::atomic::AtomicU64; use std::time::Instant; +use tokio::io::{AsyncRead, AsyncWriteExt, duplex}; fn make_crypto_reader(reader: T) -> CryptoReader where diff --git a/src/proxy/tests/relay_quota_extended_attack_surface_security_tests.rs b/src/proxy/tests/relay_quota_extended_attack_surface_security_tests.rs index e80690b..8ce1c26 100644 --- a/src/proxy/tests/relay_quota_extended_attack_surface_security_tests.rs +++ b/src/proxy/tests/relay_quota_extended_attack_surface_security_tests.rs @@ -5,10 +5,13 @@ use crate::stream::BufferPool; use rand::rngs::StdRng; use rand::{RngExt, SeedableRng}; use std::sync::Arc; -use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; use tokio::time::{Duration, timeout}; -async fn read_available(reader: &mut R, budget: Duration) -> usize { +async fn read_available( + reader: &mut R, + budget: Duration, +) -> usize { let start = tokio::time::Instant::now(); let mut total = 0usize; let mut buf = [0u8; 128]; @@ -57,16 +60,25 @@ async fn positive_quota_path_forwards_both_directions_within_limit() { Arc::new(BufferPool::new()), )); - client_peer.write_all(&[0xAA, 0xBB, 0xCC, 0xDD]).await.unwrap(); + client_peer + .write_all(&[0xAA, 0xBB, 0xCC, 0xDD]) + .await + .unwrap(); server_peer.read_exact(&mut [0u8; 4]).await.unwrap(); - server_peer.write_all(&[0x11, 0x22, 0x33, 0x44]).await.unwrap(); + server_peer + .write_all(&[0x11, 0x22, 0x33, 0x44]) + .await + .unwrap(); client_peer.read_exact(&mut [0u8; 4]).await.unwrap(); drop(client_peer); drop(server_peer); - let relay_result = timeout(Duration::from_secs(2), relay).await.unwrap().unwrap(); + let relay_result = timeout(Duration::from_secs(2), relay) + .await + .unwrap() + .unwrap(); assert!(relay_result.is_ok()); assert!(stats.get_user_quota_used(user) <= 16); } @@ -98,11 +110,23 @@ async fn negative_preloaded_quota_forbids_any_forwarding() { client_peer.write_all(&[0xAA]).await.unwrap(); server_peer.write_all(&[0xBB]).await.unwrap(); - assert_eq!(read_available(&mut server_peer, Duration::from_millis(120)).await, 0); - assert_eq!(read_available(&mut client_peer, Duration::from_millis(120)).await, 0); + assert_eq!( + read_available(&mut server_peer, Duration::from_millis(120)).await, + 0 + ); + assert_eq!( + read_available(&mut client_peer, Duration::from_millis(120)).await, + 0 + ); - let relay_result = timeout(Duration::from_secs(2), relay).await.unwrap().unwrap(); - assert!(matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. }))); + let relay_result = timeout(Duration::from_secs(2), relay) + .await + .unwrap() + .unwrap(); + assert!(matches!( + relay_result, + Err(ProxyError::DataQuotaExceeded { .. }) + )); assert!(stats.get_user_quota_used(user) <= 8); } @@ -135,13 +159,25 @@ async fn edge_quota_one_ensures_at_most_one_byte_across_directions() { ); let mut buf = [0u8; 1]; - let delivered_s2c = timeout(Duration::from_millis(120), client_peer.read(&mut buf)).await.unwrap().unwrap_or(0); - let delivered_c2s = timeout(Duration::from_millis(120), server_peer.read(&mut buf)).await.unwrap().unwrap_or(0); + let delivered_s2c = timeout(Duration::from_millis(120), client_peer.read(&mut buf)) + .await + .unwrap() + .unwrap_or(0); + let delivered_c2s = timeout(Duration::from_millis(120), server_peer.read(&mut buf)) + .await + .unwrap() + .unwrap_or(0); assert!(delivered_s2c + delivered_c2s <= 1); - let relay_result = timeout(Duration::from_secs(2), relay).await.unwrap().unwrap(); - assert!(matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. }))); + let relay_result = timeout(Duration::from_secs(2), relay) + .await + .unwrap() + .unwrap(); + assert!(matches!( + relay_result, + Err(ProxyError::DataQuotaExceeded { .. }) + )); } #[tokio::test] @@ -191,8 +227,14 @@ async fn adversarial_blackhat_alternating_jitter_does_not_overshoot_quota() { tokio::time::sleep(Duration::from_millis(((i % 3) + 1) as u64)).await; } - let relay_result = timeout(Duration::from_secs(3), relay).await.unwrap().unwrap(); - assert!(matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. }))); + let relay_result = timeout(Duration::from_secs(3), relay) + .await + .unwrap() + .unwrap(); + assert!(matches!( + relay_result, + Err(ProxyError::DataQuotaExceeded { .. }) + )); assert!(total_forwarded <= quota as usize); assert!(stats.get_user_quota_used(user) <= quota); } @@ -239,13 +281,17 @@ async fn light_fuzz_random_quota_schedule_preserves_quota_invariants() { if rng.random::() { let _ = client_peer.write_all(&[rng.random::()]).await; let mut one = [0u8; 1]; - if let Ok(Ok(n)) = timeout(Duration::from_millis(4), server_peer.read(&mut one)).await { + if let Ok(Ok(n)) = + timeout(Duration::from_millis(4), server_peer.read(&mut one)).await + { total_forwarded += n; } } else { let _ = server_peer.write_all(&[rng.random::()]).await; let mut one = [0u8; 1]; - if let Ok(Ok(n)) = timeout(Duration::from_millis(4), client_peer.read(&mut one)).await { + if let Ok(Ok(n)) = + timeout(Duration::from_millis(4), client_peer.read(&mut one)).await + { total_forwarded += n; } } @@ -254,8 +300,14 @@ async fn light_fuzz_random_quota_schedule_preserves_quota_invariants() { drop(client_peer); drop(server_peer); - let relay_result = timeout(Duration::from_secs(2), relay).await.unwrap().unwrap(); - assert!(relay_result.is_ok() || matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. }))); + let relay_result = timeout(Duration::from_secs(2), relay) + .await + .unwrap() + .unwrap(); + assert!( + relay_result.is_ok() + || matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. })) + ); assert!(total_forwarded <= quota as usize); assert!(stats.get_user_quota_used(&user) <= quota); } @@ -305,13 +357,17 @@ async fn stress_parallel_relays_for_one_user_obey_global_quota() { if (step as usize + worker as usize) % 2 == 0 { let _ = client_peer.write_all(&[(step ^ 0x5A)]).await; let mut one = [0u8; 1]; - if let Ok(Ok(n)) = timeout(Duration::from_millis(6), server_peer.read(&mut one)).await { + if let Ok(Ok(n)) = + timeout(Duration::from_millis(6), server_peer.read(&mut one)).await + { total += n; } } else { let _ = server_peer.write_all(&[(step ^ 0xA5)]).await; let mut one = [0u8; 1]; - if let Ok(Ok(n)) = timeout(Duration::from_millis(6), client_peer.read(&mut one)).await { + if let Ok(Ok(n)) = + timeout(Duration::from_millis(6), client_peer.read(&mut one)).await + { total += n; } } @@ -321,8 +377,14 @@ async fn stress_parallel_relays_for_one_user_obey_global_quota() { drop(client_peer); drop(server_peer); - let relay_result = timeout(Duration::from_secs(2), relay).await.unwrap().unwrap(); - assert!(relay_result.is_ok() || matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. }))); + let relay_result = timeout(Duration::from_secs(2), relay) + .await + .unwrap() + .unwrap(); + assert!( + relay_result.is_ok() + || matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. })) + ); total })); } diff --git a/src/stats/mod.rs b/src/stats/mod.rs index 297ff28..ff15d4f 100644 --- a/src/stats/mod.rs +++ b/src/stats/mod.rs @@ -381,7 +381,9 @@ impl Stats { return; } Self::touch_user_stats(user_stats); - user_stats.octets_from_client.fetch_add(bytes, Ordering::Relaxed); + user_stats + .octets_from_client + .fetch_add(bytes, Ordering::Relaxed); } #[inline] @@ -390,7 +392,9 @@ impl Stats { return; } Self::touch_user_stats(user_stats); - user_stats.octets_to_client.fetch_add(bytes, Ordering::Relaxed); + user_stats + .octets_to_client + .fetch_add(bytes, Ordering::Relaxed); } #[inline] @@ -812,7 +816,8 @@ impl Stats { } pub fn increment_me_d2c_data_frames_total(&self) { if self.telemetry_me_allows_normal() { - self.me_d2c_data_frames_total.fetch_add(1, Ordering::Relaxed); + self.me_d2c_data_frames_total + .fetch_add(1, Ordering::Relaxed); } } pub fn increment_me_d2c_ack_frames_total(&self) { @@ -1708,7 +1713,8 @@ impl Stats { self.me_d2c_batch_bytes_bucket_1k_4k.load(Ordering::Relaxed) } pub fn get_me_d2c_batch_bytes_bucket_4k_16k(&self) -> u64 { - self.me_d2c_batch_bytes_bucket_4k_16k.load(Ordering::Relaxed) + self.me_d2c_batch_bytes_bucket_4k_16k + .load(Ordering::Relaxed) } pub fn get_me_d2c_batch_bytes_bucket_16k_64k(&self) -> u64 { self.me_d2c_batch_bytes_bucket_16k_64k @@ -2371,8 +2377,8 @@ impl ReplayStats { mod tests { use super::*; use crate::config::MeTelemetryLevel; - use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; + use std::sync::atomic::{AtomicU64, Ordering}; #[test] fn test_stats_shared_counters() { diff --git a/src/stream/frame_stream_padding_security_tests.rs b/src/stream/frame_stream_padding_security_tests.rs index 83b30f9..1ec787e 100644 --- a/src/stream/frame_stream_padding_security_tests.rs +++ b/src/stream/frame_stream_padding_security_tests.rs @@ -14,7 +14,10 @@ fn padding_rounding_equivalent_for_extensive_safe_domain() { let old = old_padding_round_up_to_4(len).expect("old expression must be safe"); let new = new_padding_round_up_to_4(len).expect("new expression must be safe"); assert_eq!(old, new, "mismatch for len={len}"); - assert!(new >= len, "rounded length must not shrink: len={len}, out={new}"); + assert!( + new >= len, + "rounded length must not shrink: len={len}, out={new}" + ); assert_eq!(new % 4, 0, "rounded length must stay 4-byte aligned"); } } diff --git a/src/tests/ip_tracker_encapsulation_adversarial_tests.rs b/src/tests/ip_tracker_encapsulation_adversarial_tests.rs index cf42e75..3fc9727 100644 --- a/src/tests/ip_tracker_encapsulation_adversarial_tests.rs +++ b/src/tests/ip_tracker_encapsulation_adversarial_tests.rs @@ -44,7 +44,10 @@ async fn encapsulation_repeated_queue_poison_recovery_preserves_forward_progress let ip_primary = ip_from_idx(10_001); let ip_alt = ip_from_idx(10_002); - tracker.check_and_add("encap-poison", ip_primary).await.unwrap(); + tracker + .check_and_add("encap-poison", ip_primary) + .await + .unwrap(); for _ in 0..128 { let queue = tracker.cleanup_queue_mutex_for_tests(); diff --git a/src/tls_front/fetcher.rs b/src/tls_front/fetcher.rs index 4408b5a..2f39707 100644 --- a/src/tls_front/fetcher.rs +++ b/src/tls_front/fetcher.rs @@ -812,8 +812,8 @@ mod tests { #[test] fn test_encode_tls13_certificate_message_single_cert() { let cert = vec![0x30, 0x03, 0x02, 0x01, 0x01]; - let message = encode_tls13_certificate_message(std::slice::from_ref(&cert)) - .expect("message"); + let message = + encode_tls13_certificate_message(std::slice::from_ref(&cert)).expect("message"); assert_eq!(message[0], 0x0b); assert_eq!(read_u24(&message[1..4]), message.len() - 4); diff --git a/src/transport/middle_proxy/pool_status.rs b/src/transport/middle_proxy/pool_status.rs index 918ccd4..1ef59e1 100644 --- a/src/transport/middle_proxy/pool_status.rs +++ b/src/transport/middle_proxy/pool_status.rs @@ -293,9 +293,7 @@ impl MePool { WriterContour::Draining => "draining", }; - if !draining - && let Some(dc_idx) = dc - { + if !draining && let Some(dc_idx) = dc { *live_writers_by_dc_endpoint .entry((dc_idx, endpoint)) .or_insert(0) += 1; diff --git a/src/transport/pool.rs b/src/transport/pool.rs index 60f8a01..bb0baac 100644 --- a/src/transport/pool.rs +++ b/src/transport/pool.rs @@ -201,7 +201,10 @@ impl ConnectionPool { pub async fn close_all(&self) { let pools_snapshot: Vec<(SocketAddr, Arc>)> = { let pools = self.pools.read(); - pools.iter().map(|(addr, pool)| (*addr, Arc::clone(pool))).collect() + pools + .iter() + .map(|(addr, pool)| (*addr, Arc::clone(pool))) + .collect() }; for (addr, pool) in pools_snapshot { From 8e1860f91266f990f04d592c30350571f37a3a67 Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Mon, 23 Mar 2026 20:34:59 +0300 Subject: [PATCH 20/29] Update test.yml --- .github/workflows/test.yml | 34 ++++++++++++++++++++++++++++------ 1 file changed, 28 insertions(+), 6 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index d8f7a64..0c9544f 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -10,8 +10,33 @@ env: CARGO_TERM_COLOR: always jobs: +# ========================== +# Formatting +# ========================== + fmt: + name: Fmt + runs-on: ubuntu-latest + + permissions: + contents: read + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Install Rust toolchain (rustfmt only) + uses: dtolnay/rust-toolchain@stable + with: + components: rustfmt + + - name: Check formatting + run: cargo fmt -- --check + +# ========================== +# Tests + Clippy + Udeps +# ========================== test: - name: Test / Lint / Analysis + name: Test / Clippy / Udeps runs-on: ubuntu-latest permissions: @@ -23,10 +48,10 @@ jobs: - name: Checkout repository uses: actions/checkout@v4 - - name: Install latest stable Rust toolchain + - name: Install Rust toolchain uses: dtolnay/rust-toolchain@stable with: - components: rustfmt, clippy + components: clippy - name: Cache cargo registry & build artifacts uses: actions/cache@v4 @@ -46,9 +71,6 @@ jobs: - name: Run clippy run: cargo clippy -- --cap-lints warn - - name: Check formatting - run: cargo fmt -- --check - - name: Install cargo-udeps run: cargo install cargo-udeps || true From c868eaae747d93515fdb77243c70c53004ccb4d5 Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Mon, 23 Mar 2026 20:36:25 +0300 Subject: [PATCH 21/29] Update test.yml --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 0c9544f..51565d5 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -1,4 +1,4 @@ -name: Test +name: Analyze on: push: From 62a258f8e3baaad7721da78fdf95fde03500b095 Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Mon, 23 Mar 2026 20:49:17 +0300 Subject: [PATCH 22/29] Update test.yml --- .github/workflows/test.yml | 95 +++++++++++++++++++++++++++++--------- 1 file changed, 72 insertions(+), 23 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 51565d5..46d3b0a 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -1,4 +1,4 @@ -name: Analyze +name: Check on: push: @@ -9,6 +9,10 @@ on: env: CARGO_TERM_COLOR: always +concurrency: + group: test-${{ github.ref }} + cancel-in-progress: true + jobs: # ========================== # Formatting @@ -21,22 +25,19 @@ jobs: contents: read steps: - - name: Checkout repository - uses: actions/checkout@v4 + - uses: actions/checkout@v4 - - name: Install Rust toolchain (rustfmt only) - uses: dtolnay/rust-toolchain@stable + - uses: dtolnay/rust-toolchain@stable with: components: rustfmt - - name: Check formatting - run: cargo fmt -- --check + - run: cargo fmt -- --check # ========================== -# Tests + Clippy + Udeps +# Tests # ========================== test: - name: Test / Clippy / Udeps + name: Test runs-on: ubuntu-latest permissions: @@ -45,15 +46,11 @@ jobs: checks: write steps: - - name: Checkout repository - uses: actions/checkout@v4 + - uses: actions/checkout@v4 - - name: Install Rust toolchain - uses: dtolnay/rust-toolchain@stable - with: - components: clippy + - uses: dtolnay/rust-toolchain@stable - - name: Cache cargo registry & build artifacts + - name: Cache cargo uses: actions/cache@v4 with: path: | @@ -64,15 +61,67 @@ jobs: restore-keys: | ${{ runner.os }}-cargo- - - name: Run tests - run: cargo test --verbose + - run: cargo test --verbose - # clippy не валит билд (осознанно) - - name: Run clippy - run: cargo clippy -- --cap-lints warn +# ========================== +# Clippy +# ========================== + clippy: + name: Clippy + runs-on: ubuntu-latest + + permissions: + contents: read + checks: write + + steps: + - uses: actions/checkout@v4 + + - uses: dtolnay/rust-toolchain@stable + with: + components: clippy + + - name: Cache cargo + uses: actions/cache@v4 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + target + key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} + restore-keys: | + ${{ runner.os }}-cargo- + + - run: cargo clippy -- --cap-lints warn + +# ========================== +# Udeps +# ========================== + udeps: + name: Udeps + runs-on: ubuntu-latest + + permissions: + contents: read + + steps: + - uses: actions/checkout@v4 + + - uses: dtolnay/rust-toolchain@stable + + - name: Cache cargo + uses: actions/cache@v4 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + target + key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} + restore-keys: | + ${{ runner.os }}-cargo- - name: Install cargo-udeps run: cargo install cargo-udeps || true - - name: Check for unused dependencies - run: cargo udeps || true \ No newline at end of file + # тоже не валит билд + - run: cargo udeps || true \ No newline at end of file From bb71de0230fbe2db1be34c4f0ee79a356b0a5bd6 Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Mon, 23 Mar 2026 20:54:58 +0300 Subject: [PATCH 23/29] Missing proxy_protocol_trusted_cidrs as trust- Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com> --- src/config/defaults.rs | 4 ++++ src/config/load.rs | 43 ++++++++++++++++++++++++++++++++++++++++++ src/config/types.rs | 9 +++++---- src/proxy/client.rs | 2 +- 4 files changed, 53 insertions(+), 5 deletions(-) diff --git a/src/config/defaults.rs b/src/config/defaults.rs index 09d146a..f02403e 100644 --- a/src/config/defaults.rs +++ b/src/config/defaults.rs @@ -185,6 +185,10 @@ pub(crate) fn default_proxy_protocol_header_timeout_ms() -> u64 { 500 } +pub(crate) fn default_proxy_protocol_trusted_cidrs() -> Vec { + vec!["0.0.0.0/0".parse().unwrap(), "::/0".parse().unwrap()] +} + pub(crate) fn default_server_max_connections() -> u32 { 10_000 } diff --git a/src/config/load.rs b/src/config/load.rs index 8f12757..2a501ea 100644 --- a/src/config/load.rs +++ b/src/config/load.rs @@ -1263,6 +1263,10 @@ mod tests { assert_eq!(cfg.general.update_every, default_update_every()); assert_eq!(cfg.server.listen_addr_ipv4, default_listen_addr_ipv4()); assert_eq!(cfg.server.listen_addr_ipv6, default_listen_addr_ipv6_opt()); + assert_eq!( + cfg.server.proxy_protocol_trusted_cidrs, + default_proxy_protocol_trusted_cidrs() + ); assert_eq!(cfg.server.api.listen, default_api_listen()); assert_eq!(cfg.server.api.whitelist, default_api_whitelist()); assert_eq!( @@ -1395,6 +1399,10 @@ mod tests { let server = ServerConfig::default(); assert_eq!(server.listen_addr_ipv6, Some(default_listen_addr_ipv6())); + assert_eq!( + server.proxy_protocol_trusted_cidrs, + default_proxy_protocol_trusted_cidrs() + ); assert_eq!(server.api.listen, default_api_listen()); assert_eq!(server.api.whitelist, default_api_whitelist()); assert_eq!( @@ -1430,6 +1438,41 @@ mod tests { assert_eq!(access.users, default_access_users()); } + #[test] + fn proxy_protocol_trusted_cidrs_missing_uses_trust_all_but_explicit_empty_stays_empty() { + let cfg_missing: ProxyConfig = toml::from_str( + r#" + [server] + [general] + [network] + [access] + "#, + ) + .unwrap(); + assert_eq!( + cfg_missing.server.proxy_protocol_trusted_cidrs, + default_proxy_protocol_trusted_cidrs() + ); + + let cfg_explicit_empty: ProxyConfig = toml::from_str( + r#" + [server] + proxy_protocol_trusted_cidrs = [] + + [general] + [network] + [access] + "#, + ) + .unwrap(); + assert!( + cfg_explicit_empty + .server + .proxy_protocol_trusted_cidrs + .is_empty() + ); + } + #[test] fn dc_overrides_allow_string_and_array() { let toml = r#" diff --git a/src/config/types.rs b/src/config/types.rs index 240d2f1..2d204e2 100644 --- a/src/config/types.rs +++ b/src/config/types.rs @@ -1240,9 +1240,10 @@ pub struct ServerConfig { /// Trusted source CIDRs allowed to send incoming PROXY protocol headers. /// - /// When non-empty, connections from addresses outside this allowlist are - /// rejected before `src_addr` is applied. - #[serde(default)] + /// If this field is omitted in config, it defaults to trust-all CIDRs + /// (`0.0.0.0/0` and `::/0`). If it is explicitly set to an empty list, + /// all PROXY protocol headers are rejected. + #[serde(default = "default_proxy_protocol_trusted_cidrs")] pub proxy_protocol_trusted_cidrs: Vec, /// Port for the Prometheus-compatible metrics endpoint. @@ -1287,7 +1288,7 @@ impl Default for ServerConfig { listen_tcp: None, proxy_protocol: false, proxy_protocol_header_timeout_ms: default_proxy_protocol_header_timeout_ms(), - proxy_protocol_trusted_cidrs: Vec::new(), + proxy_protocol_trusted_cidrs: default_proxy_protocol_trusted_cidrs(), metrics_port: None, metrics_listen: None, metrics_whitelist: default_metrics_whitelist(), diff --git a/src/proxy/client.rs b/src/proxy/client.rs index 0190e8e..0becac7 100644 --- a/src/proxy/client.rs +++ b/src/proxy/client.rs @@ -323,7 +323,7 @@ fn is_trusted_proxy_source(peer_ip: IpAddr, trusted: &[IpNetwork]) -> bool { let warned = EMPTY_PROXY_TRUST_WARNED.get_or_init(|| AtomicBool::new(false)); if !warned.swap(true, Ordering::Relaxed) { warn!( - "PROXY protocol enabled but server.proxy_protocol_trusted_cidrs is empty; rejecting all PROXY headers by default" + "PROXY protocol enabled but server.proxy_protocol_trusted_cidrs is empty; rejecting all PROXY headers" ); } return false; From 8db566dbe92da3b963e00ca65681352cc0eeed78 Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Mon, 23 Mar 2026 21:58:39 +0300 Subject: [PATCH 24/29] TLS Validator Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com> --- src/config/load.rs | 36 +++++++++ src/config/types.rs | 13 ++++ src/error.rs | 3 + src/proxy/client.rs | 17 ++++- src/proxy/handshake.rs | 48 +++++++++--- src/proxy/tests/handshake_security_tests.rs | 83 +++++++++++++++++++++ 6 files changed, 187 insertions(+), 13 deletions(-) diff --git a/src/config/load.rs b/src/config/load.rs index 2a501ea..8355fb4 100644 --- a/src/config/load.rs +++ b/src/config/load.rs @@ -1267,6 +1267,10 @@ mod tests { cfg.server.proxy_protocol_trusted_cidrs, default_proxy_protocol_trusted_cidrs() ); + assert_eq!( + cfg.censorship.unknown_sni_action, + UnknownSniAction::Drop + ); assert_eq!(cfg.server.api.listen, default_api_listen()); assert_eq!(cfg.server.api.whitelist, default_api_whitelist()); assert_eq!( @@ -1403,6 +1407,10 @@ mod tests { server.proxy_protocol_trusted_cidrs, default_proxy_protocol_trusted_cidrs() ); + assert_eq!( + AntiCensorshipConfig::default().unknown_sni_action, + UnknownSniAction::Drop + ); assert_eq!(server.api.listen, default_api_listen()); assert_eq!(server.api.whitelist, default_api_whitelist()); assert_eq!( @@ -1473,6 +1481,34 @@ mod tests { ); } + #[test] + fn unknown_sni_action_parses_and_defaults_to_drop() { + let cfg_default: ProxyConfig = toml::from_str( + r#" + [server] + [general] + [network] + [access] + [censorship] + "#, + ) + .unwrap(); + assert_eq!(cfg_default.censorship.unknown_sni_action, UnknownSniAction::Drop); + + let cfg_mask: ProxyConfig = toml::from_str( + r#" + [server] + [general] + [network] + [access] + [censorship] + unknown_sni_action = "mask" + "#, + ) + .unwrap(); + assert_eq!(cfg_mask.censorship.unknown_sni_action, UnknownSniAction::Mask); + } + #[test] fn dc_overrides_allow_string_and_array() { let toml = r#" diff --git a/src/config/types.rs b/src/config/types.rs index 2d204e2..68ba278 100644 --- a/src/config/types.rs +++ b/src/config/types.rs @@ -1359,6 +1359,14 @@ impl Default for TimeoutsConfig { } } +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)] +#[serde(rename_all = "lowercase")] +pub enum UnknownSniAction { + #[default] + Drop, + Mask, +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct AntiCensorshipConfig { #[serde(default = "default_tls_domain")] @@ -1368,6 +1376,10 @@ pub struct AntiCensorshipConfig { #[serde(default)] pub tls_domains: Vec, + /// Policy for TLS ClientHello with unknown (non-configured) SNI. + #[serde(default)] + pub unknown_sni_action: UnknownSniAction, + /// Upstream scope used for TLS front metadata fetches. /// Empty value keeps default upstream routing behavior. #[serde(default = "default_tls_fetch_scope")] @@ -1478,6 +1490,7 @@ impl Default for AntiCensorshipConfig { Self { tls_domain: default_tls_domain(), tls_domains: Vec::new(), + unknown_sni_action: UnknownSniAction::Drop, tls_fetch_scope: default_tls_fetch_scope(), mask: default_true(), mask_host: None, diff --git a/src/error.rs b/src/error.rs index d9aeb22..49c8c81 100644 --- a/src/error.rs +++ b/src/error.rs @@ -216,6 +216,9 @@ pub enum ProxyError { #[error("Invalid proxy protocol header")] InvalidProxyProtocol, + #[error("Unknown TLS SNI")] + UnknownTlsSni, + #[error("Proxy error: {0}")] Proxy(String), diff --git a/src/proxy/client.rs b/src/proxy/client.rs index 0becac7..8ce3e96 100644 --- a/src/proxy/client.rs +++ b/src/proxy/client.rs @@ -317,6 +317,13 @@ fn record_handshake_failure_class( record_beobachten_class(beobachten, config, peer_ip, class); } +#[inline] +fn increment_bad_on_unknown_tls_sni(stats: &Stats, error: &ProxyError) { + if matches!(error, ProxyError::UnknownTlsSni) { + stats.increment_connects_bad(); + } +} + fn is_trusted_proxy_source(peer_ip: IpAddr, trusted: &[IpNetwork]) -> bool { if trusted.is_empty() { static EMPTY_PROXY_TRUST_WARNED: OnceLock = OnceLock::new(); @@ -508,7 +515,10 @@ where beobachten.clone(), )); } - HandshakeResult::Error(e) => return Err(e), + HandshakeResult::Error(e) => { + increment_bad_on_unknown_tls_sni(stats.as_ref(), &e); + return Err(e); + } }; debug!(peer = %peer, "Reading MTProto handshake through TLS"); @@ -959,7 +969,10 @@ impl RunningClientHandler { self.beobachten.clone(), )); } - HandshakeResult::Error(e) => return Err(e), + HandshakeResult::Error(e) => { + increment_bad_on_unknown_tls_sni(stats.as_ref(), &e); + return Err(e); + } }; debug!(peer = %peer, "Reading MTProto handshake through TLS"); diff --git a/src/proxy/handshake.rs b/src/proxy/handshake.rs index 55a8a21..8b8c4a3 100644 --- a/src/proxy/handshake.rs +++ b/src/proxy/handshake.rs @@ -16,7 +16,7 @@ use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tracing::{debug, trace, warn}; use zeroize::{Zeroize, Zeroizing}; -use crate::config::ProxyConfig; +use crate::config::{ProxyConfig, UnknownSniAction}; use crate::crypto::{AesCtr, SecureRandom, sha256}; use crate::error::{HandshakeResult, ProxyError}; use crate::protocol::constants::*; @@ -510,6 +510,21 @@ fn decode_user_secrets( secrets } +#[inline] +fn find_matching_tls_domain<'a>(config: &'a ProxyConfig, sni: &str) -> Option<&'a str> { + if config.censorship.tls_domain.eq_ignore_ascii_case(sni) { + return Some(config.censorship.tls_domain.as_str()); + } + + for domain in &config.censorship.tls_domains { + if domain.eq_ignore_ascii_case(sni) { + return Some(domain.as_str()); + } + } + + None +} + async fn maybe_apply_server_hello_delay(config: &ProxyConfig) { if config.censorship.server_hello_delay_max_ms == 0 { return; @@ -593,6 +608,25 @@ where } let client_sni = tls::extract_sni_from_client_hello(handshake); + let matched_tls_domain = client_sni + .as_deref() + .and_then(|sni| find_matching_tls_domain(config, sni)); + + if client_sni.is_some() && matched_tls_domain.is_none() { + auth_probe_record_failure(peer.ip(), Instant::now()); + maybe_apply_server_hello_delay(config).await; + debug!( + peer = %peer, + sni = ?client_sni, + action = ?config.censorship.unknown_sni_action, + "TLS handshake rejected by unknown SNI policy" + ); + return match config.censorship.unknown_sni_action { + UnknownSniAction::Drop => HandshakeResult::Error(ProxyError::UnknownTlsSni), + UnknownSniAction::Mask => HandshakeResult::BadClient { reader, writer }, + }; + } + let secrets = decode_user_secrets(config, client_sni.as_deref()); let validation = match tls::validate_tls_handshake_with_replay_window( @@ -633,16 +667,8 @@ where let cached = if config.censorship.tls_emulation { if let Some(cache) = tls_cache.as_ref() { - let selected_domain = if let Some(sni) = client_sni.as_ref() { - if cache.contains_domain(sni).await { - sni.clone() - } else { - config.censorship.tls_domain.clone() - } - } else { - config.censorship.tls_domain.clone() - }; - let cached_entry = cache.get(&selected_domain).await; + let selected_domain = matched_tls_domain.unwrap_or(config.censorship.tls_domain.as_str()); + let cached_entry = cache.get(selected_domain).await; let use_full_cert_payload = cache .take_full_cert_budget_for_ip( peer.ip(), diff --git a/src/proxy/tests/handshake_security_tests.rs b/src/proxy/tests/handshake_security_tests.rs index d06f63e..6796c5c 100644 --- a/src/proxy/tests/handshake_security_tests.rs +++ b/src/proxy/tests/handshake_security_tests.rs @@ -956,6 +956,89 @@ async fn stress_tls_sni_preferred_user_hint_scales_to_large_user_set() { } } +#[tokio::test] +async fn tls_unknown_sni_drop_policy_returns_hard_error() { + let secret = [0x48u8; 16]; + let mut config = test_config_with_secret_hex("48484848484848484848484848484848"); + config.censorship.unknown_sni_action = UnknownSniAction::Drop; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.190:44326".parse().unwrap(); + let handshake = + make_valid_tls_client_hello_with_sni_and_alpn(&secret, 0, "unknown.example", &[b"h2"]); + + let result = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!( + result, + HandshakeResult::Error(ProxyError::UnknownTlsSni) + )); +} + +#[tokio::test] +async fn tls_unknown_sni_mask_policy_falls_back_to_bad_client() { + let secret = [0x49u8; 16]; + let mut config = test_config_with_secret_hex("49494949494949494949494949494949"); + config.censorship.unknown_sni_action = UnknownSniAction::Mask; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.191:44326".parse().unwrap(); + let handshake = + make_valid_tls_client_hello_with_sni_and_alpn(&secret, 0, "unknown.example", &[b"h2"]); + + let result = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(result, HandshakeResult::BadClient { .. })); +} + +#[tokio::test] +async fn tls_missing_sni_keeps_legacy_auth_path() { + let secret = [0x4Au8; 16]; + let mut config = test_config_with_secret_hex("4a4a4a4a4a4a4a4a4a4a4a4a4a4a4a4a"); + config.censorship.unknown_sni_action = UnknownSniAction::Drop; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.192:44326".parse().unwrap(); + let handshake = make_valid_tls_handshake(&secret, 0); + + let result = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(result, HandshakeResult::Success(_))); +} + #[tokio::test] async fn alpn_enforce_rejects_unsupported_client_alpn() { let secret = [0x33u8; 16]; From a40d6929e585229fcff27306e42bc0218d8d97aa Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Mon, 23 Mar 2026 22:41:17 +0300 Subject: [PATCH 25/29] Upstream-driver getProxyConfig and getProxyConfig Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com> --- src/config/load.rs | 15 +- src/maestro/helpers.rs | 7 +- src/maestro/me_startup.rs | 5 +- src/proxy/handshake.rs | 3 +- src/transport/middle_proxy/config_updater.rs | 49 +++-- src/transport/middle_proxy/http_fetch.rs | 184 +++++++++++++++++++ src/transport/middle_proxy/mod.rs | 6 +- src/transport/middle_proxy/secret.rs | 38 ++-- 8 files changed, 266 insertions(+), 41 deletions(-) create mode 100644 src/transport/middle_proxy/http_fetch.rs diff --git a/src/config/load.rs b/src/config/load.rs index 8355fb4..2c46766 100644 --- a/src/config/load.rs +++ b/src/config/load.rs @@ -1267,10 +1267,7 @@ mod tests { cfg.server.proxy_protocol_trusted_cidrs, default_proxy_protocol_trusted_cidrs() ); - assert_eq!( - cfg.censorship.unknown_sni_action, - UnknownSniAction::Drop - ); + assert_eq!(cfg.censorship.unknown_sni_action, UnknownSniAction::Drop); assert_eq!(cfg.server.api.listen, default_api_listen()); assert_eq!(cfg.server.api.whitelist, default_api_whitelist()); assert_eq!( @@ -1493,7 +1490,10 @@ mod tests { "#, ) .unwrap(); - assert_eq!(cfg_default.censorship.unknown_sni_action, UnknownSniAction::Drop); + assert_eq!( + cfg_default.censorship.unknown_sni_action, + UnknownSniAction::Drop + ); let cfg_mask: ProxyConfig = toml::from_str( r#" @@ -1506,7 +1506,10 @@ mod tests { "#, ) .unwrap(); - assert_eq!(cfg_mask.censorship.unknown_sni_action, UnknownSniAction::Mask); + assert_eq!( + cfg_mask.censorship.unknown_sni_action, + UnknownSniAction::Mask + ); } #[test] diff --git a/src/maestro/helpers.rs b/src/maestro/helpers.rs index 35f796f..032460c 100644 --- a/src/maestro/helpers.rs +++ b/src/maestro/helpers.rs @@ -8,8 +8,10 @@ use tracing::{debug, error, info, warn}; use crate::cli; use crate::config::ProxyConfig; +use crate::transport::UpstreamManager; use crate::transport::middle_proxy::{ - ProxyConfigData, fetch_proxy_config_with_raw, load_proxy_config_cache, save_proxy_config_cache, + ProxyConfigData, fetch_proxy_config_with_raw_via_upstream, load_proxy_config_cache, + save_proxy_config_cache, }; pub(crate) fn resolve_runtime_config_path( @@ -288,9 +290,10 @@ pub(crate) async fn load_startup_proxy_config_snapshot( cache_path: Option<&str>, me2dc_fallback: bool, label: &'static str, + upstream: Option>, ) -> Option { loop { - match fetch_proxy_config_with_raw(url).await { + match fetch_proxy_config_with_raw_via_upstream(url, upstream.clone()).await { Ok((cfg, raw)) => { if !cfg.map.is_empty() { if let Some(path) = cache_path diff --git a/src/maestro/me_startup.rs b/src/maestro/me_startup.rs index 022f8ae..b1e605c 100644 --- a/src/maestro/me_startup.rs +++ b/src/maestro/me_startup.rs @@ -63,9 +63,10 @@ pub(crate) async fn initialize_me_pool( let proxy_secret_path = config.general.proxy_secret_path.as_deref(); let pool_size = config.general.middle_proxy_pool_size.max(1); let proxy_secret = loop { - match crate::transport::middle_proxy::fetch_proxy_secret( + match crate::transport::middle_proxy::fetch_proxy_secret_with_upstream( proxy_secret_path, config.general.proxy_secret_len_max, + Some(upstream_manager.clone()), ) .await { @@ -129,6 +130,7 @@ pub(crate) async fn initialize_me_pool( config.general.proxy_config_v4_cache_path.as_deref(), me2dc_fallback, "getProxyConfig", + Some(upstream_manager.clone()), ) .await; if cfg_v4.is_some() { @@ -160,6 +162,7 @@ pub(crate) async fn initialize_me_pool( config.general.proxy_config_v6_cache_path.as_deref(), me2dc_fallback, "getProxyConfigV6", + Some(upstream_manager.clone()), ) .await; if cfg_v6.is_some() { diff --git a/src/proxy/handshake.rs b/src/proxy/handshake.rs index 8b8c4a3..9d48fe9 100644 --- a/src/proxy/handshake.rs +++ b/src/proxy/handshake.rs @@ -667,7 +667,8 @@ where let cached = if config.censorship.tls_emulation { if let Some(cache) = tls_cache.as_ref() { - let selected_domain = matched_tls_domain.unwrap_or(config.censorship.tls_domain.as_str()); + let selected_domain = + matched_tls_domain.unwrap_or(config.censorship.tls_domain.as_str()); let cached_entry = cache.get(selected_domain).await; let use_full_cert_payload = cache .take_full_cert_budget_for_ip( diff --git a/src/transport/middle_proxy/config_updater.rs b/src/transport/middle_proxy/config_updater.rs index 8e5a701..9819e8d 100644 --- a/src/transport/middle_proxy/config_updater.rs +++ b/src/transport/middle_proxy/config_updater.rs @@ -11,17 +11,19 @@ use tracing::{debug, info, warn}; use crate::config::ProxyConfig; use crate::error::Result; +use crate::transport::UpstreamManager; +use super::http_fetch::https_get; use super::MePool; use super::rotation::{MeReinitTrigger, enqueue_reinit_trigger}; -use super::secret::download_proxy_secret_with_max_len; +use super::secret::download_proxy_secret_with_max_len_via_upstream; use super::selftest::record_timeskew_sample; use std::time::SystemTime; -async fn retry_fetch(url: &str) -> Option { +async fn retry_fetch(url: &str, upstream: Option>) -> Option { let delays = [1u64, 5, 15]; for (i, d) in delays.iter().enumerate() { - match fetch_proxy_config(url).await { + match fetch_proxy_config_via_upstream(url, upstream.clone()).await { Ok(cfg) => return Some(cfg), Err(e) => { if i == delays.len() - 1 { @@ -96,13 +98,17 @@ pub async fn save_proxy_config_cache(path: &str, raw_text: &str) -> Result<()> { } pub async fn fetch_proxy_config_with_raw(url: &str) -> Result<(ProxyConfigData, String)> { - let resp = reqwest::get(url).await.map_err(|e| { - crate::error::ProxyError::Proxy(format!("fetch_proxy_config GET failed: {e}")) - })?; - let http_status = resp.status().as_u16(); + fetch_proxy_config_with_raw_via_upstream(url, None).await +} - if let Some(date) = resp.headers().get(reqwest::header::DATE) - && let Ok(date_str) = date.to_str() +pub async fn fetch_proxy_config_with_raw_via_upstream( + url: &str, + upstream: Option>, +) -> Result<(ProxyConfigData, String)> { + let resp = https_get(url, upstream).await?; + let http_status = resp.status; + + if let Some(date_str) = resp.date_header.as_deref() && let Ok(server_time) = httpdate::parse_http_date(date_str) && let Ok(skew) = SystemTime::now() .duration_since(server_time) @@ -123,9 +129,7 @@ pub async fn fetch_proxy_config_with_raw(url: &str) -> Result<(ProxyConfigData, } } - let text = resp.text().await.map_err(|e| { - crate::error::ProxyError::Proxy(format!("fetch_proxy_config read failed: {e}")) - })?; + let text = String::from_utf8_lossy(&resp.body).into_owned(); let parsed = parse_proxy_config_text(&text, http_status); Ok((parsed, text)) } @@ -261,7 +265,14 @@ fn parse_proxy_line(line: &str) -> Option<(i32, IpAddr, u16)> { } pub async fn fetch_proxy_config(url: &str) -> Result { - fetch_proxy_config_with_raw(url) + fetch_proxy_config_via_upstream(url, None).await +} + +pub async fn fetch_proxy_config_via_upstream( + url: &str, + upstream: Option>, +) -> Result { + fetch_proxy_config_with_raw_via_upstream(url, upstream) .await .map(|(parsed, _raw)| parsed) } @@ -300,6 +311,7 @@ async fn run_update_cycle( state: &mut UpdaterState, reinit_tx: &mpsc::Sender, ) { + let upstream = pool.upstream.clone(); pool.update_runtime_reinit_policy( cfg.general.hardswap, cfg.general.me_pool_drain_ttl_secs, @@ -354,7 +366,7 @@ async fn run_update_cycle( let mut maps_changed = false; let mut ready_v4: Option<(ProxyConfigData, u64)> = None; - let cfg_v4 = retry_fetch("https://core.telegram.org/getProxyConfig").await; + let cfg_v4 = retry_fetch("https://core.telegram.org/getProxyConfig", upstream.clone()).await; if let Some(cfg_v4) = cfg_v4 && snapshot_passes_guards(cfg, &cfg_v4, "getProxyConfig") { @@ -378,7 +390,7 @@ async fn run_update_cycle( } let mut ready_v6: Option<(ProxyConfigData, u64)> = None; - let cfg_v6 = retry_fetch("https://core.telegram.org/getProxyConfigV6").await; + let cfg_v6 = retry_fetch("https://core.telegram.org/getProxyConfigV6", upstream.clone()).await; if let Some(cfg_v6) = cfg_v6 && snapshot_passes_guards(cfg, &cfg_v6, "getProxyConfigV6") { @@ -456,7 +468,12 @@ async fn run_update_cycle( pool.reset_stun_state(); if cfg.general.proxy_secret_rotate_runtime { - match download_proxy_secret_with_max_len(cfg.general.proxy_secret_len_max).await { + match download_proxy_secret_with_max_len_via_upstream( + cfg.general.proxy_secret_len_max, + upstream, + ) + .await + { Ok(secret) => { let secret_hash = hash_secret(&secret); let stable_hits = state.secret.observe(secret_hash); diff --git a/src/transport/middle_proxy/http_fetch.rs b/src/transport/middle_proxy/http_fetch.rs new file mode 100644 index 0000000..c1bb4f6 --- /dev/null +++ b/src/transport/middle_proxy/http_fetch.rs @@ -0,0 +1,184 @@ +use std::sync::Arc; +use std::time::Duration; + +use http_body_util::{BodyExt, Empty}; +use hyper::header::{CONNECTION, DATE, HOST, USER_AGENT}; +use hyper::{Method, Request}; +use hyper_util::rt::TokioIo; +use rustls::pki_types::ServerName; +use tokio::net::TcpStream; +use tokio::time::timeout; +use tokio_rustls::TlsConnector; +use tracing::debug; + +use crate::error::{ProxyError, Result}; +use crate::network::dns_overrides::resolve_socket_addr; +use crate::transport::{UpstreamManager, UpstreamStream}; + +const HTTP_CONNECT_TIMEOUT: Duration = Duration::from_secs(10); +const HTTP_REQUEST_TIMEOUT: Duration = Duration::from_secs(15); + +pub(crate) struct HttpsGetResponse { + pub(crate) status: u16, + pub(crate) date_header: Option, + pub(crate) body: Vec, +} + +fn build_tls_client_config() -> Arc { + let mut root_store = rustls::RootCertStore::empty(); + root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); + let config = rustls::ClientConfig::builder() + .with_root_certificates(root_store) + .with_no_client_auth(); + Arc::new(config) +} + +fn extract_host_port_path(url: &str) -> Result<(String, u16, String)> { + let parsed = url::Url::parse(url) + .map_err(|e| ProxyError::Proxy(format!("invalid URL '{url}': {e}")))?; + if parsed.scheme() != "https" { + return Err(ProxyError::Proxy(format!( + "unsupported URL scheme '{}': only https is supported", + parsed.scheme() + ))); + } + + let host = parsed + .host_str() + .ok_or_else(|| ProxyError::Proxy(format!("URL has no host: {url}")))? + .to_string(); + let port = parsed + .port_or_known_default() + .ok_or_else(|| ProxyError::Proxy(format!("URL has no known port: {url}")))?; + + let mut path = parsed.path().to_string(); + if path.is_empty() { + path.push('/'); + } + if let Some(query) = parsed.query() { + path.push('?'); + path.push_str(query); + } + + Ok((host, port, path)) +} + +async fn resolve_target_addr(host: &str, port: u16) -> Result { + if let Some(addr) = resolve_socket_addr(host, port) { + return Ok(addr); + } + + let addrs: Vec = tokio::net::lookup_host((host, port)) + .await + .map_err(|e| ProxyError::Proxy(format!("DNS resolve failed for {host}:{port}: {e}")))? + .collect(); + + if let Some(addr) = addrs.iter().copied().find(|addr| addr.is_ipv4()) { + return Ok(addr); + } + + addrs + .first() + .copied() + .ok_or_else(|| ProxyError::Proxy(format!("DNS returned no addresses for {host}:{port}"))) +} + +async fn connect_https_transport( + host: &str, + port: u16, + upstream: Option>, +) -> Result { + if let Some(manager) = upstream { + let target = resolve_target_addr(host, port).await?; + return timeout(HTTP_CONNECT_TIMEOUT, manager.connect(target, None, None)) + .await + .map_err(|_| { + ProxyError::Proxy(format!("upstream connect timeout for {host}:{port}")) + })? + .map_err(|e| { + ProxyError::Proxy(format!( + "upstream connect failed for {host}:{port}: {e}" + )) + }); + } + + if let Some(addr) = resolve_socket_addr(host, port) { + let stream = timeout(HTTP_CONNECT_TIMEOUT, TcpStream::connect(addr)) + .await + .map_err(|_| ProxyError::Proxy(format!("connect timeout for {host}:{port}")))? + .map_err(|e| ProxyError::Proxy(format!("connect failed for {host}:{port}: {e}")))?; + return Ok(UpstreamStream::Tcp(stream)); + } + + let stream = timeout(HTTP_CONNECT_TIMEOUT, TcpStream::connect((host, port))) + .await + .map_err(|_| ProxyError::Proxy(format!("connect timeout for {host}:{port}")))? + .map_err(|e| ProxyError::Proxy(format!("connect failed for {host}:{port}: {e}")))?; + Ok(UpstreamStream::Tcp(stream)) +} + +pub(crate) async fn https_get( + url: &str, + upstream: Option>, +) -> Result { + let (host, port, path_and_query) = extract_host_port_path(url)?; + let stream = connect_https_transport(&host, port, upstream).await?; + + let server_name = ServerName::try_from(host.clone()) + .map_err(|_| ProxyError::Proxy(format!("invalid TLS server name: {host}")))?; + let connector = TlsConnector::from(build_tls_client_config()); + let tls_stream = timeout(HTTP_REQUEST_TIMEOUT, connector.connect(server_name, stream)) + .await + .map_err(|_| ProxyError::Proxy(format!("TLS handshake timeout for {host}:{port}")))? + .map_err(|e| ProxyError::Proxy(format!("TLS handshake failed for {host}:{port}: {e}")))?; + + let (mut sender, connection) = hyper::client::conn::http1::handshake(TokioIo::new(tls_stream)) + .await + .map_err(|e| ProxyError::Proxy(format!("HTTP handshake failed for {host}:{port}: {e}")))?; + + tokio::spawn(async move { + if let Err(e) = connection.await { + debug!(error = %e, "HTTPS fetch connection task failed"); + } + }); + + let host_header = if port == 443 { + host.clone() + } else { + format!("{host}:{port}") + }; + + let request = Request::builder() + .method(Method::GET) + .uri(path_and_query) + .header(HOST, host_header) + .header(USER_AGENT, "telemt-middle-proxy/1") + .header(CONNECTION, "close") + .body(Empty::::new()) + .map_err(|e| ProxyError::Proxy(format!("build HTTP request failed for {url}: {e}")))?; + + let response = timeout(HTTP_REQUEST_TIMEOUT, sender.send_request(request)) + .await + .map_err(|_| ProxyError::Proxy(format!("HTTP request timeout for {url}")))? + .map_err(|e| ProxyError::Proxy(format!("HTTP request failed for {url}: {e}")))?; + + let status = response.status().as_u16(); + let date_header = response + .headers() + .get(DATE) + .and_then(|value| value.to_str().ok()) + .map(|value| value.to_string()); + + let body = timeout(HTTP_REQUEST_TIMEOUT, response.into_body().collect()) + .await + .map_err(|_| ProxyError::Proxy(format!("HTTP body read timeout for {url}")))? + .map_err(|e| ProxyError::Proxy(format!("HTTP body read failed for {url}: {e}")))? + .to_bytes() + .to_vec(); + + Ok(HttpsGetResponse { + status, + date_header, + body, + }) +} diff --git a/src/transport/middle_proxy/mod.rs b/src/transport/middle_proxy/mod.rs index 5536869..3a3642a 100644 --- a/src/transport/middle_proxy/mod.rs +++ b/src/transport/middle_proxy/mod.rs @@ -4,6 +4,7 @@ mod codec; mod config_updater; mod handshake; mod health; +mod http_fetch; #[cfg(test)] #[path = "tests/health_adversarial_tests.rs"] mod health_adversarial_tests; @@ -44,7 +45,8 @@ use bytes::Bytes; #[allow(unused_imports)] pub use config_updater::{ - ProxyConfigData, fetch_proxy_config, fetch_proxy_config_with_raw, load_proxy_config_cache, + ProxyConfigData, fetch_proxy_config, fetch_proxy_config_via_upstream, + fetch_proxy_config_with_raw, fetch_proxy_config_with_raw_via_upstream, load_proxy_config_cache, me_config_updater, save_proxy_config_cache, }; pub use health::{me_drain_timeout_enforcer, me_health_monitor, me_zombie_writer_watchdog}; @@ -57,7 +59,7 @@ pub use pool::MePool; pub use pool_nat::{detect_public_ip, stun_probe}; pub use registry::ConnRegistry; pub use rotation::{MeReinitTrigger, me_reinit_scheduler, me_rotation_task}; -pub use secret::fetch_proxy_secret; +pub use secret::{fetch_proxy_secret, fetch_proxy_secret_with_upstream}; pub(crate) use selftest::{bnd_snapshot, timeskew_snapshot, upstream_bnd_snapshots}; pub use wire::proto_flags_for_tag; diff --git a/src/transport/middle_proxy/secret.rs b/src/transport/middle_proxy/secret.rs index 504270a..450c80a 100644 --- a/src/transport/middle_proxy/secret.rs +++ b/src/transport/middle_proxy/secret.rs @@ -1,8 +1,11 @@ use httpdate; +use std::sync::Arc; use std::time::SystemTime; use tracing::{debug, info, warn}; +use super::http_fetch::https_get; use super::selftest::record_timeskew_sample; +use crate::transport::UpstreamManager; use crate::error::{ProxyError, Result}; pub const PROXY_SECRET_MIN_LEN: usize = 32; @@ -34,10 +37,19 @@ pub(super) fn validate_proxy_secret_len(data_len: usize, max_len: usize) -> Resu /// Fetch Telegram proxy-secret binary. pub async fn fetch_proxy_secret(cache_path: Option<&str>, max_len: usize) -> Result> { + fetch_proxy_secret_with_upstream(cache_path, max_len, None).await +} + +/// Fetch Telegram proxy-secret binary, optionally through upstream routing. +pub async fn fetch_proxy_secret_with_upstream( + cache_path: Option<&str>, + max_len: usize, + upstream: Option>, +) -> Result> { let cache = cache_path.unwrap_or("proxy-secret"); // 1) Try fresh download first. - match download_proxy_secret_with_max_len(max_len).await { + match download_proxy_secret_with_max_len_via_upstream(max_len, upstream).await { Ok(data) => { if let Err(e) = tokio::fs::write(cache, &data).await { warn!(error = %e, "Failed to cache proxy-secret (non-fatal)"); @@ -77,19 +89,23 @@ pub async fn fetch_proxy_secret(cache_path: Option<&str>, max_len: usize) -> Res } pub async fn download_proxy_secret_with_max_len(max_len: usize) -> Result> { - let resp = reqwest::get("https://core.telegram.org/getProxySecret") - .await - .map_err(|e| ProxyError::Proxy(format!("Failed to download proxy-secret: {e}")))?; + download_proxy_secret_with_max_len_via_upstream(max_len, None).await +} - if !resp.status().is_success() { +pub async fn download_proxy_secret_with_max_len_via_upstream( + max_len: usize, + upstream: Option>, +) -> Result> { + let resp = https_get("https://core.telegram.org/getProxySecret", upstream).await?; + + if !(200..=299).contains(&resp.status) { return Err(ProxyError::Proxy(format!( "proxy-secret download HTTP {}", - resp.status() + resp.status ))); } - if let Some(date) = resp.headers().get(reqwest::header::DATE) - && let Ok(date_str) = date.to_str() + if let Some(date_str) = resp.date_header.as_deref() && let Ok(server_time) = httpdate::parse_http_date(date_str) && let Ok(skew) = SystemTime::now() .duration_since(server_time) @@ -110,11 +126,7 @@ pub async fn download_proxy_secret_with_max_len(max_len: usize) -> Result Date: Mon, 23 Mar 2026 23:00:46 +0300 Subject: [PATCH 26/29] Rustfmt --- src/transport/middle_proxy/config_updater.rs | 10 ++++++++-- src/transport/middle_proxy/http_fetch.rs | 12 ++++-------- src/transport/middle_proxy/mod.rs | 3 ++- src/transport/middle_proxy/secret.rs | 4 +++- 4 files changed, 17 insertions(+), 12 deletions(-) diff --git a/src/transport/middle_proxy/config_updater.rs b/src/transport/middle_proxy/config_updater.rs index 9819e8d..ba90c1a 100644 --- a/src/transport/middle_proxy/config_updater.rs +++ b/src/transport/middle_proxy/config_updater.rs @@ -13,8 +13,8 @@ use crate::config::ProxyConfig; use crate::error::Result; use crate::transport::UpstreamManager; -use super::http_fetch::https_get; use super::MePool; +use super::http_fetch::https_get; use super::rotation::{MeReinitTrigger, enqueue_reinit_trigger}; use super::secret::download_proxy_secret_with_max_len_via_upstream; use super::selftest::record_timeskew_sample; @@ -97,6 +97,7 @@ pub async fn save_proxy_config_cache(path: &str, raw_text: &str) -> Result<()> { Ok(()) } +#[allow(dead_code)] pub async fn fetch_proxy_config_with_raw(url: &str) -> Result<(ProxyConfigData, String)> { fetch_proxy_config_with_raw_via_upstream(url, None).await } @@ -264,6 +265,7 @@ fn parse_proxy_line(line: &str) -> Option<(i32, IpAddr, u16)> { Some((dc, ip, port)) } +#[allow(dead_code)] pub async fn fetch_proxy_config(url: &str) -> Result { fetch_proxy_config_via_upstream(url, None).await } @@ -390,7 +392,11 @@ async fn run_update_cycle( } let mut ready_v6: Option<(ProxyConfigData, u64)> = None; - let cfg_v6 = retry_fetch("https://core.telegram.org/getProxyConfigV6", upstream.clone()).await; + let cfg_v6 = retry_fetch( + "https://core.telegram.org/getProxyConfigV6", + upstream.clone(), + ) + .await; if let Some(cfg_v6) = cfg_v6 && snapshot_passes_guards(cfg, &cfg_v6, "getProxyConfigV6") { diff --git a/src/transport/middle_proxy/http_fetch.rs b/src/transport/middle_proxy/http_fetch.rs index c1bb4f6..2f21934 100644 --- a/src/transport/middle_proxy/http_fetch.rs +++ b/src/transport/middle_proxy/http_fetch.rs @@ -34,8 +34,8 @@ fn build_tls_client_config() -> Arc { } fn extract_host_port_path(url: &str) -> Result<(String, u16, String)> { - let parsed = url::Url::parse(url) - .map_err(|e| ProxyError::Proxy(format!("invalid URL '{url}': {e}")))?; + let parsed = + url::Url::parse(url).map_err(|e| ProxyError::Proxy(format!("invalid URL '{url}': {e}")))?; if parsed.scheme() != "https" { return Err(ProxyError::Proxy(format!( "unsupported URL scheme '{}': only https is supported", @@ -92,13 +92,9 @@ async fn connect_https_transport( let target = resolve_target_addr(host, port).await?; return timeout(HTTP_CONNECT_TIMEOUT, manager.connect(target, None, None)) .await - .map_err(|_| { - ProxyError::Proxy(format!("upstream connect timeout for {host}:{port}")) - })? + .map_err(|_| ProxyError::Proxy(format!("upstream connect timeout for {host}:{port}")))? .map_err(|e| { - ProxyError::Proxy(format!( - "upstream connect failed for {host}:{port}: {e}" - )) + ProxyError::Proxy(format!("upstream connect failed for {host}:{port}: {e}")) }); } diff --git a/src/transport/middle_proxy/mod.rs b/src/transport/middle_proxy/mod.rs index 3a3642a..6dfbee6 100644 --- a/src/transport/middle_proxy/mod.rs +++ b/src/transport/middle_proxy/mod.rs @@ -4,7 +4,6 @@ mod codec; mod config_updater; mod handshake; mod health; -mod http_fetch; #[cfg(test)] #[path = "tests/health_adversarial_tests.rs"] mod health_adversarial_tests; @@ -14,6 +13,7 @@ mod health_integration_tests; #[cfg(test)] #[path = "tests/health_regression_tests.rs"] mod health_regression_tests; +mod http_fetch; mod ping; mod pool; mod pool_config; @@ -59,6 +59,7 @@ pub use pool::MePool; pub use pool_nat::{detect_public_ip, stun_probe}; pub use registry::ConnRegistry; pub use rotation::{MeReinitTrigger, me_reinit_scheduler, me_rotation_task}; +#[allow(unused_imports)] pub use secret::{fetch_proxy_secret, fetch_proxy_secret_with_upstream}; pub(crate) use selftest::{bnd_snapshot, timeskew_snapshot, upstream_bnd_snapshots}; pub use wire::proto_flags_for_tag; diff --git a/src/transport/middle_proxy/secret.rs b/src/transport/middle_proxy/secret.rs index 450c80a..a167773 100644 --- a/src/transport/middle_proxy/secret.rs +++ b/src/transport/middle_proxy/secret.rs @@ -5,8 +5,8 @@ use tracing::{debug, info, warn}; use super::http_fetch::https_get; use super::selftest::record_timeskew_sample; -use crate::transport::UpstreamManager; use crate::error::{ProxyError, Result}; +use crate::transport::UpstreamManager; pub const PROXY_SECRET_MIN_LEN: usize = 32; @@ -36,6 +36,7 @@ pub(super) fn validate_proxy_secret_len(data_len: usize, max_len: usize) -> Resu } /// Fetch Telegram proxy-secret binary. +#[allow(dead_code)] pub async fn fetch_proxy_secret(cache_path: Option<&str>, max_len: usize) -> Result> { fetch_proxy_secret_with_upstream(cache_path, max_len, None).await } @@ -88,6 +89,7 @@ pub async fn fetch_proxy_secret_with_upstream( } } +#[allow(dead_code)] pub async fn download_proxy_secret_with_max_len(max_len: usize) -> Result> { download_proxy_secret_with_max_len_via_upstream(max_len, None).await } From 655a08fa5cc3676c8ba62ad0c555301d99193b17 Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Mon, 23 Mar 2026 23:12:50 +0300 Subject: [PATCH 27/29] TLS Fetcher fixes --- src/tls_front/fetcher.rs | 32 ++++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/src/tls_front/fetcher.rs b/src/tls_front/fetcher.rs index 2f39707..2356a93 100644 --- a/src/tls_front/fetcher.rs +++ b/src/tls_front/fetcher.rs @@ -108,16 +108,28 @@ fn build_client_hello(sni: &str, rng: &SecureRandom) -> Vec { // Session ID: empty body.push(0); - // Cipher suites (common minimal set, TLS1.3 + a few 1.2 fallbacks) - let cipher_suites: [u8; 10] = [ - 0x13, 0x01, // TLS_AES_128_GCM_SHA256 - 0x13, 0x02, // TLS_AES_256_GCM_SHA384 - 0x13, 0x03, // TLS_CHACHA20_POLY1305_SHA256 - 0x00, 0x2f, // TLS_RSA_WITH_AES_128_CBC_SHA (legacy) - 0x00, 0xff, // RENEGOTIATION_INFO_SCSV + // Cipher suites: + // - TLS1.3 set + // - broad TLS1.2 ECDHE set for RSA/ECDSA cert chains + // This keeps raw probing compatible with common production frontends that + // still negotiate TLS1.2. + let cipher_suites: [u16; 11] = [ + 0x1301, // TLS_AES_128_GCM_SHA256 + 0x1302, // TLS_AES_256_GCM_SHA384 + 0x1303, // TLS_CHACHA20_POLY1305_SHA256 + 0xc02b, // TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 + 0xc02c, // TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 + 0xcca9, // TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 + 0xc02f, // TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 + 0xc030, // TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 + 0xcca8, // TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 + 0x009e, // TLS_DHE_RSA_WITH_AES_128_GCM_SHA256 + 0x00ff, // TLS_EMPTY_RENEGOTIATION_INFO_SCSV ]; - body.extend_from_slice(&(cipher_suites.len() as u16).to_be_bytes()); - body.extend_from_slice(&cipher_suites); + body.extend_from_slice(&((cipher_suites.len() * 2) as u16).to_be_bytes()); + for suite in cipher_suites { + body.extend_from_slice(&suite.to_be_bytes()); + } // Compression methods: null only body.push(1); @@ -147,7 +159,7 @@ fn build_client_hello(sni: &str, rng: &SecureRandom) -> Vec { } // signature_algorithms - let sig_algs: [u16; 4] = [0x0804, 0x0805, 0x0403, 0x0503]; // rsa_pss_rsae_sha256/384, ecdsa_secp256r1_sha256, rsa_pkcs1_sha256 + let sig_algs: [u16; 4] = [0x0804, 0x0805, 0x0403, 0x0503]; // rsa_pss_rsae_sha256/384, ecdsa_secp256r1_sha256, ecdsa_secp384r1_sha384 exts.extend_from_slice(&0x000du16.to_be_bytes()); exts.extend_from_slice(&((2 + sig_algs.len() * 2) as u16).to_be_bytes()); exts.extend_from_slice(&(sig_algs.len() as u16 * 2).to_be_bytes()); From f7868aa00f727b9797daa535a670dff5f73b056e Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Tue, 24 Mar 2026 09:58:24 +0300 Subject: [PATCH 28/29] Advanced TLS Fetcher Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com> --- src/config/defaults.rs | 16 + src/config/load.rs | 112 ++++- src/config/types.rs | 81 ++++ src/maestro/tls_bootstrap.rs | 25 +- src/tls_front/fetcher.rs | 785 ++++++++++++++++++++++++++++++----- 5 files changed, 905 insertions(+), 114 deletions(-) diff --git a/src/config/defaults.rs b/src/config/defaults.rs index f02403e..b0aaf5b 100644 --- a/src/config/defaults.rs +++ b/src/config/defaults.rs @@ -71,6 +71,22 @@ pub(crate) fn default_tls_fetch_scope() -> String { String::new() } +pub(crate) fn default_tls_fetch_attempt_timeout_ms() -> u64 { + 5_000 +} + +pub(crate) fn default_tls_fetch_total_budget_ms() -> u64 { + 15_000 +} + +pub(crate) fn default_tls_fetch_strict_route() -> bool { + true +} + +pub(crate) fn default_tls_fetch_profile_cache_ttl_secs() -> u64 { + 600 +} + pub(crate) fn default_mask_port() -> u16 { 443 } diff --git a/src/config/load.rs b/src/config/load.rs index 2c46766..3cb6627 100644 --- a/src/config/load.rs +++ b/src/config/load.rs @@ -1,6 +1,6 @@ #![allow(deprecated)] -use std::collections::{BTreeSet, HashMap}; +use std::collections::{BTreeSet, HashMap, HashSet}; use std::hash::{DefaultHasher, Hash, Hasher}; use std::net::{IpAddr, SocketAddr}; use std::path::{Path, PathBuf}; @@ -977,6 +977,28 @@ impl ProxyConfig { // Normalize optional TLS fetch scope: whitespace-only values disable scoped routing. config.censorship.tls_fetch_scope = config.censorship.tls_fetch_scope.trim().to_string(); + if config.censorship.tls_fetch.profiles.is_empty() { + config.censorship.tls_fetch.profiles = TlsFetchConfig::default().profiles; + } else { + let mut seen = HashSet::new(); + config + .censorship + .tls_fetch + .profiles + .retain(|profile| seen.insert(*profile)); + } + + if config.censorship.tls_fetch.attempt_timeout_ms == 0 { + return Err(ProxyError::Config( + "censorship.tls_fetch.attempt_timeout_ms must be > 0".to_string(), + )); + } + if config.censorship.tls_fetch.total_budget_ms == 0 { + return Err(ProxyError::Config( + "censorship.tls_fetch.total_budget_ms must be > 0".to_string(), + )); + } + // Merge primary + extra TLS domains, deduplicate (primary always first). if !config.censorship.tls_domains.is_empty() { let mut all = Vec::with_capacity(1 + config.censorship.tls_domains.len()); @@ -2459,6 +2481,94 @@ mod tests { let _ = std::fs::remove_file(path); } + #[test] + fn tls_fetch_defaults_are_applied() { + let toml = r#" + [censorship] + tls_domain = "example.com" + + [access.users] + user = "00000000000000000000000000000000" + "#; + let dir = std::env::temp_dir(); + let path = dir.join("telemt_tls_fetch_defaults_test.toml"); + std::fs::write(&path, toml).unwrap(); + let cfg = ProxyConfig::load(&path).unwrap(); + assert_eq!( + cfg.censorship.tls_fetch.profiles, + TlsFetchConfig::default().profiles + ); + assert!(cfg.censorship.tls_fetch.strict_route); + assert_eq!(cfg.censorship.tls_fetch.attempt_timeout_ms, 5_000); + assert_eq!(cfg.censorship.tls_fetch.total_budget_ms, 15_000); + assert_eq!(cfg.censorship.tls_fetch.profile_cache_ttl_secs, 600); + let _ = std::fs::remove_file(path); + } + + #[test] + fn tls_fetch_profiles_are_deduplicated_preserving_order() { + let toml = r#" + [censorship] + tls_domain = "example.com" + [censorship.tls_fetch] + profiles = ["compat_tls12", "modern_chrome_like", "compat_tls12", "legacy_minimal"] + + [access.users] + user = "00000000000000000000000000000000" + "#; + let dir = std::env::temp_dir(); + let path = dir.join("telemt_tls_fetch_profiles_dedup_test.toml"); + std::fs::write(&path, toml).unwrap(); + let cfg = ProxyConfig::load(&path).unwrap(); + assert_eq!( + cfg.censorship.tls_fetch.profiles, + vec![ + TlsFetchProfile::CompatTls12, + TlsFetchProfile::ModernChromeLike, + TlsFetchProfile::LegacyMinimal + ] + ); + let _ = std::fs::remove_file(path); + } + + #[test] + fn tls_fetch_attempt_timeout_zero_is_rejected() { + let toml = r#" + [censorship] + tls_domain = "example.com" + [censorship.tls_fetch] + attempt_timeout_ms = 0 + + [access.users] + user = "00000000000000000000000000000000" + "#; + let dir = std::env::temp_dir(); + let path = dir.join("telemt_tls_fetch_attempt_timeout_zero_test.toml"); + std::fs::write(&path, toml).unwrap(); + let err = ProxyConfig::load(&path).unwrap_err().to_string(); + assert!(err.contains("censorship.tls_fetch.attempt_timeout_ms must be > 0")); + let _ = std::fs::remove_file(path); + } + + #[test] + fn tls_fetch_total_budget_zero_is_rejected() { + let toml = r#" + [censorship] + tls_domain = "example.com" + [censorship.tls_fetch] + total_budget_ms = 0 + + [access.users] + user = "00000000000000000000000000000000" + "#; + let dir = std::env::temp_dir(); + let path = dir.join("telemt_tls_fetch_total_budget_zero_test.toml"); + std::fs::write(&path, toml).unwrap(); + let err = ProxyConfig::load(&path).unwrap_err().to_string(); + assert!(err.contains("censorship.tls_fetch.total_budget_ms must be > 0")); + let _ = std::fs::remove_file(path); + } + #[test] fn invalid_ad_tag_is_disabled_during_load() { let toml = r#" diff --git a/src/config/types.rs b/src/config/types.rs index 68ba278..3939664 100644 --- a/src/config/types.rs +++ b/src/config/types.rs @@ -1367,6 +1367,82 @@ pub enum UnknownSniAction { Mask, } +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum TlsFetchProfile { + ModernChromeLike, + ModernFirefoxLike, + CompatTls12, + LegacyMinimal, +} + +impl TlsFetchProfile { + pub fn as_str(self) -> &'static str { + match self { + TlsFetchProfile::ModernChromeLike => "modern_chrome_like", + TlsFetchProfile::ModernFirefoxLike => "modern_firefox_like", + TlsFetchProfile::CompatTls12 => "compat_tls12", + TlsFetchProfile::LegacyMinimal => "legacy_minimal", + } + } +} + +fn default_tls_fetch_profiles() -> Vec { + vec![ + TlsFetchProfile::ModernChromeLike, + TlsFetchProfile::ModernFirefoxLike, + TlsFetchProfile::CompatTls12, + TlsFetchProfile::LegacyMinimal, + ] +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TlsFetchConfig { + /// Ordered list of ClientHello profiles used for adaptive fallback. + #[serde(default = "default_tls_fetch_profiles")] + pub profiles: Vec, + + /// When true and upstream route is configured, TLS fetch fails closed on + /// upstream connect errors and does not fallback to direct TCP. + #[serde(default = "default_tls_fetch_strict_route")] + pub strict_route: bool, + + /// Timeout per one profile attempt in milliseconds. + #[serde(default = "default_tls_fetch_attempt_timeout_ms")] + pub attempt_timeout_ms: u64, + + /// Total wall-clock budget in milliseconds across all profile attempts. + #[serde(default = "default_tls_fetch_total_budget_ms")] + pub total_budget_ms: u64, + + /// Adds GREASE-style values into selected ClientHello extensions. + #[serde(default)] + pub grease_enabled: bool, + + /// Produces deterministic ClientHello randomness for debugging/tests. + #[serde(default)] + pub deterministic: bool, + + /// TTL for winner-profile cache entries in seconds. + /// Set to 0 to disable profile cache. + #[serde(default = "default_tls_fetch_profile_cache_ttl_secs")] + pub profile_cache_ttl_secs: u64, +} + +impl Default for TlsFetchConfig { + fn default() -> Self { + Self { + profiles: default_tls_fetch_profiles(), + strict_route: default_tls_fetch_strict_route(), + attempt_timeout_ms: default_tls_fetch_attempt_timeout_ms(), + total_budget_ms: default_tls_fetch_total_budget_ms(), + grease_enabled: false, + deterministic: false, + profile_cache_ttl_secs: default_tls_fetch_profile_cache_ttl_secs(), + } + } +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct AntiCensorshipConfig { #[serde(default = "default_tls_domain")] @@ -1385,6 +1461,10 @@ pub struct AntiCensorshipConfig { #[serde(default = "default_tls_fetch_scope")] pub tls_fetch_scope: String, + /// Fetch strategy for TLS front metadata bootstrap and periodic refresh. + #[serde(default)] + pub tls_fetch: TlsFetchConfig, + #[serde(default = "default_true")] pub mask: bool, @@ -1492,6 +1572,7 @@ impl Default for AntiCensorshipConfig { tls_domains: Vec::new(), unknown_sni_action: UnknownSniAction::Drop, tls_fetch_scope: default_tls_fetch_scope(), + tls_fetch: TlsFetchConfig::default(), mask: default_true(), mask_host: None, mask_port: default_mask_port(), diff --git a/src/maestro/tls_bootstrap.rs b/src/maestro/tls_bootstrap.rs index 342a2f9..7cf3039 100644 --- a/src/maestro/tls_bootstrap.rs +++ b/src/maestro/tls_bootstrap.rs @@ -7,6 +7,7 @@ use tracing::warn; use crate::config::ProxyConfig; use crate::startup::{COMPONENT_TLS_FRONT_BOOTSTRAP, StartupTracker}; use crate::tls_front::TlsFrontCache; +use crate::tls_front::fetcher::TlsFetchStrategy; use crate::transport::UpstreamManager; pub(crate) async fn bootstrap_tls_front( @@ -40,7 +41,17 @@ pub(crate) async fn bootstrap_tls_front( let mask_unix_sock = config.censorship.mask_unix_sock.clone(); let tls_fetch_scope = (!config.censorship.tls_fetch_scope.is_empty()) .then(|| config.censorship.tls_fetch_scope.clone()); - let fetch_timeout = Duration::from_secs(5); + let tls_fetch = config.censorship.tls_fetch.clone(); + let fetch_strategy = TlsFetchStrategy { + profiles: tls_fetch.profiles, + strict_route: tls_fetch.strict_route, + attempt_timeout: Duration::from_millis(tls_fetch.attempt_timeout_ms.max(1)), + total_budget: Duration::from_millis(tls_fetch.total_budget_ms.max(1)), + grease_enabled: tls_fetch.grease_enabled, + deterministic: tls_fetch.deterministic, + profile_cache_ttl: Duration::from_secs(tls_fetch.profile_cache_ttl_secs), + }; + let fetch_timeout = fetch_strategy.total_budget; let cache_initial = cache.clone(); let domains_initial = tls_domains.to_vec(); @@ -48,6 +59,7 @@ pub(crate) async fn bootstrap_tls_front( let unix_sock_initial = mask_unix_sock.clone(); let scope_initial = tls_fetch_scope.clone(); let upstream_initial = upstream_manager.clone(); + let strategy_initial = fetch_strategy.clone(); tokio::spawn(async move { let mut join = tokio::task::JoinSet::new(); for domain in domains_initial { @@ -56,12 +68,13 @@ pub(crate) async fn bootstrap_tls_front( let unix_sock_domain = unix_sock_initial.clone(); let scope_domain = scope_initial.clone(); let upstream_domain = upstream_initial.clone(); + let strategy_domain = strategy_initial.clone(); join.spawn(async move { - match crate::tls_front::fetcher::fetch_real_tls( + match crate::tls_front::fetcher::fetch_real_tls_with_strategy( &host_domain, port, &domain, - fetch_timeout, + &strategy_domain, Some(upstream_domain), scope_domain.as_deref(), proxy_protocol, @@ -107,6 +120,7 @@ pub(crate) async fn bootstrap_tls_front( let unix_sock_refresh = mask_unix_sock.clone(); let scope_refresh = tls_fetch_scope.clone(); let upstream_refresh = upstream_manager.clone(); + let strategy_refresh = fetch_strategy.clone(); tokio::spawn(async move { loop { let base_secs = rand::rng().random_range(4 * 3600..=6 * 3600); @@ -120,12 +134,13 @@ pub(crate) async fn bootstrap_tls_front( let unix_sock_domain = unix_sock_refresh.clone(); let scope_domain = scope_refresh.clone(); let upstream_domain = upstream_refresh.clone(); + let strategy_domain = strategy_refresh.clone(); join.spawn(async move { - match crate::tls_front::fetcher::fetch_real_tls( + match crate::tls_front::fetcher::fetch_real_tls_with_strategy( &host_domain, port, &domain, - fetch_timeout, + &strategy_domain, Some(upstream_domain), scope_domain.as_deref(), proxy_protocol, diff --git a/src/tls_front/fetcher.rs b/src/tls_front/fetcher.rs index 2356a93..503b79c 100644 --- a/src/tls_front/fetcher.rs +++ b/src/tls_front/fetcher.rs @@ -1,7 +1,9 @@ #![allow(clippy::too_many_arguments)] +use dashmap::DashMap; use std::sync::Arc; -use std::time::Duration; +use std::sync::OnceLock; +use std::time::{Duration, Instant}; use anyhow::{Result, anyhow}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; @@ -21,7 +23,8 @@ use rustls::{DigitallySignedStruct, Error as RustlsError}; use x509_parser::certificate::X509Certificate; use x509_parser::prelude::FromDer; -use crate::crypto::SecureRandom; +use crate::config::TlsFetchProfile; +use crate::crypto::{SecureRandom, sha256}; use crate::network::dns_overrides::resolve_socket_addr; use crate::protocol::constants::{ TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER, TLS_RECORD_HANDSHAKE, @@ -78,6 +81,197 @@ impl ServerCertVerifier for NoVerify { } } +#[derive(Debug, Clone)] +pub struct TlsFetchStrategy { + pub profiles: Vec, + pub strict_route: bool, + pub attempt_timeout: Duration, + pub total_budget: Duration, + pub grease_enabled: bool, + pub deterministic: bool, + pub profile_cache_ttl: Duration, +} + +impl TlsFetchStrategy { + #[allow(dead_code)] + pub fn single_attempt(connect_timeout: Duration) -> Self { + Self { + profiles: vec![TlsFetchProfile::CompatTls12], + strict_route: false, + attempt_timeout: connect_timeout.max(Duration::from_millis(1)), + total_budget: connect_timeout.max(Duration::from_millis(1)), + grease_enabled: false, + deterministic: false, + profile_cache_ttl: Duration::ZERO, + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +struct ProfileCacheKey { + host: String, + port: u16, + sni: String, + scope: Option, + proxy_protocol: u8, + route_hint: RouteHint, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +enum RouteHint { + Direct, + Upstream, + Unix, +} + +#[derive(Debug, Clone, Copy)] +struct ProfileCacheValue { + profile: TlsFetchProfile, + updated_at: Instant, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum FetchErrorKind { + Connect, + Route, + EarlyEof, + Timeout, + ServerHelloMissing, + TlsAlert, + Parse, + Other, +} + +static PROFILE_CACHE: OnceLock> = OnceLock::new(); + +fn profile_cache() -> &'static DashMap { + PROFILE_CACHE.get_or_init(DashMap::new) +} + +fn route_hint( + upstream: Option<&std::sync::Arc>, + unix_sock: Option<&str>, +) -> RouteHint { + if unix_sock.is_some() { + RouteHint::Unix + } else if upstream.is_some() { + RouteHint::Upstream + } else { + RouteHint::Direct + } +} + +fn profile_cache_key( + host: &str, + port: u16, + sni: &str, + upstream: Option<&std::sync::Arc>, + scope: Option<&str>, + proxy_protocol: u8, + unix_sock: Option<&str>, +) -> ProfileCacheKey { + ProfileCacheKey { + host: host.to_string(), + port, + sni: sni.to_string(), + scope: scope.map(ToString::to_string), + proxy_protocol, + route_hint: route_hint(upstream, unix_sock), + } +} + +fn classify_fetch_error(err: &anyhow::Error) -> FetchErrorKind { + for cause in err.chain() { + if let Some(io) = cause.downcast_ref::() { + return match io.kind() { + std::io::ErrorKind::TimedOut => FetchErrorKind::Timeout, + std::io::ErrorKind::UnexpectedEof => FetchErrorKind::EarlyEof, + std::io::ErrorKind::ConnectionRefused + | std::io::ErrorKind::ConnectionAborted + | std::io::ErrorKind::ConnectionReset + | std::io::ErrorKind::NotConnected + | std::io::ErrorKind::AddrNotAvailable => FetchErrorKind::Connect, + _ => FetchErrorKind::Other, + }; + } + } + + let message = err.to_string().to_lowercase(); + if message.contains("upstream route") { + FetchErrorKind::Route + } else if message.contains("serverhello not received") { + FetchErrorKind::ServerHelloMissing + } else if message.contains("alert") { + FetchErrorKind::TlsAlert + } else if message.contains("parse") { + FetchErrorKind::Parse + } else if message.contains("timed out") || message.contains("deadline has elapsed") { + FetchErrorKind::Timeout + } else if message.contains("eof") { + FetchErrorKind::EarlyEof + } else { + FetchErrorKind::Other + } +} + +fn order_profiles( + strategy: &TlsFetchStrategy, + cache_key: Option<&ProfileCacheKey>, + now: Instant, +) -> Vec { + let mut ordered = if strategy.profiles.is_empty() { + vec![TlsFetchProfile::CompatTls12] + } else { + strategy.profiles.clone() + }; + + if strategy.profile_cache_ttl.is_zero() { + return ordered; + } + + let Some(key) = cache_key else { + return ordered; + }; + + if let Some(cached) = profile_cache().get(key) { + let age = now.saturating_duration_since(cached.updated_at); + if age > strategy.profile_cache_ttl { + drop(cached); + profile_cache().remove(key); + return ordered; + } + + if let Some(pos) = ordered.iter().position(|profile| *profile == cached.profile) { + if pos != 0 { + ordered.swap(0, pos); + } + } + } + + ordered +} + +fn remember_profile_success( + strategy: &TlsFetchStrategy, + cache_key: Option, + profile: TlsFetchProfile, + now: Instant, +) { + if strategy.profile_cache_ttl.is_zero() { + return; + } + let Some(key) = cache_key else { + return; + }; + profile_cache().insert( + key, + ProfileCacheValue { + profile, + updated_at: now, + }, + ); +} + fn build_client_config() -> Arc { let root = rustls::RootCertStore::empty(); @@ -95,7 +289,114 @@ fn build_client_config() -> Arc { Arc::new(config) } -fn build_client_hello(sni: &str, rng: &SecureRandom) -> Vec { +fn deterministic_bytes(seed: &str, len: usize) -> Vec { + let mut out = Vec::with_capacity(len); + let mut counter: u32 = 0; + while out.len() < len { + let mut chunk_seed = Vec::with_capacity(seed.len() + std::mem::size_of::()); + chunk_seed.extend_from_slice(seed.as_bytes()); + chunk_seed.extend_from_slice(&counter.to_le_bytes()); + out.extend_from_slice(&sha256(&chunk_seed)); + counter = counter.wrapping_add(1); + } + out.truncate(len); + out +} + +fn profile_cipher_suites(profile: TlsFetchProfile) -> &'static [u16] { + const MODERN_CHROME: &[u16] = &[ + 0x1301, 0x1302, 0x1303, 0xc02b, 0xc02c, 0xcca9, 0xc02f, 0xc030, 0xcca8, 0x009e, 0x00ff, + ]; + const MODERN_FIREFOX: &[u16] = &[ + 0x1301, 0x1303, 0x1302, 0xc02b, 0xcca9, 0xc02c, 0xc02f, 0xcca8, 0xc030, 0x009e, 0x00ff, + ]; + const COMPAT_TLS12: &[u16] = &[ + 0xc02b, 0xc02c, 0xc02f, 0xc030, 0xcca9, 0xcca8, 0x1301, 0x1302, 0x1303, 0x009e, 0x00ff, + ]; + const LEGACY_MINIMAL: &[u16] = &[0xc02b, 0xc02f, 0x1301, 0x1302, 0x00ff]; + + match profile { + TlsFetchProfile::ModernChromeLike => MODERN_CHROME, + TlsFetchProfile::ModernFirefoxLike => MODERN_FIREFOX, + TlsFetchProfile::CompatTls12 => COMPAT_TLS12, + TlsFetchProfile::LegacyMinimal => LEGACY_MINIMAL, + } +} + +fn profile_groups(profile: TlsFetchProfile) -> &'static [u16] { + const MODERN: &[u16] = &[0x001d, 0x0017, 0x0018]; // x25519, secp256r1, secp384r1 + const COMPAT: &[u16] = &[0x001d, 0x0017]; + const LEGACY: &[u16] = &[0x0017]; + + match profile { + TlsFetchProfile::ModernChromeLike | TlsFetchProfile::ModernFirefoxLike => MODERN, + TlsFetchProfile::CompatTls12 => COMPAT, + TlsFetchProfile::LegacyMinimal => LEGACY, + } +} + +fn profile_sig_algs(profile: TlsFetchProfile) -> &'static [u16] { + const MODERN: &[u16] = &[0x0804, 0x0805, 0x0403, 0x0503, 0x0806]; + const COMPAT: &[u16] = &[0x0403, 0x0503, 0x0804, 0x0805]; + const LEGACY: &[u16] = &[0x0403, 0x0804]; + + match profile { + TlsFetchProfile::ModernChromeLike | TlsFetchProfile::ModernFirefoxLike => MODERN, + TlsFetchProfile::CompatTls12 => COMPAT, + TlsFetchProfile::LegacyMinimal => LEGACY, + } +} + +fn profile_alpn(profile: TlsFetchProfile) -> &'static [&'static [u8]] { + const H2_HTTP11: &[&[u8]] = &[b"h2", b"http/1.1"]; + const HTTP11: &[&[u8]] = &[b"http/1.1"]; + match profile { + TlsFetchProfile::ModernChromeLike | TlsFetchProfile::ModernFirefoxLike => H2_HTTP11, + TlsFetchProfile::CompatTls12 | TlsFetchProfile::LegacyMinimal => HTTP11, + } +} + +fn profile_supported_versions(profile: TlsFetchProfile) -> &'static [u16] { + const MODERN: &[u16] = &[0x0304, 0x0303]; + const COMPAT: &[u16] = &[0x0303, 0x0304]; + const LEGACY: &[u16] = &[0x0303]; + match profile { + TlsFetchProfile::ModernChromeLike | TlsFetchProfile::ModernFirefoxLike => MODERN, + TlsFetchProfile::CompatTls12 => COMPAT, + TlsFetchProfile::LegacyMinimal => LEGACY, + } +} + +fn profile_padding_target(profile: TlsFetchProfile) -> usize { + match profile { + TlsFetchProfile::ModernChromeLike => 220, + TlsFetchProfile::ModernFirefoxLike => 200, + TlsFetchProfile::CompatTls12 => 180, + TlsFetchProfile::LegacyMinimal => 64, + } +} + +fn grease_value(rng: &SecureRandom, deterministic: bool, seed: &str) -> u16 { + const GREASE_VALUES: [u16; 16] = [ + 0x0a0a, 0x1a1a, 0x2a2a, 0x3a3a, 0x4a4a, 0x5a5a, 0x6a6a, 0x7a7a, 0x8a8a, 0x9a9a, 0xaaaa, + 0xbaba, 0xcaca, 0xdada, 0xeaea, 0xfafa, + ]; + if deterministic { + let idx = deterministic_bytes(seed, 1)[0] as usize % GREASE_VALUES.len(); + GREASE_VALUES[idx] + } else { + let idx = (rng.bytes(1)[0] as usize) % GREASE_VALUES.len(); + GREASE_VALUES[idx] + } +} + +fn build_client_hello( + sni: &str, + rng: &SecureRandom, + profile: TlsFetchProfile, + grease_enabled: bool, + deterministic: bool, +) -> Vec { // === ClientHello body === let mut body = Vec::new(); @@ -103,29 +404,20 @@ fn build_client_hello(sni: &str, rng: &SecureRandom) -> Vec { body.extend_from_slice(&[0x03, 0x03]); // Random - body.extend_from_slice(&rng.bytes(32)); + if deterministic { + body.extend_from_slice(&deterministic_bytes(&format!("tls-fetch-random:{sni}"), 32)); + } else { + body.extend_from_slice(&rng.bytes(32)); + } // Session ID: empty body.push(0); - // Cipher suites: - // - TLS1.3 set - // - broad TLS1.2 ECDHE set for RSA/ECDSA cert chains - // This keeps raw probing compatible with common production frontends that - // still negotiate TLS1.2. - let cipher_suites: [u16; 11] = [ - 0x1301, // TLS_AES_128_GCM_SHA256 - 0x1302, // TLS_AES_256_GCM_SHA384 - 0x1303, // TLS_CHACHA20_POLY1305_SHA256 - 0xc02b, // TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 - 0xc02c, // TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 - 0xcca9, // TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 - 0xc02f, // TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 - 0xc030, // TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 - 0xcca8, // TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 - 0x009e, // TLS_DHE_RSA_WITH_AES_128_GCM_SHA256 - 0x00ff, // TLS_EMPTY_RENEGOTIATION_INFO_SCSV - ]; + let mut cipher_suites = profile_cipher_suites(profile).to_vec(); + if grease_enabled { + let grease = grease_value(rng, deterministic, &format!("cipher:{sni}")); + cipher_suites.insert(0, grease); + } body.extend_from_slice(&((cipher_suites.len() * 2) as u16).to_be_bytes()); for suite in cipher_suites { body.extend_from_slice(&suite.to_be_bytes()); @@ -150,7 +442,11 @@ fn build_client_hello(sni: &str, rng: &SecureRandom) -> Vec { exts.extend_from_slice(&sni_ext); // supported_groups - let groups: [u16; 2] = [0x001d, 0x0017]; // x25519, secp256r1 + let mut groups = profile_groups(profile).to_vec(); + if grease_enabled { + let grease = grease_value(rng, deterministic, &format!("group:{sni}")); + groups.insert(0, grease); + } exts.extend_from_slice(&0x000au16.to_be_bytes()); exts.extend_from_slice(&((2 + groups.len() * 2) as u16).to_be_bytes()); exts.extend_from_slice(&(groups.len() as u16 * 2).to_be_bytes()); @@ -159,7 +455,11 @@ fn build_client_hello(sni: &str, rng: &SecureRandom) -> Vec { } // signature_algorithms - let sig_algs: [u16; 4] = [0x0804, 0x0805, 0x0403, 0x0503]; // rsa_pss_rsae_sha256/384, ecdsa_secp256r1_sha256, ecdsa_secp384r1_sha384 + let mut sig_algs = profile_sig_algs(profile).to_vec(); + if grease_enabled { + let grease = grease_value(rng, deterministic, &format!("sigalg:{sni}")); + sig_algs.insert(0, grease); + } exts.extend_from_slice(&0x000du16.to_be_bytes()); exts.extend_from_slice(&((2 + sig_algs.len() * 2) as u16).to_be_bytes()); exts.extend_from_slice(&(sig_algs.len() as u16 * 2).to_be_bytes()); @@ -167,8 +467,12 @@ fn build_client_hello(sni: &str, rng: &SecureRandom) -> Vec { exts.extend_from_slice(&a.to_be_bytes()); } - // supported_versions (TLS1.3 + TLS1.2) - let versions: [u16; 2] = [0x0304, 0x0303]; + // supported_versions + let mut versions = profile_supported_versions(profile).to_vec(); + if grease_enabled { + let grease = grease_value(rng, deterministic, &format!("version:{sni}")); + versions.insert(0, grease); + } exts.extend_from_slice(&0x002bu16.to_be_bytes()); exts.extend_from_slice(&((1 + versions.len() * 2) as u16).to_be_bytes()); exts.push((versions.len() * 2) as u8); @@ -177,7 +481,14 @@ fn build_client_hello(sni: &str, rng: &SecureRandom) -> Vec { } // key_share (x25519) - let key = gen_key_share(rng); + let key = if deterministic { + let det = deterministic_bytes(&format!("keyshare:{sni}"), 32); + let mut key = [0u8; 32]; + key.copy_from_slice(&det); + key + } else { + gen_key_share(rng) + }; let mut keyshare = Vec::with_capacity(4 + key.len()); keyshare.extend_from_slice(&0x001du16.to_be_bytes()); // group keyshare.extend_from_slice(&(key.len() as u16).to_be_bytes()); @@ -187,18 +498,29 @@ fn build_client_hello(sni: &str, rng: &SecureRandom) -> Vec { exts.extend_from_slice(&(keyshare.len() as u16).to_be_bytes()); exts.extend_from_slice(&keyshare); - // ALPN (http/1.1) - let alpn_proto = b"http/1.1"; - exts.extend_from_slice(&0x0010u16.to_be_bytes()); - exts.extend_from_slice(&((2 + 1 + alpn_proto.len()) as u16).to_be_bytes()); - exts.extend_from_slice(&((1 + alpn_proto.len()) as u16).to_be_bytes()); - exts.push(alpn_proto.len() as u8); - exts.extend_from_slice(alpn_proto); + // ALPN + let mut alpn_list = Vec::new(); + for proto in profile_alpn(profile) { + alpn_list.push(proto.len() as u8); + alpn_list.extend_from_slice(proto); + } + if !alpn_list.is_empty() { + exts.extend_from_slice(&0x0010u16.to_be_bytes()); + exts.extend_from_slice(&((2 + alpn_list.len()) as u16).to_be_bytes()); + exts.extend_from_slice(&(alpn_list.len() as u16).to_be_bytes()); + exts.extend_from_slice(&alpn_list); + } + + if grease_enabled { + let grease = grease_value(rng, deterministic, &format!("ext:{sni}")); + exts.extend_from_slice(&grease.to_be_bytes()); + exts.extend_from_slice(&0u16.to_be_bytes()); + } // padding to reduce recognizability and keep length ~500 bytes - const TARGET_EXT_LEN: usize = 180; - if exts.len() < TARGET_EXT_LEN { - let remaining = TARGET_EXT_LEN - exts.len(); + let target_ext_len = profile_padding_target(profile); + if exts.len() < target_ext_len { + let remaining = target_ext_len - exts.len(); if remaining > 4 { let pad_len = remaining - 4; // minus type+len exts.extend_from_slice(&0x0015u16.to_be_bytes()); // padding extension @@ -414,27 +736,41 @@ async fn connect_tcp_with_upstream( connect_timeout: Duration, upstream: Option>, scope: Option<&str>, + strict_route: bool, ) -> Result { if let Some(manager) = upstream { - if let Some(addr) = resolve_socket_addr(host, port) { - match manager.connect(addr, None, scope).await { - Ok(stream) => return Ok(stream), - Err(e) => { - warn!( - host = %host, - port = port, - scope = ?scope, - error = %e, - "Upstream connect failed, using direct connect" - ); - } - } - } else if let Ok(mut addrs) = tokio::net::lookup_host((host, port)).await - && let Some(addr) = addrs.find(|a| a.is_ipv4()) - { + let resolved = if let Some(addr) = resolve_socket_addr(host, port) { + Some(addr) + } else { + match tokio::net::lookup_host((host, port)).await { + Ok(mut addrs) => addrs.find(|a| a.is_ipv4()), + Err(e) => { + if strict_route { + return Err(anyhow!( + "upstream route DNS resolution failed for {host}:{port}: {e}" + )); + } + warn!( + host = %host, + port = port, + scope = ?scope, + error = %e, + "Upstream DNS resolution failed, using direct connect" + ); + None + } + } + }; + + if let Some(addr) = resolved { match manager.connect(addr, None, scope).await { Ok(stream) => return Ok(stream), Err(e) => { + if strict_route { + return Err(anyhow!( + "upstream route connect failed for {host}:{port}: {e}" + )); + } warn!( host = %host, port = port, @@ -444,6 +780,10 @@ async fn connect_tcp_with_upstream( ); } } + } else if strict_route { + return Err(anyhow!( + "upstream route resolution produced no usable address for {host}:{port}" + )); } } Ok(UpstreamStream::Tcp( @@ -483,12 +823,15 @@ async fn fetch_via_raw_tls_stream( sni: &str, connect_timeout: Duration, proxy_protocol: u8, + profile: TlsFetchProfile, + grease_enabled: bool, + deterministic: bool, ) -> Result where S: AsyncRead + AsyncWrite + Unpin, { let rng = SecureRandom::new(); - let client_hello = build_client_hello(sni, &rng); + let client_hello = build_client_hello(sni, &rng, profile, grease_enabled, deterministic); timeout(connect_timeout, async { if proxy_protocol > 0 { let header = match proxy_protocol { @@ -562,6 +905,10 @@ async fn fetch_via_raw_tls( scope: Option<&str>, proxy_protocol: u8, unix_sock: Option<&str>, + strict_route: bool, + profile: TlsFetchProfile, + grease_enabled: bool, + deterministic: bool, ) -> Result { #[cfg(unix)] if let Some(sock_path) = unix_sock { @@ -572,8 +919,16 @@ async fn fetch_via_raw_tls( sock = %sock_path, "Raw TLS fetch using mask unix socket" ); - return fetch_via_raw_tls_stream(stream, sni, connect_timeout, proxy_protocol) - .await; + return fetch_via_raw_tls_stream( + stream, + sni, + connect_timeout, + proxy_protocol, + profile, + grease_enabled, + deterministic, + ) + .await; } Ok(Err(e)) => { warn!( @@ -596,8 +951,25 @@ async fn fetch_via_raw_tls( #[cfg(not(unix))] let _ = unix_sock; - let stream = connect_tcp_with_upstream(host, port, connect_timeout, upstream, scope).await?; - fetch_via_raw_tls_stream(stream, sni, connect_timeout, proxy_protocol).await + let stream = connect_tcp_with_upstream( + host, + port, + connect_timeout, + upstream, + scope, + strict_route, + ) + .await?; + fetch_via_raw_tls_stream( + stream, + sni, + connect_timeout, + proxy_protocol, + profile, + grease_enabled, + deterministic, + ) + .await } async fn fetch_via_rustls_stream( @@ -703,6 +1075,7 @@ async fn fetch_via_rustls( scope: Option<&str>, proxy_protocol: u8, unix_sock: Option<&str>, + strict_route: bool, ) -> Result { #[cfg(unix)] if let Some(sock_path) = unix_sock { @@ -736,16 +1109,159 @@ async fn fetch_via_rustls( #[cfg(not(unix))] let _ = unix_sock; - let stream = connect_tcp_with_upstream(host, port, connect_timeout, upstream, scope).await?; + let stream = connect_tcp_with_upstream( + host, + port, + connect_timeout, + upstream, + scope, + strict_route, + ) + .await?; fetch_via_rustls_stream(stream, host, sni, proxy_protocol).await } -/// Fetch real TLS metadata for the given SNI. -/// -/// Strategy: -/// 1) Probe raw TLS for realistic ServerHello and ApplicationData record sizes. -/// 2) Fetch certificate chain via rustls to build cert payload. -/// 3) Merge both when possible; otherwise auto-fallback to whichever succeeded. +/// Fetch real TLS metadata with an adaptive multi-profile strategy. +pub async fn fetch_real_tls_with_strategy( + host: &str, + port: u16, + sni: &str, + strategy: &TlsFetchStrategy, + upstream: Option>, + scope: Option<&str>, + proxy_protocol: u8, + unix_sock: Option<&str>, +) -> Result { + let attempt_timeout = strategy.attempt_timeout.max(Duration::from_millis(1)); + let total_budget = strategy.total_budget.max(Duration::from_millis(1)); + let started_at = Instant::now(); + let cache_key = profile_cache_key( + host, + port, + sni, + upstream.as_ref(), + scope, + proxy_protocol, + unix_sock, + ); + let profiles = order_profiles(strategy, Some(&cache_key), started_at); + + let mut raw_result = None; + let mut raw_last_error: Option = None; + let mut raw_last_error_kind = FetchErrorKind::Other; + let mut selected_profile = None; + + for profile in profiles { + let elapsed = started_at.elapsed(); + if elapsed >= total_budget { + break; + } + let timeout_for_attempt = attempt_timeout.min(total_budget - elapsed); + + match fetch_via_raw_tls( + host, + port, + sni, + timeout_for_attempt, + upstream.clone(), + scope, + proxy_protocol, + unix_sock, + strategy.strict_route, + profile, + strategy.grease_enabled, + strategy.deterministic, + ) + .await + { + Ok(res) => { + selected_profile = Some(profile); + raw_result = Some(res); + break; + } + Err(err) => { + let kind = classify_fetch_error(&err); + warn!( + sni = %sni, + profile = profile.as_str(), + error_kind = ?kind, + error = %err, + "Raw TLS fetch attempt failed" + ); + raw_last_error_kind = kind; + raw_last_error = Some(err); + if strategy.strict_route && matches!(kind, FetchErrorKind::Route) { + break; + } + } + } + } + + if let Some(profile) = selected_profile { + remember_profile_success(strategy, Some(cache_key), profile, Instant::now()); + } + + if raw_result.is_none() + && strategy.strict_route + && matches!(raw_last_error_kind, FetchErrorKind::Route) + { + if let Some(err) = raw_last_error { + return Err(err); + } + return Err(anyhow!("TLS fetch strict-route failure")); + } + + let elapsed = started_at.elapsed(); + if elapsed >= total_budget { + return match raw_result { + Some(raw) => Ok(raw), + None => Err(raw_last_error.unwrap_or_else(|| anyhow!("TLS fetch total budget exhausted"))), + }; + } + + let rustls_timeout = attempt_timeout.min(total_budget - elapsed); + let rustls_result = fetch_via_rustls( + host, + port, + sni, + rustls_timeout, + upstream, + scope, + proxy_protocol, + unix_sock, + strategy.strict_route, + ) + .await; + + match rustls_result { + Ok(rustls) => { + if let Some(mut raw) = raw_result { + raw.cert_info = rustls.cert_info; + raw.cert_payload = rustls.cert_payload; + raw.behavior_profile.source = TlsProfileSource::Merged; + debug!(sni = %sni, "Fetched TLS metadata via adaptive raw probe + rustls cert chain"); + Ok(raw) + } else { + Ok(rustls) + } + } + Err(err) => { + if let Some(raw) = raw_result { + warn!(sni = %sni, error = %err, "Rustls cert fetch failed, using raw TLS metadata only"); + Ok(raw) + } else if let Some(raw_err) = raw_last_error { + Err(anyhow!( + "TLS fetch failed (raw: {raw_err}; rustls: {err})" + )) + } else { + Err(err) + } + } + } +} + +/// Fetch real TLS metadata for the given SNI using a single-attempt compatibility strategy. +#[allow(dead_code)] pub async fn fetch_real_tls( host: &str, port: u16, @@ -756,62 +1272,30 @@ pub async fn fetch_real_tls( proxy_protocol: u8, unix_sock: Option<&str>, ) -> Result { - let raw_result = match fetch_via_raw_tls( + let strategy = TlsFetchStrategy::single_attempt(connect_timeout); + fetch_real_tls_with_strategy( host, port, sni, - connect_timeout, - upstream.clone(), - scope, - proxy_protocol, - unix_sock, - ) - .await - { - Ok(res) => Some(res), - Err(e) => { - warn!(sni = %sni, error = %e, "Raw TLS fetch failed"); - None - } - }; - - match fetch_via_rustls( - host, - port, - sni, - connect_timeout, + &strategy, upstream, scope, proxy_protocol, unix_sock, ) .await - { - Ok(rustls_result) => { - if let Some(mut raw) = raw_result { - raw.cert_info = rustls_result.cert_info; - raw.cert_payload = rustls_result.cert_payload; - raw.behavior_profile.source = TlsProfileSource::Merged; - debug!(sni = %sni, "Fetched TLS metadata via raw probe + rustls cert chain"); - Ok(raw) - } else { - Ok(rustls_result) - } - } - Err(e) => { - if let Some(raw) = raw_result { - warn!(sni = %sni, error = %e, "Rustls cert fetch failed, using raw TLS metadata only"); - Ok(raw) - } else { - Err(e) - } - } - } } #[cfg(test)] mod tests { - use super::{derive_behavior_profile, encode_tls13_certificate_message}; + use std::time::{Duration, Instant}; + + use super::{ + ProfileCacheValue, TlsFetchStrategy, build_client_hello, derive_behavior_profile, + encode_tls13_certificate_message, order_profiles, profile_cache, profile_cache_key, + }; + use crate::config::TlsFetchProfile; + use crate::crypto::SecureRandom; use crate::protocol::constants::{ TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER, TLS_RECORD_HANDSHAKE, }; @@ -860,4 +1344,89 @@ mod tests { assert_eq!(profile.ticket_record_sizes, vec![220, 180]); assert_eq!(profile.source, TlsProfileSource::Raw); } + + #[test] + fn test_order_profiles_prioritizes_fresh_cached_winner() { + let strategy = TlsFetchStrategy { + profiles: vec![ + TlsFetchProfile::ModernChromeLike, + TlsFetchProfile::CompatTls12, + TlsFetchProfile::LegacyMinimal, + ], + strict_route: true, + attempt_timeout: Duration::from_secs(1), + total_budget: Duration::from_secs(2), + grease_enabled: false, + deterministic: false, + profile_cache_ttl: Duration::from_secs(60), + }; + let cache_key = profile_cache_key( + "mask.example", + 443, + "tls.example", + None, + Some("tls"), + 0, + None, + ); + profile_cache().remove(&cache_key); + profile_cache().insert( + cache_key.clone(), + ProfileCacheValue { + profile: TlsFetchProfile::CompatTls12, + updated_at: Instant::now(), + }, + ); + + let ordered = order_profiles(&strategy, Some(&cache_key), Instant::now()); + assert_eq!(ordered[0], TlsFetchProfile::CompatTls12); + profile_cache().remove(&cache_key); + } + + #[test] + fn test_order_profiles_drops_expired_cached_winner() { + let strategy = TlsFetchStrategy { + profiles: vec![TlsFetchProfile::ModernFirefoxLike, TlsFetchProfile::CompatTls12], + strict_route: true, + attempt_timeout: Duration::from_secs(1), + total_budget: Duration::from_secs(2), + grease_enabled: false, + deterministic: false, + profile_cache_ttl: Duration::from_secs(5), + }; + let cache_key = profile_cache_key("mask2.example", 443, "tls2.example", None, None, 0, None); + profile_cache().remove(&cache_key); + profile_cache().insert( + cache_key.clone(), + ProfileCacheValue { + profile: TlsFetchProfile::CompatTls12, + updated_at: Instant::now() - Duration::from_secs(6), + }, + ); + + let ordered = order_profiles(&strategy, Some(&cache_key), Instant::now()); + assert_eq!(ordered[0], TlsFetchProfile::ModernFirefoxLike); + assert!(profile_cache().get(&cache_key).is_none()); + } + + #[test] + fn test_deterministic_client_hello_is_stable() { + let rng = SecureRandom::new(); + let first = build_client_hello( + "stable.example", + &rng, + TlsFetchProfile::ModernChromeLike, + true, + true, + ); + let second = build_client_hello( + "stable.example", + &rng, + TlsFetchProfile::ModernChromeLike, + true, + true, + ); + + assert_eq!(first, second); + } } From 8b92b80b4af7b1f19317b0e12cbd2dbc1aa418f7 Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Tue, 24 Mar 2026 10:33:06 +0300 Subject: [PATCH 29/29] Rustks CryptoProvider fixes + Rustfmt --- src/main.rs | 1 + src/proxy/handshake.rs | 114 ++++++++++++++--------- src/tls_front/fetcher.rs | 45 ++++----- src/transport/middle_proxy/http_fetch.rs | 5 +- 4 files changed, 93 insertions(+), 72 deletions(-) diff --git a/src/main.rs b/src/main.rs index 406b321..e5d931f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -29,5 +29,6 @@ mod util; #[tokio::main] async fn main() -> std::result::Result<(), Box> { + let _ = rustls::crypto::ring::default_provider().install_default(); maestro::run().await } diff --git a/src/proxy/handshake.rs b/src/proxy/handshake.rs index 9d48fe9..2ef8e1b 100644 --- a/src/proxy/handshake.rs +++ b/src/proxy/handshake.rs @@ -282,30 +282,9 @@ fn auth_probe_record_failure_with_state( let mut eviction_candidate: Option<(IpAddr, u32, Instant)> = None; let state_len = state.len(); let scan_limit = state_len.min(AUTH_PROBE_PRUNE_SCAN_LIMIT); - let start_offset = auth_probe_scan_start_offset(peer_ip, now, state_len, scan_limit); - let mut scanned = 0usize; - for entry in state.iter().skip(start_offset) { - let key = *entry.key(); - let fail_streak = entry.value().fail_streak; - let last_seen = entry.value().last_seen; - match eviction_candidate { - Some((_, current_fail, current_seen)) - if fail_streak > current_fail - || (fail_streak == current_fail && last_seen >= current_seen) => {} - _ => eviction_candidate = Some((key, fail_streak, last_seen)), - } - if auth_probe_state_expired(entry.value(), now) { - stale_keys.push(key); - } - scanned += 1; - if scanned >= scan_limit { - break; - } - } - - if scanned < scan_limit { - for entry in state.iter().take(scan_limit - scanned) { + if state_len <= AUTH_PROBE_PRUNE_SCAN_LIMIT { + for entry in state.iter() { let key = *entry.key(); let fail_streak = entry.value().fail_streak; let last_seen = entry.value().last_seen; @@ -319,6 +298,46 @@ fn auth_probe_record_failure_with_state( stale_keys.push(key); } } + } else { + let start_offset = + auth_probe_scan_start_offset(peer_ip, now, state_len, scan_limit); + let mut scanned = 0usize; + for entry in state.iter().skip(start_offset) { + let key = *entry.key(); + let fail_streak = entry.value().fail_streak; + let last_seen = entry.value().last_seen; + match eviction_candidate { + Some((_, current_fail, current_seen)) + if fail_streak > current_fail + || (fail_streak == current_fail && last_seen >= current_seen) => {} + _ => eviction_candidate = Some((key, fail_streak, last_seen)), + } + if auth_probe_state_expired(entry.value(), now) { + stale_keys.push(key); + } + scanned += 1; + if scanned >= scan_limit { + break; + } + } + + if scanned < scan_limit { + for entry in state.iter().take(scan_limit - scanned) { + let key = *entry.key(); + let fail_streak = entry.value().fail_streak; + let last_seen = entry.value().last_seen; + match eviction_candidate { + Some((_, current_fail, current_seen)) + if fail_streak > current_fail + || (fail_streak == current_fail + && last_seen >= current_seen) => {} + _ => eviction_candidate = Some((key, fail_streak, last_seen)), + } + if auth_probe_state_expired(entry.value(), now) { + stale_keys.push(key); + } + } + } } for stale_key in stale_keys { @@ -608,11 +627,35 @@ where } let client_sni = tls::extract_sni_from_client_hello(handshake); + let preferred_user_hint = client_sni + .as_deref() + .filter(|sni| config.access.users.contains_key(*sni)); let matched_tls_domain = client_sni .as_deref() .and_then(|sni| find_matching_tls_domain(config, sni)); - if client_sni.is_some() && matched_tls_domain.is_none() { + let alpn_list = if config.censorship.alpn_enforce { + tls::extract_alpn_from_client_hello(handshake) + } else { + Vec::new() + }; + let selected_alpn = if config.censorship.alpn_enforce { + if alpn_list.iter().any(|p| p == b"h2") { + Some(b"h2".to_vec()) + } else if alpn_list.iter().any(|p| p == b"http/1.1") { + Some(b"http/1.1".to_vec()) + } else if !alpn_list.is_empty() { + maybe_apply_server_hello_delay(config).await; + debug!(peer = %peer, "Client ALPN list has no supported protocol; using masking fallback"); + return HandshakeResult::BadClient { reader, writer }; + } else { + None + } + } else { + None + }; + + if client_sni.is_some() && matched_tls_domain.is_none() && preferred_user_hint.is_none() { auth_probe_record_failure(peer.ip(), Instant::now()); maybe_apply_server_hello_delay(config).await; debug!( @@ -627,7 +670,7 @@ where }; } - let secrets = decode_user_secrets(config, client_sni.as_deref()); + let secrets = decode_user_secrets(config, preferred_user_hint); let validation = match tls::validate_tls_handshake_with_replay_window( handshake, @@ -684,27 +727,6 @@ where None }; - let alpn_list = if config.censorship.alpn_enforce { - tls::extract_alpn_from_client_hello(handshake) - } else { - Vec::new() - }; - let selected_alpn = if config.censorship.alpn_enforce { - if alpn_list.iter().any(|p| p == b"h2") { - Some(b"h2".to_vec()) - } else if alpn_list.iter().any(|p| p == b"http/1.1") { - Some(b"http/1.1".to_vec()) - } else if !alpn_list.is_empty() { - maybe_apply_server_hello_delay(config).await; - debug!(peer = %peer, "Client ALPN list has no supported protocol; using masking fallback"); - return HandshakeResult::BadClient { reader, writer }; - } else { - None - } - } else { - None - }; - // Add replay digest only for policy-valid handshakes. replay_checker.add_tls_digest(digest_half); diff --git a/src/tls_front/fetcher.rs b/src/tls_front/fetcher.rs index 503b79c..bbfc336 100644 --- a/src/tls_front/fetcher.rs +++ b/src/tls_front/fetcher.rs @@ -241,7 +241,10 @@ fn order_profiles( return ordered; } - if let Some(pos) = ordered.iter().position(|profile| *profile == cached.profile) { + if let Some(pos) = ordered + .iter() + .position(|profile| *profile == cached.profile) + { if pos != 0 { ordered.swap(0, pos); } @@ -951,15 +954,9 @@ async fn fetch_via_raw_tls( #[cfg(not(unix))] let _ = unix_sock; - let stream = connect_tcp_with_upstream( - host, - port, - connect_timeout, - upstream, - scope, - strict_route, - ) - .await?; + let stream = + connect_tcp_with_upstream(host, port, connect_timeout, upstream, scope, strict_route) + .await?; fetch_via_raw_tls_stream( stream, sni, @@ -1109,15 +1106,9 @@ async fn fetch_via_rustls( #[cfg(not(unix))] let _ = unix_sock; - let stream = connect_tcp_with_upstream( - host, - port, - connect_timeout, - upstream, - scope, - strict_route, - ) - .await?; + let stream = + connect_tcp_with_upstream(host, port, connect_timeout, upstream, scope, strict_route) + .await?; fetch_via_rustls_stream(stream, host, sni, proxy_protocol).await } @@ -1215,7 +1206,9 @@ pub async fn fetch_real_tls_with_strategy( if elapsed >= total_budget { return match raw_result { Some(raw) => Ok(raw), - None => Err(raw_last_error.unwrap_or_else(|| anyhow!("TLS fetch total budget exhausted"))), + None => { + Err(raw_last_error.unwrap_or_else(|| anyhow!("TLS fetch total budget exhausted"))) + } }; } @@ -1250,9 +1243,7 @@ pub async fn fetch_real_tls_with_strategy( warn!(sni = %sni, error = %err, "Rustls cert fetch failed, using raw TLS metadata only"); Ok(raw) } else if let Some(raw_err) = raw_last_error { - Err(anyhow!( - "TLS fetch failed (raw: {raw_err}; rustls: {err})" - )) + Err(anyhow!("TLS fetch failed (raw: {raw_err}; rustls: {err})")) } else { Err(err) } @@ -1386,7 +1377,10 @@ mod tests { #[test] fn test_order_profiles_drops_expired_cached_winner() { let strategy = TlsFetchStrategy { - profiles: vec![TlsFetchProfile::ModernFirefoxLike, TlsFetchProfile::CompatTls12], + profiles: vec![ + TlsFetchProfile::ModernFirefoxLike, + TlsFetchProfile::CompatTls12, + ], strict_route: true, attempt_timeout: Duration::from_secs(1), total_budget: Duration::from_secs(2), @@ -1394,7 +1388,8 @@ mod tests { deterministic: false, profile_cache_ttl: Duration::from_secs(5), }; - let cache_key = profile_cache_key("mask2.example", 443, "tls2.example", None, None, 0, None); + let cache_key = + profile_cache_key("mask2.example", 443, "tls2.example", None, None, 0, None); profile_cache().remove(&cache_key); profile_cache().insert( cache_key.clone(), diff --git a/src/transport/middle_proxy/http_fetch.rs b/src/transport/middle_proxy/http_fetch.rs index 2f21934..5be601e 100644 --- a/src/transport/middle_proxy/http_fetch.rs +++ b/src/transport/middle_proxy/http_fetch.rs @@ -27,7 +27,10 @@ pub(crate) struct HttpsGetResponse { fn build_tls_client_config() -> Arc { let mut root_store = rustls::RootCertStore::empty(); root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); - let config = rustls::ClientConfig::builder() + let provider = rustls::crypto::ring::default_provider(); + let config = rustls::ClientConfig::builder_with_provider(Arc::new(provider)) + .with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12]) + .expect("HTTPS fetch rustls protocol versions must be valid") .with_root_certificates(root_store) .with_no_client_auth(); Arc::new(config)