mirror of https://github.com/telemt/telemt.git
TLS Fetch on unix-socket
This commit is contained in:
parent
e0d5561095
commit
a61882af6e
|
|
@ -285,17 +285,20 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
|
||||||
.mask_host
|
.mask_host
|
||||||
.clone()
|
.clone()
|
||||||
.unwrap_or_else(|| config.censorship.tls_domain.clone());
|
.unwrap_or_else(|| config.censorship.tls_domain.clone());
|
||||||
|
let mask_unix_sock = config.censorship.mask_unix_sock.clone();
|
||||||
let fetch_timeout = Duration::from_secs(5);
|
let fetch_timeout = Duration::from_secs(5);
|
||||||
|
|
||||||
let cache_initial = cache.clone();
|
let cache_initial = cache.clone();
|
||||||
let domains_initial = tls_domains.clone();
|
let domains_initial = tls_domains.clone();
|
||||||
let host_initial = mask_host.clone();
|
let host_initial = mask_host.clone();
|
||||||
|
let unix_sock_initial = mask_unix_sock.clone();
|
||||||
let upstream_initial = upstream_manager.clone();
|
let upstream_initial = upstream_manager.clone();
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
let mut join = tokio::task::JoinSet::new();
|
let mut join = tokio::task::JoinSet::new();
|
||||||
for domain in domains_initial {
|
for domain in domains_initial {
|
||||||
let cache_domain = cache_initial.clone();
|
let cache_domain = cache_initial.clone();
|
||||||
let host_domain = host_initial.clone();
|
let host_domain = host_initial.clone();
|
||||||
|
let unix_sock_domain = unix_sock_initial.clone();
|
||||||
let upstream_domain = upstream_initial.clone();
|
let upstream_domain = upstream_initial.clone();
|
||||||
join.spawn(async move {
|
join.spawn(async move {
|
||||||
match crate::tls_front::fetcher::fetch_real_tls(
|
match crate::tls_front::fetcher::fetch_real_tls(
|
||||||
|
|
@ -305,6 +308,7 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
|
||||||
fetch_timeout,
|
fetch_timeout,
|
||||||
Some(upstream_domain),
|
Some(upstream_domain),
|
||||||
proxy_protocol,
|
proxy_protocol,
|
||||||
|
unix_sock_domain.as_deref(),
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
|
|
@ -344,6 +348,7 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
|
||||||
let cache_refresh = cache.clone();
|
let cache_refresh = cache.clone();
|
||||||
let domains_refresh = tls_domains.clone();
|
let domains_refresh = tls_domains.clone();
|
||||||
let host_refresh = mask_host.clone();
|
let host_refresh = mask_host.clone();
|
||||||
|
let unix_sock_refresh = mask_unix_sock.clone();
|
||||||
let upstream_refresh = upstream_manager.clone();
|
let upstream_refresh = upstream_manager.clone();
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
loop {
|
loop {
|
||||||
|
|
@ -355,6 +360,7 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
|
||||||
for domain in domains_refresh.clone() {
|
for domain in domains_refresh.clone() {
|
||||||
let cache_domain = cache_refresh.clone();
|
let cache_domain = cache_refresh.clone();
|
||||||
let host_domain = host_refresh.clone();
|
let host_domain = host_refresh.clone();
|
||||||
|
let unix_sock_domain = unix_sock_refresh.clone();
|
||||||
let upstream_domain = upstream_refresh.clone();
|
let upstream_domain = upstream_refresh.clone();
|
||||||
join.spawn(async move {
|
join.spawn(async move {
|
||||||
match crate::tls_front::fetcher::fetch_real_tls(
|
match crate::tls_front::fetcher::fetch_real_tls(
|
||||||
|
|
@ -364,6 +370,7 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
|
||||||
fetch_timeout,
|
fetch_timeout,
|
||||||
Some(upstream_domain),
|
Some(upstream_domain),
|
||||||
proxy_protocol,
|
proxy_protocol,
|
||||||
|
unix_sock_domain.as_deref(),
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
|
|
|
||||||
|
|
@ -2,8 +2,10 @@ use std::sync::Arc;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
use anyhow::{Result, anyhow};
|
use anyhow::{Result, anyhow};
|
||||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||||
use tokio::net::TcpStream;
|
use tokio::net::TcpStream;
|
||||||
|
#[cfg(unix)]
|
||||||
|
use tokio::net::UnixStream;
|
||||||
use tokio::time::timeout;
|
use tokio::time::timeout;
|
||||||
use tokio_rustls::client::TlsStream;
|
use tokio_rustls::client::TlsStream;
|
||||||
use tokio_rustls::TlsConnector;
|
use tokio_rustls::TlsConnector;
|
||||||
|
|
@ -212,7 +214,10 @@ fn gen_key_share(rng: &SecureRandom) -> [u8; 32] {
|
||||||
key
|
key
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn read_tls_record(stream: &mut TcpStream) -> Result<(u8, Vec<u8>)> {
|
async fn read_tls_record<S>(stream: &mut S) -> Result<(u8, Vec<u8>)>
|
||||||
|
where
|
||||||
|
S: AsyncRead + Unpin,
|
||||||
|
{
|
||||||
let mut header = [0u8; 5];
|
let mut header = [0u8; 5];
|
||||||
stream.read_exact(&mut header).await?;
|
stream.read_exact(&mut header).await?;
|
||||||
let len = u16::from_be_bytes([header[3], header[4]]) as usize;
|
let len = u16::from_be_bytes([header[3], header[4]]) as usize;
|
||||||
|
|
@ -345,6 +350,44 @@ async fn connect_with_dns_override(
|
||||||
Ok(timeout(connect_timeout, TcpStream::connect((host, port))).await??)
|
Ok(timeout(connect_timeout, TcpStream::connect((host, port))).await??)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn connect_tcp_with_upstream(
|
||||||
|
host: &str,
|
||||||
|
port: u16,
|
||||||
|
connect_timeout: Duration,
|
||||||
|
upstream: Option<std::sync::Arc<crate::transport::UpstreamManager>>,
|
||||||
|
) -> Result<TcpStream> {
|
||||||
|
if let Some(manager) = upstream {
|
||||||
|
if let Some(addr) = resolve_socket_addr(host, port) {
|
||||||
|
match manager.connect(addr, None, None).await {
|
||||||
|
Ok(stream) => return Ok(stream),
|
||||||
|
Err(e) => {
|
||||||
|
warn!(
|
||||||
|
host = %host,
|
||||||
|
port = port,
|
||||||
|
error = %e,
|
||||||
|
"Upstream connect failed, using direct connect"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if let Ok(mut addrs) = tokio::net::lookup_host((host, port)).await {
|
||||||
|
if let Some(addr) = addrs.find(|a| a.is_ipv4()) {
|
||||||
|
match manager.connect(addr, None, None).await {
|
||||||
|
Ok(stream) => return Ok(stream),
|
||||||
|
Err(e) => {
|
||||||
|
warn!(
|
||||||
|
host = %host,
|
||||||
|
port = port,
|
||||||
|
error = %e,
|
||||||
|
"Upstream connect failed, using direct connect"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
connect_with_dns_override(host, port, connect_timeout).await
|
||||||
|
}
|
||||||
|
|
||||||
fn encode_tls13_certificate_message(cert_chain_der: &[Vec<u8>]) -> Option<Vec<u8>> {
|
fn encode_tls13_certificate_message(cert_chain_der: &[Vec<u8>]) -> Option<Vec<u8>> {
|
||||||
if cert_chain_der.is_empty() {
|
if cert_chain_der.is_empty() {
|
||||||
return None;
|
return None;
|
||||||
|
|
@ -374,15 +417,15 @@ fn encode_tls13_certificate_message(cert_chain_der: &[Vec<u8>]) -> Option<Vec<u8
|
||||||
Some(message)
|
Some(message)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn fetch_via_raw_tls(
|
async fn fetch_via_raw_tls_stream<S>(
|
||||||
host: &str,
|
mut stream: S,
|
||||||
port: u16,
|
|
||||||
sni: &str,
|
sni: &str,
|
||||||
connect_timeout: Duration,
|
connect_timeout: Duration,
|
||||||
proxy_protocol: u8,
|
proxy_protocol: u8,
|
||||||
) -> Result<TlsFetchResult> {
|
) -> Result<TlsFetchResult>
|
||||||
let mut stream = connect_with_dns_override(host, port, connect_timeout).await?;
|
where
|
||||||
|
S: AsyncRead + AsyncWrite + Unpin,
|
||||||
|
{
|
||||||
let rng = SecureRandom::new();
|
let rng = SecureRandom::new();
|
||||||
let client_hello = build_client_hello(sni, &rng);
|
let client_hello = build_client_hello(sni, &rng);
|
||||||
timeout(connect_timeout, async {
|
timeout(connect_timeout, async {
|
||||||
|
|
@ -438,43 +481,61 @@ async fn fetch_via_raw_tls(
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn fetch_via_rustls(
|
async fn fetch_via_raw_tls(
|
||||||
host: &str,
|
host: &str,
|
||||||
port: u16,
|
port: u16,
|
||||||
sni: &str,
|
sni: &str,
|
||||||
connect_timeout: Duration,
|
connect_timeout: Duration,
|
||||||
upstream: Option<std::sync::Arc<crate::transport::UpstreamManager>>,
|
upstream: Option<std::sync::Arc<crate::transport::UpstreamManager>>,
|
||||||
proxy_protocol: u8,
|
proxy_protocol: u8,
|
||||||
|
unix_sock: Option<&str>,
|
||||||
) -> Result<TlsFetchResult> {
|
) -> Result<TlsFetchResult> {
|
||||||
// rustls handshake path for certificate and basic negotiated metadata.
|
#[cfg(unix)]
|
||||||
let mut stream = if let Some(manager) = upstream {
|
if let Some(sock_path) = unix_sock {
|
||||||
if let Some(addr) = resolve_socket_addr(host, port) {
|
match timeout(connect_timeout, UnixStream::connect(sock_path)).await {
|
||||||
match manager.connect(addr, None, None).await {
|
Ok(Ok(stream)) => {
|
||||||
Ok(s) => s,
|
debug!(
|
||||||
Err(e) => {
|
sni = %sni,
|
||||||
warn!(sni = %sni, error = %e, "Upstream connect failed, using direct connect");
|
sock = %sock_path,
|
||||||
connect_with_dns_override(host, port, connect_timeout).await?
|
"Raw TLS fetch using mask unix socket"
|
||||||
|
);
|
||||||
|
return fetch_via_raw_tls_stream(stream, sni, connect_timeout, 0).await;
|
||||||
|
}
|
||||||
|
Ok(Err(e)) => {
|
||||||
|
warn!(
|
||||||
|
sni = %sni,
|
||||||
|
sock = %sock_path,
|
||||||
|
error = %e,
|
||||||
|
"Raw TLS unix socket connect failed, falling back to TCP"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Err(_) => {
|
||||||
|
warn!(
|
||||||
|
sni = %sni,
|
||||||
|
sock = %sock_path,
|
||||||
|
"Raw TLS unix socket connect timed out, falling back to TCP"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if let Ok(mut addrs) = tokio::net::lookup_host((host, port)).await {
|
|
||||||
if let Some(addr) = addrs.find(|a| a.is_ipv4()) {
|
|
||||||
match manager.connect(addr, None, None).await {
|
|
||||||
Ok(s) => s,
|
|
||||||
Err(e) => {
|
|
||||||
warn!(sni = %sni, error = %e, "Upstream connect failed, using direct connect");
|
|
||||||
connect_with_dns_override(host, port, connect_timeout).await?
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
} else {
|
|
||||||
connect_with_dns_override(host, port, connect_timeout).await?
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
connect_with_dns_override(host, port, connect_timeout).await?
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
connect_with_dns_override(host, port, connect_timeout).await?
|
|
||||||
};
|
|
||||||
|
|
||||||
|
#[cfg(not(unix))]
|
||||||
|
let _ = unix_sock;
|
||||||
|
|
||||||
|
let stream = connect_tcp_with_upstream(host, port, connect_timeout, upstream).await?;
|
||||||
|
fetch_via_raw_tls_stream(stream, sni, connect_timeout, proxy_protocol).await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn fetch_via_rustls_stream<S>(
|
||||||
|
mut stream: S,
|
||||||
|
host: &str,
|
||||||
|
sni: &str,
|
||||||
|
proxy_protocol: u8,
|
||||||
|
) -> Result<TlsFetchResult>
|
||||||
|
where
|
||||||
|
S: AsyncRead + AsyncWrite + Unpin,
|
||||||
|
{
|
||||||
|
// rustls handshake path for certificate and basic negotiated metadata.
|
||||||
if proxy_protocol > 0 {
|
if proxy_protocol > 0 {
|
||||||
let header = match proxy_protocol {
|
let header = match proxy_protocol {
|
||||||
2 => ProxyProtocolV2Builder::new().build(),
|
2 => ProxyProtocolV2Builder::new().build(),
|
||||||
|
|
@ -491,7 +552,7 @@ async fn fetch_via_rustls(
|
||||||
.or_else(|_| ServerName::try_from(host.to_owned()))
|
.or_else(|_| ServerName::try_from(host.to_owned()))
|
||||||
.map_err(|_| RustlsError::General("invalid SNI".into()))?;
|
.map_err(|_| RustlsError::General("invalid SNI".into()))?;
|
||||||
|
|
||||||
let tls_stream: TlsStream<TcpStream> = connector.connect(server_name, stream).await?;
|
let tls_stream: TlsStream<S> = connector.connect(server_name, stream).await?;
|
||||||
|
|
||||||
// Extract negotiated parameters and certificates
|
// Extract negotiated parameters and certificates
|
||||||
let (_io, session) = tls_stream.get_ref();
|
let (_io, session) = tls_stream.get_ref();
|
||||||
|
|
@ -552,6 +613,51 @@ async fn fetch_via_rustls(
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn fetch_via_rustls(
|
||||||
|
host: &str,
|
||||||
|
port: u16,
|
||||||
|
sni: &str,
|
||||||
|
connect_timeout: Duration,
|
||||||
|
upstream: Option<std::sync::Arc<crate::transport::UpstreamManager>>,
|
||||||
|
proxy_protocol: u8,
|
||||||
|
unix_sock: Option<&str>,
|
||||||
|
) -> Result<TlsFetchResult> {
|
||||||
|
#[cfg(unix)]
|
||||||
|
if let Some(sock_path) = unix_sock {
|
||||||
|
match timeout(connect_timeout, UnixStream::connect(sock_path)).await {
|
||||||
|
Ok(Ok(stream)) => {
|
||||||
|
debug!(
|
||||||
|
sni = %sni,
|
||||||
|
sock = %sock_path,
|
||||||
|
"Rustls fetch using mask unix socket"
|
||||||
|
);
|
||||||
|
return fetch_via_rustls_stream(stream, host, sni, 0).await;
|
||||||
|
}
|
||||||
|
Ok(Err(e)) => {
|
||||||
|
warn!(
|
||||||
|
sni = %sni,
|
||||||
|
sock = %sock_path,
|
||||||
|
error = %e,
|
||||||
|
"Rustls unix socket connect failed, falling back to TCP"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Err(_) => {
|
||||||
|
warn!(
|
||||||
|
sni = %sni,
|
||||||
|
sock = %sock_path,
|
||||||
|
"Rustls unix socket connect timed out, falling back to TCP"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(unix))]
|
||||||
|
let _ = unix_sock;
|
||||||
|
|
||||||
|
let stream = connect_tcp_with_upstream(host, port, connect_timeout, upstream).await?;
|
||||||
|
fetch_via_rustls_stream(stream, host, sni, proxy_protocol).await
|
||||||
|
}
|
||||||
|
|
||||||
/// Fetch real TLS metadata for the given SNI.
|
/// Fetch real TLS metadata for the given SNI.
|
||||||
///
|
///
|
||||||
/// Strategy:
|
/// Strategy:
|
||||||
|
|
@ -565,8 +671,19 @@ pub async fn fetch_real_tls(
|
||||||
connect_timeout: Duration,
|
connect_timeout: Duration,
|
||||||
upstream: Option<std::sync::Arc<crate::transport::UpstreamManager>>,
|
upstream: Option<std::sync::Arc<crate::transport::UpstreamManager>>,
|
||||||
proxy_protocol: u8,
|
proxy_protocol: u8,
|
||||||
|
unix_sock: Option<&str>,
|
||||||
) -> Result<TlsFetchResult> {
|
) -> Result<TlsFetchResult> {
|
||||||
let raw_result = match fetch_via_raw_tls(host, port, sni, connect_timeout, proxy_protocol).await {
|
let raw_result = match fetch_via_raw_tls(
|
||||||
|
host,
|
||||||
|
port,
|
||||||
|
sni,
|
||||||
|
connect_timeout,
|
||||||
|
upstream.clone(),
|
||||||
|
proxy_protocol,
|
||||||
|
unix_sock,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
Ok(res) => Some(res),
|
Ok(res) => Some(res),
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
warn!(sni = %sni, error = %e, "Raw TLS fetch failed");
|
warn!(sni = %sni, error = %e, "Raw TLS fetch failed");
|
||||||
|
|
@ -574,7 +691,17 @@ pub async fn fetch_real_tls(
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
match fetch_via_rustls(host, port, sni, connect_timeout, upstream, proxy_protocol).await {
|
match fetch_via_rustls(
|
||||||
|
host,
|
||||||
|
port,
|
||||||
|
sni,
|
||||||
|
connect_timeout,
|
||||||
|
upstream,
|
||||||
|
proxy_protocol,
|
||||||
|
unix_sock,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
Ok(rustls_result) => {
|
Ok(rustls_result) => {
|
||||||
if let Some(mut raw) = raw_result {
|
if let Some(mut raw) = raw_result {
|
||||||
raw.cert_info = rustls_result.cert_info;
|
raw.cert_info = rustls_result.cert_info;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue