feat: enhance quota user lock management and testing

- Adjusted QUOTA_USER_LOCKS_MAX based on test and non-test configurations to improve flexibility.
- Implemented logic to retain existing locks when the maximum quota is reached, ensuring efficient memory usage.
- Added comprehensive tests for quota user lock functionality, including cache reuse, saturation behavior, and race conditions.
- Enhanced StatsIo struct to manage wake scheduling for read and write operations, preventing unnecessary self-wakes.
- Introduced separate replay checker domains for handshake and TLS to ensure isolation and prevent cross-pollution of keys.
- Added security tests for replay checker to validate domain separation and window clamping behavior.
This commit is contained in:
David Osipov 2026-03-18 23:55:08 +04:00
parent 20e205189c
commit c7cf37898b
No known key found for this signature in database
GPG Key ID: 0E55C4A47454E82E
18 changed files with 1896 additions and 49 deletions

1
.gitignore vendored
View File

@ -21,3 +21,4 @@ target
#.idea/
proxy-secret
coverage-html/

View File

@ -1949,6 +1949,138 @@ fn server_hello_new_session_ticket_count_is_safely_capped() {
);
}
#[test]
fn boot_time_handshake_replay_remains_blocked_after_cache_window_expires() {
let secret = b"gap_t01_boot_replay";
let secrets = vec![("user".to_string(), secret.to_vec())];
let handshake = make_valid_tls_handshake(secret, 1);
let validation = validate_tls_handshake_with_replay_window(&handshake, &secrets, false, 2)
.expect("boot-time handshake must validate on first use");
let checker = crate::stats::ReplayChecker::new(128, std::time::Duration::from_millis(40));
let digest_half = &validation.digest[..TLS_DIGEST_HALF_LEN];
assert!(
!checker.check_and_add_tls_digest(digest_half),
"first use must not be treated as replay"
);
assert!(
checker.check_and_add_tls_digest(digest_half),
"immediate second use must be detected as replay"
);
std::thread::sleep(std::time::Duration::from_millis(70));
let validation_after_expiry = validate_tls_handshake_with_replay_window(&handshake, &secrets, false, 2)
.expect("boot-time handshake must still cryptographically validate after cache expiry");
let digest_half_after_expiry = &validation_after_expiry.digest[..TLS_DIGEST_HALF_LEN];
assert_eq!(digest_half, digest_half_after_expiry, "replay key must be stable for same handshake");
assert!(
checker.check_and_add_tls_digest(digest_half_after_expiry),
"after cache window expiry, the same boot-time handshake must still be treated as replay"
);
}
#[test]
fn adversarial_boot_time_handshake_should_not_be_replayable_after_cache_expiry() {
let secret = b"gap_t01_boot_replay_adversarial";
let secrets = vec![("user".to_string(), secret.to_vec())];
let handshake = make_valid_tls_handshake(secret, 1);
let validation = validate_tls_handshake_with_replay_window(&handshake, &secrets, false, 2)
.expect("boot-time handshake must validate on first use");
let checker = crate::stats::ReplayChecker::new(128, std::time::Duration::from_millis(40));
let digest_half = &validation.digest[..TLS_DIGEST_HALF_LEN];
assert!(
!checker.check_and_add_tls_digest(digest_half),
"first use must not be treated as replay"
);
assert!(
checker.check_and_add_tls_digest(digest_half),
"immediate reuse must be rejected as replay"
);
std::thread::sleep(std::time::Duration::from_millis(70));
let validation_after_expiry = validate_tls_handshake_with_replay_window(&handshake, &secrets, false, 2)
.expect("boot-time handshake still validates cryptographically after cache expiry");
let digest_half_after_expiry = &validation_after_expiry.digest[..TLS_DIGEST_HALF_LEN];
assert_eq!(
digest_half, digest_half_after_expiry,
"replay key must remain stable for the same captured handshake"
);
assert!(
checker.check_and_add_tls_digest(digest_half_after_expiry),
"security expectation: a boot-time handshake should remain replay-protected even after cache expiry"
);
}
#[test]
fn stress_short_replay_window_boot_timestamp_replay_cycles_remain_fail_closed_in_window() {
let secret = b"gap_t01_boot_replay_stress";
let secrets = vec![("user".to_string(), secret.to_vec())];
let handshake = make_valid_tls_handshake(secret, 1);
let checker = crate::stats::ReplayChecker::new(256, std::time::Duration::from_millis(25));
for cycle in 0..64 {
let validation = validate_tls_handshake_with_replay_window(&handshake, &secrets, false, 2)
.expect("boot-time handshake must validate");
let digest_half = &validation.digest[..TLS_DIGEST_HALF_LEN];
if cycle == 0 {
assert!(
!checker.check_and_add_tls_digest(digest_half),
"cycle 0: first use must be fresh"
);
assert!(
checker.check_and_add_tls_digest(digest_half),
"cycle 0: second use must be replay"
);
} else {
assert!(
checker.check_and_add_tls_digest(digest_half),
"cycle {cycle}: digest must remain replay-protected across short-window churn"
);
}
std::thread::sleep(std::time::Duration::from_millis(30));
}
}
#[test]
fn light_fuzz_boot_time_timestamp_matrix_with_short_replay_window_obeys_boot_cap() {
let secret = b"gap_t01_boot_replay_fuzz";
let secrets = vec![("user".to_string(), secret.to_vec())];
let mut s: u64 = 0xA1B2_C3D4_55AA_7733;
for _ in 0..2048 {
s ^= s << 7;
s ^= s >> 9;
s ^= s << 8;
let ts = (s as u32) % 8;
let handshake = make_valid_tls_handshake(secret, ts);
let accepted = validate_tls_handshake_with_replay_window(&handshake, &secrets, false, 2)
.is_some();
if ts < 2 {
assert!(accepted, "timestamp {ts} must remain boot-time compatible under 2s cap");
} else {
assert!(
!accepted,
"timestamp {ts} must be rejected when outside replay-window boot cap"
);
}
}
}
#[test]
fn server_hello_application_data_contains_alpn_marker_when_selected() {
let secret = b"alpn_marker_test";

View File

@ -300,7 +300,7 @@ where
handle_bad_client(
reader,
writer,
&mtproto_handshake,
&handshake,
real_peer,
local_addr,
&config,
@ -713,7 +713,7 @@ impl RunningClientHandler {
handle_bad_client(
reader,
writer,
&mtproto_handshake,
&handshake,
peer,
local_addr,
&config,

View File

@ -5,8 +5,8 @@ use crate::crypto::sha256_hmac;
use crate::protocol::constants::ProtoTag;
use crate::protocol::tls;
use crate::proxy::handshake::HandshakeSuccess;
use crate::transport::proxy_protocol::ProxyProtocolV1Builder;
use crate::stream::{CryptoReader, CryptoWriter};
use crate::transport::proxy_protocol::ProxyProtocolV1Builder;
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
@ -303,6 +303,333 @@ async fn relay_cutover_releases_user_gate_and_ip_reservation() {
let _ = tg_accept_task.await;
}
#[tokio::test]
async fn integration_route_cutover_and_quota_overlap_fails_closed_and_releases_state() {
let tg_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let tg_addr = tg_listener.local_addr().unwrap();
let tg_accept_task = tokio::spawn(async move {
let (mut stream, _) = tg_listener.accept().await.unwrap();
stream.write_all(&[0x41, 0x42]).await.unwrap();
tokio::time::sleep(Duration::from_secs(1)).await;
});
let user = "cutover-quota-overlap-user";
let peer_addr: SocketAddr = "198.51.100.240:50010".parse().unwrap();
let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new());
let mut cfg = ProxyConfig::default();
cfg.access.user_max_tcp_conns.insert(user.to_string(), 8);
cfg.access.user_data_quota.insert(user.to_string(), 1);
cfg.dc_overrides
.insert("2".to_string(), vec![tg_addr.to_string()]);
let config = Arc::new(cfg);
let upstream_manager = Arc::new(UpstreamManager::new(
vec![UpstreamConfig {
upstream_type: UpstreamType::Direct {
interface: None,
bind_addresses: None,
},
weight: 1,
enabled: true,
scopes: String::new(),
selected_scope: String::new(),
}],
1,
1,
1,
1,
false,
stats.clone(),
));
let buffer_pool = Arc::new(BufferPool::new());
let rng = Arc::new(SecureRandom::new());
let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct));
let (server_side, client_side) = duplex(64 * 1024);
let (server_reader, server_writer) = tokio::io::split(server_side);
let client_reader = make_crypto_reader(server_reader);
let client_writer = make_crypto_writer(server_writer);
let success = HandshakeSuccess {
user: user.to_string(),
dc_idx: 2,
proto_tag: ProtoTag::Intermediate,
dec_key: [0u8; 32],
dec_iv: 0,
enc_key: [0u8; 32],
enc_iv: 0,
peer: peer_addr,
is_tls: false,
};
let relay_task = tokio::spawn(RunningClientHandler::handle_authenticated_static(
client_reader,
client_writer,
success,
upstream_manager,
stats.clone(),
config,
buffer_pool,
rng,
None,
route_runtime.clone(),
"127.0.0.1:443".parse().unwrap(),
peer_addr,
ip_tracker.clone(),
));
let observed_progress = tokio::time::timeout(Duration::from_secs(2), async {
loop {
if stats.get_user_curr_connects(user) >= 1
|| ip_tracker.get_active_ip_count(user).await >= 1
|| relay_task.is_finished()
{
return true;
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
})
.await
.unwrap_or(false);
assert!(
observed_progress,
"overlap race test precondition must observe activation or bounded early termination"
);
tokio::time::sleep(Duration::from_millis(5)).await;
let _ = route_runtime.set_mode(RelayRouteMode::Middle);
let relay_result = tokio::time::timeout(Duration::from_secs(3), relay_task)
.await
.expect("overlap race relay must terminate")
.expect("overlap race relay task must not panic");
assert!(
matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. }))
|| matches!(relay_result, Err(ProxyError::Proxy(ref msg)) if msg == crate::proxy::route_mode::ROUTE_SWITCH_ERROR_MSG),
"overlap race must fail closed via quota enforcement or generic cutover termination"
);
assert_eq!(
stats.get_user_curr_connects(user),
0,
"overlap race exit must release user current-connection slot"
);
assert_eq!(
ip_tracker.get_active_ip_count(user).await,
0,
"overlap race exit must release reserved user IP footprint"
);
drop(client_side);
tg_accept_task.abort();
let _ = tg_accept_task.await;
}
#[tokio::test]
async fn stress_drop_without_release_converges_to_zero_user_and_ip_state() {
let user = "gap-t05-drop-stress-user";
let mut config = crate::config::ProxyConfig::default();
config
.access
.user_max_tcp_conns
.insert(user.to_string(), 4096);
let stats = std::sync::Arc::new(crate::stats::Stats::new());
let ip_tracker = std::sync::Arc::new(crate::ip_tracker::UserIpTracker::new());
let mut reservations = Vec::new();
for idx in 0..512u16 {
let peer = std::net::SocketAddr::new(
std::net::IpAddr::V4(std::net::Ipv4Addr::new(198, 51, (idx >> 8) as u8, (idx & 0xff) as u8)),
30_000 + idx,
);
let reservation = RunningClientHandler::acquire_user_connection_reservation_static(
user,
&config,
stats.clone(),
peer,
ip_tracker.clone(),
)
.await
.expect("reservation acquisition must succeed in stress precondition");
reservations.push(reservation);
}
assert_eq!(stats.get_user_curr_connects(user), 512);
for reservation in reservations {
std::thread::spawn(move || drop(reservation))
.join()
.expect("drop thread must not panic");
}
tokio::time::timeout(std::time::Duration::from_secs(2), async {
loop {
if stats.get_user_curr_connects(user) == 0
&& ip_tracker.get_active_ip_count(user).await == 0
{
break;
}
tokio::time::sleep(std::time::Duration::from_millis(5)).await;
}
})
.await
.expect("drop-only path must eventually release all user/IP reservations");
}
#[tokio::test]
async fn proxy_protocol_header_is_rejected_when_trust_list_is_empty() {
let mut cfg = crate::config::ProxyConfig::default();
cfg.general.beobachten = false;
cfg.server.proxy_protocol_trusted_cidrs.clear();
let config = std::sync::Arc::new(cfg);
let stats = std::sync::Arc::new(crate::stats::Stats::new());
let upstream_manager = std::sync::Arc::new(crate::transport::UpstreamManager::new(
vec![crate::config::UpstreamConfig {
upstream_type: crate::config::UpstreamType::Direct {
interface: None,
bind_addresses: None,
},
weight: 1,
enabled: true,
scopes: String::new(),
selected_scope: String::new(),
}],
1,
1,
1,
1,
false,
stats.clone(),
));
let replay_checker = std::sync::Arc::new(crate::stats::ReplayChecker::new(128, std::time::Duration::from_secs(60)));
let buffer_pool = std::sync::Arc::new(crate::stream::BufferPool::new());
let rng = std::sync::Arc::new(crate::crypto::SecureRandom::new());
let route_runtime = std::sync::Arc::new(crate::proxy::route_mode::RouteRuntimeController::new(crate::proxy::route_mode::RelayRouteMode::Direct));
let ip_tracker = std::sync::Arc::new(crate::ip_tracker::UserIpTracker::new());
let beobachten = std::sync::Arc::new(crate::stats::beobachten::BeobachtenStore::new());
let (server_side, mut client_side) = duplex(2048);
let peer: std::net::SocketAddr = "198.51.100.80:55000".parse().unwrap();
let handler = tokio::spawn(handle_client_stream(
server_side,
peer,
config,
stats,
upstream_manager,
replay_checker,
buffer_pool,
rng,
None,
route_runtime,
None,
ip_tracker,
beobachten,
true,
));
let proxy_header = ProxyProtocolV1Builder::new()
.tcp4(
"203.0.113.9:32000".parse().unwrap(),
"192.0.2.8:443".parse().unwrap(),
)
.build();
client_side.write_all(&proxy_header).await.unwrap();
drop(client_side);
let result = tokio::time::timeout(std::time::Duration::from_secs(3), handler)
.await
.unwrap()
.unwrap();
assert!(matches!(result, Err(ProxyError::InvalidProxyProtocol)));
}
#[tokio::test]
async fn proxy_protocol_header_from_untrusted_peer_range_is_rejected_under_load() {
let mut cfg = crate::config::ProxyConfig::default();
cfg.general.beobachten = false;
cfg.server.proxy_protocol_trusted_cidrs = vec!["10.0.0.0/8".parse().unwrap()];
let config = std::sync::Arc::new(cfg);
for idx in 0..32u16 {
let stats = std::sync::Arc::new(crate::stats::Stats::new());
let upstream_manager = std::sync::Arc::new(crate::transport::UpstreamManager::new(
vec![crate::config::UpstreamConfig {
upstream_type: crate::config::UpstreamType::Direct {
interface: None,
bind_addresses: None,
},
weight: 1,
enabled: true,
scopes: String::new(),
selected_scope: String::new(),
}],
1,
1,
1,
1,
false,
stats.clone(),
));
let replay_checker = std::sync::Arc::new(crate::stats::ReplayChecker::new(64, std::time::Duration::from_secs(60)));
let buffer_pool = std::sync::Arc::new(crate::stream::BufferPool::new());
let rng = std::sync::Arc::new(crate::crypto::SecureRandom::new());
let route_runtime = std::sync::Arc::new(crate::proxy::route_mode::RouteRuntimeController::new(crate::proxy::route_mode::RelayRouteMode::Direct));
let ip_tracker = std::sync::Arc::new(crate::ip_tracker::UserIpTracker::new());
let beobachten = std::sync::Arc::new(crate::stats::beobachten::BeobachtenStore::new());
let (server_side, mut client_side) = duplex(1024);
let peer = std::net::SocketAddr::new(
std::net::IpAddr::V4(std::net::Ipv4Addr::new(203, 0, 113, (idx + 1) as u8)),
55_000 + idx,
);
let handler = tokio::spawn(handle_client_stream(
server_side,
peer,
config.clone(),
stats,
upstream_manager,
replay_checker,
buffer_pool,
rng,
None,
route_runtime,
None,
ip_tracker,
beobachten,
true,
));
let proxy_header = ProxyProtocolV1Builder::new()
.tcp4(
"203.0.113.10:32000".parse().unwrap(),
"192.0.2.8:443".parse().unwrap(),
)
.build();
client_side.write_all(&proxy_header).await.unwrap();
drop(client_side);
let result = tokio::time::timeout(std::time::Duration::from_secs(2), handler)
.await
.unwrap()
.unwrap();
assert!(
matches!(result, Err(ProxyError::InvalidProxyProtocol)),
"burst idx {idx}: untrusted source must be rejected"
);
}
}
#[tokio::test]
async fn short_tls_probe_is_masked_through_client_pipeline() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
@ -888,7 +1215,7 @@ async fn valid_tls_path_does_not_fall_back_to_mask_backend() {
let ip_tracker = Arc::new(UserIpTracker::new());
let beobachten = Arc::new(BeobachtenStore::new());
let (server_side, mut client_side) = duplex(8192);
let (server_side, mut client_side) = duplex(131072);
let peer: SocketAddr = "198.51.100.80:55002".parse().unwrap();
let stats_for_assert = stats.clone();
let bad_before = stats_for_assert.get_connects_bad();
@ -947,11 +1274,12 @@ async fn valid_tls_with_invalid_mtproto_falls_back_to_mask_backend() {
let invalid_mtproto = vec![0u8; crate::protocol::constants::HANDSHAKE_LEN];
let tls_app_record = wrap_tls_application_data(&invalid_mtproto);
let expected_fallback = client_hello.clone();
let accept_task = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut got = vec![0u8; invalid_mtproto.len()];
let mut got = vec![0u8; expected_fallback.len()];
stream.read_exact(&mut got).await.unwrap();
assert_eq!(got, invalid_mtproto);
assert_eq!(got, expected_fallback);
});
let mut cfg = ProxyConfig::default();
@ -1045,11 +1373,12 @@ async fn client_handler_tls_bad_mtproto_is_forwarded_to_mask_backend() {
let invalid_mtproto = vec![0u8; crate::protocol::constants::HANDSHAKE_LEN];
let tls_app_record = wrap_tls_application_data(&invalid_mtproto);
let expected_fallback = client_hello.clone();
let mask_accept_task = tokio::spawn(async move {
let (mut stream, _) = mask_listener.accept().await.unwrap();
let mut got = vec![0u8; invalid_mtproto.len()];
let mut got = vec![0u8; expected_fallback.len()];
stream.read_exact(&mut got).await.unwrap();
assert_eq!(got, invalid_mtproto);
assert_eq!(got, expected_fallback);
});
let mut cfg = ProxyConfig::default();

View File

@ -31,6 +31,22 @@ use std::os::unix::fs::OpenOptionsExt;
const UNKNOWN_DC_LOG_DISTINCT_LIMIT: usize = 1024;
static LOGGED_UNKNOWN_DCS: OnceLock<Mutex<HashSet<i16>>> = OnceLock::new();
const MAX_SCOPE_HINT_LEN: usize = 64;
fn validated_scope_hint(user: &str) -> Option<&str> {
let scope = user.strip_prefix("scope_")?;
if scope.is_empty() || scope.len() > MAX_SCOPE_HINT_LEN {
return None;
}
if scope
.bytes()
.all(|b| b.is_ascii_alphanumeric() || b == b'-')
{
Some(scope)
} else {
None
}
}
#[derive(Clone)]
struct SanitizedUnknownDcLogPath {
@ -185,8 +201,15 @@ where
"Connecting to Telegram DC"
);
let scope_hint = validated_scope_hint(user);
if user.starts_with("scope_") && scope_hint.is_none() {
warn!(
user = %user,
"Ignoring invalid scope hint and falling back to default upstream selection"
);
}
let tg_stream = upstream_manager
.connect(dc_addr, Some(success.dc_idx), user.strip_prefix("scope_").filter(|s| !s.is_empty()))
.connect(dc_addr, Some(success.dc_idx), scope_hint)
.await?;
debug!(peer = %success.peer, dc_addr = %dc_addr, "Connected, performing TG handshake");
@ -290,10 +313,10 @@ fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result<SocketAddr> {
warn!(dc_idx = dc_idx, "Requested non-standard DC with no override; falling back to default cluster");
if config.general.unknown_dc_file_log_enabled
&& let Some(path) = &config.general.unknown_dc_log_path
&& should_log_unknown_dc(dc_idx)
&& let Ok(handle) = tokio::runtime::Handle::try_current()
{
if let Some(path) = sanitize_unknown_dc_log_path(path) {
if should_log_unknown_dc(dc_idx) {
handle.spawn_blocking(move || {
if unknown_dc_log_path_is_still_safe(&path)
&& let Ok(mut file) = open_unknown_dc_log_append(&path.resolved_path)
@ -301,6 +324,7 @@ fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result<SocketAddr> {
let _ = writeln!(file, "dc_idx={dc_idx}");
}
});
}
} else {
warn!(dc_idx = dc_idx, raw_path = %path, "Rejected unsafe unknown DC log path");
}

View File

@ -94,6 +94,26 @@ fn unknown_dc_log_fails_closed_when_dedup_lock_is_poisoned() {
);
}
#[test]
fn unsafe_unknown_dc_log_path_does_not_consume_dedup_slot() {
let _guard = unknown_dc_test_lock()
.lock()
.expect("unknown dc test lock must be available");
clear_unknown_dc_log_cache_for_testing();
let dc_idx: i16 = 31_123;
let mut cfg = ProxyConfig::default();
cfg.general.unknown_dc_file_log_enabled = true;
cfg.general.unknown_dc_log_path = Some("../telemt-unknown-dc-unsafe.log".to_string());
let _ = get_dc_addr_static(dc_idx, &cfg).expect("fallback routing must still work");
assert!(
should_log_unknown_dc(dc_idx),
"rejected unsafe log path must not consume unknown-dc dedup entry"
);
}
#[test]
fn stress_unknown_dc_log_concurrent_unique_churn_respects_cap() {
let _guard = unknown_dc_test_lock()
@ -158,6 +178,24 @@ fn light_fuzz_unknown_dc_log_mixed_duplicates_never_exceeds_cap() {
);
}
#[test]
fn scope_hint_accepts_ascii_alnum_and_dash_within_limit() {
assert_eq!(validated_scope_hint("scope_alpha-1"), Some("alpha-1"));
assert_eq!(validated_scope_hint("scope_AZ09"), Some("AZ09"));
}
#[test]
fn scope_hint_rejects_invalid_or_oversized_values() {
assert_eq!(validated_scope_hint("plain_user"), None);
assert_eq!(validated_scope_hint("scope_"), None);
assert_eq!(validated_scope_hint("scope_a/b"), None);
assert_eq!(validated_scope_hint("scope_bad space"), None);
assert_eq!(validated_scope_hint("scope_bad.dot"), None);
let oversized = format!("scope_{}", "a".repeat(MAX_SCOPE_HINT_LEN + 1));
assert_eq!(validated_scope_hint(&oversized), None);
}
#[test]
fn unknown_dc_log_path_sanitizer_rejects_parent_traversal_inputs() {
assert!(
@ -1207,3 +1245,80 @@ async fn direct_relay_cutover_storm_multi_session_keeps_generic_errors_and_relea
tg_accept_task.abort();
let _ = tg_accept_task.await;
}
#[test]
fn prefer_v6_override_matrix_prefers_matching_family_then_degrades_safely() {
let dc_idx: i16 = 2;
let mut cfg_a = ProxyConfig::default();
cfg_a.network.prefer = 6;
cfg_a.network.ipv6 = Some(true);
cfg_a.dc_overrides.insert(
dc_idx.to_string(),
vec![
"203.0.113.90:443".to_string(),
"[2001:db8::90]:443".to_string(),
],
);
let a = get_dc_addr_static(dc_idx, &cfg_a).expect("v6+v4 override set must resolve");
assert!(a.is_ipv6(), "prefer_v6 should choose v6 override when present");
let mut cfg_b = ProxyConfig::default();
cfg_b.network.prefer = 6;
cfg_b.network.ipv6 = Some(true);
cfg_b.dc_overrides
.insert(dc_idx.to_string(), vec!["203.0.113.91:443".to_string()]);
let b = get_dc_addr_static(dc_idx, &cfg_b).expect("v4-only override must still resolve");
assert!(b.is_ipv4(), "when no v6 override exists, v4 override must be used");
let mut cfg_c = ProxyConfig::default();
cfg_c.network.prefer = 6;
cfg_c.network.ipv6 = Some(true);
let c = get_dc_addr_static(dc_idx, &cfg_c).expect("table fallback must resolve");
assert_eq!(
c,
SocketAddr::new(TG_DATACENTERS_V6[(dc_idx as usize) - 1], TG_DATACENTER_PORT),
"without overrides, prefer_v6 path must resolve from static v6 datacenter table"
);
}
#[test]
fn prefer_v6_override_matrix_ignores_invalid_entries_and_keeps_fail_closed_fallback() {
let dc_idx: i16 = 3;
let mut cfg = ProxyConfig::default();
cfg.network.prefer = 6;
cfg.network.ipv6 = Some(true);
cfg.dc_overrides.insert(
dc_idx.to_string(),
vec![
"not-an-addr".to_string(),
"also:bad".to_string(),
"203.0.113.55:443".to_string(),
],
);
let addr = get_dc_addr_static(dc_idx, &cfg).expect("at least one valid override must keep resolution alive");
assert_eq!(addr, "203.0.113.55:443".parse::<SocketAddr>().unwrap());
}
#[test]
fn stress_prefer_v6_override_matrix_is_deterministic_under_mixed_inputs() {
for idx in 1..=5i16 {
let mut cfg = ProxyConfig::default();
cfg.network.prefer = 6;
cfg.network.ipv6 = Some(true);
cfg.dc_overrides.insert(
idx.to_string(),
vec![
format!("203.0.113.{}:443", 100 + idx),
format!("[2001:db8::{}]:443", 100 + idx),
],
);
let first = get_dc_addr_static(idx, &cfg).expect("first lookup must resolve");
let second = get_dc_addr_static(idx, &cfg).expect("second lookup must resolve");
assert_eq!(first, second, "override resolution must stay deterministic for dc {idx}");
assert!(first.is_ipv6(), "dc {idx}: v6 override should be preferred");
}
}

View File

@ -14,7 +14,7 @@ use dashmap::DashMap;
use dashmap::mapref::entry::Entry;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tracing::{debug, warn, trace};
use zeroize::Zeroize;
use zeroize::{Zeroize, Zeroizing};
use crate::crypto::{sha256, AesCtr, SecureRandom};
use rand::Rng;
@ -28,6 +28,10 @@ use crate::tls_front::{TlsFrontCache, emulator};
const ACCESS_SECRET_BYTES: usize = 16;
static INVALID_SECRET_WARNED: OnceLock<Mutex<HashSet<(String, String)>>> = OnceLock::new();
#[cfg(test)]
const WARNED_SECRET_MAX_ENTRIES: usize = 64;
#[cfg(not(test))]
const WARNED_SECRET_MAX_ENTRIES: usize = 1_024;
const AUTH_PROBE_TRACK_RETENTION_SECS: u64 = 10 * 60;
#[cfg(test)]
@ -406,7 +410,13 @@ fn warn_invalid_secret_once(name: &str, reason: &str, expected: usize, got: Opti
let key = (name.to_string(), reason.to_string());
let warned = INVALID_SECRET_WARNED.get_or_init(|| Mutex::new(HashSet::new()));
let should_warn = match warned.lock() {
Ok(mut guard) => guard.insert(key),
Ok(mut guard) => {
if !guard.contains(&key) && guard.len() >= WARNED_SECRET_MAX_ENTRIES {
false
} else {
guard.insert(key)
}
}
Err(_) => true,
};
@ -575,6 +585,7 @@ where
}
if handshake.len() < tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 {
auth_probe_record_failure(peer.ip(), Instant::now());
maybe_apply_server_hello_delay(config).await;
debug!(peer = %peer, "TLS handshake too short");
return HandshakeResult::BadClient { reader, writer };
@ -736,9 +747,13 @@ where
R: AsyncRead + Unpin + Send,
W: AsyncWrite + Unpin + Send,
{
let handshake_fingerprint = {
let digest = sha256(&handshake[..8]);
hex::encode(&digest[..4])
};
trace!(
peer = %peer,
handshake_head = %hex::encode(&handshake[..8]),
handshake_fingerprint = %handshake_fingerprint,
"MTProto handshake prefix"
);
@ -760,7 +775,7 @@ where
let dec_prekey = &dec_prekey_iv[..PREKEY_LEN];
let dec_iv_bytes = &dec_prekey_iv[PREKEY_LEN..];
let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len());
let mut dec_key_input = Zeroizing::new(Vec::with_capacity(PREKEY_LEN + secret.len()));
dec_key_input.extend_from_slice(dec_prekey);
dec_key_input.extend_from_slice(&secret);
let dec_key = sha256(&dec_key_input);
@ -796,7 +811,7 @@ where
let enc_prekey = &enc_prekey_iv[..PREKEY_LEN];
let enc_iv_bytes = &enc_prekey_iv[PREKEY_LEN..];
let mut enc_key_input = Vec::with_capacity(PREKEY_LEN + secret.len());
let mut enc_key_input = Zeroizing::new(Vec::with_capacity(PREKEY_LEN + secret.len()));
enc_key_input.extend_from_slice(enc_prekey);
enc_key_input.extend_from_slice(&secret);
let enc_key = sha256(&enc_key_input);
@ -885,7 +900,7 @@ pub fn generate_tg_nonce(
nonce[DC_IDX_POS..DC_IDX_POS + 2].copy_from_slice(&dc_idx.to_le_bytes());
if fast_mode {
let mut key_iv = Vec::with_capacity(KEY_LEN + IV_LEN);
let mut key_iv = Zeroizing::new(Vec::with_capacity(KEY_LEN + IV_LEN));
key_iv.extend_from_slice(client_enc_key);
key_iv.extend_from_slice(&client_enc_iv.to_be_bytes());
key_iv.reverse(); // Python/C behavior: reversed enc_key+enc_iv in nonce
@ -893,7 +908,7 @@ pub fn generate_tg_nonce(
}
let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN];
let dec_key_iv: Vec<u8> = enc_key_iv.iter().rev().copied().collect();
let dec_key_iv = Zeroizing::new(enc_key_iv.iter().rev().copied().collect::<Vec<u8>>());
let mut tg_enc_key = [0u8; 32];
tg_enc_key.copy_from_slice(&enc_key_iv[..KEY_LEN]);
@ -914,7 +929,7 @@ pub fn generate_tg_nonce(
/// Encrypt nonce for sending to Telegram and return cipher objects with correct counter state
pub fn encrypt_tg_nonce_with_ciphers(nonce: &[u8; HANDSHAKE_LEN]) -> (Vec<u8>, AesCtr, AesCtr) {
let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN];
let dec_key_iv: Vec<u8> = enc_key_iv.iter().rev().copied().collect();
let dec_key_iv = Zeroizing::new(enc_key_iv.iter().rev().copied().collect::<Vec<u8>>());
let mut enc_key = [0u8; 32];
enc_key.copy_from_slice(&enc_key_iv[..KEY_LEN]);
@ -935,6 +950,8 @@ pub fn encrypt_tg_nonce_with_ciphers(nonce: &[u8; HANDSHAKE_LEN]) -> (Vec<u8>, A
result.extend_from_slice(&encrypted_full[PROTO_TAG_POS..]);
let decryptor = AesCtr::new(&dec_key, dec_iv);
enc_key.zeroize();
dec_key.zeroize();
(result, encryptor, decryptor)
}
@ -950,6 +967,10 @@ pub fn encrypt_tg_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec<u8> {
#[path = "handshake_security_tests.rs"]
mod security_tests;
#[cfg(test)]
#[path = "handshake_gap_short_tls_probe_throttle_security_tests.rs"]
mod gap_short_tls_probe_throttle_security_tests;
/// Compile-time guard: HandshakeSuccess holds cryptographic key material and
/// must never be Copy. A Copy impl would allow silent key duplication,
/// undermining the zeroize-on-drop guarantee.

View File

@ -0,0 +1,50 @@
use super::*;
use crate::stats::ReplayChecker;
use std::net::SocketAddr;
use std::time::Duration;
fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig {
let mut cfg = ProxyConfig::default();
cfg.access.users.clear();
cfg.access
.users
.insert("user".to_string(), secret_hex.to_string());
cfg.access.ignore_time_skew = true;
cfg
}
#[tokio::test]
async fn gap_t01_short_tls_probe_burst_is_throttled() {
let _guard = auth_probe_test_lock()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
clear_auth_probe_state_for_testing();
let config = test_config_with_secret_hex("11111111111111111111111111111111");
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let rng = SecureRandom::new();
let peer: SocketAddr = "198.51.100.171:44361".parse().unwrap();
let too_short = vec![0x16, 0x03, 0x01];
for _ in 0..AUTH_PROBE_BACKOFF_START_FAILS {
let result = handle_tls_handshake(
&too_short,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
&rng,
None,
)
.await;
assert!(matches!(result, HandshakeResult::BadClient { .. }));
}
assert!(
auth_probe_fail_streak_for_testing(peer.ip())
.is_some_and(|streak| streak >= AUTH_PROBE_BACKOFF_START_FAILS),
"short TLS probe bursts must increase auth-probe fail streak"
);
}

View File

@ -1345,6 +1345,29 @@ fn invalid_secret_warning_keys_do_not_collide_on_colon_boundaries() {
);
}
#[test]
fn invalid_secret_warning_cache_is_bounded() {
let _guard = warned_secrets_test_lock()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
clear_warned_secrets_for_testing();
for idx in 0..(WARNED_SECRET_MAX_ENTRIES + 32) {
let user = format!("warned_user_{idx}");
warn_invalid_secret_once(&user, "invalid_length", ACCESS_SECRET_BYTES, Some(idx));
}
let warned = INVALID_SECRET_WARNED
.get()
.expect("warned set must be initialized");
let guard = warned.lock().expect("warned set lock must be available");
assert_eq!(
guard.len(),
WARNED_SECRET_MAX_ENTRIES,
"invalid-secret warning cache must remain bounded"
);
}
#[tokio::test]
async fn repeated_invalid_tls_probes_trigger_pre_auth_throttle() {
let _guard = auth_probe_test_lock()
@ -1921,6 +1944,165 @@ fn auth_probe_eviction_offset_changes_with_time_component() {
);
}
#[test]
fn auth_probe_round_limited_overcap_eviction_marks_saturation_and_keeps_newcomer_trackable() {
let _guard = auth_probe_test_lock()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
clear_auth_probe_state_for_testing();
let state = DashMap::new();
let now = Instant::now();
let initial = AUTH_PROBE_TRACK_MAX_ENTRIES + 64;
let sentinel = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 250));
state.insert(
sentinel,
AuthProbeState {
fail_streak: 25,
blocked_until: now,
last_seen: now - Duration::from_secs(30),
},
);
for idx in 0..(initial - 1) {
let ip = IpAddr::V4(Ipv4Addr::new(
10,
20,
((idx >> 8) & 0xff) as u8,
(idx & 0xff) as u8,
));
state.insert(
ip,
AuthProbeState {
fail_streak: 1,
blocked_until: now,
last_seen: now + Duration::from_millis((idx % 1024) as u64),
},
);
}
let newcomer = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 40));
auth_probe_record_failure_with_state(&state, newcomer, now + Duration::from_millis(1));
assert!(state.get(&newcomer).is_some(), "newcomer must still be tracked under over-cap pressure");
assert!(
state.get(&sentinel).is_some(),
"high fail-streak sentinel must survive round-limited eviction"
);
assert!(
auth_probe_saturation_is_throttled_at_for_testing(now + Duration::from_millis(1)),
"round-limited over-cap path must activate saturation throttle marker"
);
}
#[test]
fn stress_auth_probe_overcap_churn_does_not_starve_high_threat_sentinel_bucket() {
let _guard = auth_probe_test_lock()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
clear_auth_probe_state_for_testing();
let state = DashMap::new();
let base_now = Instant::now();
let sentinel = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 200));
state.insert(
sentinel,
AuthProbeState {
fail_streak: 30,
blocked_until: base_now,
last_seen: base_now - Duration::from_secs(60),
},
);
for idx in 0..(AUTH_PROBE_TRACK_MAX_ENTRIES + 80) {
let ip = IpAddr::V4(Ipv4Addr::new(
172,
22,
((idx >> 8) & 0xff) as u8,
(idx & 0xff) as u8,
));
state.insert(
ip,
AuthProbeState {
fail_streak: 1,
blocked_until: base_now,
last_seen: base_now + Duration::from_millis((idx % 2048) as u64),
},
);
}
for step in 0..512usize {
let newcomer = IpAddr::V4(Ipv4Addr::new(
203,
2,
((step >> 8) & 0xff) as u8,
(step & 0xff) as u8,
));
auth_probe_record_failure_with_state(&state, newcomer, base_now + Duration::from_millis(step as u64 + 1));
assert!(
state.get(&sentinel).is_some(),
"step {step}: high-threat sentinel must not be starved by newcomer churn"
);
assert!(state.get(&newcomer).is_some(), "step {step}: newcomer must be tracked");
}
}
#[test]
fn light_fuzz_auth_probe_overcap_eviction_prefers_less_threatening_entries() {
let _guard = auth_probe_test_lock()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
clear_auth_probe_state_for_testing();
let now = Instant::now();
let mut s: u64 = 0xBADC_0FFE_EE11_2233;
for round in 0..128usize {
let state = DashMap::new();
let sentinel = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 180));
state.insert(
sentinel,
AuthProbeState {
fail_streak: 18,
blocked_until: now,
last_seen: now - Duration::from_secs(5),
},
);
for idx in 0..AUTH_PROBE_TRACK_MAX_ENTRIES {
s ^= s << 7;
s ^= s >> 9;
s ^= s << 8;
let ip = IpAddr::V4(Ipv4Addr::new(
10,
((idx >> 8) & 0xff) as u8,
(idx & 0xff) as u8,
(s & 0xff) as u8,
));
state.insert(
ip,
AuthProbeState {
fail_streak: 1,
blocked_until: now,
last_seen: now + Duration::from_millis((s & 1023) as u64),
},
);
}
let newcomer = IpAddr::V4(Ipv4Addr::new(203, 10, ((round >> 8) & 0xff) as u8, (round & 0xff) as u8));
auth_probe_record_failure_with_state(&state, newcomer, now + Duration::from_millis(round as u64 + 1));
assert!(state.get(&newcomer).is_some(), "round {round}: newcomer should be tracked");
assert!(
state.get(&sentinel).is_some(),
"round {round}: high fail-streak sentinel should survive mixed low-threat pool"
);
}
}
#[test]
fn light_fuzz_auth_probe_eviction_offset_is_deterministic_per_input_pair() {
let mut rng = StdRng::seed_from_u64(0xA11CE5EED);

View File

@ -181,6 +181,7 @@ where
};
if let Some(header) = proxy_header {
if !write_proxy_header_with_timeout(&mut mask_write, &header).await {
wait_mask_outcome_budget(outcome_started).await;
return;
}
}
@ -246,6 +247,7 @@ where
let (mask_read, mut mask_write) = stream.into_split();
if let Some(header) = proxy_header {
if !write_proxy_header_with_timeout(&mut mask_write, &header).await {
wait_mask_outcome_budget(outcome_started).await;
return;
}
}

