Update client.rs

This commit is contained in:
Alexey 2026-02-15 03:07:13 +03:00 committed by GitHub
parent 41c90af02d
commit a6e03cfcdd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 455 additions and 332 deletions

View File

@ -3,9 +3,11 @@
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio::time::timeout; use tokio::time::timeout;
use tracing::{debug, warn}; use tracing::{debug, warn};
use crate::config::ProxyConfig; use crate::config::ProxyConfig;
@ -17,18 +19,19 @@ use crate::protocol::tls;
use crate::stats::{ReplayChecker, Stats}; use crate::stats::{ReplayChecker, Stats};
use crate::stream::{BufferPool, CryptoReader, CryptoWriter}; use crate::stream::{BufferPool, CryptoReader, CryptoWriter};
use crate::transport::middle_proxy::MePool; use crate::transport::middle_proxy::MePool;
use crate::transport::{UpstreamManager, configure_client_socket}; use crate::transport::{configure_client_socket, UpstreamManager};
use crate::proxy::direct_relay::handle_via_direct; use crate::proxy::direct_relay::handle_via_direct;
use crate::proxy::handshake::{HandshakeSuccess, handle_mtproto_handshake, handle_tls_handshake}; use crate::proxy::handshake::{
handle_mtproto_handshake, handle_tls_handshake, HandshakeSuccess,
};
use crate::proxy::masking::handle_bad_client; use crate::proxy::masking::handle_bad_client;
use crate::proxy::middle_relay::handle_via_middle_proxy; use crate::proxy::middle_relay::handle_via_middle_proxy;
/// Handle a client connection from any stream type (TCP, Unix socket)
///
/// Generic client handler (TCP, Unix socket, etc)
/// ///
/// This is the generic entry point for client handling. Unlike `ClientHandler::new().run()`,
/// it skips TCP-specific socket configuration (TCP_NODELAY, keepalive, TCP_USER_TIMEOUT)
/// which is appropriate for non-TCP streams like Unix sockets.
pub async fn handle_client_stream<S>( pub async fn handle_client_stream<S>(
mut stream: S, mut stream: S,
peer: SocketAddr, peer: SocketAddr,
@ -39,456 +42,623 @@ pub async fn handle_client_stream<S>(
buffer_pool: Arc<BufferPool>, buffer_pool: Arc<BufferPool>,
rng: Arc<SecureRandom>, rng: Arc<SecureRandom>,
me_pool: Option<Arc<MePool>>, me_pool: Option<Arc<MePool>>,
ip_tracker: Arc<UserIpTracker>,
) -> Result<()> ) -> Result<()>
where where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static, S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{ {
stats.increment_connects_all(); stats.increment_connects_all();
debug!(peer = %peer, "New connection (generic stream)"); debug!(peer = %peer, "New connection (generic stream)");
let handshake_timeout = Duration::from_secs(config.timeouts.client_handshake); let handshake_timeout =
Duration::from_secs(config.timeouts.client_handshake);
let stats_for_timeout = stats.clone(); let stats_for_timeout = stats.clone();
// For non-TCP streams, use a synthetic local address
let local_addr: SocketAddr = format!("0.0.0.0:{}", config.server.port) let local_addr: SocketAddr = format!("0.0.0.0:{}", config.server.port)
.parse() .parse()
.unwrap_or_else(|_| "0.0.0.0:443".parse().unwrap()); .unwrap_or_else(|_| "0.0.0.0:443".parse().unwrap());
let result = timeout(handshake_timeout, async { let result = timeout(handshake_timeout, async {
let mut first_bytes = [0u8; 5]; let mut first_bytes = [0u8; 5];
stream.read_exact(&mut first_bytes).await?; stream.read_exact(&mut first_bytes).await?;
let is_tls = tls::is_tls_handshake(&first_bytes[..3]); let is_tls = tls::is_tls_handshake(&first_bytes[..3]);
debug!(peer = %peer, is_tls = is_tls, "Handshake type detected");
debug!(
peer = %peer,
is_tls = is_tls,
"Handshake type detected"
);
if is_tls { if is_tls {
let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize;
let tls_len =
u16::from_be_bytes([first_bytes[3], first_bytes[4]])
as usize;
if tls_len < 512 { if tls_len < 512 {
debug!(peer = %peer, tls_len = tls_len, "TLS handshake too short");
stats.increment_connects_bad(); stats.increment_connects_bad();
let (reader, writer) = tokio::io::split(stream);
handle_bad_client(reader, writer, &first_bytes, &config).await; let (reader, writer) =
tokio::io::split(stream);
handle_bad_client(
reader,
writer,
&first_bytes,
&config,
).await;
return Ok(()); return Ok(());
} }
let mut handshake = vec![0u8; 5 + tls_len]; let mut handshake =
handshake[..5].copy_from_slice(&first_bytes); vec![0u8; 5 + tls_len];
stream.read_exact(&mut handshake[5..]).await?;
let (read_half, write_half) = tokio::io::split(stream); handshake[..5]
.copy_from_slice(&first_bytes);
let (mut tls_reader, tls_writer, _tls_user) = match handle_tls_handshake( stream.read_exact(
&handshake, read_half, write_half, peer, &mut handshake[5..],
&config, &replay_checker, &rng, ).await?;
).await {
HandshakeResult::Success(result) => result,
HandshakeResult::BadClient { reader, writer } => {
stats.increment_connects_bad();
handle_bad_client(reader, writer, &handshake, &config).await;
return Ok(());
}
HandshakeResult::Error(e) => return Err(e),
};
debug!(peer = %peer, "Reading MTProto handshake through TLS"); let (read_half, write_half) =
let mtproto_data = tls_reader.read_exact(HANDSHAKE_LEN).await?; tokio::io::split(stream);
let mtproto_handshake: [u8; HANDSHAKE_LEN] = mtproto_data[..].try_into()
.map_err(|_| ProxyError::InvalidHandshake("Short MTProto handshake".into()))?;
let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake( let (
&mtproto_handshake, tls_reader, tls_writer, peer, mut tls_reader,
&config, &replay_checker, true, tls_writer,
).await { _tls_user,
HandshakeResult::Success(result) => result, ) =
HandshakeResult::BadClient { reader: _, writer: _ } => { match handle_tls_handshake(
stats.increment_connects_bad(); &handshake,
debug!(peer = %peer, "Valid TLS but invalid MTProto handshake"); read_half,
return Ok(()); write_half,
} peer,
HandshakeResult::Error(e) => return Err(e), &config,
}; &replay_checker,
&rng,
).await {
HandshakeResult::Success(x) => x,
HandshakeResult::BadClient {
reader,
writer,
} => {
stats.increment_connects_bad();
handle_bad_client(
reader,
writer,
&handshake,
&config,
).await;
return Ok(());
}
HandshakeResult::Error(e) =>
return Err(e),
};
let mtproto_data =
tls_reader.read_exact(
HANDSHAKE_LEN,
).await?;
let mtproto_handshake:
[u8; HANDSHAKE_LEN] =
mtproto_data[..]
.try_into()
.map_err(|_| {
ProxyError::InvalidHandshake(
"Short MTProto handshake".into()
)
})?;
let (
crypto_reader,
crypto_writer,
success,
) =
match handle_mtproto_handshake(
&mtproto_handshake,
tls_reader,
tls_writer,
peer,
&config,
&replay_checker,
true,
).await {
HandshakeResult::Success(x) => x,
HandshakeResult::BadClient {
reader: _,
writer: _,
} => {
stats.increment_connects_bad();
return Ok(());
}
HandshakeResult::Error(e) =>
return Err(e),
};
RunningClientHandler::handle_authenticated_static( RunningClientHandler::handle_authenticated_static(
crypto_reader, crypto_writer, success, crypto_reader,
upstream_manager, stats, config, buffer_pool, rng, me_pool, crypto_writer,
success,
upstream_manager,
stats,
config,
buffer_pool,
rng,
me_pool,
local_addr, local_addr,
peer,
ip_tracker.clone(),
).await ).await
} else { } else {
if !config.general.modes.classic && !config.general.modes.secure {
debug!(peer = %peer, "Non-TLS modes disabled"); if !config.general.modes.classic
&& !config.general.modes.secure
{
stats.increment_connects_bad(); stats.increment_connects_bad();
let (reader, writer) = tokio::io::split(stream);
handle_bad_client(reader, writer, &first_bytes, &config).await; let (reader, writer) =
tokio::io::split(stream);
handle_bad_client(
reader,
writer,
&first_bytes,
&config,
).await;
return Ok(()); return Ok(());
} }
let mut handshake = [0u8; HANDSHAKE_LEN]; let mut handshake =
handshake[..5].copy_from_slice(&first_bytes); [0u8; HANDSHAKE_LEN];
stream.read_exact(&mut handshake[5..]).await?;
let (read_half, write_half) = tokio::io::split(stream); handshake[..5]
.copy_from_slice(&first_bytes);
let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake( stream.read_exact(
&handshake, read_half, write_half, peer, &mut handshake[5..],
&config, &replay_checker, false, ).await?;
).await {
HandshakeResult::Success(result) => result, let (read_half, write_half) =
HandshakeResult::BadClient { reader, writer } => { tokio::io::split(stream);
stats.increment_connects_bad();
handle_bad_client(reader, writer, &handshake, &config).await; let (
return Ok(()); crypto_reader,
} crypto_writer,
HandshakeResult::Error(e) => return Err(e), success,
}; ) =
match handle_mtproto_handshake(
&handshake,
read_half,
write_half,
peer,
&config,
&replay_checker,
false,
).await {
HandshakeResult::Success(x) => x,
HandshakeResult::BadClient {
reader,
writer,
} => {
stats.increment_connects_bad();
handle_bad_client(
reader,
writer,
&handshake,
&config,
).await;
return Ok(());
}
HandshakeResult::Error(e) =>
return Err(e),
};
RunningClientHandler::handle_authenticated_static( RunningClientHandler::handle_authenticated_static(
crypto_reader, crypto_writer, success, crypto_reader,
upstream_manager, stats, config, buffer_pool, rng, me_pool, crypto_writer,
success,
upstream_manager,
stats,
config,
buffer_pool,
rng,
me_pool,
local_addr, local_addr,
peer,
ip_tracker.clone(),
).await ).await
} }
}).await; }).await;
match result { match result {
Ok(Ok(())) => {
debug!(peer = %peer, "Connection handled successfully"); Ok(Ok(())) => Ok(()),
Ok(())
} Ok(Err(e)) => Err(e),
Ok(Err(e)) => {
debug!(peer = %peer, error = %e, "Handshake failed");
Err(e)
}
Err(_) => { Err(_) => {
stats_for_timeout.increment_handshake_timeouts();
debug!(peer = %peer, "Handshake timeout"); stats_for_timeout
Err(ProxyError::TgHandshakeTimeout) .increment_handshake_timeouts();
Err(
ProxyError::TgHandshakeTimeout
)
} }
} }
} }
///
/// TCP-specific handler
///
pub struct ClientHandler; pub struct ClientHandler;
pub struct RunningClientHandler { pub struct RunningClientHandler {
stream: TcpStream, stream: TcpStream,
peer: SocketAddr, peer: SocketAddr,
config: Arc<ProxyConfig>, config: Arc<ProxyConfig>,
stats: Arc<Stats>, stats: Arc<Stats>,
replay_checker: Arc<ReplayChecker>, replay_checker: Arc<ReplayChecker>,
upstream_manager: Arc<UpstreamManager>, upstream_manager: Arc<UpstreamManager>,
buffer_pool: Arc<BufferPool>, buffer_pool: Arc<BufferPool>,
rng: Arc<SecureRandom>, rng: Arc<SecureRandom>,
me_pool: Option<Arc<MePool>>, me_pool: Option<Arc<MePool>>,
ip_tracker: Arc<UserIpTracker>, ip_tracker: Arc<UserIpTracker>,
} }
impl ClientHandler { impl ClientHandler {
pub fn new( pub fn new(
stream: TcpStream, stream: TcpStream,
peer: SocketAddr, peer: SocketAddr,
config: Arc<ProxyConfig>, config: Arc<ProxyConfig>,
stats: Arc<Stats>, stats: Arc<Stats>,
upstream_manager: Arc<UpstreamManager>, upstream_manager: Arc<UpstreamManager>,
replay_checker: Arc<ReplayChecker>, replay_checker: Arc<ReplayChecker>,
buffer_pool: Arc<BufferPool>, buffer_pool: Arc<BufferPool>,
rng: Arc<SecureRandom>, rng: Arc<SecureRandom>,
me_pool: Option<Arc<MePool>>, me_pool: Option<Arc<MePool>>,
ip_tracker: Arc<UserIpTracker>, ip_tracker: Arc<UserIpTracker>,
) -> RunningClientHandler { ) -> RunningClientHandler {
RunningClientHandler { RunningClientHandler {
stream, stream,
peer, peer,
config, config,
stats, stats,
replay_checker, replay_checker,
upstream_manager, upstream_manager,
buffer_pool, buffer_pool,
rng, rng,
me_pool, me_pool,
ip_tracker, ip_tracker,
} }
} }
} }
impl RunningClientHandler {
pub async fn run(mut self) -> Result<()> {
self.stats.increment_connects_all();
let peer = self.peer; impl RunningClientHandler {
debug!(peer = %peer, "New connection");
pub async fn run(mut self) -> Result<()> {
self.stats.increment_connects_all();
if let Err(e) = configure_client_socket( if let Err(e) = configure_client_socket(
&self.stream, &self.stream,
self.config.timeouts.client_keepalive, self.config.timeouts.client_keepalive,
self.config.timeouts.client_ack, self.config.timeouts.client_ack,
) { ) {
debug!(peer = %peer, error = %e, "Failed to configure client socket");
debug!(error = %e);
} }
let handshake_timeout = Duration::from_secs(self.config.timeouts.client_handshake); let timeout_dur =
let stats = self.stats.clone(); Duration::from_secs(
self.config.timeouts.client_handshake
);
let result = timeout(handshake_timeout, self.do_handshake()).await; timeout(
timeout_dur,
match result { self.do_handshake()
Ok(Ok(())) => { ).await?
debug!(peer = %peer, "Connection handled successfully");
Ok(())
}
Ok(Err(e)) => {
debug!(peer = %peer, error = %e, "Handshake failed");
Err(e)
}
Err(_) => {
stats.increment_handshake_timeouts();
debug!(peer = %peer, "Handshake timeout");
Err(ProxyError::TgHandshakeTimeout)
}
}
} }
async fn do_handshake(mut self) -> Result<()> { async fn do_handshake(mut self) -> Result<()> {
let mut first_bytes = [0u8; 5]; let mut first_bytes = [0u8; 5];
self.stream.read_exact(&mut first_bytes).await?;
let is_tls = tls::is_tls_handshake(&first_bytes[..3]); self.stream.read_exact(
let peer = self.peer; &mut first_bytes
).await?;
debug!(peer = %peer, is_tls = is_tls, "Handshake type detected"); let is_tls =
tls::is_tls_handshake(
&first_bytes[..3]
);
if is_tls { if is_tls {
self.handle_tls_client(first_bytes).await
self.handle_tls_client(
first_bytes
).await
} else { } else {
self.handle_direct_client(first_bytes).await
self.handle_direct_client(
first_bytes
).await
} }
} }
async fn handle_tls_client(mut self, first_bytes: [u8; 5]) -> Result<()> {
let peer = self.peer;
let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize; async fn handle_tls_client(
mut self,
first_bytes: [u8; 5]
) -> Result<()> {
debug!(peer = %peer, tls_len = tls_len, "Reading TLS handshake"); let tls_len =
u16::from_be_bytes(
[first_bytes[3],
first_bytes[4]]
) as usize;
if tls_len < 512 { if tls_len < 512 {
debug!(peer = %peer, tls_len = tls_len, "TLS handshake too short");
self.stats.increment_connects_bad(); self.stats.increment_connects_bad();
let (reader, writer) = self.stream.into_split();
handle_bad_client(reader, writer, &first_bytes, &self.config).await; let (r,w) =
return Ok(()); self.stream.into_split();
handle_bad_client(
r,
w,
&first_bytes,
&self.config
).await;
return Ok(())
} }
let mut handshake = vec![0u8; 5 + tls_len]; let mut handshake =
handshake[..5].copy_from_slice(&first_bytes); vec![0u8;5+tls_len];
self.stream.read_exact(&mut handshake[5..]).await?;
let config = self.config.clone(); handshake[..5]
let replay_checker = self.replay_checker.clone(); .copy_from_slice(
let stats = self.stats.clone(); &first_bytes
let buffer_pool = self.buffer_pool.clone(); );
let local_addr = self.stream.local_addr().map_err(ProxyError::Io)?; self.stream.read_exact(
let (read_half, write_half) = self.stream.into_split(); &mut handshake[5..]
).await?;
let (mut tls_reader, tls_writer, _tls_user) = match handle_tls_handshake( let local_addr =
&handshake, self.stream.local_addr()?;
read_half,
write_half,
peer,
&config,
&replay_checker,
&self.rng,
)
.await
{
HandshakeResult::Success(result) => result,
HandshakeResult::BadClient { reader, writer } => {
stats.increment_connects_bad();
handle_bad_client(reader, writer, &handshake, &config).await;
return Ok(());
}
HandshakeResult::Error(e) => return Err(e),
};
debug!(peer = %peer, "Reading MTProto handshake through TLS"); let (r,w) =
let mtproto_data = tls_reader.read_exact(HANDSHAKE_LEN).await?; self.stream.into_split();
let mtproto_handshake: [u8; HANDSHAKE_LEN] = mtproto_data[..]
.try_into()
.map_err(|_| ProxyError::InvalidHandshake("Short MTProto handshake".into()))?;
let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake( let (
&mtproto_handshake, mut tls_reader,
tls_reader,
tls_writer, tls_writer,
peer, _
&config, ) =
&replay_checker, match handle_tls_handshake(
true, &handshake,
) r,w,
.await self.peer,
{ &self.config,
HandshakeResult::Success(result) => result, &self.replay_checker,
HandshakeResult::BadClient { &self.rng
reader: _, ).await {
writer: _,
} => { HandshakeResult::Success(x)=>x,
stats.increment_connects_bad();
debug!(peer = %peer, "Valid TLS but invalid MTProto handshake"); HandshakeResult::BadClient{reader,writer}=>{
return Ok(()); handle_bad_client(
} reader,writer,
HandshakeResult::Error(e) => return Err(e), &handshake,
}; &self.config
).await;
return Ok(())
}
HandshakeResult::Error(e)=>return Err(e)
};
let mt =
tls_reader.read_exact(
HANDSHAKE_LEN
).await?;
let mt:
[u8;HANDSHAKE_LEN] =
mt.try_into().unwrap();
let (
cr,
cw,
success
) =
match handle_mtproto_handshake(
&mt,
tls_reader,
tls_writer,
self.peer,
&self.config,
&self.replay_checker,
true
).await {
HandshakeResult::Success(x)=>x,
HandshakeResult::BadClient{..}=>{
return Ok(())
}
HandshakeResult::Error(e)=>return Err(e)
};
Self::handle_authenticated_static( Self::handle_authenticated_static(
crypto_reader, cr,cw,success,
crypto_writer,
success,
self.upstream_manager, self.upstream_manager,
self.stats, self.stats,
self.config, self.config,
buffer_pool, self.buffer_pool,
self.rng, self.rng,
self.me_pool, self.me_pool,
local_addr, local_addr,
peer, self.peer,
self.ip_tracker, self.ip_tracker
) ).await
.await
} }
async fn handle_direct_client(mut self, first_bytes: [u8; 5]) -> Result<()> {
let peer = self.peer;
if !self.config.general.modes.classic && !self.config.general.modes.secure { async fn handle_direct_client(
debug!(peer = %peer, "Non-TLS modes disabled"); mut self,
self.stats.increment_connects_bad(); first_bytes:[u8;5]
let (reader, writer) = self.stream.into_split(); )->Result<()>{
handle_bad_client(reader, writer, &first_bytes, &self.config).await;
return Ok(());
}
let mut handshake = [0u8; HANDSHAKE_LEN]; let mut handshake=
handshake[..5].copy_from_slice(&first_bytes); [0u8;HANDSHAKE_LEN];
self.stream.read_exact(&mut handshake[5..]).await?;
let config = self.config.clone(); handshake[..5]
let replay_checker = self.replay_checker.clone(); .copy_from_slice(
let stats = self.stats.clone(); &first_bytes
let buffer_pool = self.buffer_pool.clone(); );
let local_addr = self.stream.local_addr().map_err(ProxyError::Io)?; self.stream.read_exact(
let (read_half, write_half) = self.stream.into_split(); &mut handshake[5..]
).await?;
let local_addr=
self.stream.local_addr()?;
let (r,w)=
self.stream.into_split();
let (
cr,
cw,
success
)=
handle_mtproto_handshake(
&handshake,
r,w,
self.peer,
&self.config,
&self.replay_checker,
false
).await?
.into_success()?;
let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake(
&handshake,
read_half,
write_half,
peer,
&config,
&replay_checker,
false,
)
.await
{
HandshakeResult::Success(result) => result,
HandshakeResult::BadClient { reader, writer } => {
stats.increment_connects_bad();
handle_bad_client(reader, writer, &handshake, &config).await;
return Ok(());
}
HandshakeResult::Error(e) => return Err(e),
};
Self::handle_authenticated_static( Self::handle_authenticated_static(
crypto_reader, cr,cw,success,
crypto_writer,
success,
self.upstream_manager, self.upstream_manager,
self.stats, self.stats,
self.config, self.config,
buffer_pool, self.buffer_pool,
self.rng, self.rng,
self.me_pool, self.me_pool,
local_addr, local_addr,
peer, self.peer,
self.ip_tracker, self.ip_tracker
) ).await
.await
} }
/// Main dispatch after successful handshake.
/// Two modes: pub(crate)
/// - Direct: TCP relay to TG DC (existing behavior) async fn handle_authenticated_static<R,W>(
/// - Middle Proxy: RPC multiplex through ME pool (new — supports CDN DCs) client_reader:CryptoReader<R>,
pub(crate) async fn handle_authenticated_static<R, W>( client_writer:CryptoWriter<W>,
client_reader: CryptoReader<R>, success:HandshakeSuccess,
client_writer: CryptoWriter<W>,
success: HandshakeSuccess, upstream_manager:Arc<UpstreamManager>,
upstream_manager: Arc<UpstreamManager>, stats:Arc<Stats>,
stats: Arc<Stats>, config:Arc<ProxyConfig>,
config: Arc<ProxyConfig>,
buffer_pool: Arc<BufferPool>, buffer_pool:Arc<BufferPool>,
rng: Arc<SecureRandom>, rng:Arc<SecureRandom>,
me_pool: Option<Arc<MePool>>, me_pool:Option<Arc<MePool>>,
local_addr: SocketAddr,
peer_addr: SocketAddr, local_addr:SocketAddr,
ip_tracker: Arc<UserIpTracker>, peer_addr:SocketAddr,
) -> Result<()>
ip_tracker:Arc<UserIpTracker>,
)->Result<()>
where where
R: AsyncRead + Unpin + Send + 'static, R:AsyncRead+Unpin+Send+'static,
W: AsyncWrite + Unpin + Send + 'static, W:AsyncWrite+Unpin+Send+'static,
{ {
let user = &success.user;
if let Err(e) = Self::check_user_limits_static(user, &config, &stats, peer_addr, &ip_tracker).await { let user=&success.user;
warn!(user = %user, error = %e, "User limit exceeded");
return Err(e); ip_tracker.check_and_add(
} user,
peer_addr.ip()
).await?;
// IP Cleanup Guard: автоматически удаляет IP при выходе из scope
struct IpCleanupGuard {
tracker: Arc<UserIpTracker>,
user: String,
ip: std::net::IpAddr,
}
impl Drop for IpCleanupGuard {
fn drop(&mut self) {
let tracker = self.tracker.clone();
let user = self.user.clone();
let ip = self.ip;
tokio::spawn(async move {
tracker.remove_ip(&user, ip).await;
debug!(user = %user, ip = %ip, "IP cleaned up on disconnect");
});
}
}
let _cleanup = IpCleanupGuard {
tracker: ip_tracker,
user: user.clone(),
ip: peer_addr.ip(),
};
// Decide: middle proxy or direct
if config.general.use_middle_proxy { if config.general.use_middle_proxy {
if let Some(ref pool) = me_pool {
if let Some(pool)=me_pool{
return handle_via_middle_proxy( return handle_via_middle_proxy(
client_reader, client_reader,
client_writer, client_writer,
success, success,
pool.clone(), pool,
stats, stats,
config, config,
buffer_pool, buffer_pool,
local_addr, local_addr
) ).await
.await;
} }
warn!("use_middle_proxy=true but MePool not initialized, falling back to direct");
} }
// Direct mode (original behavior)
handle_via_direct( handle_via_direct(
client_reader, client_reader,
client_writer, client_writer,
@ -497,55 +667,8 @@ impl RunningClientHandler {
stats, stats,
config, config,
buffer_pool, buffer_pool,
rng, rng
) ).await
.await
} }
async fn check_user_limits_static(
user: &str,
config: &ProxyConfig,
stats: &Stats,
peer_addr: SocketAddr,
ip_tracker: &UserIpTracker,
) -> Result<()> {
if let Some(expiration) = config.access.user_expirations.get(user) {
if chrono::Utc::now() > *expiration {
return Err(ProxyError::UserExpired {
user: user.to_string(),
});
}
}
// IP limit check
if let Err(reason) = ip_tracker.check_and_add(user, peer_addr.ip()).await {
warn!(
user = %user,
ip = %peer_addr.ip(),
reason = %reason,
"IP limit exceeded"
);
return Err(ProxyError::ConnectionLimitExceeded {
user: user.to_string(),
});
}
if let Some(limit) = config.access.user_max_tcp_conns.get(user) {
if stats.get_user_curr_connects(user) >= *limit as u64 {
return Err(ProxyError::ConnectionLimitExceeded {
user: user.to_string(),
});
}
}
if let Some(quota) = config.access.user_data_quota.get(user) {
if stats.get_user_total_octets(user) >= *quota {
return Err(ProxyError::DataQuotaExceeded {
user: user.to_string(),
});
}
}
Ok(())
}
} }