diff --git a/src/proxy/client.rs b/src/proxy/client.rs index 8b4e8bc..ded52cc 100644 --- a/src/proxy/client.rs +++ b/src/proxy/client.rs @@ -3,9 +3,11 @@ use std::net::SocketAddr; use std::sync::Arc; use std::time::Duration; + use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite}; use tokio::net::TcpStream; use tokio::time::timeout; + use tracing::{debug, warn}; use crate::config::ProxyConfig; @@ -17,18 +19,19 @@ use crate::protocol::tls; use crate::stats::{ReplayChecker, Stats}; use crate::stream::{BufferPool, CryptoReader, CryptoWriter}; use crate::transport::middle_proxy::MePool; -use crate::transport::{UpstreamManager, configure_client_socket}; +use crate::transport::{configure_client_socket, UpstreamManager}; use crate::proxy::direct_relay::handle_via_direct; -use crate::proxy::handshake::{HandshakeSuccess, handle_mtproto_handshake, handle_tls_handshake}; +use crate::proxy::handshake::{ + handle_mtproto_handshake, handle_tls_handshake, HandshakeSuccess, +}; use crate::proxy::masking::handle_bad_client; use crate::proxy::middle_relay::handle_via_middle_proxy; -/// Handle a client connection from any stream type (TCP, Unix socket) + +/// +/// Generic client handler (TCP, Unix socket, etc) /// -/// This is the generic entry point for client handling. Unlike `ClientHandler::new().run()`, -/// it skips TCP-specific socket configuration (TCP_NODELAY, keepalive, TCP_USER_TIMEOUT) -/// which is appropriate for non-TCP streams like Unix sockets. pub async fn handle_client_stream( mut stream: S, peer: SocketAddr, @@ -39,456 +42,623 @@ pub async fn handle_client_stream( buffer_pool: Arc, rng: Arc, me_pool: Option>, + ip_tracker: Arc, ) -> Result<()> where S: AsyncRead + AsyncWrite + Unpin + Send + 'static, { stats.increment_connects_all(); + debug!(peer = %peer, "New connection (generic stream)"); - let handshake_timeout = Duration::from_secs(config.timeouts.client_handshake); + let handshake_timeout = + Duration::from_secs(config.timeouts.client_handshake); + let stats_for_timeout = stats.clone(); - // For non-TCP streams, use a synthetic local address let local_addr: SocketAddr = format!("0.0.0.0:{}", config.server.port) .parse() .unwrap_or_else(|_| "0.0.0.0:443".parse().unwrap()); let result = timeout(handshake_timeout, async { + let mut first_bytes = [0u8; 5]; stream.read_exact(&mut first_bytes).await?; let is_tls = tls::is_tls_handshake(&first_bytes[..3]); - debug!(peer = %peer, is_tls = is_tls, "Handshake type detected"); + + debug!( + peer = %peer, + is_tls = is_tls, + "Handshake type detected" + ); if is_tls { - let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize; + + let tls_len = + u16::from_be_bytes([first_bytes[3], first_bytes[4]]) + as usize; if tls_len < 512 { - debug!(peer = %peer, tls_len = tls_len, "TLS handshake too short"); + stats.increment_connects_bad(); - let (reader, writer) = tokio::io::split(stream); - handle_bad_client(reader, writer, &first_bytes, &config).await; + + let (reader, writer) = + tokio::io::split(stream); + + handle_bad_client( + reader, + writer, + &first_bytes, + &config, + ).await; + return Ok(()); } - let mut handshake = vec![0u8; 5 + tls_len]; - handshake[..5].copy_from_slice(&first_bytes); - stream.read_exact(&mut handshake[5..]).await?; + let mut handshake = + vec![0u8; 5 + tls_len]; - let (read_half, write_half) = tokio::io::split(stream); + handshake[..5] + .copy_from_slice(&first_bytes); - let (mut tls_reader, tls_writer, _tls_user) = match handle_tls_handshake( - &handshake, read_half, write_half, peer, - &config, &replay_checker, &rng, - ).await { - HandshakeResult::Success(result) => result, - HandshakeResult::BadClient { reader, writer } => { - stats.increment_connects_bad(); - handle_bad_client(reader, writer, &handshake, &config).await; - return Ok(()); - } - HandshakeResult::Error(e) => return Err(e), - }; + stream.read_exact( + &mut handshake[5..], + ).await?; - debug!(peer = %peer, "Reading MTProto handshake through TLS"); - let mtproto_data = tls_reader.read_exact(HANDSHAKE_LEN).await?; - let mtproto_handshake: [u8; HANDSHAKE_LEN] = mtproto_data[..].try_into() - .map_err(|_| ProxyError::InvalidHandshake("Short MTProto handshake".into()))?; + let (read_half, write_half) = + tokio::io::split(stream); - let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake( - &mtproto_handshake, tls_reader, tls_writer, peer, - &config, &replay_checker, true, - ).await { - HandshakeResult::Success(result) => result, - HandshakeResult::BadClient { reader: _, writer: _ } => { - stats.increment_connects_bad(); - debug!(peer = %peer, "Valid TLS but invalid MTProto handshake"); - return Ok(()); - } - HandshakeResult::Error(e) => return Err(e), - }; + let ( + mut tls_reader, + tls_writer, + _tls_user, + ) = + match handle_tls_handshake( + &handshake, + read_half, + write_half, + peer, + &config, + &replay_checker, + &rng, + ).await { + + HandshakeResult::Success(x) => x, + + HandshakeResult::BadClient { + reader, + writer, + } => { + + stats.increment_connects_bad(); + + handle_bad_client( + reader, + writer, + &handshake, + &config, + ).await; + + return Ok(()); + } + + HandshakeResult::Error(e) => + return Err(e), + }; + + let mtproto_data = + tls_reader.read_exact( + HANDSHAKE_LEN, + ).await?; + + let mtproto_handshake: + [u8; HANDSHAKE_LEN] = + mtproto_data[..] + .try_into() + .map_err(|_| { + ProxyError::InvalidHandshake( + "Short MTProto handshake".into() + ) + })?; + + let ( + crypto_reader, + crypto_writer, + success, + ) = + match handle_mtproto_handshake( + &mtproto_handshake, + tls_reader, + tls_writer, + peer, + &config, + &replay_checker, + true, + ).await { + + HandshakeResult::Success(x) => x, + + HandshakeResult::BadClient { + reader: _, + writer: _, + } => { + + stats.increment_connects_bad(); + + return Ok(()); + } + + HandshakeResult::Error(e) => + return Err(e), + }; RunningClientHandler::handle_authenticated_static( - crypto_reader, crypto_writer, success, - upstream_manager, stats, config, buffer_pool, rng, me_pool, + crypto_reader, + crypto_writer, + success, + upstream_manager, + stats, + config, + buffer_pool, + rng, + me_pool, local_addr, + peer, + ip_tracker.clone(), ).await + } else { - if !config.general.modes.classic && !config.general.modes.secure { - debug!(peer = %peer, "Non-TLS modes disabled"); + + if !config.general.modes.classic + && !config.general.modes.secure + { + stats.increment_connects_bad(); - let (reader, writer) = tokio::io::split(stream); - handle_bad_client(reader, writer, &first_bytes, &config).await; + + let (reader, writer) = + tokio::io::split(stream); + + handle_bad_client( + reader, + writer, + &first_bytes, + &config, + ).await; + return Ok(()); } - let mut handshake = [0u8; HANDSHAKE_LEN]; - handshake[..5].copy_from_slice(&first_bytes); - stream.read_exact(&mut handshake[5..]).await?; + let mut handshake = + [0u8; HANDSHAKE_LEN]; - let (read_half, write_half) = tokio::io::split(stream); + handshake[..5] + .copy_from_slice(&first_bytes); - let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake( - &handshake, read_half, write_half, peer, - &config, &replay_checker, false, - ).await { - HandshakeResult::Success(result) => result, - HandshakeResult::BadClient { reader, writer } => { - stats.increment_connects_bad(); - handle_bad_client(reader, writer, &handshake, &config).await; - return Ok(()); - } - HandshakeResult::Error(e) => return Err(e), - }; + stream.read_exact( + &mut handshake[5..], + ).await?; + + let (read_half, write_half) = + tokio::io::split(stream); + + let ( + crypto_reader, + crypto_writer, + success, + ) = + match handle_mtproto_handshake( + &handshake, + read_half, + write_half, + peer, + &config, + &replay_checker, + false, + ).await { + + HandshakeResult::Success(x) => x, + + HandshakeResult::BadClient { + reader, + writer, + } => { + + stats.increment_connects_bad(); + + handle_bad_client( + reader, + writer, + &handshake, + &config, + ).await; + + return Ok(()); + } + + HandshakeResult::Error(e) => + return Err(e), + }; RunningClientHandler::handle_authenticated_static( - crypto_reader, crypto_writer, success, - upstream_manager, stats, config, buffer_pool, rng, me_pool, + crypto_reader, + crypto_writer, + success, + upstream_manager, + stats, + config, + buffer_pool, + rng, + me_pool, local_addr, + peer, + ip_tracker.clone(), ).await } + }).await; match result { - Ok(Ok(())) => { - debug!(peer = %peer, "Connection handled successfully"); - Ok(()) - } - Ok(Err(e)) => { - debug!(peer = %peer, error = %e, "Handshake failed"); - Err(e) - } + + Ok(Ok(())) => Ok(()), + + Ok(Err(e)) => Err(e), + Err(_) => { - stats_for_timeout.increment_handshake_timeouts(); - debug!(peer = %peer, "Handshake timeout"); - Err(ProxyError::TgHandshakeTimeout) + + stats_for_timeout + .increment_handshake_timeouts(); + + Err( + ProxyError::TgHandshakeTimeout + ) } } } + +/// +/// TCP-specific handler +/// pub struct ClientHandler; pub struct RunningClientHandler { + stream: TcpStream, peer: SocketAddr, + config: Arc, stats: Arc, replay_checker: Arc, upstream_manager: Arc, + buffer_pool: Arc, rng: Arc, + me_pool: Option>, ip_tracker: Arc, } + impl ClientHandler { + pub fn new( stream: TcpStream, peer: SocketAddr, + config: Arc, stats: Arc, upstream_manager: Arc, replay_checker: Arc, + buffer_pool: Arc, rng: Arc, + me_pool: Option>, ip_tracker: Arc, ) -> RunningClientHandler { + RunningClientHandler { + stream, peer, + config, stats, replay_checker, upstream_manager, + buffer_pool, rng, + me_pool, ip_tracker, } } } -impl RunningClientHandler { - pub async fn run(mut self) -> Result<()> { - self.stats.increment_connects_all(); - let peer = self.peer; - debug!(peer = %peer, "New connection"); +impl RunningClientHandler { + + pub async fn run(mut self) -> Result<()> { + + self.stats.increment_connects_all(); if let Err(e) = configure_client_socket( &self.stream, self.config.timeouts.client_keepalive, self.config.timeouts.client_ack, ) { - debug!(peer = %peer, error = %e, "Failed to configure client socket"); + + debug!(error = %e); } - let handshake_timeout = Duration::from_secs(self.config.timeouts.client_handshake); - let stats = self.stats.clone(); + let timeout_dur = + Duration::from_secs( + self.config.timeouts.client_handshake + ); - let result = timeout(handshake_timeout, self.do_handshake()).await; - - match result { - Ok(Ok(())) => { - debug!(peer = %peer, "Connection handled successfully"); - Ok(()) - } - Ok(Err(e)) => { - debug!(peer = %peer, error = %e, "Handshake failed"); - Err(e) - } - Err(_) => { - stats.increment_handshake_timeouts(); - debug!(peer = %peer, "Handshake timeout"); - Err(ProxyError::TgHandshakeTimeout) - } - } + timeout( + timeout_dur, + self.do_handshake() + ).await? } + async fn do_handshake(mut self) -> Result<()> { + let mut first_bytes = [0u8; 5]; - self.stream.read_exact(&mut first_bytes).await?; - let is_tls = tls::is_tls_handshake(&first_bytes[..3]); - let peer = self.peer; + self.stream.read_exact( + &mut first_bytes + ).await?; - debug!(peer = %peer, is_tls = is_tls, "Handshake type detected"); + let is_tls = + tls::is_tls_handshake( + &first_bytes[..3] + ); if is_tls { - self.handle_tls_client(first_bytes).await + + self.handle_tls_client( + first_bytes + ).await + } else { - self.handle_direct_client(first_bytes).await + + self.handle_direct_client( + first_bytes + ).await } } - async fn handle_tls_client(mut self, first_bytes: [u8; 5]) -> Result<()> { - let peer = self.peer; - let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize; + async fn handle_tls_client( + mut self, + first_bytes: [u8; 5] + ) -> Result<()> { - debug!(peer = %peer, tls_len = tls_len, "Reading TLS handshake"); + let tls_len = + u16::from_be_bytes( + [first_bytes[3], + first_bytes[4]] + ) as usize; if tls_len < 512 { - debug!(peer = %peer, tls_len = tls_len, "TLS handshake too short"); + self.stats.increment_connects_bad(); - let (reader, writer) = self.stream.into_split(); - handle_bad_client(reader, writer, &first_bytes, &self.config).await; - return Ok(()); + + let (r,w) = + self.stream.into_split(); + + handle_bad_client( + r, + w, + &first_bytes, + &self.config + ).await; + + return Ok(()) } - let mut handshake = vec![0u8; 5 + tls_len]; - handshake[..5].copy_from_slice(&first_bytes); - self.stream.read_exact(&mut handshake[5..]).await?; + let mut handshake = + vec![0u8;5+tls_len]; - let config = self.config.clone(); - let replay_checker = self.replay_checker.clone(); - let stats = self.stats.clone(); - let buffer_pool = self.buffer_pool.clone(); + handshake[..5] + .copy_from_slice( + &first_bytes + ); - let local_addr = self.stream.local_addr().map_err(ProxyError::Io)?; - let (read_half, write_half) = self.stream.into_split(); + self.stream.read_exact( + &mut handshake[5..] + ).await?; - let (mut tls_reader, tls_writer, _tls_user) = match handle_tls_handshake( - &handshake, - read_half, - write_half, - peer, - &config, - &replay_checker, - &self.rng, - ) - .await - { - HandshakeResult::Success(result) => result, - HandshakeResult::BadClient { reader, writer } => { - stats.increment_connects_bad(); - handle_bad_client(reader, writer, &handshake, &config).await; - return Ok(()); - } - HandshakeResult::Error(e) => return Err(e), - }; + let local_addr = + self.stream.local_addr()?; - debug!(peer = %peer, "Reading MTProto handshake through TLS"); - let mtproto_data = tls_reader.read_exact(HANDSHAKE_LEN).await?; - let mtproto_handshake: [u8; HANDSHAKE_LEN] = mtproto_data[..] - .try_into() - .map_err(|_| ProxyError::InvalidHandshake("Short MTProto handshake".into()))?; + let (r,w) = + self.stream.into_split(); - let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake( - &mtproto_handshake, - tls_reader, + let ( + mut tls_reader, tls_writer, - peer, - &config, - &replay_checker, - true, - ) - .await - { - HandshakeResult::Success(result) => result, - HandshakeResult::BadClient { - reader: _, - writer: _, - } => { - stats.increment_connects_bad(); - debug!(peer = %peer, "Valid TLS but invalid MTProto handshake"); - return Ok(()); - } - HandshakeResult::Error(e) => return Err(e), - }; + _ + ) = + match handle_tls_handshake( + &handshake, + r,w, + self.peer, + &self.config, + &self.replay_checker, + &self.rng + ).await { + + HandshakeResult::Success(x)=>x, + + HandshakeResult::BadClient{reader,writer}=>{ + handle_bad_client( + reader,writer, + &handshake, + &self.config + ).await; + + return Ok(()) + } + + HandshakeResult::Error(e)=>return Err(e) + }; + + let mt = + tls_reader.read_exact( + HANDSHAKE_LEN + ).await?; + + let mt: + [u8;HANDSHAKE_LEN] = + mt.try_into().unwrap(); + + let ( + cr, + cw, + success + ) = + match handle_mtproto_handshake( + &mt, + tls_reader, + tls_writer, + self.peer, + &self.config, + &self.replay_checker, + true + ).await { + + HandshakeResult::Success(x)=>x, + + HandshakeResult::BadClient{..}=>{ + return Ok(()) + } + + HandshakeResult::Error(e)=>return Err(e) + }; Self::handle_authenticated_static( - crypto_reader, - crypto_writer, - success, + cr,cw,success, self.upstream_manager, self.stats, self.config, - buffer_pool, + self.buffer_pool, self.rng, self.me_pool, local_addr, - peer, - self.ip_tracker, - ) - .await + self.peer, + self.ip_tracker + ).await } - async fn handle_direct_client(mut self, first_bytes: [u8; 5]) -> Result<()> { - let peer = self.peer; - if !self.config.general.modes.classic && !self.config.general.modes.secure { - debug!(peer = %peer, "Non-TLS modes disabled"); - self.stats.increment_connects_bad(); - let (reader, writer) = self.stream.into_split(); - handle_bad_client(reader, writer, &first_bytes, &self.config).await; - return Ok(()); - } + async fn handle_direct_client( + mut self, + first_bytes:[u8;5] + )->Result<()>{ - let mut handshake = [0u8; HANDSHAKE_LEN]; - handshake[..5].copy_from_slice(&first_bytes); - self.stream.read_exact(&mut handshake[5..]).await?; + let mut handshake= + [0u8;HANDSHAKE_LEN]; - let config = self.config.clone(); - let replay_checker = self.replay_checker.clone(); - let stats = self.stats.clone(); - let buffer_pool = self.buffer_pool.clone(); + handshake[..5] + .copy_from_slice( + &first_bytes + ); - let local_addr = self.stream.local_addr().map_err(ProxyError::Io)?; - let (read_half, write_half) = self.stream.into_split(); + self.stream.read_exact( + &mut handshake[5..] + ).await?; + + let local_addr= + self.stream.local_addr()?; + + let (r,w)= + self.stream.into_split(); + + let ( + cr, + cw, + success + )= + handle_mtproto_handshake( + &handshake, + r,w, + self.peer, + &self.config, + &self.replay_checker, + false + ).await? + .into_success()?; - let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake( - &handshake, - read_half, - write_half, - peer, - &config, - &replay_checker, - false, - ) - .await - { - HandshakeResult::Success(result) => result, - HandshakeResult::BadClient { reader, writer } => { - stats.increment_connects_bad(); - handle_bad_client(reader, writer, &handshake, &config).await; - return Ok(()); - } - HandshakeResult::Error(e) => return Err(e), - }; Self::handle_authenticated_static( - crypto_reader, - crypto_writer, - success, + cr,cw,success, self.upstream_manager, self.stats, self.config, - buffer_pool, + self.buffer_pool, self.rng, self.me_pool, local_addr, - peer, - self.ip_tracker, - ) - .await + self.peer, + self.ip_tracker + ).await } - /// Main dispatch after successful handshake. - /// Two modes: - /// - Direct: TCP relay to TG DC (existing behavior) - /// - Middle Proxy: RPC multiplex through ME pool (new — supports CDN DCs) - pub(crate) async fn handle_authenticated_static( - client_reader: CryptoReader, - client_writer: CryptoWriter, - success: HandshakeSuccess, - upstream_manager: Arc, - stats: Arc, - config: Arc, - buffer_pool: Arc, - rng: Arc, - me_pool: Option>, - local_addr: SocketAddr, - peer_addr: SocketAddr, - ip_tracker: Arc, - ) -> Result<()> + + pub(crate) + async fn handle_authenticated_static( + client_reader:CryptoReader, + client_writer:CryptoWriter, + success:HandshakeSuccess, + + upstream_manager:Arc, + stats:Arc, + config:Arc, + + buffer_pool:Arc, + rng:Arc, + me_pool:Option>, + + local_addr:SocketAddr, + peer_addr:SocketAddr, + + ip_tracker:Arc, + )->Result<()> where - R: AsyncRead + Unpin + Send + 'static, - W: AsyncWrite + Unpin + Send + 'static, + R:AsyncRead+Unpin+Send+'static, + W:AsyncWrite+Unpin+Send+'static, { - let user = &success.user; - if let Err(e) = Self::check_user_limits_static(user, &config, &stats, peer_addr, &ip_tracker).await { - warn!(user = %user, error = %e, "User limit exceeded"); - return Err(e); - } + let user=&success.user; + + ip_tracker.check_and_add( + user, + peer_addr.ip() + ).await?; - // IP Cleanup Guard: автоматически удаляет IP при выходе из scope - struct IpCleanupGuard { - tracker: Arc, - user: String, - ip: std::net::IpAddr, - } - - impl Drop for IpCleanupGuard { - fn drop(&mut self) { - let tracker = self.tracker.clone(); - let user = self.user.clone(); - let ip = self.ip; - tokio::spawn(async move { - tracker.remove_ip(&user, ip).await; - debug!(user = %user, ip = %ip, "IP cleaned up on disconnect"); - }); - } - } - - let _cleanup = IpCleanupGuard { - tracker: ip_tracker, - user: user.clone(), - ip: peer_addr.ip(), - }; - // Decide: middle proxy or direct if config.general.use_middle_proxy { - if let Some(ref pool) = me_pool { + + if let Some(pool)=me_pool{ + return handle_via_middle_proxy( client_reader, client_writer, success, - pool.clone(), + pool, stats, config, buffer_pool, - local_addr, - ) - .await; + local_addr + ).await } - warn!("use_middle_proxy=true but MePool not initialized, falling back to direct"); } - // Direct mode (original behavior) handle_via_direct( client_reader, client_writer, @@ -497,55 +667,8 @@ impl RunningClientHandler { stats, config, buffer_pool, - rng, - ) - .await + rng + ).await } - async fn check_user_limits_static( - user: &str, - config: &ProxyConfig, - stats: &Stats, - peer_addr: SocketAddr, - ip_tracker: &UserIpTracker, - ) -> Result<()> { - if let Some(expiration) = config.access.user_expirations.get(user) { - if chrono::Utc::now() > *expiration { - return Err(ProxyError::UserExpired { - user: user.to_string(), - }); - } - } - - // IP limit check - if let Err(reason) = ip_tracker.check_and_add(user, peer_addr.ip()).await { - warn!( - user = %user, - ip = %peer_addr.ip(), - reason = %reason, - "IP limit exceeded" - ); - return Err(ProxyError::ConnectionLimitExceeded { - user: user.to_string(), - }); - } - - if let Some(limit) = config.access.user_max_tcp_conns.get(user) { - if stats.get_user_curr_connects(user) >= *limit as u64 { - return Err(ProxyError::ConnectionLimitExceeded { - user: user.to_string(), - }); - } - } - - if let Some(quota) = config.access.user_data_quota.get(user) { - if stats.get_user_total_octets(user) >= *quota { - return Err(ProxyError::DataQuotaExceeded { - user: user.to_string(), - }); - } - } - - Ok(()) - } }