mirror of https://github.com/telemt/telemt.git
Add comprehensive security tests for quota management and relay functionality
- Introduced `relay_dual_lock_race_harness_security_tests.rs` to validate user liveness during lock hold and release cycles. - Added `relay_quota_extended_attack_surface_security_tests.rs` to cover various quota scenarios including positive, negative, edge cases, and adversarial conditions. - Implemented `relay_quota_lock_eviction_lifecycle_tdd_tests.rs` to ensure proper eviction of stale entries and lifecycle management of quota locks. - Created `relay_quota_lock_eviction_stress_security_tests.rs` to stress test the eviction mechanism under high churn conditions. - Enhanced `relay_quota_lock_pressure_adversarial_tests.rs` to verify reclaiming of unreferenced entries after explicit eviction. - Developed `relay_quota_retry_allocation_latency_security_tests.rs` to benchmark and validate latency and allocation behavior under contention.
This commit is contained in:
parent
91be148b72
commit
6f17d4d231
|
|
@ -1454,9 +1454,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "iri-string"
|
||||
version = "0.7.10"
|
||||
version = "0.7.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c91338f0783edbd6195decb37bae672fd3b165faffb89bf7b9e6942f8b1a731a"
|
||||
checksum = "d8e7418f59cc01c88316161279a7f665217ae316b388e58a0d10e29f54f1e5eb"
|
||||
dependencies = [
|
||||
"memchr",
|
||||
"serde",
|
||||
|
|
|
|||
|
|
@ -32,6 +32,14 @@ pub(crate) struct RuntimeWatches {
|
|||
pub(crate) detected_ip_v6: Option<IpAddr>,
|
||||
}
|
||||
|
||||
const QUOTA_USER_LOCK_EVICT_INTERVAL_SECS: u64 = 60;
|
||||
|
||||
fn spawn_quota_lock_maintenance_task() -> tokio::task::JoinHandle<()> {
|
||||
crate::proxy::relay::spawn_quota_user_lock_evictor(std::time::Duration::from_secs(
|
||||
QUOTA_USER_LOCK_EVICT_INTERVAL_SECS,
|
||||
))
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn spawn_runtime_tasks(
|
||||
config: &Arc<ProxyConfig>,
|
||||
|
|
@ -69,6 +77,8 @@ pub(crate) async fn spawn_runtime_tasks(
|
|||
rc_clone.run_periodic_cleanup().await;
|
||||
});
|
||||
|
||||
spawn_quota_lock_maintenance_task();
|
||||
|
||||
let detected_ip_v4: Option<IpAddr> = probe.detected_ipv4.map(IpAddr::V4);
|
||||
let detected_ip_v6: Option<IpAddr> = probe.detected_ipv6.map(IpAddr::V6);
|
||||
debug!(
|
||||
|
|
@ -360,3 +370,24 @@ pub(crate) async fn mark_runtime_ready(startup_tracker: &Arc<StartupTracker>) {
|
|||
.await;
|
||||
startup_tracker.mark_ready().await;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn tdd_runtime_quota_lock_maintenance_path_spawns_single_evictor_task() {
|
||||
crate::proxy::relay::reset_quota_user_lock_evictor_spawn_count_for_tests();
|
||||
|
||||
let handle = spawn_quota_lock_maintenance_task();
|
||||
tokio::time::sleep(std::time::Duration::from_millis(5)).await;
|
||||
|
||||
assert_eq!(
|
||||
crate::proxy::relay::quota_user_lock_evictor_spawn_count_for_tests(),
|
||||
1,
|
||||
"runtime maintenance path must spawn exactly one quota lock evictor task per call"
|
||||
);
|
||||
|
||||
handle.abort();
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -131,8 +131,7 @@ fn auth_probe_scan_start_offset(
|
|||
return 0;
|
||||
}
|
||||
|
||||
let window = state_len.min(scan_limit);
|
||||
auth_probe_eviction_offset(peer_ip, now) % window
|
||||
auth_probe_eviction_offset(peer_ip, now) % state_len
|
||||
}
|
||||
|
||||
fn auth_probe_is_throttled(peer_ip: IpAddr, now: Instant) -> bool {
|
||||
|
|
@ -997,6 +996,10 @@ mod auth_probe_scan_budget_security_tests;
|
|||
#[path = "tests/handshake_auth_probe_scan_offset_stress_tests.rs"]
|
||||
mod auth_probe_scan_offset_stress_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/handshake_auth_probe_eviction_bias_security_tests.rs"]
|
||||
mod auth_probe_eviction_bias_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/handshake_advanced_clever_tests.rs"]
|
||||
mod advanced_clever_tests;
|
||||
|
|
|
|||
|
|
@ -19,6 +19,8 @@ use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
|||
use tokio::net::TcpStream;
|
||||
#[cfg(unix)]
|
||||
use tokio::net::UnixStream;
|
||||
#[cfg(unix)]
|
||||
use tokio::sync::Mutex as AsyncMutex;
|
||||
use tokio::time::{Instant, timeout};
|
||||
use tracing::debug;
|
||||
|
||||
|
|
@ -95,10 +97,6 @@ where
|
|||
Ok(Ok(())) => {}
|
||||
Ok(Err(_)) | Err(_) => break,
|
||||
}
|
||||
|
||||
if total >= byte_cap {
|
||||
break;
|
||||
}
|
||||
}
|
||||
CopyOutcome {
|
||||
total,
|
||||
|
|
@ -370,6 +368,9 @@ struct LocalInterfaceCache {
|
|||
static LOCAL_INTERFACE_CACHE: OnceLock<Mutex<LocalInterfaceCache>> = OnceLock::new();
|
||||
|
||||
#[cfg(unix)]
|
||||
static LOCAL_INTERFACE_REFRESH_LOCK: OnceLock<AsyncMutex<()>> = OnceLock::new();
|
||||
|
||||
#[cfg(all(unix, test))]
|
||||
fn local_interface_ips() -> Vec<IpAddr> {
|
||||
let cache = LOCAL_INTERFACE_CACHE.get_or_init(|| Mutex::new(LocalInterfaceCache::default()));
|
||||
let mut guard = cache.lock().unwrap_or_else(|poison| poison.into_inner());
|
||||
|
|
@ -386,11 +387,59 @@ fn local_interface_ips() -> Vec<IpAddr> {
|
|||
guard.ips.clone()
|
||||
}
|
||||
|
||||
#[cfg(not(unix))]
|
||||
#[cfg(unix)]
|
||||
async fn local_interface_ips_async() -> Vec<IpAddr> {
|
||||
let cache = LOCAL_INTERFACE_CACHE.get_or_init(|| Mutex::new(LocalInterfaceCache::default()));
|
||||
|
||||
{
|
||||
let guard = cache.lock().unwrap_or_else(|poison| poison.into_inner());
|
||||
let stale = guard
|
||||
.refreshed_at
|
||||
.is_none_or(|at| at.elapsed() >= LOCAL_INTERFACE_CACHE_TTL);
|
||||
if !stale {
|
||||
return guard.ips.clone();
|
||||
}
|
||||
}
|
||||
|
||||
let refresh_lock = LOCAL_INTERFACE_REFRESH_LOCK.get_or_init(|| AsyncMutex::new(()));
|
||||
let _refresh_guard = refresh_lock.lock().await;
|
||||
|
||||
{
|
||||
let guard = cache.lock().unwrap_or_else(|poison| poison.into_inner());
|
||||
let stale = guard
|
||||
.refreshed_at
|
||||
.is_none_or(|at| at.elapsed() >= LOCAL_INTERFACE_CACHE_TTL);
|
||||
if !stale {
|
||||
return guard.ips.clone();
|
||||
}
|
||||
}
|
||||
|
||||
let refreshed = tokio::task::spawn_blocking(collect_local_interface_ips)
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
|
||||
let mut guard = cache.lock().unwrap_or_else(|poison| poison.into_inner());
|
||||
let stale = guard
|
||||
.refreshed_at
|
||||
.is_none_or(|at| at.elapsed() >= LOCAL_INTERFACE_CACHE_TTL);
|
||||
if stale {
|
||||
guard.ips = choose_interface_snapshot(&guard.ips, refreshed);
|
||||
guard.refreshed_at = Some(StdInstant::now());
|
||||
}
|
||||
|
||||
guard.ips.clone()
|
||||
}
|
||||
|
||||
#[cfg(all(not(unix), test))]
|
||||
fn local_interface_ips() -> Vec<IpAddr> {
|
||||
Vec::new()
|
||||
}
|
||||
|
||||
#[cfg(not(unix))]
|
||||
async fn local_interface_ips_async() -> Vec<IpAddr> {
|
||||
Vec::new()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
static LOCAL_INTERFACE_ENUMERATIONS: AtomicUsize = AtomicUsize::new(0);
|
||||
|
||||
|
|
@ -457,6 +506,7 @@ fn is_mask_target_local_listener_with_interfaces(
|
|||
false
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn is_mask_target_local_listener(
|
||||
mask_host: &str,
|
||||
mask_port: u16,
|
||||
|
|
@ -477,6 +527,26 @@ fn is_mask_target_local_listener(
|
|||
)
|
||||
}
|
||||
|
||||
async fn is_mask_target_local_listener_async(
|
||||
mask_host: &str,
|
||||
mask_port: u16,
|
||||
local_addr: SocketAddr,
|
||||
resolved_override: Option<SocketAddr>,
|
||||
) -> bool {
|
||||
if mask_port != local_addr.port() {
|
||||
return false;
|
||||
}
|
||||
|
||||
let interfaces = local_interface_ips_async().await;
|
||||
is_mask_target_local_listener_with_interfaces(
|
||||
mask_host,
|
||||
mask_port,
|
||||
local_addr,
|
||||
resolved_override,
|
||||
&interfaces,
|
||||
)
|
||||
}
|
||||
|
||||
fn masking_beobachten_ttl(config: &ProxyConfig) -> Duration {
|
||||
let minutes = config.general.beobachten_minutes;
|
||||
let clamped = minutes.clamp(1, 24 * 60);
|
||||
|
|
@ -608,13 +678,15 @@ pub async fn handle_bad_client<R, W>(
|
|||
.as_deref()
|
||||
.unwrap_or(&config.censorship.tls_domain);
|
||||
let mask_port = config.censorship.mask_port;
|
||||
let outcome_started = Instant::now();
|
||||
|
||||
// Fail closed when fallback points at our own listener endpoint.
|
||||
// Self-referential masking can create recursive proxy loops under
|
||||
// misconfiguration and leak distinguishable load spikes to adversaries.
|
||||
let resolved_mask_addr = resolve_socket_addr(mask_host, mask_port);
|
||||
if is_mask_target_local_listener(mask_host, mask_port, local_addr, resolved_mask_addr) {
|
||||
if is_mask_target_local_listener_async(mask_host, mask_port, local_addr, resolved_mask_addr)
|
||||
.await
|
||||
{
|
||||
let outcome_started = Instant::now();
|
||||
debug!(
|
||||
client_type = client_type,
|
||||
host = %mask_host,
|
||||
|
|
@ -627,6 +699,8 @@ pub async fn handle_bad_client<R, W>(
|
|||
return;
|
||||
}
|
||||
|
||||
let outcome_started = Instant::now();
|
||||
|
||||
debug!(
|
||||
client_type = client_type,
|
||||
host = %mask_host,
|
||||
|
|
@ -768,7 +842,13 @@ async fn consume_client_data<R: AsyncRead + Unpin>(mut reader: R, byte_cap: usiz
|
|||
let mut total = 0usize;
|
||||
|
||||
loop {
|
||||
let n = match timeout(MASK_RELAY_IDLE_TIMEOUT, reader.read(&mut buf)).await {
|
||||
let remaining_budget = byte_cap.saturating_sub(total);
|
||||
if remaining_budget == 0 {
|
||||
break;
|
||||
}
|
||||
|
||||
let read_len = remaining_budget.min(MASK_BUFFER_SIZE);
|
||||
let n = match timeout(MASK_RELAY_IDLE_TIMEOUT, reader.read(&mut buf[..read_len])).await {
|
||||
Ok(Ok(n)) => n,
|
||||
Ok(Err(_)) | Err(_) => break,
|
||||
};
|
||||
|
|
@ -804,6 +884,10 @@ mod masking_shape_above_cap_blur_security_tests;
|
|||
#[path = "tests/masking_timing_normalization_security_tests.rs"]
|
||||
mod masking_timing_normalization_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/masking_timing_budget_coupling_security_tests.rs"]
|
||||
mod masking_timing_budget_coupling_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/masking_ab_envelope_blur_integration_security_tests.rs"]
|
||||
mod masking_ab_envelope_blur_integration_security_tests;
|
||||
|
|
@ -884,6 +968,18 @@ mod masking_interface_cache_security_tests;
|
|||
#[path = "tests/masking_interface_cache_defense_in_depth_security_tests.rs"]
|
||||
mod masking_interface_cache_defense_in_depth_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/masking_interface_cache_concurrency_security_tests.rs"]
|
||||
mod masking_interface_cache_concurrency_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/masking_production_cap_regression_security_tests.rs"]
|
||||
mod masking_production_cap_regression_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/masking_extended_attack_surface_security_tests.rs"]
|
||||
mod masking_extended_attack_surface_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/masking_padding_timeout_adversarial_tests.rs"]
|
||||
mod masking_padding_timeout_adversarial_tests;
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
use std::collections::hash_map::RandomState;
|
||||
use std::collections::{BTreeSet, HashMap};
|
||||
#[cfg(test)]
|
||||
use std::future::Future;
|
||||
use std::hash::{BuildHasher, Hash};
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
|
||||
|
|
@ -45,6 +47,8 @@ const TINY_FRAME_DEBT_LIMIT: u32 = 512;
|
|||
const C2ME_SEND_TIMEOUT: Duration = Duration::from_millis(50);
|
||||
#[cfg(not(test))]
|
||||
const C2ME_SEND_TIMEOUT: Duration = Duration::from_secs(5);
|
||||
#[cfg(test)]
|
||||
const RELAY_TEST_STEP_TIMEOUT: Duration = Duration::from_secs(1);
|
||||
const ME_D2C_FLUSH_BATCH_MAX_FRAMES_MIN: usize = 1;
|
||||
const ME_D2C_FLUSH_BATCH_MAX_BYTES_MIN: usize = 4096;
|
||||
const ME_D2C_FRAME_BUF_SHRINK_HYSTERESIS_FACTOR: usize = 2;
|
||||
|
|
@ -561,11 +565,8 @@ fn quota_would_be_exceeded_for_user_soft(
|
|||
bytes: u64,
|
||||
overshoot: u64,
|
||||
) -> bool {
|
||||
quota_limit.is_some_and(|quota| {
|
||||
let cap = quota_soft_cap(quota, overshoot);
|
||||
let used = stats.get_user_total_octets(user);
|
||||
used >= cap || bytes > cap.saturating_sub(used)
|
||||
})
|
||||
let capped_limit = quota_limit.map(|quota| quota_soft_cap(quota, overshoot));
|
||||
quota_would_be_exceeded_for_user(stats, user, capped_limit, bytes)
|
||||
}
|
||||
|
||||
fn classify_me_d2c_flush_reason(
|
||||
|
|
@ -683,7 +684,7 @@ fn quota_user_lock(user: &str) -> Arc<AsyncMutex<()>> {
|
|||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn cross_mode_quota_user_lock_for_tests(user: &str) -> Arc<Mutex<()>> {
|
||||
pub(crate) fn cross_mode_quota_user_lock_for_tests(user: &str) -> Arc<AsyncMutex<()>> {
|
||||
crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(user)
|
||||
}
|
||||
|
||||
|
|
@ -712,6 +713,16 @@ async fn enqueue_c2me_command(
|
|||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
async fn run_relay_test_step_timeout<F, T>(context: &'static str, fut: F) -> T
|
||||
where
|
||||
F: Future<Output = T>,
|
||||
{
|
||||
timeout(RELAY_TEST_STEP_TIMEOUT, fut)
|
||||
.await
|
||||
.unwrap_or_else(|_| panic!("{context} exceeded {}s", RELAY_TEST_STEP_TIMEOUT.as_secs()))
|
||||
}
|
||||
|
||||
pub(crate) async fn handle_via_middle_proxy<R, W>(
|
||||
mut crypto_reader: CryptoReader<R>,
|
||||
crypto_writer: CryptoWriter<W>,
|
||||
|
|
@ -860,6 +871,7 @@ where
|
|||
let stats_clone = stats.clone();
|
||||
let rng_clone = rng.clone();
|
||||
let user_clone = user.clone();
|
||||
let cross_mode_quota_lock_me_writer = cross_mode_quota_lock.clone();
|
||||
let last_downstream_activity_ms_clone = last_downstream_activity_ms.clone();
|
||||
let bytes_me2c_clone = bytes_me2c.clone();
|
||||
let d2c_flush_policy = MeD2cFlushPolicy::from_config(&config);
|
||||
|
|
@ -881,7 +893,7 @@ where
|
|||
|
||||
let first_is_downstream_activity =
|
||||
matches!(&first, MeResponse::Data { .. } | MeResponse::Ack(_));
|
||||
match process_me_writer_response(
|
||||
match process_me_writer_response_with_cross_mode_lock(
|
||||
first,
|
||||
&mut writer,
|
||||
proto_tag,
|
||||
|
|
@ -891,6 +903,7 @@ where
|
|||
&user_clone,
|
||||
quota_limit,
|
||||
d2c_flush_policy.quota_soft_overshoot_bytes,
|
||||
cross_mode_quota_lock_me_writer.as_ref(),
|
||||
bytes_me2c_clone.as_ref(),
|
||||
conn_id,
|
||||
d2c_flush_policy.ack_flush_immediate,
|
||||
|
|
@ -939,7 +952,7 @@ where
|
|||
|
||||
let next_is_downstream_activity =
|
||||
matches!(&next, MeResponse::Data { .. } | MeResponse::Ack(_));
|
||||
match process_me_writer_response(
|
||||
match process_me_writer_response_with_cross_mode_lock(
|
||||
next,
|
||||
&mut writer,
|
||||
proto_tag,
|
||||
|
|
@ -949,6 +962,7 @@ where
|
|||
&user_clone,
|
||||
quota_limit,
|
||||
d2c_flush_policy.quota_soft_overshoot_bytes,
|
||||
cross_mode_quota_lock_me_writer.as_ref(),
|
||||
bytes_me2c_clone.as_ref(),
|
||||
conn_id,
|
||||
d2c_flush_policy.ack_flush_immediate,
|
||||
|
|
@ -1000,7 +1014,7 @@ where
|
|||
Ok(Some(next)) => {
|
||||
let next_is_downstream_activity =
|
||||
matches!(&next, MeResponse::Data { .. } | MeResponse::Ack(_));
|
||||
match process_me_writer_response(
|
||||
match process_me_writer_response_with_cross_mode_lock(
|
||||
next,
|
||||
&mut writer,
|
||||
proto_tag,
|
||||
|
|
@ -1010,6 +1024,7 @@ where
|
|||
&user_clone,
|
||||
quota_limit,
|
||||
d2c_flush_policy.quota_soft_overshoot_bytes,
|
||||
cross_mode_quota_lock_me_writer.as_ref(),
|
||||
bytes_me2c_clone.as_ref(),
|
||||
conn_id,
|
||||
d2c_flush_policy.ack_flush_immediate,
|
||||
|
|
@ -1063,7 +1078,7 @@ where
|
|||
|
||||
let extra_is_downstream_activity =
|
||||
matches!(&extra, MeResponse::Data { .. } | MeResponse::Ack(_));
|
||||
match process_me_writer_response(
|
||||
match process_me_writer_response_with_cross_mode_lock(
|
||||
extra,
|
||||
&mut writer,
|
||||
proto_tag,
|
||||
|
|
@ -1073,6 +1088,7 @@ where
|
|||
&user_clone,
|
||||
quota_limit,
|
||||
d2c_flush_policy.quota_soft_overshoot_bytes,
|
||||
cross_mode_quota_lock_me_writer.as_ref(),
|
||||
bytes_me2c_clone.as_ref(),
|
||||
conn_id,
|
||||
d2c_flush_policy.ack_flush_immediate,
|
||||
|
|
@ -1252,10 +1268,7 @@ where
|
|||
));
|
||||
break;
|
||||
};
|
||||
let _cross_mode_quota_guard = match cross_mode_lock.lock() {
|
||||
Ok(guard) => guard,
|
||||
Err(poisoned) => poisoned.into_inner(),
|
||||
};
|
||||
let _cross_mode_quota_guard = cross_mode_lock.lock().await;
|
||||
stats.add_user_octets_from(&user, payload.len() as u64);
|
||||
if quota_exceeded_for_user(stats.as_ref(), &user, Some(limit)) {
|
||||
main_result = Err(ProxyError::DataQuotaExceeded {
|
||||
|
|
@ -1741,6 +1754,7 @@ enum MeWriterResponseOutcome {
|
|||
Close,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
async fn process_me_writer_response<W>(
|
||||
response: MeResponse,
|
||||
client_writer: &mut CryptoWriter<W>,
|
||||
|
|
@ -1756,6 +1770,44 @@ async fn process_me_writer_response<W>(
|
|||
ack_flush_immediate: bool,
|
||||
batched: bool,
|
||||
) -> Result<MeWriterResponseOutcome>
|
||||
where
|
||||
W: AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
process_me_writer_response_with_cross_mode_lock(
|
||||
response,
|
||||
client_writer,
|
||||
proto_tag,
|
||||
rng,
|
||||
frame_buf,
|
||||
stats,
|
||||
user,
|
||||
quota_limit,
|
||||
quota_soft_overshoot_bytes,
|
||||
None,
|
||||
bytes_me2c,
|
||||
conn_id,
|
||||
ack_flush_immediate,
|
||||
batched,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn process_me_writer_response_with_cross_mode_lock<W>(
|
||||
response: MeResponse,
|
||||
client_writer: &mut CryptoWriter<W>,
|
||||
proto_tag: ProtoTag,
|
||||
rng: &SecureRandom,
|
||||
frame_buf: &mut Vec<u8>,
|
||||
stats: &Stats,
|
||||
user: &str,
|
||||
quota_limit: Option<u64>,
|
||||
quota_soft_overshoot_bytes: u64,
|
||||
cross_mode_quota_lock: Option<&Arc<AsyncMutex<()>>>,
|
||||
bytes_me2c: &AtomicU64,
|
||||
conn_id: u64,
|
||||
ack_flush_immediate: bool,
|
||||
batched: bool,
|
||||
) -> Result<MeWriterResponseOutcome>
|
||||
where
|
||||
W: AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
|
|
@ -1768,8 +1820,23 @@ where
|
|||
}
|
||||
let data_len = data.len() as u64;
|
||||
if let Some(limit) = quota_limit {
|
||||
let owned_cross_mode_lock;
|
||||
let cross_mode_lock = if let Some(lock) = cross_mode_quota_lock {
|
||||
lock
|
||||
} else {
|
||||
owned_cross_mode_lock =
|
||||
crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(user);
|
||||
&owned_cross_mode_lock
|
||||
};
|
||||
let cross_mode_quota_guard = cross_mode_lock.lock().await;
|
||||
let soft_limit = quota_soft_cap(limit, quota_soft_overshoot_bytes);
|
||||
if quota_would_be_exceeded_for_user(stats, user, Some(soft_limit), data_len) {
|
||||
if quota_would_be_exceeded_for_user_soft(
|
||||
stats,
|
||||
user,
|
||||
Some(limit),
|
||||
data_len,
|
||||
quota_soft_overshoot_bytes,
|
||||
) {
|
||||
stats.increment_me_d2c_quota_reject_total(MeD2cQuotaRejectStage::PreWrite);
|
||||
return Err(ProxyError::DataQuotaExceeded {
|
||||
user: user.to_string(),
|
||||
|
|
@ -1789,6 +1856,10 @@ where
|
|||
});
|
||||
}
|
||||
|
||||
// Keep cross-mode lock scope explicit and minimal: quota reservation is serialized,
|
||||
// but socket I/O proceeds without holding same-user cross-mode admission lock.
|
||||
drop(cross_mode_quota_guard);
|
||||
|
||||
let write_mode =
|
||||
match write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf)
|
||||
.await
|
||||
|
|
@ -2084,3 +2155,27 @@ mod middle_relay_tiny_frame_debt_concurrency_security_tests;
|
|||
#[cfg(test)]
|
||||
#[path = "tests/middle_relay_tiny_frame_debt_proto_chunking_security_tests.rs"]
|
||||
mod middle_relay_tiny_frame_debt_proto_chunking_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/middle_relay_cross_mode_quota_reservation_security_tests.rs"]
|
||||
mod middle_relay_cross_mode_quota_reservation_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/middle_relay_cross_mode_quota_lock_matrix_security_tests.rs"]
|
||||
mod middle_relay_cross_mode_quota_lock_matrix_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/middle_relay_cross_mode_lookup_efficiency_security_tests.rs"]
|
||||
mod middle_relay_cross_mode_lookup_efficiency_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/middle_relay_cross_mode_lock_release_regression_security_tests.rs"]
|
||||
mod middle_relay_cross_mode_lock_release_regression_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/middle_relay_quota_extended_attack_surface_security_tests.rs"]
|
||||
mod middle_relay_quota_extended_attack_surface_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/middle_relay_quota_reservation_extreme_security_tests.rs"]
|
||||
mod middle_relay_quota_reservation_extreme_security_tests;
|
||||
|
|
|
|||
|
|
@ -1,5 +1,9 @@
|
|||
use dashmap::DashMap;
|
||||
use std::sync::{Arc, Mutex, OnceLock};
|
||||
use std::sync::{Arc, OnceLock};
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
#[cfg(test)]
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
|
||||
#[cfg(test)]
|
||||
const CROSS_MODE_QUOTA_USER_LOCKS_MAX: usize = 64;
|
||||
|
|
@ -13,6 +17,11 @@ const CROSS_MODE_QUOTA_OVERFLOW_LOCK_STRIPES: usize = 256;
|
|||
static CROSS_MODE_QUOTA_USER_LOCKS: OnceLock<DashMap<String, Arc<Mutex<()>>>> = OnceLock::new();
|
||||
static CROSS_MODE_QUOTA_USER_OVERFLOW_LOCKS: OnceLock<Vec<Arc<Mutex<()>>>> = OnceLock::new();
|
||||
|
||||
#[cfg(test)]
|
||||
static CROSS_MODE_QUOTA_USER_LOCK_LOOKUPS: AtomicUsize = AtomicUsize::new(0);
|
||||
#[cfg(test)]
|
||||
static CROSS_MODE_QUOTA_USER_LOOKUPS_BY_USER: OnceLock<DashMap<String, usize>> = OnceLock::new();
|
||||
|
||||
fn cross_mode_quota_overflow_user_lock(user: &str) -> Arc<Mutex<()>> {
|
||||
let stripes = CROSS_MODE_QUOTA_USER_OVERFLOW_LOCKS.get_or_init(|| {
|
||||
(0..CROSS_MODE_QUOTA_OVERFLOW_LOCK_STRIPES)
|
||||
|
|
@ -25,6 +34,14 @@ fn cross_mode_quota_overflow_user_lock(user: &str) -> Arc<Mutex<()>> {
|
|||
}
|
||||
|
||||
pub(crate) fn cross_mode_quota_user_lock(user: &str) -> Arc<Mutex<()>> {
|
||||
#[cfg(test)]
|
||||
{
|
||||
CROSS_MODE_QUOTA_USER_LOCK_LOOKUPS.fetch_add(1, Ordering::Relaxed);
|
||||
let lookups = CROSS_MODE_QUOTA_USER_LOOKUPS_BY_USER.get_or_init(DashMap::new);
|
||||
let mut entry = lookups.entry(user.to_string()).or_insert(0);
|
||||
*entry += 1;
|
||||
}
|
||||
|
||||
let locks = CROSS_MODE_QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
if let Some(existing) = locks.get(user) {
|
||||
return Arc::clone(existing.value());
|
||||
|
|
@ -48,6 +65,24 @@ pub(crate) fn cross_mode_quota_user_lock(user: &str) -> Arc<Mutex<()>> {
|
|||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn reset_cross_mode_quota_user_lock_lookup_count_for_tests() {
|
||||
CROSS_MODE_QUOTA_USER_LOCK_LOOKUPS.store(0, Ordering::Relaxed);
|
||||
let lookups = CROSS_MODE_QUOTA_USER_LOOKUPS_BY_USER.get_or_init(DashMap::new);
|
||||
lookups.clear();
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn cross_mode_quota_user_lock_lookup_count_for_tests() -> usize {
|
||||
CROSS_MODE_QUOTA_USER_LOCK_LOOKUPS.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn cross_mode_quota_user_lock_lookup_count_for_user_for_tests(user: &str) -> usize {
|
||||
let lookups = CROSS_MODE_QUOTA_USER_LOOKUPS_BY_USER.get_or_init(DashMap::new);
|
||||
lookups.get(user).map(|entry| *entry).unwrap_or(0)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/quota_lock_registry_cross_mode_adversarial_tests.rs"]
|
||||
mod quota_lock_registry_cross_mode_adversarial_tests;
|
||||
|
|
|
|||
|
|
@ -62,6 +62,7 @@ use std::sync::{Arc, Mutex, OnceLock};
|
|||
use std::task::{Context, Poll};
|
||||
use std::time::Duration;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf, copy_bidirectional_with_sizes};
|
||||
use tokio::sync::Mutex as AsyncMutex;
|
||||
use tokio::time::{Instant, Sleep};
|
||||
use tracing::{debug, trace, warn};
|
||||
|
||||
|
|
@ -210,7 +211,7 @@ struct StatsIo<S> {
|
|||
stats: Arc<Stats>,
|
||||
user: String,
|
||||
quota_lock: Option<Arc<Mutex<()>>>,
|
||||
cross_mode_quota_lock: Option<Arc<Mutex<()>>>,
|
||||
cross_mode_quota_lock: Option<Arc<AsyncMutex<()>>>,
|
||||
quota_limit: Option<u64>,
|
||||
quota_exceeded: Arc<AtomicBool>,
|
||||
quota_read_wake_scheduled: bool,
|
||||
|
|
@ -289,6 +290,21 @@ const QUOTA_CONTENTION_RETRY_MAX_INTERVAL: Duration = Duration::from_millis(16);
|
|||
#[cfg(not(test))]
|
||||
const QUOTA_CONTENTION_RETRY_MAX_INTERVAL: Duration = Duration::from_millis(64);
|
||||
|
||||
#[cfg(test)]
|
||||
static QUOTA_RETRY_SLEEP_ALLOCS: AtomicU64 = AtomicU64::new(0);
|
||||
#[cfg(test)]
|
||||
static QUOTA_USER_LOCK_EVICTOR_SPAWN_COUNT: AtomicU64 = AtomicU64::new(0);
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn reset_quota_retry_sleep_allocs_for_tests() {
|
||||
QUOTA_RETRY_SLEEP_ALLOCS.store(0, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn quota_retry_sleep_allocs_for_tests() -> u64 {
|
||||
QUOTA_RETRY_SLEEP_ALLOCS.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn quota_contention_retry_delay(retry_attempt: u8) -> Duration {
|
||||
let shift = u32::from(retry_attempt.min(5));
|
||||
|
|
@ -317,6 +333,8 @@ fn poll_quota_retry_sleep(
|
|||
) {
|
||||
if !*wake_scheduled {
|
||||
*wake_scheduled = true;
|
||||
#[cfg(test)]
|
||||
QUOTA_RETRY_SLEEP_ALLOCS.fetch_add(1, Ordering::Relaxed);
|
||||
*sleep_slot = Some(Box::pin(tokio::time::sleep(quota_contention_retry_delay(
|
||||
*retry_attempt,
|
||||
))));
|
||||
|
|
@ -368,16 +386,47 @@ fn quota_overflow_user_lock(user: &str) -> Arc<Mutex<()>> {
|
|||
Arc::clone(&stripes[hash % stripes.len()])
|
||||
}
|
||||
|
||||
pub(crate) fn quota_user_lock_evict() {
|
||||
if let Some(locks) = QUOTA_USER_LOCKS.get() {
|
||||
locks.retain(|_, value| Arc::strong_count(value) > 1);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn spawn_quota_user_lock_evictor(interval: Duration) -> tokio::task::JoinHandle<()> {
|
||||
let interval = interval.max(Duration::from_millis(1));
|
||||
#[cfg(test)]
|
||||
QUOTA_USER_LOCK_EVICTOR_SPAWN_COUNT.fetch_add(1, Ordering::Relaxed);
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
tokio::time::sleep(interval).await;
|
||||
quota_user_lock_evict();
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn spawn_quota_user_lock_evictor_for_tests(
|
||||
interval: Duration,
|
||||
) -> tokio::task::JoinHandle<()> {
|
||||
spawn_quota_user_lock_evictor(interval)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn reset_quota_user_lock_evictor_spawn_count_for_tests() {
|
||||
QUOTA_USER_LOCK_EVICTOR_SPAWN_COUNT.store(0, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn quota_user_lock_evictor_spawn_count_for_tests() -> u64 {
|
||||
QUOTA_USER_LOCK_EVICTOR_SPAWN_COUNT.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
fn quota_user_lock(user: &str) -> Arc<Mutex<()>> {
|
||||
let locks = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
if let Some(existing) = locks.get(user) {
|
||||
return Arc::clone(existing.value());
|
||||
}
|
||||
|
||||
if locks.len() >= QUOTA_USER_LOCKS_MAX {
|
||||
locks.retain(|_, value| Arc::strong_count(value) > 1);
|
||||
}
|
||||
|
||||
if locks.len() >= QUOTA_USER_LOCKS_MAX {
|
||||
return quota_overflow_user_lock(user);
|
||||
}
|
||||
|
|
@ -393,7 +442,7 @@ fn quota_user_lock(user: &str) -> Arc<Mutex<()>> {
|
|||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn cross_mode_quota_user_lock_for_tests(user: &str) -> Arc<Mutex<()>> {
|
||||
pub(crate) fn cross_mode_quota_user_lock_for_tests(user: &str) -> Arc<AsyncMutex<()>> {
|
||||
crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(user)
|
||||
}
|
||||
|
||||
|
|
@ -410,14 +459,7 @@ impl<S: AsyncRead + Unpin> AsyncRead for StatsIo<S> {
|
|||
|
||||
let _quota_guard = if let Some(lock) = this.quota_lock.as_ref() {
|
||||
match lock.try_lock() {
|
||||
Ok(guard) => {
|
||||
reset_quota_retry_scheduler(
|
||||
&mut this.quota_read_retry_sleep,
|
||||
&mut this.quota_read_wake_scheduled,
|
||||
&mut this.quota_read_retry_attempt,
|
||||
);
|
||||
Some(guard)
|
||||
}
|
||||
Ok(guard) => Some(guard),
|
||||
Err(_) => {
|
||||
poll_quota_retry_sleep(
|
||||
&mut this.quota_read_retry_sleep,
|
||||
|
|
@ -434,14 +476,7 @@ impl<S: AsyncRead + Unpin> AsyncRead for StatsIo<S> {
|
|||
|
||||
let _cross_mode_quota_guard = if let Some(lock) = this.cross_mode_quota_lock.as_ref() {
|
||||
match lock.try_lock() {
|
||||
Ok(guard) => {
|
||||
reset_quota_retry_scheduler(
|
||||
&mut this.quota_read_retry_sleep,
|
||||
&mut this.quota_read_wake_scheduled,
|
||||
&mut this.quota_read_retry_attempt,
|
||||
);
|
||||
Some(guard)
|
||||
}
|
||||
Ok(guard) => Some(guard),
|
||||
Err(_) => {
|
||||
poll_quota_retry_sleep(
|
||||
&mut this.quota_read_retry_sleep,
|
||||
|
|
@ -456,6 +491,12 @@ impl<S: AsyncRead + Unpin> AsyncRead for StatsIo<S> {
|
|||
None
|
||||
};
|
||||
|
||||
reset_quota_retry_scheduler(
|
||||
&mut this.quota_read_retry_sleep,
|
||||
&mut this.quota_read_wake_scheduled,
|
||||
&mut this.quota_read_retry_attempt,
|
||||
);
|
||||
|
||||
if let Some(limit) = this.quota_limit
|
||||
&& this.stats.get_user_total_octets(&this.user) >= limit
|
||||
{
|
||||
|
|
@ -523,14 +564,7 @@ impl<S: AsyncWrite + Unpin> AsyncWrite for StatsIo<S> {
|
|||
|
||||
let _quota_guard = if let Some(lock) = this.quota_lock.as_ref() {
|
||||
match lock.try_lock() {
|
||||
Ok(guard) => {
|
||||
reset_quota_retry_scheduler(
|
||||
&mut this.quota_write_retry_sleep,
|
||||
&mut this.quota_write_wake_scheduled,
|
||||
&mut this.quota_write_retry_attempt,
|
||||
);
|
||||
Some(guard)
|
||||
}
|
||||
Ok(guard) => Some(guard),
|
||||
Err(_) => {
|
||||
poll_quota_retry_sleep(
|
||||
&mut this.quota_write_retry_sleep,
|
||||
|
|
@ -547,14 +581,7 @@ impl<S: AsyncWrite + Unpin> AsyncWrite for StatsIo<S> {
|
|||
|
||||
let _cross_mode_quota_guard = if let Some(lock) = this.cross_mode_quota_lock.as_ref() {
|
||||
match lock.try_lock() {
|
||||
Ok(guard) => {
|
||||
reset_quota_retry_scheduler(
|
||||
&mut this.quota_write_retry_sleep,
|
||||
&mut this.quota_write_wake_scheduled,
|
||||
&mut this.quota_write_retry_attempt,
|
||||
);
|
||||
Some(guard)
|
||||
}
|
||||
Ok(guard) => Some(guard),
|
||||
Err(_) => {
|
||||
poll_quota_retry_sleep(
|
||||
&mut this.quota_write_retry_sleep,
|
||||
|
|
@ -569,6 +596,12 @@ impl<S: AsyncWrite + Unpin> AsyncWrite for StatsIo<S> {
|
|||
None
|
||||
};
|
||||
|
||||
reset_quota_retry_scheduler(
|
||||
&mut this.quota_write_retry_sleep,
|
||||
&mut this.quota_write_wake_scheduled,
|
||||
&mut this.quota_write_retry_attempt,
|
||||
);
|
||||
|
||||
let write_buf = if let Some(limit) = this.quota_limit {
|
||||
let used = this.stats.get_user_total_octets(&this.user);
|
||||
if used >= limit {
|
||||
|
|
@ -861,6 +894,10 @@ mod relay_quota_model_adversarial_tests;
|
|||
#[path = "tests/relay_quota_overflow_regression_tests.rs"]
|
||||
mod relay_quota_overflow_regression_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_quota_extended_attack_surface_security_tests.rs"]
|
||||
mod relay_quota_extended_attack_surface_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_watchdog_delta_security_tests.rs"]
|
||||
mod relay_watchdog_delta_security_tests;
|
||||
|
|
@ -889,6 +926,14 @@ mod relay_quota_retry_scheduler_tdd_tests;
|
|||
#[path = "tests/relay_cross_mode_quota_fairness_tdd_tests.rs"]
|
||||
mod relay_cross_mode_quota_fairness_tdd_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_cross_mode_pipeline_hol_integration_security_tests.rs"]
|
||||
mod relay_cross_mode_pipeline_hol_integration_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_cross_mode_pipeline_latency_benchmark_security_tests.rs"]
|
||||
mod relay_cross_mode_pipeline_latency_benchmark_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_quota_retry_backoff_security_tests.rs"]
|
||||
mod relay_quota_retry_backoff_security_tests;
|
||||
|
|
@ -896,3 +941,31 @@ mod relay_quota_retry_backoff_security_tests;
|
|||
#[cfg(test)]
|
||||
#[path = "tests/relay_quota_retry_backoff_benchmark_security_tests.rs"]
|
||||
mod relay_quota_retry_backoff_benchmark_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_dual_lock_backoff_regression_security_tests.rs"]
|
||||
mod relay_dual_lock_backoff_regression_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_dual_lock_contention_matrix_security_tests.rs"]
|
||||
mod relay_dual_lock_contention_matrix_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_dual_lock_race_harness_security_tests.rs"]
|
||||
mod relay_dual_lock_race_harness_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_dual_lock_alternating_contention_security_tests.rs"]
|
||||
mod relay_dual_lock_alternating_contention_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_quota_retry_allocation_latency_security_tests.rs"]
|
||||
mod relay_quota_retry_allocation_latency_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_quota_lock_eviction_lifecycle_tdd_tests.rs"]
|
||||
mod relay_quota_lock_eviction_lifecycle_tdd_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_quota_lock_eviction_stress_security_tests.rs"]
|
||||
mod relay_quota_lock_eviction_stress_security_tests;
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ use crate::proxy::handshake::HandshakeSuccess;
|
|||
use crate::stream::{CryptoReader, CryptoWriter};
|
||||
use crate::transport::proxy_protocol::ProxyProtocolV1Builder;
|
||||
use rand::rngs::StdRng;
|
||||
use rand::RngCore;
|
||||
use rand::Rng;
|
||||
use rand::SeedableRng;
|
||||
use std::net::Ipv4Addr;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
|
||||
|
|
|
|||
|
|
@ -0,0 +1,93 @@
|
|||
use super::*;
|
||||
use std::collections::HashSet;
|
||||
use std::net::{IpAddr, Ipv4Addr};
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> {
|
||||
auth_probe_test_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(|poisoned| poisoned.into_inner())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn adversarial_large_state_offsets_escape_first_scan_window() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
let base = Instant::now();
|
||||
let state_len = 65_536usize;
|
||||
let scan_limit = 1_024usize;
|
||||
|
||||
let mut saw_offset_outside_first_window = false;
|
||||
for i in 0..8_192u64 {
|
||||
let ip = IpAddr::V4(Ipv4Addr::new(
|
||||
((i >> 16) & 0xff) as u8,
|
||||
((i >> 8) & 0xff) as u8,
|
||||
(i & 0xff) as u8,
|
||||
((i.wrapping_mul(131)) & 0xff) as u8,
|
||||
));
|
||||
let now = base + Duration::from_nanos(i);
|
||||
let start = auth_probe_scan_start_offset(ip, now, state_len, scan_limit);
|
||||
if start >= scan_limit {
|
||||
saw_offset_outside_first_window = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
assert!(
|
||||
saw_offset_outside_first_window,
|
||||
"scan start offset must cover the full auth-probe state, not only the first scan window"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stress_large_state_offsets_cover_many_scan_windows() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
let base = Instant::now();
|
||||
let state_len = 65_536usize;
|
||||
let scan_limit = 1_024usize;
|
||||
|
||||
let mut covered_windows = HashSet::new();
|
||||
for i in 0..16_384u64 {
|
||||
let ip = IpAddr::V4(Ipv4Addr::new(
|
||||
((i >> 16) & 0xff) as u8,
|
||||
((i >> 8) & 0xff) as u8,
|
||||
(i & 0xff) as u8,
|
||||
((i.wrapping_mul(17)) & 0xff) as u8,
|
||||
));
|
||||
let now = base + Duration::from_micros(i);
|
||||
let start = auth_probe_scan_start_offset(ip, now, state_len, scan_limit);
|
||||
covered_windows.insert(start / scan_limit);
|
||||
}
|
||||
|
||||
assert!(
|
||||
covered_windows.len() >= 16,
|
||||
"eviction scan must not collapse to a tiny hot zone; covered windows={} out of {}",
|
||||
covered_windows.len(),
|
||||
state_len / scan_limit
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn light_fuzz_offset_always_stays_inside_state_len() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
let mut seed = 0xC0FF_EE12_3456_789Au64;
|
||||
let base = Instant::now();
|
||||
|
||||
for _ in 0..8_192usize {
|
||||
seed ^= seed << 7;
|
||||
seed ^= seed >> 9;
|
||||
seed ^= seed << 8;
|
||||
|
||||
let ip = IpAddr::V4(Ipv4Addr::new(
|
||||
(seed >> 24) as u8,
|
||||
(seed >> 16) as u8,
|
||||
(seed >> 8) as u8,
|
||||
seed as u8,
|
||||
));
|
||||
let state_len = ((seed >> 16) as usize % 200_000).saturating_add(1);
|
||||
let scan_limit = ((seed >> 40) as usize % 2_048).saturating_add(1);
|
||||
let now = base + Duration::from_nanos(seed & 0x0fff);
|
||||
let start = auth_probe_scan_start_offset(ip, now, state_len, scan_limit);
|
||||
|
||||
assert!(start < state_len, "scan offset must stay inside state length");
|
||||
}
|
||||
}
|
||||
|
|
@ -22,12 +22,13 @@ fn edge_zero_state_len_yields_zero_start_offset() {
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn adversarial_large_state_must_bound_start_offset_to_scan_budget() {
|
||||
fn adversarial_large_state_must_allow_start_offset_outside_scan_budget_window() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
let base = Instant::now();
|
||||
let scan_limit = 16usize;
|
||||
let state_len = 65_536usize;
|
||||
|
||||
let mut saw_offset_outside_window = false;
|
||||
for i in 0..2048u32 {
|
||||
let ip = IpAddr::V4(Ipv4Addr::new(
|
||||
203,
|
||||
|
|
@ -38,10 +39,19 @@ fn adversarial_large_state_must_bound_start_offset_to_scan_budget() {
|
|||
let now = base + Duration::from_micros(i as u64);
|
||||
let start = auth_probe_scan_start_offset(ip, now, state_len, scan_limit);
|
||||
assert!(
|
||||
start < scan_limit,
|
||||
"start offset must stay within scan window; start={start}, limit={scan_limit}"
|
||||
start < state_len,
|
||||
"start offset must stay within state length; start={start}, len={state_len}"
|
||||
);
|
||||
if start >= scan_limit {
|
||||
saw_offset_outside_window = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
assert!(
|
||||
saw_offset_outside_window,
|
||||
"large-state eviction must sample beyond the first scan window"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -80,11 +90,10 @@ fn light_fuzz_scan_offset_budget_never_exceeds_effective_window() {
|
|||
let scan_limit = ((seed >> 32) as usize % 512).saturating_add(1);
|
||||
let now = base + Duration::from_nanos(seed & 0xffff);
|
||||
let start = auth_probe_scan_start_offset(ip, now, state_len, scan_limit);
|
||||
let effective_window = state_len.min(scan_limit);
|
||||
|
||||
assert!(
|
||||
start < effective_window,
|
||||
"scan offset must stay inside effective window"
|
||||
start < state_len,
|
||||
"scan offset must stay inside state length"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
@ -22,10 +22,10 @@ fn positive_same_ip_moving_time_yields_diverse_scan_offsets() {
|
|||
uniq.insert(offset);
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
uniq.len(),
|
||||
16,
|
||||
"offset randomization must cover the entire scan window over 512 samples"
|
||||
assert!(
|
||||
uniq.len() >= 256,
|
||||
"offset randomization collapsed unexpectedly for same-ip moving-time samples (uniq={})",
|
||||
uniq.len()
|
||||
);
|
||||
}
|
||||
|
||||
|
|
@ -45,10 +45,10 @@ fn adversarial_many_ips_same_time_spreads_offsets_without_bias_collapse() {
|
|||
uniq.insert(auth_probe_scan_start_offset(ip, now, 65_536, 16));
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
uniq.len(),
|
||||
16,
|
||||
"scan offset distribution collapsed unexpectedly across peer set"
|
||||
assert!(
|
||||
uniq.len() >= 512,
|
||||
"scan offset distribution collapsed unexpectedly across adversarial peer set (uniq={})",
|
||||
uniq.len()
|
||||
);
|
||||
}
|
||||
|
||||
|
|
@ -108,6 +108,9 @@ fn light_fuzz_scan_offset_stays_within_window_for_randomized_inputs() {
|
|||
let now = base + Duration::from_nanos(seed & 0x1fff);
|
||||
|
||||
let offset = auth_probe_scan_start_offset(ip, now, state_len, scan_limit);
|
||||
assert!(offset < state_len.min(scan_limit));
|
||||
assert!(
|
||||
offset < state_len,
|
||||
"scan offset must always remain inside state length"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
use super::*;
|
||||
use crate::crypto::{sha256, sha256_hmac, AesCtr};
|
||||
use crate::protocol::constants::{ProtoTag, RESERVED_NONCE_BEGINNINGS, RESERVED_NONCE_FIRST_BYTES};
|
||||
use rand::{RngExt, SeedableRng};
|
||||
use rand::{Rng, SeedableRng};
|
||||
use rand::rngs::StdRng;
|
||||
use std::collections::HashSet;
|
||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
|
||||
|
|
|
|||
|
|
@ -0,0 +1,217 @@
|
|||
use super::*;
|
||||
use tokio::io::{AsyncWriteExt, duplex};
|
||||
use tokio::time::{Duration, Instant, timeout};
|
||||
|
||||
fn make_self_target_config(
|
||||
timing_normalization_enabled: bool,
|
||||
floor_ms: u64,
|
||||
ceiling_ms: u64,
|
||||
beobachten_enabled: bool,
|
||||
) -> ProxyConfig {
|
||||
let mut config = ProxyConfig::default();
|
||||
config.general.beobachten = beobachten_enabled;
|
||||
config.general.beobachten_minutes = 5;
|
||||
config.censorship.mask = true;
|
||||
config.censorship.mask_unix_sock = None;
|
||||
config.censorship.mask_host = Some("127.0.0.1".to_string());
|
||||
config.censorship.mask_port = 443;
|
||||
config.censorship.mask_timing_normalization_enabled = timing_normalization_enabled;
|
||||
config.censorship.mask_timing_normalization_floor_ms = floor_ms;
|
||||
config.censorship.mask_timing_normalization_ceiling_ms = ceiling_ms;
|
||||
config
|
||||
}
|
||||
|
||||
async fn run_self_target_refusal(
|
||||
config: ProxyConfig,
|
||||
peer: SocketAddr,
|
||||
initial: &'static [u8],
|
||||
) -> Duration {
|
||||
let beobachten = BeobachtenStore::new();
|
||||
let local_addr: SocketAddr = "127.0.0.1:443".parse().expect("valid local addr");
|
||||
|
||||
let (mut client, server) = duplex(1024);
|
||||
let started = Instant::now();
|
||||
let task = tokio::spawn(async move {
|
||||
handle_bad_client(server, tokio::io::sink(), initial, peer, local_addr, &config, &beobachten)
|
||||
.await;
|
||||
});
|
||||
|
||||
client
|
||||
.shutdown()
|
||||
.await
|
||||
.expect("client shutdown must succeed");
|
||||
|
||||
timeout(Duration::from_secs(3), task)
|
||||
.await
|
||||
.expect("self-target refusal must complete in bounded time")
|
||||
.expect("self-target refusal task must not panic");
|
||||
|
||||
started.elapsed()
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn positive_self_target_refusal_honors_normalization_floor() {
|
||||
let config = make_self_target_config(true, 120, 120, false);
|
||||
let peer: SocketAddr = "203.0.113.41:54041".parse().expect("valid peer");
|
||||
|
||||
let elapsed = run_self_target_refusal(config, peer, b"GET / HTTP/1.1\r\n\r\n").await;
|
||||
|
||||
assert!(
|
||||
elapsed >= Duration::from_millis(110) && elapsed < Duration::from_millis(260),
|
||||
"normalized self-target refusal must stay within expected envelope"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn negative_non_normalized_refusal_does_not_sleep_to_large_floor() {
|
||||
let config = make_self_target_config(false, 240, 240, false);
|
||||
let peer: SocketAddr = "203.0.113.42:54042".parse().expect("valid peer");
|
||||
|
||||
let elapsed = run_self_target_refusal(config, peer, b"GET / HTTP/1.1\r\n\r\n").await;
|
||||
|
||||
assert!(
|
||||
elapsed < Duration::from_millis(180),
|
||||
"non-normalized path must not inherit normalization floor delays"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn edge_ceiling_below_floor_uses_floor_fail_closed() {
|
||||
let config = make_self_target_config(true, 140, 80, false);
|
||||
let peer: SocketAddr = "203.0.113.43:54043".parse().expect("valid peer");
|
||||
|
||||
let elapsed = run_self_target_refusal(config, peer, b"GET / HTTP/1.1\r\n\r\n").await;
|
||||
|
||||
assert!(
|
||||
elapsed >= Duration::from_millis(130) && elapsed < Duration::from_millis(280),
|
||||
"ceiling<floor must clamp to floor to preserve deterministic normalization"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn adversarial_blackhat_parallel_probes_remain_bounded_and_uniform() {
|
||||
let workers = 24usize;
|
||||
let mut tasks = Vec::with_capacity(workers);
|
||||
|
||||
for idx in 0..workers {
|
||||
tasks.push(tokio::spawn(async move {
|
||||
let cfg = make_self_target_config(true, 110, 140, false);
|
||||
let peer: SocketAddr = format!("203.0.113.50:{}", 54100 + idx as u16)
|
||||
.parse()
|
||||
.expect("valid peer");
|
||||
run_self_target_refusal(cfg, peer, b"GET /x HTTP/1.1\r\n\r\n").await
|
||||
}));
|
||||
}
|
||||
|
||||
let mut min = Duration::from_secs(60);
|
||||
let mut max = Duration::from_millis(0);
|
||||
for task in tasks {
|
||||
let elapsed = task.await.expect("probe task must not panic");
|
||||
if elapsed < min {
|
||||
min = elapsed;
|
||||
}
|
||||
if elapsed > max {
|
||||
max = elapsed;
|
||||
}
|
||||
assert!(
|
||||
elapsed >= Duration::from_millis(100) && elapsed < Duration::from_millis(320),
|
||||
"parallel probe latency must stay bounded under normalization"
|
||||
);
|
||||
}
|
||||
|
||||
assert!(
|
||||
max.saturating_sub(min) <= Duration::from_millis(130),
|
||||
"normalization should limit path variance across adversarial parallel probes"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn integration_beobachten_records_probe_classification_on_refusal() {
|
||||
let config = make_self_target_config(false, 0, 0, true);
|
||||
let peer: SocketAddr = "198.51.100.71:55071".parse().expect("valid peer");
|
||||
let local_addr: SocketAddr = "127.0.0.1:443".parse().expect("valid local addr");
|
||||
let beobachten = BeobachtenStore::new();
|
||||
|
||||
let (mut client, server) = duplex(1024);
|
||||
let task = tokio::spawn(async move {
|
||||
handle_bad_client(
|
||||
server,
|
||||
tokio::io::sink(),
|
||||
b"GET /classified HTTP/1.1\r\nHost: demo\r\n\r\n",
|
||||
peer,
|
||||
local_addr,
|
||||
&config,
|
||||
&beobachten,
|
||||
)
|
||||
.await;
|
||||
|
||||
beobachten.snapshot_text(Duration::from_secs(60))
|
||||
});
|
||||
|
||||
client
|
||||
.shutdown()
|
||||
.await
|
||||
.expect("client shutdown must succeed");
|
||||
|
||||
let snapshot = timeout(Duration::from_secs(3), task)
|
||||
.await
|
||||
.expect("integration task must complete")
|
||||
.expect("integration task must not panic");
|
||||
|
||||
assert!(snapshot.contains("[HTTP]"));
|
||||
assert!(snapshot.contains("198.51.100.71-1"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn light_fuzz_timing_configuration_matrix_is_bounded() {
|
||||
let mut seed = 0xA17E_55AA_2026_0323u64;
|
||||
|
||||
for case in 0..48u64 {
|
||||
seed ^= seed << 7;
|
||||
seed ^= seed >> 9;
|
||||
seed ^= seed << 8;
|
||||
|
||||
let enabled = (seed & 1) == 0;
|
||||
let floor = (seed >> 8) % 180;
|
||||
let ceiling = (seed >> 24) % 180;
|
||||
let config = make_self_target_config(enabled, floor, ceiling, false);
|
||||
let peer: SocketAddr = format!("203.0.113.90:{}", 56000 + (case as u16))
|
||||
.parse()
|
||||
.expect("valid peer");
|
||||
|
||||
let elapsed = run_self_target_refusal(config, peer, b"HEAD /h HTTP/1.1\r\n\r\n").await;
|
||||
|
||||
assert!(
|
||||
elapsed < Duration::from_millis(420),
|
||||
"fuzz case must stay bounded and never hang"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn stress_high_fanout_self_target_refusal_no_deadlock_or_timeout() {
|
||||
let workers = 64usize;
|
||||
let mut tasks = Vec::with_capacity(workers);
|
||||
|
||||
for idx in 0..workers {
|
||||
tasks.push(tokio::spawn(async move {
|
||||
let config = make_self_target_config(false, 0, 0, false);
|
||||
let peer: SocketAddr = format!("198.51.100.200:{}", 57000 + idx as u16)
|
||||
.parse()
|
||||
.expect("valid peer");
|
||||
run_self_target_refusal(config, peer, b"GET /stress HTTP/1.1\r\n\r\n").await
|
||||
}));
|
||||
}
|
||||
|
||||
timeout(Duration::from_secs(5), async {
|
||||
for task in tasks {
|
||||
let elapsed = task.await.expect("stress task must not panic");
|
||||
assert!(
|
||||
elapsed < Duration::from_millis(260),
|
||||
"stress refusal must remain bounded without normalization"
|
||||
);
|
||||
}
|
||||
})
|
||||
.await
|
||||
.expect("high-fanout refusal workload must complete without deadlock");
|
||||
}
|
||||
|
|
@ -0,0 +1,41 @@
|
|||
#![cfg(unix)]
|
||||
|
||||
use super::*;
|
||||
use std::sync::{Mutex, OnceLock};
|
||||
use tokio::sync::Barrier;
|
||||
|
||||
fn interface_cache_test_lock() -> &'static Mutex<()> {
|
||||
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
LOCK.get_or_init(|| Mutex::new(()))
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn adversarial_parallel_cold_miss_performs_single_interface_refresh() {
|
||||
let _guard = interface_cache_test_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(|poison| poison.into_inner());
|
||||
reset_local_interface_enumerations_for_tests();
|
||||
|
||||
let local_addr: SocketAddr = "0.0.0.0:443".parse().expect("valid local addr");
|
||||
let workers = 32usize;
|
||||
let barrier = std::sync::Arc::new(Barrier::new(workers));
|
||||
let mut tasks = Vec::with_capacity(workers);
|
||||
|
||||
for _ in 0..workers {
|
||||
let barrier = std::sync::Arc::clone(&barrier);
|
||||
tasks.push(tokio::spawn(async move {
|
||||
barrier.wait().await;
|
||||
is_mask_target_local_listener_async("127.0.0.1", 443, local_addr, None).await
|
||||
}));
|
||||
}
|
||||
|
||||
for task in tasks {
|
||||
let _ = task.await.expect("parallel cache task must not panic");
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
local_interface_enumerations_for_tests(),
|
||||
1,
|
||||
"parallel cold misses must coalesce into a single interface enumeration"
|
||||
);
|
||||
}
|
||||
|
|
@ -8,8 +8,8 @@ fn interface_cache_test_lock() -> &'static Mutex<()> {
|
|||
LOCK.get_or_init(|| Mutex::new(()))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tdd_repeated_local_listener_checks_do_not_repeat_interface_enumeration_within_window() {
|
||||
#[tokio::test]
|
||||
async fn tdd_repeated_local_listener_checks_do_not_repeat_interface_enumeration_within_window() {
|
||||
let _guard = interface_cache_test_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(|poison| poison.into_inner());
|
||||
|
|
@ -17,8 +17,8 @@ fn tdd_repeated_local_listener_checks_do_not_repeat_interface_enumeration_within
|
|||
|
||||
let local_addr: SocketAddr = "0.0.0.0:443".parse().expect("valid local addr");
|
||||
|
||||
let _ = is_mask_target_local_listener("127.0.0.1", 443, local_addr, None);
|
||||
let _ = is_mask_target_local_listener("127.0.0.1", 443, local_addr, None);
|
||||
let _ = is_mask_target_local_listener_async("127.0.0.1", 443, local_addr, None).await;
|
||||
let _ = is_mask_target_local_listener_async("127.0.0.1", 443, local_addr, None).await;
|
||||
|
||||
assert_eq!(
|
||||
local_interface_enumerations_for_tests(),
|
||||
|
|
@ -27,15 +27,15 @@ fn tdd_repeated_local_listener_checks_do_not_repeat_interface_enumeration_within
|
|||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tdd_non_local_port_short_circuit_does_not_enumerate_interfaces() {
|
||||
#[tokio::test]
|
||||
async fn tdd_non_local_port_short_circuit_does_not_enumerate_interfaces() {
|
||||
let _guard = interface_cache_test_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(|poison| poison.into_inner());
|
||||
reset_local_interface_enumerations_for_tests();
|
||||
|
||||
let local_addr: SocketAddr = "0.0.0.0:443".parse().expect("valid local addr");
|
||||
let is_local = is_mask_target_local_listener("127.0.0.1", 8443, local_addr, None);
|
||||
let is_local = is_mask_target_local_listener_async("127.0.0.1", 8443, local_addr, None).await;
|
||||
|
||||
assert!(!is_local, "different port must not be treated as local listener");
|
||||
assert_eq!(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,289 @@
|
|||
use super::*;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::task::{Context, Poll};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio::time::{Duration, Instant, timeout};
|
||||
|
||||
const PROD_CAP_BYTES: usize = 5 * 1024 * 1024;
|
||||
|
||||
struct FinitePatternReader {
|
||||
remaining: usize,
|
||||
chunk: usize,
|
||||
read_calls: Arc<AtomicUsize>,
|
||||
}
|
||||
|
||||
impl FinitePatternReader {
|
||||
fn new(total: usize, chunk: usize, read_calls: Arc<AtomicUsize>) -> Self {
|
||||
Self {
|
||||
remaining: total,
|
||||
chunk,
|
||||
read_calls,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for FinitePatternReader {
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
buf: &mut tokio::io::ReadBuf<'_>,
|
||||
) -> Poll<std::io::Result<()>> {
|
||||
self.read_calls.fetch_add(1, Ordering::Relaxed);
|
||||
|
||||
if self.remaining == 0 {
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
|
||||
let take = self.remaining.min(self.chunk).min(buf.remaining());
|
||||
if take == 0 {
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
|
||||
let fill = vec![0x5Au8; take];
|
||||
buf.put_slice(&fill);
|
||||
self.remaining -= take;
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct CountingWriter {
|
||||
written: usize,
|
||||
}
|
||||
|
||||
impl AsyncWrite for CountingWriter {
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<std::io::Result<usize>> {
|
||||
self.written = self.written.saturating_add(buf.len());
|
||||
Poll::Ready(Ok(buf.len()))
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
) -> Poll<std::io::Result<()>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
) -> Poll<std::io::Result<()>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
struct NeverReadyReader;
|
||||
|
||||
impl AsyncRead for NeverReadyReader {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
_buf: &mut tokio::io::ReadBuf<'_>,
|
||||
) -> Poll<std::io::Result<()>> {
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
|
||||
struct BudgetProbeReader {
|
||||
remaining: usize,
|
||||
total_read: Arc<AtomicUsize>,
|
||||
}
|
||||
|
||||
impl BudgetProbeReader {
|
||||
fn new(total: usize, total_read: Arc<AtomicUsize>) -> Self {
|
||||
Self {
|
||||
remaining: total,
|
||||
total_read,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for BudgetProbeReader {
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
buf: &mut tokio::io::ReadBuf<'_>,
|
||||
) -> Poll<std::io::Result<()>> {
|
||||
if self.remaining == 0 {
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
|
||||
let take = self.remaining.min(buf.remaining());
|
||||
if take == 0 {
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
|
||||
let fill = vec![0xA5u8; take];
|
||||
buf.put_slice(&fill);
|
||||
self.remaining -= take;
|
||||
self.total_read.fetch_add(take, Ordering::Relaxed);
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn positive_copy_with_production_cap_stops_exactly_at_budget() {
|
||||
let read_calls = Arc::new(AtomicUsize::new(0));
|
||||
let mut reader = FinitePatternReader::new(PROD_CAP_BYTES + (256 * 1024), 4096, read_calls);
|
||||
let mut writer = CountingWriter::default();
|
||||
|
||||
let outcome = copy_with_idle_timeout(&mut reader, &mut writer, PROD_CAP_BYTES, true).await;
|
||||
|
||||
assert_eq!(
|
||||
outcome.total, PROD_CAP_BYTES,
|
||||
"copy path must stop at explicit production cap"
|
||||
);
|
||||
assert_eq!(writer.written, PROD_CAP_BYTES);
|
||||
assert!(
|
||||
!outcome.ended_by_eof,
|
||||
"byte-cap stop must not be misclassified as EOF"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn negative_consume_with_zero_cap_performs_no_reads() {
|
||||
let read_calls = Arc::new(AtomicUsize::new(0));
|
||||
let reader = FinitePatternReader::new(1024, 64, Arc::clone(&read_calls));
|
||||
|
||||
consume_client_data_with_timeout_and_cap(reader, 0).await;
|
||||
|
||||
assert_eq!(
|
||||
read_calls.load(Ordering::Relaxed),
|
||||
0,
|
||||
"zero cap must return before reading attacker-controlled bytes"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn edge_copy_below_cap_reports_eof_without_overread() {
|
||||
let read_calls = Arc::new(AtomicUsize::new(0));
|
||||
let payload = 73 * 1024;
|
||||
let mut reader = FinitePatternReader::new(payload, 3072, read_calls);
|
||||
let mut writer = CountingWriter::default();
|
||||
|
||||
let outcome = copy_with_idle_timeout(&mut reader, &mut writer, PROD_CAP_BYTES, true).await;
|
||||
|
||||
assert_eq!(outcome.total, payload);
|
||||
assert_eq!(writer.written, payload);
|
||||
assert!(
|
||||
outcome.ended_by_eof,
|
||||
"finite upstream below cap must terminate via EOF path"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn adversarial_blackhat_never_ready_reader_is_bounded_by_timeout_guards() {
|
||||
let started = Instant::now();
|
||||
|
||||
consume_client_data_with_timeout_and_cap(NeverReadyReader, PROD_CAP_BYTES).await;
|
||||
|
||||
assert!(
|
||||
started.elapsed() < Duration::from_millis(350),
|
||||
"never-ready reader must be bounded by idle/relay timeout protections"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn integration_consume_path_honors_production_cap_for_large_payload() {
|
||||
let read_calls = Arc::new(AtomicUsize::new(0));
|
||||
let reader = FinitePatternReader::new(PROD_CAP_BYTES + (1024 * 1024), 8192, read_calls);
|
||||
|
||||
let bounded = timeout(
|
||||
Duration::from_millis(350),
|
||||
consume_client_data_with_timeout_and_cap(reader, PROD_CAP_BYTES),
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(
|
||||
bounded.is_ok(),
|
||||
"consume path with production cap must finish within bounded time"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn adversarial_consume_path_never_reads_beyond_declared_byte_cap() {
|
||||
let byte_cap = 5usize;
|
||||
let total_read = Arc::new(AtomicUsize::new(0));
|
||||
let reader = BudgetProbeReader::new(256 * 1024, Arc::clone(&total_read));
|
||||
|
||||
consume_client_data_with_timeout_and_cap(reader, byte_cap).await;
|
||||
|
||||
assert!(
|
||||
total_read.load(Ordering::Relaxed) <= byte_cap,
|
||||
"consume path must not read more than configured byte cap"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn light_fuzz_cap_and_payload_matrix_preserves_min_budget_invariant() {
|
||||
let mut seed = 0x1234_5678_9ABC_DEF0u64;
|
||||
|
||||
for _case in 0..96u32 {
|
||||
seed ^= seed << 7;
|
||||
seed ^= seed >> 9;
|
||||
seed ^= seed << 8;
|
||||
|
||||
let cap = ((seed & 0x3ffff) as usize).saturating_add(1);
|
||||
let payload = ((seed.rotate_left(11) & 0x7ffff) as usize).saturating_add(1);
|
||||
let chunk = (((seed >> 5) & 0x1fff) as usize).saturating_add(1);
|
||||
|
||||
let read_calls = Arc::new(AtomicUsize::new(0));
|
||||
let mut reader = FinitePatternReader::new(payload, chunk, read_calls);
|
||||
let mut writer = CountingWriter::default();
|
||||
|
||||
let outcome = copy_with_idle_timeout(&mut reader, &mut writer, cap, true).await;
|
||||
let expected = payload.min(cap);
|
||||
|
||||
assert_eq!(
|
||||
outcome.total, expected,
|
||||
"copy total must match min(payload, cap) under fuzzed inputs"
|
||||
);
|
||||
assert_eq!(writer.written, expected);
|
||||
if payload <= cap {
|
||||
assert!(outcome.ended_by_eof);
|
||||
} else {
|
||||
assert!(!outcome.ended_by_eof);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn stress_parallel_copy_tasks_with_production_cap_complete_without_leaks() {
|
||||
let workers = 8usize;
|
||||
let mut tasks = Vec::with_capacity(workers);
|
||||
|
||||
for idx in 0..workers {
|
||||
tasks.push(tokio::spawn(async move {
|
||||
let read_calls = Arc::new(AtomicUsize::new(0));
|
||||
let mut reader = FinitePatternReader::new(
|
||||
PROD_CAP_BYTES + (idx + 1) * 4096,
|
||||
4096 + (idx * 257),
|
||||
read_calls,
|
||||
);
|
||||
let mut writer = CountingWriter::default();
|
||||
copy_with_idle_timeout(&mut reader, &mut writer, PROD_CAP_BYTES, true).await
|
||||
}));
|
||||
}
|
||||
|
||||
timeout(Duration::from_secs(3), async {
|
||||
for task in tasks {
|
||||
let outcome = task.await.expect("stress task must not panic");
|
||||
assert_eq!(
|
||||
outcome.total, PROD_CAP_BYTES,
|
||||
"stress copy task must stay within production cap"
|
||||
);
|
||||
assert!(
|
||||
!outcome.ended_by_eof,
|
||||
"stress task should end due to cap, not EOF"
|
||||
);
|
||||
}
|
||||
})
|
||||
.await
|
||||
.expect("stress suite must complete in bounded time");
|
||||
}
|
||||
|
|
@ -12,71 +12,77 @@ fn closed_local_port() -> u16 {
|
|||
port
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn self_target_detection_matches_literal_ipv4_listener() {
|
||||
#[tokio::test]
|
||||
async fn self_target_detection_matches_literal_ipv4_listener() {
|
||||
let local: SocketAddr = "198.51.100.40:443".parse().unwrap();
|
||||
assert!(is_mask_target_local_listener(
|
||||
assert!(is_mask_target_local_listener_async(
|
||||
"198.51.100.40",
|
||||
443,
|
||||
local,
|
||||
None,
|
||||
));
|
||||
)
|
||||
.await);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn self_target_detection_matches_bracketed_ipv6_listener() {
|
||||
#[tokio::test]
|
||||
async fn self_target_detection_matches_bracketed_ipv6_listener() {
|
||||
let local: SocketAddr = "[2001:db8::44]:8443".parse().unwrap();
|
||||
assert!(is_mask_target_local_listener(
|
||||
assert!(is_mask_target_local_listener_async(
|
||||
"[2001:db8::44]",
|
||||
8443,
|
||||
local,
|
||||
None,
|
||||
));
|
||||
)
|
||||
.await);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn self_target_detection_keeps_same_ip_different_port_forwardable() {
|
||||
#[tokio::test]
|
||||
async fn self_target_detection_keeps_same_ip_different_port_forwardable() {
|
||||
let local: SocketAddr = "203.0.113.44:443".parse().unwrap();
|
||||
assert!(!is_mask_target_local_listener(
|
||||
assert!(!is_mask_target_local_listener_async(
|
||||
"203.0.113.44",
|
||||
8443,
|
||||
local,
|
||||
None,
|
||||
));
|
||||
)
|
||||
.await);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn self_target_detection_normalizes_ipv4_mapped_ipv6_literal() {
|
||||
#[tokio::test]
|
||||
async fn self_target_detection_normalizes_ipv4_mapped_ipv6_literal() {
|
||||
let local: SocketAddr = "127.0.0.1:443".parse().unwrap();
|
||||
assert!(is_mask_target_local_listener(
|
||||
assert!(is_mask_target_local_listener_async(
|
||||
"::ffff:127.0.0.1",
|
||||
443,
|
||||
local,
|
||||
None,
|
||||
));
|
||||
)
|
||||
.await);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn self_target_detection_unspecified_bind_blocks_loopback_target() {
|
||||
#[tokio::test]
|
||||
async fn self_target_detection_unspecified_bind_blocks_loopback_target() {
|
||||
let local: SocketAddr = "0.0.0.0:443".parse().unwrap();
|
||||
assert!(is_mask_target_local_listener(
|
||||
assert!(is_mask_target_local_listener_async(
|
||||
"127.0.0.1",
|
||||
443,
|
||||
local,
|
||||
None,
|
||||
));
|
||||
)
|
||||
.await);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn self_target_detection_unspecified_bind_keeps_remote_target_forwardable() {
|
||||
#[tokio::test]
|
||||
async fn self_target_detection_unspecified_bind_keeps_remote_target_forwardable() {
|
||||
let local: SocketAddr = "0.0.0.0:443".parse().unwrap();
|
||||
let remote: SocketAddr = "198.51.100.44:443".parse().unwrap();
|
||||
assert!(!is_mask_target_local_listener(
|
||||
assert!(!is_mask_target_local_listener_async(
|
||||
"mask.example",
|
||||
443,
|
||||
local,
|
||||
Some(remote),
|
||||
));
|
||||
)
|
||||
.await);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,55 @@
|
|||
#![cfg(unix)]
|
||||
|
||||
use super::*;
|
||||
use tokio::io::{AsyncWriteExt, duplex};
|
||||
use tokio::time::{Duration, Instant, timeout};
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn adversarial_delayed_interface_lookup_does_not_consume_outcome_floor_budget() {
|
||||
let mut config = ProxyConfig::default();
|
||||
config.general.beobachten = false;
|
||||
config.censorship.mask = true;
|
||||
config.censorship.mask_unix_sock = None;
|
||||
config.censorship.mask_host = Some("127.0.0.1".to_string());
|
||||
config.censorship.mask_port = 443;
|
||||
config.censorship.mask_timing_normalization_enabled = true;
|
||||
config.censorship.mask_timing_normalization_floor_ms = 120;
|
||||
config.censorship.mask_timing_normalization_ceiling_ms = 120;
|
||||
|
||||
let peer: SocketAddr = "203.0.113.151:55151".parse().expect("valid peer");
|
||||
let local_addr: SocketAddr = "0.0.0.0:443".parse().expect("valid local addr");
|
||||
let beobachten = BeobachtenStore::new();
|
||||
|
||||
let refresh_lock = LOCAL_INTERFACE_REFRESH_LOCK.get_or_init(|| AsyncMutex::new(()));
|
||||
let held_refresh_guard = refresh_lock.lock().await;
|
||||
|
||||
let (mut client, server) = duplex(1024);
|
||||
let started = Instant::now();
|
||||
let task = tokio::spawn(async move {
|
||||
handle_bad_client(
|
||||
server,
|
||||
tokio::io::sink(),
|
||||
b"GET / HTTP/1.1\r\n\r\n",
|
||||
peer,
|
||||
local_addr,
|
||||
&config,
|
||||
&beobachten,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(80)).await;
|
||||
drop(held_refresh_guard);
|
||||
client.shutdown().await.expect("client shutdown must succeed");
|
||||
|
||||
timeout(Duration::from_secs(2), task)
|
||||
.await
|
||||
.expect("task must finish in bounded time")
|
||||
.expect("task must not panic");
|
||||
let elapsed = started.elapsed();
|
||||
|
||||
assert!(
|
||||
elapsed >= Duration::from_millis(180) && elapsed < Duration::from_millis(350),
|
||||
"timing normalization floor must start after pre-outcome self-target checks"
|
||||
);
|
||||
}
|
||||
|
|
@ -645,6 +645,75 @@ fn quota_exceeded_boundary_is_inclusive() {
|
|||
assert!(!quota_exceeded_for_user(&stats, user, Some(51)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quota_soft_helper_matches_capped_generic_helper_matrix() {
|
||||
let stats = Stats::new();
|
||||
let user = "quota-soft-parity";
|
||||
|
||||
for used in [0u64, 1, 7, 63, 127, 255] {
|
||||
stats.sub_user_octets_to(user, stats.get_user_total_octets(user));
|
||||
stats.add_user_octets_to(user, used);
|
||||
|
||||
for quota in [8u64, 64, 128, 256] {
|
||||
for overshoot in [0u64, 1, 5, 32] {
|
||||
for bytes in [0u64, 1, 2, 7, 31, 64] {
|
||||
let soft = quota_would_be_exceeded_for_user_soft(
|
||||
&stats,
|
||||
user,
|
||||
Some(quota),
|
||||
bytes,
|
||||
overshoot,
|
||||
);
|
||||
let capped = quota_would_be_exceeded_for_user(
|
||||
&stats,
|
||||
user,
|
||||
Some(quota_soft_cap(quota, overshoot)),
|
||||
bytes,
|
||||
);
|
||||
assert_eq!(
|
||||
soft, capped,
|
||||
"soft helper parity mismatch: used={used} quota={quota} overshoot={overshoot} bytes={bytes}"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quota_soft_helper_none_limit_never_rejects() {
|
||||
let stats = Stats::new();
|
||||
let user = "quota-soft-none";
|
||||
stats.add_user_octets_to(user, u64::MAX);
|
||||
|
||||
assert!(!quota_would_be_exceeded_for_user_soft(
|
||||
&stats,
|
||||
user,
|
||||
None,
|
||||
u64::MAX,
|
||||
u64::MAX,
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quota_soft_cap_saturates_and_stays_fail_closed() {
|
||||
let stats = Stats::new();
|
||||
let user = "quota-soft-saturating";
|
||||
let quota = u64::MAX - 2;
|
||||
let overshoot = 100;
|
||||
|
||||
assert_eq!(quota_soft_cap(quota, overshoot), u64::MAX);
|
||||
|
||||
stats.add_user_octets_to(user, u64::MAX - 1);
|
||||
assert!(quota_would_be_exceeded_for_user_soft(
|
||||
&stats,
|
||||
user,
|
||||
Some(quota),
|
||||
2,
|
||||
overshoot,
|
||||
));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn enqueue_c2me_close_fast_path_succeeds_without_backpressure() {
|
||||
let (tx, mut rx) = mpsc::channel::<C2MeCommand>(4);
|
||||
|
|
|
|||
|
|
@ -0,0 +1,295 @@
|
|||
use super::*;
|
||||
use crate::crypto::{AesCtr, SecureRandom};
|
||||
use crate::stats::Stats;
|
||||
use crate::stream::CryptoWriter;
|
||||
use bytes::Bytes;
|
||||
use std::pin::Pin;
|
||||
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::task::{Context, Poll, Waker};
|
||||
use tokio::io::AsyncWrite;
|
||||
use tokio::sync::Notify;
|
||||
use tokio::task::JoinSet;
|
||||
use tokio::time::{Duration, timeout};
|
||||
|
||||
fn make_crypto_writer<W>(writer: W) -> CryptoWriter<W>
|
||||
where
|
||||
W: tokio::io::AsyncWrite + Unpin,
|
||||
{
|
||||
let key = [0u8; 32];
|
||||
let iv = 0u128;
|
||||
CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024)
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct BlockingWriteState {
|
||||
write_entered: AtomicBool,
|
||||
released: AtomicBool,
|
||||
write_waker: Mutex<Option<Waker>>,
|
||||
write_entered_notify: Notify,
|
||||
}
|
||||
|
||||
struct BlockingWrite {
|
||||
state: Arc<BlockingWriteState>,
|
||||
}
|
||||
|
||||
impl BlockingWrite {
|
||||
fn new(state: Arc<BlockingWriteState>) -> Self {
|
||||
Self { state }
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for BlockingWrite {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<std::io::Result<usize>> {
|
||||
self.state.write_entered.store(true, Ordering::Release);
|
||||
self.state.write_entered_notify.notify_waiters();
|
||||
|
||||
if self.state.released.load(Ordering::Acquire) {
|
||||
return Poll::Ready(Ok(buf.len()));
|
||||
}
|
||||
|
||||
if let Ok(mut slot) = self.state.write_waker.lock() {
|
||||
*slot = Some(cx.waker().clone());
|
||||
}
|
||||
|
||||
Poll::Pending
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
async fn wait_until_blocking_write_entered(state: &Arc<BlockingWriteState>) {
|
||||
for _ in 0..8 {
|
||||
if state.write_entered.load(Ordering::Acquire) {
|
||||
return;
|
||||
}
|
||||
let _ = timeout(Duration::from_millis(25), state.write_entered_notify.notified()).await;
|
||||
}
|
||||
|
||||
panic!("blocking writer did not enter poll_write in bounded time");
|
||||
}
|
||||
|
||||
fn release_blocking_write(state: &Arc<BlockingWriteState>) {
|
||||
state.released.store(true, Ordering::Release);
|
||||
if let Ok(mut slot) = state.write_waker.lock()
|
||||
&& let Some(waker) = slot.take()
|
||||
{
|
||||
waker.wake();
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn adversarial_blocked_write_releases_cross_mode_lock_and_preserves_fail_closed_quota() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user = format!("middle-cross-release-regression-{}", std::process::id());
|
||||
let cross_mode_lock = Arc::new(cross_mode_quota_user_lock_for_tests(&user));
|
||||
let bytes_me2c = Arc::new(AtomicU64::new(0));
|
||||
let writer_state = Arc::new(BlockingWriteState::default());
|
||||
|
||||
let first = {
|
||||
let stats = Arc::clone(&stats);
|
||||
let user = user.clone();
|
||||
let cross_mode_lock = Arc::clone(&cross_mode_lock);
|
||||
let bytes_me2c = Arc::clone(&bytes_me2c);
|
||||
let writer_state = Arc::clone(&writer_state);
|
||||
tokio::spawn(async move {
|
||||
let mut writer = make_crypto_writer(BlockingWrite::new(writer_state));
|
||||
let mut frame_buf = Vec::new();
|
||||
process_me_writer_response_with_cross_mode_lock(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0xAA, 0xBB, 0xCC, 0xDD]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
stats.as_ref(),
|
||||
&user,
|
||||
Some(4),
|
||||
0,
|
||||
Some(&cross_mode_lock),
|
||||
bytes_me2c.as_ref(),
|
||||
41_000,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await
|
||||
})
|
||||
};
|
||||
|
||||
wait_until_blocking_write_entered(&writer_state).await;
|
||||
|
||||
let guard = timeout(Duration::from_millis(40), cross_mode_lock.lock())
|
||||
.await
|
||||
.expect("cross-mode lock must be released while first write is pending");
|
||||
drop(guard);
|
||||
|
||||
let second = {
|
||||
let stats = Arc::clone(&stats);
|
||||
let user = user.clone();
|
||||
let cross_mode_lock = Arc::clone(&cross_mode_lock);
|
||||
let bytes_me2c = Arc::clone(&bytes_me2c);
|
||||
tokio::spawn(async move {
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
timeout(
|
||||
Duration::from_millis(150),
|
||||
process_me_writer_response_with_cross_mode_lock(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0xEE]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
stats.as_ref(),
|
||||
&user,
|
||||
Some(4),
|
||||
0,
|
||||
Some(&cross_mode_lock),
|
||||
bytes_me2c.as_ref(),
|
||||
41_001,
|
||||
false,
|
||||
false,
|
||||
),
|
||||
)
|
||||
.await
|
||||
})
|
||||
};
|
||||
|
||||
let second_result = second
|
||||
.await
|
||||
.expect("second task must not panic")
|
||||
.expect("second write must not block on cross-mode lock");
|
||||
assert!(
|
||||
matches!(second_result, Err(ProxyError::DataQuotaExceeded { .. })),
|
||||
"second write must fail closed due to first write reservation"
|
||||
);
|
||||
|
||||
release_blocking_write(&writer_state);
|
||||
|
||||
let first_result = timeout(Duration::from_millis(300), first)
|
||||
.await
|
||||
.expect("first task timed out")
|
||||
.expect("first task must not panic");
|
||||
assert!(first_result.is_ok());
|
||||
|
||||
assert_eq!(stats.get_user_total_octets(&user), 4);
|
||||
assert_eq!(bytes_me2c.load(Ordering::Relaxed), 4);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn stress_pending_write_does_not_starve_same_user_waiters_after_quota_boundary() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user = format!("middle-cross-release-stress-{}", std::process::id());
|
||||
let cross_mode_lock = Arc::new(cross_mode_quota_user_lock_for_tests(&user));
|
||||
let bytes_me2c = Arc::new(AtomicU64::new(0));
|
||||
let writer_state = Arc::new(BlockingWriteState::default());
|
||||
|
||||
let first = {
|
||||
let stats = Arc::clone(&stats);
|
||||
let user = user.clone();
|
||||
let cross_mode_lock = Arc::clone(&cross_mode_lock);
|
||||
let bytes_me2c = Arc::clone(&bytes_me2c);
|
||||
let writer_state = Arc::clone(&writer_state);
|
||||
tokio::spawn(async move {
|
||||
let mut writer = make_crypto_writer(BlockingWrite::new(writer_state));
|
||||
let mut frame_buf = Vec::new();
|
||||
process_me_writer_response_with_cross_mode_lock(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0x01, 0x02]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
stats.as_ref(),
|
||||
&user,
|
||||
Some(3),
|
||||
0,
|
||||
Some(&cross_mode_lock),
|
||||
bytes_me2c.as_ref(),
|
||||
41_100,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await
|
||||
})
|
||||
};
|
||||
|
||||
wait_until_blocking_write_entered(&writer_state).await;
|
||||
|
||||
let mut set = JoinSet::new();
|
||||
for idx in 0..48u64 {
|
||||
let stats = Arc::clone(&stats);
|
||||
let user = user.clone();
|
||||
let cross_mode_lock = Arc::clone(&cross_mode_lock);
|
||||
let bytes_me2c = Arc::clone(&bytes_me2c);
|
||||
set.spawn(async move {
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
timeout(
|
||||
Duration::from_millis(200),
|
||||
process_me_writer_response_with_cross_mode_lock(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0x10]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
stats.as_ref(),
|
||||
&user,
|
||||
Some(3),
|
||||
0,
|
||||
Some(&cross_mode_lock),
|
||||
bytes_me2c.as_ref(),
|
||||
41_200 + idx,
|
||||
false,
|
||||
false,
|
||||
),
|
||||
)
|
||||
.await
|
||||
});
|
||||
}
|
||||
|
||||
let mut ok = 0usize;
|
||||
let mut quota_exceeded = 0usize;
|
||||
while let Some(done) = set.join_next().await {
|
||||
let timed = done.expect("waiter task must not panic");
|
||||
let result = timed.expect("waiter must not block behind pending first write");
|
||||
match result {
|
||||
Ok(_) => ok += 1,
|
||||
Err(ProxyError::DataQuotaExceeded { .. }) => quota_exceeded += 1,
|
||||
Err(other) => panic!("unexpected error in waiter: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(ok, 1, "exactly one waiter should consume remaining one-byte quota");
|
||||
assert_eq!(quota_exceeded, 47);
|
||||
|
||||
release_blocking_write(&writer_state);
|
||||
|
||||
let first_result = timeout(Duration::from_millis(300), first)
|
||||
.await
|
||||
.expect("first task timed out")
|
||||
.expect("first task must not panic");
|
||||
assert!(first_result.is_ok());
|
||||
|
||||
assert_eq!(stats.get_user_total_octets(&user), 3);
|
||||
assert_eq!(bytes_me2c.load(Ordering::Relaxed), 3);
|
||||
}
|
||||
|
|
@ -0,0 +1,116 @@
|
|||
use super::*;
|
||||
use crate::crypto::{AesCtr, SecureRandom};
|
||||
use crate::stats::Stats;
|
||||
use crate::stream::CryptoWriter;
|
||||
use bytes::Bytes;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::{Mutex, OnceLock};
|
||||
|
||||
fn make_crypto_writer<W>(writer: W) -> CryptoWriter<W>
|
||||
where
|
||||
W: tokio::io::AsyncWrite + Unpin,
|
||||
{
|
||||
let key = [0u8; 32];
|
||||
let iv = 0u128;
|
||||
CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024)
|
||||
}
|
||||
|
||||
fn lookup_counter_test_lock() -> &'static Mutex<()> {
|
||||
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
LOCK.get_or_init(|| Mutex::new(()))
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn tdd_prefetched_cross_mode_lock_avoids_per_frame_registry_lookup_in_me_to_client_writer() {
|
||||
let _guard = lookup_counter_test_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(|poison| poison.into_inner());
|
||||
|
||||
let stats = Stats::new();
|
||||
let user = format!("middle-cross-mode-lookup-{}", std::process::id());
|
||||
let cross_mode_lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
|
||||
|
||||
crate::proxy::quota_lock_registry::reset_cross_mode_quota_user_lock_lookup_count_for_tests();
|
||||
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
|
||||
for idx in 0..8u64 {
|
||||
let outcome = process_me_writer_response_with_cross_mode_lock(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0xAB]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(1024),
|
||||
0,
|
||||
Some(&cross_mode_lock),
|
||||
&bytes_me2c,
|
||||
20_000 + idx,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(outcome.is_ok());
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
crate::proxy::quota_lock_registry::cross_mode_quota_user_lock_lookup_count_for_user_for_tests(&user),
|
||||
0,
|
||||
"prefetched lock path must not re-query lock registry per frame"
|
||||
);
|
||||
assert_eq!(stats.get_user_total_octets(&user), 8);
|
||||
assert_eq!(bytes_me2c.load(Ordering::Relaxed), 8);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn control_without_prefetched_lock_still_uses_registry_lookup_path() {
|
||||
let _guard = lookup_counter_test_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(|poison| poison.into_inner());
|
||||
|
||||
let stats = Stats::new();
|
||||
let user = format!("middle-cross-mode-lookup-control-{}", std::process::id());
|
||||
|
||||
crate::proxy::quota_lock_registry::reset_cross_mode_quota_user_lock_lookup_count_for_tests();
|
||||
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
|
||||
let outcome = process_me_writer_response_with_cross_mode_lock(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0xCD]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(1024),
|
||||
0,
|
||||
None,
|
||||
&bytes_me2c,
|
||||
20_100,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(outcome.is_ok());
|
||||
assert_eq!(
|
||||
crate::proxy::quota_lock_registry::cross_mode_quota_user_lock_lookup_count_for_user_for_tests(&user),
|
||||
1,
|
||||
"fallback path without prefetched lock should perform a registry lookup"
|
||||
);
|
||||
}
|
||||
|
|
@ -0,0 +1,376 @@
|
|||
use super::*;
|
||||
use crate::crypto::{AesCtr, SecureRandom};
|
||||
use crate::stats::Stats;
|
||||
use crate::stream::CryptoWriter;
|
||||
use bytes::Bytes;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use tokio::time::{Duration, timeout};
|
||||
|
||||
fn make_crypto_writer<W>(writer: W) -> CryptoWriter<W>
|
||||
where
|
||||
W: tokio::io::AsyncWrite + Unpin,
|
||||
{
|
||||
let key = [0u8; 32];
|
||||
let iv = 0u128;
|
||||
CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn positive_quota_limited_me_to_client_write_updates_counters_exactly_once() {
|
||||
let stats = Stats::new();
|
||||
let user = format!("middle-cross-matrix-positive-{}", std::process::id());
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
|
||||
let result = process_me_writer_response(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[1, 2, 3, 4]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(128),
|
||||
0,
|
||||
&bytes_me2c,
|
||||
10_001,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(stats.get_user_total_octets(&user), 4);
|
||||
assert_eq!(bytes_me2c.load(Ordering::Relaxed), 4);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn negative_held_cross_mode_lock_blocks_quota_limited_me_to_client_path() {
|
||||
let stats = Stats::new();
|
||||
let user = format!("middle-cross-matrix-negative-{}", std::process::id());
|
||||
let held = cross_mode_quota_user_lock_for_tests(&user);
|
||||
let held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold lock before ME->C call");
|
||||
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
|
||||
let blocked = timeout(
|
||||
Duration::from_millis(25),
|
||||
process_me_writer_response(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0x41]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(256),
|
||||
0,
|
||||
&bytes_me2c,
|
||||
10_002,
|
||||
false,
|
||||
false,
|
||||
),
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(blocked.is_err());
|
||||
drop(held_guard);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn edge_quota_none_bypasses_cross_mode_lock_guard_in_me_to_client_path() {
|
||||
let stats = Stats::new();
|
||||
let user = format!("middle-cross-matrix-edge-none-{}", std::process::id());
|
||||
let held = cross_mode_quota_user_lock_for_tests(&user);
|
||||
let held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold lock while quota is disabled");
|
||||
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
|
||||
let outcome = timeout(
|
||||
Duration::from_millis(80),
|
||||
process_me_writer_response(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0x11, 0x22]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
None,
|
||||
0,
|
||||
&bytes_me2c,
|
||||
10_003,
|
||||
false,
|
||||
false,
|
||||
),
|
||||
)
|
||||
.await
|
||||
.expect("quota-none path must not wait on cross-mode lock");
|
||||
|
||||
assert!(outcome.is_ok());
|
||||
drop(held_guard);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn adversarial_same_user_parallel_quota_limited_writes_stay_hard_capped() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user = format!("middle-cross-matrix-adversarial-{}", std::process::id());
|
||||
let limit = 64u64;
|
||||
let bytes_me2c = Arc::new(AtomicU64::new(0));
|
||||
let mut tasks = Vec::new();
|
||||
|
||||
for idx in 0..256u64 {
|
||||
let stats = Arc::clone(&stats);
|
||||
let bytes_me2c = Arc::clone(&bytes_me2c);
|
||||
let user = user.clone();
|
||||
tasks.push(tokio::spawn(async move {
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
process_me_writer_response(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0xEE]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
stats.as_ref(),
|
||||
&user,
|
||||
Some(limit),
|
||||
0,
|
||||
bytes_me2c.as_ref(),
|
||||
11_000 + idx,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await
|
||||
}));
|
||||
}
|
||||
|
||||
let mut ok = 0usize;
|
||||
for task in tasks {
|
||||
match task.await.expect("task must not panic") {
|
||||
Ok(_) => ok += 1,
|
||||
Err(ProxyError::DataQuotaExceeded { .. }) => {}
|
||||
Err(other) => panic!("unexpected error in adversarial parallel case: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(ok, limit as usize);
|
||||
assert_eq!(stats.get_user_total_octets(&user), limit);
|
||||
assert_eq!(bytes_me2c.load(Ordering::Relaxed), limit);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn integration_shared_lock_blocks_direct_relay_and_middle_relay_for_same_user() {
|
||||
let user = format!("middle-cross-matrix-integration-{}", std::process::id());
|
||||
let relay_lock = crate::proxy::relay::cross_mode_quota_user_lock_for_tests(&user);
|
||||
let middle_lock = cross_mode_quota_user_lock_for_tests(&user);
|
||||
assert!(
|
||||
Arc::ptr_eq(&relay_lock, &middle_lock),
|
||||
"relay and middle-relay must share the same cross-mode lock identity"
|
||||
);
|
||||
|
||||
let held_guard = relay_lock
|
||||
.try_lock()
|
||||
.expect("test must hold shared cross-mode lock");
|
||||
|
||||
let stats = Stats::new();
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
|
||||
let middle_blocked = timeout(
|
||||
Duration::from_millis(25),
|
||||
process_me_writer_response(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0x92]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(1024),
|
||||
0,
|
||||
&bytes_me2c,
|
||||
12_001,
|
||||
false,
|
||||
false,
|
||||
),
|
||||
)
|
||||
.await;
|
||||
assert!(middle_blocked.is_err());
|
||||
|
||||
drop(held_guard);
|
||||
|
||||
let middle_ready = timeout(
|
||||
Duration::from_millis(250),
|
||||
process_me_writer_response(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0x94]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(1024),
|
||||
0,
|
||||
&bytes_me2c,
|
||||
12_002,
|
||||
false,
|
||||
false,
|
||||
),
|
||||
)
|
||||
.await
|
||||
.expect("middle path must complete after release");
|
||||
|
||||
assert!(middle_ready.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn light_fuzz_mixed_payload_sizes_with_periodic_lock_holds_keeps_accounting_consistent() {
|
||||
let stats = Stats::new();
|
||||
let user = format!("middle-cross-matrix-fuzz-{}", std::process::id());
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
let mut seed = 0xC0DE_1234_55AA_9988u64;
|
||||
|
||||
for case in 0..96u32 {
|
||||
seed ^= seed << 7;
|
||||
seed ^= seed >> 9;
|
||||
seed ^= seed << 8;
|
||||
|
||||
let hold = (seed & 0x03) == 0;
|
||||
let mut held_lock = None;
|
||||
let maybe_guard = if hold {
|
||||
held_lock = Some(cross_mode_quota_user_lock_for_tests(&user));
|
||||
Some(
|
||||
held_lock
|
||||
.as_ref()
|
||||
.expect("held lock should be present")
|
||||
.try_lock()
|
||||
.expect("cross-mode lock should be acquirable in fuzz round"),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let payload_len = ((seed >> 8) as usize % 8) + 1;
|
||||
let payload = vec![(seed & 0xff) as u8; payload_len];
|
||||
let before = stats.get_user_total_octets(&user);
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
|
||||
let timed = timeout(
|
||||
Duration::from_millis(20),
|
||||
process_me_writer_response(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from(payload),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(1024),
|
||||
0,
|
||||
&bytes_me2c,
|
||||
13_000 + case as u64,
|
||||
false,
|
||||
false,
|
||||
),
|
||||
)
|
||||
.await;
|
||||
|
||||
if hold {
|
||||
assert!(timed.is_err(), "held-lock fuzz round must block within timeout");
|
||||
assert_eq!(stats.get_user_total_octets(&user), before);
|
||||
} else {
|
||||
let done = timed.expect("unheld fuzz round must complete in time");
|
||||
assert!(done.is_ok());
|
||||
}
|
||||
|
||||
drop(maybe_guard);
|
||||
drop(held_lock);
|
||||
assert_eq!(bytes_me2c.load(Ordering::Relaxed), stats.get_user_total_octets(&user));
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn stress_held_user_lock_does_not_block_other_users_me_to_client_writes() {
|
||||
let held_user = format!("middle-cross-matrix-stress-held-{}", std::process::id());
|
||||
let free_user = format!("middle-cross-matrix-stress-free-{}", std::process::id());
|
||||
|
||||
let held = cross_mode_quota_user_lock_for_tests(&held_user);
|
||||
let held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold lock for blocked user");
|
||||
|
||||
let mut tasks = Vec::new();
|
||||
for idx in 0..64u64 {
|
||||
let user = free_user.clone();
|
||||
tasks.push(tokio::spawn(async move {
|
||||
let stats = Stats::new();
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
process_me_writer_response(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0xA0]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(1),
|
||||
0,
|
||||
&bytes_me2c,
|
||||
14_000 + idx,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await
|
||||
}));
|
||||
}
|
||||
|
||||
timeout(Duration::from_secs(2), async {
|
||||
for task in tasks {
|
||||
let done = task.await.expect("free-user task must not panic");
|
||||
assert!(done.is_ok());
|
||||
}
|
||||
})
|
||||
.await
|
||||
.expect("free-user tasks should complete without waiting for held user's lock");
|
||||
|
||||
drop(held_guard);
|
||||
}
|
||||
|
|
@ -0,0 +1,254 @@
|
|||
use super::*;
|
||||
use crate::crypto::{AesCtr, SecureRandom};
|
||||
use crate::stats::Stats;
|
||||
use crate::stream::CryptoWriter;
|
||||
use bytes::Bytes;
|
||||
use std::pin::Pin;
|
||||
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::task::{Context, Poll, Waker};
|
||||
use tokio::io::AsyncWrite;
|
||||
use tokio::sync::Notify;
|
||||
use tokio::time::{Duration, timeout};
|
||||
|
||||
fn make_crypto_writer<W>(writer: W) -> CryptoWriter<W>
|
||||
where
|
||||
W: tokio::io::AsyncWrite + Unpin,
|
||||
{
|
||||
let key = [0u8; 32];
|
||||
let iv = 0u128;
|
||||
CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024)
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct BlockingWriteState {
|
||||
write_entered: AtomicBool,
|
||||
released: AtomicBool,
|
||||
write_waker: Mutex<Option<Waker>>,
|
||||
write_entered_notify: Notify,
|
||||
}
|
||||
|
||||
struct BlockingWrite {
|
||||
state: Arc<BlockingWriteState>,
|
||||
}
|
||||
|
||||
impl BlockingWrite {
|
||||
fn new(state: Arc<BlockingWriteState>) -> Self {
|
||||
Self { state }
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for BlockingWrite {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<std::io::Result<usize>> {
|
||||
self.state.write_entered.store(true, Ordering::Release);
|
||||
self.state.write_entered_notify.notify_waiters();
|
||||
|
||||
if self.state.released.load(Ordering::Acquire) {
|
||||
return Poll::Ready(Ok(buf.len()));
|
||||
}
|
||||
|
||||
if let Ok(mut slot) = self.state.write_waker.lock() {
|
||||
*slot = Some(cx.waker().clone());
|
||||
}
|
||||
Poll::Pending
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
async fn wait_until_blocking_write_entered(state: &Arc<BlockingWriteState>) {
|
||||
for _ in 0..8 {
|
||||
if state.write_entered.load(Ordering::Acquire) {
|
||||
return;
|
||||
}
|
||||
let _ = timeout(Duration::from_millis(25), state.write_entered_notify.notified()).await;
|
||||
}
|
||||
|
||||
panic!("blocking writer did not enter poll_write in bounded time");
|
||||
}
|
||||
|
||||
fn release_blocking_write(state: &Arc<BlockingWriteState>) {
|
||||
state.released.store(true, Ordering::Release);
|
||||
if let Ok(mut slot) = state.write_waker.lock()
|
||||
&& let Some(waker) = slot.take()
|
||||
{
|
||||
waker.wake();
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn adversarial_held_cross_mode_lock_blocks_me_to_client_quota_reservation_path() {
|
||||
let stats = Stats::new();
|
||||
let user = format!("middle-me2c-cross-mode-held-{}", std::process::id());
|
||||
let held = cross_mode_quota_user_lock_for_tests(&user);
|
||||
let held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold shared cross-mode lock before ME->C write path");
|
||||
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
|
||||
let blocked = timeout(
|
||||
Duration::from_millis(25),
|
||||
process_me_writer_response(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0x41]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(1024),
|
||||
0,
|
||||
&bytes_me2c,
|
||||
9901,
|
||||
false,
|
||||
false,
|
||||
),
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(
|
||||
blocked.is_err(),
|
||||
"ME->C quota reservation path must be serialized by held shared cross-mode lock"
|
||||
);
|
||||
|
||||
drop(held_guard);
|
||||
|
||||
let released = timeout(
|
||||
Duration::from_millis(250),
|
||||
process_me_writer_response(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0x42]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(1024),
|
||||
0,
|
||||
&bytes_me2c,
|
||||
9902,
|
||||
false,
|
||||
false,
|
||||
),
|
||||
)
|
||||
.await
|
||||
.expect("ME->C write must complete after cross-mode lock release");
|
||||
|
||||
assert!(released.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn business_uncontended_cross_mode_lock_allows_me_to_client_quota_reservation() {
|
||||
let stats = Stats::new();
|
||||
let user = format!("middle-me2c-cross-mode-free-{}", std::process::id());
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
|
||||
let outcome = timeout(
|
||||
Duration::from_millis(250),
|
||||
process_me_writer_response(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0x55, 0x66]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(1024),
|
||||
0,
|
||||
&bytes_me2c,
|
||||
9903,
|
||||
false,
|
||||
false,
|
||||
),
|
||||
)
|
||||
.await
|
||||
.expect("uncontended ME->C path should not stall");
|
||||
|
||||
assert!(outcome.is_ok());
|
||||
assert_eq!(stats.get_user_total_octets(&user), 2);
|
||||
assert_eq!(bytes_me2c.load(std::sync::atomic::Ordering::Relaxed), 2);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn adversarial_cross_mode_lock_is_released_before_me_to_client_write_await() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user = format!("middle-me2c-lock-drop-before-write-{}", std::process::id());
|
||||
let cross_mode_lock = cross_mode_quota_user_lock_for_tests(&user);
|
||||
let bytes_me2c = Arc::new(AtomicU64::new(0));
|
||||
let writer_state = Arc::new(BlockingWriteState::default());
|
||||
|
||||
let worker = {
|
||||
let stats = Arc::clone(&stats);
|
||||
let user = user.clone();
|
||||
let cross_mode_lock = Arc::clone(&cross_mode_lock);
|
||||
let bytes_me2c = Arc::clone(&bytes_me2c);
|
||||
let writer_state = Arc::clone(&writer_state);
|
||||
tokio::spawn(async move {
|
||||
let mut writer = make_crypto_writer(BlockingWrite::new(writer_state));
|
||||
let mut frame_buf = Vec::new();
|
||||
let rng = SecureRandom::new();
|
||||
process_me_writer_response_with_cross_mode_lock(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0xDE, 0xAD, 0xBE, 0xEF]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&rng,
|
||||
&mut frame_buf,
|
||||
stats.as_ref(),
|
||||
&user,
|
||||
Some(1024),
|
||||
0,
|
||||
Some(&cross_mode_lock),
|
||||
bytes_me2c.as_ref(),
|
||||
9910,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await
|
||||
})
|
||||
};
|
||||
|
||||
wait_until_blocking_write_entered(&writer_state).await;
|
||||
|
||||
let acquired_guard = timeout(Duration::from_millis(40), cross_mode_lock.lock())
|
||||
.await
|
||||
.expect("cross-mode lock must be free while ME->C write is pending");
|
||||
drop(acquired_guard);
|
||||
|
||||
release_blocking_write(&writer_state);
|
||||
|
||||
let result = timeout(Duration::from_millis(300), worker)
|
||||
.await
|
||||
.expect("ME->C worker timed out after releasing blocking writer")
|
||||
.expect("ME->C worker must not panic");
|
||||
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(stats.get_user_total_octets(&user), 4);
|
||||
assert_eq!(bytes_me2c.load(Ordering::Relaxed), 4);
|
||||
}
|
||||
|
|
@ -128,6 +128,7 @@ async fn adversarial_same_user_slow_writer_must_not_hol_block_peer_connection()
|
|||
&stats,
|
||||
user,
|
||||
quota_limit,
|
||||
0,
|
||||
&bytes_me2c,
|
||||
7001,
|
||||
false,
|
||||
|
|
@ -167,6 +168,7 @@ async fn adversarial_same_user_slow_writer_must_not_hol_block_peer_connection()
|
|||
&stats_fast,
|
||||
user,
|
||||
quota_limit,
|
||||
0,
|
||||
&bytes_fast,
|
||||
7002,
|
||||
false,
|
||||
|
|
@ -208,6 +210,7 @@ async fn negative_write_failure_rolls_back_pre_accounted_quota_and_forensics_byt
|
|||
&stats,
|
||||
user,
|
||||
Some(64),
|
||||
0,
|
||||
&bytes_me2c,
|
||||
7003,
|
||||
false,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,372 @@
|
|||
use super::*;
|
||||
use crate::crypto::{AesCtr, SecureRandom};
|
||||
use crate::error::ProxyError;
|
||||
use crate::stats::Stats;
|
||||
use crate::stream::CryptoWriter;
|
||||
use bytes::Bytes;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::{Arc, OnceLock, Mutex};
|
||||
use tokio::sync::Mutex as AsyncMutex;
|
||||
use tokio::task::JoinSet;
|
||||
use tokio::time::{Duration, timeout};
|
||||
|
||||
fn make_crypto_writer<W>(writer: W) -> CryptoWriter<W>
|
||||
where
|
||||
W: tokio::io::AsyncWrite + Unpin,
|
||||
{
|
||||
let key = [0u8; 32];
|
||||
let iv = 0u128;
|
||||
CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024)
|
||||
}
|
||||
|
||||
fn lookup_test_lock() -> &'static Mutex<()> {
|
||||
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
LOCK.get_or_init(|| Mutex::new(()))
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn positive_me2c_quota_counts_bytes_exactly_once() {
|
||||
let _guard = lookup_test_lock().lock().unwrap();
|
||||
let stats = Stats::new();
|
||||
let user = format!("quota-middle-ext-positive-{}", std::process::id());
|
||||
let lock = Arc::new(AsyncMutex::new(()));
|
||||
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
|
||||
let result = process_me_writer_response_with_cross_mode_lock(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[1, 2, 3, 4, 5]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(64),
|
||||
0,
|
||||
Some(&lock),
|
||||
&bytes_me2c,
|
||||
70_001,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(stats.get_user_total_octets(&user), 5);
|
||||
assert_eq!(bytes_me2c.load(Ordering::Relaxed), 5);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn negative_held_crossmode_lock_blocks_me2c_write() {
|
||||
let _guard = lookup_test_lock().lock().unwrap();
|
||||
let stats = Stats::new();
|
||||
let user = format!("quota-middle-ext-negative-{}", std::process::id());
|
||||
|
||||
let lock = Arc::new(AsyncMutex::new(()));
|
||||
let _held = lock.try_lock().expect("lock must be held");
|
||||
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
|
||||
let blocked = timeout(
|
||||
Duration::from_millis(25),
|
||||
process_me_writer_response_with_cross_mode_lock(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0xFE]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(16),
|
||||
0,
|
||||
Some(&lock),
|
||||
&bytes_me2c,
|
||||
70_101,
|
||||
false,
|
||||
false,
|
||||
),
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(blocked.is_err());
|
||||
assert_eq!(stats.get_user_total_octets(&user), 0);
|
||||
assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn edge_zero_quota_zero_payload_is_fail_closed() {
|
||||
let _guard = lookup_test_lock().lock().unwrap();
|
||||
let stats = Stats::new();
|
||||
let user = format!("quota-middle-ext-edge-{}", std::process::id());
|
||||
|
||||
let lock = Arc::new(AsyncMutex::new(()));
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
|
||||
let result = process_me_writer_response_with_cross_mode_lock(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::new(),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(0),
|
||||
0,
|
||||
Some(&lock),
|
||||
&bytes_me2c,
|
||||
70_201,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. })));
|
||||
assert_eq!(stats.get_user_total_octets(&user), 0);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn adversarial_parallel_me2c_race_falls_back_to_quota_error() {
|
||||
let _guard = lookup_test_lock().lock().unwrap();
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user = format!("quota-middle-ext-blackhat-{}", std::process::id());
|
||||
let quota = 64u64;
|
||||
let lock = Arc::new(AsyncMutex::new(()));
|
||||
let bytes_me2c = Arc::new(AtomicU64::new(0));
|
||||
|
||||
let mut set = JoinSet::new();
|
||||
for i in 0..256u64 {
|
||||
let stats = Arc::clone(&stats);
|
||||
let user = user.clone();
|
||||
let lock = Arc::clone(&lock);
|
||||
let bytes_me2c = Arc::clone(&bytes_me2c);
|
||||
|
||||
set.spawn(async move {
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let payload = vec![((i & 0xFF) as u8); (i % 4 + 1) as usize];
|
||||
|
||||
process_me_writer_response_with_cross_mode_lock(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from(payload),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
stats.as_ref(),
|
||||
&user,
|
||||
Some(quota),
|
||||
0,
|
||||
Some(&lock),
|
||||
bytes_me2c.as_ref(),
|
||||
70_301 + i,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await
|
||||
});
|
||||
}
|
||||
|
||||
let mut succeeded = 0usize;
|
||||
while let Some(done) = set.join_next().await {
|
||||
match done.expect("task must not panic") {
|
||||
Ok(_) => succeeded += 1,
|
||||
Err(ProxyError::DataQuotaExceeded { .. }) => {}
|
||||
Err(other) => panic!("unexpected error {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(stats.get_user_total_octets(&user), bytes_me2c.load(Ordering::Relaxed));
|
||||
assert!(stats.get_user_total_octets(&user) <= quota);
|
||||
assert!(succeeded <= quota as usize);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn integration_shared_prefetched_lock_blocks_then_releases_writer() {
|
||||
let stats = Stats::new();
|
||||
let user = format!("quota-middle-ext-integration-{}", std::process::id());
|
||||
let lock = Arc::new(AsyncMutex::new(()));
|
||||
let held = lock
|
||||
.try_lock()
|
||||
.expect("integration test must hold prefetched lock first");
|
||||
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
|
||||
let blocked = timeout(
|
||||
Duration::from_millis(25),
|
||||
process_me_writer_response_with_cross_mode_lock(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0xA1]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(8),
|
||||
0,
|
||||
Some(&lock),
|
||||
&bytes_me2c,
|
||||
70_360,
|
||||
false,
|
||||
false,
|
||||
),
|
||||
)
|
||||
.await;
|
||||
assert!(blocked.is_err());
|
||||
|
||||
drop(held);
|
||||
|
||||
let after_release = timeout(
|
||||
Duration::from_millis(150),
|
||||
process_me_writer_response_with_cross_mode_lock(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0xA2]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(8),
|
||||
0,
|
||||
Some(&lock),
|
||||
&bytes_me2c,
|
||||
70_361,
|
||||
false,
|
||||
false,
|
||||
),
|
||||
)
|
||||
.await
|
||||
.expect("writer should progress once the shared lock is released");
|
||||
|
||||
assert!(after_release.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn light_fuzz_small_payloads_toggle_lock_state_stays_consistent() {
|
||||
let _guard = lookup_test_lock().lock().unwrap();
|
||||
let stats = Stats::new();
|
||||
let user = format!("quota-middle-ext-fuzz-{}", std::process::id());
|
||||
let mut seed = 0xCAFE_BABE_1234u64;
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
|
||||
for case in 0..48u32 {
|
||||
seed ^= seed << 5;
|
||||
seed ^= seed >> 12;
|
||||
seed ^= seed << 13;
|
||||
let hold = (seed & 0x1) == 0;
|
||||
|
||||
let lock = Arc::new(AsyncMutex::new(()));
|
||||
let maybe_guard = if hold {
|
||||
Some(lock.try_lock().unwrap())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
|
||||
let result = timeout(
|
||||
Duration::from_millis(30),
|
||||
process_me_writer_response_with_cross_mode_lock(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from(vec![(seed & 0xFF) as u8; ((seed as usize % 5) + 1)]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(128),
|
||||
0,
|
||||
Some(&lock),
|
||||
&bytes_me2c,
|
||||
70_401 + case as u64,
|
||||
false,
|
||||
false,
|
||||
),
|
||||
)
|
||||
.await;
|
||||
|
||||
if hold {
|
||||
assert!(result.is_err());
|
||||
} else {
|
||||
assert!(result.unwrap().is_ok());
|
||||
}
|
||||
|
||||
drop(maybe_guard);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn stress_parallel_free_users_during_held_user_lock_maintains_liveness() {
|
||||
let _guard = lookup_test_lock().lock().unwrap();
|
||||
let held = Arc::new(AsyncMutex::new(()));
|
||||
let _held_guard = held.try_lock().unwrap();
|
||||
|
||||
let mut set = JoinSet::new();
|
||||
for i in 0..48u64 {
|
||||
set.spawn(async move {
|
||||
let stats = Stats::new();
|
||||
let user = format!("quota-middle-ext-stress-free-{i}");
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
let free_lock = Arc::new(AsyncMutex::new(()));
|
||||
|
||||
process_me_writer_response_with_cross_mode_lock(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0xEE]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(1),
|
||||
0,
|
||||
Some(&free_lock),
|
||||
&bytes_me2c,
|
||||
70_500 + i,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await
|
||||
});
|
||||
}
|
||||
|
||||
timeout(Duration::from_secs(2), async {
|
||||
while let Some(task) = set.join_next().await {
|
||||
task.unwrap().unwrap();
|
||||
}
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
|
@ -5,6 +5,8 @@ use crate::stream::CryptoWriter;
|
|||
use bytes::Bytes;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::task::{Context, Poll};
|
||||
use tokio::io::AsyncWrite;
|
||||
use tokio::task::JoinSet;
|
||||
|
||||
fn make_crypto_writer<W>(writer: W) -> CryptoWriter<W>
|
||||
|
|
@ -16,6 +18,77 @@ where
|
|||
CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024)
|
||||
}
|
||||
|
||||
struct FailingWriter;
|
||||
|
||||
impl AsyncWrite for FailingWriter {
|
||||
fn poll_write(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
_buf: &[u8],
|
||||
) -> Poll<std::result::Result<usize, std::io::Error>> {
|
||||
Poll::Ready(Err(std::io::Error::other("forced writer failure")))
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
) -> Poll<std::result::Result<(), std::io::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
) -> Poll<std::result::Result<(), std::io::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
struct FailAfterBudgetWriter {
|
||||
remaining: usize,
|
||||
written: usize,
|
||||
}
|
||||
|
||||
impl FailAfterBudgetWriter {
|
||||
fn new(remaining: usize) -> Self {
|
||||
Self {
|
||||
remaining,
|
||||
written: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for FailAfterBudgetWriter {
|
||||
fn poll_write(
|
||||
mut self: std::pin::Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<std::result::Result<usize, std::io::Error>> {
|
||||
if self.remaining == 0 {
|
||||
return Poll::Ready(Err(std::io::Error::other("forced short-write exhaustion")));
|
||||
}
|
||||
|
||||
let n = self.remaining.min(buf.len());
|
||||
self.remaining -= n;
|
||||
self.written += n;
|
||||
Poll::Ready(Ok(n))
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
) -> Poll<std::result::Result<(), std::io::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
) -> Poll<std::result::Result<(), std::io::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn positive_exact_quota_boundary_allows_last_frame_and_blocks_next() {
|
||||
let stats = Stats::new();
|
||||
|
|
@ -38,6 +111,7 @@ async fn positive_exact_quota_boundary_allows_last_frame_and_blocks_next() {
|
|||
&stats,
|
||||
user,
|
||||
Some(8),
|
||||
0,
|
||||
&bytes_me2c,
|
||||
7101,
|
||||
false,
|
||||
|
|
@ -62,6 +136,7 @@ async fn positive_exact_quota_boundary_allows_last_frame_and_blocks_next() {
|
|||
&stats,
|
||||
user,
|
||||
Some(8),
|
||||
0,
|
||||
&bytes_me2c,
|
||||
7102,
|
||||
false,
|
||||
|
|
@ -105,6 +180,7 @@ async fn adversarial_parallel_reservation_stress_never_overshoots_quota_or_count
|
|||
stats_ref.as_ref(),
|
||||
&user_owned,
|
||||
Some(quota_limit),
|
||||
0,
|
||||
bytes_ref.as_ref(),
|
||||
7200 + idx,
|
||||
false,
|
||||
|
|
@ -171,6 +247,7 @@ async fn light_fuzz_random_frame_sizes_preserve_quota_and_counter_consistency()
|
|||
&stats,
|
||||
user,
|
||||
Some(quota_limit),
|
||||
0,
|
||||
&bytes_me2c,
|
||||
7300 + conn,
|
||||
false,
|
||||
|
|
@ -190,3 +267,800 @@ async fn light_fuzz_random_frame_sizes_preserve_quota_and_counter_consistency()
|
|||
assert!(total <= quota_limit);
|
||||
assert_eq!(bytes_me2c.load(Ordering::Relaxed), total);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn positive_soft_overshoot_allows_burst_inside_soft_cap_then_blocks() {
|
||||
let stats = Stats::new();
|
||||
let user = "soft-cap-boundary-user";
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
let quota_limit = 10u64;
|
||||
let overshoot = 3u64;
|
||||
|
||||
stats.add_user_octets_from(user, 10);
|
||||
|
||||
let mut writer_one = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf_one = Vec::new();
|
||||
let first = process_me_writer_response(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[1, 2, 3]),
|
||||
},
|
||||
&mut writer_one,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf_one,
|
||||
&stats,
|
||||
user,
|
||||
Some(quota_limit),
|
||||
overshoot,
|
||||
&bytes_me2c,
|
||||
7401,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await;
|
||||
assert!(first.is_ok(), "soft-cap buffer should allow reaching limit+overshoot");
|
||||
assert_eq!(stats.get_user_total_octets(user), 13);
|
||||
|
||||
let mut writer_two = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf_two = Vec::new();
|
||||
let second = process_me_writer_response(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[9]),
|
||||
},
|
||||
&mut writer_two,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf_two,
|
||||
&stats,
|
||||
user,
|
||||
Some(quota_limit),
|
||||
overshoot,
|
||||
&bytes_me2c,
|
||||
7402,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await;
|
||||
assert!(matches!(second, Err(ProxyError::DataQuotaExceeded { .. })));
|
||||
assert_eq!(stats.get_user_total_octets(user), 13);
|
||||
assert_eq!(bytes_me2c.load(Ordering::Relaxed), 3);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn negative_soft_overshoot_rejects_when_payload_exceeds_remaining_soft_budget() {
|
||||
let stats = Stats::new();
|
||||
let user = "soft-cap-remaining-user";
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
let quota_limit = 10u64;
|
||||
let overshoot = 4u64;
|
||||
|
||||
stats.add_user_octets_from(user, 12);
|
||||
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let result = process_me_writer_response(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[1, 2, 3]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
user,
|
||||
Some(quota_limit),
|
||||
overshoot,
|
||||
&bytes_me2c,
|
||||
7501,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. })));
|
||||
assert_eq!(stats.get_user_total_octets(user), 12);
|
||||
assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn negative_write_failure_rolls_back_reservation_under_soft_cap_mode() {
|
||||
let stats = Stats::new();
|
||||
let user = "soft-cap-rollback-user";
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
let mut writer = make_crypto_writer(FailingWriter);
|
||||
let mut frame_buf = Vec::new();
|
||||
|
||||
stats.add_user_octets_from(user, 9);
|
||||
|
||||
let result = process_me_writer_response(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[1, 2, 3]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
user,
|
||||
Some(10),
|
||||
8,
|
||||
&bytes_me2c,
|
||||
7601,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(matches!(result, Err(ProxyError::Io(_))));
|
||||
assert_eq!(stats.get_user_total_octets(user), 9);
|
||||
assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn adversarial_parallel_soft_cap_stress_never_exceeds_soft_limit() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user = "soft-cap-stress-user";
|
||||
let quota_limit = 40u64;
|
||||
let overshoot = 5u64;
|
||||
let soft_limit = quota_limit + overshoot;
|
||||
let bytes_me2c = Arc::new(AtomicU64::new(0));
|
||||
let mut tasks = JoinSet::new();
|
||||
|
||||
for idx in 0..256u64 {
|
||||
let user_owned = user.to_string();
|
||||
let stats_ref = Arc::clone(&stats);
|
||||
let bytes_ref = Arc::clone(&bytes_me2c);
|
||||
tasks.spawn(async move {
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
process_me_writer_response(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0x42]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
stats_ref.as_ref(),
|
||||
&user_owned,
|
||||
Some(quota_limit),
|
||||
overshoot,
|
||||
bytes_ref.as_ref(),
|
||||
7700 + idx,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await
|
||||
});
|
||||
}
|
||||
|
||||
while let Some(joined) = tasks.join_next().await {
|
||||
match joined.expect("soft-cap stress task must not panic") {
|
||||
Ok(_) | Err(ProxyError::DataQuotaExceeded { .. }) => {}
|
||||
Err(other) => panic!("unexpected error in soft-cap stress case: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
let total = stats.get_user_total_octets(user);
|
||||
assert!(total <= soft_limit, "soft-cap stress must never overshoot soft limit");
|
||||
assert_eq!(bytes_me2c.load(Ordering::Relaxed), total);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn light_fuzz_soft_cap_matrix_keeps_counters_and_limits_consistent() {
|
||||
let stats = Stats::new();
|
||||
let user = "soft-cap-fuzz-user";
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
let mut seed = 0x9E37_79B9_7F4A_7C15u64;
|
||||
|
||||
for conn in 0..1024u64 {
|
||||
seed ^= seed << 7;
|
||||
seed ^= seed >> 9;
|
||||
seed ^= seed << 8;
|
||||
|
||||
let quota_limit = 32 + (seed & 0x3f);
|
||||
let overshoot = seed.rotate_left(13) & 0x0f;
|
||||
let len = ((seed >> 3) & 0x07) + 1;
|
||||
let payload = vec![0xA5; len as usize];
|
||||
let before = stats.get_user_total_octets(user);
|
||||
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let result = process_me_writer_response(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from(payload),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
user,
|
||||
Some(quota_limit),
|
||||
overshoot,
|
||||
&bytes_me2c,
|
||||
7800 + conn,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await;
|
||||
|
||||
if let Err(ref err) = result {
|
||||
assert!(
|
||||
matches!(err, ProxyError::DataQuotaExceeded { .. }),
|
||||
"soft-cap fuzz produced unexpected error variant: {err:?}"
|
||||
);
|
||||
}
|
||||
|
||||
let after = stats.get_user_total_octets(user);
|
||||
let soft_limit = quota_limit.saturating_add(overshoot);
|
||||
match result {
|
||||
Ok(_) => {
|
||||
assert_eq!(after, before.saturating_add(len));
|
||||
assert!(after <= soft_limit, "accepted write must stay within active soft cap");
|
||||
}
|
||||
Err(_) => {
|
||||
assert_eq!(after, before, "rejected write must not mutate quota state");
|
||||
}
|
||||
}
|
||||
assert_eq!(
|
||||
bytes_me2c.load(Ordering::Relaxed),
|
||||
after,
|
||||
"soft-cap fuzz must keep counters synchronized"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn positive_no_quota_limit_accumulates_data_octets_exactly() {
|
||||
let stats = Stats::new();
|
||||
let user = "no-quota-user";
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
let mut expected = 0u64;
|
||||
|
||||
for (idx, len) in [1usize, 2, 3, 5, 8, 13, 21].iter().copied().enumerate() {
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let payload = vec![0x41; len];
|
||||
let result = process_me_writer_response(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from(payload),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
user,
|
||||
None,
|
||||
0,
|
||||
&bytes_me2c,
|
||||
7900 + idx as u64,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(result.is_ok());
|
||||
expected += len as u64;
|
||||
}
|
||||
|
||||
assert_eq!(stats.get_user_total_octets(user), expected);
|
||||
assert_eq!(bytes_me2c.load(Ordering::Relaxed), expected);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn negative_zero_quota_rejects_non_empty_payload() {
|
||||
let stats = Stats::new();
|
||||
let user = "zero-quota-user";
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let result = process_me_writer_response(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0xAA]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
user,
|
||||
Some(0),
|
||||
0,
|
||||
&bytes_me2c,
|
||||
8001,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. })));
|
||||
assert_eq!(stats.get_user_total_octets(user), 0);
|
||||
assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn edge_zero_length_payload_with_zero_quota_is_fail_closed() {
|
||||
let stats = Stats::new();
|
||||
let user = "zero-len-zero-quota-user";
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let result = process_me_writer_response(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::new(),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
user,
|
||||
Some(0),
|
||||
0,
|
||||
&bytes_me2c,
|
||||
8002,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. })));
|
||||
assert_eq!(stats.get_user_total_octets(user), 0);
|
||||
assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn positive_ack_response_does_not_touch_quota_counters() {
|
||||
let stats = Stats::new();
|
||||
let user = "ack-accounting-user";
|
||||
let bytes_me2c = AtomicU64::new(11);
|
||||
stats.add_user_octets_to(user, 23);
|
||||
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let result = process_me_writer_response(
|
||||
MeResponse::Ack(0x33445566),
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
user,
|
||||
Some(24),
|
||||
0,
|
||||
&bytes_me2c,
|
||||
8003,
|
||||
true,
|
||||
true,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(stats.get_user_total_octets(user), 23);
|
||||
assert_eq!(bytes_me2c.load(Ordering::Relaxed), 11);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn edge_close_response_is_accounting_noop() {
|
||||
let stats = Stats::new();
|
||||
let user = "close-accounting-user";
|
||||
let bytes_me2c = AtomicU64::new(19);
|
||||
stats.add_user_octets_to(user, 31);
|
||||
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let result = process_me_writer_response(
|
||||
MeResponse::Close,
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
user,
|
||||
Some(40),
|
||||
3,
|
||||
&bytes_me2c,
|
||||
8004,
|
||||
false,
|
||||
true,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(stats.get_user_total_octets(user), 31);
|
||||
assert_eq!(bytes_me2c.load(Ordering::Relaxed), 19);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn negative_preloaded_above_soft_cap_rejects_even_single_byte() {
|
||||
let stats = Stats::new();
|
||||
let user = "preloaded-over-soft-cap-user";
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
let quota_limit = 20u64;
|
||||
let overshoot = 2u64;
|
||||
stats.add_user_octets_to(user, quota_limit + overshoot + 1);
|
||||
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let result = process_me_writer_response(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[1]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
user,
|
||||
Some(quota_limit),
|
||||
overshoot,
|
||||
&bytes_me2c,
|
||||
8005,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. })));
|
||||
assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0);
|
||||
assert_eq!(stats.get_user_total_octets(user), quota_limit + overshoot + 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn adversarial_fail_writer_path_never_desynchronizes_quota_accounting() {
|
||||
let stats = Stats::new();
|
||||
let user = "partial-write-rollback-user";
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
let mut writer = make_crypto_writer(FailAfterBudgetWriter::new(7));
|
||||
let mut frame_buf = Vec::new();
|
||||
let payload_len = 16 * 1024u64;
|
||||
|
||||
let result = process_me_writer_response(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from(vec![0x42; 16 * 1024]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
user,
|
||||
Some(payload_len),
|
||||
0,
|
||||
&bytes_me2c,
|
||||
8006,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await;
|
||||
|
||||
let total_after = stats.get_user_total_octets(user);
|
||||
let forensic_after = bytes_me2c.load(Ordering::Relaxed);
|
||||
assert_eq!(forensic_after, total_after);
|
||||
assert!(
|
||||
total_after == 0 || total_after == payload_len,
|
||||
"writer failure path must either roll back fully or commit exactly one payload"
|
||||
);
|
||||
|
||||
// Regardless of whether I/O failure surfaced immediately or was deferred,
|
||||
// accounting must remain fail-closed and prevent silent overshoot.
|
||||
let mut writer_two = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf_two = Vec::new();
|
||||
let second = process_me_writer_response(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0x99]),
|
||||
},
|
||||
&mut writer_two,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf_two,
|
||||
&stats,
|
||||
user,
|
||||
Some(payload_len),
|
||||
0,
|
||||
&bytes_me2c,
|
||||
8007,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await;
|
||||
|
||||
if total_after == payload_len {
|
||||
assert!(matches!(second, Err(ProxyError::DataQuotaExceeded { .. })));
|
||||
} else {
|
||||
assert!(second.is_ok());
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn stress_parallel_oversized_frames_fail_closed_without_counter_leak() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user = "parallel-fail-rollback-user";
|
||||
let bytes_me2c = Arc::new(AtomicU64::new(0));
|
||||
let mut tasks = JoinSet::new();
|
||||
|
||||
for idx in 0..256u64 {
|
||||
let user_owned = user.to_string();
|
||||
let stats_ref = Arc::clone(&stats);
|
||||
let bytes_ref = Arc::clone(&bytes_me2c);
|
||||
tasks.spawn(async move {
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
process_me_writer_response(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from(vec![0xEE; 12 * 1024]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
stats_ref.as_ref(),
|
||||
&user_owned,
|
||||
Some(512),
|
||||
0,
|
||||
bytes_ref.as_ref(),
|
||||
8100 + idx,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await
|
||||
});
|
||||
}
|
||||
|
||||
while let Some(joined) = tasks.join_next().await {
|
||||
let result = joined.expect("parallel fail writer task must not panic");
|
||||
assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. })));
|
||||
}
|
||||
|
||||
assert_eq!(stats.get_user_total_octets(user), 0);
|
||||
assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn integration_mixed_data_ack_close_sequence_preserves_data_only_accounting() {
|
||||
let stats = Stats::new();
|
||||
let user = "mixed-sequence-user";
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
|
||||
let data_one = process_me_writer_response(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[1, 2, 3]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
user,
|
||||
Some(32),
|
||||
0,
|
||||
&bytes_me2c,
|
||||
8201,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await;
|
||||
assert!(data_one.is_ok());
|
||||
|
||||
let ack = process_me_writer_response(
|
||||
MeResponse::Ack(0x0102_0304),
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
user,
|
||||
Some(32),
|
||||
0,
|
||||
&bytes_me2c,
|
||||
8202,
|
||||
true,
|
||||
true,
|
||||
)
|
||||
.await;
|
||||
assert!(ack.is_ok());
|
||||
|
||||
let data_two = process_me_writer_response(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[4, 5]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
user,
|
||||
Some(32),
|
||||
0,
|
||||
&bytes_me2c,
|
||||
8203,
|
||||
false,
|
||||
true,
|
||||
)
|
||||
.await;
|
||||
assert!(data_two.is_ok());
|
||||
|
||||
let close = process_me_writer_response(
|
||||
MeResponse::Close,
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
user,
|
||||
Some(32),
|
||||
0,
|
||||
&bytes_me2c,
|
||||
8204,
|
||||
false,
|
||||
true,
|
||||
)
|
||||
.await;
|
||||
assert!(close.is_ok());
|
||||
|
||||
assert_eq!(stats.get_user_total_octets(user), 5);
|
||||
assert_eq!(bytes_me2c.load(Ordering::Relaxed), 5);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn stress_parallel_multi_user_quota_isolation_no_cross_user_leakage() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user_a = "quota-isolation-a";
|
||||
let user_b = "quota-isolation-b";
|
||||
let limit_a = 50u64;
|
||||
let limit_b = 80u64;
|
||||
let bytes_a = Arc::new(AtomicU64::new(0));
|
||||
let bytes_b = Arc::new(AtomicU64::new(0));
|
||||
|
||||
let mut tasks = JoinSet::new();
|
||||
for idx in 0..200u64 {
|
||||
let stats_ref = Arc::clone(&stats);
|
||||
let bytes_ref = Arc::clone(&bytes_a);
|
||||
tasks.spawn(async move {
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
process_me_writer_response(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0xA1]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
stats_ref.as_ref(),
|
||||
user_a,
|
||||
Some(limit_a),
|
||||
0,
|
||||
bytes_ref.as_ref(),
|
||||
8300 + idx,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await
|
||||
});
|
||||
}
|
||||
|
||||
for idx in 0..220u64 {
|
||||
let stats_ref = Arc::clone(&stats);
|
||||
let bytes_ref = Arc::clone(&bytes_b);
|
||||
tasks.spawn(async move {
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
process_me_writer_response(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0xB2]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
stats_ref.as_ref(),
|
||||
user_b,
|
||||
Some(limit_b),
|
||||
0,
|
||||
bytes_ref.as_ref(),
|
||||
8500 + idx,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await
|
||||
});
|
||||
}
|
||||
|
||||
while let Some(joined) = tasks.join_next().await {
|
||||
let result = joined.expect("quota isolation task must not panic");
|
||||
assert!(result.is_ok() || matches!(result, Err(ProxyError::DataQuotaExceeded { .. })));
|
||||
}
|
||||
|
||||
assert_eq!(stats.get_user_total_octets(user_a), limit_a);
|
||||
assert_eq!(stats.get_user_total_octets(user_b), limit_b);
|
||||
assert_eq!(bytes_a.load(Ordering::Relaxed), limit_a);
|
||||
assert_eq!(bytes_b.load(Ordering::Relaxed), limit_b);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn light_fuzz_mixed_me_responses_preserve_quota_and_counter_invariants() {
|
||||
let stats = Stats::new();
|
||||
let user = "mixed-fuzz-user";
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
let quota_limit = 96u64;
|
||||
let mut seed = 0xDEAD_BEEF_2026_0323u64;
|
||||
|
||||
for idx in 0..2048u64 {
|
||||
seed ^= seed << 7;
|
||||
seed ^= seed >> 9;
|
||||
seed ^= seed << 8;
|
||||
|
||||
let choice = (seed & 0x03) as u8;
|
||||
let response = if choice == 0 {
|
||||
MeResponse::Ack((seed >> 8) as u32)
|
||||
} else if choice == 1 {
|
||||
MeResponse::Close
|
||||
} else {
|
||||
let len = ((seed >> 16) & 0x07) as usize;
|
||||
let mut payload = vec![0u8; len];
|
||||
payload.fill((seed & 0xff) as u8);
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from(payload),
|
||||
}
|
||||
};
|
||||
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let result = process_me_writer_response(
|
||||
response,
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
user,
|
||||
Some(quota_limit),
|
||||
0,
|
||||
&bytes_me2c,
|
||||
8800 + idx,
|
||||
(idx & 1) == 0,
|
||||
(idx & 2) == 0,
|
||||
)
|
||||
.await;
|
||||
|
||||
if let Err(err) = result {
|
||||
assert!(
|
||||
matches!(err, ProxyError::DataQuotaExceeded { .. }),
|
||||
"mixed fuzz produced unexpected error variant: {err:?}"
|
||||
);
|
||||
}
|
||||
|
||||
let total = stats.get_user_total_octets(user);
|
||||
assert!(
|
||||
total <= quota_limit,
|
||||
"mixed fuzz must keep usage at or below quota limit"
|
||||
);
|
||||
assert_eq!(bytes_me2c.load(Ordering::Relaxed), total);
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,399 @@
|
|||
use super::*;
|
||||
use crate::crypto::{AesCtr, SecureRandom};
|
||||
use crate::stats::Stats;
|
||||
use crate::stream::CryptoWriter;
|
||||
use bytes::Bytes;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::{Arc, Mutex, OnceLock};
|
||||
use tokio::sync::Mutex as AsyncMutex;
|
||||
use tokio::task::JoinSet;
|
||||
use tokio::time::{Duration, timeout};
|
||||
|
||||
fn make_crypto_writer<W>(writer: W) -> CryptoWriter<W>
|
||||
where
|
||||
W: tokio::io::AsyncWrite + Unpin,
|
||||
{
|
||||
let key = [0u8; 32];
|
||||
let iv = 0u128;
|
||||
CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024)
|
||||
}
|
||||
|
||||
fn lookup_counter_test_lock() -> &'static Mutex<()> {
|
||||
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
LOCK.get_or_init(|| Mutex::new(()))
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn positive_prefetched_cross_mode_lock_multi_frame_accounting_is_exact() {
|
||||
let _guard = lookup_counter_test_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(|poison| poison.into_inner());
|
||||
|
||||
let stats = Stats::new();
|
||||
let user = format!("quota-extreme-positive-{}", std::process::id());
|
||||
let lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
|
||||
crate::proxy::quota_lock_registry::reset_cross_mode_quota_user_lock_lookup_count_for_tests();
|
||||
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
|
||||
for idx in 0..12u64 {
|
||||
let payload = vec![0x5A; ((idx % 4) + 1) as usize];
|
||||
let result = process_me_writer_response_with_cross_mode_lock(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from(payload),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(512),
|
||||
0,
|
||||
Some(&lock),
|
||||
&bytes_me2c,
|
||||
31_000 + idx,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
crate::proxy::quota_lock_registry::cross_mode_quota_user_lock_lookup_count_for_user_for_tests(&user),
|
||||
0,
|
||||
"prefetched lock path must avoid hot-path registry lookups"
|
||||
);
|
||||
assert_eq!(
|
||||
stats.get_user_total_octets(&user),
|
||||
bytes_me2c.load(Ordering::Relaxed),
|
||||
"forensics and quota accounting must remain synchronized"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn negative_held_prefetched_lock_blocks_writer_without_accounting_mutation() {
|
||||
let _guard = lookup_counter_test_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(|poison| poison.into_inner());
|
||||
|
||||
let stats = Stats::new();
|
||||
let user = format!("quota-extreme-negative-{}", std::process::id());
|
||||
let lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
|
||||
let held_guard = lock
|
||||
.try_lock()
|
||||
.expect("test must hold lock before calling ME->C writer");
|
||||
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
|
||||
let blocked = timeout(
|
||||
Duration::from_millis(25),
|
||||
process_me_writer_response_with_cross_mode_lock(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[1, 2, 3]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(64),
|
||||
0,
|
||||
Some(&lock),
|
||||
&bytes_me2c,
|
||||
31_100,
|
||||
false,
|
||||
false,
|
||||
),
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(blocked.is_err());
|
||||
assert_eq!(stats.get_user_total_octets(&user), 0);
|
||||
assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0);
|
||||
|
||||
drop(held_guard);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn edge_zero_quota_and_zero_payload_is_fail_closed() {
|
||||
let _guard = lookup_counter_test_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(|poison| poison.into_inner());
|
||||
|
||||
let stats = Stats::new();
|
||||
let user = format!("quota-extreme-edge-{}", std::process::id());
|
||||
let lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
|
||||
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
|
||||
let result = process_me_writer_response_with_cross_mode_lock(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::new(),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(0),
|
||||
0,
|
||||
Some(&lock),
|
||||
&bytes_me2c,
|
||||
31_200,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. })));
|
||||
assert_eq!(stats.get_user_total_octets(&user), 0);
|
||||
assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn adversarial_blackhat_parallel_quota_race_never_overshoots_soft_cap() {
|
||||
let _guard = lookup_counter_test_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(|poison| poison.into_inner());
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user = format!("quota-extreme-blackhat-{}", std::process::id());
|
||||
let quota = 80u64;
|
||||
let overshoot = 7u64;
|
||||
let soft_limit = quota + overshoot;
|
||||
let lock = Arc::new(crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user));
|
||||
let bytes_me2c = Arc::new(AtomicU64::new(0));
|
||||
|
||||
let mut set = JoinSet::new();
|
||||
for idx in 0..256u64 {
|
||||
let stats = Arc::clone(&stats);
|
||||
let user = user.clone();
|
||||
let lock = Arc::clone(&lock);
|
||||
let bytes_me2c = Arc::clone(&bytes_me2c);
|
||||
|
||||
set.spawn(async move {
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let len = ((idx % 5) + 1) as usize;
|
||||
let payload = vec![0xAA; len];
|
||||
|
||||
process_me_writer_response_with_cross_mode_lock(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from(payload),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
stats.as_ref(),
|
||||
&user,
|
||||
Some(quota),
|
||||
overshoot,
|
||||
Some(&lock),
|
||||
bytes_me2c.as_ref(),
|
||||
31_300 + idx,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await
|
||||
});
|
||||
}
|
||||
|
||||
while let Some(done) = set.join_next().await {
|
||||
match done.expect("task must not panic") {
|
||||
Ok(_) | Err(ProxyError::DataQuotaExceeded { .. }) => {}
|
||||
Err(other) => panic!("unexpected error variant under black-hat race: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
let total = stats.get_user_total_octets(&user);
|
||||
assert!(
|
||||
total <= soft_limit,
|
||||
"parallel adversarial race must stay under soft cap"
|
||||
);
|
||||
assert_eq!(bytes_me2c.load(Ordering::Relaxed), total);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn integration_without_prefetched_lock_uses_registry_lookup_path() {
|
||||
let _guard = lookup_counter_test_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(|poison| poison.into_inner());
|
||||
|
||||
let stats = Stats::new();
|
||||
let user = format!("quota-extreme-integration-{}", std::process::id());
|
||||
crate::proxy::quota_lock_registry::reset_cross_mode_quota_user_lock_lookup_count_for_tests();
|
||||
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
|
||||
for idx in 0..3u64 {
|
||||
let result = process_me_writer_response_with_cross_mode_lock(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0x41]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(16),
|
||||
0,
|
||||
None,
|
||||
&bytes_me2c,
|
||||
31_400 + idx,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
crate::proxy::quota_lock_registry::cross_mode_quota_user_lock_lookup_count_for_user_for_tests(&user),
|
||||
3,
|
||||
"control path should perform one lock-registry lookup per call"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn light_fuzz_quota_matrix_preserves_fail_closed_accounting() {
|
||||
let _guard = lookup_counter_test_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(|poison| poison.into_inner());
|
||||
|
||||
let stats = Stats::new();
|
||||
let user = format!("quota-extreme-fuzz-{}", std::process::id());
|
||||
let lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
let mut seed = 0xA11C_55EE_2026_0323u64;
|
||||
|
||||
for idx in 0..512u64 {
|
||||
seed ^= seed << 7;
|
||||
seed ^= seed >> 9;
|
||||
seed ^= seed << 8;
|
||||
|
||||
let quota = 24 + (seed & 0x3f);
|
||||
let overshoot = (seed >> 13) & 0x0f;
|
||||
let len = ((seed >> 19) & 0x07) + 1;
|
||||
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let before = stats.get_user_total_octets(&user);
|
||||
|
||||
let result = process_me_writer_response_with_cross_mode_lock(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from(vec![0x11; len as usize]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(quota),
|
||||
overshoot,
|
||||
Some(&lock),
|
||||
&bytes_me2c,
|
||||
31_500 + idx,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await;
|
||||
|
||||
let after = stats.get_user_total_octets(&user);
|
||||
if result.is_ok() {
|
||||
assert!(after >= before);
|
||||
} else {
|
||||
assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. })));
|
||||
assert_eq!(after, before);
|
||||
}
|
||||
assert_eq!(bytes_me2c.load(Ordering::Relaxed), after);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn stress_prefetched_lock_high_fanout_exact_quota_success_count() {
|
||||
let _guard = lookup_counter_test_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(|poison| poison.into_inner());
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user = format!("quota-extreme-stress-{}", std::process::id());
|
||||
let quota = 96u64;
|
||||
let lock: Arc<AsyncMutex<()>> = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
|
||||
let bytes_me2c = Arc::new(AtomicU64::new(0));
|
||||
|
||||
crate::proxy::quota_lock_registry::reset_cross_mode_quota_user_lock_lookup_count_for_tests();
|
||||
|
||||
let mut set = JoinSet::new();
|
||||
for idx in 0..384u64 {
|
||||
let stats = Arc::clone(&stats);
|
||||
let user = user.clone();
|
||||
let lock = Arc::clone(&lock);
|
||||
let bytes_me2c = Arc::clone(&bytes_me2c);
|
||||
|
||||
set.spawn(async move {
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
process_me_writer_response_with_cross_mode_lock(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0xFF]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
stats.as_ref(),
|
||||
&user,
|
||||
Some(quota),
|
||||
0,
|
||||
Some(&lock),
|
||||
bytes_me2c.as_ref(),
|
||||
31_600 + idx,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await
|
||||
});
|
||||
}
|
||||
|
||||
let mut success = 0usize;
|
||||
while let Some(done) = set.join_next().await {
|
||||
match done.expect("task must not panic") {
|
||||
Ok(_) => success += 1,
|
||||
Err(ProxyError::DataQuotaExceeded { .. }) => {}
|
||||
Err(other) => panic!("unexpected error variant in stress fanout: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(success, quota as usize);
|
||||
assert_eq!(stats.get_user_total_octets(&user), quota);
|
||||
assert_eq!(bytes_me2c.load(Ordering::Relaxed), quota);
|
||||
assert_eq!(
|
||||
crate::proxy::quota_lock_registry::cross_mode_quota_user_lock_lookup_count_for_user_for_tests(&user),
|
||||
0,
|
||||
"stress prefetched path must not use lock registry lookups"
|
||||
);
|
||||
}
|
||||
|
|
@ -7,7 +7,7 @@ use std::sync::atomic::AtomicU64;
|
|||
use std::time::Instant;
|
||||
use tokio::io::{AsyncRead, AsyncWriteExt, duplex};
|
||||
use tokio::task::JoinSet;
|
||||
use tokio::time::{Duration as TokioDuration, sleep, timeout};
|
||||
use tokio::time::{Duration as TokioDuration, sleep};
|
||||
|
||||
fn make_crypto_reader<T>(reader: T) -> CryptoReader<T>
|
||||
where
|
||||
|
|
@ -42,10 +42,10 @@ fn make_forensics(conn_id: u64, started_at: Instant) -> RelayForensicsState {
|
|||
fn make_enabled_idle_policy() -> RelayClientIdlePolicy {
|
||||
RelayClientIdlePolicy {
|
||||
enabled: true,
|
||||
soft_idle: Duration::from_secs(30),
|
||||
hard_idle: Duration::from_secs(60),
|
||||
soft_idle: Duration::from_millis(50),
|
||||
hard_idle: Duration::from_millis(120),
|
||||
grace_after_downstream_activity: Duration::from_secs(0),
|
||||
legacy_frame_read_timeout: Duration::from_secs(30),
|
||||
legacy_frame_read_timeout: Duration::from_millis(50),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -94,8 +94,8 @@ async fn stress_parallel_pure_tiny_floods_all_fail_closed() {
|
|||
writer.write_all(&flood_encrypted).await.unwrap();
|
||||
drop(writer);
|
||||
|
||||
let result = timeout(
|
||||
TokioDuration::from_secs(1),
|
||||
let result = run_relay_test_step_timeout(
|
||||
"tiny flood task",
|
||||
read_once(
|
||||
&mut crypto_reader,
|
||||
ProtoTag::Abridged,
|
||||
|
|
@ -104,8 +104,7 @@ async fn stress_parallel_pure_tiny_floods_all_fail_closed() {
|
|||
&mut idle_state,
|
||||
),
|
||||
)
|
||||
.await
|
||||
.expect("tiny flood task must complete");
|
||||
.await;
|
||||
|
||||
assert!(matches!(result, Err(ProxyError::Proxy(_))));
|
||||
assert_eq!(frame_counter, 0);
|
||||
|
|
@ -140,8 +139,8 @@ async fn stress_parallel_benign_tiny_burst_then_real_all_pass() {
|
|||
let encrypted = encrypt_for_reader(&plaintext);
|
||||
writer.write_all(&encrypted).await.unwrap();
|
||||
|
||||
let result = timeout(
|
||||
TokioDuration::from_secs(1),
|
||||
let result = run_relay_test_step_timeout(
|
||||
"benign tiny burst read",
|
||||
read_once(
|
||||
&mut crypto_reader,
|
||||
ProtoTag::Abridged,
|
||||
|
|
@ -151,7 +150,6 @@ async fn stress_parallel_benign_tiny_burst_then_real_all_pass() {
|
|||
),
|
||||
)
|
||||
.await
|
||||
.expect("benign task must complete")
|
||||
.expect("benign payload must parse")
|
||||
.expect("benign payload must return frame");
|
||||
|
||||
|
|
@ -196,8 +194,8 @@ async fn adversarial_lockstep_alternating_attack_under_jitter_closes() {
|
|||
|
||||
let mut closed = false;
|
||||
for _ in 0..220 {
|
||||
let result = timeout(
|
||||
TokioDuration::from_secs(1),
|
||||
let result = run_relay_test_step_timeout(
|
||||
"alternating jitter read step",
|
||||
read_once(
|
||||
&mut crypto_reader,
|
||||
ProtoTag::Abridged,
|
||||
|
|
@ -206,8 +204,7 @@ async fn adversarial_lockstep_alternating_attack_under_jitter_closes() {
|
|||
&mut idle_state,
|
||||
),
|
||||
)
|
||||
.await
|
||||
.expect("alternating reader step must complete");
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(Some((_payload, _))) => {}
|
||||
|
|
@ -336,8 +333,8 @@ async fn light_fuzz_parallel_patterns_no_hang_or_panic() {
|
|||
drop(writer);
|
||||
|
||||
for _ in 0..320 {
|
||||
let step = timeout(
|
||||
TokioDuration::from_secs(1),
|
||||
let step = run_relay_test_step_timeout(
|
||||
"fuzz case read step",
|
||||
read_once(
|
||||
&mut crypto_reader,
|
||||
ProtoTag::Abridged,
|
||||
|
|
@ -346,8 +343,7 @@ async fn light_fuzz_parallel_patterns_no_hang_or_panic() {
|
|||
&mut idle_state,
|
||||
),
|
||||
)
|
||||
.await
|
||||
.expect("fuzz case read step must complete");
|
||||
.await;
|
||||
|
||||
match step {
|
||||
Ok(Some(_)) => {}
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ use std::sync::Arc;
|
|||
use std::sync::atomic::AtomicU64;
|
||||
use std::time::Instant;
|
||||
use tokio::io::{AsyncRead, AsyncWriteExt, duplex};
|
||||
use tokio::time::{Duration as TokioDuration, sleep, timeout};
|
||||
use tokio::time::{Duration as TokioDuration, sleep};
|
||||
|
||||
fn make_crypto_reader<T>(reader: T) -> CryptoReader<T>
|
||||
where
|
||||
|
|
@ -41,10 +41,10 @@ fn make_forensics(conn_id: u64, started_at: Instant) -> RelayForensicsState {
|
|||
fn make_enabled_idle_policy() -> RelayClientIdlePolicy {
|
||||
RelayClientIdlePolicy {
|
||||
enabled: true,
|
||||
soft_idle: Duration::from_secs(30),
|
||||
hard_idle: Duration::from_secs(60),
|
||||
soft_idle: Duration::from_millis(50),
|
||||
hard_idle: Duration::from_millis(120),
|
||||
grace_after_downstream_activity: Duration::from_secs(0),
|
||||
legacy_frame_read_timeout: Duration::from_secs(30),
|
||||
legacy_frame_read_timeout: Duration::from_millis(50),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -117,6 +117,11 @@ async fn read_once_with_state(
|
|||
.await
|
||||
}
|
||||
|
||||
fn is_fail_closed_outcome(result: &Result<Option<(PooledBuffer, bool)>>) -> bool {
|
||||
matches!(result, Err(ProxyError::Proxy(_)))
|
||||
|| matches!(result, Err(ProxyError::Io(e)) if e.kind() == std::io::ErrorKind::TimedOut)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn intermediate_chunked_zero_flood_fail_closed() {
|
||||
let (reader, mut writer) = duplex(4096);
|
||||
|
|
@ -134,8 +139,8 @@ async fn intermediate_chunked_zero_flood_fail_closed() {
|
|||
write_chunked_with_jitter(&mut writer, &encrypted, 0x1111_2222).await;
|
||||
drop(writer);
|
||||
|
||||
let result = timeout(
|
||||
TokioDuration::from_secs(2),
|
||||
let result = run_relay_test_step_timeout(
|
||||
"intermediate flood read",
|
||||
read_once_with_state(
|
||||
&mut crypto_reader,
|
||||
ProtoTag::Intermediate,
|
||||
|
|
@ -144,10 +149,12 @@ async fn intermediate_chunked_zero_flood_fail_closed() {
|
|||
&mut idle_state,
|
||||
),
|
||||
)
|
||||
.await
|
||||
.expect("intermediate flood read must complete");
|
||||
.await;
|
||||
|
||||
assert!(matches!(result, Err(ProxyError::Proxy(_))));
|
||||
assert!(
|
||||
is_fail_closed_outcome(&result),
|
||||
"zero-length flood must fail closed via debt guard or idle timeout"
|
||||
);
|
||||
assert_eq!(frame_counter, 0);
|
||||
}
|
||||
|
||||
|
|
@ -168,8 +175,8 @@ async fn secure_chunked_zero_flood_fail_closed() {
|
|||
write_chunked_with_jitter(&mut writer, &encrypted, 0x3333_4444).await;
|
||||
drop(writer);
|
||||
|
||||
let result = timeout(
|
||||
TokioDuration::from_secs(2),
|
||||
let result = run_relay_test_step_timeout(
|
||||
"secure flood read",
|
||||
read_once_with_state(
|
||||
&mut crypto_reader,
|
||||
ProtoTag::Secure,
|
||||
|
|
@ -178,10 +185,12 @@ async fn secure_chunked_zero_flood_fail_closed() {
|
|||
&mut idle_state,
|
||||
),
|
||||
)
|
||||
.await
|
||||
.expect("secure flood read must complete");
|
||||
.await;
|
||||
|
||||
assert!(matches!(result, Err(ProxyError::Proxy(_))));
|
||||
assert!(
|
||||
is_fail_closed_outcome(&result),
|
||||
"secure zero-length flood must fail closed via debt guard or idle timeout"
|
||||
);
|
||||
assert_eq!(frame_counter, 0);
|
||||
}
|
||||
|
||||
|
|
@ -208,8 +217,8 @@ async fn intermediate_chunked_alternating_attack_closes_before_eof() {
|
|||
|
||||
let mut closed = false;
|
||||
for _ in 0..240 {
|
||||
let step = timeout(
|
||||
TokioDuration::from_secs(1),
|
||||
let step = run_relay_test_step_timeout(
|
||||
"intermediate alternating read step",
|
||||
read_once_with_state(
|
||||
&mut crypto_reader,
|
||||
ProtoTag::Intermediate,
|
||||
|
|
@ -218,8 +227,7 @@ async fn intermediate_chunked_alternating_attack_closes_before_eof() {
|
|||
&mut idle_state,
|
||||
),
|
||||
)
|
||||
.await
|
||||
.expect("intermediate alternating read step must complete");
|
||||
.await;
|
||||
|
||||
match step {
|
||||
Ok(Some(_)) => {}
|
||||
|
|
@ -259,8 +267,8 @@ async fn secure_chunked_alternating_attack_closes_before_eof() {
|
|||
|
||||
let mut closed = false;
|
||||
for _ in 0..240 {
|
||||
let step = timeout(
|
||||
TokioDuration::from_secs(1),
|
||||
let step = run_relay_test_step_timeout(
|
||||
"secure alternating read step",
|
||||
read_once_with_state(
|
||||
&mut crypto_reader,
|
||||
ProtoTag::Secure,
|
||||
|
|
@ -269,8 +277,7 @@ async fn secure_chunked_alternating_attack_closes_before_eof() {
|
|||
&mut idle_state,
|
||||
),
|
||||
)
|
||||
.await
|
||||
.expect("secure alternating read step must complete");
|
||||
.await;
|
||||
|
||||
match step {
|
||||
Ok(Some(_)) => {}
|
||||
|
|
@ -394,8 +401,8 @@ async fn light_fuzz_proto_chunking_outcomes_are_bounded() {
|
|||
drop(writer);
|
||||
|
||||
for _ in 0..260 {
|
||||
let step = timeout(
|
||||
TokioDuration::from_secs(1),
|
||||
let step = run_relay_test_step_timeout(
|
||||
"fuzz proto read step",
|
||||
read_once_with_state(
|
||||
&mut crypto_reader,
|
||||
proto,
|
||||
|
|
@ -404,12 +411,12 @@ async fn light_fuzz_proto_chunking_outcomes_are_bounded() {
|
|||
&mut idle_state,
|
||||
),
|
||||
)
|
||||
.await
|
||||
.expect("fuzz proto read step must complete");
|
||||
.await;
|
||||
|
||||
match step {
|
||||
Ok(Some((_payload, _))) => {}
|
||||
Err(ProxyError::Proxy(_)) => break,
|
||||
Err(ProxyError::Io(e)) if e.kind() == std::io::ErrorKind::TimedOut => break,
|
||||
Ok(None) => break,
|
||||
Err(other) => panic!("unexpected proto chunking fuzz error: {other}"),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -40,13 +40,44 @@ fn make_forensics(conn_id: u64, started_at: Instant) -> RelayForensicsState {
|
|||
fn make_enabled_idle_policy() -> RelayClientIdlePolicy {
|
||||
RelayClientIdlePolicy {
|
||||
enabled: true,
|
||||
soft_idle: Duration::from_secs(30),
|
||||
hard_idle: Duration::from_secs(60),
|
||||
soft_idle: Duration::from_millis(50),
|
||||
hard_idle: Duration::from_millis(120),
|
||||
grace_after_downstream_activity: Duration::from_secs(0),
|
||||
legacy_frame_read_timeout: Duration::from_secs(30),
|
||||
legacy_frame_read_timeout: Duration::from_millis(50),
|
||||
}
|
||||
}
|
||||
|
||||
async fn read_bounded(
|
||||
crypto_reader: &mut CryptoReader<tokio::io::DuplexStream>,
|
||||
proto_tag: ProtoTag,
|
||||
buffer_pool: &Arc<BufferPool>,
|
||||
forensics: &RelayForensicsState,
|
||||
frame_counter: &mut u64,
|
||||
stats: &Stats,
|
||||
idle_policy: &RelayClientIdlePolicy,
|
||||
idle_state: &mut RelayClientIdleState,
|
||||
last_downstream_activity_ms: &AtomicU64,
|
||||
session_started_at: Instant,
|
||||
) -> Result<Option<(PooledBuffer, bool)>> {
|
||||
run_relay_test_step_timeout(
|
||||
"tiny-frame debt read step",
|
||||
read_client_payload_with_idle_policy(
|
||||
crypto_reader,
|
||||
proto_tag,
|
||||
1024,
|
||||
buffer_pool,
|
||||
forensics,
|
||||
frame_counter,
|
||||
stats,
|
||||
idle_policy,
|
||||
idle_state,
|
||||
last_downstream_activity_ms,
|
||||
session_started_at,
|
||||
),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
fn simulate_tiny_debt_pattern(pattern: &[bool], max_steps: usize) -> (Option<usize>, u32, usize) {
|
||||
let mut debt = 0u32;
|
||||
let mut reals = 0usize;
|
||||
|
|
@ -246,10 +277,9 @@ async fn idle_policy_enabled_intermediate_zero_length_flood_is_fail_closed() {
|
|||
writer.write_all(&flood_encrypted).await.unwrap();
|
||||
drop(writer);
|
||||
|
||||
let result = read_client_payload_with_idle_policy(
|
||||
let result = read_bounded(
|
||||
&mut crypto_reader,
|
||||
ProtoTag::Intermediate,
|
||||
1024,
|
||||
&buffer_pool,
|
||||
&forensics,
|
||||
&mut frame_counter,
|
||||
|
|
@ -282,10 +312,9 @@ async fn idle_policy_enabled_secure_zero_length_flood_is_fail_closed() {
|
|||
writer.write_all(&flood_encrypted).await.unwrap();
|
||||
drop(writer);
|
||||
|
||||
let result = read_client_payload_with_idle_policy(
|
||||
let result = read_bounded(
|
||||
&mut crypto_reader,
|
||||
ProtoTag::Secure,
|
||||
1024,
|
||||
&buffer_pool,
|
||||
&forensics,
|
||||
&mut frame_counter,
|
||||
|
|
@ -325,10 +354,9 @@ async fn intermediate_alternating_zero_and_real_eventually_closes() {
|
|||
|
||||
let mut closed = false;
|
||||
for _ in 0..220 {
|
||||
let result = read_client_payload_with_idle_policy(
|
||||
let result = read_bounded(
|
||||
&mut crypto_reader,
|
||||
ProtoTag::Intermediate,
|
||||
1024,
|
||||
&buffer_pool,
|
||||
&forensics,
|
||||
&mut frame_counter,
|
||||
|
|
@ -377,10 +405,9 @@ async fn small_tiny_burst_followed_by_real_frame_does_not_spuriously_close() {
|
|||
let encrypted = encrypt_for_reader(&plaintext);
|
||||
writer.write_all(&encrypted).await.unwrap();
|
||||
|
||||
let first = read_client_payload_with_idle_policy(
|
||||
let first = read_bounded(
|
||||
&mut crypto_reader,
|
||||
ProtoTag::Abridged,
|
||||
1024,
|
||||
&buffer_pool,
|
||||
&forensics,
|
||||
&mut frame_counter,
|
||||
|
|
@ -420,10 +447,9 @@ async fn idle_policy_enabled_zero_length_flood_is_fail_closed() {
|
|||
.expect("zero-length flood bytes must be writable");
|
||||
drop(writer);
|
||||
|
||||
let result = read_client_payload_with_idle_policy(
|
||||
let result = read_bounded(
|
||||
&mut crypto_reader,
|
||||
ProtoTag::Abridged,
|
||||
1024,
|
||||
&buffer_pool,
|
||||
&forensics,
|
||||
&mut frame_counter,
|
||||
|
|
@ -470,10 +496,9 @@ async fn idle_policy_enabled_alternating_tiny_real_eventually_closes() {
|
|||
|
||||
let mut saw_proxy_close = false;
|
||||
for _ in 0..300 {
|
||||
let result = read_client_payload_with_idle_policy(
|
||||
let result = read_bounded(
|
||||
&mut crypto_reader,
|
||||
ProtoTag::Abridged,
|
||||
1024,
|
||||
&buffer_pool,
|
||||
&forensics,
|
||||
&mut frame_counter,
|
||||
|
|
@ -527,10 +552,9 @@ async fn enabled_idle_policy_valid_nonzero_frame_still_passes() {
|
|||
.await
|
||||
.expect("nonzero frame must be writable");
|
||||
|
||||
let result = read_client_payload_with_idle_policy(
|
||||
let result = read_bounded(
|
||||
&mut crypto_reader,
|
||||
ProtoTag::Abridged,
|
||||
1024,
|
||||
&buffer_pool,
|
||||
&forensics,
|
||||
&mut frame_counter,
|
||||
|
|
@ -548,3 +572,227 @@ async fn enabled_idle_policy_valid_nonzero_frame_still_passes() {
|
|||
assert!(!result.1);
|
||||
assert_eq!(frame_counter, 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn abridged_quickack_tiny_flood_is_fail_closed() {
|
||||
let (reader, mut writer) = duplex(4096);
|
||||
let mut crypto_reader = make_crypto_reader(reader);
|
||||
let buffer_pool = Arc::new(BufferPool::new());
|
||||
let stats = Stats::new();
|
||||
let session_started_at = Instant::now();
|
||||
let forensics = make_forensics(21, session_started_at);
|
||||
let mut frame_counter = 0u64;
|
||||
let mut idle_state = RelayClientIdleState::new(session_started_at);
|
||||
let idle_policy = make_enabled_idle_policy();
|
||||
let last_downstream_activity_ms = AtomicU64::new(0);
|
||||
|
||||
let flood_plaintext = vec![0x80u8; 256];
|
||||
let flood_encrypted = encrypt_for_reader(&flood_plaintext);
|
||||
writer.write_all(&flood_encrypted).await.unwrap();
|
||||
drop(writer);
|
||||
|
||||
let result = read_bounded(
|
||||
&mut crypto_reader,
|
||||
ProtoTag::Abridged,
|
||||
&buffer_pool,
|
||||
&forensics,
|
||||
&mut frame_counter,
|
||||
&stats,
|
||||
&idle_policy,
|
||||
&mut idle_state,
|
||||
&last_downstream_activity_ms,
|
||||
session_started_at,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(
|
||||
matches!(result, Err(ProxyError::Proxy(_))),
|
||||
"quickack-marked zero-length flood must fail closed"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn abridged_extended_zero_len_flood_is_fail_closed() {
|
||||
let (reader, mut writer) = duplex(4096);
|
||||
let mut crypto_reader = make_crypto_reader(reader);
|
||||
let buffer_pool = Arc::new(BufferPool::new());
|
||||
let stats = Stats::new();
|
||||
let session_started_at = Instant::now();
|
||||
let forensics = make_forensics(22, session_started_at);
|
||||
let mut frame_counter = 0u64;
|
||||
let mut idle_state = RelayClientIdleState::new(session_started_at);
|
||||
let idle_policy = make_enabled_idle_policy();
|
||||
let last_downstream_activity_ms = AtomicU64::new(0);
|
||||
|
||||
let mut flood_plaintext = Vec::with_capacity(4 * 256);
|
||||
for _ in 0..256 {
|
||||
flood_plaintext.extend_from_slice(&[0x7f, 0x00, 0x00, 0x00]);
|
||||
}
|
||||
let flood_encrypted = encrypt_for_reader(&flood_plaintext);
|
||||
writer.write_all(&flood_encrypted).await.unwrap();
|
||||
drop(writer);
|
||||
|
||||
let result = read_bounded(
|
||||
&mut crypto_reader,
|
||||
ProtoTag::Abridged,
|
||||
&buffer_pool,
|
||||
&forensics,
|
||||
&mut frame_counter,
|
||||
&stats,
|
||||
&idle_policy,
|
||||
&mut idle_state,
|
||||
&last_downstream_activity_ms,
|
||||
session_started_at,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(
|
||||
matches!(result, Err(ProxyError::Proxy(_))),
|
||||
"extended zero-length abridged flood must fail closed"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn one_to_eight_abridged_wire_pattern_survives_without_false_positive_close() {
|
||||
let mut plaintext = Vec::with_capacity(9 * 300);
|
||||
for idx in 0..300usize {
|
||||
plaintext.push(0x00);
|
||||
for _ in 0..8 {
|
||||
let b = idx as u8;
|
||||
plaintext.push(0x01);
|
||||
plaintext.extend_from_slice(&[b, b ^ 0x11, b ^ 0x22, b ^ 0x33]);
|
||||
}
|
||||
}
|
||||
|
||||
// Keep the test single-task and deterministic: make duplex capacity larger than the
|
||||
// generated ciphertext so write_all cannot block waiting for a concurrent reader.
|
||||
let duplex_capacity = plaintext.len().saturating_add(1024);
|
||||
let (reader, mut writer) = duplex(duplex_capacity);
|
||||
let mut crypto_reader = make_crypto_reader(reader);
|
||||
let buffer_pool = Arc::new(BufferPool::new());
|
||||
let stats = Stats::new();
|
||||
let session_started_at = Instant::now();
|
||||
let forensics = make_forensics(23, session_started_at);
|
||||
let mut frame_counter = 0u64;
|
||||
let mut idle_state = RelayClientIdleState::new(session_started_at);
|
||||
let idle_policy = make_enabled_idle_policy();
|
||||
let last_downstream_activity_ms = AtomicU64::new(0);
|
||||
|
||||
let encrypted = encrypt_for_reader(&plaintext);
|
||||
writer.write_all(&encrypted).await.unwrap();
|
||||
drop(writer);
|
||||
|
||||
let mut closed = false;
|
||||
for _ in 0..3000 {
|
||||
match read_bounded(
|
||||
&mut crypto_reader,
|
||||
ProtoTag::Abridged,
|
||||
&buffer_pool,
|
||||
&forensics,
|
||||
&mut frame_counter,
|
||||
&stats,
|
||||
&idle_policy,
|
||||
&mut idle_state,
|
||||
&last_downstream_activity_ms,
|
||||
session_started_at,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(Some(_)) => {}
|
||||
Ok(None) => break,
|
||||
Err(ProxyError::Proxy(_)) => {
|
||||
closed = true;
|
||||
break;
|
||||
}
|
||||
Err(other) => panic!("unexpected error in 1:8 wire test: {other}"),
|
||||
}
|
||||
}
|
||||
|
||||
assert!(
|
||||
!closed,
|
||||
"wire-level 1:8 tiny-to-real pattern should not trigger debt close"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn deterministic_light_fuzz_abridged_wire_behavior_matches_model() {
|
||||
let mut seed = 0xD1CE_BAAD_2026_0322u64;
|
||||
|
||||
for case_idx in 0..32u64 {
|
||||
seed ^= seed << 7;
|
||||
seed ^= seed >> 9;
|
||||
seed ^= seed << 8;
|
||||
|
||||
let events = 300 + ((seed as usize) & 0xff);
|
||||
let mut pattern = Vec::with_capacity(events);
|
||||
let mut local = seed;
|
||||
for _ in 0..events {
|
||||
local ^= local << 7;
|
||||
local ^= local >> 9;
|
||||
local ^= local << 8;
|
||||
pattern.push((local & 0x03) == 0);
|
||||
}
|
||||
|
||||
let mut plaintext = Vec::with_capacity(events * 6);
|
||||
for (idx, tiny) in pattern.iter().copied().enumerate() {
|
||||
if tiny {
|
||||
plaintext.push(0x00);
|
||||
} else {
|
||||
let b = (idx as u8) ^ (case_idx as u8);
|
||||
plaintext.push(0x01);
|
||||
plaintext.extend_from_slice(&[b, b ^ 0x1F, b ^ 0x7A, b ^ 0xC3]);
|
||||
}
|
||||
}
|
||||
|
||||
let (reader, mut writer) = duplex(16 * 1024);
|
||||
let mut crypto_reader = make_crypto_reader(reader);
|
||||
let buffer_pool = Arc::new(BufferPool::new());
|
||||
let stats = Stats::new();
|
||||
let session_started_at = Instant::now();
|
||||
let forensics = make_forensics(500 + case_idx, session_started_at);
|
||||
let mut frame_counter = 0u64;
|
||||
let mut idle_state = RelayClientIdleState::new(session_started_at);
|
||||
let idle_policy = make_enabled_idle_policy();
|
||||
let last_downstream_activity_ms = AtomicU64::new(0);
|
||||
|
||||
writer
|
||||
.write_all(&encrypt_for_reader(&plaintext))
|
||||
.await
|
||||
.unwrap();
|
||||
drop(writer);
|
||||
|
||||
let (expected_close, _, _) = simulate_tiny_debt_pattern(&pattern, pattern.len());
|
||||
let mut observed_close = false;
|
||||
|
||||
for _ in 0..(events + 8) {
|
||||
match read_bounded(
|
||||
&mut crypto_reader,
|
||||
ProtoTag::Abridged,
|
||||
&buffer_pool,
|
||||
&forensics,
|
||||
&mut frame_counter,
|
||||
&stats,
|
||||
&idle_policy,
|
||||
&mut idle_state,
|
||||
&last_downstream_activity_ms,
|
||||
session_started_at,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(Some(_)) => {}
|
||||
Ok(None) => break,
|
||||
Err(ProxyError::Proxy(_)) => {
|
||||
observed_close = true;
|
||||
break;
|
||||
}
|
||||
Err(other) => panic!("unexpected fuzz error: {other}"),
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
observed_close,
|
||||
expected_close.is_some(),
|
||||
"wire parser behavior must match debt model for case {case_idx}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,267 @@
|
|||
use super::relay_bidirectional;
|
||||
use crate::stats::Stats;
|
||||
use crate::stream::BufferPool;
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
|
||||
use tokio::time::{Duration, timeout};
|
||||
|
||||
fn quota_test_guard() -> impl Drop {
|
||||
super::quota_user_lock_test_scope()
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn negative_same_user_pipeline_stalls_while_middle_lock_is_held() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("relay-pipeline-stall-{}", std::process::id());
|
||||
let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user);
|
||||
let held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold shared cross-mode lock");
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let (mut client_peer, relay_client) = duplex(1024);
|
||||
let (relay_server, mut server_peer) = duplex(1024);
|
||||
let (client_reader, client_writer) = tokio::io::split(relay_client);
|
||||
let (server_reader, server_writer) = tokio::io::split(relay_server);
|
||||
|
||||
let relay_user = user.clone();
|
||||
let relay_stats = Arc::clone(&stats);
|
||||
let relay_task = tokio::spawn(async move {
|
||||
relay_bidirectional(
|
||||
client_reader,
|
||||
client_writer,
|
||||
server_reader,
|
||||
server_writer,
|
||||
256,
|
||||
256,
|
||||
&relay_user,
|
||||
relay_stats,
|
||||
Some(1024),
|
||||
Arc::new(BufferPool::new()),
|
||||
)
|
||||
.await
|
||||
});
|
||||
|
||||
server_peer
|
||||
.write_all(&[0xA1])
|
||||
.await
|
||||
.expect("server write should enqueue while relay is stalled");
|
||||
|
||||
let mut one = [0u8; 1];
|
||||
let blocked_read = timeout(Duration::from_millis(40), client_peer.read_exact(&mut one)).await;
|
||||
assert!(
|
||||
blocked_read.is_err(),
|
||||
"same-user relay must remain blocked while cross-mode lock is held"
|
||||
);
|
||||
|
||||
drop(held_guard);
|
||||
|
||||
timeout(Duration::from_millis(400), client_peer.read_exact(&mut one))
|
||||
.await
|
||||
.expect("blocked relay must resume after cross-mode lock release")
|
||||
.expect("resumed relay must deliver queued byte");
|
||||
assert_eq!(one, [0xA1]);
|
||||
|
||||
drop(client_peer);
|
||||
drop(server_peer);
|
||||
|
||||
let relay_result = timeout(Duration::from_secs(1), relay_task)
|
||||
.await
|
||||
.expect("relay task must complete")
|
||||
.expect("relay task must not panic");
|
||||
assert!(relay_result.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn integration_other_user_pipeline_progresses_while_blocked_user_is_stalled() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let blocked_user = format!("relay-pipeline-blocked-{}", std::process::id());
|
||||
let free_user = format!("relay-pipeline-free-{}", std::process::id());
|
||||
let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&blocked_user);
|
||||
let held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold blocked user's shared cross-mode lock");
|
||||
|
||||
let stats_blocked = Arc::new(Stats::new());
|
||||
let stats_free = Arc::new(Stats::new());
|
||||
|
||||
let (mut blocked_client, blocked_relay_client) = duplex(1024);
|
||||
let (blocked_relay_server, mut blocked_server) = duplex(1024);
|
||||
let (blocked_client_reader, blocked_client_writer) = tokio::io::split(blocked_relay_client);
|
||||
let (blocked_server_reader, blocked_server_writer) = tokio::io::split(blocked_relay_server);
|
||||
|
||||
let (mut free_client, free_relay_client) = duplex(1024);
|
||||
let (free_relay_server, mut free_server) = duplex(1024);
|
||||
let (free_client_reader, free_client_writer) = tokio::io::split(free_relay_client);
|
||||
let (free_server_reader, free_server_writer) = tokio::io::split(free_relay_server);
|
||||
|
||||
let blocked_task = {
|
||||
let user = blocked_user.clone();
|
||||
let stats = Arc::clone(&stats_blocked);
|
||||
tokio::spawn(async move {
|
||||
relay_bidirectional(
|
||||
blocked_client_reader,
|
||||
blocked_client_writer,
|
||||
blocked_server_reader,
|
||||
blocked_server_writer,
|
||||
256,
|
||||
256,
|
||||
&user,
|
||||
stats,
|
||||
Some(1024),
|
||||
Arc::new(BufferPool::new()),
|
||||
)
|
||||
.await
|
||||
})
|
||||
};
|
||||
|
||||
let free_task = {
|
||||
let user = free_user.clone();
|
||||
let stats = Arc::clone(&stats_free);
|
||||
tokio::spawn(async move {
|
||||
relay_bidirectional(
|
||||
free_client_reader,
|
||||
free_client_writer,
|
||||
free_server_reader,
|
||||
free_server_writer,
|
||||
256,
|
||||
256,
|
||||
&user,
|
||||
stats,
|
||||
Some(1024),
|
||||
Arc::new(BufferPool::new()),
|
||||
)
|
||||
.await
|
||||
})
|
||||
};
|
||||
|
||||
blocked_server
|
||||
.write_all(&[0xB1])
|
||||
.await
|
||||
.expect("blocked user server write should queue");
|
||||
free_server
|
||||
.write_all(&[0xC1])
|
||||
.await
|
||||
.expect("free user server write should queue");
|
||||
|
||||
let mut blocked_buf = [0u8; 1];
|
||||
let mut free_buf = [0u8; 1];
|
||||
|
||||
let blocked_stalled = timeout(
|
||||
Duration::from_millis(40),
|
||||
blocked_client.read_exact(&mut blocked_buf),
|
||||
)
|
||||
.await;
|
||||
assert!(
|
||||
blocked_stalled.is_err(),
|
||||
"blocked user must remain stalled while its lock is held"
|
||||
);
|
||||
|
||||
timeout(Duration::from_millis(250), free_client.read_exact(&mut free_buf))
|
||||
.await
|
||||
.expect("free user must make progress while other user is blocked")
|
||||
.expect("free user read must succeed");
|
||||
assert_eq!(free_buf, [0xC1]);
|
||||
|
||||
drop(held_guard);
|
||||
|
||||
timeout(Duration::from_millis(400), blocked_client.read_exact(&mut blocked_buf))
|
||||
.await
|
||||
.expect("blocked user must resume after release")
|
||||
.expect("blocked user resumed read must succeed");
|
||||
assert_eq!(blocked_buf, [0xB1]);
|
||||
|
||||
drop(blocked_client);
|
||||
drop(blocked_server);
|
||||
drop(free_client);
|
||||
drop(free_server);
|
||||
|
||||
assert!(
|
||||
timeout(Duration::from_secs(1), blocked_task)
|
||||
.await
|
||||
.expect("blocked relay task must complete")
|
||||
.expect("blocked relay task must not panic")
|
||||
.is_ok()
|
||||
);
|
||||
assert!(
|
||||
timeout(Duration::from_secs(1), free_task)
|
||||
.await
|
||||
.expect("free relay task must complete")
|
||||
.expect("free relay task must not panic")
|
||||
.is_ok()
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn stress_jittered_hold_release_cycles_preserve_pipeline_liveness() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let mut seed = 0x5EED_C0DE_2026_0323u64;
|
||||
for round in 0..24u32 {
|
||||
seed ^= seed << 7;
|
||||
seed ^= seed >> 9;
|
||||
seed ^= seed << 8;
|
||||
|
||||
let hold_ms = 2 + (seed % 10);
|
||||
let user = format!("relay-pipeline-fuzz-{}-{round}", std::process::id());
|
||||
let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user);
|
||||
let held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold lock during fuzz round");
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let (mut client_peer, relay_client) = duplex(1024);
|
||||
let (relay_server, mut server_peer) = duplex(1024);
|
||||
let (client_reader, client_writer) = tokio::io::split(relay_client);
|
||||
let (server_reader, server_writer) = tokio::io::split(relay_server);
|
||||
|
||||
let relay_user = user.clone();
|
||||
let relay_stats = Arc::clone(&stats);
|
||||
let relay_task = tokio::spawn(async move {
|
||||
relay_bidirectional(
|
||||
client_reader,
|
||||
client_writer,
|
||||
server_reader,
|
||||
server_writer,
|
||||
256,
|
||||
256,
|
||||
&relay_user,
|
||||
relay_stats,
|
||||
Some(1024),
|
||||
Arc::new(BufferPool::new()),
|
||||
)
|
||||
.await
|
||||
});
|
||||
|
||||
server_peer
|
||||
.write_all(&[0xD1])
|
||||
.await
|
||||
.expect("server write should queue in fuzz round");
|
||||
|
||||
let mut one = [0u8; 1];
|
||||
let stalled = timeout(Duration::from_millis(30), client_peer.read_exact(&mut one)).await;
|
||||
assert!(stalled.is_err(), "held phase must stall same-user relay");
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(hold_ms)).await;
|
||||
drop(held_guard);
|
||||
|
||||
timeout(Duration::from_millis(400), client_peer.read_exact(&mut one))
|
||||
.await
|
||||
.expect("released phase must resume same-user relay")
|
||||
.expect("released phase read must succeed");
|
||||
assert_eq!(one, [0xD1]);
|
||||
|
||||
drop(client_peer);
|
||||
drop(server_peer);
|
||||
|
||||
assert!(
|
||||
timeout(Duration::from_secs(1), relay_task)
|
||||
.await
|
||||
.expect("fuzz relay task must complete")
|
||||
.expect("fuzz relay task must not panic")
|
||||
.is_ok()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,213 @@
|
|||
use super::relay_bidirectional;
|
||||
use crate::stats::Stats;
|
||||
use crate::stream::BufferPool;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
|
||||
use tokio::sync::{Barrier, watch};
|
||||
use tokio::time::{Duration, Instant, timeout};
|
||||
|
||||
fn quota_test_guard() -> impl Drop {
|
||||
super::quota_user_lock_test_scope()
|
||||
}
|
||||
|
||||
fn percentile_index(len: usize, percentile: usize) -> usize {
|
||||
((len * percentile) / 100).min(len.saturating_sub(1))
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn micro_benchmark_pipeline_release_to_delivery_latency_stays_bounded() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let rounds = 64usize;
|
||||
let user = format!("relay-pipeline-latency-single-{}", std::process::id());
|
||||
let mut samples_ms = Vec::with_capacity(rounds);
|
||||
|
||||
for round in 0..rounds {
|
||||
let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user);
|
||||
let held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold shared cross-mode lock before round");
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let (mut client_peer, relay_client) = duplex(1024);
|
||||
let (relay_server, mut server_peer) = duplex(1024);
|
||||
let (client_reader, client_writer) = tokio::io::split(relay_client);
|
||||
let (server_reader, server_writer) = tokio::io::split(relay_server);
|
||||
|
||||
let relay_user = user.clone();
|
||||
let relay_stats = Arc::clone(&stats);
|
||||
let relay_task = tokio::spawn(async move {
|
||||
relay_bidirectional(
|
||||
client_reader,
|
||||
client_writer,
|
||||
server_reader,
|
||||
server_writer,
|
||||
256,
|
||||
256,
|
||||
&relay_user,
|
||||
relay_stats,
|
||||
Some(2048),
|
||||
Arc::new(BufferPool::new()),
|
||||
)
|
||||
.await
|
||||
});
|
||||
|
||||
server_peer
|
||||
.write_all(&[(round as u8) ^ 0xA5])
|
||||
.await
|
||||
.expect("server write should queue before release");
|
||||
|
||||
let release_at = Instant::now();
|
||||
drop(held_guard);
|
||||
|
||||
let mut one = [0u8; 1];
|
||||
timeout(Duration::from_millis(450), client_peer.read_exact(&mut one))
|
||||
.await
|
||||
.expect("client must receive queued byte after release")
|
||||
.expect("queued byte read must succeed");
|
||||
samples_ms.push(release_at.elapsed().as_millis() as u64);
|
||||
|
||||
drop(client_peer);
|
||||
drop(server_peer);
|
||||
|
||||
let relay_result = timeout(Duration::from_secs(1), relay_task)
|
||||
.await
|
||||
.expect("relay task must complete")
|
||||
.expect("relay task must not panic");
|
||||
assert!(relay_result.is_ok());
|
||||
}
|
||||
|
||||
samples_ms.sort_unstable();
|
||||
let p50_ms = samples_ms[percentile_index(samples_ms.len(), 50)];
|
||||
let p95_ms = samples_ms[percentile_index(samples_ms.len(), 95)];
|
||||
|
||||
assert!(
|
||||
p50_ms <= 45,
|
||||
"single-flow release latency p50 must stay bounded; p50_ms={p50_ms}, samples={samples_ms:?}"
|
||||
);
|
||||
assert!(
|
||||
p95_ms <= 130,
|
||||
"single-flow release latency p95 must stay bounded; p95_ms={p95_ms}, samples={samples_ms:?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn stress_128_waiter_pipeline_release_latency_p95_stays_bounded() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let waiters = 128usize;
|
||||
let user = format!("relay-pipeline-latency-fanout-{}", std::process::id());
|
||||
let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user);
|
||||
let held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold shared lock before fanout release benchmark");
|
||||
|
||||
let ready_barrier = Arc::new(Barrier::new(waiters + 1));
|
||||
let release_at = Arc::new(Mutex::new(None::<Instant>));
|
||||
let (release_tx, release_rx) = watch::channel(false);
|
||||
let mut tasks = Vec::with_capacity(waiters);
|
||||
|
||||
for idx in 0..waiters {
|
||||
let user = user.clone();
|
||||
let barrier = Arc::clone(&ready_barrier);
|
||||
let release_at = Arc::clone(&release_at);
|
||||
let mut release_rx = release_rx.clone();
|
||||
|
||||
tasks.push(tokio::spawn(async move {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let (mut client_peer, relay_client) = duplex(512);
|
||||
let (relay_server, mut server_peer) = duplex(512);
|
||||
let (client_reader, client_writer) = tokio::io::split(relay_client);
|
||||
let (server_reader, server_writer) = tokio::io::split(relay_server);
|
||||
|
||||
let relay_user = user;
|
||||
let relay_stats = Arc::clone(&stats);
|
||||
let relay_task = tokio::spawn(async move {
|
||||
relay_bidirectional(
|
||||
client_reader,
|
||||
client_writer,
|
||||
server_reader,
|
||||
server_writer,
|
||||
256,
|
||||
256,
|
||||
&relay_user,
|
||||
relay_stats,
|
||||
Some(2048),
|
||||
Arc::new(BufferPool::new()),
|
||||
)
|
||||
.await
|
||||
});
|
||||
|
||||
server_peer
|
||||
.write_all(&[(idx as u8) ^ 0x5A])
|
||||
.await
|
||||
.expect("fanout server write should queue before release");
|
||||
|
||||
barrier.wait().await;
|
||||
release_rx
|
||||
.changed()
|
||||
.await
|
||||
.expect("release signal should remain available");
|
||||
|
||||
let started = {
|
||||
let guard = release_at.lock().unwrap_or_else(|poison| poison.into_inner());
|
||||
guard.expect("release timestamp must be populated before signal")
|
||||
};
|
||||
|
||||
let mut one = [0u8; 1];
|
||||
timeout(Duration::from_millis(900), client_peer.read_exact(&mut one))
|
||||
.await
|
||||
.expect("fanout waiter must receive queued byte after release")
|
||||
.expect("fanout waiter read must succeed");
|
||||
|
||||
drop(client_peer);
|
||||
drop(server_peer);
|
||||
|
||||
let relay_result = timeout(Duration::from_secs(2), relay_task)
|
||||
.await
|
||||
.expect("fanout relay task must complete")
|
||||
.expect("fanout relay task must not panic");
|
||||
assert!(relay_result.is_ok());
|
||||
|
||||
started.elapsed().as_millis() as u64
|
||||
}));
|
||||
}
|
||||
|
||||
ready_barrier.wait().await;
|
||||
{
|
||||
let mut guard = release_at.lock().unwrap_or_else(|poison| poison.into_inner());
|
||||
*guard = Some(Instant::now());
|
||||
}
|
||||
drop(held_guard);
|
||||
release_tx
|
||||
.send(true)
|
||||
.expect("release broadcast must succeed");
|
||||
|
||||
let mut samples_ms = Vec::with_capacity(waiters);
|
||||
timeout(Duration::from_secs(8), async {
|
||||
for task in tasks {
|
||||
let elapsed = task.await.expect("fanout waiter must not panic");
|
||||
samples_ms.push(elapsed);
|
||||
}
|
||||
})
|
||||
.await
|
||||
.expect("fanout benchmark must complete in bounded time");
|
||||
|
||||
samples_ms.sort_unstable();
|
||||
let p50_ms = samples_ms[percentile_index(samples_ms.len(), 50)];
|
||||
let p95_ms = samples_ms[percentile_index(samples_ms.len(), 95)];
|
||||
let max_ms = *samples_ms.last().unwrap_or(&0);
|
||||
|
||||
assert!(
|
||||
p50_ms <= 120,
|
||||
"fanout release latency p50 must stay bounded; p50_ms={p50_ms}, p95_ms={p95_ms}, max_ms={max_ms}"
|
||||
);
|
||||
assert!(
|
||||
p95_ms <= 260,
|
||||
"fanout release latency p95 must stay bounded; p50_ms={p50_ms}, p95_ms={p95_ms}, max_ms={max_ms}"
|
||||
);
|
||||
assert!(
|
||||
max_ms <= 700,
|
||||
"fanout release latency max must stay bounded; p50_ms={p50_ms}, p95_ms={p95_ms}, max_ms={max_ms}"
|
||||
);
|
||||
}
|
||||
|
|
@ -3,8 +3,9 @@ use crate::stats::Stats;
|
|||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
|
||||
use std::task::{Context, Waker};
|
||||
use std::task::{Context, Poll, Waker};
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadBuf};
|
||||
use tokio::sync::Barrier;
|
||||
use tokio::time::{Duration, timeout};
|
||||
|
||||
#[derive(Default)]
|
||||
|
|
@ -26,6 +27,13 @@ fn quota_test_guard() -> impl Drop {
|
|||
super::quota_user_lock_test_scope()
|
||||
}
|
||||
|
||||
fn build_context() -> (Arc<WakeCounter>, Context<'static>) {
|
||||
let wake_counter = Arc::new(WakeCounter::default());
|
||||
let waker = Waker::from(Arc::clone(&wake_counter));
|
||||
let leaked_waker: &'static Waker = Box::leak(Box::new(waker));
|
||||
(wake_counter, Context::from_waker(leaked_waker))
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn positive_cross_mode_uncontended_writer_progresses() {
|
||||
let _guard = quota_test_guard();
|
||||
|
|
@ -223,3 +231,374 @@ async fn light_fuzz_cross_mode_release_timing_preserves_read_write_liveness() {
|
|||
assert!(write_done.is_ok());
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn integration_middle_lock_blocks_relay_reader_for_same_user() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("cross-mode-middle-reader-block-{}", std::process::id());
|
||||
let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user);
|
||||
let _held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold middle-relay shared lock");
|
||||
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::empty(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
Some(1024),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
|
||||
let (_wake_counter, mut cx) = build_context();
|
||||
let mut one = [0u8; 1];
|
||||
let mut buf = ReadBuf::new(&mut one);
|
||||
let poll = Pin::new(&mut io).poll_read(&mut cx, &mut buf);
|
||||
assert!(poll.is_pending());
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn integration_middle_lock_release_unblocks_relay_reader() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("cross-mode-middle-reader-release-{}", std::process::id());
|
||||
let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user);
|
||||
let held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold middle-relay shared lock");
|
||||
|
||||
let task = tokio::spawn({
|
||||
let user = user.clone();
|
||||
async move {
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::empty(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
Some(1024),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
let mut one = [0u8; 1];
|
||||
io.read(&mut one).await
|
||||
}
|
||||
});
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(5)).await;
|
||||
drop(held_guard);
|
||||
|
||||
let done = timeout(Duration::from_millis(300), task)
|
||||
.await
|
||||
.expect("reader task must complete after release")
|
||||
.expect("reader task must not panic");
|
||||
assert!(done.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn business_different_user_middle_lock_does_not_block_relay_writer() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let held_user = format!("cross-mode-middle-held-{}", std::process::id());
|
||||
let active_user = format!("cross-mode-middle-active-{}", std::process::id());
|
||||
let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&held_user);
|
||||
let _held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold middle-relay lock for other user");
|
||||
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
active_user,
|
||||
Some(1024),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
|
||||
let (_wake_counter, mut cx) = build_context();
|
||||
let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x61]);
|
||||
assert!(matches!(poll, Poll::Ready(Ok(1))));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn edge_quota_none_bypasses_cross_mode_lock_even_when_held() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("cross-mode-none-limit-{}", std::process::id());
|
||||
let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user);
|
||||
let _held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold lock while quota is disabled");
|
||||
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
None,
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
|
||||
let (_wake_counter, mut cx) = build_context();
|
||||
let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x62, 0x63]);
|
||||
assert!(matches!(poll, Poll::Ready(Ok(2))));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn edge_quota_exceeded_flag_short_circuits_before_lock_path() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("cross-mode-pre-exceeded-{}", std::process::id());
|
||||
let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user);
|
||||
let _held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold shared lock before poll");
|
||||
|
||||
let quota_exceeded = Arc::new(AtomicBool::new(true));
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
Some(1024),
|
||||
Arc::clone("a_exceeded),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
|
||||
let (_wake_counter, mut cx) = build_context();
|
||||
let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x64]);
|
||||
assert!(matches!(poll, Poll::Ready(Err(ref e)) if is_quota_io_error(e)));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn adversarial_repoll_while_middle_lock_held_keeps_pending_without_usage_leak() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("cross-mode-repoll-held-{}", std::process::id());
|
||||
let stats = Arc::new(Stats::new());
|
||||
let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user);
|
||||
let _held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold lock for repoll sequence");
|
||||
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::clone(&stats),
|
||||
user.clone(),
|
||||
Some(1024),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
|
||||
let (_wake_counter, mut cx) = build_context();
|
||||
for _ in 0..8 {
|
||||
let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x65]);
|
||||
assert!(poll.is_pending());
|
||||
}
|
||||
|
||||
assert_eq!(stats.get_user_total_octets(&user), 0);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn integration_same_user_mixed_read_write_waiters_resume_after_release() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("cross-mode-mixed-resume-{}", std::process::id());
|
||||
let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user);
|
||||
let held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold lock before spawning mixed waiters");
|
||||
|
||||
let mut tasks = Vec::new();
|
||||
for i in 0..12usize {
|
||||
let user = user.clone();
|
||||
tasks.push(tokio::spawn(async move {
|
||||
if i % 2 == 0 {
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::empty(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
Some(1024),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
let mut b = [0u8; 1];
|
||||
io.read(&mut b).await.map(|_| ())
|
||||
} else {
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
Some(1024),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
io.write_all(&[0x66]).await
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(8)).await;
|
||||
drop(held_guard);
|
||||
|
||||
timeout(Duration::from_secs(1), async {
|
||||
for task in tasks {
|
||||
let result = task.await.expect("mixed waiter task must not panic");
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
})
|
||||
.await
|
||||
.expect("all mixed waiters must finish after release");
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn integration_one_user_blocked_other_user_progresses_under_middle_lock() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let blocked_user = format!("cross-mode-blocked-{}", std::process::id());
|
||||
let free_user = format!("cross-mode-free-{}", std::process::id());
|
||||
let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&blocked_user);
|
||||
let held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold blocked user lock");
|
||||
|
||||
let blocked_task = tokio::spawn({
|
||||
let blocked_user = blocked_user.clone();
|
||||
async move {
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
blocked_user,
|
||||
Some(1024),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
io.write_all(&[0x77]).await
|
||||
}
|
||||
});
|
||||
|
||||
let free_task = tokio::spawn({
|
||||
let free_user = free_user.clone();
|
||||
async move {
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
free_user,
|
||||
Some(1024),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
io.write_all(&[0x78]).await
|
||||
}
|
||||
});
|
||||
|
||||
let free_done = timeout(Duration::from_millis(250), free_task)
|
||||
.await
|
||||
.expect("free user must not be blocked")
|
||||
.expect("free user task must not panic");
|
||||
assert!(free_done.is_ok());
|
||||
|
||||
drop(held_guard);
|
||||
let blocked_done = timeout(Duration::from_secs(1), blocked_task)
|
||||
.await
|
||||
.expect("blocked user must resume after release")
|
||||
.expect("blocked user task must not panic");
|
||||
assert!(blocked_done.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn stress_middle_lock_release_allows_high_waiter_fanout_completion() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("cross-mode-fanout-{}", std::process::id());
|
||||
let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user);
|
||||
let held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold lock before fanout");
|
||||
|
||||
let waiters = 48usize;
|
||||
let gate = Arc::new(Barrier::new(waiters + 1));
|
||||
let mut tasks = Vec::new();
|
||||
for _ in 0..waiters {
|
||||
let user = user.clone();
|
||||
let gate = Arc::clone(&gate);
|
||||
tasks.push(tokio::spawn(async move {
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
Some(1024),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
gate.wait().await;
|
||||
io.write_all(&[0x79]).await
|
||||
}));
|
||||
}
|
||||
|
||||
gate.wait().await;
|
||||
tokio::time::sleep(Duration::from_millis(10)).await;
|
||||
drop(held_guard);
|
||||
|
||||
timeout(Duration::from_secs(2), async {
|
||||
for task in tasks {
|
||||
let result = task.await.expect("fanout task must not panic");
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
})
|
||||
.await
|
||||
.expect("fanout waiters must complete after release");
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn light_fuzz_middle_lock_hold_release_cycles_preserve_same_user_liveness() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let mut seed = 0xA11C_EE55_2026_0323u64;
|
||||
for round in 0..20u32 {
|
||||
seed ^= seed << 7;
|
||||
seed ^= seed >> 9;
|
||||
seed ^= seed << 8;
|
||||
|
||||
let hold_ms = 2 + (seed % 10);
|
||||
let user = format!("cross-mode-middle-fuzz-{}-{round}", std::process::id());
|
||||
let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user);
|
||||
let held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold lock in fuzz round");
|
||||
|
||||
let writer = tokio::spawn({
|
||||
let user = user.clone();
|
||||
async move {
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
Some(1024),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
io.write_all(&[0x7A]).await
|
||||
}
|
||||
});
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(hold_ms)).await;
|
||||
drop(held_guard);
|
||||
|
||||
let done = timeout(Duration::from_millis(400), writer)
|
||||
.await
|
||||
.expect("writer must complete after lock release")
|
||||
.expect("writer task must not panic");
|
||||
assert!(done.is_ok());
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,340 @@
|
|||
use super::*;
|
||||
use crate::stats::Stats;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
|
||||
use std::task::{Context, Waker};
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::time::{Duration, Instant, timeout};
|
||||
|
||||
#[derive(Default)]
|
||||
struct WakeCounter {
|
||||
wakes: AtomicUsize,
|
||||
}
|
||||
|
||||
impl std::task::Wake for WakeCounter {
|
||||
fn wake(self: Arc<Self>) {
|
||||
self.wakes.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
fn wake_by_ref(self: &Arc<Self>) {
|
||||
self.wakes.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
fn quota_test_guard() -> impl Drop {
|
||||
super::quota_user_lock_test_scope()
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn positive_uncontended_dual_lock_writer_has_zero_retry_attempt() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
format!("dual-lock-alt-positive-{}", std::process::id()),
|
||||
Some(2048),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
Instant::now(),
|
||||
);
|
||||
|
||||
let write = io.write_all(&[0xAA, 0xBB]).await;
|
||||
assert!(write.is_ok(), "uncontended write must complete");
|
||||
assert_eq!(
|
||||
io.quota_write_retry_attempt, 0,
|
||||
"uncontended write must not advance retry backoff"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn adversarial_alternating_local_and_cross_mode_contention_preserves_backoff_growth() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("dual-lock-alt-adversarial-{}", std::process::id());
|
||||
let local_lock = quota_user_lock(&user);
|
||||
let cross_mode_lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
|
||||
|
||||
let mut local_guard = Some(
|
||||
local_lock
|
||||
.try_lock()
|
||||
.expect("test must hold local quota lock initially"),
|
||||
);
|
||||
let mut cross_guard = None;
|
||||
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
Some(2048),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
Instant::now(),
|
||||
);
|
||||
|
||||
let wake_counter = Arc::new(WakeCounter::default());
|
||||
let waker = Waker::from(Arc::clone(&wake_counter));
|
||||
let mut cx = Context::from_waker(&waker);
|
||||
|
||||
let first = Pin::new(&mut io).poll_write(&mut cx, &[0x11]);
|
||||
assert!(first.is_pending(), "held local lock must block first poll");
|
||||
|
||||
let mut observed_wakes = 0usize;
|
||||
for idx in 0..18usize {
|
||||
tokio::time::sleep(Duration::from_millis(6)).await;
|
||||
|
||||
if idx % 2 == 0 {
|
||||
drop(local_guard.take());
|
||||
cross_guard = Some(
|
||||
cross_mode_lock
|
||||
.try_lock()
|
||||
.expect("cross-mode lock should be acquirable while local lock released"),
|
||||
);
|
||||
} else {
|
||||
drop(cross_guard.take());
|
||||
local_guard = Some(
|
||||
local_lock
|
||||
.try_lock()
|
||||
.expect("local lock should be acquirable while cross lock released"),
|
||||
);
|
||||
}
|
||||
|
||||
let wakes = wake_counter.wakes.load(Ordering::Relaxed);
|
||||
if wakes > observed_wakes {
|
||||
observed_wakes = wakes;
|
||||
let pending = Pin::new(&mut io).poll_write(&mut cx, &[0x12]);
|
||||
assert!(
|
||||
pending.is_pending(),
|
||||
"alternating contention must keep write pending while one lock is held"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
assert!(
|
||||
io.quota_write_retry_attempt >= 2,
|
||||
"alternating contention must still ramp retry backoff; got {}",
|
||||
io.quota_write_retry_attempt
|
||||
);
|
||||
assert!(
|
||||
wake_counter.wakes.load(Ordering::Relaxed) <= 32,
|
||||
"alternating contention must stay wake-rate-limited"
|
||||
);
|
||||
|
||||
drop(local_guard);
|
||||
drop(cross_guard);
|
||||
let ready = Pin::new(&mut io).poll_write(&mut cx, &[0x13]);
|
||||
assert!(ready.is_ready(), "writer must resume after both locks released");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn edge_retry_scheduler_resets_after_alternating_contention_clears() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("dual-lock-alt-edge-reset-{}", std::process::id());
|
||||
let local_lock = quota_user_lock(&user);
|
||||
let local_guard = local_lock
|
||||
.try_lock()
|
||||
.expect("test must hold local lock for edge scenario");
|
||||
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
Some(2048),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
Instant::now(),
|
||||
);
|
||||
|
||||
let wake_counter = Arc::new(WakeCounter::default());
|
||||
let waker = Waker::from(Arc::clone(&wake_counter));
|
||||
let mut cx = Context::from_waker(&waker);
|
||||
|
||||
let first = Pin::new(&mut io).poll_write(&mut cx, &[0x21]);
|
||||
assert!(first.is_pending());
|
||||
tokio::time::sleep(Duration::from_millis(15)).await;
|
||||
if wake_counter.wakes.load(Ordering::Relaxed) > 0 {
|
||||
let next = Pin::new(&mut io).poll_write(&mut cx, &[0x22]);
|
||||
assert!(next.is_pending());
|
||||
}
|
||||
|
||||
drop(local_guard);
|
||||
|
||||
let ready = Pin::new(&mut io).poll_write(&mut cx, &[0x23]);
|
||||
assert!(ready.is_ready());
|
||||
assert_eq!(
|
||||
io.quota_write_retry_attempt, 0,
|
||||
"successful dual-lock acquisition must reset retry scheduler"
|
||||
);
|
||||
assert!(!io.quota_write_wake_scheduled);
|
||||
assert!(io.quota_write_retry_sleep.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn integration_cross_mode_waiters_remain_live_under_alternating_contention_then_resume() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("dual-lock-alt-integration-{}", std::process::id());
|
||||
let local_lock = quota_user_lock(&user);
|
||||
let cross_mode_lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
|
||||
|
||||
let mut waiters = Vec::new();
|
||||
for _ in 0..16usize {
|
||||
let user = user.clone();
|
||||
waiters.push(tokio::spawn(async move {
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
Some(2048),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
Instant::now(),
|
||||
);
|
||||
timeout(Duration::from_secs(2), io.write_all(&[0x31])).await
|
||||
}));
|
||||
}
|
||||
|
||||
let mut local_guard = Some(
|
||||
local_lock
|
||||
.try_lock()
|
||||
.expect("integration toggle must acquire local lock first"),
|
||||
);
|
||||
let mut cross_guard = None;
|
||||
|
||||
for idx in 0..24usize {
|
||||
tokio::time::sleep(Duration::from_millis(4)).await;
|
||||
if idx % 2 == 0 {
|
||||
drop(local_guard.take());
|
||||
cross_guard = cross_mode_lock.try_lock().ok();
|
||||
} else {
|
||||
drop(cross_guard.take());
|
||||
local_guard = local_lock.try_lock().ok();
|
||||
}
|
||||
}
|
||||
|
||||
drop(local_guard);
|
||||
drop(cross_guard);
|
||||
|
||||
for waiter in waiters {
|
||||
let done = waiter.await.expect("waiter task must not panic");
|
||||
assert!(
|
||||
done.is_ok(),
|
||||
"waiter must finish once alternating contention window ends"
|
||||
);
|
||||
assert!(done.expect("waiter timeout must not fire").is_ok());
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn light_fuzz_alternating_contention_matrix_preserves_lock_gating() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("dual-lock-alt-fuzz-{}", std::process::id());
|
||||
let local_lock = quota_user_lock(&user);
|
||||
let cross_mode_lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
|
||||
let mut seed = 0xD00D_BAAD_F00D_2026u64;
|
||||
|
||||
for _round in 0..64u32 {
|
||||
seed ^= seed << 7;
|
||||
seed ^= seed >> 9;
|
||||
seed ^= seed << 8;
|
||||
|
||||
let hold_mode = (seed % 3) as u8;
|
||||
let local_guard = if hold_mode == 0 {
|
||||
Some(
|
||||
local_lock
|
||||
.try_lock()
|
||||
.expect("fuzz local lock should be acquirable"),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let cross_guard = if hold_mode == 1 {
|
||||
Some(
|
||||
cross_mode_lock
|
||||
.try_lock()
|
||||
.expect("fuzz cross lock should be acquirable"),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user.clone(),
|
||||
Some(1024),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
Instant::now(),
|
||||
);
|
||||
|
||||
let write = timeout(Duration::from_millis(35), io.write_all(&[0x51])).await;
|
||||
if hold_mode == 2 {
|
||||
assert!(write.is_ok(), "unheld fuzz round must make progress");
|
||||
assert!(write.expect("unheld round timeout").is_ok());
|
||||
} else {
|
||||
assert!(
|
||||
write.is_err(),
|
||||
"held-lock fuzz round must remain pending inside bounded window"
|
||||
);
|
||||
}
|
||||
|
||||
drop(local_guard);
|
||||
drop(cross_guard);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn stress_fanout_alternating_contention_recovers_without_hanging() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("dual-lock-alt-stress-{}", std::process::id());
|
||||
let local_lock = quota_user_lock(&user);
|
||||
let cross_mode_lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
|
||||
|
||||
let mut waiters = Vec::new();
|
||||
for _ in 0..48usize {
|
||||
let user = user.clone();
|
||||
waiters.push(tokio::spawn(async move {
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
Some(4096),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
Instant::now(),
|
||||
);
|
||||
timeout(Duration::from_secs(3), io.write_all(&[0xA0, 0xA1])).await
|
||||
}));
|
||||
}
|
||||
|
||||
let mut local_guard = Some(
|
||||
local_lock
|
||||
.try_lock()
|
||||
.expect("stress toggle must acquire local lock first"),
|
||||
);
|
||||
let mut cross_guard = None;
|
||||
for idx in 0..40usize {
|
||||
tokio::time::sleep(Duration::from_millis(3)).await;
|
||||
if idx % 2 == 0 {
|
||||
drop(local_guard.take());
|
||||
cross_guard = cross_mode_lock.try_lock().ok();
|
||||
} else {
|
||||
drop(cross_guard.take());
|
||||
local_guard = local_lock.try_lock().ok();
|
||||
}
|
||||
}
|
||||
|
||||
drop(local_guard);
|
||||
drop(cross_guard);
|
||||
|
||||
for waiter in waiters {
|
||||
let done = waiter.await.expect("stress waiter task must not panic");
|
||||
assert!(done.is_ok(), "stress waiter timed out under alternating contention");
|
||||
assert!(done.expect("stress waiter timeout should not fire").is_ok());
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,74 @@
|
|||
use super::*;
|
||||
use crate::stats::Stats;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
|
||||
use std::task::{Context, Waker};
|
||||
use tokio::time::{Duration, Instant};
|
||||
|
||||
#[derive(Default)]
|
||||
struct WakeCounter {
|
||||
wakes: AtomicUsize,
|
||||
}
|
||||
|
||||
impl std::task::Wake for WakeCounter {
|
||||
fn wake(self: Arc<Self>) {
|
||||
self.wakes.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
fn wake_by_ref(self: &Arc<Self>) {
|
||||
self.wakes.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
fn quota_test_guard() -> impl Drop {
|
||||
super::quota_user_lock_test_scope()
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn adversarial_cross_mode_only_contention_backoff_attempt_must_ramp() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("dual-lock-backoff-{}", std::process::id());
|
||||
let cross_mode_lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
|
||||
let held_cross_mode_guard = cross_mode_lock
|
||||
.try_lock()
|
||||
.expect("test must hold cross-mode lock before polling");
|
||||
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
Some(2048),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
Instant::now(),
|
||||
);
|
||||
|
||||
let wake_counter = Arc::new(WakeCounter::default());
|
||||
let waker = Waker::from(Arc::clone(&wake_counter));
|
||||
let mut cx = Context::from_waker(&waker);
|
||||
|
||||
let first = Pin::new(&mut io).poll_write(&mut cx, &[0xAA]);
|
||||
assert!(first.is_pending(), "held cross-mode lock must block writer");
|
||||
|
||||
let started = Instant::now();
|
||||
let mut last_wakes = 0usize;
|
||||
while started.elapsed() < Duration::from_millis(120) {
|
||||
let wakes = wake_counter.wakes.load(Ordering::Relaxed);
|
||||
if wakes > last_wakes {
|
||||
last_wakes = wakes;
|
||||
let next = Pin::new(&mut io).poll_write(&mut cx, &[0xAB]);
|
||||
assert!(next.is_pending(), "writer must remain blocked while lock is held");
|
||||
}
|
||||
tokio::time::sleep(Duration::from_millis(1)).await;
|
||||
}
|
||||
|
||||
assert!(
|
||||
io.quota_write_retry_attempt >= 2,
|
||||
"retry attempt must ramp under sustained second-lock contention; got {}",
|
||||
io.quota_write_retry_attempt
|
||||
);
|
||||
|
||||
drop(held_cross_mode_guard);
|
||||
}
|
||||
|
|
@ -0,0 +1,325 @@
|
|||
use super::*;
|
||||
use crate::stats::Stats;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
|
||||
use std::task::{Context, Waker};
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadBuf};
|
||||
use tokio::time::{Duration, Instant, timeout};
|
||||
|
||||
#[derive(Default)]
|
||||
struct WakeCounter {
|
||||
wakes: AtomicUsize,
|
||||
}
|
||||
|
||||
impl std::task::Wake for WakeCounter {
|
||||
fn wake(self: Arc<Self>) {
|
||||
self.wakes.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
fn wake_by_ref(self: &Arc<Self>) {
|
||||
self.wakes.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
fn quota_test_guard() -> impl Drop {
|
||||
super::quota_user_lock_test_scope()
|
||||
}
|
||||
|
||||
fn build_context() -> (Arc<WakeCounter>, Context<'static>) {
|
||||
let wake_counter = Arc::new(WakeCounter::default());
|
||||
let waker = Waker::from(Arc::clone(&wake_counter));
|
||||
let leaked_waker: &'static Waker = Box::leak(Box::new(waker));
|
||||
(wake_counter, Context::from_waker(leaked_waker))
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn positive_uncontended_dual_locks_writer_completes_without_retry_state() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
format!("dual-lock-positive-{}", std::process::id()),
|
||||
Some(4096),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
Instant::now(),
|
||||
);
|
||||
|
||||
let (_wake_counter, mut cx) = build_context();
|
||||
let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x01, 0x02, 0x03]);
|
||||
assert!(poll.is_ready());
|
||||
assert_eq!(io.quota_write_retry_attempt, 0);
|
||||
assert!(!io.quota_write_wake_scheduled);
|
||||
assert!(io.quota_write_retry_sleep.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn negative_local_lock_contention_read_retry_attempt_ramps() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("dual-lock-local-contention-{}", std::process::id());
|
||||
let held = quota_user_lock(&user);
|
||||
let held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold local quota lock before polling");
|
||||
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::empty(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
Some(2048),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
Instant::now(),
|
||||
);
|
||||
|
||||
let (wake_counter, mut cx) = build_context();
|
||||
let mut one = [0u8; 1];
|
||||
let mut buf = ReadBuf::new(&mut one);
|
||||
let first = Pin::new(&mut io).poll_read(&mut cx, &mut buf);
|
||||
assert!(first.is_pending());
|
||||
|
||||
let started = Instant::now();
|
||||
let mut observed = 0usize;
|
||||
while started.elapsed() < Duration::from_millis(120) {
|
||||
let wakes = wake_counter.wakes.load(Ordering::Relaxed);
|
||||
if wakes > observed {
|
||||
observed = wakes;
|
||||
let mut step_buf = ReadBuf::new(&mut one);
|
||||
let next = Pin::new(&mut io).poll_read(&mut cx, &mut step_buf);
|
||||
assert!(next.is_pending());
|
||||
}
|
||||
tokio::time::sleep(Duration::from_millis(1)).await;
|
||||
}
|
||||
|
||||
assert!(
|
||||
io.quota_read_retry_attempt >= 2,
|
||||
"retry attempt must ramp under sustained local-lock contention; got {}",
|
||||
io.quota_read_retry_attempt
|
||||
);
|
||||
|
||||
drop(held_guard);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn edge_cross_mode_contention_release_resets_retry_scheduler_on_success() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("dual-lock-reset-{}", std::process::id());
|
||||
let cross_mode = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
|
||||
let held_guard = cross_mode
|
||||
.try_lock()
|
||||
.expect("test must hold cross-mode lock before polling");
|
||||
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
Some(2048),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
Instant::now(),
|
||||
);
|
||||
|
||||
let (wake_counter, mut cx) = build_context();
|
||||
let first = Pin::new(&mut io).poll_write(&mut cx, &[0x10]);
|
||||
assert!(first.is_pending());
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(20)).await;
|
||||
if wake_counter.wakes.load(Ordering::Relaxed) > 0 {
|
||||
let next = Pin::new(&mut io).poll_write(&mut cx, &[0x11]);
|
||||
assert!(next.is_pending());
|
||||
}
|
||||
|
||||
drop(held_guard);
|
||||
let ready = Pin::new(&mut io).poll_write(&mut cx, &[0x12]);
|
||||
assert!(ready.is_ready());
|
||||
assert_eq!(io.quota_write_retry_attempt, 0);
|
||||
assert!(!io.quota_write_wake_scheduled);
|
||||
assert!(io.quota_write_retry_sleep.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn adversarial_cross_mode_hold_blocks_many_waiters_without_usage_leak() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("dual-lock-adversarial-{}", std::process::id());
|
||||
let stats = Arc::new(Stats::new());
|
||||
let held = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
|
||||
let held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold cross-mode lock before launching waiters");
|
||||
|
||||
let mut tasks = Vec::new();
|
||||
for _ in 0..24usize {
|
||||
let stats = Arc::clone(&stats);
|
||||
let user = user.clone();
|
||||
tasks.push(tokio::spawn(async move {
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
stats,
|
||||
user,
|
||||
Some(1024),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
Instant::now(),
|
||||
);
|
||||
timeout(Duration::from_millis(40), io.write_all(&[0x33])).await
|
||||
}));
|
||||
}
|
||||
|
||||
for task in tasks {
|
||||
let timed = task.await.expect("waiter task must not panic");
|
||||
assert!(timed.is_err(), "held cross-mode lock must keep waiter pending");
|
||||
}
|
||||
|
||||
assert_eq!(stats.get_user_total_octets(&user), 0);
|
||||
drop(held_guard);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn integration_waiters_resume_after_cross_mode_release() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("dual-lock-integration-{}", std::process::id());
|
||||
let held = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
|
||||
let held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold cross-mode lock before starting waiter");
|
||||
|
||||
let task = tokio::spawn({
|
||||
let user = user.clone();
|
||||
async move {
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
Some(1024),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
Instant::now(),
|
||||
);
|
||||
io.write_all(&[0x44]).await
|
||||
}
|
||||
});
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(10)).await;
|
||||
drop(held_guard);
|
||||
|
||||
let done = timeout(Duration::from_secs(1), task)
|
||||
.await
|
||||
.expect("waiter task must complete after release")
|
||||
.expect("waiter task must not panic");
|
||||
assert!(done.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn light_fuzz_randomized_lock_holds_preserve_liveness_and_quota_bounds() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("dual-lock-fuzz-{}", std::process::id());
|
||||
let stats = Arc::new(Stats::new());
|
||||
let mut seed = 0xA55A_55AA_C3D2_E1F0u64;
|
||||
|
||||
for _round in 0..48u32 {
|
||||
seed ^= seed << 7;
|
||||
seed ^= seed >> 9;
|
||||
seed ^= seed << 8;
|
||||
|
||||
let hold_mode = (seed % 3) as u8;
|
||||
let mut local_lock = None;
|
||||
let mut cross_lock = None;
|
||||
let mut local_guard = None;
|
||||
let mut cross_guard = None;
|
||||
|
||||
if hold_mode == 0 {
|
||||
local_lock = Some(quota_user_lock(&user));
|
||||
local_guard = Some(
|
||||
local_lock
|
||||
.as_ref()
|
||||
.expect("local lock should be present")
|
||||
.try_lock()
|
||||
.expect("local lock should be acquirable in fuzz round"),
|
||||
);
|
||||
} else if hold_mode == 1 {
|
||||
cross_lock = Some(crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(
|
||||
&user,
|
||||
));
|
||||
cross_guard = Some(
|
||||
cross_lock
|
||||
.as_ref()
|
||||
.expect("cross lock should be present")
|
||||
.try_lock()
|
||||
.expect("cross lock should be acquirable in fuzz round"),
|
||||
);
|
||||
}
|
||||
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::clone(&stats),
|
||||
user.clone(),
|
||||
Some(4096),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
Instant::now(),
|
||||
);
|
||||
|
||||
let write = timeout(Duration::from_millis(25), io.write_all(&[0x7A])).await;
|
||||
if hold_mode == 2 {
|
||||
assert!(write.is_ok(), "unheld round must make progress");
|
||||
} else {
|
||||
assert!(write.is_err(), "held-lock round must stay blocked within timeout");
|
||||
}
|
||||
|
||||
drop(local_guard);
|
||||
drop(cross_guard);
|
||||
drop(local_lock);
|
||||
drop(cross_lock);
|
||||
}
|
||||
|
||||
assert!(stats.get_user_total_octets(&user) <= 4096);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn stress_fanout_waiters_complete_after_release_without_panics() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("dual-lock-stress-{}", std::process::id());
|
||||
let held = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
|
||||
let held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold cross-mode lock before stress fanout");
|
||||
|
||||
let waiters = 64usize;
|
||||
let mut tasks = Vec::new();
|
||||
for _ in 0..waiters {
|
||||
let user = user.clone();
|
||||
tasks.push(tokio::spawn(async move {
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::empty(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
Some(1024),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
Instant::now(),
|
||||
);
|
||||
let mut one = [0u8; 1];
|
||||
io.read(&mut one).await
|
||||
}));
|
||||
}
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(12)).await;
|
||||
drop(held_guard);
|
||||
|
||||
timeout(Duration::from_secs(2), async {
|
||||
for task in tasks {
|
||||
let result = task.await.expect("stress waiter task must not panic");
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
})
|
||||
.await
|
||||
.expect("all stress waiters must complete after release");
|
||||
}
|
||||
|
|
@ -0,0 +1,128 @@
|
|||
use super::*;
|
||||
use crate::stats::Stats;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::time::{Duration, timeout};
|
||||
|
||||
fn quota_test_guard() -> impl Drop {
|
||||
super::quota_user_lock_test_scope()
|
||||
}
|
||||
|
||||
fn make_stats_io(user: String) -> StatsIo<tokio::io::Sink> {
|
||||
StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
Some(4096),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
)
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn light_fuzz_1024_round_hold_release_cycles_preserve_same_user_liveness() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("dual-lock-race-fuzz-{}", std::process::id());
|
||||
let mut seed = 0xD1CE_BAAD_5EED_1234u64;
|
||||
|
||||
for round in 0..1024u32 {
|
||||
seed ^= seed << 7;
|
||||
seed ^= seed >> 9;
|
||||
seed ^= seed << 8;
|
||||
|
||||
let hold = (seed & 1) == 0;
|
||||
let hold_ms = (seed % 3) as u64;
|
||||
|
||||
let maybe_lock = if hold {
|
||||
Some(crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(
|
||||
&user,
|
||||
))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let maybe_guard = maybe_lock.as_ref().map(|lock| {
|
||||
lock.try_lock()
|
||||
.expect("cross-mode lock must be acquirable in fuzz round")
|
||||
});
|
||||
|
||||
if hold {
|
||||
let mut blocked_io = make_stats_io(user.clone());
|
||||
let blocked = timeout(Duration::from_millis(5), blocked_io.write_all(&[0xA5])).await;
|
||||
assert!(
|
||||
blocked.is_err(),
|
||||
"held round must block waiter before lock release (round={round})"
|
||||
);
|
||||
|
||||
if hold_ms > 0 {
|
||||
tokio::time::sleep(Duration::from_millis(hold_ms)).await;
|
||||
}
|
||||
} else {
|
||||
let mut free_io = make_stats_io(user.clone());
|
||||
let free = timeout(Duration::from_millis(120), free_io.write_all(&[0xA5])).await;
|
||||
assert!(
|
||||
free.is_ok(),
|
||||
"unheld round must complete promptly (round={round})"
|
||||
);
|
||||
assert!(free.expect("unheld round should complete").is_ok());
|
||||
}
|
||||
|
||||
drop(maybe_guard);
|
||||
|
||||
let done = timeout(Duration::from_millis(350), async {
|
||||
let user = user.clone();
|
||||
let mut io = make_stats_io(user);
|
||||
io.write_all(&[0xA6]).await
|
||||
})
|
||||
.await
|
||||
.expect("post-release write must complete in bounded time");
|
||||
assert!(done.is_ok());
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn stress_jittered_three_waiter_rounds_do_not_starve_after_release() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("dual-lock-race-stress-{}", std::process::id());
|
||||
let mut seed = 0xC0FF_EE77_4444_9999u64;
|
||||
|
||||
for round in 0..256u32 {
|
||||
seed ^= seed << 7;
|
||||
seed ^= seed >> 9;
|
||||
seed ^= seed << 8;
|
||||
|
||||
let hold_ms = (seed % 4) as u64;
|
||||
let lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
|
||||
let guard = lock
|
||||
.try_lock()
|
||||
.expect("cross-mode lock must be acquirable at round start");
|
||||
|
||||
let mut waiters = Vec::new();
|
||||
for _ in 0..3usize {
|
||||
let user = user.clone();
|
||||
waiters.push(tokio::spawn(async move {
|
||||
let mut io = make_stats_io(user);
|
||||
io.write_all(&[0x55]).await
|
||||
}));
|
||||
}
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(hold_ms)).await;
|
||||
drop(guard);
|
||||
|
||||
timeout(Duration::from_secs(1), async {
|
||||
for waiter in waiters {
|
||||
let done = waiter.await.expect("waiter task must not panic");
|
||||
assert!(
|
||||
done.is_ok(),
|
||||
"waiter must complete after release (round={round})"
|
||||
);
|
||||
}
|
||||
})
|
||||
.await
|
||||
.expect("all waiters must complete in bounded time after release");
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,332 @@
|
|||
use super::relay_bidirectional;
|
||||
use crate::error::ProxyError;
|
||||
use crate::stats::Stats;
|
||||
use crate::stream::BufferPool;
|
||||
use rand::rngs::StdRng;
|
||||
use rand::{RngExt, SeedableRng};
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::time::{Duration, timeout};
|
||||
|
||||
async fn read_available<R: tokio::io::AsyncRead + Unpin>(reader: &mut R, budget: Duration) -> usize {
|
||||
let start = tokio::time::Instant::now();
|
||||
let mut total = 0usize;
|
||||
let mut buf = [0u8; 128];
|
||||
|
||||
loop {
|
||||
let elapsed = start.elapsed();
|
||||
if elapsed >= budget {
|
||||
break;
|
||||
}
|
||||
let remaining = budget.saturating_sub(elapsed);
|
||||
match timeout(remaining, reader.read(&mut buf)).await {
|
||||
Ok(Ok(0)) => break,
|
||||
Ok(Ok(n)) => total = total.saturating_add(n),
|
||||
Ok(Err(_)) | Err(_) => break,
|
||||
}
|
||||
}
|
||||
|
||||
total
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn positive_quota_path_forwards_both_directions_within_limit() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user = "quota-extended-positive-user";
|
||||
|
||||
let (mut client_peer, relay_client) = duplex(2048);
|
||||
let (relay_server, mut server_peer) = duplex(2048);
|
||||
let (client_reader, client_writer) = tokio::io::split(relay_client);
|
||||
let (server_reader, server_writer) = tokio::io::split(relay_server);
|
||||
|
||||
let relay = tokio::spawn(relay_bidirectional(
|
||||
client_reader,
|
||||
client_writer,
|
||||
server_reader,
|
||||
server_writer,
|
||||
256,
|
||||
256,
|
||||
user,
|
||||
Arc::clone(&stats),
|
||||
Some(16),
|
||||
Arc::new(BufferPool::new()),
|
||||
));
|
||||
|
||||
client_peer.write_all(&[0xAA, 0xBB, 0xCC, 0xDD]).await.unwrap();
|
||||
server_peer.read_exact(&mut [0u8; 4]).await.unwrap();
|
||||
|
||||
server_peer.write_all(&[0x11, 0x22, 0x33, 0x44]).await.unwrap();
|
||||
client_peer.read_exact(&mut [0u8; 4]).await.unwrap();
|
||||
|
||||
drop(client_peer);
|
||||
drop(server_peer);
|
||||
|
||||
let relay_result = timeout(Duration::from_secs(2), relay).await.unwrap().unwrap();
|
||||
assert!(relay_result.is_ok());
|
||||
assert!(stats.get_user_total_octets(user) <= 16);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn negative_preloaded_quota_forbids_any_forwarding() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user = "quota-extended-negative-user";
|
||||
stats.add_user_octets_from(user, 8);
|
||||
|
||||
let (mut client_peer, relay_client) = duplex(1024);
|
||||
let (relay_server, mut server_peer) = duplex(1024);
|
||||
let (client_reader, client_writer) = tokio::io::split(relay_client);
|
||||
let (server_reader, server_writer) = tokio::io::split(relay_server);
|
||||
|
||||
let relay = tokio::spawn(relay_bidirectional(
|
||||
client_reader,
|
||||
client_writer,
|
||||
server_reader,
|
||||
server_writer,
|
||||
128,
|
||||
128,
|
||||
user,
|
||||
Arc::clone(&stats),
|
||||
Some(8),
|
||||
Arc::new(BufferPool::new()),
|
||||
));
|
||||
|
||||
client_peer.write_all(&[0xAA]).await.unwrap();
|
||||
server_peer.write_all(&[0xBB]).await.unwrap();
|
||||
|
||||
assert_eq!(read_available(&mut server_peer, Duration::from_millis(120)).await, 0);
|
||||
assert_eq!(read_available(&mut client_peer, Duration::from_millis(120)).await, 0);
|
||||
|
||||
let relay_result = timeout(Duration::from_secs(2), relay).await.unwrap().unwrap();
|
||||
assert!(matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. })));
|
||||
assert!(stats.get_user_total_octets(user) <= 8);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn edge_quota_one_ensures_at_most_one_byte_across_directions() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user = "quota-extended-edge-user";
|
||||
|
||||
let (mut client_peer, relay_client) = duplex(1024);
|
||||
let (relay_server, mut server_peer) = duplex(1024);
|
||||
let (client_reader, client_writer) = tokio::io::split(relay_client);
|
||||
let (server_reader, server_writer) = tokio::io::split(relay_server);
|
||||
|
||||
let relay = tokio::spawn(relay_bidirectional(
|
||||
client_reader,
|
||||
client_writer,
|
||||
server_reader,
|
||||
server_writer,
|
||||
128,
|
||||
128,
|
||||
user,
|
||||
Arc::clone(&stats),
|
||||
Some(1),
|
||||
Arc::new(BufferPool::new()),
|
||||
));
|
||||
|
||||
let _ = tokio::join!(
|
||||
client_peer.write_all(&[0xFE]),
|
||||
server_peer.write_all(&[0xEF]),
|
||||
);
|
||||
|
||||
let mut buf = [0u8; 1];
|
||||
let delivered_s2c = timeout(Duration::from_millis(120), client_peer.read(&mut buf)).await.unwrap().unwrap_or(0);
|
||||
let delivered_c2s = timeout(Duration::from_millis(120), server_peer.read(&mut buf)).await.unwrap().unwrap_or(0);
|
||||
|
||||
assert!(delivered_s2c + delivered_c2s <= 1);
|
||||
|
||||
let relay_result = timeout(Duration::from_secs(2), relay).await.unwrap().unwrap();
|
||||
assert!(matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. })));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn adversarial_blackhat_alternating_jitter_does_not_overshoot_quota() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user = "quota-extended-blackhat-user";
|
||||
let quota = 24u64;
|
||||
|
||||
let (mut client_peer, relay_client) = duplex(4096);
|
||||
let (relay_server, mut server_peer) = duplex(4096);
|
||||
let (client_reader, client_writer) = tokio::io::split(relay_client);
|
||||
let (server_reader, server_writer) = tokio::io::split(relay_server);
|
||||
|
||||
let relay = tokio::spawn(relay_bidirectional(
|
||||
client_reader,
|
||||
client_writer,
|
||||
server_reader,
|
||||
server_writer,
|
||||
256,
|
||||
256,
|
||||
user,
|
||||
Arc::clone(&stats),
|
||||
Some(quota),
|
||||
Arc::new(BufferPool::new()),
|
||||
));
|
||||
|
||||
let mut total_forwarded = 0usize;
|
||||
|
||||
for i in 0..256usize {
|
||||
if relay.is_finished() {
|
||||
break;
|
||||
}
|
||||
if (i & 1) == 0 {
|
||||
let _ = client_peer.write_all(&[(i as u8) ^ 0x57]).await;
|
||||
let mut one = [0u8; 1];
|
||||
if let Ok(Ok(n)) = timeout(Duration::from_millis(6), server_peer.read(&mut one)).await {
|
||||
total_forwarded += n;
|
||||
}
|
||||
} else {
|
||||
let _ = server_peer.write_all(&[(i as u8) ^ 0xA8]).await;
|
||||
let mut one = [0u8; 1];
|
||||
if let Ok(Ok(n)) = timeout(Duration::from_millis(6), client_peer.read(&mut one)).await {
|
||||
total_forwarded += n;
|
||||
}
|
||||
}
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(((i % 3) + 1) as u64)).await;
|
||||
}
|
||||
|
||||
let relay_result = timeout(Duration::from_secs(3), relay).await.unwrap().unwrap();
|
||||
assert!(matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. })));
|
||||
assert!(total_forwarded <= quota as usize);
|
||||
assert!(stats.get_user_total_octets(user) <= quota);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn light_fuzz_random_quota_schedule_preserves_quota_invariants() {
|
||||
let mut rng = StdRng::seed_from_u64(0xBEEF_C0DE);
|
||||
|
||||
for case in 0..32u64 {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user = format!("quota-extended-fuzz-{case}");
|
||||
let quota = rng.random_range(1u64..=35u64);
|
||||
|
||||
let (mut client_peer, relay_client) = duplex(4096);
|
||||
let (relay_server, mut server_peer) = duplex(4096);
|
||||
let (client_reader, client_writer) = tokio::io::split(relay_client);
|
||||
let (server_reader, server_writer) = tokio::io::split(relay_server);
|
||||
|
||||
let relay_user = user.clone();
|
||||
let relay_stats = Arc::clone(&stats);
|
||||
let relay = tokio::spawn(async move {
|
||||
relay_bidirectional(
|
||||
client_reader,
|
||||
client_writer,
|
||||
server_reader,
|
||||
server_writer,
|
||||
256,
|
||||
256,
|
||||
&relay_user,
|
||||
Arc::clone(&relay_stats),
|
||||
Some(quota),
|
||||
Arc::new(BufferPool::new()),
|
||||
)
|
||||
.await
|
||||
});
|
||||
|
||||
let mut total_forwarded = 0usize;
|
||||
|
||||
for _ in 0..96usize {
|
||||
if relay.is_finished() {
|
||||
break;
|
||||
}
|
||||
|
||||
if rng.random::<bool>() {
|
||||
let _ = client_peer.write_all(&[rng.random::<u8>()]).await;
|
||||
let mut one = [0u8; 1];
|
||||
if let Ok(Ok(n)) = timeout(Duration::from_millis(4), server_peer.read(&mut one)).await {
|
||||
total_forwarded += n;
|
||||
}
|
||||
} else {
|
||||
let _ = server_peer.write_all(&[rng.random::<u8>()]).await;
|
||||
let mut one = [0u8; 1];
|
||||
if let Ok(Ok(n)) = timeout(Duration::from_millis(4), client_peer.read(&mut one)).await {
|
||||
total_forwarded += n;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
drop(client_peer);
|
||||
drop(server_peer);
|
||||
|
||||
let relay_result = timeout(Duration::from_secs(2), relay).await.unwrap().unwrap();
|
||||
assert!(relay_result.is_ok() || matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. })));
|
||||
assert!(total_forwarded <= quota as usize);
|
||||
assert!(stats.get_user_total_octets(&user) <= quota);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn stress_parallel_relays_for_one_user_obey_global_quota() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user = "quota-extended-stress-user".to_string();
|
||||
let quota = 64u64;
|
||||
|
||||
let mut tasks = Vec::new();
|
||||
|
||||
for worker in 0..4u8 {
|
||||
let stats = Arc::clone(&stats);
|
||||
let user = user.clone();
|
||||
|
||||
tasks.push(tokio::spawn(async move {
|
||||
let (mut client_peer, relay_client) = duplex(2048);
|
||||
let (relay_server, mut server_peer) = duplex(2048);
|
||||
let (client_reader, client_writer) = tokio::io::split(relay_client);
|
||||
let (server_reader, server_writer) = tokio::io::split(relay_server);
|
||||
|
||||
let relay_user = user.clone();
|
||||
let relay_stats = Arc::clone(&stats);
|
||||
let relay = tokio::spawn(async move {
|
||||
relay_bidirectional(
|
||||
client_reader,
|
||||
client_writer,
|
||||
server_reader,
|
||||
server_writer,
|
||||
128,
|
||||
128,
|
||||
&relay_user,
|
||||
Arc::clone(&relay_stats),
|
||||
Some(quota),
|
||||
Arc::new(BufferPool::new()),
|
||||
)
|
||||
.await
|
||||
});
|
||||
|
||||
let mut total = 0usize;
|
||||
for step in 0..64u8 {
|
||||
if relay.is_finished() {
|
||||
break;
|
||||
}
|
||||
if (step as usize + worker as usize) % 2 == 0 {
|
||||
let _ = client_peer.write_all(&[(step ^ 0x5A)]).await;
|
||||
let mut one = [0u8; 1];
|
||||
if let Ok(Ok(n)) = timeout(Duration::from_millis(6), server_peer.read(&mut one)).await {
|
||||
total += n;
|
||||
}
|
||||
} else {
|
||||
let _ = server_peer.write_all(&[(step ^ 0xA5)]).await;
|
||||
let mut one = [0u8; 1];
|
||||
if let Ok(Ok(n)) = timeout(Duration::from_millis(6), client_peer.read(&mut one)).await {
|
||||
total += n;
|
||||
}
|
||||
}
|
||||
tokio::time::sleep(Duration::from_millis(1)).await;
|
||||
}
|
||||
|
||||
drop(client_peer);
|
||||
drop(server_peer);
|
||||
|
||||
let relay_result = timeout(Duration::from_secs(2), relay).await.unwrap().unwrap();
|
||||
assert!(relay_result.is_ok() || matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. })));
|
||||
total
|
||||
}));
|
||||
}
|
||||
|
||||
let mut delivered = 0usize;
|
||||
for task in tasks {
|
||||
delivered += task.await.unwrap();
|
||||
}
|
||||
|
||||
assert!(stats.get_user_total_octets(&user) <= quota);
|
||||
assert!(delivered <= quota as usize);
|
||||
}
|
||||
|
|
@ -0,0 +1,79 @@
|
|||
use super::*;
|
||||
use dashmap::DashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::time::{Duration, timeout};
|
||||
|
||||
#[test]
|
||||
fn tdd_explicit_quota_lock_evict_reclaims_only_unheld_entries() {
|
||||
let _guard = quota_user_lock_test_scope();
|
||||
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
map.clear();
|
||||
|
||||
let held_user = format!("quota-evict-held-{}", std::process::id());
|
||||
let stale_a_user = format!("quota-evict-stale-a-{}", std::process::id());
|
||||
let stale_b_user = format!("quota-evict-stale-b-{}", std::process::id());
|
||||
|
||||
let held = quota_user_lock(&held_user);
|
||||
let stale_a = quota_user_lock(&stale_a_user);
|
||||
let stale_b = quota_user_lock(&stale_b_user);
|
||||
|
||||
assert!(map.get(&held_user).is_some());
|
||||
assert!(map.get(&stale_a_user).is_some());
|
||||
assert!(map.get(&stale_b_user).is_some());
|
||||
|
||||
drop(stale_a);
|
||||
drop(stale_b);
|
||||
|
||||
quota_user_lock_evict();
|
||||
|
||||
assert!(
|
||||
map.get(&held_user).is_some(),
|
||||
"held entry must survive eviction"
|
||||
);
|
||||
assert!(
|
||||
map.get(&stale_a_user).is_none(),
|
||||
"unheld stale entry must be reclaimed"
|
||||
);
|
||||
assert!(
|
||||
map.get(&stale_b_user).is_none(),
|
||||
"unheld stale entry must be reclaimed"
|
||||
);
|
||||
|
||||
drop(held);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn tdd_periodic_quota_lock_evictor_reclaims_stale_entries_off_hot_path() {
|
||||
let _guard = quota_user_lock_test_scope();
|
||||
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
map.clear();
|
||||
|
||||
let held_user = format!("quota-evict-loop-held-{}", std::process::id());
|
||||
let stale_user = format!("quota-evict-loop-stale-{}", std::process::id());
|
||||
|
||||
let held = quota_user_lock(&held_user);
|
||||
let stale = quota_user_lock(&stale_user);
|
||||
|
||||
assert_eq!(map.len(), 2);
|
||||
drop(stale);
|
||||
|
||||
let evictor = spawn_quota_user_lock_evictor_for_tests(Duration::from_millis(5));
|
||||
|
||||
timeout(Duration::from_millis(200), async {
|
||||
loop {
|
||||
if map.get(&stale_user).is_none() {
|
||||
break;
|
||||
}
|
||||
tokio::time::sleep(Duration::from_millis(5)).await;
|
||||
}
|
||||
})
|
||||
.await
|
||||
.expect("periodic quota lock evictor must reclaim stale entry");
|
||||
|
||||
evictor.abort();
|
||||
|
||||
assert!(map.get(&held_user).is_some());
|
||||
assert!(map.get(&stale_user).is_none());
|
||||
|
||||
drop(held);
|
||||
}
|
||||
|
|
@ -0,0 +1,153 @@
|
|||
use super::*;
|
||||
use dashmap::DashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::task::JoinSet;
|
||||
use tokio::time::{Duration, timeout};
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn stress_background_evictor_with_high_churn_keeps_cache_bounded_and_live() {
|
||||
let _guard = quota_user_lock_test_scope();
|
||||
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
map.clear();
|
||||
|
||||
let evictor = spawn_quota_user_lock_evictor_for_tests(Duration::from_millis(5));
|
||||
|
||||
let mut tasks = JoinSet::new();
|
||||
for worker in 0..24u32 {
|
||||
tasks.spawn(async move {
|
||||
for round in 0..320u32 {
|
||||
let user = format!(
|
||||
"quota-evict-stress-user-{}-{}-{}",
|
||||
std::process::id(),
|
||||
worker,
|
||||
round
|
||||
);
|
||||
let lock = quota_user_lock(&user);
|
||||
if round % 19 == 0 {
|
||||
tokio::task::yield_now().await;
|
||||
}
|
||||
drop(lock);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
while let Some(done) = tasks.join_next().await {
|
||||
done.expect("stress worker must not panic");
|
||||
}
|
||||
|
||||
quota_user_lock_evict();
|
||||
tokio::time::sleep(Duration::from_millis(20)).await;
|
||||
|
||||
assert!(
|
||||
map.len() <= QUOTA_USER_LOCKS_MAX,
|
||||
"quota lock map must remain bounded after churn + eviction"
|
||||
);
|
||||
|
||||
let sanity_user = format!("quota-evict-stress-sanity-{}", std::process::id());
|
||||
let sanity_lock = quota_user_lock(&sanity_user);
|
||||
assert!(
|
||||
map.get(&sanity_user).is_some(),
|
||||
"sanity user should be cacheable after eviction reclaimed stale entries"
|
||||
);
|
||||
|
||||
drop(sanity_lock);
|
||||
evictor.abort();
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn adversarial_held_lock_survives_repeated_eviction_then_reclaims_after_release() {
|
||||
let _guard = quota_user_lock_test_scope();
|
||||
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
map.clear();
|
||||
|
||||
let held_user = format!("quota-evict-held-survive-{}", std::process::id());
|
||||
let held = quota_user_lock(&held_user);
|
||||
|
||||
let evictor = spawn_quota_user_lock_evictor_for_tests(Duration::from_millis(3));
|
||||
|
||||
for idx in 0..512u32 {
|
||||
let user = format!("quota-evict-held-churn-{}-{}", std::process::id(), idx);
|
||||
let temp = quota_user_lock(&user);
|
||||
drop(temp);
|
||||
if idx % 32 == 0 {
|
||||
tokio::task::yield_now().await;
|
||||
}
|
||||
}
|
||||
|
||||
let reacquired = quota_user_lock(&held_user);
|
||||
assert!(
|
||||
Arc::ptr_eq(&held, &reacquired),
|
||||
"held user lock identity must remain stable across repeated evictions"
|
||||
);
|
||||
assert!(
|
||||
map.get(&held_user).is_some(),
|
||||
"held user entry must not be reclaimed while externally referenced"
|
||||
);
|
||||
|
||||
drop(reacquired);
|
||||
drop(held);
|
||||
|
||||
timeout(Duration::from_millis(300), async {
|
||||
loop {
|
||||
if map.get(&held_user).is_none() {
|
||||
break;
|
||||
}
|
||||
tokio::time::sleep(Duration::from_millis(5)).await;
|
||||
}
|
||||
})
|
||||
.await
|
||||
.expect("released held lock must be reclaimed by periodic evictor");
|
||||
|
||||
evictor.abort();
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn stress_saturation_then_periodic_eviction_recovers_cacheability_without_inline_retain() {
|
||||
let _guard = quota_user_lock_test_scope();
|
||||
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
map.clear();
|
||||
|
||||
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
|
||||
let prefix = format!("quota-evict-saturated-{}", std::process::id());
|
||||
for idx in 0..QUOTA_USER_LOCKS_MAX {
|
||||
retained.push(quota_user_lock(&format!("{prefix}-{idx}")));
|
||||
}
|
||||
|
||||
assert_eq!(map.len(), QUOTA_USER_LOCKS_MAX);
|
||||
|
||||
let overflow_user = format!("quota-evict-overflow-user-{}", std::process::id());
|
||||
let overflow_before = quota_user_lock(&overflow_user);
|
||||
assert!(
|
||||
map.get(&overflow_user).is_none(),
|
||||
"saturated map must initially route new user to overflow stripe"
|
||||
);
|
||||
|
||||
drop(retained);
|
||||
|
||||
let evictor = spawn_quota_user_lock_evictor_for_tests(Duration::from_millis(4));
|
||||
|
||||
timeout(Duration::from_millis(400), async {
|
||||
loop {
|
||||
if map.len() < QUOTA_USER_LOCKS_MAX {
|
||||
break;
|
||||
}
|
||||
tokio::time::sleep(Duration::from_millis(5)).await;
|
||||
}
|
||||
})
|
||||
.await
|
||||
.expect("periodic evictor must reclaim stale saturated entries");
|
||||
|
||||
let overflow_after = quota_user_lock(&overflow_user);
|
||||
assert!(
|
||||
map.get(&overflow_user).is_some(),
|
||||
"after eviction, overflow user should become cacheable again"
|
||||
);
|
||||
assert!(
|
||||
Arc::strong_count(&overflow_after) >= 2,
|
||||
"cacheable lock should be held by map and caller"
|
||||
);
|
||||
|
||||
drop(overflow_before);
|
||||
drop(overflow_after);
|
||||
evictor.abort();
|
||||
}
|
||||
|
|
@ -127,7 +127,7 @@ fn quota_lock_saturation_returns_stable_overflow_lock_without_cache_growth() {
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn quota_lock_reclaims_unreferenced_entries_before_ephemeral_fallback() {
|
||||
fn quota_lock_reclaims_unreferenced_entries_after_explicit_eviction_pass() {
|
||||
let _guard = super::quota_user_lock_test_scope();
|
||||
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
map.clear();
|
||||
|
|
@ -142,6 +142,8 @@ fn quota_lock_reclaims_unreferenced_entries_before_ephemeral_fallback() {
|
|||
|
||||
drop(retained);
|
||||
|
||||
quota_user_lock_evict();
|
||||
|
||||
let overflow_user = format!("quota-reclaim-overflow-{}", std::process::id());
|
||||
let overflow = quota_user_lock(&overflow_user);
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,249 @@
|
|||
use super::*;
|
||||
use crate::stats::Stats;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
|
||||
use std::task::{Context, Waker};
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::time::{Duration, Instant, timeout};
|
||||
|
||||
#[derive(Default)]
|
||||
struct WakeCounter {
|
||||
wakes: AtomicUsize,
|
||||
}
|
||||
|
||||
impl std::task::Wake for WakeCounter {
|
||||
fn wake(self: Arc<Self>) {
|
||||
self.wakes.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
fn wake_by_ref(self: &Arc<Self>) {
|
||||
self.wakes.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
fn quota_test_guard() -> impl Drop {
|
||||
super::quota_user_lock_test_scope()
|
||||
}
|
||||
|
||||
fn build_context() -> (Arc<WakeCounter>, Context<'static>) {
|
||||
let wake_counter = Arc::new(WakeCounter::default());
|
||||
let waker = Waker::from(Arc::clone(&wake_counter));
|
||||
let leaked_waker: &'static Waker = Box::leak(Box::new(waker));
|
||||
(wake_counter, Context::from_waker(leaked_waker))
|
||||
}
|
||||
|
||||
fn sleep_slot_ptr(slot: &Option<Pin<Box<tokio::time::Sleep>>>) -> usize {
|
||||
slot.as_ref()
|
||||
.map(|sleep| (&**sleep) as *const tokio::time::Sleep as usize)
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn tdd_single_pending_timer_does_not_allocate_on_each_repoll() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("retry-alloc-single-pending-{}", std::process::id());
|
||||
let lock = quota_user_lock(&user);
|
||||
let held_guard = lock
|
||||
.try_lock()
|
||||
.expect("test must hold local lock to force retry scheduling");
|
||||
|
||||
reset_quota_retry_sleep_allocs_for_tests();
|
||||
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
Some(2048),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
Instant::now(),
|
||||
);
|
||||
|
||||
let (_wake_counter, mut cx) = build_context();
|
||||
|
||||
let first = Pin::new(&mut io).poll_write(&mut cx, &[0xA1]);
|
||||
assert!(first.is_pending());
|
||||
let allocs_after_first = quota_retry_sleep_allocs_for_tests();
|
||||
let ptr_after_first = sleep_slot_ptr(&io.quota_write_retry_sleep);
|
||||
|
||||
let second = Pin::new(&mut io).poll_write(&mut cx, &[0xA2]);
|
||||
assert!(second.is_pending());
|
||||
let allocs_after_second = quota_retry_sleep_allocs_for_tests();
|
||||
let ptr_after_second = sleep_slot_ptr(&io.quota_write_retry_sleep);
|
||||
|
||||
assert_eq!(allocs_after_first, 1, "first pending poll must allocate one timer");
|
||||
assert_eq!(
|
||||
allocs_after_second, 1,
|
||||
"repoll while the same timer is pending must not allocate again"
|
||||
);
|
||||
assert_eq!(
|
||||
ptr_after_first, ptr_after_second,
|
||||
"repoll while pending should retain the same timer allocation"
|
||||
);
|
||||
|
||||
drop(held_guard);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn tdd_retry_cycle_allocates_once_per_fired_timer_cycle_not_per_poll() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("retry-alloc-per-cycle-{}", std::process::id());
|
||||
let lock = quota_user_lock(&user);
|
||||
let held_guard = lock
|
||||
.try_lock()
|
||||
.expect("test must hold local lock to keep write path pending");
|
||||
|
||||
reset_quota_retry_sleep_allocs_for_tests();
|
||||
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
Some(2048),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
Instant::now(),
|
||||
);
|
||||
|
||||
let (wake_counter, mut cx) = build_context();
|
||||
|
||||
let mut polls = 0u64;
|
||||
let mut observed_wakes = 0usize;
|
||||
let started = Instant::now();
|
||||
while started.elapsed() < Duration::from_millis(70) {
|
||||
let poll = Pin::new(&mut io).poll_write(&mut cx, &[0xB1]);
|
||||
polls = polls.saturating_add(1);
|
||||
assert!(poll.is_pending());
|
||||
|
||||
let wakes = wake_counter.wakes.load(Ordering::Relaxed);
|
||||
if wakes > observed_wakes {
|
||||
observed_wakes = wakes;
|
||||
}
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(1)).await;
|
||||
}
|
||||
|
||||
let allocs = quota_retry_sleep_allocs_for_tests();
|
||||
assert!(allocs >= 2, "multiple fired cycles should allocate multiple timers");
|
||||
assert!(
|
||||
allocs < polls,
|
||||
"timer allocations must be bounded by cycles, not by every repoll (allocs={allocs}, polls={polls})"
|
||||
);
|
||||
|
||||
drop(held_guard);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn adversarial_backoff_latency_envelope_stays_bounded_under_contention() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("retry-latency-envelope-{}", std::process::id());
|
||||
let lock = quota_user_lock(&user);
|
||||
let held_guard = lock
|
||||
.try_lock()
|
||||
.expect("test must hold local lock for sustained contention");
|
||||
|
||||
reset_quota_retry_sleep_allocs_for_tests();
|
||||
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
Some(2048),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
Instant::now(),
|
||||
);
|
||||
|
||||
let (wake_counter, mut cx) = build_context();
|
||||
|
||||
let first = Pin::new(&mut io).poll_write(&mut cx, &[0xC1]);
|
||||
assert!(first.is_pending());
|
||||
|
||||
let started = Instant::now();
|
||||
let mut last_wakes = 0usize;
|
||||
let mut wake_instants = Vec::new();
|
||||
|
||||
while started.elapsed() < Duration::from_millis(120) {
|
||||
let wakes = wake_counter.wakes.load(Ordering::Relaxed);
|
||||
if wakes > last_wakes {
|
||||
last_wakes = wakes;
|
||||
wake_instants.push(Instant::now());
|
||||
let pending = Pin::new(&mut io).poll_write(&mut cx, &[0xC2]);
|
||||
assert!(pending.is_pending());
|
||||
}
|
||||
tokio::time::sleep(Duration::from_millis(1)).await;
|
||||
}
|
||||
|
||||
let mut max_gap = Duration::from_millis(0);
|
||||
for idx in 1..wake_instants.len() {
|
||||
let gap = wake_instants[idx].saturating_duration_since(wake_instants[idx - 1]);
|
||||
if gap > max_gap {
|
||||
max_gap = gap;
|
||||
}
|
||||
}
|
||||
|
||||
assert!(
|
||||
max_gap <= Duration::from_millis(35),
|
||||
"retry wake gap must remain bounded in test profile; observed max gap={max_gap:?}"
|
||||
);
|
||||
assert!(
|
||||
quota_retry_sleep_allocs_for_tests() <= 16,
|
||||
"allocation cycles must remain bounded during a short contention window"
|
||||
);
|
||||
|
||||
drop(held_guard);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn micro_benchmark_release_to_completion_latency_stays_bounded() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let rounds = 96usize;
|
||||
let mut samples_ms = Vec::with_capacity(rounds);
|
||||
|
||||
for round in 0..rounds {
|
||||
let user = format!("retry-release-latency-{}-{round}", std::process::id());
|
||||
let lock = quota_user_lock(&user);
|
||||
let held_guard = lock
|
||||
.try_lock()
|
||||
.expect("test must hold local lock before spawning blocked writer");
|
||||
|
||||
let writer = tokio::spawn(async move {
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
Some(2048),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
Instant::now(),
|
||||
);
|
||||
io.write_all(&[0xD1]).await
|
||||
});
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(2)).await;
|
||||
let release_at = Instant::now();
|
||||
drop(held_guard);
|
||||
|
||||
let done = timeout(Duration::from_millis(120), writer)
|
||||
.await
|
||||
.expect("blocked writer must complete after release")
|
||||
.expect("writer task must not panic");
|
||||
assert!(done.is_ok());
|
||||
|
||||
samples_ms.push(release_at.elapsed().as_millis() as u64);
|
||||
}
|
||||
|
||||
samples_ms.sort_unstable();
|
||||
let p95_idx = ((samples_ms.len() * 95) / 100).min(samples_ms.len().saturating_sub(1));
|
||||
let p95_ms = samples_ms[p95_idx];
|
||||
|
||||
assert!(
|
||||
p95_ms <= 40,
|
||||
"contention release->completion p95 must stay bounded; p95_ms={p95_ms}, samples={samples_ms:?}"
|
||||
);
|
||||
}
|
||||
Loading…
Reference in New Issue