diff --git a/src/config/hot_reload.rs b/src/config/hot_reload.rs index fa42c55..5582e9b 100644 --- a/src/config/hot_reload.rs +++ b/src/config/hot_reload.rs @@ -540,6 +540,10 @@ fn overlay_hot_fields(old: &ProxyConfig, new: &ProxyConfig) -> ProxyConfig { cfg.access.user_max_unique_ips_mode = new.access.user_max_unique_ips_mode; cfg.access.user_max_unique_ips_window_secs = new.access.user_max_unique_ips_window_secs; + if cfg.rebuild_runtime_user_auth().is_err() { + cfg.runtime_user_auth = None; + } + cfg } diff --git a/src/config/load.rs b/src/config/load.rs index 2e27edb..58b5143 100644 --- a/src/config/load.rs +++ b/src/config/load.rs @@ -4,6 +4,7 @@ use std::collections::{BTreeSet, HashMap, HashSet}; use std::hash::{DefaultHasher, Hash, Hasher}; use std::net::{IpAddr, SocketAddr}; use std::path::{Path, PathBuf}; +use std::sync::Arc; use rand::RngExt; use serde::{Deserialize, Serialize}; @@ -15,6 +16,8 @@ use crate::error::{ProxyError, Result}; use super::defaults::*; use super::types::*; +const ACCESS_SECRET_BYTES: usize = 16; + #[derive(Debug, Clone)] pub(crate) struct LoadedConfig { pub(crate) config: ProxyConfig, @@ -22,6 +25,104 @@ pub(crate) struct LoadedConfig { pub(crate) rendered_hash: u64, } +/// Precomputed, immutable user authentication data used by handshake hot paths. +#[derive(Debug, Clone, Default)] +pub(crate) struct UserAuthSnapshot { + entries: Vec, + by_name: HashMap, + sni_index: HashMap>, + sni_initial_index: HashMap>, +} + +#[derive(Debug, Clone)] +pub(crate) struct UserAuthEntry { + pub(crate) user: String, + pub(crate) secret: [u8; ACCESS_SECRET_BYTES], +} + +impl UserAuthSnapshot { + fn from_users(users: &HashMap) -> Result { + let mut entries = Vec::with_capacity(users.len()); + let mut by_name = HashMap::with_capacity(users.len()); + let mut sni_index = HashMap::with_capacity(users.len()); + let mut sni_initial_index = HashMap::with_capacity(users.len()); + + for (user, secret_hex) in users { + let decoded = hex::decode(secret_hex).map_err(|_| ProxyError::InvalidSecret { + user: user.clone(), + reason: "Must be 32 hex characters".to_string(), + })?; + if decoded.len() != ACCESS_SECRET_BYTES { + return Err(ProxyError::InvalidSecret { + user: user.clone(), + reason: "Must be 32 hex characters".to_string(), + }); + } + + let user_id = u32::try_from(entries.len()).map_err(|_| { + ProxyError::Config("Too many users for runtime auth snapshot".to_string()) + })?; + + let mut secret = [0u8; ACCESS_SECRET_BYTES]; + secret.copy_from_slice(&decoded); + entries.push(UserAuthEntry { + user: user.clone(), + secret, + }); + by_name.insert(user.clone(), user_id); + sni_index + .entry(Self::sni_lookup_hash(user)) + .or_insert_with(Vec::new) + .push(user_id); + if let Some(initial) = user.as_bytes().first().map(|byte| byte.to_ascii_lowercase()) { + sni_initial_index + .entry(initial) + .or_insert_with(Vec::new) + .push(user_id); + } + } + + Ok(Self { + entries, + by_name, + sni_index, + sni_initial_index, + }) + } + + pub(crate) fn entries(&self) -> &[UserAuthEntry] { + &self.entries + } + + pub(crate) fn user_id_by_name(&self, user: &str) -> Option { + self.by_name.get(user).copied() + } + + pub(crate) fn entry_by_id(&self, user_id: u32) -> Option<&UserAuthEntry> { + let idx = usize::try_from(user_id).ok()?; + self.entries.get(idx) + } + + pub(crate) fn sni_candidates(&self, sni: &str) -> Option<&[u32]> { + self.sni_index + .get(&Self::sni_lookup_hash(sni)) + .map(Vec::as_slice) + } + + pub(crate) fn sni_initial_candidates(&self, sni: &str) -> Option<&[u32]> { + let initial = sni.as_bytes().first().map(|byte| byte.to_ascii_lowercase())?; + self.sni_initial_index.get(&initial).map(Vec::as_slice) + } + + fn sni_lookup_hash(value: &str) -> u64 { + let mut hasher = DefaultHasher::new(); + for byte in value.bytes() { + hasher.write_u8(byte.to_ascii_lowercase()); + } + hasher.finish() + } +} + fn normalize_config_path(path: &Path) -> PathBuf { path.canonicalize().unwrap_or_else(|_| { if path.is_absolute() { @@ -196,6 +297,10 @@ pub struct ProxyConfig { /// If not set, defaults to 2 (matching Telegram's official `default 2;` in proxy-multi.conf). #[serde(default)] pub default_dc: Option, + + /// Precomputed authentication snapshot for handshake hot paths. + #[serde(skip)] + pub(crate) runtime_user_auth: Option>, } impl ProxyConfig { @@ -1164,6 +1269,7 @@ impl ProxyConfig { .or_insert_with(|| vec!["91.105.192.100:443".to_string()]); validate_upstreams(&config)?; + config.rebuild_runtime_user_auth()?; Ok(LoadedConfig { config, @@ -1172,6 +1278,16 @@ impl ProxyConfig { }) } + pub(crate) fn rebuild_runtime_user_auth(&mut self) -> Result<()> { + let snapshot = UserAuthSnapshot::from_users(&self.access.users)?; + self.runtime_user_auth = Some(Arc::new(snapshot)); + Ok(()) + } + + pub(crate) fn runtime_user_auth(&self) -> Option<&UserAuthSnapshot> { + self.runtime_user_auth.as_deref() + } + pub fn validate(&self) -> Result<()> { if self.access.users.is_empty() { return Err(ProxyError::Config("No users configured".to_string())); diff --git a/src/proxy/handshake.rs b/src/proxy/handshake.rs index 8524cff..ad58e93 100644 --- a/src/proxy/handshake.rs +++ b/src/proxy/handshake.rs @@ -4,13 +4,16 @@ use dashmap::DashMap; use dashmap::mapref::entry::Entry; +use hmac::{Hmac, Mac}; #[cfg(test)] use std::collections::HashSet; #[cfg(test)] use std::collections::hash_map::RandomState; +use std::collections::hash_map::DefaultHasher; use std::hash::{BuildHasher, Hash, Hasher}; use std::net::SocketAddr; use std::net::{IpAddr, Ipv6Addr}; +use std::sync::atomic::Ordering; use std::sync::Arc; #[cfg(test)] use std::sync::Mutex; @@ -30,6 +33,8 @@ use crate::stream::{CryptoReader, CryptoWriter, FakeTlsReader, FakeTlsWriter}; use crate::tls_front::{TlsFrontCache, emulator}; #[cfg(test)] use rand::RngExt; +use sha2::Sha256; +use subtle::ConstantTimeEq; const ACCESS_SECRET_BYTES: usize = 16; const UNKNOWN_SNI_WARN_COOLDOWN_SECS: u64 = 5; @@ -46,6 +51,13 @@ const AUTH_PROBE_TRACK_MAX_ENTRIES: usize = 65_536; const AUTH_PROBE_PRUNE_SCAN_LIMIT: usize = 1_024; const AUTH_PROBE_BACKOFF_START_FAILS: u32 = 4; const AUTH_PROBE_SATURATION_GRACE_FAILS: u32 = 2; +const STICKY_HINT_MAX_ENTRIES: usize = 65_536; +const CANDIDATE_HINT_TRACK_CAP: usize = 64; +const OVERLOAD_CANDIDATE_BUDGET_HINTED: usize = 16; +const OVERLOAD_CANDIDATE_BUDGET_UNHINTED: usize = 8; +const RECENT_USER_RING_SCAN_LIMIT: usize = 32; + +type HmacSha256 = Hmac; #[cfg(test)] const AUTH_PROBE_BACKOFF_BASE_MS: u64 = 1; @@ -91,6 +103,304 @@ fn should_emit_unknown_sni_warn_in(shared: &ProxySharedState, now: Instant) -> b true } +#[derive(Clone, Copy)] +struct ParsedTlsAuthMaterial { + digest: [u8; tls::TLS_DIGEST_LEN], + session_id: [u8; 32], + session_id_len: usize, + now: i64, + ignore_time_skew: bool, + boot_time_cap_secs: u32, +} + +#[derive(Clone, Copy)] +struct TlsCandidateValidation { + digest: [u8; tls::TLS_DIGEST_LEN], + session_id: [u8; 32], + session_id_len: usize, +} + +struct MtprotoCandidateValidation { + proto_tag: ProtoTag, + dc_idx: i16, + dec_key: [u8; 32], + dec_iv: u128, + enc_key: [u8; 32], + enc_iv: u128, + decryptor: AesCtr, + encryptor: AesCtr, +} + +fn sni_hint_hash(sni: &str) -> u64 { + let mut hasher = DefaultHasher::new(); + for byte in sni.bytes() { + hasher.write_u8(byte.to_ascii_lowercase()); + } + hasher.finish() +} + +fn ip_prefix_hint_key(peer_ip: IpAddr) -> u64 { + match peer_ip { + // Keep /24 granularity for IPv4 to avoid over-merging unrelated clients. + IpAddr::V4(ip) => { + let [a, b, c, _] = ip.octets(); + u64::from_be_bytes([0x04, a, b, c, 0, 0, 0, 0]) + } + // Keep /56 granularity for IPv6 to retain stability while limiting bucket size. + IpAddr::V6(ip) => { + let octets = ip.octets(); + u64::from_be_bytes([ + 0x06, octets[0], octets[1], octets[2], octets[3], octets[4], octets[5], octets[6], + ]) + } + } +} + +fn sticky_hint_get_by_ip(shared: &ProxySharedState, peer_ip: IpAddr) -> Option { + shared + .handshake + .sticky_user_by_ip + .get(&peer_ip) + .map(|entry| *entry) +} + +fn sticky_hint_get_by_ip_prefix(shared: &ProxySharedState, peer_ip: IpAddr) -> Option { + shared + .handshake + .sticky_user_by_ip_prefix + .get(&ip_prefix_hint_key(peer_ip)) + .map(|entry| *entry) +} + +fn sticky_hint_get_by_sni(shared: &ProxySharedState, sni: &str) -> Option { + let key = sni_hint_hash(sni); + shared + .handshake + .sticky_user_by_sni_hash + .get(&key) + .map(|entry| *entry) +} + +fn sticky_hint_record_success_in( + shared: &ProxySharedState, + peer_ip: IpAddr, + user_id: u32, + sni: Option<&str>, +) { + if shared.handshake.sticky_user_by_ip.len() > STICKY_HINT_MAX_ENTRIES { + shared.handshake.sticky_user_by_ip.clear(); + } + shared.handshake.sticky_user_by_ip.insert(peer_ip, user_id); + + if shared.handshake.sticky_user_by_ip_prefix.len() > STICKY_HINT_MAX_ENTRIES { + shared.handshake.sticky_user_by_ip_prefix.clear(); + } + shared + .handshake + .sticky_user_by_ip_prefix + .insert(ip_prefix_hint_key(peer_ip), user_id); + + if let Some(sni) = sni { + if shared.handshake.sticky_user_by_sni_hash.len() > STICKY_HINT_MAX_ENTRIES { + shared.handshake.sticky_user_by_sni_hash.clear(); + } + shared + .handshake + .sticky_user_by_sni_hash + .insert(sni_hint_hash(sni), user_id); + } +} + +fn record_recent_user_success_in(shared: &ProxySharedState, user_id: u32) { + let ring = &shared.handshake.recent_user_ring; + if ring.is_empty() { + return; + } + let seq = shared + .handshake + .recent_user_ring_seq + .fetch_add(1, Ordering::Relaxed); + let idx = (seq as usize) % ring.len(); + ring[idx].store(user_id.saturating_add(1), Ordering::Relaxed); +} + +fn mark_candidate_if_new(tried_user_ids: &mut [u32], tried_len: &mut usize, user_id: u32) -> bool { + if tried_user_ids[..*tried_len].contains(&user_id) { + return false; + } + if *tried_len < tried_user_ids.len() { + tried_user_ids[*tried_len] = user_id; + *tried_len += 1; + } + true +} + +fn budget_for_validation(total_users: usize, overload: bool, has_hint: bool) -> usize { + if total_users == 0 { + return 0; + } + if !overload { + return total_users; + } + let cap = if has_hint { + OVERLOAD_CANDIDATE_BUDGET_HINTED + } else { + OVERLOAD_CANDIDATE_BUDGET_UNHINTED + }; + total_users.min(cap.max(1)) +} + +fn parse_tls_auth_material( + handshake: &[u8], + ignore_time_skew: bool, + replay_window_secs: u64, +) -> Option { + if handshake.len() < tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 { + return None; + } + + let digest: [u8; tls::TLS_DIGEST_LEN] = handshake + [tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .try_into() + .ok()?; + + let session_id_len_pos = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN; + let session_id_len = usize::from(handshake.get(session_id_len_pos).copied()?); + if session_id_len > 32 { + return None; + } + let session_id_start = session_id_len_pos + 1; + if handshake.len() < session_id_start + session_id_len { + return None; + } + + let mut session_id = [0u8; 32]; + session_id[..session_id_len] + .copy_from_slice(&handshake[session_id_start..session_id_start + session_id_len]); + + let now = if !ignore_time_skew { + let d = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .ok()?; + i64::try_from(d.as_secs()).ok()? + } else { + 0_i64 + }; + + let replay_window_u32 = u32::try_from(replay_window_secs).unwrap_or(u32::MAX); + let boot_time_cap_secs = if ignore_time_skew { + 0 + } else { + tls::BOOT_TIME_MAX_SECS + .min(replay_window_u32) + .min(tls::BOOT_TIME_COMPAT_MAX_SECS) + }; + + Some(ParsedTlsAuthMaterial { + digest, + session_id, + session_id_len, + now, + ignore_time_skew, + boot_time_cap_secs, + }) +} + +fn compute_tls_hmac_zeroed_digest(secret: &[u8], handshake: &[u8]) -> [u8; 32] { + let mut mac = HmacSha256::new_from_slice(secret).expect("HMAC accepts any key length"); + mac.update(&handshake[..tls::TLS_DIGEST_POS]); + mac.update(&[0u8; tls::TLS_DIGEST_LEN]); + mac.update(&handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN..]); + mac.finalize().into_bytes().into() +} + +fn validate_tls_secret_candidate( + parsed: &ParsedTlsAuthMaterial, + handshake: &[u8], + secret: &[u8], +) -> Option { + let computed = compute_tls_hmac_zeroed_digest(secret, handshake); + if !bool::from(parsed.digest[..28].ct_eq(&computed[..28])) { + return None; + } + + let timestamp = u32::from_le_bytes([ + parsed.digest[28] ^ computed[28], + parsed.digest[29] ^ computed[29], + parsed.digest[30] ^ computed[30], + parsed.digest[31] ^ computed[31], + ]); + + if !parsed.ignore_time_skew { + let is_boot_time = parsed.boot_time_cap_secs > 0 && timestamp < parsed.boot_time_cap_secs; + if !is_boot_time { + let time_diff = parsed.now - i64::from(timestamp); + if !(tls::TIME_SKEW_MIN..=tls::TIME_SKEW_MAX).contains(&time_diff) { + return None; + } + } + } + + Some(TlsCandidateValidation { + digest: parsed.digest, + session_id: parsed.session_id, + session_id_len: parsed.session_id_len, + }) +} + +fn validate_mtproto_secret_candidate( + handshake: &[u8; HANDSHAKE_LEN], + dec_prekey: &[u8; PREKEY_LEN], + dec_iv: u128, + enc_prekey: &[u8; PREKEY_LEN], + enc_iv: u128, + secret: &[u8; ACCESS_SECRET_BYTES], + config: &ProxyConfig, + is_tls: bool, +) -> Option { + let mut dec_key_input = [0u8; PREKEY_LEN + ACCESS_SECRET_BYTES]; + dec_key_input[..PREKEY_LEN].copy_from_slice(dec_prekey); + dec_key_input[PREKEY_LEN..].copy_from_slice(secret); + let dec_key = sha256(&dec_key_input); + dec_key_input.zeroize(); + + let mut decryptor = AesCtr::new(&dec_key, dec_iv); + let mut decrypted = *handshake; + decryptor.apply(&mut decrypted); + + let tag_bytes: [u8; 4] = [ + decrypted[PROTO_TAG_POS], + decrypted[PROTO_TAG_POS + 1], + decrypted[PROTO_TAG_POS + 2], + decrypted[PROTO_TAG_POS + 3], + ]; + let proto_tag = ProtoTag::from_bytes(tag_bytes)?; + if !mode_enabled_for_proto(config, proto_tag, is_tls) { + return None; + } + + let dc_idx = i16::from_le_bytes([decrypted[DC_IDX_POS], decrypted[DC_IDX_POS + 1]]); + + let mut enc_key_input = [0u8; PREKEY_LEN + ACCESS_SECRET_BYTES]; + enc_key_input[..PREKEY_LEN].copy_from_slice(enc_prekey); + enc_key_input[PREKEY_LEN..].copy_from_slice(secret); + let enc_key = sha256(&enc_key_input); + enc_key_input.zeroize(); + + let encryptor = AesCtr::new(&enc_key, enc_iv); + + Some(MtprotoCandidateValidation { + proto_tag, + dc_idx, + dec_key, + dec_iv, + enc_key, + enc_iv, + decryptor, + encryptor, + }) +} + fn normalize_auth_probe_ip(peer_ip: IpAddr) -> IpAddr { match peer_ip { IpAddr::V4(ip) => IpAddr::V4(ip), @@ -854,29 +1164,231 @@ where } } - let secrets = decode_user_secrets_in(shared, config, preferred_user_hint); + let mut validation_digest = [0u8; tls::TLS_DIGEST_LEN]; + let mut validation_session_id = [0u8; 32]; + let mut validation_session_id_len = 0usize; + let mut validated_user = String::new(); + let mut validated_secret = [0u8; ACCESS_SECRET_BYTES]; + let mut validated_user_id: Option = None; - let validation = match tls::validate_tls_handshake_with_replay_window( - handshake, - &secrets, - config.access.ignore_time_skew, - config.access.replay_window_secs, - ) { - Some(v) => v, - None => { + if let Some(snapshot) = config.runtime_user_auth() { + let parsed = match parse_tls_auth_material( + handshake, + config.access.ignore_time_skew, + config.access.replay_window_secs, + ) { + Some(parsed) => parsed, + None => { + auth_probe_record_failure_in(shared, peer.ip(), Instant::now()); + maybe_apply_server_hello_delay(config).await; + debug!(peer = %peer, "TLS handshake auth material parsing failed"); + return HandshakeResult::BadClient { reader, writer }; + } + }; + + let sticky_ip_hint = sticky_hint_get_by_ip(shared, peer.ip()); + let preferred_user_id = preferred_user_hint.and_then(|user| snapshot.user_id_by_name(user)); + let sticky_sni_hint = client_sni + .as_deref() + .and_then(|sni| sticky_hint_get_by_sni(shared, sni)); + let sticky_prefix_hint = sticky_hint_get_by_ip_prefix(shared, peer.ip()); + let sni_candidates = client_sni.as_deref().and_then(|sni| snapshot.sni_candidates(sni)); + let sni_initial_candidates = client_sni + .as_deref() + .and_then(|sni| snapshot.sni_initial_candidates(sni)); + + let has_hint = sticky_ip_hint.is_some() + || preferred_user_id.is_some() + || sticky_sni_hint.is_some() + || sticky_prefix_hint.is_some() + || sni_candidates.is_some_and(|ids| !ids.is_empty()) + || sni_initial_candidates.is_some_and(|ids| !ids.is_empty()); + let overload = auth_probe_saturation_is_throttled_in(shared, Instant::now()); + let candidate_budget = budget_for_validation(snapshot.entries().len(), overload, has_hint); + + let mut tried_user_ids = [u32::MAX; CANDIDATE_HINT_TRACK_CAP]; + let mut tried_len = 0usize; + let mut validation_checks = 0usize; + let mut budget_exhausted = false; + + macro_rules! try_user_id { + ($user_id:expr) => {{ + if validation_checks >= candidate_budget { + budget_exhausted = true; + false + } else if !mark_candidate_if_new(&mut tried_user_ids, &mut tried_len, $user_id) { + false + } else if let Some(entry) = snapshot.entry_by_id($user_id) { + validation_checks = validation_checks.saturating_add(1); + if let Some(candidate) = + validate_tls_secret_candidate(&parsed, handshake, &entry.secret) + { + validation_digest = candidate.digest; + validation_session_id = candidate.session_id; + validation_session_id_len = candidate.session_id_len; + validated_secret.copy_from_slice(&entry.secret); + validated_user = entry.user.clone(); + validated_user_id = Some($user_id); + true + } else { + false + } + } else { + false + } + }}; + } + + let mut matched = false; + if let Some(user_id) = sticky_ip_hint { + matched = try_user_id!(user_id); + } + + if !matched && let Some(user_id) = preferred_user_id { + matched = try_user_id!(user_id); + } + + if !matched && let Some(user_id) = sticky_sni_hint { + matched = try_user_id!(user_id); + } + + if !matched && let Some(user_id) = sticky_prefix_hint { + matched = try_user_id!(user_id); + } + + if !matched && !budget_exhausted + && let Some(candidate_ids) = sni_candidates + { + for &user_id in candidate_ids { + if try_user_id!(user_id) { + matched = true; + break; + } + if budget_exhausted { + break; + } + } + } + + if !matched && !budget_exhausted + && let Some(candidate_ids) = sni_initial_candidates + { + for &user_id in candidate_ids { + if try_user_id!(user_id) { + matched = true; + break; + } + if budget_exhausted { + break; + } + } + } + + if !matched && !budget_exhausted { + let ring = &shared.handshake.recent_user_ring; + if !ring.is_empty() { + let next_seq = shared + .handshake + .recent_user_ring_seq + .load(Ordering::Relaxed); + let scan_limit = ring.len().min(RECENT_USER_RING_SCAN_LIMIT); + for offset in 0..scan_limit { + let idx = (next_seq as usize + ring.len() - 1 - offset) % ring.len(); + let encoded_user_id = ring[idx].load(Ordering::Relaxed); + if encoded_user_id == 0 { + continue; + } + if try_user_id!(encoded_user_id - 1) { + matched = true; + break; + } + if budget_exhausted { + break; + } + } + } + } + + if !matched && !budget_exhausted { + for idx in 0..snapshot.entries().len() { + let Some(user_id) = u32::try_from(idx).ok() else { + break; + }; + if try_user_id!(user_id) { + matched = true; + break; + } + if budget_exhausted { + break; + } + } + } + + shared + .handshake + .auth_expensive_checks_total + .fetch_add(validation_checks as u64, Ordering::Relaxed); + if budget_exhausted { + shared + .handshake + .auth_budget_exhausted_total + .fetch_add(1, Ordering::Relaxed); + } + + if !matched { auth_probe_record_failure_in(shared, peer.ip(), Instant::now()); maybe_apply_server_hello_delay(config).await; debug!( peer = %peer, ignore_time_skew = config.access.ignore_time_skew, - "TLS handshake validation failed - no matching user or time skew" + budget_exhausted = budget_exhausted, + candidate_budget = candidate_budget, + validation_checks = validation_checks, + "TLS handshake validation failed - no matching user, time skew, or budget exhausted" ); return HandshakeResult::BadClient { reader, writer }; } - }; + } else { + let secrets = decode_user_secrets_in(shared, config, preferred_user_hint); + let validation = match tls::validate_tls_handshake_with_replay_window( + handshake, + &secrets, + config.access.ignore_time_skew, + config.access.replay_window_secs, + ) { + Some(v) => v, + None => { + auth_probe_record_failure_in(shared, peer.ip(), Instant::now()); + maybe_apply_server_hello_delay(config).await; + debug!( + peer = %peer, + ignore_time_skew = config.access.ignore_time_skew, + "TLS handshake validation failed - no matching user or time skew" + ); + return HandshakeResult::BadClient { reader, writer }; + } + }; + let secret = match secrets.iter().find(|(name, _)| *name == validation.user) { + Some((_, s)) if s.len() == ACCESS_SECRET_BYTES => s, + _ => { + maybe_apply_server_hello_delay(config).await; + return HandshakeResult::BadClient { reader, writer }; + } + }; + + validation_digest = validation.digest; + validation_session_id_len = validation.session_id.len(); + if validation_session_id_len > validation_session_id.len() { + maybe_apply_server_hello_delay(config).await; + return HandshakeResult::BadClient { reader, writer }; + } + validation_session_id[..validation_session_id_len].copy_from_slice(&validation.session_id); + validated_user = validation.user; + validated_secret.copy_from_slice(secret); + } // Reject known replay digests before expensive cache/domain/ALPN policy work. - let digest_half = &validation.digest[..tls::TLS_DIGEST_HALF_LEN]; + let digest_half = &validation_digest[..tls::TLS_DIGEST_HALF_LEN]; if replay_checker.check_tls_digest(digest_half) { auth_probe_record_failure_in(shared, peer.ip(), Instant::now()); maybe_apply_server_hello_delay(config).await; @@ -884,14 +1396,6 @@ where return HandshakeResult::BadClient { reader, writer }; } - let secret = match secrets.iter().find(|(name, _)| *name == validation.user) { - Some((_, s)) => s, - None => { - maybe_apply_server_hello_delay(config).await; - return HandshakeResult::BadClient { reader, writer }; - } - }; - let cached = if config.censorship.tls_emulation { if let Some(cache) = tls_cache.as_ref() { let selected_domain = @@ -914,11 +1418,13 @@ where // Add replay digest only for policy-valid handshakes. replay_checker.add_tls_digest(digest_half); + let validation_session_id_slice = &validation_session_id[..validation_session_id_len]; + let response = if let Some((cached_entry, use_full_cert_payload)) = cached { emulator::build_emulated_server_hello( - secret, - &validation.digest, - &validation.session_id, + &validated_secret, + &validation_digest, + validation_session_id_slice, &cached_entry, use_full_cert_payload, rng, @@ -927,9 +1433,9 @@ where ) } else { tls::build_server_hello( - secret, - &validation.digest, - &validation.session_id, + &validated_secret, + &validation_digest, + validation_session_id_slice, config.censorship.fake_cert_len, rng, selected_alpn.clone(), @@ -955,16 +1461,21 @@ where debug!( peer = %peer, - user = %validation.user, + user = %validated_user, "TLS handshake successful" ); auth_probe_record_success_in(shared, peer.ip()); + if let Some(user_id) = validated_user_id { + sticky_hint_record_success_in(shared, peer.ip(), user_id, client_sni.as_deref()); + record_recent_user_success_in(shared, user_id); + } + HandshakeResult::Success(( FakeTlsReader::new(reader), FakeTlsWriter::new(writer), - validation.user, + validated_user, )) } @@ -1061,61 +1572,150 @@ where } let dec_prekey_iv = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]; + let mut dec_prekey = [0u8; PREKEY_LEN]; + dec_prekey.copy_from_slice(&dec_prekey_iv[..PREKEY_LEN]); + let mut dec_iv_arr = [0u8; IV_LEN]; + dec_iv_arr.copy_from_slice(&dec_prekey_iv[PREKEY_LEN..]); + let dec_iv = u128::from_be_bytes(dec_iv_arr); - let enc_prekey_iv: Vec = dec_prekey_iv.iter().rev().copied().collect(); + let mut enc_prekey_iv = [0u8; PREKEY_LEN + IV_LEN]; + for idx in 0..enc_prekey_iv.len() { + enc_prekey_iv[idx] = dec_prekey_iv[dec_prekey_iv.len() - 1 - idx]; + } + let mut enc_prekey = [0u8; PREKEY_LEN]; + enc_prekey.copy_from_slice(&enc_prekey_iv[..PREKEY_LEN]); + let mut enc_iv_arr = [0u8; IV_LEN]; + enc_iv_arr.copy_from_slice(&enc_prekey_iv[PREKEY_LEN..]); + let enc_iv = u128::from_be_bytes(enc_iv_arr); - let decoded_users = decode_user_secrets_in(shared, config, preferred_user); + if let Some(snapshot) = config.runtime_user_auth() { + let sticky_ip_hint = sticky_hint_get_by_ip(shared, peer.ip()); + let sticky_prefix_hint = sticky_hint_get_by_ip_prefix(shared, peer.ip()); + let preferred_user_id = preferred_user.and_then(|user| snapshot.user_id_by_name(user)); + let has_hint = + sticky_ip_hint.is_some() || sticky_prefix_hint.is_some() || preferred_user_id.is_some(); + let overload = auth_probe_saturation_is_throttled_in(shared, Instant::now()); + let candidate_budget = budget_for_validation(snapshot.entries().len(), overload, has_hint); - for (user, secret) in decoded_users { - let dec_prekey = &dec_prekey_iv[..PREKEY_LEN]; - let dec_iv_bytes = &dec_prekey_iv[PREKEY_LEN..]; + let mut tried_user_ids = [u32::MAX; CANDIDATE_HINT_TRACK_CAP]; + let mut tried_len = 0usize; + let mut validation_checks = 0usize; + let mut budget_exhausted = false; - let mut dec_key_input = Zeroizing::new(Vec::with_capacity(PREKEY_LEN + secret.len())); - dec_key_input.extend_from_slice(dec_prekey); - dec_key_input.extend_from_slice(&secret); - let dec_key = Zeroizing::new(sha256(&dec_key_input)); + let mut matched_user = String::new(); + let mut matched_user_id = None; + let mut matched_validation = None; - let mut dec_iv_arr = [0u8; IV_LEN]; - dec_iv_arr.copy_from_slice(dec_iv_bytes); - let dec_iv = u128::from_be_bytes(dec_iv_arr); - - let mut decryptor = AesCtr::new(&dec_key, dec_iv); - let decrypted = decryptor.decrypt(handshake); - - let tag_bytes: [u8; 4] = [ - decrypted[PROTO_TAG_POS], - decrypted[PROTO_TAG_POS + 1], - decrypted[PROTO_TAG_POS + 2], - decrypted[PROTO_TAG_POS + 3], - ]; - - let proto_tag = match ProtoTag::from_bytes(tag_bytes) { - Some(tag) => tag, - None => continue, - }; - - let mode_ok = mode_enabled_for_proto(config, proto_tag, is_tls); - - if !mode_ok { - debug!(peer = %peer, user = %user, proto = ?proto_tag, "Mode not enabled"); - continue; + macro_rules! try_user_id { + ($user_id:expr) => {{ + if validation_checks >= candidate_budget { + budget_exhausted = true; + false + } else if !mark_candidate_if_new(&mut tried_user_ids, &mut tried_len, $user_id) { + false + } else if let Some(entry) = snapshot.entry_by_id($user_id) { + validation_checks = validation_checks.saturating_add(1); + if let Some(validation) = validate_mtproto_secret_candidate( + handshake, + &dec_prekey, + dec_iv, + &enc_prekey, + enc_iv, + &entry.secret, + config, + is_tls, + ) { + matched_user = entry.user.clone(); + matched_user_id = Some($user_id); + matched_validation = Some(validation); + true + } else { + false + } + } else { + false + } + }}; } - let dc_idx = i16::from_le_bytes([decrypted[DC_IDX_POS], decrypted[DC_IDX_POS + 1]]); + let mut matched = false; + if let Some(user_id) = sticky_ip_hint { + matched = try_user_id!(user_id); + } - let enc_prekey = &enc_prekey_iv[..PREKEY_LEN]; - let enc_iv_bytes = &enc_prekey_iv[PREKEY_LEN..]; + if !matched && let Some(user_id) = preferred_user_id { + matched = try_user_id!(user_id); + } - let mut enc_key_input = Zeroizing::new(Vec::with_capacity(PREKEY_LEN + secret.len())); - enc_key_input.extend_from_slice(enc_prekey); - enc_key_input.extend_from_slice(&secret); - let enc_key = Zeroizing::new(sha256(&enc_key_input)); + if !matched && let Some(user_id) = sticky_prefix_hint { + matched = try_user_id!(user_id); + } - let mut enc_iv_arr = [0u8; IV_LEN]; - enc_iv_arr.copy_from_slice(enc_iv_bytes); - let enc_iv = u128::from_be_bytes(enc_iv_arr); + if !matched && !budget_exhausted { + let ring = &shared.handshake.recent_user_ring; + if !ring.is_empty() { + let next_seq = shared + .handshake + .recent_user_ring_seq + .load(Ordering::Relaxed); + let scan_limit = ring.len().min(RECENT_USER_RING_SCAN_LIMIT); + for offset in 0..scan_limit { + let idx = (next_seq as usize + ring.len() - 1 - offset) % ring.len(); + let encoded_user_id = ring[idx].load(Ordering::Relaxed); + if encoded_user_id == 0 { + continue; + } + if try_user_id!(encoded_user_id - 1) { + matched = true; + break; + } + if budget_exhausted { + break; + } + } + } + } - let encryptor = AesCtr::new(&enc_key, enc_iv); + if !matched && !budget_exhausted { + for idx in 0..snapshot.entries().len() { + let Some(user_id) = u32::try_from(idx).ok() else { + break; + }; + if try_user_id!(user_id) { + matched = true; + break; + } + if budget_exhausted { + break; + } + } + } + + shared + .handshake + .auth_expensive_checks_total + .fetch_add(validation_checks as u64, Ordering::Relaxed); + if budget_exhausted { + shared + .handshake + .auth_budget_exhausted_total + .fetch_add(1, Ordering::Relaxed); + } + + if !matched { + auth_probe_record_failure_in(shared, peer.ip(), Instant::now()); + maybe_apply_server_hello_delay(config).await; + debug!( + peer = %peer, + budget_exhausted = budget_exhausted, + candidate_budget = candidate_budget, + validation_checks = validation_checks, + "MTProto handshake: no matching user found" + ); + return HandshakeResult::BadClient { reader, writer }; + } + + let validation = matched_validation.expect("validation must exist when matched"); // Apply replay tracking only after successful authentication. // @@ -1126,39 +1726,121 @@ where if replay_checker.check_and_add_handshake(dec_prekey_iv) { auth_probe_record_failure_in(shared, peer.ip(), Instant::now()); maybe_apply_server_hello_delay(config).await; - warn!(peer = %peer, user = %user, "MTProto replay attack detected"); + warn!(peer = %peer, user = %matched_user, "MTProto replay attack detected"); return HandshakeResult::BadClient { reader, writer }; } let success = HandshakeSuccess { - user: user.clone(), - dc_idx, - proto_tag, - dec_key: *dec_key, - dec_iv, - enc_key: *enc_key, - enc_iv, + user: matched_user.clone(), + dc_idx: validation.dc_idx, + proto_tag: validation.proto_tag, + dec_key: validation.dec_key, + dec_iv: validation.dec_iv, + enc_key: validation.enc_key, + enc_iv: validation.enc_iv, peer, is_tls, }; debug!( peer = %peer, - user = %user, - dc = dc_idx, - proto = ?proto_tag, + user = %matched_user, + dc = validation.dc_idx, + proto = ?validation.proto_tag, tls = is_tls, "MTProto handshake successful" ); auth_probe_record_success_in(shared, peer.ip()); + if let Some(user_id) = matched_user_id { + sticky_hint_record_success_in(shared, peer.ip(), user_id, None); + record_recent_user_success_in(shared, user_id); + } let max_pending = config.general.crypto_pending_buffer; return HandshakeResult::Success(( - CryptoReader::new(reader, decryptor), - CryptoWriter::new(writer, encryptor, max_pending), + CryptoReader::new(reader, validation.decryptor), + CryptoWriter::new(writer, validation.encryptor, max_pending), success, )); + } else { + let decoded_users = decode_user_secrets_in(shared, config, preferred_user); + let mut validation_checks = 0usize; + + for (user, secret) in decoded_users { + if secret.len() != ACCESS_SECRET_BYTES { + continue; + } + validation_checks = validation_checks.saturating_add(1); + + let mut secret_arr = [0u8; ACCESS_SECRET_BYTES]; + secret_arr.copy_from_slice(&secret); + let Some(validation) = validate_mtproto_secret_candidate( + handshake, + &dec_prekey, + dec_iv, + &enc_prekey, + enc_iv, + &secret_arr, + config, + is_tls, + ) else { + continue; + }; + + shared + .handshake + .auth_expensive_checks_total + .fetch_add(validation_checks as u64, Ordering::Relaxed); + + // Apply replay tracking only after successful authentication. + // + // This ordering prevents an attacker from producing invalid handshakes that + // still collide with a valid handshake's replay slot and thus evict a valid + // entry from the cache. We accept the cost of performing the full + // authentication check first to avoid poisoning the replay cache. + if replay_checker.check_and_add_handshake(dec_prekey_iv) { + auth_probe_record_failure_in(shared, peer.ip(), Instant::now()); + maybe_apply_server_hello_delay(config).await; + warn!(peer = %peer, user = %user, "MTProto replay attack detected"); + return HandshakeResult::BadClient { reader, writer }; + } + + let success = HandshakeSuccess { + user: user.clone(), + dc_idx: validation.dc_idx, + proto_tag: validation.proto_tag, + dec_key: validation.dec_key, + dec_iv: validation.dec_iv, + enc_key: validation.enc_key, + enc_iv: validation.enc_iv, + peer, + is_tls, + }; + + debug!( + peer = %peer, + user = %user, + dc = validation.dc_idx, + proto = ?validation.proto_tag, + tls = is_tls, + "MTProto handshake successful" + ); + + auth_probe_record_success_in(shared, peer.ip()); + + let max_pending = config.general.crypto_pending_buffer; + return HandshakeResult::Success(( + CryptoReader::new(reader, validation.decryptor), + CryptoWriter::new(writer, validation.encryptor, max_pending), + success, + )); + } + + shared + .handshake + .auth_expensive_checks_total + .fetch_add(validation_checks as u64, Ordering::Relaxed); } auth_probe_record_failure_in(shared, peer.ip(), Instant::now()); diff --git a/src/proxy/shared_state.rs b/src/proxy/shared_state.rs index dd49806..4fef497 100644 --- a/src/proxy/shared_state.rs +++ b/src/proxy/shared_state.rs @@ -1,7 +1,7 @@ use std::collections::HashSet; use std::collections::hash_map::RandomState; use std::net::{IpAddr, SocketAddr}; -use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; +use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering}; use std::sync::{Arc, Mutex}; use std::time::Instant; @@ -11,6 +11,8 @@ use tokio::sync::mpsc; use crate::proxy::handshake::{AuthProbeSaturationState, AuthProbeState}; use crate::proxy::middle_relay::{DesyncDedupRotationState, RelayIdleCandidateRegistry}; +const HANDSHAKE_RECENT_USER_RING_LEN: usize = 64; + #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub(crate) enum ConntrackCloseReason { NormalEof, @@ -41,6 +43,13 @@ pub(crate) struct HandshakeSharedState { pub(crate) auth_probe_eviction_hasher: RandomState, pub(crate) invalid_secret_warned: Mutex>, pub(crate) unknown_sni_warn_next_allowed: Mutex>, + pub(crate) sticky_user_by_ip: DashMap, + pub(crate) sticky_user_by_ip_prefix: DashMap, + pub(crate) sticky_user_by_sni_hash: DashMap, + pub(crate) recent_user_ring: Box<[AtomicU32]>, + pub(crate) recent_user_ring_seq: AtomicU64, + pub(crate) auth_expensive_checks_total: AtomicU64, + pub(crate) auth_budget_exhausted_total: AtomicU64, } pub(crate) struct MiddleRelaySharedState { @@ -69,6 +78,16 @@ impl ProxySharedState { auth_probe_eviction_hasher: RandomState::new(), invalid_secret_warned: Mutex::new(HashSet::new()), unknown_sni_warn_next_allowed: Mutex::new(None), + sticky_user_by_ip: DashMap::new(), + sticky_user_by_ip_prefix: DashMap::new(), + sticky_user_by_sni_hash: DashMap::new(), + recent_user_ring: std::iter::repeat_with(|| AtomicU32::new(0)) + .take(HANDSHAKE_RECENT_USER_RING_LEN) + .collect::>() + .into_boxed_slice(), + recent_user_ring_seq: AtomicU64::new(0), + auth_expensive_checks_total: AtomicU64::new(0), + auth_budget_exhausted_total: AtomicU64::new(0), }, middle_relay: MiddleRelaySharedState { desync_dedup: DashMap::new(), diff --git a/src/proxy/tests/handshake_security_tests.rs b/src/proxy/tests/handshake_security_tests.rs index df3bbe0..937740c 100644 --- a/src/proxy/tests/handshake_security_tests.rs +++ b/src/proxy/tests/handshake_security_tests.rs @@ -4,6 +4,7 @@ use dashmap::DashMap; use rand::rngs::StdRng; use rand::{RngExt, SeedableRng}; use std::net::{IpAddr, Ipv4Addr}; +use std::sync::atomic::Ordering; use std::sync::Arc; use std::time::{Duration, Instant}; use tokio::sync::Barrier; @@ -1090,6 +1091,172 @@ async fn tls_missing_sni_keeps_legacy_auth_path() { assert!(matches!(result, HandshakeResult::Success(_))); } +#[tokio::test] +async fn tls_runtime_snapshot_updates_sticky_and_recent_hints() { + let secret = [0x5Au8; 16]; + let mut config = test_config_with_secret_hex("5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a"); + config.rebuild_runtime_user_auth().unwrap(); + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let shared = ProxySharedState::new(); + let peer: SocketAddr = "198.51.100.212:44326".parse().unwrap(); + let handshake = make_valid_tls_client_hello_with_sni_and_alpn(&secret, 0, "user", &[b"h2"]); + + let result = handle_tls_handshake_with_shared( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + shared.as_ref(), + ) + .await; + + assert!(matches!(result, HandshakeResult::Success(_))); + assert_eq!( + shared + .handshake + .sticky_user_by_ip + .get(&peer.ip()) + .map(|entry| *entry), + Some(0), + "successful runtime-snapshot auth must seed sticky ip cache" + ); + assert_eq!( + shared + .handshake + .sticky_user_by_ip_prefix + .len(), + 1, + "successful runtime-snapshot auth must seed sticky prefix cache" + ); + assert!( + shared + .handshake + .auth_expensive_checks_total + .load(Ordering::Relaxed) + >= 1, + "runtime-snapshot path must account expensive candidate checks" + ); +} + +#[tokio::test] +async fn tls_overload_budget_limits_candidate_scan_depth() { + let mut config = ProxyConfig::default(); + config.access.users.clear(); + config.access.ignore_time_skew = true; + for idx in 0..32u8 { + config + .access + .users + .insert(format!("user-{idx}"), format!("{:032x}", u128::from(idx) + 1)); + } + config.rebuild_runtime_user_auth().unwrap(); + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let shared = ProxySharedState::new(); + let now = Instant::now(); + { + let mut saturation = shared.handshake.auth_probe_saturation.lock().unwrap(); + *saturation = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_millis(200), + last_seen: now, + }); + } + + let peer: SocketAddr = "198.51.100.213:44326".parse().unwrap(); + let attacker_secret = [0xEFu8; 16]; + let handshake = make_valid_tls_handshake(&attacker_secret, 0); + + let result = handle_tls_handshake_with_shared( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + shared.as_ref(), + ) + .await; + + assert!(matches!(result, HandshakeResult::BadClient { .. })); + assert_eq!( + shared + .handshake + .auth_budget_exhausted_total + .load(Ordering::Relaxed), + 1, + "overload mode must account budget exhaustion when scan is capped" + ); + assert_eq!( + shared + .handshake + .auth_expensive_checks_total + .load(Ordering::Relaxed), + OVERLOAD_CANDIDATE_BUDGET_UNHINTED as u64, + "overload scan depth must stay within capped candidate budget" + ); +} + +#[tokio::test] +async fn mtproto_runtime_snapshot_prefers_preferred_user_hint() { + let mut config = ProxyConfig::default(); + config.access.users.clear(); + config.access.ignore_time_skew = true; + config.access.users.insert( + "alpha".to_string(), + "11111111111111111111111111111111".to_string(), + ); + config.access.users.insert( + "beta".to_string(), + "22222222222222222222222222222222".to_string(), + ); + config.rebuild_runtime_user_auth().unwrap(); + + let handshake = + make_valid_mtproto_handshake("22222222222222222222222222222222", ProtoTag::Secure, 2); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let peer: SocketAddr = "198.51.100.214:44326".parse().unwrap(); + let shared = ProxySharedState::new(); + + let result = handle_mtproto_handshake_with_shared( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + Some("beta"), + shared.as_ref(), + ) + .await; + + match result { + HandshakeResult::Success((_, _, success)) => { + assert_eq!(success.user, "beta"); + } + _ => panic!("mtproto runtime snapshot auth must succeed for preferred user"), + } + + assert_eq!( + shared + .handshake + .auth_expensive_checks_total + .load(Ordering::Relaxed), + 1, + "preferred user hint must produce single-candidate success in snapshot path" + ); +} + #[tokio::test] async fn alpn_enforce_rejects_unsupported_client_alpn() { let secret = [0x33u8; 16];