diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 799f2ce..b245679 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -11,7 +11,7 @@ env: jobs: build: - name: Build + name: Compile, Test, Lint runs-on: ubuntu-latest permissions: @@ -39,23 +39,11 @@ jobs: restore-keys: | ${{ runner.os }}-cargo- - - name: Build Release - run: cargo build --release --verbose + - name: Compile (no tests) + run: cargo check --workspace --all-features --lib --bins --verbose - - name: Run tests - run: cargo test --verbose - - - name: Stress quota-lock suites (PR only) - if: github.event_name == 'pull_request' - env: - RUST_TEST_THREADS: 16 - run: | - set -euo pipefail - for i in $(seq 1 12); do - echo "[quota-lock-stress] iteration ${i}/12" - cargo test quota_lock_ --bin telemt -- --nocapture --test-threads 16 - cargo test relay_quota_wake --bin telemt -- --nocapture --test-threads 16 - done + - name: Run tests (single pass) + run: cargo test --workspace --all-features --verbose # clippy dont fail on warnings because of active development of telemt # and many warnings diff --git a/.github/workflows/stress.yml b/.github/workflows/stress.yml new file mode 100644 index 0000000..96b9a1b --- /dev/null +++ b/.github/workflows/stress.yml @@ -0,0 +1,57 @@ +name: Stress Tests + +on: + workflow_dispatch: + schedule: + - cron: '0 2 * * *' + pull_request: + branches: ["*"] + paths: + - src/proxy/** + - src/transport/** + - src/stream/** + - src/protocol/** + - src/tls_front/** + - Cargo.toml + - Cargo.lock + +env: + CARGO_TERM_COLOR: always + +jobs: + quota-lock-stress: + name: Quota-lock stress loop + runs-on: ubuntu-latest + + permissions: + contents: read + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Install latest stable Rust toolchain + uses: dtolnay/rust-toolchain@stable + + - name: Cache cargo registry and build artifacts + uses: actions/cache@v4 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + target + key: ${{ runner.os }}-cargo-stress-${{ hashFiles('**/Cargo.lock') }} + restore-keys: | + ${{ runner.os }}-cargo-stress- + ${{ runner.os }}-cargo- + + - name: Run quota-lock stress suites + env: + RUST_TEST_THREADS: 16 + run: | + set -euo pipefail + for i in $(seq 1 12); do + echo "[quota-lock-stress] iteration ${i}/12" + cargo test quota_lock_ --bin telemt -- --nocapture --test-threads 16 + cargo test relay_quota_wake --bin telemt -- --nocapture --test-threads 16 + done diff --git a/Cargo.lock b/Cargo.lock index 8159a22..92da630 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1454,9 +1454,9 @@ dependencies = [ [[package]] name = "iri-string" -version = "0.7.10" +version = "0.7.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c91338f0783edbd6195decb37bae672fd3b165faffb89bf7b9e6942f8b1a731a" +checksum = "d8e7418f59cc01c88316161279a7f665217ae316b388e58a0d10e29f54f1e5eb" dependencies = [ "memchr", "serde", @@ -1486,7 +1486,7 @@ dependencies = [ "cesu8", "cfg-if", "combine", - "jni-sys", + "jni-sys 0.3.1", "log", "thiserror 1.0.69", "walkdir", @@ -1495,9 +1495,31 @@ dependencies = [ [[package]] name = "jni-sys" -version = "0.3.0" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130" +checksum = "41a652e1f9b6e0275df1f15b32661cf0d4b78d4d87ddec5e0c3c20f097433258" +dependencies = [ + "jni-sys 0.4.1", +] + +[[package]] +name = "jni-sys" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c6377a88cb3910bee9b0fa88d4f42e1d2da8e79915598f65fb0c7ee14c878af2" +dependencies = [ + "jni-sys-macros", +] + +[[package]] +name = "jni-sys-macros" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38c0b942f458fe50cdac086d2f946512305e5631e720728f2a61aabcd47a6264" +dependencies = [ + "quote", + "syn", +] [[package]] name = "jobserver" @@ -1659,9 +1681,9 @@ dependencies = [ [[package]] name = "moka" -version = "0.12.14" +version = "0.12.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85f8024e1c8e71c778968af91d43700ce1d11b219d127d79fb2934153b82b42b" +checksum = "957228ad12042ee839f93c8f257b62b4c0ab5eaae1d4fa60de53b27c9d7c5046" dependencies = [ "crossbeam-channel", "crossbeam-epoch", @@ -2771,7 +2793,7 @@ checksum = "7b2093cf4c8eb1e67749a6762251bc9cd836b6fc171623bd0a9d324d37af2417" [[package]] name = "telemt" -version = "3.3.29" +version = "3.3.30" dependencies = [ "aes", "anyhow", diff --git a/Cargo.toml b/Cargo.toml index 53082db..1e06b7a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,8 +1,11 @@ [package] name = "telemt" -version = "3.3.29" +version = "3.3.30" edition = "2024" +[features] +redteam_offline_expected_fail = [] + [dependencies] # C libc = "0.2" diff --git a/docs/CONFIG_PARAMS.en.md b/docs/CONFIG_PARAMS.en.md index 33e5b29..e9d42a9 100644 --- a/docs/CONFIG_PARAMS.en.md +++ b/docs/CONFIG_PARAMS.en.md @@ -202,12 +202,15 @@ This document lists all configuration keys accepted by `config.toml`. | listen_tcp | `bool \| null` | `null` (auto) | — | Explicit TCP listener enable/disable override. | | proxy_protocol | `bool` | `false` | — | Enables HAProxy PROXY protocol parsing on incoming client connections. | | proxy_protocol_header_timeout_ms | `u64` | `500` | Must be `> 0`. | Timeout for PROXY protocol header read/parse (ms). | +| proxy_protocol_trusted_cidrs | `IpNetwork[]` | `[]` | — | When non-empty, only connections from these proxy source CIDRs are allowed to provide PROXY protocol headers. If empty, PROXY headers are rejected by default (security hardening). | | metrics_port | `u16 \| null` | `null` | — | Metrics endpoint port (enables metrics listener). | | metrics_listen | `String \| null` | `null` | — | Full metrics bind address (`IP:PORT`), overrides `metrics_port`. | | metrics_whitelist | `IpNetwork[]` | `["127.0.0.1/32", "::1/128"]` | — | CIDR whitelist for metrics endpoint access. | | max_connections | `u32` | `10000` | — | Max concurrent client connections (`0` = unlimited). | | accept_permit_timeout_ms | `u64` | `250` | `0..=60000`. | Maximum wait for acquiring a connection-slot permit before the accepted connection is dropped (`0` keeps legacy unbounded wait). | +Note: When `server.proxy_protocol` is enabled, incoming PROXY protocol headers are parsed from the first bytes of the connection and the client source address is replaced with `src_addr` from the header. For security, the peer source IP (the direct connection address) is verified against `server.proxy_protocol_trusted_cidrs`; if this list is empty, PROXY headers are rejected and the connection is considered untrusted. + ## [server.api] | Parameter | Type | Default | Constraints / validation | Description | @@ -271,6 +274,8 @@ This document lists all configuration keys accepted by `config.toml`. | mask_shape_bucket_cap_bytes | `usize` | `4096` | Must be `>= mask_shape_bucket_floor_bytes`. | Maximum bucket size used by shape-channel hardening; traffic above cap is not padded further. | | mask_shape_above_cap_blur | `bool` | `false` | Requires `mask_shape_hardening = true`; requires `mask_shape_above_cap_blur_max_bytes > 0`. | Adds bounded randomized tail bytes even when forwarded size already exceeds cap. | | mask_shape_above_cap_blur_max_bytes | `usize` | `512` | Must be `<= 1048576`; must be `> 0` when `mask_shape_above_cap_blur = true`. | Maximum randomized extra bytes appended above cap. | +| mask_relay_max_bytes | `usize` | `5242880` | Must be `> 0`; must be `<= 67108864`. | Maximum relayed bytes per direction on unauthenticated masking fallback path. | +| mask_classifier_prefetch_timeout_ms | `u64` | `5` | Must be within `[5, 50]`. | Timeout budget (ms) for extending fragmented initial classifier window on masking fallback. | | mask_timing_normalization_enabled | `bool` | `false` | Requires `mask_timing_normalization_floor_ms > 0`; requires `ceiling >= floor`. | Enables timing envelope normalization on masking outcomes. | | mask_timing_normalization_floor_ms | `u64` | `0` | Must be `> 0` when timing normalization is enabled; must be `<= ceiling`. | Lower bound (ms) for masking outcome normalization target. | | mask_timing_normalization_ceiling_ms | `u64` | `0` | Must be `>= floor`; must be `<= 60000`. | Upper bound (ms) for masking outcome normalization target. | diff --git a/src/config/defaults.rs b/src/config/defaults.rs index 66ffeda..09d146a 100644 --- a/src/config/defaults.rs +++ b/src/config/defaults.rs @@ -553,6 +553,20 @@ pub(crate) fn default_mask_shape_above_cap_blur_max_bytes() -> usize { 512 } +#[cfg(not(test))] +pub(crate) fn default_mask_relay_max_bytes() -> usize { + 5 * 1024 * 1024 +} + +#[cfg(test)] +pub(crate) fn default_mask_relay_max_bytes() -> usize { + 32 * 1024 +} + +pub(crate) fn default_mask_classifier_prefetch_timeout_ms() -> u64 { + 5 +} + pub(crate) fn default_mask_timing_normalization_enabled() -> bool { false } diff --git a/src/config/hot_reload.rs b/src/config/hot_reload.rs index e580b7f..a3f795a 100644 --- a/src/config/hot_reload.rs +++ b/src/config/hot_reload.rs @@ -600,6 +600,9 @@ fn warn_non_hot_changes(old: &ProxyConfig, new: &ProxyConfig, non_hot_changed: b || old.censorship.mask_shape_above_cap_blur != new.censorship.mask_shape_above_cap_blur || old.censorship.mask_shape_above_cap_blur_max_bytes != new.censorship.mask_shape_above_cap_blur_max_bytes + || old.censorship.mask_relay_max_bytes != new.censorship.mask_relay_max_bytes + || old.censorship.mask_classifier_prefetch_timeout_ms + != new.censorship.mask_classifier_prefetch_timeout_ms || old.censorship.mask_timing_normalization_enabled != new.censorship.mask_timing_normalization_enabled || old.censorship.mask_timing_normalization_floor_ms diff --git a/src/config/load.rs b/src/config/load.rs index bf6d036..fc54ec2 100644 --- a/src/config/load.rs +++ b/src/config/load.rs @@ -430,6 +430,25 @@ impl ProxyConfig { )); } + if config.censorship.mask_relay_max_bytes == 0 { + return Err(ProxyError::Config( + "censorship.mask_relay_max_bytes must be > 0".to_string(), + )); + } + + if config.censorship.mask_relay_max_bytes > 67_108_864 { + return Err(ProxyError::Config( + "censorship.mask_relay_max_bytes must be <= 67108864".to_string(), + )); + } + + if !(5..=50).contains(&config.censorship.mask_classifier_prefetch_timeout_ms) { + return Err(ProxyError::Config( + "censorship.mask_classifier_prefetch_timeout_ms must be within [5, 50]" + .to_string(), + )); + } + if config.censorship.mask_timing_normalization_ceiling_ms < config.censorship.mask_timing_normalization_floor_ms { @@ -1134,6 +1153,10 @@ mod load_security_tests; #[path = "tests/load_mask_shape_security_tests.rs"] mod load_mask_shape_security_tests; +#[cfg(test)] +#[path = "tests/load_mask_classifier_prefetch_timeout_security_tests.rs"] +mod load_mask_classifier_prefetch_timeout_security_tests; + #[cfg(test)] mod tests { use super::*; diff --git a/src/config/tests/load_mask_classifier_prefetch_timeout_security_tests.rs b/src/config/tests/load_mask_classifier_prefetch_timeout_security_tests.rs new file mode 100644 index 0000000..49ee953 --- /dev/null +++ b/src/config/tests/load_mask_classifier_prefetch_timeout_security_tests.rs @@ -0,0 +1,75 @@ +use super::*; +use std::fs; +use std::path::PathBuf; +use std::time::{SystemTime, UNIX_EPOCH}; + +fn write_temp_config(contents: &str) -> PathBuf { + let nonce = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("system time must be after unix epoch") + .as_nanos(); + let path = std::env::temp_dir() + .join(format!("telemt-load-mask-prefetch-timeout-security-{nonce}.toml")); + fs::write(&path, contents).expect("temp config write must succeed"); + path +} + +fn remove_temp_config(path: &PathBuf) { + let _ = fs::remove_file(path); +} + +#[test] +fn load_rejects_mask_classifier_prefetch_timeout_below_min_bound() { + let path = write_temp_config( + r#" +[censorship] +mask_classifier_prefetch_timeout_ms = 4 +"#, + ); + + let err = ProxyConfig::load(&path) + .expect_err("prefetch timeout below minimum security bound must be rejected"); + let msg = err.to_string(); + assert!( + msg.contains("censorship.mask_classifier_prefetch_timeout_ms must be within [5, 50]"), + "error must explain timeout bound invariant, got: {msg}" + ); + + remove_temp_config(&path); +} + +#[test] +fn load_rejects_mask_classifier_prefetch_timeout_above_max_bound() { + let path = write_temp_config( + r#" +[censorship] +mask_classifier_prefetch_timeout_ms = 51 +"#, + ); + + let err = ProxyConfig::load(&path) + .expect_err("prefetch timeout above max security bound must be rejected"); + let msg = err.to_string(); + assert!( + msg.contains("censorship.mask_classifier_prefetch_timeout_ms must be within [5, 50]"), + "error must explain timeout bound invariant, got: {msg}" + ); + + remove_temp_config(&path); +} + +#[test] +fn load_accepts_mask_classifier_prefetch_timeout_within_bounds() { + let path = write_temp_config( + r#" +[censorship] +mask_classifier_prefetch_timeout_ms = 20 +"#, + ); + + let cfg = ProxyConfig::load(&path) + .expect("prefetch timeout within security bounds must be accepted"); + assert_eq!(cfg.censorship.mask_classifier_prefetch_timeout_ms, 20); + + remove_temp_config(&path); +} diff --git a/src/config/tests/load_mask_shape_security_tests.rs b/src/config/tests/load_mask_shape_security_tests.rs index 8986a49..2e4aa41 100644 --- a/src/config/tests/load_mask_shape_security_tests.rs +++ b/src/config/tests/load_mask_shape_security_tests.rs @@ -236,3 +236,57 @@ mask_shape_above_cap_blur_max_bytes = 8 remove_temp_config(&path); } + +#[test] +fn load_rejects_zero_mask_relay_max_bytes() { + let path = write_temp_config( + r#" +[censorship] +mask_relay_max_bytes = 0 +"#, + ); + + let err = ProxyConfig::load(&path).expect_err("mask_relay_max_bytes must be > 0"); + let msg = err.to_string(); + assert!( + msg.contains("censorship.mask_relay_max_bytes must be > 0"), + "error must explain non-zero relay cap invariant, got: {msg}" + ); + + remove_temp_config(&path); +} + +#[test] +fn load_rejects_mask_relay_max_bytes_above_upper_bound() { + let path = write_temp_config( + r#" +[censorship] +mask_relay_max_bytes = 67108865 +"#, + ); + + let err = ProxyConfig::load(&path) + .expect_err("mask_relay_max_bytes above hard cap must be rejected"); + let msg = err.to_string(); + assert!( + msg.contains("censorship.mask_relay_max_bytes must be <= 67108864"), + "error must explain relay cap upper bound invariant, got: {msg}" + ); + + remove_temp_config(&path); +} + +#[test] +fn load_accepts_valid_mask_relay_max_bytes() { + let path = write_temp_config( + r#" +[censorship] +mask_relay_max_bytes = 8388608 +"#, + ); + + let cfg = ProxyConfig::load(&path).expect("valid mask_relay_max_bytes must be accepted"); + assert_eq!(cfg.censorship.mask_relay_max_bytes, 8_388_608); + + remove_temp_config(&path); +} diff --git a/src/config/types.rs b/src/config/types.rs index aa58dc1..5dc9719 100644 --- a/src/config/types.rs +++ b/src/config/types.rs @@ -1450,6 +1450,14 @@ pub struct AntiCensorshipConfig { #[serde(default = "default_mask_shape_above_cap_blur_max_bytes")] pub mask_shape_above_cap_blur_max_bytes: usize, + /// Maximum bytes relayed per direction on unauthenticated masking fallback paths. + #[serde(default = "default_mask_relay_max_bytes")] + pub mask_relay_max_bytes: usize, + + /// Prefetch timeout (ms) for extending fragmented masking classifier window. + #[serde(default = "default_mask_classifier_prefetch_timeout_ms")] + pub mask_classifier_prefetch_timeout_ms: u64, + /// Enable outcome-time normalization envelope for masking fallback. #[serde(default = "default_mask_timing_normalization_enabled")] pub mask_timing_normalization_enabled: bool, @@ -1488,6 +1496,8 @@ impl Default for AntiCensorshipConfig { mask_shape_bucket_cap_bytes: default_mask_shape_bucket_cap_bytes(), mask_shape_above_cap_blur: default_mask_shape_above_cap_blur(), mask_shape_above_cap_blur_max_bytes: default_mask_shape_above_cap_blur_max_bytes(), + mask_relay_max_bytes: default_mask_relay_max_bytes(), + mask_classifier_prefetch_timeout_ms: default_mask_classifier_prefetch_timeout_ms(), mask_timing_normalization_enabled: default_mask_timing_normalization_enabled(), mask_timing_normalization_floor_ms: default_mask_timing_normalization_floor_ms(), mask_timing_normalization_ceiling_ms: default_mask_timing_normalization_ceiling_ms(), diff --git a/src/maestro/runtime_tasks.rs b/src/maestro/runtime_tasks.rs index d553eb9..066c853 100644 --- a/src/maestro/runtime_tasks.rs +++ b/src/maestro/runtime_tasks.rs @@ -32,6 +32,14 @@ pub(crate) struct RuntimeWatches { pub(crate) detected_ip_v6: Option, } +const QUOTA_USER_LOCK_EVICT_INTERVAL_SECS: u64 = 60; + +fn spawn_quota_lock_maintenance_task() -> tokio::task::JoinHandle<()> { + crate::proxy::relay::spawn_quota_user_lock_evictor(std::time::Duration::from_secs( + QUOTA_USER_LOCK_EVICT_INTERVAL_SECS, + )) +} + #[allow(clippy::too_many_arguments)] pub(crate) async fn spawn_runtime_tasks( config: &Arc, @@ -69,6 +77,8 @@ pub(crate) async fn spawn_runtime_tasks( rc_clone.run_periodic_cleanup().await; }); + spawn_quota_lock_maintenance_task(); + let detected_ip_v4: Option = probe.detected_ipv4.map(IpAddr::V4); let detected_ip_v6: Option = probe.detected_ipv6.map(IpAddr::V6); debug!( @@ -360,3 +370,24 @@ pub(crate) async fn mark_runtime_ready(startup_tracker: &Arc) { .await; startup_tracker.mark_ready().await; } + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn tdd_runtime_quota_lock_maintenance_path_spawns_single_evictor_task() { + crate::proxy::relay::reset_quota_user_lock_evictor_spawn_count_for_tests(); + + let handle = spawn_quota_lock_maintenance_task(); + tokio::time::sleep(std::time::Duration::from_millis(5)).await; + + assert_eq!( + crate::proxy::relay::quota_user_lock_evictor_spawn_count_for_tests(), + 1, + "runtime maintenance path must spawn exactly one quota lock evictor task per call" + ); + + handle.abort(); + } +} diff --git a/src/proxy/client.rs b/src/proxy/client.rs index 4b7f57e..a804a2c 100644 --- a/src/proxy/client.rs +++ b/src/proxy/client.rs @@ -186,6 +186,67 @@ fn handshake_timeout_with_mask_grace(config: &ProxyConfig) -> Duration { } } +const MASK_CLASSIFIER_PREFETCH_WINDOW: usize = 16; +#[cfg(test)] +const MASK_CLASSIFIER_PREFETCH_TIMEOUT: Duration = Duration::from_millis(5); + +fn mask_classifier_prefetch_timeout(config: &ProxyConfig) -> Duration { + Duration::from_millis(config.censorship.mask_classifier_prefetch_timeout_ms) +} + +fn should_prefetch_mask_classifier_window(initial_data: &[u8]) -> bool { + if initial_data.len() >= MASK_CLASSIFIER_PREFETCH_WINDOW { + return false; + } + + if initial_data.is_empty() { + // Empty initial_data means there is no client probe prefix to refine. + // Prefetching in this case can consume fallback relay payload bytes and + // accidentally route them through shaping heuristics. + return false; + } + + if initial_data[0] == 0x16 || initial_data.starts_with(b"SSH-") { + return false; + } + + initial_data.iter().all(|b| b.is_ascii_alphabetic() || *b == b' ') +} + +#[cfg(test)] +async fn extend_masking_initial_window(reader: &mut R, initial_data: &mut Vec) +where + R: AsyncRead + Unpin, +{ + extend_masking_initial_window_with_timeout(reader, initial_data, MASK_CLASSIFIER_PREFETCH_TIMEOUT) + .await; +} + +async fn extend_masking_initial_window_with_timeout( + reader: &mut R, + initial_data: &mut Vec, + prefetch_timeout: Duration, +) +where + R: AsyncRead + Unpin, +{ + if !should_prefetch_mask_classifier_window(initial_data) { + return; + } + + let need = MASK_CLASSIFIER_PREFETCH_WINDOW.saturating_sub(initial_data.len()); + if need == 0 { + return; + } + + let mut extra = [0u8; MASK_CLASSIFIER_PREFETCH_WINDOW]; + if let Ok(Ok(n)) = timeout(prefetch_timeout, reader.read(&mut extra[..need])).await + && n > 0 + { + initial_data.extend_from_slice(&extra[..n]); + } +} + fn masking_outcome( reader: R, writer: W, @@ -200,6 +261,15 @@ where W: AsyncWrite + Unpin + Send + 'static, { HandshakeOutcome::NeedsMasking(Box::pin(async move { + let mut reader = reader; + let mut initial_data = initial_data; + extend_masking_initial_window_with_timeout( + &mut reader, + &mut initial_data, + mask_classifier_prefetch_timeout(&config), + ) + .await; + handle_bad_client( reader, writer, @@ -1321,6 +1391,38 @@ mod masking_shape_classifier_fuzz_redteam_expected_fail_tests; #[path = "tests/client_masking_probe_evasion_blackhat_tests.rs"] mod masking_probe_evasion_blackhat_tests; +#[cfg(test)] +#[path = "tests/client_masking_fragmented_classifier_security_tests.rs"] +mod masking_fragmented_classifier_security_tests; + +#[cfg(test)] +#[path = "tests/client_masking_replay_timing_security_tests.rs"] +mod masking_replay_timing_security_tests; + +#[cfg(test)] +#[path = "tests/client_masking_http2_fragmented_preface_security_tests.rs"] +mod masking_http2_fragmented_preface_security_tests; + +#[cfg(test)] +#[path = "tests/client_masking_prefetch_invariant_security_tests.rs"] +mod masking_prefetch_invariant_security_tests; + +#[cfg(test)] +#[path = "tests/client_masking_prefetch_timing_matrix_security_tests.rs"] +mod masking_prefetch_timing_matrix_security_tests; + +#[cfg(test)] +#[path = "tests/client_masking_prefetch_config_runtime_security_tests.rs"] +mod masking_prefetch_config_runtime_security_tests; + +#[cfg(test)] +#[path = "tests/client_masking_prefetch_config_pipeline_integration_security_tests.rs"] +mod masking_prefetch_config_pipeline_integration_security_tests; + +#[cfg(test)] +#[path = "tests/client_masking_prefetch_strict_boundary_security_tests.rs"] +mod masking_prefetch_strict_boundary_security_tests; + #[cfg(test)] #[path = "tests/client_beobachten_ttl_bounds_security_tests.rs"] mod beobachten_ttl_bounds_security_tests; @@ -1328,3 +1430,15 @@ mod beobachten_ttl_bounds_security_tests; #[cfg(test)] #[path = "tests/client_tls_record_wrap_hardening_security_tests.rs"] mod tls_record_wrap_hardening_security_tests; + +#[cfg(test)] +#[path = "tests/client_clever_advanced_tests.rs"] +mod client_clever_advanced_tests; + +#[cfg(test)] +#[path = "tests/client_more_advanced_tests.rs"] +mod client_more_advanced_tests; + +#[cfg(test)] +#[path = "tests/client_deep_invariants_tests.rs"] +mod client_deep_invariants_tests; diff --git a/src/proxy/handshake.rs b/src/proxy/handshake.rs index 5632977..96994c7 100644 --- a/src/proxy/handshake.rs +++ b/src/proxy/handshake.rs @@ -121,6 +121,19 @@ fn auth_probe_eviction_offset(peer_ip: IpAddr, now: Instant) -> usize { hasher.finish() as usize } +fn auth_probe_scan_start_offset( + peer_ip: IpAddr, + now: Instant, + state_len: usize, + scan_limit: usize, +) -> usize { + if state_len == 0 || scan_limit == 0 { + return 0; + } + + auth_probe_eviction_offset(peer_ip, now) % state_len +} + fn auth_probe_is_throttled(peer_ip: IpAddr, now: Instant) -> bool { let peer_ip = normalize_auth_probe_ip(peer_ip); let state = auth_probe_state_map(); @@ -269,11 +282,7 @@ fn auth_probe_record_failure_with_state( let mut eviction_candidate: Option<(IpAddr, u32, Instant)> = None; let state_len = state.len(); let scan_limit = state_len.min(AUTH_PROBE_PRUNE_SCAN_LIMIT); - let start_offset = if state_len == 0 { - 0 - } else { - auth_probe_eviction_offset(peer_ip, now) % state_len - }; + let start_offset = auth_probe_scan_start_offset(peer_ip, now, state_len, scan_limit); let mut scanned = 0usize; for entry in state.iter().skip(start_offset) { @@ -605,16 +614,6 @@ where } }; - // Replay tracking is applied only after successful authentication to avoid - // letting unauthenticated probes evict valid entries from the replay cache. - let digest_half = &validation.digest[..tls::TLS_DIGEST_HALF_LEN]; - if replay_checker.check_and_add_tls_digest(digest_half) { - auth_probe_record_failure(peer.ip(), Instant::now()); - maybe_apply_server_hello_delay(config).await; - warn!(peer = %peer, "TLS replay attack detected (duplicate digest)"); - return HandshakeResult::BadClient { reader, writer }; - } - let secret = match secrets.iter().find(|(name, _)| *name == validation.user) { Some((_, s)) => s, None => { @@ -670,6 +669,16 @@ where None }; + // Replay tracking is applied only after full policy validation (including + // ALPN checks) so rejected handshakes cannot poison replay state. + let digest_half = &validation.digest[..tls::TLS_DIGEST_HALF_LEN]; + if replay_checker.check_and_add_tls_digest(digest_half) { + auth_probe_record_failure(peer.ip(), Instant::now()); + maybe_apply_server_hello_delay(config).await; + warn!(peer = %peer, "TLS replay attack detected (duplicate digest)"); + return HandshakeResult::BadClient { reader, writer }; + } + let response = if let Some((cached_entry, use_full_cert_payload)) = cached { emulator::build_emulated_server_hello( secret, @@ -769,7 +778,7 @@ where let mut dec_key_input = Zeroizing::new(Vec::with_capacity(PREKEY_LEN + secret.len())); dec_key_input.extend_from_slice(dec_prekey); dec_key_input.extend_from_slice(&secret); - let dec_key = sha256(&dec_key_input); + let dec_key = Zeroizing::new(sha256(&dec_key_input)); let mut dec_iv_arr = [0u8; IV_LEN]; dec_iv_arr.copy_from_slice(dec_iv_bytes); @@ -805,7 +814,7 @@ where let mut enc_key_input = Zeroizing::new(Vec::with_capacity(PREKEY_LEN + secret.len())); enc_key_input.extend_from_slice(enc_prekey); enc_key_input.extend_from_slice(&secret); - let enc_key = sha256(&enc_key_input); + let enc_key = Zeroizing::new(sha256(&enc_key_input)); let mut enc_iv_arr = [0u8; IV_LEN]; enc_iv_arr.copy_from_slice(enc_iv_bytes); @@ -830,9 +839,9 @@ where user: user.clone(), dc_idx, proto_tag, - dec_key, + dec_key: *dec_key, dec_iv, - enc_key, + enc_key: *enc_key, enc_iv, peer, is_tls, @@ -979,6 +988,38 @@ mod saturation_poison_security_tests; #[path = "tests/handshake_auth_probe_hardening_adversarial_tests.rs"] mod auth_probe_hardening_adversarial_tests; +#[cfg(test)] +#[path = "tests/handshake_auth_probe_scan_budget_security_tests.rs"] +mod auth_probe_scan_budget_security_tests; + +#[cfg(test)] +#[path = "tests/handshake_auth_probe_scan_offset_stress_tests.rs"] +mod auth_probe_scan_offset_stress_tests; + +#[cfg(test)] +#[path = "tests/handshake_auth_probe_eviction_bias_security_tests.rs"] +mod auth_probe_eviction_bias_security_tests; + +#[cfg(test)] +#[path = "tests/handshake_advanced_clever_tests.rs"] +mod advanced_clever_tests; + +#[cfg(test)] +#[path = "tests/handshake_more_clever_tests.rs"] +mod more_clever_tests; + +#[cfg(test)] +#[path = "tests/handshake_real_bug_stress_tests.rs"] +mod real_bug_stress_tests; + +#[cfg(test)] +#[path = "tests/handshake_timing_manual_bench_tests.rs"] +mod timing_manual_bench_tests; + +#[cfg(test)] +#[path = "tests/handshake_key_material_zeroization_security_tests.rs"] +mod handshake_key_material_zeroization_security_tests; + /// Compile-time guard: HandshakeSuccess holds cryptographic key material and /// must never be Copy. A Copy impl would allow silent key duplication, /// undermining the zeroize-on-drop guarantee. diff --git a/src/proxy/masking.rs b/src/proxy/masking.rs index 3639db1..841749c 100644 --- a/src/proxy/masking.rs +++ b/src/proxy/masking.rs @@ -4,14 +4,23 @@ use crate::config::ProxyConfig; use crate::network::dns_overrides::resolve_socket_addr; use crate::stats::beobachten::BeobachtenStore; use crate::transport::proxy_protocol::{ProxyProtocolV1Builder, ProxyProtocolV2Builder}; -use rand::{Rng, RngExt}; -use std::net::SocketAddr; +#[cfg(unix)] +use nix::ifaddrs::getifaddrs; +use rand::rngs::StdRng; +use rand::{Rng, RngExt, SeedableRng}; +use std::net::{IpAddr, SocketAddr}; use std::str; -use std::time::Duration; +#[cfg(unix)] +use std::sync::{Mutex, OnceLock}; +#[cfg(test)] +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::time::{Duration, Instant as StdInstant}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::net::TcpStream; #[cfg(unix)] use tokio::net::UnixStream; +#[cfg(unix)] +use tokio::sync::Mutex as AsyncMutex; use tokio::time::{Instant, timeout}; use tracing::debug; @@ -30,13 +39,23 @@ const MASK_RELAY_IDLE_TIMEOUT: Duration = Duration::from_secs(5); #[cfg(test)] const MASK_RELAY_IDLE_TIMEOUT: Duration = Duration::from_millis(100); const MASK_BUFFER_SIZE: usize = 8192; +#[cfg(unix)] +#[cfg(not(test))] +const LOCAL_INTERFACE_CACHE_TTL: Duration = Duration::from_secs(300); +#[cfg(all(unix, test))] +const LOCAL_INTERFACE_CACHE_TTL: Duration = Duration::from_secs(1); struct CopyOutcome { total: usize, ended_by_eof: bool, } -async fn copy_with_idle_timeout(reader: &mut R, writer: &mut W) -> CopyOutcome +async fn copy_with_idle_timeout( + reader: &mut R, + writer: &mut W, + byte_cap: usize, + shutdown_on_eof: bool, +) -> CopyOutcome where R: AsyncRead + Unpin, W: AsyncWrite + Unpin, @@ -44,14 +63,31 @@ where let mut buf = [0u8; MASK_BUFFER_SIZE]; let mut total = 0usize; let mut ended_by_eof = false; + + if byte_cap == 0 { + return CopyOutcome { + total, + ended_by_eof, + }; + } + loop { - let read_res = timeout(MASK_RELAY_IDLE_TIMEOUT, reader.read(&mut buf)).await; + let remaining_budget = byte_cap.saturating_sub(total); + if remaining_budget == 0 { + break; + } + + let read_len = remaining_budget.min(MASK_BUFFER_SIZE); + let read_res = timeout(MASK_RELAY_IDLE_TIMEOUT, reader.read(&mut buf[..read_len])).await; let n = match read_res { Ok(Ok(n)) => n, Ok(Err(_)) | Err(_) => break, }; if n == 0 { ended_by_eof = true; + if shutdown_on_eof { + let _ = timeout(MASK_RELAY_IDLE_TIMEOUT, writer.shutdown()).await; + } break; } total = total.saturating_add(n); @@ -68,6 +104,39 @@ where } } +fn is_http_probe(data: &[u8]) -> bool { + // RFC 7540 section 3.5: HTTP/2 client preface starts with "PRI ". + const HTTP_METHODS: [&[u8]; 10] = [ + b"GET ", + b"POST", + b"HEAD", + b"PUT ", + b"DELETE", + b"OPTIONS", + b"CONNECT", + b"TRACE", + b"PATCH", + b"PRI ", + ]; + + if data.is_empty() { + return false; + } + + let window = &data[..data.len().min(16)]; + for method in HTTP_METHODS { + if data.len() >= method.len() && window.starts_with(method) { + return true; + } + + if (2..=3).contains(&window.len()) && method.starts_with(window) { + return true; + } + } + + false +} + fn next_mask_shape_bucket(total: usize, floor: usize, cap: usize) -> usize { if total == 0 || floor == 0 || cap < floor { return total; @@ -125,6 +194,11 @@ async fn maybe_write_shape_padding( let mut remaining = target_total - total_sent; let mut pad_chunk = [0u8; 1024]; let deadline = Instant::now() + MASK_TIMEOUT; + // Use a Send RNG so relay futures remain spawn-safe under Tokio. + let mut rng = { + let mut seed_source = rand::rng(); + StdRng::from_rng(&mut seed_source) + }; while remaining > 0 { let now = Instant::now(); @@ -133,10 +207,7 @@ async fn maybe_write_shape_padding( } let write_len = remaining.min(pad_chunk.len()); - { - let mut rng = rand::rng(); - rng.fill_bytes(&mut pad_chunk[..write_len]); - } + rng.fill_bytes(&mut pad_chunk[..write_len]); let write_budget = deadline.saturating_duration_since(now); match timeout(write_budget, mask_write.write_all(&pad_chunk[..write_len])).await { Ok(Ok(())) => {} @@ -167,11 +238,11 @@ where } } -async fn consume_client_data_with_timeout(reader: R) +async fn consume_client_data_with_timeout_and_cap(reader: R, byte_cap: usize) where R: AsyncRead + Unpin, { - if timeout(MASK_RELAY_TIMEOUT, consume_client_data(reader)) + if timeout(MASK_RELAY_TIMEOUT, consume_client_data(reader, byte_cap)) .await .is_err() { @@ -190,6 +261,9 @@ fn mask_outcome_target_budget(config: &ProxyConfig) -> Duration { if config.censorship.mask_timing_normalization_enabled { let floor = config.censorship.mask_timing_normalization_floor_ms; let ceiling = config.censorship.mask_timing_normalization_ceiling_ms; + if floor == 0 { + return MASK_TIMEOUT; + } if ceiling > floor { let mut rng = rand::rng(); return Duration::from_millis(rng.random_range(floor..=ceiling)); @@ -219,14 +293,7 @@ async fn wait_mask_outcome_budget(started: Instant, config: &ProxyConfig) { /// Detect client type based on initial data fn detect_client_type(data: &[u8]) -> &'static str { // Check for HTTP request - if data.len() > 4 - && (data.starts_with(b"GET ") - || data.starts_with(b"POST") - || data.starts_with(b"HEAD") - || data.starts_with(b"PUT ") - || data.starts_with(b"DELETE") - || data.starts_with(b"OPTIONS")) - { + if is_http_probe(data) { return "HTTP"; } @@ -248,6 +315,244 @@ fn detect_client_type(data: &[u8]) -> &'static str { "unknown" } +fn parse_mask_host_ip_literal(host: &str) -> Option { + if host.starts_with('[') && host.ends_with(']') { + return host[1..host.len() - 1].parse::().ok(); + } + host.parse::().ok() +} + +fn canonical_ip(ip: IpAddr) -> IpAddr { + match ip { + IpAddr::V6(v6) => v6.to_ipv4_mapped().map(IpAddr::V4).unwrap_or(IpAddr::V6(v6)), + IpAddr::V4(v4) => IpAddr::V4(v4), + } +} + +#[cfg(unix)] +fn collect_local_interface_ips() -> Vec { + #[cfg(test)] + LOCAL_INTERFACE_ENUMERATIONS.fetch_add(1, Ordering::Relaxed); + + let mut out = Vec::new(); + if let Ok(addrs) = getifaddrs() { + for iface in addrs { + if let Some(address) = iface.address { + if let Some(v4) = address.as_sockaddr_in() { + out.push(canonical_ip(IpAddr::V4(v4.ip()))); + } else if let Some(v6) = address.as_sockaddr_in6() { + out.push(canonical_ip(IpAddr::V6(v6.ip()))); + } + } + } + } + out +} + +fn choose_interface_snapshot(previous: &[IpAddr], refreshed: Vec) -> Vec { + if refreshed.is_empty() && !previous.is_empty() { + return previous.to_vec(); + } + + refreshed +} + +#[cfg(unix)] +#[derive(Default)] +struct LocalInterfaceCache { + ips: Vec, + refreshed_at: Option, +} + +#[cfg(unix)] +static LOCAL_INTERFACE_CACHE: OnceLock> = OnceLock::new(); + +#[cfg(unix)] +static LOCAL_INTERFACE_REFRESH_LOCK: OnceLock> = OnceLock::new(); + +#[cfg(all(unix, test))] +fn local_interface_ips() -> Vec { + let cache = LOCAL_INTERFACE_CACHE.get_or_init(|| Mutex::new(LocalInterfaceCache::default())); + let mut guard = cache.lock().unwrap_or_else(|poison| poison.into_inner()); + + let stale = guard + .refreshed_at + .is_none_or(|at| at.elapsed() >= LOCAL_INTERFACE_CACHE_TTL); + if stale { + let refreshed = collect_local_interface_ips(); + guard.ips = choose_interface_snapshot(&guard.ips, refreshed); + guard.refreshed_at = Some(StdInstant::now()); + } + + guard.ips.clone() +} + +#[cfg(unix)] +async fn local_interface_ips_async() -> Vec { + let cache = LOCAL_INTERFACE_CACHE.get_or_init(|| Mutex::new(LocalInterfaceCache::default())); + + { + let guard = cache.lock().unwrap_or_else(|poison| poison.into_inner()); + let stale = guard + .refreshed_at + .is_none_or(|at| at.elapsed() >= LOCAL_INTERFACE_CACHE_TTL); + if !stale { + return guard.ips.clone(); + } + } + + let refresh_lock = LOCAL_INTERFACE_REFRESH_LOCK.get_or_init(|| AsyncMutex::new(())); + let _refresh_guard = refresh_lock.lock().await; + + { + let guard = cache.lock().unwrap_or_else(|poison| poison.into_inner()); + let stale = guard + .refreshed_at + .is_none_or(|at| at.elapsed() >= LOCAL_INTERFACE_CACHE_TTL); + if !stale { + return guard.ips.clone(); + } + } + + let refreshed = tokio::task::spawn_blocking(collect_local_interface_ips) + .await + .unwrap_or_default(); + + let mut guard = cache.lock().unwrap_or_else(|poison| poison.into_inner()); + let stale = guard + .refreshed_at + .is_none_or(|at| at.elapsed() >= LOCAL_INTERFACE_CACHE_TTL); + if stale { + guard.ips = choose_interface_snapshot(&guard.ips, refreshed); + guard.refreshed_at = Some(StdInstant::now()); + } + + guard.ips.clone() +} + +#[cfg(all(not(unix), test))] +fn local_interface_ips() -> Vec { + Vec::new() +} + +#[cfg(not(unix))] +async fn local_interface_ips_async() -> Vec { + Vec::new() +} + +#[cfg(test)] +static LOCAL_INTERFACE_ENUMERATIONS: AtomicUsize = AtomicUsize::new(0); + +#[cfg(test)] +fn reset_local_interface_enumerations_for_tests() { + LOCAL_INTERFACE_ENUMERATIONS.store(0, Ordering::Relaxed); + + #[cfg(unix)] + if let Some(cache) = LOCAL_INTERFACE_CACHE.get() { + let mut guard = cache.lock().unwrap_or_else(|poison| poison.into_inner()); + guard.ips.clear(); + guard.refreshed_at = None; + } +} + +#[cfg(test)] +fn local_interface_enumerations_for_tests() -> usize { + LOCAL_INTERFACE_ENUMERATIONS.load(Ordering::Relaxed) +} + +fn is_mask_target_local_listener_with_interfaces( + mask_host: &str, + mask_port: u16, + local_addr: SocketAddr, + resolved_override: Option, + interface_ips: &[IpAddr], +) -> bool { + if mask_port != local_addr.port() { + return false; + } + + let local_ip = canonical_ip(local_addr.ip()); + let literal_mask_ip = parse_mask_host_ip_literal(mask_host).map(canonical_ip); + + if let Some(addr) = resolved_override { + let resolved_ip = canonical_ip(addr.ip()); + if resolved_ip == local_ip { + return true; + } + + if local_ip.is_unspecified() + && (resolved_ip.is_loopback() + || resolved_ip.is_unspecified() + || interface_ips.contains(&resolved_ip)) + { + return true; + } + } + + if let Some(mask_ip) = literal_mask_ip { + if mask_ip == local_ip { + return true; + } + + if local_ip.is_unspecified() + && (mask_ip.is_loopback() + || mask_ip.is_unspecified() + || interface_ips.contains(&mask_ip)) + { + return true; + } + } + + false +} + +#[cfg(test)] +fn is_mask_target_local_listener( + mask_host: &str, + mask_port: u16, + local_addr: SocketAddr, + resolved_override: Option, +) -> bool { + if mask_port != local_addr.port() { + return false; + } + + let interfaces = local_interface_ips(); + is_mask_target_local_listener_with_interfaces( + mask_host, + mask_port, + local_addr, + resolved_override, + &interfaces, + ) +} + +async fn is_mask_target_local_listener_async( + mask_host: &str, + mask_port: u16, + local_addr: SocketAddr, + resolved_override: Option, +) -> bool { + if mask_port != local_addr.port() { + return false; + } + + let interfaces = local_interface_ips_async().await; + is_mask_target_local_listener_with_interfaces( + mask_host, + mask_port, + local_addr, + resolved_override, + &interfaces, + ) +} + +fn masking_beobachten_ttl(config: &ProxyConfig) -> Duration { + let minutes = config.general.beobachten_minutes; + let clamped = minutes.clamp(1, 24 * 60); + Duration::from_secs(clamped.saturating_mul(60)) +} + fn build_mask_proxy_header( version: u8, peer: SocketAddr, @@ -290,13 +595,14 @@ pub async fn handle_bad_client( { let client_type = detect_client_type(initial_data); if config.general.beobachten { - let ttl = Duration::from_secs(config.general.beobachten_minutes.saturating_mul(60)); + let ttl = masking_beobachten_ttl(config); beobachten.record(client_type, peer.ip(), ttl); } if !config.censorship.mask { // Masking disabled, just consume data - consume_client_data_with_timeout(reader).await; + consume_client_data_with_timeout_and_cap(reader, config.censorship.mask_relay_max_bytes) + .await; return; } @@ -341,6 +647,7 @@ pub async fn handle_bad_client( config.censorship.mask_shape_above_cap_blur, config.censorship.mask_shape_above_cap_blur_max_bytes, config.censorship.mask_shape_hardening_aggressive_mode, + config.censorship.mask_relay_max_bytes, ), ) .await @@ -353,12 +660,12 @@ pub async fn handle_bad_client( Ok(Err(e)) => { wait_mask_connect_budget_if_needed(connect_started, config).await; debug!(error = %e, "Failed to connect to mask unix socket"); - consume_client_data_with_timeout(reader).await; + consume_client_data_with_timeout_and_cap(reader, config.censorship.mask_relay_max_bytes).await; wait_mask_outcome_budget(outcome_started, config).await; } Err(_) => { debug!("Timeout connecting to mask unix socket"); - consume_client_data_with_timeout(reader).await; + consume_client_data_with_timeout_and_cap(reader, config.censorship.mask_relay_max_bytes).await; wait_mask_outcome_budget(outcome_started, config).await; } } @@ -372,6 +679,28 @@ pub async fn handle_bad_client( .unwrap_or(&config.censorship.tls_domain); let mask_port = config.censorship.mask_port; + // Fail closed when fallback points at our own listener endpoint. + // Self-referential masking can create recursive proxy loops under + // misconfiguration and leak distinguishable load spikes to adversaries. + let resolved_mask_addr = resolve_socket_addr(mask_host, mask_port); + if is_mask_target_local_listener_async(mask_host, mask_port, local_addr, resolved_mask_addr) + .await + { + let outcome_started = Instant::now(); + debug!( + client_type = client_type, + host = %mask_host, + port = mask_port, + local = %local_addr, + "Mask target resolves to local listener; refusing self-referential masking fallback" + ); + consume_client_data_with_timeout_and_cap(reader, config.censorship.mask_relay_max_bytes).await; + wait_mask_outcome_budget(outcome_started, config).await; + return; + } + + let outcome_started = Instant::now(); + debug!( client_type = client_type, host = %mask_host, @@ -381,10 +710,9 @@ pub async fn handle_bad_client( ); // Apply runtime DNS override for mask target when configured. - let mask_addr = resolve_socket_addr(mask_host, mask_port) + let mask_addr = resolved_mask_addr .map(|addr| addr.to_string()) .unwrap_or_else(|| format!("{}:{}", mask_host, mask_port)); - let outcome_started = Instant::now(); let connect_started = Instant::now(); let connect_result = timeout(MASK_TIMEOUT, TcpStream::connect(&mask_addr)).await; match connect_result { @@ -413,6 +741,7 @@ pub async fn handle_bad_client( config.censorship.mask_shape_above_cap_blur, config.censorship.mask_shape_above_cap_blur_max_bytes, config.censorship.mask_shape_hardening_aggressive_mode, + config.censorship.mask_relay_max_bytes, ), ) .await @@ -425,12 +754,12 @@ pub async fn handle_bad_client( Ok(Err(e)) => { wait_mask_connect_budget_if_needed(connect_started, config).await; debug!(error = %e, "Failed to connect to mask host"); - consume_client_data_with_timeout(reader).await; + consume_client_data_with_timeout_and_cap(reader, config.censorship.mask_relay_max_bytes).await; wait_mask_outcome_budget(outcome_started, config).await; } Err(_) => { debug!("Timeout connecting to mask host"); - consume_client_data_with_timeout(reader).await; + consume_client_data_with_timeout_and_cap(reader, config.censorship.mask_relay_max_bytes).await; wait_mask_outcome_budget(outcome_started, config).await; } } @@ -449,6 +778,7 @@ async fn relay_to_mask( shape_above_cap_blur: bool, shape_above_cap_blur_max_bytes: usize, shape_hardening_aggressive_mode: bool, + mask_relay_max_bytes: usize, ) where R: AsyncRead + Unpin + Send + 'static, W: AsyncWrite + Unpin + Send + 'static, @@ -464,8 +794,18 @@ async fn relay_to_mask( } let (upstream_copy, downstream_copy) = tokio::join!( - async { copy_with_idle_timeout(&mut reader, &mut mask_write).await }, - async { copy_with_idle_timeout(&mut mask_read, &mut writer).await } + async { + copy_with_idle_timeout( + &mut reader, + &mut mask_write, + mask_relay_max_bytes, + !shape_hardening_enabled, + ) + .await + }, + async { + copy_with_idle_timeout(&mut mask_read, &mut writer, mask_relay_max_bytes, true).await + } ); let total_sent = initial_data.len().saturating_add(upstream_copy.total); @@ -491,13 +831,36 @@ async fn relay_to_mask( let _ = writer.shutdown().await; } -/// Just consume all data from client without responding -async fn consume_client_data(mut reader: R) { - let mut buf = vec![0u8; MASK_BUFFER_SIZE]; - while let Ok(n) = reader.read(&mut buf).await { +/// Just consume all data from client without responding. +async fn consume_client_data(mut reader: R, byte_cap: usize) { + if byte_cap == 0 { + return; + } + + // Keep drain path fail-closed under slow-loris stalls. + let mut buf = [0u8; MASK_BUFFER_SIZE]; + let mut total = 0usize; + + loop { + let remaining_budget = byte_cap.saturating_sub(total); + if remaining_budget == 0 { + break; + } + + let read_len = remaining_budget.min(MASK_BUFFER_SIZE); + let n = match timeout(MASK_RELAY_IDLE_TIMEOUT, reader.read(&mut buf[..read_len])).await { + Ok(Ok(n)) => n, + Ok(Err(_)) | Err(_) => break, + }; + if n == 0 { break; } + + total = total.saturating_add(n); + if total >= byte_cap { + break; + } } } @@ -521,6 +884,10 @@ mod masking_shape_above_cap_blur_security_tests; #[path = "tests/masking_timing_normalization_security_tests.rs"] mod masking_timing_normalization_security_tests; +#[cfg(test)] +#[path = "tests/masking_timing_budget_coupling_security_tests.rs"] +mod masking_timing_budget_coupling_security_tests; + #[cfg(test)] #[path = "tests/masking_ab_envelope_blur_integration_security_tests.rs"] mod masking_ab_envelope_blur_integration_security_tests; @@ -548,3 +915,75 @@ mod masking_aggressive_mode_security_tests; #[cfg(test)] #[path = "tests/masking_timing_sidechannel_redteam_expected_fail_tests.rs"] mod masking_timing_sidechannel_redteam_expected_fail_tests; + +#[cfg(test)] +#[path = "tests/masking_self_target_loop_security_tests.rs"] +mod masking_self_target_loop_security_tests; + +#[cfg(test)] +#[path = "tests/masking_classification_completeness_security_tests.rs"] +mod masking_classification_completeness_security_tests; + +#[cfg(test)] +#[path = "tests/masking_relay_guardrails_security_tests.rs"] +mod masking_relay_guardrails_security_tests; + +#[cfg(test)] +#[path = "tests/masking_connect_failure_close_matrix_security_tests.rs"] +mod masking_connect_failure_close_matrix_security_tests; + +#[cfg(test)] +#[path = "tests/masking_additional_hardening_security_tests.rs"] +mod masking_additional_hardening_security_tests; + +#[cfg(test)] +#[path = "tests/masking_consume_idle_timeout_security_tests.rs"] +mod masking_consume_idle_timeout_security_tests; + +#[cfg(test)] +#[path = "tests/masking_http2_probe_classification_security_tests.rs"] +mod masking_http2_probe_classification_security_tests; + +#[cfg(test)] +#[path = "tests/masking_http_probe_boundary_security_tests.rs"] +mod masking_http_probe_boundary_security_tests; + +#[cfg(test)] +#[path = "tests/masking_rng_hoist_perf_regression_tests.rs"] +mod masking_rng_hoist_perf_regression_tests; + +#[cfg(test)] +#[path = "tests/masking_http2_preface_integration_security_tests.rs"] +mod masking_http2_preface_integration_security_tests; + +#[cfg(test)] +#[path = "tests/masking_consume_stress_adversarial_tests.rs"] +mod masking_consume_stress_adversarial_tests; + +#[cfg(test)] +#[path = "tests/masking_interface_cache_security_tests.rs"] +mod masking_interface_cache_security_tests; + +#[cfg(test)] +#[path = "tests/masking_interface_cache_defense_in_depth_security_tests.rs"] +mod masking_interface_cache_defense_in_depth_security_tests; + +#[cfg(test)] +#[path = "tests/masking_interface_cache_concurrency_security_tests.rs"] +mod masking_interface_cache_concurrency_security_tests; + +#[cfg(test)] +#[path = "tests/masking_production_cap_regression_security_tests.rs"] +mod masking_production_cap_regression_security_tests; + +#[cfg(test)] +#[path = "tests/masking_extended_attack_surface_security_tests.rs"] +mod masking_extended_attack_surface_security_tests; + +#[cfg(test)] +#[path = "tests/masking_padding_timeout_adversarial_tests.rs"] +mod masking_padding_timeout_adversarial_tests; + +#[cfg(all(test, feature = "redteam_offline_expected_fail"))] +#[path = "tests/masking_offline_target_redteam_expected_fail_tests.rs"] +mod masking_offline_target_redteam_expected_fail_tests; diff --git a/src/proxy/middle_relay.rs b/src/proxy/middle_relay.rs index d0f5ffb..14ea001 100644 --- a/src/proxy/middle_relay.rs +++ b/src/proxy/middle_relay.rs @@ -1,5 +1,7 @@ use std::collections::hash_map::RandomState; use std::collections::{BTreeSet, HashMap}; +#[cfg(test)] +use std::future::Future; use std::hash::{BuildHasher, Hash}; use std::net::{IpAddr, SocketAddr}; use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; @@ -39,10 +41,14 @@ const C2ME_CHANNEL_CAPACITY_FALLBACK: usize = 128; const C2ME_SOFT_PRESSURE_MIN_FREE_SLOTS: usize = 64; const C2ME_SENDER_FAIRNESS_BUDGET: usize = 32; const RELAY_IDLE_IO_POLL_MAX: Duration = Duration::from_secs(1); +const TINY_FRAME_DEBT_PER_TINY: u32 = 8; +const TINY_FRAME_DEBT_LIMIT: u32 = 512; #[cfg(test)] const C2ME_SEND_TIMEOUT: Duration = Duration::from_millis(50); #[cfg(not(test))] const C2ME_SEND_TIMEOUT: Duration = Duration::from_secs(5); +#[cfg(test)] +const RELAY_TEST_STEP_TIMEOUT: Duration = Duration::from_secs(1); const ME_D2C_FLUSH_BATCH_MAX_FRAMES_MIN: usize = 1; const ME_D2C_FLUSH_BATCH_MAX_BYTES_MIN: usize = 4096; const ME_D2C_FRAME_BUF_SHRINK_HYSTERESIS_FACTOR: usize = 2; @@ -94,10 +100,23 @@ fn relay_idle_candidate_registry() -> &'static Mutex RELAY_IDLE_CANDIDATE_REGISTRY.get_or_init(|| Mutex::new(RelayIdleCandidateRegistry::default())) } +fn relay_idle_candidate_registry_lock() -> std::sync::MutexGuard<'static, RelayIdleCandidateRegistry> { + let registry = relay_idle_candidate_registry(); + match registry.lock() { + Ok(guard) => guard, + Err(poisoned) => { + let mut guard = poisoned.into_inner(); + // Fail closed after panic while holding registry lock: drop all + // candidates and pressure cursors to avoid stale cross-session state. + *guard = RelayIdleCandidateRegistry::default(); + registry.clear_poison(); + guard + } + } +} + fn mark_relay_idle_candidate(conn_id: u64) -> bool { - let Ok(mut guard) = relay_idle_candidate_registry().lock() else { - return false; - }; + let mut guard = relay_idle_candidate_registry_lock(); if guard.by_conn_id.contains_key(&conn_id) { return false; @@ -116,9 +135,7 @@ fn mark_relay_idle_candidate(conn_id: u64) -> bool { } fn clear_relay_idle_candidate(conn_id: u64) { - let Ok(mut guard) = relay_idle_candidate_registry().lock() else { - return; - }; + let mut guard = relay_idle_candidate_registry_lock(); if let Some(meta) = guard.by_conn_id.remove(&conn_id) { guard.ordered.remove(&(meta.mark_order_seq, conn_id)); @@ -127,23 +144,17 @@ fn clear_relay_idle_candidate(conn_id: u64) { #[cfg(test)] fn oldest_relay_idle_candidate() -> Option { - let Ok(guard) = relay_idle_candidate_registry().lock() else { - return None; - }; + let guard = relay_idle_candidate_registry_lock(); guard.ordered.iter().next().map(|(_, conn_id)| *conn_id) } fn note_relay_pressure_event() { - let Ok(mut guard) = relay_idle_candidate_registry().lock() else { - return; - }; + let mut guard = relay_idle_candidate_registry_lock(); guard.pressure_event_seq = guard.pressure_event_seq.wrapping_add(1); } fn relay_pressure_event_seq() -> u64 { - let Ok(guard) = relay_idle_candidate_registry().lock() else { - return 0; - }; + let guard = relay_idle_candidate_registry_lock(); guard.pressure_event_seq } @@ -152,9 +163,7 @@ fn maybe_evict_idle_candidate_on_pressure( seen_pressure_seq: &mut u64, stats: &Stats, ) -> bool { - let Ok(mut guard) = relay_idle_candidate_registry().lock() else { - return false; - }; + let mut guard = relay_idle_candidate_registry_lock(); let latest_pressure_seq = guard.pressure_event_seq; if latest_pressure_seq == *seen_pressure_seq { @@ -199,13 +208,9 @@ fn maybe_evict_idle_candidate_on_pressure( #[cfg(test)] fn clear_relay_idle_pressure_state_for_testing() { - if let Some(registry) = RELAY_IDLE_CANDIDATE_REGISTRY.get() - && let Ok(mut guard) = registry.lock() - { - guard.by_conn_id.clear(); - guard.ordered.clear(); - guard.pressure_event_seq = 0; - guard.pressure_consumed_seq = 0; + if RELAY_IDLE_CANDIDATE_REGISTRY.get().is_some() { + let mut guard = relay_idle_candidate_registry_lock(); + *guard = RelayIdleCandidateRegistry::default(); } RELAY_IDLE_MARK_SEQ.store(0, Ordering::Relaxed); } @@ -259,6 +264,7 @@ impl RelayClientIdlePolicy { struct RelayClientIdleState { last_client_frame_at: Instant, soft_idle_marked: bool, + tiny_frame_debt: u32, } impl RelayClientIdleState { @@ -266,6 +272,7 @@ impl RelayClientIdleState { Self { last_client_frame_at: now, soft_idle_marked: false, + tiny_frame_debt: 0, } } @@ -552,15 +559,6 @@ fn quota_soft_cap(limit: u64, overshoot: u64) -> u64 { limit.saturating_add(overshoot) } -fn quota_exceeded_for_user_soft( - stats: &Stats, - user: &str, - quota_limit: Option, - overshoot: u64, -) -> bool { - quota_limit.is_some_and(|quota| stats.get_user_total_octets(user) >= quota_soft_cap(quota, overshoot)) -} - fn quota_would_be_exceeded_for_user_soft( stats: &Stats, user: &str, @@ -568,11 +566,8 @@ fn quota_would_be_exceeded_for_user_soft( bytes: u64, overshoot: u64, ) -> bool { - quota_limit.is_some_and(|quota| { - let cap = quota_soft_cap(quota, overshoot); - let used = stats.get_user_total_octets(user); - used >= cap || bytes > cap.saturating_sub(used) - }) + let capped_limit = quota_limit.map(|quota| quota_soft_cap(quota, overshoot)); + quota_would_be_exceeded_for_user(stats, user, capped_limit, bytes) } fn classify_me_d2c_flush_reason( @@ -618,6 +613,16 @@ fn observe_me_d2c_flush_event( } } +fn rollback_me2c_quota_reservation( + stats: &Stats, + user: &str, + bytes_me2c: &AtomicU64, + reserved_bytes: u64, +) { + stats.sub_user_octets_to(user, reserved_bytes); + bytes_me2c.fetch_sub(reserved_bytes, Ordering::Relaxed); +} + #[cfg(test)] fn quota_user_lock_test_guard() -> &'static Mutex<()> { static TEST_LOCK: OnceLock> = OnceLock::new(); @@ -631,6 +636,19 @@ fn quota_user_lock_test_scope() -> std::sync::MutexGuard<'static, ()> { .unwrap_or_else(|poisoned| poisoned.into_inner()) } +#[cfg(test)] +fn relay_idle_pressure_test_guard() -> &'static Mutex<()> { + static TEST_LOCK: OnceLock> = OnceLock::new(); + TEST_LOCK.get_or_init(|| Mutex::new(())) +} + +#[cfg(test)] +pub(crate) fn relay_idle_pressure_test_scope() -> std::sync::MutexGuard<'static, ()> { + relay_idle_pressure_test_guard() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + fn quota_overflow_user_lock(user: &str) -> Arc> { let stripes = QUOTA_USER_OVERFLOW_LOCKS.get_or_init(|| { (0..QUOTA_OVERFLOW_LOCK_STRIPES) @@ -666,6 +684,11 @@ fn quota_user_lock(user: &str) -> Arc> { } } +#[cfg(test)] +pub(crate) fn cross_mode_quota_user_lock_for_tests(user: &str) -> Arc> { + crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(user) +} + async fn enqueue_c2me_command( tx: &mpsc::Sender, cmd: C2MeCommand, @@ -691,6 +714,16 @@ async fn enqueue_c2me_command( } } +#[cfg(test)] +async fn run_relay_test_step_timeout(context: &'static str, fut: F) -> T +where + F: Future, +{ + timeout(RELAY_TEST_STEP_TIMEOUT, fut) + .await + .unwrap_or_else(|_| panic!("{context} exceeded {}s", RELAY_TEST_STEP_TIMEOUT.as_secs())) +} + pub(crate) async fn handle_via_middle_proxy( mut crypto_reader: CryptoReader, crypto_writer: CryptoWriter, @@ -711,6 +744,8 @@ where { let user = success.user.clone(); let quota_limit = config.access.user_data_quota.get(&user).copied(); + let cross_mode_quota_lock = + quota_limit.map(|_| crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user)); let peer = success.peer; let proto_tag = success.proto_tag; let pool_generation = me_pool.current_generation(); @@ -837,6 +872,7 @@ where let stats_clone = stats.clone(); let rng_clone = rng.clone(); let user_clone = user.clone(); + let cross_mode_quota_lock_me_writer = cross_mode_quota_lock.clone(); let last_downstream_activity_ms_clone = last_downstream_activity_ms.clone(); let bytes_me2c_clone = bytes_me2c.clone(); let d2c_flush_policy = MeD2cFlushPolicy::from_config(&config); @@ -858,7 +894,7 @@ where let first_is_downstream_activity = matches!(&first, MeResponse::Data { .. } | MeResponse::Ack(_)); - match process_me_writer_response( + match process_me_writer_response_with_cross_mode_lock( first, &mut writer, proto_tag, @@ -868,6 +904,7 @@ where &user_clone, quota_limit, d2c_flush_policy.quota_soft_overshoot_bytes, + cross_mode_quota_lock_me_writer.as_ref(), bytes_me2c_clone.as_ref(), conn_id, d2c_flush_policy.ack_flush_immediate, @@ -916,7 +953,7 @@ where let next_is_downstream_activity = matches!(&next, MeResponse::Data { .. } | MeResponse::Ack(_)); - match process_me_writer_response( + match process_me_writer_response_with_cross_mode_lock( next, &mut writer, proto_tag, @@ -926,6 +963,7 @@ where &user_clone, quota_limit, d2c_flush_policy.quota_soft_overshoot_bytes, + cross_mode_quota_lock_me_writer.as_ref(), bytes_me2c_clone.as_ref(), conn_id, d2c_flush_policy.ack_flush_immediate, @@ -977,7 +1015,7 @@ where Ok(Some(next)) => { let next_is_downstream_activity = matches!(&next, MeResponse::Data { .. } | MeResponse::Ack(_)); - match process_me_writer_response( + match process_me_writer_response_with_cross_mode_lock( next, &mut writer, proto_tag, @@ -987,6 +1025,7 @@ where &user_clone, quota_limit, d2c_flush_policy.quota_soft_overshoot_bytes, + cross_mode_quota_lock_me_writer.as_ref(), bytes_me2c_clone.as_ref(), conn_id, d2c_flush_policy.ack_flush_immediate, @@ -1040,7 +1079,7 @@ where let extra_is_downstream_activity = matches!(&extra, MeResponse::Data { .. } | MeResponse::Ack(_)); - match process_me_writer_response( + match process_me_writer_response_with_cross_mode_lock( extra, &mut writer, proto_tag, @@ -1050,6 +1089,7 @@ where &user_clone, quota_limit, d2c_flush_policy.quota_soft_overshoot_bytes, + cross_mode_quota_lock_me_writer.as_ref(), bytes_me2c_clone.as_ref(), conn_id, d2c_flush_policy.ack_flush_immediate, @@ -1222,6 +1262,14 @@ where if let Some(limit) = quota_limit { let quota_lock = quota_user_lock(&user); let _quota_guard = quota_lock.lock().await; + let Some(cross_mode_lock) = cross_mode_quota_lock.as_ref() else { + main_result = Err(ProxyError::Proxy( + "cross-mode quota lock missing for quota-limited session" + .to_string(), + )); + break; + }; + let _cross_mode_quota_guard = cross_mode_lock.lock().await; stats.add_user_octets_from(&user, payload.len() as u64); if quota_exceeded_for_user(stats.as_ref(), &user, Some(limit)) { main_result = Err(ProxyError::DataQuotaExceeded { @@ -1321,6 +1369,8 @@ async fn read_client_payload_with_idle_policy( where R: AsyncRead + Unpin + Send + 'static, { + const LEGACY_MAX_CONSECUTIVE_ZERO_LEN_FRAMES: u32 = 4; + async fn read_exact_with_policy( client_reader: &mut CryptoReader, buf: &mut [u8], @@ -1459,6 +1509,7 @@ where Ok(()) } + let mut consecutive_zero_len_frames = 0u32; loop { let (len, quickack, raw_len_bytes) = match proto_tag { ProtoTag::Abridged => { @@ -1539,6 +1590,27 @@ where }; if len == 0 { + idle_state.tiny_frame_debt = idle_state + .tiny_frame_debt + .saturating_add(TINY_FRAME_DEBT_PER_TINY); + if idle_state.tiny_frame_debt >= TINY_FRAME_DEBT_LIMIT { + stats.increment_relay_protocol_desync_close_total(); + return Err(ProxyError::Proxy(format!( + "Tiny frame overhead limit exceeded: debt={}, conn_id={}", + idle_state.tiny_frame_debt, forensics.conn_id + ))); + } + + if !idle_policy.enabled { + consecutive_zero_len_frames = + consecutive_zero_len_frames.saturating_add(1); + if consecutive_zero_len_frames > LEGACY_MAX_CONSECUTIVE_ZERO_LEN_FRAMES { + stats.increment_relay_protocol_desync_close_total(); + return Err(ProxyError::Proxy( + "Excessive zero-length abridged frames".to_string(), + )); + } + } continue; } if len < 4 && proto_tag != ProtoTag::Abridged { @@ -1607,6 +1679,7 @@ where } *frame_counter += 1; idle_state.on_client_frame(Instant::now()); + idle_state.tiny_frame_debt = idle_state.tiny_frame_debt.saturating_sub(1); clear_relay_idle_candidate(forensics.conn_id); return Ok(Some((payload, quickack))); } @@ -1682,6 +1755,7 @@ enum MeWriterResponseOutcome { Close, } +#[cfg(test)] async fn process_me_writer_response( response: MeResponse, client_writer: &mut CryptoWriter, @@ -1697,6 +1771,44 @@ async fn process_me_writer_response( ack_flush_immediate: bool, batched: bool, ) -> Result +where + W: AsyncWrite + Unpin + Send + 'static, +{ + process_me_writer_response_with_cross_mode_lock( + response, + client_writer, + proto_tag, + rng, + frame_buf, + stats, + user, + quota_limit, + quota_soft_overshoot_bytes, + None, + bytes_me2c, + conn_id, + ack_flush_immediate, + batched, + ) + .await +} + +async fn process_me_writer_response_with_cross_mode_lock( + response: MeResponse, + client_writer: &mut CryptoWriter, + proto_tag: ProtoTag, + rng: &SecureRandom, + frame_buf: &mut Vec, + stats: &Stats, + user: &str, + quota_limit: Option, + quota_soft_overshoot_bytes: u64, + cross_mode_quota_lock: Option<&Arc>>, + bytes_me2c: &AtomicU64, + conn_id: u64, + ack_flush_immediate: bool, + batched: bool, +) -> Result where W: AsyncWrite + Unpin + Send + 'static, { @@ -1708,39 +1820,76 @@ where trace!(conn_id, bytes = data.len(), flags, "ME->C data"); } let data_len = data.len() as u64; - if quota_would_be_exceeded_for_user_soft( - stats, - user, - quota_limit, - data_len, - quota_soft_overshoot_bytes, - ) { - stats.increment_me_d2c_quota_reject_total(MeD2cQuotaRejectStage::PreWrite); - return Err(ProxyError::DataQuotaExceeded { - user: user.to_string(), - }); - } + if let Some(limit) = quota_limit { + let owned_cross_mode_lock; + let cross_mode_lock = if let Some(lock) = cross_mode_quota_lock { + lock + } else { + owned_cross_mode_lock = + crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(user); + &owned_cross_mode_lock + }; + let cross_mode_quota_guard = cross_mode_lock.lock().await; + let soft_limit = quota_soft_cap(limit, quota_soft_overshoot_bytes); + if quota_would_be_exceeded_for_user_soft( + stats, + user, + Some(limit), + data_len, + quota_soft_overshoot_bytes, + ) { + stats.increment_me_d2c_quota_reject_total(MeD2cQuotaRejectStage::PreWrite); + return Err(ProxyError::DataQuotaExceeded { + user: user.to_string(), + }); + } - let write_mode = - write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf) - .await?; - stats.increment_me_d2c_write_mode(write_mode); + // Reserve quota before awaiting network I/O to avoid same-user HoL stalls. + // If reservation loses a race or write fails, we roll back immediately. + bytes_me2c.fetch_add(data_len, Ordering::Relaxed); + stats.add_user_octets_to(user, data_len); - bytes_me2c.fetch_add(data.len() as u64, Ordering::Relaxed); - stats.add_user_octets_to(user, data.len() as u64); - stats.increment_me_d2c_data_frames_total(); - stats.add_me_d2c_payload_bytes_total(data.len() as u64); + if stats.get_user_total_octets(user) > soft_limit { + rollback_me2c_quota_reservation(stats, user, bytes_me2c, data_len); + stats.increment_me_d2c_quota_reject_total(MeD2cQuotaRejectStage::PreWrite); + return Err(ProxyError::DataQuotaExceeded { + user: user.to_string(), + }); + } - if quota_exceeded_for_user_soft( - stats, - user, - quota_limit, - quota_soft_overshoot_bytes, - ) { - stats.increment_me_d2c_quota_reject_total(MeD2cQuotaRejectStage::PostWrite); - return Err(ProxyError::DataQuotaExceeded { - user: user.to_string(), - }); + // Keep cross-mode lock scope explicit and minimal: quota reservation is serialized, + // but socket I/O proceeds without holding same-user cross-mode admission lock. + drop(cross_mode_quota_guard); + + let write_mode = + match write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf) + .await + { + Ok(mode) => mode, + Err(err) => { + rollback_me2c_quota_reservation(stats, user, bytes_me2c, data_len); + return Err(err); + } + }; + + stats.increment_me_d2c_data_frames_total(); + stats.add_me_d2c_payload_bytes_total(data_len); + stats.increment_me_d2c_write_mode(write_mode); + + // Do not fail immediately on exact boundary after a successful write. + // Returning an error here can bypass batch flush in the caller and risk + // dropping buffered ciphertext from CryptoWriter. The next frame is + // rejected by the pre-check at function entry. + } else { + let write_mode = + write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf) + .await?; + + bytes_me2c.fetch_add(data_len, Ordering::Relaxed); + stats.add_user_octets_to(user, data_len); + stats.increment_me_d2c_data_frames_total(); + stats.add_me_d2c_payload_bytes_total(data_len); + stats.increment_me_d2c_write_mode(write_mode); } Ok(MeWriterResponseOutcome::Continue { @@ -1979,3 +2128,55 @@ mod length_cast_hardening_security_tests; #[cfg(test)] #[path = "tests/middle_relay_blackhat_campaign_integration_tests.rs"] mod blackhat_campaign_integration_tests; + +#[cfg(test)] +#[path = "tests/middle_relay_hol_quota_security_tests.rs"] +mod hol_quota_security_tests; + +#[cfg(test)] +#[path = "tests/middle_relay_quota_reservation_adversarial_tests.rs"] +mod quota_reservation_adversarial_tests; + +#[cfg(test)] +#[path = "tests/middle_relay_idle_registry_poison_security_tests.rs"] +mod middle_relay_idle_registry_poison_security_tests; + +#[cfg(test)] +#[path = "tests/middle_relay_zero_length_frame_security_tests.rs"] +mod middle_relay_zero_length_frame_security_tests; + +#[cfg(test)] +#[path = "tests/middle_relay_tiny_frame_debt_security_tests.rs"] +mod middle_relay_tiny_frame_debt_security_tests; + +#[cfg(test)] +#[path = "tests/middle_relay_tiny_frame_debt_concurrency_security_tests.rs"] +mod middle_relay_tiny_frame_debt_concurrency_security_tests; + +#[cfg(test)] +#[path = "tests/middle_relay_tiny_frame_debt_proto_chunking_security_tests.rs"] +mod middle_relay_tiny_frame_debt_proto_chunking_security_tests; + +#[cfg(test)] +#[path = "tests/middle_relay_cross_mode_quota_reservation_security_tests.rs"] +mod middle_relay_cross_mode_quota_reservation_security_tests; + +#[cfg(test)] +#[path = "tests/middle_relay_cross_mode_quota_lock_matrix_security_tests.rs"] +mod middle_relay_cross_mode_quota_lock_matrix_security_tests; + +#[cfg(test)] +#[path = "tests/middle_relay_cross_mode_lookup_efficiency_security_tests.rs"] +mod middle_relay_cross_mode_lookup_efficiency_security_tests; + +#[cfg(test)] +#[path = "tests/middle_relay_cross_mode_lock_release_regression_security_tests.rs"] +mod middle_relay_cross_mode_lock_release_regression_security_tests; + +#[cfg(test)] +#[path = "tests/middle_relay_quota_extended_attack_surface_security_tests.rs"] +mod middle_relay_quota_extended_attack_surface_security_tests; + +#[cfg(test)] +#[path = "tests/middle_relay_quota_reservation_extreme_security_tests.rs"] +mod middle_relay_quota_reservation_extreme_security_tests; diff --git a/src/proxy/mod.rs b/src/proxy/mod.rs index eebc188..519f1b3 100644 --- a/src/proxy/mod.rs +++ b/src/proxy/mod.rs @@ -64,6 +64,7 @@ pub mod direct_relay; pub mod handshake; pub mod masking; pub mod middle_relay; +pub mod quota_lock_registry; pub mod relay; pub mod route_mode; pub mod session_eviction; diff --git a/src/proxy/quota_lock_registry.rs b/src/proxy/quota_lock_registry.rs new file mode 100644 index 0000000..7798b09 --- /dev/null +++ b/src/proxy/quota_lock_registry.rs @@ -0,0 +1,88 @@ +use dashmap::DashMap; +use std::sync::{Arc, OnceLock}; +use tokio::sync::Mutex; + +#[cfg(test)] +use std::sync::atomic::{AtomicUsize, Ordering}; + +#[cfg(test)] +const CROSS_MODE_QUOTA_USER_LOCKS_MAX: usize = 64; +#[cfg(not(test))] +const CROSS_MODE_QUOTA_USER_LOCKS_MAX: usize = 4_096; +#[cfg(test)] +const CROSS_MODE_QUOTA_OVERFLOW_LOCK_STRIPES: usize = 16; +#[cfg(not(test))] +const CROSS_MODE_QUOTA_OVERFLOW_LOCK_STRIPES: usize = 256; + +static CROSS_MODE_QUOTA_USER_LOCKS: OnceLock>>> = OnceLock::new(); +static CROSS_MODE_QUOTA_USER_OVERFLOW_LOCKS: OnceLock>>> = OnceLock::new(); + +#[cfg(test)] +static CROSS_MODE_QUOTA_USER_LOCK_LOOKUPS: AtomicUsize = AtomicUsize::new(0); +#[cfg(test)] +static CROSS_MODE_QUOTA_USER_LOOKUPS_BY_USER: OnceLock> = OnceLock::new(); + +fn cross_mode_quota_overflow_user_lock(user: &str) -> Arc> { + let stripes = CROSS_MODE_QUOTA_USER_OVERFLOW_LOCKS.get_or_init(|| { + (0..CROSS_MODE_QUOTA_OVERFLOW_LOCK_STRIPES) + .map(|_| Arc::new(Mutex::new(()))) + .collect() + }); + + let hash = crc32fast::hash(user.as_bytes()) as usize; + Arc::clone(&stripes[hash % stripes.len()]) +} + +pub(crate) fn cross_mode_quota_user_lock(user: &str) -> Arc> { + #[cfg(test)] + { + CROSS_MODE_QUOTA_USER_LOCK_LOOKUPS.fetch_add(1, Ordering::Relaxed); + let lookups = CROSS_MODE_QUOTA_USER_LOOKUPS_BY_USER.get_or_init(DashMap::new); + let mut entry = lookups.entry(user.to_string()).or_insert(0); + *entry += 1; + } + + let locks = CROSS_MODE_QUOTA_USER_LOCKS.get_or_init(DashMap::new); + if let Some(existing) = locks.get(user) { + return Arc::clone(existing.value()); + } + + if locks.len() >= CROSS_MODE_QUOTA_USER_LOCKS_MAX { + locks.retain(|_, value| Arc::strong_count(value) > 1); + } + + if locks.len() >= CROSS_MODE_QUOTA_USER_LOCKS_MAX { + return cross_mode_quota_overflow_user_lock(user); + } + + let created = Arc::new(Mutex::new(())); + match locks.entry(user.to_string()) { + dashmap::mapref::entry::Entry::Occupied(entry) => Arc::clone(entry.get()), + dashmap::mapref::entry::Entry::Vacant(entry) => { + entry.insert(Arc::clone(&created)); + created + } + } +} + +#[cfg(test)] +pub(crate) fn reset_cross_mode_quota_user_lock_lookup_count_for_tests() { + CROSS_MODE_QUOTA_USER_LOCK_LOOKUPS.store(0, Ordering::Relaxed); + let lookups = CROSS_MODE_QUOTA_USER_LOOKUPS_BY_USER.get_or_init(DashMap::new); + lookups.clear(); +} + +#[cfg(test)] +pub(crate) fn cross_mode_quota_user_lock_lookup_count_for_tests() -> usize { + CROSS_MODE_QUOTA_USER_LOCK_LOOKUPS.load(Ordering::Relaxed) +} + +#[cfg(test)] +pub(crate) fn cross_mode_quota_user_lock_lookup_count_for_user_for_tests(user: &str) -> usize { + let lookups = CROSS_MODE_QUOTA_USER_LOOKUPS_BY_USER.get_or_init(DashMap::new); + lookups.get(user).map(|entry| *entry).unwrap_or(0) +} + +#[cfg(test)] +#[path = "tests/quota_lock_registry_cross_mode_adversarial_tests.rs"] +mod quota_lock_registry_cross_mode_adversarial_tests; diff --git a/src/proxy/relay.rs b/src/proxy/relay.rs index 2431ff4..55f1385 100644 --- a/src/proxy/relay.rs +++ b/src/proxy/relay.rs @@ -62,7 +62,8 @@ use std::sync::{Arc, Mutex, OnceLock}; use std::task::{Context, Poll}; use std::time::Duration; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf, copy_bidirectional_with_sizes}; -use tokio::time::Instant; +use tokio::sync::Mutex as AsyncMutex; +use tokio::time::{Instant, Sleep}; use tracing::{debug, trace, warn}; // ============= Constants ============= @@ -209,12 +210,16 @@ struct StatsIo { counters: Arc, stats: Arc, user: String, + quota_lock: Option>>, + cross_mode_quota_lock: Option>>, quota_limit: Option, quota_exceeded: Arc, quota_read_wake_scheduled: bool, quota_write_wake_scheduled: bool, - quota_read_retry_active: Arc, - quota_write_retry_active: Arc, + quota_read_retry_sleep: Option>>, + quota_write_retry_sleep: Option>>, + quota_read_retry_attempt: u8, + quota_write_retry_attempt: u8, epoch: Instant, } @@ -230,30 +235,29 @@ impl StatsIo { ) -> Self { // Mark initial activity so the watchdog doesn't fire before data flows counters.touch(Instant::now(), epoch); + let quota_lock = quota_limit.map(|_| quota_user_lock(&user)); + let cross_mode_quota_lock = quota_limit + .map(|_| crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user)); Self { inner, counters, stats, user, + quota_lock, + cross_mode_quota_lock, quota_limit, quota_exceeded, quota_read_wake_scheduled: false, quota_write_wake_scheduled: false, - quota_read_retry_active: Arc::new(AtomicBool::new(false)), - quota_write_retry_active: Arc::new(AtomicBool::new(false)), + quota_read_retry_sleep: None, + quota_write_retry_sleep: None, + quota_read_retry_attempt: 0, + quota_write_retry_attempt: 0, epoch, } } } -impl Drop for StatsIo { - fn drop(&mut self) { - self.quota_read_retry_active.store(false, Ordering::Relaxed); - self.quota_write_retry_active - .store(false, Ordering::Relaxed); - } -} - #[derive(Debug)] struct QuotaIoSentinel; @@ -281,20 +285,69 @@ fn is_quota_io_error(err: &io::Error) -> bool { const QUOTA_CONTENTION_RETRY_INTERVAL: Duration = Duration::from_millis(1); #[cfg(not(test))] const QUOTA_CONTENTION_RETRY_INTERVAL: Duration = Duration::from_millis(2); +#[cfg(test)] +const QUOTA_CONTENTION_RETRY_MAX_INTERVAL: Duration = Duration::from_millis(16); +#[cfg(not(test))] +const QUOTA_CONTENTION_RETRY_MAX_INTERVAL: Duration = Duration::from_millis(64); -fn spawn_quota_retry_waker(retry_active: Arc, waker: std::task::Waker) { - tokio::task::spawn(async move { - loop { - if !retry_active.load(Ordering::Relaxed) { - break; - } - tokio::time::sleep(QUOTA_CONTENTION_RETRY_INTERVAL).await; - if !retry_active.load(Ordering::Relaxed) { - break; - } - waker.wake_by_ref(); - } - }); +#[cfg(test)] +static QUOTA_RETRY_SLEEP_ALLOCS: AtomicU64 = AtomicU64::new(0); +#[cfg(test)] +static QUOTA_USER_LOCK_EVICTOR_SPAWN_COUNT: AtomicU64 = AtomicU64::new(0); + +#[cfg(test)] +pub(crate) fn reset_quota_retry_sleep_allocs_for_tests() { + QUOTA_RETRY_SLEEP_ALLOCS.store(0, Ordering::Relaxed); +} + +#[cfg(test)] +pub(crate) fn quota_retry_sleep_allocs_for_tests() -> u64 { + QUOTA_RETRY_SLEEP_ALLOCS.load(Ordering::Relaxed) +} + +#[inline] +fn quota_contention_retry_delay(retry_attempt: u8) -> Duration { + let shift = u32::from(retry_attempt.min(5)); + let multiplier = 1_u32 << shift; + QUOTA_CONTENTION_RETRY_INTERVAL + .saturating_mul(multiplier) + .min(QUOTA_CONTENTION_RETRY_MAX_INTERVAL) +} + +#[inline] +fn reset_quota_retry_scheduler( + sleep_slot: &mut Option>>, + wake_scheduled: &mut bool, + retry_attempt: &mut u8, +) { + *wake_scheduled = false; + *sleep_slot = None; + *retry_attempt = 0; +} + +fn poll_quota_retry_sleep( + sleep_slot: &mut Option>>, + wake_scheduled: &mut bool, + retry_attempt: &mut u8, + cx: &mut Context<'_>, +) { + if !*wake_scheduled { + *wake_scheduled = true; + #[cfg(test)] + QUOTA_RETRY_SLEEP_ALLOCS.fetch_add(1, Ordering::Relaxed); + *sleep_slot = Some(Box::pin(tokio::time::sleep(quota_contention_retry_delay( + *retry_attempt, + )))); + } + + if let Some(sleep) = sleep_slot.as_mut() + && sleep.as_mut().poll(cx).is_ready() + { + *sleep_slot = None; + *wake_scheduled = false; + *retry_attempt = retry_attempt.saturating_add(1); + cx.waker().wake_by_ref(); + } } static QUOTA_USER_LOCKS: OnceLock>>> = OnceLock::new(); @@ -333,16 +386,47 @@ fn quota_overflow_user_lock(user: &str) -> Arc> { Arc::clone(&stripes[hash % stripes.len()]) } +pub(crate) fn quota_user_lock_evict() { + if let Some(locks) = QUOTA_USER_LOCKS.get() { + locks.retain(|_, value| Arc::strong_count(value) > 1); + } +} + +pub(crate) fn spawn_quota_user_lock_evictor(interval: Duration) -> tokio::task::JoinHandle<()> { + let interval = interval.max(Duration::from_millis(1)); + #[cfg(test)] + QUOTA_USER_LOCK_EVICTOR_SPAWN_COUNT.fetch_add(1, Ordering::Relaxed); + tokio::spawn(async move { + loop { + tokio::time::sleep(interval).await; + quota_user_lock_evict(); + } + }) +} + +#[cfg(test)] +pub(crate) fn spawn_quota_user_lock_evictor_for_tests( + interval: Duration, +) -> tokio::task::JoinHandle<()> { + spawn_quota_user_lock_evictor(interval) +} + +#[cfg(test)] +pub(crate) fn reset_quota_user_lock_evictor_spawn_count_for_tests() { + QUOTA_USER_LOCK_EVICTOR_SPAWN_COUNT.store(0, Ordering::Relaxed); +} + +#[cfg(test)] +pub(crate) fn quota_user_lock_evictor_spawn_count_for_tests() -> u64 { + QUOTA_USER_LOCK_EVICTOR_SPAWN_COUNT.load(Ordering::Relaxed) +} + fn quota_user_lock(user: &str) -> Arc> { let locks = QUOTA_USER_LOCKS.get_or_init(DashMap::new); if let Some(existing) = locks.get(user) { return Arc::clone(existing.value()); } - if locks.len() >= QUOTA_USER_LOCKS_MAX { - locks.retain(|_, value| Arc::strong_count(value) > 1); - } - if locks.len() >= QUOTA_USER_LOCKS_MAX { return quota_overflow_user_lock(user); } @@ -357,6 +441,11 @@ fn quota_user_lock(user: &str) -> Arc> { } } +#[cfg(test)] +pub(crate) fn cross_mode_quota_user_lock_for_tests(user: &str) -> Arc> { + crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(user) +} + impl AsyncRead for StatsIo { fn poll_read( self: Pin<&mut Self>, @@ -368,26 +457,16 @@ impl AsyncRead for StatsIo { return Poll::Ready(Err(quota_io_error())); } - let quota_lock = this - .quota_limit - .is_some() - .then(|| quota_user_lock(&this.user)); - let _quota_guard = if let Some(lock) = quota_lock.as_ref() { + let _quota_guard = if let Some(lock) = this.quota_lock.as_ref() { match lock.try_lock() { - Ok(guard) => { - this.quota_read_wake_scheduled = false; - this.quota_read_retry_active.store(false, Ordering::Relaxed); - Some(guard) - } + Ok(guard) => Some(guard), Err(_) => { - if !this.quota_read_wake_scheduled { - this.quota_read_wake_scheduled = true; - this.quota_read_retry_active.store(true, Ordering::Relaxed); - spawn_quota_retry_waker( - Arc::clone(&this.quota_read_retry_active), - cx.waker().clone(), - ); - } + poll_quota_retry_sleep( + &mut this.quota_read_retry_sleep, + &mut this.quota_read_wake_scheduled, + &mut this.quota_read_retry_attempt, + cx, + ); return Poll::Pending; } } @@ -395,6 +474,29 @@ impl AsyncRead for StatsIo { None }; + let _cross_mode_quota_guard = if let Some(lock) = this.cross_mode_quota_lock.as_ref() { + match lock.try_lock() { + Ok(guard) => Some(guard), + Err(_) => { + poll_quota_retry_sleep( + &mut this.quota_read_retry_sleep, + &mut this.quota_read_wake_scheduled, + &mut this.quota_read_retry_attempt, + cx, + ); + return Poll::Pending; + } + } + } else { + None + }; + + reset_quota_retry_scheduler( + &mut this.quota_read_retry_sleep, + &mut this.quota_read_wake_scheduled, + &mut this.quota_read_retry_attempt, + ); + if let Some(limit) = this.quota_limit && this.stats.get_user_total_octets(&this.user) >= limit { @@ -460,27 +562,16 @@ impl AsyncWrite for StatsIo { return Poll::Ready(Err(quota_io_error())); } - let quota_lock = this - .quota_limit - .is_some() - .then(|| quota_user_lock(&this.user)); - let _quota_guard = if let Some(lock) = quota_lock.as_ref() { + let _quota_guard = if let Some(lock) = this.quota_lock.as_ref() { match lock.try_lock() { - Ok(guard) => { - this.quota_write_wake_scheduled = false; - this.quota_write_retry_active - .store(false, Ordering::Relaxed); - Some(guard) - } + Ok(guard) => Some(guard), Err(_) => { - if !this.quota_write_wake_scheduled { - this.quota_write_wake_scheduled = true; - this.quota_write_retry_active.store(true, Ordering::Relaxed); - spawn_quota_retry_waker( - Arc::clone(&this.quota_write_retry_active), - cx.waker().clone(), - ); - } + poll_quota_retry_sleep( + &mut this.quota_write_retry_sleep, + &mut this.quota_write_wake_scheduled, + &mut this.quota_write_retry_attempt, + cx, + ); return Poll::Pending; } } @@ -488,6 +579,29 @@ impl AsyncWrite for StatsIo { None }; + let _cross_mode_quota_guard = if let Some(lock) = this.cross_mode_quota_lock.as_ref() { + match lock.try_lock() { + Ok(guard) => Some(guard), + Err(_) => { + poll_quota_retry_sleep( + &mut this.quota_write_retry_sleep, + &mut this.quota_write_wake_scheduled, + &mut this.quota_write_retry_attempt, + cx, + ); + return Poll::Pending; + } + } + } else { + None + }; + + reset_quota_retry_scheduler( + &mut this.quota_write_retry_sleep, + &mut this.quota_write_wake_scheduled, + &mut this.quota_write_retry_attempt, + ); + let write_buf = if let Some(limit) = this.quota_limit { let used = this.stats.get_user_total_octets(&this.user); if used >= limit { @@ -780,6 +894,10 @@ mod relay_quota_model_adversarial_tests; #[path = "tests/relay_quota_overflow_regression_tests.rs"] mod relay_quota_overflow_regression_tests; +#[cfg(test)] +#[path = "tests/relay_quota_extended_attack_surface_security_tests.rs"] +mod relay_quota_extended_attack_surface_security_tests; + #[cfg(test)] #[path = "tests/relay_watchdog_delta_security_tests.rs"] mod relay_watchdog_delta_security_tests; @@ -791,3 +909,63 @@ mod relay_quota_waker_storm_adversarial_tests; #[cfg(test)] #[path = "tests/relay_quota_wake_liveness_regression_tests.rs"] mod relay_quota_wake_liveness_regression_tests; + +#[cfg(test)] +#[path = "tests/relay_quota_lock_identity_security_tests.rs"] +mod relay_quota_lock_identity_security_tests; + +#[cfg(test)] +#[path = "tests/relay_cross_mode_quota_lock_security_tests.rs"] +mod relay_cross_mode_quota_lock_security_tests; + +#[cfg(test)] +#[path = "tests/relay_quota_retry_scheduler_tdd_tests.rs"] +mod relay_quota_retry_scheduler_tdd_tests; + +#[cfg(test)] +#[path = "tests/relay_cross_mode_quota_fairness_tdd_tests.rs"] +mod relay_cross_mode_quota_fairness_tdd_tests; + +#[cfg(test)] +#[path = "tests/relay_cross_mode_pipeline_hol_integration_security_tests.rs"] +mod relay_cross_mode_pipeline_hol_integration_security_tests; + +#[cfg(test)] +#[path = "tests/relay_cross_mode_pipeline_latency_benchmark_security_tests.rs"] +mod relay_cross_mode_pipeline_latency_benchmark_security_tests; + +#[cfg(test)] +#[path = "tests/relay_quota_retry_backoff_security_tests.rs"] +mod relay_quota_retry_backoff_security_tests; + +#[cfg(test)] +#[path = "tests/relay_quota_retry_backoff_benchmark_security_tests.rs"] +mod relay_quota_retry_backoff_benchmark_security_tests; + +#[cfg(test)] +#[path = "tests/relay_dual_lock_backoff_regression_security_tests.rs"] +mod relay_dual_lock_backoff_regression_security_tests; + +#[cfg(test)] +#[path = "tests/relay_dual_lock_contention_matrix_security_tests.rs"] +mod relay_dual_lock_contention_matrix_security_tests; + +#[cfg(test)] +#[path = "tests/relay_dual_lock_race_harness_security_tests.rs"] +mod relay_dual_lock_race_harness_security_tests; + +#[cfg(test)] +#[path = "tests/relay_dual_lock_alternating_contention_security_tests.rs"] +mod relay_dual_lock_alternating_contention_security_tests; + +#[cfg(test)] +#[path = "tests/relay_quota_retry_allocation_latency_security_tests.rs"] +mod relay_quota_retry_allocation_latency_security_tests; + +#[cfg(test)] +#[path = "tests/relay_quota_lock_eviction_lifecycle_tdd_tests.rs"] +mod relay_quota_lock_eviction_lifecycle_tdd_tests; + +#[cfg(test)] +#[path = "tests/relay_quota_lock_eviction_stress_security_tests.rs"] +mod relay_quota_lock_eviction_stress_security_tests; diff --git a/src/proxy/tests/client_clever_advanced_tests.rs b/src/proxy/tests/client_clever_advanced_tests.rs new file mode 100644 index 0000000..da2e703 --- /dev/null +++ b/src/proxy/tests/client_clever_advanced_tests.rs @@ -0,0 +1,409 @@ +use super::*; +use crate::config::{UpstreamConfig, UpstreamType, ProxyConfig}; +use crate::protocol::constants::{MAX_TLS_PLAINTEXT_SIZE, MIN_TLS_CLIENT_HELLO_SIZE}; +use crate::stats::Stats; +use crate::transport::UpstreamManager; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; +use std::time::Duration; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt, ReadBuf, duplex}; +use tokio::net::TcpListener; + +#[test] +fn edge_mask_reject_delay_min_greater_than_max_does_not_panic() { + let mut config = ProxyConfig::default(); + config.censorship.server_hello_delay_min_ms = 5000; + config.censorship.server_hello_delay_max_ms = 1000; + + let rt = tokio::runtime::Runtime::new().unwrap(); + rt.block_on(async { + let start = std::time::Instant::now(); + maybe_apply_mask_reject_delay(&config).await; + let elapsed = start.elapsed(); + + assert!(elapsed >= Duration::from_millis(1000)); + assert!(elapsed < Duration::from_millis(1500)); + }); +} + +#[test] +fn edge_handshake_timeout_with_mask_grace_saturating_add_prevents_overflow() { + let mut config = ProxyConfig::default(); + config.timeouts.client_handshake = u64::MAX; + config.censorship.mask = true; + + let timeout = handshake_timeout_with_mask_grace(&config); + assert_eq!(timeout.as_secs(), u64::MAX); +} + +#[test] +fn edge_tls_clienthello_len_in_bounds_exact_boundaries() { + assert!(tls_clienthello_len_in_bounds(MIN_TLS_CLIENT_HELLO_SIZE)); + assert!(!tls_clienthello_len_in_bounds(MIN_TLS_CLIENT_HELLO_SIZE - 1)); + assert!(tls_clienthello_len_in_bounds(MAX_TLS_PLAINTEXT_SIZE)); + assert!(!tls_clienthello_len_in_bounds(MAX_TLS_PLAINTEXT_SIZE + 1)); +} + +#[test] +fn edge_synthetic_local_addr_boundaries() { + assert_eq!(synthetic_local_addr(0).port(), 0); + assert_eq!(synthetic_local_addr(80).port(), 80); + assert_eq!(synthetic_local_addr(u16::MAX).port(), u16::MAX); +} + +#[test] +fn edge_beobachten_record_handshake_failure_class_stream_error_eof() { + let beobachten = BeobachtenStore::new(); + let mut config = ProxyConfig::default(); + config.general.beobachten = true; + config.general.beobachten_minutes = 1; + + let eof_err = ProxyError::Stream(crate::error::StreamError::UnexpectedEof); + let peer_ip: IpAddr = "198.51.100.100".parse().unwrap(); + + record_handshake_failure_class(&beobachten, &config, peer_ip, &eof_err); + + let snapshot = beobachten.snapshot_text(Duration::from_secs(60)); + assert!(snapshot.contains("[expected_64_got_0]")); +} + +#[tokio::test] +async fn adversarial_tls_handshake_timeout_during_masking_delay() { + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.timeouts.client_handshake = 1; + cfg.censorship.mask = true; + cfg.censorship.server_hello_delay_min_ms = 3000; + cfg.censorship.server_hello_delay_max_ms = 3000; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let (server_side, mut client_side) = duplex(4096); + + let handle = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.1:55000".parse().unwrap(), + config, + stats.clone(), + Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + false, + )); + + client_side.write_all(&[0x16, 0x03, 0x01, 0xFF, 0xFF]).await.unwrap(); + + let result = tokio::time::timeout(Duration::from_secs(4), handle) + .await + .unwrap() + .unwrap(); + + assert!(matches!(result, Err(ProxyError::TgHandshakeTimeout))); + assert_eq!(stats.get_handshake_timeouts(), 1); +} + +#[tokio::test] +async fn blackhat_proxy_protocol_slowloris_timeout() { + let mut cfg = ProxyConfig::default(); + cfg.server.proxy_protocol_header_timeout_ms = 200; + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + + let (server_side, mut client_side) = duplex(4096); + let handle = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.2:55000".parse().unwrap(), + config, + stats.clone(), + Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + true, + )); + + client_side.write_all(b"PROXY TCP4 192.").await.unwrap(); + tokio::time::sleep(Duration::from_millis(300)).await; + + let result = tokio::time::timeout(Duration::from_secs(2), handle) + .await + .unwrap() + .unwrap(); + + assert!(matches!(result, Err(ProxyError::InvalidProxyProtocol))); + assert_eq!(stats.get_connects_bad(), 1); +} + +#[test] +fn blackhat_ipv4_mapped_ipv6_proxy_source_bypass_attempt() { + let trusted = vec!["192.0.2.0/24".parse().unwrap()]; + let peer_ip = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0xffff, 0xc000, 0x0201)); + assert!(!is_trusted_proxy_source(peer_ip, &trusted)); +} + +#[tokio::test] +async fn negative_proxy_protocol_enabled_but_client_sends_tls_hello() { + let mut cfg = ProxyConfig::default(); + cfg.server.proxy_protocol_header_timeout_ms = 500; + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + + let (server_side, mut client_side) = duplex(4096); + let handle = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.3:55000".parse().unwrap(), + config, + stats.clone(), + Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + true, + )); + + client_side.write_all(&[0x16, 0x03, 0x01, 0x02, 0x00]).await.unwrap(); + + let result = tokio::time::timeout(Duration::from_secs(2), handle) + .await + .unwrap() + .unwrap(); + + assert!(matches!(result, Err(ProxyError::InvalidProxyProtocol))); + assert_eq!(stats.get_connects_bad(), 1); +} + +#[tokio::test] +async fn edge_client_stream_exactly_4_bytes_eof() { + let config = Arc::new(ProxyConfig::default()); + let stats = Arc::new(Stats::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(4096); + let handle = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.4:55000".parse().unwrap(), + config, + stats.clone(), + Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + beobachten.clone(), + false, + )); + + client_side.write_all(&[0x16, 0x03, 0x01, 0x00]).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(2), handle).await; + + let snapshot = beobachten.snapshot_text(Duration::from_secs(60)); + assert!(snapshot.contains("[expected_64_got_0]")); +} + +#[tokio::test] +async fn edge_client_stream_tls_header_valid_but_body_1_byte_short_eof() { + let config = Arc::new(ProxyConfig::default()); + let stats = Arc::new(Stats::new()); + + let (server_side, mut client_side) = duplex(4096); + let handle = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.5:55000".parse().unwrap(), + config, + stats.clone(), + Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + false, + )); + + client_side.write_all(&[0x16, 0x03, 0x01, 0x00, 100]).await.unwrap(); + client_side.write_all(&vec![0x41; 99]).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(2), handle).await; + assert_eq!(stats.get_connects_bad(), 1); +} + +#[tokio::test] +async fn integration_non_tls_modes_disabled_immediately_masks() { + let mut cfg = ProxyConfig::default(); + cfg.general.modes.classic = false; + cfg.general.modes.secure = false; + cfg.censorship.mask = true; + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + + let (server_side, mut client_side) = duplex(4096); + let handle = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.6:55000".parse().unwrap(), + config, + stats.clone(), + Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + false, + )); + + client_side.write_all(b"GET / HTTP/1.1\r\n").await.unwrap(); + let _ = tokio::time::timeout(Duration::from_secs(2), handle).await; + assert_eq!(stats.get_connects_bad(), 1); +} + +struct YieldingReader { + data: Vec, + pos: usize, + yields_left: usize, +} + +impl AsyncRead for YieldingReader { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let this = self.get_mut(); + if this.yields_left > 0 { + this.yields_left -= 1; + cx.waker().wake_by_ref(); + return Poll::Pending; + } + if this.pos >= this.data.len() { + return Poll::Ready(Ok(())); + } + buf.put_slice(&this.data[this.pos..this.pos + 1]); + this.pos += 1; + this.yields_left = 2; + Poll::Ready(Ok(())) + } +} + +#[tokio::test] +async fn fuzz_read_with_progress_heavy_yielding() { + let expected_data = b"HEAVY_YIELD_TEST_DATA".to_vec(); + let mut reader = YieldingReader { + data: expected_data.clone(), + pos: 0, + yields_left: 2, + }; + + let mut buf = vec![0u8; expected_data.len()]; + let read_bytes = read_with_progress(&mut reader, &mut buf).await.unwrap(); + + assert_eq!(read_bytes, expected_data.len()); + assert_eq!(buf, expected_data); +} + +#[test] +fn edge_wrap_tls_application_record_exactly_u16_max() { + let payload = vec![0u8; 65535]; + let wrapped = wrap_tls_application_record(&payload); + assert_eq!(wrapped.len(), 65540); + assert_eq!(wrapped[0], TLS_RECORD_APPLICATION); + assert_eq!(&wrapped[3..5], &65535u16.to_be_bytes()); +} + +#[test] +fn fuzz_wrap_tls_application_record_lengths() { + let lengths = [0, 1, 65534, 65535, 65536, 131070, 131071, 131072]; + for len in lengths { + let payload = vec![0u8; len]; + let wrapped = wrap_tls_application_record(&payload); + let expected_chunks = len.div_ceil(65535).max(1); + assert_eq!(wrapped.len(), len + 5 * expected_chunks); + } +} + +#[tokio::test] +async fn stress_user_connection_reservation_concurrent_same_ip_exhaustion() { + let user = "stress-same-ip-user"; + let mut config = ProxyConfig::default(); + config.access.user_max_tcp_conns.insert(user.to_string(), 5); + + let config = Arc::new(config); + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 10).await; + + let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(198, 51, 100, 77)), 55000); + + let mut tasks = tokio::task::JoinSet::new(); + let mut reservations = Vec::new(); + + for _ in 0..10 { + let config = config.clone(); + let stats = stats.clone(); + let ip_tracker = ip_tracker.clone(); + tasks.spawn(async move { + RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats, + peer, + ip_tracker, + ) + .await + }); + } + + let mut successes = 0; + let mut failures = 0; + + while let Some(res) = tasks.join_next().await { + match res.unwrap() { + Ok(r) => { + successes += 1; + reservations.push(r); + } + Err(_) => failures += 1, + } + } + + assert_eq!(successes, 5); + assert_eq!(failures, 5); + assert_eq!(stats.get_user_curr_connects(user), 5); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 1); + + for reservation in reservations { + reservation.release().await; + } + + assert_eq!(stats.get_user_curr_connects(user), 0); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 0); +} diff --git a/src/proxy/tests/client_deep_invariants_tests.rs b/src/proxy/tests/client_deep_invariants_tests.rs new file mode 100644 index 0000000..97c55c6 --- /dev/null +++ b/src/proxy/tests/client_deep_invariants_tests.rs @@ -0,0 +1,196 @@ +use super::*; +use crate::config::ProxyConfig; +use crate::protocol::constants::MIN_TLS_CLIENT_HELLO_SIZE; +use crate::stats::Stats; +use crate::transport::UpstreamManager; +use std::sync::Arc; +use std::time::Duration; +use tokio::io::{AsyncWriteExt, duplex}; + +#[test] +fn invariant_wrap_tls_application_record_exact_multiples() { + let chunk_size = u16::MAX as usize; + let payload = vec![0xAA; chunk_size * 2]; + + let wrapped = wrap_tls_application_record(&payload); + + assert_eq!(wrapped.len(), 2 * (5 + chunk_size)); + assert_eq!(wrapped[0], TLS_RECORD_APPLICATION); + assert_eq!(&wrapped[3..5], &65535u16.to_be_bytes()); + + let second_header_idx = 5 + chunk_size; + assert_eq!(wrapped[second_header_idx], TLS_RECORD_APPLICATION); + assert_eq!( + &wrapped[second_header_idx + 3..second_header_idx + 5], + &65535u16.to_be_bytes() + ); +} + +#[tokio::test] +async fn invariant_tls_clienthello_truncation_exact_boundary_triggers_masking() { + let config = Arc::new(ProxyConfig::default()); + let stats = Arc::new(Stats::new()); + + let (server_side, mut client_side) = duplex(4096); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.20:55000".parse().unwrap(), + config, + stats.clone(), + Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + false, + )); + + let claimed_len = MIN_TLS_CLIENT_HELLO_SIZE as u16; + let mut header = vec![0x16, 0x03, 0x01]; + header.extend_from_slice(&claimed_len.to_be_bytes()); + + client_side.write_all(&header).await.unwrap(); + client_side + .write_all(&vec![0x42; MIN_TLS_CLIENT_HELLO_SIZE - 1]) + .await + .unwrap(); + client_side.shutdown().await.unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(2), handler).await.unwrap(); + assert_eq!(stats.get_connects_bad(), 1); +} + +#[tokio::test] +async fn invariant_acquire_reservation_ip_limit_rollback() { + let user = "rollback-test-user"; + let mut config = ProxyConfig::default(); + config.access.user_max_tcp_conns.insert(user.to_string(), 10); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 1).await; + + let peer_a = "198.51.100.21:55000".parse().unwrap(); + let _res_a = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer_a, + ip_tracker.clone(), + ) + .await + .unwrap(); + + assert_eq!(stats.get_user_curr_connects(user), 1); + + let peer_b = "203.0.113.22:55000".parse().unwrap(); + let res_b = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer_b, + ip_tracker.clone(), + ) + .await; + + assert!(matches!( + res_b, + Err(ProxyError::ConnectionLimitExceeded { .. }) + )); + assert_eq!(stats.get_user_curr_connects(user), 1); +} + +#[tokio::test] +async fn invariant_quota_exact_boundary_inclusive() { + let user = "quota-strict-user"; + let mut config = ProxyConfig::default(); + config.access.user_data_quota.insert(user.to_string(), 1000); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + let peer = "198.51.100.23:55000".parse().unwrap(); + + stats.add_user_octets_from(user, 999); + let res1 = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await; + assert!(res1.is_ok()); + res1.unwrap().release().await; + + stats.add_user_octets_from(user, 1); + let res2 = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await; + assert!(matches!(res2, Err(ProxyError::DataQuotaExceeded { .. }))); +} + +#[tokio::test] +async fn invariant_direct_mode_partial_header_eof_is_error_not_bad_connect() { + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = true; + cfg.general.beobachten_minutes = 1; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(4096); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.25:55000".parse().unwrap(), + config, + stats.clone(), + Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + beobachten.clone(), + false, + )); + + client_side.write_all(&[0xEF, 0xEF, 0xEF]).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let result = tokio::time::timeout(Duration::from_secs(2), handler) + .await + .unwrap() + .unwrap(); + + assert!(result.is_err()); + assert_eq!(stats.get_connects_bad(), 0); + let snapshot = beobachten.snapshot_text(Duration::from_secs(60)); + assert!(snapshot.contains("[expected_64_got_0]")); +} + +#[tokio::test] +async fn invariant_route_mode_snapshot_picks_up_latest_mode() { + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + assert!(matches!( + route_runtime.snapshot().mode, + RelayRouteMode::Direct + )); + + route_runtime.set_mode(RelayRouteMode::Middle); + assert!(matches!( + route_runtime.snapshot().mode, + RelayRouteMode::Middle + )); +} diff --git a/src/proxy/tests/client_masking_fragmented_classifier_security_tests.rs b/src/proxy/tests/client_masking_fragmented_classifier_security_tests.rs new file mode 100644 index 0000000..d7ac4ef --- /dev/null +++ b/src/proxy/tests/client_masking_fragmented_classifier_security_tests.rs @@ -0,0 +1,100 @@ +use super::*; +use crate::config::{UpstreamConfig, UpstreamType}; +use std::sync::Arc; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::net::TcpListener; +use tokio::time::Duration; + +fn new_upstream_manager(stats: Arc) -> Arc { + Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats, + )) +} + +#[tokio::test] +async fn fragmented_connect_probe_is_classified_as_http_via_prefetch_window() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = Vec::new(); + stream.read_to_end(&mut got).await.unwrap(); + got + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = true; + cfg.general.beobachten_minutes = 1; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.general.modes.classic = false; + cfg.general.modes.secure = false; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(4096); + let peer: SocketAddr = "198.51.100.251:57501".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats.clone(), + new_upstream_manager(stats), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + beobachten.clone(), + false, + )); + + client_side.write_all(b"CONNE").await.unwrap(); + client_side + .write_all(b"CT example.org:443 HTTP/1.1\r\nHost: example.org\r\n\r\n") + .await + .unwrap(); + client_side.shutdown().await.unwrap(); + + let forwarded = tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + assert!( + forwarded.starts_with(b"CONNECT example.org:443 HTTP/1.1"), + "mask backend must receive the full fragmented CONNECT probe" + ); + + let result = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + assert!(result.is_ok()); + + let snapshot = beobachten.snapshot_text(Duration::from_secs(60)); + assert!(snapshot.contains("[HTTP]")); + assert!(snapshot.contains("198.51.100.251-1")); +} diff --git a/src/proxy/tests/client_masking_http2_fragmented_preface_security_tests.rs b/src/proxy/tests/client_masking_http2_fragmented_preface_security_tests.rs new file mode 100644 index 0000000..fcf51ab --- /dev/null +++ b/src/proxy/tests/client_masking_http2_fragmented_preface_security_tests.rs @@ -0,0 +1,129 @@ +use super::*; +use crate::config::{UpstreamConfig, UpstreamType}; +use std::sync::Arc; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::net::TcpListener; +use tokio::time::{Duration, sleep}; + +fn new_upstream_manager(stats: Arc) -> Arc { + Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats, + )) +} + +async fn run_http2_fragment_case(split_at: usize, delay_ms: u64, peer: SocketAddr) { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let preface = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = Vec::new(); + stream.read_to_end(&mut got).await.unwrap(); + got + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = true; + cfg.general.beobachten_minutes = 1; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.general.modes.classic = false; + cfg.general.modes.secure = false; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(4096); + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats.clone(), + new_upstream_manager(stats), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + beobachten.clone(), + false, + )); + + let first = split_at.min(preface.len()); + client_side.write_all(&preface[..first]).await.unwrap(); + if first < preface.len() { + sleep(Duration::from_millis(delay_ms)).await; + client_side.write_all(&preface[first..]).await.unwrap(); + } + client_side.shutdown().await.unwrap(); + + let forwarded = tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + assert!( + forwarded.starts_with(&preface), + "mask backend must receive an intact HTTP/2 preface prefix" + ); + + let result = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + assert!(result.is_ok()); + + let snapshot = beobachten.snapshot_text(Duration::from_secs(60)); + assert!(snapshot.contains("[HTTP]")); + assert!(snapshot.contains(&format!("{}-1", peer.ip()))); +} + +#[tokio::test] +async fn http2_preface_fragmentation_matrix_is_classified_and_forwarded() { + let cases = [ + (2usize, 0u64), + (3, 0), + (4, 0), + (2, 7), + (3, 7), + (8, 1), + ]; + + for (i, (split_at, delay_ms)) in cases.into_iter().enumerate() { + let peer: SocketAddr = format!("198.51.100.{}:58{}", 140 + i, 100 + i) + .parse() + .unwrap(); + run_http2_fragment_case(split_at, delay_ms, peer).await; + } +} + +#[tokio::test] +async fn http2_preface_splitpoint_light_fuzz_classifies_http() { + for split_at in 2usize..=12 { + let delay_ms = if split_at % 3 == 0 { 7 } else { 1 }; + let peer: SocketAddr = format!("198.51.101.{}:59{}", split_at, 10 + split_at) + .parse() + .unwrap(); + run_http2_fragment_case(split_at, delay_ms, peer).await; + } +} diff --git a/src/proxy/tests/client_masking_prefetch_config_pipeline_integration_security_tests.rs b/src/proxy/tests/client_masking_prefetch_config_pipeline_integration_security_tests.rs new file mode 100644 index 0000000..e64dc03 --- /dev/null +++ b/src/proxy/tests/client_masking_prefetch_config_pipeline_integration_security_tests.rs @@ -0,0 +1,150 @@ +use super::*; +use crate::config::{UpstreamConfig, UpstreamType}; +use std::sync::Arc; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::net::TcpListener; +use tokio::time::{Duration, sleep}; + +fn new_upstream_manager(stats: Arc) -> Arc { + Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats, + )) +} + +async fn run_pipeline_prefetch_case( + prefetch_timeout_ms: u64, + delayed_tail_ms: u64, + peer: SocketAddr, +) -> (Vec, String) { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = Vec::new(); + stream.read_to_end(&mut got).await.unwrap(); + got + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = true; + cfg.general.beobachten_minutes = 1; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_classifier_prefetch_timeout_ms = prefetch_timeout_ms; + cfg.general.modes.classic = false; + cfg.general.modes.secure = false; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(4096); + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats.clone(), + new_upstream_manager(stats), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + beobachten.clone(), + false, + )); + + client_side.write_all(b"C").await.unwrap(); + sleep(Duration::from_millis(delayed_tail_ms)).await; + + client_side + .write_all(b"ONNECT example.org:443 HTTP/1.1\r\nHost: example.org\r\n\r\n") + .await + .unwrap(); + client_side.shutdown().await.unwrap(); + + let forwarded = tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + let result = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + assert!(result.is_ok()); + + let snapshot = beobachten.snapshot_text(Duration::from_secs(60)); + (forwarded, snapshot) +} + +#[tokio::test] +async fn tdd_pipeline_prefetch_5ms_misses_15ms_tail_and_classifies_as_port_scanner() { + let peer: SocketAddr = "198.51.100.171:58071".parse().unwrap(); + let (forwarded, snapshot) = run_pipeline_prefetch_case(5, 15, peer).await; + + assert!( + forwarded.starts_with(b"CONNECT"), + "mask backend must still receive full payload bytes in-order" + ); + assert!( + snapshot.contains("[HTTP]") || snapshot.contains("[port-scanner]"), + "unexpected classifier snapshot for 5ms delayed-tail case: {snapshot}" + ); +} + +#[tokio::test] +async fn tdd_pipeline_prefetch_20ms_recovers_15ms_tail_and_classifies_as_http() { + let peer: SocketAddr = "198.51.100.172:58072".parse().unwrap(); + let (forwarded, snapshot) = run_pipeline_prefetch_case(20, 15, peer).await; + + assert!( + forwarded.starts_with(b"CONNECT"), + "mask backend must receive full CONNECT payload" + ); + assert!( + snapshot.contains("[HTTP]"), + "20ms budget should recover delayed fragmented prefix and classify as HTTP" + ); +} + +#[tokio::test] +async fn matrix_pipeline_prefetch_budget_behavior_5_20_50ms() { + let peer5: SocketAddr = "198.51.100.173:58073".parse().unwrap(); + let peer20: SocketAddr = "198.51.100.174:58074".parse().unwrap(); + let peer50: SocketAddr = "198.51.100.175:58075".parse().unwrap(); + + let (_, snap5) = run_pipeline_prefetch_case(5, 35, peer5).await; + let (_, snap20) = run_pipeline_prefetch_case(20, 35, peer20).await; + let (_, snap50) = run_pipeline_prefetch_case(50, 35, peer50).await; + + assert!( + snap5.contains("[HTTP]") || snap5.contains("[port-scanner]"), + "unexpected 5ms snapshot: {snap5}" + ); + assert!( + snap20.contains("[HTTP]") || snap20.contains("[port-scanner]"), + "unexpected 20ms snapshot: {snap20}" + ); + assert!(snap50.contains("[HTTP]")); +} diff --git a/src/proxy/tests/client_masking_prefetch_config_runtime_security_tests.rs b/src/proxy/tests/client_masking_prefetch_config_runtime_security_tests.rs new file mode 100644 index 0000000..cdf2136 --- /dev/null +++ b/src/proxy/tests/client_masking_prefetch_config_runtime_security_tests.rs @@ -0,0 +1,82 @@ +use super::*; +use tokio::io::{AsyncWriteExt, duplex}; +use tokio::time::{Duration, sleep}; + +#[test] +fn prefetch_timeout_budget_reads_from_config() { + let mut cfg = ProxyConfig::default(); + assert_eq!( + mask_classifier_prefetch_timeout(&cfg), + Duration::from_millis(5), + "default prefetch timeout budget must remain 5ms" + ); + + cfg.censorship.mask_classifier_prefetch_timeout_ms = 20; + assert_eq!( + mask_classifier_prefetch_timeout(&cfg), + Duration::from_millis(20), + "runtime prefetch timeout budget must follow configured value" + ); +} + +#[tokio::test] +async fn configured_prefetch_budget_20ms_recovers_tail_delayed_15ms() { + let (mut reader, mut writer) = duplex(1024); + + let writer_task = tokio::spawn(async move { + sleep(Duration::from_millis(15)).await; + writer + .write_all(b"ONNECT example.org:443 HTTP/1.1\r\n") + .await + .expect("tail bytes must be writable"); + writer.shutdown().await.expect("writer shutdown must succeed"); + }); + + let mut initial_data = b"C".to_vec(); + extend_masking_initial_window_with_timeout( + &mut reader, + &mut initial_data, + Duration::from_millis(20), + ) + .await; + + writer_task + .await + .expect("writer task must not panic in runtime timeout test"); + + assert!( + initial_data.starts_with(b"CONNECT"), + "20ms configured prefetch budget should recover 15ms delayed CONNECT tail" + ); +} + +#[tokio::test] +async fn configured_prefetch_budget_5ms_misses_tail_delayed_15ms() { + let (mut reader, mut writer) = duplex(1024); + + let writer_task = tokio::spawn(async move { + sleep(Duration::from_millis(15)).await; + writer + .write_all(b"ONNECT example.org:443 HTTP/1.1\r\n") + .await + .expect("tail bytes must be writable"); + writer.shutdown().await.expect("writer shutdown must succeed"); + }); + + let mut initial_data = b"C".to_vec(); + extend_masking_initial_window_with_timeout( + &mut reader, + &mut initial_data, + Duration::from_millis(5), + ) + .await; + + writer_task + .await + .expect("writer task must not panic in runtime timeout test"); + + assert!( + !initial_data.starts_with(b"CONNECT"), + "5ms configured prefetch budget should miss 15ms delayed CONNECT tail" + ); +} diff --git a/src/proxy/tests/client_masking_prefetch_invariant_security_tests.rs b/src/proxy/tests/client_masking_prefetch_invariant_security_tests.rs new file mode 100644 index 0000000..2e03ce9 --- /dev/null +++ b/src/proxy/tests/client_masking_prefetch_invariant_security_tests.rs @@ -0,0 +1,261 @@ +use super::*; +use crate::config::{UpstreamConfig, UpstreamType}; +use crate::crypto::sha256_hmac; +use crate::protocol::constants::{HANDSHAKE_LEN, TLS_VERSION}; +use crate::protocol::tls; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::net::TcpListener; + +struct PipelineHarness { + config: Arc, + stats: Arc, + upstream_manager: Arc, + replay_checker: Arc, + buffer_pool: Arc, + rng: Arc, + route_runtime: Arc, + ip_tracker: Arc, + beobachten: Arc, +} + +fn build_harness(secret_hex: &str, mask_port: u16) -> PipelineHarness { + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = mask_port; + cfg.censorship.mask_proxy_protocol = 0; + cfg.access.ignore_time_skew = true; + cfg.access + .users + .insert("user".to_string(), secret_hex.to_string()); + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + + PipelineHarness { + config, + stats, + upstream_manager, + replay_checker: Arc::new(ReplayChecker::new(256, Duration::from_secs(60))), + buffer_pool: Arc::new(BufferPool::new()), + rng: Arc::new(SecureRandom::new()), + route_runtime: Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + ip_tracker: Arc::new(UserIpTracker::new()), + beobachten: Arc::new(BeobachtenStore::new()), + } +} + +fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec { + let total_len = 5 + tls_len; + let mut handshake = vec![fill; total_len]; + + handshake[0] = 0x16; + handshake[1] = 0x03; + handshake[2] = 0x01; + handshake[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes()); + + let session_id_len: usize = 32; + handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8; + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + let computed = sha256_hmac(secret, &handshake); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .copy_from_slice(&digest); + + handshake +} + +fn wrap_tls_application_data(payload: &[u8]) -> Vec { + let mut record = Vec::with_capacity(5 + payload.len()); + record.push(0x17); + record.extend_from_slice(&TLS_VERSION); + record.extend_from_slice(&(payload.len() as u16).to_be_bytes()); + record.extend_from_slice(payload); + record +} + +async fn read_and_discard_tls_record_body(stream: &mut T, header: [u8; 5]) +where + T: tokio::io::AsyncRead + Unpin, +{ + let len = u16::from_be_bytes([header[3], header[4]]) as usize; + let mut body = vec![0u8; len]; + stream.read_exact(&mut body).await.unwrap(); +} + +#[test] +fn empty_initial_data_prefetch_gate_is_fail_closed() { + assert!( + !should_prefetch_mask_classifier_window(&[]), + "empty initial_data must not trigger classifier prefetch" + ); +} + +#[tokio::test] +async fn blackhat_empty_initial_data_prefetch_must_not_consume_fallback_payload() { + let payload = b"\x17\x03\x03\x00\x10coalesced-tail-bytes".to_vec(); + let (mut reader, mut writer) = duplex(1024); + + writer.write_all(&payload).await.unwrap(); + writer.shutdown().await.unwrap(); + + let mut initial_data = Vec::new(); + extend_masking_initial_window(&mut reader, &mut initial_data).await; + + assert!( + initial_data.is_empty(), + "empty initial_data must remain empty after prefetch stage" + ); + + let mut remaining = Vec::new(); + reader.read_to_end(&mut remaining).await.unwrap(); + assert_eq!( + remaining, payload, + "prefetch stage must not consume fallback payload when initial_data is empty" + ); +} + +#[tokio::test] +async fn positive_fragmented_http_prefix_still_prefetches_within_window() { + let (mut reader, mut writer) = duplex(1024); + writer + .write_all(b"NECT example.org:443 HTTP/1.1\r\n") + .await + .unwrap(); + writer.shutdown().await.unwrap(); + + let mut initial_data = b"CON".to_vec(); + extend_masking_initial_window(&mut reader, &mut initial_data).await; + + assert!( + initial_data.starts_with(b"CONNECT"), + "fragmented HTTP method prefix should still be recoverable by prefetch" + ); + assert!( + initial_data.len() <= 16, + "prefetch window must remain bounded" + ); +} + +#[tokio::test] +async fn light_fuzz_empty_initial_data_never_prefetches_any_bytes() { + let mut seed = 0xD15C_A11E_2026_0322u64; + + for _ in 0..128 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let len = ((seed & 0x3f) as usize).saturating_add(1); + let mut payload = vec![0u8; len]; + for (idx, byte) in payload.iter_mut().enumerate() { + *byte = (seed as u8).wrapping_add(idx as u8).wrapping_mul(17); + } + + let (mut reader, mut writer) = duplex(1024); + writer.write_all(&payload).await.unwrap(); + writer.shutdown().await.unwrap(); + + let mut initial_data = Vec::new(); + extend_masking_initial_window(&mut reader, &mut initial_data).await; + assert!(initial_data.is_empty()); + + let mut remaining = Vec::new(); + reader.read_to_end(&mut remaining).await.unwrap(); + assert_eq!(remaining, payload); + } +} + +#[tokio::test] +async fn blackhat_integration_empty_initial_data_path_is_byte_exact_and_eof_clean() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0xD3u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 411, 600, 0x2B); + let mut invalid_payload = vec![0u8; HANDSHAKE_LEN]; + invalid_payload[0] = 0xFF; + let invalid_mtproto_record = wrap_tls_application_data(&invalid_payload); + let trailing_record = wrap_tls_application_data(b"empty-prefetch-invariant"); + let expected = trailing_record.clone(); + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got = vec![0u8; expected.len()]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, expected); + + let mut one = [0u8; 1]; + let n = stream.read(&mut one).await.unwrap(); + assert_eq!( + n, 0, + "fallback stream must not append synthetic bytes on empty initial_data path" + ); + }); + + let harness = build_harness("d3d3d3d3d3d3d3d3d3d3d3d3d3d3d3d3", backend_addr.port()); + let (server_side, mut client_side) = duplex(131072); + + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.245:56145".parse().unwrap(), + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut head = [0u8; 5]; + client_side.read_exact(&mut head).await.unwrap(); + assert_eq!(head[0], 0x16); + read_and_discard_tls_record_body(&mut client_side, head).await; + + client_side.write_all(&invalid_mtproto_record).await.unwrap(); + client_side.write_all(&trailing_record).await.unwrap(); + client_side.shutdown().await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} diff --git a/src/proxy/tests/client_masking_prefetch_strict_boundary_security_tests.rs b/src/proxy/tests/client_masking_prefetch_strict_boundary_security_tests.rs new file mode 100644 index 0000000..9ece258 --- /dev/null +++ b/src/proxy/tests/client_masking_prefetch_strict_boundary_security_tests.rs @@ -0,0 +1,70 @@ +use super::*; +use tokio::io::{AsyncWriteExt, duplex}; +use tokio::time::{Duration, advance, sleep}; + +async fn run_strict_prefetch_case(prefetch_ms: u64, tail_delay_ms: u64) -> Vec { + let (mut reader, mut writer) = duplex(1024); + + let writer_task = tokio::spawn(async move { + sleep(Duration::from_millis(tail_delay_ms)).await; + let _ = writer.write_all(b"ONNECT example.org:443 HTTP/1.1\r\n").await; + let _ = writer.shutdown().await; + }); + + let mut initial_data = b"C".to_vec(); + let mut prefetch_task = tokio::spawn(async move { + extend_masking_initial_window_with_timeout( + &mut reader, + &mut initial_data, + Duration::from_millis(prefetch_ms), + ) + .await; + initial_data + }); + + tokio::task::yield_now().await; + + if tail_delay_ms > 0 { + advance(Duration::from_millis(tail_delay_ms)).await; + tokio::task::yield_now().await; + } + + if prefetch_ms > tail_delay_ms { + advance(Duration::from_millis(prefetch_ms - tail_delay_ms)).await; + tokio::task::yield_now().await; + } + + let result = prefetch_task.await.expect("prefetch task must not panic"); + writer_task.await.expect("writer task must not panic"); + result +} + +#[tokio::test(start_paused = true)] +async fn strict_prefetch_5ms_misses_15ms_tail() { + let got = run_strict_prefetch_case(5, 15).await; + assert_eq!(got, b"C".to_vec()); +} + +#[tokio::test(start_paused = true)] +async fn strict_prefetch_20ms_recovers_15ms_tail() { + let got = run_strict_prefetch_case(20, 15).await; + assert!(got.starts_with(b"CONNECT")); +} + +#[tokio::test(start_paused = true)] +async fn strict_prefetch_50ms_recovers_35ms_tail() { + let got = run_strict_prefetch_case(50, 35).await; + assert!(got.starts_with(b"CONNECT")); +} + +#[tokio::test(start_paused = true)] +async fn strict_prefetch_equal_budget_and_delay_recovers_tail() { + let got = run_strict_prefetch_case(20, 20).await; + assert!(got.starts_with(b"CONNECT")); +} + +#[tokio::test(start_paused = true)] +async fn strict_prefetch_one_ms_after_budget_misses_tail() { + let got = run_strict_prefetch_case(20, 21).await; + assert_eq!(got, b"C".to_vec()); +} diff --git a/src/proxy/tests/client_masking_prefetch_timing_matrix_security_tests.rs b/src/proxy/tests/client_masking_prefetch_timing_matrix_security_tests.rs new file mode 100644 index 0000000..3f4ab17 --- /dev/null +++ b/src/proxy/tests/client_masking_prefetch_timing_matrix_security_tests.rs @@ -0,0 +1,95 @@ +use super::*; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::time::{Duration, sleep, timeout}; + +async fn extend_masking_initial_window_with_budget( + reader: &mut R, + initial_data: &mut Vec, + prefetch_timeout: Duration, +) where + R: AsyncRead + Unpin, +{ + if !should_prefetch_mask_classifier_window(initial_data) { + return; + } + + let need = 16usize.saturating_sub(initial_data.len()); + if need == 0 { + return; + } + + let mut extra = [0u8; 16]; + if let Ok(Ok(n)) = timeout(prefetch_timeout, reader.read(&mut extra[..need])).await + && n > 0 + { + initial_data.extend_from_slice(&extra[..n]); + } +} + +async fn run_prefetch_budget_case(prefetch_budget_ms: u64, delayed_tail_ms: u64) -> bool { + let (mut reader, mut writer) = duplex(1024); + + let writer_task = tokio::spawn(async move { + sleep(Duration::from_millis(delayed_tail_ms)).await; + writer + .write_all(b"ONNECT example.org:443 HTTP/1.1\r\n") + .await + .expect("tail bytes must be writable"); + writer.shutdown().await.expect("writer shutdown must succeed"); + }); + + let mut initial_data = b"C".to_vec(); + extend_masking_initial_window_with_budget( + &mut reader, + &mut initial_data, + Duration::from_millis(prefetch_budget_ms), + ) + .await; + + writer_task + .await + .expect("writer task must not panic during matrix case"); + + initial_data.starts_with(b"CONNECT") +} + +#[tokio::test] +async fn adversarial_prefetch_budget_matrix_5_20_50ms_for_fragmented_connect_tail() { + let cases = [ + // (tail-delay-ms, expected CONNECT recovery for budgets [5, 20, 50]) + (2u64, [true, true, true]), + (15u64, [false, true, true]), + (35u64, [false, false, true]), + ]; + + for (tail_delay_ms, expected) in cases { + let got_5 = run_prefetch_budget_case(5, tail_delay_ms).await; + let got_20 = run_prefetch_budget_case(20, tail_delay_ms).await; + let got_50 = run_prefetch_budget_case(50, tail_delay_ms).await; + + assert_eq!( + got_5, expected[0], + "5ms prefetch budget mismatch for tail delay {}ms", + tail_delay_ms + ); + assert_eq!( + got_20, expected[1], + "20ms prefetch budget mismatch for tail delay {}ms", + tail_delay_ms + ); + assert_eq!( + got_50, expected[2], + "50ms prefetch budget mismatch for tail delay {}ms", + tail_delay_ms + ); + } +} + +#[tokio::test] +async fn control_current_runtime_prefetch_budget_is_5ms() { + assert_eq!( + MASK_CLASSIFIER_PREFETCH_TIMEOUT, + Duration::from_millis(5), + "matrix assumptions require current runtime prefetch budget to stay at 5ms" + ); +} diff --git a/src/proxy/tests/client_masking_replay_timing_security_tests.rs b/src/proxy/tests/client_masking_replay_timing_security_tests.rs new file mode 100644 index 0000000..225ce50 --- /dev/null +++ b/src/proxy/tests/client_masking_replay_timing_security_tests.rs @@ -0,0 +1,161 @@ +use super::*; +use crate::config::{UpstreamConfig, UpstreamType}; +use crate::crypto::sha256_hmac; +use crate::protocol::constants::{HANDSHAKE_LEN, TLS_VERSION}; +use crate::protocol::tls; +use std::sync::Arc; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::time::{Duration, Instant}; + +fn new_upstream_manager(stats: Arc) -> Arc { + Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats, + )) +} + +fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec { + let total_len = 5 + tls_len; + let mut handshake = vec![fill; total_len]; + + handshake[0] = 0x16; + handshake[1] = 0x03; + handshake[2] = 0x01; + handshake[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes()); + + let session_id_len: usize = 32; + handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8; + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + let computed = sha256_hmac(secret, &handshake); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .copy_from_slice(&digest); + handshake +} + +async fn run_replay_candidate_session( + replay_checker: Arc, + hello: &[u8], + peer: SocketAddr, + drive_mtproto_fail: bool, +) -> Duration { + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = 1; + cfg.censorship.mask_timing_normalization_enabled = false; + cfg.access.ignore_time_skew = true; + cfg.access + .users + .insert("user".to_string(), "abababababababababababababababab".to_string()); + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(65536); + let started = Instant::now(); + + let task = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats.clone(), + new_upstream_manager(stats), + replay_checker, + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + beobachten, + false, + )); + + client_side.write_all(hello).await.unwrap(); + + if drive_mtproto_fail { + let mut server_hello_head = [0u8; 5]; + client_side.read_exact(&mut server_hello_head).await.unwrap(); + assert_eq!(server_hello_head[0], 0x16); + let body_len = u16::from_be_bytes([server_hello_head[3], server_hello_head[4]]) as usize; + let mut body = vec![0u8; body_len]; + client_side.read_exact(&mut body).await.unwrap(); + + let mut invalid_mtproto_record = Vec::with_capacity(5 + HANDSHAKE_LEN); + invalid_mtproto_record.push(0x17); + invalid_mtproto_record.extend_from_slice(&TLS_VERSION); + invalid_mtproto_record.extend_from_slice(&(HANDSHAKE_LEN as u16).to_be_bytes()); + invalid_mtproto_record.extend_from_slice(&vec![0u8; HANDSHAKE_LEN]); + client_side.write_all(&invalid_mtproto_record).await.unwrap(); + client_side + .write_all(b"GET /replay-fallback HTTP/1.1\r\nHost: x\r\n\r\n") + .await + .unwrap(); + } + + client_side.shutdown().await.unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(4), task) + .await + .unwrap() + .unwrap(); + + started.elapsed() +} + +#[tokio::test] +async fn replay_reject_still_honors_masking_timing_budget() { + let replay_checker = Arc::new(ReplayChecker::new(256, Duration::from_secs(60))); + let hello = make_valid_tls_client_hello(&[0xAB; 16], 7, 600, 0x51); + + let seed_elapsed = run_replay_candidate_session( + Arc::clone(&replay_checker), + &hello, + "198.51.100.201:58001".parse().unwrap(), + true, + ) + .await; + + assert!( + seed_elapsed >= Duration::from_millis(40) && seed_elapsed < Duration::from_millis(250), + "seed replay-candidate run must honor masking timing budget without unbounded delay" + ); + + let replay_elapsed = run_replay_candidate_session( + Arc::clone(&replay_checker), + &hello, + "198.51.100.202:58002".parse().unwrap(), + false, + ) + .await; + + assert!( + replay_elapsed >= Duration::from_millis(40) + && replay_elapsed < Duration::from_millis(250), + "replay rejection path must still satisfy masking timing budget without unbounded DB/CPU delay" + ); +} diff --git a/src/proxy/tests/client_more_advanced_tests.rs b/src/proxy/tests/client_more_advanced_tests.rs new file mode 100644 index 0000000..021848a --- /dev/null +++ b/src/proxy/tests/client_more_advanced_tests.rs @@ -0,0 +1,257 @@ +use super::*; +use crate::config::ProxyConfig; +use crate::stats::Stats; +use crate::transport::UpstreamManager; +use std::sync::Arc; +use std::time::Duration; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; + +#[tokio::test] +async fn edge_mask_delay_bypassed_if_max_is_zero() { + let mut config = ProxyConfig::default(); + config.censorship.server_hello_delay_min_ms = 10_000; + config.censorship.server_hello_delay_max_ms = 0; + + let start = std::time::Instant::now(); + maybe_apply_mask_reject_delay(&config).await; + assert!(start.elapsed() < Duration::from_millis(50)); +} + +#[test] +fn edge_beobachten_ttl_clamps_exactly_to_24_hours() { + let mut config = ProxyConfig::default(); + config.general.beobachten = true; + config.general.beobachten_minutes = 100_000; + + let ttl = beobachten_ttl(&config); + assert_eq!(ttl.as_secs(), 24 * 60 * 60); +} + +#[test] +fn edge_wrap_tls_application_record_empty_payload() { + let wrapped = wrap_tls_application_record(&[]); + assert_eq!(wrapped.len(), 5); + assert_eq!(wrapped[0], TLS_RECORD_APPLICATION); + assert_eq!(&wrapped[3..5], &[0, 0]); +} + +#[tokio::test] +async fn boundary_user_data_quota_exact_match_rejects() { + let user = "quota-boundary-user"; + let mut config = ProxyConfig::default(); + config.access.user_data_quota.insert(user.to_string(), 1024); + + let stats = Arc::new(Stats::new()); + stats.add_user_octets_from(user, 1024); + + let ip_tracker = Arc::new(UserIpTracker::new()); + let peer = "198.51.100.10:55000".parse().unwrap(); + + let result = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats, + peer, + ip_tracker, + ) + .await; + + assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. }))); +} + +#[tokio::test] +async fn boundary_user_expiration_in_past_rejects() { + let user = "expired-boundary-user"; + let mut config = ProxyConfig::default(); + let expired_time = chrono::Utc::now() - chrono::Duration::milliseconds(1); + config + .access + .user_expirations + .insert(user.to_string(), expired_time); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + let peer = "198.51.100.11:55000".parse().unwrap(); + + let result = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats, + peer, + ip_tracker, + ) + .await; + + assert!(matches!(result, Err(ProxyError::UserExpired { .. }))); +} + +#[tokio::test] +async fn blackhat_proxy_protocol_massive_garbage_rejected_quickly() { + let mut cfg = ProxyConfig::default(); + cfg.server.proxy_protocol_header_timeout_ms = 300; + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + + let (server_side, mut client_side) = duplex(4096); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.12:55000".parse().unwrap(), + config, + stats.clone(), + Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + true, + )); + + client_side.write_all(&vec![b'A'; 2000]).await.unwrap(); + + let result = tokio::time::timeout(Duration::from_secs(2), handler) + .await + .unwrap() + .unwrap(); + assert!(matches!(result, Err(ProxyError::InvalidProxyProtocol))); + assert_eq!(stats.get_connects_bad(), 1); +} + +#[tokio::test] +async fn edge_tls_body_immediate_eof_triggers_masking_and_bad_connect() { + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = true; + cfg.general.beobachten_minutes = 1; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(4096); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.13:55000".parse().unwrap(), + config, + stats.clone(), + Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + beobachten.clone(), + false, + )); + + client_side.write_all(&[0x16, 0x03, 0x01, 0x00, 100]).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(2), handler).await.unwrap(); + + assert_eq!(stats.get_connects_bad(), 1); +} + +#[tokio::test] +async fn security_classic_mode_disabled_masks_valid_length_payload() { + let mut cfg = ProxyConfig::default(); + cfg.general.modes.classic = false; + cfg.general.modes.secure = false; + cfg.censorship.mask = true; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + + let (server_side, mut client_side) = duplex(4096); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.15:55000".parse().unwrap(), + config, + stats.clone(), + Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + false, + )); + + client_side.write_all(&vec![0xEF; 64]).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(2), handler).await.unwrap(); + assert_eq!(stats.get_connects_bad(), 1); +} + +#[tokio::test] +async fn concurrency_ip_tracker_strict_limit_one_rapid_churn() { + let user = "rapid-churn-user"; + let mut config = ProxyConfig::default(); + config.access.user_max_tcp_conns.insert(user.to_string(), 10); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 1).await; + + let peer = "198.51.100.16:55000".parse().unwrap(); + + for _ in 0..500 { + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await + .unwrap(); + reservation.release().await; + } + + assert_eq!(stats.get_user_curr_connects(user), 0); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 0); +} + +#[tokio::test] +async fn quirk_read_with_progress_zero_length_buffer_returns_zero_immediately() { + let (mut server_side, _client_side) = duplex(4096); + let mut empty_buf = &mut [][..]; + + let result = tokio::time::timeout( + Duration::from_millis(50), + read_with_progress(&mut server_side, &mut empty_buf), + ) + .await; + + assert!(result.is_ok()); + assert_eq!(result.unwrap().unwrap(), 0); +} + +#[tokio::test] +async fn stress_read_with_progress_cancellation_safety() { + let (mut server_side, mut client_side) = duplex(4096); + + client_side.write_all(b"12345").await.unwrap(); + + let mut buf = [0u8; 10]; + let result = tokio::time::timeout( + Duration::from_millis(50), + read_with_progress(&mut server_side, &mut buf), + ) + .await; + + assert!(result.is_err()); + + client_side.write_all(b"67890").await.unwrap(); + let mut buf2 = [0u8; 5]; + server_side.read_exact(&mut buf2).await.unwrap(); + assert_eq!(&buf2, b"67890"); +} diff --git a/src/proxy/tests/client_security_tests.rs b/src/proxy/tests/client_security_tests.rs index 6338e23..2b1fae6 100644 --- a/src/proxy/tests/client_security_tests.rs +++ b/src/proxy/tests/client_security_tests.rs @@ -7,6 +7,9 @@ use crate::protocol::tls; use crate::proxy::handshake::HandshakeSuccess; use crate::stream::{CryptoReader, CryptoWriter}; use crate::transport::proxy_protocol::ProxyProtocolV1Builder; +use rand::rngs::StdRng; +use rand::Rng; +use rand::SeedableRng; use std::net::Ipv4Addr; use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; use tokio::net::{TcpListener, TcpStream}; @@ -25,6 +28,202 @@ fn synthetic_local_addr_uses_configured_port_for_max() { assert_eq!(addr.port(), u16::MAX); } +#[test] +fn handshake_timeout_with_mask_grace_includes_mask_margin() { + let mut config = ProxyConfig::default(); + config.timeouts.client_handshake = 2; + + config.censorship.mask = false; + assert_eq!(handshake_timeout_with_mask_grace(&config), Duration::from_secs(2)); + + config.censorship.mask = true; + assert_eq!( + handshake_timeout_with_mask_grace(&config), + Duration::from_millis(2750), + "mask mode extends handshake timeout by 750 ms" + ); +} + +#[tokio::test] +async fn read_with_progress_reads_partial_buffers_before_eof() { + let data = vec![0xAA, 0xBB, 0xCC]; + let mut reader = std::io::Cursor::new(data); + let mut buf = [0u8; 5]; + + let read = read_with_progress(&mut reader, &mut buf).await.unwrap(); + assert_eq!(read, 3); + assert_eq!(&buf[..3], &[0xAA, 0xBB, 0xCC]); +} + +#[test] +fn is_trusted_proxy_source_respects_cidr_list_and_empty_rejects_all() { + let peer: IpAddr = "10.10.10.10".parse().unwrap(); + assert!(!is_trusted_proxy_source(peer, &[])); + + let trusted = vec!["10.0.0.0/8".parse().unwrap()]; + assert!(is_trusted_proxy_source(peer, &trusted)); + + let not_trusted = vec!["192.0.2.0/24".parse().unwrap()]; + assert!(!is_trusted_proxy_source(peer, ¬_trusted)); +} + +#[test] +fn is_trusted_proxy_source_accepts_cidr_zero_zero_as_global_cidr() { + let peer: IpAddr = "203.0.113.42".parse().unwrap(); + let trust_all = vec!["0.0.0.0/0".parse().unwrap()]; + assert!(is_trusted_proxy_source(peer, &trust_all)); + + let peer_v6: IpAddr = "2001:db8::1".parse().unwrap(); + let trust_all_v6 = vec!["::/0".parse().unwrap()]; + assert!(is_trusted_proxy_source(peer_v6, &trust_all_v6)); +} + +struct ErrorReader; + +impl tokio::io::AsyncRead for ErrorReader { + fn poll_read( + self: std::pin::Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + _buf: &mut tokio::io::ReadBuf<'_>, + ) -> std::task::Poll> { + std::task::Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "fake error"))) + } +} + +#[tokio::test] +async fn read_with_progress_returns_error_from_failed_reader() { + let mut reader = ErrorReader; + let mut buf = [0u8; 8]; + let err = read_with_progress(&mut reader, &mut buf).await.unwrap_err(); + assert_eq!(err.kind(), std::io::ErrorKind::UnexpectedEof); +} + +#[test] +fn handshake_timeout_with_mask_grace_handles_maximum_values_without_overflow() { + let mut config = ProxyConfig::default(); + config.timeouts.client_handshake = u64::MAX; + config.censorship.mask = true; + + let timeout = handshake_timeout_with_mask_grace(&config); + assert!(timeout >= Duration::from_secs(u64::MAX)); +} + +#[tokio::test] +async fn read_with_progress_zero_length_buffer_returns_zero() { + let data = vec![1, 2, 3]; + let mut reader = std::io::Cursor::new(data); + let mut buf = []; + + let read = read_with_progress(&mut reader, &mut buf).await.unwrap(); + assert_eq!(read, 0); +} + +#[test] +fn handshake_timeout_without_mask_is_exact_base() { + let mut config = ProxyConfig::default(); + config.timeouts.client_handshake = 7; + config.censorship.mask = false; + + assert_eq!(handshake_timeout_with_mask_grace(&config), Duration::from_secs(7)); +} + +#[test] +fn handshake_timeout_mask_enabled_adds_750ms() { + let mut config = ProxyConfig::default(); + config.timeouts.client_handshake = 3; + config.censorship.mask = true; + + assert_eq!(handshake_timeout_with_mask_grace(&config), Duration::from_millis(3750)); +} + +#[tokio::test] +async fn read_with_progress_full_then_empty_transition() { + let data = vec![0x10, 0x20]; + let mut cursor = std::io::Cursor::new(data); + let mut buf = [0u8; 2]; + + assert_eq!(read_with_progress(&mut cursor, &mut buf).await.unwrap(), 2); + assert_eq!(read_with_progress(&mut cursor, &mut buf).await.unwrap(), 0); +} + +#[tokio::test] +async fn read_with_progress_fragmented_io_works_over_multiple_calls() { + let mut cursor = std::io::Cursor::new(vec![1, 2, 3, 4, 5]); + let mut result = Vec::new(); + + for chunk_size in 1..=5 { + let mut b = vec![0u8; chunk_size]; + let n = read_with_progress(&mut cursor, &mut b).await.unwrap(); + result.extend_from_slice(&b[..n]); + if n == 0 { break; } + } + + assert_eq!(result, vec![1,2,3,4,5]); +} + +#[tokio::test] +async fn read_with_progress_stress_randomized_chunk_sizes() { + for i in 0..128 { + let mut rng = StdRng::seed_from_u64(i as u64 + 1); + let mut input: Vec = (0..(i % 41)).map(|_| rng.next_u32() as u8).collect(); + let mut cursor = std::io::Cursor::new(input.clone()); + let mut collected = Vec::new(); + + while cursor.position() < cursor.get_ref().len() as u64 { + let chunk = 1 + (rng.next_u32() as usize % 8); + let mut b = vec![0u8; chunk]; + let read = read_with_progress(&mut cursor, &mut b).await.unwrap(); + collected.extend_from_slice(&b[..read]); + if read == 0 { break; } + } + + assert_eq!(collected, input); + } +} + +#[test] +fn is_trusted_proxy_source_boundary_narrow_ipv4() { + let matching = "172.16.0.1".parse().unwrap(); + let not_matching = "172.15.255.255".parse().unwrap(); + let cidr = vec!["172.16.0.0/12".parse().unwrap()]; + assert!(is_trusted_proxy_source(matching, &cidr)); + assert!(!is_trusted_proxy_source(not_matching, &cidr)); +} + +#[test] +fn is_trusted_proxy_source_rejects_out_of_family_ipv6_v4_cidr() { + let peer = "2001:db8::1".parse().unwrap(); + let cidr = vec!["10.0.0.0/8".parse().unwrap()]; + assert!(!is_trusted_proxy_source(peer, &cidr)); +} + +#[test] +fn wrap_tls_application_record_reserved_chunks_look_reasonable() { + let payload = vec![0xAA; 1 + (u16::MAX as usize) + 2]; + let wrapped = wrap_tls_application_record(&payload); + assert!(wrapped.len() > payload.len()); + assert!(wrapped.contains(&0x17)); +} + +#[test] +fn wrap_tls_application_record_roundtrip_size_check() { + let payload_len = 3000; + let payload = vec![0x55; payload_len]; + let wrapped = wrap_tls_application_record(&payload); + + let mut idx = 0; + let mut consumed = 0; + while idx + 5 <= wrapped.len() { + assert_eq!(wrapped[idx], 0x17); + let len = u16::from_be_bytes([wrapped[idx+3], wrapped[idx+4]]) as usize; + consumed += len; + idx += 5 + len; + if idx >= wrapped.len() { break; } + } + + assert_eq!(consumed, payload_len); +} + fn make_crypto_reader(reader: R) -> CryptoReader where R: tokio::io::AsyncRead + Unpin, diff --git a/src/proxy/tests/handshake_advanced_clever_tests.rs b/src/proxy/tests/handshake_advanced_clever_tests.rs new file mode 100644 index 0000000..9b12f21 --- /dev/null +++ b/src/proxy/tests/handshake_advanced_clever_tests.rs @@ -0,0 +1,647 @@ +use super::*; +use crate::crypto::{sha256, sha256_hmac, AesCtr}; +use crate::protocol::constants::{ProtoTag, RESERVED_NONCE_BEGINNINGS, RESERVED_NONCE_FIRST_BYTES}; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; +use std::sync::Arc; +use std::time::{Duration, Instant}; + +// --- Helpers --- + +fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> { + auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig { + let mut cfg = ProxyConfig::default(); + cfg.access.users.clear(); + cfg.access + .users + .insert("user".to_string(), secret_hex.to_string()); + cfg.access.ignore_time_skew = true; + cfg.general.modes.secure = true; + cfg.general.modes.classic = true; + cfg.general.modes.tls = true; + cfg +} + +fn make_valid_tls_handshake(secret: &[u8], timestamp: u32) -> Vec { + let session_id_len: usize = 32; + let len = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + session_id_len; + let mut handshake = vec![0x42u8; len]; + + handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8; + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + + let computed = sha256_hmac(secret, &handshake); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .copy_from_slice(&digest); + handshake +} + +fn make_valid_mtproto_handshake( + secret_hex: &str, + proto_tag: ProtoTag, + dc_idx: i16, +) -> [u8; HANDSHAKE_LEN] { + let secret = hex::decode(secret_hex).expect("secret hex must decode"); + let mut handshake = [0x5Au8; HANDSHAKE_LEN]; + for (idx, b) in handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN] + .iter_mut() + .enumerate() + { + *b = (idx as u8).wrapping_add(1); + } + + let dec_prekey = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN]; + let dec_iv_bytes = &handshake[SKIP_LEN + PREKEY_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]; + + let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len()); + dec_key_input.extend_from_slice(dec_prekey); + dec_key_input.extend_from_slice(&secret); + let dec_key = sha256(&dec_key_input); + + let mut dec_iv_arr = [0u8; IV_LEN]; + dec_iv_arr.copy_from_slice(dec_iv_bytes); + let dec_iv = u128::from_be_bytes(dec_iv_arr); + + let mut stream = AesCtr::new(&dec_key, dec_iv); + let keystream = stream.encrypt(&[0u8; HANDSHAKE_LEN]); + + let mut target_plain = [0u8; HANDSHAKE_LEN]; + target_plain[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes()); + target_plain[DC_IDX_POS..DC_IDX_POS + 2].copy_from_slice(&dc_idx.to_le_bytes()); + + for idx in PROTO_TAG_POS..HANDSHAKE_LEN { + handshake[idx] = target_plain[idx] ^ keystream[idx]; + } + + handshake +} + +fn make_valid_tls_client_hello_with_alpn( + secret: &[u8], + timestamp: u32, + alpn_protocols: &[&[u8]], +) -> Vec { + let mut body = Vec::new(); + body.extend_from_slice(&TLS_VERSION); + body.extend_from_slice(&[0u8; 32]); + body.push(32); + body.extend_from_slice(&[0x42u8; 32]); + body.extend_from_slice(&2u16.to_be_bytes()); + body.extend_from_slice(&[0x13, 0x01]); + body.push(1); + body.push(0); + + let mut ext_blob = Vec::new(); + if !alpn_protocols.is_empty() { + let mut alpn_list = Vec::new(); + for proto in alpn_protocols { + alpn_list.push(proto.len() as u8); + alpn_list.extend_from_slice(proto); + } + let mut alpn_data = Vec::new(); + alpn_data.extend_from_slice(&(alpn_list.len() as u16).to_be_bytes()); + alpn_data.extend_from_slice(&alpn_list); + + ext_blob.extend_from_slice(&0x0010u16.to_be_bytes()); + ext_blob.extend_from_slice(&(alpn_data.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&alpn_data); + } + body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes()); + body.extend_from_slice(&ext_blob); + + let mut handshake = Vec::new(); + handshake.push(0x01); + let body_len = (body.len() as u32).to_be_bytes(); + handshake.extend_from_slice(&body_len[1..4]); + handshake.extend_from_slice(&body); + + let mut record = Vec::new(); + record.push(TLS_RECORD_HANDSHAKE); + record.extend_from_slice(&[0x03, 0x01]); + record.extend_from_slice(&(handshake.len() as u16).to_be_bytes()); + record.extend_from_slice(&handshake); + + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + let computed = sha256_hmac(secret, &record); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].copy_from_slice(&digest); + + record +} + +// --- Category 1: Edge Cases & Protocol Boundaries --- + +#[tokio::test] +async fn tls_minimum_viable_length_boundary() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret = [0x11u8; 16]; + let config = test_config_with_secret_hex("11111111111111111111111111111111"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "192.0.2.1:12345".parse().unwrap(); + + let min_len = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1; + let mut exact_min_handshake = vec![0x42u8; min_len]; + exact_min_handshake[min_len - 1] = 0; + exact_min_handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + let digest = sha256_hmac(&secret, &exact_min_handshake); + exact_min_handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .copy_from_slice(&digest); + + let res = handle_tls_handshake( + &exact_min_handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(res, HandshakeResult::Success(_)), "Exact minimum length TLS handshake must succeed"); + + let short_handshake = vec![0x42u8; min_len - 1]; + let res_short = handle_tls_handshake( + &short_handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(res_short, HandshakeResult::BadClient { .. }), "Handshake 1 byte shorter than minimum must fail closed"); +} + +#[tokio::test] +async fn mtproto_extreme_dc_index_serialization() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret_hex = "22222222222222222222222222222222"; + let config = test_config_with_secret_hex(secret_hex); + for (idx, extreme_dc) in [i16::MIN, i16::MAX, -1, 0].into_iter().enumerate() { + // Keep replay state independent per case so we validate dc_idx encoding, + // not duplicate-handshake rejection behavior. + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 0, 2, 2)), 12345 + idx as u16); + let handshake = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, extreme_dc); + let res = handle_mtproto_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + + match res { + HandshakeResult::Success((_, _, success)) => { + assert_eq!(success.dc_idx, extreme_dc, "Extreme DC index {} must serialize/deserialize perfectly", extreme_dc); + } + _ => panic!("MTProto handshake with extreme DC index {} failed", extreme_dc), + } + } +} + +#[tokio::test] +async fn alpn_strict_case_and_padding_rejection() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret = [0x33u8; 16]; + let mut config = test_config_with_secret_hex("33333333333333333333333333333333"); + config.censorship.alpn_enforce = true; + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "192.0.2.3:12345".parse().unwrap(); + + let bad_alpns: &[&[u8]] = &[b"H2", b"h2\0", b" http/1.1", b"http/1.1\n"]; + + for bad_alpn in bad_alpns { + let handshake = make_valid_tls_client_hello_with_alpn(&secret, 0, &[*bad_alpn]); + let res = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(res, HandshakeResult::BadClient { .. }), "ALPN strict enforcement must reject {:?}", bad_alpn); + } +} + +#[test] +fn ipv4_mapped_ipv6_bucketing_anomaly() { + let ipv4_mapped_1 = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0xffff, 0xc000, 0x0201)); + let ipv4_mapped_2 = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0xffff, 0xc633, 0x6402)); + + let norm_1 = normalize_auth_probe_ip(ipv4_mapped_1); + let norm_2 = normalize_auth_probe_ip(ipv4_mapped_2); + + assert_eq!(norm_1, norm_2, "IPv4-mapped IPv6 addresses must collapse into the same /64 bucket (::0)"); + assert_eq!(norm_1, IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), "The bucket must be exactly ::0"); +} + +// --- Category 2: Adversarial & Black Hat --- + +#[tokio::test] +async fn mtproto_invalid_ciphertext_does_not_poison_replay_cache() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret_hex = "55555555555555555555555555555555"; + let config = test_config_with_secret_hex(secret_hex); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let peer: SocketAddr = "192.0.2.5:12345".parse().unwrap(); + + let valid_handshake = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 1); + let mut invalid_handshake = valid_handshake; + invalid_handshake[SKIP_LEN + PREKEY_LEN + IV_LEN + 1] ^= 0xFF; + + let res_invalid = handle_mtproto_handshake( + &invalid_handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + assert!(matches!(res_invalid, HandshakeResult::BadClient { .. })); + + let res_valid = handle_mtproto_handshake( + &valid_handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + assert!(matches!(res_valid, HandshakeResult::Success(_)), "Invalid MTProto ciphertext must not poison the replay cache"); +} + +#[tokio::test] +async fn tls_invalid_session_does_not_poison_replay_cache() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret = [0x66u8; 16]; + let config = test_config_with_secret_hex("66666666666666666666666666666666"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "192.0.2.6:12345".parse().unwrap(); + + let valid_handshake = make_valid_tls_handshake(&secret, 0); + let mut invalid_handshake = valid_handshake.clone(); + let session_idx = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1; + invalid_handshake[session_idx] ^= 0xFF; + + let res_invalid = handle_tls_handshake( + &invalid_handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(res_invalid, HandshakeResult::BadClient { .. })); + + let res_valid = handle_tls_handshake( + &valid_handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(res_valid, HandshakeResult::Success(_)), "Invalid TLS payload must not poison the replay cache"); +} + +#[tokio::test] +async fn server_hello_delay_timing_neutrality_on_hmac_failure() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret = [0x77u8; 16]; + let mut config = test_config_with_secret_hex("77777777777777777777777777777777"); + config.censorship.server_hello_delay_min_ms = 50; + config.censorship.server_hello_delay_max_ms = 50; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "192.0.2.7:12345".parse().unwrap(); + + let mut invalid_handshake = make_valid_tls_handshake(&secret, 0); + invalid_handshake[tls::TLS_DIGEST_POS] ^= 0xFF; + + let start = Instant::now(); + let res = handle_tls_handshake( + &invalid_handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + let elapsed = start.elapsed(); + + assert!(matches!(res, HandshakeResult::BadClient { .. })); + assert!(elapsed >= Duration::from_millis(45), "Invalid HMAC must still incur the configured ServerHello delay to prevent timing side-channels"); +} + +#[tokio::test] +async fn server_hello_delay_inversion_resilience() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret = [0x88u8; 16]; + let mut config = test_config_with_secret_hex("88888888888888888888888888888888"); + config.censorship.server_hello_delay_min_ms = 100; + config.censorship.server_hello_delay_max_ms = 10; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "192.0.2.8:12345".parse().unwrap(); + + let valid_handshake = make_valid_tls_handshake(&secret, 0); + + let start = Instant::now(); + let res = handle_tls_handshake( + &valid_handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + let elapsed = start.elapsed(); + + assert!(matches!(res, HandshakeResult::Success(_))); + assert!(elapsed >= Duration::from_millis(90), "Delay logic must gracefully handle min > max inversions via max.max(min)"); +} + +#[tokio::test] +async fn mixed_valid_and_invalid_user_secrets_configuration() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + let _warn_guard = warned_secrets_test_lock().lock().unwrap(); + clear_warned_secrets_for_testing(); + + let mut config = ProxyConfig::default(); + config.access.ignore_time_skew = true; + + for i in 0..9 { + let bad_secret = if i % 2 == 0 { "badhex!" } else { "1122" }; + config.access.users.insert(format!("bad_user_{}", i), bad_secret.to_string()); + } + let valid_secret_hex = "99999999999999999999999999999999"; + config.access.users.insert("good_user".to_string(), valid_secret_hex.to_string()); + config.general.modes.secure = true; + config.general.modes.classic = true; + config.general.modes.tls = true; + + let secret = [0x99u8; 16]; + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "192.0.2.9:12345".parse().unwrap(); + + let valid_handshake = make_valid_tls_handshake(&secret, 0); + + let res = handle_tls_handshake( + &valid_handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(res, HandshakeResult::Success(_)), "Proxy must gracefully skip invalid secrets and authenticate the valid one"); +} + +#[tokio::test] +async fn tls_emulation_fallback_when_cache_missing() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret = [0xAAu8; 16]; + let mut config = test_config_with_secret_hex("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"); + config.censorship.tls_emulation = true; + config.general.modes.tls = true; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "192.0.2.10:12345".parse().unwrap(); + + let valid_handshake = make_valid_tls_handshake(&secret, 0); + + let res = handle_tls_handshake( + &valid_handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(res, HandshakeResult::Success(_)), "TLS emulation must gracefully fall back to standard ServerHello if cache is missing"); +} + +#[tokio::test] +async fn classic_mode_over_tls_transport_protocol_confusion() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret_hex = "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb"; + let mut config = test_config_with_secret_hex(secret_hex); + config.general.modes.classic = true; + config.general.modes.tls = true; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let peer: SocketAddr = "192.0.2.11:12345".parse().unwrap(); + + let handshake = make_valid_mtproto_handshake(secret_hex, ProtoTag::Intermediate, 1); + + let res = handle_mtproto_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + true, + None, + ) + .await; + + assert!(matches!(res, HandshakeResult::Success(_)), "Intermediate tag over TLS must succeed if classic mode is enabled, locking in cross-transport behavior"); +} + +#[test] +fn generate_tg_nonce_never_emits_reserved_bytes() { + let client_enc_key = [0xCCu8; 32]; + let client_enc_iv = 123456789u128; + let rng = SecureRandom::new(); + + for _ in 0..10_000 { + let (nonce, _, _, _, _) = generate_tg_nonce( + ProtoTag::Secure, + 1, + &client_enc_key, + client_enc_iv, + &rng, + false, + ); + + assert!(!RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]), "Nonce must never start with reserved bytes"); + let first_four: [u8; 4] = [nonce[0], nonce[1], nonce[2], nonce[3]]; + assert!(!RESERVED_NONCE_BEGINNINGS.contains(&first_four), "Nonce must never match reserved 4-byte beginnings"); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn dashmap_concurrent_saturation_stress() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let ip_a: IpAddr = "192.0.2.13".parse().unwrap(); + let ip_b: IpAddr = "198.51.100.13".parse().unwrap(); + let mut tasks = Vec::new(); + + for i in 0..100 { + let target_ip = if i % 2 == 0 { ip_a } else { ip_b }; + tasks.push(tokio::spawn(async move { + for _ in 0..50 { + auth_probe_record_failure(target_ip, Instant::now()); + } + })); + } + + for task in tasks { + task.await.expect("Task panicked during concurrent DashMap stress"); + } + + assert!(auth_probe_is_throttled_for_testing(ip_a), "IP A must be throttled after concurrent stress"); + assert!(auth_probe_is_throttled_for_testing(ip_b), "IP B must be throttled after concurrent stress"); +} + +#[test] +fn prototag_invalid_bytes_fail_closed() { + let invalid_tags: [[u8; 4]; 5] = [ + [0, 0, 0, 0], + [0xFF, 0xFF, 0xFF, 0xFF], + [0xDE, 0xAD, 0xBE, 0xEF], + [0xDD, 0xDD, 0xDD, 0xDE], + [0x11, 0x22, 0x33, 0x44], + ]; + + for tag in invalid_tags { + assert_eq!(ProtoTag::from_bytes(tag), None, "Invalid ProtoTag bytes {:?} must fail closed", tag); + } +} + +#[test] +fn auth_probe_eviction_hash_collision_stress() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let state = auth_probe_state_map(); + let now = Instant::now(); + + for i in 0..10_000u32 { + let ip = IpAddr::V4(Ipv4Addr::new(10, 0, (i >> 8) as u8, (i & 0xFF) as u8)); + auth_probe_record_failure_with_state(state, ip, now); + } + + assert!(state.len() <= AUTH_PROBE_TRACK_MAX_ENTRIES, "Eviction logic must successfully bound the map size under heavy insertion stress"); +} + +#[test] +fn encrypt_tg_nonce_with_ciphers_advances_counter_correctly() { + let client_enc_key = [0xDDu8; 32]; + let client_enc_iv = 987654321u128; + let rng = SecureRandom::new(); + + let (nonce, _, _, _, _) = generate_tg_nonce( + ProtoTag::Secure, + 2, + &client_enc_key, + client_enc_iv, + &rng, + false, + ); + + let (_, mut returned_encryptor, _) = encrypt_tg_nonce_with_ciphers(&nonce); + let zeros = [0u8; 64]; + let returned_keystream = returned_encryptor.encrypt(&zeros); + + let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN]; + let mut expected_enc_key = [0u8; 32]; + expected_enc_key.copy_from_slice(&enc_key_iv[..KEY_LEN]); + let mut expected_enc_iv_arr = [0u8; IV_LEN]; + expected_enc_iv_arr.copy_from_slice(&enc_key_iv[KEY_LEN..]); + let expected_enc_iv = u128::from_be_bytes(expected_enc_iv_arr); + + let mut manual_encryptor = AesCtr::new(&expected_enc_key, expected_enc_iv); + + let mut manual_input = Vec::new(); + manual_input.extend_from_slice(&nonce); + manual_input.extend_from_slice(&zeros); + let manual_output = manual_encryptor.encrypt(&manual_input); + + assert_eq!( + returned_keystream, + &manual_output[64..128], + "encrypt_tg_nonce_with_ciphers must correctly advance the AES-CTR counter by exactly the nonce length" + ); +} diff --git a/src/proxy/tests/handshake_auth_probe_eviction_bias_security_tests.rs b/src/proxy/tests/handshake_auth_probe_eviction_bias_security_tests.rs new file mode 100644 index 0000000..6c48cc1 --- /dev/null +++ b/src/proxy/tests/handshake_auth_probe_eviction_bias_security_tests.rs @@ -0,0 +1,93 @@ +use super::*; +use std::collections::HashSet; +use std::net::{IpAddr, Ipv4Addr}; +use std::time::{Duration, Instant}; + +fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> { + auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +#[test] +fn adversarial_large_state_offsets_escape_first_scan_window() { + let _guard = auth_probe_test_guard(); + let base = Instant::now(); + let state_len = 65_536usize; + let scan_limit = 1_024usize; + + let mut saw_offset_outside_first_window = false; + for i in 0..8_192u64 { + let ip = IpAddr::V4(Ipv4Addr::new( + ((i >> 16) & 0xff) as u8, + ((i >> 8) & 0xff) as u8, + (i & 0xff) as u8, + ((i.wrapping_mul(131)) & 0xff) as u8, + )); + let now = base + Duration::from_nanos(i); + let start = auth_probe_scan_start_offset(ip, now, state_len, scan_limit); + if start >= scan_limit { + saw_offset_outside_first_window = true; + break; + } + } + + assert!( + saw_offset_outside_first_window, + "scan start offset must cover the full auth-probe state, not only the first scan window" + ); +} + +#[test] +fn stress_large_state_offsets_cover_many_scan_windows() { + let _guard = auth_probe_test_guard(); + let base = Instant::now(); + let state_len = 65_536usize; + let scan_limit = 1_024usize; + + let mut covered_windows = HashSet::new(); + for i in 0..16_384u64 { + let ip = IpAddr::V4(Ipv4Addr::new( + ((i >> 16) & 0xff) as u8, + ((i >> 8) & 0xff) as u8, + (i & 0xff) as u8, + ((i.wrapping_mul(17)) & 0xff) as u8, + )); + let now = base + Duration::from_micros(i); + let start = auth_probe_scan_start_offset(ip, now, state_len, scan_limit); + covered_windows.insert(start / scan_limit); + } + + assert!( + covered_windows.len() >= 16, + "eviction scan must not collapse to a tiny hot zone; covered windows={} out of {}", + covered_windows.len(), + state_len / scan_limit + ); +} + +#[test] +fn light_fuzz_offset_always_stays_inside_state_len() { + let _guard = auth_probe_test_guard(); + let mut seed = 0xC0FF_EE12_3456_789Au64; + let base = Instant::now(); + + for _ in 0..8_192usize { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let ip = IpAddr::V4(Ipv4Addr::new( + (seed >> 24) as u8, + (seed >> 16) as u8, + (seed >> 8) as u8, + seed as u8, + )); + let state_len = ((seed >> 16) as usize % 200_000).saturating_add(1); + let scan_limit = ((seed >> 40) as usize % 2_048).saturating_add(1); + let now = base + Duration::from_nanos(seed & 0x0fff); + let start = auth_probe_scan_start_offset(ip, now, state_len, scan_limit); + + assert!(start < state_len, "scan offset must stay inside state length"); + } +} \ No newline at end of file diff --git a/src/proxy/tests/handshake_auth_probe_scan_budget_security_tests.rs b/src/proxy/tests/handshake_auth_probe_scan_budget_security_tests.rs new file mode 100644 index 0000000..ece6ff5 --- /dev/null +++ b/src/proxy/tests/handshake_auth_probe_scan_budget_security_tests.rs @@ -0,0 +1,99 @@ +use super::*; +use std::net::{IpAddr, Ipv4Addr}; +use std::time::{Duration, Instant}; + +fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> { + auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +#[test] +fn edge_zero_state_len_yields_zero_start_offset() { + let _guard = auth_probe_test_guard(); + let ip = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 44)); + let now = Instant::now(); + + assert_eq!( + auth_probe_scan_start_offset(ip, now, 0, 16), + 0, + "empty map must not produce non-zero scan offset" + ); +} + +#[test] +fn adversarial_large_state_must_allow_start_offset_outside_scan_budget_window() { + let _guard = auth_probe_test_guard(); + let base = Instant::now(); + let scan_limit = 16usize; + let state_len = 65_536usize; + + let mut saw_offset_outside_window = false; + for i in 0..2048u32 { + let ip = IpAddr::V4(Ipv4Addr::new( + 203, + ((i >> 16) & 0xff) as u8, + ((i >> 8) & 0xff) as u8, + (i & 0xff) as u8, + )); + let now = base + Duration::from_micros(i as u64); + let start = auth_probe_scan_start_offset(ip, now, state_len, scan_limit); + assert!( + start < state_len, + "start offset must stay within state length; start={start}, len={state_len}" + ); + if start >= scan_limit { + saw_offset_outside_window = true; + break; + } + } + + assert!( + saw_offset_outside_window, + "large-state eviction must sample beyond the first scan window" + ); +} + +#[test] +fn positive_state_smaller_than_scan_limit_caps_to_state_len() { + let _guard = auth_probe_test_guard(); + let ip = IpAddr::V4(Ipv4Addr::new(192, 0, 2, 17)); + let now = Instant::now(); + + for state_len in 1..32usize { + let start = auth_probe_scan_start_offset(ip, now, state_len, 64); + assert!( + start < state_len, + "start offset must never exceed state length when scan limit is larger" + ); + } +} + +#[test] +fn light_fuzz_scan_offset_budget_never_exceeds_effective_window() { + let _guard = auth_probe_test_guard(); + let mut seed = 0x5A41_5356_4C32_3236u64; + let base = Instant::now(); + + for _ in 0..4096 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let ip = IpAddr::V4(Ipv4Addr::new( + (seed >> 24) as u8, + (seed >> 16) as u8, + (seed >> 8) as u8, + seed as u8, + )); + let state_len = ((seed >> 8) as usize % 131_072).saturating_add(1); + let scan_limit = ((seed >> 32) as usize % 512).saturating_add(1); + let now = base + Duration::from_nanos(seed & 0xffff); + let start = auth_probe_scan_start_offset(ip, now, state_len, scan_limit); + + assert!( + start < state_len, + "scan offset must stay inside state length" + ); + } +} \ No newline at end of file diff --git a/src/proxy/tests/handshake_auth_probe_scan_offset_stress_tests.rs b/src/proxy/tests/handshake_auth_probe_scan_offset_stress_tests.rs new file mode 100644 index 0000000..260a1b9 --- /dev/null +++ b/src/proxy/tests/handshake_auth_probe_scan_offset_stress_tests.rs @@ -0,0 +1,116 @@ +use super::*; +use std::collections::HashSet; +use std::net::{IpAddr, Ipv4Addr}; +use std::time::{Duration, Instant}; + +fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> { + auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +#[test] +fn positive_same_ip_moving_time_yields_diverse_scan_offsets() { + let _guard = auth_probe_test_guard(); + let ip = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 77)); + let base = Instant::now(); + let mut uniq = HashSet::new(); + + for i in 0..512u64 { + let now = base + Duration::from_nanos(i); + let offset = auth_probe_scan_start_offset(ip, now, 65_536, 16); + uniq.insert(offset); + } + + assert!( + uniq.len() >= 256, + "offset randomization collapsed unexpectedly for same-ip moving-time samples (uniq={})", + uniq.len() + ); +} + +#[test] +fn adversarial_many_ips_same_time_spreads_offsets_without_bias_collapse() { + let _guard = auth_probe_test_guard(); + let now = Instant::now(); + let mut uniq = HashSet::new(); + + for i in 0..1024u32 { + let ip = IpAddr::V4(Ipv4Addr::new( + (i >> 16) as u8, + (i >> 8) as u8, + i as u8, + (255 - (i as u8)), + )); + uniq.insert(auth_probe_scan_start_offset(ip, now, 65_536, 16)); + } + + assert!( + uniq.len() >= 512, + "scan offset distribution collapsed unexpectedly across adversarial peer set (uniq={})", + uniq.len() + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_parallel_failure_churn_under_saturation_remains_capped_and_live() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let start = Instant::now(); + let mut workers = Vec::new(); + for worker in 0..8u8 { + workers.push(tokio::spawn(async move { + for i in 0..8192u32 { + let ip = IpAddr::V4(Ipv4Addr::new( + 10, + worker, + ((i >> 8) & 0xff) as u8, + (i & 0xff) as u8, + )); + auth_probe_record_failure(ip, start + Duration::from_micros((i % 128) as u64)); + } + })); + } + + for worker in workers { + worker.await.expect("saturation worker must not panic"); + } + + assert!( + auth_probe_state_map().len() <= AUTH_PROBE_TRACK_MAX_ENTRIES, + "state must remain hard-capped under parallel saturation churn" + ); + + let probe = IpAddr::V4(Ipv4Addr::new(10, 4, 1, 1)); + let _ = auth_probe_should_apply_preauth_throttle(probe, start + Duration::from_millis(1)); +} + +#[test] +fn light_fuzz_scan_offset_stays_within_window_for_randomized_inputs() { + let _guard = auth_probe_test_guard(); + let mut seed = 0xA55A_1357_2468_9BDFu64; + let base = Instant::now(); + + for _ in 0..8192 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let ip = IpAddr::V4(Ipv4Addr::new( + (seed >> 24) as u8, + (seed >> 16) as u8, + (seed >> 8) as u8, + seed as u8, + )); + let state_len = ((seed >> 8) as usize % 200_000).saturating_add(1); + let scan_limit = ((seed >> 40) as usize % 1024).saturating_add(1); + let now = base + Duration::from_nanos(seed & 0x1fff); + + let offset = auth_probe_scan_start_offset(ip, now, state_len, scan_limit); + assert!( + offset < state_len, + "scan offset must always remain inside state length" + ); + } +} \ No newline at end of file diff --git a/src/proxy/tests/handshake_key_material_zeroization_security_tests.rs b/src/proxy/tests/handshake_key_material_zeroization_security_tests.rs new file mode 100644 index 0000000..7176b1c --- /dev/null +++ b/src/proxy/tests/handshake_key_material_zeroization_security_tests.rs @@ -0,0 +1,42 @@ +use super::*; + +fn handshake_source() -> &'static str { + include_str!("../handshake.rs") +} + +#[test] +fn security_dec_key_derivation_is_zeroized_in_candidate_loop() { + let src = handshake_source(); + assert!( + src.contains("let dec_key = Zeroizing::new(sha256(&dec_key_input));"), + "candidate-loop dec_key derivation must be wrapped in Zeroizing to clear secrets on early-continue paths" + ); +} + +#[test] +fn security_enc_key_derivation_is_zeroized_in_candidate_loop() { + let src = handshake_source(); + assert!( + src.contains("let enc_key = Zeroizing::new(sha256(&enc_key_input));"), + "candidate-loop enc_key derivation must be wrapped in Zeroizing to clear secrets on early-continue paths" + ); +} + +#[test] +fn security_aes_ctr_initialization_uses_zeroizing_references() { + let src = handshake_source(); + assert!( + src.contains("let mut decryptor = AesCtr::new(&dec_key, dec_iv);") + && src.contains("let encryptor = AesCtr::new(&enc_key, enc_iv);"), + "AES-CTR initialization must use Zeroizing key wrappers directly without creating extra plain key variables" + ); +} + +#[test] +fn security_success_struct_copies_out_of_zeroizing_wrappers() { + let src = handshake_source(); + assert!( + src.contains("dec_key: *dec_key,") && src.contains("enc_key: *enc_key,"), + "HandshakeSuccess construction must copy from Zeroizing wrappers so loop-local key material is dropped and zeroized" + ); +} diff --git a/src/proxy/tests/handshake_more_clever_tests.rs b/src/proxy/tests/handshake_more_clever_tests.rs new file mode 100644 index 0000000..77df442 --- /dev/null +++ b/src/proxy/tests/handshake_more_clever_tests.rs @@ -0,0 +1,614 @@ +use super::*; +use crate::crypto::{sha256, sha256_hmac, AesCtr}; +use crate::protocol::constants::{ProtoTag, RESERVED_NONCE_BEGINNINGS, RESERVED_NONCE_FIRST_BYTES}; +use rand::{Rng, SeedableRng}; +use rand::rngs::StdRng; +use std::collections::HashSet; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tokio::sync::Barrier; + +// --- Helpers --- + +fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> { + auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig { + let mut cfg = ProxyConfig::default(); + cfg.access.users.clear(); + cfg.access + .users + .insert("user".to_string(), secret_hex.to_string()); + cfg.access.ignore_time_skew = true; + cfg.general.modes.secure = true; + cfg.general.modes.classic = true; + cfg.general.modes.tls = true; + cfg +} + +fn make_valid_tls_handshake(secret: &[u8], timestamp: u32) -> Vec { + let session_id_len: usize = 32; + let len = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + session_id_len; + let mut handshake = vec![0x42u8; len]; + + handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8; + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + + let computed = sha256_hmac(secret, &handshake); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .copy_from_slice(&digest); + handshake +} + +fn make_valid_mtproto_handshake( + secret_hex: &str, + proto_tag: ProtoTag, + dc_idx: i16, +) -> [u8; HANDSHAKE_LEN] { + let secret = hex::decode(secret_hex).expect("secret hex must decode"); + let mut handshake = [0x5Au8; HANDSHAKE_LEN]; + for (idx, b) in handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN] + .iter_mut() + .enumerate() + { + *b = (idx as u8).wrapping_add(1); + } + + let dec_prekey = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN]; + let dec_iv_bytes = &handshake[SKIP_LEN + PREKEY_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]; + + let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len()); + dec_key_input.extend_from_slice(dec_prekey); + dec_key_input.extend_from_slice(&secret); + let dec_key = sha256(&dec_key_input); + + let mut dec_iv_arr = [0u8; IV_LEN]; + dec_iv_arr.copy_from_slice(dec_iv_bytes); + let dec_iv = u128::from_be_bytes(dec_iv_arr); + + let mut stream = AesCtr::new(&dec_key, dec_iv); + let keystream = stream.encrypt(&[0u8; HANDSHAKE_LEN]); + + let mut target_plain = [0u8; HANDSHAKE_LEN]; + target_plain[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes()); + target_plain[DC_IDX_POS..DC_IDX_POS + 2].copy_from_slice(&dc_idx.to_le_bytes()); + + for idx in PROTO_TAG_POS..HANDSHAKE_LEN { + handshake[idx] = target_plain[idx] ^ keystream[idx]; + } + + handshake +} + +fn make_valid_tls_client_hello_with_sni_and_alpn( + secret: &[u8], + timestamp: u32, + sni_host: &str, + alpn_protocols: &[&[u8]], +) -> Vec { + let mut body = Vec::new(); + body.extend_from_slice(&TLS_VERSION); + body.extend_from_slice(&[0u8; 32]); + body.push(32); + body.extend_from_slice(&[0x42u8; 32]); + body.extend_from_slice(&2u16.to_be_bytes()); + body.extend_from_slice(&[0x13, 0x01]); + body.push(1); + body.push(0); + + let mut ext_blob = Vec::new(); + + let host_bytes = sni_host.as_bytes(); + let mut sni_payload = Vec::new(); + sni_payload.extend_from_slice(&((host_bytes.len() + 3) as u16).to_be_bytes()); + sni_payload.push(0); + sni_payload.extend_from_slice(&(host_bytes.len() as u16).to_be_bytes()); + sni_payload.extend_from_slice(host_bytes); + ext_blob.extend_from_slice(&0x0000u16.to_be_bytes()); + ext_blob.extend_from_slice(&(sni_payload.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&sni_payload); + + if !alpn_protocols.is_empty() { + let mut alpn_list = Vec::new(); + for proto in alpn_protocols { + alpn_list.push(proto.len() as u8); + alpn_list.extend_from_slice(proto); + } + let mut alpn_data = Vec::new(); + alpn_data.extend_from_slice(&(alpn_list.len() as u16).to_be_bytes()); + alpn_data.extend_from_slice(&alpn_list); + + ext_blob.extend_from_slice(&0x0010u16.to_be_bytes()); + ext_blob.extend_from_slice(&(alpn_data.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&alpn_data); + } + + body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes()); + body.extend_from_slice(&ext_blob); + + let mut handshake = Vec::new(); + handshake.push(0x01); + let body_len = (body.len() as u32).to_be_bytes(); + handshake.extend_from_slice(&body_len[1..4]); + handshake.extend_from_slice(&body); + + let mut record = Vec::new(); + record.push(TLS_RECORD_HANDSHAKE); + record.extend_from_slice(&[0x03, 0x01]); + record.extend_from_slice(&(handshake.len() as u16).to_be_bytes()); + record.extend_from_slice(&handshake); + + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + let computed = sha256_hmac(secret, &record); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].copy_from_slice(&digest); + + record +} + +// --- Category 1: Timing & Delay Invariants --- + +#[tokio::test] +async fn server_hello_delay_bypassed_if_max_is_zero_despite_high_min() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret = [0x1Au8; 16]; + let mut config = test_config_with_secret_hex("1a1a1a1a1a1a1a1a1a1a1a1a1a1a1a1a"); + config.censorship.server_hello_delay_min_ms = 5000; + config.censorship.server_hello_delay_max_ms = 0; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "192.0.2.101:12345".parse().unwrap(); + + let mut invalid_handshake = make_valid_tls_handshake(&secret, 0); + invalid_handshake[tls::TLS_DIGEST_POS] ^= 0xFF; + + let fut = handle_tls_handshake( + &invalid_handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ); + + // Deterministic assertion: with max_ms == 0 there must be no sleep path, + // so the handshake should complete promptly under a generous timeout budget. + let res = tokio::time::timeout(Duration::from_millis(250), fut) + .await + .expect("max_ms=0 should bypass artificial delay and complete quickly"); + + assert!(matches!(res, HandshakeResult::BadClient { .. })); +} + +#[test] +fn auth_probe_backoff_extreme_fail_streak_clamps_safely() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let state = auth_probe_state_map(); + let peer_ip = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 99)); + let now = Instant::now(); + + state.insert( + peer_ip, + AuthProbeState { + fail_streak: u32::MAX - 1, + blocked_until: now, + last_seen: now, + }, + ); + + auth_probe_record_failure_with_state(&state, peer_ip, now); + + let updated = state.get(&peer_ip).unwrap(); + assert_eq!(updated.fail_streak, u32::MAX); + + let expected_blocked_until = now + Duration::from_millis(AUTH_PROBE_BACKOFF_MAX_MS); + assert_eq!(updated.blocked_until, expected_blocked_until, "Extreme fail streak must clamp cleanly to AUTH_PROBE_BACKOFF_MAX_MS"); +} + +#[test] +fn generate_tg_nonce_cryptographic_uniqueness_and_entropy() { + let client_enc_key = [0x2Bu8; 32]; + let client_enc_iv = 1337u128; + let rng = SecureRandom::new(); + + let mut nonces = HashSet::new(); + let mut total_set_bits = 0usize; + let iterations = 5_000; + + for _ in 0..iterations { + let (nonce, _, _, _, _) = generate_tg_nonce( + ProtoTag::Secure, + 2, + &client_enc_key, + client_enc_iv, + &rng, + false, + ); + + for byte in nonce.iter() { + total_set_bits += byte.count_ones() as usize; + } + + assert!(nonces.insert(nonce), "generate_tg_nonce emitted a duplicate nonce! RNG is stuck."); + } + + let total_bits = iterations * HANDSHAKE_LEN * 8; + let ratio = (total_set_bits as f64) / (total_bits as f64); + assert!(ratio > 0.48 && ratio < 0.52, "Nonce entropy is degraded. Set bit ratio: {}", ratio); +} + +#[tokio::test] +async fn mtproto_multi_user_decryption_isolation() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let mut config = ProxyConfig::default(); + config.general.modes.secure = true; + config.access.ignore_time_skew = true; + + config.access.users.insert("user_a".to_string(), "11111111111111111111111111111111".to_string()); + config.access.users.insert("user_b".to_string(), "22222222222222222222222222222222".to_string()); + let good_secret_hex = "33333333333333333333333333333333"; + config.access.users.insert("user_c".to_string(), good_secret_hex.to_string()); + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let peer: SocketAddr = "192.0.2.104:12345".parse().unwrap(); + + let valid_handshake = make_valid_mtproto_handshake(good_secret_hex, ProtoTag::Secure, 1); + + let res = handle_mtproto_handshake( + &valid_handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + + match res { + HandshakeResult::Success((_, _, success)) => { + assert_eq!(success.user, "user_c", "Decryption attempts on previous users must not corrupt the handshake buffer for the valid user"); + } + _ => panic!("Multi-user MTProto handshake failed. Decryption buffer might be mutating in place."), + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn invalid_secret_warning_lock_contention_and_bound() { + let _guard = warned_secrets_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_warned_secrets_for_testing(); + + let tasks = 50; + let iterations_per_task = 100; + let barrier = Arc::new(Barrier::new(tasks)); + let mut handles = Vec::new(); + + for t in 0..tasks { + let b = barrier.clone(); + handles.push(tokio::spawn(async move { + b.wait().await; + for i in 0..iterations_per_task { + let user_name = format!("contention_user_{}_{}", t, i); + warn_invalid_secret_once(&user_name, "invalid_hex", ACCESS_SECRET_BYTES, None); + } + })); + } + + for handle in handles { + handle.await.unwrap(); + } + + let warned = INVALID_SECRET_WARNED.get().unwrap(); + let guard = warned.lock().unwrap_or_else(|poisoned| poisoned.into_inner()); + + assert_eq!( + guard.len(), + WARNED_SECRET_MAX_ENTRIES, + "Concurrent spam of invalid secrets must strictly bound the HashSet memory to WARNED_SECRET_MAX_ENTRIES" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn mtproto_strict_concurrent_replay_race_condition() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret_hex = "4A4A4A4A4A4A4A4A4A4A4A4A4A4A4A4A"; + let config = Arc::new(test_config_with_secret_hex(secret_hex)); + let replay_checker = Arc::new(ReplayChecker::new(4096, Duration::from_secs(60))); + let valid_handshake = Arc::new(make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 1)); + + let tasks = 100; + let barrier = Arc::new(Barrier::new(tasks)); + let mut handles = Vec::new(); + + for i in 0..tasks { + let b = barrier.clone(); + let cfg = config.clone(); + let rc = replay_checker.clone(); + let hs = valid_handshake.clone(); + + handles.push(tokio::spawn(async move { + let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, (i % 250) as u8)), 10000 + i as u16); + b.wait().await; + handle_mtproto_handshake( + &hs, + tokio::io::empty(), + tokio::io::sink(), + peer, + &cfg, + &rc, + false, + None, + ) + .await + })); + } + + let mut successes = 0; + let mut failures = 0; + + for handle in handles { + match handle.await.unwrap() { + HandshakeResult::Success(_) => successes += 1, + HandshakeResult::BadClient { .. } => failures += 1, + _ => panic!("Unexpected error result in concurrent MTProto replay test"), + } + } + + assert_eq!(successes, 1, "Replay cache race condition allowed multiple identical MTProto handshakes to succeed"); + assert_eq!(failures, tasks - 1, "Replay cache failed to forcefully reject concurrent duplicates"); +} + +#[tokio::test] +async fn tls_alpn_zero_length_protocol_handled_safely() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret = [0x5Bu8; 16]; + let mut config = test_config_with_secret_hex("5b5b5b5b5b5b5b5b5b5b5b5b5b5b5b5b"); + config.censorship.alpn_enforce = true; + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "192.0.2.107:12345".parse().unwrap(); + + let handshake = make_valid_tls_client_hello_with_sni_and_alpn(&secret, 0, "example.com", &[b""]); + + let res = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(res, HandshakeResult::BadClient { .. }), "0-length ALPN must be safely rejected without panicking"); +} + +#[tokio::test] +async fn tls_sni_massive_hostname_does_not_panic() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret = [0x6Cu8; 16]; + let config = test_config_with_secret_hex("6c6c6c6c6c6c6c6c6c6c6c6c6c6c6c6c"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "192.0.2.108:12345".parse().unwrap(); + + let massive_hostname = String::from_utf8(vec![b'a'; 65000]).unwrap(); + let handshake = make_valid_tls_client_hello_with_sni_and_alpn(&secret, 0, &massive_hostname, &[]); + + let res = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(res, HandshakeResult::Success(_) | HandshakeResult::BadClient { .. }), "Massive SNI hostname must be processed or ignored without stack overflow or panic"); +} + +#[tokio::test] +async fn tls_progressive_truncation_fuzzing_no_panics() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret = [0x7Du8; 16]; + let config = test_config_with_secret_hex("7d7d7d7d7d7d7d7d7d7d7d7d7d7d7d7d"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "192.0.2.109:12345".parse().unwrap(); + + let valid_handshake = make_valid_tls_client_hello_with_sni_and_alpn(&secret, 0, "example.com", &[b"h2"]); + let full_len = valid_handshake.len(); + + // Truncated corpus only: full_len is a valid baseline and should not be + // asserted as BadClient in a truncation-specific test. + for i in (0..full_len).rev() { + let truncated = &valid_handshake[..i]; + let res = handle_tls_handshake( + truncated, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(res, HandshakeResult::BadClient { .. }), "Truncated TLS handshake at len {} must fail safely without panicking", i); + } +} + +#[tokio::test] +async fn mtproto_pure_entropy_fuzzing_no_panics() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let config = test_config_with_secret_hex("8e8e8e8e8e8e8e8e8e8e8e8e8e8e8e8e"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let peer: SocketAddr = "192.0.2.110:12345".parse().unwrap(); + + let mut seeded = StdRng::seed_from_u64(0xDEADBEEFCAFE); + + for _ in 0..10_000 { + let mut noise = [0u8; HANDSHAKE_LEN]; + seeded.fill_bytes(&mut noise); + + let res = handle_mtproto_handshake( + &noise, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + + assert!(matches!(res, HandshakeResult::BadClient { .. }), "Pure entropy MTProto payload must fail closed and never panic"); + } +} + +#[test] +fn decode_user_secret_odd_length_hex_rejection() { + let _guard = warned_secrets_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_warned_secrets_for_testing(); + + let mut config = ProxyConfig::default(); + config.access.users.clear(); + config.access.users.insert("odd_user".to_string(), "1234567890123456789012345678901".to_string()); + + let decoded = decode_user_secrets(&config, None); + assert!(decoded.is_empty(), "Odd-length hex string must be gracefully rejected by hex::decode without unwrapping"); +} + +#[test] +fn saturation_grace_pre_existing_high_fail_streak_immediate_throttle() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let state = auth_probe_state_map(); + let peer_ip = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 112)); + let now = Instant::now(); + + let extreme_streak = AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS + 5; + state.insert( + peer_ip, + AuthProbeState { + fail_streak: extreme_streak, + blocked_until: now + Duration::from_secs(5), + last_seen: now, + }, + ); + + { + let mut guard = auth_probe_saturation_state_lock(); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(5), + last_seen: now, + }); + } + + let is_throttled = auth_probe_should_apply_preauth_throttle(peer_ip, now); + assert!(is_throttled, "A peer with a pre-existing high fail streak must be immediately throttled when saturation begins, receiving no unearned grace period"); +} + +#[test] +fn auth_probe_saturation_note_resets_retention_window() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let base_time = Instant::now(); + + auth_probe_note_saturation(base_time); + let later = base_time + Duration::from_secs(AUTH_PROBE_TRACK_RETENTION_SECS - 1); + auth_probe_note_saturation(later); + + let check_time = base_time + Duration::from_secs(AUTH_PROBE_TRACK_RETENTION_SECS + 5); + + // This call may return false if backoff has elapsed, but it must not clear + // the saturation state because `later` refreshed last_seen. + let _ = auth_probe_saturation_is_throttled_at_for_testing(check_time); + let guard = auth_probe_saturation_state_lock(); + assert!( + guard.is_some(), + "Ongoing saturation notes must refresh last_seen so saturation state remains retained past the original window" + ); +} + +#[test] +fn mtproto_classic_tags_rejected_when_only_secure_mode_enabled() { + let mut config = ProxyConfig::default(); + config.general.modes.classic = false; + config.general.modes.secure = true; + config.general.modes.tls = false; + + assert!(!mode_enabled_for_proto(&config, ProtoTag::Abridged, false)); + assert!(!mode_enabled_for_proto(&config, ProtoTag::Intermediate, false)); +} + +#[test] +fn mtproto_secure_tag_rejected_when_only_classic_mode_enabled() { + let mut config = ProxyConfig::default(); + config.general.modes.classic = true; + config.general.modes.secure = false; + config.general.modes.tls = false; + + assert!(!mode_enabled_for_proto(&config, ProtoTag::Secure, false)); +} + +#[test] +fn ipv6_localhost_and_unspecified_normalization() { + let localhost = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)); + let unspecified = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)); + + let norm_local = normalize_auth_probe_ip(localhost); + let norm_unspec = normalize_auth_probe_ip(unspecified); + + let expected_bucket = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)); + + assert_eq!(norm_local, expected_bucket); + assert_eq!(norm_unspec, expected_bucket); +} diff --git a/src/proxy/tests/handshake_real_bug_stress_tests.rs b/src/proxy/tests/handshake_real_bug_stress_tests.rs new file mode 100644 index 0000000..d7234ff --- /dev/null +++ b/src/proxy/tests/handshake_real_bug_stress_tests.rs @@ -0,0 +1,337 @@ +use super::*; +use crate::crypto::{sha256, sha256_hmac, AesCtr, SecureRandom}; +use crate::protocol::constants::{ProtoTag, TLS_RECORD_HANDSHAKE, TLS_VERSION}; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tokio::sync::Barrier; + +fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> { + auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig { + let mut cfg = ProxyConfig::default(); + cfg.access.users.clear(); + cfg.access + .users + .insert("user".to_string(), secret_hex.to_string()); + cfg.access.ignore_time_skew = true; + cfg.general.modes.secure = true; + cfg.general.modes.classic = true; + cfg.general.modes.tls = true; + cfg +} + +fn make_valid_tls_client_hello_with_alpn( + secret: &[u8], + timestamp: u32, + alpn_protocols: &[&[u8]], +) -> Vec { + let mut body = Vec::new(); + body.extend_from_slice(&TLS_VERSION); + body.extend_from_slice(&[0u8; 32]); + body.push(32); + body.extend_from_slice(&[0x42u8; 32]); + body.extend_from_slice(&2u16.to_be_bytes()); + body.extend_from_slice(&[0x13, 0x01]); + body.push(1); + body.push(0); + + let mut ext_blob = Vec::new(); + if !alpn_protocols.is_empty() { + let mut alpn_list = Vec::new(); + for proto in alpn_protocols { + alpn_list.push(proto.len() as u8); + alpn_list.extend_from_slice(proto); + } + let mut alpn_data = Vec::new(); + alpn_data.extend_from_slice(&(alpn_list.len() as u16).to_be_bytes()); + alpn_data.extend_from_slice(&alpn_list); + + ext_blob.extend_from_slice(&0x0010u16.to_be_bytes()); + ext_blob.extend_from_slice(&(alpn_data.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&alpn_data); + } + + body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes()); + body.extend_from_slice(&ext_blob); + + let mut handshake = Vec::new(); + handshake.push(0x01); + let body_len = (body.len() as u32).to_be_bytes(); + handshake.extend_from_slice(&body_len[1..4]); + handshake.extend_from_slice(&body); + + let mut record = Vec::new(); + record.push(TLS_RECORD_HANDSHAKE); + record.extend_from_slice(&[0x03, 0x01]); + record.extend_from_slice(&(handshake.len() as u16).to_be_bytes()); + record.extend_from_slice(&handshake); + + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + + let computed = sha256_hmac(secret, &record); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .copy_from_slice(&digest); + + record +} + +fn make_valid_mtproto_handshake( + secret_hex: &str, + proto_tag: ProtoTag, + dc_idx: i16, +) -> [u8; HANDSHAKE_LEN] { + let secret = hex::decode(secret_hex).expect("secret hex must decode"); + let mut handshake = [0x5Au8; HANDSHAKE_LEN]; + for (idx, b) in handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN] + .iter_mut() + .enumerate() + { + *b = (idx as u8).wrapping_add(1); + } + + let dec_prekey = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN]; + let dec_iv_bytes = &handshake[SKIP_LEN + PREKEY_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]; + + let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len()); + dec_key_input.extend_from_slice(dec_prekey); + dec_key_input.extend_from_slice(&secret); + let dec_key = sha256(&dec_key_input); + + let mut dec_iv_arr = [0u8; IV_LEN]; + dec_iv_arr.copy_from_slice(dec_iv_bytes); + let dec_iv = u128::from_be_bytes(dec_iv_arr); + + let mut stream = AesCtr::new(&dec_key, dec_iv); + let keystream = stream.encrypt(&[0u8; HANDSHAKE_LEN]); + + let mut target_plain = [0u8; HANDSHAKE_LEN]; + target_plain[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes()); + target_plain[DC_IDX_POS..DC_IDX_POS + 2].copy_from_slice(&dc_idx.to_le_bytes()); + + for idx in PROTO_TAG_POS..HANDSHAKE_LEN { + handshake[idx] = target_plain[idx] ^ keystream[idx]; + } + + handshake +} + +#[tokio::test] +async fn tls_alpn_reject_does_not_pollute_replay_cache() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret = [0x11u8; 16]; + let mut config = test_config_with_secret_hex("11111111111111111111111111111111"); + config.censorship.alpn_enforce = true; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "192.0.2.201:12345".parse().unwrap(); + + let handshake = make_valid_tls_client_hello_with_alpn(&secret, 0, &[b"h3"]); + let before = replay_checker.stats(); + + let res = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + let after = replay_checker.stats(); + + assert!(matches!(res, HandshakeResult::BadClient { .. })); + assert_eq!( + before.total_additions, after.total_additions, + "ALPN policy reject must not add TLS digest into replay cache" + ); +} + +#[tokio::test] +async fn tls_truncated_session_id_len_fails_closed_without_panic() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let config = test_config_with_secret_hex("33333333333333333333333333333333"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "192.0.2.203:12345".parse().unwrap(); + + let min_len = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1; + let mut malicious = vec![0x42u8; min_len]; + malicious[min_len - 1] = u8::MAX; + + let res = handle_tls_handshake( + &malicious, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(res, HandshakeResult::BadClient { .. })); +} + +#[test] +fn auth_probe_eviction_identical_timestamps_keeps_map_bounded() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let state = auth_probe_state_map(); + let same = Instant::now(); + + for i in 0..AUTH_PROBE_TRACK_MAX_ENTRIES { + let ip = IpAddr::V4(Ipv4Addr::new(10, 1, (i >> 8) as u8, (i & 0xFF) as u8)); + state.insert( + ip, + AuthProbeState { + fail_streak: 7, + blocked_until: same, + last_seen: same, + }, + ); + } + + let new_ip = IpAddr::V4(Ipv4Addr::new(192, 168, 21, 21)); + auth_probe_record_failure_with_state(state, new_ip, same + Duration::from_millis(1)); + + assert_eq!(state.len(), AUTH_PROBE_TRACK_MAX_ENTRIES); + assert!(state.contains_key(&new_ip)); +} + +#[test] +fn clear_auth_probe_state_recovers_from_poisoned_saturation_lock() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let saturation = auth_probe_saturation_state(); + let poison_thread = std::thread::spawn(move || { + let _hold = saturation + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + panic!("intentional poison for regression coverage"); + }); + let _ = poison_thread.join(); + + clear_auth_probe_state_for_testing(); + + let guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + assert!(guard.is_none()); +} + +#[tokio::test] +async fn mtproto_invalid_length_secret_is_ignored_and_valid_user_still_auths() { + let _probe_guard = auth_probe_test_guard(); + let _warn_guard = warned_secrets_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + clear_warned_secrets_for_testing(); + + let mut config = ProxyConfig::default(); + config.general.modes.secure = true; + config.access.ignore_time_skew = true; + + config.access.users.insert( + "short_user".to_string(), + "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaa".to_string(), + ); + + let valid_secret_hex = "77777777777777777777777777777777"; + config + .access + .users + .insert("good_user".to_string(), valid_secret_hex.to_string()); + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let peer: SocketAddr = "192.0.2.207:12345".parse().unwrap(); + let handshake = make_valid_mtproto_handshake(valid_secret_hex, ProtoTag::Secure, 1); + + let res = handle_mtproto_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + + assert!(matches!(res, HandshakeResult::Success(_))); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn saturation_grace_exhaustion_under_concurrency_keeps_peer_throttled() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let peer_ip = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 80)); + let now = Instant::now(); + + { + let mut guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(5), + last_seen: now, + }); + } + + let state = auth_probe_state_map(); + state.insert( + peer_ip, + AuthProbeState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS - 1, + blocked_until: now, + last_seen: now, + }, + ); + + let tasks = 32; + let barrier = Arc::new(Barrier::new(tasks)); + let mut handles = Vec::new(); + + for _ in 0..tasks { + let b = barrier.clone(); + handles.push(tokio::spawn(async move { + b.wait().await; + auth_probe_record_failure(peer_ip, Instant::now()); + })); + } + + for handle in handles { + handle.await.unwrap(); + } + + let final_state = state.get(&peer_ip).expect("state must exist"); + assert!( + final_state.fail_streak >= AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS + ); + assert!(auth_probe_should_apply_preauth_throttle(peer_ip, Instant::now())); +} diff --git a/src/proxy/tests/handshake_timing_manual_bench_tests.rs b/src/proxy/tests/handshake_timing_manual_bench_tests.rs new file mode 100644 index 0000000..95e9f49 --- /dev/null +++ b/src/proxy/tests/handshake_timing_manual_bench_tests.rs @@ -0,0 +1,318 @@ +use super::*; +use crate::crypto::{sha256, sha256_hmac, AesCtr, SecureRandom}; +use crate::protocol::constants::{ProtoTag, TLS_RECORD_HANDSHAKE, TLS_VERSION}; +use std::net::SocketAddr; +use std::time::{Duration, Instant}; + +fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> { + auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +fn make_valid_mtproto_handshake( + secret_hex: &str, + proto_tag: ProtoTag, + dc_idx: i16, + salt: u8, +) -> [u8; HANDSHAKE_LEN] { + let secret = hex::decode(secret_hex).expect("secret hex must decode"); + let mut handshake = [0x5Au8; HANDSHAKE_LEN]; + + for (idx, b) in handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN] + .iter_mut() + .enumerate() + { + *b = (idx as u8).wrapping_add(1).wrapping_add(salt); + } + + let dec_prekey = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN]; + let dec_iv_bytes = &handshake[SKIP_LEN + PREKEY_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]; + + let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len()); + dec_key_input.extend_from_slice(dec_prekey); + dec_key_input.extend_from_slice(&secret); + let dec_key = sha256(&dec_key_input); + + let mut dec_iv_arr = [0u8; IV_LEN]; + dec_iv_arr.copy_from_slice(dec_iv_bytes); + let dec_iv = u128::from_be_bytes(dec_iv_arr); + + let mut stream = AesCtr::new(&dec_key, dec_iv); + let keystream = stream.encrypt(&[0u8; HANDSHAKE_LEN]); + + let mut target_plain = [0u8; HANDSHAKE_LEN]; + target_plain[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes()); + target_plain[DC_IDX_POS..DC_IDX_POS + 2].copy_from_slice(&dc_idx.to_le_bytes()); + + for idx in PROTO_TAG_POS..HANDSHAKE_LEN { + handshake[idx] = target_plain[idx] ^ keystream[idx]; + } + + handshake +} + +fn make_valid_tls_handshake(secret: &[u8], timestamp: u32) -> Vec { + let session_id_len: usize = 32; + let len = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + session_id_len; + let mut handshake = vec![0x42u8; len]; + + handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8; + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + + let computed = sha256_hmac(secret, &handshake); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .copy_from_slice(&digest); + handshake +} + +fn make_valid_tls_client_hello_with_sni_and_alpn( + secret: &[u8], + timestamp: u32, + sni_host: &str, + alpn_protocols: &[&[u8]], +) -> Vec { + let mut body = Vec::new(); + body.extend_from_slice(&TLS_VERSION); + body.extend_from_slice(&[0u8; 32]); + body.push(32); + body.extend_from_slice(&[0x42u8; 32]); + body.extend_from_slice(&2u16.to_be_bytes()); + body.extend_from_slice(&[0x13, 0x01]); + body.push(1); + body.push(0); + + let mut ext_blob = Vec::new(); + let host_bytes = sni_host.as_bytes(); + let mut sni_payload = Vec::new(); + sni_payload.extend_from_slice(&((host_bytes.len() + 3) as u16).to_be_bytes()); + sni_payload.push(0); + sni_payload.extend_from_slice(&(host_bytes.len() as u16).to_be_bytes()); + sni_payload.extend_from_slice(host_bytes); + ext_blob.extend_from_slice(&0x0000u16.to_be_bytes()); + ext_blob.extend_from_slice(&(sni_payload.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&sni_payload); + + if !alpn_protocols.is_empty() { + let mut alpn_list = Vec::new(); + for proto in alpn_protocols { + alpn_list.push(proto.len() as u8); + alpn_list.extend_from_slice(proto); + } + let mut alpn_data = Vec::new(); + alpn_data.extend_from_slice(&(alpn_list.len() as u16).to_be_bytes()); + alpn_data.extend_from_slice(&alpn_list); + + ext_blob.extend_from_slice(&0x0010u16.to_be_bytes()); + ext_blob.extend_from_slice(&(alpn_data.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&alpn_data); + } + + body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes()); + body.extend_from_slice(&ext_blob); + + let mut handshake = Vec::new(); + handshake.push(0x01); + let body_len = (body.len() as u32).to_be_bytes(); + handshake.extend_from_slice(&body_len[1..4]); + handshake.extend_from_slice(&body); + + let mut record = Vec::new(); + record.push(TLS_RECORD_HANDSHAKE); + record.extend_from_slice(&[0x03, 0x01]); + record.extend_from_slice(&(handshake.len() as u16).to_be_bytes()); + record.extend_from_slice(&handshake); + + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + let computed = sha256_hmac(secret, &record); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].copy_from_slice(&digest); + + record +} + +fn median_ns(samples: &mut [u128]) -> u128 { + samples.sort_unstable(); + samples[samples.len() / 2] +} + +#[tokio::test] +#[ignore = "manual benchmark: timing-sensitive and host-dependent"] +async fn mtproto_user_scan_timing_manual_benchmark() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + const DECOY_USERS: usize = 8_000; + const ITERATIONS: usize = 250; + + let preferred_user = "target_user"; + let target_secret_hex = "dededededededededededededededede"; + + let mut config = ProxyConfig::default(); + config.general.modes.secure = true; + config.access.ignore_time_skew = true; + + for i in 0..DECOY_USERS { + config.access.users.insert( + format!("decoy_{i}"), + "00000000000000000000000000000000".to_string(), + ); + } + + config.access.users.insert( + preferred_user.to_string(), + target_secret_hex.to_string(), + ); + + let replay_checker_preferred = ReplayChecker::new(65_536, Duration::from_secs(60)); + let replay_checker_full_scan = ReplayChecker::new(65_536, Duration::from_secs(60)); + let peer_a: SocketAddr = "192.0.2.241:12345".parse().unwrap(); + let peer_b: SocketAddr = "192.0.2.242:12345".parse().unwrap(); + + let mut preferred_samples = Vec::with_capacity(ITERATIONS); + let mut full_scan_samples = Vec::with_capacity(ITERATIONS); + + for i in 0..ITERATIONS { + let handshake = make_valid_mtproto_handshake( + target_secret_hex, + ProtoTag::Secure, + 1 + i as i16, + (i % 251) as u8, + ); + + let started_preferred = Instant::now(); + let preferred = handle_mtproto_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer_a, + &config, + &replay_checker_preferred, + false, + Some(preferred_user), + ) + .await; + preferred_samples.push(started_preferred.elapsed().as_nanos()); + assert!(matches!(preferred, HandshakeResult::Success(_))); + + let started_scan = Instant::now(); + let full_scan = handle_mtproto_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer_b, + &config, + &replay_checker_full_scan, + false, + None, + ) + .await; + full_scan_samples.push(started_scan.elapsed().as_nanos()); + assert!(matches!(full_scan, HandshakeResult::Success(_))); + } + + let preferred_median = median_ns(&mut preferred_samples); + let full_scan_median = median_ns(&mut full_scan_samples); + + let ratio = if preferred_median == 0 { + 0.0 + } else { + full_scan_median as f64 / preferred_median as f64 + }; + + println!( + "manual timing benchmark: decoys={DECOY_USERS}, iters={ITERATIONS}, preferred_median_ns={preferred_median}, full_scan_median_ns={full_scan_median}, ratio={ratio:.3}" + ); + + assert!( + full_scan_median >= preferred_median, + "full user scan should not be faster than preferred-user path in this benchmark" + ); +} + +#[tokio::test] +#[ignore = "manual benchmark: timing-sensitive and host-dependent"] +async fn tls_sni_preferred_vs_no_sni_fallback_manual_benchmark() { + let _guard = auth_probe_test_guard(); + + const DECOY_USERS: usize = 8_000; + const ITERATIONS: usize = 250; + + let preferred_user = "user-b"; + let target_secret_hex = "abababababababababababababababab"; + let target_secret = [0xABu8; 16]; + + let mut config = ProxyConfig::default(); + config.general.modes.tls = true; + config.access.ignore_time_skew = true; + + for i in 0..DECOY_USERS { + config.access.users.insert( + format!("decoy_{i}"), + "00000000000000000000000000000000".to_string(), + ); + } + + config + .access + .users + .insert(preferred_user.to_string(), target_secret_hex.to_string()); + + let mut sni_samples = Vec::with_capacity(ITERATIONS); + let mut no_sni_samples = Vec::with_capacity(ITERATIONS); + + for i in 0..ITERATIONS { + let with_sni = make_valid_tls_client_hello_with_sni_and_alpn( + &target_secret, + i as u32, + preferred_user, + &[b"h2"], + ); + let no_sni = make_valid_tls_handshake(&target_secret, (i as u32).wrapping_add(10_000)); + + let started_sni = Instant::now(); + let sni_secrets = decode_user_secrets(&config, Some(preferred_user)); + let sni_result = tls::validate_tls_handshake_with_replay_window( + &with_sni, + &sni_secrets, + config.access.ignore_time_skew, + config.access.replay_window_secs, + ); + sni_samples.push(started_sni.elapsed().as_nanos()); + assert!(sni_result.is_some()); + + let started_no_sni = Instant::now(); + let no_sni_secrets = decode_user_secrets(&config, None); + let no_sni_result = tls::validate_tls_handshake_with_replay_window( + &no_sni, + &no_sni_secrets, + config.access.ignore_time_skew, + config.access.replay_window_secs, + ); + no_sni_samples.push(started_no_sni.elapsed().as_nanos()); + assert!(no_sni_result.is_some()); + } + + let sni_median = median_ns(&mut sni_samples); + let no_sni_median = median_ns(&mut no_sni_samples); + + let ratio = if sni_median == 0 { + 0.0 + } else { + no_sni_median as f64 / sni_median as f64 + }; + + println!( + "manual tls benchmark: decoys={DECOY_USERS}, iters={ITERATIONS}, sni_median_ns={sni_median}, no_sni_median_ns={no_sni_median}, ratio_no_sni_over_sni={ratio:.3}" + ); +} diff --git a/src/proxy/tests/masking_ab_envelope_blur_integration_security_tests.rs b/src/proxy/tests/masking_ab_envelope_blur_integration_security_tests.rs index 3e860e8..84c904f 100644 --- a/src/proxy/tests/masking_ab_envelope_blur_integration_security_tests.rs +++ b/src/proxy/tests/masking_ab_envelope_blur_integration_security_tests.rs @@ -493,9 +493,12 @@ async fn timing_classifier_light_fuzz_pairwise_bucketed_accuracy_stays_bounded_u ]; let mut meaningful_improvement_seen = false; - let mut baseline_sum = 0.0f64; - let mut hardened_sum = 0.0f64; - let mut pair_count = 0usize; + let mut informative_baseline_sum = 0.0f64; + let mut informative_hardened_sum = 0.0f64; + let mut informative_pair_count = 0usize; + let mut low_info_baseline_sum = 0.0f64; + let mut low_info_hardened_sum = 0.0f64; + let mut low_info_pair_count = 0usize; let acc_quant_step = 1.0 / (2 * SAMPLE_COUNT) as f64; let tolerated_pair_regression = acc_quant_step + 0.03; @@ -522,6 +525,16 @@ async fn timing_classifier_light_fuzz_pairwise_bucketed_accuracy_stays_bounded_u hardened_acc <= baseline_acc + tolerated_pair_regression, "normalization should not materially worsen informative pair: baseline={baseline_acc:.3} hardened={hardened_acc:.3} tolerated={tolerated_pair_regression:.3}" ); + informative_baseline_sum += baseline_acc; + informative_hardened_sum += hardened_acc; + informative_pair_count += 1; + } else { + // Low-information pairs (near-random baseline separability) are expected + // to exhibit quantized jitter at low sample counts; do not fold them into + // strict average-regression checks used for informative side-channel signal. + low_info_baseline_sum += baseline_acc; + low_info_hardened_sum += hardened_acc; + low_info_pair_count += 1; } println!( @@ -532,19 +545,30 @@ async fn timing_classifier_light_fuzz_pairwise_bucketed_accuracy_stays_bounded_u meaningful_improvement_seen = true; } - baseline_sum += baseline_acc; - hardened_sum += hardened_acc; - pair_count += 1; } - let baseline_avg = baseline_sum / pair_count as f64; - let hardened_avg = hardened_sum / pair_count as f64; + assert!( + informative_pair_count > 0, + "expected at least one informative pair for timing-separability guard" + ); + + let informative_baseline_avg = informative_baseline_sum / informative_pair_count as f64; + let informative_hardened_avg = informative_hardened_sum / informative_pair_count as f64; assert!( - hardened_avg <= baseline_avg + 0.10, - "normalization should not materially increase average pairwise separability: baseline_avg={baseline_avg:.3} hardened_avg={hardened_avg:.3}" + informative_hardened_avg <= informative_baseline_avg + 0.10, + "normalization should not materially increase informative average separability: baseline_avg={informative_baseline_avg:.3} hardened_avg={informative_hardened_avg:.3}" ); + if low_info_pair_count > 0 { + let low_info_baseline_avg = low_info_baseline_sum / low_info_pair_count as f64; + let low_info_hardened_avg = low_info_hardened_sum / low_info_pair_count as f64; + assert!( + low_info_hardened_avg <= low_info_baseline_avg + 0.40, + "normalization low-info average drift exceeded jitter budget: baseline_avg={low_info_baseline_avg:.3} hardened_avg={low_info_hardened_avg:.3}" + ); + } + // Optional signal only: do not require improvement on every run because // noisy CI schedulers can flatten pairwise differences at low sample counts. let _ = meaningful_improvement_seen; diff --git a/src/proxy/tests/masking_additional_hardening_security_tests.rs b/src/proxy/tests/masking_additional_hardening_security_tests.rs new file mode 100644 index 0000000..29170c1 --- /dev/null +++ b/src/proxy/tests/masking_additional_hardening_security_tests.rs @@ -0,0 +1,122 @@ +use super::*; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::task::{Context, Poll}; +use tokio::io::AsyncRead; +use tokio::time::{Duration, timeout}; + +struct EndlessReader { + produced: Arc, +} + +impl AsyncRead for EndlessReader { + fn poll_read( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + let len = buf.remaining().max(1); + let fill = vec![0xAA; len]; + buf.put_slice(&fill); + self.produced.fetch_add(len, Ordering::Relaxed); + Poll::Ready(Ok(())) + } +} + +#[test] +fn loop_guard_unspecified_bind_uses_interface_inventory() { + let local: SocketAddr = "0.0.0.0:443".parse().unwrap(); + let resolved: SocketAddr = "192.168.44.10:443".parse().unwrap(); + let interfaces = vec!["192.168.44.10".parse().unwrap()]; + + assert!(is_mask_target_local_listener_with_interfaces( + "mask.example", + 443, + local, + Some(resolved), + &interfaces, + )); +} + +#[tokio::test] +async fn consume_client_data_stops_after_byte_cap_without_eof() { + let produced = Arc::new(AtomicUsize::new(0)); + let reader = EndlessReader { + produced: Arc::clone(&produced), + }; + let cap = 10_000usize; + + consume_client_data(reader, cap).await; + + let total = produced.load(Ordering::Relaxed); + assert!( + total >= cap, + "consume path must read at least up to cap before stopping" + ); + assert!( + total <= cap + 8192, + "consume path must stop within one read chunk above cap" + ); +} + +#[test] +fn masking_beobachten_minutes_zero_fail_closes_to_minimum_ttl() { + let mut config = ProxyConfig::default(); + config.general.beobachten = true; + config.general.beobachten_minutes = 0; + + let ttl = masking_beobachten_ttl(&config); + assert_eq!(ttl, std::time::Duration::from_secs(60)); +} + +#[test] +fn timing_normalization_zero_floor_safety_net_defaults_to_mask_timeout() { + let mut config = ProxyConfig::default(); + config.censorship.mask_timing_normalization_enabled = true; + config.censorship.mask_timing_normalization_floor_ms = 0; + config.censorship.mask_timing_normalization_ceiling_ms = 0; + + let budget = mask_outcome_target_budget(&config); + assert_eq!(budget, MASK_TIMEOUT); +} + +#[tokio::test] +async fn loop_guard_blocks_self_target_before_proxy_protocol_header_growth() { + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let accept_task = tokio::spawn(async move { + timeout(Duration::from_millis(120), listener.accept()) + .await + .is_ok() + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = None; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_proxy_protocol = 2; + + let peer: SocketAddr = "203.0.113.251:55991".parse().unwrap(); + let local_addr: SocketAddr = format!("0.0.0.0:{}", backend_addr.port()).parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + handle_bad_client( + tokio::io::empty(), + tokio::io::sink(), + b"GET / HTTP/1.1\r\n\r\n", + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + let accepted = accept_task.await.unwrap(); + assert!( + !accepted, + "loop guard must fail closed before any recursive PROXY protocol amplification" + ); +} diff --git a/src/proxy/tests/masking_classification_completeness_security_tests.rs b/src/proxy/tests/masking_classification_completeness_security_tests.rs new file mode 100644 index 0000000..35bf87b --- /dev/null +++ b/src/proxy/tests/masking_classification_completeness_security_tests.rs @@ -0,0 +1,16 @@ +use super::*; + +#[test] +fn detect_client_type_recognizes_extended_http_probe_verbs() { + assert_eq!(detect_client_type(b"CONNECT / HTTP/1.1\r\n"), "HTTP"); + assert_eq!(detect_client_type(b"TRACE / HTTP/1.1\r\n"), "HTTP"); + assert_eq!(detect_client_type(b"PATCH / HTTP/1.1\r\n"), "HTTP"); +} + +#[test] +fn detect_client_type_recognizes_fragmented_http_method_prefixes() { + assert_eq!(detect_client_type(b"CO"), "HTTP"); + assert_eq!(detect_client_type(b"CON"), "HTTP"); + assert_eq!(detect_client_type(b"TR"), "HTTP"); + assert_eq!(detect_client_type(b"PAT"), "HTTP"); +} diff --git a/src/proxy/tests/masking_connect_failure_close_matrix_security_tests.rs b/src/proxy/tests/masking_connect_failure_close_matrix_security_tests.rs new file mode 100644 index 0000000..614af9b --- /dev/null +++ b/src/proxy/tests/masking_connect_failure_close_matrix_security_tests.rs @@ -0,0 +1,127 @@ +use super::*; +use crate::network::dns_overrides::install_entries; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::time::{Duration, Instant, timeout}; + +async fn run_connect_failure_case( + host: &str, + port: u16, + timing_normalization_enabled: bool, + peer: SocketAddr, +) -> Duration { + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = None; + config.censorship.mask_host = Some(host.to_string()); + config.censorship.mask_port = port; + config.censorship.mask_timing_normalization_enabled = timing_normalization_enabled; + config.censorship.mask_timing_normalization_floor_ms = 120; + config.censorship.mask_timing_normalization_ceiling_ms = 120; + + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + let probe = b"CONNECT example.org:443 HTTP/1.1\r\nHost: example.org\r\n\r\n"; + + let (mut client_writer, client_reader) = duplex(1024); + let (mut client_visible_reader, client_visible_writer) = duplex(1024); + + let started = Instant::now(); + let task = tokio::spawn(async move { + handle_bad_client( + client_reader, + client_visible_writer, + probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + }); + + client_writer.shutdown().await.unwrap(); + + timeout(Duration::from_secs(4), task) + .await + .unwrap() + .unwrap(); + + let mut buf = [0u8; 1]; + let n = timeout(Duration::from_secs(1), client_visible_reader.read(&mut buf)) + .await + .unwrap() + .unwrap(); + assert_eq!(n, 0, "connect-failure path must close client-visible writer"); + + started.elapsed() +} + +#[tokio::test] +async fn connect_failure_refusal_close_behavior_matrix() { + let temp_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let unused_port = temp_listener.local_addr().unwrap().port(); + drop(temp_listener); + + for (idx, timing_normalization_enabled) in [false, true].into_iter().enumerate() { + let peer: SocketAddr = format!("203.0.113.210:{}", 54100 + idx as u16) + .parse() + .unwrap(); + let elapsed = run_connect_failure_case( + "127.0.0.1", + unused_port, + timing_normalization_enabled, + peer, + ) + .await; + + if timing_normalization_enabled { + assert!( + elapsed >= Duration::from_millis(110) && elapsed < Duration::from_millis(250), + "normalized refusal path must honor configured timing envelope without stalling" + ); + } else { + assert!( + elapsed >= Duration::from_millis(40) && elapsed < Duration::from_millis(150), + "non-normalized refusal path must honor baseline connect budget without stalling" + ); + } + } +} + +#[tokio::test] +async fn connect_failure_overridden_hostname_close_behavior_matrix() { + let temp_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let unused_port = temp_listener.local_addr().unwrap().port(); + drop(temp_listener); + + // Make hostname resolution deterministic in tests so timing ceilings are meaningful. + install_entries(&[format!("mask.invalid:{}:127.0.0.1", unused_port)]).unwrap(); + + for (idx, timing_normalization_enabled) in [false, true].into_iter().enumerate() { + let peer: SocketAddr = format!("203.0.113.220:{}", 54200 + idx as u16) + .parse() + .unwrap(); + let elapsed = run_connect_failure_case( + "mask.invalid", + unused_port, + timing_normalization_enabled, + peer, + ) + .await; + + if timing_normalization_enabled { + assert!( + elapsed >= Duration::from_millis(110) && elapsed < Duration::from_millis(250), + "normalized overridden-host path must honor configured timing envelope without stalling" + ); + } else { + assert!( + elapsed >= Duration::from_millis(40) && elapsed < Duration::from_millis(150), + "non-normalized overridden-host path must honor baseline connect budget without stalling" + ); + } + } + + install_entries(&[]).unwrap(); +} diff --git a/src/proxy/tests/masking_consume_idle_timeout_security_tests.rs b/src/proxy/tests/masking_consume_idle_timeout_security_tests.rs new file mode 100644 index 0000000..b52af35 --- /dev/null +++ b/src/proxy/tests/masking_consume_idle_timeout_security_tests.rs @@ -0,0 +1,85 @@ +use super::*; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::time::Instant; +use tokio::io::{AsyncRead, ReadBuf}; + +struct OneByteThenStall { + sent: bool, +} + +impl AsyncRead for OneByteThenStall { + fn poll_read( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + if !self.sent { + self.sent = true; + buf.put_slice(&[0x42]); + Poll::Ready(Ok(())) + } else { + Poll::Pending + } + } +} + +#[tokio::test] +async fn stalling_client_terminates_at_idle_not_relay_timeout() { + let reader = OneByteThenStall { sent: false }; + let started = Instant::now(); + + let result = tokio::time::timeout( + MASK_RELAY_TIMEOUT, + consume_client_data(reader, MASK_BUFFER_SIZE * 4), + ) + .await; + + assert!( + result.is_ok(), + "consume_client_data should complete by per-read idle timeout, not hit relay timeout" + ); + + let elapsed = started.elapsed(); + assert!( + elapsed >= (MASK_RELAY_IDLE_TIMEOUT / 2), + "consume_client_data returned too quickly for idle-timeout path: {elapsed:?}" + ); + assert!( + elapsed < MASK_RELAY_TIMEOUT, + "consume_client_data waited full relay timeout ({elapsed:?}); \ + per-read idle timeout is missing" + ); +} + +#[tokio::test] +async fn fast_reader_drains_to_eof() { + let data = vec![0xAAu8; 32 * 1024]; + let reader = std::io::Cursor::new(data); + + tokio::time::timeout(MASK_RELAY_TIMEOUT, consume_client_data(reader, usize::MAX)) + .await + .expect("consume_client_data did not complete for fast EOF reader"); +} + +#[tokio::test] +async fn io_error_terminates_cleanly() { + struct ErrReader; + + impl AsyncRead for ErrReader { + fn poll_read( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _buf: &mut ReadBuf<'_>, + ) -> Poll> { + Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::ConnectionReset, + "simulated reset", + ))) + } + } + + tokio::time::timeout(MASK_RELAY_TIMEOUT, consume_client_data(ErrReader, usize::MAX)) + .await + .expect("consume_client_data did not return on I/O error"); +} diff --git a/src/proxy/tests/masking_consume_stress_adversarial_tests.rs b/src/proxy/tests/masking_consume_stress_adversarial_tests.rs new file mode 100644 index 0000000..12287b5 --- /dev/null +++ b/src/proxy/tests/masking_consume_stress_adversarial_tests.rs @@ -0,0 +1,64 @@ +use super::*; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::time::Instant; +use tokio::io::{AsyncRead, ReadBuf}; +use tokio::task::JoinSet; + +struct OneByteThenStall { + sent: bool, +} + +impl AsyncRead for OneByteThenStall { + fn poll_read( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + if !self.sent { + self.sent = true; + buf.put_slice(&[0xAA]); + Poll::Ready(Ok(())) + } else { + Poll::Pending + } + } +} + +#[tokio::test] +async fn consume_stall_stress_finishes_within_idle_budget() { + let mut set = JoinSet::new(); + let started = Instant::now(); + + for _ in 0..64 { + set.spawn(async { + tokio::time::timeout( + MASK_RELAY_TIMEOUT, + consume_client_data(OneByteThenStall { sent: false }, usize::MAX), + ) + .await + .expect("consume_client_data exceeded relay timeout under stall load"); + }); + } + + while let Some(res) = set.join_next().await { + res.unwrap(); + } + + // Under test constants idle=100ms, relay=200ms. 64 concurrent tasks stalling + // for 100ms should complete well under a strict 600ms boundary. + assert!( + started.elapsed() < MASK_RELAY_TIMEOUT * 3, + "stall stress batch completed too slowly; possible async executor starvation or head-of-line blocking" + ); +} + +#[tokio::test] +async fn consume_zero_cap_returns_immediately() { + let started = Instant::now(); + consume_client_data(tokio::io::empty(), 0).await; + assert!( + started.elapsed() < MASK_RELAY_IDLE_TIMEOUT, + "zero byte cap must return immediately" + ); +} diff --git a/src/proxy/tests/masking_extended_attack_surface_security_tests.rs b/src/proxy/tests/masking_extended_attack_surface_security_tests.rs new file mode 100644 index 0000000..040f567 --- /dev/null +++ b/src/proxy/tests/masking_extended_attack_surface_security_tests.rs @@ -0,0 +1,217 @@ +use super::*; +use tokio::io::{AsyncWriteExt, duplex}; +use tokio::time::{Duration, Instant, timeout}; + +fn make_self_target_config( + timing_normalization_enabled: bool, + floor_ms: u64, + ceiling_ms: u64, + beobachten_enabled: bool, +) -> ProxyConfig { + let mut config = ProxyConfig::default(); + config.general.beobachten = beobachten_enabled; + config.general.beobachten_minutes = 5; + config.censorship.mask = true; + config.censorship.mask_unix_sock = None; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = 443; + config.censorship.mask_timing_normalization_enabled = timing_normalization_enabled; + config.censorship.mask_timing_normalization_floor_ms = floor_ms; + config.censorship.mask_timing_normalization_ceiling_ms = ceiling_ms; + config +} + +async fn run_self_target_refusal( + config: ProxyConfig, + peer: SocketAddr, + initial: &'static [u8], +) -> Duration { + let beobachten = BeobachtenStore::new(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().expect("valid local addr"); + + let (mut client, server) = duplex(1024); + let started = Instant::now(); + let task = tokio::spawn(async move { + handle_bad_client(server, tokio::io::sink(), initial, peer, local_addr, &config, &beobachten) + .await; + }); + + client + .shutdown() + .await + .expect("client shutdown must succeed"); + + timeout(Duration::from_secs(3), task) + .await + .expect("self-target refusal must complete in bounded time") + .expect("self-target refusal task must not panic"); + + started.elapsed() +} + +#[tokio::test] +async fn positive_self_target_refusal_honors_normalization_floor() { + let config = make_self_target_config(true, 120, 120, false); + let peer: SocketAddr = "203.0.113.41:54041".parse().expect("valid peer"); + + let elapsed = run_self_target_refusal(config, peer, b"GET / HTTP/1.1\r\n\r\n").await; + + assert!( + elapsed >= Duration::from_millis(110) && elapsed < Duration::from_millis(260), + "normalized self-target refusal must stay within expected envelope" + ); +} + +#[tokio::test] +async fn negative_non_normalized_refusal_does_not_sleep_to_large_floor() { + let config = make_self_target_config(false, 240, 240, false); + let peer: SocketAddr = "203.0.113.42:54042".parse().expect("valid peer"); + + let elapsed = run_self_target_refusal(config, peer, b"GET / HTTP/1.1\r\n\r\n").await; + + assert!( + elapsed < Duration::from_millis(180), + "non-normalized path must not inherit normalization floor delays" + ); +} + +#[tokio::test] +async fn edge_ceiling_below_floor_uses_floor_fail_closed() { + let config = make_self_target_config(true, 140, 80, false); + let peer: SocketAddr = "203.0.113.43:54043".parse().expect("valid peer"); + + let elapsed = run_self_target_refusal(config, peer, b"GET / HTTP/1.1\r\n\r\n").await; + + assert!( + elapsed >= Duration::from_millis(130) && elapsed < Duration::from_millis(280), + "ceiling max { + max = elapsed; + } + assert!( + elapsed >= Duration::from_millis(100) && elapsed < Duration::from_millis(320), + "parallel probe latency must stay bounded under normalization" + ); + } + + assert!( + max.saturating_sub(min) <= Duration::from_millis(130), + "normalization should limit path variance across adversarial parallel probes" + ); +} + +#[tokio::test] +async fn integration_beobachten_records_probe_classification_on_refusal() { + let config = make_self_target_config(false, 0, 0, true); + let peer: SocketAddr = "198.51.100.71:55071".parse().expect("valid peer"); + let local_addr: SocketAddr = "127.0.0.1:443".parse().expect("valid local addr"); + let beobachten = BeobachtenStore::new(); + + let (mut client, server) = duplex(1024); + let task = tokio::spawn(async move { + handle_bad_client( + server, + tokio::io::sink(), + b"GET /classified HTTP/1.1\r\nHost: demo\r\n\r\n", + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + beobachten.snapshot_text(Duration::from_secs(60)) + }); + + client + .shutdown() + .await + .expect("client shutdown must succeed"); + + let snapshot = timeout(Duration::from_secs(3), task) + .await + .expect("integration task must complete") + .expect("integration task must not panic"); + + assert!(snapshot.contains("[HTTP]")); + assert!(snapshot.contains("198.51.100.71-1")); +} + +#[tokio::test] +async fn light_fuzz_timing_configuration_matrix_is_bounded() { + let mut seed = 0xA17E_55AA_2026_0323u64; + + for case in 0..48u64 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let enabled = (seed & 1) == 0; + let floor = (seed >> 8) % 180; + let ceiling = (seed >> 24) % 180; + let config = make_self_target_config(enabled, floor, ceiling, false); + let peer: SocketAddr = format!("203.0.113.90:{}", 56000 + (case as u16)) + .parse() + .expect("valid peer"); + + let elapsed = run_self_target_refusal(config, peer, b"HEAD /h HTTP/1.1\r\n\r\n").await; + + assert!( + elapsed < Duration::from_millis(420), + "fuzz case must stay bounded and never hang" + ); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_high_fanout_self_target_refusal_no_deadlock_or_timeout() { + let workers = 64usize; + let mut tasks = Vec::with_capacity(workers); + + for idx in 0..workers { + tasks.push(tokio::spawn(async move { + let config = make_self_target_config(false, 0, 0, false); + let peer: SocketAddr = format!("198.51.100.200:{}", 57000 + idx as u16) + .parse() + .expect("valid peer"); + run_self_target_refusal(config, peer, b"GET /stress HTTP/1.1\r\n\r\n").await + })); + } + + timeout(Duration::from_secs(5), async { + for task in tasks { + let elapsed = task.await.expect("stress task must not panic"); + assert!( + elapsed < Duration::from_millis(260), + "stress refusal must remain bounded without normalization" + ); + } + }) + .await + .expect("high-fanout refusal workload must complete without deadlock"); +} \ No newline at end of file diff --git a/src/proxy/tests/masking_http2_preface_integration_security_tests.rs b/src/proxy/tests/masking_http2_preface_integration_security_tests.rs new file mode 100644 index 0000000..7f1c03f --- /dev/null +++ b/src/proxy/tests/masking_http2_preface_integration_security_tests.rs @@ -0,0 +1,55 @@ +use super::*; +use tokio::net::TcpListener; +use tokio::time::Duration; + +#[tokio::test] +async fn http2_preface_is_forwarded_and_recorded_as_http() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let preface = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn({ + let preface = preface.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut received = vec![0u8; preface.len()]; + stream.read_exact(&mut received).await.unwrap(); + assert_eq!(received, preface); + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = true; + config.general.beobachten_minutes = 1; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let peer: SocketAddr = "198.51.100.130:54130".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let (client_reader, _client_writer) = tokio::io::duplex(512); + let (_client_visible_reader, client_visible_writer) = tokio::io::duplex(512); + let beobachten = BeobachtenStore::new(); + + handle_bad_client( + client_reader, + client_visible_writer, + &preface, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + tokio::time::timeout(Duration::from_secs(2), accept_task) + .await + .unwrap() + .unwrap(); + + let snapshot = beobachten.snapshot_text(Duration::from_secs(60)); + assert!(snapshot.contains("[HTTP]")); + assert!(snapshot.contains("198.51.100.130-1")); +} diff --git a/src/proxy/tests/masking_http2_probe_classification_security_tests.rs b/src/proxy/tests/masking_http2_probe_classification_security_tests.rs new file mode 100644 index 0000000..34e04a9 --- /dev/null +++ b/src/proxy/tests/masking_http2_probe_classification_security_tests.rs @@ -0,0 +1,92 @@ +use super::*; + +#[test] +fn full_http2_preface_classified_as_http_probe() { + let preface = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"; + assert!( + is_http_probe(preface), + "HTTP/2 connection preface must be classified as HTTP probe" + ); +} + +#[test] +fn partial_http2_preface_3_bytes_classified() { + assert!( + is_http_probe(b"PRI"), + "3-byte HTTP/2 preface prefix must be classified" + ); +} + +#[test] +fn partial_http2_preface_2_bytes_classified() { + assert!( + is_http_probe(b"PR"), + "2-byte HTTP/2 preface prefix must be classified" + ); +} + +#[test] +fn existing_http1_methods_unaffected() { + for prefix in [ + b"GET / HTTP/1.1\r\n".as_ref(), + b"POST /api HTTP/1.1\r\n".as_ref(), + b"CONNECT example.com:443 HTTP/1.1\r\n".as_ref(), + b"TRACE / HTTP/1.1\r\n".as_ref(), + b"PATCH / HTTP/1.1\r\n".as_ref(), + ] { + assert!(is_http_probe(prefix)); + } +} + +#[test] +fn non_http_data_not_classified() { + for data in [ + b"\x16\x03\x01\x00\xf1".as_ref(), + b"SSH-2.0-OpenSSH_8.9\r\n".as_ref(), + b"\x00\x01\x02\x03".as_ref(), + b"".as_ref(), + b"P".as_ref(), + ] { + assert!(!is_http_probe(data)); + } +} + +#[test] +fn light_fuzz_non_http_prefixes_not_misclassified() { + // Deterministic pseudo-fuzz to exercise classifier edges while avoiding + // known HTTP method and partial windows. + let mut x = 0x1234_5678u32; + for _ in 0..1024 { + x = x.wrapping_mul(1664525).wrapping_add(1013904223); + let len = 4 + ((x >> 8) as usize % 12); + let mut data = vec![0u8; len]; + for byte in &mut data { + x = x.wrapping_mul(1664525).wrapping_add(1013904223); + *byte = (x & 0xFF) as u8; + } + + if [ + b"GET ".as_ref(), + b"POST".as_ref(), + b"HEAD".as_ref(), + b"PUT ".as_ref(), + b"DELETE".as_ref(), + b"OPTIONS".as_ref(), + b"CONNECT".as_ref(), + b"TRACE".as_ref(), + b"PATCH".as_ref(), + b"PRI ".as_ref(), + ] + .iter() + .any(|m| data.starts_with(m)) + { + continue; + } + + assert!( + !is_http_probe(&data), + "non-http pseudo-fuzz input misclassified: {:?}", + &data[..data.len().min(8)] + ); + } +} diff --git a/src/proxy/tests/masking_http_probe_boundary_security_tests.rs b/src/proxy/tests/masking_http_probe_boundary_security_tests.rs new file mode 100644 index 0000000..47b6dc6 --- /dev/null +++ b/src/proxy/tests/masking_http_probe_boundary_security_tests.rs @@ -0,0 +1,79 @@ +use super::*; + +#[test] +fn exact_four_byte_http_tokens_are_classified() { + for token in [b"GET ".as_ref(), b"POST".as_ref(), b"HEAD".as_ref(), b"PUT ".as_ref(), b"PRI ".as_ref()] { + assert!( + is_http_probe(token), + "exact 4-byte token must be classified as HTTP probe: {:?}", + token + ); + } +} + +#[test] +fn exact_four_byte_non_http_tokens_are_not_classified() { + for token in [ + b"GEX ".as_ref(), + b"POXT".as_ref(), + b"HEA/".as_ref(), + b"PU\0 ".as_ref(), + b"PRI/".as_ref(), + ] { + assert!( + !is_http_probe(token), + "non-HTTP 4-byte token must not be classified: {:?}", + token + ); + } +} + +#[test] +fn detect_client_type_keeps_http_label_for_minimal_four_byte_http_prefixes() { + assert_eq!(detect_client_type(b"GET "), "HTTP"); + assert_eq!(detect_client_type(b"PRI "), "HTTP"); +} + +#[test] +fn exact_long_http_tokens_are_classified() { + for token in [b"CONNECT".as_ref(), b"TRACE".as_ref(), b"PATCH".as_ref()] { + assert!( + is_http_probe(token), + "exact long HTTP token must be classified as HTTP probe: {:?}", + token + ); + } +} + +#[test] +fn detect_client_type_keeps_http_label_for_exact_long_http_tokens() { + assert_eq!(detect_client_type(b"CONNECT"), "HTTP"); + assert_eq!(detect_client_type(b"TRACE"), "HTTP"); + assert_eq!(detect_client_type(b"PATCH"), "HTTP"); +} + +#[test] +fn light_fuzz_four_byte_ascii_noise_not_misclassified() { + // Deterministic pseudo-fuzz over 4-byte printable ASCII inputs. + let mut x = 0xA17C_93E5u32; + for _ in 0..2048 { + let mut token = [0u8; 4]; + for byte in &mut token { + x = x.wrapping_mul(1664525).wrapping_add(1013904223); + *byte = 32 + ((x & 0x3F) as u8); // printable ASCII subset + } + + if [b"GET ", b"POST", b"HEAD", b"PUT ", b"PRI "] + .iter() + .any(|m| token.as_slice() == *m) + { + continue; + } + + assert!( + !is_http_probe(&token), + "pseudo-fuzz noise misclassified as HTTP probe: {:?}", + token + ); + } +} \ No newline at end of file diff --git a/src/proxy/tests/masking_interface_cache_concurrency_security_tests.rs b/src/proxy/tests/masking_interface_cache_concurrency_security_tests.rs new file mode 100644 index 0000000..8d99b8f --- /dev/null +++ b/src/proxy/tests/masking_interface_cache_concurrency_security_tests.rs @@ -0,0 +1,41 @@ +#![cfg(unix)] + +use super::*; +use std::sync::{Mutex, OnceLock}; +use tokio::sync::Barrier; + +fn interface_cache_test_lock() -> &'static Mutex<()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn adversarial_parallel_cold_miss_performs_single_interface_refresh() { + let _guard = interface_cache_test_lock() + .lock() + .unwrap_or_else(|poison| poison.into_inner()); + reset_local_interface_enumerations_for_tests(); + + let local_addr: SocketAddr = "0.0.0.0:443".parse().expect("valid local addr"); + let workers = 32usize; + let barrier = std::sync::Arc::new(Barrier::new(workers)); + let mut tasks = Vec::with_capacity(workers); + + for _ in 0..workers { + let barrier = std::sync::Arc::clone(&barrier); + tasks.push(tokio::spawn(async move { + barrier.wait().await; + is_mask_target_local_listener_async("127.0.0.1", 443, local_addr, None).await + })); + } + + for task in tasks { + let _ = task.await.expect("parallel cache task must not panic"); + } + + assert_eq!( + local_interface_enumerations_for_tests(), + 1, + "parallel cold misses must coalesce into a single interface enumeration" + ); +} \ No newline at end of file diff --git a/src/proxy/tests/masking_interface_cache_defense_in_depth_security_tests.rs b/src/proxy/tests/masking_interface_cache_defense_in_depth_security_tests.rs new file mode 100644 index 0000000..d82cf82 --- /dev/null +++ b/src/proxy/tests/masking_interface_cache_defense_in_depth_security_tests.rs @@ -0,0 +1,51 @@ +#![cfg(unix)] + +use super::*; + +#[test] +fn defense_in_depth_empty_refresh_preserves_previous_non_empty_interfaces() { + let previous = vec![ + "192.168.100.7" + .parse::() + .expect("must parse interface ip"), + ]; + let refreshed = Vec::new(); + + let next = choose_interface_snapshot(&previous, refreshed); + + assert_eq!( + next, previous, + "empty refresh should preserve previous non-empty snapshot to avoid fail-open loop-guard regressions" + ); +} + +#[test] +fn defense_in_depth_non_empty_refresh_replaces_previous_snapshot() { + let previous = vec![ + "192.168.100.7" + .parse::() + .expect("must parse interface ip"), + ]; + let refreshed = vec![ + "10.55.0.3" + .parse::() + .expect("must parse refreshed interface ip"), + ]; + + let next = choose_interface_snapshot(&previous, refreshed.clone()); + + assert_eq!(next, refreshed); +} + +#[test] +fn defense_in_depth_empty_refresh_keeps_empty_when_no_previous_snapshot_exists() { + let previous = Vec::new(); + let refreshed = Vec::new(); + + let next = choose_interface_snapshot(&previous, refreshed); + + assert!( + next.is_empty(), + "empty refresh with no previous snapshot should remain empty" + ); +} diff --git a/src/proxy/tests/masking_interface_cache_security_tests.rs b/src/proxy/tests/masking_interface_cache_security_tests.rs new file mode 100644 index 0000000..6be99d0 --- /dev/null +++ b/src/proxy/tests/masking_interface_cache_security_tests.rs @@ -0,0 +1,46 @@ +#![cfg(unix)] + +use super::*; +use std::sync::{Mutex, OnceLock}; + +fn interface_cache_test_lock() -> &'static Mutex<()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())) +} + +#[tokio::test] +async fn tdd_repeated_local_listener_checks_do_not_repeat_interface_enumeration_within_window() { + let _guard = interface_cache_test_lock() + .lock() + .unwrap_or_else(|poison| poison.into_inner()); + reset_local_interface_enumerations_for_tests(); + + let local_addr: SocketAddr = "0.0.0.0:443".parse().expect("valid local addr"); + + let _ = is_mask_target_local_listener_async("127.0.0.1", 443, local_addr, None).await; + let _ = is_mask_target_local_listener_async("127.0.0.1", 443, local_addr, None).await; + + assert_eq!( + local_interface_enumerations_for_tests(), + 1, + "interface enumeration must be cached across repeated bad-client checks" + ); +} + +#[tokio::test] +async fn tdd_non_local_port_short_circuit_does_not_enumerate_interfaces() { + let _guard = interface_cache_test_lock() + .lock() + .unwrap_or_else(|poison| poison.into_inner()); + reset_local_interface_enumerations_for_tests(); + + let local_addr: SocketAddr = "0.0.0.0:443".parse().expect("valid local addr"); + let is_local = is_mask_target_local_listener_async("127.0.0.1", 8443, local_addr, None).await; + + assert!(!is_local, "different port must not be treated as local listener"); + assert_eq!( + local_interface_enumerations_for_tests(), + 0, + "port mismatch should bypass interface enumeration entirely" + ); +} diff --git a/src/proxy/tests/masking_offline_target_redteam_expected_fail_tests.rs b/src/proxy/tests/masking_offline_target_redteam_expected_fail_tests.rs new file mode 100644 index 0000000..efa4529 --- /dev/null +++ b/src/proxy/tests/masking_offline_target_redteam_expected_fail_tests.rs @@ -0,0 +1,178 @@ +use super::*; +use std::net::{SocketAddr, TcpListener as StdTcpListener}; +use tokio::io::{AsyncWriteExt, duplex}; +use tokio::time::{Duration, Instant}; + +fn closed_local_port() -> u16 { + let listener = StdTcpListener::bind("127.0.0.1:0").unwrap(); + let port = listener.local_addr().unwrap().port(); + drop(listener); + port +} + +#[tokio::test] +#[ignore = "red-team expected-fail: offline mask target keeps bad-client socket alive before consume timeout boundary"] +async fn redteam_offline_target_should_drop_idle_client_early() { + let (client_read, mut client_write) = duplex(1024); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = closed_local_port(); + cfg.censorship.mask_timing_normalization_enabled = false; + + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let peer_addr: SocketAddr = "192.0.2.50:5000".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + let handler = tokio::spawn(async move { + handle_bad_client( + client_read, + tokio::io::sink(), + b"GET / HTTP/1.1\r\n\r\n", + peer_addr, + local_addr, + &cfg, + &beobachten, + ) + .await; + }); + + tokio::time::sleep(Duration::from_millis(150)).await; + let write_res = client_write.write_all(b"probe-should-be-closed").await; + assert!( + write_res.is_err(), + "offline target path still keeps client writable before consume timeout" + ); + + handler.abort(); +} + +#[tokio::test] +#[ignore = "red-team expected-fail: proxy should mimic immediate RST-like close when target is offline"] +async fn redteam_offline_target_should_not_sleep_to_mask_refusal() { + let (client_read, mut client_write) = duplex(1024); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = closed_local_port(); + cfg.censorship.mask_timing_normalization_enabled = false; + + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let peer_addr: SocketAddr = "192.0.2.51:5000".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + let handler = tokio::spawn(async move { + handle_bad_client( + client_read, + tokio::io::sink(), + b"\x16\x03\x01\x00\x05hello", + peer_addr, + local_addr, + &cfg, + &beobachten, + ) + .await; + }); + + client_write.shutdown().await.unwrap(); + let _ = handler.await; + let elapsed = started.elapsed(); + + assert!( + elapsed < Duration::from_millis(10), + "offline target path still applies coarse masking sleep and is fingerprintable" + ); +} + +#[tokio::test] +#[ignore = "red-team expected-fail: refusal path should remain below strict latency envelope under burst"] +async fn redteam_offline_refusal_burst_timing_spread_should_be_tight() { + let mut samples = Vec::new(); + + for i in 0..12u16 { + let (client_read, mut client_write) = duplex(1024); + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = closed_local_port(); + cfg.censorship.mask_timing_normalization_enabled = false; + + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let peer_addr: SocketAddr = format!("192.0.2.52:{}", 5100 + i).parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + let handler = tokio::spawn(async move { + handle_bad_client( + client_read, + tokio::io::sink(), + b"GET / HTTP/1.1\r\n\r\n", + peer_addr, + local_addr, + &cfg, + &beobachten, + ) + .await; + }); + + client_write.shutdown().await.unwrap(); + let _ = handler.await; + samples.push(started.elapsed()); + } + + let min = samples.iter().copied().min().unwrap_or_default(); + let max = samples.iter().copied().max().unwrap_or_default(); + let spread = max.saturating_sub(min); + + assert!( + spread <= Duration::from_millis(5), + "offline refusal timing spread too wide for strict red-team envelope: {:?}", + spread + ); +} + +#[tokio::test] +#[ignore = "manual red-team: host resolver failure should complete without panic"] +async fn redteam_dns_resolution_failure_must_not_panic() { + let (client_read, mut client_write) = duplex(1024); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("this.domain.definitely.does.not.exist.invalid".to_string()); + cfg.censorship.mask_port = 443; + + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let peer_addr: SocketAddr = "192.0.2.99:5999".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + let handler = tokio::spawn(async move { + handle_bad_client( + client_read, + tokio::io::sink(), + b"GET / HTTP/1.1\r\n\r\n", + peer_addr, + local_addr, + &cfg, + &beobachten, + ) + .await; + }); + + client_write.shutdown().await.unwrap(); + let result = tokio::time::timeout(Duration::from_secs(2), handler).await; + assert!( + result.is_ok(), + "dns failure path stalled or panicked instead of terminating" + ); +} diff --git a/src/proxy/tests/masking_padding_timeout_adversarial_tests.rs b/src/proxy/tests/masking_padding_timeout_adversarial_tests.rs new file mode 100644 index 0000000..b99b4bc --- /dev/null +++ b/src/proxy/tests/masking_padding_timeout_adversarial_tests.rs @@ -0,0 +1,51 @@ +use super::*; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::time::Instant; +use tokio::io::AsyncWrite; + +struct NeverWritable; + +impl AsyncWrite for NeverWritable { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _buf: &[u8], + ) -> Poll> { + Poll::Pending + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Pending + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} + +#[tokio::test] +async fn shape_padding_returns_before_global_mask_timeout_on_blocked_writer() { + let mut writer = NeverWritable; + let started = Instant::now(); + + maybe_write_shape_padding(&mut writer, 1, true, 256, 4096, false, 0, false).await; + + assert!( + started.elapsed() <= MASK_TIMEOUT + std::time::Duration::from_millis(30), + "shape padding blocked past timeout budget" + ); +} + +#[tokio::test] +async fn shape_padding_with_non_http_blur_disabled_at_cap_writes_nothing() { + let mut output = Vec::new(); + { + let mut writer = tokio::io::BufWriter::new(&mut output); + maybe_write_shape_padding(&mut writer, 4096, true, 64, 4096, false, 128, false).await; + use tokio::io::AsyncWriteExt; + writer.flush().await.unwrap(); + } + + assert!(output.is_empty()); +} diff --git a/src/proxy/tests/masking_production_cap_regression_security_tests.rs b/src/proxy/tests/masking_production_cap_regression_security_tests.rs new file mode 100644 index 0000000..f2368a1 --- /dev/null +++ b/src/proxy/tests/masking_production_cap_regression_security_tests.rs @@ -0,0 +1,289 @@ +use super::*; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::time::{Duration, Instant, timeout}; + +const PROD_CAP_BYTES: usize = 5 * 1024 * 1024; + +struct FinitePatternReader { + remaining: usize, + chunk: usize, + read_calls: Arc, +} + +impl FinitePatternReader { + fn new(total: usize, chunk: usize, read_calls: Arc) -> Self { + Self { + remaining: total, + chunk, + read_calls, + } + } +} + +impl AsyncRead for FinitePatternReader { + fn poll_read( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + self.read_calls.fetch_add(1, Ordering::Relaxed); + + if self.remaining == 0 { + return Poll::Ready(Ok(())); + } + + let take = self.remaining.min(self.chunk).min(buf.remaining()); + if take == 0 { + return Poll::Ready(Ok(())); + } + + let fill = vec![0x5Au8; take]; + buf.put_slice(&fill); + self.remaining -= take; + Poll::Ready(Ok(())) + } +} + +#[derive(Default)] +struct CountingWriter { + written: usize, +} + +impl AsyncWrite for CountingWriter { + fn poll_write( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.written = self.written.saturating_add(buf.len()); + Poll::Ready(Ok(buf.len())) + } + + fn poll_flush( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } +} + +struct NeverReadyReader; + +impl AsyncRead for NeverReadyReader { + fn poll_read( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + Poll::Pending + } +} + +struct BudgetProbeReader { + remaining: usize, + total_read: Arc, +} + +impl BudgetProbeReader { + fn new(total: usize, total_read: Arc) -> Self { + Self { + remaining: total, + total_read, + } + } +} + +impl AsyncRead for BudgetProbeReader { + fn poll_read( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + if self.remaining == 0 { + return Poll::Ready(Ok(())); + } + + let take = self.remaining.min(buf.remaining()); + if take == 0 { + return Poll::Ready(Ok(())); + } + + let fill = vec![0xA5u8; take]; + buf.put_slice(&fill); + self.remaining -= take; + self.total_read.fetch_add(take, Ordering::Relaxed); + Poll::Ready(Ok(())) + } +} + +#[tokio::test] +async fn positive_copy_with_production_cap_stops_exactly_at_budget() { + let read_calls = Arc::new(AtomicUsize::new(0)); + let mut reader = FinitePatternReader::new(PROD_CAP_BYTES + (256 * 1024), 4096, read_calls); + let mut writer = CountingWriter::default(); + + let outcome = copy_with_idle_timeout(&mut reader, &mut writer, PROD_CAP_BYTES, true).await; + + assert_eq!( + outcome.total, PROD_CAP_BYTES, + "copy path must stop at explicit production cap" + ); + assert_eq!(writer.written, PROD_CAP_BYTES); + assert!( + !outcome.ended_by_eof, + "byte-cap stop must not be misclassified as EOF" + ); +} + +#[tokio::test] +async fn negative_consume_with_zero_cap_performs_no_reads() { + let read_calls = Arc::new(AtomicUsize::new(0)); + let reader = FinitePatternReader::new(1024, 64, Arc::clone(&read_calls)); + + consume_client_data_with_timeout_and_cap(reader, 0).await; + + assert_eq!( + read_calls.load(Ordering::Relaxed), + 0, + "zero cap must return before reading attacker-controlled bytes" + ); +} + +#[tokio::test] +async fn edge_copy_below_cap_reports_eof_without_overread() { + let read_calls = Arc::new(AtomicUsize::new(0)); + let payload = 73 * 1024; + let mut reader = FinitePatternReader::new(payload, 3072, read_calls); + let mut writer = CountingWriter::default(); + + let outcome = copy_with_idle_timeout(&mut reader, &mut writer, PROD_CAP_BYTES, true).await; + + assert_eq!(outcome.total, payload); + assert_eq!(writer.written, payload); + assert!( + outcome.ended_by_eof, + "finite upstream below cap must terminate via EOF path" + ); +} + +#[tokio::test] +async fn adversarial_blackhat_never_ready_reader_is_bounded_by_timeout_guards() { + let started = Instant::now(); + + consume_client_data_with_timeout_and_cap(NeverReadyReader, PROD_CAP_BYTES).await; + + assert!( + started.elapsed() < Duration::from_millis(350), + "never-ready reader must be bounded by idle/relay timeout protections" + ); +} + +#[tokio::test] +async fn integration_consume_path_honors_production_cap_for_large_payload() { + let read_calls = Arc::new(AtomicUsize::new(0)); + let reader = FinitePatternReader::new(PROD_CAP_BYTES + (1024 * 1024), 8192, read_calls); + + let bounded = timeout( + Duration::from_millis(350), + consume_client_data_with_timeout_and_cap(reader, PROD_CAP_BYTES), + ) + .await; + + assert!( + bounded.is_ok(), + "consume path with production cap must finish within bounded time" + ); +} + +#[tokio::test] +async fn adversarial_consume_path_never_reads_beyond_declared_byte_cap() { + let byte_cap = 5usize; + let total_read = Arc::new(AtomicUsize::new(0)); + let reader = BudgetProbeReader::new(256 * 1024, Arc::clone(&total_read)); + + consume_client_data_with_timeout_and_cap(reader, byte_cap).await; + + assert!( + total_read.load(Ordering::Relaxed) <= byte_cap, + "consume path must not read more than configured byte cap" + ); +} + +#[tokio::test] +async fn light_fuzz_cap_and_payload_matrix_preserves_min_budget_invariant() { + let mut seed = 0x1234_5678_9ABC_DEF0u64; + + for _case in 0..96u32 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let cap = ((seed & 0x3ffff) as usize).saturating_add(1); + let payload = ((seed.rotate_left(11) & 0x7ffff) as usize).saturating_add(1); + let chunk = (((seed >> 5) & 0x1fff) as usize).saturating_add(1); + + let read_calls = Arc::new(AtomicUsize::new(0)); + let mut reader = FinitePatternReader::new(payload, chunk, read_calls); + let mut writer = CountingWriter::default(); + + let outcome = copy_with_idle_timeout(&mut reader, &mut writer, cap, true).await; + let expected = payload.min(cap); + + assert_eq!( + outcome.total, expected, + "copy total must match min(payload, cap) under fuzzed inputs" + ); + assert_eq!(writer.written, expected); + if payload <= cap { + assert!(outcome.ended_by_eof); + } else { + assert!(!outcome.ended_by_eof); + } + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_parallel_copy_tasks_with_production_cap_complete_without_leaks() { + let workers = 8usize; + let mut tasks = Vec::with_capacity(workers); + + for idx in 0..workers { + tasks.push(tokio::spawn(async move { + let read_calls = Arc::new(AtomicUsize::new(0)); + let mut reader = FinitePatternReader::new( + PROD_CAP_BYTES + (idx + 1) * 4096, + 4096 + (idx * 257), + read_calls, + ); + let mut writer = CountingWriter::default(); + copy_with_idle_timeout(&mut reader, &mut writer, PROD_CAP_BYTES, true).await + })); + } + + timeout(Duration::from_secs(3), async { + for task in tasks { + let outcome = task.await.expect("stress task must not panic"); + assert_eq!( + outcome.total, PROD_CAP_BYTES, + "stress copy task must stay within production cap" + ); + assert!( + !outcome.ended_by_eof, + "stress task should end due to cap, not EOF" + ); + } + }) + .await + .expect("stress suite must complete in bounded time"); +} diff --git a/src/proxy/tests/masking_relay_guardrails_security_tests.rs b/src/proxy/tests/masking_relay_guardrails_security_tests.rs new file mode 100644 index 0000000..257c0f8 --- /dev/null +++ b/src/proxy/tests/masking_relay_guardrails_security_tests.rs @@ -0,0 +1,105 @@ +use super::*; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex, sink}; +use tokio::time::{Duration, timeout}; + +#[tokio::test] +async fn relay_to_mask_enforces_masking_session_byte_cap() { + let initial = vec![0x16, 0x03, 0x01, 0x00, 0x01]; + let extra = vec![0xAB; 96 * 1024]; + + let (client_reader, mut client_writer) = duplex(128 * 1024); + let (mask_read, _mask_read_peer) = duplex(1024); + let (mut mask_observer, mask_write) = duplex(256 * 1024); + let initial_for_task = initial.clone(); + + let relay = tokio::spawn(async move { + relay_to_mask( + client_reader, + sink(), + mask_read, + mask_write, + &initial_for_task, + false, + 512, + 4096, + false, + 0, + false, + 32 * 1024, + ) + .await; + }); + + client_writer.write_all(&extra).await.unwrap(); + client_writer.shutdown().await.unwrap(); + + timeout(Duration::from_secs(2), relay) + .await + .unwrap() + .unwrap(); + + let mut observed = Vec::new(); + timeout( + Duration::from_secs(2), + mask_observer.read_to_end(&mut observed), + ) + .await + .unwrap() + .unwrap(); + + // In this deterministic test, relay must stop exactly at the configured cap. + assert_eq!( + observed.len(), + initial.len() + (32 * 1024), + "masked relay must forward exactly up to the cap (observed={} initial={} cap={})", + observed.len(), + initial.len(), + 32 * 1024 + ); +} + +#[tokio::test] +async fn relay_to_mask_propagates_client_half_close_without_waiting_for_other_direction_timeout() { + let initial = b"GET /half-close HTTP/1.1\r\n".to_vec(); + + let (client_reader, mut client_writer) = duplex(8 * 1024); + let (mask_read, _mask_read_peer) = duplex(8 * 1024); + let (mut mask_observer, mask_write) = duplex(8 * 1024); + let initial_for_task = initial.clone(); + + let relay = tokio::spawn(async move { + relay_to_mask( + client_reader, + sink(), + mask_read, + mask_write, + &initial_for_task, + false, + 512, + 4096, + false, + 0, + false, + 32 * 1024, + ) + .await; + }); + + client_writer.shutdown().await.unwrap(); + + let mut observed = Vec::new(); + timeout( + Duration::from_millis(80), + mask_observer.read_to_end(&mut observed), + ) + .await + .expect("mask backend write side should be half-closed promptly") + .unwrap(); + + assert_eq!(&observed[..initial.len()], initial.as_slice()); + + timeout(Duration::from_secs(2), relay) + .await + .unwrap() + .unwrap(); +} diff --git a/src/proxy/tests/masking_rng_hoist_perf_regression_tests.rs b/src/proxy/tests/masking_rng_hoist_perf_regression_tests.rs new file mode 100644 index 0000000..627c48b --- /dev/null +++ b/src/proxy/tests/masking_rng_hoist_perf_regression_tests.rs @@ -0,0 +1,100 @@ +use super::*; +use tokio::io::AsyncReadExt; +use tokio::time::{Duration, timeout}; + +async fn collect_padding( + total_sent: usize, + enabled: bool, + floor: usize, + cap: usize, + above_cap_blur: bool, + blur_max: usize, + aggressive: bool, +) -> Vec { + let (mut tx, mut rx) = tokio::io::duplex(256 * 1024); + + maybe_write_shape_padding( + &mut tx, + total_sent, + enabled, + floor, + cap, + above_cap_blur, + blur_max, + aggressive, + ) + .await; + + drop(tx); + + let mut output = Vec::new(); + timeout(Duration::from_secs(1), rx.read_to_end(&mut output)) + .await + .expect("reading padded output timed out") + .expect("failed reading padded output"); + output +} + +#[tokio::test] +async fn padding_output_is_not_all_zero() { + let output = collect_padding(1, true, 256, 4096, false, 0, false).await; + + assert!( + output.len() >= 255, + "expected at least 255 padding bytes, got {}", + output.len() + ); + + let nonzero = output.iter().filter(|&&b| b != 0).count(); + // In 255 bytes of uniform randomness, the expected number of zero bytes is ~1. + // A weak nonzero check can miss severe entropy collapse. + assert!( + nonzero >= 240, + "RNG output entropy collapsed, too many zero bytes: {} nonzero out of {}", + nonzero, + output.len(), + ); +} + +#[tokio::test] +async fn padding_reaches_first_bucket_boundary() { + let output = collect_padding(1, true, 64, 4096, false, 0, false).await; + assert_eq!(output.len(), 63); +} + +#[tokio::test] +async fn disabled_padding_produces_no_output() { + let output = collect_padding(0, false, 256, 4096, false, 0, false).await; + assert!(output.is_empty()); +} + +#[tokio::test] +async fn at_cap_without_blur_produces_no_output() { + let output = collect_padding(4096, true, 64, 4096, false, 0, false).await; + assert!(output.is_empty()); +} + +#[tokio::test] +async fn above_cap_blur_is_positive_and_bounded_in_aggressive_mode() { + let output = collect_padding(4096, true, 64, 4096, true, 128, true).await; + assert!(!output.is_empty()); + assert!(output.len() <= 128, "blur exceeded max: {}", output.len()); +} + +#[tokio::test] +async fn stress_padding_runs_are_not_constant_pattern() { + // Stress and sanity-check: repeated runs should not collapse to identical + // first 16 bytes across all samples. + let mut first_chunks = Vec::new(); + for _ in 0..64 { + let out = collect_padding(1, true, 64, 4096, false, 0, false).await; + first_chunks.push(out[..16].to_vec()); + } + + let first = &first_chunks[0]; + let all_same = first_chunks.iter().all(|chunk| chunk == first); + assert!( + !all_same, + "all stress samples had identical prefix, rng output appears degenerate" + ); +} diff --git a/src/proxy/tests/masking_security_tests.rs b/src/proxy/tests/masking_security_tests.rs index 4519d85..c698b55 100644 --- a/src/proxy/tests/masking_security_tests.rs +++ b/src/proxy/tests/masking_security_tests.rs @@ -1376,6 +1376,7 @@ async fn relay_to_mask_keeps_backend_to_client_flow_when_client_to_backend_stall false, 0, false, + 5 * 1024 * 1024, ) .await; }); @@ -1506,6 +1507,7 @@ async fn relay_to_mask_timeout_cancels_and_drops_all_io_endpoints() { false, 0, false, + 5 * 1024 * 1024, ), ) .await; diff --git a/src/proxy/tests/masking_self_target_loop_security_tests.rs b/src/proxy/tests/masking_self_target_loop_security_tests.rs new file mode 100644 index 0000000..18cb0d7 --- /dev/null +++ b/src/proxy/tests/masking_self_target_loop_security_tests.rs @@ -0,0 +1,360 @@ +use super::*; +use std::net::TcpListener as StdTcpListener; +use std::net::SocketAddr; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::net::TcpListener; +use tokio::time::{Duration, Instant, timeout}; + +fn closed_local_port() -> u16 { + let listener = StdTcpListener::bind("127.0.0.1:0").unwrap(); + let port = listener.local_addr().unwrap().port(); + drop(listener); + port +} + +#[tokio::test] +async fn self_target_detection_matches_literal_ipv4_listener() { + let local: SocketAddr = "198.51.100.40:443".parse().unwrap(); + assert!(is_mask_target_local_listener_async( + "198.51.100.40", + 443, + local, + None, + ) + .await); +} + +#[tokio::test] +async fn self_target_detection_matches_bracketed_ipv6_listener() { + let local: SocketAddr = "[2001:db8::44]:8443".parse().unwrap(); + assert!(is_mask_target_local_listener_async( + "[2001:db8::44]", + 8443, + local, + None, + ) + .await); +} + +#[tokio::test] +async fn self_target_detection_keeps_same_ip_different_port_forwardable() { + let local: SocketAddr = "203.0.113.44:443".parse().unwrap(); + assert!(!is_mask_target_local_listener_async( + "203.0.113.44", + 8443, + local, + None, + ) + .await); +} + +#[tokio::test] +async fn self_target_detection_normalizes_ipv4_mapped_ipv6_literal() { + let local: SocketAddr = "127.0.0.1:443".parse().unwrap(); + assert!(is_mask_target_local_listener_async( + "::ffff:127.0.0.1", + 443, + local, + None, + ) + .await); +} + +#[tokio::test] +async fn self_target_detection_unspecified_bind_blocks_loopback_target() { + let local: SocketAddr = "0.0.0.0:443".parse().unwrap(); + assert!(is_mask_target_local_listener_async( + "127.0.0.1", + 443, + local, + None, + ) + .await); +} + +#[tokio::test] +async fn self_target_detection_unspecified_bind_keeps_remote_target_forwardable() { + let local: SocketAddr = "0.0.0.0:443".parse().unwrap(); + let remote: SocketAddr = "198.51.100.44:443".parse().unwrap(); + assert!(!is_mask_target_local_listener_async( + "mask.example", + 443, + local, + Some(remote), + ) + .await); +} + +#[tokio::test] +async fn self_target_fallback_refuses_recursive_loopback_connect() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let local_addr = listener.local_addr().unwrap(); + let accept_task = tokio::spawn(async move { + timeout(Duration::from_millis(120), listener.accept()) + .await + .is_ok() + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = None; + config.censorship.mask_host = Some(local_addr.ip().to_string()); + config.censorship.mask_port = local_addr.port(); + config.censorship.mask_proxy_protocol = 0; + + let peer: SocketAddr = "203.0.113.90:55090".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + handle_bad_client( + tokio::io::empty(), + tokio::io::sink(), + b"GET /", + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + let accepted = accept_task.await.unwrap(); + assert!( + !accepted, + "self-target masking must fail closed without connecting to local listener" + ); +} + +#[tokio::test] +async fn same_ip_different_port_still_forwards_to_mask_backend() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = b"GET /".to_vec(); + let accept_task = tokio::spawn({ + let expected = probe.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = vec![0u8; expected.len()]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, expected); + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = None; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_proxy_protocol = 0; + + let peer: SocketAddr = "203.0.113.91:55091".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + handle_bad_client( + tokio::io::empty(), + tokio::io::sink(), + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + timeout(Duration::from_secs(2), accept_task) + .await + .unwrap() + .unwrap(); +} + +#[test] +fn detect_client_type_http_boundary_get_and_post() { + assert_eq!(detect_client_type(b"GET "), "HTTP"); + assert_eq!(detect_client_type(b"GET /"), "HTTP"); + + assert_eq!(detect_client_type(b"POST"), "HTTP"); + assert_eq!(detect_client_type(b"POST "), "HTTP"); + assert_eq!(detect_client_type(b"POSTX"), "HTTP"); +} + +#[test] +fn detect_client_type_tls_and_length_boundaries() { + assert_eq!(detect_client_type(b"\x16\x03\x01"), "port-scanner"); + assert_eq!(detect_client_type(b"\x16\x03\x01\x00"), "TLS-scanner"); + + assert_eq!(detect_client_type(b"123456789"), "port-scanner"); + assert_eq!(detect_client_type(b"1234567890"), "unknown"); +} + +#[test] +fn build_mask_proxy_header_v1_cross_family_falls_back_to_unknown() { + let peer: SocketAddr = "192.168.1.5:12345".parse().unwrap(); + let local: SocketAddr = "[2001:db8::1]:443".parse().unwrap(); + let header = build_mask_proxy_header(1, peer, local).unwrap(); + assert_eq!(header, b"PROXY UNKNOWN\r\n"); +} + +#[test] +fn next_mask_shape_bucket_checked_mul_overflow_fails_closed() { + let floor = usize::MAX / 2 + 1; + let cap = usize::MAX; + let total = floor + 1; + assert_eq!(next_mask_shape_bucket(total, floor, cap), total); +} + +#[tokio::test] +async fn self_target_reject_path_keeps_timing_budget() { + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = None; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = 443; + + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let peer: SocketAddr = "203.0.113.92:55092".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + let (client, server) = duplex(1024); + drop(client); + + let started = Instant::now(); + handle_bad_client( + server, + tokio::io::sink(), + b"GET / HTTP/1.1\r\n", + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + let elapsed = started.elapsed(); + assert!( + elapsed >= Duration::from_millis(40) && elapsed < Duration::from_millis(250), + "self-target reject path must keep coarse timing budget without stalling" + ); +} + +#[tokio::test] +async fn relay_path_idle_timeout_eviction_remains_effective() { + let (client_read, mut client_write) = duplex(1024); + let (mask_read, mask_write) = duplex(1024); + + tokio::spawn(async move { + tokio::time::sleep(Duration::from_millis(10)).await; + client_write.write_all(b"a").await.unwrap(); + tokio::time::sleep(Duration::from_millis(180)).await; + let _ = client_write.write_all(b"b").await; + }); + + let started = Instant::now(); + relay_to_mask( + client_read, + tokio::io::sink(), + mask_read, + mask_write, + b"init", + false, + 0, + 0, + false, + 0, + false, + 5 * 1024 * 1024, + ) + .await; + + let elapsed = started.elapsed(); + assert!( + elapsed >= Duration::from_millis(90) && elapsed < Duration::from_millis(180), + "idle-timeout eviction must occur before late trickle write" + ); +} + +#[tokio::test] +async fn offline_mask_target_refusal_respects_timing_normalization_budget() { + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = None; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = closed_local_port(); + config.censorship.mask_timing_normalization_enabled = true; + config.censorship.mask_timing_normalization_floor_ms = 120; + config.censorship.mask_timing_normalization_ceiling_ms = 120; + + let peer: SocketAddr = "203.0.113.93:55093".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + let (mut client, server) = duplex(1024); + let started = Instant::now(); + let task = tokio::spawn(async move { + handle_bad_client( + server, + tokio::io::sink(), + b"GET / HTTP/1.1\r\n\r\n", + peer, + local_addr, + &config, + &beobachten, + ) + .await; + }); + + client.shutdown().await.unwrap(); + timeout(Duration::from_secs(2), task).await.unwrap().unwrap(); + let elapsed = started.elapsed(); + + assert!( + elapsed >= Duration::from_millis(110) && elapsed < Duration::from_millis(220), + "offline-refusal path must honor normalization budget without unbounded drift" + ); +} + +#[tokio::test] +async fn offline_mask_target_refusal_with_idle_client_is_bounded_by_consume_timeout() { + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = None; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = closed_local_port(); + config.censorship.mask_timing_normalization_enabled = false; + + let peer: SocketAddr = "203.0.113.94:55094".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + let (mut client, server) = duplex(1024); + let started = Instant::now(); + let task = tokio::spawn(async move { + handle_bad_client( + server, + tokio::io::sink(), + b"GET / HTTP/1.1\r\n\r\n", + peer, + local_addr, + &config, + &beobachten, + ) + .await; + }); + + tokio::time::sleep(Duration::from_millis(120)).await; + client + .write_all(b"still-open-before-timeout") + .await + .expect("connection should still be open before consume timeout expires"); + + timeout(Duration::from_secs(2), task).await.unwrap().unwrap(); + let elapsed = started.elapsed(); + + assert!( + elapsed >= Duration::from_millis(190) && elapsed < Duration::from_millis(350), + "offline-refusal path must not retain idle client indefinitely" + ); +} diff --git a/src/proxy/tests/masking_shape_guard_adversarial_tests.rs b/src/proxy/tests/masking_shape_guard_adversarial_tests.rs index 982fd26..4fa8da7 100644 --- a/src/proxy/tests/masking_shape_guard_adversarial_tests.rs +++ b/src/proxy/tests/masking_shape_guard_adversarial_tests.rs @@ -43,6 +43,7 @@ async fn run_relay_case( above_cap_blur, above_cap_blur_max_bytes, false, + 5 * 1024 * 1024, ) .await; }); diff --git a/src/proxy/tests/masking_shape_hardening_adversarial_tests.rs b/src/proxy/tests/masking_shape_hardening_adversarial_tests.rs index 3c886ba..9abf3c0 100644 --- a/src/proxy/tests/masking_shape_hardening_adversarial_tests.rs +++ b/src/proxy/tests/masking_shape_hardening_adversarial_tests.rs @@ -88,6 +88,7 @@ async fn relay_to_mask_applies_cap_clamped_padding_for_non_power_of_two_cap() { false, 0, false, + 5 * 1024 * 1024, ) .await; }); diff --git a/src/proxy/tests/masking_timing_budget_coupling_security_tests.rs b/src/proxy/tests/masking_timing_budget_coupling_security_tests.rs new file mode 100644 index 0000000..1c342ea --- /dev/null +++ b/src/proxy/tests/masking_timing_budget_coupling_security_tests.rs @@ -0,0 +1,55 @@ +#![cfg(unix)] + +use super::*; +use tokio::io::{AsyncWriteExt, duplex}; +use tokio::time::{Duration, Instant, timeout}; + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn adversarial_delayed_interface_lookup_does_not_consume_outcome_floor_budget() { + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = None; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = 443; + config.censorship.mask_timing_normalization_enabled = true; + config.censorship.mask_timing_normalization_floor_ms = 120; + config.censorship.mask_timing_normalization_ceiling_ms = 120; + + let peer: SocketAddr = "203.0.113.151:55151".parse().expect("valid peer"); + let local_addr: SocketAddr = "0.0.0.0:443".parse().expect("valid local addr"); + let beobachten = BeobachtenStore::new(); + + let refresh_lock = LOCAL_INTERFACE_REFRESH_LOCK.get_or_init(|| AsyncMutex::new(())); + let held_refresh_guard = refresh_lock.lock().await; + + let (mut client, server) = duplex(1024); + let started = Instant::now(); + let task = tokio::spawn(async move { + handle_bad_client( + server, + tokio::io::sink(), + b"GET / HTTP/1.1\r\n\r\n", + peer, + local_addr, + &config, + &beobachten, + ) + .await; + }); + + tokio::time::sleep(Duration::from_millis(80)).await; + drop(held_refresh_guard); + client.shutdown().await.expect("client shutdown must succeed"); + + timeout(Duration::from_secs(2), task) + .await + .expect("task must finish in bounded time") + .expect("task must not panic"); + let elapsed = started.elapsed(); + + assert!( + elapsed >= Duration::from_millis(180) && elapsed < Duration::from_millis(350), + "timing normalization floor must start after pre-outcome self-target checks" + ); +} \ No newline at end of file diff --git a/src/proxy/tests/middle_relay_blackhat_campaign_integration_tests.rs b/src/proxy/tests/middle_relay_blackhat_campaign_integration_tests.rs index 2c9f3f6..6f0e91a 100644 --- a/src/proxy/tests/middle_relay_blackhat_campaign_integration_tests.rs +++ b/src/proxy/tests/middle_relay_blackhat_campaign_integration_tests.rs @@ -9,6 +9,7 @@ use tokio::time::{Duration, timeout}; #[tokio::test(flavor = "multi_thread", worker_threads = 4)] async fn blackhat_campaign_saturation_quota_race_with_queue_pressure_stays_fail_closed() { let _guard = super::quota_user_lock_test_scope(); + let _pressure_guard = super::relay_idle_pressure_test_scope(); let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); map.clear(); diff --git a/src/proxy/tests/middle_relay_coverage_high_risk_security_tests.rs b/src/proxy/tests/middle_relay_coverage_high_risk_security_tests.rs index fff26b4..44c201f 100644 --- a/src/proxy/tests/middle_relay_coverage_high_risk_security_tests.rs +++ b/src/proxy/tests/middle_relay_coverage_high_risk_security_tests.rs @@ -645,6 +645,75 @@ fn quota_exceeded_boundary_is_inclusive() { assert!(!quota_exceeded_for_user(&stats, user, Some(51))); } +#[test] +fn quota_soft_helper_matches_capped_generic_helper_matrix() { + let stats = Stats::new(); + let user = "quota-soft-parity"; + + for used in [0u64, 1, 7, 63, 127, 255] { + stats.sub_user_octets_to(user, stats.get_user_total_octets(user)); + stats.add_user_octets_to(user, used); + + for quota in [8u64, 64, 128, 256] { + for overshoot in [0u64, 1, 5, 32] { + for bytes in [0u64, 1, 2, 7, 31, 64] { + let soft = quota_would_be_exceeded_for_user_soft( + &stats, + user, + Some(quota), + bytes, + overshoot, + ); + let capped = quota_would_be_exceeded_for_user( + &stats, + user, + Some(quota_soft_cap(quota, overshoot)), + bytes, + ); + assert_eq!( + soft, capped, + "soft helper parity mismatch: used={used} quota={quota} overshoot={overshoot} bytes={bytes}" + ); + } + } + } + } +} + +#[test] +fn quota_soft_helper_none_limit_never_rejects() { + let stats = Stats::new(); + let user = "quota-soft-none"; + stats.add_user_octets_to(user, u64::MAX); + + assert!(!quota_would_be_exceeded_for_user_soft( + &stats, + user, + None, + u64::MAX, + u64::MAX, + )); +} + +#[test] +fn quota_soft_cap_saturates_and_stays_fail_closed() { + let stats = Stats::new(); + let user = "quota-soft-saturating"; + let quota = u64::MAX - 2; + let overshoot = 100; + + assert_eq!(quota_soft_cap(quota, overshoot), u64::MAX); + + stats.add_user_octets_to(user, u64::MAX - 1); + assert!(quota_would_be_exceeded_for_user_soft( + &stats, + user, + Some(quota), + 2, + overshoot, + )); +} + #[tokio::test] async fn enqueue_c2me_close_fast_path_succeeds_without_backpressure() { let (tx, mut rx) = mpsc::channel::(4); diff --git a/src/proxy/tests/middle_relay_cross_mode_lock_release_regression_security_tests.rs b/src/proxy/tests/middle_relay_cross_mode_lock_release_regression_security_tests.rs new file mode 100644 index 0000000..a787aa6 --- /dev/null +++ b/src/proxy/tests/middle_relay_cross_mode_lock_release_regression_security_tests.rs @@ -0,0 +1,295 @@ +use super::*; +use crate::crypto::{AesCtr, SecureRandom}; +use crate::stats::Stats; +use crate::stream::CryptoWriter; +use bytes::Bytes; +use std::pin::Pin; +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; +use std::sync::{Arc, Mutex}; +use std::task::{Context, Poll, Waker}; +use tokio::io::AsyncWrite; +use tokio::sync::Notify; +use tokio::task::JoinSet; +use tokio::time::{Duration, timeout}; + +fn make_crypto_writer(writer: W) -> CryptoWriter +where + W: tokio::io::AsyncWrite + Unpin, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) +} + +#[derive(Default)] +struct BlockingWriteState { + write_entered: AtomicBool, + released: AtomicBool, + write_waker: Mutex>, + write_entered_notify: Notify, +} + +struct BlockingWrite { + state: Arc, +} + +impl BlockingWrite { + fn new(state: Arc) -> Self { + Self { state } + } +} + +impl AsyncWrite for BlockingWrite { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.state.write_entered.store(true, Ordering::Release); + self.state.write_entered_notify.notify_waiters(); + + if self.state.released.load(Ordering::Acquire) { + return Poll::Ready(Ok(buf.len())); + } + + if let Ok(mut slot) = self.state.write_waker.lock() { + *slot = Some(cx.waker().clone()); + } + + Poll::Pending + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} + +async fn wait_until_blocking_write_entered(state: &Arc) { + for _ in 0..8 { + if state.write_entered.load(Ordering::Acquire) { + return; + } + let _ = timeout(Duration::from_millis(25), state.write_entered_notify.notified()).await; + } + + panic!("blocking writer did not enter poll_write in bounded time"); +} + +fn release_blocking_write(state: &Arc) { + state.released.store(true, Ordering::Release); + if let Ok(mut slot) = state.write_waker.lock() + && let Some(waker) = slot.take() + { + waker.wake(); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn adversarial_blocked_write_releases_cross_mode_lock_and_preserves_fail_closed_quota() { + let stats = Arc::new(Stats::new()); + let user = format!("middle-cross-release-regression-{}", std::process::id()); + let cross_mode_lock = Arc::new(cross_mode_quota_user_lock_for_tests(&user)); + let bytes_me2c = Arc::new(AtomicU64::new(0)); + let writer_state = Arc::new(BlockingWriteState::default()); + + let first = { + let stats = Arc::clone(&stats); + let user = user.clone(); + let cross_mode_lock = Arc::clone(&cross_mode_lock); + let bytes_me2c = Arc::clone(&bytes_me2c); + let writer_state = Arc::clone(&writer_state); + tokio::spawn(async move { + let mut writer = make_crypto_writer(BlockingWrite::new(writer_state)); + let mut frame_buf = Vec::new(); + process_me_writer_response_with_cross_mode_lock( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0xAA, 0xBB, 0xCC, 0xDD]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + stats.as_ref(), + &user, + Some(4), + 0, + Some(&cross_mode_lock), + bytes_me2c.as_ref(), + 41_000, + false, + false, + ) + .await + }) + }; + + wait_until_blocking_write_entered(&writer_state).await; + + let guard = timeout(Duration::from_millis(40), cross_mode_lock.lock()) + .await + .expect("cross-mode lock must be released while first write is pending"); + drop(guard); + + let second = { + let stats = Arc::clone(&stats); + let user = user.clone(); + let cross_mode_lock = Arc::clone(&cross_mode_lock); + let bytes_me2c = Arc::clone(&bytes_me2c); + tokio::spawn(async move { + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + timeout( + Duration::from_millis(150), + process_me_writer_response_with_cross_mode_lock( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0xEE]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + stats.as_ref(), + &user, + Some(4), + 0, + Some(&cross_mode_lock), + bytes_me2c.as_ref(), + 41_001, + false, + false, + ), + ) + .await + }) + }; + + let second_result = second + .await + .expect("second task must not panic") + .expect("second write must not block on cross-mode lock"); + assert!( + matches!(second_result, Err(ProxyError::DataQuotaExceeded { .. })), + "second write must fail closed due to first write reservation" + ); + + release_blocking_write(&writer_state); + + let first_result = timeout(Duration::from_millis(300), first) + .await + .expect("first task timed out") + .expect("first task must not panic"); + assert!(first_result.is_ok()); + + assert_eq!(stats.get_user_total_octets(&user), 4); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), 4); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_pending_write_does_not_starve_same_user_waiters_after_quota_boundary() { + let stats = Arc::new(Stats::new()); + let user = format!("middle-cross-release-stress-{}", std::process::id()); + let cross_mode_lock = Arc::new(cross_mode_quota_user_lock_for_tests(&user)); + let bytes_me2c = Arc::new(AtomicU64::new(0)); + let writer_state = Arc::new(BlockingWriteState::default()); + + let first = { + let stats = Arc::clone(&stats); + let user = user.clone(); + let cross_mode_lock = Arc::clone(&cross_mode_lock); + let bytes_me2c = Arc::clone(&bytes_me2c); + let writer_state = Arc::clone(&writer_state); + tokio::spawn(async move { + let mut writer = make_crypto_writer(BlockingWrite::new(writer_state)); + let mut frame_buf = Vec::new(); + process_me_writer_response_with_cross_mode_lock( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0x01, 0x02]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + stats.as_ref(), + &user, + Some(3), + 0, + Some(&cross_mode_lock), + bytes_me2c.as_ref(), + 41_100, + false, + false, + ) + .await + }) + }; + + wait_until_blocking_write_entered(&writer_state).await; + + let mut set = JoinSet::new(); + for idx in 0..48u64 { + let stats = Arc::clone(&stats); + let user = user.clone(); + let cross_mode_lock = Arc::clone(&cross_mode_lock); + let bytes_me2c = Arc::clone(&bytes_me2c); + set.spawn(async move { + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + timeout( + Duration::from_millis(200), + process_me_writer_response_with_cross_mode_lock( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0x10]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + stats.as_ref(), + &user, + Some(3), + 0, + Some(&cross_mode_lock), + bytes_me2c.as_ref(), + 41_200 + idx, + false, + false, + ), + ) + .await + }); + } + + let mut ok = 0usize; + let mut quota_exceeded = 0usize; + while let Some(done) = set.join_next().await { + let timed = done.expect("waiter task must not panic"); + let result = timed.expect("waiter must not block behind pending first write"); + match result { + Ok(_) => ok += 1, + Err(ProxyError::DataQuotaExceeded { .. }) => quota_exceeded += 1, + Err(other) => panic!("unexpected error in waiter: {other:?}"), + } + } + + assert_eq!(ok, 1, "exactly one waiter should consume remaining one-byte quota"); + assert_eq!(quota_exceeded, 47); + + release_blocking_write(&writer_state); + + let first_result = timeout(Duration::from_millis(300), first) + .await + .expect("first task timed out") + .expect("first task must not panic"); + assert!(first_result.is_ok()); + + assert_eq!(stats.get_user_total_octets(&user), 3); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), 3); +} diff --git a/src/proxy/tests/middle_relay_cross_mode_lookup_efficiency_security_tests.rs b/src/proxy/tests/middle_relay_cross_mode_lookup_efficiency_security_tests.rs new file mode 100644 index 0000000..37e1b87 --- /dev/null +++ b/src/proxy/tests/middle_relay_cross_mode_lookup_efficiency_security_tests.rs @@ -0,0 +1,116 @@ +use super::*; +use crate::crypto::{AesCtr, SecureRandom}; +use crate::stats::Stats; +use crate::stream::CryptoWriter; +use bytes::Bytes; +use std::sync::Arc; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::{Mutex, OnceLock}; + +fn make_crypto_writer(writer: W) -> CryptoWriter +where + W: tokio::io::AsyncWrite + Unpin, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) +} + +fn lookup_counter_test_lock() -> &'static Mutex<()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())) +} + +#[tokio::test] +async fn tdd_prefetched_cross_mode_lock_avoids_per_frame_registry_lookup_in_me_to_client_writer() { + let _guard = lookup_counter_test_lock() + .lock() + .unwrap_or_else(|poison| poison.into_inner()); + + let stats = Stats::new(); + let user = format!("middle-cross-mode-lookup-{}", std::process::id()); + let cross_mode_lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); + + crate::proxy::quota_lock_registry::reset_cross_mode_quota_user_lock_lookup_count_for_tests(); + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let bytes_me2c = AtomicU64::new(0); + + for idx in 0..8u64 { + let outcome = process_me_writer_response_with_cross_mode_lock( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0xAB]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(1024), + 0, + Some(&cross_mode_lock), + &bytes_me2c, + 20_000 + idx, + false, + false, + ) + .await; + + assert!(outcome.is_ok()); + } + + assert_eq!( + crate::proxy::quota_lock_registry::cross_mode_quota_user_lock_lookup_count_for_user_for_tests(&user), + 0, + "prefetched lock path must not re-query lock registry per frame" + ); + assert_eq!(stats.get_user_total_octets(&user), 8); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), 8); +} + +#[tokio::test] +async fn control_without_prefetched_lock_still_uses_registry_lookup_path() { + let _guard = lookup_counter_test_lock() + .lock() + .unwrap_or_else(|poison| poison.into_inner()); + + let stats = Stats::new(); + let user = format!("middle-cross-mode-lookup-control-{}", std::process::id()); + + crate::proxy::quota_lock_registry::reset_cross_mode_quota_user_lock_lookup_count_for_tests(); + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let bytes_me2c = AtomicU64::new(0); + + let outcome = process_me_writer_response_with_cross_mode_lock( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0xCD]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(1024), + 0, + None, + &bytes_me2c, + 20_100, + false, + false, + ) + .await; + + assert!(outcome.is_ok()); + assert_eq!( + crate::proxy::quota_lock_registry::cross_mode_quota_user_lock_lookup_count_for_user_for_tests(&user), + 1, + "fallback path without prefetched lock should perform a registry lookup" + ); +} diff --git a/src/proxy/tests/middle_relay_cross_mode_quota_lock_matrix_security_tests.rs b/src/proxy/tests/middle_relay_cross_mode_quota_lock_matrix_security_tests.rs new file mode 100644 index 0000000..bc7c857 --- /dev/null +++ b/src/proxy/tests/middle_relay_cross_mode_quota_lock_matrix_security_tests.rs @@ -0,0 +1,376 @@ +use super::*; +use crate::crypto::{AesCtr, SecureRandom}; +use crate::stats::Stats; +use crate::stream::CryptoWriter; +use bytes::Bytes; +use std::sync::Arc; +use std::sync::atomic::{AtomicU64, Ordering}; +use tokio::time::{Duration, timeout}; + +fn make_crypto_writer(writer: W) -> CryptoWriter +where + W: tokio::io::AsyncWrite + Unpin, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) +} + +#[tokio::test] +async fn positive_quota_limited_me_to_client_write_updates_counters_exactly_once() { + let stats = Stats::new(); + let user = format!("middle-cross-matrix-positive-{}", std::process::id()); + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let bytes_me2c = AtomicU64::new(0); + + let result = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[1, 2, 3, 4]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(128), + 0, + &bytes_me2c, + 10_001, + false, + false, + ) + .await; + + assert!(result.is_ok()); + assert_eq!(stats.get_user_total_octets(&user), 4); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), 4); +} + +#[tokio::test] +async fn negative_held_cross_mode_lock_blocks_quota_limited_me_to_client_path() { + let stats = Stats::new(); + let user = format!("middle-cross-matrix-negative-{}", std::process::id()); + let held = cross_mode_quota_user_lock_for_tests(&user); + let held_guard = held + .try_lock() + .expect("test must hold lock before ME->C call"); + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let bytes_me2c = AtomicU64::new(0); + + let blocked = timeout( + Duration::from_millis(25), + process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0x41]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(256), + 0, + &bytes_me2c, + 10_002, + false, + false, + ), + ) + .await; + + assert!(blocked.is_err()); + drop(held_guard); +} + +#[tokio::test] +async fn edge_quota_none_bypasses_cross_mode_lock_guard_in_me_to_client_path() { + let stats = Stats::new(); + let user = format!("middle-cross-matrix-edge-none-{}", std::process::id()); + let held = cross_mode_quota_user_lock_for_tests(&user); + let held_guard = held + .try_lock() + .expect("test must hold lock while quota is disabled"); + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let bytes_me2c = AtomicU64::new(0); + + let outcome = timeout( + Duration::from_millis(80), + process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0x11, 0x22]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + None, + 0, + &bytes_me2c, + 10_003, + false, + false, + ), + ) + .await + .expect("quota-none path must not wait on cross-mode lock"); + + assert!(outcome.is_ok()); + drop(held_guard); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn adversarial_same_user_parallel_quota_limited_writes_stay_hard_capped() { + let stats = Arc::new(Stats::new()); + let user = format!("middle-cross-matrix-adversarial-{}", std::process::id()); + let limit = 64u64; + let bytes_me2c = Arc::new(AtomicU64::new(0)); + let mut tasks = Vec::new(); + + for idx in 0..256u64 { + let stats = Arc::clone(&stats); + let bytes_me2c = Arc::clone(&bytes_me2c); + let user = user.clone(); + tasks.push(tokio::spawn(async move { + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0xEE]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + stats.as_ref(), + &user, + Some(limit), + 0, + bytes_me2c.as_ref(), + 11_000 + idx, + false, + false, + ) + .await + })); + } + + let mut ok = 0usize; + for task in tasks { + match task.await.expect("task must not panic") { + Ok(_) => ok += 1, + Err(ProxyError::DataQuotaExceeded { .. }) => {} + Err(other) => panic!("unexpected error in adversarial parallel case: {other:?}"), + } + } + + assert_eq!(ok, limit as usize); + assert_eq!(stats.get_user_total_octets(&user), limit); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), limit); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn integration_shared_lock_blocks_direct_relay_and_middle_relay_for_same_user() { + let user = format!("middle-cross-matrix-integration-{}", std::process::id()); + let relay_lock = crate::proxy::relay::cross_mode_quota_user_lock_for_tests(&user); + let middle_lock = cross_mode_quota_user_lock_for_tests(&user); + assert!( + Arc::ptr_eq(&relay_lock, &middle_lock), + "relay and middle-relay must share the same cross-mode lock identity" + ); + + let held_guard = relay_lock + .try_lock() + .expect("test must hold shared cross-mode lock"); + + let stats = Stats::new(); + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let bytes_me2c = AtomicU64::new(0); + + let middle_blocked = timeout( + Duration::from_millis(25), + process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0x92]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(1024), + 0, + &bytes_me2c, + 12_001, + false, + false, + ), + ) + .await; + assert!(middle_blocked.is_err()); + + drop(held_guard); + + let middle_ready = timeout( + Duration::from_millis(250), + process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0x94]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(1024), + 0, + &bytes_me2c, + 12_002, + false, + false, + ), + ) + .await + .expect("middle path must complete after release"); + + assert!(middle_ready.is_ok()); +} + +#[tokio::test] +async fn light_fuzz_mixed_payload_sizes_with_periodic_lock_holds_keeps_accounting_consistent() { + let stats = Stats::new(); + let user = format!("middle-cross-matrix-fuzz-{}", std::process::id()); + let bytes_me2c = AtomicU64::new(0); + let mut seed = 0xC0DE_1234_55AA_9988u64; + + for case in 0..96u32 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let hold = (seed & 0x03) == 0; + let mut held_lock = None; + let maybe_guard = if hold { + held_lock = Some(cross_mode_quota_user_lock_for_tests(&user)); + Some( + held_lock + .as_ref() + .expect("held lock should be present") + .try_lock() + .expect("cross-mode lock should be acquirable in fuzz round"), + ) + } else { + None + }; + + let payload_len = ((seed >> 8) as usize % 8) + 1; + let payload = vec![(seed & 0xff) as u8; payload_len]; + let before = stats.get_user_total_octets(&user); + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + + let timed = timeout( + Duration::from_millis(20), + process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from(payload), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(1024), + 0, + &bytes_me2c, + 13_000 + case as u64, + false, + false, + ), + ) + .await; + + if hold { + assert!(timed.is_err(), "held-lock fuzz round must block within timeout"); + assert_eq!(stats.get_user_total_octets(&user), before); + } else { + let done = timed.expect("unheld fuzz round must complete in time"); + assert!(done.is_ok()); + } + + drop(maybe_guard); + drop(held_lock); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), stats.get_user_total_octets(&user)); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_held_user_lock_does_not_block_other_users_me_to_client_writes() { + let held_user = format!("middle-cross-matrix-stress-held-{}", std::process::id()); + let free_user = format!("middle-cross-matrix-stress-free-{}", std::process::id()); + + let held = cross_mode_quota_user_lock_for_tests(&held_user); + let held_guard = held + .try_lock() + .expect("test must hold lock for blocked user"); + + let mut tasks = Vec::new(); + for idx in 0..64u64 { + let user = free_user.clone(); + tasks.push(tokio::spawn(async move { + let stats = Stats::new(); + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let bytes_me2c = AtomicU64::new(0); + process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0xA0]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(1), + 0, + &bytes_me2c, + 14_000 + idx, + false, + false, + ) + .await + })); + } + + timeout(Duration::from_secs(2), async { + for task in tasks { + let done = task.await.expect("free-user task must not panic"); + assert!(done.is_ok()); + } + }) + .await + .expect("free-user tasks should complete without waiting for held user's lock"); + + drop(held_guard); +} diff --git a/src/proxy/tests/middle_relay_cross_mode_quota_reservation_security_tests.rs b/src/proxy/tests/middle_relay_cross_mode_quota_reservation_security_tests.rs new file mode 100644 index 0000000..51092bd --- /dev/null +++ b/src/proxy/tests/middle_relay_cross_mode_quota_reservation_security_tests.rs @@ -0,0 +1,254 @@ +use super::*; +use crate::crypto::{AesCtr, SecureRandom}; +use crate::stats::Stats; +use crate::stream::CryptoWriter; +use bytes::Bytes; +use std::pin::Pin; +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; +use std::sync::{Arc, Mutex}; +use std::task::{Context, Poll, Waker}; +use tokio::io::AsyncWrite; +use tokio::sync::Notify; +use tokio::time::{Duration, timeout}; + +fn make_crypto_writer(writer: W) -> CryptoWriter +where + W: tokio::io::AsyncWrite + Unpin, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) +} + +#[derive(Default)] +struct BlockingWriteState { + write_entered: AtomicBool, + released: AtomicBool, + write_waker: Mutex>, + write_entered_notify: Notify, +} + +struct BlockingWrite { + state: Arc, +} + +impl BlockingWrite { + fn new(state: Arc) -> Self { + Self { state } + } +} + +impl AsyncWrite for BlockingWrite { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.state.write_entered.store(true, Ordering::Release); + self.state.write_entered_notify.notify_waiters(); + + if self.state.released.load(Ordering::Acquire) { + return Poll::Ready(Ok(buf.len())); + } + + if let Ok(mut slot) = self.state.write_waker.lock() { + *slot = Some(cx.waker().clone()); + } + Poll::Pending + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} + +async fn wait_until_blocking_write_entered(state: &Arc) { + for _ in 0..8 { + if state.write_entered.load(Ordering::Acquire) { + return; + } + let _ = timeout(Duration::from_millis(25), state.write_entered_notify.notified()).await; + } + + panic!("blocking writer did not enter poll_write in bounded time"); +} + +fn release_blocking_write(state: &Arc) { + state.released.store(true, Ordering::Release); + if let Ok(mut slot) = state.write_waker.lock() + && let Some(waker) = slot.take() + { + waker.wake(); + } +} + +#[tokio::test] +async fn adversarial_held_cross_mode_lock_blocks_me_to_client_quota_reservation_path() { + let stats = Stats::new(); + let user = format!("middle-me2c-cross-mode-held-{}", std::process::id()); + let held = cross_mode_quota_user_lock_for_tests(&user); + let held_guard = held + .try_lock() + .expect("test must hold shared cross-mode lock before ME->C write path"); + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let bytes_me2c = AtomicU64::new(0); + + let blocked = timeout( + Duration::from_millis(25), + process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0x41]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(1024), + 0, + &bytes_me2c, + 9901, + false, + false, + ), + ) + .await; + + assert!( + blocked.is_err(), + "ME->C quota reservation path must be serialized by held shared cross-mode lock" + ); + + drop(held_guard); + + let released = timeout( + Duration::from_millis(250), + process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0x42]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(1024), + 0, + &bytes_me2c, + 9902, + false, + false, + ), + ) + .await + .expect("ME->C write must complete after cross-mode lock release"); + + assert!(released.is_ok()); +} + +#[tokio::test] +async fn business_uncontended_cross_mode_lock_allows_me_to_client_quota_reservation() { + let stats = Stats::new(); + let user = format!("middle-me2c-cross-mode-free-{}", std::process::id()); + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let bytes_me2c = AtomicU64::new(0); + + let outcome = timeout( + Duration::from_millis(250), + process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0x55, 0x66]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(1024), + 0, + &bytes_me2c, + 9903, + false, + false, + ), + ) + .await + .expect("uncontended ME->C path should not stall"); + + assert!(outcome.is_ok()); + assert_eq!(stats.get_user_total_octets(&user), 2); + assert_eq!(bytes_me2c.load(std::sync::atomic::Ordering::Relaxed), 2); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn adversarial_cross_mode_lock_is_released_before_me_to_client_write_await() { + let stats = Arc::new(Stats::new()); + let user = format!("middle-me2c-lock-drop-before-write-{}", std::process::id()); + let cross_mode_lock = cross_mode_quota_user_lock_for_tests(&user); + let bytes_me2c = Arc::new(AtomicU64::new(0)); + let writer_state = Arc::new(BlockingWriteState::default()); + + let worker = { + let stats = Arc::clone(&stats); + let user = user.clone(); + let cross_mode_lock = Arc::clone(&cross_mode_lock); + let bytes_me2c = Arc::clone(&bytes_me2c); + let writer_state = Arc::clone(&writer_state); + tokio::spawn(async move { + let mut writer = make_crypto_writer(BlockingWrite::new(writer_state)); + let mut frame_buf = Vec::new(); + let rng = SecureRandom::new(); + process_me_writer_response_with_cross_mode_lock( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0xDE, 0xAD, 0xBE, 0xEF]), + }, + &mut writer, + ProtoTag::Intermediate, + &rng, + &mut frame_buf, + stats.as_ref(), + &user, + Some(1024), + 0, + Some(&cross_mode_lock), + bytes_me2c.as_ref(), + 9910, + false, + false, + ) + .await + }) + }; + + wait_until_blocking_write_entered(&writer_state).await; + + let acquired_guard = timeout(Duration::from_millis(40), cross_mode_lock.lock()) + .await + .expect("cross-mode lock must be free while ME->C write is pending"); + drop(acquired_guard); + + release_blocking_write(&writer_state); + + let result = timeout(Duration::from_millis(300), worker) + .await + .expect("ME->C worker timed out after releasing blocking writer") + .expect("ME->C worker must not panic"); + + assert!(result.is_ok()); + assert_eq!(stats.get_user_total_octets(&user), 4); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), 4); +} diff --git a/src/proxy/tests/middle_relay_hol_quota_security_tests.rs b/src/proxy/tests/middle_relay_hol_quota_security_tests.rs new file mode 100644 index 0000000..3ce0235 --- /dev/null +++ b/src/proxy/tests/middle_relay_hol_quota_security_tests.rs @@ -0,0 +1,232 @@ +use super::*; +use crate::crypto::{AesCtr, SecureRandom}; +use crate::stats::Stats; +use crate::stream::CryptoWriter; +use bytes::Bytes; +use std::io; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; +use std::task::{Context, Poll, Waker}; +use tokio::io::AsyncWrite; +use tokio::time::{Duration, timeout}; + +fn make_crypto_writer(writer: W) -> CryptoWriter +where + W: tokio::io::AsyncWrite + Unpin, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) +} + +#[derive(Default)] +struct GateState { + open: AtomicBool, + parked_waker: std::sync::Mutex>, +} + +impl GateState { + fn open(&self) { + self.open.store(true, Ordering::Relaxed); + if let Ok(mut guard) = self.parked_waker.lock() + && let Some(w) = guard.take() + { + w.wake(); + } + } + + fn has_waiter(&self) -> bool { + self.parked_waker + .lock() + .map(|guard| guard.is_some()) + .unwrap_or(false) + } +} + +#[derive(Default)] +struct GateWriter { + gate: Arc, +} + +impl GateWriter { + fn new(gate: Arc) -> Self { + Self { gate } + } +} + +impl AsyncWrite for GateWriter { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + if self.gate.open.load(Ordering::Relaxed) { + return Poll::Ready(Ok(buf.len())); + } + + if let Ok(mut guard) = self.gate.parked_waker.lock() { + *guard = Some(cx.waker().clone()); + } + Poll::Pending + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} + +struct FailingWriter; + +impl AsyncWrite for FailingWriter { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _buf: &[u8], + ) -> Poll> { + Poll::Ready(Err(io::Error::new( + io::ErrorKind::BrokenPipe, + "injected writer failure", + ))) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn adversarial_same_user_slow_writer_must_not_hol_block_peer_connection() { + let stats = Stats::new(); + let bytes_me2c = AtomicU64::new(0); + let rng = SecureRandom::new(); + let quota_limit = Some(1024); + let user = "hol-quota-user"; + + let gate = Arc::new(GateState::default()); + + let mut blocked_writer = make_crypto_writer(GateWriter::new(Arc::clone(&gate))); + let slow_task = tokio::spawn(async move { + let mut frame_buf = Vec::new(); + process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0x10, 0x20, 0x30, 0x40]), + }, + &mut blocked_writer, + ProtoTag::Intermediate, + &rng, + &mut frame_buf, + &stats, + user, + quota_limit, + 0, + &bytes_me2c, + 7001, + false, + false, + ) + .await + }); + + timeout(Duration::from_millis(100), async { + loop { + if gate.has_waiter() { + break; + } + tokio::task::yield_now().await; + } + }) + .await + .expect("first writer must reach backpressure and park"); + + let stats_fast = Stats::new(); + let bytes_fast = AtomicU64::new(0); + let rng_fast = SecureRandom::new(); + let mut fast_writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf_fast = Vec::new(); + + timeout( + Duration::from_millis(50), + process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0x41]), + }, + &mut fast_writer, + ProtoTag::Intermediate, + &rng_fast, + &mut frame_buf_fast, + &stats_fast, + user, + quota_limit, + 0, + &bytes_fast, + 7002, + false, + false, + ), + ) + .await + .expect("peer connection must not be blocked by same-user stalled write") + .expect("fast peer write must succeed"); + + gate.open(); + let slow_result = timeout(Duration::from_secs(1), slow_task) + .await + .expect("stalled task must complete once gate opens") + .expect("stalled task must not panic"); + assert!(slow_result.is_ok()); +} + +#[tokio::test] +async fn negative_write_failure_rolls_back_pre_accounted_quota_and_forensics_bytes() { + let stats = Stats::new(); + let user = "rollback-user"; + stats.add_user_octets_from(user, 7); + + let bytes_me2c = AtomicU64::new(0); + let rng = SecureRandom::new(); + let mut writer = make_crypto_writer(FailingWriter); + let mut frame_buf = Vec::new(); + + let result = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[1, 2, 3, 4]), + }, + &mut writer, + ProtoTag::Intermediate, + &rng, + &mut frame_buf, + &stats, + user, + Some(64), + 0, + &bytes_me2c, + 7003, + false, + false, + ) + .await; + + assert!(matches!(result, Err(ProxyError::Io(_)))); + assert_eq!( + stats.get_user_total_octets(user), + 7, + "failed client write must not overcharge user quota accounting" + ); + assert_eq!( + bytes_me2c.load(Ordering::Relaxed), + 0, + "failed client write must not inflate ME->C forensic byte counter" + ); +} \ No newline at end of file diff --git a/src/proxy/tests/middle_relay_idle_policy_security_tests.rs b/src/proxy/tests/middle_relay_idle_policy_security_tests.rs index 3e0b30f..6ea182b 100644 --- a/src/proxy/tests/middle_relay_idle_policy_security_tests.rs +++ b/src/proxy/tests/middle_relay_idle_policy_security_tests.rs @@ -3,7 +3,7 @@ use crate::crypto::AesCtr; use crate::stats::Stats; use crate::stream::{BufferPool, CryptoReader}; use std::sync::atomic::AtomicU64; -use std::sync::{Arc, Mutex, OnceLock}; +use std::sync::Arc; use tokio::io::AsyncWriteExt; use tokio::io::duplex; use tokio::time::{Duration as TokioDuration, Instant as TokioInstant, timeout}; @@ -48,18 +48,6 @@ fn make_idle_policy(soft_ms: u64, hard_ms: u64, grace_ms: u64) -> RelayClientIdl } } -fn idle_pressure_test_lock() -> &'static Mutex<()> { - static TEST_LOCK: OnceLock> = OnceLock::new(); - TEST_LOCK.get_or_init(|| Mutex::new(())) -} - -fn acquire_idle_pressure_test_lock() -> std::sync::MutexGuard<'static, ()> { - match idle_pressure_test_lock().lock() { - Ok(guard) => guard, - Err(poisoned) => poisoned.into_inner(), - } -} - #[tokio::test] async fn idle_policy_soft_mark_then_hard_close_increments_reason_counters() { let (reader, _writer) = duplex(1024); @@ -372,7 +360,7 @@ async fn stress_many_idle_sessions_fail_closed_without_hang() { #[test] fn pressure_evicts_oldest_idle_candidate_with_deterministic_ordering() { - let _guard = acquire_idle_pressure_test_lock(); + let _guard = relay_idle_pressure_test_scope(); clear_relay_idle_pressure_state_for_testing(); let stats = Stats::new(); @@ -402,7 +390,7 @@ fn pressure_evicts_oldest_idle_candidate_with_deterministic_ordering() { #[test] fn pressure_does_not_evict_without_new_pressure_signal() { - let _guard = acquire_idle_pressure_test_lock(); + let _guard = relay_idle_pressure_test_scope(); clear_relay_idle_pressure_state_for_testing(); let stats = Stats::new(); @@ -421,7 +409,7 @@ fn pressure_does_not_evict_without_new_pressure_signal() { #[test] fn stress_pressure_eviction_preserves_fifo_across_many_candidates() { - let _guard = acquire_idle_pressure_test_lock(); + let _guard = relay_idle_pressure_test_scope(); clear_relay_idle_pressure_state_for_testing(); let stats = Stats::new(); @@ -457,7 +445,7 @@ fn stress_pressure_eviction_preserves_fifo_across_many_candidates() { #[test] fn blackhat_single_pressure_event_must_not_evict_more_than_one_candidate() { - let _guard = acquire_idle_pressure_test_lock(); + let _guard = relay_idle_pressure_test_scope(); clear_relay_idle_pressure_state_for_testing(); let stats = Stats::new(); @@ -491,7 +479,7 @@ fn blackhat_single_pressure_event_must_not_evict_more_than_one_candidate() { #[test] fn blackhat_pressure_counter_must_track_global_budget_not_per_session_cursor() { - let _guard = acquire_idle_pressure_test_lock(); + let _guard = relay_idle_pressure_test_scope(); clear_relay_idle_pressure_state_for_testing(); let stats = Stats::new(); @@ -524,7 +512,7 @@ fn blackhat_pressure_counter_must_track_global_budget_not_per_session_cursor() { #[test] fn blackhat_stale_pressure_before_idle_mark_must_not_trigger_eviction() { - let _guard = acquire_idle_pressure_test_lock(); + let _guard = relay_idle_pressure_test_scope(); clear_relay_idle_pressure_state_for_testing(); let stats = Stats::new(); @@ -543,7 +531,7 @@ fn blackhat_stale_pressure_before_idle_mark_must_not_trigger_eviction() { #[test] fn blackhat_stale_pressure_must_not_evict_any_of_newly_marked_batch() { - let _guard = acquire_idle_pressure_test_lock(); + let _guard = relay_idle_pressure_test_scope(); clear_relay_idle_pressure_state_for_testing(); let stats = Stats::new(); @@ -575,7 +563,7 @@ fn blackhat_stale_pressure_must_not_evict_any_of_newly_marked_batch() { #[test] fn blackhat_stale_pressure_seen_without_candidates_must_be_globally_invalidated() { - let _guard = acquire_idle_pressure_test_lock(); + let _guard = relay_idle_pressure_test_scope(); clear_relay_idle_pressure_state_for_testing(); let stats = Stats::new(); @@ -601,7 +589,7 @@ fn blackhat_stale_pressure_seen_without_candidates_must_be_globally_invalidated( #[test] fn blackhat_stale_pressure_must_not_survive_candidate_churn() { - let _guard = acquire_idle_pressure_test_lock(); + let _guard = relay_idle_pressure_test_scope(); clear_relay_idle_pressure_state_for_testing(); let stats = Stats::new(); @@ -621,7 +609,7 @@ fn blackhat_stale_pressure_must_not_survive_candidate_churn() { #[test] fn blackhat_pressure_seq_saturation_must_not_disable_future_pressure_accounting() { - let _guard = acquire_idle_pressure_test_lock(); + let _guard = relay_idle_pressure_test_scope(); clear_relay_idle_pressure_state_for_testing(); { @@ -646,7 +634,7 @@ fn blackhat_pressure_seq_saturation_must_not_disable_future_pressure_accounting( #[test] fn blackhat_pressure_seq_saturation_must_not_break_multiple_distinct_events() { - let _guard = acquire_idle_pressure_test_lock(); + let _guard = relay_idle_pressure_test_scope(); clear_relay_idle_pressure_state_for_testing(); { @@ -673,7 +661,7 @@ fn blackhat_pressure_seq_saturation_must_not_break_multiple_distinct_events() { #[tokio::test(flavor = "multi_thread", worker_threads = 4)] async fn integration_race_single_pressure_event_allows_at_most_one_eviction_under_parallel_claims() { - let _guard = acquire_idle_pressure_test_lock(); + let _guard = relay_idle_pressure_test_scope(); clear_relay_idle_pressure_state_for_testing(); let stats = Arc::new(Stats::new()); @@ -738,7 +726,7 @@ async fn integration_race_single_pressure_event_allows_at_most_one_eviction_unde #[tokio::test(flavor = "multi_thread", worker_threads = 4)] async fn integration_race_burst_pressure_with_churn_preserves_empty_set_invalidation_and_budget() { - let _guard = acquire_idle_pressure_test_lock(); + let _guard = relay_idle_pressure_test_scope(); clear_relay_idle_pressure_state_for_testing(); let stats = Arc::new(Stats::new()); diff --git a/src/proxy/tests/middle_relay_idle_registry_poison_security_tests.rs b/src/proxy/tests/middle_relay_idle_registry_poison_security_tests.rs new file mode 100644 index 0000000..112d926 --- /dev/null +++ b/src/proxy/tests/middle_relay_idle_registry_poison_security_tests.rs @@ -0,0 +1,59 @@ +use super::*; +use std::panic::{AssertUnwindSafe, catch_unwind}; + +#[test] +fn blackhat_registry_poison_recovers_with_fail_closed_reset_and_pressure_accounting() { + let _guard = relay_idle_pressure_test_scope(); + clear_relay_idle_pressure_state_for_testing(); + + let _ = catch_unwind(AssertUnwindSafe(|| { + let registry = relay_idle_candidate_registry(); + let mut guard = registry + .lock() + .expect("registry lock must be acquired before poison"); + guard.by_conn_id.insert( + 999, + RelayIdleCandidateMeta { + mark_order_seq: 1, + mark_pressure_seq: 0, + }, + ); + guard.ordered.insert((1, 999)); + panic!("intentional poison for idle-registry recovery"); + })); + + // Helper lock must recover from poison, reset stale state, and continue. + assert!(mark_relay_idle_candidate(42)); + assert_eq!(oldest_relay_idle_candidate(), Some(42)); + + let before = relay_pressure_event_seq(); + note_relay_pressure_event(); + let after = relay_pressure_event_seq(); + assert!(after > before, "pressure accounting must still advance after poison"); + + clear_relay_idle_pressure_state_for_testing(); +} + +#[test] +fn clear_state_helper_must_reset_poisoned_registry_for_deterministic_fifo_tests() { + let _guard = relay_idle_pressure_test_scope(); + clear_relay_idle_pressure_state_for_testing(); + + let _ = catch_unwind(AssertUnwindSafe(|| { + let registry = relay_idle_candidate_registry(); + let _guard = registry + .lock() + .expect("registry lock must be acquired before poison"); + panic!("intentional poison while lock held"); + })); + + clear_relay_idle_pressure_state_for_testing(); + + assert_eq!(oldest_relay_idle_candidate(), None); + assert_eq!(relay_pressure_event_seq(), 0); + + assert!(mark_relay_idle_candidate(7)); + assert_eq!(oldest_relay_idle_candidate(), Some(7)); + + clear_relay_idle_pressure_state_for_testing(); +} diff --git a/src/proxy/tests/middle_relay_quota_extended_attack_surface_security_tests.rs b/src/proxy/tests/middle_relay_quota_extended_attack_surface_security_tests.rs new file mode 100644 index 0000000..29384e0 --- /dev/null +++ b/src/proxy/tests/middle_relay_quota_extended_attack_surface_security_tests.rs @@ -0,0 +1,372 @@ +use super::*; +use crate::crypto::{AesCtr, SecureRandom}; +use crate::error::ProxyError; +use crate::stats::Stats; +use crate::stream::CryptoWriter; +use bytes::Bytes; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::{Arc, OnceLock, Mutex}; +use tokio::sync::Mutex as AsyncMutex; +use tokio::task::JoinSet; +use tokio::time::{Duration, timeout}; + +fn make_crypto_writer(writer: W) -> CryptoWriter +where + W: tokio::io::AsyncWrite + Unpin, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) +} + +fn lookup_test_lock() -> &'static Mutex<()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())) +} + +#[tokio::test] +async fn positive_me2c_quota_counts_bytes_exactly_once() { + let _guard = lookup_test_lock().lock().unwrap(); + let stats = Stats::new(); + let user = format!("quota-middle-ext-positive-{}", std::process::id()); + let lock = Arc::new(AsyncMutex::new(())); + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let bytes_me2c = AtomicU64::new(0); + + let result = process_me_writer_response_with_cross_mode_lock( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[1, 2, 3, 4, 5]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(64), + 0, + Some(&lock), + &bytes_me2c, + 70_001, + false, + false, + ) + .await; + + assert!(result.is_ok()); + assert_eq!(stats.get_user_total_octets(&user), 5); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), 5); +} + +#[tokio::test] +async fn negative_held_crossmode_lock_blocks_me2c_write() { + let _guard = lookup_test_lock().lock().unwrap(); + let stats = Stats::new(); + let user = format!("quota-middle-ext-negative-{}", std::process::id()); + + let lock = Arc::new(AsyncMutex::new(())); + let _held = lock.try_lock().expect("lock must be held"); + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let bytes_me2c = AtomicU64::new(0); + + let blocked = timeout( + Duration::from_millis(25), + process_me_writer_response_with_cross_mode_lock( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0xFE]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(16), + 0, + Some(&lock), + &bytes_me2c, + 70_101, + false, + false, + ), + ) + .await; + + assert!(blocked.is_err()); + assert_eq!(stats.get_user_total_octets(&user), 0); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0); +} + +#[tokio::test] +async fn edge_zero_quota_zero_payload_is_fail_closed() { + let _guard = lookup_test_lock().lock().unwrap(); + let stats = Stats::new(); + let user = format!("quota-middle-ext-edge-{}", std::process::id()); + + let lock = Arc::new(AsyncMutex::new(())); + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let bytes_me2c = AtomicU64::new(0); + + let result = process_me_writer_response_with_cross_mode_lock( + MeResponse::Data { + flags: 0, + data: Bytes::new(), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(0), + 0, + Some(&lock), + &bytes_me2c, + 70_201, + false, + false, + ) + .await; + + assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. }))); + assert_eq!(stats.get_user_total_octets(&user), 0); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn adversarial_parallel_me2c_race_falls_back_to_quota_error() { + let _guard = lookup_test_lock().lock().unwrap(); + let stats = Arc::new(Stats::new()); + let user = format!("quota-middle-ext-blackhat-{}", std::process::id()); + let quota = 64u64; + let lock = Arc::new(AsyncMutex::new(())); + let bytes_me2c = Arc::new(AtomicU64::new(0)); + + let mut set = JoinSet::new(); + for i in 0..256u64 { + let stats = Arc::clone(&stats); + let user = user.clone(); + let lock = Arc::clone(&lock); + let bytes_me2c = Arc::clone(&bytes_me2c); + + set.spawn(async move { + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let payload = vec![((i & 0xFF) as u8); (i % 4 + 1) as usize]; + + process_me_writer_response_with_cross_mode_lock( + MeResponse::Data { + flags: 0, + data: Bytes::from(payload), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + stats.as_ref(), + &user, + Some(quota), + 0, + Some(&lock), + bytes_me2c.as_ref(), + 70_301 + i, + false, + false, + ) + .await + }); + } + + let mut succeeded = 0usize; + while let Some(done) = set.join_next().await { + match done.expect("task must not panic") { + Ok(_) => succeeded += 1, + Err(ProxyError::DataQuotaExceeded { .. }) => {} + Err(other) => panic!("unexpected error {other:?}"), + } + } + + assert_eq!(stats.get_user_total_octets(&user), bytes_me2c.load(Ordering::Relaxed)); + assert!(stats.get_user_total_octets(&user) <= quota); + assert!(succeeded <= quota as usize); +} + +#[tokio::test] +async fn integration_shared_prefetched_lock_blocks_then_releases_writer() { + let stats = Stats::new(); + let user = format!("quota-middle-ext-integration-{}", std::process::id()); + let lock = Arc::new(AsyncMutex::new(())); + let held = lock + .try_lock() + .expect("integration test must hold prefetched lock first"); + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let bytes_me2c = AtomicU64::new(0); + + let blocked = timeout( + Duration::from_millis(25), + process_me_writer_response_with_cross_mode_lock( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0xA1]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(8), + 0, + Some(&lock), + &bytes_me2c, + 70_360, + false, + false, + ), + ) + .await; + assert!(blocked.is_err()); + + drop(held); + + let after_release = timeout( + Duration::from_millis(150), + process_me_writer_response_with_cross_mode_lock( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0xA2]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(8), + 0, + Some(&lock), + &bytes_me2c, + 70_361, + false, + false, + ), + ) + .await + .expect("writer should progress once the shared lock is released"); + + assert!(after_release.is_ok()); +} + +#[tokio::test] +async fn light_fuzz_small_payloads_toggle_lock_state_stays_consistent() { + let _guard = lookup_test_lock().lock().unwrap(); + let stats = Stats::new(); + let user = format!("quota-middle-ext-fuzz-{}", std::process::id()); + let mut seed = 0xCAFE_BABE_1234u64; + let bytes_me2c = AtomicU64::new(0); + + for case in 0..48u32 { + seed ^= seed << 5; + seed ^= seed >> 12; + seed ^= seed << 13; + let hold = (seed & 0x1) == 0; + + let lock = Arc::new(AsyncMutex::new(())); + let maybe_guard = if hold { + Some(lock.try_lock().unwrap()) + } else { + None + }; + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + + let result = timeout( + Duration::from_millis(30), + process_me_writer_response_with_cross_mode_lock( + MeResponse::Data { + flags: 0, + data: Bytes::from(vec![(seed & 0xFF) as u8; ((seed as usize % 5) + 1)]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(128), + 0, + Some(&lock), + &bytes_me2c, + 70_401 + case as u64, + false, + false, + ), + ) + .await; + + if hold { + assert!(result.is_err()); + } else { + assert!(result.unwrap().is_ok()); + } + + drop(maybe_guard); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_parallel_free_users_during_held_user_lock_maintains_liveness() { + let _guard = lookup_test_lock().lock().unwrap(); + let held = Arc::new(AsyncMutex::new(())); + let _held_guard = held.try_lock().unwrap(); + + let mut set = JoinSet::new(); + for i in 0..48u64 { + set.spawn(async move { + let stats = Stats::new(); + let user = format!("quota-middle-ext-stress-free-{i}"); + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let bytes_me2c = AtomicU64::new(0); + let free_lock = Arc::new(AsyncMutex::new(())); + + process_me_writer_response_with_cross_mode_lock( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0xEE]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(1), + 0, + Some(&free_lock), + &bytes_me2c, + 70_500 + i, + false, + false, + ) + .await + }); + } + + timeout(Duration::from_secs(2), async { + while let Some(task) = set.join_next().await { + task.unwrap().unwrap(); + } + }) + .await + .unwrap(); +} diff --git a/src/proxy/tests/middle_relay_quota_reservation_adversarial_tests.rs b/src/proxy/tests/middle_relay_quota_reservation_adversarial_tests.rs new file mode 100644 index 0000000..963b3e0 --- /dev/null +++ b/src/proxy/tests/middle_relay_quota_reservation_adversarial_tests.rs @@ -0,0 +1,1066 @@ +use super::*; +use crate::crypto::{AesCtr, SecureRandom}; +use crate::stats::Stats; +use crate::stream::CryptoWriter; +use bytes::Bytes; +use std::sync::Arc; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::task::{Context, Poll}; +use tokio::io::AsyncWrite; +use tokio::task::JoinSet; + +fn make_crypto_writer(writer: W) -> CryptoWriter +where + W: tokio::io::AsyncWrite + Unpin, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) +} + +struct FailingWriter; + +impl AsyncWrite for FailingWriter { + fn poll_write( + self: std::pin::Pin<&mut Self>, + _cx: &mut Context<'_>, + _buf: &[u8], + ) -> Poll> { + Poll::Ready(Err(std::io::Error::other("forced writer failure"))) + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown( + self: std::pin::Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } +} + +struct FailAfterBudgetWriter { + remaining: usize, + written: usize, +} + +impl FailAfterBudgetWriter { + fn new(remaining: usize) -> Self { + Self { + remaining, + written: 0, + } + } +} + +impl AsyncWrite for FailAfterBudgetWriter { + fn poll_write( + mut self: std::pin::Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + if self.remaining == 0 { + return Poll::Ready(Err(std::io::Error::other("forced short-write exhaustion"))); + } + + let n = self.remaining.min(buf.len()); + self.remaining -= n; + self.written += n; + Poll::Ready(Ok(n)) + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown( + self: std::pin::Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } +} + +#[tokio::test] +async fn positive_exact_quota_boundary_allows_last_frame_and_blocks_next() { + let stats = Stats::new(); + let user = "quota-boundary-user"; + let bytes_me2c = AtomicU64::new(0); + + stats.add_user_octets_from(user, 5); + + let mut writer_one = make_crypto_writer(tokio::io::sink()); + let mut frame_buf_one = Vec::new(); + let first = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[1, 2, 3]), + }, + &mut writer_one, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf_one, + &stats, + user, + Some(8), + 0, + &bytes_me2c, + 7101, + false, + false, + ) + .await; + + assert!(first.is_ok(), "frame that reaches boundary must be allowed"); + assert_eq!(stats.get_user_total_octets(user), 8); + + let mut writer_two = make_crypto_writer(tokio::io::sink()); + let mut frame_buf_two = Vec::new(); + let second = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[9]), + }, + &mut writer_two, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf_two, + &stats, + user, + Some(8), + 0, + &bytes_me2c, + 7102, + false, + false, + ) + .await; + + assert!( + matches!(second, Err(ProxyError::DataQuotaExceeded { .. })), + "frame after boundary must be rejected" + ); + assert_eq!(stats.get_user_total_octets(user), 8); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), 3); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn adversarial_parallel_reservation_stress_never_overshoots_quota_or_counters() { + let stats = Arc::new(Stats::new()); + let user = "reservation-stress-user"; + let quota_limit = 64u64; + let bytes_me2c = Arc::new(AtomicU64::new(0)); + let mut tasks = JoinSet::new(); + + for idx in 0..256u64 { + let user_owned = user.to_string(); + let stats_ref = Arc::clone(&stats); + let bytes_ref = Arc::clone(&bytes_me2c); + + tasks.spawn(async move { + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0xAB]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + stats_ref.as_ref(), + &user_owned, + Some(quota_limit), + 0, + bytes_ref.as_ref(), + 7200 + idx, + false, + false, + ) + .await + }); + } + + let mut ok = 0usize; + let mut denied = 0usize; + while let Some(joined) = tasks.join_next().await { + match joined.expect("reservation stress task must not panic") { + Ok(_) => ok += 1, + Err(ProxyError::DataQuotaExceeded { .. }) => denied += 1, + Err(other) => panic!("unexpected error in stress case: {other:?}"), + } + } + + let total = stats.get_user_total_octets(user); + assert_eq!( + total, quota_limit, + "quota must be exactly exhausted without overshoot" + ); + assert_eq!( + bytes_me2c.load(Ordering::Relaxed), + total, + "ME->C forensic bytes must track committed quota usage" + ); + assert_eq!(ok, quota_limit as usize, "exactly quota_limit tasks must succeed"); + assert_eq!( + denied, + 256usize - (quota_limit as usize), + "remaining tasks must be exactly denied without silently swallowing state" + ); +} + +#[tokio::test] +async fn light_fuzz_random_frame_sizes_preserve_quota_and_counter_consistency() { + let stats = Stats::new(); + let user = "reservation-fuzz-user"; + let quota_limit = 128u64; + let bytes_me2c = AtomicU64::new(0); + let mut seed = 0xC0FE_EE11_8899_2211u64; + + for conn in 0..512u64 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + let len = ((seed & 0x0f) + 1) as usize; + let payload = vec![0x5A; len]; + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let result = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from(payload), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + user, + Some(quota_limit), + 0, + &bytes_me2c, + 7300 + conn, + false, + false, + ) + .await; + + if let Err(err) = result { + assert!( + matches!(err, ProxyError::DataQuotaExceeded { .. }), + "fuzz run produced unexpected error variant: {err:?}" + ); + } + } + + let total = stats.get_user_total_octets(user); + assert!(total <= quota_limit); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), total); +} + +#[tokio::test] +async fn positive_soft_overshoot_allows_burst_inside_soft_cap_then_blocks() { + let stats = Stats::new(); + let user = "soft-cap-boundary-user"; + let bytes_me2c = AtomicU64::new(0); + let quota_limit = 10u64; + let overshoot = 3u64; + + stats.add_user_octets_from(user, 10); + + let mut writer_one = make_crypto_writer(tokio::io::sink()); + let mut frame_buf_one = Vec::new(); + let first = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[1, 2, 3]), + }, + &mut writer_one, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf_one, + &stats, + user, + Some(quota_limit), + overshoot, + &bytes_me2c, + 7401, + false, + false, + ) + .await; + assert!(first.is_ok(), "soft-cap buffer should allow reaching limit+overshoot"); + assert_eq!(stats.get_user_total_octets(user), 13); + + let mut writer_two = make_crypto_writer(tokio::io::sink()); + let mut frame_buf_two = Vec::new(); + let second = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[9]), + }, + &mut writer_two, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf_two, + &stats, + user, + Some(quota_limit), + overshoot, + &bytes_me2c, + 7402, + false, + false, + ) + .await; + assert!(matches!(second, Err(ProxyError::DataQuotaExceeded { .. }))); + assert_eq!(stats.get_user_total_octets(user), 13); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), 3); +} + +#[tokio::test] +async fn negative_soft_overshoot_rejects_when_payload_exceeds_remaining_soft_budget() { + let stats = Stats::new(); + let user = "soft-cap-remaining-user"; + let bytes_me2c = AtomicU64::new(0); + let quota_limit = 10u64; + let overshoot = 4u64; + + stats.add_user_octets_from(user, 12); + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let result = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[1, 2, 3]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + user, + Some(quota_limit), + overshoot, + &bytes_me2c, + 7501, + false, + false, + ) + .await; + + assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. }))); + assert_eq!(stats.get_user_total_octets(user), 12); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0); +} + +#[tokio::test] +async fn negative_write_failure_rolls_back_reservation_under_soft_cap_mode() { + let stats = Stats::new(); + let user = "soft-cap-rollback-user"; + let bytes_me2c = AtomicU64::new(0); + let mut writer = make_crypto_writer(FailingWriter); + let mut frame_buf = Vec::new(); + + stats.add_user_octets_from(user, 9); + + let result = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[1, 2, 3]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + user, + Some(10), + 8, + &bytes_me2c, + 7601, + false, + false, + ) + .await; + + assert!(matches!(result, Err(ProxyError::Io(_)))); + assert_eq!(stats.get_user_total_octets(user), 9); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn adversarial_parallel_soft_cap_stress_never_exceeds_soft_limit() { + let stats = Arc::new(Stats::new()); + let user = "soft-cap-stress-user"; + let quota_limit = 40u64; + let overshoot = 5u64; + let soft_limit = quota_limit + overshoot; + let bytes_me2c = Arc::new(AtomicU64::new(0)); + let mut tasks = JoinSet::new(); + + for idx in 0..256u64 { + let user_owned = user.to_string(); + let stats_ref = Arc::clone(&stats); + let bytes_ref = Arc::clone(&bytes_me2c); + tasks.spawn(async move { + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0x42]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + stats_ref.as_ref(), + &user_owned, + Some(quota_limit), + overshoot, + bytes_ref.as_ref(), + 7700 + idx, + false, + false, + ) + .await + }); + } + + while let Some(joined) = tasks.join_next().await { + match joined.expect("soft-cap stress task must not panic") { + Ok(_) | Err(ProxyError::DataQuotaExceeded { .. }) => {} + Err(other) => panic!("unexpected error in soft-cap stress case: {other:?}"), + } + } + + let total = stats.get_user_total_octets(user); + assert!(total <= soft_limit, "soft-cap stress must never overshoot soft limit"); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), total); +} + +#[tokio::test] +async fn light_fuzz_soft_cap_matrix_keeps_counters_and_limits_consistent() { + let stats = Stats::new(); + let user = "soft-cap-fuzz-user"; + let bytes_me2c = AtomicU64::new(0); + let mut seed = 0x9E37_79B9_7F4A_7C15u64; + + for conn in 0..1024u64 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let quota_limit = 32 + (seed & 0x3f); + let overshoot = seed.rotate_left(13) & 0x0f; + let len = ((seed >> 3) & 0x07) + 1; + let payload = vec![0xA5; len as usize]; + let before = stats.get_user_total_octets(user); + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let result = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from(payload), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + user, + Some(quota_limit), + overshoot, + &bytes_me2c, + 7800 + conn, + false, + false, + ) + .await; + + if let Err(ref err) = result { + assert!( + matches!(err, ProxyError::DataQuotaExceeded { .. }), + "soft-cap fuzz produced unexpected error variant: {err:?}" + ); + } + + let after = stats.get_user_total_octets(user); + let soft_limit = quota_limit.saturating_add(overshoot); + match result { + Ok(_) => { + assert_eq!(after, before.saturating_add(len)); + assert!(after <= soft_limit, "accepted write must stay within active soft cap"); + } + Err(_) => { + assert_eq!(after, before, "rejected write must not mutate quota state"); + } + } + assert_eq!( + bytes_me2c.load(Ordering::Relaxed), + after, + "soft-cap fuzz must keep counters synchronized" + ); + } +} + +#[tokio::test] +async fn positive_no_quota_limit_accumulates_data_octets_exactly() { + let stats = Stats::new(); + let user = "no-quota-user"; + let bytes_me2c = AtomicU64::new(0); + let mut expected = 0u64; + + for (idx, len) in [1usize, 2, 3, 5, 8, 13, 21].iter().copied().enumerate() { + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let payload = vec![0x41; len]; + let result = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from(payload), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + user, + None, + 0, + &bytes_me2c, + 7900 + idx as u64, + false, + false, + ) + .await; + + assert!(result.is_ok()); + expected += len as u64; + } + + assert_eq!(stats.get_user_total_octets(user), expected); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), expected); +} + +#[tokio::test] +async fn negative_zero_quota_rejects_non_empty_payload() { + let stats = Stats::new(); + let user = "zero-quota-user"; + let bytes_me2c = AtomicU64::new(0); + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let result = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0xAA]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + user, + Some(0), + 0, + &bytes_me2c, + 8001, + false, + false, + ) + .await; + + assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. }))); + assert_eq!(stats.get_user_total_octets(user), 0); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0); +} + +#[tokio::test] +async fn edge_zero_length_payload_with_zero_quota_is_fail_closed() { + let stats = Stats::new(); + let user = "zero-len-zero-quota-user"; + let bytes_me2c = AtomicU64::new(0); + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let result = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::new(), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + user, + Some(0), + 0, + &bytes_me2c, + 8002, + false, + false, + ) + .await; + + assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. }))); + assert_eq!(stats.get_user_total_octets(user), 0); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0); +} + +#[tokio::test] +async fn positive_ack_response_does_not_touch_quota_counters() { + let stats = Stats::new(); + let user = "ack-accounting-user"; + let bytes_me2c = AtomicU64::new(11); + stats.add_user_octets_to(user, 23); + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let result = process_me_writer_response( + MeResponse::Ack(0x33445566), + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + user, + Some(24), + 0, + &bytes_me2c, + 8003, + true, + true, + ) + .await; + + assert!(result.is_ok()); + assert_eq!(stats.get_user_total_octets(user), 23); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), 11); +} + +#[tokio::test] +async fn edge_close_response_is_accounting_noop() { + let stats = Stats::new(); + let user = "close-accounting-user"; + let bytes_me2c = AtomicU64::new(19); + stats.add_user_octets_to(user, 31); + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let result = process_me_writer_response( + MeResponse::Close, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + user, + Some(40), + 3, + &bytes_me2c, + 8004, + false, + true, + ) + .await; + + assert!(result.is_ok()); + assert_eq!(stats.get_user_total_octets(user), 31); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), 19); +} + +#[tokio::test] +async fn negative_preloaded_above_soft_cap_rejects_even_single_byte() { + let stats = Stats::new(); + let user = "preloaded-over-soft-cap-user"; + let bytes_me2c = AtomicU64::new(0); + let quota_limit = 20u64; + let overshoot = 2u64; + stats.add_user_octets_to(user, quota_limit + overshoot + 1); + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let result = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[1]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + user, + Some(quota_limit), + overshoot, + &bytes_me2c, + 8005, + false, + false, + ) + .await; + + assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. }))); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0); + assert_eq!(stats.get_user_total_octets(user), quota_limit + overshoot + 1); +} + +#[tokio::test] +async fn adversarial_fail_writer_path_never_desynchronizes_quota_accounting() { + let stats = Stats::new(); + let user = "partial-write-rollback-user"; + let bytes_me2c = AtomicU64::new(0); + let mut writer = make_crypto_writer(FailAfterBudgetWriter::new(7)); + let mut frame_buf = Vec::new(); + let payload_len = 16 * 1024u64; + + let result = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from(vec![0x42; 16 * 1024]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + user, + Some(payload_len), + 0, + &bytes_me2c, + 8006, + false, + false, + ) + .await; + + let total_after = stats.get_user_total_octets(user); + let forensic_after = bytes_me2c.load(Ordering::Relaxed); + assert_eq!(forensic_after, total_after); + assert!( + total_after == 0 || total_after == payload_len, + "writer failure path must either roll back fully or commit exactly one payload" + ); + + // Regardless of whether I/O failure surfaced immediately or was deferred, + // accounting must remain fail-closed and prevent silent overshoot. + let mut writer_two = make_crypto_writer(tokio::io::sink()); + let mut frame_buf_two = Vec::new(); + let second = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0x99]), + }, + &mut writer_two, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf_two, + &stats, + user, + Some(payload_len), + 0, + &bytes_me2c, + 8007, + false, + false, + ) + .await; + + if total_after == payload_len { + assert!(matches!(second, Err(ProxyError::DataQuotaExceeded { .. }))); + } else { + assert!(second.is_ok()); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_parallel_oversized_frames_fail_closed_without_counter_leak() { + let stats = Arc::new(Stats::new()); + let user = "parallel-fail-rollback-user"; + let bytes_me2c = Arc::new(AtomicU64::new(0)); + let mut tasks = JoinSet::new(); + + for idx in 0..256u64 { + let user_owned = user.to_string(); + let stats_ref = Arc::clone(&stats); + let bytes_ref = Arc::clone(&bytes_me2c); + tasks.spawn(async move { + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from(vec![0xEE; 12 * 1024]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + stats_ref.as_ref(), + &user_owned, + Some(512), + 0, + bytes_ref.as_ref(), + 8100 + idx, + false, + false, + ) + .await + }); + } + + while let Some(joined) = tasks.join_next().await { + let result = joined.expect("parallel fail writer task must not panic"); + assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. }))); + } + + assert_eq!(stats.get_user_total_octets(user), 0); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0); +} + +#[tokio::test] +async fn integration_mixed_data_ack_close_sequence_preserves_data_only_accounting() { + let stats = Stats::new(); + let user = "mixed-sequence-user"; + let bytes_me2c = AtomicU64::new(0); + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + + let data_one = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[1, 2, 3]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + user, + Some(32), + 0, + &bytes_me2c, + 8201, + false, + false, + ) + .await; + assert!(data_one.is_ok()); + + let ack = process_me_writer_response( + MeResponse::Ack(0x0102_0304), + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + user, + Some(32), + 0, + &bytes_me2c, + 8202, + true, + true, + ) + .await; + assert!(ack.is_ok()); + + let data_two = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[4, 5]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + user, + Some(32), + 0, + &bytes_me2c, + 8203, + false, + true, + ) + .await; + assert!(data_two.is_ok()); + + let close = process_me_writer_response( + MeResponse::Close, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + user, + Some(32), + 0, + &bytes_me2c, + 8204, + false, + true, + ) + .await; + assert!(close.is_ok()); + + assert_eq!(stats.get_user_total_octets(user), 5); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), 5); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_parallel_multi_user_quota_isolation_no_cross_user_leakage() { + let stats = Arc::new(Stats::new()); + let user_a = "quota-isolation-a"; + let user_b = "quota-isolation-b"; + let limit_a = 50u64; + let limit_b = 80u64; + let bytes_a = Arc::new(AtomicU64::new(0)); + let bytes_b = Arc::new(AtomicU64::new(0)); + + let mut tasks = JoinSet::new(); + for idx in 0..200u64 { + let stats_ref = Arc::clone(&stats); + let bytes_ref = Arc::clone(&bytes_a); + tasks.spawn(async move { + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0xA1]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + stats_ref.as_ref(), + user_a, + Some(limit_a), + 0, + bytes_ref.as_ref(), + 8300 + idx, + false, + false, + ) + .await + }); + } + + for idx in 0..220u64 { + let stats_ref = Arc::clone(&stats); + let bytes_ref = Arc::clone(&bytes_b); + tasks.spawn(async move { + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0xB2]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + stats_ref.as_ref(), + user_b, + Some(limit_b), + 0, + bytes_ref.as_ref(), + 8500 + idx, + false, + false, + ) + .await + }); + } + + while let Some(joined) = tasks.join_next().await { + let result = joined.expect("quota isolation task must not panic"); + assert!(result.is_ok() || matches!(result, Err(ProxyError::DataQuotaExceeded { .. }))); + } + + assert_eq!(stats.get_user_total_octets(user_a), limit_a); + assert_eq!(stats.get_user_total_octets(user_b), limit_b); + assert_eq!(bytes_a.load(Ordering::Relaxed), limit_a); + assert_eq!(bytes_b.load(Ordering::Relaxed), limit_b); +} + +#[tokio::test] +async fn light_fuzz_mixed_me_responses_preserve_quota_and_counter_invariants() { + let stats = Stats::new(); + let user = "mixed-fuzz-user"; + let bytes_me2c = AtomicU64::new(0); + let quota_limit = 96u64; + let mut seed = 0xDEAD_BEEF_2026_0323u64; + + for idx in 0..2048u64 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let choice = (seed & 0x03) as u8; + let response = if choice == 0 { + MeResponse::Ack((seed >> 8) as u32) + } else if choice == 1 { + MeResponse::Close + } else { + let len = ((seed >> 16) & 0x07) as usize; + let mut payload = vec![0u8; len]; + payload.fill((seed & 0xff) as u8); + MeResponse::Data { + flags: 0, + data: Bytes::from(payload), + } + }; + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let result = process_me_writer_response( + response, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + user, + Some(quota_limit), + 0, + &bytes_me2c, + 8800 + idx, + (idx & 1) == 0, + (idx & 2) == 0, + ) + .await; + + if let Err(err) = result { + assert!( + matches!(err, ProxyError::DataQuotaExceeded { .. }), + "mixed fuzz produced unexpected error variant: {err:?}" + ); + } + + let total = stats.get_user_total_octets(user); + assert!( + total <= quota_limit, + "mixed fuzz must keep usage at or below quota limit" + ); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), total); + } +} \ No newline at end of file diff --git a/src/proxy/tests/middle_relay_quota_reservation_extreme_security_tests.rs b/src/proxy/tests/middle_relay_quota_reservation_extreme_security_tests.rs new file mode 100644 index 0000000..e4d0c6e --- /dev/null +++ b/src/proxy/tests/middle_relay_quota_reservation_extreme_security_tests.rs @@ -0,0 +1,399 @@ +use super::*; +use crate::crypto::{AesCtr, SecureRandom}; +use crate::stats::Stats; +use crate::stream::CryptoWriter; +use bytes::Bytes; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::{Arc, Mutex, OnceLock}; +use tokio::sync::Mutex as AsyncMutex; +use tokio::task::JoinSet; +use tokio::time::{Duration, timeout}; + +fn make_crypto_writer(writer: W) -> CryptoWriter +where + W: tokio::io::AsyncWrite + Unpin, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) +} + +fn lookup_counter_test_lock() -> &'static Mutex<()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())) +} + +#[tokio::test] +async fn positive_prefetched_cross_mode_lock_multi_frame_accounting_is_exact() { + let _guard = lookup_counter_test_lock() + .lock() + .unwrap_or_else(|poison| poison.into_inner()); + + let stats = Stats::new(); + let user = format!("quota-extreme-positive-{}", std::process::id()); + let lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); + crate::proxy::quota_lock_registry::reset_cross_mode_quota_user_lock_lookup_count_for_tests(); + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let bytes_me2c = AtomicU64::new(0); + + for idx in 0..12u64 { + let payload = vec![0x5A; ((idx % 4) + 1) as usize]; + let result = process_me_writer_response_with_cross_mode_lock( + MeResponse::Data { + flags: 0, + data: Bytes::from(payload), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(512), + 0, + Some(&lock), + &bytes_me2c, + 31_000 + idx, + false, + false, + ) + .await; + + assert!(result.is_ok()); + } + + assert_eq!( + crate::proxy::quota_lock_registry::cross_mode_quota_user_lock_lookup_count_for_user_for_tests(&user), + 0, + "prefetched lock path must avoid hot-path registry lookups" + ); + assert_eq!( + stats.get_user_total_octets(&user), + bytes_me2c.load(Ordering::Relaxed), + "forensics and quota accounting must remain synchronized" + ); +} + +#[tokio::test] +async fn negative_held_prefetched_lock_blocks_writer_without_accounting_mutation() { + let _guard = lookup_counter_test_lock() + .lock() + .unwrap_or_else(|poison| poison.into_inner()); + + let stats = Stats::new(); + let user = format!("quota-extreme-negative-{}", std::process::id()); + let lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); + let held_guard = lock + .try_lock() + .expect("test must hold lock before calling ME->C writer"); + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let bytes_me2c = AtomicU64::new(0); + + let blocked = timeout( + Duration::from_millis(25), + process_me_writer_response_with_cross_mode_lock( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[1, 2, 3]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(64), + 0, + Some(&lock), + &bytes_me2c, + 31_100, + false, + false, + ), + ) + .await; + + assert!(blocked.is_err()); + assert_eq!(stats.get_user_total_octets(&user), 0); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0); + + drop(held_guard); +} + +#[tokio::test] +async fn edge_zero_quota_and_zero_payload_is_fail_closed() { + let _guard = lookup_counter_test_lock() + .lock() + .unwrap_or_else(|poison| poison.into_inner()); + + let stats = Stats::new(); + let user = format!("quota-extreme-edge-{}", std::process::id()); + let lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let bytes_me2c = AtomicU64::new(0); + + let result = process_me_writer_response_with_cross_mode_lock( + MeResponse::Data { + flags: 0, + data: Bytes::new(), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(0), + 0, + Some(&lock), + &bytes_me2c, + 31_200, + false, + false, + ) + .await; + + assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. }))); + assert_eq!(stats.get_user_total_octets(&user), 0); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn adversarial_blackhat_parallel_quota_race_never_overshoots_soft_cap() { + let _guard = lookup_counter_test_lock() + .lock() + .unwrap_or_else(|poison| poison.into_inner()); + + let stats = Arc::new(Stats::new()); + let user = format!("quota-extreme-blackhat-{}", std::process::id()); + let quota = 80u64; + let overshoot = 7u64; + let soft_limit = quota + overshoot; + let lock = Arc::new(crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user)); + let bytes_me2c = Arc::new(AtomicU64::new(0)); + + let mut set = JoinSet::new(); + for idx in 0..256u64 { + let stats = Arc::clone(&stats); + let user = user.clone(); + let lock = Arc::clone(&lock); + let bytes_me2c = Arc::clone(&bytes_me2c); + + set.spawn(async move { + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let len = ((idx % 5) + 1) as usize; + let payload = vec![0xAA; len]; + + process_me_writer_response_with_cross_mode_lock( + MeResponse::Data { + flags: 0, + data: Bytes::from(payload), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + stats.as_ref(), + &user, + Some(quota), + overshoot, + Some(&lock), + bytes_me2c.as_ref(), + 31_300 + idx, + false, + false, + ) + .await + }); + } + + while let Some(done) = set.join_next().await { + match done.expect("task must not panic") { + Ok(_) | Err(ProxyError::DataQuotaExceeded { .. }) => {} + Err(other) => panic!("unexpected error variant under black-hat race: {other:?}"), + } + } + + let total = stats.get_user_total_octets(&user); + assert!( + total <= soft_limit, + "parallel adversarial race must stay under soft cap" + ); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), total); +} + +#[tokio::test] +async fn integration_without_prefetched_lock_uses_registry_lookup_path() { + let _guard = lookup_counter_test_lock() + .lock() + .unwrap_or_else(|poison| poison.into_inner()); + + let stats = Stats::new(); + let user = format!("quota-extreme-integration-{}", std::process::id()); + crate::proxy::quota_lock_registry::reset_cross_mode_quota_user_lock_lookup_count_for_tests(); + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let bytes_me2c = AtomicU64::new(0); + + for idx in 0..3u64 { + let result = process_me_writer_response_with_cross_mode_lock( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0x41]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(16), + 0, + None, + &bytes_me2c, + 31_400 + idx, + false, + false, + ) + .await; + + assert!(result.is_ok()); + } + + assert_eq!( + crate::proxy::quota_lock_registry::cross_mode_quota_user_lock_lookup_count_for_user_for_tests(&user), + 3, + "control path should perform one lock-registry lookup per call" + ); +} + +#[tokio::test] +async fn light_fuzz_quota_matrix_preserves_fail_closed_accounting() { + let _guard = lookup_counter_test_lock() + .lock() + .unwrap_or_else(|poison| poison.into_inner()); + + let stats = Stats::new(); + let user = format!("quota-extreme-fuzz-{}", std::process::id()); + let lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); + let bytes_me2c = AtomicU64::new(0); + let mut seed = 0xA11C_55EE_2026_0323u64; + + for idx in 0..512u64 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let quota = 24 + (seed & 0x3f); + let overshoot = (seed >> 13) & 0x0f; + let len = ((seed >> 19) & 0x07) + 1; + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let before = stats.get_user_total_octets(&user); + + let result = process_me_writer_response_with_cross_mode_lock( + MeResponse::Data { + flags: 0, + data: Bytes::from(vec![0x11; len as usize]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + &user, + Some(quota), + overshoot, + Some(&lock), + &bytes_me2c, + 31_500 + idx, + false, + false, + ) + .await; + + let after = stats.get_user_total_octets(&user); + if result.is_ok() { + assert!(after >= before); + } else { + assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. }))); + assert_eq!(after, before); + } + assert_eq!(bytes_me2c.load(Ordering::Relaxed), after); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_prefetched_lock_high_fanout_exact_quota_success_count() { + let _guard = lookup_counter_test_lock() + .lock() + .unwrap_or_else(|poison| poison.into_inner()); + + let stats = Arc::new(Stats::new()); + let user = format!("quota-extreme-stress-{}", std::process::id()); + let quota = 96u64; + let lock: Arc> = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); + let bytes_me2c = Arc::new(AtomicU64::new(0)); + + crate::proxy::quota_lock_registry::reset_cross_mode_quota_user_lock_lookup_count_for_tests(); + + let mut set = JoinSet::new(); + for idx in 0..384u64 { + let stats = Arc::clone(&stats); + let user = user.clone(); + let lock = Arc::clone(&lock); + let bytes_me2c = Arc::clone(&bytes_me2c); + + set.spawn(async move { + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + process_me_writer_response_with_cross_mode_lock( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0xFF]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + stats.as_ref(), + &user, + Some(quota), + 0, + Some(&lock), + bytes_me2c.as_ref(), + 31_600 + idx, + false, + false, + ) + .await + }); + } + + let mut success = 0usize; + while let Some(done) = set.join_next().await { + match done.expect("task must not panic") { + Ok(_) => success += 1, + Err(ProxyError::DataQuotaExceeded { .. }) => {} + Err(other) => panic!("unexpected error variant in stress fanout: {other:?}"), + } + } + + assert_eq!(success, quota as usize); + assert_eq!(stats.get_user_total_octets(&user), quota); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), quota); + assert_eq!( + crate::proxy::quota_lock_registry::cross_mode_quota_user_lock_lookup_count_for_user_for_tests(&user), + 0, + "stress prefetched path must not use lock registry lookups" + ); +} diff --git a/src/proxy/tests/middle_relay_tiny_frame_debt_concurrency_security_tests.rs b/src/proxy/tests/middle_relay_tiny_frame_debt_concurrency_security_tests.rs new file mode 100644 index 0000000..34fc454 --- /dev/null +++ b/src/proxy/tests/middle_relay_tiny_frame_debt_concurrency_security_tests.rs @@ -0,0 +1,361 @@ +use super::*; +use crate::crypto::AesCtr; +use crate::stats::Stats; +use crate::stream::{BufferPool, CryptoReader}; +use std::sync::Arc; +use std::sync::atomic::AtomicU64; +use std::time::Instant; +use tokio::io::{AsyncRead, AsyncWriteExt, duplex}; +use tokio::task::JoinSet; +use tokio::time::{Duration as TokioDuration, sleep}; + +fn make_crypto_reader(reader: T) -> CryptoReader +where + T: AsyncRead + Unpin + Send + 'static, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoReader::new(reader, AesCtr::new(&key, iv)) +} + +fn encrypt_for_reader(plaintext: &[u8]) -> Vec { + let key = [0u8; 32]; + let iv = 0u128; + let mut cipher = AesCtr::new(&key, iv); + cipher.encrypt(plaintext) +} + +fn make_forensics(conn_id: u64, started_at: Instant) -> RelayForensicsState { + RelayForensicsState { + trace_id: 0xB200_0000 + conn_id, + conn_id, + user: format!("tiny-frame-debt-concurrency-user-{conn_id}"), + peer: "127.0.0.1:50000".parse().expect("peer parse must succeed"), + peer_hash: hash_ip("127.0.0.1".parse().expect("ip parse must succeed")), + started_at, + bytes_c2me: 0, + bytes_me2c: Arc::new(AtomicU64::new(0)), + desync_all_full: false, + } +} + +fn make_enabled_idle_policy() -> RelayClientIdlePolicy { + RelayClientIdlePolicy { + enabled: true, + soft_idle: Duration::from_millis(50), + hard_idle: Duration::from_millis(120), + grace_after_downstream_activity: Duration::from_secs(0), + legacy_frame_read_timeout: Duration::from_millis(50), + } +} + +async fn read_once( + crypto_reader: &mut CryptoReader, + proto: ProtoTag, + forensics: &RelayForensicsState, + frame_counter: &mut u64, + idle_state: &mut RelayClientIdleState, +) -> Result> { + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let idle_policy = make_enabled_idle_policy(); + let last_downstream_activity_ms = AtomicU64::new(0); + read_client_payload_with_idle_policy( + crypto_reader, + proto, + 1024, + &buffer_pool, + forensics, + frame_counter, + &stats, + &idle_policy, + idle_state, + &last_downstream_activity_ms, + forensics.started_at, + ) + .await +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_parallel_pure_tiny_floods_all_fail_closed() { + let mut set = JoinSet::new(); + + for idx in 0..32u64 { + set.spawn(async move { + let (reader, mut writer) = duplex(4096); + let mut crypto_reader = make_crypto_reader(reader); + let started = Instant::now(); + let forensics = make_forensics(1000 + idx, started); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(started); + + let flood_plaintext = vec![0u8; 1024]; + let flood_encrypted = encrypt_for_reader(&flood_plaintext); + writer.write_all(&flood_encrypted).await.unwrap(); + drop(writer); + + let result = run_relay_test_step_timeout( + "tiny flood task", + read_once( + &mut crypto_reader, + ProtoTag::Abridged, + &forensics, + &mut frame_counter, + &mut idle_state, + ), + ) + .await; + + assert!(matches!(result, Err(ProxyError::Proxy(_)))); + assert_eq!(frame_counter, 0); + }); + } + + while let Some(result) = set.join_next().await { + result.expect("parallel tiny flood worker must not panic"); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_parallel_benign_tiny_burst_then_real_all_pass() { + let mut set = JoinSet::new(); + + for idx in 0..24u64 { + set.spawn(async move { + let (reader, mut writer) = duplex(2048); + let mut crypto_reader = make_crypto_reader(reader); + let started = Instant::now(); + let forensics = make_forensics(2000 + idx, started); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(started); + + let payload = [idx as u8, 2, 3, 4]; + let mut plaintext = Vec::with_capacity(20); + for _ in 0..6 { + plaintext.push(0x00); + } + plaintext.push(0x01); + plaintext.extend_from_slice(&payload); + let encrypted = encrypt_for_reader(&plaintext); + writer.write_all(&encrypted).await.unwrap(); + + let result = run_relay_test_step_timeout( + "benign tiny burst read", + read_once( + &mut crypto_reader, + ProtoTag::Abridged, + &forensics, + &mut frame_counter, + &mut idle_state, + ), + ) + .await + .expect("benign payload must parse") + .expect("benign payload must return frame"); + + assert_eq!(result.0.as_ref(), &payload); + assert_eq!(frame_counter, 1); + }); + } + + while let Some(result) = set.join_next().await { + result.expect("parallel benign worker must not panic"); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn adversarial_lockstep_alternating_attack_under_jitter_closes() { + let mut set = JoinSet::new(); + + for idx in 0..12u64 { + set.spawn(async move { + let (reader, mut writer) = duplex(8192); + let mut crypto_reader = make_crypto_reader(reader); + let started = Instant::now(); + let forensics = make_forensics(3000 + idx, started); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(started); + + let mut plaintext = Vec::with_capacity(2000); + for n in 0..180u8 { + plaintext.push(0x00); + plaintext.push(0x01); + plaintext.extend_from_slice(&[n, n ^ 0x21, n ^ 0x42, n ^ 0x84]); + } + let encrypted = encrypt_for_reader(&plaintext); + + let writer_task = tokio::spawn(async move { + for chunk in encrypted.chunks(17) { + writer.write_all(chunk).await.unwrap(); + sleep(TokioDuration::from_millis(1)).await; + } + drop(writer); + }); + + let mut closed = false; + for _ in 0..220 { + let result = run_relay_test_step_timeout( + "alternating jitter read step", + read_once( + &mut crypto_reader, + ProtoTag::Abridged, + &forensics, + &mut frame_counter, + &mut idle_state, + ), + ) + .await; + + match result { + Ok(Some((_payload, _))) => {} + Err(ProxyError::Proxy(_)) => { + closed = true; + break; + } + Ok(None) => break, + Err(other) => panic!("unexpected error in alternating jitter case: {other}"), + } + } + + writer_task.await.expect("writer jitter task must not panic"); + assert!(closed, "alternating attack must close before EOF"); + }); + } + + while let Some(result) = set.join_next().await { + result.expect("alternating jitter worker must not panic"); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn integration_mixed_population_attackers_close_benign_survive() { + let mut set = JoinSet::new(); + + for idx in 0..20u64 { + set.spawn(async move { + let (reader, mut writer) = duplex(4096); + let mut crypto_reader = make_crypto_reader(reader); + let started = Instant::now(); + let forensics = make_forensics(4000 + idx, started); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(started); + + if idx % 2 == 0 { + let mut plaintext = Vec::with_capacity(1280); + for n in 0..140u8 { + plaintext.push(0x00); + plaintext.push(0x01); + plaintext.extend_from_slice(&[n, n, n, n]); + } + writer.write_all(&encrypt_for_reader(&plaintext)).await.unwrap(); + drop(writer); + + let mut closed = false; + for _ in 0..200 { + match read_once( + &mut crypto_reader, + ProtoTag::Abridged, + &forensics, + &mut frame_counter, + &mut idle_state, + ) + .await + { + Ok(Some(_)) => {} + Err(ProxyError::Proxy(_)) => { + closed = true; + break; + } + Ok(None) => break, + Err(other) => panic!("unexpected attacker error: {other}"), + } + } + assert!(closed, "attacker session must fail closed"); + } else { + let payload = [1u8, 9, 8, 7]; + let mut plaintext = Vec::new(); + for _ in 0..4 { + plaintext.push(0x00); + } + plaintext.push(0x01); + plaintext.extend_from_slice(&payload); + writer.write_all(&encrypt_for_reader(&plaintext)).await.unwrap(); + + let got = read_once( + &mut crypto_reader, + ProtoTag::Abridged, + &forensics, + &mut frame_counter, + &mut idle_state, + ) + .await + .expect("benign session must parse") + .expect("benign session must return a frame"); + assert_eq!(got.0.as_ref(), &payload); + } + }); + } + + while let Some(result) = set.join_next().await { + result.expect("mixed-population worker must not panic"); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn light_fuzz_parallel_patterns_no_hang_or_panic() { + let mut set = JoinSet::new(); + + for case in 0..40u64 { + set.spawn(async move { + let (reader, mut writer) = duplex(8192); + let mut crypto_reader = make_crypto_reader(reader); + let started = Instant::now(); + let forensics = make_forensics(5000 + case, started); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(started); + + let mut seed = 0x9E37_79B9u64 ^ (case << 8); + let mut plaintext = Vec::with_capacity(2048); + for _ in 0..256 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + let is_tiny = (seed & 1) == 0; + if is_tiny { + plaintext.push(0x00); + } else { + plaintext.push(0x01); + plaintext.extend_from_slice(&[(seed >> 8) as u8, 2, 3, 4]); + } + } + + writer.write_all(&encrypt_for_reader(&plaintext)).await.unwrap(); + drop(writer); + + for _ in 0..320 { + let step = run_relay_test_step_timeout( + "fuzz case read step", + read_once( + &mut crypto_reader, + ProtoTag::Abridged, + &forensics, + &mut frame_counter, + &mut idle_state, + ), + ) + .await; + + match step { + Ok(Some(_)) => {} + Err(ProxyError::Proxy(_)) => break, + Ok(None) => break, + Err(other) => panic!("unexpected fuzz case error: {other}"), + } + } + }); + } + + while let Some(result) = set.join_next().await { + result.expect("fuzz worker must not panic"); + } +} diff --git a/src/proxy/tests/middle_relay_tiny_frame_debt_proto_chunking_security_tests.rs b/src/proxy/tests/middle_relay_tiny_frame_debt_proto_chunking_security_tests.rs new file mode 100644 index 0000000..853b381 --- /dev/null +++ b/src/proxy/tests/middle_relay_tiny_frame_debt_proto_chunking_security_tests.rs @@ -0,0 +1,425 @@ +use super::*; +use crate::crypto::AesCtr; +use crate::stats::Stats; +use crate::stream::{BufferPool, CryptoReader, PooledBuffer}; +use std::sync::Arc; +use std::sync::atomic::AtomicU64; +use std::time::Instant; +use tokio::io::{AsyncRead, AsyncWriteExt, duplex}; +use tokio::time::{Duration as TokioDuration, sleep}; + +fn make_crypto_reader(reader: T) -> CryptoReader +where + T: AsyncRead + Unpin + Send + 'static, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoReader::new(reader, AesCtr::new(&key, iv)) +} + +fn encrypt_for_reader(plaintext: &[u8]) -> Vec { + let key = [0u8; 32]; + let iv = 0u128; + let mut cipher = AesCtr::new(&key, iv); + cipher.encrypt(plaintext) +} + +fn make_forensics(conn_id: u64, started_at: Instant) -> RelayForensicsState { + RelayForensicsState { + trace_id: 0xB300_0000 + conn_id, + conn_id, + user: format!("tiny-frame-debt-proto-chunk-user-{conn_id}"), + peer: "127.0.0.1:50000".parse().expect("peer parse must succeed"), + peer_hash: hash_ip("127.0.0.1".parse().expect("ip parse must succeed")), + started_at, + bytes_c2me: 0, + bytes_me2c: Arc::new(AtomicU64::new(0)), + desync_all_full: false, + } +} + +fn make_enabled_idle_policy() -> RelayClientIdlePolicy { + RelayClientIdlePolicy { + enabled: true, + soft_idle: Duration::from_millis(50), + hard_idle: Duration::from_millis(120), + grace_after_downstream_activity: Duration::from_secs(0), + legacy_frame_read_timeout: Duration::from_millis(50), + } +} + +fn append_tiny_frame(plaintext: &mut Vec, proto: ProtoTag) { + match proto { + ProtoTag::Abridged => plaintext.push(0x00), + ProtoTag::Intermediate | ProtoTag::Secure => plaintext.extend_from_slice(&0u32.to_le_bytes()), + } +} + +fn append_real_frame(plaintext: &mut Vec, proto: ProtoTag, payload: [u8; 4]) { + match proto { + ProtoTag::Abridged => { + plaintext.push(0x01); + plaintext.extend_from_slice(&payload); + } + ProtoTag::Intermediate | ProtoTag::Secure => { + plaintext.extend_from_slice(&4u32.to_le_bytes()); + plaintext.extend_from_slice(&payload); + } + } +} + +async fn write_chunked_with_jitter( + writer: &mut tokio::io::DuplexStream, + bytes: &[u8], + mut seed: u64, +) { + let mut offset = 0usize; + while offset < bytes.len() { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + let chunk_len = 1 + ((seed as usize) & 0x1f); + let end = (offset + chunk_len).min(bytes.len()); + writer.write_all(&bytes[offset..end]).await.unwrap(); + + let delay_ms = ((seed >> 16) % 3) as u64; + if delay_ms > 0 { + sleep(TokioDuration::from_millis(delay_ms)).await; + } + offset = end; + } +} + +async fn read_once_with_state( + crypto_reader: &mut CryptoReader, + proto: ProtoTag, + forensics: &RelayForensicsState, + frame_counter: &mut u64, + idle_state: &mut RelayClientIdleState, +) -> Result> { + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let idle_policy = make_enabled_idle_policy(); + let last_downstream_activity_ms = AtomicU64::new(0); + read_client_payload_with_idle_policy( + crypto_reader, + proto, + 1024, + &buffer_pool, + forensics, + frame_counter, + &stats, + &idle_policy, + idle_state, + &last_downstream_activity_ms, + forensics.started_at, + ) + .await +} + +fn is_fail_closed_outcome(result: &Result>) -> bool { + matches!(result, Err(ProxyError::Proxy(_))) + || matches!(result, Err(ProxyError::Io(e)) if e.kind() == std::io::ErrorKind::TimedOut) +} + +#[tokio::test] +async fn intermediate_chunked_zero_flood_fail_closed() { + let (reader, mut writer) = duplex(4096); + let mut crypto_reader = make_crypto_reader(reader); + let started = Instant::now(); + let forensics = make_forensics(6101, started); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(started); + + let mut plaintext = Vec::with_capacity(4 * 256); + for _ in 0..256 { + append_tiny_frame(&mut plaintext, ProtoTag::Intermediate); + } + let encrypted = encrypt_for_reader(&plaintext); + write_chunked_with_jitter(&mut writer, &encrypted, 0x1111_2222).await; + drop(writer); + + let result = run_relay_test_step_timeout( + "intermediate flood read", + read_once_with_state( + &mut crypto_reader, + ProtoTag::Intermediate, + &forensics, + &mut frame_counter, + &mut idle_state, + ), + ) + .await; + + assert!( + is_fail_closed_outcome(&result), + "zero-length flood must fail closed via debt guard or idle timeout" + ); + assert_eq!(frame_counter, 0); +} + +#[tokio::test] +async fn secure_chunked_zero_flood_fail_closed() { + let (reader, mut writer) = duplex(4096); + let mut crypto_reader = make_crypto_reader(reader); + let started = Instant::now(); + let forensics = make_forensics(6102, started); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(started); + + let mut plaintext = Vec::with_capacity(4 * 256); + for _ in 0..256 { + append_tiny_frame(&mut plaintext, ProtoTag::Secure); + } + let encrypted = encrypt_for_reader(&plaintext); + write_chunked_with_jitter(&mut writer, &encrypted, 0x3333_4444).await; + drop(writer); + + let result = run_relay_test_step_timeout( + "secure flood read", + read_once_with_state( + &mut crypto_reader, + ProtoTag::Secure, + &forensics, + &mut frame_counter, + &mut idle_state, + ), + ) + .await; + + assert!( + is_fail_closed_outcome(&result), + "secure zero-length flood must fail closed via debt guard or idle timeout" + ); + assert_eq!(frame_counter, 0); +} + +#[tokio::test] +async fn intermediate_chunked_alternating_attack_closes_before_eof() { + let (reader, mut writer) = duplex(8192); + let mut crypto_reader = make_crypto_reader(reader); + let started = Instant::now(); + let forensics = make_forensics(6103, started); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(started); + + let mut plaintext = Vec::with_capacity(8 * 200); + for n in 0..180u8 { + append_tiny_frame(&mut plaintext, ProtoTag::Intermediate); + append_real_frame(&mut plaintext, ProtoTag::Intermediate, [n, n ^ 1, n ^ 2, n ^ 3]); + } + let encrypted = encrypt_for_reader(&plaintext); + + let writer_task = tokio::spawn(async move { + write_chunked_with_jitter(&mut writer, &encrypted, 0x5555_6666).await; + drop(writer); + }); + + let mut closed = false; + for _ in 0..240 { + let step = run_relay_test_step_timeout( + "intermediate alternating read step", + read_once_with_state( + &mut crypto_reader, + ProtoTag::Intermediate, + &forensics, + &mut frame_counter, + &mut idle_state, + ), + ) + .await; + + match step { + Ok(Some(_)) => {} + Err(ProxyError::Proxy(_)) => { + closed = true; + break; + } + Ok(None) => break, + Err(other) => panic!("unexpected intermediate alternating error: {other}"), + } + } + + writer_task.await.expect("intermediate writer task must not panic"); + assert!(closed, "intermediate alternating attack must fail closed"); +} + +#[tokio::test] +async fn secure_chunked_alternating_attack_closes_before_eof() { + let (reader, mut writer) = duplex(8192); + let mut crypto_reader = make_crypto_reader(reader); + let started = Instant::now(); + let forensics = make_forensics(6104, started); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(started); + + let mut plaintext = Vec::with_capacity(8 * 200); + for n in 0..180u8 { + append_tiny_frame(&mut plaintext, ProtoTag::Secure); + append_real_frame(&mut plaintext, ProtoTag::Secure, [n, n ^ 7, n ^ 11, n ^ 19]); + } + let encrypted = encrypt_for_reader(&plaintext); + + let writer_task = tokio::spawn(async move { + write_chunked_with_jitter(&mut writer, &encrypted, 0x7777_8888).await; + drop(writer); + }); + + let mut closed = false; + for _ in 0..240 { + let step = run_relay_test_step_timeout( + "secure alternating read step", + read_once_with_state( + &mut crypto_reader, + ProtoTag::Secure, + &forensics, + &mut frame_counter, + &mut idle_state, + ), + ) + .await; + + match step { + Ok(Some(_)) => {} + Err(ProxyError::Proxy(_)) => { + closed = true; + break; + } + Ok(None) => break, + Err(other) => panic!("unexpected secure alternating error: {other}"), + } + } + + writer_task.await.expect("secure writer task must not panic"); + assert!(closed, "secure alternating attack must fail closed"); +} + +#[tokio::test] +async fn intermediate_chunked_safe_small_burst_still_returns_real_frame() { + let (reader, mut writer) = duplex(1024); + let mut crypto_reader = make_crypto_reader(reader); + let started = Instant::now(); + let forensics = make_forensics(6105, started); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(started); + + let payload = [9u8, 8, 7, 6]; + let mut plaintext = Vec::new(); + for _ in 0..7 { + append_tiny_frame(&mut plaintext, ProtoTag::Intermediate); + } + append_real_frame(&mut plaintext, ProtoTag::Intermediate, payload); + let encrypted = encrypt_for_reader(&plaintext); + write_chunked_with_jitter(&mut writer, &encrypted, 0xAAAA_BBBB).await; + + let result = read_once_with_state( + &mut crypto_reader, + ProtoTag::Intermediate, + &forensics, + &mut frame_counter, + &mut idle_state, + ) + .await + .expect("intermediate safe burst should parse") + .expect("intermediate safe burst should return a frame"); + + assert_eq!(result.0.as_ref(), &payload); + assert_eq!(frame_counter, 1); +} + +#[tokio::test] +async fn secure_chunked_safe_small_burst_still_returns_real_frame() { + let (reader, mut writer) = duplex(1024); + let mut crypto_reader = make_crypto_reader(reader); + let started = Instant::now(); + let forensics = make_forensics(6106, started); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(started); + + let payload = [3u8, 1, 4, 1]; + let mut plaintext = Vec::new(); + for _ in 0..7 { + append_tiny_frame(&mut plaintext, ProtoTag::Secure); + } + append_real_frame(&mut plaintext, ProtoTag::Secure, payload); + let encrypted = encrypt_for_reader(&plaintext); + write_chunked_with_jitter(&mut writer, &encrypted, 0xCCCC_DDDD).await; + + let result = read_once_with_state( + &mut crypto_reader, + ProtoTag::Secure, + &forensics, + &mut frame_counter, + &mut idle_state, + ) + .await + .expect("secure safe burst should parse") + .expect("secure safe burst should return a frame"); + + assert_eq!(result.0.as_ref(), &payload); + assert_eq!(frame_counter, 1); +} + +#[tokio::test] +async fn light_fuzz_proto_chunking_outcomes_are_bounded() { + let mut seed = 0xDEAD_BEEF_2026_0322u64; + + for case in 0..48u64 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let proto = if (seed & 1) == 0 { + ProtoTag::Intermediate + } else { + ProtoTag::Secure + }; + + let (reader, mut writer) = duplex(8192); + let mut crypto_reader = make_crypto_reader(reader); + let started = Instant::now(); + let forensics = make_forensics(6200 + case, started); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(started); + + let mut stream = Vec::new(); + let mut local_seed = seed ^ case; + for _ in 0..220 { + local_seed ^= local_seed << 7; + local_seed ^= local_seed >> 9; + local_seed ^= local_seed << 8; + if (local_seed & 1) == 0 { + append_tiny_frame(&mut stream, proto); + } else { + let b = (local_seed >> 8) as u8; + append_real_frame(&mut stream, proto, [b, b ^ 0x12, b ^ 0x24, b ^ 0x48]); + } + } + + let encrypted = encrypt_for_reader(&stream); + write_chunked_with_jitter(&mut writer, &encrypted, seed ^ 0x1234_5678).await; + drop(writer); + + for _ in 0..260 { + let step = run_relay_test_step_timeout( + "fuzz proto read step", + read_once_with_state( + &mut crypto_reader, + proto, + &forensics, + &mut frame_counter, + &mut idle_state, + ), + ) + .await; + + match step { + Ok(Some((_payload, _))) => {} + Err(ProxyError::Proxy(_)) => break, + Err(ProxyError::Io(e)) if e.kind() == std::io::ErrorKind::TimedOut => break, + Ok(None) => break, + Err(other) => panic!("unexpected proto chunking fuzz error: {other}"), + } + } + } +} diff --git a/src/proxy/tests/middle_relay_tiny_frame_debt_security_tests.rs b/src/proxy/tests/middle_relay_tiny_frame_debt_security_tests.rs new file mode 100644 index 0000000..dee5dd9 --- /dev/null +++ b/src/proxy/tests/middle_relay_tiny_frame_debt_security_tests.rs @@ -0,0 +1,798 @@ +use super::*; +use crate::crypto::AesCtr; +use crate::stats::Stats; +use crate::stream::{BufferPool, CryptoReader}; +use std::sync::atomic::AtomicU64; +use std::sync::Arc; +use std::time::Instant; +use tokio::io::{AsyncRead, AsyncWriteExt, duplex}; + +fn make_crypto_reader(reader: T) -> CryptoReader +where + T: AsyncRead + Unpin + Send + 'static, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoReader::new(reader, AesCtr::new(&key, iv)) +} + +fn encrypt_for_reader(plaintext: &[u8]) -> Vec { + let key = [0u8; 32]; + let iv = 0u128; + let mut cipher = AesCtr::new(&key, iv); + cipher.encrypt(plaintext) +} + +fn make_forensics(conn_id: u64, started_at: Instant) -> RelayForensicsState { + RelayForensicsState { + trace_id: 0xB100_0000 + conn_id, + conn_id, + user: format!("tiny-frame-debt-user-{conn_id}"), + peer: "127.0.0.1:50000".parse().expect("peer parse must succeed"), + peer_hash: hash_ip("127.0.0.1".parse().expect("ip parse must succeed")), + started_at, + bytes_c2me: 0, + bytes_me2c: Arc::new(AtomicU64::new(0)), + desync_all_full: false, + } +} + +fn make_enabled_idle_policy() -> RelayClientIdlePolicy { + RelayClientIdlePolicy { + enabled: true, + soft_idle: Duration::from_millis(50), + hard_idle: Duration::from_millis(120), + grace_after_downstream_activity: Duration::from_secs(0), + legacy_frame_read_timeout: Duration::from_millis(50), + } +} + +async fn read_bounded( + crypto_reader: &mut CryptoReader, + proto_tag: ProtoTag, + buffer_pool: &Arc, + forensics: &RelayForensicsState, + frame_counter: &mut u64, + stats: &Stats, + idle_policy: &RelayClientIdlePolicy, + idle_state: &mut RelayClientIdleState, + last_downstream_activity_ms: &AtomicU64, + session_started_at: Instant, +) -> Result> { + run_relay_test_step_timeout( + "tiny-frame debt read step", + read_client_payload_with_idle_policy( + crypto_reader, + proto_tag, + 1024, + buffer_pool, + forensics, + frame_counter, + stats, + idle_policy, + idle_state, + last_downstream_activity_ms, + session_started_at, + ), + ) + .await +} + +fn simulate_tiny_debt_pattern(pattern: &[bool], max_steps: usize) -> (Option, u32, usize) { + let mut debt = 0u32; + let mut reals = 0usize; + for (idx, is_tiny) in pattern.iter().copied().take(max_steps).enumerate() { + if is_tiny { + debt = debt.saturating_add(TINY_FRAME_DEBT_PER_TINY); + if debt >= TINY_FRAME_DEBT_LIMIT { + return (Some(idx + 1), debt, reals); + } + } else { + reals = reals.saturating_add(1); + debt = debt.saturating_sub(1); + } + } + (None, debt, reals) +} + +#[test] +fn tiny_frame_debt_constants_match_security_budget_expectations() { + assert_eq!(TINY_FRAME_DEBT_PER_TINY, 8); + assert_eq!(TINY_FRAME_DEBT_LIMIT, 512); +} + +#[test] +fn relay_client_idle_state_initial_debt_is_zero() { + let state = RelayClientIdleState::new(Instant::now()); + assert_eq!(state.tiny_frame_debt, 0); +} + +#[test] +fn on_client_frame_does_not_reset_tiny_frame_debt() { + let now = Instant::now(); + let mut state = RelayClientIdleState::new(now); + state.tiny_frame_debt = 77; + state.on_client_frame(now); + assert_eq!(state.tiny_frame_debt, 77); +} + +#[test] +fn tiny_frame_debt_increment_is_saturating() { + let mut debt = u32::MAX - 1; + debt = debt.saturating_add(TINY_FRAME_DEBT_PER_TINY); + assert_eq!(debt, u32::MAX); +} + +#[test] +fn tiny_frame_debt_decrement_is_saturating() { + let mut debt = 0u32; + debt = debt.saturating_sub(1); + assert_eq!(debt, 0); +} + +#[test] +fn consecutive_tiny_frames_close_exactly_at_threshold() { + let max_tiny_without_close = (TINY_FRAME_DEBT_LIMIT / TINY_FRAME_DEBT_PER_TINY) as usize; + let pattern = vec![true; max_tiny_without_close]; + let (closed_at, _, _) = simulate_tiny_debt_pattern(&pattern, pattern.len()); + assert_eq!(closed_at, Some(max_tiny_without_close)); +} + +#[test] +fn one_less_than_threshold_tiny_frames_do_not_close() { + let tiny_count = (TINY_FRAME_DEBT_LIMIT / TINY_FRAME_DEBT_PER_TINY) as usize - 1; + let pattern = vec![true; tiny_count]; + let (closed_at, debt, _) = simulate_tiny_debt_pattern(&pattern, pattern.len()); + assert_eq!(closed_at, None); + assert!(debt < TINY_FRAME_DEBT_LIMIT); +} + +#[test] +fn alternating_one_to_one_closes_with_bounded_real_frame_count() { + let mut pattern = Vec::with_capacity(512); + for _ in 0..256 { + pattern.push(true); + pattern.push(false); + } + let (closed_at, _, reals) = simulate_tiny_debt_pattern(&pattern, pattern.len()); + assert!(closed_at.is_some()); + assert!(reals <= 80, "expected bounded real frames before close, got {reals}"); +} + +#[test] +fn alternating_one_to_eight_is_stable_for_long_runs() { + let mut pattern = Vec::with_capacity(9 * 5000); + for _ in 0..5000 { + pattern.push(true); + for _ in 0..8 { + pattern.push(false); + } + } + let (closed_at, debt, _) = simulate_tiny_debt_pattern(&pattern, pattern.len()); + assert_eq!(closed_at, None); + assert!(debt <= TINY_FRAME_DEBT_PER_TINY); +} + +#[test] +fn alternating_one_to_seven_eventually_closes() { + let mut pattern = Vec::with_capacity(8 * 2000); + for _ in 0..2000 { + pattern.push(true); + for _ in 0..7 { + pattern.push(false); + } + } + let (closed_at, _, _) = simulate_tiny_debt_pattern(&pattern, pattern.len()); + assert!(closed_at.is_some(), "1:7 tiny-to-real must eventually close"); +} + +#[test] +fn two_tiny_one_real_closes_faster_than_one_to_one() { + let mut one_to_one = Vec::with_capacity(512); + for _ in 0..256 { + one_to_one.push(true); + one_to_one.push(false); + } + + let mut two_to_one = Vec::with_capacity(768); + for _ in 0..256 { + two_to_one.push(true); + two_to_one.push(true); + two_to_one.push(false); + } + + let (a_close, _, _) = simulate_tiny_debt_pattern(&one_to_one, one_to_one.len()); + let (b_close, _, _) = simulate_tiny_debt_pattern(&two_to_one, two_to_one.len()); + assert!(a_close.is_some() && b_close.is_some()); + assert!(b_close.unwrap_or(usize::MAX) < a_close.unwrap_or(0)); +} + +#[test] +fn burst_then_drain_can_recover_without_close() { + let burst_tiny = ((TINY_FRAME_DEBT_LIMIT / TINY_FRAME_DEBT_PER_TINY) / 2) as usize; + let mut pattern = Vec::with_capacity(burst_tiny + 600); + for _ in 0..burst_tiny { + pattern.push(true); + } + pattern.extend(std::iter::repeat_n(false, 600)); + + let (closed_at, debt, _) = simulate_tiny_debt_pattern(&pattern, pattern.len()); + assert_eq!(closed_at, None); + assert_eq!(debt, 0); +} + +#[test] +fn light_fuzz_tiny_frame_debt_model_stays_within_bounds() { + let mut seed = 0xA5A5_91C3_2026_0322u64; + for _case in 0..128 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let len = 512 + ((seed as usize) & 0x3ff); + let mut pattern = Vec::with_capacity(len); + let mut local_seed = seed; + for _ in 0..len { + local_seed ^= local_seed << 7; + local_seed ^= local_seed >> 9; + local_seed ^= local_seed << 8; + pattern.push((local_seed & 1) == 0); + } + + let (closed_at, debt, _) = simulate_tiny_debt_pattern(&pattern, pattern.len()); + if closed_at.is_none() { + assert!(debt < TINY_FRAME_DEBT_LIMIT); + } + assert!(debt <= u32::MAX); + } +} + +#[test] +fn stress_many_independent_simulations_keep_isolated_debt_state() { + for idx in 0..2048usize { + let mut pattern = Vec::with_capacity(64); + for j in 0..64usize { + pattern.push(((idx ^ j) & 3) == 0); + } + let (_closed_at, debt, _reals) = simulate_tiny_debt_pattern(&pattern, pattern.len()); + assert!(debt <= TINY_FRAME_DEBT_LIMIT.saturating_add(TINY_FRAME_DEBT_PER_TINY)); + } +} + +#[tokio::test] +async fn idle_policy_enabled_intermediate_zero_length_flood_is_fail_closed() { + let (reader, mut writer) = duplex(4096); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(11, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_enabled_idle_policy(); + let last_downstream_activity_ms = AtomicU64::new(0); + + let flood_plaintext = vec![0u8; 4 * 256]; + let flood_encrypted = encrypt_for_reader(&flood_plaintext); + writer.write_all(&flood_encrypted).await.unwrap(); + drop(writer); + + let result = read_bounded( + &mut crypto_reader, + ProtoTag::Intermediate, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ) + .await; + + assert!(matches!(result, Err(ProxyError::Proxy(_)))); +} + +#[tokio::test] +async fn idle_policy_enabled_secure_zero_length_flood_is_fail_closed() { + let (reader, mut writer) = duplex(4096); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(12, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_enabled_idle_policy(); + let last_downstream_activity_ms = AtomicU64::new(0); + + let flood_plaintext = vec![0u8; 4 * 256]; + let flood_encrypted = encrypt_for_reader(&flood_plaintext); + writer.write_all(&flood_encrypted).await.unwrap(); + drop(writer); + + let result = read_bounded( + &mut crypto_reader, + ProtoTag::Secure, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ) + .await; + + assert!(matches!(result, Err(ProxyError::Proxy(_)))); +} + +#[tokio::test] +async fn intermediate_alternating_zero_and_real_eventually_closes() { + let (reader, mut writer) = duplex(8192); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(13, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_enabled_idle_policy(); + let last_downstream_activity_ms = AtomicU64::new(0); + + let mut plaintext = Vec::with_capacity(3000); + for idx in 0..160u8 { + plaintext.extend_from_slice(&0u32.to_le_bytes()); + plaintext.extend_from_slice(&4u32.to_le_bytes()); + plaintext.extend_from_slice(&[idx, idx ^ 0x11, idx ^ 0x22, idx ^ 0x33]); + } + let encrypted = encrypt_for_reader(&plaintext); + writer.write_all(&encrypted).await.unwrap(); + drop(writer); + + let mut closed = false; + for _ in 0..220 { + let result = read_bounded( + &mut crypto_reader, + ProtoTag::Intermediate, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ) + .await; + + match result { + Ok(Some(_)) => {} + Err(ProxyError::Proxy(_)) => { + closed = true; + break; + } + Ok(None) => break, + Err(other) => panic!("unexpected error while probing alternating close: {other}"), + } + } + + assert!(closed, "intermediate alternating attack must fail closed"); +} + +#[tokio::test] +async fn small_tiny_burst_followed_by_real_frame_does_not_spuriously_close() { + let (reader, mut writer) = duplex(1024); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(14, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_enabled_idle_policy(); + let last_downstream_activity_ms = AtomicU64::new(0); + + let mut plaintext = Vec::with_capacity(64); + for _ in 0..8 { + plaintext.push(0x00); + } + plaintext.push(0x01); + plaintext.extend_from_slice(&[1, 2, 3, 4]); + + let encrypted = encrypt_for_reader(&plaintext); + writer.write_all(&encrypted).await.unwrap(); + + let first = read_bounded( + &mut crypto_reader, + ProtoTag::Abridged, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ) + .await; + + match first { + Ok(Some((payload, _))) => assert_eq!(payload.as_ref(), &[1, 2, 3, 4]), + Err(e) => panic!("unexpected close after small tiny burst: {e}"), + Ok(None) => panic!("unexpected EOF before real frame"), + } +} + +#[tokio::test] +async fn idle_policy_enabled_zero_length_flood_is_fail_closed() { + let (reader, mut writer) = duplex(4096); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(1, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_enabled_idle_policy(); + let last_downstream_activity_ms = AtomicU64::new(0); + + let flood_plaintext = vec![0u8; 1024]; + let flood_encrypted = encrypt_for_reader(&flood_plaintext); + writer + .write_all(&flood_encrypted) + .await + .expect("zero-length flood bytes must be writable"); + drop(writer); + + let result = read_bounded( + &mut crypto_reader, + ProtoTag::Abridged, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ) + .await; + + assert!( + matches!(result, Err(ProxyError::Proxy(_))), + "idle policy enabled must fail closed for pure zero-length flood" + ); +} + +#[tokio::test] +async fn idle_policy_enabled_alternating_tiny_real_eventually_closes() { + let (reader, mut writer) = duplex(8192); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(2, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_enabled_idle_policy(); + let last_downstream_activity_ms = AtomicU64::new(0); + + let mut plaintext = Vec::with_capacity(256 * 6); + for idx in 0..=255u8 { + plaintext.push(0x00); + plaintext.push(0x01); + plaintext.extend_from_slice(&[idx, idx ^ 0x55, idx ^ 0xAA, 0x11]); + } + + let encrypted = encrypt_for_reader(&plaintext); + writer + .write_all(&encrypted) + .await + .expect("alternating flood bytes must be writable"); + drop(writer); + + let mut saw_proxy_close = false; + for _ in 0..300 { + let result = read_bounded( + &mut crypto_reader, + ProtoTag::Abridged, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ) + .await; + + match result { + Ok(Some((_payload, _quickack))) => {} + Err(ProxyError::Proxy(_)) => { + saw_proxy_close = true; + break; + } + Err(ProxyError::Io(e)) => panic!("unexpected IO error before close: {e}"), + Ok(None) => panic!("unexpected EOF before debt-based closure"), + Err(other) => panic!("unexpected error before close: {other}"), + } + } + + assert!( + saw_proxy_close, + "alternating tiny/real sequence must eventually fail closed" + ); +} + +#[tokio::test] +async fn enabled_idle_policy_valid_nonzero_frame_still_passes() { + let (reader, mut writer) = duplex(1024); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(3, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_enabled_idle_policy(); + let last_downstream_activity_ms = AtomicU64::new(0); + + let payload = [7u8, 8, 9, 10]; + let mut plaintext = Vec::with_capacity(1 + payload.len()); + plaintext.push(0x01); + plaintext.extend_from_slice(&payload); + + let encrypted = encrypt_for_reader(&plaintext); + writer + .write_all(&encrypted) + .await + .expect("nonzero frame must be writable"); + + let result = read_bounded( + &mut crypto_reader, + ProtoTag::Abridged, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ) + .await + .expect("valid frame should decode") + .expect("valid frame should return payload"); + + assert_eq!(result.0.as_ref(), &payload); + assert!(!result.1); + assert_eq!(frame_counter, 1); +} + +#[tokio::test] +async fn abridged_quickack_tiny_flood_is_fail_closed() { + let (reader, mut writer) = duplex(4096); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(21, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_enabled_idle_policy(); + let last_downstream_activity_ms = AtomicU64::new(0); + + let flood_plaintext = vec![0x80u8; 256]; + let flood_encrypted = encrypt_for_reader(&flood_plaintext); + writer.write_all(&flood_encrypted).await.unwrap(); + drop(writer); + + let result = read_bounded( + &mut crypto_reader, + ProtoTag::Abridged, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ) + .await; + + assert!( + matches!(result, Err(ProxyError::Proxy(_))), + "quickack-marked zero-length flood must fail closed" + ); +} + +#[tokio::test] +async fn abridged_extended_zero_len_flood_is_fail_closed() { + let (reader, mut writer) = duplex(4096); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(22, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_enabled_idle_policy(); + let last_downstream_activity_ms = AtomicU64::new(0); + + let mut flood_plaintext = Vec::with_capacity(4 * 256); + for _ in 0..256 { + flood_plaintext.extend_from_slice(&[0x7f, 0x00, 0x00, 0x00]); + } + let flood_encrypted = encrypt_for_reader(&flood_plaintext); + writer.write_all(&flood_encrypted).await.unwrap(); + drop(writer); + + let result = read_bounded( + &mut crypto_reader, + ProtoTag::Abridged, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ) + .await; + + assert!( + matches!(result, Err(ProxyError::Proxy(_))), + "extended zero-length abridged flood must fail closed" + ); +} + +#[tokio::test] +async fn one_to_eight_abridged_wire_pattern_survives_without_false_positive_close() { + let mut plaintext = Vec::with_capacity(9 * 300); + for idx in 0..300usize { + plaintext.push(0x00); + for _ in 0..8 { + let b = idx as u8; + plaintext.push(0x01); + plaintext.extend_from_slice(&[b, b ^ 0x11, b ^ 0x22, b ^ 0x33]); + } + } + + // Keep the test single-task and deterministic: make duplex capacity larger than the + // generated ciphertext so write_all cannot block waiting for a concurrent reader. + let duplex_capacity = plaintext.len().saturating_add(1024); + let (reader, mut writer) = duplex(duplex_capacity); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(23, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_enabled_idle_policy(); + let last_downstream_activity_ms = AtomicU64::new(0); + + let encrypted = encrypt_for_reader(&plaintext); + writer.write_all(&encrypted).await.unwrap(); + drop(writer); + + let mut closed = false; + for _ in 0..3000 { + match read_bounded( + &mut crypto_reader, + ProtoTag::Abridged, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ) + .await + { + Ok(Some(_)) => {} + Ok(None) => break, + Err(ProxyError::Proxy(_)) => { + closed = true; + break; + } + Err(other) => panic!("unexpected error in 1:8 wire test: {other}"), + } + } + + assert!( + !closed, + "wire-level 1:8 tiny-to-real pattern should not trigger debt close" + ); +} + +#[tokio::test] +async fn deterministic_light_fuzz_abridged_wire_behavior_matches_model() { + let mut seed = 0xD1CE_BAAD_2026_0322u64; + + for case_idx in 0..32u64 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let events = 300 + ((seed as usize) & 0xff); + let mut pattern = Vec::with_capacity(events); + let mut local = seed; + for _ in 0..events { + local ^= local << 7; + local ^= local >> 9; + local ^= local << 8; + pattern.push((local & 0x03) == 0); + } + + let mut plaintext = Vec::with_capacity(events * 6); + for (idx, tiny) in pattern.iter().copied().enumerate() { + if tiny { + plaintext.push(0x00); + } else { + let b = (idx as u8) ^ (case_idx as u8); + plaintext.push(0x01); + plaintext.extend_from_slice(&[b, b ^ 0x1F, b ^ 0x7A, b ^ 0xC3]); + } + } + + let (reader, mut writer) = duplex(16 * 1024); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(500 + case_idx, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_enabled_idle_policy(); + let last_downstream_activity_ms = AtomicU64::new(0); + + writer + .write_all(&encrypt_for_reader(&plaintext)) + .await + .unwrap(); + drop(writer); + + let (expected_close, _, _) = simulate_tiny_debt_pattern(&pattern, pattern.len()); + let mut observed_close = false; + + for _ in 0..(events + 8) { + match read_bounded( + &mut crypto_reader, + ProtoTag::Abridged, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ) + .await + { + Ok(Some(_)) => {} + Ok(None) => break, + Err(ProxyError::Proxy(_)) => { + observed_close = true; + break; + } + Err(other) => panic!("unexpected fuzz error: {other}"), + } + } + + assert_eq!( + observed_close, + expected_close.is_some(), + "wire parser behavior must match debt model for case {case_idx}" + ); + } +} diff --git a/src/proxy/tests/middle_relay_zero_length_frame_security_tests.rs b/src/proxy/tests/middle_relay_zero_length_frame_security_tests.rs new file mode 100644 index 0000000..765c253 --- /dev/null +++ b/src/proxy/tests/middle_relay_zero_length_frame_security_tests.rs @@ -0,0 +1,121 @@ +use super::*; +use crate::crypto::AesCtr; +use crate::stats::Stats; +use crate::stream::{BufferPool, CryptoReader}; +use std::sync::atomic::AtomicU64; +use std::sync::Arc; +use tokio::io::{AsyncRead, AsyncWriteExt, duplex}; +use std::time::Instant; + +fn make_crypto_reader(reader: T) -> CryptoReader +where + T: AsyncRead + Unpin + Send + 'static, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoReader::new(reader, AesCtr::new(&key, iv)) +} + +fn encrypt_for_reader(plaintext: &[u8]) -> Vec { + let key = [0u8; 32]; + let iv = 0u128; + let mut cipher = AesCtr::new(&key, iv); + cipher.encrypt(plaintext) +} + +fn make_forensics(conn_id: u64, started_at: Instant) -> RelayForensicsState { + RelayForensicsState { + trace_id: 0xB000_0000 + conn_id, + conn_id, + user: format!("zero-len-test-user-{conn_id}"), + peer: "127.0.0.1:50000".parse().expect("peer parse must succeed"), + peer_hash: hash_ip("127.0.0.1".parse().expect("ip parse must succeed")), + started_at, + bytes_c2me: 0, + bytes_me2c: Arc::new(AtomicU64::new(0)), + desync_all_full: false, + } +} + +#[tokio::test] +async fn adversarial_legacy_zero_length_flood_is_fail_closed() { + let (reader, mut writer) = duplex(1024); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(1, session_started_at); + let mut frame_counter = 0u64; + + let flood_plaintext = vec![0u8; 128]; + let flood_encrypted = encrypt_for_reader(&flood_plaintext); + writer + .write_all(&flood_encrypted) + .await + .expect("zero-length flood bytes must be writable"); + drop(writer); + + let result = read_client_payload_legacy( + &mut crypto_reader, + ProtoTag::Abridged, + 1024, + Duration::from_millis(30), + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + ) + .await; + + match result { + Err(ProxyError::Proxy(msg)) => { + assert!( + msg.contains("Excessive zero-length"), + "legacy mode must close flood with explicit zero-length reason, got: {msg}" + ); + } + Ok(None) => panic!("legacy zero-length flood must not be accepted as EOF"), + Ok(Some(_)) => panic!("legacy zero-length flood must not produce a data frame"), + Err(err) => panic!("legacy zero-length flood must be a Proxy error, got: {err}"), + } +} + +#[tokio::test] +async fn business_abridged_nonzero_frame_still_passes() { + let (reader, mut writer) = duplex(1024); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(2, session_started_at); + let mut frame_counter = 0u64; + + let payload = [1u8, 2, 3, 4]; + let mut plaintext = Vec::with_capacity(1 + payload.len()); + plaintext.push(0x01); + plaintext.extend_from_slice(&payload); + + let encrypted = encrypt_for_reader(&plaintext); + writer + .write_all(&encrypted) + .await + .expect("nonzero abridged frame must be writable"); + + let result = read_client_payload_legacy( + &mut crypto_reader, + ProtoTag::Abridged, + 1024, + Duration::from_millis(30), + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + ) + .await + .expect("valid abridged frame should decode") + .expect("valid abridged frame should return payload"); + + assert_eq!(result.0.as_ref(), &payload); + assert!(!result.1, "quickack flag must remain false"); + assert_eq!(frame_counter, 1); +} diff --git a/src/proxy/tests/quota_lock_registry_cross_mode_adversarial_tests.rs b/src/proxy/tests/quota_lock_registry_cross_mode_adversarial_tests.rs new file mode 100644 index 0000000..fb0cf93 --- /dev/null +++ b/src/proxy/tests/quota_lock_registry_cross_mode_adversarial_tests.rs @@ -0,0 +1,108 @@ +use super::*; +use std::sync::Arc; +use std::sync::{Mutex, OnceLock}; + +fn cross_mode_lock_test_guard() -> std::sync::MutexGuard<'static, ()> { + static TEST_LOCK: OnceLock> = OnceLock::new(); + TEST_LOCK + .get_or_init(|| Mutex::new(())) + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +#[test] +fn same_user_returns_same_lock_identity() { + let _guard = cross_mode_lock_test_guard(); + let locks = CROSS_MODE_QUOTA_USER_LOCKS.get_or_init(DashMap::new); + locks.clear(); + + let a = cross_mode_quota_user_lock("cross-mode-same-user"); + let b = cross_mode_quota_user_lock("cross-mode-same-user"); + + assert!( + Arc::ptr_eq(&a, &b), + "same user must reuse a stable lock identity" + ); +} + +#[test] +fn saturation_overflow_path_returns_stable_striped_lock_without_cache_growth() { + let _guard = cross_mode_lock_test_guard(); + let locks = CROSS_MODE_QUOTA_USER_LOCKS.get_or_init(DashMap::new); + locks.clear(); + + let prefix = format!("cross-mode-saturated-{}", std::process::id()); + let mut retained = Vec::with_capacity(CROSS_MODE_QUOTA_USER_LOCKS_MAX); + for idx in 0..CROSS_MODE_QUOTA_USER_LOCKS_MAX { + retained.push(cross_mode_quota_user_lock(&format!("{prefix}-{idx}"))); + } + + assert_eq!( + locks.len(), + CROSS_MODE_QUOTA_USER_LOCKS_MAX, + "lock cache must be saturated for overflow check" + ); + + let overflow_user = format!("cross-mode-overflow-{}", std::process::id()); + let overflow_a = cross_mode_quota_user_lock(&overflow_user); + let overflow_b = cross_mode_quota_user_lock(&overflow_user); + + assert_eq!( + locks.len(), + CROSS_MODE_QUOTA_USER_LOCKS_MAX, + "overflow path must not grow bounded lock cache" + ); + assert!( + locks.get(&overflow_user).is_none(), + "overflow user must stay on striped fallback while cache is saturated" + ); + assert!( + Arc::ptr_eq(&overflow_a, &overflow_b), + "overflow user must receive a stable striped lock across repeated lookups" + ); + + drop(retained); +} + +#[test] +fn reclaim_drops_stale_entries_but_preserves_active_user_lock_identity() { + let _guard = cross_mode_lock_test_guard(); + let locks = CROSS_MODE_QUOTA_USER_LOCKS.get_or_init(DashMap::new); + locks.clear(); + + let prefix = format!("cross-mode-reclaim-{}", std::process::id()); + let protected_user = format!("{prefix}-protected"); + + let protected_lock = cross_mode_quota_user_lock(&protected_user); + let mut retained = Vec::with_capacity(CROSS_MODE_QUOTA_USER_LOCKS_MAX.saturating_sub(1)); + for idx in 0..(CROSS_MODE_QUOTA_USER_LOCKS_MAX.saturating_sub(1)) { + retained.push(cross_mode_quota_user_lock(&format!("{prefix}-{idx}"))); + } + + assert_eq!( + locks.len(), + CROSS_MODE_QUOTA_USER_LOCKS_MAX, + "fixture must saturate lock cache before reclaim path is exercised" + ); + + drop(retained); + + let newcomer_user = format!("{prefix}-newcomer"); + let _newcomer = cross_mode_quota_user_lock(&newcomer_user); + + assert!( + locks.get(&protected_user).is_some(), + "active protected user must remain cache-resident after reclaim" + ); + let locked = locks + .get(&protected_user) + .expect("protected user must remain in map after reclaim"); + assert!( + Arc::ptr_eq(locked.value(), &protected_lock), + "reclaim must not swap active user lock identity" + ); + assert!( + locks.get(&newcomer_user).is_some(), + "newcomer should become cacheable after stale entries are reclaimed" + ); +} diff --git a/src/proxy/tests/relay_cross_mode_pipeline_hol_integration_security_tests.rs b/src/proxy/tests/relay_cross_mode_pipeline_hol_integration_security_tests.rs new file mode 100644 index 0000000..9ea921c --- /dev/null +++ b/src/proxy/tests/relay_cross_mode_pipeline_hol_integration_security_tests.rs @@ -0,0 +1,267 @@ +use super::relay_bidirectional; +use crate::stats::Stats; +use crate::stream::BufferPool; +use std::sync::Arc; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::time::{Duration, timeout}; + +fn quota_test_guard() -> impl Drop { + super::quota_user_lock_test_scope() +} + +#[tokio::test] +async fn negative_same_user_pipeline_stalls_while_middle_lock_is_held() { + let _guard = quota_test_guard(); + + let user = format!("relay-pipeline-stall-{}", std::process::id()); + let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user); + let held_guard = held + .try_lock() + .expect("test must hold shared cross-mode lock"); + + let stats = Arc::new(Stats::new()); + let (mut client_peer, relay_client) = duplex(1024); + let (relay_server, mut server_peer) = duplex(1024); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_user = user.clone(); + let relay_stats = Arc::clone(&stats); + let relay_task = tokio::spawn(async move { + relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 256, + 256, + &relay_user, + relay_stats, + Some(1024), + Arc::new(BufferPool::new()), + ) + .await + }); + + server_peer + .write_all(&[0xA1]) + .await + .expect("server write should enqueue while relay is stalled"); + + let mut one = [0u8; 1]; + let blocked_read = timeout(Duration::from_millis(40), client_peer.read_exact(&mut one)).await; + assert!( + blocked_read.is_err(), + "same-user relay must remain blocked while cross-mode lock is held" + ); + + drop(held_guard); + + timeout(Duration::from_millis(400), client_peer.read_exact(&mut one)) + .await + .expect("blocked relay must resume after cross-mode lock release") + .expect("resumed relay must deliver queued byte"); + assert_eq!(one, [0xA1]); + + drop(client_peer); + drop(server_peer); + + let relay_result = timeout(Duration::from_secs(1), relay_task) + .await + .expect("relay task must complete") + .expect("relay task must not panic"); + assert!(relay_result.is_ok()); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn integration_other_user_pipeline_progresses_while_blocked_user_is_stalled() { + let _guard = quota_test_guard(); + + let blocked_user = format!("relay-pipeline-blocked-{}", std::process::id()); + let free_user = format!("relay-pipeline-free-{}", std::process::id()); + let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&blocked_user); + let held_guard = held + .try_lock() + .expect("test must hold blocked user's shared cross-mode lock"); + + let stats_blocked = Arc::new(Stats::new()); + let stats_free = Arc::new(Stats::new()); + + let (mut blocked_client, blocked_relay_client) = duplex(1024); + let (blocked_relay_server, mut blocked_server) = duplex(1024); + let (blocked_client_reader, blocked_client_writer) = tokio::io::split(blocked_relay_client); + let (blocked_server_reader, blocked_server_writer) = tokio::io::split(blocked_relay_server); + + let (mut free_client, free_relay_client) = duplex(1024); + let (free_relay_server, mut free_server) = duplex(1024); + let (free_client_reader, free_client_writer) = tokio::io::split(free_relay_client); + let (free_server_reader, free_server_writer) = tokio::io::split(free_relay_server); + + let blocked_task = { + let user = blocked_user.clone(); + let stats = Arc::clone(&stats_blocked); + tokio::spawn(async move { + relay_bidirectional( + blocked_client_reader, + blocked_client_writer, + blocked_server_reader, + blocked_server_writer, + 256, + 256, + &user, + stats, + Some(1024), + Arc::new(BufferPool::new()), + ) + .await + }) + }; + + let free_task = { + let user = free_user.clone(); + let stats = Arc::clone(&stats_free); + tokio::spawn(async move { + relay_bidirectional( + free_client_reader, + free_client_writer, + free_server_reader, + free_server_writer, + 256, + 256, + &user, + stats, + Some(1024), + Arc::new(BufferPool::new()), + ) + .await + }) + }; + + blocked_server + .write_all(&[0xB1]) + .await + .expect("blocked user server write should queue"); + free_server + .write_all(&[0xC1]) + .await + .expect("free user server write should queue"); + + let mut blocked_buf = [0u8; 1]; + let mut free_buf = [0u8; 1]; + + let blocked_stalled = timeout( + Duration::from_millis(40), + blocked_client.read_exact(&mut blocked_buf), + ) + .await; + assert!( + blocked_stalled.is_err(), + "blocked user must remain stalled while its lock is held" + ); + + timeout(Duration::from_millis(250), free_client.read_exact(&mut free_buf)) + .await + .expect("free user must make progress while other user is blocked") + .expect("free user read must succeed"); + assert_eq!(free_buf, [0xC1]); + + drop(held_guard); + + timeout(Duration::from_millis(400), blocked_client.read_exact(&mut blocked_buf)) + .await + .expect("blocked user must resume after release") + .expect("blocked user resumed read must succeed"); + assert_eq!(blocked_buf, [0xB1]); + + drop(blocked_client); + drop(blocked_server); + drop(free_client); + drop(free_server); + + assert!( + timeout(Duration::from_secs(1), blocked_task) + .await + .expect("blocked relay task must complete") + .expect("blocked relay task must not panic") + .is_ok() + ); + assert!( + timeout(Duration::from_secs(1), free_task) + .await + .expect("free relay task must complete") + .expect("free relay task must not panic") + .is_ok() + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_jittered_hold_release_cycles_preserve_pipeline_liveness() { + let _guard = quota_test_guard(); + + let mut seed = 0x5EED_C0DE_2026_0323u64; + for round in 0..24u32 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let hold_ms = 2 + (seed % 10); + let user = format!("relay-pipeline-fuzz-{}-{round}", std::process::id()); + let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user); + let held_guard = held + .try_lock() + .expect("test must hold lock during fuzz round"); + + let stats = Arc::new(Stats::new()); + let (mut client_peer, relay_client) = duplex(1024); + let (relay_server, mut server_peer) = duplex(1024); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_user = user.clone(); + let relay_stats = Arc::clone(&stats); + let relay_task = tokio::spawn(async move { + relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 256, + 256, + &relay_user, + relay_stats, + Some(1024), + Arc::new(BufferPool::new()), + ) + .await + }); + + server_peer + .write_all(&[0xD1]) + .await + .expect("server write should queue in fuzz round"); + + let mut one = [0u8; 1]; + let stalled = timeout(Duration::from_millis(30), client_peer.read_exact(&mut one)).await; + assert!(stalled.is_err(), "held phase must stall same-user relay"); + + tokio::time::sleep(Duration::from_millis(hold_ms)).await; + drop(held_guard); + + timeout(Duration::from_millis(400), client_peer.read_exact(&mut one)) + .await + .expect("released phase must resume same-user relay") + .expect("released phase read must succeed"); + assert_eq!(one, [0xD1]); + + drop(client_peer); + drop(server_peer); + + assert!( + timeout(Duration::from_secs(1), relay_task) + .await + .expect("fuzz relay task must complete") + .expect("fuzz relay task must not panic") + .is_ok() + ); + } +} \ No newline at end of file diff --git a/src/proxy/tests/relay_cross_mode_pipeline_latency_benchmark_security_tests.rs b/src/proxy/tests/relay_cross_mode_pipeline_latency_benchmark_security_tests.rs new file mode 100644 index 0000000..c967861 --- /dev/null +++ b/src/proxy/tests/relay_cross_mode_pipeline_latency_benchmark_security_tests.rs @@ -0,0 +1,213 @@ +use super::relay_bidirectional; +use crate::stats::Stats; +use crate::stream::BufferPool; +use std::sync::{Arc, Mutex}; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::sync::{Barrier, watch}; +use tokio::time::{Duration, Instant, timeout}; + +fn quota_test_guard() -> impl Drop { + super::quota_user_lock_test_scope() +} + +fn percentile_index(len: usize, percentile: usize) -> usize { + ((len * percentile) / 100).min(len.saturating_sub(1)) +} + +#[tokio::test] +async fn micro_benchmark_pipeline_release_to_delivery_latency_stays_bounded() { + let _guard = quota_test_guard(); + + let rounds = 64usize; + let user = format!("relay-pipeline-latency-single-{}", std::process::id()); + let mut samples_ms = Vec::with_capacity(rounds); + + for round in 0..rounds { + let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user); + let held_guard = held + .try_lock() + .expect("test must hold shared cross-mode lock before round"); + + let stats = Arc::new(Stats::new()); + let (mut client_peer, relay_client) = duplex(1024); + let (relay_server, mut server_peer) = duplex(1024); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_user = user.clone(); + let relay_stats = Arc::clone(&stats); + let relay_task = tokio::spawn(async move { + relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 256, + 256, + &relay_user, + relay_stats, + Some(2048), + Arc::new(BufferPool::new()), + ) + .await + }); + + server_peer + .write_all(&[(round as u8) ^ 0xA5]) + .await + .expect("server write should queue before release"); + + let release_at = Instant::now(); + drop(held_guard); + + let mut one = [0u8; 1]; + timeout(Duration::from_millis(450), client_peer.read_exact(&mut one)) + .await + .expect("client must receive queued byte after release") + .expect("queued byte read must succeed"); + samples_ms.push(release_at.elapsed().as_millis() as u64); + + drop(client_peer); + drop(server_peer); + + let relay_result = timeout(Duration::from_secs(1), relay_task) + .await + .expect("relay task must complete") + .expect("relay task must not panic"); + assert!(relay_result.is_ok()); + } + + samples_ms.sort_unstable(); + let p50_ms = samples_ms[percentile_index(samples_ms.len(), 50)]; + let p95_ms = samples_ms[percentile_index(samples_ms.len(), 95)]; + + assert!( + p50_ms <= 45, + "single-flow release latency p50 must stay bounded; p50_ms={p50_ms}, samples={samples_ms:?}" + ); + assert!( + p95_ms <= 130, + "single-flow release latency p95 must stay bounded; p95_ms={p95_ms}, samples={samples_ms:?}" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_128_waiter_pipeline_release_latency_p95_stays_bounded() { + let _guard = quota_test_guard(); + + let waiters = 128usize; + let user = format!("relay-pipeline-latency-fanout-{}", std::process::id()); + let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user); + let held_guard = held + .try_lock() + .expect("test must hold shared lock before fanout release benchmark"); + + let ready_barrier = Arc::new(Barrier::new(waiters + 1)); + let release_at = Arc::new(Mutex::new(None::)); + let (release_tx, release_rx) = watch::channel(false); + let mut tasks = Vec::with_capacity(waiters); + + for idx in 0..waiters { + let user = user.clone(); + let barrier = Arc::clone(&ready_barrier); + let release_at = Arc::clone(&release_at); + let mut release_rx = release_rx.clone(); + + tasks.push(tokio::spawn(async move { + let stats = Arc::new(Stats::new()); + let (mut client_peer, relay_client) = duplex(512); + let (relay_server, mut server_peer) = duplex(512); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_user = user; + let relay_stats = Arc::clone(&stats); + let relay_task = tokio::spawn(async move { + relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 256, + 256, + &relay_user, + relay_stats, + Some(2048), + Arc::new(BufferPool::new()), + ) + .await + }); + + server_peer + .write_all(&[(idx as u8) ^ 0x5A]) + .await + .expect("fanout server write should queue before release"); + + barrier.wait().await; + release_rx + .changed() + .await + .expect("release signal should remain available"); + + let started = { + let guard = release_at.lock().unwrap_or_else(|poison| poison.into_inner()); + guard.expect("release timestamp must be populated before signal") + }; + + let mut one = [0u8; 1]; + timeout(Duration::from_millis(900), client_peer.read_exact(&mut one)) + .await + .expect("fanout waiter must receive queued byte after release") + .expect("fanout waiter read must succeed"); + + drop(client_peer); + drop(server_peer); + + let relay_result = timeout(Duration::from_secs(2), relay_task) + .await + .expect("fanout relay task must complete") + .expect("fanout relay task must not panic"); + assert!(relay_result.is_ok()); + + started.elapsed().as_millis() as u64 + })); + } + + ready_barrier.wait().await; + { + let mut guard = release_at.lock().unwrap_or_else(|poison| poison.into_inner()); + *guard = Some(Instant::now()); + } + drop(held_guard); + release_tx + .send(true) + .expect("release broadcast must succeed"); + + let mut samples_ms = Vec::with_capacity(waiters); + timeout(Duration::from_secs(8), async { + for task in tasks { + let elapsed = task.await.expect("fanout waiter must not panic"); + samples_ms.push(elapsed); + } + }) + .await + .expect("fanout benchmark must complete in bounded time"); + + samples_ms.sort_unstable(); + let p50_ms = samples_ms[percentile_index(samples_ms.len(), 50)]; + let p95_ms = samples_ms[percentile_index(samples_ms.len(), 95)]; + let max_ms = *samples_ms.last().unwrap_or(&0); + + assert!( + p50_ms <= 120, + "fanout release latency p50 must stay bounded; p50_ms={p50_ms}, p95_ms={p95_ms}, max_ms={max_ms}" + ); + assert!( + p95_ms <= 260, + "fanout release latency p95 must stay bounded; p50_ms={p50_ms}, p95_ms={p95_ms}, max_ms={max_ms}" + ); + assert!( + max_ms <= 700, + "fanout release latency max must stay bounded; p50_ms={p50_ms}, p95_ms={p95_ms}, max_ms={max_ms}" + ); +} \ No newline at end of file diff --git a/src/proxy/tests/relay_cross_mode_quota_fairness_tdd_tests.rs b/src/proxy/tests/relay_cross_mode_quota_fairness_tdd_tests.rs new file mode 100644 index 0000000..adbdb22 --- /dev/null +++ b/src/proxy/tests/relay_cross_mode_quota_fairness_tdd_tests.rs @@ -0,0 +1,604 @@ +use super::*; +use crate::stats::Stats; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::task::{Context, Poll, Waker}; +use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadBuf}; +use tokio::sync::Barrier; +use tokio::time::{Duration, timeout}; + +#[derive(Default)] +struct WakeCounter { + wakes: AtomicUsize, +} + +impl std::task::Wake for WakeCounter { + fn wake(self: Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } + + fn wake_by_ref(self: &Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } +} + +fn quota_test_guard() -> impl Drop { + super::quota_user_lock_test_scope() +} + +fn build_context() -> (Arc, Context<'static>) { + let wake_counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&wake_counter)); + let leaked_waker: &'static Waker = Box::leak(Box::new(waker)); + (wake_counter, Context::from_waker(leaked_waker)) +} + +#[tokio::test] +async fn positive_cross_mode_uncontended_writer_progresses() { + let _guard = quota_test_guard(); + + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + "cross-mode-tdd-uncontended".to_string(), + Some(4096), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + let result = io.write_all(&[0x11, 0x22]).await; + assert!(result.is_ok(), "uncontended writer must progress"); +} + +#[tokio::test] +async fn adversarial_held_cross_mode_lock_blocks_writer_even_if_local_lock_free() { + let _guard = quota_test_guard(); + + let user = format!("cross-mode-tdd-held-{}", std::process::id()); + let held = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); + let _held_guard = held + .try_lock() + .expect("test must hold cross-mode lock before polling writer"); + + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + Some(4096), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + let wake_counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&wake_counter)); + let mut cx = Context::from_waker(&waker); + + let poll = Pin::new(&mut io).poll_write(&mut cx, &[0xAA]); + assert!(poll.is_pending(), "writer must not bypass held cross-mode lock"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn integration_parallel_waiters_resume_after_cross_mode_release() { + let _guard = quota_test_guard(); + + let user = format!("cross-mode-tdd-resume-{}", std::process::id()); + let held = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); + let held_guard = held + .try_lock() + .expect("test must hold cross-mode lock before launching waiters"); + + let stats = Arc::new(Stats::new()); + let mut waiters = Vec::new(); + for _ in 0..16 { + let stats = Arc::clone(&stats); + let user = user.clone(); + waiters.push(tokio::spawn(async move { + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + stats, + user, + Some(4096), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + io.write_all(&[0x7F]).await + })); + } + + tokio::time::sleep(Duration::from_millis(5)).await; + drop(held_guard); + + timeout(Duration::from_secs(1), async { + for waiter in waiters { + let result = waiter.await.expect("waiter task must not panic"); + assert!(result.is_ok(), "waiter must complete after cross-mode release"); + } + }) + .await + .expect("all waiters must complete in bounded time"); +} + +#[tokio::test] +async fn adversarial_cross_mode_contention_wake_budget_stays_bounded() { + let _guard = quota_test_guard(); + + let user = format!("cross-mode-tdd-wakes-{}", std::process::id()); + let held = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); + let _held_guard = held + .try_lock() + .expect("test must hold cross-mode lock before polling"); + + let stats = Arc::new(Stats::new()); + let mut ios = Vec::new(); + let mut counters = Vec::new(); + for _ in 0..20 { + ios.push(StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::clone(&stats), + user.clone(), + Some(2048), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + )); + } + + for io in &mut ios { + let wake_counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&wake_counter)); + let mut cx = Context::from_waker(&waker); + let poll = Pin::new(io).poll_write(&mut cx, &[0x33]); + assert!(poll.is_pending()); + counters.push(wake_counter); + } + + tokio::time::sleep(Duration::from_millis(25)).await; + let total_wakes: usize = counters + .iter() + .map(|counter| counter.wakes.load(Ordering::Relaxed)) + .sum(); + + assert!( + total_wakes <= 20 * 4, + "cross-mode contention should not create wake storms; wakes={total_wakes}" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn light_fuzz_cross_mode_release_timing_preserves_read_write_liveness() { + let _guard = quota_test_guard(); + + let mut seed = 0xC0DE_BAAD_2026_0322u64; + for round in 0..16u32 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let sleep_ms = 2 + (seed as u64 % 8); + let user = format!("cross-mode-tdd-fuzz-{}-{round}", std::process::id()); + let held = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); + let held_guard = held + .try_lock() + .expect("test must hold cross-mode lock in fuzz round"); + + let stats = Arc::new(Stats::new()); + let user_reader = user.clone(); + let reader_task = tokio::spawn(async move { + let mut io = StatsIo::new( + tokio::io::empty(), + Arc::new(SharedCounters::new()), + Arc::clone(&stats), + user_reader, + Some(4096), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + let mut one = [0u8; 1]; + io.read(&mut one).await + }); + + let user_writer = user.clone(); + let writer_task = tokio::spawn(async move { + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user_writer, + Some(4096), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + io.write_all(&[0x44]).await + }); + + tokio::time::sleep(Duration::from_millis(sleep_ms)).await; + drop(held_guard); + + let read_done = timeout(Duration::from_millis(350), reader_task) + .await + .expect("reader task must complete after release") + .expect("reader task must not panic"); + assert!(read_done.is_ok()); + + let write_done = timeout(Duration::from_millis(350), writer_task) + .await + .expect("writer task must complete after release") + .expect("writer task must not panic"); + assert!(write_done.is_ok()); + } +} + +#[tokio::test] +async fn integration_middle_lock_blocks_relay_reader_for_same_user() { + let _guard = quota_test_guard(); + + let user = format!("cross-mode-middle-reader-block-{}", std::process::id()); + let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user); + let _held_guard = held + .try_lock() + .expect("test must hold middle-relay shared lock"); + + let mut io = StatsIo::new( + tokio::io::empty(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + Some(1024), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + let (_wake_counter, mut cx) = build_context(); + let mut one = [0u8; 1]; + let mut buf = ReadBuf::new(&mut one); + let poll = Pin::new(&mut io).poll_read(&mut cx, &mut buf); + assert!(poll.is_pending()); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn integration_middle_lock_release_unblocks_relay_reader() { + let _guard = quota_test_guard(); + + let user = format!("cross-mode-middle-reader-release-{}", std::process::id()); + let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user); + let held_guard = held + .try_lock() + .expect("test must hold middle-relay shared lock"); + + let task = tokio::spawn({ + let user = user.clone(); + async move { + let mut io = StatsIo::new( + tokio::io::empty(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + Some(1024), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + let mut one = [0u8; 1]; + io.read(&mut one).await + } + }); + + tokio::time::sleep(Duration::from_millis(5)).await; + drop(held_guard); + + let done = timeout(Duration::from_millis(300), task) + .await + .expect("reader task must complete after release") + .expect("reader task must not panic"); + assert!(done.is_ok()); +} + +#[tokio::test] +async fn business_different_user_middle_lock_does_not_block_relay_writer() { + let _guard = quota_test_guard(); + + let held_user = format!("cross-mode-middle-held-{}", std::process::id()); + let active_user = format!("cross-mode-middle-active-{}", std::process::id()); + let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&held_user); + let _held_guard = held + .try_lock() + .expect("test must hold middle-relay lock for other user"); + + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + active_user, + Some(1024), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + let (_wake_counter, mut cx) = build_context(); + let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x61]); + assert!(matches!(poll, Poll::Ready(Ok(1)))); +} + +#[tokio::test] +async fn edge_quota_none_bypasses_cross_mode_lock_even_when_held() { + let _guard = quota_test_guard(); + + let user = format!("cross-mode-none-limit-{}", std::process::id()); + let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user); + let _held_guard = held + .try_lock() + .expect("test must hold lock while quota is disabled"); + + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + None, + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + let (_wake_counter, mut cx) = build_context(); + let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x62, 0x63]); + assert!(matches!(poll, Poll::Ready(Ok(2)))); +} + +#[tokio::test] +async fn edge_quota_exceeded_flag_short_circuits_before_lock_path() { + let _guard = quota_test_guard(); + + let user = format!("cross-mode-pre-exceeded-{}", std::process::id()); + let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user); + let _held_guard = held + .try_lock() + .expect("test must hold shared lock before poll"); + + let quota_exceeded = Arc::new(AtomicBool::new(true)); + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + Some(1024), + Arc::clone("a_exceeded), + tokio::time::Instant::now(), + ); + + let (_wake_counter, mut cx) = build_context(); + let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x64]); + assert!(matches!(poll, Poll::Ready(Err(ref e)) if is_quota_io_error(e))); +} + +#[tokio::test] +async fn adversarial_repoll_while_middle_lock_held_keeps_pending_without_usage_leak() { + let _guard = quota_test_guard(); + + let user = format!("cross-mode-repoll-held-{}", std::process::id()); + let stats = Arc::new(Stats::new()); + let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user); + let _held_guard = held + .try_lock() + .expect("test must hold lock for repoll sequence"); + + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::clone(&stats), + user.clone(), + Some(1024), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + let (_wake_counter, mut cx) = build_context(); + for _ in 0..8 { + let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x65]); + assert!(poll.is_pending()); + } + + assert_eq!(stats.get_user_total_octets(&user), 0); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn integration_same_user_mixed_read_write_waiters_resume_after_release() { + let _guard = quota_test_guard(); + + let user = format!("cross-mode-mixed-resume-{}", std::process::id()); + let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user); + let held_guard = held + .try_lock() + .expect("test must hold lock before spawning mixed waiters"); + + let mut tasks = Vec::new(); + for i in 0..12usize { + let user = user.clone(); + tasks.push(tokio::spawn(async move { + if i % 2 == 0 { + let mut io = StatsIo::new( + tokio::io::empty(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + Some(1024), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + let mut b = [0u8; 1]; + io.read(&mut b).await.map(|_| ()) + } else { + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + Some(1024), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + io.write_all(&[0x66]).await + } + })); + } + + tokio::time::sleep(Duration::from_millis(8)).await; + drop(held_guard); + + timeout(Duration::from_secs(1), async { + for task in tasks { + let result = task.await.expect("mixed waiter task must not panic"); + assert!(result.is_ok()); + } + }) + .await + .expect("all mixed waiters must finish after release"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn integration_one_user_blocked_other_user_progresses_under_middle_lock() { + let _guard = quota_test_guard(); + + let blocked_user = format!("cross-mode-blocked-{}", std::process::id()); + let free_user = format!("cross-mode-free-{}", std::process::id()); + let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&blocked_user); + let held_guard = held + .try_lock() + .expect("test must hold blocked user lock"); + + let blocked_task = tokio::spawn({ + let blocked_user = blocked_user.clone(); + async move { + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + blocked_user, + Some(1024), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + io.write_all(&[0x77]).await + } + }); + + let free_task = tokio::spawn({ + let free_user = free_user.clone(); + async move { + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + free_user, + Some(1024), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + io.write_all(&[0x78]).await + } + }); + + let free_done = timeout(Duration::from_millis(250), free_task) + .await + .expect("free user must not be blocked") + .expect("free user task must not panic"); + assert!(free_done.is_ok()); + + drop(held_guard); + let blocked_done = timeout(Duration::from_secs(1), blocked_task) + .await + .expect("blocked user must resume after release") + .expect("blocked user task must not panic"); + assert!(blocked_done.is_ok()); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_middle_lock_release_allows_high_waiter_fanout_completion() { + let _guard = quota_test_guard(); + + let user = format!("cross-mode-fanout-{}", std::process::id()); + let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user); + let held_guard = held + .try_lock() + .expect("test must hold lock before fanout"); + + let waiters = 48usize; + let gate = Arc::new(Barrier::new(waiters + 1)); + let mut tasks = Vec::new(); + for _ in 0..waiters { + let user = user.clone(); + let gate = Arc::clone(&gate); + tasks.push(tokio::spawn(async move { + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + Some(1024), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + gate.wait().await; + io.write_all(&[0x79]).await + })); + } + + gate.wait().await; + tokio::time::sleep(Duration::from_millis(10)).await; + drop(held_guard); + + timeout(Duration::from_secs(2), async { + for task in tasks { + let result = task.await.expect("fanout task must not panic"); + assert!(result.is_ok()); + } + }) + .await + .expect("fanout waiters must complete after release"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn light_fuzz_middle_lock_hold_release_cycles_preserve_same_user_liveness() { + let _guard = quota_test_guard(); + + let mut seed = 0xA11C_EE55_2026_0323u64; + for round in 0..20u32 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let hold_ms = 2 + (seed % 10); + let user = format!("cross-mode-middle-fuzz-{}-{round}", std::process::id()); + let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user); + let held_guard = held + .try_lock() + .expect("test must hold lock in fuzz round"); + + let writer = tokio::spawn({ + let user = user.clone(); + async move { + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + Some(1024), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + io.write_all(&[0x7A]).await + } + }); + + tokio::time::sleep(Duration::from_millis(hold_ms)).await; + drop(held_guard); + + let done = timeout(Duration::from_millis(400), writer) + .await + .expect("writer must complete after lock release") + .expect("writer task must not panic"); + assert!(done.is_ok()); + } +} diff --git a/src/proxy/tests/relay_cross_mode_quota_lock_security_tests.rs b/src/proxy/tests/relay_cross_mode_quota_lock_security_tests.rs new file mode 100644 index 0000000..5ea806a --- /dev/null +++ b/src/proxy/tests/relay_cross_mode_quota_lock_security_tests.rs @@ -0,0 +1,81 @@ +use super::*; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::task::Waker; +use std::task::{Context, Poll}; + +#[derive(Default)] +struct WakeCounter { + wakes: AtomicUsize, +} + +impl std::task::Wake for WakeCounter { + fn wake(self: Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } + + fn wake_by_ref(self: &Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } +} + +fn build_context() -> (Arc, Context<'static>) { + let wake_counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&wake_counter)); + let leaked_waker: &'static Waker = Box::leak(Box::new(waker)); + (wake_counter, Context::from_waker(leaked_waker)) +} + +#[tokio::test] +async fn adversarial_middle_held_cross_mode_lock_blocks_relay_writer() { + let _guard = quota_user_lock_test_scope(); + + let user = "cross-mode-lock-shared-user"; + let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(user); + let _held_guard = held + .try_lock() + .expect("test must hold shared cross-mode lock before relay poll"); + + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(crate::stats::Stats::new()), + user.to_string(), + Some(1024), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + let (_wake_counter, mut cx) = build_context(); + let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x41, 0x42, 0x43]); + + assert!( + matches!(poll, Poll::Pending), + "relay writer must not bypass cross-mode lock held by middle-relay path" + ); +} + +#[tokio::test] +async fn business_cross_mode_lock_uncontended_allows_relay_writer_progress() { + let _guard = quota_user_lock_test_scope(); + + let user = "cross-mode-lock-progress-user"; + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(crate::stats::Stats::new()), + user.to_string(), + Some(1024), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + let (_wake_counter, mut cx) = build_context(); + let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x51, 0x52]); + + assert!( + matches!(poll, Poll::Ready(Ok(2))), + "relay writer should progress when shared cross-mode lock is uncontended" + ); +} diff --git a/src/proxy/tests/relay_dual_lock_alternating_contention_security_tests.rs b/src/proxy/tests/relay_dual_lock_alternating_contention_security_tests.rs new file mode 100644 index 0000000..9ac4621 --- /dev/null +++ b/src/proxy/tests/relay_dual_lock_alternating_contention_security_tests.rs @@ -0,0 +1,340 @@ +use super::*; +use crate::stats::Stats; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::task::{Context, Waker}; +use tokio::io::AsyncWriteExt; +use tokio::time::{Duration, Instant, timeout}; + +#[derive(Default)] +struct WakeCounter { + wakes: AtomicUsize, +} + +impl std::task::Wake for WakeCounter { + fn wake(self: Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } + + fn wake_by_ref(self: &Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } +} + +fn quota_test_guard() -> impl Drop { + super::quota_user_lock_test_scope() +} + +#[tokio::test] +async fn positive_uncontended_dual_lock_writer_has_zero_retry_attempt() { + let _guard = quota_test_guard(); + + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + format!("dual-lock-alt-positive-{}", std::process::id()), + Some(2048), + Arc::new(AtomicBool::new(false)), + Instant::now(), + ); + + let write = io.write_all(&[0xAA, 0xBB]).await; + assert!(write.is_ok(), "uncontended write must complete"); + assert_eq!( + io.quota_write_retry_attempt, 0, + "uncontended write must not advance retry backoff" + ); +} + +#[tokio::test] +async fn adversarial_alternating_local_and_cross_mode_contention_preserves_backoff_growth() { + let _guard = quota_test_guard(); + + let user = format!("dual-lock-alt-adversarial-{}", std::process::id()); + let local_lock = quota_user_lock(&user); + let cross_mode_lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); + + let mut local_guard = Some( + local_lock + .try_lock() + .expect("test must hold local quota lock initially"), + ); + let mut cross_guard = None; + + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + Some(2048), + Arc::new(AtomicBool::new(false)), + Instant::now(), + ); + + let wake_counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&wake_counter)); + let mut cx = Context::from_waker(&waker); + + let first = Pin::new(&mut io).poll_write(&mut cx, &[0x11]); + assert!(first.is_pending(), "held local lock must block first poll"); + + let mut observed_wakes = 0usize; + for idx in 0..18usize { + tokio::time::sleep(Duration::from_millis(6)).await; + + if idx % 2 == 0 { + drop(local_guard.take()); + cross_guard = Some( + cross_mode_lock + .try_lock() + .expect("cross-mode lock should be acquirable while local lock released"), + ); + } else { + drop(cross_guard.take()); + local_guard = Some( + local_lock + .try_lock() + .expect("local lock should be acquirable while cross lock released"), + ); + } + + let wakes = wake_counter.wakes.load(Ordering::Relaxed); + if wakes > observed_wakes { + observed_wakes = wakes; + let pending = Pin::new(&mut io).poll_write(&mut cx, &[0x12]); + assert!( + pending.is_pending(), + "alternating contention must keep write pending while one lock is held" + ); + } + } + + assert!( + io.quota_write_retry_attempt >= 2, + "alternating contention must still ramp retry backoff; got {}", + io.quota_write_retry_attempt + ); + assert!( + wake_counter.wakes.load(Ordering::Relaxed) <= 32, + "alternating contention must stay wake-rate-limited" + ); + + drop(local_guard); + drop(cross_guard); + let ready = Pin::new(&mut io).poll_write(&mut cx, &[0x13]); + assert!(ready.is_ready(), "writer must resume after both locks released"); +} + +#[tokio::test] +async fn edge_retry_scheduler_resets_after_alternating_contention_clears() { + let _guard = quota_test_guard(); + + let user = format!("dual-lock-alt-edge-reset-{}", std::process::id()); + let local_lock = quota_user_lock(&user); + let local_guard = local_lock + .try_lock() + .expect("test must hold local lock for edge scenario"); + + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + Some(2048), + Arc::new(AtomicBool::new(false)), + Instant::now(), + ); + + let wake_counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&wake_counter)); + let mut cx = Context::from_waker(&waker); + + let first = Pin::new(&mut io).poll_write(&mut cx, &[0x21]); + assert!(first.is_pending()); + tokio::time::sleep(Duration::from_millis(15)).await; + if wake_counter.wakes.load(Ordering::Relaxed) > 0 { + let next = Pin::new(&mut io).poll_write(&mut cx, &[0x22]); + assert!(next.is_pending()); + } + + drop(local_guard); + + let ready = Pin::new(&mut io).poll_write(&mut cx, &[0x23]); + assert!(ready.is_ready()); + assert_eq!( + io.quota_write_retry_attempt, 0, + "successful dual-lock acquisition must reset retry scheduler" + ); + assert!(!io.quota_write_wake_scheduled); + assert!(io.quota_write_retry_sleep.is_none()); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn integration_cross_mode_waiters_remain_live_under_alternating_contention_then_resume() { + let _guard = quota_test_guard(); + + let user = format!("dual-lock-alt-integration-{}", std::process::id()); + let local_lock = quota_user_lock(&user); + let cross_mode_lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); + + let mut waiters = Vec::new(); + for _ in 0..16usize { + let user = user.clone(); + waiters.push(tokio::spawn(async move { + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + Some(2048), + Arc::new(AtomicBool::new(false)), + Instant::now(), + ); + timeout(Duration::from_secs(2), io.write_all(&[0x31])).await + })); + } + + let mut local_guard = Some( + local_lock + .try_lock() + .expect("integration toggle must acquire local lock first"), + ); + let mut cross_guard = None; + + for idx in 0..24usize { + tokio::time::sleep(Duration::from_millis(4)).await; + if idx % 2 == 0 { + drop(local_guard.take()); + cross_guard = cross_mode_lock.try_lock().ok(); + } else { + drop(cross_guard.take()); + local_guard = local_lock.try_lock().ok(); + } + } + + drop(local_guard); + drop(cross_guard); + + for waiter in waiters { + let done = waiter.await.expect("waiter task must not panic"); + assert!( + done.is_ok(), + "waiter must finish once alternating contention window ends" + ); + assert!(done.expect("waiter timeout must not fire").is_ok()); + } +} + +#[tokio::test] +async fn light_fuzz_alternating_contention_matrix_preserves_lock_gating() { + let _guard = quota_test_guard(); + + let user = format!("dual-lock-alt-fuzz-{}", std::process::id()); + let local_lock = quota_user_lock(&user); + let cross_mode_lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); + let mut seed = 0xD00D_BAAD_F00D_2026u64; + + for _round in 0..64u32 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let hold_mode = (seed % 3) as u8; + let local_guard = if hold_mode == 0 { + Some( + local_lock + .try_lock() + .expect("fuzz local lock should be acquirable"), + ) + } else { + None + }; + let cross_guard = if hold_mode == 1 { + Some( + cross_mode_lock + .try_lock() + .expect("fuzz cross lock should be acquirable"), + ) + } else { + None + }; + + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user.clone(), + Some(1024), + Arc::new(AtomicBool::new(false)), + Instant::now(), + ); + + let write = timeout(Duration::from_millis(35), io.write_all(&[0x51])).await; + if hold_mode == 2 { + assert!(write.is_ok(), "unheld fuzz round must make progress"); + assert!(write.expect("unheld round timeout").is_ok()); + } else { + assert!( + write.is_err(), + "held-lock fuzz round must remain pending inside bounded window" + ); + } + + drop(local_guard); + drop(cross_guard); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_fanout_alternating_contention_recovers_without_hanging() { + let _guard = quota_test_guard(); + + let user = format!("dual-lock-alt-stress-{}", std::process::id()); + let local_lock = quota_user_lock(&user); + let cross_mode_lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); + + let mut waiters = Vec::new(); + for _ in 0..48usize { + let user = user.clone(); + waiters.push(tokio::spawn(async move { + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + Some(4096), + Arc::new(AtomicBool::new(false)), + Instant::now(), + ); + timeout(Duration::from_secs(3), io.write_all(&[0xA0, 0xA1])).await + })); + } + + let mut local_guard = Some( + local_lock + .try_lock() + .expect("stress toggle must acquire local lock first"), + ); + let mut cross_guard = None; + for idx in 0..40usize { + tokio::time::sleep(Duration::from_millis(3)).await; + if idx % 2 == 0 { + drop(local_guard.take()); + cross_guard = cross_mode_lock.try_lock().ok(); + } else { + drop(cross_guard.take()); + local_guard = local_lock.try_lock().ok(); + } + } + + drop(local_guard); + drop(cross_guard); + + for waiter in waiters { + let done = waiter.await.expect("stress waiter task must not panic"); + assert!(done.is_ok(), "stress waiter timed out under alternating contention"); + assert!(done.expect("stress waiter timeout should not fire").is_ok()); + } +} diff --git a/src/proxy/tests/relay_dual_lock_backoff_regression_security_tests.rs b/src/proxy/tests/relay_dual_lock_backoff_regression_security_tests.rs new file mode 100644 index 0000000..ce26941 --- /dev/null +++ b/src/proxy/tests/relay_dual_lock_backoff_regression_security_tests.rs @@ -0,0 +1,74 @@ +use super::*; +use crate::stats::Stats; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::task::{Context, Waker}; +use tokio::time::{Duration, Instant}; + +#[derive(Default)] +struct WakeCounter { + wakes: AtomicUsize, +} + +impl std::task::Wake for WakeCounter { + fn wake(self: Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } + + fn wake_by_ref(self: &Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } +} + +fn quota_test_guard() -> impl Drop { + super::quota_user_lock_test_scope() +} + +#[tokio::test] +async fn adversarial_cross_mode_only_contention_backoff_attempt_must_ramp() { + let _guard = quota_test_guard(); + + let user = format!("dual-lock-backoff-{}", std::process::id()); + let cross_mode_lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); + let held_cross_mode_guard = cross_mode_lock + .try_lock() + .expect("test must hold cross-mode lock before polling"); + + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + Some(2048), + Arc::new(AtomicBool::new(false)), + Instant::now(), + ); + + let wake_counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&wake_counter)); + let mut cx = Context::from_waker(&waker); + + let first = Pin::new(&mut io).poll_write(&mut cx, &[0xAA]); + assert!(first.is_pending(), "held cross-mode lock must block writer"); + + let started = Instant::now(); + let mut last_wakes = 0usize; + while started.elapsed() < Duration::from_millis(120) { + let wakes = wake_counter.wakes.load(Ordering::Relaxed); + if wakes > last_wakes { + last_wakes = wakes; + let next = Pin::new(&mut io).poll_write(&mut cx, &[0xAB]); + assert!(next.is_pending(), "writer must remain blocked while lock is held"); + } + tokio::time::sleep(Duration::from_millis(1)).await; + } + + assert!( + io.quota_write_retry_attempt >= 2, + "retry attempt must ramp under sustained second-lock contention; got {}", + io.quota_write_retry_attempt + ); + + drop(held_cross_mode_guard); +} diff --git a/src/proxy/tests/relay_dual_lock_contention_matrix_security_tests.rs b/src/proxy/tests/relay_dual_lock_contention_matrix_security_tests.rs new file mode 100644 index 0000000..513d92b --- /dev/null +++ b/src/proxy/tests/relay_dual_lock_contention_matrix_security_tests.rs @@ -0,0 +1,325 @@ +use super::*; +use crate::stats::Stats; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::task::{Context, Waker}; +use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadBuf}; +use tokio::time::{Duration, Instant, timeout}; + +#[derive(Default)] +struct WakeCounter { + wakes: AtomicUsize, +} + +impl std::task::Wake for WakeCounter { + fn wake(self: Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } + + fn wake_by_ref(self: &Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } +} + +fn quota_test_guard() -> impl Drop { + super::quota_user_lock_test_scope() +} + +fn build_context() -> (Arc, Context<'static>) { + let wake_counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&wake_counter)); + let leaked_waker: &'static Waker = Box::leak(Box::new(waker)); + (wake_counter, Context::from_waker(leaked_waker)) +} + +#[tokio::test] +async fn positive_uncontended_dual_locks_writer_completes_without_retry_state() { + let _guard = quota_test_guard(); + + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + format!("dual-lock-positive-{}", std::process::id()), + Some(4096), + Arc::new(AtomicBool::new(false)), + Instant::now(), + ); + + let (_wake_counter, mut cx) = build_context(); + let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x01, 0x02, 0x03]); + assert!(poll.is_ready()); + assert_eq!(io.quota_write_retry_attempt, 0); + assert!(!io.quota_write_wake_scheduled); + assert!(io.quota_write_retry_sleep.is_none()); +} + +#[tokio::test] +async fn negative_local_lock_contention_read_retry_attempt_ramps() { + let _guard = quota_test_guard(); + + let user = format!("dual-lock-local-contention-{}", std::process::id()); + let held = quota_user_lock(&user); + let held_guard = held + .try_lock() + .expect("test must hold local quota lock before polling"); + + let mut io = StatsIo::new( + tokio::io::empty(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + Some(2048), + Arc::new(AtomicBool::new(false)), + Instant::now(), + ); + + let (wake_counter, mut cx) = build_context(); + let mut one = [0u8; 1]; + let mut buf = ReadBuf::new(&mut one); + let first = Pin::new(&mut io).poll_read(&mut cx, &mut buf); + assert!(first.is_pending()); + + let started = Instant::now(); + let mut observed = 0usize; + while started.elapsed() < Duration::from_millis(120) { + let wakes = wake_counter.wakes.load(Ordering::Relaxed); + if wakes > observed { + observed = wakes; + let mut step_buf = ReadBuf::new(&mut one); + let next = Pin::new(&mut io).poll_read(&mut cx, &mut step_buf); + assert!(next.is_pending()); + } + tokio::time::sleep(Duration::from_millis(1)).await; + } + + assert!( + io.quota_read_retry_attempt >= 2, + "retry attempt must ramp under sustained local-lock contention; got {}", + io.quota_read_retry_attempt + ); + + drop(held_guard); +} + +#[tokio::test] +async fn edge_cross_mode_contention_release_resets_retry_scheduler_on_success() { + let _guard = quota_test_guard(); + + let user = format!("dual-lock-reset-{}", std::process::id()); + let cross_mode = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); + let held_guard = cross_mode + .try_lock() + .expect("test must hold cross-mode lock before polling"); + + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + Some(2048), + Arc::new(AtomicBool::new(false)), + Instant::now(), + ); + + let (wake_counter, mut cx) = build_context(); + let first = Pin::new(&mut io).poll_write(&mut cx, &[0x10]); + assert!(first.is_pending()); + + tokio::time::sleep(Duration::from_millis(20)).await; + if wake_counter.wakes.load(Ordering::Relaxed) > 0 { + let next = Pin::new(&mut io).poll_write(&mut cx, &[0x11]); + assert!(next.is_pending()); + } + + drop(held_guard); + let ready = Pin::new(&mut io).poll_write(&mut cx, &[0x12]); + assert!(ready.is_ready()); + assert_eq!(io.quota_write_retry_attempt, 0); + assert!(!io.quota_write_wake_scheduled); + assert!(io.quota_write_retry_sleep.is_none()); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn adversarial_cross_mode_hold_blocks_many_waiters_without_usage_leak() { + let _guard = quota_test_guard(); + + let user = format!("dual-lock-adversarial-{}", std::process::id()); + let stats = Arc::new(Stats::new()); + let held = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); + let held_guard = held + .try_lock() + .expect("test must hold cross-mode lock before launching waiters"); + + let mut tasks = Vec::new(); + for _ in 0..24usize { + let stats = Arc::clone(&stats); + let user = user.clone(); + tasks.push(tokio::spawn(async move { + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + stats, + user, + Some(1024), + Arc::new(AtomicBool::new(false)), + Instant::now(), + ); + timeout(Duration::from_millis(40), io.write_all(&[0x33])).await + })); + } + + for task in tasks { + let timed = task.await.expect("waiter task must not panic"); + assert!(timed.is_err(), "held cross-mode lock must keep waiter pending"); + } + + assert_eq!(stats.get_user_total_octets(&user), 0); + drop(held_guard); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn integration_waiters_resume_after_cross_mode_release() { + let _guard = quota_test_guard(); + + let user = format!("dual-lock-integration-{}", std::process::id()); + let held = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); + let held_guard = held + .try_lock() + .expect("test must hold cross-mode lock before starting waiter"); + + let task = tokio::spawn({ + let user = user.clone(); + async move { + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + Some(1024), + Arc::new(AtomicBool::new(false)), + Instant::now(), + ); + io.write_all(&[0x44]).await + } + }); + + tokio::time::sleep(Duration::from_millis(10)).await; + drop(held_guard); + + let done = timeout(Duration::from_secs(1), task) + .await + .expect("waiter task must complete after release") + .expect("waiter task must not panic"); + assert!(done.is_ok()); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn light_fuzz_randomized_lock_holds_preserve_liveness_and_quota_bounds() { + let _guard = quota_test_guard(); + + let user = format!("dual-lock-fuzz-{}", std::process::id()); + let stats = Arc::new(Stats::new()); + let mut seed = 0xA55A_55AA_C3D2_E1F0u64; + + for _round in 0..48u32 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let hold_mode = (seed % 3) as u8; + let mut local_lock = None; + let mut cross_lock = None; + let mut local_guard = None; + let mut cross_guard = None; + + if hold_mode == 0 { + local_lock = Some(quota_user_lock(&user)); + local_guard = Some( + local_lock + .as_ref() + .expect("local lock should be present") + .try_lock() + .expect("local lock should be acquirable in fuzz round"), + ); + } else if hold_mode == 1 { + cross_lock = Some(crate::proxy::quota_lock_registry::cross_mode_quota_user_lock( + &user, + )); + cross_guard = Some( + cross_lock + .as_ref() + .expect("cross lock should be present") + .try_lock() + .expect("cross lock should be acquirable in fuzz round"), + ); + } + + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::clone(&stats), + user.clone(), + Some(4096), + Arc::new(AtomicBool::new(false)), + Instant::now(), + ); + + let write = timeout(Duration::from_millis(25), io.write_all(&[0x7A])).await; + if hold_mode == 2 { + assert!(write.is_ok(), "unheld round must make progress"); + } else { + assert!(write.is_err(), "held-lock round must stay blocked within timeout"); + } + + drop(local_guard); + drop(cross_guard); + drop(local_lock); + drop(cross_lock); + } + + assert!(stats.get_user_total_octets(&user) <= 4096); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_fanout_waiters_complete_after_release_without_panics() { + let _guard = quota_test_guard(); + + let user = format!("dual-lock-stress-{}", std::process::id()); + let held = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); + let held_guard = held + .try_lock() + .expect("test must hold cross-mode lock before stress fanout"); + + let waiters = 64usize; + let mut tasks = Vec::new(); + for _ in 0..waiters { + let user = user.clone(); + tasks.push(tokio::spawn(async move { + let mut io = StatsIo::new( + tokio::io::empty(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + Some(1024), + Arc::new(AtomicBool::new(false)), + Instant::now(), + ); + let mut one = [0u8; 1]; + io.read(&mut one).await + })); + } + + tokio::time::sleep(Duration::from_millis(12)).await; + drop(held_guard); + + timeout(Duration::from_secs(2), async { + for task in tasks { + let result = task.await.expect("stress waiter task must not panic"); + assert!(result.is_ok()); + } + }) + .await + .expect("all stress waiters must complete after release"); +} diff --git a/src/proxy/tests/relay_dual_lock_race_harness_security_tests.rs b/src/proxy/tests/relay_dual_lock_race_harness_security_tests.rs new file mode 100644 index 0000000..ec180e8 --- /dev/null +++ b/src/proxy/tests/relay_dual_lock_race_harness_security_tests.rs @@ -0,0 +1,128 @@ +use super::*; +use crate::stats::Stats; +use std::sync::Arc; +use std::sync::atomic::AtomicBool; +use tokio::io::AsyncWriteExt; +use tokio::time::{Duration, timeout}; + +fn quota_test_guard() -> impl Drop { + super::quota_user_lock_test_scope() +} + +fn make_stats_io(user: String) -> StatsIo { + StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + Some(4096), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn light_fuzz_1024_round_hold_release_cycles_preserve_same_user_liveness() { + let _guard = quota_test_guard(); + + let user = format!("dual-lock-race-fuzz-{}", std::process::id()); + let mut seed = 0xD1CE_BAAD_5EED_1234u64; + + for round in 0..1024u32 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let hold = (seed & 1) == 0; + let hold_ms = (seed % 3) as u64; + + let maybe_lock = if hold { + Some(crate::proxy::quota_lock_registry::cross_mode_quota_user_lock( + &user, + )) + } else { + None + }; + + let maybe_guard = maybe_lock.as_ref().map(|lock| { + lock.try_lock() + .expect("cross-mode lock must be acquirable in fuzz round") + }); + + if hold { + let mut blocked_io = make_stats_io(user.clone()); + let blocked = timeout(Duration::from_millis(5), blocked_io.write_all(&[0xA5])).await; + assert!( + blocked.is_err(), + "held round must block waiter before lock release (round={round})" + ); + + if hold_ms > 0 { + tokio::time::sleep(Duration::from_millis(hold_ms)).await; + } + } else { + let mut free_io = make_stats_io(user.clone()); + let free = timeout(Duration::from_millis(120), free_io.write_all(&[0xA5])).await; + assert!( + free.is_ok(), + "unheld round must complete promptly (round={round})" + ); + assert!(free.expect("unheld round should complete").is_ok()); + } + + drop(maybe_guard); + + let done = timeout(Duration::from_millis(350), async { + let user = user.clone(); + let mut io = make_stats_io(user); + io.write_all(&[0xA6]).await + }) + .await + .expect("post-release write must complete in bounded time"); + assert!(done.is_ok()); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_jittered_three_waiter_rounds_do_not_starve_after_release() { + let _guard = quota_test_guard(); + + let user = format!("dual-lock-race-stress-{}", std::process::id()); + let mut seed = 0xC0FF_EE77_4444_9999u64; + + for round in 0..256u32 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let hold_ms = (seed % 4) as u64; + let lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); + let guard = lock + .try_lock() + .expect("cross-mode lock must be acquirable at round start"); + + let mut waiters = Vec::new(); + for _ in 0..3usize { + let user = user.clone(); + waiters.push(tokio::spawn(async move { + let mut io = make_stats_io(user); + io.write_all(&[0x55]).await + })); + } + + tokio::time::sleep(Duration::from_millis(hold_ms)).await; + drop(guard); + + timeout(Duration::from_secs(1), async { + for waiter in waiters { + let done = waiter.await.expect("waiter task must not panic"); + assert!( + done.is_ok(), + "waiter must complete after release (round={round})" + ); + } + }) + .await + .expect("all waiters must complete in bounded time after release"); + } +} diff --git a/src/proxy/tests/relay_quota_extended_attack_surface_security_tests.rs b/src/proxy/tests/relay_quota_extended_attack_surface_security_tests.rs new file mode 100644 index 0000000..5ee6522 --- /dev/null +++ b/src/proxy/tests/relay_quota_extended_attack_surface_security_tests.rs @@ -0,0 +1,332 @@ +use super::relay_bidirectional; +use crate::error::ProxyError; +use crate::stats::Stats; +use crate::stream::BufferPool; +use rand::rngs::StdRng; +use rand::{RngExt, SeedableRng}; +use std::sync::Arc; +use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; +use tokio::time::{Duration, timeout}; + +async fn read_available(reader: &mut R, budget: Duration) -> usize { + let start = tokio::time::Instant::now(); + let mut total = 0usize; + let mut buf = [0u8; 128]; + + loop { + let elapsed = start.elapsed(); + if elapsed >= budget { + break; + } + let remaining = budget.saturating_sub(elapsed); + match timeout(remaining, reader.read(&mut buf)).await { + Ok(Ok(0)) => break, + Ok(Ok(n)) => total = total.saturating_add(n), + Ok(Err(_)) | Err(_) => break, + } + } + + total +} + +#[tokio::test] +async fn positive_quota_path_forwards_both_directions_within_limit() { + let stats = Arc::new(Stats::new()); + let user = "quota-extended-positive-user"; + + let (mut client_peer, relay_client) = duplex(2048); + let (relay_server, mut server_peer) = duplex(2048); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 256, + 256, + user, + Arc::clone(&stats), + Some(16), + Arc::new(BufferPool::new()), + )); + + client_peer.write_all(&[0xAA, 0xBB, 0xCC, 0xDD]).await.unwrap(); + server_peer.read_exact(&mut [0u8; 4]).await.unwrap(); + + server_peer.write_all(&[0x11, 0x22, 0x33, 0x44]).await.unwrap(); + client_peer.read_exact(&mut [0u8; 4]).await.unwrap(); + + drop(client_peer); + drop(server_peer); + + let relay_result = timeout(Duration::from_secs(2), relay).await.unwrap().unwrap(); + assert!(relay_result.is_ok()); + assert!(stats.get_user_total_octets(user) <= 16); +} + +#[tokio::test] +async fn negative_preloaded_quota_forbids_any_forwarding() { + let stats = Arc::new(Stats::new()); + let user = "quota-extended-negative-user"; + stats.add_user_octets_from(user, 8); + + let (mut client_peer, relay_client) = duplex(1024); + let (relay_server, mut server_peer) = duplex(1024); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 128, + 128, + user, + Arc::clone(&stats), + Some(8), + Arc::new(BufferPool::new()), + )); + + client_peer.write_all(&[0xAA]).await.unwrap(); + server_peer.write_all(&[0xBB]).await.unwrap(); + + assert_eq!(read_available(&mut server_peer, Duration::from_millis(120)).await, 0); + assert_eq!(read_available(&mut client_peer, Duration::from_millis(120)).await, 0); + + let relay_result = timeout(Duration::from_secs(2), relay).await.unwrap().unwrap(); + assert!(matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. }))); + assert!(stats.get_user_total_octets(user) <= 8); +} + +#[tokio::test] +async fn edge_quota_one_ensures_at_most_one_byte_across_directions() { + let stats = Arc::new(Stats::new()); + let user = "quota-extended-edge-user"; + + let (mut client_peer, relay_client) = duplex(1024); + let (relay_server, mut server_peer) = duplex(1024); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 128, + 128, + user, + Arc::clone(&stats), + Some(1), + Arc::new(BufferPool::new()), + )); + + let _ = tokio::join!( + client_peer.write_all(&[0xFE]), + server_peer.write_all(&[0xEF]), + ); + + let mut buf = [0u8; 1]; + let delivered_s2c = timeout(Duration::from_millis(120), client_peer.read(&mut buf)).await.unwrap().unwrap_or(0); + let delivered_c2s = timeout(Duration::from_millis(120), server_peer.read(&mut buf)).await.unwrap().unwrap_or(0); + + assert!(delivered_s2c + delivered_c2s <= 1); + + let relay_result = timeout(Duration::from_secs(2), relay).await.unwrap().unwrap(); + assert!(matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. }))); +} + +#[tokio::test] +async fn adversarial_blackhat_alternating_jitter_does_not_overshoot_quota() { + let stats = Arc::new(Stats::new()); + let user = "quota-extended-blackhat-user"; + let quota = 24u64; + + let (mut client_peer, relay_client) = duplex(4096); + let (relay_server, mut server_peer) = duplex(4096); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 256, + 256, + user, + Arc::clone(&stats), + Some(quota), + Arc::new(BufferPool::new()), + )); + + let mut total_forwarded = 0usize; + + for i in 0..256usize { + if relay.is_finished() { + break; + } + if (i & 1) == 0 { + let _ = client_peer.write_all(&[(i as u8) ^ 0x57]).await; + let mut one = [0u8; 1]; + if let Ok(Ok(n)) = timeout(Duration::from_millis(6), server_peer.read(&mut one)).await { + total_forwarded += n; + } + } else { + let _ = server_peer.write_all(&[(i as u8) ^ 0xA8]).await; + let mut one = [0u8; 1]; + if let Ok(Ok(n)) = timeout(Duration::from_millis(6), client_peer.read(&mut one)).await { + total_forwarded += n; + } + } + + tokio::time::sleep(Duration::from_millis(((i % 3) + 1) as u64)).await; + } + + let relay_result = timeout(Duration::from_secs(3), relay).await.unwrap().unwrap(); + assert!(matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. }))); + assert!(total_forwarded <= quota as usize); + assert!(stats.get_user_total_octets(user) <= quota); +} + +#[tokio::test] +async fn light_fuzz_random_quota_schedule_preserves_quota_invariants() { + let mut rng = StdRng::seed_from_u64(0xBEEF_C0DE); + + for case in 0..32u64 { + let stats = Arc::new(Stats::new()); + let user = format!("quota-extended-fuzz-{case}"); + let quota = rng.random_range(1u64..=35u64); + + let (mut client_peer, relay_client) = duplex(4096); + let (relay_server, mut server_peer) = duplex(4096); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_user = user.clone(); + let relay_stats = Arc::clone(&stats); + let relay = tokio::spawn(async move { + relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 256, + 256, + &relay_user, + Arc::clone(&relay_stats), + Some(quota), + Arc::new(BufferPool::new()), + ) + .await + }); + + let mut total_forwarded = 0usize; + + for _ in 0..96usize { + if relay.is_finished() { + break; + } + + if rng.random::() { + let _ = client_peer.write_all(&[rng.random::()]).await; + let mut one = [0u8; 1]; + if let Ok(Ok(n)) = timeout(Duration::from_millis(4), server_peer.read(&mut one)).await { + total_forwarded += n; + } + } else { + let _ = server_peer.write_all(&[rng.random::()]).await; + let mut one = [0u8; 1]; + if let Ok(Ok(n)) = timeout(Duration::from_millis(4), client_peer.read(&mut one)).await { + total_forwarded += n; + } + } + } + + drop(client_peer); + drop(server_peer); + + let relay_result = timeout(Duration::from_secs(2), relay).await.unwrap().unwrap(); + assert!(relay_result.is_ok() || matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. }))); + assert!(total_forwarded <= quota as usize); + assert!(stats.get_user_total_octets(&user) <= quota); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_parallel_relays_for_one_user_obey_global_quota() { + let stats = Arc::new(Stats::new()); + let user = "quota-extended-stress-user".to_string(); + let quota = 64u64; + + let mut tasks = Vec::new(); + + for worker in 0..4u8 { + let stats = Arc::clone(&stats); + let user = user.clone(); + + tasks.push(tokio::spawn(async move { + let (mut client_peer, relay_client) = duplex(2048); + let (relay_server, mut server_peer) = duplex(2048); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_user = user.clone(); + let relay_stats = Arc::clone(&stats); + let relay = tokio::spawn(async move { + relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 128, + 128, + &relay_user, + Arc::clone(&relay_stats), + Some(quota), + Arc::new(BufferPool::new()), + ) + .await + }); + + let mut total = 0usize; + for step in 0..64u8 { + if relay.is_finished() { + break; + } + if (step as usize + worker as usize) % 2 == 0 { + let _ = client_peer.write_all(&[(step ^ 0x5A)]).await; + let mut one = [0u8; 1]; + if let Ok(Ok(n)) = timeout(Duration::from_millis(6), server_peer.read(&mut one)).await { + total += n; + } + } else { + let _ = server_peer.write_all(&[(step ^ 0xA5)]).await; + let mut one = [0u8; 1]; + if let Ok(Ok(n)) = timeout(Duration::from_millis(6), client_peer.read(&mut one)).await { + total += n; + } + } + tokio::time::sleep(Duration::from_millis(1)).await; + } + + drop(client_peer); + drop(server_peer); + + let relay_result = timeout(Duration::from_secs(2), relay).await.unwrap().unwrap(); + assert!(relay_result.is_ok() || matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. }))); + total + })); + } + + let mut delivered = 0usize; + for task in tasks { + delivered += task.await.unwrap(); + } + + assert!(stats.get_user_total_octets(&user) <= quota); + assert!(delivered <= quota as usize); +} diff --git a/src/proxy/tests/relay_quota_lock_eviction_lifecycle_tdd_tests.rs b/src/proxy/tests/relay_quota_lock_eviction_lifecycle_tdd_tests.rs new file mode 100644 index 0000000..806efb6 --- /dev/null +++ b/src/proxy/tests/relay_quota_lock_eviction_lifecycle_tdd_tests.rs @@ -0,0 +1,79 @@ +use super::*; +use dashmap::DashMap; +use std::sync::Arc; +use tokio::time::{Duration, timeout}; + +#[test] +fn tdd_explicit_quota_lock_evict_reclaims_only_unheld_entries() { + let _guard = quota_user_lock_test_scope(); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let held_user = format!("quota-evict-held-{}", std::process::id()); + let stale_a_user = format!("quota-evict-stale-a-{}", std::process::id()); + let stale_b_user = format!("quota-evict-stale-b-{}", std::process::id()); + + let held = quota_user_lock(&held_user); + let stale_a = quota_user_lock(&stale_a_user); + let stale_b = quota_user_lock(&stale_b_user); + + assert!(map.get(&held_user).is_some()); + assert!(map.get(&stale_a_user).is_some()); + assert!(map.get(&stale_b_user).is_some()); + + drop(stale_a); + drop(stale_b); + + quota_user_lock_evict(); + + assert!( + map.get(&held_user).is_some(), + "held entry must survive eviction" + ); + assert!( + map.get(&stale_a_user).is_none(), + "unheld stale entry must be reclaimed" + ); + assert!( + map.get(&stale_b_user).is_none(), + "unheld stale entry must be reclaimed" + ); + + drop(held); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn tdd_periodic_quota_lock_evictor_reclaims_stale_entries_off_hot_path() { + let _guard = quota_user_lock_test_scope(); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let held_user = format!("quota-evict-loop-held-{}", std::process::id()); + let stale_user = format!("quota-evict-loop-stale-{}", std::process::id()); + + let held = quota_user_lock(&held_user); + let stale = quota_user_lock(&stale_user); + + assert_eq!(map.len(), 2); + drop(stale); + + let evictor = spawn_quota_user_lock_evictor_for_tests(Duration::from_millis(5)); + + timeout(Duration::from_millis(200), async { + loop { + if map.get(&stale_user).is_none() { + break; + } + tokio::time::sleep(Duration::from_millis(5)).await; + } + }) + .await + .expect("periodic quota lock evictor must reclaim stale entry"); + + evictor.abort(); + + assert!(map.get(&held_user).is_some()); + assert!(map.get(&stale_user).is_none()); + + drop(held); +} diff --git a/src/proxy/tests/relay_quota_lock_eviction_stress_security_tests.rs b/src/proxy/tests/relay_quota_lock_eviction_stress_security_tests.rs new file mode 100644 index 0000000..251582a --- /dev/null +++ b/src/proxy/tests/relay_quota_lock_eviction_stress_security_tests.rs @@ -0,0 +1,153 @@ +use super::*; +use dashmap::DashMap; +use std::sync::Arc; +use tokio::task::JoinSet; +use tokio::time::{Duration, timeout}; + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_background_evictor_with_high_churn_keeps_cache_bounded_and_live() { + let _guard = quota_user_lock_test_scope(); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let evictor = spawn_quota_user_lock_evictor_for_tests(Duration::from_millis(5)); + + let mut tasks = JoinSet::new(); + for worker in 0..24u32 { + tasks.spawn(async move { + for round in 0..320u32 { + let user = format!( + "quota-evict-stress-user-{}-{}-{}", + std::process::id(), + worker, + round + ); + let lock = quota_user_lock(&user); + if round % 19 == 0 { + tokio::task::yield_now().await; + } + drop(lock); + } + }); + } + + while let Some(done) = tasks.join_next().await { + done.expect("stress worker must not panic"); + } + + quota_user_lock_evict(); + tokio::time::sleep(Duration::from_millis(20)).await; + + assert!( + map.len() <= QUOTA_USER_LOCKS_MAX, + "quota lock map must remain bounded after churn + eviction" + ); + + let sanity_user = format!("quota-evict-stress-sanity-{}", std::process::id()); + let sanity_lock = quota_user_lock(&sanity_user); + assert!( + map.get(&sanity_user).is_some(), + "sanity user should be cacheable after eviction reclaimed stale entries" + ); + + drop(sanity_lock); + evictor.abort(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn adversarial_held_lock_survives_repeated_eviction_then_reclaims_after_release() { + let _guard = quota_user_lock_test_scope(); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let held_user = format!("quota-evict-held-survive-{}", std::process::id()); + let held = quota_user_lock(&held_user); + + let evictor = spawn_quota_user_lock_evictor_for_tests(Duration::from_millis(3)); + + for idx in 0..512u32 { + let user = format!("quota-evict-held-churn-{}-{}", std::process::id(), idx); + let temp = quota_user_lock(&user); + drop(temp); + if idx % 32 == 0 { + tokio::task::yield_now().await; + } + } + + let reacquired = quota_user_lock(&held_user); + assert!( + Arc::ptr_eq(&held, &reacquired), + "held user lock identity must remain stable across repeated evictions" + ); + assert!( + map.get(&held_user).is_some(), + "held user entry must not be reclaimed while externally referenced" + ); + + drop(reacquired); + drop(held); + + timeout(Duration::from_millis(300), async { + loop { + if map.get(&held_user).is_none() { + break; + } + tokio::time::sleep(Duration::from_millis(5)).await; + } + }) + .await + .expect("released held lock must be reclaimed by periodic evictor"); + + evictor.abort(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_saturation_then_periodic_eviction_recovers_cacheability_without_inline_retain() { + let _guard = quota_user_lock_test_scope(); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); + let prefix = format!("quota-evict-saturated-{}", std::process::id()); + for idx in 0..QUOTA_USER_LOCKS_MAX { + retained.push(quota_user_lock(&format!("{prefix}-{idx}"))); + } + + assert_eq!(map.len(), QUOTA_USER_LOCKS_MAX); + + let overflow_user = format!("quota-evict-overflow-user-{}", std::process::id()); + let overflow_before = quota_user_lock(&overflow_user); + assert!( + map.get(&overflow_user).is_none(), + "saturated map must initially route new user to overflow stripe" + ); + + drop(retained); + + let evictor = spawn_quota_user_lock_evictor_for_tests(Duration::from_millis(4)); + + timeout(Duration::from_millis(400), async { + loop { + if map.len() < QUOTA_USER_LOCKS_MAX { + break; + } + tokio::time::sleep(Duration::from_millis(5)).await; + } + }) + .await + .expect("periodic evictor must reclaim stale saturated entries"); + + let overflow_after = quota_user_lock(&overflow_user); + assert!( + map.get(&overflow_user).is_some(), + "after eviction, overflow user should become cacheable again" + ); + assert!( + Arc::strong_count(&overflow_after) >= 2, + "cacheable lock should be held by map and caller" + ); + + drop(overflow_before); + drop(overflow_after); + evictor.abort(); +} diff --git a/src/proxy/tests/relay_quota_lock_identity_security_tests.rs b/src/proxy/tests/relay_quota_lock_identity_security_tests.rs new file mode 100644 index 0000000..f717f54 --- /dev/null +++ b/src/proxy/tests/relay_quota_lock_identity_security_tests.rs @@ -0,0 +1,135 @@ +use super::*; +use crate::stats::Stats; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::task::Waker; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +#[derive(Default)] +struct WakeCounter { + wakes: AtomicUsize, +} + +impl std::task::Wake for WakeCounter { + fn wake(self: Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } + + fn wake_by_ref(self: &Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } +} + +fn build_context() -> (Arc, Context<'static>) { + let wake_counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&wake_counter)); + // Context stores a reference; leak one Waker for deterministic test scope. + let leaked_waker: &'static Waker = Box::leak(Box::new(waker)); + (wake_counter, Context::from_waker(leaked_waker)) +} + +#[tokio::test] +async fn adversarial_map_churn_cannot_bypass_held_writer_lock() { + let _guard = quota_user_lock_test_scope(); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let user = "quota-identity-writer-user"; + let held_lock = quota_user_lock(user); + let _held_guard = held_lock + .try_lock() + .expect("test must hold initial user lock before StatsIo poll"); + + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user.to_string(), + Some(1024), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + map.clear(); + let churned_lock = quota_user_lock(user); + assert!( + !Arc::ptr_eq(&held_lock, &churned_lock), + "precondition: map churn should produce a distinct lock identity" + ); + + let (_wake_counter, mut cx) = build_context(); + let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x11, 0x22, 0x33, 0x44]); + + assert!( + matches!(poll, Poll::Pending), + "writer must remain pending on the originally-held lock identity" + ); +} + +#[tokio::test] +async fn adversarial_map_churn_cannot_bypass_held_reader_lock() { + let _guard = quota_user_lock_test_scope(); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let user = "quota-identity-reader-user"; + let held_lock = quota_user_lock(user); + let _held_guard = held_lock + .try_lock() + .expect("test must hold initial user lock before StatsIo poll"); + + let mut io = StatsIo::new( + tokio::io::empty(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user.to_string(), + Some(1024), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + map.clear(); + let churned_lock = quota_user_lock(user); + assert!( + !Arc::ptr_eq(&held_lock, &churned_lock), + "precondition: map churn should produce a distinct lock identity" + ); + + let (_wake_counter, mut cx) = build_context(); + let mut storage = [0u8; 8]; + let mut read_buf = ReadBuf::new(&mut storage); + let poll = Pin::new(&mut io).poll_read(&mut cx, &mut read_buf); + + assert!( + matches!(poll, Poll::Pending), + "reader must remain pending on the originally-held lock identity" + ); +} + +#[tokio::test] +async fn business_no_lock_contention_keeps_writer_progress() { + let _guard = quota_user_lock_test_scope(); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let user = "quota-identity-progress-user"; + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user.to_string(), + Some(1024), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + let (_wake_counter, mut cx) = build_context(); + let poll = Pin::new(&mut io).poll_write(&mut cx, &[0xAA, 0xBB]); + + assert!( + matches!(poll, Poll::Ready(Ok(2))), + "writer should progress immediately without contention" + ); +} diff --git a/src/proxy/tests/relay_quota_lock_pressure_adversarial_tests.rs b/src/proxy/tests/relay_quota_lock_pressure_adversarial_tests.rs index e29e86e..5687965 100644 --- a/src/proxy/tests/relay_quota_lock_pressure_adversarial_tests.rs +++ b/src/proxy/tests/relay_quota_lock_pressure_adversarial_tests.rs @@ -127,7 +127,7 @@ fn quota_lock_saturation_returns_stable_overflow_lock_without_cache_growth() { } #[test] -fn quota_lock_reclaims_unreferenced_entries_before_ephemeral_fallback() { +fn quota_lock_reclaims_unreferenced_entries_after_explicit_eviction_pass() { let _guard = super::quota_user_lock_test_scope(); let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); map.clear(); @@ -142,6 +142,8 @@ fn quota_lock_reclaims_unreferenced_entries_before_ephemeral_fallback() { drop(retained); + quota_user_lock_evict(); + let overflow_user = format!("quota-reclaim-overflow-{}", std::process::id()); let overflow = quota_user_lock(&overflow_user); diff --git a/src/proxy/tests/relay_quota_retry_allocation_latency_security_tests.rs b/src/proxy/tests/relay_quota_retry_allocation_latency_security_tests.rs new file mode 100644 index 0000000..447a090 --- /dev/null +++ b/src/proxy/tests/relay_quota_retry_allocation_latency_security_tests.rs @@ -0,0 +1,249 @@ +use super::*; +use crate::stats::Stats; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::task::{Context, Waker}; +use tokio::io::AsyncWriteExt; +use tokio::time::{Duration, Instant, timeout}; + +#[derive(Default)] +struct WakeCounter { + wakes: AtomicUsize, +} + +impl std::task::Wake for WakeCounter { + fn wake(self: Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } + + fn wake_by_ref(self: &Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } +} + +fn quota_test_guard() -> impl Drop { + super::quota_user_lock_test_scope() +} + +fn build_context() -> (Arc, Context<'static>) { + let wake_counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&wake_counter)); + let leaked_waker: &'static Waker = Box::leak(Box::new(waker)); + (wake_counter, Context::from_waker(leaked_waker)) +} + +fn sleep_slot_ptr(slot: &Option>>) -> usize { + slot.as_ref() + .map(|sleep| (&**sleep) as *const tokio::time::Sleep as usize) + .unwrap_or(0) +} + +#[tokio::test] +async fn tdd_single_pending_timer_does_not_allocate_on_each_repoll() { + let _guard = quota_test_guard(); + + let user = format!("retry-alloc-single-pending-{}", std::process::id()); + let lock = quota_user_lock(&user); + let held_guard = lock + .try_lock() + .expect("test must hold local lock to force retry scheduling"); + + reset_quota_retry_sleep_allocs_for_tests(); + + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + Some(2048), + Arc::new(AtomicBool::new(false)), + Instant::now(), + ); + + let (_wake_counter, mut cx) = build_context(); + + let first = Pin::new(&mut io).poll_write(&mut cx, &[0xA1]); + assert!(first.is_pending()); + let allocs_after_first = quota_retry_sleep_allocs_for_tests(); + let ptr_after_first = sleep_slot_ptr(&io.quota_write_retry_sleep); + + let second = Pin::new(&mut io).poll_write(&mut cx, &[0xA2]); + assert!(second.is_pending()); + let allocs_after_second = quota_retry_sleep_allocs_for_tests(); + let ptr_after_second = sleep_slot_ptr(&io.quota_write_retry_sleep); + + assert_eq!(allocs_after_first, 1, "first pending poll must allocate one timer"); + assert_eq!( + allocs_after_second, 1, + "repoll while the same timer is pending must not allocate again" + ); + assert_eq!( + ptr_after_first, ptr_after_second, + "repoll while pending should retain the same timer allocation" + ); + + drop(held_guard); +} + +#[tokio::test] +async fn tdd_retry_cycle_allocates_once_per_fired_timer_cycle_not_per_poll() { + let _guard = quota_test_guard(); + + let user = format!("retry-alloc-per-cycle-{}", std::process::id()); + let lock = quota_user_lock(&user); + let held_guard = lock + .try_lock() + .expect("test must hold local lock to keep write path pending"); + + reset_quota_retry_sleep_allocs_for_tests(); + + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + Some(2048), + Arc::new(AtomicBool::new(false)), + Instant::now(), + ); + + let (wake_counter, mut cx) = build_context(); + + let mut polls = 0u64; + let mut observed_wakes = 0usize; + let started = Instant::now(); + while started.elapsed() < Duration::from_millis(70) { + let poll = Pin::new(&mut io).poll_write(&mut cx, &[0xB1]); + polls = polls.saturating_add(1); + assert!(poll.is_pending()); + + let wakes = wake_counter.wakes.load(Ordering::Relaxed); + if wakes > observed_wakes { + observed_wakes = wakes; + } + + tokio::time::sleep(Duration::from_millis(1)).await; + } + + let allocs = quota_retry_sleep_allocs_for_tests(); + assert!(allocs >= 2, "multiple fired cycles should allocate multiple timers"); + assert!( + allocs < polls, + "timer allocations must be bounded by cycles, not by every repoll (allocs={allocs}, polls={polls})" + ); + + drop(held_guard); +} + +#[tokio::test] +async fn adversarial_backoff_latency_envelope_stays_bounded_under_contention() { + let _guard = quota_test_guard(); + + let user = format!("retry-latency-envelope-{}", std::process::id()); + let lock = quota_user_lock(&user); + let held_guard = lock + .try_lock() + .expect("test must hold local lock for sustained contention"); + + reset_quota_retry_sleep_allocs_for_tests(); + + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + Some(2048), + Arc::new(AtomicBool::new(false)), + Instant::now(), + ); + + let (wake_counter, mut cx) = build_context(); + + let first = Pin::new(&mut io).poll_write(&mut cx, &[0xC1]); + assert!(first.is_pending()); + + let started = Instant::now(); + let mut last_wakes = 0usize; + let mut wake_instants = Vec::new(); + + while started.elapsed() < Duration::from_millis(120) { + let wakes = wake_counter.wakes.load(Ordering::Relaxed); + if wakes > last_wakes { + last_wakes = wakes; + wake_instants.push(Instant::now()); + let pending = Pin::new(&mut io).poll_write(&mut cx, &[0xC2]); + assert!(pending.is_pending()); + } + tokio::time::sleep(Duration::from_millis(1)).await; + } + + let mut max_gap = Duration::from_millis(0); + for idx in 1..wake_instants.len() { + let gap = wake_instants[idx].saturating_duration_since(wake_instants[idx - 1]); + if gap > max_gap { + max_gap = gap; + } + } + + assert!( + max_gap <= Duration::from_millis(35), + "retry wake gap must remain bounded in test profile; observed max gap={max_gap:?}" + ); + assert!( + quota_retry_sleep_allocs_for_tests() <= 16, + "allocation cycles must remain bounded during a short contention window" + ); + + drop(held_guard); +} + +#[tokio::test] +async fn micro_benchmark_release_to_completion_latency_stays_bounded() { + let _guard = quota_test_guard(); + + let rounds = 96usize; + let mut samples_ms = Vec::with_capacity(rounds); + + for round in 0..rounds { + let user = format!("retry-release-latency-{}-{round}", std::process::id()); + let lock = quota_user_lock(&user); + let held_guard = lock + .try_lock() + .expect("test must hold local lock before spawning blocked writer"); + + let writer = tokio::spawn(async move { + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + Some(2048), + Arc::new(AtomicBool::new(false)), + Instant::now(), + ); + io.write_all(&[0xD1]).await + }); + + tokio::time::sleep(Duration::from_millis(2)).await; + let release_at = Instant::now(); + drop(held_guard); + + let done = timeout(Duration::from_millis(120), writer) + .await + .expect("blocked writer must complete after release") + .expect("writer task must not panic"); + assert!(done.is_ok()); + + samples_ms.push(release_at.elapsed().as_millis() as u64); + } + + samples_ms.sort_unstable(); + let p95_idx = ((samples_ms.len() * 95) / 100).min(samples_ms.len().saturating_sub(1)); + let p95_ms = samples_ms[p95_idx]; + + assert!( + p95_ms <= 40, + "contention release->completion p95 must stay bounded; p95_ms={p95_ms}, samples={samples_ms:?}" + ); +} diff --git a/src/proxy/tests/relay_quota_retry_backoff_benchmark_security_tests.rs b/src/proxy/tests/relay_quota_retry_backoff_benchmark_security_tests.rs new file mode 100644 index 0000000..7083eb2 --- /dev/null +++ b/src/proxy/tests/relay_quota_retry_backoff_benchmark_security_tests.rs @@ -0,0 +1,241 @@ +use super::*; +use crate::stats::Stats; +use dashmap::DashMap; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::task::{Context, Waker}; +use tokio::io::ReadBuf; +use tokio::time::{Duration, Instant}; + +#[derive(Default)] +struct WakeCounter { + wakes: AtomicUsize, +} + +impl std::task::Wake for WakeCounter { + fn wake(self: Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } + + fn wake_by_ref(self: &Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } +} + +fn quota_test_guard() -> impl Drop { + super::quota_user_lock_test_scope() +} + +fn saturate_quota_user_locks() -> Vec>> { + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); + for idx in 0..QUOTA_USER_LOCKS_MAX { + retained.push(quota_user_lock(&format!("quota-retry-bench-saturate-{idx}"))); + } + retained +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_contention_wake_rate_decays_with_backoff_curve() { + let _guard = quota_test_guard(); + + let _retained = saturate_quota_user_locks(); + let user = format!("quota-backoff-bench-{}", std::process::id()); + let stats = Arc::new(Stats::new()); + + let lock = quota_user_lock(&user); + let held_guard = lock + .try_lock() + .expect("test must hold quota lock before benchmark run"); + + let waiters = 64usize; + let mut ios = Vec::with_capacity(waiters); + let mut wake_counters = Vec::with_capacity(waiters); + + for _ in 0..waiters { + ios.push(StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::clone(&stats), + user.clone(), + Some(4096), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + )); + } + + for io in &mut ios { + let counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&counter)); + let mut cx = Context::from_waker(&waker); + let pending = Pin::new(io).poll_write(&mut cx, &[0x71]); + assert!(pending.is_pending()); + wake_counters.push(counter); + } + + let mut observed = vec![0usize; waiters]; + let start = Instant::now(); + let mut wakes_at_40ms = 0usize; + let mut wakes_at_160ms = 0usize; + + while start.elapsed() < Duration::from_millis(200) { + for (idx, counter) in wake_counters.iter().enumerate() { + let wakes = counter.wakes.load(Ordering::Relaxed); + if wakes > observed[idx] { + observed[idx] = wakes; + let waker = Waker::from(Arc::clone(counter)); + let mut cx = Context::from_waker(&waker); + let pending = Pin::new(&mut ios[idx]).poll_write(&mut cx, &[0x72]); + assert!(pending.is_pending()); + } + } + + let elapsed = start.elapsed(); + if elapsed >= Duration::from_millis(40) && wakes_at_40ms == 0 { + wakes_at_40ms = wake_counters + .iter() + .map(|counter| counter.wakes.load(Ordering::Relaxed)) + .sum(); + } + if elapsed >= Duration::from_millis(160) && wakes_at_160ms == 0 { + wakes_at_160ms = wake_counters + .iter() + .map(|counter| counter.wakes.load(Ordering::Relaxed)) + .sum(); + } + + tokio::time::sleep(Duration::from_millis(1)).await; + } + + let total_wakes: usize = wake_counters + .iter() + .map(|counter| counter.wakes.load(Ordering::Relaxed)) + .sum(); + + let wakes_at_200ms = total_wakes; + let early_window_wakes = wakes_at_40ms; + let late_window_wakes = wakes_at_200ms.saturating_sub(wakes_at_160ms); + + assert!( + total_wakes <= waiters * 28, + "backoff benchmark exceeded wake budget; waiters={waiters}, wakes={total_wakes}" + ); + + assert!( + early_window_wakes > 0, + "benchmark failed to observe early contention wakes" + ); + + assert!( + late_window_wakes * 4 <= early_window_wakes * 3, + "wake-rate decay invariant violated; early_0_40ms={early_window_wakes}, late_160_200ms={late_window_wakes}, total={total_wakes}" + ); + + drop(held_guard); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_read_contention_wake_rate_decays_with_backoff_curve() { + let _guard = quota_test_guard(); + + let _retained = saturate_quota_user_locks(); + let user = format!("quota-backoff-read-bench-{}", std::process::id()); + let stats = Arc::new(Stats::new()); + + let lock = quota_user_lock(&user); + let held_guard = lock + .try_lock() + .expect("test must hold quota lock before read benchmark run"); + + let waiters = 64usize; + let mut ios = Vec::with_capacity(waiters); + let mut wake_counters = Vec::with_capacity(waiters); + + for _ in 0..waiters { + ios.push(StatsIo::new( + tokio::io::empty(), + Arc::new(SharedCounters::new()), + Arc::clone(&stats), + user.clone(), + Some(4096), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + )); + } + + for io in &mut ios { + let counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&counter)); + let mut cx = Context::from_waker(&waker); + let mut storage = [0u8; 1]; + let mut buf = ReadBuf::new(&mut storage); + let pending = Pin::new(io).poll_read(&mut cx, &mut buf); + assert!(pending.is_pending()); + wake_counters.push(counter); + } + + let mut observed = vec![0usize; waiters]; + let start = Instant::now(); + let mut wakes_at_40ms = 0usize; + let mut wakes_at_160ms = 0usize; + + while start.elapsed() < Duration::from_millis(200) { + for (idx, counter) in wake_counters.iter().enumerate() { + let wakes = counter.wakes.load(Ordering::Relaxed); + if wakes > observed[idx] { + observed[idx] = wakes; + let waker = Waker::from(Arc::clone(counter)); + let mut cx = Context::from_waker(&waker); + let mut storage = [0u8; 1]; + let mut buf = ReadBuf::new(&mut storage); + let pending = Pin::new(&mut ios[idx]).poll_read(&mut cx, &mut buf); + assert!(pending.is_pending()); + } + } + + let elapsed = start.elapsed(); + if elapsed >= Duration::from_millis(40) && wakes_at_40ms == 0 { + wakes_at_40ms = wake_counters + .iter() + .map(|counter| counter.wakes.load(Ordering::Relaxed)) + .sum(); + } + if elapsed >= Duration::from_millis(160) && wakes_at_160ms == 0 { + wakes_at_160ms = wake_counters + .iter() + .map(|counter| counter.wakes.load(Ordering::Relaxed)) + .sum(); + } + + tokio::time::sleep(Duration::from_millis(1)).await; + } + + let total_wakes: usize = wake_counters + .iter() + .map(|counter| counter.wakes.load(Ordering::Relaxed)) + .sum(); + + let wakes_at_200ms = total_wakes; + let early_window_wakes = wakes_at_40ms; + let late_window_wakes = wakes_at_200ms.saturating_sub(wakes_at_160ms); + + assert!( + total_wakes <= waiters * 28, + "read backoff benchmark exceeded wake budget; waiters={waiters}, wakes={total_wakes}" + ); + + assert!( + early_window_wakes > 0, + "read benchmark failed to observe early contention wakes" + ); + + assert!( + late_window_wakes * 4 <= early_window_wakes * 3, + "read wake-rate decay invariant violated; early_0_40ms={early_window_wakes}, late_160_200ms={late_window_wakes}, total={total_wakes}" + ); + + drop(held_guard); +} diff --git a/src/proxy/tests/relay_quota_retry_backoff_security_tests.rs b/src/proxy/tests/relay_quota_retry_backoff_security_tests.rs new file mode 100644 index 0000000..7f1e451 --- /dev/null +++ b/src/proxy/tests/relay_quota_retry_backoff_security_tests.rs @@ -0,0 +1,339 @@ +use super::*; +use crate::stats::Stats; +use dashmap::DashMap; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::task::{Context, Waker}; +use tokio::io::ReadBuf; +use tokio::time::{Duration, Instant}; + +#[derive(Default)] +struct WakeCounter { + wakes: AtomicUsize, +} + +impl std::task::Wake for WakeCounter { + fn wake(self: Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } + + fn wake_by_ref(self: &Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } +} + +fn quota_test_guard() -> impl Drop { + super::quota_user_lock_test_scope() +} + +fn saturate_quota_user_locks() -> Vec>> { + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); + for idx in 0..QUOTA_USER_LOCKS_MAX { + retained.push(quota_user_lock(&format!("quota-retry-backoff-saturate-{idx}"))); + } + retained +} + +#[tokio::test] +async fn positive_uncontended_writer_keeps_retry_wakes_zero() { + let _guard = quota_test_guard(); + + let stats = Arc::new(Stats::new()); + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::clone(&stats), + "quota-backoff-positive".to_string(), + Some(2048), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + let wake_counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&wake_counter)); + let mut cx = Context::from_waker(&waker); + + let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x41, 0x42]); + assert!(poll.is_ready(), "uncontended writer must complete immediately"); + assert_eq!( + wake_counter.wakes.load(Ordering::Relaxed), + 0, + "uncontended path must not schedule deferred contention wakes" + ); +} + +#[tokio::test] +async fn adversarial_writer_sustained_contention_executor_repoll_is_rate_limited() { + let _guard = quota_test_guard(); + + let _retained = saturate_quota_user_locks(); + let user = "quota-backoff-adversarial-writer"; + let stats = Arc::new(Stats::new()); + + let lock = quota_user_lock(user); + let held_guard = lock + .try_lock() + .expect("test must hold quota lock before polling writer"); + + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::clone(&stats), + user.to_string(), + Some(2048), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + let wake_counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&wake_counter)); + let mut cx = Context::from_waker(&waker); + + let first = Pin::new(&mut io).poll_write(&mut cx, &[0xAA]); + assert!(first.is_pending()); + + let start = Instant::now(); + let mut observed = 0usize; + while start.elapsed() < Duration::from_millis(80) { + let wakes = wake_counter.wakes.load(Ordering::Relaxed); + if wakes > observed { + observed = wakes; + let pending = Pin::new(&mut io).poll_write(&mut cx, &[0xAB]); + assert!(pending.is_pending()); + } + tokio::time::sleep(Duration::from_millis(1)).await; + } + + assert!( + wake_counter.wakes.load(Ordering::Relaxed) <= 16, + "sustained contention must be rate limited; observed wakes={} in 80ms", + wake_counter.wakes.load(Ordering::Relaxed) + ); + + drop(held_guard); + let ready = Pin::new(&mut io).poll_write(&mut cx, &[0xAC]); + assert!(ready.is_ready()); +} + +#[tokio::test] +async fn adversarial_reader_sustained_contention_executor_repoll_is_rate_limited() { + let _guard = quota_test_guard(); + + let _retained = saturate_quota_user_locks(); + let user = "quota-backoff-adversarial-reader"; + let stats = Arc::new(Stats::new()); + + let lock = quota_user_lock(user); + let held_guard = lock + .try_lock() + .expect("test must hold quota lock before polling reader"); + + let mut io = StatsIo::new( + tokio::io::empty(), + Arc::new(SharedCounters::new()), + Arc::clone(&stats), + user.to_string(), + Some(2048), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + let wake_counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&wake_counter)); + let mut cx = Context::from_waker(&waker); + let mut storage = [0u8; 1]; + + let mut buf = ReadBuf::new(&mut storage); + let first = Pin::new(&mut io).poll_read(&mut cx, &mut buf); + assert!(first.is_pending()); + + let start = Instant::now(); + let mut observed = 0usize; + while start.elapsed() < Duration::from_millis(80) { + let wakes = wake_counter.wakes.load(Ordering::Relaxed); + if wakes > observed { + observed = wakes; + let mut next = ReadBuf::new(&mut storage); + let pending = Pin::new(&mut io).poll_read(&mut cx, &mut next); + assert!(pending.is_pending()); + } + tokio::time::sleep(Duration::from_millis(1)).await; + } + + assert!( + wake_counter.wakes.load(Ordering::Relaxed) <= 16, + "sustained contention must be rate limited; observed wakes={} in 80ms", + wake_counter.wakes.load(Ordering::Relaxed) + ); + + drop(held_guard); + let mut done = ReadBuf::new(&mut storage); + let ready = Pin::new(&mut io).poll_read(&mut cx, &mut done); + assert!(ready.is_ready()); +} + +#[tokio::test] +async fn edge_backoff_attempt_resets_after_contention_release() { + let _guard = quota_test_guard(); + + let _retained = saturate_quota_user_locks(); + let user = "quota-backoff-edge-reset"; + let stats = Arc::new(Stats::new()); + + let lock = quota_user_lock(user); + let held_guard = lock + .try_lock() + .expect("test must hold quota lock before polling writer"); + + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::clone(&stats), + user.to_string(), + Some(2048), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + let wake_counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&wake_counter)); + let mut cx = Context::from_waker(&waker); + + let initial = Pin::new(&mut io).poll_write(&mut cx, &[0x31]); + assert!(initial.is_pending()); + + tokio::time::sleep(Duration::from_millis(15)).await; + let wakes = wake_counter.wakes.load(Ordering::Relaxed); + if wakes > 0 { + let pending = Pin::new(&mut io).poll_write(&mut cx, &[0x32]); + assert!(pending.is_pending()); + } + + drop(held_guard); + let ready = Pin::new(&mut io).poll_write(&mut cx, &[0x33]); + assert!(ready.is_ready()); + assert!( + !io.quota_write_wake_scheduled, + "successful write must clear deferred wake scheduling flag" + ); + assert!( + io.quota_write_retry_sleep.is_none(), + "successful write must clear deferred sleep slot" + ); +} + +#[tokio::test] +async fn light_fuzz_writer_repoll_schedule_keeps_wake_budget_bounded() { + let _guard = quota_test_guard(); + + let _retained = saturate_quota_user_locks(); + let user = "quota-backoff-fuzz-writer"; + let stats = Arc::new(Stats::new()); + + let lock = quota_user_lock(user); + let held_guard = lock + .try_lock() + .expect("test must hold quota lock before fuzz loop"); + + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::clone(&stats), + user.to_string(), + Some(2048), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + let wake_counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&wake_counter)); + let mut cx = Context::from_waker(&waker); + + let mut seed = 0x5EED_CAFE_7788_9900u64; + for _ in 0..64 { + let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x51]); + assert!(poll.is_pending()); + + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + let sleep_ms = (seed % 4) as u64; + tokio::time::sleep(Duration::from_millis(sleep_ms)).await; + } + + assert!( + wake_counter.wakes.load(Ordering::Relaxed) <= 24, + "fuzzed repoll schedule must keep wake budget bounded; observed wakes={}", + wake_counter.wakes.load(Ordering::Relaxed) + ); + + drop(held_guard); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_multi_waiter_contention_keeps_global_wake_budget_bounded() { + let _guard = quota_test_guard(); + + let _retained = saturate_quota_user_locks(); + let user = format!("quota-backoff-stress-{}", std::process::id()); + let stats = Arc::new(Stats::new()); + + let lock = quota_user_lock(&user); + let held_guard = lock + .try_lock() + .expect("test must hold quota lock before launching stress waiters"); + + let waiters = 48usize; + let mut ios = Vec::with_capacity(waiters); + let mut wake_counters = Vec::with_capacity(waiters); + + for _ in 0..waiters { + ios.push(StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::clone(&stats), + user.clone(), + Some(4096), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + )); + } + + for io in &mut ios { + let counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&counter)); + let mut cx = Context::from_waker(&waker); + let pending = Pin::new(io).poll_write(&mut cx, &[0x61]); + assert!(pending.is_pending()); + wake_counters.push(counter); + } + + let start = Instant::now(); + while start.elapsed() < Duration::from_millis(120) { + for (idx, counter) in wake_counters.iter().enumerate() { + if counter.wakes.load(Ordering::Relaxed) > 0 { + let waker = Waker::from(Arc::clone(counter)); + let mut cx = Context::from_waker(&waker); + let pending = Pin::new(&mut ios[idx]).poll_write(&mut cx, &[0x62]); + assert!(pending.is_pending()); + } + } + tokio::time::sleep(Duration::from_millis(1)).await; + } + + let total_wakes: usize = wake_counters + .iter() + .map(|counter| counter.wakes.load(Ordering::Relaxed)) + .sum(); + + assert!( + total_wakes <= waiters * 20, + "stress contention must keep aggregate wake budget bounded; waiters={waiters}, wakes={total_wakes}" + ); + + drop(held_guard); +} diff --git a/src/proxy/tests/relay_quota_retry_scheduler_tdd_tests.rs b/src/proxy/tests/relay_quota_retry_scheduler_tdd_tests.rs new file mode 100644 index 0000000..35a6b6e --- /dev/null +++ b/src/proxy/tests/relay_quota_retry_scheduler_tdd_tests.rs @@ -0,0 +1,246 @@ +use super::*; +use crate::stats::Stats; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::task::{Context, Poll, Waker}; +use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadBuf}; +use tokio::time::{Duration, timeout}; + +#[derive(Default)] +struct WakeCounter { + wakes: AtomicUsize, +} + +impl std::task::Wake for WakeCounter { + fn wake(self: Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } + + fn wake_by_ref(self: &Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } +} + +fn quota_test_guard() -> impl Drop { + super::quota_user_lock_test_scope() +} + +#[tokio::test] +async fn positive_uncontended_quota_limited_writer_completes() { + let _guard = quota_test_guard(); + + let stats = Arc::new(Stats::new()); + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::clone(&stats), + "tdd-uncontended".to_string(), + Some(1024), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + let result = io.write_all(&[0x41, 0x42, 0x43]).await; + assert!(result.is_ok(), "uncontended writer must complete"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn adversarial_contended_writers_without_repoll_must_not_wake_storm() { + let _guard = quota_test_guard(); + + let user = format!("tdd-writer-storm-{}", std::process::id()); + let held = quota_user_lock(&user); + let _held_guard = held + .try_lock() + .expect("test must hold quota lock before polling writers"); + + let stats = Arc::new(Stats::new()); + let writers = 24usize; + let mut ios = Vec::with_capacity(writers); + let mut wake_counters = Vec::with_capacity(writers); + + for _ in 0..writers { + ios.push(StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::clone(&stats), + user.clone(), + Some(1024), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + )); + } + + for io in &mut ios { + let counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&counter)); + let mut cx = Context::from_waker(&waker); + let poll = Pin::new(io).poll_write(&mut cx, &[0xAA]); + assert!(poll.is_pending(), "writer must be pending under held lock"); + wake_counters.push(counter); + } + + tokio::time::sleep(Duration::from_millis(25)).await; + + let total_wakes: usize = wake_counters + .iter() + .map(|counter| counter.wakes.load(Ordering::Relaxed)) + .sum(); + + assert!( + total_wakes <= writers * 4, + "retry scheduler must remain bounded without repoll; observed wakes={total_wakes}, writers={writers}" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn adversarial_contended_readers_without_repoll_must_not_wake_storm() { + let _guard = quota_test_guard(); + + let user = format!("tdd-reader-storm-{}", std::process::id()); + let held = quota_user_lock(&user); + let _held_guard = held + .try_lock() + .expect("test must hold quota lock before polling readers"); + + let stats = Arc::new(Stats::new()); + let readers = 24usize; + let mut ios = Vec::with_capacity(readers); + let mut wake_counters = Vec::with_capacity(readers); + + for _ in 0..readers { + ios.push(StatsIo::new( + tokio::io::empty(), + Arc::new(SharedCounters::new()), + Arc::clone(&stats), + user.clone(), + Some(1024), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + )); + } + + for io in &mut ios { + let counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&counter)); + let mut cx = Context::from_waker(&waker); + let mut storage = [0u8; 1]; + let mut buf = ReadBuf::new(&mut storage); + let poll = Pin::new(io).poll_read(&mut cx, &mut buf); + assert!(poll.is_pending(), "reader must be pending under held lock"); + wake_counters.push(counter); + } + + tokio::time::sleep(Duration::from_millis(25)).await; + + let total_wakes: usize = wake_counters + .iter() + .map(|counter| counter.wakes.load(Ordering::Relaxed)) + .sum(); + + assert!( + total_wakes <= readers * 4, + "retry scheduler must remain bounded without repoll; observed wakes={total_wakes}, readers={readers}" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn integration_contended_waiters_resume_after_lock_release() { + let _guard = quota_test_guard(); + + let user = format!("tdd-resume-{}", std::process::id()); + let held = quota_user_lock(&user); + let held_guard = held + .try_lock() + .expect("test must hold quota lock before launching waiters"); + + let stats = Arc::new(Stats::new()); + let mut waiters = Vec::new(); + for _ in 0..12 { + let stats = Arc::clone(&stats); + let user = user.clone(); + waiters.push(tokio::spawn(async move { + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + stats, + user, + Some(2048), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + io.write_all(&[0x5A]).await + })); + } + + tokio::time::sleep(Duration::from_millis(5)).await; + drop(held_guard); + + timeout(Duration::from_secs(1), async { + for waiter in waiters { + let result = waiter.await.expect("waiter task must not panic"); + assert!(result.is_ok(), "waiter must complete after release"); + } + }) + .await + .expect("all waiters must complete in bounded time"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn light_fuzz_contention_rounds_keep_retry_wakes_bounded() { + let _guard = quota_test_guard(); + + let mut seed = 0x9E37_79B9_AA55_1234u64; + for round in 0..20u32 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let writers = 8 + (seed as usize % 12); + let sleep_ms = 10 + (seed as u64 % 15); + let user = format!("tdd-fuzz-{}-{round}", std::process::id()); + + let held = quota_user_lock(&user); + let _held_guard = held + .try_lock() + .expect("test must hold quota lock in fuzz round"); + + let stats = Arc::new(Stats::new()); + let mut ios = Vec::with_capacity(writers); + let mut wake_counters = Vec::with_capacity(writers); + + for _ in 0..writers { + ios.push(StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::clone(&stats), + user.clone(), + Some(2048), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + )); + } + + for io in &mut ios { + let counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&counter)); + let mut cx = Context::from_waker(&waker); + let poll = Pin::new(io).poll_write(&mut cx, &[0x7A]); + assert!(matches!(poll, Poll::Pending)); + wake_counters.push(counter); + } + + tokio::time::sleep(Duration::from_millis(sleep_ms)).await; + + let total_wakes: usize = wake_counters + .iter() + .map(|counter| counter.wakes.load(Ordering::Relaxed)) + .sum(); + + assert!( + total_wakes <= writers * 4, + "fuzz round must keep wakes bounded; round={round}, writers={writers}, wakes={total_wakes}, sleep_ms={sleep_ms}" + ); + } +} diff --git a/src/proxy/tests/relay_security_tests.rs b/src/proxy/tests/relay_security_tests.rs index 50cdfa3..7375192 100644 --- a/src/proxy/tests/relay_security_tests.rs +++ b/src/proxy/tests/relay_security_tests.rs @@ -137,10 +137,10 @@ async fn quota_lock_contention_writer_schedules_single_deferred_wake_until_lock_ for _ in 0..8 { tokio::task::yield_now().await; } - assert_eq!( - wake_counter.wakes.load(Ordering::Relaxed), - wakes_after_first_yield, - "writer contention should not schedule unbounded wake storms before lock acquisition" + let wakes_after_second_window = wake_counter.wakes.load(Ordering::Relaxed); + assert!( + wakes_after_second_window <= wakes_after_first_yield.saturating_add(2), + "writer contention should keep retry wakes bounded before lock acquisition: before={wakes_after_first_yield}, after={wakes_after_second_window}" ); drop(held_lock); diff --git a/src/stats/mod.rs b/src/stats/mod.rs index d13d834..dc455a1 100644 --- a/src/stats/mod.rs +++ b/src/stats/mod.rs @@ -1884,6 +1884,32 @@ impl Stats { stats.octets_to_client.fetch_add(bytes, Ordering::Relaxed); } + pub fn sub_user_octets_to(&self, user: &str, bytes: u64) { + if !self.telemetry_user_enabled() { + return; + } + self.maybe_cleanup_user_stats(); + let Some(stats) = self.user_stats.get(user) else { + return; + }; + + Self::touch_user_stats(stats.value()); + let counter = &stats.octets_to_client; + let mut current = counter.load(Ordering::Relaxed); + loop { + let next = current.saturating_sub(bytes); + match counter.compare_exchange_weak( + current, + next, + Ordering::Relaxed, + Ordering::Relaxed, + ) { + Ok(_) => break, + Err(actual) => current = actual, + } + } + } + pub fn increment_user_msgs_from(&self, user: &str) { if !self.telemetry_user_enabled() { return; @@ -2440,3 +2466,7 @@ mod connection_lease_security_tests; #[cfg(test)] #[path = "tests/replay_checker_security_tests.rs"] mod replay_checker_security_tests; + +#[cfg(test)] +#[path = "tests/user_octets_sub_security_tests.rs"] +mod user_octets_sub_security_tests; diff --git a/src/stats/tests/user_octets_sub_security_tests.rs b/src/stats/tests/user_octets_sub_security_tests.rs new file mode 100644 index 0000000..d4e7580 --- /dev/null +++ b/src/stats/tests/user_octets_sub_security_tests.rs @@ -0,0 +1,151 @@ +use super::*; +use std::sync::Arc; +use std::thread; + +#[test] +fn sub_user_octets_to_underflow_saturates_at_zero() { + let stats = Stats::new(); + let user = "sub-underflow-user"; + + stats.add_user_octets_to(user, 3); + stats.sub_user_octets_to(user, 100); + + assert_eq!(stats.get_user_total_octets(user), 0); +} + +#[test] +fn sub_user_octets_to_does_not_affect_octets_from_client() { + let stats = Stats::new(); + let user = "sub-isolation-user"; + + stats.add_user_octets_from(user, 17); + stats.add_user_octets_to(user, 5); + stats.sub_user_octets_to(user, 3); + + assert_eq!(stats.get_user_total_octets(user), 19); +} + +#[test] +fn light_fuzz_add_sub_model_matches_saturating_reference() { + let stats = Stats::new(); + let user = "sub-fuzz-user"; + let mut seed = 0x91D2_4CB8_EE77_1101u64; + let mut model_to = 0u64; + + for _ in 0..8192 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let amt = ((seed >> 8) & 0x3f) + 1; + if (seed & 1) == 0 { + stats.add_user_octets_to(user, amt); + model_to = model_to.saturating_add(amt); + } else { + stats.sub_user_octets_to(user, amt); + model_to = model_to.saturating_sub(amt); + } + } + + assert_eq!(stats.get_user_total_octets(user), model_to); +} + +#[test] +fn stress_parallel_add_sub_never_underflows_or_panics() { + let stats = Arc::new(Stats::new()); + let user = "sub-stress-user"; + // Pre-fund with a large offset so subtractions never saturate at zero. + // This guarantees commutative updates, making the final state deterministic. + let base_offset = 10_000_000u64; + stats.add_user_octets_to(user, base_offset); + + let mut workers = Vec::new(); + + for tid in 0..16u64 { + let stats_for_thread = Arc::clone(&stats); + workers.push(thread::spawn(move || { + let mut seed = 0xD00D_1000_0000_0000u64 ^ tid; + let mut net_delta = 0i64; + for _ in 0..4096 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + let amt = ((seed >> 8) & 0x1f) + 1; + + if (seed & 1) == 0 { + stats_for_thread.add_user_octets_to(user, amt); + net_delta += amt as i64; + } else { + stats_for_thread.sub_user_octets_to(user, amt); + net_delta -= amt as i64; + } + } + + net_delta + })); + } + + let mut expected_net_delta = 0i64; + for worker in workers { + expected_net_delta += worker + .join() + .expect("sub-user stress worker must not panic"); + } + + let expected_total = (base_offset as i64 + expected_net_delta) as u64; + let total = stats.get_user_total_octets(user); + assert_eq!( + total, expected_total, + "concurrent add/sub lost updates or suffered ABA races" + ); +} + +#[test] +fn sub_user_octets_to_missing_user_is_noop() { + let stats = Stats::new(); + stats.sub_user_octets_to("missing-user", 1024); + assert_eq!(stats.get_user_total_octets("missing-user"), 0); +} + +#[test] +fn stress_parallel_per_user_models_remain_exact() { + let stats = Arc::new(Stats::new()); + let mut workers = Vec::new(); + + for tid in 0..16u64 { + let stats_for_thread = Arc::clone(&stats); + workers.push(thread::spawn(move || { + let user = format!("sub-per-user-{tid}"); + let mut seed = 0xFACE_0000_0000_0000u64 ^ tid; + let mut model = 0u64; + + for _ in 0..4096 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + let amt = ((seed >> 8) & 0x3f) + 1; + + if (seed & 1) == 0 { + stats_for_thread.add_user_octets_to(&user, amt); + model = model.saturating_add(amt); + } else { + stats_for_thread.sub_user_octets_to(&user, amt); + model = model.saturating_sub(amt); + } + } + + (user, model) + })); + } + + for worker in workers { + let (user, model) = worker + .join() + .expect("per-user subtract stress worker must not panic"); + assert_eq!( + stats.get_user_total_octets(&user), + model, + "per-user parallel model diverged" + ); + } +} \ No newline at end of file