diff --git a/src/cli.rs b/src/cli.rs index a1182a7..8ea9c9f 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -239,7 +239,7 @@ tls_full_cert_ttl_secs = 90 [access] replay_check_len = 65536 -replay_window_secs = 1800 +replay_window_secs = 120 ignore_time_skew = false [access.users] diff --git a/src/config/defaults.rs b/src/config/defaults.rs index ea9250d..a136539 100644 --- a/src/config/defaults.rs +++ b/src/config/defaults.rs @@ -73,7 +73,7 @@ pub(crate) fn default_replay_check_len() -> usize { } pub(crate) fn default_replay_window_secs() -> u64 { - 1800 + 120 } pub(crate) fn default_handshake_timeout() -> u64 { @@ -456,11 +456,11 @@ pub(crate) fn default_tls_full_cert_ttl_secs() -> u64 { } pub(crate) fn default_server_hello_delay_min_ms() -> u64 { - 0 + 8 } pub(crate) fn default_server_hello_delay_max_ms() -> u64 { - 0 + 24 } pub(crate) fn default_alpn_enforce() -> bool { diff --git a/src/protocol/tls.rs b/src/protocol/tls.rs index 0f54245..3a22214 100644 --- a/src/protocol/tls.rs +++ b/src/protocol/tls.rs @@ -27,8 +27,8 @@ pub const TLS_DIGEST_POS: usize = 11; pub const TLS_DIGEST_HALF_LEN: usize = 16; /// Time skew limits for anti-replay (in seconds) -pub const TIME_SKEW_MIN: i64 = -20 * 60; // 20 minutes before -pub const TIME_SKEW_MAX: i64 = 10 * 60; // 10 minutes after +pub const TIME_SKEW_MIN: i64 = -2 * 60; // 2 minutes before +pub const TIME_SKEW_MAX: i64 = 2 * 60; // 2 minutes after /// Maximum accepted boot-time timestamp (seconds) before skew checks are enforced. pub const BOOT_TIME_MAX_SECS: u32 = 7 * 24 * 60 * 60; diff --git a/src/protocol/tls_security_tests.rs b/src/protocol/tls_security_tests.rs index 98d7319..bfc8f0d 100644 --- a/src/protocol/tls_security_tests.rs +++ b/src/protocol/tls_security_tests.rs @@ -1394,3 +1394,111 @@ fn server_hello_application_data_payload_varies_across_runs() { "ApplicationData payload should vary across runs to reduce fingerprintability" ); } + +#[test] +fn replay_window_zero_disables_boot_bypass_for_any_nonzero_timestamp() { + let secret = b"window_zero_boot_bypass_test"; + let secrets = vec![("u".to_string(), secret.to_vec())]; + + let ts1 = make_valid_tls_handshake(secret, 1); + assert!( + validate_tls_handshake_with_replay_window(&ts1, &secrets, false, 0).is_none(), + "replay_window_secs=0 must reject nonzero timestamps even in boot-time range" + ); + + let ts0 = make_valid_tls_handshake(secret, 0); + assert!( + validate_tls_handshake_with_replay_window(&ts0, &secrets, false, 0).is_none(), + "replay_window_secs=0 enforces strict skew check and rejects timestamp=0 on normal wall-clock systems" + ); +} + +#[test] +fn large_replay_window_does_not_expand_time_skew_acceptance() { + let secret = b"large_replay_window_skew_bound_test"; + let secrets = vec![("u".to_string(), secret.to_vec())]; + let now: i64 = 1_700_000_000; + + let ts_far_past = (now - 600) as u32; + let valid = make_valid_tls_handshake(secret, ts_far_past); + assert!( + validate_tls_handshake_with_replay_window(&valid, &secrets, false, 86_400).is_none(), + "large replay window must not relax strict skew check once boot-time bypass is not in play" + ); +} + +#[test] +fn parse_tls_record_header_accepts_tls_version_constant() { + let header = [TLS_RECORD_HANDSHAKE, TLS_VERSION[0], TLS_VERSION[1], 0x00, 0x2A]; + let parsed = parse_tls_record_header(&header).expect("TLS_VERSION header should be accepted"); + assert_eq!(parsed.0, TLS_RECORD_HANDSHAKE); + assert_eq!(parsed.1, 42); +} + +#[test] +fn server_hello_clamps_fake_cert_len_lower_bound() { + let secret = b"fake_cert_lower_bound_test"; + let client_digest = [0x11u8; TLS_DIGEST_LEN]; + let session_id = vec![0x77; 32]; + let rng = crate::crypto::SecureRandom::new(); + + let response = build_server_hello(secret, &client_digest, &session_id, 1, &rng, None, 0); + + let sh_len = u16::from_be_bytes([response[3], response[4]]) as usize; + let ccs_pos = 5 + sh_len; + let ccs_len = u16::from_be_bytes([response[ccs_pos + 3], response[ccs_pos + 4]]) as usize; + let app_pos = ccs_pos + 5 + ccs_len; + let app_len = u16::from_be_bytes([response[app_pos + 3], response[app_pos + 4]]) as usize; + + assert_eq!(response[app_pos], TLS_RECORD_APPLICATION); + assert_eq!(app_len, 64, "fake cert payload must be clamped to minimum 64 bytes"); +} + +#[test] +fn server_hello_clamps_fake_cert_len_upper_bound() { + let secret = b"fake_cert_upper_bound_test"; + let client_digest = [0x22u8; TLS_DIGEST_LEN]; + let session_id = vec![0x66; 32]; + let rng = crate::crypto::SecureRandom::new(); + + let response = build_server_hello(secret, &client_digest, &session_id, 65_535, &rng, None, 0); + + let sh_len = u16::from_be_bytes([response[3], response[4]]) as usize; + let ccs_pos = 5 + sh_len; + let ccs_len = u16::from_be_bytes([response[ccs_pos + 3], response[ccs_pos + 4]]) as usize; + let app_pos = ccs_pos + 5 + ccs_len; + let app_len = u16::from_be_bytes([response[app_pos + 3], response[app_pos + 4]]) as usize; + + assert_eq!(response[app_pos], TLS_RECORD_APPLICATION); + assert_eq!(app_len, 16_640, "fake cert payload must be clamped to TLS record max bound"); +} + +#[test] +fn server_hello_new_session_ticket_count_matches_configuration() { + let secret = b"ticket_count_surface_test"; + let client_digest = [0x33u8; TLS_DIGEST_LEN]; + let session_id = vec![0x55; 32]; + let rng = crate::crypto::SecureRandom::new(); + + let tickets: u8 = 3; + let response = build_server_hello(secret, &client_digest, &session_id, 1024, &rng, None, tickets); + + let mut pos = 0usize; + let mut app_records = 0usize; + while pos + 5 <= response.len() { + let rtype = response[pos]; + let rlen = u16::from_be_bytes([response[pos + 3], response[pos + 4]]) as usize; + let next = pos + 5 + rlen; + assert!(next <= response.len(), "TLS record must stay inside response bounds"); + if rtype == TLS_RECORD_APPLICATION { + app_records += 1; + } + pos = next; + } + + assert_eq!( + app_records, + 1 + tickets as usize, + "response must contain one main application record plus configured ticket-like tail records" + ); +} diff --git a/src/proxy/handshake.rs b/src/proxy/handshake.rs index a1b3eb7..e25fe39 100644 --- a/src/proxy/handshake.rs +++ b/src/proxy/handshake.rs @@ -36,6 +36,7 @@ const AUTH_PROBE_TRACK_MAX_ENTRIES: usize = 256; const AUTH_PROBE_TRACK_MAX_ENTRIES: usize = 65_536; const AUTH_PROBE_PRUNE_SCAN_LIMIT: usize = 1_024; const AUTH_PROBE_BACKOFF_START_FAILS: u32 = 4; +const AUTH_PROBE_SATURATION_GRACE_FAILS: u32 = 2; #[cfg(test)] const AUTH_PROBE_BACKOFF_BASE_MS: u64 = 1; @@ -54,12 +55,24 @@ struct AuthProbeState { last_seen: Instant, } +#[derive(Clone, Copy)] +struct AuthProbeSaturationState { + fail_streak: u32, + blocked_until: Instant, + last_seen: Instant, +} + static AUTH_PROBE_STATE: OnceLock> = OnceLock::new(); +static AUTH_PROBE_SATURATION_STATE: OnceLock>> = OnceLock::new(); fn auth_probe_state_map() -> &'static DashMap { AUTH_PROBE_STATE.get_or_init(DashMap::new) } +fn auth_probe_saturation_state() -> &'static Mutex> { + AUTH_PROBE_SATURATION_STATE.get_or_init(|| Mutex::new(None)) +} + fn normalize_auth_probe_ip(peer_ip: IpAddr) -> IpAddr { match peer_ip { IpAddr::V4(ip) => IpAddr::V4(ip), @@ -108,6 +121,83 @@ fn auth_probe_is_throttled(peer_ip: IpAddr, now: Instant) -> bool { now < entry.blocked_until } +fn auth_probe_saturation_grace_exhausted(peer_ip: IpAddr, now: Instant) -> bool { + let peer_ip = normalize_auth_probe_ip(peer_ip); + let state = auth_probe_state_map(); + let Some(entry) = state.get(&peer_ip) else { + return false; + }; + if auth_probe_state_expired(&entry, now) { + drop(entry); + state.remove(&peer_ip); + return false; + } + + entry.fail_streak >= AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS +} + +fn auth_probe_should_apply_preauth_throttle(peer_ip: IpAddr, now: Instant) -> bool { + if !auth_probe_is_throttled(peer_ip, now) { + return false; + } + + if !auth_probe_saturation_is_throttled(now) { + return true; + } + + auth_probe_saturation_grace_exhausted(peer_ip, now) +} + +fn auth_probe_saturation_is_throttled(now: Instant) -> bool { + let saturation = auth_probe_saturation_state(); + let mut guard = match saturation.lock() { + Ok(guard) => guard, + Err(_) => return false, + }; + + let Some(state) = guard.as_mut() else { + return false; + }; + + if now.duration_since(state.last_seen) > Duration::from_secs(AUTH_PROBE_TRACK_RETENTION_SECS) { + *guard = None; + return false; + } + + if now < state.blocked_until { + return true; + } + + false +} + +fn auth_probe_note_saturation(now: Instant) { + let saturation = auth_probe_saturation_state(); + let mut guard = match saturation.lock() { + Ok(guard) => guard, + Err(_) => return, + }; + + match guard.as_mut() { + Some(state) + if now.duration_since(state.last_seen) + <= Duration::from_secs(AUTH_PROBE_TRACK_RETENTION_SECS) => + { + state.fail_streak = state.fail_streak.saturating_add(1); + state.last_seen = now; + state.blocked_until = now + auth_probe_backoff(state.fail_streak); + } + _ => { + let fail_streak = AUTH_PROBE_BACKOFF_START_FAILS; + *guard = Some(AuthProbeSaturationState { + fail_streak, + blocked_until: now + auth_probe_backoff(fail_streak), + last_seen: now, + }); + } + } +} + fn auth_probe_record_failure(peer_ip: IpAddr, now: Instant) { let peer_ip = normalize_auth_probe_ip(peer_ip); let state = auth_probe_state_map(); @@ -157,11 +247,11 @@ fn auth_probe_record_failure_with_state( } if state.len() >= AUTH_PROBE_TRACK_MAX_ENTRIES { if eviction_candidates.is_empty() { + auth_probe_note_saturation(now); return; } - let idx = auth_probe_eviction_offset(peer_ip, now) % eviction_candidates.len(); - let evict_key = eviction_candidates[idx]; - state.remove(&evict_key); + auth_probe_note_saturation(now); + return; } } @@ -186,6 +276,11 @@ fn clear_auth_probe_state_for_testing() { if let Some(state) = AUTH_PROBE_STATE.get() { state.clear(); } + if let Some(saturation) = AUTH_PROBE_SATURATION_STATE.get() + && let Ok(mut guard) = saturation.lock() + { + *guard = None; + } } #[cfg(test)] @@ -200,6 +295,11 @@ fn auth_probe_is_throttled_for_testing(peer_ip: IpAddr) -> bool { auth_probe_is_throttled(peer_ip, Instant::now()) } +#[cfg(test)] +fn auth_probe_saturation_is_throttled_for_testing() -> bool { + auth_probe_saturation_is_throttled(Instant::now()) +} + #[cfg(test)] fn auth_probe_test_lock() -> &'static Mutex<()> { static TEST_LOCK: OnceLock> = OnceLock::new(); @@ -385,7 +485,8 @@ where { debug!(peer = %peer, handshake_len = handshake.len(), "Processing TLS handshake"); - if auth_probe_is_throttled(peer.ip(), Instant::now()) { + let throttle_now = Instant::now(); + if auth_probe_should_apply_preauth_throttle(peer.ip(), throttle_now) { maybe_apply_server_hello_delay(config).await; debug!(peer = %peer, "TLS handshake rejected by pre-auth probe throttle"); return HandshakeResult::BadClient { reader, writer }; @@ -554,7 +655,8 @@ where { trace!(peer = %peer, handshake = ?hex::encode(handshake), "MTProto handshake bytes"); - if auth_probe_is_throttled(peer.ip(), Instant::now()) { + let throttle_now = Instant::now(); + if auth_probe_should_apply_preauth_throttle(peer.ip(), throttle_now) { maybe_apply_server_hello_delay(config).await; debug!(peer = %peer, "MTProto handshake rejected by pre-auth probe throttle"); return HandshakeResult::BadClient { reader, writer }; diff --git a/src/proxy/handshake_security_tests.rs b/src/proxy/handshake_security_tests.rs index 7040025..b14ab58 100644 --- a/src/proxy/handshake_security_tests.rs +++ b/src/proxy/handshake_security_tests.rs @@ -1,6 +1,8 @@ use super::*; -use crate::crypto::sha256_hmac; +use crate::crypto::{sha256, sha256_hmac}; use dashmap::DashMap; +use rand::{Rng, SeedableRng}; +use rand::rngs::StdRng; use std::net::{IpAddr, Ipv4Addr}; use std::sync::Arc; use std::time::{Duration, Instant}; @@ -94,6 +96,43 @@ fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig { cfg } +fn make_valid_mtproto_handshake(secret_hex: &str, proto_tag: ProtoTag, dc_idx: i16) -> [u8; HANDSHAKE_LEN] { + let secret = hex::decode(secret_hex).expect("secret hex must decode for mtproto test helper"); + + let mut handshake = [0x5Au8; HANDSHAKE_LEN]; + for (idx, b) in handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN] + .iter_mut() + .enumerate() + { + *b = (idx as u8).wrapping_add(1); + } + + let dec_prekey = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN]; + let dec_iv_bytes = &handshake[SKIP_LEN + PREKEY_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]; + + let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len()); + dec_key_input.extend_from_slice(dec_prekey); + dec_key_input.extend_from_slice(&secret); + let dec_key = sha256(&dec_key_input); + + let mut dec_iv_arr = [0u8; IV_LEN]; + dec_iv_arr.copy_from_slice(dec_iv_bytes); + let dec_iv = u128::from_be_bytes(dec_iv_arr); + + let mut stream = AesCtr::new(&dec_key, dec_iv); + let keystream = stream.encrypt(&[0u8; HANDSHAKE_LEN]); + + let mut target_plain = [0u8; HANDSHAKE_LEN]; + target_plain[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes()); + target_plain[DC_IDX_POS..DC_IDX_POS + 2].copy_from_slice(&dc_idx.to_le_bytes()); + + for idx in PROTO_TAG_POS..HANDSHAKE_LEN { + handshake[idx] = target_plain[idx] ^ keystream[idx]; + } + + handshake +} + #[test] fn test_generate_tg_nonce() { let client_enc_key = [0x24u8; 32]; @@ -349,6 +388,7 @@ async fn invalid_tls_probe_does_not_pollute_replay_cache() { invalid[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = 32; let before = replay_checker.stats(); + let result = handle_tls_handshake( &invalid, tokio::io::empty(), @@ -1013,7 +1053,7 @@ fn auth_probe_capacity_prunes_stale_entries_for_new_ips() { } #[test] -fn auth_probe_capacity_forces_bounded_eviction_when_map_is_fresh_and_full() { +fn auth_probe_capacity_saturation_enables_global_throttle_when_map_is_fresh_and_full() { let state = DashMap::new(); let now = Instant::now(); @@ -1038,13 +1078,17 @@ fn auth_probe_capacity_forces_bounded_eviction_when_map_is_fresh_and_full() { auth_probe_record_failure_with_state(&state, newcomer, now); assert!( - state.get(&newcomer).is_some(), - "when all entries are fresh and full, one bounded eviction must admit a new probe source" + state.get(&newcomer).is_none(), + "fresh-at-cap auth probe state must not churn by evicting tracked sources" ); assert_eq!( state.len(), AUTH_PROBE_TRACK_MAX_ENTRIES, - "auth probe map must stay at the configured cap after forced eviction" + "auth probe map must stay exactly at the configured cap under saturation" + ); + assert!( + auth_probe_saturation_is_throttled_for_testing(), + "capacity saturation must activate coarse global pre-auth throttling" ); } @@ -1250,3 +1294,1118 @@ async fn invalid_probe_noise_from_other_ips_does_not_break_valid_tls_handshake() "successful victim handshake must not retain pre-auth failure streak" ); } + +#[test] +fn auth_probe_saturation_state_expires_after_retention_window() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let now = Instant::now(); + let saturation = auth_probe_saturation_state(); + { + let mut guard = saturation + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(30), + last_seen: now - Duration::from_secs(AUTH_PROBE_TRACK_RETENTION_SECS + 1), + }); + } + + assert!( + !auth_probe_saturation_is_throttled_for_testing(), + "expired saturation state must stop throttling and self-clear" + ); + + let guard = saturation + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + assert!(guard.is_none(), "expired saturation state must be removed"); +} + +#[tokio::test] +async fn global_saturation_marker_does_not_block_valid_tls_handshake() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let secret = [0x41u8; 16]; + let config = test_config_with_secret_hex("41414141414141414141414141414141"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.101:45101".parse().unwrap(); + + let now = Instant::now(); + let saturation = auth_probe_saturation_state(); + { + let mut guard = saturation + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(5), + last_seen: now, + }); + } + + let valid = make_valid_tls_handshake(&secret, 0); + let result = handle_tls_handshake( + &valid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!( + matches!(result, HandshakeResult::Success(_)), + "global saturation marker must not block valid authenticated TLS handshakes" + ); + assert_eq!( + auth_probe_fail_streak_for_testing(peer.ip()), + None, + "successful handshake under saturation marker must not retain per-ip probe failures" + ); +} + +#[tokio::test] +async fn expired_global_saturation_allows_valid_tls_handshake() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let secret = [0x55u8; 16]; + let config = test_config_with_secret_hex("55555555555555555555555555555555"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.102:45102".parse().unwrap(); + + let now = Instant::now(); + let saturation = auth_probe_saturation_state(); + { + let mut guard = saturation + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(5), + last_seen: now - Duration::from_secs(AUTH_PROBE_TRACK_RETENTION_SECS + 1), + }); + } + + let valid = make_valid_tls_handshake(&secret, 0); + let result = handle_tls_handshake( + &valid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!( + matches!(result, HandshakeResult::Success(_)), + "expired saturation marker must not block valid handshake" + ); +} + +#[tokio::test] +async fn valid_tls_is_blocked_by_per_ip_preauth_throttle_without_saturation() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let secret = [0x61u8; 16]; + let config = test_config_with_secret_hex("61616161616161616161616161616161"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.103:45103".parse().unwrap(); + + auth_probe_state_map().insert( + normalize_auth_probe_ip(peer.ip()), + AuthProbeState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: Instant::now() + Duration::from_secs(5), + last_seen: Instant::now(), + }, + ); + + let valid = make_valid_tls_handshake(&secret, 0); + let result = handle_tls_handshake( + &valid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(result, HandshakeResult::BadClient { .. })); +} + +#[tokio::test] +async fn saturation_allows_valid_tls_even_when_peer_ip_is_currently_throttled() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let secret = [0x62u8; 16]; + let config = test_config_with_secret_hex("62626262626262626262626262626262"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.104:45104".parse().unwrap(); + let now = Instant::now(); + + auth_probe_state_map().insert( + normalize_auth_probe_ip(peer.ip()), + AuthProbeState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(5), + last_seen: now, + }, + ); + { + let mut guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(5), + last_seen: now, + }); + } + + let valid = make_valid_tls_handshake(&secret, 0); + let result = handle_tls_handshake( + &valid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(result, HandshakeResult::Success(_))); + assert_eq!( + auth_probe_fail_streak_for_testing(peer.ip()), + None, + "successful auth under saturation must clear the peer's throttled state" + ); +} + +#[tokio::test] +async fn saturation_still_rejects_invalid_tls_probe_and_records_failure() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let config = test_config_with_secret_hex("63636363636363636363636363636363"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.105:45105".parse().unwrap(); + let now = Instant::now(); + { + let mut guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(5), + last_seen: now, + }); + } + + let mut invalid = vec![0x42u8; tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + 32]; + invalid[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = 32; + + let result = handle_tls_handshake( + &invalid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(result, HandshakeResult::BadClient { .. })); + assert_eq!( + auth_probe_fail_streak_for_testing(peer.ip()), + Some(1), + "invalid TLS during saturation must still increment per-ip failure tracking" + ); +} + +#[tokio::test] +async fn saturation_grace_exhaustion_preauth_throttles_repeated_invalid_tls_probe() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let config = test_config_with_secret_hex("63636363636363636363636363636363"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.205:45205".parse().unwrap(); + let now = Instant::now(); + auth_probe_state_map().insert( + normalize_auth_probe_ip(peer.ip()), + AuthProbeState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }, + ); + { + let mut guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }); + } + + let mut invalid = vec![0x42u8; tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + 32]; + invalid[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = 32; + + let result = handle_tls_handshake( + &invalid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(result, HandshakeResult::BadClient { .. })); + assert_eq!( + auth_probe_fail_streak_for_testing(peer.ip()), + Some(AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS), + "pre-auth throttle under exhausted saturation grace must reject without re-processing invalid TLS" + ); +} + +#[tokio::test] +async fn saturation_allows_valid_mtproto_even_when_peer_ip_is_currently_throttled() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let secret_hex = "64646464646464646464646464646464"; + let mut config = test_config_with_secret_hex(secret_hex); + config.general.modes.secure = true; + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let peer: SocketAddr = "198.51.100.106:45106".parse().unwrap(); + let now = Instant::now(); + + auth_probe_state_map().insert( + normalize_auth_probe_ip(peer.ip()), + AuthProbeState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(5), + last_seen: now, + }, + ); + { + let mut guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(5), + last_seen: now, + }); + } + + let valid = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 2); + let result = handle_mtproto_handshake( + &valid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + + assert!(matches!(result, HandshakeResult::Success(_))); + assert_eq!( + auth_probe_fail_streak_for_testing(peer.ip()), + None, + "successful mtproto auth under saturation must clear the peer's throttled state" + ); +} + +#[tokio::test] +async fn saturation_still_rejects_invalid_mtproto_probe_and_records_failure() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let config = test_config_with_secret_hex("65656565656565656565656565656565"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let peer: SocketAddr = "198.51.100.107:45107".parse().unwrap(); + let now = Instant::now(); + { + let mut guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(5), + last_seen: now, + }); + } + + let invalid = [0u8; HANDSHAKE_LEN]; + + let result = handle_mtproto_handshake( + &invalid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + + assert!(matches!(result, HandshakeResult::BadClient { .. })); + assert_eq!( + auth_probe_fail_streak_for_testing(peer.ip()), + Some(1), + "invalid mtproto during saturation must still increment per-ip failure tracking" + ); +} + +#[tokio::test] +async fn saturation_grace_exhaustion_preauth_throttles_repeated_invalid_mtproto_probe() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let config = test_config_with_secret_hex("65656565656565656565656565656565"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let peer: SocketAddr = "198.51.100.206:45206".parse().unwrap(); + let now = Instant::now(); + auth_probe_state_map().insert( + normalize_auth_probe_ip(peer.ip()), + AuthProbeState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }, + ); + { + let mut guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }); + } + + let invalid = [0u8; HANDSHAKE_LEN]; + let result = handle_mtproto_handshake( + &invalid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + + assert!(matches!(result, HandshakeResult::BadClient { .. })); + assert_eq!( + auth_probe_fail_streak_for_testing(peer.ip()), + Some(AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS), + "pre-auth throttle under exhausted saturation grace must reject without re-processing invalid MTProto" + ); +} + +#[tokio::test] +async fn saturation_grace_progression_tls_reaches_cap_then_stops_incrementing() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let config = test_config_with_secret_hex("70707070707070707070707070707070"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.207:45207".parse().unwrap(); + let now = Instant::now(); + auth_probe_state_map().insert( + normalize_auth_probe_ip(peer.ip()), + AuthProbeState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }, + ); + { + let mut guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }); + } + + let mut invalid = vec![0x42u8; tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + 32]; + invalid[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = 32; + + for expected in [ + AUTH_PROBE_BACKOFF_START_FAILS + 1, + AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS, + ] { + let result = handle_tls_handshake( + &invalid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(result, HandshakeResult::BadClient { .. })); + assert_eq!(auth_probe_fail_streak_for_testing(peer.ip()), Some(expected)); + } + + { + let mut entry = auth_probe_state_map() + .get_mut(&normalize_auth_probe_ip(peer.ip())) + .expect("peer state must exist before exhaustion recheck"); + entry.fail_streak = AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS; + entry.blocked_until = Instant::now() + Duration::from_secs(1); + entry.last_seen = Instant::now(); + } + + let result = handle_tls_handshake( + &invalid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(result, HandshakeResult::BadClient { .. })); + assert_eq!( + auth_probe_fail_streak_for_testing(peer.ip()), + Some(AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS), + "once grace is exhausted, repeated invalid TLS must be pre-auth throttled without further fail-streak growth" + ); +} + +#[tokio::test] +async fn saturation_grace_progression_mtproto_reaches_cap_then_stops_incrementing() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let config = test_config_with_secret_hex("71717171717171717171717171717171"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let peer: SocketAddr = "198.51.100.208:45208".parse().unwrap(); + let now = Instant::now(); + auth_probe_state_map().insert( + normalize_auth_probe_ip(peer.ip()), + AuthProbeState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }, + ); + { + let mut guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }); + } + + let invalid = [0u8; HANDSHAKE_LEN]; + + for expected in [ + AUTH_PROBE_BACKOFF_START_FAILS + 1, + AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS, + ] { + let result = handle_mtproto_handshake( + &invalid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + assert!(matches!(result, HandshakeResult::BadClient { .. })); + assert_eq!(auth_probe_fail_streak_for_testing(peer.ip()), Some(expected)); + } + + { + let mut entry = auth_probe_state_map() + .get_mut(&normalize_auth_probe_ip(peer.ip())) + .expect("peer state must exist before exhaustion recheck"); + entry.fail_streak = AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS; + entry.blocked_until = Instant::now() + Duration::from_secs(1); + entry.last_seen = Instant::now(); + } + + let result = handle_mtproto_handshake( + &invalid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + assert!(matches!(result, HandshakeResult::BadClient { .. })); + assert_eq!( + auth_probe_fail_streak_for_testing(peer.ip()), + Some(AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS), + "once grace is exhausted, repeated invalid MTProto must be pre-auth throttled without further fail-streak growth" + ); +} + +#[tokio::test] +async fn saturation_grace_boundary_still_admits_valid_tls_before_exhaustion() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let secret = [0x72u8; 16]; + let config = test_config_with_secret_hex("72727272727272727272727272727272"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.209:45209".parse().unwrap(); + let now = Instant::now(); + auth_probe_state_map().insert( + normalize_auth_probe_ip(peer.ip()), + AuthProbeState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS - 1, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }, + ); + { + let mut guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }); + } + + let valid = make_valid_tls_handshake(&secret, 0); + let result = handle_tls_handshake( + &valid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!( + matches!(result, HandshakeResult::Success(_)), + "valid TLS should still pass while peer remains within saturation grace budget" + ); + assert_eq!(auth_probe_fail_streak_for_testing(peer.ip()), None); +} + +#[tokio::test] +async fn saturation_grace_exhaustion_blocks_valid_tls_until_backoff_expires() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let secret = [0x73u8; 16]; + let config = test_config_with_secret_hex("73737373737373737373737373737373"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.210:45210".parse().unwrap(); + let now = Instant::now(); + auth_probe_state_map().insert( + normalize_auth_probe_ip(peer.ip()), + AuthProbeState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS, + blocked_until: now + Duration::from_millis(200), + last_seen: now, + }, + ); + { + let mut guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }); + } + + let valid = make_valid_tls_handshake(&secret, 0); + let blocked = handle_tls_handshake( + &valid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(blocked, HandshakeResult::BadClient { .. })); + + tokio::time::sleep(Duration::from_millis(230)).await; + + let allowed = handle_tls_handshake( + &valid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!( + matches!(allowed, HandshakeResult::Success(_)), + "valid TLS should recover after peer-specific pre-auth backoff has elapsed" + ); + assert_eq!(auth_probe_fail_streak_for_testing(peer.ip()), None); +} + +#[tokio::test] +async fn saturation_grace_exhaustion_is_shared_across_tls_and_mtproto_for_same_peer() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let config = test_config_with_secret_hex("74747474747474747474747474747474"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.211:45211".parse().unwrap(); + let now = Instant::now(); + auth_probe_state_map().insert( + normalize_auth_probe_ip(peer.ip()), + AuthProbeState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }, + ); + { + let mut guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }); + } + + let mut invalid_tls = vec![0x42u8; tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + 32]; + invalid_tls[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = 32; + let invalid_mtproto = [0u8; HANDSHAKE_LEN]; + + let tls_result = handle_tls_handshake( + &invalid_tls, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(tls_result, HandshakeResult::BadClient { .. })); + + let mtproto_result = handle_mtproto_handshake( + &invalid_mtproto, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + assert!(matches!(mtproto_result, HandshakeResult::BadClient { .. })); + + assert_eq!( + auth_probe_fail_streak_for_testing(peer.ip()), + Some(AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS), + "saturation grace exhaustion must gate both TLS and MTProto pre-auth paths for one peer" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn adversarial_same_peer_invalid_tls_storm_does_not_bypass_saturation_grace_cap() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let config = Arc::new(test_config_with_secret_hex("75757575757575757575757575757575")); + let replay_checker = Arc::new(ReplayChecker::new(1024, Duration::from_secs(60))); + let rng = Arc::new(SecureRandom::new()); + let peer: SocketAddr = "198.51.100.212:45212".parse().unwrap(); + let now = Instant::now(); + auth_probe_state_map().insert( + normalize_auth_probe_ip(peer.ip()), + AuthProbeState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }, + ); + { + let mut guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }); + } + + let mut invalid_tls = vec![0x42u8; tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + 32]; + invalid_tls[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = 32; + let invalid_tls = Arc::new(invalid_tls); + + let mut tasks = Vec::new(); + for _ in 0..64usize { + let config = config.clone(); + let replay_checker = replay_checker.clone(); + let rng = rng.clone(); + let invalid_tls = invalid_tls.clone(); + tasks.push(tokio::spawn(async move { + handle_tls_handshake( + &invalid_tls, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await + })); + } + + for task in tasks { + let result = task.await.unwrap(); + assert!(matches!(result, HandshakeResult::BadClient { .. })); + } + + assert_eq!( + auth_probe_fail_streak_for_testing(peer.ip()), + Some(AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS), + "same-peer invalid storm under exhausted grace must stay pre-auth throttled without fail-streak growth" + ); +} + +#[tokio::test] +async fn light_fuzz_saturation_grace_tls_invalid_inputs_never_authenticate_or_panic() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let config = test_config_with_secret_hex("76767676767676767676767676767676"); + let replay_checker = ReplayChecker::new(2048, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.213:45213".parse().unwrap(); + let now = Instant::now(); + auth_probe_state_map().insert( + normalize_auth_probe_ip(peer.ip()), + AuthProbeState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }, + ); + { + let mut guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }); + } + + let mut seeded = StdRng::seed_from_u64(0xD15EA5E5_u64); + for _ in 0..128usize { + let len = seeded.random_range(0usize..96usize); + let mut probe = vec![0u8; len]; + seeded.fill(&mut probe[..]); + + let result = handle_tls_handshake( + &probe, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(result, HandshakeResult::BadClient { .. })); + } + + let streak = auth_probe_fail_streak_for_testing(peer.ip()) + .expect("peer should remain tracked after repeated invalid fuzz probes"); + assert!( + streak >= AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS, + "fuzzed invalid TLS probes under saturation must not reduce fail-streak below exhaustion threshold" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn adversarial_saturation_burst_only_admits_valid_tls_and_mtproto_handshakes() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let secret_hex = "66666666666666666666666666666666"; + let secret = [0x66u8; 16]; + let mut cfg = test_config_with_secret_hex(secret_hex); + cfg.general.modes.secure = true; + let config = Arc::new(cfg); + let replay_checker = Arc::new(ReplayChecker::new(4096, Duration::from_secs(60))); + let rng = Arc::new(SecureRandom::new()); + let now = Instant::now(); + + { + let mut guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(5), + last_seen: now, + }); + } + + let valid_tls = Arc::new(make_valid_tls_handshake(&secret, 0)); + let valid_mtproto = Arc::new(make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 3)); + let mut invalid_tls = vec![0x42u8; tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + 32]; + invalid_tls[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = 32; + let invalid_tls = Arc::new(invalid_tls); + + let mut invalid_tls_tasks = Vec::new(); + for idx in 0..48u16 { + let config = config.clone(); + let replay_checker = replay_checker.clone(); + let rng = rng.clone(); + let invalid_tls = invalid_tls.clone(); + invalid_tls_tasks.push(tokio::spawn(async move { + let octet = ((idx % 200) + 1) as u8; + let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(203, 0, 113, octet)), 46000 + idx); + handle_tls_handshake( + &invalid_tls, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await + })); + } + + let valid_tls_task = { + let config = config.clone(); + let replay_checker = replay_checker.clone(); + let rng = rng.clone(); + let valid_tls = valid_tls.clone(); + tokio::spawn(async move { + handle_tls_handshake( + &valid_tls, + tokio::io::empty(), + tokio::io::sink(), + "198.51.100.108:45108".parse().unwrap(), + &config, + &replay_checker, + &rng, + None, + ) + .await + }) + }; + + let valid_mtproto_task = { + let config = config.clone(); + let replay_checker = replay_checker.clone(); + let valid_mtproto = valid_mtproto.clone(); + tokio::spawn(async move { + handle_mtproto_handshake( + &valid_mtproto, + tokio::io::empty(), + tokio::io::sink(), + "198.51.100.109:45109".parse().unwrap(), + &config, + &replay_checker, + false, + None, + ) + .await + }) + }; + + let mut bad_clients = 0usize; + for task in invalid_tls_tasks { + match task.await.unwrap() { + HandshakeResult::BadClient { .. } => bad_clients += 1, + HandshakeResult::Success(_) => panic!("invalid TLS probe unexpectedly authenticated"), + HandshakeResult::Error(err) => panic!("unexpected error in invalid TLS saturation burst test: {err}"), + } + } + + let valid_tls_result = valid_tls_task.await.unwrap(); + assert!( + matches!(valid_tls_result, HandshakeResult::Success(_)), + "valid TLS probe must authenticate during saturation burst" + ); + + let valid_mtproto_result = valid_mtproto_task.await.unwrap(); + assert!( + matches!(valid_mtproto_result, HandshakeResult::Success(_)), + "valid MTProto probe must authenticate during saturation burst" + ); + + assert_eq!( + bad_clients, + 48, + "all invalid TLS probes in mixed saturation burst must be rejected" + ); +} + +#[tokio::test] +async fn expired_saturation_keeps_per_ip_throttle_enforced_for_valid_tls() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let secret = [0x67u8; 16]; + let config = test_config_with_secret_hex("67676767676767676767676767676767"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.110:45110".parse().unwrap(); + let now = Instant::now(); + + auth_probe_state_map().insert( + normalize_auth_probe_ip(peer.ip()), + AuthProbeState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(5), + last_seen: now, + }, + ); + { + let mut guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(5), + last_seen: now - Duration::from_secs(AUTH_PROBE_TRACK_RETENTION_SECS + 1), + }); + } + + let valid = make_valid_tls_handshake(&secret, 0); + let result = handle_tls_handshake( + &valid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!( + matches!(result, HandshakeResult::BadClient { .. }), + "expired saturation marker must not disable per-ip pre-auth throttle" + ); +} diff --git a/src/proxy/masking.rs b/src/proxy/masking.rs index eb6f6da..636f637 100644 --- a/src/proxy/masking.rs +++ b/src/proxy/masking.rs @@ -24,8 +24,36 @@ const MASK_TIMEOUT: Duration = Duration::from_millis(50); const MASK_RELAY_TIMEOUT: Duration = Duration::from_secs(60); #[cfg(test)] const MASK_RELAY_TIMEOUT: Duration = Duration::from_millis(200); +#[cfg(not(test))] +const MASK_RELAY_IDLE_TIMEOUT: Duration = Duration::from_secs(5); +#[cfg(test)] +const MASK_RELAY_IDLE_TIMEOUT: Duration = Duration::from_millis(100); const MASK_BUFFER_SIZE: usize = 8192; +async fn copy_with_idle_timeout(reader: &mut R, writer: &mut W) +where + R: AsyncRead + Unpin, + W: AsyncWrite + Unpin, +{ + let mut buf = vec![0u8; MASK_BUFFER_SIZE]; + loop { + let read_res = timeout(MASK_RELAY_IDLE_TIMEOUT, reader.read(&mut buf)).await; + let n = match read_res { + Ok(Ok(n)) => n, + Ok(Err(_)) | Err(_) => break, + }; + if n == 0 { + break; + } + + let write_res = timeout(MASK_RELAY_IDLE_TIMEOUT, writer.write_all(&buf[..n])).await; + match write_res { + Ok(Ok(())) => {} + Ok(Err(_)) | Err(_) => break, + } + } +} + async fn write_proxy_header_with_timeout(mask_write: &mut W, header: &[u8]) -> bool where W: AsyncWrite + Unpin, @@ -264,11 +292,11 @@ where let _ = tokio::join!( async { - let _ = tokio::io::copy(&mut reader, &mut mask_write).await; + copy_with_idle_timeout(&mut reader, &mut mask_write).await; let _ = mask_write.shutdown().await; }, async { - let _ = tokio::io::copy(&mut mask_read, &mut writer).await; + copy_with_idle_timeout(&mut mask_read, &mut writer).await; let _ = writer.shutdown().await; } ); diff --git a/src/proxy/masking_security_tests.rs b/src/proxy/masking_security_tests.rs index 2310846..1cee108 100644 --- a/src/proxy/masking_security_tests.rs +++ b/src/proxy/masking_security_tests.rs @@ -234,8 +234,9 @@ async fn backend_connect_refusal_waits_mask_connect_budget_before_fallback() { let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); let probe = b"GET /probe HTTP/1.1\r\nHost: x\r\n\r\n"; - // Keep reader open so fallback path does not terminate immediately on EOF. - let (_client_reader_side, client_reader) = duplex(256); + // Close client reader immediately to force the refusal path to rely on masking budget timing. + let (client_reader_side, client_reader) = duplex(256); + drop(client_reader_side); let (_client_visible_reader, client_visible_writer) = duplex(256); let beobachten = BeobachtenStore::new(); @@ -890,6 +891,59 @@ async fn mask_disabled_slowloris_connection_is_closed_by_consume_timeout() { timeout(Duration::from_secs(1), task).await.unwrap().unwrap(); } +#[tokio::test] +async fn mask_enabled_idle_relay_is_closed_by_idle_timeout_before_global_relay_timeout() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = b"GET /idle HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut received = vec![0u8; probe.len()]; + stream.read_exact(&mut received).await.unwrap(); + assert_eq!(received, probe); + sleep(Duration::from_millis(300)).await; + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let peer: SocketAddr = "198.51.100.34:45456".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (_client_reader_side, client_reader) = duplex(512); + let (_client_visible_reader, client_visible_writer) = duplex(512); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + let elapsed = started.elapsed(); + assert!( + elapsed < Duration::from_millis(150), + "idle unauth relay must terminate on idle timeout instead of waiting for full relay timeout" + ); + + accept_task.await.unwrap(); +} + struct PendingWriter; impl tokio::io::AsyncWrite for PendingWriter { @@ -1250,3 +1304,166 @@ async fn timing_matrix_masking_classes_under_controlled_inputs() { (reachable_mean as u128) / BUCKET_MS ); } + +#[tokio::test] +async fn backend_connect_refusal_completes_within_bounded_mask_budget() { + let temp_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let unused_port = temp_listener.local_addr().unwrap().port(); + drop(temp_listener); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = unused_port; + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let peer: SocketAddr = "203.0.113.41:51001".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let probe = b"GET /bounded HTTP/1.1\r\nHost: x\r\n\r\n"; + + let (_client_reader_side, client_reader) = duplex(256); + let (_client_visible_reader, client_visible_writer) = duplex(256); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + let elapsed = started.elapsed(); + assert!( + elapsed >= Duration::from_millis(45), + "connect refusal path must respect minimum masking budget" + ); + assert!( + elapsed < Duration::from_millis(500), + "connect refusal path must stay bounded and avoid unbounded stall" + ); +} + +#[tokio::test] +async fn reachable_backend_one_response_then_silence_is_cut_by_idle_timeout() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = b"GET /oneshot HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + let response = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK".to_vec(); + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + let response = response.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut received = vec![0u8; probe.len()]; + stream.read_exact(&mut received).await.unwrap(); + assert_eq!(received, probe); + stream.write_all(&response).await.unwrap(); + sleep(Duration::from_millis(300)).await; + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let peer: SocketAddr = "203.0.113.42:51002".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (_client_reader_side, client_reader) = duplex(256); + let (mut client_visible_reader, client_visible_writer) = duplex(512); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + let elapsed = started.elapsed(); + + let mut observed = vec![0u8; response.len()]; + client_visible_reader.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, response); + assert!( + elapsed < Duration::from_millis(190), + "idle backend silence after first response must be cut by relay idle timeout" + ); + + accept_task.await.unwrap(); +} + +#[tokio::test] +async fn adversarial_client_drip_feed_longer_than_idle_timeout_is_cut_off() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let initial = b"GET /drip HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn({ + let initial = initial.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut observed = vec![0u8; initial.len()]; + stream.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, initial); + + let mut extra = [0u8; 1]; + let read_res = timeout(Duration::from_millis(220), stream.read_exact(&mut extra)).await; + assert!( + read_res.is_err() || read_res.unwrap().is_err(), + "drip-fed post-probe byte arriving after idle timeout should not be forwarded" + ); + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let peer: SocketAddr = "203.0.113.43:51003".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (mut client_writer_side, client_reader) = duplex(256); + let (_client_visible_reader, client_visible_writer) = duplex(256); + let beobachten = BeobachtenStore::new(); + + let relay_task = tokio::spawn(async move { + handle_bad_client( + client_reader, + client_visible_writer, + &initial, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + }); + + sleep(Duration::from_millis(160)).await; + let _ = client_writer_side.write_all(b"X").await; + drop(client_writer_side); + + timeout(Duration::from_secs(1), relay_task).await.unwrap().unwrap(); + accept_task.await.unwrap(); +}