#![allow(clippy::too_many_arguments)] use dashmap::DashMap; use std::sync::Arc; use std::sync::OnceLock; use std::time::{Duration, Instant}; use anyhow::{Result, anyhow}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::net::TcpStream; #[cfg(unix)] use tokio::net::UnixStream; use tokio::time::timeout; use tokio_rustls::TlsConnector; use tokio_rustls::client::TlsStream; use tracing::{debug, warn}; use rustls::client::ClientConfig; use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier}; use rustls::pki_types::{CertificateDer, ServerName, UnixTime}; use rustls::{DigitallySignedStruct, Error as RustlsError}; use x509_parser::certificate::X509Certificate; use x509_parser::prelude::FromDer; use crate::config::TlsFetchProfile; use crate::crypto::{SecureRandom, sha256}; use crate::network::dns_overrides::resolve_socket_addr; use crate::protocol::constants::{ TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER, TLS_RECORD_HANDSHAKE, }; use crate::tls_front::types::{ ParsedCertificateInfo, ParsedServerHello, TlsBehaviorProfile, TlsCertPayload, TlsExtension, TlsFetchResult, TlsProfileSource, }; use crate::transport::UpstreamStream; use crate::transport::proxy_protocol::{ProxyProtocolV1Builder, ProxyProtocolV2Builder}; /// No-op verifier: accept any certificate (we only need lengths and metadata). #[derive(Debug)] struct NoVerify; impl ServerCertVerifier for NoVerify { fn verify_server_cert( &self, _end_entity: &CertificateDer<'_>, _intermediates: &[CertificateDer<'_>], _server_name: &ServerName<'_>, _ocsp: &[u8], _now: UnixTime, ) -> Result { Ok(ServerCertVerified::assertion()) } fn verify_tls12_signature( &self, _message: &[u8], _cert: &CertificateDer<'_>, _dss: &DigitallySignedStruct, ) -> Result { Ok(HandshakeSignatureValid::assertion()) } fn verify_tls13_signature( &self, _message: &[u8], _cert: &CertificateDer<'_>, _dss: &DigitallySignedStruct, ) -> Result { Ok(HandshakeSignatureValid::assertion()) } fn supported_verify_schemes(&self) -> Vec { use rustls::SignatureScheme::*; vec![ RSA_PKCS1_SHA256, RSA_PSS_SHA256, ECDSA_NISTP256_SHA256, ECDSA_NISTP384_SHA384, ] } } #[derive(Debug, Clone)] pub struct TlsFetchStrategy { pub profiles: Vec, pub strict_route: bool, pub attempt_timeout: Duration, pub total_budget: Duration, pub grease_enabled: bool, pub deterministic: bool, pub profile_cache_ttl: Duration, } impl TlsFetchStrategy { #[allow(dead_code)] pub fn single_attempt(connect_timeout: Duration) -> Self { Self { profiles: vec![TlsFetchProfile::CompatTls12], strict_route: false, attempt_timeout: connect_timeout.max(Duration::from_millis(1)), total_budget: connect_timeout.max(Duration::from_millis(1)), grease_enabled: false, deterministic: false, profile_cache_ttl: Duration::ZERO, } } } #[derive(Debug, Clone, PartialEq, Eq, Hash)] struct ProfileCacheKey { host: String, port: u16, sni: String, scope: Option, proxy_protocol: u8, route_hint: RouteHint, } #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] enum RouteHint { Direct, Upstream, Unix, } #[derive(Debug, Clone, Copy)] struct ProfileCacheValue { profile: TlsFetchProfile, updated_at: Instant, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] enum FetchErrorKind { Connect, Route, EarlyEof, Timeout, ServerHelloMissing, TlsAlert, Parse, Other, } static PROFILE_CACHE: OnceLock> = OnceLock::new(); fn profile_cache() -> &'static DashMap { PROFILE_CACHE.get_or_init(DashMap::new) } fn route_hint( upstream: Option<&std::sync::Arc>, unix_sock: Option<&str>, ) -> RouteHint { if unix_sock.is_some() { RouteHint::Unix } else if upstream.is_some() { RouteHint::Upstream } else { RouteHint::Direct } } fn profile_cache_key( host: &str, port: u16, sni: &str, upstream: Option<&std::sync::Arc>, scope: Option<&str>, proxy_protocol: u8, unix_sock: Option<&str>, ) -> ProfileCacheKey { ProfileCacheKey { host: host.to_string(), port, sni: sni.to_string(), scope: scope.map(ToString::to_string), proxy_protocol, route_hint: route_hint(upstream, unix_sock), } } fn classify_fetch_error(err: &anyhow::Error) -> FetchErrorKind { for cause in err.chain() { if let Some(io) = cause.downcast_ref::() { return match io.kind() { std::io::ErrorKind::TimedOut => FetchErrorKind::Timeout, std::io::ErrorKind::UnexpectedEof => FetchErrorKind::EarlyEof, std::io::ErrorKind::ConnectionRefused | std::io::ErrorKind::ConnectionAborted | std::io::ErrorKind::ConnectionReset | std::io::ErrorKind::NotConnected | std::io::ErrorKind::AddrNotAvailable => FetchErrorKind::Connect, _ => FetchErrorKind::Other, }; } } let message = err.to_string().to_lowercase(); if message.contains("upstream route") { FetchErrorKind::Route } else if message.contains("serverhello not received") { FetchErrorKind::ServerHelloMissing } else if message.contains("alert") { FetchErrorKind::TlsAlert } else if message.contains("parse") { FetchErrorKind::Parse } else if message.contains("timed out") || message.contains("deadline has elapsed") { FetchErrorKind::Timeout } else if message.contains("eof") { FetchErrorKind::EarlyEof } else { FetchErrorKind::Other } } fn order_profiles( strategy: &TlsFetchStrategy, cache_key: Option<&ProfileCacheKey>, now: Instant, ) -> Vec { let mut ordered = if strategy.profiles.is_empty() { vec![TlsFetchProfile::CompatTls12] } else { strategy.profiles.clone() }; if strategy.profile_cache_ttl.is_zero() { return ordered; } let Some(key) = cache_key else { return ordered; }; if let Some(cached) = profile_cache().get(key) { let age = now.saturating_duration_since(cached.updated_at); if age > strategy.profile_cache_ttl { drop(cached); profile_cache().remove(key); return ordered; } if let Some(pos) = ordered .iter() .position(|profile| *profile == cached.profile) { if pos != 0 { ordered.swap(0, pos); } } } ordered } fn remember_profile_success( strategy: &TlsFetchStrategy, cache_key: Option, profile: TlsFetchProfile, now: Instant, ) { if strategy.profile_cache_ttl.is_zero() { return; } let Some(key) = cache_key else { return; }; profile_cache().insert( key, ProfileCacheValue { profile, updated_at: now, }, ); } fn build_client_config() -> Arc { let root = rustls::RootCertStore::empty(); let provider = rustls::crypto::ring::default_provider(); let mut config = ClientConfig::builder_with_provider(Arc::new(provider)) .with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12]) .expect("protocol versions") .with_root_certificates(root) .with_no_client_auth(); config .dangerous() .set_certificate_verifier(Arc::new(NoVerify)); Arc::new(config) } fn deterministic_bytes(seed: &str, len: usize) -> Vec { let mut out = Vec::with_capacity(len); let mut counter: u32 = 0; while out.len() < len { let mut chunk_seed = Vec::with_capacity(seed.len() + std::mem::size_of::()); chunk_seed.extend_from_slice(seed.as_bytes()); chunk_seed.extend_from_slice(&counter.to_le_bytes()); out.extend_from_slice(&sha256(&chunk_seed)); counter = counter.wrapping_add(1); } out.truncate(len); out } fn profile_cipher_suites(profile: TlsFetchProfile) -> &'static [u16] { const MODERN_CHROME: &[u16] = &[ 0x1301, 0x1302, 0x1303, 0xc02b, 0xc02c, 0xcca9, 0xc02f, 0xc030, 0xcca8, 0x009e, 0x00ff, ]; const MODERN_FIREFOX: &[u16] = &[ 0x1301, 0x1303, 0x1302, 0xc02b, 0xcca9, 0xc02c, 0xc02f, 0xcca8, 0xc030, 0x009e, 0x00ff, ]; const COMPAT_TLS12: &[u16] = &[ 0xc02b, 0xc02c, 0xc02f, 0xc030, 0xcca9, 0xcca8, 0x1301, 0x1302, 0x1303, 0x009e, 0x00ff, ]; const LEGACY_MINIMAL: &[u16] = &[0xc02b, 0xc02f, 0x1301, 0x1302, 0x00ff]; match profile { TlsFetchProfile::ModernChromeLike => MODERN_CHROME, TlsFetchProfile::ModernFirefoxLike => MODERN_FIREFOX, TlsFetchProfile::CompatTls12 => COMPAT_TLS12, TlsFetchProfile::LegacyMinimal => LEGACY_MINIMAL, } } fn profile_groups(profile: TlsFetchProfile) -> &'static [u16] { const MODERN: &[u16] = &[0x001d, 0x0017, 0x0018]; // x25519, secp256r1, secp384r1 const COMPAT: &[u16] = &[0x001d, 0x0017]; const LEGACY: &[u16] = &[0x0017]; match profile { TlsFetchProfile::ModernChromeLike | TlsFetchProfile::ModernFirefoxLike => MODERN, TlsFetchProfile::CompatTls12 => COMPAT, TlsFetchProfile::LegacyMinimal => LEGACY, } } fn profile_sig_algs(profile: TlsFetchProfile) -> &'static [u16] { const MODERN: &[u16] = &[0x0804, 0x0805, 0x0403, 0x0503, 0x0806]; const COMPAT: &[u16] = &[0x0403, 0x0503, 0x0804, 0x0805]; const LEGACY: &[u16] = &[0x0403, 0x0804]; match profile { TlsFetchProfile::ModernChromeLike | TlsFetchProfile::ModernFirefoxLike => MODERN, TlsFetchProfile::CompatTls12 => COMPAT, TlsFetchProfile::LegacyMinimal => LEGACY, } } fn profile_alpn(profile: TlsFetchProfile) -> &'static [&'static [u8]] { const H2_HTTP11: &[&[u8]] = &[b"h2", b"http/1.1"]; const HTTP11: &[&[u8]] = &[b"http/1.1"]; match profile { TlsFetchProfile::ModernChromeLike | TlsFetchProfile::ModernFirefoxLike => H2_HTTP11, TlsFetchProfile::CompatTls12 | TlsFetchProfile::LegacyMinimal => HTTP11, } } fn profile_supported_versions(profile: TlsFetchProfile) -> &'static [u16] { const MODERN: &[u16] = &[0x0304, 0x0303]; const COMPAT: &[u16] = &[0x0303, 0x0304]; const LEGACY: &[u16] = &[0x0303]; match profile { TlsFetchProfile::ModernChromeLike | TlsFetchProfile::ModernFirefoxLike => MODERN, TlsFetchProfile::CompatTls12 => COMPAT, TlsFetchProfile::LegacyMinimal => LEGACY, } } fn profile_padding_target(profile: TlsFetchProfile) -> usize { match profile { TlsFetchProfile::ModernChromeLike => 220, TlsFetchProfile::ModernFirefoxLike => 200, TlsFetchProfile::CompatTls12 => 180, TlsFetchProfile::LegacyMinimal => 64, } } fn grease_value(rng: &SecureRandom, deterministic: bool, seed: &str) -> u16 { const GREASE_VALUES: [u16; 16] = [ 0x0a0a, 0x1a1a, 0x2a2a, 0x3a3a, 0x4a4a, 0x5a5a, 0x6a6a, 0x7a7a, 0x8a8a, 0x9a9a, 0xaaaa, 0xbaba, 0xcaca, 0xdada, 0xeaea, 0xfafa, ]; if deterministic { let idx = deterministic_bytes(seed, 1)[0] as usize % GREASE_VALUES.len(); GREASE_VALUES[idx] } else { let idx = (rng.bytes(1)[0] as usize) % GREASE_VALUES.len(); GREASE_VALUES[idx] } } fn build_client_hello( sni: &str, rng: &SecureRandom, profile: TlsFetchProfile, grease_enabled: bool, deterministic: bool, ) -> Vec { // === ClientHello body === let mut body = Vec::new(); // Legacy version (TLS 1.0) as in real ClientHello headers body.extend_from_slice(&[0x03, 0x03]); // Random if deterministic { body.extend_from_slice(&deterministic_bytes(&format!("tls-fetch-random:{sni}"), 32)); } else { body.extend_from_slice(&rng.bytes(32)); } // Session ID: empty body.push(0); let mut cipher_suites = profile_cipher_suites(profile).to_vec(); if grease_enabled { let grease = grease_value(rng, deterministic, &format!("cipher:{sni}")); cipher_suites.insert(0, grease); } body.extend_from_slice(&((cipher_suites.len() * 2) as u16).to_be_bytes()); for suite in cipher_suites { body.extend_from_slice(&suite.to_be_bytes()); } // Compression methods: null only body.push(1); body.push(0); // === Extensions === let mut exts = Vec::new(); // server_name (SNI) let sni_bytes = sni.as_bytes(); let mut sni_ext = Vec::with_capacity(5 + sni_bytes.len()); sni_ext.extend_from_slice(&(sni_bytes.len() as u16 + 3).to_be_bytes()); sni_ext.push(0); // host_name sni_ext.extend_from_slice(&(sni_bytes.len() as u16).to_be_bytes()); sni_ext.extend_from_slice(sni_bytes); exts.extend_from_slice(&0x0000u16.to_be_bytes()); exts.extend_from_slice(&(sni_ext.len() as u16).to_be_bytes()); exts.extend_from_slice(&sni_ext); // supported_groups let mut groups = profile_groups(profile).to_vec(); if grease_enabled { let grease = grease_value(rng, deterministic, &format!("group:{sni}")); groups.insert(0, grease); } exts.extend_from_slice(&0x000au16.to_be_bytes()); exts.extend_from_slice(&((2 + groups.len() * 2) as u16).to_be_bytes()); exts.extend_from_slice(&(groups.len() as u16 * 2).to_be_bytes()); for g in groups { exts.extend_from_slice(&g.to_be_bytes()); } // signature_algorithms let mut sig_algs = profile_sig_algs(profile).to_vec(); if grease_enabled { let grease = grease_value(rng, deterministic, &format!("sigalg:{sni}")); sig_algs.insert(0, grease); } exts.extend_from_slice(&0x000du16.to_be_bytes()); exts.extend_from_slice(&((2 + sig_algs.len() * 2) as u16).to_be_bytes()); exts.extend_from_slice(&(sig_algs.len() as u16 * 2).to_be_bytes()); for a in sig_algs { exts.extend_from_slice(&a.to_be_bytes()); } // supported_versions let mut versions = profile_supported_versions(profile).to_vec(); if grease_enabled { let grease = grease_value(rng, deterministic, &format!("version:{sni}")); versions.insert(0, grease); } exts.extend_from_slice(&0x002bu16.to_be_bytes()); exts.extend_from_slice(&((1 + versions.len() * 2) as u16).to_be_bytes()); exts.push((versions.len() * 2) as u8); for v in versions { exts.extend_from_slice(&v.to_be_bytes()); } // key_share (x25519) let key = if deterministic { let det = deterministic_bytes(&format!("keyshare:{sni}"), 32); let mut key = [0u8; 32]; key.copy_from_slice(&det); key } else { gen_key_share(rng) }; let mut keyshare = Vec::with_capacity(4 + key.len()); keyshare.extend_from_slice(&0x001du16.to_be_bytes()); // group keyshare.extend_from_slice(&(key.len() as u16).to_be_bytes()); keyshare.extend_from_slice(&key); exts.extend_from_slice(&0x0033u16.to_be_bytes()); exts.extend_from_slice(&((2 + keyshare.len()) as u16).to_be_bytes()); exts.extend_from_slice(&(keyshare.len() as u16).to_be_bytes()); exts.extend_from_slice(&keyshare); // ALPN let mut alpn_list = Vec::new(); for proto in profile_alpn(profile) { alpn_list.push(proto.len() as u8); alpn_list.extend_from_slice(proto); } if !alpn_list.is_empty() { exts.extend_from_slice(&0x0010u16.to_be_bytes()); exts.extend_from_slice(&((2 + alpn_list.len()) as u16).to_be_bytes()); exts.extend_from_slice(&(alpn_list.len() as u16).to_be_bytes()); exts.extend_from_slice(&alpn_list); } if grease_enabled { let grease = grease_value(rng, deterministic, &format!("ext:{sni}")); exts.extend_from_slice(&grease.to_be_bytes()); exts.extend_from_slice(&0u16.to_be_bytes()); } // padding to reduce recognizability and keep length ~500 bytes let target_ext_len = profile_padding_target(profile); if exts.len() < target_ext_len { let remaining = target_ext_len - exts.len(); if remaining > 4 { let pad_len = remaining - 4; // minus type+len exts.extend_from_slice(&0x0015u16.to_be_bytes()); // padding extension exts.extend_from_slice(&(pad_len as u16).to_be_bytes()); exts.resize(exts.len() + pad_len, 0); } } // Extensions length prefix body.extend_from_slice(&(exts.len() as u16).to_be_bytes()); body.extend_from_slice(&exts); // === Handshake wrapper === let mut handshake = Vec::new(); handshake.push(0x01); // ClientHello let len_bytes = (body.len() as u32).to_be_bytes(); handshake.extend_from_slice(&len_bytes[1..4]); handshake.extend_from_slice(&body); // === Record === let mut record = Vec::new(); record.push(TLS_RECORD_HANDSHAKE); record.extend_from_slice(&[0x03, 0x01]); // legacy record version record.extend_from_slice(&(handshake.len() as u16).to_be_bytes()); record.extend_from_slice(&handshake); record } fn gen_key_share(rng: &SecureRandom) -> [u8; 32] { let mut key = [0u8; 32]; key.copy_from_slice(&rng.bytes(32)); key } async fn read_tls_record(stream: &mut S) -> Result<(u8, Vec)> where S: AsyncRead + Unpin, { let mut header = [0u8; 5]; stream.read_exact(&mut header).await?; let len = u16::from_be_bytes([header[3], header[4]]) as usize; let mut body = vec![0u8; len]; stream.read_exact(&mut body).await?; Ok((header[0], body)) } fn parse_server_hello(body: &[u8]) -> Option { if body.len() < 4 || body[0] != 0x02 { return None; } let msg_len = u32::from_be_bytes([0, body[1], body[2], body[3]]) as usize; if msg_len + 4 > body.len() { return None; } let mut pos = 4; let version = [*body.get(pos)?, *body.get(pos + 1)?]; pos += 2; let mut random = [0u8; 32]; random.copy_from_slice(body.get(pos..pos + 32)?); pos += 32; let session_len = *body.get(pos)? as usize; pos += 1; let session_id = body.get(pos..pos + session_len)?.to_vec(); pos += session_len; let cipher_suite = [*body.get(pos)?, *body.get(pos + 1)?]; pos += 2; let compression = *body.get(pos)?; pos += 1; let ext_len = u16::from_be_bytes([*body.get(pos)?, *body.get(pos + 1)?]) as usize; pos += 2; let ext_end = pos.checked_add(ext_len)?; if ext_end > body.len() { return None; } let mut extensions = Vec::new(); while pos + 4 <= ext_end { let etype = u16::from_be_bytes([body[pos], body[pos + 1]]); let elen = u16::from_be_bytes([body[pos + 2], body[pos + 3]]) as usize; pos += 4; let data = body.get(pos..pos + elen)?.to_vec(); pos += elen; extensions.push(TlsExtension { ext_type: etype, data, }); } Some(ParsedServerHello { version, random, session_id, cipher_suite, compression, extensions, }) } fn derive_behavior_profile(records: &[(u8, Vec)]) -> TlsBehaviorProfile { let mut change_cipher_spec_count = 0u8; let mut app_data_record_sizes = Vec::new(); for (record_type, body) in records { match *record_type { TLS_RECORD_CHANGE_CIPHER => { change_cipher_spec_count = change_cipher_spec_count.saturating_add(1); } TLS_RECORD_APPLICATION => { app_data_record_sizes.push(body.len()); } _ => {} } } let mut ticket_record_sizes = Vec::new(); while app_data_record_sizes .last() .is_some_and(|size| *size <= 256 && ticket_record_sizes.len() < 2) { if let Some(size) = app_data_record_sizes.pop() { ticket_record_sizes.push(size); } } ticket_record_sizes.reverse(); TlsBehaviorProfile { change_cipher_spec_count: change_cipher_spec_count.max(1), app_data_record_sizes, ticket_record_sizes, source: TlsProfileSource::Raw, } } fn parse_cert_info(certs: &[CertificateDer<'static>]) -> Option { let first = certs.first()?; let (_rem, cert) = X509Certificate::from_der(first.as_ref()).ok()?; let not_before = Some(cert.validity().not_before.to_datetime().unix_timestamp()); let not_after = Some(cert.validity().not_after.to_datetime().unix_timestamp()); let issuer_cn = cert .issuer() .iter_common_name() .next() .and_then(|cn| cn.as_str().ok()) .map(|s| s.to_string()); let subject_cn = cert .subject() .iter_common_name() .next() .and_then(|cn| cn.as_str().ok()) .map(|s| s.to_string()); let san_names = cert .subject_alternative_name() .ok() .flatten() .map(|san| { san.value .general_names .iter() .filter_map(|gn| match gn { x509_parser::extensions::GeneralName::DNSName(n) => Some(n.to_string()), _ => None, }) .collect::>() }) .unwrap_or_default(); Some(ParsedCertificateInfo { not_after_unix: not_after, not_before_unix: not_before, issuer_cn, subject_cn, san_names, }) } fn u24_bytes(value: usize) -> Option<[u8; 3]> { if value > 0x00ff_ffff { return None; } Some([ ((value >> 16) & 0xff) as u8, ((value >> 8) & 0xff) as u8, (value & 0xff) as u8, ]) } async fn connect_with_dns_override( host: &str, port: u16, connect_timeout: Duration, ) -> Result { if let Some(addr) = resolve_socket_addr(host, port) { return Ok(timeout(connect_timeout, TcpStream::connect(addr)).await??); } Ok(timeout(connect_timeout, TcpStream::connect((host, port))).await??) } async fn connect_tcp_with_upstream( host: &str, port: u16, connect_timeout: Duration, upstream: Option>, scope: Option<&str>, strict_route: bool, ) -> Result { if let Some(manager) = upstream { let resolved = if let Some(addr) = resolve_socket_addr(host, port) { Some(addr) } else { match tokio::net::lookup_host((host, port)).await { Ok(mut addrs) => addrs.find(|a| a.is_ipv4()), Err(e) => { if strict_route { return Err(anyhow!( "upstream route DNS resolution failed for {host}:{port}: {e}" )); } warn!( host = %host, port = port, scope = ?scope, error = %e, "Upstream DNS resolution failed, using direct connect" ); None } } }; if let Some(addr) = resolved { match manager.connect(addr, None, scope).await { Ok(stream) => return Ok(stream), Err(e) => { if strict_route { return Err(anyhow!( "upstream route connect failed for {host}:{port}: {e}" )); } warn!( host = %host, port = port, scope = ?scope, error = %e, "Upstream connect failed, using direct connect" ); } } } else if strict_route { return Err(anyhow!( "upstream route resolution produced no usable address for {host}:{port}" )); } } Ok(UpstreamStream::Tcp( connect_with_dns_override(host, port, connect_timeout).await?, )) } fn encode_tls13_certificate_message(cert_chain_der: &[Vec]) -> Option> { if cert_chain_der.is_empty() { return None; } let mut certificate_list = Vec::new(); for cert in cert_chain_der { if cert.is_empty() { return None; } certificate_list.extend_from_slice(&u24_bytes(cert.len())?); certificate_list.extend_from_slice(cert); certificate_list.extend_from_slice(&0u16.to_be_bytes()); // cert_entry extensions } // Certificate = context_len(1) + certificate_list_len(3) + entries let body_len = 1usize.checked_add(3)?.checked_add(certificate_list.len())?; let mut message = Vec::with_capacity(4 + body_len); message.push(0x0b); // HandshakeType::certificate message.extend_from_slice(&u24_bytes(body_len)?); message.push(0x00); // certificate_request_context length message.extend_from_slice(&u24_bytes(certificate_list.len())?); message.extend_from_slice(&certificate_list); Some(message) } async fn fetch_via_raw_tls_stream( mut stream: S, sni: &str, connect_timeout: Duration, proxy_protocol: u8, profile: TlsFetchProfile, grease_enabled: bool, deterministic: bool, ) -> Result where S: AsyncRead + AsyncWrite + Unpin, { let rng = SecureRandom::new(); let client_hello = build_client_hello(sni, &rng, profile, grease_enabled, deterministic); timeout(connect_timeout, async { if proxy_protocol > 0 { let header = match proxy_protocol { 2 => ProxyProtocolV2Builder::new().build(), _ => ProxyProtocolV1Builder::new().build(), }; stream.write_all(&header).await?; } stream.write_all(&client_hello).await?; stream.flush().await?; Ok::<(), std::io::Error>(()) }) .await??; let mut records = Vec::new(); let mut app_records_seen = 0usize; // Read a bounded encrypted flight: ServerHello, CCS, certificate-like data, // and a small number of ticket-like tail records. for _ in 0..8 { match timeout(connect_timeout, read_tls_record(&mut stream)).await { Ok(Ok(rec)) => { if rec.0 == TLS_RECORD_APPLICATION { app_records_seen += 1; } records.push(rec); } Ok(Err(e)) => return Err(e), Err(_) => break, } if app_records_seen >= 4 { break; } } let mut server_hello = None; for (t, body) in &records { if *t == TLS_RECORD_HANDSHAKE && server_hello.is_none() { server_hello = parse_server_hello(body); } } let parsed = server_hello.ok_or_else(|| anyhow!("ServerHello not received"))?; let behavior_profile = derive_behavior_profile(&records); let mut app_sizes = behavior_profile.app_data_record_sizes.clone(); app_sizes.extend_from_slice(&behavior_profile.ticket_record_sizes); let total_app_data_len = app_sizes.iter().sum::().max(1024); let app_data_records_sizes = behavior_profile .app_data_record_sizes .first() .copied() .or_else(|| behavior_profile.ticket_record_sizes.first().copied()) .map(|size| vec![size]) .unwrap_or_else(|| vec![total_app_data_len]); Ok(TlsFetchResult { server_hello_parsed: parsed, app_data_records_sizes, total_app_data_len, behavior_profile, cert_info: None, cert_payload: None, }) } async fn fetch_via_raw_tls( host: &str, port: u16, sni: &str, connect_timeout: Duration, upstream: Option>, scope: Option<&str>, proxy_protocol: u8, unix_sock: Option<&str>, strict_route: bool, profile: TlsFetchProfile, grease_enabled: bool, deterministic: bool, ) -> Result { #[cfg(unix)] if let Some(sock_path) = unix_sock { match timeout(connect_timeout, UnixStream::connect(sock_path)).await { Ok(Ok(stream)) => { debug!( sni = %sni, sock = %sock_path, "Raw TLS fetch using mask unix socket" ); return fetch_via_raw_tls_stream( stream, sni, connect_timeout, proxy_protocol, profile, grease_enabled, deterministic, ) .await; } Ok(Err(e)) => { warn!( sni = %sni, sock = %sock_path, error = %e, "Raw TLS unix socket connect failed, falling back to TCP" ); } Err(_) => { warn!( sni = %sni, sock = %sock_path, "Raw TLS unix socket connect timed out, falling back to TCP" ); } } } #[cfg(not(unix))] let _ = unix_sock; let stream = connect_tcp_with_upstream(host, port, connect_timeout, upstream, scope, strict_route) .await?; fetch_via_raw_tls_stream( stream, sni, connect_timeout, proxy_protocol, profile, grease_enabled, deterministic, ) .await } async fn fetch_via_rustls_stream( mut stream: S, host: &str, sni: &str, proxy_protocol: u8, ) -> Result where S: AsyncRead + AsyncWrite + Unpin, { // rustls handshake path for certificate and basic negotiated metadata. if proxy_protocol > 0 { let header = match proxy_protocol { 2 => ProxyProtocolV2Builder::new().build(), _ => ProxyProtocolV1Builder::new().build(), }; stream.write_all(&header).await?; stream.flush().await?; } let config = build_client_config(); let connector = TlsConnector::from(config); let server_name = ServerName::try_from(sni.to_owned()) .or_else(|_| ServerName::try_from(host.to_owned())) .map_err(|_| RustlsError::General("invalid SNI".into()))?; let tls_stream: TlsStream = connector.connect(server_name, stream).await?; // Extract negotiated parameters and certificates let (_io, session) = tls_stream.get_ref(); let cipher_suite = session .negotiated_cipher_suite() .map(|s| u16::from(s.suite()).to_be_bytes()) .unwrap_or([0x13, 0x01]); let certs: Vec> = session .peer_certificates() .map(|slice| slice.to_vec()) .unwrap_or_default(); let cert_chain_der: Vec> = certs.iter().map(|c| c.as_ref().to_vec()).collect(); let cert_payload = encode_tls13_certificate_message(&cert_chain_der).map(|certificate_message| { TlsCertPayload { cert_chain_der: cert_chain_der.clone(), certificate_message, } }); let total_cert_len = cert_payload .as_ref() .map(|payload| payload.certificate_message.len()) .unwrap_or_else(|| cert_chain_der.iter().map(Vec::len).sum::()) .max(1024); let cert_info = parse_cert_info(&certs); // Heuristic: split across two records if large to mimic real servers a bit. let app_data_records_sizes = if total_cert_len > 3000 { vec![total_cert_len / 2, total_cert_len - total_cert_len / 2] } else { vec![total_cert_len] }; let parsed = ParsedServerHello { version: [0x03, 0x03], random: [0u8; 32], session_id: Vec::new(), cipher_suite, compression: 0, extensions: Vec::new(), }; debug!( sni = %sni, len = total_cert_len, cipher = format!("0x{:04x}", u16::from_be_bytes(cipher_suite)), has_cert_payload = cert_payload.is_some(), "Fetched TLS metadata via rustls" ); Ok(TlsFetchResult { server_hello_parsed: parsed, app_data_records_sizes: app_data_records_sizes.clone(), total_app_data_len: app_data_records_sizes.iter().sum(), behavior_profile: TlsBehaviorProfile { change_cipher_spec_count: 1, app_data_record_sizes: app_data_records_sizes, ticket_record_sizes: Vec::new(), source: TlsProfileSource::Rustls, }, cert_info, cert_payload, }) } async fn fetch_via_rustls( host: &str, port: u16, sni: &str, connect_timeout: Duration, upstream: Option>, scope: Option<&str>, proxy_protocol: u8, unix_sock: Option<&str>, strict_route: bool, ) -> Result { #[cfg(unix)] if let Some(sock_path) = unix_sock { match timeout(connect_timeout, UnixStream::connect(sock_path)).await { Ok(Ok(stream)) => { debug!( sni = %sni, sock = %sock_path, "Rustls fetch using mask unix socket" ); return fetch_via_rustls_stream(stream, host, sni, proxy_protocol).await; } Ok(Err(e)) => { warn!( sni = %sni, sock = %sock_path, error = %e, "Rustls unix socket connect failed, falling back to TCP" ); } Err(_) => { warn!( sni = %sni, sock = %sock_path, "Rustls unix socket connect timed out, falling back to TCP" ); } } } #[cfg(not(unix))] let _ = unix_sock; let stream = connect_tcp_with_upstream(host, port, connect_timeout, upstream, scope, strict_route) .await?; fetch_via_rustls_stream(stream, host, sni, proxy_protocol).await } /// Fetch real TLS metadata with an adaptive multi-profile strategy. pub async fn fetch_real_tls_with_strategy( host: &str, port: u16, sni: &str, strategy: &TlsFetchStrategy, upstream: Option>, scope: Option<&str>, proxy_protocol: u8, unix_sock: Option<&str>, ) -> Result { let attempt_timeout = strategy.attempt_timeout.max(Duration::from_millis(1)); let total_budget = strategy.total_budget.max(Duration::from_millis(1)); let started_at = Instant::now(); let cache_key = profile_cache_key( host, port, sni, upstream.as_ref(), scope, proxy_protocol, unix_sock, ); let profiles = order_profiles(strategy, Some(&cache_key), started_at); let mut raw_result = None; let mut raw_last_error: Option = None; let mut raw_last_error_kind = FetchErrorKind::Other; let mut selected_profile = None; for profile in profiles { let elapsed = started_at.elapsed(); if elapsed >= total_budget { break; } let timeout_for_attempt = attempt_timeout.min(total_budget - elapsed); match fetch_via_raw_tls( host, port, sni, timeout_for_attempt, upstream.clone(), scope, proxy_protocol, unix_sock, strategy.strict_route, profile, strategy.grease_enabled, strategy.deterministic, ) .await { Ok(res) => { selected_profile = Some(profile); raw_result = Some(res); break; } Err(err) => { let kind = classify_fetch_error(&err); warn!( sni = %sni, profile = profile.as_str(), error_kind = ?kind, error = %err, "Raw TLS fetch attempt failed" ); raw_last_error_kind = kind; raw_last_error = Some(err); if strategy.strict_route && matches!(kind, FetchErrorKind::Route) { break; } } } } if let Some(profile) = selected_profile { remember_profile_success(strategy, Some(cache_key), profile, Instant::now()); } if raw_result.is_none() && strategy.strict_route && matches!(raw_last_error_kind, FetchErrorKind::Route) { if let Some(err) = raw_last_error { return Err(err); } return Err(anyhow!("TLS fetch strict-route failure")); } let elapsed = started_at.elapsed(); if elapsed >= total_budget { return match raw_result { Some(raw) => Ok(raw), None => { Err(raw_last_error.unwrap_or_else(|| anyhow!("TLS fetch total budget exhausted"))) } }; } let rustls_timeout = attempt_timeout.min(total_budget - elapsed); let rustls_result = fetch_via_rustls( host, port, sni, rustls_timeout, upstream, scope, proxy_protocol, unix_sock, strategy.strict_route, ) .await; match rustls_result { Ok(rustls) => { if let Some(mut raw) = raw_result { raw.cert_info = rustls.cert_info; raw.cert_payload = rustls.cert_payload; raw.behavior_profile.source = TlsProfileSource::Merged; debug!(sni = %sni, "Fetched TLS metadata via adaptive raw probe + rustls cert chain"); Ok(raw) } else { Ok(rustls) } } Err(err) => { if let Some(raw) = raw_result { warn!(sni = %sni, error = %err, "Rustls cert fetch failed, using raw TLS metadata only"); Ok(raw) } else if let Some(raw_err) = raw_last_error { Err(anyhow!("TLS fetch failed (raw: {raw_err}; rustls: {err})")) } else { Err(err) } } } } /// Fetch real TLS metadata for the given SNI using a single-attempt compatibility strategy. #[allow(dead_code)] pub async fn fetch_real_tls( host: &str, port: u16, sni: &str, connect_timeout: Duration, upstream: Option>, scope: Option<&str>, proxy_protocol: u8, unix_sock: Option<&str>, ) -> Result { let strategy = TlsFetchStrategy::single_attempt(connect_timeout); fetch_real_tls_with_strategy( host, port, sni, &strategy, upstream, scope, proxy_protocol, unix_sock, ) .await } #[cfg(test)] mod tests { use std::time::{Duration, Instant}; use super::{ ProfileCacheValue, TlsFetchStrategy, build_client_hello, derive_behavior_profile, encode_tls13_certificate_message, order_profiles, profile_cache, profile_cache_key, }; use crate::config::TlsFetchProfile; use crate::crypto::SecureRandom; use crate::protocol::constants::{ TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER, TLS_RECORD_HANDSHAKE, }; use crate::tls_front::types::TlsProfileSource; fn read_u24(bytes: &[u8]) -> usize { ((bytes[0] as usize) << 16) | ((bytes[1] as usize) << 8) | (bytes[2] as usize) } #[test] fn test_encode_tls13_certificate_message_single_cert() { let cert = vec![0x30, 0x03, 0x02, 0x01, 0x01]; let message = encode_tls13_certificate_message(std::slice::from_ref(&cert)).expect("message"); assert_eq!(message[0], 0x0b); assert_eq!(read_u24(&message[1..4]), message.len() - 4); assert_eq!(message[4], 0x00); let cert_list_len = read_u24(&message[5..8]); assert_eq!(cert_list_len, cert.len() + 5); let cert_len = read_u24(&message[8..11]); assert_eq!(cert_len, cert.len()); assert_eq!(&message[11..11 + cert.len()], cert.as_slice()); assert_eq!(&message[11 + cert.len()..13 + cert.len()], &[0x00, 0x00]); } #[test] fn test_encode_tls13_certificate_message_empty_chain() { assert!(encode_tls13_certificate_message(&[]).is_none()); } #[test] fn test_derive_behavior_profile_splits_ticket_like_tail_records() { let profile = derive_behavior_profile(&[ (TLS_RECORD_HANDSHAKE, vec![0u8; 90]), (TLS_RECORD_CHANGE_CIPHER, vec![0x01]), (TLS_RECORD_APPLICATION, vec![0u8; 1400]), (TLS_RECORD_APPLICATION, vec![0u8; 220]), (TLS_RECORD_APPLICATION, vec![0u8; 180]), ]); assert_eq!(profile.change_cipher_spec_count, 1); assert_eq!(profile.app_data_record_sizes, vec![1400]); assert_eq!(profile.ticket_record_sizes, vec![220, 180]); assert_eq!(profile.source, TlsProfileSource::Raw); } #[test] fn test_order_profiles_prioritizes_fresh_cached_winner() { let strategy = TlsFetchStrategy { profiles: vec![ TlsFetchProfile::ModernChromeLike, TlsFetchProfile::CompatTls12, TlsFetchProfile::LegacyMinimal, ], strict_route: true, attempt_timeout: Duration::from_secs(1), total_budget: Duration::from_secs(2), grease_enabled: false, deterministic: false, profile_cache_ttl: Duration::from_secs(60), }; let cache_key = profile_cache_key( "mask.example", 443, "tls.example", None, Some("tls"), 0, None, ); profile_cache().remove(&cache_key); profile_cache().insert( cache_key.clone(), ProfileCacheValue { profile: TlsFetchProfile::CompatTls12, updated_at: Instant::now(), }, ); let ordered = order_profiles(&strategy, Some(&cache_key), Instant::now()); assert_eq!(ordered[0], TlsFetchProfile::CompatTls12); profile_cache().remove(&cache_key); } #[test] fn test_order_profiles_drops_expired_cached_winner() { let strategy = TlsFetchStrategy { profiles: vec![ TlsFetchProfile::ModernFirefoxLike, TlsFetchProfile::CompatTls12, ], strict_route: true, attempt_timeout: Duration::from_secs(1), total_budget: Duration::from_secs(2), grease_enabled: false, deterministic: false, profile_cache_ttl: Duration::from_secs(5), }; let cache_key = profile_cache_key("mask2.example", 443, "tls2.example", None, None, 0, None); profile_cache().remove(&cache_key); profile_cache().insert( cache_key.clone(), ProfileCacheValue { profile: TlsFetchProfile::CompatTls12, updated_at: Instant::now() - Duration::from_secs(6), }, ); let ordered = order_profiles(&strategy, Some(&cache_key), Instant::now()); assert_eq!(ordered[0], TlsFetchProfile::ModernFirefoxLike); assert!(profile_cache().get(&cache_key).is_none()); } #[test] fn test_deterministic_client_hello_is_stable() { let rng = SecureRandom::new(); let first = build_client_hello( "stable.example", &rng, TlsFetchProfile::ModernChromeLike, true, true, ); let second = build_client_hello( "stable.example", &rng, TlsFetchProfile::ModernChromeLike, true, true, ); assert_eq!(first, second); } }