diff --git a/src/config/defaults.rs b/src/config/defaults.rs index a136539..73b12d8 100644 --- a/src/config/defaults.rs +++ b/src/config/defaults.rs @@ -73,6 +73,8 @@ pub(crate) fn default_replay_check_len() -> usize { } pub(crate) fn default_replay_window_secs() -> u64 { + // Keep replay cache TTL tight by default to reduce replay surface. + // Deployments with higher RTT or longer reconnect jitter can override this in config. 120 } diff --git a/src/protocol/tls.rs b/src/protocol/tls.rs index 3a22214..5ff38ae 100644 --- a/src/protocol/tls.rs +++ b/src/protocol/tls.rs @@ -27,6 +27,10 @@ pub const TLS_DIGEST_POS: usize = 11; pub const TLS_DIGEST_HALF_LEN: usize = 16; /// Time skew limits for anti-replay (in seconds) +/// +/// The default window is intentionally narrow to reduce replay acceptance. +/// Operators with known clock-drifted clients should tune deployment config +/// (for example replay-window policy) to match their environment. pub const TIME_SKEW_MIN: i64 = -2 * 60; // 2 minutes before pub const TIME_SKEW_MAX: i64 = 2 * 60; // 2 minutes after /// Maximum accepted boot-time timestamp (seconds) before skew checks are enforced. @@ -316,7 +320,14 @@ pub fn validate_tls_handshake_with_replay_window( }; let replay_window_u32 = u32::try_from(replay_window_secs).unwrap_or(u32::MAX); - let boot_time_cap_secs = BOOT_TIME_MAX_SECS.min(replay_window_u32); + // Boot-time bypass and ignore_time_skew serve different compatibility paths. + // When skew checks are disabled, force boot-time cap to zero to prevent + // accidental future coupling of boot-time logic into the ignore-skew path. + let boot_time_cap_secs = if ignore_time_skew { + 0 + } else { + BOOT_TIME_MAX_SECS.min(replay_window_u32) + }; validate_tls_handshake_at_time_with_boot_cap( handshake, @@ -411,7 +422,7 @@ fn validate_tls_handshake_at_time_with_boot_cap( if !ignore_time_skew { // Allow very small timestamps (boot time instead of unix time) // This is a quirk in some clients that use uptime instead of real time - let is_boot_time = timestamp < boot_time_cap_secs; + let is_boot_time = boot_time_cap_secs > 0 && timestamp < boot_time_cap_secs; if !is_boot_time { let time_diff = now - i64::from(timestamp); if !(TIME_SKEW_MIN..=TIME_SKEW_MAX).contains(&time_diff) { @@ -705,10 +716,10 @@ pub fn is_tls_handshake(first_bytes: &[u8]) -> bool { return false; } - // TLS record header: 0x16 (handshake) 0x03 0x01 (TLS 1.0) + // TLS ClientHello commonly uses legacy record versions 0x0301 or 0x0303. first_bytes[0] == TLS_RECORD_HANDSHAKE && first_bytes[1] == 0x03 - && first_bytes[2] == 0x01 + && (first_bytes[2] == 0x01 || first_bytes[2] == 0x03) } /// Parse TLS record header, returns (record_type, length) diff --git a/src/protocol/tls_security_tests.rs b/src/protocol/tls_security_tests.rs index bfc8f0d..74baa2f 100644 --- a/src/protocol/tls_security_tests.rs +++ b/src/protocol/tls_security_tests.rs @@ -731,6 +731,246 @@ fn replay_window_cap_still_allows_small_boot_timestamp() { ); } +#[test] +fn ignore_time_skew_explicitly_decouples_from_boot_time_cap() { + let secret = b"ignore_skew_boot_cap_decouple_test"; + let ts: u32 = 1; + let h = make_valid_tls_handshake(secret, ts); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + let cap_zero = validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, true, 0, 0); + let cap_nonzero = + validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, true, 0, BOOT_TIME_MAX_SECS); + + assert!(cap_zero.is_some(), "ignore_time_skew=true must accept valid HMAC"); + assert!( + cap_nonzero.is_some(), + "ignore_time_skew path must not depend on boot-time cap" + ); + + let a = cap_zero.unwrap(); + let b = cap_nonzero.unwrap(); + assert_eq!(a.user, b.user); + assert_eq!(a.timestamp, b.timestamp); +} + +#[test] +fn adversarial_small_boot_timestamp_matrix_rejected_when_boot_cap_forced_zero() { + let secret = b"boot_cap_zero_matrix_test"; + let secrets = vec![("u".to_string(), secret.to_vec())]; + let now: i64 = 1_700_000_000; + + for ts in 0u32..1024u32 { + let h = make_valid_tls_handshake(secret, ts); + let result = validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, false, now, 0); + assert!( + result.is_none(), + "boot cap=0 must reject timestamp {ts} when skew checks are active" + ); + } +} + +#[test] +fn light_fuzz_boot_cap_zero_rejects_small_timestamp_space() { + let secret = b"boot_cap_zero_fuzz_test"; + let secrets = vec![("u".to_string(), secret.to_vec())]; + let now: i64 = 1_700_000_000; + let mut s: u64 = 0x9E37_79B9_7F4A_7C15; + + for _ in 0..4096 { + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + let ts = (s as u32) % 2048; + + let h = make_valid_tls_handshake(secret, ts); + let result = validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, false, now, 0); + assert!( + result.is_none(), + "fuzzed boot-range timestamp {ts} must be rejected when cap=0" + ); + } +} + +#[test] +fn stress_boot_cap_zero_rejection_is_deterministic_under_high_iteration_count() { + let secret = b"boot_cap_zero_stress_test"; + let secrets = vec![("u".to_string(), secret.to_vec())]; + let now: i64 = 1_700_000_000; + + for i in 0u32..20_000u32 { + let ts = i % 4096; + let h = make_valid_tls_handshake(secret, ts); + let result = validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, false, now, 0); + assert!( + result.is_none(), + "iteration {i}: timestamp {ts} must be rejected with cap=0" + ); + } +} + +#[test] +fn replay_window_one_allows_only_zero_timestamp_boot_bypass() { + let secret = b"replay_window_one_boot_test"; + let secrets = vec![("u".to_string(), secret.to_vec())]; + + let ts0 = make_valid_tls_handshake(secret, 0); + let ts1 = make_valid_tls_handshake(secret, 1); + + assert!( + validate_tls_handshake_with_replay_window(&ts0, &secrets, false, 1).is_some(), + "replay_window=1 must allow timestamp 0 via boot-time compatibility" + ); + assert!( + validate_tls_handshake_with_replay_window(&ts1, &secrets, false, 1).is_none(), + "replay_window=1 must reject timestamp 1 on normal wall-clock systems" + ); +} + +#[test] +fn replay_window_two_allows_ts0_ts1_but_rejects_ts2() { + let secret = b"replay_window_two_boot_test"; + let secrets = vec![("u".to_string(), secret.to_vec())]; + + let ts0 = make_valid_tls_handshake(secret, 0); + let ts1 = make_valid_tls_handshake(secret, 1); + let ts2 = make_valid_tls_handshake(secret, 2); + + assert!(validate_tls_handshake_with_replay_window(&ts0, &secrets, false, 2).is_some()); + assert!(validate_tls_handshake_with_replay_window(&ts1, &secrets, false, 2).is_some()); + assert!( + validate_tls_handshake_with_replay_window(&ts2, &secrets, false, 2).is_none(), + "timestamp equal to replay-window cap must not use boot-time bypass" + ); +} + +#[test] +fn adversarial_skew_boundary_matrix_accepts_only_inclusive_window_when_boot_disabled() { + let secret = b"skew_boundary_matrix_test"; + let secrets = vec![("u".to_string(), secret.to_vec())]; + let now: i64 = 1_700_000_000; + + for offset in -1500i64..=1500i64 { + let ts_i64 = now - offset; + let ts = u32::try_from(ts_i64).expect("timestamp must fit u32 for test matrix"); + let h = make_valid_tls_handshake(secret, ts); + let accepted = validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, false, now, 0) + .is_some(); + let expected = (TIME_SKEW_MIN..=TIME_SKEW_MAX).contains(&offset); + assert_eq!( + accepted, expected, + "offset {offset} must match inclusive skew window when boot bypass is disabled" + ); + } +} + +#[test] +fn light_fuzz_skew_window_rejects_outside_range_when_boot_disabled() { + let secret = b"skew_outside_fuzz_test"; + let secrets = vec![("u".to_string(), secret.to_vec())]; + let now: i64 = 1_700_000_000; + let mut s: u64 = 0x0123_4567_89AB_CDEF; + + for _ in 0..4096 { + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + + let magnitude = 1300i64 + ((s % 2000u64) as i64); + let sign = if (s & 1) == 0 { 1i64 } else { -1i64 }; + let offset = sign * magnitude; + let ts_i64 = now - offset; + let ts = u32::try_from(ts_i64).expect("timestamp must fit u32 for fuzz test"); + + let h = make_valid_tls_handshake(secret, ts); + let accepted = validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, false, now, 0) + .is_some(); + assert!( + !accepted, + "offset {offset} must be rejected outside strict skew window" + ); + } +} + +#[test] +fn stress_boot_disabled_validation_matches_time_diff_oracle() { + let secret = b"boot_disabled_oracle_stress_test"; + let secrets = vec![("u".to_string(), secret.to_vec())]; + let now: i64 = 1_700_000_000; + let mut s: u64 = 0xBADC_0FFE_EE11_2233; + + for _ in 0..25_000 { + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + let ts = s as u32; + let h = make_valid_tls_handshake(secret, ts); + + let accepted = validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, false, now, 0) + .is_some(); + let time_diff = now - i64::from(ts); + let expected = (TIME_SKEW_MIN..=TIME_SKEW_MAX).contains(&time_diff); + assert_eq!( + accepted, expected, + "boot-disabled validation must match pure time-diff oracle" + ); + } +} + +#[test] +fn integration_large_user_list_with_boot_disabled_finds_only_matching_user() { + let now: i64 = 1_700_000_000; + let target_secret = b"target_user_secret"; + let target_ts = (now - 1) as u32; + let handshake = make_valid_tls_handshake(target_secret, target_ts); + + let mut secrets = Vec::new(); + for i in 0..512u32 { + secrets.push((format!("noise-{i}"), format!("noise-secret-{i}").into_bytes())); + } + secrets.push(("target-user".to_string(), target_secret.to_vec())); + + let result = validate_tls_handshake_at_time_with_boot_cap(&handshake, &secrets, false, now, 0) + .expect("matching user should validate within strict skew window"); + assert_eq!(result.user, "target-user"); +} + +#[test] +fn light_fuzz_ignore_time_skew_accepts_wide_timestamp_range_with_valid_hmac() { + let secret = b"ignore_skew_fuzz_accept_test"; + let secrets = vec![("u".to_string(), secret.to_vec())]; + let mut s: u64 = 0xC0FF_EE11_2233_4455; + + for _ in 0..2048 { + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + let ts = s as u32; + + let h = make_valid_tls_handshake(secret, ts); + let result = validate_tls_handshake_with_replay_window(&h, &secrets, true, 60); + assert!( + result.is_some(), + "ignore_time_skew=true must accept valid HMAC for arbitrary timestamp" + ); + } +} + +#[test] +fn light_fuzz_small_replay_window_rejects_far_timestamps_when_skew_enabled() { + let secret = b"replay_window_reject_fuzz_test"; + let secrets = vec![("u".to_string(), secret.to_vec())]; + + for ts in 300u32..=1323u32 { + let h = make_valid_tls_handshake(secret, ts); + let result = validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, false, 0, 300); + assert!( + result.is_none(), + "with skew checks enabled and boot cap=300, timestamp >=300 at now=0 must be rejected" + ); + } +} + // ------------------------------------------------------------------ // Extreme timestamp values // ------------------------------------------------------------------ @@ -897,7 +1137,9 @@ fn first_matching_user_wins_over_later_duplicate_secret() { #[test] fn test_is_tls_handshake() { assert!(is_tls_handshake(&[0x16, 0x03, 0x01])); + assert!(is_tls_handshake(&[0x16, 0x03, 0x03])); assert!(is_tls_handshake(&[0x16, 0x03, 0x01, 0x02, 0x00])); + assert!(is_tls_handshake(&[0x16, 0x03, 0x03, 0x02, 0x00])); assert!(!is_tls_handshake(&[0x17, 0x03, 0x01])); assert!(!is_tls_handshake(&[0x16, 0x03, 0x02])); assert!(!is_tls_handshake(&[0x16, 0x03])); @@ -1502,3 +1744,83 @@ fn server_hello_new_session_ticket_count_matches_configuration() { "response must contain one main application record plus configured ticket-like tail records" ); } + +#[test] +fn exhaustive_tls_minor_version_classification_matches_policy() { + for minor in 0u8..=u8::MAX { + let first = [TLS_RECORD_HANDSHAKE, 0x03, minor]; + let expected = minor == 0x01 || minor == 0x03; + assert_eq!( + is_tls_handshake(&first), + expected, + "minor version {minor:#04x} classification mismatch" + ); + } +} + +#[test] +fn light_fuzz_tls_header_classifier_and_parser_policy_consistency() { + // Deterministic xorshift state keeps this fuzz test reproducible. + let mut s: u64 = 0x9E37_79B9_AA95_5A5D; + + for _ in 0..10_000 { + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + + let header = [ + (s & 0xff) as u8, + ((s >> 8) & 0xff) as u8, + ((s >> 16) & 0xff) as u8, + ((s >> 24) & 0xff) as u8, + ((s >> 32) & 0xff) as u8, + ]; + + let classified = is_tls_handshake(&header[..3]); + let expected_classified = header[0] == TLS_RECORD_HANDSHAKE + && header[1] == 0x03 + && (header[2] == 0x01 || header[2] == 0x03); + assert_eq!( + classified, + expected_classified, + "classifier policy mismatch for header {header:02x?}" + ); + + let parsed = parse_tls_record_header(&header); + let expected_parsed = header[1] == 0x03 && (header[2] == 0x01 || header[2] == TLS_VERSION[1]); + assert_eq!( + parsed.is_some(), + expected_parsed, + "parser policy mismatch for header {header:02x?}" + ); + } +} + +#[test] +fn stress_random_noise_handshakes_never_authenticate() { + let secret = b"stress_noise_secret"; + let secrets = vec![("noise-user".to_string(), secret.to_vec())]; + + // Deterministic xorshift state keeps this stress test reproducible. + let mut s: u64 = 0xD1B5_4A32_9C6E_77F1; + + for _ in 0..5_000 { + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + + let len = 1 + ((s as usize) % 196); + let mut buf = vec![0u8; len]; + for b in &mut buf { + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + *b = (s & 0xff) as u8; + } + + assert!( + validate_tls_handshake(&buf, &secrets, true).is_none(), + "random noise must never authenticate" + ); + } +} diff --git a/src/proxy/client.rs b/src/proxy/client.rs index 5d32e34..199f775 100644 --- a/src/proxy/client.rs +++ b/src/proxy/client.rs @@ -51,9 +51,9 @@ impl UserConnectionReservation { if !self.active { return; } + self.ip_tracker.remove_ip(&self.user, self.ip).await; self.active = false; self.stats.decrement_user_curr_connects(&self.user); - self.ip_tracker.remove_ip(&self.user, self.ip).await; } } @@ -111,7 +111,19 @@ use crate::proxy::middle_relay::handle_via_middle_proxy; use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController}; fn beobachten_ttl(config: &ProxyConfig) -> Duration { - Duration::from_secs(config.general.beobachten_minutes.saturating_mul(60)) + let minutes = config.general.beobachten_minutes; + if minutes == 0 { + static BEOBACHTEN_ZERO_MINUTES_WARNED: OnceLock = OnceLock::new(); + let warned = BEOBACHTEN_ZERO_MINUTES_WARNED.get_or_init(|| AtomicBool::new(false)); + if !warned.swap(true, Ordering::Relaxed) { + warn!( + "general.beobachten_minutes=0 is insecure because entries expire immediately; forcing minimum TTL to 1 minute" + ); + } + return Duration::from_secs(60); + } + + Duration::from_secs(minutes.saturating_mul(60)) } fn record_beobachten_class( @@ -494,7 +506,6 @@ impl RunningClientHandler { pub async fn run(self) -> Result<()> { self.stats.increment_connects_all(); let peer = self.peer; - let _ip_tracker = self.ip_tracker.clone(); debug!(peer = %peer, "New connection"); if let Err(e) = configure_client_socket( @@ -625,7 +636,6 @@ impl RunningClientHandler { let is_tls = tls::is_tls_handshake(&first_bytes[..3]); let peer = self.peer; - let _ip_tracker = self.ip_tracker.clone(); debug!(peer = %peer, is_tls = is_tls, "Handshake type detected"); @@ -638,7 +648,6 @@ impl RunningClientHandler { async fn handle_tls_client(mut self, first_bytes: [u8; 5], local_addr: SocketAddr) -> Result { let peer = self.peer; - let _ip_tracker = self.ip_tracker.clone(); let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize; @@ -762,7 +771,6 @@ impl RunningClientHandler { async fn handle_direct_client(mut self, first_bytes: [u8; 5], local_addr: SocketAddr) -> Result { let peer = self.peer; - let _ip_tracker = self.ip_tracker.clone(); if !self.config.general.modes.classic && !self.config.general.modes.secure { debug!(peer = %peer, "Non-TLS modes disabled"); @@ -1032,7 +1040,10 @@ impl RunningClientHandler { } match ip_tracker.check_and_add(user, peer_addr.ip()).await { - Ok(()) => {} + Ok(()) => { + ip_tracker.remove_ip(user, peer_addr.ip()).await; + stats.decrement_user_curr_connects(user); + } Err(reason) => { stats.decrement_user_curr_connects(user); warn!( diff --git a/src/proxy/client_security_tests.rs b/src/proxy/client_security_tests.rs index 6ca2d4b..6b236aa 100644 --- a/src/proxy/client_security_tests.rs +++ b/src/proxy/client_security_tests.rs @@ -361,6 +361,93 @@ async fn short_tls_probe_is_masked_through_client_pipeline() { accept_task.await.unwrap(); } +#[tokio::test] +async fn tls12_record_probe_is_masked_through_client_pipeline() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = vec![0x16, 0x03, 0x03, 0x00, 0x10]; + let backend_reply = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK".to_vec(); + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = vec![0u8; probe.len()]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, probe); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(4096); + let peer: SocketAddr = "203.0.113.78:55001".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + )); + + client_side.write_all(&probe).await.unwrap(); + let mut observed = vec![0u8; backend_reply.len()]; + client_side.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, backend_reply); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + accept_task.await.unwrap(); +} + #[tokio::test] async fn handle_client_stream_increments_connects_all_exactly_once() { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); @@ -1381,6 +1468,34 @@ fn non_eof_error_is_classified_as_other() { ); } +#[test] +fn beobachten_ttl_zero_minutes_is_floored_to_one_minute() { + let mut config = ProxyConfig::default(); + config.general.beobachten = true; + config.general.beobachten_minutes = 0; + + let ttl = beobachten_ttl(&config); + assert_eq!( + ttl, + Duration::from_secs(60), + "beobachten_minutes=0 must be fail-closed to a one-minute minimum TTL" + ); +} + +#[test] +fn beobachten_ttl_positive_minutes_remain_unchanged() { + let mut config = ProxyConfig::default(); + config.general.beobachten = true; + config.general.beobachten_minutes = 7; + + let ttl = beobachten_ttl(&config); + assert_eq!( + ttl, + Duration::from_secs(7 * 60), + "configured positive beobacten TTL must be preserved" + ); +} + #[tokio::test] async fn tcp_limit_rejection_does_not_reserve_ip_or_trigger_rollback() { let mut config = ProxyConfig::default(); @@ -1449,6 +1564,83 @@ async fn zero_tcp_limit_rejects_without_ip_or_counter_side_effects() { assert_eq!(ip_tracker.get_active_ip_count("user").await, 0); } +#[tokio::test] +async fn check_user_limits_static_success_does_not_leak_counter_or_ip_reservation() { + let user = "check-helper-user"; + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 1); + + let stats = Stats::new(); + let ip_tracker = UserIpTracker::new(); + let peer_addr: SocketAddr = "198.51.100.212:50002".parse().unwrap(); + + let first = RunningClientHandler::check_user_limits_static( + user, + &config, + &stats, + peer_addr, + &ip_tracker, + ) + .await; + assert!(first.is_ok(), "first check-only limit validation must succeed"); + + let second = RunningClientHandler::check_user_limits_static( + user, + &config, + &stats, + peer_addr, + &ip_tracker, + ) + .await; + assert!(second.is_ok(), "second check-only validation must not fail from leaked state"); + assert_eq!(stats.get_user_curr_connects(user), 0); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 0); +} + +#[tokio::test] +async fn stress_check_user_limits_static_success_never_leaks_state() { + let user = "check-helper-stress-user"; + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 1); + + let stats = Stats::new(); + let ip_tracker = UserIpTracker::new(); + + for i in 0..4096u16 { + let peer_addr = SocketAddr::new( + IpAddr::V4(std::net::Ipv4Addr::new(198, 51, 110, (i % 250) as u8 + 1)), + 40000 + (i % 1024), + ); + + let result = RunningClientHandler::check_user_limits_static( + user, + &config, + &stats, + peer_addr, + &ip_tracker, + ) + .await; + assert!(result.is_ok(), "check-only helper must remain leak-free under stress"); + } + + assert_eq!( + stats.get_user_curr_connects(user), + 0, + "stress success loop must not leak user connection counters" + ); + assert_eq!( + ip_tracker.get_active_ip_count(user).await, + 0, + "stress success loop must not leak active IP reservations" + ); +} + #[tokio::test] async fn concurrent_distinct_ip_rejections_rollback_user_counter_without_leak() { let user = "rollback-storm-user"; @@ -1678,6 +1870,249 @@ async fn explicit_release_allows_immediate_cross_ip_reacquire_under_limit() { assert_eq!(ip_tracker.get_active_ip_count(user).await, 0); } +#[tokio::test] +async fn release_abort_storm_does_not_leak_user_or_ip_reservations() { + const ATTEMPTS: usize = 256; + + let user = "release-abort-storm-user"; + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), ATTEMPTS + 16); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + for idx in 0..ATTEMPTS { + let peer = SocketAddr::new( + IpAddr::V4(std::net::Ipv4Addr::new(203, 0, 114, (idx % 250 + 1) as u8)), + 52000 + idx as u16, + ); + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await + .expect("reservation acquisition must succeed in abort storm"); + + let release_task = tokio::spawn(async move { + reservation.release().await; + }); + release_task.abort(); + let _ = release_task.await; + } + + tokio::time::timeout(Duration::from_secs(1), async { + loop { + if stats.get_user_curr_connects(user) == 0 + && ip_tracker.get_active_ip_count(user).await == 0 + { + break; + } + tokio::task::yield_now().await; + tokio::time::sleep(Duration::from_millis(2)).await; + } + }) + .await + .expect("release abort storm must not leak user slots or active IP entries"); +} + +#[tokio::test] +async fn release_abort_loop_preserves_immediate_same_ip_reacquire() { + const ITERATIONS: usize = 128; + + let user = "release-abort-reacquire-user"; + let peer: SocketAddr = "198.51.100.246:53001".parse().unwrap(); + + let mut config = ProxyConfig::default(); + config.access.user_max_tcp_conns.insert(user.to_string(), 1); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + for _ in 0..ITERATIONS { + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await + .expect("baseline acquisition must succeed"); + + let release_task = tokio::spawn(async move { + reservation.release().await; + }); + release_task.abort(); + let _ = release_task.await; + + tokio::time::timeout(Duration::from_secs(1), async { + loop { + if stats.get_user_curr_connects(user) == 0 + && ip_tracker.get_active_ip_count(user).await == 0 + { + break; + } + tokio::task::yield_now().await; + tokio::time::sleep(Duration::from_millis(2)).await; + } + }) + .await + .expect("aborted release must still converge to zero footprint"); + } + + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await + .expect("same-ip reacquire must succeed after repeated abort-release churn"); + reservation.release().await; +} + +#[tokio::test] +async fn adversarial_mixed_release_drop_abort_wave_converges_to_zero() { + const RESERVATIONS: usize = 192; + + let user = "mixed-wave-user"; + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), RESERVATIONS + 8); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let mut reservations = Vec::with_capacity(RESERVATIONS); + for idx in 0..RESERVATIONS { + let peer = SocketAddr::new( + IpAddr::V4(std::net::Ipv4Addr::new(203, 0, 115, (idx % 250 + 1) as u8)), + 54000 + idx as u16, + ); + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await + .expect("mixed-wave acquisition must succeed"); + reservations.push(reservation); + } + + let mut seed: u64 = 0xDEAD_BEEF_CAFE_BA5E; + let mut join_set = tokio::task::JoinSet::new(); + for reservation in reservations { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + match seed % 3 { + 0 => { + join_set.spawn(async move { + reservation.release().await; + }); + } + 1 => { + drop(reservation); + } + _ => { + let task = tokio::spawn(async move { + reservation.release().await; + }); + task.abort(); + let _ = task.await; + } + } + } + + while let Some(result) = join_set.join_next().await { + result.expect("release subtask must not panic"); + } + + tokio::time::timeout(Duration::from_secs(2), async { + loop { + if stats.get_user_curr_connects(user) == 0 + && ip_tracker.get_active_ip_count(user).await == 0 + { + break; + } + tokio::task::yield_now().await; + tokio::time::sleep(Duration::from_millis(2)).await; + } + }) + .await + .expect("mixed release/drop/abort wave must converge to zero footprint"); +} + +#[tokio::test] +async fn parallel_users_abort_release_isolation_preserves_independent_cleanup() { + let user_a = "abort-isolation-a"; + let user_b = "abort-isolation-b"; + + let mut config = ProxyConfig::default(); + config.access.user_max_tcp_conns.insert(user_a.to_string(), 64); + config.access.user_max_tcp_conns.insert(user_b.to_string(), 64); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let mut tasks = tokio::task::JoinSet::new(); + for idx in 0..64usize { + let user = if idx % 2 == 0 { user_a } else { user_b }; + let peer = SocketAddr::new( + IpAddr::V4(std::net::Ipv4Addr::new(198, 18, 0, (idx % 250 + 1) as u8)), + 55000 + idx as u16, + ); + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await + .expect("parallel-user acquisition must succeed"); + + tasks.spawn(async move { + let t = tokio::spawn(async move { + reservation.release().await; + }); + t.abort(); + let _ = t.await; + }); + } + + while let Some(result) = tasks.join_next().await { + result.expect("parallel-user abort task must not panic"); + } + + tokio::time::timeout(Duration::from_secs(2), async { + loop { + if stats.get_user_curr_connects(user_a) == 0 + && stats.get_user_curr_connects(user_b) == 0 + && ip_tracker.get_active_ip_count(user_a).await == 0 + && ip_tracker.get_active_ip_count(user_b).await == 0 + { + break; + } + tokio::task::yield_now().await; + tokio::time::sleep(Duration::from_millis(2)).await; + } + }) + .await + .expect("parallel users must cleanup independently under abort churn"); +} + #[tokio::test] async fn concurrent_release_storm_leaves_zero_user_and_ip_footprint() { const RESERVATIONS: usize = 64; @@ -2301,16 +2736,24 @@ async fn atomic_limit_gate_allows_only_one_concurrent_acquire() { IpAddr::V4(std::net::Ipv4Addr::new(203, 0, 113, (i + 1) as u8)), 30000 + i, ); - RunningClientHandler::check_user_limits_static("user", &config, &stats, peer, &ip_tracker) - .await - .is_ok() + RunningClientHandler::acquire_user_connection_reservation_static( + "user", + &config, + stats, + peer, + ip_tracker, + ) + .await + .ok() }); } let mut successes = 0u64; + let mut held_reservations = Vec::new(); while let Some(joined) = tasks.join_next().await { - if joined.unwrap() { + if let Some(reservation) = joined.unwrap() { successes += 1; + held_reservations.push(reservation); } } @@ -2319,6 +2762,8 @@ async fn atomic_limit_gate_allows_only_one_concurrent_acquire() { "exactly one concurrent acquire must pass for a limit=1 user" ); assert_eq!(stats.get_user_curr_connects("user"), 1); + + drop(held_reservations); } #[tokio::test] diff --git a/src/proxy/direct_relay.rs b/src/proxy/direct_relay.rs index d7d5f64..72a5c91 100644 --- a/src/proxy/direct_relay.rs +++ b/src/proxy/direct_relay.rs @@ -1,6 +1,7 @@ use std::fs::OpenOptions; use std::io::Write; use std::net::SocketAddr; +use std::path::{Component, Path, PathBuf}; use std::sync::Arc; use std::collections::HashSet; use std::sync::{Mutex, OnceLock}; @@ -32,6 +33,10 @@ static LOGGED_UNKNOWN_DCS: OnceLock>> = OnceLock::new(); // deterministic under parallel execution. fn should_log_unknown_dc(dc_idx: i16) -> bool { let set = LOGGED_UNKNOWN_DCS.get_or_init(|| Mutex::new(HashSet::new())); + should_log_unknown_dc_with_set(set, dc_idx) +} + +fn should_log_unknown_dc_with_set(set: &Mutex>, dc_idx: i16) -> bool { match set.lock() { Ok(mut guard) => { if guard.contains(&dc_idx) { @@ -42,12 +47,39 @@ fn should_log_unknown_dc(dc_idx: i16) -> bool { } guard.insert(dc_idx) } - // If the lock is poisoned, keep logging rather than silently dropping - // operator-visible diagnostics. - Err(_) => true, + // Fail closed on poisoned state to avoid unbounded blocking log writes. + Err(_) => false, } } +fn sanitize_unknown_dc_log_path(path: &str) -> Option { + let candidate = Path::new(path); + if candidate.as_os_str().is_empty() { + return None; + } + if candidate + .components() + .any(|component| matches!(component, Component::ParentDir)) + { + return None; + } + + let cwd = std::env::current_dir().ok()?; + let file_name = candidate.file_name()?; + let parent = candidate.parent().unwrap_or_else(|| Path::new(".")); + let parent_path = if parent.is_absolute() { + parent.to_path_buf() + } else { + cwd.join(parent) + }; + let canonical_parent = parent_path.canonicalize().ok()?; + if !canonical_parent.is_dir() { + return None; + } + + Some(canonical_parent.join(file_name)) +} + #[cfg(test)] fn clear_unknown_dc_log_cache_for_testing() { if let Some(set) = LOGGED_UNKNOWN_DCS.get() @@ -200,12 +232,15 @@ fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result { && should_log_unknown_dc(dc_idx) && let Ok(handle) = tokio::runtime::Handle::try_current() { - let path = path.clone(); - handle.spawn_blocking(move || { - if let Ok(mut file) = OpenOptions::new().create(true).append(true).open(path) { - let _ = writeln!(file, "dc_idx={dc_idx}"); - } - }); + if let Some(path) = sanitize_unknown_dc_log_path(path) { + handle.spawn_blocking(move || { + if let Ok(mut file) = OpenOptions::new().create(true).append(true).open(path) { + let _ = writeln!(file, "dc_idx={dc_idx}"); + } + }); + } else { + warn!(dc_idx = dc_idx, raw_path = %path, "Rejected unsafe unknown DC log path"); + } } } diff --git a/src/proxy/direct_relay_security_tests.rs b/src/proxy/direct_relay_security_tests.rs index 7390fb8..d967da3 100644 --- a/src/proxy/direct_relay_security_tests.rs +++ b/src/proxy/direct_relay_security_tests.rs @@ -6,7 +6,10 @@ use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController}; use crate::stats::Stats; use crate::stream::{BufferPool, CryptoReader, CryptoWriter}; use crate::transport::UpstreamManager; +use std::fs; +use std::path::Path; use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; use std::time::Duration; use tokio::io::duplex; use tokio::net::TcpListener; @@ -29,6 +32,10 @@ where CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) } +fn nonempty_line_count(text: &str) -> usize { + text.lines().filter(|line| !line.trim().is_empty()).count() +} + #[test] fn unknown_dc_log_is_deduplicated_per_dc_idx() { let _guard = unknown_dc_test_lock() @@ -67,6 +74,431 @@ fn unknown_dc_log_respects_distinct_limit() { ); } +#[test] +fn unknown_dc_log_fails_closed_when_dedup_lock_is_poisoned() { + let poisoned = Arc::new(std::sync::Mutex::new(std::collections::HashSet::::new())); + let poisoned_for_thread = poisoned.clone(); + + let _ = std::thread::spawn(move || { + let _guard = poisoned_for_thread + .lock() + .expect("poison setup lock must be available"); + panic!("intentional poison for fail-closed regression"); + }) + .join(); + + assert!( + !should_log_unknown_dc_with_set(poisoned.as_ref(), 4242), + "poisoned unknown-DC dedup lock must fail closed" + ); +} + +#[test] +fn stress_unknown_dc_log_concurrent_unique_churn_respects_cap() { + let _guard = unknown_dc_test_lock() + .lock() + .expect("unknown dc test lock must be available"); + clear_unknown_dc_log_cache_for_testing(); + + let accepted = Arc::new(AtomicUsize::new(0)); + let mut workers = Vec::new(); + + // Adversarial model: many concurrent peers rotate dc_idx values rapidly. + for worker in 0..16usize { + let accepted = Arc::clone(&accepted); + workers.push(std::thread::spawn(move || { + let base = (worker * 2048) as i32; + for offset in 0..512i32 { + let raw = base + offset; + let dc = (raw % i16::MAX as i32) as i16; + if should_log_unknown_dc(dc) { + accepted.fetch_add(1, Ordering::Relaxed); + } + } + })); + } + + for worker in workers { + worker.join().expect("worker thread must not panic"); + } + + assert_eq!( + accepted.load(Ordering::Relaxed), + UNKNOWN_DC_LOG_DISTINCT_LIMIT, + "concurrent unique churn must never admit more than the configured distinct cap" + ); +} + +#[test] +fn light_fuzz_unknown_dc_log_mixed_duplicates_never_exceeds_cap() { + let _guard = unknown_dc_test_lock() + .lock() + .expect("unknown dc test lock must be available"); + clear_unknown_dc_log_cache_for_testing(); + + // Deterministic xorshift sequence for reproducible mixed duplicate fuzzing. + let mut s: u64 = 0xA5A5_5A5A_C3C3_3C3C; + let mut admitted = 0usize; + + for _ in 0..20_000 { + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + + let dc = (s as i16).wrapping_sub(i16::MAX / 2); + if should_log_unknown_dc(dc) { + admitted += 1; + } + } + + assert!( + admitted <= UNKNOWN_DC_LOG_DISTINCT_LIMIT, + "mixed-duplicate fuzzed inputs must not admit more than cap" + ); +} + +#[test] +fn unknown_dc_log_path_sanitizer_rejects_parent_traversal_inputs() { + assert!( + sanitize_unknown_dc_log_path("../unknown-dc.txt").is_none(), + "parent traversal paths must be rejected" + ); + assert!( + sanitize_unknown_dc_log_path("logs/../unknown-dc.txt").is_none(), + "embedded parent traversal must be rejected" + ); + assert!( + sanitize_unknown_dc_log_path("./../unknown-dc.txt").is_none(), + "relative parent traversal must be rejected" + ); +} + +#[test] +fn unknown_dc_log_path_sanitizer_accepts_absolute_paths_with_existing_parent() { + let absolute = std::env::temp_dir().join("unknown-dc.txt"); + let absolute_str = absolute + .to_str() + .expect("temp absolute path must be valid UTF-8"); + + let sanitized = sanitize_unknown_dc_log_path(absolute_str) + .expect("absolute paths with existing parent must be accepted"); + assert_eq!(sanitized, absolute); +} + +#[test] +fn unknown_dc_log_path_sanitizer_rejects_absolute_parent_traversal() { + assert!( + sanitize_unknown_dc_log_path("/tmp/../etc/passwd").is_none(), + "absolute parent traversal must be rejected" + ); +} + +#[test] +fn unknown_dc_log_path_sanitizer_accepts_safe_relative_path() { + let base = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!("telemt-unknown-dc-log-{}", std::process::id())); + fs::create_dir_all(&base).expect("temp test directory must be creatable"); + + let candidate = base.join("unknown-dc.txt"); + let candidate_relative = format!("target/telemt-unknown-dc-log-{}/unknown-dc.txt", std::process::id()); + + let sanitized = sanitize_unknown_dc_log_path(&candidate_relative) + .expect("safe relative path with existing parent must be accepted"); + assert_eq!(sanitized, candidate); +} + +#[test] +fn unknown_dc_log_path_sanitizer_rejects_empty_or_dot_only_inputs() { + assert!( + sanitize_unknown_dc_log_path("").is_none(), + "empty path must be rejected" + ); + assert!( + sanitize_unknown_dc_log_path(".").is_none(), + "dot-only path without filename must be rejected" + ); +} + +#[test] +fn unknown_dc_log_path_sanitizer_accepts_directory_only_as_filename_projection() { + let sanitized = sanitize_unknown_dc_log_path("target/") + .expect("directory-only input is interpreted as filename projection in current sanitizer"); + assert!( + sanitized.ends_with("target"), + "directory-only input should resolve to canonical parent plus filename projection" + ); +} + +#[test] +fn unknown_dc_log_path_sanitizer_accepts_dot_prefixed_relative_path() { + let rel_dir = format!("target/telemt-unknown-dc-dot-{}", std::process::id()); + let abs_dir = std::env::current_dir() + .expect("cwd must be available") + .join(&rel_dir); + fs::create_dir_all(&abs_dir).expect("dot-prefixed test directory must be creatable"); + + let rel_candidate = format!("./{rel_dir}/unknown-dc.log"); + let expected = abs_dir.join("unknown-dc.log"); + let sanitized = sanitize_unknown_dc_log_path(&rel_candidate) + .expect("dot-prefixed safe path must be accepted"); + assert_eq!(sanitized, expected); +} + +#[test] +fn light_fuzz_unknown_dc_path_parentdir_inputs_always_rejected() { + let mut s: u64 = 0xD00D_BAAD_1234_5678; + for _ in 0..4096 { + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + let a = (s as usize) % 32; + let b = ((s >> 8) as usize) % 32; + let candidate = format!("target/{a}/../{b}/unknown-dc.log"); + assert!( + sanitize_unknown_dc_log_path(&candidate).is_none(), + "parent-dir candidate must be rejected: {candidate}" + ); + } +} + +#[test] +fn unknown_dc_log_path_sanitizer_rejects_nonexistent_parent_directory() { + let rel_candidate = format!( + "target/telemt-unknown-dc-missing-{}/nested/unknown-dc.txt", + std::process::id() + ); + + assert!( + sanitize_unknown_dc_log_path(&rel_candidate).is_none(), + "path with missing parent must be rejected to avoid implicit directory creation" + ); +} + +#[cfg(unix)] +#[test] +fn unknown_dc_log_path_sanitizer_accepts_symlinked_parent_inside_workspace() { + use std::os::unix::fs::symlink; + + let base = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!("telemt-unknown-dc-log-symlink-internal-{}", std::process::id())); + let real_parent = base.join("real_parent"); + fs::create_dir_all(&real_parent).expect("real parent dir must be creatable"); + + let symlink_parent = base.join("internal_link"); + let _ = fs::remove_file(&symlink_parent); + symlink(&real_parent, &symlink_parent).expect("internal symlink must be creatable"); + + let rel_candidate = format!( + "target/telemt-unknown-dc-log-symlink-internal-{}/internal_link/unknown-dc.txt", + std::process::id() + ); + + let sanitized = sanitize_unknown_dc_log_path(&rel_candidate) + .expect("symlinked parent that resolves inside workspace must be accepted"); + assert!( + sanitized.starts_with(&real_parent), + "sanitized path must resolve to canonical internal parent" + ); +} + +#[cfg(unix)] +#[test] +fn unknown_dc_log_path_sanitizer_accepts_symlink_parent_escape_as_canonical_path() { + use std::os::unix::fs::symlink; + + let base = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!("telemt-unknown-dc-log-symlink-{}", std::process::id())); + fs::create_dir_all(&base).expect("symlink test directory must be creatable"); + + let symlink_parent = base.join("escape_link"); + let _ = fs::remove_file(&symlink_parent); + symlink("/tmp", &symlink_parent).expect("symlink parent must be creatable"); + + let rel_candidate = format!( + "target/telemt-unknown-dc-log-symlink-{}/escape_link/unknown-dc.txt", + std::process::id() + ); + + let sanitized = sanitize_unknown_dc_log_path(&rel_candidate) + .expect("symlinked parent must canonicalize to target path"); + assert!( + sanitized.starts_with(Path::new("/tmp")), + "sanitized path must resolve to canonical symlink target" + ); +} + +#[tokio::test] +async fn unknown_dc_absolute_log_path_writes_one_entry() { + let _guard = unknown_dc_test_lock() + .lock() + .expect("unknown dc test lock must be available"); + clear_unknown_dc_log_cache_for_testing(); + + let dc_idx: i16 = 31_001; + let file_path = std::env::temp_dir().join(format!( + "telemt-unknown-dc-abs-{}-{}.log", + std::process::id(), + dc_idx + )); + let _ = fs::remove_file(&file_path); + + let mut cfg = ProxyConfig::default(); + cfg.general.unknown_dc_file_log_enabled = true; + cfg.general.unknown_dc_log_path = Some( + file_path + .to_str() + .expect("temp file path must be valid UTF-8") + .to_string(), + ); + + let _ = get_dc_addr_static(dc_idx, &cfg).expect("fallback routing must still work"); + + let mut content = None; + for _ in 0..20 { + if let Ok(text) = fs::read_to_string(&file_path) { + content = Some(text); + break; + } + tokio::time::sleep(Duration::from_millis(15)).await; + } + + let text = content.expect("absolute unknown-DC log path must produce exactly one log write"); + assert!( + text.contains(&format!("dc_idx={dc_idx}")), + "absolute unknown-DC integration log must contain requested dc_idx" + ); +} + +#[tokio::test] +async fn unknown_dc_safe_relative_log_path_writes_one_entry() { + let _guard = unknown_dc_test_lock() + .lock() + .expect("unknown dc test lock must be available"); + clear_unknown_dc_log_cache_for_testing(); + + let dc_idx: i16 = 31_002; + let rel_dir = format!("target/telemt-unknown-dc-int-{}", std::process::id()); + let rel_file = format!("{rel_dir}/unknown-dc.log"); + let abs_dir = std::env::current_dir() + .expect("cwd must be available") + .join(&rel_dir); + fs::create_dir_all(&abs_dir).expect("integration test log directory must be creatable"); + let abs_file = abs_dir.join("unknown-dc.log"); + let _ = fs::remove_file(&abs_file); + + let mut cfg = ProxyConfig::default(); + cfg.general.unknown_dc_file_log_enabled = true; + cfg.general.unknown_dc_log_path = Some(rel_file); + + let _ = get_dc_addr_static(dc_idx, &cfg).expect("fallback routing must still work"); + + let mut content = None; + for _ in 0..20 { + if let Ok(text) = fs::read_to_string(&abs_file) { + content = Some(text); + break; + } + tokio::time::sleep(Duration::from_millis(15)).await; + } + + let text = content.expect("safe relative path must produce exactly one log write"); + assert!( + text.contains(&format!("dc_idx={dc_idx}")), + "unknown-DC integration log must contain requested dc_idx" + ); +} + +#[tokio::test] +async fn unknown_dc_same_index_burst_writes_only_once() { + let _guard = unknown_dc_test_lock() + .lock() + .expect("unknown dc test lock must be available"); + clear_unknown_dc_log_cache_for_testing(); + + let dc_idx: i16 = 31_010; + let rel_dir = format!("target/telemt-unknown-dc-same-{}", std::process::id()); + let rel_file = format!("{rel_dir}/unknown-dc.log"); + let abs_dir = std::env::current_dir().unwrap().join(&rel_dir); + fs::create_dir_all(&abs_dir).expect("same-index log directory must be creatable"); + let abs_file = abs_dir.join("unknown-dc.log"); + let _ = fs::remove_file(&abs_file); + + let mut cfg = ProxyConfig::default(); + cfg.general.unknown_dc_file_log_enabled = true; + cfg.general.unknown_dc_log_path = Some(rel_file); + + for _ in 0..64 { + let _ = get_dc_addr_static(dc_idx, &cfg).expect("fallback routing must still work"); + } + + let mut content = None; + for _ in 0..30 { + if let Ok(text) = fs::read_to_string(&abs_file) { + content = Some(text); + break; + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + + let text = content.expect("same-index burst must produce at least one log write"); + assert_eq!( + nonempty_line_count(&text), + 1, + "same unknown dc index must be deduplicated to one file line" + ); +} + +#[tokio::test] +async fn unknown_dc_distinct_burst_is_hard_capped_on_file_writes() { + let _guard = unknown_dc_test_lock() + .lock() + .expect("unknown dc test lock must be available"); + clear_unknown_dc_log_cache_for_testing(); + + let rel_dir = format!("target/telemt-unknown-dc-cap-{}", std::process::id()); + let rel_file = format!("{rel_dir}/unknown-dc.log"); + let abs_dir = std::env::current_dir().unwrap().join(&rel_dir); + fs::create_dir_all(&abs_dir).expect("cap log directory must be creatable"); + let abs_file = abs_dir.join("unknown-dc.log"); + let _ = fs::remove_file(&abs_file); + + let mut cfg = ProxyConfig::default(); + cfg.general.unknown_dc_file_log_enabled = true; + cfg.general.unknown_dc_log_path = Some(rel_file); + + for i in 0..(UNKNOWN_DC_LOG_DISTINCT_LIMIT + 128) { + let dc_idx = 20_000i16.wrapping_add(i as i16); + let _ = get_dc_addr_static(dc_idx, &cfg).expect("fallback routing must still work"); + } + + let mut final_text = String::new(); + for _ in 0..80 { + if let Ok(text) = fs::read_to_string(&abs_file) { + final_text = text; + if nonempty_line_count(&final_text) >= UNKNOWN_DC_LOG_DISTINCT_LIMIT { + break; + } + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + + let line_count = nonempty_line_count(&final_text); + assert!( + line_count > 0, + "distinct unknown-dc burst must write at least one line" + ); + assert!( + line_count <= UNKNOWN_DC_LOG_DISTINCT_LIMIT, + "distinct unknown-dc writes must stay within dedup hard cap" + ); +} + #[test] fn fallback_dc_never_panics_with_single_dc_list() { let mut cfg = ProxyConfig::default(); @@ -276,6 +708,13 @@ async fn direct_relay_cutover_midflight_releases_route_gauge() { relay_result.is_err(), "cutover should terminate direct relay session" ); + assert!( + matches!( + relay_result, + Err(ProxyError::Proxy(ref msg)) if msg == ROUTE_SWITCH_ERROR_MSG + ), + "client-visible cutover error must stay generic and avoid route-internal metadata" + ); assert_eq!( stats.get_current_connections_direct(), @@ -287,3 +726,143 @@ async fn direct_relay_cutover_midflight_releases_route_gauge() { tg_accept_task.abort(); let _ = tg_accept_task.await; } + +#[tokio::test] +async fn direct_relay_cutover_storm_multi_session_keeps_generic_errors_and_releases_gauge() { + let session_count = 6usize; + let tg_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let tg_addr = tg_listener.local_addr().unwrap(); + + let tg_accept_task = tokio::spawn(async move { + let mut held_streams = Vec::with_capacity(session_count); + for _ in 0..session_count { + let (stream, _) = tg_listener.accept().await.unwrap(); + held_streams.push(stream); + } + tokio::time::sleep(Duration::from_secs(60)).await; + drop(held_streams); + }); + + let stats = Arc::new(Stats::new()); + let mut config = ProxyConfig::default(); + config + .dc_overrides + .insert("2".to_string(), vec![tg_addr.to_string()]); + let config = Arc::new(config); + + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + + let rng = Arc::new(SecureRandom::new()); + let buffer_pool = Arc::new(BufferPool::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let route_snapshot = route_runtime.snapshot(); + + let mut relay_tasks = Vec::with_capacity(session_count); + let mut client_sides = Vec::with_capacity(session_count); + + for idx in 0..session_count { + let (server_side, client_side) = duplex(64 * 1024); + client_sides.push(client_side); + let (server_reader, server_writer) = tokio::io::split(server_side); + let client_reader = make_crypto_reader(server_reader); + let client_writer = make_crypto_writer(server_writer); + + let success = HandshakeSuccess { + user: format!("cutover-storm-direct-user-{idx}"), + dc_idx: 2, + proto_tag: ProtoTag::Intermediate, + dec_key: [0u8; 32], + dec_iv: 0, + enc_key: [0u8; 32], + enc_iv: 0, + peer: SocketAddr::new( + std::net::IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 1)), + 51000 + idx as u16, + ), + is_tls: false, + }; + + relay_tasks.push(tokio::spawn(handle_via_direct( + client_reader, + client_writer, + success, + upstream_manager.clone(), + stats.clone(), + config.clone(), + buffer_pool.clone(), + rng.clone(), + route_runtime.subscribe(), + route_snapshot, + 0xA000_0000 + idx as u64, + ))); + } + + tokio::time::timeout(Duration::from_secs(4), async { + loop { + if stats.get_current_connections_direct() == session_count as u64 { + break; + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + }) + .await + .expect("all direct sessions must become active before cutover storm"); + + let route_runtime_flipper = route_runtime.clone(); + let flipper = tokio::spawn(async move { + for step in 0..64u32 { + let mode = if (step & 1) == 0 { + RelayRouteMode::Middle + } else { + RelayRouteMode::Direct + }; + let _ = route_runtime_flipper.set_mode(mode); + tokio::time::sleep(Duration::from_millis(15)).await; + } + }); + + for relay_task in relay_tasks { + let relay_result = tokio::time::timeout(Duration::from_secs(10), relay_task) + .await + .expect("direct relay task must finish under cutover storm") + .expect("direct relay task must not panic"); + + assert!( + matches!( + relay_result, + Err(ProxyError::Proxy(ref msg)) if msg == ROUTE_SWITCH_ERROR_MSG + ), + "storm-cutover termination must remain generic for all direct sessions" + ); + } + + flipper.abort(); + let _ = flipper.await; + + assert_eq!( + stats.get_current_connections_direct(), + 0, + "direct route gauge must return to zero after cutover storm" + ); + + drop(client_sides); + tg_accept_task.abort(); + let _ = tg_accept_task.await; +} diff --git a/src/proxy/handshake.rs b/src/proxy/handshake.rs index e25fe39..03b5012 100644 --- a/src/proxy/handshake.rs +++ b/src/proxy/handshake.rs @@ -235,23 +235,31 @@ fn auth_probe_record_failure_with_state( if state.len() >= AUTH_PROBE_TRACK_MAX_ENTRIES { let mut stale_keys = Vec::new(); - let mut eviction_candidates = Vec::new(); + let mut oldest_candidate: Option<(IpAddr, Instant)> = None; for entry in state.iter().take(AUTH_PROBE_PRUNE_SCAN_LIMIT) { - eviction_candidates.push(*entry.key()); + let key = *entry.key(); + let last_seen = entry.value().last_seen; + match oldest_candidate { + Some((_, oldest_seen)) if last_seen >= oldest_seen => {} + _ => oldest_candidate = Some((key, last_seen)), + } if auth_probe_state_expired(entry.value(), now) { - stale_keys.push(*entry.key()); + stale_keys.push(key); } } for stale_key in stale_keys { state.remove(&stale_key); } if state.len() >= AUTH_PROBE_TRACK_MAX_ENTRIES { - if eviction_candidates.is_empty() { + let Some((evict_key, _)) = oldest_candidate else { auth_probe_note_saturation(now); return; - } + }; + state.remove(&evict_key); auth_probe_note_saturation(now); - return; + if state.len() >= AUTH_PROBE_TRACK_MAX_ENTRIES { + return; + } } } @@ -300,6 +308,11 @@ fn auth_probe_saturation_is_throttled_for_testing() -> bool { auth_probe_saturation_is_throttled(Instant::now()) } +#[cfg(test)] +fn auth_probe_saturation_is_throttled_at_for_testing(now: Instant) -> bool { + auth_probe_saturation_is_throttled(now) +} + #[cfg(test)] fn auth_probe_test_lock() -> &'static Mutex<()> { static TEST_LOCK: OnceLock> = OnceLock::new(); @@ -498,7 +511,8 @@ where return HandshakeResult::BadClient { reader, writer }; } - let secrets = decode_user_secrets(config, None); + let client_sni = tls::extract_sni_from_client_hello(handshake); + let secrets = decode_user_secrets(config, client_sni.as_deref()); let validation = match tls::validate_tls_handshake_with_replay_window( handshake, @@ -539,9 +553,9 @@ where let cached = if config.censorship.tls_emulation { if let Some(cache) = tls_cache.as_ref() { - let selected_domain = if let Some(sni) = tls::extract_sni_from_client_hello(handshake) { + let selected_domain = if let Some(sni) = client_sni.as_ref() { if cache.contains_domain(&sni).await { - sni + sni.clone() } else { config.censorship.tls_domain.clone() } diff --git a/src/proxy/handshake_security_tests.rs b/src/proxy/handshake_security_tests.rs index b14ab58..1823167 100644 --- a/src/proxy/handshake_security_tests.rs +++ b/src/proxy/handshake_security_tests.rs @@ -86,6 +86,77 @@ fn make_valid_tls_client_hello_with_alpn( record } +fn make_valid_tls_client_hello_with_sni_and_alpn( + secret: &[u8], + timestamp: u32, + sni_host: &str, + alpn_protocols: &[&[u8]], +) -> Vec { + let mut body = Vec::new(); + body.extend_from_slice(&TLS_VERSION); + body.extend_from_slice(&[0u8; 32]); + body.push(32); + body.extend_from_slice(&[0x42u8; 32]); + body.extend_from_slice(&2u16.to_be_bytes()); + body.extend_from_slice(&[0x13, 0x01]); + body.push(1); + body.push(0); + + let mut ext_blob = Vec::new(); + + let host_bytes = sni_host.as_bytes(); + let mut sni_payload = Vec::new(); + sni_payload.extend_from_slice(&((host_bytes.len() + 3) as u16).to_be_bytes()); + sni_payload.push(0); + sni_payload.extend_from_slice(&(host_bytes.len() as u16).to_be_bytes()); + sni_payload.extend_from_slice(host_bytes); + ext_blob.extend_from_slice(&0x0000u16.to_be_bytes()); + ext_blob.extend_from_slice(&(sni_payload.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&sni_payload); + + if !alpn_protocols.is_empty() { + let mut alpn_list = Vec::new(); + for proto in alpn_protocols { + alpn_list.push(proto.len() as u8); + alpn_list.extend_from_slice(proto); + } + let mut alpn_data = Vec::new(); + alpn_data.extend_from_slice(&(alpn_list.len() as u16).to_be_bytes()); + alpn_data.extend_from_slice(&alpn_list); + + ext_blob.extend_from_slice(&0x0010u16.to_be_bytes()); + ext_blob.extend_from_slice(&(alpn_data.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&alpn_data); + } + + body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes()); + body.extend_from_slice(&ext_blob); + + let mut handshake = Vec::new(); + handshake.push(0x01); + let body_len = (body.len() as u32).to_be_bytes(); + handshake.extend_from_slice(&body_len[1..4]); + handshake.extend_from_slice(&body); + + let mut record = Vec::new(); + record.push(TLS_RECORD_HANDSHAKE); + record.extend_from_slice(&[0x03, 0x01]); + record.extend_from_slice(&(handshake.len() as u16).to_be_bytes()); + record.extend_from_slice(&handshake); + + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + let computed = sha256_hmac(secret, &record); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .copy_from_slice(&digest); + + record +} + fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig { let mut cfg = ProxyConfig::default(); cfg.access.users.clear(); @@ -332,6 +403,45 @@ async fn tls_replay_second_identical_handshake_is_rejected() { assert!(matches!(second, HandshakeResult::BadClient { .. })); } +#[tokio::test] +async fn tls_replay_with_ignore_time_skew_and_small_boot_timestamp_is_still_blocked() { + let secret = [0x19u8; 16]; + let config = test_config_with_secret_hex("19191919191919191919191919191919"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.121:44321".parse().unwrap(); + let handshake = make_valid_tls_handshake(&secret, 1); + + let first = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(first, HandshakeResult::Success(_))); + + let replay = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!( + matches!(replay, HandshakeResult::BadClient { .. }), + "ignore_time_skew must not weaken replay rejection for small boot timestamps" + ); +} + #[tokio::test] async fn tls_replay_concurrent_identical_handshake_allows_exactly_one_success() { let secret = [0x77u8; 16]; @@ -377,6 +487,177 @@ async fn tls_replay_concurrent_identical_handshake_allows_exactly_one_success() ); } +#[tokio::test] +async fn tls_replay_matrix_rotating_peers_first_accept_then_rejects() { + let secret = [0x52u8; 16]; + let config = test_config_with_secret_hex("52525252525252525252525252525252"); + let replay_checker = ReplayChecker::new(4096, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let handshake = make_valid_tls_handshake(&secret, 17); + + let first_peer: SocketAddr = "198.51.100.31:44001".parse().unwrap(); + let first = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + first_peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(first, HandshakeResult::Success(_))); + + for i in 0..128u16 { + let peer = SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(198, 51, 100, ((i % 250) + 1) as u8)), + 45000 + i, + ); + let replay = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!( + matches!(replay, HandshakeResult::BadClient { .. }), + "replay digest must be rejected regardless of source peer rotation" + ); + } +} + +#[tokio::test] +async fn adversarial_tls_replay_churn_allows_only_unique_digests() { + let secret = [0x5Au8; 16]; + let mut config = test_config_with_secret_hex("5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a"); + config.access.ignore_time_skew = true; + let config = Arc::new(config); + let replay_checker = Arc::new(ReplayChecker::new(8192, Duration::from_secs(60))); + let rng = Arc::new(SecureRandom::new()); + + let make_tagged_handshake = |timestamp: u32, tag: u8| { + let session_id_len: usize = 32; + let len = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + session_id_len; + let mut handshake = vec![tag; len]; + + handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8; + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + let computed = sha256_hmac(&secret, &handshake); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .copy_from_slice(&digest); + handshake + }; + + let mut tasks = Vec::new(); + + // 128 exact duplicates: only one should pass. + let duplicated = Arc::new(make_valid_tls_handshake(&secret, 999)); + for i in 0..128u16 { + let config = Arc::clone(&config); + let replay_checker = Arc::clone(&replay_checker); + let rng = Arc::clone(&rng); + let duplicated = Arc::clone(&duplicated); + tasks.push(tokio::spawn(async move { + let peer = SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(203, 0, 113, ((i % 250) + 1) as u8)), + 46000 + i, + ); + handle_tls_handshake( + &duplicated, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await + })); + } + + // 128 unique timestamps: all should pass because HMAC digest differs. + for i in 0..128u16 { + let config = Arc::clone(&config); + let replay_checker = Arc::clone(&replay_checker); + let rng = Arc::clone(&rng); + let handshake = make_tagged_handshake(10_000 + i as u32, (i as u8).wrapping_add(0x80)); + tasks.push(tokio::spawn(async move { + let peer = SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(198, 18, 0, ((i % 250) + 1) as u8)), + 47000 + i, + ); + handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await + })); + } + + let mut duplicate_success = 0usize; + let mut duplicate_reject = 0usize; + let mut unique_success = 0usize; + let mut unique_reject = 0usize; + + for (idx, task) in tasks.into_iter().enumerate() { + let result = task.await.unwrap(); + let is_duplicate_group = idx < 128; + match result { + HandshakeResult::Success(_) => { + if is_duplicate_group { + duplicate_success += 1; + } else { + unique_success += 1; + } + } + HandshakeResult::BadClient { .. } => { + if is_duplicate_group { + duplicate_reject += 1; + } else { + unique_reject += 1; + } + } + HandshakeResult::Error(e) => panic!("unexpected handshake error in churn test: {e}"), + } + } + + assert_eq!( + duplicate_success, 1, + "duplicate replay group must allow exactly one successful handshake" + ); + assert_eq!( + duplicate_reject, 127, + "duplicate replay group must reject all remaining replays" + ); + assert_eq!( + unique_success, 128, + "unique digest group must fully pass under replay churn" + ); + assert_eq!( + unique_reject, 0, + "unique digest group must not be falsely rejected as replay" + ); +} + #[tokio::test] async fn invalid_tls_probe_does_not_pollute_replay_cache() { let config = test_config_with_secret_hex("11111111111111111111111111111111"); @@ -530,6 +811,149 @@ async fn mixed_secret_lengths_keep_valid_user_authenticating() { assert!(matches!(result, HandshakeResult::Success(_))); } +#[tokio::test] +async fn tls_sni_preferred_user_hint_selects_matching_identity_first() { + let shared_secret = [0x3Bu8; 16]; + let mut config = ProxyConfig::default(); + config.access.users.clear(); + config.access.users.insert( + "user-a".to_string(), + "3b3b3b3b3b3b3b3b3b3b3b3b3b3b3b3b".to_string(), + ); + config.access.users.insert( + "user-b".to_string(), + "3b3b3b3b3b3b3b3b3b3b3b3b3b3b3b3b".to_string(), + ); + config.access.ignore_time_skew = true; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.188:44326".parse().unwrap(); + let handshake = make_valid_tls_client_hello_with_sni_and_alpn( + &shared_secret, + 0, + "user-b", + &[b"h2"], + ); + + let result = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + match result { + HandshakeResult::Success((_, _, user)) => { + assert_eq!( + user, "user-b", + "TLS SNI preferred-user hint must select matching identity before equivalent decoys" + ); + } + _ => panic!("TLS handshake must succeed for valid shared-secret SNI case"), + } +} + +#[test] +fn stress_decode_user_secrets_keeps_preferred_user_first_in_large_set() { + let mut config = ProxyConfig::default(); + config.access.users.clear(); + + let preferred_user = "target-user.example".to_string(); + let secret_hex = "7f7f7f7f7f7f7f7f7f7f7f7f7f7f7f7f".to_string(); + + for i in 0..4096usize { + config.access.users.insert( + format!("decoy-{i:04}.example"), + secret_hex.clone(), + ); + } + config + .access + .users + .insert(preferred_user.clone(), secret_hex.clone()); + + let decoded = decode_user_secrets(&config, Some(preferred_user.as_str())); + assert_eq!( + decoded.len(), + config.access.users.len(), + "decoded secret set must preserve full user cardinality under stress" + ); + assert_eq!( + decoded.first().map(|(name, _)| name.as_str()), + Some(preferred_user.as_str()), + "preferred user must be first even under adversarial large user sets" + ); + assert_eq!( + decoded + .iter() + .filter(|(name, _)| name == &preferred_user) + .count(), + 1, + "preferred user must appear exactly once in decoded list" + ); +} + +#[tokio::test] +async fn stress_tls_sni_preferred_user_hint_scales_to_large_user_set() { + let shared_secret = [0x7Fu8; 16]; + let mut config = ProxyConfig::default(); + config.access.users.clear(); + config.access.ignore_time_skew = true; + + let preferred_user = "target-user.example".to_string(); + let secret_hex = "7f7f7f7f7f7f7f7f7f7f7f7f7f7f7f7f".to_string(); + + for i in 0..4096usize { + config.access.users.insert( + format!("decoy-{i:04}.example"), + secret_hex.clone(), + ); + } + config + .access + .users + .insert(preferred_user.clone(), secret_hex); + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.189:44326".parse().unwrap(); + let handshake = make_valid_tls_client_hello_with_sni_and_alpn( + &shared_secret, + 0, + preferred_user.as_str(), + &[b"h2"], + ); + + let result = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + match result { + HandshakeResult::Success((_, _, user)) => { + assert_eq!( + user, + preferred_user, + "SNI preferred-user hint must remain stable under large user cardinality" + ); + } + _ => panic!("TLS handshake must succeed for valid preferred-user stress case"), + } +} + #[tokio::test] async fn alpn_enforce_rejects_unsupported_client_alpn() { let secret = [0x33u8; 16]; @@ -1053,7 +1477,12 @@ fn auth_probe_capacity_prunes_stale_entries_for_new_ips() { } #[test] -fn auth_probe_capacity_saturation_enables_global_throttle_when_map_is_fresh_and_full() { +fn auth_probe_capacity_fresh_full_map_still_tracks_newcomer_with_bounded_eviction() { + let _guard = auth_probe_test_lock() + .lock() + .expect("auth probe test lock must be available"); + clear_auth_probe_state_for_testing(); + let state = DashMap::new(); let now = Instant::now(); @@ -1069,29 +1498,92 @@ fn auth_probe_capacity_saturation_enables_global_throttle_when_map_is_fresh_and_ AuthProbeState { fail_streak: 1, blocked_until: now, - last_seen: now, + last_seen: now + Duration::from_millis(idx as u64 + 1), }, ); } + let oldest = IpAddr::V4(Ipv4Addr::new(172, 16, 0, 0)); + state.insert( + oldest, + AuthProbeState { + fail_streak: 1, + blocked_until: now, + last_seen: now - Duration::from_secs(5), + }, + ); + let newcomer = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 55)); auth_probe_record_failure_with_state(&state, newcomer, now); assert!( - state.get(&newcomer).is_none(), - "fresh-at-cap auth probe state must not churn by evicting tracked sources" + state.get(&newcomer).is_some(), + "fresh-at-cap auth probe map must still track a new source after bounded eviction" + ); + assert!( + state.get(&oldest).is_none(), + "capacity eviction must remove the oldest tracked source first" ); assert_eq!( state.len(), AUTH_PROBE_TRACK_MAX_ENTRIES, - "auth probe map must stay exactly at the configured cap under saturation" + "auth probe map must stay at configured cap after bounded eviction" ); assert!( - auth_probe_saturation_is_throttled_for_testing(), - "capacity saturation must activate coarse global pre-auth throttling" + auth_probe_saturation_is_throttled_at_for_testing(now), + "capacity pressure should still activate coarse global pre-auth throttling" ); } +#[test] +fn stress_auth_probe_full_map_churn_keeps_bound_and_tracks_newcomers() { + let _guard = auth_probe_test_lock() + .lock() + .expect("auth probe test lock must be available"); + clear_auth_probe_state_for_testing(); + + let state = DashMap::new(); + let base_now = Instant::now(); + + for idx in 0..AUTH_PROBE_TRACK_MAX_ENTRIES { + let ip = IpAddr::V4(Ipv4Addr::new( + 10, + 2, + ((idx >> 8) & 0xff) as u8, + (idx & 0xff) as u8, + )); + state.insert( + ip, + AuthProbeState { + fail_streak: 1, + blocked_until: base_now, + last_seen: base_now + Duration::from_millis((idx % 2048) as u64), + }, + ); + } + + for step in 0..1024usize { + let newcomer = IpAddr::V4(Ipv4Addr::new( + 203, + 0, + ((step >> 8) & 0xff) as u8, + (step & 0xff) as u8, + )); + let now = base_now + Duration::from_millis(10_000 + step as u64); + auth_probe_record_failure_with_state(&state, newcomer, now); + + assert!( + state.get(&newcomer).is_some(), + "new source must still be tracked under sustained at-capacity churn" + ); + assert_eq!( + state.len(), + AUTH_PROBE_TRACK_MAX_ENTRIES, + "auth probe map size must stay hard-bounded at capacity" + ); + } +} + #[test] fn auth_probe_ipv6_is_bucketed_by_prefix_64() { let state = DashMap::new(); diff --git a/src/proxy/masking.rs b/src/proxy/masking.rs index 636f637..b0f6985 100644 --- a/src/proxy/masking.rs +++ b/src/proxy/masking.rs @@ -35,7 +35,7 @@ where R: AsyncRead + Unpin, W: AsyncWrite + Unpin, { - let mut buf = vec![0u8; MASK_BUFFER_SIZE]; + let mut buf = [0u8; MASK_BUFFER_SIZE]; loop { let read_res = timeout(MASK_RELAY_IDLE_TIMEOUT, reader.read(&mut buf)).await; let n = match read_res { diff --git a/src/proxy/middle_relay.rs b/src/proxy/middle_relay.rs index affa4cd..bf23045 100644 --- a/src/proxy/middle_relay.rs +++ b/src/proxy/middle_relay.rs @@ -1,4 +1,5 @@ -use std::collections::hash_map::DefaultHasher; +use std::collections::hash_map::RandomState; +use std::hash::BuildHasher; use std::hash::{Hash, Hasher}; use std::net::{IpAddr, SocketAddr}; use std::sync::atomic::{AtomicU64, Ordering}; @@ -41,6 +42,7 @@ const C2ME_SENDER_FAIRNESS_BUDGET: usize = 32; const ME_D2C_FLUSH_BATCH_MAX_FRAMES_MIN: usize = 1; const ME_D2C_FLUSH_BATCH_MAX_BYTES_MIN: usize = 4096; static DESYNC_DEDUP: OnceLock> = OnceLock::new(); +static DESYNC_HASHER: OnceLock = OnceLock::new(); struct RelayForensicsState { trace_id: u64, @@ -80,7 +82,8 @@ impl MeD2cFlushPolicy { } fn hash_value(value: &T) -> u64 { - let mut hasher = DefaultHasher::new(); + let state = DESYNC_HASHER.get_or_init(RandomState::new); + let mut hasher = state.build_hasher(); value.hash(&mut hasher); hasher.finish() } @@ -106,12 +109,17 @@ fn should_emit_full_desync(key: u64, all_full: bool, now: Instant) -> bool { if dedup.len() >= DESYNC_DEDUP_MAX_ENTRIES { let mut stale_keys = Vec::new(); - let mut eviction_candidate = None; + let mut oldest_candidate: Option<(u64, Instant)> = None; for entry in dedup.iter().take(DESYNC_DEDUP_PRUNE_SCAN_LIMIT) { - if eviction_candidate.is_none() { - eviction_candidate = Some(*entry.key()); + let key = *entry.key(); + let seen_at = *entry.value(); + + match oldest_candidate { + Some((_, oldest_seen)) if seen_at >= oldest_seen => {} + _ => oldest_candidate = Some((key, seen_at)), } - if now.duration_since(*entry.value()) >= DESYNC_DEDUP_WINDOW { + + if now.duration_since(seen_at) >= DESYNC_DEDUP_WINDOW { stale_keys.push(*entry.key()); } } @@ -119,7 +127,7 @@ fn should_emit_full_desync(key: u64, all_full: bool, now: Instant) -> bool { dedup.remove(&stale_key); } if dedup.len() >= DESYNC_DEDUP_MAX_ENTRIES { - let Some(evict_key) = eviction_candidate else { + let Some((evict_key, _)) = oldest_candidate else { return false; }; dedup.remove(&evict_key); diff --git a/src/proxy/middle_relay_security_tests.rs b/src/proxy/middle_relay_security_tests.rs index f88b5a0..441595e 100644 --- a/src/proxy/middle_relay_security_tests.rs +++ b/src/proxy/middle_relay_security_tests.rs @@ -8,7 +8,9 @@ use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController}; use crate::stats::Stats; use crate::stream::{BufferPool, CryptoReader, CryptoWriter, PooledBuffer}; use crate::transport::middle_proxy::MePool; -use std::collections::HashMap; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; +use std::collections::{HashMap, HashSet}; use std::net::SocketAddr; use std::sync::Arc; use std::sync::atomic::AtomicU64; @@ -220,6 +222,190 @@ fn desync_dedup_full_cache_churn_stays_suppressed() { } } +#[test] +fn dedup_hash_is_stable_for_same_input_within_process() { + let sample = ( + "scope_user", + hash_ip("198.51.100.7".parse().unwrap()), + ProtoTag::Secure, + ); + let first = hash_value(&sample); + let second = hash_value(&sample); + assert_eq!( + first, second, + "dedup hash must be stable within a process for cache lookups" + ); +} + +#[test] +fn dedup_hash_resists_simple_collision_bursts_for_peer_ip_space() { + let mut seen = HashSet::new(); + + for octet in 1u16..=2048 { + let third = ((octet / 256) & 0xff) as u8; + let fourth = (octet & 0xff) as u8; + let ip = IpAddr::V4(std::net::Ipv4Addr::new(198, 51, third, fourth)); + let key = hash_value(&( + "scope_user", + hash_ip(ip), + ProtoTag::Secure, + DESYNC_ERROR_CLASS, + )); + seen.insert(key); + } + + assert_eq!( + seen.len(), + 2048, + "adversarial peer-IP burst should not collapse dedup keys via trivial collisions" + ); +} + +#[test] +fn light_fuzz_dedup_hash_collision_rate_stays_negligible() { + let mut rng = StdRng::seed_from_u64(0x9E37_79B9_A1B2_C3D4); + let mut seen = HashSet::new(); + let samples = 8192usize; + + for _ in 0..samples { + let user_seed: u64 = rng.random(); + let peer_seed: u64 = rng.random(); + let proto = if (peer_seed & 1) == 0 { + ProtoTag::Secure + } else { + ProtoTag::Intermediate + }; + let key = hash_value(&(user_seed, peer_seed, proto, DESYNC_ERROR_CLASS)); + seen.insert(key); + } + + let collisions = samples - seen.len(); + assert!( + collisions <= 1, + "light fuzz collision count should remain negligible for 64-bit dedup keys" + ); +} + +#[test] +fn stress_desync_dedup_churn_keeps_cache_hard_bounded() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("desync dedup test lock must be available"); + clear_desync_dedup_for_testing(); + + let now = Instant::now(); + let total = DESYNC_DEDUP_MAX_ENTRIES + 8192; + + for key in 0..total as u64 { + let emitted = should_emit_full_desync(key, false, now); + if key < DESYNC_DEDUP_MAX_ENTRIES as u64 { + assert!(emitted, "keys below cap must be admitted initially"); + } else { + assert!( + !emitted, + "new keys above cap must stay suppressed under sustained churn" + ); + } + } + + let len = DESYNC_DEDUP + .get() + .expect("dedup cache must be initialized by stress run") + .len(); + assert!( + len <= DESYNC_DEDUP_MAX_ENTRIES, + "dedup cache must stay bounded under stress churn" + ); +} + +#[test] +fn desync_dedup_full_cache_inserts_new_key_with_bounded_single_key_churn() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("desync dedup test lock must be available"); + clear_desync_dedup_for_testing(); + + let dedup = DESYNC_DEDUP.get_or_init(DashMap::new); + let base_now = Instant::now(); + + // Fill with fresh entries so stale-pruning does not apply. + for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { + dedup.insert(key, base_now - TokioDuration::from_millis(10)); + } + + let before_keys: std::collections::HashSet = dedup.iter().map(|e| *e.key()).collect(); + + let newcomer_key = u64::MAX; + let emitted = should_emit_full_desync(newcomer_key, false, base_now); + assert!( + !emitted, + "new entry under full fresh cache must stay suppressed" + ); + assert!( + dedup.get(&newcomer_key).is_some(), + "new key must be inserted after bounded eviction" + ); + + let after_keys: std::collections::HashSet = dedup.iter().map(|e| *e.key()).collect(); + let removed_count = before_keys.difference(&after_keys).count(); + let added_count = after_keys.difference(&before_keys).count(); + + assert_eq!( + removed_count, 1, + "full-cache insertion must evict exactly one prior key" + ); + assert_eq!( + added_count, 1, + "full-cache insertion must add exactly one newcomer key" + ); + assert!( + dedup.len() <= DESYNC_DEDUP_MAX_ENTRIES, + "dedup cache must remain hard-bounded after full-cache churn" + ); +} + +#[test] +fn light_fuzz_desync_dedup_temporal_gate_behavior_is_stable() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("desync dedup test lock must be available"); + clear_desync_dedup_for_testing(); + + let key = 0xC0DE_CAFE_u64; + let start = Instant::now(); + + assert!( + should_emit_full_desync(key, false, start), + "first event for key must emit full forensic record" + ); + + // Deterministic pseudo-random time deltas around dedup window edge. + let mut s: u64 = 0x1234_5678_9ABC_DEF0; + for _ in 0..2048 { + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + + let delta_ms = s % (DESYNC_DEDUP_WINDOW.as_millis() as u64 * 2 + 1); + let now = start + TokioDuration::from_millis(delta_ms); + let emitted = should_emit_full_desync(key, false, now); + + if delta_ms < DESYNC_DEDUP_WINDOW.as_millis() as u64 { + assert!( + !emitted, + "events inside dedup window must remain suppressed" + ); + } else { + // Once window elapsed for this key, at least one sample should re-emit and refresh. + if emitted { + return; + } + } + } + + panic!("expected at least one post-window sample to re-emit forensic record"); +} + fn make_forensics_state() -> RelayForensicsState { RelayForensicsState { trace_id: 1, @@ -1010,6 +1196,13 @@ async fn middle_relay_cutover_midflight_releases_route_gauge() { relay_result.is_err(), "cutover should terminate middle relay session" ); + assert!( + matches!( + relay_result, + Err(ProxyError::Proxy(ref msg)) if msg == ROUTE_SWITCH_ERROR_MSG + ), + "client-visible cutover error must stay generic and avoid route-internal metadata" + ); assert_eq!( stats.get_current_connections_me(), @@ -1019,3 +1212,107 @@ async fn middle_relay_cutover_midflight_releases_route_gauge() { drop(client_side); } + +#[tokio::test] +async fn middle_relay_cutover_storm_multi_session_keeps_generic_errors_and_releases_gauge() { + let session_count = 6usize; + let stats = Arc::new(Stats::new()); + let me_pool = make_me_pool_for_abort_test(stats.clone()).await; + let config = Arc::new(ProxyConfig::default()); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Middle)); + let route_snapshot = route_runtime.snapshot(); + + let mut relay_tasks = Vec::with_capacity(session_count); + let mut client_sides = Vec::with_capacity(session_count); + + for idx in 0..session_count { + let (server_side, client_side) = duplex(64 * 1024); + client_sides.push(client_side); + let (server_reader, server_writer) = tokio::io::split(server_side); + let crypto_reader = make_crypto_reader(server_reader); + let crypto_writer = make_crypto_writer(server_writer); + + let success = HandshakeSuccess { + user: format!("cutover-storm-middle-user-{idx}"), + dc_idx: 2, + proto_tag: ProtoTag::Intermediate, + dec_key: [0u8; 32], + dec_iv: 0, + enc_key: [0u8; 32], + enc_iv: 0, + peer: SocketAddr::new( + std::net::IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 1)), + 52000 + idx as u16, + ), + is_tls: false, + }; + + relay_tasks.push(tokio::spawn(handle_via_middle_proxy( + crypto_reader, + crypto_writer, + success, + me_pool.clone(), + stats.clone(), + config.clone(), + buffer_pool.clone(), + "127.0.0.1:443".parse().unwrap(), + rng.clone(), + route_runtime.subscribe(), + route_snapshot, + 0xB000_0000 + idx as u64, + ))); + } + + tokio::time::timeout(TokioDuration::from_secs(4), async { + loop { + if stats.get_current_connections_me() == session_count as u64 { + break; + } + tokio::time::sleep(TokioDuration::from_millis(10)).await; + } + }) + .await + .expect("all middle sessions must become active before cutover storm"); + + let route_runtime_flipper = route_runtime.clone(); + let flipper = tokio::spawn(async move { + for step in 0..64u32 { + let mode = if (step & 1) == 0 { + RelayRouteMode::Direct + } else { + RelayRouteMode::Middle + }; + let _ = route_runtime_flipper.set_mode(mode); + tokio::time::sleep(TokioDuration::from_millis(15)).await; + } + }); + + for relay_task in relay_tasks { + let relay_result = tokio::time::timeout(TokioDuration::from_secs(10), relay_task) + .await + .expect("middle relay task must finish under cutover storm") + .expect("middle relay task must not panic"); + + assert!( + matches!( + relay_result, + Err(ProxyError::Proxy(ref msg)) if msg == ROUTE_SWITCH_ERROR_MSG + ), + "storm-cutover termination must remain generic for all middle sessions" + ); + } + + flipper.abort(); + let _ = flipper.await; + + assert_eq!( + stats.get_current_connections_me(), + 0, + "middle route gauge must return to zero after cutover storm" + ); + + drop(client_sides); +} diff --git a/src/proxy/route_mode.rs b/src/proxy/route_mode.rs index 306c536..2b109d1 100644 --- a/src/proxy/route_mode.rs +++ b/src/proxy/route_mode.rs @@ -4,7 +4,7 @@ use std::time::{Duration, SystemTime, UNIX_EPOCH}; use tokio::sync::watch; -pub(crate) const ROUTE_SWITCH_ERROR_MSG: &str = "Route mode switched by cutover"; +pub(crate) const ROUTE_SWITCH_ERROR_MSG: &str = "Session terminated"; #[derive(Clone, Copy, Debug, PartialEq, Eq)] #[repr(u8)] @@ -140,3 +140,7 @@ pub(crate) fn cutover_stagger_delay(session_id: u64, generation: u64) -> Duratio let ms = 1000 + (value % 1000); Duration::from_millis(ms) } + +#[cfg(test)] +#[path = "route_mode_security_tests.rs"] +mod security_tests; diff --git a/src/proxy/route_mode_security_tests.rs b/src/proxy/route_mode_security_tests.rs new file mode 100644 index 0000000..36ab5c3 --- /dev/null +++ b/src/proxy/route_mode_security_tests.rs @@ -0,0 +1,106 @@ +use super::*; + +#[test] +fn cutover_stagger_delay_is_deterministic_for_same_inputs() { + let d1 = cutover_stagger_delay(0x0123_4567_89ab_cdef, 42); + let d2 = cutover_stagger_delay(0x0123_4567_89ab_cdef, 42); + assert_eq!( + d1, d2, + "stagger delay must be deterministic for identical session/generation inputs" + ); +} + +#[test] +fn cutover_stagger_delay_stays_within_budget_bounds() { + // Black-hat model: censors trigger many cutovers and correlate disconnect timing. + // Keep delay inside a narrow coarse window to avoid long-tail spikes. + for generation in [0u64, 1, 2, 3, 16, 128, u32::MAX as u64, u64::MAX] { + for session_id in [ + 0u64, + 1, + 2, + 0xdead_beef, + 0xfeed_face_cafe_babe, + u64::MAX, + ] { + let delay = cutover_stagger_delay(session_id, generation); + assert!( + (1000..=1999).contains(&delay.as_millis()), + "stagger delay must remain in fixed 1000..=1999ms budget" + ); + } + } +} + +#[test] +fn cutover_stagger_delay_changes_with_generation_for_same_session() { + let session_id = 0x0123_4567_89ab_cdef; + let first = cutover_stagger_delay(session_id, 100); + let second = cutover_stagger_delay(session_id, 101); + assert_ne!( + first, second, + "adjacent cutover generations should decorrelate disconnect delays" + ); +} + +#[test] +fn route_runtime_set_mode_is_idempotent_for_same_mode() { + let runtime = RouteRuntimeController::new(RelayRouteMode::Direct); + let first = runtime.snapshot(); + let changed = runtime.set_mode(RelayRouteMode::Direct); + let second = runtime.snapshot(); + + assert!( + changed.is_none(), + "setting already-active mode must not produce a cutover event" + ); + assert_eq!( + first.generation, second.generation, + "idempotent mode set must not bump generation" + ); +} + +#[test] +fn affected_cutover_state_triggers_only_for_newer_generation() { + let runtime = RouteRuntimeController::new(RelayRouteMode::Direct); + let rx = runtime.subscribe(); + let initial = runtime.snapshot(); + + assert!( + affected_cutover_state(&rx, RelayRouteMode::Direct, initial.generation).is_none(), + "current generation must not be considered a cutover for existing session" + ); + + let next = runtime + .set_mode(RelayRouteMode::Middle) + .expect("mode change must produce cutover state"); + let seen = affected_cutover_state(&rx, RelayRouteMode::Direct, initial.generation) + .expect("newer generation must be observed as cutover"); + + assert_eq!(seen.generation, next.generation); + assert_eq!(seen.mode, RelayRouteMode::Middle); +} + +#[test] +fn light_fuzz_cutover_stagger_delay_distribution_stays_in_fixed_window() { + // Deterministic xorshift fuzzing keeps this test stable across runs. + let mut s: u64 = 0x9E37_79B9_7F4A_7C15; + + for _ in 0..20_000 { + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + let session_id = s; + + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + let generation = s; + + let delay = cutover_stagger_delay(session_id, generation); + assert!( + (1000..=1999).contains(&delay.as_millis()), + "fuzzed inputs must always map into fixed stagger window" + ); + } +}