diff --git a/src/maestro/tls_bootstrap.rs b/src/maestro/tls_bootstrap.rs index 7cf3039..4412723 100644 --- a/src/maestro/tls_bootstrap.rs +++ b/src/maestro/tls_bootstrap.rs @@ -10,6 +10,14 @@ use crate::tls_front::TlsFrontCache; use crate::tls_front::fetcher::TlsFetchStrategy; use crate::transport::UpstreamManager; +fn tls_fetch_host_for_domain(mask_host: &str, primary_tls_domain: &str, domain: &str) -> String { + if mask_host.eq_ignore_ascii_case(primary_tls_domain) { + domain.to_string() + } else { + mask_host.to_string() + } +} + pub(crate) async fn bootstrap_tls_front( config: &ProxyConfig, tls_domains: &[String], @@ -56,6 +64,7 @@ pub(crate) async fn bootstrap_tls_front( let cache_initial = cache.clone(); let domains_initial = tls_domains.to_vec(); let host_initial = mask_host.clone(); + let primary_initial = config.censorship.tls_domain.clone(); let unix_sock_initial = mask_unix_sock.clone(); let scope_initial = tls_fetch_scope.clone(); let upstream_initial = upstream_manager.clone(); @@ -64,7 +73,8 @@ pub(crate) async fn bootstrap_tls_front( let mut join = tokio::task::JoinSet::new(); for domain in domains_initial { let cache_domain = cache_initial.clone(); - let host_domain = host_initial.clone(); + let host_domain = + tls_fetch_host_for_domain(&host_initial, &primary_initial, &domain); let unix_sock_domain = unix_sock_initial.clone(); let scope_domain = scope_initial.clone(); let upstream_domain = upstream_initial.clone(); @@ -117,6 +127,7 @@ pub(crate) async fn bootstrap_tls_front( let cache_refresh = cache.clone(); let domains_refresh = tls_domains.to_vec(); let host_refresh = mask_host.clone(); + let primary_refresh = config.censorship.tls_domain.clone(); let unix_sock_refresh = mask_unix_sock.clone(); let scope_refresh = tls_fetch_scope.clone(); let upstream_refresh = upstream_manager.clone(); @@ -130,7 +141,8 @@ pub(crate) async fn bootstrap_tls_front( let mut join = tokio::task::JoinSet::new(); for domain in domains_refresh.clone() { let cache_domain = cache_refresh.clone(); - let host_domain = host_refresh.clone(); + let host_domain = + tls_fetch_host_for_domain(&host_refresh, &primary_refresh, &domain); let unix_sock_domain = unix_sock_refresh.clone(); let scope_domain = scope_refresh.clone(); let upstream_domain = upstream_refresh.clone(); @@ -186,3 +198,24 @@ pub(crate) async fn bootstrap_tls_front( tls_cache } + +#[cfg(test)] +mod tests { + use super::tls_fetch_host_for_domain; + + #[test] + fn tls_fetch_host_uses_each_domain_when_mask_host_is_primary_default() { + assert_eq!( + tls_fetch_host_for_domain("a.com", "a.com", "b.com"), + "b.com" + ); + } + + #[test] + fn tls_fetch_host_preserves_explicit_non_primary_mask_host() { + assert_eq!( + tls_fetch_host_for_domain("origin.example", "a.com", "b.com"), + "origin.example" + ); + } +} diff --git a/src/tls_front/cache.rs b/src/tls_front/cache.rs index af8addf..4f71f5a 100644 --- a/src/tls_front/cache.rs +++ b/src/tls_front/cache.rs @@ -130,6 +130,14 @@ impl TlsFrontCache { warn!(file = %name, "Skipping TLS cache entry with invalid domain"); continue; } + if !cert_info_matches_domain(&cached) { + warn!( + file = %name, + domain = %cached.domain, + "Skipping TLS cache entry with mismatched certificate metadata" + ); + continue; + } // fetched_at is skipped during deserialization; approximate with file mtime if available. if let Ok(meta) = entry.metadata().await && let Ok(modified) = meta.modified() @@ -209,10 +217,100 @@ impl TlsFrontCache { } } +fn cert_info_matches_domain(cached: &CachedTlsData) -> bool { + let Some(cert_info) = cached.cert_info.as_ref() else { + return true; + }; + if !cert_info.san_names.is_empty() { + return cert_info + .san_names + .iter() + .any(|name| dns_name_matches_domain(name, &cached.domain)); + } + cert_info + .subject_cn + .as_deref() + .map_or(true, |name| dns_name_matches_domain(name, &cached.domain)) +} + +fn dns_name_matches_domain(pattern: &str, domain: &str) -> bool { + let pattern = normalize_dns_name(pattern); + let domain = normalize_dns_name(domain); + if pattern == domain { + return true; + } + + let Some(suffix) = pattern.strip_prefix("*.") else { + return false; + }; + let Some(prefix) = domain.strip_suffix(suffix) else { + return false; + }; + prefix.ends_with('.') && !prefix[..prefix.len() - 1].contains('.') +} + +fn normalize_dns_name(value: &str) -> String { + value.trim().trim_end_matches('.').to_ascii_lowercase() +} + #[cfg(test)] mod tests { use super::*; + fn cached_with_cert_info( + domain: &str, + subject_cn: Option<&str>, + san_names: Vec<&str>, + ) -> CachedTlsData { + CachedTlsData { + server_hello_template: ParsedServerHello { + version: [0x03, 0x03], + random: [0u8; 32], + session_id: Vec::new(), + cipher_suite: [0x13, 0x01], + compression: 0, + extensions: Vec::new(), + }, + cert_info: Some(crate::tls_front::types::ParsedCertificateInfo { + not_after_unix: None, + not_before_unix: None, + issuer_cn: None, + subject_cn: subject_cn.map(str::to_string), + san_names: san_names.into_iter().map(str::to_string).collect(), + }), + cert_payload: None, + app_data_records_sizes: vec![1024], + total_app_data_len: 1024, + behavior_profile: TlsBehaviorProfile::default(), + fetched_at: SystemTime::now(), + domain: domain.to_string(), + } + } + + #[test] + fn cert_info_domain_match_accepts_exact_san() { + let cached = cached_with_cert_info("b.com", Some("a.com"), vec!["b.com"]); + assert!(cert_info_matches_domain(&cached)); + } + + #[test] + fn cert_info_domain_match_rejects_wrong_san() { + let cached = cached_with_cert_info("b.com", Some("b.com"), vec!["a.com"]); + assert!(!cert_info_matches_domain(&cached)); + } + + #[test] + fn cert_info_domain_match_accepts_single_label_wildcard_san() { + let cached = cached_with_cert_info("api.b.com", None, vec!["*.b.com"]); + assert!(cert_info_matches_domain(&cached)); + } + + #[test] + fn cert_info_domain_match_rejects_multi_label_wildcard_san() { + let cached = cached_with_cert_info("deep.api.b.com", None, vec!["*.b.com"]); + assert!(!cert_info_matches_domain(&cached)); + } + #[tokio::test] async fn test_take_full_cert_budget_for_ip_uses_ttl() { let cache = TlsFrontCache::new(&["example.com".to_string()], 1024, "tlsfront-test-cache");