diff --git a/src/main.rs b/src/main.rs index 0f3757b..61debb9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -265,7 +265,7 @@ async fn main() -> std::result::Result<(), Box> { } // Connection concurrency limit - let _max_connections = Arc::new(Semaphore::new(10_000)); + let max_connections = Arc::new(Semaphore::new(10_000)); if use_middle_proxy && !decision.ipv4_me && !decision.ipv6_me { warn!("No usable IP family for Middle Proxy detected; falling back to direct DC"); @@ -844,6 +844,7 @@ match crate::transport::middle_proxy::fetch_proxy_secret(proxy_secret_path).awai let me_pool = me_pool.clone(); let tls_cache = tls_cache.clone(); let ip_tracker = ip_tracker.clone(); + let max_connections_unix = max_connections.clone(); tokio::spawn(async move { let unix_conn_counter = std::sync::Arc::new(std::sync::atomic::AtomicU64::new(1)); @@ -851,6 +852,13 @@ match crate::transport::middle_proxy::fetch_proxy_secret(proxy_secret_path).awai loop { match unix_listener.accept().await { Ok((stream, _)) => { + let permit = match max_connections_unix.clone().acquire_owned().await { + Ok(permit) => permit, + Err(_) => { + error!("Connection limiter is closed"); + break; + } + }; let conn_id = unix_conn_counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed); let fake_peer = SocketAddr::from(([127, 0, 0, 1], (conn_id % 65535) as u16)); @@ -866,6 +874,7 @@ match crate::transport::middle_proxy::fetch_proxy_secret(proxy_secret_path).awai let proxy_protocol_enabled = config.server.proxy_protocol; tokio::spawn(async move { + let _permit = permit; if let Err(e) = crate::proxy::client::handle_client_stream( stream, fake_peer, config, stats, upstream_manager, replay_checker, buffer_pool, rng, @@ -933,11 +942,19 @@ match crate::transport::middle_proxy::fetch_proxy_secret(proxy_secret_path).awai let me_pool = me_pool.clone(); let tls_cache = tls_cache.clone(); let ip_tracker = ip_tracker.clone(); + let max_connections_tcp = max_connections.clone(); tokio::spawn(async move { loop { match listener.accept().await { Ok((stream, peer_addr)) => { + let permit = match max_connections_tcp.clone().acquire_owned().await { + Ok(permit) => permit, + Err(_) => { + error!("Connection limiter is closed"); + break; + } + }; let config = config_rx.borrow_and_update().clone(); let stats = stats.clone(); let upstream_manager = upstream_manager.clone(); @@ -950,6 +967,7 @@ match crate::transport::middle_proxy::fetch_proxy_secret(proxy_secret_path).awai let proxy_protocol_enabled = listener_proxy_protocol; tokio::spawn(async move { + let _permit = permit; if let Err(e) = ClientHandler::new( stream, peer_addr, diff --git a/src/stats/mod.rs b/src/stats/mod.rs index e480ec6..307da6d 100644 --- a/src/stats/mod.rs +++ b/src/stats/mod.rs @@ -109,7 +109,22 @@ impl Stats { pub fn decrement_user_curr_connects(&self, user: &str) { if let Some(stats) = self.user_stats.get(user) { - stats.curr_connects.fetch_sub(1, Ordering::Relaxed); + let counter = &stats.curr_connects; + let mut current = counter.load(Ordering::Relaxed); + loop { + if current == 0 { + break; + } + match counter.compare_exchange_weak( + current, + current - 1, + Ordering::Relaxed, + Ordering::Relaxed, + ) { + Ok(_) => break, + Err(actual) => current = actual, + } + } } } diff --git a/src/transport/middle_proxy/registry.rs b/src/transport/middle_proxy/registry.rs index 4b25e00..6a9250d 100644 --- a/src/transport/middle_proxy/registry.rs +++ b/src/transport/middle_proxy/registry.rs @@ -1,6 +1,7 @@ use std::collections::{HashMap, HashSet}; use std::net::SocketAddr; use std::sync::atomic::{AtomicU64, Ordering}; +use std::time::Duration; use tokio::sync::{mpsc, RwLock}; use tokio::sync::mpsc::error::TrySendError; @@ -9,6 +10,7 @@ use super::codec::WriterCommand; use super::MeResponse; const ROUTE_CHANNEL_CAPACITY: usize = 4096; +const ROUTE_BACKPRESSURE_TIMEOUT: Duration = Duration::from_millis(25); #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum RouteResult { @@ -94,15 +96,26 @@ impl ConnRegistry { } pub async fn route(&self, id: u64, resp: MeResponse) -> RouteResult { - let inner = self.inner.read().await; - if let Some(tx) = inner.map.get(&id) { - match tx.try_send(resp) { - Ok(()) => RouteResult::Routed, - Err(TrySendError::Closed(_)) => RouteResult::ChannelClosed, - Err(TrySendError::Full(_)) => RouteResult::QueueFull, + let tx = { + let inner = self.inner.read().await; + inner.map.get(&id).cloned() + }; + + let Some(tx) = tx else { + return RouteResult::NoConn; + }; + + match tx.try_send(resp) { + Ok(()) => RouteResult::Routed, + Err(TrySendError::Closed(_)) => RouteResult::ChannelClosed, + Err(TrySendError::Full(resp)) => { + // Absorb short bursts without dropping/closing the session immediately. + match tokio::time::timeout(ROUTE_BACKPRESSURE_TIMEOUT, tx.send(resp)).await { + Ok(Ok(())) => RouteResult::Routed, + Ok(Err(_)) => RouteResult::ChannelClosed, + Err(_) => RouteResult::QueueFull, + } } - } else { - RouteResult::NoConn } }