TLS Fetch on unix-socket

This commit is contained in:
Alexey 2026-02-28 02:55:21 +03:00
parent e0d5561095
commit a61882af6e
No known key found for this signature in database
2 changed files with 171 additions and 37 deletions

View File

@ -285,17 +285,20 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
.mask_host .mask_host
.clone() .clone()
.unwrap_or_else(|| config.censorship.tls_domain.clone()); .unwrap_or_else(|| config.censorship.tls_domain.clone());
let mask_unix_sock = config.censorship.mask_unix_sock.clone();
let fetch_timeout = Duration::from_secs(5); let fetch_timeout = Duration::from_secs(5);
let cache_initial = cache.clone(); let cache_initial = cache.clone();
let domains_initial = tls_domains.clone(); let domains_initial = tls_domains.clone();
let host_initial = mask_host.clone(); let host_initial = mask_host.clone();
let unix_sock_initial = mask_unix_sock.clone();
let upstream_initial = upstream_manager.clone(); let upstream_initial = upstream_manager.clone();
tokio::spawn(async move { tokio::spawn(async move {
let mut join = tokio::task::JoinSet::new(); let mut join = tokio::task::JoinSet::new();
for domain in domains_initial { for domain in domains_initial {
let cache_domain = cache_initial.clone(); let cache_domain = cache_initial.clone();
let host_domain = host_initial.clone(); let host_domain = host_initial.clone();
let unix_sock_domain = unix_sock_initial.clone();
let upstream_domain = upstream_initial.clone(); let upstream_domain = upstream_initial.clone();
join.spawn(async move { join.spawn(async move {
match crate::tls_front::fetcher::fetch_real_tls( match crate::tls_front::fetcher::fetch_real_tls(
@ -305,6 +308,7 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
fetch_timeout, fetch_timeout,
Some(upstream_domain), Some(upstream_domain),
proxy_protocol, proxy_protocol,
unix_sock_domain.as_deref(),
) )
.await .await
{ {
@ -344,6 +348,7 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
let cache_refresh = cache.clone(); let cache_refresh = cache.clone();
let domains_refresh = tls_domains.clone(); let domains_refresh = tls_domains.clone();
let host_refresh = mask_host.clone(); let host_refresh = mask_host.clone();
let unix_sock_refresh = mask_unix_sock.clone();
let upstream_refresh = upstream_manager.clone(); let upstream_refresh = upstream_manager.clone();
tokio::spawn(async move { tokio::spawn(async move {
loop { loop {
@ -355,6 +360,7 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
for domain in domains_refresh.clone() { for domain in domains_refresh.clone() {
let cache_domain = cache_refresh.clone(); let cache_domain = cache_refresh.clone();
let host_domain = host_refresh.clone(); let host_domain = host_refresh.clone();
let unix_sock_domain = unix_sock_refresh.clone();
let upstream_domain = upstream_refresh.clone(); let upstream_domain = upstream_refresh.clone();
join.spawn(async move { join.spawn(async move {
match crate::tls_front::fetcher::fetch_real_tls( match crate::tls_front::fetcher::fetch_real_tls(
@ -364,6 +370,7 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
fetch_timeout, fetch_timeout,
Some(upstream_domain), Some(upstream_domain),
proxy_protocol, proxy_protocol,
unix_sock_domain.as_deref(),
) )
.await .await
{ {

View File

@ -2,8 +2,10 @@ use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use anyhow::{Result, anyhow}; use anyhow::{Result, anyhow};
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpStream; use tokio::net::TcpStream;
#[cfg(unix)]
use tokio::net::UnixStream;
use tokio::time::timeout; use tokio::time::timeout;
use tokio_rustls::client::TlsStream; use tokio_rustls::client::TlsStream;
use tokio_rustls::TlsConnector; use tokio_rustls::TlsConnector;
@ -212,7 +214,10 @@ fn gen_key_share(rng: &SecureRandom) -> [u8; 32] {
key key
} }
async fn read_tls_record(stream: &mut TcpStream) -> Result<(u8, Vec<u8>)> { async fn read_tls_record<S>(stream: &mut S) -> Result<(u8, Vec<u8>)>
where
S: AsyncRead + Unpin,
{
let mut header = [0u8; 5]; let mut header = [0u8; 5];
stream.read_exact(&mut header).await?; stream.read_exact(&mut header).await?;
let len = u16::from_be_bytes([header[3], header[4]]) as usize; let len = u16::from_be_bytes([header[3], header[4]]) as usize;
@ -345,6 +350,44 @@ async fn connect_with_dns_override(
Ok(timeout(connect_timeout, TcpStream::connect((host, port))).await??) Ok(timeout(connect_timeout, TcpStream::connect((host, port))).await??)
} }
async fn connect_tcp_with_upstream(
host: &str,
port: u16,
connect_timeout: Duration,
upstream: Option<std::sync::Arc<crate::transport::UpstreamManager>>,
) -> Result<TcpStream> {
if let Some(manager) = upstream {
if let Some(addr) = resolve_socket_addr(host, port) {
match manager.connect(addr, None, None).await {
Ok(stream) => return Ok(stream),
Err(e) => {
warn!(
host = %host,
port = port,
error = %e,
"Upstream connect failed, using direct connect"
);
}
}
} else if let Ok(mut addrs) = tokio::net::lookup_host((host, port)).await {
if let Some(addr) = addrs.find(|a| a.is_ipv4()) {
match manager.connect(addr, None, None).await {
Ok(stream) => return Ok(stream),
Err(e) => {
warn!(
host = %host,
port = port,
error = %e,
"Upstream connect failed, using direct connect"
);
}
}
}
}
}
connect_with_dns_override(host, port, connect_timeout).await
}
fn encode_tls13_certificate_message(cert_chain_der: &[Vec<u8>]) -> Option<Vec<u8>> { fn encode_tls13_certificate_message(cert_chain_der: &[Vec<u8>]) -> Option<Vec<u8>> {
if cert_chain_der.is_empty() { if cert_chain_der.is_empty() {
return None; return None;
@ -374,15 +417,15 @@ fn encode_tls13_certificate_message(cert_chain_der: &[Vec<u8>]) -> Option<Vec<u8
Some(message) Some(message)
} }
async fn fetch_via_raw_tls( async fn fetch_via_raw_tls_stream<S>(
host: &str, mut stream: S,
port: u16,
sni: &str, sni: &str,
connect_timeout: Duration, connect_timeout: Duration,
proxy_protocol: u8, proxy_protocol: u8,
) -> Result<TlsFetchResult> { ) -> Result<TlsFetchResult>
let mut stream = connect_with_dns_override(host, port, connect_timeout).await?; where
S: AsyncRead + AsyncWrite + Unpin,
{
let rng = SecureRandom::new(); let rng = SecureRandom::new();
let client_hello = build_client_hello(sni, &rng); let client_hello = build_client_hello(sni, &rng);
timeout(connect_timeout, async { timeout(connect_timeout, async {
@ -438,43 +481,61 @@ async fn fetch_via_raw_tls(
}) })
} }
async fn fetch_via_rustls( async fn fetch_via_raw_tls(
host: &str, host: &str,
port: u16, port: u16,
sni: &str, sni: &str,
connect_timeout: Duration, connect_timeout: Duration,
upstream: Option<std::sync::Arc<crate::transport::UpstreamManager>>, upstream: Option<std::sync::Arc<crate::transport::UpstreamManager>>,
proxy_protocol: u8, proxy_protocol: u8,
unix_sock: Option<&str>,
) -> Result<TlsFetchResult> { ) -> Result<TlsFetchResult> {
// rustls handshake path for certificate and basic negotiated metadata. #[cfg(unix)]
let mut stream = if let Some(manager) = upstream { if let Some(sock_path) = unix_sock {
if let Some(addr) = resolve_socket_addr(host, port) { match timeout(connect_timeout, UnixStream::connect(sock_path)).await {
match manager.connect(addr, None, None).await { Ok(Ok(stream)) => {
Ok(s) => s, debug!(
Err(e) => { sni = %sni,
warn!(sni = %sni, error = %e, "Upstream connect failed, using direct connect"); sock = %sock_path,
connect_with_dns_override(host, port, connect_timeout).await? "Raw TLS fetch using mask unix socket"
);
return fetch_via_raw_tls_stream(stream, sni, connect_timeout, 0).await;
}
Ok(Err(e)) => {
warn!(
sni = %sni,
sock = %sock_path,
error = %e,
"Raw TLS unix socket connect failed, falling back to TCP"
);
}
Err(_) => {
warn!(
sni = %sni,
sock = %sock_path,
"Raw TLS unix socket connect timed out, falling back to TCP"
);
} }
} }
} else if let Ok(mut addrs) = tokio::net::lookup_host((host, port)).await {
if let Some(addr) = addrs.find(|a| a.is_ipv4()) {
match manager.connect(addr, None, None).await {
Ok(s) => s,
Err(e) => {
warn!(sni = %sni, error = %e, "Upstream connect failed, using direct connect");
connect_with_dns_override(host, port, connect_timeout).await?
} }
}
} else {
connect_with_dns_override(host, port, connect_timeout).await?
}
} else {
connect_with_dns_override(host, port, connect_timeout).await?
}
} else {
connect_with_dns_override(host, port, connect_timeout).await?
};
#[cfg(not(unix))]
let _ = unix_sock;
let stream = connect_tcp_with_upstream(host, port, connect_timeout, upstream).await?;
fetch_via_raw_tls_stream(stream, sni, connect_timeout, proxy_protocol).await
}
async fn fetch_via_rustls_stream<S>(
mut stream: S,
host: &str,
sni: &str,
proxy_protocol: u8,
) -> Result<TlsFetchResult>
where
S: AsyncRead + AsyncWrite + Unpin,
{
// rustls handshake path for certificate and basic negotiated metadata.
if proxy_protocol > 0 { if proxy_protocol > 0 {
let header = match proxy_protocol { let header = match proxy_protocol {
2 => ProxyProtocolV2Builder::new().build(), 2 => ProxyProtocolV2Builder::new().build(),
@ -491,7 +552,7 @@ async fn fetch_via_rustls(
.or_else(|_| ServerName::try_from(host.to_owned())) .or_else(|_| ServerName::try_from(host.to_owned()))
.map_err(|_| RustlsError::General("invalid SNI".into()))?; .map_err(|_| RustlsError::General("invalid SNI".into()))?;
let tls_stream: TlsStream<TcpStream> = connector.connect(server_name, stream).await?; let tls_stream: TlsStream<S> = connector.connect(server_name, stream).await?;
// Extract negotiated parameters and certificates // Extract negotiated parameters and certificates
let (_io, session) = tls_stream.get_ref(); let (_io, session) = tls_stream.get_ref();
@ -552,6 +613,51 @@ async fn fetch_via_rustls(
}) })
} }
async fn fetch_via_rustls(
host: &str,
port: u16,
sni: &str,
connect_timeout: Duration,
upstream: Option<std::sync::Arc<crate::transport::UpstreamManager>>,
proxy_protocol: u8,
unix_sock: Option<&str>,
) -> Result<TlsFetchResult> {
#[cfg(unix)]
if let Some(sock_path) = unix_sock {
match timeout(connect_timeout, UnixStream::connect(sock_path)).await {
Ok(Ok(stream)) => {
debug!(
sni = %sni,
sock = %sock_path,
"Rustls fetch using mask unix socket"
);
return fetch_via_rustls_stream(stream, host, sni, 0).await;
}
Ok(Err(e)) => {
warn!(
sni = %sni,
sock = %sock_path,
error = %e,
"Rustls unix socket connect failed, falling back to TCP"
);
}
Err(_) => {
warn!(
sni = %sni,
sock = %sock_path,
"Rustls unix socket connect timed out, falling back to TCP"
);
}
}
}
#[cfg(not(unix))]
let _ = unix_sock;
let stream = connect_tcp_with_upstream(host, port, connect_timeout, upstream).await?;
fetch_via_rustls_stream(stream, host, sni, proxy_protocol).await
}
/// Fetch real TLS metadata for the given SNI. /// Fetch real TLS metadata for the given SNI.
/// ///
/// Strategy: /// Strategy:
@ -565,8 +671,19 @@ pub async fn fetch_real_tls(
connect_timeout: Duration, connect_timeout: Duration,
upstream: Option<std::sync::Arc<crate::transport::UpstreamManager>>, upstream: Option<std::sync::Arc<crate::transport::UpstreamManager>>,
proxy_protocol: u8, proxy_protocol: u8,
unix_sock: Option<&str>,
) -> Result<TlsFetchResult> { ) -> Result<TlsFetchResult> {
let raw_result = match fetch_via_raw_tls(host, port, sni, connect_timeout, proxy_protocol).await { let raw_result = match fetch_via_raw_tls(
host,
port,
sni,
connect_timeout,
upstream.clone(),
proxy_protocol,
unix_sock,
)
.await
{
Ok(res) => Some(res), Ok(res) => Some(res),
Err(e) => { Err(e) => {
warn!(sni = %sni, error = %e, "Raw TLS fetch failed"); warn!(sni = %sni, error = %e, "Raw TLS fetch failed");
@ -574,7 +691,17 @@ pub async fn fetch_real_tls(
} }
}; };
match fetch_via_rustls(host, port, sni, connect_timeout, upstream, proxy_protocol).await { match fetch_via_rustls(
host,
port,
sni,
connect_timeout,
upstream,
proxy_protocol,
unix_sock,
)
.await
{
Ok(rustls_result) => { Ok(rustls_result) => {
if let Some(mut raw) = raw_result { if let Some(mut raw) = raw_result {
raw.cert_info = rustls_result.cert_info; raw.cert_info = rustls_result.cert_info;