Security hardening, concurrency fixes, and expanded test coverage

This commit introduces a comprehensive set of improvements to enhance
the security, reliability, and configurability of the proxy server,
specifically targeting adversarial resilience and high-load concurrency.

Security & Cryptography:
- Zeroize MTProto cryptographic key material (`dec_key`, `enc_key`)
  immediately after use to prevent memory leakage on early returns.
- Move TLS handshake replay tracking after full policy/ALPN validation
  to prevent cache poisoning by unauthenticated probes.
- Add `proxy_protocol_trusted_cidrs` configuration to restrict PROXY
  protocol headers to trusted networks, rejecting spoofed IPs.

Adversarial Resilience & DoS Mitigation:
- Implement "Tiny Frame Debt" tracking in the middle-relay to prevent
  CPU exhaustion from malicious 0-byte or 1-byte frame floods.
- Add `mask_relay_max_bytes` to strictly bound unauthenticated fallback
  connections, preventing the proxy from being abused as an open relay.
- Add a 5ms prefetch window (`mask_classifier_prefetch_timeout_ms`) to
  correctly assemble and classify fragmented HTTP/1.1 and HTTP/2 probes
  (e.g., `PRI * HTTP/2.0`) before routing them to masking heuristics.
- Prevent recursive masking loops (FD exhaustion) by verifying the mask
  target is not the proxy's own listener via local interface enumeration.

Concurrency & Reliability:
- Eliminate executor waker storms during quota lock contention by replacing
  the spin-waker task with inline `Sleep` and exponential backoff.
- Roll back user quota reservations (`rollback_me2c_quota_reservation`)
  if a network write fails, preventing Head-of-Line (HoL) blocking from
  permanently burning data quotas.
- Recover gracefully from idle-registry `Mutex` poisoning instead of
  panicking, ensuring isolated thread failures do not break the proxy.
- Fix `auth_probe_scan_start_offset` modulo logic to ensure bounds safety.

Testing:
- Add extensive adversarial, timing, fuzzing, and invariant test suites
  for both the client and handshake modules.
This commit is contained in:
David Osipov 2026-03-22 23:06:26 +04:00
parent 6fc188f0c4
commit 91be148b72
No known key found for this signature in database
GPG Key ID: 0E55C4A47454E82E
65 changed files with 7473 additions and 210 deletions

34
Cargo.lock generated
View File

