From 096f27020b06c6a723b7583f260ccba9ef078c80 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=96=D0=BE=D1=80=D0=B0=20=D0=97=D0=BC=D0=B5=D0=B9=D0=BA?= =?UTF-8?q?=D0=B8=D0=BD?= Date: Sat, 14 Feb 2026 23:16:30 +0300 Subject: [PATCH] Fix: handshake timeout no longer kills active relay sessions Split handshake phase (with timeout) from relay phase (without handshake timeout). Relay has its own activity timeouts (keepalive, client_ack). --- src/main.rs | 5 +- src/proxy/client.rs | 168 +++++++++++++++++++++++++------------------- 2 files changed, 99 insertions(+), 74 deletions(-) diff --git a/src/main.rs b/src/main.rs index 5ae4f66..6440f8a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -29,7 +29,6 @@ use crate::proxy::{ClientHandler, handle_client_stream}; use crate::transport::{create_unix_listener, cleanup_unix_socket}; use crate::crypto::SecureRandom; use crate::ip_tracker::UserIpTracker; -use crate::proxy::ClientHandler; use crate::stats::{ReplayChecker, Stats}; use crate::stream::BufferPool; use crate::transport::middle_proxy::MePool; @@ -482,6 +481,7 @@ async fn main() -> std::result::Result<(), Box> { let buffer_pool = buffer_pool.clone(); let rng = rng.clone(); let me_pool = me_pool.clone(); + let ip_tracker = ip_tracker.clone(); let unix_conn_counter = std::sync::Arc::new( std::sync::atomic::AtomicU64::new(1) @@ -502,12 +502,13 @@ async fn main() -> std::result::Result<(), Box> { let buffer_pool = buffer_pool.clone(); let rng = rng.clone(); let me_pool = me_pool.clone(); + let ip_tracker = ip_tracker.clone(); tokio::spawn(async move { if let Err(e) = handle_client_stream( stream, fake_peer, config, stats, upstream_manager, replay_checker, buffer_pool, rng, - me_pool, + me_pool, ip_tracker, ).await { debug!(error = %e, "Unix socket connection error"); } diff --git a/src/proxy/client.rs b/src/proxy/client.rs index 8b4e8bc..db24cc9 100644 --- a/src/proxy/client.rs +++ b/src/proxy/client.rs @@ -1,6 +1,8 @@ //! Client Handler +use std::future::Future; use std::net::SocketAddr; +use std::pin::Pin; use std::sync::Arc; use std::time::Duration; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite}; @@ -8,6 +10,17 @@ use tokio::net::TcpStream; use tokio::time::timeout; use tracing::{debug, warn}; +/// Post-handshake future (relay phase, runs outside handshake timeout) +type PostHandshakeFuture = Pin> + Send>>; + +/// Result of the handshake phase +enum HandshakeOutcome { + /// Handshake succeeded, relay work to do (outside timeout) + NeedsRelay(PostHandshakeFuture), + /// Already fully handled (bad client masking, etc.) + Handled, +} + use crate::config::ProxyConfig; use crate::crypto::SecureRandom; use crate::error::{HandshakeResult, ProxyError, Result}; @@ -39,6 +52,7 @@ pub async fn handle_client_stream( buffer_pool: Arc, rng: Arc, me_pool: Option>, + ip_tracker: Arc, ) -> Result<()> where S: AsyncRead + AsyncWrite + Unpin + Send + 'static, @@ -54,7 +68,8 @@ where .parse() .unwrap_or_else(|_| "0.0.0.0:443".parse().unwrap()); - let result = timeout(handshake_timeout, async { + // Phase 1: handshake (with timeout) + let outcome = match timeout(handshake_timeout, async { let mut first_bytes = [0u8; 5]; stream.read_exact(&mut first_bytes).await?; @@ -69,7 +84,7 @@ where stats.increment_connects_bad(); let (reader, writer) = tokio::io::split(stream); handle_bad_client(reader, writer, &first_bytes, &config).await; - return Ok(()); + return Ok(HandshakeOutcome::Handled); } let mut handshake = vec![0u8; 5 + tls_len]; @@ -86,7 +101,7 @@ where HandshakeResult::BadClient { reader, writer } => { stats.increment_connects_bad(); handle_bad_client(reader, writer, &handshake, &config).await; - return Ok(()); + return Ok(HandshakeOutcome::Handled); } HandshakeResult::Error(e) => return Err(e), }; @@ -104,23 +119,25 @@ where HandshakeResult::BadClient { reader: _, writer: _ } => { stats.increment_connects_bad(); debug!(peer = %peer, "Valid TLS but invalid MTProto handshake"); - return Ok(()); + return Ok(HandshakeOutcome::Handled); } HandshakeResult::Error(e) => return Err(e), }; - RunningClientHandler::handle_authenticated_static( - crypto_reader, crypto_writer, success, - upstream_manager, stats, config, buffer_pool, rng, me_pool, - local_addr, - ).await + Ok(HandshakeOutcome::NeedsRelay(Box::pin( + RunningClientHandler::handle_authenticated_static( + crypto_reader, crypto_writer, success, + upstream_manager, stats, config, buffer_pool, rng, me_pool, + local_addr, peer, ip_tracker.clone(), + ), + ))) } else { if !config.general.modes.classic && !config.general.modes.secure { debug!(peer = %peer, "Non-TLS modes disabled"); stats.increment_connects_bad(); let (reader, writer) = tokio::io::split(stream); handle_bad_client(reader, writer, &first_bytes, &config).await; - return Ok(()); + return Ok(HandshakeOutcome::Handled); } let mut handshake = [0u8; HANDSHAKE_LEN]; @@ -137,33 +154,36 @@ where HandshakeResult::BadClient { reader, writer } => { stats.increment_connects_bad(); handle_bad_client(reader, writer, &handshake, &config).await; - return Ok(()); + return Ok(HandshakeOutcome::Handled); } HandshakeResult::Error(e) => return Err(e), }; - RunningClientHandler::handle_authenticated_static( - crypto_reader, crypto_writer, success, - upstream_manager, stats, config, buffer_pool, rng, me_pool, - local_addr, - ).await - } - }).await; - - match result { - Ok(Ok(())) => { - debug!(peer = %peer, "Connection handled successfully"); - Ok(()) + Ok(HandshakeOutcome::NeedsRelay(Box::pin( + RunningClientHandler::handle_authenticated_static( + crypto_reader, crypto_writer, success, + upstream_manager, stats, config, buffer_pool, rng, me_pool, + local_addr, peer, ip_tracker, + ), + ))) } + }).await { + Ok(Ok(outcome)) => outcome, Ok(Err(e)) => { debug!(peer = %peer, error = %e, "Handshake failed"); - Err(e) + return Err(e); } Err(_) => { stats_for_timeout.increment_handshake_timeouts(); debug!(peer = %peer, "Handshake timeout"); - Err(ProxyError::TgHandshakeTimeout) + return Err(ProxyError::TgHandshakeTimeout); } + }; + + // Phase 2: relay (WITHOUT handshake timeout — relay has its own activity timeouts) + match outcome { + HandshakeOutcome::NeedsRelay(fut) => fut.await, + HandshakeOutcome::Handled => Ok(()), } } @@ -228,26 +248,28 @@ impl RunningClientHandler { let handshake_timeout = Duration::from_secs(self.config.timeouts.client_handshake); let stats = self.stats.clone(); - let result = timeout(handshake_timeout, self.do_handshake()).await; - - match result { - Ok(Ok(())) => { - debug!(peer = %peer, "Connection handled successfully"); - Ok(()) - } + // Phase 1: handshake (with timeout) + let outcome = match timeout(handshake_timeout, self.do_handshake()).await { + Ok(Ok(outcome)) => outcome, Ok(Err(e)) => { debug!(peer = %peer, error = %e, "Handshake failed"); - Err(e) + return Err(e); } Err(_) => { stats.increment_handshake_timeouts(); debug!(peer = %peer, "Handshake timeout"); - Err(ProxyError::TgHandshakeTimeout) + return Err(ProxyError::TgHandshakeTimeout); } + }; + + // Phase 2: relay (WITHOUT handshake timeout — relay has its own activity timeouts) + match outcome { + HandshakeOutcome::NeedsRelay(fut) => fut.await, + HandshakeOutcome::Handled => Ok(()), } } - async fn do_handshake(mut self) -> Result<()> { + async fn do_handshake(mut self) -> Result { let mut first_bytes = [0u8; 5]; self.stream.read_exact(&mut first_bytes).await?; @@ -263,7 +285,7 @@ impl RunningClientHandler { } } - async fn handle_tls_client(mut self, first_bytes: [u8; 5]) -> Result<()> { + async fn handle_tls_client(mut self, first_bytes: [u8; 5]) -> Result { let peer = self.peer; let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize; @@ -275,7 +297,7 @@ impl RunningClientHandler { self.stats.increment_connects_bad(); let (reader, writer) = self.stream.into_split(); handle_bad_client(reader, writer, &first_bytes, &self.config).await; - return Ok(()); + return Ok(HandshakeOutcome::Handled); } let mut handshake = vec![0u8; 5 + tls_len]; @@ -305,7 +327,7 @@ impl RunningClientHandler { HandshakeResult::BadClient { reader, writer } => { stats.increment_connects_bad(); handle_bad_client(reader, writer, &handshake, &config).await; - return Ok(()); + return Ok(HandshakeOutcome::Handled); } HandshakeResult::Error(e) => return Err(e), }; @@ -334,29 +356,30 @@ impl RunningClientHandler { } => { stats.increment_connects_bad(); debug!(peer = %peer, "Valid TLS but invalid MTProto handshake"); - return Ok(()); + return Ok(HandshakeOutcome::Handled); } HandshakeResult::Error(e) => return Err(e), }; - Self::handle_authenticated_static( - crypto_reader, - crypto_writer, - success, - self.upstream_manager, - self.stats, - self.config, - buffer_pool, - self.rng, - self.me_pool, - local_addr, - peer, - self.ip_tracker, - ) - .await + Ok(HandshakeOutcome::NeedsRelay(Box::pin( + Self::handle_authenticated_static( + crypto_reader, + crypto_writer, + success, + self.upstream_manager, + self.stats, + self.config, + buffer_pool, + self.rng, + self.me_pool, + local_addr, + peer, + self.ip_tracker, + ), + ))) } - async fn handle_direct_client(mut self, first_bytes: [u8; 5]) -> Result<()> { + async fn handle_direct_client(mut self, first_bytes: [u8; 5]) -> Result { let peer = self.peer; if !self.config.general.modes.classic && !self.config.general.modes.secure { @@ -364,7 +387,7 @@ impl RunningClientHandler { self.stats.increment_connects_bad(); let (reader, writer) = self.stream.into_split(); handle_bad_client(reader, writer, &first_bytes, &self.config).await; - return Ok(()); + return Ok(HandshakeOutcome::Handled); } let mut handshake = [0u8; HANDSHAKE_LEN]; @@ -394,26 +417,27 @@ impl RunningClientHandler { HandshakeResult::BadClient { reader, writer } => { stats.increment_connects_bad(); handle_bad_client(reader, writer, &handshake, &config).await; - return Ok(()); + return Ok(HandshakeOutcome::Handled); } HandshakeResult::Error(e) => return Err(e), }; - Self::handle_authenticated_static( - crypto_reader, - crypto_writer, - success, - self.upstream_manager, - self.stats, - self.config, - buffer_pool, - self.rng, - self.me_pool, - local_addr, - peer, - self.ip_tracker, - ) - .await + Ok(HandshakeOutcome::NeedsRelay(Box::pin( + Self::handle_authenticated_static( + crypto_reader, + crypto_writer, + success, + self.upstream_manager, + self.stats, + self.config, + buffer_pool, + self.rng, + self.me_pool, + local_addr, + peer, + self.ip_tracker, + ), + ))) } /// Main dispatch after successful handshake.