Refactor security tests and improve connection lease management

- Removed ignored attributes from timing-sensitive tests in handshake_security_tests.rs.
- Adjusted latency bucket assertions in malformed_tls_classes_share_close_latency_buckets.
- Reduced iteration count in timing_matrix_tls_classes_under_fixed_delay_budget.
- Enhanced assertions for timing class bounds in timing_matrix_tls_classes_under_fixed_delay_budget.
- Updated successful_tls_handshake_clears_pre_auth_failure_streak to improve clarity and assertions.
- Modified saturation tests to ensure invalid probes do not produce incorrect failure states.
- Added new assertions to ensure proper behavior under saturation conditions in saturation_grace_progression tests.
- Introduced connection lease management in stats/mod.rs to track direct and middle connections.
- Added tests for connection lease security and replay checker security.
- Improved relay bidirectional tests to ensure proper quota handling and statistics tracking.
- Refactored adversarial tests to ensure concurrent operations do not exceed limits.
This commit is contained in:
David Osipov 2026-03-19 16:26:45 +04:00
parent 3caa93d620
commit 1ff97186bc
No known key found for this signature in database
GPG Key ID: 0E55C4A47454E82E
12 changed files with 788 additions and 585 deletions

2
Cargo.lock generated
View File

@ -2131,7 +2131,7 @@ dependencies = [
[[package]] [[package]]
name = "telemt" name = "telemt"
version = "3.3.20" version = "3.3.23"
dependencies = [ dependencies = [
"aes", "aes",
"anyhow", "anyhow",

View File

@ -21,6 +21,11 @@ enum HandshakeOutcome {
Handled, Handled,
} }
#[cfg(test)]
#[path = "client_limits_security_tests.rs"]
mod limits_security_tests;
use crate::config::ProxyConfig; use crate::config::ProxyConfig;
use crate::crypto::SecureRandom; use crate::crypto::SecureRandom;
use crate::error::{HandshakeResult, ProxyError, Result}; use crate::error::{HandshakeResult, ProxyError, Result};

View File

@ -0,0 +1,228 @@
use super::RunningClientHandler;
use crate::config::ProxyConfig;
use crate::error::ProxyError;
use crate::ip_tracker::UserIpTracker;
use crate::stats::Stats;
use std::sync::Arc;
fn peer(addr: &str) -> std::net::SocketAddr {
addr.parse().expect("test socket addr must parse")
}
#[tokio::test]
async fn limits_check_accepts_under_quota_and_limits() {
let user = "limits-ok-user";
let config = ProxyConfig::default();
let stats = Stats::new();
let ip_tracker = UserIpTracker::new();
let result = RunningClientHandler::check_user_limits_static(
user,
&config,
&stats,
peer("127.0.0.10:5000"),
&ip_tracker,
)
.await;
assert!(result.is_ok(), "healthy user must pass limit checks");
assert_eq!(ip_tracker.get_active_ip_count(user).await, 1);
assert!(
ip_tracker
.is_ip_active(user, "127.0.0.10".parse().expect("ip must parse"))
.await,
"accepted check must reserve caller IP"
);
}
#[tokio::test]
async fn tcp_limit_rejection_rolls_back_ip_and_increments_counter() {
let user = "tcp-limit-user";
let mut config = ProxyConfig::default();
config.access.user_max_tcp_conns.insert(user.to_string(), 1);
let stats = Stats::new();
stats.increment_user_curr_connects(user);
let ip_tracker = UserIpTracker::new();
let result = RunningClientHandler::check_user_limits_static(
user,
&config,
&stats,
peer("127.0.0.11:5001"),
&ip_tracker,
)
.await;
assert!(
matches!(result, Err(ProxyError::ConnectionLimitExceeded { user: u }) if u == user),
"tcp limit overflow must fail with typed limit error"
);
assert_eq!(
ip_tracker.get_active_ip_count(user).await,
0,
"rejected tcp-limit check must rollback temporary IP reservation"
);
assert_eq!(
stats.get_ip_reservation_rollback_tcp_limit_total(),
1,
"tcp-limit rejection after temporary reservation must increment rollback counter"
);
}
#[tokio::test]
async fn quota_limit_rejection_rolls_back_ip_and_increments_counter() {
let user = "quota-limit-user";
let mut config = ProxyConfig::default();
config.access.user_data_quota.insert(user.to_string(), 1024);
let stats = Stats::new();
stats.add_user_octets_from(user, 1024);
let ip_tracker = UserIpTracker::new();
let result = RunningClientHandler::check_user_limits_static(
user,
&config,
&stats,
peer("127.0.0.12:5002"),
&ip_tracker,
)
.await;
assert!(
matches!(result, Err(ProxyError::DataQuotaExceeded { user: u }) if u == user),
"quota overflow must fail with typed quota error"
);
assert_eq!(
ip_tracker.get_active_ip_count(user).await,
0,
"rejected quota check must rollback temporary IP reservation"
);
assert_eq!(
stats.get_ip_reservation_rollback_quota_limit_total(),
1,
"quota-limit rejection after temporary reservation must increment rollback counter"
);
}
#[tokio::test]
async fn ip_limit_rejection_does_not_increment_rollback_counters() {
let user = "ip-limit-user";
let config = ProxyConfig::default();
let stats = Stats::new();
let ip_tracker = UserIpTracker::new();
ip_tracker.set_user_limit(user, 1).await;
ip_tracker
.check_and_add(user, "127.0.0.21".parse().expect("ip must parse"))
.await
.expect("precondition: first unique ip must fit");
let result = RunningClientHandler::check_user_limits_static(
user,
&config,
&stats,
peer("127.0.0.22:5003"),
&ip_tracker,
)
.await;
assert!(
matches!(result, Err(ProxyError::ConnectionLimitExceeded { user: u }) if u == user),
"ip gate rejection must surface typed connection limit error"
);
assert_eq!(
ip_tracker.get_active_ip_count(user).await,
1,
"failed ip-gate attempt must not mutate active ip footprint"
);
assert_eq!(
stats.get_ip_reservation_rollback_tcp_limit_total(),
0,
"early ip-gate rejection must not increment tcp rollback counter"
);
assert_eq!(
stats.get_ip_reservation_rollback_quota_limit_total(),
0,
"early ip-gate rejection must not increment quota rollback counter"
);
}
#[tokio::test]
async fn same_ip_rechecks_do_not_expand_unique_ip_footprint() {
let user = "same-ip-user";
let config = ProxyConfig::default();
let stats = Stats::new();
let ip_tracker = UserIpTracker::new();
ip_tracker.set_user_limit(user, 1).await;
let first = RunningClientHandler::check_user_limits_static(
user,
&config,
&stats,
peer("127.0.0.30:5004"),
&ip_tracker,
)
.await;
let second = RunningClientHandler::check_user_limits_static(
user,
&config,
&stats,
peer("127.0.0.30:5005"),
&ip_tracker,
)
.await;
assert!(first.is_ok() && second.is_ok(), "same-ip rechecks under unique-ip cap must pass");
assert_eq!(
ip_tracker.get_active_ip_count(user).await,
1,
"same-ip rechecks must keep one unique active IP"
);
}
#[tokio::test]
async fn mixed_limit_failures_keep_ip_tracker_consistent_under_concurrency() {
let user = "concurrent-limits-user";
let mut config = ProxyConfig::default();
config.access.user_max_tcp_conns.insert(user.to_string(), 1);
config.access.user_data_quota.insert(user.to_string(), 1);
let config = Arc::new(config);
let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new());
// Force both limit checks to reject after tentative IP reservation.
stats.increment_user_curr_connects(user);
stats.add_user_octets_from(user, 1);
let mut tasks = Vec::new();
for idx in 0..32u16 {
let config = Arc::clone(&config);
let stats = Arc::clone(&stats);
let ip_tracker = Arc::clone(&ip_tracker);
let addr = format!("127.0.1.{}:{}", idx + 1, 6000 + idx);
tasks.push(tokio::spawn(async move {
RunningClientHandler::check_user_limits_static(
user,
&config,
&stats,
peer(&addr),
&ip_tracker,
)
.await
}));
}
for task in tasks {
let result = task.await.expect("limit task must join");
assert!(result.is_err(), "all constrained attempts must fail closed");
}
assert_eq!(
ip_tracker.get_active_ip_count(user).await,
0,
"concurrent rejected attempts must not leave dangling active IP reservations"
);
}

View File

@ -1,13 +1,18 @@
use std::fs::OpenOptions; use std::fs::OpenOptions;
use std::io::Write; use std::io::Write;
use std::path::{Component, Path, PathBuf};
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::Arc; use std::sync::{Arc, Mutex, OnceLock};
use std::collections::HashSet;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio::sync::watch; use tokio::sync::watch;
use tracing::{debug, info, warn}; use tracing::{debug, info, warn};
#[cfg(unix)]
use std::os::unix::fs::OpenOptionsExt;
use crate::config::ProxyConfig; use crate::config::ProxyConfig;
use crate::crypto::SecureRandom; use crate::crypto::SecureRandom;
use crate::error::{ProxyError, Result}; use crate::error::{ProxyError, Result};
@ -24,6 +29,140 @@ use crate::stats::Stats;
use crate::stream::{BufferPool, CryptoReader, CryptoWriter}; use crate::stream::{BufferPool, CryptoReader, CryptoWriter};
use crate::transport::UpstreamManager; use crate::transport::UpstreamManager;
#[cfg(test)]
#[path = "direct_relay_security_tests.rs"]
mod security_tests;
const UNKNOWN_DC_LOG_DISTINCT_LIMIT: usize = 1024;
const MAX_SCOPE_HINT_LEN: usize = 64;
static UNKNOWN_DC_LOGGED_SET: OnceLock<Mutex<HashSet<i16>>> = OnceLock::new();
struct SanitizedUnknownDcLogPath {
resolved_path: PathBuf,
parent_canonical: PathBuf,
}
fn unknown_dc_log_set() -> &'static Mutex<HashSet<i16>> {
UNKNOWN_DC_LOGGED_SET.get_or_init(|| Mutex::new(HashSet::new()))
}
fn should_log_unknown_dc_with_set(set: &Mutex<HashSet<i16>>, dc_idx: i16) -> bool {
let mut guard = match set.lock() {
Ok(guard) => guard,
Err(_) => return false,
};
if guard.contains(&dc_idx) {
return false;
}
if guard.len() >= UNKNOWN_DC_LOG_DISTINCT_LIMIT {
return false;
}
guard.insert(dc_idx)
}
fn should_log_unknown_dc(dc_idx: i16) -> bool {
should_log_unknown_dc_with_set(unknown_dc_log_set(), dc_idx)
}
#[cfg(test)]
fn unknown_dc_test_lock() -> &'static Mutex<()> {
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
TEST_LOCK.get_or_init(|| Mutex::new(()))
}
#[cfg(test)]
fn clear_unknown_dc_log_cache_for_testing() {
if let Ok(mut guard) = unknown_dc_log_set().lock() {
guard.clear();
}
}
fn validated_scope_hint(user: &str) -> Option<&str> {
let scope = user.strip_prefix("scope_")?;
if scope.is_empty() || scope.len() > MAX_SCOPE_HINT_LEN {
return None;
}
if scope
.as_bytes()
.iter()
.all(|b| b.is_ascii_alphanumeric() || *b == b'-')
{
Some(scope)
} else {
None
}
}
fn sanitize_unknown_dc_log_path(raw: &str) -> Option<SanitizedUnknownDcLogPath> {
if raw.trim().is_empty() {
return None;
}
if raw.trim() == "." {
return None;
}
let candidate = Path::new(raw);
if candidate.as_os_str().is_empty() {
return None;
}
if candidate
.components()
.any(|comp| matches!(comp, Component::ParentDir))
{
return None;
}
let cwd = std::env::current_dir().ok()?;
let absolute = if candidate.is_absolute() {
candidate.to_path_buf()
} else {
cwd.join(candidate)
};
let file_name = absolute.file_name().map(|f| f.to_os_string())?;
let parent = absolute.parent().unwrap_or(&cwd);
let parent_canonical = parent.canonicalize().ok()?;
let resolved_path = parent_canonical.join(file_name);
Some(SanitizedUnknownDcLogPath {
resolved_path,
parent_canonical,
})
}
fn unknown_dc_log_path_is_still_safe(path: &SanitizedUnknownDcLogPath) -> bool {
let Some(parent) = path.resolved_path.parent() else {
return false;
};
let Ok(parent_canonical) = parent.canonicalize() else {
return false;
};
if parent_canonical != path.parent_canonical {
return false;
}
if let Ok(meta) = std::fs::symlink_metadata(&path.resolved_path) {
if meta.file_type().is_symlink() {
return false;
}
}
true
}
fn open_unknown_dc_log_append(path: &Path) -> std::io::Result<std::fs::File> {
let mut opts = OpenOptions::new();
opts.create(true).append(true);
#[cfg(unix)]
{
opts.custom_flags(libc::O_NOFOLLOW);
}
opts.open(path)
}
pub(crate) async fn handle_via_direct<R, W>( pub(crate) async fn handle_via_direct<R, W>(
client_reader: CryptoReader<R>, client_reader: CryptoReader<R>,
client_writer: CryptoWriter<W>, client_writer: CryptoWriter<W>,
@ -56,7 +195,7 @@ where
); );
let tg_stream = upstream_manager let tg_stream = upstream_manager
.connect(dc_addr, Some(success.dc_idx), user.strip_prefix("scope_").filter(|s| !s.is_empty())) .connect(dc_addr, Some(success.dc_idx), validated_scope_hint(user))
.await?; .await?;
debug!(peer = %success.peer, dc_addr = %dc_addr, "Connected, performing TG handshake"); debug!(peer = %success.peer, dc_addr = %dc_addr, "Connected, performing TG handshake");
@ -68,7 +207,7 @@ where
stats.increment_user_connects(user); stats.increment_user_connects(user);
stats.increment_user_curr_connects(user); stats.increment_user_curr_connects(user);
stats.increment_current_connections_direct(); let _direct_connection_lease = stats.acquire_direct_connection_lease();
let seed_tier = adaptive_buffers::seed_tier_for_user(user); let seed_tier = adaptive_buffers::seed_tier_for_user(user);
let (c2s_copy_buf, s2c_copy_buf) = adaptive_buffers::direct_copy_buffers_for_tier( let (c2s_copy_buf, s2c_copy_buf) = adaptive_buffers::direct_copy_buffers_for_tier(
@ -121,7 +260,6 @@ where
} }
}; };
stats.decrement_current_connections_direct();
stats.decrement_user_curr_connects(user); stats.decrement_user_curr_connects(user);
match &relay_result { match &relay_result {
@ -132,6 +270,7 @@ where
relay_result relay_result
} }
fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result<SocketAddr> { fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result<SocketAddr> {
let prefer_v6 = config.network.prefer == 6 && config.network.ipv6.unwrap_or(true); let prefer_v6 = config.network.prefer == 6 && config.network.ipv6.unwrap_or(true);
let datacenters = if prefer_v6 { let datacenters = if prefer_v6 {
@ -173,11 +312,16 @@ fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result<SocketAddr> {
warn!(dc_idx = dc_idx, "Requested non-standard DC with no override; falling back to default cluster"); warn!(dc_idx = dc_idx, "Requested non-standard DC with no override; falling back to default cluster");
if config.general.unknown_dc_file_log_enabled if config.general.unknown_dc_file_log_enabled
&& let Some(path) = &config.general.unknown_dc_log_path && let Some(path) = &config.general.unknown_dc_log_path
&& let Some(sanitized) = sanitize_unknown_dc_log_path(path)
&& unknown_dc_log_path_is_still_safe(&sanitized)
&& should_log_unknown_dc(dc_idx)
&& let Ok(handle) = tokio::runtime::Handle::try_current() && let Ok(handle) = tokio::runtime::Handle::try_current()
{ {
let path = path.clone();
handle.spawn_blocking(move || { handle.spawn_blocking(move || {
if let Ok(mut file) = OpenOptions::new().create(true).append(true).open(path) { if !unknown_dc_log_path_is_still_safe(&sanitized) {
return;
}
if let Ok(mut file) = open_unknown_dc_log_append(&sanitized.resolved_path) {
let _ = writeln!(file, "dc_idx={dc_idx}"); let _ = writeln!(file, "dc_idx={dc_idx}");
} }
}); });
@ -188,7 +332,7 @@ fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result<SocketAddr> {
let fallback_idx = if default_dc >= 1 && default_dc <= num_dcs { let fallback_idx = if default_dc >= 1 && default_dc <= num_dcs {
default_dc - 1 default_dc - 1
} else { } else {
1 0
}; };
info!( info!(

View File

@ -3,6 +3,7 @@ use crate::config::{UpstreamConfig, UpstreamType};
use crate::crypto::{AesCtr, SecureRandom}; use crate::crypto::{AesCtr, SecureRandom};
use crate::protocol::constants::ProtoTag; use crate::protocol::constants::ProtoTag;
use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController}; use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController};
use crate::proxy::session_eviction::SessionLease;
use crate::stats::Stats; use crate::stats::Stats;
use crate::stream::{BufferPool, CryptoReader, CryptoWriter}; use crate::stream::{BufferPool, CryptoReader, CryptoWriter};
use crate::transport::UpstreamManager; use crate::transport::UpstreamManager;
@ -17,6 +18,40 @@ use tokio::io::duplex;
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio::time::{timeout, Duration as TokioDuration}; use tokio::time::{timeout, Duration as TokioDuration};
async fn handle_via_direct_compat<R, W>(
client_reader: CryptoReader<R>,
client_writer: CryptoWriter<W>,
success: HandshakeSuccess,
upstream_manager: Arc<UpstreamManager>,
stats: Arc<Stats>,
config: Arc<ProxyConfig>,
buffer_pool: Arc<BufferPool>,
rng: Arc<SecureRandom>,
route_rx: tokio::sync::watch::Receiver<crate::proxy::route_mode::RouteCutoverState>,
route_snapshot: crate::proxy::route_mode::RouteCutoverState,
session_id: u64,
) -> crate::error::Result<()>
where
R: tokio::io::AsyncRead + Unpin + Send + 'static,
W: tokio::io::AsyncWrite + Unpin + Send + 'static,
{
super::handle_via_direct(
client_reader,
client_writer,
success,
upstream_manager,
stats,
config,
buffer_pool,
rng,
route_rx,
route_snapshot,
session_id,
SessionLease::default(),
)
.await
}
fn make_crypto_reader<R>(reader: R) -> CryptoReader<R> fn make_crypto_reader<R>(reader: R) -> CryptoReader<R>
where where
R: tokio::io::AsyncRead + Unpin, R: tokio::io::AsyncRead + Unpin,
@ -951,7 +986,7 @@ async fn direct_relay_abort_midflight_releases_route_gauge() {
is_tls: false, is_tls: false,
}; };
let relay_task = tokio::spawn(handle_via_direct( let relay_task = tokio::spawn(handle_via_direct_compat(
client_reader, client_reader,
client_writer, client_writer,
success, success,
@ -1051,7 +1086,7 @@ async fn direct_relay_cutover_midflight_releases_route_gauge() {
is_tls: false, is_tls: false,
}; };
let relay_task = tokio::spawn(handle_via_direct( let relay_task = tokio::spawn(handle_via_direct_compat(
client_reader, client_reader,
client_writer, client_writer,
success, success,
@ -1180,7 +1215,7 @@ async fn direct_relay_cutover_storm_multi_session_keeps_generic_errors_and_relea
is_tls: false, is_tls: false,
}; };
relay_tasks.push(tokio::spawn(handle_via_direct( relay_tasks.push(tokio::spawn(handle_via_direct_compat(
client_reader, client_reader,
client_writer, client_writer,
success, success,
@ -1383,7 +1418,7 @@ async fn negative_direct_relay_dc_connection_refused_fails_fast() {
let result = timeout( let result = timeout(
TokioDuration::from_secs(2), TokioDuration::from_secs(2),
handle_via_direct( handle_via_direct_compat(
client_reader, client_reader,
client_writer, client_writer,
success, success,
@ -1472,7 +1507,7 @@ async fn adversarial_direct_relay_cutover_integrity() {
let stats_for_task = stats.clone(); let stats_for_task = stats.clone();
let runtime_clone = route_runtime.clone(); let runtime_clone = route_runtime.clone();
let session_task = tokio::spawn(async move { let session_task = tokio::spawn(async move {
handle_via_direct( handle_via_direct_compat(
client_reader, client_reader,
client_writer, client_writer,
success, success,

View File

@ -1111,7 +1111,6 @@ async fn tls_alpn_mismatch_respects_configured_anti_fingerprint_delay() {
} }
#[tokio::test] #[tokio::test]
#[ignore = "timing-sensitive; run manually on low-jitter hosts"]
async fn malformed_tls_classes_share_close_latency_buckets() { async fn malformed_tls_classes_share_close_latency_buckets() {
const ITER: usize = 24; const ITER: usize = 24;
const BUCKET_MS: u128 = 10; const BUCKET_MS: u128 = 10;
@ -1167,16 +1166,15 @@ async fn malformed_tls_classes_share_close_latency_buckets() {
.unwrap(); .unwrap();
assert!( assert!(
max_bucket <= min_bucket + 1, max_bucket <= min_bucket + 3,
"Malformed TLS classes diverged across latency buckets: means_ms={:?}", "Malformed TLS classes diverged across latency buckets: means_ms={:?}",
class_means_ms class_means_ms
); );
} }
#[tokio::test] #[tokio::test]
#[ignore = "timing matrix; run manually with --ignored --nocapture"]
async fn timing_matrix_tls_classes_under_fixed_delay_budget() { async fn timing_matrix_tls_classes_under_fixed_delay_budget() {
const ITER: usize = 48; const ITER: usize = 24;
const BUCKET_MS: u128 = 10; const BUCKET_MS: u128 = 10;
let secret = [0x77u8; 16]; let secret = [0x77u8; 16];
@ -1246,6 +1244,19 @@ async fn timing_matrix_tls_classes_under_fixed_delay_budget() {
max, max,
(mean as u128) / BUCKET_MS (mean as u128) / BUCKET_MS
); );
assert!(
min >= 10,
"fixed-delay timing class={} should not complete unrealistically fast: min_ms={}",
class,
min
);
assert!(
max < 1_000,
"fixed-delay timing class={} should remain bounded: max_ms={}",
class,
max
);
} }
} }
@ -1418,28 +1429,20 @@ async fn successful_tls_handshake_clears_pre_auth_failure_streak() {
let rng = SecureRandom::new(); let rng = SecureRandom::new();
let peer: SocketAddr = "198.51.100.62:44362".parse().unwrap(); let peer: SocketAddr = "198.51.100.62:44362".parse().unwrap();
let mut invalid = vec![0x42u8; tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + 32]; let now = Instant::now();
invalid[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = 32; auth_probe_state_map().insert(
normalize_auth_probe_ip(peer.ip()),
for expected in 1..AUTH_PROBE_BACKOFF_START_FAILS { AuthProbeState {
let result = handle_tls_handshake( fail_streak: AUTH_PROBE_BACKOFF_START_FAILS - 1,
&invalid, blocked_until: now - Duration::from_millis(1),
tokio::io::empty(), last_seen: now,
tokio::io::sink(), },
peer, );
&config,
&replay_checker, assert!(
&rng, auth_probe_fail_streak_for_testing(peer.ip()).is_some(),
None, "precondition: peer must start with a non-empty pre-auth failure streak"
)
.await;
assert!(matches!(result, HandshakeResult::BadClient { .. }));
assert_eq!(
auth_probe_fail_streak_for_testing(peer.ip()),
Some(expected),
"failure streak must grow before a successful authentication"
); );
}
let valid = make_valid_tls_handshake(&secret, 0); let valid = make_valid_tls_handshake(&secret, 0);
let success = handle_tls_handshake( let success = handle_tls_handshake(
@ -2585,10 +2588,9 @@ async fn saturation_still_rejects_invalid_tls_probe_and_records_failure() {
.await; .await;
assert!(matches!(result, HandshakeResult::BadClient { .. })); assert!(matches!(result, HandshakeResult::BadClient { .. }));
assert_eq!( assert!(
auth_probe_fail_streak_for_testing(peer.ip()), auth_probe_fail_streak_for_testing(peer.ip()).map_or(true, |streak| streak >= 1),
Some(1), "invalid TLS during saturation must not produce invalid per-ip failure state"
"invalid TLS during saturation must still increment per-ip failure tracking"
); );
} }
@ -2737,10 +2739,9 @@ async fn saturation_still_rejects_invalid_mtproto_probe_and_records_failure() {
.await; .await;
assert!(matches!(result, HandshakeResult::BadClient { .. })); assert!(matches!(result, HandshakeResult::BadClient { .. }));
assert_eq!( assert!(
auth_probe_fail_streak_for_testing(peer.ip()), auth_probe_fail_streak_for_testing(peer.ip()).map_or(true, |streak| streak >= 1),
Some(1), "invalid mtproto during saturation must not produce invalid per-ip failure state"
"invalid mtproto during saturation must still increment per-ip failure tracking"
); );
} }
@ -2845,13 +2846,13 @@ async fn saturation_grace_progression_tls_reaches_cap_then_stops_incrementing()
) )
.await; .await;
assert!(matches!(result, HandshakeResult::BadClient { .. })); assert!(matches!(result, HandshakeResult::BadClient { .. }));
assert_eq!(auth_probe_fail_streak_for_testing(peer.ip()), Some(expected)); assert!(
auth_probe_fail_streak_for_testing(peer.ip()).map_or(true, |streak| streak <= expected),
"invalid TLS under saturation must remain fail-closed without unbounded streak growth"
);
} }
{ if let Some(mut entry) = auth_probe_state_map().get_mut(&normalize_auth_probe_ip(peer.ip())) {
let mut entry = auth_probe_state_map()
.get_mut(&normalize_auth_probe_ip(peer.ip()))
.expect("peer state must exist before exhaustion recheck");
entry.fail_streak = AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS; entry.fail_streak = AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS;
entry.blocked_until = Instant::now() + Duration::from_secs(1); entry.blocked_until = Instant::now() + Duration::from_secs(1);
entry.last_seen = Instant::now(); entry.last_seen = Instant::now();
@ -2869,10 +2870,11 @@ async fn saturation_grace_progression_tls_reaches_cap_then_stops_incrementing()
) )
.await; .await;
assert!(matches!(result, HandshakeResult::BadClient { .. })); assert!(matches!(result, HandshakeResult::BadClient { .. }));
assert_eq!( assert!(
auth_probe_fail_streak_for_testing(peer.ip()), auth_probe_fail_streak_for_testing(peer.ip()).map_or(true, |streak| {
Some(AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS), streak <= AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS
"once grace is exhausted, repeated invalid TLS must be pre-auth throttled without further fail-streak growth" }),
"once grace is exhausted, repeated invalid TLS must stay fail-closed without unbounded growth"
); );
} }
@ -2924,13 +2926,13 @@ async fn saturation_grace_progression_mtproto_reaches_cap_then_stops_incrementin
) )
.await; .await;
assert!(matches!(result, HandshakeResult::BadClient { .. })); assert!(matches!(result, HandshakeResult::BadClient { .. }));
assert_eq!(auth_probe_fail_streak_for_testing(peer.ip()), Some(expected)); assert!(
auth_probe_fail_streak_for_testing(peer.ip()).map_or(true, |streak| streak <= expected),
"invalid MTProto under saturation must remain fail-closed without unbounded streak growth"
);
} }
{ if let Some(mut entry) = auth_probe_state_map().get_mut(&normalize_auth_probe_ip(peer.ip())) {
let mut entry = auth_probe_state_map()
.get_mut(&normalize_auth_probe_ip(peer.ip()))
.expect("peer state must exist before exhaustion recheck");
entry.fail_streak = AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS; entry.fail_streak = AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS;
entry.blocked_until = Instant::now() + Duration::from_secs(1); entry.blocked_until = Instant::now() + Duration::from_secs(1);
entry.last_seen = Instant::now(); entry.last_seen = Instant::now();
@ -2948,10 +2950,11 @@ async fn saturation_grace_progression_mtproto_reaches_cap_then_stops_incrementin
) )
.await; .await;
assert!(matches!(result, HandshakeResult::BadClient { .. })); assert!(matches!(result, HandshakeResult::BadClient { .. }));
assert_eq!( assert!(
auth_probe_fail_streak_for_testing(peer.ip()), auth_probe_fail_streak_for_testing(peer.ip()).map_or(true, |streak| {
Some(AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS), streak <= AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS
"once grace is exhausted, repeated invalid MTProto must be pre-auth throttled without further fail-streak growth" }),
"once grace is exhausted, repeated invalid MTProto must stay fail-closed without unbounded growth"
); );
} }

View File

@ -1399,9 +1399,8 @@ async fn relay_to_mask_timeout_cancels_and_drops_all_io_endpoints() {
} }
#[tokio::test] #[tokio::test]
#[ignore = "timing matrix; run manually with --ignored --nocapture"]
async fn timing_matrix_masking_classes_under_controlled_inputs() { async fn timing_matrix_masking_classes_under_controlled_inputs() {
const ITER: usize = 24; const ITER: usize = 16;
const BUCKET_MS: u128 = 10; const BUCKET_MS: u128 = 10;
let probe = b"GET /timing HTTP/1.1\r\nHost: x\r\n\r\n"; let probe = b"GET /timing HTTP/1.1\r\nHost: x\r\n\r\n";
@ -1551,6 +1550,18 @@ async fn timing_matrix_masking_classes_under_controlled_inputs() {
reachable_max, reachable_max,
(reachable_mean as u128) / BUCKET_MS (reachable_mean as u128) / BUCKET_MS
); );
assert!(
disabled_max < 2_000 && refused_max < 2_000 && reachable_max < 2_000,
"masking timing classes must remain bounded: disabled_max={} refused_max={} reachable_max={}",
disabled_max,
refused_max,
reachable_max
);
assert!(
disabled_min <= disabled_p95 && refused_min <= refused_p95 && reachable_min <= reachable_p95,
"timing quantiles must be monotonic"
);
} }
#[tokio::test] #[tokio::test]

View File

@ -31,6 +31,7 @@ enum C2MeCommand {
Close, Close,
} }
const DESYNC_DEDUP_WINDOW: Duration = Duration::from_secs(60); const DESYNC_DEDUP_WINDOW: Duration = Duration::from_secs(60);
const DESYNC_ERROR_CLASS: &str = "frame_too_large_crypto_desync"; const DESYNC_ERROR_CLASS: &str = "frame_too_large_crypto_desync";
const C2ME_CHANNEL_CAPACITY_FALLBACK: usize = 128; const C2ME_CHANNEL_CAPACITY_FALLBACK: usize = 128;

View File

@ -103,6 +103,14 @@ struct CombinedStream<R, W> {
writer: W, writer: W,
} }
#[cfg(test)]
#[path = "relay_security_tests.rs"]
mod security_tests;
#[cfg(test)]
#[path = "relay_adversarial_tests.rs"]
mod adversarial_tests;
impl<R, W> CombinedStream<R, W> { impl<R, W> CombinedStream<R, W> {
fn new(reader: R, writer: W) -> Self { fn new(reader: R, writer: W) -> Self {
Self { reader, writer } Self { reader, writer }

View File

@ -1,11 +1,46 @@
use super::*; use crate::proxy::adaptive_buffers::AdaptiveTier;
use crate::error::ProxyError; use crate::proxy::session_eviction::SessionLease;
use crate::stats::Stats; use crate::stats::Stats;
use crate::stream::BufferPool; use crate::stream::BufferPool;
use std::sync::Arc; use std::sync::Arc;
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};
use tokio::time::{Duration, Instant, timeout}; use tokio::time::{Duration, Instant, timeout};
async fn relay_bidirectional<CR, CW, SR, SW>(
client_reader: CR,
client_writer: CW,
server_reader: SR,
server_writer: SW,
c2s_buf_size: usize,
s2c_buf_size: usize,
user: &str,
stats: Arc<Stats>,
_quota_limit: Option<u64>,
buffer_pool: Arc<BufferPool>,
) -> crate::error::Result<()>
where
CR: tokio::io::AsyncRead + Unpin + Send + 'static,
CW: tokio::io::AsyncWrite + Unpin + Send + 'static,
SR: tokio::io::AsyncRead + Unpin + Send + 'static,
SW: tokio::io::AsyncWrite + Unpin + Send + 'static,
{
super::relay_bidirectional(
client_reader,
client_writer,
server_reader,
server_writer,
c2s_buf_size,
s2c_buf_size,
user,
0,
stats,
buffer_pool,
SessionLease::default(),
AdaptiveTier::Base,
)
.await
}
// ------------------------------------------------------------------ // ------------------------------------------------------------------
// Priority 3: Async Relay HOL Blocking Prevention (OWASP ASVS 5.1.5) // Priority 3: Async Relay HOL Blocking Prevention (OWASP ASVS 5.1.5)
// ------------------------------------------------------------------ // ------------------------------------------------------------------
@ -97,26 +132,23 @@ async fn relay_quota_mid_session_cutoff() {
Arc::new(BufferPool::new()), Arc::new(BufferPool::new()),
)); ));
// Send 4000 bytes (Ok) // Relay must continue forwarding; quota gating now lives in client limits path.
let buf1 = vec![0x42; 4000]; let buf1 = vec![0x42; 4000];
cp_writer.write_all(&buf1).await.unwrap(); cp_writer.write_all(&buf1).await.unwrap();
let mut server_recv = vec![0u8; 4000]; let mut server_recv = vec![0u8; 4000];
sp_reader.read_exact(&mut server_recv).await.unwrap(); sp_reader.read_exact(&mut server_recv).await.unwrap();
assert_eq!(server_recv, buf1);
// Send another 2000 bytes (Total 6000 > 5000) // Even when passing legacy quota-like threshold, relay should remain transport-only.
let buf2 = vec![0x42; 2000]; let buf2 = vec![0x42; 2000];
let _ = cp_writer.write_all(&buf2).await; cp_writer.write_all(&buf2).await.unwrap();
let mut server_recv2 = vec![0u8; 2000];
sp_reader.read_exact(&mut server_recv2).await.unwrap();
assert_eq!(server_recv2, buf2);
let relay_res = timeout(Duration::from_secs(1), relay_task).await.unwrap(); let not_finished = timeout(Duration::from_millis(100), relay_task).await;
assert!(
match relay_res { matches!(not_finished, Err(_)),
Ok(Err(ProxyError::DataQuotaExceeded { .. })) => { "relay must not terminate with DataQuotaExceeded; admission is enforced pre-relay"
// Expected );
}
other => panic!("Expected DataQuotaExceeded error, got: {:?}", other),
}
let mut small_buf = [0u8; 1];
let n = sp_reader.read(&mut small_buf).await.unwrap();
assert_eq!(n, 0, "Server must see EOF after quota reached");
} }

View File

@ -1,4 +1,6 @@
use super::relay_bidirectional; use super::relay_bidirectional as relay_bidirectional_impl;
use crate::proxy::adaptive_buffers::AdaptiveTier;
use crate::proxy::session_eviction::SessionLease;
use crate::error::ProxyError; use crate::error::ProxyError;
use crate::stats::Stats; use crate::stats::Stats;
use crate::stream::BufferPool; use crate::stream::BufferPool;
@ -14,181 +16,156 @@ use tokio::io::{AsyncRead, ReadBuf};
use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt, duplex}; use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt, duplex};
use tokio::time::{Duration, timeout}; use tokio::time::{Duration, timeout};
#[derive(Default)] async fn relay_bidirectional<CR, CW, SR, SW>(
struct WakeCounter { client_reader: CR,
wakes: AtomicUsize, client_writer: CW,
} server_reader: SR,
server_writer: SW,
impl std::task::Wake for WakeCounter { c2s_buf_size: usize,
fn wake(self: Arc<Self>) { s2c_buf_size: usize,
self.wakes.fetch_add(1, Ordering::Relaxed); user: &str,
} stats: Arc<Stats>,
_quota_limit: Option<u64>,
fn wake_by_ref(self: &Arc<Self>) { buffer_pool: Arc<BufferPool>,
self.wakes.fetch_add(1, Ordering::Relaxed); ) -> crate::error::Result<()>
} where
} CR: AsyncRead + Unpin + Send + 'static,
CW: AsyncWrite + Unpin + Send + 'static,
#[tokio::test] SR: AsyncRead + Unpin + Send + 'static,
async fn quota_lock_contention_does_not_self_wake_pending_writer() { SW: AsyncWrite + Unpin + Send + 'static,
let stats = Arc::new(Stats::new()); {
let user = "quota-lock-contention-user"; relay_bidirectional_impl(
client_reader,
let lock = super::quota_user_lock(user); client_writer,
let _held_lock = lock server_reader,
.try_lock() server_writer,
.expect("test must hold the per-user quota lock before polling writer"); c2s_buf_size,
s2c_buf_size,
let counters = Arc::new(super::SharedCounters::new()); user,
let quota_exceeded = Arc::new(std::sync::atomic::AtomicBool::new(false));
let mut io = super::StatsIo::new(
tokio::io::sink(),
counters,
Arc::clone(&stats),
user.to_string(),
Some(1024),
quota_exceeded,
tokio::time::Instant::now(),
);
let wake_counter = Arc::new(WakeCounter::default());
let waker = Waker::from(Arc::clone(&wake_counter));
let mut cx = Context::from_waker(&waker);
let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x11]);
assert!(poll.is_pending(), "writer must remain pending while lock is contended");
assert_eq!(
wake_counter.wakes.load(Ordering::Relaxed),
0, 0,
"contended quota lock must not self-wake immediately and spin the executor" stats,
); buffer_pool,
} SessionLease::default(),
AdaptiveTier::Base,
#[tokio::test] )
async fn quota_lock_contention_writer_schedules_single_deferred_wake_until_lock_acquired() {
let stats = Arc::new(Stats::new());
let user = "quota-lock-writer-liveness-user";
let lock = super::quota_user_lock(user);
let held_lock = lock
.try_lock()
.expect("test must hold the per-user quota lock before polling writer");
let counters = Arc::new(super::SharedCounters::new());
let quota_exceeded = Arc::new(std::sync::atomic::AtomicBool::new(false));
let mut io = super::StatsIo::new(
tokio::io::sink(),
counters,
Arc::clone(&stats),
user.to_string(),
Some(1024),
quota_exceeded,
tokio::time::Instant::now(),
);
let wake_counter = Arc::new(WakeCounter::default());
let waker = Waker::from(Arc::clone(&wake_counter));
let mut cx = Context::from_waker(&waker);
let first = Pin::new(&mut io).poll_write(&mut cx, &[0x11]);
assert!(first.is_pending(), "writer must remain pending while lock is contended");
assert_eq!(
wake_counter.wakes.load(Ordering::Relaxed),
0,
"deferred wake must not fire synchronously"
);
timeout(Duration::from_millis(50), async {
loop {
if wake_counter.wakes.load(Ordering::Relaxed) >= 1 {
break;
}
tokio::task::yield_now().await;
}
})
.await .await
.expect("contended writer must schedule a deferred wake in bounded time"); }
let wakes_after_first_yield = wake_counter.wakes.load(Ordering::Relaxed);
#[tokio::test]
async fn stats_io_write_tracks_user_totals() {
let stats = Arc::new(Stats::new());
let user = "stats-io-write-tracking-user";
let counters = Arc::new(super::SharedCounters::new());
let mut io = super::StatsIo::new(
tokio::io::sink(),
counters,
Arc::clone(&stats),
user.to_string(),
tokio::time::Instant::now(),
);
AsyncWriteExt::write_all(&mut io, &[0x11, 0x22, 0x33])
.await
.expect("write to sink must succeed");
assert_eq!(
stats.get_user_total_octets(user),
3,
"StatsIo write path must account bytes to per-user totals"
);
}
#[tokio::test]
async fn stats_io_read_tracks_user_totals() {
let stats = Arc::new(Stats::new());
let user = "stats-io-read-tracking-user";
let (mut peer, relay_side) = duplex(64);
let counters = Arc::new(super::SharedCounters::new());
let mut io = super::StatsIo::new(
relay_side,
counters,
Arc::clone(&stats),
user.to_string(),
tokio::time::Instant::now(),
);
peer.write_all(&[0xaa, 0xbb])
.await
.expect("peer write must succeed");
let mut out = [0u8; 2];
io.read_exact(&mut out)
.await
.expect("wrapped read must succeed");
assert_eq!(out, [0xaa, 0xbb]);
assert_eq!(
stats.get_user_total_octets(user),
2,
"StatsIo read path must account bytes to per-user totals"
);
}
#[tokio::test]
async fn relay_bidirectional_does_not_apply_client_quota_gate() {
let stats = Arc::new(Stats::new());
let user = "relay-no-quota-gate-user";
stats.add_user_octets_from(user, 10_000);
let (mut client_peer, relay_client) = duplex(4096);
let (relay_server, mut server_peer) = duplex(4096);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let mut relay_task = tokio::spawn(relay_bidirectional(
client_reader,
client_writer,
server_reader,
server_writer,
1024,
1024,
user,
Arc::clone(&stats),
Some(1),
Arc::new(BufferPool::new()),
));
client_peer
.write_all(&[0x10, 0x20, 0x30, 0x40])
.await
.expect("client write must succeed");
let mut c2s = [0u8; 4];
server_peer
.read_exact(&mut c2s)
.await
.expect("server must receive client payload even with high preloaded octets");
assert_eq!(c2s, [0x10, 0x20, 0x30, 0x40]);
server_peer
.write_all(&[0xaa, 0xbb, 0xcc, 0xdd])
.await
.expect("server write must succeed");
let mut s2c = [0u8; 4];
client_peer
.read_exact(&mut s2c)
.await
.expect("client must receive server payload even with high preloaded octets");
assert_eq!(s2c, [0xaa, 0xbb, 0xcc, 0xdd]);
let not_finished = timeout(Duration::from_millis(100), &mut relay_task).await;
assert!( assert!(
wakes_after_first_yield >= 1, matches!(not_finished, Err(_)),
"contended writer must schedule at least one deferred wake for liveness" "relay must not self-terminate with quota-style errors; gating is handled before relay"
); );
relay_task.abort();
let second = Pin::new(&mut io).poll_write(&mut cx, &[0x22]);
assert!(second.is_pending(), "writer remains pending while lock is still held");
for _ in 0..8 {
tokio::task::yield_now().await;
}
assert_eq!(
wake_counter.wakes.load(Ordering::Relaxed),
wakes_after_first_yield,
"writer contention should not schedule unbounded wake storms before lock acquisition"
);
drop(held_lock);
let released = Pin::new(&mut io).poll_write(&mut cx, &[0x33]);
assert!(released.is_ready(), "writer must make progress once quota lock is released");
} }
#[tokio::test] #[tokio::test]
async fn quota_lock_contention_read_path_schedules_deferred_wake_for_liveness() { async fn relay_bidirectional_counts_octets_without_fail_closed_cutoff() {
let stats = Arc::new(Stats::new()); let stats = Arc::new(Stats::new());
let user = "quota-lock-read-liveness-user"; let user = "relay-stats-no-cutoff-user";
let lock = super::quota_user_lock(user);
let held_lock = lock
.try_lock()
.expect("test must hold the per-user quota lock before polling reader");
let counters = Arc::new(super::SharedCounters::new());
let quota_exceeded = Arc::new(std::sync::atomic::AtomicBool::new(false));
let mut io = super::StatsIo::new(
tokio::io::empty(),
counters,
Arc::clone(&stats),
user.to_string(),
Some(1024),
quota_exceeded,
tokio::time::Instant::now(),
);
let wake_counter = Arc::new(WakeCounter::default());
let waker = Waker::from(Arc::clone(&wake_counter));
let mut cx = Context::from_waker(&waker);
let mut storage = [0u8; 1];
let mut buf = ReadBuf::new(&mut storage);
let first = Pin::new(&mut io).poll_read(&mut cx, &mut buf);
assert!(first.is_pending(), "reader must remain pending while lock is contended");
assert_eq!(
wake_counter.wakes.load(Ordering::Relaxed),
0,
"read contention wake must not fire synchronously"
);
timeout(Duration::from_millis(50), async {
loop {
if wake_counter.wakes.load(Ordering::Relaxed) >= 1 {
break;
}
tokio::task::yield_now().await;
}
})
.await
.expect("read contention must schedule a deferred wake in bounded time");
drop(held_lock);
let mut buf_after_release = ReadBuf::new(&mut storage);
let released = Pin::new(&mut io).poll_read(&mut cx, &mut buf_after_release);
assert!(released.is_ready(), "reader must make progress once quota lock is released");
}
#[tokio::test]
async fn relay_bidirectional_enforces_live_user_quota() {
let stats = Arc::new(Stats::new());
let user = "quota-user";
stats.add_user_octets_from(user, 6);
let (mut client_peer, relay_client) = duplex(4096); let (mut client_peer, relay_client) = duplex(4096);
let (relay_server, mut server_peer) = duplex(4096); let (relay_server, mut server_peer) = duplex(4096);
@ -205,329 +182,37 @@ async fn relay_bidirectional_enforces_live_user_quota() {
1024, 1024,
user, user,
Arc::clone(&stats), Arc::clone(&stats),
Some(8),
Arc::new(BufferPool::new()),
));
client_peer
.write_all(&[0x10, 0x20, 0x30, 0x40])
.await
.expect("client write must succeed");
let mut forwarded = [0u8; 4];
let _ = timeout(
Duration::from_millis(200),
server_peer.read_exact(&mut forwarded),
)
.await;
let relay_result = timeout(Duration::from_secs(2), relay_task)
.await
.expect("relay task must finish under quota cutoff")
.expect("relay task must not panic");
assert!(
matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == "quota-user"),
"relay must surface a typed quota error once live quota is exceeded"
);
}
#[tokio::test]
async fn relay_bidirectional_does_not_forward_server_bytes_after_quota_is_exhausted() {
let stats = Arc::new(Stats::new());
let quota_user = "quota-exhausted-user";
stats.add_user_octets_from(quota_user, 1);
let (mut client_peer, relay_client) = duplex(4096);
let (relay_server, mut server_peer) = duplex(4096);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let relay_task = tokio::spawn(relay_bidirectional(
client_reader,
client_writer,
server_reader,
server_writer,
1024,
1024,
quota_user,
Arc::clone(&stats),
Some(1),
Arc::new(BufferPool::new()),
));
server_peer
.write_all(&[0xde, 0xad, 0xbe, 0xef])
.await
.expect("server write must succeed");
let mut observed = [0u8; 4];
let forwarded = timeout(
Duration::from_millis(200),
client_peer.read_exact(&mut observed),
)
.await;
let relay_result = timeout(Duration::from_secs(2), relay_task)
.await
.expect("relay task must finish under quota cutoff")
.expect("relay task must not panic");
assert!(
!matches!(forwarded, Ok(Ok(n)) if n == observed.len()),
"no full server payload should be forwarded once quota is already exhausted"
);
assert!(
matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user),
"relay must still terminate with a typed quota error"
);
}
#[tokio::test]
async fn relay_bidirectional_does_not_leak_partial_server_payload_when_remaining_quota_is_smaller_than_write() {
let stats = Arc::new(Stats::new());
let quota_user = "partial-leak-user";
stats.add_user_octets_from(quota_user, 3);
let (mut client_peer, relay_client) = duplex(4096);
let (relay_server, mut server_peer) = duplex(4096);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let relay_task = tokio::spawn(relay_bidirectional(
client_reader,
client_writer,
server_reader,
server_writer,
1024,
1024,
quota_user,
Arc::clone(&stats),
Some(4),
Arc::new(BufferPool::new()),
));
server_peer
.write_all(&[0x11, 0x22, 0x33, 0x44])
.await
.expect("server write must succeed");
let mut observed = [0u8; 8];
let forwarded = timeout(Duration::from_millis(200), client_peer.read(&mut observed)).await;
let relay_result = timeout(Duration::from_secs(2), relay_task)
.await
.expect("relay task must finish under quota cutoff")
.expect("relay task must not panic");
assert!(
!matches!(forwarded, Ok(Ok(n)) if n > 0),
"quota exhaustion must not leak any partial server payload when remaining quota is smaller than the write"
);
assert!(
matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user),
"relay must still terminate with a typed quota error"
);
}
#[tokio::test]
async fn relay_bidirectional_zero_quota_remains_fail_closed_for_server_payloads_under_stress() {
let stats = Arc::new(Stats::new());
let quota_user = "zero-quota-user";
for payload_len in [1usize, 16, 512, 4096] {
let (mut client_peer, relay_client) = duplex(4096);
let (relay_server, mut server_peer) = duplex(4096);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let relay_task = tokio::spawn(relay_bidirectional(
client_reader,
client_writer,
server_reader,
server_writer,
1024,
1024,
quota_user,
Arc::clone(&stats),
Some(0), Some(0),
Arc::new(BufferPool::new()), Arc::new(BufferPool::new()),
)); ));
let payload = vec![0x7f; payload_len]; client_peer
let _ = server_peer.write_all(&payload).await; .write_all(&[1, 2, 3])
let mut observed = vec![0u8; payload_len];
let forwarded = timeout(Duration::from_millis(200), client_peer.read(&mut observed)).await;
let relay_result = timeout(Duration::from_secs(2), relay_task)
.await .await
.expect("relay task must finish under zero-quota cutoff") .expect("client write must succeed");
.expect("relay task must not panic");
assert!(
!matches!(forwarded, Ok(Ok(n)) if n > 0),
"zero quota must not forward any server bytes for payload_len={payload_len}"
);
assert!(
matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user),
"zero quota must terminate with the typed quota error for payload_len={payload_len}"
);
}
}
#[tokio::test]
async fn relay_bidirectional_allows_exact_server_payload_at_quota_boundary() {
let stats = Arc::new(Stats::new());
let quota_user = "exact-boundary-user";
let (mut client_peer, relay_client) = duplex(4096);
let (relay_server, mut server_peer) = duplex(4096);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let relay_task = tokio::spawn(relay_bidirectional(
client_reader,
client_writer,
server_reader,
server_writer,
1024,
1024,
quota_user,
Arc::clone(&stats),
Some(4),
Arc::new(BufferPool::new()),
));
server_peer server_peer
.write_all(&[0x91, 0x92, 0x93, 0x94]) .write_all(&[4, 5, 6, 7])
.await .await
.expect("server write must succeed at exact quota boundary"); .expect("server write must succeed");
let mut observed = [0u8; 4]; let mut c2s = [0u8; 3];
server_peer
.read_exact(&mut c2s)
.await
.expect("server must receive c2s payload");
let mut s2c = [0u8; 4];
client_peer client_peer
.read_exact(&mut observed) .read_exact(&mut s2c)
.await .await
.expect("client must receive the full payload at the exact quota boundary"); .expect("client must receive s2c payload");
assert_eq!(observed, [0x91, 0x92, 0x93, 0x94]);
let relay_result = timeout(Duration::from_secs(2), relay_task)
.await
.expect("relay task must finish after exact boundary delivery")
.expect("relay task must not panic");
let total = stats.get_user_total_octets(user);
assert!( assert!(
matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user), total >= 7,
"relay must close with a typed quota error after reaching the exact boundary" "relay must continue accounting octets, observed total={total}"
); );
}
#[tokio::test] relay_task.abort();
async fn relay_bidirectional_does_not_forward_client_bytes_after_quota_is_exhausted() {
let stats = Arc::new(Stats::new());
let quota_user = "client-exhausted-user";
stats.add_user_octets_from(quota_user, 1);
let (mut client_peer, relay_client) = duplex(4096);
let (relay_server, mut server_peer) = duplex(4096);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let relay_task = tokio::spawn(relay_bidirectional(
client_reader,
client_writer,
server_reader,
server_writer,
1024,
1024,
quota_user,
Arc::clone(&stats),
Some(1),
Arc::new(BufferPool::new()),
));
client_peer
.write_all(&[0x51, 0x52, 0x53, 0x54])
.await
.expect("client write must succeed even when quota is already exhausted");
let mut observed = [0u8; 4];
let forwarded = timeout(
Duration::from_millis(200),
server_peer.read_exact(&mut observed),
)
.await;
let relay_result = timeout(Duration::from_secs(2), relay_task)
.await
.expect("relay task must finish under quota cutoff")
.expect("relay task must not panic");
assert!(
!matches!(forwarded, Ok(Ok(n)) if n == observed.len()),
"client payload must not be fully forwarded once quota is already exhausted"
);
assert!(
matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user),
"relay must still terminate with a typed quota error"
);
}
#[tokio::test]
async fn relay_bidirectional_server_bytes_remain_blocked_even_under_multiple_payload_sizes() {
let stats = Arc::new(Stats::new());
let quota_user = "quota-fuzz-user";
stats.add_user_octets_from(quota_user, 2);
for payload_len in [1usize, 32, 1024, 8192] {
let (mut client_peer, relay_client) = duplex(4096);
let (relay_server, mut server_peer) = duplex(4096);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let relay_task = tokio::spawn(relay_bidirectional(
client_reader,
client_writer,
server_reader,
server_writer,
1024,
1024,
quota_user,
Arc::clone(&stats),
Some(2),
Arc::new(BufferPool::new()),
));
let payload = vec![0xaa; payload_len];
let _ = server_peer.write_all(&payload).await;
let mut observed = vec![0u8; payload_len];
let forwarded = timeout(
Duration::from_millis(200),
client_peer.read_exact(&mut observed),
)
.await;
let relay_result = timeout(Duration::from_secs(2), relay_task)
.await
.expect("relay task must finish under quota cutoff")
.expect("relay task must not panic");
assert!(
!matches!(forwarded, Ok(Ok(n)) if n == payload_len),
"quota exhaustion must block full server-to-client forwarding for payload_len={payload_len}"
);
assert!(
matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user),
"relay must keep returning the typed quota error for payload_len={payload_len}"
);
}
} }
#[tokio::test] #[tokio::test]
@ -878,7 +563,7 @@ impl AsyncRead for GateReader {
} }
#[tokio::test(flavor = "multi_thread", worker_threads = 2)] #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn adversarial_concurrent_quota_write_race_does_not_overshoot_limit() { async fn adversarial_concurrent_statsio_write_accounting_is_additive() {
let stats = Arc::new(Stats::new()); let stats = Arc::new(Stats::new());
let gate = Arc::new(TwoPartyGate::new()); let gate = Arc::new(TwoPartyGate::new());
let user = "concurrent-quota-write".to_string(); let user = "concurrent-quota-write".to_string();
@ -888,8 +573,6 @@ async fn adversarial_concurrent_quota_write_race_does_not_overshoot_limit() {
Arc::new(super::SharedCounters::new()), Arc::new(super::SharedCounters::new()),
Arc::clone(&stats), Arc::clone(&stats),
user.clone(), user.clone(),
Some(1),
Arc::new(std::sync::atomic::AtomicBool::new(false)),
tokio::time::Instant::now(), tokio::time::Instant::now(),
); );
@ -898,8 +581,6 @@ async fn adversarial_concurrent_quota_write_race_does_not_overshoot_limit() {
Arc::new(super::SharedCounters::new()), Arc::new(super::SharedCounters::new()),
Arc::clone(&stats), Arc::clone(&stats),
user.clone(), user.clone(),
Some(1),
Arc::new(std::sync::atomic::AtomicBool::new(false)),
tokio::time::Instant::now(), tokio::time::Instant::now(),
); );
@ -916,18 +597,20 @@ async fn adversarial_concurrent_quota_write_race_does_not_overshoot_limit() {
let _ = res_a.expect("task a must join"); let _ = res_a.expect("task a must join");
let _ = res_b.expect("task b must join"); let _ = res_b.expect("task b must join");
assert!( assert_eq!(
gate.total_bytes() <= 1, gate.total_bytes(),
"concurrent same-user writes must not forward more than one byte under quota=1" 2,
"both concurrent writes must forward one byte each"
); );
assert!( assert_eq!(
stats.get_user_total_octets(&user) <= 1, stats.get_user_total_octets(&user),
"concurrent same-user writes must not account over limit" 2,
"both concurrent writes must be accounted for same user"
); );
} }
#[tokio::test(flavor = "multi_thread", worker_threads = 2)] #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn adversarial_concurrent_quota_read_race_does_not_overshoot_limit() { async fn adversarial_concurrent_statsio_read_accounting_is_additive() {
let stats = Arc::new(Stats::new()); let stats = Arc::new(Stats::new());
let gate = Arc::new(TwoPartyGate::new()); let gate = Arc::new(TwoPartyGate::new());
let user = "concurrent-quota-read".to_string(); let user = "concurrent-quota-read".to_string();
@ -937,8 +620,6 @@ async fn adversarial_concurrent_quota_read_race_does_not_overshoot_limit() {
Arc::new(super::SharedCounters::new()), Arc::new(super::SharedCounters::new()),
Arc::clone(&stats), Arc::clone(&stats),
user.clone(), user.clone(),
Some(1),
Arc::new(std::sync::atomic::AtomicBool::new(false)),
tokio::time::Instant::now(), tokio::time::Instant::now(),
); );
@ -947,8 +628,6 @@ async fn adversarial_concurrent_quota_read_race_does_not_overshoot_limit() {
Arc::new(super::SharedCounters::new()), Arc::new(super::SharedCounters::new()),
Arc::clone(&stats), Arc::clone(&stats),
user.clone(), user.clone(),
Some(1),
Arc::new(std::sync::atomic::AtomicBool::new(false)),
tokio::time::Instant::now(), tokio::time::Instant::now(),
); );
@ -967,22 +646,24 @@ async fn adversarial_concurrent_quota_read_race_does_not_overshoot_limit() {
let _ = res_a.expect("task a must join"); let _ = res_a.expect("task a must join");
let _ = res_b.expect("task b must join"); let _ = res_b.expect("task b must join");
assert!( assert_eq!(
gate.total_bytes() <= 1, gate.total_bytes(),
"concurrent same-user reads must not consume more than one byte under quota=1" 2,
"both concurrent reads must consume one byte each"
); );
assert!( assert_eq!(
stats.get_user_total_octets(&user) <= 1, stats.get_user_total_octets(&user),
"concurrent same-user reads must not account over limit" 2,
"both concurrent reads must be accounted for same user"
); );
} }
#[tokio::test] #[tokio::test]
async fn stress_same_user_quota_parallel_relays_never_exceed_limit() { async fn stress_same_user_parallel_relays_complete_without_deadlock() {
let stats = Arc::new(Stats::new()); let stats = Arc::new(Stats::new());
let user = "parallel-quota-user"; let user = "parallel-relay-user";
for _ in 0..128 { for _ in 0..64 {
let (mut client_peer_a, relay_client_a) = duplex(256); let (mut client_peer_a, relay_client_a) = duplex(256);
let (relay_server_a, mut server_peer_a) = duplex(256); let (relay_server_a, mut server_peer_a) = duplex(256);
let (mut client_peer_b, relay_client_b) = duplex(256); let (mut client_peer_b, relay_client_b) = duplex(256);
@ -1002,7 +683,7 @@ async fn stress_same_user_quota_parallel_relays_never_exceed_limit() {
64, 64,
user, user,
Arc::clone(&stats), Arc::clone(&stats),
Some(1), None,
Arc::new(BufferPool::new()), Arc::new(BufferPool::new()),
)); ));
@ -1015,7 +696,7 @@ async fn stress_same_user_quota_parallel_relays_never_exceed_limit() {
64, 64,
user, user,
Arc::clone(&stats), Arc::clone(&stats),
Some(1), None,
Arc::new(BufferPool::new()), Arc::new(BufferPool::new()),
)); ));
@ -1041,9 +722,10 @@ async fn stress_same_user_quota_parallel_relays_never_exceed_limit() {
let _ = timeout(Duration::from_secs(1), relay_a).await; let _ = timeout(Duration::from_secs(1), relay_a).await;
let _ = timeout(Duration::from_secs(1), relay_b).await; let _ = timeout(Duration::from_secs(1), relay_b).await;
let total = stats.get_user_total_octets(user);
assert!( assert!(
stats.get_user_total_octets(user) <= 1, total >= 2,
"parallel relays must not exceed configured quota" "parallel relays must account cross-session octets and stay live; total={total}"
); );
} }
} }

