From a74def9561e3d6fe653ae1764b98db2589be1513 Mon Sep 17 00:00:00 2001 From: Sergey Kutovoy Date: Tue, 17 Mar 2026 12:58:40 +0500 Subject: [PATCH 01/13] Update metrics configuration to support custom listen address - Bump telemt dependency version from 3.3.15 to 3.3.19. - Add `metrics_listen` option to `config.toml` for specifying a custom address for the metrics endpoint. - Update `ServerConfig` struct to include `metrics_listen` and adjust logic in `spawn_metrics_if_configured` to prioritize this new option over `metrics_port`. - Enhance error handling for invalid listen addresses in metrics setup. --- Cargo.lock | 2 +- config.toml | 1 + src/config/types.rs | 9 +++++++++ src/maestro/runtime_tasks.rs | 28 +++++++++++++++++++++++++--- src/metrics.rs | 28 ++++++++++++++++++++++++++++ 5 files changed, 64 insertions(+), 4 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 06ea5c6..a704404 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2087,7 +2087,7 @@ dependencies = [ [[package]] name = "telemt" -version = "3.3.15" +version = "3.3.19" dependencies = [ "aes", "anyhow", diff --git a/config.toml b/config.toml index 63fa4ae..f4eb3ae 100644 --- a/config.toml +++ b/config.toml @@ -32,6 +32,7 @@ show = "*" port = 443 # proxy_protocol = false # Enable if behind HAProxy/nginx with PROXY protocol # metrics_port = 9090 +# metrics_listen = "0.0.0.0:9090" # Listen address for metrics (overrides metrics_port) # metrics_whitelist = ["127.0.0.1", "::1", "0.0.0.0/0"] [server.api] diff --git a/src/config/types.rs b/src/config/types.rs index f676f54..7ea1fe7 100644 --- a/src/config/types.rs +++ b/src/config/types.rs @@ -1156,9 +1156,17 @@ pub struct ServerConfig { #[serde(default = "default_proxy_protocol_header_timeout_ms")] pub proxy_protocol_header_timeout_ms: u64, + /// Port for the Prometheus-compatible metrics endpoint. + /// Enables metrics when set; binds on all interfaces (dual-stack) by default. #[serde(default)] pub metrics_port: Option, + /// Listen address for metrics in `IP:PORT` format (e.g. `"127.0.0.1:9090"`). + /// When set, takes precedence over `metrics_port` and binds on the specified address only. + #[serde(default)] + pub metrics_listen: Option, + + /// CIDR whitelist for the metrics endpoint. #[serde(default = "default_metrics_whitelist")] pub metrics_whitelist: Vec, @@ -1186,6 +1194,7 @@ impl Default for ServerConfig { proxy_protocol: false, proxy_protocol_header_timeout_ms: default_proxy_protocol_header_timeout_ms(), metrics_port: None, + metrics_listen: None, metrics_whitelist: default_metrics_whitelist(), api: ApiConfig::default(), listeners: Vec::new(), diff --git a/src/maestro/runtime_tasks.rs b/src/maestro/runtime_tasks.rs index 329e267..d9691a8 100644 --- a/src/maestro/runtime_tasks.rs +++ b/src/maestro/runtime_tasks.rs @@ -279,11 +279,32 @@ pub(crate) async fn spawn_metrics_if_configured( ip_tracker: Arc, config_rx: watch::Receiver>, ) { - if let Some(port) = config.server.metrics_port { + // metrics_listen takes precedence; fall back to metrics_port for backward compat. + let metrics_target: Option<(u16, Option)> = + if let Some(ref listen) = config.server.metrics_listen { + match listen.parse::() { + Ok(addr) => Some((addr.port(), Some(listen.clone()))), + Err(e) => { + startup_tracker + .skip_component( + COMPONENT_METRICS_START, + Some(format!("invalid metrics_listen \"{}\": {}", listen, e)), + ) + .await; + None + } + } + } else { + config.server.metrics_port.map(|p| (p, None)) + }; + + if let Some((port, listen)) = metrics_target { + let fallback_label = format!("port {}", port); + let label = listen.as_deref().unwrap_or(&fallback_label); startup_tracker .start_component( COMPONENT_METRICS_START, - Some(format!("spawn metrics endpoint on {}", port)), + Some(format!("spawn metrics endpoint on {}", label)), ) .await; let stats = stats.clone(); @@ -294,6 +315,7 @@ pub(crate) async fn spawn_metrics_if_configured( tokio::spawn(async move { metrics::serve( port, + listen, stats, beobachten, ip_tracker_metrics, @@ -308,7 +330,7 @@ pub(crate) async fn spawn_metrics_if_configured( Some("metrics task spawned".to_string()), ) .await; - } else { + } else if config.server.metrics_listen.is_none() { startup_tracker .skip_component( COMPONENT_METRICS_START, diff --git a/src/metrics.rs b/src/metrics.rs index 02edfd7..f4f8a2e 100644 --- a/src/metrics.rs +++ b/src/metrics.rs @@ -21,6 +21,7 @@ use crate::transport::{ListenOptions, create_listener}; pub async fn serve( port: u16, + listen: Option, stats: Arc, beobachten: Arc, ip_tracker: Arc, @@ -28,6 +29,33 @@ pub async fn serve( whitelist: Vec, ) { let whitelist = Arc::new(whitelist); + + // If `metrics_listen` is set, bind on that single address only. + if let Some(ref listen_addr) = listen { + let addr: SocketAddr = match listen_addr.parse() { + Ok(a) => a, + Err(e) => { + warn!(error = %e, "Invalid metrics_listen address: {}", listen_addr); + return; + } + }; + let is_ipv6 = addr.is_ipv6(); + match bind_metrics_listener(addr, is_ipv6) { + Ok(listener) => { + info!("Metrics endpoint: http://{}/metrics and /beobachten", addr); + serve_listener( + listener, stats, beobachten, ip_tracker, config_rx, whitelist, + ) + .await; + } + Err(e) => { + warn!(error = %e, "Failed to bind metrics on {}", addr); + } + } + return; + } + + // Fallback: bind on 0.0.0.0 and [::] using metrics_port. let mut listener_v4 = None; let mut listener_v6 = None; From bd0cefdb12d864ea8a25e5829515fa86eb376b3a Mon Sep 17 00:00:00 2001 From: Dimasssss Date: Tue, 17 Mar 2026 11:56:56 +0300 Subject: [PATCH 02/13] Update TLS-F-TCP-S.ru.md --- docs/fronting-splitting/TLS-F-TCP-S.ru.md | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/docs/fronting-splitting/TLS-F-TCP-S.ru.md b/docs/fronting-splitting/TLS-F-TCP-S.ru.md index 6ae6f05..1f9f872 100644 --- a/docs/fronting-splitting/TLS-F-TCP-S.ru.md +++ b/docs/fronting-splitting/TLS-F-TCP-S.ru.md @@ -38,8 +38,9 @@ umweltschutz.de -> A-запись 198.18.88.88 В конфигурации Telemt: -``` -tls_domain = umweltschutz.de +```toml +[censorship] +tls_domain = "umweltschutz.de" ``` Этот домен используется клиентом как SNI в ClientHello @@ -56,8 +57,9 @@ tls_domain = umweltschutz.de В конфигурации Telemt: -``` -mask_host = 127.0.0.1 +```toml +[censorship] +mask_host = "127.0.0.1" mask_port = 8443 ``` @@ -151,16 +153,18 @@ mask_host:mask_port Например: -``` -tls_domain = github.com -mask_host = github.com +```toml +[censorship] +tls_domain = "github.com" +mask_host = "github.com" mask_port = 443 ``` или -``` -mask_host = 140.82.121.4 +```toml +[censorship] +mask_host = "140.82.121.4" ``` В этом случае: From c9271d90837c03d684f734d46ee7f0d5539d1a5d Mon Sep 17 00:00:00 2001 From: David Osipov Date: Tue, 17 Mar 2026 17:11:51 +0400 Subject: [PATCH 03/13] Add health monitoring tests for draining writers - Introduced adversarial tests to validate the behavior of the health monitoring system under various conditions, including the management of draining writers. - Implemented integration tests to ensure the health monitor correctly handles expired and empty draining writers. - Added regression tests to verify the functionality of the draining writers' cleanup process, ensuring it adheres to the defined thresholds and budgets. - Updated the module structure to include the new test files for better organization and maintainability. --- src/ip_tracker_regression_tests.rs | 450 +++++++++++++++++ src/main.rs | 2 + src/transport/middle_proxy/health.rs | 77 ++- .../middle_proxy/health_adversarial_tests.rs | 437 +++++++++++++++++ .../middle_proxy/health_integration_tests.rs | 227 +++++++++ .../middle_proxy/health_regression_tests.rs | 462 ++++++++++++++++++ src/transport/middle_proxy/mod.rs | 6 + 7 files changed, 1653 insertions(+), 8 deletions(-) create mode 100644 src/ip_tracker_regression_tests.rs create mode 100644 src/transport/middle_proxy/health_adversarial_tests.rs create mode 100644 src/transport/middle_proxy/health_integration_tests.rs create mode 100644 src/transport/middle_proxy/health_regression_tests.rs diff --git a/src/ip_tracker_regression_tests.rs b/src/ip_tracker_regression_tests.rs new file mode 100644 index 0000000..5d6b358 --- /dev/null +++ b/src/ip_tracker_regression_tests.rs @@ -0,0 +1,450 @@ +use std::collections::HashMap; +use std::net::{IpAddr, Ipv4Addr}; +use std::sync::Arc; +use std::time::Duration; + +use crate::config::UserMaxUniqueIpsMode; +use crate::ip_tracker::UserIpTracker; + +fn ip_from_idx(idx: u32) -> IpAddr { + let a = 10u8; + let b = ((idx / 65_536) % 256) as u8; + let c = ((idx / 256) % 256) as u8; + let d = (idx % 256) as u8; + IpAddr::V4(Ipv4Addr::new(a, b, c, d)) +} + +#[tokio::test] +async fn active_window_enforces_large_unique_ip_burst() { + let tracker = UserIpTracker::new(); + tracker.set_user_limit("burst_user", 64).await; + tracker + .set_limit_policy(UserMaxUniqueIpsMode::ActiveWindow, 30) + .await; + + for idx in 0..64 { + assert!(tracker.check_and_add("burst_user", ip_from_idx(idx)).await.is_ok()); + } + assert!(tracker.check_and_add("burst_user", ip_from_idx(9_999)).await.is_err()); + assert_eq!(tracker.get_active_ip_count("burst_user").await, 64); +} + +#[tokio::test] +async fn global_limit_applies_across_many_users() { + let tracker = UserIpTracker::new(); + tracker.load_limits(3, &HashMap::new()).await; + + for user_idx in 0..150u32 { + let user = format!("u{}", user_idx); + assert!(tracker.check_and_add(&user, ip_from_idx(user_idx * 10)).await.is_ok()); + assert!(tracker + .check_and_add(&user, ip_from_idx(user_idx * 10 + 1)) + .await + .is_ok()); + assert!(tracker + .check_and_add(&user, ip_from_idx(user_idx * 10 + 2)) + .await + .is_ok()); + assert!(tracker + .check_and_add(&user, ip_from_idx(user_idx * 10 + 3)) + .await + .is_err()); + } + + assert_eq!(tracker.get_stats().await.len(), 150); +} + +#[tokio::test] +async fn user_zero_override_falls_back_to_global_limit() { + let tracker = UserIpTracker::new(); + let mut limits = HashMap::new(); + limits.insert("target".to_string(), 0); + tracker.load_limits(2, &limits).await; + + assert!(tracker.check_and_add("target", ip_from_idx(1)).await.is_ok()); + assert!(tracker.check_and_add("target", ip_from_idx(2)).await.is_ok()); + assert!(tracker.check_and_add("target", ip_from_idx(3)).await.is_err()); + assert_eq!(tracker.get_user_limit("target").await, Some(2)); +} + +#[tokio::test] +async fn remove_ip_is_idempotent_after_counter_reaches_zero() { + let tracker = UserIpTracker::new(); + tracker.set_user_limit("u", 2).await; + let ip = ip_from_idx(42); + + tracker.check_and_add("u", ip).await.unwrap(); + tracker.remove_ip("u", ip).await; + tracker.remove_ip("u", ip).await; + tracker.remove_ip("u", ip).await; + + assert_eq!(tracker.get_active_ip_count("u").await, 0); + assert!(!tracker.is_ip_active("u", ip).await); +} + +#[tokio::test] +async fn clear_user_ips_resets_active_and_recent() { + let tracker = UserIpTracker::new(); + tracker.set_user_limit("u", 10).await; + + for idx in 0..6 { + tracker.check_and_add("u", ip_from_idx(idx)).await.unwrap(); + } + + tracker.clear_user_ips("u").await; + + assert_eq!(tracker.get_active_ip_count("u").await, 0); + let counts = tracker + .get_recent_counts_for_users(&["u".to_string()]) + .await; + assert_eq!(counts.get("u").copied().unwrap_or(0), 0); +} + +#[tokio::test] +async fn clear_all_resets_multi_user_state() { + let tracker = UserIpTracker::new(); + + for user_idx in 0..80u32 { + let user = format!("u{}", user_idx); + for ip_idx in 0..3 { + tracker + .check_and_add(&user, ip_from_idx(user_idx * 100 + ip_idx)) + .await + .unwrap(); + } + } + + tracker.clear_all().await; + + assert!(tracker.get_stats().await.is_empty()); + let users = (0..80u32) + .map(|idx| format!("u{}", idx)) + .collect::>(); + let recent = tracker.get_recent_counts_for_users(&users).await; + assert!(recent.values().all(|count| *count == 0)); +} + +#[tokio::test] +async fn get_active_ips_for_users_are_sorted() { + let tracker = UserIpTracker::new(); + tracker.set_user_limit("user", 10).await; + + tracker + .check_and_add("user", IpAddr::V4(Ipv4Addr::new(10, 0, 0, 9))) + .await + .unwrap(); + tracker + .check_and_add("user", IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))) + .await + .unwrap(); + tracker + .check_and_add("user", IpAddr::V4(Ipv4Addr::new(10, 0, 0, 5))) + .await + .unwrap(); + + let map = tracker + .get_active_ips_for_users(&["user".to_string()]) + .await; + let ips = map.get("user").cloned().unwrap_or_default(); + + assert_eq!( + ips, + vec![ + IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), + IpAddr::V4(Ipv4Addr::new(10, 0, 0, 5)), + IpAddr::V4(Ipv4Addr::new(10, 0, 0, 9)), + ] + ); +} + +#[tokio::test] +async fn get_recent_ips_for_users_are_sorted() { + let tracker = UserIpTracker::new(); + tracker.set_user_limit("user", 10).await; + + tracker + .check_and_add("user", IpAddr::V4(Ipv4Addr::new(10, 1, 0, 9))) + .await + .unwrap(); + tracker + .check_and_add("user", IpAddr::V4(Ipv4Addr::new(10, 1, 0, 1))) + .await + .unwrap(); + tracker + .check_and_add("user", IpAddr::V4(Ipv4Addr::new(10, 1, 0, 5))) + .await + .unwrap(); + + let map = tracker + .get_recent_ips_for_users(&["user".to_string()]) + .await; + let ips = map.get("user").cloned().unwrap_or_default(); + + assert_eq!( + ips, + vec![ + IpAddr::V4(Ipv4Addr::new(10, 1, 0, 1)), + IpAddr::V4(Ipv4Addr::new(10, 1, 0, 5)), + IpAddr::V4(Ipv4Addr::new(10, 1, 0, 9)), + ] + ); +} + +#[tokio::test] +async fn time_window_expires_for_large_rotation() { + let tracker = UserIpTracker::new(); + tracker.set_user_limit("tw", 1).await; + tracker + .set_limit_policy(UserMaxUniqueIpsMode::TimeWindow, 1) + .await; + + tracker.check_and_add("tw", ip_from_idx(1)).await.unwrap(); + tracker.remove_ip("tw", ip_from_idx(1)).await; + assert!(tracker.check_and_add("tw", ip_from_idx(2)).await.is_err()); + + tokio::time::sleep(Duration::from_millis(1_100)).await; + assert!(tracker.check_and_add("tw", ip_from_idx(2)).await.is_ok()); +} + +#[tokio::test] +async fn combined_mode_blocks_recent_after_disconnect() { + let tracker = UserIpTracker::new(); + tracker.set_user_limit("cmb", 1).await; + tracker + .set_limit_policy(UserMaxUniqueIpsMode::Combined, 2) + .await; + + tracker.check_and_add("cmb", ip_from_idx(11)).await.unwrap(); + tracker.remove_ip("cmb", ip_from_idx(11)).await; + + assert!(tracker.check_and_add("cmb", ip_from_idx(12)).await.is_err()); +} + +#[tokio::test] +async fn load_limits_replaces_large_limit_map() { + let tracker = UserIpTracker::new(); + let mut first = HashMap::new(); + let mut second = HashMap::new(); + + for idx in 0..300usize { + first.insert(format!("u{}", idx), 2usize); + } + for idx in 150..450usize { + second.insert(format!("u{}", idx), 4usize); + } + + tracker.load_limits(0, &first).await; + tracker.load_limits(0, &second).await; + + assert_eq!(tracker.get_user_limit("u20").await, None); + assert_eq!(tracker.get_user_limit("u200").await, Some(4)); + assert_eq!(tracker.get_user_limit("u420").await, Some(4)); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn concurrent_same_user_unique_ip_pressure_stays_bounded() { + let tracker = Arc::new(UserIpTracker::new()); + tracker.set_user_limit("hot", 32).await; + tracker + .set_limit_policy(UserMaxUniqueIpsMode::ActiveWindow, 30) + .await; + + let mut handles = Vec::new(); + for worker in 0..16u32 { + let tracker_cloned = tracker.clone(); + handles.push(tokio::spawn(async move { + let base = worker * 200; + for step in 0..200u32 { + let _ = tracker_cloned + .check_and_add("hot", ip_from_idx(base + step)) + .await; + } + })); + } + + for handle in handles { + handle.await.unwrap(); + } + + assert!(tracker.get_active_ip_count("hot").await <= 32); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn concurrent_many_users_isolate_limits() { + let tracker = Arc::new(UserIpTracker::new()); + tracker.load_limits(4, &HashMap::new()).await; + + let mut handles = Vec::new(); + for user_idx in 0..120u32 { + let tracker_cloned = tracker.clone(); + handles.push(tokio::spawn(async move { + let user = format!("u{}", user_idx); + for ip_idx in 0..10u32 { + let _ = tracker_cloned + .check_and_add(&user, ip_from_idx(user_idx * 1_000 + ip_idx)) + .await; + } + })); + } + + for handle in handles { + handle.await.unwrap(); + } + + let stats = tracker.get_stats().await; + assert_eq!(stats.len(), 120); + assert!(stats.iter().all(|(_, active, limit)| *active <= 4 && *limit == 4)); +} + +#[tokio::test] +async fn same_ip_reconnect_high_frequency_keeps_single_unique() { + let tracker = UserIpTracker::new(); + tracker.set_user_limit("same", 2).await; + let ip = ip_from_idx(9); + + for _ in 0..2_000 { + tracker.check_and_add("same", ip).await.unwrap(); + } + + assert_eq!(tracker.get_active_ip_count("same").await, 1); + assert!(tracker.is_ip_active("same", ip).await); +} + +#[tokio::test] +async fn format_stats_contains_expected_limited_and_unlimited_markers() { + let tracker = UserIpTracker::new(); + tracker.set_user_limit("limited", 2).await; + tracker.check_and_add("limited", ip_from_idx(1)).await.unwrap(); + tracker.check_and_add("open", ip_from_idx(2)).await.unwrap(); + + let text = tracker.format_stats().await; + + assert!(text.contains("limited")); + assert!(text.contains("open")); + assert!(text.contains("unlimited")); +} + +#[tokio::test] +async fn stats_report_global_default_for_users_without_override() { + let tracker = UserIpTracker::new(); + tracker.load_limits(5, &HashMap::new()).await; + + tracker.check_and_add("a", ip_from_idx(1)).await.unwrap(); + tracker.check_and_add("b", ip_from_idx(2)).await.unwrap(); + + let stats = tracker.get_stats().await; + assert!(stats.iter().any(|(user, _, limit)| user == "a" && *limit == 5)); + assert!(stats.iter().any(|(user, _, limit)| user == "b" && *limit == 5)); +} + +#[tokio::test] +async fn stress_cycle_add_remove_clear_preserves_empty_end_state() { + let tracker = UserIpTracker::new(); + + for cycle in 0..50u32 { + let user = format!("cycle{}", cycle); + tracker.set_user_limit(&user, 128).await; + + for ip_idx in 0..128u32 { + tracker + .check_and_add(&user, ip_from_idx(cycle * 10_000 + ip_idx)) + .await + .unwrap(); + } + + for ip_idx in 0..128u32 { + tracker + .remove_ip(&user, ip_from_idx(cycle * 10_000 + ip_idx)) + .await; + } + + tracker.clear_user_ips(&user).await; + } + + assert!(tracker.get_stats().await.is_empty()); +} + +#[tokio::test] +async fn remove_unknown_user_or_ip_does_not_corrupt_state() { + let tracker = UserIpTracker::new(); + + tracker.remove_ip("no_user", ip_from_idx(1)).await; + tracker.check_and_add("x", ip_from_idx(2)).await.unwrap(); + tracker.remove_ip("x", ip_from_idx(3)).await; + + assert_eq!(tracker.get_active_ip_count("x").await, 1); + assert!(tracker.is_ip_active("x", ip_from_idx(2)).await); +} + +#[tokio::test] +async fn active_and_recent_views_match_after_mixed_workload() { + let tracker = UserIpTracker::new(); + tracker.set_user_limit("mix", 16).await; + + for ip_idx in 0..12u32 { + tracker.check_and_add("mix", ip_from_idx(ip_idx)).await.unwrap(); + } + for ip_idx in 0..6u32 { + tracker.remove_ip("mix", ip_from_idx(ip_idx)).await; + } + + let active = tracker + .get_active_ips_for_users(&["mix".to_string()]) + .await + .get("mix") + .cloned() + .unwrap_or_default(); + let recent_count = tracker + .get_recent_counts_for_users(&["mix".to_string()]) + .await + .get("mix") + .copied() + .unwrap_or(0); + + assert_eq!(active.len(), 6); + assert!(recent_count >= active.len()); + assert!(recent_count <= 12); +} + +#[tokio::test] +async fn global_limit_switch_updates_enforcement_immediately() { + let tracker = UserIpTracker::new(); + tracker.load_limits(2, &HashMap::new()).await; + + assert!(tracker.check_and_add("u", ip_from_idx(1)).await.is_ok()); + assert!(tracker.check_and_add("u", ip_from_idx(2)).await.is_ok()); + assert!(tracker.check_and_add("u", ip_from_idx(3)).await.is_err()); + + tracker.clear_user_ips("u").await; + tracker.load_limits(4, &HashMap::new()).await; + + assert!(tracker.check_and_add("u", ip_from_idx(1)).await.is_ok()); + assert!(tracker.check_and_add("u", ip_from_idx(2)).await.is_ok()); + assert!(tracker.check_and_add("u", ip_from_idx(3)).await.is_ok()); + assert!(tracker.check_and_add("u", ip_from_idx(4)).await.is_ok()); + assert!(tracker.check_and_add("u", ip_from_idx(5)).await.is_err()); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn concurrent_reconnect_and_disconnect_preserves_non_negative_counts() { + let tracker = Arc::new(UserIpTracker::new()); + tracker.set_user_limit("cc", 8).await; + + let mut handles = Vec::new(); + for worker in 0..8u32 { + let tracker_cloned = tracker.clone(); + handles.push(tokio::spawn(async move { + let ip = ip_from_idx(50 + worker); + for _ in 0..500u32 { + let _ = tracker_cloned.check_and_add("cc", ip).await; + tracker_cloned.remove_ip("cc", ip).await; + } + })); + } + + for handle in handles { + handle.await.unwrap(); + } + + assert!(tracker.get_active_ip_count("cc").await <= 8); +} diff --git a/src/main.rs b/src/main.rs index 73ada8c..2cfbe28 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,6 +6,8 @@ mod config; mod crypto; mod error; mod ip_tracker; +#[cfg(test)] +mod ip_tracker_regression_tests; mod maestro; mod metrics; mod network; diff --git a/src/transport/middle_proxy/health.rs b/src/transport/middle_proxy/health.rs index e5f4260..8ac6839 100644 --- a/src/transport/middle_proxy/health.rs +++ b/src/transport/middle_proxy/health.rs @@ -25,6 +25,9 @@ const HEALTH_RECONNECT_BUDGET_PER_CORE: usize = 2; const HEALTH_RECONNECT_BUDGET_PER_DC: usize = 1; const HEALTH_RECONNECT_BUDGET_MIN: usize = 4; const HEALTH_RECONNECT_BUDGET_MAX: usize = 128; +const HEALTH_DRAIN_CLOSE_BUDGET_PER_CORE: usize = 16; +const HEALTH_DRAIN_CLOSE_BUDGET_MIN: usize = 16; +const HEALTH_DRAIN_CLOSE_BUDGET_MAX: usize = 256; #[derive(Debug, Clone)] struct DcFloorPlanEntry { @@ -111,7 +114,7 @@ pub async fn me_health_monitor(pool: Arc, rng: Arc, _min_c } } -async fn reap_draining_writers( +pub(super) async fn reap_draining_writers( pool: &Arc, warn_next_allowed: &mut HashMap, ) { @@ -122,14 +125,22 @@ async fn reap_draining_writers( .me_pool_drain_threshold .load(std::sync::atomic::Ordering::Relaxed); let writers = pool.writers.read().await.clone(); + let activity = pool.registry.writer_activity_snapshot().await; let mut draining_writers = Vec::new(); + let mut empty_writer_ids = Vec::::new(); + let mut force_close_writer_ids = Vec::::new(); for writer in writers { if !writer.draining.load(std::sync::atomic::Ordering::Relaxed) { continue; } - let is_empty = pool.registry.is_writer_empty(writer.id).await; - if is_empty { - pool.remove_writer_and_close_clients(writer.id).await; + if activity + .bound_clients_by_writer + .get(&writer.id) + .copied() + .unwrap_or(0) + == 0 + { + empty_writer_ids.push(writer.id); continue; } draining_writers.push(writer); @@ -156,12 +167,13 @@ async fn reap_draining_writers( "ME draining writer threshold exceeded, force-closing oldest draining writers" ); for writer in draining_writers.drain(..overflow) { - pool.stats.increment_pool_force_close_total(); - pool.remove_writer_and_close_clients(writer.id).await; + force_close_writer_ids.push(writer.id); } } + let mut active_draining_writer_ids = HashSet::with_capacity(draining_writers.len()); for writer in draining_writers { + active_draining_writer_ids.insert(writer.id); let drain_started_at_epoch_secs = writer .draining_started_at_epoch_secs .load(std::sync::atomic::Ordering::Relaxed); @@ -191,10 +203,59 @@ async fn reap_draining_writers( .load(std::sync::atomic::Ordering::Relaxed); if deadline_epoch_secs != 0 && now_epoch_secs >= deadline_epoch_secs { warn!(writer_id = writer.id, "Drain timeout, force-closing"); - pool.stats.increment_pool_force_close_total(); - pool.remove_writer_and_close_clients(writer.id).await; + force_close_writer_ids.push(writer.id); + active_draining_writer_ids.remove(&writer.id); } } + + warn_next_allowed.retain(|writer_id, _| active_draining_writer_ids.contains(writer_id)); + + let close_budget = health_drain_close_budget(); + let requested_force_close = force_close_writer_ids.len(); + let requested_empty_close = empty_writer_ids.len(); + let requested_close_total = requested_force_close.saturating_add(requested_empty_close); + let mut closed_writer_ids = HashSet::::new(); + let mut closed_total = 0usize; + for writer_id in force_close_writer_ids { + if closed_total >= close_budget { + break; + } + if !closed_writer_ids.insert(writer_id) { + continue; + } + pool.stats.increment_pool_force_close_total(); + pool.remove_writer_and_close_clients(writer_id).await; + closed_total = closed_total.saturating_add(1); + } + for writer_id in empty_writer_ids { + if closed_total >= close_budget { + break; + } + if !closed_writer_ids.insert(writer_id) { + continue; + } + pool.remove_writer_and_close_clients(writer_id).await; + closed_total = closed_total.saturating_add(1); + } + + let pending_close_total = requested_close_total.saturating_sub(closed_total); + if pending_close_total > 0 { + warn!( + close_budget, + closed_total, + pending_close_total, + "ME draining close backlog deferred to next health cycle" + ); + } +} + +pub(super) fn health_drain_close_budget() -> usize { + let cpu_cores = std::thread::available_parallelism() + .map(std::num::NonZeroUsize::get) + .unwrap_or(1); + cpu_cores + .saturating_mul(HEALTH_DRAIN_CLOSE_BUDGET_PER_CORE) + .clamp(HEALTH_DRAIN_CLOSE_BUDGET_MIN, HEALTH_DRAIN_CLOSE_BUDGET_MAX) } fn should_emit_writer_warn( diff --git a/src/transport/middle_proxy/health_adversarial_tests.rs b/src/transport/middle_proxy/health_adversarial_tests.rs new file mode 100644 index 0000000..675005a --- /dev/null +++ b/src/transport/middle_proxy/health_adversarial_tests.rs @@ -0,0 +1,437 @@ +use std::collections::HashMap; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, AtomicU8, AtomicU32, AtomicU64, Ordering}; +use std::time::{Duration, Instant}; + +use tokio::sync::mpsc; +use tokio_util::sync::CancellationToken; + +use super::codec::WriterCommand; +use super::health::{health_drain_close_budget, reap_draining_writers}; +use super::pool::{MePool, MeWriter, WriterContour}; +use super::registry::ConnMeta; +use super::me_health_monitor; +use crate::config::{GeneralConfig, MeRouteNoWriterMode, MeSocksKdfPolicy, MeWriterPickMode}; +use crate::crypto::SecureRandom; +use crate::network::probe::NetworkDecision; +use crate::stats::Stats; + +async fn make_pool( + me_pool_drain_threshold: u64, + me_health_interval_ms_unhealthy: u64, + me_health_interval_ms_healthy: u64, +) -> (Arc, Arc) { + let general = GeneralConfig { + me_pool_drain_threshold, + me_health_interval_ms_unhealthy, + me_health_interval_ms_healthy, + ..GeneralConfig::default() + }; + + let rng = Arc::new(SecureRandom::new()); + let pool = MePool::new( + None, + vec![1u8; 32], + None, + false, + None, + Vec::new(), + 1, + None, + 12, + 1200, + HashMap::new(), + HashMap::new(), + None, + NetworkDecision::default(), + None, + rng.clone(), + Arc::new(Stats::default()), + general.me_keepalive_enabled, + general.me_keepalive_interval_secs, + general.me_keepalive_jitter_secs, + general.me_keepalive_payload_random, + general.rpc_proxy_req_every, + general.me_warmup_stagger_enabled, + general.me_warmup_step_delay_ms, + general.me_warmup_step_jitter_ms, + general.me_reconnect_max_concurrent_per_dc, + general.me_reconnect_backoff_base_ms, + general.me_reconnect_backoff_cap_ms, + general.me_reconnect_fast_retry_count, + general.me_single_endpoint_shadow_writers, + general.me_single_endpoint_outage_mode_enabled, + general.me_single_endpoint_outage_disable_quarantine, + general.me_single_endpoint_outage_backoff_min_ms, + general.me_single_endpoint_outage_backoff_max_ms, + general.me_single_endpoint_shadow_rotate_every_secs, + general.me_floor_mode, + general.me_adaptive_floor_idle_secs, + general.me_adaptive_floor_min_writers_single_endpoint, + general.me_adaptive_floor_min_writers_multi_endpoint, + general.me_adaptive_floor_recover_grace_secs, + general.me_adaptive_floor_writers_per_core_total, + general.me_adaptive_floor_cpu_cores_override, + general.me_adaptive_floor_max_extra_writers_single_per_core, + general.me_adaptive_floor_max_extra_writers_multi_per_core, + general.me_adaptive_floor_max_active_writers_per_core, + general.me_adaptive_floor_max_warm_writers_per_core, + general.me_adaptive_floor_max_active_writers_global, + general.me_adaptive_floor_max_warm_writers_global, + general.hardswap, + general.me_pool_drain_ttl_secs, + general.me_pool_drain_threshold, + general.effective_me_pool_force_close_secs(), + general.me_pool_min_fresh_ratio, + general.me_hardswap_warmup_delay_min_ms, + general.me_hardswap_warmup_delay_max_ms, + general.me_hardswap_warmup_extra_passes, + general.me_hardswap_warmup_pass_backoff_base_ms, + general.me_bind_stale_mode, + general.me_bind_stale_ttl_secs, + general.me_secret_atomic_snapshot, + general.me_deterministic_writer_sort, + MeWriterPickMode::default(), + general.me_writer_pick_sample_size, + MeSocksKdfPolicy::default(), + general.me_writer_cmd_channel_capacity, + general.me_route_channel_capacity, + general.me_route_backpressure_base_timeout_ms, + general.me_route_backpressure_high_timeout_ms, + general.me_route_backpressure_high_watermark_pct, + general.me_reader_route_data_wait_ms, + general.me_health_interval_ms_unhealthy, + general.me_health_interval_ms_healthy, + general.me_warn_rate_limit_ms, + MeRouteNoWriterMode::default(), + general.me_route_no_writer_wait_ms, + general.me_route_inline_recovery_attempts, + general.me_route_inline_recovery_wait_ms, + ); + + (pool, rng) +} + +async fn insert_draining_writer( + pool: &Arc, + writer_id: u64, + drain_started_at_epoch_secs: u64, + bound_clients: usize, + drain_deadline_epoch_secs: u64, +) { + let (tx, _writer_rx) = mpsc::channel::(8); + let writer = MeWriter { + id: writer_id, + addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 6000 + writer_id as u16), + source_ip: IpAddr::V4(Ipv4Addr::LOCALHOST), + writer_dc: 2, + generation: 1, + contour: Arc::new(AtomicU8::new(WriterContour::Draining.as_u8())), + created_at: Instant::now() - Duration::from_secs(writer_id), + tx: tx.clone(), + cancel: CancellationToken::new(), + degraded: Arc::new(AtomicBool::new(false)), + rtt_ema_ms_x10: Arc::new(AtomicU32::new(0)), + draining: Arc::new(AtomicBool::new(true)), + draining_started_at_epoch_secs: Arc::new(AtomicU64::new(drain_started_at_epoch_secs)), + drain_deadline_epoch_secs: Arc::new(AtomicU64::new(drain_deadline_epoch_secs)), + allow_drain_fallback: Arc::new(AtomicBool::new(false)), + }; + + pool.writers.write().await.push(writer); + pool.registry.register_writer(writer_id, tx).await; + pool.conn_count.fetch_add(1, Ordering::Relaxed); + + for idx in 0..bound_clients { + let (conn_id, _rx) = pool.registry.register().await; + assert!( + pool.registry + .bind_writer( + conn_id, + writer_id, + ConnMeta { + target_dc: 2, + client_addr: SocketAddr::new( + IpAddr::V4(Ipv4Addr::LOCALHOST), + 8000 + idx as u16, + ), + our_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443), + proto_flags: 0, + }, + ) + .await + ); + } +} + +async fn writer_count(pool: &Arc) -> usize { + pool.writers.read().await.len() +} + +async fn sorted_writer_ids(pool: &Arc) -> Vec { + let mut ids = pool + .writers + .read() + .await + .iter() + .map(|writer| writer.id) + .collect::>(); + ids.sort_unstable(); + ids +} + +#[tokio::test] +async fn reap_draining_writers_clears_warn_state_when_pool_empty() { + let (pool, _rng) = make_pool(128, 1, 1).await; + let mut warn_next_allowed = HashMap::new(); + warn_next_allowed.insert(11, Instant::now() + Duration::from_secs(5)); + warn_next_allowed.insert(22, Instant::now() + Duration::from_secs(5)); + + reap_draining_writers(&pool, &mut warn_next_allowed).await; + + assert!(warn_next_allowed.is_empty()); +} + +#[tokio::test] +async fn reap_draining_writers_respects_threshold_across_multiple_overflow_cycles() { + let threshold = 3u64; + let (pool, _rng) = make_pool(threshold, 1, 1).await; + let now_epoch_secs = MePool::now_epoch_secs(); + + for writer_id in 1..=60u64 { + insert_draining_writer( + &pool, + writer_id, + now_epoch_secs.saturating_sub(600).saturating_add(writer_id), + 1, + 0, + ) + .await; + } + + let mut warn_next_allowed = HashMap::new(); + for _ in 0..64 { + reap_draining_writers(&pool, &mut warn_next_allowed).await; + if writer_count(&pool).await <= threshold as usize { + break; + } + } + + assert_eq!(writer_count(&pool).await, threshold as usize); + assert_eq!(sorted_writer_ids(&pool).await, vec![58, 59, 60]); +} + +#[tokio::test] +async fn reap_draining_writers_handles_large_empty_writer_population() { + let (pool, _rng) = make_pool(128, 1, 1).await; + let now_epoch_secs = MePool::now_epoch_secs(); + let total = health_drain_close_budget().saturating_mul(3).saturating_add(27); + + for writer_id in 1..=total as u64 { + insert_draining_writer( + &pool, + writer_id, + now_epoch_secs.saturating_sub(120), + 0, + 0, + ) + .await; + } + + let mut warn_next_allowed = HashMap::new(); + for _ in 0..24 { + if writer_count(&pool).await == 0 { + break; + } + reap_draining_writers(&pool, &mut warn_next_allowed).await; + } + + assert_eq!(writer_count(&pool).await, 0); +} + +#[tokio::test] +async fn reap_draining_writers_processes_mass_deadline_expiry_without_unbounded_growth() { + let (pool, _rng) = make_pool(128, 1, 1).await; + let now_epoch_secs = MePool::now_epoch_secs(); + let total = health_drain_close_budget().saturating_mul(4).saturating_add(31); + + for writer_id in 1..=total as u64 { + insert_draining_writer( + &pool, + writer_id, + now_epoch_secs.saturating_sub(180), + 1, + now_epoch_secs.saturating_sub(1), + ) + .await; + } + + let mut warn_next_allowed = HashMap::new(); + for _ in 0..40 { + if writer_count(&pool).await == 0 { + break; + } + reap_draining_writers(&pool, &mut warn_next_allowed).await; + } + + assert_eq!(writer_count(&pool).await, 0); +} + +#[tokio::test] +async fn reap_draining_writers_maintains_warn_state_subset_property_under_bulk_churn() { + let (pool, _rng) = make_pool(128, 1, 1).await; + let now_epoch_secs = MePool::now_epoch_secs(); + let mut warn_next_allowed = HashMap::new(); + + for wave in 0..40u64 { + for offset in 0..8u64 { + insert_draining_writer( + &pool, + wave * 100 + offset, + now_epoch_secs.saturating_sub(400 + offset), + 1, + 0, + ) + .await; + } + + reap_draining_writers(&pool, &mut warn_next_allowed).await; + assert!(warn_next_allowed.len() <= writer_count(&pool).await); + + let ids = sorted_writer_ids(&pool).await; + for writer_id in ids.into_iter().take(3) { + let _ = pool.remove_writer_and_close_clients(writer_id).await; + } + + reap_draining_writers(&pool, &mut warn_next_allowed).await; + assert!(warn_next_allowed.len() <= writer_count(&pool).await); + } +} + +#[tokio::test] +async fn reap_draining_writers_budgeted_cleanup_never_increases_pool_size() { + let (pool, _rng) = make_pool(5, 1, 1).await; + let now_epoch_secs = MePool::now_epoch_secs(); + + for writer_id in 1..=200u64 { + insert_draining_writer( + &pool, + writer_id, + now_epoch_secs.saturating_sub(240).saturating_add(writer_id), + 1, + 0, + ) + .await; + } + + let mut warn_next_allowed = HashMap::new(); + let mut previous = writer_count(&pool).await; + for _ in 0..32 { + reap_draining_writers(&pool, &mut warn_next_allowed).await; + let current = writer_count(&pool).await; + assert!(current <= previous); + previous = current; + } +} + +#[tokio::test] +async fn me_health_monitor_converges_to_threshold_under_live_injection_churn() { + let threshold = 7u64; + let (pool, rng) = make_pool(threshold, 1, 1).await; + let now_epoch_secs = MePool::now_epoch_secs(); + + for writer_id in 1..=40u64 { + insert_draining_writer( + &pool, + writer_id, + now_epoch_secs.saturating_sub(300).saturating_add(writer_id), + 1, + 0, + ) + .await; + } + + let monitor = tokio::spawn(me_health_monitor(pool.clone(), rng, 0)); + + for wave in 0..8u64 { + for offset in 0..10u64 { + insert_draining_writer( + &pool, + 1000 + wave * 100 + offset, + now_epoch_secs.saturating_sub(120).saturating_add(offset), + 1, + 0, + ) + .await; + } + tokio::time::sleep(Duration::from_millis(5)).await; + } + + tokio::time::sleep(Duration::from_millis(120)).await; + monitor.abort(); + let _ = monitor.await; + + assert!(writer_count(&pool).await <= threshold as usize); +} + +#[tokio::test] +async fn me_health_monitor_drains_deadline_storm_with_budgeted_progress() { + let (pool, rng) = make_pool(128, 1, 1).await; + let now_epoch_secs = MePool::now_epoch_secs(); + + for writer_id in 1..=220u64 { + insert_draining_writer( + &pool, + writer_id, + now_epoch_secs.saturating_sub(120), + 1, + now_epoch_secs.saturating_sub(1), + ) + .await; + } + + let monitor = tokio::spawn(me_health_monitor(pool.clone(), rng, 0)); + tokio::time::sleep(Duration::from_millis(120)).await; + monitor.abort(); + let _ = monitor.await; + + assert_eq!(writer_count(&pool).await, 0); +} + +#[tokio::test] +async fn me_health_monitor_eliminates_mixed_empty_and_deadline_backlog() { + let threshold = 12u64; + let (pool, rng) = make_pool(threshold, 1, 1).await; + let now_epoch_secs = MePool::now_epoch_secs(); + + for writer_id in 1..=180u64 { + let bound_clients = if writer_id % 3 == 0 { 0 } else { 1 }; + let deadline = if writer_id % 2 == 0 { + now_epoch_secs.saturating_sub(1) + } else { + 0 + }; + insert_draining_writer( + &pool, + writer_id, + now_epoch_secs.saturating_sub(250).saturating_add(writer_id), + bound_clients, + deadline, + ) + .await; + } + + let monitor = tokio::spawn(me_health_monitor(pool.clone(), rng, 0)); + tokio::time::sleep(Duration::from_millis(140)).await; + monitor.abort(); + let _ = monitor.await; + + assert!(writer_count(&pool).await <= threshold as usize); +} + +#[test] +fn health_drain_close_budget_is_within_expected_bounds() { + let budget = health_drain_close_budget(); + assert!((16..=256).contains(&budget)); +} diff --git a/src/transport/middle_proxy/health_integration_tests.rs b/src/transport/middle_proxy/health_integration_tests.rs new file mode 100644 index 0000000..70b6411 --- /dev/null +++ b/src/transport/middle_proxy/health_integration_tests.rs @@ -0,0 +1,227 @@ +use std::collections::HashMap; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, AtomicU8, AtomicU32, AtomicU64, Ordering}; +use std::time::{Duration, Instant}; + +use tokio::sync::mpsc; +use tokio_util::sync::CancellationToken; + +use super::codec::WriterCommand; +use super::health::health_drain_close_budget; +use super::pool::{MePool, MeWriter, WriterContour}; +use super::registry::ConnMeta; +use super::me_health_monitor; +use crate::config::{GeneralConfig, MeRouteNoWriterMode, MeSocksKdfPolicy, MeWriterPickMode}; +use crate::crypto::SecureRandom; +use crate::network::probe::NetworkDecision; +use crate::stats::Stats; + +async fn make_pool( + me_pool_drain_threshold: u64, + me_health_interval_ms_unhealthy: u64, + me_health_interval_ms_healthy: u64, +) -> (Arc, Arc) { + let general = GeneralConfig { + me_pool_drain_threshold, + me_health_interval_ms_unhealthy, + me_health_interval_ms_healthy, + ..GeneralConfig::default() + }; + let rng = Arc::new(SecureRandom::new()); + let pool = MePool::new( + None, + vec![1u8; 32], + None, + false, + None, + Vec::new(), + 1, + None, + 12, + 1200, + HashMap::new(), + HashMap::new(), + None, + NetworkDecision::default(), + None, + rng.clone(), + Arc::new(Stats::default()), + general.me_keepalive_enabled, + general.me_keepalive_interval_secs, + general.me_keepalive_jitter_secs, + general.me_keepalive_payload_random, + general.rpc_proxy_req_every, + general.me_warmup_stagger_enabled, + general.me_warmup_step_delay_ms, + general.me_warmup_step_jitter_ms, + general.me_reconnect_max_concurrent_per_dc, + general.me_reconnect_backoff_base_ms, + general.me_reconnect_backoff_cap_ms, + general.me_reconnect_fast_retry_count, + general.me_single_endpoint_shadow_writers, + general.me_single_endpoint_outage_mode_enabled, + general.me_single_endpoint_outage_disable_quarantine, + general.me_single_endpoint_outage_backoff_min_ms, + general.me_single_endpoint_outage_backoff_max_ms, + general.me_single_endpoint_shadow_rotate_every_secs, + general.me_floor_mode, + general.me_adaptive_floor_idle_secs, + general.me_adaptive_floor_min_writers_single_endpoint, + general.me_adaptive_floor_min_writers_multi_endpoint, + general.me_adaptive_floor_recover_grace_secs, + general.me_adaptive_floor_writers_per_core_total, + general.me_adaptive_floor_cpu_cores_override, + general.me_adaptive_floor_max_extra_writers_single_per_core, + general.me_adaptive_floor_max_extra_writers_multi_per_core, + general.me_adaptive_floor_max_active_writers_per_core, + general.me_adaptive_floor_max_warm_writers_per_core, + general.me_adaptive_floor_max_active_writers_global, + general.me_adaptive_floor_max_warm_writers_global, + general.hardswap, + general.me_pool_drain_ttl_secs, + general.me_pool_drain_threshold, + general.effective_me_pool_force_close_secs(), + general.me_pool_min_fresh_ratio, + general.me_hardswap_warmup_delay_min_ms, + general.me_hardswap_warmup_delay_max_ms, + general.me_hardswap_warmup_extra_passes, + general.me_hardswap_warmup_pass_backoff_base_ms, + general.me_bind_stale_mode, + general.me_bind_stale_ttl_secs, + general.me_secret_atomic_snapshot, + general.me_deterministic_writer_sort, + MeWriterPickMode::default(), + general.me_writer_pick_sample_size, + MeSocksKdfPolicy::default(), + general.me_writer_cmd_channel_capacity, + general.me_route_channel_capacity, + general.me_route_backpressure_base_timeout_ms, + general.me_route_backpressure_high_timeout_ms, + general.me_route_backpressure_high_watermark_pct, + general.me_reader_route_data_wait_ms, + general.me_health_interval_ms_unhealthy, + general.me_health_interval_ms_healthy, + general.me_warn_rate_limit_ms, + MeRouteNoWriterMode::default(), + general.me_route_no_writer_wait_ms, + general.me_route_inline_recovery_attempts, + general.me_route_inline_recovery_wait_ms, + ); + (pool, rng) +} + +async fn insert_draining_writer( + pool: &Arc, + writer_id: u64, + drain_started_at_epoch_secs: u64, + bound_clients: usize, + drain_deadline_epoch_secs: u64, +) { + let (tx, _writer_rx) = mpsc::channel::(8); + let writer = MeWriter { + id: writer_id, + addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 5500 + writer_id as u16), + source_ip: IpAddr::V4(Ipv4Addr::LOCALHOST), + writer_dc: 2, + generation: 1, + contour: Arc::new(AtomicU8::new(WriterContour::Draining.as_u8())), + created_at: Instant::now() - Duration::from_secs(writer_id), + tx: tx.clone(), + cancel: CancellationToken::new(), + degraded: Arc::new(AtomicBool::new(false)), + rtt_ema_ms_x10: Arc::new(AtomicU32::new(0)), + draining: Arc::new(AtomicBool::new(true)), + draining_started_at_epoch_secs: Arc::new(AtomicU64::new(drain_started_at_epoch_secs)), + drain_deadline_epoch_secs: Arc::new(AtomicU64::new(drain_deadline_epoch_secs)), + allow_drain_fallback: Arc::new(AtomicBool::new(false)), + }; + pool.writers.write().await.push(writer); + pool.registry.register_writer(writer_id, tx).await; + pool.conn_count.fetch_add(1, Ordering::Relaxed); + for idx in 0..bound_clients { + let (conn_id, _rx) = pool.registry.register().await; + assert!( + pool.registry + .bind_writer( + conn_id, + writer_id, + ConnMeta { + target_dc: 2, + client_addr: SocketAddr::new( + IpAddr::V4(Ipv4Addr::LOCALHOST), + 7200 + idx as u16, + ), + our_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443), + proto_flags: 0, + }, + ) + .await + ); + } +} + +#[tokio::test] +async fn me_health_monitor_drains_expired_backlog_over_multiple_cycles() { + let (pool, rng) = make_pool(128, 1, 1).await; + let now_epoch_secs = MePool::now_epoch_secs(); + let writer_total = health_drain_close_budget().saturating_mul(2).saturating_add(9); + for writer_id in 1..=writer_total as u64 { + insert_draining_writer( + &pool, + writer_id, + now_epoch_secs.saturating_sub(120), + 1, + now_epoch_secs.saturating_sub(1), + ) + .await; + } + + let monitor = tokio::spawn(me_health_monitor(pool.clone(), rng, 0)); + tokio::time::sleep(Duration::from_millis(60)).await; + monitor.abort(); + let _ = monitor.await; + + assert!(pool.writers.read().await.is_empty()); +} + +#[tokio::test] +async fn me_health_monitor_cleans_empty_draining_writers_without_force_close() { + let (pool, rng) = make_pool(128, 1, 1).await; + let now_epoch_secs = MePool::now_epoch_secs(); + for writer_id in 1..=24u64 { + insert_draining_writer(&pool, writer_id, now_epoch_secs.saturating_sub(60), 0, 0).await; + } + + let monitor = tokio::spawn(me_health_monitor(pool.clone(), rng, 0)); + tokio::time::sleep(Duration::from_millis(30)).await; + monitor.abort(); + let _ = monitor.await; + + assert!(pool.writers.read().await.is_empty()); +} + +#[tokio::test] +async fn me_health_monitor_converges_retry_like_threshold_backlog_to_empty() { + let threshold = 4u64; + let (pool, rng) = make_pool(threshold, 1, 1).await; + let now_epoch_secs = MePool::now_epoch_secs(); + let writer_total = threshold as usize + health_drain_close_budget().saturating_add(11); + for writer_id in 1..=writer_total as u64 { + insert_draining_writer( + &pool, + writer_id, + now_epoch_secs.saturating_sub(300).saturating_add(writer_id), + 1, + 0, + ) + .await; + } + + let monitor = tokio::spawn(me_health_monitor(pool.clone(), rng, 0)); + tokio::time::sleep(Duration::from_millis(60)).await; + monitor.abort(); + let _ = monitor.await; + + assert!(pool.writers.read().await.is_empty()); +} diff --git a/src/transport/middle_proxy/health_regression_tests.rs b/src/transport/middle_proxy/health_regression_tests.rs new file mode 100644 index 0000000..05a8e6a --- /dev/null +++ b/src/transport/middle_proxy/health_regression_tests.rs @@ -0,0 +1,462 @@ +use std::collections::HashMap; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, AtomicU8, AtomicU32, AtomicU64, Ordering}; +use std::time::{Duration, Instant}; + +use tokio::sync::mpsc; +use tokio_util::sync::CancellationToken; + +use super::codec::WriterCommand; +use super::health::{health_drain_close_budget, reap_draining_writers}; +use super::pool::{MePool, MeWriter, WriterContour}; +use super::registry::ConnMeta; +use crate::config::{GeneralConfig, MeRouteNoWriterMode, MeSocksKdfPolicy, MeWriterPickMode}; +use crate::crypto::SecureRandom; +use crate::network::probe::NetworkDecision; +use crate::stats::Stats; + +async fn make_pool(me_pool_drain_threshold: u64) -> Arc { + let general = GeneralConfig { + me_pool_drain_threshold, + ..GeneralConfig::default() + }; + + MePool::new( + None, + vec![1u8; 32], + None, + false, + None, + Vec::new(), + 1, + None, + 12, + 1200, + HashMap::new(), + HashMap::new(), + None, + NetworkDecision::default(), + None, + Arc::new(SecureRandom::new()), + Arc::new(Stats::default()), + general.me_keepalive_enabled, + general.me_keepalive_interval_secs, + general.me_keepalive_jitter_secs, + general.me_keepalive_payload_random, + general.rpc_proxy_req_every, + general.me_warmup_stagger_enabled, + general.me_warmup_step_delay_ms, + general.me_warmup_step_jitter_ms, + general.me_reconnect_max_concurrent_per_dc, + general.me_reconnect_backoff_base_ms, + general.me_reconnect_backoff_cap_ms, + general.me_reconnect_fast_retry_count, + general.me_single_endpoint_shadow_writers, + general.me_single_endpoint_outage_mode_enabled, + general.me_single_endpoint_outage_disable_quarantine, + general.me_single_endpoint_outage_backoff_min_ms, + general.me_single_endpoint_outage_backoff_max_ms, + general.me_single_endpoint_shadow_rotate_every_secs, + general.me_floor_mode, + general.me_adaptive_floor_idle_secs, + general.me_adaptive_floor_min_writers_single_endpoint, + general.me_adaptive_floor_min_writers_multi_endpoint, + general.me_adaptive_floor_recover_grace_secs, + general.me_adaptive_floor_writers_per_core_total, + general.me_adaptive_floor_cpu_cores_override, + general.me_adaptive_floor_max_extra_writers_single_per_core, + general.me_adaptive_floor_max_extra_writers_multi_per_core, + general.me_adaptive_floor_max_active_writers_per_core, + general.me_adaptive_floor_max_warm_writers_per_core, + general.me_adaptive_floor_max_active_writers_global, + general.me_adaptive_floor_max_warm_writers_global, + general.hardswap, + general.me_pool_drain_ttl_secs, + general.me_pool_drain_threshold, + general.effective_me_pool_force_close_secs(), + general.me_pool_min_fresh_ratio, + general.me_hardswap_warmup_delay_min_ms, + general.me_hardswap_warmup_delay_max_ms, + general.me_hardswap_warmup_extra_passes, + general.me_hardswap_warmup_pass_backoff_base_ms, + general.me_bind_stale_mode, + general.me_bind_stale_ttl_secs, + general.me_secret_atomic_snapshot, + general.me_deterministic_writer_sort, + MeWriterPickMode::default(), + general.me_writer_pick_sample_size, + MeSocksKdfPolicy::default(), + general.me_writer_cmd_channel_capacity, + general.me_route_channel_capacity, + general.me_route_backpressure_base_timeout_ms, + general.me_route_backpressure_high_timeout_ms, + general.me_route_backpressure_high_watermark_pct, + general.me_reader_route_data_wait_ms, + general.me_health_interval_ms_unhealthy, + general.me_health_interval_ms_healthy, + general.me_warn_rate_limit_ms, + MeRouteNoWriterMode::default(), + general.me_route_no_writer_wait_ms, + general.me_route_inline_recovery_attempts, + general.me_route_inline_recovery_wait_ms, + ) +} + +async fn insert_draining_writer( + pool: &Arc, + writer_id: u64, + drain_started_at_epoch_secs: u64, + bound_clients: usize, + drain_deadline_epoch_secs: u64, +) -> Vec { + let mut conn_ids = Vec::with_capacity(bound_clients); + let (tx, _writer_rx) = mpsc::channel::(8); + let writer = MeWriter { + id: writer_id, + addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 4500 + writer_id as u16), + source_ip: IpAddr::V4(Ipv4Addr::LOCALHOST), + writer_dc: 2, + generation: 1, + contour: Arc::new(AtomicU8::new(WriterContour::Draining.as_u8())), + created_at: Instant::now() - Duration::from_secs(writer_id), + tx: tx.clone(), + cancel: CancellationToken::new(), + degraded: Arc::new(AtomicBool::new(false)), + rtt_ema_ms_x10: Arc::new(AtomicU32::new(0)), + draining: Arc::new(AtomicBool::new(true)), + draining_started_at_epoch_secs: Arc::new(AtomicU64::new(drain_started_at_epoch_secs)), + drain_deadline_epoch_secs: Arc::new(AtomicU64::new(drain_deadline_epoch_secs)), + allow_drain_fallback: Arc::new(AtomicBool::new(false)), + }; + pool.writers.write().await.push(writer); + pool.registry.register_writer(writer_id, tx).await; + pool.conn_count.fetch_add(1, Ordering::Relaxed); + for idx in 0..bound_clients { + let (conn_id, _rx) = pool.registry.register().await; + assert!( + pool.registry + .bind_writer( + conn_id, + writer_id, + ConnMeta { + target_dc: 2, + client_addr: SocketAddr::new( + IpAddr::V4(Ipv4Addr::LOCALHOST), + 6200 + idx as u16, + ), + our_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443), + proto_flags: 0, + }, + ) + .await + ); + conn_ids.push(conn_id); + } + conn_ids +} + +async fn current_writer_ids(pool: &Arc) -> Vec { + let mut writer_ids = pool + .writers + .read() + .await + .iter() + .map(|writer| writer.id) + .collect::>(); + writer_ids.sort_unstable(); + writer_ids +} + +#[tokio::test] +async fn reap_draining_writers_drops_warn_state_for_removed_writer() { + let pool = make_pool(128).await; + let now_epoch_secs = MePool::now_epoch_secs(); + let conn_ids = + insert_draining_writer(&pool, 7, now_epoch_secs.saturating_sub(180), 1, 0).await; + let mut warn_next_allowed = HashMap::new(); + + reap_draining_writers(&pool, &mut warn_next_allowed).await; + assert!(warn_next_allowed.contains_key(&7)); + + let _ = pool.remove_writer_and_close_clients(7).await; + assert!(pool.registry.get_writer(conn_ids[0]).await.is_none()); + + reap_draining_writers(&pool, &mut warn_next_allowed).await; + assert!(!warn_next_allowed.contains_key(&7)); +} + +#[tokio::test] +async fn reap_draining_writers_removes_empty_draining_writers() { + let pool = make_pool(128).await; + let now_epoch_secs = MePool::now_epoch_secs(); + insert_draining_writer(&pool, 1, now_epoch_secs.saturating_sub(40), 0, 0).await; + insert_draining_writer(&pool, 2, now_epoch_secs.saturating_sub(30), 0, 0).await; + insert_draining_writer(&pool, 3, now_epoch_secs.saturating_sub(20), 1, 0).await; + let mut warn_next_allowed = HashMap::new(); + + reap_draining_writers(&pool, &mut warn_next_allowed).await; + + assert_eq!(current_writer_ids(&pool).await, vec![3]); +} + +#[tokio::test] +async fn reap_draining_writers_overflow_closes_oldest_non_empty_writers() { + let pool = make_pool(2).await; + let now_epoch_secs = MePool::now_epoch_secs(); + insert_draining_writer(&pool, 11, now_epoch_secs.saturating_sub(40), 1, 0).await; + insert_draining_writer(&pool, 22, now_epoch_secs.saturating_sub(30), 1, 0).await; + insert_draining_writer(&pool, 33, now_epoch_secs.saturating_sub(20), 1, 0).await; + insert_draining_writer(&pool, 44, now_epoch_secs.saturating_sub(10), 1, 0).await; + let mut warn_next_allowed = HashMap::new(); + + reap_draining_writers(&pool, &mut warn_next_allowed).await; + + assert_eq!(current_writer_ids(&pool).await, vec![33, 44]); +} + +#[tokio::test] +async fn reap_draining_writers_deadline_force_close_applies_under_threshold() { + let pool = make_pool(128).await; + let now_epoch_secs = MePool::now_epoch_secs(); + insert_draining_writer( + &pool, + 50, + now_epoch_secs.saturating_sub(15), + 1, + now_epoch_secs.saturating_sub(1), + ) + .await; + let mut warn_next_allowed = HashMap::new(); + + reap_draining_writers(&pool, &mut warn_next_allowed).await; + + assert!(current_writer_ids(&pool).await.is_empty()); +} + +#[tokio::test] +async fn reap_draining_writers_limits_closes_per_health_tick() { + let pool = make_pool(128).await; + let now_epoch_secs = MePool::now_epoch_secs(); + let close_budget = health_drain_close_budget(); + let writer_total = close_budget.saturating_add(19); + for writer_id in 1..=writer_total as u64 { + insert_draining_writer( + &pool, + writer_id, + now_epoch_secs.saturating_sub(20), + 1, + now_epoch_secs.saturating_sub(1), + ) + .await; + } + let mut warn_next_allowed = HashMap::new(); + + reap_draining_writers(&pool, &mut warn_next_allowed).await; + + assert_eq!(pool.writers.read().await.len(), writer_total - close_budget); +} + +#[tokio::test] +async fn reap_draining_writers_backlog_drains_across_ticks() { + let pool = make_pool(128).await; + let now_epoch_secs = MePool::now_epoch_secs(); + let close_budget = health_drain_close_budget(); + let writer_total = close_budget.saturating_mul(2).saturating_add(7); + for writer_id in 1..=writer_total as u64 { + insert_draining_writer( + &pool, + writer_id, + now_epoch_secs.saturating_sub(20), + 1, + now_epoch_secs.saturating_sub(1), + ) + .await; + } + let mut warn_next_allowed = HashMap::new(); + + for _ in 0..8 { + if pool.writers.read().await.is_empty() { + break; + } + reap_draining_writers(&pool, &mut warn_next_allowed).await; + } + + assert!(pool.writers.read().await.is_empty()); +} + +#[tokio::test] +async fn reap_draining_writers_threshold_backlog_converges_to_threshold() { + let threshold = 5u64; + let pool = make_pool(threshold).await; + let now_epoch_secs = MePool::now_epoch_secs(); + let close_budget = health_drain_close_budget(); + let writer_total = threshold as usize + close_budget.saturating_add(12); + for writer_id in 1..=writer_total as u64 { + insert_draining_writer( + &pool, + writer_id, + now_epoch_secs.saturating_sub(200).saturating_add(writer_id), + 1, + 0, + ) + .await; + } + let mut warn_next_allowed = HashMap::new(); + + for _ in 0..16 { + reap_draining_writers(&pool, &mut warn_next_allowed).await; + if pool.writers.read().await.len() <= threshold as usize { + break; + } + } + + assert_eq!(pool.writers.read().await.len(), threshold as usize); +} + +#[tokio::test] +async fn reap_draining_writers_threshold_zero_preserves_non_expired_non_empty_writers() { + let pool = make_pool(0).await; + let now_epoch_secs = MePool::now_epoch_secs(); + insert_draining_writer(&pool, 10, now_epoch_secs.saturating_sub(40), 1, 0).await; + insert_draining_writer(&pool, 20, now_epoch_secs.saturating_sub(30), 1, 0).await; + insert_draining_writer(&pool, 30, now_epoch_secs.saturating_sub(20), 1, 0).await; + let mut warn_next_allowed = HashMap::new(); + + reap_draining_writers(&pool, &mut warn_next_allowed).await; + + assert_eq!(current_writer_ids(&pool).await, vec![10, 20, 30]); +} + +#[tokio::test] +async fn reap_draining_writers_prioritizes_force_close_before_empty_cleanup() { + let pool = make_pool(128).await; + let now_epoch_secs = MePool::now_epoch_secs(); + let close_budget = health_drain_close_budget(); + for writer_id in 1..=close_budget as u64 { + insert_draining_writer( + &pool, + writer_id, + now_epoch_secs.saturating_sub(20), + 1, + now_epoch_secs.saturating_sub(1), + ) + .await; + } + let empty_writer_id = close_budget as u64 + 1; + insert_draining_writer(&pool, empty_writer_id, now_epoch_secs.saturating_sub(20), 0, 0).await; + let mut warn_next_allowed = HashMap::new(); + + reap_draining_writers(&pool, &mut warn_next_allowed).await; + + assert_eq!(current_writer_ids(&pool).await, vec![empty_writer_id]); +} + +#[tokio::test] +async fn reap_draining_writers_empty_cleanup_does_not_increment_force_close_metric() { + let pool = make_pool(128).await; + let now_epoch_secs = MePool::now_epoch_secs(); + insert_draining_writer(&pool, 1, now_epoch_secs.saturating_sub(60), 0, 0).await; + insert_draining_writer(&pool, 2, now_epoch_secs.saturating_sub(50), 0, 0).await; + let mut warn_next_allowed = HashMap::new(); + + reap_draining_writers(&pool, &mut warn_next_allowed).await; + + assert!(current_writer_ids(&pool).await.is_empty()); + assert_eq!(pool.stats.get_pool_force_close_total(), 0); +} + +#[tokio::test] +async fn reap_draining_writers_handles_duplicate_force_close_requests_for_same_writer() { + let pool = make_pool(1).await; + let now_epoch_secs = MePool::now_epoch_secs(); + insert_draining_writer( + &pool, + 10, + now_epoch_secs.saturating_sub(30), + 1, + now_epoch_secs.saturating_sub(1), + ) + .await; + insert_draining_writer( + &pool, + 20, + now_epoch_secs.saturating_sub(20), + 1, + now_epoch_secs.saturating_sub(1), + ) + .await; + let mut warn_next_allowed = HashMap::new(); + + reap_draining_writers(&pool, &mut warn_next_allowed).await; + + assert!(current_writer_ids(&pool).await.is_empty()); +} + +#[tokio::test] +async fn reap_draining_writers_warn_state_never_exceeds_live_draining_population_under_churn() { + let pool = make_pool(128).await; + let now_epoch_secs = MePool::now_epoch_secs(); + let mut warn_next_allowed = HashMap::new(); + + for wave in 0..12u64 { + for offset in 0..9u64 { + insert_draining_writer( + &pool, + wave * 100 + offset, + now_epoch_secs.saturating_sub(120 + offset), + 1, + 0, + ) + .await; + } + reap_draining_writers(&pool, &mut warn_next_allowed).await; + assert!(warn_next_allowed.len() <= pool.writers.read().await.len()); + + let existing_writer_ids = current_writer_ids(&pool).await; + for writer_id in existing_writer_ids.into_iter().take(4) { + let _ = pool.remove_writer_and_close_clients(writer_id).await; + } + reap_draining_writers(&pool, &mut warn_next_allowed).await; + assert!(warn_next_allowed.len() <= pool.writers.read().await.len()); + } +} + +#[tokio::test] +async fn reap_draining_writers_mixed_backlog_converges_without_leaking_warn_state() { + let pool = make_pool(6).await; + let now_epoch_secs = MePool::now_epoch_secs(); + let mut warn_next_allowed = HashMap::new(); + + for writer_id in 1..=18u64 { + let bound_clients = if writer_id % 3 == 0 { 0 } else { 1 }; + let deadline = if writer_id % 2 == 0 { + now_epoch_secs.saturating_sub(1) + } else { + 0 + }; + insert_draining_writer( + &pool, + writer_id, + now_epoch_secs.saturating_sub(300).saturating_add(writer_id), + bound_clients, + deadline, + ) + .await; + } + + for _ in 0..16 { + reap_draining_writers(&pool, &mut warn_next_allowed).await; + if pool.writers.read().await.len() <= 6 { + break; + } + } + + assert!(pool.writers.read().await.len() <= 6); + assert!(warn_next_allowed.len() <= pool.writers.read().await.len()); +} + +#[test] +fn general_config_default_drain_threshold_remains_enabled() { + assert_eq!(GeneralConfig::default().me_pool_drain_threshold, 128); +} diff --git a/src/transport/middle_proxy/mod.rs b/src/transport/middle_proxy/mod.rs index 92e222d..590c996 100644 --- a/src/transport/middle_proxy/mod.rs +++ b/src/transport/middle_proxy/mod.rs @@ -21,6 +21,12 @@ mod secret; mod selftest; mod wire; mod pool_status; +#[cfg(test)] +mod health_regression_tests; +#[cfg(test)] +mod health_integration_tests; +#[cfg(test)] +mod health_adversarial_tests; use bytes::Bytes; From 35bca7d4cc07db95721b39bae399ba9425ec7bbf Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Tue, 17 Mar 2026 16:31:32 +0300 Subject: [PATCH 04/13] Update Cargo.toml --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 9374924..dad9cf0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "telemt" -version = "3.3.19" +version = "3.3.20" edition = "2024" [dependencies] From 1357f3cc4c18f1b59e7bfd3cdabddd2d4e39d4aa Mon Sep 17 00:00:00 2001 From: David Osipov Date: Tue, 17 Mar 2026 18:16:17 +0400 Subject: [PATCH 05/13] bump version to 3.3.20 and implement connection lease management for direct and middle relays --- Cargo.lock | 2 +- src/proxy/direct_relay.rs | 4 +- src/proxy/direct_relay_security_tests.rs | 129 ++++++++++++++ src/proxy/middle_relay.rs | 4 +- src/proxy/middle_relay_security_tests.rs | 167 ++++++++++++++++++- src/stats/connection_lease_security_tests.rs | 114 +++++++++++++ src/stats/mod.rs | 54 ++++++ 7 files changed, 465 insertions(+), 9 deletions(-) create mode 100644 src/stats/connection_lease_security_tests.rs diff --git a/Cargo.lock b/Cargo.lock index 89eefd6..677ab84 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2093,7 +2093,7 @@ dependencies = [ [[package]] name = "telemt" -version = "3.3.19" +version = "3.3.20" dependencies = [ "aes", "anyhow", diff --git a/src/proxy/direct_relay.rs b/src/proxy/direct_relay.rs index 9c6116c..d7d5f64 100644 --- a/src/proxy/direct_relay.rs +++ b/src/proxy/direct_relay.rs @@ -105,7 +105,7 @@ where debug!(peer = %success.peer, "TG handshake complete, starting relay"); stats.increment_user_connects(user); - stats.increment_current_connections_direct(); + let _direct_connection_lease = stats.acquire_direct_connection_lease(); let relay_result = relay_bidirectional( client_reader, @@ -148,8 +148,6 @@ where } }; - stats.decrement_current_connections_direct(); - match &relay_result { Ok(()) => debug!(user = %user, "Direct relay completed"), Err(e) => debug!(user = %user, error = %e, "Direct relay ended with error"), diff --git a/src/proxy/direct_relay_security_tests.rs b/src/proxy/direct_relay_security_tests.rs index 3b3185a..1e2d673 100644 --- a/src/proxy/direct_relay_security_tests.rs +++ b/src/proxy/direct_relay_security_tests.rs @@ -1,4 +1,33 @@ use super::*; +use crate::config::{UpstreamConfig, UpstreamType}; +use crate::crypto::{AesCtr, SecureRandom}; +use crate::protocol::constants::ProtoTag; +use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController}; +use crate::stats::Stats; +use crate::stream::{BufferPool, CryptoReader, CryptoWriter}; +use crate::transport::UpstreamManager; +use std::sync::Arc; +use std::time::Duration; +use tokio::io::duplex; +use tokio::net::TcpListener; + +fn make_crypto_reader(reader: R) -> CryptoReader +where + R: tokio::io::AsyncRead + Unpin, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoReader::new(reader, AesCtr::new(&key, iv)) +} + +fn make_crypto_writer(writer: W) -> CryptoWriter +where + W: tokio::io::AsyncWrite + Unpin, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) +} #[test] fn unknown_dc_log_is_deduplicated_per_dc_idx() { @@ -49,3 +78,103 @@ fn fallback_dc_never_panics_with_single_dc_list() { let expected = SocketAddr::new(TG_DATACENTERS_V6[0], TG_DATACENTER_PORT); assert_eq!(addr, expected); } + +#[tokio::test] +async fn direct_relay_abort_midflight_releases_route_gauge() { + let tg_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let tg_addr = tg_listener.local_addr().unwrap(); + + let tg_accept_task = tokio::spawn(async move { + let (stream, _) = tg_listener.accept().await.unwrap(); + let _hold_stream = stream; + tokio::time::sleep(Duration::from_secs(60)).await; + }); + + let stats = Arc::new(Stats::new()); + let mut config = ProxyConfig::default(); + config + .dc_overrides + .insert("2".to_string(), vec![tg_addr.to_string()]); + let config = Arc::new(config); + + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + + let rng = Arc::new(SecureRandom::new()); + let buffer_pool = Arc::new(BufferPool::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let route_snapshot = route_runtime.snapshot(); + + let (server_side, client_side) = duplex(64 * 1024); + let (server_reader, server_writer) = tokio::io::split(server_side); + let client_reader = make_crypto_reader(server_reader); + let client_writer = make_crypto_writer(server_writer); + + let success = HandshakeSuccess { + user: "abort-direct-user".to_string(), + dc_idx: 2, + proto_tag: ProtoTag::Intermediate, + dec_key: [0u8; 32], + dec_iv: 0, + enc_key: [0u8; 32], + enc_iv: 0, + peer: "127.0.0.1:50000".parse().unwrap(), + is_tls: false, + }; + + let relay_task = tokio::spawn(handle_via_direct( + client_reader, + client_writer, + success, + upstream_manager, + stats.clone(), + config, + buffer_pool, + rng, + route_runtime.subscribe(), + route_snapshot, + 0xabad1dea, + )); + + let started = tokio::time::timeout(Duration::from_secs(2), async { + loop { + if stats.get_current_connections_direct() == 1 { + break; + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + }) + .await; + assert!(started.is_ok(), "direct relay must increment route gauge before abort"); + + relay_task.abort(); + let joined = relay_task.await; + assert!(joined.is_err(), "aborted direct relay task must return join error"); + + tokio::time::sleep(Duration::from_millis(20)).await; + assert_eq!( + stats.get_current_connections_direct(), + 0, + "route gauge must be released when direct relay task is aborted mid-flight" + ); + + drop(client_side); + tg_accept_task.abort(); + let _ = tg_accept_task.await; +} diff --git a/src/proxy/middle_relay.rs b/src/proxy/middle_relay.rs index 1acbdc1..affa4cd 100644 --- a/src/proxy/middle_relay.rs +++ b/src/proxy/middle_relay.rs @@ -306,7 +306,7 @@ where }; stats.increment_user_connects(&user); - stats.increment_current_connections_me(); + let _me_connection_lease = stats.acquire_me_connection_lease(); if let Some(cutover) = affected_cutover_state( &route_rx, @@ -324,7 +324,6 @@ where tokio::time::sleep(delay).await; let _ = me_pool.send_close(conn_id).await; me_pool.registry().unregister(conn_id).await; - stats.decrement_current_connections_me(); return Err(ProxyError::Proxy(ROUTE_SWITCH_ERROR_MSG.to_string())); } @@ -672,7 +671,6 @@ where "ME relay cleanup" ); me_pool.registry().unregister(conn_id).await; - stats.decrement_current_connections_me(); result } diff --git a/src/proxy/middle_relay_security_tests.rs b/src/proxy/middle_relay_security_tests.rs index a2a6c3e..509ba95 100644 --- a/src/proxy/middle_relay_security_tests.rs +++ b/src/proxy/middle_relay_security_tests.rs @@ -2,8 +2,13 @@ use super::*; use bytes::Bytes; use crate::crypto::AesCtr; use crate::crypto::SecureRandom; +use crate::config::{GeneralConfig, MeRouteNoWriterMode, MeSocksKdfPolicy, MeWriterPickMode}; +use crate::network::probe::NetworkDecision; +use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController}; use crate::stats::Stats; use crate::stream::{BufferPool, CryptoReader, CryptoWriter, PooledBuffer}; +use crate::transport::middle_proxy::MePool; +use std::collections::HashMap; use std::net::SocketAddr; use std::sync::Arc; use std::sync::atomic::AtomicU64; @@ -229,18 +234,108 @@ fn make_forensics_state() -> RelayForensicsState { } } -fn make_crypto_reader(reader: tokio::io::DuplexStream) -> CryptoReader { +fn make_crypto_reader(reader: R) -> CryptoReader +where + R: tokio::io::AsyncRead + Unpin, +{ let key = [0u8; 32]; let iv = 0u128; CryptoReader::new(reader, AesCtr::new(&key, iv)) } -fn make_crypto_writer(writer: tokio::io::DuplexStream) -> CryptoWriter { +fn make_crypto_writer(writer: W) -> CryptoWriter +where + W: tokio::io::AsyncWrite + Unpin, +{ let key = [0u8; 32]; let iv = 0u128; CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) } +async fn make_me_pool_for_abort_test(stats: Arc) -> Arc { + let general = GeneralConfig::default(); + + MePool::new( + None, + vec![1u8; 32], + None, + false, + None, + Vec::new(), + 1, + None, + 12, + 1200, + HashMap::new(), + HashMap::new(), + None, + NetworkDecision::default(), + None, + Arc::new(SecureRandom::new()), + stats, + general.me_keepalive_enabled, + general.me_keepalive_interval_secs, + general.me_keepalive_jitter_secs, + general.me_keepalive_payload_random, + general.rpc_proxy_req_every, + general.me_warmup_stagger_enabled, + general.me_warmup_step_delay_ms, + general.me_warmup_step_jitter_ms, + general.me_reconnect_max_concurrent_per_dc, + general.me_reconnect_backoff_base_ms, + general.me_reconnect_backoff_cap_ms, + general.me_reconnect_fast_retry_count, + general.me_single_endpoint_shadow_writers, + general.me_single_endpoint_outage_mode_enabled, + general.me_single_endpoint_outage_disable_quarantine, + general.me_single_endpoint_outage_backoff_min_ms, + general.me_single_endpoint_outage_backoff_max_ms, + general.me_single_endpoint_shadow_rotate_every_secs, + general.me_floor_mode, + general.me_adaptive_floor_idle_secs, + general.me_adaptive_floor_min_writers_single_endpoint, + general.me_adaptive_floor_min_writers_multi_endpoint, + general.me_adaptive_floor_recover_grace_secs, + general.me_adaptive_floor_writers_per_core_total, + general.me_adaptive_floor_cpu_cores_override, + general.me_adaptive_floor_max_extra_writers_single_per_core, + general.me_adaptive_floor_max_extra_writers_multi_per_core, + general.me_adaptive_floor_max_active_writers_per_core, + general.me_adaptive_floor_max_warm_writers_per_core, + general.me_adaptive_floor_max_active_writers_global, + general.me_adaptive_floor_max_warm_writers_global, + general.hardswap, + general.me_pool_drain_ttl_secs, + general.me_pool_drain_threshold, + general.effective_me_pool_force_close_secs(), + general.me_pool_min_fresh_ratio, + general.me_hardswap_warmup_delay_min_ms, + general.me_hardswap_warmup_delay_max_ms, + general.me_hardswap_warmup_extra_passes, + general.me_hardswap_warmup_pass_backoff_base_ms, + general.me_bind_stale_mode, + general.me_bind_stale_ttl_secs, + general.me_secret_atomic_snapshot, + general.me_deterministic_writer_sort, + MeWriterPickMode::default(), + general.me_writer_pick_sample_size, + MeSocksKdfPolicy::default(), + general.me_writer_cmd_channel_capacity, + general.me_route_channel_capacity, + general.me_route_backpressure_base_timeout_ms, + general.me_route_backpressure_high_timeout_ms, + general.me_route_backpressure_high_watermark_pct, + general.me_reader_route_data_wait_ms, + general.me_health_interval_ms_unhealthy, + general.me_health_interval_ms_healthy, + general.me_warn_rate_limit_ms, + MeRouteNoWriterMode::default(), + general.me_route_no_writer_wait_ms, + general.me_route_inline_recovery_attempts, + general.me_route_inline_recovery_wait_ms, + ) +} + fn encrypt_for_reader(plaintext: &[u8]) -> Vec { let key = [0u8; 32]; let iv = 0u128; @@ -779,3 +874,71 @@ async fn process_me_writer_response_data_updates_byte_accounting() { "ME->C byte accounting must increase by emitted payload size" ); } + +#[tokio::test] +async fn middle_relay_abort_midflight_releases_route_gauge() { + let stats = Arc::new(Stats::new()); + let me_pool = make_me_pool_for_abort_test(stats.clone()).await; + let config = Arc::new(ProxyConfig::default()); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Middle)); + let route_snapshot = route_runtime.snapshot(); + + let (server_side, client_side) = duplex(64 * 1024); + let (server_reader, server_writer) = tokio::io::split(server_side); + let crypto_reader = make_crypto_reader(server_reader); + let crypto_writer = make_crypto_writer(server_writer); + + let success = HandshakeSuccess { + user: "abort-middle-user".to_string(), + dc_idx: 2, + proto_tag: ProtoTag::Intermediate, + dec_key: [0u8; 32], + dec_iv: 0, + enc_key: [0u8; 32], + enc_iv: 0, + peer: "127.0.0.1:50001".parse().unwrap(), + is_tls: false, + }; + + let relay_task = tokio::spawn(handle_via_middle_proxy( + crypto_reader, + crypto_writer, + success, + me_pool, + stats.clone(), + config, + buffer_pool, + "127.0.0.1:443".parse().unwrap(), + rng, + route_runtime.subscribe(), + route_snapshot, + 0xdecafbad, + )); + + let started = tokio::time::timeout(TokioDuration::from_secs(2), async { + loop { + if stats.get_current_connections_me() == 1 { + break; + } + tokio::time::sleep(TokioDuration::from_millis(10)).await; + } + }) + .await; + assert!(started.is_ok(), "middle relay must increment route gauge before abort"); + + relay_task.abort(); + let joined = relay_task.await; + assert!(joined.is_err(), "aborted middle relay task must return join error"); + + tokio::time::sleep(TokioDuration::from_millis(20)).await; + assert_eq!( + stats.get_current_connections_me(), + 0, + "route gauge must be released when middle relay task is aborted mid-flight" + ); + + drop(client_side); +} diff --git a/src/stats/connection_lease_security_tests.rs b/src/stats/connection_lease_security_tests.rs new file mode 100644 index 0000000..2d942c2 --- /dev/null +++ b/src/stats/connection_lease_security_tests.rs @@ -0,0 +1,114 @@ +use super::*; +use std::panic::{self, AssertUnwindSafe}; +use std::sync::Arc; +use std::time::Duration; + +#[test] +fn direct_connection_lease_balances_on_drop() { + let stats = Arc::new(Stats::new()); + assert_eq!(stats.get_current_connections_direct(), 0); + + { + let _lease = stats.acquire_direct_connection_lease(); + assert_eq!(stats.get_current_connections_direct(), 1); + } + + assert_eq!(stats.get_current_connections_direct(), 0); +} + +#[test] +fn middle_connection_lease_balances_on_drop() { + let stats = Arc::new(Stats::new()); + assert_eq!(stats.get_current_connections_me(), 0); + + { + let _lease = stats.acquire_me_connection_lease(); + assert_eq!(stats.get_current_connections_me(), 1); + } + + assert_eq!(stats.get_current_connections_me(), 0); +} + +#[test] +fn connection_lease_disarm_prevents_double_release() { + let stats = Arc::new(Stats::new()); + + let mut lease = stats.acquire_direct_connection_lease(); + assert_eq!(stats.get_current_connections_direct(), 1); + + stats.decrement_current_connections_direct(); + assert_eq!(stats.get_current_connections_direct(), 0); + + lease.disarm(); + drop(lease); + + assert_eq!(stats.get_current_connections_direct(), 0); +} + +#[test] +fn direct_connection_lease_balances_on_panic_unwind() { + let stats = Arc::new(Stats::new()); + let stats_for_panic = stats.clone(); + + let panic_result = panic::catch_unwind(AssertUnwindSafe(move || { + let _lease = stats_for_panic.acquire_direct_connection_lease(); + panic!("intentional panic to verify lease drop path"); + })); + + assert!(panic_result.is_err(), "panic must propagate from test closure"); + assert_eq!( + stats.get_current_connections_direct(), + 0, + "panic unwind must release direct route gauge" + ); +} + +#[tokio::test] +async fn direct_connection_lease_balances_on_task_abort() { + let stats = Arc::new(Stats::new()); + let stats_for_task = stats.clone(); + + let task = tokio::spawn(async move { + let _lease = stats_for_task.acquire_direct_connection_lease(); + tokio::time::sleep(Duration::from_secs(60)).await; + }); + + tokio::time::sleep(Duration::from_millis(20)).await; + assert_eq!(stats.get_current_connections_direct(), 1); + + task.abort(); + let joined = task.await; + assert!(joined.is_err(), "aborted task must return a join error"); + + tokio::time::sleep(Duration::from_millis(20)).await; + assert_eq!( + stats.get_current_connections_direct(), + 0, + "aborted task must release direct route gauge" + ); +} + +#[tokio::test] +async fn middle_connection_lease_balances_on_task_abort() { + let stats = Arc::new(Stats::new()); + let stats_for_task = stats.clone(); + + let task = tokio::spawn(async move { + let _lease = stats_for_task.acquire_me_connection_lease(); + tokio::time::sleep(Duration::from_secs(60)).await; + }); + + tokio::time::sleep(Duration::from_millis(20)).await; + assert_eq!(stats.get_current_connections_me(), 1); + + task.abort(); + let joined = task.await; + assert!(joined.is_err(), "aborted task must return a join error"); + + tokio::time::sleep(Duration::from_millis(20)).await; + assert_eq!( + stats.get_current_connections_me(), + 0, + "aborted task must release middle route gauge" + ); +} diff --git a/src/stats/mod.rs b/src/stats/mod.rs index 603552d..36241af 100644 --- a/src/stats/mod.rs +++ b/src/stats/mod.rs @@ -6,6 +6,7 @@ pub mod beobachten; pub mod telemetry; use std::sync::atomic::{AtomicBool, AtomicU8, AtomicU64, Ordering}; +use std::sync::Arc; use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; use dashmap::DashMap; use parking_lot::Mutex; @@ -19,6 +20,45 @@ use tracing::debug; use crate::config::{MeTelemetryLevel, MeWriterPickMode}; use self::telemetry::TelemetryPolicy; +#[derive(Clone, Copy)] +enum RouteConnectionGauge { + Direct, + Middle, +} + +pub struct RouteConnectionLease { + stats: Arc, + gauge: RouteConnectionGauge, + active: bool, +} + +impl RouteConnectionLease { + fn new(stats: Arc, gauge: RouteConnectionGauge) -> Self { + Self { + stats, + gauge, + active: true, + } + } + + #[cfg(test)] + fn disarm(&mut self) { + self.active = false; + } +} + +impl Drop for RouteConnectionLease { + fn drop(&mut self) { + if !self.active { + return; + } + match self.gauge { + RouteConnectionGauge::Direct => self.stats.decrement_current_connections_direct(), + RouteConnectionGauge::Middle => self.stats.decrement_current_connections_me(), + } + } +} + // ============= Stats ============= #[derive(Default)] @@ -285,6 +325,16 @@ impl Stats { pub fn decrement_current_connections_me(&self) { Self::decrement_atomic_saturating(&self.current_connections_me); } + + pub fn acquire_direct_connection_lease(self: &Arc) -> RouteConnectionLease { + self.increment_current_connections_direct(); + RouteConnectionLease::new(self.clone(), RouteConnectionGauge::Direct) + } + + pub fn acquire_me_connection_lease(self: &Arc) -> RouteConnectionLease { + self.increment_current_connections_me(); + RouteConnectionLease::new(self.clone(), RouteConnectionGauge::Middle) + } pub fn increment_handshake_timeouts(&self) { if self.telemetry_core_enabled() { self.handshake_timeouts.fetch_add(1, Ordering::Relaxed); @@ -1772,3 +1822,7 @@ mod tests { assert_eq!(checker.stats().total_entries, 500); } } + +#[cfg(test)] +#[path = "connection_lease_security_tests.rs"] +mod connection_lease_security_tests; From 4808a3018574ebc3d336535d4bdef1c105905eab Mon Sep 17 00:00:00 2001 From: David Osipov Date: Tue, 17 Mar 2026 18:29:56 +0400 Subject: [PATCH 06/13] Merge upstream/main into flow-sec rehearsal: resolve config and middle-proxy health conflicts --- src/transport/middle_proxy/health.rs | 133 +++++++++++++++++++++------ 1 file changed, 106 insertions(+), 27 deletions(-) diff --git a/src/transport/middle_proxy/health.rs b/src/transport/middle_proxy/health.rs index 8ac6839..edc9598 100644 --- a/src/transport/middle_proxy/health.rs +++ b/src/transport/middle_proxy/health.rs @@ -124,12 +124,12 @@ pub(super) async fn reap_draining_writers( let drain_threshold = pool .me_pool_drain_threshold .load(std::sync::atomic::Ordering::Relaxed); - let writers = pool.writers.read().await.clone(); let activity = pool.registry.writer_activity_snapshot().await; - let mut draining_writers = Vec::new(); + let mut draining_writers = Vec::::new(); let mut empty_writer_ids = Vec::::new(); let mut force_close_writer_ids = Vec::::new(); - for writer in writers { + let writers = pool.writers.read().await; + for writer in writers.iter() { if !writer.draining.load(std::sync::atomic::Ordering::Relaxed) { continue; } @@ -143,23 +143,38 @@ pub(super) async fn reap_draining_writers( empty_writer_ids.push(writer.id); continue; } - draining_writers.push(writer); + draining_writers.push(DrainingWriterSnapshot { + id: writer.id, + writer_dc: writer.writer_dc, + addr: writer.addr, + generation: writer.generation, + created_at: writer.created_at, + draining_started_at_epoch_secs: writer + .draining_started_at_epoch_secs + .load(std::sync::atomic::Ordering::Relaxed), + drain_deadline_epoch_secs: writer + .drain_deadline_epoch_secs + .load(std::sync::atomic::Ordering::Relaxed), + allow_drain_fallback: writer + .allow_drain_fallback + .load(std::sync::atomic::Ordering::Relaxed), + }); } + drop(writers); - if drain_threshold > 0 && draining_writers.len() > drain_threshold as usize { + let overflow = if drain_threshold > 0 && draining_writers.len() > drain_threshold as usize { + draining_writers.len().saturating_sub(drain_threshold as usize) + } else { + 0 + }; + + if overflow > 0 { draining_writers.sort_by(|left, right| { - let left_started = left - .draining_started_at_epoch_secs - .load(std::sync::atomic::Ordering::Relaxed); - let right_started = right - .draining_started_at_epoch_secs - .load(std::sync::atomic::Ordering::Relaxed); - left_started - .cmp(&right_started) + left.draining_started_at_epoch_secs + .cmp(&right.draining_started_at_epoch_secs) .then_with(|| left.created_at.cmp(&right.created_at)) .then_with(|| left.id.cmp(&right.id)) }); - let overflow = draining_writers.len().saturating_sub(drain_threshold as usize); warn!( draining_writers = draining_writers.len(), me_pool_drain_threshold = drain_threshold, @@ -174,12 +189,9 @@ pub(super) async fn reap_draining_writers( let mut active_draining_writer_ids = HashSet::with_capacity(draining_writers.len()); for writer in draining_writers { active_draining_writer_ids.insert(writer.id); - let drain_started_at_epoch_secs = writer - .draining_started_at_epoch_secs - .load(std::sync::atomic::Ordering::Relaxed); if drain_ttl_secs > 0 - && drain_started_at_epoch_secs != 0 - && now_epoch_secs.saturating_sub(drain_started_at_epoch_secs) > drain_ttl_secs + && writer.draining_started_at_epoch_secs != 0 + && now_epoch_secs.saturating_sub(writer.draining_started_at_epoch_secs) > drain_ttl_secs && should_emit_writer_warn( warn_next_allowed, writer.id, @@ -194,14 +206,12 @@ pub(super) async fn reap_draining_writers( generation = writer.generation, drain_ttl_secs, force_close_secs = pool.me_pool_force_close_secs.load(std::sync::atomic::Ordering::Relaxed), - allow_drain_fallback = writer.allow_drain_fallback.load(std::sync::atomic::Ordering::Relaxed), + allow_drain_fallback = writer.allow_drain_fallback, "ME draining writer remains non-empty past drain TTL" ); } - let deadline_epoch_secs = writer - .drain_deadline_epoch_secs - .load(std::sync::atomic::Ordering::Relaxed); - if deadline_epoch_secs != 0 && now_epoch_secs >= deadline_epoch_secs { + if writer.drain_deadline_epoch_secs != 0 && now_epoch_secs >= writer.drain_deadline_epoch_secs + { warn!(writer_id = writer.id, "Drain timeout, force-closing"); force_close_writer_ids.push(writer.id); active_draining_writer_ids.remove(&writer.id); @@ -258,6 +268,18 @@ pub(super) fn health_drain_close_budget() -> usize { .clamp(HEALTH_DRAIN_CLOSE_BUDGET_MIN, HEALTH_DRAIN_CLOSE_BUDGET_MAX) } +#[derive(Debug, Clone)] +struct DrainingWriterSnapshot { + id: u64, + writer_dc: i32, + addr: SocketAddr, + generation: u64, + created_at: Instant, + draining_started_at_epoch_secs: u64, + drain_deadline_epoch_secs: u64, + allow_drain_fallback: bool, +} + fn should_emit_writer_warn( next_allowed: &mut HashMap, writer_id: u64, @@ -1391,6 +1413,15 @@ mod tests { me_pool_drain_threshold, ..GeneralConfig::default() }; + let mut proxy_map_v4 = HashMap::new(); + proxy_map_v4.insert( + 2, + vec![(IpAddr::V4(Ipv4Addr::new(203, 0, 113, 10)), 443)], + ); + let decision = NetworkDecision { + ipv4_me: true, + ..NetworkDecision::default() + }; MePool::new( None, vec![1u8; 32], @@ -1402,10 +1433,10 @@ mod tests { None, 12, 1200, - HashMap::new(), + proxy_map_v4, HashMap::new(), None, - NetworkDecision::default(), + decision, None, Arc::new(SecureRandom::new()), Arc::new(Stats::default()), @@ -1516,8 +1547,55 @@ mod tests { conn_id } + async fn insert_live_writer(pool: &Arc, writer_id: u64, writer_dc: i32) { + let (tx, _writer_rx) = mpsc::channel::(8); + let writer = MeWriter { + id: writer_id, + addr: SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(203, 0, 113, (writer_id as u8).saturating_add(1))), + 4000 + writer_id as u16, + ), + source_ip: IpAddr::V4(Ipv4Addr::LOCALHOST), + writer_dc, + generation: 2, + contour: Arc::new(AtomicU8::new(WriterContour::Active.as_u8())), + created_at: Instant::now(), + tx: tx.clone(), + cancel: CancellationToken::new(), + degraded: Arc::new(AtomicBool::new(false)), + rtt_ema_ms_x10: Arc::new(AtomicU32::new(0)), + draining: Arc::new(AtomicBool::new(false)), + draining_started_at_epoch_secs: Arc::new(AtomicU64::new(0)), + drain_deadline_epoch_secs: Arc::new(AtomicU64::new(0)), + allow_drain_fallback: Arc::new(AtomicBool::new(false)), + }; + pool.writers.write().await.push(writer); + pool.registry.register_writer(writer_id, tx).await; + pool.conn_count.fetch_add(1, Ordering::Relaxed); + } + #[tokio::test] async fn reap_draining_writers_force_closes_oldest_over_threshold() { + let pool = make_pool(2).await; + insert_live_writer(&pool, 1, 2).await; + let now_epoch_secs = MePool::now_epoch_secs(); + let conn_a = insert_draining_writer(&pool, 10, now_epoch_secs.saturating_sub(30)).await; + let conn_b = insert_draining_writer(&pool, 20, now_epoch_secs.saturating_sub(20)).await; + let conn_c = insert_draining_writer(&pool, 30, now_epoch_secs.saturating_sub(10)).await; + let mut warn_next_allowed = HashMap::new(); + + reap_draining_writers(&pool, &mut warn_next_allowed).await; + + let mut writer_ids: Vec = pool.writers.read().await.iter().map(|writer| writer.id).collect(); + writer_ids.sort_unstable(); + assert_eq!(writer_ids, vec![1, 20, 30]); + assert!(pool.registry.get_writer(conn_a).await.is_none()); + assert_eq!(pool.registry.get_writer(conn_b).await.unwrap().writer_id, 20); + assert_eq!(pool.registry.get_writer(conn_c).await.unwrap().writer_id, 30); + } + + #[tokio::test] + async fn reap_draining_writers_force_closes_overflow_without_replacement() { let pool = make_pool(2).await; let now_epoch_secs = MePool::now_epoch_secs(); let conn_a = insert_draining_writer(&pool, 10, now_epoch_secs.saturating_sub(30)).await; @@ -1527,7 +1605,8 @@ mod tests { reap_draining_writers(&pool, &mut warn_next_allowed).await; - let writer_ids: Vec = pool.writers.read().await.iter().map(|writer| writer.id).collect(); + let mut writer_ids: Vec = pool.writers.read().await.iter().map(|writer| writer.id).collect(); + writer_ids.sort_unstable(); assert_eq!(writer_ids, vec![20, 30]); assert!(pool.registry.get_writer(conn_a).await.is_none()); assert_eq!(pool.registry.get_writer(conn_b).await.unwrap().writer_id, 20); From c540a6657fdfa0974cea22783f1f82b6f922d429 Mon Sep 17 00:00:00 2001 From: David Osipov Date: Tue, 17 Mar 2026 19:05:26 +0400 Subject: [PATCH 07/13] Implement user connection reservation management and enhance relay task handling in proxy --- src/proxy/client.rs | 119 +++++++- src/proxy/client_security_tests.rs | 268 +++++++++++++++++++ src/proxy/direct_relay_security_tests.rs | 109 ++++++++ src/proxy/middle_relay_security_tests.rs | 77 ++++++ src/stats/connection_lease_security_tests.rs | 151 +++++++++++ 5 files changed, 714 insertions(+), 10 deletions(-) diff --git a/src/proxy/client.rs b/src/proxy/client.rs index 5ccbd40..254d922 100644 --- a/src/proxy/client.rs +++ b/src/proxy/client.rs @@ -24,6 +24,39 @@ enum HandshakeOutcome { Handled, } +struct UserConnectionReservation { + stats: Arc, + ip_tracker: Arc, + user: String, + ip: IpAddr, +} + +impl UserConnectionReservation { + fn new(stats: Arc, ip_tracker: Arc, user: String, ip: IpAddr) -> Self { + Self { + stats, + ip_tracker, + user, + ip, + } + } +} + +impl Drop for UserConnectionReservation { + fn drop(&mut self) { + self.stats.decrement_user_curr_connects(&self.user); + + if let Ok(handle) = tokio::runtime::Handle::try_current() { + let ip_tracker = self.ip_tracker.clone(); + let user = self.user.clone(); + let ip = self.ip; + handle.spawn(async move { + ip_tracker.remove_ip(&user, ip).await; + }); + } + } +} + use crate::config::ProxyConfig; use crate::crypto::SecureRandom; use crate::error::{HandshakeResult, ProxyError, Result, StreamError}; @@ -90,6 +123,10 @@ fn is_trusted_proxy_source(peer_ip: IpAddr, trusted: &[IpNetwork]) -> bool { trusted.iter().any(|cidr| cidr.contains(peer_ip)) } +fn synthetic_local_addr(port: u16) -> SocketAddr { + SocketAddr::from(([0, 0, 0, 0], port)) +} + pub async fn handle_client_stream( mut stream: S, peer: SocketAddr, @@ -113,9 +150,7 @@ where let mut real_peer = normalize_ip(peer); // For non-TCP streams, use a synthetic local address; may be overridden by PROXY protocol dst - let mut local_addr: SocketAddr = format!("0.0.0.0:{}", config.server.port) - .parse() - .unwrap_or_else(|_| "0.0.0.0:443".parse().unwrap()); + let mut local_addr = synthetic_local_addr(config.server.port); if proxy_protocol_enabled { let proxy_header_timeout = Duration::from_millis( @@ -798,10 +833,22 @@ impl RunningClientHandler { { let user = success.user.clone(); - if let Err(e) = Self::check_user_limits_static(&user, &config, &stats, peer_addr, &ip_tracker).await { - warn!(user = %user, error = %e, "User limit exceeded"); - return Err(e); - } + let _user_limit_reservation = + match Self::acquire_user_connection_reservation_static( + &user, + &config, + stats.clone(), + peer_addr, + ip_tracker, + ) + .await + { + Ok(reservation) => reservation, + Err(e) => { + warn!(user = %user, error = %e, "User limit exceeded"); + return Err(e); + } + }; let route_snapshot = route_runtime.snapshot(); let session_id = rng.u64(); @@ -858,12 +905,64 @@ impl RunningClientHandler { ) .await }; - - stats.decrement_user_curr_connects(&user); - ip_tracker.remove_ip(&user, peer_addr.ip()).await; relay_result } + async fn acquire_user_connection_reservation_static( + user: &str, + config: &ProxyConfig, + stats: Arc, + peer_addr: SocketAddr, + ip_tracker: Arc, + ) -> Result { + if let Some(expiration) = config.access.user_expirations.get(user) + && chrono::Utc::now() > *expiration + { + return Err(ProxyError::UserExpired { + user: user.to_string(), + }); + } + + if let Some(quota) = config.access.user_data_quota.get(user) + && stats.get_user_total_octets(user) >= *quota + { + return Err(ProxyError::DataQuotaExceeded { + user: user.to_string(), + }); + } + + let limit = config.access.user_max_tcp_conns.get(user).map(|v| *v as u64); + if !stats.try_acquire_user_curr_connects(user, limit) { + return Err(ProxyError::ConnectionLimitExceeded { + user: user.to_string(), + }); + } + + match ip_tracker.check_and_add(user, peer_addr.ip()).await { + Ok(()) => {} + Err(reason) => { + stats.decrement_user_curr_connects(user); + warn!( + user = %user, + ip = %peer_addr.ip(), + reason = %reason, + "IP limit exceeded" + ); + return Err(ProxyError::ConnectionLimitExceeded { + user: user.to_string(), + }); + } + } + + Ok(UserConnectionReservation::new( + stats, + ip_tracker, + user.to_string(), + peer_addr.ip(), + )) + } + + #[cfg(test)] async fn check_user_limits_static( user: &str, config: &ProxyConfig, diff --git a/src/proxy/client_security_tests.rs b/src/proxy/client_security_tests.rs index 415cafd..defb3c0 100644 --- a/src/proxy/client_security_tests.rs +++ b/src/proxy/client_security_tests.rs @@ -1,11 +1,279 @@ use super::*; use crate::config::{UpstreamConfig, UpstreamType}; +use crate::crypto::AesCtr; use crate::crypto::sha256_hmac; +use crate::protocol::constants::ProtoTag; use crate::protocol::tls; +use crate::proxy::handshake::HandshakeSuccess; use crate::transport::proxy_protocol::ProxyProtocolV1Builder; +use crate::stream::{CryptoReader, CryptoWriter}; use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream}; +#[test] +fn synthetic_local_addr_uses_configured_port_for_zero() { + let addr = synthetic_local_addr(0); + assert_eq!(addr.ip(), IpAddr::from([0, 0, 0, 0])); + assert_eq!(addr.port(), 0); +} + +#[test] +fn synthetic_local_addr_uses_configured_port_for_max() { + let addr = synthetic_local_addr(u16::MAX); + assert_eq!(addr.ip(), IpAddr::from([0, 0, 0, 0])); + assert_eq!(addr.port(), u16::MAX); +} + +fn make_crypto_reader(reader: R) -> CryptoReader +where + R: tokio::io::AsyncRead + Unpin, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoReader::new(reader, AesCtr::new(&key, iv)) +} + +fn make_crypto_writer(writer: W) -> CryptoWriter +where + W: tokio::io::AsyncWrite + Unpin, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) +} + +#[tokio::test] +async fn relay_task_abort_releases_user_gate_and_ip_reservation() { + let tg_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let tg_addr = tg_listener.local_addr().unwrap(); + + let tg_accept_task = tokio::spawn(async move { + let (stream, _) = tg_listener.accept().await.unwrap(); + let _hold_stream = stream; + tokio::time::sleep(Duration::from_secs(60)).await; + }); + + let user = "abort-user"; + let peer_addr: SocketAddr = "198.51.100.230:50000".parse().unwrap(); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let mut cfg = ProxyConfig::default(); + cfg.access.user_max_tcp_conns.insert(user.to_string(), 8); + cfg.dc_overrides + .insert("2".to_string(), vec![tg_addr.to_string()]); + let config = Arc::new(cfg); + + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + + let (server_side, client_side) = duplex(64 * 1024); + let (server_reader, server_writer) = tokio::io::split(server_side); + let client_reader = make_crypto_reader(server_reader); + let client_writer = make_crypto_writer(server_writer); + + let success = HandshakeSuccess { + user: user.to_string(), + dc_idx: 2, + proto_tag: ProtoTag::Intermediate, + dec_key: [0u8; 32], + dec_iv: 0, + enc_key: [0u8; 32], + enc_iv: 0, + peer: peer_addr, + is_tls: false, + }; + + let relay_task = tokio::spawn(RunningClientHandler::handle_authenticated_static( + client_reader, + client_writer, + success, + upstream_manager, + stats.clone(), + config, + buffer_pool, + rng, + None, + route_runtime, + "127.0.0.1:443".parse().unwrap(), + peer_addr, + ip_tracker.clone(), + )); + + tokio::time::timeout(Duration::from_secs(2), async { + loop { + if stats.get_user_curr_connects(user) == 1 + && ip_tracker.get_active_ip_count(user).await == 1 + { + break; + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + }) + .await + .expect("relay must reserve user slot and IP before abort"); + + relay_task.abort(); + let joined = relay_task.await; + assert!(joined.is_err(), "aborted relay task must return join error"); + + tokio::time::sleep(Duration::from_millis(50)).await; + assert_eq!( + stats.get_user_curr_connects(user), + 0, + "task abort must release user current-connection slot" + ); + assert_eq!( + ip_tracker.get_active_ip_count(user).await, + 0, + "task abort must release reserved user IP footprint" + ); + + drop(client_side); + tg_accept_task.abort(); + let _ = tg_accept_task.await; +} + +#[tokio::test] +async fn relay_cutover_releases_user_gate_and_ip_reservation() { + let tg_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let tg_addr = tg_listener.local_addr().unwrap(); + + let tg_accept_task = tokio::spawn(async move { + let (stream, _) = tg_listener.accept().await.unwrap(); + let _hold_stream = stream; + tokio::time::sleep(Duration::from_secs(60)).await; + }); + + let user = "cutover-user"; + let peer_addr: SocketAddr = "198.51.100.231:50001".parse().unwrap(); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let mut cfg = ProxyConfig::default(); + cfg.access.user_max_tcp_conns.insert(user.to_string(), 8); + cfg.dc_overrides + .insert("2".to_string(), vec![tg_addr.to_string()]); + let config = Arc::new(cfg); + + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + + let (server_side, client_side) = duplex(64 * 1024); + let (server_reader, server_writer) = tokio::io::split(server_side); + let client_reader = make_crypto_reader(server_reader); + let client_writer = make_crypto_writer(server_writer); + + let success = HandshakeSuccess { + user: user.to_string(), + dc_idx: 2, + proto_tag: ProtoTag::Intermediate, + dec_key: [0u8; 32], + dec_iv: 0, + enc_key: [0u8; 32], + enc_iv: 0, + peer: peer_addr, + is_tls: false, + }; + + let relay_task = tokio::spawn(RunningClientHandler::handle_authenticated_static( + client_reader, + client_writer, + success, + upstream_manager, + stats.clone(), + config, + buffer_pool, + rng, + None, + route_runtime.clone(), + "127.0.0.1:443".parse().unwrap(), + peer_addr, + ip_tracker.clone(), + )); + + tokio::time::timeout(Duration::from_secs(2), async { + loop { + if stats.get_user_curr_connects(user) == 1 + && ip_tracker.get_active_ip_count(user).await == 1 + { + break; + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + }) + .await + .expect("relay must reserve user slot and IP before cutover"); + + assert!( + route_runtime.set_mode(RelayRouteMode::Middle).is_some(), + "cutover must advance route generation" + ); + + let relay_result = tokio::time::timeout(Duration::from_secs(6), relay_task) + .await + .expect("relay must terminate after cutover") + .expect("relay task must not panic"); + assert!(relay_result.is_err(), "cutover must terminate direct relay session"); + + assert_eq!( + stats.get_user_curr_connects(user), + 0, + "cutover exit must release user current-connection slot" + ); + assert_eq!( + ip_tracker.get_active_ip_count(user).await, + 0, + "cutover exit must release reserved user IP footprint" + ); + + drop(client_side); + tg_accept_task.abort(); + let _ = tg_accept_task.await; +} + #[tokio::test] async fn short_tls_probe_is_masked_through_client_pipeline() { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); diff --git a/src/proxy/direct_relay_security_tests.rs b/src/proxy/direct_relay_security_tests.rs index 1e2d673..7390fb8 100644 --- a/src/proxy/direct_relay_security_tests.rs +++ b/src/proxy/direct_relay_security_tests.rs @@ -178,3 +178,112 @@ async fn direct_relay_abort_midflight_releases_route_gauge() { tg_accept_task.abort(); let _ = tg_accept_task.await; } + +#[tokio::test] +async fn direct_relay_cutover_midflight_releases_route_gauge() { + let tg_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let tg_addr = tg_listener.local_addr().unwrap(); + + let tg_accept_task = tokio::spawn(async move { + let (stream, _) = tg_listener.accept().await.unwrap(); + let _hold_stream = stream; + tokio::time::sleep(Duration::from_secs(60)).await; + }); + + let stats = Arc::new(Stats::new()); + let mut config = ProxyConfig::default(); + config + .dc_overrides + .insert("2".to_string(), vec![tg_addr.to_string()]); + let config = Arc::new(config); + + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + + let rng = Arc::new(SecureRandom::new()); + let buffer_pool = Arc::new(BufferPool::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let route_snapshot = route_runtime.snapshot(); + + let (server_side, client_side) = duplex(64 * 1024); + let (server_reader, server_writer) = tokio::io::split(server_side); + let client_reader = make_crypto_reader(server_reader); + let client_writer = make_crypto_writer(server_writer); + + let success = HandshakeSuccess { + user: "cutover-direct-user".to_string(), + dc_idx: 2, + proto_tag: ProtoTag::Intermediate, + dec_key: [0u8; 32], + dec_iv: 0, + enc_key: [0u8; 32], + enc_iv: 0, + peer: "127.0.0.1:50002".parse().unwrap(), + is_tls: false, + }; + + let relay_task = tokio::spawn(handle_via_direct( + client_reader, + client_writer, + success, + upstream_manager, + stats.clone(), + config, + buffer_pool, + rng, + route_runtime.subscribe(), + route_snapshot, + 0xface_cafe, + )); + + tokio::time::timeout(Duration::from_secs(2), async { + loop { + if stats.get_current_connections_direct() == 1 { + break; + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + }) + .await + .expect("direct relay must increment route gauge before cutover"); + + assert!( + route_runtime.set_mode(RelayRouteMode::Middle).is_some(), + "cutover must advance route generation" + ); + + let relay_result = tokio::time::timeout(Duration::from_secs(6), relay_task) + .await + .expect("direct relay must terminate after cutover") + .expect("direct relay task must not panic"); + assert!( + relay_result.is_err(), + "cutover should terminate direct relay session" + ); + + assert_eq!( + stats.get_current_connections_direct(), + 0, + "route gauge must be released when direct relay exits on cutover" + ); + + drop(client_side); + tg_accept_task.abort(); + let _ = tg_accept_task.await; +} diff --git a/src/proxy/middle_relay_security_tests.rs b/src/proxy/middle_relay_security_tests.rs index 509ba95..f88b5a0 100644 --- a/src/proxy/middle_relay_security_tests.rs +++ b/src/proxy/middle_relay_security_tests.rs @@ -942,3 +942,80 @@ async fn middle_relay_abort_midflight_releases_route_gauge() { drop(client_side); } + +#[tokio::test] +async fn middle_relay_cutover_midflight_releases_route_gauge() { + let stats = Arc::new(Stats::new()); + let me_pool = make_me_pool_for_abort_test(stats.clone()).await; + let config = Arc::new(ProxyConfig::default()); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Middle)); + let route_snapshot = route_runtime.snapshot(); + + let (server_side, client_side) = duplex(64 * 1024); + let (server_reader, server_writer) = tokio::io::split(server_side); + let crypto_reader = make_crypto_reader(server_reader); + let crypto_writer = make_crypto_writer(server_writer); + + let success = HandshakeSuccess { + user: "cutover-middle-user".to_string(), + dc_idx: 2, + proto_tag: ProtoTag::Intermediate, + dec_key: [0u8; 32], + dec_iv: 0, + enc_key: [0u8; 32], + enc_iv: 0, + peer: "127.0.0.1:50003".parse().unwrap(), + is_tls: false, + }; + + let relay_task = tokio::spawn(handle_via_middle_proxy( + crypto_reader, + crypto_writer, + success, + me_pool, + stats.clone(), + config, + buffer_pool, + "127.0.0.1:443".parse().unwrap(), + rng, + route_runtime.subscribe(), + route_snapshot, + 0xfeed_beef, + )); + + tokio::time::timeout(TokioDuration::from_secs(2), async { + loop { + if stats.get_current_connections_me() == 1 { + break; + } + tokio::time::sleep(TokioDuration::from_millis(10)).await; + } + }) + .await + .expect("middle relay must increment route gauge before cutover"); + + assert!( + route_runtime.set_mode(RelayRouteMode::Direct).is_some(), + "cutover must advance route generation" + ); + + let relay_result = tokio::time::timeout(TokioDuration::from_secs(6), relay_task) + .await + .expect("middle relay must terminate after cutover") + .expect("middle relay task must not panic"); + assert!( + relay_result.is_err(), + "cutover should terminate middle relay session" + ); + + assert_eq!( + stats.get_current_connections_me(), + 0, + "route gauge must be released when middle relay exits on cutover" + ); + + drop(client_side); +} diff --git a/src/stats/connection_lease_security_tests.rs b/src/stats/connection_lease_security_tests.rs index 2d942c2..69ae89a 100644 --- a/src/stats/connection_lease_security_tests.rs +++ b/src/stats/connection_lease_security_tests.rs @@ -2,6 +2,7 @@ use super::*; use std::panic::{self, AssertUnwindSafe}; use std::sync::Arc; use std::time::Duration; +use tokio::sync::Barrier; #[test] fn direct_connection_lease_balances_on_drop() { @@ -63,6 +64,156 @@ fn direct_connection_lease_balances_on_panic_unwind() { ); } +#[test] +fn middle_connection_lease_balances_on_panic_unwind() { + let stats = Arc::new(Stats::new()); + let stats_for_panic = stats.clone(); + + let panic_result = panic::catch_unwind(AssertUnwindSafe(move || { + let _lease = stats_for_panic.acquire_me_connection_lease(); + panic!("intentional panic to verify middle lease drop path"); + })); + + assert!(panic_result.is_err(), "panic must propagate from test closure"); + assert_eq!( + stats.get_current_connections_me(), + 0, + "panic unwind must release middle route gauge" + ); +} + +#[tokio::test] +async fn concurrent_mixed_route_lease_churn_balances_to_zero() { + const TASKS: usize = 48; + const ITERATIONS_PER_TASK: usize = 256; + + let stats = Arc::new(Stats::new()); + let barrier = Arc::new(Barrier::new(TASKS)); + let mut workers = Vec::with_capacity(TASKS); + + for task_idx in 0..TASKS { + let stats_for_task = stats.clone(); + let barrier_for_task = barrier.clone(); + workers.push(tokio::spawn(async move { + barrier_for_task.wait().await; + for iter in 0..ITERATIONS_PER_TASK { + if (task_idx + iter) % 2 == 0 { + let _lease = stats_for_task.acquire_direct_connection_lease(); + tokio::task::yield_now().await; + } else { + let _lease = stats_for_task.acquire_me_connection_lease(); + tokio::task::yield_now().await; + } + } + })); + } + + for worker in workers { + worker + .await + .expect("lease churn worker must not panic"); + } + + assert_eq!( + stats.get_current_connections_direct(), + 0, + "direct route gauge must return to zero after concurrent lease churn" + ); + assert_eq!( + stats.get_current_connections_me(), + 0, + "middle route gauge must return to zero after concurrent lease churn" + ); +} + +#[tokio::test] +async fn abort_storm_mixed_route_leases_returns_all_gauges_to_zero() { + const TASKS: usize = 64; + + let stats = Arc::new(Stats::new()); + let mut workers = Vec::with_capacity(TASKS); + + for task_idx in 0..TASKS { + let stats_for_task = stats.clone(); + workers.push(tokio::spawn(async move { + if task_idx % 2 == 0 { + let _lease = stats_for_task.acquire_direct_connection_lease(); + tokio::time::sleep(Duration::from_secs(60)).await; + } else { + let _lease = stats_for_task.acquire_me_connection_lease(); + tokio::time::sleep(Duration::from_secs(60)).await; + } + })); + } + + tokio::time::timeout(Duration::from_secs(2), async { + loop { + let total = stats.get_current_connections_direct() + stats.get_current_connections_me(); + if total == TASKS as u64 { + break; + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + }) + .await + .expect("all storm tasks must acquire route leases before abort"); + + for worker in &workers { + worker.abort(); + } + for worker in workers { + let joined = worker.await; + assert!(joined.is_err(), "aborted worker must return join error"); + } + + tokio::time::timeout(Duration::from_secs(2), async { + loop { + if stats.get_current_connections_direct() == 0 && stats.get_current_connections_me() == 0 { + break; + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + }) + .await + .expect("all route gauges must drain to zero after abort storm"); +} + +#[test] +fn saturating_route_decrements_do_not_underflow_under_race() { + const THREADS: usize = 16; + const DECREMENTS_PER_THREAD: usize = 4096; + + let stats = Arc::new(Stats::new()); + let mut workers = Vec::with_capacity(THREADS); + + for _ in 0..THREADS { + let stats_for_thread = stats.clone(); + workers.push(std::thread::spawn(move || { + for _ in 0..DECREMENTS_PER_THREAD { + stats_for_thread.decrement_current_connections_direct(); + stats_for_thread.decrement_current_connections_me(); + } + })); + } + + for worker in workers { + worker + .join() + .expect("decrement race worker must not panic"); + } + + assert_eq!( + stats.get_current_connections_direct(), + 0, + "direct route decrement races must never underflow" + ); + assert_eq!( + stats.get_current_connections_me(), + 0, + "middle route decrement races must never underflow" + ); +} + #[tokio::test] async fn direct_connection_lease_balances_on_task_abort() { let stats = Arc::new(Stats::new()); From d81140ccec601d0d65e2bfa9755d06cd099e20a4 Mon Sep 17 00:00:00 2001 From: David Osipov Date: Tue, 17 Mar 2026 19:39:29 +0400 Subject: [PATCH 08/13] Enhance UserConnectionReservation management: add active state and release method, improve cleanup on drop, and implement tests for immediate release and concurrent handling --- src/proxy/client.rs | 24 +++- src/proxy/client_security_tests.rs | 215 +++++++++++++++++++++++++++++ 2 files changed, 238 insertions(+), 1 deletion(-) diff --git a/src/proxy/client.rs b/src/proxy/client.rs index 254d922..f80f74d 100644 --- a/src/proxy/client.rs +++ b/src/proxy/client.rs @@ -29,6 +29,7 @@ struct UserConnectionReservation { ip_tracker: Arc, user: String, ip: IpAddr, + active: bool, } impl UserConnectionReservation { @@ -38,12 +39,26 @@ impl UserConnectionReservation { ip_tracker, user, ip, + active: true, } } + + async fn release(mut self) { + if !self.active { + return; + } + self.active = false; + self.stats.decrement_user_curr_connects(&self.user); + self.ip_tracker.remove_ip(&self.user, self.ip).await; + } } impl Drop for UserConnectionReservation { fn drop(&mut self) { + if !self.active { + return; + } + self.active = false; self.stats.decrement_user_curr_connects(&self.user); if let Ok(handle) = tokio::runtime::Handle::try_current() { @@ -53,6 +68,12 @@ impl Drop for UserConnectionReservation { handle.spawn(async move { ip_tracker.remove_ip(&user, ip).await; }); + } else { + warn!( + user = %self.user, + ip = %self.ip, + "UserConnectionReservation dropped without Tokio runtime; IP reservation cleanup skipped" + ); } } } @@ -833,7 +854,7 @@ impl RunningClientHandler { { let user = success.user.clone(); - let _user_limit_reservation = + let user_limit_reservation = match Self::acquire_user_connection_reservation_static( &user, &config, @@ -905,6 +926,7 @@ impl RunningClientHandler { ) .await }; + user_limit_reservation.release().await; relay_result } diff --git a/src/proxy/client_security_tests.rs b/src/proxy/client_security_tests.rs index defb3c0..7047987 100644 --- a/src/proxy/client_security_tests.rs +++ b/src/proxy/client_security_tests.rs @@ -1420,6 +1420,221 @@ async fn tcp_limit_rejection_does_not_reserve_ip_or_trigger_rollback() { ); } +#[tokio::test] +async fn explicit_reservation_release_cleans_user_and_ip_immediately() { + let user = "release-user"; + let peer_addr: SocketAddr = "198.51.100.240:50002".parse().unwrap(); + + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 4); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer_addr, + ip_tracker.clone(), + ) + .await + .expect("reservation acquisition must succeed"); + + assert_eq!(stats.get_user_curr_connects(user), 1); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 1); + + reservation.release().await; + + assert_eq!( + stats.get_user_curr_connects(user), + 0, + "explicit release must synchronously free user connection slot" + ); + assert_eq!( + ip_tracker.get_active_ip_count(user).await, + 0, + "explicit release must synchronously remove reserved user IP" + ); +} + +#[tokio::test] +async fn explicit_reservation_release_does_not_double_decrement_on_drop() { + let user = "release-once-user"; + let peer_addr: SocketAddr = "198.51.100.241:50003".parse().unwrap(); + + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 4); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer_addr, + ip_tracker, + ) + .await + .expect("reservation acquisition must succeed"); + + reservation.release().await; + + assert_eq!( + stats.get_user_curr_connects(user), + 0, + "release must disarm drop and prevent double decrement" + ); +} + +#[tokio::test] +async fn drop_fallback_eventually_cleans_user_and_ip_reservation() { + let user = "drop-fallback-user"; + let peer_addr: SocketAddr = "198.51.100.242:50004".parse().unwrap(); + + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 4); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer_addr, + ip_tracker.clone(), + ) + .await + .expect("reservation acquisition must succeed"); + + assert_eq!(stats.get_user_curr_connects(user), 1); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 1); + + drop(reservation); + + tokio::time::timeout(Duration::from_secs(1), async { + loop { + if stats.get_user_curr_connects(user) == 0 + && ip_tracker.get_active_ip_count(user).await == 0 + { + break; + } + tokio::time::sleep(Duration::from_millis(5)).await; + } + }) + .await + .expect("drop fallback must eventually clean both user slot and active IP"); +} + +#[tokio::test] +async fn explicit_release_allows_immediate_cross_ip_reacquire_under_limit() { + let user = "cross-ip-user"; + let peer1: SocketAddr = "198.51.100.243:50005".parse().unwrap(); + let peer2: SocketAddr = "198.51.100.244:50006".parse().unwrap(); + + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 4); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 1).await; + + let first = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer1, + ip_tracker.clone(), + ) + .await + .expect("first reservation must succeed"); + first.release().await; + + let second = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer2, + ip_tracker.clone(), + ) + .await + .expect("second reservation must succeed immediately after explicit release"); + second.release().await; + + assert_eq!(stats.get_user_curr_connects(user), 0); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 0); +} + +#[tokio::test] +async fn concurrent_release_storm_leaves_zero_user_and_ip_footprint() { + const RESERVATIONS: usize = 64; + + let user = "release-storm-user"; + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), RESERVATIONS + 8); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let mut reservations = Vec::with_capacity(RESERVATIONS); + for idx in 0..RESERVATIONS { + let ip = std::net::Ipv4Addr::new(203, 0, 113, (idx + 1) as u8); + let peer = SocketAddr::new(IpAddr::V4(ip), 51000 + idx as u16); + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await + .expect("reservation acquisition in storm must succeed"); + reservations.push(reservation); + } + + assert_eq!(stats.get_user_curr_connects(user), RESERVATIONS as u64); + assert_eq!(ip_tracker.get_active_ip_count(user).await, RESERVATIONS); + + let mut tasks = tokio::task::JoinSet::new(); + for reservation in reservations { + tasks.spawn(async move { + reservation.release().await; + }); + } + + while let Some(result) = tasks.join_next().await { + result.expect("release task must not panic"); + } + + assert_eq!( + stats.get_user_curr_connects(user), + 0, + "release storm must drain user current-connection counter to zero" + ); + assert_eq!( + ip_tracker.get_active_ip_count(user).await, + 0, + "release storm must clear all active IP entries" + ); +} + #[tokio::test] async fn quota_rejection_does_not_reserve_ip_or_trigger_rollback() { let mut config = ProxyConfig::default(); From 4e3f42dce3ba6753305eba70b482edec5dbd87c1 Mon Sep 17 00:00:00 2001 From: David Osipov Date: Tue, 17 Mar 2026 19:55:55 +0400 Subject: [PATCH 09/13] Add must_use attribute to UserConnectionReservation and RouteConnectionLease structs for better resource management --- src/proxy/client.rs | 1 + src/proxy/client_security_tests.rs | 217 +++++++++++++++++++++++++++++ src/stats/mod.rs | 1 + 3 files changed, 219 insertions(+) diff --git a/src/proxy/client.rs b/src/proxy/client.rs index f80f74d..0077f72 100644 --- a/src/proxy/client.rs +++ b/src/proxy/client.rs @@ -24,6 +24,7 @@ enum HandshakeOutcome { Handled, } +#[must_use = "UserConnectionReservation must be kept alive to retain user/IP reservation until release or drop"] struct UserConnectionReservation { stats: Arc, ip_tracker: Arc, diff --git a/src/proxy/client_security_tests.rs b/src/proxy/client_security_tests.rs index 7047987..8058c38 100644 --- a/src/proxy/client_security_tests.rs +++ b/src/proxy/client_security_tests.rs @@ -1635,6 +1635,223 @@ async fn concurrent_release_storm_leaves_zero_user_and_ip_footprint() { ); } +#[tokio::test] +async fn relay_connect_error_releases_user_and_ip_before_return() { + let user = "relay-error-user"; + let peer_addr: SocketAddr = "198.51.100.245:50007".parse().unwrap(); + + let dead_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let dead_port = dead_listener.local_addr().unwrap().port(); + drop(dead_listener); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 1); + config + .dc_overrides + .insert("2".to_string(), vec![format!("127.0.0.1:{dead_port}")]); + let config = Arc::new(config); + + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + + let (server_side, _client_side) = duplex(64 * 1024); + let (server_reader, server_writer) = tokio::io::split(server_side); + let client_reader = make_crypto_reader(server_reader); + let client_writer = make_crypto_writer(server_writer); + + let success = HandshakeSuccess { + user: user.to_string(), + dc_idx: 2, + proto_tag: ProtoTag::Intermediate, + dec_key: [0u8; 32], + dec_iv: 0, + enc_key: [0u8; 32], + enc_iv: 0, + peer: peer_addr, + is_tls: false, + }; + + let result = RunningClientHandler::handle_authenticated_static( + client_reader, + client_writer, + success, + upstream_manager, + stats.clone(), + config, + buffer_pool, + rng, + None, + route_runtime, + "127.0.0.1:443".parse().unwrap(), + peer_addr, + ip_tracker.clone(), + ) + .await; + + assert!(result.is_err(), "relay must fail when upstream DC is unreachable"); + assert_eq!( + stats.get_user_curr_connects(user), + 0, + "error return must release user slot before returning" + ); + assert_eq!( + ip_tracker.get_active_ip_count(user).await, + 0, + "error return must release user IP reservation before returning" + ); +} + +#[tokio::test] +async fn mixed_release_and_drop_same_ip_preserves_counter_correctness() { + let user = "same-ip-mixed-user"; + let peer_addr: SocketAddr = "198.51.100.246:50008".parse().unwrap(); + + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 8); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let reservation_a = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer_addr, + ip_tracker.clone(), + ) + .await + .expect("first reservation must succeed"); + let reservation_b = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer_addr, + ip_tracker.clone(), + ) + .await + .expect("second reservation must succeed"); + + assert_eq!(stats.get_user_curr_connects(user), 2); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 1); + + reservation_a.release().await; + assert_eq!( + stats.get_user_curr_connects(user), + 1, + "explicit release must decrement only one active reservation" + ); + assert_eq!( + ip_tracker.get_active_ip_count(user).await, + 1, + "same IP must remain active while second reservation exists" + ); + + drop(reservation_b); + tokio::time::timeout(Duration::from_secs(1), async { + loop { + if stats.get_user_curr_connects(user) == 0 + && ip_tracker.get_active_ip_count(user).await == 0 + { + break; + } + tokio::time::sleep(Duration::from_millis(5)).await; + } + }) + .await + .expect("drop fallback must clear final same-IP reservation"); +} + +#[tokio::test] +async fn drop_one_of_two_same_ip_reservations_keeps_ip_active() { + let user = "same-ip-drop-one-user"; + let peer_addr: SocketAddr = "198.51.100.247:50009".parse().unwrap(); + + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 8); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let reservation_a = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer_addr, + ip_tracker.clone(), + ) + .await + .expect("first reservation must succeed"); + let reservation_b = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer_addr, + ip_tracker.clone(), + ) + .await + .expect("second reservation must succeed"); + + drop(reservation_a); + tokio::time::timeout(Duration::from_secs(1), async { + loop { + if stats.get_user_curr_connects(user) == 1 + && ip_tracker.get_active_ip_count(user).await == 1 + { + break; + } + tokio::time::sleep(Duration::from_millis(5)).await; + } + }) + .await + .expect("dropping one reservation must keep same-IP activity for remaining reservation"); + + reservation_b.release().await; + tokio::time::timeout(Duration::from_secs(1), async { + loop { + if stats.get_user_curr_connects(user) == 0 + && ip_tracker.get_active_ip_count(user).await == 0 + { + break; + } + tokio::time::sleep(Duration::from_millis(5)).await; + } + }) + .await + .expect("final release must converge to zero footprint after async fallback cleanup"); +} + #[tokio::test] async fn quota_rejection_does_not_reserve_ip_or_trigger_rollback() { let mut config = ProxyConfig::default(); diff --git a/src/stats/mod.rs b/src/stats/mod.rs index 36241af..3ad361f 100644 --- a/src/stats/mod.rs +++ b/src/stats/mod.rs @@ -26,6 +26,7 @@ enum RouteConnectionGauge { Middle, } +#[must_use = "RouteConnectionLease must be kept alive to hold the connection gauge increment"] pub struct RouteConnectionLease { stats: Arc, gauge: RouteConnectionGauge, From 0284b9f9e3a4754d1784cebb0aae8a645474d9c1 Mon Sep 17 00:00:00 2001 From: David Osipov Date: Tue, 17 Mar 2026 20:14:07 +0400 Subject: [PATCH 10/13] Refactor health integration tests to use wait_for_pool_empty for improved readability and timeout handling --- .../middle_proxy/health_integration_tests.rs | 20 ++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/src/transport/middle_proxy/health_integration_tests.rs b/src/transport/middle_proxy/health_integration_tests.rs index 70b6411..476b549 100644 --- a/src/transport/middle_proxy/health_integration_tests.rs +++ b/src/transport/middle_proxy/health_integration_tests.rs @@ -161,6 +161,20 @@ async fn insert_draining_writer( } } +async fn wait_for_pool_empty(pool: &Arc, timeout: Duration) { + let start = Instant::now(); + loop { + if pool.writers.read().await.is_empty() { + return; + } + assert!( + start.elapsed() < timeout, + "timed out waiting for pool.writers to become empty" + ); + tokio::time::sleep(Duration::from_millis(5)).await; + } +} + #[tokio::test] async fn me_health_monitor_drains_expired_backlog_over_multiple_cycles() { let (pool, rng) = make_pool(128, 1, 1).await; @@ -178,7 +192,7 @@ async fn me_health_monitor_drains_expired_backlog_over_multiple_cycles() { } let monitor = tokio::spawn(me_health_monitor(pool.clone(), rng, 0)); - tokio::time::sleep(Duration::from_millis(60)).await; + wait_for_pool_empty(&pool, Duration::from_secs(1)).await; monitor.abort(); let _ = monitor.await; @@ -194,7 +208,7 @@ async fn me_health_monitor_cleans_empty_draining_writers_without_force_close() { } let monitor = tokio::spawn(me_health_monitor(pool.clone(), rng, 0)); - tokio::time::sleep(Duration::from_millis(30)).await; + wait_for_pool_empty(&pool, Duration::from_secs(1)).await; monitor.abort(); let _ = monitor.await; @@ -219,7 +233,7 @@ async fn me_health_monitor_converges_retry_like_threshold_backlog_to_empty() { } let monitor = tokio::spawn(me_health_monitor(pool.clone(), rng, 0)); - tokio::time::sleep(Duration::from_millis(60)).await; + wait_for_pool_empty(&pool, Duration::from_secs(1)).await; monitor.abort(); let _ = monitor.await; From 2c06288b40c9ea733c6d0ac342d7ccd7f7c98ff9 Mon Sep 17 00:00:00 2001 From: David Osipov Date: Tue, 17 Mar 2026 20:21:01 +0400 Subject: [PATCH 11/13] Enhance UserConnectionReservation: add runtime handle for cross-thread IP cleanup and implement tests for user expiration and connection limits --- src/proxy/client.rs | 13 +- src/proxy/client_security_tests.rs | 221 +++++++++++++++++++++++++++++ 2 files changed, 233 insertions(+), 1 deletion(-) diff --git a/src/proxy/client.rs b/src/proxy/client.rs index 0077f72..849e409 100644 --- a/src/proxy/client.rs +++ b/src/proxy/client.rs @@ -31,16 +31,19 @@ struct UserConnectionReservation { user: String, ip: IpAddr, active: bool, + runtime_handle: Option, } impl UserConnectionReservation { fn new(stats: Arc, ip_tracker: Arc, user: String, ip: IpAddr) -> Self { + let runtime_handle = tokio::runtime::Handle::try_current().ok(); Self { stats, ip_tracker, user, ip, active: true, + runtime_handle, } } @@ -62,7 +65,15 @@ impl Drop for UserConnectionReservation { self.active = false; self.stats.decrement_user_curr_connects(&self.user); - if let Ok(handle) = tokio::runtime::Handle::try_current() { + if let Some(handle) = &self.runtime_handle { + let ip_tracker = self.ip_tracker.clone(); + let user = self.user.clone(); + let ip = self.ip; + let handle = handle.clone(); + handle.spawn(async move { + ip_tracker.remove_ip(&user, ip).await; + }); + } else if let Ok(handle) = tokio::runtime::Handle::try_current() { let ip_tracker = self.ip_tracker.clone(); let user = self.user.clone(); let ip = self.ip; diff --git a/src/proxy/client_security_tests.rs b/src/proxy/client_security_tests.rs index 8058c38..8bdb234 100644 --- a/src/proxy/client_security_tests.rs +++ b/src/proxy/client_security_tests.rs @@ -1888,6 +1888,227 @@ async fn quota_rejection_does_not_reserve_ip_or_trigger_rollback() { ); } +#[tokio::test] +async fn expired_user_rejection_does_not_reserve_ip_or_increment_curr_connects() { + let mut config = ProxyConfig::default(); + config + .access + .user_expirations + .insert("user".to_string(), chrono::Utc::now() - chrono::Duration::seconds(1)); + + let stats = Stats::new(); + let ip_tracker = UserIpTracker::new(); + let peer_addr: SocketAddr = "203.0.113.212:50002".parse().unwrap(); + + let result = RunningClientHandler::check_user_limits_static( + "user", + &config, + &stats, + peer_addr, + &ip_tracker, + ) + .await; + + assert!(matches!( + result, + Err(ProxyError::UserExpired { user }) if user == "user" + )); + assert_eq!(stats.get_user_curr_connects("user"), 0); + assert_eq!(ip_tracker.get_active_ip_count("user").await, 0); +} + +#[tokio::test] +async fn same_ip_second_reservation_succeeds_under_unique_ip_limit_one() { + let user = "same-ip-unique-limit-user"; + let peer_addr: SocketAddr = "198.51.100.248:50010".parse().unwrap(); + + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 8); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 1).await; + + let first = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer_addr, + ip_tracker.clone(), + ) + .await + .expect("first reservation must succeed"); + let second = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer_addr, + ip_tracker.clone(), + ) + .await + .expect("second reservation from same IP must succeed under unique-ip limit=1"); + + assert_eq!(stats.get_user_curr_connects(user), 2); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 1); + + first.release().await; + second.release().await; + assert_eq!(stats.get_user_curr_connects(user), 0); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 0); +} + +#[tokio::test] +async fn second_distinct_ip_is_rejected_under_unique_ip_limit_one() { + let user = "distinct-ip-unique-limit-user"; + let peer1: SocketAddr = "198.51.100.249:50011".parse().unwrap(); + let peer2: SocketAddr = "198.51.100.250:50012".parse().unwrap(); + + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 8); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 1).await; + + let first = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer1, + ip_tracker.clone(), + ) + .await + .expect("first reservation must succeed"); + + let second = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer2, + ip_tracker.clone(), + ) + .await; + + assert!(matches!( + second, + Err(ProxyError::ConnectionLimitExceeded { user }) if user == "distinct-ip-unique-limit-user" + )); + assert_eq!(stats.get_user_curr_connects(user), 1); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 1); + + first.release().await; +} + +#[tokio::test] +async fn cross_thread_drop_uses_captured_runtime_for_ip_cleanup() { + let user = "cross-thread-drop-user"; + let peer_addr: SocketAddr = "198.51.100.251:50013".parse().unwrap(); + + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 8); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer_addr, + ip_tracker.clone(), + ) + .await + .expect("reservation acquisition must succeed"); + + assert_eq!(stats.get_user_curr_connects(user), 1); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 1); + + std::thread::spawn(move || { + drop(reservation); + }) + .join() + .expect("drop thread must not panic"); + + tokio::time::timeout(Duration::from_secs(1), async { + loop { + if stats.get_user_curr_connects(user) == 0 + && ip_tracker.get_active_ip_count(user).await == 0 + { + break; + } + tokio::time::sleep(Duration::from_millis(5)).await; + } + }) + .await + .expect("cross-thread drop must still converge to zero user and IP footprint"); +} + +#[tokio::test] +async fn immediate_reacquire_after_cross_thread_drop_succeeds() { + let user = "cross-thread-reacquire-user"; + let peer_addr: SocketAddr = "198.51.100.252:50014".parse().unwrap(); + + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 1); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer_addr, + ip_tracker.clone(), + ) + .await + .expect("initial reservation must succeed"); + + std::thread::spawn(move || { + drop(reservation); + }) + .join() + .expect("drop thread must not panic"); + + tokio::time::timeout(Duration::from_secs(1), async { + loop { + if stats.get_user_curr_connects(user) == 0 + && ip_tracker.get_active_ip_count(user).await == 0 + { + break; + } + tokio::time::sleep(Duration::from_millis(5)).await; + } + }) + .await + .expect("cross-thread cleanup must settle before reacquire check"); + + let reacquire = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats, + peer_addr, + ip_tracker, + ) + .await; + assert!( + reacquire.is_ok(), + "reacquire must succeed after cross-thread drop cleanup" + ); +} + #[tokio::test] async fn concurrent_limit_rejections_from_mixed_ips_leave_no_ip_footprint() { const PARALLEL_IPS: usize = 64; From 60953bcc2c5d940b739b50ec5d9c02d73027b203 Mon Sep 17 00:00:00 2001 From: David Osipov Date: Tue, 17 Mar 2026 20:53:37 +0400 Subject: [PATCH 12/13] Refactor user connection limit checks and enhance health monitoring tests: update warning messages, add new tests for draining writers, and improve state management --- src/proxy/client.rs | 6 +- src/proxy/client_security_tests.rs | 574 ++++++++++++++++++ src/transport/middle_proxy/health.rs | 17 +- .../middle_proxy/health_adversarial_tests.rs | 178 ++++++ .../middle_proxy/health_regression_tests.rs | 132 ++++ 5 files changed, 899 insertions(+), 8 deletions(-) diff --git a/src/proxy/client.rs b/src/proxy/client.rs index 849e409..5d32e34 100644 --- a/src/proxy/client.rs +++ b/src/proxy/client.rs @@ -878,7 +878,7 @@ impl RunningClientHandler { { Ok(reservation) => reservation, Err(e) => { - warn!(user = %user, error = %e, "User limit exceeded"); + warn!(user = %user, error = %e, "User admission check failed"); return Err(e); } }; @@ -998,8 +998,8 @@ impl RunningClientHandler { #[cfg(test)] async fn check_user_limits_static( - user: &str, - config: &ProxyConfig, + user: &str, + config: &ProxyConfig, stats: &Stats, peer_addr: SocketAddr, ip_tracker: &UserIpTracker, diff --git a/src/proxy/client_security_tests.rs b/src/proxy/client_security_tests.rs index 8bdb234..6ca2d4b 100644 --- a/src/proxy/client_security_tests.rs +++ b/src/proxy/client_security_tests.rs @@ -1420,6 +1420,105 @@ async fn tcp_limit_rejection_does_not_reserve_ip_or_trigger_rollback() { ); } +#[tokio::test] +async fn zero_tcp_limit_rejects_without_ip_or_counter_side_effects() { + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert("user".to_string(), 0); + + let stats = Stats::new(); + let ip_tracker = UserIpTracker::new(); + let peer_addr: SocketAddr = "198.51.100.211:50001".parse().unwrap(); + + let result = RunningClientHandler::check_user_limits_static( + "user", + &config, + &stats, + peer_addr, + &ip_tracker, + ) + .await; + + assert!(matches!( + result, + Err(ProxyError::ConnectionLimitExceeded { user }) if user == "user" + )); + assert_eq!(stats.get_user_curr_connects("user"), 0); + assert_eq!(ip_tracker.get_active_ip_count("user").await, 0); +} + +#[tokio::test] +async fn concurrent_distinct_ip_rejections_rollback_user_counter_without_leak() { + let user = "rollback-storm-user"; + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 128); + + let config = Arc::new(config); + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 1).await; + + let keeper_peer: SocketAddr = "198.51.100.212:50002".parse().unwrap(); + let keeper = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + keeper_peer, + ip_tracker.clone(), + ) + .await + .expect("keeper reservation must succeed"); + + let mut tasks = tokio::task::JoinSet::new(); + for i in 0..64u8 { + let config = config.clone(); + let stats = stats.clone(); + let ip_tracker = ip_tracker.clone(); + tasks.spawn(async move { + let peer = SocketAddr::new( + IpAddr::V4(std::net::Ipv4Addr::new(198, 51, 101, i.saturating_add(1))), + 41000 + i as u16, + ); + let result = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats, + peer, + ip_tracker, + ) + .await; + assert!(matches!( + result, + Err(ProxyError::ConnectionLimitExceeded { user }) if user == "rollback-storm-user" + )); + }); + } + + while let Some(joined) = tasks.join_next().await { + joined.unwrap(); + } + + assert_eq!( + stats.get_user_curr_connects(user), + 1, + "failed distinct-IP attempts must rollback acquired user slots" + ); + assert_eq!( + ip_tracker.get_active_ip_count(user).await, + 1, + "failed distinct-IP attempts must not leave extra active IPs" + ); + + keeper.release().await; + assert_eq!(stats.get_user_curr_connects(user), 0); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 0); +} + #[tokio::test] async fn explicit_reservation_release_cleans_user_and_ip_immediately() { let user = "release-user"; @@ -2990,3 +3089,478 @@ async fn tls_record_len_16384_is_accepted_in_client_handler_pipeline() { "Valid max-length ClientHello must not increment bad counter" ); } + +fn lcg_next(state: &mut u64) -> u64 { + *state = state.wrapping_mul(6364136223846793005).wrapping_add(1); + *state +} + +async fn wait_for_user_and_ip_zero( + stats: &Arc, + ip_tracker: &Arc, + user: &str, +) { + tokio::time::timeout(Duration::from_secs(2), async { + loop { + if stats.get_user_curr_connects(user) == 0 + && ip_tracker.get_active_ip_count(user).await == 0 + { + break; + } + tokio::time::sleep(Duration::from_millis(5)).await; + } + }) + .await + .expect("cleanup must converge to zero user and IP footprint"); +} + +async fn burst_acquire_distinct_ips( + user: &'static str, + config: Arc, + stats: Arc, + ip_tracker: Arc, + third_octet: u8, + attempts: u16, +) -> (Vec, usize) { + let mut tasks = tokio::task::JoinSet::new(); + for i in 0..attempts { + let config = config.clone(); + let stats = stats.clone(); + let ip_tracker = ip_tracker.clone(); + tasks.spawn(async move { + let host = (i as u8).saturating_add(1); + let peer = SocketAddr::new( + IpAddr::V4(std::net::Ipv4Addr::new(198, 51, third_octet, host)), + 55000 + i, + ); + RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats, + peer, + ip_tracker, + ) + .await + }); + } + + let mut successes = Vec::new(); + let mut failures = 0usize; + while let Some(joined) = tasks.join_next().await { + match joined.expect("burst acquire task must not panic") { + Ok(reservation) => successes.push(reservation), + Err(err) => { + assert!(matches!( + err, + ProxyError::ConnectionLimitExceeded { user: ref denied_user } + if denied_user == user + )); + failures = failures.saturating_add(1); + } + } + } + + (successes, failures) +} + +#[tokio::test] +async fn deterministic_mixed_reservation_churn_preserves_counter_and_eventual_cleanup() { + let user = "deterministic-churn-user"; + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 12); + + let config = Arc::new(config); + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 4).await; + + let mut seed = 0xD1F2_A4C8_991B_77E1u64; + let mut reservations: Vec> = Vec::new(); + + for step in 0..220u64 { + let op = (lcg_next(&mut seed) % 100) as u8; + let active = reservations.iter().filter(|entry| entry.is_some()).count(); + + if active == 0 || op < 55 { + let ip_octet = (lcg_next(&mut seed) % 16 + 1) as u8; + let peer = SocketAddr::new( + IpAddr::V4(std::net::Ipv4Addr::new(198, 51, 120, ip_octet)), + 52000 + (step % 4000) as u16, + ); + let result = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await; + + if let Ok(reservation) = result { + reservations.push(Some(reservation)); + } else { + assert!(matches!( + result, + Err(ProxyError::ConnectionLimitExceeded { user }) if user == "deterministic-churn-user" + )); + } + } else { + let selected = reservations + .iter() + .enumerate() + .filter(|(_, entry)| entry.is_some()) + .map(|(idx, _)| idx) + .nth((lcg_next(&mut seed) as usize) % active) + .unwrap(); + + let reservation = reservations[selected].take().unwrap(); + if op < 80 { + reservation.release().await; + } else { + std::thread::spawn(move || { + drop(reservation); + }) + .join() + .expect("cross-thread drop must not panic"); + } + } + + let live_slots = reservations.iter().filter(|entry| entry.is_some()).count() as u64; + assert_eq!( + stats.get_user_curr_connects(user), + live_slots, + "current-connects counter must match number of live reservations" + ); + assert!( + stats.get_user_curr_connects(user) <= 12, + "current-connects must stay within configured TCP limit" + ); + assert!( + ip_tracker.get_active_ip_count(user).await <= 4, + "active unique IPs must stay within configured per-user IP limit" + ); + } + + for reservation in reservations.into_iter().flatten() { + reservation.release().await; + } + wait_for_user_and_ip_zero(&stats, &ip_tracker, user).await; +} + +#[tokio::test] +async fn cross_thread_drop_storm_then_parallel_reacquire_wave_has_no_leak() { + let user = "drop-storm-reacquire-user"; + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 64); + + let config = Arc::new(config); + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 8).await; + + let mut initial = Vec::new(); + for i in 0..32u16 { + let ip_octet = (i % 8 + 1) as u8; + let peer = SocketAddr::new( + IpAddr::V4(std::net::Ipv4Addr::new(203, 0, 120, ip_octet)), + 53000 + i, + ); + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await + .expect("initial reservation must succeed"); + initial.push(reservation); + } + + let mut second_half = initial.split_off(16); + + let mut releases = Vec::new(); + for reservation in initial { + releases.push(tokio::spawn(async move { + reservation.release().await; + })); + } + for release_task in releases { + release_task.await.expect("release task must not panic"); + } + + let mut drop_threads = Vec::new(); + for reservation in second_half.drain(..) { + drop_threads.push(std::thread::spawn(move || { + drop(reservation); + })); + } + for drop_thread in drop_threads { + drop_thread + .join() + .expect("cross-thread drop worker must not panic"); + } + + wait_for_user_and_ip_zero(&stats, &ip_tracker, user).await; + + let mut reacquire_tasks = tokio::task::JoinSet::new(); + for i in 0..16u16 { + let config = config.clone(); + let stats = stats.clone(); + let ip_tracker = ip_tracker.clone(); + reacquire_tasks.spawn(async move { + let peer = SocketAddr::new( + IpAddr::V4(std::net::Ipv4Addr::new(198, 51, 121, (i + 1) as u8)), + 54000 + i, + ); + RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats, + peer, + ip_tracker, + ) + .await + }); + } + + let mut acquired = Vec::new(); + while let Some(joined) = reacquire_tasks.join_next().await { + match joined.expect("reacquire task must not panic") { + Ok(reservation) => acquired.push(reservation), + Err(err) => { + assert!(matches!( + err, + ProxyError::ConnectionLimitExceeded { user } + if user == "drop-storm-reacquire-user" + )); + } + } + } + + assert!( + acquired.len() <= 8, + "parallel distinct-IP reacquire wave must not exceed per-user unique IP limit" + ); + for reservation in acquired { + reservation.release().await; + } + wait_for_user_and_ip_zero(&stats, &ip_tracker, user).await; +} + +#[tokio::test] +async fn scheduled_near_limit_and_burst_windows_preserve_admission_invariants() { + let user: &'static str = "scheduled-attack-user"; + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 6); + + let config = Arc::new(config); + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 2).await; + + let mut base = Vec::new(); + for i in 0..5u16 { + let peer = SocketAddr::new(IpAddr::V4(std::net::Ipv4Addr::new(198, 51, 130, 1)), 56000 + i); + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await + .expect("near-limit warmup reservation must succeed"); + base.push(reservation); + } + assert_eq!(stats.get_user_curr_connects(user), 5); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 1); + + let (wave1_success, wave1_fail) = burst_acquire_distinct_ips( + user, + config.clone(), + stats.clone(), + ip_tracker.clone(), + 131, + 32, + ) + .await; + assert_eq!(wave1_success.len(), 1); + assert_eq!(wave1_fail, 31); + assert_eq!(stats.get_user_curr_connects(user), 6); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 2); + + let released = base.pop().expect("must have releasable reservation"); + released.release().await; + for reservation in wave1_success { + reservation.release().await; + } + + tokio::time::timeout(Duration::from_secs(1), async { + loop { + if stats.get_user_curr_connects(user) == 4 + && ip_tracker.get_active_ip_count(user).await == 1 + { + break; + } + tokio::time::sleep(Duration::from_millis(5)).await; + } + }) + .await + .expect("window cleanup must settle to expected occupancy"); + + let (wave2_success, wave2_fail) = burst_acquire_distinct_ips( + user, + config, + stats.clone(), + ip_tracker.clone(), + 132, + 32, + ) + .await; + assert_eq!(wave2_success.len(), 1); + assert_eq!(wave2_fail, 31); + assert_eq!(stats.get_user_curr_connects(user), 5); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 2); + + let tail = base.split_off(2); + + let mut drop_threads = Vec::new(); + for reservation in base { + drop_threads.push(std::thread::spawn(move || { + drop(reservation); + })); + } + for drop_thread in drop_threads { + drop_thread + .join() + .expect("cross-thread scheduled cleanup must not panic"); + } + + for reservation in tail { + reservation.release().await; + } + for reservation in wave2_success { + reservation.release().await; + } + + wait_for_user_and_ip_zero(&stats, &ip_tracker, user).await; +} + +#[tokio::test] +async fn scheduled_mode_switch_burst_churn_preserves_limits_and_cleanup() { + let user: &'static str = "scheduled-mode-switch-user"; + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 10); + + let config = Arc::new(config); + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 3).await; + + let base_peer = SocketAddr::new(IpAddr::V4(std::net::Ipv4Addr::new(198, 51, 140, 1)), 57000); + let mut base = Vec::new(); + for i in 0..7u16 { + let peer = SocketAddr::new(base_peer.ip(), base_peer.port().saturating_add(i)); + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await + .expect("base occupancy reservation must succeed"); + base.push(reservation); + } + + assert_eq!(stats.get_user_curr_connects(user), 7); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 1); + + for round in 0..8u8 { + let (wave_success, wave_fail) = burst_acquire_distinct_ips( + user, + config.clone(), + stats.clone(), + ip_tracker.clone(), + 141u8.saturating_add(round), + 24, + ) + .await; + + assert!( + wave_success.len() <= 2, + "burst must not exceed available unique-IP headroom under limit=3" + ); + assert_eq!(wave_success.len() + wave_fail, 24); + assert_eq!( + stats.get_user_curr_connects(user), + 7 + wave_success.len() as u64, + "slot counter must reflect base occupancy plus successful burst leases" + ); + assert!(ip_tracker.get_active_ip_count(user).await <= 3); + + if round % 2 == 0 { + for reservation in wave_success { + reservation.release().await; + } + let rotated = base.pop().expect("base rotation reservation must exist"); + rotated.release().await; + } else { + for reservation in wave_success { + std::thread::spawn(move || { + drop(reservation); + }) + .join() + .expect("drop-heavy burst cleanup thread must not panic"); + } + let rotated = base.pop().expect("base rotation reservation must exist"); + std::thread::spawn(move || { + drop(rotated); + }) + .join() + .expect("drop-heavy base cleanup thread must not panic"); + } + + let replacement = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + base_peer, + ip_tracker.clone(), + ) + .await + .expect("base replacement reservation must succeed after each round"); + base.push(replacement); + + tokio::time::timeout(Duration::from_secs(1), async { + loop { + if stats.get_user_curr_connects(user) == 7 + && ip_tracker.get_active_ip_count(user).await <= 1 + { + break; + } + tokio::time::sleep(Duration::from_millis(5)).await; + } + }) + .await + .expect("round cleanup must converge to steady base occupancy"); + } + + for reservation in base { + reservation.release().await; + } + wait_for_user_and_ip_zero(&stats, &ip_tracker, user).await; +} diff --git a/src/transport/middle_proxy/health.rs b/src/transport/middle_proxy/health.rs index edc9598..1c2c648 100644 --- a/src/transport/middle_proxy/health.rs +++ b/src/transport/middle_proxy/health.rs @@ -186,9 +186,7 @@ pub(super) async fn reap_draining_writers( } } - let mut active_draining_writer_ids = HashSet::with_capacity(draining_writers.len()); for writer in draining_writers { - active_draining_writer_ids.insert(writer.id); if drain_ttl_secs > 0 && writer.draining_started_at_epoch_secs != 0 && now_epoch_secs.saturating_sub(writer.draining_started_at_epoch_secs) > drain_ttl_secs @@ -214,12 +212,9 @@ pub(super) async fn reap_draining_writers( { warn!(writer_id = writer.id, "Drain timeout, force-closing"); force_close_writer_ids.push(writer.id); - active_draining_writer_ids.remove(&writer.id); } } - warn_next_allowed.retain(|writer_id, _| active_draining_writer_ids.contains(writer_id)); - let close_budget = health_drain_close_budget(); let requested_force_close = force_close_writer_ids.len(); let requested_empty_close = empty_writer_ids.len(); @@ -257,6 +252,18 @@ pub(super) async fn reap_draining_writers( "ME draining close backlog deferred to next health cycle" ); } + + // Keep warn cooldown state for draining writers still present in the pool; + // drop state only once a writer is actually removed. + let active_draining_writer_ids = { + let writers = pool.writers.read().await; + writers + .iter() + .filter(|writer| writer.draining.load(std::sync::atomic::Ordering::Relaxed)) + .map(|writer| writer.id) + .collect::>() + }; + warn_next_allowed.retain(|writer_id, _| active_draining_writer_ids.contains(writer_id)); } pub(super) fn health_drain_close_budget() -> usize { diff --git a/src/transport/middle_proxy/health_adversarial_tests.rs b/src/transport/middle_proxy/health_adversarial_tests.rs index 675005a..cd06fdf 100644 --- a/src/transport/middle_proxy/health_adversarial_tests.rs +++ b/src/transport/middle_proxy/health_adversarial_tests.rs @@ -1,4 +1,5 @@ use std::collections::HashMap; +use std::collections::HashSet; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::sync::Arc; use std::sync::atomic::{AtomicBool, AtomicU8, AtomicU32, AtomicU64, Ordering}; @@ -181,6 +182,40 @@ async fn sorted_writer_ids(pool: &Arc) -> Vec { ids } +fn lcg_next(state: &mut u64) -> u64 { + *state = state.wrapping_mul(6364136223846793005).wrapping_add(1); + *state +} + +async fn draining_writer_ids(pool: &Arc) -> HashSet { + pool.writers + .read() + .await + .iter() + .filter(|writer| writer.draining.load(Ordering::Relaxed)) + .map(|writer| writer.id) + .collect::>() +} + +async fn set_writer_runtime_state( + pool: &Arc, + writer_id: u64, + draining: bool, + drain_started_at_epoch_secs: u64, + drain_deadline_epoch_secs: u64, +) { + let writers = pool.writers.read().await; + if let Some(writer) = writers.iter().find(|writer| writer.id == writer_id) { + writer.draining.store(draining, Ordering::Relaxed); + writer + .draining_started_at_epoch_secs + .store(drain_started_at_epoch_secs, Ordering::Relaxed); + writer + .drain_deadline_epoch_secs + .store(drain_deadline_epoch_secs, Ordering::Relaxed); + } +} + #[tokio::test] async fn reap_draining_writers_clears_warn_state_when_pool_empty() { let (pool, _rng) = make_pool(128, 1, 1).await; @@ -430,6 +465,149 @@ async fn me_health_monitor_eliminates_mixed_empty_and_deadline_backlog() { assert!(writer_count(&pool).await <= threshold as usize); } +#[tokio::test] +async fn reap_draining_writers_deterministic_mixed_state_churn_preserves_invariants() { + let threshold = 9u64; + let (pool, _rng) = make_pool(threshold, 1, 1).await; + let mut warn_next_allowed = HashMap::new(); + let mut seed = 0x9E37_79B9_7F4A_7C15u64; + let mut next_writer_id = 20_000u64; + let now_epoch_secs = MePool::now_epoch_secs(); + + for writer_id in 1..=72u64 { + let bound_clients = if writer_id % 4 == 0 { 0 } else { 1 }; + let deadline = if writer_id % 5 == 0 { + now_epoch_secs.saturating_sub(1) + } else { + 0 + }; + insert_draining_writer( + &pool, + writer_id, + now_epoch_secs.saturating_sub(500).saturating_add(writer_id), + bound_clients, + deadline, + ) + .await; + } + + for _round in 0..90 { + reap_draining_writers(&pool, &mut warn_next_allowed).await; + + let draining_ids = draining_writer_ids(&pool).await; + assert!( + warn_next_allowed.keys().all(|id| draining_ids.contains(id)), + "warn-state keys must always be a subset of live draining writers" + ); + + let writer_ids = sorted_writer_ids(&pool).await; + if writer_ids.is_empty() { + continue; + } + + let remove_n = (lcg_next(&mut seed) % 3) as usize; + for writer_id in writer_ids.iter().copied().take(remove_n) { + let _ = pool.remove_writer_and_close_clients(writer_id).await; + } + + let survivors = sorted_writer_ids(&pool).await; + if !survivors.is_empty() { + let idx = (lcg_next(&mut seed) as usize) % survivors.len(); + let target = survivors[idx]; + set_writer_runtime_state(&pool, target, false, 0, 0).await; + } + + let survivors = sorted_writer_ids(&pool).await; + if survivors.len() > 1 { + let idx = (lcg_next(&mut seed) as usize) % survivors.len(); + let target = survivors[idx]; + let expired_deadline = if lcg_next(&mut seed) & 1 == 0 { + now_epoch_secs.saturating_sub(1) + } else { + 0 + }; + set_writer_runtime_state( + &pool, + target, + true, + now_epoch_secs.saturating_sub(120), + expired_deadline, + ) + .await; + } + + let inject_n = (lcg_next(&mut seed) % 4) as usize; + for _ in 0..inject_n { + let bound_clients = if lcg_next(&mut seed) & 1 == 0 { 0 } else { 1 }; + let deadline = if lcg_next(&mut seed) & 1 == 0 { + now_epoch_secs.saturating_sub(1) + } else { + 0 + }; + insert_draining_writer( + &pool, + next_writer_id, + now_epoch_secs.saturating_sub(240), + bound_clients, + deadline, + ) + .await; + next_writer_id = next_writer_id.saturating_add(1); + } + } + + for _ in 0..64 { + reap_draining_writers(&pool, &mut warn_next_allowed).await; + if writer_count(&pool).await <= threshold as usize { + break; + } + } + + assert!(writer_count(&pool).await <= threshold as usize); + let draining_ids = draining_writer_ids(&pool).await; + assert!(warn_next_allowed.keys().all(|id| draining_ids.contains(id))); +} + +#[tokio::test] +async fn reap_draining_writers_repeated_draining_flips_never_leave_stale_warn_state() { + let (pool, _rng) = make_pool(64, 1, 1).await; + let now_epoch_secs = MePool::now_epoch_secs(); + + for writer_id in 1..=24u64 { + insert_draining_writer( + &pool, + writer_id, + now_epoch_secs.saturating_sub(240), + 1, + 0, + ) + .await; + } + + let mut warn_next_allowed = HashMap::new(); + for _round in 0..48u64 { + for writer_id in 1..=24u64 { + let draining = (writer_id + _round) % 3 != 0; + set_writer_runtime_state( + &pool, + writer_id, + draining, + now_epoch_secs.saturating_sub(120), + 0, + ) + .await; + } + + reap_draining_writers(&pool, &mut warn_next_allowed).await; + + let draining_ids = draining_writer_ids(&pool).await; + assert!( + warn_next_allowed.keys().all(|id| draining_ids.contains(id)), + "warn-state map must not retain entries for writers outside draining set" + ); + } +} + #[test] fn health_drain_close_budget_is_within_expected_bounds() { let budget = health_drain_close_budget(); diff --git a/src/transport/middle_proxy/health_regression_tests.rs b/src/transport/middle_proxy/health_regression_tests.rs index 05a8e6a..fe73670 100644 --- a/src/transport/middle_proxy/health_regression_tests.rs +++ b/src/transport/middle_proxy/health_regression_tests.rs @@ -168,6 +168,21 @@ async fn current_writer_ids(pool: &Arc) -> Vec { writer_ids } +async fn writer_exists(pool: &Arc, writer_id: u64) -> bool { + pool.writers + .read() + .await + .iter() + .any(|writer| writer.id == writer_id) +} + +async fn set_writer_draining(pool: &Arc, writer_id: u64, draining: bool) { + let writers = pool.writers.read().await; + if let Some(writer) = writers.iter().find(|writer| writer.id == writer_id) { + writer.draining.store(draining, Ordering::Relaxed); + } +} + #[tokio::test] async fn reap_draining_writers_drops_warn_state_for_removed_writer() { let pool = make_pool(128).await; @@ -257,6 +272,123 @@ async fn reap_draining_writers_limits_closes_per_health_tick() { assert_eq!(pool.writers.read().await.len(), writer_total - close_budget); } +#[tokio::test] +async fn reap_draining_writers_keeps_warn_state_for_deadline_backlog_writers() { + let pool = make_pool(0).await; + let now_epoch_secs = MePool::now_epoch_secs(); + let close_budget = health_drain_close_budget(); + let writer_total = close_budget.saturating_add(5); + for writer_id in 1..=writer_total as u64 { + insert_draining_writer( + &pool, + writer_id, + now_epoch_secs.saturating_sub(60), + 1, + now_epoch_secs.saturating_sub(1), + ) + .await; + } + let target_writer_id = writer_total as u64; + let mut warn_next_allowed = HashMap::new(); + warn_next_allowed.insert( + target_writer_id, + Instant::now() + Duration::from_secs(300), + ); + + reap_draining_writers(&pool, &mut warn_next_allowed).await; + + assert!(writer_exists(&pool, target_writer_id).await); + assert!(warn_next_allowed.contains_key(&target_writer_id)); +} + +#[tokio::test] +async fn reap_draining_writers_keeps_warn_state_for_overflow_backlog_writers() { + let pool = make_pool(1).await; + let now_epoch_secs = MePool::now_epoch_secs(); + let close_budget = health_drain_close_budget(); + let writer_total = close_budget.saturating_add(6); + for writer_id in 1..=writer_total as u64 { + insert_draining_writer( + &pool, + writer_id, + now_epoch_secs.saturating_sub(300).saturating_add(writer_id), + 1, + 0, + ) + .await; + } + let target_writer_id = writer_total.saturating_sub(1) as u64; + let mut warn_next_allowed = HashMap::new(); + warn_next_allowed.insert( + target_writer_id, + Instant::now() + Duration::from_secs(300), + ); + + reap_draining_writers(&pool, &mut warn_next_allowed).await; + + assert!(writer_exists(&pool, target_writer_id).await); + assert!(warn_next_allowed.contains_key(&target_writer_id)); +} + +#[tokio::test] +async fn reap_draining_writers_drops_warn_state_when_writer_exits_draining_state() { + let pool = make_pool(128).await; + let now_epoch_secs = MePool::now_epoch_secs(); + insert_draining_writer(&pool, 71, now_epoch_secs.saturating_sub(60), 1, 0).await; + + let mut warn_next_allowed = HashMap::new(); + warn_next_allowed.insert(71, Instant::now() + Duration::from_secs(300)); + + set_writer_draining(&pool, 71, false).await; + reap_draining_writers(&pool, &mut warn_next_allowed).await; + + assert!(writer_exists(&pool, 71).await); + assert!( + !warn_next_allowed.contains_key(&71), + "warn cooldown state must be dropped after writer leaves draining state" + ); +} + +#[tokio::test] +async fn reap_draining_writers_preserves_warn_state_across_multiple_budget_deferrals() { + let pool = make_pool(0).await; + let now_epoch_secs = MePool::now_epoch_secs(); + let close_budget = health_drain_close_budget(); + let writer_total = close_budget.saturating_mul(2).saturating_add(1); + for writer_id in 1..=writer_total as u64 { + insert_draining_writer( + &pool, + writer_id, + now_epoch_secs.saturating_sub(120), + 1, + now_epoch_secs.saturating_sub(1), + ) + .await; + } + + let tail_writer_id = writer_total as u64; + let mut warn_next_allowed = HashMap::new(); + warn_next_allowed.insert( + tail_writer_id, + Instant::now() + Duration::from_secs(300), + ); + + reap_draining_writers(&pool, &mut warn_next_allowed).await; + assert!(writer_exists(&pool, tail_writer_id).await); + assert!(warn_next_allowed.contains_key(&tail_writer_id)); + + reap_draining_writers(&pool, &mut warn_next_allowed).await; + assert!(writer_exists(&pool, tail_writer_id).await); + assert!(warn_next_allowed.contains_key(&tail_writer_id)); + + reap_draining_writers(&pool, &mut warn_next_allowed).await; + assert!(!writer_exists(&pool, tail_writer_id).await); + assert!( + !warn_next_allowed.contains_key(&tail_writer_id), + "warn cooldown state must clear once writer is actually removed" + ); +} + #[tokio::test] async fn reap_draining_writers_backlog_drains_across_ticks() { let pool = make_pool(128).await; From f0c37f233e59520c9492f6712d1b4d84333e658c Mon Sep 17 00:00:00 2001 From: David Osipov Date: Tue, 17 Mar 2026 21:38:15 +0400 Subject: [PATCH 13/13] Refactor health management: implement remove_writer_if_empty method for cleaner writer removal logic and update related functions to enhance efficiency in handling closed writers. --- src/proxy/handshake.rs | 47 +- src/proxy/handshake_security_tests.rs | 142 +++++ src/proxy/masking.rs | 28 +- src/proxy/masking_security_tests.rs | 523 +++++++++++++++++- src/transport/middle_proxy/health.rs | 4 +- .../middle_proxy/health_regression_tests.rs | 64 +++ src/transport/middle_proxy/pool_writer.rs | 18 +- src/transport/middle_proxy/registry.rs | 17 + 8 files changed, 822 insertions(+), 21 deletions(-) diff --git a/src/proxy/handshake.rs b/src/proxy/handshake.rs index dbd50d5..a1b3eb7 100644 --- a/src/proxy/handshake.rs +++ b/src/proxy/handshake.rs @@ -317,6 +317,24 @@ fn decode_user_secrets( secrets } +async fn maybe_apply_server_hello_delay(config: &ProxyConfig) { + if config.censorship.server_hello_delay_max_ms == 0 { + return; + } + + let min = config.censorship.server_hello_delay_min_ms; + let max = config.censorship.server_hello_delay_max_ms.max(min); + let delay_ms = if max == min { + max + } else { + rand::rng().random_range(min..=max) + }; + + if delay_ms > 0 { + tokio::time::sleep(Duration::from_millis(delay_ms)).await; + } +} + /// Result of successful handshake /// /// Key material (`dec_key`, `dec_iv`, `enc_key`, `enc_iv`) is @@ -368,11 +386,13 @@ where debug!(peer = %peer, handshake_len = handshake.len(), "Processing TLS handshake"); if auth_probe_is_throttled(peer.ip(), Instant::now()) { + maybe_apply_server_hello_delay(config).await; debug!(peer = %peer, "TLS handshake rejected by pre-auth probe throttle"); return HandshakeResult::BadClient { reader, writer }; } if handshake.len() < tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 { + maybe_apply_server_hello_delay(config).await; debug!(peer = %peer, "TLS handshake too short"); return HandshakeResult::BadClient { reader, writer }; } @@ -388,6 +408,7 @@ where Some(v) => v, None => { auth_probe_record_failure(peer.ip(), Instant::now()); + maybe_apply_server_hello_delay(config).await; debug!( peer = %peer, ignore_time_skew = config.access.ignore_time_skew, @@ -402,13 +423,17 @@ where let digest_half = &validation.digest[..tls::TLS_DIGEST_HALF_LEN]; if replay_checker.check_and_add_tls_digest(digest_half) { auth_probe_record_failure(peer.ip(), Instant::now()); + maybe_apply_server_hello_delay(config).await; warn!(peer = %peer, "TLS replay attack detected (duplicate digest)"); return HandshakeResult::BadClient { reader, writer }; } let secret = match secrets.iter().find(|(name, _)| *name == validation.user) { Some((_, s)) => s, - None => return HandshakeResult::BadClient { reader, writer }, + None => { + maybe_apply_server_hello_delay(config).await; + return HandshakeResult::BadClient { reader, writer }; + } }; let cached = if config.censorship.tls_emulation { @@ -448,6 +473,7 @@ where } else if alpn_list.iter().any(|p| p == b"http/1.1") { Some(b"http/1.1".to_vec()) } else if !alpn_list.is_empty() { + maybe_apply_server_hello_delay(config).await; debug!(peer = %peer, "Client ALPN list has no supported protocol; using masking fallback"); return HandshakeResult::BadClient { reader, writer }; } else { @@ -480,19 +506,9 @@ where ) }; - // Optional anti-fingerprint delay before sending ServerHello. - if config.censorship.server_hello_delay_max_ms > 0 { - let min = config.censorship.server_hello_delay_min_ms; - let max = config.censorship.server_hello_delay_max_ms.max(min); - let delay_ms = if max == min { - max - } else { - rand::rng().random_range(min..=max) - }; - if delay_ms > 0 { - tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await; - } - } + // Apply the same optional delay budget used by reject paths to reduce + // distinguishability between success and fail-closed handshakes. + maybe_apply_server_hello_delay(config).await; debug!(peer = %peer, response_len = response.len(), "Sending TLS ServerHello"); @@ -539,6 +555,7 @@ where trace!(peer = %peer, handshake = ?hex::encode(handshake), "MTProto handshake bytes"); if auth_probe_is_throttled(peer.ip(), Instant::now()) { + maybe_apply_server_hello_delay(config).await; debug!(peer = %peer, "MTProto handshake rejected by pre-auth probe throttle"); return HandshakeResult::BadClient { reader, writer }; } @@ -609,6 +626,7 @@ where // authentication check first to avoid poisoning the replay cache. if replay_checker.check_and_add_handshake(dec_prekey_iv) { auth_probe_record_failure(peer.ip(), Instant::now()); + maybe_apply_server_hello_delay(config).await; warn!(peer = %peer, user = %user, "MTProto replay attack detected"); return HandshakeResult::BadClient { reader, writer }; } @@ -645,6 +663,7 @@ where } auth_probe_record_failure(peer.ip(), Instant::now()); + maybe_apply_server_hello_delay(config).await; debug!(peer = %peer, "MTProto handshake: no matching user found"); HandshakeResult::BadClient { reader, writer } } diff --git a/src/proxy/handshake_security_tests.rs b/src/proxy/handshake_security_tests.rs index 6bdc345..7040025 100644 --- a/src/proxy/handshake_security_tests.rs +++ b/src/proxy/handshake_security_tests.rs @@ -580,6 +580,72 @@ async fn malformed_tls_classes_complete_within_bounded_time() { } } +#[tokio::test] +async fn tls_invalid_hmac_respects_configured_anti_fingerprint_delay() { + let secret = [0x5Au8; 16]; + let mut config = test_config_with_secret_hex("5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a"); + config.censorship.server_hello_delay_min_ms = 20; + config.censorship.server_hello_delay_max_ms = 20; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.32:44331".parse().unwrap(); + let mut bad_hmac = make_valid_tls_handshake(&secret, 0); + bad_hmac[tls::TLS_DIGEST_POS] ^= 0x01; + + let started = Instant::now(); + let result = handle_tls_handshake( + &bad_hmac, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(result, HandshakeResult::BadClient { .. })); + assert!( + started.elapsed() >= Duration::from_millis(18), + "configured anti-fingerprint delay must apply to invalid TLS handshakes" + ); +} + +#[tokio::test] +async fn tls_alpn_mismatch_respects_configured_anti_fingerprint_delay() { + let secret = [0x6Bu8; 16]; + let mut config = test_config_with_secret_hex("6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b"); + config.censorship.alpn_enforce = true; + config.censorship.server_hello_delay_min_ms = 20; + config.censorship.server_hello_delay_max_ms = 20; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.33:44332".parse().unwrap(); + let handshake = make_valid_tls_client_hello_with_alpn(&secret, 0, &[b"h3"]); + + let started = Instant::now(); + let result = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(result, HandshakeResult::BadClient { .. })); + assert!( + started.elapsed() >= Duration::from_millis(18), + "configured anti-fingerprint delay must apply to ALPN-mismatch rejects" + ); +} + #[tokio::test] #[ignore = "timing-sensitive; run manually on low-jitter hosts"] async fn malformed_tls_classes_share_close_latency_buckets() { @@ -643,6 +709,82 @@ async fn malformed_tls_classes_share_close_latency_buckets() { ); } +#[tokio::test] +#[ignore = "timing matrix; run manually with --ignored --nocapture"] +async fn timing_matrix_tls_classes_under_fixed_delay_budget() { + const ITER: usize = 48; + const BUCKET_MS: u128 = 10; + + let secret = [0x77u8; 16]; + let mut config = test_config_with_secret_hex("77777777777777777777777777777777"); + config.censorship.alpn_enforce = true; + config.censorship.server_hello_delay_min_ms = 20; + config.censorship.server_hello_delay_max_ms = 20; + + let rng = SecureRandom::new(); + let base_ip = std::net::Ipv4Addr::new(198, 51, 100, 34); + + let too_short = vec![0x16, 0x03, 0x01]; + let mut bad_hmac = make_valid_tls_handshake(&secret, 0); + bad_hmac[tls::TLS_DIGEST_POS + 1] ^= 0x01; + let alpn_mismatch = make_valid_tls_client_hello_with_alpn(&secret, 0, &[b"h3"]); + let valid_h2 = make_valid_tls_client_hello_with_alpn(&secret, 0, &[b"h2"]); + + let classes = vec![ + ("too_short", too_short), + ("bad_hmac", bad_hmac), + ("alpn_mismatch", alpn_mismatch), + ("valid_h2", valid_h2), + ]; + + for (class, probe) in classes { + let mut samples_ms = Vec::with_capacity(ITER); + for idx in 0..ITER { + clear_auth_probe_state_for_testing(); + let replay_checker = ReplayChecker::new(4096, Duration::from_secs(60)); + let peer: SocketAddr = SocketAddr::from((base_ip, 44_000 + idx as u16)); + let started = Instant::now(); + let result = handle_tls_handshake( + &probe, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + let elapsed = started.elapsed(); + samples_ms.push(elapsed.as_millis()); + + if class == "valid_h2" { + assert!(matches!(result, HandshakeResult::Success(_))); + } else { + assert!(matches!(result, HandshakeResult::BadClient { .. })); + } + } + + samples_ms.sort_unstable(); + let sum: u128 = samples_ms.iter().copied().sum(); + let mean = sum as f64 / samples_ms.len() as f64; + let min = samples_ms[0]; + let p95_idx = ((samples_ms.len() as f64) * 0.95).floor() as usize; + let p95 = samples_ms[p95_idx.min(samples_ms.len() - 1)]; + let max = samples_ms[samples_ms.len() - 1]; + + println!( + "TIMING_MATRIX tls class={} mean_ms={:.2} min_ms={} p95_ms={} max_ms={} bucket_mean={}", + class, + mean, + min, + p95, + max, + (mean as u128) / BUCKET_MS + ); + } +} + #[test] fn secure_tag_requires_tls_mode_on_tls_transport() { let mut config = ProxyConfig::default(); diff --git a/src/proxy/masking.rs b/src/proxy/masking.rs index 9a23c5b..eb6f6da 100644 --- a/src/proxy/masking.rs +++ b/src/proxy/masking.rs @@ -7,7 +7,7 @@ use tokio::net::TcpStream; #[cfg(unix)] use tokio::net::UnixStream; use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt}; -use tokio::time::timeout; +use tokio::time::{Instant, timeout}; use tracing::debug; use crate::config::ProxyConfig; use crate::network::dns_overrides::resolve_socket_addr; @@ -49,6 +49,20 @@ where } } +async fn wait_mask_connect_budget(started: Instant) { + let elapsed = started.elapsed(); + if elapsed < MASK_TIMEOUT { + tokio::time::sleep(MASK_TIMEOUT - elapsed).await; + } +} + +async fn wait_mask_outcome_budget(started: Instant) { + let elapsed = started.elapsed(); + if elapsed < MASK_TIMEOUT { + tokio::time::sleep(MASK_TIMEOUT - elapsed).await; + } +} + /// Detect client type based on initial data fn detect_client_type(data: &[u8]) -> &'static str { // Check for HTTP request @@ -107,6 +121,8 @@ where // Connect via Unix socket or TCP #[cfg(unix)] if let Some(ref sock_path) = config.censorship.mask_unix_sock { + let outcome_started = Instant::now(); + let connect_started = Instant::now(); debug!( client_type = client_type, sock = %sock_path, @@ -143,14 +159,18 @@ where if timeout(MASK_RELAY_TIMEOUT, relay_to_mask(reader, writer, mask_read, mask_write, initial_data)).await.is_err() { debug!("Mask relay timed out (unix socket)"); } + wait_mask_outcome_budget(outcome_started).await; } Ok(Err(e)) => { + wait_mask_connect_budget(connect_started).await; debug!(error = %e, "Failed to connect to mask unix socket"); consume_client_data_with_timeout(reader).await; + wait_mask_outcome_budget(outcome_started).await; } Err(_) => { debug!("Timeout connecting to mask unix socket"); consume_client_data_with_timeout(reader).await; + wait_mask_outcome_budget(outcome_started).await; } } return; @@ -172,6 +192,8 @@ where let mask_addr = resolve_socket_addr(mask_host, mask_port) .map(|addr| addr.to_string()) .unwrap_or_else(|| format!("{}:{}", mask_host, mask_port)); + let outcome_started = Instant::now(); + let connect_started = Instant::now(); let connect_result = timeout(MASK_TIMEOUT, TcpStream::connect(&mask_addr)).await; match connect_result { Ok(Ok(stream)) => { @@ -202,14 +224,18 @@ where if timeout(MASK_RELAY_TIMEOUT, relay_to_mask(reader, writer, mask_read, mask_write, initial_data)).await.is_err() { debug!("Mask relay timed out"); } + wait_mask_outcome_budget(outcome_started).await; } Ok(Err(e)) => { + wait_mask_connect_budget(connect_started).await; debug!(error = %e, "Failed to connect to mask host"); consume_client_data_with_timeout(reader).await; + wait_mask_outcome_budget(outcome_started).await; } Err(_) => { debug!("Timeout connecting to mask host"); consume_client_data_with_timeout(reader).await; + wait_mask_outcome_budget(outcome_started).await; } } } diff --git a/src/proxy/masking_security_tests.rs b/src/proxy/masking_security_tests.rs index 25b6a76..2310846 100644 --- a/src/proxy/masking_security_tests.rs +++ b/src/proxy/masking_security_tests.rs @@ -8,7 +8,7 @@ use tokio::io::{duplex, AsyncBufReadExt, BufReader}; use tokio::net::TcpListener; #[cfg(unix)] use tokio::net::UnixListener; -use tokio::time::{sleep, timeout, Duration}; +use tokio::time::{Instant, sleep, timeout, Duration}; #[tokio::test] async fn bad_client_probe_is_forwarded_verbatim_to_mask_backend() { @@ -216,6 +216,372 @@ async fn backend_unavailable_falls_back_to_silent_consume() { assert_eq!(n, 0); } +#[tokio::test] +async fn backend_connect_refusal_waits_mask_connect_budget_before_fallback() { + let temp_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let unused_port = temp_listener.local_addr().unwrap().port(); + drop(temp_listener); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = unused_port; + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let peer: SocketAddr = "203.0.113.12:42426".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let probe = b"GET /probe HTTP/1.1\r\nHost: x\r\n\r\n"; + + // Keep reader open so fallback path does not terminate immediately on EOF. + let (_client_reader_side, client_reader) = duplex(256); + let (_client_visible_reader, client_visible_writer) = duplex(256); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + let task = tokio::spawn(async move { + handle_bad_client( + client_reader, + client_visible_writer, + probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + }); + + timeout(Duration::from_millis(35), task) + .await + .expect_err("masking fallback must not complete before connect budget elapses"); + assert!( + started.elapsed() >= Duration::from_millis(35), + "fallback path must absorb immediate connect refusal into connect budget" + ); +} + +#[tokio::test] +async fn backend_reachable_fast_response_waits_mask_outcome_budget() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = b"GET /ok HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + let backend_reply = b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut received = vec![0u8; probe.len()]; + stream.read_exact(&mut received).await.unwrap(); + assert_eq!(received, probe); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let peer: SocketAddr = "203.0.113.13:42427".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (client_writer_side, client_reader) = duplex(256); + drop(client_writer_side); + let (_client_visible_reader, client_visible_writer) = duplex(512); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + assert!( + started.elapsed() >= Duration::from_millis(45), + "reachable mask path must also satisfy coarse outcome budget" + ); + accept_task.await.unwrap(); +} + +#[tokio::test] +async fn mask_disabled_fast_eof_not_shaped_by_mask_budget() { + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = false; + + let peer: SocketAddr = "203.0.113.14:42428".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (client_writer_side, client_reader) = duplex(256); + drop(client_writer_side); + let (_client_visible_reader, client_visible_writer) = duplex(256); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + b"x", + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + assert!( + started.elapsed() < Duration::from_millis(20), + "mask-disabled fallback should keep immediate EOF behavior" + ); +} + +#[tokio::test] +async fn backend_reachable_slow_response_not_padded_twice() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = b"GET /slow HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + let backend_reply = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK".to_vec(); + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut received = vec![0u8; probe.len()]; + stream.read_exact(&mut received).await.unwrap(); + assert_eq!(received, probe); + sleep(Duration::from_millis(90)).await; + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let peer: SocketAddr = "203.0.113.15:42429".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (client_writer_side, client_reader) = duplex(256); + drop(client_writer_side); + let (_client_visible_reader, client_visible_writer) = duplex(512); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + let elapsed = started.elapsed(); + + assert!(elapsed >= Duration::from_millis(85)); + assert!( + elapsed < Duration::from_millis(170), + "slow reachable backend should not incur an extra full budget after already exceeding it" + ); + accept_task.await.unwrap(); +} + +#[tokio::test] +async fn adversarial_enabled_refused_and_reachable_collapse_to_same_bucket() { + const ITER: usize = 20; + const BUCKET_MS: u128 = 10; + + let probe = b"GET /collapse HTTP/1.1\r\nHost: x\r\n\r\n"; + let peer: SocketAddr = "203.0.113.16:42430".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let mut refused = Vec::with_capacity(ITER); + for _ in 0..ITER { + let temp_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let unused_port = temp_listener.local_addr().unwrap().port(); + drop(temp_listener); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = unused_port; + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let (client_writer_side, client_reader) = duplex(256); + drop(client_writer_side); + let (_client_visible_reader, client_visible_writer) = duplex(256); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + refused.push(started.elapsed().as_millis()); + } + + let mut reachable = Vec::with_capacity(ITER); + for _ in 0..ITER { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe_vec = probe.to_vec(); + let backend_reply = b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut received = vec![0u8; probe_vec.len()]; + stream.read_exact(&mut received).await.unwrap(); + stream.write_all(&backend_reply).await.unwrap(); + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let (client_writer_side, client_reader) = duplex(256); + drop(client_writer_side); + let (_client_visible_reader, client_visible_writer) = duplex(256); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + reachable.push(started.elapsed().as_millis()); + accept_task.await.unwrap(); + } + + let refused_mean = refused.iter().copied().sum::() as f64 / refused.len() as f64; + let reachable_mean = reachable.iter().copied().sum::() as f64 / reachable.len() as f64; + let refused_bucket = (refused_mean as u128) / BUCKET_MS; + let reachable_bucket = (reachable_mean as u128) / BUCKET_MS; + + assert!( + refused_bucket.abs_diff(reachable_bucket) <= 1, + "enabled refused and reachable paths must collapse into the same coarse latency bucket" + ); +} + +#[tokio::test] +async fn light_fuzz_mask_enabled_outcomes_preserve_coarse_budget() { + let mut seed: u64 = 0xA5A5_5A5A_1337_4242; + let mut next = || { + seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1); + seed + }; + + let peer: SocketAddr = "203.0.113.17:42431".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + for _ in 0..40 { + let probe_len = (next() as usize % 96).saturating_add(8); + let mut probe = vec![0u8; probe_len]; + for byte in &mut probe { + *byte = next() as u8; + } + + let use_reachable = (next() & 1) == 0; + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let (client_writer_side, client_reader) = duplex(512); + drop(client_writer_side); + let (_client_visible_reader, client_visible_writer) = duplex(512); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + if use_reachable { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + + let probe_vec = probe.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut observed = vec![0u8; probe_vec.len()]; + stream.read_exact(&mut observed).await.unwrap(); + }); + + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + accept_task.await.unwrap(); + } else { + let temp_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let unused_port = temp_listener.local_addr().unwrap().port(); + drop(temp_listener); + + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = unused_port; + + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + } + + assert!( + started.elapsed() >= Duration::from_millis(45), + "mask-enabled fallback must preserve coarse timing budget under varied probe shapes" + ); + } +} + #[tokio::test] async fn mask_disabled_consumes_client_data_without_response() { let mut config = ProxyConfig::default(); @@ -729,3 +1095,158 @@ async fn relay_to_mask_timeout_cancels_and_drops_all_io_endpoints() { assert!(mask_reader_dropped.load(Ordering::SeqCst)); assert!(mask_writer_dropped.load(Ordering::SeqCst)); } + +#[tokio::test] +#[ignore = "timing matrix; run manually with --ignored --nocapture"] +async fn timing_matrix_masking_classes_under_controlled_inputs() { + const ITER: usize = 24; + const BUCKET_MS: u128 = 10; + + let probe = b"GET /timing HTTP/1.1\r\nHost: x\r\n\r\n"; + let peer: SocketAddr = "203.0.113.40:51000".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + // Class 1: masking disabled with immediate EOF (fast fail-closed consume path). + let mut disabled_samples = Vec::with_capacity(ITER); + for _ in 0..ITER { + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = false; + + let (client_writer_side, client_reader) = duplex(256); + drop(client_writer_side); + let (_client_visible_reader, client_visible_writer) = duplex(256); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + disabled_samples.push(started.elapsed().as_millis()); + } + + // Class 2: masking enabled, backend connect refused. + let mut refused_samples = Vec::with_capacity(ITER); + for _ in 0..ITER { + let temp_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let unused_port = temp_listener.local_addr().unwrap().port(); + drop(temp_listener); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = unused_port; + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let (client_writer_side, client_reader) = duplex(256); + drop(client_writer_side); + let (_client_visible_reader, client_visible_writer) = duplex(256); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + refused_samples.push(started.elapsed().as_millis()); + } + + // Class 3: masking enabled, backend reachable and immediately responds. + let mut reachable_samples = Vec::with_capacity(ITER); + for _ in 0..ITER { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let backend_reply = b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\n\r\n".to_vec(); + let probe_vec = probe.to_vec(); + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut received = vec![0u8; probe_vec.len()]; + stream.read_exact(&mut received).await.unwrap(); + assert_eq!(received, probe_vec); + stream.write_all(&backend_reply).await.unwrap(); + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let (client_writer_side, client_reader) = duplex(256); + drop(client_writer_side); + let (_client_visible_reader, client_visible_writer) = duplex(256); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + reachable_samples.push(started.elapsed().as_millis()); + accept_task.await.unwrap(); + } + + fn summarize(samples_ms: &mut [u128]) -> (f64, u128, u128, u128) { + samples_ms.sort_unstable(); + let sum: u128 = samples_ms.iter().copied().sum(); + let mean = sum as f64 / samples_ms.len() as f64; + let min = samples_ms[0]; + let p95_idx = ((samples_ms.len() as f64) * 0.95).floor() as usize; + let p95 = samples_ms[p95_idx.min(samples_ms.len() - 1)]; + let max = samples_ms[samples_ms.len() - 1]; + (mean, min, p95, max) + } + + let (disabled_mean, disabled_min, disabled_p95, disabled_max) = summarize(&mut disabled_samples); + let (refused_mean, refused_min, refused_p95, refused_max) = summarize(&mut refused_samples); + let (reachable_mean, reachable_min, reachable_p95, reachable_max) = summarize(&mut reachable_samples); + + println!( + "TIMING_MATRIX masking class=disabled_eof mean_ms={:.2} min_ms={} p95_ms={} max_ms={} bucket_mean={}", + disabled_mean, + disabled_min, + disabled_p95, + disabled_max, + (disabled_mean as u128) / BUCKET_MS + ); + println!( + "TIMING_MATRIX masking class=enabled_refused_eof mean_ms={:.2} min_ms={} p95_ms={} max_ms={} bucket_mean={}", + refused_mean, + refused_min, + refused_p95, + refused_max, + (refused_mean as u128) / BUCKET_MS + ); + println!( + "TIMING_MATRIX masking class=enabled_reachable_eof mean_ms={:.2} min_ms={} p95_ms={} max_ms={} bucket_mean={}", + reachable_mean, + reachable_min, + reachable_p95, + reachable_max, + (reachable_mean as u128) / BUCKET_MS + ); +} diff --git a/src/transport/middle_proxy/health.rs b/src/transport/middle_proxy/health.rs index 1c2c648..a6b1031 100644 --- a/src/transport/middle_proxy/health.rs +++ b/src/transport/middle_proxy/health.rs @@ -239,7 +239,9 @@ pub(super) async fn reap_draining_writers( if !closed_writer_ids.insert(writer_id) { continue; } - pool.remove_writer_and_close_clients(writer_id).await; + if !pool.remove_writer_if_empty(writer_id).await { + continue; + } closed_total = closed_total.saturating_add(1); } diff --git a/src/transport/middle_proxy/health_regression_tests.rs b/src/transport/middle_proxy/health_regression_tests.rs index fe73670..6b6b12a 100644 --- a/src/transport/middle_proxy/health_regression_tests.rs +++ b/src/transport/middle_proxy/health_regression_tests.rs @@ -592,3 +592,67 @@ async fn reap_draining_writers_mixed_backlog_converges_without_leaking_warn_stat fn general_config_default_drain_threshold_remains_enabled() { assert_eq!(GeneralConfig::default().me_pool_drain_threshold, 128); } + +#[tokio::test] +async fn reap_draining_writers_does_not_close_writer_that_became_non_empty_after_snapshot() { + let pool = make_pool(128).await; + let now_epoch_secs = MePool::now_epoch_secs(); + + let empty_writer_id = 700u64; + insert_draining_writer( + &pool, + empty_writer_id, + now_epoch_secs.saturating_sub(60), + 0, + 0, + ) + .await; + + let stale_empty_snapshot = vec![empty_writer_id]; + let (rebound_conn_id, _rx) = pool.registry.register().await; + assert!( + pool.registry + .bind_writer( + rebound_conn_id, + empty_writer_id, + ConnMeta { + target_dc: 2, + client_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 9050), + our_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443), + proto_flags: 0, + }, + ) + .await, + "writer should accept a new bind after stale empty snapshot" + ); + + for writer_id in stale_empty_snapshot { + assert!( + !pool.remove_writer_if_empty(writer_id).await, + "atomic empty cleanup must reject writers that gained bound clients" + ); + } + + assert!( + writer_exists(&pool, empty_writer_id).await, + "empty-path cleanup must not remove a writer that gained a bound client" + ); + assert_eq!( + pool.registry.get_writer(rebound_conn_id).await.map(|w| w.writer_id), + Some(empty_writer_id) + ); + + let _ = pool.registry.unregister(rebound_conn_id).await; +} + +#[tokio::test] +async fn prune_closed_writers_closes_bound_clients_when_writer_is_non_empty() { + let pool = make_pool(128).await; + let now_epoch_secs = MePool::now_epoch_secs(); + let conn_ids = insert_draining_writer(&pool, 910, now_epoch_secs.saturating_sub(60), 1, 0).await; + + pool.prune_closed_writers().await; + + assert!(!writer_exists(&pool, 910).await); + assert!(pool.registry.get_writer(conn_ids[0]).await.is_none()); +} diff --git a/src/transport/middle_proxy/pool_writer.rs b/src/transport/middle_proxy/pool_writer.rs index 7490a98..5b23d7f 100644 --- a/src/transport/middle_proxy/pool_writer.rs +++ b/src/transport/middle_proxy/pool_writer.rs @@ -42,11 +42,10 @@ impl MePool { } for writer_id in closed_writer_ids { - if self.registry.is_writer_empty(writer_id).await { - let _ = self.remove_writer_only(writer_id).await; - } else { - let _ = self.remove_writer_and_close_clients(writer_id).await; + if self.remove_writer_if_empty(writer_id).await { + continue; } + let _ = self.remove_writer_and_close_clients(writer_id).await; } } @@ -501,6 +500,17 @@ impl MePool { } } + pub(crate) async fn remove_writer_if_empty(self: &Arc, writer_id: u64) -> bool { + if !self.registry.unregister_writer_if_empty(writer_id).await { + return false; + } + + // The registry empty-check and unregister are atomic with respect to binds, + // so remove_writer_only cannot return active bound sessions here. + let _ = self.remove_writer_only(writer_id).await; + true + } + async fn remove_writer_only(self: &Arc, writer_id: u64) -> Vec { let mut close_tx: Option> = None; let mut removed_addr: Option = None; diff --git a/src/transport/middle_proxy/registry.rs b/src/transport/middle_proxy/registry.rs index cbe1d9a..ea968b5 100644 --- a/src/transport/middle_proxy/registry.rs +++ b/src/transport/middle_proxy/registry.rs @@ -437,6 +437,23 @@ impl ConnRegistry { .unwrap_or(true) } + pub async fn unregister_writer_if_empty(&self, writer_id: u64) -> bool { + let mut inner = self.inner.write().await; + let Some(conn_ids) = inner.conns_for_writer.get(&writer_id) else { + // Writer is already absent from the registry. + return true; + }; + if !conn_ids.is_empty() { + return false; + } + + inner.writers.remove(&writer_id); + inner.last_meta_for_writer.remove(&writer_id); + inner.writer_idle_since_epoch_secs.remove(&writer_id); + inner.conns_for_writer.remove(&writer_id); + true + } + pub(super) async fn non_empty_writer_ids(&self, writer_ids: &[u64]) -> HashSet { let inner = self.inner.read().await; let mut out = HashSet::::with_capacity(writer_ids.len());