diff --git a/src/tls_front/cache.rs b/src/tls_front/cache.rs new file mode 100644 index 0000000..3fddd07 --- /dev/null +++ b/src/tls_front/cache.rs @@ -0,0 +1,103 @@ +use std::collections::HashMap; +use std::path::{Path, PathBuf}; +use std::sync::Arc; +use std::time::{SystemTime, Duration}; + +use tokio::sync::RwLock; +use tokio::time::sleep; +use tracing::{debug, warn, info}; + +use crate::tls_front::types::{CachedTlsData, ParsedServerHello, TlsFetchResult}; + +/// Lightweight in-memory + optional on-disk cache for TLS fronting data. +#[derive(Debug)] +pub struct TlsFrontCache { + memory: RwLock>>, + default: Arc, + disk_path: PathBuf, +} + +impl TlsFrontCache { + pub fn new(domains: &[String], default_len: usize, disk_path: impl AsRef) -> Self { + let default_template = ParsedServerHello { + version: [0x03, 0x03], + random: [0u8; 32], + session_id: Vec::new(), + cipher_suite: [0x13, 0x01], + compression: 0, + extensions: Vec::new(), + }; + + let default = Arc::new(CachedTlsData { + server_hello_template: default_template, + cert_info: None, + app_data_records_sizes: vec![default_len], + total_app_data_len: default_len, + fetched_at: SystemTime::now(), + domain: "default".to_string(), + }); + + let mut map = HashMap::new(); + for d in domains { + map.insert(d.clone(), default.clone()); + } + + Self { + memory: RwLock::new(map), + default, + disk_path: disk_path.as_ref().to_path_buf(), + } + } + + pub async fn get(&self, sni: &str) -> Arc { + let guard = self.memory.read().await; + guard.get(sni).cloned().unwrap_or_else(|| self.default.clone()) + } + + pub async fn set(&self, domain: &str, data: CachedTlsData) { + let mut guard = self.memory.write().await; + guard.insert(domain.to_string(), Arc::new(data)); + } + + /// Spawn background updater that periodically refreshes cached domains using provided fetcher. + pub fn spawn_updater( + self: Arc, + domains: Vec, + interval: Duration, + fetcher: F, + ) where + F: Fn(String) -> tokio::task::JoinHandle<()> + Send + Sync + 'static, + { + tokio::spawn(async move { + loop { + for domain in &domains { + fetcher(domain.clone()).await; + } + sleep(interval).await; + } + }); + } + + /// Replace cached entry from a fetch result. + pub async fn update_from_fetch(&self, domain: &str, fetched: TlsFetchResult) { + let data = CachedTlsData { + server_hello_template: fetched.server_hello_parsed, + cert_info: None, + app_data_records_sizes: fetched.app_data_records_sizes.clone(), + total_app_data_len: fetched.total_app_data_len, + fetched_at: SystemTime::now(), + domain: domain.to_string(), + }; + + self.set(domain, data).await; + debug!(domain = %domain, len = fetched.total_app_data_len, "TLS cache updated"); + } + + pub fn default_entry(&self) -> Arc { + self.default.clone() + } + + pub fn disk_path(&self) -> &Path { + &self.disk_path + } +} diff --git a/src/tls_front/fetcher.rs b/src/tls_front/fetcher.rs new file mode 100644 index 0000000..39a4609 --- /dev/null +++ b/src/tls_front/fetcher.rs @@ -0,0 +1,128 @@ +use std::sync::Arc; +use std::time::Duration; + +use tokio::net::TcpStream; +use tokio::time::timeout; +use tokio_rustls::client::TlsStream; +use tokio_rustls::TlsConnector; +use tracing::{debug, warn}; + +use rustls::client::{ClientConfig, ServerCertVerifier, ServerName}; +use rustls::{DigitallySignedStruct, Error as RustlsError}; +use rustls::pki_types::{ServerName as PkiServerName, UnixTime, CertificateDer}; + +use crate::tls_front::types::{ParsedServerHello, TlsFetchResult}; + +/// No-op verifier: accept any certificate (we only need lengths and metadata). +struct NoVerify; + +impl ServerCertVerifier for NoVerify { + fn verify_server_cert( + &self, + _end_entity: &CertificateDer<'_>, + _intermediates: &[CertificateDer<'_>], + _server_name: &PkiServerName<'_>, + _ocsp: &[u8], + _now: UnixTime, + ) -> Result { + Ok(rustls::client::ServerCertVerified::assertion()) + } + + fn verify_tls12_signature( + &self, + _message: &[u8], + _cert: &CertificateDer<'_>, + _dss: &DigitallySignedStruct, + ) -> Result { + Ok(rustls::client::HandshakeSignatureValid::assertion()) + } + + fn verify_tls13_signature( + &self, + _message: &[u8], + _cert: &CertificateDer<'_>, + _dss: &DigitallySignedStruct, + ) -> Result { + Ok(rustls::client::HandshakeSignatureValid::assertion()) + } +} + +fn build_client_config() -> Arc { + let mut root = rustls::RootCertStore::empty(); + // Optionally load system roots; failure is non-fatal. + let _ = root.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); + + Arc::new( + ClientConfig::builder() + .with_safe_default_cipher_suites() + .with_safe_default_kx_groups() + .with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12]) + .unwrap() + .with_custom_certificate_verifier(Arc::new(NoVerify)) + .with_root_certificates(root) + .with_no_client_auth(), + ) +} + +/// Try to fetch real TLS metadata for the given SNI. +pub async fn fetch_real_tls( + host: &str, + port: u16, + sni: &str, + connect_timeout: Duration, +) -> anyhow::Result { + let addr = format!("{}:{}", host, port); + let stream = timeout(connect_timeout, TcpStream::connect(addr)).await??; + + let config = build_client_config(); + let connector = TlsConnector::from(config); + + let server_name = ServerName::try_from(sni) + .or_else(|_| ServerName::try_from(host)) + .map_err(|_| RustlsError::General("invalid SNI".into()))?; + + let mut tls_stream: TlsStream = connector.connect(server_name, stream).await?; + + // Extract negotiated parameters and certificates + let (session, _io) = tls_stream.get_ref(); + let cipher_suite = session + .negotiated_cipher_suite() + .map(|s| s.suite().get_u16().to_be_bytes()) + .unwrap_or([0x13, 0x01]); + + let certs: Vec> = session + .peer_certificates() + .map(|slice| slice.iter().cloned().collect()) + .unwrap_or_default(); + + let total_cert_len: usize = certs.iter().map(|c| c.len()).sum::().max(1024); + + // Heuristic: split across two records if large to mimic real servers a bit. + let app_data_records_sizes = if total_cert_len > 3000 { + vec![total_cert_len / 2, total_cert_len - total_cert_len / 2] + } else { + vec![total_cert_len] + }; + + let parsed = ParsedServerHello { + version: [0x03, 0x03], + random: [0u8; 32], + session_id: Vec::new(), + cipher_suite, + compression: 0, + extensions: Vec::new(), + }; + + debug!( + sni = %sni, + len = total_cert_len, + cipher = format!("0x{:04x}", u16::from_be_bytes(cipher_suite)), + "Fetched TLS metadata" + ); + + Ok(TlsFetchResult { + server_hello_parsed: parsed, + app_data_records_sizes: app_data_records_sizes.clone(), + total_app_data_len: app_data_records_sizes.iter().sum(), + }) +} diff --git a/src/tls_front/types.rs b/src/tls_front/types.rs new file mode 100644 index 0000000..7f346db --- /dev/null +++ b/src/tls_front/types.rs @@ -0,0 +1,48 @@ +use std::time::SystemTime; + +/// Parsed representation of an unencrypted TLS ServerHello. +#[derive(Debug, Clone)] +pub struct ParsedServerHello { + pub version: [u8; 2], + pub random: [u8; 32], + pub session_id: Vec, + pub cipher_suite: [u8; 2], + pub compression: u8, + pub extensions: Vec, +} + +/// Generic TLS extension container. +#[derive(Debug, Clone)] +pub struct TlsExtension { + pub ext_type: u16, + pub data: Vec, +} + +/// Basic certificate metadata (optional, informative). +#[derive(Debug, Clone)] +pub struct ParsedCertificateInfo { + pub not_after_unix: Option, + pub not_before_unix: Option, + pub issuer_cn: Option, + pub subject_cn: Option, + pub san_names: Vec, +} + +/// Cached data per SNI used by the emulator. +#[derive(Debug, Clone)] +pub struct CachedTlsData { + pub server_hello_template: ParsedServerHello, + pub cert_info: Option, + pub app_data_records_sizes: Vec, + pub total_app_data_len: usize, + pub fetched_at: SystemTime, + pub domain: String, +} + +/// Result of attempting to fetch real TLS artifacts. +#[derive(Debug, Clone)] +pub struct TlsFetchResult { + pub server_hello_parsed: ParsedServerHello, + pub app_data_records_sizes: Vec, + pub total_app_data_len: usize, +}