diff --git a/AGENTS.md b/AGENTS.md index e7f94a5..c17cc76 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -390,6 +390,12 @@ you MUST explain why existing invariants remain valid. - Do not modify existing tests unless the task explicitly requires it. - Do not weaken assertions. - Preserve determinism in testable components. +- Bug-first forces the discipline of proving you understand a bug before you fix it. Tests written after a fix almost always pass trivially and catch nothing new. +- Invariants over scenarios is the core shift. The route_mode table alone would have caught both BUG-1 and BUG-2 before they were written — "snapshot equals watch state after any transition burst" is a two-line property test that fails immediately on the current diverged-atomics code. +- Differential/model catches logic drift over time. +- Scheduler pressure is specifically aimed at the concurrent state bugs that keep reappearing. A single-threaded happy-path test of set_mode will never find subtle bugs; 10,000 concurrent calls will find it on the first run. +- Mutation gate answers your original complaint directly. It measures test power. If you can remove a bounds check and nothing breaks, the suite isn't covering that branch yet — it just says so explicitly. +- Dead parameter is a code smell rule. ### 15. Security Constraints diff --git a/Cargo.lock b/Cargo.lock index 89eefd6..7749ef5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -425,6 +425,32 @@ dependencies = [ "cipher", ] +[[package]] +name = "curve25519-dalek" +version = "4.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97fb8b7c4503de7d6ae7b42ab72a5a59857b4c937ec27a3d4539dba95b5ab2be" +dependencies = [ + "cfg-if", + "cpufeatures", + "curve25519-dalek-derive", + "fiat-crypto", + "rustc_version", + "subtle", + "zeroize", +] + +[[package]] +name = "curve25519-dalek-derive" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f46882e17999c6cc590af592290432be3bce0428cb0d5f8b6715e4dc7b383eb3" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.114", +] + [[package]] name = "dashmap" version = "5.5.3" @@ -517,6 +543,12 @@ version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" +[[package]] +name = "fiat-crypto" +version = "0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28dea519a9695b9977216879a3ebfddf92f1c08c05d984f8996aecd6ecdc811d" + [[package]] name = "filetime" version = "0.2.27" @@ -1609,7 +1641,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" dependencies = [ "rand_chacha", - "rand_core", + "rand_core 0.9.5", ] [[package]] @@ -1619,9 +1651,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" dependencies = [ "ppv-lite86", - "rand_core", + "rand_core 0.9.5", ] +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" + [[package]] name = "rand_core" version = "0.9.5" @@ -1637,7 +1675,7 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "513962919efc330f829edb2535844d1b912b0fbe2ca165d613e4e8788bb05a5a" dependencies = [ - "rand_core", + "rand_core 0.9.5", ] [[package]] @@ -2093,7 +2131,7 @@ dependencies = [ [[package]] name = "telemt" -version = "3.3.19" +version = "3.3.20" dependencies = [ "aes", "anyhow", @@ -2145,6 +2183,7 @@ dependencies = [ "tracing-subscriber", "url", "webpki-roots 0.26.11", + "x25519-dalek", "x509-parser", "zeroize", ] @@ -3144,6 +3183,18 @@ version = "0.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9edde0db4769d2dc68579893f2306b26c6ecfbe0ef499b013d731b7b9247e0b9" +[[package]] +name = "x25519-dalek" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7e468321c81fb07fa7f4c636c3972b9100f0346e5b6a9f2bd0603a52f7ed277" +dependencies = [ + "curve25519-dalek", + "rand_core 0.6.4", + "serde", + "zeroize", +] + [[package]] name = "x509-parser" version = "0.15.1" diff --git a/Cargo.toml b/Cargo.toml index 4e12cad..a47a4e5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -52,6 +52,7 @@ regex = "1.11" crossbeam-queue = "0.3" num-bigint = "0.4" num-traits = "0.2" +x25519-dalek = "2" anyhow = "1.0" # HTTP diff --git a/config.toml b/config.toml index 63fa4ae..f4eb3ae 100644 --- a/config.toml +++ b/config.toml @@ -32,6 +32,7 @@ show = "*" port = 443 # proxy_protocol = false # Enable if behind HAProxy/nginx with PROXY protocol # metrics_port = 9090 +# metrics_listen = "0.0.0.0:9090" # Listen address for metrics (overrides metrics_port) # metrics_whitelist = ["127.0.0.1", "::1", "0.0.0.0/0"] [server.api] diff --git a/docs/fronting-splitting/TLS-F-TCP-S.ru.md b/docs/fronting-splitting/TLS-F-TCP-S.ru.md index 6ae6f05..1f9f872 100644 --- a/docs/fronting-splitting/TLS-F-TCP-S.ru.md +++ b/docs/fronting-splitting/TLS-F-TCP-S.ru.md @@ -38,8 +38,9 @@ umweltschutz.de -> A-запись 198.18.88.88 В конфигурации Telemt: -``` -tls_domain = umweltschutz.de +```toml +[censorship] +tls_domain = "umweltschutz.de" ``` Этот домен используется клиентом как SNI в ClientHello @@ -56,8 +57,9 @@ tls_domain = umweltschutz.de В конфигурации Telemt: -``` -mask_host = 127.0.0.1 +```toml +[censorship] +mask_host = "127.0.0.1" mask_port = 8443 ``` @@ -151,16 +153,18 @@ mask_host:mask_port Например: -``` -tls_domain = github.com -mask_host = github.com +```toml +[censorship] +tls_domain = "github.com" +mask_host = "github.com" mask_port = 443 ``` или -``` -mask_host = 140.82.121.4 +```toml +[censorship] +mask_host = "140.82.121.4" ``` В этом случае: diff --git a/src/cli.rs b/src/cli.rs index a1182a7..8ea9c9f 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -239,7 +239,7 @@ tls_full_cert_ttl_secs = 90 [access] replay_check_len = 65536 -replay_window_secs = 1800 +replay_window_secs = 120 ignore_time_skew = false [access.users] diff --git a/src/config/defaults.rs b/src/config/defaults.rs index ea9250d..73b12d8 100644 --- a/src/config/defaults.rs +++ b/src/config/defaults.rs @@ -73,7 +73,9 @@ pub(crate) fn default_replay_check_len() -> usize { } pub(crate) fn default_replay_window_secs() -> u64 { - 1800 + // Keep replay cache TTL tight by default to reduce replay surface. + // Deployments with higher RTT or longer reconnect jitter can override this in config. + 120 } pub(crate) fn default_handshake_timeout() -> u64 { @@ -456,11 +458,11 @@ pub(crate) fn default_tls_full_cert_ttl_secs() -> u64 { } pub(crate) fn default_server_hello_delay_min_ms() -> u64 { - 0 + 8 } pub(crate) fn default_server_hello_delay_max_ms() -> u64 { - 0 + 24 } pub(crate) fn default_alpn_enforce() -> bool { diff --git a/src/config/types.rs b/src/config/types.rs index 808698d..0c5f09b 100644 --- a/src/config/types.rs +++ b/src/config/types.rs @@ -1163,9 +1163,17 @@ pub struct ServerConfig { #[serde(default)] pub proxy_protocol_trusted_cidrs: Vec, + /// Port for the Prometheus-compatible metrics endpoint. + /// Enables metrics when set; binds on all interfaces (dual-stack) by default. #[serde(default)] pub metrics_port: Option, + /// Listen address for metrics in `IP:PORT` format (e.g. `"127.0.0.1:9090"`). + /// When set, takes precedence over `metrics_port` and binds on the specified address only. + #[serde(default)] + pub metrics_listen: Option, + + /// CIDR whitelist for the metrics endpoint. #[serde(default = "default_metrics_whitelist")] pub metrics_whitelist: Vec, @@ -1194,6 +1202,7 @@ impl Default for ServerConfig { proxy_protocol_header_timeout_ms: default_proxy_protocol_header_timeout_ms(), proxy_protocol_trusted_cidrs: Vec::new(), metrics_port: None, + metrics_listen: None, metrics_whitelist: default_metrics_whitelist(), api: ApiConfig::default(), listeners: Vec::new(), diff --git a/src/ip_tracker_regression_tests.rs b/src/ip_tracker_regression_tests.rs new file mode 100644 index 0000000..5d6b358 --- /dev/null +++ b/src/ip_tracker_regression_tests.rs @@ -0,0 +1,450 @@ +use std::collections::HashMap; +use std::net::{IpAddr, Ipv4Addr}; +use std::sync::Arc; +use std::time::Duration; + +use crate::config::UserMaxUniqueIpsMode; +use crate::ip_tracker::UserIpTracker; + +fn ip_from_idx(idx: u32) -> IpAddr { + let a = 10u8; + let b = ((idx / 65_536) % 256) as u8; + let c = ((idx / 256) % 256) as u8; + let d = (idx % 256) as u8; + IpAddr::V4(Ipv4Addr::new(a, b, c, d)) +} + +#[tokio::test] +async fn active_window_enforces_large_unique_ip_burst() { + let tracker = UserIpTracker::new(); + tracker.set_user_limit("burst_user", 64).await; + tracker + .set_limit_policy(UserMaxUniqueIpsMode::ActiveWindow, 30) + .await; + + for idx in 0..64 { + assert!(tracker.check_and_add("burst_user", ip_from_idx(idx)).await.is_ok()); + } + assert!(tracker.check_and_add("burst_user", ip_from_idx(9_999)).await.is_err()); + assert_eq!(tracker.get_active_ip_count("burst_user").await, 64); +} + +#[tokio::test] +async fn global_limit_applies_across_many_users() { + let tracker = UserIpTracker::new(); + tracker.load_limits(3, &HashMap::new()).await; + + for user_idx in 0..150u32 { + let user = format!("u{}", user_idx); + assert!(tracker.check_and_add(&user, ip_from_idx(user_idx * 10)).await.is_ok()); + assert!(tracker + .check_and_add(&user, ip_from_idx(user_idx * 10 + 1)) + .await + .is_ok()); + assert!(tracker + .check_and_add(&user, ip_from_idx(user_idx * 10 + 2)) + .await + .is_ok()); + assert!(tracker + .check_and_add(&user, ip_from_idx(user_idx * 10 + 3)) + .await + .is_err()); + } + + assert_eq!(tracker.get_stats().await.len(), 150); +} + +#[tokio::test] +async fn user_zero_override_falls_back_to_global_limit() { + let tracker = UserIpTracker::new(); + let mut limits = HashMap::new(); + limits.insert("target".to_string(), 0); + tracker.load_limits(2, &limits).await; + + assert!(tracker.check_and_add("target", ip_from_idx(1)).await.is_ok()); + assert!(tracker.check_and_add("target", ip_from_idx(2)).await.is_ok()); + assert!(tracker.check_and_add("target", ip_from_idx(3)).await.is_err()); + assert_eq!(tracker.get_user_limit("target").await, Some(2)); +} + +#[tokio::test] +async fn remove_ip_is_idempotent_after_counter_reaches_zero() { + let tracker = UserIpTracker::new(); + tracker.set_user_limit("u", 2).await; + let ip = ip_from_idx(42); + + tracker.check_and_add("u", ip).await.unwrap(); + tracker.remove_ip("u", ip).await; + tracker.remove_ip("u", ip).await; + tracker.remove_ip("u", ip).await; + + assert_eq!(tracker.get_active_ip_count("u").await, 0); + assert!(!tracker.is_ip_active("u", ip).await); +} + +#[tokio::test] +async fn clear_user_ips_resets_active_and_recent() { + let tracker = UserIpTracker::new(); + tracker.set_user_limit("u", 10).await; + + for idx in 0..6 { + tracker.check_and_add("u", ip_from_idx(idx)).await.unwrap(); + } + + tracker.clear_user_ips("u").await; + + assert_eq!(tracker.get_active_ip_count("u").await, 0); + let counts = tracker + .get_recent_counts_for_users(&["u".to_string()]) + .await; + assert_eq!(counts.get("u").copied().unwrap_or(0), 0); +} + +#[tokio::test] +async fn clear_all_resets_multi_user_state() { + let tracker = UserIpTracker::new(); + + for user_idx in 0..80u32 { + let user = format!("u{}", user_idx); + for ip_idx in 0..3 { + tracker + .check_and_add(&user, ip_from_idx(user_idx * 100 + ip_idx)) + .await + .unwrap(); + } + } + + tracker.clear_all().await; + + assert!(tracker.get_stats().await.is_empty()); + let users = (0..80u32) + .map(|idx| format!("u{}", idx)) + .collect::>(); + let recent = tracker.get_recent_counts_for_users(&users).await; + assert!(recent.values().all(|count| *count == 0)); +} + +#[tokio::test] +async fn get_active_ips_for_users_are_sorted() { + let tracker = UserIpTracker::new(); + tracker.set_user_limit("user", 10).await; + + tracker + .check_and_add("user", IpAddr::V4(Ipv4Addr::new(10, 0, 0, 9))) + .await + .unwrap(); + tracker + .check_and_add("user", IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))) + .await + .unwrap(); + tracker + .check_and_add("user", IpAddr::V4(Ipv4Addr::new(10, 0, 0, 5))) + .await + .unwrap(); + + let map = tracker + .get_active_ips_for_users(&["user".to_string()]) + .await; + let ips = map.get("user").cloned().unwrap_or_default(); + + assert_eq!( + ips, + vec![ + IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), + IpAddr::V4(Ipv4Addr::new(10, 0, 0, 5)), + IpAddr::V4(Ipv4Addr::new(10, 0, 0, 9)), + ] + ); +} + +#[tokio::test] +async fn get_recent_ips_for_users_are_sorted() { + let tracker = UserIpTracker::new(); + tracker.set_user_limit("user", 10).await; + + tracker + .check_and_add("user", IpAddr::V4(Ipv4Addr::new(10, 1, 0, 9))) + .await + .unwrap(); + tracker + .check_and_add("user", IpAddr::V4(Ipv4Addr::new(10, 1, 0, 1))) + .await + .unwrap(); + tracker + .check_and_add("user", IpAddr::V4(Ipv4Addr::new(10, 1, 0, 5))) + .await + .unwrap(); + + let map = tracker + .get_recent_ips_for_users(&["user".to_string()]) + .await; + let ips = map.get("user").cloned().unwrap_or_default(); + + assert_eq!( + ips, + vec![ + IpAddr::V4(Ipv4Addr::new(10, 1, 0, 1)), + IpAddr::V4(Ipv4Addr::new(10, 1, 0, 5)), + IpAddr::V4(Ipv4Addr::new(10, 1, 0, 9)), + ] + ); +} + +#[tokio::test] +async fn time_window_expires_for_large_rotation() { + let tracker = UserIpTracker::new(); + tracker.set_user_limit("tw", 1).await; + tracker + .set_limit_policy(UserMaxUniqueIpsMode::TimeWindow, 1) + .await; + + tracker.check_and_add("tw", ip_from_idx(1)).await.unwrap(); + tracker.remove_ip("tw", ip_from_idx(1)).await; + assert!(tracker.check_and_add("tw", ip_from_idx(2)).await.is_err()); + + tokio::time::sleep(Duration::from_millis(1_100)).await; + assert!(tracker.check_and_add("tw", ip_from_idx(2)).await.is_ok()); +} + +#[tokio::test] +async fn combined_mode_blocks_recent_after_disconnect() { + let tracker = UserIpTracker::new(); + tracker.set_user_limit("cmb", 1).await; + tracker + .set_limit_policy(UserMaxUniqueIpsMode::Combined, 2) + .await; + + tracker.check_and_add("cmb", ip_from_idx(11)).await.unwrap(); + tracker.remove_ip("cmb", ip_from_idx(11)).await; + + assert!(tracker.check_and_add("cmb", ip_from_idx(12)).await.is_err()); +} + +#[tokio::test] +async fn load_limits_replaces_large_limit_map() { + let tracker = UserIpTracker::new(); + let mut first = HashMap::new(); + let mut second = HashMap::new(); + + for idx in 0..300usize { + first.insert(format!("u{}", idx), 2usize); + } + for idx in 150..450usize { + second.insert(format!("u{}", idx), 4usize); + } + + tracker.load_limits(0, &first).await; + tracker.load_limits(0, &second).await; + + assert_eq!(tracker.get_user_limit("u20").await, None); + assert_eq!(tracker.get_user_limit("u200").await, Some(4)); + assert_eq!(tracker.get_user_limit("u420").await, Some(4)); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn concurrent_same_user_unique_ip_pressure_stays_bounded() { + let tracker = Arc::new(UserIpTracker::new()); + tracker.set_user_limit("hot", 32).await; + tracker + .set_limit_policy(UserMaxUniqueIpsMode::ActiveWindow, 30) + .await; + + let mut handles = Vec::new(); + for worker in 0..16u32 { + let tracker_cloned = tracker.clone(); + handles.push(tokio::spawn(async move { + let base = worker * 200; + for step in 0..200u32 { + let _ = tracker_cloned + .check_and_add("hot", ip_from_idx(base + step)) + .await; + } + })); + } + + for handle in handles { + handle.await.unwrap(); + } + + assert!(tracker.get_active_ip_count("hot").await <= 32); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn concurrent_many_users_isolate_limits() { + let tracker = Arc::new(UserIpTracker::new()); + tracker.load_limits(4, &HashMap::new()).await; + + let mut handles = Vec::new(); + for user_idx in 0..120u32 { + let tracker_cloned = tracker.clone(); + handles.push(tokio::spawn(async move { + let user = format!("u{}", user_idx); + for ip_idx in 0..10u32 { + let _ = tracker_cloned + .check_and_add(&user, ip_from_idx(user_idx * 1_000 + ip_idx)) + .await; + } + })); + } + + for handle in handles { + handle.await.unwrap(); + } + + let stats = tracker.get_stats().await; + assert_eq!(stats.len(), 120); + assert!(stats.iter().all(|(_, active, limit)| *active <= 4 && *limit == 4)); +} + +#[tokio::test] +async fn same_ip_reconnect_high_frequency_keeps_single_unique() { + let tracker = UserIpTracker::new(); + tracker.set_user_limit("same", 2).await; + let ip = ip_from_idx(9); + + for _ in 0..2_000 { + tracker.check_and_add("same", ip).await.unwrap(); + } + + assert_eq!(tracker.get_active_ip_count("same").await, 1); + assert!(tracker.is_ip_active("same", ip).await); +} + +#[tokio::test] +async fn format_stats_contains_expected_limited_and_unlimited_markers() { + let tracker = UserIpTracker::new(); + tracker.set_user_limit("limited", 2).await; + tracker.check_and_add("limited", ip_from_idx(1)).await.unwrap(); + tracker.check_and_add("open", ip_from_idx(2)).await.unwrap(); + + let text = tracker.format_stats().await; + + assert!(text.contains("limited")); + assert!(text.contains("open")); + assert!(text.contains("unlimited")); +} + +#[tokio::test] +async fn stats_report_global_default_for_users_without_override() { + let tracker = UserIpTracker::new(); + tracker.load_limits(5, &HashMap::new()).await; + + tracker.check_and_add("a", ip_from_idx(1)).await.unwrap(); + tracker.check_and_add("b", ip_from_idx(2)).await.unwrap(); + + let stats = tracker.get_stats().await; + assert!(stats.iter().any(|(user, _, limit)| user == "a" && *limit == 5)); + assert!(stats.iter().any(|(user, _, limit)| user == "b" && *limit == 5)); +} + +#[tokio::test] +async fn stress_cycle_add_remove_clear_preserves_empty_end_state() { + let tracker = UserIpTracker::new(); + + for cycle in 0..50u32 { + let user = format!("cycle{}", cycle); + tracker.set_user_limit(&user, 128).await; + + for ip_idx in 0..128u32 { + tracker + .check_and_add(&user, ip_from_idx(cycle * 10_000 + ip_idx)) + .await + .unwrap(); + } + + for ip_idx in 0..128u32 { + tracker + .remove_ip(&user, ip_from_idx(cycle * 10_000 + ip_idx)) + .await; + } + + tracker.clear_user_ips(&user).await; + } + + assert!(tracker.get_stats().await.is_empty()); +} + +#[tokio::test] +async fn remove_unknown_user_or_ip_does_not_corrupt_state() { + let tracker = UserIpTracker::new(); + + tracker.remove_ip("no_user", ip_from_idx(1)).await; + tracker.check_and_add("x", ip_from_idx(2)).await.unwrap(); + tracker.remove_ip("x", ip_from_idx(3)).await; + + assert_eq!(tracker.get_active_ip_count("x").await, 1); + assert!(tracker.is_ip_active("x", ip_from_idx(2)).await); +} + +#[tokio::test] +async fn active_and_recent_views_match_after_mixed_workload() { + let tracker = UserIpTracker::new(); + tracker.set_user_limit("mix", 16).await; + + for ip_idx in 0..12u32 { + tracker.check_and_add("mix", ip_from_idx(ip_idx)).await.unwrap(); + } + for ip_idx in 0..6u32 { + tracker.remove_ip("mix", ip_from_idx(ip_idx)).await; + } + + let active = tracker + .get_active_ips_for_users(&["mix".to_string()]) + .await + .get("mix") + .cloned() + .unwrap_or_default(); + let recent_count = tracker + .get_recent_counts_for_users(&["mix".to_string()]) + .await + .get("mix") + .copied() + .unwrap_or(0); + + assert_eq!(active.len(), 6); + assert!(recent_count >= active.len()); + assert!(recent_count <= 12); +} + +#[tokio::test] +async fn global_limit_switch_updates_enforcement_immediately() { + let tracker = UserIpTracker::new(); + tracker.load_limits(2, &HashMap::new()).await; + + assert!(tracker.check_and_add("u", ip_from_idx(1)).await.is_ok()); + assert!(tracker.check_and_add("u", ip_from_idx(2)).await.is_ok()); + assert!(tracker.check_and_add("u", ip_from_idx(3)).await.is_err()); + + tracker.clear_user_ips("u").await; + tracker.load_limits(4, &HashMap::new()).await; + + assert!(tracker.check_and_add("u", ip_from_idx(1)).await.is_ok()); + assert!(tracker.check_and_add("u", ip_from_idx(2)).await.is_ok()); + assert!(tracker.check_and_add("u", ip_from_idx(3)).await.is_ok()); + assert!(tracker.check_and_add("u", ip_from_idx(4)).await.is_ok()); + assert!(tracker.check_and_add("u", ip_from_idx(5)).await.is_err()); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn concurrent_reconnect_and_disconnect_preserves_non_negative_counts() { + let tracker = Arc::new(UserIpTracker::new()); + tracker.set_user_limit("cc", 8).await; + + let mut handles = Vec::new(); + for worker in 0..8u32 { + let tracker_cloned = tracker.clone(); + handles.push(tokio::spawn(async move { + let ip = ip_from_idx(50 + worker); + for _ in 0..500u32 { + let _ = tracker_cloned.check_and_add("cc", ip).await; + tracker_cloned.remove_ip("cc", ip).await; + } + })); + } + + for handle in handles { + handle.await.unwrap(); + } + + assert!(tracker.get_active_ip_count("cc").await <= 8); +} diff --git a/src/maestro/runtime_tasks.rs b/src/maestro/runtime_tasks.rs index d4bda4d..c2233c7 100644 --- a/src/maestro/runtime_tasks.rs +++ b/src/maestro/runtime_tasks.rs @@ -279,11 +279,32 @@ pub(crate) async fn spawn_metrics_if_configured( ip_tracker: Arc, config_rx: watch::Receiver>, ) { - if let Some(port) = config.server.metrics_port { + // metrics_listen takes precedence; fall back to metrics_port for backward compat. + let metrics_target: Option<(u16, Option)> = + if let Some(ref listen) = config.server.metrics_listen { + match listen.parse::() { + Ok(addr) => Some((addr.port(), Some(listen.clone()))), + Err(e) => { + startup_tracker + .skip_component( + COMPONENT_METRICS_START, + Some(format!("invalid metrics_listen \"{}\": {}", listen, e)), + ) + .await; + None + } + } + } else { + config.server.metrics_port.map(|p| (p, None)) + }; + + if let Some((port, listen)) = metrics_target { + let fallback_label = format!("port {}", port); + let label = listen.as_deref().unwrap_or(&fallback_label); startup_tracker .start_component( COMPONENT_METRICS_START, - Some(format!("spawn metrics endpoint on {}", port)), + Some(format!("spawn metrics endpoint on {}", label)), ) .await; let stats = stats.clone(); @@ -294,6 +315,7 @@ pub(crate) async fn spawn_metrics_if_configured( tokio::spawn(async move { metrics::serve( port, + listen, stats, beobachten, ip_tracker_metrics, @@ -308,7 +330,7 @@ pub(crate) async fn spawn_metrics_if_configured( Some("metrics task spawned".to_string()), ) .await; - } else { + } else if config.server.metrics_listen.is_none() { startup_tracker .skip_component( COMPONENT_METRICS_START, diff --git a/src/main.rs b/src/main.rs index 73ada8c..2cfbe28 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,6 +6,8 @@ mod config; mod crypto; mod error; mod ip_tracker; +#[cfg(test)] +mod ip_tracker_regression_tests; mod maestro; mod metrics; mod network; diff --git a/src/metrics.rs b/src/metrics.rs index 02edfd7..f4f8a2e 100644 --- a/src/metrics.rs +++ b/src/metrics.rs @@ -21,6 +21,7 @@ use crate::transport::{ListenOptions, create_listener}; pub async fn serve( port: u16, + listen: Option, stats: Arc, beobachten: Arc, ip_tracker: Arc, @@ -28,6 +29,33 @@ pub async fn serve( whitelist: Vec, ) { let whitelist = Arc::new(whitelist); + + // If `metrics_listen` is set, bind on that single address only. + if let Some(ref listen_addr) = listen { + let addr: SocketAddr = match listen_addr.parse() { + Ok(a) => a, + Err(e) => { + warn!(error = %e, "Invalid metrics_listen address: {}", listen_addr); + return; + } + }; + let is_ipv6 = addr.is_ipv6(); + match bind_metrics_listener(addr, is_ipv6) { + Ok(listener) => { + info!("Metrics endpoint: http://{}/metrics and /beobachten", addr); + serve_listener( + listener, stats, beobachten, ip_tracker, config_rx, whitelist, + ) + .await; + } + Err(e) => { + warn!(error = %e, "Failed to bind metrics on {}", addr); + } + } + return; + } + + // Fallback: bind on 0.0.0.0 and [::] using metrics_port. let mut listener_v4 = None; let mut listener_v6 = None; diff --git a/src/protocol/tls.rs b/src/protocol/tls.rs index 0f54245..3f9f981 100644 --- a/src/protocol/tls.rs +++ b/src/protocol/tls.rs @@ -11,9 +11,8 @@ use crate::crypto::{sha256_hmac, SecureRandom}; use crate::error::ProxyError; use super::constants::*; use std::time::{SystemTime, UNIX_EPOCH}; -use num_bigint::BigUint; -use num_traits::One; use subtle::ConstantTimeEq; +use x25519_dalek::{X25519_BASEPOINT_BYTES, x25519}; // ============= Public Constants ============= @@ -27,8 +26,12 @@ pub const TLS_DIGEST_POS: usize = 11; pub const TLS_DIGEST_HALF_LEN: usize = 16; /// Time skew limits for anti-replay (in seconds) -pub const TIME_SKEW_MIN: i64 = -20 * 60; // 20 minutes before -pub const TIME_SKEW_MAX: i64 = 10 * 60; // 10 minutes after +/// +/// The default window is intentionally narrow to reduce replay acceptance. +/// Operators with known clock-drifted clients should tune deployment config +/// (for example replay-window policy) to match their environment. +pub const TIME_SKEW_MIN: i64 = -2 * 60; // 2 minutes before +pub const TIME_SKEW_MAX: i64 = 2 * 60; // 2 minutes after /// Maximum accepted boot-time timestamp (seconds) before skew checks are enforced. pub const BOOT_TIME_MAX_SECS: u32 = 7 * 24 * 60 * 60; @@ -117,27 +120,6 @@ impl TlsExtensionBuilder { self } - /// Add ALPN extension with a single selected protocol. - fn add_alpn(&mut self, proto: &[u8]) -> &mut Self { - // Extension type: ALPN (0x0010) - self.extensions.extend_from_slice(&extension_type::ALPN.to_be_bytes()); - - // ALPN extension format: - // extension_data length (2 bytes) - // protocols length (2 bytes) - // protocol name length (1 byte) - // protocol name bytes - let proto_len = proto.len() as u8; - let list_len: u16 = 1 + u16::from(proto_len); - let ext_len: u16 = 2 + list_len; - - self.extensions.extend_from_slice(&ext_len.to_be_bytes()); - self.extensions.extend_from_slice(&list_len.to_be_bytes()); - self.extensions.push(proto_len); - self.extensions.extend_from_slice(proto); - self - } - /// Build final extensions with length prefix fn build(self) -> Vec { let mut result = Vec::with_capacity(2 + self.extensions.len()); @@ -173,8 +155,6 @@ struct ServerHelloBuilder { compression: u8, /// Extensions extensions: TlsExtensionBuilder, - /// Selected ALPN protocol (if any) - alpn: Option>, } impl ServerHelloBuilder { @@ -185,7 +165,6 @@ impl ServerHelloBuilder { cipher_suite: cipher_suite::TLS_AES_128_GCM_SHA256, compression: 0x00, extensions: TlsExtensionBuilder::new(), - alpn: None, } } @@ -200,18 +179,9 @@ impl ServerHelloBuilder { self } - fn with_alpn(mut self, proto: Option>) -> Self { - self.alpn = proto; - self - } - /// Build ServerHello message (without record header) fn build_message(&self) -> Vec { - let mut ext_builder = self.extensions.clone(); - if let Some(ref alpn) = self.alpn { - ext_builder.add_alpn(alpn); - } - let extensions = ext_builder.extensions.clone(); + let extensions = self.extensions.extensions.clone(); let extensions_len = extensions.len() as u16; // Calculate total length @@ -316,7 +286,14 @@ pub fn validate_tls_handshake_with_replay_window( }; let replay_window_u32 = u32::try_from(replay_window_secs).unwrap_or(u32::MAX); - let boot_time_cap_secs = BOOT_TIME_MAX_SECS.min(replay_window_u32); + // Boot-time bypass and ignore_time_skew serve different compatibility paths. + // When skew checks are disabled, force boot-time cap to zero to prevent + // accidental future coupling of boot-time logic into the ignore-skew path. + let boot_time_cap_secs = if ignore_time_skew { + 0 + } else { + BOOT_TIME_MAX_SECS.min(replay_window_u32) + }; validate_tls_handshake_at_time_with_boot_cap( handshake, @@ -369,6 +346,9 @@ fn validate_tls_handshake_at_time_with_boot_cap( // Extract session ID let session_id_len_pos = TLS_DIGEST_POS + TLS_DIGEST_LEN; let session_id_len = handshake.get(session_id_len_pos).copied()? as usize; + if session_id_len > 32 { + return None; + } let session_id_start = session_id_len_pos + 1; if handshake.len() < session_id_start + session_id_len { @@ -411,7 +391,7 @@ fn validate_tls_handshake_at_time_with_boot_cap( if !ignore_time_skew { // Allow very small timestamps (boot time instead of unix time) // This is a quirk in some clients that use uptime instead of real time - let is_boot_time = timestamp < boot_time_cap_secs; + let is_boot_time = boot_time_cap_secs > 0 && timestamp < boot_time_cap_secs; if !is_boot_time { let time_diff = now - i64::from(timestamp); if !(TIME_SKEW_MIN..=TIME_SKEW_MAX).contains(&time_diff) { @@ -433,27 +413,14 @@ fn validate_tls_handshake_at_time_with_boot_cap( }) } -fn curve25519_prime() -> BigUint { - (BigUint::one() << 255) - BigUint::from(19u32) -} - /// Generate a fake X25519 public key for TLS /// -/// Produces a quadratic residue mod p = 2^255 - 19 by computing n² mod p, -/// which matches Python/C behavior and avoids DPI fingerprinting. +/// Uses RFC 7748 X25519 scalar multiplication over the canonical basepoint, +/// yielding distribution-consistent public keys for anti-fingerprinting. pub fn gen_fake_x25519_key(rng: &SecureRandom) -> [u8; 32] { - let mut n_bytes = [0u8; 32]; - n_bytes.copy_from_slice(&rng.bytes(32)); - - let n = BigUint::from_bytes_le(&n_bytes); - let p = curve25519_prime(); - let pk = (&n * &n) % &p; - - let mut out = pk.to_bytes_le(); - out.resize(32, 0); - let mut result = [0u8; 32]; - result.copy_from_slice(&out[..32]); - result + let mut scalar = [0u8; 32]; + scalar.copy_from_slice(&rng.bytes(32)); + x25519(scalar, X25519_BASEPOINT_BYTES) } /// Build TLS ServerHello response @@ -470,7 +437,7 @@ pub fn build_server_hello( session_id: &[u8], fake_cert_len: usize, rng: &SecureRandom, - alpn: Option>, + _alpn: Option>, new_session_tickets: u8, ) -> Vec { const MIN_APP_DATA: usize = 64; @@ -482,7 +449,6 @@ pub fn build_server_hello( let server_hello = ServerHelloBuilder::new(session_id.to_vec()) .with_x25519_key(&x25519_key) .with_tls13_version() - .with_alpn(alpn) .build_record(); // Build Change Cipher Spec record @@ -705,10 +671,10 @@ pub fn is_tls_handshake(first_bytes: &[u8]) -> bool { return false; } - // TLS record header: 0x16 (handshake) 0x03 0x01 (TLS 1.0) + // TLS ClientHello commonly uses legacy record versions 0x0301 or 0x0303. first_bytes[0] == TLS_RECORD_HANDSHAKE && first_bytes[1] == 0x03 - && first_bytes[2] == 0x01 + && (first_bytes[2] == 0x01 || first_bytes[2] == 0x03) } /// Parse TLS record header, returns (record_type, length) diff --git a/src/protocol/tls_security_tests.rs b/src/protocol/tls_security_tests.rs index 98d7319..9f568b5 100644 --- a/src/protocol/tls_security_tests.rs +++ b/src/protocol/tls_security_tests.rs @@ -1,5 +1,8 @@ use super::*; use crate::crypto::sha256_hmac; +use crate::tls_front::emulator::build_emulated_server_hello; +use crate::tls_front::types::{CachedTlsData, ParsedServerHello, TlsBehaviorProfile, TlsProfileSource}; +use std::time::SystemTime; /// Build a TLS-handshake-like buffer that contains a valid HMAC digest /// for the given `secret` and `timestamp`. @@ -369,16 +372,16 @@ fn one_byte_session_id_validates_and_is_preserved() { } #[test] -fn max_session_id_len_255_with_valid_digest_is_accepted() { +fn max_session_id_len_255_with_valid_digest_is_rejected_by_rfc_cap() { let secret = b"sid_len_255_test"; let session_id = vec![0xCCu8; 255]; let handshake = make_valid_tls_handshake_with_session_id(secret, 0, &session_id); let secrets = vec![("u".to_string(), secret.to_vec())]; - let result = validate_tls_handshake(&handshake, &secrets, true) - .expect("session_id_len=255 with valid digest must validate"); - assert_eq!(result.session_id.len(), 255); - assert_eq!(result.session_id, session_id); + assert!( + validate_tls_handshake(&handshake, &secrets, true).is_none(), + "legacy_session_id length > 32 must be rejected even with valid digest" + ); } // ------------------------------------------------------------------ @@ -731,6 +734,246 @@ fn replay_window_cap_still_allows_small_boot_timestamp() { ); } +#[test] +fn ignore_time_skew_explicitly_decouples_from_boot_time_cap() { + let secret = b"ignore_skew_boot_cap_decouple_test"; + let ts: u32 = 1; + let h = make_valid_tls_handshake(secret, ts); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + let cap_zero = validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, true, 0, 0); + let cap_nonzero = + validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, true, 0, BOOT_TIME_MAX_SECS); + + assert!(cap_zero.is_some(), "ignore_time_skew=true must accept valid HMAC"); + assert!( + cap_nonzero.is_some(), + "ignore_time_skew path must not depend on boot-time cap" + ); + + let a = cap_zero.unwrap(); + let b = cap_nonzero.unwrap(); + assert_eq!(a.user, b.user); + assert_eq!(a.timestamp, b.timestamp); +} + +#[test] +fn adversarial_small_boot_timestamp_matrix_rejected_when_boot_cap_forced_zero() { + let secret = b"boot_cap_zero_matrix_test"; + let secrets = vec![("u".to_string(), secret.to_vec())]; + let now: i64 = 1_700_000_000; + + for ts in 0u32..1024u32 { + let h = make_valid_tls_handshake(secret, ts); + let result = validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, false, now, 0); + assert!( + result.is_none(), + "boot cap=0 must reject timestamp {ts} when skew checks are active" + ); + } +} + +#[test] +fn light_fuzz_boot_cap_zero_rejects_small_timestamp_space() { + let secret = b"boot_cap_zero_fuzz_test"; + let secrets = vec![("u".to_string(), secret.to_vec())]; + let now: i64 = 1_700_000_000; + let mut s: u64 = 0x9E37_79B9_7F4A_7C15; + + for _ in 0..4096 { + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + let ts = (s as u32) % 2048; + + let h = make_valid_tls_handshake(secret, ts); + let result = validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, false, now, 0); + assert!( + result.is_none(), + "fuzzed boot-range timestamp {ts} must be rejected when cap=0" + ); + } +} + +#[test] +fn stress_boot_cap_zero_rejection_is_deterministic_under_high_iteration_count() { + let secret = b"boot_cap_zero_stress_test"; + let secrets = vec![("u".to_string(), secret.to_vec())]; + let now: i64 = 1_700_000_000; + + for i in 0u32..20_000u32 { + let ts = i % 4096; + let h = make_valid_tls_handshake(secret, ts); + let result = validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, false, now, 0); + assert!( + result.is_none(), + "iteration {i}: timestamp {ts} must be rejected with cap=0" + ); + } +} + +#[test] +fn replay_window_one_allows_only_zero_timestamp_boot_bypass() { + let secret = b"replay_window_one_boot_test"; + let secrets = vec![("u".to_string(), secret.to_vec())]; + + let ts0 = make_valid_tls_handshake(secret, 0); + let ts1 = make_valid_tls_handshake(secret, 1); + + assert!( + validate_tls_handshake_with_replay_window(&ts0, &secrets, false, 1).is_some(), + "replay_window=1 must allow timestamp 0 via boot-time compatibility" + ); + assert!( + validate_tls_handshake_with_replay_window(&ts1, &secrets, false, 1).is_none(), + "replay_window=1 must reject timestamp 1 on normal wall-clock systems" + ); +} + +#[test] +fn replay_window_two_allows_ts0_ts1_but_rejects_ts2() { + let secret = b"replay_window_two_boot_test"; + let secrets = vec![("u".to_string(), secret.to_vec())]; + + let ts0 = make_valid_tls_handshake(secret, 0); + let ts1 = make_valid_tls_handshake(secret, 1); + let ts2 = make_valid_tls_handshake(secret, 2); + + assert!(validate_tls_handshake_with_replay_window(&ts0, &secrets, false, 2).is_some()); + assert!(validate_tls_handshake_with_replay_window(&ts1, &secrets, false, 2).is_some()); + assert!( + validate_tls_handshake_with_replay_window(&ts2, &secrets, false, 2).is_none(), + "timestamp equal to replay-window cap must not use boot-time bypass" + ); +} + +#[test] +fn adversarial_skew_boundary_matrix_accepts_only_inclusive_window_when_boot_disabled() { + let secret = b"skew_boundary_matrix_test"; + let secrets = vec![("u".to_string(), secret.to_vec())]; + let now: i64 = 1_700_000_000; + + for offset in -1500i64..=1500i64 { + let ts_i64 = now - offset; + let ts = u32::try_from(ts_i64).expect("timestamp must fit u32 for test matrix"); + let h = make_valid_tls_handshake(secret, ts); + let accepted = validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, false, now, 0) + .is_some(); + let expected = (TIME_SKEW_MIN..=TIME_SKEW_MAX).contains(&offset); + assert_eq!( + accepted, expected, + "offset {offset} must match inclusive skew window when boot bypass is disabled" + ); + } +} + +#[test] +fn light_fuzz_skew_window_rejects_outside_range_when_boot_disabled() { + let secret = b"skew_outside_fuzz_test"; + let secrets = vec![("u".to_string(), secret.to_vec())]; + let now: i64 = 1_700_000_000; + let mut s: u64 = 0x0123_4567_89AB_CDEF; + + for _ in 0..4096 { + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + + let magnitude = 1300i64 + ((s % 2000u64) as i64); + let sign = if (s & 1) == 0 { 1i64 } else { -1i64 }; + let offset = sign * magnitude; + let ts_i64 = now - offset; + let ts = u32::try_from(ts_i64).expect("timestamp must fit u32 for fuzz test"); + + let h = make_valid_tls_handshake(secret, ts); + let accepted = validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, false, now, 0) + .is_some(); + assert!( + !accepted, + "offset {offset} must be rejected outside strict skew window" + ); + } +} + +#[test] +fn stress_boot_disabled_validation_matches_time_diff_oracle() { + let secret = b"boot_disabled_oracle_stress_test"; + let secrets = vec![("u".to_string(), secret.to_vec())]; + let now: i64 = 1_700_000_000; + let mut s: u64 = 0xBADC_0FFE_EE11_2233; + + for _ in 0..25_000 { + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + let ts = s as u32; + let h = make_valid_tls_handshake(secret, ts); + + let accepted = validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, false, now, 0) + .is_some(); + let time_diff = now - i64::from(ts); + let expected = (TIME_SKEW_MIN..=TIME_SKEW_MAX).contains(&time_diff); + assert_eq!( + accepted, expected, + "boot-disabled validation must match pure time-diff oracle" + ); + } +} + +#[test] +fn integration_large_user_list_with_boot_disabled_finds_only_matching_user() { + let now: i64 = 1_700_000_000; + let target_secret = b"target_user_secret"; + let target_ts = (now - 1) as u32; + let handshake = make_valid_tls_handshake(target_secret, target_ts); + + let mut secrets = Vec::new(); + for i in 0..512u32 { + secrets.push((format!("noise-{i}"), format!("noise-secret-{i}").into_bytes())); + } + secrets.push(("target-user".to_string(), target_secret.to_vec())); + + let result = validate_tls_handshake_at_time_with_boot_cap(&handshake, &secrets, false, now, 0) + .expect("matching user should validate within strict skew window"); + assert_eq!(result.user, "target-user"); +} + +#[test] +fn light_fuzz_ignore_time_skew_accepts_wide_timestamp_range_with_valid_hmac() { + let secret = b"ignore_skew_fuzz_accept_test"; + let secrets = vec![("u".to_string(), secret.to_vec())]; + let mut s: u64 = 0xC0FF_EE11_2233_4455; + + for _ in 0..2048 { + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + let ts = s as u32; + + let h = make_valid_tls_handshake(secret, ts); + let result = validate_tls_handshake_with_replay_window(&h, &secrets, true, 60); + assert!( + result.is_some(), + "ignore_time_skew=true must accept valid HMAC for arbitrary timestamp" + ); + } +} + +#[test] +fn light_fuzz_small_replay_window_rejects_far_timestamps_when_skew_enabled() { + let secret = b"replay_window_reject_fuzz_test"; + let secrets = vec![("u".to_string(), secret.to_vec())]; + + for ts in 300u32..=1323u32 { + let h = make_valid_tls_handshake(secret, ts); + let result = validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, false, 0, 300); + assert!( + result.is_none(), + "with skew checks enabled and boot cap=300, timestamp >=300 at now=0 must be rejected" + ); + } +} + // ------------------------------------------------------------------ // Extreme timestamp values // ------------------------------------------------------------------ @@ -897,7 +1140,9 @@ fn first_matching_user_wins_over_later_duplicate_secret() { #[test] fn test_is_tls_handshake() { assert!(is_tls_handshake(&[0x16, 0x03, 0x01])); + assert!(is_tls_handshake(&[0x16, 0x03, 0x03])); assert!(is_tls_handshake(&[0x16, 0x03, 0x01, 0x02, 0x00])); + assert!(is_tls_handshake(&[0x16, 0x03, 0x03, 0x02, 0x00])); assert!(!is_tls_handshake(&[0x17, 0x03, 0x01])); assert!(!is_tls_handshake(&[0x16, 0x03, 0x02])); assert!(!is_tls_handshake(&[0x16, 0x03])); @@ -945,17 +1190,158 @@ fn test_gen_fake_x25519_key() { } #[test] -fn test_fake_x25519_key_is_quadratic_residue() { - use num_bigint::BigUint; - use num_traits::One; - +fn test_fake_x25519_key_is_nonzero_and_varies() { let rng = crate::crypto::SecureRandom::new(); - let key = gen_fake_x25519_key(&rng); - let p = curve25519_prime(); - let k_num = BigUint::from_bytes_le(&key); - let exponent = (&p - BigUint::one()) >> 1; - let legendre = k_num.modpow(&exponent, &p); - assert_eq!(legendre, BigUint::one()); + let mut unique = std::collections::HashSet::new(); + let mut saw_non_zero = false; + + for _ in 0..64 { + let key = gen_fake_x25519_key(&rng); + if key != [0u8; 32] { + saw_non_zero = true; + } + unique.insert(key); + } + + assert!( + saw_non_zero, + "generated X25519 public keys must not collapse to all-zero output" + ); + assert!( + unique.len() > 1, + "generated X25519 public keys must vary across invocations" + ); +} + +#[test] +fn validate_tls_handshake_rejects_session_id_longer_than_rfc_cap() { + let secret = b"session_id_cap_secret"; + let oversized_sid = vec![0x42u8; 33]; + let handshake = make_valid_tls_handshake_with_session_id(secret, 0, &oversized_sid); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + assert!( + validate_tls_handshake(&handshake, &secrets, true).is_none(), + "legacy_session_id length > 32 must be rejected" + ); +} + +fn server_hello_extension_types(record: &[u8]) -> Vec { + if record.len() < 9 || record[0] != TLS_RECORD_HANDSHAKE || record[5] != 0x02 { + return Vec::new(); + } + + let record_len = u16::from_be_bytes([record[3], record[4]]) as usize; + if record.len() < 5 + record_len { + return Vec::new(); + } + + let hs_len = u32::from_be_bytes([0, record[6], record[7], record[8]]) as usize; + let hs_start = 5; + let hs_end = hs_start + 4 + hs_len; + if hs_end > record.len() { + return Vec::new(); + } + + let mut pos = hs_start + 4 + 2 + 32; + if pos >= hs_end { + return Vec::new(); + } + let sid_len = record[pos] as usize; + pos += 1 + sid_len; + if pos + 2 + 1 + 2 > hs_end { + return Vec::new(); + } + + pos += 2 + 1; + let ext_len = u16::from_be_bytes([record[pos], record[pos + 1]]) as usize; + pos += 2; + let ext_end = pos + ext_len; + if ext_end > hs_end { + return Vec::new(); + } + + let mut out = Vec::new(); + while pos + 4 <= ext_end { + let etype = u16::from_be_bytes([record[pos], record[pos + 1]]); + let elen = u16::from_be_bytes([record[pos + 2], record[pos + 3]]) as usize; + pos += 4; + if pos + elen > ext_end { + break; + } + out.push(etype); + pos += elen; + } + out +} + +#[test] +fn build_server_hello_never_places_alpn_in_server_hello_extensions() { + let secret = b"alpn_sh_forbidden"; + let client_digest = [0x11u8; 32]; + let session_id = vec![0xAA; 32]; + let rng = crate::crypto::SecureRandom::new(); + + let response = build_server_hello( + secret, + &client_digest, + &session_id, + 1024, + &rng, + Some(b"h2".to_vec()), + 0, + ); + let exts = server_hello_extension_types(&response); + assert!( + !exts.contains(&0x0010), + "ALPN extension must not appear in ServerHello" + ); +} + +#[test] +fn emulated_server_hello_never_places_alpn_in_server_hello_extensions() { + let secret = b"alpn_emulated_forbidden"; + let client_digest = [0x22u8; 32]; + let session_id = vec![0xAB; 32]; + let rng = crate::crypto::SecureRandom::new(); + let cached = CachedTlsData { + server_hello_template: ParsedServerHello { + version: TLS_VERSION, + random: [0u8; 32], + session_id: Vec::new(), + cipher_suite: [0x13, 0x01], + compression: 0, + extensions: Vec::new(), + }, + cert_info: None, + cert_payload: None, + app_data_records_sizes: vec![1024], + total_app_data_len: 1024, + behavior_profile: TlsBehaviorProfile { + change_cipher_spec_count: 1, + app_data_record_sizes: vec![1024], + ticket_record_sizes: Vec::new(), + source: TlsProfileSource::Default, + }, + fetched_at: SystemTime::now(), + domain: "example.com".to_string(), + }; + + let response = build_emulated_server_hello( + secret, + &client_digest, + &session_id, + &cached, + false, + &rng, + Some(b"h2".to_vec()), + 0, + ); + let exts = server_hello_extension_types(&response); + assert!( + !exts.contains(&0x0010), + "ALPN extension must not appear in emulated ServerHello" + ); } #[test] @@ -1394,3 +1780,191 @@ fn server_hello_application_data_payload_varies_across_runs() { "ApplicationData payload should vary across runs to reduce fingerprintability" ); } + +#[test] +fn replay_window_zero_disables_boot_bypass_for_any_nonzero_timestamp() { + let secret = b"window_zero_boot_bypass_test"; + let secrets = vec![("u".to_string(), secret.to_vec())]; + + let ts1 = make_valid_tls_handshake(secret, 1); + assert!( + validate_tls_handshake_with_replay_window(&ts1, &secrets, false, 0).is_none(), + "replay_window_secs=0 must reject nonzero timestamps even in boot-time range" + ); + + let ts0 = make_valid_tls_handshake(secret, 0); + assert!( + validate_tls_handshake_with_replay_window(&ts0, &secrets, false, 0).is_none(), + "replay_window_secs=0 enforces strict skew check and rejects timestamp=0 on normal wall-clock systems" + ); +} + +#[test] +fn large_replay_window_does_not_expand_time_skew_acceptance() { + let secret = b"large_replay_window_skew_bound_test"; + let secrets = vec![("u".to_string(), secret.to_vec())]; + let now: i64 = 1_700_000_000; + + let ts_far_past = (now - 600) as u32; + let valid = make_valid_tls_handshake(secret, ts_far_past); + assert!( + validate_tls_handshake_with_replay_window(&valid, &secrets, false, 86_400).is_none(), + "large replay window must not relax strict skew check once boot-time bypass is not in play" + ); +} + +#[test] +fn parse_tls_record_header_accepts_tls_version_constant() { + let header = [TLS_RECORD_HANDSHAKE, TLS_VERSION[0], TLS_VERSION[1], 0x00, 0x2A]; + let parsed = parse_tls_record_header(&header).expect("TLS_VERSION header should be accepted"); + assert_eq!(parsed.0, TLS_RECORD_HANDSHAKE); + assert_eq!(parsed.1, 42); +} + +#[test] +fn server_hello_clamps_fake_cert_len_lower_bound() { + let secret = b"fake_cert_lower_bound_test"; + let client_digest = [0x11u8; TLS_DIGEST_LEN]; + let session_id = vec![0x77; 32]; + let rng = crate::crypto::SecureRandom::new(); + + let response = build_server_hello(secret, &client_digest, &session_id, 1, &rng, None, 0); + + let sh_len = u16::from_be_bytes([response[3], response[4]]) as usize; + let ccs_pos = 5 + sh_len; + let ccs_len = u16::from_be_bytes([response[ccs_pos + 3], response[ccs_pos + 4]]) as usize; + let app_pos = ccs_pos + 5 + ccs_len; + let app_len = u16::from_be_bytes([response[app_pos + 3], response[app_pos + 4]]) as usize; + + assert_eq!(response[app_pos], TLS_RECORD_APPLICATION); + assert_eq!(app_len, 64, "fake cert payload must be clamped to minimum 64 bytes"); +} + +#[test] +fn server_hello_clamps_fake_cert_len_upper_bound() { + let secret = b"fake_cert_upper_bound_test"; + let client_digest = [0x22u8; TLS_DIGEST_LEN]; + let session_id = vec![0x66; 32]; + let rng = crate::crypto::SecureRandom::new(); + + let response = build_server_hello(secret, &client_digest, &session_id, 65_535, &rng, None, 0); + + let sh_len = u16::from_be_bytes([response[3], response[4]]) as usize; + let ccs_pos = 5 + sh_len; + let ccs_len = u16::from_be_bytes([response[ccs_pos + 3], response[ccs_pos + 4]]) as usize; + let app_pos = ccs_pos + 5 + ccs_len; + let app_len = u16::from_be_bytes([response[app_pos + 3], response[app_pos + 4]]) as usize; + + assert_eq!(response[app_pos], TLS_RECORD_APPLICATION); + assert_eq!(app_len, 16_640, "fake cert payload must be clamped to TLS record max bound"); +} + +#[test] +fn server_hello_new_session_ticket_count_matches_configuration() { + let secret = b"ticket_count_surface_test"; + let client_digest = [0x33u8; TLS_DIGEST_LEN]; + let session_id = vec![0x55; 32]; + let rng = crate::crypto::SecureRandom::new(); + + let tickets: u8 = 3; + let response = build_server_hello(secret, &client_digest, &session_id, 1024, &rng, None, tickets); + + let mut pos = 0usize; + let mut app_records = 0usize; + while pos + 5 <= response.len() { + let rtype = response[pos]; + let rlen = u16::from_be_bytes([response[pos + 3], response[pos + 4]]) as usize; + let next = pos + 5 + rlen; + assert!(next <= response.len(), "TLS record must stay inside response bounds"); + if rtype == TLS_RECORD_APPLICATION { + app_records += 1; + } + pos = next; + } + + assert_eq!( + app_records, + 1 + tickets as usize, + "response must contain one main application record plus configured ticket-like tail records" + ); +} + +#[test] +fn exhaustive_tls_minor_version_classification_matches_policy() { + for minor in 0u8..=u8::MAX { + let first = [TLS_RECORD_HANDSHAKE, 0x03, minor]; + let expected = minor == 0x01 || minor == 0x03; + assert_eq!( + is_tls_handshake(&first), + expected, + "minor version {minor:#04x} classification mismatch" + ); + } +} + +#[test] +fn light_fuzz_tls_header_classifier_and_parser_policy_consistency() { + // Deterministic xorshift state keeps this fuzz test reproducible. + let mut s: u64 = 0x9E37_79B9_AA95_5A5D; + + for _ in 0..10_000 { + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + + let header = [ + (s & 0xff) as u8, + ((s >> 8) & 0xff) as u8, + ((s >> 16) & 0xff) as u8, + ((s >> 24) & 0xff) as u8, + ((s >> 32) & 0xff) as u8, + ]; + + let classified = is_tls_handshake(&header[..3]); + let expected_classified = header[0] == TLS_RECORD_HANDSHAKE + && header[1] == 0x03 + && (header[2] == 0x01 || header[2] == 0x03); + assert_eq!( + classified, + expected_classified, + "classifier policy mismatch for header {header:02x?}" + ); + + let parsed = parse_tls_record_header(&header); + let expected_parsed = header[1] == 0x03 && (header[2] == 0x01 || header[2] == TLS_VERSION[1]); + assert_eq!( + parsed.is_some(), + expected_parsed, + "parser policy mismatch for header {header:02x?}" + ); + } +} + +#[test] +fn stress_random_noise_handshakes_never_authenticate() { + let secret = b"stress_noise_secret"; + let secrets = vec![("noise-user".to_string(), secret.to_vec())]; + + // Deterministic xorshift state keeps this stress test reproducible. + let mut s: u64 = 0xD1B5_4A32_9C6E_77F1; + + for _ in 0..5_000 { + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + + let len = 1 + ((s as usize) % 196); + let mut buf = vec![0u8; len]; + for b in &mut buf { + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + *b = (s & 0xff) as u8; + } + + assert!( + validate_tls_handshake(&buf, &secrets, true).is_none(), + "random noise must never authenticate" + ); + } +} diff --git a/src/proxy/client.rs b/src/proxy/client.rs index 5ccbd40..199f775 100644 --- a/src/proxy/client.rs +++ b/src/proxy/client.rs @@ -24,6 +24,72 @@ enum HandshakeOutcome { Handled, } +#[must_use = "UserConnectionReservation must be kept alive to retain user/IP reservation until release or drop"] +struct UserConnectionReservation { + stats: Arc, + ip_tracker: Arc, + user: String, + ip: IpAddr, + active: bool, + runtime_handle: Option, +} + +impl UserConnectionReservation { + fn new(stats: Arc, ip_tracker: Arc, user: String, ip: IpAddr) -> Self { + let runtime_handle = tokio::runtime::Handle::try_current().ok(); + Self { + stats, + ip_tracker, + user, + ip, + active: true, + runtime_handle, + } + } + + async fn release(mut self) { + if !self.active { + return; + } + self.ip_tracker.remove_ip(&self.user, self.ip).await; + self.active = false; + self.stats.decrement_user_curr_connects(&self.user); + } +} + +impl Drop for UserConnectionReservation { + fn drop(&mut self) { + if !self.active { + return; + } + self.active = false; + self.stats.decrement_user_curr_connects(&self.user); + + if let Some(handle) = &self.runtime_handle { + let ip_tracker = self.ip_tracker.clone(); + let user = self.user.clone(); + let ip = self.ip; + let handle = handle.clone(); + handle.spawn(async move { + ip_tracker.remove_ip(&user, ip).await; + }); + } else if let Ok(handle) = tokio::runtime::Handle::try_current() { + let ip_tracker = self.ip_tracker.clone(); + let user = self.user.clone(); + let ip = self.ip; + handle.spawn(async move { + ip_tracker.remove_ip(&user, ip).await; + }); + } else { + warn!( + user = %self.user, + ip = %self.ip, + "UserConnectionReservation dropped without Tokio runtime; IP reservation cleanup skipped" + ); + } + } +} + use crate::config::ProxyConfig; use crate::crypto::SecureRandom; use crate::error::{HandshakeResult, ProxyError, Result, StreamError}; @@ -45,7 +111,19 @@ use crate::proxy::middle_relay::handle_via_middle_proxy; use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController}; fn beobachten_ttl(config: &ProxyConfig) -> Duration { - Duration::from_secs(config.general.beobachten_minutes.saturating_mul(60)) + let minutes = config.general.beobachten_minutes; + if minutes == 0 { + static BEOBACHTEN_ZERO_MINUTES_WARNED: OnceLock = OnceLock::new(); + let warned = BEOBACHTEN_ZERO_MINUTES_WARNED.get_or_init(|| AtomicBool::new(false)); + if !warned.swap(true, Ordering::Relaxed) { + warn!( + "general.beobachten_minutes=0 is insecure because entries expire immediately; forcing minimum TTL to 1 minute" + ); + } + return Duration::from_secs(60); + } + + Duration::from_secs(minutes.saturating_mul(60)) } fn record_beobachten_class( @@ -90,6 +168,10 @@ fn is_trusted_proxy_source(peer_ip: IpAddr, trusted: &[IpNetwork]) -> bool { trusted.iter().any(|cidr| cidr.contains(peer_ip)) } +fn synthetic_local_addr(port: u16) -> SocketAddr { + SocketAddr::from(([0, 0, 0, 0], port)) +} + pub async fn handle_client_stream( mut stream: S, peer: SocketAddr, @@ -113,9 +195,7 @@ where let mut real_peer = normalize_ip(peer); // For non-TCP streams, use a synthetic local address; may be overridden by PROXY protocol dst - let mut local_addr: SocketAddr = format!("0.0.0.0:{}", config.server.port) - .parse() - .unwrap_or_else(|_| "0.0.0.0:443".parse().unwrap()); + let mut local_addr = synthetic_local_addr(config.server.port); if proxy_protocol_enabled { let proxy_header_timeout = Duration::from_millis( @@ -426,7 +506,6 @@ impl RunningClientHandler { pub async fn run(self) -> Result<()> { self.stats.increment_connects_all(); let peer = self.peer; - let _ip_tracker = self.ip_tracker.clone(); debug!(peer = %peer, "New connection"); if let Err(e) = configure_client_socket( @@ -557,7 +636,6 @@ impl RunningClientHandler { let is_tls = tls::is_tls_handshake(&first_bytes[..3]); let peer = self.peer; - let _ip_tracker = self.ip_tracker.clone(); debug!(peer = %peer, is_tls = is_tls, "Handshake type detected"); @@ -570,7 +648,6 @@ impl RunningClientHandler { async fn handle_tls_client(mut self, first_bytes: [u8; 5], local_addr: SocketAddr) -> Result { let peer = self.peer; - let _ip_tracker = self.ip_tracker.clone(); let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize; @@ -694,7 +771,6 @@ impl RunningClientHandler { async fn handle_direct_client(mut self, first_bytes: [u8; 5], local_addr: SocketAddr) -> Result { let peer = self.peer; - let _ip_tracker = self.ip_tracker.clone(); if !self.config.general.modes.classic && !self.config.general.modes.secure { debug!(peer = %peer, "Non-TLS modes disabled"); @@ -798,10 +874,22 @@ impl RunningClientHandler { { let user = success.user.clone(); - if let Err(e) = Self::check_user_limits_static(&user, &config, &stats, peer_addr, &ip_tracker).await { - warn!(user = %user, error = %e, "User limit exceeded"); - return Err(e); - } + let user_limit_reservation = + match Self::acquire_user_connection_reservation_static( + &user, + &config, + stats.clone(), + peer_addr, + ip_tracker, + ) + .await + { + Ok(reservation) => reservation, + Err(e) => { + warn!(user = %user, error = %e, "User admission check failed"); + return Err(e); + } + }; let route_snapshot = route_runtime.snapshot(); let session_id = rng.u64(); @@ -858,15 +946,68 @@ impl RunningClientHandler { ) .await }; - - stats.decrement_user_curr_connects(&user); - ip_tracker.remove_ip(&user, peer_addr.ip()).await; + user_limit_reservation.release().await; relay_result } + async fn acquire_user_connection_reservation_static( + user: &str, + config: &ProxyConfig, + stats: Arc, + peer_addr: SocketAddr, + ip_tracker: Arc, + ) -> Result { + if let Some(expiration) = config.access.user_expirations.get(user) + && chrono::Utc::now() > *expiration + { + return Err(ProxyError::UserExpired { + user: user.to_string(), + }); + } + + if let Some(quota) = config.access.user_data_quota.get(user) + && stats.get_user_total_octets(user) >= *quota + { + return Err(ProxyError::DataQuotaExceeded { + user: user.to_string(), + }); + } + + let limit = config.access.user_max_tcp_conns.get(user).map(|v| *v as u64); + if !stats.try_acquire_user_curr_connects(user, limit) { + return Err(ProxyError::ConnectionLimitExceeded { + user: user.to_string(), + }); + } + + match ip_tracker.check_and_add(user, peer_addr.ip()).await { + Ok(()) => {} + Err(reason) => { + stats.decrement_user_curr_connects(user); + warn!( + user = %user, + ip = %peer_addr.ip(), + reason = %reason, + "IP limit exceeded" + ); + return Err(ProxyError::ConnectionLimitExceeded { + user: user.to_string(), + }); + } + } + + Ok(UserConnectionReservation::new( + stats, + ip_tracker, + user.to_string(), + peer_addr.ip(), + )) + } + + #[cfg(test)] async fn check_user_limits_static( - user: &str, - config: &ProxyConfig, + user: &str, + config: &ProxyConfig, stats: &Stats, peer_addr: SocketAddr, ip_tracker: &UserIpTracker, @@ -899,7 +1040,10 @@ impl RunningClientHandler { } match ip_tracker.check_and_add(user, peer_addr.ip()).await { - Ok(()) => {} + Ok(()) => { + ip_tracker.remove_ip(user, peer_addr.ip()).await; + stats.decrement_user_curr_connects(user); + } Err(reason) => { stats.decrement_user_curr_connects(user); warn!( diff --git a/src/proxy/client_security_tests.rs b/src/proxy/client_security_tests.rs index 415cafd..6b236aa 100644 --- a/src/proxy/client_security_tests.rs +++ b/src/proxy/client_security_tests.rs @@ -1,11 +1,279 @@ use super::*; use crate::config::{UpstreamConfig, UpstreamType}; +use crate::crypto::AesCtr; use crate::crypto::sha256_hmac; +use crate::protocol::constants::ProtoTag; use crate::protocol::tls; +use crate::proxy::handshake::HandshakeSuccess; use crate::transport::proxy_protocol::ProxyProtocolV1Builder; +use crate::stream::{CryptoReader, CryptoWriter}; use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream}; +#[test] +fn synthetic_local_addr_uses_configured_port_for_zero() { + let addr = synthetic_local_addr(0); + assert_eq!(addr.ip(), IpAddr::from([0, 0, 0, 0])); + assert_eq!(addr.port(), 0); +} + +#[test] +fn synthetic_local_addr_uses_configured_port_for_max() { + let addr = synthetic_local_addr(u16::MAX); + assert_eq!(addr.ip(), IpAddr::from([0, 0, 0, 0])); + assert_eq!(addr.port(), u16::MAX); +} + +fn make_crypto_reader(reader: R) -> CryptoReader +where + R: tokio::io::AsyncRead + Unpin, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoReader::new(reader, AesCtr::new(&key, iv)) +} + +fn make_crypto_writer(writer: W) -> CryptoWriter +where + W: tokio::io::AsyncWrite + Unpin, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) +} + +#[tokio::test] +async fn relay_task_abort_releases_user_gate_and_ip_reservation() { + let tg_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let tg_addr = tg_listener.local_addr().unwrap(); + + let tg_accept_task = tokio::spawn(async move { + let (stream, _) = tg_listener.accept().await.unwrap(); + let _hold_stream = stream; + tokio::time::sleep(Duration::from_secs(60)).await; + }); + + let user = "abort-user"; + let peer_addr: SocketAddr = "198.51.100.230:50000".parse().unwrap(); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let mut cfg = ProxyConfig::default(); + cfg.access.user_max_tcp_conns.insert(user.to_string(), 8); + cfg.dc_overrides + .insert("2".to_string(), vec![tg_addr.to_string()]); + let config = Arc::new(cfg); + + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + + let (server_side, client_side) = duplex(64 * 1024); + let (server_reader, server_writer) = tokio::io::split(server_side); + let client_reader = make_crypto_reader(server_reader); + let client_writer = make_crypto_writer(server_writer); + + let success = HandshakeSuccess { + user: user.to_string(), + dc_idx: 2, + proto_tag: ProtoTag::Intermediate, + dec_key: [0u8; 32], + dec_iv: 0, + enc_key: [0u8; 32], + enc_iv: 0, + peer: peer_addr, + is_tls: false, + }; + + let relay_task = tokio::spawn(RunningClientHandler::handle_authenticated_static( + client_reader, + client_writer, + success, + upstream_manager, + stats.clone(), + config, + buffer_pool, + rng, + None, + route_runtime, + "127.0.0.1:443".parse().unwrap(), + peer_addr, + ip_tracker.clone(), + )); + + tokio::time::timeout(Duration::from_secs(2), async { + loop { + if stats.get_user_curr_connects(user) == 1 + && ip_tracker.get_active_ip_count(user).await == 1 + { + break; + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + }) + .await + .expect("relay must reserve user slot and IP before abort"); + + relay_task.abort(); + let joined = relay_task.await; + assert!(joined.is_err(), "aborted relay task must return join error"); + + tokio::time::sleep(Duration::from_millis(50)).await; + assert_eq!( + stats.get_user_curr_connects(user), + 0, + "task abort must release user current-connection slot" + ); + assert_eq!( + ip_tracker.get_active_ip_count(user).await, + 0, + "task abort must release reserved user IP footprint" + ); + + drop(client_side); + tg_accept_task.abort(); + let _ = tg_accept_task.await; +} + +#[tokio::test] +async fn relay_cutover_releases_user_gate_and_ip_reservation() { + let tg_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let tg_addr = tg_listener.local_addr().unwrap(); + + let tg_accept_task = tokio::spawn(async move { + let (stream, _) = tg_listener.accept().await.unwrap(); + let _hold_stream = stream; + tokio::time::sleep(Duration::from_secs(60)).await; + }); + + let user = "cutover-user"; + let peer_addr: SocketAddr = "198.51.100.231:50001".parse().unwrap(); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let mut cfg = ProxyConfig::default(); + cfg.access.user_max_tcp_conns.insert(user.to_string(), 8); + cfg.dc_overrides + .insert("2".to_string(), vec![tg_addr.to_string()]); + let config = Arc::new(cfg); + + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + + let (server_side, client_side) = duplex(64 * 1024); + let (server_reader, server_writer) = tokio::io::split(server_side); + let client_reader = make_crypto_reader(server_reader); + let client_writer = make_crypto_writer(server_writer); + + let success = HandshakeSuccess { + user: user.to_string(), + dc_idx: 2, + proto_tag: ProtoTag::Intermediate, + dec_key: [0u8; 32], + dec_iv: 0, + enc_key: [0u8; 32], + enc_iv: 0, + peer: peer_addr, + is_tls: false, + }; + + let relay_task = tokio::spawn(RunningClientHandler::handle_authenticated_static( + client_reader, + client_writer, + success, + upstream_manager, + stats.clone(), + config, + buffer_pool, + rng, + None, + route_runtime.clone(), + "127.0.0.1:443".parse().unwrap(), + peer_addr, + ip_tracker.clone(), + )); + + tokio::time::timeout(Duration::from_secs(2), async { + loop { + if stats.get_user_curr_connects(user) == 1 + && ip_tracker.get_active_ip_count(user).await == 1 + { + break; + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + }) + .await + .expect("relay must reserve user slot and IP before cutover"); + + assert!( + route_runtime.set_mode(RelayRouteMode::Middle).is_some(), + "cutover must advance route generation" + ); + + let relay_result = tokio::time::timeout(Duration::from_secs(6), relay_task) + .await + .expect("relay must terminate after cutover") + .expect("relay task must not panic"); + assert!(relay_result.is_err(), "cutover must terminate direct relay session"); + + assert_eq!( + stats.get_user_curr_connects(user), + 0, + "cutover exit must release user current-connection slot" + ); + assert_eq!( + ip_tracker.get_active_ip_count(user).await, + 0, + "cutover exit must release reserved user IP footprint" + ); + + drop(client_side); + tg_accept_task.abort(); + let _ = tg_accept_task.await; +} + #[tokio::test] async fn short_tls_probe_is_masked_through_client_pipeline() { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); @@ -93,6 +361,93 @@ async fn short_tls_probe_is_masked_through_client_pipeline() { accept_task.await.unwrap(); } +#[tokio::test] +async fn tls12_record_probe_is_masked_through_client_pipeline() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = vec![0x16, 0x03, 0x03, 0x00, 0x10]; + let backend_reply = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK".to_vec(); + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = vec![0u8; probe.len()]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, probe); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(4096); + let peer: SocketAddr = "203.0.113.78:55001".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + )); + + client_side.write_all(&probe).await.unwrap(); + let mut observed = vec![0u8; backend_reply.len()]; + client_side.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, backend_reply); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + accept_task.await.unwrap(); +} + #[tokio::test] async fn handle_client_stream_increments_connects_all_exactly_once() { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); @@ -1113,6 +1468,34 @@ fn non_eof_error_is_classified_as_other() { ); } +#[test] +fn beobachten_ttl_zero_minutes_is_floored_to_one_minute() { + let mut config = ProxyConfig::default(); + config.general.beobachten = true; + config.general.beobachten_minutes = 0; + + let ttl = beobachten_ttl(&config); + assert_eq!( + ttl, + Duration::from_secs(60), + "beobachten_minutes=0 must be fail-closed to a one-minute minimum TTL" + ); +} + +#[test] +fn beobachten_ttl_positive_minutes_remain_unchanged() { + let mut config = ProxyConfig::default(); + config.general.beobachten = true; + config.general.beobachten_minutes = 7; + + let ttl = beobachten_ttl(&config); + assert_eq!( + ttl, + Duration::from_secs(7 * 60), + "configured positive beobacten TTL must be preserved" + ); +} + #[tokio::test] async fn tcp_limit_rejection_does_not_reserve_ip_or_trigger_rollback() { let mut config = ProxyConfig::default(); @@ -1152,6 +1535,857 @@ async fn tcp_limit_rejection_does_not_reserve_ip_or_trigger_rollback() { ); } +#[tokio::test] +async fn zero_tcp_limit_rejects_without_ip_or_counter_side_effects() { + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert("user".to_string(), 0); + + let stats = Stats::new(); + let ip_tracker = UserIpTracker::new(); + let peer_addr: SocketAddr = "198.51.100.211:50001".parse().unwrap(); + + let result = RunningClientHandler::check_user_limits_static( + "user", + &config, + &stats, + peer_addr, + &ip_tracker, + ) + .await; + + assert!(matches!( + result, + Err(ProxyError::ConnectionLimitExceeded { user }) if user == "user" + )); + assert_eq!(stats.get_user_curr_connects("user"), 0); + assert_eq!(ip_tracker.get_active_ip_count("user").await, 0); +} + +#[tokio::test] +async fn check_user_limits_static_success_does_not_leak_counter_or_ip_reservation() { + let user = "check-helper-user"; + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 1); + + let stats = Stats::new(); + let ip_tracker = UserIpTracker::new(); + let peer_addr: SocketAddr = "198.51.100.212:50002".parse().unwrap(); + + let first = RunningClientHandler::check_user_limits_static( + user, + &config, + &stats, + peer_addr, + &ip_tracker, + ) + .await; + assert!(first.is_ok(), "first check-only limit validation must succeed"); + + let second = RunningClientHandler::check_user_limits_static( + user, + &config, + &stats, + peer_addr, + &ip_tracker, + ) + .await; + assert!(second.is_ok(), "second check-only validation must not fail from leaked state"); + assert_eq!(stats.get_user_curr_connects(user), 0); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 0); +} + +#[tokio::test] +async fn stress_check_user_limits_static_success_never_leaks_state() { + let user = "check-helper-stress-user"; + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 1); + + let stats = Stats::new(); + let ip_tracker = UserIpTracker::new(); + + for i in 0..4096u16 { + let peer_addr = SocketAddr::new( + IpAddr::V4(std::net::Ipv4Addr::new(198, 51, 110, (i % 250) as u8 + 1)), + 40000 + (i % 1024), + ); + + let result = RunningClientHandler::check_user_limits_static( + user, + &config, + &stats, + peer_addr, + &ip_tracker, + ) + .await; + assert!(result.is_ok(), "check-only helper must remain leak-free under stress"); + } + + assert_eq!( + stats.get_user_curr_connects(user), + 0, + "stress success loop must not leak user connection counters" + ); + assert_eq!( + ip_tracker.get_active_ip_count(user).await, + 0, + "stress success loop must not leak active IP reservations" + ); +} + +#[tokio::test] +async fn concurrent_distinct_ip_rejections_rollback_user_counter_without_leak() { + let user = "rollback-storm-user"; + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 128); + + let config = Arc::new(config); + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 1).await; + + let keeper_peer: SocketAddr = "198.51.100.212:50002".parse().unwrap(); + let keeper = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + keeper_peer, + ip_tracker.clone(), + ) + .await + .expect("keeper reservation must succeed"); + + let mut tasks = tokio::task::JoinSet::new(); + for i in 0..64u8 { + let config = config.clone(); + let stats = stats.clone(); + let ip_tracker = ip_tracker.clone(); + tasks.spawn(async move { + let peer = SocketAddr::new( + IpAddr::V4(std::net::Ipv4Addr::new(198, 51, 101, i.saturating_add(1))), + 41000 + i as u16, + ); + let result = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats, + peer, + ip_tracker, + ) + .await; + assert!(matches!( + result, + Err(ProxyError::ConnectionLimitExceeded { user }) if user == "rollback-storm-user" + )); + }); + } + + while let Some(joined) = tasks.join_next().await { + joined.unwrap(); + } + + assert_eq!( + stats.get_user_curr_connects(user), + 1, + "failed distinct-IP attempts must rollback acquired user slots" + ); + assert_eq!( + ip_tracker.get_active_ip_count(user).await, + 1, + "failed distinct-IP attempts must not leave extra active IPs" + ); + + keeper.release().await; + assert_eq!(stats.get_user_curr_connects(user), 0); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 0); +} + +#[tokio::test] +async fn explicit_reservation_release_cleans_user_and_ip_immediately() { + let user = "release-user"; + let peer_addr: SocketAddr = "198.51.100.240:50002".parse().unwrap(); + + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 4); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer_addr, + ip_tracker.clone(), + ) + .await + .expect("reservation acquisition must succeed"); + + assert_eq!(stats.get_user_curr_connects(user), 1); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 1); + + reservation.release().await; + + assert_eq!( + stats.get_user_curr_connects(user), + 0, + "explicit release must synchronously free user connection slot" + ); + assert_eq!( + ip_tracker.get_active_ip_count(user).await, + 0, + "explicit release must synchronously remove reserved user IP" + ); +} + +#[tokio::test] +async fn explicit_reservation_release_does_not_double_decrement_on_drop() { + let user = "release-once-user"; + let peer_addr: SocketAddr = "198.51.100.241:50003".parse().unwrap(); + + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 4); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer_addr, + ip_tracker, + ) + .await + .expect("reservation acquisition must succeed"); + + reservation.release().await; + + assert_eq!( + stats.get_user_curr_connects(user), + 0, + "release must disarm drop and prevent double decrement" + ); +} + +#[tokio::test] +async fn drop_fallback_eventually_cleans_user_and_ip_reservation() { + let user = "drop-fallback-user"; + let peer_addr: SocketAddr = "198.51.100.242:50004".parse().unwrap(); + + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 4); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer_addr, + ip_tracker.clone(), + ) + .await + .expect("reservation acquisition must succeed"); + + assert_eq!(stats.get_user_curr_connects(user), 1); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 1); + + drop(reservation); + + tokio::time::timeout(Duration::from_secs(1), async { + loop { + if stats.get_user_curr_connects(user) == 0 + && ip_tracker.get_active_ip_count(user).await == 0 + { + break; + } + tokio::time::sleep(Duration::from_millis(5)).await; + } + }) + .await + .expect("drop fallback must eventually clean both user slot and active IP"); +} + +#[tokio::test] +async fn explicit_release_allows_immediate_cross_ip_reacquire_under_limit() { + let user = "cross-ip-user"; + let peer1: SocketAddr = "198.51.100.243:50005".parse().unwrap(); + let peer2: SocketAddr = "198.51.100.244:50006".parse().unwrap(); + + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 4); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 1).await; + + let first = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer1, + ip_tracker.clone(), + ) + .await + .expect("first reservation must succeed"); + first.release().await; + + let second = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer2, + ip_tracker.clone(), + ) + .await + .expect("second reservation must succeed immediately after explicit release"); + second.release().await; + + assert_eq!(stats.get_user_curr_connects(user), 0); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 0); +} + +#[tokio::test] +async fn release_abort_storm_does_not_leak_user_or_ip_reservations() { + const ATTEMPTS: usize = 256; + + let user = "release-abort-storm-user"; + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), ATTEMPTS + 16); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + for idx in 0..ATTEMPTS { + let peer = SocketAddr::new( + IpAddr::V4(std::net::Ipv4Addr::new(203, 0, 114, (idx % 250 + 1) as u8)), + 52000 + idx as u16, + ); + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await + .expect("reservation acquisition must succeed in abort storm"); + + let release_task = tokio::spawn(async move { + reservation.release().await; + }); + release_task.abort(); + let _ = release_task.await; + } + + tokio::time::timeout(Duration::from_secs(1), async { + loop { + if stats.get_user_curr_connects(user) == 0 + && ip_tracker.get_active_ip_count(user).await == 0 + { + break; + } + tokio::task::yield_now().await; + tokio::time::sleep(Duration::from_millis(2)).await; + } + }) + .await + .expect("release abort storm must not leak user slots or active IP entries"); +} + +#[tokio::test] +async fn release_abort_loop_preserves_immediate_same_ip_reacquire() { + const ITERATIONS: usize = 128; + + let user = "release-abort-reacquire-user"; + let peer: SocketAddr = "198.51.100.246:53001".parse().unwrap(); + + let mut config = ProxyConfig::default(); + config.access.user_max_tcp_conns.insert(user.to_string(), 1); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + for _ in 0..ITERATIONS { + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await + .expect("baseline acquisition must succeed"); + + let release_task = tokio::spawn(async move { + reservation.release().await; + }); + release_task.abort(); + let _ = release_task.await; + + tokio::time::timeout(Duration::from_secs(1), async { + loop { + if stats.get_user_curr_connects(user) == 0 + && ip_tracker.get_active_ip_count(user).await == 0 + { + break; + } + tokio::task::yield_now().await; + tokio::time::sleep(Duration::from_millis(2)).await; + } + }) + .await + .expect("aborted release must still converge to zero footprint"); + } + + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await + .expect("same-ip reacquire must succeed after repeated abort-release churn"); + reservation.release().await; +} + +#[tokio::test] +async fn adversarial_mixed_release_drop_abort_wave_converges_to_zero() { + const RESERVATIONS: usize = 192; + + let user = "mixed-wave-user"; + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), RESERVATIONS + 8); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let mut reservations = Vec::with_capacity(RESERVATIONS); + for idx in 0..RESERVATIONS { + let peer = SocketAddr::new( + IpAddr::V4(std::net::Ipv4Addr::new(203, 0, 115, (idx % 250 + 1) as u8)), + 54000 + idx as u16, + ); + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await + .expect("mixed-wave acquisition must succeed"); + reservations.push(reservation); + } + + let mut seed: u64 = 0xDEAD_BEEF_CAFE_BA5E; + let mut join_set = tokio::task::JoinSet::new(); + for reservation in reservations { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + match seed % 3 { + 0 => { + join_set.spawn(async move { + reservation.release().await; + }); + } + 1 => { + drop(reservation); + } + _ => { + let task = tokio::spawn(async move { + reservation.release().await; + }); + task.abort(); + let _ = task.await; + } + } + } + + while let Some(result) = join_set.join_next().await { + result.expect("release subtask must not panic"); + } + + tokio::time::timeout(Duration::from_secs(2), async { + loop { + if stats.get_user_curr_connects(user) == 0 + && ip_tracker.get_active_ip_count(user).await == 0 + { + break; + } + tokio::task::yield_now().await; + tokio::time::sleep(Duration::from_millis(2)).await; + } + }) + .await + .expect("mixed release/drop/abort wave must converge to zero footprint"); +} + +#[tokio::test] +async fn parallel_users_abort_release_isolation_preserves_independent_cleanup() { + let user_a = "abort-isolation-a"; + let user_b = "abort-isolation-b"; + + let mut config = ProxyConfig::default(); + config.access.user_max_tcp_conns.insert(user_a.to_string(), 64); + config.access.user_max_tcp_conns.insert(user_b.to_string(), 64); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let mut tasks = tokio::task::JoinSet::new(); + for idx in 0..64usize { + let user = if idx % 2 == 0 { user_a } else { user_b }; + let peer = SocketAddr::new( + IpAddr::V4(std::net::Ipv4Addr::new(198, 18, 0, (idx % 250 + 1) as u8)), + 55000 + idx as u16, + ); + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await + .expect("parallel-user acquisition must succeed"); + + tasks.spawn(async move { + let t = tokio::spawn(async move { + reservation.release().await; + }); + t.abort(); + let _ = t.await; + }); + } + + while let Some(result) = tasks.join_next().await { + result.expect("parallel-user abort task must not panic"); + } + + tokio::time::timeout(Duration::from_secs(2), async { + loop { + if stats.get_user_curr_connects(user_a) == 0 + && stats.get_user_curr_connects(user_b) == 0 + && ip_tracker.get_active_ip_count(user_a).await == 0 + && ip_tracker.get_active_ip_count(user_b).await == 0 + { + break; + } + tokio::task::yield_now().await; + tokio::time::sleep(Duration::from_millis(2)).await; + } + }) + .await + .expect("parallel users must cleanup independently under abort churn"); +} + +#[tokio::test] +async fn concurrent_release_storm_leaves_zero_user_and_ip_footprint() { + const RESERVATIONS: usize = 64; + + let user = "release-storm-user"; + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), RESERVATIONS + 8); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let mut reservations = Vec::with_capacity(RESERVATIONS); + for idx in 0..RESERVATIONS { + let ip = std::net::Ipv4Addr::new(203, 0, 113, (idx + 1) as u8); + let peer = SocketAddr::new(IpAddr::V4(ip), 51000 + idx as u16); + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await + .expect("reservation acquisition in storm must succeed"); + reservations.push(reservation); + } + + assert_eq!(stats.get_user_curr_connects(user), RESERVATIONS as u64); + assert_eq!(ip_tracker.get_active_ip_count(user).await, RESERVATIONS); + + let mut tasks = tokio::task::JoinSet::new(); + for reservation in reservations { + tasks.spawn(async move { + reservation.release().await; + }); + } + + while let Some(result) = tasks.join_next().await { + result.expect("release task must not panic"); + } + + assert_eq!( + stats.get_user_curr_connects(user), + 0, + "release storm must drain user current-connection counter to zero" + ); + assert_eq!( + ip_tracker.get_active_ip_count(user).await, + 0, + "release storm must clear all active IP entries" + ); +} + +#[tokio::test] +async fn relay_connect_error_releases_user_and_ip_before_return() { + let user = "relay-error-user"; + let peer_addr: SocketAddr = "198.51.100.245:50007".parse().unwrap(); + + let dead_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let dead_port = dead_listener.local_addr().unwrap().port(); + drop(dead_listener); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 1); + config + .dc_overrides + .insert("2".to_string(), vec![format!("127.0.0.1:{dead_port}")]); + let config = Arc::new(config); + + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + + let (server_side, _client_side) = duplex(64 * 1024); + let (server_reader, server_writer) = tokio::io::split(server_side); + let client_reader = make_crypto_reader(server_reader); + let client_writer = make_crypto_writer(server_writer); + + let success = HandshakeSuccess { + user: user.to_string(), + dc_idx: 2, + proto_tag: ProtoTag::Intermediate, + dec_key: [0u8; 32], + dec_iv: 0, + enc_key: [0u8; 32], + enc_iv: 0, + peer: peer_addr, + is_tls: false, + }; + + let result = RunningClientHandler::handle_authenticated_static( + client_reader, + client_writer, + success, + upstream_manager, + stats.clone(), + config, + buffer_pool, + rng, + None, + route_runtime, + "127.0.0.1:443".parse().unwrap(), + peer_addr, + ip_tracker.clone(), + ) + .await; + + assert!(result.is_err(), "relay must fail when upstream DC is unreachable"); + assert_eq!( + stats.get_user_curr_connects(user), + 0, + "error return must release user slot before returning" + ); + assert_eq!( + ip_tracker.get_active_ip_count(user).await, + 0, + "error return must release user IP reservation before returning" + ); +} + +#[tokio::test] +async fn mixed_release_and_drop_same_ip_preserves_counter_correctness() { + let user = "same-ip-mixed-user"; + let peer_addr: SocketAddr = "198.51.100.246:50008".parse().unwrap(); + + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 8); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let reservation_a = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer_addr, + ip_tracker.clone(), + ) + .await + .expect("first reservation must succeed"); + let reservation_b = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer_addr, + ip_tracker.clone(), + ) + .await + .expect("second reservation must succeed"); + + assert_eq!(stats.get_user_curr_connects(user), 2); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 1); + + reservation_a.release().await; + assert_eq!( + stats.get_user_curr_connects(user), + 1, + "explicit release must decrement only one active reservation" + ); + assert_eq!( + ip_tracker.get_active_ip_count(user).await, + 1, + "same IP must remain active while second reservation exists" + ); + + drop(reservation_b); + tokio::time::timeout(Duration::from_secs(1), async { + loop { + if stats.get_user_curr_connects(user) == 0 + && ip_tracker.get_active_ip_count(user).await == 0 + { + break; + } + tokio::time::sleep(Duration::from_millis(5)).await; + } + }) + .await + .expect("drop fallback must clear final same-IP reservation"); +} + +#[tokio::test] +async fn drop_one_of_two_same_ip_reservations_keeps_ip_active() { + let user = "same-ip-drop-one-user"; + let peer_addr: SocketAddr = "198.51.100.247:50009".parse().unwrap(); + + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 8); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let reservation_a = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer_addr, + ip_tracker.clone(), + ) + .await + .expect("first reservation must succeed"); + let reservation_b = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer_addr, + ip_tracker.clone(), + ) + .await + .expect("second reservation must succeed"); + + drop(reservation_a); + tokio::time::timeout(Duration::from_secs(1), async { + loop { + if stats.get_user_curr_connects(user) == 1 + && ip_tracker.get_active_ip_count(user).await == 1 + { + break; + } + tokio::time::sleep(Duration::from_millis(5)).await; + } + }) + .await + .expect("dropping one reservation must keep same-IP activity for remaining reservation"); + + reservation_b.release().await; + tokio::time::timeout(Duration::from_secs(1), async { + loop { + if stats.get_user_curr_connects(user) == 0 + && ip_tracker.get_active_ip_count(user).await == 0 + { + break; + } + tokio::time::sleep(Duration::from_millis(5)).await; + } + }) + .await + .expect("final release must converge to zero footprint after async fallback cleanup"); +} + #[tokio::test] async fn quota_rejection_does_not_reserve_ip_or_trigger_rollback() { let mut config = ProxyConfig::default(); @@ -1188,6 +2422,227 @@ async fn quota_rejection_does_not_reserve_ip_or_trigger_rollback() { ); } +#[tokio::test] +async fn expired_user_rejection_does_not_reserve_ip_or_increment_curr_connects() { + let mut config = ProxyConfig::default(); + config + .access + .user_expirations + .insert("user".to_string(), chrono::Utc::now() - chrono::Duration::seconds(1)); + + let stats = Stats::new(); + let ip_tracker = UserIpTracker::new(); + let peer_addr: SocketAddr = "203.0.113.212:50002".parse().unwrap(); + + let result = RunningClientHandler::check_user_limits_static( + "user", + &config, + &stats, + peer_addr, + &ip_tracker, + ) + .await; + + assert!(matches!( + result, + Err(ProxyError::UserExpired { user }) if user == "user" + )); + assert_eq!(stats.get_user_curr_connects("user"), 0); + assert_eq!(ip_tracker.get_active_ip_count("user").await, 0); +} + +#[tokio::test] +async fn same_ip_second_reservation_succeeds_under_unique_ip_limit_one() { + let user = "same-ip-unique-limit-user"; + let peer_addr: SocketAddr = "198.51.100.248:50010".parse().unwrap(); + + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 8); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 1).await; + + let first = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer_addr, + ip_tracker.clone(), + ) + .await + .expect("first reservation must succeed"); + let second = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer_addr, + ip_tracker.clone(), + ) + .await + .expect("second reservation from same IP must succeed under unique-ip limit=1"); + + assert_eq!(stats.get_user_curr_connects(user), 2); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 1); + + first.release().await; + second.release().await; + assert_eq!(stats.get_user_curr_connects(user), 0); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 0); +} + +#[tokio::test] +async fn second_distinct_ip_is_rejected_under_unique_ip_limit_one() { + let user = "distinct-ip-unique-limit-user"; + let peer1: SocketAddr = "198.51.100.249:50011".parse().unwrap(); + let peer2: SocketAddr = "198.51.100.250:50012".parse().unwrap(); + + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 8); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 1).await; + + let first = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer1, + ip_tracker.clone(), + ) + .await + .expect("first reservation must succeed"); + + let second = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer2, + ip_tracker.clone(), + ) + .await; + + assert!(matches!( + second, + Err(ProxyError::ConnectionLimitExceeded { user }) if user == "distinct-ip-unique-limit-user" + )); + assert_eq!(stats.get_user_curr_connects(user), 1); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 1); + + first.release().await; +} + +#[tokio::test] +async fn cross_thread_drop_uses_captured_runtime_for_ip_cleanup() { + let user = "cross-thread-drop-user"; + let peer_addr: SocketAddr = "198.51.100.251:50013".parse().unwrap(); + + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 8); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer_addr, + ip_tracker.clone(), + ) + .await + .expect("reservation acquisition must succeed"); + + assert_eq!(stats.get_user_curr_connects(user), 1); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 1); + + std::thread::spawn(move || { + drop(reservation); + }) + .join() + .expect("drop thread must not panic"); + + tokio::time::timeout(Duration::from_secs(1), async { + loop { + if stats.get_user_curr_connects(user) == 0 + && ip_tracker.get_active_ip_count(user).await == 0 + { + break; + } + tokio::time::sleep(Duration::from_millis(5)).await; + } + }) + .await + .expect("cross-thread drop must still converge to zero user and IP footprint"); +} + +#[tokio::test] +async fn immediate_reacquire_after_cross_thread_drop_succeeds() { + let user = "cross-thread-reacquire-user"; + let peer_addr: SocketAddr = "198.51.100.252:50014".parse().unwrap(); + + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 1); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer_addr, + ip_tracker.clone(), + ) + .await + .expect("initial reservation must succeed"); + + std::thread::spawn(move || { + drop(reservation); + }) + .join() + .expect("drop thread must not panic"); + + tokio::time::timeout(Duration::from_secs(1), async { + loop { + if stats.get_user_curr_connects(user) == 0 + && ip_tracker.get_active_ip_count(user).await == 0 + { + break; + } + tokio::time::sleep(Duration::from_millis(5)).await; + } + }) + .await + .expect("cross-thread cleanup must settle before reacquire check"); + + let reacquire = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats, + peer_addr, + ip_tracker, + ) + .await; + assert!( + reacquire.is_ok(), + "reacquire must succeed after cross-thread drop cleanup" + ); +} + #[tokio::test] async fn concurrent_limit_rejections_from_mixed_ips_leave_no_ip_footprint() { const PARALLEL_IPS: usize = 64; @@ -1281,16 +2736,24 @@ async fn atomic_limit_gate_allows_only_one_concurrent_acquire() { IpAddr::V4(std::net::Ipv4Addr::new(203, 0, 113, (i + 1) as u8)), 30000 + i, ); - RunningClientHandler::check_user_limits_static("user", &config, &stats, peer, &ip_tracker) - .await - .is_ok() + RunningClientHandler::acquire_user_connection_reservation_static( + "user", + &config, + stats, + peer, + ip_tracker, + ) + .await + .ok() }); } let mut successes = 0u64; + let mut held_reservations = Vec::new(); while let Some(joined) = tasks.join_next().await { - if joined.unwrap() { + if let Some(reservation) = joined.unwrap() { successes += 1; + held_reservations.push(reservation); } } @@ -1299,6 +2762,8 @@ async fn atomic_limit_gate_allows_only_one_concurrent_acquire() { "exactly one concurrent acquire must pass for a limit=1 user" ); assert_eq!(stats.get_user_curr_connects("user"), 1); + + drop(held_reservations); } #[tokio::test] @@ -2069,3 +3534,478 @@ async fn tls_record_len_16384_is_accepted_in_client_handler_pipeline() { "Valid max-length ClientHello must not increment bad counter" ); } + +fn lcg_next(state: &mut u64) -> u64 { + *state = state.wrapping_mul(6364136223846793005).wrapping_add(1); + *state +} + +async fn wait_for_user_and_ip_zero( + stats: &Arc, + ip_tracker: &Arc, + user: &str, +) { + tokio::time::timeout(Duration::from_secs(2), async { + loop { + if stats.get_user_curr_connects(user) == 0 + && ip_tracker.get_active_ip_count(user).await == 0 + { + break; + } + tokio::time::sleep(Duration::from_millis(5)).await; + } + }) + .await + .expect("cleanup must converge to zero user and IP footprint"); +} + +async fn burst_acquire_distinct_ips( + user: &'static str, + config: Arc, + stats: Arc, + ip_tracker: Arc, + third_octet: u8, + attempts: u16, +) -> (Vec, usize) { + let mut tasks = tokio::task::JoinSet::new(); + for i in 0..attempts { + let config = config.clone(); + let stats = stats.clone(); + let ip_tracker = ip_tracker.clone(); + tasks.spawn(async move { + let host = (i as u8).saturating_add(1); + let peer = SocketAddr::new( + IpAddr::V4(std::net::Ipv4Addr::new(198, 51, third_octet, host)), + 55000 + i, + ); + RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats, + peer, + ip_tracker, + ) + .await + }); + } + + let mut successes = Vec::new(); + let mut failures = 0usize; + while let Some(joined) = tasks.join_next().await { + match joined.expect("burst acquire task must not panic") { + Ok(reservation) => successes.push(reservation), + Err(err) => { + assert!(matches!( + err, + ProxyError::ConnectionLimitExceeded { user: ref denied_user } + if denied_user == user + )); + failures = failures.saturating_add(1); + } + } + } + + (successes, failures) +} + +#[tokio::test] +async fn deterministic_mixed_reservation_churn_preserves_counter_and_eventual_cleanup() { + let user = "deterministic-churn-user"; + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 12); + + let config = Arc::new(config); + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 4).await; + + let mut seed = 0xD1F2_A4C8_991B_77E1u64; + let mut reservations: Vec> = Vec::new(); + + for step in 0..220u64 { + let op = (lcg_next(&mut seed) % 100) as u8; + let active = reservations.iter().filter(|entry| entry.is_some()).count(); + + if active == 0 || op < 55 { + let ip_octet = (lcg_next(&mut seed) % 16 + 1) as u8; + let peer = SocketAddr::new( + IpAddr::V4(std::net::Ipv4Addr::new(198, 51, 120, ip_octet)), + 52000 + (step % 4000) as u16, + ); + let result = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await; + + if let Ok(reservation) = result { + reservations.push(Some(reservation)); + } else { + assert!(matches!( + result, + Err(ProxyError::ConnectionLimitExceeded { user }) if user == "deterministic-churn-user" + )); + } + } else { + let selected = reservations + .iter() + .enumerate() + .filter(|(_, entry)| entry.is_some()) + .map(|(idx, _)| idx) + .nth((lcg_next(&mut seed) as usize) % active) + .unwrap(); + + let reservation = reservations[selected].take().unwrap(); + if op < 80 { + reservation.release().await; + } else { + std::thread::spawn(move || { + drop(reservation); + }) + .join() + .expect("cross-thread drop must not panic"); + } + } + + let live_slots = reservations.iter().filter(|entry| entry.is_some()).count() as u64; + assert_eq!( + stats.get_user_curr_connects(user), + live_slots, + "current-connects counter must match number of live reservations" + ); + assert!( + stats.get_user_curr_connects(user) <= 12, + "current-connects must stay within configured TCP limit" + ); + assert!( + ip_tracker.get_active_ip_count(user).await <= 4, + "active unique IPs must stay within configured per-user IP limit" + ); + } + + for reservation in reservations.into_iter().flatten() { + reservation.release().await; + } + wait_for_user_and_ip_zero(&stats, &ip_tracker, user).await; +} + +#[tokio::test] +async fn cross_thread_drop_storm_then_parallel_reacquire_wave_has_no_leak() { + let user = "drop-storm-reacquire-user"; + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 64); + + let config = Arc::new(config); + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 8).await; + + let mut initial = Vec::new(); + for i in 0..32u16 { + let ip_octet = (i % 8 + 1) as u8; + let peer = SocketAddr::new( + IpAddr::V4(std::net::Ipv4Addr::new(203, 0, 120, ip_octet)), + 53000 + i, + ); + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await + .expect("initial reservation must succeed"); + initial.push(reservation); + } + + let mut second_half = initial.split_off(16); + + let mut releases = Vec::new(); + for reservation in initial { + releases.push(tokio::spawn(async move { + reservation.release().await; + })); + } + for release_task in releases { + release_task.await.expect("release task must not panic"); + } + + let mut drop_threads = Vec::new(); + for reservation in second_half.drain(..) { + drop_threads.push(std::thread::spawn(move || { + drop(reservation); + })); + } + for drop_thread in drop_threads { + drop_thread + .join() + .expect("cross-thread drop worker must not panic"); + } + + wait_for_user_and_ip_zero(&stats, &ip_tracker, user).await; + + let mut reacquire_tasks = tokio::task::JoinSet::new(); + for i in 0..16u16 { + let config = config.clone(); + let stats = stats.clone(); + let ip_tracker = ip_tracker.clone(); + reacquire_tasks.spawn(async move { + let peer = SocketAddr::new( + IpAddr::V4(std::net::Ipv4Addr::new(198, 51, 121, (i + 1) as u8)), + 54000 + i, + ); + RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats, + peer, + ip_tracker, + ) + .await + }); + } + + let mut acquired = Vec::new(); + while let Some(joined) = reacquire_tasks.join_next().await { + match joined.expect("reacquire task must not panic") { + Ok(reservation) => acquired.push(reservation), + Err(err) => { + assert!(matches!( + err, + ProxyError::ConnectionLimitExceeded { user } + if user == "drop-storm-reacquire-user" + )); + } + } + } + + assert!( + acquired.len() <= 8, + "parallel distinct-IP reacquire wave must not exceed per-user unique IP limit" + ); + for reservation in acquired { + reservation.release().await; + } + wait_for_user_and_ip_zero(&stats, &ip_tracker, user).await; +} + +#[tokio::test] +async fn scheduled_near_limit_and_burst_windows_preserve_admission_invariants() { + let user: &'static str = "scheduled-attack-user"; + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 6); + + let config = Arc::new(config); + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 2).await; + + let mut base = Vec::new(); + for i in 0..5u16 { + let peer = SocketAddr::new(IpAddr::V4(std::net::Ipv4Addr::new(198, 51, 130, 1)), 56000 + i); + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await + .expect("near-limit warmup reservation must succeed"); + base.push(reservation); + } + assert_eq!(stats.get_user_curr_connects(user), 5); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 1); + + let (wave1_success, wave1_fail) = burst_acquire_distinct_ips( + user, + config.clone(), + stats.clone(), + ip_tracker.clone(), + 131, + 32, + ) + .await; + assert_eq!(wave1_success.len(), 1); + assert_eq!(wave1_fail, 31); + assert_eq!(stats.get_user_curr_connects(user), 6); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 2); + + let released = base.pop().expect("must have releasable reservation"); + released.release().await; + for reservation in wave1_success { + reservation.release().await; + } + + tokio::time::timeout(Duration::from_secs(1), async { + loop { + if stats.get_user_curr_connects(user) == 4 + && ip_tracker.get_active_ip_count(user).await == 1 + { + break; + } + tokio::time::sleep(Duration::from_millis(5)).await; + } + }) + .await + .expect("window cleanup must settle to expected occupancy"); + + let (wave2_success, wave2_fail) = burst_acquire_distinct_ips( + user, + config, + stats.clone(), + ip_tracker.clone(), + 132, + 32, + ) + .await; + assert_eq!(wave2_success.len(), 1); + assert_eq!(wave2_fail, 31); + assert_eq!(stats.get_user_curr_connects(user), 5); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 2); + + let tail = base.split_off(2); + + let mut drop_threads = Vec::new(); + for reservation in base { + drop_threads.push(std::thread::spawn(move || { + drop(reservation); + })); + } + for drop_thread in drop_threads { + drop_thread + .join() + .expect("cross-thread scheduled cleanup must not panic"); + } + + for reservation in tail { + reservation.release().await; + } + for reservation in wave2_success { + reservation.release().await; + } + + wait_for_user_and_ip_zero(&stats, &ip_tracker, user).await; +} + +#[tokio::test] +async fn scheduled_mode_switch_burst_churn_preserves_limits_and_cleanup() { + let user: &'static str = "scheduled-mode-switch-user"; + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 10); + + let config = Arc::new(config); + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 3).await; + + let base_peer = SocketAddr::new(IpAddr::V4(std::net::Ipv4Addr::new(198, 51, 140, 1)), 57000); + let mut base = Vec::new(); + for i in 0..7u16 { + let peer = SocketAddr::new(base_peer.ip(), base_peer.port().saturating_add(i)); + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await + .expect("base occupancy reservation must succeed"); + base.push(reservation); + } + + assert_eq!(stats.get_user_curr_connects(user), 7); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 1); + + for round in 0..8u8 { + let (wave_success, wave_fail) = burst_acquire_distinct_ips( + user, + config.clone(), + stats.clone(), + ip_tracker.clone(), + 141u8.saturating_add(round), + 24, + ) + .await; + + assert!( + wave_success.len() <= 2, + "burst must not exceed available unique-IP headroom under limit=3" + ); + assert_eq!(wave_success.len() + wave_fail, 24); + assert_eq!( + stats.get_user_curr_connects(user), + 7 + wave_success.len() as u64, + "slot counter must reflect base occupancy plus successful burst leases" + ); + assert!(ip_tracker.get_active_ip_count(user).await <= 3); + + if round % 2 == 0 { + for reservation in wave_success { + reservation.release().await; + } + let rotated = base.pop().expect("base rotation reservation must exist"); + rotated.release().await; + } else { + for reservation in wave_success { + std::thread::spawn(move || { + drop(reservation); + }) + .join() + .expect("drop-heavy burst cleanup thread must not panic"); + } + let rotated = base.pop().expect("base rotation reservation must exist"); + std::thread::spawn(move || { + drop(rotated); + }) + .join() + .expect("drop-heavy base cleanup thread must not panic"); + } + + let replacement = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + base_peer, + ip_tracker.clone(), + ) + .await + .expect("base replacement reservation must succeed after each round"); + base.push(replacement); + + tokio::time::timeout(Duration::from_secs(1), async { + loop { + if stats.get_user_curr_connects(user) == 7 + && ip_tracker.get_active_ip_count(user).await <= 1 + { + break; + } + tokio::time::sleep(Duration::from_millis(5)).await; + } + }) + .await + .expect("round cleanup must converge to steady base occupancy"); + } + + for reservation in base { + reservation.release().await; + } + wait_for_user_and_ip_zero(&stats, &ip_tracker, user).await; +} diff --git a/src/proxy/direct_relay.rs b/src/proxy/direct_relay.rs index 9c6116c..4a7b9a9 100644 --- a/src/proxy/direct_relay.rs +++ b/src/proxy/direct_relay.rs @@ -1,6 +1,8 @@ +use std::ffi::OsString; use std::fs::OpenOptions; use std::io::Write; use std::net::SocketAddr; +use std::path::{Component, Path, PathBuf}; use std::sync::Arc; use std::collections::HashSet; use std::sync::{Mutex, OnceLock}; @@ -24,14 +26,28 @@ use crate::stats::Stats; use crate::stream::{BufferPool, CryptoReader, CryptoWriter}; use crate::transport::UpstreamManager; +#[cfg(unix)] +use std::os::unix::fs::OpenOptionsExt; + const UNKNOWN_DC_LOG_DISTINCT_LIMIT: usize = 1024; static LOGGED_UNKNOWN_DCS: OnceLock>> = OnceLock::new(); +#[derive(Clone)] +struct SanitizedUnknownDcLogPath { + resolved_path: PathBuf, + allowed_parent: PathBuf, + file_name: OsString, +} + // In tests, this function shares global mutable state. Callers that also use // cache-reset helpers must hold `unknown_dc_test_lock()` to keep assertions // deterministic under parallel execution. fn should_log_unknown_dc(dc_idx: i16) -> bool { let set = LOGGED_UNKNOWN_DCS.get_or_init(|| Mutex::new(HashSet::new())); + should_log_unknown_dc_with_set(set, dc_idx) +} + +fn should_log_unknown_dc_with_set(set: &Mutex>, dc_idx: i16) -> bool { match set.lock() { Ok(mut guard) => { if guard.contains(&dc_idx) { @@ -42,9 +58,81 @@ fn should_log_unknown_dc(dc_idx: i16) -> bool { } guard.insert(dc_idx) } - // If the lock is poisoned, keep logging rather than silently dropping - // operator-visible diagnostics. - Err(_) => true, + // Fail closed on poisoned state to avoid unbounded blocking log writes. + Err(_) => false, + } +} + +fn sanitize_unknown_dc_log_path(path: &str) -> Option { + let candidate = Path::new(path); + if candidate.as_os_str().is_empty() { + return None; + } + if candidate + .components() + .any(|component| matches!(component, Component::ParentDir)) + { + return None; + } + + let cwd = std::env::current_dir().ok()?; + let file_name = candidate.file_name()?; + let parent = candidate.parent().unwrap_or_else(|| Path::new(".")); + let parent_path = if parent.is_absolute() { + parent.to_path_buf() + } else { + cwd.join(parent) + }; + let canonical_parent = parent_path.canonicalize().ok()?; + if !canonical_parent.is_dir() { + return None; + } + + Some(SanitizedUnknownDcLogPath { + resolved_path: canonical_parent.join(file_name), + allowed_parent: canonical_parent, + file_name: file_name.to_os_string(), + }) +} + +fn unknown_dc_log_path_is_still_safe(path: &SanitizedUnknownDcLogPath) -> bool { + let Some(parent) = path.resolved_path.parent() else { + return false; + }; + let Ok(current_parent) = parent.canonicalize() else { + return false; + }; + if current_parent != path.allowed_parent { + return false; + } + + if let Ok(canonical_target) = path.resolved_path.canonicalize() { + let Some(target_parent) = canonical_target.parent() else { + return false; + }; + let Some(target_name) = canonical_target.file_name() else { + return false; + }; + if target_parent != path.allowed_parent || target_name != path.file_name { + return false; + } + } + + true +} + +fn open_unknown_dc_log_append(path: &Path) -> std::io::Result { + #[cfg(unix)] + { + OpenOptions::new() + .create(true) + .append(true) + .custom_flags(libc::O_NOFOLLOW) + .open(path) + } + #[cfg(not(unix))] + { + OpenOptions::new().create(true).append(true).open(path) } } @@ -105,7 +193,7 @@ where debug!(peer = %success.peer, "TG handshake complete, starting relay"); stats.increment_user_connects(user); - stats.increment_current_connections_direct(); + let _direct_connection_lease = stats.acquire_direct_connection_lease(); let relay_result = relay_bidirectional( client_reader, @@ -148,8 +236,6 @@ where } }; - stats.decrement_current_connections_direct(); - match &relay_result { Ok(()) => debug!(user = %user, "Direct relay completed"), Err(e) => debug!(user = %user, error = %e, "Direct relay ended with error"), @@ -202,12 +288,17 @@ fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result { && should_log_unknown_dc(dc_idx) && let Ok(handle) = tokio::runtime::Handle::try_current() { - let path = path.clone(); - handle.spawn_blocking(move || { - if let Ok(mut file) = OpenOptions::new().create(true).append(true).open(path) { - let _ = writeln!(file, "dc_idx={dc_idx}"); - } - }); + if let Some(path) = sanitize_unknown_dc_log_path(path) { + handle.spawn_blocking(move || { + if unknown_dc_log_path_is_still_safe(&path) + && let Ok(mut file) = open_unknown_dc_log_append(&path.resolved_path) + { + let _ = writeln!(file, "dc_idx={dc_idx}"); + } + }); + } else { + warn!(dc_idx = dc_idx, raw_path = %path, "Rejected unsafe unknown DC log path"); + } } } diff --git a/src/proxy/direct_relay_security_tests.rs b/src/proxy/direct_relay_security_tests.rs index 3b3185a..e47164f 100644 --- a/src/proxy/direct_relay_security_tests.rs +++ b/src/proxy/direct_relay_security_tests.rs @@ -1,4 +1,41 @@ use super::*; +use crate::config::{UpstreamConfig, UpstreamType}; +use crate::crypto::{AesCtr, SecureRandom}; +use crate::protocol::constants::ProtoTag; +use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController}; +use crate::stats::Stats; +use crate::stream::{BufferPool, CryptoReader, CryptoWriter}; +use crate::transport::UpstreamManager; +use std::fs; +use std::io::Write; +use std::path::Path; +use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::time::Duration; +use tokio::io::duplex; +use tokio::net::TcpListener; + +fn make_crypto_reader(reader: R) -> CryptoReader +where + R: tokio::io::AsyncRead + Unpin, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoReader::new(reader, AesCtr::new(&key, iv)) +} + +fn make_crypto_writer(writer: W) -> CryptoWriter +where + W: tokio::io::AsyncWrite + Unpin, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) +} + +fn nonempty_line_count(text: &str) -> usize { + text.lines().filter(|line| !line.trim().is_empty()).count() +} #[test] fn unknown_dc_log_is_deduplicated_per_dc_idx() { @@ -38,6 +75,771 @@ fn unknown_dc_log_respects_distinct_limit() { ); } +#[test] +fn unknown_dc_log_fails_closed_when_dedup_lock_is_poisoned() { + let poisoned = Arc::new(std::sync::Mutex::new(std::collections::HashSet::::new())); + let poisoned_for_thread = poisoned.clone(); + + let _ = std::thread::spawn(move || { + let _guard = poisoned_for_thread + .lock() + .expect("poison setup lock must be available"); + panic!("intentional poison for fail-closed regression"); + }) + .join(); + + assert!( + !should_log_unknown_dc_with_set(poisoned.as_ref(), 4242), + "poisoned unknown-DC dedup lock must fail closed" + ); +} + +#[test] +fn stress_unknown_dc_log_concurrent_unique_churn_respects_cap() { + let _guard = unknown_dc_test_lock() + .lock() + .expect("unknown dc test lock must be available"); + clear_unknown_dc_log_cache_for_testing(); + + let accepted = Arc::new(AtomicUsize::new(0)); + let mut workers = Vec::new(); + + // Adversarial model: many concurrent peers rotate dc_idx values rapidly. + for worker in 0..16usize { + let accepted = Arc::clone(&accepted); + workers.push(std::thread::spawn(move || { + let base = (worker * 2048) as i32; + for offset in 0..512i32 { + let raw = base + offset; + let dc = (raw % i16::MAX as i32) as i16; + if should_log_unknown_dc(dc) { + accepted.fetch_add(1, Ordering::Relaxed); + } + } + })); + } + + for worker in workers { + worker.join().expect("worker thread must not panic"); + } + + assert_eq!( + accepted.load(Ordering::Relaxed), + UNKNOWN_DC_LOG_DISTINCT_LIMIT, + "concurrent unique churn must never admit more than the configured distinct cap" + ); +} + +#[test] +fn light_fuzz_unknown_dc_log_mixed_duplicates_never_exceeds_cap() { + let _guard = unknown_dc_test_lock() + .lock() + .expect("unknown dc test lock must be available"); + clear_unknown_dc_log_cache_for_testing(); + + // Deterministic xorshift sequence for reproducible mixed duplicate fuzzing. + let mut s: u64 = 0xA5A5_5A5A_C3C3_3C3C; + let mut admitted = 0usize; + + for _ in 0..20_000 { + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + + let dc = (s as i16).wrapping_sub(i16::MAX / 2); + if should_log_unknown_dc(dc) { + admitted += 1; + } + } + + assert!( + admitted <= UNKNOWN_DC_LOG_DISTINCT_LIMIT, + "mixed-duplicate fuzzed inputs must not admit more than cap" + ); +} + +#[test] +fn unknown_dc_log_path_sanitizer_rejects_parent_traversal_inputs() { + assert!( + sanitize_unknown_dc_log_path("../unknown-dc.txt").is_none(), + "parent traversal paths must be rejected" + ); + assert!( + sanitize_unknown_dc_log_path("logs/../unknown-dc.txt").is_none(), + "embedded parent traversal must be rejected" + ); + assert!( + sanitize_unknown_dc_log_path("./../unknown-dc.txt").is_none(), + "relative parent traversal must be rejected" + ); +} + +#[test] +fn unknown_dc_log_path_sanitizer_accepts_absolute_paths_with_existing_parent() { + let absolute = std::env::temp_dir().join("unknown-dc.txt"); + let absolute_str = absolute + .to_str() + .expect("temp absolute path must be valid UTF-8"); + + let sanitized = sanitize_unknown_dc_log_path(absolute_str) + .expect("absolute paths with existing parent must be accepted"); + assert_eq!(sanitized.resolved_path, absolute); +} + +#[test] +fn unknown_dc_log_path_sanitizer_rejects_absolute_parent_traversal() { + assert!( + sanitize_unknown_dc_log_path("/tmp/../etc/passwd").is_none(), + "absolute parent traversal must be rejected" + ); +} + +#[test] +fn unknown_dc_log_path_sanitizer_accepts_safe_relative_path() { + let base = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!("telemt-unknown-dc-log-{}", std::process::id())); + fs::create_dir_all(&base).expect("temp test directory must be creatable"); + + let candidate = base.join("unknown-dc.txt"); + let candidate_relative = format!("target/telemt-unknown-dc-log-{}/unknown-dc.txt", std::process::id()); + + let sanitized = sanitize_unknown_dc_log_path(&candidate_relative) + .expect("safe relative path with existing parent must be accepted"); + assert_eq!(sanitized.resolved_path, candidate); +} + +#[test] +fn unknown_dc_log_path_sanitizer_rejects_empty_or_dot_only_inputs() { + assert!( + sanitize_unknown_dc_log_path("").is_none(), + "empty path must be rejected" + ); + assert!( + sanitize_unknown_dc_log_path(".").is_none(), + "dot-only path without filename must be rejected" + ); +} + +#[test] +fn unknown_dc_log_path_sanitizer_accepts_directory_only_as_filename_projection() { + let sanitized = sanitize_unknown_dc_log_path("target/") + .expect("directory-only input is interpreted as filename projection in current sanitizer"); + assert!( + sanitized.resolved_path.ends_with("target"), + "directory-only input should resolve to canonical parent plus filename projection" + ); +} + +#[test] +fn unknown_dc_log_path_sanitizer_accepts_dot_prefixed_relative_path() { + let rel_dir = format!("target/telemt-unknown-dc-dot-{}", std::process::id()); + let abs_dir = std::env::current_dir() + .expect("cwd must be available") + .join(&rel_dir); + fs::create_dir_all(&abs_dir).expect("dot-prefixed test directory must be creatable"); + + let rel_candidate = format!("./{rel_dir}/unknown-dc.log"); + let expected = abs_dir.join("unknown-dc.log"); + let sanitized = sanitize_unknown_dc_log_path(&rel_candidate) + .expect("dot-prefixed safe path must be accepted"); + assert_eq!(sanitized.resolved_path, expected); +} + +#[test] +fn light_fuzz_unknown_dc_path_parentdir_inputs_always_rejected() { + let mut s: u64 = 0xD00D_BAAD_1234_5678; + for _ in 0..4096 { + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + let a = (s as usize) % 32; + let b = ((s >> 8) as usize) % 32; + let candidate = format!("target/{a}/../{b}/unknown-dc.log"); + assert!( + sanitize_unknown_dc_log_path(&candidate).is_none(), + "parent-dir candidate must be rejected: {candidate}" + ); + } +} + +#[test] +fn unknown_dc_log_path_sanitizer_rejects_nonexistent_parent_directory() { + let rel_candidate = format!( + "target/telemt-unknown-dc-missing-{}/nested/unknown-dc.txt", + std::process::id() + ); + + assert!( + sanitize_unknown_dc_log_path(&rel_candidate).is_none(), + "path with missing parent must be rejected to avoid implicit directory creation" + ); +} + +#[cfg(unix)] +#[test] +fn unknown_dc_log_path_sanitizer_accepts_symlinked_parent_inside_workspace() { + use std::os::unix::fs::symlink; + + let base = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!("telemt-unknown-dc-log-symlink-internal-{}", std::process::id())); + let real_parent = base.join("real_parent"); + fs::create_dir_all(&real_parent).expect("real parent dir must be creatable"); + + let symlink_parent = base.join("internal_link"); + let _ = fs::remove_file(&symlink_parent); + symlink(&real_parent, &symlink_parent).expect("internal symlink must be creatable"); + + let rel_candidate = format!( + "target/telemt-unknown-dc-log-symlink-internal-{}/internal_link/unknown-dc.txt", + std::process::id() + ); + + let sanitized = sanitize_unknown_dc_log_path(&rel_candidate) + .expect("symlinked parent that resolves inside workspace must be accepted"); + assert!( + sanitized.resolved_path.starts_with(&real_parent), + "sanitized path must resolve to canonical internal parent" + ); +} + +#[cfg(unix)] +#[test] +fn unknown_dc_log_path_sanitizer_accepts_symlink_parent_escape_as_canonical_path() { + use std::os::unix::fs::symlink; + + let base = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!("telemt-unknown-dc-log-symlink-{}", std::process::id())); + fs::create_dir_all(&base).expect("symlink test directory must be creatable"); + + let symlink_parent = base.join("escape_link"); + let _ = fs::remove_file(&symlink_parent); + symlink("/tmp", &symlink_parent).expect("symlink parent must be creatable"); + + let rel_candidate = format!( + "target/telemt-unknown-dc-log-symlink-{}/escape_link/unknown-dc.txt", + std::process::id() + ); + + let sanitized = sanitize_unknown_dc_log_path(&rel_candidate) + .expect("symlinked parent must canonicalize to target path"); + assert!( + sanitized.resolved_path.starts_with(Path::new("/tmp")), + "sanitized path must resolve to canonical symlink target" + ); +} + +#[cfg(unix)] +#[test] +fn unknown_dc_log_path_revalidation_rejects_symlinked_target_escape() { + use std::os::unix::fs::symlink; + + let base = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!("telemt-unknown-dc-target-link-{}", std::process::id())); + fs::create_dir_all(&base).expect("target-link base must be creatable"); + + let outside = std::env::temp_dir().join(format!("telemt-outside-{}", std::process::id())); + let _ = fs::remove_file(&outside); + fs::write(&outside, "outside").expect("outside file must be writable"); + + let linked_target = base.join("unknown-dc.log"); + let _ = fs::remove_file(&linked_target); + symlink(&outside, &linked_target).expect("target symlink must be creatable"); + + let rel_candidate = format!( + "target/telemt-unknown-dc-target-link-{}/unknown-dc.log", + std::process::id() + ); + let sanitized = sanitize_unknown_dc_log_path(&rel_candidate) + .expect("candidate should sanitize before final revalidation"); + + assert!( + !unknown_dc_log_path_is_still_safe(&sanitized), + "final revalidation must reject symlinked target escape" + ); +} + +#[cfg(unix)] +#[test] +fn unknown_dc_open_append_rejects_symlink_target_with_nofollow() { + use std::os::unix::fs::symlink; + + let base = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!("telemt-unknown-dc-nofollow-{}", std::process::id())); + fs::create_dir_all(&base).expect("nofollow base must be creatable"); + + let outside = std::env::temp_dir().join(format!( + "telemt-unknown-dc-nofollow-outside-{}.log", + std::process::id() + )); + let _ = fs::remove_file(&outside); + fs::write(&outside, "outside\n").expect("outside file must be writable"); + + let linked_target = base.join("unknown-dc.log"); + let _ = fs::remove_file(&linked_target); + symlink(&outside, &linked_target).expect("symlink target must be creatable"); + + let err = open_unknown_dc_log_append(&linked_target) + .expect_err("O_NOFOLLOW open must fail for symlink target"); + assert_eq!( + err.raw_os_error(), + Some(libc::ELOOP), + "symlink target must be rejected with ELOOP when O_NOFOLLOW is applied" + ); +} + +#[cfg(unix)] +#[test] +fn unknown_dc_open_append_rejects_broken_symlink_target_with_nofollow() { + use std::os::unix::fs::symlink; + + let base = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!("telemt-unknown-dc-broken-link-{}", std::process::id())); + fs::create_dir_all(&base).expect("broken-link base must be creatable"); + + let linked_target = base.join("unknown-dc.log"); + let _ = fs::remove_file(&linked_target); + symlink(base.join("missing-target.log"), &linked_target) + .expect("broken symlink target must be creatable"); + + let err = open_unknown_dc_log_append(&linked_target) + .expect_err("O_NOFOLLOW open must fail for broken symlink target"); + assert_eq!( + err.raw_os_error(), + Some(libc::ELOOP), + "broken symlink target must be rejected with ELOOP when O_NOFOLLOW is applied" + ); +} + +#[cfg(unix)] +#[test] +fn adversarial_unknown_dc_open_append_symlink_flip_never_writes_outside_file() { + use std::os::unix::fs::symlink; + + let base = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!("telemt-unknown-dc-symlink-flip-{}", std::process::id())); + fs::create_dir_all(&base).expect("symlink-flip base must be creatable"); + + let outside = std::env::temp_dir().join(format!( + "telemt-unknown-dc-symlink-flip-outside-{}.log", + std::process::id() + )); + fs::write(&outside, "outside-baseline\n").expect("outside baseline file must be writable"); + let outside_before = fs::read_to_string(&outside).expect("outside baseline must be readable"); + + let target = base.join("unknown-dc.log"); + let _ = fs::remove_file(&target); + + for step in 0..1024usize { + let _ = fs::remove_file(&target); + if step % 2 == 0 { + symlink(&outside, &target).expect("symlink creation in flip loop must succeed"); + } + if let Ok(mut file) = open_unknown_dc_log_append(&target) { + writeln!(file, "dc_idx={step}").expect("append on regular file must succeed"); + } + } + + let outside_after = fs::read_to_string(&outside).expect("outside file must remain readable"); + assert_eq!( + outside_after, outside_before, + "outside file must never be modified under symlink-flip adversarial churn" + ); +} + +#[test] +fn unknown_dc_open_append_creates_regular_file() { + let base = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!("telemt-unknown-dc-open-{}", std::process::id())); + fs::create_dir_all(&base).expect("open test base must be creatable"); + + let target = base.join("unknown-dc.log"); + let _ = fs::remove_file(&target); + + { + let mut file = open_unknown_dc_log_append(&target) + .expect("regular target must be creatable with append open"); + writeln!(file, "dc_idx=1234").expect("append write must succeed"); + } + + let meta = fs::symlink_metadata(&target).expect("created target metadata must be readable"); + assert!(meta.file_type().is_file(), "target must be a regular file"); + assert!( + !meta.file_type().is_symlink(), + "regular target open path must not produce symlink artifacts" + ); +} + +#[test] +fn stress_unknown_dc_open_append_regular_file_preserves_line_integrity() { + let base = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!("telemt-unknown-dc-open-stress-{}", std::process::id())); + fs::create_dir_all(&base).expect("stress open base must be creatable"); + + let target = base.join("unknown-dc.log"); + let _ = fs::remove_file(&target); + + let writes = 2048usize; + for idx in 0..writes { + let mut file = open_unknown_dc_log_append(&target) + .expect("stress append open on regular file must succeed"); + writeln!(file, "dc_idx={idx}").expect("stress append write must succeed"); + } + + let content = fs::read_to_string(&target).expect("stress output file must be readable"); + assert_eq!( + nonempty_line_count(&content), + writes, + "regular-file append stress must preserve one logical line per write" + ); +} + +#[test] +fn unknown_dc_log_path_revalidation_accepts_regular_existing_target() { + let base = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!("telemt-unknown-dc-safe-target-{}", std::process::id())); + fs::create_dir_all(&base).expect("safe target base must be creatable"); + + let target = base.join("unknown-dc.log"); + fs::write(&target, "seed\n").expect("safe target seed write must succeed"); + + let rel_candidate = format!( + "target/telemt-unknown-dc-safe-target-{}/unknown-dc.log", + std::process::id() + ); + let sanitized = sanitize_unknown_dc_log_path(&rel_candidate) + .expect("safe candidate must sanitize"); + assert!( + unknown_dc_log_path_is_still_safe(&sanitized), + "revalidation must allow safe existing regular files" + ); +} + +#[test] +fn unknown_dc_log_path_revalidation_rejects_deleted_parent_after_sanitize() { + let base = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!("telemt-unknown-dc-vanish-parent-{}", std::process::id())); + fs::create_dir_all(&base).expect("vanish-parent base must be creatable"); + + let rel_candidate = format!( + "target/telemt-unknown-dc-vanish-parent-{}/unknown-dc.log", + std::process::id() + ); + let sanitized = sanitize_unknown_dc_log_path(&rel_candidate) + .expect("candidate must sanitize before parent deletion"); + + fs::remove_dir_all(&base).expect("test parent directory must be removable"); + assert!( + !unknown_dc_log_path_is_still_safe(&sanitized), + "revalidation must fail when sanitized parent disappears before write" + ); +} + +#[cfg(unix)] +#[test] +fn unknown_dc_log_path_revalidation_rejects_parent_swapped_to_symlink() { + use std::os::unix::fs::symlink; + + let parent = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!("telemt-unknown-dc-parent-swap-{}", std::process::id())); + fs::create_dir_all(&parent).expect("parent-swap test parent must be creatable"); + + let rel_candidate = format!( + "target/telemt-unknown-dc-parent-swap-{}/unknown-dc.log", + std::process::id() + ); + let sanitized = sanitize_unknown_dc_log_path(&rel_candidate) + .expect("candidate must sanitize before parent swap"); + + let moved = parent.with_extension("bak"); + let _ = fs::remove_dir_all(&moved); + fs::rename(&parent, &moved).expect("parent must be movable for swap simulation"); + symlink("/tmp", &parent).expect("symlink replacement for parent must be creatable"); + + assert!( + !unknown_dc_log_path_is_still_safe(&sanitized), + "revalidation must fail when canonical parent is swapped to a symlinked target" + ); +} + +#[cfg(unix)] +#[test] +fn adversarial_check_then_symlink_flip_is_blocked_by_nofollow_open() { + use std::os::unix::fs::symlink; + + let parent = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!("telemt-unknown-dc-check-open-race-{}", std::process::id())); + fs::create_dir_all(&parent).expect("check-open-race parent must be creatable"); + + let target = parent.join("unknown-dc.log"); + fs::write(&target, "seed\n").expect("seed target file must be writable"); + let rel_candidate = format!( + "target/telemt-unknown-dc-check-open-race-{}/unknown-dc.log", + std::process::id() + ); + let sanitized = sanitize_unknown_dc_log_path(&rel_candidate) + .expect("candidate must sanitize"); + + assert!( + unknown_dc_log_path_is_still_safe(&sanitized), + "precondition: target should initially pass revalidation" + ); + + let outside = std::env::temp_dir().join(format!( + "telemt-unknown-dc-check-open-race-outside-{}.log", + std::process::id() + )); + fs::write(&outside, "outside\n").expect("outside file must be writable"); + fs::remove_file(&target).expect("target removal before flip must succeed"); + symlink(&outside, &target).expect("target symlink flip must be creatable"); + + let err = open_unknown_dc_log_append(&sanitized.resolved_path) + .expect_err("nofollow open must fail after symlink flip between check and open"); + assert_eq!( + err.raw_os_error(), + Some(libc::ELOOP), + "symlink flip in check/open window must be neutralized by O_NOFOLLOW" + ); +} + +#[tokio::test] +async fn unknown_dc_absolute_log_path_writes_one_entry() { + let _guard = unknown_dc_test_lock() + .lock() + .expect("unknown dc test lock must be available"); + clear_unknown_dc_log_cache_for_testing(); + + let dc_idx: i16 = 31_001; + let file_path = std::env::temp_dir().join(format!( + "telemt-unknown-dc-abs-{}-{}.log", + std::process::id(), + dc_idx + )); + let _ = fs::remove_file(&file_path); + + let mut cfg = ProxyConfig::default(); + cfg.general.unknown_dc_file_log_enabled = true; + cfg.general.unknown_dc_log_path = Some( + file_path + .to_str() + .expect("temp file path must be valid UTF-8") + .to_string(), + ); + + let _ = get_dc_addr_static(dc_idx, &cfg).expect("fallback routing must still work"); + + let mut content = None; + for _ in 0..20 { + if let Ok(text) = fs::read_to_string(&file_path) { + content = Some(text); + break; + } + tokio::time::sleep(Duration::from_millis(15)).await; + } + + let text = content.expect("absolute unknown-DC log path must produce exactly one log write"); + assert!( + text.contains(&format!("dc_idx={dc_idx}")), + "absolute unknown-DC integration log must contain requested dc_idx" + ); +} + +#[tokio::test] +async fn unknown_dc_safe_relative_log_path_writes_one_entry() { + let _guard = unknown_dc_test_lock() + .lock() + .expect("unknown dc test lock must be available"); + clear_unknown_dc_log_cache_for_testing(); + + let dc_idx: i16 = 31_002; + let rel_dir = format!("target/telemt-unknown-dc-int-{}", std::process::id()); + let rel_file = format!("{rel_dir}/unknown-dc.log"); + let abs_dir = std::env::current_dir() + .expect("cwd must be available") + .join(&rel_dir); + fs::create_dir_all(&abs_dir).expect("integration test log directory must be creatable"); + let abs_file = abs_dir.join("unknown-dc.log"); + let _ = fs::remove_file(&abs_file); + + let mut cfg = ProxyConfig::default(); + cfg.general.unknown_dc_file_log_enabled = true; + cfg.general.unknown_dc_log_path = Some(rel_file); + + let _ = get_dc_addr_static(dc_idx, &cfg).expect("fallback routing must still work"); + + let mut content = None; + for _ in 0..20 { + if let Ok(text) = fs::read_to_string(&abs_file) { + content = Some(text); + break; + } + tokio::time::sleep(Duration::from_millis(15)).await; + } + + let text = content.expect("safe relative path must produce exactly one log write"); + assert!( + text.contains(&format!("dc_idx={dc_idx}")), + "unknown-DC integration log must contain requested dc_idx" + ); +} + +#[tokio::test] +async fn unknown_dc_same_index_burst_writes_only_once() { + let _guard = unknown_dc_test_lock() + .lock() + .expect("unknown dc test lock must be available"); + clear_unknown_dc_log_cache_for_testing(); + + let dc_idx: i16 = 31_010; + let rel_dir = format!("target/telemt-unknown-dc-same-{}", std::process::id()); + let rel_file = format!("{rel_dir}/unknown-dc.log"); + let abs_dir = std::env::current_dir().unwrap().join(&rel_dir); + fs::create_dir_all(&abs_dir).expect("same-index log directory must be creatable"); + let abs_file = abs_dir.join("unknown-dc.log"); + let _ = fs::remove_file(&abs_file); + + let mut cfg = ProxyConfig::default(); + cfg.general.unknown_dc_file_log_enabled = true; + cfg.general.unknown_dc_log_path = Some(rel_file); + + for _ in 0..64 { + let _ = get_dc_addr_static(dc_idx, &cfg).expect("fallback routing must still work"); + } + + let mut content = None; + for _ in 0..30 { + if let Ok(text) = fs::read_to_string(&abs_file) { + content = Some(text); + break; + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + + let text = content.expect("same-index burst must produce at least one log write"); + assert_eq!( + nonempty_line_count(&text), + 1, + "same unknown dc index must be deduplicated to one file line" + ); +} + +#[tokio::test] +async fn unknown_dc_distinct_burst_is_hard_capped_on_file_writes() { + let _guard = unknown_dc_test_lock() + .lock() + .expect("unknown dc test lock must be available"); + clear_unknown_dc_log_cache_for_testing(); + + let rel_dir = format!("target/telemt-unknown-dc-cap-{}", std::process::id()); + let rel_file = format!("{rel_dir}/unknown-dc.log"); + let abs_dir = std::env::current_dir().unwrap().join(&rel_dir); + fs::create_dir_all(&abs_dir).expect("cap log directory must be creatable"); + let abs_file = abs_dir.join("unknown-dc.log"); + let _ = fs::remove_file(&abs_file); + + let mut cfg = ProxyConfig::default(); + cfg.general.unknown_dc_file_log_enabled = true; + cfg.general.unknown_dc_log_path = Some(rel_file); + + for i in 0..(UNKNOWN_DC_LOG_DISTINCT_LIMIT + 128) { + let dc_idx = 20_000i16.wrapping_add(i as i16); + let _ = get_dc_addr_static(dc_idx, &cfg).expect("fallback routing must still work"); + } + + let mut final_text = String::new(); + for _ in 0..80 { + if let Ok(text) = fs::read_to_string(&abs_file) { + final_text = text; + if nonempty_line_count(&final_text) >= UNKNOWN_DC_LOG_DISTINCT_LIMIT { + break; + } + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + + let line_count = nonempty_line_count(&final_text); + assert!( + line_count > 0, + "distinct unknown-dc burst must write at least one line" + ); + assert!( + line_count <= UNKNOWN_DC_LOG_DISTINCT_LIMIT, + "distinct unknown-dc writes must stay within dedup hard cap" + ); +} + +#[cfg(unix)] +#[tokio::test] +async fn unknown_dc_symlinked_target_escape_is_not_written_integration() { + use std::os::unix::fs::symlink; + + let _guard = unknown_dc_test_lock() + .lock() + .expect("unknown dc test lock must be available"); + clear_unknown_dc_log_cache_for_testing(); + + let base = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!("telemt-unknown-dc-no-write-link-{}", std::process::id())); + fs::create_dir_all(&base).expect("integration symlink base must be creatable"); + + let outside = std::env::temp_dir().join(format!( + "telemt-unknown-dc-outside-{}.log", + std::process::id() + )); + fs::write(&outside, "baseline\n").expect("outside baseline file must be writable"); + + let linked_target = base.join("unknown-dc.log"); + let _ = fs::remove_file(&linked_target); + symlink(&outside, &linked_target).expect("symlink target must be creatable"); + + let rel_file = format!( + "target/telemt-unknown-dc-no-write-link-{}/unknown-dc.log", + std::process::id() + ); + let dc_idx: i16 = 31_050; + + let mut cfg = ProxyConfig::default(); + cfg.general.unknown_dc_file_log_enabled = true; + cfg.general.unknown_dc_log_path = Some(rel_file); + + let before = fs::read_to_string(&outside).expect("must read baseline outside file"); + let _ = get_dc_addr_static(dc_idx, &cfg).expect("fallback routing must still work"); + tokio::time::sleep(Duration::from_millis(80)).await; + let after = fs::read_to_string(&outside).expect("must read outside file after attempt"); + + assert_eq!( + after, before, + "symlink target escape must not be written by unknown-DC logging" + ); +} + #[test] fn fallback_dc_never_panics_with_single_dc_list() { let mut cfg = ProxyConfig::default(); @@ -49,3 +851,359 @@ fn fallback_dc_never_panics_with_single_dc_list() { let expected = SocketAddr::new(TG_DATACENTERS_V6[0], TG_DATACENTER_PORT); assert_eq!(addr, expected); } + +#[tokio::test] +async fn direct_relay_abort_midflight_releases_route_gauge() { + let tg_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let tg_addr = tg_listener.local_addr().unwrap(); + + let tg_accept_task = tokio::spawn(async move { + let (stream, _) = tg_listener.accept().await.unwrap(); + let _hold_stream = stream; + tokio::time::sleep(Duration::from_secs(60)).await; + }); + + let stats = Arc::new(Stats::new()); + let mut config = ProxyConfig::default(); + config + .dc_overrides + .insert("2".to_string(), vec![tg_addr.to_string()]); + let config = Arc::new(config); + + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + + let rng = Arc::new(SecureRandom::new()); + let buffer_pool = Arc::new(BufferPool::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let route_snapshot = route_runtime.snapshot(); + + let (server_side, client_side) = duplex(64 * 1024); + let (server_reader, server_writer) = tokio::io::split(server_side); + let client_reader = make_crypto_reader(server_reader); + let client_writer = make_crypto_writer(server_writer); + + let success = HandshakeSuccess { + user: "abort-direct-user".to_string(), + dc_idx: 2, + proto_tag: ProtoTag::Intermediate, + dec_key: [0u8; 32], + dec_iv: 0, + enc_key: [0u8; 32], + enc_iv: 0, + peer: "127.0.0.1:50000".parse().unwrap(), + is_tls: false, + }; + + let relay_task = tokio::spawn(handle_via_direct( + client_reader, + client_writer, + success, + upstream_manager, + stats.clone(), + config, + buffer_pool, + rng, + route_runtime.subscribe(), + route_snapshot, + 0xabad1dea, + )); + + let started = tokio::time::timeout(Duration::from_secs(2), async { + loop { + if stats.get_current_connections_direct() == 1 { + break; + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + }) + .await; + assert!(started.is_ok(), "direct relay must increment route gauge before abort"); + + relay_task.abort(); + let joined = relay_task.await; + assert!(joined.is_err(), "aborted direct relay task must return join error"); + + tokio::time::sleep(Duration::from_millis(20)).await; + assert_eq!( + stats.get_current_connections_direct(), + 0, + "route gauge must be released when direct relay task is aborted mid-flight" + ); + + drop(client_side); + tg_accept_task.abort(); + let _ = tg_accept_task.await; +} + +#[tokio::test] +async fn direct_relay_cutover_midflight_releases_route_gauge() { + let tg_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let tg_addr = tg_listener.local_addr().unwrap(); + + let tg_accept_task = tokio::spawn(async move { + let (stream, _) = tg_listener.accept().await.unwrap(); + let _hold_stream = stream; + tokio::time::sleep(Duration::from_secs(60)).await; + }); + + let stats = Arc::new(Stats::new()); + let mut config = ProxyConfig::default(); + config + .dc_overrides + .insert("2".to_string(), vec![tg_addr.to_string()]); + let config = Arc::new(config); + + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + + let rng = Arc::new(SecureRandom::new()); + let buffer_pool = Arc::new(BufferPool::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let route_snapshot = route_runtime.snapshot(); + + let (server_side, client_side) = duplex(64 * 1024); + let (server_reader, server_writer) = tokio::io::split(server_side); + let client_reader = make_crypto_reader(server_reader); + let client_writer = make_crypto_writer(server_writer); + + let success = HandshakeSuccess { + user: "cutover-direct-user".to_string(), + dc_idx: 2, + proto_tag: ProtoTag::Intermediate, + dec_key: [0u8; 32], + dec_iv: 0, + enc_key: [0u8; 32], + enc_iv: 0, + peer: "127.0.0.1:50002".parse().unwrap(), + is_tls: false, + }; + + let relay_task = tokio::spawn(handle_via_direct( + client_reader, + client_writer, + success, + upstream_manager, + stats.clone(), + config, + buffer_pool, + rng, + route_runtime.subscribe(), + route_snapshot, + 0xface_cafe, + )); + + tokio::time::timeout(Duration::from_secs(2), async { + loop { + if stats.get_current_connections_direct() == 1 { + break; + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + }) + .await + .expect("direct relay must increment route gauge before cutover"); + + assert!( + route_runtime.set_mode(RelayRouteMode::Middle).is_some(), + "cutover must advance route generation" + ); + + let relay_result = tokio::time::timeout(Duration::from_secs(6), relay_task) + .await + .expect("direct relay must terminate after cutover") + .expect("direct relay task must not panic"); + assert!( + relay_result.is_err(), + "cutover should terminate direct relay session" + ); + assert!( + matches!( + relay_result, + Err(ProxyError::Proxy(ref msg)) if msg == ROUTE_SWITCH_ERROR_MSG + ), + "client-visible cutover error must stay generic and avoid route-internal metadata" + ); + + assert_eq!( + stats.get_current_connections_direct(), + 0, + "route gauge must be released when direct relay exits on cutover" + ); + + drop(client_side); + tg_accept_task.abort(); + let _ = tg_accept_task.await; +} + +#[tokio::test] +async fn direct_relay_cutover_storm_multi_session_keeps_generic_errors_and_releases_gauge() { + let session_count = 6usize; + let tg_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let tg_addr = tg_listener.local_addr().unwrap(); + + let tg_accept_task = tokio::spawn(async move { + let mut held_streams = Vec::with_capacity(session_count); + for _ in 0..session_count { + let (stream, _) = tg_listener.accept().await.unwrap(); + held_streams.push(stream); + } + tokio::time::sleep(Duration::from_secs(60)).await; + drop(held_streams); + }); + + let stats = Arc::new(Stats::new()); + let mut config = ProxyConfig::default(); + config + .dc_overrides + .insert("2".to_string(), vec![tg_addr.to_string()]); + let config = Arc::new(config); + + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + + let rng = Arc::new(SecureRandom::new()); + let buffer_pool = Arc::new(BufferPool::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let route_snapshot = route_runtime.snapshot(); + + let mut relay_tasks = Vec::with_capacity(session_count); + let mut client_sides = Vec::with_capacity(session_count); + + for idx in 0..session_count { + let (server_side, client_side) = duplex(64 * 1024); + client_sides.push(client_side); + let (server_reader, server_writer) = tokio::io::split(server_side); + let client_reader = make_crypto_reader(server_reader); + let client_writer = make_crypto_writer(server_writer); + + let success = HandshakeSuccess { + user: format!("cutover-storm-direct-user-{idx}"), + dc_idx: 2, + proto_tag: ProtoTag::Intermediate, + dec_key: [0u8; 32], + dec_iv: 0, + enc_key: [0u8; 32], + enc_iv: 0, + peer: SocketAddr::new( + std::net::IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 1)), + 51000 + idx as u16, + ), + is_tls: false, + }; + + relay_tasks.push(tokio::spawn(handle_via_direct( + client_reader, + client_writer, + success, + upstream_manager.clone(), + stats.clone(), + config.clone(), + buffer_pool.clone(), + rng.clone(), + route_runtime.subscribe(), + route_snapshot, + 0xA000_0000 + idx as u64, + ))); + } + + tokio::time::timeout(Duration::from_secs(4), async { + loop { + if stats.get_current_connections_direct() == session_count as u64 { + break; + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + }) + .await + .expect("all direct sessions must become active before cutover storm"); + + let route_runtime_flipper = route_runtime.clone(); + let flipper = tokio::spawn(async move { + for step in 0..64u32 { + let mode = if (step & 1) == 0 { + RelayRouteMode::Middle + } else { + RelayRouteMode::Direct + }; + let _ = route_runtime_flipper.set_mode(mode); + tokio::time::sleep(Duration::from_millis(15)).await; + } + }); + + for relay_task in relay_tasks { + let relay_result = tokio::time::timeout(Duration::from_secs(10), relay_task) + .await + .expect("direct relay task must finish under cutover storm") + .expect("direct relay task must not panic"); + + assert!( + matches!( + relay_result, + Err(ProxyError::Proxy(ref msg)) if msg == ROUTE_SWITCH_ERROR_MSG + ), + "storm-cutover termination must remain generic for all direct sessions" + ); + } + + flipper.abort(); + let _ = flipper.await; + + assert_eq!( + stats.get_current_connections_direct(), + 0, + "direct route gauge must return to zero after cutover storm" + ); + + drop(client_sides); + tg_accept_task.abort(); + let _ = tg_accept_task.await; +} diff --git a/src/proxy/handshake.rs b/src/proxy/handshake.rs index dbd50d5..3659754 100644 --- a/src/proxy/handshake.rs +++ b/src/proxy/handshake.rs @@ -4,11 +4,11 @@ use std::net::SocketAddr; use std::collections::HashSet; +use std::collections::hash_map::RandomState; use std::net::{IpAddr, Ipv6Addr}; use std::sync::Arc; use std::sync::{Mutex, OnceLock}; -use std::collections::hash_map::DefaultHasher; -use std::hash::{Hash, Hasher}; +use std::hash::{BuildHasher, Hash, Hasher}; use std::time::{Duration, Instant}; use dashmap::DashMap; use dashmap::mapref::entry::Entry; @@ -36,6 +36,7 @@ const AUTH_PROBE_TRACK_MAX_ENTRIES: usize = 256; const AUTH_PROBE_TRACK_MAX_ENTRIES: usize = 65_536; const AUTH_PROBE_PRUNE_SCAN_LIMIT: usize = 1_024; const AUTH_PROBE_BACKOFF_START_FAILS: u32 = 4; +const AUTH_PROBE_SATURATION_GRACE_FAILS: u32 = 2; #[cfg(test)] const AUTH_PROBE_BACKOFF_BASE_MS: u64 = 1; @@ -54,12 +55,25 @@ struct AuthProbeState { last_seen: Instant, } +#[derive(Clone, Copy)] +struct AuthProbeSaturationState { + fail_streak: u32, + blocked_until: Instant, + last_seen: Instant, +} + static AUTH_PROBE_STATE: OnceLock> = OnceLock::new(); +static AUTH_PROBE_SATURATION_STATE: OnceLock>> = OnceLock::new(); +static AUTH_PROBE_EVICTION_HASHER: OnceLock = OnceLock::new(); fn auth_probe_state_map() -> &'static DashMap { AUTH_PROBE_STATE.get_or_init(DashMap::new) } +fn auth_probe_saturation_state() -> &'static Mutex> { + AUTH_PROBE_SATURATION_STATE.get_or_init(|| Mutex::new(None)) +} + fn normalize_auth_probe_ip(peer_ip: IpAddr) -> IpAddr { match peer_ip { IpAddr::V4(ip) => IpAddr::V4(ip), @@ -88,7 +102,8 @@ fn auth_probe_state_expired(state: &AuthProbeState, now: Instant) -> bool { } fn auth_probe_eviction_offset(peer_ip: IpAddr, now: Instant) -> usize { - let mut hasher = DefaultHasher::new(); + let hasher_state = AUTH_PROBE_EVICTION_HASHER.get_or_init(RandomState::new); + let mut hasher = hasher_state.build_hasher(); peer_ip.hash(&mut hasher); now.hash(&mut hasher); hasher.finish() as usize @@ -108,6 +123,83 @@ fn auth_probe_is_throttled(peer_ip: IpAddr, now: Instant) -> bool { now < entry.blocked_until } +fn auth_probe_saturation_grace_exhausted(peer_ip: IpAddr, now: Instant) -> bool { + let peer_ip = normalize_auth_probe_ip(peer_ip); + let state = auth_probe_state_map(); + let Some(entry) = state.get(&peer_ip) else { + return false; + }; + if auth_probe_state_expired(&entry, now) { + drop(entry); + state.remove(&peer_ip); + return false; + } + + entry.fail_streak >= AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS +} + +fn auth_probe_should_apply_preauth_throttle(peer_ip: IpAddr, now: Instant) -> bool { + if !auth_probe_is_throttled(peer_ip, now) { + return false; + } + + if !auth_probe_saturation_is_throttled(now) { + return true; + } + + auth_probe_saturation_grace_exhausted(peer_ip, now) +} + +fn auth_probe_saturation_is_throttled(now: Instant) -> bool { + let saturation = auth_probe_saturation_state(); + let mut guard = match saturation.lock() { + Ok(guard) => guard, + Err(_) => return false, + }; + + let Some(state) = guard.as_mut() else { + return false; + }; + + if now.duration_since(state.last_seen) > Duration::from_secs(AUTH_PROBE_TRACK_RETENTION_SECS) { + *guard = None; + return false; + } + + if now < state.blocked_until { + return true; + } + + false +} + +fn auth_probe_note_saturation(now: Instant) { + let saturation = auth_probe_saturation_state(); + let mut guard = match saturation.lock() { + Ok(guard) => guard, + Err(_) => return, + }; + + match guard.as_mut() { + Some(state) + if now.duration_since(state.last_seen) + <= Duration::from_secs(AUTH_PROBE_TRACK_RETENTION_SECS) => + { + state.fail_streak = state.fail_streak.saturating_add(1); + state.last_seen = now; + state.blocked_until = now + auth_probe_backoff(state.fail_streak); + } + _ => { + let fail_streak = AUTH_PROBE_BACKOFF_START_FAILS; + *guard = Some(AuthProbeSaturationState { + fail_streak, + blocked_until: now + auth_probe_backoff(fail_streak), + last_seen: now, + }); + } + } +} + fn auth_probe_record_failure(peer_ip: IpAddr, now: Instant) { let peer_ip = normalize_auth_probe_ip(peer_ip); let state = auth_probe_state_map(); @@ -144,24 +236,79 @@ fn auth_probe_record_failure_with_state( } if state.len() >= AUTH_PROBE_TRACK_MAX_ENTRIES { - let mut stale_keys = Vec::new(); - let mut eviction_candidates = Vec::new(); - for entry in state.iter().take(AUTH_PROBE_PRUNE_SCAN_LIMIT) { - eviction_candidates.push(*entry.key()); - if auth_probe_state_expired(entry.value(), now) { - stale_keys.push(*entry.key()); - } - } - for stale_key in stale_keys { - state.remove(&stale_key); - } - if state.len() >= AUTH_PROBE_TRACK_MAX_ENTRIES { - if eviction_candidates.is_empty() { + let mut rounds = 0usize; + while state.len() >= AUTH_PROBE_TRACK_MAX_ENTRIES { + rounds += 1; + if rounds > 8 { + auth_probe_note_saturation(now); return; } - let idx = auth_probe_eviction_offset(peer_ip, now) % eviction_candidates.len(); - let evict_key = eviction_candidates[idx]; + + let mut stale_keys = Vec::new(); + let mut eviction_candidate: Option<(IpAddr, u32, Instant)> = None; + let state_len = state.len(); + let scan_limit = state_len.min(AUTH_PROBE_PRUNE_SCAN_LIMIT); + let start_offset = if state_len == 0 { + 0 + } else { + auth_probe_eviction_offset(peer_ip, now) % state_len + }; + + let mut scanned = 0usize; + for entry in state.iter().skip(start_offset) { + let key = *entry.key(); + let fail_streak = entry.value().fail_streak; + let last_seen = entry.value().last_seen; + match eviction_candidate { + Some((_, current_fail, current_seen)) + if fail_streak > current_fail + || (fail_streak == current_fail && last_seen >= current_seen) => + { + } + _ => eviction_candidate = Some((key, fail_streak, last_seen)), + } + if auth_probe_state_expired(entry.value(), now) { + stale_keys.push(key); + } + scanned += 1; + if scanned >= scan_limit { + break; + } + } + + if scanned < scan_limit { + for entry in state.iter().take(scan_limit - scanned) { + let key = *entry.key(); + let fail_streak = entry.value().fail_streak; + let last_seen = entry.value().last_seen; + match eviction_candidate { + Some((_, current_fail, current_seen)) + if fail_streak > current_fail + || (fail_streak == current_fail && last_seen >= current_seen) => + { + } + _ => eviction_candidate = Some((key, fail_streak, last_seen)), + } + if auth_probe_state_expired(entry.value(), now) { + stale_keys.push(key); + } + } + } + + for stale_key in stale_keys { + state.remove(&stale_key); + } + + if state.len() < AUTH_PROBE_TRACK_MAX_ENTRIES { + break; + } + + let Some((evict_key, _, _)) = eviction_candidate else { + auth_probe_note_saturation(now); + return; + }; state.remove(&evict_key); + auth_probe_note_saturation(now); } } @@ -186,6 +333,11 @@ fn clear_auth_probe_state_for_testing() { if let Some(state) = AUTH_PROBE_STATE.get() { state.clear(); } + if let Some(saturation) = AUTH_PROBE_SATURATION_STATE.get() + && let Ok(mut guard) = saturation.lock() + { + *guard = None; + } } #[cfg(test)] @@ -200,6 +352,16 @@ fn auth_probe_is_throttled_for_testing(peer_ip: IpAddr) -> bool { auth_probe_is_throttled(peer_ip, Instant::now()) } +#[cfg(test)] +fn auth_probe_saturation_is_throttled_for_testing() -> bool { + auth_probe_saturation_is_throttled(Instant::now()) +} + +#[cfg(test)] +fn auth_probe_saturation_is_throttled_at_for_testing(now: Instant) -> bool { + auth_probe_saturation_is_throttled(now) +} + #[cfg(test)] fn auth_probe_test_lock() -> &'static Mutex<()> { static TEST_LOCK: OnceLock> = OnceLock::new(); @@ -317,6 +479,24 @@ fn decode_user_secrets( secrets } +async fn maybe_apply_server_hello_delay(config: &ProxyConfig) { + if config.censorship.server_hello_delay_max_ms == 0 { + return; + } + + let min = config.censorship.server_hello_delay_min_ms; + let max = config.censorship.server_hello_delay_max_ms.max(min); + let delay_ms = if max == min { + max + } else { + rand::rng().random_range(min..=max) + }; + + if delay_ms > 0 { + tokio::time::sleep(Duration::from_millis(delay_ms)).await; + } +} + /// Result of successful handshake /// /// Key material (`dec_key`, `dec_iv`, `enc_key`, `enc_iv`) is @@ -367,17 +547,21 @@ where { debug!(peer = %peer, handshake_len = handshake.len(), "Processing TLS handshake"); - if auth_probe_is_throttled(peer.ip(), Instant::now()) { + let throttle_now = Instant::now(); + if auth_probe_should_apply_preauth_throttle(peer.ip(), throttle_now) { + maybe_apply_server_hello_delay(config).await; debug!(peer = %peer, "TLS handshake rejected by pre-auth probe throttle"); return HandshakeResult::BadClient { reader, writer }; } if handshake.len() < tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 { + maybe_apply_server_hello_delay(config).await; debug!(peer = %peer, "TLS handshake too short"); return HandshakeResult::BadClient { reader, writer }; } - let secrets = decode_user_secrets(config, None); + let client_sni = tls::extract_sni_from_client_hello(handshake); + let secrets = decode_user_secrets(config, client_sni.as_deref()); let validation = match tls::validate_tls_handshake_with_replay_window( handshake, @@ -388,6 +572,7 @@ where Some(v) => v, None => { auth_probe_record_failure(peer.ip(), Instant::now()); + maybe_apply_server_hello_delay(config).await; debug!( peer = %peer, ignore_time_skew = config.access.ignore_time_skew, @@ -402,20 +587,24 @@ where let digest_half = &validation.digest[..tls::TLS_DIGEST_HALF_LEN]; if replay_checker.check_and_add_tls_digest(digest_half) { auth_probe_record_failure(peer.ip(), Instant::now()); + maybe_apply_server_hello_delay(config).await; warn!(peer = %peer, "TLS replay attack detected (duplicate digest)"); return HandshakeResult::BadClient { reader, writer }; } let secret = match secrets.iter().find(|(name, _)| *name == validation.user) { Some((_, s)) => s, - None => return HandshakeResult::BadClient { reader, writer }, + None => { + maybe_apply_server_hello_delay(config).await; + return HandshakeResult::BadClient { reader, writer }; + } }; let cached = if config.censorship.tls_emulation { if let Some(cache) = tls_cache.as_ref() { - let selected_domain = if let Some(sni) = tls::extract_sni_from_client_hello(handshake) { + let selected_domain = if let Some(sni) = client_sni.as_ref() { if cache.contains_domain(&sni).await { - sni + sni.clone() } else { config.censorship.tls_domain.clone() } @@ -448,6 +637,7 @@ where } else if alpn_list.iter().any(|p| p == b"http/1.1") { Some(b"http/1.1".to_vec()) } else if !alpn_list.is_empty() { + maybe_apply_server_hello_delay(config).await; debug!(peer = %peer, "Client ALPN list has no supported protocol; using masking fallback"); return HandshakeResult::BadClient { reader, writer }; } else { @@ -480,19 +670,9 @@ where ) }; - // Optional anti-fingerprint delay before sending ServerHello. - if config.censorship.server_hello_delay_max_ms > 0 { - let min = config.censorship.server_hello_delay_min_ms; - let max = config.censorship.server_hello_delay_max_ms.max(min); - let delay_ms = if max == min { - max - } else { - rand::rng().random_range(min..=max) - }; - if delay_ms > 0 { - tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await; - } - } + // Apply the same optional delay budget used by reject paths to reduce + // distinguishability between success and fail-closed handshakes. + maybe_apply_server_hello_delay(config).await; debug!(peer = %peer, response_len = response.len(), "Sending TLS ServerHello"); @@ -538,7 +718,9 @@ where { trace!(peer = %peer, handshake = ?hex::encode(handshake), "MTProto handshake bytes"); - if auth_probe_is_throttled(peer.ip(), Instant::now()) { + let throttle_now = Instant::now(); + if auth_probe_should_apply_preauth_throttle(peer.ip(), throttle_now) { + maybe_apply_server_hello_delay(config).await; debug!(peer = %peer, "MTProto handshake rejected by pre-auth probe throttle"); return HandshakeResult::BadClient { reader, writer }; } @@ -609,6 +791,7 @@ where // authentication check first to avoid poisoning the replay cache. if replay_checker.check_and_add_handshake(dec_prekey_iv) { auth_probe_record_failure(peer.ip(), Instant::now()); + maybe_apply_server_hello_delay(config).await; warn!(peer = %peer, user = %user, "MTProto replay attack detected"); return HandshakeResult::BadClient { reader, writer }; } @@ -645,6 +828,7 @@ where } auth_probe_record_failure(peer.ip(), Instant::now()); + maybe_apply_server_hello_delay(config).await; debug!(peer = %peer, "MTProto handshake: no matching user found"); HandshakeResult::BadClient { reader, writer } } diff --git a/src/proxy/handshake_security_tests.rs b/src/proxy/handshake_security_tests.rs index 6bdc345..2132fbe 100644 --- a/src/proxy/handshake_security_tests.rs +++ b/src/proxy/handshake_security_tests.rs @@ -1,6 +1,8 @@ use super::*; -use crate::crypto::sha256_hmac; +use crate::crypto::{sha256, sha256_hmac}; use dashmap::DashMap; +use rand::{Rng, SeedableRng}; +use rand::rngs::StdRng; use std::net::{IpAddr, Ipv4Addr}; use std::sync::Arc; use std::time::{Duration, Instant}; @@ -84,6 +86,77 @@ fn make_valid_tls_client_hello_with_alpn( record } +fn make_valid_tls_client_hello_with_sni_and_alpn( + secret: &[u8], + timestamp: u32, + sni_host: &str, + alpn_protocols: &[&[u8]], +) -> Vec { + let mut body = Vec::new(); + body.extend_from_slice(&TLS_VERSION); + body.extend_from_slice(&[0u8; 32]); + body.push(32); + body.extend_from_slice(&[0x42u8; 32]); + body.extend_from_slice(&2u16.to_be_bytes()); + body.extend_from_slice(&[0x13, 0x01]); + body.push(1); + body.push(0); + + let mut ext_blob = Vec::new(); + + let host_bytes = sni_host.as_bytes(); + let mut sni_payload = Vec::new(); + sni_payload.extend_from_slice(&((host_bytes.len() + 3) as u16).to_be_bytes()); + sni_payload.push(0); + sni_payload.extend_from_slice(&(host_bytes.len() as u16).to_be_bytes()); + sni_payload.extend_from_slice(host_bytes); + ext_blob.extend_from_slice(&0x0000u16.to_be_bytes()); + ext_blob.extend_from_slice(&(sni_payload.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&sni_payload); + + if !alpn_protocols.is_empty() { + let mut alpn_list = Vec::new(); + for proto in alpn_protocols { + alpn_list.push(proto.len() as u8); + alpn_list.extend_from_slice(proto); + } + let mut alpn_data = Vec::new(); + alpn_data.extend_from_slice(&(alpn_list.len() as u16).to_be_bytes()); + alpn_data.extend_from_slice(&alpn_list); + + ext_blob.extend_from_slice(&0x0010u16.to_be_bytes()); + ext_blob.extend_from_slice(&(alpn_data.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&alpn_data); + } + + body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes()); + body.extend_from_slice(&ext_blob); + + let mut handshake = Vec::new(); + handshake.push(0x01); + let body_len = (body.len() as u32).to_be_bytes(); + handshake.extend_from_slice(&body_len[1..4]); + handshake.extend_from_slice(&body); + + let mut record = Vec::new(); + record.push(TLS_RECORD_HANDSHAKE); + record.extend_from_slice(&[0x03, 0x01]); + record.extend_from_slice(&(handshake.len() as u16).to_be_bytes()); + record.extend_from_slice(&handshake); + + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + let computed = sha256_hmac(secret, &record); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .copy_from_slice(&digest); + + record +} + fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig { let mut cfg = ProxyConfig::default(); cfg.access.users.clear(); @@ -94,6 +167,43 @@ fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig { cfg } +fn make_valid_mtproto_handshake(secret_hex: &str, proto_tag: ProtoTag, dc_idx: i16) -> [u8; HANDSHAKE_LEN] { + let secret = hex::decode(secret_hex).expect("secret hex must decode for mtproto test helper"); + + let mut handshake = [0x5Au8; HANDSHAKE_LEN]; + for (idx, b) in handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN] + .iter_mut() + .enumerate() + { + *b = (idx as u8).wrapping_add(1); + } + + let dec_prekey = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN]; + let dec_iv_bytes = &handshake[SKIP_LEN + PREKEY_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]; + + let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len()); + dec_key_input.extend_from_slice(dec_prekey); + dec_key_input.extend_from_slice(&secret); + let dec_key = sha256(&dec_key_input); + + let mut dec_iv_arr = [0u8; IV_LEN]; + dec_iv_arr.copy_from_slice(dec_iv_bytes); + let dec_iv = u128::from_be_bytes(dec_iv_arr); + + let mut stream = AesCtr::new(&dec_key, dec_iv); + let keystream = stream.encrypt(&[0u8; HANDSHAKE_LEN]); + + let mut target_plain = [0u8; HANDSHAKE_LEN]; + target_plain[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes()); + target_plain[DC_IDX_POS..DC_IDX_POS + 2].copy_from_slice(&dc_idx.to_le_bytes()); + + for idx in PROTO_TAG_POS..HANDSHAKE_LEN { + handshake[idx] = target_plain[idx] ^ keystream[idx]; + } + + handshake +} + #[test] fn test_generate_tg_nonce() { let client_enc_key = [0x24u8; 32]; @@ -293,6 +403,45 @@ async fn tls_replay_second_identical_handshake_is_rejected() { assert!(matches!(second, HandshakeResult::BadClient { .. })); } +#[tokio::test] +async fn tls_replay_with_ignore_time_skew_and_small_boot_timestamp_is_still_blocked() { + let secret = [0x19u8; 16]; + let config = test_config_with_secret_hex("19191919191919191919191919191919"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.121:44321".parse().unwrap(); + let handshake = make_valid_tls_handshake(&secret, 1); + + let first = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(first, HandshakeResult::Success(_))); + + let replay = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!( + matches!(replay, HandshakeResult::BadClient { .. }), + "ignore_time_skew must not weaken replay rejection for small boot timestamps" + ); +} + #[tokio::test] async fn tls_replay_concurrent_identical_handshake_allows_exactly_one_success() { let secret = [0x77u8; 16]; @@ -338,6 +487,177 @@ async fn tls_replay_concurrent_identical_handshake_allows_exactly_one_success() ); } +#[tokio::test] +async fn tls_replay_matrix_rotating_peers_first_accept_then_rejects() { + let secret = [0x52u8; 16]; + let config = test_config_with_secret_hex("52525252525252525252525252525252"); + let replay_checker = ReplayChecker::new(4096, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let handshake = make_valid_tls_handshake(&secret, 17); + + let first_peer: SocketAddr = "198.51.100.31:44001".parse().unwrap(); + let first = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + first_peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(first, HandshakeResult::Success(_))); + + for i in 0..128u16 { + let peer = SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(198, 51, 100, ((i % 250) + 1) as u8)), + 45000 + i, + ); + let replay = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!( + matches!(replay, HandshakeResult::BadClient { .. }), + "replay digest must be rejected regardless of source peer rotation" + ); + } +} + +#[tokio::test] +async fn adversarial_tls_replay_churn_allows_only_unique_digests() { + let secret = [0x5Au8; 16]; + let mut config = test_config_with_secret_hex("5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a"); + config.access.ignore_time_skew = true; + let config = Arc::new(config); + let replay_checker = Arc::new(ReplayChecker::new(8192, Duration::from_secs(60))); + let rng = Arc::new(SecureRandom::new()); + + let make_tagged_handshake = |timestamp: u32, tag: u8| { + let session_id_len: usize = 32; + let len = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + session_id_len; + let mut handshake = vec![tag; len]; + + handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8; + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + let computed = sha256_hmac(&secret, &handshake); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .copy_from_slice(&digest); + handshake + }; + + let mut tasks = Vec::new(); + + // 128 exact duplicates: only one should pass. + let duplicated = Arc::new(make_valid_tls_handshake(&secret, 999)); + for i in 0..128u16 { + let config = Arc::clone(&config); + let replay_checker = Arc::clone(&replay_checker); + let rng = Arc::clone(&rng); + let duplicated = Arc::clone(&duplicated); + tasks.push(tokio::spawn(async move { + let peer = SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(203, 0, 113, ((i % 250) + 1) as u8)), + 46000 + i, + ); + handle_tls_handshake( + &duplicated, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await + })); + } + + // 128 unique timestamps: all should pass because HMAC digest differs. + for i in 0..128u16 { + let config = Arc::clone(&config); + let replay_checker = Arc::clone(&replay_checker); + let rng = Arc::clone(&rng); + let handshake = make_tagged_handshake(10_000 + i as u32, (i as u8).wrapping_add(0x80)); + tasks.push(tokio::spawn(async move { + let peer = SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(198, 18, 0, ((i % 250) + 1) as u8)), + 47000 + i, + ); + handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await + })); + } + + let mut duplicate_success = 0usize; + let mut duplicate_reject = 0usize; + let mut unique_success = 0usize; + let mut unique_reject = 0usize; + + for (idx, task) in tasks.into_iter().enumerate() { + let result = task.await.unwrap(); + let is_duplicate_group = idx < 128; + match result { + HandshakeResult::Success(_) => { + if is_duplicate_group { + duplicate_success += 1; + } else { + unique_success += 1; + } + } + HandshakeResult::BadClient { .. } => { + if is_duplicate_group { + duplicate_reject += 1; + } else { + unique_reject += 1; + } + } + HandshakeResult::Error(e) => panic!("unexpected handshake error in churn test: {e}"), + } + } + + assert_eq!( + duplicate_success, 1, + "duplicate replay group must allow exactly one successful handshake" + ); + assert_eq!( + duplicate_reject, 127, + "duplicate replay group must reject all remaining replays" + ); + assert_eq!( + unique_success, 128, + "unique digest group must fully pass under replay churn" + ); + assert_eq!( + unique_reject, 0, + "unique digest group must not be falsely rejected as replay" + ); +} + #[tokio::test] async fn invalid_tls_probe_does_not_pollute_replay_cache() { let config = test_config_with_secret_hex("11111111111111111111111111111111"); @@ -349,6 +669,7 @@ async fn invalid_tls_probe_does_not_pollute_replay_cache() { invalid[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = 32; let before = replay_checker.stats(); + let result = handle_tls_handshake( &invalid, tokio::io::empty(), @@ -490,6 +811,149 @@ async fn mixed_secret_lengths_keep_valid_user_authenticating() { assert!(matches!(result, HandshakeResult::Success(_))); } +#[tokio::test] +async fn tls_sni_preferred_user_hint_selects_matching_identity_first() { + let shared_secret = [0x3Bu8; 16]; + let mut config = ProxyConfig::default(); + config.access.users.clear(); + config.access.users.insert( + "user-a".to_string(), + "3b3b3b3b3b3b3b3b3b3b3b3b3b3b3b3b".to_string(), + ); + config.access.users.insert( + "user-b".to_string(), + "3b3b3b3b3b3b3b3b3b3b3b3b3b3b3b3b".to_string(), + ); + config.access.ignore_time_skew = true; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.188:44326".parse().unwrap(); + let handshake = make_valid_tls_client_hello_with_sni_and_alpn( + &shared_secret, + 0, + "user-b", + &[b"h2"], + ); + + let result = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + match result { + HandshakeResult::Success((_, _, user)) => { + assert_eq!( + user, "user-b", + "TLS SNI preferred-user hint must select matching identity before equivalent decoys" + ); + } + _ => panic!("TLS handshake must succeed for valid shared-secret SNI case"), + } +} + +#[test] +fn stress_decode_user_secrets_keeps_preferred_user_first_in_large_set() { + let mut config = ProxyConfig::default(); + config.access.users.clear(); + + let preferred_user = "target-user.example".to_string(); + let secret_hex = "7f7f7f7f7f7f7f7f7f7f7f7f7f7f7f7f".to_string(); + + for i in 0..4096usize { + config.access.users.insert( + format!("decoy-{i:04}.example"), + secret_hex.clone(), + ); + } + config + .access + .users + .insert(preferred_user.clone(), secret_hex.clone()); + + let decoded = decode_user_secrets(&config, Some(preferred_user.as_str())); + assert_eq!( + decoded.len(), + config.access.users.len(), + "decoded secret set must preserve full user cardinality under stress" + ); + assert_eq!( + decoded.first().map(|(name, _)| name.as_str()), + Some(preferred_user.as_str()), + "preferred user must be first even under adversarial large user sets" + ); + assert_eq!( + decoded + .iter() + .filter(|(name, _)| name == &preferred_user) + .count(), + 1, + "preferred user must appear exactly once in decoded list" + ); +} + +#[tokio::test] +async fn stress_tls_sni_preferred_user_hint_scales_to_large_user_set() { + let shared_secret = [0x7Fu8; 16]; + let mut config = ProxyConfig::default(); + config.access.users.clear(); + config.access.ignore_time_skew = true; + + let preferred_user = "target-user.example".to_string(); + let secret_hex = "7f7f7f7f7f7f7f7f7f7f7f7f7f7f7f7f".to_string(); + + for i in 0..4096usize { + config.access.users.insert( + format!("decoy-{i:04}.example"), + secret_hex.clone(), + ); + } + config + .access + .users + .insert(preferred_user.clone(), secret_hex); + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.189:44326".parse().unwrap(); + let handshake = make_valid_tls_client_hello_with_sni_and_alpn( + &shared_secret, + 0, + preferred_user.as_str(), + &[b"h2"], + ); + + let result = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + match result { + HandshakeResult::Success((_, _, user)) => { + assert_eq!( + user, + preferred_user, + "SNI preferred-user hint must remain stable under large user cardinality" + ); + } + _ => panic!("TLS handshake must succeed for valid preferred-user stress case"), + } +} + #[tokio::test] async fn alpn_enforce_rejects_unsupported_client_alpn() { let secret = [0x33u8; 16]; @@ -580,6 +1044,72 @@ async fn malformed_tls_classes_complete_within_bounded_time() { } } +#[tokio::test] +async fn tls_invalid_hmac_respects_configured_anti_fingerprint_delay() { + let secret = [0x5Au8; 16]; + let mut config = test_config_with_secret_hex("5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a"); + config.censorship.server_hello_delay_min_ms = 20; + config.censorship.server_hello_delay_max_ms = 20; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.32:44331".parse().unwrap(); + let mut bad_hmac = make_valid_tls_handshake(&secret, 0); + bad_hmac[tls::TLS_DIGEST_POS] ^= 0x01; + + let started = Instant::now(); + let result = handle_tls_handshake( + &bad_hmac, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(result, HandshakeResult::BadClient { .. })); + assert!( + started.elapsed() >= Duration::from_millis(18), + "configured anti-fingerprint delay must apply to invalid TLS handshakes" + ); +} + +#[tokio::test] +async fn tls_alpn_mismatch_respects_configured_anti_fingerprint_delay() { + let secret = [0x6Bu8; 16]; + let mut config = test_config_with_secret_hex("6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b"); + config.censorship.alpn_enforce = true; + config.censorship.server_hello_delay_min_ms = 20; + config.censorship.server_hello_delay_max_ms = 20; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.33:44332".parse().unwrap(); + let handshake = make_valid_tls_client_hello_with_alpn(&secret, 0, &[b"h3"]); + + let started = Instant::now(); + let result = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(result, HandshakeResult::BadClient { .. })); + assert!( + started.elapsed() >= Duration::from_millis(18), + "configured anti-fingerprint delay must apply to ALPN-mismatch rejects" + ); +} + #[tokio::test] #[ignore = "timing-sensitive; run manually on low-jitter hosts"] async fn malformed_tls_classes_share_close_latency_buckets() { @@ -643,6 +1173,82 @@ async fn malformed_tls_classes_share_close_latency_buckets() { ); } +#[tokio::test] +#[ignore = "timing matrix; run manually with --ignored --nocapture"] +async fn timing_matrix_tls_classes_under_fixed_delay_budget() { + const ITER: usize = 48; + const BUCKET_MS: u128 = 10; + + let secret = [0x77u8; 16]; + let mut config = test_config_with_secret_hex("77777777777777777777777777777777"); + config.censorship.alpn_enforce = true; + config.censorship.server_hello_delay_min_ms = 20; + config.censorship.server_hello_delay_max_ms = 20; + + let rng = SecureRandom::new(); + let base_ip = std::net::Ipv4Addr::new(198, 51, 100, 34); + + let too_short = vec![0x16, 0x03, 0x01]; + let mut bad_hmac = make_valid_tls_handshake(&secret, 0); + bad_hmac[tls::TLS_DIGEST_POS + 1] ^= 0x01; + let alpn_mismatch = make_valid_tls_client_hello_with_alpn(&secret, 0, &[b"h3"]); + let valid_h2 = make_valid_tls_client_hello_with_alpn(&secret, 0, &[b"h2"]); + + let classes = vec![ + ("too_short", too_short), + ("bad_hmac", bad_hmac), + ("alpn_mismatch", alpn_mismatch), + ("valid_h2", valid_h2), + ]; + + for (class, probe) in classes { + let mut samples_ms = Vec::with_capacity(ITER); + for idx in 0..ITER { + clear_auth_probe_state_for_testing(); + let replay_checker = ReplayChecker::new(4096, Duration::from_secs(60)); + let peer: SocketAddr = SocketAddr::from((base_ip, 44_000 + idx as u16)); + let started = Instant::now(); + let result = handle_tls_handshake( + &probe, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + let elapsed = started.elapsed(); + samples_ms.push(elapsed.as_millis()); + + if class == "valid_h2" { + assert!(matches!(result, HandshakeResult::Success(_))); + } else { + assert!(matches!(result, HandshakeResult::BadClient { .. })); + } + } + + samples_ms.sort_unstable(); + let sum: u128 = samples_ms.iter().copied().sum(); + let mean = sum as f64 / samples_ms.len() as f64; + let min = samples_ms[0]; + let p95_idx = ((samples_ms.len() as f64) * 0.95).floor() as usize; + let p95 = samples_ms[p95_idx.min(samples_ms.len() - 1)]; + let max = samples_ms[samples_ms.len() - 1]; + + println!( + "TIMING_MATRIX tls class={} mean_ms={:.2} min_ms={} p95_ms={} max_ms={} bucket_mean={}", + class, + mean, + min, + p95, + max, + (mean as u128) / BUCKET_MS + ); + } +} + #[test] fn secure_tag_requires_tls_mode_on_tls_transport() { let mut config = ProxyConfig::default(); @@ -871,7 +1477,12 @@ fn auth_probe_capacity_prunes_stale_entries_for_new_ips() { } #[test] -fn auth_probe_capacity_forces_bounded_eviction_when_map_is_fresh_and_full() { +fn auth_probe_capacity_fresh_full_map_still_tracks_newcomer_with_bounded_eviction() { + let _guard = auth_probe_test_lock() + .lock() + .expect("auth probe test lock must be available"); + clear_auth_probe_state_for_testing(); + let state = DashMap::new(); let now = Instant::now(); @@ -887,23 +1498,281 @@ fn auth_probe_capacity_forces_bounded_eviction_when_map_is_fresh_and_full() { AuthProbeState { fail_streak: 1, blocked_until: now, - last_seen: now, + last_seen: now + Duration::from_millis(idx as u64 + 1), }, ); } + let oldest = IpAddr::V4(Ipv4Addr::new(172, 16, 0, 0)); + state.insert( + oldest, + AuthProbeState { + fail_streak: 1, + blocked_until: now, + last_seen: now - Duration::from_secs(5), + }, + ); + let newcomer = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 55)); auth_probe_record_failure_with_state(&state, newcomer, now); assert!( state.get(&newcomer).is_some(), - "when all entries are fresh and full, one bounded eviction must admit a new probe source" + "fresh-at-cap auth probe map must still track a new source after bounded eviction" + ); + assert!( + state.get(&oldest).is_none(), + "capacity eviction must remove the oldest tracked source first" ); assert_eq!( state.len(), AUTH_PROBE_TRACK_MAX_ENTRIES, - "auth probe map must stay at the configured cap after forced eviction" + "auth probe map must stay at configured cap after bounded eviction" ); + assert!( + auth_probe_saturation_is_throttled_at_for_testing(now), + "capacity pressure should still activate coarse global pre-auth throttling" + ); +} + +#[test] +fn stress_auth_probe_full_map_churn_keeps_bound_and_tracks_newcomers() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let state = DashMap::new(); + let base_now = Instant::now(); + + for idx in 0..AUTH_PROBE_TRACK_MAX_ENTRIES { + let ip = IpAddr::V4(Ipv4Addr::new( + 10, + 2, + ((idx >> 8) & 0xff) as u8, + (idx & 0xff) as u8, + )); + state.insert( + ip, + AuthProbeState { + fail_streak: 1, + blocked_until: base_now, + last_seen: base_now + Duration::from_millis((idx % 2048) as u64), + }, + ); + } + + for step in 0..1024usize { + let newcomer = IpAddr::V4(Ipv4Addr::new( + 203, + 0, + ((step >> 8) & 0xff) as u8, + (step & 0xff) as u8, + )); + let now = base_now + Duration::from_millis(10_000 + step as u64); + auth_probe_record_failure_with_state(&state, newcomer, now); + + assert!( + state.get(&newcomer).is_some(), + "new source must still be tracked under sustained at-capacity churn" + ); + assert_eq!( + state.len(), + AUTH_PROBE_TRACK_MAX_ENTRIES, + "auth probe map size must stay hard-bounded at capacity" + ); + } +} + +#[test] +fn auth_probe_capacity_prefers_evicting_low_fail_streak_entries_first() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let state = DashMap::new(); + let now = Instant::now(); + + // Fill map at capacity with mostly high fail streak entries. + for idx in 0..AUTH_PROBE_TRACK_MAX_ENTRIES { + let ip = IpAddr::V4(Ipv4Addr::new( + 172, + 20, + ((idx >> 8) & 0xff) as u8, + (idx & 0xff) as u8, + )); + state.insert( + ip, + AuthProbeState { + fail_streak: 9, + blocked_until: now, + last_seen: now + Duration::from_millis(idx as u64 + 1), + }, + ); + } + + let low_fail = IpAddr::V4(Ipv4Addr::new(172, 21, 0, 1)); + state.insert( + low_fail, + AuthProbeState { + fail_streak: 1, + blocked_until: now, + last_seen: now + Duration::from_secs(30), + }, + ); + + let high_fail_old = IpAddr::V4(Ipv4Addr::new(172, 21, 0, 2)); + state.insert( + high_fail_old, + AuthProbeState { + fail_streak: 12, + blocked_until: now, + last_seen: now - Duration::from_secs(10), + }, + ); + + let newcomer = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 201)); + auth_probe_record_failure_with_state(&state, newcomer, now); + + assert!(state.get(&newcomer).is_some(), "new source must be tracked"); + assert!( + state.get(&low_fail).is_none(), + "least-penalized entry should be evicted before high-penalty entries" + ); + assert!( + state.get(&high_fail_old).is_some(), + "high fail-streak entry should be preserved under mixed-priority eviction" + ); +} + +#[test] +fn auth_probe_capacity_tie_breaker_evicts_oldest_with_equal_fail_streak() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let state = DashMap::new(); + let now = Instant::now(); + + for idx in 0..(AUTH_PROBE_TRACK_MAX_ENTRIES - 2) { + let ip = IpAddr::V4(Ipv4Addr::new( + 172, + 30, + ((idx >> 8) & 0xff) as u8, + (idx & 0xff) as u8, + )); + state.insert( + ip, + AuthProbeState { + fail_streak: 5, + blocked_until: now, + last_seen: now + Duration::from_millis(idx as u64 + 1), + }, + ); + } + + let oldest = IpAddr::V4(Ipv4Addr::new(172, 31, 0, 1)); + let newer = IpAddr::V4(Ipv4Addr::new(172, 31, 0, 2)); + state.insert( + oldest, + AuthProbeState { + fail_streak: 1, + blocked_until: now, + last_seen: now - Duration::from_secs(20), + }, + ); + state.insert( + newer, + AuthProbeState { + fail_streak: 1, + blocked_until: now, + last_seen: now - Duration::from_secs(5), + }, + ); + + let newcomer = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 202)); + auth_probe_record_failure_with_state(&state, newcomer, now); + + assert!(state.get(&newcomer).is_some(), "new source must be tracked"); + assert!( + state.get(&oldest).is_none(), + "among equal fail streak candidates, oldest entry must be evicted" + ); + assert!( + state.get(&newer).is_some(), + "newer equal-priority entry should be retained" + ); +} + +#[test] +fn stress_auth_probe_capacity_churn_preserves_high_fail_sentinels() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let state = DashMap::new(); + let base_now = Instant::now(); + + let sentinel_a = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 250)); + let sentinel_b = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 251)); + + state.insert( + sentinel_a, + AuthProbeState { + fail_streak: 20, + blocked_until: base_now, + last_seen: base_now - Duration::from_secs(30), + }, + ); + state.insert( + sentinel_b, + AuthProbeState { + fail_streak: 21, + blocked_until: base_now, + last_seen: base_now - Duration::from_secs(31), + }, + ); + + for idx in 0..(AUTH_PROBE_TRACK_MAX_ENTRIES - 2) { + let ip = IpAddr::V4(Ipv4Addr::new( + 10, + 4, + ((idx >> 8) & 0xff) as u8, + (idx & 0xff) as u8, + )); + state.insert( + ip, + AuthProbeState { + fail_streak: 1, + blocked_until: base_now, + last_seen: base_now + Duration::from_millis((idx % 1024) as u64), + }, + ); + } + + for step in 0..1024usize { + let newcomer = IpAddr::V4(Ipv4Addr::new( + 203, + 1, + ((step >> 8) & 0xff) as u8, + (step & 0xff) as u8, + )); + let now = base_now + Duration::from_millis(10_000 + step as u64); + auth_probe_record_failure_with_state(&state, newcomer, now); + + assert_eq!( + state.len(), + AUTH_PROBE_TRACK_MAX_ENTRIES, + "auth probe map must remain hard-bounded at capacity" + ); + assert!( + state.get(&sentinel_a).is_some() && state.get(&sentinel_b).is_some(), + "high fail-streak sentinels should survive low-streak newcomer churn" + ); + } } #[test] @@ -996,6 +1865,97 @@ fn auth_probe_eviction_offset_varies_with_input() { assert_ne!(a, c, "different peer IPs should not collapse to one offset"); } +#[test] +fn auth_probe_eviction_offset_changes_with_time_component() { + let ip = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 77)); + let now = Instant::now(); + let later = now + Duration::from_millis(1); + + let a = auth_probe_eviction_offset(ip, now); + let b = auth_probe_eviction_offset(ip, later); + + assert_ne!( + a, b, + "eviction offset must incorporate timestamp entropy and not only peer IP" + ); +} + +#[test] +fn light_fuzz_auth_probe_eviction_offset_is_deterministic_per_input_pair() { + let mut rng = StdRng::seed_from_u64(0xA11CE5EED); + let base = Instant::now(); + + for _ in 0..4096usize { + let ip = IpAddr::V4(Ipv4Addr::new(rng.random(), rng.random(), rng.random(), rng.random())); + let offset_ns = rng.random_range(0_u64..2_000_000); + let when = base + Duration::from_nanos(offset_ns); + + let first = auth_probe_eviction_offset(ip, when); + let second = auth_probe_eviction_offset(ip, when); + assert_eq!( + first, second, + "eviction offset must be stable for identical (ip, now) pairs" + ); + } +} + +#[test] +fn adversarial_eviction_offset_spread_avoids_single_bucket_collapse() { + let modulus = AUTH_PROBE_TRACK_MAX_ENTRIES; + let mut bucket_hits = vec![0usize; modulus]; + let now = Instant::now(); + + for idx in 0..8192usize { + let ip = IpAddr::V4(Ipv4Addr::new( + 100, + ((idx >> 8) & 0xff) as u8, + (idx & 0xff) as u8, + ((idx.wrapping_mul(37)) & 0xff) as u8, + )); + let bucket = auth_probe_eviction_offset(ip, now) % modulus; + bucket_hits[bucket] += 1; + } + + let non_empty_buckets = bucket_hits.iter().filter(|&&hits| hits > 0).count(); + assert!( + non_empty_buckets >= modulus / 2, + "adversarial sequential input should cover a broad bucket set (covered {non_empty_buckets}/{modulus})" + ); + + let max_hits = bucket_hits.iter().copied().max().unwrap_or(0); + let min_non_zero_hits = bucket_hits + .iter() + .copied() + .filter(|&hits| hits > 0) + .min() + .unwrap_or(0); + assert!( + max_hits <= min_non_zero_hits.saturating_mul(32).max(1), + "bucket skew is unexpectedly extreme for keyed hasher spread (max={max_hits}, min_non_zero={min_non_zero_hits})" + ); +} + +#[test] +fn stress_auth_probe_eviction_offset_high_volume_uniqueness_sanity() { + let now = Instant::now(); + let mut seen = std::collections::HashSet::new(); + + for idx in 0..50_000usize { + let ip = IpAddr::V4(Ipv4Addr::new( + 198, + ((idx >> 16) & 0xff) as u8, + ((idx >> 8) & 0xff) as u8, + (idx & 0xff) as u8, + )); + seen.insert(auth_probe_eviction_offset(ip, now)); + } + + assert!( + seen.len() >= 40_000, + "high-volume eviction offsets should not collapse excessively under keyed hashing" + ); +} + #[tokio::test(flavor = "multi_thread", worker_threads = 4)] async fn auth_probe_concurrent_failures_do_not_lose_fail_streak_updates() { let _guard = auth_probe_test_lock() @@ -1108,3 +2068,1118 @@ async fn invalid_probe_noise_from_other_ips_does_not_break_valid_tls_handshake() "successful victim handshake must not retain pre-auth failure streak" ); } + +#[test] +fn auth_probe_saturation_state_expires_after_retention_window() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let now = Instant::now(); + let saturation = auth_probe_saturation_state(); + { + let mut guard = saturation + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(30), + last_seen: now - Duration::from_secs(AUTH_PROBE_TRACK_RETENTION_SECS + 1), + }); + } + + assert!( + !auth_probe_saturation_is_throttled_for_testing(), + "expired saturation state must stop throttling and self-clear" + ); + + let guard = saturation + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + assert!(guard.is_none(), "expired saturation state must be removed"); +} + +#[tokio::test] +async fn global_saturation_marker_does_not_block_valid_tls_handshake() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let secret = [0x41u8; 16]; + let config = test_config_with_secret_hex("41414141414141414141414141414141"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.101:45101".parse().unwrap(); + + let now = Instant::now(); + let saturation = auth_probe_saturation_state(); + { + let mut guard = saturation + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(5), + last_seen: now, + }); + } + + let valid = make_valid_tls_handshake(&secret, 0); + let result = handle_tls_handshake( + &valid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!( + matches!(result, HandshakeResult::Success(_)), + "global saturation marker must not block valid authenticated TLS handshakes" + ); + assert_eq!( + auth_probe_fail_streak_for_testing(peer.ip()), + None, + "successful handshake under saturation marker must not retain per-ip probe failures" + ); +} + +#[tokio::test] +async fn expired_global_saturation_allows_valid_tls_handshake() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let secret = [0x55u8; 16]; + let config = test_config_with_secret_hex("55555555555555555555555555555555"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.102:45102".parse().unwrap(); + + let now = Instant::now(); + let saturation = auth_probe_saturation_state(); + { + let mut guard = saturation + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(5), + last_seen: now - Duration::from_secs(AUTH_PROBE_TRACK_RETENTION_SECS + 1), + }); + } + + let valid = make_valid_tls_handshake(&secret, 0); + let result = handle_tls_handshake( + &valid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!( + matches!(result, HandshakeResult::Success(_)), + "expired saturation marker must not block valid handshake" + ); +} + +#[tokio::test] +async fn valid_tls_is_blocked_by_per_ip_preauth_throttle_without_saturation() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let secret = [0x61u8; 16]; + let config = test_config_with_secret_hex("61616161616161616161616161616161"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.103:45103".parse().unwrap(); + + auth_probe_state_map().insert( + normalize_auth_probe_ip(peer.ip()), + AuthProbeState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: Instant::now() + Duration::from_secs(5), + last_seen: Instant::now(), + }, + ); + + let valid = make_valid_tls_handshake(&secret, 0); + let result = handle_tls_handshake( + &valid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(result, HandshakeResult::BadClient { .. })); +} + +#[tokio::test] +async fn saturation_allows_valid_tls_even_when_peer_ip_is_currently_throttled() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let secret = [0x62u8; 16]; + let config = test_config_with_secret_hex("62626262626262626262626262626262"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.104:45104".parse().unwrap(); + let now = Instant::now(); + + auth_probe_state_map().insert( + normalize_auth_probe_ip(peer.ip()), + AuthProbeState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(5), + last_seen: now, + }, + ); + { + let mut guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(5), + last_seen: now, + }); + } + + let valid = make_valid_tls_handshake(&secret, 0); + let result = handle_tls_handshake( + &valid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(result, HandshakeResult::Success(_))); + assert_eq!( + auth_probe_fail_streak_for_testing(peer.ip()), + None, + "successful auth under saturation must clear the peer's throttled state" + ); +} + +#[tokio::test] +async fn saturation_still_rejects_invalid_tls_probe_and_records_failure() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let config = test_config_with_secret_hex("63636363636363636363636363636363"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.105:45105".parse().unwrap(); + let now = Instant::now(); + { + let mut guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(5), + last_seen: now, + }); + } + + let mut invalid = vec![0x42u8; tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + 32]; + invalid[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = 32; + + let result = handle_tls_handshake( + &invalid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(result, HandshakeResult::BadClient { .. })); + assert_eq!( + auth_probe_fail_streak_for_testing(peer.ip()), + Some(1), + "invalid TLS during saturation must still increment per-ip failure tracking" + ); +} + +#[tokio::test] +async fn saturation_grace_exhaustion_preauth_throttles_repeated_invalid_tls_probe() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let config = test_config_with_secret_hex("63636363636363636363636363636363"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.205:45205".parse().unwrap(); + let now = Instant::now(); + auth_probe_state_map().insert( + normalize_auth_probe_ip(peer.ip()), + AuthProbeState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }, + ); + { + let mut guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }); + } + + let mut invalid = vec![0x42u8; tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + 32]; + invalid[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = 32; + + let result = handle_tls_handshake( + &invalid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(result, HandshakeResult::BadClient { .. })); + assert_eq!( + auth_probe_fail_streak_for_testing(peer.ip()), + Some(AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS), + "pre-auth throttle under exhausted saturation grace must reject without re-processing invalid TLS" + ); +} + +#[tokio::test] +async fn saturation_allows_valid_mtproto_even_when_peer_ip_is_currently_throttled() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let secret_hex = "64646464646464646464646464646464"; + let mut config = test_config_with_secret_hex(secret_hex); + config.general.modes.secure = true; + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let peer: SocketAddr = "198.51.100.106:45106".parse().unwrap(); + let now = Instant::now(); + + auth_probe_state_map().insert( + normalize_auth_probe_ip(peer.ip()), + AuthProbeState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(5), + last_seen: now, + }, + ); + { + let mut guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(5), + last_seen: now, + }); + } + + let valid = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 2); + let result = handle_mtproto_handshake( + &valid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + + assert!(matches!(result, HandshakeResult::Success(_))); + assert_eq!( + auth_probe_fail_streak_for_testing(peer.ip()), + None, + "successful mtproto auth under saturation must clear the peer's throttled state" + ); +} + +#[tokio::test] +async fn saturation_still_rejects_invalid_mtproto_probe_and_records_failure() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let config = test_config_with_secret_hex("65656565656565656565656565656565"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let peer: SocketAddr = "198.51.100.107:45107".parse().unwrap(); + let now = Instant::now(); + { + let mut guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(5), + last_seen: now, + }); + } + + let invalid = [0u8; HANDSHAKE_LEN]; + + let result = handle_mtproto_handshake( + &invalid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + + assert!(matches!(result, HandshakeResult::BadClient { .. })); + assert_eq!( + auth_probe_fail_streak_for_testing(peer.ip()), + Some(1), + "invalid mtproto during saturation must still increment per-ip failure tracking" + ); +} + +#[tokio::test] +async fn saturation_grace_exhaustion_preauth_throttles_repeated_invalid_mtproto_probe() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let config = test_config_with_secret_hex("65656565656565656565656565656565"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let peer: SocketAddr = "198.51.100.206:45206".parse().unwrap(); + let now = Instant::now(); + auth_probe_state_map().insert( + normalize_auth_probe_ip(peer.ip()), + AuthProbeState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }, + ); + { + let mut guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }); + } + + let invalid = [0u8; HANDSHAKE_LEN]; + let result = handle_mtproto_handshake( + &invalid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + + assert!(matches!(result, HandshakeResult::BadClient { .. })); + assert_eq!( + auth_probe_fail_streak_for_testing(peer.ip()), + Some(AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS), + "pre-auth throttle under exhausted saturation grace must reject without re-processing invalid MTProto" + ); +} + +#[tokio::test] +async fn saturation_grace_progression_tls_reaches_cap_then_stops_incrementing() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let config = test_config_with_secret_hex("70707070707070707070707070707070"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.207:45207".parse().unwrap(); + let now = Instant::now(); + auth_probe_state_map().insert( + normalize_auth_probe_ip(peer.ip()), + AuthProbeState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }, + ); + { + let mut guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }); + } + + let mut invalid = vec![0x42u8; tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + 32]; + invalid[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = 32; + + for expected in [ + AUTH_PROBE_BACKOFF_START_FAILS + 1, + AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS, + ] { + let result = handle_tls_handshake( + &invalid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(result, HandshakeResult::BadClient { .. })); + assert_eq!(auth_probe_fail_streak_for_testing(peer.ip()), Some(expected)); + } + + { + let mut entry = auth_probe_state_map() + .get_mut(&normalize_auth_probe_ip(peer.ip())) + .expect("peer state must exist before exhaustion recheck"); + entry.fail_streak = AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS; + entry.blocked_until = Instant::now() + Duration::from_secs(1); + entry.last_seen = Instant::now(); + } + + let result = handle_tls_handshake( + &invalid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(result, HandshakeResult::BadClient { .. })); + assert_eq!( + auth_probe_fail_streak_for_testing(peer.ip()), + Some(AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS), + "once grace is exhausted, repeated invalid TLS must be pre-auth throttled without further fail-streak growth" + ); +} + +#[tokio::test] +async fn saturation_grace_progression_mtproto_reaches_cap_then_stops_incrementing() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let config = test_config_with_secret_hex("71717171717171717171717171717171"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let peer: SocketAddr = "198.51.100.208:45208".parse().unwrap(); + let now = Instant::now(); + auth_probe_state_map().insert( + normalize_auth_probe_ip(peer.ip()), + AuthProbeState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }, + ); + { + let mut guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }); + } + + let invalid = [0u8; HANDSHAKE_LEN]; + + for expected in [ + AUTH_PROBE_BACKOFF_START_FAILS + 1, + AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS, + ] { + let result = handle_mtproto_handshake( + &invalid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + assert!(matches!(result, HandshakeResult::BadClient { .. })); + assert_eq!(auth_probe_fail_streak_for_testing(peer.ip()), Some(expected)); + } + + { + let mut entry = auth_probe_state_map() + .get_mut(&normalize_auth_probe_ip(peer.ip())) + .expect("peer state must exist before exhaustion recheck"); + entry.fail_streak = AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS; + entry.blocked_until = Instant::now() + Duration::from_secs(1); + entry.last_seen = Instant::now(); + } + + let result = handle_mtproto_handshake( + &invalid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + assert!(matches!(result, HandshakeResult::BadClient { .. })); + assert_eq!( + auth_probe_fail_streak_for_testing(peer.ip()), + Some(AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS), + "once grace is exhausted, repeated invalid MTProto must be pre-auth throttled without further fail-streak growth" + ); +} + +#[tokio::test] +async fn saturation_grace_boundary_still_admits_valid_tls_before_exhaustion() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let secret = [0x72u8; 16]; + let config = test_config_with_secret_hex("72727272727272727272727272727272"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.209:45209".parse().unwrap(); + let now = Instant::now(); + auth_probe_state_map().insert( + normalize_auth_probe_ip(peer.ip()), + AuthProbeState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS - 1, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }, + ); + { + let mut guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }); + } + + let valid = make_valid_tls_handshake(&secret, 0); + let result = handle_tls_handshake( + &valid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!( + matches!(result, HandshakeResult::Success(_)), + "valid TLS should still pass while peer remains within saturation grace budget" + ); + assert_eq!(auth_probe_fail_streak_for_testing(peer.ip()), None); +} + +#[tokio::test] +async fn saturation_grace_exhaustion_blocks_valid_tls_until_backoff_expires() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let secret = [0x73u8; 16]; + let config = test_config_with_secret_hex("73737373737373737373737373737373"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.210:45210".parse().unwrap(); + let now = Instant::now(); + auth_probe_state_map().insert( + normalize_auth_probe_ip(peer.ip()), + AuthProbeState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS, + blocked_until: now + Duration::from_millis(200), + last_seen: now, + }, + ); + { + let mut guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }); + } + + let valid = make_valid_tls_handshake(&secret, 0); + let blocked = handle_tls_handshake( + &valid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(blocked, HandshakeResult::BadClient { .. })); + + tokio::time::sleep(Duration::from_millis(230)).await; + + let allowed = handle_tls_handshake( + &valid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!( + matches!(allowed, HandshakeResult::Success(_)), + "valid TLS should recover after peer-specific pre-auth backoff has elapsed" + ); + assert_eq!(auth_probe_fail_streak_for_testing(peer.ip()), None); +} + +#[tokio::test] +async fn saturation_grace_exhaustion_is_shared_across_tls_and_mtproto_for_same_peer() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let config = test_config_with_secret_hex("74747474747474747474747474747474"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.211:45211".parse().unwrap(); + let now = Instant::now(); + auth_probe_state_map().insert( + normalize_auth_probe_ip(peer.ip()), + AuthProbeState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }, + ); + { + let mut guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }); + } + + let mut invalid_tls = vec![0x42u8; tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + 32]; + invalid_tls[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = 32; + let invalid_mtproto = [0u8; HANDSHAKE_LEN]; + + let tls_result = handle_tls_handshake( + &invalid_tls, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(tls_result, HandshakeResult::BadClient { .. })); + + let mtproto_result = handle_mtproto_handshake( + &invalid_mtproto, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + assert!(matches!(mtproto_result, HandshakeResult::BadClient { .. })); + + assert_eq!( + auth_probe_fail_streak_for_testing(peer.ip()), + Some(AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS), + "saturation grace exhaustion must gate both TLS and MTProto pre-auth paths for one peer" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn adversarial_same_peer_invalid_tls_storm_does_not_bypass_saturation_grace_cap() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let config = Arc::new(test_config_with_secret_hex("75757575757575757575757575757575")); + let replay_checker = Arc::new(ReplayChecker::new(1024, Duration::from_secs(60))); + let rng = Arc::new(SecureRandom::new()); + let peer: SocketAddr = "198.51.100.212:45212".parse().unwrap(); + let now = Instant::now(); + auth_probe_state_map().insert( + normalize_auth_probe_ip(peer.ip()), + AuthProbeState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }, + ); + { + let mut guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }); + } + + let mut invalid_tls = vec![0x42u8; tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + 32]; + invalid_tls[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = 32; + let invalid_tls = Arc::new(invalid_tls); + + let mut tasks = Vec::new(); + for _ in 0..64usize { + let config = config.clone(); + let replay_checker = replay_checker.clone(); + let rng = rng.clone(); + let invalid_tls = invalid_tls.clone(); + tasks.push(tokio::spawn(async move { + handle_tls_handshake( + &invalid_tls, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await + })); + } + + for task in tasks { + let result = task.await.unwrap(); + assert!(matches!(result, HandshakeResult::BadClient { .. })); + } + + assert_eq!( + auth_probe_fail_streak_for_testing(peer.ip()), + Some(AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS), + "same-peer invalid storm under exhausted grace must stay pre-auth throttled without fail-streak growth" + ); +} + +#[tokio::test] +async fn light_fuzz_saturation_grace_tls_invalid_inputs_never_authenticate_or_panic() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let config = test_config_with_secret_hex("76767676767676767676767676767676"); + let replay_checker = ReplayChecker::new(2048, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.213:45213".parse().unwrap(); + let now = Instant::now(); + auth_probe_state_map().insert( + normalize_auth_probe_ip(peer.ip()), + AuthProbeState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }, + ); + { + let mut guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }); + } + + let mut seeded = StdRng::seed_from_u64(0xD15EA5E5_u64); + for _ in 0..128usize { + let len = seeded.random_range(0usize..96usize); + let mut probe = vec![0u8; len]; + seeded.fill(&mut probe[..]); + + let result = handle_tls_handshake( + &probe, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(result, HandshakeResult::BadClient { .. })); + } + + let streak = auth_probe_fail_streak_for_testing(peer.ip()) + .expect("peer should remain tracked after repeated invalid fuzz probes"); + assert!( + streak >= AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS, + "fuzzed invalid TLS probes under saturation must not reduce fail-streak below exhaustion threshold" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn adversarial_saturation_burst_only_admits_valid_tls_and_mtproto_handshakes() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let secret_hex = "66666666666666666666666666666666"; + let secret = [0x66u8; 16]; + let mut cfg = test_config_with_secret_hex(secret_hex); + cfg.general.modes.secure = true; + let config = Arc::new(cfg); + let replay_checker = Arc::new(ReplayChecker::new(4096, Duration::from_secs(60))); + let rng = Arc::new(SecureRandom::new()); + let now = Instant::now(); + + { + let mut guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(5), + last_seen: now, + }); + } + + let valid_tls = Arc::new(make_valid_tls_handshake(&secret, 0)); + let valid_mtproto = Arc::new(make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 3)); + let mut invalid_tls = vec![0x42u8; tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + 32]; + invalid_tls[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = 32; + let invalid_tls = Arc::new(invalid_tls); + + let mut invalid_tls_tasks = Vec::new(); + for idx in 0..48u16 { + let config = config.clone(); + let replay_checker = replay_checker.clone(); + let rng = rng.clone(); + let invalid_tls = invalid_tls.clone(); + invalid_tls_tasks.push(tokio::spawn(async move { + let octet = ((idx % 200) + 1) as u8; + let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(203, 0, 113, octet)), 46000 + idx); + handle_tls_handshake( + &invalid_tls, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await + })); + } + + let valid_tls_task = { + let config = config.clone(); + let replay_checker = replay_checker.clone(); + let rng = rng.clone(); + let valid_tls = valid_tls.clone(); + tokio::spawn(async move { + handle_tls_handshake( + &valid_tls, + tokio::io::empty(), + tokio::io::sink(), + "198.51.100.108:45108".parse().unwrap(), + &config, + &replay_checker, + &rng, + None, + ) + .await + }) + }; + + let valid_mtproto_task = { + let config = config.clone(); + let replay_checker = replay_checker.clone(); + let valid_mtproto = valid_mtproto.clone(); + tokio::spawn(async move { + handle_mtproto_handshake( + &valid_mtproto, + tokio::io::empty(), + tokio::io::sink(), + "198.51.100.109:45109".parse().unwrap(), + &config, + &replay_checker, + false, + None, + ) + .await + }) + }; + + let mut bad_clients = 0usize; + for task in invalid_tls_tasks { + match task.await.unwrap() { + HandshakeResult::BadClient { .. } => bad_clients += 1, + HandshakeResult::Success(_) => panic!("invalid TLS probe unexpectedly authenticated"), + HandshakeResult::Error(err) => panic!("unexpected error in invalid TLS saturation burst test: {err}"), + } + } + + let valid_tls_result = valid_tls_task.await.unwrap(); + assert!( + matches!(valid_tls_result, HandshakeResult::Success(_)), + "valid TLS probe must authenticate during saturation burst" + ); + + let valid_mtproto_result = valid_mtproto_task.await.unwrap(); + assert!( + matches!(valid_mtproto_result, HandshakeResult::Success(_)), + "valid MTProto probe must authenticate during saturation burst" + ); + + assert_eq!( + bad_clients, + 48, + "all invalid TLS probes in mixed saturation burst must be rejected" + ); +} + +#[tokio::test] +async fn expired_saturation_keeps_per_ip_throttle_enforced_for_valid_tls() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let secret = [0x67u8; 16]; + let config = test_config_with_secret_hex("67676767676767676767676767676767"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.110:45110".parse().unwrap(); + let now = Instant::now(); + + auth_probe_state_map().insert( + normalize_auth_probe_ip(peer.ip()), + AuthProbeState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(5), + last_seen: now, + }, + ); + { + let mut guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(5), + last_seen: now - Duration::from_secs(AUTH_PROBE_TRACK_RETENTION_SECS + 1), + }); + } + + let valid = make_valid_tls_handshake(&secret, 0); + let result = handle_tls_handshake( + &valid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!( + matches!(result, HandshakeResult::BadClient { .. }), + "expired saturation marker must not disable per-ip pre-auth throttle" + ); +} diff --git a/src/proxy/masking.rs b/src/proxy/masking.rs index 9a23c5b..b0f6985 100644 --- a/src/proxy/masking.rs +++ b/src/proxy/masking.rs @@ -7,7 +7,7 @@ use tokio::net::TcpStream; #[cfg(unix)] use tokio::net::UnixStream; use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt}; -use tokio::time::timeout; +use tokio::time::{Instant, timeout}; use tracing::debug; use crate::config::ProxyConfig; use crate::network::dns_overrides::resolve_socket_addr; @@ -24,8 +24,36 @@ const MASK_TIMEOUT: Duration = Duration::from_millis(50); const MASK_RELAY_TIMEOUT: Duration = Duration::from_secs(60); #[cfg(test)] const MASK_RELAY_TIMEOUT: Duration = Duration::from_millis(200); +#[cfg(not(test))] +const MASK_RELAY_IDLE_TIMEOUT: Duration = Duration::from_secs(5); +#[cfg(test)] +const MASK_RELAY_IDLE_TIMEOUT: Duration = Duration::from_millis(100); const MASK_BUFFER_SIZE: usize = 8192; +async fn copy_with_idle_timeout(reader: &mut R, writer: &mut W) +where + R: AsyncRead + Unpin, + W: AsyncWrite + Unpin, +{ + let mut buf = [0u8; MASK_BUFFER_SIZE]; + loop { + let read_res = timeout(MASK_RELAY_IDLE_TIMEOUT, reader.read(&mut buf)).await; + let n = match read_res { + Ok(Ok(n)) => n, + Ok(Err(_)) | Err(_) => break, + }; + if n == 0 { + break; + } + + let write_res = timeout(MASK_RELAY_IDLE_TIMEOUT, writer.write_all(&buf[..n])).await; + match write_res { + Ok(Ok(())) => {} + Ok(Err(_)) | Err(_) => break, + } + } +} + async fn write_proxy_header_with_timeout(mask_write: &mut W, header: &[u8]) -> bool where W: AsyncWrite + Unpin, @@ -49,6 +77,20 @@ where } } +async fn wait_mask_connect_budget(started: Instant) { + let elapsed = started.elapsed(); + if elapsed < MASK_TIMEOUT { + tokio::time::sleep(MASK_TIMEOUT - elapsed).await; + } +} + +async fn wait_mask_outcome_budget(started: Instant) { + let elapsed = started.elapsed(); + if elapsed < MASK_TIMEOUT { + tokio::time::sleep(MASK_TIMEOUT - elapsed).await; + } +} + /// Detect client type based on initial data fn detect_client_type(data: &[u8]) -> &'static str { // Check for HTTP request @@ -107,6 +149,8 @@ where // Connect via Unix socket or TCP #[cfg(unix)] if let Some(ref sock_path) = config.censorship.mask_unix_sock { + let outcome_started = Instant::now(); + let connect_started = Instant::now(); debug!( client_type = client_type, sock = %sock_path, @@ -143,14 +187,18 @@ where if timeout(MASK_RELAY_TIMEOUT, relay_to_mask(reader, writer, mask_read, mask_write, initial_data)).await.is_err() { debug!("Mask relay timed out (unix socket)"); } + wait_mask_outcome_budget(outcome_started).await; } Ok(Err(e)) => { + wait_mask_connect_budget(connect_started).await; debug!(error = %e, "Failed to connect to mask unix socket"); consume_client_data_with_timeout(reader).await; + wait_mask_outcome_budget(outcome_started).await; } Err(_) => { debug!("Timeout connecting to mask unix socket"); consume_client_data_with_timeout(reader).await; + wait_mask_outcome_budget(outcome_started).await; } } return; @@ -172,6 +220,8 @@ where let mask_addr = resolve_socket_addr(mask_host, mask_port) .map(|addr| addr.to_string()) .unwrap_or_else(|| format!("{}:{}", mask_host, mask_port)); + let outcome_started = Instant::now(); + let connect_started = Instant::now(); let connect_result = timeout(MASK_TIMEOUT, TcpStream::connect(&mask_addr)).await; match connect_result { Ok(Ok(stream)) => { @@ -202,14 +252,18 @@ where if timeout(MASK_RELAY_TIMEOUT, relay_to_mask(reader, writer, mask_read, mask_write, initial_data)).await.is_err() { debug!("Mask relay timed out"); } + wait_mask_outcome_budget(outcome_started).await; } Ok(Err(e)) => { + wait_mask_connect_budget(connect_started).await; debug!(error = %e, "Failed to connect to mask host"); consume_client_data_with_timeout(reader).await; + wait_mask_outcome_budget(outcome_started).await; } Err(_) => { debug!("Timeout connecting to mask host"); consume_client_data_with_timeout(reader).await; + wait_mask_outcome_budget(outcome_started).await; } } } @@ -238,11 +292,11 @@ where let _ = tokio::join!( async { - let _ = tokio::io::copy(&mut reader, &mut mask_write).await; + copy_with_idle_timeout(&mut reader, &mut mask_write).await; let _ = mask_write.shutdown().await; }, async { - let _ = tokio::io::copy(&mut mask_read, &mut writer).await; + copy_with_idle_timeout(&mut mask_read, &mut writer).await; let _ = writer.shutdown().await; } ); diff --git a/src/proxy/masking_security_tests.rs b/src/proxy/masking_security_tests.rs index 25b6a76..1cee108 100644 --- a/src/proxy/masking_security_tests.rs +++ b/src/proxy/masking_security_tests.rs @@ -8,7 +8,7 @@ use tokio::io::{duplex, AsyncBufReadExt, BufReader}; use tokio::net::TcpListener; #[cfg(unix)] use tokio::net::UnixListener; -use tokio::time::{sleep, timeout, Duration}; +use tokio::time::{Instant, sleep, timeout, Duration}; #[tokio::test] async fn bad_client_probe_is_forwarded_verbatim_to_mask_backend() { @@ -216,6 +216,373 @@ async fn backend_unavailable_falls_back_to_silent_consume() { assert_eq!(n, 0); } +#[tokio::test] +async fn backend_connect_refusal_waits_mask_connect_budget_before_fallback() { + let temp_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let unused_port = temp_listener.local_addr().unwrap().port(); + drop(temp_listener); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = unused_port; + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let peer: SocketAddr = "203.0.113.12:42426".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let probe = b"GET /probe HTTP/1.1\r\nHost: x\r\n\r\n"; + + // Close client reader immediately to force the refusal path to rely on masking budget timing. + let (client_reader_side, client_reader) = duplex(256); + drop(client_reader_side); + let (_client_visible_reader, client_visible_writer) = duplex(256); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + let task = tokio::spawn(async move { + handle_bad_client( + client_reader, + client_visible_writer, + probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + }); + + timeout(Duration::from_millis(35), task) + .await + .expect_err("masking fallback must not complete before connect budget elapses"); + assert!( + started.elapsed() >= Duration::from_millis(35), + "fallback path must absorb immediate connect refusal into connect budget" + ); +} + +#[tokio::test] +async fn backend_reachable_fast_response_waits_mask_outcome_budget() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = b"GET /ok HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + let backend_reply = b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut received = vec![0u8; probe.len()]; + stream.read_exact(&mut received).await.unwrap(); + assert_eq!(received, probe); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let peer: SocketAddr = "203.0.113.13:42427".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (client_writer_side, client_reader) = duplex(256); + drop(client_writer_side); + let (_client_visible_reader, client_visible_writer) = duplex(512); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + assert!( + started.elapsed() >= Duration::from_millis(45), + "reachable mask path must also satisfy coarse outcome budget" + ); + accept_task.await.unwrap(); +} + +#[tokio::test] +async fn mask_disabled_fast_eof_not_shaped_by_mask_budget() { + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = false; + + let peer: SocketAddr = "203.0.113.14:42428".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (client_writer_side, client_reader) = duplex(256); + drop(client_writer_side); + let (_client_visible_reader, client_visible_writer) = duplex(256); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + b"x", + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + assert!( + started.elapsed() < Duration::from_millis(20), + "mask-disabled fallback should keep immediate EOF behavior" + ); +} + +#[tokio::test] +async fn backend_reachable_slow_response_not_padded_twice() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = b"GET /slow HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + let backend_reply = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK".to_vec(); + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut received = vec![0u8; probe.len()]; + stream.read_exact(&mut received).await.unwrap(); + assert_eq!(received, probe); + sleep(Duration::from_millis(90)).await; + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let peer: SocketAddr = "203.0.113.15:42429".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (client_writer_side, client_reader) = duplex(256); + drop(client_writer_side); + let (_client_visible_reader, client_visible_writer) = duplex(512); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + let elapsed = started.elapsed(); + + assert!(elapsed >= Duration::from_millis(85)); + assert!( + elapsed < Duration::from_millis(170), + "slow reachable backend should not incur an extra full budget after already exceeding it" + ); + accept_task.await.unwrap(); +} + +#[tokio::test] +async fn adversarial_enabled_refused_and_reachable_collapse_to_same_bucket() { + const ITER: usize = 20; + const BUCKET_MS: u128 = 10; + + let probe = b"GET /collapse HTTP/1.1\r\nHost: x\r\n\r\n"; + let peer: SocketAddr = "203.0.113.16:42430".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let mut refused = Vec::with_capacity(ITER); + for _ in 0..ITER { + let temp_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let unused_port = temp_listener.local_addr().unwrap().port(); + drop(temp_listener); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = unused_port; + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let (client_writer_side, client_reader) = duplex(256); + drop(client_writer_side); + let (_client_visible_reader, client_visible_writer) = duplex(256); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + refused.push(started.elapsed().as_millis()); + } + + let mut reachable = Vec::with_capacity(ITER); + for _ in 0..ITER { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe_vec = probe.to_vec(); + let backend_reply = b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut received = vec![0u8; probe_vec.len()]; + stream.read_exact(&mut received).await.unwrap(); + stream.write_all(&backend_reply).await.unwrap(); + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let (client_writer_side, client_reader) = duplex(256); + drop(client_writer_side); + let (_client_visible_reader, client_visible_writer) = duplex(256); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + reachable.push(started.elapsed().as_millis()); + accept_task.await.unwrap(); + } + + let refused_mean = refused.iter().copied().sum::() as f64 / refused.len() as f64; + let reachable_mean = reachable.iter().copied().sum::() as f64 / reachable.len() as f64; + let refused_bucket = (refused_mean as u128) / BUCKET_MS; + let reachable_bucket = (reachable_mean as u128) / BUCKET_MS; + + assert!( + refused_bucket.abs_diff(reachable_bucket) <= 1, + "enabled refused and reachable paths must collapse into the same coarse latency bucket" + ); +} + +#[tokio::test] +async fn light_fuzz_mask_enabled_outcomes_preserve_coarse_budget() { + let mut seed: u64 = 0xA5A5_5A5A_1337_4242; + let mut next = || { + seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1); + seed + }; + + let peer: SocketAddr = "203.0.113.17:42431".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + for _ in 0..40 { + let probe_len = (next() as usize % 96).saturating_add(8); + let mut probe = vec![0u8; probe_len]; + for byte in &mut probe { + *byte = next() as u8; + } + + let use_reachable = (next() & 1) == 0; + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let (client_writer_side, client_reader) = duplex(512); + drop(client_writer_side); + let (_client_visible_reader, client_visible_writer) = duplex(512); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + if use_reachable { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + + let probe_vec = probe.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut observed = vec![0u8; probe_vec.len()]; + stream.read_exact(&mut observed).await.unwrap(); + }); + + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + accept_task.await.unwrap(); + } else { + let temp_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let unused_port = temp_listener.local_addr().unwrap().port(); + drop(temp_listener); + + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = unused_port; + + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + } + + assert!( + started.elapsed() >= Duration::from_millis(45), + "mask-enabled fallback must preserve coarse timing budget under varied probe shapes" + ); + } +} + #[tokio::test] async fn mask_disabled_consumes_client_data_without_response() { let mut config = ProxyConfig::default(); @@ -524,6 +891,59 @@ async fn mask_disabled_slowloris_connection_is_closed_by_consume_timeout() { timeout(Duration::from_secs(1), task).await.unwrap().unwrap(); } +#[tokio::test] +async fn mask_enabled_idle_relay_is_closed_by_idle_timeout_before_global_relay_timeout() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = b"GET /idle HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut received = vec![0u8; probe.len()]; + stream.read_exact(&mut received).await.unwrap(); + assert_eq!(received, probe); + sleep(Duration::from_millis(300)).await; + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let peer: SocketAddr = "198.51.100.34:45456".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (_client_reader_side, client_reader) = duplex(512); + let (_client_visible_reader, client_visible_writer) = duplex(512); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + let elapsed = started.elapsed(); + assert!( + elapsed < Duration::from_millis(150), + "idle unauth relay must terminate on idle timeout instead of waiting for full relay timeout" + ); + + accept_task.await.unwrap(); +} + struct PendingWriter; impl tokio::io::AsyncWrite for PendingWriter { @@ -729,3 +1149,321 @@ async fn relay_to_mask_timeout_cancels_and_drops_all_io_endpoints() { assert!(mask_reader_dropped.load(Ordering::SeqCst)); assert!(mask_writer_dropped.load(Ordering::SeqCst)); } + +#[tokio::test] +#[ignore = "timing matrix; run manually with --ignored --nocapture"] +async fn timing_matrix_masking_classes_under_controlled_inputs() { + const ITER: usize = 24; + const BUCKET_MS: u128 = 10; + + let probe = b"GET /timing HTTP/1.1\r\nHost: x\r\n\r\n"; + let peer: SocketAddr = "203.0.113.40:51000".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + // Class 1: masking disabled with immediate EOF (fast fail-closed consume path). + let mut disabled_samples = Vec::with_capacity(ITER); + for _ in 0..ITER { + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = false; + + let (client_writer_side, client_reader) = duplex(256); + drop(client_writer_side); + let (_client_visible_reader, client_visible_writer) = duplex(256); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + disabled_samples.push(started.elapsed().as_millis()); + } + + // Class 2: masking enabled, backend connect refused. + let mut refused_samples = Vec::with_capacity(ITER); + for _ in 0..ITER { + let temp_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let unused_port = temp_listener.local_addr().unwrap().port(); + drop(temp_listener); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = unused_port; + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let (client_writer_side, client_reader) = duplex(256); + drop(client_writer_side); + let (_client_visible_reader, client_visible_writer) = duplex(256); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + refused_samples.push(started.elapsed().as_millis()); + } + + // Class 3: masking enabled, backend reachable and immediately responds. + let mut reachable_samples = Vec::with_capacity(ITER); + for _ in 0..ITER { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let backend_reply = b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\n\r\n".to_vec(); + let probe_vec = probe.to_vec(); + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut received = vec![0u8; probe_vec.len()]; + stream.read_exact(&mut received).await.unwrap(); + assert_eq!(received, probe_vec); + stream.write_all(&backend_reply).await.unwrap(); + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let (client_writer_side, client_reader) = duplex(256); + drop(client_writer_side); + let (_client_visible_reader, client_visible_writer) = duplex(256); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + reachable_samples.push(started.elapsed().as_millis()); + accept_task.await.unwrap(); + } + + fn summarize(samples_ms: &mut [u128]) -> (f64, u128, u128, u128) { + samples_ms.sort_unstable(); + let sum: u128 = samples_ms.iter().copied().sum(); + let mean = sum as f64 / samples_ms.len() as f64; + let min = samples_ms[0]; + let p95_idx = ((samples_ms.len() as f64) * 0.95).floor() as usize; + let p95 = samples_ms[p95_idx.min(samples_ms.len() - 1)]; + let max = samples_ms[samples_ms.len() - 1]; + (mean, min, p95, max) + } + + let (disabled_mean, disabled_min, disabled_p95, disabled_max) = summarize(&mut disabled_samples); + let (refused_mean, refused_min, refused_p95, refused_max) = summarize(&mut refused_samples); + let (reachable_mean, reachable_min, reachable_p95, reachable_max) = summarize(&mut reachable_samples); + + println!( + "TIMING_MATRIX masking class=disabled_eof mean_ms={:.2} min_ms={} p95_ms={} max_ms={} bucket_mean={}", + disabled_mean, + disabled_min, + disabled_p95, + disabled_max, + (disabled_mean as u128) / BUCKET_MS + ); + println!( + "TIMING_MATRIX masking class=enabled_refused_eof mean_ms={:.2} min_ms={} p95_ms={} max_ms={} bucket_mean={}", + refused_mean, + refused_min, + refused_p95, + refused_max, + (refused_mean as u128) / BUCKET_MS + ); + println!( + "TIMING_MATRIX masking class=enabled_reachable_eof mean_ms={:.2} min_ms={} p95_ms={} max_ms={} bucket_mean={}", + reachable_mean, + reachable_min, + reachable_p95, + reachable_max, + (reachable_mean as u128) / BUCKET_MS + ); +} + +#[tokio::test] +async fn backend_connect_refusal_completes_within_bounded_mask_budget() { + let temp_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let unused_port = temp_listener.local_addr().unwrap().port(); + drop(temp_listener); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = unused_port; + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let peer: SocketAddr = "203.0.113.41:51001".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let probe = b"GET /bounded HTTP/1.1\r\nHost: x\r\n\r\n"; + + let (_client_reader_side, client_reader) = duplex(256); + let (_client_visible_reader, client_visible_writer) = duplex(256); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + let elapsed = started.elapsed(); + assert!( + elapsed >= Duration::from_millis(45), + "connect refusal path must respect minimum masking budget" + ); + assert!( + elapsed < Duration::from_millis(500), + "connect refusal path must stay bounded and avoid unbounded stall" + ); +} + +#[tokio::test] +async fn reachable_backend_one_response_then_silence_is_cut_by_idle_timeout() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = b"GET /oneshot HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + let response = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK".to_vec(); + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + let response = response.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut received = vec![0u8; probe.len()]; + stream.read_exact(&mut received).await.unwrap(); + assert_eq!(received, probe); + stream.write_all(&response).await.unwrap(); + sleep(Duration::from_millis(300)).await; + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let peer: SocketAddr = "203.0.113.42:51002".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (_client_reader_side, client_reader) = duplex(256); + let (mut client_visible_reader, client_visible_writer) = duplex(512); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + let elapsed = started.elapsed(); + + let mut observed = vec![0u8; response.len()]; + client_visible_reader.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, response); + assert!( + elapsed < Duration::from_millis(190), + "idle backend silence after first response must be cut by relay idle timeout" + ); + + accept_task.await.unwrap(); +} + +#[tokio::test] +async fn adversarial_client_drip_feed_longer_than_idle_timeout_is_cut_off() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let initial = b"GET /drip HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn({ + let initial = initial.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut observed = vec![0u8; initial.len()]; + stream.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, initial); + + let mut extra = [0u8; 1]; + let read_res = timeout(Duration::from_millis(220), stream.read_exact(&mut extra)).await; + assert!( + read_res.is_err() || read_res.unwrap().is_err(), + "drip-fed post-probe byte arriving after idle timeout should not be forwarded" + ); + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let peer: SocketAddr = "203.0.113.43:51003".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (mut client_writer_side, client_reader) = duplex(256); + let (_client_visible_reader, client_visible_writer) = duplex(256); + let beobachten = BeobachtenStore::new(); + + let relay_task = tokio::spawn(async move { + handle_bad_client( + client_reader, + client_visible_writer, + &initial, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + }); + + sleep(Duration::from_millis(160)).await; + let _ = client_writer_side.write_all(b"X").await; + drop(client_writer_side); + + timeout(Duration::from_secs(1), relay_task).await.unwrap().unwrap(); + accept_task.await.unwrap(); +} diff --git a/src/proxy/middle_relay.rs b/src/proxy/middle_relay.rs index 1acbdc1..bf23045 100644 --- a/src/proxy/middle_relay.rs +++ b/src/proxy/middle_relay.rs @@ -1,4 +1,5 @@ -use std::collections::hash_map::DefaultHasher; +use std::collections::hash_map::RandomState; +use std::hash::BuildHasher; use std::hash::{Hash, Hasher}; use std::net::{IpAddr, SocketAddr}; use std::sync::atomic::{AtomicU64, Ordering}; @@ -41,6 +42,7 @@ const C2ME_SENDER_FAIRNESS_BUDGET: usize = 32; const ME_D2C_FLUSH_BATCH_MAX_FRAMES_MIN: usize = 1; const ME_D2C_FLUSH_BATCH_MAX_BYTES_MIN: usize = 4096; static DESYNC_DEDUP: OnceLock> = OnceLock::new(); +static DESYNC_HASHER: OnceLock = OnceLock::new(); struct RelayForensicsState { trace_id: u64, @@ -80,7 +82,8 @@ impl MeD2cFlushPolicy { } fn hash_value(value: &T) -> u64 { - let mut hasher = DefaultHasher::new(); + let state = DESYNC_HASHER.get_or_init(RandomState::new); + let mut hasher = state.build_hasher(); value.hash(&mut hasher); hasher.finish() } @@ -106,12 +109,17 @@ fn should_emit_full_desync(key: u64, all_full: bool, now: Instant) -> bool { if dedup.len() >= DESYNC_DEDUP_MAX_ENTRIES { let mut stale_keys = Vec::new(); - let mut eviction_candidate = None; + let mut oldest_candidate: Option<(u64, Instant)> = None; for entry in dedup.iter().take(DESYNC_DEDUP_PRUNE_SCAN_LIMIT) { - if eviction_candidate.is_none() { - eviction_candidate = Some(*entry.key()); + let key = *entry.key(); + let seen_at = *entry.value(); + + match oldest_candidate { + Some((_, oldest_seen)) if seen_at >= oldest_seen => {} + _ => oldest_candidate = Some((key, seen_at)), } - if now.duration_since(*entry.value()) >= DESYNC_DEDUP_WINDOW { + + if now.duration_since(seen_at) >= DESYNC_DEDUP_WINDOW { stale_keys.push(*entry.key()); } } @@ -119,7 +127,7 @@ fn should_emit_full_desync(key: u64, all_full: bool, now: Instant) -> bool { dedup.remove(&stale_key); } if dedup.len() >= DESYNC_DEDUP_MAX_ENTRIES { - let Some(evict_key) = eviction_candidate else { + let Some((evict_key, _)) = oldest_candidate else { return false; }; dedup.remove(&evict_key); @@ -306,7 +314,7 @@ where }; stats.increment_user_connects(&user); - stats.increment_current_connections_me(); + let _me_connection_lease = stats.acquire_me_connection_lease(); if let Some(cutover) = affected_cutover_state( &route_rx, @@ -324,7 +332,6 @@ where tokio::time::sleep(delay).await; let _ = me_pool.send_close(conn_id).await; me_pool.registry().unregister(conn_id).await; - stats.decrement_current_connections_me(); return Err(ProxyError::Proxy(ROUTE_SWITCH_ERROR_MSG.to_string())); } @@ -672,7 +679,6 @@ where "ME relay cleanup" ); me_pool.registry().unregister(conn_id).await; - stats.decrement_current_connections_me(); result } diff --git a/src/proxy/middle_relay_security_tests.rs b/src/proxy/middle_relay_security_tests.rs index a2a6c3e..441595e 100644 --- a/src/proxy/middle_relay_security_tests.rs +++ b/src/proxy/middle_relay_security_tests.rs @@ -2,8 +2,15 @@ use super::*; use bytes::Bytes; use crate::crypto::AesCtr; use crate::crypto::SecureRandom; +use crate::config::{GeneralConfig, MeRouteNoWriterMode, MeSocksKdfPolicy, MeWriterPickMode}; +use crate::network::probe::NetworkDecision; +use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController}; use crate::stats::Stats; use crate::stream::{BufferPool, CryptoReader, CryptoWriter, PooledBuffer}; +use crate::transport::middle_proxy::MePool; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; +use std::collections::{HashMap, HashSet}; use std::net::SocketAddr; use std::sync::Arc; use std::sync::atomic::AtomicU64; @@ -215,6 +222,190 @@ fn desync_dedup_full_cache_churn_stays_suppressed() { } } +#[test] +fn dedup_hash_is_stable_for_same_input_within_process() { + let sample = ( + "scope_user", + hash_ip("198.51.100.7".parse().unwrap()), + ProtoTag::Secure, + ); + let first = hash_value(&sample); + let second = hash_value(&sample); + assert_eq!( + first, second, + "dedup hash must be stable within a process for cache lookups" + ); +} + +#[test] +fn dedup_hash_resists_simple_collision_bursts_for_peer_ip_space() { + let mut seen = HashSet::new(); + + for octet in 1u16..=2048 { + let third = ((octet / 256) & 0xff) as u8; + let fourth = (octet & 0xff) as u8; + let ip = IpAddr::V4(std::net::Ipv4Addr::new(198, 51, third, fourth)); + let key = hash_value(&( + "scope_user", + hash_ip(ip), + ProtoTag::Secure, + DESYNC_ERROR_CLASS, + )); + seen.insert(key); + } + + assert_eq!( + seen.len(), + 2048, + "adversarial peer-IP burst should not collapse dedup keys via trivial collisions" + ); +} + +#[test] +fn light_fuzz_dedup_hash_collision_rate_stays_negligible() { + let mut rng = StdRng::seed_from_u64(0x9E37_79B9_A1B2_C3D4); + let mut seen = HashSet::new(); + let samples = 8192usize; + + for _ in 0..samples { + let user_seed: u64 = rng.random(); + let peer_seed: u64 = rng.random(); + let proto = if (peer_seed & 1) == 0 { + ProtoTag::Secure + } else { + ProtoTag::Intermediate + }; + let key = hash_value(&(user_seed, peer_seed, proto, DESYNC_ERROR_CLASS)); + seen.insert(key); + } + + let collisions = samples - seen.len(); + assert!( + collisions <= 1, + "light fuzz collision count should remain negligible for 64-bit dedup keys" + ); +} + +#[test] +fn stress_desync_dedup_churn_keeps_cache_hard_bounded() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("desync dedup test lock must be available"); + clear_desync_dedup_for_testing(); + + let now = Instant::now(); + let total = DESYNC_DEDUP_MAX_ENTRIES + 8192; + + for key in 0..total as u64 { + let emitted = should_emit_full_desync(key, false, now); + if key < DESYNC_DEDUP_MAX_ENTRIES as u64 { + assert!(emitted, "keys below cap must be admitted initially"); + } else { + assert!( + !emitted, + "new keys above cap must stay suppressed under sustained churn" + ); + } + } + + let len = DESYNC_DEDUP + .get() + .expect("dedup cache must be initialized by stress run") + .len(); + assert!( + len <= DESYNC_DEDUP_MAX_ENTRIES, + "dedup cache must stay bounded under stress churn" + ); +} + +#[test] +fn desync_dedup_full_cache_inserts_new_key_with_bounded_single_key_churn() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("desync dedup test lock must be available"); + clear_desync_dedup_for_testing(); + + let dedup = DESYNC_DEDUP.get_or_init(DashMap::new); + let base_now = Instant::now(); + + // Fill with fresh entries so stale-pruning does not apply. + for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { + dedup.insert(key, base_now - TokioDuration::from_millis(10)); + } + + let before_keys: std::collections::HashSet = dedup.iter().map(|e| *e.key()).collect(); + + let newcomer_key = u64::MAX; + let emitted = should_emit_full_desync(newcomer_key, false, base_now); + assert!( + !emitted, + "new entry under full fresh cache must stay suppressed" + ); + assert!( + dedup.get(&newcomer_key).is_some(), + "new key must be inserted after bounded eviction" + ); + + let after_keys: std::collections::HashSet = dedup.iter().map(|e| *e.key()).collect(); + let removed_count = before_keys.difference(&after_keys).count(); + let added_count = after_keys.difference(&before_keys).count(); + + assert_eq!( + removed_count, 1, + "full-cache insertion must evict exactly one prior key" + ); + assert_eq!( + added_count, 1, + "full-cache insertion must add exactly one newcomer key" + ); + assert!( + dedup.len() <= DESYNC_DEDUP_MAX_ENTRIES, + "dedup cache must remain hard-bounded after full-cache churn" + ); +} + +#[test] +fn light_fuzz_desync_dedup_temporal_gate_behavior_is_stable() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("desync dedup test lock must be available"); + clear_desync_dedup_for_testing(); + + let key = 0xC0DE_CAFE_u64; + let start = Instant::now(); + + assert!( + should_emit_full_desync(key, false, start), + "first event for key must emit full forensic record" + ); + + // Deterministic pseudo-random time deltas around dedup window edge. + let mut s: u64 = 0x1234_5678_9ABC_DEF0; + for _ in 0..2048 { + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + + let delta_ms = s % (DESYNC_DEDUP_WINDOW.as_millis() as u64 * 2 + 1); + let now = start + TokioDuration::from_millis(delta_ms); + let emitted = should_emit_full_desync(key, false, now); + + if delta_ms < DESYNC_DEDUP_WINDOW.as_millis() as u64 { + assert!( + !emitted, + "events inside dedup window must remain suppressed" + ); + } else { + // Once window elapsed for this key, at least one sample should re-emit and refresh. + if emitted { + return; + } + } + } + + panic!("expected at least one post-window sample to re-emit forensic record"); +} + fn make_forensics_state() -> RelayForensicsState { RelayForensicsState { trace_id: 1, @@ -229,18 +420,108 @@ fn make_forensics_state() -> RelayForensicsState { } } -fn make_crypto_reader(reader: tokio::io::DuplexStream) -> CryptoReader { +fn make_crypto_reader(reader: R) -> CryptoReader +where + R: tokio::io::AsyncRead + Unpin, +{ let key = [0u8; 32]; let iv = 0u128; CryptoReader::new(reader, AesCtr::new(&key, iv)) } -fn make_crypto_writer(writer: tokio::io::DuplexStream) -> CryptoWriter { +fn make_crypto_writer(writer: W) -> CryptoWriter +where + W: tokio::io::AsyncWrite + Unpin, +{ let key = [0u8; 32]; let iv = 0u128; CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) } +async fn make_me_pool_for_abort_test(stats: Arc) -> Arc { + let general = GeneralConfig::default(); + + MePool::new( + None, + vec![1u8; 32], + None, + false, + None, + Vec::new(), + 1, + None, + 12, + 1200, + HashMap::new(), + HashMap::new(), + None, + NetworkDecision::default(), + None, + Arc::new(SecureRandom::new()), + stats, + general.me_keepalive_enabled, + general.me_keepalive_interval_secs, + general.me_keepalive_jitter_secs, + general.me_keepalive_payload_random, + general.rpc_proxy_req_every, + general.me_warmup_stagger_enabled, + general.me_warmup_step_delay_ms, + general.me_warmup_step_jitter_ms, + general.me_reconnect_max_concurrent_per_dc, + general.me_reconnect_backoff_base_ms, + general.me_reconnect_backoff_cap_ms, + general.me_reconnect_fast_retry_count, + general.me_single_endpoint_shadow_writers, + general.me_single_endpoint_outage_mode_enabled, + general.me_single_endpoint_outage_disable_quarantine, + general.me_single_endpoint_outage_backoff_min_ms, + general.me_single_endpoint_outage_backoff_max_ms, + general.me_single_endpoint_shadow_rotate_every_secs, + general.me_floor_mode, + general.me_adaptive_floor_idle_secs, + general.me_adaptive_floor_min_writers_single_endpoint, + general.me_adaptive_floor_min_writers_multi_endpoint, + general.me_adaptive_floor_recover_grace_secs, + general.me_adaptive_floor_writers_per_core_total, + general.me_adaptive_floor_cpu_cores_override, + general.me_adaptive_floor_max_extra_writers_single_per_core, + general.me_adaptive_floor_max_extra_writers_multi_per_core, + general.me_adaptive_floor_max_active_writers_per_core, + general.me_adaptive_floor_max_warm_writers_per_core, + general.me_adaptive_floor_max_active_writers_global, + general.me_adaptive_floor_max_warm_writers_global, + general.hardswap, + general.me_pool_drain_ttl_secs, + general.me_pool_drain_threshold, + general.effective_me_pool_force_close_secs(), + general.me_pool_min_fresh_ratio, + general.me_hardswap_warmup_delay_min_ms, + general.me_hardswap_warmup_delay_max_ms, + general.me_hardswap_warmup_extra_passes, + general.me_hardswap_warmup_pass_backoff_base_ms, + general.me_bind_stale_mode, + general.me_bind_stale_ttl_secs, + general.me_secret_atomic_snapshot, + general.me_deterministic_writer_sort, + MeWriterPickMode::default(), + general.me_writer_pick_sample_size, + MeSocksKdfPolicy::default(), + general.me_writer_cmd_channel_capacity, + general.me_route_channel_capacity, + general.me_route_backpressure_base_timeout_ms, + general.me_route_backpressure_high_timeout_ms, + general.me_route_backpressure_high_watermark_pct, + general.me_reader_route_data_wait_ms, + general.me_health_interval_ms_unhealthy, + general.me_health_interval_ms_healthy, + general.me_warn_rate_limit_ms, + MeRouteNoWriterMode::default(), + general.me_route_no_writer_wait_ms, + general.me_route_inline_recovery_attempts, + general.me_route_inline_recovery_wait_ms, + ) +} + fn encrypt_for_reader(plaintext: &[u8]) -> Vec { let key = [0u8; 32]; let iv = 0u128; @@ -779,3 +1060,259 @@ async fn process_me_writer_response_data_updates_byte_accounting() { "ME->C byte accounting must increase by emitted payload size" ); } + +#[tokio::test] +async fn middle_relay_abort_midflight_releases_route_gauge() { + let stats = Arc::new(Stats::new()); + let me_pool = make_me_pool_for_abort_test(stats.clone()).await; + let config = Arc::new(ProxyConfig::default()); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Middle)); + let route_snapshot = route_runtime.snapshot(); + + let (server_side, client_side) = duplex(64 * 1024); + let (server_reader, server_writer) = tokio::io::split(server_side); + let crypto_reader = make_crypto_reader(server_reader); + let crypto_writer = make_crypto_writer(server_writer); + + let success = HandshakeSuccess { + user: "abort-middle-user".to_string(), + dc_idx: 2, + proto_tag: ProtoTag::Intermediate, + dec_key: [0u8; 32], + dec_iv: 0, + enc_key: [0u8; 32], + enc_iv: 0, + peer: "127.0.0.1:50001".parse().unwrap(), + is_tls: false, + }; + + let relay_task = tokio::spawn(handle_via_middle_proxy( + crypto_reader, + crypto_writer, + success, + me_pool, + stats.clone(), + config, + buffer_pool, + "127.0.0.1:443".parse().unwrap(), + rng, + route_runtime.subscribe(), + route_snapshot, + 0xdecafbad, + )); + + let started = tokio::time::timeout(TokioDuration::from_secs(2), async { + loop { + if stats.get_current_connections_me() == 1 { + break; + } + tokio::time::sleep(TokioDuration::from_millis(10)).await; + } + }) + .await; + assert!(started.is_ok(), "middle relay must increment route gauge before abort"); + + relay_task.abort(); + let joined = relay_task.await; + assert!(joined.is_err(), "aborted middle relay task must return join error"); + + tokio::time::sleep(TokioDuration::from_millis(20)).await; + assert_eq!( + stats.get_current_connections_me(), + 0, + "route gauge must be released when middle relay task is aborted mid-flight" + ); + + drop(client_side); +} + +#[tokio::test] +async fn middle_relay_cutover_midflight_releases_route_gauge() { + let stats = Arc::new(Stats::new()); + let me_pool = make_me_pool_for_abort_test(stats.clone()).await; + let config = Arc::new(ProxyConfig::default()); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Middle)); + let route_snapshot = route_runtime.snapshot(); + + let (server_side, client_side) = duplex(64 * 1024); + let (server_reader, server_writer) = tokio::io::split(server_side); + let crypto_reader = make_crypto_reader(server_reader); + let crypto_writer = make_crypto_writer(server_writer); + + let success = HandshakeSuccess { + user: "cutover-middle-user".to_string(), + dc_idx: 2, + proto_tag: ProtoTag::Intermediate, + dec_key: [0u8; 32], + dec_iv: 0, + enc_key: [0u8; 32], + enc_iv: 0, + peer: "127.0.0.1:50003".parse().unwrap(), + is_tls: false, + }; + + let relay_task = tokio::spawn(handle_via_middle_proxy( + crypto_reader, + crypto_writer, + success, + me_pool, + stats.clone(), + config, + buffer_pool, + "127.0.0.1:443".parse().unwrap(), + rng, + route_runtime.subscribe(), + route_snapshot, + 0xfeed_beef, + )); + + tokio::time::timeout(TokioDuration::from_secs(2), async { + loop { + if stats.get_current_connections_me() == 1 { + break; + } + tokio::time::sleep(TokioDuration::from_millis(10)).await; + } + }) + .await + .expect("middle relay must increment route gauge before cutover"); + + assert!( + route_runtime.set_mode(RelayRouteMode::Direct).is_some(), + "cutover must advance route generation" + ); + + let relay_result = tokio::time::timeout(TokioDuration::from_secs(6), relay_task) + .await + .expect("middle relay must terminate after cutover") + .expect("middle relay task must not panic"); + assert!( + relay_result.is_err(), + "cutover should terminate middle relay session" + ); + assert!( + matches!( + relay_result, + Err(ProxyError::Proxy(ref msg)) if msg == ROUTE_SWITCH_ERROR_MSG + ), + "client-visible cutover error must stay generic and avoid route-internal metadata" + ); + + assert_eq!( + stats.get_current_connections_me(), + 0, + "route gauge must be released when middle relay exits on cutover" + ); + + drop(client_side); +} + +#[tokio::test] +async fn middle_relay_cutover_storm_multi_session_keeps_generic_errors_and_releases_gauge() { + let session_count = 6usize; + let stats = Arc::new(Stats::new()); + let me_pool = make_me_pool_for_abort_test(stats.clone()).await; + let config = Arc::new(ProxyConfig::default()); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Middle)); + let route_snapshot = route_runtime.snapshot(); + + let mut relay_tasks = Vec::with_capacity(session_count); + let mut client_sides = Vec::with_capacity(session_count); + + for idx in 0..session_count { + let (server_side, client_side) = duplex(64 * 1024); + client_sides.push(client_side); + let (server_reader, server_writer) = tokio::io::split(server_side); + let crypto_reader = make_crypto_reader(server_reader); + let crypto_writer = make_crypto_writer(server_writer); + + let success = HandshakeSuccess { + user: format!("cutover-storm-middle-user-{idx}"), + dc_idx: 2, + proto_tag: ProtoTag::Intermediate, + dec_key: [0u8; 32], + dec_iv: 0, + enc_key: [0u8; 32], + enc_iv: 0, + peer: SocketAddr::new( + std::net::IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 1)), + 52000 + idx as u16, + ), + is_tls: false, + }; + + relay_tasks.push(tokio::spawn(handle_via_middle_proxy( + crypto_reader, + crypto_writer, + success, + me_pool.clone(), + stats.clone(), + config.clone(), + buffer_pool.clone(), + "127.0.0.1:443".parse().unwrap(), + rng.clone(), + route_runtime.subscribe(), + route_snapshot, + 0xB000_0000 + idx as u64, + ))); + } + + tokio::time::timeout(TokioDuration::from_secs(4), async { + loop { + if stats.get_current_connections_me() == session_count as u64 { + break; + } + tokio::time::sleep(TokioDuration::from_millis(10)).await; + } + }) + .await + .expect("all middle sessions must become active before cutover storm"); + + let route_runtime_flipper = route_runtime.clone(); + let flipper = tokio::spawn(async move { + for step in 0..64u32 { + let mode = if (step & 1) == 0 { + RelayRouteMode::Direct + } else { + RelayRouteMode::Middle + }; + let _ = route_runtime_flipper.set_mode(mode); + tokio::time::sleep(TokioDuration::from_millis(15)).await; + } + }); + + for relay_task in relay_tasks { + let relay_result = tokio::time::timeout(TokioDuration::from_secs(10), relay_task) + .await + .expect("middle relay task must finish under cutover storm") + .expect("middle relay task must not panic"); + + assert!( + matches!( + relay_result, + Err(ProxyError::Proxy(ref msg)) if msg == ROUTE_SWITCH_ERROR_MSG + ), + "storm-cutover termination must remain generic for all middle sessions" + ); + } + + flipper.abort(); + let _ = flipper.await; + + assert_eq!( + stats.get_current_connections_me(), + 0, + "middle route gauge must return to zero after cutover storm" + ); + + drop(client_sides); +} diff --git a/src/proxy/route_mode.rs b/src/proxy/route_mode.rs index 306c536..114babe 100644 --- a/src/proxy/route_mode.rs +++ b/src/proxy/route_mode.rs @@ -1,10 +1,10 @@ use std::sync::Arc; -use std::sync::atomic::{AtomicU8, AtomicU64, Ordering}; +use std::sync::atomic::{AtomicU64, Ordering}; use std::time::{Duration, SystemTime, UNIX_EPOCH}; use tokio::sync::watch; -pub(crate) const ROUTE_SWITCH_ERROR_MSG: &str = "Route mode switched by cutover"; +pub(crate) const ROUTE_SWITCH_ERROR_MSG: &str = "Session terminated"; #[derive(Clone, Copy, Debug, PartialEq, Eq)] #[repr(u8)] @@ -14,17 +14,6 @@ pub(crate) enum RelayRouteMode { } impl RelayRouteMode { - pub(crate) fn as_u8(self) -> u8 { - self as u8 - } - - pub(crate) fn from_u8(value: u8) -> Self { - match value { - 1 => Self::Middle, - _ => Self::Direct, - } - } - pub(crate) fn as_str(self) -> &'static str { match self { Self::Direct => "direct", @@ -41,8 +30,6 @@ pub(crate) struct RouteCutoverState { #[derive(Clone)] pub(crate) struct RouteRuntimeController { - mode: Arc, - generation: Arc, direct_since_epoch_secs: Arc, tx: watch::Sender, } @@ -60,18 +47,13 @@ impl RouteRuntimeController { 0 }; Self { - mode: Arc::new(AtomicU8::new(initial_mode.as_u8())), - generation: Arc::new(AtomicU64::new(0)), direct_since_epoch_secs: Arc::new(AtomicU64::new(direct_since_epoch_secs)), tx, } } pub(crate) fn snapshot(&self) -> RouteCutoverState { - RouteCutoverState { - mode: RelayRouteMode::from_u8(self.mode.load(Ordering::Relaxed)), - generation: self.generation.load(Ordering::Relaxed), - } + *self.tx.borrow() } pub(crate) fn subscribe(&self) -> watch::Receiver { @@ -84,20 +66,29 @@ impl RouteRuntimeController { } pub(crate) fn set_mode(&self, mode: RelayRouteMode) -> Option { - let previous = self.mode.swap(mode.as_u8(), Ordering::Relaxed); - if previous == mode.as_u8() { + let mut next = None; + let changed = self.tx.send_if_modified(|state| { + if state.mode == mode { + return false; + } + state.mode = mode; + state.generation = state.generation.saturating_add(1); + next = Some(*state); + true + }); + + if !changed { return None; } + if matches!(mode, RelayRouteMode::Direct) { self.direct_since_epoch_secs .store(now_epoch_secs(), Ordering::Relaxed); } else { self.direct_since_epoch_secs.store(0, Ordering::Relaxed); } - let generation = self.generation.fetch_add(1, Ordering::Relaxed) + 1; - let next = RouteCutoverState { mode, generation }; - self.tx.send_replace(next); - Some(next) + + next } } @@ -110,10 +101,10 @@ fn now_epoch_secs() -> u64 { pub(crate) fn is_session_affected_by_cutover( current: RouteCutoverState, - _session_mode: RelayRouteMode, + session_mode: RelayRouteMode, session_generation: u64, ) -> bool { - current.generation > session_generation + current.generation > session_generation && current.mode != session_mode } pub(crate) fn affected_cutover_state( @@ -140,3 +131,7 @@ pub(crate) fn cutover_stagger_delay(session_id: u64, generation: u64) -> Duratio let ms = 1000 + (value % 1000); Duration::from_millis(ms) } + +#[cfg(test)] +#[path = "route_mode_security_tests.rs"] +mod security_tests; diff --git a/src/proxy/route_mode_security_tests.rs b/src/proxy/route_mode_security_tests.rs new file mode 100644 index 0000000..e86d574 --- /dev/null +++ b/src/proxy/route_mode_security_tests.rs @@ -0,0 +1,340 @@ +use super::*; +use rand::{Rng, SeedableRng}; +use rand::rngs::StdRng; +use std::sync::Arc; +use std::sync::atomic::{AtomicU64, Ordering}; + +#[test] +fn cutover_stagger_delay_is_deterministic_for_same_inputs() { + let d1 = cutover_stagger_delay(0x0123_4567_89ab_cdef, 42); + let d2 = cutover_stagger_delay(0x0123_4567_89ab_cdef, 42); + assert_eq!( + d1, d2, + "stagger delay must be deterministic for identical session/generation inputs" + ); +} + +#[test] +fn cutover_stagger_delay_stays_within_budget_bounds() { + // Black-hat model: censors trigger many cutovers and correlate disconnect timing. + // Keep delay inside a narrow coarse window to avoid long-tail spikes. + for generation in [0u64, 1, 2, 3, 16, 128, u32::MAX as u64, u64::MAX] { + for session_id in [ + 0u64, + 1, + 2, + 0xdead_beef, + 0xfeed_face_cafe_babe, + u64::MAX, + ] { + let delay = cutover_stagger_delay(session_id, generation); + assert!( + (1000..=1999).contains(&delay.as_millis()), + "stagger delay must remain in fixed 1000..=1999ms budget" + ); + } + } +} + +#[test] +fn cutover_stagger_delay_changes_with_generation_for_same_session() { + let session_id = 0x0123_4567_89ab_cdef; + let first = cutover_stagger_delay(session_id, 100); + let second = cutover_stagger_delay(session_id, 101); + assert_ne!( + first, second, + "adjacent cutover generations should decorrelate disconnect delays" + ); +} + +#[test] +fn route_runtime_set_mode_is_idempotent_for_same_mode() { + let runtime = RouteRuntimeController::new(RelayRouteMode::Direct); + let first = runtime.snapshot(); + let changed = runtime.set_mode(RelayRouteMode::Direct); + let second = runtime.snapshot(); + + assert!( + changed.is_none(), + "setting already-active mode must not produce a cutover event" + ); + assert_eq!( + first.generation, second.generation, + "idempotent mode set must not bump generation" + ); +} + +#[test] +fn affected_cutover_state_triggers_only_for_newer_generation() { + let runtime = RouteRuntimeController::new(RelayRouteMode::Direct); + let rx = runtime.subscribe(); + let initial = runtime.snapshot(); + + assert!( + affected_cutover_state(&rx, RelayRouteMode::Direct, initial.generation).is_none(), + "current generation must not be considered a cutover for existing session" + ); + + let next = runtime + .set_mode(RelayRouteMode::Middle) + .expect("mode change must produce cutover state"); + let seen = affected_cutover_state(&rx, RelayRouteMode::Direct, initial.generation) + .expect("newer generation must be observed as cutover"); + + assert_eq!(seen.generation, next.generation); + assert_eq!(seen.mode, RelayRouteMode::Middle); +} + +#[test] +fn integration_watch_and_snapshot_follow_same_transition_sequence() { + let runtime = RouteRuntimeController::new(RelayRouteMode::Direct); + let rx = runtime.subscribe(); + + let sequence = [ + RelayRouteMode::Middle, + RelayRouteMode::Middle, + RelayRouteMode::Direct, + RelayRouteMode::Direct, + RelayRouteMode::Middle, + ]; + + let mut expected_generation = 0u64; + let mut expected_mode = RelayRouteMode::Direct; + + for target in sequence { + let changed = runtime.set_mode(target); + if target == expected_mode { + assert!(changed.is_none(), "idempotent transition must return none"); + } else { + expected_mode = target; + expected_generation = expected_generation.saturating_add(1); + let emitted = changed.expect("real transition must emit cutover state"); + assert_eq!(emitted.mode, expected_mode); + assert_eq!(emitted.generation, expected_generation); + } + + let snap = runtime.snapshot(); + let watched = *rx.borrow(); + assert_eq!(snap, watched, "snapshot and watch state must stay aligned"); + assert_eq!(snap.mode, expected_mode); + assert_eq!(snap.generation, expected_generation); + } +} + +#[test] +fn session_is_not_affected_when_mode_matches_even_if_generation_advanced() { + let session_mode = RelayRouteMode::Direct; + let current = RouteCutoverState { + mode: RelayRouteMode::Direct, + generation: 2, + }; + let session_generation = 0; + + assert!( + !is_session_affected_by_cutover(current, session_mode, session_generation), + "session on matching final route mode should not be force-cut over on intermediate generation bumps" + ); +} + +#[test] +fn cutover_predicate_rejects_equal_generation_even_if_mode_differs() { + let current = RouteCutoverState { + mode: RelayRouteMode::Middle, + generation: 77, + }; + assert!( + !is_session_affected_by_cutover(current, RelayRouteMode::Direct, 77), + "equal generation must never trigger cutover regardless of mode mismatch" + ); +} + +#[test] +fn adversarial_route_oscillation_only_cuts_over_sessions_with_different_final_mode() { + let runtime = RouteRuntimeController::new(RelayRouteMode::Direct); + let rx = runtime.subscribe(); + let session_generation = runtime.snapshot().generation; + + runtime + .set_mode(RelayRouteMode::Middle) + .expect("direct->middle must transition"); + runtime + .set_mode(RelayRouteMode::Direct) + .expect("middle->direct must transition"); + + assert!( + affected_cutover_state(&rx, RelayRouteMode::Direct, session_generation).is_none(), + "direct session should survive when final mode returns to direct" + ); + assert!( + affected_cutover_state(&rx, RelayRouteMode::Middle, session_generation).is_some(), + "middle session should be cut over when final mode is direct" + ); +} + +#[test] +fn light_fuzz_cutover_predicate_matches_reference_oracle() { + let mut rng = StdRng::seed_from_u64(0xC0DEC0DE5EED); + for _ in 0..20_000 { + let current = RouteCutoverState { + mode: if rng.random::() { + RelayRouteMode::Direct + } else { + RelayRouteMode::Middle + }, + generation: rng.random_range(0u64..1_000_000), + }; + let session_mode = if rng.random::() { + RelayRouteMode::Direct + } else { + RelayRouteMode::Middle + }; + let session_generation = rng.random_range(0u64..1_000_000); + + let expected = current.generation > session_generation && current.mode != session_mode; + let actual = is_session_affected_by_cutover(current, session_mode, session_generation); + assert_eq!( + actual, expected, + "cutover predicate must match mode-aware generation oracle" + ); + } +} + +#[test] +fn light_fuzz_set_mode_generation_tracks_only_real_transitions() { + let runtime = RouteRuntimeController::new(RelayRouteMode::Direct); + let mut rng = StdRng::seed_from_u64(0x0DDC0FFE); + + let mut expected_mode = RelayRouteMode::Direct; + let mut expected_generation = 0u64; + + for _ in 0..10_000 { + let candidate = if rng.random::() { + RelayRouteMode::Direct + } else { + RelayRouteMode::Middle + }; + let changed = runtime.set_mode(candidate); + + if candidate == expected_mode { + assert!(changed.is_none(), "idempotent set_mode must not emit cutover state"); + } else { + expected_mode = candidate; + expected_generation = expected_generation.saturating_add(1); + let next = changed.expect("mode transition must emit cutover state"); + assert_eq!(next.mode, expected_mode); + assert_eq!(next.generation, expected_generation); + } + } + + let final_state = runtime.snapshot(); + assert_eq!(final_state.mode, expected_mode); + assert_eq!(final_state.generation, expected_generation); +} + +#[test] +fn stress_snapshot_and_watch_state_remain_consistent_under_concurrent_switch_storm() { + let runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + + std::thread::scope(|scope| { + let mut writers = Vec::new(); + for worker in 0..4usize { + let runtime = Arc::clone(&runtime); + writers.push(scope.spawn(move || { + for step in 0..20_000usize { + let mode = if (worker + step) % 2 == 0 { + RelayRouteMode::Direct + } else { + RelayRouteMode::Middle + }; + let _ = runtime.set_mode(mode); + } + })); + } + + for writer in writers { + writer + .join() + .expect("route mode writer thread must not panic"); + } + + let rx = runtime.subscribe(); + for _ in 0..128 { + assert_eq!( + runtime.snapshot(), + *rx.borrow(), + "snapshot and watch state must converge after concurrent set_mode churn" + ); + std::thread::yield_now(); + } + }); +} + +#[test] +fn stress_concurrent_transition_count_matches_final_generation() { + let runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let successful_transitions = Arc::new(AtomicU64::new(0)); + + std::thread::scope(|scope| { + let mut workers = Vec::new(); + for worker in 0..6usize { + let runtime = Arc::clone(&runtime); + let successful_transitions = Arc::clone(&successful_transitions); + workers.push(scope.spawn(move || { + let mut state = (worker as u64 + 1).wrapping_mul(0x9E37_79B9_7F4A_7C15); + for _ in 0..25_000usize { + state ^= state << 7; + state ^= state >> 9; + state ^= state << 8; + let mode = if (state & 1) == 0 { + RelayRouteMode::Direct + } else { + RelayRouteMode::Middle + }; + if runtime.set_mode(mode).is_some() { + successful_transitions.fetch_add(1, Ordering::Relaxed); + } + } + })); + } + + for worker in workers { + worker.join().expect("route mode transition worker must not panic"); + } + }); + + let final_state = runtime.snapshot(); + assert_eq!( + final_state.generation, + successful_transitions.load(Ordering::Relaxed), + "final generation must equal number of accepted mode transitions" + ); + assert_eq!( + final_state, + *runtime.subscribe().borrow(), + "watch and snapshot state must match after concurrent transition accounting" + ); +} + +#[test] +fn light_fuzz_cutover_stagger_delay_distribution_stays_in_fixed_window() { + // Deterministic xorshift fuzzing keeps this test stable across runs. + let mut s: u64 = 0x9E37_79B9_7F4A_7C15; + + for _ in 0..20_000 { + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + let session_id = s; + + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + let generation = s; + + let delay = cutover_stagger_delay(session_id, generation); + assert!( + (1000..=1999).contains(&delay.as_millis()), + "fuzzed inputs must always map into fixed stagger window" + ); + } +} diff --git a/src/stats/connection_lease_security_tests.rs b/src/stats/connection_lease_security_tests.rs new file mode 100644 index 0000000..69ae89a --- /dev/null +++ b/src/stats/connection_lease_security_tests.rs @@ -0,0 +1,265 @@ +use super::*; +use std::panic::{self, AssertUnwindSafe}; +use std::sync::Arc; +use std::time::Duration; +use tokio::sync::Barrier; + +#[test] +fn direct_connection_lease_balances_on_drop() { + let stats = Arc::new(Stats::new()); + assert_eq!(stats.get_current_connections_direct(), 0); + + { + let _lease = stats.acquire_direct_connection_lease(); + assert_eq!(stats.get_current_connections_direct(), 1); + } + + assert_eq!(stats.get_current_connections_direct(), 0); +} + +#[test] +fn middle_connection_lease_balances_on_drop() { + let stats = Arc::new(Stats::new()); + assert_eq!(stats.get_current_connections_me(), 0); + + { + let _lease = stats.acquire_me_connection_lease(); + assert_eq!(stats.get_current_connections_me(), 1); + } + + assert_eq!(stats.get_current_connections_me(), 0); +} + +#[test] +fn connection_lease_disarm_prevents_double_release() { + let stats = Arc::new(Stats::new()); + + let mut lease = stats.acquire_direct_connection_lease(); + assert_eq!(stats.get_current_connections_direct(), 1); + + stats.decrement_current_connections_direct(); + assert_eq!(stats.get_current_connections_direct(), 0); + + lease.disarm(); + drop(lease); + + assert_eq!(stats.get_current_connections_direct(), 0); +} + +#[test] +fn direct_connection_lease_balances_on_panic_unwind() { + let stats = Arc::new(Stats::new()); + let stats_for_panic = stats.clone(); + + let panic_result = panic::catch_unwind(AssertUnwindSafe(move || { + let _lease = stats_for_panic.acquire_direct_connection_lease(); + panic!("intentional panic to verify lease drop path"); + })); + + assert!(panic_result.is_err(), "panic must propagate from test closure"); + assert_eq!( + stats.get_current_connections_direct(), + 0, + "panic unwind must release direct route gauge" + ); +} + +#[test] +fn middle_connection_lease_balances_on_panic_unwind() { + let stats = Arc::new(Stats::new()); + let stats_for_panic = stats.clone(); + + let panic_result = panic::catch_unwind(AssertUnwindSafe(move || { + let _lease = stats_for_panic.acquire_me_connection_lease(); + panic!("intentional panic to verify middle lease drop path"); + })); + + assert!(panic_result.is_err(), "panic must propagate from test closure"); + assert_eq!( + stats.get_current_connections_me(), + 0, + "panic unwind must release middle route gauge" + ); +} + +#[tokio::test] +async fn concurrent_mixed_route_lease_churn_balances_to_zero() { + const TASKS: usize = 48; + const ITERATIONS_PER_TASK: usize = 256; + + let stats = Arc::new(Stats::new()); + let barrier = Arc::new(Barrier::new(TASKS)); + let mut workers = Vec::with_capacity(TASKS); + + for task_idx in 0..TASKS { + let stats_for_task = stats.clone(); + let barrier_for_task = barrier.clone(); + workers.push(tokio::spawn(async move { + barrier_for_task.wait().await; + for iter in 0..ITERATIONS_PER_TASK { + if (task_idx + iter) % 2 == 0 { + let _lease = stats_for_task.acquire_direct_connection_lease(); + tokio::task::yield_now().await; + } else { + let _lease = stats_for_task.acquire_me_connection_lease(); + tokio::task::yield_now().await; + } + } + })); + } + + for worker in workers { + worker + .await + .expect("lease churn worker must not panic"); + } + + assert_eq!( + stats.get_current_connections_direct(), + 0, + "direct route gauge must return to zero after concurrent lease churn" + ); + assert_eq!( + stats.get_current_connections_me(), + 0, + "middle route gauge must return to zero after concurrent lease churn" + ); +} + +#[tokio::test] +async fn abort_storm_mixed_route_leases_returns_all_gauges_to_zero() { + const TASKS: usize = 64; + + let stats = Arc::new(Stats::new()); + let mut workers = Vec::with_capacity(TASKS); + + for task_idx in 0..TASKS { + let stats_for_task = stats.clone(); + workers.push(tokio::spawn(async move { + if task_idx % 2 == 0 { + let _lease = stats_for_task.acquire_direct_connection_lease(); + tokio::time::sleep(Duration::from_secs(60)).await; + } else { + let _lease = stats_for_task.acquire_me_connection_lease(); + tokio::time::sleep(Duration::from_secs(60)).await; + } + })); + } + + tokio::time::timeout(Duration::from_secs(2), async { + loop { + let total = stats.get_current_connections_direct() + stats.get_current_connections_me(); + if total == TASKS as u64 { + break; + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + }) + .await + .expect("all storm tasks must acquire route leases before abort"); + + for worker in &workers { + worker.abort(); + } + for worker in workers { + let joined = worker.await; + assert!(joined.is_err(), "aborted worker must return join error"); + } + + tokio::time::timeout(Duration::from_secs(2), async { + loop { + if stats.get_current_connections_direct() == 0 && stats.get_current_connections_me() == 0 { + break; + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + }) + .await + .expect("all route gauges must drain to zero after abort storm"); +} + +#[test] +fn saturating_route_decrements_do_not_underflow_under_race() { + const THREADS: usize = 16; + const DECREMENTS_PER_THREAD: usize = 4096; + + let stats = Arc::new(Stats::new()); + let mut workers = Vec::with_capacity(THREADS); + + for _ in 0..THREADS { + let stats_for_thread = stats.clone(); + workers.push(std::thread::spawn(move || { + for _ in 0..DECREMENTS_PER_THREAD { + stats_for_thread.decrement_current_connections_direct(); + stats_for_thread.decrement_current_connections_me(); + } + })); + } + + for worker in workers { + worker + .join() + .expect("decrement race worker must not panic"); + } + + assert_eq!( + stats.get_current_connections_direct(), + 0, + "direct route decrement races must never underflow" + ); + assert_eq!( + stats.get_current_connections_me(), + 0, + "middle route decrement races must never underflow" + ); +} + +#[tokio::test] +async fn direct_connection_lease_balances_on_task_abort() { + let stats = Arc::new(Stats::new()); + let stats_for_task = stats.clone(); + + let task = tokio::spawn(async move { + let _lease = stats_for_task.acquire_direct_connection_lease(); + tokio::time::sleep(Duration::from_secs(60)).await; + }); + + tokio::time::sleep(Duration::from_millis(20)).await; + assert_eq!(stats.get_current_connections_direct(), 1); + + task.abort(); + let joined = task.await; + assert!(joined.is_err(), "aborted task must return a join error"); + + tokio::time::sleep(Duration::from_millis(20)).await; + assert_eq!( + stats.get_current_connections_direct(), + 0, + "aborted task must release direct route gauge" + ); +} + +#[tokio::test] +async fn middle_connection_lease_balances_on_task_abort() { + let stats = Arc::new(Stats::new()); + let stats_for_task = stats.clone(); + + let task = tokio::spawn(async move { + let _lease = stats_for_task.acquire_me_connection_lease(); + tokio::time::sleep(Duration::from_secs(60)).await; + }); + + tokio::time::sleep(Duration::from_millis(20)).await; + assert_eq!(stats.get_current_connections_me(), 1); + + task.abort(); + let joined = task.await; + assert!(joined.is_err(), "aborted task must return a join error"); + + tokio::time::sleep(Duration::from_millis(20)).await; + assert_eq!( + stats.get_current_connections_me(), + 0, + "aborted task must release middle route gauge" + ); +} diff --git a/src/stats/mod.rs b/src/stats/mod.rs index 603552d..3ad361f 100644 --- a/src/stats/mod.rs +++ b/src/stats/mod.rs @@ -6,6 +6,7 @@ pub mod beobachten; pub mod telemetry; use std::sync::atomic::{AtomicBool, AtomicU8, AtomicU64, Ordering}; +use std::sync::Arc; use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; use dashmap::DashMap; use parking_lot::Mutex; @@ -19,6 +20,46 @@ use tracing::debug; use crate::config::{MeTelemetryLevel, MeWriterPickMode}; use self::telemetry::TelemetryPolicy; +#[derive(Clone, Copy)] +enum RouteConnectionGauge { + Direct, + Middle, +} + +#[must_use = "RouteConnectionLease must be kept alive to hold the connection gauge increment"] +pub struct RouteConnectionLease { + stats: Arc, + gauge: RouteConnectionGauge, + active: bool, +} + +impl RouteConnectionLease { + fn new(stats: Arc, gauge: RouteConnectionGauge) -> Self { + Self { + stats, + gauge, + active: true, + } + } + + #[cfg(test)] + fn disarm(&mut self) { + self.active = false; + } +} + +impl Drop for RouteConnectionLease { + fn drop(&mut self) { + if !self.active { + return; + } + match self.gauge { + RouteConnectionGauge::Direct => self.stats.decrement_current_connections_direct(), + RouteConnectionGauge::Middle => self.stats.decrement_current_connections_me(), + } + } +} + // ============= Stats ============= #[derive(Default)] @@ -285,6 +326,16 @@ impl Stats { pub fn decrement_current_connections_me(&self) { Self::decrement_atomic_saturating(&self.current_connections_me); } + + pub fn acquire_direct_connection_lease(self: &Arc) -> RouteConnectionLease { + self.increment_current_connections_direct(); + RouteConnectionLease::new(self.clone(), RouteConnectionGauge::Direct) + } + + pub fn acquire_me_connection_lease(self: &Arc) -> RouteConnectionLease { + self.increment_current_connections_me(); + RouteConnectionLease::new(self.clone(), RouteConnectionGauge::Middle) + } pub fn increment_handshake_timeouts(&self) { if self.telemetry_core_enabled() { self.handshake_timeouts.fetch_add(1, Ordering::Relaxed); @@ -1772,3 +1823,7 @@ mod tests { assert_eq!(checker.stats().total_entries, 500); } } + +#[cfg(test)] +#[path = "connection_lease_security_tests.rs"] +mod connection_lease_security_tests; diff --git a/src/tls_front/emulator.rs b/src/tls_front/emulator.rs index 3278f63..7e329c5 100644 --- a/src/tls_front/emulator.rs +++ b/src/tls_front/emulator.rs @@ -103,7 +103,7 @@ pub fn build_emulated_server_hello( cached: &CachedTlsData, use_full_cert_payload: bool, rng: &SecureRandom, - alpn: Option>, + _alpn: Option>, new_session_tickets: u8, ) -> Vec { // --- ServerHello --- @@ -117,15 +117,6 @@ pub fn build_emulated_server_hello( extensions.extend_from_slice(&0x002bu16.to_be_bytes()); extensions.extend_from_slice(&(2u16).to_be_bytes()); extensions.extend_from_slice(&0x0304u16.to_be_bytes()); - if let Some(alpn_proto) = &alpn { - extensions.extend_from_slice(&0x0010u16.to_be_bytes()); - let list_len: u16 = 1 + alpn_proto.len() as u16; - let ext_len: u16 = 2 + list_len; - extensions.extend_from_slice(&ext_len.to_be_bytes()); - extensions.extend_from_slice(&list_len.to_be_bytes()); - extensions.push(alpn_proto.len() as u8); - extensions.extend_from_slice(alpn_proto); - } let extensions_len = extensions.len() as u16; let body_len = 2 + // version diff --git a/src/transport/middle_proxy/health.rs b/src/transport/middle_proxy/health.rs index a2e107d..a6b1031 100644 --- a/src/transport/middle_proxy/health.rs +++ b/src/transport/middle_proxy/health.rs @@ -25,6 +25,9 @@ const HEALTH_RECONNECT_BUDGET_PER_CORE: usize = 2; const HEALTH_RECONNECT_BUDGET_PER_DC: usize = 1; const HEALTH_RECONNECT_BUDGET_MIN: usize = 4; const HEALTH_RECONNECT_BUDGET_MAX: usize = 128; +const HEALTH_DRAIN_CLOSE_BUDGET_PER_CORE: usize = 16; +const HEALTH_DRAIN_CLOSE_BUDGET_MIN: usize = 16; +const HEALTH_DRAIN_CLOSE_BUDGET_MAX: usize = 256; #[derive(Debug, Clone)] struct DcFloorPlanEntry { @@ -111,106 +114,75 @@ pub async fn me_health_monitor(pool: Arc, rng: Arc, _min_c } } -async fn reap_draining_writers( +pub(super) async fn reap_draining_writers( pool: &Arc, warn_next_allowed: &mut HashMap, ) { - if pool.draining_active_runtime() == 0 { - return; - } - let now_epoch_secs = MePool::now_epoch_secs(); let now = Instant::now(); let drain_ttl_secs = pool.me_pool_drain_ttl_secs.load(std::sync::atomic::Ordering::Relaxed); let drain_threshold = pool .me_pool_drain_threshold .load(std::sync::atomic::Ordering::Relaxed); - let mut draining_writers = { - let writers = pool.writers.read().await; - let mut draining_writers = Vec::::new(); - for writer in writers.iter() { - if !writer.draining.load(std::sync::atomic::Ordering::Relaxed) { - continue; - } - draining_writers.push(DrainingWriterSnapshot { - id: writer.id, - writer_dc: writer.writer_dc, - addr: writer.addr, - generation: writer.generation, - created_at: writer.created_at, - draining_started_at_epoch_secs: writer - .draining_started_at_epoch_secs - .load(std::sync::atomic::Ordering::Relaxed), - drain_deadline_epoch_secs: writer - .drain_deadline_epoch_secs - .load(std::sync::atomic::Ordering::Relaxed), - allow_drain_fallback: writer - .allow_drain_fallback - .load(std::sync::atomic::Ordering::Relaxed), - }); + let activity = pool.registry.writer_activity_snapshot().await; + let mut draining_writers = Vec::::new(); + let mut empty_writer_ids = Vec::::new(); + let mut force_close_writer_ids = Vec::::new(); + let writers = pool.writers.read().await; + for writer in writers.iter() { + if !writer.draining.load(std::sync::atomic::Ordering::Relaxed) { + continue; } - draining_writers - }; - - if draining_writers.is_empty() { - return; - } - - let draining_ids: Vec = draining_writers.iter().map(|writer| writer.id).collect(); - let non_empty_writer_ids = pool.registry.non_empty_writer_ids(&draining_ids).await; - let mut non_empty_draining_writers = - Vec::::with_capacity(draining_writers.len()); - for writer in draining_writers.drain(..) { - if non_empty_writer_ids.contains(&writer.id) { - non_empty_draining_writers.push(writer); - } else { - pool.remove_writer_and_close_clients(writer.id).await; + if activity + .bound_clients_by_writer + .get(&writer.id) + .copied() + .unwrap_or(0) + == 0 + { + empty_writer_ids.push(writer.id); + continue; } + draining_writers.push(DrainingWriterSnapshot { + id: writer.id, + writer_dc: writer.writer_dc, + addr: writer.addr, + generation: writer.generation, + created_at: writer.created_at, + draining_started_at_epoch_secs: writer + .draining_started_at_epoch_secs + .load(std::sync::atomic::Ordering::Relaxed), + drain_deadline_epoch_secs: writer + .drain_deadline_epoch_secs + .load(std::sync::atomic::Ordering::Relaxed), + allow_drain_fallback: writer + .allow_drain_fallback + .load(std::sync::atomic::Ordering::Relaxed), + }); } - draining_writers = non_empty_draining_writers; - if draining_writers.is_empty() { - return; - } + drop(writers); let overflow = if drain_threshold > 0 && draining_writers.len() > drain_threshold as usize { draining_writers.len().saturating_sub(drain_threshold as usize) } else { 0 }; - let has_deadline_expired = draining_writers.iter().any(|writer| { - writer.drain_deadline_epoch_secs != 0 && now_epoch_secs >= writer.drain_deadline_epoch_secs - }); - let can_drop_with_replacement = if overflow > 0 || has_deadline_expired { - pool.has_non_draining_writer_per_desired_dc_group().await - } else { - false - }; if overflow > 0 { - if can_drop_with_replacement { - draining_writers.sort_by(|left, right| { - left.draining_started_at_epoch_secs - .cmp(&right.draining_started_at_epoch_secs) - .then_with(|| left.created_at.cmp(&right.created_at)) - .then_with(|| left.id.cmp(&right.id)) - }); - warn!( - draining_writers = draining_writers.len(), - me_pool_drain_threshold = drain_threshold, - removing_writers = overflow, - "ME draining writer threshold exceeded, force-closing oldest draining writers" - ); - for writer in draining_writers.drain(..overflow) { - pool.stats.increment_pool_force_close_total(); - pool.remove_writer_and_close_clients(writer.id).await; - } - } else { - warn!( - draining_writers = draining_writers.len(), - me_pool_drain_threshold = drain_threshold, - overflow, - "ME draining threshold exceeded, but replacement coverage is incomplete; keeping draining writers" - ); + draining_writers.sort_by(|left, right| { + left.draining_started_at_epoch_secs + .cmp(&right.draining_started_at_epoch_secs) + .then_with(|| left.created_at.cmp(&right.created_at)) + .then_with(|| left.id.cmp(&right.id)) + }); + warn!( + draining_writers = draining_writers.len(), + me_pool_drain_threshold = drain_threshold, + removing_writers = overflow, + "ME draining writer threshold exceeded, force-closing oldest draining writers" + ); + for writer in draining_writers.drain(..overflow) { + force_close_writer_ids.push(writer.id); } } @@ -238,25 +210,71 @@ async fn reap_draining_writers( } if writer.drain_deadline_epoch_secs != 0 && now_epoch_secs >= writer.drain_deadline_epoch_secs { - if can_drop_with_replacement { - warn!(writer_id = writer.id, "Drain timeout, force-closing"); - pool.stats.increment_pool_force_close_total(); - pool.remove_writer_and_close_clients(writer.id).await; - } else if should_emit_writer_warn( - warn_next_allowed, - writer.id, - now, - pool.warn_rate_limit_duration(), - ) { - warn!( - writer_id = writer.id, - writer_dc = writer.writer_dc, - endpoint = %writer.addr, - "Drain timeout reached, but replacement coverage is incomplete; keeping draining writer" - ); - } + warn!(writer_id = writer.id, "Drain timeout, force-closing"); + force_close_writer_ids.push(writer.id); } } + + let close_budget = health_drain_close_budget(); + let requested_force_close = force_close_writer_ids.len(); + let requested_empty_close = empty_writer_ids.len(); + let requested_close_total = requested_force_close.saturating_add(requested_empty_close); + let mut closed_writer_ids = HashSet::::new(); + let mut closed_total = 0usize; + for writer_id in force_close_writer_ids { + if closed_total >= close_budget { + break; + } + if !closed_writer_ids.insert(writer_id) { + continue; + } + pool.stats.increment_pool_force_close_total(); + pool.remove_writer_and_close_clients(writer_id).await; + closed_total = closed_total.saturating_add(1); + } + for writer_id in empty_writer_ids { + if closed_total >= close_budget { + break; + } + if !closed_writer_ids.insert(writer_id) { + continue; + } + if !pool.remove_writer_if_empty(writer_id).await { + continue; + } + closed_total = closed_total.saturating_add(1); + } + + let pending_close_total = requested_close_total.saturating_sub(closed_total); + if pending_close_total > 0 { + warn!( + close_budget, + closed_total, + pending_close_total, + "ME draining close backlog deferred to next health cycle" + ); + } + + // Keep warn cooldown state for draining writers still present in the pool; + // drop state only once a writer is actually removed. + let active_draining_writer_ids = { + let writers = pool.writers.read().await; + writers + .iter() + .filter(|writer| writer.draining.load(std::sync::atomic::Ordering::Relaxed)) + .map(|writer| writer.id) + .collect::>() + }; + warn_next_allowed.retain(|writer_id, _| active_draining_writer_ids.contains(writer_id)); +} + +pub(super) fn health_drain_close_budget() -> usize { + let cpu_cores = std::thread::available_parallelism() + .map(std::num::NonZeroUsize::get) + .unwrap_or(1); + cpu_cores + .saturating_mul(HEALTH_DRAIN_CLOSE_BUDGET_PER_CORE) + .clamp(HEALTH_DRAIN_CLOSE_BUDGET_MIN, HEALTH_DRAIN_CLOSE_BUDGET_MAX) } #[derive(Debug, Clone)] @@ -1521,7 +1539,6 @@ mod tests { pool.writers.write().await.push(writer); pool.registry.register_writer(writer_id, tx).await; pool.conn_count.fetch_add(1, Ordering::Relaxed); - pool.increment_draining_active_runtime(); assert!( pool.registry .bind_writer( @@ -1570,7 +1587,6 @@ mod tests { async fn reap_draining_writers_force_closes_oldest_over_threshold() { let pool = make_pool(2).await; insert_live_writer(&pool, 1, 2).await; - assert!(pool.has_non_draining_writer_per_desired_dc_group().await); let now_epoch_secs = MePool::now_epoch_secs(); let conn_a = insert_draining_writer(&pool, 10, now_epoch_secs.saturating_sub(30)).await; let conn_b = insert_draining_writer(&pool, 20, now_epoch_secs.saturating_sub(20)).await; @@ -1588,7 +1604,7 @@ mod tests { } #[tokio::test] - async fn reap_draining_writers_does_not_force_close_overflow_without_replacement() { + async fn reap_draining_writers_force_closes_overflow_without_replacement() { let pool = make_pool(2).await; let now_epoch_secs = MePool::now_epoch_secs(); let conn_a = insert_draining_writer(&pool, 10, now_epoch_secs.saturating_sub(30)).await; @@ -1600,8 +1616,8 @@ mod tests { let mut writer_ids: Vec = pool.writers.read().await.iter().map(|writer| writer.id).collect(); writer_ids.sort_unstable(); - assert_eq!(writer_ids, vec![10, 20, 30]); - assert_eq!(pool.registry.get_writer(conn_a).await.unwrap().writer_id, 10); + assert_eq!(writer_ids, vec![20, 30]); + assert!(pool.registry.get_writer(conn_a).await.is_none()); assert_eq!(pool.registry.get_writer(conn_b).await.unwrap().writer_id, 20); assert_eq!(pool.registry.get_writer(conn_c).await.unwrap().writer_id, 30); } diff --git a/src/transport/middle_proxy/health_adversarial_tests.rs b/src/transport/middle_proxy/health_adversarial_tests.rs new file mode 100644 index 0000000..cd06fdf --- /dev/null +++ b/src/transport/middle_proxy/health_adversarial_tests.rs @@ -0,0 +1,615 @@ +use std::collections::HashMap; +use std::collections::HashSet; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, AtomicU8, AtomicU32, AtomicU64, Ordering}; +use std::time::{Duration, Instant}; + +use tokio::sync::mpsc; +use tokio_util::sync::CancellationToken; + +use super::codec::WriterCommand; +use super::health::{health_drain_close_budget, reap_draining_writers}; +use super::pool::{MePool, MeWriter, WriterContour}; +use super::registry::ConnMeta; +use super::me_health_monitor; +use crate::config::{GeneralConfig, MeRouteNoWriterMode, MeSocksKdfPolicy, MeWriterPickMode}; +use crate::crypto::SecureRandom; +use crate::network::probe::NetworkDecision; +use crate::stats::Stats; + +async fn make_pool( + me_pool_drain_threshold: u64, + me_health_interval_ms_unhealthy: u64, + me_health_interval_ms_healthy: u64, +) -> (Arc, Arc) { + let general = GeneralConfig { + me_pool_drain_threshold, + me_health_interval_ms_unhealthy, + me_health_interval_ms_healthy, + ..GeneralConfig::default() + }; + + let rng = Arc::new(SecureRandom::new()); + let pool = MePool::new( + None, + vec![1u8; 32], + None, + false, + None, + Vec::new(), + 1, + None, + 12, + 1200, + HashMap::new(), + HashMap::new(), + None, + NetworkDecision::default(), + None, + rng.clone(), + Arc::new(Stats::default()), + general.me_keepalive_enabled, + general.me_keepalive_interval_secs, + general.me_keepalive_jitter_secs, + general.me_keepalive_payload_random, + general.rpc_proxy_req_every, + general.me_warmup_stagger_enabled, + general.me_warmup_step_delay_ms, + general.me_warmup_step_jitter_ms, + general.me_reconnect_max_concurrent_per_dc, + general.me_reconnect_backoff_base_ms, + general.me_reconnect_backoff_cap_ms, + general.me_reconnect_fast_retry_count, + general.me_single_endpoint_shadow_writers, + general.me_single_endpoint_outage_mode_enabled, + general.me_single_endpoint_outage_disable_quarantine, + general.me_single_endpoint_outage_backoff_min_ms, + general.me_single_endpoint_outage_backoff_max_ms, + general.me_single_endpoint_shadow_rotate_every_secs, + general.me_floor_mode, + general.me_adaptive_floor_idle_secs, + general.me_adaptive_floor_min_writers_single_endpoint, + general.me_adaptive_floor_min_writers_multi_endpoint, + general.me_adaptive_floor_recover_grace_secs, + general.me_adaptive_floor_writers_per_core_total, + general.me_adaptive_floor_cpu_cores_override, + general.me_adaptive_floor_max_extra_writers_single_per_core, + general.me_adaptive_floor_max_extra_writers_multi_per_core, + general.me_adaptive_floor_max_active_writers_per_core, + general.me_adaptive_floor_max_warm_writers_per_core, + general.me_adaptive_floor_max_active_writers_global, + general.me_adaptive_floor_max_warm_writers_global, + general.hardswap, + general.me_pool_drain_ttl_secs, + general.me_pool_drain_threshold, + general.effective_me_pool_force_close_secs(), + general.me_pool_min_fresh_ratio, + general.me_hardswap_warmup_delay_min_ms, + general.me_hardswap_warmup_delay_max_ms, + general.me_hardswap_warmup_extra_passes, + general.me_hardswap_warmup_pass_backoff_base_ms, + general.me_bind_stale_mode, + general.me_bind_stale_ttl_secs, + general.me_secret_atomic_snapshot, + general.me_deterministic_writer_sort, + MeWriterPickMode::default(), + general.me_writer_pick_sample_size, + MeSocksKdfPolicy::default(), + general.me_writer_cmd_channel_capacity, + general.me_route_channel_capacity, + general.me_route_backpressure_base_timeout_ms, + general.me_route_backpressure_high_timeout_ms, + general.me_route_backpressure_high_watermark_pct, + general.me_reader_route_data_wait_ms, + general.me_health_interval_ms_unhealthy, + general.me_health_interval_ms_healthy, + general.me_warn_rate_limit_ms, + MeRouteNoWriterMode::default(), + general.me_route_no_writer_wait_ms, + general.me_route_inline_recovery_attempts, + general.me_route_inline_recovery_wait_ms, + ); + + (pool, rng) +} + +async fn insert_draining_writer( + pool: &Arc, + writer_id: u64, + drain_started_at_epoch_secs: u64, + bound_clients: usize, + drain_deadline_epoch_secs: u64, +) { + let (tx, _writer_rx) = mpsc::channel::(8); + let writer = MeWriter { + id: writer_id, + addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 6000 + writer_id as u16), + source_ip: IpAddr::V4(Ipv4Addr::LOCALHOST), + writer_dc: 2, + generation: 1, + contour: Arc::new(AtomicU8::new(WriterContour::Draining.as_u8())), + created_at: Instant::now() - Duration::from_secs(writer_id), + tx: tx.clone(), + cancel: CancellationToken::new(), + degraded: Arc::new(AtomicBool::new(false)), + rtt_ema_ms_x10: Arc::new(AtomicU32::new(0)), + draining: Arc::new(AtomicBool::new(true)), + draining_started_at_epoch_secs: Arc::new(AtomicU64::new(drain_started_at_epoch_secs)), + drain_deadline_epoch_secs: Arc::new(AtomicU64::new(drain_deadline_epoch_secs)), + allow_drain_fallback: Arc::new(AtomicBool::new(false)), + }; + + pool.writers.write().await.push(writer); + pool.registry.register_writer(writer_id, tx).await; + pool.conn_count.fetch_add(1, Ordering::Relaxed); + + for idx in 0..bound_clients { + let (conn_id, _rx) = pool.registry.register().await; + assert!( + pool.registry + .bind_writer( + conn_id, + writer_id, + ConnMeta { + target_dc: 2, + client_addr: SocketAddr::new( + IpAddr::V4(Ipv4Addr::LOCALHOST), + 8000 + idx as u16, + ), + our_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443), + proto_flags: 0, + }, + ) + .await + ); + } +} + +async fn writer_count(pool: &Arc) -> usize { + pool.writers.read().await.len() +} + +async fn sorted_writer_ids(pool: &Arc) -> Vec { + let mut ids = pool + .writers + .read() + .await + .iter() + .map(|writer| writer.id) + .collect::>(); + ids.sort_unstable(); + ids +} + +fn lcg_next(state: &mut u64) -> u64 { + *state = state.wrapping_mul(6364136223846793005).wrapping_add(1); + *state +} + +async fn draining_writer_ids(pool: &Arc) -> HashSet { + pool.writers + .read() + .await + .iter() + .filter(|writer| writer.draining.load(Ordering::Relaxed)) + .map(|writer| writer.id) + .collect::>() +} + +async fn set_writer_runtime_state( + pool: &Arc, + writer_id: u64, + draining: bool, + drain_started_at_epoch_secs: u64, + drain_deadline_epoch_secs: u64, +) { + let writers = pool.writers.read().await; + if let Some(writer) = writers.iter().find(|writer| writer.id == writer_id) { + writer.draining.store(draining, Ordering::Relaxed); + writer + .draining_started_at_epoch_secs + .store(drain_started_at_epoch_secs, Ordering::Relaxed); + writer + .drain_deadline_epoch_secs + .store(drain_deadline_epoch_secs, Ordering::Relaxed); + } +} + +#[tokio::test] +async fn reap_draining_writers_clears_warn_state_when_pool_empty() { + let (pool, _rng) = make_pool(128, 1, 1).await; + let mut warn_next_allowed = HashMap::new(); + warn_next_allowed.insert(11, Instant::now() + Duration::from_secs(5)); + warn_next_allowed.insert(22, Instant::now() + Duration::from_secs(5)); + + reap_draining_writers(&pool, &mut warn_next_allowed).await; + + assert!(warn_next_allowed.is_empty()); +} + +#[tokio::test] +async fn reap_draining_writers_respects_threshold_across_multiple_overflow_cycles() { + let threshold = 3u64; + let (pool, _rng) = make_pool(threshold, 1, 1).await; + let now_epoch_secs = MePool::now_epoch_secs(); + + for writer_id in 1..=60u64 { + insert_draining_writer( + &pool, + writer_id, + now_epoch_secs.saturating_sub(600).saturating_add(writer_id), + 1, + 0, + ) + .await; + } + + let mut warn_next_allowed = HashMap::new(); + for _ in 0..64 { + reap_draining_writers(&pool, &mut warn_next_allowed).await; + if writer_count(&pool).await <= threshold as usize { + break; + } + } + + assert_eq!(writer_count(&pool).await, threshold as usize); + assert_eq!(sorted_writer_ids(&pool).await, vec![58, 59, 60]); +} + +#[tokio::test] +async fn reap_draining_writers_handles_large_empty_writer_population() { + let (pool, _rng) = make_pool(128, 1, 1).await; + let now_epoch_secs = MePool::now_epoch_secs(); + let total = health_drain_close_budget().saturating_mul(3).saturating_add(27); + + for writer_id in 1..=total as u64 { + insert_draining_writer( + &pool, + writer_id, + now_epoch_secs.saturating_sub(120), + 0, + 0, + ) + .await; + } + + let mut warn_next_allowed = HashMap::new(); + for _ in 0..24 { + if writer_count(&pool).await == 0 { + break; + } + reap_draining_writers(&pool, &mut warn_next_allowed).await; + } + + assert_eq!(writer_count(&pool).await, 0); +} + +#[tokio::test] +async fn reap_draining_writers_processes_mass_deadline_expiry_without_unbounded_growth() { + let (pool, _rng) = make_pool(128, 1, 1).await; + let now_epoch_secs = MePool::now_epoch_secs(); + let total = health_drain_close_budget().saturating_mul(4).saturating_add(31); + + for writer_id in 1..=total as u64 { + insert_draining_writer( + &pool, + writer_id, + now_epoch_secs.saturating_sub(180), + 1, + now_epoch_secs.saturating_sub(1), + ) + .await; + } + + let mut warn_next_allowed = HashMap::new(); + for _ in 0..40 { + if writer_count(&pool).await == 0 { + break; + } + reap_draining_writers(&pool, &mut warn_next_allowed).await; + } + + assert_eq!(writer_count(&pool).await, 0); +} + +#[tokio::test] +async fn reap_draining_writers_maintains_warn_state_subset_property_under_bulk_churn() { + let (pool, _rng) = make_pool(128, 1, 1).await; + let now_epoch_secs = MePool::now_epoch_secs(); + let mut warn_next_allowed = HashMap::new(); + + for wave in 0..40u64 { + for offset in 0..8u64 { + insert_draining_writer( + &pool, + wave * 100 + offset, + now_epoch_secs.saturating_sub(400 + offset), + 1, + 0, + ) + .await; + } + + reap_draining_writers(&pool, &mut warn_next_allowed).await; + assert!(warn_next_allowed.len() <= writer_count(&pool).await); + + let ids = sorted_writer_ids(&pool).await; + for writer_id in ids.into_iter().take(3) { + let _ = pool.remove_writer_and_close_clients(writer_id).await; + } + + reap_draining_writers(&pool, &mut warn_next_allowed).await; + assert!(warn_next_allowed.len() <= writer_count(&pool).await); + } +} + +#[tokio::test] +async fn reap_draining_writers_budgeted_cleanup_never_increases_pool_size() { + let (pool, _rng) = make_pool(5, 1, 1).await; + let now_epoch_secs = MePool::now_epoch_secs(); + + for writer_id in 1..=200u64 { + insert_draining_writer( + &pool, + writer_id, + now_epoch_secs.saturating_sub(240).saturating_add(writer_id), + 1, + 0, + ) + .await; + } + + let mut warn_next_allowed = HashMap::new(); + let mut previous = writer_count(&pool).await; + for _ in 0..32 { + reap_draining_writers(&pool, &mut warn_next_allowed).await; + let current = writer_count(&pool).await; + assert!(current <= previous); + previous = current; + } +} + +#[tokio::test] +async fn me_health_monitor_converges_to_threshold_under_live_injection_churn() { + let threshold = 7u64; + let (pool, rng) = make_pool(threshold, 1, 1).await; + let now_epoch_secs = MePool::now_epoch_secs(); + + for writer_id in 1..=40u64 { + insert_draining_writer( + &pool, + writer_id, + now_epoch_secs.saturating_sub(300).saturating_add(writer_id), + 1, + 0, + ) + .await; + } + + let monitor = tokio::spawn(me_health_monitor(pool.clone(), rng, 0)); + + for wave in 0..8u64 { + for offset in 0..10u64 { + insert_draining_writer( + &pool, + 1000 + wave * 100 + offset, + now_epoch_secs.saturating_sub(120).saturating_add(offset), + 1, + 0, + ) + .await; + } + tokio::time::sleep(Duration::from_millis(5)).await; + } + + tokio::time::sleep(Duration::from_millis(120)).await; + monitor.abort(); + let _ = monitor.await; + + assert!(writer_count(&pool).await <= threshold as usize); +} + +#[tokio::test] +async fn me_health_monitor_drains_deadline_storm_with_budgeted_progress() { + let (pool, rng) = make_pool(128, 1, 1).await; + let now_epoch_secs = MePool::now_epoch_secs(); + + for writer_id in 1..=220u64 { + insert_draining_writer( + &pool, + writer_id, + now_epoch_secs.saturating_sub(120), + 1, + now_epoch_secs.saturating_sub(1), + ) + .await; + } + + let monitor = tokio::spawn(me_health_monitor(pool.clone(), rng, 0)); + tokio::time::sleep(Duration::from_millis(120)).await; + monitor.abort(); + let _ = monitor.await; + + assert_eq!(writer_count(&pool).await, 0); +} + +#[tokio::test] +async fn me_health_monitor_eliminates_mixed_empty_and_deadline_backlog() { + let threshold = 12u64; + let (pool, rng) = make_pool(threshold, 1, 1).await; + let now_epoch_secs = MePool::now_epoch_secs(); + + for writer_id in 1..=180u64 { + let bound_clients = if writer_id % 3 == 0 { 0 } else { 1 }; + let deadline = if writer_id % 2 == 0 { + now_epoch_secs.saturating_sub(1) + } else { + 0 + }; + insert_draining_writer( + &pool, + writer_id, + now_epoch_secs.saturating_sub(250).saturating_add(writer_id), + bound_clients, + deadline, + ) + .await; + } + + let monitor = tokio::spawn(me_health_monitor(pool.clone(), rng, 0)); + tokio::time::sleep(Duration::from_millis(140)).await; + monitor.abort(); + let _ = monitor.await; + + assert!(writer_count(&pool).await <= threshold as usize); +} + +#[tokio::test] +async fn reap_draining_writers_deterministic_mixed_state_churn_preserves_invariants() { + let threshold = 9u64; + let (pool, _rng) = make_pool(threshold, 1, 1).await; + let mut warn_next_allowed = HashMap::new(); + let mut seed = 0x9E37_79B9_7F4A_7C15u64; + let mut next_writer_id = 20_000u64; + let now_epoch_secs = MePool::now_epoch_secs(); + + for writer_id in 1..=72u64 { + let bound_clients = if writer_id % 4 == 0 { 0 } else { 1 }; + let deadline = if writer_id % 5 == 0 { + now_epoch_secs.saturating_sub(1) + } else { + 0 + }; + insert_draining_writer( + &pool, + writer_id, + now_epoch_secs.saturating_sub(500).saturating_add(writer_id), + bound_clients, + deadline, + ) + .await; + } + + for _round in 0..90 { + reap_draining_writers(&pool, &mut warn_next_allowed).await; + + let draining_ids = draining_writer_ids(&pool).await; + assert!( + warn_next_allowed.keys().all(|id| draining_ids.contains(id)), + "warn-state keys must always be a subset of live draining writers" + ); + + let writer_ids = sorted_writer_ids(&pool).await; + if writer_ids.is_empty() { + continue; + } + + let remove_n = (lcg_next(&mut seed) % 3) as usize; + for writer_id in writer_ids.iter().copied().take(remove_n) { + let _ = pool.remove_writer_and_close_clients(writer_id).await; + } + + let survivors = sorted_writer_ids(&pool).await; + if !survivors.is_empty() { + let idx = (lcg_next(&mut seed) as usize) % survivors.len(); + let target = survivors[idx]; + set_writer_runtime_state(&pool, target, false, 0, 0).await; + } + + let survivors = sorted_writer_ids(&pool).await; + if survivors.len() > 1 { + let idx = (lcg_next(&mut seed) as usize) % survivors.len(); + let target = survivors[idx]; + let expired_deadline = if lcg_next(&mut seed) & 1 == 0 { + now_epoch_secs.saturating_sub(1) + } else { + 0 + }; + set_writer_runtime_state( + &pool, + target, + true, + now_epoch_secs.saturating_sub(120), + expired_deadline, + ) + .await; + } + + let inject_n = (lcg_next(&mut seed) % 4) as usize; + for _ in 0..inject_n { + let bound_clients = if lcg_next(&mut seed) & 1 == 0 { 0 } else { 1 }; + let deadline = if lcg_next(&mut seed) & 1 == 0 { + now_epoch_secs.saturating_sub(1) + } else { + 0 + }; + insert_draining_writer( + &pool, + next_writer_id, + now_epoch_secs.saturating_sub(240), + bound_clients, + deadline, + ) + .await; + next_writer_id = next_writer_id.saturating_add(1); + } + } + + for _ in 0..64 { + reap_draining_writers(&pool, &mut warn_next_allowed).await; + if writer_count(&pool).await <= threshold as usize { + break; + } + } + + assert!(writer_count(&pool).await <= threshold as usize); + let draining_ids = draining_writer_ids(&pool).await; + assert!(warn_next_allowed.keys().all(|id| draining_ids.contains(id))); +} + +#[tokio::test] +async fn reap_draining_writers_repeated_draining_flips_never_leave_stale_warn_state() { + let (pool, _rng) = make_pool(64, 1, 1).await; + let now_epoch_secs = MePool::now_epoch_secs(); + + for writer_id in 1..=24u64 { + insert_draining_writer( + &pool, + writer_id, + now_epoch_secs.saturating_sub(240), + 1, + 0, + ) + .await; + } + + let mut warn_next_allowed = HashMap::new(); + for _round in 0..48u64 { + for writer_id in 1..=24u64 { + let draining = (writer_id + _round) % 3 != 0; + set_writer_runtime_state( + &pool, + writer_id, + draining, + now_epoch_secs.saturating_sub(120), + 0, + ) + .await; + } + + reap_draining_writers(&pool, &mut warn_next_allowed).await; + + let draining_ids = draining_writer_ids(&pool).await; + assert!( + warn_next_allowed.keys().all(|id| draining_ids.contains(id)), + "warn-state map must not retain entries for writers outside draining set" + ); + } +} + +#[test] +fn health_drain_close_budget_is_within_expected_bounds() { + let budget = health_drain_close_budget(); + assert!((16..=256).contains(&budget)); +} diff --git a/src/transport/middle_proxy/health_integration_tests.rs b/src/transport/middle_proxy/health_integration_tests.rs new file mode 100644 index 0000000..476b549 --- /dev/null +++ b/src/transport/middle_proxy/health_integration_tests.rs @@ -0,0 +1,241 @@ +use std::collections::HashMap; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, AtomicU8, AtomicU32, AtomicU64, Ordering}; +use std::time::{Duration, Instant}; + +use tokio::sync::mpsc; +use tokio_util::sync::CancellationToken; + +use super::codec::WriterCommand; +use super::health::health_drain_close_budget; +use super::pool::{MePool, MeWriter, WriterContour}; +use super::registry::ConnMeta; +use super::me_health_monitor; +use crate::config::{GeneralConfig, MeRouteNoWriterMode, MeSocksKdfPolicy, MeWriterPickMode}; +use crate::crypto::SecureRandom; +use crate::network::probe::NetworkDecision; +use crate::stats::Stats; + +async fn make_pool( + me_pool_drain_threshold: u64, + me_health_interval_ms_unhealthy: u64, + me_health_interval_ms_healthy: u64, +) -> (Arc, Arc) { + let general = GeneralConfig { + me_pool_drain_threshold, + me_health_interval_ms_unhealthy, + me_health_interval_ms_healthy, + ..GeneralConfig::default() + }; + let rng = Arc::new(SecureRandom::new()); + let pool = MePool::new( + None, + vec![1u8; 32], + None, + false, + None, + Vec::new(), + 1, + None, + 12, + 1200, + HashMap::new(), + HashMap::new(), + None, + NetworkDecision::default(), + None, + rng.clone(), + Arc::new(Stats::default()), + general.me_keepalive_enabled, + general.me_keepalive_interval_secs, + general.me_keepalive_jitter_secs, + general.me_keepalive_payload_random, + general.rpc_proxy_req_every, + general.me_warmup_stagger_enabled, + general.me_warmup_step_delay_ms, + general.me_warmup_step_jitter_ms, + general.me_reconnect_max_concurrent_per_dc, + general.me_reconnect_backoff_base_ms, + general.me_reconnect_backoff_cap_ms, + general.me_reconnect_fast_retry_count, + general.me_single_endpoint_shadow_writers, + general.me_single_endpoint_outage_mode_enabled, + general.me_single_endpoint_outage_disable_quarantine, + general.me_single_endpoint_outage_backoff_min_ms, + general.me_single_endpoint_outage_backoff_max_ms, + general.me_single_endpoint_shadow_rotate_every_secs, + general.me_floor_mode, + general.me_adaptive_floor_idle_secs, + general.me_adaptive_floor_min_writers_single_endpoint, + general.me_adaptive_floor_min_writers_multi_endpoint, + general.me_adaptive_floor_recover_grace_secs, + general.me_adaptive_floor_writers_per_core_total, + general.me_adaptive_floor_cpu_cores_override, + general.me_adaptive_floor_max_extra_writers_single_per_core, + general.me_adaptive_floor_max_extra_writers_multi_per_core, + general.me_adaptive_floor_max_active_writers_per_core, + general.me_adaptive_floor_max_warm_writers_per_core, + general.me_adaptive_floor_max_active_writers_global, + general.me_adaptive_floor_max_warm_writers_global, + general.hardswap, + general.me_pool_drain_ttl_secs, + general.me_pool_drain_threshold, + general.effective_me_pool_force_close_secs(), + general.me_pool_min_fresh_ratio, + general.me_hardswap_warmup_delay_min_ms, + general.me_hardswap_warmup_delay_max_ms, + general.me_hardswap_warmup_extra_passes, + general.me_hardswap_warmup_pass_backoff_base_ms, + general.me_bind_stale_mode, + general.me_bind_stale_ttl_secs, + general.me_secret_atomic_snapshot, + general.me_deterministic_writer_sort, + MeWriterPickMode::default(), + general.me_writer_pick_sample_size, + MeSocksKdfPolicy::default(), + general.me_writer_cmd_channel_capacity, + general.me_route_channel_capacity, + general.me_route_backpressure_base_timeout_ms, + general.me_route_backpressure_high_timeout_ms, + general.me_route_backpressure_high_watermark_pct, + general.me_reader_route_data_wait_ms, + general.me_health_interval_ms_unhealthy, + general.me_health_interval_ms_healthy, + general.me_warn_rate_limit_ms, + MeRouteNoWriterMode::default(), + general.me_route_no_writer_wait_ms, + general.me_route_inline_recovery_attempts, + general.me_route_inline_recovery_wait_ms, + ); + (pool, rng) +} + +async fn insert_draining_writer( + pool: &Arc, + writer_id: u64, + drain_started_at_epoch_secs: u64, + bound_clients: usize, + drain_deadline_epoch_secs: u64, +) { + let (tx, _writer_rx) = mpsc::channel::(8); + let writer = MeWriter { + id: writer_id, + addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 5500 + writer_id as u16), + source_ip: IpAddr::V4(Ipv4Addr::LOCALHOST), + writer_dc: 2, + generation: 1, + contour: Arc::new(AtomicU8::new(WriterContour::Draining.as_u8())), + created_at: Instant::now() - Duration::from_secs(writer_id), + tx: tx.clone(), + cancel: CancellationToken::new(), + degraded: Arc::new(AtomicBool::new(false)), + rtt_ema_ms_x10: Arc::new(AtomicU32::new(0)), + draining: Arc::new(AtomicBool::new(true)), + draining_started_at_epoch_secs: Arc::new(AtomicU64::new(drain_started_at_epoch_secs)), + drain_deadline_epoch_secs: Arc::new(AtomicU64::new(drain_deadline_epoch_secs)), + allow_drain_fallback: Arc::new(AtomicBool::new(false)), + }; + pool.writers.write().await.push(writer); + pool.registry.register_writer(writer_id, tx).await; + pool.conn_count.fetch_add(1, Ordering::Relaxed); + for idx in 0..bound_clients { + let (conn_id, _rx) = pool.registry.register().await; + assert!( + pool.registry + .bind_writer( + conn_id, + writer_id, + ConnMeta { + target_dc: 2, + client_addr: SocketAddr::new( + IpAddr::V4(Ipv4Addr::LOCALHOST), + 7200 + idx as u16, + ), + our_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443), + proto_flags: 0, + }, + ) + .await + ); + } +} + +async fn wait_for_pool_empty(pool: &Arc, timeout: Duration) { + let start = Instant::now(); + loop { + if pool.writers.read().await.is_empty() { + return; + } + assert!( + start.elapsed() < timeout, + "timed out waiting for pool.writers to become empty" + ); + tokio::time::sleep(Duration::from_millis(5)).await; + } +} + +#[tokio::test] +async fn me_health_monitor_drains_expired_backlog_over_multiple_cycles() { + let (pool, rng) = make_pool(128, 1, 1).await; + let now_epoch_secs = MePool::now_epoch_secs(); + let writer_total = health_drain_close_budget().saturating_mul(2).saturating_add(9); + for writer_id in 1..=writer_total as u64 { + insert_draining_writer( + &pool, + writer_id, + now_epoch_secs.saturating_sub(120), + 1, + now_epoch_secs.saturating_sub(1), + ) + .await; + } + + let monitor = tokio::spawn(me_health_monitor(pool.clone(), rng, 0)); + wait_for_pool_empty(&pool, Duration::from_secs(1)).await; + monitor.abort(); + let _ = monitor.await; + + assert!(pool.writers.read().await.is_empty()); +} + +#[tokio::test] +async fn me_health_monitor_cleans_empty_draining_writers_without_force_close() { + let (pool, rng) = make_pool(128, 1, 1).await; + let now_epoch_secs = MePool::now_epoch_secs(); + for writer_id in 1..=24u64 { + insert_draining_writer(&pool, writer_id, now_epoch_secs.saturating_sub(60), 0, 0).await; + } + + let monitor = tokio::spawn(me_health_monitor(pool.clone(), rng, 0)); + wait_for_pool_empty(&pool, Duration::from_secs(1)).await; + monitor.abort(); + let _ = monitor.await; + + assert!(pool.writers.read().await.is_empty()); +} + +#[tokio::test] +async fn me_health_monitor_converges_retry_like_threshold_backlog_to_empty() { + let threshold = 4u64; + let (pool, rng) = make_pool(threshold, 1, 1).await; + let now_epoch_secs = MePool::now_epoch_secs(); + let writer_total = threshold as usize + health_drain_close_budget().saturating_add(11); + for writer_id in 1..=writer_total as u64 { + insert_draining_writer( + &pool, + writer_id, + now_epoch_secs.saturating_sub(300).saturating_add(writer_id), + 1, + 0, + ) + .await; + } + + let monitor = tokio::spawn(me_health_monitor(pool.clone(), rng, 0)); + wait_for_pool_empty(&pool, Duration::from_secs(1)).await; + monitor.abort(); + let _ = monitor.await; + + assert!(pool.writers.read().await.is_empty()); +} diff --git a/src/transport/middle_proxy/health_regression_tests.rs b/src/transport/middle_proxy/health_regression_tests.rs new file mode 100644 index 0000000..6b6b12a --- /dev/null +++ b/src/transport/middle_proxy/health_regression_tests.rs @@ -0,0 +1,658 @@ +use std::collections::HashMap; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, AtomicU8, AtomicU32, AtomicU64, Ordering}; +use std::time::{Duration, Instant}; + +use tokio::sync::mpsc; +use tokio_util::sync::CancellationToken; + +use super::codec::WriterCommand; +use super::health::{health_drain_close_budget, reap_draining_writers}; +use super::pool::{MePool, MeWriter, WriterContour}; +use super::registry::ConnMeta; +use crate::config::{GeneralConfig, MeRouteNoWriterMode, MeSocksKdfPolicy, MeWriterPickMode}; +use crate::crypto::SecureRandom; +use crate::network::probe::NetworkDecision; +use crate::stats::Stats; + +async fn make_pool(me_pool_drain_threshold: u64) -> Arc { + let general = GeneralConfig { + me_pool_drain_threshold, + ..GeneralConfig::default() + }; + + MePool::new( + None, + vec![1u8; 32], + None, + false, + None, + Vec::new(), + 1, + None, + 12, + 1200, + HashMap::new(), + HashMap::new(), + None, + NetworkDecision::default(), + None, + Arc::new(SecureRandom::new()), + Arc::new(Stats::default()), + general.me_keepalive_enabled, + general.me_keepalive_interval_secs, + general.me_keepalive_jitter_secs, + general.me_keepalive_payload_random, + general.rpc_proxy_req_every, + general.me_warmup_stagger_enabled, + general.me_warmup_step_delay_ms, + general.me_warmup_step_jitter_ms, + general.me_reconnect_max_concurrent_per_dc, + general.me_reconnect_backoff_base_ms, + general.me_reconnect_backoff_cap_ms, + general.me_reconnect_fast_retry_count, + general.me_single_endpoint_shadow_writers, + general.me_single_endpoint_outage_mode_enabled, + general.me_single_endpoint_outage_disable_quarantine, + general.me_single_endpoint_outage_backoff_min_ms, + general.me_single_endpoint_outage_backoff_max_ms, + general.me_single_endpoint_shadow_rotate_every_secs, + general.me_floor_mode, + general.me_adaptive_floor_idle_secs, + general.me_adaptive_floor_min_writers_single_endpoint, + general.me_adaptive_floor_min_writers_multi_endpoint, + general.me_adaptive_floor_recover_grace_secs, + general.me_adaptive_floor_writers_per_core_total, + general.me_adaptive_floor_cpu_cores_override, + general.me_adaptive_floor_max_extra_writers_single_per_core, + general.me_adaptive_floor_max_extra_writers_multi_per_core, + general.me_adaptive_floor_max_active_writers_per_core, + general.me_adaptive_floor_max_warm_writers_per_core, + general.me_adaptive_floor_max_active_writers_global, + general.me_adaptive_floor_max_warm_writers_global, + general.hardswap, + general.me_pool_drain_ttl_secs, + general.me_pool_drain_threshold, + general.effective_me_pool_force_close_secs(), + general.me_pool_min_fresh_ratio, + general.me_hardswap_warmup_delay_min_ms, + general.me_hardswap_warmup_delay_max_ms, + general.me_hardswap_warmup_extra_passes, + general.me_hardswap_warmup_pass_backoff_base_ms, + general.me_bind_stale_mode, + general.me_bind_stale_ttl_secs, + general.me_secret_atomic_snapshot, + general.me_deterministic_writer_sort, + MeWriterPickMode::default(), + general.me_writer_pick_sample_size, + MeSocksKdfPolicy::default(), + general.me_writer_cmd_channel_capacity, + general.me_route_channel_capacity, + general.me_route_backpressure_base_timeout_ms, + general.me_route_backpressure_high_timeout_ms, + general.me_route_backpressure_high_watermark_pct, + general.me_reader_route_data_wait_ms, + general.me_health_interval_ms_unhealthy, + general.me_health_interval_ms_healthy, + general.me_warn_rate_limit_ms, + MeRouteNoWriterMode::default(), + general.me_route_no_writer_wait_ms, + general.me_route_inline_recovery_attempts, + general.me_route_inline_recovery_wait_ms, + ) +} + +async fn insert_draining_writer( + pool: &Arc, + writer_id: u64, + drain_started_at_epoch_secs: u64, + bound_clients: usize, + drain_deadline_epoch_secs: u64, +) -> Vec { + let mut conn_ids = Vec::with_capacity(bound_clients); + let (tx, _writer_rx) = mpsc::channel::(8); + let writer = MeWriter { + id: writer_id, + addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 4500 + writer_id as u16), + source_ip: IpAddr::V4(Ipv4Addr::LOCALHOST), + writer_dc: 2, + generation: 1, + contour: Arc::new(AtomicU8::new(WriterContour::Draining.as_u8())), + created_at: Instant::now() - Duration::from_secs(writer_id), + tx: tx.clone(), + cancel: CancellationToken::new(), + degraded: Arc::new(AtomicBool::new(false)), + rtt_ema_ms_x10: Arc::new(AtomicU32::new(0)), + draining: Arc::new(AtomicBool::new(true)), + draining_started_at_epoch_secs: Arc::new(AtomicU64::new(drain_started_at_epoch_secs)), + drain_deadline_epoch_secs: Arc::new(AtomicU64::new(drain_deadline_epoch_secs)), + allow_drain_fallback: Arc::new(AtomicBool::new(false)), + }; + pool.writers.write().await.push(writer); + pool.registry.register_writer(writer_id, tx).await; + pool.conn_count.fetch_add(1, Ordering::Relaxed); + for idx in 0..bound_clients { + let (conn_id, _rx) = pool.registry.register().await; + assert!( + pool.registry + .bind_writer( + conn_id, + writer_id, + ConnMeta { + target_dc: 2, + client_addr: SocketAddr::new( + IpAddr::V4(Ipv4Addr::LOCALHOST), + 6200 + idx as u16, + ), + our_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443), + proto_flags: 0, + }, + ) + .await + ); + conn_ids.push(conn_id); + } + conn_ids +} + +async fn current_writer_ids(pool: &Arc) -> Vec { + let mut writer_ids = pool + .writers + .read() + .await + .iter() + .map(|writer| writer.id) + .collect::>(); + writer_ids.sort_unstable(); + writer_ids +} + +async fn writer_exists(pool: &Arc, writer_id: u64) -> bool { + pool.writers + .read() + .await + .iter() + .any(|writer| writer.id == writer_id) +} + +async fn set_writer_draining(pool: &Arc, writer_id: u64, draining: bool) { + let writers = pool.writers.read().await; + if let Some(writer) = writers.iter().find(|writer| writer.id == writer_id) { + writer.draining.store(draining, Ordering::Relaxed); + } +} + +#[tokio::test] +async fn reap_draining_writers_drops_warn_state_for_removed_writer() { + let pool = make_pool(128).await; + let now_epoch_secs = MePool::now_epoch_secs(); + let conn_ids = + insert_draining_writer(&pool, 7, now_epoch_secs.saturating_sub(180), 1, 0).await; + let mut warn_next_allowed = HashMap::new(); + + reap_draining_writers(&pool, &mut warn_next_allowed).await; + assert!(warn_next_allowed.contains_key(&7)); + + let _ = pool.remove_writer_and_close_clients(7).await; + assert!(pool.registry.get_writer(conn_ids[0]).await.is_none()); + + reap_draining_writers(&pool, &mut warn_next_allowed).await; + assert!(!warn_next_allowed.contains_key(&7)); +} + +#[tokio::test] +async fn reap_draining_writers_removes_empty_draining_writers() { + let pool = make_pool(128).await; + let now_epoch_secs = MePool::now_epoch_secs(); + insert_draining_writer(&pool, 1, now_epoch_secs.saturating_sub(40), 0, 0).await; + insert_draining_writer(&pool, 2, now_epoch_secs.saturating_sub(30), 0, 0).await; + insert_draining_writer(&pool, 3, now_epoch_secs.saturating_sub(20), 1, 0).await; + let mut warn_next_allowed = HashMap::new(); + + reap_draining_writers(&pool, &mut warn_next_allowed).await; + + assert_eq!(current_writer_ids(&pool).await, vec![3]); +} + +#[tokio::test] +async fn reap_draining_writers_overflow_closes_oldest_non_empty_writers() { + let pool = make_pool(2).await; + let now_epoch_secs = MePool::now_epoch_secs(); + insert_draining_writer(&pool, 11, now_epoch_secs.saturating_sub(40), 1, 0).await; + insert_draining_writer(&pool, 22, now_epoch_secs.saturating_sub(30), 1, 0).await; + insert_draining_writer(&pool, 33, now_epoch_secs.saturating_sub(20), 1, 0).await; + insert_draining_writer(&pool, 44, now_epoch_secs.saturating_sub(10), 1, 0).await; + let mut warn_next_allowed = HashMap::new(); + + reap_draining_writers(&pool, &mut warn_next_allowed).await; + + assert_eq!(current_writer_ids(&pool).await, vec![33, 44]); +} + +#[tokio::test] +async fn reap_draining_writers_deadline_force_close_applies_under_threshold() { + let pool = make_pool(128).await; + let now_epoch_secs = MePool::now_epoch_secs(); + insert_draining_writer( + &pool, + 50, + now_epoch_secs.saturating_sub(15), + 1, + now_epoch_secs.saturating_sub(1), + ) + .await; + let mut warn_next_allowed = HashMap::new(); + + reap_draining_writers(&pool, &mut warn_next_allowed).await; + + assert!(current_writer_ids(&pool).await.is_empty()); +} + +#[tokio::test] +async fn reap_draining_writers_limits_closes_per_health_tick() { + let pool = make_pool(128).await; + let now_epoch_secs = MePool::now_epoch_secs(); + let close_budget = health_drain_close_budget(); + let writer_total = close_budget.saturating_add(19); + for writer_id in 1..=writer_total as u64 { + insert_draining_writer( + &pool, + writer_id, + now_epoch_secs.saturating_sub(20), + 1, + now_epoch_secs.saturating_sub(1), + ) + .await; + } + let mut warn_next_allowed = HashMap::new(); + + reap_draining_writers(&pool, &mut warn_next_allowed).await; + + assert_eq!(pool.writers.read().await.len(), writer_total - close_budget); +} + +#[tokio::test] +async fn reap_draining_writers_keeps_warn_state_for_deadline_backlog_writers() { + let pool = make_pool(0).await; + let now_epoch_secs = MePool::now_epoch_secs(); + let close_budget = health_drain_close_budget(); + let writer_total = close_budget.saturating_add(5); + for writer_id in 1..=writer_total as u64 { + insert_draining_writer( + &pool, + writer_id, + now_epoch_secs.saturating_sub(60), + 1, + now_epoch_secs.saturating_sub(1), + ) + .await; + } + let target_writer_id = writer_total as u64; + let mut warn_next_allowed = HashMap::new(); + warn_next_allowed.insert( + target_writer_id, + Instant::now() + Duration::from_secs(300), + ); + + reap_draining_writers(&pool, &mut warn_next_allowed).await; + + assert!(writer_exists(&pool, target_writer_id).await); + assert!(warn_next_allowed.contains_key(&target_writer_id)); +} + +#[tokio::test] +async fn reap_draining_writers_keeps_warn_state_for_overflow_backlog_writers() { + let pool = make_pool(1).await; + let now_epoch_secs = MePool::now_epoch_secs(); + let close_budget = health_drain_close_budget(); + let writer_total = close_budget.saturating_add(6); + for writer_id in 1..=writer_total as u64 { + insert_draining_writer( + &pool, + writer_id, + now_epoch_secs.saturating_sub(300).saturating_add(writer_id), + 1, + 0, + ) + .await; + } + let target_writer_id = writer_total.saturating_sub(1) as u64; + let mut warn_next_allowed = HashMap::new(); + warn_next_allowed.insert( + target_writer_id, + Instant::now() + Duration::from_secs(300), + ); + + reap_draining_writers(&pool, &mut warn_next_allowed).await; + + assert!(writer_exists(&pool, target_writer_id).await); + assert!(warn_next_allowed.contains_key(&target_writer_id)); +} + +#[tokio::test] +async fn reap_draining_writers_drops_warn_state_when_writer_exits_draining_state() { + let pool = make_pool(128).await; + let now_epoch_secs = MePool::now_epoch_secs(); + insert_draining_writer(&pool, 71, now_epoch_secs.saturating_sub(60), 1, 0).await; + + let mut warn_next_allowed = HashMap::new(); + warn_next_allowed.insert(71, Instant::now() + Duration::from_secs(300)); + + set_writer_draining(&pool, 71, false).await; + reap_draining_writers(&pool, &mut warn_next_allowed).await; + + assert!(writer_exists(&pool, 71).await); + assert!( + !warn_next_allowed.contains_key(&71), + "warn cooldown state must be dropped after writer leaves draining state" + ); +} + +#[tokio::test] +async fn reap_draining_writers_preserves_warn_state_across_multiple_budget_deferrals() { + let pool = make_pool(0).await; + let now_epoch_secs = MePool::now_epoch_secs(); + let close_budget = health_drain_close_budget(); + let writer_total = close_budget.saturating_mul(2).saturating_add(1); + for writer_id in 1..=writer_total as u64 { + insert_draining_writer( + &pool, + writer_id, + now_epoch_secs.saturating_sub(120), + 1, + now_epoch_secs.saturating_sub(1), + ) + .await; + } + + let tail_writer_id = writer_total as u64; + let mut warn_next_allowed = HashMap::new(); + warn_next_allowed.insert( + tail_writer_id, + Instant::now() + Duration::from_secs(300), + ); + + reap_draining_writers(&pool, &mut warn_next_allowed).await; + assert!(writer_exists(&pool, tail_writer_id).await); + assert!(warn_next_allowed.contains_key(&tail_writer_id)); + + reap_draining_writers(&pool, &mut warn_next_allowed).await; + assert!(writer_exists(&pool, tail_writer_id).await); + assert!(warn_next_allowed.contains_key(&tail_writer_id)); + + reap_draining_writers(&pool, &mut warn_next_allowed).await; + assert!(!writer_exists(&pool, tail_writer_id).await); + assert!( + !warn_next_allowed.contains_key(&tail_writer_id), + "warn cooldown state must clear once writer is actually removed" + ); +} + +#[tokio::test] +async fn reap_draining_writers_backlog_drains_across_ticks() { + let pool = make_pool(128).await; + let now_epoch_secs = MePool::now_epoch_secs(); + let close_budget = health_drain_close_budget(); + let writer_total = close_budget.saturating_mul(2).saturating_add(7); + for writer_id in 1..=writer_total as u64 { + insert_draining_writer( + &pool, + writer_id, + now_epoch_secs.saturating_sub(20), + 1, + now_epoch_secs.saturating_sub(1), + ) + .await; + } + let mut warn_next_allowed = HashMap::new(); + + for _ in 0..8 { + if pool.writers.read().await.is_empty() { + break; + } + reap_draining_writers(&pool, &mut warn_next_allowed).await; + } + + assert!(pool.writers.read().await.is_empty()); +} + +#[tokio::test] +async fn reap_draining_writers_threshold_backlog_converges_to_threshold() { + let threshold = 5u64; + let pool = make_pool(threshold).await; + let now_epoch_secs = MePool::now_epoch_secs(); + let close_budget = health_drain_close_budget(); + let writer_total = threshold as usize + close_budget.saturating_add(12); + for writer_id in 1..=writer_total as u64 { + insert_draining_writer( + &pool, + writer_id, + now_epoch_secs.saturating_sub(200).saturating_add(writer_id), + 1, + 0, + ) + .await; + } + let mut warn_next_allowed = HashMap::new(); + + for _ in 0..16 { + reap_draining_writers(&pool, &mut warn_next_allowed).await; + if pool.writers.read().await.len() <= threshold as usize { + break; + } + } + + assert_eq!(pool.writers.read().await.len(), threshold as usize); +} + +#[tokio::test] +async fn reap_draining_writers_threshold_zero_preserves_non_expired_non_empty_writers() { + let pool = make_pool(0).await; + let now_epoch_secs = MePool::now_epoch_secs(); + insert_draining_writer(&pool, 10, now_epoch_secs.saturating_sub(40), 1, 0).await; + insert_draining_writer(&pool, 20, now_epoch_secs.saturating_sub(30), 1, 0).await; + insert_draining_writer(&pool, 30, now_epoch_secs.saturating_sub(20), 1, 0).await; + let mut warn_next_allowed = HashMap::new(); + + reap_draining_writers(&pool, &mut warn_next_allowed).await; + + assert_eq!(current_writer_ids(&pool).await, vec![10, 20, 30]); +} + +#[tokio::test] +async fn reap_draining_writers_prioritizes_force_close_before_empty_cleanup() { + let pool = make_pool(128).await; + let now_epoch_secs = MePool::now_epoch_secs(); + let close_budget = health_drain_close_budget(); + for writer_id in 1..=close_budget as u64 { + insert_draining_writer( + &pool, + writer_id, + now_epoch_secs.saturating_sub(20), + 1, + now_epoch_secs.saturating_sub(1), + ) + .await; + } + let empty_writer_id = close_budget as u64 + 1; + insert_draining_writer(&pool, empty_writer_id, now_epoch_secs.saturating_sub(20), 0, 0).await; + let mut warn_next_allowed = HashMap::new(); + + reap_draining_writers(&pool, &mut warn_next_allowed).await; + + assert_eq!(current_writer_ids(&pool).await, vec![empty_writer_id]); +} + +#[tokio::test] +async fn reap_draining_writers_empty_cleanup_does_not_increment_force_close_metric() { + let pool = make_pool(128).await; + let now_epoch_secs = MePool::now_epoch_secs(); + insert_draining_writer(&pool, 1, now_epoch_secs.saturating_sub(60), 0, 0).await; + insert_draining_writer(&pool, 2, now_epoch_secs.saturating_sub(50), 0, 0).await; + let mut warn_next_allowed = HashMap::new(); + + reap_draining_writers(&pool, &mut warn_next_allowed).await; + + assert!(current_writer_ids(&pool).await.is_empty()); + assert_eq!(pool.stats.get_pool_force_close_total(), 0); +} + +#[tokio::test] +async fn reap_draining_writers_handles_duplicate_force_close_requests_for_same_writer() { + let pool = make_pool(1).await; + let now_epoch_secs = MePool::now_epoch_secs(); + insert_draining_writer( + &pool, + 10, + now_epoch_secs.saturating_sub(30), + 1, + now_epoch_secs.saturating_sub(1), + ) + .await; + insert_draining_writer( + &pool, + 20, + now_epoch_secs.saturating_sub(20), + 1, + now_epoch_secs.saturating_sub(1), + ) + .await; + let mut warn_next_allowed = HashMap::new(); + + reap_draining_writers(&pool, &mut warn_next_allowed).await; + + assert!(current_writer_ids(&pool).await.is_empty()); +} + +#[tokio::test] +async fn reap_draining_writers_warn_state_never_exceeds_live_draining_population_under_churn() { + let pool = make_pool(128).await; + let now_epoch_secs = MePool::now_epoch_secs(); + let mut warn_next_allowed = HashMap::new(); + + for wave in 0..12u64 { + for offset in 0..9u64 { + insert_draining_writer( + &pool, + wave * 100 + offset, + now_epoch_secs.saturating_sub(120 + offset), + 1, + 0, + ) + .await; + } + reap_draining_writers(&pool, &mut warn_next_allowed).await; + assert!(warn_next_allowed.len() <= pool.writers.read().await.len()); + + let existing_writer_ids = current_writer_ids(&pool).await; + for writer_id in existing_writer_ids.into_iter().take(4) { + let _ = pool.remove_writer_and_close_clients(writer_id).await; + } + reap_draining_writers(&pool, &mut warn_next_allowed).await; + assert!(warn_next_allowed.len() <= pool.writers.read().await.len()); + } +} + +#[tokio::test] +async fn reap_draining_writers_mixed_backlog_converges_without_leaking_warn_state() { + let pool = make_pool(6).await; + let now_epoch_secs = MePool::now_epoch_secs(); + let mut warn_next_allowed = HashMap::new(); + + for writer_id in 1..=18u64 { + let bound_clients = if writer_id % 3 == 0 { 0 } else { 1 }; + let deadline = if writer_id % 2 == 0 { + now_epoch_secs.saturating_sub(1) + } else { + 0 + }; + insert_draining_writer( + &pool, + writer_id, + now_epoch_secs.saturating_sub(300).saturating_add(writer_id), + bound_clients, + deadline, + ) + .await; + } + + for _ in 0..16 { + reap_draining_writers(&pool, &mut warn_next_allowed).await; + if pool.writers.read().await.len() <= 6 { + break; + } + } + + assert!(pool.writers.read().await.len() <= 6); + assert!(warn_next_allowed.len() <= pool.writers.read().await.len()); +} + +#[test] +fn general_config_default_drain_threshold_remains_enabled() { + assert_eq!(GeneralConfig::default().me_pool_drain_threshold, 128); +} + +#[tokio::test] +async fn reap_draining_writers_does_not_close_writer_that_became_non_empty_after_snapshot() { + let pool = make_pool(128).await; + let now_epoch_secs = MePool::now_epoch_secs(); + + let empty_writer_id = 700u64; + insert_draining_writer( + &pool, + empty_writer_id, + now_epoch_secs.saturating_sub(60), + 0, + 0, + ) + .await; + + let stale_empty_snapshot = vec![empty_writer_id]; + let (rebound_conn_id, _rx) = pool.registry.register().await; + assert!( + pool.registry + .bind_writer( + rebound_conn_id, + empty_writer_id, + ConnMeta { + target_dc: 2, + client_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 9050), + our_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443), + proto_flags: 0, + }, + ) + .await, + "writer should accept a new bind after stale empty snapshot" + ); + + for writer_id in stale_empty_snapshot { + assert!( + !pool.remove_writer_if_empty(writer_id).await, + "atomic empty cleanup must reject writers that gained bound clients" + ); + } + + assert!( + writer_exists(&pool, empty_writer_id).await, + "empty-path cleanup must not remove a writer that gained a bound client" + ); + assert_eq!( + pool.registry.get_writer(rebound_conn_id).await.map(|w| w.writer_id), + Some(empty_writer_id) + ); + + let _ = pool.registry.unregister(rebound_conn_id).await; +} + +#[tokio::test] +async fn prune_closed_writers_closes_bound_clients_when_writer_is_non_empty() { + let pool = make_pool(128).await; + let now_epoch_secs = MePool::now_epoch_secs(); + let conn_ids = insert_draining_writer(&pool, 910, now_epoch_secs.saturating_sub(60), 1, 0).await; + + pool.prune_closed_writers().await; + + assert!(!writer_exists(&pool, 910).await); + assert!(pool.registry.get_writer(conn_ids[0]).await.is_none()); +} diff --git a/src/transport/middle_proxy/mod.rs b/src/transport/middle_proxy/mod.rs index 92e222d..590c996 100644 --- a/src/transport/middle_proxy/mod.rs +++ b/src/transport/middle_proxy/mod.rs @@ -21,6 +21,12 @@ mod secret; mod selftest; mod wire; mod pool_status; +#[cfg(test)] +mod health_regression_tests; +#[cfg(test)] +mod health_integration_tests; +#[cfg(test)] +mod health_adversarial_tests; use bytes::Bytes; diff --git a/src/transport/middle_proxy/pool_writer.rs b/src/transport/middle_proxy/pool_writer.rs index 7490a98..5b23d7f 100644 --- a/src/transport/middle_proxy/pool_writer.rs +++ b/src/transport/middle_proxy/pool_writer.rs @@ -42,11 +42,10 @@ impl MePool { } for writer_id in closed_writer_ids { - if self.registry.is_writer_empty(writer_id).await { - let _ = self.remove_writer_only(writer_id).await; - } else { - let _ = self.remove_writer_and_close_clients(writer_id).await; + if self.remove_writer_if_empty(writer_id).await { + continue; } + let _ = self.remove_writer_and_close_clients(writer_id).await; } } @@ -501,6 +500,17 @@ impl MePool { } } + pub(crate) async fn remove_writer_if_empty(self: &Arc, writer_id: u64) -> bool { + if !self.registry.unregister_writer_if_empty(writer_id).await { + return false; + } + + // The registry empty-check and unregister are atomic with respect to binds, + // so remove_writer_only cannot return active bound sessions here. + let _ = self.remove_writer_only(writer_id).await; + true + } + async fn remove_writer_only(self: &Arc, writer_id: u64) -> Vec { let mut close_tx: Option> = None; let mut removed_addr: Option = None; diff --git a/src/transport/middle_proxy/registry.rs b/src/transport/middle_proxy/registry.rs index cbe1d9a..ea968b5 100644 --- a/src/transport/middle_proxy/registry.rs +++ b/src/transport/middle_proxy/registry.rs @@ -437,6 +437,23 @@ impl ConnRegistry { .unwrap_or(true) } + pub async fn unregister_writer_if_empty(&self, writer_id: u64) -> bool { + let mut inner = self.inner.write().await; + let Some(conn_ids) = inner.conns_for_writer.get(&writer_id) else { + // Writer is already absent from the registry. + return true; + }; + if !conn_ids.is_empty() { + return false; + } + + inner.writers.remove(&writer_id); + inner.last_meta_for_writer.remove(&writer_id); + inner.writer_idle_since_epoch_secs.remove(&writer_id); + inner.conns_for_writer.remove(&writer_id); + true + } + pub(super) async fn non_empty_writer_ids(&self, writer_ids: &[u64]) -> HashSet { let inner = self.inner.read().await; let mut out = HashSet::::with_capacity(writer_ids.len());