diff --git a/src/proxy/client.rs b/src/proxy/client.rs index 041e7cb..dbaa504 100644 --- a/src/proxy/client.rs +++ b/src/proxy/client.rs @@ -4,7 +4,7 @@ use std::net::SocketAddr; use std::sync::Arc; use std::time::Duration; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite}; -use tokio::net::TcpStream; +use tokio::net::{TcpStream, tcp::{OwnedReadHalf, OwnedWriteHalf}}; use tokio::time::timeout; use tracing::{debug, warn}; @@ -15,7 +15,7 @@ use crate::ip_tracker::UserIpTracker; use crate::protocol::constants::*; use crate::protocol::tls; use crate::stats::{ReplayChecker, Stats}; -use crate::stream::{BufferPool, CryptoReader, CryptoWriter}; +use crate::stream::{BufferPool, CryptoReader, CryptoWriter, FakeTlsReader, FakeTlsWriter}; use crate::transport::middle_proxy::MePool; use crate::transport::{UpstreamManager, configure_client_socket}; @@ -39,6 +39,51 @@ pub struct RunningClientHandler { ip_tracker: Arc, } +type TcpReader = OwnedReadHalf; +type TcpWriter = OwnedWriteHalf; +type TlsReader = FakeTlsReader; +type TlsWriter = FakeTlsWriter; +type DirectReader = CryptoReader; +type DirectWriter = CryptoWriter; +type TlsCryptoReader = CryptoReader; +type TlsCryptoWriter = CryptoWriter; + +enum HandshakeContext { + Direct { + client_reader: DirectReader, + client_writer: DirectWriter, + success: HandshakeSuccess, + upstream_manager: Arc, + stats: Arc, + config: Arc, + buffer_pool: Arc, + rng: Arc, + me_pool: Option>, + local_addr: SocketAddr, + peer_addr: SocketAddr, + ip_tracker: Arc, + }, + Tls { + client_reader: TlsCryptoReader, + client_writer: TlsCryptoWriter, + success: HandshakeSuccess, + upstream_manager: Arc, + stats: Arc, + config: Arc, + buffer_pool: Arc, + rng: Arc, + me_pool: Option>, + local_addr: SocketAddr, + peer_addr: SocketAddr, + ip_tracker: Arc, + }, +} + +enum HandshakeOutcome { + Completed, + Ready(HandshakeContext), +} + impl ClientHandler { pub fn new( stream: TcpStream, @@ -88,10 +133,72 @@ impl RunningClientHandler { let result = timeout(handshake_timeout, self.do_handshake()).await; match result { - Ok(Ok(())) => { + Ok(Ok(HandshakeOutcome::Completed)) => { debug!(peer = %peer, "Connection handled successfully"); Ok(()) } + Ok(Ok(HandshakeOutcome::Ready(ctx))) => match ctx { + HandshakeContext::Direct { + client_reader, + client_writer, + success, + upstream_manager, + stats, + config, + buffer_pool, + rng, + me_pool, + local_addr, + peer_addr, + ip_tracker, + } => { + Self::handle_authenticated_static( + client_reader, + client_writer, + success, + upstream_manager, + stats, + config, + buffer_pool, + rng, + me_pool, + local_addr, + peer_addr, + ip_tracker, + ) + .await + } + HandshakeContext::Tls { + client_reader, + client_writer, + success, + upstream_manager, + stats, + config, + buffer_pool, + rng, + me_pool, + local_addr, + peer_addr, + ip_tracker, + } => { + Self::handle_authenticated_static( + client_reader, + client_writer, + success, + upstream_manager, + stats, + config, + buffer_pool, + rng, + me_pool, + local_addr, + peer_addr, + ip_tracker, + ) + .await + } + }, Ok(Err(e)) => { debug!(peer = %peer, error = %e, "Handshake failed"); Err(e) @@ -104,7 +211,7 @@ impl RunningClientHandler { } } - async fn do_handshake(mut self) -> Result<()> { + async fn do_handshake(mut self) -> Result { let mut first_bytes = [0u8; 5]; self.stream.read_exact(&mut first_bytes).await?; @@ -120,7 +227,7 @@ impl RunningClientHandler { } } - async fn handle_tls_client(mut self, first_bytes: [u8; 5]) -> Result<()> { + async fn handle_tls_client(mut self, first_bytes: [u8; 5]) -> Result { let peer = self.peer; let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize; @@ -132,18 +239,13 @@ impl RunningClientHandler { self.stats.increment_connects_bad(); let (reader, writer) = self.stream.into_split(); handle_bad_client(reader, writer, &first_bytes, &self.config).await; - return Ok(()); + return Ok(HandshakeOutcome::Completed); } let mut handshake = vec![0u8; 5 + tls_len]; handshake[..5].copy_from_slice(&first_bytes); self.stream.read_exact(&mut handshake[5..]).await?; - let config = self.config.clone(); - let replay_checker = self.replay_checker.clone(); - let stats = self.stats.clone(); - let buffer_pool = self.buffer_pool.clone(); - let local_addr = self.stream.local_addr().map_err(ProxyError::Io)?; let (read_half, write_half) = self.stream.into_split(); @@ -152,17 +254,17 @@ impl RunningClientHandler { read_half, write_half, peer, - &config, - &replay_checker, + &self.config, + &self.replay_checker, &self.rng, ) .await { HandshakeResult::Success(result) => result, HandshakeResult::BadClient { reader, writer } => { - stats.increment_connects_bad(); - handle_bad_client(reader, writer, &handshake, &config).await; - return Ok(()); + self.stats.increment_connects_bad(); + handle_bad_client(reader, writer, &handshake, &self.config).await; + return Ok(HandshakeOutcome::Completed); } HandshakeResult::Error(e) => return Err(e), }; @@ -178,8 +280,8 @@ impl RunningClientHandler { tls_reader, tls_writer, peer, - &config, - &replay_checker, + &self.config, + &self.replay_checker, true, ) .await @@ -189,31 +291,30 @@ impl RunningClientHandler { reader: _, writer: _, } => { - stats.increment_connects_bad(); + self.stats.increment_connects_bad(); debug!(peer = %peer, "Valid TLS but invalid MTProto handshake"); - return Ok(()); + return Ok(HandshakeOutcome::Completed); } HandshakeResult::Error(e) => return Err(e), }; - Self::handle_authenticated_static( - crypto_reader, - crypto_writer, + Ok(HandshakeOutcome::Ready(HandshakeContext::Tls { + client_reader: crypto_reader, + client_writer: crypto_writer, success, - self.upstream_manager, - self.stats, - self.config, - buffer_pool, - self.rng, - self.me_pool, + upstream_manager: self.upstream_manager, + stats: self.stats, + config: self.config, + buffer_pool: self.buffer_pool, + rng: self.rng, + me_pool: self.me_pool, local_addr, - peer, - self.ip_tracker, - ) - .await + peer_addr: peer, + ip_tracker: self.ip_tracker, + })) } - async fn handle_direct_client(mut self, first_bytes: [u8; 5]) -> Result<()> { + async fn handle_direct_client(mut self, first_bytes: [u8; 5]) -> Result { let peer = self.peer; if !self.config.general.modes.classic && !self.config.general.modes.secure { @@ -221,18 +322,13 @@ impl RunningClientHandler { self.stats.increment_connects_bad(); let (reader, writer) = self.stream.into_split(); handle_bad_client(reader, writer, &first_bytes, &self.config).await; - return Ok(()); + return Ok(HandshakeOutcome::Completed); } let mut handshake = [0u8; HANDSHAKE_LEN]; handshake[..5].copy_from_slice(&first_bytes); self.stream.read_exact(&mut handshake[5..]).await?; - let config = self.config.clone(); - let replay_checker = self.replay_checker.clone(); - let stats = self.stats.clone(); - let buffer_pool = self.buffer_pool.clone(); - let local_addr = self.stream.local_addr().map_err(ProxyError::Io)?; let (read_half, write_half) = self.stream.into_split(); @@ -241,36 +337,35 @@ impl RunningClientHandler { read_half, write_half, peer, - &config, - &replay_checker, + &self.config, + &self.replay_checker, false, ) .await { HandshakeResult::Success(result) => result, HandshakeResult::BadClient { reader, writer } => { - stats.increment_connects_bad(); - handle_bad_client(reader, writer, &handshake, &config).await; - return Ok(()); + self.stats.increment_connects_bad(); + handle_bad_client(reader, writer, &handshake, &self.config).await; + return Ok(HandshakeOutcome::Completed); } HandshakeResult::Error(e) => return Err(e), }; - Self::handle_authenticated_static( - crypto_reader, - crypto_writer, + Ok(HandshakeOutcome::Ready(HandshakeContext::Direct { + client_reader: crypto_reader, + client_writer: crypto_writer, success, - self.upstream_manager, - self.stats, - self.config, - buffer_pool, - self.rng, - self.me_pool, + upstream_manager: self.upstream_manager, + stats: self.stats, + config: self.config, + buffer_pool: self.buffer_pool, + rng: self.rng, + me_pool: self.me_pool, local_addr, - peer, - self.ip_tracker, - ) - .await + peer_addr: peer, + ip_tracker: self.ip_tracker, + })) } /// Main dispatch after successful handshake.