Fix: handshake timeout no longer kills active relay sessions

Split handshake phase (with timeout) from relay phase (without handshake timeout).
Relay has its own activity timeouts (keepalive, client_ack).
This commit is contained in:
Жора Змейкин 2026-02-14 23:16:30 +03:00
parent 1a6b39b829
commit 096f27020b
No known key found for this signature in database
GPG Key ID: F9A576E7B79C7F61
2 changed files with 99 additions and 74 deletions

View File

@ -29,7 +29,6 @@ use crate::proxy::{ClientHandler, handle_client_stream};
use crate::transport::{create_unix_listener, cleanup_unix_socket}; use crate::transport::{create_unix_listener, cleanup_unix_socket};
use crate::crypto::SecureRandom; use crate::crypto::SecureRandom;
use crate::ip_tracker::UserIpTracker; use crate::ip_tracker::UserIpTracker;
use crate::proxy::ClientHandler;
use crate::stats::{ReplayChecker, Stats}; use crate::stats::{ReplayChecker, Stats};
use crate::stream::BufferPool; use crate::stream::BufferPool;
use crate::transport::middle_proxy::MePool; use crate::transport::middle_proxy::MePool;
@ -482,6 +481,7 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
let buffer_pool = buffer_pool.clone(); let buffer_pool = buffer_pool.clone();
let rng = rng.clone(); let rng = rng.clone();
let me_pool = me_pool.clone(); let me_pool = me_pool.clone();
let ip_tracker = ip_tracker.clone();
let unix_conn_counter = std::sync::Arc::new( let unix_conn_counter = std::sync::Arc::new(
std::sync::atomic::AtomicU64::new(1) std::sync::atomic::AtomicU64::new(1)
@ -502,12 +502,13 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
let buffer_pool = buffer_pool.clone(); let buffer_pool = buffer_pool.clone();
let rng = rng.clone(); let rng = rng.clone();
let me_pool = me_pool.clone(); let me_pool = me_pool.clone();
let ip_tracker = ip_tracker.clone();
tokio::spawn(async move { tokio::spawn(async move {
if let Err(e) = handle_client_stream( if let Err(e) = handle_client_stream(
stream, fake_peer, config, stats, stream, fake_peer, config, stats,
upstream_manager, replay_checker, buffer_pool, rng, upstream_manager, replay_checker, buffer_pool, rng,
me_pool, me_pool, ip_tracker,
).await { ).await {
debug!(error = %e, "Unix socket connection error"); debug!(error = %e, "Unix socket connection error");
} }

View File

@ -1,6 +1,8 @@
//! Client Handler //! Client Handler
use std::future::Future;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
@ -8,6 +10,17 @@ use tokio::net::TcpStream;
use tokio::time::timeout; use tokio::time::timeout;
use tracing::{debug, warn}; use tracing::{debug, warn};
/// Post-handshake future (relay phase, runs outside handshake timeout)
type PostHandshakeFuture = Pin<Box<dyn Future<Output = Result<()>> + Send>>;
/// Result of the handshake phase
enum HandshakeOutcome {
/// Handshake succeeded, relay work to do (outside timeout)
NeedsRelay(PostHandshakeFuture),
/// Already fully handled (bad client masking, etc.)
Handled,
}
use crate::config::ProxyConfig; use crate::config::ProxyConfig;
use crate::crypto::SecureRandom; use crate::crypto::SecureRandom;
use crate::error::{HandshakeResult, ProxyError, Result}; use crate::error::{HandshakeResult, ProxyError, Result};
@ -39,6 +52,7 @@ pub async fn handle_client_stream<S>(
buffer_pool: Arc<BufferPool>, buffer_pool: Arc<BufferPool>,
rng: Arc<SecureRandom>, rng: Arc<SecureRandom>,
me_pool: Option<Arc<MePool>>, me_pool: Option<Arc<MePool>>,
ip_tracker: Arc<UserIpTracker>,
) -> Result<()> ) -> Result<()>
where where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static, S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
@ -54,7 +68,8 @@ where
.parse() .parse()
.unwrap_or_else(|_| "0.0.0.0:443".parse().unwrap()); .unwrap_or_else(|_| "0.0.0.0:443".parse().unwrap());
let result = timeout(handshake_timeout, async { // Phase 1: handshake (with timeout)
let outcome = match timeout(handshake_timeout, async {
let mut first_bytes = [0u8; 5]; let mut first_bytes = [0u8; 5];
stream.read_exact(&mut first_bytes).await?; stream.read_exact(&mut first_bytes).await?;
@ -69,7 +84,7 @@ where
stats.increment_connects_bad(); stats.increment_connects_bad();
let (reader, writer) = tokio::io::split(stream); let (reader, writer) = tokio::io::split(stream);
handle_bad_client(reader, writer, &first_bytes, &config).await; handle_bad_client(reader, writer, &first_bytes, &config).await;
return Ok(()); return Ok(HandshakeOutcome::Handled);
} }
let mut handshake = vec![0u8; 5 + tls_len]; let mut handshake = vec![0u8; 5 + tls_len];
@ -86,7 +101,7 @@ where
HandshakeResult::BadClient { reader, writer } => { HandshakeResult::BadClient { reader, writer } => {
stats.increment_connects_bad(); stats.increment_connects_bad();
handle_bad_client(reader, writer, &handshake, &config).await; handle_bad_client(reader, writer, &handshake, &config).await;
return Ok(()); return Ok(HandshakeOutcome::Handled);
} }
HandshakeResult::Error(e) => return Err(e), HandshakeResult::Error(e) => return Err(e),
}; };
@ -104,23 +119,25 @@ where
HandshakeResult::BadClient { reader: _, writer: _ } => { HandshakeResult::BadClient { reader: _, writer: _ } => {
stats.increment_connects_bad(); stats.increment_connects_bad();
debug!(peer = %peer, "Valid TLS but invalid MTProto handshake"); debug!(peer = %peer, "Valid TLS but invalid MTProto handshake");
return Ok(()); return Ok(HandshakeOutcome::Handled);
} }
HandshakeResult::Error(e) => return Err(e), HandshakeResult::Error(e) => return Err(e),
}; };
RunningClientHandler::handle_authenticated_static( Ok(HandshakeOutcome::NeedsRelay(Box::pin(
crypto_reader, crypto_writer, success, RunningClientHandler::handle_authenticated_static(
upstream_manager, stats, config, buffer_pool, rng, me_pool, crypto_reader, crypto_writer, success,
local_addr, upstream_manager, stats, config, buffer_pool, rng, me_pool,
).await local_addr, peer, ip_tracker.clone(),
),
)))
} else { } else {
if !config.general.modes.classic && !config.general.modes.secure { if !config.general.modes.classic && !config.general.modes.secure {
debug!(peer = %peer, "Non-TLS modes disabled"); debug!(peer = %peer, "Non-TLS modes disabled");
stats.increment_connects_bad(); stats.increment_connects_bad();
let (reader, writer) = tokio::io::split(stream); let (reader, writer) = tokio::io::split(stream);
handle_bad_client(reader, writer, &first_bytes, &config).await; handle_bad_client(reader, writer, &first_bytes, &config).await;
return Ok(()); return Ok(HandshakeOutcome::Handled);
} }
let mut handshake = [0u8; HANDSHAKE_LEN]; let mut handshake = [0u8; HANDSHAKE_LEN];
@ -137,33 +154,36 @@ where
HandshakeResult::BadClient { reader, writer } => { HandshakeResult::BadClient { reader, writer } => {
stats.increment_connects_bad(); stats.increment_connects_bad();
handle_bad_client(reader, writer, &handshake, &config).await; handle_bad_client(reader, writer, &handshake, &config).await;
return Ok(()); return Ok(HandshakeOutcome::Handled);
} }
HandshakeResult::Error(e) => return Err(e), HandshakeResult::Error(e) => return Err(e),
}; };
RunningClientHandler::handle_authenticated_static( Ok(HandshakeOutcome::NeedsRelay(Box::pin(
crypto_reader, crypto_writer, success, RunningClientHandler::handle_authenticated_static(
upstream_manager, stats, config, buffer_pool, rng, me_pool, crypto_reader, crypto_writer, success,
local_addr, upstream_manager, stats, config, buffer_pool, rng, me_pool,
).await local_addr, peer, ip_tracker,
} ),
}).await; )))
match result {
Ok(Ok(())) => {
debug!(peer = %peer, "Connection handled successfully");
Ok(())
} }
}).await {
Ok(Ok(outcome)) => outcome,
Ok(Err(e)) => { Ok(Err(e)) => {
debug!(peer = %peer, error = %e, "Handshake failed"); debug!(peer = %peer, error = %e, "Handshake failed");
Err(e) return Err(e);
} }
Err(_) => { Err(_) => {
stats_for_timeout.increment_handshake_timeouts(); stats_for_timeout.increment_handshake_timeouts();
debug!(peer = %peer, "Handshake timeout"); debug!(peer = %peer, "Handshake timeout");
Err(ProxyError::TgHandshakeTimeout) return Err(ProxyError::TgHandshakeTimeout);
} }
};
// Phase 2: relay (WITHOUT handshake timeout — relay has its own activity timeouts)
match outcome {
HandshakeOutcome::NeedsRelay(fut) => fut.await,
HandshakeOutcome::Handled => Ok(()),
} }
} }
@ -228,26 +248,28 @@ impl RunningClientHandler {
let handshake_timeout = Duration::from_secs(self.config.timeouts.client_handshake); let handshake_timeout = Duration::from_secs(self.config.timeouts.client_handshake);
let stats = self.stats.clone(); let stats = self.stats.clone();
let result = timeout(handshake_timeout, self.do_handshake()).await; // Phase 1: handshake (with timeout)
let outcome = match timeout(handshake_timeout, self.do_handshake()).await {
match result { Ok(Ok(outcome)) => outcome,
Ok(Ok(())) => {
debug!(peer = %peer, "Connection handled successfully");
Ok(())
}
Ok(Err(e)) => { Ok(Err(e)) => {
debug!(peer = %peer, error = %e, "Handshake failed"); debug!(peer = %peer, error = %e, "Handshake failed");
Err(e) return Err(e);
} }
Err(_) => { Err(_) => {
stats.increment_handshake_timeouts(); stats.increment_handshake_timeouts();
debug!(peer = %peer, "Handshake timeout"); debug!(peer = %peer, "Handshake timeout");
Err(ProxyError::TgHandshakeTimeout) return Err(ProxyError::TgHandshakeTimeout);
} }
};
// Phase 2: relay (WITHOUT handshake timeout — relay has its own activity timeouts)
match outcome {
HandshakeOutcome::NeedsRelay(fut) => fut.await,
HandshakeOutcome::Handled => Ok(()),
} }
} }
async fn do_handshake(mut self) -> Result<()> { async fn do_handshake(mut self) -> Result<HandshakeOutcome> {
let mut first_bytes = [0u8; 5]; let mut first_bytes = [0u8; 5];
self.stream.read_exact(&mut first_bytes).await?; self.stream.read_exact(&mut first_bytes).await?;
@ -263,7 +285,7 @@ impl RunningClientHandler {
} }
} }
async fn handle_tls_client(mut self, first_bytes: [u8; 5]) -> Result<()> { async fn handle_tls_client(mut self, first_bytes: [u8; 5]) -> Result<HandshakeOutcome> {
let peer = self.peer; let peer = self.peer;
let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize; let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize;
@ -275,7 +297,7 @@ impl RunningClientHandler {
self.stats.increment_connects_bad(); self.stats.increment_connects_bad();
let (reader, writer) = self.stream.into_split(); let (reader, writer) = self.stream.into_split();
handle_bad_client(reader, writer, &first_bytes, &self.config).await; handle_bad_client(reader, writer, &first_bytes, &self.config).await;
return Ok(()); return Ok(HandshakeOutcome::Handled);
} }
let mut handshake = vec![0u8; 5 + tls_len]; let mut handshake = vec![0u8; 5 + tls_len];
@ -305,7 +327,7 @@ impl RunningClientHandler {
HandshakeResult::BadClient { reader, writer } => { HandshakeResult::BadClient { reader, writer } => {
stats.increment_connects_bad(); stats.increment_connects_bad();
handle_bad_client(reader, writer, &handshake, &config).await; handle_bad_client(reader, writer, &handshake, &config).await;
return Ok(()); return Ok(HandshakeOutcome::Handled);
} }
HandshakeResult::Error(e) => return Err(e), HandshakeResult::Error(e) => return Err(e),
}; };
@ -334,29 +356,30 @@ impl RunningClientHandler {
} => { } => {
stats.increment_connects_bad(); stats.increment_connects_bad();
debug!(peer = %peer, "Valid TLS but invalid MTProto handshake"); debug!(peer = %peer, "Valid TLS but invalid MTProto handshake");
return Ok(()); return Ok(HandshakeOutcome::Handled);
} }
HandshakeResult::Error(e) => return Err(e), HandshakeResult::Error(e) => return Err(e),
}; };
Self::handle_authenticated_static( Ok(HandshakeOutcome::NeedsRelay(Box::pin(
crypto_reader, Self::handle_authenticated_static(
crypto_writer, crypto_reader,
success, crypto_writer,
self.upstream_manager, success,
self.stats, self.upstream_manager,
self.config, self.stats,
buffer_pool, self.config,
self.rng, buffer_pool,
self.me_pool, self.rng,
local_addr, self.me_pool,
peer, local_addr,
self.ip_tracker, peer,
) self.ip_tracker,
.await ),
)))
} }
async fn handle_direct_client(mut self, first_bytes: [u8; 5]) -> Result<()> { async fn handle_direct_client(mut self, first_bytes: [u8; 5]) -> Result<HandshakeOutcome> {
let peer = self.peer; let peer = self.peer;
if !self.config.general.modes.classic && !self.config.general.modes.secure { if !self.config.general.modes.classic && !self.config.general.modes.secure {
@ -364,7 +387,7 @@ impl RunningClientHandler {
self.stats.increment_connects_bad(); self.stats.increment_connects_bad();
let (reader, writer) = self.stream.into_split(); let (reader, writer) = self.stream.into_split();
handle_bad_client(reader, writer, &first_bytes, &self.config).await; handle_bad_client(reader, writer, &first_bytes, &self.config).await;
return Ok(()); return Ok(HandshakeOutcome::Handled);
} }
let mut handshake = [0u8; HANDSHAKE_LEN]; let mut handshake = [0u8; HANDSHAKE_LEN];
@ -394,26 +417,27 @@ impl RunningClientHandler {
HandshakeResult::BadClient { reader, writer } => { HandshakeResult::BadClient { reader, writer } => {
stats.increment_connects_bad(); stats.increment_connects_bad();
handle_bad_client(reader, writer, &handshake, &config).await; handle_bad_client(reader, writer, &handshake, &config).await;
return Ok(()); return Ok(HandshakeOutcome::Handled);
} }
HandshakeResult::Error(e) => return Err(e), HandshakeResult::Error(e) => return Err(e),
}; };
Self::handle_authenticated_static( Ok(HandshakeOutcome::NeedsRelay(Box::pin(
crypto_reader, Self::handle_authenticated_static(
crypto_writer, crypto_reader,
success, crypto_writer,
self.upstream_manager, success,
self.stats, self.upstream_manager,
self.config, self.stats,
buffer_pool, self.config,
self.rng, buffer_pool,
self.me_pool, self.rng,
local_addr, self.me_pool,
peer, local_addr,
self.ip_tracker, peer,
) self.ip_tracker,
.await ),
)))
} }
/// Main dispatch after successful handshake. /// Main dispatch after successful handshake.