From c540a6657fdfa0974cea22783f1f82b6f922d429 Mon Sep 17 00:00:00 2001 From: David Osipov Date: Tue, 17 Mar 2026 19:05:26 +0400 Subject: [PATCH] Implement user connection reservation management and enhance relay task handling in proxy --- src/proxy/client.rs | 119 +++++++- src/proxy/client_security_tests.rs | 268 +++++++++++++++++++ src/proxy/direct_relay_security_tests.rs | 109 ++++++++ src/proxy/middle_relay_security_tests.rs | 77 ++++++ src/stats/connection_lease_security_tests.rs | 151 +++++++++++ 5 files changed, 714 insertions(+), 10 deletions(-) diff --git a/src/proxy/client.rs b/src/proxy/client.rs index 5ccbd40..254d922 100644 --- a/src/proxy/client.rs +++ b/src/proxy/client.rs @@ -24,6 +24,39 @@ enum HandshakeOutcome { Handled, } +struct UserConnectionReservation { + stats: Arc, + ip_tracker: Arc, + user: String, + ip: IpAddr, +} + +impl UserConnectionReservation { + fn new(stats: Arc, ip_tracker: Arc, user: String, ip: IpAddr) -> Self { + Self { + stats, + ip_tracker, + user, + ip, + } + } +} + +impl Drop for UserConnectionReservation { + fn drop(&mut self) { + self.stats.decrement_user_curr_connects(&self.user); + + if let Ok(handle) = tokio::runtime::Handle::try_current() { + let ip_tracker = self.ip_tracker.clone(); + let user = self.user.clone(); + let ip = self.ip; + handle.spawn(async move { + ip_tracker.remove_ip(&user, ip).await; + }); + } + } +} + use crate::config::ProxyConfig; use crate::crypto::SecureRandom; use crate::error::{HandshakeResult, ProxyError, Result, StreamError}; @@ -90,6 +123,10 @@ fn is_trusted_proxy_source(peer_ip: IpAddr, trusted: &[IpNetwork]) -> bool { trusted.iter().any(|cidr| cidr.contains(peer_ip)) } +fn synthetic_local_addr(port: u16) -> SocketAddr { + SocketAddr::from(([0, 0, 0, 0], port)) +} + pub async fn handle_client_stream( mut stream: S, peer: SocketAddr, @@ -113,9 +150,7 @@ where let mut real_peer = normalize_ip(peer); // For non-TCP streams, use a synthetic local address; may be overridden by PROXY protocol dst - let mut local_addr: SocketAddr = format!("0.0.0.0:{}", config.server.port) - .parse() - .unwrap_or_else(|_| "0.0.0.0:443".parse().unwrap()); + let mut local_addr = synthetic_local_addr(config.server.port); if proxy_protocol_enabled { let proxy_header_timeout = Duration::from_millis( @@ -798,10 +833,22 @@ impl RunningClientHandler { { let user = success.user.clone(); - if let Err(e) = Self::check_user_limits_static(&user, &config, &stats, peer_addr, &ip_tracker).await { - warn!(user = %user, error = %e, "User limit exceeded"); - return Err(e); - } + let _user_limit_reservation = + match Self::acquire_user_connection_reservation_static( + &user, + &config, + stats.clone(), + peer_addr, + ip_tracker, + ) + .await + { + Ok(reservation) => reservation, + Err(e) => { + warn!(user = %user, error = %e, "User limit exceeded"); + return Err(e); + } + }; let route_snapshot = route_runtime.snapshot(); let session_id = rng.u64(); @@ -858,12 +905,64 @@ impl RunningClientHandler { ) .await }; - - stats.decrement_user_curr_connects(&user); - ip_tracker.remove_ip(&user, peer_addr.ip()).await; relay_result } + async fn acquire_user_connection_reservation_static( + user: &str, + config: &ProxyConfig, + stats: Arc, + peer_addr: SocketAddr, + ip_tracker: Arc, + ) -> Result { + if let Some(expiration) = config.access.user_expirations.get(user) + && chrono::Utc::now() > *expiration + { + return Err(ProxyError::UserExpired { + user: user.to_string(), + }); + } + + if let Some(quota) = config.access.user_data_quota.get(user) + && stats.get_user_total_octets(user) >= *quota + { + return Err(ProxyError::DataQuotaExceeded { + user: user.to_string(), + }); + } + + let limit = config.access.user_max_tcp_conns.get(user).map(|v| *v as u64); + if !stats.try_acquire_user_curr_connects(user, limit) { + return Err(ProxyError::ConnectionLimitExceeded { + user: user.to_string(), + }); + } + + match ip_tracker.check_and_add(user, peer_addr.ip()).await { + Ok(()) => {} + Err(reason) => { + stats.decrement_user_curr_connects(user); + warn!( + user = %user, + ip = %peer_addr.ip(), + reason = %reason, + "IP limit exceeded" + ); + return Err(ProxyError::ConnectionLimitExceeded { + user: user.to_string(), + }); + } + } + + Ok(UserConnectionReservation::new( + stats, + ip_tracker, + user.to_string(), + peer_addr.ip(), + )) + } + + #[cfg(test)] async fn check_user_limits_static( user: &str, config: &ProxyConfig, diff --git a/src/proxy/client_security_tests.rs b/src/proxy/client_security_tests.rs index 415cafd..defb3c0 100644 --- a/src/proxy/client_security_tests.rs +++ b/src/proxy/client_security_tests.rs @@ -1,11 +1,279 @@ use super::*; use crate::config::{UpstreamConfig, UpstreamType}; +use crate::crypto::AesCtr; use crate::crypto::sha256_hmac; +use crate::protocol::constants::ProtoTag; use crate::protocol::tls; +use crate::proxy::handshake::HandshakeSuccess; use crate::transport::proxy_protocol::ProxyProtocolV1Builder; +use crate::stream::{CryptoReader, CryptoWriter}; use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream}; +#[test] +fn synthetic_local_addr_uses_configured_port_for_zero() { + let addr = synthetic_local_addr(0); + assert_eq!(addr.ip(), IpAddr::from([0, 0, 0, 0])); + assert_eq!(addr.port(), 0); +} + +#[test] +fn synthetic_local_addr_uses_configured_port_for_max() { + let addr = synthetic_local_addr(u16::MAX); + assert_eq!(addr.ip(), IpAddr::from([0, 0, 0, 0])); + assert_eq!(addr.port(), u16::MAX); +} + +fn make_crypto_reader(reader: R) -> CryptoReader +where + R: tokio::io::AsyncRead + Unpin, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoReader::new(reader, AesCtr::new(&key, iv)) +} + +fn make_crypto_writer(writer: W) -> CryptoWriter +where + W: tokio::io::AsyncWrite + Unpin, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) +} + +#[tokio::test] +async fn relay_task_abort_releases_user_gate_and_ip_reservation() { + let tg_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let tg_addr = tg_listener.local_addr().unwrap(); + + let tg_accept_task = tokio::spawn(async move { + let (stream, _) = tg_listener.accept().await.unwrap(); + let _hold_stream = stream; + tokio::time::sleep(Duration::from_secs(60)).await; + }); + + let user = "abort-user"; + let peer_addr: SocketAddr = "198.51.100.230:50000".parse().unwrap(); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let mut cfg = ProxyConfig::default(); + cfg.access.user_max_tcp_conns.insert(user.to_string(), 8); + cfg.dc_overrides + .insert("2".to_string(), vec![tg_addr.to_string()]); + let config = Arc::new(cfg); + + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + + let (server_side, client_side) = duplex(64 * 1024); + let (server_reader, server_writer) = tokio::io::split(server_side); + let client_reader = make_crypto_reader(server_reader); + let client_writer = make_crypto_writer(server_writer); + + let success = HandshakeSuccess { + user: user.to_string(), + dc_idx: 2, + proto_tag: ProtoTag::Intermediate, + dec_key: [0u8; 32], + dec_iv: 0, + enc_key: [0u8; 32], + enc_iv: 0, + peer: peer_addr, + is_tls: false, + }; + + let relay_task = tokio::spawn(RunningClientHandler::handle_authenticated_static( + client_reader, + client_writer, + success, + upstream_manager, + stats.clone(), + config, + buffer_pool, + rng, + None, + route_runtime, + "127.0.0.1:443".parse().unwrap(), + peer_addr, + ip_tracker.clone(), + )); + + tokio::time::timeout(Duration::from_secs(2), async { + loop { + if stats.get_user_curr_connects(user) == 1 + && ip_tracker.get_active_ip_count(user).await == 1 + { + break; + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + }) + .await + .expect("relay must reserve user slot and IP before abort"); + + relay_task.abort(); + let joined = relay_task.await; + assert!(joined.is_err(), "aborted relay task must return join error"); + + tokio::time::sleep(Duration::from_millis(50)).await; + assert_eq!( + stats.get_user_curr_connects(user), + 0, + "task abort must release user current-connection slot" + ); + assert_eq!( + ip_tracker.get_active_ip_count(user).await, + 0, + "task abort must release reserved user IP footprint" + ); + + drop(client_side); + tg_accept_task.abort(); + let _ = tg_accept_task.await; +} + +#[tokio::test] +async fn relay_cutover_releases_user_gate_and_ip_reservation() { + let tg_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let tg_addr = tg_listener.local_addr().unwrap(); + + let tg_accept_task = tokio::spawn(async move { + let (stream, _) = tg_listener.accept().await.unwrap(); + let _hold_stream = stream; + tokio::time::sleep(Duration::from_secs(60)).await; + }); + + let user = "cutover-user"; + let peer_addr: SocketAddr = "198.51.100.231:50001".parse().unwrap(); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let mut cfg = ProxyConfig::default(); + cfg.access.user_max_tcp_conns.insert(user.to_string(), 8); + cfg.dc_overrides + .insert("2".to_string(), vec![tg_addr.to_string()]); + let config = Arc::new(cfg); + + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + + let (server_side, client_side) = duplex(64 * 1024); + let (server_reader, server_writer) = tokio::io::split(server_side); + let client_reader = make_crypto_reader(server_reader); + let client_writer = make_crypto_writer(server_writer); + + let success = HandshakeSuccess { + user: user.to_string(), + dc_idx: 2, + proto_tag: ProtoTag::Intermediate, + dec_key: [0u8; 32], + dec_iv: 0, + enc_key: [0u8; 32], + enc_iv: 0, + peer: peer_addr, + is_tls: false, + }; + + let relay_task = tokio::spawn(RunningClientHandler::handle_authenticated_static( + client_reader, + client_writer, + success, + upstream_manager, + stats.clone(), + config, + buffer_pool, + rng, + None, + route_runtime.clone(), + "127.0.0.1:443".parse().unwrap(), + peer_addr, + ip_tracker.clone(), + )); + + tokio::time::timeout(Duration::from_secs(2), async { + loop { + if stats.get_user_curr_connects(user) == 1 + && ip_tracker.get_active_ip_count(user).await == 1 + { + break; + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + }) + .await + .expect("relay must reserve user slot and IP before cutover"); + + assert!( + route_runtime.set_mode(RelayRouteMode::Middle).is_some(), + "cutover must advance route generation" + ); + + let relay_result = tokio::time::timeout(Duration::from_secs(6), relay_task) + .await + .expect("relay must terminate after cutover") + .expect("relay task must not panic"); + assert!(relay_result.is_err(), "cutover must terminate direct relay session"); + + assert_eq!( + stats.get_user_curr_connects(user), + 0, + "cutover exit must release user current-connection slot" + ); + assert_eq!( + ip_tracker.get_active_ip_count(user).await, + 0, + "cutover exit must release reserved user IP footprint" + ); + + drop(client_side); + tg_accept_task.abort(); + let _ = tg_accept_task.await; +} + #[tokio::test] async fn short_tls_probe_is_masked_through_client_pipeline() { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); diff --git a/src/proxy/direct_relay_security_tests.rs b/src/proxy/direct_relay_security_tests.rs index 1e2d673..7390fb8 100644 --- a/src/proxy/direct_relay_security_tests.rs +++ b/src/proxy/direct_relay_security_tests.rs @@ -178,3 +178,112 @@ async fn direct_relay_abort_midflight_releases_route_gauge() { tg_accept_task.abort(); let _ = tg_accept_task.await; } + +#[tokio::test] +async fn direct_relay_cutover_midflight_releases_route_gauge() { + let tg_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let tg_addr = tg_listener.local_addr().unwrap(); + + let tg_accept_task = tokio::spawn(async move { + let (stream, _) = tg_listener.accept().await.unwrap(); + let _hold_stream = stream; + tokio::time::sleep(Duration::from_secs(60)).await; + }); + + let stats = Arc::new(Stats::new()); + let mut config = ProxyConfig::default(); + config + .dc_overrides + .insert("2".to_string(), vec![tg_addr.to_string()]); + let config = Arc::new(config); + + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + + let rng = Arc::new(SecureRandom::new()); + let buffer_pool = Arc::new(BufferPool::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let route_snapshot = route_runtime.snapshot(); + + let (server_side, client_side) = duplex(64 * 1024); + let (server_reader, server_writer) = tokio::io::split(server_side); + let client_reader = make_crypto_reader(server_reader); + let client_writer = make_crypto_writer(server_writer); + + let success = HandshakeSuccess { + user: "cutover-direct-user".to_string(), + dc_idx: 2, + proto_tag: ProtoTag::Intermediate, + dec_key: [0u8; 32], + dec_iv: 0, + enc_key: [0u8; 32], + enc_iv: 0, + peer: "127.0.0.1:50002".parse().unwrap(), + is_tls: false, + }; + + let relay_task = tokio::spawn(handle_via_direct( + client_reader, + client_writer, + success, + upstream_manager, + stats.clone(), + config, + buffer_pool, + rng, + route_runtime.subscribe(), + route_snapshot, + 0xface_cafe, + )); + + tokio::time::timeout(Duration::from_secs(2), async { + loop { + if stats.get_current_connections_direct() == 1 { + break; + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + }) + .await + .expect("direct relay must increment route gauge before cutover"); + + assert!( + route_runtime.set_mode(RelayRouteMode::Middle).is_some(), + "cutover must advance route generation" + ); + + let relay_result = tokio::time::timeout(Duration::from_secs(6), relay_task) + .await + .expect("direct relay must terminate after cutover") + .expect("direct relay task must not panic"); + assert!( + relay_result.is_err(), + "cutover should terminate direct relay session" + ); + + assert_eq!( + stats.get_current_connections_direct(), + 0, + "route gauge must be released when direct relay exits on cutover" + ); + + drop(client_side); + tg_accept_task.abort(); + let _ = tg_accept_task.await; +} diff --git a/src/proxy/middle_relay_security_tests.rs b/src/proxy/middle_relay_security_tests.rs index 509ba95..f88b5a0 100644 --- a/src/proxy/middle_relay_security_tests.rs +++ b/src/proxy/middle_relay_security_tests.rs @@ -942,3 +942,80 @@ async fn middle_relay_abort_midflight_releases_route_gauge() { drop(client_side); } + +#[tokio::test] +async fn middle_relay_cutover_midflight_releases_route_gauge() { + let stats = Arc::new(Stats::new()); + let me_pool = make_me_pool_for_abort_test(stats.clone()).await; + let config = Arc::new(ProxyConfig::default()); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Middle)); + let route_snapshot = route_runtime.snapshot(); + + let (server_side, client_side) = duplex(64 * 1024); + let (server_reader, server_writer) = tokio::io::split(server_side); + let crypto_reader = make_crypto_reader(server_reader); + let crypto_writer = make_crypto_writer(server_writer); + + let success = HandshakeSuccess { + user: "cutover-middle-user".to_string(), + dc_idx: 2, + proto_tag: ProtoTag::Intermediate, + dec_key: [0u8; 32], + dec_iv: 0, + enc_key: [0u8; 32], + enc_iv: 0, + peer: "127.0.0.1:50003".parse().unwrap(), + is_tls: false, + }; + + let relay_task = tokio::spawn(handle_via_middle_proxy( + crypto_reader, + crypto_writer, + success, + me_pool, + stats.clone(), + config, + buffer_pool, + "127.0.0.1:443".parse().unwrap(), + rng, + route_runtime.subscribe(), + route_snapshot, + 0xfeed_beef, + )); + + tokio::time::timeout(TokioDuration::from_secs(2), async { + loop { + if stats.get_current_connections_me() == 1 { + break; + } + tokio::time::sleep(TokioDuration::from_millis(10)).await; + } + }) + .await + .expect("middle relay must increment route gauge before cutover"); + + assert!( + route_runtime.set_mode(RelayRouteMode::Direct).is_some(), + "cutover must advance route generation" + ); + + let relay_result = tokio::time::timeout(TokioDuration::from_secs(6), relay_task) + .await + .expect("middle relay must terminate after cutover") + .expect("middle relay task must not panic"); + assert!( + relay_result.is_err(), + "cutover should terminate middle relay session" + ); + + assert_eq!( + stats.get_current_connections_me(), + 0, + "route gauge must be released when middle relay exits on cutover" + ); + + drop(client_side); +} diff --git a/src/stats/connection_lease_security_tests.rs b/src/stats/connection_lease_security_tests.rs index 2d942c2..69ae89a 100644 --- a/src/stats/connection_lease_security_tests.rs +++ b/src/stats/connection_lease_security_tests.rs @@ -2,6 +2,7 @@ use super::*; use std::panic::{self, AssertUnwindSafe}; use std::sync::Arc; use std::time::Duration; +use tokio::sync::Barrier; #[test] fn direct_connection_lease_balances_on_drop() { @@ -63,6 +64,156 @@ fn direct_connection_lease_balances_on_panic_unwind() { ); } +#[test] +fn middle_connection_lease_balances_on_panic_unwind() { + let stats = Arc::new(Stats::new()); + let stats_for_panic = stats.clone(); + + let panic_result = panic::catch_unwind(AssertUnwindSafe(move || { + let _lease = stats_for_panic.acquire_me_connection_lease(); + panic!("intentional panic to verify middle lease drop path"); + })); + + assert!(panic_result.is_err(), "panic must propagate from test closure"); + assert_eq!( + stats.get_current_connections_me(), + 0, + "panic unwind must release middle route gauge" + ); +} + +#[tokio::test] +async fn concurrent_mixed_route_lease_churn_balances_to_zero() { + const TASKS: usize = 48; + const ITERATIONS_PER_TASK: usize = 256; + + let stats = Arc::new(Stats::new()); + let barrier = Arc::new(Barrier::new(TASKS)); + let mut workers = Vec::with_capacity(TASKS); + + for task_idx in 0..TASKS { + let stats_for_task = stats.clone(); + let barrier_for_task = barrier.clone(); + workers.push(tokio::spawn(async move { + barrier_for_task.wait().await; + for iter in 0..ITERATIONS_PER_TASK { + if (task_idx + iter) % 2 == 0 { + let _lease = stats_for_task.acquire_direct_connection_lease(); + tokio::task::yield_now().await; + } else { + let _lease = stats_for_task.acquire_me_connection_lease(); + tokio::task::yield_now().await; + } + } + })); + } + + for worker in workers { + worker + .await + .expect("lease churn worker must not panic"); + } + + assert_eq!( + stats.get_current_connections_direct(), + 0, + "direct route gauge must return to zero after concurrent lease churn" + ); + assert_eq!( + stats.get_current_connections_me(), + 0, + "middle route gauge must return to zero after concurrent lease churn" + ); +} + +#[tokio::test] +async fn abort_storm_mixed_route_leases_returns_all_gauges_to_zero() { + const TASKS: usize = 64; + + let stats = Arc::new(Stats::new()); + let mut workers = Vec::with_capacity(TASKS); + + for task_idx in 0..TASKS { + let stats_for_task = stats.clone(); + workers.push(tokio::spawn(async move { + if task_idx % 2 == 0 { + let _lease = stats_for_task.acquire_direct_connection_lease(); + tokio::time::sleep(Duration::from_secs(60)).await; + } else { + let _lease = stats_for_task.acquire_me_connection_lease(); + tokio::time::sleep(Duration::from_secs(60)).await; + } + })); + } + + tokio::time::timeout(Duration::from_secs(2), async { + loop { + let total = stats.get_current_connections_direct() + stats.get_current_connections_me(); + if total == TASKS as u64 { + break; + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + }) + .await + .expect("all storm tasks must acquire route leases before abort"); + + for worker in &workers { + worker.abort(); + } + for worker in workers { + let joined = worker.await; + assert!(joined.is_err(), "aborted worker must return join error"); + } + + tokio::time::timeout(Duration::from_secs(2), async { + loop { + if stats.get_current_connections_direct() == 0 && stats.get_current_connections_me() == 0 { + break; + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + }) + .await + .expect("all route gauges must drain to zero after abort storm"); +} + +#[test] +fn saturating_route_decrements_do_not_underflow_under_race() { + const THREADS: usize = 16; + const DECREMENTS_PER_THREAD: usize = 4096; + + let stats = Arc::new(Stats::new()); + let mut workers = Vec::with_capacity(THREADS); + + for _ in 0..THREADS { + let stats_for_thread = stats.clone(); + workers.push(std::thread::spawn(move || { + for _ in 0..DECREMENTS_PER_THREAD { + stats_for_thread.decrement_current_connections_direct(); + stats_for_thread.decrement_current_connections_me(); + } + })); + } + + for worker in workers { + worker + .join() + .expect("decrement race worker must not panic"); + } + + assert_eq!( + stats.get_current_connections_direct(), + 0, + "direct route decrement races must never underflow" + ); + assert_eq!( + stats.get_current_connections_me(), + 0, + "middle route decrement races must never underflow" + ); +} + #[tokio::test] async fn direct_connection_lease_balances_on_task_abort() { let stats = Arc::new(Stats::new());