Harden masking fallback and frame readers after flow sync

Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
This commit is contained in:
Alexey
2026-06-17 21:48:57 +03:00
parent 49742d38a7
commit 72800e4aa7
13 changed files with 401 additions and 88 deletions

View File

@@ -1940,6 +1940,13 @@ impl ProxyConfig {
)); ));
} }
if config.server.listen_backlog == 0 || config.server.listen_backlog > i32::MAX as u32 {
return Err(ProxyError::Config(format!(
"server.listen_backlog must be within [1, {}]",
i32::MAX
)));
}
config config
.server .server
.client_mss_value() .client_mss_value()

View File

@@ -95,6 +95,44 @@ max_client_frame = 16777217
remove_temp_config(&path); remove_temp_config(&path);
} }
#[test]
fn load_rejects_listen_backlog_above_i32_upper_bound() {
let path = write_temp_config(
r#"
[server]
listen_backlog = 2147483648
"#,
);
let err = ProxyConfig::load(&path).expect_err("listen_backlog above socket cap must fail");
let msg = err.to_string();
assert!(
msg.contains("server.listen_backlog must be within [1, 2147483647]"),
"error must explain listen_backlog hard cap, got: {msg}"
);
remove_temp_config(&path);
}
#[test]
fn load_rejects_zero_listen_backlog() {
let path = write_temp_config(
r#"
[server]
listen_backlog = 0
"#,
);
let err = ProxyConfig::load(&path).expect_err("zero listen_backlog must fail");
let msg = err.to_string();
assert!(
msg.contains("server.listen_backlog must be within [1, 2147483647]"),
"error must explain listen_backlog lower bound, got: {msg}"
);
remove_temp_config(&path);
}
#[test] #[test]
fn load_accepts_memory_limits_at_hard_upper_bounds() { fn load_accepts_memory_limits_at_hard_upper_bounds() {
let path = write_temp_config( let path = write_temp_config(

View File

@@ -113,7 +113,7 @@ use crate::proxy::handshake::{
}; };
#[cfg(test)] #[cfg(test)]
use crate::proxy::handshake::{handle_mtproto_handshake, handle_tls_handshake}; use crate::proxy::handshake::{handle_mtproto_handshake, handle_tls_handshake};
use crate::proxy::masking::handle_bad_client; use crate::proxy::masking::handle_bad_client_with_shared;
use crate::proxy::middle_relay::handle_via_middle_proxy; use crate::proxy::middle_relay::handle_via_middle_proxy;
use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController}; use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController};
use crate::proxy::shared_state::ProxySharedState; use crate::proxy::shared_state::ProxySharedState;
@@ -310,6 +310,7 @@ fn masking_outcome<R, W>(
local_addr: SocketAddr, local_addr: SocketAddr,
config: Arc<ProxyConfig>, config: Arc<ProxyConfig>,
beobachten: Arc<BeobachtenStore>, beobachten: Arc<BeobachtenStore>,
shared: Arc<ProxySharedState>,
) -> HandshakeOutcome ) -> HandshakeOutcome
where where
R: AsyncRead + Unpin + Send + 'static, R: AsyncRead + Unpin + Send + 'static,
@@ -325,7 +326,7 @@ where
) )
.await; .await;
handle_bad_client( handle_bad_client_with_shared(
reader, reader,
writer, writer,
&initial_data, &initial_data,
@@ -333,6 +334,7 @@ where
local_addr, local_addr,
&config, &config,
&beobachten, &beobachten,
shared.as_ref(),
) )
.await; .await;
Ok(()) Ok(())
@@ -718,6 +720,7 @@ where
local_addr, local_addr,
config.clone(), config.clone(),
beobachten.clone(), beobachten.clone(),
shared.clone(),
)); ));
} }
@@ -739,6 +742,7 @@ where
local_addr, local_addr,
config.clone(), config.clone(),
beobachten.clone(), beobachten.clone(),
shared.clone(),
)); ));
} }
}; };
@@ -757,6 +761,7 @@ where
local_addr, local_addr,
config.clone(), config.clone(),
beobachten.clone(), beobachten.clone(),
shared.clone(),
)); ));
} }
@@ -787,6 +792,7 @@ where
local_addr, local_addr,
config.clone(), config.clone(),
beobachten.clone(), beobachten.clone(),
shared.clone(),
)); ));
} }
HandshakeResult::Error(e) => { HandshakeResult::Error(e) => {
@@ -844,6 +850,7 @@ where
local_addr, local_addr,
config.clone(), config.clone(),
beobachten.clone(), beobachten.clone(),
shared.clone(),
)); ));
} }
HandshakeResult::Error(e) => return Err(e), HandshakeResult::Error(e) => return Err(e),
@@ -873,6 +880,7 @@ where
local_addr, local_addr,
config.clone(), config.clone(),
beobachten.clone(), beobachten.clone(),
shared.clone(),
)); ));
} }
@@ -898,6 +906,7 @@ where
local_addr, local_addr,
config.clone(), config.clone(),
beobachten.clone(), beobachten.clone(),
shared.clone(),
)); ));
} }
HandshakeResult::Error(e) => return Err(e), HandshakeResult::Error(e) => return Err(e),
@@ -1329,6 +1338,7 @@ impl RunningClientHandler {
local_addr, local_addr,
self.config.clone(), self.config.clone(),
self.beobachten.clone(), self.beobachten.clone(),
self.shared.clone(),
)); ));
} }
@@ -1350,6 +1360,7 @@ impl RunningClientHandler {
local_addr, local_addr,
self.config.clone(), self.config.clone(),
self.beobachten.clone(), self.beobachten.clone(),
self.shared.clone(),
)); ));
} }
}; };
@@ -1369,6 +1380,7 @@ impl RunningClientHandler {
local_addr, local_addr,
self.config.clone(), self.config.clone(),
self.beobachten.clone(), self.beobachten.clone(),
self.shared.clone(),
)); ));
} }
@@ -1416,6 +1428,7 @@ impl RunningClientHandler {
local_addr, local_addr,
config.clone(), config.clone(),
self.beobachten.clone(), self.beobachten.clone(),
self.shared.clone(),
)); ));
} }
HandshakeResult::Error(e) => { HandshakeResult::Error(e) => {
@@ -1483,6 +1496,7 @@ impl RunningClientHandler {
local_addr, local_addr,
config.clone(), config.clone(),
self.beobachten.clone(), self.beobachten.clone(),
self.shared.clone(),
)); ));
} }
HandshakeResult::Error(e) => return Err(e), HandshakeResult::Error(e) => return Err(e),
@@ -1530,6 +1544,7 @@ impl RunningClientHandler {
local_addr, local_addr,
self.config.clone(), self.config.clone(),
self.beobachten.clone(), self.beobachten.clone(),
self.shared.clone(),
)); ));
} }
@@ -1568,6 +1583,7 @@ impl RunningClientHandler {
local_addr, local_addr,
config.clone(), config.clone(),
self.beobachten.clone(), self.beobachten.clone(),
self.shared.clone(),
)); ));
} }
HandshakeResult::Error(e) => return Err(e), HandshakeResult::Error(e) => return Err(e),

