mirror of https://github.com/telemt/telemt.git
Implement user connection reservation management and enhance relay task handling in proxy
This commit is contained in:
parent
4808a30185
commit
c540a6657f
|
|
@ -24,6 +24,39 @@ enum HandshakeOutcome {
|
|||
Handled,
|
||||
}
|
||||
|
||||
struct UserConnectionReservation {
|
||||
stats: Arc<Stats>,
|
||||
ip_tracker: Arc<UserIpTracker>,
|
||||
user: String,
|
||||
ip: IpAddr,
|
||||
}
|
||||
|
||||
impl UserConnectionReservation {
|
||||
fn new(stats: Arc<Stats>, ip_tracker: Arc<UserIpTracker>, user: String, ip: IpAddr) -> Self {
|
||||
Self {
|
||||
stats,
|
||||
ip_tracker,
|
||||
user,
|
||||
ip,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for UserConnectionReservation {
|
||||
fn drop(&mut self) {
|
||||
self.stats.decrement_user_curr_connects(&self.user);
|
||||
|
||||
if let Ok(handle) = tokio::runtime::Handle::try_current() {
|
||||
let ip_tracker = self.ip_tracker.clone();
|
||||
let user = self.user.clone();
|
||||
let ip = self.ip;
|
||||
handle.spawn(async move {
|
||||
ip_tracker.remove_ip(&user, ip).await;
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
use crate::config::ProxyConfig;
|
||||
use crate::crypto::SecureRandom;
|
||||
use crate::error::{HandshakeResult, ProxyError, Result, StreamError};
|
||||
|
|
@ -90,6 +123,10 @@ fn is_trusted_proxy_source(peer_ip: IpAddr, trusted: &[IpNetwork]) -> bool {
|
|||
trusted.iter().any(|cidr| cidr.contains(peer_ip))
|
||||
}
|
||||
|
||||
fn synthetic_local_addr(port: u16) -> SocketAddr {
|
||||
SocketAddr::from(([0, 0, 0, 0], port))
|
||||
}
|
||||
|
||||
pub async fn handle_client_stream<S>(
|
||||
mut stream: S,
|
||||
peer: SocketAddr,
|
||||
|
|
@ -113,9 +150,7 @@ where
|
|||
let mut real_peer = normalize_ip(peer);
|
||||
|
||||
// For non-TCP streams, use a synthetic local address; may be overridden by PROXY protocol dst
|
||||
let mut local_addr: SocketAddr = format!("0.0.0.0:{}", config.server.port)
|
||||
.parse()
|
||||
.unwrap_or_else(|_| "0.0.0.0:443".parse().unwrap());
|
||||
let mut local_addr = synthetic_local_addr(config.server.port);
|
||||
|
||||
if proxy_protocol_enabled {
|
||||
let proxy_header_timeout = Duration::from_millis(
|
||||
|
|
@ -798,10 +833,22 @@ impl RunningClientHandler {
|
|||
{
|
||||
let user = success.user.clone();
|
||||
|
||||
if let Err(e) = Self::check_user_limits_static(&user, &config, &stats, peer_addr, &ip_tracker).await {
|
||||
warn!(user = %user, error = %e, "User limit exceeded");
|
||||
return Err(e);
|
||||
}
|
||||
let _user_limit_reservation =
|
||||
match Self::acquire_user_connection_reservation_static(
|
||||
&user,
|
||||
&config,
|
||||
stats.clone(),
|
||||
peer_addr,
|
||||
ip_tracker,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(reservation) => reservation,
|
||||
Err(e) => {
|
||||
warn!(user = %user, error = %e, "User limit exceeded");
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
|
||||
let route_snapshot = route_runtime.snapshot();
|
||||
let session_id = rng.u64();
|
||||
|
|
@ -858,12 +905,64 @@ impl RunningClientHandler {
|
|||
)
|
||||
.await
|
||||
};
|
||||
|
||||
stats.decrement_user_curr_connects(&user);
|
||||
ip_tracker.remove_ip(&user, peer_addr.ip()).await;
|
||||
relay_result
|
||||
}
|
||||
|
||||
async fn acquire_user_connection_reservation_static(
|
||||
user: &str,
|
||||
config: &ProxyConfig,
|
||||
stats: Arc<Stats>,
|
||||
peer_addr: SocketAddr,
|
||||
ip_tracker: Arc<UserIpTracker>,
|
||||
) -> Result<UserConnectionReservation> {
|
||||
if let Some(expiration) = config.access.user_expirations.get(user)
|
||||
&& chrono::Utc::now() > *expiration
|
||||
{
|
||||
return Err(ProxyError::UserExpired {
|
||||
user: user.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
if let Some(quota) = config.access.user_data_quota.get(user)
|
||||
&& stats.get_user_total_octets(user) >= *quota
|
||||
{
|
||||
return Err(ProxyError::DataQuotaExceeded {
|
||||
user: user.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
let limit = config.access.user_max_tcp_conns.get(user).map(|v| *v as u64);
|
||||
if !stats.try_acquire_user_curr_connects(user, limit) {
|
||||
return Err(ProxyError::ConnectionLimitExceeded {
|
||||
user: user.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
match ip_tracker.check_and_add(user, peer_addr.ip()).await {
|
||||
Ok(()) => {}
|
||||
Err(reason) => {
|
||||
stats.decrement_user_curr_connects(user);
|
||||
warn!(
|
||||
user = %user,
|
||||
ip = %peer_addr.ip(),
|
||||
reason = %reason,
|
||||
"IP limit exceeded"
|
||||
);
|
||||
return Err(ProxyError::ConnectionLimitExceeded {
|
||||
user: user.to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Ok(UserConnectionReservation::new(
|
||||
stats,
|
||||
ip_tracker,
|
||||
user.to_string(),
|
||||
peer_addr.ip(),
|
||||
))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
async fn check_user_limits_static(
|
||||
user: &str,
|
||||
config: &ProxyConfig,
|
||||
|
|
|
|||
|
|
@ -1,11 +1,279 @@
|
|||
use super::*;
|
||||
use crate::config::{UpstreamConfig, UpstreamType};
|
||||
use crate::crypto::AesCtr;
|
||||
use crate::crypto::sha256_hmac;
|
||||
use crate::protocol::constants::ProtoTag;
|
||||
use crate::protocol::tls;
|
||||
use crate::proxy::handshake::HandshakeSuccess;
|
||||
use crate::transport::proxy_protocol::ProxyProtocolV1Builder;
|
||||
use crate::stream::{CryptoReader, CryptoWriter};
|
||||
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
|
||||
#[test]
|
||||
fn synthetic_local_addr_uses_configured_port_for_zero() {
|
||||
let addr = synthetic_local_addr(0);
|
||||
assert_eq!(addr.ip(), IpAddr::from([0, 0, 0, 0]));
|
||||
assert_eq!(addr.port(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn synthetic_local_addr_uses_configured_port_for_max() {
|
||||
let addr = synthetic_local_addr(u16::MAX);
|
||||
assert_eq!(addr.ip(), IpAddr::from([0, 0, 0, 0]));
|
||||
assert_eq!(addr.port(), u16::MAX);
|
||||
}
|
||||
|
||||
fn make_crypto_reader<R>(reader: R) -> CryptoReader<R>
|
||||
where
|
||||
R: tokio::io::AsyncRead + Unpin,
|
||||
{
|
||||
let key = [0u8; 32];
|
||||
let iv = 0u128;
|
||||
CryptoReader::new(reader, AesCtr::new(&key, iv))
|
||||
}
|
||||
|
||||
fn make_crypto_writer<W>(writer: W) -> CryptoWriter<W>
|
||||
where
|
||||
W: tokio::io::AsyncWrite + Unpin,
|
||||
{
|
||||
let key = [0u8; 32];
|
||||
let iv = 0u128;
|
||||
CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn relay_task_abort_releases_user_gate_and_ip_reservation() {
|
||||
let tg_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let tg_addr = tg_listener.local_addr().unwrap();
|
||||
|
||||
let tg_accept_task = tokio::spawn(async move {
|
||||
let (stream, _) = tg_listener.accept().await.unwrap();
|
||||
let _hold_stream = stream;
|
||||
tokio::time::sleep(Duration::from_secs(60)).await;
|
||||
});
|
||||
|
||||
let user = "abort-user";
|
||||
let peer_addr: SocketAddr = "198.51.100.230:50000".parse().unwrap();
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let ip_tracker = Arc::new(UserIpTracker::new());
|
||||
|
||||
let mut cfg = ProxyConfig::default();
|
||||
cfg.access.user_max_tcp_conns.insert(user.to_string(), 8);
|
||||
cfg.dc_overrides
|
||||
.insert("2".to_string(), vec![tg_addr.to_string()]);
|
||||
let config = Arc::new(cfg);
|
||||
|
||||
let upstream_manager = Arc::new(UpstreamManager::new(
|
||||
vec![UpstreamConfig {
|
||||
upstream_type: UpstreamType::Direct {
|
||||
interface: None,
|
||||
bind_addresses: None,
|
||||
},
|
||||
weight: 1,
|
||||
enabled: true,
|
||||
scopes: String::new(),
|
||||
selected_scope: String::new(),
|
||||
}],
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
false,
|
||||
stats.clone(),
|
||||
));
|
||||
|
||||
let buffer_pool = Arc::new(BufferPool::new());
|
||||
let rng = Arc::new(SecureRandom::new());
|
||||
let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct));
|
||||
|
||||
let (server_side, client_side) = duplex(64 * 1024);
|
||||
let (server_reader, server_writer) = tokio::io::split(server_side);
|
||||
let client_reader = make_crypto_reader(server_reader);
|
||||
let client_writer = make_crypto_writer(server_writer);
|
||||
|
||||
let success = HandshakeSuccess {
|
||||
user: user.to_string(),
|
||||
dc_idx: 2,
|
||||
proto_tag: ProtoTag::Intermediate,
|
||||
dec_key: [0u8; 32],
|
||||
dec_iv: 0,
|
||||
enc_key: [0u8; 32],
|
||||
enc_iv: 0,
|
||||
peer: peer_addr,
|
||||
is_tls: false,
|
||||
};
|
||||
|
||||
let relay_task = tokio::spawn(RunningClientHandler::handle_authenticated_static(
|
||||
client_reader,
|
||||
client_writer,
|
||||
success,
|
||||
upstream_manager,
|
||||
stats.clone(),
|
||||
config,
|
||||
buffer_pool,
|
||||
rng,
|
||||
None,
|
||||
route_runtime,
|
||||
"127.0.0.1:443".parse().unwrap(),
|
||||
peer_addr,
|
||||
ip_tracker.clone(),
|
||||
));
|
||||
|
||||
tokio::time::timeout(Duration::from_secs(2), async {
|
||||
loop {
|
||||
if stats.get_user_curr_connects(user) == 1
|
||||
&& ip_tracker.get_active_ip_count(user).await == 1
|
||||
{
|
||||
break;
|
||||
}
|
||||
tokio::time::sleep(Duration::from_millis(10)).await;
|
||||
}
|
||||
})
|
||||
.await
|
||||
.expect("relay must reserve user slot and IP before abort");
|
||||
|
||||
relay_task.abort();
|
||||
let joined = relay_task.await;
|
||||
assert!(joined.is_err(), "aborted relay task must return join error");
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(50)).await;
|
||||
assert_eq!(
|
||||
stats.get_user_curr_connects(user),
|
||||
0,
|
||||
"task abort must release user current-connection slot"
|
||||
);
|
||||
assert_eq!(
|
||||
ip_tracker.get_active_ip_count(user).await,
|
||||
0,
|
||||
"task abort must release reserved user IP footprint"
|
||||
);
|
||||
|
||||
drop(client_side);
|
||||
tg_accept_task.abort();
|
||||
let _ = tg_accept_task.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn relay_cutover_releases_user_gate_and_ip_reservation() {
|
||||
let tg_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let tg_addr = tg_listener.local_addr().unwrap();
|
||||
|
||||
let tg_accept_task = tokio::spawn(async move {
|
||||
let (stream, _) = tg_listener.accept().await.unwrap();
|
||||
let _hold_stream = stream;
|
||||
tokio::time::sleep(Duration::from_secs(60)).await;
|
||||
});
|
||||
|
||||
let user = "cutover-user";
|
||||
let peer_addr: SocketAddr = "198.51.100.231:50001".parse().unwrap();
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let ip_tracker = Arc::new(UserIpTracker::new());
|
||||
|
||||
let mut cfg = ProxyConfig::default();
|
||||
cfg.access.user_max_tcp_conns.insert(user.to_string(), 8);
|
||||
cfg.dc_overrides
|
||||
.insert("2".to_string(), vec![tg_addr.to_string()]);
|
||||
let config = Arc::new(cfg);
|
||||
|
||||
let upstream_manager = Arc::new(UpstreamManager::new(
|
||||
vec![UpstreamConfig {
|
||||
upstream_type: UpstreamType::Direct {
|
||||
interface: None,
|
||||
bind_addresses: None,
|
||||
},
|
||||
weight: 1,
|
||||
enabled: true,
|
||||
scopes: String::new(),
|
||||
selected_scope: String::new(),
|
||||
}],
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
false,
|
||||
stats.clone(),
|
||||
));
|
||||
|
||||
let buffer_pool = Arc::new(BufferPool::new());
|
||||
let rng = Arc::new(SecureRandom::new());
|
||||
let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct));
|
||||
|
||||
let (server_side, client_side) = duplex(64 * 1024);
|
||||
let (server_reader, server_writer) = tokio::io::split(server_side);
|
||||
let client_reader = make_crypto_reader(server_reader);
|
||||
let client_writer = make_crypto_writer(server_writer);
|
||||
|
||||
let success = HandshakeSuccess {
|
||||
user: user.to_string(),
|
||||
dc_idx: 2,
|
||||
proto_tag: ProtoTag::Intermediate,
|
||||
dec_key: [0u8; 32],
|
||||
dec_iv: 0,
|
||||
enc_key: [0u8; 32],
|
||||
enc_iv: 0,
|
||||
peer: peer_addr,
|
||||
is_tls: false,
|
||||
};
|
||||
|
||||
let relay_task = tokio::spawn(RunningClientHandler::handle_authenticated_static(
|
||||
client_reader,
|
||||
client_writer,
|
||||
success,
|
||||
upstream_manager,
|
||||
stats.clone(),
|
||||
config,
|
||||
buffer_pool,
|
||||
rng,
|
||||
None,
|
||||
route_runtime.clone(),
|
||||
"127.0.0.1:443".parse().unwrap(),
|
||||
peer_addr,
|
||||
ip_tracker.clone(),
|
||||
));
|
||||
|
||||
tokio::time::timeout(Duration::from_secs(2), async {
|
||||
loop {
|
||||
if stats.get_user_curr_connects(user) == 1
|
||||
&& ip_tracker.get_active_ip_count(user).await == 1
|
||||
{
|
||||
break;
|
||||
}
|
||||
tokio::time::sleep(Duration::from_millis(10)).await;
|
||||
}
|
||||
})
|
||||
.await
|
||||
.expect("relay must reserve user slot and IP before cutover");
|
||||
|
||||
assert!(
|
||||
route_runtime.set_mode(RelayRouteMode::Middle).is_some(),
|
||||
"cutover must advance route generation"
|
||||
);
|
||||
|
||||
let relay_result = tokio::time::timeout(Duration::from_secs(6), relay_task)
|
||||
.await
|
||||
.expect("relay must terminate after cutover")
|
||||
.expect("relay task must not panic");
|
||||
assert!(relay_result.is_err(), "cutover must terminate direct relay session");
|
||||
|
||||
assert_eq!(
|
||||
stats.get_user_curr_connects(user),
|
||||
0,
|
||||
"cutover exit must release user current-connection slot"
|
||||
);
|
||||
assert_eq!(
|
||||
ip_tracker.get_active_ip_count(user).await,
|
||||
0,
|
||||
"cutover exit must release reserved user IP footprint"
|
||||
);
|
||||
|
||||
drop(client_side);
|
||||
tg_accept_task.abort();
|
||||
let _ = tg_accept_task.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn short_tls_probe_is_masked_through_client_pipeline() {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
|
|
|
|||
|
|
@ -178,3 +178,112 @@ async fn direct_relay_abort_midflight_releases_route_gauge() {
|
|||
tg_accept_task.abort();
|
||||
let _ = tg_accept_task.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn direct_relay_cutover_midflight_releases_route_gauge() {
|
||||
let tg_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let tg_addr = tg_listener.local_addr().unwrap();
|
||||
|
||||
let tg_accept_task = tokio::spawn(async move {
|
||||
let (stream, _) = tg_listener.accept().await.unwrap();
|
||||
let _hold_stream = stream;
|
||||
tokio::time::sleep(Duration::from_secs(60)).await;
|
||||
});
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let mut config = ProxyConfig::default();
|
||||
config
|
||||
.dc_overrides
|
||||
.insert("2".to_string(), vec![tg_addr.to_string()]);
|
||||
let config = Arc::new(config);
|
||||
|
||||
let upstream_manager = Arc::new(UpstreamManager::new(
|
||||
vec![UpstreamConfig {
|
||||
upstream_type: UpstreamType::Direct {
|
||||
interface: None,
|
||||
bind_addresses: None,
|
||||
},
|
||||
weight: 1,
|
||||
enabled: true,
|
||||
scopes: String::new(),
|
||||
selected_scope: String::new(),
|
||||
}],
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
false,
|
||||
stats.clone(),
|
||||
));
|
||||
|
||||
let rng = Arc::new(SecureRandom::new());
|
||||
let buffer_pool = Arc::new(BufferPool::new());
|
||||
let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct));
|
||||
let route_snapshot = route_runtime.snapshot();
|
||||
|
||||
let (server_side, client_side) = duplex(64 * 1024);
|
||||
let (server_reader, server_writer) = tokio::io::split(server_side);
|
||||
let client_reader = make_crypto_reader(server_reader);
|
||||
let client_writer = make_crypto_writer(server_writer);
|
||||
|
||||
let success = HandshakeSuccess {
|
||||
user: "cutover-direct-user".to_string(),
|
||||
dc_idx: 2,
|
||||
proto_tag: ProtoTag::Intermediate,
|
||||
dec_key: [0u8; 32],
|
||||
dec_iv: 0,
|
||||
enc_key: [0u8; 32],
|
||||
enc_iv: 0,
|
||||
peer: "127.0.0.1:50002".parse().unwrap(),
|
||||
is_tls: false,
|
||||
};
|
||||
|
||||
let relay_task = tokio::spawn(handle_via_direct(
|
||||
client_reader,
|
||||
client_writer,
|
||||
success,
|
||||
upstream_manager,
|
||||
stats.clone(),
|
||||
config,
|
||||
buffer_pool,
|
||||
rng,
|
||||
route_runtime.subscribe(),
|
||||
route_snapshot,
|
||||
0xface_cafe,
|
||||
));
|
||||
|
||||
tokio::time::timeout(Duration::from_secs(2), async {
|
||||
loop {
|
||||
if stats.get_current_connections_direct() == 1 {
|
||||
break;
|
||||
}
|
||||
tokio::time::sleep(Duration::from_millis(10)).await;
|
||||
}
|
||||
})
|
||||
.await
|
||||
.expect("direct relay must increment route gauge before cutover");
|
||||
|
||||
assert!(
|
||||
route_runtime.set_mode(RelayRouteMode::Middle).is_some(),
|
||||
"cutover must advance route generation"
|
||||
);
|
||||
|
||||
let relay_result = tokio::time::timeout(Duration::from_secs(6), relay_task)
|
||||
.await
|
||||
.expect("direct relay must terminate after cutover")
|
||||
.expect("direct relay task must not panic");
|
||||
assert!(
|
||||
relay_result.is_err(),
|
||||
"cutover should terminate direct relay session"
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
stats.get_current_connections_direct(),
|
||||
0,
|
||||
"route gauge must be released when direct relay exits on cutover"
|
||||
);
|
||||
|
||||
drop(client_side);
|
||||
tg_accept_task.abort();
|
||||
let _ = tg_accept_task.await;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -942,3 +942,80 @@ async fn middle_relay_abort_midflight_releases_route_gauge() {
|
|||
|
||||
drop(client_side);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn middle_relay_cutover_midflight_releases_route_gauge() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let me_pool = make_me_pool_for_abort_test(stats.clone()).await;
|
||||
let config = Arc::new(ProxyConfig::default());
|
||||
let buffer_pool = Arc::new(BufferPool::new());
|
||||
let rng = Arc::new(SecureRandom::new());
|
||||
|
||||
let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Middle));
|
||||
let route_snapshot = route_runtime.snapshot();
|
||||
|
||||
let (server_side, client_side) = duplex(64 * 1024);
|
||||
let (server_reader, server_writer) = tokio::io::split(server_side);
|
||||
let crypto_reader = make_crypto_reader(server_reader);
|
||||
let crypto_writer = make_crypto_writer(server_writer);
|
||||
|
||||
let success = HandshakeSuccess {
|
||||
user: "cutover-middle-user".to_string(),
|
||||
dc_idx: 2,
|
||||
proto_tag: ProtoTag::Intermediate,
|
||||
dec_key: [0u8; 32],
|
||||
dec_iv: 0,
|
||||
enc_key: [0u8; 32],
|
||||
enc_iv: 0,
|
||||
peer: "127.0.0.1:50003".parse().unwrap(),
|
||||
is_tls: false,
|
||||
};
|
||||
|
||||
let relay_task = tokio::spawn(handle_via_middle_proxy(
|
||||
crypto_reader,
|
||||
crypto_writer,
|
||||
success,
|
||||
me_pool,
|
||||
stats.clone(),
|
||||
config,
|
||||
buffer_pool,
|
||||
"127.0.0.1:443".parse().unwrap(),
|
||||
rng,
|
||||
route_runtime.subscribe(),
|
||||
route_snapshot,
|
||||
0xfeed_beef,
|
||||
));
|
||||
|
||||
tokio::time::timeout(TokioDuration::from_secs(2), async {
|
||||
loop {
|
||||
if stats.get_current_connections_me() == 1 {
|
||||
break;
|
||||
}
|
||||
tokio::time::sleep(TokioDuration::from_millis(10)).await;
|
||||
}
|
||||
})
|
||||
.await
|
||||
.expect("middle relay must increment route gauge before cutover");
|
||||
|
||||
assert!(
|
||||
route_runtime.set_mode(RelayRouteMode::Direct).is_some(),
|
||||
"cutover must advance route generation"
|
||||
);
|
||||
|
||||
let relay_result = tokio::time::timeout(TokioDuration::from_secs(6), relay_task)
|
||||
.await
|
||||
.expect("middle relay must terminate after cutover")
|
||||
.expect("middle relay task must not panic");
|
||||
assert!(
|
||||
relay_result.is_err(),
|
||||
"cutover should terminate middle relay session"
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
stats.get_current_connections_me(),
|
||||
0,
|
||||
"route gauge must be released when middle relay exits on cutover"
|
||||
);
|
||||
|
||||
drop(client_side);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ use super::*;
|
|||
use std::panic::{self, AssertUnwindSafe};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::Barrier;
|
||||
|
||||
#[test]
|
||||
fn direct_connection_lease_balances_on_drop() {
|
||||
|
|
@ -63,6 +64,156 @@ fn direct_connection_lease_balances_on_panic_unwind() {
|
|||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn middle_connection_lease_balances_on_panic_unwind() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let stats_for_panic = stats.clone();
|
||||
|
||||
let panic_result = panic::catch_unwind(AssertUnwindSafe(move || {
|
||||
let _lease = stats_for_panic.acquire_me_connection_lease();
|
||||
panic!("intentional panic to verify middle lease drop path");
|
||||
}));
|
||||
|
||||
assert!(panic_result.is_err(), "panic must propagate from test closure");
|
||||
assert_eq!(
|
||||
stats.get_current_connections_me(),
|
||||
0,
|
||||
"panic unwind must release middle route gauge"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn concurrent_mixed_route_lease_churn_balances_to_zero() {
|
||||
const TASKS: usize = 48;
|
||||
const ITERATIONS_PER_TASK: usize = 256;
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let barrier = Arc::new(Barrier::new(TASKS));
|
||||
let mut workers = Vec::with_capacity(TASKS);
|
||||
|
||||
for task_idx in 0..TASKS {
|
||||
let stats_for_task = stats.clone();
|
||||
let barrier_for_task = barrier.clone();
|
||||
workers.push(tokio::spawn(async move {
|
||||
barrier_for_task.wait().await;
|
||||
for iter in 0..ITERATIONS_PER_TASK {
|
||||
if (task_idx + iter) % 2 == 0 {
|
||||
let _lease = stats_for_task.acquire_direct_connection_lease();
|
||||
tokio::task::yield_now().await;
|
||||
} else {
|
||||
let _lease = stats_for_task.acquire_me_connection_lease();
|
||||
tokio::task::yield_now().await;
|
||||
}
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
for worker in workers {
|
||||
worker
|
||||
.await
|
||||
.expect("lease churn worker must not panic");
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
stats.get_current_connections_direct(),
|
||||
0,
|
||||
"direct route gauge must return to zero after concurrent lease churn"
|
||||
);
|
||||
assert_eq!(
|
||||
stats.get_current_connections_me(),
|
||||
0,
|
||||
"middle route gauge must return to zero after concurrent lease churn"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn abort_storm_mixed_route_leases_returns_all_gauges_to_zero() {
|
||||
const TASKS: usize = 64;
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let mut workers = Vec::with_capacity(TASKS);
|
||||
|
||||
for task_idx in 0..TASKS {
|
||||
let stats_for_task = stats.clone();
|
||||
workers.push(tokio::spawn(async move {
|
||||
if task_idx % 2 == 0 {
|
||||
let _lease = stats_for_task.acquire_direct_connection_lease();
|
||||
tokio::time::sleep(Duration::from_secs(60)).await;
|
||||
} else {
|
||||
let _lease = stats_for_task.acquire_me_connection_lease();
|
||||
tokio::time::sleep(Duration::from_secs(60)).await;
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
tokio::time::timeout(Duration::from_secs(2), async {
|
||||
loop {
|
||||
let total = stats.get_current_connections_direct() + stats.get_current_connections_me();
|
||||
if total == TASKS as u64 {
|
||||
break;
|
||||
}
|
||||
tokio::time::sleep(Duration::from_millis(10)).await;
|
||||
}
|
||||
})
|
||||
.await
|
||||
.expect("all storm tasks must acquire route leases before abort");
|
||||
|
||||
for worker in &workers {
|
||||
worker.abort();
|
||||
}
|
||||
for worker in workers {
|
||||
let joined = worker.await;
|
||||
assert!(joined.is_err(), "aborted worker must return join error");
|
||||
}
|
||||
|
||||
tokio::time::timeout(Duration::from_secs(2), async {
|
||||
loop {
|
||||
if stats.get_current_connections_direct() == 0 && stats.get_current_connections_me() == 0 {
|
||||
break;
|
||||
}
|
||||
tokio::time::sleep(Duration::from_millis(10)).await;
|
||||
}
|
||||
})
|
||||
.await
|
||||
.expect("all route gauges must drain to zero after abort storm");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn saturating_route_decrements_do_not_underflow_under_race() {
|
||||
const THREADS: usize = 16;
|
||||
const DECREMENTS_PER_THREAD: usize = 4096;
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let mut workers = Vec::with_capacity(THREADS);
|
||||
|
||||
for _ in 0..THREADS {
|
||||
let stats_for_thread = stats.clone();
|
||||
workers.push(std::thread::spawn(move || {
|
||||
for _ in 0..DECREMENTS_PER_THREAD {
|
||||
stats_for_thread.decrement_current_connections_direct();
|
||||
stats_for_thread.decrement_current_connections_me();
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
for worker in workers {
|
||||
worker
|
||||
.join()
|
||||
.expect("decrement race worker must not panic");
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
stats.get_current_connections_direct(),
|
||||
0,
|
||||
"direct route decrement races must never underflow"
|
||||
);
|
||||
assert_eq!(
|
||||
stats.get_current_connections_me(),
|
||||
0,
|
||||
"middle route decrement races must never underflow"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn direct_connection_lease_balances_on_task_abort() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
|
|
|
|||
Loading…
Reference in New Issue