Phase 2 implemented with additional guards

This commit is contained in:
David Osipov
2026-04-03 02:08:59 +04:00
parent a9f695623d
commit 6ea867ce36
27 changed files with 2513 additions and 1131 deletions

View File

@@ -81,10 +81,15 @@ use crate::transport::socket::normalize_ip;
use crate::transport::{UpstreamManager, configure_client_socket, parse_proxy_protocol};
use crate::proxy::direct_relay::handle_via_direct;
use crate::proxy::handshake::{HandshakeSuccess, handle_mtproto_handshake, handle_tls_handshake};
use crate::proxy::handshake::{
HandshakeSuccess, handle_mtproto_handshake_with_shared, handle_tls_handshake_with_shared,
};
#[cfg(test)]
use crate::proxy::handshake::{handle_mtproto_handshake, handle_tls_handshake};
use crate::proxy::masking::handle_bad_client;
use crate::proxy::middle_relay::handle_via_middle_proxy;
use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController};
use crate::proxy::shared_state::ProxySharedState;
fn beobachten_ttl(config: &ProxyConfig) -> Duration {
const BEOBACHTEN_TTL_MAX_MINUTES: u64 = 24 * 60;
@@ -342,7 +347,48 @@ fn synthetic_local_addr(port: u16) -> SocketAddr {
SocketAddr::from(([0, 0, 0, 0], port))
}
#[cfg(test)]
pub async fn handle_client_stream<S>(
stream: S,
peer: SocketAddr,
config: Arc<ProxyConfig>,
stats: Arc<Stats>,
upstream_manager: Arc<UpstreamManager>,
replay_checker: Arc<ReplayChecker>,
buffer_pool: Arc<BufferPool>,
rng: Arc<SecureRandom>,
me_pool: Option<Arc<MePool>>,
route_runtime: Arc<RouteRuntimeController>,
tls_cache: Option<Arc<TlsFrontCache>>,
ip_tracker: Arc<UserIpTracker>,
beobachten: Arc<BeobachtenStore>,
proxy_protocol_enabled: bool,
) -> Result<()>
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
handle_client_stream_with_shared(
stream,
peer,
config,
stats,
upstream_manager,
replay_checker,
buffer_pool,
rng,
me_pool,
route_runtime,
tls_cache,
ip_tracker,
beobachten,
ProxySharedState::new(),
proxy_protocol_enabled,
)
.await
}
#[allow(clippy::too_many_arguments)]
pub async fn handle_client_stream_with_shared<S>(
mut stream: S,
peer: SocketAddr,
config: Arc<ProxyConfig>,
@@ -356,6 +402,7 @@ pub async fn handle_client_stream<S>(
tls_cache: Option<Arc<TlsFrontCache>>,
ip_tracker: Arc<UserIpTracker>,
beobachten: Arc<BeobachtenStore>,
shared: Arc<ProxySharedState>,
proxy_protocol_enabled: bool,
) -> Result<()>
where
@@ -550,9 +597,10 @@ where
let (read_half, write_half) = tokio::io::split(stream);
let (mut tls_reader, tls_writer, tls_user) = match handle_tls_handshake(
let (mut tls_reader, tls_writer, tls_user) = match handle_tls_handshake_with_shared(
&handshake, read_half, write_half, real_peer,
&config, &replay_checker, &rng, tls_cache.clone(),
shared.as_ref(),
).await {
HandshakeResult::Success(result) => result,
HandshakeResult::BadClient { reader, writer } => {
@@ -578,9 +626,10 @@ where
let mtproto_handshake: [u8; HANDSHAKE_LEN] = mtproto_data[..].try_into()
.map_err(|_| ProxyError::InvalidHandshake("Short MTProto handshake".into()))?;
let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake(
let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake_with_shared(
&mtproto_handshake, tls_reader, tls_writer, real_peer,
&config, &replay_checker, true, Some(tls_user.as_str()),
shared.as_ref(),
).await {
HandshakeResult::Success(result) => result,
HandshakeResult::BadClient { reader, writer } => {
@@ -614,11 +663,12 @@ where
};
Ok(HandshakeOutcome::NeedsRelay(Box::pin(
RunningClientHandler::handle_authenticated_static(
RunningClientHandler::handle_authenticated_static_with_shared(
crypto_reader, crypto_writer, success,
upstream_manager, stats, config, buffer_pool, rng, me_pool,
route_runtime.clone(),
local_addr, real_peer, ip_tracker.clone(),
shared.clone(),
),
)))
} else {
@@ -644,9 +694,10 @@ where
let (read_half, write_half) = tokio::io::split(stream);
let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake(
let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake_with_shared(
&handshake, read_half, write_half, real_peer,
&config, &replay_checker, false, None,
shared.as_ref(),
).await {
HandshakeResult::Success(result) => result,
HandshakeResult::BadClient { reader, writer } => {
@@ -665,7 +716,7 @@ where
};
Ok(HandshakeOutcome::NeedsRelay(Box::pin(
RunningClientHandler::handle_authenticated_static(
RunningClientHandler::handle_authenticated_static_with_shared(
crypto_reader,
crypto_writer,
success,
@@ -679,6 +730,7 @@ where
local_addr,
real_peer,
ip_tracker.clone(),
shared.clone(),
)
)))
}
@@ -731,10 +783,12 @@ pub struct RunningClientHandler {
tls_cache: Option<Arc<TlsFrontCache>>,
ip_tracker: Arc<UserIpTracker>,
beobachten: Arc<BeobachtenStore>,
shared: Arc<ProxySharedState>,
proxy_protocol_enabled: bool,
}
impl ClientHandler {
#[cfg(test)]
pub fn new(
stream: TcpStream,
peer: SocketAddr,
@@ -751,6 +805,45 @@ impl ClientHandler {
beobachten: Arc<BeobachtenStore>,
proxy_protocol_enabled: bool,
real_peer_report: Arc<std::sync::Mutex<Option<SocketAddr>>>,
) -> RunningClientHandler {
Self::new_with_shared(
stream,
peer,
config,
stats,
upstream_manager,
replay_checker,
buffer_pool,
rng,
me_pool,
route_runtime,
tls_cache,
ip_tracker,
beobachten,
ProxySharedState::new(),
proxy_protocol_enabled,
real_peer_report,
)
}
#[allow(clippy::too_many_arguments)]
pub fn new_with_shared(
stream: TcpStream,
peer: SocketAddr,
config: Arc<ProxyConfig>,
stats: Arc<Stats>,
upstream_manager: Arc<UpstreamManager>,
replay_checker: Arc<ReplayChecker>,
buffer_pool: Arc<BufferPool>,
rng: Arc<SecureRandom>,
me_pool: Option<Arc<MePool>>,
route_runtime: Arc<RouteRuntimeController>,
tls_cache: Option<Arc<TlsFrontCache>>,
ip_tracker: Arc<UserIpTracker>,
beobachten: Arc<BeobachtenStore>,
shared: Arc<ProxySharedState>,
proxy_protocol_enabled: bool,
real_peer_report: Arc<std::sync::Mutex<Option<SocketAddr>>>,
) -> RunningClientHandler {
let normalized_peer = normalize_ip(peer);
RunningClientHandler {
@@ -769,6 +862,7 @@ impl ClientHandler {
tls_cache,
ip_tracker,
beobachten,
shared,
proxy_protocol_enabled,
}
}
@@ -1058,7 +1152,7 @@ impl RunningClientHandler {
let (read_half, write_half) = self.stream.into_split();
let (mut tls_reader, tls_writer, tls_user) = match handle_tls_handshake(
let (mut tls_reader, tls_writer, tls_user) = match handle_tls_handshake_with_shared(
&handshake,
read_half,
write_half,
@@ -1067,6 +1161,7 @@ impl RunningClientHandler {
&replay_checker,
&self.rng,
self.tls_cache.clone(),
self.shared.as_ref(),
)
.await
{
@@ -1095,7 +1190,7 @@ impl RunningClientHandler {
.try_into()
.map_err(|_| ProxyError::InvalidHandshake("Short MTProto handshake".into()))?;
let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake(
let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake_with_shared(
&mtproto_handshake,
tls_reader,
tls_writer,
@@ -1104,6 +1199,7 @@ impl RunningClientHandler {
&replay_checker,
true,
Some(tls_user.as_str()),
self.shared.as_ref(),
)
.await
{
@@ -1140,7 +1236,7 @@ impl RunningClientHandler {
};
Ok(HandshakeOutcome::NeedsRelay(Box::pin(
Self::handle_authenticated_static(
Self::handle_authenticated_static_with_shared(
crypto_reader,
crypto_writer,
success,
@@ -1154,6 +1250,7 @@ impl RunningClientHandler {
local_addr,
peer,
self.ip_tracker,
self.shared,
),
)))
}
@@ -1192,7 +1289,7 @@ impl RunningClientHandler {
let (read_half, write_half) = self.stream.into_split();
let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake(
let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake_with_shared(
&handshake,
read_half,
write_half,
@@ -1201,6 +1298,7 @@ impl RunningClientHandler {
&replay_checker,
false,
None,
self.shared.as_ref(),
)
.await
{
@@ -1221,7 +1319,7 @@ impl RunningClientHandler {
};
Ok(HandshakeOutcome::NeedsRelay(Box::pin(
Self::handle_authenticated_static(
Self::handle_authenticated_static_with_shared(
crypto_reader,
crypto_writer,
success,
@@ -1235,6 +1333,7 @@ impl RunningClientHandler {
local_addr,
peer,
self.ip_tracker,
self.shared,
),
)))
}
@@ -1243,6 +1342,7 @@ impl RunningClientHandler {
/// Two modes:
/// - Direct: TCP relay to TG DC (existing behavior)
/// - Middle Proxy: RPC multiplex through ME pool (new — supports CDN DCs)
#[cfg(test)]
async fn handle_authenticated_static<R, W>(
client_reader: CryptoReader<R>,
client_writer: CryptoWriter<W>,
@@ -1258,6 +1358,45 @@ impl RunningClientHandler {
peer_addr: SocketAddr,
ip_tracker: Arc<UserIpTracker>,
) -> Result<()>
where
R: AsyncRead + Unpin + Send + 'static,
W: AsyncWrite + Unpin + Send + 'static,
{
Self::handle_authenticated_static_with_shared(
client_reader,
client_writer,
success,
upstream_manager,
stats,
config,
buffer_pool,
rng,
me_pool,
route_runtime,
local_addr,
peer_addr,
ip_tracker,
ProxySharedState::new(),
)
.await
}
async fn handle_authenticated_static_with_shared<R, W>(
client_reader: CryptoReader<R>,
client_writer: CryptoWriter<W>,
success: HandshakeSuccess,
upstream_manager: Arc<UpstreamManager>,
stats: Arc<Stats>,
config: Arc<ProxyConfig>,
buffer_pool: Arc<BufferPool>,
rng: Arc<SecureRandom>,
me_pool: Option<Arc<MePool>>,
route_runtime: Arc<RouteRuntimeController>,
local_addr: SocketAddr,
peer_addr: SocketAddr,
ip_tracker: Arc<UserIpTracker>,
_shared: Arc<ProxySharedState>,
) -> Result<()>
where
R: AsyncRead + Unpin + Send + 'static,
W: AsyncWrite + Unpin + Send + 'static,
@@ -1299,6 +1438,7 @@ impl RunningClientHandler {
route_runtime.subscribe(),
route_snapshot,
session_id,
_shared,
)
.await
} else {