diff --git a/src/proxy/client.rs b/src/proxy/client.rs index dbaa504..041e7cb 100644 --- a/src/proxy/client.rs +++ b/src/proxy/client.rs @@ -4,7 +4,7 @@ use std::net::SocketAddr; use std::sync::Arc; use std::time::Duration; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite}; -use tokio::net::{TcpStream, tcp::{OwnedReadHalf, OwnedWriteHalf}}; +use tokio::net::TcpStream; use tokio::time::timeout; use tracing::{debug, warn}; @@ -15,7 +15,7 @@ use crate::ip_tracker::UserIpTracker; use crate::protocol::constants::*; use crate::protocol::tls; use crate::stats::{ReplayChecker, Stats}; -use crate::stream::{BufferPool, CryptoReader, CryptoWriter, FakeTlsReader, FakeTlsWriter}; +use crate::stream::{BufferPool, CryptoReader, CryptoWriter}; use crate::transport::middle_proxy::MePool; use crate::transport::{UpstreamManager, configure_client_socket}; @@ -39,51 +39,6 @@ pub struct RunningClientHandler { ip_tracker: Arc, } -type TcpReader = OwnedReadHalf; -type TcpWriter = OwnedWriteHalf; -type TlsReader = FakeTlsReader; -type TlsWriter = FakeTlsWriter; -type DirectReader = CryptoReader; -type DirectWriter = CryptoWriter; -type TlsCryptoReader = CryptoReader; -type TlsCryptoWriter = CryptoWriter; - -enum HandshakeContext { - Direct { - client_reader: DirectReader, - client_writer: DirectWriter, - success: HandshakeSuccess, - upstream_manager: Arc, - stats: Arc, - config: Arc, - buffer_pool: Arc, - rng: Arc, - me_pool: Option>, - local_addr: SocketAddr, - peer_addr: SocketAddr, - ip_tracker: Arc, - }, - Tls { - client_reader: TlsCryptoReader, - client_writer: TlsCryptoWriter, - success: HandshakeSuccess, - upstream_manager: Arc, - stats: Arc, - config: Arc, - buffer_pool: Arc, - rng: Arc, - me_pool: Option>, - local_addr: SocketAddr, - peer_addr: SocketAddr, - ip_tracker: Arc, - }, -} - -enum HandshakeOutcome { - Completed, - Ready(HandshakeContext), -} - impl ClientHandler { pub fn new( stream: TcpStream, @@ -133,72 +88,10 @@ impl RunningClientHandler { let result = timeout(handshake_timeout, self.do_handshake()).await; match result { - Ok(Ok(HandshakeOutcome::Completed)) => { + Ok(Ok(())) => { debug!(peer = %peer, "Connection handled successfully"); Ok(()) } - Ok(Ok(HandshakeOutcome::Ready(ctx))) => match ctx { - HandshakeContext::Direct { - client_reader, - client_writer, - success, - upstream_manager, - stats, - config, - buffer_pool, - rng, - me_pool, - local_addr, - peer_addr, - ip_tracker, - } => { - Self::handle_authenticated_static( - client_reader, - client_writer, - success, - upstream_manager, - stats, - config, - buffer_pool, - rng, - me_pool, - local_addr, - peer_addr, - ip_tracker, - ) - .await - } - HandshakeContext::Tls { - client_reader, - client_writer, - success, - upstream_manager, - stats, - config, - buffer_pool, - rng, - me_pool, - local_addr, - peer_addr, - ip_tracker, - } => { - Self::handle_authenticated_static( - client_reader, - client_writer, - success, - upstream_manager, - stats, - config, - buffer_pool, - rng, - me_pool, - local_addr, - peer_addr, - ip_tracker, - ) - .await - } - }, Ok(Err(e)) => { debug!(peer = %peer, error = %e, "Handshake failed"); Err(e) @@ -211,7 +104,7 @@ impl RunningClientHandler { } } - async fn do_handshake(mut self) -> Result { + async fn do_handshake(mut self) -> Result<()> { let mut first_bytes = [0u8; 5]; self.stream.read_exact(&mut first_bytes).await?; @@ -227,7 +120,7 @@ impl RunningClientHandler { } } - async fn handle_tls_client(mut self, first_bytes: [u8; 5]) -> Result { + async fn handle_tls_client(mut self, first_bytes: [u8; 5]) -> Result<()> { let peer = self.peer; let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize; @@ -239,13 +132,18 @@ impl RunningClientHandler { self.stats.increment_connects_bad(); let (reader, writer) = self.stream.into_split(); handle_bad_client(reader, writer, &first_bytes, &self.config).await; - return Ok(HandshakeOutcome::Completed); + return Ok(()); } let mut handshake = vec![0u8; 5 + tls_len]; handshake[..5].copy_from_slice(&first_bytes); self.stream.read_exact(&mut handshake[5..]).await?; + let config = self.config.clone(); + let replay_checker = self.replay_checker.clone(); + let stats = self.stats.clone(); + let buffer_pool = self.buffer_pool.clone(); + let local_addr = self.stream.local_addr().map_err(ProxyError::Io)?; let (read_half, write_half) = self.stream.into_split(); @@ -254,17 +152,17 @@ impl RunningClientHandler { read_half, write_half, peer, - &self.config, - &self.replay_checker, + &config, + &replay_checker, &self.rng, ) .await { HandshakeResult::Success(result) => result, HandshakeResult::BadClient { reader, writer } => { - self.stats.increment_connects_bad(); - handle_bad_client(reader, writer, &handshake, &self.config).await; - return Ok(HandshakeOutcome::Completed); + stats.increment_connects_bad(); + handle_bad_client(reader, writer, &handshake, &config).await; + return Ok(()); } HandshakeResult::Error(e) => return Err(e), }; @@ -280,8 +178,8 @@ impl RunningClientHandler { tls_reader, tls_writer, peer, - &self.config, - &self.replay_checker, + &config, + &replay_checker, true, ) .await @@ -291,30 +189,31 @@ impl RunningClientHandler { reader: _, writer: _, } => { - self.stats.increment_connects_bad(); + stats.increment_connects_bad(); debug!(peer = %peer, "Valid TLS but invalid MTProto handshake"); - return Ok(HandshakeOutcome::Completed); + return Ok(()); } HandshakeResult::Error(e) => return Err(e), }; - Ok(HandshakeOutcome::Ready(HandshakeContext::Tls { - client_reader: crypto_reader, - client_writer: crypto_writer, + Self::handle_authenticated_static( + crypto_reader, + crypto_writer, success, - upstream_manager: self.upstream_manager, - stats: self.stats, - config: self.config, - buffer_pool: self.buffer_pool, - rng: self.rng, - me_pool: self.me_pool, + self.upstream_manager, + self.stats, + self.config, + buffer_pool, + self.rng, + self.me_pool, local_addr, - peer_addr: peer, - ip_tracker: self.ip_tracker, - })) + peer, + self.ip_tracker, + ) + .await } - async fn handle_direct_client(mut self, first_bytes: [u8; 5]) -> Result { + async fn handle_direct_client(mut self, first_bytes: [u8; 5]) -> Result<()> { let peer = self.peer; if !self.config.general.modes.classic && !self.config.general.modes.secure { @@ -322,13 +221,18 @@ impl RunningClientHandler { self.stats.increment_connects_bad(); let (reader, writer) = self.stream.into_split(); handle_bad_client(reader, writer, &first_bytes, &self.config).await; - return Ok(HandshakeOutcome::Completed); + return Ok(()); } let mut handshake = [0u8; HANDSHAKE_LEN]; handshake[..5].copy_from_slice(&first_bytes); self.stream.read_exact(&mut handshake[5..]).await?; + let config = self.config.clone(); + let replay_checker = self.replay_checker.clone(); + let stats = self.stats.clone(); + let buffer_pool = self.buffer_pool.clone(); + let local_addr = self.stream.local_addr().map_err(ProxyError::Io)?; let (read_half, write_half) = self.stream.into_split(); @@ -337,35 +241,36 @@ impl RunningClientHandler { read_half, write_half, peer, - &self.config, - &self.replay_checker, + &config, + &replay_checker, false, ) .await { HandshakeResult::Success(result) => result, HandshakeResult::BadClient { reader, writer } => { - self.stats.increment_connects_bad(); - handle_bad_client(reader, writer, &handshake, &self.config).await; - return Ok(HandshakeOutcome::Completed); + stats.increment_connects_bad(); + handle_bad_client(reader, writer, &handshake, &config).await; + return Ok(()); } HandshakeResult::Error(e) => return Err(e), }; - Ok(HandshakeOutcome::Ready(HandshakeContext::Direct { - client_reader: crypto_reader, - client_writer: crypto_writer, + Self::handle_authenticated_static( + crypto_reader, + crypto_writer, success, - upstream_manager: self.upstream_manager, - stats: self.stats, - config: self.config, - buffer_pool: self.buffer_pool, - rng: self.rng, - me_pool: self.me_pool, + self.upstream_manager, + self.stats, + self.config, + buffer_pool, + self.rng, + self.me_pool, local_addr, - peer_addr: peer, - ip_tracker: self.ip_tracker, - })) + peer, + self.ip_tracker, + ) + .await } /// Main dispatch after successful handshake.