diff --git a/src/main.rs b/src/main.rs index b065d4e..e759095 100644 --- a/src/main.rs +++ b/src/main.rs @@ -285,17 +285,20 @@ async fn main() -> std::result::Result<(), Box> { .mask_host .clone() .unwrap_or_else(|| config.censorship.tls_domain.clone()); + let mask_unix_sock = config.censorship.mask_unix_sock.clone(); let fetch_timeout = Duration::from_secs(5); let cache_initial = cache.clone(); let domains_initial = tls_domains.clone(); let host_initial = mask_host.clone(); + let unix_sock_initial = mask_unix_sock.clone(); let upstream_initial = upstream_manager.clone(); tokio::spawn(async move { let mut join = tokio::task::JoinSet::new(); for domain in domains_initial { let cache_domain = cache_initial.clone(); let host_domain = host_initial.clone(); + let unix_sock_domain = unix_sock_initial.clone(); let upstream_domain = upstream_initial.clone(); join.spawn(async move { match crate::tls_front::fetcher::fetch_real_tls( @@ -305,6 +308,7 @@ async fn main() -> std::result::Result<(), Box> { fetch_timeout, Some(upstream_domain), proxy_protocol, + unix_sock_domain.as_deref(), ) .await { @@ -344,6 +348,7 @@ async fn main() -> std::result::Result<(), Box> { let cache_refresh = cache.clone(); let domains_refresh = tls_domains.clone(); let host_refresh = mask_host.clone(); + let unix_sock_refresh = mask_unix_sock.clone(); let upstream_refresh = upstream_manager.clone(); tokio::spawn(async move { loop { @@ -355,6 +360,7 @@ async fn main() -> std::result::Result<(), Box> { for domain in domains_refresh.clone() { let cache_domain = cache_refresh.clone(); let host_domain = host_refresh.clone(); + let unix_sock_domain = unix_sock_refresh.clone(); let upstream_domain = upstream_refresh.clone(); join.spawn(async move { match crate::tls_front::fetcher::fetch_real_tls( @@ -364,6 +370,7 @@ async fn main() -> std::result::Result<(), Box> { fetch_timeout, Some(upstream_domain), proxy_protocol, + unix_sock_domain.as_deref(), ) .await { diff --git a/src/tls_front/fetcher.rs b/src/tls_front/fetcher.rs index ba80332..1731cdc 100644 --- a/src/tls_front/fetcher.rs +++ b/src/tls_front/fetcher.rs @@ -2,8 +2,10 @@ use std::sync::Arc; use std::time::Duration; use anyhow::{Result, anyhow}; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::net::TcpStream; +#[cfg(unix)] +use tokio::net::UnixStream; use tokio::time::timeout; use tokio_rustls::client::TlsStream; use tokio_rustls::TlsConnector; @@ -212,7 +214,10 @@ fn gen_key_share(rng: &SecureRandom) -> [u8; 32] { key } -async fn read_tls_record(stream: &mut TcpStream) -> Result<(u8, Vec)> { +async fn read_tls_record(stream: &mut S) -> Result<(u8, Vec)> +where + S: AsyncRead + Unpin, +{ let mut header = [0u8; 5]; stream.read_exact(&mut header).await?; let len = u16::from_be_bytes([header[3], header[4]]) as usize; @@ -345,6 +350,44 @@ async fn connect_with_dns_override( Ok(timeout(connect_timeout, TcpStream::connect((host, port))).await??) } +async fn connect_tcp_with_upstream( + host: &str, + port: u16, + connect_timeout: Duration, + upstream: Option>, +) -> Result { + if let Some(manager) = upstream { + if let Some(addr) = resolve_socket_addr(host, port) { + match manager.connect(addr, None, None).await { + Ok(stream) => return Ok(stream), + Err(e) => { + warn!( + host = %host, + port = port, + error = %e, + "Upstream connect failed, using direct connect" + ); + } + } + } else if let Ok(mut addrs) = tokio::net::lookup_host((host, port)).await { + if let Some(addr) = addrs.find(|a| a.is_ipv4()) { + match manager.connect(addr, None, None).await { + Ok(stream) => return Ok(stream), + Err(e) => { + warn!( + host = %host, + port = port, + error = %e, + "Upstream connect failed, using direct connect" + ); + } + } + } + } + } + connect_with_dns_override(host, port, connect_timeout).await +} + fn encode_tls13_certificate_message(cert_chain_der: &[Vec]) -> Option> { if cert_chain_der.is_empty() { return None; @@ -374,15 +417,15 @@ fn encode_tls13_certificate_message(cert_chain_der: &[Vec]) -> Option( + mut stream: S, sni: &str, connect_timeout: Duration, proxy_protocol: u8, -) -> Result { - let mut stream = connect_with_dns_override(host, port, connect_timeout).await?; - +) -> Result +where + S: AsyncRead + AsyncWrite + Unpin, +{ let rng = SecureRandom::new(); let client_hello = build_client_hello(sni, &rng); timeout(connect_timeout, async { @@ -438,43 +481,61 @@ async fn fetch_via_raw_tls( }) } -async fn fetch_via_rustls( +async fn fetch_via_raw_tls( host: &str, port: u16, sni: &str, connect_timeout: Duration, upstream: Option>, proxy_protocol: u8, + unix_sock: Option<&str>, ) -> Result { - // rustls handshake path for certificate and basic negotiated metadata. - let mut stream = if let Some(manager) = upstream { - if let Some(addr) = resolve_socket_addr(host, port) { - match manager.connect(addr, None, None).await { - Ok(s) => s, - Err(e) => { - warn!(sni = %sni, error = %e, "Upstream connect failed, using direct connect"); - connect_with_dns_override(host, port, connect_timeout).await? - } + #[cfg(unix)] + if let Some(sock_path) = unix_sock { + match timeout(connect_timeout, UnixStream::connect(sock_path)).await { + Ok(Ok(stream)) => { + debug!( + sni = %sni, + sock = %sock_path, + "Raw TLS fetch using mask unix socket" + ); + return fetch_via_raw_tls_stream(stream, sni, connect_timeout, 0).await; } - } else if let Ok(mut addrs) = tokio::net::lookup_host((host, port)).await { - if let Some(addr) = addrs.find(|a| a.is_ipv4()) { - match manager.connect(addr, None, None).await { - Ok(s) => s, - Err(e) => { - warn!(sni = %sni, error = %e, "Upstream connect failed, using direct connect"); - connect_with_dns_override(host, port, connect_timeout).await? - } - } - } else { - connect_with_dns_override(host, port, connect_timeout).await? + Ok(Err(e)) => { + warn!( + sni = %sni, + sock = %sock_path, + error = %e, + "Raw TLS unix socket connect failed, falling back to TCP" + ); + } + Err(_) => { + warn!( + sni = %sni, + sock = %sock_path, + "Raw TLS unix socket connect timed out, falling back to TCP" + ); } - } else { - connect_with_dns_override(host, port, connect_timeout).await? } - } else { - connect_with_dns_override(host, port, connect_timeout).await? - }; + } + #[cfg(not(unix))] + let _ = unix_sock; + + let stream = connect_tcp_with_upstream(host, port, connect_timeout, upstream).await?; + fetch_via_raw_tls_stream(stream, sni, connect_timeout, proxy_protocol).await +} + +async fn fetch_via_rustls_stream( + mut stream: S, + host: &str, + sni: &str, + proxy_protocol: u8, +) -> Result +where + S: AsyncRead + AsyncWrite + Unpin, +{ + // rustls handshake path for certificate and basic negotiated metadata. if proxy_protocol > 0 { let header = match proxy_protocol { 2 => ProxyProtocolV2Builder::new().build(), @@ -491,7 +552,7 @@ async fn fetch_via_rustls( .or_else(|_| ServerName::try_from(host.to_owned())) .map_err(|_| RustlsError::General("invalid SNI".into()))?; - let tls_stream: TlsStream = connector.connect(server_name, stream).await?; + let tls_stream: TlsStream = connector.connect(server_name, stream).await?; // Extract negotiated parameters and certificates let (_io, session) = tls_stream.get_ref(); @@ -552,6 +613,51 @@ async fn fetch_via_rustls( }) } +async fn fetch_via_rustls( + host: &str, + port: u16, + sni: &str, + connect_timeout: Duration, + upstream: Option>, + proxy_protocol: u8, + unix_sock: Option<&str>, +) -> Result { + #[cfg(unix)] + if let Some(sock_path) = unix_sock { + match timeout(connect_timeout, UnixStream::connect(sock_path)).await { + Ok(Ok(stream)) => { + debug!( + sni = %sni, + sock = %sock_path, + "Rustls fetch using mask unix socket" + ); + return fetch_via_rustls_stream(stream, host, sni, 0).await; + } + Ok(Err(e)) => { + warn!( + sni = %sni, + sock = %sock_path, + error = %e, + "Rustls unix socket connect failed, falling back to TCP" + ); + } + Err(_) => { + warn!( + sni = %sni, + sock = %sock_path, + "Rustls unix socket connect timed out, falling back to TCP" + ); + } + } + } + + #[cfg(not(unix))] + let _ = unix_sock; + + let stream = connect_tcp_with_upstream(host, port, connect_timeout, upstream).await?; + fetch_via_rustls_stream(stream, host, sni, proxy_protocol).await +} + /// Fetch real TLS metadata for the given SNI. /// /// Strategy: @@ -565,8 +671,19 @@ pub async fn fetch_real_tls( connect_timeout: Duration, upstream: Option>, proxy_protocol: u8, + unix_sock: Option<&str>, ) -> Result { - let raw_result = match fetch_via_raw_tls(host, port, sni, connect_timeout, proxy_protocol).await { + let raw_result = match fetch_via_raw_tls( + host, + port, + sni, + connect_timeout, + upstream.clone(), + proxy_protocol, + unix_sock, + ) + .await + { Ok(res) => Some(res), Err(e) => { warn!(sni = %sni, error = %e, "Raw TLS fetch failed"); @@ -574,7 +691,17 @@ pub async fn fetch_real_tls( } }; - match fetch_via_rustls(host, port, sni, connect_timeout, upstream, proxy_protocol).await { + match fetch_via_rustls( + host, + port, + sni, + connect_timeout, + upstream, + proxy_protocol, + unix_sock, + ) + .await + { Ok(rustls_result) => { if let Some(mut raw) = raw_result { raw.cert_info = rustls_result.cert_info;