mirror of https://github.com/telemt/telemt.git
TLS Fetch on unix-socket
This commit is contained in:
parent
e0d5561095
commit
a61882af6e
|
|
@ -285,17 +285,20 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
|
|||
.mask_host
|
||||
.clone()
|
||||
.unwrap_or_else(|| config.censorship.tls_domain.clone());
|
||||
let mask_unix_sock = config.censorship.mask_unix_sock.clone();
|
||||
let fetch_timeout = Duration::from_secs(5);
|
||||
|
||||
let cache_initial = cache.clone();
|
||||
let domains_initial = tls_domains.clone();
|
||||
let host_initial = mask_host.clone();
|
||||
let unix_sock_initial = mask_unix_sock.clone();
|
||||
let upstream_initial = upstream_manager.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut join = tokio::task::JoinSet::new();
|
||||
for domain in domains_initial {
|
||||
let cache_domain = cache_initial.clone();
|
||||
let host_domain = host_initial.clone();
|
||||
let unix_sock_domain = unix_sock_initial.clone();
|
||||
let upstream_domain = upstream_initial.clone();
|
||||
join.spawn(async move {
|
||||
match crate::tls_front::fetcher::fetch_real_tls(
|
||||
|
|
@ -305,6 +308,7 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
|
|||
fetch_timeout,
|
||||
Some(upstream_domain),
|
||||
proxy_protocol,
|
||||
unix_sock_domain.as_deref(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
|
|
@ -344,6 +348,7 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
|
|||
let cache_refresh = cache.clone();
|
||||
let domains_refresh = tls_domains.clone();
|
||||
let host_refresh = mask_host.clone();
|
||||
let unix_sock_refresh = mask_unix_sock.clone();
|
||||
let upstream_refresh = upstream_manager.clone();
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
|
|
@ -355,6 +360,7 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
|
|||
for domain in domains_refresh.clone() {
|
||||
let cache_domain = cache_refresh.clone();
|
||||
let host_domain = host_refresh.clone();
|
||||
let unix_sock_domain = unix_sock_refresh.clone();
|
||||
let upstream_domain = upstream_refresh.clone();
|
||||
join.spawn(async move {
|
||||
match crate::tls_front::fetcher::fetch_real_tls(
|
||||
|
|
@ -364,6 +370,7 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
|
|||
fetch_timeout,
|
||||
Some(upstream_domain),
|
||||
proxy_protocol,
|
||||
unix_sock_domain.as_deref(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
|
|
|
|||
|
|
@ -2,8 +2,10 @@ use std::sync::Arc;
|
|||
use std::time::Duration;
|
||||
|
||||
use anyhow::{Result, anyhow};
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
use tokio::net::TcpStream;
|
||||
#[cfg(unix)]
|
||||
use tokio::net::UnixStream;
|
||||
use tokio::time::timeout;
|
||||
use tokio_rustls::client::TlsStream;
|
||||
use tokio_rustls::TlsConnector;
|
||||
|
|
@ -212,7 +214,10 @@ fn gen_key_share(rng: &SecureRandom) -> [u8; 32] {
|
|||
key
|
||||
}
|
||||
|
||||
async fn read_tls_record(stream: &mut TcpStream) -> Result<(u8, Vec<u8>)> {
|
||||
async fn read_tls_record<S>(stream: &mut S) -> Result<(u8, Vec<u8>)>
|
||||
where
|
||||
S: AsyncRead + Unpin,
|
||||
{
|
||||
let mut header = [0u8; 5];
|
||||
stream.read_exact(&mut header).await?;
|
||||
let len = u16::from_be_bytes([header[3], header[4]]) as usize;
|
||||
|
|
@ -345,6 +350,44 @@ async fn connect_with_dns_override(
|
|||
Ok(timeout(connect_timeout, TcpStream::connect((host, port))).await??)
|
||||
}
|
||||
|
||||
async fn connect_tcp_with_upstream(
|
||||
host: &str,
|
||||
port: u16,
|
||||
connect_timeout: Duration,
|
||||
upstream: Option<std::sync::Arc<crate::transport::UpstreamManager>>,
|
||||
) -> Result<TcpStream> {
|
||||
if let Some(manager) = upstream {
|
||||
if let Some(addr) = resolve_socket_addr(host, port) {
|
||||
match manager.connect(addr, None, None).await {
|
||||
Ok(stream) => return Ok(stream),
|
||||
Err(e) => {
|
||||
warn!(
|
||||
host = %host,
|
||||
port = port,
|
||||
error = %e,
|
||||
"Upstream connect failed, using direct connect"
|
||||
);
|
||||
}
|
||||
}
|
||||
} else if let Ok(mut addrs) = tokio::net::lookup_host((host, port)).await {
|
||||
if let Some(addr) = addrs.find(|a| a.is_ipv4()) {
|
||||
match manager.connect(addr, None, None).await {
|
||||
Ok(stream) => return Ok(stream),
|
||||
Err(e) => {
|
||||
warn!(
|
||||
host = %host,
|
||||
port = port,
|
||||
error = %e,
|
||||
"Upstream connect failed, using direct connect"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
connect_with_dns_override(host, port, connect_timeout).await
|
||||
}
|
||||
|
||||
fn encode_tls13_certificate_message(cert_chain_der: &[Vec<u8>]) -> Option<Vec<u8>> {
|
||||
if cert_chain_der.is_empty() {
|
||||
return None;
|
||||
|
|
@ -374,15 +417,15 @@ fn encode_tls13_certificate_message(cert_chain_der: &[Vec<u8>]) -> Option<Vec<u8
|
|||
Some(message)
|
||||
}
|
||||
|
||||
async fn fetch_via_raw_tls(
|
||||
host: &str,
|
||||
port: u16,
|
||||
async fn fetch_via_raw_tls_stream<S>(
|
||||
mut stream: S,
|
||||
sni: &str,
|
||||
connect_timeout: Duration,
|
||||
proxy_protocol: u8,
|
||||
) -> Result<TlsFetchResult> {
|
||||
let mut stream = connect_with_dns_override(host, port, connect_timeout).await?;
|
||||
|
||||
) -> Result<TlsFetchResult>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
let rng = SecureRandom::new();
|
||||
let client_hello = build_client_hello(sni, &rng);
|
||||
timeout(connect_timeout, async {
|
||||
|
|
@ -438,43 +481,61 @@ async fn fetch_via_raw_tls(
|
|||
})
|
||||
}
|
||||
|
||||
async fn fetch_via_rustls(
|
||||
async fn fetch_via_raw_tls(
|
||||
host: &str,
|
||||
port: u16,
|
||||
sni: &str,
|
||||
connect_timeout: Duration,
|
||||
upstream: Option<std::sync::Arc<crate::transport::UpstreamManager>>,
|
||||
proxy_protocol: u8,
|
||||
unix_sock: Option<&str>,
|
||||
) -> Result<TlsFetchResult> {
|
||||
// rustls handshake path for certificate and basic negotiated metadata.
|
||||
let mut stream = if let Some(manager) = upstream {
|
||||
if let Some(addr) = resolve_socket_addr(host, port) {
|
||||
match manager.connect(addr, None, None).await {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
warn!(sni = %sni, error = %e, "Upstream connect failed, using direct connect");
|
||||
connect_with_dns_override(host, port, connect_timeout).await?
|
||||
#[cfg(unix)]
|
||||
if let Some(sock_path) = unix_sock {
|
||||
match timeout(connect_timeout, UnixStream::connect(sock_path)).await {
|
||||
Ok(Ok(stream)) => {
|
||||
debug!(
|
||||
sni = %sni,
|
||||
sock = %sock_path,
|
||||
"Raw TLS fetch using mask unix socket"
|
||||
);
|
||||
return fetch_via_raw_tls_stream(stream, sni, connect_timeout, 0).await;
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
warn!(
|
||||
sni = %sni,
|
||||
sock = %sock_path,
|
||||
error = %e,
|
||||
"Raw TLS unix socket connect failed, falling back to TCP"
|
||||
);
|
||||
}
|
||||
Err(_) => {
|
||||
warn!(
|
||||
sni = %sni,
|
||||
sock = %sock_path,
|
||||
"Raw TLS unix socket connect timed out, falling back to TCP"
|
||||
);
|
||||
}
|
||||
}
|
||||
} else if let Ok(mut addrs) = tokio::net::lookup_host((host, port)).await {
|
||||
if let Some(addr) = addrs.find(|a| a.is_ipv4()) {
|
||||
match manager.connect(addr, None, None).await {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
warn!(sni = %sni, error = %e, "Upstream connect failed, using direct connect");
|
||||
connect_with_dns_override(host, port, connect_timeout).await?
|
||||
}
|
||||
}
|
||||
} else {
|
||||
connect_with_dns_override(host, port, connect_timeout).await?
|
||||
}
|
||||
} else {
|
||||
connect_with_dns_override(host, port, connect_timeout).await?
|
||||
}
|
||||
} else {
|
||||
connect_with_dns_override(host, port, connect_timeout).await?
|
||||
};
|
||||
|
||||
#[cfg(not(unix))]
|
||||
let _ = unix_sock;
|
||||
|
||||
let stream = connect_tcp_with_upstream(host, port, connect_timeout, upstream).await?;
|
||||
fetch_via_raw_tls_stream(stream, sni, connect_timeout, proxy_protocol).await
|
||||
}
|
||||
|
||||
async fn fetch_via_rustls_stream<S>(
|
||||
mut stream: S,
|
||||
host: &str,
|
||||
sni: &str,
|
||||
proxy_protocol: u8,
|
||||
) -> Result<TlsFetchResult>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
// rustls handshake path for certificate and basic negotiated metadata.
|
||||
if proxy_protocol > 0 {
|
||||
let header = match proxy_protocol {
|
||||
2 => ProxyProtocolV2Builder::new().build(),
|
||||
|
|
@ -491,7 +552,7 @@ async fn fetch_via_rustls(
|
|||
.or_else(|_| ServerName::try_from(host.to_owned()))
|
||||
.map_err(|_| RustlsError::General("invalid SNI".into()))?;
|
||||
|
||||
let tls_stream: TlsStream<TcpStream> = connector.connect(server_name, stream).await?;
|
||||
let tls_stream: TlsStream<S> = connector.connect(server_name, stream).await?;
|
||||
|
||||
// Extract negotiated parameters and certificates
|
||||
let (_io, session) = tls_stream.get_ref();
|
||||
|
|
@ -552,6 +613,51 @@ async fn fetch_via_rustls(
|
|||
})
|
||||
}
|
||||
|
||||
async fn fetch_via_rustls(
|
||||
host: &str,
|
||||
port: u16,
|
||||
sni: &str,
|
||||
connect_timeout: Duration,
|
||||
upstream: Option<std::sync::Arc<crate::transport::UpstreamManager>>,
|
||||
proxy_protocol: u8,
|
||||
unix_sock: Option<&str>,
|
||||
) -> Result<TlsFetchResult> {
|
||||
#[cfg(unix)]
|
||||
if let Some(sock_path) = unix_sock {
|
||||
match timeout(connect_timeout, UnixStream::connect(sock_path)).await {
|
||||
Ok(Ok(stream)) => {
|
||||
debug!(
|
||||
sni = %sni,
|
||||
sock = %sock_path,
|
||||
"Rustls fetch using mask unix socket"
|
||||
);
|
||||
return fetch_via_rustls_stream(stream, host, sni, 0).await;
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
warn!(
|
||||
sni = %sni,
|
||||
sock = %sock_path,
|
||||
error = %e,
|
||||
"Rustls unix socket connect failed, falling back to TCP"
|
||||
);
|
||||
}
|
||||
Err(_) => {
|
||||
warn!(
|
||||
sni = %sni,
|
||||
sock = %sock_path,
|
||||
"Rustls unix socket connect timed out, falling back to TCP"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(unix))]
|
||||
let _ = unix_sock;
|
||||
|
||||
let stream = connect_tcp_with_upstream(host, port, connect_timeout, upstream).await?;
|
||||
fetch_via_rustls_stream(stream, host, sni, proxy_protocol).await
|
||||
}
|
||||
|
||||
/// Fetch real TLS metadata for the given SNI.
|
||||
///
|
||||
/// Strategy:
|
||||
|
|
@ -565,8 +671,19 @@ pub async fn fetch_real_tls(
|
|||
connect_timeout: Duration,
|
||||
upstream: Option<std::sync::Arc<crate::transport::UpstreamManager>>,
|
||||
proxy_protocol: u8,
|
||||
unix_sock: Option<&str>,
|
||||
) -> Result<TlsFetchResult> {
|
||||
let raw_result = match fetch_via_raw_tls(host, port, sni, connect_timeout, proxy_protocol).await {
|
||||
let raw_result = match fetch_via_raw_tls(
|
||||
host,
|
||||
port,
|
||||
sni,
|
||||
connect_timeout,
|
||||
upstream.clone(),
|
||||
proxy_protocol,
|
||||
unix_sock,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(res) => Some(res),
|
||||
Err(e) => {
|
||||
warn!(sni = %sni, error = %e, "Raw TLS fetch failed");
|
||||
|
|
@ -574,7 +691,17 @@ pub async fn fetch_real_tls(
|
|||
}
|
||||
};
|
||||
|
||||
match fetch_via_rustls(host, port, sni, connect_timeout, upstream, proxy_protocol).await {
|
||||
match fetch_via_rustls(
|
||||
host,
|
||||
port,
|
||||
sni,
|
||||
connect_timeout,
|
||||
upstream,
|
||||
proxy_protocol,
|
||||
unix_sock,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(rustls_result) => {
|
||||
if let Some(mut raw) = raw_result {
|
||||
raw.cert_info = rustls_result.cert_info;
|
||||
|
|
|
|||
Loading…
Reference in New Issue