From c7cf37898b9fabb5137a73a5c912fc7523fecd3e Mon Sep 17 00:00:00 2001 From: David Osipov Date: Wed, 18 Mar 2026 23:55:08 +0400 Subject: [PATCH] feat: enhance quota user lock management and testing - Adjusted QUOTA_USER_LOCKS_MAX based on test and non-test configurations to improve flexibility. - Implemented logic to retain existing locks when the maximum quota is reached, ensuring efficient memory usage. - Added comprehensive tests for quota user lock functionality, including cache reuse, saturation behavior, and race conditions. - Enhanced StatsIo struct to manage wake scheduling for read and write operations, preventing unnecessary self-wakes. - Introduced separate replay checker domains for handshake and TLS to ensure isolation and prevent cross-pollution of keys. - Added security tests for replay checker to validate domain separation and window clamping behavior. --- .gitignore | 1 + src/protocol/tls_security_tests.rs | 132 +++++++ src/proxy/client.rs | 4 +- src/proxy/client_security_tests.rs | 341 ++++++++++++++++- src/proxy/direct_relay.rs | 42 +- src/proxy/direct_relay_security_tests.rs | 115 ++++++ src/proxy/handshake.rs | 37 +- ...short_tls_probe_throttle_security_tests.rs | 50 +++ src/proxy/handshake_security_tests.rs | 182 +++++++++ src/proxy/masking.rs | 2 + src/proxy/masking_security_tests.rs | 248 ++++++++++++ src/proxy/middle_relay.rs | 14 +- src/proxy/middle_relay_security_tests.rs | 359 ++++++++++++++++++ src/proxy/relay.rs | 32 +- src/proxy/relay_security_tests.rs | 170 +++++++++ src/proxy/route_mode_security_tests.rs | 66 ++++ src/stats/mod.rs | 70 +++- src/stats/replay_checker_security_tests.rs | 80 ++++ 18 files changed, 1896 insertions(+), 49 deletions(-) create mode 100644 src/proxy/handshake_gap_short_tls_probe_throttle_security_tests.rs create mode 100644 src/stats/replay_checker_security_tests.rs diff --git a/.gitignore b/.gitignore index 3a45e41..bc782ca 100644 --- a/.gitignore +++ b/.gitignore @@ -21,3 +21,4 @@ target #.idea/ proxy-secret +coverage-html/ \ No newline at end of file diff --git a/src/protocol/tls_security_tests.rs b/src/protocol/tls_security_tests.rs index f8f2695..e551cca 100644 --- a/src/protocol/tls_security_tests.rs +++ b/src/protocol/tls_security_tests.rs @@ -1949,6 +1949,138 @@ fn server_hello_new_session_ticket_count_is_safely_capped() { ); } +#[test] +fn boot_time_handshake_replay_remains_blocked_after_cache_window_expires() { + let secret = b"gap_t01_boot_replay"; + let secrets = vec![("user".to_string(), secret.to_vec())]; + let handshake = make_valid_tls_handshake(secret, 1); + + let validation = validate_tls_handshake_with_replay_window(&handshake, &secrets, false, 2) + .expect("boot-time handshake must validate on first use"); + + let checker = crate::stats::ReplayChecker::new(128, std::time::Duration::from_millis(40)); + let digest_half = &validation.digest[..TLS_DIGEST_HALF_LEN]; + + assert!( + !checker.check_and_add_tls_digest(digest_half), + "first use must not be treated as replay" + ); + assert!( + checker.check_and_add_tls_digest(digest_half), + "immediate second use must be detected as replay" + ); + + std::thread::sleep(std::time::Duration::from_millis(70)); + + let validation_after_expiry = validate_tls_handshake_with_replay_window(&handshake, &secrets, false, 2) + .expect("boot-time handshake must still cryptographically validate after cache expiry"); + let digest_half_after_expiry = &validation_after_expiry.digest[..TLS_DIGEST_HALF_LEN]; + assert_eq!(digest_half, digest_half_after_expiry, "replay key must be stable for same handshake"); + + assert!( + checker.check_and_add_tls_digest(digest_half_after_expiry), + "after cache window expiry, the same boot-time handshake must still be treated as replay" + ); +} + +#[test] +fn adversarial_boot_time_handshake_should_not_be_replayable_after_cache_expiry() { + let secret = b"gap_t01_boot_replay_adversarial"; + let secrets = vec![("user".to_string(), secret.to_vec())]; + let handshake = make_valid_tls_handshake(secret, 1); + + let validation = validate_tls_handshake_with_replay_window(&handshake, &secrets, false, 2) + .expect("boot-time handshake must validate on first use"); + + let checker = crate::stats::ReplayChecker::new(128, std::time::Duration::from_millis(40)); + let digest_half = &validation.digest[..TLS_DIGEST_HALF_LEN]; + + assert!( + !checker.check_and_add_tls_digest(digest_half), + "first use must not be treated as replay" + ); + assert!( + checker.check_and_add_tls_digest(digest_half), + "immediate reuse must be rejected as replay" + ); + + std::thread::sleep(std::time::Duration::from_millis(70)); + + let validation_after_expiry = validate_tls_handshake_with_replay_window(&handshake, &secrets, false, 2) + .expect("boot-time handshake still validates cryptographically after cache expiry"); + let digest_half_after_expiry = &validation_after_expiry.digest[..TLS_DIGEST_HALF_LEN]; + + assert_eq!( + digest_half, digest_half_after_expiry, + "replay key must remain stable for the same captured handshake" + ); + + assert!( + checker.check_and_add_tls_digest(digest_half_after_expiry), + "security expectation: a boot-time handshake should remain replay-protected even after cache expiry" + ); +} + +#[test] +fn stress_short_replay_window_boot_timestamp_replay_cycles_remain_fail_closed_in_window() { + let secret = b"gap_t01_boot_replay_stress"; + let secrets = vec![("user".to_string(), secret.to_vec())]; + let handshake = make_valid_tls_handshake(secret, 1); + + let checker = crate::stats::ReplayChecker::new(256, std::time::Duration::from_millis(25)); + + for cycle in 0..64 { + let validation = validate_tls_handshake_with_replay_window(&handshake, &secrets, false, 2) + .expect("boot-time handshake must validate"); + let digest_half = &validation.digest[..TLS_DIGEST_HALF_LEN]; + + if cycle == 0 { + assert!( + !checker.check_and_add_tls_digest(digest_half), + "cycle 0: first use must be fresh" + ); + assert!( + checker.check_and_add_tls_digest(digest_half), + "cycle 0: second use must be replay" + ); + } else { + assert!( + checker.check_and_add_tls_digest(digest_half), + "cycle {cycle}: digest must remain replay-protected across short-window churn" + ); + } + + std::thread::sleep(std::time::Duration::from_millis(30)); + } +} + +#[test] +fn light_fuzz_boot_time_timestamp_matrix_with_short_replay_window_obeys_boot_cap() { + let secret = b"gap_t01_boot_replay_fuzz"; + let secrets = vec![("user".to_string(), secret.to_vec())]; + + let mut s: u64 = 0xA1B2_C3D4_55AA_7733; + for _ in 0..2048 { + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + let ts = (s as u32) % 8; + + let handshake = make_valid_tls_handshake(secret, ts); + let accepted = validate_tls_handshake_with_replay_window(&handshake, &secrets, false, 2) + .is_some(); + + if ts < 2 { + assert!(accepted, "timestamp {ts} must remain boot-time compatible under 2s cap"); + } else { + assert!( + !accepted, + "timestamp {ts} must be rejected when outside replay-window boot cap" + ); + } + } +} + #[test] fn server_hello_application_data_contains_alpn_marker_when_selected() { let secret = b"alpn_marker_test"; diff --git a/src/proxy/client.rs b/src/proxy/client.rs index d7b3660..6c64a94 100644 --- a/src/proxy/client.rs +++ b/src/proxy/client.rs @@ -300,7 +300,7 @@ where handle_bad_client( reader, writer, - &mtproto_handshake, + &handshake, real_peer, local_addr, &config, @@ -713,7 +713,7 @@ impl RunningClientHandler { handle_bad_client( reader, writer, - &mtproto_handshake, + &handshake, peer, local_addr, &config, diff --git a/src/proxy/client_security_tests.rs b/src/proxy/client_security_tests.rs index 7e34f4b..abd6266 100644 --- a/src/proxy/client_security_tests.rs +++ b/src/proxy/client_security_tests.rs @@ -5,8 +5,8 @@ use crate::crypto::sha256_hmac; use crate::protocol::constants::ProtoTag; use crate::protocol::tls; use crate::proxy::handshake::HandshakeSuccess; -use crate::transport::proxy_protocol::ProxyProtocolV1Builder; use crate::stream::{CryptoReader, CryptoWriter}; +use crate::transport::proxy_protocol::ProxyProtocolV1Builder; use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream}; @@ -303,6 +303,333 @@ async fn relay_cutover_releases_user_gate_and_ip_reservation() { let _ = tg_accept_task.await; } +#[tokio::test] +async fn integration_route_cutover_and_quota_overlap_fails_closed_and_releases_state() { + let tg_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let tg_addr = tg_listener.local_addr().unwrap(); + + let tg_accept_task = tokio::spawn(async move { + let (mut stream, _) = tg_listener.accept().await.unwrap(); + stream.write_all(&[0x41, 0x42]).await.unwrap(); + tokio::time::sleep(Duration::from_secs(1)).await; + }); + + let user = "cutover-quota-overlap-user"; + let peer_addr: SocketAddr = "198.51.100.240:50010".parse().unwrap(); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let mut cfg = ProxyConfig::default(); + cfg.access.user_max_tcp_conns.insert(user.to_string(), 8); + cfg.access.user_data_quota.insert(user.to_string(), 1); + cfg.dc_overrides + .insert("2".to_string(), vec![tg_addr.to_string()]); + let config = Arc::new(cfg); + + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + + let (server_side, client_side) = duplex(64 * 1024); + let (server_reader, server_writer) = tokio::io::split(server_side); + let client_reader = make_crypto_reader(server_reader); + let client_writer = make_crypto_writer(server_writer); + + let success = HandshakeSuccess { + user: user.to_string(), + dc_idx: 2, + proto_tag: ProtoTag::Intermediate, + dec_key: [0u8; 32], + dec_iv: 0, + enc_key: [0u8; 32], + enc_iv: 0, + peer: peer_addr, + is_tls: false, + }; + + let relay_task = tokio::spawn(RunningClientHandler::handle_authenticated_static( + client_reader, + client_writer, + success, + upstream_manager, + stats.clone(), + config, + buffer_pool, + rng, + None, + route_runtime.clone(), + "127.0.0.1:443".parse().unwrap(), + peer_addr, + ip_tracker.clone(), + )); + + let observed_progress = tokio::time::timeout(Duration::from_secs(2), async { + loop { + if stats.get_user_curr_connects(user) >= 1 + || ip_tracker.get_active_ip_count(user).await >= 1 + || relay_task.is_finished() + { + return true; + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + }) + .await + .unwrap_or(false); + assert!( + observed_progress, + "overlap race test precondition must observe activation or bounded early termination" + ); + + tokio::time::sleep(Duration::from_millis(5)).await; + let _ = route_runtime.set_mode(RelayRouteMode::Middle); + + let relay_result = tokio::time::timeout(Duration::from_secs(3), relay_task) + .await + .expect("overlap race relay must terminate") + .expect("overlap race relay task must not panic"); + + assert!( + matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. })) + || matches!(relay_result, Err(ProxyError::Proxy(ref msg)) if msg == crate::proxy::route_mode::ROUTE_SWITCH_ERROR_MSG), + "overlap race must fail closed via quota enforcement or generic cutover termination" + ); + + assert_eq!( + stats.get_user_curr_connects(user), + 0, + "overlap race exit must release user current-connection slot" + ); + assert_eq!( + ip_tracker.get_active_ip_count(user).await, + 0, + "overlap race exit must release reserved user IP footprint" + ); + + drop(client_side); + tg_accept_task.abort(); + let _ = tg_accept_task.await; +} + +#[tokio::test] +async fn stress_drop_without_release_converges_to_zero_user_and_ip_state() { + let user = "gap-t05-drop-stress-user"; + let mut config = crate::config::ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 4096); + + let stats = std::sync::Arc::new(crate::stats::Stats::new()); + let ip_tracker = std::sync::Arc::new(crate::ip_tracker::UserIpTracker::new()); + + let mut reservations = Vec::new(); + for idx in 0..512u16 { + let peer = std::net::SocketAddr::new( + std::net::IpAddr::V4(std::net::Ipv4Addr::new(198, 51, (idx >> 8) as u8, (idx & 0xff) as u8)), + 30_000 + idx, + ); + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await + .expect("reservation acquisition must succeed in stress precondition"); + reservations.push(reservation); + } + + assert_eq!(stats.get_user_curr_connects(user), 512); + + for reservation in reservations { + std::thread::spawn(move || drop(reservation)) + .join() + .expect("drop thread must not panic"); + } + + tokio::time::timeout(std::time::Duration::from_secs(2), async { + loop { + if stats.get_user_curr_connects(user) == 0 + && ip_tracker.get_active_ip_count(user).await == 0 + { + break; + } + tokio::time::sleep(std::time::Duration::from_millis(5)).await; + } + }) + .await + .expect("drop-only path must eventually release all user/IP reservations"); +} + +#[tokio::test] +async fn proxy_protocol_header_is_rejected_when_trust_list_is_empty() { + let mut cfg = crate::config::ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.server.proxy_protocol_trusted_cidrs.clear(); + + let config = std::sync::Arc::new(cfg); + let stats = std::sync::Arc::new(crate::stats::Stats::new()); + let upstream_manager = std::sync::Arc::new(crate::transport::UpstreamManager::new( + vec![crate::config::UpstreamConfig { + upstream_type: crate::config::UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = std::sync::Arc::new(crate::stats::ReplayChecker::new(128, std::time::Duration::from_secs(60))); + let buffer_pool = std::sync::Arc::new(crate::stream::BufferPool::new()); + let rng = std::sync::Arc::new(crate::crypto::SecureRandom::new()); + let route_runtime = std::sync::Arc::new(crate::proxy::route_mode::RouteRuntimeController::new(crate::proxy::route_mode::RelayRouteMode::Direct)); + let ip_tracker = std::sync::Arc::new(crate::ip_tracker::UserIpTracker::new()); + let beobachten = std::sync::Arc::new(crate::stats::beobachten::BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(2048); + let peer: std::net::SocketAddr = "198.51.100.80:55000".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + true, + )); + + let proxy_header = ProxyProtocolV1Builder::new() + .tcp4( + "203.0.113.9:32000".parse().unwrap(), + "192.0.2.8:443".parse().unwrap(), + ) + .build(); + client_side.write_all(&proxy_header).await.unwrap(); + drop(client_side); + + let result = tokio::time::timeout(std::time::Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + assert!(matches!(result, Err(ProxyError::InvalidProxyProtocol))); +} + +#[tokio::test] +async fn proxy_protocol_header_from_untrusted_peer_range_is_rejected_under_load() { + let mut cfg = crate::config::ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.server.proxy_protocol_trusted_cidrs = vec!["10.0.0.0/8".parse().unwrap()]; + + let config = std::sync::Arc::new(cfg); + + for idx in 0..32u16 { + let stats = std::sync::Arc::new(crate::stats::Stats::new()); + let upstream_manager = std::sync::Arc::new(crate::transport::UpstreamManager::new( + vec![crate::config::UpstreamConfig { + upstream_type: crate::config::UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = std::sync::Arc::new(crate::stats::ReplayChecker::new(64, std::time::Duration::from_secs(60))); + let buffer_pool = std::sync::Arc::new(crate::stream::BufferPool::new()); + let rng = std::sync::Arc::new(crate::crypto::SecureRandom::new()); + let route_runtime = std::sync::Arc::new(crate::proxy::route_mode::RouteRuntimeController::new(crate::proxy::route_mode::RelayRouteMode::Direct)); + let ip_tracker = std::sync::Arc::new(crate::ip_tracker::UserIpTracker::new()); + let beobachten = std::sync::Arc::new(crate::stats::beobachten::BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(1024); + let peer = std::net::SocketAddr::new( + std::net::IpAddr::V4(std::net::Ipv4Addr::new(203, 0, 113, (idx + 1) as u8)), + 55_000 + idx, + ); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config.clone(), + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + true, + )); + + let proxy_header = ProxyProtocolV1Builder::new() + .tcp4( + "203.0.113.10:32000".parse().unwrap(), + "192.0.2.8:443".parse().unwrap(), + ) + .build(); + client_side.write_all(&proxy_header).await.unwrap(); + drop(client_side); + + let result = tokio::time::timeout(std::time::Duration::from_secs(2), handler) + .await + .unwrap() + .unwrap(); + assert!( + matches!(result, Err(ProxyError::InvalidProxyProtocol)), + "burst idx {idx}: untrusted source must be rejected" + ); + } +} + #[tokio::test] async fn short_tls_probe_is_masked_through_client_pipeline() { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); @@ -888,7 +1215,7 @@ async fn valid_tls_path_does_not_fall_back_to_mask_backend() { let ip_tracker = Arc::new(UserIpTracker::new()); let beobachten = Arc::new(BeobachtenStore::new()); - let (server_side, mut client_side) = duplex(8192); + let (server_side, mut client_side) = duplex(131072); let peer: SocketAddr = "198.51.100.80:55002".parse().unwrap(); let stats_for_assert = stats.clone(); let bad_before = stats_for_assert.get_connects_bad(); @@ -947,11 +1274,12 @@ async fn valid_tls_with_invalid_mtproto_falls_back_to_mask_backend() { let invalid_mtproto = vec![0u8; crate::protocol::constants::HANDSHAKE_LEN]; let tls_app_record = wrap_tls_application_data(&invalid_mtproto); + let expected_fallback = client_hello.clone(); let accept_task = tokio::spawn(async move { let (mut stream, _) = listener.accept().await.unwrap(); - let mut got = vec![0u8; invalid_mtproto.len()]; + let mut got = vec![0u8; expected_fallback.len()]; stream.read_exact(&mut got).await.unwrap(); - assert_eq!(got, invalid_mtproto); + assert_eq!(got, expected_fallback); }); let mut cfg = ProxyConfig::default(); @@ -1045,11 +1373,12 @@ async fn client_handler_tls_bad_mtproto_is_forwarded_to_mask_backend() { let invalid_mtproto = vec![0u8; crate::protocol::constants::HANDSHAKE_LEN]; let tls_app_record = wrap_tls_application_data(&invalid_mtproto); + let expected_fallback = client_hello.clone(); let mask_accept_task = tokio::spawn(async move { let (mut stream, _) = mask_listener.accept().await.unwrap(); - let mut got = vec![0u8; invalid_mtproto.len()]; + let mut got = vec![0u8; expected_fallback.len()]; stream.read_exact(&mut got).await.unwrap(); - assert_eq!(got, invalid_mtproto); + assert_eq!(got, expected_fallback); }); let mut cfg = ProxyConfig::default(); diff --git a/src/proxy/direct_relay.rs b/src/proxy/direct_relay.rs index d36856d..ede908e 100644 --- a/src/proxy/direct_relay.rs +++ b/src/proxy/direct_relay.rs @@ -31,6 +31,22 @@ use std::os::unix::fs::OpenOptionsExt; const UNKNOWN_DC_LOG_DISTINCT_LIMIT: usize = 1024; static LOGGED_UNKNOWN_DCS: OnceLock>> = OnceLock::new(); +const MAX_SCOPE_HINT_LEN: usize = 64; + +fn validated_scope_hint(user: &str) -> Option<&str> { + let scope = user.strip_prefix("scope_")?; + if scope.is_empty() || scope.len() > MAX_SCOPE_HINT_LEN { + return None; + } + if scope + .bytes() + .all(|b| b.is_ascii_alphanumeric() || b == b'-') + { + Some(scope) + } else { + None + } +} #[derive(Clone)] struct SanitizedUnknownDcLogPath { @@ -185,8 +201,15 @@ where "Connecting to Telegram DC" ); + let scope_hint = validated_scope_hint(user); + if user.starts_with("scope_") && scope_hint.is_none() { + warn!( + user = %user, + "Ignoring invalid scope hint and falling back to default upstream selection" + ); + } let tg_stream = upstream_manager - .connect(dc_addr, Some(success.dc_idx), user.strip_prefix("scope_").filter(|s| !s.is_empty())) + .connect(dc_addr, Some(success.dc_idx), scope_hint) .await?; debug!(peer = %success.peer, dc_addr = %dc_addr, "Connected, performing TG handshake"); @@ -290,17 +313,18 @@ fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result { warn!(dc_idx = dc_idx, "Requested non-standard DC with no override; falling back to default cluster"); if config.general.unknown_dc_file_log_enabled && let Some(path) = &config.general.unknown_dc_log_path - && should_log_unknown_dc(dc_idx) && let Ok(handle) = tokio::runtime::Handle::try_current() { if let Some(path) = sanitize_unknown_dc_log_path(path) { - handle.spawn_blocking(move || { - if unknown_dc_log_path_is_still_safe(&path) - && let Ok(mut file) = open_unknown_dc_log_append(&path.resolved_path) - { - let _ = writeln!(file, "dc_idx={dc_idx}"); - } - }); + if should_log_unknown_dc(dc_idx) { + handle.spawn_blocking(move || { + if unknown_dc_log_path_is_still_safe(&path) + && let Ok(mut file) = open_unknown_dc_log_append(&path.resolved_path) + { + let _ = writeln!(file, "dc_idx={dc_idx}"); + } + }); + } } else { warn!(dc_idx = dc_idx, raw_path = %path, "Rejected unsafe unknown DC log path"); } diff --git a/src/proxy/direct_relay_security_tests.rs b/src/proxy/direct_relay_security_tests.rs index e47164f..6c25068 100644 --- a/src/proxy/direct_relay_security_tests.rs +++ b/src/proxy/direct_relay_security_tests.rs @@ -94,6 +94,26 @@ fn unknown_dc_log_fails_closed_when_dedup_lock_is_poisoned() { ); } +#[test] +fn unsafe_unknown_dc_log_path_does_not_consume_dedup_slot() { + let _guard = unknown_dc_test_lock() + .lock() + .expect("unknown dc test lock must be available"); + clear_unknown_dc_log_cache_for_testing(); + + let dc_idx: i16 = 31_123; + let mut cfg = ProxyConfig::default(); + cfg.general.unknown_dc_file_log_enabled = true; + cfg.general.unknown_dc_log_path = Some("../telemt-unknown-dc-unsafe.log".to_string()); + + let _ = get_dc_addr_static(dc_idx, &cfg).expect("fallback routing must still work"); + + assert!( + should_log_unknown_dc(dc_idx), + "rejected unsafe log path must not consume unknown-dc dedup entry" + ); +} + #[test] fn stress_unknown_dc_log_concurrent_unique_churn_respects_cap() { let _guard = unknown_dc_test_lock() @@ -158,6 +178,24 @@ fn light_fuzz_unknown_dc_log_mixed_duplicates_never_exceeds_cap() { ); } +#[test] +fn scope_hint_accepts_ascii_alnum_and_dash_within_limit() { + assert_eq!(validated_scope_hint("scope_alpha-1"), Some("alpha-1")); + assert_eq!(validated_scope_hint("scope_AZ09"), Some("AZ09")); +} + +#[test] +fn scope_hint_rejects_invalid_or_oversized_values() { + assert_eq!(validated_scope_hint("plain_user"), None); + assert_eq!(validated_scope_hint("scope_"), None); + assert_eq!(validated_scope_hint("scope_a/b"), None); + assert_eq!(validated_scope_hint("scope_bad space"), None); + assert_eq!(validated_scope_hint("scope_bad.dot"), None); + + let oversized = format!("scope_{}", "a".repeat(MAX_SCOPE_HINT_LEN + 1)); + assert_eq!(validated_scope_hint(&oversized), None); +} + #[test] fn unknown_dc_log_path_sanitizer_rejects_parent_traversal_inputs() { assert!( @@ -1207,3 +1245,80 @@ async fn direct_relay_cutover_storm_multi_session_keeps_generic_errors_and_relea tg_accept_task.abort(); let _ = tg_accept_task.await; } + +#[test] +fn prefer_v6_override_matrix_prefers_matching_family_then_degrades_safely() { + let dc_idx: i16 = 2; + + let mut cfg_a = ProxyConfig::default(); + cfg_a.network.prefer = 6; + cfg_a.network.ipv6 = Some(true); + cfg_a.dc_overrides.insert( + dc_idx.to_string(), + vec![ + "203.0.113.90:443".to_string(), + "[2001:db8::90]:443".to_string(), + ], + ); + let a = get_dc_addr_static(dc_idx, &cfg_a).expect("v6+v4 override set must resolve"); + assert!(a.is_ipv6(), "prefer_v6 should choose v6 override when present"); + + let mut cfg_b = ProxyConfig::default(); + cfg_b.network.prefer = 6; + cfg_b.network.ipv6 = Some(true); + cfg_b.dc_overrides + .insert(dc_idx.to_string(), vec!["203.0.113.91:443".to_string()]); + let b = get_dc_addr_static(dc_idx, &cfg_b).expect("v4-only override must still resolve"); + assert!(b.is_ipv4(), "when no v6 override exists, v4 override must be used"); + + let mut cfg_c = ProxyConfig::default(); + cfg_c.network.prefer = 6; + cfg_c.network.ipv6 = Some(true); + let c = get_dc_addr_static(dc_idx, &cfg_c).expect("table fallback must resolve"); + assert_eq!( + c, + SocketAddr::new(TG_DATACENTERS_V6[(dc_idx as usize) - 1], TG_DATACENTER_PORT), + "without overrides, prefer_v6 path must resolve from static v6 datacenter table" + ); +} + +#[test] +fn prefer_v6_override_matrix_ignores_invalid_entries_and_keeps_fail_closed_fallback() { + let dc_idx: i16 = 3; + + let mut cfg = ProxyConfig::default(); + cfg.network.prefer = 6; + cfg.network.ipv6 = Some(true); + cfg.dc_overrides.insert( + dc_idx.to_string(), + vec![ + "not-an-addr".to_string(), + "also:bad".to_string(), + "203.0.113.55:443".to_string(), + ], + ); + + let addr = get_dc_addr_static(dc_idx, &cfg).expect("at least one valid override must keep resolution alive"); + assert_eq!(addr, "203.0.113.55:443".parse::().unwrap()); +} + +#[test] +fn stress_prefer_v6_override_matrix_is_deterministic_under_mixed_inputs() { + for idx in 1..=5i16 { + let mut cfg = ProxyConfig::default(); + cfg.network.prefer = 6; + cfg.network.ipv6 = Some(true); + cfg.dc_overrides.insert( + idx.to_string(), + vec![ + format!("203.0.113.{}:443", 100 + idx), + format!("[2001:db8::{}]:443", 100 + idx), + ], + ); + + let first = get_dc_addr_static(idx, &cfg).expect("first lookup must resolve"); + let second = get_dc_addr_static(idx, &cfg).expect("second lookup must resolve"); + assert_eq!(first, second, "override resolution must stay deterministic for dc {idx}"); + assert!(first.is_ipv6(), "dc {idx}: v6 override should be preferred"); + } +} diff --git a/src/proxy/handshake.rs b/src/proxy/handshake.rs index dc83ccc..6886e65 100644 --- a/src/proxy/handshake.rs +++ b/src/proxy/handshake.rs @@ -14,7 +14,7 @@ use dashmap::DashMap; use dashmap::mapref::entry::Entry; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tracing::{debug, warn, trace}; -use zeroize::Zeroize; +use zeroize::{Zeroize, Zeroizing}; use crate::crypto::{sha256, AesCtr, SecureRandom}; use rand::Rng; @@ -28,6 +28,10 @@ use crate::tls_front::{TlsFrontCache, emulator}; const ACCESS_SECRET_BYTES: usize = 16; static INVALID_SECRET_WARNED: OnceLock>> = OnceLock::new(); +#[cfg(test)] +const WARNED_SECRET_MAX_ENTRIES: usize = 64; +#[cfg(not(test))] +const WARNED_SECRET_MAX_ENTRIES: usize = 1_024; const AUTH_PROBE_TRACK_RETENTION_SECS: u64 = 10 * 60; #[cfg(test)] @@ -406,7 +410,13 @@ fn warn_invalid_secret_once(name: &str, reason: &str, expected: usize, got: Opti let key = (name.to_string(), reason.to_string()); let warned = INVALID_SECRET_WARNED.get_or_init(|| Mutex::new(HashSet::new())); let should_warn = match warned.lock() { - Ok(mut guard) => guard.insert(key), + Ok(mut guard) => { + if !guard.contains(&key) && guard.len() >= WARNED_SECRET_MAX_ENTRIES { + false + } else { + guard.insert(key) + } + } Err(_) => true, }; @@ -575,6 +585,7 @@ where } if handshake.len() < tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 { + auth_probe_record_failure(peer.ip(), Instant::now()); maybe_apply_server_hello_delay(config).await; debug!(peer = %peer, "TLS handshake too short"); return HandshakeResult::BadClient { reader, writer }; @@ -736,9 +747,13 @@ where R: AsyncRead + Unpin + Send, W: AsyncWrite + Unpin + Send, { + let handshake_fingerprint = { + let digest = sha256(&handshake[..8]); + hex::encode(&digest[..4]) + }; trace!( peer = %peer, - handshake_head = %hex::encode(&handshake[..8]), + handshake_fingerprint = %handshake_fingerprint, "MTProto handshake prefix" ); @@ -760,7 +775,7 @@ where let dec_prekey = &dec_prekey_iv[..PREKEY_LEN]; let dec_iv_bytes = &dec_prekey_iv[PREKEY_LEN..]; - let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len()); + let mut dec_key_input = Zeroizing::new(Vec::with_capacity(PREKEY_LEN + secret.len())); dec_key_input.extend_from_slice(dec_prekey); dec_key_input.extend_from_slice(&secret); let dec_key = sha256(&dec_key_input); @@ -796,7 +811,7 @@ where let enc_prekey = &enc_prekey_iv[..PREKEY_LEN]; let enc_iv_bytes = &enc_prekey_iv[PREKEY_LEN..]; - let mut enc_key_input = Vec::with_capacity(PREKEY_LEN + secret.len()); + let mut enc_key_input = Zeroizing::new(Vec::with_capacity(PREKEY_LEN + secret.len())); enc_key_input.extend_from_slice(enc_prekey); enc_key_input.extend_from_slice(&secret); let enc_key = sha256(&enc_key_input); @@ -885,7 +900,7 @@ pub fn generate_tg_nonce( nonce[DC_IDX_POS..DC_IDX_POS + 2].copy_from_slice(&dc_idx.to_le_bytes()); if fast_mode { - let mut key_iv = Vec::with_capacity(KEY_LEN + IV_LEN); + let mut key_iv = Zeroizing::new(Vec::with_capacity(KEY_LEN + IV_LEN)); key_iv.extend_from_slice(client_enc_key); key_iv.extend_from_slice(&client_enc_iv.to_be_bytes()); key_iv.reverse(); // Python/C behavior: reversed enc_key+enc_iv in nonce @@ -893,7 +908,7 @@ pub fn generate_tg_nonce( } let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN]; - let dec_key_iv: Vec = enc_key_iv.iter().rev().copied().collect(); + let dec_key_iv = Zeroizing::new(enc_key_iv.iter().rev().copied().collect::>()); let mut tg_enc_key = [0u8; 32]; tg_enc_key.copy_from_slice(&enc_key_iv[..KEY_LEN]); @@ -914,7 +929,7 @@ pub fn generate_tg_nonce( /// Encrypt nonce for sending to Telegram and return cipher objects with correct counter state pub fn encrypt_tg_nonce_with_ciphers(nonce: &[u8; HANDSHAKE_LEN]) -> (Vec, AesCtr, AesCtr) { let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN]; - let dec_key_iv: Vec = enc_key_iv.iter().rev().copied().collect(); + let dec_key_iv = Zeroizing::new(enc_key_iv.iter().rev().copied().collect::>()); let mut enc_key = [0u8; 32]; enc_key.copy_from_slice(&enc_key_iv[..KEY_LEN]); @@ -935,6 +950,8 @@ pub fn encrypt_tg_nonce_with_ciphers(nonce: &[u8; HANDSHAKE_LEN]) -> (Vec, A result.extend_from_slice(&encrypted_full[PROTO_TAG_POS..]); let decryptor = AesCtr::new(&dec_key, dec_iv); + enc_key.zeroize(); + dec_key.zeroize(); (result, encryptor, decryptor) } @@ -950,6 +967,10 @@ pub fn encrypt_tg_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec { #[path = "handshake_security_tests.rs"] mod security_tests; +#[cfg(test)] +#[path = "handshake_gap_short_tls_probe_throttle_security_tests.rs"] +mod gap_short_tls_probe_throttle_security_tests; + /// Compile-time guard: HandshakeSuccess holds cryptographic key material and /// must never be Copy. A Copy impl would allow silent key duplication, /// undermining the zeroize-on-drop guarantee. diff --git a/src/proxy/handshake_gap_short_tls_probe_throttle_security_tests.rs b/src/proxy/handshake_gap_short_tls_probe_throttle_security_tests.rs new file mode 100644 index 0000000..2ea32bc --- /dev/null +++ b/src/proxy/handshake_gap_short_tls_probe_throttle_security_tests.rs @@ -0,0 +1,50 @@ +use super::*; +use crate::stats::ReplayChecker; +use std::net::SocketAddr; +use std::time::Duration; + +fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig { + let mut cfg = ProxyConfig::default(); + cfg.access.users.clear(); + cfg.access + .users + .insert("user".to_string(), secret_hex.to_string()); + cfg.access.ignore_time_skew = true; + cfg +} + +#[tokio::test] +async fn gap_t01_short_tls_probe_burst_is_throttled() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let config = test_config_with_secret_hex("11111111111111111111111111111111"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.171:44361".parse().unwrap(); + + let too_short = vec![0x16, 0x03, 0x01]; + + for _ in 0..AUTH_PROBE_BACKOFF_START_FAILS { + let result = handle_tls_handshake( + &too_short, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(result, HandshakeResult::BadClient { .. })); + } + + assert!( + auth_probe_fail_streak_for_testing(peer.ip()) + .is_some_and(|streak| streak >= AUTH_PROBE_BACKOFF_START_FAILS), + "short TLS probe bursts must increase auth-probe fail streak" + ); +} diff --git a/src/proxy/handshake_security_tests.rs b/src/proxy/handshake_security_tests.rs index 7af7192..c93d18e 100644 --- a/src/proxy/handshake_security_tests.rs +++ b/src/proxy/handshake_security_tests.rs @@ -1345,6 +1345,29 @@ fn invalid_secret_warning_keys_do_not_collide_on_colon_boundaries() { ); } +#[test] +fn invalid_secret_warning_cache_is_bounded() { + let _guard = warned_secrets_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_warned_secrets_for_testing(); + + for idx in 0..(WARNED_SECRET_MAX_ENTRIES + 32) { + let user = format!("warned_user_{idx}"); + warn_invalid_secret_once(&user, "invalid_length", ACCESS_SECRET_BYTES, Some(idx)); + } + + let warned = INVALID_SECRET_WARNED + .get() + .expect("warned set must be initialized"); + let guard = warned.lock().expect("warned set lock must be available"); + assert_eq!( + guard.len(), + WARNED_SECRET_MAX_ENTRIES, + "invalid-secret warning cache must remain bounded" + ); +} + #[tokio::test] async fn repeated_invalid_tls_probes_trigger_pre_auth_throttle() { let _guard = auth_probe_test_lock() @@ -1921,6 +1944,165 @@ fn auth_probe_eviction_offset_changes_with_time_component() { ); } + +#[test] +fn auth_probe_round_limited_overcap_eviction_marks_saturation_and_keeps_newcomer_trackable() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let state = DashMap::new(); + let now = Instant::now(); + let initial = AUTH_PROBE_TRACK_MAX_ENTRIES + 64; + + let sentinel = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 250)); + state.insert( + sentinel, + AuthProbeState { + fail_streak: 25, + blocked_until: now, + last_seen: now - Duration::from_secs(30), + }, + ); + + for idx in 0..(initial - 1) { + let ip = IpAddr::V4(Ipv4Addr::new( + 10, + 20, + ((idx >> 8) & 0xff) as u8, + (idx & 0xff) as u8, + )); + state.insert( + ip, + AuthProbeState { + fail_streak: 1, + blocked_until: now, + last_seen: now + Duration::from_millis((idx % 1024) as u64), + }, + ); + } + + let newcomer = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 40)); + auth_probe_record_failure_with_state(&state, newcomer, now + Duration::from_millis(1)); + + assert!(state.get(&newcomer).is_some(), "newcomer must still be tracked under over-cap pressure"); + assert!( + state.get(&sentinel).is_some(), + "high fail-streak sentinel must survive round-limited eviction" + ); + assert!( + auth_probe_saturation_is_throttled_at_for_testing(now + Duration::from_millis(1)), + "round-limited over-cap path must activate saturation throttle marker" + ); +} + +#[test] +fn stress_auth_probe_overcap_churn_does_not_starve_high_threat_sentinel_bucket() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let state = DashMap::new(); + let base_now = Instant::now(); + + let sentinel = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 200)); + state.insert( + sentinel, + AuthProbeState { + fail_streak: 30, + blocked_until: base_now, + last_seen: base_now - Duration::from_secs(60), + }, + ); + + for idx in 0..(AUTH_PROBE_TRACK_MAX_ENTRIES + 80) { + let ip = IpAddr::V4(Ipv4Addr::new( + 172, + 22, + ((idx >> 8) & 0xff) as u8, + (idx & 0xff) as u8, + )); + state.insert( + ip, + AuthProbeState { + fail_streak: 1, + blocked_until: base_now, + last_seen: base_now + Duration::from_millis((idx % 2048) as u64), + }, + ); + } + + for step in 0..512usize { + let newcomer = IpAddr::V4(Ipv4Addr::new( + 203, + 2, + ((step >> 8) & 0xff) as u8, + (step & 0xff) as u8, + )); + auth_probe_record_failure_with_state(&state, newcomer, base_now + Duration::from_millis(step as u64 + 1)); + + assert!( + state.get(&sentinel).is_some(), + "step {step}: high-threat sentinel must not be starved by newcomer churn" + ); + assert!(state.get(&newcomer).is_some(), "step {step}: newcomer must be tracked"); + } +} + +#[test] +fn light_fuzz_auth_probe_overcap_eviction_prefers_less_threatening_entries() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let now = Instant::now(); + let mut s: u64 = 0xBADC_0FFE_EE11_2233; + + for round in 0..128usize { + let state = DashMap::new(); + let sentinel = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 180)); + state.insert( + sentinel, + AuthProbeState { + fail_streak: 18, + blocked_until: now, + last_seen: now - Duration::from_secs(5), + }, + ); + + for idx in 0..AUTH_PROBE_TRACK_MAX_ENTRIES { + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + let ip = IpAddr::V4(Ipv4Addr::new( + 10, + ((idx >> 8) & 0xff) as u8, + (idx & 0xff) as u8, + (s & 0xff) as u8, + )); + state.insert( + ip, + AuthProbeState { + fail_streak: 1, + blocked_until: now, + last_seen: now + Duration::from_millis((s & 1023) as u64), + }, + ); + } + + let newcomer = IpAddr::V4(Ipv4Addr::new(203, 10, ((round >> 8) & 0xff) as u8, (round & 0xff) as u8)); + auth_probe_record_failure_with_state(&state, newcomer, now + Duration::from_millis(round as u64 + 1)); + + assert!(state.get(&newcomer).is_some(), "round {round}: newcomer should be tracked"); + assert!( + state.get(&sentinel).is_some(), + "round {round}: high fail-streak sentinel should survive mixed low-threat pool" + ); + } +} #[test] fn light_fuzz_auth_probe_eviction_offset_is_deterministic_per_input_pair() { let mut rng = StdRng::seed_from_u64(0xA11CE5EED); diff --git a/src/proxy/masking.rs b/src/proxy/masking.rs index b0f6985..030fb2f 100644 --- a/src/proxy/masking.rs +++ b/src/proxy/masking.rs @@ -181,6 +181,7 @@ where }; if let Some(header) = proxy_header { if !write_proxy_header_with_timeout(&mut mask_write, &header).await { + wait_mask_outcome_budget(outcome_started).await; return; } } @@ -246,6 +247,7 @@ where let (mask_read, mut mask_write) = stream.into_split(); if let Some(header) = proxy_header { if !write_proxy_header_with_timeout(&mut mask_write, &header).await { + wait_mask_outcome_budget(outcome_started).await; return; } } diff --git a/src/proxy/masking_security_tests.rs b/src/proxy/masking_security_tests.rs index 1cee108..893b3e5 100644 --- a/src/proxy/masking_security_tests.rs +++ b/src/proxy/masking_security_tests.rs @@ -317,6 +317,254 @@ async fn backend_reachable_fast_response_waits_mask_outcome_budget() { accept_task.await.unwrap(); } +#[tokio::test] +async fn proxy_header_write_error_on_tcp_path_still_honors_coarse_outcome_budget() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = b"GET /proxy-hdr-err HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + drop(stream); + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 1; + + let peer: SocketAddr = "203.0.113.88:42430".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (client_reader_side, client_reader) = duplex(256); + drop(client_reader_side); + let (_client_visible_reader, client_visible_writer) = duplex(512); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + let task = tokio::spawn(async move { + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + }); + + timeout(Duration::from_millis(35), task) + .await + .expect_err("proxy-header write error path should remain inside coarse masking budget window"); + assert!( + started.elapsed() >= Duration::from_millis(35), + "proxy-header write error path should avoid immediate-return timing signature" + ); + + accept_task.await.unwrap(); +} + +#[cfg(unix)] +#[tokio::test] +async fn proxy_header_write_error_on_unix_path_still_honors_coarse_outcome_budget() { + let sock_path = format!( + "/tmp/telemt-mask-unix-hdr-err-{}-{}.sock", + std::process::id(), + rand::random::() + ); + let _ = std::fs::remove_file(&sock_path); + + let listener = UnixListener::bind(&sock_path).unwrap(); + let probe = b"GET /unix-hdr-err HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + drop(stream); + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = Some(sock_path.clone()); + config.censorship.mask_proxy_protocol = 1; + + let peer: SocketAddr = "203.0.113.89:42431".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (client_reader_side, client_reader) = duplex(256); + drop(client_reader_side); + let (_client_visible_reader, client_visible_writer) = duplex(512); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + let task = tokio::spawn(async move { + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + }); + + timeout(Duration::from_millis(35), task) + .await + .expect_err("unix proxy-header write error path should remain inside coarse masking budget window"); + assert!( + started.elapsed() >= Duration::from_millis(35), + "unix proxy-header write error path should avoid immediate-return timing signature" + ); + + accept_task.await.unwrap(); + let _ = std::fs::remove_file(sock_path); +} + +#[cfg(unix)] +#[tokio::test] +async fn unix_socket_proxy_protocol_v1_header_is_sent_before_probe() { + let sock_path = format!( + "/tmp/telemt-mask-unix-v1-{}-{}.sock", + std::process::id(), + rand::random::() + ); + let _ = std::fs::remove_file(&sock_path); + + let listener = UnixListener::bind(&sock_path).unwrap(); + let probe = b"GET /unix-v1 HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + let backend_reply = b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + let backend_reply = backend_reply.clone(); + async move { + let (stream, _) = listener.accept().await.unwrap(); + let mut reader = BufReader::new(stream); + + let mut header_line = Vec::new(); + reader.read_until(b'\n', &mut header_line).await.unwrap(); + let header_text = String::from_utf8(header_line).unwrap(); + assert!(header_text.starts_with("PROXY "), "must start with PROXY prefix"); + assert!(header_text.ends_with("\r\n"), "v1 header must end with CRLF"); + + let mut received_probe = vec![0u8; probe.len()]; + reader.read_exact(&mut received_probe).await.unwrap(); + assert_eq!(received_probe, probe); + + let mut stream = reader.into_inner(); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = Some(sock_path.clone()); + config.censorship.mask_proxy_protocol = 1; + + let peer: SocketAddr = "203.0.113.51:51010".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (client_reader, _client_writer) = duplex(256); + let (mut client_visible_reader, client_visible_writer) = duplex(2048); + + let beobachten = BeobachtenStore::new(); + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + let mut observed = vec![0u8; backend_reply.len()]; + client_visible_reader.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, backend_reply); + + accept_task.await.unwrap(); + let _ = std::fs::remove_file(sock_path); +} + +#[cfg(unix)] +#[tokio::test] +async fn unix_socket_proxy_protocol_v2_header_is_sent_before_probe() { + let sock_path = format!( + "/tmp/telemt-mask-unix-v2-{}-{}.sock", + std::process::id(), + rand::random::() + ); + let _ = std::fs::remove_file(&sock_path); + + let listener = UnixListener::bind(&sock_path).unwrap(); + let probe = b"GET /unix-v2 HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + let backend_reply = b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut sig = [0u8; 12]; + stream.read_exact(&mut sig).await.unwrap(); + assert_eq!(&sig, b"\r\n\r\n\0\r\nQUIT\n", "v2 signature must match spec"); + + let mut fixed = [0u8; 4]; + stream.read_exact(&mut fixed).await.unwrap(); + let addr_len = u16::from_be_bytes([fixed[2], fixed[3]]) as usize; + let mut addr_block = vec![0u8; addr_len]; + stream.read_exact(&mut addr_block).await.unwrap(); + + let mut received_probe = vec![0u8; probe.len()]; + stream.read_exact(&mut received_probe).await.unwrap(); + assert_eq!(received_probe, probe); + + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = Some(sock_path.clone()); + config.censorship.mask_proxy_protocol = 2; + + let peer: SocketAddr = "203.0.113.52:51011".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (client_reader, _client_writer) = duplex(256); + let (mut client_visible_reader, client_visible_writer) = duplex(2048); + + let beobachten = BeobachtenStore::new(); + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + let mut observed = vec![0u8; backend_reply.len()]; + client_visible_reader.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, backend_reply); + + accept_task.await.unwrap(); + let _ = std::fs::remove_file(sock_path); +} + #[tokio::test] async fn mask_disabled_fast_eof_not_shaped_by_mask_budget() { let mut config = ProxyConfig::default(); diff --git a/src/proxy/middle_relay.rs b/src/proxy/middle_relay.rs index 1dbbbfd..7298cb4 100644 --- a/src/proxy/middle_relay.rs +++ b/src/proxy/middle_relay.rs @@ -44,6 +44,10 @@ const C2ME_SEND_TIMEOUT: Duration = Duration::from_millis(50); const C2ME_SEND_TIMEOUT: Duration = Duration::from_secs(5); const ME_D2C_FLUSH_BATCH_MAX_FRAMES_MIN: usize = 1; const ME_D2C_FLUSH_BATCH_MAX_BYTES_MIN: usize = 4096; +#[cfg(test)] +const QUOTA_USER_LOCKS_MAX: usize = 64; +#[cfg(not(test))] +const QUOTA_USER_LOCKS_MAX: usize = 4_096; static DESYNC_DEDUP: OnceLock> = OnceLock::new(); static DESYNC_HASHER: OnceLock = OnceLock::new(); static DESYNC_FULL_CACHE_LAST_EMIT_AT: OnceLock>> = OnceLock::new(); @@ -336,6 +340,14 @@ fn quota_user_lock(user: &str) -> Arc> { return Arc::clone(existing.value()); } + if locks.len() >= QUOTA_USER_LOCKS_MAX { + locks.retain(|_, value| Arc::strong_count(value) > 1); + } + + if locks.len() >= QUOTA_USER_LOCKS_MAX { + return Arc::new(AsyncMutex::new(())); + } + let created = Arc::new(AsyncMutex::new(())); match locks.entry(user.to_string()) { dashmap::mapref::entry::Entry::Occupied(entry) => Arc::clone(entry.get()), @@ -405,7 +417,7 @@ where ); let (conn_id, me_rx) = me_pool.registry().register().await; - let trace_id = conn_id; + let trace_id = session_id; let bytes_me2c = Arc::new(AtomicU64::new(0)); let mut forensics = RelayForensicsState { trace_id, diff --git a/src/proxy/middle_relay_security_tests.rs b/src/proxy/middle_relay_security_tests.rs index 4dd1178..896e465 100644 --- a/src/proxy/middle_relay_security_tests.rs +++ b/src/proxy/middle_relay_security_tests.rs @@ -15,7 +15,9 @@ use std::net::SocketAddr; use std::sync::Arc; use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering}; use std::thread; +use tokio::sync::Barrier; use tokio::io::AsyncReadExt; +use tokio::io::AsyncWriteExt; use tokio::io::duplex; use tokio::time::{Duration as TokioDuration, timeout}; @@ -233,6 +235,219 @@ fn desync_dedup_cache_is_bounded() { ); } +#[test] +fn quota_user_lock_cache_reuses_entry_for_same_user() { + let a = quota_user_lock("quota-user-a"); + let b = quota_user_lock("quota-user-a"); + assert!(Arc::ptr_eq(&a, &b), "same user must reuse same quota lock"); +} + +#[test] +fn quota_user_lock_cache_is_bounded_under_unique_churn() { + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + for idx in 0..(QUOTA_USER_LOCKS_MAX + 128) { + let user = format!("quota-user-{idx}"); + let lock = quota_user_lock(&user); + drop(lock); + } + + assert!( + map.len() <= QUOTA_USER_LOCKS_MAX, + "quota lock cache must stay within configured bound" + ); +} + +#[test] +fn quota_user_lock_cache_saturation_returns_ephemeral_lock_without_growth() { + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); + for idx in 0..QUOTA_USER_LOCKS_MAX { + let user = format!("quota-held-user-{idx}"); + retained.push(quota_user_lock(&user)); + } + + assert_eq!( + map.len(), + QUOTA_USER_LOCKS_MAX, + "precondition: cache should be full before overflow acquisition" + ); + + let overflow_a = quota_user_lock("quota-overflow-user"); + let overflow_b = quota_user_lock("quota-overflow-user"); + + assert_eq!( + map.len(), + QUOTA_USER_LOCKS_MAX, + "overflow acquisition must not grow cache past hard limit" + ); + assert!( + map.get("quota-overflow-user").is_none(), + "overflow path should not cache new user lock when map is saturated and all entries are retained" + ); + assert!( + !Arc::ptr_eq(&overflow_a, &overflow_b), + "overflow user lock should be ephemeral under saturation to preserve bounded cache size" + ); + + drop(retained); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn adversarial_quota_race_under_lock_cache_saturation_still_allows_only_one_winner() { + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); + for idx in 0..QUOTA_USER_LOCKS_MAX { + let user = format!("quota-saturated-user-{idx}"); + retained.push(quota_user_lock(&user)); + } + + assert_eq!( + map.len(), + QUOTA_USER_LOCKS_MAX, + "precondition: cache must be saturated for overflow-user race test" + ); + + let stats = Stats::new(); + let bytes_me2c = AtomicU64::new(0); + let user = "gap-t04-saturated-lock-race-user"; + let barrier = Arc::new(Barrier::new(2)); + + let one = run_quota_race_attempt(&stats, &bytes_me2c, user, 0x55, 9101, barrier.clone()); + let two = run_quota_race_attempt(&stats, &bytes_me2c, user, 0x66, 9102, barrier); + let (r1, r2) = tokio::join!(one, two); + + assert!( + matches!(r1, Ok(_) | Err(ProxyError::DataQuotaExceeded { .. })) + && matches!(r2, Ok(_) | Err(ProxyError::DataQuotaExceeded { .. })), + "both racers must resolve cleanly without unexpected errors" + ); + assert!( + matches!(r1, Err(ProxyError::DataQuotaExceeded { .. })) + || matches!(r2, Err(ProxyError::DataQuotaExceeded { .. })), + "at least one racer must be quota-rejected even when lock cache is saturated" + ); + assert_eq!( + stats.get_user_total_octets(user), + 1, + "saturated lock cache must not permit double-success quota overshoot" + ); + + drop(retained); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_quota_race_under_lock_cache_saturation_never_allows_double_success() { + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); + for idx in 0..QUOTA_USER_LOCKS_MAX { + let user = format!("quota-saturated-stress-holder-{idx}"); + retained.push(quota_user_lock(&user)); + } + + let stats = Stats::new(); + let bytes_me2c = AtomicU64::new(0); + + for round in 0..128u64 { + let user = format!("gap-t04-saturated-race-round-{round}"); + let barrier = Arc::new(Barrier::new(2)); + + let one = run_quota_race_attempt( + &stats, + &bytes_me2c, + &user, + 0x71, + 12_000 + round, + barrier.clone(), + ); + let two = run_quota_race_attempt( + &stats, + &bytes_me2c, + &user, + 0x72, + 13_000 + round, + barrier, + ); + + let (r1, r2) = tokio::join!(one, two); + assert!( + matches!(r1, Ok(_) | Err(ProxyError::DataQuotaExceeded { .. })) + && matches!(r2, Ok(_) | Err(ProxyError::DataQuotaExceeded { .. })), + "round {round}: racers must resolve cleanly" + ); + assert!( + matches!(r1, Err(ProxyError::DataQuotaExceeded { .. })) + || matches!(r2, Err(ProxyError::DataQuotaExceeded { .. })), + "round {round}: at least one racer must be quota-rejected" + ); + assert_eq!( + stats.get_user_total_octets(&user), + 1, + "round {round}: saturated cache must still enforce exactly one forwarded byte" + ); + } + + drop(retained); +} + +#[test] +fn adversarial_forensics_trace_id_should_not_alias_conn_id() { + let now = Instant::now(); + let trace_id = 0x1122_3344_5566_7788; + let conn_id = 0x8877_6655_4433_2211; + let state = RelayForensicsState { + trace_id, + conn_id, + user: "trace-user".to_string(), + peer: "198.51.100.17:443".parse().unwrap(), + peer_hash: 0x8877_6655_4433_2211, + started_at: now, + bytes_c2me: 0, + bytes_me2c: Arc::new(AtomicU64::new(0)), + desync_all_full: false, + }; + + assert_ne!( + state.trace_id, state.conn_id, + "security expectation: trace correlation should be independent of connection identity" + ); + assert_eq!(state.trace_id, trace_id); + assert_eq!(state.conn_id, conn_id); +} + +#[tokio::test] +async fn abridged_ack_uses_big_endian_confirm_bytes_after_decryption() { + let (mut writer_side, reader_side) = duplex(8); + let key = [0u8; 32]; + let iv = 0u128; + let mut writer = CryptoWriter::new(reader_side, AesCtr::new(&key, iv), 8 * 1024); + + write_client_ack(&mut writer, ProtoTag::Abridged, 0x11_22_33_44) + .await + .expect("ack write must succeed"); + + let mut observed = [0u8; 4]; + writer_side + .read_exact(&mut observed) + .await + .expect("ack bytes must be readable"); + let mut decryptor = AesCtr::new(&key, iv); + let decrypted = decryptor.decrypt(&observed); + + assert_eq!( + decrypted, + 0x11_22_33_44u32.to_be_bytes(), + "abridged ACK should encode confirm bytes in big-endian order" + ); +} + #[test] fn desync_dedup_full_cache_churn_stays_suppressed() { let _guard = desync_dedup_test_lock() @@ -1707,6 +1922,150 @@ async fn middle_relay_cutover_midflight_releases_route_gauge() { drop(client_side); } +async fn run_quota_race_attempt( + stats: &Stats, + bytes_me2c: &AtomicU64, + user: &str, + payload: u8, + conn_id: u64, + barrier: Arc, +) -> Result { + let (writer_side, _reader_side) = duplex(1024); + let mut writer = make_crypto_writer(writer_side); + let rng = SecureRandom::new(); + let mut frame_buf = Vec::new(); + + barrier.wait().await; + process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from(vec![payload]), + }, + &mut writer, + ProtoTag::Intermediate, + &rng, + &mut frame_buf, + stats, + user, + Some(1), + bytes_me2c, + conn_id, + false, + false, + ) + .await +} + +#[tokio::test] +async fn abridged_max_extended_length_fails_closed_without_panic_or_partial_read() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("middle relay test lock must be available"); + + let (reader, mut writer) = duplex(256); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let forensics = make_forensics_state(); + let mut frame_counter = 0; + + let plaintext = vec![0x7f, 0xff, 0xff, 0xff]; + let encrypted = encrypt_for_reader(&plaintext); + writer.write_all(&encrypted).await.unwrap(); + + let result = read_client_payload( + &mut crypto_reader, + ProtoTag::Abridged, + 4096, + TokioDuration::from_secs(1), + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + ) + .await; + + assert!(result.is_err(), "oversized abridged length must fail closed"); + assert_eq!(frame_counter, 0, "oversized frame must not be counted as accepted"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn deterministic_quota_race_exactly_one_succeeds_and_one_is_rejected() { + let stats = Stats::new(); + let bytes_me2c = AtomicU64::new(0); + let user = "gap-t04-race-user"; + let barrier = Arc::new(Barrier::new(2)); + + let f1 = run_quota_race_attempt(&stats, &bytes_me2c, user, 0x11, 5001, barrier.clone()); + let f2 = run_quota_race_attempt(&stats, &bytes_me2c, user, 0x22, 5002, barrier); + + let (r1, r2) = tokio::join!(f1, f2); + + assert!( + matches!(r1, Ok(_) | Err(ProxyError::DataQuotaExceeded { .. })), + "first racer must either finish or fail closed on quota" + ); + assert!( + matches!(r2, Ok(_) | Err(ProxyError::DataQuotaExceeded { .. })), + "second racer must either finish or fail closed on quota" + ); + assert!( + matches!(r1, Err(ProxyError::DataQuotaExceeded { .. })) + || matches!(r2, Err(ProxyError::DataQuotaExceeded { .. })), + "at least one racer must be quota-rejected" + ); + assert_eq!( + stats.get_user_total_octets(user), + 1, + "same-user race must forward/account exactly one payload byte" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_quota_race_bursts_never_allow_double_success_per_round() { + let stats = Stats::new(); + let bytes_me2c = AtomicU64::new(0); + + for round in 0..128u64 { + let user = format!("gap-t04-race-burst-{round}"); + let barrier = Arc::new(Barrier::new(2)); + + let one = run_quota_race_attempt( + &stats, + &bytes_me2c, + &user, + 0x33, + 6000 + round, + barrier.clone(), + ); + let two = run_quota_race_attempt( + &stats, + &bytes_me2c, + &user, + 0x44, + 7000 + round, + barrier, + ); + + let (r1, r2) = tokio::join!(one, two); + assert!( + matches!(r1, Ok(_) | Err(ProxyError::DataQuotaExceeded { .. })) + && matches!(r2, Ok(_) | Err(ProxyError::DataQuotaExceeded { .. })), + "round {round}: racers must resolve cleanly without unexpected errors" + ); + assert!( + matches!(r1, Err(ProxyError::DataQuotaExceeded { .. })) + || matches!(r2, Err(ProxyError::DataQuotaExceeded { .. })), + "round {round}: at least one racer must be quota-rejected" + ); + assert_eq!( + stats.get_user_total_octets(&user), + 1, + "round {round}: same-user total octets must remain exactly 1 (single forwarded winner)" + ); + } +} + #[tokio::test] async fn middle_relay_cutover_storm_multi_session_keeps_generic_errors_and_releases_gauge() { let session_count = 6usize; diff --git a/src/proxy/relay.rs b/src/proxy/relay.rs index 46a2b21..8b4c87f 100644 --- a/src/proxy/relay.rs +++ b/src/proxy/relay.rs @@ -208,6 +208,8 @@ struct StatsIo { user: String, quota_limit: Option, quota_exceeded: Arc, + quota_read_wake_scheduled: bool, + quota_write_wake_scheduled: bool, epoch: Instant, } @@ -230,6 +232,8 @@ impl StatsIo { user, quota_limit, quota_exceeded, + quota_read_wake_scheduled: false, + quota_write_wake_scheduled: false, epoch, } } @@ -293,9 +297,19 @@ impl AsyncRead for StatsIo { .then(|| quota_user_lock(&this.user)); let _quota_guard = if let Some(lock) = quota_lock.as_ref() { match lock.try_lock() { - Ok(guard) => Some(guard), + Ok(guard) => { + this.quota_read_wake_scheduled = false; + Some(guard) + } Err(_) => { - cx.waker().wake_by_ref(); + if !this.quota_read_wake_scheduled { + this.quota_read_wake_scheduled = true; + let waker = cx.waker().clone(); + tokio::task::spawn(async move { + tokio::task::yield_now().await; + waker.wake(); + }); + } return Poll::Pending; } } @@ -356,9 +370,19 @@ impl AsyncWrite for StatsIo { .then(|| quota_user_lock(&this.user)); let _quota_guard = if let Some(lock) = quota_lock.as_ref() { match lock.try_lock() { - Ok(guard) => Some(guard), + Ok(guard) => { + this.quota_write_wake_scheduled = false; + Some(guard) + } Err(_) => { - cx.waker().wake_by_ref(); + if !this.quota_write_wake_scheduled { + this.quota_write_wake_scheduled = true; + let waker = cx.waker().clone(); + tokio::task::spawn(async move { + tokio::task::yield_now().await; + waker.wake(); + }); + } return Poll::Pending; } } diff --git a/src/proxy/relay_security_tests.rs b/src/proxy/relay_security_tests.rs index 7b985cb..9ba8295 100644 --- a/src/proxy/relay_security_tests.rs +++ b/src/proxy/relay_security_tests.rs @@ -14,6 +14,176 @@ use tokio::io::{AsyncRead, ReadBuf}; use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt, duplex}; use tokio::time::{Duration, timeout}; +#[derive(Default)] +struct WakeCounter { + wakes: AtomicUsize, +} + +impl std::task::Wake for WakeCounter { + fn wake(self: Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } + + fn wake_by_ref(self: &Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } +} + +#[tokio::test] +async fn quota_lock_contention_does_not_self_wake_pending_writer() { + let stats = Arc::new(Stats::new()); + let user = "quota-lock-contention-user"; + + let lock = super::quota_user_lock(user); + let _held_lock = lock + .try_lock() + .expect("test must hold the per-user quota lock before polling writer"); + + let counters = Arc::new(super::SharedCounters::new()); + let quota_exceeded = Arc::new(std::sync::atomic::AtomicBool::new(false)); + let mut io = super::StatsIo::new( + tokio::io::sink(), + counters, + Arc::clone(&stats), + user.to_string(), + Some(1024), + quota_exceeded, + tokio::time::Instant::now(), + ); + + let wake_counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&wake_counter)); + let mut cx = Context::from_waker(&waker); + + let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x11]); + assert!(poll.is_pending(), "writer must remain pending while lock is contended"); + assert_eq!( + wake_counter.wakes.load(Ordering::Relaxed), + 0, + "contended quota lock must not self-wake immediately and spin the executor" + ); +} + +#[tokio::test] +async fn quota_lock_contention_writer_schedules_single_deferred_wake_until_lock_acquired() { + let stats = Arc::new(Stats::new()); + let user = "quota-lock-writer-liveness-user"; + + let lock = super::quota_user_lock(user); + let held_lock = lock + .try_lock() + .expect("test must hold the per-user quota lock before polling writer"); + + let counters = Arc::new(super::SharedCounters::new()); + let quota_exceeded = Arc::new(std::sync::atomic::AtomicBool::new(false)); + let mut io = super::StatsIo::new( + tokio::io::sink(), + counters, + Arc::clone(&stats), + user.to_string(), + Some(1024), + quota_exceeded, + tokio::time::Instant::now(), + ); + + let wake_counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&wake_counter)); + let mut cx = Context::from_waker(&waker); + + let first = Pin::new(&mut io).poll_write(&mut cx, &[0x11]); + assert!(first.is_pending(), "writer must remain pending while lock is contended"); + assert_eq!( + wake_counter.wakes.load(Ordering::Relaxed), + 0, + "deferred wake must not fire synchronously" + ); + + timeout(Duration::from_millis(50), async { + loop { + if wake_counter.wakes.load(Ordering::Relaxed) >= 1 { + break; + } + tokio::task::yield_now().await; + } + }) + .await + .expect("contended writer must schedule a deferred wake in bounded time"); + let wakes_after_first_yield = wake_counter.wakes.load(Ordering::Relaxed); + assert!( + wakes_after_first_yield >= 1, + "contended writer must schedule at least one deferred wake for liveness" + ); + + let second = Pin::new(&mut io).poll_write(&mut cx, &[0x22]); + assert!(second.is_pending(), "writer remains pending while lock is still held"); + + for _ in 0..8 { + tokio::task::yield_now().await; + } + assert_eq!( + wake_counter.wakes.load(Ordering::Relaxed), + wakes_after_first_yield, + "writer contention should not schedule unbounded wake storms before lock acquisition" + ); + + drop(held_lock); + let released = Pin::new(&mut io).poll_write(&mut cx, &[0x33]); + assert!(released.is_ready(), "writer must make progress once quota lock is released"); +} + +#[tokio::test] +async fn quota_lock_contention_read_path_schedules_deferred_wake_for_liveness() { + let stats = Arc::new(Stats::new()); + let user = "quota-lock-read-liveness-user"; + + let lock = super::quota_user_lock(user); + let held_lock = lock + .try_lock() + .expect("test must hold the per-user quota lock before polling reader"); + + let counters = Arc::new(super::SharedCounters::new()); + let quota_exceeded = Arc::new(std::sync::atomic::AtomicBool::new(false)); + let mut io = super::StatsIo::new( + tokio::io::empty(), + counters, + Arc::clone(&stats), + user.to_string(), + Some(1024), + quota_exceeded, + tokio::time::Instant::now(), + ); + + let wake_counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&wake_counter)); + let mut cx = Context::from_waker(&waker); + let mut storage = [0u8; 1]; + let mut buf = ReadBuf::new(&mut storage); + + let first = Pin::new(&mut io).poll_read(&mut cx, &mut buf); + assert!(first.is_pending(), "reader must remain pending while lock is contended"); + assert_eq!( + wake_counter.wakes.load(Ordering::Relaxed), + 0, + "read contention wake must not fire synchronously" + ); + + timeout(Duration::from_millis(50), async { + loop { + if wake_counter.wakes.load(Ordering::Relaxed) >= 1 { + break; + } + tokio::task::yield_now().await; + } + }) + .await + .expect("read contention must schedule a deferred wake in bounded time"); + + drop(held_lock); + let mut buf_after_release = ReadBuf::new(&mut storage); + let released = Pin::new(&mut io).poll_read(&mut cx, &mut buf_after_release); + assert!(released.is_ready(), "reader must make progress once quota lock is released"); +} + #[tokio::test] async fn relay_bidirectional_enforces_live_user_quota() { let stats = Arc::new(Stats::new()); diff --git a/src/proxy/route_mode_security_tests.rs b/src/proxy/route_mode_security_tests.rs index e86d574..2926615 100644 --- a/src/proxy/route_mode_security_tests.rs +++ b/src/proxy/route_mode_security_tests.rs @@ -338,3 +338,69 @@ fn light_fuzz_cutover_stagger_delay_distribution_stays_in_fixed_window() { ); } } + +#[test] +fn cutover_stagger_delay_distribution_has_no_empty_buckets_under_sequential_sessions() { + let mut buckets = [0usize; 1000]; + let generation = 4242u64; + + for session_id in 0..250_000u64 { + let delay_ms = cutover_stagger_delay(session_id, generation).as_millis() as usize; + let idx = delay_ms - 1000; + buckets[idx] += 1; + } + + let empty = buckets.iter().filter(|&&count| count == 0).count(); + assert_eq!( + empty, 0, + "all 1000 delay buckets must be exercised to avoid cutover herd clustering" + ); +} + +#[test] +fn light_fuzz_cutover_stagger_delay_distribution_stays_reasonably_uniform() { + let mut buckets = [0usize; 1000]; + let mut s: u64 = 0x1BAD_B002_CAFE_F00D; + + for _ in 0..300_000usize { + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + let session_id = s; + + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + let generation = s; + + let delay_ms = cutover_stagger_delay(session_id, generation).as_millis() as usize; + buckets[delay_ms - 1000] += 1; + } + + let min = *buckets.iter().min().unwrap_or(&0); + let max = *buckets.iter().max().unwrap_or(&0); + assert!(min > 0, "fuzzed distribution must not leave empty buckets"); + assert!( + max <= min.saturating_mul(3), + "bucket skew is too high for anti-herd staggering (max={max}, min={min})" + ); +} + +#[test] +fn stress_cutover_stagger_delay_distribution_remains_stable_across_generations() { + for generation in [0u64, 1, 7, 31, 255, 1024, u32::MAX as u64, u64::MAX - 1] { + let mut buckets = [0usize; 1000]; + for session_id in 0..100_000u64 { + let delay_ms = cutover_stagger_delay(session_id ^ 0x9E37_79B9, generation) + .as_millis() as usize; + buckets[delay_ms - 1000] += 1; + } + + let min = *buckets.iter().min().unwrap_or(&0); + let max = *buckets.iter().max().unwrap_or(&0); + assert!( + max <= min.saturating_mul(4).max(1), + "generation={generation}: distribution collapsed (max={max}, min={min})" + ); + } +} diff --git a/src/stats/mod.rs b/src/stats/mod.rs index 3ad361f..3c79448 100644 --- a/src/stats/mod.rs +++ b/src/stats/mod.rs @@ -1508,9 +1508,11 @@ impl Stats { // ============= Replay Checker ============= pub struct ReplayChecker { - shards: Vec>, + handshake_shards: Vec>, + tls_shards: Vec>, shard_mask: usize, window: Duration, + tls_window: Duration, checks: AtomicU64, hits: AtomicU64, additions: AtomicU64, @@ -1587,19 +1589,24 @@ impl ReplayShard { impl ReplayChecker { pub fn new(total_capacity: usize, window: Duration) -> Self { + const MIN_TLS_REPLAY_WINDOW: Duration = Duration::from_secs(120); let num_shards = 64; let shard_capacity = (total_capacity / num_shards).max(1); let cap = NonZeroUsize::new(shard_capacity).unwrap(); - let mut shards = Vec::with_capacity(num_shards); + let mut handshake_shards = Vec::with_capacity(num_shards); + let mut tls_shards = Vec::with_capacity(num_shards); for _ in 0..num_shards { - shards.push(Mutex::new(ReplayShard::new(cap))); + handshake_shards.push(Mutex::new(ReplayShard::new(cap))); + tls_shards.push(Mutex::new(ReplayShard::new(cap))); } Self { - shards, + handshake_shards, + tls_shards, shard_mask: num_shards - 1, window, + tls_window: window.max(MIN_TLS_REPLAY_WINDOW), checks: AtomicU64::new(0), hits: AtomicU64::new(0), additions: AtomicU64::new(0), @@ -1613,46 +1620,60 @@ impl ReplayChecker { (hasher.finish() as usize) & self.shard_mask } - fn check_and_add_internal(&self, data: &[u8]) -> bool { + fn check_and_add_internal( + &self, + data: &[u8], + shards: &[Mutex], + window: Duration, + ) -> bool { self.checks.fetch_add(1, Ordering::Relaxed); let idx = self.get_shard_idx(data); - let mut shard = self.shards[idx].lock(); + let mut shard = shards[idx].lock(); let now = Instant::now(); - let found = shard.check(data, now, self.window); + let found = shard.check(data, now, window); if found { self.hits.fetch_add(1, Ordering::Relaxed); } else { - shard.add(data, now, self.window); + shard.add(data, now, window); self.additions.fetch_add(1, Ordering::Relaxed); } found } - fn add_only(&self, data: &[u8]) { + fn add_only(&self, data: &[u8], shards: &[Mutex], window: Duration) { self.additions.fetch_add(1, Ordering::Relaxed); let idx = self.get_shard_idx(data); - let mut shard = self.shards[idx].lock(); - shard.add(data, Instant::now(), self.window); + let mut shard = shards[idx].lock(); + shard.add(data, Instant::now(), window); } pub fn check_and_add_handshake(&self, data: &[u8]) -> bool { - self.check_and_add_internal(data) + self.check_and_add_internal(data, &self.handshake_shards, self.window) } pub fn check_and_add_tls_digest(&self, data: &[u8]) -> bool { - self.check_and_add_internal(data) + self.check_and_add_internal(data, &self.tls_shards, self.tls_window) } // Compatibility helpers (non-atomic split operations) — prefer check_and_add_*. pub fn check_handshake(&self, data: &[u8]) -> bool { self.check_and_add_handshake(data) } - pub fn add_handshake(&self, data: &[u8]) { self.add_only(data) } + pub fn add_handshake(&self, data: &[u8]) { + self.add_only(data, &self.handshake_shards, self.window) + } pub fn check_tls_digest(&self, data: &[u8]) -> bool { self.check_and_add_tls_digest(data) } - pub fn add_tls_digest(&self, data: &[u8]) { self.add_only(data) } + pub fn add_tls_digest(&self, data: &[u8]) { + self.add_only(data, &self.tls_shards, self.tls_window) + } pub fn stats(&self) -> ReplayStats { let mut total_entries = 0; let mut total_queue_len = 0; - for shard in &self.shards { + for shard in &self.handshake_shards { + let s = shard.lock(); + total_entries += s.cache.len(); + total_queue_len += s.queue.len(); + } + for shard in &self.tls_shards { let s = shard.lock(); total_entries += s.cache.len(); total_queue_len += s.queue.len(); @@ -1665,7 +1686,7 @@ impl ReplayChecker { total_hits: self.hits.load(Ordering::Relaxed), total_additions: self.additions.load(Ordering::Relaxed), total_cleanups: self.cleanups.load(Ordering::Relaxed), - num_shards: self.shards.len(), + num_shards: self.handshake_shards.len() + self.tls_shards.len(), window_secs: self.window.as_secs(), } } @@ -1683,13 +1704,20 @@ impl ReplayChecker { let now = Instant::now(); let mut cleaned = 0usize; - for shard_mutex in &self.shards { + for shard_mutex in &self.handshake_shards { let mut shard = shard_mutex.lock(); let before = shard.len(); shard.cleanup(now, self.window); let after = shard.len(); cleaned += before.saturating_sub(after); } + for shard_mutex in &self.tls_shards { + let mut shard = shard_mutex.lock(); + let before = shard.len(); + shard.cleanup(now, self.tls_window); + let after = shard.len(); + cleaned += before.saturating_sub(after); + } self.cleanups.fetch_add(1, Ordering::Relaxed); @@ -1815,7 +1843,7 @@ mod tests { fn test_replay_checker_many_keys() { let checker = ReplayChecker::new(10_000, Duration::from_secs(60)); for i in 0..500u32 { - checker.add_only(&i.to_le_bytes()); + checker.add_handshake(&i.to_le_bytes()); } for i in 0..500u32 { assert!(checker.check_handshake(&i.to_le_bytes())); @@ -1827,3 +1855,7 @@ mod tests { #[cfg(test)] #[path = "connection_lease_security_tests.rs"] mod connection_lease_security_tests; + +#[cfg(test)] +#[path = "replay_checker_security_tests.rs"] +mod replay_checker_security_tests; diff --git a/src/stats/replay_checker_security_tests.rs b/src/stats/replay_checker_security_tests.rs new file mode 100644 index 0000000..8e73204 --- /dev/null +++ b/src/stats/replay_checker_security_tests.rs @@ -0,0 +1,80 @@ +use super::*; +use std::time::Duration; + +#[test] +fn replay_checker_keeps_tls_and_handshake_domains_isolated_for_same_key() { + let checker = ReplayChecker::new(128, Duration::from_millis(20)); + let key = b"same-key-domain-separation"; + + assert!( + !checker.check_and_add_handshake(key), + "first handshake use should be fresh" + ); + assert!( + !checker.check_and_add_tls_digest(key), + "same bytes in TLS domain should still be fresh" + ); + + assert!( + checker.check_and_add_handshake(key), + "second handshake use should be replay-hit" + ); + assert!( + checker.check_and_add_tls_digest(key), + "second TLS use should be replay-hit independently" + ); +} + +#[test] +fn replay_checker_tls_window_is_clamped_beyond_small_handshake_window() { + let checker = ReplayChecker::new(128, Duration::from_millis(20)); + let handshake_key = b"short-window-handshake"; + let tls_key = b"short-window-tls"; + + assert!(!checker.check_and_add_handshake(handshake_key)); + assert!(!checker.check_and_add_tls_digest(tls_key)); + + std::thread::sleep(Duration::from_millis(80)); + + assert!( + !checker.check_and_add_handshake(handshake_key), + "handshake key should expire under short configured window" + ); + assert!( + checker.check_and_add_tls_digest(tls_key), + "TLS key should still be replay-hit because TLS window is clamped to a secure minimum" + ); +} + +#[test] +fn replay_checker_compat_add_paths_do_not_cross_pollute_domains() { + let checker = ReplayChecker::new(128, Duration::from_secs(1)); + let key = b"compat-domain-separation"; + + checker.add_handshake(key); + assert!( + checker.check_and_add_handshake(key), + "handshake add helper must populate handshake domain" + ); + assert!( + !checker.check_and_add_tls_digest(key), + "handshake add helper must not pollute TLS domain" + ); + + checker.add_tls_digest(key); + assert!( + checker.check_and_add_tls_digest(key), + "TLS add helper must populate TLS domain" + ); +} + +#[test] +fn replay_checker_stats_reflect_dual_shard_domains() { + let checker = ReplayChecker::new(128, Duration::from_secs(1)); + let stats = checker.stats(); + + assert_eq!( + stats.num_shards, 128, + "stats should expose both shard domains (handshake + TLS)" + ); +}