mirror of
https://github.com/telemt/telemt.git
synced 2026-06-30 06:41:11 +03:00
Exclusive Mask + Startup Speed-up
Signed-off-by: Alexey <247128645+axkurcom@users.noreply.github.com>
This commit is contained in:
+71
-2
@@ -11,6 +11,7 @@ use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::time::Duration;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::RwLock;
|
||||
use tokio::time::timeout;
|
||||
use tracing::{debug, warn};
|
||||
|
||||
@@ -452,7 +453,50 @@ where
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
#[allow(dead_code)]
|
||||
pub async fn handle_client_stream_with_shared<S>(
|
||||
stream: S,
|
||||
peer: SocketAddr,
|
||||
config: Arc<ProxyConfig>,
|
||||
stats: Arc<Stats>,
|
||||
upstream_manager: Arc<UpstreamManager>,
|
||||
replay_checker: Arc<ReplayChecker>,
|
||||
buffer_pool: Arc<BufferPool>,
|
||||
rng: Arc<SecureRandom>,
|
||||
me_pool: Option<Arc<MePool>>,
|
||||
route_runtime: Arc<RouteRuntimeController>,
|
||||
tls_cache: Option<Arc<TlsFrontCache>>,
|
||||
ip_tracker: Arc<UserIpTracker>,
|
||||
beobachten: Arc<BeobachtenStore>,
|
||||
shared: Arc<ProxySharedState>,
|
||||
proxy_protocol_enabled: bool,
|
||||
) -> Result<()>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
handle_client_stream_with_shared_and_pool_runtime(
|
||||
stream,
|
||||
peer,
|
||||
config,
|
||||
stats,
|
||||
upstream_manager,
|
||||
replay_checker,
|
||||
buffer_pool,
|
||||
rng,
|
||||
me_pool,
|
||||
None,
|
||||
route_runtime,
|
||||
tls_cache,
|
||||
ip_tracker,
|
||||
beobachten,
|
||||
shared,
|
||||
proxy_protocol_enabled,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn handle_client_stream_with_shared_and_pool_runtime<S>(
|
||||
mut stream: S,
|
||||
peer: SocketAddr,
|
||||
config: Arc<ProxyConfig>,
|
||||
@@ -462,6 +506,7 @@ pub async fn handle_client_stream_with_shared<S>(
|
||||
buffer_pool: Arc<BufferPool>,
|
||||
rng: Arc<SecureRandom>,
|
||||
me_pool: Option<Arc<MePool>>,
|
||||
me_pool_runtime: Option<Arc<RwLock<Option<Arc<MePool>>>>>,
|
||||
route_runtime: Arc<RouteRuntimeController>,
|
||||
tls_cache: Option<Arc<TlsFrontCache>>,
|
||||
ip_tracker: Arc<UserIpTracker>,
|
||||
@@ -731,6 +776,7 @@ where
|
||||
RunningClientHandler::handle_authenticated_static_with_shared(
|
||||
crypto_reader, crypto_writer, success,
|
||||
upstream_manager, stats, config, buffer_pool, rng, me_pool,
|
||||
me_pool_runtime,
|
||||
route_runtime.clone(),
|
||||
local_addr, real_peer, ip_tracker.clone(),
|
||||
shared.clone(),
|
||||
@@ -791,6 +837,7 @@ where
|
||||
buffer_pool,
|
||||
rng,
|
||||
me_pool,
|
||||
me_pool_runtime,
|
||||
route_runtime.clone(),
|
||||
local_addr,
|
||||
real_peer,
|
||||
@@ -846,6 +893,7 @@ pub struct RunningClientHandler {
|
||||
buffer_pool: Arc<BufferPool>,
|
||||
rng: Arc<SecureRandom>,
|
||||
me_pool: Option<Arc<MePool>>,
|
||||
me_pool_runtime: Option<Arc<RwLock<Option<Arc<MePool>>>>>,
|
||||
route_runtime: Arc<RouteRuntimeController>,
|
||||
tls_cache: Option<Arc<TlsFrontCache>>,
|
||||
ip_tracker: Arc<UserIpTracker>,
|
||||
@@ -891,6 +939,7 @@ impl ClientHandler {
|
||||
buffer_pool,
|
||||
rng,
|
||||
me_pool,
|
||||
None,
|
||||
route_runtime,
|
||||
tls_cache,
|
||||
ip_tracker,
|
||||
@@ -915,6 +964,7 @@ impl ClientHandler {
|
||||
buffer_pool: Arc<BufferPool>,
|
||||
rng: Arc<SecureRandom>,
|
||||
me_pool: Option<Arc<MePool>>,
|
||||
me_pool_runtime: Option<Arc<RwLock<Option<Arc<MePool>>>>>,
|
||||
route_runtime: Arc<RouteRuntimeController>,
|
||||
tls_cache: Option<Arc<TlsFrontCache>>,
|
||||
ip_tracker: Arc<UserIpTracker>,
|
||||
@@ -938,6 +988,7 @@ impl ClientHandler {
|
||||
buffer_pool,
|
||||
rng,
|
||||
me_pool,
|
||||
me_pool_runtime,
|
||||
route_runtime,
|
||||
tls_cache,
|
||||
ip_tracker,
|
||||
@@ -1345,6 +1396,7 @@ impl RunningClientHandler {
|
||||
buffer_pool,
|
||||
self.rng,
|
||||
self.me_pool,
|
||||
self.me_pool_runtime,
|
||||
self.route_runtime.clone(),
|
||||
local_addr,
|
||||
peer,
|
||||
@@ -1429,6 +1481,7 @@ impl RunningClientHandler {
|
||||
buffer_pool,
|
||||
self.rng,
|
||||
self.me_pool,
|
||||
self.me_pool_runtime,
|
||||
self.route_runtime.clone(),
|
||||
local_addr,
|
||||
peer,
|
||||
@@ -1472,6 +1525,7 @@ impl RunningClientHandler {
|
||||
buffer_pool,
|
||||
rng,
|
||||
me_pool,
|
||||
None,
|
||||
route_runtime,
|
||||
local_addr,
|
||||
peer_addr,
|
||||
@@ -1491,6 +1545,7 @@ impl RunningClientHandler {
|
||||
buffer_pool: Arc<BufferPool>,
|
||||
rng: Arc<SecureRandom>,
|
||||
me_pool: Option<Arc<MePool>>,
|
||||
me_pool_runtime: Option<Arc<RwLock<Option<Arc<MePool>>>>>,
|
||||
route_runtime: Arc<RouteRuntimeController>,
|
||||
local_addr: SocketAddr,
|
||||
peer_addr: SocketAddr,
|
||||
@@ -1521,15 +1576,29 @@ impl RunningClientHandler {
|
||||
|
||||
let route_snapshot = route_runtime.snapshot();
|
||||
let session_id = rng.u64();
|
||||
let relay_result = if config.general.use_middle_proxy
|
||||
let selected_me_pool = if config.general.use_middle_proxy
|
||||
&& matches!(route_snapshot.mode, RelayRouteMode::Middle)
|
||||
{
|
||||
if let Some(ref pool) = me_pool {
|
||||
Some(pool.clone())
|
||||
} else if let Some(pool_runtime) = me_pool_runtime.as_ref() {
|
||||
pool_runtime.read().await.clone()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let relay_result = if config.general.use_middle_proxy
|
||||
&& matches!(route_snapshot.mode, RelayRouteMode::Middle)
|
||||
{
|
||||
if let Some(pool) = selected_me_pool {
|
||||
handle_via_middle_proxy(
|
||||
client_reader,
|
||||
client_writer,
|
||||
success,
|
||||
pool.clone(),
|
||||
pool,
|
||||
stats.clone(),
|
||||
config,
|
||||
buffer_pool,
|
||||
|
||||
+103
-7
@@ -47,6 +47,12 @@ struct CopyOutcome {
|
||||
ended_by_eof: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
struct MaskTcpTarget<'a> {
|
||||
host: &'a str,
|
||||
port: u16,
|
||||
}
|
||||
|
||||
async fn copy_with_idle_timeout<R, W>(
|
||||
reader: &mut R,
|
||||
writer: &mut W,
|
||||
@@ -331,7 +337,9 @@ async fn wait_mask_outcome_budget(started: Instant, config: &ProxyConfig) {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tls_domain_mask_host_tests {
|
||||
use super::{mask_host_for_initial_data, matching_tls_domain_for_sni};
|
||||
use super::{
|
||||
mask_host_for_initial_data, mask_tcp_target_for_initial_data, matching_tls_domain_for_sni,
|
||||
};
|
||||
use crate::config::ProxyConfig;
|
||||
|
||||
fn client_hello_with_sni(sni_host: &str) -> Vec<u8> {
|
||||
@@ -410,6 +418,25 @@ mod tls_domain_mask_host_tests {
|
||||
|
||||
assert_eq!(mask_host_for_initial_data(&config, &initial_data), "b.com");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn exclusive_mask_target_overrides_only_matching_sni() {
|
||||
let mut config = config_with_tls_domains();
|
||||
config
|
||||
.censorship
|
||||
.exclusive_mask
|
||||
.insert("b.com".to_string(), "origin-b.example:8443".to_string());
|
||||
let b_initial_data = client_hello_with_sni("B.COM");
|
||||
let c_initial_data = client_hello_with_sni("c.com");
|
||||
|
||||
let b_target = mask_tcp_target_for_initial_data(&config, &b_initial_data);
|
||||
let c_target = mask_tcp_target_for_initial_data(&config, &c_initial_data);
|
||||
|
||||
assert_eq!(b_target.host, "origin-b.example");
|
||||
assert_eq!(b_target.port, 8443);
|
||||
assert_eq!(c_target.host, "c.com");
|
||||
assert_eq!(c_target.port, config.censorship.mask_port);
|
||||
}
|
||||
}
|
||||
|
||||
/// Detect client type based on initial data
|
||||
@@ -458,7 +485,61 @@ fn matching_tls_domain_for_sni<'a>(config: &'a ProxyConfig, sni: &str) -> Option
|
||||
None
|
||||
}
|
||||
|
||||
fn parse_exclusive_mask_target(target: &str) -> Option<MaskTcpTarget<'_>> {
|
||||
let target = target.trim();
|
||||
if target.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
if target.starts_with('[') {
|
||||
let end = target.find(']')?;
|
||||
if target.get(end + 1..end + 2)? != ":" {
|
||||
return None;
|
||||
}
|
||||
let port = target[end + 2..].parse::<u16>().ok()?;
|
||||
return (port > 0).then_some(MaskTcpTarget {
|
||||
host: &target[..=end],
|
||||
port,
|
||||
});
|
||||
}
|
||||
|
||||
let (host, port) = target.rsplit_once(':')?;
|
||||
if host.is_empty() || host.contains(':') {
|
||||
return None;
|
||||
}
|
||||
let port = port.parse::<u16>().ok()?;
|
||||
(port > 0).then_some(MaskTcpTarget { host, port })
|
||||
}
|
||||
|
||||
fn exclusive_mask_target_for_sni<'a>(
|
||||
config: &'a ProxyConfig,
|
||||
sni: &str,
|
||||
) -> Option<MaskTcpTarget<'a>> {
|
||||
for (domain, target) in &config.censorship.exclusive_mask {
|
||||
if domain.eq_ignore_ascii_case(sni) {
|
||||
return parse_exclusive_mask_target(target);
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn mask_host_for_initial_data<'a>(config: &'a ProxyConfig, initial_data: &[u8]) -> &'a str {
|
||||
mask_tcp_target_for_initial_data(config, initial_data).host
|
||||
}
|
||||
|
||||
fn mask_tcp_target_for_initial_data<'a>(
|
||||
config: &'a ProxyConfig,
|
||||
initial_data: &[u8],
|
||||
) -> MaskTcpTarget<'a> {
|
||||
if let Some(target) = tls::extract_sni_from_client_hello(initial_data)
|
||||
.as_deref()
|
||||
.and_then(|sni| exclusive_mask_target_for_sni(config, sni))
|
||||
{
|
||||
return target;
|
||||
}
|
||||
|
||||
let configured_mask_host = config
|
||||
.censorship
|
||||
.mask_host
|
||||
@@ -466,13 +547,20 @@ fn mask_host_for_initial_data<'a>(config: &'a ProxyConfig, initial_data: &[u8])
|
||||
.unwrap_or(&config.censorship.tls_domain);
|
||||
|
||||
if !configured_mask_host.eq_ignore_ascii_case(&config.censorship.tls_domain) {
|
||||
return configured_mask_host;
|
||||
return MaskTcpTarget {
|
||||
host: configured_mask_host,
|
||||
port: config.censorship.mask_port,
|
||||
};
|
||||
}
|
||||
|
||||
tls::extract_sni_from_client_hello(initial_data)
|
||||
let host = tls::extract_sni_from_client_hello(initial_data)
|
||||
.as_deref()
|
||||
.and_then(|sni| matching_tls_domain_for_sni(config, sni))
|
||||
.unwrap_or(configured_mask_host)
|
||||
.unwrap_or(configured_mask_host);
|
||||
MaskTcpTarget {
|
||||
host,
|
||||
port: config.censorship.mask_port,
|
||||
}
|
||||
}
|
||||
|
||||
fn canonical_ip(ip: IpAddr) -> IpAddr {
|
||||
@@ -770,9 +858,15 @@ pub async fn handle_bad_client<R, W>(
|
||||
return;
|
||||
}
|
||||
|
||||
let exclusive_tcp_target = tls::extract_sni_from_client_hello(initial_data)
|
||||
.as_deref()
|
||||
.and_then(|sni| exclusive_mask_target_for_sni(config, sni));
|
||||
|
||||
// Connect via Unix socket or TCP
|
||||
#[cfg(unix)]
|
||||
if let Some(ref sock_path) = config.censorship.mask_unix_sock {
|
||||
if exclusive_tcp_target.is_none()
|
||||
&& let Some(ref sock_path) = config.censorship.mask_unix_sock
|
||||
{
|
||||
let outcome_started = Instant::now();
|
||||
let connect_started = Instant::now();
|
||||
debug!(
|
||||
@@ -849,8 +943,10 @@ pub async fn handle_bad_client<R, W>(
|
||||
return;
|
||||
}
|
||||
|
||||
let mask_host = mask_host_for_initial_data(config, initial_data);
|
||||
let mask_port = config.censorship.mask_port;
|
||||
let mask_target = exclusive_tcp_target
|
||||
.unwrap_or_else(|| mask_tcp_target_for_initial_data(config, initial_data));
|
||||
let mask_host = mask_target.host;
|
||||
let mask_port = mask_target.port;
|
||||
|
||||
// Fail closed when fallback points at our own listener endpoint.
|
||||
// Self-referential masking can create recursive proxy loops under
|
||||
|
||||
Reference in New Issue
Block a user