mirror of
https://github.com/telemt/telemt.git
synced 2026-06-18 08:58:30 +03:00
Harden masking fallback and frame readers after flow sync
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
This commit is contained in:
@@ -1940,6 +1940,13 @@ impl ProxyConfig {
|
||||
));
|
||||
}
|
||||
|
||||
if config.server.listen_backlog == 0 || config.server.listen_backlog > i32::MAX as u32 {
|
||||
return Err(ProxyError::Config(format!(
|
||||
"server.listen_backlog must be within [1, {}]",
|
||||
i32::MAX
|
||||
)));
|
||||
}
|
||||
|
||||
config
|
||||
.server
|
||||
.client_mss_value()
|
||||
|
||||
@@ -95,6 +95,44 @@ max_client_frame = 16777217
|
||||
remove_temp_config(&path);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_rejects_listen_backlog_above_i32_upper_bound() {
|
||||
let path = write_temp_config(
|
||||
r#"
|
||||
[server]
|
||||
listen_backlog = 2147483648
|
||||
"#,
|
||||
);
|
||||
|
||||
let err = ProxyConfig::load(&path).expect_err("listen_backlog above socket cap must fail");
|
||||
let msg = err.to_string();
|
||||
assert!(
|
||||
msg.contains("server.listen_backlog must be within [1, 2147483647]"),
|
||||
"error must explain listen_backlog hard cap, got: {msg}"
|
||||
);
|
||||
|
||||
remove_temp_config(&path);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_rejects_zero_listen_backlog() {
|
||||
let path = write_temp_config(
|
||||
r#"
|
||||
[server]
|
||||
listen_backlog = 0
|
||||
"#,
|
||||
);
|
||||
|
||||
let err = ProxyConfig::load(&path).expect_err("zero listen_backlog must fail");
|
||||
let msg = err.to_string();
|
||||
assert!(
|
||||
msg.contains("server.listen_backlog must be within [1, 2147483647]"),
|
||||
"error must explain listen_backlog lower bound, got: {msg}"
|
||||
);
|
||||
|
||||
remove_temp_config(&path);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_accepts_memory_limits_at_hard_upper_bounds() {
|
||||
let path = write_temp_config(
|
||||
|
||||
@@ -113,7 +113,7 @@ use crate::proxy::handshake::{
|
||||
};
|
||||
#[cfg(test)]
|
||||
use crate::proxy::handshake::{handle_mtproto_handshake, handle_tls_handshake};
|
||||
use crate::proxy::masking::handle_bad_client;
|
||||
use crate::proxy::masking::handle_bad_client_with_shared;
|
||||
use crate::proxy::middle_relay::handle_via_middle_proxy;
|
||||
use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController};
|
||||
use crate::proxy::shared_state::ProxySharedState;
|
||||
@@ -310,6 +310,7 @@ fn masking_outcome<R, W>(
|
||||
local_addr: SocketAddr,
|
||||
config: Arc<ProxyConfig>,
|
||||
beobachten: Arc<BeobachtenStore>,
|
||||
shared: Arc<ProxySharedState>,
|
||||
) -> HandshakeOutcome
|
||||
where
|
||||
R: AsyncRead + Unpin + Send + 'static,
|
||||
@@ -325,7 +326,7 @@ where
|
||||
)
|
||||
.await;
|
||||
|
||||
handle_bad_client(
|
||||
handle_bad_client_with_shared(
|
||||
reader,
|
||||
writer,
|
||||
&initial_data,
|
||||
@@ -333,6 +334,7 @@ where
|
||||
local_addr,
|
||||
&config,
|
||||
&beobachten,
|
||||
shared.as_ref(),
|
||||
)
|
||||
.await;
|
||||
Ok(())
|
||||
@@ -718,6 +720,7 @@ where
|
||||
local_addr,
|
||||
config.clone(),
|
||||
beobachten.clone(),
|
||||
shared.clone(),
|
||||
));
|
||||
}
|
||||
|
||||
@@ -739,6 +742,7 @@ where
|
||||
local_addr,
|
||||
config.clone(),
|
||||
beobachten.clone(),
|
||||
shared.clone(),
|
||||
));
|
||||
}
|
||||
};
|
||||
@@ -757,6 +761,7 @@ where
|
||||
local_addr,
|
||||
config.clone(),
|
||||
beobachten.clone(),
|
||||
shared.clone(),
|
||||
));
|
||||
}
|
||||
|
||||
@@ -787,6 +792,7 @@ where
|
||||
local_addr,
|
||||
config.clone(),
|
||||
beobachten.clone(),
|
||||
shared.clone(),
|
||||
));
|
||||
}
|
||||
HandshakeResult::Error(e) => {
|
||||
@@ -844,6 +850,7 @@ where
|
||||
local_addr,
|
||||
config.clone(),
|
||||
beobachten.clone(),
|
||||
shared.clone(),
|
||||
));
|
||||
}
|
||||
HandshakeResult::Error(e) => return Err(e),
|
||||
@@ -873,6 +880,7 @@ where
|
||||
local_addr,
|
||||
config.clone(),
|
||||
beobachten.clone(),
|
||||
shared.clone(),
|
||||
));
|
||||
}
|
||||
|
||||
@@ -898,6 +906,7 @@ where
|
||||
local_addr,
|
||||
config.clone(),
|
||||
beobachten.clone(),
|
||||
shared.clone(),
|
||||
));
|
||||
}
|
||||
HandshakeResult::Error(e) => return Err(e),
|
||||
@@ -1329,6 +1338,7 @@ impl RunningClientHandler {
|
||||
local_addr,
|
||||
self.config.clone(),
|
||||
self.beobachten.clone(),
|
||||
self.shared.clone(),
|
||||
));
|
||||
}
|
||||
|
||||
@@ -1350,6 +1360,7 @@ impl RunningClientHandler {
|
||||
local_addr,
|
||||
self.config.clone(),
|
||||
self.beobachten.clone(),
|
||||
self.shared.clone(),
|
||||
));
|
||||
}
|
||||
};
|
||||
@@ -1369,6 +1380,7 @@ impl RunningClientHandler {
|
||||
local_addr,
|
||||
self.config.clone(),
|
||||
self.beobachten.clone(),
|
||||
self.shared.clone(),
|
||||
));
|
||||
}
|
||||
|
||||
@@ -1416,6 +1428,7 @@ impl RunningClientHandler {
|
||||
local_addr,
|
||||
config.clone(),
|
||||
self.beobachten.clone(),
|
||||
self.shared.clone(),
|
||||
));
|
||||
}
|
||||
HandshakeResult::Error(e) => {
|
||||
@@ -1483,6 +1496,7 @@ impl RunningClientHandler {
|
||||
local_addr,
|
||||
config.clone(),
|
||||
self.beobachten.clone(),
|
||||
self.shared.clone(),
|
||||
));
|
||||
}
|
||||
HandshakeResult::Error(e) => return Err(e),
|
||||
@@ -1530,6 +1544,7 @@ impl RunningClientHandler {
|
||||
local_addr,
|
||||
self.config.clone(),
|
||||
self.beobachten.clone(),
|
||||
self.shared.clone(),
|
||||
));
|
||||
}
|
||||
|
||||
@@ -1568,6 +1583,7 @@ impl RunningClientHandler {
|
||||
local_addr,
|
||||
config.clone(),
|
||||
self.beobachten.clone(),
|
||||
self.shared.clone(),
|
||||
));
|
||||
}
|
||||
HandshakeResult::Error(e) => return Err(e),
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
use crate::config::ProxyConfig;
|
||||
use crate::network::dns_overrides::resolve_socket_addr;
|
||||
use crate::protocol::tls;
|
||||
use crate::proxy::shared_state::ProxySharedState;
|
||||
use crate::stats::beobachten::BeobachtenStore;
|
||||
use crate::transport::proxy_protocol::{ProxyProtocolV1Builder, ProxyProtocolV2Builder};
|
||||
use crate::transport::socket::configure_tcp_socket;
|
||||
@@ -10,6 +11,7 @@ use crate::transport::socket::configure_tcp_socket;
|
||||
use nix::ifaddrs::getifaddrs;
|
||||
use rand::rngs::StdRng;
|
||||
use rand::{Rng, RngExt, SeedableRng};
|
||||
use std::io::{Error as IoError, ErrorKind};
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::str;
|
||||
#[cfg(test)]
|
||||
@@ -18,7 +20,7 @@ use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::sync::{Mutex, OnceLock};
|
||||
use std::time::{Duration, Instant as StdInstant};
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::net::{TcpStream, lookup_host};
|
||||
#[cfg(unix)]
|
||||
use tokio::net::UnixStream;
|
||||
#[cfg(unix)]
|
||||
@@ -271,6 +273,32 @@ async fn consume_client_data_with_timeout_and_cap<R>(
|
||||
}
|
||||
}
|
||||
|
||||
fn mask_failure_drain_cap(config: &ProxyConfig) -> usize {
|
||||
let configured_cap = config.censorship.mask_relay_max_bytes;
|
||||
if configured_cap == 0 {
|
||||
return MASK_BUFFER_SIZE;
|
||||
}
|
||||
|
||||
configured_cap.min(MASK_BUFFER_SIZE)
|
||||
}
|
||||
|
||||
async fn consume_mask_failure_path<R>(
|
||||
reader: R,
|
||||
config: &ProxyConfig,
|
||||
relay_timeout: Duration,
|
||||
idle_timeout: Duration,
|
||||
) where
|
||||
R: AsyncRead + Unpin,
|
||||
{
|
||||
consume_client_data_with_timeout_and_cap(
|
||||
reader,
|
||||
mask_failure_drain_cap(config),
|
||||
relay_timeout,
|
||||
idle_timeout,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
async fn wait_mask_connect_budget(started: Instant) {
|
||||
let elapsed = started.elapsed();
|
||||
if elapsed < MASK_TIMEOUT {
|
||||
@@ -501,6 +529,32 @@ fn parse_mask_host_ip_literal(host: &str) -> Option<IpAddr> {
|
||||
host.parse::<IpAddr>().ok()
|
||||
}
|
||||
|
||||
async fn resolve_mask_target_addrs(
|
||||
mask_host: &str,
|
||||
mask_port: u16,
|
||||
) -> std::io::Result<Vec<SocketAddr>> {
|
||||
if let Some(addr) = resolve_socket_addr(mask_host, mask_port) {
|
||||
return Ok(vec![addr]);
|
||||
}
|
||||
|
||||
if let Some(ip) = parse_mask_host_ip_literal(mask_host) {
|
||||
return Ok(vec![SocketAddr::new(ip, mask_port)]);
|
||||
}
|
||||
|
||||
let addrs = timeout(MASK_TIMEOUT, lookup_host((mask_host, mask_port)))
|
||||
.await
|
||||
.map_err(|_| IoError::new(ErrorKind::TimedOut, "mask target DNS lookup timed out"))??;
|
||||
let addrs = addrs.collect::<Vec<_>>();
|
||||
if addrs.is_empty() {
|
||||
return Err(IoError::new(
|
||||
ErrorKind::NotFound,
|
||||
"mask target DNS lookup returned no addresses",
|
||||
));
|
||||
}
|
||||
|
||||
Ok(addrs)
|
||||
}
|
||||
|
||||
fn matching_tls_domain_for_sni<'a>(config: &'a ProxyConfig, sni: &str) -> Option<&'a str> {
|
||||
if config.censorship.tls_domain.eq_ignore_ascii_case(sni) {
|
||||
return Some(config.censorship.tls_domain.as_str());
|
||||
@@ -782,7 +836,7 @@ fn is_mask_target_local_listener_with_interfaces(
|
||||
mask_host: &str,
|
||||
mask_port: u16,
|
||||
local_addr: SocketAddr,
|
||||
resolved_override: Option<SocketAddr>,
|
||||
resolved_addrs: &[SocketAddr],
|
||||
interface_ips: &[IpAddr],
|
||||
) -> bool {
|
||||
if mask_port != local_addr.port() {
|
||||
@@ -792,7 +846,7 @@ fn is_mask_target_local_listener_with_interfaces(
|
||||
let local_ip = canonical_ip(local_addr.ip());
|
||||
let literal_mask_ip = parse_mask_host_ip_literal(mask_host).map(canonical_ip);
|
||||
|
||||
if let Some(addr) = resolved_override {
|
||||
for addr in resolved_addrs {
|
||||
let resolved_ip = canonical_ip(addr.ip());
|
||||
if resolved_ip == local_ip {
|
||||
return true;
|
||||
@@ -829,7 +883,7 @@ fn is_mask_target_local_listener(
|
||||
mask_host: &str,
|
||||
mask_port: u16,
|
||||
local_addr: SocketAddr,
|
||||
resolved_override: Option<SocketAddr>,
|
||||
resolved_addrs: &[SocketAddr],
|
||||
) -> bool {
|
||||
if mask_port != local_addr.port() {
|
||||
return false;
|
||||
@@ -840,7 +894,7 @@ fn is_mask_target_local_listener(
|
||||
mask_host,
|
||||
mask_port,
|
||||
local_addr,
|
||||
resolved_override,
|
||||
resolved_addrs,
|
||||
&interfaces,
|
||||
)
|
||||
}
|
||||
@@ -849,7 +903,7 @@ async fn is_mask_target_local_listener_async(
|
||||
mask_host: &str,
|
||||
mask_port: u16,
|
||||
local_addr: SocketAddr,
|
||||
resolved_override: Option<SocketAddr>,
|
||||
resolved_addrs: &[SocketAddr],
|
||||
) -> bool {
|
||||
if mask_port != local_addr.port() {
|
||||
return false;
|
||||
@@ -860,7 +914,7 @@ async fn is_mask_target_local_listener_async(
|
||||
mask_host,
|
||||
mask_port,
|
||||
local_addr,
|
||||
resolved_override,
|
||||
resolved_addrs,
|
||||
&interfaces,
|
||||
)
|
||||
}
|
||||
@@ -904,7 +958,7 @@ fn configure_mask_backend_socket(stream: &TcpStream) {
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle a bad client by forwarding to mask host
|
||||
/// Handles a bad client by forwarding it to the configured mask target.
|
||||
pub async fn handle_bad_client<R, W>(
|
||||
reader: R,
|
||||
writer: W,
|
||||
@@ -916,6 +970,34 @@ pub async fn handle_bad_client<R, W>(
|
||||
) where
|
||||
R: AsyncRead + Unpin + Send + 'static,
|
||||
W: AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
let shared = ProxySharedState::new();
|
||||
handle_bad_client_with_shared(
|
||||
reader,
|
||||
writer,
|
||||
initial_data,
|
||||
peer,
|
||||
local_addr,
|
||||
config,
|
||||
beobachten,
|
||||
shared.as_ref(),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
/// Handles a bad client with shared pre-auth fallback admission state.
|
||||
pub(crate) async fn handle_bad_client_with_shared<R, W>(
|
||||
reader: R,
|
||||
writer: W,
|
||||
initial_data: &[u8],
|
||||
peer: SocketAddr,
|
||||
local_addr: SocketAddr,
|
||||
config: &ProxyConfig,
|
||||
beobachten: &BeobachtenStore,
|
||||
shared: &ProxySharedState,
|
||||
) where
|
||||
R: AsyncRead + Unpin + Send + 'static,
|
||||
W: AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
let client_type = detect_client_type(initial_data);
|
||||
if config.general.beobachten {
|
||||
@@ -938,6 +1020,17 @@ pub async fn handle_bad_client<R, W>(
|
||||
return;
|
||||
}
|
||||
|
||||
let Some(_masking_permit) = shared.try_acquire_masking_fallback_permit() else {
|
||||
let outcome_started = Instant::now();
|
||||
debug!(
|
||||
client_type = client_type,
|
||||
"Masking fallback concurrency limit reached"
|
||||
);
|
||||
consume_mask_failure_path(reader, config, relay_timeout, idle_timeout).await;
|
||||
wait_mask_outcome_budget(outcome_started, config).await;
|
||||
return;
|
||||
};
|
||||
|
||||
let client_sni = tls::extract_sni_from_client_hello(initial_data);
|
||||
let exclusive_tcp_target = client_sni
|
||||
.as_deref()
|
||||
@@ -1000,24 +1093,12 @@ pub async fn handle_bad_client<R, W>(
|
||||
Ok(Err(e)) => {
|
||||
wait_mask_connect_budget_if_needed(connect_started, config).await;
|
||||
debug!(error = %e, "Failed to connect to mask unix socket");
|
||||
consume_client_data_with_timeout_and_cap(
|
||||
reader,
|
||||
config.censorship.mask_relay_max_bytes,
|
||||
relay_timeout,
|
||||
idle_timeout,
|
||||
)
|
||||
.await;
|
||||
consume_mask_failure_path(reader, config, relay_timeout, idle_timeout).await;
|
||||
wait_mask_outcome_budget(outcome_started, config).await;
|
||||
}
|
||||
Err(_) => {
|
||||
debug!("Timeout connecting to mask unix socket");
|
||||
consume_client_data_with_timeout_and_cap(
|
||||
reader,
|
||||
config.censorship.mask_relay_max_bytes,
|
||||
relay_timeout,
|
||||
idle_timeout,
|
||||
)
|
||||
.await;
|
||||
consume_mask_failure_path(reader, config, relay_timeout, idle_timeout).await;
|
||||
wait_mask_outcome_budget(outcome_started, config).await;
|
||||
}
|
||||
}
|
||||
@@ -1030,11 +1111,27 @@ pub async fn handle_bad_client<R, W>(
|
||||
let mask_host = mask_target.host;
|
||||
let mask_port = mask_target.port;
|
||||
|
||||
let resolved_mask_addrs = match resolve_mask_target_addrs(mask_host, mask_port).await {
|
||||
Ok(addrs) => addrs,
|
||||
Err(e) => {
|
||||
let outcome_started = Instant::now();
|
||||
debug!(
|
||||
client_type = client_type,
|
||||
host = %mask_host,
|
||||
port = mask_port,
|
||||
error = %e,
|
||||
"Failed to resolve mask target"
|
||||
);
|
||||
consume_mask_failure_path(reader, config, relay_timeout, idle_timeout).await;
|
||||
wait_mask_outcome_budget(outcome_started, config).await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// Fail closed when fallback points at our own listener endpoint.
|
||||
// Self-referential masking can create recursive proxy loops under
|
||||
// misconfiguration and leak distinguishable load spikes to adversaries.
|
||||
let resolved_mask_addr = resolve_socket_addr(mask_host, mask_port);
|
||||
if is_mask_target_local_listener_async(mask_host, mask_port, local_addr, resolved_mask_addr)
|
||||
if is_mask_target_local_listener_async(mask_host, mask_port, local_addr, &resolved_mask_addrs)
|
||||
.await
|
||||
{
|
||||
let outcome_started = Instant::now();
|
||||
@@ -1045,13 +1142,7 @@ pub async fn handle_bad_client<R, W>(
|
||||
local = %local_addr,
|
||||
"Mask target resolves to local listener; refusing self-referential masking fallback"
|
||||
);
|
||||
consume_client_data_with_timeout_and_cap(
|
||||
reader,
|
||||
config.censorship.mask_relay_max_bytes,
|
||||
relay_timeout,
|
||||
idle_timeout,
|
||||
)
|
||||
.await;
|
||||
consume_mask_failure_path(reader, config, relay_timeout, idle_timeout).await;
|
||||
wait_mask_outcome_budget(outcome_started, config).await;
|
||||
return;
|
||||
}
|
||||
@@ -1066,12 +1157,12 @@ pub async fn handle_bad_client<R, W>(
|
||||
"Forwarding bad client to mask host"
|
||||
);
|
||||
|
||||
// Apply runtime DNS override for mask target when configured.
|
||||
let mask_addr = resolved_mask_addr
|
||||
.map(|addr| addr.to_string())
|
||||
.unwrap_or_else(|| format!("{}:{}", mask_host, mask_port));
|
||||
let connect_started = Instant::now();
|
||||
let connect_result = timeout(MASK_TIMEOUT, TcpStream::connect(&mask_addr)).await;
|
||||
let connect_result = timeout(
|
||||
MASK_TIMEOUT,
|
||||
TcpStream::connect(resolved_mask_addrs.as_slice()),
|
||||
)
|
||||
.await;
|
||||
match connect_result {
|
||||
Ok(Ok(stream)) => {
|
||||
configure_mask_backend_socket(&stream);
|
||||
@@ -1113,24 +1204,12 @@ pub async fn handle_bad_client<R, W>(
|
||||
Ok(Err(e)) => {
|
||||
wait_mask_connect_budget_if_needed(connect_started, config).await;
|
||||
debug!(error = %e, "Failed to connect to mask host");
|
||||
consume_client_data_with_timeout_and_cap(
|
||||
reader,
|
||||
config.censorship.mask_relay_max_bytes,
|
||||
relay_timeout,
|
||||
idle_timeout,
|
||||
)
|
||||
.await;
|
||||
consume_mask_failure_path(reader, config, relay_timeout, idle_timeout).await;
|
||||
wait_mask_outcome_budget(outcome_started, config).await;
|
||||
}
|
||||
Err(_) => {
|
||||
debug!("Timeout connecting to mask host");
|
||||
consume_client_data_with_timeout_and_cap(
|
||||
reader,
|
||||
config.censorship.mask_relay_max_bytes,
|
||||
relay_timeout,
|
||||
idle_timeout,
|
||||
)
|
||||
.await;
|
||||
consume_mask_failure_path(reader, config, relay_timeout, idle_timeout).await;
|
||||
wait_mask_outcome_budget(outcome_started, config).await;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -52,7 +52,8 @@ use self::c2me::{
|
||||
};
|
||||
use self::d2c::{
|
||||
MeD2cFlushPolicy, MeWriterResponseOutcome, classify_me_d2c_flush_reason,
|
||||
flush_client_or_cancel, observe_me_d2c_flush_event,
|
||||
flush_client_or_cancel, me_d2c_flush_reason_requires_client_flush,
|
||||
observe_me_d2c_flush_event,
|
||||
process_me_writer_response_with_traffic_lease,
|
||||
};
|
||||
use self::desync::{RelayForensicsState, hash_ip_in, report_desync_frame_too_large_in};
|
||||
|
||||
@@ -55,6 +55,37 @@ pub(super) fn classify_me_d2c_flush_reason(
|
||||
MeD2cFlushReason::QueueDrain
|
||||
}
|
||||
|
||||
pub(super) fn me_d2c_flush_reason_requires_client_flush(reason: MeD2cFlushReason) -> bool {
|
||||
!matches!(reason, MeD2cFlushReason::QueueDrain)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn queue_drain_is_not_a_physical_flush_trigger() {
|
||||
assert!(!me_d2c_flush_reason_requires_client_flush(
|
||||
MeD2cFlushReason::QueueDrain
|
||||
));
|
||||
assert!(me_d2c_flush_reason_requires_client_flush(
|
||||
MeD2cFlushReason::AckImmediate
|
||||
));
|
||||
assert!(me_d2c_flush_reason_requires_client_flush(
|
||||
MeD2cFlushReason::BatchFrames
|
||||
));
|
||||
assert!(me_d2c_flush_reason_requires_client_flush(
|
||||
MeD2cFlushReason::BatchBytes
|
||||
));
|
||||
assert!(me_d2c_flush_reason_requires_client_flush(
|
||||
MeD2cFlushReason::MaxDelay
|
||||
));
|
||||
assert!(me_d2c_flush_reason_requires_client_flush(
|
||||
MeD2cFlushReason::Close
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn observe_me_d2c_flush_event(
|
||||
stats: &Stats,
|
||||
reason: MeD2cFlushReason,
|
||||
|
||||
@@ -491,12 +491,18 @@ where
|
||||
d2c_flush_policy.max_bytes,
|
||||
max_delay_fired,
|
||||
);
|
||||
let flush_started_at = if stats_clone.telemetry_policy().me_level.allows_debug() {
|
||||
let physical_flush =
|
||||
me_d2c_flush_reason_requires_client_flush(flush_reason);
|
||||
let flush_started_at = if physical_flush
|
||||
&& stats_clone.telemetry_policy().me_level.allows_debug()
|
||||
{
|
||||
Some(Instant::now())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
flush_client_or_cancel(&mut writer, &flow_cancel_me_writer).await?;
|
||||
if physical_flush {
|
||||
flush_client_or_cancel(&mut writer, &flow_cancel_me_writer).await?;
|
||||
}
|
||||
let flush_duration_us = flush_started_at.map(|started| {
|
||||
started
|
||||
.elapsed()
|
||||
|
||||
@@ -6,7 +6,7 @@ use std::sync::{Arc, Mutex};
|
||||
use std::time::Instant;
|
||||
|
||||
use dashmap::DashMap;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::{OwnedSemaphorePermit, Semaphore, mpsc};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
use crate::proxy::handshake::{AuthProbeSaturationState, AuthProbeState};
|
||||
@@ -14,6 +14,7 @@ use crate::proxy::middle_relay::{DesyncDedupRotationState, RelayIdleCandidateReg
|
||||
use crate::proxy::traffic_limiter::TrafficLimiter;
|
||||
|
||||
const HANDSHAKE_RECENT_USER_RING_LEN: usize = 64;
|
||||
const MASKING_FALLBACK_MAX_CONCURRENT: usize = 512;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub(crate) enum ConntrackCloseReason {
|
||||
@@ -72,6 +73,7 @@ pub(crate) struct ProxySharedState {
|
||||
active_user_sessions: DashMap<(String, u64), CancellationToken>,
|
||||
pub(crate) conntrack_pressure_active: AtomicBool,
|
||||
pub(crate) conntrack_close_tx: Mutex<Option<mpsc::Sender<ConntrackCloseEvent>>>,
|
||||
masking_fallback_permits: Arc<Semaphore>,
|
||||
}
|
||||
|
||||
#[must_use = "registered user sessions must be kept alive until relay completion"]
|
||||
@@ -131,9 +133,18 @@ impl ProxySharedState {
|
||||
active_user_sessions: DashMap::new(),
|
||||
conntrack_pressure_active: AtomicBool::new(false),
|
||||
conntrack_close_tx: Mutex::new(None),
|
||||
masking_fallback_permits: Arc::new(Semaphore::new(MASKING_FALLBACK_MAX_CONCURRENT)),
|
||||
})
|
||||
}
|
||||
|
||||
/// Attempts to reserve one masking fallback slot for a pre-auth connection.
|
||||
pub(crate) fn try_acquire_masking_fallback_permit(&self) -> Option<OwnedSemaphorePermit> {
|
||||
self.masking_fallback_permits
|
||||
.clone()
|
||||
.try_acquire_owned()
|
||||
.ok()
|
||||
}
|
||||
|
||||
pub(crate) fn is_user_enabled(&self, user: &str) -> bool {
|
||||
!self.disabled_users.contains_key(user)
|
||||
}
|
||||
|
||||
@@ -34,7 +34,7 @@ fn loop_guard_unspecified_bind_uses_interface_inventory() {
|
||||
"mask.example",
|
||||
443,
|
||||
local,
|
||||
Some(resolved),
|
||||
&[resolved],
|
||||
&interfaces,
|
||||
));
|
||||
}
|
||||
|
||||
@@ -25,7 +25,7 @@ async fn adversarial_parallel_cold_miss_performs_single_interface_refresh() {
|
||||
let barrier = std::sync::Arc::clone(&barrier);
|
||||
tasks.push(tokio::spawn(async move {
|
||||
barrier.wait().await;
|
||||
is_mask_target_local_listener_async("127.0.0.1", 443, local_addr, None).await
|
||||
is_mask_target_local_listener_async("127.0.0.1", 443, local_addr, &[]).await
|
||||
}));
|
||||
}
|
||||
|
||||
|
||||
@@ -17,8 +17,8 @@ async fn tdd_repeated_local_listener_checks_do_not_repeat_interface_enumeration_
|
||||
|
||||
let local_addr: SocketAddr = "0.0.0.0:443".parse().expect("valid local addr");
|
||||
|
||||
let _ = is_mask_target_local_listener_async("127.0.0.1", 443, local_addr, None).await;
|
||||
let _ = is_mask_target_local_listener_async("127.0.0.1", 443, local_addr, None).await;
|
||||
let _ = is_mask_target_local_listener_async("127.0.0.1", 443, local_addr, &[]).await;
|
||||
let _ = is_mask_target_local_listener_async("127.0.0.1", 443, local_addr, &[]).await;
|
||||
|
||||
assert_eq!(
|
||||
local_interface_enumerations_for_tests(),
|
||||
@@ -35,7 +35,7 @@ async fn tdd_non_local_port_short_circuit_does_not_enumerate_interfaces() {
|
||||
reset_local_interface_enumerations_for_tests();
|
||||
|
||||
let local_addr: SocketAddr = "0.0.0.0:443".parse().expect("valid local addr");
|
||||
let is_local = is_mask_target_local_listener_async("127.0.0.1", 8443, local_addr, None).await;
|
||||
let is_local = is_mask_target_local_listener_async("127.0.0.1", 8443, local_addr, &[]).await;
|
||||
|
||||
assert!(
|
||||
!is_local,
|
||||
|
||||
@@ -15,38 +15,49 @@ fn closed_local_port() -> u16 {
|
||||
#[tokio::test]
|
||||
async fn self_target_detection_matches_literal_ipv4_listener() {
|
||||
let local: SocketAddr = "198.51.100.40:443".parse().unwrap();
|
||||
assert!(is_mask_target_local_listener_async("198.51.100.40", 443, local, None,).await);
|
||||
assert!(is_mask_target_local_listener_async("198.51.100.40", 443, local, &[],).await);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn self_target_detection_matches_bracketed_ipv6_listener() {
|
||||
let local: SocketAddr = "[2001:db8::44]:8443".parse().unwrap();
|
||||
assert!(is_mask_target_local_listener_async("[2001:db8::44]", 8443, local, None,).await);
|
||||
assert!(is_mask_target_local_listener_async("[2001:db8::44]", 8443, local, &[],).await);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn self_target_detection_keeps_same_ip_different_port_forwardable() {
|
||||
let local: SocketAddr = "203.0.113.44:443".parse().unwrap();
|
||||
assert!(!is_mask_target_local_listener_async("203.0.113.44", 8443, local, None,).await);
|
||||
assert!(!is_mask_target_local_listener_async("203.0.113.44", 8443, local, &[],).await);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn self_target_detection_normalizes_ipv4_mapped_ipv6_literal() {
|
||||
let local: SocketAddr = "127.0.0.1:443".parse().unwrap();
|
||||
assert!(is_mask_target_local_listener_async("::ffff:127.0.0.1", 443, local, None,).await);
|
||||
assert!(is_mask_target_local_listener_async("::ffff:127.0.0.1", 443, local, &[],).await);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn self_target_detection_unspecified_bind_blocks_loopback_target() {
|
||||
let local: SocketAddr = "0.0.0.0:443".parse().unwrap();
|
||||
assert!(is_mask_target_local_listener_async("127.0.0.1", 443, local, None,).await);
|
||||
assert!(is_mask_target_local_listener_async("127.0.0.1", 443, local, &[],).await);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn self_target_detection_unspecified_bind_keeps_remote_target_forwardable() {
|
||||
let local: SocketAddr = "0.0.0.0:443".parse().unwrap();
|
||||
let remote: SocketAddr = "198.51.100.44:443".parse().unwrap();
|
||||
assert!(!is_mask_target_local_listener_async("mask.example", 443, local, Some(remote),).await);
|
||||
assert!(!is_mask_target_local_listener_async("mask.example", 443, local, &[remote],).await);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn self_target_detection_checks_all_resolved_addresses() {
|
||||
let local: SocketAddr = "127.0.0.1:443".parse().unwrap();
|
||||
let remote: SocketAddr = "198.51.100.44:443".parse().unwrap();
|
||||
let loopback: SocketAddr = "127.0.0.1:443".parse().unwrap();
|
||||
|
||||
assert!(
|
||||
is_mask_target_local_listener_async("mask.example", 443, local, &[remote, loopback],).await
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
||||
@@ -11,16 +11,41 @@ use std::io::{Error, ErrorKind, Result};
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
|
||||
const DEFAULT_MAX_FRAME_SIZE: usize = 16 * 1024 * 1024;
|
||||
|
||||
fn reject_oversize_frame(len: usize, max_frame_size: usize, protocol: &str) -> Result<()> {
|
||||
if len > max_frame_size {
|
||||
return Err(Error::new(
|
||||
ErrorKind::InvalidData,
|
||||
format!("{protocol} frame too large: {len} bytes (max {max_frame_size})"),
|
||||
));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ============= Abridged (Compact) Frame =============
|
||||
|
||||
/// Reader for abridged MTProto framing
|
||||
pub struct AbridgedFrameReader<R> {
|
||||
upstream: R,
|
||||
max_frame_size: usize,
|
||||
}
|
||||
|
||||
impl<R> AbridgedFrameReader<R> {
|
||||
/// Creates a reader with the default maximum frame size.
|
||||
pub fn new(upstream: R) -> Self {
|
||||
Self { upstream }
|
||||
Self {
|
||||
upstream,
|
||||
max_frame_size: DEFAULT_MAX_FRAME_SIZE,
|
||||
}
|
||||
}
|
||||
|
||||
fn with_max_frame_size(upstream: R, max_frame_size: usize) -> Self {
|
||||
Self {
|
||||
upstream,
|
||||
max_frame_size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -48,10 +73,12 @@ impl<R: AsyncRead + Unpin> AbridgedFrameReader<R> {
|
||||
len = u32::from_le_bytes([len_bytes[0], len_bytes[1], len_bytes[2], 0]) as usize;
|
||||
}
|
||||
|
||||
// Length is in 4-byte words
|
||||
let byte_len = len * 4;
|
||||
// Length is in 4-byte words.
|
||||
let byte_len = len
|
||||
.checked_mul(4)
|
||||
.ok_or_else(|| Error::new(ErrorKind::InvalidData, "abridged frame length overflow"))?;
|
||||
reject_oversize_frame(byte_len, self.max_frame_size, "abridged")?;
|
||||
|
||||
// Read data
|
||||
let mut data = vec![0u8; byte_len];
|
||||
self.upstream.read_exact(&mut data).await?;
|
||||
|
||||
@@ -152,11 +179,23 @@ impl<W> LayeredStream<W> for AbridgedFrameWriter<W> {
|
||||
/// Reader for intermediate MTProto framing
|
||||
pub struct IntermediateFrameReader<R> {
|
||||
upstream: R,
|
||||
max_frame_size: usize,
|
||||
}
|
||||
|
||||
impl<R> IntermediateFrameReader<R> {
|
||||
/// Creates a reader with the default maximum frame size.
|
||||
pub fn new(upstream: R) -> Self {
|
||||
Self { upstream }
|
||||
Self {
|
||||
upstream,
|
||||
max_frame_size: DEFAULT_MAX_FRAME_SIZE,
|
||||
}
|
||||
}
|
||||
|
||||
fn with_max_frame_size(upstream: R, max_frame_size: usize) -> Self {
|
||||
Self {
|
||||
upstream,
|
||||
max_frame_size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -171,8 +210,8 @@ impl<R: AsyncRead + Unpin> IntermediateFrameReader<R> {
|
||||
let header = parse_intermediate_header(len_bytes);
|
||||
let len = header.wire_len;
|
||||
meta.quickack = header.quickack;
|
||||
reject_oversize_frame(len, self.max_frame_size, "intermediate")?;
|
||||
|
||||
// Read data
|
||||
let mut data = vec![0u8; len];
|
||||
self.upstream.read_exact(&mut data).await?;
|
||||
|
||||
@@ -243,11 +282,23 @@ impl<W> LayeredStream<W> for IntermediateFrameWriter<W> {
|
||||
/// Reader for secure intermediate MTProto framing (with padding)
|
||||
pub struct SecureIntermediateFrameReader<R> {
|
||||
upstream: R,
|
||||
max_frame_size: usize,
|
||||
}
|
||||
|
||||
impl<R> SecureIntermediateFrameReader<R> {
|
||||
/// Creates a reader with the default maximum frame size.
|
||||
pub fn new(upstream: R) -> Self {
|
||||
Self { upstream }
|
||||
Self {
|
||||
upstream,
|
||||
max_frame_size: DEFAULT_MAX_FRAME_SIZE,
|
||||
}
|
||||
}
|
||||
|
||||
fn with_max_frame_size(upstream: R, max_frame_size: usize) -> Self {
|
||||
Self {
|
||||
upstream,
|
||||
max_frame_size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -262,17 +313,16 @@ impl<R: AsyncRead + Unpin> SecureIntermediateFrameReader<R> {
|
||||
let header = parse_intermediate_header(len_bytes);
|
||||
let len = header.wire_len;
|
||||
meta.quickack = header.quickack;
|
||||
|
||||
// Read data (including padding)
|
||||
let mut data = vec![0u8; len];
|
||||
self.upstream.read_exact(&mut data).await?;
|
||||
|
||||
reject_oversize_frame(len, self.max_frame_size, "secure intermediate")?;
|
||||
let payload_len = secure_payload_len_from_wire_len(len).ok_or_else(|| {
|
||||
Error::new(
|
||||
ErrorKind::InvalidData,
|
||||
format!("Invalid secure frame length: {len}"),
|
||||
)
|
||||
})?;
|
||||
|
||||
let mut data = vec![0u8; len];
|
||||
self.upstream.read_exact(&mut data).await?;
|
||||
data.truncate(payload_len);
|
||||
|
||||
Ok((Bytes::from(data), meta))
|
||||
@@ -321,7 +371,9 @@ impl<W: AsyncWrite + Unpin> SecureIntermediateFrameWriter<W> {
|
||||
let padding_len = secure_padding_len(data.len(), &self.rng);
|
||||
let padding = self.rng.bytes(padding_len);
|
||||
|
||||
let total_len = data.len() + padding_len;
|
||||
let total_len = data.len().checked_add(padding_len).ok_or_else(|| {
|
||||
Error::new(ErrorKind::InvalidInput, "secure frame length overflow")
|
||||
})?;
|
||||
let len = encode_intermediate_header(total_len, meta.quickack).ok_or_else(|| {
|
||||
Error::new(
|
||||
ErrorKind::InvalidInput,
|
||||
@@ -507,15 +559,26 @@ pub enum FrameReaderKind<R> {
|
||||
}
|
||||
|
||||
impl<R: AsyncRead + Unpin> FrameReaderKind<R> {
|
||||
/// Creates a frame reader with the default maximum frame size.
|
||||
pub fn new(upstream: R, proto_tag: ProtoTag) -> Self {
|
||||
Self::with_max_frame_size(upstream, proto_tag, DEFAULT_MAX_FRAME_SIZE)
|
||||
}
|
||||
|
||||
fn with_max_frame_size(
|
||||
upstream: R,
|
||||
proto_tag: ProtoTag,
|
||||
max_frame_size: usize,
|
||||
) -> Self {
|
||||
match proto_tag {
|
||||
ProtoTag::Abridged => FrameReaderKind::Abridged(AbridgedFrameReader::new(upstream)),
|
||||
ProtoTag::Intermediate => {
|
||||
FrameReaderKind::Intermediate(IntermediateFrameReader::new(upstream))
|
||||
}
|
||||
ProtoTag::Secure => {
|
||||
FrameReaderKind::SecureIntermediate(SecureIntermediateFrameReader::new(upstream))
|
||||
}
|
||||
ProtoTag::Abridged => FrameReaderKind::Abridged(
|
||||
AbridgedFrameReader::with_max_frame_size(upstream, max_frame_size),
|
||||
),
|
||||
ProtoTag::Intermediate => FrameReaderKind::Intermediate(
|
||||
IntermediateFrameReader::with_max_frame_size(upstream, max_frame_size),
|
||||
),
|
||||
ProtoTag::Secure => FrameReaderKind::SecureIntermediate(
|
||||
SecureIntermediateFrameReader::with_max_frame_size(upstream, max_frame_size),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -569,7 +632,8 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::crypto::SecureRandom;
|
||||
use std::sync::Arc;
|
||||
use tokio::io::duplex;
|
||||
use tokio::io::{AsyncWriteExt, duplex};
|
||||
use tokio::time::{Duration, timeout};
|
||||
|
||||
fn assert_secure_decoded_payload(decoded: &[u8], original: &[u8]) {
|
||||
assert!(decoded.starts_with(original));
|
||||
@@ -672,6 +736,55 @@ mod tests {
|
||||
assert!(meta.quickack);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn abridged_reader_rejects_oversize_frame_before_body_read() {
|
||||
let (mut client, server) = duplex(1024);
|
||||
let mut reader = AbridgedFrameReader::new(server);
|
||||
let len_words = (DEFAULT_MAX_FRAME_SIZE / 4) + 1;
|
||||
let encoded = (len_words as u32).to_le_bytes();
|
||||
|
||||
client
|
||||
.write_all(&[0x7f, encoded[0], encoded[1], encoded[2]])
|
||||
.await
|
||||
.unwrap();
|
||||
let err = timeout(Duration::from_millis(50), reader.read_frame())
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap_err();
|
||||
|
||||
assert_eq!(err.kind(), ErrorKind::InvalidData);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn intermediate_reader_rejects_oversize_frame_before_body_read() {
|
||||
let (mut client, server) = duplex(1024);
|
||||
let mut reader = IntermediateFrameReader::new(server);
|
||||
let len = encode_intermediate_header(DEFAULT_MAX_FRAME_SIZE + 1, false).unwrap();
|
||||
|
||||
client.write_all(&len.to_le_bytes()).await.unwrap();
|
||||
let err = timeout(Duration::from_millis(50), reader.read_frame())
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap_err();
|
||||
|
||||
assert_eq!(err.kind(), ErrorKind::InvalidData);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn secure_reader_rejects_oversize_frame_before_body_read() {
|
||||
let (mut client, server) = duplex(1024);
|
||||
let mut reader = SecureIntermediateFrameReader::new(server);
|
||||
let len = encode_intermediate_header(DEFAULT_MAX_FRAME_SIZE + 4, false).unwrap();
|
||||
|
||||
client.write_all(&len.to_le_bytes()).await.unwrap();
|
||||
let err = timeout(Duration::from_millis(50), reader.read_frame())
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap_err();
|
||||
|
||||
assert_eq!(err.kind(), ErrorKind::InvalidData);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_secure_intermediate_padding() {
|
||||
let (client, server) = duplex(1024);
|
||||
|
||||
Reference in New Issue
Block a user