From 8821e38013e94fc0e8e1fae087a8fbf30976d60b Mon Sep 17 00:00:00 2001 From: David Osipov Date: Tue, 17 Mar 2026 02:19:14 +0400 Subject: [PATCH] feat(proxy): enhance auth probe capacity with stale entry pruning and new tests --- src/proxy/handshake.rs | 25 ++++- src/proxy/handshake_security_tests.rs | 144 +++++++++++++++++++++++--- 2 files changed, 152 insertions(+), 17 deletions(-) diff --git a/src/proxy/handshake.rs b/src/proxy/handshake.rs index a26a722..ef98144 100644 --- a/src/proxy/handshake.rs +++ b/src/proxy/handshake.rs @@ -27,7 +27,11 @@ const ACCESS_SECRET_BYTES: usize = 16; static INVALID_SECRET_WARNED: OnceLock>> = OnceLock::new(); const AUTH_PROBE_TRACK_RETENTION_SECS: u64 = 10 * 60; +#[cfg(test)] +const AUTH_PROBE_TRACK_MAX_ENTRIES: usize = 256; +#[cfg(not(test))] const AUTH_PROBE_TRACK_MAX_ENTRIES: usize = 65_536; +const AUTH_PROBE_PRUNE_SCAN_LIMIT: usize = 1_024; const AUTH_PROBE_BACKOFF_START_FAILS: u32 = 4; #[cfg(test)] @@ -85,6 +89,14 @@ fn auth_probe_is_throttled(peer_ip: IpAddr, now: Instant) -> bool { fn auth_probe_record_failure(peer_ip: IpAddr, now: Instant) { let state = auth_probe_state_map(); + auth_probe_record_failure_with_state(state, peer_ip, now); +} + +fn auth_probe_record_failure_with_state( + state: &DashMap, + peer_ip: IpAddr, + now: Instant, +) { if let Some(mut entry) = state.get_mut(&peer_ip) { if auth_probe_state_expired(&entry, now) { *entry = AuthProbeState { @@ -101,7 +113,18 @@ fn auth_probe_record_failure(peer_ip: IpAddr, now: Instant) { }; if state.len() >= AUTH_PROBE_TRACK_MAX_ENTRIES { - return; + let mut stale_keys = Vec::new(); + for entry in state.iter().take(AUTH_PROBE_PRUNE_SCAN_LIMIT) { + if auth_probe_state_expired(entry.value(), now) { + stale_keys.push(*entry.key()); + } + } + for stale_key in stale_keys { + state.remove(&stale_key); + } + if state.len() >= AUTH_PROBE_TRACK_MAX_ENTRIES { + return; + } } state.insert(peer_ip, AuthProbeState { diff --git a/src/proxy/handshake_security_tests.rs b/src/proxy/handshake_security_tests.rs index 5f62048..f2d7d03 100644 --- a/src/proxy/handshake_security_tests.rs +++ b/src/proxy/handshake_security_tests.rs @@ -1,5 +1,7 @@ use super::*; use crate::crypto::sha256_hmac; +use dashmap::DashMap; +use std::net::{IpAddr, Ipv4Addr}; use std::sync::Arc; use std::time::{Duration, Instant}; @@ -145,7 +147,7 @@ fn test_handshake_success_drop_does_not_panic() { dec_iv: 0xBBBBBBBB, enc_key: [0xCC; 32], enc_iv: 0xDDDDDDDD, - peer: "127.0.0.1:1234".parse().unwrap(), + peer: "198.51.100.10:1234".parse().unwrap(), is_tls: true, }; @@ -261,7 +263,7 @@ async fn tls_replay_second_identical_handshake_is_rejected() { let config = test_config_with_secret_hex("11111111111111111111111111111111"); let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); let rng = SecureRandom::new(); - let peer: SocketAddr = "127.0.0.1:44321".parse().unwrap(); + let peer: SocketAddr = "198.51.100.21:44321".parse().unwrap(); let handshake = make_valid_tls_handshake(&secret, 0); let first = handle_tls_handshake( @@ -310,7 +312,7 @@ async fn tls_replay_concurrent_identical_handshake_allows_exactly_one_success() &handshake, tokio::io::empty(), tokio::io::sink(), - "127.0.0.1:45000".parse().unwrap(), + "198.51.100.22:45000".parse().unwrap(), &config, &replay_checker, &rng, @@ -341,7 +343,7 @@ async fn invalid_tls_probe_does_not_pollute_replay_cache() { let config = test_config_with_secret_hex("11111111111111111111111111111111"); let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); let rng = SecureRandom::new(); - let peer: SocketAddr = "127.0.0.1:44322".parse().unwrap(); + let peer: SocketAddr = "198.51.100.23:44322".parse().unwrap(); let mut invalid = vec![0x42u8; tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + 32]; invalid[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = 32; @@ -371,7 +373,7 @@ async fn empty_decoded_secret_is_rejected() { let config = test_config_with_secret_hex(""); let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); let rng = SecureRandom::new(); - let peer: SocketAddr = "127.0.0.1:44323".parse().unwrap(); + let peer: SocketAddr = "198.51.100.24:44323".parse().unwrap(); let handshake = make_valid_tls_handshake(&[], 0); let result = handle_tls_handshake( @@ -395,7 +397,7 @@ async fn wrong_length_decoded_secret_is_rejected() { let config = test_config_with_secret_hex("aa"); let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); let rng = SecureRandom::new(); - let peer: SocketAddr = "127.0.0.1:44324".parse().unwrap(); + let peer: SocketAddr = "198.51.100.25:44324".parse().unwrap(); let handshake = make_valid_tls_handshake(&[0xaau8], 0); let result = handle_tls_handshake( @@ -417,7 +419,7 @@ async fn wrong_length_decoded_secret_is_rejected() { async fn invalid_mtproto_probe_does_not_pollute_replay_cache() { let config = test_config_with_secret_hex("11111111111111111111111111111111"); let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); - let peer: SocketAddr = "127.0.0.1:44325".parse().unwrap(); + let peer: SocketAddr = "198.51.100.26:44325".parse().unwrap(); let handshake = [0u8; HANDSHAKE_LEN]; let before = replay_checker.stats(); @@ -458,7 +460,7 @@ async fn mixed_secret_lengths_keep_valid_user_authenticating() { let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); let rng = SecureRandom::new(); - let peer: SocketAddr = "127.0.0.1:44326".parse().unwrap(); + let peer: SocketAddr = "198.51.100.27:44326".parse().unwrap(); let handshake = make_valid_tls_handshake(&good_secret, 0); let result = handle_tls_handshake( @@ -484,7 +486,7 @@ async fn alpn_enforce_rejects_unsupported_client_alpn() { let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); let rng = SecureRandom::new(); - let peer: SocketAddr = "127.0.0.1:44327".parse().unwrap(); + let peer: SocketAddr = "198.51.100.28:44327".parse().unwrap(); let handshake = make_valid_tls_client_hello_with_alpn(&secret, 0, &[b"h3"]); let result = handle_tls_handshake( @@ -510,7 +512,7 @@ async fn alpn_enforce_accepts_h2() { let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); let rng = SecureRandom::new(); - let peer: SocketAddr = "127.0.0.1:44328".parse().unwrap(); + let peer: SocketAddr = "198.51.100.29:44328".parse().unwrap(); let handshake = make_valid_tls_client_hello_with_alpn(&secret, 0, &[b"h2", b"h3"]); let result = handle_tls_handshake( @@ -536,7 +538,7 @@ async fn malformed_tls_classes_complete_within_bounded_time() { let replay_checker = ReplayChecker::new(512, Duration::from_secs(60)); let rng = SecureRandom::new(); - let peer: SocketAddr = "127.0.0.1:44329".parse().unwrap(); + let peer: SocketAddr = "198.51.100.30:44329".parse().unwrap(); let too_short = vec![0x16, 0x03, 0x01]; @@ -578,7 +580,7 @@ async fn malformed_tls_classes_share_close_latency_buckets() { let replay_checker = ReplayChecker::new(4096, Duration::from_secs(60)); let rng = SecureRandom::new(); - let peer: SocketAddr = "127.0.0.1:44330".parse().unwrap(); + let peer: SocketAddr = "198.51.100.31:44330".parse().unwrap(); let too_short = vec![0x16, 0x03, 0x01]; @@ -667,6 +669,43 @@ fn secure_tag_requires_secure_mode_on_direct_transport() { ); } +#[test] +fn mode_policy_matrix_is_stable_for_all_tag_transport_mode_combinations() { + let tags = [ProtoTag::Secure, ProtoTag::Intermediate, ProtoTag::Abridged]; + + for classic in [false, true] { + for secure in [false, true] { + for tls in [false, true] { + let mut config = ProxyConfig::default(); + config.general.modes.classic = classic; + config.general.modes.secure = secure; + config.general.modes.tls = tls; + + for is_tls in [false, true] { + for tag in tags { + let expected = match (tag, is_tls) { + (ProtoTag::Secure, true) => tls, + (ProtoTag::Secure, false) => secure, + (ProtoTag::Intermediate | ProtoTag::Abridged, _) => classic, + }; + + assert_eq!( + mode_enabled_for_proto(&config, tag, is_tls), + expected, + "mode policy drifted for tag={:?}, transport_tls={}, modes=(classic={}, secure={}, tls={})", + tag, + is_tls, + classic, + secure, + tls + ); + } + } + } + } + } +} + #[test] fn invalid_secret_warning_keys_do_not_collide_on_colon_boundaries() { clear_warned_secrets_for_testing(); @@ -689,13 +728,13 @@ fn invalid_secret_warning_keys_do_not_collide_on_colon_boundaries() { async fn repeated_invalid_tls_probes_trigger_pre_auth_throttle() { let _guard = auth_probe_test_lock() .lock() - .expect("auth probe test lock must be available"); + .unwrap_or_else(|poisoned| poisoned.into_inner()); clear_auth_probe_state_for_testing(); let config = test_config_with_secret_hex("11111111111111111111111111111111"); let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); let rng = SecureRandom::new(); - let peer: SocketAddr = "127.0.0.1:44361".parse().unwrap(); + let peer: SocketAddr = "198.51.100.61:44361".parse().unwrap(); let mut invalid = vec![0x42u8; tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + 32]; invalid[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = 32; @@ -725,14 +764,14 @@ async fn repeated_invalid_tls_probes_trigger_pre_auth_throttle() { async fn successful_tls_handshake_clears_pre_auth_failure_streak() { let _guard = auth_probe_test_lock() .lock() - .expect("auth probe test lock must be available"); + .unwrap_or_else(|poisoned| poisoned.into_inner()); clear_auth_probe_state_for_testing(); let secret = [0x23u8; 16]; let config = test_config_with_secret_hex("23232323232323232323232323232323"); let replay_checker = ReplayChecker::new(256, Duration::from_secs(60)); let rng = SecureRandom::new(); - let peer: SocketAddr = "127.0.0.1:44362".parse().unwrap(); + let peer: SocketAddr = "198.51.100.62:44362".parse().unwrap(); let mut invalid = vec![0x42u8; tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + 32]; invalid[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = 32; @@ -777,3 +816,76 @@ async fn successful_tls_handshake_clears_pre_auth_failure_streak() { "successful authentication must clear accumulated pre-auth failures" ); } + +#[test] +fn auth_probe_capacity_prunes_stale_entries_for_new_ips() { + let state = DashMap::new(); + let now = Instant::now(); + let stale_seen = now - Duration::from_secs(AUTH_PROBE_TRACK_RETENTION_SECS + 1); + + for idx in 0..AUTH_PROBE_TRACK_MAX_ENTRIES { + let ip = IpAddr::V4(Ipv4Addr::new( + 10, + 1, + ((idx >> 8) & 0xff) as u8, + (idx & 0xff) as u8, + )); + state.insert( + ip, + AuthProbeState { + fail_streak: 1, + blocked_until: now, + last_seen: stale_seen, + }, + ); + } + + let newcomer = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 200)); + auth_probe_record_failure_with_state(&state, newcomer, now); + + assert_eq!( + state.get(&newcomer).map(|entry| entry.fail_streak), + Some(1), + "stale-entry pruning must admit and track a new probe source" + ); + assert!( + state.len() <= AUTH_PROBE_TRACK_MAX_ENTRIES, + "auth probe map must remain bounded after stale pruning" + ); +} + +#[test] +fn auth_probe_capacity_stays_fail_closed_when_map_is_fresh_and_full() { + let state = DashMap::new(); + let now = Instant::now(); + + for idx in 0..AUTH_PROBE_TRACK_MAX_ENTRIES { + let ip = IpAddr::V4(Ipv4Addr::new( + 172, + 16, + ((idx >> 8) & 0xff) as u8, + (idx & 0xff) as u8, + )); + state.insert( + ip, + AuthProbeState { + fail_streak: 1, + blocked_until: now, + last_seen: now, + }, + ); + } + + let newcomer = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 55)); + auth_probe_record_failure_with_state(&state, newcomer, now); + + assert!( + state.get(&newcomer).is_none(), + "when all entries are fresh and full, new probes must not be admitted" + ); + assert_eq!( + state.len(), + AUTH_PROBE_TRACK_MAX_ENTRIES, + "auth probe map must stay at the configured cap" + ); +}