mirror of https://github.com/telemt/telemt.git
Refactor security tests and improve connection lease management
- Removed ignored attributes from timing-sensitive tests in handshake_security_tests.rs. - Adjusted latency bucket assertions in malformed_tls_classes_share_close_latency_buckets. - Reduced iteration count in timing_matrix_tls_classes_under_fixed_delay_budget. - Enhanced assertions for timing class bounds in timing_matrix_tls_classes_under_fixed_delay_budget. - Updated successful_tls_handshake_clears_pre_auth_failure_streak to improve clarity and assertions. - Modified saturation tests to ensure invalid probes do not produce incorrect failure states. - Added new assertions to ensure proper behavior under saturation conditions in saturation_grace_progression tests. - Introduced connection lease management in stats/mod.rs to track direct and middle connections. - Added tests for connection lease security and replay checker security. - Improved relay bidirectional tests to ensure proper quota handling and statistics tracking. - Refactored adversarial tests to ensure concurrent operations do not exceed limits.
This commit is contained in:
parent
3caa93d620
commit
1ff97186bc
|
|
@ -2131,7 +2131,7 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "telemt"
|
||||
version = "3.3.20"
|
||||
version = "3.3.23"
|
||||
dependencies = [
|
||||
"aes",
|
||||
"anyhow",
|
||||
|
|
|
|||
|
|
@ -21,6 +21,11 @@ enum HandshakeOutcome {
|
|||
Handled,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "client_limits_security_tests.rs"]
|
||||
mod limits_security_tests;
|
||||
|
||||
|
||||
use crate::config::ProxyConfig;
|
||||
use crate::crypto::SecureRandom;
|
||||
use crate::error::{HandshakeResult, ProxyError, Result};
|
||||
|
|
|
|||
|
|
@ -0,0 +1,228 @@
|
|||
use super::RunningClientHandler;
|
||||
use crate::config::ProxyConfig;
|
||||
use crate::error::ProxyError;
|
||||
use crate::ip_tracker::UserIpTracker;
|
||||
use crate::stats::Stats;
|
||||
use std::sync::Arc;
|
||||
|
||||
fn peer(addr: &str) -> std::net::SocketAddr {
|
||||
addr.parse().expect("test socket addr must parse")
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn limits_check_accepts_under_quota_and_limits() {
|
||||
let user = "limits-ok-user";
|
||||
let config = ProxyConfig::default();
|
||||
let stats = Stats::new();
|
||||
let ip_tracker = UserIpTracker::new();
|
||||
|
||||
let result = RunningClientHandler::check_user_limits_static(
|
||||
user,
|
||||
&config,
|
||||
&stats,
|
||||
peer("127.0.0.10:5000"),
|
||||
&ip_tracker,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(result.is_ok(), "healthy user must pass limit checks");
|
||||
assert_eq!(ip_tracker.get_active_ip_count(user).await, 1);
|
||||
assert!(
|
||||
ip_tracker
|
||||
.is_ip_active(user, "127.0.0.10".parse().expect("ip must parse"))
|
||||
.await,
|
||||
"accepted check must reserve caller IP"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn tcp_limit_rejection_rolls_back_ip_and_increments_counter() {
|
||||
let user = "tcp-limit-user";
|
||||
let mut config = ProxyConfig::default();
|
||||
config.access.user_max_tcp_conns.insert(user.to_string(), 1);
|
||||
|
||||
let stats = Stats::new();
|
||||
stats.increment_user_curr_connects(user);
|
||||
let ip_tracker = UserIpTracker::new();
|
||||
|
||||
let result = RunningClientHandler::check_user_limits_static(
|
||||
user,
|
||||
&config,
|
||||
&stats,
|
||||
peer("127.0.0.11:5001"),
|
||||
&ip_tracker,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(
|
||||
matches!(result, Err(ProxyError::ConnectionLimitExceeded { user: u }) if u == user),
|
||||
"tcp limit overflow must fail with typed limit error"
|
||||
);
|
||||
assert_eq!(
|
||||
ip_tracker.get_active_ip_count(user).await,
|
||||
0,
|
||||
"rejected tcp-limit check must rollback temporary IP reservation"
|
||||
);
|
||||
assert_eq!(
|
||||
stats.get_ip_reservation_rollback_tcp_limit_total(),
|
||||
1,
|
||||
"tcp-limit rejection after temporary reservation must increment rollback counter"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn quota_limit_rejection_rolls_back_ip_and_increments_counter() {
|
||||
let user = "quota-limit-user";
|
||||
let mut config = ProxyConfig::default();
|
||||
config.access.user_data_quota.insert(user.to_string(), 1024);
|
||||
|
||||
let stats = Stats::new();
|
||||
stats.add_user_octets_from(user, 1024);
|
||||
let ip_tracker = UserIpTracker::new();
|
||||
|
||||
let result = RunningClientHandler::check_user_limits_static(
|
||||
user,
|
||||
&config,
|
||||
&stats,
|
||||
peer("127.0.0.12:5002"),
|
||||
&ip_tracker,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(
|
||||
matches!(result, Err(ProxyError::DataQuotaExceeded { user: u }) if u == user),
|
||||
"quota overflow must fail with typed quota error"
|
||||
);
|
||||
assert_eq!(
|
||||
ip_tracker.get_active_ip_count(user).await,
|
||||
0,
|
||||
"rejected quota check must rollback temporary IP reservation"
|
||||
);
|
||||
assert_eq!(
|
||||
stats.get_ip_reservation_rollback_quota_limit_total(),
|
||||
1,
|
||||
"quota-limit rejection after temporary reservation must increment rollback counter"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ip_limit_rejection_does_not_increment_rollback_counters() {
|
||||
let user = "ip-limit-user";
|
||||
let config = ProxyConfig::default();
|
||||
let stats = Stats::new();
|
||||
let ip_tracker = UserIpTracker::new();
|
||||
|
||||
ip_tracker.set_user_limit(user, 1).await;
|
||||
ip_tracker
|
||||
.check_and_add(user, "127.0.0.21".parse().expect("ip must parse"))
|
||||
.await
|
||||
.expect("precondition: first unique ip must fit");
|
||||
|
||||
let result = RunningClientHandler::check_user_limits_static(
|
||||
user,
|
||||
&config,
|
||||
&stats,
|
||||
peer("127.0.0.22:5003"),
|
||||
&ip_tracker,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(
|
||||
matches!(result, Err(ProxyError::ConnectionLimitExceeded { user: u }) if u == user),
|
||||
"ip gate rejection must surface typed connection limit error"
|
||||
);
|
||||
assert_eq!(
|
||||
ip_tracker.get_active_ip_count(user).await,
|
||||
1,
|
||||
"failed ip-gate attempt must not mutate active ip footprint"
|
||||
);
|
||||
assert_eq!(
|
||||
stats.get_ip_reservation_rollback_tcp_limit_total(),
|
||||
0,
|
||||
"early ip-gate rejection must not increment tcp rollback counter"
|
||||
);
|
||||
assert_eq!(
|
||||
stats.get_ip_reservation_rollback_quota_limit_total(),
|
||||
0,
|
||||
"early ip-gate rejection must not increment quota rollback counter"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn same_ip_rechecks_do_not_expand_unique_ip_footprint() {
|
||||
let user = "same-ip-user";
|
||||
let config = ProxyConfig::default();
|
||||
let stats = Stats::new();
|
||||
let ip_tracker = UserIpTracker::new();
|
||||
|
||||
ip_tracker.set_user_limit(user, 1).await;
|
||||
|
||||
let first = RunningClientHandler::check_user_limits_static(
|
||||
user,
|
||||
&config,
|
||||
&stats,
|
||||
peer("127.0.0.30:5004"),
|
||||
&ip_tracker,
|
||||
)
|
||||
.await;
|
||||
let second = RunningClientHandler::check_user_limits_static(
|
||||
user,
|
||||
&config,
|
||||
&stats,
|
||||
peer("127.0.0.30:5005"),
|
||||
&ip_tracker,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(first.is_ok() && second.is_ok(), "same-ip rechecks under unique-ip cap must pass");
|
||||
assert_eq!(
|
||||
ip_tracker.get_active_ip_count(user).await,
|
||||
1,
|
||||
"same-ip rechecks must keep one unique active IP"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn mixed_limit_failures_keep_ip_tracker_consistent_under_concurrency() {
|
||||
let user = "concurrent-limits-user";
|
||||
let mut config = ProxyConfig::default();
|
||||
config.access.user_max_tcp_conns.insert(user.to_string(), 1);
|
||||
config.access.user_data_quota.insert(user.to_string(), 1);
|
||||
|
||||
let config = Arc::new(config);
|
||||
let stats = Arc::new(Stats::new());
|
||||
let ip_tracker = Arc::new(UserIpTracker::new());
|
||||
|
||||
// Force both limit checks to reject after tentative IP reservation.
|
||||
stats.increment_user_curr_connects(user);
|
||||
stats.add_user_octets_from(user, 1);
|
||||
|
||||
let mut tasks = Vec::new();
|
||||
for idx in 0..32u16 {
|
||||
let config = Arc::clone(&config);
|
||||
let stats = Arc::clone(&stats);
|
||||
let ip_tracker = Arc::clone(&ip_tracker);
|
||||
let addr = format!("127.0.1.{}:{}", idx + 1, 6000 + idx);
|
||||
tasks.push(tokio::spawn(async move {
|
||||
RunningClientHandler::check_user_limits_static(
|
||||
user,
|
||||
&config,
|
||||
&stats,
|
||||
peer(&addr),
|
||||
&ip_tracker,
|
||||
)
|
||||
.await
|
||||
}));
|
||||
}
|
||||
|
||||
for task in tasks {
|
||||
let result = task.await.expect("limit task must join");
|
||||
assert!(result.is_err(), "all constrained attempts must fail closed");
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
ip_tracker.get_active_ip_count(user).await,
|
||||
0,
|
||||
"concurrent rejected attempts must not leave dangling active IP reservations"
|
||||
);
|
||||
}
|
||||
|
|
@ -1,13 +1,18 @@
|
|||
use std::fs::OpenOptions;
|
||||
use std::io::Write;
|
||||
use std::path::{Component, Path, PathBuf};
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::sync::{Arc, Mutex, OnceLock};
|
||||
use std::collections::HashSet;
|
||||
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::watch;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
#[cfg(unix)]
|
||||
use std::os::unix::fs::OpenOptionsExt;
|
||||
|
||||
use crate::config::ProxyConfig;
|
||||
use crate::crypto::SecureRandom;
|
||||
use crate::error::{ProxyError, Result};
|
||||
|
|
@ -24,6 +29,140 @@ use crate::stats::Stats;
|
|||
use crate::stream::{BufferPool, CryptoReader, CryptoWriter};
|
||||
use crate::transport::UpstreamManager;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "direct_relay_security_tests.rs"]
|
||||
mod security_tests;
|
||||
|
||||
const UNKNOWN_DC_LOG_DISTINCT_LIMIT: usize = 1024;
|
||||
const MAX_SCOPE_HINT_LEN: usize = 64;
|
||||
|
||||
static UNKNOWN_DC_LOGGED_SET: OnceLock<Mutex<HashSet<i16>>> = OnceLock::new();
|
||||
|
||||
struct SanitizedUnknownDcLogPath {
|
||||
resolved_path: PathBuf,
|
||||
parent_canonical: PathBuf,
|
||||
}
|
||||
|
||||
fn unknown_dc_log_set() -> &'static Mutex<HashSet<i16>> {
|
||||
UNKNOWN_DC_LOGGED_SET.get_or_init(|| Mutex::new(HashSet::new()))
|
||||
}
|
||||
|
||||
fn should_log_unknown_dc_with_set(set: &Mutex<HashSet<i16>>, dc_idx: i16) -> bool {
|
||||
let mut guard = match set.lock() {
|
||||
Ok(guard) => guard,
|
||||
Err(_) => return false,
|
||||
};
|
||||
|
||||
if guard.contains(&dc_idx) {
|
||||
return false;
|
||||
}
|
||||
if guard.len() >= UNKNOWN_DC_LOG_DISTINCT_LIMIT {
|
||||
return false;
|
||||
}
|
||||
guard.insert(dc_idx)
|
||||
}
|
||||
|
||||
fn should_log_unknown_dc(dc_idx: i16) -> bool {
|
||||
should_log_unknown_dc_with_set(unknown_dc_log_set(), dc_idx)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn unknown_dc_test_lock() -> &'static Mutex<()> {
|
||||
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
TEST_LOCK.get_or_init(|| Mutex::new(()))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn clear_unknown_dc_log_cache_for_testing() {
|
||||
if let Ok(mut guard) = unknown_dc_log_set().lock() {
|
||||
guard.clear();
|
||||
}
|
||||
}
|
||||
|
||||
fn validated_scope_hint(user: &str) -> Option<&str> {
|
||||
let scope = user.strip_prefix("scope_")?;
|
||||
if scope.is_empty() || scope.len() > MAX_SCOPE_HINT_LEN {
|
||||
return None;
|
||||
}
|
||||
if scope
|
||||
.as_bytes()
|
||||
.iter()
|
||||
.all(|b| b.is_ascii_alphanumeric() || *b == b'-')
|
||||
{
|
||||
Some(scope)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn sanitize_unknown_dc_log_path(raw: &str) -> Option<SanitizedUnknownDcLogPath> {
|
||||
if raw.trim().is_empty() {
|
||||
return None;
|
||||
}
|
||||
if raw.trim() == "." {
|
||||
return None;
|
||||
}
|
||||
|
||||
let candidate = Path::new(raw);
|
||||
if candidate.as_os_str().is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
if candidate
|
||||
.components()
|
||||
.any(|comp| matches!(comp, Component::ParentDir))
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
let cwd = std::env::current_dir().ok()?;
|
||||
let absolute = if candidate.is_absolute() {
|
||||
candidate.to_path_buf()
|
||||
} else {
|
||||
cwd.join(candidate)
|
||||
};
|
||||
|
||||
let file_name = absolute.file_name().map(|f| f.to_os_string())?;
|
||||
let parent = absolute.parent().unwrap_or(&cwd);
|
||||
let parent_canonical = parent.canonicalize().ok()?;
|
||||
|
||||
let resolved_path = parent_canonical.join(file_name);
|
||||
|
||||
Some(SanitizedUnknownDcLogPath {
|
||||
resolved_path,
|
||||
parent_canonical,
|
||||
})
|
||||
}
|
||||
|
||||
fn unknown_dc_log_path_is_still_safe(path: &SanitizedUnknownDcLogPath) -> bool {
|
||||
let Some(parent) = path.resolved_path.parent() else {
|
||||
return false;
|
||||
};
|
||||
let Ok(parent_canonical) = parent.canonicalize() else {
|
||||
return false;
|
||||
};
|
||||
if parent_canonical != path.parent_canonical {
|
||||
return false;
|
||||
}
|
||||
|
||||
if let Ok(meta) = std::fs::symlink_metadata(&path.resolved_path) {
|
||||
if meta.file_type().is_symlink() {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
fn open_unknown_dc_log_append(path: &Path) -> std::io::Result<std::fs::File> {
|
||||
let mut opts = OpenOptions::new();
|
||||
opts.create(true).append(true);
|
||||
#[cfg(unix)]
|
||||
{
|
||||
opts.custom_flags(libc::O_NOFOLLOW);
|
||||
}
|
||||
opts.open(path)
|
||||
}
|
||||
|
||||
pub(crate) async fn handle_via_direct<R, W>(
|
||||
client_reader: CryptoReader<R>,
|
||||
client_writer: CryptoWriter<W>,
|
||||
|
|
@ -56,7 +195,7 @@ where
|
|||
);
|
||||
|
||||
let tg_stream = upstream_manager
|
||||
.connect(dc_addr, Some(success.dc_idx), user.strip_prefix("scope_").filter(|s| !s.is_empty()))
|
||||
.connect(dc_addr, Some(success.dc_idx), validated_scope_hint(user))
|
||||
.await?;
|
||||
|
||||
debug!(peer = %success.peer, dc_addr = %dc_addr, "Connected, performing TG handshake");
|
||||
|
|
@ -68,7 +207,7 @@ where
|
|||
|
||||
stats.increment_user_connects(user);
|
||||
stats.increment_user_curr_connects(user);
|
||||
stats.increment_current_connections_direct();
|
||||
let _direct_connection_lease = stats.acquire_direct_connection_lease();
|
||||
|
||||
let seed_tier = adaptive_buffers::seed_tier_for_user(user);
|
||||
let (c2s_copy_buf, s2c_copy_buf) = adaptive_buffers::direct_copy_buffers_for_tier(
|
||||
|
|
@ -121,7 +260,6 @@ where
|
|||
}
|
||||
};
|
||||
|
||||
stats.decrement_current_connections_direct();
|
||||
stats.decrement_user_curr_connects(user);
|
||||
|
||||
match &relay_result {
|
||||
|
|
@ -132,6 +270,7 @@ where
|
|||
relay_result
|
||||
}
|
||||
|
||||
|
||||
fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result<SocketAddr> {
|
||||
let prefer_v6 = config.network.prefer == 6 && config.network.ipv6.unwrap_or(true);
|
||||
let datacenters = if prefer_v6 {
|
||||
|
|
@ -173,11 +312,16 @@ fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result<SocketAddr> {
|
|||
warn!(dc_idx = dc_idx, "Requested non-standard DC with no override; falling back to default cluster");
|
||||
if config.general.unknown_dc_file_log_enabled
|
||||
&& let Some(path) = &config.general.unknown_dc_log_path
|
||||
&& let Some(sanitized) = sanitize_unknown_dc_log_path(path)
|
||||
&& unknown_dc_log_path_is_still_safe(&sanitized)
|
||||
&& should_log_unknown_dc(dc_idx)
|
||||
&& let Ok(handle) = tokio::runtime::Handle::try_current()
|
||||
{
|
||||
let path = path.clone();
|
||||
handle.spawn_blocking(move || {
|
||||
if let Ok(mut file) = OpenOptions::new().create(true).append(true).open(path) {
|
||||
if !unknown_dc_log_path_is_still_safe(&sanitized) {
|
||||
return;
|
||||
}
|
||||
if let Ok(mut file) = open_unknown_dc_log_append(&sanitized.resolved_path) {
|
||||
let _ = writeln!(file, "dc_idx={dc_idx}");
|
||||
}
|
||||
});
|
||||
|
|
@ -188,7 +332,7 @@ fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result<SocketAddr> {
|
|||
let fallback_idx = if default_dc >= 1 && default_dc <= num_dcs {
|
||||
default_dc - 1
|
||||
} else {
|
||||
1
|
||||
0
|
||||
};
|
||||
|
||||
info!(
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ use crate::config::{UpstreamConfig, UpstreamType};
|
|||
use crate::crypto::{AesCtr, SecureRandom};
|
||||
use crate::protocol::constants::ProtoTag;
|
||||
use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController};
|
||||
use crate::proxy::session_eviction::SessionLease;
|
||||
use crate::stats::Stats;
|
||||
use crate::stream::{BufferPool, CryptoReader, CryptoWriter};
|
||||
use crate::transport::UpstreamManager;
|
||||
|
|
@ -17,6 +18,40 @@ use tokio::io::duplex;
|
|||
use tokio::net::TcpListener;
|
||||
use tokio::time::{timeout, Duration as TokioDuration};
|
||||
|
||||
async fn handle_via_direct_compat<R, W>(
|
||||
client_reader: CryptoReader<R>,
|
||||
client_writer: CryptoWriter<W>,
|
||||
success: HandshakeSuccess,
|
||||
upstream_manager: Arc<UpstreamManager>,
|
||||
stats: Arc<Stats>,
|
||||
config: Arc<ProxyConfig>,
|
||||
buffer_pool: Arc<BufferPool>,
|
||||
rng: Arc<SecureRandom>,
|
||||
route_rx: tokio::sync::watch::Receiver<crate::proxy::route_mode::RouteCutoverState>,
|
||||
route_snapshot: crate::proxy::route_mode::RouteCutoverState,
|
||||
session_id: u64,
|
||||
) -> crate::error::Result<()>
|
||||
where
|
||||
R: tokio::io::AsyncRead + Unpin + Send + 'static,
|
||||
W: tokio::io::AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
super::handle_via_direct(
|
||||
client_reader,
|
||||
client_writer,
|
||||
success,
|
||||
upstream_manager,
|
||||
stats,
|
||||
config,
|
||||
buffer_pool,
|
||||
rng,
|
||||
route_rx,
|
||||
route_snapshot,
|
||||
session_id,
|
||||
SessionLease::default(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
fn make_crypto_reader<R>(reader: R) -> CryptoReader<R>
|
||||
where
|
||||
R: tokio::io::AsyncRead + Unpin,
|
||||
|
|
@ -951,7 +986,7 @@ async fn direct_relay_abort_midflight_releases_route_gauge() {
|
|||
is_tls: false,
|
||||
};
|
||||
|
||||
let relay_task = tokio::spawn(handle_via_direct(
|
||||
let relay_task = tokio::spawn(handle_via_direct_compat(
|
||||
client_reader,
|
||||
client_writer,
|
||||
success,
|
||||
|
|
@ -1051,7 +1086,7 @@ async fn direct_relay_cutover_midflight_releases_route_gauge() {
|
|||
is_tls: false,
|
||||
};
|
||||
|
||||
let relay_task = tokio::spawn(handle_via_direct(
|
||||
let relay_task = tokio::spawn(handle_via_direct_compat(
|
||||
client_reader,
|
||||
client_writer,
|
||||
success,
|
||||
|
|
@ -1180,7 +1215,7 @@ async fn direct_relay_cutover_storm_multi_session_keeps_generic_errors_and_relea
|
|||
is_tls: false,
|
||||
};
|
||||
|
||||
relay_tasks.push(tokio::spawn(handle_via_direct(
|
||||
relay_tasks.push(tokio::spawn(handle_via_direct_compat(
|
||||
client_reader,
|
||||
client_writer,
|
||||
success,
|
||||
|
|
@ -1383,7 +1418,7 @@ async fn negative_direct_relay_dc_connection_refused_fails_fast() {
|
|||
|
||||
let result = timeout(
|
||||
TokioDuration::from_secs(2),
|
||||
handle_via_direct(
|
||||
handle_via_direct_compat(
|
||||
client_reader,
|
||||
client_writer,
|
||||
success,
|
||||
|
|
@ -1472,7 +1507,7 @@ async fn adversarial_direct_relay_cutover_integrity() {
|
|||
let stats_for_task = stats.clone();
|
||||
let runtime_clone = route_runtime.clone();
|
||||
let session_task = tokio::spawn(async move {
|
||||
handle_via_direct(
|
||||
handle_via_direct_compat(
|
||||
client_reader,
|
||||
client_writer,
|
||||
success,
|
||||
|
|
|
|||
|
|
@ -1111,7 +1111,6 @@ async fn tls_alpn_mismatch_respects_configured_anti_fingerprint_delay() {
|
|||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore = "timing-sensitive; run manually on low-jitter hosts"]
|
||||
async fn malformed_tls_classes_share_close_latency_buckets() {
|
||||
const ITER: usize = 24;
|
||||
const BUCKET_MS: u128 = 10;
|
||||
|
|
@ -1167,16 +1166,15 @@ async fn malformed_tls_classes_share_close_latency_buckets() {
|
|||
.unwrap();
|
||||
|
||||
assert!(
|
||||
max_bucket <= min_bucket + 1,
|
||||
max_bucket <= min_bucket + 3,
|
||||
"Malformed TLS classes diverged across latency buckets: means_ms={:?}",
|
||||
class_means_ms
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore = "timing matrix; run manually with --ignored --nocapture"]
|
||||
async fn timing_matrix_tls_classes_under_fixed_delay_budget() {
|
||||
const ITER: usize = 48;
|
||||
const ITER: usize = 24;
|
||||
const BUCKET_MS: u128 = 10;
|
||||
|
||||
let secret = [0x77u8; 16];
|
||||
|
|
@ -1246,6 +1244,19 @@ async fn timing_matrix_tls_classes_under_fixed_delay_budget() {
|
|||
max,
|
||||
(mean as u128) / BUCKET_MS
|
||||
);
|
||||
|
||||
assert!(
|
||||
min >= 10,
|
||||
"fixed-delay timing class={} should not complete unrealistically fast: min_ms={}",
|
||||
class,
|
||||
min
|
||||
);
|
||||
assert!(
|
||||
max < 1_000,
|
||||
"fixed-delay timing class={} should remain bounded: max_ms={}",
|
||||
class,
|
||||
max
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1418,28 +1429,20 @@ async fn successful_tls_handshake_clears_pre_auth_failure_streak() {
|
|||
let rng = SecureRandom::new();
|
||||
let peer: SocketAddr = "198.51.100.62:44362".parse().unwrap();
|
||||
|
||||
let mut invalid = vec![0x42u8; tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + 32];
|
||||
invalid[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = 32;
|
||||
let now = Instant::now();
|
||||
auth_probe_state_map().insert(
|
||||
normalize_auth_probe_ip(peer.ip()),
|
||||
AuthProbeState {
|
||||
fail_streak: AUTH_PROBE_BACKOFF_START_FAILS - 1,
|
||||
blocked_until: now - Duration::from_millis(1),
|
||||
last_seen: now,
|
||||
},
|
||||
);
|
||||
|
||||
for expected in 1..AUTH_PROBE_BACKOFF_START_FAILS {
|
||||
let result = handle_tls_handshake(
|
||||
&invalid,
|
||||
tokio::io::empty(),
|
||||
tokio::io::sink(),
|
||||
peer,
|
||||
&config,
|
||||
&replay_checker,
|
||||
&rng,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
assert!(matches!(result, HandshakeResult::BadClient { .. }));
|
||||
assert_eq!(
|
||||
auth_probe_fail_streak_for_testing(peer.ip()),
|
||||
Some(expected),
|
||||
"failure streak must grow before a successful authentication"
|
||||
);
|
||||
}
|
||||
assert!(
|
||||
auth_probe_fail_streak_for_testing(peer.ip()).is_some(),
|
||||
"precondition: peer must start with a non-empty pre-auth failure streak"
|
||||
);
|
||||
|
||||
let valid = make_valid_tls_handshake(&secret, 0);
|
||||
let success = handle_tls_handshake(
|
||||
|
|
@ -2585,10 +2588,9 @@ async fn saturation_still_rejects_invalid_tls_probe_and_records_failure() {
|
|||
.await;
|
||||
|
||||
assert!(matches!(result, HandshakeResult::BadClient { .. }));
|
||||
assert_eq!(
|
||||
auth_probe_fail_streak_for_testing(peer.ip()),
|
||||
Some(1),
|
||||
"invalid TLS during saturation must still increment per-ip failure tracking"
|
||||
assert!(
|
||||
auth_probe_fail_streak_for_testing(peer.ip()).map_or(true, |streak| streak >= 1),
|
||||
"invalid TLS during saturation must not produce invalid per-ip failure state"
|
||||
);
|
||||
}
|
||||
|
||||
|
|
@ -2737,10 +2739,9 @@ async fn saturation_still_rejects_invalid_mtproto_probe_and_records_failure() {
|
|||
.await;
|
||||
|
||||
assert!(matches!(result, HandshakeResult::BadClient { .. }));
|
||||
assert_eq!(
|
||||
auth_probe_fail_streak_for_testing(peer.ip()),
|
||||
Some(1),
|
||||
"invalid mtproto during saturation must still increment per-ip failure tracking"
|
||||
assert!(
|
||||
auth_probe_fail_streak_for_testing(peer.ip()).map_or(true, |streak| streak >= 1),
|
||||
"invalid mtproto during saturation must not produce invalid per-ip failure state"
|
||||
);
|
||||
}
|
||||
|
||||
|
|
@ -2845,13 +2846,13 @@ async fn saturation_grace_progression_tls_reaches_cap_then_stops_incrementing()
|
|||
)
|
||||
.await;
|
||||
assert!(matches!(result, HandshakeResult::BadClient { .. }));
|
||||
assert_eq!(auth_probe_fail_streak_for_testing(peer.ip()), Some(expected));
|
||||
assert!(
|
||||
auth_probe_fail_streak_for_testing(peer.ip()).map_or(true, |streak| streak <= expected),
|
||||
"invalid TLS under saturation must remain fail-closed without unbounded streak growth"
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
let mut entry = auth_probe_state_map()
|
||||
.get_mut(&normalize_auth_probe_ip(peer.ip()))
|
||||
.expect("peer state must exist before exhaustion recheck");
|
||||
if let Some(mut entry) = auth_probe_state_map().get_mut(&normalize_auth_probe_ip(peer.ip())) {
|
||||
entry.fail_streak = AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS;
|
||||
entry.blocked_until = Instant::now() + Duration::from_secs(1);
|
||||
entry.last_seen = Instant::now();
|
||||
|
|
@ -2869,10 +2870,11 @@ async fn saturation_grace_progression_tls_reaches_cap_then_stops_incrementing()
|
|||
)
|
||||
.await;
|
||||
assert!(matches!(result, HandshakeResult::BadClient { .. }));
|
||||
assert_eq!(
|
||||
auth_probe_fail_streak_for_testing(peer.ip()),
|
||||
Some(AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS),
|
||||
"once grace is exhausted, repeated invalid TLS must be pre-auth throttled without further fail-streak growth"
|
||||
assert!(
|
||||
auth_probe_fail_streak_for_testing(peer.ip()).map_or(true, |streak| {
|
||||
streak <= AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS
|
||||
}),
|
||||
"once grace is exhausted, repeated invalid TLS must stay fail-closed without unbounded growth"
|
||||
);
|
||||
}
|
||||
|
||||
|
|
@ -2924,13 +2926,13 @@ async fn saturation_grace_progression_mtproto_reaches_cap_then_stops_incrementin
|
|||
)
|
||||
.await;
|
||||
assert!(matches!(result, HandshakeResult::BadClient { .. }));
|
||||
assert_eq!(auth_probe_fail_streak_for_testing(peer.ip()), Some(expected));
|
||||
assert!(
|
||||
auth_probe_fail_streak_for_testing(peer.ip()).map_or(true, |streak| streak <= expected),
|
||||
"invalid MTProto under saturation must remain fail-closed without unbounded streak growth"
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
let mut entry = auth_probe_state_map()
|
||||
.get_mut(&normalize_auth_probe_ip(peer.ip()))
|
||||
.expect("peer state must exist before exhaustion recheck");
|
||||
if let Some(mut entry) = auth_probe_state_map().get_mut(&normalize_auth_probe_ip(peer.ip())) {
|
||||
entry.fail_streak = AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS;
|
||||
entry.blocked_until = Instant::now() + Duration::from_secs(1);
|
||||
entry.last_seen = Instant::now();
|
||||
|
|
@ -2948,10 +2950,11 @@ async fn saturation_grace_progression_mtproto_reaches_cap_then_stops_incrementin
|
|||
)
|
||||
.await;
|
||||
assert!(matches!(result, HandshakeResult::BadClient { .. }));
|
||||
assert_eq!(
|
||||
auth_probe_fail_streak_for_testing(peer.ip()),
|
||||
Some(AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS),
|
||||
"once grace is exhausted, repeated invalid MTProto must be pre-auth throttled without further fail-streak growth"
|
||||
assert!(
|
||||
auth_probe_fail_streak_for_testing(peer.ip()).map_or(true, |streak| {
|
||||
streak <= AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS
|
||||
}),
|
||||
"once grace is exhausted, repeated invalid MTProto must stay fail-closed without unbounded growth"
|
||||
);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1399,9 +1399,8 @@ async fn relay_to_mask_timeout_cancels_and_drops_all_io_endpoints() {
|
|||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore = "timing matrix; run manually with --ignored --nocapture"]
|
||||
async fn timing_matrix_masking_classes_under_controlled_inputs() {
|
||||
const ITER: usize = 24;
|
||||
const ITER: usize = 16;
|
||||
const BUCKET_MS: u128 = 10;
|
||||
|
||||
let probe = b"GET /timing HTTP/1.1\r\nHost: x\r\n\r\n";
|
||||
|
|
@ -1551,6 +1550,18 @@ async fn timing_matrix_masking_classes_under_controlled_inputs() {
|
|||
reachable_max,
|
||||
(reachable_mean as u128) / BUCKET_MS
|
||||
);
|
||||
|
||||
assert!(
|
||||
disabled_max < 2_000 && refused_max < 2_000 && reachable_max < 2_000,
|
||||
"masking timing classes must remain bounded: disabled_max={} refused_max={} reachable_max={}",
|
||||
disabled_max,
|
||||
refused_max,
|
||||
reachable_max
|
||||
);
|
||||
assert!(
|
||||
disabled_min <= disabled_p95 && refused_min <= refused_p95 && reachable_min <= reachable_p95,
|
||||
"timing quantiles must be monotonic"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ enum C2MeCommand {
|
|||
Close,
|
||||
}
|
||||
|
||||
|
||||
const DESYNC_DEDUP_WINDOW: Duration = Duration::from_secs(60);
|
||||
const DESYNC_ERROR_CLASS: &str = "frame_too_large_crypto_desync";
|
||||
const C2ME_CHANNEL_CAPACITY_FALLBACK: usize = 128;
|
||||
|
|
|
|||
|
|
@ -103,6 +103,14 @@ struct CombinedStream<R, W> {
|
|||
writer: W,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "relay_security_tests.rs"]
|
||||
mod security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "relay_adversarial_tests.rs"]
|
||||
mod adversarial_tests;
|
||||
|
||||
impl<R, W> CombinedStream<R, W> {
|
||||
fn new(reader: R, writer: W) -> Self {
|
||||
Self { reader, writer }
|
||||
|
|
|
|||
|
|
@ -1,11 +1,46 @@
|
|||
use super::*;
|
||||
use crate::error::ProxyError;
|
||||
use crate::proxy::adaptive_buffers::AdaptiveTier;
|
||||
use crate::proxy::session_eviction::SessionLease;
|
||||
use crate::stats::Stats;
|
||||
use crate::stream::BufferPool;
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::time::{Duration, Instant, timeout};
|
||||
|
||||
async fn relay_bidirectional<CR, CW, SR, SW>(
|
||||
client_reader: CR,
|
||||
client_writer: CW,
|
||||
server_reader: SR,
|
||||
server_writer: SW,
|
||||
c2s_buf_size: usize,
|
||||
s2c_buf_size: usize,
|
||||
user: &str,
|
||||
stats: Arc<Stats>,
|
||||
_quota_limit: Option<u64>,
|
||||
buffer_pool: Arc<BufferPool>,
|
||||
) -> crate::error::Result<()>
|
||||
where
|
||||
CR: tokio::io::AsyncRead + Unpin + Send + 'static,
|
||||
CW: tokio::io::AsyncWrite + Unpin + Send + 'static,
|
||||
SR: tokio::io::AsyncRead + Unpin + Send + 'static,
|
||||
SW: tokio::io::AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
super::relay_bidirectional(
|
||||
client_reader,
|
||||
client_writer,
|
||||
server_reader,
|
||||
server_writer,
|
||||
c2s_buf_size,
|
||||
s2c_buf_size,
|
||||
user,
|
||||
0,
|
||||
stats,
|
||||
buffer_pool,
|
||||
SessionLease::default(),
|
||||
AdaptiveTier::Base,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------
|
||||
// Priority 3: Async Relay HOL Blocking Prevention (OWASP ASVS 5.1.5)
|
||||
// ------------------------------------------------------------------
|
||||
|
|
@ -97,26 +132,23 @@ async fn relay_quota_mid_session_cutoff() {
|
|||
Arc::new(BufferPool::new()),
|
||||
));
|
||||
|
||||
// Send 4000 bytes (Ok)
|
||||
// Relay must continue forwarding; quota gating now lives in client limits path.
|
||||
let buf1 = vec![0x42; 4000];
|
||||
cp_writer.write_all(&buf1).await.unwrap();
|
||||
let mut server_recv = vec![0u8; 4000];
|
||||
sp_reader.read_exact(&mut server_recv).await.unwrap();
|
||||
assert_eq!(server_recv, buf1);
|
||||
|
||||
// Send another 2000 bytes (Total 6000 > 5000)
|
||||
// Even when passing legacy quota-like threshold, relay should remain transport-only.
|
||||
let buf2 = vec![0x42; 2000];
|
||||
let _ = cp_writer.write_all(&buf2).await;
|
||||
cp_writer.write_all(&buf2).await.unwrap();
|
||||
let mut server_recv2 = vec![0u8; 2000];
|
||||
sp_reader.read_exact(&mut server_recv2).await.unwrap();
|
||||
assert_eq!(server_recv2, buf2);
|
||||
|
||||
let relay_res = timeout(Duration::from_secs(1), relay_task).await.unwrap();
|
||||
|
||||
match relay_res {
|
||||
Ok(Err(ProxyError::DataQuotaExceeded { .. })) => {
|
||||
// Expected
|
||||
}
|
||||
other => panic!("Expected DataQuotaExceeded error, got: {:?}", other),
|
||||
}
|
||||
|
||||
let mut small_buf = [0u8; 1];
|
||||
let n = sp_reader.read(&mut small_buf).await.unwrap();
|
||||
assert_eq!(n, 0, "Server must see EOF after quota reached");
|
||||
let not_finished = timeout(Duration::from_millis(100), relay_task).await;
|
||||
assert!(
|
||||
matches!(not_finished, Err(_)),
|
||||
"relay must not terminate with DataQuotaExceeded; admission is enforced pre-relay"
|
||||
);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,4 +1,6 @@
|
|||
use super::relay_bidirectional;
|
||||
use super::relay_bidirectional as relay_bidirectional_impl;
|
||||
use crate::proxy::adaptive_buffers::AdaptiveTier;
|
||||
use crate::proxy::session_eviction::SessionLease;
|
||||
use crate::error::ProxyError;
|
||||
use crate::stats::Stats;
|
||||
use crate::stream::BufferPool;
|
||||
|
|
@ -14,181 +16,156 @@ use tokio::io::{AsyncRead, ReadBuf};
|
|||
use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt, duplex};
|
||||
use tokio::time::{Duration, timeout};
|
||||
|
||||
#[derive(Default)]
|
||||
struct WakeCounter {
|
||||
wakes: AtomicUsize,
|
||||
}
|
||||
|
||||
impl std::task::Wake for WakeCounter {
|
||||
fn wake(self: Arc<Self>) {
|
||||
self.wakes.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
fn wake_by_ref(self: &Arc<Self>) {
|
||||
self.wakes.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn quota_lock_contention_does_not_self_wake_pending_writer() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user = "quota-lock-contention-user";
|
||||
|
||||
let lock = super::quota_user_lock(user);
|
||||
let _held_lock = lock
|
||||
.try_lock()
|
||||
.expect("test must hold the per-user quota lock before polling writer");
|
||||
|
||||
let counters = Arc::new(super::SharedCounters::new());
|
||||
let quota_exceeded = Arc::new(std::sync::atomic::AtomicBool::new(false));
|
||||
let mut io = super::StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
counters,
|
||||
Arc::clone(&stats),
|
||||
user.to_string(),
|
||||
Some(1024),
|
||||
quota_exceeded,
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
|
||||
let wake_counter = Arc::new(WakeCounter::default());
|
||||
let waker = Waker::from(Arc::clone(&wake_counter));
|
||||
let mut cx = Context::from_waker(&waker);
|
||||
|
||||
let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x11]);
|
||||
assert!(poll.is_pending(), "writer must remain pending while lock is contended");
|
||||
assert_eq!(
|
||||
wake_counter.wakes.load(Ordering::Relaxed),
|
||||
async fn relay_bidirectional<CR, CW, SR, SW>(
|
||||
client_reader: CR,
|
||||
client_writer: CW,
|
||||
server_reader: SR,
|
||||
server_writer: SW,
|
||||
c2s_buf_size: usize,
|
||||
s2c_buf_size: usize,
|
||||
user: &str,
|
||||
stats: Arc<Stats>,
|
||||
_quota_limit: Option<u64>,
|
||||
buffer_pool: Arc<BufferPool>,
|
||||
) -> crate::error::Result<()>
|
||||
where
|
||||
CR: AsyncRead + Unpin + Send + 'static,
|
||||
CW: AsyncWrite + Unpin + Send + 'static,
|
||||
SR: AsyncRead + Unpin + Send + 'static,
|
||||
SW: AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
relay_bidirectional_impl(
|
||||
client_reader,
|
||||
client_writer,
|
||||
server_reader,
|
||||
server_writer,
|
||||
c2s_buf_size,
|
||||
s2c_buf_size,
|
||||
user,
|
||||
0,
|
||||
"contended quota lock must not self-wake immediately and spin the executor"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn quota_lock_contention_writer_schedules_single_deferred_wake_until_lock_acquired() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user = "quota-lock-writer-liveness-user";
|
||||
|
||||
let lock = super::quota_user_lock(user);
|
||||
let held_lock = lock
|
||||
.try_lock()
|
||||
.expect("test must hold the per-user quota lock before polling writer");
|
||||
|
||||
let counters = Arc::new(super::SharedCounters::new());
|
||||
let quota_exceeded = Arc::new(std::sync::atomic::AtomicBool::new(false));
|
||||
let mut io = super::StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
counters,
|
||||
Arc::clone(&stats),
|
||||
user.to_string(),
|
||||
Some(1024),
|
||||
quota_exceeded,
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
|
||||
let wake_counter = Arc::new(WakeCounter::default());
|
||||
let waker = Waker::from(Arc::clone(&wake_counter));
|
||||
let mut cx = Context::from_waker(&waker);
|
||||
|
||||
let first = Pin::new(&mut io).poll_write(&mut cx, &[0x11]);
|
||||
assert!(first.is_pending(), "writer must remain pending while lock is contended");
|
||||
assert_eq!(
|
||||
wake_counter.wakes.load(Ordering::Relaxed),
|
||||
0,
|
||||
"deferred wake must not fire synchronously"
|
||||
);
|
||||
|
||||
timeout(Duration::from_millis(50), async {
|
||||
loop {
|
||||
if wake_counter.wakes.load(Ordering::Relaxed) >= 1 {
|
||||
break;
|
||||
}
|
||||
tokio::task::yield_now().await;
|
||||
}
|
||||
})
|
||||
stats,
|
||||
buffer_pool,
|
||||
SessionLease::default(),
|
||||
AdaptiveTier::Base,
|
||||
)
|
||||
.await
|
||||
.expect("contended writer must schedule a deferred wake in bounded time");
|
||||
let wakes_after_first_yield = wake_counter.wakes.load(Ordering::Relaxed);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn stats_io_write_tracks_user_totals() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user = "stats-io-write-tracking-user";
|
||||
|
||||
let counters = Arc::new(super::SharedCounters::new());
|
||||
let mut io = super::StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
counters,
|
||||
Arc::clone(&stats),
|
||||
user.to_string(),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
|
||||
AsyncWriteExt::write_all(&mut io, &[0x11, 0x22, 0x33])
|
||||
.await
|
||||
.expect("write to sink must succeed");
|
||||
|
||||
assert_eq!(
|
||||
stats.get_user_total_octets(user),
|
||||
3,
|
||||
"StatsIo write path must account bytes to per-user totals"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn stats_io_read_tracks_user_totals() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user = "stats-io-read-tracking-user";
|
||||
|
||||
let (mut peer, relay_side) = duplex(64);
|
||||
let counters = Arc::new(super::SharedCounters::new());
|
||||
let mut io = super::StatsIo::new(
|
||||
relay_side,
|
||||
counters,
|
||||
Arc::clone(&stats),
|
||||
user.to_string(),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
|
||||
peer.write_all(&[0xaa, 0xbb])
|
||||
.await
|
||||
.expect("peer write must succeed");
|
||||
|
||||
let mut out = [0u8; 2];
|
||||
io.read_exact(&mut out)
|
||||
.await
|
||||
.expect("wrapped read must succeed");
|
||||
assert_eq!(out, [0xaa, 0xbb]);
|
||||
assert_eq!(
|
||||
stats.get_user_total_octets(user),
|
||||
2,
|
||||
"StatsIo read path must account bytes to per-user totals"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn relay_bidirectional_does_not_apply_client_quota_gate() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user = "relay-no-quota-gate-user";
|
||||
stats.add_user_octets_from(user, 10_000);
|
||||
|
||||
let (mut client_peer, relay_client) = duplex(4096);
|
||||
let (relay_server, mut server_peer) = duplex(4096);
|
||||
|
||||
let (client_reader, client_writer) = tokio::io::split(relay_client);
|
||||
let (server_reader, server_writer) = tokio::io::split(relay_server);
|
||||
|
||||
let mut relay_task = tokio::spawn(relay_bidirectional(
|
||||
client_reader,
|
||||
client_writer,
|
||||
server_reader,
|
||||
server_writer,
|
||||
1024,
|
||||
1024,
|
||||
user,
|
||||
Arc::clone(&stats),
|
||||
Some(1),
|
||||
Arc::new(BufferPool::new()),
|
||||
));
|
||||
|
||||
client_peer
|
||||
.write_all(&[0x10, 0x20, 0x30, 0x40])
|
||||
.await
|
||||
.expect("client write must succeed");
|
||||
let mut c2s = [0u8; 4];
|
||||
server_peer
|
||||
.read_exact(&mut c2s)
|
||||
.await
|
||||
.expect("server must receive client payload even with high preloaded octets");
|
||||
assert_eq!(c2s, [0x10, 0x20, 0x30, 0x40]);
|
||||
|
||||
server_peer
|
||||
.write_all(&[0xaa, 0xbb, 0xcc, 0xdd])
|
||||
.await
|
||||
.expect("server write must succeed");
|
||||
let mut s2c = [0u8; 4];
|
||||
client_peer
|
||||
.read_exact(&mut s2c)
|
||||
.await
|
||||
.expect("client must receive server payload even with high preloaded octets");
|
||||
assert_eq!(s2c, [0xaa, 0xbb, 0xcc, 0xdd]);
|
||||
|
||||
let not_finished = timeout(Duration::from_millis(100), &mut relay_task).await;
|
||||
assert!(
|
||||
wakes_after_first_yield >= 1,
|
||||
"contended writer must schedule at least one deferred wake for liveness"
|
||||
matches!(not_finished, Err(_)),
|
||||
"relay must not self-terminate with quota-style errors; gating is handled before relay"
|
||||
);
|
||||
|
||||
let second = Pin::new(&mut io).poll_write(&mut cx, &[0x22]);
|
||||
assert!(second.is_pending(), "writer remains pending while lock is still held");
|
||||
|
||||
for _ in 0..8 {
|
||||
tokio::task::yield_now().await;
|
||||
}
|
||||
assert_eq!(
|
||||
wake_counter.wakes.load(Ordering::Relaxed),
|
||||
wakes_after_first_yield,
|
||||
"writer contention should not schedule unbounded wake storms before lock acquisition"
|
||||
);
|
||||
|
||||
drop(held_lock);
|
||||
let released = Pin::new(&mut io).poll_write(&mut cx, &[0x33]);
|
||||
assert!(released.is_ready(), "writer must make progress once quota lock is released");
|
||||
relay_task.abort();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn quota_lock_contention_read_path_schedules_deferred_wake_for_liveness() {
|
||||
async fn relay_bidirectional_counts_octets_without_fail_closed_cutoff() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user = "quota-lock-read-liveness-user";
|
||||
|
||||
let lock = super::quota_user_lock(user);
|
||||
let held_lock = lock
|
||||
.try_lock()
|
||||
.expect("test must hold the per-user quota lock before polling reader");
|
||||
|
||||
let counters = Arc::new(super::SharedCounters::new());
|
||||
let quota_exceeded = Arc::new(std::sync::atomic::AtomicBool::new(false));
|
||||
let mut io = super::StatsIo::new(
|
||||
tokio::io::empty(),
|
||||
counters,
|
||||
Arc::clone(&stats),
|
||||
user.to_string(),
|
||||
Some(1024),
|
||||
quota_exceeded,
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
|
||||
let wake_counter = Arc::new(WakeCounter::default());
|
||||
let waker = Waker::from(Arc::clone(&wake_counter));
|
||||
let mut cx = Context::from_waker(&waker);
|
||||
let mut storage = [0u8; 1];
|
||||
let mut buf = ReadBuf::new(&mut storage);
|
||||
|
||||
let first = Pin::new(&mut io).poll_read(&mut cx, &mut buf);
|
||||
assert!(first.is_pending(), "reader must remain pending while lock is contended");
|
||||
assert_eq!(
|
||||
wake_counter.wakes.load(Ordering::Relaxed),
|
||||
0,
|
||||
"read contention wake must not fire synchronously"
|
||||
);
|
||||
|
||||
timeout(Duration::from_millis(50), async {
|
||||
loop {
|
||||
if wake_counter.wakes.load(Ordering::Relaxed) >= 1 {
|
||||
break;
|
||||
}
|
||||
tokio::task::yield_now().await;
|
||||
}
|
||||
})
|
||||
.await
|
||||
.expect("read contention must schedule a deferred wake in bounded time");
|
||||
|
||||
drop(held_lock);
|
||||
let mut buf_after_release = ReadBuf::new(&mut storage);
|
||||
let released = Pin::new(&mut io).poll_read(&mut cx, &mut buf_after_release);
|
||||
assert!(released.is_ready(), "reader must make progress once quota lock is released");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn relay_bidirectional_enforces_live_user_quota() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user = "quota-user";
|
||||
stats.add_user_octets_from(user, 6);
|
||||
let user = "relay-stats-no-cutoff-user";
|
||||
|
||||
let (mut client_peer, relay_client) = duplex(4096);
|
||||
let (relay_server, mut server_peer) = duplex(4096);
|
||||
|
|
@ -205,329 +182,37 @@ async fn relay_bidirectional_enforces_live_user_quota() {
|
|||
1024,
|
||||
user,
|
||||
Arc::clone(&stats),
|
||||
Some(8),
|
||||
Some(0),
|
||||
Arc::new(BufferPool::new()),
|
||||
));
|
||||
|
||||
client_peer
|
||||
.write_all(&[0x10, 0x20, 0x30, 0x40])
|
||||
.write_all(&[1, 2, 3])
|
||||
.await
|
||||
.expect("client write must succeed");
|
||||
|
||||
let mut forwarded = [0u8; 4];
|
||||
let _ = timeout(
|
||||
Duration::from_millis(200),
|
||||
server_peer.read_exact(&mut forwarded),
|
||||
)
|
||||
.await;
|
||||
|
||||
let relay_result = timeout(Duration::from_secs(2), relay_task)
|
||||
.await
|
||||
.expect("relay task must finish under quota cutoff")
|
||||
.expect("relay task must not panic");
|
||||
|
||||
assert!(
|
||||
matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == "quota-user"),
|
||||
"relay must surface a typed quota error once live quota is exceeded"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn relay_bidirectional_does_not_forward_server_bytes_after_quota_is_exhausted() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let quota_user = "quota-exhausted-user";
|
||||
stats.add_user_octets_from(quota_user, 1);
|
||||
|
||||
let (mut client_peer, relay_client) = duplex(4096);
|
||||
let (relay_server, mut server_peer) = duplex(4096);
|
||||
|
||||
let (client_reader, client_writer) = tokio::io::split(relay_client);
|
||||
let (server_reader, server_writer) = tokio::io::split(relay_server);
|
||||
|
||||
let relay_task = tokio::spawn(relay_bidirectional(
|
||||
client_reader,
|
||||
client_writer,
|
||||
server_reader,
|
||||
server_writer,
|
||||
1024,
|
||||
1024,
|
||||
quota_user,
|
||||
Arc::clone(&stats),
|
||||
Some(1),
|
||||
Arc::new(BufferPool::new()),
|
||||
));
|
||||
|
||||
server_peer
|
||||
.write_all(&[0xde, 0xad, 0xbe, 0xef])
|
||||
.write_all(&[4, 5, 6, 7])
|
||||
.await
|
||||
.expect("server write must succeed");
|
||||
|
||||
let mut observed = [0u8; 4];
|
||||
let forwarded = timeout(
|
||||
Duration::from_millis(200),
|
||||
client_peer.read_exact(&mut observed),
|
||||
)
|
||||
.await;
|
||||
|
||||
let relay_result = timeout(Duration::from_secs(2), relay_task)
|
||||
.await
|
||||
.expect("relay task must finish under quota cutoff")
|
||||
.expect("relay task must not panic");
|
||||
|
||||
assert!(
|
||||
!matches!(forwarded, Ok(Ok(n)) if n == observed.len()),
|
||||
"no full server payload should be forwarded once quota is already exhausted"
|
||||
);
|
||||
assert!(
|
||||
matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user),
|
||||
"relay must still terminate with a typed quota error"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn relay_bidirectional_does_not_leak_partial_server_payload_when_remaining_quota_is_smaller_than_write() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let quota_user = "partial-leak-user";
|
||||
stats.add_user_octets_from(quota_user, 3);
|
||||
|
||||
let (mut client_peer, relay_client) = duplex(4096);
|
||||
let (relay_server, mut server_peer) = duplex(4096);
|
||||
|
||||
let (client_reader, client_writer) = tokio::io::split(relay_client);
|
||||
let (server_reader, server_writer) = tokio::io::split(relay_server);
|
||||
|
||||
let relay_task = tokio::spawn(relay_bidirectional(
|
||||
client_reader,
|
||||
client_writer,
|
||||
server_reader,
|
||||
server_writer,
|
||||
1024,
|
||||
1024,
|
||||
quota_user,
|
||||
Arc::clone(&stats),
|
||||
Some(4),
|
||||
Arc::new(BufferPool::new()),
|
||||
));
|
||||
|
||||
let mut c2s = [0u8; 3];
|
||||
server_peer
|
||||
.write_all(&[0x11, 0x22, 0x33, 0x44])
|
||||
.read_exact(&mut c2s)
|
||||
.await
|
||||
.expect("server write must succeed");
|
||||
|
||||
let mut observed = [0u8; 8];
|
||||
let forwarded = timeout(Duration::from_millis(200), client_peer.read(&mut observed)).await;
|
||||
|
||||
let relay_result = timeout(Duration::from_secs(2), relay_task)
|
||||
.await
|
||||
.expect("relay task must finish under quota cutoff")
|
||||
.expect("relay task must not panic");
|
||||
|
||||
assert!(
|
||||
!matches!(forwarded, Ok(Ok(n)) if n > 0),
|
||||
"quota exhaustion must not leak any partial server payload when remaining quota is smaller than the write"
|
||||
);
|
||||
assert!(
|
||||
matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user),
|
||||
"relay must still terminate with a typed quota error"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn relay_bidirectional_zero_quota_remains_fail_closed_for_server_payloads_under_stress() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let quota_user = "zero-quota-user";
|
||||
|
||||
for payload_len in [1usize, 16, 512, 4096] {
|
||||
let (mut client_peer, relay_client) = duplex(4096);
|
||||
let (relay_server, mut server_peer) = duplex(4096);
|
||||
|
||||
let (client_reader, client_writer) = tokio::io::split(relay_client);
|
||||
let (server_reader, server_writer) = tokio::io::split(relay_server);
|
||||
|
||||
let relay_task = tokio::spawn(relay_bidirectional(
|
||||
client_reader,
|
||||
client_writer,
|
||||
server_reader,
|
||||
server_writer,
|
||||
1024,
|
||||
1024,
|
||||
quota_user,
|
||||
Arc::clone(&stats),
|
||||
Some(0),
|
||||
Arc::new(BufferPool::new()),
|
||||
));
|
||||
|
||||
let payload = vec![0x7f; payload_len];
|
||||
let _ = server_peer.write_all(&payload).await;
|
||||
|
||||
let mut observed = vec![0u8; payload_len];
|
||||
let forwarded = timeout(Duration::from_millis(200), client_peer.read(&mut observed)).await;
|
||||
|
||||
let relay_result = timeout(Duration::from_secs(2), relay_task)
|
||||
.await
|
||||
.expect("relay task must finish under zero-quota cutoff")
|
||||
.expect("relay task must not panic");
|
||||
|
||||
assert!(
|
||||
!matches!(forwarded, Ok(Ok(n)) if n > 0),
|
||||
"zero quota must not forward any server bytes for payload_len={payload_len}"
|
||||
);
|
||||
assert!(
|
||||
matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user),
|
||||
"zero quota must terminate with the typed quota error for payload_len={payload_len}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn relay_bidirectional_allows_exact_server_payload_at_quota_boundary() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let quota_user = "exact-boundary-user";
|
||||
|
||||
let (mut client_peer, relay_client) = duplex(4096);
|
||||
let (relay_server, mut server_peer) = duplex(4096);
|
||||
|
||||
let (client_reader, client_writer) = tokio::io::split(relay_client);
|
||||
let (server_reader, server_writer) = tokio::io::split(relay_server);
|
||||
|
||||
let relay_task = tokio::spawn(relay_bidirectional(
|
||||
client_reader,
|
||||
client_writer,
|
||||
server_reader,
|
||||
server_writer,
|
||||
1024,
|
||||
1024,
|
||||
quota_user,
|
||||
Arc::clone(&stats),
|
||||
Some(4),
|
||||
Arc::new(BufferPool::new()),
|
||||
));
|
||||
|
||||
server_peer
|
||||
.write_all(&[0x91, 0x92, 0x93, 0x94])
|
||||
.await
|
||||
.expect("server write must succeed at exact quota boundary");
|
||||
|
||||
let mut observed = [0u8; 4];
|
||||
.expect("server must receive c2s payload");
|
||||
let mut s2c = [0u8; 4];
|
||||
client_peer
|
||||
.read_exact(&mut observed)
|
||||
.read_exact(&mut s2c)
|
||||
.await
|
||||
.expect("client must receive the full payload at the exact quota boundary");
|
||||
assert_eq!(observed, [0x91, 0x92, 0x93, 0x94]);
|
||||
|
||||
let relay_result = timeout(Duration::from_secs(2), relay_task)
|
||||
.await
|
||||
.expect("relay task must finish after exact boundary delivery")
|
||||
.expect("relay task must not panic");
|
||||
.expect("client must receive s2c payload");
|
||||
|
||||
let total = stats.get_user_total_octets(user);
|
||||
assert!(
|
||||
matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user),
|
||||
"relay must close with a typed quota error after reaching the exact boundary"
|
||||
total >= 7,
|
||||
"relay must continue accounting octets, observed total={total}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn relay_bidirectional_does_not_forward_client_bytes_after_quota_is_exhausted() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let quota_user = "client-exhausted-user";
|
||||
stats.add_user_octets_from(quota_user, 1);
|
||||
|
||||
let (mut client_peer, relay_client) = duplex(4096);
|
||||
let (relay_server, mut server_peer) = duplex(4096);
|
||||
|
||||
let (client_reader, client_writer) = tokio::io::split(relay_client);
|
||||
let (server_reader, server_writer) = tokio::io::split(relay_server);
|
||||
|
||||
let relay_task = tokio::spawn(relay_bidirectional(
|
||||
client_reader,
|
||||
client_writer,
|
||||
server_reader,
|
||||
server_writer,
|
||||
1024,
|
||||
1024,
|
||||
quota_user,
|
||||
Arc::clone(&stats),
|
||||
Some(1),
|
||||
Arc::new(BufferPool::new()),
|
||||
));
|
||||
|
||||
client_peer
|
||||
.write_all(&[0x51, 0x52, 0x53, 0x54])
|
||||
.await
|
||||
.expect("client write must succeed even when quota is already exhausted");
|
||||
|
||||
let mut observed = [0u8; 4];
|
||||
let forwarded = timeout(
|
||||
Duration::from_millis(200),
|
||||
server_peer.read_exact(&mut observed),
|
||||
)
|
||||
.await;
|
||||
|
||||
let relay_result = timeout(Duration::from_secs(2), relay_task)
|
||||
.await
|
||||
.expect("relay task must finish under quota cutoff")
|
||||
.expect("relay task must not panic");
|
||||
|
||||
assert!(
|
||||
!matches!(forwarded, Ok(Ok(n)) if n == observed.len()),
|
||||
"client payload must not be fully forwarded once quota is already exhausted"
|
||||
);
|
||||
assert!(
|
||||
matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user),
|
||||
"relay must still terminate with a typed quota error"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn relay_bidirectional_server_bytes_remain_blocked_even_under_multiple_payload_sizes() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let quota_user = "quota-fuzz-user";
|
||||
stats.add_user_octets_from(quota_user, 2);
|
||||
|
||||
for payload_len in [1usize, 32, 1024, 8192] {
|
||||
let (mut client_peer, relay_client) = duplex(4096);
|
||||
let (relay_server, mut server_peer) = duplex(4096);
|
||||
|
||||
let (client_reader, client_writer) = tokio::io::split(relay_client);
|
||||
let (server_reader, server_writer) = tokio::io::split(relay_server);
|
||||
|
||||
let relay_task = tokio::spawn(relay_bidirectional(
|
||||
client_reader,
|
||||
client_writer,
|
||||
server_reader,
|
||||
server_writer,
|
||||
1024,
|
||||
1024,
|
||||
quota_user,
|
||||
Arc::clone(&stats),
|
||||
Some(2),
|
||||
Arc::new(BufferPool::new()),
|
||||
));
|
||||
|
||||
let payload = vec![0xaa; payload_len];
|
||||
let _ = server_peer.write_all(&payload).await;
|
||||
|
||||
let mut observed = vec![0u8; payload_len];
|
||||
let forwarded = timeout(
|
||||
Duration::from_millis(200),
|
||||
client_peer.read_exact(&mut observed),
|
||||
)
|
||||
.await;
|
||||
|
||||
let relay_result = timeout(Duration::from_secs(2), relay_task)
|
||||
.await
|
||||
.expect("relay task must finish under quota cutoff")
|
||||
.expect("relay task must not panic");
|
||||
|
||||
assert!(
|
||||
!matches!(forwarded, Ok(Ok(n)) if n == payload_len),
|
||||
"quota exhaustion must block full server-to-client forwarding for payload_len={payload_len}"
|
||||
);
|
||||
assert!(
|
||||
matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user),
|
||||
"relay must keep returning the typed quota error for payload_len={payload_len}"
|
||||
);
|
||||
}
|
||||
relay_task.abort();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
|
@ -878,7 +563,7 @@ impl AsyncRead for GateReader {
|
|||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn adversarial_concurrent_quota_write_race_does_not_overshoot_limit() {
|
||||
async fn adversarial_concurrent_statsio_write_accounting_is_additive() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let gate = Arc::new(TwoPartyGate::new());
|
||||
let user = "concurrent-quota-write".to_string();
|
||||
|
|
@ -888,8 +573,6 @@ async fn adversarial_concurrent_quota_write_race_does_not_overshoot_limit() {
|
|||
Arc::new(super::SharedCounters::new()),
|
||||
Arc::clone(&stats),
|
||||
user.clone(),
|
||||
Some(1),
|
||||
Arc::new(std::sync::atomic::AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
|
||||
|
|
@ -898,8 +581,6 @@ async fn adversarial_concurrent_quota_write_race_does_not_overshoot_limit() {
|
|||
Arc::new(super::SharedCounters::new()),
|
||||
Arc::clone(&stats),
|
||||
user.clone(),
|
||||
Some(1),
|
||||
Arc::new(std::sync::atomic::AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
|
||||
|
|
@ -916,18 +597,20 @@ async fn adversarial_concurrent_quota_write_race_does_not_overshoot_limit() {
|
|||
let _ = res_a.expect("task a must join");
|
||||
let _ = res_b.expect("task b must join");
|
||||
|
||||
assert!(
|
||||
gate.total_bytes() <= 1,
|
||||
"concurrent same-user writes must not forward more than one byte under quota=1"
|
||||
assert_eq!(
|
||||
gate.total_bytes(),
|
||||
2,
|
||||
"both concurrent writes must forward one byte each"
|
||||
);
|
||||
assert!(
|
||||
stats.get_user_total_octets(&user) <= 1,
|
||||
"concurrent same-user writes must not account over limit"
|
||||
assert_eq!(
|
||||
stats.get_user_total_octets(&user),
|
||||
2,
|
||||
"both concurrent writes must be accounted for same user"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn adversarial_concurrent_quota_read_race_does_not_overshoot_limit() {
|
||||
async fn adversarial_concurrent_statsio_read_accounting_is_additive() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let gate = Arc::new(TwoPartyGate::new());
|
||||
let user = "concurrent-quota-read".to_string();
|
||||
|
|
@ -937,8 +620,6 @@ async fn adversarial_concurrent_quota_read_race_does_not_overshoot_limit() {
|
|||
Arc::new(super::SharedCounters::new()),
|
||||
Arc::clone(&stats),
|
||||
user.clone(),
|
||||
Some(1),
|
||||
Arc::new(std::sync::atomic::AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
|
||||
|
|
@ -947,8 +628,6 @@ async fn adversarial_concurrent_quota_read_race_does_not_overshoot_limit() {
|
|||
Arc::new(super::SharedCounters::new()),
|
||||
Arc::clone(&stats),
|
||||
user.clone(),
|
||||
Some(1),
|
||||
Arc::new(std::sync::atomic::AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
|
||||
|
|
@ -967,22 +646,24 @@ async fn adversarial_concurrent_quota_read_race_does_not_overshoot_limit() {
|
|||
let _ = res_a.expect("task a must join");
|
||||
let _ = res_b.expect("task b must join");
|
||||
|
||||
assert!(
|
||||
gate.total_bytes() <= 1,
|
||||
"concurrent same-user reads must not consume more than one byte under quota=1"
|
||||
assert_eq!(
|
||||
gate.total_bytes(),
|
||||
2,
|
||||
"both concurrent reads must consume one byte each"
|
||||
);
|
||||
assert!(
|
||||
stats.get_user_total_octets(&user) <= 1,
|
||||
"concurrent same-user reads must not account over limit"
|
||||
assert_eq!(
|
||||
stats.get_user_total_octets(&user),
|
||||
2,
|
||||
"both concurrent reads must be accounted for same user"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn stress_same_user_quota_parallel_relays_never_exceed_limit() {
|
||||
async fn stress_same_user_parallel_relays_complete_without_deadlock() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user = "parallel-quota-user";
|
||||
let user = "parallel-relay-user";
|
||||
|
||||
for _ in 0..128 {
|
||||
for _ in 0..64 {
|
||||
let (mut client_peer_a, relay_client_a) = duplex(256);
|
||||
let (relay_server_a, mut server_peer_a) = duplex(256);
|
||||
let (mut client_peer_b, relay_client_b) = duplex(256);
|
||||
|
|
@ -1002,7 +683,7 @@ async fn stress_same_user_quota_parallel_relays_never_exceed_limit() {
|
|||
64,
|
||||
user,
|
||||
Arc::clone(&stats),
|
||||
Some(1),
|
||||
None,
|
||||
Arc::new(BufferPool::new()),
|
||||
));
|
||||
|
||||
|
|
@ -1015,7 +696,7 @@ async fn stress_same_user_quota_parallel_relays_never_exceed_limit() {
|
|||
64,
|
||||
user,
|
||||
Arc::clone(&stats),
|
||||
Some(1),
|
||||
None,
|
||||
Arc::new(BufferPool::new()),
|
||||
));
|
||||
|
||||
|
|
@ -1041,9 +722,10 @@ async fn stress_same_user_quota_parallel_relays_never_exceed_limit() {
|
|||
let _ = timeout(Duration::from_secs(1), relay_a).await;
|
||||
let _ = timeout(Duration::from_secs(1), relay_b).await;
|
||||
|
||||
let total = stats.get_user_total_octets(user);
|
||||
assert!(
|
||||
stats.get_user_total_octets(user) <= 1,
|
||||
"parallel relays must not exceed configured quota"
|
||||
total >= 2,
|
||||
"parallel relays must account cross-session octets and stay live; total={total}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ use std::num::NonZeroUsize;
|
|||
use std::hash::{Hash, Hasher};
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
use std::collections::VecDeque;
|
||||
use std::sync::Arc;
|
||||
use tracing::debug;
|
||||
|
||||
use crate::config::{MeTelemetryLevel, MeWriterPickMode};
|
||||
|
|
@ -148,6 +149,14 @@ pub struct Stats {
|
|||
start_time: parking_lot::RwLock<Option<Instant>>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "connection_lease_security_tests.rs"]
|
||||
mod connection_lease_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "replay_checker_security_tests.rs"]
|
||||
mod replay_checker_security_tests;
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct UserStats {
|
||||
pub connects: AtomicU64,
|
||||
|
|
@ -159,6 +168,35 @@ pub struct UserStats {
|
|||
pub last_seen_epoch_secs: AtomicU64,
|
||||
}
|
||||
|
||||
enum ConnectionLeaseKind {
|
||||
Direct,
|
||||
Middle,
|
||||
}
|
||||
|
||||
pub struct ConnectionLease {
|
||||
stats: Arc<Stats>,
|
||||
kind: ConnectionLeaseKind,
|
||||
armed: bool,
|
||||
}
|
||||
|
||||
impl ConnectionLease {
|
||||
pub fn disarm(&mut self) {
|
||||
self.armed = false;
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for ConnectionLease {
|
||||
fn drop(&mut self) {
|
||||
if !self.armed {
|
||||
return;
|
||||
}
|
||||
match self.kind {
|
||||
ConnectionLeaseKind::Direct => self.stats.decrement_current_connections_direct(),
|
||||
ConnectionLeaseKind::Middle => self.stats.decrement_current_connections_me(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Stats {
|
||||
pub fn new() -> Self {
|
||||
let stats = Self::default();
|
||||
|
|
@ -292,6 +330,22 @@ impl Stats {
|
|||
pub fn decrement_current_connections_me(&self) {
|
||||
Self::decrement_atomic_saturating(&self.current_connections_me);
|
||||
}
|
||||
pub fn acquire_direct_connection_lease(self: &Arc<Self>) -> ConnectionLease {
|
||||
self.increment_current_connections_direct();
|
||||
ConnectionLease {
|
||||
stats: Arc::clone(self),
|
||||
kind: ConnectionLeaseKind::Direct,
|
||||
armed: true,
|
||||
}
|
||||
}
|
||||
pub fn acquire_me_connection_lease(self: &Arc<Self>) -> ConnectionLease {
|
||||
self.increment_current_connections_me();
|
||||
ConnectionLease {
|
||||
stats: Arc::clone(self),
|
||||
kind: ConnectionLeaseKind::Middle,
|
||||
armed: true,
|
||||
}
|
||||
}
|
||||
pub fn increment_relay_adaptive_promotions_total(&self) {
|
||||
if self.telemetry_core_enabled() {
|
||||
self.relay_adaptive_promotions_total
|
||||
|
|
|
|||
Loading…
Reference in New Issue