View File

@@ -3,6 +3,7 @@
use crate::config::ProxyConfig; use crate::config::ProxyConfig;
use crate::network::dns_overrides::resolve_socket_addr; use crate::network::dns_overrides::resolve_socket_addr;
use crate::protocol::tls; use crate::protocol::tls;
use crate::proxy::shared_state::ProxySharedState;
use crate::stats::beobachten::BeobachtenStore; use crate::stats::beobachten::BeobachtenStore;
use crate::transport::proxy_protocol::{ProxyProtocolV1Builder, ProxyProtocolV2Builder}; use crate::transport::proxy_protocol::{ProxyProtocolV1Builder, ProxyProtocolV2Builder};
use crate::transport::socket::configure_tcp_socket; use crate::transport::socket::configure_tcp_socket;
@@ -10,6 +11,7 @@ use crate::transport::socket::configure_tcp_socket;
use nix::ifaddrs::getifaddrs; use nix::ifaddrs::getifaddrs;
use rand::rngs::StdRng; use rand::rngs::StdRng;
use rand::{Rng, RngExt, SeedableRng}; use rand::{Rng, RngExt, SeedableRng};
use std::io::{Error as IoError, ErrorKind};
use std::net::{IpAddr, SocketAddr}; use std::net::{IpAddr, SocketAddr};
use std::str; use std::str;
#[cfg(test)] #[cfg(test)]
@@ -18,7 +20,7 @@ use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Mutex, OnceLock}; use std::sync::{Mutex, OnceLock};
use std::time::{Duration, Instant as StdInstant}; use std::time::{Duration, Instant as StdInstant};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpStream; use tokio::net::{TcpStream, lookup_host};
#[cfg(unix)] #[cfg(unix)]
use tokio::net::UnixStream; use tokio::net::UnixStream;
#[cfg(unix)] #[cfg(unix)]
@@ -271,6 +273,32 @@ async fn consume_client_data_with_timeout_and_cap<R>(
} }
} }
fn mask_failure_drain_cap(config: &ProxyConfig) -> usize {
let configured_cap = config.censorship.mask_relay_max_bytes;
if configured_cap == 0 {
return MASK_BUFFER_SIZE;
}
configured_cap.min(MASK_BUFFER_SIZE)
}
async fn consume_mask_failure_path<R>(
reader: R,
config: &ProxyConfig,
relay_timeout: Duration,
idle_timeout: Duration,
) where
R: AsyncRead + Unpin,
{
consume_client_data_with_timeout_and_cap(
reader,
mask_failure_drain_cap(config),
relay_timeout,
idle_timeout,
)
.await;
}
async fn wait_mask_connect_budget(started: Instant) { async fn wait_mask_connect_budget(started: Instant) {
let elapsed = started.elapsed(); let elapsed = started.elapsed();
if elapsed < MASK_TIMEOUT { if elapsed < MASK_TIMEOUT {
@@ -501,6 +529,32 @@ fn parse_mask_host_ip_literal(host: &str) -> Option<IpAddr> {
host.parse::<IpAddr>().ok() host.parse::<IpAddr>().ok()
} }
async fn resolve_mask_target_addrs(
mask_host: &str,
mask_port: u16,
) -> std::io::Result<Vec<SocketAddr>> {
if let Some(addr) = resolve_socket_addr(mask_host, mask_port) {
return Ok(vec![addr]);
}
if let Some(ip) = parse_mask_host_ip_literal(mask_host) {
return Ok(vec![SocketAddr::new(ip, mask_port)]);
}
let addrs = timeout(MASK_TIMEOUT, lookup_host((mask_host, mask_port)))
.await
.map_err(|_| IoError::new(ErrorKind::TimedOut, "mask target DNS lookup timed out"))??;
let addrs = addrs.collect::<Vec<_>>();
if addrs.is_empty() {
return Err(IoError::new(
ErrorKind::NotFound,
"mask target DNS lookup returned no addresses",
));
}
Ok(addrs)
}
fn matching_tls_domain_for_sni<'a>(config: &'a ProxyConfig, sni: &str) -> Option<&'a str> { fn matching_tls_domain_for_sni<'a>(config: &'a ProxyConfig, sni: &str) -> Option<&'a str> {
if config.censorship.tls_domain.eq_ignore_ascii_case(sni) { if config.censorship.tls_domain.eq_ignore_ascii_case(sni) {
return Some(config.censorship.tls_domain.as_str()); return Some(config.censorship.tls_domain.as_str());
@@ -782,7 +836,7 @@ fn is_mask_target_local_listener_with_interfaces(
mask_host: &str, mask_host: &str,
mask_port: u16, mask_port: u16,
local_addr: SocketAddr, local_addr: SocketAddr,
resolved_override: Option<SocketAddr>, resolved_addrs: &[SocketAddr],
interface_ips: &[IpAddr], interface_ips: &[IpAddr],
) -> bool { ) -> bool {
if mask_port != local_addr.port() { if mask_port != local_addr.port() {
@@ -792,7 +846,7 @@ fn is_mask_target_local_listener_with_interfaces(
let local_ip = canonical_ip(local_addr.ip()); let local_ip = canonical_ip(local_addr.ip());
let literal_mask_ip = parse_mask_host_ip_literal(mask_host).map(canonical_ip); let literal_mask_ip = parse_mask_host_ip_literal(mask_host).map(canonical_ip);
if let Some(addr) = resolved_override { for addr in resolved_addrs {
let resolved_ip = canonical_ip(addr.ip()); let resolved_ip = canonical_ip(addr.ip());
if resolved_ip == local_ip { if resolved_ip == local_ip {
return true; return true;
@@ -829,7 +883,7 @@ fn is_mask_target_local_listener(
mask_host: &str, mask_host: &str,
mask_port: u16, mask_port: u16,
local_addr: SocketAddr, local_addr: SocketAddr,
resolved_override: Option<SocketAddr>, resolved_addrs: &[SocketAddr],
) -> bool { ) -> bool {
if mask_port != local_addr.port() { if mask_port != local_addr.port() {
return false; return false;
@@ -840,7 +894,7 @@ fn is_mask_target_local_listener(
mask_host, mask_host,
mask_port, mask_port,
local_addr, local_addr,
resolved_override, resolved_addrs,
&interfaces, &interfaces,
) )
} }
@@ -849,7 +903,7 @@ async fn is_mask_target_local_listener_async(
mask_host: &str, mask_host: &str,
mask_port: u16, mask_port: u16,
local_addr: SocketAddr, local_addr: SocketAddr,
resolved_override: Option<SocketAddr>, resolved_addrs: &[SocketAddr],
) -> bool { ) -> bool {
if mask_port != local_addr.port() { if mask_port != local_addr.port() {
return false; return false;
@@ -860,7 +914,7 @@ async fn is_mask_target_local_listener_async(
mask_host, mask_host,
mask_port, mask_port,
local_addr, local_addr,
resolved_override, resolved_addrs,
&interfaces, &interfaces,
) )
} }
@@ -904,7 +958,7 @@ fn configure_mask_backend_socket(stream: &TcpStream) {
} }
} }
/// Handle a bad client by forwarding to mask host /// Handles a bad client by forwarding it to the configured mask target.
pub async fn handle_bad_client<R, W>( pub async fn handle_bad_client<R, W>(
reader: R, reader: R,
writer: W, writer: W,
@@ -916,6 +970,34 @@ pub async fn handle_bad_client<R, W>(
) where ) where
R: AsyncRead + Unpin + Send + 'static, R: AsyncRead + Unpin + Send + 'static,
W: AsyncWrite + Unpin + Send + 'static, W: AsyncWrite + Unpin + Send + 'static,
{
let shared = ProxySharedState::new();
handle_bad_client_with_shared(
reader,
writer,
initial_data,
peer,
local_addr,
config,
beobachten,
shared.as_ref(),
)
.await;
}
/// Handles a bad client with shared pre-auth fallback admission state.
pub(crate) async fn handle_bad_client_with_shared<R, W>(
reader: R,
writer: W,
initial_data: &[u8],
peer: SocketAddr,
local_addr: SocketAddr,
config: &ProxyConfig,
beobachten: &BeobachtenStore,
shared: &ProxySharedState,
) where
R: AsyncRead + Unpin + Send + 'static,
W: AsyncWrite + Unpin + Send + 'static,
{ {
let client_type = detect_client_type(initial_data); let client_type = detect_client_type(initial_data);
if config.general.beobachten { if config.general.beobachten {
@@ -938,6 +1020,17 @@ pub async fn handle_bad_client<R, W>(
return; return;
} }
let Some(_masking_permit) = shared.try_acquire_masking_fallback_permit() else {
let outcome_started = Instant::now();
debug!(
client_type = client_type,
"Masking fallback concurrency limit reached"
);
consume_mask_failure_path(reader, config, relay_timeout, idle_timeout).await;
wait_mask_outcome_budget(outcome_started, config).await;
return;
};
let client_sni = tls::extract_sni_from_client_hello(initial_data); let client_sni = tls::extract_sni_from_client_hello(initial_data);
let exclusive_tcp_target = client_sni let exclusive_tcp_target = client_sni
.as_deref() .as_deref()
@@ -1000,24 +1093,12 @@ pub async fn handle_bad_client<R, W>(
Ok(Err(e)) => { Ok(Err(e)) => {
wait_mask_connect_budget_if_needed(connect_started, config).await; wait_mask_connect_budget_if_needed(connect_started, config).await;
debug!(error = %e, "Failed to connect to mask unix socket"); debug!(error = %e, "Failed to connect to mask unix socket");
consume_client_data_with_timeout_and_cap( consume_mask_failure_path(reader, config, relay_timeout, idle_timeout).await;
reader,
config.censorship.mask_relay_max_bytes,
relay_timeout,
idle_timeout,
)
.await;
wait_mask_outcome_budget(outcome_started, config).await; wait_mask_outcome_budget(outcome_started, config).await;
} }
Err(_) => { Err(_) => {
debug!("Timeout connecting to mask unix socket"); debug!("Timeout connecting to mask unix socket");
consume_client_data_with_timeout_and_cap( consume_mask_failure_path(reader, config, relay_timeout, idle_timeout).await;
reader,
config.censorship.mask_relay_max_bytes,
relay_timeout,
idle_timeout,
)
.await;
wait_mask_outcome_budget(outcome_started, config).await; wait_mask_outcome_budget(outcome_started, config).await;
} }
} }
@@ -1030,11 +1111,27 @@ pub async fn handle_bad_client<R, W>(
let mask_host = mask_target.host; let mask_host = mask_target.host;
let mask_port = mask_target.port; let mask_port = mask_target.port;
let resolved_mask_addrs = match resolve_mask_target_addrs(mask_host, mask_port).await {
Ok(addrs) => addrs,
Err(e) => {
let outcome_started = Instant::now();
debug!(
client_type = client_type,
host = %mask_host,
port = mask_port,
error = %e,
"Failed to resolve mask target"
);
consume_mask_failure_path(reader, config, relay_timeout, idle_timeout).await;
wait_mask_outcome_budget(outcome_started, config).await;
return;
}
};
// Fail closed when fallback points at our own listener endpoint. // Fail closed when fallback points at our own listener endpoint.
// Self-referential masking can create recursive proxy loops under // Self-referential masking can create recursive proxy loops under
// misconfiguration and leak distinguishable load spikes to adversaries. // misconfiguration and leak distinguishable load spikes to adversaries.
let resolved_mask_addr = resolve_socket_addr(mask_host, mask_port); if is_mask_target_local_listener_async(mask_host, mask_port, local_addr, &resolved_mask_addrs)
if is_mask_target_local_listener_async(mask_host, mask_port, local_addr, resolved_mask_addr)
.await .await
{ {
let outcome_started = Instant::now(); let outcome_started = Instant::now();
@@ -1045,13 +1142,7 @@ pub async fn handle_bad_client<R, W>(
local = %local_addr, local = %local_addr,
"Mask target resolves to local listener; refusing self-referential masking fallback" "Mask target resolves to local listener; refusing self-referential masking fallback"
); );
consume_client_data_with_timeout_and_cap( consume_mask_failure_path(reader, config, relay_timeout, idle_timeout).await;
reader,
config.censorship.mask_relay_max_bytes,
relay_timeout,
idle_timeout,
)
.await;
wait_mask_outcome_budget(outcome_started, config).await; wait_mask_outcome_budget(outcome_started, config).await;
return; return;
} }
@@ -1066,12 +1157,12 @@ pub async fn handle_bad_client<R, W>(
"Forwarding bad client to mask host" "Forwarding bad client to mask host"
); );
// Apply runtime DNS override for mask target when configured.
let mask_addr = resolved_mask_addr
.map(|addr| addr.to_string())
.unwrap_or_else(|| format!("{}:{}", mask_host, mask_port));
let connect_started = Instant::now(); let connect_started = Instant::now();
let connect_result = timeout(MASK_TIMEOUT, TcpStream::connect(&mask_addr)).await; let connect_result = timeout(
MASK_TIMEOUT,
TcpStream::connect(resolved_mask_addrs.as_slice()),
)
.await;
match connect_result { match connect_result {
Ok(Ok(stream)) => { Ok(Ok(stream)) => {
configure_mask_backend_socket(&stream); configure_mask_backend_socket(&stream);
@@ -1113,24 +1204,12 @@ pub async fn handle_bad_client<R, W>(
Ok(Err(e)) => { Ok(Err(e)) => {
wait_mask_connect_budget_if_needed(connect_started, config).await; wait_mask_connect_budget_if_needed(connect_started, config).await;
debug!(error = %e, "Failed to connect to mask host"); debug!(error = %e, "Failed to connect to mask host");
consume_client_data_with_timeout_and_cap( consume_mask_failure_path(reader, config, relay_timeout, idle_timeout).await;
reader,
config.censorship.mask_relay_max_bytes,
relay_timeout,
idle_timeout,
)
.await;
wait_mask_outcome_budget(outcome_started, config).await; wait_mask_outcome_budget(outcome_started, config).await;
} }
Err(_) => { Err(_) => {
debug!("Timeout connecting to mask host"); debug!("Timeout connecting to mask host");
consume_client_data_with_timeout_and_cap( consume_mask_failure_path(reader, config, relay_timeout, idle_timeout).await;
reader,
config.censorship.mask_relay_max_bytes,
relay_timeout,
idle_timeout,
)
.await;
wait_mask_outcome_budget(outcome_started, config).await; wait_mask_outcome_budget(outcome_started, config).await;
} }
} }

View File

@@ -52,7 +52,8 @@ use self::c2me::{
}; };
use self::d2c::{ use self::d2c::{
MeD2cFlushPolicy, MeWriterResponseOutcome, classify_me_d2c_flush_reason, MeD2cFlushPolicy, MeWriterResponseOutcome, classify_me_d2c_flush_reason,
flush_client_or_cancel, observe_me_d2c_flush_event, flush_client_or_cancel, me_d2c_flush_reason_requires_client_flush,
observe_me_d2c_flush_event,
process_me_writer_response_with_traffic_lease, process_me_writer_response_with_traffic_lease,
}; };
use self::desync::{RelayForensicsState, hash_ip_in, report_desync_frame_too_large_in}; use self::desync::{RelayForensicsState, hash_ip_in, report_desync_frame_too_large_in};

View File

@@ -55,6 +55,37 @@ pub(super) fn classify_me_d2c_flush_reason(
MeD2cFlushReason::QueueDrain MeD2cFlushReason::QueueDrain
} }
pub(super) fn me_d2c_flush_reason_requires_client_flush(reason: MeD2cFlushReason) -> bool {
!matches!(reason, MeD2cFlushReason::QueueDrain)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn queue_drain_is_not_a_physical_flush_trigger() {
assert!(!me_d2c_flush_reason_requires_client_flush(
MeD2cFlushReason::QueueDrain
));
assert!(me_d2c_flush_reason_requires_client_flush(
MeD2cFlushReason::AckImmediate
));
assert!(me_d2c_flush_reason_requires_client_flush(
MeD2cFlushReason::BatchFrames
));
assert!(me_d2c_flush_reason_requires_client_flush(
MeD2cFlushReason::BatchBytes
));
assert!(me_d2c_flush_reason_requires_client_flush(
MeD2cFlushReason::MaxDelay
));
assert!(me_d2c_flush_reason_requires_client_flush(
MeD2cFlushReason::Close
));
}
}
pub(super) fn observe_me_d2c_flush_event( pub(super) fn observe_me_d2c_flush_event(
stats: &Stats, stats: &Stats,
reason: MeD2cFlushReason, reason: MeD2cFlushReason,

View File

@@ -491,12 +491,18 @@ where
d2c_flush_policy.max_bytes, d2c_flush_policy.max_bytes,
max_delay_fired, max_delay_fired,
); );
let flush_started_at = if stats_clone.telemetry_policy().me_level.allows_debug() { let physical_flush =
me_d2c_flush_reason_requires_client_flush(flush_reason);
let flush_started_at = if physical_flush
&& stats_clone.telemetry_policy().me_level.allows_debug()
{
Some(Instant::now()) Some(Instant::now())
} else { } else {
None None
}; };
flush_client_or_cancel(&mut writer, &flow_cancel_me_writer).await?; if physical_flush {
flush_client_or_cancel(&mut writer, &flow_cancel_me_writer).await?;
}
let flush_duration_us = flush_started_at.map(|started| { let flush_duration_us = flush_started_at.map(|started| {
started started
.elapsed() .elapsed()

View File

@@ -6,7 +6,7 @@ use std::sync::{Arc, Mutex};
use std::time::Instant; use std::time::Instant;
use dashmap::DashMap; use dashmap::DashMap;
use tokio::sync::mpsc; use tokio::sync::{OwnedSemaphorePermit, Semaphore, mpsc};
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use crate::proxy::handshake::{AuthProbeSaturationState, AuthProbeState}; use crate::proxy::handshake::{AuthProbeSaturationState, AuthProbeState};
@@ -14,6 +14,7 @@ use crate::proxy::middle_relay::{DesyncDedupRotationState, RelayIdleCandidateReg
use crate::proxy::traffic_limiter::TrafficLimiter; use crate::proxy::traffic_limiter::TrafficLimiter;
const HANDSHAKE_RECENT_USER_RING_LEN: usize = 64; const HANDSHAKE_RECENT_USER_RING_LEN: usize = 64;
const MASKING_FALLBACK_MAX_CONCURRENT: usize = 512;
#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum ConntrackCloseReason { pub(crate) enum ConntrackCloseReason {
@@ -72,6 +73,7 @@ pub(crate) struct ProxySharedState {
active_user_sessions: DashMap<(String, u64), CancellationToken>, active_user_sessions: DashMap<(String, u64), CancellationToken>,
pub(crate) conntrack_pressure_active: AtomicBool, pub(crate) conntrack_pressure_active: AtomicBool,
pub(crate) conntrack_close_tx: Mutex<Option<mpsc::Sender<ConntrackCloseEvent>>>, pub(crate) conntrack_close_tx: Mutex<Option<mpsc::Sender<ConntrackCloseEvent>>>,
masking_fallback_permits: Arc<Semaphore>,
} }
#[must_use = "registered user sessions must be kept alive until relay completion"] #[must_use = "registered user sessions must be kept alive until relay completion"]
@@ -131,9 +133,18 @@ impl ProxySharedState {
active_user_sessions: DashMap::new(), active_user_sessions: DashMap::new(),
conntrack_pressure_active: AtomicBool::new(false), conntrack_pressure_active: AtomicBool::new(false),
conntrack_close_tx: Mutex::new(None), conntrack_close_tx: Mutex::new(None),
masking_fallback_permits: Arc::new(Semaphore::new(MASKING_FALLBACK_MAX_CONCURRENT)),
}) })
} }
/// Attempts to reserve one masking fallback slot for a pre-auth connection.
pub(crate) fn try_acquire_masking_fallback_permit(&self) -> Option<OwnedSemaphorePermit> {
self.masking_fallback_permits
.clone()
.try_acquire_owned()
.ok()
}
pub(crate) fn is_user_enabled(&self, user: &str) -> bool { pub(crate) fn is_user_enabled(&self, user: &str) -> bool {
!self.disabled_users.contains_key(user) !self.disabled_users.contains_key(user)
} }

View File

@@ -34,7 +34,7 @@ fn loop_guard_unspecified_bind_uses_interface_inventory() {
"mask.example", "mask.example",
443, 443,
local, local,
Some(resolved), &[resolved],
&interfaces, &interfaces,
)); ));
} }

View File

@@ -25,7 +25,7 @@ async fn adversarial_parallel_cold_miss_performs_single_interface_refresh() {
let barrier = std::sync::Arc::clone(&barrier); let barrier = std::sync::Arc::clone(&barrier);
tasks.push(tokio::spawn(async move { tasks.push(tokio::spawn(async move {
barrier.wait().await; barrier.wait().await;
is_mask_target_local_listener_async("127.0.0.1", 443, local_addr, None).await is_mask_target_local_listener_async("127.0.0.1", 443, local_addr, &[]).await
})); }));
} }

View File

@@ -17,8 +17,8 @@ async fn tdd_repeated_local_listener_checks_do_not_repeat_interface_enumeration_
let local_addr: SocketAddr = "0.0.0.0:443".parse().expect("valid local addr"); let local_addr: SocketAddr = "0.0.0.0:443".parse().expect("valid local addr");
let _ = is_mask_target_local_listener_async("127.0.0.1", 443, local_addr, None).await; let _ = is_mask_target_local_listener_async("127.0.0.1", 443, local_addr, &[]).await;
let _ = is_mask_target_local_listener_async("127.0.0.1", 443, local_addr, None).await; let _ = is_mask_target_local_listener_async("127.0.0.1", 443, local_addr, &[]).await;
assert_eq!( assert_eq!(
local_interface_enumerations_for_tests(), local_interface_enumerations_for_tests(),
@@ -35,7 +35,7 @@ async fn tdd_non_local_port_short_circuit_does_not_enumerate_interfaces() {
reset_local_interface_enumerations_for_tests(); reset_local_interface_enumerations_for_tests();
let local_addr: SocketAddr = "0.0.0.0:443".parse().expect("valid local addr"); let local_addr: SocketAddr = "0.0.0.0:443".parse().expect("valid local addr");
let is_local = is_mask_target_local_listener_async("127.0.0.1", 8443, local_addr, None).await; let is_local = is_mask_target_local_listener_async("127.0.0.1", 8443, local_addr, &[]).await;
assert!( assert!(
!is_local, !is_local,

View File

@@ -15,38 +15,49 @@ fn closed_local_port() -> u16 {
#[tokio::test] #[tokio::test]
async fn self_target_detection_matches_literal_ipv4_listener() { async fn self_target_detection_matches_literal_ipv4_listener() {
let local: SocketAddr = "198.51.100.40:443".parse().unwrap(); let local: SocketAddr = "198.51.100.40:443".parse().unwrap();
assert!(is_mask_target_local_listener_async("198.51.100.40", 443, local, None,).await); assert!(is_mask_target_local_listener_async("198.51.100.40", 443, local, &[],).await);
} }
#[tokio::test] #[tokio::test]
async fn self_target_detection_matches_bracketed_ipv6_listener() { async fn self_target_detection_matches_bracketed_ipv6_listener() {
let local: SocketAddr = "[2001:db8::44]:8443".parse().unwrap(); let local: SocketAddr = "[2001:db8::44]:8443".parse().unwrap();
assert!(is_mask_target_local_listener_async("[2001:db8::44]", 8443, local, None,).await); assert!(is_mask_target_local_listener_async("[2001:db8::44]", 8443, local, &[],).await);
} }
#[tokio::test] #[tokio::test]
async fn self_target_detection_keeps_same_ip_different_port_forwardable() { async fn self_target_detection_keeps_same_ip_different_port_forwardable() {
let local: SocketAddr = "203.0.113.44:443".parse().unwrap(); let local: SocketAddr = "203.0.113.44:443".parse().unwrap();
assert!(!is_mask_target_local_listener_async("203.0.113.44", 8443, local, None,).await); assert!(!is_mask_target_local_listener_async("203.0.113.44", 8443, local, &[],).await);
} }
#[tokio::test] #[tokio::test]
async fn self_target_detection_normalizes_ipv4_mapped_ipv6_literal() { async fn self_target_detection_normalizes_ipv4_mapped_ipv6_literal() {
let local: SocketAddr = "127.0.0.1:443".parse().unwrap(); let local: SocketAddr = "127.0.0.1:443".parse().unwrap();
assert!(is_mask_target_local_listener_async("::ffff:127.0.0.1", 443, local, None,).await); assert!(is_mask_target_local_listener_async("::ffff:127.0.0.1", 443, local, &[],).await);
} }
#[tokio::test] #[tokio::test]
async fn self_target_detection_unspecified_bind_blocks_loopback_target() { async fn self_target_detection_unspecified_bind_blocks_loopback_target() {
let local: SocketAddr = "0.0.0.0:443".parse().unwrap(); let local: SocketAddr = "0.0.0.0:443".parse().unwrap();
assert!(is_mask_target_local_listener_async("127.0.0.1", 443, local, None,).await); assert!(is_mask_target_local_listener_async("127.0.0.1", 443, local, &[],).await);
} }
#[tokio::test] #[tokio::test]
async fn self_target_detection_unspecified_bind_keeps_remote_target_forwardable() { async fn self_target_detection_unspecified_bind_keeps_remote_target_forwardable() {
let local: SocketAddr = "0.0.0.0:443".parse().unwrap(); let local: SocketAddr = "0.0.0.0:443".parse().unwrap();
let remote: SocketAddr = "198.51.100.44:443".parse().unwrap(); let remote: SocketAddr = "198.51.100.44:443".parse().unwrap();
assert!(!is_mask_target_local_listener_async("mask.example", 443, local, Some(remote),).await); assert!(!is_mask_target_local_listener_async("mask.example", 443, local, &[remote],).await);
}
#[tokio::test]
async fn self_target_detection_checks_all_resolved_addresses() {
let local: SocketAddr = "127.0.0.1:443".parse().unwrap();
let remote: SocketAddr = "198.51.100.44:443".parse().unwrap();
let loopback: SocketAddr = "127.0.0.1:443".parse().unwrap();
assert!(
is_mask_target_local_listener_async("mask.example", 443, local, &[remote, loopback],).await
);
} }
#[tokio::test] #[tokio::test]

View File

@@ -11,16 +11,41 @@ use std::io::{Error, ErrorKind, Result};
use std::sync::Arc; use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
const DEFAULT_MAX_FRAME_SIZE: usize = 16 * 1024 * 1024;
fn reject_oversize_frame(len: usize, max_frame_size: usize, protocol: &str) -> Result<()> {
if len > max_frame_size {
return Err(Error::new(
ErrorKind::InvalidData,
format!("{protocol} frame too large: {len} bytes (max {max_frame_size})"),
));
}
Ok(())
}
// ============= Abridged (Compact) Frame ============= // ============= Abridged (Compact) Frame =============
/// Reader for abridged MTProto framing /// Reader for abridged MTProto framing
pub struct AbridgedFrameReader<R> { pub struct AbridgedFrameReader<R> {
upstream: R, upstream: R,
max_frame_size: usize,
} }
impl<R> AbridgedFrameReader<R> { impl<R> AbridgedFrameReader<R> {
/// Creates a reader with the default maximum frame size.
pub fn new(upstream: R) -> Self { pub fn new(upstream: R) -> Self {
Self { upstream } Self {
upstream,
max_frame_size: DEFAULT_MAX_FRAME_SIZE,
}
}
fn with_max_frame_size(upstream: R, max_frame_size: usize) -> Self {
Self {
upstream,
max_frame_size,
}
} }
} }
@@ -48,10 +73,12 @@ impl<R: AsyncRead + Unpin> AbridgedFrameReader<R> {
len = u32::from_le_bytes([len_bytes[0], len_bytes[1], len_bytes[2], 0]) as usize; len = u32::from_le_bytes([len_bytes[0], len_bytes[1], len_bytes[2], 0]) as usize;
} }
// Length is in 4-byte words // Length is in 4-byte words.
let byte_len = len * 4; let byte_len = len
.checked_mul(4)
.ok_or_else(|| Error::new(ErrorKind::InvalidData, "abridged frame length overflow"))?;
reject_oversize_frame(byte_len, self.max_frame_size, "abridged")?;
// Read data
let mut data = vec![0u8; byte_len]; let mut data = vec![0u8; byte_len];
self.upstream.read_exact(&mut data).await?; self.upstream.read_exact(&mut data).await?;
@@ -152,11 +179,23 @@ impl<W> LayeredStream<W> for AbridgedFrameWriter<W> {
/// Reader for intermediate MTProto framing /// Reader for intermediate MTProto framing
pub struct IntermediateFrameReader<R> { pub struct IntermediateFrameReader<R> {
upstream: R, upstream: R,
max_frame_size: usize,
} }
impl<R> IntermediateFrameReader<R> { impl<R> IntermediateFrameReader<R> {
/// Creates a reader with the default maximum frame size.
pub fn new(upstream: R) -> Self { pub fn new(upstream: R) -> Self {
Self { upstream } Self {
upstream,
max_frame_size: DEFAULT_MAX_FRAME_SIZE,
}
}
fn with_max_frame_size(upstream: R, max_frame_size: usize) -> Self {
Self {
upstream,
max_frame_size,
}
} }
} }
@@ -171,8 +210,8 @@ impl<R: AsyncRead + Unpin> IntermediateFrameReader<R> {
let header = parse_intermediate_header(len_bytes); let header = parse_intermediate_header(len_bytes);
let len = header.wire_len; let len = header.wire_len;
meta.quickack = header.quickack; meta.quickack = header.quickack;
reject_oversize_frame(len, self.max_frame_size, "intermediate")?;
// Read data
let mut data = vec![0u8; len]; let mut data = vec![0u8; len];
self.upstream.read_exact(&mut data).await?; self.upstream.read_exact(&mut data).await?;
@@ -243,11 +282,23 @@ impl<W> LayeredStream<W> for IntermediateFrameWriter<W> {
/// Reader for secure intermediate MTProto framing (with padding) /// Reader for secure intermediate MTProto framing (with padding)
pub struct SecureIntermediateFrameReader<R> { pub struct SecureIntermediateFrameReader<R> {
upstream: R, upstream: R,
max_frame_size: usize,
} }
impl<R> SecureIntermediateFrameReader<R> { impl<R> SecureIntermediateFrameReader<R> {
/// Creates a reader with the default maximum frame size.
pub fn new(upstream: R) -> Self { pub fn new(upstream: R) -> Self {
Self { upstream } Self {
upstream,
max_frame_size: DEFAULT_MAX_FRAME_SIZE,
}
}
fn with_max_frame_size(upstream: R, max_frame_size: usize) -> Self {
Self {
upstream,
max_frame_size,
}
} }
} }
@@ -262,17 +313,16 @@ impl<R: AsyncRead + Unpin> SecureIntermediateFrameReader<R> {
let header = parse_intermediate_header(len_bytes); let header = parse_intermediate_header(len_bytes);
let len = header.wire_len; let len = header.wire_len;
meta.quickack = header.quickack; meta.quickack = header.quickack;
reject_oversize_frame(len, self.max_frame_size, "secure intermediate")?;
// Read data (including padding)
let mut data = vec![0u8; len];
self.upstream.read_exact(&mut data).await?;
let payload_len = secure_payload_len_from_wire_len(len).ok_or_else(|| { let payload_len = secure_payload_len_from_wire_len(len).ok_or_else(|| {
Error::new( Error::new(
ErrorKind::InvalidData, ErrorKind::InvalidData,
format!("Invalid secure frame length: {len}"), format!("Invalid secure frame length: {len}"),
) )
})?; })?;
let mut data = vec![0u8; len];
self.upstream.read_exact(&mut data).await?;
data.truncate(payload_len); data.truncate(payload_len);
Ok((Bytes::from(data), meta)) Ok((Bytes::from(data), meta))
@@ -321,7 +371,9 @@ impl<W: AsyncWrite + Unpin> SecureIntermediateFrameWriter<W> {
let padding_len = secure_padding_len(data.len(), &self.rng); let padding_len = secure_padding_len(data.len(), &self.rng);
let padding = self.rng.bytes(padding_len); let padding = self.rng.bytes(padding_len);
let total_len = data.len() + padding_len; let total_len = data.len().checked_add(padding_len).ok_or_else(|| {
Error::new(ErrorKind::InvalidInput, "secure frame length overflow")
})?;
let len = encode_intermediate_header(total_len, meta.quickack).ok_or_else(|| { let len = encode_intermediate_header(total_len, meta.quickack).ok_or_else(|| {
Error::new( Error::new(
ErrorKind::InvalidInput, ErrorKind::InvalidInput,
@@ -507,15 +559,26 @@ pub enum FrameReaderKind<R> {
} }
impl<R: AsyncRead + Unpin> FrameReaderKind<R> { impl<R: AsyncRead + Unpin> FrameReaderKind<R> {
/// Creates a frame reader with the default maximum frame size.
pub fn new(upstream: R, proto_tag: ProtoTag) -> Self { pub fn new(upstream: R, proto_tag: ProtoTag) -> Self {
Self::with_max_frame_size(upstream, proto_tag, DEFAULT_MAX_FRAME_SIZE)
}
fn with_max_frame_size(
upstream: R,
proto_tag: ProtoTag,
max_frame_size: usize,
) -> Self {
match proto_tag { match proto_tag {
ProtoTag::Abridged => FrameReaderKind::Abridged(AbridgedFrameReader::new(upstream)), ProtoTag::Abridged => FrameReaderKind::Abridged(
ProtoTag::Intermediate => { AbridgedFrameReader::with_max_frame_size(upstream, max_frame_size),
FrameReaderKind::Intermediate(IntermediateFrameReader::new(upstream)) ),
} ProtoTag::Intermediate => FrameReaderKind::Intermediate(
ProtoTag::Secure => { IntermediateFrameReader::with_max_frame_size(upstream, max_frame_size),
FrameReaderKind::SecureIntermediate(SecureIntermediateFrameReader::new(upstream)) ),
} ProtoTag::Secure => FrameReaderKind::SecureIntermediate(
SecureIntermediateFrameReader::with_max_frame_size(upstream, max_frame_size),
),
} }
} }
@@ -569,7 +632,8 @@ mod tests {
use super::*; use super::*;
use crate::crypto::SecureRandom; use crate::crypto::SecureRandom;
use std::sync::Arc; use std::sync::Arc;
use tokio::io::duplex; use tokio::io::{AsyncWriteExt, duplex};
use tokio::time::{Duration, timeout};
fn assert_secure_decoded_payload(decoded: &[u8], original: &[u8]) { fn assert_secure_decoded_payload(decoded: &[u8], original: &[u8]) {
assert!(decoded.starts_with(original)); assert!(decoded.starts_with(original));
@@ -672,6 +736,55 @@ mod tests {
assert!(meta.quickack); assert!(meta.quickack);
} }
#[tokio::test]
async fn abridged_reader_rejects_oversize_frame_before_body_read() {
let (mut client, server) = duplex(1024);
let mut reader = AbridgedFrameReader::new(server);
let len_words = (DEFAULT_MAX_FRAME_SIZE / 4) + 1;
let encoded = (len_words as u32).to_le_bytes();
client
.write_all(&[0x7f, encoded[0], encoded[1], encoded[2]])
.await
.unwrap();
let err = timeout(Duration::from_millis(50), reader.read_frame())
.await
.unwrap()
.unwrap_err();
assert_eq!(err.kind(), ErrorKind::InvalidData);
}
#[tokio::test]
async fn intermediate_reader_rejects_oversize_frame_before_body_read() {
let (mut client, server) = duplex(1024);
let mut reader = IntermediateFrameReader::new(server);
let len = encode_intermediate_header(DEFAULT_MAX_FRAME_SIZE + 1, false).unwrap();
client.write_all(&len.to_le_bytes()).await.unwrap();
let err = timeout(Duration::from_millis(50), reader.read_frame())
.await
.unwrap()
.unwrap_err();
assert_eq!(err.kind(), ErrorKind::InvalidData);
}
#[tokio::test]
async fn secure_reader_rejects_oversize_frame_before_body_read() {
let (mut client, server) = duplex(1024);
let mut reader = SecureIntermediateFrameReader::new(server);
let len = encode_intermediate_header(DEFAULT_MAX_FRAME_SIZE + 4, false).unwrap();
client.write_all(&len.to_le_bytes()).await.unwrap();
let err = timeout(Duration::from_millis(50), reader.read_frame())
.await
.unwrap()
.unwrap_err();
assert_eq!(err.kind(), ErrorKind::InvalidData);
}
#[tokio::test] #[tokio::test]
async fn test_secure_intermediate_padding() { async fn test_secure_intermediate_padding() {
let (client, server) = duplex(1024); let (client, server) = duplex(1024);