mirror of https://github.com/telemt/telemt.git
Implement user connection reservation management and enhance relay task handling in proxy
This commit is contained in:
parent
4808a30185
commit
c540a6657f
|
|
@ -24,6 +24,39 @@ enum HandshakeOutcome {
|
||||||
Handled,
|
Handled,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct UserConnectionReservation {
|
||||||
|
stats: Arc<Stats>,
|
||||||
|
ip_tracker: Arc<UserIpTracker>,
|
||||||
|
user: String,
|
||||||
|
ip: IpAddr,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl UserConnectionReservation {
|
||||||
|
fn new(stats: Arc<Stats>, ip_tracker: Arc<UserIpTracker>, user: String, ip: IpAddr) -> Self {
|
||||||
|
Self {
|
||||||
|
stats,
|
||||||
|
ip_tracker,
|
||||||
|
user,
|
||||||
|
ip,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for UserConnectionReservation {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
self.stats.decrement_user_curr_connects(&self.user);
|
||||||
|
|
||||||
|
if let Ok(handle) = tokio::runtime::Handle::try_current() {
|
||||||
|
let ip_tracker = self.ip_tracker.clone();
|
||||||
|
let user = self.user.clone();
|
||||||
|
let ip = self.ip;
|
||||||
|
handle.spawn(async move {
|
||||||
|
ip_tracker.remove_ip(&user, ip).await;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
use crate::config::ProxyConfig;
|
use crate::config::ProxyConfig;
|
||||||
use crate::crypto::SecureRandom;
|
use crate::crypto::SecureRandom;
|
||||||
use crate::error::{HandshakeResult, ProxyError, Result, StreamError};
|
use crate::error::{HandshakeResult, ProxyError, Result, StreamError};
|
||||||
|
|
@ -90,6 +123,10 @@ fn is_trusted_proxy_source(peer_ip: IpAddr, trusted: &[IpNetwork]) -> bool {
|
||||||
trusted.iter().any(|cidr| cidr.contains(peer_ip))
|
trusted.iter().any(|cidr| cidr.contains(peer_ip))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn synthetic_local_addr(port: u16) -> SocketAddr {
|
||||||
|
SocketAddr::from(([0, 0, 0, 0], port))
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn handle_client_stream<S>(
|
pub async fn handle_client_stream<S>(
|
||||||
mut stream: S,
|
mut stream: S,
|
||||||
peer: SocketAddr,
|
peer: SocketAddr,
|
||||||
|
|
@ -113,9 +150,7 @@ where
|
||||||
let mut real_peer = normalize_ip(peer);
|
let mut real_peer = normalize_ip(peer);
|
||||||
|
|
||||||
// For non-TCP streams, use a synthetic local address; may be overridden by PROXY protocol dst
|
// For non-TCP streams, use a synthetic local address; may be overridden by PROXY protocol dst
|
||||||
let mut local_addr: SocketAddr = format!("0.0.0.0:{}", config.server.port)
|
let mut local_addr = synthetic_local_addr(config.server.port);
|
||||||
.parse()
|
|
||||||
.unwrap_or_else(|_| "0.0.0.0:443".parse().unwrap());
|
|
||||||
|
|
||||||
if proxy_protocol_enabled {
|
if proxy_protocol_enabled {
|
||||||
let proxy_header_timeout = Duration::from_millis(
|
let proxy_header_timeout = Duration::from_millis(
|
||||||
|
|
@ -798,10 +833,22 @@ impl RunningClientHandler {
|
||||||
{
|
{
|
||||||
let user = success.user.clone();
|
let user = success.user.clone();
|
||||||
|
|
||||||
if let Err(e) = Self::check_user_limits_static(&user, &config, &stats, peer_addr, &ip_tracker).await {
|
let _user_limit_reservation =
|
||||||
warn!(user = %user, error = %e, "User limit exceeded");
|
match Self::acquire_user_connection_reservation_static(
|
||||||
return Err(e);
|
&user,
|
||||||
}
|
&config,
|
||||||
|
stats.clone(),
|
||||||
|
peer_addr,
|
||||||
|
ip_tracker,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(reservation) => reservation,
|
||||||
|
Err(e) => {
|
||||||
|
warn!(user = %user, error = %e, "User limit exceeded");
|
||||||
|
return Err(e);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
let route_snapshot = route_runtime.snapshot();
|
let route_snapshot = route_runtime.snapshot();
|
||||||
let session_id = rng.u64();
|
let session_id = rng.u64();
|
||||||
|
|
@ -858,12 +905,64 @@ impl RunningClientHandler {
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
};
|
};
|
||||||
|
|
||||||
stats.decrement_user_curr_connects(&user);
|
|
||||||
ip_tracker.remove_ip(&user, peer_addr.ip()).await;
|
|
||||||
relay_result
|
relay_result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn acquire_user_connection_reservation_static(
|
||||||
|
user: &str,
|
||||||
|
config: &ProxyConfig,
|
||||||
|
stats: Arc<Stats>,
|
||||||
|
peer_addr: SocketAddr,
|
||||||
|
ip_tracker: Arc<UserIpTracker>,
|
||||||
|
) -> Result<UserConnectionReservation> {
|
||||||
|
if let Some(expiration) = config.access.user_expirations.get(user)
|
||||||
|
&& chrono::Utc::now() > *expiration
|
||||||
|
{
|
||||||
|
return Err(ProxyError::UserExpired {
|
||||||
|
user: user.to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(quota) = config.access.user_data_quota.get(user)
|
||||||
|
&& stats.get_user_total_octets(user) >= *quota
|
||||||
|
{
|
||||||
|
return Err(ProxyError::DataQuotaExceeded {
|
||||||
|
user: user.to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let limit = config.access.user_max_tcp_conns.get(user).map(|v| *v as u64);
|
||||||
|
if !stats.try_acquire_user_curr_connects(user, limit) {
|
||||||
|
return Err(ProxyError::ConnectionLimitExceeded {
|
||||||
|
user: user.to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
match ip_tracker.check_and_add(user, peer_addr.ip()).await {
|
||||||
|
Ok(()) => {}
|
||||||
|
Err(reason) => {
|
||||||
|
stats.decrement_user_curr_connects(user);
|
||||||
|
warn!(
|
||||||
|
user = %user,
|
||||||
|
ip = %peer_addr.ip(),
|
||||||
|
reason = %reason,
|
||||||
|
"IP limit exceeded"
|
||||||
|
);
|
||||||
|
return Err(ProxyError::ConnectionLimitExceeded {
|
||||||
|
user: user.to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(UserConnectionReservation::new(
|
||||||
|
stats,
|
||||||
|
ip_tracker,
|
||||||
|
user.to_string(),
|
||||||
|
peer_addr.ip(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
async fn check_user_limits_static(
|
async fn check_user_limits_static(
|
||||||
user: &str,
|
user: &str,
|
||||||
config: &ProxyConfig,
|
config: &ProxyConfig,
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,279 @@
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::config::{UpstreamConfig, UpstreamType};
|
use crate::config::{UpstreamConfig, UpstreamType};
|
||||||
|
use crate::crypto::AesCtr;
|
||||||
use crate::crypto::sha256_hmac;
|
use crate::crypto::sha256_hmac;
|
||||||
|
use crate::protocol::constants::ProtoTag;
|
||||||
use crate::protocol::tls;
|
use crate::protocol::tls;
|
||||||
|
use crate::proxy::handshake::HandshakeSuccess;
|
||||||
use crate::transport::proxy_protocol::ProxyProtocolV1Builder;
|
use crate::transport::proxy_protocol::ProxyProtocolV1Builder;
|
||||||
|
use crate::stream::{CryptoReader, CryptoWriter};
|
||||||
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};
|
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};
|
||||||
use tokio::net::{TcpListener, TcpStream};
|
use tokio::net::{TcpListener, TcpStream};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn synthetic_local_addr_uses_configured_port_for_zero() {
|
||||||
|
let addr = synthetic_local_addr(0);
|
||||||
|
assert_eq!(addr.ip(), IpAddr::from([0, 0, 0, 0]));
|
||||||
|
assert_eq!(addr.port(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn synthetic_local_addr_uses_configured_port_for_max() {
|
||||||
|
let addr = synthetic_local_addr(u16::MAX);
|
||||||
|
assert_eq!(addr.ip(), IpAddr::from([0, 0, 0, 0]));
|
||||||
|
assert_eq!(addr.port(), u16::MAX);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn make_crypto_reader<R>(reader: R) -> CryptoReader<R>
|
||||||
|
where
|
||||||
|
R: tokio::io::AsyncRead + Unpin,
|
||||||
|
{
|
||||||
|
let key = [0u8; 32];
|
||||||
|
let iv = 0u128;
|
||||||
|
CryptoReader::new(reader, AesCtr::new(&key, iv))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn make_crypto_writer<W>(writer: W) -> CryptoWriter<W>
|
||||||
|
where
|
||||||
|
W: tokio::io::AsyncWrite + Unpin,
|
||||||
|
{
|
||||||
|
let key = [0u8; 32];
|
||||||
|
let iv = 0u128;
|
||||||
|
CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn relay_task_abort_releases_user_gate_and_ip_reservation() {
|
||||||
|
let tg_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let tg_addr = tg_listener.local_addr().unwrap();
|
||||||
|
|
||||||
|
let tg_accept_task = tokio::spawn(async move {
|
||||||
|
let (stream, _) = tg_listener.accept().await.unwrap();
|
||||||
|
let _hold_stream = stream;
|
||||||
|
tokio::time::sleep(Duration::from_secs(60)).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
let user = "abort-user";
|
||||||
|
let peer_addr: SocketAddr = "198.51.100.230:50000".parse().unwrap();
|
||||||
|
|
||||||
|
let stats = Arc::new(Stats::new());
|
||||||
|
let ip_tracker = Arc::new(UserIpTracker::new());
|
||||||
|
|
||||||
|
let mut cfg = ProxyConfig::default();
|
||||||
|
cfg.access.user_max_tcp_conns.insert(user.to_string(), 8);
|
||||||
|
cfg.dc_overrides
|
||||||
|
.insert("2".to_string(), vec![tg_addr.to_string()]);
|
||||||
|
let config = Arc::new(cfg);
|
||||||
|
|
||||||
|
let upstream_manager = Arc::new(UpstreamManager::new(
|
||||||
|
vec![UpstreamConfig {
|
||||||
|
upstream_type: UpstreamType::Direct {
|
||||||
|
interface: None,
|
||||||
|
bind_addresses: None,
|
||||||
|
},
|
||||||
|
weight: 1,
|
||||||
|
enabled: true,
|
||||||
|
scopes: String::new(),
|
||||||
|
selected_scope: String::new(),
|
||||||
|
}],
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
false,
|
||||||
|
stats.clone(),
|
||||||
|
));
|
||||||
|
|
||||||
|
let buffer_pool = Arc::new(BufferPool::new());
|
||||||
|
let rng = Arc::new(SecureRandom::new());
|
||||||
|
let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct));
|
||||||
|
|
||||||
|
let (server_side, client_side) = duplex(64 * 1024);
|
||||||
|
let (server_reader, server_writer) = tokio::io::split(server_side);
|
||||||
|
let client_reader = make_crypto_reader(server_reader);
|
||||||
|
let client_writer = make_crypto_writer(server_writer);
|
||||||
|
|
||||||
|
let success = HandshakeSuccess {
|
||||||
|
user: user.to_string(),
|
||||||
|
dc_idx: 2,
|
||||||
|
proto_tag: ProtoTag::Intermediate,
|
||||||
|
dec_key: [0u8; 32],
|
||||||
|
dec_iv: 0,
|
||||||
|
enc_key: [0u8; 32],
|
||||||
|
enc_iv: 0,
|
||||||
|
peer: peer_addr,
|
||||||
|
is_tls: false,
|
||||||
|
};
|
||||||
|
|
||||||
|
let relay_task = tokio::spawn(RunningClientHandler::handle_authenticated_static(
|
||||||
|
client_reader,
|
||||||
|
client_writer,
|
||||||
|
success,
|
||||||
|
upstream_manager,
|
||||||
|
stats.clone(),
|
||||||
|
config,
|
||||||
|
buffer_pool,
|
||||||
|
rng,
|
||||||
|
None,
|
||||||
|
route_runtime,
|
||||||
|
"127.0.0.1:443".parse().unwrap(),
|
||||||
|
peer_addr,
|
||||||
|
ip_tracker.clone(),
|
||||||
|
));
|
||||||
|
|
||||||
|
tokio::time::timeout(Duration::from_secs(2), async {
|
||||||
|
loop {
|
||||||
|
if stats.get_user_curr_connects(user) == 1
|
||||||
|
&& ip_tracker.get_active_ip_count(user).await == 1
|
||||||
|
{
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
tokio::time::sleep(Duration::from_millis(10)).await;
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.expect("relay must reserve user slot and IP before abort");
|
||||||
|
|
||||||
|
relay_task.abort();
|
||||||
|
let joined = relay_task.await;
|
||||||
|
assert!(joined.is_err(), "aborted relay task must return join error");
|
||||||
|
|
||||||
|
tokio::time::sleep(Duration::from_millis(50)).await;
|
||||||
|
assert_eq!(
|
||||||
|
stats.get_user_curr_connects(user),
|
||||||
|
0,
|
||||||
|
"task abort must release user current-connection slot"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
ip_tracker.get_active_ip_count(user).await,
|
||||||
|
0,
|
||||||
|
"task abort must release reserved user IP footprint"
|
||||||
|
);
|
||||||
|
|
||||||
|
drop(client_side);
|
||||||
|
tg_accept_task.abort();
|
||||||
|
let _ = tg_accept_task.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn relay_cutover_releases_user_gate_and_ip_reservation() {
|
||||||
|
let tg_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let tg_addr = tg_listener.local_addr().unwrap();
|
||||||
|
|
||||||
|
let tg_accept_task = tokio::spawn(async move {
|
||||||
|
let (stream, _) = tg_listener.accept().await.unwrap();
|
||||||
|
let _hold_stream = stream;
|
||||||
|
tokio::time::sleep(Duration::from_secs(60)).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
let user = "cutover-user";
|
||||||
|
let peer_addr: SocketAddr = "198.51.100.231:50001".parse().unwrap();
|
||||||
|
|
||||||
|
let stats = Arc::new(Stats::new());
|
||||||
|
let ip_tracker = Arc::new(UserIpTracker::new());
|
||||||
|
|
||||||
|
let mut cfg = ProxyConfig::default();
|
||||||
|
cfg.access.user_max_tcp_conns.insert(user.to_string(), 8);
|
||||||
|
cfg.dc_overrides
|
||||||
|
.insert("2".to_string(), vec![tg_addr.to_string()]);
|
||||||
|
let config = Arc::new(cfg);
|
||||||
|
|
||||||
|
let upstream_manager = Arc::new(UpstreamManager::new(
|
||||||
|
vec![UpstreamConfig {
|
||||||
|
upstream_type: UpstreamType::Direct {
|
||||||
|
interface: None,
|
||||||
|
bind_addresses: None,
|
||||||
|
},
|
||||||
|
weight: 1,
|
||||||
|
enabled: true,
|
||||||
|
scopes: String::new(),
|
||||||
|
selected_scope: String::new(),
|
||||||
|
}],
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
false,
|
||||||
|
stats.clone(),
|
||||||
|
));
|
||||||
|
|
||||||
|
let buffer_pool = Arc::new(BufferPool::new());
|
||||||
|
let rng = Arc::new(SecureRandom::new());
|
||||||
|
let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct));
|
||||||
|
|
||||||
|
let (server_side, client_side) = duplex(64 * 1024);
|
||||||
|
let (server_reader, server_writer) = tokio::io::split(server_side);
|
||||||
|
let client_reader = make_crypto_reader(server_reader);
|
||||||
|
let client_writer = make_crypto_writer(server_writer);
|
||||||
|
|
||||||
|
let success = HandshakeSuccess {
|
||||||
|
user: user.to_string(),
|
||||||
|
dc_idx: 2,
|
||||||
|
proto_tag: ProtoTag::Intermediate,
|
||||||
|
dec_key: [0u8; 32],
|
||||||
|
dec_iv: 0,
|
||||||
|
enc_key: [0u8; 32],
|
||||||
|
enc_iv: 0,
|
||||||
|
peer: peer_addr,
|
||||||
|
is_tls: false,
|
||||||
|
};
|
||||||
|
|
||||||
|
let relay_task = tokio::spawn(RunningClientHandler::handle_authenticated_static(
|
||||||
|
client_reader,
|
||||||
|
client_writer,
|
||||||
|
success,
|
||||||
|
upstream_manager,
|
||||||
|
stats.clone(),
|
||||||
|
config,
|
||||||
|
buffer_pool,
|
||||||
|
rng,
|
||||||
|
None,
|
||||||
|
route_runtime.clone(),
|
||||||
|
"127.0.0.1:443".parse().unwrap(),
|
||||||
|
peer_addr,
|
||||||
|
ip_tracker.clone(),
|
||||||
|
));
|
||||||
|
|
||||||
|
tokio::time::timeout(Duration::from_secs(2), async {
|
||||||
|
loop {
|
||||||
|
if stats.get_user_curr_connects(user) == 1
|
||||||
|
&& ip_tracker.get_active_ip_count(user).await == 1
|
||||||
|
{
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
tokio::time::sleep(Duration::from_millis(10)).await;
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.expect("relay must reserve user slot and IP before cutover");
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
route_runtime.set_mode(RelayRouteMode::Middle).is_some(),
|
||||||
|
"cutover must advance route generation"
|
||||||
|
);
|
||||||
|
|
||||||
|
let relay_result = tokio::time::timeout(Duration::from_secs(6), relay_task)
|
||||||
|
.await
|
||||||
|
.expect("relay must terminate after cutover")
|
||||||
|
.expect("relay task must not panic");
|
||||||
|
assert!(relay_result.is_err(), "cutover must terminate direct relay session");
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
stats.get_user_curr_connects(user),
|
||||||
|
0,
|
||||||
|
"cutover exit must release user current-connection slot"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
ip_tracker.get_active_ip_count(user).await,
|
||||||
|
0,
|
||||||
|
"cutover exit must release reserved user IP footprint"
|
||||||
|
);
|
||||||
|
|
||||||
|
drop(client_side);
|
||||||
|
tg_accept_task.abort();
|
||||||
|
let _ = tg_accept_task.await;
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn short_tls_probe_is_masked_through_client_pipeline() {
|
async fn short_tls_probe_is_masked_through_client_pipeline() {
|
||||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
|
|
||||||
|
|
@ -178,3 +178,112 @@ async fn direct_relay_abort_midflight_releases_route_gauge() {
|
||||||
tg_accept_task.abort();
|
tg_accept_task.abort();
|
||||||
let _ = tg_accept_task.await;
|
let _ = tg_accept_task.await;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn direct_relay_cutover_midflight_releases_route_gauge() {
|
||||||
|
let tg_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let tg_addr = tg_listener.local_addr().unwrap();
|
||||||
|
|
||||||
|
let tg_accept_task = tokio::spawn(async move {
|
||||||
|
let (stream, _) = tg_listener.accept().await.unwrap();
|
||||||
|
let _hold_stream = stream;
|
||||||
|
tokio::time::sleep(Duration::from_secs(60)).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
let stats = Arc::new(Stats::new());
|
||||||
|
let mut config = ProxyConfig::default();
|
||||||
|
config
|
||||||
|
.dc_overrides
|
||||||
|
.insert("2".to_string(), vec![tg_addr.to_string()]);
|
||||||
|
let config = Arc::new(config);
|
||||||
|
|
||||||
|
let upstream_manager = Arc::new(UpstreamManager::new(
|
||||||
|
vec![UpstreamConfig {
|
||||||
|
upstream_type: UpstreamType::Direct {
|
||||||
|
interface: None,
|
||||||
|
bind_addresses: None,
|
||||||
|
},
|
||||||
|
weight: 1,
|
||||||
|
enabled: true,
|
||||||
|
scopes: String::new(),
|
||||||
|
selected_scope: String::new(),
|
||||||
|
}],
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
false,
|
||||||
|
stats.clone(),
|
||||||
|
));
|
||||||
|
|
||||||
|
let rng = Arc::new(SecureRandom::new());
|
||||||
|
let buffer_pool = Arc::new(BufferPool::new());
|
||||||
|
let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct));
|
||||||
|
let route_snapshot = route_runtime.snapshot();
|
||||||
|
|
||||||
|
let (server_side, client_side) = duplex(64 * 1024);
|
||||||
|
let (server_reader, server_writer) = tokio::io::split(server_side);
|
||||||
|
let client_reader = make_crypto_reader(server_reader);
|
||||||
|
let client_writer = make_crypto_writer(server_writer);
|
||||||
|
|
||||||
|
let success = HandshakeSuccess {
|
||||||
|
user: "cutover-direct-user".to_string(),
|
||||||
|
dc_idx: 2,
|
||||||
|
proto_tag: ProtoTag::Intermediate,
|
||||||
|
dec_key: [0u8; 32],
|
||||||
|
dec_iv: 0,
|
||||||
|
enc_key: [0u8; 32],
|
||||||
|
enc_iv: 0,
|
||||||
|
peer: "127.0.0.1:50002".parse().unwrap(),
|
||||||
|
is_tls: false,
|
||||||
|
};
|
||||||
|
|
||||||
|
let relay_task = tokio::spawn(handle_via_direct(
|
||||||
|
client_reader,
|
||||||
|
client_writer,
|
||||||
|
success,
|
||||||
|
upstream_manager,
|
||||||
|
stats.clone(),
|
||||||
|
config,
|
||||||
|
buffer_pool,
|
||||||
|
rng,
|
||||||
|
route_runtime.subscribe(),
|
||||||
|
route_snapshot,
|
||||||
|
0xface_cafe,
|
||||||
|
));
|
||||||
|
|
||||||
|
tokio::time::timeout(Duration::from_secs(2), async {
|
||||||
|
loop {
|
||||||
|
if stats.get_current_connections_direct() == 1 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
tokio::time::sleep(Duration::from_millis(10)).await;
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.expect("direct relay must increment route gauge before cutover");
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
route_runtime.set_mode(RelayRouteMode::Middle).is_some(),
|
||||||
|
"cutover must advance route generation"
|
||||||
|
);
|
||||||
|
|
||||||
|
let relay_result = tokio::time::timeout(Duration::from_secs(6), relay_task)
|
||||||
|
.await
|
||||||
|
.expect("direct relay must terminate after cutover")
|
||||||
|
.expect("direct relay task must not panic");
|
||||||
|
assert!(
|
||||||
|
relay_result.is_err(),
|
||||||
|
"cutover should terminate direct relay session"
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
stats.get_current_connections_direct(),
|
||||||
|
0,
|
||||||
|
"route gauge must be released when direct relay exits on cutover"
|
||||||
|
);
|
||||||
|
|
||||||
|
drop(client_side);
|
||||||
|
tg_accept_task.abort();
|
||||||
|
let _ = tg_accept_task.await;
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -942,3 +942,80 @@ async fn middle_relay_abort_midflight_releases_route_gauge() {
|
||||||
|
|
||||||
drop(client_side);
|
drop(client_side);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn middle_relay_cutover_midflight_releases_route_gauge() {
|
||||||
|
let stats = Arc::new(Stats::new());
|
||||||
|
let me_pool = make_me_pool_for_abort_test(stats.clone()).await;
|
||||||
|
let config = Arc::new(ProxyConfig::default());
|
||||||
|
let buffer_pool = Arc::new(BufferPool::new());
|
||||||
|
let rng = Arc::new(SecureRandom::new());
|
||||||
|
|
||||||
|
let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Middle));
|
||||||
|
let route_snapshot = route_runtime.snapshot();
|
||||||
|
|
||||||
|
let (server_side, client_side) = duplex(64 * 1024);
|
||||||
|
let (server_reader, server_writer) = tokio::io::split(server_side);
|
||||||
|
let crypto_reader = make_crypto_reader(server_reader);
|
||||||
|
let crypto_writer = make_crypto_writer(server_writer);
|
||||||
|
|
||||||
|
let success = HandshakeSuccess {
|
||||||
|
user: "cutover-middle-user".to_string(),
|
||||||
|
dc_idx: 2,
|
||||||
|
proto_tag: ProtoTag::Intermediate,
|
||||||
|
dec_key: [0u8; 32],
|
||||||
|
dec_iv: 0,
|
||||||
|
enc_key: [0u8; 32],
|
||||||
|
enc_iv: 0,
|
||||||
|
peer: "127.0.0.1:50003".parse().unwrap(),
|
||||||
|
is_tls: false,
|
||||||
|
};
|
||||||
|
|
||||||
|
let relay_task = tokio::spawn(handle_via_middle_proxy(
|
||||||
|
crypto_reader,
|
||||||
|
crypto_writer,
|
||||||
|
success,
|
||||||
|
me_pool,
|
||||||
|
stats.clone(),
|
||||||
|
config,
|
||||||
|
buffer_pool,
|
||||||
|
"127.0.0.1:443".parse().unwrap(),
|
||||||
|
rng,
|
||||||
|
route_runtime.subscribe(),
|
||||||
|
route_snapshot,
|
||||||
|
0xfeed_beef,
|
||||||
|
));
|
||||||
|
|
||||||
|
tokio::time::timeout(TokioDuration::from_secs(2), async {
|
||||||
|
loop {
|
||||||
|
if stats.get_current_connections_me() == 1 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
tokio::time::sleep(TokioDuration::from_millis(10)).await;
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.expect("middle relay must increment route gauge before cutover");
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
route_runtime.set_mode(RelayRouteMode::Direct).is_some(),
|
||||||
|
"cutover must advance route generation"
|
||||||
|
);
|
||||||
|
|
||||||
|
let relay_result = tokio::time::timeout(TokioDuration::from_secs(6), relay_task)
|
||||||
|
.await
|
||||||
|
.expect("middle relay must terminate after cutover")
|
||||||
|
.expect("middle relay task must not panic");
|
||||||
|
assert!(
|
||||||
|
relay_result.is_err(),
|
||||||
|
"cutover should terminate middle relay session"
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
stats.get_current_connections_me(),
|
||||||
|
0,
|
||||||
|
"route gauge must be released when middle relay exits on cutover"
|
||||||
|
);
|
||||||
|
|
||||||
|
drop(client_side);
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ use super::*;
|
||||||
use std::panic::{self, AssertUnwindSafe};
|
use std::panic::{self, AssertUnwindSafe};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
use tokio::sync::Barrier;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn direct_connection_lease_balances_on_drop() {
|
fn direct_connection_lease_balances_on_drop() {
|
||||||
|
|
@ -63,6 +64,156 @@ fn direct_connection_lease_balances_on_panic_unwind() {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn middle_connection_lease_balances_on_panic_unwind() {
|
||||||
|
let stats = Arc::new(Stats::new());
|
||||||
|
let stats_for_panic = stats.clone();
|
||||||
|
|
||||||
|
let panic_result = panic::catch_unwind(AssertUnwindSafe(move || {
|
||||||
|
let _lease = stats_for_panic.acquire_me_connection_lease();
|
||||||
|
panic!("intentional panic to verify middle lease drop path");
|
||||||
|
}));
|
||||||
|
|
||||||
|
assert!(panic_result.is_err(), "panic must propagate from test closure");
|
||||||
|
assert_eq!(
|
||||||
|
stats.get_current_connections_me(),
|
||||||
|
0,
|
||||||
|
"panic unwind must release middle route gauge"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn concurrent_mixed_route_lease_churn_balances_to_zero() {
|
||||||
|
const TASKS: usize = 48;
|
||||||
|
const ITERATIONS_PER_TASK: usize = 256;
|
||||||
|
|
||||||
|
let stats = Arc::new(Stats::new());
|
||||||
|
let barrier = Arc::new(Barrier::new(TASKS));
|
||||||
|
let mut workers = Vec::with_capacity(TASKS);
|
||||||
|
|
||||||
|
for task_idx in 0..TASKS {
|
||||||
|
let stats_for_task = stats.clone();
|
||||||
|
let barrier_for_task = barrier.clone();
|
||||||
|
workers.push(tokio::spawn(async move {
|
||||||
|
barrier_for_task.wait().await;
|
||||||
|
for iter in 0..ITERATIONS_PER_TASK {
|
||||||
|
if (task_idx + iter) % 2 == 0 {
|
||||||
|
let _lease = stats_for_task.acquire_direct_connection_lease();
|
||||||
|
tokio::task::yield_now().await;
|
||||||
|
} else {
|
||||||
|
let _lease = stats_for_task.acquire_me_connection_lease();
|
||||||
|
tokio::task::yield_now().await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
for worker in workers {
|
||||||
|
worker
|
||||||
|
.await
|
||||||
|
.expect("lease churn worker must not panic");
|
||||||
|
}
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
stats.get_current_connections_direct(),
|
||||||
|
0,
|
||||||
|
"direct route gauge must return to zero after concurrent lease churn"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
stats.get_current_connections_me(),
|
||||||
|
0,
|
||||||
|
"middle route gauge must return to zero after concurrent lease churn"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn abort_storm_mixed_route_leases_returns_all_gauges_to_zero() {
|
||||||
|
const TASKS: usize = 64;
|
||||||
|
|
||||||
|
let stats = Arc::new(Stats::new());
|
||||||
|
let mut workers = Vec::with_capacity(TASKS);
|
||||||
|
|
||||||
|
for task_idx in 0..TASKS {
|
||||||
|
let stats_for_task = stats.clone();
|
||||||
|
workers.push(tokio::spawn(async move {
|
||||||
|
if task_idx % 2 == 0 {
|
||||||
|
let _lease = stats_for_task.acquire_direct_connection_lease();
|
||||||
|
tokio::time::sleep(Duration::from_secs(60)).await;
|
||||||
|
} else {
|
||||||
|
let _lease = stats_for_task.acquire_me_connection_lease();
|
||||||
|
tokio::time::sleep(Duration::from_secs(60)).await;
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
tokio::time::timeout(Duration::from_secs(2), async {
|
||||||
|
loop {
|
||||||
|
let total = stats.get_current_connections_direct() + stats.get_current_connections_me();
|
||||||
|
if total == TASKS as u64 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
tokio::time::sleep(Duration::from_millis(10)).await;
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.expect("all storm tasks must acquire route leases before abort");
|
||||||
|
|
||||||
|
for worker in &workers {
|
||||||
|
worker.abort();
|
||||||
|
}
|
||||||
|
for worker in workers {
|
||||||
|
let joined = worker.await;
|
||||||
|
assert!(joined.is_err(), "aborted worker must return join error");
|
||||||
|
}
|
||||||
|
|
||||||
|
tokio::time::timeout(Duration::from_secs(2), async {
|
||||||
|
loop {
|
||||||
|
if stats.get_current_connections_direct() == 0 && stats.get_current_connections_me() == 0 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
tokio::time::sleep(Duration::from_millis(10)).await;
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.expect("all route gauges must drain to zero after abort storm");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn saturating_route_decrements_do_not_underflow_under_race() {
|
||||||
|
const THREADS: usize = 16;
|
||||||
|
const DECREMENTS_PER_THREAD: usize = 4096;
|
||||||
|
|
||||||
|
let stats = Arc::new(Stats::new());
|
||||||
|
let mut workers = Vec::with_capacity(THREADS);
|
||||||
|
|
||||||
|
for _ in 0..THREADS {
|
||||||
|
let stats_for_thread = stats.clone();
|
||||||
|
workers.push(std::thread::spawn(move || {
|
||||||
|
for _ in 0..DECREMENTS_PER_THREAD {
|
||||||
|
stats_for_thread.decrement_current_connections_direct();
|
||||||
|
stats_for_thread.decrement_current_connections_me();
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
for worker in workers {
|
||||||
|
worker
|
||||||
|
.join()
|
||||||
|
.expect("decrement race worker must not panic");
|
||||||
|
}
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
stats.get_current_connections_direct(),
|
||||||
|
0,
|
||||||
|
"direct route decrement races must never underflow"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
stats.get_current_connections_me(),
|
||||||
|
0,
|
||||||
|
"middle route decrement races must never underflow"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn direct_connection_lease_balances_on_task_abort() {
|
async fn direct_connection_lease_balances_on_task_abort() {
|
||||||
let stats = Arc::new(Stats::new());
|
let stats = Arc::new(Stats::new());
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue