mirror of https://github.com/telemt/telemt.git
Update client.rs
This commit is contained in:
parent
41c90af02d
commit
a6e03cfcdd
|
|
@ -3,9 +3,11 @@
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
|
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
|
||||||
use tokio::net::TcpStream;
|
use tokio::net::TcpStream;
|
||||||
use tokio::time::timeout;
|
use tokio::time::timeout;
|
||||||
|
|
||||||
use tracing::{debug, warn};
|
use tracing::{debug, warn};
|
||||||
|
|
||||||
use crate::config::ProxyConfig;
|
use crate::config::ProxyConfig;
|
||||||
|
|
@ -17,18 +19,19 @@ use crate::protocol::tls;
|
||||||
use crate::stats::{ReplayChecker, Stats};
|
use crate::stats::{ReplayChecker, Stats};
|
||||||
use crate::stream::{BufferPool, CryptoReader, CryptoWriter};
|
use crate::stream::{BufferPool, CryptoReader, CryptoWriter};
|
||||||
use crate::transport::middle_proxy::MePool;
|
use crate::transport::middle_proxy::MePool;
|
||||||
use crate::transport::{UpstreamManager, configure_client_socket};
|
use crate::transport::{configure_client_socket, UpstreamManager};
|
||||||
|
|
||||||
use crate::proxy::direct_relay::handle_via_direct;
|
use crate::proxy::direct_relay::handle_via_direct;
|
||||||
use crate::proxy::handshake::{HandshakeSuccess, handle_mtproto_handshake, handle_tls_handshake};
|
use crate::proxy::handshake::{
|
||||||
|
handle_mtproto_handshake, handle_tls_handshake, HandshakeSuccess,
|
||||||
|
};
|
||||||
use crate::proxy::masking::handle_bad_client;
|
use crate::proxy::masking::handle_bad_client;
|
||||||
use crate::proxy::middle_relay::handle_via_middle_proxy;
|
use crate::proxy::middle_relay::handle_via_middle_proxy;
|
||||||
|
|
||||||
/// Handle a client connection from any stream type (TCP, Unix socket)
|
|
||||||
|
///
|
||||||
|
/// Generic client handler (TCP, Unix socket, etc)
|
||||||
///
|
///
|
||||||
/// This is the generic entry point for client handling. Unlike `ClientHandler::new().run()`,
|
|
||||||
/// it skips TCP-specific socket configuration (TCP_NODELAY, keepalive, TCP_USER_TIMEOUT)
|
|
||||||
/// which is appropriate for non-TCP streams like Unix sockets.
|
|
||||||
pub async fn handle_client_stream<S>(
|
pub async fn handle_client_stream<S>(
|
||||||
mut stream: S,
|
mut stream: S,
|
||||||
peer: SocketAddr,
|
peer: SocketAddr,
|
||||||
|
|
@ -39,456 +42,623 @@ pub async fn handle_client_stream<S>(
|
||||||
buffer_pool: Arc<BufferPool>,
|
buffer_pool: Arc<BufferPool>,
|
||||||
rng: Arc<SecureRandom>,
|
rng: Arc<SecureRandom>,
|
||||||
me_pool: Option<Arc<MePool>>,
|
me_pool: Option<Arc<MePool>>,
|
||||||
|
ip_tracker: Arc<UserIpTracker>,
|
||||||
) -> Result<()>
|
) -> Result<()>
|
||||||
where
|
where
|
||||||
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
||||||
{
|
{
|
||||||
stats.increment_connects_all();
|
stats.increment_connects_all();
|
||||||
|
|
||||||
debug!(peer = %peer, "New connection (generic stream)");
|
debug!(peer = %peer, "New connection (generic stream)");
|
||||||
|
|
||||||
let handshake_timeout = Duration::from_secs(config.timeouts.client_handshake);
|
let handshake_timeout =
|
||||||
|
Duration::from_secs(config.timeouts.client_handshake);
|
||||||
|
|
||||||
let stats_for_timeout = stats.clone();
|
let stats_for_timeout = stats.clone();
|
||||||
|
|
||||||
// For non-TCP streams, use a synthetic local address
|
|
||||||
let local_addr: SocketAddr = format!("0.0.0.0:{}", config.server.port)
|
let local_addr: SocketAddr = format!("0.0.0.0:{}", config.server.port)
|
||||||
.parse()
|
.parse()
|
||||||
.unwrap_or_else(|_| "0.0.0.0:443".parse().unwrap());
|
.unwrap_or_else(|_| "0.0.0.0:443".parse().unwrap());
|
||||||
|
|
||||||
let result = timeout(handshake_timeout, async {
|
let result = timeout(handshake_timeout, async {
|
||||||
|
|
||||||
let mut first_bytes = [0u8; 5];
|
let mut first_bytes = [0u8; 5];
|
||||||
stream.read_exact(&mut first_bytes).await?;
|
stream.read_exact(&mut first_bytes).await?;
|
||||||
|
|
||||||
let is_tls = tls::is_tls_handshake(&first_bytes[..3]);
|
let is_tls = tls::is_tls_handshake(&first_bytes[..3]);
|
||||||
debug!(peer = %peer, is_tls = is_tls, "Handshake type detected");
|
|
||||||
|
debug!(
|
||||||
|
peer = %peer,
|
||||||
|
is_tls = is_tls,
|
||||||
|
"Handshake type detected"
|
||||||
|
);
|
||||||
|
|
||||||
if is_tls {
|
if is_tls {
|
||||||
let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize;
|
|
||||||
|
let tls_len =
|
||||||
|
u16::from_be_bytes([first_bytes[3], first_bytes[4]])
|
||||||
|
as usize;
|
||||||
|
|
||||||
if tls_len < 512 {
|
if tls_len < 512 {
|
||||||
debug!(peer = %peer, tls_len = tls_len, "TLS handshake too short");
|
|
||||||
stats.increment_connects_bad();
|
stats.increment_connects_bad();
|
||||||
let (reader, writer) = tokio::io::split(stream);
|
|
||||||
handle_bad_client(reader, writer, &first_bytes, &config).await;
|
let (reader, writer) =
|
||||||
|
tokio::io::split(stream);
|
||||||
|
|
||||||
|
handle_bad_client(
|
||||||
|
reader,
|
||||||
|
writer,
|
||||||
|
&first_bytes,
|
||||||
|
&config,
|
||||||
|
).await;
|
||||||
|
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut handshake = vec![0u8; 5 + tls_len];
|
let mut handshake =
|
||||||
handshake[..5].copy_from_slice(&first_bytes);
|
vec![0u8; 5 + tls_len];
|
||||||
stream.read_exact(&mut handshake[5..]).await?;
|
|
||||||
|
|
||||||
let (read_half, write_half) = tokio::io::split(stream);
|
handshake[..5]
|
||||||
|
.copy_from_slice(&first_bytes);
|
||||||
|
|
||||||
let (mut tls_reader, tls_writer, _tls_user) = match handle_tls_handshake(
|
stream.read_exact(
|
||||||
&handshake, read_half, write_half, peer,
|
&mut handshake[5..],
|
||||||
&config, &replay_checker, &rng,
|
).await?;
|
||||||
).await {
|
|
||||||
HandshakeResult::Success(result) => result,
|
|
||||||
HandshakeResult::BadClient { reader, writer } => {
|
|
||||||
stats.increment_connects_bad();
|
|
||||||
handle_bad_client(reader, writer, &handshake, &config).await;
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
HandshakeResult::Error(e) => return Err(e),
|
|
||||||
};
|
|
||||||
|
|
||||||
debug!(peer = %peer, "Reading MTProto handshake through TLS");
|
let (read_half, write_half) =
|
||||||
let mtproto_data = tls_reader.read_exact(HANDSHAKE_LEN).await?;
|
tokio::io::split(stream);
|
||||||
let mtproto_handshake: [u8; HANDSHAKE_LEN] = mtproto_data[..].try_into()
|
|
||||||
.map_err(|_| ProxyError::InvalidHandshake("Short MTProto handshake".into()))?;
|
|
||||||
|
|
||||||
let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake(
|
let (
|
||||||
&mtproto_handshake, tls_reader, tls_writer, peer,
|
mut tls_reader,
|
||||||
&config, &replay_checker, true,
|
tls_writer,
|
||||||
).await {
|
_tls_user,
|
||||||
HandshakeResult::Success(result) => result,
|
) =
|
||||||
HandshakeResult::BadClient { reader: _, writer: _ } => {
|
match handle_tls_handshake(
|
||||||
stats.increment_connects_bad();
|
&handshake,
|
||||||
debug!(peer = %peer, "Valid TLS but invalid MTProto handshake");
|
read_half,
|
||||||
return Ok(());
|
write_half,
|
||||||
}
|
peer,
|
||||||
HandshakeResult::Error(e) => return Err(e),
|
&config,
|
||||||
};
|
&replay_checker,
|
||||||
|
&rng,
|
||||||
|
).await {
|
||||||
|
|
||||||
|
HandshakeResult::Success(x) => x,
|
||||||
|
|
||||||
|
HandshakeResult::BadClient {
|
||||||
|
reader,
|
||||||
|
writer,
|
||||||
|
} => {
|
||||||
|
|
||||||
|
stats.increment_connects_bad();
|
||||||
|
|
||||||
|
handle_bad_client(
|
||||||
|
reader,
|
||||||
|
writer,
|
||||||
|
&handshake,
|
||||||
|
&config,
|
||||||
|
).await;
|
||||||
|
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
HandshakeResult::Error(e) =>
|
||||||
|
return Err(e),
|
||||||
|
};
|
||||||
|
|
||||||
|
let mtproto_data =
|
||||||
|
tls_reader.read_exact(
|
||||||
|
HANDSHAKE_LEN,
|
||||||
|
).await?;
|
||||||
|
|
||||||
|
let mtproto_handshake:
|
||||||
|
[u8; HANDSHAKE_LEN] =
|
||||||
|
mtproto_data[..]
|
||||||
|
.try_into()
|
||||||
|
.map_err(|_| {
|
||||||
|
ProxyError::InvalidHandshake(
|
||||||
|
"Short MTProto handshake".into()
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let (
|
||||||
|
crypto_reader,
|
||||||
|
crypto_writer,
|
||||||
|
success,
|
||||||
|
) =
|
||||||
|
match handle_mtproto_handshake(
|
||||||
|
&mtproto_handshake,
|
||||||
|
tls_reader,
|
||||||
|
tls_writer,
|
||||||
|
peer,
|
||||||
|
&config,
|
||||||
|
&replay_checker,
|
||||||
|
true,
|
||||||
|
).await {
|
||||||
|
|
||||||
|
HandshakeResult::Success(x) => x,
|
||||||
|
|
||||||
|
HandshakeResult::BadClient {
|
||||||
|
reader: _,
|
||||||
|
writer: _,
|
||||||
|
} => {
|
||||||
|
|
||||||
|
stats.increment_connects_bad();
|
||||||
|
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
HandshakeResult::Error(e) =>
|
||||||
|
return Err(e),
|
||||||
|
};
|
||||||
|
|
||||||
RunningClientHandler::handle_authenticated_static(
|
RunningClientHandler::handle_authenticated_static(
|
||||||
crypto_reader, crypto_writer, success,
|
crypto_reader,
|
||||||
upstream_manager, stats, config, buffer_pool, rng, me_pool,
|
crypto_writer,
|
||||||
|
success,
|
||||||
|
upstream_manager,
|
||||||
|
stats,
|
||||||
|
config,
|
||||||
|
buffer_pool,
|
||||||
|
rng,
|
||||||
|
me_pool,
|
||||||
local_addr,
|
local_addr,
|
||||||
|
peer,
|
||||||
|
ip_tracker.clone(),
|
||||||
).await
|
).await
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
if !config.general.modes.classic && !config.general.modes.secure {
|
|
||||||
debug!(peer = %peer, "Non-TLS modes disabled");
|
if !config.general.modes.classic
|
||||||
|
&& !config.general.modes.secure
|
||||||
|
{
|
||||||
|
|
||||||
stats.increment_connects_bad();
|
stats.increment_connects_bad();
|
||||||
let (reader, writer) = tokio::io::split(stream);
|
|
||||||
handle_bad_client(reader, writer, &first_bytes, &config).await;
|
let (reader, writer) =
|
||||||
|
tokio::io::split(stream);
|
||||||
|
|
||||||
|
handle_bad_client(
|
||||||
|
reader,
|
||||||
|
writer,
|
||||||
|
&first_bytes,
|
||||||
|
&config,
|
||||||
|
).await;
|
||||||
|
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut handshake = [0u8; HANDSHAKE_LEN];
|
let mut handshake =
|
||||||
handshake[..5].copy_from_slice(&first_bytes);
|
[0u8; HANDSHAKE_LEN];
|
||||||
stream.read_exact(&mut handshake[5..]).await?;
|
|
||||||
|
|
||||||
let (read_half, write_half) = tokio::io::split(stream);
|
handshake[..5]
|
||||||
|
.copy_from_slice(&first_bytes);
|
||||||
|
|
||||||
let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake(
|
stream.read_exact(
|
||||||
&handshake, read_half, write_half, peer,
|
&mut handshake[5..],
|
||||||
&config, &replay_checker, false,
|
).await?;
|
||||||
).await {
|
|
||||||
HandshakeResult::Success(result) => result,
|
let (read_half, write_half) =
|
||||||
HandshakeResult::BadClient { reader, writer } => {
|
tokio::io::split(stream);
|
||||||
stats.increment_connects_bad();
|
|
||||||
handle_bad_client(reader, writer, &handshake, &config).await;
|
let (
|
||||||
return Ok(());
|
crypto_reader,
|
||||||
}
|
crypto_writer,
|
||||||
HandshakeResult::Error(e) => return Err(e),
|
success,
|
||||||
};
|
) =
|
||||||
|
match handle_mtproto_handshake(
|
||||||
|
&handshake,
|
||||||
|
read_half,
|
||||||
|
write_half,
|
||||||
|
peer,
|
||||||
|
&config,
|
||||||
|
&replay_checker,
|
||||||
|
false,
|
||||||
|
).await {
|
||||||
|
|
||||||
|
HandshakeResult::Success(x) => x,
|
||||||
|
|
||||||
|
HandshakeResult::BadClient {
|
||||||
|
reader,
|
||||||
|
writer,
|
||||||
|
} => {
|
||||||
|
|
||||||
|
stats.increment_connects_bad();
|
||||||
|
|
||||||
|
handle_bad_client(
|
||||||
|
reader,
|
||||||
|
writer,
|
||||||
|
&handshake,
|
||||||
|
&config,
|
||||||
|
).await;
|
||||||
|
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
HandshakeResult::Error(e) =>
|
||||||
|
return Err(e),
|
||||||
|
};
|
||||||
|
|
||||||
RunningClientHandler::handle_authenticated_static(
|
RunningClientHandler::handle_authenticated_static(
|
||||||
crypto_reader, crypto_writer, success,
|
crypto_reader,
|
||||||
upstream_manager, stats, config, buffer_pool, rng, me_pool,
|
crypto_writer,
|
||||||
|
success,
|
||||||
|
upstream_manager,
|
||||||
|
stats,
|
||||||
|
config,
|
||||||
|
buffer_pool,
|
||||||
|
rng,
|
||||||
|
me_pool,
|
||||||
local_addr,
|
local_addr,
|
||||||
|
peer,
|
||||||
|
ip_tracker.clone(),
|
||||||
).await
|
).await
|
||||||
}
|
}
|
||||||
|
|
||||||
}).await;
|
}).await;
|
||||||
|
|
||||||
match result {
|
match result {
|
||||||
Ok(Ok(())) => {
|
|
||||||
debug!(peer = %peer, "Connection handled successfully");
|
Ok(Ok(())) => Ok(()),
|
||||||
Ok(())
|
|
||||||
}
|
Ok(Err(e)) => Err(e),
|
||||||
Ok(Err(e)) => {
|
|
||||||
debug!(peer = %peer, error = %e, "Handshake failed");
|
|
||||||
Err(e)
|
|
||||||
}
|
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
stats_for_timeout.increment_handshake_timeouts();
|
|
||||||
debug!(peer = %peer, "Handshake timeout");
|
stats_for_timeout
|
||||||
Err(ProxyError::TgHandshakeTimeout)
|
.increment_handshake_timeouts();
|
||||||
|
|
||||||
|
Err(
|
||||||
|
ProxyError::TgHandshakeTimeout
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
///
|
||||||
|
/// TCP-specific handler
|
||||||
|
///
|
||||||
pub struct ClientHandler;
|
pub struct ClientHandler;
|
||||||
|
|
||||||
pub struct RunningClientHandler {
|
pub struct RunningClientHandler {
|
||||||
|
|
||||||
stream: TcpStream,
|
stream: TcpStream,
|
||||||
peer: SocketAddr,
|
peer: SocketAddr,
|
||||||
|
|
||||||
config: Arc<ProxyConfig>,
|
config: Arc<ProxyConfig>,
|
||||||
stats: Arc<Stats>,
|
stats: Arc<Stats>,
|
||||||
replay_checker: Arc<ReplayChecker>,
|
replay_checker: Arc<ReplayChecker>,
|
||||||
upstream_manager: Arc<UpstreamManager>,
|
upstream_manager: Arc<UpstreamManager>,
|
||||||
|
|
||||||
buffer_pool: Arc<BufferPool>,
|
buffer_pool: Arc<BufferPool>,
|
||||||
rng: Arc<SecureRandom>,
|
rng: Arc<SecureRandom>,
|
||||||
|
|
||||||
me_pool: Option<Arc<MePool>>,
|
me_pool: Option<Arc<MePool>>,
|
||||||
ip_tracker: Arc<UserIpTracker>,
|
ip_tracker: Arc<UserIpTracker>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
impl ClientHandler {
|
impl ClientHandler {
|
||||||
|
|
||||||
pub fn new(
|
pub fn new(
|
||||||
stream: TcpStream,
|
stream: TcpStream,
|
||||||
peer: SocketAddr,
|
peer: SocketAddr,
|
||||||
|
|
||||||
config: Arc<ProxyConfig>,
|
config: Arc<ProxyConfig>,
|
||||||
stats: Arc<Stats>,
|
stats: Arc<Stats>,
|
||||||
upstream_manager: Arc<UpstreamManager>,
|
upstream_manager: Arc<UpstreamManager>,
|
||||||
replay_checker: Arc<ReplayChecker>,
|
replay_checker: Arc<ReplayChecker>,
|
||||||
|
|
||||||
buffer_pool: Arc<BufferPool>,
|
buffer_pool: Arc<BufferPool>,
|
||||||
rng: Arc<SecureRandom>,
|
rng: Arc<SecureRandom>,
|
||||||
|
|
||||||
me_pool: Option<Arc<MePool>>,
|
me_pool: Option<Arc<MePool>>,
|
||||||
ip_tracker: Arc<UserIpTracker>,
|
ip_tracker: Arc<UserIpTracker>,
|
||||||
) -> RunningClientHandler {
|
) -> RunningClientHandler {
|
||||||
|
|
||||||
RunningClientHandler {
|
RunningClientHandler {
|
||||||
|
|
||||||
stream,
|
stream,
|
||||||
peer,
|
peer,
|
||||||
|
|
||||||
config,
|
config,
|
||||||
stats,
|
stats,
|
||||||
replay_checker,
|
replay_checker,
|
||||||
upstream_manager,
|
upstream_manager,
|
||||||
|
|
||||||
buffer_pool,
|
buffer_pool,
|
||||||
rng,
|
rng,
|
||||||
|
|
||||||
me_pool,
|
me_pool,
|
||||||
ip_tracker,
|
ip_tracker,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RunningClientHandler {
|
|
||||||
pub async fn run(mut self) -> Result<()> {
|
|
||||||
self.stats.increment_connects_all();
|
|
||||||
|
|
||||||
let peer = self.peer;
|
impl RunningClientHandler {
|
||||||
debug!(peer = %peer, "New connection");
|
|
||||||
|
pub async fn run(mut self) -> Result<()> {
|
||||||
|
|
||||||
|
self.stats.increment_connects_all();
|
||||||
|
|
||||||
if let Err(e) = configure_client_socket(
|
if let Err(e) = configure_client_socket(
|
||||||
&self.stream,
|
&self.stream,
|
||||||
self.config.timeouts.client_keepalive,
|
self.config.timeouts.client_keepalive,
|
||||||
self.config.timeouts.client_ack,
|
self.config.timeouts.client_ack,
|
||||||
) {
|
) {
|
||||||
debug!(peer = %peer, error = %e, "Failed to configure client socket");
|
|
||||||
|
debug!(error = %e);
|
||||||
}
|
}
|
||||||
|
|
||||||
let handshake_timeout = Duration::from_secs(self.config.timeouts.client_handshake);
|
let timeout_dur =
|
||||||
let stats = self.stats.clone();
|
Duration::from_secs(
|
||||||
|
self.config.timeouts.client_handshake
|
||||||
|
);
|
||||||
|
|
||||||
let result = timeout(handshake_timeout, self.do_handshake()).await;
|
timeout(
|
||||||
|
timeout_dur,
|
||||||
match result {
|
self.do_handshake()
|
||||||
Ok(Ok(())) => {
|
).await?
|
||||||
debug!(peer = %peer, "Connection handled successfully");
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
Ok(Err(e)) => {
|
|
||||||
debug!(peer = %peer, error = %e, "Handshake failed");
|
|
||||||
Err(e)
|
|
||||||
}
|
|
||||||
Err(_) => {
|
|
||||||
stats.increment_handshake_timeouts();
|
|
||||||
debug!(peer = %peer, "Handshake timeout");
|
|
||||||
Err(ProxyError::TgHandshakeTimeout)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
async fn do_handshake(mut self) -> Result<()> {
|
async fn do_handshake(mut self) -> Result<()> {
|
||||||
|
|
||||||
let mut first_bytes = [0u8; 5];
|
let mut first_bytes = [0u8; 5];
|
||||||
self.stream.read_exact(&mut first_bytes).await?;
|
|
||||||
|
|
||||||
let is_tls = tls::is_tls_handshake(&first_bytes[..3]);
|
self.stream.read_exact(
|
||||||
let peer = self.peer;
|
&mut first_bytes
|
||||||
|
).await?;
|
||||||
|
|
||||||
debug!(peer = %peer, is_tls = is_tls, "Handshake type detected");
|
let is_tls =
|
||||||
|
tls::is_tls_handshake(
|
||||||
|
&first_bytes[..3]
|
||||||
|
);
|
||||||
|
|
||||||
if is_tls {
|
if is_tls {
|
||||||
self.handle_tls_client(first_bytes).await
|
|
||||||
|
self.handle_tls_client(
|
||||||
|
first_bytes
|
||||||
|
).await
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
self.handle_direct_client(first_bytes).await
|
|
||||||
|
self.handle_direct_client(
|
||||||
|
first_bytes
|
||||||
|
).await
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_tls_client(mut self, first_bytes: [u8; 5]) -> Result<()> {
|
|
||||||
let peer = self.peer;
|
|
||||||
|
|
||||||
let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize;
|
async fn handle_tls_client(
|
||||||
|
mut self,
|
||||||
|
first_bytes: [u8; 5]
|
||||||
|
) -> Result<()> {
|
||||||
|
|
||||||
debug!(peer = %peer, tls_len = tls_len, "Reading TLS handshake");
|
let tls_len =
|
||||||
|
u16::from_be_bytes(
|
||||||
|
[first_bytes[3],
|
||||||
|
first_bytes[4]]
|
||||||
|
) as usize;
|
||||||
|
|
||||||
if tls_len < 512 {
|
if tls_len < 512 {
|
||||||
debug!(peer = %peer, tls_len = tls_len, "TLS handshake too short");
|
|
||||||
self.stats.increment_connects_bad();
|
self.stats.increment_connects_bad();
|
||||||
let (reader, writer) = self.stream.into_split();
|
|
||||||
handle_bad_client(reader, writer, &first_bytes, &self.config).await;
|
let (r,w) =
|
||||||
return Ok(());
|
self.stream.into_split();
|
||||||
|
|
||||||
|
handle_bad_client(
|
||||||
|
r,
|
||||||
|
w,
|
||||||
|
&first_bytes,
|
||||||
|
&self.config
|
||||||
|
).await;
|
||||||
|
|
||||||
|
return Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut handshake = vec![0u8; 5 + tls_len];
|
let mut handshake =
|
||||||
handshake[..5].copy_from_slice(&first_bytes);
|
vec![0u8;5+tls_len];
|
||||||
self.stream.read_exact(&mut handshake[5..]).await?;
|
|
||||||
|
|
||||||
let config = self.config.clone();
|
handshake[..5]
|
||||||
let replay_checker = self.replay_checker.clone();
|
.copy_from_slice(
|
||||||
let stats = self.stats.clone();
|
&first_bytes
|
||||||
let buffer_pool = self.buffer_pool.clone();
|
);
|
||||||
|
|
||||||
let local_addr = self.stream.local_addr().map_err(ProxyError::Io)?;
|
self.stream.read_exact(
|
||||||
let (read_half, write_half) = self.stream.into_split();
|
&mut handshake[5..]
|
||||||
|
).await?;
|
||||||
|
|
||||||
let (mut tls_reader, tls_writer, _tls_user) = match handle_tls_handshake(
|
let local_addr =
|
||||||
&handshake,
|
self.stream.local_addr()?;
|
||||||
read_half,
|
|
||||||
write_half,
|
|
||||||
peer,
|
|
||||||
&config,
|
|
||||||
&replay_checker,
|
|
||||||
&self.rng,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
HandshakeResult::Success(result) => result,
|
|
||||||
HandshakeResult::BadClient { reader, writer } => {
|
|
||||||
stats.increment_connects_bad();
|
|
||||||
handle_bad_client(reader, writer, &handshake, &config).await;
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
HandshakeResult::Error(e) => return Err(e),
|
|
||||||
};
|
|
||||||
|
|
||||||
debug!(peer = %peer, "Reading MTProto handshake through TLS");
|
let (r,w) =
|
||||||
let mtproto_data = tls_reader.read_exact(HANDSHAKE_LEN).await?;
|
self.stream.into_split();
|
||||||
let mtproto_handshake: [u8; HANDSHAKE_LEN] = mtproto_data[..]
|
|
||||||
.try_into()
|
|
||||||
.map_err(|_| ProxyError::InvalidHandshake("Short MTProto handshake".into()))?;
|
|
||||||
|
|
||||||
let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake(
|
let (
|
||||||
&mtproto_handshake,
|
mut tls_reader,
|
||||||
tls_reader,
|
|
||||||
tls_writer,
|
tls_writer,
|
||||||
peer,
|
_
|
||||||
&config,
|
) =
|
||||||
&replay_checker,
|
match handle_tls_handshake(
|
||||||
true,
|
&handshake,
|
||||||
)
|
r,w,
|
||||||
.await
|
self.peer,
|
||||||
{
|
&self.config,
|
||||||
HandshakeResult::Success(result) => result,
|
&self.replay_checker,
|
||||||
HandshakeResult::BadClient {
|
&self.rng
|
||||||
reader: _,
|
).await {
|
||||||
writer: _,
|
|
||||||
} => {
|
HandshakeResult::Success(x)=>x,
|
||||||
stats.increment_connects_bad();
|
|
||||||
debug!(peer = %peer, "Valid TLS but invalid MTProto handshake");
|
HandshakeResult::BadClient{reader,writer}=>{
|
||||||
return Ok(());
|
handle_bad_client(
|
||||||
}
|
reader,writer,
|
||||||
HandshakeResult::Error(e) => return Err(e),
|
&handshake,
|
||||||
};
|
&self.config
|
||||||
|
).await;
|
||||||
|
|
||||||
|
return Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
HandshakeResult::Error(e)=>return Err(e)
|
||||||
|
};
|
||||||
|
|
||||||
|
let mt =
|
||||||
|
tls_reader.read_exact(
|
||||||
|
HANDSHAKE_LEN
|
||||||
|
).await?;
|
||||||
|
|
||||||
|
let mt:
|
||||||
|
[u8;HANDSHAKE_LEN] =
|
||||||
|
mt.try_into().unwrap();
|
||||||
|
|
||||||
|
let (
|
||||||
|
cr,
|
||||||
|
cw,
|
||||||
|
success
|
||||||
|
) =
|
||||||
|
match handle_mtproto_handshake(
|
||||||
|
&mt,
|
||||||
|
tls_reader,
|
||||||
|
tls_writer,
|
||||||
|
self.peer,
|
||||||
|
&self.config,
|
||||||
|
&self.replay_checker,
|
||||||
|
true
|
||||||
|
).await {
|
||||||
|
|
||||||
|
HandshakeResult::Success(x)=>x,
|
||||||
|
|
||||||
|
HandshakeResult::BadClient{..}=>{
|
||||||
|
return Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
HandshakeResult::Error(e)=>return Err(e)
|
||||||
|
};
|
||||||
|
|
||||||
Self::handle_authenticated_static(
|
Self::handle_authenticated_static(
|
||||||
crypto_reader,
|
cr,cw,success,
|
||||||
crypto_writer,
|
|
||||||
success,
|
|
||||||
self.upstream_manager,
|
self.upstream_manager,
|
||||||
self.stats,
|
self.stats,
|
||||||
self.config,
|
self.config,
|
||||||
buffer_pool,
|
self.buffer_pool,
|
||||||
self.rng,
|
self.rng,
|
||||||
self.me_pool,
|
self.me_pool,
|
||||||
local_addr,
|
local_addr,
|
||||||
peer,
|
self.peer,
|
||||||
self.ip_tracker,
|
self.ip_tracker
|
||||||
)
|
).await
|
||||||
.await
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_direct_client(mut self, first_bytes: [u8; 5]) -> Result<()> {
|
|
||||||
let peer = self.peer;
|
|
||||||
|
|
||||||
if !self.config.general.modes.classic && !self.config.general.modes.secure {
|
async fn handle_direct_client(
|
||||||
debug!(peer = %peer, "Non-TLS modes disabled");
|
mut self,
|
||||||
self.stats.increment_connects_bad();
|
first_bytes:[u8;5]
|
||||||
let (reader, writer) = self.stream.into_split();
|
)->Result<()>{
|
||||||
handle_bad_client(reader, writer, &first_bytes, &self.config).await;
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut handshake = [0u8; HANDSHAKE_LEN];
|
let mut handshake=
|
||||||
handshake[..5].copy_from_slice(&first_bytes);
|
[0u8;HANDSHAKE_LEN];
|
||||||
self.stream.read_exact(&mut handshake[5..]).await?;
|
|
||||||
|
|
||||||
let config = self.config.clone();
|
handshake[..5]
|
||||||
let replay_checker = self.replay_checker.clone();
|
.copy_from_slice(
|
||||||
let stats = self.stats.clone();
|
&first_bytes
|
||||||
let buffer_pool = self.buffer_pool.clone();
|
);
|
||||||
|
|
||||||
let local_addr = self.stream.local_addr().map_err(ProxyError::Io)?;
|
self.stream.read_exact(
|
||||||
let (read_half, write_half) = self.stream.into_split();
|
&mut handshake[5..]
|
||||||
|
).await?;
|
||||||
|
|
||||||
|
let local_addr=
|
||||||
|
self.stream.local_addr()?;
|
||||||
|
|
||||||
|
let (r,w)=
|
||||||
|
self.stream.into_split();
|
||||||
|
|
||||||
|
let (
|
||||||
|
cr,
|
||||||
|
cw,
|
||||||
|
success
|
||||||
|
)=
|
||||||
|
handle_mtproto_handshake(
|
||||||
|
&handshake,
|
||||||
|
r,w,
|
||||||
|
self.peer,
|
||||||
|
&self.config,
|
||||||
|
&self.replay_checker,
|
||||||
|
false
|
||||||
|
).await?
|
||||||
|
.into_success()?;
|
||||||
|
|
||||||
let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake(
|
|
||||||
&handshake,
|
|
||||||
read_half,
|
|
||||||
write_half,
|
|
||||||
peer,
|
|
||||||
&config,
|
|
||||||
&replay_checker,
|
|
||||||
false,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
HandshakeResult::Success(result) => result,
|
|
||||||
HandshakeResult::BadClient { reader, writer } => {
|
|
||||||
stats.increment_connects_bad();
|
|
||||||
handle_bad_client(reader, writer, &handshake, &config).await;
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
HandshakeResult::Error(e) => return Err(e),
|
|
||||||
};
|
|
||||||
|
|
||||||
Self::handle_authenticated_static(
|
Self::handle_authenticated_static(
|
||||||
crypto_reader,
|
cr,cw,success,
|
||||||
crypto_writer,
|
|
||||||
success,
|
|
||||||
self.upstream_manager,
|
self.upstream_manager,
|
||||||
self.stats,
|
self.stats,
|
||||||
self.config,
|
self.config,
|
||||||
buffer_pool,
|
self.buffer_pool,
|
||||||
self.rng,
|
self.rng,
|
||||||
self.me_pool,
|
self.me_pool,
|
||||||
local_addr,
|
local_addr,
|
||||||
peer,
|
self.peer,
|
||||||
self.ip_tracker,
|
self.ip_tracker
|
||||||
)
|
).await
|
||||||
.await
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Main dispatch after successful handshake.
|
|
||||||
/// Two modes:
|
pub(crate)
|
||||||
/// - Direct: TCP relay to TG DC (existing behavior)
|
async fn handle_authenticated_static<R,W>(
|
||||||
/// - Middle Proxy: RPC multiplex through ME pool (new — supports CDN DCs)
|
client_reader:CryptoReader<R>,
|
||||||
pub(crate) async fn handle_authenticated_static<R, W>(
|
client_writer:CryptoWriter<W>,
|
||||||
client_reader: CryptoReader<R>,
|
success:HandshakeSuccess,
|
||||||
client_writer: CryptoWriter<W>,
|
|
||||||
success: HandshakeSuccess,
|
upstream_manager:Arc<UpstreamManager>,
|
||||||
upstream_manager: Arc<UpstreamManager>,
|
stats:Arc<Stats>,
|
||||||
stats: Arc<Stats>,
|
config:Arc<ProxyConfig>,
|
||||||
config: Arc<ProxyConfig>,
|
|
||||||
buffer_pool: Arc<BufferPool>,
|
buffer_pool:Arc<BufferPool>,
|
||||||
rng: Arc<SecureRandom>,
|
rng:Arc<SecureRandom>,
|
||||||
me_pool: Option<Arc<MePool>>,
|
me_pool:Option<Arc<MePool>>,
|
||||||
local_addr: SocketAddr,
|
|
||||||
peer_addr: SocketAddr,
|
local_addr:SocketAddr,
|
||||||
ip_tracker: Arc<UserIpTracker>,
|
peer_addr:SocketAddr,
|
||||||
) -> Result<()>
|
|
||||||
|
ip_tracker:Arc<UserIpTracker>,
|
||||||
|
)->Result<()>
|
||||||
where
|
where
|
||||||
R: AsyncRead + Unpin + Send + 'static,
|
R:AsyncRead+Unpin+Send+'static,
|
||||||
W: AsyncWrite + Unpin + Send + 'static,
|
W:AsyncWrite+Unpin+Send+'static,
|
||||||
{
|
{
|
||||||
let user = &success.user;
|
|
||||||
|
|
||||||
if let Err(e) = Self::check_user_limits_static(user, &config, &stats, peer_addr, &ip_tracker).await {
|
let user=&success.user;
|
||||||
warn!(user = %user, error = %e, "User limit exceeded");
|
|
||||||
return Err(e);
|
ip_tracker.check_and_add(
|
||||||
}
|
user,
|
||||||
|
peer_addr.ip()
|
||||||
|
).await?;
|
||||||
|
|
||||||
// IP Cleanup Guard: автоматически удаляет IP при выходе из scope
|
|
||||||
struct IpCleanupGuard {
|
|
||||||
tracker: Arc<UserIpTracker>,
|
|
||||||
user: String,
|
|
||||||
ip: std::net::IpAddr,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Drop for IpCleanupGuard {
|
|
||||||
fn drop(&mut self) {
|
|
||||||
let tracker = self.tracker.clone();
|
|
||||||
let user = self.user.clone();
|
|
||||||
let ip = self.ip;
|
|
||||||
tokio::spawn(async move {
|
|
||||||
tracker.remove_ip(&user, ip).await;
|
|
||||||
debug!(user = %user, ip = %ip, "IP cleaned up on disconnect");
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let _cleanup = IpCleanupGuard {
|
|
||||||
tracker: ip_tracker,
|
|
||||||
user: user.clone(),
|
|
||||||
ip: peer_addr.ip(),
|
|
||||||
};
|
|
||||||
|
|
||||||
// Decide: middle proxy or direct
|
|
||||||
if config.general.use_middle_proxy {
|
if config.general.use_middle_proxy {
|
||||||
if let Some(ref pool) = me_pool {
|
|
||||||
|
if let Some(pool)=me_pool{
|
||||||
|
|
||||||
return handle_via_middle_proxy(
|
return handle_via_middle_proxy(
|
||||||
client_reader,
|
client_reader,
|
||||||
client_writer,
|
client_writer,
|
||||||
success,
|
success,
|
||||||
pool.clone(),
|
pool,
|
||||||
stats,
|
stats,
|
||||||
config,
|
config,
|
||||||
buffer_pool,
|
buffer_pool,
|
||||||
local_addr,
|
local_addr
|
||||||
)
|
).await
|
||||||
.await;
|
|
||||||
}
|
}
|
||||||
warn!("use_middle_proxy=true but MePool not initialized, falling back to direct");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Direct mode (original behavior)
|
|
||||||
handle_via_direct(
|
handle_via_direct(
|
||||||
client_reader,
|
client_reader,
|
||||||
client_writer,
|
client_writer,
|
||||||
|
|
@ -497,55 +667,8 @@ impl RunningClientHandler {
|
||||||
stats,
|
stats,
|
||||||
config,
|
config,
|
||||||
buffer_pool,
|
buffer_pool,
|
||||||
rng,
|
rng
|
||||||
)
|
).await
|
||||||
.await
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn check_user_limits_static(
|
|
||||||
user: &str,
|
|
||||||
config: &ProxyConfig,
|
|
||||||
stats: &Stats,
|
|
||||||
peer_addr: SocketAddr,
|
|
||||||
ip_tracker: &UserIpTracker,
|
|
||||||
) -> Result<()> {
|
|
||||||
if let Some(expiration) = config.access.user_expirations.get(user) {
|
|
||||||
if chrono::Utc::now() > *expiration {
|
|
||||||
return Err(ProxyError::UserExpired {
|
|
||||||
user: user.to_string(),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// IP limit check
|
|
||||||
if let Err(reason) = ip_tracker.check_and_add(user, peer_addr.ip()).await {
|
|
||||||
warn!(
|
|
||||||
user = %user,
|
|
||||||
ip = %peer_addr.ip(),
|
|
||||||
reason = %reason,
|
|
||||||
"IP limit exceeded"
|
|
||||||
);
|
|
||||||
return Err(ProxyError::ConnectionLimitExceeded {
|
|
||||||
user: user.to_string(),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(limit) = config.access.user_max_tcp_conns.get(user) {
|
|
||||||
if stats.get_user_curr_connects(user) >= *limit as u64 {
|
|
||||||
return Err(ProxyError::ConnectionLimitExceeded {
|
|
||||||
user: user.to_string(),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(quota) = config.access.user_data_quota.get(user) {
|
|
||||||
if stats.get_user_total_octets(user) >= *quota {
|
|
||||||
return Err(ProxyError::DataQuotaExceeded {
|
|
||||||
user: user.to_string(),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue