From 205fc88718ae8ae2e4f050e8e4f56951ad1d0cf5 Mon Sep 17 00:00:00 2001 From: David Osipov Date: Tue, 17 Mar 2026 01:29:30 +0400 Subject: [PATCH] feat(proxy): enhance logging and deduplication for unknown datacenters - Implemented a mechanism to log unknown datacenter indices with a distinct limit to avoid excessive logging. - Introduced tests to ensure that logging is deduplicated per datacenter index and respects the distinct limit. - Updated the fallback logic for datacenter resolution to prevent panics when only a single datacenter is available. feat(proxy): add authentication probe throttling - Added a pre-authentication probe throttling mechanism to limit the rate of invalid TLS and MTProto handshake attempts. - Introduced a backoff strategy for repeated failures and ensured that successful handshakes reset the failure count. - Implemented tests to validate the behavior of the authentication probe under various conditions. fix(proxy): ensure proper flushing of masked writes - Added a flush operation after writing initial data to the mask writer to ensure data integrity. refactor(proxy): optimize desynchronization deduplication - Replaced the Mutex-based deduplication structure with a DashMap for improved concurrency and performance. - Implemented a bounded cache for deduplication to limit memory usage and prevent stale entries from persisting. test(proxy): enhance security tests for middle relay and handshake - Added comprehensive tests for the middle relay and handshake processes, including scenarios for deduplication and authentication probe behavior. - Ensured that the tests cover edge cases and validate the expected behavior of the system under load. --- src/config/types.rs | 8 + src/protocol/tls.rs | 49 ++- src/protocol/tls_security_tests.rs | 28 ++ src/proxy/client.rs | 73 ++++- src/proxy/client_security_tests.rs | 381 +++++++++++++++++++++++ src/proxy/direct_relay.rs | 52 +++- src/proxy/direct_relay_security_tests.rs | 51 +++ src/proxy/handshake.rs | 159 +++++++++- src/proxy/handshake_security_tests.rs | 137 ++++++-- src/proxy/masking.rs | 3 + src/proxy/masking_security_tests.rs | 6 + src/proxy/middle_relay.rs | 169 ++++------ src/proxy/middle_relay_security_tests.rs | 103 ++++++ src/stats/mod.rs | 27 ++ src/stream/frame_codec.rs | 28 ++ 15 files changed, 1124 insertions(+), 150 deletions(-) create mode 100644 src/proxy/direct_relay_security_tests.rs create mode 100644 src/proxy/middle_relay_security_tests.rs diff --git a/src/config/types.rs b/src/config/types.rs index 04a22ce..7989d6c 100644 --- a/src/config/types.rs +++ b/src/config/types.rs @@ -1156,6 +1156,13 @@ pub struct ServerConfig { #[serde(default = "default_proxy_protocol_header_timeout_ms")] pub proxy_protocol_header_timeout_ms: u64, + /// Trusted source CIDRs allowed to send incoming PROXY protocol headers. + /// + /// When non-empty, connections from addresses outside this allowlist are + /// rejected before `src_addr` is applied. + #[serde(default)] + pub proxy_protocol_trusted_cidrs: Vec, + #[serde(default)] pub metrics_port: Option, @@ -1180,6 +1187,7 @@ impl Default for ServerConfig { listen_tcp: None, proxy_protocol: false, proxy_protocol_header_timeout_ms: default_proxy_protocol_header_timeout_ms(), + proxy_protocol_trusted_cidrs: Vec::new(), metrics_port: None, metrics_whitelist: default_metrics_whitelist(), api: ApiConfig::default(), diff --git a/src/protocol/tls.rs b/src/protocol/tls.rs index 5a5ef21..c82c9fe 100644 --- a/src/protocol/tls.rs +++ b/src/protocol/tls.rs @@ -285,6 +285,26 @@ pub fn validate_tls_handshake( handshake: &[u8], secrets: &[(String, Vec)], ignore_time_skew: bool, +) -> Option { + validate_tls_handshake_with_replay_window( + handshake, + secrets, + ignore_time_skew, + u64::from(BOOT_TIME_MAX_SECS), + ) +} + +/// Validate TLS ClientHello and cap the boot-time bypass by replay-cache TTL. +/// +/// A boot-time timestamp is only accepted when it falls below both +/// `BOOT_TIME_MAX_SECS` and the configured replay window, preventing timestamp +/// reuse outside replay cache coverage. +#[must_use] +pub fn validate_tls_handshake_with_replay_window( + handshake: &[u8], + secrets: &[(String, Vec)], + ignore_time_skew: bool, + replay_window_secs: u64, ) -> Option { // Only pay the clock syscall when we will actually compare against it. // If `ignore_time_skew` is set, a broken or unavailable system clock @@ -295,7 +315,16 @@ pub fn validate_tls_handshake( 0_i64 }; - validate_tls_handshake_at_time(handshake, secrets, ignore_time_skew, now) + let replay_window_u32 = u32::try_from(replay_window_secs).unwrap_or(u32::MAX); + let boot_time_cap_secs = BOOT_TIME_MAX_SECS.min(replay_window_u32); + + validate_tls_handshake_at_time_with_boot_cap( + handshake, + secrets, + ignore_time_skew, + now, + boot_time_cap_secs, + ) } fn system_time_to_unix_secs(now: SystemTime) -> Option { @@ -311,6 +340,22 @@ fn validate_tls_handshake_at_time( secrets: &[(String, Vec)], ignore_time_skew: bool, now: i64, +) -> Option { + validate_tls_handshake_at_time_with_boot_cap( + handshake, + secrets, + ignore_time_skew, + now, + BOOT_TIME_MAX_SECS, + ) +} + +fn validate_tls_handshake_at_time_with_boot_cap( + handshake: &[u8], + secrets: &[(String, Vec)], + ignore_time_skew: bool, + now: i64, + boot_time_cap_secs: u32, ) -> Option { if handshake.len() < TLS_DIGEST_POS + TLS_DIGEST_LEN + 1 { return None; @@ -366,7 +411,7 @@ fn validate_tls_handshake_at_time( if !ignore_time_skew { // Allow very small timestamps (boot time instead of unix time) // This is a quirk in some clients that use uptime instead of real time - let is_boot_time = timestamp < BOOT_TIME_MAX_SECS; + let is_boot_time = timestamp < boot_time_cap_secs; if !is_boot_time { let time_diff = now - i64::from(timestamp); if !(TIME_SKEW_MIN..=TIME_SKEW_MAX).contains(&time_diff) { diff --git a/src/protocol/tls_security_tests.rs b/src/protocol/tls_security_tests.rs index 4372af8..c25a517 100644 --- a/src/protocol/tls_security_tests.rs +++ b/src/protocol/tls_security_tests.rs @@ -654,6 +654,34 @@ fn timestamp_at_boot_threshold_triggers_skew_check() { ); } +#[test] +fn replay_window_cap_disables_boot_bypass_for_old_timestamps() { + let secret = b"boot_cap_disabled_test"; + let ts: u32 = 900; + let h = make_valid_tls_handshake(secret, ts); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + let result = validate_tls_handshake_with_replay_window(&h, &secrets, false, 300); + assert!( + result.is_none(), + "timestamp above replay-window cap must not use boot-time bypass" + ); +} + +#[test] +fn replay_window_cap_still_allows_small_boot_timestamp() { + let secret = b"boot_cap_enabled_test"; + let ts: u32 = 120; + let h = make_valid_tls_handshake(secret, ts); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + let result = validate_tls_handshake_with_replay_window(&h, &secrets, false, 300); + assert!( + result.is_some(), + "timestamp below replay-window cap must retain boot-time compatibility" + ); +} + // ------------------------------------------------------------------ // Extreme timestamp values // ------------------------------------------------------------------ diff --git a/src/proxy/client.rs b/src/proxy/client.rs index ec99a47..5ccbd40 100644 --- a/src/proxy/client.rs +++ b/src/proxy/client.rs @@ -4,7 +4,10 @@ use std::future::Future; use std::net::{IpAddr, SocketAddr}; use std::pin::Pin; use std::sync::Arc; +use std::sync::OnceLock; +use std::sync::atomic::{AtomicBool, Ordering}; use std::time::Duration; +use ipnetwork::IpNetwork; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite}; use tokio::net::TcpStream; use tokio::time::timeout; @@ -73,6 +76,20 @@ fn record_handshake_failure_class( record_beobachten_class(beobachten, config, peer_ip, class); } +fn is_trusted_proxy_source(peer_ip: IpAddr, trusted: &[IpNetwork]) -> bool { + if trusted.is_empty() { + static EMPTY_PROXY_TRUST_WARNED: OnceLock = OnceLock::new(); + let warned = EMPTY_PROXY_TRUST_WARNED.get_or_init(|| AtomicBool::new(false)); + if !warned.swap(true, Ordering::Relaxed) { + warn!( + "PROXY protocol enabled but server.proxy_protocol_trusted_cidrs is empty; rejecting all PROXY headers by default" + ); + } + return false; + } + trusted.iter().any(|cidr| cidr.contains(peer_ip)) +} + pub async fn handle_client_stream( mut stream: S, peer: SocketAddr, @@ -106,6 +123,17 @@ where ); match timeout(proxy_header_timeout, parse_proxy_protocol(&mut stream, peer)).await { Ok(Ok(info)) => { + if !is_trusted_proxy_source(peer.ip(), &config.server.proxy_protocol_trusted_cidrs) + { + stats.increment_connects_bad(); + warn!( + peer = %peer, + trusted = ?config.server.proxy_protocol_trusted_cidrs, + "Rejecting PROXY protocol header from untrusted source" + ); + record_beobachten_class(&beobachten, &config, peer.ip(), "other"); + return Err(ProxyError::InvalidProxyProtocol); + } debug!( peer = %peer, client = %info.src_addr, @@ -462,6 +490,24 @@ impl RunningClientHandler { .await { Ok(Ok(info)) => { + if !is_trusted_proxy_source( + self.peer.ip(), + &self.config.server.proxy_protocol_trusted_cidrs, + ) { + self.stats.increment_connects_bad(); + warn!( + peer = %self.peer, + trusted = ?self.config.server.proxy_protocol_trusted_cidrs, + "Rejecting PROXY protocol header from untrusted source" + ); + record_beobachten_class( + &self.beobachten, + &self.config, + self.peer.ip(), + "other", + ); + return Err(ProxyError::InvalidProxyProtocol); + } debug!( peer = %self.peer, client = %info.src_addr, @@ -768,7 +814,7 @@ impl RunningClientHandler { client_writer, success, pool.clone(), - stats, + stats.clone(), config, buffer_pool, local_addr, @@ -785,7 +831,7 @@ impl RunningClientHandler { client_writer, success, upstream_manager, - stats, + stats.clone(), config, buffer_pool, rng, @@ -802,7 +848,7 @@ impl RunningClientHandler { client_writer, success, upstream_manager, - stats, + stats.clone(), config, buffer_pool, rng, @@ -813,6 +859,7 @@ impl RunningClientHandler { .await }; + stats.decrement_user_curr_connects(&user); ip_tracker.remove_ip(&user, peer_addr.ip()).await; relay_result } @@ -832,14 +879,6 @@ impl RunningClientHandler { }); } - if let Some(limit) = config.access.user_max_tcp_conns.get(user) - && stats.get_user_curr_connects(user) >= *limit as u64 - { - return Err(ProxyError::ConnectionLimitExceeded { - user: user.to_string(), - }); - } - if let Some(quota) = config.access.user_data_quota.get(user) && stats.get_user_total_octets(user) >= *quota { @@ -848,9 +887,21 @@ impl RunningClientHandler { }); } + let limit = config + .access + .user_max_tcp_conns + .get(user) + .map(|v| *v as u64); + if !stats.try_acquire_user_curr_connects(user, limit) { + return Err(ProxyError::ConnectionLimitExceeded { + user: user.to_string(), + }); + } + match ip_tracker.check_and_add(user, peer_addr.ip()).await { Ok(()) => {} Err(reason) => { + stats.decrement_user_curr_connects(user); warn!( user = %user, ip = %peer_addr.ip(), diff --git a/src/proxy/client_security_tests.rs b/src/proxy/client_security_tests.rs index 100763a..415cafd 100644 --- a/src/proxy/client_security_tests.rs +++ b/src/proxy/client_security_tests.rs @@ -2,6 +2,7 @@ use super::*; use crate::config::{UpstreamConfig, UpstreamType}; use crate::crypto::sha256_hmac; use crate::protocol::tls; +use crate::transport::proxy_protocol::ProxyProtocolV1Builder; use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream}; @@ -92,6 +93,206 @@ async fn short_tls_probe_is_masked_through_client_pipeline() { accept_task.await.unwrap(); } +#[tokio::test] +async fn handle_client_stream_increments_connects_all_exactly_once() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = vec![0x16, 0x03, 0x01, 0x00, 0x10]; + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = vec![0u8; probe.len()]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, probe); + } + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let before = stats.get_connects_all(); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(4096); + let peer: SocketAddr = "203.0.113.177:55001".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats.clone(), + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + )); + + client_side.write_all(&probe).await.unwrap(); + drop(client_side); + + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + assert_eq!( + stats.get_connects_all(), + before + 1, + "handle_client_stream must increment connects_all exactly once" + ); +} + +#[tokio::test] +async fn running_client_handler_increments_connects_all_exactly_once() { + let mask_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = mask_listener.local_addr().unwrap(); + + let front_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let front_addr = front_listener.local_addr().unwrap(); + + let probe = [0x16, 0x03, 0x01, 0x00, 0x10]; + + let mask_accept_task = tokio::spawn(async move { + let (mut stream, _) = mask_listener.accept().await.unwrap(); + let mut got = [0u8; 5]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, probe); + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let before = stats.get_connects_all(); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let server_task = { + let config = config.clone(); + let stats = stats.clone(); + let upstream_manager = upstream_manager.clone(); + let replay_checker = replay_checker.clone(); + let buffer_pool = buffer_pool.clone(); + let rng = rng.clone(); + let route_runtime = route_runtime.clone(); + let ip_tracker = ip_tracker.clone(); + let beobachten = beobachten.clone(); + + tokio::spawn(async move { + let (stream, peer) = front_listener.accept().await.unwrap(); + let real_peer_report = Arc::new(std::sync::Mutex::new(None)); + ClientHandler::new( + stream, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + real_peer_report, + ) + .run() + .await + }) + }; + + let mut client = TcpStream::connect(front_addr).await.unwrap(); + client.write_all(&probe).await.unwrap(); + drop(client); + + let _ = tokio::time::timeout(Duration::from_secs(3), server_task) + .await + .unwrap() + .unwrap(); + tokio::time::timeout(Duration::from_secs(3), mask_accept_task) + .await + .unwrap() + .unwrap(); + + assert_eq!( + stats.get_connects_all(), + before + 1, + "ClientHandler::run must increment connects_all exactly once" + ); +} + #[tokio::test] async fn partial_tls_header_stall_triggers_handshake_timeout() { let mut cfg = ProxyConfig::default(); @@ -1058,6 +1259,186 @@ async fn concurrent_limit_rejections_from_mixed_ips_leave_no_ip_footprint() { ); } +#[tokio::test] +async fn atomic_limit_gate_allows_only_one_concurrent_acquire() { + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert("user".to_string(), 1); + + let config = Arc::new(config); + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let mut tasks = tokio::task::JoinSet::new(); + for i in 0..64u16 { + let config = config.clone(); + let stats = stats.clone(); + let ip_tracker = ip_tracker.clone(); + tasks.spawn(async move { + let peer = SocketAddr::new( + IpAddr::V4(std::net::Ipv4Addr::new(203, 0, 113, (i + 1) as u8)), + 30000 + i, + ); + RunningClientHandler::check_user_limits_static("user", &config, &stats, peer, &ip_tracker) + .await + .is_ok() + }); + } + + let mut successes = 0u64; + while let Some(joined) = tasks.join_next().await { + if joined.unwrap() { + successes += 1; + } + } + + assert_eq!( + successes, 1, + "exactly one concurrent acquire must pass for a limit=1 user" + ); + assert_eq!(stats.get_user_curr_connects("user"), 1); +} + +#[tokio::test] +async fn untrusted_proxy_header_source_is_rejected() { + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.server.proxy_protocol_trusted_cidrs = vec!["10.10.0.0/16".parse().unwrap()]; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(2048); + let peer: SocketAddr = "198.51.100.44:55000".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + true, + )); + + let proxy_header = ProxyProtocolV1Builder::new() + .tcp4( + "203.0.113.9:32000".parse().unwrap(), + "192.0.2.8:443".parse().unwrap(), + ) + .build(); + client_side.write_all(&proxy_header).await.unwrap(); + drop(client_side); + + let result = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + assert!(matches!(result, Err(ProxyError::InvalidProxyProtocol))); +} + +#[tokio::test] +async fn empty_proxy_trusted_cidrs_rejects_proxy_header_by_default() { + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.server.proxy_protocol_trusted_cidrs.clear(); + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(2048); + let peer: SocketAddr = "198.51.100.45:55000".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + true, + )); + + let proxy_header = ProxyProtocolV1Builder::new() + .tcp4( + "203.0.113.9:32000".parse().unwrap(), + "192.0.2.8:443".parse().unwrap(), + ) + .build(); + client_side.write_all(&proxy_header).await.unwrap(); + drop(client_side); + + let result = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + assert!(matches!(result, Err(ProxyError::InvalidProxyProtocol))); +} + #[tokio::test] async fn oversized_tls_record_is_masked_in_generic_stream_pipeline() { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); diff --git a/src/proxy/direct_relay.rs b/src/proxy/direct_relay.rs index 7a7810a..9c6116c 100644 --- a/src/proxy/direct_relay.rs +++ b/src/proxy/direct_relay.rs @@ -2,6 +2,8 @@ use std::fs::OpenOptions; use std::io::Write; use std::net::SocketAddr; use std::sync::Arc; +use std::collections::HashSet; +use std::sync::{Mutex, OnceLock}; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio::net::TcpStream; @@ -22,6 +24,45 @@ use crate::stats::Stats; use crate::stream::{BufferPool, CryptoReader, CryptoWriter}; use crate::transport::UpstreamManager; +const UNKNOWN_DC_LOG_DISTINCT_LIMIT: usize = 1024; +static LOGGED_UNKNOWN_DCS: OnceLock>> = OnceLock::new(); + +// In tests, this function shares global mutable state. Callers that also use +// cache-reset helpers must hold `unknown_dc_test_lock()` to keep assertions +// deterministic under parallel execution. +fn should_log_unknown_dc(dc_idx: i16) -> bool { + let set = LOGGED_UNKNOWN_DCS.get_or_init(|| Mutex::new(HashSet::new())); + match set.lock() { + Ok(mut guard) => { + if guard.contains(&dc_idx) { + return false; + } + if guard.len() >= UNKNOWN_DC_LOG_DISTINCT_LIMIT { + return false; + } + guard.insert(dc_idx) + } + // If the lock is poisoned, keep logging rather than silently dropping + // operator-visible diagnostics. + Err(_) => true, + } +} + +#[cfg(test)] +fn clear_unknown_dc_log_cache_for_testing() { + if let Some(set) = LOGGED_UNKNOWN_DCS.get() + && let Ok(mut guard) = set.lock() + { + guard.clear(); + } +} + +#[cfg(test)] +fn unknown_dc_test_lock() -> &'static Mutex<()> { + static TEST_LOCK: OnceLock> = OnceLock::new(); + TEST_LOCK.get_or_init(|| Mutex::new(())) +} + pub(crate) async fn handle_via_direct( client_reader: CryptoReader, client_writer: CryptoWriter, @@ -64,7 +105,6 @@ where debug!(peer = %success.peer, "TG handshake complete, starting relay"); stats.increment_user_connects(user); - stats.increment_user_curr_connects(user); stats.increment_current_connections_direct(); let relay_result = relay_bidirectional( @@ -109,7 +149,6 @@ where }; stats.decrement_current_connections_direct(); - stats.decrement_user_curr_connects(user); match &relay_result { Ok(()) => debug!(user = %user, "Direct relay completed"), @@ -160,6 +199,7 @@ fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result { warn!(dc_idx = dc_idx, "Requested non-standard DC with no override; falling back to default cluster"); if config.general.unknown_dc_file_log_enabled && let Some(path) = &config.general.unknown_dc_log_path + && should_log_unknown_dc(dc_idx) && let Ok(handle) = tokio::runtime::Handle::try_current() { let path = path.clone(); @@ -175,7 +215,7 @@ fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result { let fallback_idx = if default_dc >= 1 && default_dc <= num_dcs { default_dc - 1 } else { - 1 + 0 }; info!( @@ -203,8 +243,6 @@ async fn do_tg_handshake_static( let (nonce, _tg_enc_key, _tg_enc_iv, _tg_dec_key, _tg_dec_iv) = generate_tg_nonce( success.proto_tag, success.dc_idx, - &success.dec_key, - success.dec_iv, &success.enc_key, success.enc_iv, rng, @@ -230,3 +268,7 @@ async fn do_tg_handshake_static( CryptoWriter::new(write_half, tg_encryptor, max_pending), )) } + +#[cfg(test)] +#[path = "direct_relay_security_tests.rs"] +mod security_tests; diff --git a/src/proxy/direct_relay_security_tests.rs b/src/proxy/direct_relay_security_tests.rs new file mode 100644 index 0000000..3b3185a --- /dev/null +++ b/src/proxy/direct_relay_security_tests.rs @@ -0,0 +1,51 @@ +use super::*; + +#[test] +fn unknown_dc_log_is_deduplicated_per_dc_idx() { + let _guard = unknown_dc_test_lock() + .lock() + .expect("unknown dc test lock must be available"); + clear_unknown_dc_log_cache_for_testing(); + + assert!(should_log_unknown_dc(777)); + assert!( + !should_log_unknown_dc(777), + "same unknown dc_idx must not be logged repeatedly" + ); + assert!( + should_log_unknown_dc(778), + "different unknown dc_idx must still be loggable" + ); +} + +#[test] +fn unknown_dc_log_respects_distinct_limit() { + let _guard = unknown_dc_test_lock() + .lock() + .expect("unknown dc test lock must be available"); + clear_unknown_dc_log_cache_for_testing(); + + for dc in 1..=UNKNOWN_DC_LOG_DISTINCT_LIMIT { + assert!( + should_log_unknown_dc(dc as i16), + "expected first-time unknown dc_idx to be loggable" + ); + } + + assert!( + !should_log_unknown_dc(i16::MAX), + "distinct unknown dc_idx entries above limit must not be logged" + ); +} + +#[test] +fn fallback_dc_never_panics_with_single_dc_list() { + let mut cfg = ProxyConfig::default(); + cfg.network.prefer = 6; + cfg.network.ipv6 = Some(true); + cfg.default_dc = Some(42); + + let addr = get_dc_addr_static(999, &cfg).expect("fallback dc must resolve safely"); + let expected = SocketAddr::new(TG_DATACENTERS_V6[0], TG_DATACENTER_PORT); + assert_eq!(addr, expected); +} diff --git a/src/proxy/handshake.rs b/src/proxy/handshake.rs index a97657d..a26a722 100644 --- a/src/proxy/handshake.rs +++ b/src/proxy/handshake.rs @@ -4,9 +4,11 @@ use std::net::SocketAddr; use std::collections::HashSet; +use std::net::IpAddr; use std::sync::Arc; use std::sync::{Mutex, OnceLock}; -use std::time::Duration; +use std::time::{Duration, Instant}; +use dashmap::DashMap; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tracing::{debug, warn, trace}; use zeroize::Zeroize; @@ -22,10 +24,138 @@ use crate::config::ProxyConfig; use crate::tls_front::{TlsFrontCache, emulator}; const ACCESS_SECRET_BYTES: usize = 16; -static INVALID_SECRET_WARNED: OnceLock>> = OnceLock::new(); +static INVALID_SECRET_WARNED: OnceLock>> = OnceLock::new(); + +const AUTH_PROBE_TRACK_RETENTION_SECS: u64 = 10 * 60; +const AUTH_PROBE_TRACK_MAX_ENTRIES: usize = 65_536; +const AUTH_PROBE_BACKOFF_START_FAILS: u32 = 4; + +#[cfg(test)] +const AUTH_PROBE_BACKOFF_BASE_MS: u64 = 1; +#[cfg(not(test))] +const AUTH_PROBE_BACKOFF_BASE_MS: u64 = 25; + +#[cfg(test)] +const AUTH_PROBE_BACKOFF_MAX_MS: u64 = 16; +#[cfg(not(test))] +const AUTH_PROBE_BACKOFF_MAX_MS: u64 = 1_000; + +#[derive(Clone, Copy)] +struct AuthProbeState { + fail_streak: u32, + blocked_until: Instant, + last_seen: Instant, +} + +static AUTH_PROBE_STATE: OnceLock> = OnceLock::new(); + +fn auth_probe_state_map() -> &'static DashMap { + AUTH_PROBE_STATE.get_or_init(DashMap::new) +} + +fn auth_probe_backoff(fail_streak: u32) -> Duration { + if fail_streak < AUTH_PROBE_BACKOFF_START_FAILS { + return Duration::ZERO; + } + let shift = (fail_streak - AUTH_PROBE_BACKOFF_START_FAILS).min(10); + let multiplier = 1u64.checked_shl(shift).unwrap_or(u64::MAX); + let ms = AUTH_PROBE_BACKOFF_BASE_MS + .saturating_mul(multiplier) + .min(AUTH_PROBE_BACKOFF_MAX_MS); + Duration::from_millis(ms) +} + +fn auth_probe_state_expired(state: &AuthProbeState, now: Instant) -> bool { + let retention = Duration::from_secs(AUTH_PROBE_TRACK_RETENTION_SECS); + now.duration_since(state.last_seen) > retention +} + +fn auth_probe_is_throttled(peer_ip: IpAddr, now: Instant) -> bool { + let state = auth_probe_state_map(); + let Some(entry) = state.get(&peer_ip) else { + return false; + }; + if auth_probe_state_expired(&entry, now) { + drop(entry); + state.remove(&peer_ip); + return false; + } + now < entry.blocked_until +} + +fn auth_probe_record_failure(peer_ip: IpAddr, now: Instant) { + let state = auth_probe_state_map(); + if let Some(mut entry) = state.get_mut(&peer_ip) { + if auth_probe_state_expired(&entry, now) { + *entry = AuthProbeState { + fail_streak: 1, + blocked_until: now + auth_probe_backoff(1), + last_seen: now, + }; + return; + } + entry.fail_streak = entry.fail_streak.saturating_add(1); + entry.last_seen = now; + entry.blocked_until = now + auth_probe_backoff(entry.fail_streak); + return; + }; + + if state.len() >= AUTH_PROBE_TRACK_MAX_ENTRIES { + return; + } + + state.insert(peer_ip, AuthProbeState { + fail_streak: 0, + blocked_until: now, + last_seen: now, + }); + + if let Some(mut entry) = state.get_mut(&peer_ip) { + entry.fail_streak = 1; + entry.blocked_until = now + auth_probe_backoff(1); + } +} + +fn auth_probe_record_success(peer_ip: IpAddr) { + let state = auth_probe_state_map(); + state.remove(&peer_ip); +} + +#[cfg(test)] +fn clear_auth_probe_state_for_testing() { + if let Some(state) = AUTH_PROBE_STATE.get() { + state.clear(); + } +} + +#[cfg(test)] +fn auth_probe_fail_streak_for_testing(peer_ip: IpAddr) -> Option { + let state = AUTH_PROBE_STATE.get()?; + state.get(&peer_ip).map(|entry| entry.fail_streak) +} + +#[cfg(test)] +fn auth_probe_is_throttled_for_testing(peer_ip: IpAddr) -> bool { + auth_probe_is_throttled(peer_ip, Instant::now()) +} + +#[cfg(test)] +fn auth_probe_test_lock() -> &'static Mutex<()> { + static TEST_LOCK: OnceLock> = OnceLock::new(); + TEST_LOCK.get_or_init(|| Mutex::new(())) +} + +#[cfg(test)] +fn clear_warned_secrets_for_testing() { + if let Some(warned) = INVALID_SECRET_WARNED.get() + && let Ok(mut guard) = warned.lock() + { + guard.clear(); + } +} fn warn_invalid_secret_once(name: &str, reason: &str, expected: usize, got: Option) { - let key = format!("{}:{}", name, reason); + let key = (name.to_string(), reason.to_string()); let warned = INVALID_SECRET_WARNED.get_or_init(|| Mutex::new(HashSet::new())); let should_warn = match warned.lock() { Ok(mut guard) => guard.insert(key), @@ -170,6 +300,11 @@ where { debug!(peer = %peer, handshake_len = handshake.len(), "Processing TLS handshake"); + if auth_probe_is_throttled(peer.ip(), Instant::now()) { + debug!(peer = %peer, "TLS handshake rejected by pre-auth probe throttle"); + return HandshakeResult::BadClient { reader, writer }; + } + if handshake.len() < tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 { debug!(peer = %peer, "TLS handshake too short"); return HandshakeResult::BadClient { reader, writer }; @@ -177,13 +312,15 @@ where let secrets = decode_user_secrets(config, None); - let validation = match tls::validate_tls_handshake( + let validation = match tls::validate_tls_handshake_with_replay_window( handshake, &secrets, config.access.ignore_time_skew, + config.access.replay_window_secs, ) { Some(v) => v, None => { + auth_probe_record_failure(peer.ip(), Instant::now()); debug!( peer = %peer, ignore_time_skew = config.access.ignore_time_skew, @@ -197,6 +334,7 @@ where // letting unauthenticated probes evict valid entries from the replay cache. let digest_half = &validation.digest[..tls::TLS_DIGEST_HALF_LEN]; if replay_checker.check_and_add_tls_digest(digest_half) { + auth_probe_record_failure(peer.ip(), Instant::now()); warn!(peer = %peer, "TLS replay attack detected (duplicate digest)"); return HandshakeResult::BadClient { reader, writer }; } @@ -307,6 +445,8 @@ where "TLS handshake successful" ); + auth_probe_record_success(peer.ip()); + HandshakeResult::Success(( FakeTlsReader::new(reader), FakeTlsWriter::new(writer), @@ -331,6 +471,11 @@ where { trace!(peer = %peer, handshake = ?hex::encode(handshake), "MTProto handshake bytes"); + if auth_probe_is_throttled(peer.ip(), Instant::now()) { + debug!(peer = %peer, "MTProto handshake rejected by pre-auth probe throttle"); + return HandshakeResult::BadClient { reader, writer }; + } + let dec_prekey_iv = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]; let enc_prekey_iv: Vec = dec_prekey_iv.iter().rev().copied().collect(); @@ -396,6 +541,7 @@ where // entry from the cache. We accept the cost of performing the full // authentication check first to avoid poisoning the replay cache. if replay_checker.check_and_add_handshake(dec_prekey_iv) { + auth_probe_record_failure(peer.ip(), Instant::now()); warn!(peer = %peer, user = %user, "MTProto replay attack detected"); return HandshakeResult::BadClient { reader, writer }; } @@ -421,6 +567,8 @@ where "MTProto handshake successful" ); + auth_probe_record_success(peer.ip()); + let max_pending = config.general.crypto_pending_buffer; return HandshakeResult::Success(( CryptoReader::new(reader, decryptor), @@ -429,6 +577,7 @@ where )); } + auth_probe_record_failure(peer.ip(), Instant::now()); debug!(peer = %peer, "MTProto handshake: no matching user found"); HandshakeResult::BadClient { reader, writer } } @@ -437,8 +586,6 @@ where pub fn generate_tg_nonce( proto_tag: ProtoTag, dc_idx: i16, - _client_dec_key: &[u8; 32], - _client_dec_iv: u128, client_enc_key: &[u8; 32], client_enc_iv: u128, rng: &SecureRandom, diff --git a/src/proxy/handshake_security_tests.rs b/src/proxy/handshake_security_tests.rs index da4aa26..5f62048 100644 --- a/src/proxy/handshake_security_tests.rs +++ b/src/proxy/handshake_security_tests.rs @@ -82,6 +82,7 @@ fn make_valid_tls_client_hello_with_alpn( } fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig { + clear_auth_probe_state_for_testing(); let mut cfg = ProxyConfig::default(); cfg.access.users.clear(); cfg.access @@ -93,8 +94,6 @@ fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig { #[test] fn test_generate_tg_nonce() { - let client_dec_key = [0x42u8; 32]; - let client_dec_iv = 12345u128; let client_enc_key = [0x24u8; 32]; let client_enc_iv = 54321u128; @@ -102,8 +101,6 @@ fn test_generate_tg_nonce() { let (nonce, _tg_enc_key, _tg_enc_iv, _tg_dec_key, _tg_dec_iv) = generate_tg_nonce( ProtoTag::Secure, 2, - &client_dec_key, - client_dec_iv, &client_enc_key, client_enc_iv, &rng, @@ -118,8 +115,6 @@ fn test_generate_tg_nonce() { #[test] fn test_encrypt_tg_nonce() { - let client_dec_key = [0x42u8; 32]; - let client_dec_iv = 12345u128; let client_enc_key = [0x24u8; 32]; let client_enc_iv = 54321u128; @@ -127,8 +122,6 @@ fn test_encrypt_tg_nonce() { let (nonce, _, _, _, _) = generate_tg_nonce( ProtoTag::Secure, 2, - &client_dec_key, - client_dec_iv, &client_enc_key, client_enc_iv, &rng, @@ -164,8 +157,6 @@ fn test_handshake_success_drop_does_not_panic() { #[test] fn test_generate_tg_nonce_enc_dec_material_is_consistent() { - let client_dec_key = [0x12u8; 32]; - let client_dec_iv = 0x11223344556677889900aabbccddeeffu128; let client_enc_key = [0x34u8; 32]; let client_enc_iv = 0xffeeddccbbaa00998877665544332211u128; let rng = SecureRandom::new(); @@ -173,8 +164,6 @@ fn test_generate_tg_nonce_enc_dec_material_is_consistent() { let (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv) = generate_tg_nonce( ProtoTag::Secure, 7, - &client_dec_key, - client_dec_iv, &client_enc_key, client_enc_iv, &rng, @@ -209,8 +198,6 @@ fn test_generate_tg_nonce_enc_dec_material_is_consistent() { #[test] fn test_generate_tg_nonce_fast_mode_embeds_reversed_client_enc_material() { - let client_dec_key = [0x22u8; 32]; - let client_dec_iv = 0x0102030405060708090a0b0c0d0e0f10u128; let client_enc_key = [0xABu8; 32]; let client_enc_iv = 0x11223344556677889900aabbccddeeffu128; let rng = SecureRandom::new(); @@ -218,8 +205,6 @@ fn test_generate_tg_nonce_fast_mode_embeds_reversed_client_enc_material() { let (nonce, _, _, _, _) = generate_tg_nonce( ProtoTag::Secure, 9, - &client_dec_key, - client_dec_iv, &client_enc_key, client_enc_iv, &rng, @@ -236,8 +221,6 @@ fn test_generate_tg_nonce_fast_mode_embeds_reversed_client_enc_material() { #[test] fn test_encrypt_tg_nonce_with_ciphers_matches_manual_suffix_encryption() { - let client_dec_key = [0x42u8; 32]; - let client_dec_iv = 12345u128; let client_enc_key = [0x24u8; 32]; let client_enc_iv = 54321u128; @@ -245,8 +228,6 @@ fn test_encrypt_tg_nonce_with_ciphers_matches_manual_suffix_encryption() { let (nonce, _, _, _, _) = generate_tg_nonce( ProtoTag::Secure, 2, - &client_dec_key, - client_dec_iv, &client_enc_key, client_enc_iv, &rng, @@ -386,6 +367,7 @@ async fn invalid_tls_probe_does_not_pollute_replay_cache() { #[tokio::test] async fn empty_decoded_secret_is_rejected() { + clear_warned_secrets_for_testing(); let config = test_config_with_secret_hex(""); let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); let rng = SecureRandom::new(); @@ -409,6 +391,7 @@ async fn empty_decoded_secret_is_rejected() { #[tokio::test] async fn wrong_length_decoded_secret_is_rejected() { + clear_warned_secrets_for_testing(); let config = test_config_with_secret_hex("aa"); let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); let rng = SecureRandom::new(); @@ -458,6 +441,8 @@ async fn invalid_mtproto_probe_does_not_pollute_replay_cache() { #[tokio::test] async fn mixed_secret_lengths_keep_valid_user_authenticating() { + clear_warned_secrets_for_testing(); + clear_auth_probe_state_for_testing(); let good_secret = [0x22u8; 16]; let mut config = ProxyConfig::default(); config.access.users.clear(); @@ -582,6 +567,7 @@ async fn malformed_tls_classes_complete_within_bounded_time() { } #[tokio::test] +#[ignore = "timing-sensitive; run manually on low-jitter hosts"] async fn malformed_tls_classes_share_close_latency_buckets() { const ITER: usize = 24; const BUCKET_MS: u128 = 10; @@ -680,3 +666,114 @@ fn secure_tag_requires_secure_mode_on_direct_transport() { "Secure tag without TLS must be accepted when secure mode is enabled" ); } + +#[test] +fn invalid_secret_warning_keys_do_not_collide_on_colon_boundaries() { + clear_warned_secrets_for_testing(); + + warn_invalid_secret_once("a:b", "c", ACCESS_SECRET_BYTES, Some(1)); + warn_invalid_secret_once("a", "b:c", ACCESS_SECRET_BYTES, Some(2)); + + let warned = INVALID_SECRET_WARNED + .get() + .expect("warned set must be initialized"); + let guard = warned.lock().expect("warned set lock must be available"); + assert_eq!( + guard.len(), + 2, + "(name, reason) pairs that stringify to the same colon-joined key must remain distinct" + ); +} + +#[tokio::test] +async fn repeated_invalid_tls_probes_trigger_pre_auth_throttle() { + let _guard = auth_probe_test_lock() + .lock() + .expect("auth probe test lock must be available"); + clear_auth_probe_state_for_testing(); + + let config = test_config_with_secret_hex("11111111111111111111111111111111"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "127.0.0.1:44361".parse().unwrap(); + + let mut invalid = vec![0x42u8; tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + 32]; + invalid[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = 32; + + for _ in 0..AUTH_PROBE_BACKOFF_START_FAILS { + let result = handle_tls_handshake( + &invalid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(result, HandshakeResult::BadClient { .. })); + } + + assert!( + auth_probe_is_throttled_for_testing(peer.ip()), + "invalid probe burst must activate per-IP pre-auth throttle" + ); +} + +#[tokio::test] +async fn successful_tls_handshake_clears_pre_auth_failure_streak() { + let _guard = auth_probe_test_lock() + .lock() + .expect("auth probe test lock must be available"); + clear_auth_probe_state_for_testing(); + + let secret = [0x23u8; 16]; + let config = test_config_with_secret_hex("23232323232323232323232323232323"); + let replay_checker = ReplayChecker::new(256, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "127.0.0.1:44362".parse().unwrap(); + + let mut invalid = vec![0x42u8; tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + 32]; + invalid[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = 32; + + for expected in 1..AUTH_PROBE_BACKOFF_START_FAILS { + let result = handle_tls_handshake( + &invalid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(result, HandshakeResult::BadClient { .. })); + assert_eq!( + auth_probe_fail_streak_for_testing(peer.ip()), + Some(expected), + "failure streak must grow before a successful authentication" + ); + } + + let valid = make_valid_tls_handshake(&secret, 0); + let success = handle_tls_handshake( + &valid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(success, HandshakeResult::Success(_))); + assert_eq!( + auth_probe_fail_streak_for_testing(peer.ip()), + None, + "successful authentication must clear accumulated pre-auth failures" + ); +} diff --git a/src/proxy/masking.rs b/src/proxy/masking.rs index d7eaef8..e347d73 100644 --- a/src/proxy/masking.rs +++ b/src/proxy/masking.rs @@ -232,6 +232,9 @@ where if mask_write.write_all(initial_data).await.is_err() { return; } + if mask_write.flush().await.is_err() { + return; + } let mut client_buf = vec![0u8; MASK_BUFFER_SIZE]; let mut mask_buf = vec![0u8; MASK_BUFFER_SIZE]; diff --git a/src/proxy/masking_security_tests.rs b/src/proxy/masking_security_tests.rs index 2fc6a79..52e9f69 100644 --- a/src/proxy/masking_security_tests.rs +++ b/src/proxy/masking_security_tests.rs @@ -122,6 +122,12 @@ fn detect_client_type_covers_ssh_port_scanner_and_unknown() { assert_eq!(detect_client_type(b"random-binary-payload"), "unknown"); } +#[test] +fn detect_client_type_len_boundary_9_vs_10_bytes() { + assert_eq!(detect_client_type(b"123456789"), "port-scanner"); + assert_eq!(detect_client_type(b"1234567890"), "unknown"); +} + #[tokio::test] async fn beobachten_records_scanner_class_when_mask_is_disabled() { let mut config = ProxyConfig::default(); diff --git a/src/proxy/middle_relay.rs b/src/proxy/middle_relay.rs index aaae1b3..0aaa016 100644 --- a/src/proxy/middle_relay.rs +++ b/src/proxy/middle_relay.rs @@ -1,12 +1,14 @@ -use std::collections::HashMap; use std::collections::hash_map::DefaultHasher; use std::hash::{Hash, Hasher}; use std::net::{IpAddr, SocketAddr}; use std::sync::atomic::{AtomicU64, Ordering}; -use std::sync::{Arc, Mutex, OnceLock}; +use std::sync::{Arc, OnceLock}; use std::time::{Duration, Instant}; +#[cfg(test)] +use std::sync::Mutex; -use bytes::Bytes; +use bytes::{Bytes, BytesMut}; +use dashmap::DashMap; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::sync::{mpsc, oneshot, watch}; use tracing::{debug, trace, warn}; @@ -30,13 +32,15 @@ enum C2MeCommand { } const DESYNC_DEDUP_WINDOW: Duration = Duration::from_secs(60); +const DESYNC_DEDUP_MAX_ENTRIES: usize = 65_536; +const DESYNC_DEDUP_PRUNE_SCAN_LIMIT: usize = 1024; const DESYNC_ERROR_CLASS: &str = "frame_too_large_crypto_desync"; const C2ME_CHANNEL_CAPACITY_FALLBACK: usize = 128; const C2ME_SOFT_PRESSURE_MIN_FREE_SLOTS: usize = 64; const C2ME_SENDER_FAIRNESS_BUDGET: usize = 32; const ME_D2C_FLUSH_BATCH_MAX_FRAMES_MIN: usize = 1; const ME_D2C_FLUSH_BATCH_MAX_BYTES_MIN: usize = 4096; -static DESYNC_DEDUP: OnceLock>> = OnceLock::new(); +static DESYNC_DEDUP: OnceLock> = OnceLock::new(); struct RelayForensicsState { trace_id: u64, @@ -90,24 +94,46 @@ fn should_emit_full_desync(key: u64, all_full: bool, now: Instant) -> bool { return true; } - let dedup = DESYNC_DEDUP.get_or_init(|| Mutex::new(HashMap::new())); - let mut guard = dedup.lock().expect("desync dedup mutex poisoned"); - guard.retain(|_, seen_at| now.duration_since(*seen_at) < DESYNC_DEDUP_WINDOW); + let dedup = DESYNC_DEDUP.get_or_init(DashMap::new); - match guard.get_mut(&key) { - Some(seen_at) => { - if now.duration_since(*seen_at) >= DESYNC_DEDUP_WINDOW { - *seen_at = now; - true - } else { - false + if let Some(mut seen_at) = dedup.get_mut(&key) { + if now.duration_since(*seen_at) >= DESYNC_DEDUP_WINDOW { + *seen_at = now; + return true; + } + return false; + } + + if dedup.len() >= DESYNC_DEDUP_MAX_ENTRIES { + let mut stale_keys = Vec::new(); + for entry in dedup.iter().take(DESYNC_DEDUP_PRUNE_SCAN_LIMIT) { + if now.duration_since(*entry.value()) >= DESYNC_DEDUP_WINDOW { + stale_keys.push(*entry.key()); } } - None => { - guard.insert(key, now); - true + for stale_key in stale_keys { + dedup.remove(&stale_key); + } + if dedup.len() >= DESYNC_DEDUP_MAX_ENTRIES { + return false; } } + + dedup.insert(key, now); + true +} + +#[cfg(test)] +fn clear_desync_dedup_for_testing() { + if let Some(dedup) = DESYNC_DEDUP.get() { + dedup.clear(); + } +} + +#[cfg(test)] +fn desync_dedup_test_lock() -> &'static Mutex<()> { + static TEST_LOCK: OnceLock> = OnceLock::new(); + TEST_LOCK.get_or_init(|| Mutex::new(())) } fn report_desync_frame_too_large( @@ -229,7 +255,7 @@ pub(crate) async fn handle_via_middle_proxy( me_pool: Arc, stats: Arc, config: Arc, - _buffer_pool: Arc, + buffer_pool: Arc, local_addr: SocketAddr, rng: Arc, mut route_rx: watch::Receiver, @@ -271,7 +297,6 @@ where }; stats.increment_user_connects(&user); - stats.increment_user_curr_connects(&user); stats.increment_current_connections_me(); if let Some(cutover) = affected_cutover_state( @@ -291,7 +316,6 @@ where let _ = me_pool.send_close(conn_id).await; me_pool.registry().unregister(conn_id).await; stats.decrement_current_connections_me(); - stats.decrement_user_curr_connects(&user); return Err(ProxyError::Proxy(ROUTE_SWITCH_ERROR_MSG.to_string())); } @@ -557,6 +581,7 @@ where &mut crypto_reader, proto_tag, frame_limit, + &buffer_pool, &forensics, &mut frame_counter, &stats, @@ -638,7 +663,6 @@ where ); me_pool.registry().unregister(conn_id).await; stats.decrement_current_connections_me(); - stats.decrement_user_curr_connects(&user); result } @@ -646,6 +670,7 @@ async fn read_client_payload( client_reader: &mut CryptoReader, proto_tag: ProtoTag, max_frame: usize, + buffer_pool: &Arc, forensics: &RelayForensicsState, frame_counter: &mut u64, stats: &Stats, @@ -737,18 +762,27 @@ where len }; - let mut payload = vec![0u8; len]; - client_reader - .read_exact(&mut payload) - .await - .map_err(ProxyError::Io)?; + let chunk_cap = buffer_pool.buffer_size().max(1024); + let mut payload = BytesMut::with_capacity(len.min(chunk_cap)); + let mut remaining = len; + while remaining > 0 { + let chunk_len = remaining.min(chunk_cap); + let mut chunk = buffer_pool.get(); + chunk.resize(chunk_len, 0); + client_reader + .read_exact(&mut chunk[..chunk_len]) + .await + .map_err(ProxyError::Io)?; + payload.extend_from_slice(&chunk[..chunk_len]); + remaining -= chunk_len; + } // Secure Intermediate: strip validated trailing padding bytes. if proto_tag == ProtoTag::Secure { payload.truncate(secure_payload_len); } *frame_counter += 1; - return Ok(Some((Bytes::from(payload), quickack))); + return Ok(Some((payload.freeze(), quickack))); } } @@ -940,82 +974,5 @@ where } #[cfg(test)] -mod tests { - use super::*; - use tokio::time::{Duration as TokioDuration, timeout}; - - #[test] - fn should_yield_sender_only_on_budget_with_backlog() { - assert!(!should_yield_c2me_sender(0, true)); - assert!(!should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET - 1, true)); - assert!(!should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET, false)); - assert!(should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET, true)); - } - - #[tokio::test] - async fn enqueue_c2me_command_uses_try_send_fast_path() { - let (tx, mut rx) = mpsc::channel::(2); - enqueue_c2me_command( - &tx, - C2MeCommand::Data { - payload: Bytes::from_static(&[1, 2, 3]), - flags: 0, - }, - ) - .await - .unwrap(); - - let recv = timeout(TokioDuration::from_millis(50), rx.recv()) - .await - .unwrap() - .unwrap(); - match recv { - C2MeCommand::Data { payload, flags } => { - assert_eq!(payload.as_ref(), &[1, 2, 3]); - assert_eq!(flags, 0); - } - C2MeCommand::Close => panic!("unexpected close command"), - } - } - - #[tokio::test] - async fn enqueue_c2me_command_falls_back_to_send_when_queue_is_full() { - let (tx, mut rx) = mpsc::channel::(1); - tx.send(C2MeCommand::Data { - payload: Bytes::from_static(&[9]), - flags: 9, - }) - .await - .unwrap(); - - let tx2 = tx.clone(); - let producer = tokio::spawn(async move { - enqueue_c2me_command( - &tx2, - C2MeCommand::Data { - payload: Bytes::from_static(&[7, 7]), - flags: 7, - }, - ) - .await - .unwrap(); - }); - - let _ = timeout(TokioDuration::from_millis(100), rx.recv()) - .await - .unwrap(); - producer.await.unwrap(); - - let recv = timeout(TokioDuration::from_millis(100), rx.recv()) - .await - .unwrap() - .unwrap(); - match recv { - C2MeCommand::Data { payload, flags } => { - assert_eq!(payload.as_ref(), &[7, 7]); - assert_eq!(flags, 7); - } - C2MeCommand::Close => panic!("unexpected close command"), - } - } -} +#[path = "middle_relay_security_tests.rs"] +mod security_tests; diff --git a/src/proxy/middle_relay_security_tests.rs b/src/proxy/middle_relay_security_tests.rs new file mode 100644 index 0000000..d7d1243 --- /dev/null +++ b/src/proxy/middle_relay_security_tests.rs @@ -0,0 +1,103 @@ +use super::*; +use tokio::time::{Duration as TokioDuration, timeout}; + +#[test] +fn should_yield_sender_only_on_budget_with_backlog() { + assert!(!should_yield_c2me_sender(0, true)); + assert!(!should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET - 1, true)); + assert!(!should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET, false)); + assert!(should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET, true)); +} + +#[tokio::test] +async fn enqueue_c2me_command_uses_try_send_fast_path() { + let (tx, mut rx) = mpsc::channel::(2); + enqueue_c2me_command( + &tx, + C2MeCommand::Data { + payload: Bytes::from_static(&[1, 2, 3]), + flags: 0, + }, + ) + .await + .unwrap(); + + let recv = timeout(TokioDuration::from_millis(50), rx.recv()) + .await + .unwrap() + .unwrap(); + match recv { + C2MeCommand::Data { payload, flags } => { + assert_eq!(payload.as_ref(), &[1, 2, 3]); + assert_eq!(flags, 0); + } + C2MeCommand::Close => panic!("unexpected close command"), + } +} + +#[tokio::test] +async fn enqueue_c2me_command_falls_back_to_send_when_queue_is_full() { + let (tx, mut rx) = mpsc::channel::(1); + tx.send(C2MeCommand::Data { + payload: Bytes::from_static(&[9]), + flags: 9, + }) + .await + .unwrap(); + + let tx2 = tx.clone(); + let producer = tokio::spawn(async move { + enqueue_c2me_command( + &tx2, + C2MeCommand::Data { + payload: Bytes::from_static(&[7, 7]), + flags: 7, + }, + ) + .await + .unwrap(); + }); + + let _ = timeout(TokioDuration::from_millis(100), rx.recv()) + .await + .unwrap(); + producer.await.unwrap(); + + let recv = timeout(TokioDuration::from_millis(100), rx.recv()) + .await + .unwrap() + .unwrap(); + match recv { + C2MeCommand::Data { payload, flags } => { + assert_eq!(payload.as_ref(), &[7, 7]); + assert_eq!(flags, 7); + } + C2MeCommand::Close => panic!("unexpected close command"), + } +} + +#[test] +fn desync_dedup_cache_is_bounded() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("desync dedup test lock must be available"); + clear_desync_dedup_for_testing(); + + let now = Instant::now(); + for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { + assert!( + should_emit_full_desync(key, false, now), + "unique keys up to cap must be tracked" + ); + } + + assert!( + !should_emit_full_desync(u64::MAX, false, now), + "new key above cap must be suppressed to bound memory" + ); + + assert!( + !should_emit_full_desync(7, false, now), + "already tracked key inside dedup window must stay suppressed" + ); +} diff --git a/src/stats/mod.rs b/src/stats/mod.rs index 25905b2..603552d 100644 --- a/src/stats/mod.rs +++ b/src/stats/mod.rs @@ -1256,6 +1256,33 @@ impl Stats { Self::touch_user_stats(stats.value()); stats.curr_connects.fetch_add(1, Ordering::Relaxed); } + + pub fn try_acquire_user_curr_connects(&self, user: &str, limit: Option) -> bool { + if !self.telemetry_user_enabled() { + return true; + } + + self.maybe_cleanup_user_stats(); + let stats = self.user_stats.entry(user.to_string()).or_default(); + Self::touch_user_stats(stats.value()); + + let counter = &stats.curr_connects; + let mut current = counter.load(Ordering::Relaxed); + loop { + if let Some(max) = limit && current >= max { + return false; + } + match counter.compare_exchange_weak( + current, + current.saturating_add(1), + Ordering::Relaxed, + Ordering::Relaxed, + ) { + Ok(_) => return true, + Err(actual) => current = actual, + } + } + } pub fn decrement_user_curr_connects(&self, user: &str) { self.maybe_cleanup_user_stats(); diff --git a/src/stream/frame_codec.rs b/src/stream/frame_codec.rs index 2ff7de7..403f695 100644 --- a/src/stream/frame_codec.rs +++ b/src/stream/frame_codec.rs @@ -513,6 +513,7 @@ impl FrameCodecTrait for SecureCodec { #[cfg(test)] mod tests { use super::*; + use std::collections::HashSet; use tokio_util::codec::{FramedRead, FramedWrite}; use tokio::io::duplex; use futures::{SinkExt, StreamExt}; @@ -630,4 +631,31 @@ mod tests { let result = codec.decode(&mut buf); assert!(result.is_err()); } + + #[test] + fn secure_codec_always_adds_padding_and_jitters_wire_length() { + let codec = SecureCodec::new(Arc::new(SecureRandom::new())); + let payload = Bytes::from_static(&[1, 2, 3, 4, 5, 6, 7, 8]); + let mut wire_lens = HashSet::new(); + + for _ in 0..64 { + let frame = Frame::new(payload.clone()); + let mut out = BytesMut::new(); + codec.encode(&frame, &mut out).unwrap(); + + assert!(out.len() >= 4 + payload.len() + 1); + let wire_len = u32::from_le_bytes([out[0], out[1], out[2], out[3]]) as usize; + assert!( + (payload.len() + 1..=payload.len() + 3).contains(&wire_len), + "Secure wire length must be payload+1..3, got {wire_len}" + ); + assert_ne!(wire_len % 4, 0, "Secure wire length must be non-4-aligned"); + wire_lens.insert(wire_len); + } + + assert!( + wire_lens.len() >= 2, + "Secure padding should create observable wire-length jitter" + ); + } }