Revert "Use handshake timeout only for handshake, not for the whole session."

This reverts commit 2c4016d792.
This commit is contained in:
AndreyAkifev 2026-02-16 10:20:42 +07:00
parent 2c4016d792
commit 4b0d4f74d9
1 changed files with 58 additions and 153 deletions

View File

@ -4,7 +4,7 @@ use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
use tokio::net::{TcpStream, tcp::{OwnedReadHalf, OwnedWriteHalf}}; use tokio::net::TcpStream;
use tokio::time::timeout; use tokio::time::timeout;
use tracing::{debug, warn}; use tracing::{debug, warn};
@ -15,7 +15,7 @@ use crate::ip_tracker::UserIpTracker;
use crate::protocol::constants::*; use crate::protocol::constants::*;
use crate::protocol::tls; use crate::protocol::tls;
use crate::stats::{ReplayChecker, Stats}; use crate::stats::{ReplayChecker, Stats};
use crate::stream::{BufferPool, CryptoReader, CryptoWriter, FakeTlsReader, FakeTlsWriter}; use crate::stream::{BufferPool, CryptoReader, CryptoWriter};
use crate::transport::middle_proxy::MePool; use crate::transport::middle_proxy::MePool;
use crate::transport::{UpstreamManager, configure_client_socket}; use crate::transport::{UpstreamManager, configure_client_socket};
@ -39,51 +39,6 @@ pub struct RunningClientHandler {
ip_tracker: Arc<UserIpTracker>, ip_tracker: Arc<UserIpTracker>,
} }
type TcpReader = OwnedReadHalf;
type TcpWriter = OwnedWriteHalf;
type TlsReader = FakeTlsReader<TcpReader>;
type TlsWriter = FakeTlsWriter<TcpWriter>;
type DirectReader = CryptoReader<TcpReader>;
type DirectWriter = CryptoWriter<TcpWriter>;
type TlsCryptoReader = CryptoReader<TlsReader>;
type TlsCryptoWriter = CryptoWriter<TlsWriter>;
enum HandshakeContext {
Direct {
client_reader: DirectReader,
client_writer: DirectWriter,
success: HandshakeSuccess,
upstream_manager: Arc<UpstreamManager>,
stats: Arc<Stats>,
config: Arc<ProxyConfig>,
buffer_pool: Arc<BufferPool>,
rng: Arc<SecureRandom>,
me_pool: Option<Arc<MePool>>,
local_addr: SocketAddr,
peer_addr: SocketAddr,
ip_tracker: Arc<UserIpTracker>,
},
Tls {
client_reader: TlsCryptoReader,
client_writer: TlsCryptoWriter,
success: HandshakeSuccess,
upstream_manager: Arc<UpstreamManager>,
stats: Arc<Stats>,
config: Arc<ProxyConfig>,
buffer_pool: Arc<BufferPool>,
rng: Arc<SecureRandom>,
me_pool: Option<Arc<MePool>>,
local_addr: SocketAddr,
peer_addr: SocketAddr,
ip_tracker: Arc<UserIpTracker>,
},
}
enum HandshakeOutcome {
Completed,
Ready(HandshakeContext),
}
impl ClientHandler { impl ClientHandler {
pub fn new( pub fn new(
stream: TcpStream, stream: TcpStream,
@ -133,72 +88,10 @@ impl RunningClientHandler {
let result = timeout(handshake_timeout, self.do_handshake()).await; let result = timeout(handshake_timeout, self.do_handshake()).await;
match result { match result {
Ok(Ok(HandshakeOutcome::Completed)) => { Ok(Ok(())) => {
debug!(peer = %peer, "Connection handled successfully"); debug!(peer = %peer, "Connection handled successfully");
Ok(()) Ok(())
} }
Ok(Ok(HandshakeOutcome::Ready(ctx))) => match ctx {
HandshakeContext::Direct {
client_reader,
client_writer,
success,
upstream_manager,
stats,
config,
buffer_pool,
rng,
me_pool,
local_addr,
peer_addr,
ip_tracker,
} => {
Self::handle_authenticated_static(
client_reader,
client_writer,
success,
upstream_manager,
stats,
config,
buffer_pool,
rng,
me_pool,
local_addr,
peer_addr,
ip_tracker,
)
.await
}
HandshakeContext::Tls {
client_reader,
client_writer,
success,
upstream_manager,
stats,
config,
buffer_pool,
rng,
me_pool,
local_addr,
peer_addr,
ip_tracker,
} => {
Self::handle_authenticated_static(
client_reader,
client_writer,
success,
upstream_manager,
stats,
config,
buffer_pool,
rng,
me_pool,
local_addr,
peer_addr,
ip_tracker,
)
.await
}
},
Ok(Err(e)) => { Ok(Err(e)) => {
debug!(peer = %peer, error = %e, "Handshake failed"); debug!(peer = %peer, error = %e, "Handshake failed");
Err(e) Err(e)
@ -211,7 +104,7 @@ impl RunningClientHandler {
} }
} }
async fn do_handshake(mut self) -> Result<HandshakeOutcome> { async fn do_handshake(mut self) -> Result<()> {
let mut first_bytes = [0u8; 5]; let mut first_bytes = [0u8; 5];
self.stream.read_exact(&mut first_bytes).await?; self.stream.read_exact(&mut first_bytes).await?;
@ -227,7 +120,7 @@ impl RunningClientHandler {
} }
} }
async fn handle_tls_client(mut self, first_bytes: [u8; 5]) -> Result<HandshakeOutcome> { async fn handle_tls_client(mut self, first_bytes: [u8; 5]) -> Result<()> {
let peer = self.peer; let peer = self.peer;
let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize; let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize;
@ -239,13 +132,18 @@ impl RunningClientHandler {
self.stats.increment_connects_bad(); self.stats.increment_connects_bad();
let (reader, writer) = self.stream.into_split(); let (reader, writer) = self.stream.into_split();
handle_bad_client(reader, writer, &first_bytes, &self.config).await; handle_bad_client(reader, writer, &first_bytes, &self.config).await;
return Ok(HandshakeOutcome::Completed); return Ok(());
} }
let mut handshake = vec![0u8; 5 + tls_len]; let mut handshake = vec![0u8; 5 + tls_len];
handshake[..5].copy_from_slice(&first_bytes); handshake[..5].copy_from_slice(&first_bytes);
self.stream.read_exact(&mut handshake[5..]).await?; self.stream.read_exact(&mut handshake[5..]).await?;
let config = self.config.clone();
let replay_checker = self.replay_checker.clone();
let stats = self.stats.clone();
let buffer_pool = self.buffer_pool.clone();
let local_addr = self.stream.local_addr().map_err(ProxyError::Io)?; let local_addr = self.stream.local_addr().map_err(ProxyError::Io)?;
let (read_half, write_half) = self.stream.into_split(); let (read_half, write_half) = self.stream.into_split();
@ -254,17 +152,17 @@ impl RunningClientHandler {
read_half, read_half,
write_half, write_half,
peer, peer,
&self.config, &config,
&self.replay_checker, &replay_checker,
&self.rng, &self.rng,
) )
.await .await
{ {
HandshakeResult::Success(result) => result, HandshakeResult::Success(result) => result,
HandshakeResult::BadClient { reader, writer } => { HandshakeResult::BadClient { reader, writer } => {
self.stats.increment_connects_bad(); stats.increment_connects_bad();
handle_bad_client(reader, writer, &handshake, &self.config).await; handle_bad_client(reader, writer, &handshake, &config).await;
return Ok(HandshakeOutcome::Completed); return Ok(());
} }
HandshakeResult::Error(e) => return Err(e), HandshakeResult::Error(e) => return Err(e),
}; };
@ -280,8 +178,8 @@ impl RunningClientHandler {
tls_reader, tls_reader,
tls_writer, tls_writer,
peer, peer,
&self.config, &config,
&self.replay_checker, &replay_checker,
true, true,
) )
.await .await
@ -291,30 +189,31 @@ impl RunningClientHandler {
reader: _, reader: _,
writer: _, writer: _,
} => { } => {
self.stats.increment_connects_bad(); stats.increment_connects_bad();
debug!(peer = %peer, "Valid TLS but invalid MTProto handshake"); debug!(peer = %peer, "Valid TLS but invalid MTProto handshake");
return Ok(HandshakeOutcome::Completed); return Ok(());
} }
HandshakeResult::Error(e) => return Err(e), HandshakeResult::Error(e) => return Err(e),
}; };
Ok(HandshakeOutcome::Ready(HandshakeContext::Tls { Self::handle_authenticated_static(
client_reader: crypto_reader, crypto_reader,
client_writer: crypto_writer, crypto_writer,
success, success,
upstream_manager: self.upstream_manager, self.upstream_manager,
stats: self.stats, self.stats,
config: self.config, self.config,
buffer_pool: self.buffer_pool, buffer_pool,
rng: self.rng, self.rng,
me_pool: self.me_pool, self.me_pool,
local_addr, local_addr,
peer_addr: peer, peer,
ip_tracker: self.ip_tracker, self.ip_tracker,
})) )
.await
} }
async fn handle_direct_client(mut self, first_bytes: [u8; 5]) -> Result<HandshakeOutcome> { async fn handle_direct_client(mut self, first_bytes: [u8; 5]) -> Result<()> {
let peer = self.peer; let peer = self.peer;
if !self.config.general.modes.classic && !self.config.general.modes.secure { if !self.config.general.modes.classic && !self.config.general.modes.secure {
@ -322,13 +221,18 @@ impl RunningClientHandler {
self.stats.increment_connects_bad(); self.stats.increment_connects_bad();
let (reader, writer) = self.stream.into_split(); let (reader, writer) = self.stream.into_split();
handle_bad_client(reader, writer, &first_bytes, &self.config).await; handle_bad_client(reader, writer, &first_bytes, &self.config).await;
return Ok(HandshakeOutcome::Completed); return Ok(());
} }
let mut handshake = [0u8; HANDSHAKE_LEN]; let mut handshake = [0u8; HANDSHAKE_LEN];
handshake[..5].copy_from_slice(&first_bytes); handshake[..5].copy_from_slice(&first_bytes);
self.stream.read_exact(&mut handshake[5..]).await?; self.stream.read_exact(&mut handshake[5..]).await?;
let config = self.config.clone();
let replay_checker = self.replay_checker.clone();
let stats = self.stats.clone();
let buffer_pool = self.buffer_pool.clone();
let local_addr = self.stream.local_addr().map_err(ProxyError::Io)?; let local_addr = self.stream.local_addr().map_err(ProxyError::Io)?;
let (read_half, write_half) = self.stream.into_split(); let (read_half, write_half) = self.stream.into_split();
@ -337,35 +241,36 @@ impl RunningClientHandler {
read_half, read_half,
write_half, write_half,
peer, peer,
&self.config, &config,
&self.replay_checker, &replay_checker,
false, false,
) )
.await .await
{ {
HandshakeResult::Success(result) => result, HandshakeResult::Success(result) => result,
HandshakeResult::BadClient { reader, writer } => { HandshakeResult::BadClient { reader, writer } => {
self.stats.increment_connects_bad(); stats.increment_connects_bad();
handle_bad_client(reader, writer, &handshake, &self.config).await; handle_bad_client(reader, writer, &handshake, &config).await;
return Ok(HandshakeOutcome::Completed); return Ok(());
} }
HandshakeResult::Error(e) => return Err(e), HandshakeResult::Error(e) => return Err(e),
}; };
Ok(HandshakeOutcome::Ready(HandshakeContext::Direct { Self::handle_authenticated_static(
client_reader: crypto_reader, crypto_reader,
client_writer: crypto_writer, crypto_writer,
success, success,
upstream_manager: self.upstream_manager, self.upstream_manager,
stats: self.stats, self.stats,
config: self.config, self.config,
buffer_pool: self.buffer_pool, buffer_pool,
rng: self.rng, self.rng,
me_pool: self.me_pool, self.me_pool,
local_addr, local_addr,
peer_addr: peer, peer,
ip_tracker: self.ip_tracker, self.ip_tracker,
})) )
.await
} }
/// Main dispatch after successful handshake. /// Main dispatch after successful handshake.