Implement user connection reservation management and enhance relay task handling in proxy

This commit is contained in:
David Osipov 2026-03-17 19:05:26 +04:00
parent 4808a30185
commit c540a6657f
No known key found for this signature in database
GPG Key ID: 0E55C4A47454E82E
5 changed files with 714 additions and 10 deletions

View File

@ -24,6 +24,39 @@ enum HandshakeOutcome {
Handled,
}
struct UserConnectionReservation {
stats: Arc<Stats>,
ip_tracker: Arc<UserIpTracker>,
user: String,
ip: IpAddr,
}
impl UserConnectionReservation {
fn new(stats: Arc<Stats>, ip_tracker: Arc<UserIpTracker>, user: String, ip: IpAddr) -> Self {
Self {
stats,
ip_tracker,
user,
ip,
}
}
}
impl Drop for UserConnectionReservation {
fn drop(&mut self) {
self.stats.decrement_user_curr_connects(&self.user);
if let Ok(handle) = tokio::runtime::Handle::try_current() {
let ip_tracker = self.ip_tracker.clone();
let user = self.user.clone();
let ip = self.ip;
handle.spawn(async move {
ip_tracker.remove_ip(&user, ip).await;
});
}
}
}
use crate::config::ProxyConfig;
use crate::crypto::SecureRandom;
use crate::error::{HandshakeResult, ProxyError, Result, StreamError};
@ -90,6 +123,10 @@ fn is_trusted_proxy_source(peer_ip: IpAddr, trusted: &[IpNetwork]) -> bool {
trusted.iter().any(|cidr| cidr.contains(peer_ip))
}
fn synthetic_local_addr(port: u16) -> SocketAddr {
SocketAddr::from(([0, 0, 0, 0], port))
}
pub async fn handle_client_stream<S>(
mut stream: S,
peer: SocketAddr,
@ -113,9 +150,7 @@ where
let mut real_peer = normalize_ip(peer);
// For non-TCP streams, use a synthetic local address; may be overridden by PROXY protocol dst
let mut local_addr: SocketAddr = format!("0.0.0.0:{}", config.server.port)
.parse()
.unwrap_or_else(|_| "0.0.0.0:443".parse().unwrap());
let mut local_addr = synthetic_local_addr(config.server.port);
if proxy_protocol_enabled {
let proxy_header_timeout = Duration::from_millis(
@ -798,10 +833,22 @@ impl RunningClientHandler {
{
let user = success.user.clone();
if let Err(e) = Self::check_user_limits_static(&user, &config, &stats, peer_addr, &ip_tracker).await {
warn!(user = %user, error = %e, "User limit exceeded");
return Err(e);
}
let _user_limit_reservation =
match Self::acquire_user_connection_reservation_static(
&user,
&config,
stats.clone(),
peer_addr,
ip_tracker,
)
.await
{
Ok(reservation) => reservation,
Err(e) => {
warn!(user = %user, error = %e, "User limit exceeded");
return Err(e);
}
};
let route_snapshot = route_runtime.snapshot();
let session_id = rng.u64();
@ -858,12 +905,64 @@ impl RunningClientHandler {
)
.await
};
stats.decrement_user_curr_connects(&user);
ip_tracker.remove_ip(&user, peer_addr.ip()).await;
relay_result
}
async fn acquire_user_connection_reservation_static(
user: &str,
config: &ProxyConfig,
stats: Arc<Stats>,
peer_addr: SocketAddr,
ip_tracker: Arc<UserIpTracker>,
) -> Result<UserConnectionReservation> {
if let Some(expiration) = config.access.user_expirations.get(user)
&& chrono::Utc::now() > *expiration
{
return Err(ProxyError::UserExpired {
user: user.to_string(),
});
}
if let Some(quota) = config.access.user_data_quota.get(user)
&& stats.get_user_total_octets(user) >= *quota
{
return Err(ProxyError::DataQuotaExceeded {
user: user.to_string(),
});
}
let limit = config.access.user_max_tcp_conns.get(user).map(|v| *v as u64);
if !stats.try_acquire_user_curr_connects(user, limit) {
return Err(ProxyError::ConnectionLimitExceeded {
user: user.to_string(),
});
}
match ip_tracker.check_and_add(user, peer_addr.ip()).await {
Ok(()) => {}
Err(reason) => {
stats.decrement_user_curr_connects(user);
warn!(
user = %user,
ip = %peer_addr.ip(),
reason = %reason,
"IP limit exceeded"
);
return Err(ProxyError::ConnectionLimitExceeded {
user: user.to_string(),
});
}
}
Ok(UserConnectionReservation::new(
stats,
ip_tracker,
user.to_string(),
peer_addr.ip(),
))
}
#[cfg(test)]
async fn check_user_limits_static(
user: &str,
config: &ProxyConfig,

View File

@ -1,11 +1,279 @@
use super::*;
use crate::config::{UpstreamConfig, UpstreamType};
use crate::crypto::AesCtr;
use crate::crypto::sha256_hmac;
use crate::protocol::constants::ProtoTag;
use crate::protocol::tls;
use crate::proxy::handshake::HandshakeSuccess;
use crate::transport::proxy_protocol::ProxyProtocolV1Builder;
use crate::stream::{CryptoReader, CryptoWriter};
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
#[test]
fn synthetic_local_addr_uses_configured_port_for_zero() {
let addr = synthetic_local_addr(0);
assert_eq!(addr.ip(), IpAddr::from([0, 0, 0, 0]));
assert_eq!(addr.port(), 0);
}
#[test]
fn synthetic_local_addr_uses_configured_port_for_max() {
let addr = synthetic_local_addr(u16::MAX);
assert_eq!(addr.ip(), IpAddr::from([0, 0, 0, 0]));
assert_eq!(addr.port(), u16::MAX);
}
fn make_crypto_reader<R>(reader: R) -> CryptoReader<R>
where
R: tokio::io::AsyncRead + Unpin,
{
let key = [0u8; 32];
let iv = 0u128;
CryptoReader::new(reader, AesCtr::new(&key, iv))
}
fn make_crypto_writer<W>(writer: W) -> CryptoWriter<W>
where
W: tokio::io::AsyncWrite + Unpin,
{
let key = [0u8; 32];
let iv = 0u128;
CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024)
}
#[tokio::test]
async fn relay_task_abort_releases_user_gate_and_ip_reservation() {
let tg_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let tg_addr = tg_listener.local_addr().unwrap();
let tg_accept_task = tokio::spawn(async move {
let (stream, _) = tg_listener.accept().await.unwrap();
let _hold_stream = stream;
tokio::time::sleep(Duration::from_secs(60)).await;
});
let user = "abort-user";
let peer_addr: SocketAddr = "198.51.100.230:50000".parse().unwrap();
let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new());
let mut cfg = ProxyConfig::default();
cfg.access.user_max_tcp_conns.insert(user.to_string(), 8);
cfg.dc_overrides
.insert("2".to_string(), vec![tg_addr.to_string()]);
let config = Arc::new(cfg);
let upstream_manager = Arc::new(UpstreamManager::new(
vec![UpstreamConfig {
upstream_type: UpstreamType::Direct {
interface: None,
bind_addresses: None,
},
weight: 1,
enabled: true,
scopes: String::new(),
selected_scope: String::new(),
}],
1,
1,
1,
1,
false,
stats.clone(),
));
let buffer_pool = Arc::new(BufferPool::new());
let rng = Arc::new(SecureRandom::new());
let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct));
let (server_side, client_side) = duplex(64 * 1024);
let (server_reader, server_writer) = tokio::io::split(server_side);
let client_reader = make_crypto_reader(server_reader);
let client_writer = make_crypto_writer(server_writer);
let success = HandshakeSuccess {
user: user.to_string(),
dc_idx: 2,
proto_tag: ProtoTag::Intermediate,
dec_key: [0u8; 32],
dec_iv: 0,
enc_key: [0u8; 32],
enc_iv: 0,
peer: peer_addr,
is_tls: false,
};
let relay_task = tokio::spawn(RunningClientHandler::handle_authenticated_static(
client_reader,
client_writer,
success,
upstream_manager,
stats.clone(),
config,
buffer_pool,
rng,
None,
route_runtime,
"127.0.0.1:443".parse().unwrap(),
peer_addr,
ip_tracker.clone(),
));
tokio::time::timeout(Duration::from_secs(2), async {
loop {
if stats.get_user_curr_connects(user) == 1
&& ip_tracker.get_active_ip_count(user).await == 1
{
break;
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
})
.await
.expect("relay must reserve user slot and IP before abort");
relay_task.abort();
let joined = relay_task.await;
assert!(joined.is_err(), "aborted relay task must return join error");
tokio::time::sleep(Duration::from_millis(50)).await;
assert_eq!(
stats.get_user_curr_connects(user),
0,
"task abort must release user current-connection slot"
);
assert_eq!(
ip_tracker.get_active_ip_count(user).await,
0,
"task abort must release reserved user IP footprint"
);
drop(client_side);
tg_accept_task.abort();
let _ = tg_accept_task.await;
}
#[tokio::test]
async fn relay_cutover_releases_user_gate_and_ip_reservation() {
let tg_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let tg_addr = tg_listener.local_addr().unwrap();
let tg_accept_task = tokio::spawn(async move {
let (stream, _) = tg_listener.accept().await.unwrap();
let _hold_stream = stream;
tokio::time::sleep(Duration::from_secs(60)).await;
});
let user = "cutover-user";
let peer_addr: SocketAddr = "198.51.100.231:50001".parse().unwrap();
let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new());
let mut cfg = ProxyConfig::default();
cfg.access.user_max_tcp_conns.insert(user.to_string(), 8);
cfg.dc_overrides
.insert("2".to_string(), vec![tg_addr.to_string()]);
let config = Arc::new(cfg);
let upstream_manager = Arc::new(UpstreamManager::new(
vec![UpstreamConfig {
upstream_type: UpstreamType::Direct {
interface: None,
bind_addresses: None,
},
weight: 1,
enabled: true,
scopes: String::new(),
selected_scope: String::new(),
}],
1,
1,
1,
1,
false,
stats.clone(),
));
let buffer_pool = Arc::new(BufferPool::new());
let rng = Arc::new(SecureRandom::new());
let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct));
let (server_side, client_side) = duplex(64 * 1024);
let (server_reader, server_writer) = tokio::io::split(server_side);
let client_reader = make_crypto_reader(server_reader);
let client_writer = make_crypto_writer(server_writer);
let success = HandshakeSuccess {
user: user.to_string(),
dc_idx: 2,
proto_tag: ProtoTag::Intermediate,
dec_key: [0u8; 32],
dec_iv: 0,
enc_key: [0u8; 32],
enc_iv: 0,
peer: peer_addr,
is_tls: false,
};
let relay_task = tokio::spawn(RunningClientHandler::handle_authenticated_static(
client_reader,
client_writer,
success,
upstream_manager,
stats.clone(),
config,
buffer_pool,
rng,
None,
route_runtime.clone(),
"127.0.0.1:443".parse().unwrap(),
peer_addr,
ip_tracker.clone(),
));
tokio::time::timeout(Duration::from_secs(2), async {
loop {
if stats.get_user_curr_connects(user) == 1
&& ip_tracker.get_active_ip_count(user).await == 1
{
break;
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
})
.await
.expect("relay must reserve user slot and IP before cutover");
assert!(
route_runtime.set_mode(RelayRouteMode::Middle).is_some(),
"cutover must advance route generation"
);
let relay_result = tokio::time::timeout(Duration::from_secs(6), relay_task)
.await
.expect("relay must terminate after cutover")
.expect("relay task must not panic");
assert!(relay_result.is_err(), "cutover must terminate direct relay session");
assert_eq!(
stats.get_user_curr_connects(user),
0,
"cutover exit must release user current-connection slot"
);
assert_eq!(
ip_tracker.get_active_ip_count(user).await,
0,
"cutover exit must release reserved user IP footprint"
);
drop(client_side);
tg_accept_task.abort();
let _ = tg_accept_task.await;
}
#[tokio::test]
async fn short_tls_probe_is_masked_through_client_pipeline() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();

View File

@ -178,3 +178,112 @@ async fn direct_relay_abort_midflight_releases_route_gauge() {
tg_accept_task.abort();
let _ = tg_accept_task.await;
}
#[tokio::test]
async fn direct_relay_cutover_midflight_releases_route_gauge() {
let tg_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let tg_addr = tg_listener.local_addr().unwrap();
let tg_accept_task = tokio::spawn(async move {
let (stream, _) = tg_listener.accept().await.unwrap();
let _hold_stream = stream;
tokio::time::sleep(Duration::from_secs(60)).await;
});
let stats = Arc::new(Stats::new());
let mut config = ProxyConfig::default();
config
.dc_overrides
.insert("2".to_string(), vec![tg_addr.to_string()]);
let config = Arc::new(config);
let upstream_manager = Arc::new(UpstreamManager::new(
vec![UpstreamConfig {
upstream_type: UpstreamType::Direct {
interface: None,
bind_addresses: None,
},
weight: 1,
enabled: true,
scopes: String::new(),
selected_scope: String::new(),
}],
1,
1,
1,
1,
false,
stats.clone(),
));
let rng = Arc::new(SecureRandom::new());
let buffer_pool = Arc::new(BufferPool::new());
let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct));
let route_snapshot = route_runtime.snapshot();
let (server_side, client_side) = duplex(64 * 1024);
let (server_reader, server_writer) = tokio::io::split(server_side);
let client_reader = make_crypto_reader(server_reader);
let client_writer = make_crypto_writer(server_writer);
let success = HandshakeSuccess {
user: "cutover-direct-user".to_string(),
dc_idx: 2,
proto_tag: ProtoTag::Intermediate,
dec_key: [0u8; 32],
dec_iv: 0,
enc_key: [0u8; 32],
enc_iv: 0,
peer: "127.0.0.1:50002".parse().unwrap(),
is_tls: false,
};
let relay_task = tokio::spawn(handle_via_direct(
client_reader,
client_writer,
success,
upstream_manager,
stats.clone(),
config,
buffer_pool,
rng,
route_runtime.subscribe(),
route_snapshot,
0xface_cafe,
));
tokio::time::timeout(Duration::from_secs(2), async {
loop {
if stats.get_current_connections_direct() == 1 {
break;
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
})
.await
.expect("direct relay must increment route gauge before cutover");
assert!(
route_runtime.set_mode(RelayRouteMode::Middle).is_some(),
"cutover must advance route generation"
);
let relay_result = tokio::time::timeout(Duration::from_secs(6), relay_task)
.await
.expect("direct relay must terminate after cutover")
.expect("direct relay task must not panic");
assert!(
relay_result.is_err(),
"cutover should terminate direct relay session"
);
assert_eq!(
stats.get_current_connections_direct(),
0,
"route gauge must be released when direct relay exits on cutover"
);
drop(client_side);
tg_accept_task.abort();
let _ = tg_accept_task.await;
}

View File

@ -942,3 +942,80 @@ async fn middle_relay_abort_midflight_releases_route_gauge() {
drop(client_side);
}
#[tokio::test]
async fn middle_relay_cutover_midflight_releases_route_gauge() {
let stats = Arc::new(Stats::new());
let me_pool = make_me_pool_for_abort_test(stats.clone()).await;
let config = Arc::new(ProxyConfig::default());
let buffer_pool = Arc::new(BufferPool::new());
let rng = Arc::new(SecureRandom::new());
let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Middle));
let route_snapshot = route_runtime.snapshot();
let (server_side, client_side) = duplex(64 * 1024);
let (server_reader, server_writer) = tokio::io::split(server_side);
let crypto_reader = make_crypto_reader(server_reader);
let crypto_writer = make_crypto_writer(server_writer);
let success = HandshakeSuccess {
user: "cutover-middle-user".to_string(),
dc_idx: 2,
proto_tag: ProtoTag::Intermediate,
dec_key: [0u8; 32],
dec_iv: 0,
enc_key: [0u8; 32],
enc_iv: 0,
peer: "127.0.0.1:50003".parse().unwrap(),
is_tls: false,
};
let relay_task = tokio::spawn(handle_via_middle_proxy(
crypto_reader,
crypto_writer,
success,
me_pool,
stats.clone(),
config,
buffer_pool,
"127.0.0.1:443".parse().unwrap(),
rng,
route_runtime.subscribe(),
route_snapshot,
0xfeed_beef,
));
tokio::time::timeout(TokioDuration::from_secs(2), async {
loop {
if stats.get_current_connections_me() == 1 {
break;
}
tokio::time::sleep(TokioDuration::from_millis(10)).await;
}
})
.await
.expect("middle relay must increment route gauge before cutover");
assert!(
route_runtime.set_mode(RelayRouteMode::Direct).is_some(),
"cutover must advance route generation"
);
let relay_result = tokio::time::timeout(TokioDuration::from_secs(6), relay_task)
.await
.expect("middle relay must terminate after cutover")
.expect("middle relay task must not panic");
assert!(
relay_result.is_err(),
"cutover should terminate middle relay session"
);
assert_eq!(
stats.get_current_connections_me(),
0,
"route gauge must be released when middle relay exits on cutover"
);
drop(client_side);
}

