Use handshake timeout only for handshake, not for the whole session.

This commit is contained in:
AndreyAkifev 2026-02-15 19:42:17 +07:00
parent a80db2ddbc
commit 2c4016d792
1 changed files with 153 additions and 58 deletions

View File

@ -4,7 +4,7 @@ use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
use tokio::net::TcpStream; use tokio::net::{TcpStream, tcp::{OwnedReadHalf, OwnedWriteHalf}};
use tokio::time::timeout; use tokio::time::timeout;
use tracing::{debug, warn}; use tracing::{debug, warn};
@ -15,7 +15,7 @@ use crate::ip_tracker::UserIpTracker;
use crate::protocol::constants::*; use crate::protocol::constants::*;
use crate::protocol::tls; use crate::protocol::tls;
use crate::stats::{ReplayChecker, Stats}; use crate::stats::{ReplayChecker, Stats};
use crate::stream::{BufferPool, CryptoReader, CryptoWriter}; use crate::stream::{BufferPool, CryptoReader, CryptoWriter, FakeTlsReader, FakeTlsWriter};
use crate::transport::middle_proxy::MePool; use crate::transport::middle_proxy::MePool;
use crate::transport::{UpstreamManager, configure_client_socket}; use crate::transport::{UpstreamManager, configure_client_socket};
@ -39,6 +39,51 @@ pub struct RunningClientHandler {
ip_tracker: Arc<UserIpTracker>, ip_tracker: Arc<UserIpTracker>,
} }
type TcpReader = OwnedReadHalf;
type TcpWriter = OwnedWriteHalf;
type TlsReader = FakeTlsReader<TcpReader>;
type TlsWriter = FakeTlsWriter<TcpWriter>;
type DirectReader = CryptoReader<TcpReader>;
type DirectWriter = CryptoWriter<TcpWriter>;
type TlsCryptoReader = CryptoReader<TlsReader>;
type TlsCryptoWriter = CryptoWriter<TlsWriter>;
enum HandshakeContext {
Direct {
client_reader: DirectReader,
client_writer: DirectWriter,
success: HandshakeSuccess,
upstream_manager: Arc<UpstreamManager>,
stats: Arc<Stats>,
config: Arc<ProxyConfig>,
buffer_pool: Arc<BufferPool>,
rng: Arc<SecureRandom>,
me_pool: Option<Arc<MePool>>,
local_addr: SocketAddr,
peer_addr: SocketAddr,
ip_tracker: Arc<UserIpTracker>,
},
Tls {
client_reader: TlsCryptoReader,
client_writer: TlsCryptoWriter,
success: HandshakeSuccess,
upstream_manager: Arc<UpstreamManager>,
stats: Arc<Stats>,
config: Arc<ProxyConfig>,
buffer_pool: Arc<BufferPool>,
rng: Arc<SecureRandom>,
me_pool: Option<Arc<MePool>>,
local_addr: SocketAddr,
peer_addr: SocketAddr,
ip_tracker: Arc<UserIpTracker>,
},
}
enum HandshakeOutcome {
Completed,
Ready(HandshakeContext),
}
impl ClientHandler { impl ClientHandler {
pub fn new( pub fn new(
stream: TcpStream, stream: TcpStream,
@ -88,10 +133,72 @@ impl RunningClientHandler {
let result = timeout(handshake_timeout, self.do_handshake()).await; let result = timeout(handshake_timeout, self.do_handshake()).await;
match result { match result {
Ok(Ok(())) => { Ok(Ok(HandshakeOutcome::Completed)) => {
debug!(peer = %peer, "Connection handled successfully"); debug!(peer = %peer, "Connection handled successfully");
Ok(()) Ok(())
} }
Ok(Ok(HandshakeOutcome::Ready(ctx))) => match ctx {
HandshakeContext::Direct {
client_reader,
client_writer,
success,
upstream_manager,
stats,
config,
buffer_pool,
rng,
me_pool,
local_addr,
peer_addr,
ip_tracker,
} => {
Self::handle_authenticated_static(
client_reader,
client_writer,
success,
upstream_manager,
stats,
config,
buffer_pool,
rng,
me_pool,
local_addr,
peer_addr,
ip_tracker,
)
.await
}
HandshakeContext::Tls {
client_reader,
client_writer,
success,
upstream_manager,
stats,
config,
buffer_pool,
rng,
me_pool,
local_addr,
peer_addr,
ip_tracker,
} => {
Self::handle_authenticated_static(
client_reader,
client_writer,
success,
upstream_manager,
stats,
config,
buffer_pool,
rng,
me_pool,
local_addr,
peer_addr,
ip_tracker,
)
.await
}
},
Ok(Err(e)) => { Ok(Err(e)) => {
debug!(peer = %peer, error = %e, "Handshake failed"); debug!(peer = %peer, error = %e, "Handshake failed");
Err(e) Err(e)
@ -104,7 +211,7 @@ impl RunningClientHandler {
} }
} }
async fn do_handshake(mut self) -> Result<()> { async fn do_handshake(mut self) -> Result<HandshakeOutcome> {
let mut first_bytes = [0u8; 5]; let mut first_bytes = [0u8; 5];
self.stream.read_exact(&mut first_bytes).await?; self.stream.read_exact(&mut first_bytes).await?;
@ -120,7 +227,7 @@ impl RunningClientHandler {
} }
} }
async fn handle_tls_client(mut self, first_bytes: [u8; 5]) -> Result<()> { async fn handle_tls_client(mut self, first_bytes: [u8; 5]) -> Result<HandshakeOutcome> {
let peer = self.peer; let peer = self.peer;
let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize; let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize;
@ -132,18 +239,13 @@ impl RunningClientHandler {
self.stats.increment_connects_bad(); self.stats.increment_connects_bad();
let (reader, writer) = self.stream.into_split(); let (reader, writer) = self.stream.into_split();
handle_bad_client(reader, writer, &first_bytes, &self.config).await; handle_bad_client(reader, writer, &first_bytes, &self.config).await;
return Ok(()); return Ok(HandshakeOutcome::Completed);
} }
let mut handshake = vec![0u8; 5 + tls_len]; let mut handshake = vec![0u8; 5 + tls_len];
handshake[..5].copy_from_slice(&first_bytes); handshake[..5].copy_from_slice(&first_bytes);
self.stream.read_exact(&mut handshake[5..]).await?; self.stream.read_exact(&mut handshake[5..]).await?;
let config = self.config.clone();
let replay_checker = self.replay_checker.clone();
let stats = self.stats.clone();
let buffer_pool = self.buffer_pool.clone();
let local_addr = self.stream.local_addr().map_err(ProxyError::Io)?; let local_addr = self.stream.local_addr().map_err(ProxyError::Io)?;
let (read_half, write_half) = self.stream.into_split(); let (read_half, write_half) = self.stream.into_split();
@ -152,17 +254,17 @@ impl RunningClientHandler {
read_half, read_half,
write_half, write_half,
peer, peer,
&config, &self.config,
&replay_checker, &self.replay_checker,
&self.rng, &self.rng,
) )
.await .await
{ {
HandshakeResult::Success(result) => result, HandshakeResult::Success(result) => result,
HandshakeResult::BadClient { reader, writer } => { HandshakeResult::BadClient { reader, writer } => {
stats.increment_connects_bad(); self.stats.increment_connects_bad();
handle_bad_client(reader, writer, &handshake, &config).await; handle_bad_client(reader, writer, &handshake, &self.config).await;
return Ok(()); return Ok(HandshakeOutcome::Completed);
} }
HandshakeResult::Error(e) => return Err(e), HandshakeResult::Error(e) => return Err(e),
}; };
@ -178,8 +280,8 @@ impl RunningClientHandler {
tls_reader, tls_reader,
tls_writer, tls_writer,
peer, peer,
&config, &self.config,
&replay_checker, &self.replay_checker,
true, true,
) )
.await .await
@ -189,31 +291,30 @@ impl RunningClientHandler {
reader: _, reader: _,
writer: _, writer: _,
} => { } => {
stats.increment_connects_bad(); self.stats.increment_connects_bad();
debug!(peer = %peer, "Valid TLS but invalid MTProto handshake"); debug!(peer = %peer, "Valid TLS but invalid MTProto handshake");
return Ok(()); return Ok(HandshakeOutcome::Completed);
} }
HandshakeResult::Error(e) => return Err(e), HandshakeResult::Error(e) => return Err(e),
}; };
Self::handle_authenticated_static( Ok(HandshakeOutcome::Ready(HandshakeContext::Tls {
crypto_reader, client_reader: crypto_reader,
crypto_writer, client_writer: crypto_writer,
success, success,
self.upstream_manager, upstream_manager: self.upstream_manager,
self.stats, stats: self.stats,
self.config, config: self.config,
buffer_pool, buffer_pool: self.buffer_pool,
self.rng, rng: self.rng,
self.me_pool, me_pool: self.me_pool,
local_addr, local_addr,
peer, peer_addr: peer,
self.ip_tracker, ip_tracker: self.ip_tracker,
) }))
.await
} }
async fn handle_direct_client(mut self, first_bytes: [u8; 5]) -> Result<()> { async fn handle_direct_client(mut self, first_bytes: [u8; 5]) -> Result<HandshakeOutcome> {
let peer = self.peer; let peer = self.peer;
if !self.config.general.modes.classic && !self.config.general.modes.secure { if !self.config.general.modes.classic && !self.config.general.modes.secure {
@ -221,18 +322,13 @@ impl RunningClientHandler {
self.stats.increment_connects_bad(); self.stats.increment_connects_bad();
let (reader, writer) = self.stream.into_split(); let (reader, writer) = self.stream.into_split();
handle_bad_client(reader, writer, &first_bytes, &self.config).await; handle_bad_client(reader, writer, &first_bytes, &self.config).await;
return Ok(()); return Ok(HandshakeOutcome::Completed);
} }
let mut handshake = [0u8; HANDSHAKE_LEN]; let mut handshake = [0u8; HANDSHAKE_LEN];
handshake[..5].copy_from_slice(&first_bytes); handshake[..5].copy_from_slice(&first_bytes);
self.stream.read_exact(&mut handshake[5..]).await?; self.stream.read_exact(&mut handshake[5..]).await?;
let config = self.config.clone();
let replay_checker = self.replay_checker.clone();
let stats = self.stats.clone();
let buffer_pool = self.buffer_pool.clone();
let local_addr = self.stream.local_addr().map_err(ProxyError::Io)?; let local_addr = self.stream.local_addr().map_err(ProxyError::Io)?;
let (read_half, write_half) = self.stream.into_split(); let (read_half, write_half) = self.stream.into_split();
@ -241,36 +337,35 @@ impl RunningClientHandler {
read_half, read_half,
write_half, write_half,
peer, peer,
&config, &self.config,
&replay_checker, &self.replay_checker,
false, false,
) )
.await .await
{ {
HandshakeResult::Success(result) => result, HandshakeResult::Success(result) => result,
HandshakeResult::BadClient { reader, writer } => { HandshakeResult::BadClient { reader, writer } => {
stats.increment_connects_bad(); self.stats.increment_connects_bad();
handle_bad_client(reader, writer, &handshake, &config).await; handle_bad_client(reader, writer, &handshake, &self.config).await;
return Ok(()); return Ok(HandshakeOutcome::Completed);
} }
HandshakeResult::Error(e) => return Err(e), HandshakeResult::Error(e) => return Err(e),
}; };
Self::handle_authenticated_static( Ok(HandshakeOutcome::Ready(HandshakeContext::Direct {
crypto_reader, client_reader: crypto_reader,
crypto_writer, client_writer: crypto_writer,
success, success,
self.upstream_manager, upstream_manager: self.upstream_manager,
self.stats, stats: self.stats,
self.config, config: self.config,
buffer_pool, buffer_pool: self.buffer_pool,
self.rng, rng: self.rng,
self.me_pool, me_pool: self.me_pool,
local_addr, local_addr,
peer, peer_addr: peer,
self.ip_tracker, ip_tracker: self.ip_tracker,
) }))
.await
} }
/// Main dispatch after successful handshake. /// Main dispatch after successful handshake.