View File

@ -317,6 +317,254 @@ async fn backend_reachable_fast_response_waits_mask_outcome_budget() {
accept_task.await.unwrap();
}
#[tokio::test]
async fn proxy_header_write_error_on_tcp_path_still_honors_coarse_outcome_budget() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let probe = b"GET /proxy-hdr-err HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec();
let accept_task = tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
drop(stream);
});
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = backend_addr.port();
config.censorship.mask_unix_sock = None;
config.censorship.mask_proxy_protocol = 1;
let peer: SocketAddr = "203.0.113.88:42430".parse().unwrap();
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let (client_reader_side, client_reader) = duplex(256);
drop(client_reader_side);
let (_client_visible_reader, client_visible_writer) = duplex(512);
let beobachten = BeobachtenStore::new();
let started = Instant::now();
let task = tokio::spawn(async move {
handle_bad_client(
client_reader,
client_visible_writer,
&probe,
peer,
local_addr,
&config,
&beobachten,
)
.await;
});
timeout(Duration::from_millis(35), task)
.await
.expect_err("proxy-header write error path should remain inside coarse masking budget window");
assert!(
started.elapsed() >= Duration::from_millis(35),
"proxy-header write error path should avoid immediate-return timing signature"
);
accept_task.await.unwrap();
}
#[cfg(unix)]
#[tokio::test]
async fn proxy_header_write_error_on_unix_path_still_honors_coarse_outcome_budget() {
let sock_path = format!(
"/tmp/telemt-mask-unix-hdr-err-{}-{}.sock",
std::process::id(),
rand::random::<u64>()
);
let _ = std::fs::remove_file(&sock_path);
let listener = UnixListener::bind(&sock_path).unwrap();
let probe = b"GET /unix-hdr-err HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec();
let accept_task = tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
drop(stream);
});
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_unix_sock = Some(sock_path.clone());
config.censorship.mask_proxy_protocol = 1;
let peer: SocketAddr = "203.0.113.89:42431".parse().unwrap();
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let (client_reader_side, client_reader) = duplex(256);
drop(client_reader_side);
let (_client_visible_reader, client_visible_writer) = duplex(512);
let beobachten = BeobachtenStore::new();
let started = Instant::now();
let task = tokio::spawn(async move {
handle_bad_client(
client_reader,
client_visible_writer,
&probe,
peer,
local_addr,
&config,
&beobachten,
)
.await;
});
timeout(Duration::from_millis(35), task)
.await
.expect_err("unix proxy-header write error path should remain inside coarse masking budget window");
assert!(
started.elapsed() >= Duration::from_millis(35),
"unix proxy-header write error path should avoid immediate-return timing signature"
);
accept_task.await.unwrap();
let _ = std::fs::remove_file(sock_path);
}
#[cfg(unix)]
#[tokio::test]
async fn unix_socket_proxy_protocol_v1_header_is_sent_before_probe() {
let sock_path = format!(
"/tmp/telemt-mask-unix-v1-{}-{}.sock",
std::process::id(),
rand::random::<u64>()
);
let _ = std::fs::remove_file(&sock_path);
let listener = UnixListener::bind(&sock_path).unwrap();
let probe = b"GET /unix-v1 HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec();
let backend_reply = b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\n\r\n".to_vec();
let accept_task = tokio::spawn({
let probe = probe.clone();
let backend_reply = backend_reply.clone();
async move {
let (stream, _) = listener.accept().await.unwrap();
let mut reader = BufReader::new(stream);
let mut header_line = Vec::new();
reader.read_until(b'\n', &mut header_line).await.unwrap();
let header_text = String::from_utf8(header_line).unwrap();
assert!(header_text.starts_with("PROXY "), "must start with PROXY prefix");
assert!(header_text.ends_with("\r\n"), "v1 header must end with CRLF");
let mut received_probe = vec![0u8; probe.len()];
reader.read_exact(&mut received_probe).await.unwrap();
assert_eq!(received_probe, probe);
let mut stream = reader.into_inner();
stream.write_all(&backend_reply).await.unwrap();
}
});
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_unix_sock = Some(sock_path.clone());
config.censorship.mask_proxy_protocol = 1;
let peer: SocketAddr = "203.0.113.51:51010".parse().unwrap();
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let (client_reader, _client_writer) = duplex(256);
let (mut client_visible_reader, client_visible_writer) = duplex(2048);
let beobachten = BeobachtenStore::new();
handle_bad_client(
client_reader,
client_visible_writer,
&probe,
peer,
local_addr,
&config,
&beobachten,
)
.await;
let mut observed = vec![0u8; backend_reply.len()];
client_visible_reader.read_exact(&mut observed).await.unwrap();
assert_eq!(observed, backend_reply);
accept_task.await.unwrap();
let _ = std::fs::remove_file(sock_path);
}
#[cfg(unix)]
#[tokio::test]
async fn unix_socket_proxy_protocol_v2_header_is_sent_before_probe() {
let sock_path = format!(
"/tmp/telemt-mask-unix-v2-{}-{}.sock",
std::process::id(),
rand::random::<u64>()
);
let _ = std::fs::remove_file(&sock_path);
let listener = UnixListener::bind(&sock_path).unwrap();
let probe = b"GET /unix-v2 HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec();
let backend_reply = b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n".to_vec();
let accept_task = tokio::spawn({
let probe = probe.clone();
let backend_reply = backend_reply.clone();
async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut sig = [0u8; 12];
stream.read_exact(&mut sig).await.unwrap();
assert_eq!(&sig, b"\r\n\r\n\0\r\nQUIT\n", "v2 signature must match spec");
let mut fixed = [0u8; 4];
stream.read_exact(&mut fixed).await.unwrap();
let addr_len = u16::from_be_bytes([fixed[2], fixed[3]]) as usize;
let mut addr_block = vec![0u8; addr_len];
stream.read_exact(&mut addr_block).await.unwrap();
let mut received_probe = vec![0u8; probe.len()];
stream.read_exact(&mut received_probe).await.unwrap();
assert_eq!(received_probe, probe);
stream.write_all(&backend_reply).await.unwrap();
}
});
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_unix_sock = Some(sock_path.clone());
config.censorship.mask_proxy_protocol = 2;
let peer: SocketAddr = "203.0.113.52:51011".parse().unwrap();
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let (client_reader, _client_writer) = duplex(256);
let (mut client_visible_reader, client_visible_writer) = duplex(2048);
let beobachten = BeobachtenStore::new();
handle_bad_client(
client_reader,
client_visible_writer,
&probe,
peer,
local_addr,
&config,
&beobachten,
)
.await;
let mut observed = vec![0u8; backend_reply.len()];
client_visible_reader.read_exact(&mut observed).await.unwrap();
assert_eq!(observed, backend_reply);
accept_task.await.unwrap();
let _ = std::fs::remove_file(sock_path);
}
#[tokio::test]
async fn mask_disabled_fast_eof_not_shaped_by_mask_budget() {
let mut config = ProxyConfig::default();

View File

@ -44,6 +44,10 @@ const C2ME_SEND_TIMEOUT: Duration = Duration::from_millis(50);
const C2ME_SEND_TIMEOUT: Duration = Duration::from_secs(5);
const ME_D2C_FLUSH_BATCH_MAX_FRAMES_MIN: usize = 1;
const ME_D2C_FLUSH_BATCH_MAX_BYTES_MIN: usize = 4096;
#[cfg(test)]
const QUOTA_USER_LOCKS_MAX: usize = 64;
#[cfg(not(test))]
const QUOTA_USER_LOCKS_MAX: usize = 4_096;
static DESYNC_DEDUP: OnceLock<DashMap<u64, Instant>> = OnceLock::new();
static DESYNC_HASHER: OnceLock<RandomState> = OnceLock::new();
static DESYNC_FULL_CACHE_LAST_EMIT_AT: OnceLock<Mutex<Option<Instant>>> = OnceLock::new();
@ -336,6 +340,14 @@ fn quota_user_lock(user: &str) -> Arc<AsyncMutex<()>> {
return Arc::clone(existing.value());
}
if locks.len() >= QUOTA_USER_LOCKS_MAX {
locks.retain(|_, value| Arc::strong_count(value) > 1);
}
if locks.len() >= QUOTA_USER_LOCKS_MAX {
return Arc::new(AsyncMutex::new(()));
}
let created = Arc::new(AsyncMutex::new(()));
match locks.entry(user.to_string()) {
dashmap::mapref::entry::Entry::Occupied(entry) => Arc::clone(entry.get()),
@ -405,7 +417,7 @@ where
);
let (conn_id, me_rx) = me_pool.registry().register().await;
let trace_id = conn_id;
let trace_id = session_id;
let bytes_me2c = Arc::new(AtomicU64::new(0));
let mut forensics = RelayForensicsState {
trace_id,

View File

@ -15,7 +15,9 @@ use std::net::SocketAddr;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::thread;
use tokio::sync::Barrier;
use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt;
use tokio::io::duplex;
use tokio::time::{Duration as TokioDuration, timeout};
@ -233,6 +235,219 @@ fn desync_dedup_cache_is_bounded() {
);
}
#[test]
fn quota_user_lock_cache_reuses_entry_for_same_user() {
let a = quota_user_lock("quota-user-a");
let b = quota_user_lock("quota-user-a");
assert!(Arc::ptr_eq(&a, &b), "same user must reuse same quota lock");
}
#[test]
fn quota_user_lock_cache_is_bounded_under_unique_churn() {
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
for idx in 0..(QUOTA_USER_LOCKS_MAX + 128) {
let user = format!("quota-user-{idx}");
let lock = quota_user_lock(&user);
drop(lock);
}
assert!(
map.len() <= QUOTA_USER_LOCKS_MAX,
"quota lock cache must stay within configured bound"
);
}
#[test]
fn quota_user_lock_cache_saturation_returns_ephemeral_lock_without_growth() {
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
for idx in 0..QUOTA_USER_LOCKS_MAX {
let user = format!("quota-held-user-{idx}");
retained.push(quota_user_lock(&user));
}
assert_eq!(
map.len(),
QUOTA_USER_LOCKS_MAX,
"precondition: cache should be full before overflow acquisition"
);
let overflow_a = quota_user_lock("quota-overflow-user");
let overflow_b = quota_user_lock("quota-overflow-user");
assert_eq!(
map.len(),
QUOTA_USER_LOCKS_MAX,
"overflow acquisition must not grow cache past hard limit"
);
assert!(
map.get("quota-overflow-user").is_none(),
"overflow path should not cache new user lock when map is saturated and all entries are retained"
);
assert!(
!Arc::ptr_eq(&overflow_a, &overflow_b),
"overflow user lock should be ephemeral under saturation to preserve bounded cache size"
);
drop(retained);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn adversarial_quota_race_under_lock_cache_saturation_still_allows_only_one_winner() {
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
for idx in 0..QUOTA_USER_LOCKS_MAX {
let user = format!("quota-saturated-user-{idx}");
retained.push(quota_user_lock(&user));
}
assert_eq!(
map.len(),
QUOTA_USER_LOCKS_MAX,
"precondition: cache must be saturated for overflow-user race test"
);
let stats = Stats::new();
let bytes_me2c = AtomicU64::new(0);
let user = "gap-t04-saturated-lock-race-user";
let barrier = Arc::new(Barrier::new(2));
let one = run_quota_race_attempt(&stats, &bytes_me2c, user, 0x55, 9101, barrier.clone());
let two = run_quota_race_attempt(&stats, &bytes_me2c, user, 0x66, 9102, barrier);
let (r1, r2) = tokio::join!(one, two);
assert!(
matches!(r1, Ok(_) | Err(ProxyError::DataQuotaExceeded { .. }))
&& matches!(r2, Ok(_) | Err(ProxyError::DataQuotaExceeded { .. })),
"both racers must resolve cleanly without unexpected errors"
);
assert!(
matches!(r1, Err(ProxyError::DataQuotaExceeded { .. }))
|| matches!(r2, Err(ProxyError::DataQuotaExceeded { .. })),
"at least one racer must be quota-rejected even when lock cache is saturated"
);
assert_eq!(
stats.get_user_total_octets(user),
1,
"saturated lock cache must not permit double-success quota overshoot"
);
drop(retained);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn stress_quota_race_under_lock_cache_saturation_never_allows_double_success() {
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
for idx in 0..QUOTA_USER_LOCKS_MAX {
let user = format!("quota-saturated-stress-holder-{idx}");
retained.push(quota_user_lock(&user));
}
let stats = Stats::new();
let bytes_me2c = AtomicU64::new(0);
for round in 0..128u64 {
let user = format!("gap-t04-saturated-race-round-{round}");
let barrier = Arc::new(Barrier::new(2));
let one = run_quota_race_attempt(
&stats,
&bytes_me2c,
&user,
0x71,
12_000 + round,
barrier.clone(),
);
let two = run_quota_race_attempt(
&stats,
&bytes_me2c,
&user,
0x72,
13_000 + round,
barrier,
);
let (r1, r2) = tokio::join!(one, two);
assert!(
matches!(r1, Ok(_) | Err(ProxyError::DataQuotaExceeded { .. }))
&& matches!(r2, Ok(_) | Err(ProxyError::DataQuotaExceeded { .. })),
"round {round}: racers must resolve cleanly"
);
assert!(
matches!(r1, Err(ProxyError::DataQuotaExceeded { .. }))
|| matches!(r2, Err(ProxyError::DataQuotaExceeded { .. })),
"round {round}: at least one racer must be quota-rejected"
);
assert_eq!(
stats.get_user_total_octets(&user),
1,
"round {round}: saturated cache must still enforce exactly one forwarded byte"
);
}
drop(retained);
}
#[test]
fn adversarial_forensics_trace_id_should_not_alias_conn_id() {
let now = Instant::now();
let trace_id = 0x1122_3344_5566_7788;
let conn_id = 0x8877_6655_4433_2211;
let state = RelayForensicsState {
trace_id,
conn_id,
user: "trace-user".to_string(),
peer: "198.51.100.17:443".parse().unwrap(),
peer_hash: 0x8877_6655_4433_2211,
started_at: now,
bytes_c2me: 0,
bytes_me2c: Arc::new(AtomicU64::new(0)),
desync_all_full: false,
};
assert_ne!(
state.trace_id, state.conn_id,
"security expectation: trace correlation should be independent of connection identity"
);
assert_eq!(state.trace_id, trace_id);
assert_eq!(state.conn_id, conn_id);
}
#[tokio::test]
async fn abridged_ack_uses_big_endian_confirm_bytes_after_decryption() {
let (mut writer_side, reader_side) = duplex(8);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(reader_side, AesCtr::new(&key, iv), 8 * 1024);
write_client_ack(&mut writer, ProtoTag::Abridged, 0x11_22_33_44)
.await
.expect("ack write must succeed");
let mut observed = [0u8; 4];
writer_side
.read_exact(&mut observed)
.await
.expect("ack bytes must be readable");
let mut decryptor = AesCtr::new(&key, iv);
let decrypted = decryptor.decrypt(&observed);
assert_eq!(
decrypted,
0x11_22_33_44u32.to_be_bytes(),
"abridged ACK should encode confirm bytes in big-endian order"
);
}
#[test]
fn desync_dedup_full_cache_churn_stays_suppressed() {
let _guard = desync_dedup_test_lock()
@ -1707,6 +1922,150 @@ async fn middle_relay_cutover_midflight_releases_route_gauge() {
drop(client_side);
}
async fn run_quota_race_attempt(
stats: &Stats,
bytes_me2c: &AtomicU64,
user: &str,
payload: u8,
conn_id: u64,
barrier: Arc<Barrier>,
) -> Result<MeWriterResponseOutcome> {
let (writer_side, _reader_side) = duplex(1024);
let mut writer = make_crypto_writer(writer_side);
let rng = SecureRandom::new();
let mut frame_buf = Vec::new();
barrier.wait().await;
process_me_writer_response(
MeResponse::Data {
flags: 0,
data: Bytes::from(vec![payload]),
},
&mut writer,
ProtoTag::Intermediate,
&rng,
&mut frame_buf,
stats,
user,
Some(1),
bytes_me2c,
conn_id,
false,
false,
)
.await
}
#[tokio::test]
async fn abridged_max_extended_length_fails_closed_without_panic_or_partial_read() {
let _guard = desync_dedup_test_lock()
.lock()
.expect("middle relay test lock must be available");
let (reader, mut writer) = duplex(256);
let mut crypto_reader = make_crypto_reader(reader);
let buffer_pool = Arc::new(BufferPool::new());
let stats = Stats::new();
let forensics = make_forensics_state();
let mut frame_counter = 0;
let plaintext = vec![0x7f, 0xff, 0xff, 0xff];
let encrypted = encrypt_for_reader(&plaintext);
writer.write_all(&encrypted).await.unwrap();
let result = read_client_payload(
&mut crypto_reader,
ProtoTag::Abridged,
4096,
TokioDuration::from_secs(1),
&buffer_pool,
&forensics,
&mut frame_counter,
&stats,
)
.await;
assert!(result.is_err(), "oversized abridged length must fail closed");
assert_eq!(frame_counter, 0, "oversized frame must not be counted as accepted");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn deterministic_quota_race_exactly_one_succeeds_and_one_is_rejected() {
let stats = Stats::new();
let bytes_me2c = AtomicU64::new(0);
let user = "gap-t04-race-user";
let barrier = Arc::new(Barrier::new(2));
let f1 = run_quota_race_attempt(&stats, &bytes_me2c, user, 0x11, 5001, barrier.clone());
let f2 = run_quota_race_attempt(&stats, &bytes_me2c, user, 0x22, 5002, barrier);
let (r1, r2) = tokio::join!(f1, f2);
assert!(
matches!(r1, Ok(_) | Err(ProxyError::DataQuotaExceeded { .. })),
"first racer must either finish or fail closed on quota"
);
assert!(
matches!(r2, Ok(_) | Err(ProxyError::DataQuotaExceeded { .. })),
"second racer must either finish or fail closed on quota"
);
assert!(
matches!(r1, Err(ProxyError::DataQuotaExceeded { .. }))
|| matches!(r2, Err(ProxyError::DataQuotaExceeded { .. })),
"at least one racer must be quota-rejected"
);
assert_eq!(
stats.get_user_total_octets(user),
1,
"same-user race must forward/account exactly one payload byte"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn stress_quota_race_bursts_never_allow_double_success_per_round() {
let stats = Stats::new();
let bytes_me2c = AtomicU64::new(0);
for round in 0..128u64 {
let user = format!("gap-t04-race-burst-{round}");
let barrier = Arc::new(Barrier::new(2));
let one = run_quota_race_attempt(
&stats,
&bytes_me2c,
&user,
0x33,
6000 + round,
barrier.clone(),
);
let two = run_quota_race_attempt(
&stats,
&bytes_me2c,
&user,
0x44,
7000 + round,
barrier,
);
let (r1, r2) = tokio::join!(one, two);
assert!(
matches!(r1, Ok(_) | Err(ProxyError::DataQuotaExceeded { .. }))
&& matches!(r2, Ok(_) | Err(ProxyError::DataQuotaExceeded { .. })),
"round {round}: racers must resolve cleanly without unexpected errors"
);
assert!(
matches!(r1, Err(ProxyError::DataQuotaExceeded { .. }))
|| matches!(r2, Err(ProxyError::DataQuotaExceeded { .. })),
"round {round}: at least one racer must be quota-rejected"
);
assert_eq!(
stats.get_user_total_octets(&user),
1,
"round {round}: same-user total octets must remain exactly 1 (single forwarded winner)"
);
}
}
#[tokio::test]
async fn middle_relay_cutover_storm_multi_session_keeps_generic_errors_and_releases_gauge() {
let session_count = 6usize;

View File

@ -208,6 +208,8 @@ struct StatsIo<S> {
user: String,
quota_limit: Option<u64>,
quota_exceeded: Arc<AtomicBool>,
quota_read_wake_scheduled: bool,
quota_write_wake_scheduled: bool,
epoch: Instant,
}
@ -230,6 +232,8 @@ impl<S> StatsIo<S> {
user,
quota_limit,
quota_exceeded,
quota_read_wake_scheduled: false,
quota_write_wake_scheduled: false,
epoch,
}
}
@ -293,9 +297,19 @@ impl<S: AsyncRead + Unpin> AsyncRead for StatsIo<S> {
.then(|| quota_user_lock(&this.user));
let _quota_guard = if let Some(lock) = quota_lock.as_ref() {
match lock.try_lock() {
Ok(guard) => Some(guard),
Ok(guard) => {
this.quota_read_wake_scheduled = false;
Some(guard)
}
Err(_) => {
cx.waker().wake_by_ref();
if !this.quota_read_wake_scheduled {
this.quota_read_wake_scheduled = true;
let waker = cx.waker().clone();
tokio::task::spawn(async move {
tokio::task::yield_now().await;
waker.wake();
});
}
return Poll::Pending;
}
}
@ -356,9 +370,19 @@ impl<S: AsyncWrite + Unpin> AsyncWrite for StatsIo<S> {
.then(|| quota_user_lock(&this.user));
let _quota_guard = if let Some(lock) = quota_lock.as_ref() {
match lock.try_lock() {
Ok(guard) => Some(guard),
Ok(guard) => {
this.quota_write_wake_scheduled = false;
Some(guard)
}
Err(_) => {
cx.waker().wake_by_ref();
if !this.quota_write_wake_scheduled {
this.quota_write_wake_scheduled = true;
let waker = cx.waker().clone();
tokio::task::spawn(async move {
tokio::task::yield_now().await;
waker.wake();
});
}
return Poll::Pending;
}
}

View File

@ -14,6 +14,176 @@ use tokio::io::{AsyncRead, ReadBuf};
use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt, duplex};
use tokio::time::{Duration, timeout};
#[derive(Default)]
struct WakeCounter {
wakes: AtomicUsize,
}
impl std::task::Wake for WakeCounter {
fn wake(self: Arc<Self>) {
self.wakes.fetch_add(1, Ordering::Relaxed);
}
fn wake_by_ref(self: &Arc<Self>) {
self.wakes.fetch_add(1, Ordering::Relaxed);
}
}
#[tokio::test]
async fn quota_lock_contention_does_not_self_wake_pending_writer() {
let stats = Arc::new(Stats::new());
let user = "quota-lock-contention-user";
let lock = super::quota_user_lock(user);
let _held_lock = lock
.try_lock()
.expect("test must hold the per-user quota lock before polling writer");
let counters = Arc::new(super::SharedCounters::new());
let quota_exceeded = Arc::new(std::sync::atomic::AtomicBool::new(false));
let mut io = super::StatsIo::new(
tokio::io::sink(),
counters,
Arc::clone(&stats),
user.to_string(),
Some(1024),
quota_exceeded,
tokio::time::Instant::now(),
);
let wake_counter = Arc::new(WakeCounter::default());
let waker = Waker::from(Arc::clone(&wake_counter));
let mut cx = Context::from_waker(&waker);
let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x11]);
assert!(poll.is_pending(), "writer must remain pending while lock is contended");
assert_eq!(
wake_counter.wakes.load(Ordering::Relaxed),
0,
"contended quota lock must not self-wake immediately and spin the executor"
);
}
#[tokio::test]
async fn quota_lock_contention_writer_schedules_single_deferred_wake_until_lock_acquired() {
let stats = Arc::new(Stats::new());
let user = "quota-lock-writer-liveness-user";
let lock = super::quota_user_lock(user);
let held_lock = lock
.try_lock()
.expect("test must hold the per-user quota lock before polling writer");
let counters = Arc::new(super::SharedCounters::new());
let quota_exceeded = Arc::new(std::sync::atomic::AtomicBool::new(false));
let mut io = super::StatsIo::new(
tokio::io::sink(),
counters,
Arc::clone(&stats),
user.to_string(),
Some(1024),
quota_exceeded,
tokio::time::Instant::now(),
);
let wake_counter = Arc::new(WakeCounter::default());
let waker = Waker::from(Arc::clone(&wake_counter));
let mut cx = Context::from_waker(&waker);
let first = Pin::new(&mut io).poll_write(&mut cx, &[0x11]);
assert!(first.is_pending(), "writer must remain pending while lock is contended");
assert_eq!(
wake_counter.wakes.load(Ordering::Relaxed),
0,
"deferred wake must not fire synchronously"
);
timeout(Duration::from_millis(50), async {
loop {
if wake_counter.wakes.load(Ordering::Relaxed) >= 1 {
break;
}
tokio::task::yield_now().await;
}
})
.await
.expect("contended writer must schedule a deferred wake in bounded time");
let wakes_after_first_yield = wake_counter.wakes.load(Ordering::Relaxed);
assert!(
wakes_after_first_yield >= 1,
"contended writer must schedule at least one deferred wake for liveness"
);
let second = Pin::new(&mut io).poll_write(&mut cx, &[0x22]);
assert!(second.is_pending(), "writer remains pending while lock is still held");
for _ in 0..8 {
tokio::task::yield_now().await;
}
assert_eq!(
wake_counter.wakes.load(Ordering::Relaxed),
wakes_after_first_yield,
"writer contention should not schedule unbounded wake storms before lock acquisition"
);
drop(held_lock);
let released = Pin::new(&mut io).poll_write(&mut cx, &[0x33]);
assert!(released.is_ready(), "writer must make progress once quota lock is released");
}
#[tokio::test]
async fn quota_lock_contention_read_path_schedules_deferred_wake_for_liveness() {
let stats = Arc::new(Stats::new());
let user = "quota-lock-read-liveness-user";
let lock = super::quota_user_lock(user);
let held_lock = lock
.try_lock()
.expect("test must hold the per-user quota lock before polling reader");
let counters = Arc::new(super::SharedCounters::new());
let quota_exceeded = Arc::new(std::sync::atomic::AtomicBool::new(false));
let mut io = super::StatsIo::new(
tokio::io::empty(),
counters,
Arc::clone(&stats),
user.to_string(),
Some(1024),
quota_exceeded,
tokio::time::Instant::now(),
);
let wake_counter = Arc::new(WakeCounter::default());
let waker = Waker::from(Arc::clone(&wake_counter));
let mut cx = Context::from_waker(&waker);
let mut storage = [0u8; 1];
let mut buf = ReadBuf::new(&mut storage);
let first = Pin::new(&mut io).poll_read(&mut cx, &mut buf);
assert!(first.is_pending(), "reader must remain pending while lock is contended");
assert_eq!(
wake_counter.wakes.load(Ordering::Relaxed),
0,
"read contention wake must not fire synchronously"
);
timeout(Duration::from_millis(50), async {
loop {
if wake_counter.wakes.load(Ordering::Relaxed) >= 1 {
break;
}
tokio::task::yield_now().await;
}
})
.await
.expect("read contention must schedule a deferred wake in bounded time");
drop(held_lock);
let mut buf_after_release = ReadBuf::new(&mut storage);
let released = Pin::new(&mut io).poll_read(&mut cx, &mut buf_after_release);
assert!(released.is_ready(), "reader must make progress once quota lock is released");
}
#[tokio::test]
async fn relay_bidirectional_enforces_live_user_quota() {
let stats = Arc::new(Stats::new());

View File

@ -338,3 +338,69 @@ fn light_fuzz_cutover_stagger_delay_distribution_stays_in_fixed_window() {
);
}
}
#[test]
fn cutover_stagger_delay_distribution_has_no_empty_buckets_under_sequential_sessions() {
let mut buckets = [0usize; 1000];
let generation = 4242u64;
for session_id in 0..250_000u64 {
let delay_ms = cutover_stagger_delay(session_id, generation).as_millis() as usize;
let idx = delay_ms - 1000;
buckets[idx] += 1;
}
let empty = buckets.iter().filter(|&&count| count == 0).count();
assert_eq!(
empty, 0,
"all 1000 delay buckets must be exercised to avoid cutover herd clustering"
);
}
#[test]
fn light_fuzz_cutover_stagger_delay_distribution_stays_reasonably_uniform() {
let mut buckets = [0usize; 1000];
let mut s: u64 = 0x1BAD_B002_CAFE_F00D;
for _ in 0..300_000usize {
s ^= s << 7;
s ^= s >> 9;
s ^= s << 8;
let session_id = s;
s ^= s << 7;
s ^= s >> 9;
s ^= s << 8;
let generation = s;
let delay_ms = cutover_stagger_delay(session_id, generation).as_millis() as usize;
buckets[delay_ms - 1000] += 1;
}
let min = *buckets.iter().min().unwrap_or(&0);
let max = *buckets.iter().max().unwrap_or(&0);
assert!(min > 0, "fuzzed distribution must not leave empty buckets");
assert!(
max <= min.saturating_mul(3),
"bucket skew is too high for anti-herd staggering (max={max}, min={min})"
);
}
#[test]
fn stress_cutover_stagger_delay_distribution_remains_stable_across_generations() {
for generation in [0u64, 1, 7, 31, 255, 1024, u32::MAX as u64, u64::MAX - 1] {
let mut buckets = [0usize; 1000];
for session_id in 0..100_000u64 {
let delay_ms = cutover_stagger_delay(session_id ^ 0x9E37_79B9, generation)
.as_millis() as usize;
buckets[delay_ms - 1000] += 1;
}
let min = *buckets.iter().min().unwrap_or(&0);
let max = *buckets.iter().max().unwrap_or(&0);
assert!(
max <= min.saturating_mul(4).max(1),
"generation={generation}: distribution collapsed (max={max}, min={min})"
);
}
}

View File

@ -1508,9 +1508,11 @@ impl Stats {
// ============= Replay Checker =============
pub struct ReplayChecker {
shards: Vec<Mutex<ReplayShard>>,
handshake_shards: Vec<Mutex<ReplayShard>>,
tls_shards: Vec<Mutex<ReplayShard>>,
shard_mask: usize,
window: Duration,
tls_window: Duration,
checks: AtomicU64,
hits: AtomicU64,
additions: AtomicU64,
@ -1587,19 +1589,24 @@ impl ReplayShard {
impl ReplayChecker {
pub fn new(total_capacity: usize, window: Duration) -> Self {
const MIN_TLS_REPLAY_WINDOW: Duration = Duration::from_secs(120);
let num_shards = 64;
let shard_capacity = (total_capacity / num_shards).max(1);
let cap = NonZeroUsize::new(shard_capacity).unwrap();
let mut shards = Vec::with_capacity(num_shards);
let mut handshake_shards = Vec::with_capacity(num_shards);
let mut tls_shards = Vec::with_capacity(num_shards);
for _ in 0..num_shards {
shards.push(Mutex::new(ReplayShard::new(cap)));
handshake_shards.push(Mutex::new(ReplayShard::new(cap)));
tls_shards.push(Mutex::new(ReplayShard::new(cap)));
}
Self {
shards,
handshake_shards,
tls_shards,
shard_mask: num_shards - 1,
window,
tls_window: window.max(MIN_TLS_REPLAY_WINDOW),
checks: AtomicU64::new(0),
hits: AtomicU64::new(0),
additions: AtomicU64::new(0),
@ -1613,46 +1620,60 @@ impl ReplayChecker {
(hasher.finish() as usize) & self.shard_mask
}
fn check_and_add_internal(&self, data: &[u8]) -> bool {
fn check_and_add_internal(
&self,
data: &[u8],
shards: &[Mutex<ReplayShard>],
window: Duration,
) -> bool {
self.checks.fetch_add(1, Ordering::Relaxed);
let idx = self.get_shard_idx(data);
let mut shard = self.shards[idx].lock();
let mut shard = shards[idx].lock();
let now = Instant::now();
let found = shard.check(data, now, self.window);
let found = shard.check(data, now, window);
if found {
self.hits.fetch_add(1, Ordering::Relaxed);
} else {
shard.add(data, now, self.window);
shard.add(data, now, window);
self.additions.fetch_add(1, Ordering::Relaxed);
}
found
}
fn add_only(&self, data: &[u8]) {
fn add_only(&self, data: &[u8], shards: &[Mutex<ReplayShard>], window: Duration) {
self.additions.fetch_add(1, Ordering::Relaxed);
let idx = self.get_shard_idx(data);
let mut shard = self.shards[idx].lock();
shard.add(data, Instant::now(), self.window);
let mut shard = shards[idx].lock();
shard.add(data, Instant::now(), window);
}
pub fn check_and_add_handshake(&self, data: &[u8]) -> bool {
self.check_and_add_internal(data)
self.check_and_add_internal(data, &self.handshake_shards, self.window)
}
pub fn check_and_add_tls_digest(&self, data: &[u8]) -> bool {
self.check_and_add_internal(data)
self.check_and_add_internal(data, &self.tls_shards, self.tls_window)
}
// Compatibility helpers (non-atomic split operations) — prefer check_and_add_*.
pub fn check_handshake(&self, data: &[u8]) -> bool { self.check_and_add_handshake(data) }
pub fn add_handshake(&self, data: &[u8]) { self.add_only(data) }
pub fn add_handshake(&self, data: &[u8]) {
self.add_only(data, &self.handshake_shards, self.window)
}
pub fn check_tls_digest(&self, data: &[u8]) -> bool { self.check_and_add_tls_digest(data) }
pub fn add_tls_digest(&self, data: &[u8]) { self.add_only(data) }
pub fn add_tls_digest(&self, data: &[u8]) {
self.add_only(data, &self.tls_shards, self.tls_window)
}
pub fn stats(&self) -> ReplayStats {
let mut total_entries = 0;
let mut total_queue_len = 0;
for shard in &self.shards {
for shard in &self.handshake_shards {
let s = shard.lock();
total_entries += s.cache.len();
total_queue_len += s.queue.len();
}
for shard in &self.tls_shards {
let s = shard.lock();
total_entries += s.cache.len();
total_queue_len += s.queue.len();
@ -1665,7 +1686,7 @@ impl ReplayChecker {
total_hits: self.hits.load(Ordering::Relaxed),
total_additions: self.additions.load(Ordering::Relaxed),
total_cleanups: self.cleanups.load(Ordering::Relaxed),
num_shards: self.shards.len(),
num_shards: self.handshake_shards.len() + self.tls_shards.len(),
window_secs: self.window.as_secs(),
}
}
@ -1683,13 +1704,20 @@ impl ReplayChecker {
let now = Instant::now();
let mut cleaned = 0usize;
for shard_mutex in &self.shards {
for shard_mutex in &self.handshake_shards {
let mut shard = shard_mutex.lock();
let before = shard.len();
shard.cleanup(now, self.window);
let after = shard.len();
cleaned += before.saturating_sub(after);
}
for shard_mutex in &self.tls_shards {
let mut shard = shard_mutex.lock();
let before = shard.len();
shard.cleanup(now, self.tls_window);
let after = shard.len();
cleaned += before.saturating_sub(after);
}
self.cleanups.fetch_add(1, Ordering::Relaxed);
@ -1815,7 +1843,7 @@ mod tests {
fn test_replay_checker_many_keys() {
let checker = ReplayChecker::new(10_000, Duration::from_secs(60));
for i in 0..500u32 {
checker.add_only(&i.to_le_bytes());
checker.add_handshake(&i.to_le_bytes());
}
for i in 0..500u32 {
assert!(checker.check_handshake(&i.to_le_bytes()));
@ -1827,3 +1855,7 @@ mod tests {
#[cfg(test)]
#[path = "connection_lease_security_tests.rs"]
mod connection_lease_security_tests;
#[cfg(test)]
#[path = "replay_checker_security_tests.rs"]
mod replay_checker_security_tests;

View File

@ -0,0 +1,80 @@
use super::*;
use std::time::Duration;
#[test]
fn replay_checker_keeps_tls_and_handshake_domains_isolated_for_same_key() {
let checker = ReplayChecker::new(128, Duration::from_millis(20));
let key = b"same-key-domain-separation";
assert!(
!checker.check_and_add_handshake(key),
"first handshake use should be fresh"
);
assert!(
!checker.check_and_add_tls_digest(key),
"same bytes in TLS domain should still be fresh"
);
assert!(
checker.check_and_add_handshake(key),
"second handshake use should be replay-hit"
);
assert!(
checker.check_and_add_tls_digest(key),
"second TLS use should be replay-hit independently"
);
}
#[test]
fn replay_checker_tls_window_is_clamped_beyond_small_handshake_window() {
let checker = ReplayChecker::new(128, Duration::from_millis(20));
let handshake_key = b"short-window-handshake";
let tls_key = b"short-window-tls";
assert!(!checker.check_and_add_handshake(handshake_key));
assert!(!checker.check_and_add_tls_digest(tls_key));
std::thread::sleep(Duration::from_millis(80));
assert!(
!checker.check_and_add_handshake(handshake_key),
"handshake key should expire under short configured window"
);
assert!(
checker.check_and_add_tls_digest(tls_key),
"TLS key should still be replay-hit because TLS window is clamped to a secure minimum"
);
}
#[test]
fn replay_checker_compat_add_paths_do_not_cross_pollute_domains() {
let checker = ReplayChecker::new(128, Duration::from_secs(1));
let key = b"compat-domain-separation";
checker.add_handshake(key);
assert!(
checker.check_and_add_handshake(key),
"handshake add helper must populate handshake domain"
);
assert!(
!checker.check_and_add_tls_digest(key),
"handshake add helper must not pollute TLS domain"
);
checker.add_tls_digest(key);
assert!(
checker.check_and_add_tls_digest(key),
"TLS add helper must populate TLS domain"
);
}
#[test]
fn replay_checker_stats_reflect_dual_shard_domains() {
let checker = ReplayChecker::new(128, Duration::from_secs(1));
let stats = checker.stats();
assert_eq!(
stats.num_shards, 128,
"stats should expose both shard domains (handshake + TLS)"
);
}