diff --git a/src/proxy/handshake.rs b/src/proxy/handshake.rs index d060dc7..ea36aca 100644 --- a/src/proxy/handshake.rs +++ b/src/proxy/handshake.rs @@ -108,11 +108,18 @@ where let cached = if config.censorship.tls_emulation { if let Some(cache) = tls_cache.as_ref() { - if let Some(sni) = tls::extract_sni_from_client_hello(handshake) { - Some(cache.get(&sni).await) + let selected_domain = if let Some(sni) = tls::extract_sni_from_client_hello(handshake) { + if cache.contains_domain(&sni).await { + sni + } else { + config.censorship.tls_domain.clone() + } } else { - Some(cache.get(&config.censorship.tls_domain).await) - } + config.censorship.tls_domain.clone() + }; + let cached_entry = cache.get(&selected_domain).await; + let use_full_cert_payload = cache.take_full_cert_budget(&selected_domain).await; + Some((cached_entry, use_full_cert_payload)) } else { None } @@ -137,12 +144,13 @@ where None }; - let response = if let Some(cached_entry) = cached { + let response = if let Some((cached_entry, use_full_cert_payload)) = cached { emulator::build_emulated_server_hello( secret, &validation.digest, &validation.session_id, &cached_entry, + use_full_cert_payload, rng, selected_alpn.clone(), config.censorship.tls_new_session_tickets, diff --git a/src/tls_front/cache.rs b/src/tls_front/cache.rs index 803cdf9..1d3fd88 100644 --- a/src/tls_front/cache.rs +++ b/src/tls_front/cache.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::path::{Path, PathBuf}; use std::sync::Arc; use std::time::{SystemTime, Duration}; @@ -14,6 +14,7 @@ use crate::tls_front::types::{CachedTlsData, ParsedServerHello, TlsFetchResult}; pub struct TlsFrontCache { memory: RwLock>>, default: Arc, + full_cert_sent: RwLock>, disk_path: PathBuf, } @@ -46,6 +47,7 @@ impl TlsFrontCache { Self { memory: RwLock::new(map), default, + full_cert_sent: RwLock::new(HashSet::new()), disk_path: disk_path.as_ref().to_path_buf(), } } @@ -55,6 +57,15 @@ impl TlsFrontCache { guard.get(sni).cloned().unwrap_or_else(|| self.default.clone()) } + pub async fn contains_domain(&self, domain: &str) -> bool { + self.memory.read().await.contains_key(domain) + } + + /// Returns true only on first request for a domain after process start. + pub async fn take_full_cert_budget(&self, domain: &str) -> bool { + self.full_cert_sent.write().await.insert(domain.to_string()) + } + pub async fn set(&self, domain: &str, data: CachedTlsData) { let mut guard = self.memory.write().await; guard.insert(domain.to_string(), Arc::new(data)); diff --git a/src/tls_front/emulator.rs b/src/tls_front/emulator.rs index d2a4697..25d2a8c 100644 --- a/src/tls_front/emulator.rs +++ b/src/tls_front/emulator.rs @@ -3,7 +3,7 @@ use crate::protocol::constants::{ TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER, TLS_RECORD_HANDSHAKE, TLS_VERSION, }; use crate::protocol::tls::{TLS_DIGEST_LEN, TLS_DIGEST_POS, gen_fake_x25519_key}; -use crate::tls_front::types::CachedTlsData; +use crate::tls_front::types::{CachedTlsData, ParsedCertificateInfo}; const MIN_APP_DATA: usize = 64; const MAX_APP_DATA: usize = 16640; // RFC 8446 ยง5.2 allows up to 2^14 + 256 @@ -27,39 +27,81 @@ fn jitter_and_clamp_sizes(sizes: &[usize], rng: &SecureRandom) -> Vec { .collect() } +fn app_data_body_capacity(sizes: &[usize]) -> usize { + sizes.iter().map(|&size| size.saturating_sub(17)).sum() +} + fn ensure_payload_capacity(mut sizes: Vec, payload_len: usize) -> Vec { if payload_len == 0 { return sizes; } - let mut total = sizes.iter().sum::(); - if total >= payload_len { + let mut body_total = app_data_body_capacity(&sizes); + if body_total >= payload_len { return sizes; } if let Some(last) = sizes.last_mut() { let free = MAX_APP_DATA.saturating_sub(*last); - let grow = free.min(payload_len - total); + let grow = free.min(payload_len - body_total); *last += grow; - total += grow; + body_total += grow; } - while total < payload_len { - let remaining = payload_len - total; - let chunk = remaining.min(MAX_APP_DATA).max(MIN_APP_DATA); + while body_total < payload_len { + let remaining = payload_len - body_total; + let chunk = (remaining + 17).min(MAX_APP_DATA).max(MIN_APP_DATA); sizes.push(chunk); - total += chunk; + body_total += chunk.saturating_sub(17); } sizes } +fn build_compact_cert_info_payload(cert_info: &ParsedCertificateInfo) -> Option> { + let mut fields = Vec::new(); + + if let Some(subject) = cert_info.subject_cn.as_deref() { + fields.push(format!("CN={subject}")); + } + if let Some(issuer) = cert_info.issuer_cn.as_deref() { + fields.push(format!("ISSUER={issuer}")); + } + if let Some(not_before) = cert_info.not_before_unix { + fields.push(format!("NB={not_before}")); + } + if let Some(not_after) = cert_info.not_after_unix { + fields.push(format!("NA={not_after}")); + } + if !cert_info.san_names.is_empty() { + let san = cert_info + .san_names + .iter() + .take(8) + .map(String::as_str) + .collect::>() + .join(","); + fields.push(format!("SAN={san}")); + } + + if fields.is_empty() { + return None; + } + + let mut payload = fields.join(";").into_bytes(); + if payload.len() > 512 { + payload.truncate(512); + } + Some(payload) +} + /// Build a ServerHello + CCS + ApplicationData sequence using cached TLS metadata. pub fn build_emulated_server_hello( secret: &[u8], client_digest: &[u8; TLS_DIGEST_LEN], session_id: &[u8], cached: &CachedTlsData, + use_full_cert_payload: bool, rng: &SecureRandom, alpn: Option>, new_session_tickets: u8, @@ -137,13 +179,22 @@ pub fn build_emulated_server_hello( sizes.push(cached.total_app_data_len.max(1024)); } let mut sizes = jitter_and_clamp_sizes(&sizes, rng); - let cert_payload = cached - .cert_payload + let compact_payload = cached + .cert_info .as_ref() - .map(|payload| payload.certificate_message.as_slice()) - .filter(|payload| !payload.is_empty()); + .and_then(build_compact_cert_info_payload); + let selected_payload: Option<&[u8]> = if use_full_cert_payload { + cached + .cert_payload + .as_ref() + .map(|payload| payload.certificate_message.as_slice()) + .filter(|payload| !payload.is_empty()) + .or_else(|| compact_payload.as_deref()) + } else { + compact_payload.as_deref() + }; - if let Some(payload) = cert_payload { + if let Some(payload) = selected_payload { sizes = ensure_payload_capacity(sizes, payload.len()); } @@ -155,15 +206,22 @@ pub fn build_emulated_server_hello( rec.extend_from_slice(&TLS_VERSION); rec.extend_from_slice(&(size as u16).to_be_bytes()); - if let Some(payload) = cert_payload { - let remaining = payload.len().saturating_sub(payload_offset); - let copy_len = remaining.min(size); - if copy_len > 0 { - rec.extend_from_slice(&payload[payload_offset..payload_offset + copy_len]); - payload_offset += copy_len; - } - if size > copy_len { - rec.extend_from_slice(&rng.bytes(size - copy_len)); + if let Some(payload) = selected_payload { + if size > 17 { + let body_len = size - 17; + let remaining = payload.len().saturating_sub(payload_offset); + let copy_len = remaining.min(body_len); + if copy_len > 0 { + rec.extend_from_slice(&payload[payload_offset..payload_offset + copy_len]); + payload_offset += copy_len; + } + if body_len > copy_len { + rec.extend_from_slice(&rng.bytes(body_len - copy_len)); + } + rec.push(0x16); // inner content type marker (handshake) + rec.extend_from_slice(&rng.bytes(16)); // AEAD-like tag + } else { + rec.extend_from_slice(&rng.bytes(size)); } } else { if size > 17 { @@ -262,6 +320,7 @@ mod tests { &[0x11; 32], &[0x22; 16], &cached, + true, &rng, None, 0, @@ -287,6 +346,7 @@ mod tests { &[0x22; 32], &[0x33; 16], &cached, + true, &rng, None, 0, @@ -296,4 +356,35 @@ mod tests { assert!(payload.len() >= 64); assert_eq!(payload[payload.len() - 17], 0x16); } + + #[test] + fn test_build_emulated_server_hello_uses_compact_payload_after_first() { + let cert_msg = vec![0x0b, 0x00, 0x00, 0x05, 0x00, 0xaa, 0xbb, 0xcc, 0xdd]; + let mut cached = make_cached(Some(TlsCertPayload { + cert_chain_der: vec![vec![0x30, 0x01, 0x00]], + certificate_message: cert_msg, + })); + cached.cert_info = Some(crate::tls_front::types::ParsedCertificateInfo { + not_after_unix: Some(1_900_000_000), + not_before_unix: Some(1_700_000_000), + issuer_cn: Some("Issuer".to_string()), + subject_cn: Some("example.com".to_string()), + san_names: vec!["example.com".to_string(), "www.example.com".to_string()], + }); + + let rng = SecureRandom::new(); + let response = build_emulated_server_hello( + b"secret", + &[0x44; 32], + &[0x55; 16], + &cached, + false, + &rng, + None, + 0, + ); + + let payload = first_app_data_payload(&response); + assert!(payload.starts_with(b"CN=example.com")); + } }