diff --git a/src/config/defaults.rs b/src/config/defaults.rs index f02403e..b0aaf5b 100644 --- a/src/config/defaults.rs +++ b/src/config/defaults.rs @@ -71,6 +71,22 @@ pub(crate) fn default_tls_fetch_scope() -> String { String::new() } +pub(crate) fn default_tls_fetch_attempt_timeout_ms() -> u64 { + 5_000 +} + +pub(crate) fn default_tls_fetch_total_budget_ms() -> u64 { + 15_000 +} + +pub(crate) fn default_tls_fetch_strict_route() -> bool { + true +} + +pub(crate) fn default_tls_fetch_profile_cache_ttl_secs() -> u64 { + 600 +} + pub(crate) fn default_mask_port() -> u16 { 443 } diff --git a/src/config/load.rs b/src/config/load.rs index 2c46766..3cb6627 100644 --- a/src/config/load.rs +++ b/src/config/load.rs @@ -1,6 +1,6 @@ #![allow(deprecated)] -use std::collections::{BTreeSet, HashMap}; +use std::collections::{BTreeSet, HashMap, HashSet}; use std::hash::{DefaultHasher, Hash, Hasher}; use std::net::{IpAddr, SocketAddr}; use std::path::{Path, PathBuf}; @@ -977,6 +977,28 @@ impl ProxyConfig { // Normalize optional TLS fetch scope: whitespace-only values disable scoped routing. config.censorship.tls_fetch_scope = config.censorship.tls_fetch_scope.trim().to_string(); + if config.censorship.tls_fetch.profiles.is_empty() { + config.censorship.tls_fetch.profiles = TlsFetchConfig::default().profiles; + } else { + let mut seen = HashSet::new(); + config + .censorship + .tls_fetch + .profiles + .retain(|profile| seen.insert(*profile)); + } + + if config.censorship.tls_fetch.attempt_timeout_ms == 0 { + return Err(ProxyError::Config( + "censorship.tls_fetch.attempt_timeout_ms must be > 0".to_string(), + )); + } + if config.censorship.tls_fetch.total_budget_ms == 0 { + return Err(ProxyError::Config( + "censorship.tls_fetch.total_budget_ms must be > 0".to_string(), + )); + } + // Merge primary + extra TLS domains, deduplicate (primary always first). if !config.censorship.tls_domains.is_empty() { let mut all = Vec::with_capacity(1 + config.censorship.tls_domains.len()); @@ -2459,6 +2481,94 @@ mod tests { let _ = std::fs::remove_file(path); } + #[test] + fn tls_fetch_defaults_are_applied() { + let toml = r#" + [censorship] + tls_domain = "example.com" + + [access.users] + user = "00000000000000000000000000000000" + "#; + let dir = std::env::temp_dir(); + let path = dir.join("telemt_tls_fetch_defaults_test.toml"); + std::fs::write(&path, toml).unwrap(); + let cfg = ProxyConfig::load(&path).unwrap(); + assert_eq!( + cfg.censorship.tls_fetch.profiles, + TlsFetchConfig::default().profiles + ); + assert!(cfg.censorship.tls_fetch.strict_route); + assert_eq!(cfg.censorship.tls_fetch.attempt_timeout_ms, 5_000); + assert_eq!(cfg.censorship.tls_fetch.total_budget_ms, 15_000); + assert_eq!(cfg.censorship.tls_fetch.profile_cache_ttl_secs, 600); + let _ = std::fs::remove_file(path); + } + + #[test] + fn tls_fetch_profiles_are_deduplicated_preserving_order() { + let toml = r#" + [censorship] + tls_domain = "example.com" + [censorship.tls_fetch] + profiles = ["compat_tls12", "modern_chrome_like", "compat_tls12", "legacy_minimal"] + + [access.users] + user = "00000000000000000000000000000000" + "#; + let dir = std::env::temp_dir(); + let path = dir.join("telemt_tls_fetch_profiles_dedup_test.toml"); + std::fs::write(&path, toml).unwrap(); + let cfg = ProxyConfig::load(&path).unwrap(); + assert_eq!( + cfg.censorship.tls_fetch.profiles, + vec![ + TlsFetchProfile::CompatTls12, + TlsFetchProfile::ModernChromeLike, + TlsFetchProfile::LegacyMinimal + ] + ); + let _ = std::fs::remove_file(path); + } + + #[test] + fn tls_fetch_attempt_timeout_zero_is_rejected() { + let toml = r#" + [censorship] + tls_domain = "example.com" + [censorship.tls_fetch] + attempt_timeout_ms = 0 + + [access.users] + user = "00000000000000000000000000000000" + "#; + let dir = std::env::temp_dir(); + let path = dir.join("telemt_tls_fetch_attempt_timeout_zero_test.toml"); + std::fs::write(&path, toml).unwrap(); + let err = ProxyConfig::load(&path).unwrap_err().to_string(); + assert!(err.contains("censorship.tls_fetch.attempt_timeout_ms must be > 0")); + let _ = std::fs::remove_file(path); + } + + #[test] + fn tls_fetch_total_budget_zero_is_rejected() { + let toml = r#" + [censorship] + tls_domain = "example.com" + [censorship.tls_fetch] + total_budget_ms = 0 + + [access.users] + user = "00000000000000000000000000000000" + "#; + let dir = std::env::temp_dir(); + let path = dir.join("telemt_tls_fetch_total_budget_zero_test.toml"); + std::fs::write(&path, toml).unwrap(); + let err = ProxyConfig::load(&path).unwrap_err().to_string(); + assert!(err.contains("censorship.tls_fetch.total_budget_ms must be > 0")); + let _ = std::fs::remove_file(path); + } + #[test] fn invalid_ad_tag_is_disabled_during_load() { let toml = r#" diff --git a/src/config/types.rs b/src/config/types.rs index 68ba278..3939664 100644 --- a/src/config/types.rs +++ b/src/config/types.rs @@ -1367,6 +1367,82 @@ pub enum UnknownSniAction { Mask, } +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum TlsFetchProfile { + ModernChromeLike, + ModernFirefoxLike, + CompatTls12, + LegacyMinimal, +} + +impl TlsFetchProfile { + pub fn as_str(self) -> &'static str { + match self { + TlsFetchProfile::ModernChromeLike => "modern_chrome_like", + TlsFetchProfile::ModernFirefoxLike => "modern_firefox_like", + TlsFetchProfile::CompatTls12 => "compat_tls12", + TlsFetchProfile::LegacyMinimal => "legacy_minimal", + } + } +} + +fn default_tls_fetch_profiles() -> Vec { + vec![ + TlsFetchProfile::ModernChromeLike, + TlsFetchProfile::ModernFirefoxLike, + TlsFetchProfile::CompatTls12, + TlsFetchProfile::LegacyMinimal, + ] +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TlsFetchConfig { + /// Ordered list of ClientHello profiles used for adaptive fallback. + #[serde(default = "default_tls_fetch_profiles")] + pub profiles: Vec, + + /// When true and upstream route is configured, TLS fetch fails closed on + /// upstream connect errors and does not fallback to direct TCP. + #[serde(default = "default_tls_fetch_strict_route")] + pub strict_route: bool, + + /// Timeout per one profile attempt in milliseconds. + #[serde(default = "default_tls_fetch_attempt_timeout_ms")] + pub attempt_timeout_ms: u64, + + /// Total wall-clock budget in milliseconds across all profile attempts. + #[serde(default = "default_tls_fetch_total_budget_ms")] + pub total_budget_ms: u64, + + /// Adds GREASE-style values into selected ClientHello extensions. + #[serde(default)] + pub grease_enabled: bool, + + /// Produces deterministic ClientHello randomness for debugging/tests. + #[serde(default)] + pub deterministic: bool, + + /// TTL for winner-profile cache entries in seconds. + /// Set to 0 to disable profile cache. + #[serde(default = "default_tls_fetch_profile_cache_ttl_secs")] + pub profile_cache_ttl_secs: u64, +} + +impl Default for TlsFetchConfig { + fn default() -> Self { + Self { + profiles: default_tls_fetch_profiles(), + strict_route: default_tls_fetch_strict_route(), + attempt_timeout_ms: default_tls_fetch_attempt_timeout_ms(), + total_budget_ms: default_tls_fetch_total_budget_ms(), + grease_enabled: false, + deterministic: false, + profile_cache_ttl_secs: default_tls_fetch_profile_cache_ttl_secs(), + } + } +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct AntiCensorshipConfig { #[serde(default = "default_tls_domain")] @@ -1385,6 +1461,10 @@ pub struct AntiCensorshipConfig { #[serde(default = "default_tls_fetch_scope")] pub tls_fetch_scope: String, + /// Fetch strategy for TLS front metadata bootstrap and periodic refresh. + #[serde(default)] + pub tls_fetch: TlsFetchConfig, + #[serde(default = "default_true")] pub mask: bool, @@ -1492,6 +1572,7 @@ impl Default for AntiCensorshipConfig { tls_domains: Vec::new(), unknown_sni_action: UnknownSniAction::Drop, tls_fetch_scope: default_tls_fetch_scope(), + tls_fetch: TlsFetchConfig::default(), mask: default_true(), mask_host: None, mask_port: default_mask_port(), diff --git a/src/maestro/tls_bootstrap.rs b/src/maestro/tls_bootstrap.rs index 342a2f9..7cf3039 100644 --- a/src/maestro/tls_bootstrap.rs +++ b/src/maestro/tls_bootstrap.rs @@ -7,6 +7,7 @@ use tracing::warn; use crate::config::ProxyConfig; use crate::startup::{COMPONENT_TLS_FRONT_BOOTSTRAP, StartupTracker}; use crate::tls_front::TlsFrontCache; +use crate::tls_front::fetcher::TlsFetchStrategy; use crate::transport::UpstreamManager; pub(crate) async fn bootstrap_tls_front( @@ -40,7 +41,17 @@ pub(crate) async fn bootstrap_tls_front( let mask_unix_sock = config.censorship.mask_unix_sock.clone(); let tls_fetch_scope = (!config.censorship.tls_fetch_scope.is_empty()) .then(|| config.censorship.tls_fetch_scope.clone()); - let fetch_timeout = Duration::from_secs(5); + let tls_fetch = config.censorship.tls_fetch.clone(); + let fetch_strategy = TlsFetchStrategy { + profiles: tls_fetch.profiles, + strict_route: tls_fetch.strict_route, + attempt_timeout: Duration::from_millis(tls_fetch.attempt_timeout_ms.max(1)), + total_budget: Duration::from_millis(tls_fetch.total_budget_ms.max(1)), + grease_enabled: tls_fetch.grease_enabled, + deterministic: tls_fetch.deterministic, + profile_cache_ttl: Duration::from_secs(tls_fetch.profile_cache_ttl_secs), + }; + let fetch_timeout = fetch_strategy.total_budget; let cache_initial = cache.clone(); let domains_initial = tls_domains.to_vec(); @@ -48,6 +59,7 @@ pub(crate) async fn bootstrap_tls_front( let unix_sock_initial = mask_unix_sock.clone(); let scope_initial = tls_fetch_scope.clone(); let upstream_initial = upstream_manager.clone(); + let strategy_initial = fetch_strategy.clone(); tokio::spawn(async move { let mut join = tokio::task::JoinSet::new(); for domain in domains_initial { @@ -56,12 +68,13 @@ pub(crate) async fn bootstrap_tls_front( let unix_sock_domain = unix_sock_initial.clone(); let scope_domain = scope_initial.clone(); let upstream_domain = upstream_initial.clone(); + let strategy_domain = strategy_initial.clone(); join.spawn(async move { - match crate::tls_front::fetcher::fetch_real_tls( + match crate::tls_front::fetcher::fetch_real_tls_with_strategy( &host_domain, port, &domain, - fetch_timeout, + &strategy_domain, Some(upstream_domain), scope_domain.as_deref(), proxy_protocol, @@ -107,6 +120,7 @@ pub(crate) async fn bootstrap_tls_front( let unix_sock_refresh = mask_unix_sock.clone(); let scope_refresh = tls_fetch_scope.clone(); let upstream_refresh = upstream_manager.clone(); + let strategy_refresh = fetch_strategy.clone(); tokio::spawn(async move { loop { let base_secs = rand::rng().random_range(4 * 3600..=6 * 3600); @@ -120,12 +134,13 @@ pub(crate) async fn bootstrap_tls_front( let unix_sock_domain = unix_sock_refresh.clone(); let scope_domain = scope_refresh.clone(); let upstream_domain = upstream_refresh.clone(); + let strategy_domain = strategy_refresh.clone(); join.spawn(async move { - match crate::tls_front::fetcher::fetch_real_tls( + match crate::tls_front::fetcher::fetch_real_tls_with_strategy( &host_domain, port, &domain, - fetch_timeout, + &strategy_domain, Some(upstream_domain), scope_domain.as_deref(), proxy_protocol, diff --git a/src/tls_front/fetcher.rs b/src/tls_front/fetcher.rs index 2356a93..503b79c 100644 --- a/src/tls_front/fetcher.rs +++ b/src/tls_front/fetcher.rs @@ -1,7 +1,9 @@ #![allow(clippy::too_many_arguments)] +use dashmap::DashMap; use std::sync::Arc; -use std::time::Duration; +use std::sync::OnceLock; +use std::time::{Duration, Instant}; use anyhow::{Result, anyhow}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; @@ -21,7 +23,8 @@ use rustls::{DigitallySignedStruct, Error as RustlsError}; use x509_parser::certificate::X509Certificate; use x509_parser::prelude::FromDer; -use crate::crypto::SecureRandom; +use crate::config::TlsFetchProfile; +use crate::crypto::{SecureRandom, sha256}; use crate::network::dns_overrides::resolve_socket_addr; use crate::protocol::constants::{ TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER, TLS_RECORD_HANDSHAKE, @@ -78,6 +81,197 @@ impl ServerCertVerifier for NoVerify { } } +#[derive(Debug, Clone)] +pub struct TlsFetchStrategy { + pub profiles: Vec, + pub strict_route: bool, + pub attempt_timeout: Duration, + pub total_budget: Duration, + pub grease_enabled: bool, + pub deterministic: bool, + pub profile_cache_ttl: Duration, +} + +impl TlsFetchStrategy { + #[allow(dead_code)] + pub fn single_attempt(connect_timeout: Duration) -> Self { + Self { + profiles: vec![TlsFetchProfile::CompatTls12], + strict_route: false, + attempt_timeout: connect_timeout.max(Duration::from_millis(1)), + total_budget: connect_timeout.max(Duration::from_millis(1)), + grease_enabled: false, + deterministic: false, + profile_cache_ttl: Duration::ZERO, + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +struct ProfileCacheKey { + host: String, + port: u16, + sni: String, + scope: Option, + proxy_protocol: u8, + route_hint: RouteHint, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +enum RouteHint { + Direct, + Upstream, + Unix, +} + +#[derive(Debug, Clone, Copy)] +struct ProfileCacheValue { + profile: TlsFetchProfile, + updated_at: Instant, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum FetchErrorKind { + Connect, + Route, + EarlyEof, + Timeout, + ServerHelloMissing, + TlsAlert, + Parse, + Other, +} + +static PROFILE_CACHE: OnceLock> = OnceLock::new(); + +fn profile_cache() -> &'static DashMap { + PROFILE_CACHE.get_or_init(DashMap::new) +} + +fn route_hint( + upstream: Option<&std::sync::Arc>, + unix_sock: Option<&str>, +) -> RouteHint { + if unix_sock.is_some() { + RouteHint::Unix + } else if upstream.is_some() { + RouteHint::Upstream + } else { + RouteHint::Direct + } +} + +fn profile_cache_key( + host: &str, + port: u16, + sni: &str, + upstream: Option<&std::sync::Arc>, + scope: Option<&str>, + proxy_protocol: u8, + unix_sock: Option<&str>, +) -> ProfileCacheKey { + ProfileCacheKey { + host: host.to_string(), + port, + sni: sni.to_string(), + scope: scope.map(ToString::to_string), + proxy_protocol, + route_hint: route_hint(upstream, unix_sock), + } +} + +fn classify_fetch_error(err: &anyhow::Error) -> FetchErrorKind { + for cause in err.chain() { + if let Some(io) = cause.downcast_ref::() { + return match io.kind() { + std::io::ErrorKind::TimedOut => FetchErrorKind::Timeout, + std::io::ErrorKind::UnexpectedEof => FetchErrorKind::EarlyEof, + std::io::ErrorKind::ConnectionRefused + | std::io::ErrorKind::ConnectionAborted + | std::io::ErrorKind::ConnectionReset + | std::io::ErrorKind::NotConnected + | std::io::ErrorKind::AddrNotAvailable => FetchErrorKind::Connect, + _ => FetchErrorKind::Other, + }; + } + } + + let message = err.to_string().to_lowercase(); + if message.contains("upstream route") { + FetchErrorKind::Route + } else if message.contains("serverhello not received") { + FetchErrorKind::ServerHelloMissing + } else if message.contains("alert") { + FetchErrorKind::TlsAlert + } else if message.contains("parse") { + FetchErrorKind::Parse + } else if message.contains("timed out") || message.contains("deadline has elapsed") { + FetchErrorKind::Timeout + } else if message.contains("eof") { + FetchErrorKind::EarlyEof + } else { + FetchErrorKind::Other + } +} + +fn order_profiles( + strategy: &TlsFetchStrategy, + cache_key: Option<&ProfileCacheKey>, + now: Instant, +) -> Vec { + let mut ordered = if strategy.profiles.is_empty() { + vec![TlsFetchProfile::CompatTls12] + } else { + strategy.profiles.clone() + }; + + if strategy.profile_cache_ttl.is_zero() { + return ordered; + } + + let Some(key) = cache_key else { + return ordered; + }; + + if let Some(cached) = profile_cache().get(key) { + let age = now.saturating_duration_since(cached.updated_at); + if age > strategy.profile_cache_ttl { + drop(cached); + profile_cache().remove(key); + return ordered; + } + + if let Some(pos) = ordered.iter().position(|profile| *profile == cached.profile) { + if pos != 0 { + ordered.swap(0, pos); + } + } + } + + ordered +} + +fn remember_profile_success( + strategy: &TlsFetchStrategy, + cache_key: Option, + profile: TlsFetchProfile, + now: Instant, +) { + if strategy.profile_cache_ttl.is_zero() { + return; + } + let Some(key) = cache_key else { + return; + }; + profile_cache().insert( + key, + ProfileCacheValue { + profile, + updated_at: now, + }, + ); +} + fn build_client_config() -> Arc { let root = rustls::RootCertStore::empty(); @@ -95,7 +289,114 @@ fn build_client_config() -> Arc { Arc::new(config) } -fn build_client_hello(sni: &str, rng: &SecureRandom) -> Vec { +fn deterministic_bytes(seed: &str, len: usize) -> Vec { + let mut out = Vec::with_capacity(len); + let mut counter: u32 = 0; + while out.len() < len { + let mut chunk_seed = Vec::with_capacity(seed.len() + std::mem::size_of::()); + chunk_seed.extend_from_slice(seed.as_bytes()); + chunk_seed.extend_from_slice(&counter.to_le_bytes()); + out.extend_from_slice(&sha256(&chunk_seed)); + counter = counter.wrapping_add(1); + } + out.truncate(len); + out +} + +fn profile_cipher_suites(profile: TlsFetchProfile) -> &'static [u16] { + const MODERN_CHROME: &[u16] = &[ + 0x1301, 0x1302, 0x1303, 0xc02b, 0xc02c, 0xcca9, 0xc02f, 0xc030, 0xcca8, 0x009e, 0x00ff, + ]; + const MODERN_FIREFOX: &[u16] = &[ + 0x1301, 0x1303, 0x1302, 0xc02b, 0xcca9, 0xc02c, 0xc02f, 0xcca8, 0xc030, 0x009e, 0x00ff, + ]; + const COMPAT_TLS12: &[u16] = &[ + 0xc02b, 0xc02c, 0xc02f, 0xc030, 0xcca9, 0xcca8, 0x1301, 0x1302, 0x1303, 0x009e, 0x00ff, + ]; + const LEGACY_MINIMAL: &[u16] = &[0xc02b, 0xc02f, 0x1301, 0x1302, 0x00ff]; + + match profile { + TlsFetchProfile::ModernChromeLike => MODERN_CHROME, + TlsFetchProfile::ModernFirefoxLike => MODERN_FIREFOX, + TlsFetchProfile::CompatTls12 => COMPAT_TLS12, + TlsFetchProfile::LegacyMinimal => LEGACY_MINIMAL, + } +} + +fn profile_groups(profile: TlsFetchProfile) -> &'static [u16] { + const MODERN: &[u16] = &[0x001d, 0x0017, 0x0018]; // x25519, secp256r1, secp384r1 + const COMPAT: &[u16] = &[0x001d, 0x0017]; + const LEGACY: &[u16] = &[0x0017]; + + match profile { + TlsFetchProfile::ModernChromeLike | TlsFetchProfile::ModernFirefoxLike => MODERN, + TlsFetchProfile::CompatTls12 => COMPAT, + TlsFetchProfile::LegacyMinimal => LEGACY, + } +} + +fn profile_sig_algs(profile: TlsFetchProfile) -> &'static [u16] { + const MODERN: &[u16] = &[0x0804, 0x0805, 0x0403, 0x0503, 0x0806]; + const COMPAT: &[u16] = &[0x0403, 0x0503, 0x0804, 0x0805]; + const LEGACY: &[u16] = &[0x0403, 0x0804]; + + match profile { + TlsFetchProfile::ModernChromeLike | TlsFetchProfile::ModernFirefoxLike => MODERN, + TlsFetchProfile::CompatTls12 => COMPAT, + TlsFetchProfile::LegacyMinimal => LEGACY, + } +} + +fn profile_alpn(profile: TlsFetchProfile) -> &'static [&'static [u8]] { + const H2_HTTP11: &[&[u8]] = &[b"h2", b"http/1.1"]; + const HTTP11: &[&[u8]] = &[b"http/1.1"]; + match profile { + TlsFetchProfile::ModernChromeLike | TlsFetchProfile::ModernFirefoxLike => H2_HTTP11, + TlsFetchProfile::CompatTls12 | TlsFetchProfile::LegacyMinimal => HTTP11, + } +} + +fn profile_supported_versions(profile: TlsFetchProfile) -> &'static [u16] { + const MODERN: &[u16] = &[0x0304, 0x0303]; + const COMPAT: &[u16] = &[0x0303, 0x0304]; + const LEGACY: &[u16] = &[0x0303]; + match profile { + TlsFetchProfile::ModernChromeLike | TlsFetchProfile::ModernFirefoxLike => MODERN, + TlsFetchProfile::CompatTls12 => COMPAT, + TlsFetchProfile::LegacyMinimal => LEGACY, + } +} + +fn profile_padding_target(profile: TlsFetchProfile) -> usize { + match profile { + TlsFetchProfile::ModernChromeLike => 220, + TlsFetchProfile::ModernFirefoxLike => 200, + TlsFetchProfile::CompatTls12 => 180, + TlsFetchProfile::LegacyMinimal => 64, + } +} + +fn grease_value(rng: &SecureRandom, deterministic: bool, seed: &str) -> u16 { + const GREASE_VALUES: [u16; 16] = [ + 0x0a0a, 0x1a1a, 0x2a2a, 0x3a3a, 0x4a4a, 0x5a5a, 0x6a6a, 0x7a7a, 0x8a8a, 0x9a9a, 0xaaaa, + 0xbaba, 0xcaca, 0xdada, 0xeaea, 0xfafa, + ]; + if deterministic { + let idx = deterministic_bytes(seed, 1)[0] as usize % GREASE_VALUES.len(); + GREASE_VALUES[idx] + } else { + let idx = (rng.bytes(1)[0] as usize) % GREASE_VALUES.len(); + GREASE_VALUES[idx] + } +} + +fn build_client_hello( + sni: &str, + rng: &SecureRandom, + profile: TlsFetchProfile, + grease_enabled: bool, + deterministic: bool, +) -> Vec { // === ClientHello body === let mut body = Vec::new(); @@ -103,29 +404,20 @@ fn build_client_hello(sni: &str, rng: &SecureRandom) -> Vec { body.extend_from_slice(&[0x03, 0x03]); // Random - body.extend_from_slice(&rng.bytes(32)); + if deterministic { + body.extend_from_slice(&deterministic_bytes(&format!("tls-fetch-random:{sni}"), 32)); + } else { + body.extend_from_slice(&rng.bytes(32)); + } // Session ID: empty body.push(0); - // Cipher suites: - // - TLS1.3 set - // - broad TLS1.2 ECDHE set for RSA/ECDSA cert chains - // This keeps raw probing compatible with common production frontends that - // still negotiate TLS1.2. - let cipher_suites: [u16; 11] = [ - 0x1301, // TLS_AES_128_GCM_SHA256 - 0x1302, // TLS_AES_256_GCM_SHA384 - 0x1303, // TLS_CHACHA20_POLY1305_SHA256 - 0xc02b, // TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 - 0xc02c, // TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 - 0xcca9, // TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 - 0xc02f, // TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 - 0xc030, // TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 - 0xcca8, // TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 - 0x009e, // TLS_DHE_RSA_WITH_AES_128_GCM_SHA256 - 0x00ff, // TLS_EMPTY_RENEGOTIATION_INFO_SCSV - ]; + let mut cipher_suites = profile_cipher_suites(profile).to_vec(); + if grease_enabled { + let grease = grease_value(rng, deterministic, &format!("cipher:{sni}")); + cipher_suites.insert(0, grease); + } body.extend_from_slice(&((cipher_suites.len() * 2) as u16).to_be_bytes()); for suite in cipher_suites { body.extend_from_slice(&suite.to_be_bytes()); @@ -150,7 +442,11 @@ fn build_client_hello(sni: &str, rng: &SecureRandom) -> Vec { exts.extend_from_slice(&sni_ext); // supported_groups - let groups: [u16; 2] = [0x001d, 0x0017]; // x25519, secp256r1 + let mut groups = profile_groups(profile).to_vec(); + if grease_enabled { + let grease = grease_value(rng, deterministic, &format!("group:{sni}")); + groups.insert(0, grease); + } exts.extend_from_slice(&0x000au16.to_be_bytes()); exts.extend_from_slice(&((2 + groups.len() * 2) as u16).to_be_bytes()); exts.extend_from_slice(&(groups.len() as u16 * 2).to_be_bytes()); @@ -159,7 +455,11 @@ fn build_client_hello(sni: &str, rng: &SecureRandom) -> Vec { } // signature_algorithms - let sig_algs: [u16; 4] = [0x0804, 0x0805, 0x0403, 0x0503]; // rsa_pss_rsae_sha256/384, ecdsa_secp256r1_sha256, ecdsa_secp384r1_sha384 + let mut sig_algs = profile_sig_algs(profile).to_vec(); + if grease_enabled { + let grease = grease_value(rng, deterministic, &format!("sigalg:{sni}")); + sig_algs.insert(0, grease); + } exts.extend_from_slice(&0x000du16.to_be_bytes()); exts.extend_from_slice(&((2 + sig_algs.len() * 2) as u16).to_be_bytes()); exts.extend_from_slice(&(sig_algs.len() as u16 * 2).to_be_bytes()); @@ -167,8 +467,12 @@ fn build_client_hello(sni: &str, rng: &SecureRandom) -> Vec { exts.extend_from_slice(&a.to_be_bytes()); } - // supported_versions (TLS1.3 + TLS1.2) - let versions: [u16; 2] = [0x0304, 0x0303]; + // supported_versions + let mut versions = profile_supported_versions(profile).to_vec(); + if grease_enabled { + let grease = grease_value(rng, deterministic, &format!("version:{sni}")); + versions.insert(0, grease); + } exts.extend_from_slice(&0x002bu16.to_be_bytes()); exts.extend_from_slice(&((1 + versions.len() * 2) as u16).to_be_bytes()); exts.push((versions.len() * 2) as u8); @@ -177,7 +481,14 @@ fn build_client_hello(sni: &str, rng: &SecureRandom) -> Vec { } // key_share (x25519) - let key = gen_key_share(rng); + let key = if deterministic { + let det = deterministic_bytes(&format!("keyshare:{sni}"), 32); + let mut key = [0u8; 32]; + key.copy_from_slice(&det); + key + } else { + gen_key_share(rng) + }; let mut keyshare = Vec::with_capacity(4 + key.len()); keyshare.extend_from_slice(&0x001du16.to_be_bytes()); // group keyshare.extend_from_slice(&(key.len() as u16).to_be_bytes()); @@ -187,18 +498,29 @@ fn build_client_hello(sni: &str, rng: &SecureRandom) -> Vec { exts.extend_from_slice(&(keyshare.len() as u16).to_be_bytes()); exts.extend_from_slice(&keyshare); - // ALPN (http/1.1) - let alpn_proto = b"http/1.1"; - exts.extend_from_slice(&0x0010u16.to_be_bytes()); - exts.extend_from_slice(&((2 + 1 + alpn_proto.len()) as u16).to_be_bytes()); - exts.extend_from_slice(&((1 + alpn_proto.len()) as u16).to_be_bytes()); - exts.push(alpn_proto.len() as u8); - exts.extend_from_slice(alpn_proto); + // ALPN + let mut alpn_list = Vec::new(); + for proto in profile_alpn(profile) { + alpn_list.push(proto.len() as u8); + alpn_list.extend_from_slice(proto); + } + if !alpn_list.is_empty() { + exts.extend_from_slice(&0x0010u16.to_be_bytes()); + exts.extend_from_slice(&((2 + alpn_list.len()) as u16).to_be_bytes()); + exts.extend_from_slice(&(alpn_list.len() as u16).to_be_bytes()); + exts.extend_from_slice(&alpn_list); + } + + if grease_enabled { + let grease = grease_value(rng, deterministic, &format!("ext:{sni}")); + exts.extend_from_slice(&grease.to_be_bytes()); + exts.extend_from_slice(&0u16.to_be_bytes()); + } // padding to reduce recognizability and keep length ~500 bytes - const TARGET_EXT_LEN: usize = 180; - if exts.len() < TARGET_EXT_LEN { - let remaining = TARGET_EXT_LEN - exts.len(); + let target_ext_len = profile_padding_target(profile); + if exts.len() < target_ext_len { + let remaining = target_ext_len - exts.len(); if remaining > 4 { let pad_len = remaining - 4; // minus type+len exts.extend_from_slice(&0x0015u16.to_be_bytes()); // padding extension @@ -414,27 +736,41 @@ async fn connect_tcp_with_upstream( connect_timeout: Duration, upstream: Option>, scope: Option<&str>, + strict_route: bool, ) -> Result { if let Some(manager) = upstream { - if let Some(addr) = resolve_socket_addr(host, port) { - match manager.connect(addr, None, scope).await { - Ok(stream) => return Ok(stream), - Err(e) => { - warn!( - host = %host, - port = port, - scope = ?scope, - error = %e, - "Upstream connect failed, using direct connect" - ); - } - } - } else if let Ok(mut addrs) = tokio::net::lookup_host((host, port)).await - && let Some(addr) = addrs.find(|a| a.is_ipv4()) - { + let resolved = if let Some(addr) = resolve_socket_addr(host, port) { + Some(addr) + } else { + match tokio::net::lookup_host((host, port)).await { + Ok(mut addrs) => addrs.find(|a| a.is_ipv4()), + Err(e) => { + if strict_route { + return Err(anyhow!( + "upstream route DNS resolution failed for {host}:{port}: {e}" + )); + } + warn!( + host = %host, + port = port, + scope = ?scope, + error = %e, + "Upstream DNS resolution failed, using direct connect" + ); + None + } + } + }; + + if let Some(addr) = resolved { match manager.connect(addr, None, scope).await { Ok(stream) => return Ok(stream), Err(e) => { + if strict_route { + return Err(anyhow!( + "upstream route connect failed for {host}:{port}: {e}" + )); + } warn!( host = %host, port = port, @@ -444,6 +780,10 @@ async fn connect_tcp_with_upstream( ); } } + } else if strict_route { + return Err(anyhow!( + "upstream route resolution produced no usable address for {host}:{port}" + )); } } Ok(UpstreamStream::Tcp( @@ -483,12 +823,15 @@ async fn fetch_via_raw_tls_stream( sni: &str, connect_timeout: Duration, proxy_protocol: u8, + profile: TlsFetchProfile, + grease_enabled: bool, + deterministic: bool, ) -> Result where S: AsyncRead + AsyncWrite + Unpin, { let rng = SecureRandom::new(); - let client_hello = build_client_hello(sni, &rng); + let client_hello = build_client_hello(sni, &rng, profile, grease_enabled, deterministic); timeout(connect_timeout, async { if proxy_protocol > 0 { let header = match proxy_protocol { @@ -562,6 +905,10 @@ async fn fetch_via_raw_tls( scope: Option<&str>, proxy_protocol: u8, unix_sock: Option<&str>, + strict_route: bool, + profile: TlsFetchProfile, + grease_enabled: bool, + deterministic: bool, ) -> Result { #[cfg(unix)] if let Some(sock_path) = unix_sock { @@ -572,8 +919,16 @@ async fn fetch_via_raw_tls( sock = %sock_path, "Raw TLS fetch using mask unix socket" ); - return fetch_via_raw_tls_stream(stream, sni, connect_timeout, proxy_protocol) - .await; + return fetch_via_raw_tls_stream( + stream, + sni, + connect_timeout, + proxy_protocol, + profile, + grease_enabled, + deterministic, + ) + .await; } Ok(Err(e)) => { warn!( @@ -596,8 +951,25 @@ async fn fetch_via_raw_tls( #[cfg(not(unix))] let _ = unix_sock; - let stream = connect_tcp_with_upstream(host, port, connect_timeout, upstream, scope).await?; - fetch_via_raw_tls_stream(stream, sni, connect_timeout, proxy_protocol).await + let stream = connect_tcp_with_upstream( + host, + port, + connect_timeout, + upstream, + scope, + strict_route, + ) + .await?; + fetch_via_raw_tls_stream( + stream, + sni, + connect_timeout, + proxy_protocol, + profile, + grease_enabled, + deterministic, + ) + .await } async fn fetch_via_rustls_stream( @@ -703,6 +1075,7 @@ async fn fetch_via_rustls( scope: Option<&str>, proxy_protocol: u8, unix_sock: Option<&str>, + strict_route: bool, ) -> Result { #[cfg(unix)] if let Some(sock_path) = unix_sock { @@ -736,16 +1109,159 @@ async fn fetch_via_rustls( #[cfg(not(unix))] let _ = unix_sock; - let stream = connect_tcp_with_upstream(host, port, connect_timeout, upstream, scope).await?; + let stream = connect_tcp_with_upstream( + host, + port, + connect_timeout, + upstream, + scope, + strict_route, + ) + .await?; fetch_via_rustls_stream(stream, host, sni, proxy_protocol).await } -/// Fetch real TLS metadata for the given SNI. -/// -/// Strategy: -/// 1) Probe raw TLS for realistic ServerHello and ApplicationData record sizes. -/// 2) Fetch certificate chain via rustls to build cert payload. -/// 3) Merge both when possible; otherwise auto-fallback to whichever succeeded. +/// Fetch real TLS metadata with an adaptive multi-profile strategy. +pub async fn fetch_real_tls_with_strategy( + host: &str, + port: u16, + sni: &str, + strategy: &TlsFetchStrategy, + upstream: Option>, + scope: Option<&str>, + proxy_protocol: u8, + unix_sock: Option<&str>, +) -> Result { + let attempt_timeout = strategy.attempt_timeout.max(Duration::from_millis(1)); + let total_budget = strategy.total_budget.max(Duration::from_millis(1)); + let started_at = Instant::now(); + let cache_key = profile_cache_key( + host, + port, + sni, + upstream.as_ref(), + scope, + proxy_protocol, + unix_sock, + ); + let profiles = order_profiles(strategy, Some(&cache_key), started_at); + + let mut raw_result = None; + let mut raw_last_error: Option = None; + let mut raw_last_error_kind = FetchErrorKind::Other; + let mut selected_profile = None; + + for profile in profiles { + let elapsed = started_at.elapsed(); + if elapsed >= total_budget { + break; + } + let timeout_for_attempt = attempt_timeout.min(total_budget - elapsed); + + match fetch_via_raw_tls( + host, + port, + sni, + timeout_for_attempt, + upstream.clone(), + scope, + proxy_protocol, + unix_sock, + strategy.strict_route, + profile, + strategy.grease_enabled, + strategy.deterministic, + ) + .await + { + Ok(res) => { + selected_profile = Some(profile); + raw_result = Some(res); + break; + } + Err(err) => { + let kind = classify_fetch_error(&err); + warn!( + sni = %sni, + profile = profile.as_str(), + error_kind = ?kind, + error = %err, + "Raw TLS fetch attempt failed" + ); + raw_last_error_kind = kind; + raw_last_error = Some(err); + if strategy.strict_route && matches!(kind, FetchErrorKind::Route) { + break; + } + } + } + } + + if let Some(profile) = selected_profile { + remember_profile_success(strategy, Some(cache_key), profile, Instant::now()); + } + + if raw_result.is_none() + && strategy.strict_route + && matches!(raw_last_error_kind, FetchErrorKind::Route) + { + if let Some(err) = raw_last_error { + return Err(err); + } + return Err(anyhow!("TLS fetch strict-route failure")); + } + + let elapsed = started_at.elapsed(); + if elapsed >= total_budget { + return match raw_result { + Some(raw) => Ok(raw), + None => Err(raw_last_error.unwrap_or_else(|| anyhow!("TLS fetch total budget exhausted"))), + }; + } + + let rustls_timeout = attempt_timeout.min(total_budget - elapsed); + let rustls_result = fetch_via_rustls( + host, + port, + sni, + rustls_timeout, + upstream, + scope, + proxy_protocol, + unix_sock, + strategy.strict_route, + ) + .await; + + match rustls_result { + Ok(rustls) => { + if let Some(mut raw) = raw_result { + raw.cert_info = rustls.cert_info; + raw.cert_payload = rustls.cert_payload; + raw.behavior_profile.source = TlsProfileSource::Merged; + debug!(sni = %sni, "Fetched TLS metadata via adaptive raw probe + rustls cert chain"); + Ok(raw) + } else { + Ok(rustls) + } + } + Err(err) => { + if let Some(raw) = raw_result { + warn!(sni = %sni, error = %err, "Rustls cert fetch failed, using raw TLS metadata only"); + Ok(raw) + } else if let Some(raw_err) = raw_last_error { + Err(anyhow!( + "TLS fetch failed (raw: {raw_err}; rustls: {err})" + )) + } else { + Err(err) + } + } + } +} + +/// Fetch real TLS metadata for the given SNI using a single-attempt compatibility strategy. +#[allow(dead_code)] pub async fn fetch_real_tls( host: &str, port: u16, @@ -756,62 +1272,30 @@ pub async fn fetch_real_tls( proxy_protocol: u8, unix_sock: Option<&str>, ) -> Result { - let raw_result = match fetch_via_raw_tls( + let strategy = TlsFetchStrategy::single_attempt(connect_timeout); + fetch_real_tls_with_strategy( host, port, sni, - connect_timeout, - upstream.clone(), - scope, - proxy_protocol, - unix_sock, - ) - .await - { - Ok(res) => Some(res), - Err(e) => { - warn!(sni = %sni, error = %e, "Raw TLS fetch failed"); - None - } - }; - - match fetch_via_rustls( - host, - port, - sni, - connect_timeout, + &strategy, upstream, scope, proxy_protocol, unix_sock, ) .await - { - Ok(rustls_result) => { - if let Some(mut raw) = raw_result { - raw.cert_info = rustls_result.cert_info; - raw.cert_payload = rustls_result.cert_payload; - raw.behavior_profile.source = TlsProfileSource::Merged; - debug!(sni = %sni, "Fetched TLS metadata via raw probe + rustls cert chain"); - Ok(raw) - } else { - Ok(rustls_result) - } - } - Err(e) => { - if let Some(raw) = raw_result { - warn!(sni = %sni, error = %e, "Rustls cert fetch failed, using raw TLS metadata only"); - Ok(raw) - } else { - Err(e) - } - } - } } #[cfg(test)] mod tests { - use super::{derive_behavior_profile, encode_tls13_certificate_message}; + use std::time::{Duration, Instant}; + + use super::{ + ProfileCacheValue, TlsFetchStrategy, build_client_hello, derive_behavior_profile, + encode_tls13_certificate_message, order_profiles, profile_cache, profile_cache_key, + }; + use crate::config::TlsFetchProfile; + use crate::crypto::SecureRandom; use crate::protocol::constants::{ TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER, TLS_RECORD_HANDSHAKE, }; @@ -860,4 +1344,89 @@ mod tests { assert_eq!(profile.ticket_record_sizes, vec![220, 180]); assert_eq!(profile.source, TlsProfileSource::Raw); } + + #[test] + fn test_order_profiles_prioritizes_fresh_cached_winner() { + let strategy = TlsFetchStrategy { + profiles: vec![ + TlsFetchProfile::ModernChromeLike, + TlsFetchProfile::CompatTls12, + TlsFetchProfile::LegacyMinimal, + ], + strict_route: true, + attempt_timeout: Duration::from_secs(1), + total_budget: Duration::from_secs(2), + grease_enabled: false, + deterministic: false, + profile_cache_ttl: Duration::from_secs(60), + }; + let cache_key = profile_cache_key( + "mask.example", + 443, + "tls.example", + None, + Some("tls"), + 0, + None, + ); + profile_cache().remove(&cache_key); + profile_cache().insert( + cache_key.clone(), + ProfileCacheValue { + profile: TlsFetchProfile::CompatTls12, + updated_at: Instant::now(), + }, + ); + + let ordered = order_profiles(&strategy, Some(&cache_key), Instant::now()); + assert_eq!(ordered[0], TlsFetchProfile::CompatTls12); + profile_cache().remove(&cache_key); + } + + #[test] + fn test_order_profiles_drops_expired_cached_winner() { + let strategy = TlsFetchStrategy { + profiles: vec![TlsFetchProfile::ModernFirefoxLike, TlsFetchProfile::CompatTls12], + strict_route: true, + attempt_timeout: Duration::from_secs(1), + total_budget: Duration::from_secs(2), + grease_enabled: false, + deterministic: false, + profile_cache_ttl: Duration::from_secs(5), + }; + let cache_key = profile_cache_key("mask2.example", 443, "tls2.example", None, None, 0, None); + profile_cache().remove(&cache_key); + profile_cache().insert( + cache_key.clone(), + ProfileCacheValue { + profile: TlsFetchProfile::CompatTls12, + updated_at: Instant::now() - Duration::from_secs(6), + }, + ); + + let ordered = order_profiles(&strategy, Some(&cache_key), Instant::now()); + assert_eq!(ordered[0], TlsFetchProfile::ModernFirefoxLike); + assert!(profile_cache().get(&cache_key).is_none()); + } + + #[test] + fn test_deterministic_client_hello_is_stable() { + let rng = SecureRandom::new(); + let first = build_client_hello( + "stable.example", + &rng, + TlsFetchProfile::ModernChromeLike, + true, + true, + ); + let second = build_client_hello( + "stable.example", + &rng, + TlsFetchProfile::ModernChromeLike, + true, + true, + ); + + assert_eq!(first, second); + } }