diff --git a/src/proxy/masking.rs b/src/proxy/masking.rs index c48ec9c..b2a8e84 100644 --- a/src/proxy/masking.rs +++ b/src/proxy/masking.rs @@ -2,6 +2,7 @@ use crate::config::ProxyConfig; use crate::network::dns_overrides::resolve_socket_addr; +use crate::protocol::tls; use crate::stats::beobachten::BeobachtenStore; use crate::transport::proxy_protocol::{ProxyProtocolV1Builder, ProxyProtocolV2Builder}; #[cfg(unix)] @@ -328,6 +329,89 @@ async fn wait_mask_outcome_budget(started: Instant, config: &ProxyConfig) { } } +#[cfg(test)] +mod tls_domain_mask_host_tests { + use super::{mask_host_for_initial_data, matching_tls_domain_for_sni}; + use crate::config::ProxyConfig; + + fn client_hello_with_sni(sni_host: &str) -> Vec { + let mut body = Vec::new(); + body.extend_from_slice(&[0x03, 0x03]); + body.extend_from_slice(&[0u8; 32]); + body.push(32); + body.extend_from_slice(&[0x42u8; 32]); + body.extend_from_slice(&2u16.to_be_bytes()); + body.extend_from_slice(&[0x13, 0x01]); + body.push(1); + body.push(0); + + let host_bytes = sni_host.as_bytes(); + let mut sni_payload = Vec::new(); + sni_payload.extend_from_slice(&((host_bytes.len() + 3) as u16).to_be_bytes()); + sni_payload.push(0); + sni_payload.extend_from_slice(&(host_bytes.len() as u16).to_be_bytes()); + sni_payload.extend_from_slice(host_bytes); + + let mut extensions = Vec::new(); + extensions.extend_from_slice(&0x0000u16.to_be_bytes()); + extensions.extend_from_slice(&(sni_payload.len() as u16).to_be_bytes()); + extensions.extend_from_slice(&sni_payload); + body.extend_from_slice(&(extensions.len() as u16).to_be_bytes()); + body.extend_from_slice(&extensions); + + let mut handshake = Vec::new(); + handshake.push(0x01); + let body_len = (body.len() as u32).to_be_bytes(); + handshake.extend_from_slice(&body_len[1..4]); + handshake.extend_from_slice(&body); + + let mut record = Vec::new(); + record.push(0x16); + record.extend_from_slice(&[0x03, 0x01]); + record.extend_from_slice(&(handshake.len() as u16).to_be_bytes()); + record.extend_from_slice(&handshake); + record + } + + fn config_with_tls_domains() -> ProxyConfig { + let mut config = ProxyConfig::default(); + config.censorship.tls_domain = "a.com".to_string(); + config.censorship.tls_domains = vec!["b.com".to_string(), "c.com".to_string()]; + config.censorship.mask_host = Some("a.com".to_string()); + config + } + + #[test] + fn matching_tls_domain_accepts_primary_and_extra_domains_case_insensitively() { + let config = config_with_tls_domains(); + + assert_eq!(matching_tls_domain_for_sni(&config, "A.COM"), Some("a.com")); + assert_eq!(matching_tls_domain_for_sni(&config, "B.COM"), Some("b.com")); + assert_eq!(matching_tls_domain_for_sni(&config, "unknown.com"), None); + } + + #[test] + fn mask_host_preserves_explicit_non_primary_origin() { + let mut config = config_with_tls_domains(); + config.censorship.mask_host = Some("origin.example".to_string()); + + let initial_data = client_hello_with_sni("b.com"); + + assert_eq!( + mask_host_for_initial_data(&config, &initial_data), + "origin.example" + ); + } + + #[test] + fn mask_host_uses_matching_tls_domain_when_mask_host_is_primary_default() { + let config = config_with_tls_domains(); + let initial_data = client_hello_with_sni("b.com"); + + assert_eq!(mask_host_for_initial_data(&config, &initial_data), "b.com"); + } +} + /// Detect client type based on initial data fn detect_client_type(data: &[u8]) -> &'static str { // Check for HTTP request @@ -360,6 +444,37 @@ fn parse_mask_host_ip_literal(host: &str) -> Option { host.parse::().ok() } +fn matching_tls_domain_for_sni<'a>(config: &'a ProxyConfig, sni: &str) -> Option<&'a str> { + if config.censorship.tls_domain.eq_ignore_ascii_case(sni) { + return Some(config.censorship.tls_domain.as_str()); + } + + for domain in &config.censorship.tls_domains { + if domain.eq_ignore_ascii_case(sni) { + return Some(domain.as_str()); + } + } + + None +} + +fn mask_host_for_initial_data<'a>(config: &'a ProxyConfig, initial_data: &[u8]) -> &'a str { + let configured_mask_host = config + .censorship + .mask_host + .as_deref() + .unwrap_or(&config.censorship.tls_domain); + + if !configured_mask_host.eq_ignore_ascii_case(&config.censorship.tls_domain) { + return configured_mask_host; + } + + tls::extract_sni_from_client_hello(initial_data) + .as_deref() + .and_then(|sni| matching_tls_domain_for_sni(config, sni)) + .unwrap_or(configured_mask_host) +} + fn canonical_ip(ip: IpAddr) -> IpAddr { match ip { IpAddr::V6(v6) => v6 @@ -734,11 +849,7 @@ pub async fn handle_bad_client( return; } - let mask_host = config - .censorship - .mask_host - .as_deref() - .unwrap_or(&config.censorship.tls_domain); + let mask_host = mask_host_for_initial_data(config, initial_data); let mask_port = config.censorship.mask_port; // Fail closed when fallback points at our own listener endpoint.