From 72800e4aa7dbfe687654277f1c3dcdc3989c18ac Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Wed, 17 Jun 2026 21:48:57 +0300 Subject: [PATCH] Harden masking fallback and frame readers after flow sync Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com> --- src/config/load.rs | 7 + .../tests/load_memory_envelope_tests.rs | 38 ++++ src/proxy/client.rs | 20 +- src/proxy/masking.rs | 179 +++++++++++++----- src/proxy/middle_relay.rs | 3 +- src/proxy/middle_relay/d2c.rs | 31 +++ src/proxy/middle_relay/session.rs | 10 +- src/proxy/shared_state.rs | 13 +- ...ing_additional_hardening_security_tests.rs | 2 +- ...erface_cache_concurrency_security_tests.rs | 2 +- .../masking_interface_cache_security_tests.rs | 6 +- ...masking_self_target_loop_security_tests.rs | 23 ++- src/stream/frame_stream.rs | 155 +++++++++++++-- 13 files changed, 401 insertions(+), 88 deletions(-) diff --git a/src/config/load.rs b/src/config/load.rs index 7b976c8..a89d42f 100644 --- a/src/config/load.rs +++ b/src/config/load.rs @@ -1940,6 +1940,13 @@ impl ProxyConfig { )); } + if config.server.listen_backlog == 0 || config.server.listen_backlog > i32::MAX as u32 { + return Err(ProxyError::Config(format!( + "server.listen_backlog must be within [1, {}]", + i32::MAX + ))); + } + config .server .client_mss_value() diff --git a/src/config/tests/load_memory_envelope_tests.rs b/src/config/tests/load_memory_envelope_tests.rs index ea78498..06681d3 100644 --- a/src/config/tests/load_memory_envelope_tests.rs +++ b/src/config/tests/load_memory_envelope_tests.rs @@ -95,6 +95,44 @@ max_client_frame = 16777217 remove_temp_config(&path); } +#[test] +fn load_rejects_listen_backlog_above_i32_upper_bound() { + let path = write_temp_config( + r#" +[server] +listen_backlog = 2147483648 +"#, + ); + + let err = ProxyConfig::load(&path).expect_err("listen_backlog above socket cap must fail"); + let msg = err.to_string(); + assert!( + msg.contains("server.listen_backlog must be within [1, 2147483647]"), + "error must explain listen_backlog hard cap, got: {msg}" + ); + + remove_temp_config(&path); +} + +#[test] +fn load_rejects_zero_listen_backlog() { + let path = write_temp_config( + r#" +[server] +listen_backlog = 0 +"#, + ); + + let err = ProxyConfig::load(&path).expect_err("zero listen_backlog must fail"); + let msg = err.to_string(); + assert!( + msg.contains("server.listen_backlog must be within [1, 2147483647]"), + "error must explain listen_backlog lower bound, got: {msg}" + ); + + remove_temp_config(&path); +} + #[test] fn load_accepts_memory_limits_at_hard_upper_bounds() { let path = write_temp_config( diff --git a/src/proxy/client.rs b/src/proxy/client.rs index 34b540b..a180e07 100644 --- a/src/proxy/client.rs +++ b/src/proxy/client.rs @@ -113,7 +113,7 @@ use crate::proxy::handshake::{ }; #[cfg(test)] use crate::proxy::handshake::{handle_mtproto_handshake, handle_tls_handshake}; -use crate::proxy::masking::handle_bad_client; +use crate::proxy::masking::handle_bad_client_with_shared; use crate::proxy::middle_relay::handle_via_middle_proxy; use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController}; use crate::proxy::shared_state::ProxySharedState; @@ -310,6 +310,7 @@ fn masking_outcome( local_addr: SocketAddr, config: Arc, beobachten: Arc, + shared: Arc, ) -> HandshakeOutcome where R: AsyncRead + Unpin + Send + 'static, @@ -325,7 +326,7 @@ where ) .await; - handle_bad_client( + handle_bad_client_with_shared( reader, writer, &initial_data, @@ -333,6 +334,7 @@ where local_addr, &config, &beobachten, + shared.as_ref(), ) .await; Ok(()) @@ -718,6 +720,7 @@ where local_addr, config.clone(), beobachten.clone(), + shared.clone(), )); } @@ -739,6 +742,7 @@ where local_addr, config.clone(), beobachten.clone(), + shared.clone(), )); } }; @@ -757,6 +761,7 @@ where local_addr, config.clone(), beobachten.clone(), + shared.clone(), )); } @@ -787,6 +792,7 @@ where local_addr, config.clone(), beobachten.clone(), + shared.clone(), )); } HandshakeResult::Error(e) => { @@ -844,6 +850,7 @@ where local_addr, config.clone(), beobachten.clone(), + shared.clone(), )); } HandshakeResult::Error(e) => return Err(e), @@ -873,6 +880,7 @@ where local_addr, config.clone(), beobachten.clone(), + shared.clone(), )); } @@ -898,6 +906,7 @@ where local_addr, config.clone(), beobachten.clone(), + shared.clone(), )); } HandshakeResult::Error(e) => return Err(e), @@ -1329,6 +1338,7 @@ impl RunningClientHandler { local_addr, self.config.clone(), self.beobachten.clone(), + self.shared.clone(), )); } @@ -1350,6 +1360,7 @@ impl RunningClientHandler { local_addr, self.config.clone(), self.beobachten.clone(), + self.shared.clone(), )); } }; @@ -1369,6 +1380,7 @@ impl RunningClientHandler { local_addr, self.config.clone(), self.beobachten.clone(), + self.shared.clone(), )); } @@ -1416,6 +1428,7 @@ impl RunningClientHandler { local_addr, config.clone(), self.beobachten.clone(), + self.shared.clone(), )); } HandshakeResult::Error(e) => { @@ -1483,6 +1496,7 @@ impl RunningClientHandler { local_addr, config.clone(), self.beobachten.clone(), + self.shared.clone(), )); } HandshakeResult::Error(e) => return Err(e), @@ -1530,6 +1544,7 @@ impl RunningClientHandler { local_addr, self.config.clone(), self.beobachten.clone(), + self.shared.clone(), )); } @@ -1568,6 +1583,7 @@ impl RunningClientHandler { local_addr, config.clone(), self.beobachten.clone(), + self.shared.clone(), )); } HandshakeResult::Error(e) => return Err(e), diff --git a/src/proxy/masking.rs b/src/proxy/masking.rs index fa29529..fbc766a 100644 --- a/src/proxy/masking.rs +++ b/src/proxy/masking.rs @@ -3,6 +3,7 @@ use crate::config::ProxyConfig; use crate::network::dns_overrides::resolve_socket_addr; use crate::protocol::tls; +use crate::proxy::shared_state::ProxySharedState; use crate::stats::beobachten::BeobachtenStore; use crate::transport::proxy_protocol::{ProxyProtocolV1Builder, ProxyProtocolV2Builder}; use crate::transport::socket::configure_tcp_socket; @@ -10,6 +11,7 @@ use crate::transport::socket::configure_tcp_socket; use nix::ifaddrs::getifaddrs; use rand::rngs::StdRng; use rand::{Rng, RngExt, SeedableRng}; +use std::io::{Error as IoError, ErrorKind}; use std::net::{IpAddr, SocketAddr}; use std::str; #[cfg(test)] @@ -18,7 +20,7 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::{Mutex, OnceLock}; use std::time::{Duration, Instant as StdInstant}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; -use tokio::net::TcpStream; +use tokio::net::{TcpStream, lookup_host}; #[cfg(unix)] use tokio::net::UnixStream; #[cfg(unix)] @@ -271,6 +273,32 @@ async fn consume_client_data_with_timeout_and_cap( } } +fn mask_failure_drain_cap(config: &ProxyConfig) -> usize { + let configured_cap = config.censorship.mask_relay_max_bytes; + if configured_cap == 0 { + return MASK_BUFFER_SIZE; + } + + configured_cap.min(MASK_BUFFER_SIZE) +} + +async fn consume_mask_failure_path( + reader: R, + config: &ProxyConfig, + relay_timeout: Duration, + idle_timeout: Duration, +) where + R: AsyncRead + Unpin, +{ + consume_client_data_with_timeout_and_cap( + reader, + mask_failure_drain_cap(config), + relay_timeout, + idle_timeout, + ) + .await; +} + async fn wait_mask_connect_budget(started: Instant) { let elapsed = started.elapsed(); if elapsed < MASK_TIMEOUT { @@ -501,6 +529,32 @@ fn parse_mask_host_ip_literal(host: &str) -> Option { host.parse::().ok() } +async fn resolve_mask_target_addrs( + mask_host: &str, + mask_port: u16, +) -> std::io::Result> { + if let Some(addr) = resolve_socket_addr(mask_host, mask_port) { + return Ok(vec![addr]); + } + + if let Some(ip) = parse_mask_host_ip_literal(mask_host) { + return Ok(vec![SocketAddr::new(ip, mask_port)]); + } + + let addrs = timeout(MASK_TIMEOUT, lookup_host((mask_host, mask_port))) + .await + .map_err(|_| IoError::new(ErrorKind::TimedOut, "mask target DNS lookup timed out"))??; + let addrs = addrs.collect::>(); + if addrs.is_empty() { + return Err(IoError::new( + ErrorKind::NotFound, + "mask target DNS lookup returned no addresses", + )); + } + + Ok(addrs) +} + fn matching_tls_domain_for_sni<'a>(config: &'a ProxyConfig, sni: &str) -> Option<&'a str> { if config.censorship.tls_domain.eq_ignore_ascii_case(sni) { return Some(config.censorship.tls_domain.as_str()); @@ -782,7 +836,7 @@ fn is_mask_target_local_listener_with_interfaces( mask_host: &str, mask_port: u16, local_addr: SocketAddr, - resolved_override: Option, + resolved_addrs: &[SocketAddr], interface_ips: &[IpAddr], ) -> bool { if mask_port != local_addr.port() { @@ -792,7 +846,7 @@ fn is_mask_target_local_listener_with_interfaces( let local_ip = canonical_ip(local_addr.ip()); let literal_mask_ip = parse_mask_host_ip_literal(mask_host).map(canonical_ip); - if let Some(addr) = resolved_override { + for addr in resolved_addrs { let resolved_ip = canonical_ip(addr.ip()); if resolved_ip == local_ip { return true; @@ -829,7 +883,7 @@ fn is_mask_target_local_listener( mask_host: &str, mask_port: u16, local_addr: SocketAddr, - resolved_override: Option, + resolved_addrs: &[SocketAddr], ) -> bool { if mask_port != local_addr.port() { return false; @@ -840,7 +894,7 @@ fn is_mask_target_local_listener( mask_host, mask_port, local_addr, - resolved_override, + resolved_addrs, &interfaces, ) } @@ -849,7 +903,7 @@ async fn is_mask_target_local_listener_async( mask_host: &str, mask_port: u16, local_addr: SocketAddr, - resolved_override: Option, + resolved_addrs: &[SocketAddr], ) -> bool { if mask_port != local_addr.port() { return false; @@ -860,7 +914,7 @@ async fn is_mask_target_local_listener_async( mask_host, mask_port, local_addr, - resolved_override, + resolved_addrs, &interfaces, ) } @@ -904,7 +958,7 @@ fn configure_mask_backend_socket(stream: &TcpStream) { } } -/// Handle a bad client by forwarding to mask host +/// Handles a bad client by forwarding it to the configured mask target. pub async fn handle_bad_client( reader: R, writer: W, @@ -916,6 +970,34 @@ pub async fn handle_bad_client( ) where R: AsyncRead + Unpin + Send + 'static, W: AsyncWrite + Unpin + Send + 'static, +{ + let shared = ProxySharedState::new(); + handle_bad_client_with_shared( + reader, + writer, + initial_data, + peer, + local_addr, + config, + beobachten, + shared.as_ref(), + ) + .await; +} + +/// Handles a bad client with shared pre-auth fallback admission state. +pub(crate) async fn handle_bad_client_with_shared( + reader: R, + writer: W, + initial_data: &[u8], + peer: SocketAddr, + local_addr: SocketAddr, + config: &ProxyConfig, + beobachten: &BeobachtenStore, + shared: &ProxySharedState, +) where + R: AsyncRead + Unpin + Send + 'static, + W: AsyncWrite + Unpin + Send + 'static, { let client_type = detect_client_type(initial_data); if config.general.beobachten { @@ -938,6 +1020,17 @@ pub async fn handle_bad_client( return; } + let Some(_masking_permit) = shared.try_acquire_masking_fallback_permit() else { + let outcome_started = Instant::now(); + debug!( + client_type = client_type, + "Masking fallback concurrency limit reached" + ); + consume_mask_failure_path(reader, config, relay_timeout, idle_timeout).await; + wait_mask_outcome_budget(outcome_started, config).await; + return; + }; + let client_sni = tls::extract_sni_from_client_hello(initial_data); let exclusive_tcp_target = client_sni .as_deref() @@ -1000,24 +1093,12 @@ pub async fn handle_bad_client( Ok(Err(e)) => { wait_mask_connect_budget_if_needed(connect_started, config).await; debug!(error = %e, "Failed to connect to mask unix socket"); - consume_client_data_with_timeout_and_cap( - reader, - config.censorship.mask_relay_max_bytes, - relay_timeout, - idle_timeout, - ) - .await; + consume_mask_failure_path(reader, config, relay_timeout, idle_timeout).await; wait_mask_outcome_budget(outcome_started, config).await; } Err(_) => { debug!("Timeout connecting to mask unix socket"); - consume_client_data_with_timeout_and_cap( - reader, - config.censorship.mask_relay_max_bytes, - relay_timeout, - idle_timeout, - ) - .await; + consume_mask_failure_path(reader, config, relay_timeout, idle_timeout).await; wait_mask_outcome_budget(outcome_started, config).await; } } @@ -1030,11 +1111,27 @@ pub async fn handle_bad_client( let mask_host = mask_target.host; let mask_port = mask_target.port; + let resolved_mask_addrs = match resolve_mask_target_addrs(mask_host, mask_port).await { + Ok(addrs) => addrs, + Err(e) => { + let outcome_started = Instant::now(); + debug!( + client_type = client_type, + host = %mask_host, + port = mask_port, + error = %e, + "Failed to resolve mask target" + ); + consume_mask_failure_path(reader, config, relay_timeout, idle_timeout).await; + wait_mask_outcome_budget(outcome_started, config).await; + return; + } + }; + // Fail closed when fallback points at our own listener endpoint. // Self-referential masking can create recursive proxy loops under // misconfiguration and leak distinguishable load spikes to adversaries. - let resolved_mask_addr = resolve_socket_addr(mask_host, mask_port); - if is_mask_target_local_listener_async(mask_host, mask_port, local_addr, resolved_mask_addr) + if is_mask_target_local_listener_async(mask_host, mask_port, local_addr, &resolved_mask_addrs) .await { let outcome_started = Instant::now(); @@ -1045,13 +1142,7 @@ pub async fn handle_bad_client( local = %local_addr, "Mask target resolves to local listener; refusing self-referential masking fallback" ); - consume_client_data_with_timeout_and_cap( - reader, - config.censorship.mask_relay_max_bytes, - relay_timeout, - idle_timeout, - ) - .await; + consume_mask_failure_path(reader, config, relay_timeout, idle_timeout).await; wait_mask_outcome_budget(outcome_started, config).await; return; } @@ -1066,12 +1157,12 @@ pub async fn handle_bad_client( "Forwarding bad client to mask host" ); - // Apply runtime DNS override for mask target when configured. - let mask_addr = resolved_mask_addr - .map(|addr| addr.to_string()) - .unwrap_or_else(|| format!("{}:{}", mask_host, mask_port)); let connect_started = Instant::now(); - let connect_result = timeout(MASK_TIMEOUT, TcpStream::connect(&mask_addr)).await; + let connect_result = timeout( + MASK_TIMEOUT, + TcpStream::connect(resolved_mask_addrs.as_slice()), + ) + .await; match connect_result { Ok(Ok(stream)) => { configure_mask_backend_socket(&stream); @@ -1113,24 +1204,12 @@ pub async fn handle_bad_client( Ok(Err(e)) => { wait_mask_connect_budget_if_needed(connect_started, config).await; debug!(error = %e, "Failed to connect to mask host"); - consume_client_data_with_timeout_and_cap( - reader, - config.censorship.mask_relay_max_bytes, - relay_timeout, - idle_timeout, - ) - .await; + consume_mask_failure_path(reader, config, relay_timeout, idle_timeout).await; wait_mask_outcome_budget(outcome_started, config).await; } Err(_) => { debug!("Timeout connecting to mask host"); - consume_client_data_with_timeout_and_cap( - reader, - config.censorship.mask_relay_max_bytes, - relay_timeout, - idle_timeout, - ) - .await; + consume_mask_failure_path(reader, config, relay_timeout, idle_timeout).await; wait_mask_outcome_budget(outcome_started, config).await; } } diff --git a/src/proxy/middle_relay.rs b/src/proxy/middle_relay.rs index 2c61c86..25f3787 100644 --- a/src/proxy/middle_relay.rs +++ b/src/proxy/middle_relay.rs @@ -52,7 +52,8 @@ use self::c2me::{ }; use self::d2c::{ MeD2cFlushPolicy, MeWriterResponseOutcome, classify_me_d2c_flush_reason, - flush_client_or_cancel, observe_me_d2c_flush_event, + flush_client_or_cancel, me_d2c_flush_reason_requires_client_flush, + observe_me_d2c_flush_event, process_me_writer_response_with_traffic_lease, }; use self::desync::{RelayForensicsState, hash_ip_in, report_desync_frame_too_large_in}; diff --git a/src/proxy/middle_relay/d2c.rs b/src/proxy/middle_relay/d2c.rs index d227aa9..a6f043b 100644 --- a/src/proxy/middle_relay/d2c.rs +++ b/src/proxy/middle_relay/d2c.rs @@ -55,6 +55,37 @@ pub(super) fn classify_me_d2c_flush_reason( MeD2cFlushReason::QueueDrain } +pub(super) fn me_d2c_flush_reason_requires_client_flush(reason: MeD2cFlushReason) -> bool { + !matches!(reason, MeD2cFlushReason::QueueDrain) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn queue_drain_is_not_a_physical_flush_trigger() { + assert!(!me_d2c_flush_reason_requires_client_flush( + MeD2cFlushReason::QueueDrain + )); + assert!(me_d2c_flush_reason_requires_client_flush( + MeD2cFlushReason::AckImmediate + )); + assert!(me_d2c_flush_reason_requires_client_flush( + MeD2cFlushReason::BatchFrames + )); + assert!(me_d2c_flush_reason_requires_client_flush( + MeD2cFlushReason::BatchBytes + )); + assert!(me_d2c_flush_reason_requires_client_flush( + MeD2cFlushReason::MaxDelay + )); + assert!(me_d2c_flush_reason_requires_client_flush( + MeD2cFlushReason::Close + )); + } +} + pub(super) fn observe_me_d2c_flush_event( stats: &Stats, reason: MeD2cFlushReason, diff --git a/src/proxy/middle_relay/session.rs b/src/proxy/middle_relay/session.rs index 4865993..acded2b 100644 --- a/src/proxy/middle_relay/session.rs +++ b/src/proxy/middle_relay/session.rs @@ -491,12 +491,18 @@ where d2c_flush_policy.max_bytes, max_delay_fired, ); - let flush_started_at = if stats_clone.telemetry_policy().me_level.allows_debug() { + let physical_flush = + me_d2c_flush_reason_requires_client_flush(flush_reason); + let flush_started_at = if physical_flush + && stats_clone.telemetry_policy().me_level.allows_debug() + { Some(Instant::now()) } else { None }; - flush_client_or_cancel(&mut writer, &flow_cancel_me_writer).await?; + if physical_flush { + flush_client_or_cancel(&mut writer, &flow_cancel_me_writer).await?; + } let flush_duration_us = flush_started_at.map(|started| { started .elapsed() diff --git a/src/proxy/shared_state.rs b/src/proxy/shared_state.rs index 9ed319b..6a47761 100644 --- a/src/proxy/shared_state.rs +++ b/src/proxy/shared_state.rs @@ -6,7 +6,7 @@ use std::sync::{Arc, Mutex}; use std::time::Instant; use dashmap::DashMap; -use tokio::sync::mpsc; +use tokio::sync::{OwnedSemaphorePermit, Semaphore, mpsc}; use tokio_util::sync::CancellationToken; use crate::proxy::handshake::{AuthProbeSaturationState, AuthProbeState}; @@ -14,6 +14,7 @@ use crate::proxy::middle_relay::{DesyncDedupRotationState, RelayIdleCandidateReg use crate::proxy::traffic_limiter::TrafficLimiter; const HANDSHAKE_RECENT_USER_RING_LEN: usize = 64; +const MASKING_FALLBACK_MAX_CONCURRENT: usize = 512; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub(crate) enum ConntrackCloseReason { @@ -72,6 +73,7 @@ pub(crate) struct ProxySharedState { active_user_sessions: DashMap<(String, u64), CancellationToken>, pub(crate) conntrack_pressure_active: AtomicBool, pub(crate) conntrack_close_tx: Mutex>>, + masking_fallback_permits: Arc, } #[must_use = "registered user sessions must be kept alive until relay completion"] @@ -131,9 +133,18 @@ impl ProxySharedState { active_user_sessions: DashMap::new(), conntrack_pressure_active: AtomicBool::new(false), conntrack_close_tx: Mutex::new(None), + masking_fallback_permits: Arc::new(Semaphore::new(MASKING_FALLBACK_MAX_CONCURRENT)), }) } + /// Attempts to reserve one masking fallback slot for a pre-auth connection. + pub(crate) fn try_acquire_masking_fallback_permit(&self) -> Option { + self.masking_fallback_permits + .clone() + .try_acquire_owned() + .ok() + } + pub(crate) fn is_user_enabled(&self, user: &str) -> bool { !self.disabled_users.contains_key(user) } diff --git a/src/proxy/tests/masking_additional_hardening_security_tests.rs b/src/proxy/tests/masking_additional_hardening_security_tests.rs index 1b8ca2e..22ee8e3 100644 --- a/src/proxy/tests/masking_additional_hardening_security_tests.rs +++ b/src/proxy/tests/masking_additional_hardening_security_tests.rs @@ -34,7 +34,7 @@ fn loop_guard_unspecified_bind_uses_interface_inventory() { "mask.example", 443, local, - Some(resolved), + &[resolved], &interfaces, )); } diff --git a/src/proxy/tests/masking_interface_cache_concurrency_security_tests.rs b/src/proxy/tests/masking_interface_cache_concurrency_security_tests.rs index ed6d1ab..a1584fc 100644 --- a/src/proxy/tests/masking_interface_cache_concurrency_security_tests.rs +++ b/src/proxy/tests/masking_interface_cache_concurrency_security_tests.rs @@ -25,7 +25,7 @@ async fn adversarial_parallel_cold_miss_performs_single_interface_refresh() { let barrier = std::sync::Arc::clone(&barrier); tasks.push(tokio::spawn(async move { barrier.wait().await; - is_mask_target_local_listener_async("127.0.0.1", 443, local_addr, None).await + is_mask_target_local_listener_async("127.0.0.1", 443, local_addr, &[]).await })); } diff --git a/src/proxy/tests/masking_interface_cache_security_tests.rs b/src/proxy/tests/masking_interface_cache_security_tests.rs index 17debb0..4be2857 100644 --- a/src/proxy/tests/masking_interface_cache_security_tests.rs +++ b/src/proxy/tests/masking_interface_cache_security_tests.rs @@ -17,8 +17,8 @@ async fn tdd_repeated_local_listener_checks_do_not_repeat_interface_enumeration_ let local_addr: SocketAddr = "0.0.0.0:443".parse().expect("valid local addr"); - let _ = is_mask_target_local_listener_async("127.0.0.1", 443, local_addr, None).await; - let _ = is_mask_target_local_listener_async("127.0.0.1", 443, local_addr, None).await; + let _ = is_mask_target_local_listener_async("127.0.0.1", 443, local_addr, &[]).await; + let _ = is_mask_target_local_listener_async("127.0.0.1", 443, local_addr, &[]).await; assert_eq!( local_interface_enumerations_for_tests(), @@ -35,7 +35,7 @@ async fn tdd_non_local_port_short_circuit_does_not_enumerate_interfaces() { reset_local_interface_enumerations_for_tests(); let local_addr: SocketAddr = "0.0.0.0:443".parse().expect("valid local addr"); - let is_local = is_mask_target_local_listener_async("127.0.0.1", 8443, local_addr, None).await; + let is_local = is_mask_target_local_listener_async("127.0.0.1", 8443, local_addr, &[]).await; assert!( !is_local, diff --git a/src/proxy/tests/masking_self_target_loop_security_tests.rs b/src/proxy/tests/masking_self_target_loop_security_tests.rs index 975b4fc..0510d44 100644 --- a/src/proxy/tests/masking_self_target_loop_security_tests.rs +++ b/src/proxy/tests/masking_self_target_loop_security_tests.rs @@ -15,38 +15,49 @@ fn closed_local_port() -> u16 { #[tokio::test] async fn self_target_detection_matches_literal_ipv4_listener() { let local: SocketAddr = "198.51.100.40:443".parse().unwrap(); - assert!(is_mask_target_local_listener_async("198.51.100.40", 443, local, None,).await); + assert!(is_mask_target_local_listener_async("198.51.100.40", 443, local, &[],).await); } #[tokio::test] async fn self_target_detection_matches_bracketed_ipv6_listener() { let local: SocketAddr = "[2001:db8::44]:8443".parse().unwrap(); - assert!(is_mask_target_local_listener_async("[2001:db8::44]", 8443, local, None,).await); + assert!(is_mask_target_local_listener_async("[2001:db8::44]", 8443, local, &[],).await); } #[tokio::test] async fn self_target_detection_keeps_same_ip_different_port_forwardable() { let local: SocketAddr = "203.0.113.44:443".parse().unwrap(); - assert!(!is_mask_target_local_listener_async("203.0.113.44", 8443, local, None,).await); + assert!(!is_mask_target_local_listener_async("203.0.113.44", 8443, local, &[],).await); } #[tokio::test] async fn self_target_detection_normalizes_ipv4_mapped_ipv6_literal() { let local: SocketAddr = "127.0.0.1:443".parse().unwrap(); - assert!(is_mask_target_local_listener_async("::ffff:127.0.0.1", 443, local, None,).await); + assert!(is_mask_target_local_listener_async("::ffff:127.0.0.1", 443, local, &[],).await); } #[tokio::test] async fn self_target_detection_unspecified_bind_blocks_loopback_target() { let local: SocketAddr = "0.0.0.0:443".parse().unwrap(); - assert!(is_mask_target_local_listener_async("127.0.0.1", 443, local, None,).await); + assert!(is_mask_target_local_listener_async("127.0.0.1", 443, local, &[],).await); } #[tokio::test] async fn self_target_detection_unspecified_bind_keeps_remote_target_forwardable() { let local: SocketAddr = "0.0.0.0:443".parse().unwrap(); let remote: SocketAddr = "198.51.100.44:443".parse().unwrap(); - assert!(!is_mask_target_local_listener_async("mask.example", 443, local, Some(remote),).await); + assert!(!is_mask_target_local_listener_async("mask.example", 443, local, &[remote],).await); +} + +#[tokio::test] +async fn self_target_detection_checks_all_resolved_addresses() { + let local: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let remote: SocketAddr = "198.51.100.44:443".parse().unwrap(); + let loopback: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + assert!( + is_mask_target_local_listener_async("mask.example", 443, local, &[remote, loopback],).await + ); } #[tokio::test] diff --git a/src/stream/frame_stream.rs b/src/stream/frame_stream.rs index 42faaec..dc87bd0 100644 --- a/src/stream/frame_stream.rs +++ b/src/stream/frame_stream.rs @@ -11,16 +11,41 @@ use std::io::{Error, ErrorKind, Result}; use std::sync::Arc; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +const DEFAULT_MAX_FRAME_SIZE: usize = 16 * 1024 * 1024; + +fn reject_oversize_frame(len: usize, max_frame_size: usize, protocol: &str) -> Result<()> { + if len > max_frame_size { + return Err(Error::new( + ErrorKind::InvalidData, + format!("{protocol} frame too large: {len} bytes (max {max_frame_size})"), + )); + } + + Ok(()) +} + // ============= Abridged (Compact) Frame ============= /// Reader for abridged MTProto framing pub struct AbridgedFrameReader { upstream: R, + max_frame_size: usize, } impl AbridgedFrameReader { + /// Creates a reader with the default maximum frame size. pub fn new(upstream: R) -> Self { - Self { upstream } + Self { + upstream, + max_frame_size: DEFAULT_MAX_FRAME_SIZE, + } + } + + fn with_max_frame_size(upstream: R, max_frame_size: usize) -> Self { + Self { + upstream, + max_frame_size, + } } } @@ -48,10 +73,12 @@ impl AbridgedFrameReader { len = u32::from_le_bytes([len_bytes[0], len_bytes[1], len_bytes[2], 0]) as usize; } - // Length is in 4-byte words - let byte_len = len * 4; + // Length is in 4-byte words. + let byte_len = len + .checked_mul(4) + .ok_or_else(|| Error::new(ErrorKind::InvalidData, "abridged frame length overflow"))?; + reject_oversize_frame(byte_len, self.max_frame_size, "abridged")?; - // Read data let mut data = vec![0u8; byte_len]; self.upstream.read_exact(&mut data).await?; @@ -152,11 +179,23 @@ impl LayeredStream for AbridgedFrameWriter { /// Reader for intermediate MTProto framing pub struct IntermediateFrameReader { upstream: R, + max_frame_size: usize, } impl IntermediateFrameReader { + /// Creates a reader with the default maximum frame size. pub fn new(upstream: R) -> Self { - Self { upstream } + Self { + upstream, + max_frame_size: DEFAULT_MAX_FRAME_SIZE, + } + } + + fn with_max_frame_size(upstream: R, max_frame_size: usize) -> Self { + Self { + upstream, + max_frame_size, + } } } @@ -171,8 +210,8 @@ impl IntermediateFrameReader { let header = parse_intermediate_header(len_bytes); let len = header.wire_len; meta.quickack = header.quickack; + reject_oversize_frame(len, self.max_frame_size, "intermediate")?; - // Read data let mut data = vec![0u8; len]; self.upstream.read_exact(&mut data).await?; @@ -243,11 +282,23 @@ impl LayeredStream for IntermediateFrameWriter { /// Reader for secure intermediate MTProto framing (with padding) pub struct SecureIntermediateFrameReader { upstream: R, + max_frame_size: usize, } impl SecureIntermediateFrameReader { + /// Creates a reader with the default maximum frame size. pub fn new(upstream: R) -> Self { - Self { upstream } + Self { + upstream, + max_frame_size: DEFAULT_MAX_FRAME_SIZE, + } + } + + fn with_max_frame_size(upstream: R, max_frame_size: usize) -> Self { + Self { + upstream, + max_frame_size, + } } } @@ -262,17 +313,16 @@ impl SecureIntermediateFrameReader { let header = parse_intermediate_header(len_bytes); let len = header.wire_len; meta.quickack = header.quickack; - - // Read data (including padding) - let mut data = vec![0u8; len]; - self.upstream.read_exact(&mut data).await?; - + reject_oversize_frame(len, self.max_frame_size, "secure intermediate")?; let payload_len = secure_payload_len_from_wire_len(len).ok_or_else(|| { Error::new( ErrorKind::InvalidData, format!("Invalid secure frame length: {len}"), ) })?; + + let mut data = vec![0u8; len]; + self.upstream.read_exact(&mut data).await?; data.truncate(payload_len); Ok((Bytes::from(data), meta)) @@ -321,7 +371,9 @@ impl SecureIntermediateFrameWriter { let padding_len = secure_padding_len(data.len(), &self.rng); let padding = self.rng.bytes(padding_len); - let total_len = data.len() + padding_len; + let total_len = data.len().checked_add(padding_len).ok_or_else(|| { + Error::new(ErrorKind::InvalidInput, "secure frame length overflow") + })?; let len = encode_intermediate_header(total_len, meta.quickack).ok_or_else(|| { Error::new( ErrorKind::InvalidInput, @@ -507,15 +559,26 @@ pub enum FrameReaderKind { } impl FrameReaderKind { + /// Creates a frame reader with the default maximum frame size. pub fn new(upstream: R, proto_tag: ProtoTag) -> Self { + Self::with_max_frame_size(upstream, proto_tag, DEFAULT_MAX_FRAME_SIZE) + } + + fn with_max_frame_size( + upstream: R, + proto_tag: ProtoTag, + max_frame_size: usize, + ) -> Self { match proto_tag { - ProtoTag::Abridged => FrameReaderKind::Abridged(AbridgedFrameReader::new(upstream)), - ProtoTag::Intermediate => { - FrameReaderKind::Intermediate(IntermediateFrameReader::new(upstream)) - } - ProtoTag::Secure => { - FrameReaderKind::SecureIntermediate(SecureIntermediateFrameReader::new(upstream)) - } + ProtoTag::Abridged => FrameReaderKind::Abridged( + AbridgedFrameReader::with_max_frame_size(upstream, max_frame_size), + ), + ProtoTag::Intermediate => FrameReaderKind::Intermediate( + IntermediateFrameReader::with_max_frame_size(upstream, max_frame_size), + ), + ProtoTag::Secure => FrameReaderKind::SecureIntermediate( + SecureIntermediateFrameReader::with_max_frame_size(upstream, max_frame_size), + ), } } @@ -569,7 +632,8 @@ mod tests { use super::*; use crate::crypto::SecureRandom; use std::sync::Arc; - use tokio::io::duplex; + use tokio::io::{AsyncWriteExt, duplex}; + use tokio::time::{Duration, timeout}; fn assert_secure_decoded_payload(decoded: &[u8], original: &[u8]) { assert!(decoded.starts_with(original)); @@ -672,6 +736,55 @@ mod tests { assert!(meta.quickack); } + #[tokio::test] + async fn abridged_reader_rejects_oversize_frame_before_body_read() { + let (mut client, server) = duplex(1024); + let mut reader = AbridgedFrameReader::new(server); + let len_words = (DEFAULT_MAX_FRAME_SIZE / 4) + 1; + let encoded = (len_words as u32).to_le_bytes(); + + client + .write_all(&[0x7f, encoded[0], encoded[1], encoded[2]]) + .await + .unwrap(); + let err = timeout(Duration::from_millis(50), reader.read_frame()) + .await + .unwrap() + .unwrap_err(); + + assert_eq!(err.kind(), ErrorKind::InvalidData); + } + + #[tokio::test] + async fn intermediate_reader_rejects_oversize_frame_before_body_read() { + let (mut client, server) = duplex(1024); + let mut reader = IntermediateFrameReader::new(server); + let len = encode_intermediate_header(DEFAULT_MAX_FRAME_SIZE + 1, false).unwrap(); + + client.write_all(&len.to_le_bytes()).await.unwrap(); + let err = timeout(Duration::from_millis(50), reader.read_frame()) + .await + .unwrap() + .unwrap_err(); + + assert_eq!(err.kind(), ErrorKind::InvalidData); + } + + #[tokio::test] + async fn secure_reader_rejects_oversize_frame_before_body_read() { + let (mut client, server) = duplex(1024); + let mut reader = SecureIntermediateFrameReader::new(server); + let len = encode_intermediate_header(DEFAULT_MAX_FRAME_SIZE + 4, false).unwrap(); + + client.write_all(&len.to_le_bytes()).await.unwrap(); + let err = timeout(Duration::from_millis(50), reader.read_frame()) + .await + .unwrap() + .unwrap_err(); + + assert_eq!(err.kind(), ErrorKind::InvalidData); + } + #[tokio::test] async fn test_secure_intermediate_padding() { let (client, server) = duplex(1024);