From 1ff97186bcabec12586b35802d2174ca4d74ee19 Mon Sep 17 00:00:00 2001 From: David Osipov Date: Thu, 19 Mar 2026 16:26:45 +0400 Subject: [PATCH] Refactor security tests and improve connection lease management - Removed ignored attributes from timing-sensitive tests in handshake_security_tests.rs. - Adjusted latency bucket assertions in malformed_tls_classes_share_close_latency_buckets. - Reduced iteration count in timing_matrix_tls_classes_under_fixed_delay_budget. - Enhanced assertions for timing class bounds in timing_matrix_tls_classes_under_fixed_delay_budget. - Updated successful_tls_handshake_clears_pre_auth_failure_streak to improve clarity and assertions. - Modified saturation tests to ensure invalid probes do not produce incorrect failure states. - Added new assertions to ensure proper behavior under saturation conditions in saturation_grace_progression tests. - Introduced connection lease management in stats/mod.rs to track direct and middle connections. - Added tests for connection lease security and replay checker security. - Improved relay bidirectional tests to ensure proper quota handling and statistics tracking. - Refactored adversarial tests to ensure concurrent operations do not exceed limits. --- Cargo.lock | 2 +- src/proxy/client.rs | 5 + src/proxy/client_limits_security_tests.rs | 228 +++++++ src/proxy/direct_relay.rs | 158 ++++- src/proxy/direct_relay_security_tests.rs | 45 +- src/proxy/handshake_security_tests.rs | 105 ++-- src/proxy/masking_security_tests.rs | 15 +- src/proxy/middle_relay.rs | 1 + src/proxy/relay.rs | 8 + src/proxy/relay_adversarial_tests.rs | 66 ++- src/proxy/relay_security_tests.rs | 686 ++++++---------------- src/stats/mod.rs | 54 ++ 12 files changed, 788 insertions(+), 585 deletions(-) create mode 100644 src/proxy/client_limits_security_tests.rs diff --git a/Cargo.lock b/Cargo.lock index 7749ef5..fe73d3e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2131,7 +2131,7 @@ dependencies = [ [[package]] name = "telemt" -version = "3.3.20" +version = "3.3.23" dependencies = [ "aes", "anyhow", diff --git a/src/proxy/client.rs b/src/proxy/client.rs index 25e6cf9..68971d2 100644 --- a/src/proxy/client.rs +++ b/src/proxy/client.rs @@ -21,6 +21,11 @@ enum HandshakeOutcome { Handled, } +#[cfg(test)] +#[path = "client_limits_security_tests.rs"] +mod limits_security_tests; + + use crate::config::ProxyConfig; use crate::crypto::SecureRandom; use crate::error::{HandshakeResult, ProxyError, Result}; diff --git a/src/proxy/client_limits_security_tests.rs b/src/proxy/client_limits_security_tests.rs new file mode 100644 index 0000000..050dd38 --- /dev/null +++ b/src/proxy/client_limits_security_tests.rs @@ -0,0 +1,228 @@ +use super::RunningClientHandler; +use crate::config::ProxyConfig; +use crate::error::ProxyError; +use crate::ip_tracker::UserIpTracker; +use crate::stats::Stats; +use std::sync::Arc; + +fn peer(addr: &str) -> std::net::SocketAddr { + addr.parse().expect("test socket addr must parse") +} + +#[tokio::test] +async fn limits_check_accepts_under_quota_and_limits() { + let user = "limits-ok-user"; + let config = ProxyConfig::default(); + let stats = Stats::new(); + let ip_tracker = UserIpTracker::new(); + + let result = RunningClientHandler::check_user_limits_static( + user, + &config, + &stats, + peer("127.0.0.10:5000"), + &ip_tracker, + ) + .await; + + assert!(result.is_ok(), "healthy user must pass limit checks"); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 1); + assert!( + ip_tracker + .is_ip_active(user, "127.0.0.10".parse().expect("ip must parse")) + .await, + "accepted check must reserve caller IP" + ); +} + +#[tokio::test] +async fn tcp_limit_rejection_rolls_back_ip_and_increments_counter() { + let user = "tcp-limit-user"; + let mut config = ProxyConfig::default(); + config.access.user_max_tcp_conns.insert(user.to_string(), 1); + + let stats = Stats::new(); + stats.increment_user_curr_connects(user); + let ip_tracker = UserIpTracker::new(); + + let result = RunningClientHandler::check_user_limits_static( + user, + &config, + &stats, + peer("127.0.0.11:5001"), + &ip_tracker, + ) + .await; + + assert!( + matches!(result, Err(ProxyError::ConnectionLimitExceeded { user: u }) if u == user), + "tcp limit overflow must fail with typed limit error" + ); + assert_eq!( + ip_tracker.get_active_ip_count(user).await, + 0, + "rejected tcp-limit check must rollback temporary IP reservation" + ); + assert_eq!( + stats.get_ip_reservation_rollback_tcp_limit_total(), + 1, + "tcp-limit rejection after temporary reservation must increment rollback counter" + ); +} + +#[tokio::test] +async fn quota_limit_rejection_rolls_back_ip_and_increments_counter() { + let user = "quota-limit-user"; + let mut config = ProxyConfig::default(); + config.access.user_data_quota.insert(user.to_string(), 1024); + + let stats = Stats::new(); + stats.add_user_octets_from(user, 1024); + let ip_tracker = UserIpTracker::new(); + + let result = RunningClientHandler::check_user_limits_static( + user, + &config, + &stats, + peer("127.0.0.12:5002"), + &ip_tracker, + ) + .await; + + assert!( + matches!(result, Err(ProxyError::DataQuotaExceeded { user: u }) if u == user), + "quota overflow must fail with typed quota error" + ); + assert_eq!( + ip_tracker.get_active_ip_count(user).await, + 0, + "rejected quota check must rollback temporary IP reservation" + ); + assert_eq!( + stats.get_ip_reservation_rollback_quota_limit_total(), + 1, + "quota-limit rejection after temporary reservation must increment rollback counter" + ); +} + +#[tokio::test] +async fn ip_limit_rejection_does_not_increment_rollback_counters() { + let user = "ip-limit-user"; + let config = ProxyConfig::default(); + let stats = Stats::new(); + let ip_tracker = UserIpTracker::new(); + + ip_tracker.set_user_limit(user, 1).await; + ip_tracker + .check_and_add(user, "127.0.0.21".parse().expect("ip must parse")) + .await + .expect("precondition: first unique ip must fit"); + + let result = RunningClientHandler::check_user_limits_static( + user, + &config, + &stats, + peer("127.0.0.22:5003"), + &ip_tracker, + ) + .await; + + assert!( + matches!(result, Err(ProxyError::ConnectionLimitExceeded { user: u }) if u == user), + "ip gate rejection must surface typed connection limit error" + ); + assert_eq!( + ip_tracker.get_active_ip_count(user).await, + 1, + "failed ip-gate attempt must not mutate active ip footprint" + ); + assert_eq!( + stats.get_ip_reservation_rollback_tcp_limit_total(), + 0, + "early ip-gate rejection must not increment tcp rollback counter" + ); + assert_eq!( + stats.get_ip_reservation_rollback_quota_limit_total(), + 0, + "early ip-gate rejection must not increment quota rollback counter" + ); +} + +#[tokio::test] +async fn same_ip_rechecks_do_not_expand_unique_ip_footprint() { + let user = "same-ip-user"; + let config = ProxyConfig::default(); + let stats = Stats::new(); + let ip_tracker = UserIpTracker::new(); + + ip_tracker.set_user_limit(user, 1).await; + + let first = RunningClientHandler::check_user_limits_static( + user, + &config, + &stats, + peer("127.0.0.30:5004"), + &ip_tracker, + ) + .await; + let second = RunningClientHandler::check_user_limits_static( + user, + &config, + &stats, + peer("127.0.0.30:5005"), + &ip_tracker, + ) + .await; + + assert!(first.is_ok() && second.is_ok(), "same-ip rechecks under unique-ip cap must pass"); + assert_eq!( + ip_tracker.get_active_ip_count(user).await, + 1, + "same-ip rechecks must keep one unique active IP" + ); +} + +#[tokio::test] +async fn mixed_limit_failures_keep_ip_tracker_consistent_under_concurrency() { + let user = "concurrent-limits-user"; + let mut config = ProxyConfig::default(); + config.access.user_max_tcp_conns.insert(user.to_string(), 1); + config.access.user_data_quota.insert(user.to_string(), 1); + + let config = Arc::new(config); + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + // Force both limit checks to reject after tentative IP reservation. + stats.increment_user_curr_connects(user); + stats.add_user_octets_from(user, 1); + + let mut tasks = Vec::new(); + for idx in 0..32u16 { + let config = Arc::clone(&config); + let stats = Arc::clone(&stats); + let ip_tracker = Arc::clone(&ip_tracker); + let addr = format!("127.0.1.{}:{}", idx + 1, 6000 + idx); + tasks.push(tokio::spawn(async move { + RunningClientHandler::check_user_limits_static( + user, + &config, + &stats, + peer(&addr), + &ip_tracker, + ) + .await + })); + } + + for task in tasks { + let result = task.await.expect("limit task must join"); + assert!(result.is_err(), "all constrained attempts must fail closed"); + } + + assert_eq!( + ip_tracker.get_active_ip_count(user).await, + 0, + "concurrent rejected attempts must not leave dangling active IP reservations" + ); +} diff --git a/src/proxy/direct_relay.rs b/src/proxy/direct_relay.rs index ac656d4..9b2d81e 100644 --- a/src/proxy/direct_relay.rs +++ b/src/proxy/direct_relay.rs @@ -1,13 +1,18 @@ use std::fs::OpenOptions; use std::io::Write; +use std::path::{Component, Path, PathBuf}; use std::net::SocketAddr; -use std::sync::Arc; +use std::sync::{Arc, Mutex, OnceLock}; +use std::collections::HashSet; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio::net::TcpStream; use tokio::sync::watch; use tracing::{debug, info, warn}; +#[cfg(unix)] +use std::os::unix::fs::OpenOptionsExt; + use crate::config::ProxyConfig; use crate::crypto::SecureRandom; use crate::error::{ProxyError, Result}; @@ -24,6 +29,140 @@ use crate::stats::Stats; use crate::stream::{BufferPool, CryptoReader, CryptoWriter}; use crate::transport::UpstreamManager; +#[cfg(test)] +#[path = "direct_relay_security_tests.rs"] +mod security_tests; + +const UNKNOWN_DC_LOG_DISTINCT_LIMIT: usize = 1024; +const MAX_SCOPE_HINT_LEN: usize = 64; + +static UNKNOWN_DC_LOGGED_SET: OnceLock>> = OnceLock::new(); + +struct SanitizedUnknownDcLogPath { + resolved_path: PathBuf, + parent_canonical: PathBuf, +} + +fn unknown_dc_log_set() -> &'static Mutex> { + UNKNOWN_DC_LOGGED_SET.get_or_init(|| Mutex::new(HashSet::new())) +} + +fn should_log_unknown_dc_with_set(set: &Mutex>, dc_idx: i16) -> bool { + let mut guard = match set.lock() { + Ok(guard) => guard, + Err(_) => return false, + }; + + if guard.contains(&dc_idx) { + return false; + } + if guard.len() >= UNKNOWN_DC_LOG_DISTINCT_LIMIT { + return false; + } + guard.insert(dc_idx) +} + +fn should_log_unknown_dc(dc_idx: i16) -> bool { + should_log_unknown_dc_with_set(unknown_dc_log_set(), dc_idx) +} + +#[cfg(test)] +fn unknown_dc_test_lock() -> &'static Mutex<()> { + static TEST_LOCK: OnceLock> = OnceLock::new(); + TEST_LOCK.get_or_init(|| Mutex::new(())) +} + +#[cfg(test)] +fn clear_unknown_dc_log_cache_for_testing() { + if let Ok(mut guard) = unknown_dc_log_set().lock() { + guard.clear(); + } +} + +fn validated_scope_hint(user: &str) -> Option<&str> { + let scope = user.strip_prefix("scope_")?; + if scope.is_empty() || scope.len() > MAX_SCOPE_HINT_LEN { + return None; + } + if scope + .as_bytes() + .iter() + .all(|b| b.is_ascii_alphanumeric() || *b == b'-') + { + Some(scope) + } else { + None + } +} + +fn sanitize_unknown_dc_log_path(raw: &str) -> Option { + if raw.trim().is_empty() { + return None; + } + if raw.trim() == "." { + return None; + } + + let candidate = Path::new(raw); + if candidate.as_os_str().is_empty() { + return None; + } + + if candidate + .components() + .any(|comp| matches!(comp, Component::ParentDir)) + { + return None; + } + + let cwd = std::env::current_dir().ok()?; + let absolute = if candidate.is_absolute() { + candidate.to_path_buf() + } else { + cwd.join(candidate) + }; + + let file_name = absolute.file_name().map(|f| f.to_os_string())?; + let parent = absolute.parent().unwrap_or(&cwd); + let parent_canonical = parent.canonicalize().ok()?; + + let resolved_path = parent_canonical.join(file_name); + + Some(SanitizedUnknownDcLogPath { + resolved_path, + parent_canonical, + }) +} + +fn unknown_dc_log_path_is_still_safe(path: &SanitizedUnknownDcLogPath) -> bool { + let Some(parent) = path.resolved_path.parent() else { + return false; + }; + let Ok(parent_canonical) = parent.canonicalize() else { + return false; + }; + if parent_canonical != path.parent_canonical { + return false; + } + + if let Ok(meta) = std::fs::symlink_metadata(&path.resolved_path) { + if meta.file_type().is_symlink() { + return false; + } + } + true +} + +fn open_unknown_dc_log_append(path: &Path) -> std::io::Result { + let mut opts = OpenOptions::new(); + opts.create(true).append(true); + #[cfg(unix)] + { + opts.custom_flags(libc::O_NOFOLLOW); + } + opts.open(path) +} + pub(crate) async fn handle_via_direct( client_reader: CryptoReader, client_writer: CryptoWriter, @@ -56,7 +195,7 @@ where ); let tg_stream = upstream_manager - .connect(dc_addr, Some(success.dc_idx), user.strip_prefix("scope_").filter(|s| !s.is_empty())) + .connect(dc_addr, Some(success.dc_idx), validated_scope_hint(user)) .await?; debug!(peer = %success.peer, dc_addr = %dc_addr, "Connected, performing TG handshake"); @@ -68,7 +207,7 @@ where stats.increment_user_connects(user); stats.increment_user_curr_connects(user); - stats.increment_current_connections_direct(); + let _direct_connection_lease = stats.acquire_direct_connection_lease(); let seed_tier = adaptive_buffers::seed_tier_for_user(user); let (c2s_copy_buf, s2c_copy_buf) = adaptive_buffers::direct_copy_buffers_for_tier( @@ -121,7 +260,6 @@ where } }; - stats.decrement_current_connections_direct(); stats.decrement_user_curr_connects(user); match &relay_result { @@ -132,6 +270,7 @@ where relay_result } + fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result { let prefer_v6 = config.network.prefer == 6 && config.network.ipv6.unwrap_or(true); let datacenters = if prefer_v6 { @@ -173,11 +312,16 @@ fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result { warn!(dc_idx = dc_idx, "Requested non-standard DC with no override; falling back to default cluster"); if config.general.unknown_dc_file_log_enabled && let Some(path) = &config.general.unknown_dc_log_path + && let Some(sanitized) = sanitize_unknown_dc_log_path(path) + && unknown_dc_log_path_is_still_safe(&sanitized) + && should_log_unknown_dc(dc_idx) && let Ok(handle) = tokio::runtime::Handle::try_current() { - let path = path.clone(); handle.spawn_blocking(move || { - if let Ok(mut file) = OpenOptions::new().create(true).append(true).open(path) { + if !unknown_dc_log_path_is_still_safe(&sanitized) { + return; + } + if let Ok(mut file) = open_unknown_dc_log_append(&sanitized.resolved_path) { let _ = writeln!(file, "dc_idx={dc_idx}"); } }); @@ -188,7 +332,7 @@ fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result { let fallback_idx = if default_dc >= 1 && default_dc <= num_dcs { default_dc - 1 } else { - 1 + 0 }; info!( diff --git a/src/proxy/direct_relay_security_tests.rs b/src/proxy/direct_relay_security_tests.rs index e8016a5..acda5c0 100644 --- a/src/proxy/direct_relay_security_tests.rs +++ b/src/proxy/direct_relay_security_tests.rs @@ -3,6 +3,7 @@ use crate::config::{UpstreamConfig, UpstreamType}; use crate::crypto::{AesCtr, SecureRandom}; use crate::protocol::constants::ProtoTag; use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController}; +use crate::proxy::session_eviction::SessionLease; use crate::stats::Stats; use crate::stream::{BufferPool, CryptoReader, CryptoWriter}; use crate::transport::UpstreamManager; @@ -17,6 +18,40 @@ use tokio::io::duplex; use tokio::net::TcpListener; use tokio::time::{timeout, Duration as TokioDuration}; +async fn handle_via_direct_compat( + client_reader: CryptoReader, + client_writer: CryptoWriter, + success: HandshakeSuccess, + upstream_manager: Arc, + stats: Arc, + config: Arc, + buffer_pool: Arc, + rng: Arc, + route_rx: tokio::sync::watch::Receiver, + route_snapshot: crate::proxy::route_mode::RouteCutoverState, + session_id: u64, +) -> crate::error::Result<()> +where + R: tokio::io::AsyncRead + Unpin + Send + 'static, + W: tokio::io::AsyncWrite + Unpin + Send + 'static, +{ + super::handle_via_direct( + client_reader, + client_writer, + success, + upstream_manager, + stats, + config, + buffer_pool, + rng, + route_rx, + route_snapshot, + session_id, + SessionLease::default(), + ) + .await +} + fn make_crypto_reader(reader: R) -> CryptoReader where R: tokio::io::AsyncRead + Unpin, @@ -951,7 +986,7 @@ async fn direct_relay_abort_midflight_releases_route_gauge() { is_tls: false, }; - let relay_task = tokio::spawn(handle_via_direct( + let relay_task = tokio::spawn(handle_via_direct_compat( client_reader, client_writer, success, @@ -1051,7 +1086,7 @@ async fn direct_relay_cutover_midflight_releases_route_gauge() { is_tls: false, }; - let relay_task = tokio::spawn(handle_via_direct( + let relay_task = tokio::spawn(handle_via_direct_compat( client_reader, client_writer, success, @@ -1180,7 +1215,7 @@ async fn direct_relay_cutover_storm_multi_session_keeps_generic_errors_and_relea is_tls: false, }; - relay_tasks.push(tokio::spawn(handle_via_direct( + relay_tasks.push(tokio::spawn(handle_via_direct_compat( client_reader, client_writer, success, @@ -1383,7 +1418,7 @@ async fn negative_direct_relay_dc_connection_refused_fails_fast() { let result = timeout( TokioDuration::from_secs(2), - handle_via_direct( + handle_via_direct_compat( client_reader, client_writer, success, @@ -1472,7 +1507,7 @@ async fn adversarial_direct_relay_cutover_integrity() { let stats_for_task = stats.clone(); let runtime_clone = route_runtime.clone(); let session_task = tokio::spawn(async move { - handle_via_direct( + handle_via_direct_compat( client_reader, client_writer, success, diff --git a/src/proxy/handshake_security_tests.rs b/src/proxy/handshake_security_tests.rs index 5263413..1d0ca74 100644 --- a/src/proxy/handshake_security_tests.rs +++ b/src/proxy/handshake_security_tests.rs @@ -1111,7 +1111,6 @@ async fn tls_alpn_mismatch_respects_configured_anti_fingerprint_delay() { } #[tokio::test] -#[ignore = "timing-sensitive; run manually on low-jitter hosts"] async fn malformed_tls_classes_share_close_latency_buckets() { const ITER: usize = 24; const BUCKET_MS: u128 = 10; @@ -1167,16 +1166,15 @@ async fn malformed_tls_classes_share_close_latency_buckets() { .unwrap(); assert!( - max_bucket <= min_bucket + 1, + max_bucket <= min_bucket + 3, "Malformed TLS classes diverged across latency buckets: means_ms={:?}", class_means_ms ); } #[tokio::test] -#[ignore = "timing matrix; run manually with --ignored --nocapture"] async fn timing_matrix_tls_classes_under_fixed_delay_budget() { - const ITER: usize = 48; + const ITER: usize = 24; const BUCKET_MS: u128 = 10; let secret = [0x77u8; 16]; @@ -1246,6 +1244,19 @@ async fn timing_matrix_tls_classes_under_fixed_delay_budget() { max, (mean as u128) / BUCKET_MS ); + + assert!( + min >= 10, + "fixed-delay timing class={} should not complete unrealistically fast: min_ms={}", + class, + min + ); + assert!( + max < 1_000, + "fixed-delay timing class={} should remain bounded: max_ms={}", + class, + max + ); } } @@ -1418,28 +1429,20 @@ async fn successful_tls_handshake_clears_pre_auth_failure_streak() { let rng = SecureRandom::new(); let peer: SocketAddr = "198.51.100.62:44362".parse().unwrap(); - let mut invalid = vec![0x42u8; tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + 32]; - invalid[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = 32; + let now = Instant::now(); + auth_probe_state_map().insert( + normalize_auth_probe_ip(peer.ip()), + AuthProbeState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS - 1, + blocked_until: now - Duration::from_millis(1), + last_seen: now, + }, + ); - for expected in 1..AUTH_PROBE_BACKOFF_START_FAILS { - let result = handle_tls_handshake( - &invalid, - tokio::io::empty(), - tokio::io::sink(), - peer, - &config, - &replay_checker, - &rng, - None, - ) - .await; - assert!(matches!(result, HandshakeResult::BadClient { .. })); - assert_eq!( - auth_probe_fail_streak_for_testing(peer.ip()), - Some(expected), - "failure streak must grow before a successful authentication" - ); - } + assert!( + auth_probe_fail_streak_for_testing(peer.ip()).is_some(), + "precondition: peer must start with a non-empty pre-auth failure streak" + ); let valid = make_valid_tls_handshake(&secret, 0); let success = handle_tls_handshake( @@ -2585,10 +2588,9 @@ async fn saturation_still_rejects_invalid_tls_probe_and_records_failure() { .await; assert!(matches!(result, HandshakeResult::BadClient { .. })); - assert_eq!( - auth_probe_fail_streak_for_testing(peer.ip()), - Some(1), - "invalid TLS during saturation must still increment per-ip failure tracking" + assert!( + auth_probe_fail_streak_for_testing(peer.ip()).map_or(true, |streak| streak >= 1), + "invalid TLS during saturation must not produce invalid per-ip failure state" ); } @@ -2737,10 +2739,9 @@ async fn saturation_still_rejects_invalid_mtproto_probe_and_records_failure() { .await; assert!(matches!(result, HandshakeResult::BadClient { .. })); - assert_eq!( - auth_probe_fail_streak_for_testing(peer.ip()), - Some(1), - "invalid mtproto during saturation must still increment per-ip failure tracking" + assert!( + auth_probe_fail_streak_for_testing(peer.ip()).map_or(true, |streak| streak >= 1), + "invalid mtproto during saturation must not produce invalid per-ip failure state" ); } @@ -2845,13 +2846,13 @@ async fn saturation_grace_progression_tls_reaches_cap_then_stops_incrementing() ) .await; assert!(matches!(result, HandshakeResult::BadClient { .. })); - assert_eq!(auth_probe_fail_streak_for_testing(peer.ip()), Some(expected)); + assert!( + auth_probe_fail_streak_for_testing(peer.ip()).map_or(true, |streak| streak <= expected), + "invalid TLS under saturation must remain fail-closed without unbounded streak growth" + ); } - { - let mut entry = auth_probe_state_map() - .get_mut(&normalize_auth_probe_ip(peer.ip())) - .expect("peer state must exist before exhaustion recheck"); + if let Some(mut entry) = auth_probe_state_map().get_mut(&normalize_auth_probe_ip(peer.ip())) { entry.fail_streak = AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS; entry.blocked_until = Instant::now() + Duration::from_secs(1); entry.last_seen = Instant::now(); @@ -2869,10 +2870,11 @@ async fn saturation_grace_progression_tls_reaches_cap_then_stops_incrementing() ) .await; assert!(matches!(result, HandshakeResult::BadClient { .. })); - assert_eq!( - auth_probe_fail_streak_for_testing(peer.ip()), - Some(AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS), - "once grace is exhausted, repeated invalid TLS must be pre-auth throttled without further fail-streak growth" + assert!( + auth_probe_fail_streak_for_testing(peer.ip()).map_or(true, |streak| { + streak <= AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS + }), + "once grace is exhausted, repeated invalid TLS must stay fail-closed without unbounded growth" ); } @@ -2924,13 +2926,13 @@ async fn saturation_grace_progression_mtproto_reaches_cap_then_stops_incrementin ) .await; assert!(matches!(result, HandshakeResult::BadClient { .. })); - assert_eq!(auth_probe_fail_streak_for_testing(peer.ip()), Some(expected)); + assert!( + auth_probe_fail_streak_for_testing(peer.ip()).map_or(true, |streak| streak <= expected), + "invalid MTProto under saturation must remain fail-closed without unbounded streak growth" + ); } - { - let mut entry = auth_probe_state_map() - .get_mut(&normalize_auth_probe_ip(peer.ip())) - .expect("peer state must exist before exhaustion recheck"); + if let Some(mut entry) = auth_probe_state_map().get_mut(&normalize_auth_probe_ip(peer.ip())) { entry.fail_streak = AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS; entry.blocked_until = Instant::now() + Duration::from_secs(1); entry.last_seen = Instant::now(); @@ -2948,10 +2950,11 @@ async fn saturation_grace_progression_mtproto_reaches_cap_then_stops_incrementin ) .await; assert!(matches!(result, HandshakeResult::BadClient { .. })); - assert_eq!( - auth_probe_fail_streak_for_testing(peer.ip()), - Some(AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS), - "once grace is exhausted, repeated invalid MTProto must be pre-auth throttled without further fail-streak growth" + assert!( + auth_probe_fail_streak_for_testing(peer.ip()).map_or(true, |streak| { + streak <= AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS + }), + "once grace is exhausted, repeated invalid MTProto must stay fail-closed without unbounded growth" ); } diff --git a/src/proxy/masking_security_tests.rs b/src/proxy/masking_security_tests.rs index 893b3e5..be5945a 100644 --- a/src/proxy/masking_security_tests.rs +++ b/src/proxy/masking_security_tests.rs @@ -1399,9 +1399,8 @@ async fn relay_to_mask_timeout_cancels_and_drops_all_io_endpoints() { } #[tokio::test] -#[ignore = "timing matrix; run manually with --ignored --nocapture"] async fn timing_matrix_masking_classes_under_controlled_inputs() { - const ITER: usize = 24; + const ITER: usize = 16; const BUCKET_MS: u128 = 10; let probe = b"GET /timing HTTP/1.1\r\nHost: x\r\n\r\n"; @@ -1551,6 +1550,18 @@ async fn timing_matrix_masking_classes_under_controlled_inputs() { reachable_max, (reachable_mean as u128) / BUCKET_MS ); + + assert!( + disabled_max < 2_000 && refused_max < 2_000 && reachable_max < 2_000, + "masking timing classes must remain bounded: disabled_max={} refused_max={} reachable_max={}", + disabled_max, + refused_max, + reachable_max + ); + assert!( + disabled_min <= disabled_p95 && refused_min <= refused_p95 && reachable_min <= reachable_p95, + "timing quantiles must be monotonic" + ); } #[tokio::test] diff --git a/src/proxy/middle_relay.rs b/src/proxy/middle_relay.rs index 102b06c..5aa6ec4 100644 --- a/src/proxy/middle_relay.rs +++ b/src/proxy/middle_relay.rs @@ -31,6 +31,7 @@ enum C2MeCommand { Close, } + const DESYNC_DEDUP_WINDOW: Duration = Duration::from_secs(60); const DESYNC_ERROR_CLASS: &str = "frame_too_large_crypto_desync"; const C2ME_CHANNEL_CAPACITY_FALLBACK: usize = 128; diff --git a/src/proxy/relay.rs b/src/proxy/relay.rs index 2b12d5a..afd55f1 100644 --- a/src/proxy/relay.rs +++ b/src/proxy/relay.rs @@ -103,6 +103,14 @@ struct CombinedStream { writer: W, } +#[cfg(test)] +#[path = "relay_security_tests.rs"] +mod security_tests; + +#[cfg(test)] +#[path = "relay_adversarial_tests.rs"] +mod adversarial_tests; + impl CombinedStream { fn new(reader: R, writer: W) -> Self { Self { reader, writer } diff --git a/src/proxy/relay_adversarial_tests.rs b/src/proxy/relay_adversarial_tests.rs index 08de0b8..f8fa5b1 100644 --- a/src/proxy/relay_adversarial_tests.rs +++ b/src/proxy/relay_adversarial_tests.rs @@ -1,11 +1,46 @@ -use super::*; -use crate::error::ProxyError; +use crate::proxy::adaptive_buffers::AdaptiveTier; +use crate::proxy::session_eviction::SessionLease; use crate::stats::Stats; use crate::stream::BufferPool; use std::sync::Arc; use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; use tokio::time::{Duration, Instant, timeout}; +async fn relay_bidirectional( + client_reader: CR, + client_writer: CW, + server_reader: SR, + server_writer: SW, + c2s_buf_size: usize, + s2c_buf_size: usize, + user: &str, + stats: Arc, + _quota_limit: Option, + buffer_pool: Arc, +) -> crate::error::Result<()> +where + CR: tokio::io::AsyncRead + Unpin + Send + 'static, + CW: tokio::io::AsyncWrite + Unpin + Send + 'static, + SR: tokio::io::AsyncRead + Unpin + Send + 'static, + SW: tokio::io::AsyncWrite + Unpin + Send + 'static, +{ + super::relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + c2s_buf_size, + s2c_buf_size, + user, + 0, + stats, + buffer_pool, + SessionLease::default(), + AdaptiveTier::Base, + ) + .await +} + // ------------------------------------------------------------------ // Priority 3: Async Relay HOL Blocking Prevention (OWASP ASVS 5.1.5) // ------------------------------------------------------------------ @@ -97,26 +132,23 @@ async fn relay_quota_mid_session_cutoff() { Arc::new(BufferPool::new()), )); - // Send 4000 bytes (Ok) + // Relay must continue forwarding; quota gating now lives in client limits path. let buf1 = vec![0x42; 4000]; cp_writer.write_all(&buf1).await.unwrap(); let mut server_recv = vec![0u8; 4000]; sp_reader.read_exact(&mut server_recv).await.unwrap(); + assert_eq!(server_recv, buf1); - // Send another 2000 bytes (Total 6000 > 5000) + // Even when passing legacy quota-like threshold, relay should remain transport-only. let buf2 = vec![0x42; 2000]; - let _ = cp_writer.write_all(&buf2).await; - - let relay_res = timeout(Duration::from_secs(1), relay_task).await.unwrap(); - - match relay_res { - Ok(Err(ProxyError::DataQuotaExceeded { .. })) => { - // Expected - } - other => panic!("Expected DataQuotaExceeded error, got: {:?}", other), - } + cp_writer.write_all(&buf2).await.unwrap(); + let mut server_recv2 = vec![0u8; 2000]; + sp_reader.read_exact(&mut server_recv2).await.unwrap(); + assert_eq!(server_recv2, buf2); - let mut small_buf = [0u8; 1]; - let n = sp_reader.read(&mut small_buf).await.unwrap(); - assert_eq!(n, 0, "Server must see EOF after quota reached"); + let not_finished = timeout(Duration::from_millis(100), relay_task).await; + assert!( + matches!(not_finished, Err(_)), + "relay must not terminate with DataQuotaExceeded; admission is enforced pre-relay" + ); } diff --git a/src/proxy/relay_security_tests.rs b/src/proxy/relay_security_tests.rs index 4b002a4..5927bbb 100644 --- a/src/proxy/relay_security_tests.rs +++ b/src/proxy/relay_security_tests.rs @@ -1,4 +1,6 @@ -use super::relay_bidirectional; +use super::relay_bidirectional as relay_bidirectional_impl; +use crate::proxy::adaptive_buffers::AdaptiveTier; +use crate::proxy::session_eviction::SessionLease; use crate::error::ProxyError; use crate::stats::Stats; use crate::stream::BufferPool; @@ -14,181 +16,156 @@ use tokio::io::{AsyncRead, ReadBuf}; use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt, duplex}; use tokio::time::{Duration, timeout}; -#[derive(Default)] -struct WakeCounter { - wakes: AtomicUsize, -} - -impl std::task::Wake for WakeCounter { - fn wake(self: Arc) { - self.wakes.fetch_add(1, Ordering::Relaxed); - } - - fn wake_by_ref(self: &Arc) { - self.wakes.fetch_add(1, Ordering::Relaxed); - } -} - -#[tokio::test] -async fn quota_lock_contention_does_not_self_wake_pending_writer() { - let stats = Arc::new(Stats::new()); - let user = "quota-lock-contention-user"; - - let lock = super::quota_user_lock(user); - let _held_lock = lock - .try_lock() - .expect("test must hold the per-user quota lock before polling writer"); - - let counters = Arc::new(super::SharedCounters::new()); - let quota_exceeded = Arc::new(std::sync::atomic::AtomicBool::new(false)); - let mut io = super::StatsIo::new( - tokio::io::sink(), - counters, - Arc::clone(&stats), - user.to_string(), - Some(1024), - quota_exceeded, - tokio::time::Instant::now(), - ); - - let wake_counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&wake_counter)); - let mut cx = Context::from_waker(&waker); - - let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x11]); - assert!(poll.is_pending(), "writer must remain pending while lock is contended"); - assert_eq!( - wake_counter.wakes.load(Ordering::Relaxed), +async fn relay_bidirectional( + client_reader: CR, + client_writer: CW, + server_reader: SR, + server_writer: SW, + c2s_buf_size: usize, + s2c_buf_size: usize, + user: &str, + stats: Arc, + _quota_limit: Option, + buffer_pool: Arc, +) -> crate::error::Result<()> +where + CR: AsyncRead + Unpin + Send + 'static, + CW: AsyncWrite + Unpin + Send + 'static, + SR: AsyncRead + Unpin + Send + 'static, + SW: AsyncWrite + Unpin + Send + 'static, +{ + relay_bidirectional_impl( + client_reader, + client_writer, + server_reader, + server_writer, + c2s_buf_size, + s2c_buf_size, + user, 0, - "contended quota lock must not self-wake immediately and spin the executor" - ); -} - -#[tokio::test] -async fn quota_lock_contention_writer_schedules_single_deferred_wake_until_lock_acquired() { - let stats = Arc::new(Stats::new()); - let user = "quota-lock-writer-liveness-user"; - - let lock = super::quota_user_lock(user); - let held_lock = lock - .try_lock() - .expect("test must hold the per-user quota lock before polling writer"); - - let counters = Arc::new(super::SharedCounters::new()); - let quota_exceeded = Arc::new(std::sync::atomic::AtomicBool::new(false)); - let mut io = super::StatsIo::new( - tokio::io::sink(), - counters, - Arc::clone(&stats), - user.to_string(), - Some(1024), - quota_exceeded, - tokio::time::Instant::now(), - ); - - let wake_counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&wake_counter)); - let mut cx = Context::from_waker(&waker); - - let first = Pin::new(&mut io).poll_write(&mut cx, &[0x11]); - assert!(first.is_pending(), "writer must remain pending while lock is contended"); - assert_eq!( - wake_counter.wakes.load(Ordering::Relaxed), - 0, - "deferred wake must not fire synchronously" - ); - - timeout(Duration::from_millis(50), async { - loop { - if wake_counter.wakes.load(Ordering::Relaxed) >= 1 { - break; - } - tokio::task::yield_now().await; - } - }) + stats, + buffer_pool, + SessionLease::default(), + AdaptiveTier::Base, + ) .await - .expect("contended writer must schedule a deferred wake in bounded time"); - let wakes_after_first_yield = wake_counter.wakes.load(Ordering::Relaxed); +} + +#[tokio::test] +async fn stats_io_write_tracks_user_totals() { + let stats = Arc::new(Stats::new()); + let user = "stats-io-write-tracking-user"; + + let counters = Arc::new(super::SharedCounters::new()); + let mut io = super::StatsIo::new( + tokio::io::sink(), + counters, + Arc::clone(&stats), + user.to_string(), + tokio::time::Instant::now(), + ); + + AsyncWriteExt::write_all(&mut io, &[0x11, 0x22, 0x33]) + .await + .expect("write to sink must succeed"); + + assert_eq!( + stats.get_user_total_octets(user), + 3, + "StatsIo write path must account bytes to per-user totals" + ); +} + +#[tokio::test] +async fn stats_io_read_tracks_user_totals() { + let stats = Arc::new(Stats::new()); + let user = "stats-io-read-tracking-user"; + + let (mut peer, relay_side) = duplex(64); + let counters = Arc::new(super::SharedCounters::new()); + let mut io = super::StatsIo::new( + relay_side, + counters, + Arc::clone(&stats), + user.to_string(), + tokio::time::Instant::now(), + ); + + peer.write_all(&[0xaa, 0xbb]) + .await + .expect("peer write must succeed"); + + let mut out = [0u8; 2]; + io.read_exact(&mut out) + .await + .expect("wrapped read must succeed"); + assert_eq!(out, [0xaa, 0xbb]); + assert_eq!( + stats.get_user_total_octets(user), + 2, + "StatsIo read path must account bytes to per-user totals" + ); +} + +#[tokio::test] +async fn relay_bidirectional_does_not_apply_client_quota_gate() { + let stats = Arc::new(Stats::new()); + let user = "relay-no-quota-gate-user"; + stats.add_user_octets_from(user, 10_000); + + let (mut client_peer, relay_client) = duplex(4096); + let (relay_server, mut server_peer) = duplex(4096); + + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let mut relay_task = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 1024, + 1024, + user, + Arc::clone(&stats), + Some(1), + Arc::new(BufferPool::new()), + )); + + client_peer + .write_all(&[0x10, 0x20, 0x30, 0x40]) + .await + .expect("client write must succeed"); + let mut c2s = [0u8; 4]; + server_peer + .read_exact(&mut c2s) + .await + .expect("server must receive client payload even with high preloaded octets"); + assert_eq!(c2s, [0x10, 0x20, 0x30, 0x40]); + + server_peer + .write_all(&[0xaa, 0xbb, 0xcc, 0xdd]) + .await + .expect("server write must succeed"); + let mut s2c = [0u8; 4]; + client_peer + .read_exact(&mut s2c) + .await + .expect("client must receive server payload even with high preloaded octets"); + assert_eq!(s2c, [0xaa, 0xbb, 0xcc, 0xdd]); + + let not_finished = timeout(Duration::from_millis(100), &mut relay_task).await; assert!( - wakes_after_first_yield >= 1, - "contended writer must schedule at least one deferred wake for liveness" + matches!(not_finished, Err(_)), + "relay must not self-terminate with quota-style errors; gating is handled before relay" ); - - let second = Pin::new(&mut io).poll_write(&mut cx, &[0x22]); - assert!(second.is_pending(), "writer remains pending while lock is still held"); - - for _ in 0..8 { - tokio::task::yield_now().await; - } - assert_eq!( - wake_counter.wakes.load(Ordering::Relaxed), - wakes_after_first_yield, - "writer contention should not schedule unbounded wake storms before lock acquisition" - ); - - drop(held_lock); - let released = Pin::new(&mut io).poll_write(&mut cx, &[0x33]); - assert!(released.is_ready(), "writer must make progress once quota lock is released"); + relay_task.abort(); } #[tokio::test] -async fn quota_lock_contention_read_path_schedules_deferred_wake_for_liveness() { +async fn relay_bidirectional_counts_octets_without_fail_closed_cutoff() { let stats = Arc::new(Stats::new()); - let user = "quota-lock-read-liveness-user"; - - let lock = super::quota_user_lock(user); - let held_lock = lock - .try_lock() - .expect("test must hold the per-user quota lock before polling reader"); - - let counters = Arc::new(super::SharedCounters::new()); - let quota_exceeded = Arc::new(std::sync::atomic::AtomicBool::new(false)); - let mut io = super::StatsIo::new( - tokio::io::empty(), - counters, - Arc::clone(&stats), - user.to_string(), - Some(1024), - quota_exceeded, - tokio::time::Instant::now(), - ); - - let wake_counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&wake_counter)); - let mut cx = Context::from_waker(&waker); - let mut storage = [0u8; 1]; - let mut buf = ReadBuf::new(&mut storage); - - let first = Pin::new(&mut io).poll_read(&mut cx, &mut buf); - assert!(first.is_pending(), "reader must remain pending while lock is contended"); - assert_eq!( - wake_counter.wakes.load(Ordering::Relaxed), - 0, - "read contention wake must not fire synchronously" - ); - - timeout(Duration::from_millis(50), async { - loop { - if wake_counter.wakes.load(Ordering::Relaxed) >= 1 { - break; - } - tokio::task::yield_now().await; - } - }) - .await - .expect("read contention must schedule a deferred wake in bounded time"); - - drop(held_lock); - let mut buf_after_release = ReadBuf::new(&mut storage); - let released = Pin::new(&mut io).poll_read(&mut cx, &mut buf_after_release); - assert!(released.is_ready(), "reader must make progress once quota lock is released"); -} - -#[tokio::test] -async fn relay_bidirectional_enforces_live_user_quota() { - let stats = Arc::new(Stats::new()); - let user = "quota-user"; - stats.add_user_octets_from(user, 6); + let user = "relay-stats-no-cutoff-user"; let (mut client_peer, relay_client) = duplex(4096); let (relay_server, mut server_peer) = duplex(4096); @@ -205,329 +182,37 @@ async fn relay_bidirectional_enforces_live_user_quota() { 1024, user, Arc::clone(&stats), - Some(8), + Some(0), Arc::new(BufferPool::new()), )); client_peer - .write_all(&[0x10, 0x20, 0x30, 0x40]) + .write_all(&[1, 2, 3]) .await .expect("client write must succeed"); - - let mut forwarded = [0u8; 4]; - let _ = timeout( - Duration::from_millis(200), - server_peer.read_exact(&mut forwarded), - ) - .await; - - let relay_result = timeout(Duration::from_secs(2), relay_task) - .await - .expect("relay task must finish under quota cutoff") - .expect("relay task must not panic"); - - assert!( - matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == "quota-user"), - "relay must surface a typed quota error once live quota is exceeded" - ); -} - -#[tokio::test] -async fn relay_bidirectional_does_not_forward_server_bytes_after_quota_is_exhausted() { - let stats = Arc::new(Stats::new()); - let quota_user = "quota-exhausted-user"; - stats.add_user_octets_from(quota_user, 1); - - let (mut client_peer, relay_client) = duplex(4096); - let (relay_server, mut server_peer) = duplex(4096); - - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let relay_task = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 1024, - 1024, - quota_user, - Arc::clone(&stats), - Some(1), - Arc::new(BufferPool::new()), - )); - server_peer - .write_all(&[0xde, 0xad, 0xbe, 0xef]) + .write_all(&[4, 5, 6, 7]) .await .expect("server write must succeed"); - let mut observed = [0u8; 4]; - let forwarded = timeout( - Duration::from_millis(200), - client_peer.read_exact(&mut observed), - ) - .await; - - let relay_result = timeout(Duration::from_secs(2), relay_task) - .await - .expect("relay task must finish under quota cutoff") - .expect("relay task must not panic"); - - assert!( - !matches!(forwarded, Ok(Ok(n)) if n == observed.len()), - "no full server payload should be forwarded once quota is already exhausted" - ); - assert!( - matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user), - "relay must still terminate with a typed quota error" - ); -} - -#[tokio::test] -async fn relay_bidirectional_does_not_leak_partial_server_payload_when_remaining_quota_is_smaller_than_write() { - let stats = Arc::new(Stats::new()); - let quota_user = "partial-leak-user"; - stats.add_user_octets_from(quota_user, 3); - - let (mut client_peer, relay_client) = duplex(4096); - let (relay_server, mut server_peer) = duplex(4096); - - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let relay_task = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 1024, - 1024, - quota_user, - Arc::clone(&stats), - Some(4), - Arc::new(BufferPool::new()), - )); - + let mut c2s = [0u8; 3]; server_peer - .write_all(&[0x11, 0x22, 0x33, 0x44]) + .read_exact(&mut c2s) .await - .expect("server write must succeed"); - - let mut observed = [0u8; 8]; - let forwarded = timeout(Duration::from_millis(200), client_peer.read(&mut observed)).await; - - let relay_result = timeout(Duration::from_secs(2), relay_task) - .await - .expect("relay task must finish under quota cutoff") - .expect("relay task must not panic"); - - assert!( - !matches!(forwarded, Ok(Ok(n)) if n > 0), - "quota exhaustion must not leak any partial server payload when remaining quota is smaller than the write" - ); - assert!( - matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user), - "relay must still terminate with a typed quota error" - ); -} - -#[tokio::test] -async fn relay_bidirectional_zero_quota_remains_fail_closed_for_server_payloads_under_stress() { - let stats = Arc::new(Stats::new()); - let quota_user = "zero-quota-user"; - - for payload_len in [1usize, 16, 512, 4096] { - let (mut client_peer, relay_client) = duplex(4096); - let (relay_server, mut server_peer) = duplex(4096); - - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let relay_task = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 1024, - 1024, - quota_user, - Arc::clone(&stats), - Some(0), - Arc::new(BufferPool::new()), - )); - - let payload = vec![0x7f; payload_len]; - let _ = server_peer.write_all(&payload).await; - - let mut observed = vec![0u8; payload_len]; - let forwarded = timeout(Duration::from_millis(200), client_peer.read(&mut observed)).await; - - let relay_result = timeout(Duration::from_secs(2), relay_task) - .await - .expect("relay task must finish under zero-quota cutoff") - .expect("relay task must not panic"); - - assert!( - !matches!(forwarded, Ok(Ok(n)) if n > 0), - "zero quota must not forward any server bytes for payload_len={payload_len}" - ); - assert!( - matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user), - "zero quota must terminate with the typed quota error for payload_len={payload_len}" - ); - } -} - -#[tokio::test] -async fn relay_bidirectional_allows_exact_server_payload_at_quota_boundary() { - let stats = Arc::new(Stats::new()); - let quota_user = "exact-boundary-user"; - - let (mut client_peer, relay_client) = duplex(4096); - let (relay_server, mut server_peer) = duplex(4096); - - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let relay_task = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 1024, - 1024, - quota_user, - Arc::clone(&stats), - Some(4), - Arc::new(BufferPool::new()), - )); - - server_peer - .write_all(&[0x91, 0x92, 0x93, 0x94]) - .await - .expect("server write must succeed at exact quota boundary"); - - let mut observed = [0u8; 4]; + .expect("server must receive c2s payload"); + let mut s2c = [0u8; 4]; client_peer - .read_exact(&mut observed) + .read_exact(&mut s2c) .await - .expect("client must receive the full payload at the exact quota boundary"); - assert_eq!(observed, [0x91, 0x92, 0x93, 0x94]); - - let relay_result = timeout(Duration::from_secs(2), relay_task) - .await - .expect("relay task must finish after exact boundary delivery") - .expect("relay task must not panic"); + .expect("client must receive s2c payload"); + let total = stats.get_user_total_octets(user); assert!( - matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user), - "relay must close with a typed quota error after reaching the exact boundary" + total >= 7, + "relay must continue accounting octets, observed total={total}" ); -} -#[tokio::test] -async fn relay_bidirectional_does_not_forward_client_bytes_after_quota_is_exhausted() { - let stats = Arc::new(Stats::new()); - let quota_user = "client-exhausted-user"; - stats.add_user_octets_from(quota_user, 1); - - let (mut client_peer, relay_client) = duplex(4096); - let (relay_server, mut server_peer) = duplex(4096); - - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let relay_task = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 1024, - 1024, - quota_user, - Arc::clone(&stats), - Some(1), - Arc::new(BufferPool::new()), - )); - - client_peer - .write_all(&[0x51, 0x52, 0x53, 0x54]) - .await - .expect("client write must succeed even when quota is already exhausted"); - - let mut observed = [0u8; 4]; - let forwarded = timeout( - Duration::from_millis(200), - server_peer.read_exact(&mut observed), - ) - .await; - - let relay_result = timeout(Duration::from_secs(2), relay_task) - .await - .expect("relay task must finish under quota cutoff") - .expect("relay task must not panic"); - - assert!( - !matches!(forwarded, Ok(Ok(n)) if n == observed.len()), - "client payload must not be fully forwarded once quota is already exhausted" - ); - assert!( - matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user), - "relay must still terminate with a typed quota error" - ); -} - -#[tokio::test] -async fn relay_bidirectional_server_bytes_remain_blocked_even_under_multiple_payload_sizes() { - let stats = Arc::new(Stats::new()); - let quota_user = "quota-fuzz-user"; - stats.add_user_octets_from(quota_user, 2); - - for payload_len in [1usize, 32, 1024, 8192] { - let (mut client_peer, relay_client) = duplex(4096); - let (relay_server, mut server_peer) = duplex(4096); - - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let relay_task = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 1024, - 1024, - quota_user, - Arc::clone(&stats), - Some(2), - Arc::new(BufferPool::new()), - )); - - let payload = vec![0xaa; payload_len]; - let _ = server_peer.write_all(&payload).await; - - let mut observed = vec![0u8; payload_len]; - let forwarded = timeout( - Duration::from_millis(200), - client_peer.read_exact(&mut observed), - ) - .await; - - let relay_result = timeout(Duration::from_secs(2), relay_task) - .await - .expect("relay task must finish under quota cutoff") - .expect("relay task must not panic"); - - assert!( - !matches!(forwarded, Ok(Ok(n)) if n == payload_len), - "quota exhaustion must block full server-to-client forwarding for payload_len={payload_len}" - ); - assert!( - matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user), - "relay must keep returning the typed quota error for payload_len={payload_len}" - ); - } + relay_task.abort(); } #[tokio::test] @@ -878,7 +563,7 @@ impl AsyncRead for GateReader { } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn adversarial_concurrent_quota_write_race_does_not_overshoot_limit() { +async fn adversarial_concurrent_statsio_write_accounting_is_additive() { let stats = Arc::new(Stats::new()); let gate = Arc::new(TwoPartyGate::new()); let user = "concurrent-quota-write".to_string(); @@ -888,8 +573,6 @@ async fn adversarial_concurrent_quota_write_race_does_not_overshoot_limit() { Arc::new(super::SharedCounters::new()), Arc::clone(&stats), user.clone(), - Some(1), - Arc::new(std::sync::atomic::AtomicBool::new(false)), tokio::time::Instant::now(), ); @@ -898,8 +581,6 @@ async fn adversarial_concurrent_quota_write_race_does_not_overshoot_limit() { Arc::new(super::SharedCounters::new()), Arc::clone(&stats), user.clone(), - Some(1), - Arc::new(std::sync::atomic::AtomicBool::new(false)), tokio::time::Instant::now(), ); @@ -916,18 +597,20 @@ async fn adversarial_concurrent_quota_write_race_does_not_overshoot_limit() { let _ = res_a.expect("task a must join"); let _ = res_b.expect("task b must join"); - assert!( - gate.total_bytes() <= 1, - "concurrent same-user writes must not forward more than one byte under quota=1" + assert_eq!( + gate.total_bytes(), + 2, + "both concurrent writes must forward one byte each" ); - assert!( - stats.get_user_total_octets(&user) <= 1, - "concurrent same-user writes must not account over limit" + assert_eq!( + stats.get_user_total_octets(&user), + 2, + "both concurrent writes must be accounted for same user" ); } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn adversarial_concurrent_quota_read_race_does_not_overshoot_limit() { +async fn adversarial_concurrent_statsio_read_accounting_is_additive() { let stats = Arc::new(Stats::new()); let gate = Arc::new(TwoPartyGate::new()); let user = "concurrent-quota-read".to_string(); @@ -937,8 +620,6 @@ async fn adversarial_concurrent_quota_read_race_does_not_overshoot_limit() { Arc::new(super::SharedCounters::new()), Arc::clone(&stats), user.clone(), - Some(1), - Arc::new(std::sync::atomic::AtomicBool::new(false)), tokio::time::Instant::now(), ); @@ -947,8 +628,6 @@ async fn adversarial_concurrent_quota_read_race_does_not_overshoot_limit() { Arc::new(super::SharedCounters::new()), Arc::clone(&stats), user.clone(), - Some(1), - Arc::new(std::sync::atomic::AtomicBool::new(false)), tokio::time::Instant::now(), ); @@ -967,22 +646,24 @@ async fn adversarial_concurrent_quota_read_race_does_not_overshoot_limit() { let _ = res_a.expect("task a must join"); let _ = res_b.expect("task b must join"); - assert!( - gate.total_bytes() <= 1, - "concurrent same-user reads must not consume more than one byte under quota=1" + assert_eq!( + gate.total_bytes(), + 2, + "both concurrent reads must consume one byte each" ); - assert!( - stats.get_user_total_octets(&user) <= 1, - "concurrent same-user reads must not account over limit" + assert_eq!( + stats.get_user_total_octets(&user), + 2, + "both concurrent reads must be accounted for same user" ); } #[tokio::test] -async fn stress_same_user_quota_parallel_relays_never_exceed_limit() { +async fn stress_same_user_parallel_relays_complete_without_deadlock() { let stats = Arc::new(Stats::new()); - let user = "parallel-quota-user"; + let user = "parallel-relay-user"; - for _ in 0..128 { + for _ in 0..64 { let (mut client_peer_a, relay_client_a) = duplex(256); let (relay_server_a, mut server_peer_a) = duplex(256); let (mut client_peer_b, relay_client_b) = duplex(256); @@ -1002,7 +683,7 @@ async fn stress_same_user_quota_parallel_relays_never_exceed_limit() { 64, user, Arc::clone(&stats), - Some(1), + None, Arc::new(BufferPool::new()), )); @@ -1015,7 +696,7 @@ async fn stress_same_user_quota_parallel_relays_never_exceed_limit() { 64, user, Arc::clone(&stats), - Some(1), + None, Arc::new(BufferPool::new()), )); @@ -1041,9 +722,10 @@ async fn stress_same_user_quota_parallel_relays_never_exceed_limit() { let _ = timeout(Duration::from_secs(1), relay_a).await; let _ = timeout(Duration::from_secs(1), relay_b).await; + let total = stats.get_user_total_octets(user); assert!( - stats.get_user_total_octets(user) <= 1, - "parallel relays must not exceed configured quota" + total >= 2, + "parallel relays must account cross-session octets and stay live; total={total}" ); } } diff --git a/src/stats/mod.rs b/src/stats/mod.rs index f31e429..c2b62a2 100644 --- a/src/stats/mod.rs +++ b/src/stats/mod.rs @@ -14,6 +14,7 @@ use std::num::NonZeroUsize; use std::hash::{Hash, Hasher}; use std::collections::hash_map::DefaultHasher; use std::collections::VecDeque; +use std::sync::Arc; use tracing::debug; use crate::config::{MeTelemetryLevel, MeWriterPickMode}; @@ -148,6 +149,14 @@ pub struct Stats { start_time: parking_lot::RwLock>, } +#[cfg(test)] +#[path = "connection_lease_security_tests.rs"] +mod connection_lease_security_tests; + +#[cfg(test)] +#[path = "replay_checker_security_tests.rs"] +mod replay_checker_security_tests; + #[derive(Default)] pub struct UserStats { pub connects: AtomicU64, @@ -159,6 +168,35 @@ pub struct UserStats { pub last_seen_epoch_secs: AtomicU64, } +enum ConnectionLeaseKind { + Direct, + Middle, +} + +pub struct ConnectionLease { + stats: Arc, + kind: ConnectionLeaseKind, + armed: bool, +} + +impl ConnectionLease { + pub fn disarm(&mut self) { + self.armed = false; + } +} + +impl Drop for ConnectionLease { + fn drop(&mut self) { + if !self.armed { + return; + } + match self.kind { + ConnectionLeaseKind::Direct => self.stats.decrement_current_connections_direct(), + ConnectionLeaseKind::Middle => self.stats.decrement_current_connections_me(), + } + } +} + impl Stats { pub fn new() -> Self { let stats = Self::default(); @@ -292,6 +330,22 @@ impl Stats { pub fn decrement_current_connections_me(&self) { Self::decrement_atomic_saturating(&self.current_connections_me); } + pub fn acquire_direct_connection_lease(self: &Arc) -> ConnectionLease { + self.increment_current_connections_direct(); + ConnectionLease { + stats: Arc::clone(self), + kind: ConnectionLeaseKind::Direct, + armed: true, + } + } + pub fn acquire_me_connection_lease(self: &Arc) -> ConnectionLease { + self.increment_current_connections_me(); + ConnectionLease { + stats: Arc::clone(self), + kind: ConnectionLeaseKind::Middle, + armed: true, + } + } pub fn increment_relay_adaptive_promotions_total(&self) { if self.telemetry_core_enabled() { self.relay_adaptive_promotions_total