diff --git a/Cargo.lock b/Cargo.lock index 8159a22..c4cde39 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1486,7 +1486,7 @@ dependencies = [ "cesu8", "cfg-if", "combine", - "jni-sys", + "jni-sys 0.3.1", "log", "thiserror 1.0.69", "walkdir", @@ -1495,9 +1495,31 @@ dependencies = [ [[package]] name = "jni-sys" -version = "0.3.0" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130" +checksum = "41a652e1f9b6e0275df1f15b32661cf0d4b78d4d87ddec5e0c3c20f097433258" +dependencies = [ + "jni-sys 0.4.1", +] + +[[package]] +name = "jni-sys" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c6377a88cb3910bee9b0fa88d4f42e1d2da8e79915598f65fb0c7ee14c878af2" +dependencies = [ + "jni-sys-macros", +] + +[[package]] +name = "jni-sys-macros" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38c0b942f458fe50cdac086d2f946512305e5631e720728f2a61aabcd47a6264" +dependencies = [ + "quote", + "syn", +] [[package]] name = "jobserver" @@ -1659,9 +1681,9 @@ dependencies = [ [[package]] name = "moka" -version = "0.12.14" +version = "0.12.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85f8024e1c8e71c778968af91d43700ce1d11b219d127d79fb2934153b82b42b" +checksum = "957228ad12042ee839f93c8f257b62b4c0ab5eaae1d4fa60de53b27c9d7c5046" dependencies = [ "crossbeam-channel", "crossbeam-epoch", @@ -2771,7 +2793,7 @@ checksum = "7b2093cf4c8eb1e67749a6762251bc9cd836b6fc171623bd0a9d324d37af2417" [[package]] name = "telemt" -version = "3.3.29" +version = "3.3.30" dependencies = [ "aes", "anyhow", diff --git a/Cargo.toml b/Cargo.toml index 53082db..1e06b7a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,8 +1,11 @@ [package] name = "telemt" -version = "3.3.29" +version = "3.3.30" edition = "2024" +[features] +redteam_offline_expected_fail = [] + [dependencies] # C libc = "0.2" diff --git a/docs/CONFIG_PARAMS.en.md b/docs/CONFIG_PARAMS.en.md index f8c56a0..73d36e1 100644 --- a/docs/CONFIG_PARAMS.en.md +++ b/docs/CONFIG_PARAMS.en.md @@ -269,6 +269,8 @@ Note: When `server.proxy_protocol` is enabled, incoming PROXY protocol headers a | mask_shape_bucket_cap_bytes | `usize` | `4096` | Must be `>= mask_shape_bucket_floor_bytes`. | Maximum bucket size used by shape-channel hardening; traffic above cap is not padded further. | | mask_shape_above_cap_blur | `bool` | `false` | Requires `mask_shape_hardening = true`; requires `mask_shape_above_cap_blur_max_bytes > 0`. | Adds bounded randomized tail bytes even when forwarded size already exceeds cap. | | mask_shape_above_cap_blur_max_bytes | `usize` | `512` | Must be `<= 1048576`; must be `> 0` when `mask_shape_above_cap_blur = true`. | Maximum randomized extra bytes appended above cap. | +| mask_relay_max_bytes | `usize` | `5242880` | Must be `> 0`; must be `<= 67108864`. | Maximum relayed bytes per direction on unauthenticated masking fallback path. | +| mask_classifier_prefetch_timeout_ms | `u64` | `5` | Must be within `[5, 50]`. | Timeout budget (ms) for extending fragmented initial classifier window on masking fallback. | | mask_timing_normalization_enabled | `bool` | `false` | Requires `mask_timing_normalization_floor_ms > 0`; requires `ceiling >= floor`. | Enables timing envelope normalization on masking outcomes. | | mask_timing_normalization_floor_ms | `u64` | `0` | Must be `> 0` when timing normalization is enabled; must be `<= ceiling`. | Lower bound (ms) for masking outcome normalization target. | | mask_timing_normalization_ceiling_ms | `u64` | `0` | Must be `>= floor`; must be `<= 60000`. | Upper bound (ms) for masking outcome normalization target. | diff --git a/src/config/defaults.rs b/src/config/defaults.rs index 66ffeda..09d146a 100644 --- a/src/config/defaults.rs +++ b/src/config/defaults.rs @@ -553,6 +553,20 @@ pub(crate) fn default_mask_shape_above_cap_blur_max_bytes() -> usize { 512 } +#[cfg(not(test))] +pub(crate) fn default_mask_relay_max_bytes() -> usize { + 5 * 1024 * 1024 +} + +#[cfg(test)] +pub(crate) fn default_mask_relay_max_bytes() -> usize { + 32 * 1024 +} + +pub(crate) fn default_mask_classifier_prefetch_timeout_ms() -> u64 { + 5 +} + pub(crate) fn default_mask_timing_normalization_enabled() -> bool { false } diff --git a/src/config/hot_reload.rs b/src/config/hot_reload.rs index e580b7f..a3f795a 100644 --- a/src/config/hot_reload.rs +++ b/src/config/hot_reload.rs @@ -600,6 +600,9 @@ fn warn_non_hot_changes(old: &ProxyConfig, new: &ProxyConfig, non_hot_changed: b || old.censorship.mask_shape_above_cap_blur != new.censorship.mask_shape_above_cap_blur || old.censorship.mask_shape_above_cap_blur_max_bytes != new.censorship.mask_shape_above_cap_blur_max_bytes + || old.censorship.mask_relay_max_bytes != new.censorship.mask_relay_max_bytes + || old.censorship.mask_classifier_prefetch_timeout_ms + != new.censorship.mask_classifier_prefetch_timeout_ms || old.censorship.mask_timing_normalization_enabled != new.censorship.mask_timing_normalization_enabled || old.censorship.mask_timing_normalization_floor_ms diff --git a/src/config/load.rs b/src/config/load.rs index bf6d036..fc54ec2 100644 --- a/src/config/load.rs +++ b/src/config/load.rs @@ -430,6 +430,25 @@ impl ProxyConfig { )); } + if config.censorship.mask_relay_max_bytes == 0 { + return Err(ProxyError::Config( + "censorship.mask_relay_max_bytes must be > 0".to_string(), + )); + } + + if config.censorship.mask_relay_max_bytes > 67_108_864 { + return Err(ProxyError::Config( + "censorship.mask_relay_max_bytes must be <= 67108864".to_string(), + )); + } + + if !(5..=50).contains(&config.censorship.mask_classifier_prefetch_timeout_ms) { + return Err(ProxyError::Config( + "censorship.mask_classifier_prefetch_timeout_ms must be within [5, 50]" + .to_string(), + )); + } + if config.censorship.mask_timing_normalization_ceiling_ms < config.censorship.mask_timing_normalization_floor_ms { @@ -1134,6 +1153,10 @@ mod load_security_tests; #[path = "tests/load_mask_shape_security_tests.rs"] mod load_mask_shape_security_tests; +#[cfg(test)] +#[path = "tests/load_mask_classifier_prefetch_timeout_security_tests.rs"] +mod load_mask_classifier_prefetch_timeout_security_tests; + #[cfg(test)] mod tests { use super::*; diff --git a/src/config/tests/load_mask_classifier_prefetch_timeout_security_tests.rs b/src/config/tests/load_mask_classifier_prefetch_timeout_security_tests.rs new file mode 100644 index 0000000..49ee953 --- /dev/null +++ b/src/config/tests/load_mask_classifier_prefetch_timeout_security_tests.rs @@ -0,0 +1,75 @@ +use super::*; +use std::fs; +use std::path::PathBuf; +use std::time::{SystemTime, UNIX_EPOCH}; + +fn write_temp_config(contents: &str) -> PathBuf { + let nonce = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("system time must be after unix epoch") + .as_nanos(); + let path = std::env::temp_dir() + .join(format!("telemt-load-mask-prefetch-timeout-security-{nonce}.toml")); + fs::write(&path, contents).expect("temp config write must succeed"); + path +} + +fn remove_temp_config(path: &PathBuf) { + let _ = fs::remove_file(path); +} + +#[test] +fn load_rejects_mask_classifier_prefetch_timeout_below_min_bound() { + let path = write_temp_config( + r#" +[censorship] +mask_classifier_prefetch_timeout_ms = 4 +"#, + ); + + let err = ProxyConfig::load(&path) + .expect_err("prefetch timeout below minimum security bound must be rejected"); + let msg = err.to_string(); + assert!( + msg.contains("censorship.mask_classifier_prefetch_timeout_ms must be within [5, 50]"), + "error must explain timeout bound invariant, got: {msg}" + ); + + remove_temp_config(&path); +} + +#[test] +fn load_rejects_mask_classifier_prefetch_timeout_above_max_bound() { + let path = write_temp_config( + r#" +[censorship] +mask_classifier_prefetch_timeout_ms = 51 +"#, + ); + + let err = ProxyConfig::load(&path) + .expect_err("prefetch timeout above max security bound must be rejected"); + let msg = err.to_string(); + assert!( + msg.contains("censorship.mask_classifier_prefetch_timeout_ms must be within [5, 50]"), + "error must explain timeout bound invariant, got: {msg}" + ); + + remove_temp_config(&path); +} + +#[test] +fn load_accepts_mask_classifier_prefetch_timeout_within_bounds() { + let path = write_temp_config( + r#" +[censorship] +mask_classifier_prefetch_timeout_ms = 20 +"#, + ); + + let cfg = ProxyConfig::load(&path) + .expect("prefetch timeout within security bounds must be accepted"); + assert_eq!(cfg.censorship.mask_classifier_prefetch_timeout_ms, 20); + + remove_temp_config(&path); +} diff --git a/src/config/tests/load_mask_shape_security_tests.rs b/src/config/tests/load_mask_shape_security_tests.rs index 8986a49..2e4aa41 100644 --- a/src/config/tests/load_mask_shape_security_tests.rs +++ b/src/config/tests/load_mask_shape_security_tests.rs @@ -236,3 +236,57 @@ mask_shape_above_cap_blur_max_bytes = 8 remove_temp_config(&path); } + +#[test] +fn load_rejects_zero_mask_relay_max_bytes() { + let path = write_temp_config( + r#" +[censorship] +mask_relay_max_bytes = 0 +"#, + ); + + let err = ProxyConfig::load(&path).expect_err("mask_relay_max_bytes must be > 0"); + let msg = err.to_string(); + assert!( + msg.contains("censorship.mask_relay_max_bytes must be > 0"), + "error must explain non-zero relay cap invariant, got: {msg}" + ); + + remove_temp_config(&path); +} + +#[test] +fn load_rejects_mask_relay_max_bytes_above_upper_bound() { + let path = write_temp_config( + r#" +[censorship] +mask_relay_max_bytes = 67108865 +"#, + ); + + let err = ProxyConfig::load(&path) + .expect_err("mask_relay_max_bytes above hard cap must be rejected"); + let msg = err.to_string(); + assert!( + msg.contains("censorship.mask_relay_max_bytes must be <= 67108864"), + "error must explain relay cap upper bound invariant, got: {msg}" + ); + + remove_temp_config(&path); +} + +#[test] +fn load_accepts_valid_mask_relay_max_bytes() { + let path = write_temp_config( + r#" +[censorship] +mask_relay_max_bytes = 8388608 +"#, + ); + + let cfg = ProxyConfig::load(&path).expect("valid mask_relay_max_bytes must be accepted"); + assert_eq!(cfg.censorship.mask_relay_max_bytes, 8_388_608); + + remove_temp_config(&path); +} diff --git a/src/config/types.rs b/src/config/types.rs index aa58dc1..5dc9719 100644 --- a/src/config/types.rs +++ b/src/config/types.rs @@ -1450,6 +1450,14 @@ pub struct AntiCensorshipConfig { #[serde(default = "default_mask_shape_above_cap_blur_max_bytes")] pub mask_shape_above_cap_blur_max_bytes: usize, + /// Maximum bytes relayed per direction on unauthenticated masking fallback paths. + #[serde(default = "default_mask_relay_max_bytes")] + pub mask_relay_max_bytes: usize, + + /// Prefetch timeout (ms) for extending fragmented masking classifier window. + #[serde(default = "default_mask_classifier_prefetch_timeout_ms")] + pub mask_classifier_prefetch_timeout_ms: u64, + /// Enable outcome-time normalization envelope for masking fallback. #[serde(default = "default_mask_timing_normalization_enabled")] pub mask_timing_normalization_enabled: bool, @@ -1488,6 +1496,8 @@ impl Default for AntiCensorshipConfig { mask_shape_bucket_cap_bytes: default_mask_shape_bucket_cap_bytes(), mask_shape_above_cap_blur: default_mask_shape_above_cap_blur(), mask_shape_above_cap_blur_max_bytes: default_mask_shape_above_cap_blur_max_bytes(), + mask_relay_max_bytes: default_mask_relay_max_bytes(), + mask_classifier_prefetch_timeout_ms: default_mask_classifier_prefetch_timeout_ms(), mask_timing_normalization_enabled: default_mask_timing_normalization_enabled(), mask_timing_normalization_floor_ms: default_mask_timing_normalization_floor_ms(), mask_timing_normalization_ceiling_ms: default_mask_timing_normalization_ceiling_ms(), diff --git a/src/proxy/client.rs b/src/proxy/client.rs index 1941f36..a804a2c 100644 --- a/src/proxy/client.rs +++ b/src/proxy/client.rs @@ -186,6 +186,67 @@ fn handshake_timeout_with_mask_grace(config: &ProxyConfig) -> Duration { } } +const MASK_CLASSIFIER_PREFETCH_WINDOW: usize = 16; +#[cfg(test)] +const MASK_CLASSIFIER_PREFETCH_TIMEOUT: Duration = Duration::from_millis(5); + +fn mask_classifier_prefetch_timeout(config: &ProxyConfig) -> Duration { + Duration::from_millis(config.censorship.mask_classifier_prefetch_timeout_ms) +} + +fn should_prefetch_mask_classifier_window(initial_data: &[u8]) -> bool { + if initial_data.len() >= MASK_CLASSIFIER_PREFETCH_WINDOW { + return false; + } + + if initial_data.is_empty() { + // Empty initial_data means there is no client probe prefix to refine. + // Prefetching in this case can consume fallback relay payload bytes and + // accidentally route them through shaping heuristics. + return false; + } + + if initial_data[0] == 0x16 || initial_data.starts_with(b"SSH-") { + return false; + } + + initial_data.iter().all(|b| b.is_ascii_alphabetic() || *b == b' ') +} + +#[cfg(test)] +async fn extend_masking_initial_window(reader: &mut R, initial_data: &mut Vec) +where + R: AsyncRead + Unpin, +{ + extend_masking_initial_window_with_timeout(reader, initial_data, MASK_CLASSIFIER_PREFETCH_TIMEOUT) + .await; +} + +async fn extend_masking_initial_window_with_timeout( + reader: &mut R, + initial_data: &mut Vec, + prefetch_timeout: Duration, +) +where + R: AsyncRead + Unpin, +{ + if !should_prefetch_mask_classifier_window(initial_data) { + return; + } + + let need = MASK_CLASSIFIER_PREFETCH_WINDOW.saturating_sub(initial_data.len()); + if need == 0 { + return; + } + + let mut extra = [0u8; MASK_CLASSIFIER_PREFETCH_WINDOW]; + if let Ok(Ok(n)) = timeout(prefetch_timeout, reader.read(&mut extra[..need])).await + && n > 0 + { + initial_data.extend_from_slice(&extra[..n]); + } +} + fn masking_outcome( reader: R, writer: W, @@ -200,6 +261,15 @@ where W: AsyncWrite + Unpin + Send + 'static, { HandshakeOutcome::NeedsMasking(Box::pin(async move { + let mut reader = reader; + let mut initial_data = initial_data; + extend_masking_initial_window_with_timeout( + &mut reader, + &mut initial_data, + mask_classifier_prefetch_timeout(&config), + ) + .await; + handle_bad_client( reader, writer, @@ -1321,6 +1391,38 @@ mod masking_shape_classifier_fuzz_redteam_expected_fail_tests; #[path = "tests/client_masking_probe_evasion_blackhat_tests.rs"] mod masking_probe_evasion_blackhat_tests; +#[cfg(test)] +#[path = "tests/client_masking_fragmented_classifier_security_tests.rs"] +mod masking_fragmented_classifier_security_tests; + +#[cfg(test)] +#[path = "tests/client_masking_replay_timing_security_tests.rs"] +mod masking_replay_timing_security_tests; + +#[cfg(test)] +#[path = "tests/client_masking_http2_fragmented_preface_security_tests.rs"] +mod masking_http2_fragmented_preface_security_tests; + +#[cfg(test)] +#[path = "tests/client_masking_prefetch_invariant_security_tests.rs"] +mod masking_prefetch_invariant_security_tests; + +#[cfg(test)] +#[path = "tests/client_masking_prefetch_timing_matrix_security_tests.rs"] +mod masking_prefetch_timing_matrix_security_tests; + +#[cfg(test)] +#[path = "tests/client_masking_prefetch_config_runtime_security_tests.rs"] +mod masking_prefetch_config_runtime_security_tests; + +#[cfg(test)] +#[path = "tests/client_masking_prefetch_config_pipeline_integration_security_tests.rs"] +mod masking_prefetch_config_pipeline_integration_security_tests; + +#[cfg(test)] +#[path = "tests/client_masking_prefetch_strict_boundary_security_tests.rs"] +mod masking_prefetch_strict_boundary_security_tests; + #[cfg(test)] #[path = "tests/client_beobachten_ttl_bounds_security_tests.rs"] mod beobachten_ttl_bounds_security_tests; diff --git a/src/proxy/handshake.rs b/src/proxy/handshake.rs index 7e4b62c..3444a88 100644 --- a/src/proxy/handshake.rs +++ b/src/proxy/handshake.rs @@ -121,6 +121,20 @@ fn auth_probe_eviction_offset(peer_ip: IpAddr, now: Instant) -> usize { hasher.finish() as usize } +fn auth_probe_scan_start_offset( + peer_ip: IpAddr, + now: Instant, + state_len: usize, + scan_limit: usize, +) -> usize { + if state_len == 0 || scan_limit == 0 { + return 0; + } + + let window = state_len.min(scan_limit); + auth_probe_eviction_offset(peer_ip, now) % window +} + fn auth_probe_is_throttled(peer_ip: IpAddr, now: Instant) -> bool { let peer_ip = normalize_auth_probe_ip(peer_ip); let state = auth_probe_state_map(); @@ -269,11 +283,7 @@ fn auth_probe_record_failure_with_state( let mut eviction_candidate: Option<(IpAddr, u32, Instant)> = None; let state_len = state.len(); let scan_limit = state_len.min(AUTH_PROBE_PRUNE_SCAN_LIMIT); - let start_offset = if state_len == 0 { - 0 - } else { - auth_probe_eviction_offset(peer_ip, now) % state_len - }; + let start_offset = auth_probe_scan_start_offset(peer_ip, now, state_len, scan_limit); let mut scanned = 0usize; for entry in state.iter().skip(start_offset) { @@ -769,7 +779,7 @@ where let mut dec_key_input = Zeroizing::new(Vec::with_capacity(PREKEY_LEN + secret.len())); dec_key_input.extend_from_slice(dec_prekey); dec_key_input.extend_from_slice(&secret); - let dec_key = sha256(&dec_key_input); + let dec_key = Zeroizing::new(sha256(&dec_key_input)); let mut dec_iv_arr = [0u8; IV_LEN]; dec_iv_arr.copy_from_slice(dec_iv_bytes); @@ -805,7 +815,7 @@ where let mut enc_key_input = Zeroizing::new(Vec::with_capacity(PREKEY_LEN + secret.len())); enc_key_input.extend_from_slice(enc_prekey); enc_key_input.extend_from_slice(&secret); - let enc_key = sha256(&enc_key_input); + let enc_key = Zeroizing::new(sha256(&enc_key_input)); let mut enc_iv_arr = [0u8; IV_LEN]; enc_iv_arr.copy_from_slice(enc_iv_bytes); @@ -830,9 +840,9 @@ where user: user.clone(), dc_idx, proto_tag, - dec_key, + dec_key: *dec_key, dec_iv, - enc_key, + enc_key: *enc_key, enc_iv, peer, is_tls, @@ -979,6 +989,14 @@ mod saturation_poison_security_tests; #[path = "tests/handshake_auth_probe_hardening_adversarial_tests.rs"] mod auth_probe_hardening_adversarial_tests; +#[cfg(test)] +#[path = "tests/handshake_auth_probe_scan_budget_security_tests.rs"] +mod auth_probe_scan_budget_security_tests; + +#[cfg(test)] +#[path = "tests/handshake_auth_probe_scan_offset_stress_tests.rs"] +mod auth_probe_scan_offset_stress_tests; + #[cfg(test)] #[path = "tests/handshake_advanced_clever_tests.rs"] mod advanced_clever_tests; @@ -995,6 +1013,10 @@ mod real_bug_stress_tests; #[path = "tests/handshake_timing_manual_bench_tests.rs"] mod timing_manual_bench_tests; +#[cfg(test)] +#[path = "tests/handshake_key_material_zeroization_security_tests.rs"] +mod handshake_key_material_zeroization_security_tests; + /// Compile-time guard: HandshakeSuccess holds cryptographic key material and /// must never be Copy. A Copy impl would allow silent key duplication, /// undermining the zeroize-on-drop guarantee. diff --git a/src/proxy/masking.rs b/src/proxy/masking.rs index 3639db1..7d970c2 100644 --- a/src/proxy/masking.rs +++ b/src/proxy/masking.rs @@ -4,10 +4,17 @@ use crate::config::ProxyConfig; use crate::network::dns_overrides::resolve_socket_addr; use crate::stats::beobachten::BeobachtenStore; use crate::transport::proxy_protocol::{ProxyProtocolV1Builder, ProxyProtocolV2Builder}; -use rand::{Rng, RngExt}; -use std::net::SocketAddr; +#[cfg(unix)] +use nix::ifaddrs::getifaddrs; +use rand::rngs::StdRng; +use rand::{Rng, RngExt, SeedableRng}; +use std::net::{IpAddr, SocketAddr}; use std::str; -use std::time::Duration; +#[cfg(unix)] +use std::sync::{Mutex, OnceLock}; +#[cfg(test)] +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::time::{Duration, Instant as StdInstant}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::net::TcpStream; #[cfg(unix)] @@ -30,13 +37,23 @@ const MASK_RELAY_IDLE_TIMEOUT: Duration = Duration::from_secs(5); #[cfg(test)] const MASK_RELAY_IDLE_TIMEOUT: Duration = Duration::from_millis(100); const MASK_BUFFER_SIZE: usize = 8192; +#[cfg(unix)] +#[cfg(not(test))] +const LOCAL_INTERFACE_CACHE_TTL: Duration = Duration::from_secs(300); +#[cfg(all(unix, test))] +const LOCAL_INTERFACE_CACHE_TTL: Duration = Duration::from_secs(1); struct CopyOutcome { total: usize, ended_by_eof: bool, } -async fn copy_with_idle_timeout(reader: &mut R, writer: &mut W) -> CopyOutcome +async fn copy_with_idle_timeout( + reader: &mut R, + writer: &mut W, + byte_cap: usize, + shutdown_on_eof: bool, +) -> CopyOutcome where R: AsyncRead + Unpin, W: AsyncWrite + Unpin, @@ -44,14 +61,31 @@ where let mut buf = [0u8; MASK_BUFFER_SIZE]; let mut total = 0usize; let mut ended_by_eof = false; + + if byte_cap == 0 { + return CopyOutcome { + total, + ended_by_eof, + }; + } + loop { - let read_res = timeout(MASK_RELAY_IDLE_TIMEOUT, reader.read(&mut buf)).await; + let remaining_budget = byte_cap.saturating_sub(total); + if remaining_budget == 0 { + break; + } + + let read_len = remaining_budget.min(MASK_BUFFER_SIZE); + let read_res = timeout(MASK_RELAY_IDLE_TIMEOUT, reader.read(&mut buf[..read_len])).await; let n = match read_res { Ok(Ok(n)) => n, Ok(Err(_)) | Err(_) => break, }; if n == 0 { ended_by_eof = true; + if shutdown_on_eof { + let _ = timeout(MASK_RELAY_IDLE_TIMEOUT, writer.shutdown()).await; + } break; } total = total.saturating_add(n); @@ -61,6 +95,10 @@ where Ok(Ok(())) => {} Ok(Err(_)) | Err(_) => break, } + + if total >= byte_cap { + break; + } } CopyOutcome { total, @@ -68,6 +106,39 @@ where } } +fn is_http_probe(data: &[u8]) -> bool { + // RFC 7540 section 3.5: HTTP/2 client preface starts with "PRI ". + const HTTP_METHODS: [&[u8]; 10] = [ + b"GET ", + b"POST", + b"HEAD", + b"PUT ", + b"DELETE", + b"OPTIONS", + b"CONNECT", + b"TRACE", + b"PATCH", + b"PRI ", + ]; + + if data.is_empty() { + return false; + } + + let window = &data[..data.len().min(16)]; + for method in HTTP_METHODS { + if data.len() >= method.len() && window.starts_with(method) { + return true; + } + + if (2..=3).contains(&window.len()) && method.starts_with(window) { + return true; + } + } + + false +} + fn next_mask_shape_bucket(total: usize, floor: usize, cap: usize) -> usize { if total == 0 || floor == 0 || cap < floor { return total; @@ -125,6 +196,11 @@ async fn maybe_write_shape_padding( let mut remaining = target_total - total_sent; let mut pad_chunk = [0u8; 1024]; let deadline = Instant::now() + MASK_TIMEOUT; + // Use a Send RNG so relay futures remain spawn-safe under Tokio. + let mut rng = { + let mut seed_source = rand::rng(); + StdRng::from_rng(&mut seed_source) + }; while remaining > 0 { let now = Instant::now(); @@ -133,10 +209,7 @@ async fn maybe_write_shape_padding( } let write_len = remaining.min(pad_chunk.len()); - { - let mut rng = rand::rng(); - rng.fill_bytes(&mut pad_chunk[..write_len]); - } + rng.fill_bytes(&mut pad_chunk[..write_len]); let write_budget = deadline.saturating_duration_since(now); match timeout(write_budget, mask_write.write_all(&pad_chunk[..write_len])).await { Ok(Ok(())) => {} @@ -167,11 +240,11 @@ where } } -async fn consume_client_data_with_timeout(reader: R) +async fn consume_client_data_with_timeout_and_cap(reader: R, byte_cap: usize) where R: AsyncRead + Unpin, { - if timeout(MASK_RELAY_TIMEOUT, consume_client_data(reader)) + if timeout(MASK_RELAY_TIMEOUT, consume_client_data(reader, byte_cap)) .await .is_err() { @@ -190,6 +263,9 @@ fn mask_outcome_target_budget(config: &ProxyConfig) -> Duration { if config.censorship.mask_timing_normalization_enabled { let floor = config.censorship.mask_timing_normalization_floor_ms; let ceiling = config.censorship.mask_timing_normalization_ceiling_ms; + if floor == 0 { + return MASK_TIMEOUT; + } if ceiling > floor { let mut rng = rand::rng(); return Duration::from_millis(rng.random_range(floor..=ceiling)); @@ -219,14 +295,7 @@ async fn wait_mask_outcome_budget(started: Instant, config: &ProxyConfig) { /// Detect client type based on initial data fn detect_client_type(data: &[u8]) -> &'static str { // Check for HTTP request - if data.len() > 4 - && (data.starts_with(b"GET ") - || data.starts_with(b"POST") - || data.starts_with(b"HEAD") - || data.starts_with(b"PUT ") - || data.starts_with(b"DELETE") - || data.starts_with(b"OPTIONS")) - { + if is_http_probe(data) { return "HTTP"; } @@ -248,6 +317,172 @@ fn detect_client_type(data: &[u8]) -> &'static str { "unknown" } +fn parse_mask_host_ip_literal(host: &str) -> Option { + if host.starts_with('[') && host.ends_with(']') { + return host[1..host.len() - 1].parse::().ok(); + } + host.parse::().ok() +} + +fn canonical_ip(ip: IpAddr) -> IpAddr { + match ip { + IpAddr::V6(v6) => v6.to_ipv4_mapped().map(IpAddr::V4).unwrap_or(IpAddr::V6(v6)), + IpAddr::V4(v4) => IpAddr::V4(v4), + } +} + +#[cfg(unix)] +fn collect_local_interface_ips() -> Vec { + #[cfg(test)] + LOCAL_INTERFACE_ENUMERATIONS.fetch_add(1, Ordering::Relaxed); + + let mut out = Vec::new(); + if let Ok(addrs) = getifaddrs() { + for iface in addrs { + if let Some(address) = iface.address { + if let Some(v4) = address.as_sockaddr_in() { + out.push(canonical_ip(IpAddr::V4(v4.ip()))); + } else if let Some(v6) = address.as_sockaddr_in6() { + out.push(canonical_ip(IpAddr::V6(v6.ip()))); + } + } + } + } + out +} + +fn choose_interface_snapshot(previous: &[IpAddr], refreshed: Vec) -> Vec { + if refreshed.is_empty() && !previous.is_empty() { + return previous.to_vec(); + } + + refreshed +} + +#[cfg(unix)] +#[derive(Default)] +struct LocalInterfaceCache { + ips: Vec, + refreshed_at: Option, +} + +#[cfg(unix)] +static LOCAL_INTERFACE_CACHE: OnceLock> = OnceLock::new(); + +#[cfg(unix)] +fn local_interface_ips() -> Vec { + let cache = LOCAL_INTERFACE_CACHE.get_or_init(|| Mutex::new(LocalInterfaceCache::default())); + let mut guard = cache.lock().unwrap_or_else(|poison| poison.into_inner()); + + let stale = guard + .refreshed_at + .is_none_or(|at| at.elapsed() >= LOCAL_INTERFACE_CACHE_TTL); + if stale { + let refreshed = collect_local_interface_ips(); + guard.ips = choose_interface_snapshot(&guard.ips, refreshed); + guard.refreshed_at = Some(StdInstant::now()); + } + + guard.ips.clone() +} + +#[cfg(not(unix))] +fn local_interface_ips() -> Vec { + Vec::new() +} + +#[cfg(test)] +static LOCAL_INTERFACE_ENUMERATIONS: AtomicUsize = AtomicUsize::new(0); + +#[cfg(test)] +fn reset_local_interface_enumerations_for_tests() { + LOCAL_INTERFACE_ENUMERATIONS.store(0, Ordering::Relaxed); + + #[cfg(unix)] + if let Some(cache) = LOCAL_INTERFACE_CACHE.get() { + let mut guard = cache.lock().unwrap_or_else(|poison| poison.into_inner()); + guard.ips.clear(); + guard.refreshed_at = None; + } +} + +#[cfg(test)] +fn local_interface_enumerations_for_tests() -> usize { + LOCAL_INTERFACE_ENUMERATIONS.load(Ordering::Relaxed) +} + +fn is_mask_target_local_listener_with_interfaces( + mask_host: &str, + mask_port: u16, + local_addr: SocketAddr, + resolved_override: Option, + interface_ips: &[IpAddr], +) -> bool { + if mask_port != local_addr.port() { + return false; + } + + let local_ip = canonical_ip(local_addr.ip()); + let literal_mask_ip = parse_mask_host_ip_literal(mask_host).map(canonical_ip); + + if let Some(addr) = resolved_override { + let resolved_ip = canonical_ip(addr.ip()); + if resolved_ip == local_ip { + return true; + } + + if local_ip.is_unspecified() + && (resolved_ip.is_loopback() + || resolved_ip.is_unspecified() + || interface_ips.contains(&resolved_ip)) + { + return true; + } + } + + if let Some(mask_ip) = literal_mask_ip { + if mask_ip == local_ip { + return true; + } + + if local_ip.is_unspecified() + && (mask_ip.is_loopback() + || mask_ip.is_unspecified() + || interface_ips.contains(&mask_ip)) + { + return true; + } + } + + false +} + +fn is_mask_target_local_listener( + mask_host: &str, + mask_port: u16, + local_addr: SocketAddr, + resolved_override: Option, +) -> bool { + if mask_port != local_addr.port() { + return false; + } + + let interfaces = local_interface_ips(); + is_mask_target_local_listener_with_interfaces( + mask_host, + mask_port, + local_addr, + resolved_override, + &interfaces, + ) +} + +fn masking_beobachten_ttl(config: &ProxyConfig) -> Duration { + let minutes = config.general.beobachten_minutes; + let clamped = minutes.clamp(1, 24 * 60); + Duration::from_secs(clamped.saturating_mul(60)) +} + fn build_mask_proxy_header( version: u8, peer: SocketAddr, @@ -290,13 +525,14 @@ pub async fn handle_bad_client( { let client_type = detect_client_type(initial_data); if config.general.beobachten { - let ttl = Duration::from_secs(config.general.beobachten_minutes.saturating_mul(60)); + let ttl = masking_beobachten_ttl(config); beobachten.record(client_type, peer.ip(), ttl); } if !config.censorship.mask { // Masking disabled, just consume data - consume_client_data_with_timeout(reader).await; + consume_client_data_with_timeout_and_cap(reader, config.censorship.mask_relay_max_bytes) + .await; return; } @@ -341,6 +577,7 @@ pub async fn handle_bad_client( config.censorship.mask_shape_above_cap_blur, config.censorship.mask_shape_above_cap_blur_max_bytes, config.censorship.mask_shape_hardening_aggressive_mode, + config.censorship.mask_relay_max_bytes, ), ) .await @@ -353,12 +590,12 @@ pub async fn handle_bad_client( Ok(Err(e)) => { wait_mask_connect_budget_if_needed(connect_started, config).await; debug!(error = %e, "Failed to connect to mask unix socket"); - consume_client_data_with_timeout(reader).await; + consume_client_data_with_timeout_and_cap(reader, config.censorship.mask_relay_max_bytes).await; wait_mask_outcome_budget(outcome_started, config).await; } Err(_) => { debug!("Timeout connecting to mask unix socket"); - consume_client_data_with_timeout(reader).await; + consume_client_data_with_timeout_and_cap(reader, config.censorship.mask_relay_max_bytes).await; wait_mask_outcome_budget(outcome_started, config).await; } } @@ -371,6 +608,24 @@ pub async fn handle_bad_client( .as_deref() .unwrap_or(&config.censorship.tls_domain); let mask_port = config.censorship.mask_port; + let outcome_started = Instant::now(); + + // Fail closed when fallback points at our own listener endpoint. + // Self-referential masking can create recursive proxy loops under + // misconfiguration and leak distinguishable load spikes to adversaries. + let resolved_mask_addr = resolve_socket_addr(mask_host, mask_port); + if is_mask_target_local_listener(mask_host, mask_port, local_addr, resolved_mask_addr) { + debug!( + client_type = client_type, + host = %mask_host, + port = mask_port, + local = %local_addr, + "Mask target resolves to local listener; refusing self-referential masking fallback" + ); + consume_client_data_with_timeout_and_cap(reader, config.censorship.mask_relay_max_bytes).await; + wait_mask_outcome_budget(outcome_started, config).await; + return; + } debug!( client_type = client_type, @@ -381,10 +636,9 @@ pub async fn handle_bad_client( ); // Apply runtime DNS override for mask target when configured. - let mask_addr = resolve_socket_addr(mask_host, mask_port) + let mask_addr = resolved_mask_addr .map(|addr| addr.to_string()) .unwrap_or_else(|| format!("{}:{}", mask_host, mask_port)); - let outcome_started = Instant::now(); let connect_started = Instant::now(); let connect_result = timeout(MASK_TIMEOUT, TcpStream::connect(&mask_addr)).await; match connect_result { @@ -413,6 +667,7 @@ pub async fn handle_bad_client( config.censorship.mask_shape_above_cap_blur, config.censorship.mask_shape_above_cap_blur_max_bytes, config.censorship.mask_shape_hardening_aggressive_mode, + config.censorship.mask_relay_max_bytes, ), ) .await @@ -425,12 +680,12 @@ pub async fn handle_bad_client( Ok(Err(e)) => { wait_mask_connect_budget_if_needed(connect_started, config).await; debug!(error = %e, "Failed to connect to mask host"); - consume_client_data_with_timeout(reader).await; + consume_client_data_with_timeout_and_cap(reader, config.censorship.mask_relay_max_bytes).await; wait_mask_outcome_budget(outcome_started, config).await; } Err(_) => { debug!("Timeout connecting to mask host"); - consume_client_data_with_timeout(reader).await; + consume_client_data_with_timeout_and_cap(reader, config.censorship.mask_relay_max_bytes).await; wait_mask_outcome_budget(outcome_started, config).await; } } @@ -449,6 +704,7 @@ async fn relay_to_mask( shape_above_cap_blur: bool, shape_above_cap_blur_max_bytes: usize, shape_hardening_aggressive_mode: bool, + mask_relay_max_bytes: usize, ) where R: AsyncRead + Unpin + Send + 'static, W: AsyncWrite + Unpin + Send + 'static, @@ -464,8 +720,18 @@ async fn relay_to_mask( } let (upstream_copy, downstream_copy) = tokio::join!( - async { copy_with_idle_timeout(&mut reader, &mut mask_write).await }, - async { copy_with_idle_timeout(&mut mask_read, &mut writer).await } + async { + copy_with_idle_timeout( + &mut reader, + &mut mask_write, + mask_relay_max_bytes, + !shape_hardening_enabled, + ) + .await + }, + async { + copy_with_idle_timeout(&mut mask_read, &mut writer, mask_relay_max_bytes, true).await + } ); let total_sent = initial_data.len().saturating_add(upstream_copy.total); @@ -491,13 +757,30 @@ async fn relay_to_mask( let _ = writer.shutdown().await; } -/// Just consume all data from client without responding -async fn consume_client_data(mut reader: R) { - let mut buf = vec![0u8; MASK_BUFFER_SIZE]; - while let Ok(n) = reader.read(&mut buf).await { +/// Just consume all data from client without responding. +async fn consume_client_data(mut reader: R, byte_cap: usize) { + if byte_cap == 0 { + return; + } + + // Keep drain path fail-closed under slow-loris stalls. + let mut buf = [0u8; MASK_BUFFER_SIZE]; + let mut total = 0usize; + + loop { + let n = match timeout(MASK_RELAY_IDLE_TIMEOUT, reader.read(&mut buf)).await { + Ok(Ok(n)) => n, + Ok(Err(_)) | Err(_) => break, + }; + if n == 0 { break; } + + total = total.saturating_add(n); + if total >= byte_cap { + break; + } } } @@ -548,3 +831,63 @@ mod masking_aggressive_mode_security_tests; #[cfg(test)] #[path = "tests/masking_timing_sidechannel_redteam_expected_fail_tests.rs"] mod masking_timing_sidechannel_redteam_expected_fail_tests; + +#[cfg(test)] +#[path = "tests/masking_self_target_loop_security_tests.rs"] +mod masking_self_target_loop_security_tests; + +#[cfg(test)] +#[path = "tests/masking_classification_completeness_security_tests.rs"] +mod masking_classification_completeness_security_tests; + +#[cfg(test)] +#[path = "tests/masking_relay_guardrails_security_tests.rs"] +mod masking_relay_guardrails_security_tests; + +#[cfg(test)] +#[path = "tests/masking_connect_failure_close_matrix_security_tests.rs"] +mod masking_connect_failure_close_matrix_security_tests; + +#[cfg(test)] +#[path = "tests/masking_additional_hardening_security_tests.rs"] +mod masking_additional_hardening_security_tests; + +#[cfg(test)] +#[path = "tests/masking_consume_idle_timeout_security_tests.rs"] +mod masking_consume_idle_timeout_security_tests; + +#[cfg(test)] +#[path = "tests/masking_http2_probe_classification_security_tests.rs"] +mod masking_http2_probe_classification_security_tests; + +#[cfg(test)] +#[path = "tests/masking_http_probe_boundary_security_tests.rs"] +mod masking_http_probe_boundary_security_tests; + +#[cfg(test)] +#[path = "tests/masking_rng_hoist_perf_regression_tests.rs"] +mod masking_rng_hoist_perf_regression_tests; + +#[cfg(test)] +#[path = "tests/masking_http2_preface_integration_security_tests.rs"] +mod masking_http2_preface_integration_security_tests; + +#[cfg(test)] +#[path = "tests/masking_consume_stress_adversarial_tests.rs"] +mod masking_consume_stress_adversarial_tests; + +#[cfg(test)] +#[path = "tests/masking_interface_cache_security_tests.rs"] +mod masking_interface_cache_security_tests; + +#[cfg(test)] +#[path = "tests/masking_interface_cache_defense_in_depth_security_tests.rs"] +mod masking_interface_cache_defense_in_depth_security_tests; + +#[cfg(test)] +#[path = "tests/masking_padding_timeout_adversarial_tests.rs"] +mod masking_padding_timeout_adversarial_tests; + +#[cfg(all(test, feature = "redteam_offline_expected_fail"))] +#[path = "tests/masking_offline_target_redteam_expected_fail_tests.rs"] +mod masking_offline_target_redteam_expected_fail_tests; diff --git a/src/proxy/middle_relay.rs b/src/proxy/middle_relay.rs index 8b8d3dc..0d2a748 100644 --- a/src/proxy/middle_relay.rs +++ b/src/proxy/middle_relay.rs @@ -39,6 +39,8 @@ const C2ME_CHANNEL_CAPACITY_FALLBACK: usize = 128; const C2ME_SOFT_PRESSURE_MIN_FREE_SLOTS: usize = 64; const C2ME_SENDER_FAIRNESS_BUDGET: usize = 32; const RELAY_IDLE_IO_POLL_MAX: Duration = Duration::from_secs(1); +const TINY_FRAME_DEBT_PER_TINY: u32 = 8; +const TINY_FRAME_DEBT_LIMIT: u32 = 512; #[cfg(test)] const C2ME_SEND_TIMEOUT: Duration = Duration::from_millis(50); #[cfg(not(test))] @@ -94,10 +96,23 @@ fn relay_idle_candidate_registry() -> &'static Mutex RELAY_IDLE_CANDIDATE_REGISTRY.get_or_init(|| Mutex::new(RelayIdleCandidateRegistry::default())) } +fn relay_idle_candidate_registry_lock() -> std::sync::MutexGuard<'static, RelayIdleCandidateRegistry> { + let registry = relay_idle_candidate_registry(); + match registry.lock() { + Ok(guard) => guard, + Err(poisoned) => { + let mut guard = poisoned.into_inner(); + // Fail closed after panic while holding registry lock: drop all + // candidates and pressure cursors to avoid stale cross-session state. + *guard = RelayIdleCandidateRegistry::default(); + registry.clear_poison(); + guard + } + } +} + fn mark_relay_idle_candidate(conn_id: u64) -> bool { - let Ok(mut guard) = relay_idle_candidate_registry().lock() else { - return false; - }; + let mut guard = relay_idle_candidate_registry_lock(); if guard.by_conn_id.contains_key(&conn_id) { return false; @@ -116,9 +131,7 @@ fn mark_relay_idle_candidate(conn_id: u64) -> bool { } fn clear_relay_idle_candidate(conn_id: u64) { - let Ok(mut guard) = relay_idle_candidate_registry().lock() else { - return; - }; + let mut guard = relay_idle_candidate_registry_lock(); if let Some(meta) = guard.by_conn_id.remove(&conn_id) { guard.ordered.remove(&(meta.mark_order_seq, conn_id)); @@ -127,23 +140,17 @@ fn clear_relay_idle_candidate(conn_id: u64) { #[cfg(test)] fn oldest_relay_idle_candidate() -> Option { - let Ok(guard) = relay_idle_candidate_registry().lock() else { - return None; - }; + let guard = relay_idle_candidate_registry_lock(); guard.ordered.iter().next().map(|(_, conn_id)| *conn_id) } fn note_relay_pressure_event() { - let Ok(mut guard) = relay_idle_candidate_registry().lock() else { - return; - }; + let mut guard = relay_idle_candidate_registry_lock(); guard.pressure_event_seq = guard.pressure_event_seq.wrapping_add(1); } fn relay_pressure_event_seq() -> u64 { - let Ok(guard) = relay_idle_candidate_registry().lock() else { - return 0; - }; + let guard = relay_idle_candidate_registry_lock(); guard.pressure_event_seq } @@ -152,9 +159,7 @@ fn maybe_evict_idle_candidate_on_pressure( seen_pressure_seq: &mut u64, stats: &Stats, ) -> bool { - let Ok(mut guard) = relay_idle_candidate_registry().lock() else { - return false; - }; + let mut guard = relay_idle_candidate_registry_lock(); let latest_pressure_seq = guard.pressure_event_seq; if latest_pressure_seq == *seen_pressure_seq { @@ -199,13 +204,9 @@ fn maybe_evict_idle_candidate_on_pressure( #[cfg(test)] fn clear_relay_idle_pressure_state_for_testing() { - if let Some(registry) = RELAY_IDLE_CANDIDATE_REGISTRY.get() - && let Ok(mut guard) = registry.lock() - { - guard.by_conn_id.clear(); - guard.ordered.clear(); - guard.pressure_event_seq = 0; - guard.pressure_consumed_seq = 0; + if RELAY_IDLE_CANDIDATE_REGISTRY.get().is_some() { + let mut guard = relay_idle_candidate_registry_lock(); + *guard = RelayIdleCandidateRegistry::default(); } RELAY_IDLE_MARK_SEQ.store(0, Ordering::Relaxed); } @@ -259,6 +260,7 @@ impl RelayClientIdlePolicy { struct RelayClientIdleState { last_client_frame_at: Instant, soft_idle_marked: bool, + tiny_frame_debt: u32, } impl RelayClientIdleState { @@ -266,6 +268,7 @@ impl RelayClientIdleState { Self { last_client_frame_at: now, soft_idle_marked: false, + tiny_frame_debt: 0, } } @@ -551,15 +554,6 @@ fn quota_soft_cap(limit: u64, overshoot: u64) -> u64 { limit.saturating_add(overshoot) } -fn quota_exceeded_for_user_soft( - stats: &Stats, - user: &str, - quota_limit: Option, - overshoot: u64, -) -> bool { - quota_limit.is_some_and(|quota| stats.get_user_total_octets(user) >= quota_soft_cap(quota, overshoot)) -} - fn quota_would_be_exceeded_for_user_soft( stats: &Stats, user: &str, @@ -617,6 +611,16 @@ fn observe_me_d2c_flush_event( } } +fn rollback_me2c_quota_reservation( + stats: &Stats, + user: &str, + bytes_me2c: &AtomicU64, + reserved_bytes: u64, +) { + stats.sub_user_octets_to(user, reserved_bytes); + bytes_me2c.fetch_sub(reserved_bytes, Ordering::Relaxed); +} + #[cfg(test)] fn quota_user_lock_test_guard() -> &'static Mutex<()> { static TEST_LOCK: OnceLock> = OnceLock::new(); @@ -630,6 +634,19 @@ fn quota_user_lock_test_scope() -> std::sync::MutexGuard<'static, ()> { .unwrap_or_else(|poisoned| poisoned.into_inner()) } +#[cfg(test)] +fn relay_idle_pressure_test_guard() -> &'static Mutex<()> { + static TEST_LOCK: OnceLock> = OnceLock::new(); + TEST_LOCK.get_or_init(|| Mutex::new(())) +} + +#[cfg(test)] +pub(crate) fn relay_idle_pressure_test_scope() -> std::sync::MutexGuard<'static, ()> { + relay_idle_pressure_test_guard() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + fn quota_overflow_user_lock(user: &str) -> Arc> { let stripes = QUOTA_USER_OVERFLOW_LOCKS.get_or_init(|| { (0..QUOTA_OVERFLOW_LOCK_STRIPES) @@ -665,6 +682,11 @@ fn quota_user_lock(user: &str) -> Arc> { } } +#[cfg(test)] +pub(crate) fn cross_mode_quota_user_lock_for_tests(user: &str) -> Arc> { + crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(user) +} + async fn enqueue_c2me_command( tx: &mpsc::Sender, cmd: C2MeCommand, @@ -710,6 +732,8 @@ where { let user = success.user.clone(); let quota_limit = config.access.user_data_quota.get(&user).copied(); + let cross_mode_quota_lock = + quota_limit.map(|_| crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user)); let peer = success.peer; let proto_tag = success.proto_tag; let pool_generation = me_pool.current_generation(); @@ -1221,6 +1245,17 @@ where if let Some(limit) = quota_limit { let quota_lock = quota_user_lock(&user); let _quota_guard = quota_lock.lock().await; + let Some(cross_mode_lock) = cross_mode_quota_lock.as_ref() else { + main_result = Err(ProxyError::Proxy( + "cross-mode quota lock missing for quota-limited session" + .to_string(), + )); + break; + }; + let _cross_mode_quota_guard = match cross_mode_lock.lock() { + Ok(guard) => guard, + Err(poisoned) => poisoned.into_inner(), + }; stats.add_user_octets_from(&user, payload.len() as u64); if quota_exceeded_for_user(stats.as_ref(), &user, Some(limit)) { main_result = Err(ProxyError::DataQuotaExceeded { @@ -1320,6 +1355,8 @@ async fn read_client_payload_with_idle_policy( where R: AsyncRead + Unpin + Send + 'static, { + const LEGACY_MAX_CONSECUTIVE_ZERO_LEN_FRAMES: u32 = 4; + async fn read_exact_with_policy( client_reader: &mut CryptoReader, buf: &mut [u8], @@ -1458,6 +1495,7 @@ where Ok(()) } + let mut consecutive_zero_len_frames = 0u32; loop { let (len, quickack, raw_len_bytes) = match proto_tag { ProtoTag::Abridged => { @@ -1538,6 +1576,27 @@ where }; if len == 0 { + idle_state.tiny_frame_debt = idle_state + .tiny_frame_debt + .saturating_add(TINY_FRAME_DEBT_PER_TINY); + if idle_state.tiny_frame_debt >= TINY_FRAME_DEBT_LIMIT { + stats.increment_relay_protocol_desync_close_total(); + return Err(ProxyError::Proxy(format!( + "Tiny frame overhead limit exceeded: debt={}, conn_id={}", + idle_state.tiny_frame_debt, forensics.conn_id + ))); + } + + if !idle_policy.enabled { + consecutive_zero_len_frames = + consecutive_zero_len_frames.saturating_add(1); + if consecutive_zero_len_frames > LEGACY_MAX_CONSECUTIVE_ZERO_LEN_FRAMES { + stats.increment_relay_protocol_desync_close_total(); + return Err(ProxyError::Proxy( + "Excessive zero-length abridged frames".to_string(), + )); + } + } continue; } if len < 4 && proto_tag != ProtoTag::Abridged { @@ -1606,6 +1665,7 @@ where } *frame_counter += 1; idle_state.on_client_frame(Instant::now()); + idle_state.tiny_frame_debt = idle_state.tiny_frame_debt.saturating_sub(1); clear_relay_idle_candidate(forensics.conn_id); return Ok(Some((payload, quickack))); } @@ -1707,39 +1767,57 @@ where trace!(conn_id, bytes = data.len(), flags, "ME->C data"); } let data_len = data.len() as u64; - if quota_would_be_exceeded_for_user_soft( - stats, - user, - quota_limit, - data_len, - quota_soft_overshoot_bytes, - ) { - stats.increment_me_d2c_quota_reject_total(MeD2cQuotaRejectStage::PreWrite); - return Err(ProxyError::DataQuotaExceeded { - user: user.to_string(), - }); - } + if let Some(limit) = quota_limit { + let soft_limit = quota_soft_cap(limit, quota_soft_overshoot_bytes); + if quota_would_be_exceeded_for_user(stats, user, Some(soft_limit), data_len) { + stats.increment_me_d2c_quota_reject_total(MeD2cQuotaRejectStage::PreWrite); + return Err(ProxyError::DataQuotaExceeded { + user: user.to_string(), + }); + } - let write_mode = - write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf) - .await?; - stats.increment_me_d2c_write_mode(write_mode); + // Reserve quota before awaiting network I/O to avoid same-user HoL stalls. + // If reservation loses a race or write fails, we roll back immediately. + bytes_me2c.fetch_add(data_len, Ordering::Relaxed); + stats.add_user_octets_to(user, data_len); - bytes_me2c.fetch_add(data.len() as u64, Ordering::Relaxed); - stats.add_user_octets_to(user, data.len() as u64); - stats.increment_me_d2c_data_frames_total(); - stats.add_me_d2c_payload_bytes_total(data.len() as u64); + if stats.get_user_total_octets(user) > soft_limit { + rollback_me2c_quota_reservation(stats, user, bytes_me2c, data_len); + stats.increment_me_d2c_quota_reject_total(MeD2cQuotaRejectStage::PreWrite); + return Err(ProxyError::DataQuotaExceeded { + user: user.to_string(), + }); + } - if quota_exceeded_for_user_soft( - stats, - user, - quota_limit, - quota_soft_overshoot_bytes, - ) { - stats.increment_me_d2c_quota_reject_total(MeD2cQuotaRejectStage::PostWrite); - return Err(ProxyError::DataQuotaExceeded { - user: user.to_string(), - }); + let write_mode = + match write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf) + .await + { + Ok(mode) => mode, + Err(err) => { + rollback_me2c_quota_reservation(stats, user, bytes_me2c, data_len); + return Err(err); + } + }; + + stats.increment_me_d2c_data_frames_total(); + stats.add_me_d2c_payload_bytes_total(data_len); + stats.increment_me_d2c_write_mode(write_mode); + + // Do not fail immediately on exact boundary after a successful write. + // Returning an error here can bypass batch flush in the caller and risk + // dropping buffered ciphertext from CryptoWriter. The next frame is + // rejected by the pre-check at function entry. + } else { + let write_mode = + write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf) + .await?; + + bytes_me2c.fetch_add(data_len, Ordering::Relaxed); + stats.add_user_octets_to(user, data_len); + stats.increment_me_d2c_data_frames_total(); + stats.add_me_d2c_payload_bytes_total(data_len); + stats.increment_me_d2c_write_mode(write_mode); } Ok(MeWriterResponseOutcome::Continue { @@ -1978,3 +2056,31 @@ mod length_cast_hardening_security_tests; #[cfg(test)] #[path = "tests/middle_relay_blackhat_campaign_integration_tests.rs"] mod blackhat_campaign_integration_tests; + +#[cfg(test)] +#[path = "tests/middle_relay_hol_quota_security_tests.rs"] +mod hol_quota_security_tests; + +#[cfg(test)] +#[path = "tests/middle_relay_quota_reservation_adversarial_tests.rs"] +mod quota_reservation_adversarial_tests; + +#[cfg(test)] +#[path = "tests/middle_relay_idle_registry_poison_security_tests.rs"] +mod middle_relay_idle_registry_poison_security_tests; + +#[cfg(test)] +#[path = "tests/middle_relay_zero_length_frame_security_tests.rs"] +mod middle_relay_zero_length_frame_security_tests; + +#[cfg(test)] +#[path = "tests/middle_relay_tiny_frame_debt_security_tests.rs"] +mod middle_relay_tiny_frame_debt_security_tests; + +#[cfg(test)] +#[path = "tests/middle_relay_tiny_frame_debt_concurrency_security_tests.rs"] +mod middle_relay_tiny_frame_debt_concurrency_security_tests; + +#[cfg(test)] +#[path = "tests/middle_relay_tiny_frame_debt_proto_chunking_security_tests.rs"] +mod middle_relay_tiny_frame_debt_proto_chunking_security_tests; diff --git a/src/proxy/mod.rs b/src/proxy/mod.rs index eebc188..519f1b3 100644 --- a/src/proxy/mod.rs +++ b/src/proxy/mod.rs @@ -64,6 +64,7 @@ pub mod direct_relay; pub mod handshake; pub mod masking; pub mod middle_relay; +pub mod quota_lock_registry; pub mod relay; pub mod route_mode; pub mod session_eviction; diff --git a/src/proxy/quota_lock_registry.rs b/src/proxy/quota_lock_registry.rs new file mode 100644 index 0000000..ac64a57 --- /dev/null +++ b/src/proxy/quota_lock_registry.rs @@ -0,0 +1,53 @@ +use dashmap::DashMap; +use std::sync::{Arc, Mutex, OnceLock}; + +#[cfg(test)] +const CROSS_MODE_QUOTA_USER_LOCKS_MAX: usize = 64; +#[cfg(not(test))] +const CROSS_MODE_QUOTA_USER_LOCKS_MAX: usize = 4_096; +#[cfg(test)] +const CROSS_MODE_QUOTA_OVERFLOW_LOCK_STRIPES: usize = 16; +#[cfg(not(test))] +const CROSS_MODE_QUOTA_OVERFLOW_LOCK_STRIPES: usize = 256; + +static CROSS_MODE_QUOTA_USER_LOCKS: OnceLock>>> = OnceLock::new(); +static CROSS_MODE_QUOTA_USER_OVERFLOW_LOCKS: OnceLock>>> = OnceLock::new(); + +fn cross_mode_quota_overflow_user_lock(user: &str) -> Arc> { + let stripes = CROSS_MODE_QUOTA_USER_OVERFLOW_LOCKS.get_or_init(|| { + (0..CROSS_MODE_QUOTA_OVERFLOW_LOCK_STRIPES) + .map(|_| Arc::new(Mutex::new(()))) + .collect() + }); + + let hash = crc32fast::hash(user.as_bytes()) as usize; + Arc::clone(&stripes[hash % stripes.len()]) +} + +pub(crate) fn cross_mode_quota_user_lock(user: &str) -> Arc> { + let locks = CROSS_MODE_QUOTA_USER_LOCKS.get_or_init(DashMap::new); + if let Some(existing) = locks.get(user) { + return Arc::clone(existing.value()); + } + + if locks.len() >= CROSS_MODE_QUOTA_USER_LOCKS_MAX { + locks.retain(|_, value| Arc::strong_count(value) > 1); + } + + if locks.len() >= CROSS_MODE_QUOTA_USER_LOCKS_MAX { + return cross_mode_quota_overflow_user_lock(user); + } + + let created = Arc::new(Mutex::new(())); + match locks.entry(user.to_string()) { + dashmap::mapref::entry::Entry::Occupied(entry) => Arc::clone(entry.get()), + dashmap::mapref::entry::Entry::Vacant(entry) => { + entry.insert(Arc::clone(&created)); + created + } + } +} + +#[cfg(test)] +#[path = "tests/quota_lock_registry_cross_mode_adversarial_tests.rs"] +mod quota_lock_registry_cross_mode_adversarial_tests; diff --git a/src/proxy/relay.rs b/src/proxy/relay.rs index 2431ff4..dcacedd 100644 --- a/src/proxy/relay.rs +++ b/src/proxy/relay.rs @@ -62,7 +62,7 @@ use std::sync::{Arc, Mutex, OnceLock}; use std::task::{Context, Poll}; use std::time::Duration; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf, copy_bidirectional_with_sizes}; -use tokio::time::Instant; +use tokio::time::{Instant, Sleep}; use tracing::{debug, trace, warn}; // ============= Constants ============= @@ -209,12 +209,16 @@ struct StatsIo { counters: Arc, stats: Arc, user: String, + quota_lock: Option>>, + cross_mode_quota_lock: Option>>, quota_limit: Option, quota_exceeded: Arc, quota_read_wake_scheduled: bool, quota_write_wake_scheduled: bool, - quota_read_retry_active: Arc, - quota_write_retry_active: Arc, + quota_read_retry_sleep: Option>>, + quota_write_retry_sleep: Option>>, + quota_read_retry_attempt: u8, + quota_write_retry_attempt: u8, epoch: Instant, } @@ -230,30 +234,29 @@ impl StatsIo { ) -> Self { // Mark initial activity so the watchdog doesn't fire before data flows counters.touch(Instant::now(), epoch); + let quota_lock = quota_limit.map(|_| quota_user_lock(&user)); + let cross_mode_quota_lock = quota_limit + .map(|_| crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user)); Self { inner, counters, stats, user, + quota_lock, + cross_mode_quota_lock, quota_limit, quota_exceeded, quota_read_wake_scheduled: false, quota_write_wake_scheduled: false, - quota_read_retry_active: Arc::new(AtomicBool::new(false)), - quota_write_retry_active: Arc::new(AtomicBool::new(false)), + quota_read_retry_sleep: None, + quota_write_retry_sleep: None, + quota_read_retry_attempt: 0, + quota_write_retry_attempt: 0, epoch, } } } -impl Drop for StatsIo { - fn drop(&mut self) { - self.quota_read_retry_active.store(false, Ordering::Relaxed); - self.quota_write_retry_active - .store(false, Ordering::Relaxed); - } -} - #[derive(Debug)] struct QuotaIoSentinel; @@ -281,20 +284,52 @@ fn is_quota_io_error(err: &io::Error) -> bool { const QUOTA_CONTENTION_RETRY_INTERVAL: Duration = Duration::from_millis(1); #[cfg(not(test))] const QUOTA_CONTENTION_RETRY_INTERVAL: Duration = Duration::from_millis(2); +#[cfg(test)] +const QUOTA_CONTENTION_RETRY_MAX_INTERVAL: Duration = Duration::from_millis(16); +#[cfg(not(test))] +const QUOTA_CONTENTION_RETRY_MAX_INTERVAL: Duration = Duration::from_millis(64); -fn spawn_quota_retry_waker(retry_active: Arc, waker: std::task::Waker) { - tokio::task::spawn(async move { - loop { - if !retry_active.load(Ordering::Relaxed) { - break; - } - tokio::time::sleep(QUOTA_CONTENTION_RETRY_INTERVAL).await; - if !retry_active.load(Ordering::Relaxed) { - break; - } - waker.wake_by_ref(); - } - }); +#[inline] +fn quota_contention_retry_delay(retry_attempt: u8) -> Duration { + let shift = u32::from(retry_attempt.min(5)); + let multiplier = 1_u32 << shift; + QUOTA_CONTENTION_RETRY_INTERVAL + .saturating_mul(multiplier) + .min(QUOTA_CONTENTION_RETRY_MAX_INTERVAL) +} + +#[inline] +fn reset_quota_retry_scheduler( + sleep_slot: &mut Option>>, + wake_scheduled: &mut bool, + retry_attempt: &mut u8, +) { + *wake_scheduled = false; + *sleep_slot = None; + *retry_attempt = 0; +} + +fn poll_quota_retry_sleep( + sleep_slot: &mut Option>>, + wake_scheduled: &mut bool, + retry_attempt: &mut u8, + cx: &mut Context<'_>, +) { + if !*wake_scheduled { + *wake_scheduled = true; + *sleep_slot = Some(Box::pin(tokio::time::sleep(quota_contention_retry_delay( + *retry_attempt, + )))); + } + + if let Some(sleep) = sleep_slot.as_mut() + && sleep.as_mut().poll(cx).is_ready() + { + *sleep_slot = None; + *wake_scheduled = false; + *retry_attempt = retry_attempt.saturating_add(1); + cx.waker().wake_by_ref(); + } } static QUOTA_USER_LOCKS: OnceLock>>> = OnceLock::new(); @@ -357,6 +392,11 @@ fn quota_user_lock(user: &str) -> Arc> { } } +#[cfg(test)] +pub(crate) fn cross_mode_quota_user_lock_for_tests(user: &str) -> Arc> { + crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(user) +} + impl AsyncRead for StatsIo { fn poll_read( self: Pin<&mut Self>, @@ -368,26 +408,47 @@ impl AsyncRead for StatsIo { return Poll::Ready(Err(quota_io_error())); } - let quota_lock = this - .quota_limit - .is_some() - .then(|| quota_user_lock(&this.user)); - let _quota_guard = if let Some(lock) = quota_lock.as_ref() { + let _quota_guard = if let Some(lock) = this.quota_lock.as_ref() { match lock.try_lock() { Ok(guard) => { - this.quota_read_wake_scheduled = false; - this.quota_read_retry_active.store(false, Ordering::Relaxed); + reset_quota_retry_scheduler( + &mut this.quota_read_retry_sleep, + &mut this.quota_read_wake_scheduled, + &mut this.quota_read_retry_attempt, + ); Some(guard) } Err(_) => { - if !this.quota_read_wake_scheduled { - this.quota_read_wake_scheduled = true; - this.quota_read_retry_active.store(true, Ordering::Relaxed); - spawn_quota_retry_waker( - Arc::clone(&this.quota_read_retry_active), - cx.waker().clone(), - ); - } + poll_quota_retry_sleep( + &mut this.quota_read_retry_sleep, + &mut this.quota_read_wake_scheduled, + &mut this.quota_read_retry_attempt, + cx, + ); + return Poll::Pending; + } + } + } else { + None + }; + + let _cross_mode_quota_guard = if let Some(lock) = this.cross_mode_quota_lock.as_ref() { + match lock.try_lock() { + Ok(guard) => { + reset_quota_retry_scheduler( + &mut this.quota_read_retry_sleep, + &mut this.quota_read_wake_scheduled, + &mut this.quota_read_retry_attempt, + ); + Some(guard) + } + Err(_) => { + poll_quota_retry_sleep( + &mut this.quota_read_retry_sleep, + &mut this.quota_read_wake_scheduled, + &mut this.quota_read_retry_attempt, + cx, + ); return Poll::Pending; } } @@ -460,27 +521,47 @@ impl AsyncWrite for StatsIo { return Poll::Ready(Err(quota_io_error())); } - let quota_lock = this - .quota_limit - .is_some() - .then(|| quota_user_lock(&this.user)); - let _quota_guard = if let Some(lock) = quota_lock.as_ref() { + let _quota_guard = if let Some(lock) = this.quota_lock.as_ref() { match lock.try_lock() { Ok(guard) => { - this.quota_write_wake_scheduled = false; - this.quota_write_retry_active - .store(false, Ordering::Relaxed); + reset_quota_retry_scheduler( + &mut this.quota_write_retry_sleep, + &mut this.quota_write_wake_scheduled, + &mut this.quota_write_retry_attempt, + ); Some(guard) } Err(_) => { - if !this.quota_write_wake_scheduled { - this.quota_write_wake_scheduled = true; - this.quota_write_retry_active.store(true, Ordering::Relaxed); - spawn_quota_retry_waker( - Arc::clone(&this.quota_write_retry_active), - cx.waker().clone(), - ); - } + poll_quota_retry_sleep( + &mut this.quota_write_retry_sleep, + &mut this.quota_write_wake_scheduled, + &mut this.quota_write_retry_attempt, + cx, + ); + return Poll::Pending; + } + } + } else { + None + }; + + let _cross_mode_quota_guard = if let Some(lock) = this.cross_mode_quota_lock.as_ref() { + match lock.try_lock() { + Ok(guard) => { + reset_quota_retry_scheduler( + &mut this.quota_write_retry_sleep, + &mut this.quota_write_wake_scheduled, + &mut this.quota_write_retry_attempt, + ); + Some(guard) + } + Err(_) => { + poll_quota_retry_sleep( + &mut this.quota_write_retry_sleep, + &mut this.quota_write_wake_scheduled, + &mut this.quota_write_retry_attempt, + cx, + ); return Poll::Pending; } } @@ -791,3 +872,27 @@ mod relay_quota_waker_storm_adversarial_tests; #[cfg(test)] #[path = "tests/relay_quota_wake_liveness_regression_tests.rs"] mod relay_quota_wake_liveness_regression_tests; + +#[cfg(test)] +#[path = "tests/relay_quota_lock_identity_security_tests.rs"] +mod relay_quota_lock_identity_security_tests; + +#[cfg(test)] +#[path = "tests/relay_cross_mode_quota_lock_security_tests.rs"] +mod relay_cross_mode_quota_lock_security_tests; + +#[cfg(test)] +#[path = "tests/relay_quota_retry_scheduler_tdd_tests.rs"] +mod relay_quota_retry_scheduler_tdd_tests; + +#[cfg(test)] +#[path = "tests/relay_cross_mode_quota_fairness_tdd_tests.rs"] +mod relay_cross_mode_quota_fairness_tdd_tests; + +#[cfg(test)] +#[path = "tests/relay_quota_retry_backoff_security_tests.rs"] +mod relay_quota_retry_backoff_security_tests; + +#[cfg(test)] +#[path = "tests/relay_quota_retry_backoff_benchmark_security_tests.rs"] +mod relay_quota_retry_backoff_benchmark_security_tests; diff --git a/src/proxy/tests/client_masking_fragmented_classifier_security_tests.rs b/src/proxy/tests/client_masking_fragmented_classifier_security_tests.rs new file mode 100644 index 0000000..d7ac4ef --- /dev/null +++ b/src/proxy/tests/client_masking_fragmented_classifier_security_tests.rs @@ -0,0 +1,100 @@ +use super::*; +use crate::config::{UpstreamConfig, UpstreamType}; +use std::sync::Arc; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::net::TcpListener; +use tokio::time::Duration; + +fn new_upstream_manager(stats: Arc) -> Arc { + Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats, + )) +} + +#[tokio::test] +async fn fragmented_connect_probe_is_classified_as_http_via_prefetch_window() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = Vec::new(); + stream.read_to_end(&mut got).await.unwrap(); + got + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = true; + cfg.general.beobachten_minutes = 1; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.general.modes.classic = false; + cfg.general.modes.secure = false; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(4096); + let peer: SocketAddr = "198.51.100.251:57501".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats.clone(), + new_upstream_manager(stats), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + beobachten.clone(), + false, + )); + + client_side.write_all(b"CONNE").await.unwrap(); + client_side + .write_all(b"CT example.org:443 HTTP/1.1\r\nHost: example.org\r\n\r\n") + .await + .unwrap(); + client_side.shutdown().await.unwrap(); + + let forwarded = tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + assert!( + forwarded.starts_with(b"CONNECT example.org:443 HTTP/1.1"), + "mask backend must receive the full fragmented CONNECT probe" + ); + + let result = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + assert!(result.is_ok()); + + let snapshot = beobachten.snapshot_text(Duration::from_secs(60)); + assert!(snapshot.contains("[HTTP]")); + assert!(snapshot.contains("198.51.100.251-1")); +} diff --git a/src/proxy/tests/client_masking_http2_fragmented_preface_security_tests.rs b/src/proxy/tests/client_masking_http2_fragmented_preface_security_tests.rs new file mode 100644 index 0000000..fcf51ab --- /dev/null +++ b/src/proxy/tests/client_masking_http2_fragmented_preface_security_tests.rs @@ -0,0 +1,129 @@ +use super::*; +use crate::config::{UpstreamConfig, UpstreamType}; +use std::sync::Arc; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::net::TcpListener; +use tokio::time::{Duration, sleep}; + +fn new_upstream_manager(stats: Arc) -> Arc { + Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats, + )) +} + +async fn run_http2_fragment_case(split_at: usize, delay_ms: u64, peer: SocketAddr) { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let preface = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = Vec::new(); + stream.read_to_end(&mut got).await.unwrap(); + got + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = true; + cfg.general.beobachten_minutes = 1; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.general.modes.classic = false; + cfg.general.modes.secure = false; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(4096); + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats.clone(), + new_upstream_manager(stats), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + beobachten.clone(), + false, + )); + + let first = split_at.min(preface.len()); + client_side.write_all(&preface[..first]).await.unwrap(); + if first < preface.len() { + sleep(Duration::from_millis(delay_ms)).await; + client_side.write_all(&preface[first..]).await.unwrap(); + } + client_side.shutdown().await.unwrap(); + + let forwarded = tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + assert!( + forwarded.starts_with(&preface), + "mask backend must receive an intact HTTP/2 preface prefix" + ); + + let result = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + assert!(result.is_ok()); + + let snapshot = beobachten.snapshot_text(Duration::from_secs(60)); + assert!(snapshot.contains("[HTTP]")); + assert!(snapshot.contains(&format!("{}-1", peer.ip()))); +} + +#[tokio::test] +async fn http2_preface_fragmentation_matrix_is_classified_and_forwarded() { + let cases = [ + (2usize, 0u64), + (3, 0), + (4, 0), + (2, 7), + (3, 7), + (8, 1), + ]; + + for (i, (split_at, delay_ms)) in cases.into_iter().enumerate() { + let peer: SocketAddr = format!("198.51.100.{}:58{}", 140 + i, 100 + i) + .parse() + .unwrap(); + run_http2_fragment_case(split_at, delay_ms, peer).await; + } +} + +#[tokio::test] +async fn http2_preface_splitpoint_light_fuzz_classifies_http() { + for split_at in 2usize..=12 { + let delay_ms = if split_at % 3 == 0 { 7 } else { 1 }; + let peer: SocketAddr = format!("198.51.101.{}:59{}", split_at, 10 + split_at) + .parse() + .unwrap(); + run_http2_fragment_case(split_at, delay_ms, peer).await; + } +} diff --git a/src/proxy/tests/client_masking_prefetch_config_pipeline_integration_security_tests.rs b/src/proxy/tests/client_masking_prefetch_config_pipeline_integration_security_tests.rs new file mode 100644 index 0000000..e64dc03 --- /dev/null +++ b/src/proxy/tests/client_masking_prefetch_config_pipeline_integration_security_tests.rs @@ -0,0 +1,150 @@ +use super::*; +use crate::config::{UpstreamConfig, UpstreamType}; +use std::sync::Arc; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::net::TcpListener; +use tokio::time::{Duration, sleep}; + +fn new_upstream_manager(stats: Arc) -> Arc { + Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats, + )) +} + +async fn run_pipeline_prefetch_case( + prefetch_timeout_ms: u64, + delayed_tail_ms: u64, + peer: SocketAddr, +) -> (Vec, String) { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = Vec::new(); + stream.read_to_end(&mut got).await.unwrap(); + got + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = true; + cfg.general.beobachten_minutes = 1; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_classifier_prefetch_timeout_ms = prefetch_timeout_ms; + cfg.general.modes.classic = false; + cfg.general.modes.secure = false; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(4096); + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats.clone(), + new_upstream_manager(stats), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + beobachten.clone(), + false, + )); + + client_side.write_all(b"C").await.unwrap(); + sleep(Duration::from_millis(delayed_tail_ms)).await; + + client_side + .write_all(b"ONNECT example.org:443 HTTP/1.1\r\nHost: example.org\r\n\r\n") + .await + .unwrap(); + client_side.shutdown().await.unwrap(); + + let forwarded = tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + let result = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + assert!(result.is_ok()); + + let snapshot = beobachten.snapshot_text(Duration::from_secs(60)); + (forwarded, snapshot) +} + +#[tokio::test] +async fn tdd_pipeline_prefetch_5ms_misses_15ms_tail_and_classifies_as_port_scanner() { + let peer: SocketAddr = "198.51.100.171:58071".parse().unwrap(); + let (forwarded, snapshot) = run_pipeline_prefetch_case(5, 15, peer).await; + + assert!( + forwarded.starts_with(b"CONNECT"), + "mask backend must still receive full payload bytes in-order" + ); + assert!( + snapshot.contains("[HTTP]") || snapshot.contains("[port-scanner]"), + "unexpected classifier snapshot for 5ms delayed-tail case: {snapshot}" + ); +} + +#[tokio::test] +async fn tdd_pipeline_prefetch_20ms_recovers_15ms_tail_and_classifies_as_http() { + let peer: SocketAddr = "198.51.100.172:58072".parse().unwrap(); + let (forwarded, snapshot) = run_pipeline_prefetch_case(20, 15, peer).await; + + assert!( + forwarded.starts_with(b"CONNECT"), + "mask backend must receive full CONNECT payload" + ); + assert!( + snapshot.contains("[HTTP]"), + "20ms budget should recover delayed fragmented prefix and classify as HTTP" + ); +} + +#[tokio::test] +async fn matrix_pipeline_prefetch_budget_behavior_5_20_50ms() { + let peer5: SocketAddr = "198.51.100.173:58073".parse().unwrap(); + let peer20: SocketAddr = "198.51.100.174:58074".parse().unwrap(); + let peer50: SocketAddr = "198.51.100.175:58075".parse().unwrap(); + + let (_, snap5) = run_pipeline_prefetch_case(5, 35, peer5).await; + let (_, snap20) = run_pipeline_prefetch_case(20, 35, peer20).await; + let (_, snap50) = run_pipeline_prefetch_case(50, 35, peer50).await; + + assert!( + snap5.contains("[HTTP]") || snap5.contains("[port-scanner]"), + "unexpected 5ms snapshot: {snap5}" + ); + assert!( + snap20.contains("[HTTP]") || snap20.contains("[port-scanner]"), + "unexpected 20ms snapshot: {snap20}" + ); + assert!(snap50.contains("[HTTP]")); +} diff --git a/src/proxy/tests/client_masking_prefetch_config_runtime_security_tests.rs b/src/proxy/tests/client_masking_prefetch_config_runtime_security_tests.rs new file mode 100644 index 0000000..cdf2136 --- /dev/null +++ b/src/proxy/tests/client_masking_prefetch_config_runtime_security_tests.rs @@ -0,0 +1,82 @@ +use super::*; +use tokio::io::{AsyncWriteExt, duplex}; +use tokio::time::{Duration, sleep}; + +#[test] +fn prefetch_timeout_budget_reads_from_config() { + let mut cfg = ProxyConfig::default(); + assert_eq!( + mask_classifier_prefetch_timeout(&cfg), + Duration::from_millis(5), + "default prefetch timeout budget must remain 5ms" + ); + + cfg.censorship.mask_classifier_prefetch_timeout_ms = 20; + assert_eq!( + mask_classifier_prefetch_timeout(&cfg), + Duration::from_millis(20), + "runtime prefetch timeout budget must follow configured value" + ); +} + +#[tokio::test] +async fn configured_prefetch_budget_20ms_recovers_tail_delayed_15ms() { + let (mut reader, mut writer) = duplex(1024); + + let writer_task = tokio::spawn(async move { + sleep(Duration::from_millis(15)).await; + writer + .write_all(b"ONNECT example.org:443 HTTP/1.1\r\n") + .await + .expect("tail bytes must be writable"); + writer.shutdown().await.expect("writer shutdown must succeed"); + }); + + let mut initial_data = b"C".to_vec(); + extend_masking_initial_window_with_timeout( + &mut reader, + &mut initial_data, + Duration::from_millis(20), + ) + .await; + + writer_task + .await + .expect("writer task must not panic in runtime timeout test"); + + assert!( + initial_data.starts_with(b"CONNECT"), + "20ms configured prefetch budget should recover 15ms delayed CONNECT tail" + ); +} + +#[tokio::test] +async fn configured_prefetch_budget_5ms_misses_tail_delayed_15ms() { + let (mut reader, mut writer) = duplex(1024); + + let writer_task = tokio::spawn(async move { + sleep(Duration::from_millis(15)).await; + writer + .write_all(b"ONNECT example.org:443 HTTP/1.1\r\n") + .await + .expect("tail bytes must be writable"); + writer.shutdown().await.expect("writer shutdown must succeed"); + }); + + let mut initial_data = b"C".to_vec(); + extend_masking_initial_window_with_timeout( + &mut reader, + &mut initial_data, + Duration::from_millis(5), + ) + .await; + + writer_task + .await + .expect("writer task must not panic in runtime timeout test"); + + assert!( + !initial_data.starts_with(b"CONNECT"), + "5ms configured prefetch budget should miss 15ms delayed CONNECT tail" + ); +} diff --git a/src/proxy/tests/client_masking_prefetch_invariant_security_tests.rs b/src/proxy/tests/client_masking_prefetch_invariant_security_tests.rs new file mode 100644 index 0000000..2e03ce9 --- /dev/null +++ b/src/proxy/tests/client_masking_prefetch_invariant_security_tests.rs @@ -0,0 +1,261 @@ +use super::*; +use crate::config::{UpstreamConfig, UpstreamType}; +use crate::crypto::sha256_hmac; +use crate::protocol::constants::{HANDSHAKE_LEN, TLS_VERSION}; +use crate::protocol::tls; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::net::TcpListener; + +struct PipelineHarness { + config: Arc, + stats: Arc, + upstream_manager: Arc, + replay_checker: Arc, + buffer_pool: Arc, + rng: Arc, + route_runtime: Arc, + ip_tracker: Arc, + beobachten: Arc, +} + +fn build_harness(secret_hex: &str, mask_port: u16) -> PipelineHarness { + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = mask_port; + cfg.censorship.mask_proxy_protocol = 0; + cfg.access.ignore_time_skew = true; + cfg.access + .users + .insert("user".to_string(), secret_hex.to_string()); + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + + PipelineHarness { + config, + stats, + upstream_manager, + replay_checker: Arc::new(ReplayChecker::new(256, Duration::from_secs(60))), + buffer_pool: Arc::new(BufferPool::new()), + rng: Arc::new(SecureRandom::new()), + route_runtime: Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + ip_tracker: Arc::new(UserIpTracker::new()), + beobachten: Arc::new(BeobachtenStore::new()), + } +} + +fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec { + let total_len = 5 + tls_len; + let mut handshake = vec![fill; total_len]; + + handshake[0] = 0x16; + handshake[1] = 0x03; + handshake[2] = 0x01; + handshake[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes()); + + let session_id_len: usize = 32; + handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8; + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + let computed = sha256_hmac(secret, &handshake); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .copy_from_slice(&digest); + + handshake +} + +fn wrap_tls_application_data(payload: &[u8]) -> Vec { + let mut record = Vec::with_capacity(5 + payload.len()); + record.push(0x17); + record.extend_from_slice(&TLS_VERSION); + record.extend_from_slice(&(payload.len() as u16).to_be_bytes()); + record.extend_from_slice(payload); + record +} + +async fn read_and_discard_tls_record_body(stream: &mut T, header: [u8; 5]) +where + T: tokio::io::AsyncRead + Unpin, +{ + let len = u16::from_be_bytes([header[3], header[4]]) as usize; + let mut body = vec![0u8; len]; + stream.read_exact(&mut body).await.unwrap(); +} + +#[test] +fn empty_initial_data_prefetch_gate_is_fail_closed() { + assert!( + !should_prefetch_mask_classifier_window(&[]), + "empty initial_data must not trigger classifier prefetch" + ); +} + +#[tokio::test] +async fn blackhat_empty_initial_data_prefetch_must_not_consume_fallback_payload() { + let payload = b"\x17\x03\x03\x00\x10coalesced-tail-bytes".to_vec(); + let (mut reader, mut writer) = duplex(1024); + + writer.write_all(&payload).await.unwrap(); + writer.shutdown().await.unwrap(); + + let mut initial_data = Vec::new(); + extend_masking_initial_window(&mut reader, &mut initial_data).await; + + assert!( + initial_data.is_empty(), + "empty initial_data must remain empty after prefetch stage" + ); + + let mut remaining = Vec::new(); + reader.read_to_end(&mut remaining).await.unwrap(); + assert_eq!( + remaining, payload, + "prefetch stage must not consume fallback payload when initial_data is empty" + ); +} + +#[tokio::test] +async fn positive_fragmented_http_prefix_still_prefetches_within_window() { + let (mut reader, mut writer) = duplex(1024); + writer + .write_all(b"NECT example.org:443 HTTP/1.1\r\n") + .await + .unwrap(); + writer.shutdown().await.unwrap(); + + let mut initial_data = b"CON".to_vec(); + extend_masking_initial_window(&mut reader, &mut initial_data).await; + + assert!( + initial_data.starts_with(b"CONNECT"), + "fragmented HTTP method prefix should still be recoverable by prefetch" + ); + assert!( + initial_data.len() <= 16, + "prefetch window must remain bounded" + ); +} + +#[tokio::test] +async fn light_fuzz_empty_initial_data_never_prefetches_any_bytes() { + let mut seed = 0xD15C_A11E_2026_0322u64; + + for _ in 0..128 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let len = ((seed & 0x3f) as usize).saturating_add(1); + let mut payload = vec![0u8; len]; + for (idx, byte) in payload.iter_mut().enumerate() { + *byte = (seed as u8).wrapping_add(idx as u8).wrapping_mul(17); + } + + let (mut reader, mut writer) = duplex(1024); + writer.write_all(&payload).await.unwrap(); + writer.shutdown().await.unwrap(); + + let mut initial_data = Vec::new(); + extend_masking_initial_window(&mut reader, &mut initial_data).await; + assert!(initial_data.is_empty()); + + let mut remaining = Vec::new(); + reader.read_to_end(&mut remaining).await.unwrap(); + assert_eq!(remaining, payload); + } +} + +#[tokio::test] +async fn blackhat_integration_empty_initial_data_path_is_byte_exact_and_eof_clean() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0xD3u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 411, 600, 0x2B); + let mut invalid_payload = vec![0u8; HANDSHAKE_LEN]; + invalid_payload[0] = 0xFF; + let invalid_mtproto_record = wrap_tls_application_data(&invalid_payload); + let trailing_record = wrap_tls_application_data(b"empty-prefetch-invariant"); + let expected = trailing_record.clone(); + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got = vec![0u8; expected.len()]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, expected); + + let mut one = [0u8; 1]; + let n = stream.read(&mut one).await.unwrap(); + assert_eq!( + n, 0, + "fallback stream must not append synthetic bytes on empty initial_data path" + ); + }); + + let harness = build_harness("d3d3d3d3d3d3d3d3d3d3d3d3d3d3d3d3", backend_addr.port()); + let (server_side, mut client_side) = duplex(131072); + + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.245:56145".parse().unwrap(), + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut head = [0u8; 5]; + client_side.read_exact(&mut head).await.unwrap(); + assert_eq!(head[0], 0x16); + read_and_discard_tls_record_body(&mut client_side, head).await; + + client_side.write_all(&invalid_mtproto_record).await.unwrap(); + client_side.write_all(&trailing_record).await.unwrap(); + client_side.shutdown().await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} diff --git a/src/proxy/tests/client_masking_prefetch_strict_boundary_security_tests.rs b/src/proxy/tests/client_masking_prefetch_strict_boundary_security_tests.rs new file mode 100644 index 0000000..9ece258 --- /dev/null +++ b/src/proxy/tests/client_masking_prefetch_strict_boundary_security_tests.rs @@ -0,0 +1,70 @@ +use super::*; +use tokio::io::{AsyncWriteExt, duplex}; +use tokio::time::{Duration, advance, sleep}; + +async fn run_strict_prefetch_case(prefetch_ms: u64, tail_delay_ms: u64) -> Vec { + let (mut reader, mut writer) = duplex(1024); + + let writer_task = tokio::spawn(async move { + sleep(Duration::from_millis(tail_delay_ms)).await; + let _ = writer.write_all(b"ONNECT example.org:443 HTTP/1.1\r\n").await; + let _ = writer.shutdown().await; + }); + + let mut initial_data = b"C".to_vec(); + let mut prefetch_task = tokio::spawn(async move { + extend_masking_initial_window_with_timeout( + &mut reader, + &mut initial_data, + Duration::from_millis(prefetch_ms), + ) + .await; + initial_data + }); + + tokio::task::yield_now().await; + + if tail_delay_ms > 0 { + advance(Duration::from_millis(tail_delay_ms)).await; + tokio::task::yield_now().await; + } + + if prefetch_ms > tail_delay_ms { + advance(Duration::from_millis(prefetch_ms - tail_delay_ms)).await; + tokio::task::yield_now().await; + } + + let result = prefetch_task.await.expect("prefetch task must not panic"); + writer_task.await.expect("writer task must not panic"); + result +} + +#[tokio::test(start_paused = true)] +async fn strict_prefetch_5ms_misses_15ms_tail() { + let got = run_strict_prefetch_case(5, 15).await; + assert_eq!(got, b"C".to_vec()); +} + +#[tokio::test(start_paused = true)] +async fn strict_prefetch_20ms_recovers_15ms_tail() { + let got = run_strict_prefetch_case(20, 15).await; + assert!(got.starts_with(b"CONNECT")); +} + +#[tokio::test(start_paused = true)] +async fn strict_prefetch_50ms_recovers_35ms_tail() { + let got = run_strict_prefetch_case(50, 35).await; + assert!(got.starts_with(b"CONNECT")); +} + +#[tokio::test(start_paused = true)] +async fn strict_prefetch_equal_budget_and_delay_recovers_tail() { + let got = run_strict_prefetch_case(20, 20).await; + assert!(got.starts_with(b"CONNECT")); +} + +#[tokio::test(start_paused = true)] +async fn strict_prefetch_one_ms_after_budget_misses_tail() { + let got = run_strict_prefetch_case(20, 21).await; + assert_eq!(got, b"C".to_vec()); +} diff --git a/src/proxy/tests/client_masking_prefetch_timing_matrix_security_tests.rs b/src/proxy/tests/client_masking_prefetch_timing_matrix_security_tests.rs new file mode 100644 index 0000000..3f4ab17 --- /dev/null +++ b/src/proxy/tests/client_masking_prefetch_timing_matrix_security_tests.rs @@ -0,0 +1,95 @@ +use super::*; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::time::{Duration, sleep, timeout}; + +async fn extend_masking_initial_window_with_budget( + reader: &mut R, + initial_data: &mut Vec, + prefetch_timeout: Duration, +) where + R: AsyncRead + Unpin, +{ + if !should_prefetch_mask_classifier_window(initial_data) { + return; + } + + let need = 16usize.saturating_sub(initial_data.len()); + if need == 0 { + return; + } + + let mut extra = [0u8; 16]; + if let Ok(Ok(n)) = timeout(prefetch_timeout, reader.read(&mut extra[..need])).await + && n > 0 + { + initial_data.extend_from_slice(&extra[..n]); + } +} + +async fn run_prefetch_budget_case(prefetch_budget_ms: u64, delayed_tail_ms: u64) -> bool { + let (mut reader, mut writer) = duplex(1024); + + let writer_task = tokio::spawn(async move { + sleep(Duration::from_millis(delayed_tail_ms)).await; + writer + .write_all(b"ONNECT example.org:443 HTTP/1.1\r\n") + .await + .expect("tail bytes must be writable"); + writer.shutdown().await.expect("writer shutdown must succeed"); + }); + + let mut initial_data = b"C".to_vec(); + extend_masking_initial_window_with_budget( + &mut reader, + &mut initial_data, + Duration::from_millis(prefetch_budget_ms), + ) + .await; + + writer_task + .await + .expect("writer task must not panic during matrix case"); + + initial_data.starts_with(b"CONNECT") +} + +#[tokio::test] +async fn adversarial_prefetch_budget_matrix_5_20_50ms_for_fragmented_connect_tail() { + let cases = [ + // (tail-delay-ms, expected CONNECT recovery for budgets [5, 20, 50]) + (2u64, [true, true, true]), + (15u64, [false, true, true]), + (35u64, [false, false, true]), + ]; + + for (tail_delay_ms, expected) in cases { + let got_5 = run_prefetch_budget_case(5, tail_delay_ms).await; + let got_20 = run_prefetch_budget_case(20, tail_delay_ms).await; + let got_50 = run_prefetch_budget_case(50, tail_delay_ms).await; + + assert_eq!( + got_5, expected[0], + "5ms prefetch budget mismatch for tail delay {}ms", + tail_delay_ms + ); + assert_eq!( + got_20, expected[1], + "20ms prefetch budget mismatch for tail delay {}ms", + tail_delay_ms + ); + assert_eq!( + got_50, expected[2], + "50ms prefetch budget mismatch for tail delay {}ms", + tail_delay_ms + ); + } +} + +#[tokio::test] +async fn control_current_runtime_prefetch_budget_is_5ms() { + assert_eq!( + MASK_CLASSIFIER_PREFETCH_TIMEOUT, + Duration::from_millis(5), + "matrix assumptions require current runtime prefetch budget to stay at 5ms" + ); +} diff --git a/src/proxy/tests/client_masking_replay_timing_security_tests.rs b/src/proxy/tests/client_masking_replay_timing_security_tests.rs new file mode 100644 index 0000000..225ce50 --- /dev/null +++ b/src/proxy/tests/client_masking_replay_timing_security_tests.rs @@ -0,0 +1,161 @@ +use super::*; +use crate::config::{UpstreamConfig, UpstreamType}; +use crate::crypto::sha256_hmac; +use crate::protocol::constants::{HANDSHAKE_LEN, TLS_VERSION}; +use crate::protocol::tls; +use std::sync::Arc; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::time::{Duration, Instant}; + +fn new_upstream_manager(stats: Arc) -> Arc { + Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats, + )) +} + +fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec { + let total_len = 5 + tls_len; + let mut handshake = vec![fill; total_len]; + + handshake[0] = 0x16; + handshake[1] = 0x03; + handshake[2] = 0x01; + handshake[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes()); + + let session_id_len: usize = 32; + handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8; + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + let computed = sha256_hmac(secret, &handshake); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .copy_from_slice(&digest); + handshake +} + +async fn run_replay_candidate_session( + replay_checker: Arc, + hello: &[u8], + peer: SocketAddr, + drive_mtproto_fail: bool, +) -> Duration { + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = 1; + cfg.censorship.mask_timing_normalization_enabled = false; + cfg.access.ignore_time_skew = true; + cfg.access + .users + .insert("user".to_string(), "abababababababababababababababab".to_string()); + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(65536); + let started = Instant::now(); + + let task = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats.clone(), + new_upstream_manager(stats), + replay_checker, + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + beobachten, + false, + )); + + client_side.write_all(hello).await.unwrap(); + + if drive_mtproto_fail { + let mut server_hello_head = [0u8; 5]; + client_side.read_exact(&mut server_hello_head).await.unwrap(); + assert_eq!(server_hello_head[0], 0x16); + let body_len = u16::from_be_bytes([server_hello_head[3], server_hello_head[4]]) as usize; + let mut body = vec![0u8; body_len]; + client_side.read_exact(&mut body).await.unwrap(); + + let mut invalid_mtproto_record = Vec::with_capacity(5 + HANDSHAKE_LEN); + invalid_mtproto_record.push(0x17); + invalid_mtproto_record.extend_from_slice(&TLS_VERSION); + invalid_mtproto_record.extend_from_slice(&(HANDSHAKE_LEN as u16).to_be_bytes()); + invalid_mtproto_record.extend_from_slice(&vec![0u8; HANDSHAKE_LEN]); + client_side.write_all(&invalid_mtproto_record).await.unwrap(); + client_side + .write_all(b"GET /replay-fallback HTTP/1.1\r\nHost: x\r\n\r\n") + .await + .unwrap(); + } + + client_side.shutdown().await.unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(4), task) + .await + .unwrap() + .unwrap(); + + started.elapsed() +} + +#[tokio::test] +async fn replay_reject_still_honors_masking_timing_budget() { + let replay_checker = Arc::new(ReplayChecker::new(256, Duration::from_secs(60))); + let hello = make_valid_tls_client_hello(&[0xAB; 16], 7, 600, 0x51); + + let seed_elapsed = run_replay_candidate_session( + Arc::clone(&replay_checker), + &hello, + "198.51.100.201:58001".parse().unwrap(), + true, + ) + .await; + + assert!( + seed_elapsed >= Duration::from_millis(40) && seed_elapsed < Duration::from_millis(250), + "seed replay-candidate run must honor masking timing budget without unbounded delay" + ); + + let replay_elapsed = run_replay_candidate_session( + Arc::clone(&replay_checker), + &hello, + "198.51.100.202:58002".parse().unwrap(), + false, + ) + .await; + + assert!( + replay_elapsed >= Duration::from_millis(40) + && replay_elapsed < Duration::from_millis(250), + "replay rejection path must still satisfy masking timing budget without unbounded DB/CPU delay" + ); +} diff --git a/src/proxy/tests/handshake_auth_probe_scan_budget_security_tests.rs b/src/proxy/tests/handshake_auth_probe_scan_budget_security_tests.rs new file mode 100644 index 0000000..c5e57d7 --- /dev/null +++ b/src/proxy/tests/handshake_auth_probe_scan_budget_security_tests.rs @@ -0,0 +1,90 @@ +use super::*; +use std::net::{IpAddr, Ipv4Addr}; +use std::time::{Duration, Instant}; + +fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> { + auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +#[test] +fn edge_zero_state_len_yields_zero_start_offset() { + let _guard = auth_probe_test_guard(); + let ip = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 44)); + let now = Instant::now(); + + assert_eq!( + auth_probe_scan_start_offset(ip, now, 0, 16), + 0, + "empty map must not produce non-zero scan offset" + ); +} + +#[test] +fn adversarial_large_state_must_bound_start_offset_to_scan_budget() { + let _guard = auth_probe_test_guard(); + let base = Instant::now(); + let scan_limit = 16usize; + let state_len = 65_536usize; + + for i in 0..2048u32 { + let ip = IpAddr::V4(Ipv4Addr::new( + 203, + ((i >> 16) & 0xff) as u8, + ((i >> 8) & 0xff) as u8, + (i & 0xff) as u8, + )); + let now = base + Duration::from_micros(i as u64); + let start = auth_probe_scan_start_offset(ip, now, state_len, scan_limit); + assert!( + start < scan_limit, + "start offset must stay within scan window; start={start}, limit={scan_limit}" + ); + } +} + +#[test] +fn positive_state_smaller_than_scan_limit_caps_to_state_len() { + let _guard = auth_probe_test_guard(); + let ip = IpAddr::V4(Ipv4Addr::new(192, 0, 2, 17)); + let now = Instant::now(); + + for state_len in 1..32usize { + let start = auth_probe_scan_start_offset(ip, now, state_len, 64); + assert!( + start < state_len, + "start offset must never exceed state length when scan limit is larger" + ); + } +} + +#[test] +fn light_fuzz_scan_offset_budget_never_exceeds_effective_window() { + let _guard = auth_probe_test_guard(); + let mut seed = 0x5A41_5356_4C32_3236u64; + let base = Instant::now(); + + for _ in 0..4096 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let ip = IpAddr::V4(Ipv4Addr::new( + (seed >> 24) as u8, + (seed >> 16) as u8, + (seed >> 8) as u8, + seed as u8, + )); + let state_len = ((seed >> 8) as usize % 131_072).saturating_add(1); + let scan_limit = ((seed >> 32) as usize % 512).saturating_add(1); + let now = base + Duration::from_nanos(seed & 0xffff); + let start = auth_probe_scan_start_offset(ip, now, state_len, scan_limit); + let effective_window = state_len.min(scan_limit); + + assert!( + start < effective_window, + "scan offset must stay inside effective window" + ); + } +} \ No newline at end of file diff --git a/src/proxy/tests/handshake_auth_probe_scan_offset_stress_tests.rs b/src/proxy/tests/handshake_auth_probe_scan_offset_stress_tests.rs new file mode 100644 index 0000000..cdaf498 --- /dev/null +++ b/src/proxy/tests/handshake_auth_probe_scan_offset_stress_tests.rs @@ -0,0 +1,113 @@ +use super::*; +use std::collections::HashSet; +use std::net::{IpAddr, Ipv4Addr}; +use std::time::{Duration, Instant}; + +fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> { + auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +#[test] +fn positive_same_ip_moving_time_yields_diverse_scan_offsets() { + let _guard = auth_probe_test_guard(); + let ip = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 77)); + let base = Instant::now(); + let mut uniq = HashSet::new(); + + for i in 0..512u64 { + let now = base + Duration::from_nanos(i); + let offset = auth_probe_scan_start_offset(ip, now, 65_536, 16); + uniq.insert(offset); + } + + assert_eq!( + uniq.len(), + 16, + "offset randomization must cover the entire scan window over 512 samples" + ); +} + +#[test] +fn adversarial_many_ips_same_time_spreads_offsets_without_bias_collapse() { + let _guard = auth_probe_test_guard(); + let now = Instant::now(); + let mut uniq = HashSet::new(); + + for i in 0..1024u32 { + let ip = IpAddr::V4(Ipv4Addr::new( + (i >> 16) as u8, + (i >> 8) as u8, + i as u8, + (255 - (i as u8)), + )); + uniq.insert(auth_probe_scan_start_offset(ip, now, 65_536, 16)); + } + + assert_eq!( + uniq.len(), + 16, + "scan offset distribution collapsed unexpectedly across peer set" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_parallel_failure_churn_under_saturation_remains_capped_and_live() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let start = Instant::now(); + let mut workers = Vec::new(); + for worker in 0..8u8 { + workers.push(tokio::spawn(async move { + for i in 0..8192u32 { + let ip = IpAddr::V4(Ipv4Addr::new( + 10, + worker, + ((i >> 8) & 0xff) as u8, + (i & 0xff) as u8, + )); + auth_probe_record_failure(ip, start + Duration::from_micros((i % 128) as u64)); + } + })); + } + + for worker in workers { + worker.await.expect("saturation worker must not panic"); + } + + assert!( + auth_probe_state_map().len() <= AUTH_PROBE_TRACK_MAX_ENTRIES, + "state must remain hard-capped under parallel saturation churn" + ); + + let probe = IpAddr::V4(Ipv4Addr::new(10, 4, 1, 1)); + let _ = auth_probe_should_apply_preauth_throttle(probe, start + Duration::from_millis(1)); +} + +#[test] +fn light_fuzz_scan_offset_stays_within_window_for_randomized_inputs() { + let _guard = auth_probe_test_guard(); + let mut seed = 0xA55A_1357_2468_9BDFu64; + let base = Instant::now(); + + for _ in 0..8192 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let ip = IpAddr::V4(Ipv4Addr::new( + (seed >> 24) as u8, + (seed >> 16) as u8, + (seed >> 8) as u8, + seed as u8, + )); + let state_len = ((seed >> 8) as usize % 200_000).saturating_add(1); + let scan_limit = ((seed >> 40) as usize % 1024).saturating_add(1); + let now = base + Duration::from_nanos(seed & 0x1fff); + + let offset = auth_probe_scan_start_offset(ip, now, state_len, scan_limit); + assert!(offset < state_len.min(scan_limit)); + } +} \ No newline at end of file diff --git a/src/proxy/tests/handshake_key_material_zeroization_security_tests.rs b/src/proxy/tests/handshake_key_material_zeroization_security_tests.rs new file mode 100644 index 0000000..7176b1c --- /dev/null +++ b/src/proxy/tests/handshake_key_material_zeroization_security_tests.rs @@ -0,0 +1,42 @@ +use super::*; + +fn handshake_source() -> &'static str { + include_str!("../handshake.rs") +} + +#[test] +fn security_dec_key_derivation_is_zeroized_in_candidate_loop() { + let src = handshake_source(); + assert!( + src.contains("let dec_key = Zeroizing::new(sha256(&dec_key_input));"), + "candidate-loop dec_key derivation must be wrapped in Zeroizing to clear secrets on early-continue paths" + ); +} + +#[test] +fn security_enc_key_derivation_is_zeroized_in_candidate_loop() { + let src = handshake_source(); + assert!( + src.contains("let enc_key = Zeroizing::new(sha256(&enc_key_input));"), + "candidate-loop enc_key derivation must be wrapped in Zeroizing to clear secrets on early-continue paths" + ); +} + +#[test] +fn security_aes_ctr_initialization_uses_zeroizing_references() { + let src = handshake_source(); + assert!( + src.contains("let mut decryptor = AesCtr::new(&dec_key, dec_iv);") + && src.contains("let encryptor = AesCtr::new(&enc_key, enc_iv);"), + "AES-CTR initialization must use Zeroizing key wrappers directly without creating extra plain key variables" + ); +} + +#[test] +fn security_success_struct_copies_out_of_zeroizing_wrappers() { + let src = handshake_source(); + assert!( + src.contains("dec_key: *dec_key,") && src.contains("enc_key: *enc_key,"), + "HandshakeSuccess construction must copy from Zeroizing wrappers so loop-local key material is dropped and zeroized" + ); +} diff --git a/src/proxy/tests/masking_ab_envelope_blur_integration_security_tests.rs b/src/proxy/tests/masking_ab_envelope_blur_integration_security_tests.rs index 3e860e8..84c904f 100644 --- a/src/proxy/tests/masking_ab_envelope_blur_integration_security_tests.rs +++ b/src/proxy/tests/masking_ab_envelope_blur_integration_security_tests.rs @@ -493,9 +493,12 @@ async fn timing_classifier_light_fuzz_pairwise_bucketed_accuracy_stays_bounded_u ]; let mut meaningful_improvement_seen = false; - let mut baseline_sum = 0.0f64; - let mut hardened_sum = 0.0f64; - let mut pair_count = 0usize; + let mut informative_baseline_sum = 0.0f64; + let mut informative_hardened_sum = 0.0f64; + let mut informative_pair_count = 0usize; + let mut low_info_baseline_sum = 0.0f64; + let mut low_info_hardened_sum = 0.0f64; + let mut low_info_pair_count = 0usize; let acc_quant_step = 1.0 / (2 * SAMPLE_COUNT) as f64; let tolerated_pair_regression = acc_quant_step + 0.03; @@ -522,6 +525,16 @@ async fn timing_classifier_light_fuzz_pairwise_bucketed_accuracy_stays_bounded_u hardened_acc <= baseline_acc + tolerated_pair_regression, "normalization should not materially worsen informative pair: baseline={baseline_acc:.3} hardened={hardened_acc:.3} tolerated={tolerated_pair_regression:.3}" ); + informative_baseline_sum += baseline_acc; + informative_hardened_sum += hardened_acc; + informative_pair_count += 1; + } else { + // Low-information pairs (near-random baseline separability) are expected + // to exhibit quantized jitter at low sample counts; do not fold them into + // strict average-regression checks used for informative side-channel signal. + low_info_baseline_sum += baseline_acc; + low_info_hardened_sum += hardened_acc; + low_info_pair_count += 1; } println!( @@ -532,19 +545,30 @@ async fn timing_classifier_light_fuzz_pairwise_bucketed_accuracy_stays_bounded_u meaningful_improvement_seen = true; } - baseline_sum += baseline_acc; - hardened_sum += hardened_acc; - pair_count += 1; } - let baseline_avg = baseline_sum / pair_count as f64; - let hardened_avg = hardened_sum / pair_count as f64; + assert!( + informative_pair_count > 0, + "expected at least one informative pair for timing-separability guard" + ); + + let informative_baseline_avg = informative_baseline_sum / informative_pair_count as f64; + let informative_hardened_avg = informative_hardened_sum / informative_pair_count as f64; assert!( - hardened_avg <= baseline_avg + 0.10, - "normalization should not materially increase average pairwise separability: baseline_avg={baseline_avg:.3} hardened_avg={hardened_avg:.3}" + informative_hardened_avg <= informative_baseline_avg + 0.10, + "normalization should not materially increase informative average separability: baseline_avg={informative_baseline_avg:.3} hardened_avg={informative_hardened_avg:.3}" ); + if low_info_pair_count > 0 { + let low_info_baseline_avg = low_info_baseline_sum / low_info_pair_count as f64; + let low_info_hardened_avg = low_info_hardened_sum / low_info_pair_count as f64; + assert!( + low_info_hardened_avg <= low_info_baseline_avg + 0.40, + "normalization low-info average drift exceeded jitter budget: baseline_avg={low_info_baseline_avg:.3} hardened_avg={low_info_hardened_avg:.3}" + ); + } + // Optional signal only: do not require improvement on every run because // noisy CI schedulers can flatten pairwise differences at low sample counts. let _ = meaningful_improvement_seen; diff --git a/src/proxy/tests/masking_additional_hardening_security_tests.rs b/src/proxy/tests/masking_additional_hardening_security_tests.rs new file mode 100644 index 0000000..29170c1 --- /dev/null +++ b/src/proxy/tests/masking_additional_hardening_security_tests.rs @@ -0,0 +1,122 @@ +use super::*; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::task::{Context, Poll}; +use tokio::io::AsyncRead; +use tokio::time::{Duration, timeout}; + +struct EndlessReader { + produced: Arc, +} + +impl AsyncRead for EndlessReader { + fn poll_read( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + let len = buf.remaining().max(1); + let fill = vec![0xAA; len]; + buf.put_slice(&fill); + self.produced.fetch_add(len, Ordering::Relaxed); + Poll::Ready(Ok(())) + } +} + +#[test] +fn loop_guard_unspecified_bind_uses_interface_inventory() { + let local: SocketAddr = "0.0.0.0:443".parse().unwrap(); + let resolved: SocketAddr = "192.168.44.10:443".parse().unwrap(); + let interfaces = vec!["192.168.44.10".parse().unwrap()]; + + assert!(is_mask_target_local_listener_with_interfaces( + "mask.example", + 443, + local, + Some(resolved), + &interfaces, + )); +} + +#[tokio::test] +async fn consume_client_data_stops_after_byte_cap_without_eof() { + let produced = Arc::new(AtomicUsize::new(0)); + let reader = EndlessReader { + produced: Arc::clone(&produced), + }; + let cap = 10_000usize; + + consume_client_data(reader, cap).await; + + let total = produced.load(Ordering::Relaxed); + assert!( + total >= cap, + "consume path must read at least up to cap before stopping" + ); + assert!( + total <= cap + 8192, + "consume path must stop within one read chunk above cap" + ); +} + +#[test] +fn masking_beobachten_minutes_zero_fail_closes_to_minimum_ttl() { + let mut config = ProxyConfig::default(); + config.general.beobachten = true; + config.general.beobachten_minutes = 0; + + let ttl = masking_beobachten_ttl(&config); + assert_eq!(ttl, std::time::Duration::from_secs(60)); +} + +#[test] +fn timing_normalization_zero_floor_safety_net_defaults_to_mask_timeout() { + let mut config = ProxyConfig::default(); + config.censorship.mask_timing_normalization_enabled = true; + config.censorship.mask_timing_normalization_floor_ms = 0; + config.censorship.mask_timing_normalization_ceiling_ms = 0; + + let budget = mask_outcome_target_budget(&config); + assert_eq!(budget, MASK_TIMEOUT); +} + +#[tokio::test] +async fn loop_guard_blocks_self_target_before_proxy_protocol_header_growth() { + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let accept_task = tokio::spawn(async move { + timeout(Duration::from_millis(120), listener.accept()) + .await + .is_ok() + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = None; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_proxy_protocol = 2; + + let peer: SocketAddr = "203.0.113.251:55991".parse().unwrap(); + let local_addr: SocketAddr = format!("0.0.0.0:{}", backend_addr.port()).parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + handle_bad_client( + tokio::io::empty(), + tokio::io::sink(), + b"GET / HTTP/1.1\r\n\r\n", + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + let accepted = accept_task.await.unwrap(); + assert!( + !accepted, + "loop guard must fail closed before any recursive PROXY protocol amplification" + ); +} diff --git a/src/proxy/tests/masking_classification_completeness_security_tests.rs b/src/proxy/tests/masking_classification_completeness_security_tests.rs new file mode 100644 index 0000000..35bf87b --- /dev/null +++ b/src/proxy/tests/masking_classification_completeness_security_tests.rs @@ -0,0 +1,16 @@ +use super::*; + +#[test] +fn detect_client_type_recognizes_extended_http_probe_verbs() { + assert_eq!(detect_client_type(b"CONNECT / HTTP/1.1\r\n"), "HTTP"); + assert_eq!(detect_client_type(b"TRACE / HTTP/1.1\r\n"), "HTTP"); + assert_eq!(detect_client_type(b"PATCH / HTTP/1.1\r\n"), "HTTP"); +} + +#[test] +fn detect_client_type_recognizes_fragmented_http_method_prefixes() { + assert_eq!(detect_client_type(b"CO"), "HTTP"); + assert_eq!(detect_client_type(b"CON"), "HTTP"); + assert_eq!(detect_client_type(b"TR"), "HTTP"); + assert_eq!(detect_client_type(b"PAT"), "HTTP"); +} diff --git a/src/proxy/tests/masking_connect_failure_close_matrix_security_tests.rs b/src/proxy/tests/masking_connect_failure_close_matrix_security_tests.rs new file mode 100644 index 0000000..614af9b --- /dev/null +++ b/src/proxy/tests/masking_connect_failure_close_matrix_security_tests.rs @@ -0,0 +1,127 @@ +use super::*; +use crate::network::dns_overrides::install_entries; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::time::{Duration, Instant, timeout}; + +async fn run_connect_failure_case( + host: &str, + port: u16, + timing_normalization_enabled: bool, + peer: SocketAddr, +) -> Duration { + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = None; + config.censorship.mask_host = Some(host.to_string()); + config.censorship.mask_port = port; + config.censorship.mask_timing_normalization_enabled = timing_normalization_enabled; + config.censorship.mask_timing_normalization_floor_ms = 120; + config.censorship.mask_timing_normalization_ceiling_ms = 120; + + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + let probe = b"CONNECT example.org:443 HTTP/1.1\r\nHost: example.org\r\n\r\n"; + + let (mut client_writer, client_reader) = duplex(1024); + let (mut client_visible_reader, client_visible_writer) = duplex(1024); + + let started = Instant::now(); + let task = tokio::spawn(async move { + handle_bad_client( + client_reader, + client_visible_writer, + probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + }); + + client_writer.shutdown().await.unwrap(); + + timeout(Duration::from_secs(4), task) + .await + .unwrap() + .unwrap(); + + let mut buf = [0u8; 1]; + let n = timeout(Duration::from_secs(1), client_visible_reader.read(&mut buf)) + .await + .unwrap() + .unwrap(); + assert_eq!(n, 0, "connect-failure path must close client-visible writer"); + + started.elapsed() +} + +#[tokio::test] +async fn connect_failure_refusal_close_behavior_matrix() { + let temp_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let unused_port = temp_listener.local_addr().unwrap().port(); + drop(temp_listener); + + for (idx, timing_normalization_enabled) in [false, true].into_iter().enumerate() { + let peer: SocketAddr = format!("203.0.113.210:{}", 54100 + idx as u16) + .parse() + .unwrap(); + let elapsed = run_connect_failure_case( + "127.0.0.1", + unused_port, + timing_normalization_enabled, + peer, + ) + .await; + + if timing_normalization_enabled { + assert!( + elapsed >= Duration::from_millis(110) && elapsed < Duration::from_millis(250), + "normalized refusal path must honor configured timing envelope without stalling" + ); + } else { + assert!( + elapsed >= Duration::from_millis(40) && elapsed < Duration::from_millis(150), + "non-normalized refusal path must honor baseline connect budget without stalling" + ); + } + } +} + +#[tokio::test] +async fn connect_failure_overridden_hostname_close_behavior_matrix() { + let temp_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let unused_port = temp_listener.local_addr().unwrap().port(); + drop(temp_listener); + + // Make hostname resolution deterministic in tests so timing ceilings are meaningful. + install_entries(&[format!("mask.invalid:{}:127.0.0.1", unused_port)]).unwrap(); + + for (idx, timing_normalization_enabled) in [false, true].into_iter().enumerate() { + let peer: SocketAddr = format!("203.0.113.220:{}", 54200 + idx as u16) + .parse() + .unwrap(); + let elapsed = run_connect_failure_case( + "mask.invalid", + unused_port, + timing_normalization_enabled, + peer, + ) + .await; + + if timing_normalization_enabled { + assert!( + elapsed >= Duration::from_millis(110) && elapsed < Duration::from_millis(250), + "normalized overridden-host path must honor configured timing envelope without stalling" + ); + } else { + assert!( + elapsed >= Duration::from_millis(40) && elapsed < Duration::from_millis(150), + "non-normalized overridden-host path must honor baseline connect budget without stalling" + ); + } + } + + install_entries(&[]).unwrap(); +} diff --git a/src/proxy/tests/masking_consume_idle_timeout_security_tests.rs b/src/proxy/tests/masking_consume_idle_timeout_security_tests.rs new file mode 100644 index 0000000..b52af35 --- /dev/null +++ b/src/proxy/tests/masking_consume_idle_timeout_security_tests.rs @@ -0,0 +1,85 @@ +use super::*; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::time::Instant; +use tokio::io::{AsyncRead, ReadBuf}; + +struct OneByteThenStall { + sent: bool, +} + +impl AsyncRead for OneByteThenStall { + fn poll_read( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + if !self.sent { + self.sent = true; + buf.put_slice(&[0x42]); + Poll::Ready(Ok(())) + } else { + Poll::Pending + } + } +} + +#[tokio::test] +async fn stalling_client_terminates_at_idle_not_relay_timeout() { + let reader = OneByteThenStall { sent: false }; + let started = Instant::now(); + + let result = tokio::time::timeout( + MASK_RELAY_TIMEOUT, + consume_client_data(reader, MASK_BUFFER_SIZE * 4), + ) + .await; + + assert!( + result.is_ok(), + "consume_client_data should complete by per-read idle timeout, not hit relay timeout" + ); + + let elapsed = started.elapsed(); + assert!( + elapsed >= (MASK_RELAY_IDLE_TIMEOUT / 2), + "consume_client_data returned too quickly for idle-timeout path: {elapsed:?}" + ); + assert!( + elapsed < MASK_RELAY_TIMEOUT, + "consume_client_data waited full relay timeout ({elapsed:?}); \ + per-read idle timeout is missing" + ); +} + +#[tokio::test] +async fn fast_reader_drains_to_eof() { + let data = vec![0xAAu8; 32 * 1024]; + let reader = std::io::Cursor::new(data); + + tokio::time::timeout(MASK_RELAY_TIMEOUT, consume_client_data(reader, usize::MAX)) + .await + .expect("consume_client_data did not complete for fast EOF reader"); +} + +#[tokio::test] +async fn io_error_terminates_cleanly() { + struct ErrReader; + + impl AsyncRead for ErrReader { + fn poll_read( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _buf: &mut ReadBuf<'_>, + ) -> Poll> { + Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::ConnectionReset, + "simulated reset", + ))) + } + } + + tokio::time::timeout(MASK_RELAY_TIMEOUT, consume_client_data(ErrReader, usize::MAX)) + .await + .expect("consume_client_data did not return on I/O error"); +} diff --git a/src/proxy/tests/masking_consume_stress_adversarial_tests.rs b/src/proxy/tests/masking_consume_stress_adversarial_tests.rs new file mode 100644 index 0000000..12287b5 --- /dev/null +++ b/src/proxy/tests/masking_consume_stress_adversarial_tests.rs @@ -0,0 +1,64 @@ +use super::*; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::time::Instant; +use tokio::io::{AsyncRead, ReadBuf}; +use tokio::task::JoinSet; + +struct OneByteThenStall { + sent: bool, +} + +impl AsyncRead for OneByteThenStall { + fn poll_read( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + if !self.sent { + self.sent = true; + buf.put_slice(&[0xAA]); + Poll::Ready(Ok(())) + } else { + Poll::Pending + } + } +} + +#[tokio::test] +async fn consume_stall_stress_finishes_within_idle_budget() { + let mut set = JoinSet::new(); + let started = Instant::now(); + + for _ in 0..64 { + set.spawn(async { + tokio::time::timeout( + MASK_RELAY_TIMEOUT, + consume_client_data(OneByteThenStall { sent: false }, usize::MAX), + ) + .await + .expect("consume_client_data exceeded relay timeout under stall load"); + }); + } + + while let Some(res) = set.join_next().await { + res.unwrap(); + } + + // Under test constants idle=100ms, relay=200ms. 64 concurrent tasks stalling + // for 100ms should complete well under a strict 600ms boundary. + assert!( + started.elapsed() < MASK_RELAY_TIMEOUT * 3, + "stall stress batch completed too slowly; possible async executor starvation or head-of-line blocking" + ); +} + +#[tokio::test] +async fn consume_zero_cap_returns_immediately() { + let started = Instant::now(); + consume_client_data(tokio::io::empty(), 0).await; + assert!( + started.elapsed() < MASK_RELAY_IDLE_TIMEOUT, + "zero byte cap must return immediately" + ); +} diff --git a/src/proxy/tests/masking_http2_preface_integration_security_tests.rs b/src/proxy/tests/masking_http2_preface_integration_security_tests.rs new file mode 100644 index 0000000..7f1c03f --- /dev/null +++ b/src/proxy/tests/masking_http2_preface_integration_security_tests.rs @@ -0,0 +1,55 @@ +use super::*; +use tokio::net::TcpListener; +use tokio::time::Duration; + +#[tokio::test] +async fn http2_preface_is_forwarded_and_recorded_as_http() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let preface = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn({ + let preface = preface.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut received = vec![0u8; preface.len()]; + stream.read_exact(&mut received).await.unwrap(); + assert_eq!(received, preface); + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = true; + config.general.beobachten_minutes = 1; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let peer: SocketAddr = "198.51.100.130:54130".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let (client_reader, _client_writer) = tokio::io::duplex(512); + let (_client_visible_reader, client_visible_writer) = tokio::io::duplex(512); + let beobachten = BeobachtenStore::new(); + + handle_bad_client( + client_reader, + client_visible_writer, + &preface, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + tokio::time::timeout(Duration::from_secs(2), accept_task) + .await + .unwrap() + .unwrap(); + + let snapshot = beobachten.snapshot_text(Duration::from_secs(60)); + assert!(snapshot.contains("[HTTP]")); + assert!(snapshot.contains("198.51.100.130-1")); +} diff --git a/src/proxy/tests/masking_http2_probe_classification_security_tests.rs b/src/proxy/tests/masking_http2_probe_classification_security_tests.rs new file mode 100644 index 0000000..34e04a9 --- /dev/null +++ b/src/proxy/tests/masking_http2_probe_classification_security_tests.rs @@ -0,0 +1,92 @@ +use super::*; + +#[test] +fn full_http2_preface_classified_as_http_probe() { + let preface = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"; + assert!( + is_http_probe(preface), + "HTTP/2 connection preface must be classified as HTTP probe" + ); +} + +#[test] +fn partial_http2_preface_3_bytes_classified() { + assert!( + is_http_probe(b"PRI"), + "3-byte HTTP/2 preface prefix must be classified" + ); +} + +#[test] +fn partial_http2_preface_2_bytes_classified() { + assert!( + is_http_probe(b"PR"), + "2-byte HTTP/2 preface prefix must be classified" + ); +} + +#[test] +fn existing_http1_methods_unaffected() { + for prefix in [ + b"GET / HTTP/1.1\r\n".as_ref(), + b"POST /api HTTP/1.1\r\n".as_ref(), + b"CONNECT example.com:443 HTTP/1.1\r\n".as_ref(), + b"TRACE / HTTP/1.1\r\n".as_ref(), + b"PATCH / HTTP/1.1\r\n".as_ref(), + ] { + assert!(is_http_probe(prefix)); + } +} + +#[test] +fn non_http_data_not_classified() { + for data in [ + b"\x16\x03\x01\x00\xf1".as_ref(), + b"SSH-2.0-OpenSSH_8.9\r\n".as_ref(), + b"\x00\x01\x02\x03".as_ref(), + b"".as_ref(), + b"P".as_ref(), + ] { + assert!(!is_http_probe(data)); + } +} + +#[test] +fn light_fuzz_non_http_prefixes_not_misclassified() { + // Deterministic pseudo-fuzz to exercise classifier edges while avoiding + // known HTTP method and partial windows. + let mut x = 0x1234_5678u32; + for _ in 0..1024 { + x = x.wrapping_mul(1664525).wrapping_add(1013904223); + let len = 4 + ((x >> 8) as usize % 12); + let mut data = vec![0u8; len]; + for byte in &mut data { + x = x.wrapping_mul(1664525).wrapping_add(1013904223); + *byte = (x & 0xFF) as u8; + } + + if [ + b"GET ".as_ref(), + b"POST".as_ref(), + b"HEAD".as_ref(), + b"PUT ".as_ref(), + b"DELETE".as_ref(), + b"OPTIONS".as_ref(), + b"CONNECT".as_ref(), + b"TRACE".as_ref(), + b"PATCH".as_ref(), + b"PRI ".as_ref(), + ] + .iter() + .any(|m| data.starts_with(m)) + { + continue; + } + + assert!( + !is_http_probe(&data), + "non-http pseudo-fuzz input misclassified: {:?}", + &data[..data.len().min(8)] + ); + } +} diff --git a/src/proxy/tests/masking_http_probe_boundary_security_tests.rs b/src/proxy/tests/masking_http_probe_boundary_security_tests.rs new file mode 100644 index 0000000..47b6dc6 --- /dev/null +++ b/src/proxy/tests/masking_http_probe_boundary_security_tests.rs @@ -0,0 +1,79 @@ +use super::*; + +#[test] +fn exact_four_byte_http_tokens_are_classified() { + for token in [b"GET ".as_ref(), b"POST".as_ref(), b"HEAD".as_ref(), b"PUT ".as_ref(), b"PRI ".as_ref()] { + assert!( + is_http_probe(token), + "exact 4-byte token must be classified as HTTP probe: {:?}", + token + ); + } +} + +#[test] +fn exact_four_byte_non_http_tokens_are_not_classified() { + for token in [ + b"GEX ".as_ref(), + b"POXT".as_ref(), + b"HEA/".as_ref(), + b"PU\0 ".as_ref(), + b"PRI/".as_ref(), + ] { + assert!( + !is_http_probe(token), + "non-HTTP 4-byte token must not be classified: {:?}", + token + ); + } +} + +#[test] +fn detect_client_type_keeps_http_label_for_minimal_four_byte_http_prefixes() { + assert_eq!(detect_client_type(b"GET "), "HTTP"); + assert_eq!(detect_client_type(b"PRI "), "HTTP"); +} + +#[test] +fn exact_long_http_tokens_are_classified() { + for token in [b"CONNECT".as_ref(), b"TRACE".as_ref(), b"PATCH".as_ref()] { + assert!( + is_http_probe(token), + "exact long HTTP token must be classified as HTTP probe: {:?}", + token + ); + } +} + +#[test] +fn detect_client_type_keeps_http_label_for_exact_long_http_tokens() { + assert_eq!(detect_client_type(b"CONNECT"), "HTTP"); + assert_eq!(detect_client_type(b"TRACE"), "HTTP"); + assert_eq!(detect_client_type(b"PATCH"), "HTTP"); +} + +#[test] +fn light_fuzz_four_byte_ascii_noise_not_misclassified() { + // Deterministic pseudo-fuzz over 4-byte printable ASCII inputs. + let mut x = 0xA17C_93E5u32; + for _ in 0..2048 { + let mut token = [0u8; 4]; + for byte in &mut token { + x = x.wrapping_mul(1664525).wrapping_add(1013904223); + *byte = 32 + ((x & 0x3F) as u8); // printable ASCII subset + } + + if [b"GET ", b"POST", b"HEAD", b"PUT ", b"PRI "] + .iter() + .any(|m| token.as_slice() == *m) + { + continue; + } + + assert!( + !is_http_probe(&token), + "pseudo-fuzz noise misclassified as HTTP probe: {:?}", + token + ); + } +} \ No newline at end of file diff --git a/src/proxy/tests/masking_interface_cache_defense_in_depth_security_tests.rs b/src/proxy/tests/masking_interface_cache_defense_in_depth_security_tests.rs new file mode 100644 index 0000000..d82cf82 --- /dev/null +++ b/src/proxy/tests/masking_interface_cache_defense_in_depth_security_tests.rs @@ -0,0 +1,51 @@ +#![cfg(unix)] + +use super::*; + +#[test] +fn defense_in_depth_empty_refresh_preserves_previous_non_empty_interfaces() { + let previous = vec![ + "192.168.100.7" + .parse::() + .expect("must parse interface ip"), + ]; + let refreshed = Vec::new(); + + let next = choose_interface_snapshot(&previous, refreshed); + + assert_eq!( + next, previous, + "empty refresh should preserve previous non-empty snapshot to avoid fail-open loop-guard regressions" + ); +} + +#[test] +fn defense_in_depth_non_empty_refresh_replaces_previous_snapshot() { + let previous = vec![ + "192.168.100.7" + .parse::() + .expect("must parse interface ip"), + ]; + let refreshed = vec![ + "10.55.0.3" + .parse::() + .expect("must parse refreshed interface ip"), + ]; + + let next = choose_interface_snapshot(&previous, refreshed.clone()); + + assert_eq!(next, refreshed); +} + +#[test] +fn defense_in_depth_empty_refresh_keeps_empty_when_no_previous_snapshot_exists() { + let previous = Vec::new(); + let refreshed = Vec::new(); + + let next = choose_interface_snapshot(&previous, refreshed); + + assert!( + next.is_empty(), + "empty refresh with no previous snapshot should remain empty" + ); +} diff --git a/src/proxy/tests/masking_interface_cache_security_tests.rs b/src/proxy/tests/masking_interface_cache_security_tests.rs new file mode 100644 index 0000000..b14d7c3 --- /dev/null +++ b/src/proxy/tests/masking_interface_cache_security_tests.rs @@ -0,0 +1,46 @@ +#![cfg(unix)] + +use super::*; +use std::sync::{Mutex, OnceLock}; + +fn interface_cache_test_lock() -> &'static Mutex<()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())) +} + +#[test] +fn tdd_repeated_local_listener_checks_do_not_repeat_interface_enumeration_within_window() { + let _guard = interface_cache_test_lock() + .lock() + .unwrap_or_else(|poison| poison.into_inner()); + reset_local_interface_enumerations_for_tests(); + + let local_addr: SocketAddr = "0.0.0.0:443".parse().expect("valid local addr"); + + let _ = is_mask_target_local_listener("127.0.0.1", 443, local_addr, None); + let _ = is_mask_target_local_listener("127.0.0.1", 443, local_addr, None); + + assert_eq!( + local_interface_enumerations_for_tests(), + 1, + "interface enumeration must be cached across repeated bad-client checks" + ); +} + +#[test] +fn tdd_non_local_port_short_circuit_does_not_enumerate_interfaces() { + let _guard = interface_cache_test_lock() + .lock() + .unwrap_or_else(|poison| poison.into_inner()); + reset_local_interface_enumerations_for_tests(); + + let local_addr: SocketAddr = "0.0.0.0:443".parse().expect("valid local addr"); + let is_local = is_mask_target_local_listener("127.0.0.1", 8443, local_addr, None); + + assert!(!is_local, "different port must not be treated as local listener"); + assert_eq!( + local_interface_enumerations_for_tests(), + 0, + "port mismatch should bypass interface enumeration entirely" + ); +} diff --git a/src/proxy/tests/masking_offline_target_redteam_expected_fail_tests.rs b/src/proxy/tests/masking_offline_target_redteam_expected_fail_tests.rs new file mode 100644 index 0000000..efa4529 --- /dev/null +++ b/src/proxy/tests/masking_offline_target_redteam_expected_fail_tests.rs @@ -0,0 +1,178 @@ +use super::*; +use std::net::{SocketAddr, TcpListener as StdTcpListener}; +use tokio::io::{AsyncWriteExt, duplex}; +use tokio::time::{Duration, Instant}; + +fn closed_local_port() -> u16 { + let listener = StdTcpListener::bind("127.0.0.1:0").unwrap(); + let port = listener.local_addr().unwrap().port(); + drop(listener); + port +} + +#[tokio::test] +#[ignore = "red-team expected-fail: offline mask target keeps bad-client socket alive before consume timeout boundary"] +async fn redteam_offline_target_should_drop_idle_client_early() { + let (client_read, mut client_write) = duplex(1024); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = closed_local_port(); + cfg.censorship.mask_timing_normalization_enabled = false; + + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let peer_addr: SocketAddr = "192.0.2.50:5000".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + let handler = tokio::spawn(async move { + handle_bad_client( + client_read, + tokio::io::sink(), + b"GET / HTTP/1.1\r\n\r\n", + peer_addr, + local_addr, + &cfg, + &beobachten, + ) + .await; + }); + + tokio::time::sleep(Duration::from_millis(150)).await; + let write_res = client_write.write_all(b"probe-should-be-closed").await; + assert!( + write_res.is_err(), + "offline target path still keeps client writable before consume timeout" + ); + + handler.abort(); +} + +#[tokio::test] +#[ignore = "red-team expected-fail: proxy should mimic immediate RST-like close when target is offline"] +async fn redteam_offline_target_should_not_sleep_to_mask_refusal() { + let (client_read, mut client_write) = duplex(1024); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = closed_local_port(); + cfg.censorship.mask_timing_normalization_enabled = false; + + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let peer_addr: SocketAddr = "192.0.2.51:5000".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + let handler = tokio::spawn(async move { + handle_bad_client( + client_read, + tokio::io::sink(), + b"\x16\x03\x01\x00\x05hello", + peer_addr, + local_addr, + &cfg, + &beobachten, + ) + .await; + }); + + client_write.shutdown().await.unwrap(); + let _ = handler.await; + let elapsed = started.elapsed(); + + assert!( + elapsed < Duration::from_millis(10), + "offline target path still applies coarse masking sleep and is fingerprintable" + ); +} + +#[tokio::test] +#[ignore = "red-team expected-fail: refusal path should remain below strict latency envelope under burst"] +async fn redteam_offline_refusal_burst_timing_spread_should_be_tight() { + let mut samples = Vec::new(); + + for i in 0..12u16 { + let (client_read, mut client_write) = duplex(1024); + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = closed_local_port(); + cfg.censorship.mask_timing_normalization_enabled = false; + + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let peer_addr: SocketAddr = format!("192.0.2.52:{}", 5100 + i).parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + let handler = tokio::spawn(async move { + handle_bad_client( + client_read, + tokio::io::sink(), + b"GET / HTTP/1.1\r\n\r\n", + peer_addr, + local_addr, + &cfg, + &beobachten, + ) + .await; + }); + + client_write.shutdown().await.unwrap(); + let _ = handler.await; + samples.push(started.elapsed()); + } + + let min = samples.iter().copied().min().unwrap_or_default(); + let max = samples.iter().copied().max().unwrap_or_default(); + let spread = max.saturating_sub(min); + + assert!( + spread <= Duration::from_millis(5), + "offline refusal timing spread too wide for strict red-team envelope: {:?}", + spread + ); +} + +#[tokio::test] +#[ignore = "manual red-team: host resolver failure should complete without panic"] +async fn redteam_dns_resolution_failure_must_not_panic() { + let (client_read, mut client_write) = duplex(1024); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("this.domain.definitely.does.not.exist.invalid".to_string()); + cfg.censorship.mask_port = 443; + + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let peer_addr: SocketAddr = "192.0.2.99:5999".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + let handler = tokio::spawn(async move { + handle_bad_client( + client_read, + tokio::io::sink(), + b"GET / HTTP/1.1\r\n\r\n", + peer_addr, + local_addr, + &cfg, + &beobachten, + ) + .await; + }); + + client_write.shutdown().await.unwrap(); + let result = tokio::time::timeout(Duration::from_secs(2), handler).await; + assert!( + result.is_ok(), + "dns failure path stalled or panicked instead of terminating" + ); +} diff --git a/src/proxy/tests/masking_padding_timeout_adversarial_tests.rs b/src/proxy/tests/masking_padding_timeout_adversarial_tests.rs new file mode 100644 index 0000000..b99b4bc --- /dev/null +++ b/src/proxy/tests/masking_padding_timeout_adversarial_tests.rs @@ -0,0 +1,51 @@ +use super::*; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::time::Instant; +use tokio::io::AsyncWrite; + +struct NeverWritable; + +impl AsyncWrite for NeverWritable { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _buf: &[u8], + ) -> Poll> { + Poll::Pending + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Pending + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} + +#[tokio::test] +async fn shape_padding_returns_before_global_mask_timeout_on_blocked_writer() { + let mut writer = NeverWritable; + let started = Instant::now(); + + maybe_write_shape_padding(&mut writer, 1, true, 256, 4096, false, 0, false).await; + + assert!( + started.elapsed() <= MASK_TIMEOUT + std::time::Duration::from_millis(30), + "shape padding blocked past timeout budget" + ); +} + +#[tokio::test] +async fn shape_padding_with_non_http_blur_disabled_at_cap_writes_nothing() { + let mut output = Vec::new(); + { + let mut writer = tokio::io::BufWriter::new(&mut output); + maybe_write_shape_padding(&mut writer, 4096, true, 64, 4096, false, 128, false).await; + use tokio::io::AsyncWriteExt; + writer.flush().await.unwrap(); + } + + assert!(output.is_empty()); +} diff --git a/src/proxy/tests/masking_relay_guardrails_security_tests.rs b/src/proxy/tests/masking_relay_guardrails_security_tests.rs new file mode 100644 index 0000000..257c0f8 --- /dev/null +++ b/src/proxy/tests/masking_relay_guardrails_security_tests.rs @@ -0,0 +1,105 @@ +use super::*; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex, sink}; +use tokio::time::{Duration, timeout}; + +#[tokio::test] +async fn relay_to_mask_enforces_masking_session_byte_cap() { + let initial = vec![0x16, 0x03, 0x01, 0x00, 0x01]; + let extra = vec![0xAB; 96 * 1024]; + + let (client_reader, mut client_writer) = duplex(128 * 1024); + let (mask_read, _mask_read_peer) = duplex(1024); + let (mut mask_observer, mask_write) = duplex(256 * 1024); + let initial_for_task = initial.clone(); + + let relay = tokio::spawn(async move { + relay_to_mask( + client_reader, + sink(), + mask_read, + mask_write, + &initial_for_task, + false, + 512, + 4096, + false, + 0, + false, + 32 * 1024, + ) + .await; + }); + + client_writer.write_all(&extra).await.unwrap(); + client_writer.shutdown().await.unwrap(); + + timeout(Duration::from_secs(2), relay) + .await + .unwrap() + .unwrap(); + + let mut observed = Vec::new(); + timeout( + Duration::from_secs(2), + mask_observer.read_to_end(&mut observed), + ) + .await + .unwrap() + .unwrap(); + + // In this deterministic test, relay must stop exactly at the configured cap. + assert_eq!( + observed.len(), + initial.len() + (32 * 1024), + "masked relay must forward exactly up to the cap (observed={} initial={} cap={})", + observed.len(), + initial.len(), + 32 * 1024 + ); +} + +#[tokio::test] +async fn relay_to_mask_propagates_client_half_close_without_waiting_for_other_direction_timeout() { + let initial = b"GET /half-close HTTP/1.1\r\n".to_vec(); + + let (client_reader, mut client_writer) = duplex(8 * 1024); + let (mask_read, _mask_read_peer) = duplex(8 * 1024); + let (mut mask_observer, mask_write) = duplex(8 * 1024); + let initial_for_task = initial.clone(); + + let relay = tokio::spawn(async move { + relay_to_mask( + client_reader, + sink(), + mask_read, + mask_write, + &initial_for_task, + false, + 512, + 4096, + false, + 0, + false, + 32 * 1024, + ) + .await; + }); + + client_writer.shutdown().await.unwrap(); + + let mut observed = Vec::new(); + timeout( + Duration::from_millis(80), + mask_observer.read_to_end(&mut observed), + ) + .await + .expect("mask backend write side should be half-closed promptly") + .unwrap(); + + assert_eq!(&observed[..initial.len()], initial.as_slice()); + + timeout(Duration::from_secs(2), relay) + .await + .unwrap() + .unwrap(); +} diff --git a/src/proxy/tests/masking_rng_hoist_perf_regression_tests.rs b/src/proxy/tests/masking_rng_hoist_perf_regression_tests.rs new file mode 100644 index 0000000..627c48b --- /dev/null +++ b/src/proxy/tests/masking_rng_hoist_perf_regression_tests.rs @@ -0,0 +1,100 @@ +use super::*; +use tokio::io::AsyncReadExt; +use tokio::time::{Duration, timeout}; + +async fn collect_padding( + total_sent: usize, + enabled: bool, + floor: usize, + cap: usize, + above_cap_blur: bool, + blur_max: usize, + aggressive: bool, +) -> Vec { + let (mut tx, mut rx) = tokio::io::duplex(256 * 1024); + + maybe_write_shape_padding( + &mut tx, + total_sent, + enabled, + floor, + cap, + above_cap_blur, + blur_max, + aggressive, + ) + .await; + + drop(tx); + + let mut output = Vec::new(); + timeout(Duration::from_secs(1), rx.read_to_end(&mut output)) + .await + .expect("reading padded output timed out") + .expect("failed reading padded output"); + output +} + +#[tokio::test] +async fn padding_output_is_not_all_zero() { + let output = collect_padding(1, true, 256, 4096, false, 0, false).await; + + assert!( + output.len() >= 255, + "expected at least 255 padding bytes, got {}", + output.len() + ); + + let nonzero = output.iter().filter(|&&b| b != 0).count(); + // In 255 bytes of uniform randomness, the expected number of zero bytes is ~1. + // A weak nonzero check can miss severe entropy collapse. + assert!( + nonzero >= 240, + "RNG output entropy collapsed, too many zero bytes: {} nonzero out of {}", + nonzero, + output.len(), + ); +} + +#[tokio::test] +async fn padding_reaches_first_bucket_boundary() { + let output = collect_padding(1, true, 64, 4096, false, 0, false).await; + assert_eq!(output.len(), 63); +} + +#[tokio::test] +async fn disabled_padding_produces_no_output() { + let output = collect_padding(0, false, 256, 4096, false, 0, false).await; + assert!(output.is_empty()); +} + +#[tokio::test] +async fn at_cap_without_blur_produces_no_output() { + let output = collect_padding(4096, true, 64, 4096, false, 0, false).await; + assert!(output.is_empty()); +} + +#[tokio::test] +async fn above_cap_blur_is_positive_and_bounded_in_aggressive_mode() { + let output = collect_padding(4096, true, 64, 4096, true, 128, true).await; + assert!(!output.is_empty()); + assert!(output.len() <= 128, "blur exceeded max: {}", output.len()); +} + +#[tokio::test] +async fn stress_padding_runs_are_not_constant_pattern() { + // Stress and sanity-check: repeated runs should not collapse to identical + // first 16 bytes across all samples. + let mut first_chunks = Vec::new(); + for _ in 0..64 { + let out = collect_padding(1, true, 64, 4096, false, 0, false).await; + first_chunks.push(out[..16].to_vec()); + } + + let first = &first_chunks[0]; + let all_same = first_chunks.iter().all(|chunk| chunk == first); + assert!( + !all_same, + "all stress samples had identical prefix, rng output appears degenerate" + ); +} diff --git a/src/proxy/tests/masking_security_tests.rs b/src/proxy/tests/masking_security_tests.rs index 4519d85..c698b55 100644 --- a/src/proxy/tests/masking_security_tests.rs +++ b/src/proxy/tests/masking_security_tests.rs @@ -1376,6 +1376,7 @@ async fn relay_to_mask_keeps_backend_to_client_flow_when_client_to_backend_stall false, 0, false, + 5 * 1024 * 1024, ) .await; }); @@ -1506,6 +1507,7 @@ async fn relay_to_mask_timeout_cancels_and_drops_all_io_endpoints() { false, 0, false, + 5 * 1024 * 1024, ), ) .await; diff --git a/src/proxy/tests/masking_self_target_loop_security_tests.rs b/src/proxy/tests/masking_self_target_loop_security_tests.rs new file mode 100644 index 0000000..b92ce3d --- /dev/null +++ b/src/proxy/tests/masking_self_target_loop_security_tests.rs @@ -0,0 +1,354 @@ +use super::*; +use std::net::TcpListener as StdTcpListener; +use std::net::SocketAddr; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::net::TcpListener; +use tokio::time::{Duration, Instant, timeout}; + +fn closed_local_port() -> u16 { + let listener = StdTcpListener::bind("127.0.0.1:0").unwrap(); + let port = listener.local_addr().unwrap().port(); + drop(listener); + port +} + +#[test] +fn self_target_detection_matches_literal_ipv4_listener() { + let local: SocketAddr = "198.51.100.40:443".parse().unwrap(); + assert!(is_mask_target_local_listener( + "198.51.100.40", + 443, + local, + None, + )); +} + +#[test] +fn self_target_detection_matches_bracketed_ipv6_listener() { + let local: SocketAddr = "[2001:db8::44]:8443".parse().unwrap(); + assert!(is_mask_target_local_listener( + "[2001:db8::44]", + 8443, + local, + None, + )); +} + +#[test] +fn self_target_detection_keeps_same_ip_different_port_forwardable() { + let local: SocketAddr = "203.0.113.44:443".parse().unwrap(); + assert!(!is_mask_target_local_listener( + "203.0.113.44", + 8443, + local, + None, + )); +} + +#[test] +fn self_target_detection_normalizes_ipv4_mapped_ipv6_literal() { + let local: SocketAddr = "127.0.0.1:443".parse().unwrap(); + assert!(is_mask_target_local_listener( + "::ffff:127.0.0.1", + 443, + local, + None, + )); +} + +#[test] +fn self_target_detection_unspecified_bind_blocks_loopback_target() { + let local: SocketAddr = "0.0.0.0:443".parse().unwrap(); + assert!(is_mask_target_local_listener( + "127.0.0.1", + 443, + local, + None, + )); +} + +#[test] +fn self_target_detection_unspecified_bind_keeps_remote_target_forwardable() { + let local: SocketAddr = "0.0.0.0:443".parse().unwrap(); + let remote: SocketAddr = "198.51.100.44:443".parse().unwrap(); + assert!(!is_mask_target_local_listener( + "mask.example", + 443, + local, + Some(remote), + )); +} + +#[tokio::test] +async fn self_target_fallback_refuses_recursive_loopback_connect() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let local_addr = listener.local_addr().unwrap(); + let accept_task = tokio::spawn(async move { + timeout(Duration::from_millis(120), listener.accept()) + .await + .is_ok() + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = None; + config.censorship.mask_host = Some(local_addr.ip().to_string()); + config.censorship.mask_port = local_addr.port(); + config.censorship.mask_proxy_protocol = 0; + + let peer: SocketAddr = "203.0.113.90:55090".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + handle_bad_client( + tokio::io::empty(), + tokio::io::sink(), + b"GET /", + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + let accepted = accept_task.await.unwrap(); + assert!( + !accepted, + "self-target masking must fail closed without connecting to local listener" + ); +} + +#[tokio::test] +async fn same_ip_different_port_still_forwards_to_mask_backend() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = b"GET /".to_vec(); + let accept_task = tokio::spawn({ + let expected = probe.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = vec![0u8; expected.len()]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, expected); + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = None; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_proxy_protocol = 0; + + let peer: SocketAddr = "203.0.113.91:55091".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + handle_bad_client( + tokio::io::empty(), + tokio::io::sink(), + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + timeout(Duration::from_secs(2), accept_task) + .await + .unwrap() + .unwrap(); +} + +#[test] +fn detect_client_type_http_boundary_get_and_post() { + assert_eq!(detect_client_type(b"GET "), "HTTP"); + assert_eq!(detect_client_type(b"GET /"), "HTTP"); + + assert_eq!(detect_client_type(b"POST"), "HTTP"); + assert_eq!(detect_client_type(b"POST "), "HTTP"); + assert_eq!(detect_client_type(b"POSTX"), "HTTP"); +} + +#[test] +fn detect_client_type_tls_and_length_boundaries() { + assert_eq!(detect_client_type(b"\x16\x03\x01"), "port-scanner"); + assert_eq!(detect_client_type(b"\x16\x03\x01\x00"), "TLS-scanner"); + + assert_eq!(detect_client_type(b"123456789"), "port-scanner"); + assert_eq!(detect_client_type(b"1234567890"), "unknown"); +} + +#[test] +fn build_mask_proxy_header_v1_cross_family_falls_back_to_unknown() { + let peer: SocketAddr = "192.168.1.5:12345".parse().unwrap(); + let local: SocketAddr = "[2001:db8::1]:443".parse().unwrap(); + let header = build_mask_proxy_header(1, peer, local).unwrap(); + assert_eq!(header, b"PROXY UNKNOWN\r\n"); +} + +#[test] +fn next_mask_shape_bucket_checked_mul_overflow_fails_closed() { + let floor = usize::MAX / 2 + 1; + let cap = usize::MAX; + let total = floor + 1; + assert_eq!(next_mask_shape_bucket(total, floor, cap), total); +} + +#[tokio::test] +async fn self_target_reject_path_keeps_timing_budget() { + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = None; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = 443; + + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let peer: SocketAddr = "203.0.113.92:55092".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + let (client, server) = duplex(1024); + drop(client); + + let started = Instant::now(); + handle_bad_client( + server, + tokio::io::sink(), + b"GET / HTTP/1.1\r\n", + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + let elapsed = started.elapsed(); + assert!( + elapsed >= Duration::from_millis(40) && elapsed < Duration::from_millis(250), + "self-target reject path must keep coarse timing budget without stalling" + ); +} + +#[tokio::test] +async fn relay_path_idle_timeout_eviction_remains_effective() { + let (client_read, mut client_write) = duplex(1024); + let (mask_read, mask_write) = duplex(1024); + + tokio::spawn(async move { + tokio::time::sleep(Duration::from_millis(10)).await; + client_write.write_all(b"a").await.unwrap(); + tokio::time::sleep(Duration::from_millis(180)).await; + let _ = client_write.write_all(b"b").await; + }); + + let started = Instant::now(); + relay_to_mask( + client_read, + tokio::io::sink(), + mask_read, + mask_write, + b"init", + false, + 0, + 0, + false, + 0, + false, + 5 * 1024 * 1024, + ) + .await; + + let elapsed = started.elapsed(); + assert!( + elapsed >= Duration::from_millis(90) && elapsed < Duration::from_millis(180), + "idle-timeout eviction must occur before late trickle write" + ); +} + +#[tokio::test] +async fn offline_mask_target_refusal_respects_timing_normalization_budget() { + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = None; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = closed_local_port(); + config.censorship.mask_timing_normalization_enabled = true; + config.censorship.mask_timing_normalization_floor_ms = 120; + config.censorship.mask_timing_normalization_ceiling_ms = 120; + + let peer: SocketAddr = "203.0.113.93:55093".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + let (mut client, server) = duplex(1024); + let started = Instant::now(); + let task = tokio::spawn(async move { + handle_bad_client( + server, + tokio::io::sink(), + b"GET / HTTP/1.1\r\n\r\n", + peer, + local_addr, + &config, + &beobachten, + ) + .await; + }); + + client.shutdown().await.unwrap(); + timeout(Duration::from_secs(2), task).await.unwrap().unwrap(); + let elapsed = started.elapsed(); + + assert!( + elapsed >= Duration::from_millis(110) && elapsed < Duration::from_millis(220), + "offline-refusal path must honor normalization budget without unbounded drift" + ); +} + +#[tokio::test] +async fn offline_mask_target_refusal_with_idle_client_is_bounded_by_consume_timeout() { + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = None; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = closed_local_port(); + config.censorship.mask_timing_normalization_enabled = false; + + let peer: SocketAddr = "203.0.113.94:55094".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + let (mut client, server) = duplex(1024); + let started = Instant::now(); + let task = tokio::spawn(async move { + handle_bad_client( + server, + tokio::io::sink(), + b"GET / HTTP/1.1\r\n\r\n", + peer, + local_addr, + &config, + &beobachten, + ) + .await; + }); + + tokio::time::sleep(Duration::from_millis(120)).await; + client + .write_all(b"still-open-before-timeout") + .await + .expect("connection should still be open before consume timeout expires"); + + timeout(Duration::from_secs(2), task).await.unwrap().unwrap(); + let elapsed = started.elapsed(); + + assert!( + elapsed >= Duration::from_millis(190) && elapsed < Duration::from_millis(350), + "offline-refusal path must not retain idle client indefinitely" + ); +} diff --git a/src/proxy/tests/masking_shape_guard_adversarial_tests.rs b/src/proxy/tests/masking_shape_guard_adversarial_tests.rs index 982fd26..4fa8da7 100644 --- a/src/proxy/tests/masking_shape_guard_adversarial_tests.rs +++ b/src/proxy/tests/masking_shape_guard_adversarial_tests.rs @@ -43,6 +43,7 @@ async fn run_relay_case( above_cap_blur, above_cap_blur_max_bytes, false, + 5 * 1024 * 1024, ) .await; }); diff --git a/src/proxy/tests/masking_shape_hardening_adversarial_tests.rs b/src/proxy/tests/masking_shape_hardening_adversarial_tests.rs index 3c886ba..9abf3c0 100644 --- a/src/proxy/tests/masking_shape_hardening_adversarial_tests.rs +++ b/src/proxy/tests/masking_shape_hardening_adversarial_tests.rs @@ -88,6 +88,7 @@ async fn relay_to_mask_applies_cap_clamped_padding_for_non_power_of_two_cap() { false, 0, false, + 5 * 1024 * 1024, ) .await; }); diff --git a/src/proxy/tests/middle_relay_blackhat_campaign_integration_tests.rs b/src/proxy/tests/middle_relay_blackhat_campaign_integration_tests.rs index 2c9f3f6..6f0e91a 100644 --- a/src/proxy/tests/middle_relay_blackhat_campaign_integration_tests.rs +++ b/src/proxy/tests/middle_relay_blackhat_campaign_integration_tests.rs @@ -9,6 +9,7 @@ use tokio::time::{Duration, timeout}; #[tokio::test(flavor = "multi_thread", worker_threads = 4)] async fn blackhat_campaign_saturation_quota_race_with_queue_pressure_stays_fail_closed() { let _guard = super::quota_user_lock_test_scope(); + let _pressure_guard = super::relay_idle_pressure_test_scope(); let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); map.clear(); diff --git a/src/proxy/tests/middle_relay_hol_quota_security_tests.rs b/src/proxy/tests/middle_relay_hol_quota_security_tests.rs new file mode 100644 index 0000000..3d7929b --- /dev/null +++ b/src/proxy/tests/middle_relay_hol_quota_security_tests.rs @@ -0,0 +1,229 @@ +use super::*; +use crate::crypto::{AesCtr, SecureRandom}; +use crate::stats::Stats; +use crate::stream::CryptoWriter; +use bytes::Bytes; +use std::io; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; +use std::task::{Context, Poll, Waker}; +use tokio::io::AsyncWrite; +use tokio::time::{Duration, timeout}; + +fn make_crypto_writer(writer: W) -> CryptoWriter +where + W: tokio::io::AsyncWrite + Unpin, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) +} + +#[derive(Default)] +struct GateState { + open: AtomicBool, + parked_waker: std::sync::Mutex>, +} + +impl GateState { + fn open(&self) { + self.open.store(true, Ordering::Relaxed); + if let Ok(mut guard) = self.parked_waker.lock() + && let Some(w) = guard.take() + { + w.wake(); + } + } + + fn has_waiter(&self) -> bool { + self.parked_waker + .lock() + .map(|guard| guard.is_some()) + .unwrap_or(false) + } +} + +#[derive(Default)] +struct GateWriter { + gate: Arc, +} + +impl GateWriter { + fn new(gate: Arc) -> Self { + Self { gate } + } +} + +impl AsyncWrite for GateWriter { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + if self.gate.open.load(Ordering::Relaxed) { + return Poll::Ready(Ok(buf.len())); + } + + if let Ok(mut guard) = self.gate.parked_waker.lock() { + *guard = Some(cx.waker().clone()); + } + Poll::Pending + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} + +struct FailingWriter; + +impl AsyncWrite for FailingWriter { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _buf: &[u8], + ) -> Poll> { + Poll::Ready(Err(io::Error::new( + io::ErrorKind::BrokenPipe, + "injected writer failure", + ))) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn adversarial_same_user_slow_writer_must_not_hol_block_peer_connection() { + let stats = Stats::new(); + let bytes_me2c = AtomicU64::new(0); + let rng = SecureRandom::new(); + let quota_limit = Some(1024); + let user = "hol-quota-user"; + + let gate = Arc::new(GateState::default()); + + let mut blocked_writer = make_crypto_writer(GateWriter::new(Arc::clone(&gate))); + let slow_task = tokio::spawn(async move { + let mut frame_buf = Vec::new(); + process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0x10, 0x20, 0x30, 0x40]), + }, + &mut blocked_writer, + ProtoTag::Intermediate, + &rng, + &mut frame_buf, + &stats, + user, + quota_limit, + &bytes_me2c, + 7001, + false, + false, + ) + .await + }); + + timeout(Duration::from_millis(100), async { + loop { + if gate.has_waiter() { + break; + } + tokio::task::yield_now().await; + } + }) + .await + .expect("first writer must reach backpressure and park"); + + let stats_fast = Stats::new(); + let bytes_fast = AtomicU64::new(0); + let rng_fast = SecureRandom::new(); + let mut fast_writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf_fast = Vec::new(); + + timeout( + Duration::from_millis(50), + process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0x41]), + }, + &mut fast_writer, + ProtoTag::Intermediate, + &rng_fast, + &mut frame_buf_fast, + &stats_fast, + user, + quota_limit, + &bytes_fast, + 7002, + false, + false, + ), + ) + .await + .expect("peer connection must not be blocked by same-user stalled write") + .expect("fast peer write must succeed"); + + gate.open(); + let slow_result = timeout(Duration::from_secs(1), slow_task) + .await + .expect("stalled task must complete once gate opens") + .expect("stalled task must not panic"); + assert!(slow_result.is_ok()); +} + +#[tokio::test] +async fn negative_write_failure_rolls_back_pre_accounted_quota_and_forensics_bytes() { + let stats = Stats::new(); + let user = "rollback-user"; + stats.add_user_octets_from(user, 7); + + let bytes_me2c = AtomicU64::new(0); + let rng = SecureRandom::new(); + let mut writer = make_crypto_writer(FailingWriter); + let mut frame_buf = Vec::new(); + + let result = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[1, 2, 3, 4]), + }, + &mut writer, + ProtoTag::Intermediate, + &rng, + &mut frame_buf, + &stats, + user, + Some(64), + &bytes_me2c, + 7003, + false, + false, + ) + .await; + + assert!(matches!(result, Err(ProxyError::Io(_)))); + assert_eq!( + stats.get_user_total_octets(user), + 7, + "failed client write must not overcharge user quota accounting" + ); + assert_eq!( + bytes_me2c.load(Ordering::Relaxed), + 0, + "failed client write must not inflate ME->C forensic byte counter" + ); +} \ No newline at end of file diff --git a/src/proxy/tests/middle_relay_idle_policy_security_tests.rs b/src/proxy/tests/middle_relay_idle_policy_security_tests.rs index 3e0b30f..6ea182b 100644 --- a/src/proxy/tests/middle_relay_idle_policy_security_tests.rs +++ b/src/proxy/tests/middle_relay_idle_policy_security_tests.rs @@ -3,7 +3,7 @@ use crate::crypto::AesCtr; use crate::stats::Stats; use crate::stream::{BufferPool, CryptoReader}; use std::sync::atomic::AtomicU64; -use std::sync::{Arc, Mutex, OnceLock}; +use std::sync::Arc; use tokio::io::AsyncWriteExt; use tokio::io::duplex; use tokio::time::{Duration as TokioDuration, Instant as TokioInstant, timeout}; @@ -48,18 +48,6 @@ fn make_idle_policy(soft_ms: u64, hard_ms: u64, grace_ms: u64) -> RelayClientIdl } } -fn idle_pressure_test_lock() -> &'static Mutex<()> { - static TEST_LOCK: OnceLock> = OnceLock::new(); - TEST_LOCK.get_or_init(|| Mutex::new(())) -} - -fn acquire_idle_pressure_test_lock() -> std::sync::MutexGuard<'static, ()> { - match idle_pressure_test_lock().lock() { - Ok(guard) => guard, - Err(poisoned) => poisoned.into_inner(), - } -} - #[tokio::test] async fn idle_policy_soft_mark_then_hard_close_increments_reason_counters() { let (reader, _writer) = duplex(1024); @@ -372,7 +360,7 @@ async fn stress_many_idle_sessions_fail_closed_without_hang() { #[test] fn pressure_evicts_oldest_idle_candidate_with_deterministic_ordering() { - let _guard = acquire_idle_pressure_test_lock(); + let _guard = relay_idle_pressure_test_scope(); clear_relay_idle_pressure_state_for_testing(); let stats = Stats::new(); @@ -402,7 +390,7 @@ fn pressure_evicts_oldest_idle_candidate_with_deterministic_ordering() { #[test] fn pressure_does_not_evict_without_new_pressure_signal() { - let _guard = acquire_idle_pressure_test_lock(); + let _guard = relay_idle_pressure_test_scope(); clear_relay_idle_pressure_state_for_testing(); let stats = Stats::new(); @@ -421,7 +409,7 @@ fn pressure_does_not_evict_without_new_pressure_signal() { #[test] fn stress_pressure_eviction_preserves_fifo_across_many_candidates() { - let _guard = acquire_idle_pressure_test_lock(); + let _guard = relay_idle_pressure_test_scope(); clear_relay_idle_pressure_state_for_testing(); let stats = Stats::new(); @@ -457,7 +445,7 @@ fn stress_pressure_eviction_preserves_fifo_across_many_candidates() { #[test] fn blackhat_single_pressure_event_must_not_evict_more_than_one_candidate() { - let _guard = acquire_idle_pressure_test_lock(); + let _guard = relay_idle_pressure_test_scope(); clear_relay_idle_pressure_state_for_testing(); let stats = Stats::new(); @@ -491,7 +479,7 @@ fn blackhat_single_pressure_event_must_not_evict_more_than_one_candidate() { #[test] fn blackhat_pressure_counter_must_track_global_budget_not_per_session_cursor() { - let _guard = acquire_idle_pressure_test_lock(); + let _guard = relay_idle_pressure_test_scope(); clear_relay_idle_pressure_state_for_testing(); let stats = Stats::new(); @@ -524,7 +512,7 @@ fn blackhat_pressure_counter_must_track_global_budget_not_per_session_cursor() { #[test] fn blackhat_stale_pressure_before_idle_mark_must_not_trigger_eviction() { - let _guard = acquire_idle_pressure_test_lock(); + let _guard = relay_idle_pressure_test_scope(); clear_relay_idle_pressure_state_for_testing(); let stats = Stats::new(); @@ -543,7 +531,7 @@ fn blackhat_stale_pressure_before_idle_mark_must_not_trigger_eviction() { #[test] fn blackhat_stale_pressure_must_not_evict_any_of_newly_marked_batch() { - let _guard = acquire_idle_pressure_test_lock(); + let _guard = relay_idle_pressure_test_scope(); clear_relay_idle_pressure_state_for_testing(); let stats = Stats::new(); @@ -575,7 +563,7 @@ fn blackhat_stale_pressure_must_not_evict_any_of_newly_marked_batch() { #[test] fn blackhat_stale_pressure_seen_without_candidates_must_be_globally_invalidated() { - let _guard = acquire_idle_pressure_test_lock(); + let _guard = relay_idle_pressure_test_scope(); clear_relay_idle_pressure_state_for_testing(); let stats = Stats::new(); @@ -601,7 +589,7 @@ fn blackhat_stale_pressure_seen_without_candidates_must_be_globally_invalidated( #[test] fn blackhat_stale_pressure_must_not_survive_candidate_churn() { - let _guard = acquire_idle_pressure_test_lock(); + let _guard = relay_idle_pressure_test_scope(); clear_relay_idle_pressure_state_for_testing(); let stats = Stats::new(); @@ -621,7 +609,7 @@ fn blackhat_stale_pressure_must_not_survive_candidate_churn() { #[test] fn blackhat_pressure_seq_saturation_must_not_disable_future_pressure_accounting() { - let _guard = acquire_idle_pressure_test_lock(); + let _guard = relay_idle_pressure_test_scope(); clear_relay_idle_pressure_state_for_testing(); { @@ -646,7 +634,7 @@ fn blackhat_pressure_seq_saturation_must_not_disable_future_pressure_accounting( #[test] fn blackhat_pressure_seq_saturation_must_not_break_multiple_distinct_events() { - let _guard = acquire_idle_pressure_test_lock(); + let _guard = relay_idle_pressure_test_scope(); clear_relay_idle_pressure_state_for_testing(); { @@ -673,7 +661,7 @@ fn blackhat_pressure_seq_saturation_must_not_break_multiple_distinct_events() { #[tokio::test(flavor = "multi_thread", worker_threads = 4)] async fn integration_race_single_pressure_event_allows_at_most_one_eviction_under_parallel_claims() { - let _guard = acquire_idle_pressure_test_lock(); + let _guard = relay_idle_pressure_test_scope(); clear_relay_idle_pressure_state_for_testing(); let stats = Arc::new(Stats::new()); @@ -738,7 +726,7 @@ async fn integration_race_single_pressure_event_allows_at_most_one_eviction_unde #[tokio::test(flavor = "multi_thread", worker_threads = 4)] async fn integration_race_burst_pressure_with_churn_preserves_empty_set_invalidation_and_budget() { - let _guard = acquire_idle_pressure_test_lock(); + let _guard = relay_idle_pressure_test_scope(); clear_relay_idle_pressure_state_for_testing(); let stats = Arc::new(Stats::new()); diff --git a/src/proxy/tests/middle_relay_idle_registry_poison_security_tests.rs b/src/proxy/tests/middle_relay_idle_registry_poison_security_tests.rs new file mode 100644 index 0000000..112d926 --- /dev/null +++ b/src/proxy/tests/middle_relay_idle_registry_poison_security_tests.rs @@ -0,0 +1,59 @@ +use super::*; +use std::panic::{AssertUnwindSafe, catch_unwind}; + +#[test] +fn blackhat_registry_poison_recovers_with_fail_closed_reset_and_pressure_accounting() { + let _guard = relay_idle_pressure_test_scope(); + clear_relay_idle_pressure_state_for_testing(); + + let _ = catch_unwind(AssertUnwindSafe(|| { + let registry = relay_idle_candidate_registry(); + let mut guard = registry + .lock() + .expect("registry lock must be acquired before poison"); + guard.by_conn_id.insert( + 999, + RelayIdleCandidateMeta { + mark_order_seq: 1, + mark_pressure_seq: 0, + }, + ); + guard.ordered.insert((1, 999)); + panic!("intentional poison for idle-registry recovery"); + })); + + // Helper lock must recover from poison, reset stale state, and continue. + assert!(mark_relay_idle_candidate(42)); + assert_eq!(oldest_relay_idle_candidate(), Some(42)); + + let before = relay_pressure_event_seq(); + note_relay_pressure_event(); + let after = relay_pressure_event_seq(); + assert!(after > before, "pressure accounting must still advance after poison"); + + clear_relay_idle_pressure_state_for_testing(); +} + +#[test] +fn clear_state_helper_must_reset_poisoned_registry_for_deterministic_fifo_tests() { + let _guard = relay_idle_pressure_test_scope(); + clear_relay_idle_pressure_state_for_testing(); + + let _ = catch_unwind(AssertUnwindSafe(|| { + let registry = relay_idle_candidate_registry(); + let _guard = registry + .lock() + .expect("registry lock must be acquired before poison"); + panic!("intentional poison while lock held"); + })); + + clear_relay_idle_pressure_state_for_testing(); + + assert_eq!(oldest_relay_idle_candidate(), None); + assert_eq!(relay_pressure_event_seq(), 0); + + assert!(mark_relay_idle_candidate(7)); + assert_eq!(oldest_relay_idle_candidate(), Some(7)); + + clear_relay_idle_pressure_state_for_testing(); +} diff --git a/src/proxy/tests/middle_relay_quota_reservation_adversarial_tests.rs b/src/proxy/tests/middle_relay_quota_reservation_adversarial_tests.rs new file mode 100644 index 0000000..717a375 --- /dev/null +++ b/src/proxy/tests/middle_relay_quota_reservation_adversarial_tests.rs @@ -0,0 +1,192 @@ +use super::*; +use crate::crypto::{AesCtr, SecureRandom}; +use crate::stats::Stats; +use crate::stream::CryptoWriter; +use bytes::Bytes; +use std::sync::Arc; +use std::sync::atomic::{AtomicU64, Ordering}; +use tokio::task::JoinSet; + +fn make_crypto_writer(writer: W) -> CryptoWriter +where + W: tokio::io::AsyncWrite + Unpin, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) +} + +#[tokio::test] +async fn positive_exact_quota_boundary_allows_last_frame_and_blocks_next() { + let stats = Stats::new(); + let user = "quota-boundary-user"; + let bytes_me2c = AtomicU64::new(0); + + stats.add_user_octets_from(user, 5); + + let mut writer_one = make_crypto_writer(tokio::io::sink()); + let mut frame_buf_one = Vec::new(); + let first = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[1, 2, 3]), + }, + &mut writer_one, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf_one, + &stats, + user, + Some(8), + &bytes_me2c, + 7101, + false, + false, + ) + .await; + + assert!(first.is_ok(), "frame that reaches boundary must be allowed"); + assert_eq!(stats.get_user_total_octets(user), 8); + + let mut writer_two = make_crypto_writer(tokio::io::sink()); + let mut frame_buf_two = Vec::new(); + let second = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[9]), + }, + &mut writer_two, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf_two, + &stats, + user, + Some(8), + &bytes_me2c, + 7102, + false, + false, + ) + .await; + + assert!( + matches!(second, Err(ProxyError::DataQuotaExceeded { .. })), + "frame after boundary must be rejected" + ); + assert_eq!(stats.get_user_total_octets(user), 8); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), 3); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn adversarial_parallel_reservation_stress_never_overshoots_quota_or_counters() { + let stats = Arc::new(Stats::new()); + let user = "reservation-stress-user"; + let quota_limit = 64u64; + let bytes_me2c = Arc::new(AtomicU64::new(0)); + let mut tasks = JoinSet::new(); + + for idx in 0..256u64 { + let user_owned = user.to_string(); + let stats_ref = Arc::clone(&stats); + let bytes_ref = Arc::clone(&bytes_me2c); + + tasks.spawn(async move { + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0xAB]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + stats_ref.as_ref(), + &user_owned, + Some(quota_limit), + bytes_ref.as_ref(), + 7200 + idx, + false, + false, + ) + .await + }); + } + + let mut ok = 0usize; + let mut denied = 0usize; + while let Some(joined) = tasks.join_next().await { + match joined.expect("reservation stress task must not panic") { + Ok(_) => ok += 1, + Err(ProxyError::DataQuotaExceeded { .. }) => denied += 1, + Err(other) => panic!("unexpected error in stress case: {other:?}"), + } + } + + let total = stats.get_user_total_octets(user); + assert_eq!( + total, quota_limit, + "quota must be exactly exhausted without overshoot" + ); + assert_eq!( + bytes_me2c.load(Ordering::Relaxed), + total, + "ME->C forensic bytes must track committed quota usage" + ); + assert_eq!(ok, quota_limit as usize, "exactly quota_limit tasks must succeed"); + assert_eq!( + denied, + 256usize - (quota_limit as usize), + "remaining tasks must be exactly denied without silently swallowing state" + ); +} + +#[tokio::test] +async fn light_fuzz_random_frame_sizes_preserve_quota_and_counter_consistency() { + let stats = Stats::new(); + let user = "reservation-fuzz-user"; + let quota_limit = 128u64; + let bytes_me2c = AtomicU64::new(0); + let mut seed = 0xC0FE_EE11_8899_2211u64; + + for conn in 0..512u64 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + let len = ((seed & 0x0f) + 1) as usize; + let payload = vec![0x5A; len]; + + let mut writer = make_crypto_writer(tokio::io::sink()); + let mut frame_buf = Vec::new(); + let result = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from(payload), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + user, + Some(quota_limit), + &bytes_me2c, + 7300 + conn, + false, + false, + ) + .await; + + if let Err(err) = result { + assert!( + matches!(err, ProxyError::DataQuotaExceeded { .. }), + "fuzz run produced unexpected error variant: {err:?}" + ); + } + } + + let total = stats.get_user_total_octets(user); + assert!(total <= quota_limit); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), total); +} \ No newline at end of file diff --git a/src/proxy/tests/middle_relay_tiny_frame_debt_concurrency_security_tests.rs b/src/proxy/tests/middle_relay_tiny_frame_debt_concurrency_security_tests.rs new file mode 100644 index 0000000..1bf3123 --- /dev/null +++ b/src/proxy/tests/middle_relay_tiny_frame_debt_concurrency_security_tests.rs @@ -0,0 +1,365 @@ +use super::*; +use crate::crypto::AesCtr; +use crate::stats::Stats; +use crate::stream::{BufferPool, CryptoReader}; +use std::sync::Arc; +use std::sync::atomic::AtomicU64; +use std::time::Instant; +use tokio::io::{AsyncRead, AsyncWriteExt, duplex}; +use tokio::task::JoinSet; +use tokio::time::{Duration as TokioDuration, sleep, timeout}; + +fn make_crypto_reader(reader: T) -> CryptoReader +where + T: AsyncRead + Unpin + Send + 'static, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoReader::new(reader, AesCtr::new(&key, iv)) +} + +fn encrypt_for_reader(plaintext: &[u8]) -> Vec { + let key = [0u8; 32]; + let iv = 0u128; + let mut cipher = AesCtr::new(&key, iv); + cipher.encrypt(plaintext) +} + +fn make_forensics(conn_id: u64, started_at: Instant) -> RelayForensicsState { + RelayForensicsState { + trace_id: 0xB200_0000 + conn_id, + conn_id, + user: format!("tiny-frame-debt-concurrency-user-{conn_id}"), + peer: "127.0.0.1:50000".parse().expect("peer parse must succeed"), + peer_hash: hash_ip("127.0.0.1".parse().expect("ip parse must succeed")), + started_at, + bytes_c2me: 0, + bytes_me2c: Arc::new(AtomicU64::new(0)), + desync_all_full: false, + } +} + +fn make_enabled_idle_policy() -> RelayClientIdlePolicy { + RelayClientIdlePolicy { + enabled: true, + soft_idle: Duration::from_secs(30), + hard_idle: Duration::from_secs(60), + grace_after_downstream_activity: Duration::from_secs(0), + legacy_frame_read_timeout: Duration::from_secs(30), + } +} + +async fn read_once( + crypto_reader: &mut CryptoReader, + proto: ProtoTag, + forensics: &RelayForensicsState, + frame_counter: &mut u64, + idle_state: &mut RelayClientIdleState, +) -> Result> { + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let idle_policy = make_enabled_idle_policy(); + let last_downstream_activity_ms = AtomicU64::new(0); + read_client_payload_with_idle_policy( + crypto_reader, + proto, + 1024, + &buffer_pool, + forensics, + frame_counter, + &stats, + &idle_policy, + idle_state, + &last_downstream_activity_ms, + forensics.started_at, + ) + .await +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_parallel_pure_tiny_floods_all_fail_closed() { + let mut set = JoinSet::new(); + + for idx in 0..32u64 { + set.spawn(async move { + let (reader, mut writer) = duplex(4096); + let mut crypto_reader = make_crypto_reader(reader); + let started = Instant::now(); + let forensics = make_forensics(1000 + idx, started); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(started); + + let flood_plaintext = vec![0u8; 1024]; + let flood_encrypted = encrypt_for_reader(&flood_plaintext); + writer.write_all(&flood_encrypted).await.unwrap(); + drop(writer); + + let result = timeout( + TokioDuration::from_secs(1), + read_once( + &mut crypto_reader, + ProtoTag::Abridged, + &forensics, + &mut frame_counter, + &mut idle_state, + ), + ) + .await + .expect("tiny flood task must complete"); + + assert!(matches!(result, Err(ProxyError::Proxy(_)))); + assert_eq!(frame_counter, 0); + }); + } + + while let Some(result) = set.join_next().await { + result.expect("parallel tiny flood worker must not panic"); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_parallel_benign_tiny_burst_then_real_all_pass() { + let mut set = JoinSet::new(); + + for idx in 0..24u64 { + set.spawn(async move { + let (reader, mut writer) = duplex(2048); + let mut crypto_reader = make_crypto_reader(reader); + let started = Instant::now(); + let forensics = make_forensics(2000 + idx, started); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(started); + + let payload = [idx as u8, 2, 3, 4]; + let mut plaintext = Vec::with_capacity(20); + for _ in 0..6 { + plaintext.push(0x00); + } + plaintext.push(0x01); + plaintext.extend_from_slice(&payload); + let encrypted = encrypt_for_reader(&plaintext); + writer.write_all(&encrypted).await.unwrap(); + + let result = timeout( + TokioDuration::from_secs(1), + read_once( + &mut crypto_reader, + ProtoTag::Abridged, + &forensics, + &mut frame_counter, + &mut idle_state, + ), + ) + .await + .expect("benign task must complete") + .expect("benign payload must parse") + .expect("benign payload must return frame"); + + assert_eq!(result.0.as_ref(), &payload); + assert_eq!(frame_counter, 1); + }); + } + + while let Some(result) = set.join_next().await { + result.expect("parallel benign worker must not panic"); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn adversarial_lockstep_alternating_attack_under_jitter_closes() { + let mut set = JoinSet::new(); + + for idx in 0..12u64 { + set.spawn(async move { + let (reader, mut writer) = duplex(8192); + let mut crypto_reader = make_crypto_reader(reader); + let started = Instant::now(); + let forensics = make_forensics(3000 + idx, started); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(started); + + let mut plaintext = Vec::with_capacity(2000); + for n in 0..180u8 { + plaintext.push(0x00); + plaintext.push(0x01); + plaintext.extend_from_slice(&[n, n ^ 0x21, n ^ 0x42, n ^ 0x84]); + } + let encrypted = encrypt_for_reader(&plaintext); + + let writer_task = tokio::spawn(async move { + for chunk in encrypted.chunks(17) { + writer.write_all(chunk).await.unwrap(); + sleep(TokioDuration::from_millis(1)).await; + } + drop(writer); + }); + + let mut closed = false; + for _ in 0..220 { + let result = timeout( + TokioDuration::from_secs(1), + read_once( + &mut crypto_reader, + ProtoTag::Abridged, + &forensics, + &mut frame_counter, + &mut idle_state, + ), + ) + .await + .expect("alternating reader step must complete"); + + match result { + Ok(Some((_payload, _))) => {} + Err(ProxyError::Proxy(_)) => { + closed = true; + break; + } + Ok(None) => break, + Err(other) => panic!("unexpected error in alternating jitter case: {other}"), + } + } + + writer_task.await.expect("writer jitter task must not panic"); + assert!(closed, "alternating attack must close before EOF"); + }); + } + + while let Some(result) = set.join_next().await { + result.expect("alternating jitter worker must not panic"); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn integration_mixed_population_attackers_close_benign_survive() { + let mut set = JoinSet::new(); + + for idx in 0..20u64 { + set.spawn(async move { + let (reader, mut writer) = duplex(4096); + let mut crypto_reader = make_crypto_reader(reader); + let started = Instant::now(); + let forensics = make_forensics(4000 + idx, started); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(started); + + if idx % 2 == 0 { + let mut plaintext = Vec::with_capacity(1280); + for n in 0..140u8 { + plaintext.push(0x00); + plaintext.push(0x01); + plaintext.extend_from_slice(&[n, n, n, n]); + } + writer.write_all(&encrypt_for_reader(&plaintext)).await.unwrap(); + drop(writer); + + let mut closed = false; + for _ in 0..200 { + match read_once( + &mut crypto_reader, + ProtoTag::Abridged, + &forensics, + &mut frame_counter, + &mut idle_state, + ) + .await + { + Ok(Some(_)) => {} + Err(ProxyError::Proxy(_)) => { + closed = true; + break; + } + Ok(None) => break, + Err(other) => panic!("unexpected attacker error: {other}"), + } + } + assert!(closed, "attacker session must fail closed"); + } else { + let payload = [1u8, 9, 8, 7]; + let mut plaintext = Vec::new(); + for _ in 0..4 { + plaintext.push(0x00); + } + plaintext.push(0x01); + plaintext.extend_from_slice(&payload); + writer.write_all(&encrypt_for_reader(&plaintext)).await.unwrap(); + + let got = read_once( + &mut crypto_reader, + ProtoTag::Abridged, + &forensics, + &mut frame_counter, + &mut idle_state, + ) + .await + .expect("benign session must parse") + .expect("benign session must return a frame"); + assert_eq!(got.0.as_ref(), &payload); + } + }); + } + + while let Some(result) = set.join_next().await { + result.expect("mixed-population worker must not panic"); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn light_fuzz_parallel_patterns_no_hang_or_panic() { + let mut set = JoinSet::new(); + + for case in 0..40u64 { + set.spawn(async move { + let (reader, mut writer) = duplex(8192); + let mut crypto_reader = make_crypto_reader(reader); + let started = Instant::now(); + let forensics = make_forensics(5000 + case, started); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(started); + + let mut seed = 0x9E37_79B9u64 ^ (case << 8); + let mut plaintext = Vec::with_capacity(2048); + for _ in 0..256 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + let is_tiny = (seed & 1) == 0; + if is_tiny { + plaintext.push(0x00); + } else { + plaintext.push(0x01); + plaintext.extend_from_slice(&[(seed >> 8) as u8, 2, 3, 4]); + } + } + + writer.write_all(&encrypt_for_reader(&plaintext)).await.unwrap(); + drop(writer); + + for _ in 0..320 { + let step = timeout( + TokioDuration::from_secs(1), + read_once( + &mut crypto_reader, + ProtoTag::Abridged, + &forensics, + &mut frame_counter, + &mut idle_state, + ), + ) + .await + .expect("fuzz case read step must complete"); + + match step { + Ok(Some(_)) => {} + Err(ProxyError::Proxy(_)) => break, + Ok(None) => break, + Err(other) => panic!("unexpected fuzz case error: {other}"), + } + } + }); + } + + while let Some(result) = set.join_next().await { + result.expect("fuzz worker must not panic"); + } +} diff --git a/src/proxy/tests/middle_relay_tiny_frame_debt_proto_chunking_security_tests.rs b/src/proxy/tests/middle_relay_tiny_frame_debt_proto_chunking_security_tests.rs new file mode 100644 index 0000000..0ff46a2 --- /dev/null +++ b/src/proxy/tests/middle_relay_tiny_frame_debt_proto_chunking_security_tests.rs @@ -0,0 +1,418 @@ +use super::*; +use crate::crypto::AesCtr; +use crate::stats::Stats; +use crate::stream::{BufferPool, CryptoReader, PooledBuffer}; +use std::sync::Arc; +use std::sync::atomic::AtomicU64; +use std::time::Instant; +use tokio::io::{AsyncRead, AsyncWriteExt, duplex}; +use tokio::time::{Duration as TokioDuration, sleep, timeout}; + +fn make_crypto_reader(reader: T) -> CryptoReader +where + T: AsyncRead + Unpin + Send + 'static, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoReader::new(reader, AesCtr::new(&key, iv)) +} + +fn encrypt_for_reader(plaintext: &[u8]) -> Vec { + let key = [0u8; 32]; + let iv = 0u128; + let mut cipher = AesCtr::new(&key, iv); + cipher.encrypt(plaintext) +} + +fn make_forensics(conn_id: u64, started_at: Instant) -> RelayForensicsState { + RelayForensicsState { + trace_id: 0xB300_0000 + conn_id, + conn_id, + user: format!("tiny-frame-debt-proto-chunk-user-{conn_id}"), + peer: "127.0.0.1:50000".parse().expect("peer parse must succeed"), + peer_hash: hash_ip("127.0.0.1".parse().expect("ip parse must succeed")), + started_at, + bytes_c2me: 0, + bytes_me2c: Arc::new(AtomicU64::new(0)), + desync_all_full: false, + } +} + +fn make_enabled_idle_policy() -> RelayClientIdlePolicy { + RelayClientIdlePolicy { + enabled: true, + soft_idle: Duration::from_secs(30), + hard_idle: Duration::from_secs(60), + grace_after_downstream_activity: Duration::from_secs(0), + legacy_frame_read_timeout: Duration::from_secs(30), + } +} + +fn append_tiny_frame(plaintext: &mut Vec, proto: ProtoTag) { + match proto { + ProtoTag::Abridged => plaintext.push(0x00), + ProtoTag::Intermediate | ProtoTag::Secure => plaintext.extend_from_slice(&0u32.to_le_bytes()), + } +} + +fn append_real_frame(plaintext: &mut Vec, proto: ProtoTag, payload: [u8; 4]) { + match proto { + ProtoTag::Abridged => { + plaintext.push(0x01); + plaintext.extend_from_slice(&payload); + } + ProtoTag::Intermediate | ProtoTag::Secure => { + plaintext.extend_from_slice(&4u32.to_le_bytes()); + plaintext.extend_from_slice(&payload); + } + } +} + +async fn write_chunked_with_jitter( + writer: &mut tokio::io::DuplexStream, + bytes: &[u8], + mut seed: u64, +) { + let mut offset = 0usize; + while offset < bytes.len() { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + let chunk_len = 1 + ((seed as usize) & 0x1f); + let end = (offset + chunk_len).min(bytes.len()); + writer.write_all(&bytes[offset..end]).await.unwrap(); + + let delay_ms = ((seed >> 16) % 3) as u64; + if delay_ms > 0 { + sleep(TokioDuration::from_millis(delay_ms)).await; + } + offset = end; + } +} + +async fn read_once_with_state( + crypto_reader: &mut CryptoReader, + proto: ProtoTag, + forensics: &RelayForensicsState, + frame_counter: &mut u64, + idle_state: &mut RelayClientIdleState, +) -> Result> { + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let idle_policy = make_enabled_idle_policy(); + let last_downstream_activity_ms = AtomicU64::new(0); + read_client_payload_with_idle_policy( + crypto_reader, + proto, + 1024, + &buffer_pool, + forensics, + frame_counter, + &stats, + &idle_policy, + idle_state, + &last_downstream_activity_ms, + forensics.started_at, + ) + .await +} + +#[tokio::test] +async fn intermediate_chunked_zero_flood_fail_closed() { + let (reader, mut writer) = duplex(4096); + let mut crypto_reader = make_crypto_reader(reader); + let started = Instant::now(); + let forensics = make_forensics(6101, started); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(started); + + let mut plaintext = Vec::with_capacity(4 * 256); + for _ in 0..256 { + append_tiny_frame(&mut plaintext, ProtoTag::Intermediate); + } + let encrypted = encrypt_for_reader(&plaintext); + write_chunked_with_jitter(&mut writer, &encrypted, 0x1111_2222).await; + drop(writer); + + let result = timeout( + TokioDuration::from_secs(2), + read_once_with_state( + &mut crypto_reader, + ProtoTag::Intermediate, + &forensics, + &mut frame_counter, + &mut idle_state, + ), + ) + .await + .expect("intermediate flood read must complete"); + + assert!(matches!(result, Err(ProxyError::Proxy(_)))); + assert_eq!(frame_counter, 0); +} + +#[tokio::test] +async fn secure_chunked_zero_flood_fail_closed() { + let (reader, mut writer) = duplex(4096); + let mut crypto_reader = make_crypto_reader(reader); + let started = Instant::now(); + let forensics = make_forensics(6102, started); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(started); + + let mut plaintext = Vec::with_capacity(4 * 256); + for _ in 0..256 { + append_tiny_frame(&mut plaintext, ProtoTag::Secure); + } + let encrypted = encrypt_for_reader(&plaintext); + write_chunked_with_jitter(&mut writer, &encrypted, 0x3333_4444).await; + drop(writer); + + let result = timeout( + TokioDuration::from_secs(2), + read_once_with_state( + &mut crypto_reader, + ProtoTag::Secure, + &forensics, + &mut frame_counter, + &mut idle_state, + ), + ) + .await + .expect("secure flood read must complete"); + + assert!(matches!(result, Err(ProxyError::Proxy(_)))); + assert_eq!(frame_counter, 0); +} + +#[tokio::test] +async fn intermediate_chunked_alternating_attack_closes_before_eof() { + let (reader, mut writer) = duplex(8192); + let mut crypto_reader = make_crypto_reader(reader); + let started = Instant::now(); + let forensics = make_forensics(6103, started); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(started); + + let mut plaintext = Vec::with_capacity(8 * 200); + for n in 0..180u8 { + append_tiny_frame(&mut plaintext, ProtoTag::Intermediate); + append_real_frame(&mut plaintext, ProtoTag::Intermediate, [n, n ^ 1, n ^ 2, n ^ 3]); + } + let encrypted = encrypt_for_reader(&plaintext); + + let writer_task = tokio::spawn(async move { + write_chunked_with_jitter(&mut writer, &encrypted, 0x5555_6666).await; + drop(writer); + }); + + let mut closed = false; + for _ in 0..240 { + let step = timeout( + TokioDuration::from_secs(1), + read_once_with_state( + &mut crypto_reader, + ProtoTag::Intermediate, + &forensics, + &mut frame_counter, + &mut idle_state, + ), + ) + .await + .expect("intermediate alternating read step must complete"); + + match step { + Ok(Some(_)) => {} + Err(ProxyError::Proxy(_)) => { + closed = true; + break; + } + Ok(None) => break, + Err(other) => panic!("unexpected intermediate alternating error: {other}"), + } + } + + writer_task.await.expect("intermediate writer task must not panic"); + assert!(closed, "intermediate alternating attack must fail closed"); +} + +#[tokio::test] +async fn secure_chunked_alternating_attack_closes_before_eof() { + let (reader, mut writer) = duplex(8192); + let mut crypto_reader = make_crypto_reader(reader); + let started = Instant::now(); + let forensics = make_forensics(6104, started); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(started); + + let mut plaintext = Vec::with_capacity(8 * 200); + for n in 0..180u8 { + append_tiny_frame(&mut plaintext, ProtoTag::Secure); + append_real_frame(&mut plaintext, ProtoTag::Secure, [n, n ^ 7, n ^ 11, n ^ 19]); + } + let encrypted = encrypt_for_reader(&plaintext); + + let writer_task = tokio::spawn(async move { + write_chunked_with_jitter(&mut writer, &encrypted, 0x7777_8888).await; + drop(writer); + }); + + let mut closed = false; + for _ in 0..240 { + let step = timeout( + TokioDuration::from_secs(1), + read_once_with_state( + &mut crypto_reader, + ProtoTag::Secure, + &forensics, + &mut frame_counter, + &mut idle_state, + ), + ) + .await + .expect("secure alternating read step must complete"); + + match step { + Ok(Some(_)) => {} + Err(ProxyError::Proxy(_)) => { + closed = true; + break; + } + Ok(None) => break, + Err(other) => panic!("unexpected secure alternating error: {other}"), + } + } + + writer_task.await.expect("secure writer task must not panic"); + assert!(closed, "secure alternating attack must fail closed"); +} + +#[tokio::test] +async fn intermediate_chunked_safe_small_burst_still_returns_real_frame() { + let (reader, mut writer) = duplex(1024); + let mut crypto_reader = make_crypto_reader(reader); + let started = Instant::now(); + let forensics = make_forensics(6105, started); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(started); + + let payload = [9u8, 8, 7, 6]; + let mut plaintext = Vec::new(); + for _ in 0..7 { + append_tiny_frame(&mut plaintext, ProtoTag::Intermediate); + } + append_real_frame(&mut plaintext, ProtoTag::Intermediate, payload); + let encrypted = encrypt_for_reader(&plaintext); + write_chunked_with_jitter(&mut writer, &encrypted, 0xAAAA_BBBB).await; + + let result = read_once_with_state( + &mut crypto_reader, + ProtoTag::Intermediate, + &forensics, + &mut frame_counter, + &mut idle_state, + ) + .await + .expect("intermediate safe burst should parse") + .expect("intermediate safe burst should return a frame"); + + assert_eq!(result.0.as_ref(), &payload); + assert_eq!(frame_counter, 1); +} + +#[tokio::test] +async fn secure_chunked_safe_small_burst_still_returns_real_frame() { + let (reader, mut writer) = duplex(1024); + let mut crypto_reader = make_crypto_reader(reader); + let started = Instant::now(); + let forensics = make_forensics(6106, started); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(started); + + let payload = [3u8, 1, 4, 1]; + let mut plaintext = Vec::new(); + for _ in 0..7 { + append_tiny_frame(&mut plaintext, ProtoTag::Secure); + } + append_real_frame(&mut plaintext, ProtoTag::Secure, payload); + let encrypted = encrypt_for_reader(&plaintext); + write_chunked_with_jitter(&mut writer, &encrypted, 0xCCCC_DDDD).await; + + let result = read_once_with_state( + &mut crypto_reader, + ProtoTag::Secure, + &forensics, + &mut frame_counter, + &mut idle_state, + ) + .await + .expect("secure safe burst should parse") + .expect("secure safe burst should return a frame"); + + assert_eq!(result.0.as_ref(), &payload); + assert_eq!(frame_counter, 1); +} + +#[tokio::test] +async fn light_fuzz_proto_chunking_outcomes_are_bounded() { + let mut seed = 0xDEAD_BEEF_2026_0322u64; + + for case in 0..48u64 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let proto = if (seed & 1) == 0 { + ProtoTag::Intermediate + } else { + ProtoTag::Secure + }; + + let (reader, mut writer) = duplex(8192); + let mut crypto_reader = make_crypto_reader(reader); + let started = Instant::now(); + let forensics = make_forensics(6200 + case, started); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(started); + + let mut stream = Vec::new(); + let mut local_seed = seed ^ case; + for _ in 0..220 { + local_seed ^= local_seed << 7; + local_seed ^= local_seed >> 9; + local_seed ^= local_seed << 8; + if (local_seed & 1) == 0 { + append_tiny_frame(&mut stream, proto); + } else { + let b = (local_seed >> 8) as u8; + append_real_frame(&mut stream, proto, [b, b ^ 0x12, b ^ 0x24, b ^ 0x48]); + } + } + + let encrypted = encrypt_for_reader(&stream); + write_chunked_with_jitter(&mut writer, &encrypted, seed ^ 0x1234_5678).await; + drop(writer); + + for _ in 0..260 { + let step = timeout( + TokioDuration::from_secs(1), + read_once_with_state( + &mut crypto_reader, + proto, + &forensics, + &mut frame_counter, + &mut idle_state, + ), + ) + .await + .expect("fuzz proto read step must complete"); + + match step { + Ok(Some((_payload, _))) => {} + Err(ProxyError::Proxy(_)) => break, + Ok(None) => break, + Err(other) => panic!("unexpected proto chunking fuzz error: {other}"), + } + } + } +} diff --git a/src/proxy/tests/middle_relay_tiny_frame_debt_security_tests.rs b/src/proxy/tests/middle_relay_tiny_frame_debt_security_tests.rs new file mode 100644 index 0000000..d0719c8 --- /dev/null +++ b/src/proxy/tests/middle_relay_tiny_frame_debt_security_tests.rs @@ -0,0 +1,550 @@ +use super::*; +use crate::crypto::AesCtr; +use crate::stats::Stats; +use crate::stream::{BufferPool, CryptoReader}; +use std::sync::atomic::AtomicU64; +use std::sync::Arc; +use std::time::Instant; +use tokio::io::{AsyncRead, AsyncWriteExt, duplex}; + +fn make_crypto_reader(reader: T) -> CryptoReader +where + T: AsyncRead + Unpin + Send + 'static, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoReader::new(reader, AesCtr::new(&key, iv)) +} + +fn encrypt_for_reader(plaintext: &[u8]) -> Vec { + let key = [0u8; 32]; + let iv = 0u128; + let mut cipher = AesCtr::new(&key, iv); + cipher.encrypt(plaintext) +} + +fn make_forensics(conn_id: u64, started_at: Instant) -> RelayForensicsState { + RelayForensicsState { + trace_id: 0xB100_0000 + conn_id, + conn_id, + user: format!("tiny-frame-debt-user-{conn_id}"), + peer: "127.0.0.1:50000".parse().expect("peer parse must succeed"), + peer_hash: hash_ip("127.0.0.1".parse().expect("ip parse must succeed")), + started_at, + bytes_c2me: 0, + bytes_me2c: Arc::new(AtomicU64::new(0)), + desync_all_full: false, + } +} + +fn make_enabled_idle_policy() -> RelayClientIdlePolicy { + RelayClientIdlePolicy { + enabled: true, + soft_idle: Duration::from_secs(30), + hard_idle: Duration::from_secs(60), + grace_after_downstream_activity: Duration::from_secs(0), + legacy_frame_read_timeout: Duration::from_secs(30), + } +} + +fn simulate_tiny_debt_pattern(pattern: &[bool], max_steps: usize) -> (Option, u32, usize) { + let mut debt = 0u32; + let mut reals = 0usize; + for (idx, is_tiny) in pattern.iter().copied().take(max_steps).enumerate() { + if is_tiny { + debt = debt.saturating_add(TINY_FRAME_DEBT_PER_TINY); + if debt >= TINY_FRAME_DEBT_LIMIT { + return (Some(idx + 1), debt, reals); + } + } else { + reals = reals.saturating_add(1); + debt = debt.saturating_sub(1); + } + } + (None, debt, reals) +} + +#[test] +fn tiny_frame_debt_constants_match_security_budget_expectations() { + assert_eq!(TINY_FRAME_DEBT_PER_TINY, 8); + assert_eq!(TINY_FRAME_DEBT_LIMIT, 512); +} + +#[test] +fn relay_client_idle_state_initial_debt_is_zero() { + let state = RelayClientIdleState::new(Instant::now()); + assert_eq!(state.tiny_frame_debt, 0); +} + +#[test] +fn on_client_frame_does_not_reset_tiny_frame_debt() { + let now = Instant::now(); + let mut state = RelayClientIdleState::new(now); + state.tiny_frame_debt = 77; + state.on_client_frame(now); + assert_eq!(state.tiny_frame_debt, 77); +} + +#[test] +fn tiny_frame_debt_increment_is_saturating() { + let mut debt = u32::MAX - 1; + debt = debt.saturating_add(TINY_FRAME_DEBT_PER_TINY); + assert_eq!(debt, u32::MAX); +} + +#[test] +fn tiny_frame_debt_decrement_is_saturating() { + let mut debt = 0u32; + debt = debt.saturating_sub(1); + assert_eq!(debt, 0); +} + +#[test] +fn consecutive_tiny_frames_close_exactly_at_threshold() { + let max_tiny_without_close = (TINY_FRAME_DEBT_LIMIT / TINY_FRAME_DEBT_PER_TINY) as usize; + let pattern = vec![true; max_tiny_without_close]; + let (closed_at, _, _) = simulate_tiny_debt_pattern(&pattern, pattern.len()); + assert_eq!(closed_at, Some(max_tiny_without_close)); +} + +#[test] +fn one_less_than_threshold_tiny_frames_do_not_close() { + let tiny_count = (TINY_FRAME_DEBT_LIMIT / TINY_FRAME_DEBT_PER_TINY) as usize - 1; + let pattern = vec![true; tiny_count]; + let (closed_at, debt, _) = simulate_tiny_debt_pattern(&pattern, pattern.len()); + assert_eq!(closed_at, None); + assert!(debt < TINY_FRAME_DEBT_LIMIT); +} + +#[test] +fn alternating_one_to_one_closes_with_bounded_real_frame_count() { + let mut pattern = Vec::with_capacity(512); + for _ in 0..256 { + pattern.push(true); + pattern.push(false); + } + let (closed_at, _, reals) = simulate_tiny_debt_pattern(&pattern, pattern.len()); + assert!(closed_at.is_some()); + assert!(reals <= 80, "expected bounded real frames before close, got {reals}"); +} + +#[test] +fn alternating_one_to_eight_is_stable_for_long_runs() { + let mut pattern = Vec::with_capacity(9 * 5000); + for _ in 0..5000 { + pattern.push(true); + for _ in 0..8 { + pattern.push(false); + } + } + let (closed_at, debt, _) = simulate_tiny_debt_pattern(&pattern, pattern.len()); + assert_eq!(closed_at, None); + assert!(debt <= TINY_FRAME_DEBT_PER_TINY); +} + +#[test] +fn alternating_one_to_seven_eventually_closes() { + let mut pattern = Vec::with_capacity(8 * 2000); + for _ in 0..2000 { + pattern.push(true); + for _ in 0..7 { + pattern.push(false); + } + } + let (closed_at, _, _) = simulate_tiny_debt_pattern(&pattern, pattern.len()); + assert!(closed_at.is_some(), "1:7 tiny-to-real must eventually close"); +} + +#[test] +fn two_tiny_one_real_closes_faster_than_one_to_one() { + let mut one_to_one = Vec::with_capacity(512); + for _ in 0..256 { + one_to_one.push(true); + one_to_one.push(false); + } + + let mut two_to_one = Vec::with_capacity(768); + for _ in 0..256 { + two_to_one.push(true); + two_to_one.push(true); + two_to_one.push(false); + } + + let (a_close, _, _) = simulate_tiny_debt_pattern(&one_to_one, one_to_one.len()); + let (b_close, _, _) = simulate_tiny_debt_pattern(&two_to_one, two_to_one.len()); + assert!(a_close.is_some() && b_close.is_some()); + assert!(b_close.unwrap_or(usize::MAX) < a_close.unwrap_or(0)); +} + +#[test] +fn burst_then_drain_can_recover_without_close() { + let burst_tiny = ((TINY_FRAME_DEBT_LIMIT / TINY_FRAME_DEBT_PER_TINY) / 2) as usize; + let mut pattern = Vec::with_capacity(burst_tiny + 600); + for _ in 0..burst_tiny { + pattern.push(true); + } + pattern.extend(std::iter::repeat_n(false, 600)); + + let (closed_at, debt, _) = simulate_tiny_debt_pattern(&pattern, pattern.len()); + assert_eq!(closed_at, None); + assert_eq!(debt, 0); +} + +#[test] +fn light_fuzz_tiny_frame_debt_model_stays_within_bounds() { + let mut seed = 0xA5A5_91C3_2026_0322u64; + for _case in 0..128 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let len = 512 + ((seed as usize) & 0x3ff); + let mut pattern = Vec::with_capacity(len); + let mut local_seed = seed; + for _ in 0..len { + local_seed ^= local_seed << 7; + local_seed ^= local_seed >> 9; + local_seed ^= local_seed << 8; + pattern.push((local_seed & 1) == 0); + } + + let (closed_at, debt, _) = simulate_tiny_debt_pattern(&pattern, pattern.len()); + if closed_at.is_none() { + assert!(debt < TINY_FRAME_DEBT_LIMIT); + } + assert!(debt <= u32::MAX); + } +} + +#[test] +fn stress_many_independent_simulations_keep_isolated_debt_state() { + for idx in 0..2048usize { + let mut pattern = Vec::with_capacity(64); + for j in 0..64usize { + pattern.push(((idx ^ j) & 3) == 0); + } + let (_closed_at, debt, _reals) = simulate_tiny_debt_pattern(&pattern, pattern.len()); + assert!(debt <= TINY_FRAME_DEBT_LIMIT.saturating_add(TINY_FRAME_DEBT_PER_TINY)); + } +} + +#[tokio::test] +async fn idle_policy_enabled_intermediate_zero_length_flood_is_fail_closed() { + let (reader, mut writer) = duplex(4096); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(11, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_enabled_idle_policy(); + let last_downstream_activity_ms = AtomicU64::new(0); + + let flood_plaintext = vec![0u8; 4 * 256]; + let flood_encrypted = encrypt_for_reader(&flood_plaintext); + writer.write_all(&flood_encrypted).await.unwrap(); + drop(writer); + + let result = read_client_payload_with_idle_policy( + &mut crypto_reader, + ProtoTag::Intermediate, + 1024, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ) + .await; + + assert!(matches!(result, Err(ProxyError::Proxy(_)))); +} + +#[tokio::test] +async fn idle_policy_enabled_secure_zero_length_flood_is_fail_closed() { + let (reader, mut writer) = duplex(4096); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(12, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_enabled_idle_policy(); + let last_downstream_activity_ms = AtomicU64::new(0); + + let flood_plaintext = vec![0u8; 4 * 256]; + let flood_encrypted = encrypt_for_reader(&flood_plaintext); + writer.write_all(&flood_encrypted).await.unwrap(); + drop(writer); + + let result = read_client_payload_with_idle_policy( + &mut crypto_reader, + ProtoTag::Secure, + 1024, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ) + .await; + + assert!(matches!(result, Err(ProxyError::Proxy(_)))); +} + +#[tokio::test] +async fn intermediate_alternating_zero_and_real_eventually_closes() { + let (reader, mut writer) = duplex(8192); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(13, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_enabled_idle_policy(); + let last_downstream_activity_ms = AtomicU64::new(0); + + let mut plaintext = Vec::with_capacity(3000); + for idx in 0..160u8 { + plaintext.extend_from_slice(&0u32.to_le_bytes()); + plaintext.extend_from_slice(&4u32.to_le_bytes()); + plaintext.extend_from_slice(&[idx, idx ^ 0x11, idx ^ 0x22, idx ^ 0x33]); + } + let encrypted = encrypt_for_reader(&plaintext); + writer.write_all(&encrypted).await.unwrap(); + drop(writer); + + let mut closed = false; + for _ in 0..220 { + let result = read_client_payload_with_idle_policy( + &mut crypto_reader, + ProtoTag::Intermediate, + 1024, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ) + .await; + + match result { + Ok(Some(_)) => {} + Err(ProxyError::Proxy(_)) => { + closed = true; + break; + } + Ok(None) => break, + Err(other) => panic!("unexpected error while probing alternating close: {other}"), + } + } + + assert!(closed, "intermediate alternating attack must fail closed"); +} + +#[tokio::test] +async fn small_tiny_burst_followed_by_real_frame_does_not_spuriously_close() { + let (reader, mut writer) = duplex(1024); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(14, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_enabled_idle_policy(); + let last_downstream_activity_ms = AtomicU64::new(0); + + let mut plaintext = Vec::with_capacity(64); + for _ in 0..8 { + plaintext.push(0x00); + } + plaintext.push(0x01); + plaintext.extend_from_slice(&[1, 2, 3, 4]); + + let encrypted = encrypt_for_reader(&plaintext); + writer.write_all(&encrypted).await.unwrap(); + + let first = read_client_payload_with_idle_policy( + &mut crypto_reader, + ProtoTag::Abridged, + 1024, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ) + .await; + + match first { + Ok(Some((payload, _))) => assert_eq!(payload.as_ref(), &[1, 2, 3, 4]), + Err(e) => panic!("unexpected close after small tiny burst: {e}"), + Ok(None) => panic!("unexpected EOF before real frame"), + } +} + +#[tokio::test] +async fn idle_policy_enabled_zero_length_flood_is_fail_closed() { + let (reader, mut writer) = duplex(4096); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(1, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_enabled_idle_policy(); + let last_downstream_activity_ms = AtomicU64::new(0); + + let flood_plaintext = vec![0u8; 1024]; + let flood_encrypted = encrypt_for_reader(&flood_plaintext); + writer + .write_all(&flood_encrypted) + .await + .expect("zero-length flood bytes must be writable"); + drop(writer); + + let result = read_client_payload_with_idle_policy( + &mut crypto_reader, + ProtoTag::Abridged, + 1024, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ) + .await; + + assert!( + matches!(result, Err(ProxyError::Proxy(_))), + "idle policy enabled must fail closed for pure zero-length flood" + ); +} + +#[tokio::test] +async fn idle_policy_enabled_alternating_tiny_real_eventually_closes() { + let (reader, mut writer) = duplex(8192); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(2, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_enabled_idle_policy(); + let last_downstream_activity_ms = AtomicU64::new(0); + + let mut plaintext = Vec::with_capacity(256 * 6); + for idx in 0..=255u8 { + plaintext.push(0x00); + plaintext.push(0x01); + plaintext.extend_from_slice(&[idx, idx ^ 0x55, idx ^ 0xAA, 0x11]); + } + + let encrypted = encrypt_for_reader(&plaintext); + writer + .write_all(&encrypted) + .await + .expect("alternating flood bytes must be writable"); + drop(writer); + + let mut saw_proxy_close = false; + for _ in 0..300 { + let result = read_client_payload_with_idle_policy( + &mut crypto_reader, + ProtoTag::Abridged, + 1024, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ) + .await; + + match result { + Ok(Some((_payload, _quickack))) => {} + Err(ProxyError::Proxy(_)) => { + saw_proxy_close = true; + break; + } + Err(ProxyError::Io(e)) => panic!("unexpected IO error before close: {e}"), + Ok(None) => panic!("unexpected EOF before debt-based closure"), + Err(other) => panic!("unexpected error before close: {other}"), + } + } + + assert!( + saw_proxy_close, + "alternating tiny/real sequence must eventually fail closed" + ); +} + +#[tokio::test] +async fn enabled_idle_policy_valid_nonzero_frame_still_passes() { + let (reader, mut writer) = duplex(1024); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(3, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_enabled_idle_policy(); + let last_downstream_activity_ms = AtomicU64::new(0); + + let payload = [7u8, 8, 9, 10]; + let mut plaintext = Vec::with_capacity(1 + payload.len()); + plaintext.push(0x01); + plaintext.extend_from_slice(&payload); + + let encrypted = encrypt_for_reader(&plaintext); + writer + .write_all(&encrypted) + .await + .expect("nonzero frame must be writable"); + + let result = read_client_payload_with_idle_policy( + &mut crypto_reader, + ProtoTag::Abridged, + 1024, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ) + .await + .expect("valid frame should decode") + .expect("valid frame should return payload"); + + assert_eq!(result.0.as_ref(), &payload); + assert!(!result.1); + assert_eq!(frame_counter, 1); +} diff --git a/src/proxy/tests/middle_relay_zero_length_frame_security_tests.rs b/src/proxy/tests/middle_relay_zero_length_frame_security_tests.rs new file mode 100644 index 0000000..765c253 --- /dev/null +++ b/src/proxy/tests/middle_relay_zero_length_frame_security_tests.rs @@ -0,0 +1,121 @@ +use super::*; +use crate::crypto::AesCtr; +use crate::stats::Stats; +use crate::stream::{BufferPool, CryptoReader}; +use std::sync::atomic::AtomicU64; +use std::sync::Arc; +use tokio::io::{AsyncRead, AsyncWriteExt, duplex}; +use std::time::Instant; + +fn make_crypto_reader(reader: T) -> CryptoReader +where + T: AsyncRead + Unpin + Send + 'static, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoReader::new(reader, AesCtr::new(&key, iv)) +} + +fn encrypt_for_reader(plaintext: &[u8]) -> Vec { + let key = [0u8; 32]; + let iv = 0u128; + let mut cipher = AesCtr::new(&key, iv); + cipher.encrypt(plaintext) +} + +fn make_forensics(conn_id: u64, started_at: Instant) -> RelayForensicsState { + RelayForensicsState { + trace_id: 0xB000_0000 + conn_id, + conn_id, + user: format!("zero-len-test-user-{conn_id}"), + peer: "127.0.0.1:50000".parse().expect("peer parse must succeed"), + peer_hash: hash_ip("127.0.0.1".parse().expect("ip parse must succeed")), + started_at, + bytes_c2me: 0, + bytes_me2c: Arc::new(AtomicU64::new(0)), + desync_all_full: false, + } +} + +#[tokio::test] +async fn adversarial_legacy_zero_length_flood_is_fail_closed() { + let (reader, mut writer) = duplex(1024); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(1, session_started_at); + let mut frame_counter = 0u64; + + let flood_plaintext = vec![0u8; 128]; + let flood_encrypted = encrypt_for_reader(&flood_plaintext); + writer + .write_all(&flood_encrypted) + .await + .expect("zero-length flood bytes must be writable"); + drop(writer); + + let result = read_client_payload_legacy( + &mut crypto_reader, + ProtoTag::Abridged, + 1024, + Duration::from_millis(30), + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + ) + .await; + + match result { + Err(ProxyError::Proxy(msg)) => { + assert!( + msg.contains("Excessive zero-length"), + "legacy mode must close flood with explicit zero-length reason, got: {msg}" + ); + } + Ok(None) => panic!("legacy zero-length flood must not be accepted as EOF"), + Ok(Some(_)) => panic!("legacy zero-length flood must not produce a data frame"), + Err(err) => panic!("legacy zero-length flood must be a Proxy error, got: {err}"), + } +} + +#[tokio::test] +async fn business_abridged_nonzero_frame_still_passes() { + let (reader, mut writer) = duplex(1024); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(2, session_started_at); + let mut frame_counter = 0u64; + + let payload = [1u8, 2, 3, 4]; + let mut plaintext = Vec::with_capacity(1 + payload.len()); + plaintext.push(0x01); + plaintext.extend_from_slice(&payload); + + let encrypted = encrypt_for_reader(&plaintext); + writer + .write_all(&encrypted) + .await + .expect("nonzero abridged frame must be writable"); + + let result = read_client_payload_legacy( + &mut crypto_reader, + ProtoTag::Abridged, + 1024, + Duration::from_millis(30), + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + ) + .await + .expect("valid abridged frame should decode") + .expect("valid abridged frame should return payload"); + + assert_eq!(result.0.as_ref(), &payload); + assert!(!result.1, "quickack flag must remain false"); + assert_eq!(frame_counter, 1); +} diff --git a/src/proxy/tests/quota_lock_registry_cross_mode_adversarial_tests.rs b/src/proxy/tests/quota_lock_registry_cross_mode_adversarial_tests.rs new file mode 100644 index 0000000..fb0cf93 --- /dev/null +++ b/src/proxy/tests/quota_lock_registry_cross_mode_adversarial_tests.rs @@ -0,0 +1,108 @@ +use super::*; +use std::sync::Arc; +use std::sync::{Mutex, OnceLock}; + +fn cross_mode_lock_test_guard() -> std::sync::MutexGuard<'static, ()> { + static TEST_LOCK: OnceLock> = OnceLock::new(); + TEST_LOCK + .get_or_init(|| Mutex::new(())) + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +#[test] +fn same_user_returns_same_lock_identity() { + let _guard = cross_mode_lock_test_guard(); + let locks = CROSS_MODE_QUOTA_USER_LOCKS.get_or_init(DashMap::new); + locks.clear(); + + let a = cross_mode_quota_user_lock("cross-mode-same-user"); + let b = cross_mode_quota_user_lock("cross-mode-same-user"); + + assert!( + Arc::ptr_eq(&a, &b), + "same user must reuse a stable lock identity" + ); +} + +#[test] +fn saturation_overflow_path_returns_stable_striped_lock_without_cache_growth() { + let _guard = cross_mode_lock_test_guard(); + let locks = CROSS_MODE_QUOTA_USER_LOCKS.get_or_init(DashMap::new); + locks.clear(); + + let prefix = format!("cross-mode-saturated-{}", std::process::id()); + let mut retained = Vec::with_capacity(CROSS_MODE_QUOTA_USER_LOCKS_MAX); + for idx in 0..CROSS_MODE_QUOTA_USER_LOCKS_MAX { + retained.push(cross_mode_quota_user_lock(&format!("{prefix}-{idx}"))); + } + + assert_eq!( + locks.len(), + CROSS_MODE_QUOTA_USER_LOCKS_MAX, + "lock cache must be saturated for overflow check" + ); + + let overflow_user = format!("cross-mode-overflow-{}", std::process::id()); + let overflow_a = cross_mode_quota_user_lock(&overflow_user); + let overflow_b = cross_mode_quota_user_lock(&overflow_user); + + assert_eq!( + locks.len(), + CROSS_MODE_QUOTA_USER_LOCKS_MAX, + "overflow path must not grow bounded lock cache" + ); + assert!( + locks.get(&overflow_user).is_none(), + "overflow user must stay on striped fallback while cache is saturated" + ); + assert!( + Arc::ptr_eq(&overflow_a, &overflow_b), + "overflow user must receive a stable striped lock across repeated lookups" + ); + + drop(retained); +} + +#[test] +fn reclaim_drops_stale_entries_but_preserves_active_user_lock_identity() { + let _guard = cross_mode_lock_test_guard(); + let locks = CROSS_MODE_QUOTA_USER_LOCKS.get_or_init(DashMap::new); + locks.clear(); + + let prefix = format!("cross-mode-reclaim-{}", std::process::id()); + let protected_user = format!("{prefix}-protected"); + + let protected_lock = cross_mode_quota_user_lock(&protected_user); + let mut retained = Vec::with_capacity(CROSS_MODE_QUOTA_USER_LOCKS_MAX.saturating_sub(1)); + for idx in 0..(CROSS_MODE_QUOTA_USER_LOCKS_MAX.saturating_sub(1)) { + retained.push(cross_mode_quota_user_lock(&format!("{prefix}-{idx}"))); + } + + assert_eq!( + locks.len(), + CROSS_MODE_QUOTA_USER_LOCKS_MAX, + "fixture must saturate lock cache before reclaim path is exercised" + ); + + drop(retained); + + let newcomer_user = format!("{prefix}-newcomer"); + let _newcomer = cross_mode_quota_user_lock(&newcomer_user); + + assert!( + locks.get(&protected_user).is_some(), + "active protected user must remain cache-resident after reclaim" + ); + let locked = locks + .get(&protected_user) + .expect("protected user must remain in map after reclaim"); + assert!( + Arc::ptr_eq(locked.value(), &protected_lock), + "reclaim must not swap active user lock identity" + ); + assert!( + locks.get(&newcomer_user).is_some(), + "newcomer should become cacheable after stale entries are reclaimed" + ); +} diff --git a/src/proxy/tests/relay_cross_mode_quota_fairness_tdd_tests.rs b/src/proxy/tests/relay_cross_mode_quota_fairness_tdd_tests.rs new file mode 100644 index 0000000..87944ba --- /dev/null +++ b/src/proxy/tests/relay_cross_mode_quota_fairness_tdd_tests.rs @@ -0,0 +1,225 @@ +use super::*; +use crate::stats::Stats; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::task::{Context, Waker}; +use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadBuf}; +use tokio::time::{Duration, timeout}; + +#[derive(Default)] +struct WakeCounter { + wakes: AtomicUsize, +} + +impl std::task::Wake for WakeCounter { + fn wake(self: Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } + + fn wake_by_ref(self: &Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } +} + +fn quota_test_guard() -> impl Drop { + super::quota_user_lock_test_scope() +} + +#[tokio::test] +async fn positive_cross_mode_uncontended_writer_progresses() { + let _guard = quota_test_guard(); + + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + "cross-mode-tdd-uncontended".to_string(), + Some(4096), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + let result = io.write_all(&[0x11, 0x22]).await; + assert!(result.is_ok(), "uncontended writer must progress"); +} + +#[tokio::test] +async fn adversarial_held_cross_mode_lock_blocks_writer_even_if_local_lock_free() { + let _guard = quota_test_guard(); + + let user = format!("cross-mode-tdd-held-{}", std::process::id()); + let held = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); + let _held_guard = held + .try_lock() + .expect("test must hold cross-mode lock before polling writer"); + + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user, + Some(4096), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + let wake_counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&wake_counter)); + let mut cx = Context::from_waker(&waker); + + let poll = Pin::new(&mut io).poll_write(&mut cx, &[0xAA]); + assert!(poll.is_pending(), "writer must not bypass held cross-mode lock"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn integration_parallel_waiters_resume_after_cross_mode_release() { + let _guard = quota_test_guard(); + + let user = format!("cross-mode-tdd-resume-{}", std::process::id()); + let held = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); + let held_guard = held + .try_lock() + .expect("test must hold cross-mode lock before launching waiters"); + + let stats = Arc::new(Stats::new()); + let mut waiters = Vec::new(); + for _ in 0..16 { + let stats = Arc::clone(&stats); + let user = user.clone(); + waiters.push(tokio::spawn(async move { + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + stats, + user, + Some(4096), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + io.write_all(&[0x7F]).await + })); + } + + tokio::time::sleep(Duration::from_millis(5)).await; + drop(held_guard); + + timeout(Duration::from_secs(1), async { + for waiter in waiters { + let result = waiter.await.expect("waiter task must not panic"); + assert!(result.is_ok(), "waiter must complete after cross-mode release"); + } + }) + .await + .expect("all waiters must complete in bounded time"); +} + +#[tokio::test] +async fn adversarial_cross_mode_contention_wake_budget_stays_bounded() { + let _guard = quota_test_guard(); + + let user = format!("cross-mode-tdd-wakes-{}", std::process::id()); + let held = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); + let _held_guard = held + .try_lock() + .expect("test must hold cross-mode lock before polling"); + + let stats = Arc::new(Stats::new()); + let mut ios = Vec::new(); + let mut counters = Vec::new(); + for _ in 0..20 { + ios.push(StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::clone(&stats), + user.clone(), + Some(2048), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + )); + } + + for io in &mut ios { + let wake_counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&wake_counter)); + let mut cx = Context::from_waker(&waker); + let poll = Pin::new(io).poll_write(&mut cx, &[0x33]); + assert!(poll.is_pending()); + counters.push(wake_counter); + } + + tokio::time::sleep(Duration::from_millis(25)).await; + let total_wakes: usize = counters + .iter() + .map(|counter| counter.wakes.load(Ordering::Relaxed)) + .sum(); + + assert!( + total_wakes <= 20 * 4, + "cross-mode contention should not create wake storms; wakes={total_wakes}" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn light_fuzz_cross_mode_release_timing_preserves_read_write_liveness() { + let _guard = quota_test_guard(); + + let mut seed = 0xC0DE_BAAD_2026_0322u64; + for round in 0..16u32 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let sleep_ms = 2 + (seed as u64 % 8); + let user = format!("cross-mode-tdd-fuzz-{}-{round}", std::process::id()); + let held = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); + let held_guard = held + .try_lock() + .expect("test must hold cross-mode lock in fuzz round"); + + let stats = Arc::new(Stats::new()); + let user_reader = user.clone(); + let reader_task = tokio::spawn(async move { + let mut io = StatsIo::new( + tokio::io::empty(), + Arc::new(SharedCounters::new()), + Arc::clone(&stats), + user_reader, + Some(4096), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + let mut one = [0u8; 1]; + io.read(&mut one).await + }); + + let user_writer = user.clone(); + let writer_task = tokio::spawn(async move { + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user_writer, + Some(4096), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + io.write_all(&[0x44]).await + }); + + tokio::time::sleep(Duration::from_millis(sleep_ms)).await; + drop(held_guard); + + let read_done = timeout(Duration::from_millis(350), reader_task) + .await + .expect("reader task must complete after release") + .expect("reader task must not panic"); + assert!(read_done.is_ok()); + + let write_done = timeout(Duration::from_millis(350), writer_task) + .await + .expect("writer task must complete after release") + .expect("writer task must not panic"); + assert!(write_done.is_ok()); + } +} diff --git a/src/proxy/tests/relay_cross_mode_quota_lock_security_tests.rs b/src/proxy/tests/relay_cross_mode_quota_lock_security_tests.rs new file mode 100644 index 0000000..5ea806a --- /dev/null +++ b/src/proxy/tests/relay_cross_mode_quota_lock_security_tests.rs @@ -0,0 +1,81 @@ +use super::*; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::task::Waker; +use std::task::{Context, Poll}; + +#[derive(Default)] +struct WakeCounter { + wakes: AtomicUsize, +} + +impl std::task::Wake for WakeCounter { + fn wake(self: Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } + + fn wake_by_ref(self: &Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } +} + +fn build_context() -> (Arc, Context<'static>) { + let wake_counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&wake_counter)); + let leaked_waker: &'static Waker = Box::leak(Box::new(waker)); + (wake_counter, Context::from_waker(leaked_waker)) +} + +#[tokio::test] +async fn adversarial_middle_held_cross_mode_lock_blocks_relay_writer() { + let _guard = quota_user_lock_test_scope(); + + let user = "cross-mode-lock-shared-user"; + let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(user); + let _held_guard = held + .try_lock() + .expect("test must hold shared cross-mode lock before relay poll"); + + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(crate::stats::Stats::new()), + user.to_string(), + Some(1024), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + let (_wake_counter, mut cx) = build_context(); + let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x41, 0x42, 0x43]); + + assert!( + matches!(poll, Poll::Pending), + "relay writer must not bypass cross-mode lock held by middle-relay path" + ); +} + +#[tokio::test] +async fn business_cross_mode_lock_uncontended_allows_relay_writer_progress() { + let _guard = quota_user_lock_test_scope(); + + let user = "cross-mode-lock-progress-user"; + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(crate::stats::Stats::new()), + user.to_string(), + Some(1024), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + let (_wake_counter, mut cx) = build_context(); + let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x51, 0x52]); + + assert!( + matches!(poll, Poll::Ready(Ok(2))), + "relay writer should progress when shared cross-mode lock is uncontended" + ); +} diff --git a/src/proxy/tests/relay_quota_lock_identity_security_tests.rs b/src/proxy/tests/relay_quota_lock_identity_security_tests.rs new file mode 100644 index 0000000..f717f54 --- /dev/null +++ b/src/proxy/tests/relay_quota_lock_identity_security_tests.rs @@ -0,0 +1,135 @@ +use super::*; +use crate::stats::Stats; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::task::Waker; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +#[derive(Default)] +struct WakeCounter { + wakes: AtomicUsize, +} + +impl std::task::Wake for WakeCounter { + fn wake(self: Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } + + fn wake_by_ref(self: &Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } +} + +fn build_context() -> (Arc, Context<'static>) { + let wake_counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&wake_counter)); + // Context stores a reference; leak one Waker for deterministic test scope. + let leaked_waker: &'static Waker = Box::leak(Box::new(waker)); + (wake_counter, Context::from_waker(leaked_waker)) +} + +#[tokio::test] +async fn adversarial_map_churn_cannot_bypass_held_writer_lock() { + let _guard = quota_user_lock_test_scope(); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let user = "quota-identity-writer-user"; + let held_lock = quota_user_lock(user); + let _held_guard = held_lock + .try_lock() + .expect("test must hold initial user lock before StatsIo poll"); + + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user.to_string(), + Some(1024), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + map.clear(); + let churned_lock = quota_user_lock(user); + assert!( + !Arc::ptr_eq(&held_lock, &churned_lock), + "precondition: map churn should produce a distinct lock identity" + ); + + let (_wake_counter, mut cx) = build_context(); + let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x11, 0x22, 0x33, 0x44]); + + assert!( + matches!(poll, Poll::Pending), + "writer must remain pending on the originally-held lock identity" + ); +} + +#[tokio::test] +async fn adversarial_map_churn_cannot_bypass_held_reader_lock() { + let _guard = quota_user_lock_test_scope(); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let user = "quota-identity-reader-user"; + let held_lock = quota_user_lock(user); + let _held_guard = held_lock + .try_lock() + .expect("test must hold initial user lock before StatsIo poll"); + + let mut io = StatsIo::new( + tokio::io::empty(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user.to_string(), + Some(1024), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + map.clear(); + let churned_lock = quota_user_lock(user); + assert!( + !Arc::ptr_eq(&held_lock, &churned_lock), + "precondition: map churn should produce a distinct lock identity" + ); + + let (_wake_counter, mut cx) = build_context(); + let mut storage = [0u8; 8]; + let mut read_buf = ReadBuf::new(&mut storage); + let poll = Pin::new(&mut io).poll_read(&mut cx, &mut read_buf); + + assert!( + matches!(poll, Poll::Pending), + "reader must remain pending on the originally-held lock identity" + ); +} + +#[tokio::test] +async fn business_no_lock_contention_keeps_writer_progress() { + let _guard = quota_user_lock_test_scope(); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let user = "quota-identity-progress-user"; + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::new(Stats::new()), + user.to_string(), + Some(1024), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + let (_wake_counter, mut cx) = build_context(); + let poll = Pin::new(&mut io).poll_write(&mut cx, &[0xAA, 0xBB]); + + assert!( + matches!(poll, Poll::Ready(Ok(2))), + "writer should progress immediately without contention" + ); +} diff --git a/src/proxy/tests/relay_quota_retry_backoff_benchmark_security_tests.rs b/src/proxy/tests/relay_quota_retry_backoff_benchmark_security_tests.rs new file mode 100644 index 0000000..7083eb2 --- /dev/null +++ b/src/proxy/tests/relay_quota_retry_backoff_benchmark_security_tests.rs @@ -0,0 +1,241 @@ +use super::*; +use crate::stats::Stats; +use dashmap::DashMap; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::task::{Context, Waker}; +use tokio::io::ReadBuf; +use tokio::time::{Duration, Instant}; + +#[derive(Default)] +struct WakeCounter { + wakes: AtomicUsize, +} + +impl std::task::Wake for WakeCounter { + fn wake(self: Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } + + fn wake_by_ref(self: &Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } +} + +fn quota_test_guard() -> impl Drop { + super::quota_user_lock_test_scope() +} + +fn saturate_quota_user_locks() -> Vec>> { + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); + for idx in 0..QUOTA_USER_LOCKS_MAX { + retained.push(quota_user_lock(&format!("quota-retry-bench-saturate-{idx}"))); + } + retained +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_contention_wake_rate_decays_with_backoff_curve() { + let _guard = quota_test_guard(); + + let _retained = saturate_quota_user_locks(); + let user = format!("quota-backoff-bench-{}", std::process::id()); + let stats = Arc::new(Stats::new()); + + let lock = quota_user_lock(&user); + let held_guard = lock + .try_lock() + .expect("test must hold quota lock before benchmark run"); + + let waiters = 64usize; + let mut ios = Vec::with_capacity(waiters); + let mut wake_counters = Vec::with_capacity(waiters); + + for _ in 0..waiters { + ios.push(StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::clone(&stats), + user.clone(), + Some(4096), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + )); + } + + for io in &mut ios { + let counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&counter)); + let mut cx = Context::from_waker(&waker); + let pending = Pin::new(io).poll_write(&mut cx, &[0x71]); + assert!(pending.is_pending()); + wake_counters.push(counter); + } + + let mut observed = vec![0usize; waiters]; + let start = Instant::now(); + let mut wakes_at_40ms = 0usize; + let mut wakes_at_160ms = 0usize; + + while start.elapsed() < Duration::from_millis(200) { + for (idx, counter) in wake_counters.iter().enumerate() { + let wakes = counter.wakes.load(Ordering::Relaxed); + if wakes > observed[idx] { + observed[idx] = wakes; + let waker = Waker::from(Arc::clone(counter)); + let mut cx = Context::from_waker(&waker); + let pending = Pin::new(&mut ios[idx]).poll_write(&mut cx, &[0x72]); + assert!(pending.is_pending()); + } + } + + let elapsed = start.elapsed(); + if elapsed >= Duration::from_millis(40) && wakes_at_40ms == 0 { + wakes_at_40ms = wake_counters + .iter() + .map(|counter| counter.wakes.load(Ordering::Relaxed)) + .sum(); + } + if elapsed >= Duration::from_millis(160) && wakes_at_160ms == 0 { + wakes_at_160ms = wake_counters + .iter() + .map(|counter| counter.wakes.load(Ordering::Relaxed)) + .sum(); + } + + tokio::time::sleep(Duration::from_millis(1)).await; + } + + let total_wakes: usize = wake_counters + .iter() + .map(|counter| counter.wakes.load(Ordering::Relaxed)) + .sum(); + + let wakes_at_200ms = total_wakes; + let early_window_wakes = wakes_at_40ms; + let late_window_wakes = wakes_at_200ms.saturating_sub(wakes_at_160ms); + + assert!( + total_wakes <= waiters * 28, + "backoff benchmark exceeded wake budget; waiters={waiters}, wakes={total_wakes}" + ); + + assert!( + early_window_wakes > 0, + "benchmark failed to observe early contention wakes" + ); + + assert!( + late_window_wakes * 4 <= early_window_wakes * 3, + "wake-rate decay invariant violated; early_0_40ms={early_window_wakes}, late_160_200ms={late_window_wakes}, total={total_wakes}" + ); + + drop(held_guard); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_read_contention_wake_rate_decays_with_backoff_curve() { + let _guard = quota_test_guard(); + + let _retained = saturate_quota_user_locks(); + let user = format!("quota-backoff-read-bench-{}", std::process::id()); + let stats = Arc::new(Stats::new()); + + let lock = quota_user_lock(&user); + let held_guard = lock + .try_lock() + .expect("test must hold quota lock before read benchmark run"); + + let waiters = 64usize; + let mut ios = Vec::with_capacity(waiters); + let mut wake_counters = Vec::with_capacity(waiters); + + for _ in 0..waiters { + ios.push(StatsIo::new( + tokio::io::empty(), + Arc::new(SharedCounters::new()), + Arc::clone(&stats), + user.clone(), + Some(4096), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + )); + } + + for io in &mut ios { + let counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&counter)); + let mut cx = Context::from_waker(&waker); + let mut storage = [0u8; 1]; + let mut buf = ReadBuf::new(&mut storage); + let pending = Pin::new(io).poll_read(&mut cx, &mut buf); + assert!(pending.is_pending()); + wake_counters.push(counter); + } + + let mut observed = vec![0usize; waiters]; + let start = Instant::now(); + let mut wakes_at_40ms = 0usize; + let mut wakes_at_160ms = 0usize; + + while start.elapsed() < Duration::from_millis(200) { + for (idx, counter) in wake_counters.iter().enumerate() { + let wakes = counter.wakes.load(Ordering::Relaxed); + if wakes > observed[idx] { + observed[idx] = wakes; + let waker = Waker::from(Arc::clone(counter)); + let mut cx = Context::from_waker(&waker); + let mut storage = [0u8; 1]; + let mut buf = ReadBuf::new(&mut storage); + let pending = Pin::new(&mut ios[idx]).poll_read(&mut cx, &mut buf); + assert!(pending.is_pending()); + } + } + + let elapsed = start.elapsed(); + if elapsed >= Duration::from_millis(40) && wakes_at_40ms == 0 { + wakes_at_40ms = wake_counters + .iter() + .map(|counter| counter.wakes.load(Ordering::Relaxed)) + .sum(); + } + if elapsed >= Duration::from_millis(160) && wakes_at_160ms == 0 { + wakes_at_160ms = wake_counters + .iter() + .map(|counter| counter.wakes.load(Ordering::Relaxed)) + .sum(); + } + + tokio::time::sleep(Duration::from_millis(1)).await; + } + + let total_wakes: usize = wake_counters + .iter() + .map(|counter| counter.wakes.load(Ordering::Relaxed)) + .sum(); + + let wakes_at_200ms = total_wakes; + let early_window_wakes = wakes_at_40ms; + let late_window_wakes = wakes_at_200ms.saturating_sub(wakes_at_160ms); + + assert!( + total_wakes <= waiters * 28, + "read backoff benchmark exceeded wake budget; waiters={waiters}, wakes={total_wakes}" + ); + + assert!( + early_window_wakes > 0, + "read benchmark failed to observe early contention wakes" + ); + + assert!( + late_window_wakes * 4 <= early_window_wakes * 3, + "read wake-rate decay invariant violated; early_0_40ms={early_window_wakes}, late_160_200ms={late_window_wakes}, total={total_wakes}" + ); + + drop(held_guard); +} diff --git a/src/proxy/tests/relay_quota_retry_backoff_security_tests.rs b/src/proxy/tests/relay_quota_retry_backoff_security_tests.rs new file mode 100644 index 0000000..7f1e451 --- /dev/null +++ b/src/proxy/tests/relay_quota_retry_backoff_security_tests.rs @@ -0,0 +1,339 @@ +use super::*; +use crate::stats::Stats; +use dashmap::DashMap; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::task::{Context, Waker}; +use tokio::io::ReadBuf; +use tokio::time::{Duration, Instant}; + +#[derive(Default)] +struct WakeCounter { + wakes: AtomicUsize, +} + +impl std::task::Wake for WakeCounter { + fn wake(self: Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } + + fn wake_by_ref(self: &Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } +} + +fn quota_test_guard() -> impl Drop { + super::quota_user_lock_test_scope() +} + +fn saturate_quota_user_locks() -> Vec>> { + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); + for idx in 0..QUOTA_USER_LOCKS_MAX { + retained.push(quota_user_lock(&format!("quota-retry-backoff-saturate-{idx}"))); + } + retained +} + +#[tokio::test] +async fn positive_uncontended_writer_keeps_retry_wakes_zero() { + let _guard = quota_test_guard(); + + let stats = Arc::new(Stats::new()); + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::clone(&stats), + "quota-backoff-positive".to_string(), + Some(2048), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + let wake_counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&wake_counter)); + let mut cx = Context::from_waker(&waker); + + let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x41, 0x42]); + assert!(poll.is_ready(), "uncontended writer must complete immediately"); + assert_eq!( + wake_counter.wakes.load(Ordering::Relaxed), + 0, + "uncontended path must not schedule deferred contention wakes" + ); +} + +#[tokio::test] +async fn adversarial_writer_sustained_contention_executor_repoll_is_rate_limited() { + let _guard = quota_test_guard(); + + let _retained = saturate_quota_user_locks(); + let user = "quota-backoff-adversarial-writer"; + let stats = Arc::new(Stats::new()); + + let lock = quota_user_lock(user); + let held_guard = lock + .try_lock() + .expect("test must hold quota lock before polling writer"); + + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::clone(&stats), + user.to_string(), + Some(2048), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + let wake_counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&wake_counter)); + let mut cx = Context::from_waker(&waker); + + let first = Pin::new(&mut io).poll_write(&mut cx, &[0xAA]); + assert!(first.is_pending()); + + let start = Instant::now(); + let mut observed = 0usize; + while start.elapsed() < Duration::from_millis(80) { + let wakes = wake_counter.wakes.load(Ordering::Relaxed); + if wakes > observed { + observed = wakes; + let pending = Pin::new(&mut io).poll_write(&mut cx, &[0xAB]); + assert!(pending.is_pending()); + } + tokio::time::sleep(Duration::from_millis(1)).await; + } + + assert!( + wake_counter.wakes.load(Ordering::Relaxed) <= 16, + "sustained contention must be rate limited; observed wakes={} in 80ms", + wake_counter.wakes.load(Ordering::Relaxed) + ); + + drop(held_guard); + let ready = Pin::new(&mut io).poll_write(&mut cx, &[0xAC]); + assert!(ready.is_ready()); +} + +#[tokio::test] +async fn adversarial_reader_sustained_contention_executor_repoll_is_rate_limited() { + let _guard = quota_test_guard(); + + let _retained = saturate_quota_user_locks(); + let user = "quota-backoff-adversarial-reader"; + let stats = Arc::new(Stats::new()); + + let lock = quota_user_lock(user); + let held_guard = lock + .try_lock() + .expect("test must hold quota lock before polling reader"); + + let mut io = StatsIo::new( + tokio::io::empty(), + Arc::new(SharedCounters::new()), + Arc::clone(&stats), + user.to_string(), + Some(2048), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + let wake_counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&wake_counter)); + let mut cx = Context::from_waker(&waker); + let mut storage = [0u8; 1]; + + let mut buf = ReadBuf::new(&mut storage); + let first = Pin::new(&mut io).poll_read(&mut cx, &mut buf); + assert!(first.is_pending()); + + let start = Instant::now(); + let mut observed = 0usize; + while start.elapsed() < Duration::from_millis(80) { + let wakes = wake_counter.wakes.load(Ordering::Relaxed); + if wakes > observed { + observed = wakes; + let mut next = ReadBuf::new(&mut storage); + let pending = Pin::new(&mut io).poll_read(&mut cx, &mut next); + assert!(pending.is_pending()); + } + tokio::time::sleep(Duration::from_millis(1)).await; + } + + assert!( + wake_counter.wakes.load(Ordering::Relaxed) <= 16, + "sustained contention must be rate limited; observed wakes={} in 80ms", + wake_counter.wakes.load(Ordering::Relaxed) + ); + + drop(held_guard); + let mut done = ReadBuf::new(&mut storage); + let ready = Pin::new(&mut io).poll_read(&mut cx, &mut done); + assert!(ready.is_ready()); +} + +#[tokio::test] +async fn edge_backoff_attempt_resets_after_contention_release() { + let _guard = quota_test_guard(); + + let _retained = saturate_quota_user_locks(); + let user = "quota-backoff-edge-reset"; + let stats = Arc::new(Stats::new()); + + let lock = quota_user_lock(user); + let held_guard = lock + .try_lock() + .expect("test must hold quota lock before polling writer"); + + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::clone(&stats), + user.to_string(), + Some(2048), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + let wake_counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&wake_counter)); + let mut cx = Context::from_waker(&waker); + + let initial = Pin::new(&mut io).poll_write(&mut cx, &[0x31]); + assert!(initial.is_pending()); + + tokio::time::sleep(Duration::from_millis(15)).await; + let wakes = wake_counter.wakes.load(Ordering::Relaxed); + if wakes > 0 { + let pending = Pin::new(&mut io).poll_write(&mut cx, &[0x32]); + assert!(pending.is_pending()); + } + + drop(held_guard); + let ready = Pin::new(&mut io).poll_write(&mut cx, &[0x33]); + assert!(ready.is_ready()); + assert!( + !io.quota_write_wake_scheduled, + "successful write must clear deferred wake scheduling flag" + ); + assert!( + io.quota_write_retry_sleep.is_none(), + "successful write must clear deferred sleep slot" + ); +} + +#[tokio::test] +async fn light_fuzz_writer_repoll_schedule_keeps_wake_budget_bounded() { + let _guard = quota_test_guard(); + + let _retained = saturate_quota_user_locks(); + let user = "quota-backoff-fuzz-writer"; + let stats = Arc::new(Stats::new()); + + let lock = quota_user_lock(user); + let held_guard = lock + .try_lock() + .expect("test must hold quota lock before fuzz loop"); + + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::clone(&stats), + user.to_string(), + Some(2048), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + let wake_counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&wake_counter)); + let mut cx = Context::from_waker(&waker); + + let mut seed = 0x5EED_CAFE_7788_9900u64; + for _ in 0..64 { + let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x51]); + assert!(poll.is_pending()); + + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + let sleep_ms = (seed % 4) as u64; + tokio::time::sleep(Duration::from_millis(sleep_ms)).await; + } + + assert!( + wake_counter.wakes.load(Ordering::Relaxed) <= 24, + "fuzzed repoll schedule must keep wake budget bounded; observed wakes={}", + wake_counter.wakes.load(Ordering::Relaxed) + ); + + drop(held_guard); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_multi_waiter_contention_keeps_global_wake_budget_bounded() { + let _guard = quota_test_guard(); + + let _retained = saturate_quota_user_locks(); + let user = format!("quota-backoff-stress-{}", std::process::id()); + let stats = Arc::new(Stats::new()); + + let lock = quota_user_lock(&user); + let held_guard = lock + .try_lock() + .expect("test must hold quota lock before launching stress waiters"); + + let waiters = 48usize; + let mut ios = Vec::with_capacity(waiters); + let mut wake_counters = Vec::with_capacity(waiters); + + for _ in 0..waiters { + ios.push(StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::clone(&stats), + user.clone(), + Some(4096), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + )); + } + + for io in &mut ios { + let counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&counter)); + let mut cx = Context::from_waker(&waker); + let pending = Pin::new(io).poll_write(&mut cx, &[0x61]); + assert!(pending.is_pending()); + wake_counters.push(counter); + } + + let start = Instant::now(); + while start.elapsed() < Duration::from_millis(120) { + for (idx, counter) in wake_counters.iter().enumerate() { + if counter.wakes.load(Ordering::Relaxed) > 0 { + let waker = Waker::from(Arc::clone(counter)); + let mut cx = Context::from_waker(&waker); + let pending = Pin::new(&mut ios[idx]).poll_write(&mut cx, &[0x62]); + assert!(pending.is_pending()); + } + } + tokio::time::sleep(Duration::from_millis(1)).await; + } + + let total_wakes: usize = wake_counters + .iter() + .map(|counter| counter.wakes.load(Ordering::Relaxed)) + .sum(); + + assert!( + total_wakes <= waiters * 20, + "stress contention must keep aggregate wake budget bounded; waiters={waiters}, wakes={total_wakes}" + ); + + drop(held_guard); +} diff --git a/src/proxy/tests/relay_quota_retry_scheduler_tdd_tests.rs b/src/proxy/tests/relay_quota_retry_scheduler_tdd_tests.rs new file mode 100644 index 0000000..35a6b6e --- /dev/null +++ b/src/proxy/tests/relay_quota_retry_scheduler_tdd_tests.rs @@ -0,0 +1,246 @@ +use super::*; +use crate::stats::Stats; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::task::{Context, Poll, Waker}; +use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadBuf}; +use tokio::time::{Duration, timeout}; + +#[derive(Default)] +struct WakeCounter { + wakes: AtomicUsize, +} + +impl std::task::Wake for WakeCounter { + fn wake(self: Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } + + fn wake_by_ref(self: &Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } +} + +fn quota_test_guard() -> impl Drop { + super::quota_user_lock_test_scope() +} + +#[tokio::test] +async fn positive_uncontended_quota_limited_writer_completes() { + let _guard = quota_test_guard(); + + let stats = Arc::new(Stats::new()); + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::clone(&stats), + "tdd-uncontended".to_string(), + Some(1024), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + let result = io.write_all(&[0x41, 0x42, 0x43]).await; + assert!(result.is_ok(), "uncontended writer must complete"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn adversarial_contended_writers_without_repoll_must_not_wake_storm() { + let _guard = quota_test_guard(); + + let user = format!("tdd-writer-storm-{}", std::process::id()); + let held = quota_user_lock(&user); + let _held_guard = held + .try_lock() + .expect("test must hold quota lock before polling writers"); + + let stats = Arc::new(Stats::new()); + let writers = 24usize; + let mut ios = Vec::with_capacity(writers); + let mut wake_counters = Vec::with_capacity(writers); + + for _ in 0..writers { + ios.push(StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::clone(&stats), + user.clone(), + Some(1024), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + )); + } + + for io in &mut ios { + let counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&counter)); + let mut cx = Context::from_waker(&waker); + let poll = Pin::new(io).poll_write(&mut cx, &[0xAA]); + assert!(poll.is_pending(), "writer must be pending under held lock"); + wake_counters.push(counter); + } + + tokio::time::sleep(Duration::from_millis(25)).await; + + let total_wakes: usize = wake_counters + .iter() + .map(|counter| counter.wakes.load(Ordering::Relaxed)) + .sum(); + + assert!( + total_wakes <= writers * 4, + "retry scheduler must remain bounded without repoll; observed wakes={total_wakes}, writers={writers}" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn adversarial_contended_readers_without_repoll_must_not_wake_storm() { + let _guard = quota_test_guard(); + + let user = format!("tdd-reader-storm-{}", std::process::id()); + let held = quota_user_lock(&user); + let _held_guard = held + .try_lock() + .expect("test must hold quota lock before polling readers"); + + let stats = Arc::new(Stats::new()); + let readers = 24usize; + let mut ios = Vec::with_capacity(readers); + let mut wake_counters = Vec::with_capacity(readers); + + for _ in 0..readers { + ios.push(StatsIo::new( + tokio::io::empty(), + Arc::new(SharedCounters::new()), + Arc::clone(&stats), + user.clone(), + Some(1024), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + )); + } + + for io in &mut ios { + let counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&counter)); + let mut cx = Context::from_waker(&waker); + let mut storage = [0u8; 1]; + let mut buf = ReadBuf::new(&mut storage); + let poll = Pin::new(io).poll_read(&mut cx, &mut buf); + assert!(poll.is_pending(), "reader must be pending under held lock"); + wake_counters.push(counter); + } + + tokio::time::sleep(Duration::from_millis(25)).await; + + let total_wakes: usize = wake_counters + .iter() + .map(|counter| counter.wakes.load(Ordering::Relaxed)) + .sum(); + + assert!( + total_wakes <= readers * 4, + "retry scheduler must remain bounded without repoll; observed wakes={total_wakes}, readers={readers}" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn integration_contended_waiters_resume_after_lock_release() { + let _guard = quota_test_guard(); + + let user = format!("tdd-resume-{}", std::process::id()); + let held = quota_user_lock(&user); + let held_guard = held + .try_lock() + .expect("test must hold quota lock before launching waiters"); + + let stats = Arc::new(Stats::new()); + let mut waiters = Vec::new(); + for _ in 0..12 { + let stats = Arc::clone(&stats); + let user = user.clone(); + waiters.push(tokio::spawn(async move { + let mut io = StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + stats, + user, + Some(2048), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + io.write_all(&[0x5A]).await + })); + } + + tokio::time::sleep(Duration::from_millis(5)).await; + drop(held_guard); + + timeout(Duration::from_secs(1), async { + for waiter in waiters { + let result = waiter.await.expect("waiter task must not panic"); + assert!(result.is_ok(), "waiter must complete after release"); + } + }) + .await + .expect("all waiters must complete in bounded time"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn light_fuzz_contention_rounds_keep_retry_wakes_bounded() { + let _guard = quota_test_guard(); + + let mut seed = 0x9E37_79B9_AA55_1234u64; + for round in 0..20u32 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let writers = 8 + (seed as usize % 12); + let sleep_ms = 10 + (seed as u64 % 15); + let user = format!("tdd-fuzz-{}-{round}", std::process::id()); + + let held = quota_user_lock(&user); + let _held_guard = held + .try_lock() + .expect("test must hold quota lock in fuzz round"); + + let stats = Arc::new(Stats::new()); + let mut ios = Vec::with_capacity(writers); + let mut wake_counters = Vec::with_capacity(writers); + + for _ in 0..writers { + ios.push(StatsIo::new( + tokio::io::sink(), + Arc::new(SharedCounters::new()), + Arc::clone(&stats), + user.clone(), + Some(2048), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + )); + } + + for io in &mut ios { + let counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&counter)); + let mut cx = Context::from_waker(&waker); + let poll = Pin::new(io).poll_write(&mut cx, &[0x7A]); + assert!(matches!(poll, Poll::Pending)); + wake_counters.push(counter); + } + + tokio::time::sleep(Duration::from_millis(sleep_ms)).await; + + let total_wakes: usize = wake_counters + .iter() + .map(|counter| counter.wakes.load(Ordering::Relaxed)) + .sum(); + + assert!( + total_wakes <= writers * 4, + "fuzz round must keep wakes bounded; round={round}, writers={writers}, wakes={total_wakes}, sleep_ms={sleep_ms}" + ); + } +} diff --git a/src/proxy/tests/relay_security_tests.rs b/src/proxy/tests/relay_security_tests.rs index 50cdfa3..7375192 100644 --- a/src/proxy/tests/relay_security_tests.rs +++ b/src/proxy/tests/relay_security_tests.rs @@ -137,10 +137,10 @@ async fn quota_lock_contention_writer_schedules_single_deferred_wake_until_lock_ for _ in 0..8 { tokio::task::yield_now().await; } - assert_eq!( - wake_counter.wakes.load(Ordering::Relaxed), - wakes_after_first_yield, - "writer contention should not schedule unbounded wake storms before lock acquisition" + let wakes_after_second_window = wake_counter.wakes.load(Ordering::Relaxed); + assert!( + wakes_after_second_window <= wakes_after_first_yield.saturating_add(2), + "writer contention should keep retry wakes bounded before lock acquisition: before={wakes_after_first_yield}, after={wakes_after_second_window}" ); drop(held_lock); diff --git a/src/stats/mod.rs b/src/stats/mod.rs index d13d834..dc455a1 100644 --- a/src/stats/mod.rs +++ b/src/stats/mod.rs @@ -1884,6 +1884,32 @@ impl Stats { stats.octets_to_client.fetch_add(bytes, Ordering::Relaxed); } + pub fn sub_user_octets_to(&self, user: &str, bytes: u64) { + if !self.telemetry_user_enabled() { + return; + } + self.maybe_cleanup_user_stats(); + let Some(stats) = self.user_stats.get(user) else { + return; + }; + + Self::touch_user_stats(stats.value()); + let counter = &stats.octets_to_client; + let mut current = counter.load(Ordering::Relaxed); + loop { + let next = current.saturating_sub(bytes); + match counter.compare_exchange_weak( + current, + next, + Ordering::Relaxed, + Ordering::Relaxed, + ) { + Ok(_) => break, + Err(actual) => current = actual, + } + } + } + pub fn increment_user_msgs_from(&self, user: &str) { if !self.telemetry_user_enabled() { return; @@ -2440,3 +2466,7 @@ mod connection_lease_security_tests; #[cfg(test)] #[path = "tests/replay_checker_security_tests.rs"] mod replay_checker_security_tests; + +#[cfg(test)] +#[path = "tests/user_octets_sub_security_tests.rs"] +mod user_octets_sub_security_tests; diff --git a/src/stats/tests/user_octets_sub_security_tests.rs b/src/stats/tests/user_octets_sub_security_tests.rs new file mode 100644 index 0000000..d4e7580 --- /dev/null +++ b/src/stats/tests/user_octets_sub_security_tests.rs @@ -0,0 +1,151 @@ +use super::*; +use std::sync::Arc; +use std::thread; + +#[test] +fn sub_user_octets_to_underflow_saturates_at_zero() { + let stats = Stats::new(); + let user = "sub-underflow-user"; + + stats.add_user_octets_to(user, 3); + stats.sub_user_octets_to(user, 100); + + assert_eq!(stats.get_user_total_octets(user), 0); +} + +#[test] +fn sub_user_octets_to_does_not_affect_octets_from_client() { + let stats = Stats::new(); + let user = "sub-isolation-user"; + + stats.add_user_octets_from(user, 17); + stats.add_user_octets_to(user, 5); + stats.sub_user_octets_to(user, 3); + + assert_eq!(stats.get_user_total_octets(user), 19); +} + +#[test] +fn light_fuzz_add_sub_model_matches_saturating_reference() { + let stats = Stats::new(); + let user = "sub-fuzz-user"; + let mut seed = 0x91D2_4CB8_EE77_1101u64; + let mut model_to = 0u64; + + for _ in 0..8192 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let amt = ((seed >> 8) & 0x3f) + 1; + if (seed & 1) == 0 { + stats.add_user_octets_to(user, amt); + model_to = model_to.saturating_add(amt); + } else { + stats.sub_user_octets_to(user, amt); + model_to = model_to.saturating_sub(amt); + } + } + + assert_eq!(stats.get_user_total_octets(user), model_to); +} + +#[test] +fn stress_parallel_add_sub_never_underflows_or_panics() { + let stats = Arc::new(Stats::new()); + let user = "sub-stress-user"; + // Pre-fund with a large offset so subtractions never saturate at zero. + // This guarantees commutative updates, making the final state deterministic. + let base_offset = 10_000_000u64; + stats.add_user_octets_to(user, base_offset); + + let mut workers = Vec::new(); + + for tid in 0..16u64 { + let stats_for_thread = Arc::clone(&stats); + workers.push(thread::spawn(move || { + let mut seed = 0xD00D_1000_0000_0000u64 ^ tid; + let mut net_delta = 0i64; + for _ in 0..4096 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + let amt = ((seed >> 8) & 0x1f) + 1; + + if (seed & 1) == 0 { + stats_for_thread.add_user_octets_to(user, amt); + net_delta += amt as i64; + } else { + stats_for_thread.sub_user_octets_to(user, amt); + net_delta -= amt as i64; + } + } + + net_delta + })); + } + + let mut expected_net_delta = 0i64; + for worker in workers { + expected_net_delta += worker + .join() + .expect("sub-user stress worker must not panic"); + } + + let expected_total = (base_offset as i64 + expected_net_delta) as u64; + let total = stats.get_user_total_octets(user); + assert_eq!( + total, expected_total, + "concurrent add/sub lost updates or suffered ABA races" + ); +} + +#[test] +fn sub_user_octets_to_missing_user_is_noop() { + let stats = Stats::new(); + stats.sub_user_octets_to("missing-user", 1024); + assert_eq!(stats.get_user_total_octets("missing-user"), 0); +} + +#[test] +fn stress_parallel_per_user_models_remain_exact() { + let stats = Arc::new(Stats::new()); + let mut workers = Vec::new(); + + for tid in 0..16u64 { + let stats_for_thread = Arc::clone(&stats); + workers.push(thread::spawn(move || { + let user = format!("sub-per-user-{tid}"); + let mut seed = 0xFACE_0000_0000_0000u64 ^ tid; + let mut model = 0u64; + + for _ in 0..4096 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + let amt = ((seed >> 8) & 0x3f) + 1; + + if (seed & 1) == 0 { + stats_for_thread.add_user_octets_to(&user, amt); + model = model.saturating_add(amt); + } else { + stats_for_thread.sub_user_octets_to(&user, amt); + model = model.saturating_sub(amt); + } + } + + (user, model) + })); + } + + for worker in workers { + let (user, model) = worker + .join() + .expect("per-user subtract stress worker must not panic"); + assert_eq!( + stats.get_user_total_octets(&user), + model, + "per-user parallel model diverged" + ); + } +} \ No newline at end of file