View File

@ -2,6 +2,7 @@ use super::*;
use std::panic::{self, AssertUnwindSafe};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Barrier;
#[test]
fn direct_connection_lease_balances_on_drop() {
@ -63,6 +64,156 @@ fn direct_connection_lease_balances_on_panic_unwind() {
);
}
#[test]
fn middle_connection_lease_balances_on_panic_unwind() {
let stats = Arc::new(Stats::new());
let stats_for_panic = stats.clone();
let panic_result = panic::catch_unwind(AssertUnwindSafe(move || {
let _lease = stats_for_panic.acquire_me_connection_lease();
panic!("intentional panic to verify middle lease drop path");
}));
assert!(panic_result.is_err(), "panic must propagate from test closure");
assert_eq!(
stats.get_current_connections_me(),
0,
"panic unwind must release middle route gauge"
);
}
#[tokio::test]
async fn concurrent_mixed_route_lease_churn_balances_to_zero() {
const TASKS: usize = 48;
const ITERATIONS_PER_TASK: usize = 256;
let stats = Arc::new(Stats::new());
let barrier = Arc::new(Barrier::new(TASKS));
let mut workers = Vec::with_capacity(TASKS);
for task_idx in 0..TASKS {
let stats_for_task = stats.clone();
let barrier_for_task = barrier.clone();
workers.push(tokio::spawn(async move {
barrier_for_task.wait().await;
for iter in 0..ITERATIONS_PER_TASK {
if (task_idx + iter) % 2 == 0 {
let _lease = stats_for_task.acquire_direct_connection_lease();
tokio::task::yield_now().await;
} else {
let _lease = stats_for_task.acquire_me_connection_lease();
tokio::task::yield_now().await;
}
}
}));
}
for worker in workers {
worker
.await
.expect("lease churn worker must not panic");
}
assert_eq!(
stats.get_current_connections_direct(),
0,
"direct route gauge must return to zero after concurrent lease churn"
);
assert_eq!(
stats.get_current_connections_me(),
0,
"middle route gauge must return to zero after concurrent lease churn"
);
}
#[tokio::test]
async fn abort_storm_mixed_route_leases_returns_all_gauges_to_zero() {
const TASKS: usize = 64;
let stats = Arc::new(Stats::new());
let mut workers = Vec::with_capacity(TASKS);
for task_idx in 0..TASKS {
let stats_for_task = stats.clone();
workers.push(tokio::spawn(async move {
if task_idx % 2 == 0 {
let _lease = stats_for_task.acquire_direct_connection_lease();
tokio::time::sleep(Duration::from_secs(60)).await;
} else {
let _lease = stats_for_task.acquire_me_connection_lease();
tokio::time::sleep(Duration::from_secs(60)).await;
}
}));
}
tokio::time::timeout(Duration::from_secs(2), async {
loop {
let total = stats.get_current_connections_direct() + stats.get_current_connections_me();
if total == TASKS as u64 {
break;
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
})
.await
.expect("all storm tasks must acquire route leases before abort");
for worker in &workers {
worker.abort();
}
for worker in workers {
let joined = worker.await;
assert!(joined.is_err(), "aborted worker must return join error");
}
tokio::time::timeout(Duration::from_secs(2), async {
loop {
if stats.get_current_connections_direct() == 0 && stats.get_current_connections_me() == 0 {
break;
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
})
.await
.expect("all route gauges must drain to zero after abort storm");
}
#[test]
fn saturating_route_decrements_do_not_underflow_under_race() {
const THREADS: usize = 16;
const DECREMENTS_PER_THREAD: usize = 4096;
let stats = Arc::new(Stats::new());
let mut workers = Vec::with_capacity(THREADS);
for _ in 0..THREADS {
let stats_for_thread = stats.clone();
workers.push(std::thread::spawn(move || {
for _ in 0..DECREMENTS_PER_THREAD {
stats_for_thread.decrement_current_connections_direct();
stats_for_thread.decrement_current_connections_me();
}
}));
}
for worker in workers {
worker
.join()
.expect("decrement race worker must not panic");
}
assert_eq!(
stats.get_current_connections_direct(),
0,
"direct route decrement races must never underflow"
);
assert_eq!(
stats.get_current_connections_me(),
0,
"middle route decrement races must never underflow"
);
}
#[tokio::test]
async fn direct_connection_lease_balances_on_task_abort() {
let stats = Arc::new(Stats::new());