Exclusive Mask + Startup Speed-up

Signed-off-by: Alexey <247128645+axkurcom@users.noreply.github.com>
This commit is contained in:
Alexey
2026-05-19 21:56:26 +03:00
parent 9e877e45c9
commit 914f141715
14 changed files with 529 additions and 109 deletions
+71 -2
View File
@@ -11,6 +11,7 @@ use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
use tokio::net::TcpStream;
use tokio::sync::RwLock;
use tokio::time::timeout;
use tracing::{debug, warn};
@@ -452,7 +453,50 @@ where
}
#[allow(clippy::too_many_arguments)]
#[allow(dead_code)]
pub async fn handle_client_stream_with_shared<S>(
stream: S,
peer: SocketAddr,
config: Arc<ProxyConfig>,
stats: Arc<Stats>,
upstream_manager: Arc<UpstreamManager>,
replay_checker: Arc<ReplayChecker>,
buffer_pool: Arc<BufferPool>,
rng: Arc<SecureRandom>,
me_pool: Option<Arc<MePool>>,
route_runtime: Arc<RouteRuntimeController>,
tls_cache: Option<Arc<TlsFrontCache>>,
ip_tracker: Arc<UserIpTracker>,
beobachten: Arc<BeobachtenStore>,
shared: Arc<ProxySharedState>,
proxy_protocol_enabled: bool,
) -> Result<()>
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
handle_client_stream_with_shared_and_pool_runtime(
stream,
peer,
config,
stats,
upstream_manager,
replay_checker,
buffer_pool,
rng,
me_pool,
None,
route_runtime,
tls_cache,
ip_tracker,
beobachten,
shared,
proxy_protocol_enabled,
)
.await
}
#[allow(clippy::too_many_arguments)]
pub async fn handle_client_stream_with_shared_and_pool_runtime<S>(
mut stream: S,
peer: SocketAddr,
config: Arc<ProxyConfig>,
@@ -462,6 +506,7 @@ pub async fn handle_client_stream_with_shared<S>(
buffer_pool: Arc<BufferPool>,
rng: Arc<SecureRandom>,
me_pool: Option<Arc<MePool>>,
me_pool_runtime: Option<Arc<RwLock<Option<Arc<MePool>>>>>,
route_runtime: Arc<RouteRuntimeController>,
tls_cache: Option<Arc<TlsFrontCache>>,
ip_tracker: Arc<UserIpTracker>,
@@ -731,6 +776,7 @@ where
RunningClientHandler::handle_authenticated_static_with_shared(
crypto_reader, crypto_writer, success,
upstream_manager, stats, config, buffer_pool, rng, me_pool,
me_pool_runtime,
route_runtime.clone(),
local_addr, real_peer, ip_tracker.clone(),
shared.clone(),
@@ -791,6 +837,7 @@ where
buffer_pool,
rng,
me_pool,
me_pool_runtime,
route_runtime.clone(),
local_addr,
real_peer,
@@ -846,6 +893,7 @@ pub struct RunningClientHandler {
buffer_pool: Arc<BufferPool>,
rng: Arc<SecureRandom>,
me_pool: Option<Arc<MePool>>,
me_pool_runtime: Option<Arc<RwLock<Option<Arc<MePool>>>>>,
route_runtime: Arc<RouteRuntimeController>,
tls_cache: Option<Arc<TlsFrontCache>>,
ip_tracker: Arc<UserIpTracker>,
@@ -891,6 +939,7 @@ impl ClientHandler {
buffer_pool,
rng,
me_pool,
None,
route_runtime,
tls_cache,
ip_tracker,
@@ -915,6 +964,7 @@ impl ClientHandler {
buffer_pool: Arc<BufferPool>,
rng: Arc<SecureRandom>,
me_pool: Option<Arc<MePool>>,
me_pool_runtime: Option<Arc<RwLock<Option<Arc<MePool>>>>>,
route_runtime: Arc<RouteRuntimeController>,
tls_cache: Option<Arc<TlsFrontCache>>,
ip_tracker: Arc<UserIpTracker>,
@@ -938,6 +988,7 @@ impl ClientHandler {
buffer_pool,
rng,
me_pool,
me_pool_runtime,
route_runtime,
tls_cache,
ip_tracker,
@@ -1345,6 +1396,7 @@ impl RunningClientHandler {
buffer_pool,
self.rng,
self.me_pool,
self.me_pool_runtime,
self.route_runtime.clone(),
local_addr,
peer,
@@ -1429,6 +1481,7 @@ impl RunningClientHandler {
buffer_pool,
self.rng,
self.me_pool,
self.me_pool_runtime,
self.route_runtime.clone(),
local_addr,
peer,
@@ -1472,6 +1525,7 @@ impl RunningClientHandler {
buffer_pool,
rng,
me_pool,
None,
route_runtime,
local_addr,
peer_addr,
@@ -1491,6 +1545,7 @@ impl RunningClientHandler {
buffer_pool: Arc<BufferPool>,
rng: Arc<SecureRandom>,
me_pool: Option<Arc<MePool>>,
me_pool_runtime: Option<Arc<RwLock<Option<Arc<MePool>>>>>,
route_runtime: Arc<RouteRuntimeController>,
local_addr: SocketAddr,
peer_addr: SocketAddr,
@@ -1521,15 +1576,29 @@ impl RunningClientHandler {
let route_snapshot = route_runtime.snapshot();
let session_id = rng.u64();
let relay_result = if config.general.use_middle_proxy
let selected_me_pool = if config.general.use_middle_proxy
&& matches!(route_snapshot.mode, RelayRouteMode::Middle)
{
if let Some(ref pool) = me_pool {
Some(pool.clone())
} else if let Some(pool_runtime) = me_pool_runtime.as_ref() {
pool_runtime.read().await.clone()
} else {
None
}
} else {
None
};
let relay_result = if config.general.use_middle_proxy
&& matches!(route_snapshot.mode, RelayRouteMode::Middle)
{
if let Some(pool) = selected_me_pool {
handle_via_middle_proxy(
client_reader,
client_writer,
success,
pool.clone(),
pool,
stats.clone(),
config,
buffer_pool,
+103 -7
View File
@@ -47,6 +47,12 @@ struct CopyOutcome {
ended_by_eof: bool,
}
#[derive(Clone, Copy)]
struct MaskTcpTarget<'a> {
host: &'a str,
port: u16,
}
async fn copy_with_idle_timeout<R, W>(
reader: &mut R,
writer: &mut W,
@@ -331,7 +337,9 @@ async fn wait_mask_outcome_budget(started: Instant, config: &ProxyConfig) {
#[cfg(test)]
mod tls_domain_mask_host_tests {
use super::{mask_host_for_initial_data, matching_tls_domain_for_sni};
use super::{
mask_host_for_initial_data, mask_tcp_target_for_initial_data, matching_tls_domain_for_sni,
};
use crate::config::ProxyConfig;
fn client_hello_with_sni(sni_host: &str) -> Vec<u8> {
@@ -410,6 +418,25 @@ mod tls_domain_mask_host_tests {
assert_eq!(mask_host_for_initial_data(&config, &initial_data), "b.com");
}
#[test]
fn exclusive_mask_target_overrides_only_matching_sni() {
let mut config = config_with_tls_domains();
config
.censorship
.exclusive_mask
.insert("b.com".to_string(), "origin-b.example:8443".to_string());
let b_initial_data = client_hello_with_sni("B.COM");
let c_initial_data = client_hello_with_sni("c.com");
let b_target = mask_tcp_target_for_initial_data(&config, &b_initial_data);
let c_target = mask_tcp_target_for_initial_data(&config, &c_initial_data);
assert_eq!(b_target.host, "origin-b.example");
assert_eq!(b_target.port, 8443);
assert_eq!(c_target.host, "c.com");
assert_eq!(c_target.port, config.censorship.mask_port);
}
}
/// Detect client type based on initial data
@@ -458,7 +485,61 @@ fn matching_tls_domain_for_sni<'a>(config: &'a ProxyConfig, sni: &str) -> Option
None
}
fn parse_exclusive_mask_target(target: &str) -> Option<MaskTcpTarget<'_>> {
let target = target.trim();
if target.is_empty() {
return None;
}
if target.starts_with('[') {
let end = target.find(']')?;
if target.get(end + 1..end + 2)? != ":" {
return None;
}
let port = target[end + 2..].parse::<u16>().ok()?;
return (port > 0).then_some(MaskTcpTarget {
host: &target[..=end],
port,
});
}
let (host, port) = target.rsplit_once(':')?;
if host.is_empty() || host.contains(':') {
return None;
}
let port = port.parse::<u16>().ok()?;
(port > 0).then_some(MaskTcpTarget { host, port })
}
fn exclusive_mask_target_for_sni<'a>(
config: &'a ProxyConfig,
sni: &str,
) -> Option<MaskTcpTarget<'a>> {
for (domain, target) in &config.censorship.exclusive_mask {
if domain.eq_ignore_ascii_case(sni) {
return parse_exclusive_mask_target(target);
}
}
None
}
#[cfg(test)]
fn mask_host_for_initial_data<'a>(config: &'a ProxyConfig, initial_data: &[u8]) -> &'a str {
mask_tcp_target_for_initial_data(config, initial_data).host
}
fn mask_tcp_target_for_initial_data<'a>(
config: &'a ProxyConfig,
initial_data: &[u8],
) -> MaskTcpTarget<'a> {
if let Some(target) = tls::extract_sni_from_client_hello(initial_data)
.as_deref()
.and_then(|sni| exclusive_mask_target_for_sni(config, sni))
{
return target;
}
let configured_mask_host = config
.censorship
.mask_host
@@ -466,13 +547,20 @@ fn mask_host_for_initial_data<'a>(config: &'a ProxyConfig, initial_data: &[u8])
.unwrap_or(&config.censorship.tls_domain);
if !configured_mask_host.eq_ignore_ascii_case(&config.censorship.tls_domain) {
return configured_mask_host;
return MaskTcpTarget {
host: configured_mask_host,
port: config.censorship.mask_port,
};
}
tls::extract_sni_from_client_hello(initial_data)
let host = tls::extract_sni_from_client_hello(initial_data)
.as_deref()
.and_then(|sni| matching_tls_domain_for_sni(config, sni))
.unwrap_or(configured_mask_host)
.unwrap_or(configured_mask_host);
MaskTcpTarget {
host,
port: config.censorship.mask_port,
}
}
fn canonical_ip(ip: IpAddr) -> IpAddr {
@@ -770,9 +858,15 @@ pub async fn handle_bad_client<R, W>(
return;
}
let exclusive_tcp_target = tls::extract_sni_from_client_hello(initial_data)
.as_deref()
.and_then(|sni| exclusive_mask_target_for_sni(config, sni));
// Connect via Unix socket or TCP
#[cfg(unix)]
if let Some(ref sock_path) = config.censorship.mask_unix_sock {
if exclusive_tcp_target.is_none()
&& let Some(ref sock_path) = config.censorship.mask_unix_sock
{
let outcome_started = Instant::now();
let connect_started = Instant::now();
debug!(
@@ -849,8 +943,10 @@ pub async fn handle_bad_client<R, W>(
return;
}
let mask_host = mask_host_for_initial_data(config, initial_data);
let mask_port = config.censorship.mask_port;
let mask_target = exclusive_tcp_target
.unwrap_or_else(|| mask_tcp_target_for_initial_data(config, initial_data));
let mask_host = mask_target.host;
let mask_port = mask_target.port;
// Fail closed when fallback points at our own listener endpoint.
// Self-referential masking can create recursive proxy loops under