@ -1486,7 +1486,7 @@ dependencies = [
"cesu8",
"cfg-if",
"combine",
"jni-sys",
"jni-sys 0.3.1",
"log",
"thiserror 1.0.69",
"walkdir",
@ -1495,9 +1495,31 @@ dependencies = [
[[package]]
name = "jni-sys"
version = "0.3.0"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130"
checksum = "41a652e1f9b6e0275df1f15b32661cf0d4b78d4d87ddec5e0c3c20f097433258"
dependencies = [
"jni-sys 0.4.1",
]
[[package]]
name = "jni-sys"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c6377a88cb3910bee9b0fa88d4f42e1d2da8e79915598f65fb0c7ee14c878af2"
dependencies = [
"jni-sys-macros",
]
[[package]]
name = "jni-sys-macros"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38c0b942f458fe50cdac086d2f946512305e5631e720728f2a61aabcd47a6264"
dependencies = [
"quote",
"syn",
]
[[package]]
name = "jobserver"
@ -1659,9 +1681,9 @@ dependencies = [
[[package]]
name = "moka"
version = "0.12.14"
version = "0.12.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "85f8024e1c8e71c778968af91d43700ce1d11b219d127d79fb2934153b82b42b"
checksum = "957228ad12042ee839f93c8f257b62b4c0ab5eaae1d4fa60de53b27c9d7c5046"
dependencies = [
"crossbeam-channel",
"crossbeam-epoch",
@ -2771,7 +2793,7 @@ checksum = "7b2093cf4c8eb1e67749a6762251bc9cd836b6fc171623bd0a9d324d37af2417"
[[package]]
name = "telemt"
version = "3.3.29"
version = "3.3.30"
dependencies = [
"aes",
"anyhow",

View File

@ -1,8 +1,11 @@
[package]
name = "telemt"
version = "3.3.29"
version = "3.3.30"
edition = "2024"
[features]
redteam_offline_expected_fail = []
[dependencies]
# C
libc = "0.2"

View File

@ -269,6 +269,8 @@ Note: When `server.proxy_protocol` is enabled, incoming PROXY protocol headers a
| mask_shape_bucket_cap_bytes | `usize` | `4096` | Must be `>= mask_shape_bucket_floor_bytes`. | Maximum bucket size used by shape-channel hardening; traffic above cap is not padded further. |
| mask_shape_above_cap_blur | `bool` | `false` | Requires `mask_shape_hardening = true`; requires `mask_shape_above_cap_blur_max_bytes > 0`. | Adds bounded randomized tail bytes even when forwarded size already exceeds cap. |
| mask_shape_above_cap_blur_max_bytes | `usize` | `512` | Must be `<= 1048576`; must be `> 0` when `mask_shape_above_cap_blur = true`. | Maximum randomized extra bytes appended above cap. |
| mask_relay_max_bytes | `usize` | `5242880` | Must be `> 0`; must be `<= 67108864`. | Maximum relayed bytes per direction on unauthenticated masking fallback path. |
| mask_classifier_prefetch_timeout_ms | `u64` | `5` | Must be within `[5, 50]`. | Timeout budget (ms) for extending fragmented initial classifier window on masking fallback. |
| mask_timing_normalization_enabled | `bool` | `false` | Requires `mask_timing_normalization_floor_ms > 0`; requires `ceiling >= floor`. | Enables timing envelope normalization on masking outcomes. |
| mask_timing_normalization_floor_ms | `u64` | `0` | Must be `> 0` when timing normalization is enabled; must be `<= ceiling`. | Lower bound (ms) for masking outcome normalization target. |
| mask_timing_normalization_ceiling_ms | `u64` | `0` | Must be `>= floor`; must be `<= 60000`. | Upper bound (ms) for masking outcome normalization target. |

View File

@ -553,6 +553,20 @@ pub(crate) fn default_mask_shape_above_cap_blur_max_bytes() -> usize {
512
}
#[cfg(not(test))]
pub(crate) fn default_mask_relay_max_bytes() -> usize {
5 * 1024 * 1024
}
#[cfg(test)]
pub(crate) fn default_mask_relay_max_bytes() -> usize {
32 * 1024
}
pub(crate) fn default_mask_classifier_prefetch_timeout_ms() -> u64 {
5
}
pub(crate) fn default_mask_timing_normalization_enabled() -> bool {
false
}

View File

@ -600,6 +600,9 @@ fn warn_non_hot_changes(old: &ProxyConfig, new: &ProxyConfig, non_hot_changed: b
|| old.censorship.mask_shape_above_cap_blur != new.censorship.mask_shape_above_cap_blur
|| old.censorship.mask_shape_above_cap_blur_max_bytes
!= new.censorship.mask_shape_above_cap_blur_max_bytes
|| old.censorship.mask_relay_max_bytes != new.censorship.mask_relay_max_bytes
|| old.censorship.mask_classifier_prefetch_timeout_ms
!= new.censorship.mask_classifier_prefetch_timeout_ms
|| old.censorship.mask_timing_normalization_enabled
!= new.censorship.mask_timing_normalization_enabled
|| old.censorship.mask_timing_normalization_floor_ms

View File

@ -430,6 +430,25 @@ impl ProxyConfig {
));
}
if config.censorship.mask_relay_max_bytes == 0 {
return Err(ProxyError::Config(
"censorship.mask_relay_max_bytes must be > 0".to_string(),
));
}
if config.censorship.mask_relay_max_bytes > 67_108_864 {
return Err(ProxyError::Config(
"censorship.mask_relay_max_bytes must be <= 67108864".to_string(),
));
}
if !(5..=50).contains(&config.censorship.mask_classifier_prefetch_timeout_ms) {
return Err(ProxyError::Config(
"censorship.mask_classifier_prefetch_timeout_ms must be within [5, 50]"
.to_string(),
));
}
if config.censorship.mask_timing_normalization_ceiling_ms
< config.censorship.mask_timing_normalization_floor_ms
{
@ -1134,6 +1153,10 @@ mod load_security_tests;
#[path = "tests/load_mask_shape_security_tests.rs"]
mod load_mask_shape_security_tests;
#[cfg(test)]
#[path = "tests/load_mask_classifier_prefetch_timeout_security_tests.rs"]
mod load_mask_classifier_prefetch_timeout_security_tests;
#[cfg(test)]
mod tests {
use super::*;

View File

@ -0,0 +1,75 @@
use super::*;
use std::fs;
use std::path::PathBuf;
use std::time::{SystemTime, UNIX_EPOCH};
fn write_temp_config(contents: &str) -> PathBuf {
let nonce = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("system time must be after unix epoch")
.as_nanos();
let path = std::env::temp_dir()
.join(format!("telemt-load-mask-prefetch-timeout-security-{nonce}.toml"));
fs::write(&path, contents).expect("temp config write must succeed");
path
}
fn remove_temp_config(path: &PathBuf) {
let _ = fs::remove_file(path);
}
#[test]
fn load_rejects_mask_classifier_prefetch_timeout_below_min_bound() {
let path = write_temp_config(
r#"
[censorship]
mask_classifier_prefetch_timeout_ms = 4
"#,
);
let err = ProxyConfig::load(&path)
.expect_err("prefetch timeout below minimum security bound must be rejected");
let msg = err.to_string();
assert!(
msg.contains("censorship.mask_classifier_prefetch_timeout_ms must be within [5, 50]"),
"error must explain timeout bound invariant, got: {msg}"
);
remove_temp_config(&path);
}
#[test]
fn load_rejects_mask_classifier_prefetch_timeout_above_max_bound() {
let path = write_temp_config(
r#"
[censorship]
mask_classifier_prefetch_timeout_ms = 51
"#,
);
let err = ProxyConfig::load(&path)
.expect_err("prefetch timeout above max security bound must be rejected");
let msg = err.to_string();
assert!(
msg.contains("censorship.mask_classifier_prefetch_timeout_ms must be within [5, 50]"),
"error must explain timeout bound invariant, got: {msg}"
);
remove_temp_config(&path);
}
#[test]
fn load_accepts_mask_classifier_prefetch_timeout_within_bounds() {
let path = write_temp_config(
r#"
[censorship]
mask_classifier_prefetch_timeout_ms = 20
"#,
);
let cfg = ProxyConfig::load(&path)
.expect("prefetch timeout within security bounds must be accepted");
assert_eq!(cfg.censorship.mask_classifier_prefetch_timeout_ms, 20);
remove_temp_config(&path);
}

View File

@ -236,3 +236,57 @@ mask_shape_above_cap_blur_max_bytes = 8
remove_temp_config(&path);
}
#[test]
fn load_rejects_zero_mask_relay_max_bytes() {
let path = write_temp_config(
r#"
[censorship]
mask_relay_max_bytes = 0
"#,
);
let err = ProxyConfig::load(&path).expect_err("mask_relay_max_bytes must be > 0");
let msg = err.to_string();
assert!(
msg.contains("censorship.mask_relay_max_bytes must be > 0"),
"error must explain non-zero relay cap invariant, got: {msg}"
);
remove_temp_config(&path);
}
#[test]
fn load_rejects_mask_relay_max_bytes_above_upper_bound() {
let path = write_temp_config(
r#"
[censorship]
mask_relay_max_bytes = 67108865
"#,
);
let err = ProxyConfig::load(&path)
.expect_err("mask_relay_max_bytes above hard cap must be rejected");
let msg = err.to_string();
assert!(
msg.contains("censorship.mask_relay_max_bytes must be <= 67108864"),
"error must explain relay cap upper bound invariant, got: {msg}"
);
remove_temp_config(&path);
}
#[test]
fn load_accepts_valid_mask_relay_max_bytes() {
let path = write_temp_config(
r#"
[censorship]
mask_relay_max_bytes = 8388608
"#,
);
let cfg = ProxyConfig::load(&path).expect("valid mask_relay_max_bytes must be accepted");
assert_eq!(cfg.censorship.mask_relay_max_bytes, 8_388_608);
remove_temp_config(&path);
}

View File

@ -1450,6 +1450,14 @@ pub struct AntiCensorshipConfig {
#[serde(default = "default_mask_shape_above_cap_blur_max_bytes")]
pub mask_shape_above_cap_blur_max_bytes: usize,
/// Maximum bytes relayed per direction on unauthenticated masking fallback paths.
#[serde(default = "default_mask_relay_max_bytes")]
pub mask_relay_max_bytes: usize,
/// Prefetch timeout (ms) for extending fragmented masking classifier window.
#[serde(default = "default_mask_classifier_prefetch_timeout_ms")]
pub mask_classifier_prefetch_timeout_ms: u64,
/// Enable outcome-time normalization envelope for masking fallback.
#[serde(default = "default_mask_timing_normalization_enabled")]
pub mask_timing_normalization_enabled: bool,
@ -1488,6 +1496,8 @@ impl Default for AntiCensorshipConfig {
mask_shape_bucket_cap_bytes: default_mask_shape_bucket_cap_bytes(),
mask_shape_above_cap_blur: default_mask_shape_above_cap_blur(),
mask_shape_above_cap_blur_max_bytes: default_mask_shape_above_cap_blur_max_bytes(),
mask_relay_max_bytes: default_mask_relay_max_bytes(),
mask_classifier_prefetch_timeout_ms: default_mask_classifier_prefetch_timeout_ms(),
mask_timing_normalization_enabled: default_mask_timing_normalization_enabled(),
mask_timing_normalization_floor_ms: default_mask_timing_normalization_floor_ms(),
mask_timing_normalization_ceiling_ms: default_mask_timing_normalization_ceiling_ms(),

View File

@ -186,6 +186,67 @@ fn handshake_timeout_with_mask_grace(config: &ProxyConfig) -> Duration {
}
}
const MASK_CLASSIFIER_PREFETCH_WINDOW: usize = 16;
#[cfg(test)]
const MASK_CLASSIFIER_PREFETCH_TIMEOUT: Duration = Duration::from_millis(5);
fn mask_classifier_prefetch_timeout(config: &ProxyConfig) -> Duration {
Duration::from_millis(config.censorship.mask_classifier_prefetch_timeout_ms)
}
fn should_prefetch_mask_classifier_window(initial_data: &[u8]) -> bool {
if initial_data.len() >= MASK_CLASSIFIER_PREFETCH_WINDOW {
return false;
}
if initial_data.is_empty() {
// Empty initial_data means there is no client probe prefix to refine.
// Prefetching in this case can consume fallback relay payload bytes and
// accidentally route them through shaping heuristics.
return false;
}
if initial_data[0] == 0x16 || initial_data.starts_with(b"SSH-") {
return false;
}
initial_data.iter().all(|b| b.is_ascii_alphabetic() || *b == b' ')
}
#[cfg(test)]
async fn extend_masking_initial_window<R>(reader: &mut R, initial_data: &mut Vec<u8>)
where
R: AsyncRead + Unpin,
{
extend_masking_initial_window_with_timeout(reader, initial_data, MASK_CLASSIFIER_PREFETCH_TIMEOUT)
.await;
}
async fn extend_masking_initial_window_with_timeout<R>(
reader: &mut R,
initial_data: &mut Vec<u8>,
prefetch_timeout: Duration,
)
where
R: AsyncRead + Unpin,
{
if !should_prefetch_mask_classifier_window(initial_data) {
return;
}
let need = MASK_CLASSIFIER_PREFETCH_WINDOW.saturating_sub(initial_data.len());
if need == 0 {
return;
}
let mut extra = [0u8; MASK_CLASSIFIER_PREFETCH_WINDOW];
if let Ok(Ok(n)) = timeout(prefetch_timeout, reader.read(&mut extra[..need])).await
&& n > 0
{
initial_data.extend_from_slice(&extra[..n]);
}
}
fn masking_outcome<R, W>(
reader: R,
writer: W,
@ -200,6 +261,15 @@ where
W: AsyncWrite + Unpin + Send + 'static,
{
HandshakeOutcome::NeedsMasking(Box::pin(async move {
let mut reader = reader;
let mut initial_data = initial_data;
extend_masking_initial_window_with_timeout(
&mut reader,
&mut initial_data,
mask_classifier_prefetch_timeout(&config),
)
.await;
handle_bad_client(
reader,
writer,
@ -1321,6 +1391,38 @@ mod masking_shape_classifier_fuzz_redteam_expected_fail_tests;
#[path = "tests/client_masking_probe_evasion_blackhat_tests.rs"]
mod masking_probe_evasion_blackhat_tests;
#[cfg(test)]
#[path = "tests/client_masking_fragmented_classifier_security_tests.rs"]
mod masking_fragmented_classifier_security_tests;
#[cfg(test)]
#[path = "tests/client_masking_replay_timing_security_tests.rs"]
mod masking_replay_timing_security_tests;
#[cfg(test)]
#[path = "tests/client_masking_http2_fragmented_preface_security_tests.rs"]
mod masking_http2_fragmented_preface_security_tests;
#[cfg(test)]
#[path = "tests/client_masking_prefetch_invariant_security_tests.rs"]
mod masking_prefetch_invariant_security_tests;
#[cfg(test)]
#[path = "tests/client_masking_prefetch_timing_matrix_security_tests.rs"]
mod masking_prefetch_timing_matrix_security_tests;
#[cfg(test)]
#[path = "tests/client_masking_prefetch_config_runtime_security_tests.rs"]
mod masking_prefetch_config_runtime_security_tests;
#[cfg(test)]
#[path = "tests/client_masking_prefetch_config_pipeline_integration_security_tests.rs"]
mod masking_prefetch_config_pipeline_integration_security_tests;
#[cfg(test)]
#[path = "tests/client_masking_prefetch_strict_boundary_security_tests.rs"]
mod masking_prefetch_strict_boundary_security_tests;
#[cfg(test)]
#[path = "tests/client_beobachten_ttl_bounds_security_tests.rs"]
mod beobachten_ttl_bounds_security_tests;

View File

@ -121,6 +121,20 @@ fn auth_probe_eviction_offset(peer_ip: IpAddr, now: Instant) -> usize {
hasher.finish() as usize
}
fn auth_probe_scan_start_offset(
peer_ip: IpAddr,
now: Instant,
state_len: usize,
scan_limit: usize,
) -> usize {
if state_len == 0 || scan_limit == 0 {
return 0;
}
let window = state_len.min(scan_limit);
auth_probe_eviction_offset(peer_ip, now) % window
}
fn auth_probe_is_throttled(peer_ip: IpAddr, now: Instant) -> bool {
let peer_ip = normalize_auth_probe_ip(peer_ip);
let state = auth_probe_state_map();
@ -269,11 +283,7 @@ fn auth_probe_record_failure_with_state(
let mut eviction_candidate: Option<(IpAddr, u32, Instant)> = None;
let state_len = state.len();
let scan_limit = state_len.min(AUTH_PROBE_PRUNE_SCAN_LIMIT);
let start_offset = if state_len == 0 {
0
} else {
auth_probe_eviction_offset(peer_ip, now) % state_len
};
let start_offset = auth_probe_scan_start_offset(peer_ip, now, state_len, scan_limit);
let mut scanned = 0usize;
for entry in state.iter().skip(start_offset) {
@ -769,7 +779,7 @@ where
let mut dec_key_input = Zeroizing::new(Vec::with_capacity(PREKEY_LEN + secret.len()));
dec_key_input.extend_from_slice(dec_prekey);
dec_key_input.extend_from_slice(&secret);
let dec_key = sha256(&dec_key_input);
let dec_key = Zeroizing::new(sha256(&dec_key_input));
let mut dec_iv_arr = [0u8; IV_LEN];
dec_iv_arr.copy_from_slice(dec_iv_bytes);
@ -805,7 +815,7 @@ where
let mut enc_key_input = Zeroizing::new(Vec::with_capacity(PREKEY_LEN + secret.len()));
enc_key_input.extend_from_slice(enc_prekey);
enc_key_input.extend_from_slice(&secret);
let enc_key = sha256(&enc_key_input);
let enc_key = Zeroizing::new(sha256(&enc_key_input));
let mut enc_iv_arr = [0u8; IV_LEN];
enc_iv_arr.copy_from_slice(enc_iv_bytes);
@ -830,9 +840,9 @@ where
user: user.clone(),
dc_idx,
proto_tag,
dec_key,
dec_key: *dec_key,
dec_iv,
enc_key,
enc_key: *enc_key,
enc_iv,
peer,
is_tls,
@ -979,6 +989,14 @@ mod saturation_poison_security_tests;
#[path = "tests/handshake_auth_probe_hardening_adversarial_tests.rs"]
mod auth_probe_hardening_adversarial_tests;
#[cfg(test)]
#[path = "tests/handshake_auth_probe_scan_budget_security_tests.rs"]
mod auth_probe_scan_budget_security_tests;
#[cfg(test)]
#[path = "tests/handshake_auth_probe_scan_offset_stress_tests.rs"]
mod auth_probe_scan_offset_stress_tests;
#[cfg(test)]
#[path = "tests/handshake_advanced_clever_tests.rs"]
mod advanced_clever_tests;
@ -995,6 +1013,10 @@ mod real_bug_stress_tests;
#[path = "tests/handshake_timing_manual_bench_tests.rs"]
mod timing_manual_bench_tests;
#[cfg(test)]
#[path = "tests/handshake_key_material_zeroization_security_tests.rs"]
mod handshake_key_material_zeroization_security_tests;
/// Compile-time guard: HandshakeSuccess holds cryptographic key material and
/// must never be Copy. A Copy impl would allow silent key duplication,
/// undermining the zeroize-on-drop guarantee.

View File

@ -4,10 +4,17 @@ use crate::config::ProxyConfig;
use crate::network::dns_overrides::resolve_socket_addr;
use crate::stats::beobachten::BeobachtenStore;
use crate::transport::proxy_protocol::{ProxyProtocolV1Builder, ProxyProtocolV2Builder};
use rand::{Rng, RngExt};
use std::net::SocketAddr;
#[cfg(unix)]
use nix::ifaddrs::getifaddrs;
use rand::rngs::StdRng;
use rand::{Rng, RngExt, SeedableRng};
use std::net::{IpAddr, SocketAddr};
use std::str;
use std::time::Duration;
#[cfg(unix)]
use std::sync::{Mutex, OnceLock};
#[cfg(test)]
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::{Duration, Instant as StdInstant};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpStream;
#[cfg(unix)]
@ -30,13 +37,23 @@ const MASK_RELAY_IDLE_TIMEOUT: Duration = Duration::from_secs(5);
#[cfg(test)]
const MASK_RELAY_IDLE_TIMEOUT: Duration = Duration::from_millis(100);
const MASK_BUFFER_SIZE: usize = 8192;
#[cfg(unix)]
#[cfg(not(test))]
const LOCAL_INTERFACE_CACHE_TTL: Duration = Duration::from_secs(300);
#[cfg(all(unix, test))]
const LOCAL_INTERFACE_CACHE_TTL: Duration = Duration::from_secs(1);
struct CopyOutcome {
total: usize,
ended_by_eof: bool,
}
async fn copy_with_idle_timeout<R, W>(reader: &mut R, writer: &mut W) -> CopyOutcome
async fn copy_with_idle_timeout<R, W>(
reader: &mut R,
writer: &mut W,
byte_cap: usize,
shutdown_on_eof: bool,
) -> CopyOutcome
where
R: AsyncRead + Unpin,
W: AsyncWrite + Unpin,
@ -44,14 +61,31 @@ where
let mut buf = [0u8; MASK_BUFFER_SIZE];
let mut total = 0usize;
let mut ended_by_eof = false;
if byte_cap == 0 {
return CopyOutcome {
total,
ended_by_eof,
};
}
loop {
let read_res = timeout(MASK_RELAY_IDLE_TIMEOUT, reader.read(&mut buf)).await;
let remaining_budget = byte_cap.saturating_sub(total);
if remaining_budget == 0 {
break;
}
let read_len = remaining_budget.min(MASK_BUFFER_SIZE);
let read_res = timeout(MASK_RELAY_IDLE_TIMEOUT, reader.read(&mut buf[..read_len])).await;
let n = match read_res {
Ok(Ok(n)) => n,
Ok(Err(_)) | Err(_) => break,
};
if n == 0 {
ended_by_eof = true;
if shutdown_on_eof {
let _ = timeout(MASK_RELAY_IDLE_TIMEOUT, writer.shutdown()).await;
}
break;
}
total = total.saturating_add(n);
@ -61,6 +95,10 @@ where
Ok(Ok(())) => {}
Ok(Err(_)) | Err(_) => break,
}
if total >= byte_cap {
break;
}
}
CopyOutcome {
total,
@ -68,6 +106,39 @@ where
}
}
fn is_http_probe(data: &[u8]) -> bool {
// RFC 7540 section 3.5: HTTP/2 client preface starts with "PRI ".
const HTTP_METHODS: [&[u8]; 10] = [
b"GET ",
b"POST",
b"HEAD",
b"PUT ",
b"DELETE",
b"OPTIONS",
b"CONNECT",
b"TRACE",
b"PATCH",
b"PRI ",
];
if data.is_empty() {
return false;
}
let window = &data[..data.len().min(16)];
for method in HTTP_METHODS {
if data.len() >= method.len() && window.starts_with(method) {
return true;
}
if (2..=3).contains(&window.len()) && method.starts_with(window) {
return true;
}
}
false
}
fn next_mask_shape_bucket(total: usize, floor: usize, cap: usize) -> usize {
if total == 0 || floor == 0 || cap < floor {
return total;
@ -125,6 +196,11 @@ async fn maybe_write_shape_padding<W>(
let mut remaining = target_total - total_sent;
let mut pad_chunk = [0u8; 1024];
let deadline = Instant::now() + MASK_TIMEOUT;
// Use a Send RNG so relay futures remain spawn-safe under Tokio.
let mut rng = {
let mut seed_source = rand::rng();
StdRng::from_rng(&mut seed_source)
};
while remaining > 0 {
let now = Instant::now();
@ -133,10 +209,7 @@ async fn maybe_write_shape_padding<W>(
}
let write_len = remaining.min(pad_chunk.len());
{
let mut rng = rand::rng();
rng.fill_bytes(&mut pad_chunk[..write_len]);
}
rng.fill_bytes(&mut pad_chunk[..write_len]);
let write_budget = deadline.saturating_duration_since(now);
match timeout(write_budget, mask_write.write_all(&pad_chunk[..write_len])).await {
Ok(Ok(())) => {}
@ -167,11 +240,11 @@ where
}
}
async fn consume_client_data_with_timeout<R>(reader: R)
async fn consume_client_data_with_timeout_and_cap<R>(reader: R, byte_cap: usize)
where
R: AsyncRead + Unpin,
{
if timeout(MASK_RELAY_TIMEOUT, consume_client_data(reader))
if timeout(MASK_RELAY_TIMEOUT, consume_client_data(reader, byte_cap))
.await
.is_err()
{
@ -190,6 +263,9 @@ fn mask_outcome_target_budget(config: &ProxyConfig) -> Duration {
if config.censorship.mask_timing_normalization_enabled {
let floor = config.censorship.mask_timing_normalization_floor_ms;
let ceiling = config.censorship.mask_timing_normalization_ceiling_ms;
if floor == 0 {
return MASK_TIMEOUT;
}
if ceiling > floor {
let mut rng = rand::rng();
return Duration::from_millis(rng.random_range(floor..=ceiling));
@ -219,14 +295,7 @@ async fn wait_mask_outcome_budget(started: Instant, config: &ProxyConfig) {
/// Detect client type based on initial data
fn detect_client_type(data: &[u8]) -> &'static str {
// Check for HTTP request
if data.len() > 4
&& (data.starts_with(b"GET ")
|| data.starts_with(b"POST")
|| data.starts_with(b"HEAD")
|| data.starts_with(b"PUT ")
|| data.starts_with(b"DELETE")
|| data.starts_with(b"OPTIONS"))
{
if is_http_probe(data) {
return "HTTP";
}
@ -248,6 +317,172 @@ fn detect_client_type(data: &[u8]) -> &'static str {
"unknown"
}
fn parse_mask_host_ip_literal(host: &str) -> Option<IpAddr> {
if host.starts_with('[') && host.ends_with(']') {
return host[1..host.len() - 1].parse::<IpAddr>().ok();
}
host.parse::<IpAddr>().ok()
}
fn canonical_ip(ip: IpAddr) -> IpAddr {
match ip {
IpAddr::V6(v6) => v6.to_ipv4_mapped().map(IpAddr::V4).unwrap_or(IpAddr::V6(v6)),
IpAddr::V4(v4) => IpAddr::V4(v4),
}
}
#[cfg(unix)]
fn collect_local_interface_ips() -> Vec<IpAddr> {
#[cfg(test)]
LOCAL_INTERFACE_ENUMERATIONS.fetch_add(1, Ordering::Relaxed);
let mut out = Vec::new();
if let Ok(addrs) = getifaddrs() {
for iface in addrs {
if let Some(address) = iface.address {
if let Some(v4) = address.as_sockaddr_in() {
out.push(canonical_ip(IpAddr::V4(v4.ip())));
} else if let Some(v6) = address.as_sockaddr_in6() {
out.push(canonical_ip(IpAddr::V6(v6.ip())));
}
}
}
}
out
}
fn choose_interface_snapshot(previous: &[IpAddr], refreshed: Vec<IpAddr>) -> Vec<IpAddr> {
if refreshed.is_empty() && !previous.is_empty() {
return previous.to_vec();
}
refreshed
}
#[cfg(unix)]
#[derive(Default)]
struct LocalInterfaceCache {
ips: Vec<IpAddr>,
refreshed_at: Option<StdInstant>,
}
#[cfg(unix)]
static LOCAL_INTERFACE_CACHE: OnceLock<Mutex<LocalInterfaceCache>> = OnceLock::new();
#[cfg(unix)]
fn local_interface_ips() -> Vec<IpAddr> {
let cache = LOCAL_INTERFACE_CACHE.get_or_init(|| Mutex::new(LocalInterfaceCache::default()));
let mut guard = cache.lock().unwrap_or_else(|poison| poison.into_inner());
let stale = guard
.refreshed_at
.is_none_or(|at| at.elapsed() >= LOCAL_INTERFACE_CACHE_TTL);
if stale {
let refreshed = collect_local_interface_ips();
guard.ips = choose_interface_snapshot(&guard.ips, refreshed);
guard.refreshed_at = Some(StdInstant::now());
}
guard.ips.clone()
}
#[cfg(not(unix))]
fn local_interface_ips() -> Vec<IpAddr> {
Vec::new()
}
#[cfg(test)]
static LOCAL_INTERFACE_ENUMERATIONS: AtomicUsize = AtomicUsize::new(0);
#[cfg(test)]
fn reset_local_interface_enumerations_for_tests() {
LOCAL_INTERFACE_ENUMERATIONS.store(0, Ordering::Relaxed);
#[cfg(unix)]
if let Some(cache) = LOCAL_INTERFACE_CACHE.get() {
let mut guard = cache.lock().unwrap_or_else(|poison| poison.into_inner());
guard.ips.clear();
guard.refreshed_at = None;
}
}
#[cfg(test)]
fn local_interface_enumerations_for_tests() -> usize {
LOCAL_INTERFACE_ENUMERATIONS.load(Ordering::Relaxed)
}
fn is_mask_target_local_listener_with_interfaces(
mask_host: &str,
mask_port: u16,
local_addr: SocketAddr,
resolved_override: Option<SocketAddr>,
interface_ips: &[IpAddr],
) -> bool {
if mask_port != local_addr.port() {
return false;
}
let local_ip = canonical_ip(local_addr.ip());
let literal_mask_ip = parse_mask_host_ip_literal(mask_host).map(canonical_ip);
if let Some(addr) = resolved_override {
let resolved_ip = canonical_ip(addr.ip());
if resolved_ip == local_ip {
return true;
}
if local_ip.is_unspecified()
&& (resolved_ip.is_loopback()
|| resolved_ip.is_unspecified()
|| interface_ips.contains(&resolved_ip))
{
return true;
}
}
if let Some(mask_ip) = literal_mask_ip {
if mask_ip == local_ip {
return true;
}
if local_ip.is_unspecified()
&& (mask_ip.is_loopback()
|| mask_ip.is_unspecified()
|| interface_ips.contains(&mask_ip))
{
return true;
}
}
false
}
fn is_mask_target_local_listener(
mask_host: &str,
mask_port: u16,
local_addr: SocketAddr,
resolved_override: Option<SocketAddr>,
) -> bool {
if mask_port != local_addr.port() {
return false;
}
let interfaces = local_interface_ips();
is_mask_target_local_listener_with_interfaces(
mask_host,
mask_port,
local_addr,
resolved_override,
&interfaces,
)
}
fn masking_beobachten_ttl(config: &ProxyConfig) -> Duration {
let minutes = config.general.beobachten_minutes;
let clamped = minutes.clamp(1, 24 * 60);
Duration::from_secs(clamped.saturating_mul(60))
}
fn build_mask_proxy_header(
version: u8,
peer: SocketAddr,
@ -290,13 +525,14 @@ pub async fn handle_bad_client<R, W>(
{
let client_type = detect_client_type(initial_data);
if config.general.beobachten {
let ttl = Duration::from_secs(config.general.beobachten_minutes.saturating_mul(60));
let ttl = masking_beobachten_ttl(config);
beobachten.record(client_type, peer.ip(), ttl);
}
if !config.censorship.mask {
// Masking disabled, just consume data
consume_client_data_with_timeout(reader).await;
consume_client_data_with_timeout_and_cap(reader, config.censorship.mask_relay_max_bytes)
.await;
return;
}
@ -341,6 +577,7 @@ pub async fn handle_bad_client<R, W>(
config.censorship.mask_shape_above_cap_blur,
config.censorship.mask_shape_above_cap_blur_max_bytes,
config.censorship.mask_shape_hardening_aggressive_mode,
config.censorship.mask_relay_max_bytes,
),
)
.await
@ -353,12 +590,12 @@ pub async fn handle_bad_client<R, W>(
Ok(Err(e)) => {
wait_mask_connect_budget_if_needed(connect_started, config).await;
debug!(error = %e, "Failed to connect to mask unix socket");
consume_client_data_with_timeout(reader).await;
consume_client_data_with_timeout_and_cap(reader, config.censorship.mask_relay_max_bytes).await;
wait_mask_outcome_budget(outcome_started, config).await;
}
Err(_) => {
debug!("Timeout connecting to mask unix socket");
consume_client_data_with_timeout(reader).await;
consume_client_data_with_timeout_and_cap(reader, config.censorship.mask_relay_max_bytes).await;
wait_mask_outcome_budget(outcome_started, config).await;
}
}
@ -371,6 +608,24 @@ pub async fn handle_bad_client<R, W>(
.as_deref()
.unwrap_or(&config.censorship.tls_domain);
let mask_port = config.censorship.mask_port;
let outcome_started = Instant::now();
// Fail closed when fallback points at our own listener endpoint.
// Self-referential masking can create recursive proxy loops under
// misconfiguration and leak distinguishable load spikes to adversaries.
let resolved_mask_addr = resolve_socket_addr(mask_host, mask_port);
if is_mask_target_local_listener(mask_host, mask_port, local_addr, resolved_mask_addr) {
debug!(
client_type = client_type,
host = %mask_host,
port = mask_port,
local = %local_addr,
"Mask target resolves to local listener; refusing self-referential masking fallback"
);
consume_client_data_with_timeout_and_cap(reader, config.censorship.mask_relay_max_bytes).await;
wait_mask_outcome_budget(outcome_started, config).await;
return;
}
debug!(
client_type = client_type,
@ -381,10 +636,9 @@ pub async fn handle_bad_client<R, W>(
);
// Apply runtime DNS override for mask target when configured.
let mask_addr = resolve_socket_addr(mask_host, mask_port)
let mask_addr = resolved_mask_addr
.map(|addr| addr.to_string())
.unwrap_or_else(|| format!("{}:{}", mask_host, mask_port));
let outcome_started = Instant::now();
let connect_started = Instant::now();
let connect_result = timeout(MASK_TIMEOUT, TcpStream::connect(&mask_addr)).await;
match connect_result {
@ -413,6 +667,7 @@ pub async fn handle_bad_client<R, W>(
config.censorship.mask_shape_above_cap_blur,
config.censorship.mask_shape_above_cap_blur_max_bytes,
config.censorship.mask_shape_hardening_aggressive_mode,
config.censorship.mask_relay_max_bytes,
),
)
.await
@ -425,12 +680,12 @@ pub async fn handle_bad_client<R, W>(
Ok(Err(e)) => {
wait_mask_connect_budget_if_needed(connect_started, config).await;
debug!(error = %e, "Failed to connect to mask host");
consume_client_data_with_timeout(reader).await;
consume_client_data_with_timeout_and_cap(reader, config.censorship.mask_relay_max_bytes).await;
wait_mask_outcome_budget(outcome_started, config).await;
}
Err(_) => {
debug!("Timeout connecting to mask host");
consume_client_data_with_timeout(reader).await;
consume_client_data_with_timeout_and_cap(reader, config.censorship.mask_relay_max_bytes).await;
wait_mask_outcome_budget(outcome_started, config).await;
}
}
@ -449,6 +704,7 @@ async fn relay_to_mask<R, W, MR, MW>(
shape_above_cap_blur: bool,
shape_above_cap_blur_max_bytes: usize,
shape_hardening_aggressive_mode: bool,
mask_relay_max_bytes: usize,
) where
R: AsyncRead + Unpin + Send + 'static,
W: AsyncWrite + Unpin + Send + 'static,
@ -464,8 +720,18 @@ async fn relay_to_mask<R, W, MR, MW>(
}
let (upstream_copy, downstream_copy) = tokio::join!(
async { copy_with_idle_timeout(&mut reader, &mut mask_write).await },
async { copy_with_idle_timeout(&mut mask_read, &mut writer).await }
async {
copy_with_idle_timeout(
&mut reader,
&mut mask_write,
mask_relay_max_bytes,
!shape_hardening_enabled,
)
.await
},
async {
copy_with_idle_timeout(&mut mask_read, &mut writer, mask_relay_max_bytes, true).await
}
);
let total_sent = initial_data.len().saturating_add(upstream_copy.total);
@ -491,13 +757,30 @@ async fn relay_to_mask<R, W, MR, MW>(
let _ = writer.shutdown().await;
}
/// Just consume all data from client without responding
async fn consume_client_data<R: AsyncRead + Unpin>(mut reader: R) {
let mut buf = vec![0u8; MASK_BUFFER_SIZE];
while let Ok(n) = reader.read(&mut buf).await {
/// Just consume all data from client without responding.
async fn consume_client_data<R: AsyncRead + Unpin>(mut reader: R, byte_cap: usize) {
if byte_cap == 0 {
return;
}
// Keep drain path fail-closed under slow-loris stalls.
let mut buf = [0u8; MASK_BUFFER_SIZE];
let mut total = 0usize;
loop {
let n = match timeout(MASK_RELAY_IDLE_TIMEOUT, reader.read(&mut buf)).await {
Ok(Ok(n)) => n,
Ok(Err(_)) | Err(_) => break,
};
if n == 0 {
break;
}
total = total.saturating_add(n);
if total >= byte_cap {
break;
}
}
}
@ -548,3 +831,63 @@ mod masking_aggressive_mode_security_tests;
#[cfg(test)]
#[path = "tests/masking_timing_sidechannel_redteam_expected_fail_tests.rs"]
mod masking_timing_sidechannel_redteam_expected_fail_tests;
#[cfg(test)]
#[path = "tests/masking_self_target_loop_security_tests.rs"]
mod masking_self_target_loop_security_tests;
#[cfg(test)]
#[path = "tests/masking_classification_completeness_security_tests.rs"]
mod masking_classification_completeness_security_tests;
#[cfg(test)]
#[path = "tests/masking_relay_guardrails_security_tests.rs"]
mod masking_relay_guardrails_security_tests;
#[cfg(test)]
#[path = "tests/masking_connect_failure_close_matrix_security_tests.rs"]
mod masking_connect_failure_close_matrix_security_tests;
#[cfg(test)]
#[path = "tests/masking_additional_hardening_security_tests.rs"]
mod masking_additional_hardening_security_tests;
#[cfg(test)]
#[path = "tests/masking_consume_idle_timeout_security_tests.rs"]
mod masking_consume_idle_timeout_security_tests;
#[cfg(test)]
#[path = "tests/masking_http2_probe_classification_security_tests.rs"]
mod masking_http2_probe_classification_security_tests;
#[cfg(test)]
#[path = "tests/masking_http_probe_boundary_security_tests.rs"]
mod masking_http_probe_boundary_security_tests;
#[cfg(test)]
#[path = "tests/masking_rng_hoist_perf_regression_tests.rs"]
mod masking_rng_hoist_perf_regression_tests;
#[cfg(test)]
#[path = "tests/masking_http2_preface_integration_security_tests.rs"]
mod masking_http2_preface_integration_security_tests;
#[cfg(test)]
#[path = "tests/masking_consume_stress_adversarial_tests.rs"]
mod masking_consume_stress_adversarial_tests;
#[cfg(test)]
#[path = "tests/masking_interface_cache_security_tests.rs"]
mod masking_interface_cache_security_tests;
#[cfg(test)]
#[path = "tests/masking_interface_cache_defense_in_depth_security_tests.rs"]
mod masking_interface_cache_defense_in_depth_security_tests;
#[cfg(test)]
#[path = "tests/masking_padding_timeout_adversarial_tests.rs"]
mod masking_padding_timeout_adversarial_tests;
#[cfg(all(test, feature = "redteam_offline_expected_fail"))]
#[path = "tests/masking_offline_target_redteam_expected_fail_tests.rs"]
mod masking_offline_target_redteam_expected_fail_tests;

View File

@ -39,6 +39,8 @@ const C2ME_CHANNEL_CAPACITY_FALLBACK: usize = 128;
const C2ME_SOFT_PRESSURE_MIN_FREE_SLOTS: usize = 64;
const C2ME_SENDER_FAIRNESS_BUDGET: usize = 32;
const RELAY_IDLE_IO_POLL_MAX: Duration = Duration::from_secs(1);
const TINY_FRAME_DEBT_PER_TINY: u32 = 8;
const TINY_FRAME_DEBT_LIMIT: u32 = 512;
#[cfg(test)]
const C2ME_SEND_TIMEOUT: Duration = Duration::from_millis(50);
#[cfg(not(test))]
@ -94,10 +96,23 @@ fn relay_idle_candidate_registry() -> &'static Mutex<RelayIdleCandidateRegistry>
RELAY_IDLE_CANDIDATE_REGISTRY.get_or_init(|| Mutex::new(RelayIdleCandidateRegistry::default()))
}
fn relay_idle_candidate_registry_lock() -> std::sync::MutexGuard<'static, RelayIdleCandidateRegistry> {
let registry = relay_idle_candidate_registry();
match registry.lock() {
Ok(guard) => guard,
Err(poisoned) => {
let mut guard = poisoned.into_inner();
// Fail closed after panic while holding registry lock: drop all
// candidates and pressure cursors to avoid stale cross-session state.
*guard = RelayIdleCandidateRegistry::default();
registry.clear_poison();
guard
}
}
}
fn mark_relay_idle_candidate(conn_id: u64) -> bool {
let Ok(mut guard) = relay_idle_candidate_registry().lock() else {
return false;
};
let mut guard = relay_idle_candidate_registry_lock();
if guard.by_conn_id.contains_key(&conn_id) {
return false;
@ -116,9 +131,7 @@ fn mark_relay_idle_candidate(conn_id: u64) -> bool {
}
fn clear_relay_idle_candidate(conn_id: u64) {
let Ok(mut guard) = relay_idle_candidate_registry().lock() else {
return;
};
let mut guard = relay_idle_candidate_registry_lock();
if let Some(meta) = guard.by_conn_id.remove(&conn_id) {
guard.ordered.remove(&(meta.mark_order_seq, conn_id));
@ -127,23 +140,17 @@ fn clear_relay_idle_candidate(conn_id: u64) {
#[cfg(test)]
fn oldest_relay_idle_candidate() -> Option<u64> {
let Ok(guard) = relay_idle_candidate_registry().lock() else {
return None;
};
let guard = relay_idle_candidate_registry_lock();
guard.ordered.iter().next().map(|(_, conn_id)| *conn_id)
}
fn note_relay_pressure_event() {
let Ok(mut guard) = relay_idle_candidate_registry().lock() else {
return;
};
let mut guard = relay_idle_candidate_registry_lock();
guard.pressure_event_seq = guard.pressure_event_seq.wrapping_add(1);
}
fn relay_pressure_event_seq() -> u64 {
let Ok(guard) = relay_idle_candidate_registry().lock() else {
return 0;
};
let guard = relay_idle_candidate_registry_lock();
guard.pressure_event_seq
}
@ -152,9 +159,7 @@ fn maybe_evict_idle_candidate_on_pressure(
seen_pressure_seq: &mut u64,
stats: &Stats,
) -> bool {
let Ok(mut guard) = relay_idle_candidate_registry().lock() else {
return false;
};
let mut guard = relay_idle_candidate_registry_lock();
let latest_pressure_seq = guard.pressure_event_seq;
if latest_pressure_seq == *seen_pressure_seq {
@ -199,13 +204,9 @@ fn maybe_evict_idle_candidate_on_pressure(
#[cfg(test)]
fn clear_relay_idle_pressure_state_for_testing() {
if let Some(registry) = RELAY_IDLE_CANDIDATE_REGISTRY.get()
&& let Ok(mut guard) = registry.lock()
{
guard.by_conn_id.clear();
guard.ordered.clear();
guard.pressure_event_seq = 0;
guard.pressure_consumed_seq = 0;
if RELAY_IDLE_CANDIDATE_REGISTRY.get().is_some() {
let mut guard = relay_idle_candidate_registry_lock();
*guard = RelayIdleCandidateRegistry::default();
}
RELAY_IDLE_MARK_SEQ.store(0, Ordering::Relaxed);
}
@ -259,6 +260,7 @@ impl RelayClientIdlePolicy {
struct RelayClientIdleState {
last_client_frame_at: Instant,
soft_idle_marked: bool,
tiny_frame_debt: u32,
}
impl RelayClientIdleState {
@ -266,6 +268,7 @@ impl RelayClientIdleState {
Self {
last_client_frame_at: now,
soft_idle_marked: false,
tiny_frame_debt: 0,
}
}
@ -551,15 +554,6 @@ fn quota_soft_cap(limit: u64, overshoot: u64) -> u64 {
limit.saturating_add(overshoot)
}
fn quota_exceeded_for_user_soft(
stats: &Stats,
user: &str,
quota_limit: Option<u64>,
overshoot: u64,
) -> bool {
quota_limit.is_some_and(|quota| stats.get_user_total_octets(user) >= quota_soft_cap(quota, overshoot))
}
fn quota_would_be_exceeded_for_user_soft(
stats: &Stats,
user: &str,
@ -617,6 +611,16 @@ fn observe_me_d2c_flush_event(
}
}
fn rollback_me2c_quota_reservation(
stats: &Stats,
user: &str,
bytes_me2c: &AtomicU64,
reserved_bytes: u64,
) {
stats.sub_user_octets_to(user, reserved_bytes);
bytes_me2c.fetch_sub(reserved_bytes, Ordering::Relaxed);
}
#[cfg(test)]
fn quota_user_lock_test_guard() -> &'static Mutex<()> {
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
@ -630,6 +634,19 @@ fn quota_user_lock_test_scope() -> std::sync::MutexGuard<'static, ()> {
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
#[cfg(test)]
fn relay_idle_pressure_test_guard() -> &'static Mutex<()> {
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
TEST_LOCK.get_or_init(|| Mutex::new(()))
}
#[cfg(test)]
pub(crate) fn relay_idle_pressure_test_scope() -> std::sync::MutexGuard<'static, ()> {
relay_idle_pressure_test_guard()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
fn quota_overflow_user_lock(user: &str) -> Arc<AsyncMutex<()>> {
let stripes = QUOTA_USER_OVERFLOW_LOCKS.get_or_init(|| {
(0..QUOTA_OVERFLOW_LOCK_STRIPES)
@ -665,6 +682,11 @@ fn quota_user_lock(user: &str) -> Arc<AsyncMutex<()>> {
}
}
#[cfg(test)]
pub(crate) fn cross_mode_quota_user_lock_for_tests(user: &str) -> Arc<Mutex<()>> {
crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(user)
}
async fn enqueue_c2me_command(
tx: &mpsc::Sender<C2MeCommand>,
cmd: C2MeCommand,
@ -710,6 +732,8 @@ where
{
let user = success.user.clone();
let quota_limit = config.access.user_data_quota.get(&user).copied();
let cross_mode_quota_lock =
quota_limit.map(|_| crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user));
let peer = success.peer;
let proto_tag = success.proto_tag;
let pool_generation = me_pool.current_generation();
@ -1221,6 +1245,17 @@ where
if let Some(limit) = quota_limit {
let quota_lock = quota_user_lock(&user);
let _quota_guard = quota_lock.lock().await;
let Some(cross_mode_lock) = cross_mode_quota_lock.as_ref() else {
main_result = Err(ProxyError::Proxy(
"cross-mode quota lock missing for quota-limited session"
.to_string(),
));
break;
};
let _cross_mode_quota_guard = match cross_mode_lock.lock() {
Ok(guard) => guard,
Err(poisoned) => poisoned.into_inner(),
};
stats.add_user_octets_from(&user, payload.len() as u64);
if quota_exceeded_for_user(stats.as_ref(), &user, Some(limit)) {
main_result = Err(ProxyError::DataQuotaExceeded {
@ -1320,6 +1355,8 @@ async fn read_client_payload_with_idle_policy<R>(
where
R: AsyncRead + Unpin + Send + 'static,
{
const LEGACY_MAX_CONSECUTIVE_ZERO_LEN_FRAMES: u32 = 4;
async fn read_exact_with_policy<R>(
client_reader: &mut CryptoReader<R>,
buf: &mut [u8],
@ -1458,6 +1495,7 @@ where
Ok(())
}
let mut consecutive_zero_len_frames = 0u32;
loop {
let (len, quickack, raw_len_bytes) = match proto_tag {
ProtoTag::Abridged => {
@ -1538,6 +1576,27 @@ where
};
if len == 0 {
idle_state.tiny_frame_debt = idle_state
.tiny_frame_debt
.saturating_add(TINY_FRAME_DEBT_PER_TINY);
if idle_state.tiny_frame_debt >= TINY_FRAME_DEBT_LIMIT {
stats.increment_relay_protocol_desync_close_total();
return Err(ProxyError::Proxy(format!(
"Tiny frame overhead limit exceeded: debt={}, conn_id={}",
idle_state.tiny_frame_debt, forensics.conn_id
)));
}
if !idle_policy.enabled {
consecutive_zero_len_frames =
consecutive_zero_len_frames.saturating_add(1);
if consecutive_zero_len_frames > LEGACY_MAX_CONSECUTIVE_ZERO_LEN_FRAMES {
stats.increment_relay_protocol_desync_close_total();
return Err(ProxyError::Proxy(
"Excessive zero-length abridged frames".to_string(),
));
}
}
continue;
}
if len < 4 && proto_tag != ProtoTag::Abridged {
@ -1606,6 +1665,7 @@ where
}
*frame_counter += 1;
idle_state.on_client_frame(Instant::now());
idle_state.tiny_frame_debt = idle_state.tiny_frame_debt.saturating_sub(1);
clear_relay_idle_candidate(forensics.conn_id);
return Ok(Some((payload, quickack)));
}
@ -1707,39 +1767,57 @@ where
trace!(conn_id, bytes = data.len(), flags, "ME->C data");
}
let data_len = data.len() as u64;
if quota_would_be_exceeded_for_user_soft(
stats,
user,
quota_limit,
data_len,
quota_soft_overshoot_bytes,
) {
stats.increment_me_d2c_quota_reject_total(MeD2cQuotaRejectStage::PreWrite);
return Err(ProxyError::DataQuotaExceeded {
user: user.to_string(),
});
}
if let Some(limit) = quota_limit {
let soft_limit = quota_soft_cap(limit, quota_soft_overshoot_bytes);
if quota_would_be_exceeded_for_user(stats, user, Some(soft_limit), data_len) {
stats.increment_me_d2c_quota_reject_total(MeD2cQuotaRejectStage::PreWrite);
return Err(ProxyError::DataQuotaExceeded {
user: user.to_string(),
});
}
let write_mode =
write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf)
.await?;
stats.increment_me_d2c_write_mode(write_mode);
// Reserve quota before awaiting network I/O to avoid same-user HoL stalls.
// If reservation loses a race or write fails, we roll back immediately.
bytes_me2c.fetch_add(data_len, Ordering::Relaxed);
stats.add_user_octets_to(user, data_len);
bytes_me2c.fetch_add(data.len() as u64, Ordering::Relaxed);
stats.add_user_octets_to(user, data.len() as u64);
stats.increment_me_d2c_data_frames_total();
stats.add_me_d2c_payload_bytes_total(data.len() as u64);
if stats.get_user_total_octets(user) > soft_limit {
rollback_me2c_quota_reservation(stats, user, bytes_me2c, data_len);
stats.increment_me_d2c_quota_reject_total(MeD2cQuotaRejectStage::PreWrite);
return Err(ProxyError::DataQuotaExceeded {
user: user.to_string(),
});
}
if quota_exceeded_for_user_soft(
stats,
user,
quota_limit,
quota_soft_overshoot_bytes,
) {
stats.increment_me_d2c_quota_reject_total(MeD2cQuotaRejectStage::PostWrite);
return Err(ProxyError::DataQuotaExceeded {
user: user.to_string(),
});
let write_mode =
match write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf)
.await
{
Ok(mode) => mode,
Err(err) => {
rollback_me2c_quota_reservation(stats, user, bytes_me2c, data_len);
return Err(err);
}
};
stats.increment_me_d2c_data_frames_total();
stats.add_me_d2c_payload_bytes_total(data_len);
stats.increment_me_d2c_write_mode(write_mode);
// Do not fail immediately on exact boundary after a successful write.
// Returning an error here can bypass batch flush in the caller and risk
// dropping buffered ciphertext from CryptoWriter. The next frame is
// rejected by the pre-check at function entry.
} else {
let write_mode =
write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf)
.await?;
bytes_me2c.fetch_add(data_len, Ordering::Relaxed);
stats.add_user_octets_to(user, data_len);
stats.increment_me_d2c_data_frames_total();
stats.add_me_d2c_payload_bytes_total(data_len);
stats.increment_me_d2c_write_mode(write_mode);
}
Ok(MeWriterResponseOutcome::Continue {
@ -1978,3 +2056,31 @@ mod length_cast_hardening_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_blackhat_campaign_integration_tests.rs"]
mod blackhat_campaign_integration_tests;
#[cfg(test)]
#[path = "tests/middle_relay_hol_quota_security_tests.rs"]
mod hol_quota_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_quota_reservation_adversarial_tests.rs"]
mod quota_reservation_adversarial_tests;
#[cfg(test)]
#[path = "tests/middle_relay_idle_registry_poison_security_tests.rs"]
mod middle_relay_idle_registry_poison_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_zero_length_frame_security_tests.rs"]
mod middle_relay_zero_length_frame_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_tiny_frame_debt_security_tests.rs"]
mod middle_relay_tiny_frame_debt_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_tiny_frame_debt_concurrency_security_tests.rs"]
mod middle_relay_tiny_frame_debt_concurrency_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_tiny_frame_debt_proto_chunking_security_tests.rs"]
mod middle_relay_tiny_frame_debt_proto_chunking_security_tests;

View File

@ -64,6 +64,7 @@ pub mod direct_relay;
pub mod handshake;
pub mod masking;
pub mod middle_relay;
pub mod quota_lock_registry;
pub mod relay;
pub mod route_mode;
pub mod session_eviction;

View File

@ -0,0 +1,53 @@
use dashmap::DashMap;
use std::sync::{Arc, Mutex, OnceLock};
#[cfg(test)]
const CROSS_MODE_QUOTA_USER_LOCKS_MAX: usize = 64;
#[cfg(not(test))]
const CROSS_MODE_QUOTA_USER_LOCKS_MAX: usize = 4_096;
#[cfg(test)]
const CROSS_MODE_QUOTA_OVERFLOW_LOCK_STRIPES: usize = 16;
#[cfg(not(test))]
const CROSS_MODE_QUOTA_OVERFLOW_LOCK_STRIPES: usize = 256;
static CROSS_MODE_QUOTA_USER_LOCKS: OnceLock<DashMap<String, Arc<Mutex<()>>>> = OnceLock::new();
static CROSS_MODE_QUOTA_USER_OVERFLOW_LOCKS: OnceLock<Vec<Arc<Mutex<()>>>> = OnceLock::new();
fn cross_mode_quota_overflow_user_lock(user: &str) -> Arc<Mutex<()>> {
let stripes = CROSS_MODE_QUOTA_USER_OVERFLOW_LOCKS.get_or_init(|| {
(0..CROSS_MODE_QUOTA_OVERFLOW_LOCK_STRIPES)
.map(|_| Arc::new(Mutex::new(())))
.collect()
});
let hash = crc32fast::hash(user.as_bytes()) as usize;
Arc::clone(&stripes[hash % stripes.len()])
}
pub(crate) fn cross_mode_quota_user_lock(user: &str) -> Arc<Mutex<()>> {
let locks = CROSS_MODE_QUOTA_USER_LOCKS.get_or_init(DashMap::new);
if let Some(existing) = locks.get(user) {
return Arc::clone(existing.value());
}
if locks.len() >= CROSS_MODE_QUOTA_USER_LOCKS_MAX {
locks.retain(|_, value| Arc::strong_count(value) > 1);
}
if locks.len() >= CROSS_MODE_QUOTA_USER_LOCKS_MAX {
return cross_mode_quota_overflow_user_lock(user);
}
let created = Arc::new(Mutex::new(()));
match locks.entry(user.to_string()) {
dashmap::mapref::entry::Entry::Occupied(entry) => Arc::clone(entry.get()),
dashmap::mapref::entry::Entry::Vacant(entry) => {
entry.insert(Arc::clone(&created));
created
}
}
}
#[cfg(test)]
#[path = "tests/quota_lock_registry_cross_mode_adversarial_tests.rs"]
mod quota_lock_registry_cross_mode_adversarial_tests;

View File

@ -62,7 +62,7 @@ use std::sync::{Arc, Mutex, OnceLock};
use std::task::{Context, Poll};
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf, copy_bidirectional_with_sizes};
use tokio::time::Instant;
use tokio::time::{Instant, Sleep};
use tracing::{debug, trace, warn};
// ============= Constants =============
@ -209,12 +209,16 @@ struct StatsIo<S> {
counters: Arc<SharedCounters>,
stats: Arc<Stats>,
user: String,
quota_lock: Option<Arc<Mutex<()>>>,
cross_mode_quota_lock: Option<Arc<Mutex<()>>>,
quota_limit: Option<u64>,
quota_exceeded: Arc<AtomicBool>,
quota_read_wake_scheduled: bool,
quota_write_wake_scheduled: bool,
quota_read_retry_active: Arc<AtomicBool>,
quota_write_retry_active: Arc<AtomicBool>,
quota_read_retry_sleep: Option<Pin<Box<Sleep>>>,
quota_write_retry_sleep: Option<Pin<Box<Sleep>>>,
quota_read_retry_attempt: u8,
quota_write_retry_attempt: u8,
epoch: Instant,
}
@ -230,30 +234,29 @@ impl<S> StatsIo<S> {
) -> Self {
// Mark initial activity so the watchdog doesn't fire before data flows
counters.touch(Instant::now(), epoch);
let quota_lock = quota_limit.map(|_| quota_user_lock(&user));
let cross_mode_quota_lock = quota_limit
.map(|_| crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user));
Self {
inner,
counters,
stats,
user,
quota_lock,
cross_mode_quota_lock,
quota_limit,
quota_exceeded,
quota_read_wake_scheduled: false,
quota_write_wake_scheduled: false,
quota_read_retry_active: Arc::new(AtomicBool::new(false)),
quota_write_retry_active: Arc::new(AtomicBool::new(false)),
quota_read_retry_sleep: None,
quota_write_retry_sleep: None,
quota_read_retry_attempt: 0,
quota_write_retry_attempt: 0,
epoch,
}
}
}
impl<S> Drop for StatsIo<S> {
fn drop(&mut self) {
self.quota_read_retry_active.store(false, Ordering::Relaxed);
self.quota_write_retry_active
.store(false, Ordering::Relaxed);
}
}
#[derive(Debug)]
struct QuotaIoSentinel;
@ -281,20 +284,52 @@ fn is_quota_io_error(err: &io::Error) -> bool {
const QUOTA_CONTENTION_RETRY_INTERVAL: Duration = Duration::from_millis(1);
#[cfg(not(test))]
const QUOTA_CONTENTION_RETRY_INTERVAL: Duration = Duration::from_millis(2);
#[cfg(test)]
const QUOTA_CONTENTION_RETRY_MAX_INTERVAL: Duration = Duration::from_millis(16);
#[cfg(not(test))]
const QUOTA_CONTENTION_RETRY_MAX_INTERVAL: Duration = Duration::from_millis(64);
fn spawn_quota_retry_waker(retry_active: Arc<AtomicBool>, waker: std::task::Waker) {
tokio::task::spawn(async move {
loop {
if !retry_active.load(Ordering::Relaxed) {
break;
}
tokio::time::sleep(QUOTA_CONTENTION_RETRY_INTERVAL).await;
if !retry_active.load(Ordering::Relaxed) {
break;
}
waker.wake_by_ref();
}
});
#[inline]
fn quota_contention_retry_delay(retry_attempt: u8) -> Duration {
let shift = u32::from(retry_attempt.min(5));
let multiplier = 1_u32 << shift;
QUOTA_CONTENTION_RETRY_INTERVAL
.saturating_mul(multiplier)
.min(QUOTA_CONTENTION_RETRY_MAX_INTERVAL)
}
#[inline]
fn reset_quota_retry_scheduler(
sleep_slot: &mut Option<Pin<Box<Sleep>>>,
wake_scheduled: &mut bool,
retry_attempt: &mut u8,
) {
*wake_scheduled = false;
*sleep_slot = None;
*retry_attempt = 0;
}
fn poll_quota_retry_sleep(
sleep_slot: &mut Option<Pin<Box<Sleep>>>,
wake_scheduled: &mut bool,
retry_attempt: &mut u8,
cx: &mut Context<'_>,
) {
if !*wake_scheduled {
*wake_scheduled = true;
*sleep_slot = Some(Box::pin(tokio::time::sleep(quota_contention_retry_delay(
*retry_attempt,
))));
}
if let Some(sleep) = sleep_slot.as_mut()
&& sleep.as_mut().poll(cx).is_ready()
{
*sleep_slot = None;
*wake_scheduled = false;
*retry_attempt = retry_attempt.saturating_add(1);
cx.waker().wake_by_ref();
}
}
static QUOTA_USER_LOCKS: OnceLock<DashMap<String, Arc<Mutex<()>>>> = OnceLock::new();
@ -357,6 +392,11 @@ fn quota_user_lock(user: &str) -> Arc<Mutex<()>> {
}
}
#[cfg(test)]
pub(crate) fn cross_mode_quota_user_lock_for_tests(user: &str) -> Arc<Mutex<()>> {
crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(user)
}
impl<S: AsyncRead + Unpin> AsyncRead for StatsIo<S> {
fn poll_read(
self: Pin<&mut Self>,
@ -368,26 +408,47 @@ impl<S: AsyncRead + Unpin> AsyncRead for StatsIo<S> {
return Poll::Ready(Err(quota_io_error()));
}
let quota_lock = this
.quota_limit
.is_some()
.then(|| quota_user_lock(&this.user));
let _quota_guard = if let Some(lock) = quota_lock.as_ref() {
let _quota_guard = if let Some(lock) = this.quota_lock.as_ref() {
match lock.try_lock() {
Ok(guard) => {
this.quota_read_wake_scheduled = false;
this.quota_read_retry_active.store(false, Ordering::Relaxed);
reset_quota_retry_scheduler(
&mut this.quota_read_retry_sleep,
&mut this.quota_read_wake_scheduled,
&mut this.quota_read_retry_attempt,
);
Some(guard)
}
Err(_) => {
if !this.quota_read_wake_scheduled {
this.quota_read_wake_scheduled = true;
this.quota_read_retry_active.store(true, Ordering::Relaxed);
spawn_quota_retry_waker(
Arc::clone(&this.quota_read_retry_active),
cx.waker().clone(),
);
}
poll_quota_retry_sleep(
&mut this.quota_read_retry_sleep,
&mut this.quota_read_wake_scheduled,
&mut this.quota_read_retry_attempt,
cx,
);
return Poll::Pending;
}
}
} else {
None
};
let _cross_mode_quota_guard = if let Some(lock) = this.cross_mode_quota_lock.as_ref() {
match lock.try_lock() {
Ok(guard) => {
reset_quota_retry_scheduler(
&mut this.quota_read_retry_sleep,
&mut this.quota_read_wake_scheduled,
&mut this.quota_read_retry_attempt,
);
Some(guard)
}
Err(_) => {
poll_quota_retry_sleep(
&mut this.quota_read_retry_sleep,
&mut this.quota_read_wake_scheduled,
&mut this.quota_read_retry_attempt,
cx,
);
return Poll::Pending;
}
}
@ -460,27 +521,47 @@ impl<S: AsyncWrite + Unpin> AsyncWrite for StatsIo<S> {
return Poll::Ready(Err(quota_io_error()));
}
let quota_lock = this
.quota_limit
.is_some()
.then(|| quota_user_lock(&this.user));
let _quota_guard = if let Some(lock) = quota_lock.as_ref() {
let _quota_guard = if let Some(lock) = this.quota_lock.as_ref() {
match lock.try_lock() {
Ok(guard) => {
this.quota_write_wake_scheduled = false;
this.quota_write_retry_active
.store(false, Ordering::Relaxed);
reset_quota_retry_scheduler(
&mut this.quota_write_retry_sleep,
&mut this.quota_write_wake_scheduled,
&mut this.quota_write_retry_attempt,
);
Some(guard)
}
Err(_) => {
if !this.quota_write_wake_scheduled {
this.quota_write_wake_scheduled = true;
this.quota_write_retry_active.store(true, Ordering::Relaxed);
spawn_quota_retry_waker(
Arc::clone(&this.quota_write_retry_active),
cx.waker().clone(),
);
}
poll_quota_retry_sleep(
&mut this.quota_write_retry_sleep,
&mut this.quota_write_wake_scheduled,
&mut this.quota_write_retry_attempt,
cx,
);
return Poll::Pending;
}
}
} else {
None
};
let _cross_mode_quota_guard = if let Some(lock) = this.cross_mode_quota_lock.as_ref() {
match lock.try_lock() {
Ok(guard) => {
reset_quota_retry_scheduler(
&mut this.quota_write_retry_sleep,
&mut this.quota_write_wake_scheduled,
&mut this.quota_write_retry_attempt,
);
Some(guard)
}
Err(_) => {
poll_quota_retry_sleep(
&mut this.quota_write_retry_sleep,
&mut this.quota_write_wake_scheduled,
&mut this.quota_write_retry_attempt,
cx,
);
return Poll::Pending;
}
}
@ -791,3 +872,27 @@ mod relay_quota_waker_storm_adversarial_tests;
#[cfg(test)]
#[path = "tests/relay_quota_wake_liveness_regression_tests.rs"]
mod relay_quota_wake_liveness_regression_tests;
#[cfg(test)]
#[path = "tests/relay_quota_lock_identity_security_tests.rs"]
mod relay_quota_lock_identity_security_tests;
#[cfg(test)]
#[path = "tests/relay_cross_mode_quota_lock_security_tests.rs"]
mod relay_cross_mode_quota_lock_security_tests;
#[cfg(test)]
#[path = "tests/relay_quota_retry_scheduler_tdd_tests.rs"]
mod relay_quota_retry_scheduler_tdd_tests;
#[cfg(test)]
#[path = "tests/relay_cross_mode_quota_fairness_tdd_tests.rs"]
mod relay_cross_mode_quota_fairness_tdd_tests;
#[cfg(test)]
#[path = "tests/relay_quota_retry_backoff_security_tests.rs"]
mod relay_quota_retry_backoff_security_tests;
#[cfg(test)]
#[path = "tests/relay_quota_retry_backoff_benchmark_security_tests.rs"]
mod relay_quota_retry_backoff_benchmark_security_tests;

View File

@ -0,0 +1,100 @@
use super::*;
use crate::config::{UpstreamConfig, UpstreamType};
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::net::TcpListener;
use tokio::time::Duration;
fn new_upstream_manager(stats: Arc<Stats>) -> Arc<UpstreamManager> {
Arc::new(UpstreamManager::new(
vec![UpstreamConfig {
upstream_type: UpstreamType::Direct {
interface: None,
bind_addresses: None,
},
weight: 1,
enabled: true,
scopes: String::new(),
selected_scope: String::new(),
}],
1,
1,
1,
1,
false,
stats,
))
}
#[tokio::test]
async fn fragmented_connect_probe_is_classified_as_http_via_prefetch_window() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let accept_task = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut got = Vec::new();
stream.read_to_end(&mut got).await.unwrap();
got
});
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = true;
cfg.general.beobachten_minutes = 1;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = backend_addr.port();
cfg.general.modes.classic = false;
cfg.general.modes.secure = false;
let config = Arc::new(cfg);
let stats = Arc::new(Stats::new());
let beobachten = Arc::new(BeobachtenStore::new());
let (server_side, mut client_side) = duplex(4096);
let peer: SocketAddr = "198.51.100.251:57501".parse().unwrap();
let handler = tokio::spawn(handle_client_stream(
server_side,
peer,
config,
stats.clone(),
new_upstream_manager(stats),
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
None,
Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
None,
Arc::new(UserIpTracker::new()),
beobachten.clone(),
false,
));
client_side.write_all(b"CONNE").await.unwrap();
client_side
.write_all(b"CT example.org:443 HTTP/1.1\r\nHost: example.org\r\n\r\n")
.await
.unwrap();
client_side.shutdown().await.unwrap();
let forwarded = tokio::time::timeout(Duration::from_secs(3), accept_task)
.await
.unwrap()
.unwrap();
assert!(
forwarded.starts_with(b"CONNECT example.org:443 HTTP/1.1"),
"mask backend must receive the full fragmented CONNECT probe"
);
let result = tokio::time::timeout(Duration::from_secs(3), handler)
.await
.unwrap()
.unwrap();
assert!(result.is_ok());
let snapshot = beobachten.snapshot_text(Duration::from_secs(60));
assert!(snapshot.contains("[HTTP]"));
assert!(snapshot.contains("198.51.100.251-1"));
}

View File

@ -0,0 +1,129 @@
use super::*;
use crate::config::{UpstreamConfig, UpstreamType};
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::net::TcpListener;
use tokio::time::{Duration, sleep};
fn new_upstream_manager(stats: Arc<Stats>) -> Arc<UpstreamManager> {
Arc::new(UpstreamManager::new(
vec![UpstreamConfig {
upstream_type: UpstreamType::Direct {
interface: None,
bind_addresses: None,
},
weight: 1,
enabled: true,
scopes: String::new(),
selected_scope: String::new(),
}],
1,
1,
1,
1,
false,
stats,
))
}
async fn run_http2_fragment_case(split_at: usize, delay_ms: u64, peer: SocketAddr) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let preface = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n".to_vec();
let accept_task = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut got = Vec::new();
stream.read_to_end(&mut got).await.unwrap();
got
});
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = true;
cfg.general.beobachten_minutes = 1;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = backend_addr.port();
cfg.general.modes.classic = false;
cfg.general.modes.secure = false;
let config = Arc::new(cfg);
let stats = Arc::new(Stats::new());
let beobachten = Arc::new(BeobachtenStore::new());
let (server_side, mut client_side) = duplex(4096);
let handler = tokio::spawn(handle_client_stream(
server_side,
peer,
config,
stats.clone(),
new_upstream_manager(stats),
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
None,
Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
None,
Arc::new(UserIpTracker::new()),
beobachten.clone(),
false,
));
let first = split_at.min(preface.len());
client_side.write_all(&preface[..first]).await.unwrap();
if first < preface.len() {
sleep(Duration::from_millis(delay_ms)).await;
client_side.write_all(&preface[first..]).await.unwrap();
}
client_side.shutdown().await.unwrap();
let forwarded = tokio::time::timeout(Duration::from_secs(3), accept_task)
.await
.unwrap()
.unwrap();
assert!(
forwarded.starts_with(&preface),
"mask backend must receive an intact HTTP/2 preface prefix"
);
let result = tokio::time::timeout(Duration::from_secs(3), handler)
.await
.unwrap()
.unwrap();
assert!(result.is_ok());
let snapshot = beobachten.snapshot_text(Duration::from_secs(60));
assert!(snapshot.contains("[HTTP]"));
assert!(snapshot.contains(&format!("{}-1", peer.ip())));
}
#[tokio::test]
async fn http2_preface_fragmentation_matrix_is_classified_and_forwarded() {
let cases = [
(2usize, 0u64),
(3, 0),
(4, 0),
(2, 7),
(3, 7),
(8, 1),
];
for (i, (split_at, delay_ms)) in cases.into_iter().enumerate() {
let peer: SocketAddr = format!("198.51.100.{}:58{}", 140 + i, 100 + i)
.parse()
.unwrap();
run_http2_fragment_case(split_at, delay_ms, peer).await;
}
}
#[tokio::test]
async fn http2_preface_splitpoint_light_fuzz_classifies_http() {
for split_at in 2usize..=12 {
let delay_ms = if split_at % 3 == 0 { 7 } else { 1 };
let peer: SocketAddr = format!("198.51.101.{}:59{}", split_at, 10 + split_at)
.parse()
.unwrap();
run_http2_fragment_case(split_at, delay_ms, peer).await;
}
}

View File

@ -0,0 +1,150 @@
use super::*;
use crate::config::{UpstreamConfig, UpstreamType};
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::net::TcpListener;
use tokio::time::{Duration, sleep};
fn new_upstream_manager(stats: Arc<Stats>) -> Arc<UpstreamManager> {
Arc::new(UpstreamManager::new(
vec![UpstreamConfig {
upstream_type: UpstreamType::Direct {
interface: None,
bind_addresses: None,
},
weight: 1,
enabled: true,
scopes: String::new(),
selected_scope: String::new(),
}],
1,
1,
1,
1,
false,
stats,
))
}
async fn run_pipeline_prefetch_case(
prefetch_timeout_ms: u64,
delayed_tail_ms: u64,
peer: SocketAddr,
) -> (Vec<u8>, String) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let accept_task = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut got = Vec::new();
stream.read_to_end(&mut got).await.unwrap();
got
});
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = true;
cfg.general.beobachten_minutes = 1;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = backend_addr.port();
cfg.censorship.mask_classifier_prefetch_timeout_ms = prefetch_timeout_ms;
cfg.general.modes.classic = false;
cfg.general.modes.secure = false;
let config = Arc::new(cfg);
let stats = Arc::new(Stats::new());
let beobachten = Arc::new(BeobachtenStore::new());
let (server_side, mut client_side) = duplex(4096);
let handler = tokio::spawn(handle_client_stream(
server_side,
peer,
config,
stats.clone(),
new_upstream_manager(stats),
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
None,
Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
None,
Arc::new(UserIpTracker::new()),
beobachten.clone(),
false,
));
client_side.write_all(b"C").await.unwrap();
sleep(Duration::from_millis(delayed_tail_ms)).await;
client_side
.write_all(b"ONNECT example.org:443 HTTP/1.1\r\nHost: example.org\r\n\r\n")
.await
.unwrap();
client_side.shutdown().await.unwrap();
let forwarded = tokio::time::timeout(Duration::from_secs(3), accept_task)
.await
.unwrap()
.unwrap();
let result = tokio::time::timeout(Duration::from_secs(3), handler)
.await
.unwrap()
.unwrap();
assert!(result.is_ok());
let snapshot = beobachten.snapshot_text(Duration::from_secs(60));
(forwarded, snapshot)
}
#[tokio::test]
async fn tdd_pipeline_prefetch_5ms_misses_15ms_tail_and_classifies_as_port_scanner() {
let peer: SocketAddr = "198.51.100.171:58071".parse().unwrap();
let (forwarded, snapshot) = run_pipeline_prefetch_case(5, 15, peer).await;
assert!(
forwarded.starts_with(b"CONNECT"),
"mask backend must still receive full payload bytes in-order"
);
assert!(
snapshot.contains("[HTTP]") || snapshot.contains("[port-scanner]"),
"unexpected classifier snapshot for 5ms delayed-tail case: {snapshot}"
);
}
#[tokio::test]
async fn tdd_pipeline_prefetch_20ms_recovers_15ms_tail_and_classifies_as_http() {
let peer: SocketAddr = "198.51.100.172:58072".parse().unwrap();
let (forwarded, snapshot) = run_pipeline_prefetch_case(20, 15, peer).await;
assert!(
forwarded.starts_with(b"CONNECT"),
"mask backend must receive full CONNECT payload"
);
assert!(
snapshot.contains("[HTTP]"),
"20ms budget should recover delayed fragmented prefix and classify as HTTP"
);
}
#[tokio::test]
async fn matrix_pipeline_prefetch_budget_behavior_5_20_50ms() {
let peer5: SocketAddr = "198.51.100.173:58073".parse().unwrap();
let peer20: SocketAddr = "198.51.100.174:58074".parse().unwrap();
let peer50: SocketAddr = "198.51.100.175:58075".parse().unwrap();
let (_, snap5) = run_pipeline_prefetch_case(5, 35, peer5).await;
let (_, snap20) = run_pipeline_prefetch_case(20, 35, peer20).await;
let (_, snap50) = run_pipeline_prefetch_case(50, 35, peer50).await;
assert!(
snap5.contains("[HTTP]") || snap5.contains("[port-scanner]"),
"unexpected 5ms snapshot: {snap5}"
);
assert!(
snap20.contains("[HTTP]") || snap20.contains("[port-scanner]"),
"unexpected 20ms snapshot: {snap20}"
);
assert!(snap50.contains("[HTTP]"));
}

View File

@ -0,0 +1,82 @@
use super::*;
use tokio::io::{AsyncWriteExt, duplex};
use tokio::time::{Duration, sleep};
#[test]
fn prefetch_timeout_budget_reads_from_config() {
let mut cfg = ProxyConfig::default();
assert_eq!(
mask_classifier_prefetch_timeout(&cfg),
Duration::from_millis(5),
"default prefetch timeout budget must remain 5ms"
);
cfg.censorship.mask_classifier_prefetch_timeout_ms = 20;
assert_eq!(
mask_classifier_prefetch_timeout(&cfg),
Duration::from_millis(20),
"runtime prefetch timeout budget must follow configured value"
);
}
#[tokio::test]
async fn configured_prefetch_budget_20ms_recovers_tail_delayed_15ms() {
let (mut reader, mut writer) = duplex(1024);
let writer_task = tokio::spawn(async move {
sleep(Duration::from_millis(15)).await;
writer
.write_all(b"ONNECT example.org:443 HTTP/1.1\r\n")
.await
.expect("tail bytes must be writable");
writer.shutdown().await.expect("writer shutdown must succeed");
});
let mut initial_data = b"C".to_vec();
extend_masking_initial_window_with_timeout(
&mut reader,
&mut initial_data,
Duration::from_millis(20),
)
.await;
writer_task
.await
.expect("writer task must not panic in runtime timeout test");
assert!(
initial_data.starts_with(b"CONNECT"),
"20ms configured prefetch budget should recover 15ms delayed CONNECT tail"
);
}
#[tokio::test]
async fn configured_prefetch_budget_5ms_misses_tail_delayed_15ms() {
let (mut reader, mut writer) = duplex(1024);
let writer_task = tokio::spawn(async move {
sleep(Duration::from_millis(15)).await;
writer
.write_all(b"ONNECT example.org:443 HTTP/1.1\r\n")
.await
.expect("tail bytes must be writable");
writer.shutdown().await.expect("writer shutdown must succeed");
});
let mut initial_data = b"C".to_vec();
extend_masking_initial_window_with_timeout(
&mut reader,
&mut initial_data,
Duration::from_millis(5),
)
.await;
writer_task
.await
.expect("writer task must not panic in runtime timeout test");
assert!(
!initial_data.starts_with(b"CONNECT"),
"5ms configured prefetch budget should miss 15ms delayed CONNECT tail"
);
}

View File

@ -0,0 +1,261 @@
use super::*;
use crate::config::{UpstreamConfig, UpstreamType};
use crate::crypto::sha256_hmac;
use crate::protocol::constants::{HANDSHAKE_LEN, TLS_VERSION};
use crate::protocol::tls;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::net::TcpListener;
struct PipelineHarness {
config: Arc<ProxyConfig>,
stats: Arc<Stats>,
upstream_manager: Arc<UpstreamManager>,
replay_checker: Arc<ReplayChecker>,
buffer_pool: Arc<BufferPool>,
rng: Arc<SecureRandom>,
route_runtime: Arc<RouteRuntimeController>,
ip_tracker: Arc<UserIpTracker>,
beobachten: Arc<BeobachtenStore>,
}
fn build_harness(secret_hex: &str, mask_port: u16) -> PipelineHarness {
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = mask_port;
cfg.censorship.mask_proxy_protocol = 0;
cfg.access.ignore_time_skew = true;
cfg.access
.users
.insert("user".to_string(), secret_hex.to_string());
let config = Arc::new(cfg);
let stats = Arc::new(Stats::new());
let upstream_manager = Arc::new(UpstreamManager::new(
vec![UpstreamConfig {
upstream_type: UpstreamType::Direct {
interface: None,
bind_addresses: None,
},
weight: 1,
enabled: true,
scopes: String::new(),
selected_scope: String::new(),
}],
1,
1,
1,
1,
false,
stats.clone(),
));
PipelineHarness {
config,
stats,
upstream_manager,
replay_checker: Arc::new(ReplayChecker::new(256, Duration::from_secs(60))),
buffer_pool: Arc::new(BufferPool::new()),
rng: Arc::new(SecureRandom::new()),
route_runtime: Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
ip_tracker: Arc::new(UserIpTracker::new()),
beobachten: Arc::new(BeobachtenStore::new()),
}
}
fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec<u8> {
let total_len = 5 + tls_len;
let mut handshake = vec![fill; total_len];
handshake[0] = 0x16;
handshake[1] = 0x03;
handshake[2] = 0x01;
handshake[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes());
let session_id_len: usize = 32;
handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8;
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0);
let computed = sha256_hmac(secret, &handshake);
let mut digest = computed;
let ts = timestamp.to_le_bytes();
for i in 0..4 {
digest[28 + i] ^= ts[i];
}
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN]
.copy_from_slice(&digest);
handshake
}
fn wrap_tls_application_data(payload: &[u8]) -> Vec<u8> {
let mut record = Vec::with_capacity(5 + payload.len());
record.push(0x17);
record.extend_from_slice(&TLS_VERSION);
record.extend_from_slice(&(payload.len() as u16).to_be_bytes());
record.extend_from_slice(payload);
record
}
async fn read_and_discard_tls_record_body<T>(stream: &mut T, header: [u8; 5])
where
T: tokio::io::AsyncRead + Unpin,
{
let len = u16::from_be_bytes([header[3], header[4]]) as usize;
let mut body = vec![0u8; len];
stream.read_exact(&mut body).await.unwrap();
}
#[test]
fn empty_initial_data_prefetch_gate_is_fail_closed() {
assert!(
!should_prefetch_mask_classifier_window(&[]),
"empty initial_data must not trigger classifier prefetch"
);
}
#[tokio::test]
async fn blackhat_empty_initial_data_prefetch_must_not_consume_fallback_payload() {
let payload = b"\x17\x03\x03\x00\x10coalesced-tail-bytes".to_vec();
let (mut reader, mut writer) = duplex(1024);
writer.write_all(&payload).await.unwrap();
writer.shutdown().await.unwrap();
let mut initial_data = Vec::new();
extend_masking_initial_window(&mut reader, &mut initial_data).await;
assert!(
initial_data.is_empty(),
"empty initial_data must remain empty after prefetch stage"
);
let mut remaining = Vec::new();
reader.read_to_end(&mut remaining).await.unwrap();
assert_eq!(
remaining, payload,
"prefetch stage must not consume fallback payload when initial_data is empty"
);
}
#[tokio::test]
async fn positive_fragmented_http_prefix_still_prefetches_within_window() {
let (mut reader, mut writer) = duplex(1024);
writer
.write_all(b"NECT example.org:443 HTTP/1.1\r\n")
.await
.unwrap();
writer.shutdown().await.unwrap();
let mut initial_data = b"CON".to_vec();
extend_masking_initial_window(&mut reader, &mut initial_data).await;
assert!(
initial_data.starts_with(b"CONNECT"),
"fragmented HTTP method prefix should still be recoverable by prefetch"
);
assert!(
initial_data.len() <= 16,
"prefetch window must remain bounded"
);
}
#[tokio::test]
async fn light_fuzz_empty_initial_data_never_prefetches_any_bytes() {
let mut seed = 0xD15C_A11E_2026_0322u64;
for _ in 0..128 {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let len = ((seed & 0x3f) as usize).saturating_add(1);
let mut payload = vec![0u8; len];
for (idx, byte) in payload.iter_mut().enumerate() {
*byte = (seed as u8).wrapping_add(idx as u8).wrapping_mul(17);
}
let (mut reader, mut writer) = duplex(1024);
writer.write_all(&payload).await.unwrap();
writer.shutdown().await.unwrap();
let mut initial_data = Vec::new();
extend_masking_initial_window(&mut reader, &mut initial_data).await;
assert!(initial_data.is_empty());
let mut remaining = Vec::new();
reader.read_to_end(&mut remaining).await.unwrap();
assert_eq!(remaining, payload);
}
}
#[tokio::test]
async fn blackhat_integration_empty_initial_data_path_is_byte_exact_and_eof_clean() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let secret = [0xD3u8; 16];
let client_hello = make_valid_tls_client_hello(&secret, 411, 600, 0x2B);
let mut invalid_payload = vec![0u8; HANDSHAKE_LEN];
invalid_payload[0] = 0xFF;
let invalid_mtproto_record = wrap_tls_application_data(&invalid_payload);
let trailing_record = wrap_tls_application_data(b"empty-prefetch-invariant");
let expected = trailing_record.clone();
let accept_task = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut got = vec![0u8; expected.len()];
stream.read_exact(&mut got).await.unwrap();
assert_eq!(got, expected);
let mut one = [0u8; 1];
let n = stream.read(&mut one).await.unwrap();
assert_eq!(
n, 0,
"fallback stream must not append synthetic bytes on empty initial_data path"
);
});
let harness = build_harness("d3d3d3d3d3d3d3d3d3d3d3d3d3d3d3d3", backend_addr.port());
let (server_side, mut client_side) = duplex(131072);
let handler = tokio::spawn(handle_client_stream(
server_side,
"198.51.100.245:56145".parse().unwrap(),
harness.config,
harness.stats,
harness.upstream_manager,
harness.replay_checker,
harness.buffer_pool,
harness.rng,
None,
harness.route_runtime,
None,
harness.ip_tracker,
harness.beobachten,
false,
));
client_side.write_all(&client_hello).await.unwrap();
let mut head = [0u8; 5];
client_side.read_exact(&mut head).await.unwrap();
assert_eq!(head[0], 0x16);
read_and_discard_tls_record_body(&mut client_side, head).await;
client_side.write_all(&invalid_mtproto_record).await.unwrap();
client_side.write_all(&trailing_record).await.unwrap();
client_side.shutdown().await.unwrap();
tokio::time::timeout(Duration::from_secs(3), accept_task)
.await
.unwrap()
.unwrap();
let _ = tokio::time::timeout(Duration::from_secs(3), handler)
.await
.unwrap()
.unwrap();
}

View File

@ -0,0 +1,70 @@
use super::*;
use tokio::io::{AsyncWriteExt, duplex};
use tokio::time::{Duration, advance, sleep};
async fn run_strict_prefetch_case(prefetch_ms: u64, tail_delay_ms: u64) -> Vec<u8> {
let (mut reader, mut writer) = duplex(1024);
let writer_task = tokio::spawn(async move {
sleep(Duration::from_millis(tail_delay_ms)).await;
let _ = writer.write_all(b"ONNECT example.org:443 HTTP/1.1\r\n").await;
let _ = writer.shutdown().await;
});
let mut initial_data = b"C".to_vec();
let mut prefetch_task = tokio::spawn(async move {
extend_masking_initial_window_with_timeout(
&mut reader,
&mut initial_data,
Duration::from_millis(prefetch_ms),
)
.await;
initial_data
});
tokio::task::yield_now().await;
if tail_delay_ms > 0 {
advance(Duration::from_millis(tail_delay_ms)).await;
tokio::task::yield_now().await;
}
if prefetch_ms > tail_delay_ms {
advance(Duration::from_millis(prefetch_ms - tail_delay_ms)).await;
tokio::task::yield_now().await;
}
let result = prefetch_task.await.expect("prefetch task must not panic");
writer_task.await.expect("writer task must not panic");
result
}
#[tokio::test(start_paused = true)]
async fn strict_prefetch_5ms_misses_15ms_tail() {
let got = run_strict_prefetch_case(5, 15).await;
assert_eq!(got, b"C".to_vec());
}
#[tokio::test(start_paused = true)]
async fn strict_prefetch_20ms_recovers_15ms_tail() {
let got = run_strict_prefetch_case(20, 15).await;
assert!(got.starts_with(b"CONNECT"));
}
#[tokio::test(start_paused = true)]
async fn strict_prefetch_50ms_recovers_35ms_tail() {
let got = run_strict_prefetch_case(50, 35).await;
assert!(got.starts_with(b"CONNECT"));
}
#[tokio::test(start_paused = true)]
async fn strict_prefetch_equal_budget_and_delay_recovers_tail() {
let got = run_strict_prefetch_case(20, 20).await;
assert!(got.starts_with(b"CONNECT"));
}
#[tokio::test(start_paused = true)]
async fn strict_prefetch_one_ms_after_budget_misses_tail() {
let got = run_strict_prefetch_case(20, 21).await;
assert_eq!(got, b"C".to_vec());
}

View File

@ -0,0 +1,95 @@
use super::*;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt, duplex};
use tokio::time::{Duration, sleep, timeout};
async fn extend_masking_initial_window_with_budget<R>(
reader: &mut R,
initial_data: &mut Vec<u8>,
prefetch_timeout: Duration,
) where
R: AsyncRead + Unpin,
{
if !should_prefetch_mask_classifier_window(initial_data) {
return;
}
let need = 16usize.saturating_sub(initial_data.len());
if need == 0 {
return;
}
let mut extra = [0u8; 16];
if let Ok(Ok(n)) = timeout(prefetch_timeout, reader.read(&mut extra[..need])).await
&& n > 0
{
initial_data.extend_from_slice(&extra[..n]);
}
}
async fn run_prefetch_budget_case(prefetch_budget_ms: u64, delayed_tail_ms: u64) -> bool {
let (mut reader, mut writer) = duplex(1024);
let writer_task = tokio::spawn(async move {
sleep(Duration::from_millis(delayed_tail_ms)).await;
writer
.write_all(b"ONNECT example.org:443 HTTP/1.1\r\n")
.await
.expect("tail bytes must be writable");
writer.shutdown().await.expect("writer shutdown must succeed");
});
let mut initial_data = b"C".to_vec();
extend_masking_initial_window_with_budget(
&mut reader,
&mut initial_data,
Duration::from_millis(prefetch_budget_ms),
)
.await;
writer_task
.await
.expect("writer task must not panic during matrix case");
initial_data.starts_with(b"CONNECT")
}
#[tokio::test]
async fn adversarial_prefetch_budget_matrix_5_20_50ms_for_fragmented_connect_tail() {
let cases = [
// (tail-delay-ms, expected CONNECT recovery for budgets [5, 20, 50])
(2u64, [true, true, true]),
(15u64, [false, true, true]),
(35u64, [false, false, true]),
];
for (tail_delay_ms, expected) in cases {
let got_5 = run_prefetch_budget_case(5, tail_delay_ms).await;
let got_20 = run_prefetch_budget_case(20, tail_delay_ms).await;
let got_50 = run_prefetch_budget_case(50, tail_delay_ms).await;
assert_eq!(
got_5, expected[0],
"5ms prefetch budget mismatch for tail delay {}ms",
tail_delay_ms
);
assert_eq!(
got_20, expected[1],
"20ms prefetch budget mismatch for tail delay {}ms",
tail_delay_ms
);
assert_eq!(
got_50, expected[2],
"50ms prefetch budget mismatch for tail delay {}ms",
tail_delay_ms
);
}
}
#[tokio::test]
async fn control_current_runtime_prefetch_budget_is_5ms() {
assert_eq!(
MASK_CLASSIFIER_PREFETCH_TIMEOUT,
Duration::from_millis(5),
"matrix assumptions require current runtime prefetch budget to stay at 5ms"
);
}

View File

@ -0,0 +1,161 @@
use super::*;
use crate::config::{UpstreamConfig, UpstreamType};
use crate::crypto::sha256_hmac;
use crate::protocol::constants::{HANDSHAKE_LEN, TLS_VERSION};
use crate::protocol::tls;
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::time::{Duration, Instant};
fn new_upstream_manager(stats: Arc<Stats>) -> Arc<UpstreamManager> {
Arc::new(UpstreamManager::new(
vec![UpstreamConfig {
upstream_type: UpstreamType::Direct {
interface: None,
bind_addresses: None,
},
weight: 1,
enabled: true,
scopes: String::new(),
selected_scope: String::new(),
}],
1,
1,
1,
1,
false,
stats,
))
}
fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec<u8> {
let total_len = 5 + tls_len;
let mut handshake = vec![fill; total_len];
handshake[0] = 0x16;
handshake[1] = 0x03;
handshake[2] = 0x01;
handshake[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes());
let session_id_len: usize = 32;
handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8;
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0);
let computed = sha256_hmac(secret, &handshake);
let mut digest = computed;
let ts = timestamp.to_le_bytes();
for i in 0..4 {
digest[28 + i] ^= ts[i];
}
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN]
.copy_from_slice(&digest);
handshake
}
async fn run_replay_candidate_session(
replay_checker: Arc<ReplayChecker>,
hello: &[u8],
peer: SocketAddr,
drive_mtproto_fail: bool,
) -> Duration {
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = 1;
cfg.censorship.mask_timing_normalization_enabled = false;
cfg.access.ignore_time_skew = true;
cfg.access
.users
.insert("user".to_string(), "abababababababababababababababab".to_string());
let config = Arc::new(cfg);
let stats = Arc::new(Stats::new());
let beobachten = Arc::new(BeobachtenStore::new());
let (server_side, mut client_side) = duplex(65536);
let started = Instant::now();
let task = tokio::spawn(handle_client_stream(
server_side,
peer,
config,
stats.clone(),
new_upstream_manager(stats),
replay_checker,
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
None,
Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
None,
Arc::new(UserIpTracker::new()),
beobachten,
false,
));
client_side.write_all(hello).await.unwrap();
if drive_mtproto_fail {
let mut server_hello_head = [0u8; 5];
client_side.read_exact(&mut server_hello_head).await.unwrap();
assert_eq!(server_hello_head[0], 0x16);
let body_len = u16::from_be_bytes([server_hello_head[3], server_hello_head[4]]) as usize;
let mut body = vec![0u8; body_len];
client_side.read_exact(&mut body).await.unwrap();
let mut invalid_mtproto_record = Vec::with_capacity(5 + HANDSHAKE_LEN);
invalid_mtproto_record.push(0x17);
invalid_mtproto_record.extend_from_slice(&TLS_VERSION);
invalid_mtproto_record.extend_from_slice(&(HANDSHAKE_LEN as u16).to_be_bytes());
invalid_mtproto_record.extend_from_slice(&vec![0u8; HANDSHAKE_LEN]);
client_side.write_all(&invalid_mtproto_record).await.unwrap();
client_side
.write_all(b"GET /replay-fallback HTTP/1.1\r\nHost: x\r\n\r\n")
.await
.unwrap();
}
client_side.shutdown().await.unwrap();
let _ = tokio::time::timeout(Duration::from_secs(4), task)
.await
.unwrap()
.unwrap();
started.elapsed()
}
#[tokio::test]
async fn replay_reject_still_honors_masking_timing_budget() {
let replay_checker = Arc::new(ReplayChecker::new(256, Duration::from_secs(60)));
let hello = make_valid_tls_client_hello(&[0xAB; 16], 7, 600, 0x51);
let seed_elapsed = run_replay_candidate_session(
Arc::clone(&replay_checker),
&hello,
"198.51.100.201:58001".parse().unwrap(),
true,
)
.await;
assert!(
seed_elapsed >= Duration::from_millis(40) && seed_elapsed < Duration::from_millis(250),
"seed replay-candidate run must honor masking timing budget without unbounded delay"
);
let replay_elapsed = run_replay_candidate_session(
Arc::clone(&replay_checker),
&hello,
"198.51.100.202:58002".parse().unwrap(),
false,
)
.await;
assert!(
replay_elapsed >= Duration::from_millis(40)
&& replay_elapsed < Duration::from_millis(250),
"replay rejection path must still satisfy masking timing budget without unbounded DB/CPU delay"
);
}

View File

@ -0,0 +1,90 @@
use super::*;
use std::net::{IpAddr, Ipv4Addr};
use std::time::{Duration, Instant};
fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> {
auth_probe_test_lock()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
#[test]
fn edge_zero_state_len_yields_zero_start_offset() {
let _guard = auth_probe_test_guard();
let ip = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 44));
let now = Instant::now();
assert_eq!(
auth_probe_scan_start_offset(ip, now, 0, 16),
0,
"empty map must not produce non-zero scan offset"
);
}
#[test]
fn adversarial_large_state_must_bound_start_offset_to_scan_budget() {
let _guard = auth_probe_test_guard();
let base = Instant::now();
let scan_limit = 16usize;
let state_len = 65_536usize;
for i in 0..2048u32 {
let ip = IpAddr::V4(Ipv4Addr::new(
203,
((i >> 16) & 0xff) as u8,
((i >> 8) & 0xff) as u8,
(i & 0xff) as u8,
));
let now = base + Duration::from_micros(i as u64);
let start = auth_probe_scan_start_offset(ip, now, state_len, scan_limit);
assert!(
start < scan_limit,
"start offset must stay within scan window; start={start}, limit={scan_limit}"
);
}
}
#[test]
fn positive_state_smaller_than_scan_limit_caps_to_state_len() {
let _guard = auth_probe_test_guard();
let ip = IpAddr::V4(Ipv4Addr::new(192, 0, 2, 17));
let now = Instant::now();
for state_len in 1..32usize {
let start = auth_probe_scan_start_offset(ip, now, state_len, 64);
assert!(
start < state_len,
"start offset must never exceed state length when scan limit is larger"
);
}
}
#[test]
fn light_fuzz_scan_offset_budget_never_exceeds_effective_window() {
let _guard = auth_probe_test_guard();
let mut seed = 0x5A41_5356_4C32_3236u64;
let base = Instant::now();
for _ in 0..4096 {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let ip = IpAddr::V4(Ipv4Addr::new(
(seed >> 24) as u8,
(seed >> 16) as u8,
(seed >> 8) as u8,
seed as u8,
));
let state_len = ((seed >> 8) as usize % 131_072).saturating_add(1);
let scan_limit = ((seed >> 32) as usize % 512).saturating_add(1);
let now = base + Duration::from_nanos(seed & 0xffff);
let start = auth_probe_scan_start_offset(ip, now, state_len, scan_limit);
let effective_window = state_len.min(scan_limit);
assert!(
start < effective_window,
"scan offset must stay inside effective window"
);
}
}

View File

@ -0,0 +1,113 @@
use super::*;
use std::collections::HashSet;
use std::net::{IpAddr, Ipv4Addr};
use std::time::{Duration, Instant};
fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> {
auth_probe_test_lock()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
#[test]
fn positive_same_ip_moving_time_yields_diverse_scan_offsets() {
let _guard = auth_probe_test_guard();
let ip = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 77));
let base = Instant::now();
let mut uniq = HashSet::new();
for i in 0..512u64 {
let now = base + Duration::from_nanos(i);
let offset = auth_probe_scan_start_offset(ip, now, 65_536, 16);
uniq.insert(offset);
}
assert_eq!(
uniq.len(),
16,
"offset randomization must cover the entire scan window over 512 samples"
);
}
#[test]
fn adversarial_many_ips_same_time_spreads_offsets_without_bias_collapse() {
let _guard = auth_probe_test_guard();
let now = Instant::now();
let mut uniq = HashSet::new();
for i in 0..1024u32 {
let ip = IpAddr::V4(Ipv4Addr::new(
(i >> 16) as u8,
(i >> 8) as u8,
i as u8,
(255 - (i as u8)),
));
uniq.insert(auth_probe_scan_start_offset(ip, now, 65_536, 16));
}
assert_eq!(
uniq.len(),
16,
"scan offset distribution collapsed unexpectedly across peer set"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn stress_parallel_failure_churn_under_saturation_remains_capped_and_live() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let start = Instant::now();
let mut workers = Vec::new();
for worker in 0..8u8 {
workers.push(tokio::spawn(async move {
for i in 0..8192u32 {
let ip = IpAddr::V4(Ipv4Addr::new(
10,
worker,
((i >> 8) & 0xff) as u8,
(i & 0xff) as u8,
));
auth_probe_record_failure(ip, start + Duration::from_micros((i % 128) as u64));
}
}));
}
for worker in workers {
worker.await.expect("saturation worker must not panic");
}
assert!(
auth_probe_state_map().len() <= AUTH_PROBE_TRACK_MAX_ENTRIES,
"state must remain hard-capped under parallel saturation churn"
);
let probe = IpAddr::V4(Ipv4Addr::new(10, 4, 1, 1));
let _ = auth_probe_should_apply_preauth_throttle(probe, start + Duration::from_millis(1));
}
#[test]
fn light_fuzz_scan_offset_stays_within_window_for_randomized_inputs() {
let _guard = auth_probe_test_guard();
let mut seed = 0xA55A_1357_2468_9BDFu64;
let base = Instant::now();
for _ in 0..8192 {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let ip = IpAddr::V4(Ipv4Addr::new(
(seed >> 24) as u8,
(seed >> 16) as u8,
(seed >> 8) as u8,
seed as u8,
));
let state_len = ((seed >> 8) as usize % 200_000).saturating_add(1);
let scan_limit = ((seed >> 40) as usize % 1024).saturating_add(1);
let now = base + Duration::from_nanos(seed & 0x1fff);
let offset = auth_probe_scan_start_offset(ip, now, state_len, scan_limit);
assert!(offset < state_len.min(scan_limit));
}
}

View File

@ -0,0 +1,42 @@
use super::*;
fn handshake_source() -> &'static str {
include_str!("../handshake.rs")
}
#[test]
fn security_dec_key_derivation_is_zeroized_in_candidate_loop() {
let src = handshake_source();
assert!(
src.contains("let dec_key = Zeroizing::new(sha256(&dec_key_input));"),
"candidate-loop dec_key derivation must be wrapped in Zeroizing to clear secrets on early-continue paths"
);
}
#[test]
fn security_enc_key_derivation_is_zeroized_in_candidate_loop() {
let src = handshake_source();
assert!(
src.contains("let enc_key = Zeroizing::new(sha256(&enc_key_input));"),
"candidate-loop enc_key derivation must be wrapped in Zeroizing to clear secrets on early-continue paths"
);
}
#[test]
fn security_aes_ctr_initialization_uses_zeroizing_references() {
let src = handshake_source();
assert!(
src.contains("let mut decryptor = AesCtr::new(&dec_key, dec_iv);")
&& src.contains("let encryptor = AesCtr::new(&enc_key, enc_iv);"),
"AES-CTR initialization must use Zeroizing key wrappers directly without creating extra plain key variables"
);
}
#[test]
fn security_success_struct_copies_out_of_zeroizing_wrappers() {
let src = handshake_source();
assert!(
src.contains("dec_key: *dec_key,") && src.contains("enc_key: *enc_key,"),
"HandshakeSuccess construction must copy from Zeroizing wrappers so loop-local key material is dropped and zeroized"
);
}

View File

@ -493,9 +493,12 @@ async fn timing_classifier_light_fuzz_pairwise_bucketed_accuracy_stays_bounded_u
];
let mut meaningful_improvement_seen = false;
let mut baseline_sum = 0.0f64;
let mut hardened_sum = 0.0f64;
let mut pair_count = 0usize;
let mut informative_baseline_sum = 0.0f64;
let mut informative_hardened_sum = 0.0f64;
let mut informative_pair_count = 0usize;
let mut low_info_baseline_sum = 0.0f64;
let mut low_info_hardened_sum = 0.0f64;
let mut low_info_pair_count = 0usize;
let acc_quant_step = 1.0 / (2 * SAMPLE_COUNT) as f64;
let tolerated_pair_regression = acc_quant_step + 0.03;
@ -522,6 +525,16 @@ async fn timing_classifier_light_fuzz_pairwise_bucketed_accuracy_stays_bounded_u
hardened_acc <= baseline_acc + tolerated_pair_regression,
"normalization should not materially worsen informative pair: baseline={baseline_acc:.3} hardened={hardened_acc:.3} tolerated={tolerated_pair_regression:.3}"
);
informative_baseline_sum += baseline_acc;
informative_hardened_sum += hardened_acc;
informative_pair_count += 1;
} else {
// Low-information pairs (near-random baseline separability) are expected
// to exhibit quantized jitter at low sample counts; do not fold them into
// strict average-regression checks used for informative side-channel signal.
low_info_baseline_sum += baseline_acc;
low_info_hardened_sum += hardened_acc;
low_info_pair_count += 1;
}
println!(
@ -532,19 +545,30 @@ async fn timing_classifier_light_fuzz_pairwise_bucketed_accuracy_stays_bounded_u
meaningful_improvement_seen = true;
}
baseline_sum += baseline_acc;
hardened_sum += hardened_acc;
pair_count += 1;
}
let baseline_avg = baseline_sum / pair_count as f64;
let hardened_avg = hardened_sum / pair_count as f64;
assert!(
informative_pair_count > 0,
"expected at least one informative pair for timing-separability guard"
);
let informative_baseline_avg = informative_baseline_sum / informative_pair_count as f64;
let informative_hardened_avg = informative_hardened_sum / informative_pair_count as f64;
assert!(
hardened_avg <= baseline_avg + 0.10,
"normalization should not materially increase average pairwise separability: baseline_avg={baseline_avg:.3} hardened_avg={hardened_avg:.3}"
informative_hardened_avg <= informative_baseline_avg + 0.10,
"normalization should not materially increase informative average separability: baseline_avg={informative_baseline_avg:.3} hardened_avg={informative_hardened_avg:.3}"
);
if low_info_pair_count > 0 {
let low_info_baseline_avg = low_info_baseline_sum / low_info_pair_count as f64;
let low_info_hardened_avg = low_info_hardened_sum / low_info_pair_count as f64;
assert!(
low_info_hardened_avg <= low_info_baseline_avg + 0.40,
"normalization low-info average drift exceeded jitter budget: baseline_avg={low_info_baseline_avg:.3} hardened_avg={low_info_hardened_avg:.3}"
);
}
// Optional signal only: do not require improvement on every run because
// noisy CI schedulers can flatten pairwise differences at low sample counts.
let _ = meaningful_improvement_seen;

View File

@ -0,0 +1,122 @@
use super::*;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::task::{Context, Poll};
use tokio::io::AsyncRead;
use tokio::time::{Duration, timeout};
struct EndlessReader {
produced: Arc<AtomicUsize>,
}
impl AsyncRead for EndlessReader {
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
let len = buf.remaining().max(1);
let fill = vec![0xAA; len];
buf.put_slice(&fill);
self.produced.fetch_add(len, Ordering::Relaxed);
Poll::Ready(Ok(()))
}
}
#[test]
fn loop_guard_unspecified_bind_uses_interface_inventory() {
let local: SocketAddr = "0.0.0.0:443".parse().unwrap();
let resolved: SocketAddr = "192.168.44.10:443".parse().unwrap();
let interfaces = vec!["192.168.44.10".parse().unwrap()];
assert!(is_mask_target_local_listener_with_interfaces(
"mask.example",
443,
local,
Some(resolved),
&interfaces,
));
}
#[tokio::test]
async fn consume_client_data_stops_after_byte_cap_without_eof() {
let produced = Arc::new(AtomicUsize::new(0));
let reader = EndlessReader {
produced: Arc::clone(&produced),
};
let cap = 10_000usize;
consume_client_data(reader, cap).await;
let total = produced.load(Ordering::Relaxed);
assert!(
total >= cap,
"consume path must read at least up to cap before stopping"
);
assert!(
total <= cap + 8192,
"consume path must stop within one read chunk above cap"
);
}
#[test]
fn masking_beobachten_minutes_zero_fail_closes_to_minimum_ttl() {
let mut config = ProxyConfig::default();
config.general.beobachten = true;
config.general.beobachten_minutes = 0;
let ttl = masking_beobachten_ttl(&config);
assert_eq!(ttl, std::time::Duration::from_secs(60));
}
#[test]
fn timing_normalization_zero_floor_safety_net_defaults_to_mask_timeout() {
let mut config = ProxyConfig::default();
config.censorship.mask_timing_normalization_enabled = true;
config.censorship.mask_timing_normalization_floor_ms = 0;
config.censorship.mask_timing_normalization_ceiling_ms = 0;
let budget = mask_outcome_target_budget(&config);
assert_eq!(budget, MASK_TIMEOUT);
}
#[tokio::test]
async fn loop_guard_blocks_self_target_before_proxy_protocol_header_growth() {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let accept_task = tokio::spawn(async move {
timeout(Duration::from_millis(120), listener.accept())
.await
.is_ok()
});
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_unix_sock = None;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = backend_addr.port();
config.censorship.mask_proxy_protocol = 2;
let peer: SocketAddr = "203.0.113.251:55991".parse().unwrap();
let local_addr: SocketAddr = format!("0.0.0.0:{}", backend_addr.port()).parse().unwrap();
let beobachten = BeobachtenStore::new();
handle_bad_client(
tokio::io::empty(),
tokio::io::sink(),
b"GET / HTTP/1.1\r\n\r\n",
peer,
local_addr,
&config,
&beobachten,
)
.await;
let accepted = accept_task.await.unwrap();
assert!(
!accepted,
"loop guard must fail closed before any recursive PROXY protocol amplification"
);
}

View File

@ -0,0 +1,16 @@
use super::*;
#[test]
fn detect_client_type_recognizes_extended_http_probe_verbs() {
assert_eq!(detect_client_type(b"CONNECT / HTTP/1.1\r\n"), "HTTP");
assert_eq!(detect_client_type(b"TRACE / HTTP/1.1\r\n"), "HTTP");
assert_eq!(detect_client_type(b"PATCH / HTTP/1.1\r\n"), "HTTP");
}
#[test]
fn detect_client_type_recognizes_fragmented_http_method_prefixes() {
assert_eq!(detect_client_type(b"CO"), "HTTP");
assert_eq!(detect_client_type(b"CON"), "HTTP");
assert_eq!(detect_client_type(b"TR"), "HTTP");
assert_eq!(detect_client_type(b"PAT"), "HTTP");
}

View File

@ -0,0 +1,127 @@
use super::*;
use crate::network::dns_overrides::install_entries;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::time::{Duration, Instant, timeout};
async fn run_connect_failure_case(
host: &str,
port: u16,
timing_normalization_enabled: bool,
peer: SocketAddr,
) -> Duration {
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_unix_sock = None;
config.censorship.mask_host = Some(host.to_string());
config.censorship.mask_port = port;
config.censorship.mask_timing_normalization_enabled = timing_normalization_enabled;
config.censorship.mask_timing_normalization_floor_ms = 120;
config.censorship.mask_timing_normalization_ceiling_ms = 120;
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let beobachten = BeobachtenStore::new();
let probe = b"CONNECT example.org:443 HTTP/1.1\r\nHost: example.org\r\n\r\n";
let (mut client_writer, client_reader) = duplex(1024);
let (mut client_visible_reader, client_visible_writer) = duplex(1024);
let started = Instant::now();
let task = tokio::spawn(async move {
handle_bad_client(
client_reader,
client_visible_writer,
probe,
peer,
local_addr,
&config,
&beobachten,
)
.await;
});
client_writer.shutdown().await.unwrap();
timeout(Duration::from_secs(4), task)
.await
.unwrap()
.unwrap();
let mut buf = [0u8; 1];
let n = timeout(Duration::from_secs(1), client_visible_reader.read(&mut buf))
.await
.unwrap()
.unwrap();
assert_eq!(n, 0, "connect-failure path must close client-visible writer");
started.elapsed()
}
#[tokio::test]
async fn connect_failure_refusal_close_behavior_matrix() {
let temp_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let unused_port = temp_listener.local_addr().unwrap().port();
drop(temp_listener);
for (idx, timing_normalization_enabled) in [false, true].into_iter().enumerate() {
let peer: SocketAddr = format!("203.0.113.210:{}", 54100 + idx as u16)
.parse()
.unwrap();
let elapsed = run_connect_failure_case(
"127.0.0.1",
unused_port,
timing_normalization_enabled,
peer,
)
.await;
if timing_normalization_enabled {
assert!(
elapsed >= Duration::from_millis(110) && elapsed < Duration::from_millis(250),
"normalized refusal path must honor configured timing envelope without stalling"
);
} else {
assert!(
elapsed >= Duration::from_millis(40) && elapsed < Duration::from_millis(150),
"non-normalized refusal path must honor baseline connect budget without stalling"
);
}
}
}
#[tokio::test]
async fn connect_failure_overridden_hostname_close_behavior_matrix() {
let temp_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let unused_port = temp_listener.local_addr().unwrap().port();
drop(temp_listener);
// Make hostname resolution deterministic in tests so timing ceilings are meaningful.
install_entries(&[format!("mask.invalid:{}:127.0.0.1", unused_port)]).unwrap();
for (idx, timing_normalization_enabled) in [false, true].into_iter().enumerate() {
let peer: SocketAddr = format!("203.0.113.220:{}", 54200 + idx as u16)
.parse()
.unwrap();
let elapsed = run_connect_failure_case(
"mask.invalid",
unused_port,
timing_normalization_enabled,
peer,
)
.await;
if timing_normalization_enabled {
assert!(
elapsed >= Duration::from_millis(110) && elapsed < Duration::from_millis(250),
"normalized overridden-host path must honor configured timing envelope without stalling"
);
} else {
assert!(
elapsed >= Duration::from_millis(40) && elapsed < Duration::from_millis(150),
"non-normalized overridden-host path must honor baseline connect budget without stalling"
);
}
}
install_entries(&[]).unwrap();
}

View File

@ -0,0 +1,85 @@
use super::*;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Instant;
use tokio::io::{AsyncRead, ReadBuf};
struct OneByteThenStall {
sent: bool,
}
impl AsyncRead for OneByteThenStall {
fn poll_read(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
if !self.sent {
self.sent = true;
buf.put_slice(&[0x42]);
Poll::Ready(Ok(()))
} else {
Poll::Pending
}
}
}
#[tokio::test]
async fn stalling_client_terminates_at_idle_not_relay_timeout() {
let reader = OneByteThenStall { sent: false };
let started = Instant::now();
let result = tokio::time::timeout(
MASK_RELAY_TIMEOUT,
consume_client_data(reader, MASK_BUFFER_SIZE * 4),
)
.await;
assert!(
result.is_ok(),
"consume_client_data should complete by per-read idle timeout, not hit relay timeout"
);
let elapsed = started.elapsed();
assert!(
elapsed >= (MASK_RELAY_IDLE_TIMEOUT / 2),
"consume_client_data returned too quickly for idle-timeout path: {elapsed:?}"
);
assert!(
elapsed < MASK_RELAY_TIMEOUT,
"consume_client_data waited full relay timeout ({elapsed:?}); \
per-read idle timeout is missing"
);
}
#[tokio::test]
async fn fast_reader_drains_to_eof() {
let data = vec![0xAAu8; 32 * 1024];
let reader = std::io::Cursor::new(data);
tokio::time::timeout(MASK_RELAY_TIMEOUT, consume_client_data(reader, usize::MAX))
.await
.expect("consume_client_data did not complete for fast EOF reader");
}
#[tokio::test]
async fn io_error_terminates_cleanly() {
struct ErrReader;
impl AsyncRead for ErrReader {
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::ConnectionReset,
"simulated reset",
)))
}
}
tokio::time::timeout(MASK_RELAY_TIMEOUT, consume_client_data(ErrReader, usize::MAX))
.await
.expect("consume_client_data did not return on I/O error");
}

View File

@ -0,0 +1,64 @@
use super::*;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Instant;
use tokio::io::{AsyncRead, ReadBuf};
use tokio::task::JoinSet;
struct OneByteThenStall {
sent: bool,
}
impl AsyncRead for OneByteThenStall {
fn poll_read(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
if !self.sent {
self.sent = true;
buf.put_slice(&[0xAA]);
Poll::Ready(Ok(()))
} else {
Poll::Pending
}
}
}
#[tokio::test]
async fn consume_stall_stress_finishes_within_idle_budget() {
let mut set = JoinSet::new();
let started = Instant::now();
for _ in 0..64 {
set.spawn(async {
tokio::time::timeout(
MASK_RELAY_TIMEOUT,
consume_client_data(OneByteThenStall { sent: false }, usize::MAX),
)
.await
.expect("consume_client_data exceeded relay timeout under stall load");
});
}
while let Some(res) = set.join_next().await {
res.unwrap();
}
// Under test constants idle=100ms, relay=200ms. 64 concurrent tasks stalling
// for 100ms should complete well under a strict 600ms boundary.
assert!(
started.elapsed() < MASK_RELAY_TIMEOUT * 3,
"stall stress batch completed too slowly; possible async executor starvation or head-of-line blocking"
);
}
#[tokio::test]
async fn consume_zero_cap_returns_immediately() {
let started = Instant::now();
consume_client_data(tokio::io::empty(), 0).await;
assert!(
started.elapsed() < MASK_RELAY_IDLE_TIMEOUT,
"zero byte cap must return immediately"
);
}

View File

@ -0,0 +1,55 @@
use super::*;
use tokio::net::TcpListener;
use tokio::time::Duration;
#[tokio::test]
async fn http2_preface_is_forwarded_and_recorded_as_http() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let preface = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n".to_vec();
let accept_task = tokio::spawn({
let preface = preface.clone();
async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut received = vec![0u8; preface.len()];
stream.read_exact(&mut received).await.unwrap();
assert_eq!(received, preface);
}
});
let mut config = ProxyConfig::default();
config.general.beobachten = true;
config.general.beobachten_minutes = 1;
config.censorship.mask = true;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = backend_addr.port();
config.censorship.mask_unix_sock = None;
config.censorship.mask_proxy_protocol = 0;
let peer: SocketAddr = "198.51.100.130:54130".parse().unwrap();
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let (client_reader, _client_writer) = tokio::io::duplex(512);
let (_client_visible_reader, client_visible_writer) = tokio::io::duplex(512);
let beobachten = BeobachtenStore::new();
handle_bad_client(
client_reader,
client_visible_writer,
&preface,
peer,
local_addr,
&config,
&beobachten,
)
.await;
tokio::time::timeout(Duration::from_secs(2), accept_task)
.await
.unwrap()
.unwrap();
let snapshot = beobachten.snapshot_text(Duration::from_secs(60));
assert!(snapshot.contains("[HTTP]"));
assert!(snapshot.contains("198.51.100.130-1"));
}

View File

@ -0,0 +1,92 @@
use super::*;
#[test]
fn full_http2_preface_classified_as_http_probe() {
let preface = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n";
assert!(
is_http_probe(preface),
"HTTP/2 connection preface must be classified as HTTP probe"
);
}
#[test]
fn partial_http2_preface_3_bytes_classified() {
assert!(
is_http_probe(b"PRI"),
"3-byte HTTP/2 preface prefix must be classified"
);
}
#[test]
fn partial_http2_preface_2_bytes_classified() {
assert!(
is_http_probe(b"PR"),
"2-byte HTTP/2 preface prefix must be classified"
);
}
#[test]
fn existing_http1_methods_unaffected() {
for prefix in [
b"GET / HTTP/1.1\r\n".as_ref(),
b"POST /api HTTP/1.1\r\n".as_ref(),
b"CONNECT example.com:443 HTTP/1.1\r\n".as_ref(),
b"TRACE / HTTP/1.1\r\n".as_ref(),
b"PATCH / HTTP/1.1\r\n".as_ref(),
] {
assert!(is_http_probe(prefix));
}
}
#[test]
fn non_http_data_not_classified() {
for data in [
b"\x16\x03\x01\x00\xf1".as_ref(),
b"SSH-2.0-OpenSSH_8.9\r\n".as_ref(),
b"\x00\x01\x02\x03".as_ref(),
b"".as_ref(),
b"P".as_ref(),
] {
assert!(!is_http_probe(data));
}
}
#[test]
fn light_fuzz_non_http_prefixes_not_misclassified() {
// Deterministic pseudo-fuzz to exercise classifier edges while avoiding
// known HTTP method and partial windows.
let mut x = 0x1234_5678u32;
for _ in 0..1024 {
x = x.wrapping_mul(1664525).wrapping_add(1013904223);
let len = 4 + ((x >> 8) as usize % 12);
let mut data = vec![0u8; len];
for byte in &mut data {
x = x.wrapping_mul(1664525).wrapping_add(1013904223);
*byte = (x & 0xFF) as u8;
}
if [
b"GET ".as_ref(),
b"POST".as_ref(),
b"HEAD".as_ref(),
b"PUT ".as_ref(),
b"DELETE".as_ref(),
b"OPTIONS".as_ref(),
b"CONNECT".as_ref(),
b"TRACE".as_ref(),
b"PATCH".as_ref(),
b"PRI ".as_ref(),
]
.iter()
.any(|m| data.starts_with(m))
{
continue;
}
assert!(
!is_http_probe(&data),
"non-http pseudo-fuzz input misclassified: {:?}",
&data[..data.len().min(8)]
);
}
}

View File

@ -0,0 +1,79 @@
use super::*;
#[test]
fn exact_four_byte_http_tokens_are_classified() {
for token in [b"GET ".as_ref(), b"POST".as_ref(), b"HEAD".as_ref(), b"PUT ".as_ref(), b"PRI ".as_ref()] {
assert!(
is_http_probe(token),
"exact 4-byte token must be classified as HTTP probe: {:?}",
token
);
}
}
#[test]
fn exact_four_byte_non_http_tokens_are_not_classified() {
for token in [
b"GEX ".as_ref(),
b"POXT".as_ref(),
b"HEA/".as_ref(),
b"PU\0 ".as_ref(),
b"PRI/".as_ref(),
] {
assert!(
!is_http_probe(token),
"non-HTTP 4-byte token must not be classified: {:?}",
token
);
}
}
#[test]
fn detect_client_type_keeps_http_label_for_minimal_four_byte_http_prefixes() {
assert_eq!(detect_client_type(b"GET "), "HTTP");
assert_eq!(detect_client_type(b"PRI "), "HTTP");
}
#[test]
fn exact_long_http_tokens_are_classified() {
for token in [b"CONNECT".as_ref(), b"TRACE".as_ref(), b"PATCH".as_ref()] {
assert!(
is_http_probe(token),
"exact long HTTP token must be classified as HTTP probe: {:?}",
token
);
}
}
#[test]
fn detect_client_type_keeps_http_label_for_exact_long_http_tokens() {
assert_eq!(detect_client_type(b"CONNECT"), "HTTP");
assert_eq!(detect_client_type(b"TRACE"), "HTTP");
assert_eq!(detect_client_type(b"PATCH"), "HTTP");
}
#[test]
fn light_fuzz_four_byte_ascii_noise_not_misclassified() {
// Deterministic pseudo-fuzz over 4-byte printable ASCII inputs.
let mut x = 0xA17C_93E5u32;
for _ in 0..2048 {
let mut token = [0u8; 4];
for byte in &mut token {
x = x.wrapping_mul(1664525).wrapping_add(1013904223);
*byte = 32 + ((x & 0x3F) as u8); // printable ASCII subset
}
if [b"GET ", b"POST", b"HEAD", b"PUT ", b"PRI "]
.iter()
.any(|m| token.as_slice() == *m)
{
continue;
}
assert!(
!is_http_probe(&token),
"pseudo-fuzz noise misclassified as HTTP probe: {:?}",
token
);
}
}

View File

@ -0,0 +1,51 @@
#![cfg(unix)]
use super::*;
#[test]
fn defense_in_depth_empty_refresh_preserves_previous_non_empty_interfaces() {
let previous = vec![
"192.168.100.7"
.parse::<IpAddr>()
.expect("must parse interface ip"),
];
let refreshed = Vec::new();
let next = choose_interface_snapshot(&previous, refreshed);
assert_eq!(
next, previous,
"empty refresh should preserve previous non-empty snapshot to avoid fail-open loop-guard regressions"
);
}
#[test]
fn defense_in_depth_non_empty_refresh_replaces_previous_snapshot() {
let previous = vec![
"192.168.100.7"
.parse::<IpAddr>()
.expect("must parse interface ip"),
];
let refreshed = vec![
"10.55.0.3"
.parse::<IpAddr>()
.expect("must parse refreshed interface ip"),
];
let next = choose_interface_snapshot(&previous, refreshed.clone());
assert_eq!(next, refreshed);
}
#[test]
fn defense_in_depth_empty_refresh_keeps_empty_when_no_previous_snapshot_exists() {
let previous = Vec::new();
let refreshed = Vec::new();
let next = choose_interface_snapshot(&previous, refreshed);
assert!(
next.is_empty(),
"empty refresh with no previous snapshot should remain empty"
);
}

View File

@ -0,0 +1,46 @@
#![cfg(unix)]
use super::*;
use std::sync::{Mutex, OnceLock};
fn interface_cache_test_lock() -> &'static Mutex<()> {
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
LOCK.get_or_init(|| Mutex::new(()))
}
#[test]
fn tdd_repeated_local_listener_checks_do_not_repeat_interface_enumeration_within_window() {
let _guard = interface_cache_test_lock()
.lock()
.unwrap_or_else(|poison| poison.into_inner());
reset_local_interface_enumerations_for_tests();
let local_addr: SocketAddr = "0.0.0.0:443".parse().expect("valid local addr");
let _ = is_mask_target_local_listener("127.0.0.1", 443, local_addr, None);
let _ = is_mask_target_local_listener("127.0.0.1", 443, local_addr, None);
assert_eq!(
local_interface_enumerations_for_tests(),
1,
"interface enumeration must be cached across repeated bad-client checks"
);
}
#[test]
fn tdd_non_local_port_short_circuit_does_not_enumerate_interfaces() {
let _guard = interface_cache_test_lock()
.lock()
.unwrap_or_else(|poison| poison.into_inner());
reset_local_interface_enumerations_for_tests();
let local_addr: SocketAddr = "0.0.0.0:443".parse().expect("valid local addr");
let is_local = is_mask_target_local_listener("127.0.0.1", 8443, local_addr, None);
assert!(!is_local, "different port must not be treated as local listener");
assert_eq!(
local_interface_enumerations_for_tests(),
0,
"port mismatch should bypass interface enumeration entirely"
);
}

View File

@ -0,0 +1,178 @@
use super::*;
use std::net::{SocketAddr, TcpListener as StdTcpListener};
use tokio::io::{AsyncWriteExt, duplex};
use tokio::time::{Duration, Instant};
fn closed_local_port() -> u16 {
let listener = StdTcpListener::bind("127.0.0.1:0").unwrap();
let port = listener.local_addr().unwrap().port();
drop(listener);
port
}
#[tokio::test]
#[ignore = "red-team expected-fail: offline mask target keeps bad-client socket alive before consume timeout boundary"]
async fn redteam_offline_target_should_drop_idle_client_early() {
let (client_read, mut client_write) = duplex(1024);
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = closed_local_port();
cfg.censorship.mask_timing_normalization_enabled = false;
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let peer_addr: SocketAddr = "192.0.2.50:5000".parse().unwrap();
let beobachten = BeobachtenStore::new();
let handler = tokio::spawn(async move {
handle_bad_client(
client_read,
tokio::io::sink(),
b"GET / HTTP/1.1\r\n\r\n",
peer_addr,
local_addr,
&cfg,
&beobachten,
)
.await;
});
tokio::time::sleep(Duration::from_millis(150)).await;
let write_res = client_write.write_all(b"probe-should-be-closed").await;
assert!(
write_res.is_err(),
"offline target path still keeps client writable before consume timeout"
);
handler.abort();
}
#[tokio::test]
#[ignore = "red-team expected-fail: proxy should mimic immediate RST-like close when target is offline"]
async fn redteam_offline_target_should_not_sleep_to_mask_refusal() {
let (client_read, mut client_write) = duplex(1024);
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = closed_local_port();
cfg.censorship.mask_timing_normalization_enabled = false;
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let peer_addr: SocketAddr = "192.0.2.51:5000".parse().unwrap();
let beobachten = BeobachtenStore::new();
let started = Instant::now();
let handler = tokio::spawn(async move {
handle_bad_client(
client_read,
tokio::io::sink(),
b"\x16\x03\x01\x00\x05hello",
peer_addr,
local_addr,
&cfg,
&beobachten,
)
.await;
});
client_write.shutdown().await.unwrap();
let _ = handler.await;
let elapsed = started.elapsed();
assert!(
elapsed < Duration::from_millis(10),
"offline target path still applies coarse masking sleep and is fingerprintable"
);
}
#[tokio::test]
#[ignore = "red-team expected-fail: refusal path should remain below strict latency envelope under burst"]
async fn redteam_offline_refusal_burst_timing_spread_should_be_tight() {
let mut samples = Vec::new();
for i in 0..12u16 {
let (client_read, mut client_write) = duplex(1024);
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = closed_local_port();
cfg.censorship.mask_timing_normalization_enabled = false;
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let peer_addr: SocketAddr = format!("192.0.2.52:{}", 5100 + i).parse().unwrap();
let beobachten = BeobachtenStore::new();
let started = Instant::now();
let handler = tokio::spawn(async move {
handle_bad_client(
client_read,
tokio::io::sink(),
b"GET / HTTP/1.1\r\n\r\n",
peer_addr,
local_addr,
&cfg,
&beobachten,
)
.await;
});
client_write.shutdown().await.unwrap();
let _ = handler.await;
samples.push(started.elapsed());
}
let min = samples.iter().copied().min().unwrap_or_default();
let max = samples.iter().copied().max().unwrap_or_default();
let spread = max.saturating_sub(min);
assert!(
spread <= Duration::from_millis(5),
"offline refusal timing spread too wide for strict red-team envelope: {:?}",
spread
);
}
#[tokio::test]
#[ignore = "manual red-team: host resolver failure should complete without panic"]
async fn redteam_dns_resolution_failure_must_not_panic() {
let (client_read, mut client_write) = duplex(1024);
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("this.domain.definitely.does.not.exist.invalid".to_string());
cfg.censorship.mask_port = 443;
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let peer_addr: SocketAddr = "192.0.2.99:5999".parse().unwrap();
let beobachten = BeobachtenStore::new();
let handler = tokio::spawn(async move {
handle_bad_client(
client_read,
tokio::io::sink(),
b"GET / HTTP/1.1\r\n\r\n",
peer_addr,
local_addr,
&cfg,
&beobachten,
)
.await;
});
client_write.shutdown().await.unwrap();
let result = tokio::time::timeout(Duration::from_secs(2), handler).await;
assert!(
result.is_ok(),
"dns failure path stalled or panicked instead of terminating"
);
}

View File

@ -0,0 +1,51 @@
use super::*;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Instant;
use tokio::io::AsyncWrite;
struct NeverWritable;
impl AsyncWrite for NeverWritable {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &[u8],
) -> Poll<std::io::Result<usize>> {
Poll::Pending
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Poll::Pending
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Poll::Ready(Ok(()))
}
}
#[tokio::test]
async fn shape_padding_returns_before_global_mask_timeout_on_blocked_writer() {
let mut writer = NeverWritable;
let started = Instant::now();
maybe_write_shape_padding(&mut writer, 1, true, 256, 4096, false, 0, false).await;
assert!(
started.elapsed() <= MASK_TIMEOUT + std::time::Duration::from_millis(30),
"shape padding blocked past timeout budget"
);
}
#[tokio::test]
async fn shape_padding_with_non_http_blur_disabled_at_cap_writes_nothing() {
let mut output = Vec::new();
{
let mut writer = tokio::io::BufWriter::new(&mut output);
maybe_write_shape_padding(&mut writer, 4096, true, 64, 4096, false, 128, false).await;
use tokio::io::AsyncWriteExt;
writer.flush().await.unwrap();
}
assert!(output.is_empty());
}

View File

@ -0,0 +1,105 @@
use super::*;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex, sink};
use tokio::time::{Duration, timeout};
#[tokio::test]
async fn relay_to_mask_enforces_masking_session_byte_cap() {
let initial = vec![0x16, 0x03, 0x01, 0x00, 0x01];
let extra = vec![0xAB; 96 * 1024];
let (client_reader, mut client_writer) = duplex(128 * 1024);
let (mask_read, _mask_read_peer) = duplex(1024);
let (mut mask_observer, mask_write) = duplex(256 * 1024);
let initial_for_task = initial.clone();
let relay = tokio::spawn(async move {
relay_to_mask(
client_reader,
sink(),
mask_read,
mask_write,
&initial_for_task,
false,
512,
4096,
false,
0,
false,
32 * 1024,
)
.await;
});
client_writer.write_all(&extra).await.unwrap();
client_writer.shutdown().await.unwrap();
timeout(Duration::from_secs(2), relay)
.await
.unwrap()
.unwrap();
let mut observed = Vec::new();
timeout(
Duration::from_secs(2),
mask_observer.read_to_end(&mut observed),
)
.await
.unwrap()
.unwrap();
// In this deterministic test, relay must stop exactly at the configured cap.
assert_eq!(
observed.len(),
initial.len() + (32 * 1024),
"masked relay must forward exactly up to the cap (observed={} initial={} cap={})",
observed.len(),
initial.len(),
32 * 1024
);
}
#[tokio::test]
async fn relay_to_mask_propagates_client_half_close_without_waiting_for_other_direction_timeout() {
let initial = b"GET /half-close HTTP/1.1\r\n".to_vec();
let (client_reader, mut client_writer) = duplex(8 * 1024);
let (mask_read, _mask_read_peer) = duplex(8 * 1024);
let (mut mask_observer, mask_write) = duplex(8 * 1024);
let initial_for_task = initial.clone();
let relay = tokio::spawn(async move {
relay_to_mask(
client_reader,
sink(),
mask_read,
mask_write,
&initial_for_task,
false,
512,
4096,
false,
0,
false,
32 * 1024,
)
.await;
});
client_writer.shutdown().await.unwrap();
let mut observed = Vec::new();
timeout(
Duration::from_millis(80),
mask_observer.read_to_end(&mut observed),
)
.await
.expect("mask backend write side should be half-closed promptly")
.unwrap();
assert_eq!(&observed[..initial.len()], initial.as_slice());
timeout(Duration::from_secs(2), relay)
.await
.unwrap()
.unwrap();
}

View File

@ -0,0 +1,100 @@
use super::*;
use tokio::io::AsyncReadExt;
use tokio::time::{Duration, timeout};
async fn collect_padding(
total_sent: usize,
enabled: bool,
floor: usize,
cap: usize,
above_cap_blur: bool,
blur_max: usize,
aggressive: bool,
) -> Vec<u8> {
let (mut tx, mut rx) = tokio::io::duplex(256 * 1024);
maybe_write_shape_padding(
&mut tx,
total_sent,
enabled,
floor,
cap,
above_cap_blur,
blur_max,
aggressive,
)
.await;
drop(tx);
let mut output = Vec::new();
timeout(Duration::from_secs(1), rx.read_to_end(&mut output))
.await
.expect("reading padded output timed out")
.expect("failed reading padded output");
output
}
#[tokio::test]
async fn padding_output_is_not_all_zero() {
let output = collect_padding(1, true, 256, 4096, false, 0, false).await;
assert!(
output.len() >= 255,
"expected at least 255 padding bytes, got {}",
output.len()
);
let nonzero = output.iter().filter(|&&b| b != 0).count();
// In 255 bytes of uniform randomness, the expected number of zero bytes is ~1.
// A weak nonzero check can miss severe entropy collapse.
assert!(
nonzero >= 240,
"RNG output entropy collapsed, too many zero bytes: {} nonzero out of {}",
nonzero,
output.len(),
);
}
#[tokio::test]
async fn padding_reaches_first_bucket_boundary() {
let output = collect_padding(1, true, 64, 4096, false, 0, false).await;
assert_eq!(output.len(), 63);
}
#[tokio::test]
async fn disabled_padding_produces_no_output() {
let output = collect_padding(0, false, 256, 4096, false, 0, false).await;
assert!(output.is_empty());
}
#[tokio::test]
async fn at_cap_without_blur_produces_no_output() {
let output = collect_padding(4096, true, 64, 4096, false, 0, false).await;
assert!(output.is_empty());
}
#[tokio::test]
async fn above_cap_blur_is_positive_and_bounded_in_aggressive_mode() {
let output = collect_padding(4096, true, 64, 4096, true, 128, true).await;
assert!(!output.is_empty());
assert!(output.len() <= 128, "blur exceeded max: {}", output.len());
}
#[tokio::test]
async fn stress_padding_runs_are_not_constant_pattern() {
// Stress and sanity-check: repeated runs should not collapse to identical
// first 16 bytes across all samples.
let mut first_chunks = Vec::new();
for _ in 0..64 {
let out = collect_padding(1, true, 64, 4096, false, 0, false).await;
first_chunks.push(out[..16].to_vec());
}
let first = &first_chunks[0];
let all_same = first_chunks.iter().all(|chunk| chunk == first);
assert!(
!all_same,
"all stress samples had identical prefix, rng output appears degenerate"
);
}

View File

@ -1376,6 +1376,7 @@ async fn relay_to_mask_keeps_backend_to_client_flow_when_client_to_backend_stall
false,
0,
false,
5 * 1024 * 1024,
)
.await;
});
@ -1506,6 +1507,7 @@ async fn relay_to_mask_timeout_cancels_and_drops_all_io_endpoints() {
false,
0,
false,
5 * 1024 * 1024,
),
)
.await;

View File

@ -0,0 +1,354 @@
use super::*;
use std::net::TcpListener as StdTcpListener;
use std::net::SocketAddr;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::net::TcpListener;
use tokio::time::{Duration, Instant, timeout};
fn closed_local_port() -> u16 {
let listener = StdTcpListener::bind("127.0.0.1:0").unwrap();
let port = listener.local_addr().unwrap().port();
drop(listener);
port
}
#[test]
fn self_target_detection_matches_literal_ipv4_listener() {
let local: SocketAddr = "198.51.100.40:443".parse().unwrap();
assert!(is_mask_target_local_listener(
"198.51.100.40",
443,
local,
None,
));
}
#[test]
fn self_target_detection_matches_bracketed_ipv6_listener() {
let local: SocketAddr = "[2001:db8::44]:8443".parse().unwrap();
assert!(is_mask_target_local_listener(
"[2001:db8::44]",
8443,
local,
None,
));
}
#[test]
fn self_target_detection_keeps_same_ip_different_port_forwardable() {
let local: SocketAddr = "203.0.113.44:443".parse().unwrap();
assert!(!is_mask_target_local_listener(
"203.0.113.44",
8443,
local,
None,
));
}
#[test]
fn self_target_detection_normalizes_ipv4_mapped_ipv6_literal() {
let local: SocketAddr = "127.0.0.1:443".parse().unwrap();
assert!(is_mask_target_local_listener(
"::ffff:127.0.0.1",
443,
local,
None,
));
}
#[test]
fn self_target_detection_unspecified_bind_blocks_loopback_target() {
let local: SocketAddr = "0.0.0.0:443".parse().unwrap();
assert!(is_mask_target_local_listener(
"127.0.0.1",
443,
local,
None,
));
}
#[test]
fn self_target_detection_unspecified_bind_keeps_remote_target_forwardable() {
let local: SocketAddr = "0.0.0.0:443".parse().unwrap();
let remote: SocketAddr = "198.51.100.44:443".parse().unwrap();
assert!(!is_mask_target_local_listener(
"mask.example",
443,
local,
Some(remote),
));
}
#[tokio::test]
async fn self_target_fallback_refuses_recursive_loopback_connect() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let local_addr = listener.local_addr().unwrap();
let accept_task = tokio::spawn(async move {
timeout(Duration::from_millis(120), listener.accept())
.await
.is_ok()
});
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_unix_sock = None;
config.censorship.mask_host = Some(local_addr.ip().to_string());
config.censorship.mask_port = local_addr.port();
config.censorship.mask_proxy_protocol = 0;
let peer: SocketAddr = "203.0.113.90:55090".parse().unwrap();
let beobachten = BeobachtenStore::new();
handle_bad_client(
tokio::io::empty(),
tokio::io::sink(),
b"GET /",
peer,
local_addr,
&config,
&beobachten,
)
.await;
let accepted = accept_task.await.unwrap();
assert!(
!accepted,
"self-target masking must fail closed without connecting to local listener"
);
}
#[tokio::test]
async fn same_ip_different_port_still_forwards_to_mask_backend() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let probe = b"GET /".to_vec();
let accept_task = tokio::spawn({
let expected = probe.clone();
async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut got = vec![0u8; expected.len()];
stream.read_exact(&mut got).await.unwrap();
assert_eq!(got, expected);
}
});
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_unix_sock = None;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = backend_addr.port();
config.censorship.mask_proxy_protocol = 0;
let peer: SocketAddr = "203.0.113.91:55091".parse().unwrap();
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let beobachten = BeobachtenStore::new();
handle_bad_client(
tokio::io::empty(),
tokio::io::sink(),
&probe,
peer,
local_addr,
&config,
&beobachten,
)
.await;
timeout(Duration::from_secs(2), accept_task)
.await
.unwrap()
.unwrap();
}
#[test]
fn detect_client_type_http_boundary_get_and_post() {
assert_eq!(detect_client_type(b"GET "), "HTTP");
assert_eq!(detect_client_type(b"GET /"), "HTTP");
assert_eq!(detect_client_type(b"POST"), "HTTP");
assert_eq!(detect_client_type(b"POST "), "HTTP");
assert_eq!(detect_client_type(b"POSTX"), "HTTP");
}
#[test]
fn detect_client_type_tls_and_length_boundaries() {
assert_eq!(detect_client_type(b"\x16\x03\x01"), "port-scanner");
assert_eq!(detect_client_type(b"\x16\x03\x01\x00"), "TLS-scanner");
assert_eq!(detect_client_type(b"123456789"), "port-scanner");
assert_eq!(detect_client_type(b"1234567890"), "unknown");
}
#[test]
fn build_mask_proxy_header_v1_cross_family_falls_back_to_unknown() {
let peer: SocketAddr = "192.168.1.5:12345".parse().unwrap();
let local: SocketAddr = "[2001:db8::1]:443".parse().unwrap();
let header = build_mask_proxy_header(1, peer, local).unwrap();
assert_eq!(header, b"PROXY UNKNOWN\r\n");
}
#[test]
fn next_mask_shape_bucket_checked_mul_overflow_fails_closed() {
let floor = usize::MAX / 2 + 1;
let cap = usize::MAX;
let total = floor + 1;
assert_eq!(next_mask_shape_bucket(total, floor, cap), total);
}
#[tokio::test]
async fn self_target_reject_path_keeps_timing_budget() {
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_unix_sock = None;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = 443;
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let peer: SocketAddr = "203.0.113.92:55092".parse().unwrap();
let beobachten = BeobachtenStore::new();
let (client, server) = duplex(1024);
drop(client);
let started = Instant::now();
handle_bad_client(
server,
tokio::io::sink(),
b"GET / HTTP/1.1\r\n",
peer,
local_addr,
&config,
&beobachten,
)
.await;
let elapsed = started.elapsed();
assert!(
elapsed >= Duration::from_millis(40) && elapsed < Duration::from_millis(250),
"self-target reject path must keep coarse timing budget without stalling"
);
}
#[tokio::test]
async fn relay_path_idle_timeout_eviction_remains_effective() {
let (client_read, mut client_write) = duplex(1024);
let (mask_read, mask_write) = duplex(1024);
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(10)).await;
client_write.write_all(b"a").await.unwrap();
tokio::time::sleep(Duration::from_millis(180)).await;
let _ = client_write.write_all(b"b").await;
});
let started = Instant::now();
relay_to_mask(
client_read,
tokio::io::sink(),
mask_read,
mask_write,
b"init",
false,
0,
0,
false,
0,
false,
5 * 1024 * 1024,
)
.await;
let elapsed = started.elapsed();
assert!(
elapsed >= Duration::from_millis(90) && elapsed < Duration::from_millis(180),
"idle-timeout eviction must occur before late trickle write"
);
}
#[tokio::test]
async fn offline_mask_target_refusal_respects_timing_normalization_budget() {
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_unix_sock = None;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = closed_local_port();
config.censorship.mask_timing_normalization_enabled = true;
config.censorship.mask_timing_normalization_floor_ms = 120;
config.censorship.mask_timing_normalization_ceiling_ms = 120;
let peer: SocketAddr = "203.0.113.93:55093".parse().unwrap();
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let beobachten = BeobachtenStore::new();
let (mut client, server) = duplex(1024);
let started = Instant::now();
let task = tokio::spawn(async move {
handle_bad_client(
server,
tokio::io::sink(),
b"GET / HTTP/1.1\r\n\r\n",
peer,
local_addr,
&config,
&beobachten,
)
.await;
});
client.shutdown().await.unwrap();
timeout(Duration::from_secs(2), task).await.unwrap().unwrap();
let elapsed = started.elapsed();
assert!(
elapsed >= Duration::from_millis(110) && elapsed < Duration::from_millis(220),
"offline-refusal path must honor normalization budget without unbounded drift"
);
}
#[tokio::test]
async fn offline_mask_target_refusal_with_idle_client_is_bounded_by_consume_timeout() {
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_unix_sock = None;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = closed_local_port();
config.censorship.mask_timing_normalization_enabled = false;
let peer: SocketAddr = "203.0.113.94:55094".parse().unwrap();
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let beobachten = BeobachtenStore::new();
let (mut client, server) = duplex(1024);
let started = Instant::now();
let task = tokio::spawn(async move {
handle_bad_client(
server,
tokio::io::sink(),
b"GET / HTTP/1.1\r\n\r\n",
peer,
local_addr,
&config,
&beobachten,
)
.await;
});
tokio::time::sleep(Duration::from_millis(120)).await;
client
.write_all(b"still-open-before-timeout")
.await
.expect("connection should still be open before consume timeout expires");
timeout(Duration::from_secs(2), task).await.unwrap().unwrap();
let elapsed = started.elapsed();
assert!(
elapsed >= Duration::from_millis(190) && elapsed < Duration::from_millis(350),
"offline-refusal path must not retain idle client indefinitely"
);
}

View File

@ -43,6 +43,7 @@ async fn run_relay_case(
above_cap_blur,
above_cap_blur_max_bytes,
false,
5 * 1024 * 1024,
)
.await;
});

View File

@ -88,6 +88,7 @@ async fn relay_to_mask_applies_cap_clamped_padding_for_non_power_of_two_cap() {
false,
0,
false,
5 * 1024 * 1024,
)
.await;
});

View File

@ -9,6 +9,7 @@ use tokio::time::{Duration, timeout};
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn blackhat_campaign_saturation_quota_race_with_queue_pressure_stays_fail_closed() {
let _guard = super::quota_user_lock_test_scope();
let _pressure_guard = super::relay_idle_pressure_test_scope();
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();

View File

@ -0,0 +1,229 @@
use super::*;
use crate::crypto::{AesCtr, SecureRandom};
use crate::stats::Stats;
use crate::stream::CryptoWriter;
use bytes::Bytes;
use std::io;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::task::{Context, Poll, Waker};
use tokio::io::AsyncWrite;
use tokio::time::{Duration, timeout};
fn make_crypto_writer<W>(writer: W) -> CryptoWriter<W>
where
W: tokio::io::AsyncWrite + Unpin,
{
let key = [0u8; 32];
let iv = 0u128;
CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024)
}
#[derive(Default)]
struct GateState {
open: AtomicBool,
parked_waker: std::sync::Mutex<Option<Waker>>,
}
impl GateState {
fn open(&self) {
self.open.store(true, Ordering::Relaxed);
if let Ok(mut guard) = self.parked_waker.lock()
&& let Some(w) = guard.take()
{
w.wake();
}
}
fn has_waiter(&self) -> bool {
self.parked_waker
.lock()
.map(|guard| guard.is_some())
.unwrap_or(false)
}
}
#[derive(Default)]
struct GateWriter {
gate: Arc<GateState>,
}
impl GateWriter {
fn new(gate: Arc<GateState>) -> Self {
Self { gate }
}
}
impl AsyncWrite for GateWriter {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
if self.gate.open.load(Ordering::Relaxed) {
return Poll::Ready(Ok(buf.len()));
}
if let Ok(mut guard) = self.gate.parked_waker.lock() {
*guard = Some(cx.waker().clone());
}
Poll::Pending
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
}
struct FailingWriter;
impl AsyncWrite for FailingWriter {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &[u8],
) -> Poll<io::Result<usize>> {
Poll::Ready(Err(io::Error::new(
io::ErrorKind::BrokenPipe,
"injected writer failure",
)))
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn adversarial_same_user_slow_writer_must_not_hol_block_peer_connection() {
let stats = Stats::new();
let bytes_me2c = AtomicU64::new(0);
let rng = SecureRandom::new();
let quota_limit = Some(1024);
let user = "hol-quota-user";
let gate = Arc::new(GateState::default());
let mut blocked_writer = make_crypto_writer(GateWriter::new(Arc::clone(&gate)));
let slow_task = tokio::spawn(async move {
let mut frame_buf = Vec::new();
process_me_writer_response(
MeResponse::Data {
flags: 0,
data: Bytes::from_static(&[0x10, 0x20, 0x30, 0x40]),
},
&mut blocked_writer,
ProtoTag::Intermediate,
&rng,
&mut frame_buf,
&stats,
user,
quota_limit,
&bytes_me2c,
7001,
false,
false,
)
.await
});
timeout(Duration::from_millis(100), async {
loop {
if gate.has_waiter() {
break;
}
tokio::task::yield_now().await;
}
})
.await
.expect("first writer must reach backpressure and park");
let stats_fast = Stats::new();
let bytes_fast = AtomicU64::new(0);
let rng_fast = SecureRandom::new();
let mut fast_writer = make_crypto_writer(tokio::io::sink());
let mut frame_buf_fast = Vec::new();
timeout(
Duration::from_millis(50),
process_me_writer_response(
MeResponse::Data {
flags: 0,
data: Bytes::from_static(&[0x41]),
},
&mut fast_writer,
ProtoTag::Intermediate,
&rng_fast,
&mut frame_buf_fast,
&stats_fast,
user,
quota_limit,
&bytes_fast,
7002,
false,
false,
),
)
.await
.expect("peer connection must not be blocked by same-user stalled write")
.expect("fast peer write must succeed");
gate.open();
let slow_result = timeout(Duration::from_secs(1), slow_task)
.await
.expect("stalled task must complete once gate opens")
.expect("stalled task must not panic");
assert!(slow_result.is_ok());
}
#[tokio::test]
async fn negative_write_failure_rolls_back_pre_accounted_quota_and_forensics_bytes() {
let stats = Stats::new();
let user = "rollback-user";
stats.add_user_octets_from(user, 7);
let bytes_me2c = AtomicU64::new(0);
let rng = SecureRandom::new();
let mut writer = make_crypto_writer(FailingWriter);
let mut frame_buf = Vec::new();
let result = process_me_writer_response(
MeResponse::Data {
flags: 0,
data: Bytes::from_static(&[1, 2, 3, 4]),
},
&mut writer,
ProtoTag::Intermediate,
&rng,
&mut frame_buf,
&stats,
user,
Some(64),
&bytes_me2c,
7003,
false,
false,
)
.await;
assert!(matches!(result, Err(ProxyError::Io(_))));
assert_eq!(
stats.get_user_total_octets(user),
7,
"failed client write must not overcharge user quota accounting"
);
assert_eq!(
bytes_me2c.load(Ordering::Relaxed),
0,
"failed client write must not inflate ME->C forensic byte counter"
);
}

View File

@ -3,7 +3,7 @@ use crate::crypto::AesCtr;
use crate::stats::Stats;
use crate::stream::{BufferPool, CryptoReader};
use std::sync::atomic::AtomicU64;
use std::sync::{Arc, Mutex, OnceLock};
use std::sync::Arc;
use tokio::io::AsyncWriteExt;
use tokio::io::duplex;
use tokio::time::{Duration as TokioDuration, Instant as TokioInstant, timeout};
@ -48,18 +48,6 @@ fn make_idle_policy(soft_ms: u64, hard_ms: u64, grace_ms: u64) -> RelayClientIdl
}
}
fn idle_pressure_test_lock() -> &'static Mutex<()> {
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
TEST_LOCK.get_or_init(|| Mutex::new(()))
}
fn acquire_idle_pressure_test_lock() -> std::sync::MutexGuard<'static, ()> {
match idle_pressure_test_lock().lock() {
Ok(guard) => guard,
Err(poisoned) => poisoned.into_inner(),
}
}
#[tokio::test]
async fn idle_policy_soft_mark_then_hard_close_increments_reason_counters() {
let (reader, _writer) = duplex(1024);
@ -372,7 +360,7 @@ async fn stress_many_idle_sessions_fail_closed_without_hang() {
#[test]
fn pressure_evicts_oldest_idle_candidate_with_deterministic_ordering() {
let _guard = acquire_idle_pressure_test_lock();
let _guard = relay_idle_pressure_test_scope();
clear_relay_idle_pressure_state_for_testing();
let stats = Stats::new();
@ -402,7 +390,7 @@ fn pressure_evicts_oldest_idle_candidate_with_deterministic_ordering() {
#[test]
fn pressure_does_not_evict_without_new_pressure_signal() {
let _guard = acquire_idle_pressure_test_lock();
let _guard = relay_idle_pressure_test_scope();
clear_relay_idle_pressure_state_for_testing();
let stats = Stats::new();
@ -421,7 +409,7 @@ fn pressure_does_not_evict_without_new_pressure_signal() {
#[test]
fn stress_pressure_eviction_preserves_fifo_across_many_candidates() {
let _guard = acquire_idle_pressure_test_lock();
let _guard = relay_idle_pressure_test_scope();
clear_relay_idle_pressure_state_for_testing();
let stats = Stats::new();
@ -457,7 +445,7 @@ fn stress_pressure_eviction_preserves_fifo_across_many_candidates() {
#[test]
fn blackhat_single_pressure_event_must_not_evict_more_than_one_candidate() {
let _guard = acquire_idle_pressure_test_lock();
let _guard = relay_idle_pressure_test_scope();
clear_relay_idle_pressure_state_for_testing();
let stats = Stats::new();
@ -491,7 +479,7 @@ fn blackhat_single_pressure_event_must_not_evict_more_than_one_candidate() {
#[test]
fn blackhat_pressure_counter_must_track_global_budget_not_per_session_cursor() {
let _guard = acquire_idle_pressure_test_lock();
let _guard = relay_idle_pressure_test_scope();
clear_relay_idle_pressure_state_for_testing();
let stats = Stats::new();
@ -524,7 +512,7 @@ fn blackhat_pressure_counter_must_track_global_budget_not_per_session_cursor() {
#[test]
fn blackhat_stale_pressure_before_idle_mark_must_not_trigger_eviction() {
let _guard = acquire_idle_pressure_test_lock();
let _guard = relay_idle_pressure_test_scope();
clear_relay_idle_pressure_state_for_testing();
let stats = Stats::new();
@ -543,7 +531,7 @@ fn blackhat_stale_pressure_before_idle_mark_must_not_trigger_eviction() {
#[test]
fn blackhat_stale_pressure_must_not_evict_any_of_newly_marked_batch() {
let _guard = acquire_idle_pressure_test_lock();
let _guard = relay_idle_pressure_test_scope();
clear_relay_idle_pressure_state_for_testing();
let stats = Stats::new();
@ -575,7 +563,7 @@ fn blackhat_stale_pressure_must_not_evict_any_of_newly_marked_batch() {
#[test]
fn blackhat_stale_pressure_seen_without_candidates_must_be_globally_invalidated() {
let _guard = acquire_idle_pressure_test_lock();
let _guard = relay_idle_pressure_test_scope();
clear_relay_idle_pressure_state_for_testing();
let stats = Stats::new();
@ -601,7 +589,7 @@ fn blackhat_stale_pressure_seen_without_candidates_must_be_globally_invalidated(
#[test]
fn blackhat_stale_pressure_must_not_survive_candidate_churn() {
let _guard = acquire_idle_pressure_test_lock();
let _guard = relay_idle_pressure_test_scope();
clear_relay_idle_pressure_state_for_testing();
let stats = Stats::new();
@ -621,7 +609,7 @@ fn blackhat_stale_pressure_must_not_survive_candidate_churn() {
#[test]
fn blackhat_pressure_seq_saturation_must_not_disable_future_pressure_accounting() {
let _guard = acquire_idle_pressure_test_lock();
let _guard = relay_idle_pressure_test_scope();
clear_relay_idle_pressure_state_for_testing();
{
@ -646,7 +634,7 @@ fn blackhat_pressure_seq_saturation_must_not_disable_future_pressure_accounting(
#[test]
fn blackhat_pressure_seq_saturation_must_not_break_multiple_distinct_events() {
let _guard = acquire_idle_pressure_test_lock();
let _guard = relay_idle_pressure_test_scope();
clear_relay_idle_pressure_state_for_testing();
{
@ -673,7 +661,7 @@ fn blackhat_pressure_seq_saturation_must_not_break_multiple_distinct_events() {
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn integration_race_single_pressure_event_allows_at_most_one_eviction_under_parallel_claims()
{
let _guard = acquire_idle_pressure_test_lock();
let _guard = relay_idle_pressure_test_scope();
clear_relay_idle_pressure_state_for_testing();
let stats = Arc::new(Stats::new());
@ -738,7 +726,7 @@ async fn integration_race_single_pressure_event_allows_at_most_one_eviction_unde
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn integration_race_burst_pressure_with_churn_preserves_empty_set_invalidation_and_budget() {
let _guard = acquire_idle_pressure_test_lock();
let _guard = relay_idle_pressure_test_scope();
clear_relay_idle_pressure_state_for_testing();
let stats = Arc::new(Stats::new());

View File

@ -0,0 +1,59 @@
use super::*;
use std::panic::{AssertUnwindSafe, catch_unwind};
#[test]
fn blackhat_registry_poison_recovers_with_fail_closed_reset_and_pressure_accounting() {
let _guard = relay_idle_pressure_test_scope();
clear_relay_idle_pressure_state_for_testing();
let _ = catch_unwind(AssertUnwindSafe(|| {
let registry = relay_idle_candidate_registry();
let mut guard = registry
.lock()
.expect("registry lock must be acquired before poison");
guard.by_conn_id.insert(
999,
RelayIdleCandidateMeta {
mark_order_seq: 1,
mark_pressure_seq: 0,
},
);
guard.ordered.insert((1, 999));
panic!("intentional poison for idle-registry recovery");
}));
// Helper lock must recover from poison, reset stale state, and continue.
assert!(mark_relay_idle_candidate(42));
assert_eq!(oldest_relay_idle_candidate(), Some(42));
let before = relay_pressure_event_seq();
note_relay_pressure_event();
let after = relay_pressure_event_seq();
assert!(after > before, "pressure accounting must still advance after poison");
clear_relay_idle_pressure_state_for_testing();
}
#[test]
fn clear_state_helper_must_reset_poisoned_registry_for_deterministic_fifo_tests() {
let _guard = relay_idle_pressure_test_scope();
clear_relay_idle_pressure_state_for_testing();
let _ = catch_unwind(AssertUnwindSafe(|| {
let registry = relay_idle_candidate_registry();
let _guard = registry
.lock()
.expect("registry lock must be acquired before poison");
panic!("intentional poison while lock held");
}));
clear_relay_idle_pressure_state_for_testing();
assert_eq!(oldest_relay_idle_candidate(), None);
assert_eq!(relay_pressure_event_seq(), 0);
assert!(mark_relay_idle_candidate(7));
assert_eq!(oldest_relay_idle_candidate(), Some(7));
clear_relay_idle_pressure_state_for_testing();
}

View File

@ -0,0 +1,192 @@
use super::*;
use crate::crypto::{AesCtr, SecureRandom};
use crate::stats::Stats;
use crate::stream::CryptoWriter;
use bytes::Bytes;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use tokio::task::JoinSet;
fn make_crypto_writer<W>(writer: W) -> CryptoWriter<W>
where
W: tokio::io::AsyncWrite + Unpin,
{
let key = [0u8; 32];
let iv = 0u128;
CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024)
}
#[tokio::test]
async fn positive_exact_quota_boundary_allows_last_frame_and_blocks_next() {
let stats = Stats::new();
let user = "quota-boundary-user";
let bytes_me2c = AtomicU64::new(0);
stats.add_user_octets_from(user, 5);
let mut writer_one = make_crypto_writer(tokio::io::sink());
let mut frame_buf_one = Vec::new();
let first = process_me_writer_response(
MeResponse::Data {
flags: 0,
data: Bytes::from_static(&[1, 2, 3]),
},
&mut writer_one,
ProtoTag::Intermediate,
&SecureRandom::new(),
&mut frame_buf_one,
&stats,
user,
Some(8),
&bytes_me2c,
7101,
false,
false,
)
.await;
assert!(first.is_ok(), "frame that reaches boundary must be allowed");
assert_eq!(stats.get_user_total_octets(user), 8);
let mut writer_two = make_crypto_writer(tokio::io::sink());
let mut frame_buf_two = Vec::new();
let second = process_me_writer_response(
MeResponse::Data {
flags: 0,
data: Bytes::from_static(&[9]),
},
&mut writer_two,
ProtoTag::Intermediate,
&SecureRandom::new(),
&mut frame_buf_two,
&stats,
user,
Some(8),
&bytes_me2c,
7102,
false,
false,
)
.await;
assert!(
matches!(second, Err(ProxyError::DataQuotaExceeded { .. })),
"frame after boundary must be rejected"
);
assert_eq!(stats.get_user_total_octets(user), 8);
assert_eq!(bytes_me2c.load(Ordering::Relaxed), 3);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn adversarial_parallel_reservation_stress_never_overshoots_quota_or_counters() {
let stats = Arc::new(Stats::new());
let user = "reservation-stress-user";
let quota_limit = 64u64;
let bytes_me2c = Arc::new(AtomicU64::new(0));
let mut tasks = JoinSet::new();
for idx in 0..256u64 {
let user_owned = user.to_string();
let stats_ref = Arc::clone(&stats);
let bytes_ref = Arc::clone(&bytes_me2c);
tasks.spawn(async move {
let mut writer = make_crypto_writer(tokio::io::sink());
let mut frame_buf = Vec::new();
process_me_writer_response(
MeResponse::Data {
flags: 0,
data: Bytes::from_static(&[0xAB]),
},
&mut writer,
ProtoTag::Intermediate,
&SecureRandom::new(),
&mut frame_buf,
stats_ref.as_ref(),
&user_owned,
Some(quota_limit),
bytes_ref.as_ref(),
7200 + idx,
false,
false,
)
.await
});
}
let mut ok = 0usize;
let mut denied = 0usize;
while let Some(joined) = tasks.join_next().await {
match joined.expect("reservation stress task must not panic") {
Ok(_) => ok += 1,
Err(ProxyError::DataQuotaExceeded { .. }) => denied += 1,
Err(other) => panic!("unexpected error in stress case: {other:?}"),
}
}
let total = stats.get_user_total_octets(user);
assert_eq!(
total, quota_limit,
"quota must be exactly exhausted without overshoot"
);
assert_eq!(
bytes_me2c.load(Ordering::Relaxed),
total,
"ME->C forensic bytes must track committed quota usage"
);
assert_eq!(ok, quota_limit as usize, "exactly quota_limit tasks must succeed");
assert_eq!(
denied,
256usize - (quota_limit as usize),
"remaining tasks must be exactly denied without silently swallowing state"
);
}
#[tokio::test]
async fn light_fuzz_random_frame_sizes_preserve_quota_and_counter_consistency() {
let stats = Stats::new();
let user = "reservation-fuzz-user";
let quota_limit = 128u64;
let bytes_me2c = AtomicU64::new(0);
let mut seed = 0xC0FE_EE11_8899_2211u64;
for conn in 0..512u64 {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let len = ((seed & 0x0f) + 1) as usize;
let payload = vec![0x5A; len];
let mut writer = make_crypto_writer(tokio::io::sink());
let mut frame_buf = Vec::new();
let result = process_me_writer_response(
MeResponse::Data {
flags: 0,
data: Bytes::from(payload),
},
&mut writer,
ProtoTag::Intermediate,
&SecureRandom::new(),
&mut frame_buf,
&stats,
user,
Some(quota_limit),
&bytes_me2c,
7300 + conn,
false,
false,
)
.await;
if let Err(err) = result {
assert!(
matches!(err, ProxyError::DataQuotaExceeded { .. }),
"fuzz run produced unexpected error variant: {err:?}"
);
}
}
let total = stats.get_user_total_octets(user);
assert!(total <= quota_limit);
assert_eq!(bytes_me2c.load(Ordering::Relaxed), total);
}

View File

@ -0,0 +1,365 @@
use super::*;
use crate::crypto::AesCtr;
use crate::stats::Stats;
use crate::stream::{BufferPool, CryptoReader};
use std::sync::Arc;
use std::sync::atomic::AtomicU64;
use std::time::Instant;
use tokio::io::{AsyncRead, AsyncWriteExt, duplex};
use tokio::task::JoinSet;
use tokio::time::{Duration as TokioDuration, sleep, timeout};
fn make_crypto_reader<T>(reader: T) -> CryptoReader<T>
where
T: AsyncRead + Unpin + Send + 'static,
{
let key = [0u8; 32];
let iv = 0u128;
CryptoReader::new(reader, AesCtr::new(&key, iv))
}
fn encrypt_for_reader(plaintext: &[u8]) -> Vec<u8> {
let key = [0u8; 32];
let iv = 0u128;
let mut cipher = AesCtr::new(&key, iv);
cipher.encrypt(plaintext)
}
fn make_forensics(conn_id: u64, started_at: Instant) -> RelayForensicsState {
RelayForensicsState {
trace_id: 0xB200_0000 + conn_id,
conn_id,
user: format!("tiny-frame-debt-concurrency-user-{conn_id}"),
peer: "127.0.0.1:50000".parse().expect("peer parse must succeed"),
peer_hash: hash_ip("127.0.0.1".parse().expect("ip parse must succeed")),
started_at,
bytes_c2me: 0,
bytes_me2c: Arc::new(AtomicU64::new(0)),
desync_all_full: false,
}
}
fn make_enabled_idle_policy() -> RelayClientIdlePolicy {
RelayClientIdlePolicy {
enabled: true,
soft_idle: Duration::from_secs(30),
hard_idle: Duration::from_secs(60),
grace_after_downstream_activity: Duration::from_secs(0),
legacy_frame_read_timeout: Duration::from_secs(30),
}
}
async fn read_once(
crypto_reader: &mut CryptoReader<tokio::io::DuplexStream>,
proto: ProtoTag,
forensics: &RelayForensicsState,
frame_counter: &mut u64,
idle_state: &mut RelayClientIdleState,
) -> Result<Option<(PooledBuffer, bool)>> {
let buffer_pool = Arc::new(BufferPool::new());
let stats = Stats::new();
let idle_policy = make_enabled_idle_policy();
let last_downstream_activity_ms = AtomicU64::new(0);
read_client_payload_with_idle_policy(
crypto_reader,
proto,
1024,
&buffer_pool,
forensics,
frame_counter,
&stats,
&idle_policy,
idle_state,
&last_downstream_activity_ms,
forensics.started_at,
)
.await
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn stress_parallel_pure_tiny_floods_all_fail_closed() {
let mut set = JoinSet::new();
for idx in 0..32u64 {
set.spawn(async move {
let (reader, mut writer) = duplex(4096);
let mut crypto_reader = make_crypto_reader(reader);
let started = Instant::now();
let forensics = make_forensics(1000 + idx, started);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(started);
let flood_plaintext = vec![0u8; 1024];
let flood_encrypted = encrypt_for_reader(&flood_plaintext);
writer.write_all(&flood_encrypted).await.unwrap();
drop(writer);
let result = timeout(
TokioDuration::from_secs(1),
read_once(
&mut crypto_reader,
ProtoTag::Abridged,
&forensics,
&mut frame_counter,
&mut idle_state,
),
)
.await
.expect("tiny flood task must complete");
assert!(matches!(result, Err(ProxyError::Proxy(_))));
assert_eq!(frame_counter, 0);
});
}
while let Some(result) = set.join_next().await {
result.expect("parallel tiny flood worker must not panic");
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn stress_parallel_benign_tiny_burst_then_real_all_pass() {
let mut set = JoinSet::new();
for idx in 0..24u64 {
set.spawn(async move {
let (reader, mut writer) = duplex(2048);
let mut crypto_reader = make_crypto_reader(reader);
let started = Instant::now();
let forensics = make_forensics(2000 + idx, started);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(started);
let payload = [idx as u8, 2, 3, 4];
let mut plaintext = Vec::with_capacity(20);
for _ in 0..6 {
plaintext.push(0x00);
}
plaintext.push(0x01);
plaintext.extend_from_slice(&payload);
let encrypted = encrypt_for_reader(&plaintext);
writer.write_all(&encrypted).await.unwrap();
let result = timeout(
TokioDuration::from_secs(1),
read_once(
&mut crypto_reader,
ProtoTag::Abridged,
&forensics,
&mut frame_counter,
&mut idle_state,
),
)
.await
.expect("benign task must complete")
.expect("benign payload must parse")
.expect("benign payload must return frame");
assert_eq!(result.0.as_ref(), &payload);
assert_eq!(frame_counter, 1);
});
}
while let Some(result) = set.join_next().await {
result.expect("parallel benign worker must not panic");
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn adversarial_lockstep_alternating_attack_under_jitter_closes() {
let mut set = JoinSet::new();
for idx in 0..12u64 {
set.spawn(async move {
let (reader, mut writer) = duplex(8192);
let mut crypto_reader = make_crypto_reader(reader);
let started = Instant::now();
let forensics = make_forensics(3000 + idx, started);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(started);
let mut plaintext = Vec::with_capacity(2000);
for n in 0..180u8 {
plaintext.push(0x00);
plaintext.push(0x01);
plaintext.extend_from_slice(&[n, n ^ 0x21, n ^ 0x42, n ^ 0x84]);
}
let encrypted = encrypt_for_reader(&plaintext);
let writer_task = tokio::spawn(async move {
for chunk in encrypted.chunks(17) {
writer.write_all(chunk).await.unwrap();
sleep(TokioDuration::from_millis(1)).await;
}
drop(writer);
});
let mut closed = false;
for _ in 0..220 {
let result = timeout(
TokioDuration::from_secs(1),
read_once(
&mut crypto_reader,
ProtoTag::Abridged,
&forensics,
&mut frame_counter,
&mut idle_state,
),
)
.await
.expect("alternating reader step must complete");
match result {
Ok(Some((_payload, _))) => {}
Err(ProxyError::Proxy(_)) => {
closed = true;
break;
}
Ok(None) => break,
Err(other) => panic!("unexpected error in alternating jitter case: {other}"),
}
}
writer_task.await.expect("writer jitter task must not panic");
assert!(closed, "alternating attack must close before EOF");
});
}
while let Some(result) = set.join_next().await {
result.expect("alternating jitter worker must not panic");
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn integration_mixed_population_attackers_close_benign_survive() {
let mut set = JoinSet::new();
for idx in 0..20u64 {
set.spawn(async move {
let (reader, mut writer) = duplex(4096);
let mut crypto_reader = make_crypto_reader(reader);
let started = Instant::now();
let forensics = make_forensics(4000 + idx, started);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(started);
if idx % 2 == 0 {
let mut plaintext = Vec::with_capacity(1280);
for n in 0..140u8 {
plaintext.push(0x00);
plaintext.push(0x01);
plaintext.extend_from_slice(&[n, n, n, n]);
}
writer.write_all(&encrypt_for_reader(&plaintext)).await.unwrap();
drop(writer);
let mut closed = false;
for _ in 0..200 {
match read_once(
&mut crypto_reader,
ProtoTag::Abridged,
&forensics,
&mut frame_counter,
&mut idle_state,
)
.await
{
Ok(Some(_)) => {}
Err(ProxyError::Proxy(_)) => {
closed = true;
break;
}
Ok(None) => break,
Err(other) => panic!("unexpected attacker error: {other}"),
}
}
assert!(closed, "attacker session must fail closed");
} else {
let payload = [1u8, 9, 8, 7];
let mut plaintext = Vec::new();
for _ in 0..4 {
plaintext.push(0x00);
}
plaintext.push(0x01);
plaintext.extend_from_slice(&payload);
writer.write_all(&encrypt_for_reader(&plaintext)).await.unwrap();
let got = read_once(
&mut crypto_reader,
ProtoTag::Abridged,
&forensics,
&mut frame_counter,
&mut idle_state,
)
.await
.expect("benign session must parse")
.expect("benign session must return a frame");
assert_eq!(got.0.as_ref(), &payload);
}
});
}
while let Some(result) = set.join_next().await {
result.expect("mixed-population worker must not panic");
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn light_fuzz_parallel_patterns_no_hang_or_panic() {
let mut set = JoinSet::new();
for case in 0..40u64 {
set.spawn(async move {
let (reader, mut writer) = duplex(8192);
let mut crypto_reader = make_crypto_reader(reader);
let started = Instant::now();
let forensics = make_forensics(5000 + case, started);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(started);
let mut seed = 0x9E37_79B9u64 ^ (case << 8);
let mut plaintext = Vec::with_capacity(2048);
for _ in 0..256 {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let is_tiny = (seed & 1) == 0;
if is_tiny {
plaintext.push(0x00);
} else {
plaintext.push(0x01);
plaintext.extend_from_slice(&[(seed >> 8) as u8, 2, 3, 4]);
}
}
writer.write_all(&encrypt_for_reader(&plaintext)).await.unwrap();
drop(writer);
for _ in 0..320 {
let step = timeout(
TokioDuration::from_secs(1),
read_once(
&mut crypto_reader,
ProtoTag::Abridged,
&forensics,
&mut frame_counter,
&mut idle_state,
),
)
.await
.expect("fuzz case read step must complete");
match step {
Ok(Some(_)) => {}
Err(ProxyError::Proxy(_)) => break,
Ok(None) => break,
Err(other) => panic!("unexpected fuzz case error: {other}"),
}
}
});
}
while let Some(result) = set.join_next().await {
result.expect("fuzz worker must not panic");
}
}

View File

@ -0,0 +1,418 @@
use super::*;
use crate::crypto::AesCtr;
use crate::stats::Stats;
use crate::stream::{BufferPool, CryptoReader, PooledBuffer};
use std::sync::Arc;
use std::sync::atomic::AtomicU64;
use std::time::Instant;
use tokio::io::{AsyncRead, AsyncWriteExt, duplex};
use tokio::time::{Duration as TokioDuration, sleep, timeout};
fn make_crypto_reader<T>(reader: T) -> CryptoReader<T>
where
T: AsyncRead + Unpin + Send + 'static,
{
let key = [0u8; 32];
let iv = 0u128;
CryptoReader::new(reader, AesCtr::new(&key, iv))
}
fn encrypt_for_reader(plaintext: &[u8]) -> Vec<u8> {
let key = [0u8; 32];
let iv = 0u128;
let mut cipher = AesCtr::new(&key, iv);
cipher.encrypt(plaintext)
}
fn make_forensics(conn_id: u64, started_at: Instant) -> RelayForensicsState {
RelayForensicsState {
trace_id: 0xB300_0000 + conn_id,
conn_id,
user: format!("tiny-frame-debt-proto-chunk-user-{conn_id}"),
peer: "127.0.0.1:50000".parse().expect("peer parse must succeed"),
peer_hash: hash_ip("127.0.0.1".parse().expect("ip parse must succeed")),
started_at,
bytes_c2me: 0,
bytes_me2c: Arc::new(AtomicU64::new(0)),
desync_all_full: false,
}
}
fn make_enabled_idle_policy() -> RelayClientIdlePolicy {
RelayClientIdlePolicy {
enabled: true,
soft_idle: Duration::from_secs(30),
hard_idle: Duration::from_secs(60),
grace_after_downstream_activity: Duration::from_secs(0),
legacy_frame_read_timeout: Duration::from_secs(30),
}
}
fn append_tiny_frame(plaintext: &mut Vec<u8>, proto: ProtoTag) {
match proto {
ProtoTag::Abridged => plaintext.push(0x00),
ProtoTag::Intermediate | ProtoTag::Secure => plaintext.extend_from_slice(&0u32.to_le_bytes()),
}
}
fn append_real_frame(plaintext: &mut Vec<u8>, proto: ProtoTag, payload: [u8; 4]) {
match proto {
ProtoTag::Abridged => {
plaintext.push(0x01);
plaintext.extend_from_slice(&payload);
}
ProtoTag::Intermediate | ProtoTag::Secure => {
plaintext.extend_from_slice(&4u32.to_le_bytes());
plaintext.extend_from_slice(&payload);
}
}
}
async fn write_chunked_with_jitter(
writer: &mut tokio::io::DuplexStream,
bytes: &[u8],
mut seed: u64,
) {
let mut offset = 0usize;
while offset < bytes.len() {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let chunk_len = 1 + ((seed as usize) & 0x1f);
let end = (offset + chunk_len).min(bytes.len());
writer.write_all(&bytes[offset..end]).await.unwrap();
let delay_ms = ((seed >> 16) % 3) as u64;
if delay_ms > 0 {
sleep(TokioDuration::from_millis(delay_ms)).await;
}
offset = end;
}
}
async fn read_once_with_state(
crypto_reader: &mut CryptoReader<tokio::io::DuplexStream>,
proto: ProtoTag,
forensics: &RelayForensicsState,
frame_counter: &mut u64,
idle_state: &mut RelayClientIdleState,
) -> Result<Option<(PooledBuffer, bool)>> {
let buffer_pool = Arc::new(BufferPool::new());
let stats = Stats::new();
let idle_policy = make_enabled_idle_policy();
let last_downstream_activity_ms = AtomicU64::new(0);
read_client_payload_with_idle_policy(
crypto_reader,
proto,
1024,
&buffer_pool,
forensics,
frame_counter,
&stats,
&idle_policy,
idle_state,
&last_downstream_activity_ms,
forensics.started_at,
)
.await
}
#[tokio::test]
async fn intermediate_chunked_zero_flood_fail_closed() {
let (reader, mut writer) = duplex(4096);
let mut crypto_reader = make_crypto_reader(reader);
let started = Instant::now();
let forensics = make_forensics(6101, started);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(started);
let mut plaintext = Vec::with_capacity(4 * 256);
for _ in 0..256 {
append_tiny_frame(&mut plaintext, ProtoTag::Intermediate);
}
let encrypted = encrypt_for_reader(&plaintext);
write_chunked_with_jitter(&mut writer, &encrypted, 0x1111_2222).await;
drop(writer);
let result = timeout(
TokioDuration::from_secs(2),
read_once_with_state(
&mut crypto_reader,
ProtoTag::Intermediate,
&forensics,
&mut frame_counter,
&mut idle_state,
),
)
.await
.expect("intermediate flood read must complete");
assert!(matches!(result, Err(ProxyError::Proxy(_))));
assert_eq!(frame_counter, 0);
}
#[tokio::test]
async fn secure_chunked_zero_flood_fail_closed() {
let (reader, mut writer) = duplex(4096);
let mut crypto_reader = make_crypto_reader(reader);
let started = Instant::now();
let forensics = make_forensics(6102, started);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(started);
let mut plaintext = Vec::with_capacity(4 * 256);
for _ in 0..256 {
append_tiny_frame(&mut plaintext, ProtoTag::Secure);
}
let encrypted = encrypt_for_reader(&plaintext);
write_chunked_with_jitter(&mut writer, &encrypted, 0x3333_4444).await;
drop(writer);
let result = timeout(
TokioDuration::from_secs(2),
read_once_with_state(
&mut crypto_reader,
ProtoTag::Secure,
&forensics,
&mut frame_counter,
&mut idle_state,
),
)
.await
.expect("secure flood read must complete");
assert!(matches!(result, Err(ProxyError::Proxy(_))));
assert_eq!(frame_counter, 0);
}
#[tokio::test]
async fn intermediate_chunked_alternating_attack_closes_before_eof() {
let (reader, mut writer) = duplex(8192);
let mut crypto_reader = make_crypto_reader(reader);
let started = Instant::now();
let forensics = make_forensics(6103, started);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(started);
let mut plaintext = Vec::with_capacity(8 * 200);
for n in 0..180u8 {
append_tiny_frame(&mut plaintext, ProtoTag::Intermediate);
append_real_frame(&mut plaintext, ProtoTag::Intermediate, [n, n ^ 1, n ^ 2, n ^ 3]);
}
let encrypted = encrypt_for_reader(&plaintext);
let writer_task = tokio::spawn(async move {
write_chunked_with_jitter(&mut writer, &encrypted, 0x5555_6666).await;
drop(writer);
});
let mut closed = false;
for _ in 0..240 {
let step = timeout(
TokioDuration::from_secs(1),
read_once_with_state(
&mut crypto_reader,
ProtoTag::Intermediate,
&forensics,
&mut frame_counter,
&mut idle_state,
),
)
.await
.expect("intermediate alternating read step must complete");
match step {
Ok(Some(_)) => {}
Err(ProxyError::Proxy(_)) => {
closed = true;
break;
}
Ok(None) => break,
Err(other) => panic!("unexpected intermediate alternating error: {other}"),
}
}
writer_task.await.expect("intermediate writer task must not panic");
assert!(closed, "intermediate alternating attack must fail closed");
}
#[tokio::test]
async fn secure_chunked_alternating_attack_closes_before_eof() {
let (reader, mut writer) = duplex(8192);
let mut crypto_reader = make_crypto_reader(reader);
let started = Instant::now();
let forensics = make_forensics(6104, started);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(started);
let mut plaintext = Vec::with_capacity(8 * 200);
for n in 0..180u8 {
append_tiny_frame(&mut plaintext, ProtoTag::Secure);
append_real_frame(&mut plaintext, ProtoTag::Secure, [n, n ^ 7, n ^ 11, n ^ 19]);
}
let encrypted = encrypt_for_reader(&plaintext);
let writer_task = tokio::spawn(async move {
write_chunked_with_jitter(&mut writer, &encrypted, 0x7777_8888).await;
drop(writer);
});
let mut closed = false;
for _ in 0..240 {
let step = timeout(
TokioDuration::from_secs(1),
read_once_with_state(
&mut crypto_reader,
ProtoTag::Secure,
&forensics,
&mut frame_counter,
&mut idle_state,
),
)
.await
.expect("secure alternating read step must complete");
match step {
Ok(Some(_)) => {}
Err(ProxyError::Proxy(_)) => {
closed = true;
break;
}
Ok(None) => break,
Err(other) => panic!("unexpected secure alternating error: {other}"),
}
}
writer_task.await.expect("secure writer task must not panic");
assert!(closed, "secure alternating attack must fail closed");
}
#[tokio::test]
async fn intermediate_chunked_safe_small_burst_still_returns_real_frame() {
let (reader, mut writer) = duplex(1024);
let mut crypto_reader = make_crypto_reader(reader);
let started = Instant::now();
let forensics = make_forensics(6105, started);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(started);
let payload = [9u8, 8, 7, 6];
let mut plaintext = Vec::new();
for _ in 0..7 {
append_tiny_frame(&mut plaintext, ProtoTag::Intermediate);
}
append_real_frame(&mut plaintext, ProtoTag::Intermediate, payload);
let encrypted = encrypt_for_reader(&plaintext);
write_chunked_with_jitter(&mut writer, &encrypted, 0xAAAA_BBBB).await;
let result = read_once_with_state(
&mut crypto_reader,
ProtoTag::Intermediate,
&forensics,
&mut frame_counter,
&mut idle_state,
)
.await
.expect("intermediate safe burst should parse")
.expect("intermediate safe burst should return a frame");
assert_eq!(result.0.as_ref(), &payload);
assert_eq!(frame_counter, 1);
}
#[tokio::test]
async fn secure_chunked_safe_small_burst_still_returns_real_frame() {
let (reader, mut writer) = duplex(1024);
let mut crypto_reader = make_crypto_reader(reader);
let started = Instant::now();
let forensics = make_forensics(6106, started);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(started);
let payload = [3u8, 1, 4, 1];
let mut plaintext = Vec::new();
for _ in 0..7 {
append_tiny_frame(&mut plaintext, ProtoTag::Secure);
}
append_real_frame(&mut plaintext, ProtoTag::Secure, payload);
let encrypted = encrypt_for_reader(&plaintext);
write_chunked_with_jitter(&mut writer, &encrypted, 0xCCCC_DDDD).await;
let result = read_once_with_state(
&mut crypto_reader,
ProtoTag::Secure,
&forensics,
&mut frame_counter,
&mut idle_state,
)
.await
.expect("secure safe burst should parse")
.expect("secure safe burst should return a frame");
assert_eq!(result.0.as_ref(), &payload);
assert_eq!(frame_counter, 1);
}
#[tokio::test]
async fn light_fuzz_proto_chunking_outcomes_are_bounded() {
let mut seed = 0xDEAD_BEEF_2026_0322u64;
for case in 0..48u64 {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let proto = if (seed & 1) == 0 {
ProtoTag::Intermediate
} else {
ProtoTag::Secure
};
let (reader, mut writer) = duplex(8192);
let mut crypto_reader = make_crypto_reader(reader);
let started = Instant::now();
let forensics = make_forensics(6200 + case, started);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(started);
let mut stream = Vec::new();
let mut local_seed = seed ^ case;
for _ in 0..220 {
local_seed ^= local_seed << 7;
local_seed ^= local_seed >> 9;
local_seed ^= local_seed << 8;
if (local_seed & 1) == 0 {
append_tiny_frame(&mut stream, proto);
} else {
let b = (local_seed >> 8) as u8;
append_real_frame(&mut stream, proto, [b, b ^ 0x12, b ^ 0x24, b ^ 0x48]);
}
}
let encrypted = encrypt_for_reader(&stream);
write_chunked_with_jitter(&mut writer, &encrypted, seed ^ 0x1234_5678).await;
drop(writer);
for _ in 0..260 {
let step = timeout(
TokioDuration::from_secs(1),
read_once_with_state(
&mut crypto_reader,
proto,
&forensics,
&mut frame_counter,
&mut idle_state,
),
)
.await
.expect("fuzz proto read step must complete");
match step {
Ok(Some((_payload, _))) => {}
Err(ProxyError::Proxy(_)) => break,
Ok(None) => break,
Err(other) => panic!("unexpected proto chunking fuzz error: {other}"),
}
}
}
}

View File

@ -0,0 +1,550 @@
use super::*;
use crate::crypto::AesCtr;
use crate::stats::Stats;
use crate::stream::{BufferPool, CryptoReader};
use std::sync::atomic::AtomicU64;
use std::sync::Arc;
use std::time::Instant;
use tokio::io::{AsyncRead, AsyncWriteExt, duplex};
fn make_crypto_reader<T>(reader: T) -> CryptoReader<T>
where
T: AsyncRead + Unpin + Send + 'static,
{
let key = [0u8; 32];
let iv = 0u128;
CryptoReader::new(reader, AesCtr::new(&key, iv))
}
fn encrypt_for_reader(plaintext: &[u8]) -> Vec<u8> {
let key = [0u8; 32];
let iv = 0u128;
let mut cipher = AesCtr::new(&key, iv);
cipher.encrypt(plaintext)
}
fn make_forensics(conn_id: u64, started_at: Instant) -> RelayForensicsState {
RelayForensicsState {
trace_id: 0xB100_0000 + conn_id,
conn_id,
user: format!("tiny-frame-debt-user-{conn_id}"),
peer: "127.0.0.1:50000".parse().expect("peer parse must succeed"),
peer_hash: hash_ip("127.0.0.1".parse().expect("ip parse must succeed")),
started_at,
bytes_c2me: 0,
bytes_me2c: Arc::new(AtomicU64::new(0)),
desync_all_full: false,
}
}
fn make_enabled_idle_policy() -> RelayClientIdlePolicy {
RelayClientIdlePolicy {
enabled: true,
soft_idle: Duration::from_secs(30),
hard_idle: Duration::from_secs(60),
grace_after_downstream_activity: Duration::from_secs(0),
legacy_frame_read_timeout: Duration::from_secs(30),
}
}
fn simulate_tiny_debt_pattern(pattern: &[bool], max_steps: usize) -> (Option<usize>, u32, usize) {
let mut debt = 0u32;
let mut reals = 0usize;
for (idx, is_tiny) in pattern.iter().copied().take(max_steps).enumerate() {
if is_tiny {
debt = debt.saturating_add(TINY_FRAME_DEBT_PER_TINY);
if debt >= TINY_FRAME_DEBT_LIMIT {
return (Some(idx + 1), debt, reals);
}
} else {
reals = reals.saturating_add(1);
debt = debt.saturating_sub(1);
}
}
(None, debt, reals)
}
#[test]
fn tiny_frame_debt_constants_match_security_budget_expectations() {
assert_eq!(TINY_FRAME_DEBT_PER_TINY, 8);
assert_eq!(TINY_FRAME_DEBT_LIMIT, 512);
}
#[test]
fn relay_client_idle_state_initial_debt_is_zero() {
let state = RelayClientIdleState::new(Instant::now());
assert_eq!(state.tiny_frame_debt, 0);
}
#[test]
fn on_client_frame_does_not_reset_tiny_frame_debt() {
let now = Instant::now();
let mut state = RelayClientIdleState::new(now);
state.tiny_frame_debt = 77;
state.on_client_frame(now);
assert_eq!(state.tiny_frame_debt, 77);
}
#[test]
fn tiny_frame_debt_increment_is_saturating() {
let mut debt = u32::MAX - 1;
debt = debt.saturating_add(TINY_FRAME_DEBT_PER_TINY);
assert_eq!(debt, u32::MAX);
}
#[test]
fn tiny_frame_debt_decrement_is_saturating() {
let mut debt = 0u32;
debt = debt.saturating_sub(1);
assert_eq!(debt, 0);
}
#[test]
fn consecutive_tiny_frames_close_exactly_at_threshold() {
let max_tiny_without_close = (TINY_FRAME_DEBT_LIMIT / TINY_FRAME_DEBT_PER_TINY) as usize;
let pattern = vec![true; max_tiny_without_close];
let (closed_at, _, _) = simulate_tiny_debt_pattern(&pattern, pattern.len());
assert_eq!(closed_at, Some(max_tiny_without_close));
}
#[test]
fn one_less_than_threshold_tiny_frames_do_not_close() {
let tiny_count = (TINY_FRAME_DEBT_LIMIT / TINY_FRAME_DEBT_PER_TINY) as usize - 1;
let pattern = vec![true; tiny_count];
let (closed_at, debt, _) = simulate_tiny_debt_pattern(&pattern, pattern.len());
assert_eq!(closed_at, None);
assert!(debt < TINY_FRAME_DEBT_LIMIT);
}
#[test]
fn alternating_one_to_one_closes_with_bounded_real_frame_count() {
let mut pattern = Vec::with_capacity(512);
for _ in 0..256 {
pattern.push(true);
pattern.push(false);
}
let (closed_at, _, reals) = simulate_tiny_debt_pattern(&pattern, pattern.len());
assert!(closed_at.is_some());
assert!(reals <= 80, "expected bounded real frames before close, got {reals}");
}
#[test]
fn alternating_one_to_eight_is_stable_for_long_runs() {
let mut pattern = Vec::with_capacity(9 * 5000);
for _ in 0..5000 {
pattern.push(true);
for _ in 0..8 {
pattern.push(false);
}
}
let (closed_at, debt, _) = simulate_tiny_debt_pattern(&pattern, pattern.len());
assert_eq!(closed_at, None);
assert!(debt <= TINY_FRAME_DEBT_PER_TINY);
}
#[test]
fn alternating_one_to_seven_eventually_closes() {
let mut pattern = Vec::with_capacity(8 * 2000);
for _ in 0..2000 {
pattern.push(true);
for _ in 0..7 {
pattern.push(false);
}
}
let (closed_at, _, _) = simulate_tiny_debt_pattern(&pattern, pattern.len());
assert!(closed_at.is_some(), "1:7 tiny-to-real must eventually close");
}
#[test]
fn two_tiny_one_real_closes_faster_than_one_to_one() {
let mut one_to_one = Vec::with_capacity(512);
for _ in 0..256 {
one_to_one.push(true);
one_to_one.push(false);
}
let mut two_to_one = Vec::with_capacity(768);
for _ in 0..256 {
two_to_one.push(true);
two_to_one.push(true);
two_to_one.push(false);
}
let (a_close, _, _) = simulate_tiny_debt_pattern(&one_to_one, one_to_one.len());
let (b_close, _, _) = simulate_tiny_debt_pattern(&two_to_one, two_to_one.len());
assert!(a_close.is_some() && b_close.is_some());
assert!(b_close.unwrap_or(usize::MAX) < a_close.unwrap_or(0));
}
#[test]
fn burst_then_drain_can_recover_without_close() {
let burst_tiny = ((TINY_FRAME_DEBT_LIMIT / TINY_FRAME_DEBT_PER_TINY) / 2) as usize;
let mut pattern = Vec::with_capacity(burst_tiny + 600);
for _ in 0..burst_tiny {
pattern.push(true);
}
pattern.extend(std::iter::repeat_n(false, 600));
let (closed_at, debt, _) = simulate_tiny_debt_pattern(&pattern, pattern.len());
assert_eq!(closed_at, None);
assert_eq!(debt, 0);
}
#[test]
fn light_fuzz_tiny_frame_debt_model_stays_within_bounds() {
let mut seed = 0xA5A5_91C3_2026_0322u64;
for _case in 0..128 {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let len = 512 + ((seed as usize) & 0x3ff);
let mut pattern = Vec::with_capacity(len);
let mut local_seed = seed;
for _ in 0..len {
local_seed ^= local_seed << 7;
local_seed ^= local_seed >> 9;
local_seed ^= local_seed << 8;
pattern.push((local_seed & 1) == 0);
}
let (closed_at, debt, _) = simulate_tiny_debt_pattern(&pattern, pattern.len());
if closed_at.is_none() {
assert!(debt < TINY_FRAME_DEBT_LIMIT);
}
assert!(debt <= u32::MAX);
}
}
#[test]
fn stress_many_independent_simulations_keep_isolated_debt_state() {
for idx in 0..2048usize {
let mut pattern = Vec::with_capacity(64);
for j in 0..64usize {
pattern.push(((idx ^ j) & 3) == 0);
}
let (_closed_at, debt, _reals) = simulate_tiny_debt_pattern(&pattern, pattern.len());
assert!(debt <= TINY_FRAME_DEBT_LIMIT.saturating_add(TINY_FRAME_DEBT_PER_TINY));
}
}
#[tokio::test]
async fn idle_policy_enabled_intermediate_zero_length_flood_is_fail_closed() {
let (reader, mut writer) = duplex(4096);
let mut crypto_reader = make_crypto_reader(reader);
let buffer_pool = Arc::new(BufferPool::new());
let stats = Stats::new();
let session_started_at = Instant::now();
let forensics = make_forensics(11, session_started_at);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(session_started_at);
let idle_policy = make_enabled_idle_policy();
let last_downstream_activity_ms = AtomicU64::new(0);
let flood_plaintext = vec![0u8; 4 * 256];
let flood_encrypted = encrypt_for_reader(&flood_plaintext);
writer.write_all(&flood_encrypted).await.unwrap();
drop(writer);
let result = read_client_payload_with_idle_policy(
&mut crypto_reader,
ProtoTag::Intermediate,
1024,
&buffer_pool,
&forensics,
&mut frame_counter,
&stats,
&idle_policy,
&mut idle_state,
&last_downstream_activity_ms,
session_started_at,
)
.await;
assert!(matches!(result, Err(ProxyError::Proxy(_))));
}
#[tokio::test]
async fn idle_policy_enabled_secure_zero_length_flood_is_fail_closed() {
let (reader, mut writer) = duplex(4096);
let mut crypto_reader = make_crypto_reader(reader);
let buffer_pool = Arc::new(BufferPool::new());
let stats = Stats::new();
let session_started_at = Instant::now();
let forensics = make_forensics(12, session_started_at);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(session_started_at);
let idle_policy = make_enabled_idle_policy();
let last_downstream_activity_ms = AtomicU64::new(0);
let flood_plaintext = vec![0u8; 4 * 256];
let flood_encrypted = encrypt_for_reader(&flood_plaintext);
writer.write_all(&flood_encrypted).await.unwrap();
drop(writer);
let result = read_client_payload_with_idle_policy(
&mut crypto_reader,
ProtoTag::Secure,
1024,
&buffer_pool,
&forensics,
&mut frame_counter,
&stats,
&idle_policy,
&mut idle_state,
&last_downstream_activity_ms,
session_started_at,
)
.await;
assert!(matches!(result, Err(ProxyError::Proxy(_))));
}
#[tokio::test]
async fn intermediate_alternating_zero_and_real_eventually_closes() {
let (reader, mut writer) = duplex(8192);
let mut crypto_reader = make_crypto_reader(reader);
let buffer_pool = Arc::new(BufferPool::new());
let stats = Stats::new();
let session_started_at = Instant::now();
let forensics = make_forensics(13, session_started_at);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(session_started_at);
let idle_policy = make_enabled_idle_policy();
let last_downstream_activity_ms = AtomicU64::new(0);
let mut plaintext = Vec::with_capacity(3000);
for idx in 0..160u8 {
plaintext.extend_from_slice(&0u32.to_le_bytes());
plaintext.extend_from_slice(&4u32.to_le_bytes());
plaintext.extend_from_slice(&[idx, idx ^ 0x11, idx ^ 0x22, idx ^ 0x33]);
}
let encrypted = encrypt_for_reader(&plaintext);
writer.write_all(&encrypted).await.unwrap();
drop(writer);
let mut closed = false;
for _ in 0..220 {
let result = read_client_payload_with_idle_policy(
&mut crypto_reader,
ProtoTag::Intermediate,
1024,
&buffer_pool,
&forensics,
&mut frame_counter,
&stats,
&idle_policy,
&mut idle_state,
&last_downstream_activity_ms,
session_started_at,
)
.await;
match result {
Ok(Some(_)) => {}
Err(ProxyError::Proxy(_)) => {
closed = true;
break;
}
Ok(None) => break,
Err(other) => panic!("unexpected error while probing alternating close: {other}"),
}
}
assert!(closed, "intermediate alternating attack must fail closed");
}
#[tokio::test]
async fn small_tiny_burst_followed_by_real_frame_does_not_spuriously_close() {
let (reader, mut writer) = duplex(1024);
let mut crypto_reader = make_crypto_reader(reader);
let buffer_pool = Arc::new(BufferPool::new());
let stats = Stats::new();
let session_started_at = Instant::now();
let forensics = make_forensics(14, session_started_at);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(session_started_at);
let idle_policy = make_enabled_idle_policy();
let last_downstream_activity_ms = AtomicU64::new(0);
let mut plaintext = Vec::with_capacity(64);
for _ in 0..8 {
plaintext.push(0x00);
}
plaintext.push(0x01);
plaintext.extend_from_slice(&[1, 2, 3, 4]);
let encrypted = encrypt_for_reader(&plaintext);
writer.write_all(&encrypted).await.unwrap();
let first = read_client_payload_with_idle_policy(
&mut crypto_reader,
ProtoTag::Abridged,
1024,
&buffer_pool,
&forensics,
&mut frame_counter,
&stats,
&idle_policy,
&mut idle_state,
&last_downstream_activity_ms,
session_started_at,
)
.await;
match first {
Ok(Some((payload, _))) => assert_eq!(payload.as_ref(), &[1, 2, 3, 4]),
Err(e) => panic!("unexpected close after small tiny burst: {e}"),
Ok(None) => panic!("unexpected EOF before real frame"),
}
}
#[tokio::test]
async fn idle_policy_enabled_zero_length_flood_is_fail_closed() {
let (reader, mut writer) = duplex(4096);
let mut crypto_reader = make_crypto_reader(reader);
let buffer_pool = Arc::new(BufferPool::new());
let stats = Stats::new();
let session_started_at = Instant::now();
let forensics = make_forensics(1, session_started_at);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(session_started_at);
let idle_policy = make_enabled_idle_policy();
let last_downstream_activity_ms = AtomicU64::new(0);
let flood_plaintext = vec![0u8; 1024];
let flood_encrypted = encrypt_for_reader(&flood_plaintext);
writer
.write_all(&flood_encrypted)
.await
.expect("zero-length flood bytes must be writable");
drop(writer);
let result = read_client_payload_with_idle_policy(
&mut crypto_reader,
ProtoTag::Abridged,
1024,
&buffer_pool,
&forensics,
&mut frame_counter,
&stats,
&idle_policy,
&mut idle_state,
&last_downstream_activity_ms,
session_started_at,
)
.await;
assert!(
matches!(result, Err(ProxyError::Proxy(_))),
"idle policy enabled must fail closed for pure zero-length flood"
);
}
#[tokio::test]
async fn idle_policy_enabled_alternating_tiny_real_eventually_closes() {
let (reader, mut writer) = duplex(8192);
let mut crypto_reader = make_crypto_reader(reader);
let buffer_pool = Arc::new(BufferPool::new());
let stats = Stats::new();
let session_started_at = Instant::now();
let forensics = make_forensics(2, session_started_at);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(session_started_at);
let idle_policy = make_enabled_idle_policy();
let last_downstream_activity_ms = AtomicU64::new(0);
let mut plaintext = Vec::with_capacity(256 * 6);
for idx in 0..=255u8 {
plaintext.push(0x00);
plaintext.push(0x01);
plaintext.extend_from_slice(&[idx, idx ^ 0x55, idx ^ 0xAA, 0x11]);
}
let encrypted = encrypt_for_reader(&plaintext);
writer
.write_all(&encrypted)
.await
.expect("alternating flood bytes must be writable");
drop(writer);
let mut saw_proxy_close = false;
for _ in 0..300 {
let result = read_client_payload_with_idle_policy(
&mut crypto_reader,
ProtoTag::Abridged,
1024,
&buffer_pool,
&forensics,
&mut frame_counter,
&stats,
&idle_policy,
&mut idle_state,
&last_downstream_activity_ms,
session_started_at,
)
.await;
match result {
Ok(Some((_payload, _quickack))) => {}
Err(ProxyError::Proxy(_)) => {
saw_proxy_close = true;
break;
}
Err(ProxyError::Io(e)) => panic!("unexpected IO error before close: {e}"),
Ok(None) => panic!("unexpected EOF before debt-based closure"),
Err(other) => panic!("unexpected error before close: {other}"),
}
}
assert!(
saw_proxy_close,
"alternating tiny/real sequence must eventually fail closed"
);
}
#[tokio::test]
async fn enabled_idle_policy_valid_nonzero_frame_still_passes() {
let (reader, mut writer) = duplex(1024);
let mut crypto_reader = make_crypto_reader(reader);
let buffer_pool = Arc::new(BufferPool::new());
let stats = Stats::new();
let session_started_at = Instant::now();
let forensics = make_forensics(3, session_started_at);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(session_started_at);
let idle_policy = make_enabled_idle_policy();
let last_downstream_activity_ms = AtomicU64::new(0);
let payload = [7u8, 8, 9, 10];
let mut plaintext = Vec::with_capacity(1 + payload.len());
plaintext.push(0x01);
plaintext.extend_from_slice(&payload);
let encrypted = encrypt_for_reader(&plaintext);
writer
.write_all(&encrypted)
.await
.expect("nonzero frame must be writable");
let result = read_client_payload_with_idle_policy(
&mut crypto_reader,
ProtoTag::Abridged,
1024,
&buffer_pool,
&forensics,
&mut frame_counter,
&stats,
&idle_policy,
&mut idle_state,
&last_downstream_activity_ms,
session_started_at,
)
.await
.expect("valid frame should decode")
.expect("valid frame should return payload");
assert_eq!(result.0.as_ref(), &payload);
assert!(!result.1);
assert_eq!(frame_counter, 1);
}

View File

@ -0,0 +1,121 @@
use super::*;
use crate::crypto::AesCtr;
use crate::stats::Stats;
use crate::stream::{BufferPool, CryptoReader};
use std::sync::atomic::AtomicU64;
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWriteExt, duplex};
use std::time::Instant;
fn make_crypto_reader<T>(reader: T) -> CryptoReader<T>
where
T: AsyncRead + Unpin + Send + 'static,
{
let key = [0u8; 32];
let iv = 0u128;
CryptoReader::new(reader, AesCtr::new(&key, iv))
}
fn encrypt_for_reader(plaintext: &[u8]) -> Vec<u8> {
let key = [0u8; 32];
let iv = 0u128;
let mut cipher = AesCtr::new(&key, iv);
cipher.encrypt(plaintext)
}
fn make_forensics(conn_id: u64, started_at: Instant) -> RelayForensicsState {
RelayForensicsState {
trace_id: 0xB000_0000 + conn_id,
conn_id,
user: format!("zero-len-test-user-{conn_id}"),
peer: "127.0.0.1:50000".parse().expect("peer parse must succeed"),
peer_hash: hash_ip("127.0.0.1".parse().expect("ip parse must succeed")),
started_at,
bytes_c2me: 0,
bytes_me2c: Arc::new(AtomicU64::new(0)),
desync_all_full: false,
}
}
#[tokio::test]
async fn adversarial_legacy_zero_length_flood_is_fail_closed() {
let (reader, mut writer) = duplex(1024);
let mut crypto_reader = make_crypto_reader(reader);
let buffer_pool = Arc::new(BufferPool::new());
let stats = Stats::new();
let session_started_at = Instant::now();
let forensics = make_forensics(1, session_started_at);
let mut frame_counter = 0u64;
let flood_plaintext = vec![0u8; 128];
let flood_encrypted = encrypt_for_reader(&flood_plaintext);
writer
.write_all(&flood_encrypted)
.await
.expect("zero-length flood bytes must be writable");
drop(writer);
let result = read_client_payload_legacy(
&mut crypto_reader,
ProtoTag::Abridged,
1024,
Duration::from_millis(30),
&buffer_pool,
&forensics,
&mut frame_counter,
&stats,
)
.await;
match result {
Err(ProxyError::Proxy(msg)) => {
assert!(
msg.contains("Excessive zero-length"),
"legacy mode must close flood with explicit zero-length reason, got: {msg}"
);
}
Ok(None) => panic!("legacy zero-length flood must not be accepted as EOF"),
Ok(Some(_)) => panic!("legacy zero-length flood must not produce a data frame"),
Err(err) => panic!("legacy zero-length flood must be a Proxy error, got: {err}"),
}
}
#[tokio::test]
async fn business_abridged_nonzero_frame_still_passes() {
let (reader, mut writer) = duplex(1024);
let mut crypto_reader = make_crypto_reader(reader);
let buffer_pool = Arc::new(BufferPool::new());
let stats = Stats::new();
let session_started_at = Instant::now();
let forensics = make_forensics(2, session_started_at);
let mut frame_counter = 0u64;
let payload = [1u8, 2, 3, 4];
let mut plaintext = Vec::with_capacity(1 + payload.len());
plaintext.push(0x01);
plaintext.extend_from_slice(&payload);
let encrypted = encrypt_for_reader(&plaintext);
writer
.write_all(&encrypted)
.await
.expect("nonzero abridged frame must be writable");
let result = read_client_payload_legacy(
&mut crypto_reader,
ProtoTag::Abridged,
1024,
Duration::from_millis(30),
&buffer_pool,
&forensics,
&mut frame_counter,
&stats,
)
.await
.expect("valid abridged frame should decode")
.expect("valid abridged frame should return payload");
assert_eq!(result.0.as_ref(), &payload);
assert!(!result.1, "quickack flag must remain false");
assert_eq!(frame_counter, 1);
}

View File

@ -0,0 +1,108 @@
use super::*;
use std::sync::Arc;
use std::sync::{Mutex, OnceLock};
fn cross_mode_lock_test_guard() -> std::sync::MutexGuard<'static, ()> {
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
TEST_LOCK
.get_or_init(|| Mutex::new(()))
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
#[test]
fn same_user_returns_same_lock_identity() {
let _guard = cross_mode_lock_test_guard();
let locks = CROSS_MODE_QUOTA_USER_LOCKS.get_or_init(DashMap::new);
locks.clear();
let a = cross_mode_quota_user_lock("cross-mode-same-user");
let b = cross_mode_quota_user_lock("cross-mode-same-user");
assert!(
Arc::ptr_eq(&a, &b),
"same user must reuse a stable lock identity"
);
}
#[test]
fn saturation_overflow_path_returns_stable_striped_lock_without_cache_growth() {
let _guard = cross_mode_lock_test_guard();
let locks = CROSS_MODE_QUOTA_USER_LOCKS.get_or_init(DashMap::new);
locks.clear();
let prefix = format!("cross-mode-saturated-{}", std::process::id());
let mut retained = Vec::with_capacity(CROSS_MODE_QUOTA_USER_LOCKS_MAX);
for idx in 0..CROSS_MODE_QUOTA_USER_LOCKS_MAX {
retained.push(cross_mode_quota_user_lock(&format!("{prefix}-{idx}")));
}
assert_eq!(
locks.len(),
CROSS_MODE_QUOTA_USER_LOCKS_MAX,
"lock cache must be saturated for overflow check"
);
let overflow_user = format!("cross-mode-overflow-{}", std::process::id());
let overflow_a = cross_mode_quota_user_lock(&overflow_user);
let overflow_b = cross_mode_quota_user_lock(&overflow_user);
assert_eq!(
locks.len(),
CROSS_MODE_QUOTA_USER_LOCKS_MAX,
"overflow path must not grow bounded lock cache"
);
assert!(
locks.get(&overflow_user).is_none(),
"overflow user must stay on striped fallback while cache is saturated"
);
assert!(
Arc::ptr_eq(&overflow_a, &overflow_b),
"overflow user must receive a stable striped lock across repeated lookups"
);
drop(retained);
}
#[test]
fn reclaim_drops_stale_entries_but_preserves_active_user_lock_identity() {
let _guard = cross_mode_lock_test_guard();
let locks = CROSS_MODE_QUOTA_USER_LOCKS.get_or_init(DashMap::new);
locks.clear();
let prefix = format!("cross-mode-reclaim-{}", std::process::id());
let protected_user = format!("{prefix}-protected");
let protected_lock = cross_mode_quota_user_lock(&protected_user);
let mut retained = Vec::with_capacity(CROSS_MODE_QUOTA_USER_LOCKS_MAX.saturating_sub(1));
for idx in 0..(CROSS_MODE_QUOTA_USER_LOCKS_MAX.saturating_sub(1)) {
retained.push(cross_mode_quota_user_lock(&format!("{prefix}-{idx}")));
}
assert_eq!(
locks.len(),
CROSS_MODE_QUOTA_USER_LOCKS_MAX,
"fixture must saturate lock cache before reclaim path is exercised"
);
drop(retained);
let newcomer_user = format!("{prefix}-newcomer");
let _newcomer = cross_mode_quota_user_lock(&newcomer_user);
assert!(
locks.get(&protected_user).is_some(),
"active protected user must remain cache-resident after reclaim"
);
let locked = locks
.get(&protected_user)
.expect("protected user must remain in map after reclaim");
assert!(
Arc::ptr_eq(locked.value(), &protected_lock),
"reclaim must not swap active user lock identity"
);
assert!(
locks.get(&newcomer_user).is_some(),
"newcomer should become cacheable after stale entries are reclaimed"
);
}

View File

@ -0,0 +1,225 @@
use super::*;
use crate::stats::Stats;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::task::{Context, Waker};
use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadBuf};
use tokio::time::{Duration, timeout};
#[derive(Default)]
struct WakeCounter {
wakes: AtomicUsize,
}
impl std::task::Wake for WakeCounter {
fn wake(self: Arc<Self>) {
self.wakes.fetch_add(1, Ordering::Relaxed);
}
fn wake_by_ref(self: &Arc<Self>) {
self.wakes.fetch_add(1, Ordering::Relaxed);
}
}
fn quota_test_guard() -> impl Drop {
super::quota_user_lock_test_scope()
}
#[tokio::test]
async fn positive_cross_mode_uncontended_writer_progresses() {
let _guard = quota_test_guard();
let mut io = StatsIo::new(
tokio::io::sink(),
Arc::new(SharedCounters::new()),
Arc::new(Stats::new()),
"cross-mode-tdd-uncontended".to_string(),
Some(4096),
Arc::new(AtomicBool::new(false)),
tokio::time::Instant::now(),
);
let result = io.write_all(&[0x11, 0x22]).await;
assert!(result.is_ok(), "uncontended writer must progress");
}
#[tokio::test]
async fn adversarial_held_cross_mode_lock_blocks_writer_even_if_local_lock_free() {
let _guard = quota_test_guard();
let user = format!("cross-mode-tdd-held-{}", std::process::id());
let held = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
let _held_guard = held
.try_lock()
.expect("test must hold cross-mode lock before polling writer");
let mut io = StatsIo::new(
tokio::io::sink(),
Arc::new(SharedCounters::new()),
Arc::new(Stats::new()),
user,
Some(4096),
Arc::new(AtomicBool::new(false)),
tokio::time::Instant::now(),
);
let wake_counter = Arc::new(WakeCounter::default());
let waker = Waker::from(Arc::clone(&wake_counter));
let mut cx = Context::from_waker(&waker);
let poll = Pin::new(&mut io).poll_write(&mut cx, &[0xAA]);
assert!(poll.is_pending(), "writer must not bypass held cross-mode lock");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn integration_parallel_waiters_resume_after_cross_mode_release() {
let _guard = quota_test_guard();
let user = format!("cross-mode-tdd-resume-{}", std::process::id());
let held = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
let held_guard = held
.try_lock()
.expect("test must hold cross-mode lock before launching waiters");
let stats = Arc::new(Stats::new());
let mut waiters = Vec::new();
for _ in 0..16 {
let stats = Arc::clone(&stats);
let user = user.clone();
waiters.push(tokio::spawn(async move {
let mut io = StatsIo::new(
tokio::io::sink(),
Arc::new(SharedCounters::new()),
stats,
user,
Some(4096),
Arc::new(AtomicBool::new(false)),
tokio::time::Instant::now(),
);
io.write_all(&[0x7F]).await
}));
}
tokio::time::sleep(Duration::from_millis(5)).await;
drop(held_guard);
timeout(Duration::from_secs(1), async {
for waiter in waiters {
let result = waiter.await.expect("waiter task must not panic");
assert!(result.is_ok(), "waiter must complete after cross-mode release");
}
})
.await
.expect("all waiters must complete in bounded time");
}
#[tokio::test]
async fn adversarial_cross_mode_contention_wake_budget_stays_bounded() {
let _guard = quota_test_guard();
let user = format!("cross-mode-tdd-wakes-{}", std::process::id());
let held = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
let _held_guard = held
.try_lock()
.expect("test must hold cross-mode lock before polling");
let stats = Arc::new(Stats::new());
let mut ios = Vec::new();
let mut counters = Vec::new();
for _ in 0..20 {
ios.push(StatsIo::new(
tokio::io::sink(),
Arc::new(SharedCounters::new()),
Arc::clone(&stats),
user.clone(),
Some(2048),
Arc::new(AtomicBool::new(false)),
tokio::time::Instant::now(),
));
}
for io in &mut ios {
let wake_counter = Arc::new(WakeCounter::default());
let waker = Waker::from(Arc::clone(&wake_counter));
let mut cx = Context::from_waker(&waker);
let poll = Pin::new(io).poll_write(&mut cx, &[0x33]);
assert!(poll.is_pending());
counters.push(wake_counter);
}
tokio::time::sleep(Duration::from_millis(25)).await;
let total_wakes: usize = counters
.iter()
.map(|counter| counter.wakes.load(Ordering::Relaxed))
.sum();
assert!(
total_wakes <= 20 * 4,
"cross-mode contention should not create wake storms; wakes={total_wakes}"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn light_fuzz_cross_mode_release_timing_preserves_read_write_liveness() {
let _guard = quota_test_guard();
let mut seed = 0xC0DE_BAAD_2026_0322u64;
for round in 0..16u32 {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let sleep_ms = 2 + (seed as u64 % 8);
let user = format!("cross-mode-tdd-fuzz-{}-{round}", std::process::id());
let held = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
let held_guard = held
.try_lock()
.expect("test must hold cross-mode lock in fuzz round");
let stats = Arc::new(Stats::new());
let user_reader = user.clone();
let reader_task = tokio::spawn(async move {
let mut io = StatsIo::new(
tokio::io::empty(),
Arc::new(SharedCounters::new()),
Arc::clone(&stats),
user_reader,
Some(4096),
Arc::new(AtomicBool::new(false)),
tokio::time::Instant::now(),
);
let mut one = [0u8; 1];
io.read(&mut one).await
});
let user_writer = user.clone();
let writer_task = tokio::spawn(async move {
let mut io = StatsIo::new(
tokio::io::sink(),
Arc::new(SharedCounters::new()),
Arc::new(Stats::new()),
user_writer,
Some(4096),
Arc::new(AtomicBool::new(false)),
tokio::time::Instant::now(),
);
io.write_all(&[0x44]).await
});
tokio::time::sleep(Duration::from_millis(sleep_ms)).await;
drop(held_guard);
let read_done = timeout(Duration::from_millis(350), reader_task)
.await
.expect("reader task must complete after release")
.expect("reader task must not panic");
assert!(read_done.is_ok());
let write_done = timeout(Duration::from_millis(350), writer_task)
.await
.expect("writer task must complete after release")
.expect("writer task must not panic");
assert!(write_done.is_ok());
}
}

View File

@ -0,0 +1,81 @@
use super::*;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::task::Waker;
use std::task::{Context, Poll};
#[derive(Default)]
struct WakeCounter {
wakes: AtomicUsize,
}
impl std::task::Wake for WakeCounter {
fn wake(self: Arc<Self>) {
self.wakes.fetch_add(1, Ordering::Relaxed);
}
fn wake_by_ref(self: &Arc<Self>) {
self.wakes.fetch_add(1, Ordering::Relaxed);
}
}
fn build_context() -> (Arc<WakeCounter>, Context<'static>) {
let wake_counter = Arc::new(WakeCounter::default());
let waker = Waker::from(Arc::clone(&wake_counter));
let leaked_waker: &'static Waker = Box::leak(Box::new(waker));
(wake_counter, Context::from_waker(leaked_waker))
}
#[tokio::test]
async fn adversarial_middle_held_cross_mode_lock_blocks_relay_writer() {
let _guard = quota_user_lock_test_scope();
let user = "cross-mode-lock-shared-user";
let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(user);
let _held_guard = held
.try_lock()
.expect("test must hold shared cross-mode lock before relay poll");
let mut io = StatsIo::new(
tokio::io::sink(),
Arc::new(SharedCounters::new()),
Arc::new(crate::stats::Stats::new()),
user.to_string(),
Some(1024),
Arc::new(AtomicBool::new(false)),
tokio::time::Instant::now(),
);
let (_wake_counter, mut cx) = build_context();
let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x41, 0x42, 0x43]);
assert!(
matches!(poll, Poll::Pending),
"relay writer must not bypass cross-mode lock held by middle-relay path"
);
}
#[tokio::test]
async fn business_cross_mode_lock_uncontended_allows_relay_writer_progress() {
let _guard = quota_user_lock_test_scope();
let user = "cross-mode-lock-progress-user";
let mut io = StatsIo::new(
tokio::io::sink(),
Arc::new(SharedCounters::new()),
Arc::new(crate::stats::Stats::new()),
user.to_string(),
Some(1024),
Arc::new(AtomicBool::new(false)),
tokio::time::Instant::now(),
);
let (_wake_counter, mut cx) = build_context();
let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x51, 0x52]);
assert!(
matches!(poll, Poll::Ready(Ok(2))),
"relay writer should progress when shared cross-mode lock is uncontended"
);
}

View File

@ -0,0 +1,135 @@
use super::*;
use crate::stats::Stats;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::task::Waker;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
#[derive(Default)]
struct WakeCounter {
wakes: AtomicUsize,
}
impl std::task::Wake for WakeCounter {
fn wake(self: Arc<Self>) {
self.wakes.fetch_add(1, Ordering::Relaxed);
}
fn wake_by_ref(self: &Arc<Self>) {
self.wakes.fetch_add(1, Ordering::Relaxed);
}
}
fn build_context() -> (Arc<WakeCounter>, Context<'static>) {
let wake_counter = Arc::new(WakeCounter::default());
let waker = Waker::from(Arc::clone(&wake_counter));
// Context stores a reference; leak one Waker for deterministic test scope.
let leaked_waker: &'static Waker = Box::leak(Box::new(waker));
(wake_counter, Context::from_waker(leaked_waker))
}
#[tokio::test]
async fn adversarial_map_churn_cannot_bypass_held_writer_lock() {
let _guard = quota_user_lock_test_scope();
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let user = "quota-identity-writer-user";
let held_lock = quota_user_lock(user);
let _held_guard = held_lock
.try_lock()
.expect("test must hold initial user lock before StatsIo poll");
let mut io = StatsIo::new(
tokio::io::sink(),
Arc::new(SharedCounters::new()),
Arc::new(Stats::new()),
user.to_string(),
Some(1024),
Arc::new(AtomicBool::new(false)),
tokio::time::Instant::now(),
);
map.clear();
let churned_lock = quota_user_lock(user);
assert!(
!Arc::ptr_eq(&held_lock, &churned_lock),
"precondition: map churn should produce a distinct lock identity"
);
let (_wake_counter, mut cx) = build_context();
let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x11, 0x22, 0x33, 0x44]);
assert!(
matches!(poll, Poll::Pending),
"writer must remain pending on the originally-held lock identity"
);
}
#[tokio::test]
async fn adversarial_map_churn_cannot_bypass_held_reader_lock() {
let _guard = quota_user_lock_test_scope();
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let user = "quota-identity-reader-user";
let held_lock = quota_user_lock(user);
let _held_guard = held_lock
.try_lock()
.expect("test must hold initial user lock before StatsIo poll");
let mut io = StatsIo::new(
tokio::io::empty(),
Arc::new(SharedCounters::new()),
Arc::new(Stats::new()),
user.to_string(),
Some(1024),
Arc::new(AtomicBool::new(false)),
tokio::time::Instant::now(),
);
map.clear();
let churned_lock = quota_user_lock(user);
assert!(
!Arc::ptr_eq(&held_lock, &churned_lock),
"precondition: map churn should produce a distinct lock identity"
);
let (_wake_counter, mut cx) = build_context();
let mut storage = [0u8; 8];
let mut read_buf = ReadBuf::new(&mut storage);
let poll = Pin::new(&mut io).poll_read(&mut cx, &mut read_buf);
assert!(
matches!(poll, Poll::Pending),
"reader must remain pending on the originally-held lock identity"
);
}
#[tokio::test]
async fn business_no_lock_contention_keeps_writer_progress() {
let _guard = quota_user_lock_test_scope();
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let user = "quota-identity-progress-user";
let mut io = StatsIo::new(
tokio::io::sink(),
Arc::new(SharedCounters::new()),
Arc::new(Stats::new()),
user.to_string(),
Some(1024),
Arc::new(AtomicBool::new(false)),
tokio::time::Instant::now(),
);
let (_wake_counter, mut cx) = build_context();
let poll = Pin::new(&mut io).poll_write(&mut cx, &[0xAA, 0xBB]);
assert!(
matches!(poll, Poll::Ready(Ok(2))),
"writer should progress immediately without contention"
);
}

View File

@ -0,0 +1,241 @@
use super::*;
use crate::stats::Stats;
use dashmap::DashMap;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::task::{Context, Waker};
use tokio::io::ReadBuf;
use tokio::time::{Duration, Instant};
#[derive(Default)]
struct WakeCounter {
wakes: AtomicUsize,
}
impl std::task::Wake for WakeCounter {
fn wake(self: Arc<Self>) {
self.wakes.fetch_add(1, Ordering::Relaxed);
}
fn wake_by_ref(self: &Arc<Self>) {
self.wakes.fetch_add(1, Ordering::Relaxed);
}
}
fn quota_test_guard() -> impl Drop {
super::quota_user_lock_test_scope()
}
fn saturate_quota_user_locks() -> Vec<Arc<std::sync::Mutex<()>>> {
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
for idx in 0..QUOTA_USER_LOCKS_MAX {
retained.push(quota_user_lock(&format!("quota-retry-bench-saturate-{idx}")));
}
retained
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn stress_contention_wake_rate_decays_with_backoff_curve() {
let _guard = quota_test_guard();
let _retained = saturate_quota_user_locks();
let user = format!("quota-backoff-bench-{}", std::process::id());
let stats = Arc::new(Stats::new());
let lock = quota_user_lock(&user);
let held_guard = lock
.try_lock()
.expect("test must hold quota lock before benchmark run");
let waiters = 64usize;
let mut ios = Vec::with_capacity(waiters);
let mut wake_counters = Vec::with_capacity(waiters);
for _ in 0..waiters {
ios.push(StatsIo::new(
tokio::io::sink(),
Arc::new(SharedCounters::new()),
Arc::clone(&stats),
user.clone(),
Some(4096),
Arc::new(AtomicBool::new(false)),
tokio::time::Instant::now(),
));
}
for io in &mut ios {
let counter = Arc::new(WakeCounter::default());
let waker = Waker::from(Arc::clone(&counter));
let mut cx = Context::from_waker(&waker);
let pending = Pin::new(io).poll_write(&mut cx, &[0x71]);
assert!(pending.is_pending());
wake_counters.push(counter);
}
let mut observed = vec![0usize; waiters];
let start = Instant::now();
let mut wakes_at_40ms = 0usize;
let mut wakes_at_160ms = 0usize;
while start.elapsed() < Duration::from_millis(200) {
for (idx, counter) in wake_counters.iter().enumerate() {
let wakes = counter.wakes.load(Ordering::Relaxed);
if wakes > observed[idx] {
observed[idx] = wakes;
let waker = Waker::from(Arc::clone(counter));
let mut cx = Context::from_waker(&waker);
let pending = Pin::new(&mut ios[idx]).poll_write(&mut cx, &[0x72]);
assert!(pending.is_pending());
}
}
let elapsed = start.elapsed();
if elapsed >= Duration::from_millis(40) && wakes_at_40ms == 0 {
wakes_at_40ms = wake_counters
.iter()
.map(|counter| counter.wakes.load(Ordering::Relaxed))
.sum();
}
if elapsed >= Duration::from_millis(160) && wakes_at_160ms == 0 {
wakes_at_160ms = wake_counters
.iter()
.map(|counter| counter.wakes.load(Ordering::Relaxed))
.sum();
}
tokio::time::sleep(Duration::from_millis(1)).await;
}
let total_wakes: usize = wake_counters
.iter()
.map(|counter| counter.wakes.load(Ordering::Relaxed))
.sum();
let wakes_at_200ms = total_wakes;
let early_window_wakes = wakes_at_40ms;
let late_window_wakes = wakes_at_200ms.saturating_sub(wakes_at_160ms);
assert!(
total_wakes <= waiters * 28,
"backoff benchmark exceeded wake budget; waiters={waiters}, wakes={total_wakes}"
);
assert!(
early_window_wakes > 0,
"benchmark failed to observe early contention wakes"
);
assert!(
late_window_wakes * 4 <= early_window_wakes * 3,
"wake-rate decay invariant violated; early_0_40ms={early_window_wakes}, late_160_200ms={late_window_wakes}, total={total_wakes}"
);
drop(held_guard);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn stress_read_contention_wake_rate_decays_with_backoff_curve() {
let _guard = quota_test_guard();
let _retained = saturate_quota_user_locks();
let user = format!("quota-backoff-read-bench-{}", std::process::id());
let stats = Arc::new(Stats::new());
let lock = quota_user_lock(&user);
let held_guard = lock
.try_lock()
.expect("test must hold quota lock before read benchmark run");
let waiters = 64usize;
let mut ios = Vec::with_capacity(waiters);
let mut wake_counters = Vec::with_capacity(waiters);
for _ in 0..waiters {
ios.push(StatsIo::new(
tokio::io::empty(),
Arc::new(SharedCounters::new()),
Arc::clone(&stats),
user.clone(),
Some(4096),
Arc::new(AtomicBool::new(false)),
tokio::time::Instant::now(),
));
}
for io in &mut ios {
let counter = Arc::new(WakeCounter::default());
let waker = Waker::from(Arc::clone(&counter));
let mut cx = Context::from_waker(&waker);
let mut storage = [0u8; 1];
let mut buf = ReadBuf::new(&mut storage);
let pending = Pin::new(io).poll_read(&mut cx, &mut buf);
assert!(pending.is_pending());
wake_counters.push(counter);
}
let mut observed = vec![0usize; waiters];
let start = Instant::now();
let mut wakes_at_40ms = 0usize;
let mut wakes_at_160ms = 0usize;
while start.elapsed() < Duration::from_millis(200) {
for (idx, counter) in wake_counters.iter().enumerate() {
let wakes = counter.wakes.load(Ordering::Relaxed);
if wakes > observed[idx] {
observed[idx] = wakes;
let waker = Waker::from(Arc::clone(counter));
let mut cx = Context::from_waker(&waker);
let mut storage = [0u8; 1];
let mut buf = ReadBuf::new(&mut storage);
let pending = Pin::new(&mut ios[idx]).poll_read(&mut cx, &mut buf);
assert!(pending.is_pending());
}
}
let elapsed = start.elapsed();
if elapsed >= Duration::from_millis(40) && wakes_at_40ms == 0 {
wakes_at_40ms = wake_counters
.iter()
.map(|counter| counter.wakes.load(Ordering::Relaxed))
.sum();
}
if elapsed >= Duration::from_millis(160) && wakes_at_160ms == 0 {
wakes_at_160ms = wake_counters
.iter()
.map(|counter| counter.wakes.load(Ordering::Relaxed))
.sum();
}
tokio::time::sleep(Duration::from_millis(1)).await;
}
let total_wakes: usize = wake_counters
.iter()
.map(|counter| counter.wakes.load(Ordering::Relaxed))
.sum();
let wakes_at_200ms = total_wakes;
let early_window_wakes = wakes_at_40ms;
let late_window_wakes = wakes_at_200ms.saturating_sub(wakes_at_160ms);
assert!(
total_wakes <= waiters * 28,
"read backoff benchmark exceeded wake budget; waiters={waiters}, wakes={total_wakes}"
);
assert!(
early_window_wakes > 0,
"read benchmark failed to observe early contention wakes"
);
assert!(
late_window_wakes * 4 <= early_window_wakes * 3,
"read wake-rate decay invariant violated; early_0_40ms={early_window_wakes}, late_160_200ms={late_window_wakes}, total={total_wakes}"
);
drop(held_guard);
}

View File

@ -0,0 +1,339 @@
use super::*;
use crate::stats::Stats;
use dashmap::DashMap;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::task::{Context, Waker};
use tokio::io::ReadBuf;
use tokio::time::{Duration, Instant};
#[derive(Default)]
struct WakeCounter {
wakes: AtomicUsize,
}
impl std::task::Wake for WakeCounter {
fn wake(self: Arc<Self>) {
self.wakes.fetch_add(1, Ordering::Relaxed);
}
fn wake_by_ref(self: &Arc<Self>) {
self.wakes.fetch_add(1, Ordering::Relaxed);
}
}
fn quota_test_guard() -> impl Drop {
super::quota_user_lock_test_scope()
}
fn saturate_quota_user_locks() -> Vec<Arc<std::sync::Mutex<()>>> {
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
for idx in 0..QUOTA_USER_LOCKS_MAX {
retained.push(quota_user_lock(&format!("quota-retry-backoff-saturate-{idx}")));
}
retained
}
#[tokio::test]
async fn positive_uncontended_writer_keeps_retry_wakes_zero() {
let _guard = quota_test_guard();
let stats = Arc::new(Stats::new());
let mut io = StatsIo::new(
tokio::io::sink(),
Arc::new(SharedCounters::new()),
Arc::clone(&stats),
"quota-backoff-positive".to_string(),
Some(2048),
Arc::new(AtomicBool::new(false)),
tokio::time::Instant::now(),
);
let wake_counter = Arc::new(WakeCounter::default());
let waker = Waker::from(Arc::clone(&wake_counter));
let mut cx = Context::from_waker(&waker);
let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x41, 0x42]);
assert!(poll.is_ready(), "uncontended writer must complete immediately");
assert_eq!(
wake_counter.wakes.load(Ordering::Relaxed),
0,
"uncontended path must not schedule deferred contention wakes"
);
}
#[tokio::test]
async fn adversarial_writer_sustained_contention_executor_repoll_is_rate_limited() {
let _guard = quota_test_guard();
let _retained = saturate_quota_user_locks();
let user = "quota-backoff-adversarial-writer";
let stats = Arc::new(Stats::new());
let lock = quota_user_lock(user);
let held_guard = lock
.try_lock()
.expect("test must hold quota lock before polling writer");
let mut io = StatsIo::new(
tokio::io::sink(),
Arc::new(SharedCounters::new()),
Arc::clone(&stats),
user.to_string(),
Some(2048),
Arc::new(AtomicBool::new(false)),
tokio::time::Instant::now(),
);
let wake_counter = Arc::new(WakeCounter::default());
let waker = Waker::from(Arc::clone(&wake_counter));
let mut cx = Context::from_waker(&waker);
let first = Pin::new(&mut io).poll_write(&mut cx, &[0xAA]);
assert!(first.is_pending());
let start = Instant::now();
let mut observed = 0usize;
while start.elapsed() < Duration::from_millis(80) {
let wakes = wake_counter.wakes.load(Ordering::Relaxed);
if wakes > observed {
observed = wakes;
let pending = Pin::new(&mut io).poll_write(&mut cx, &[0xAB]);
assert!(pending.is_pending());
}
tokio::time::sleep(Duration::from_millis(1)).await;
}
assert!(
wake_counter.wakes.load(Ordering::Relaxed) <= 16,
"sustained contention must be rate limited; observed wakes={} in 80ms",
wake_counter.wakes.load(Ordering::Relaxed)
);
drop(held_guard);
let ready = Pin::new(&mut io).poll_write(&mut cx, &[0xAC]);
assert!(ready.is_ready());
}
#[tokio::test]
async fn adversarial_reader_sustained_contention_executor_repoll_is_rate_limited() {
let _guard = quota_test_guard();
let _retained = saturate_quota_user_locks();
let user = "quota-backoff-adversarial-reader";
let stats = Arc::new(Stats::new());
let lock = quota_user_lock(user);
let held_guard = lock
.try_lock()
.expect("test must hold quota lock before polling reader");
let mut io = StatsIo::new(
tokio::io::empty(),
Arc::new(SharedCounters::new()),
Arc::clone(&stats),
user.to_string(),
Some(2048),
Arc::new(AtomicBool::new(false)),
tokio::time::Instant::now(),
);
let wake_counter = Arc::new(WakeCounter::default());
let waker = Waker::from(Arc::clone(&wake_counter));
let mut cx = Context::from_waker(&waker);
let mut storage = [0u8; 1];
let mut buf = ReadBuf::new(&mut storage);
let first = Pin::new(&mut io).poll_read(&mut cx, &mut buf);
assert!(first.is_pending());
let start = Instant::now();
let mut observed = 0usize;
while start.elapsed() < Duration::from_millis(80) {
let wakes = wake_counter.wakes.load(Ordering::Relaxed);
if wakes > observed {
observed = wakes;
let mut next = ReadBuf::new(&mut storage);
let pending = Pin::new(&mut io).poll_read(&mut cx, &mut next);
assert!(pending.is_pending());
}
tokio::time::sleep(Duration::from_millis(1)).await;
}
assert!(
wake_counter.wakes.load(Ordering::Relaxed) <= 16,
"sustained contention must be rate limited; observed wakes={} in 80ms",
wake_counter.wakes.load(Ordering::Relaxed)
);
drop(held_guard);
let mut done = ReadBuf::new(&mut storage);
let ready = Pin::new(&mut io).poll_read(&mut cx, &mut done);
assert!(ready.is_ready());
}
#[tokio::test]
async fn edge_backoff_attempt_resets_after_contention_release() {
let _guard = quota_test_guard();
let _retained = saturate_quota_user_locks();
let user = "quota-backoff-edge-reset";
let stats = Arc::new(Stats::new());
let lock = quota_user_lock(user);
let held_guard = lock
.try_lock()
.expect("test must hold quota lock before polling writer");
let mut io = StatsIo::new(
tokio::io::sink(),
Arc::new(SharedCounters::new()),
Arc::clone(&stats),
user.to_string(),
Some(2048),
Arc::new(AtomicBool::new(false)),
tokio::time::Instant::now(),
);
let wake_counter = Arc::new(WakeCounter::default());
let waker = Waker::from(Arc::clone(&wake_counter));
let mut cx = Context::from_waker(&waker);
let initial = Pin::new(&mut io).poll_write(&mut cx, &[0x31]);
assert!(initial.is_pending());
tokio::time::sleep(Duration::from_millis(15)).await;
let wakes = wake_counter.wakes.load(Ordering::Relaxed);
if wakes > 0 {
let pending = Pin::new(&mut io).poll_write(&mut cx, &[0x32]);
assert!(pending.is_pending());
}
drop(held_guard);
let ready = Pin::new(&mut io).poll_write(&mut cx, &[0x33]);
assert!(ready.is_ready());
assert!(
!io.quota_write_wake_scheduled,
"successful write must clear deferred wake scheduling flag"
);
assert!(
io.quota_write_retry_sleep.is_none(),
"successful write must clear deferred sleep slot"
);
}
#[tokio::test]
async fn light_fuzz_writer_repoll_schedule_keeps_wake_budget_bounded() {
let _guard = quota_test_guard();
let _retained = saturate_quota_user_locks();
let user = "quota-backoff-fuzz-writer";
let stats = Arc::new(Stats::new());
let lock = quota_user_lock(user);
let held_guard = lock
.try_lock()
.expect("test must hold quota lock before fuzz loop");
let mut io = StatsIo::new(
tokio::io::sink(),
Arc::new(SharedCounters::new()),
Arc::clone(&stats),
user.to_string(),
Some(2048),
Arc::new(AtomicBool::new(false)),
tokio::time::Instant::now(),
);
let wake_counter = Arc::new(WakeCounter::default());
let waker = Waker::from(Arc::clone(&wake_counter));
let mut cx = Context::from_waker(&waker);
let mut seed = 0x5EED_CAFE_7788_9900u64;
for _ in 0..64 {
let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x51]);
assert!(poll.is_pending());
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let sleep_ms = (seed % 4) as u64;
tokio::time::sleep(Duration::from_millis(sleep_ms)).await;
}
assert!(
wake_counter.wakes.load(Ordering::Relaxed) <= 24,
"fuzzed repoll schedule must keep wake budget bounded; observed wakes={}",
wake_counter.wakes.load(Ordering::Relaxed)
);
drop(held_guard);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn stress_multi_waiter_contention_keeps_global_wake_budget_bounded() {
let _guard = quota_test_guard();
let _retained = saturate_quota_user_locks();
let user = format!("quota-backoff-stress-{}", std::process::id());
let stats = Arc::new(Stats::new());
let lock = quota_user_lock(&user);
let held_guard = lock
.try_lock()
.expect("test must hold quota lock before launching stress waiters");
let waiters = 48usize;
let mut ios = Vec::with_capacity(waiters);
let mut wake_counters = Vec::with_capacity(waiters);
for _ in 0..waiters {
ios.push(StatsIo::new(
tokio::io::sink(),
Arc::new(SharedCounters::new()),
Arc::clone(&stats),
user.clone(),
Some(4096),
Arc::new(AtomicBool::new(false)),
tokio::time::Instant::now(),
));
}
for io in &mut ios {
let counter = Arc::new(WakeCounter::default());
let waker = Waker::from(Arc::clone(&counter));
let mut cx = Context::from_waker(&waker);
let pending = Pin::new(io).poll_write(&mut cx, &[0x61]);
assert!(pending.is_pending());
wake_counters.push(counter);
}
let start = Instant::now();
while start.elapsed() < Duration::from_millis(120) {
for (idx, counter) in wake_counters.iter().enumerate() {
if counter.wakes.load(Ordering::Relaxed) > 0 {
let waker = Waker::from(Arc::clone(counter));
let mut cx = Context::from_waker(&waker);
let pending = Pin::new(&mut ios[idx]).poll_write(&mut cx, &[0x62]);
assert!(pending.is_pending());
}
}
tokio::time::sleep(Duration::from_millis(1)).await;
}
let total_wakes: usize = wake_counters
.iter()
.map(|counter| counter.wakes.load(Ordering::Relaxed))
.sum();
assert!(
total_wakes <= waiters * 20,
"stress contention must keep aggregate wake budget bounded; waiters={waiters}, wakes={total_wakes}"
);
drop(held_guard);
}

View File

@ -0,0 +1,246 @@
use super::*;
use crate::stats::Stats;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::task::{Context, Poll, Waker};
use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadBuf};
use tokio::time::{Duration, timeout};
#[derive(Default)]
struct WakeCounter {
wakes: AtomicUsize,
}
impl std::task::Wake for WakeCounter {
fn wake(self: Arc<Self>) {
self.wakes.fetch_add(1, Ordering::Relaxed);
}
fn wake_by_ref(self: &Arc<Self>) {
self.wakes.fetch_add(1, Ordering::Relaxed);
}
}
fn quota_test_guard() -> impl Drop {
super::quota_user_lock_test_scope()
}
#[tokio::test]
async fn positive_uncontended_quota_limited_writer_completes() {
let _guard = quota_test_guard();
let stats = Arc::new(Stats::new());
let mut io = StatsIo::new(
tokio::io::sink(),
Arc::new(SharedCounters::new()),
Arc::clone(&stats),
"tdd-uncontended".to_string(),
Some(1024),
Arc::new(AtomicBool::new(false)),
tokio::time::Instant::now(),
);
let result = io.write_all(&[0x41, 0x42, 0x43]).await;
assert!(result.is_ok(), "uncontended writer must complete");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn adversarial_contended_writers_without_repoll_must_not_wake_storm() {
let _guard = quota_test_guard();
let user = format!("tdd-writer-storm-{}", std::process::id());
let held = quota_user_lock(&user);
let _held_guard = held
.try_lock()
.expect("test must hold quota lock before polling writers");
let stats = Arc::new(Stats::new());
let writers = 24usize;
let mut ios = Vec::with_capacity(writers);
let mut wake_counters = Vec::with_capacity(writers);
for _ in 0..writers {
ios.push(StatsIo::new(
tokio::io::sink(),
Arc::new(SharedCounters::new()),
Arc::clone(&stats),
user.clone(),
Some(1024),
Arc::new(AtomicBool::new(false)),
tokio::time::Instant::now(),
));
}
for io in &mut ios {
let counter = Arc::new(WakeCounter::default());
let waker = Waker::from(Arc::clone(&counter));
let mut cx = Context::from_waker(&waker);
let poll = Pin::new(io).poll_write(&mut cx, &[0xAA]);
assert!(poll.is_pending(), "writer must be pending under held lock");
wake_counters.push(counter);
}
tokio::time::sleep(Duration::from_millis(25)).await;
let total_wakes: usize = wake_counters
.iter()
.map(|counter| counter.wakes.load(Ordering::Relaxed))
.sum();
assert!(
total_wakes <= writers * 4,
"retry scheduler must remain bounded without repoll; observed wakes={total_wakes}, writers={writers}"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn adversarial_contended_readers_without_repoll_must_not_wake_storm() {
let _guard = quota_test_guard();
let user = format!("tdd-reader-storm-{}", std::process::id());
let held = quota_user_lock(&user);
let _held_guard = held
.try_lock()
.expect("test must hold quota lock before polling readers");
let stats = Arc::new(Stats::new());
let readers = 24usize;
let mut ios = Vec::with_capacity(readers);
let mut wake_counters = Vec::with_capacity(readers);
for _ in 0..readers {
ios.push(StatsIo::new(
tokio::io::empty(),
Arc::new(SharedCounters::new()),
Arc::clone(&stats),
user.clone(),
Some(1024),
Arc::new(AtomicBool::new(false)),
tokio::time::Instant::now(),
));
}
for io in &mut ios {
let counter = Arc::new(WakeCounter::default());
let waker = Waker::from(Arc::clone(&counter));
let mut cx = Context::from_waker(&waker);
let mut storage = [0u8; 1];
let mut buf = ReadBuf::new(&mut storage);
let poll = Pin::new(io).poll_read(&mut cx, &mut buf);
assert!(poll.is_pending(), "reader must be pending under held lock");
wake_counters.push(counter);
}
tokio::time::sleep(Duration::from_millis(25)).await;
let total_wakes: usize = wake_counters
.iter()
.map(|counter| counter.wakes.load(Ordering::Relaxed))
.sum();
assert!(
total_wakes <= readers * 4,
"retry scheduler must remain bounded without repoll; observed wakes={total_wakes}, readers={readers}"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn integration_contended_waiters_resume_after_lock_release() {
let _guard = quota_test_guard();
let user = format!("tdd-resume-{}", std::process::id());
let held = quota_user_lock(&user);
let held_guard = held
.try_lock()
.expect("test must hold quota lock before launching waiters");
let stats = Arc::new(Stats::new());
let mut waiters = Vec::new();
for _ in 0..12 {
let stats = Arc::clone(&stats);
let user = user.clone();
waiters.push(tokio::spawn(async move {
let mut io = StatsIo::new(
tokio::io::sink(),
Arc::new(SharedCounters::new()),
stats,
user,
Some(2048),
Arc::new(AtomicBool::new(false)),
tokio::time::Instant::now(),
);
io.write_all(&[0x5A]).await
}));
}
tokio::time::sleep(Duration::from_millis(5)).await;
drop(held_guard);
timeout(Duration::from_secs(1), async {
for waiter in waiters {
let result = waiter.await.expect("waiter task must not panic");
assert!(result.is_ok(), "waiter must complete after release");
}
})
.await
.expect("all waiters must complete in bounded time");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn light_fuzz_contention_rounds_keep_retry_wakes_bounded() {
let _guard = quota_test_guard();
let mut seed = 0x9E37_79B9_AA55_1234u64;
for round in 0..20u32 {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let writers = 8 + (seed as usize % 12);
let sleep_ms = 10 + (seed as u64 % 15);
let user = format!("tdd-fuzz-{}-{round}", std::process::id());
let held = quota_user_lock(&user);
let _held_guard = held
.try_lock()
.expect("test must hold quota lock in fuzz round");
let stats = Arc::new(Stats::new());
let mut ios = Vec::with_capacity(writers);
let mut wake_counters = Vec::with_capacity(writers);
for _ in 0..writers {
ios.push(StatsIo::new(
tokio::io::sink(),
Arc::new(SharedCounters::new()),
Arc::clone(&stats),
user.clone(),
Some(2048),
Arc::new(AtomicBool::new(false)),
tokio::time::Instant::now(),
));
}
for io in &mut ios {
let counter = Arc::new(WakeCounter::default());
let waker = Waker::from(Arc::clone(&counter));
let mut cx = Context::from_waker(&waker);
let poll = Pin::new(io).poll_write(&mut cx, &[0x7A]);
assert!(matches!(poll, Poll::Pending));
wake_counters.push(counter);
}
tokio::time::sleep(Duration::from_millis(sleep_ms)).await;
let total_wakes: usize = wake_counters
.iter()
.map(|counter| counter.wakes.load(Ordering::Relaxed))
.sum();
assert!(
total_wakes <= writers * 4,
"fuzz round must keep wakes bounded; round={round}, writers={writers}, wakes={total_wakes}, sleep_ms={sleep_ms}"
);
}
}

View File

@ -137,10 +137,10 @@ async fn quota_lock_contention_writer_schedules_single_deferred_wake_until_lock_
for _ in 0..8 {
tokio::task::yield_now().await;
}
assert_eq!(
wake_counter.wakes.load(Ordering::Relaxed),
wakes_after_first_yield,
"writer contention should not schedule unbounded wake storms before lock acquisition"
let wakes_after_second_window = wake_counter.wakes.load(Ordering::Relaxed);
assert!(
wakes_after_second_window <= wakes_after_first_yield.saturating_add(2),
"writer contention should keep retry wakes bounded before lock acquisition: before={wakes_after_first_yield}, after={wakes_after_second_window}"
);
drop(held_lock);

View File

@ -1884,6 +1884,32 @@ impl Stats {
stats.octets_to_client.fetch_add(bytes, Ordering::Relaxed);
}
pub fn sub_user_octets_to(&self, user: &str, bytes: u64) {
if !self.telemetry_user_enabled() {
return;
}
self.maybe_cleanup_user_stats();
let Some(stats) = self.user_stats.get(user) else {
return;
};
Self::touch_user_stats(stats.value());
let counter = &stats.octets_to_client;
let mut current = counter.load(Ordering::Relaxed);
loop {
let next = current.saturating_sub(bytes);
match counter.compare_exchange_weak(
current,
next,
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(_) => break,
Err(actual) => current = actual,
}
}
}
pub fn increment_user_msgs_from(&self, user: &str) {
if !self.telemetry_user_enabled() {
return;
@ -2440,3 +2466,7 @@ mod connection_lease_security_tests;
#[cfg(test)]
#[path = "tests/replay_checker_security_tests.rs"]
mod replay_checker_security_tests;
#[cfg(test)]
#[path = "tests/user_octets_sub_security_tests.rs"]
mod user_octets_sub_security_tests;

View File

@ -0,0 +1,151 @@
use super::*;
use std::sync::Arc;
use std::thread;
#[test]
fn sub_user_octets_to_underflow_saturates_at_zero() {
let stats = Stats::new();
let user = "sub-underflow-user";
stats.add_user_octets_to(user, 3);
stats.sub_user_octets_to(user, 100);
assert_eq!(stats.get_user_total_octets(user), 0);
}
#[test]
fn sub_user_octets_to_does_not_affect_octets_from_client() {
let stats = Stats::new();
let user = "sub-isolation-user";
stats.add_user_octets_from(user, 17);
stats.add_user_octets_to(user, 5);
stats.sub_user_octets_to(user, 3);
assert_eq!(stats.get_user_total_octets(user), 19);
}
#[test]
fn light_fuzz_add_sub_model_matches_saturating_reference() {
let stats = Stats::new();
let user = "sub-fuzz-user";
let mut seed = 0x91D2_4CB8_EE77_1101u64;
let mut model_to = 0u64;
for _ in 0..8192 {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let amt = ((seed >> 8) & 0x3f) + 1;
if (seed & 1) == 0 {
stats.add_user_octets_to(user, amt);
model_to = model_to.saturating_add(amt);
} else {
stats.sub_user_octets_to(user, amt);
model_to = model_to.saturating_sub(amt);
}
}
assert_eq!(stats.get_user_total_octets(user), model_to);
}
#[test]
fn stress_parallel_add_sub_never_underflows_or_panics() {
let stats = Arc::new(Stats::new());
let user = "sub-stress-user";
// Pre-fund with a large offset so subtractions never saturate at zero.
// This guarantees commutative updates, making the final state deterministic.
let base_offset = 10_000_000u64;
stats.add_user_octets_to(user, base_offset);
let mut workers = Vec::new();
for tid in 0..16u64 {
let stats_for_thread = Arc::clone(&stats);
workers.push(thread::spawn(move || {
let mut seed = 0xD00D_1000_0000_0000u64 ^ tid;
let mut net_delta = 0i64;
for _ in 0..4096 {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let amt = ((seed >> 8) & 0x1f) + 1;
if (seed & 1) == 0 {
stats_for_thread.add_user_octets_to(user, amt);
net_delta += amt as i64;
} else {
stats_for_thread.sub_user_octets_to(user, amt);
net_delta -= amt as i64;
}
}
net_delta
}));
}
let mut expected_net_delta = 0i64;
for worker in workers {
expected_net_delta += worker
.join()
.expect("sub-user stress worker must not panic");
}
let expected_total = (base_offset as i64 + expected_net_delta) as u64;
let total = stats.get_user_total_octets(user);
assert_eq!(
total, expected_total,
"concurrent add/sub lost updates or suffered ABA races"
);
}
#[test]
fn sub_user_octets_to_missing_user_is_noop() {
let stats = Stats::new();
stats.sub_user_octets_to("missing-user", 1024);
assert_eq!(stats.get_user_total_octets("missing-user"), 0);
}
#[test]
fn stress_parallel_per_user_models_remain_exact() {
let stats = Arc::new(Stats::new());
let mut workers = Vec::new();
for tid in 0..16u64 {
let stats_for_thread = Arc::clone(&stats);
workers.push(thread::spawn(move || {
let user = format!("sub-per-user-{tid}");
let mut seed = 0xFACE_0000_0000_0000u64 ^ tid;
let mut model = 0u64;
for _ in 0..4096 {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let amt = ((seed >> 8) & 0x3f) + 1;
if (seed & 1) == 0 {
stats_for_thread.add_user_octets_to(&user, amt);
model = model.saturating_add(amt);
} else {
stats_for_thread.sub_user_octets_to(&user, amt);
model = model.saturating_sub(amt);
}
}
(user, model)
}));
}
for worker in workers {
let (user, model) = worker
.join()
.expect("per-user subtract stress worker must not panic");
assert_eq!(
stats.get_user_total_octets(&user),
model,
"per-user parallel model diverged"
);
}
}