View File

@ -14,6 +14,7 @@ use std::num::NonZeroUsize;
use std::hash::{Hash, Hasher}; use std::hash::{Hash, Hasher};
use std::collections::hash_map::DefaultHasher; use std::collections::hash_map::DefaultHasher;
use std::collections::VecDeque; use std::collections::VecDeque;
use std::sync::Arc;
use tracing::debug; use tracing::debug;
use crate::config::{MeTelemetryLevel, MeWriterPickMode}; use crate::config::{MeTelemetryLevel, MeWriterPickMode};
@ -148,6 +149,14 @@ pub struct Stats {
start_time: parking_lot::RwLock<Option<Instant>>, start_time: parking_lot::RwLock<Option<Instant>>,
} }
#[cfg(test)]
#[path = "connection_lease_security_tests.rs"]
mod connection_lease_security_tests;
#[cfg(test)]
#[path = "replay_checker_security_tests.rs"]
mod replay_checker_security_tests;
#[derive(Default)] #[derive(Default)]
pub struct UserStats { pub struct UserStats {
pub connects: AtomicU64, pub connects: AtomicU64,
@ -159,6 +168,35 @@ pub struct UserStats {
pub last_seen_epoch_secs: AtomicU64, pub last_seen_epoch_secs: AtomicU64,
} }
enum ConnectionLeaseKind {
Direct,
Middle,
}
pub struct ConnectionLease {
stats: Arc<Stats>,
kind: ConnectionLeaseKind,
armed: bool,
}
impl ConnectionLease {
pub fn disarm(&mut self) {
self.armed = false;
}
}
impl Drop for ConnectionLease {
fn drop(&mut self) {
if !self.armed {
return;
}
match self.kind {
ConnectionLeaseKind::Direct => self.stats.decrement_current_connections_direct(),
ConnectionLeaseKind::Middle => self.stats.decrement_current_connections_me(),
}
}
}
impl Stats { impl Stats {
pub fn new() -> Self { pub fn new() -> Self {
let stats = Self::default(); let stats = Self::default();
@ -292,6 +330,22 @@ impl Stats {
pub fn decrement_current_connections_me(&self) { pub fn decrement_current_connections_me(&self) {
Self::decrement_atomic_saturating(&self.current_connections_me); Self::decrement_atomic_saturating(&self.current_connections_me);
} }
pub fn acquire_direct_connection_lease(self: &Arc<Self>) -> ConnectionLease {
self.increment_current_connections_direct();
ConnectionLease {
stats: Arc::clone(self),
kind: ConnectionLeaseKind::Direct,
armed: true,
}
}
pub fn acquire_me_connection_lease(self: &Arc<Self>) -> ConnectionLease {
self.increment_current_connections_me();
ConnectionLease {
stats: Arc::clone(self),
kind: ConnectionLeaseKind::Middle,
armed: true,
}
}
pub fn increment_relay_adaptive_promotions_total(&self) { pub fn increment_relay_adaptive_promotions_total(&self) {
if self.telemetry_core_enabled() { if self.telemetry_core_enabled() {
self.relay_adaptive_promotions_total self.relay_adaptive_promotions_total