Redesign Quotas on Atomics

This commit is contained in:
Alexey 2026-03-23 15:53:44 +03:00
parent 0c3c9009a9
commit 6f4356f72a
No known key found for this signature in database
10 changed files with 408 additions and 1043 deletions

View File

@ -32,14 +32,6 @@ pub(crate) struct RuntimeWatches {
pub(crate) detected_ip_v6: Option<IpAddr>,
}
const QUOTA_USER_LOCK_EVICT_INTERVAL_SECS: u64 = 60;
fn spawn_quota_lock_maintenance_task() -> tokio::task::JoinHandle<()> {
crate::proxy::relay::spawn_quota_user_lock_evictor(std::time::Duration::from_secs(
QUOTA_USER_LOCK_EVICT_INTERVAL_SECS,
))
}
#[allow(clippy::too_many_arguments)]
pub(crate) async fn spawn_runtime_tasks(
config: &Arc<ProxyConfig>,
@ -77,8 +69,6 @@ pub(crate) async fn spawn_runtime_tasks(
rc_clone.run_periodic_cleanup().await;
});
spawn_quota_lock_maintenance_task();
let detected_ip_v4: Option<IpAddr> = probe.detected_ipv4.map(IpAddr::V4);
let detected_ip_v6: Option<IpAddr> = probe.detected_ipv6.map(IpAddr::V6);
debug!(
@ -370,24 +360,3 @@ pub(crate) async fn mark_runtime_ready(startup_tracker: &Arc<StartupTracker>) {
.await;
startup_tracker.mark_ready().await;
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn tdd_runtime_quota_lock_maintenance_path_spawns_single_evictor_task() {
crate::proxy::relay::reset_quota_user_lock_evictor_spawn_count_for_tests();
let handle = spawn_quota_lock_maintenance_task();
tokio::time::sleep(std::time::Duration::from_millis(5)).await;
assert_eq!(
crate::proxy::relay::quota_user_lock_evictor_spawn_count_for_tests(),
1,
"runtime maintenance path must spawn exactly one quota lock evictor task per call"
);
handle.abort();
}
}

View File

@ -1223,7 +1223,7 @@ impl RunningClientHandler {
}
if let Some(quota) = config.access.user_data_quota.get(user)
&& stats.get_user_total_octets(user) >= *quota
&& stats.get_user_quota_used(user) >= *quota
{
return Err(ProxyError::DataQuotaExceeded {
user: user.to_string(),
@ -1282,7 +1282,7 @@ impl RunningClientHandler {
}
if let Some(quota) = config.access.user_data_quota.get(user)
&& stats.get_user_total_octets(user) >= *quota
&& stats.get_user_quota_used(user) >= *quota
{
return Err(ProxyError::DataQuotaExceeded {
user: user.to_string(),

View File

@ -614,6 +614,15 @@ where
}
};
// Reject known replay digests before expensive cache/domain/ALPN policy work.
let digest_half = &validation.digest[..tls::TLS_DIGEST_HALF_LEN];
if replay_checker.check_tls_digest(digest_half) {
auth_probe_record_failure(peer.ip(), Instant::now());
maybe_apply_server_hello_delay(config).await;
warn!(peer = %peer, "TLS replay attack detected (duplicate digest)");
return HandshakeResult::BadClient { reader, writer };
}
let secret = match secrets.iter().find(|(name, _)| *name == validation.user) {
Some((_, s)) => s,
None => {
@ -669,15 +678,8 @@ where
None
};
// Replay tracking is applied only after full policy validation (including
// ALPN checks) so rejected handshakes cannot poison replay state.
let digest_half = &validation.digest[..tls::TLS_DIGEST_HALF_LEN];
if replay_checker.check_and_add_tls_digest(digest_half) {
auth_probe_record_failure(peer.ip(), Instant::now());
maybe_apply_server_hello_delay(config).await;
warn!(peer = %peer, "TLS replay attack detected (duplicate digest)");
return HandshakeResult::BadClient { reader, writer };
}
// Add replay digest only for policy-valid handshakes.
replay_checker.add_tls_digest(digest_half);
let response = if let Some((cached_entry, use_full_cert_payload)) = cached {
emulator::build_emulated_server_hello(

View File

@ -60,7 +60,7 @@ where
R: AsyncRead + Unpin,
W: AsyncWrite + Unpin,
{
let mut buf = [0u8; MASK_BUFFER_SIZE];
let mut buf = Box::new([0u8; MASK_BUFFER_SIZE]);
let mut total = 0usize;
let mut ended_by_eof = false;
@ -262,7 +262,11 @@ fn mask_outcome_target_budget(config: &ProxyConfig) -> Duration {
let floor = config.censorship.mask_timing_normalization_floor_ms;
let ceiling = config.censorship.mask_timing_normalization_ceiling_ms;
if floor == 0 {
return MASK_TIMEOUT;
if ceiling == 0 {
return Duration::from_millis(0);
}
let mut rng = rand::rng();
return Duration::from_millis(rng.random_range(0..=ceiling));
}
if ceiling > floor {
let mut rng = rand::rng();
@ -838,7 +842,7 @@ async fn consume_client_data<R: AsyncRead + Unpin>(mut reader: R, byte_cap: usiz
}
// Keep drain path fail-closed under slow-loris stalls.
let mut buf = [0u8; MASK_BUFFER_SIZE];
let mut buf = Box::new([0u8; MASK_BUFFER_SIZE]);
let mut total = 0usize;
loop {

View File

@ -10,7 +10,7 @@ use std::time::{Duration, Instant};
use dashmap::DashMap;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::sync::{Mutex as AsyncMutex, mpsc, oneshot, watch};
use tokio::sync::{mpsc, oneshot, watch};
use tokio::time::timeout;
use tracing::{debug, info, trace, warn};
@ -23,7 +23,7 @@ use crate::proxy::route_mode::{
ROUTE_SWITCH_ERROR_MSG, RelayRouteMode, RouteCutoverState, affected_cutover_state,
cutover_stagger_delay,
};
use crate::stats::{MeD2cFlushReason, MeD2cQuotaRejectStage, MeD2cWriteMode, Stats};
use crate::stats::{MeD2cFlushReason, MeD2cQuotaRejectStage, MeD2cWriteMode, QuotaReserveError, Stats, UserStats};
use crate::stream::{BufferPool, CryptoReader, CryptoWriter, PooledBuffer};
use crate::transport::middle_proxy::{MePool, MeResponse, proto_flags_for_tag};
@ -53,20 +53,11 @@ const ME_D2C_FLUSH_BATCH_MAX_FRAMES_MIN: usize = 1;
const ME_D2C_FLUSH_BATCH_MAX_BYTES_MIN: usize = 4096;
const ME_D2C_FRAME_BUF_SHRINK_HYSTERESIS_FACTOR: usize = 2;
const ME_D2C_SINGLE_WRITE_COALESCE_MAX_BYTES: usize = 128 * 1024;
#[cfg(test)]
const QUOTA_USER_LOCKS_MAX: usize = 64;
#[cfg(not(test))]
const QUOTA_USER_LOCKS_MAX: usize = 4_096;
#[cfg(test)]
const QUOTA_OVERFLOW_LOCK_STRIPES: usize = 16;
#[cfg(not(test))]
const QUOTA_OVERFLOW_LOCK_STRIPES: usize = 256;
const QUOTA_RESERVE_SPIN_RETRIES: usize = 32;
static DESYNC_DEDUP: OnceLock<DashMap<u64, Instant>> = OnceLock::new();
static DESYNC_HASHER: OnceLock<RandomState> = OnceLock::new();
static DESYNC_FULL_CACHE_LAST_EMIT_AT: OnceLock<Mutex<Option<Instant>>> = OnceLock::new();
static DESYNC_DEDUP_EVER_SATURATED: OnceLock<AtomicBool> = OnceLock::new();
static QUOTA_USER_LOCKS: OnceLock<DashMap<String, Arc<AsyncMutex<()>>>> = OnceLock::new();
static QUOTA_USER_OVERFLOW_LOCKS: OnceLock<Vec<Arc<AsyncMutex<()>>>> = OnceLock::new();
static RELAY_IDLE_CANDIDATE_REGISTRY: OnceLock<Mutex<RelayIdleCandidateRegistry>> = OnceLock::new();
static RELAY_IDLE_MARK_SEQ: AtomicU64 = AtomicU64::new(0);
@ -538,36 +529,28 @@ fn should_yield_c2me_sender(sent_since_yield: usize, has_backlog: bool) -> bool
has_backlog && sent_since_yield >= C2ME_SENDER_FAIRNESS_BUDGET
}
fn quota_exceeded_for_user(stats: &Stats, user: &str, quota_limit: Option<u64>) -> bool {
quota_limit.is_some_and(|quota| stats.get_user_total_octets(user) >= quota)
}
#[cfg_attr(not(test), allow(dead_code))]
fn quota_would_be_exceeded_for_user(
stats: &Stats,
user: &str,
quota_limit: Option<u64>,
bytes: u64,
) -> bool {
quota_limit.is_some_and(|quota| {
let used = stats.get_user_total_octets(user);
used >= quota || bytes > quota.saturating_sub(used)
})
}
fn quota_soft_cap(limit: u64, overshoot: u64) -> u64 {
limit.saturating_add(overshoot)
}
fn quota_would_be_exceeded_for_user_soft(
stats: &Stats,
user: &str,
quota_limit: Option<u64>,
async fn reserve_user_quota_with_yield(
user_stats: &UserStats,
bytes: u64,
overshoot: u64,
) -> bool {
let capped_limit = quota_limit.map(|quota| quota_soft_cap(quota, overshoot));
quota_would_be_exceeded_for_user(stats, user, capped_limit, bytes)
limit: u64,
) -> std::result::Result<u64, QuotaReserveError> {
loop {
for _ in 0..QUOTA_RESERVE_SPIN_RETRIES {
match user_stats.quota_try_reserve(bytes, limit) {
Ok(total) => return Ok(total),
Err(QuotaReserveError::LimitExceeded) => {
return Err(QuotaReserveError::LimitExceeded);
}
Err(QuotaReserveError::Contended) => std::hint::spin_loop(),
}
}
tokio::task::yield_now().await;
}
}
fn classify_me_d2c_flush_reason(
@ -613,29 +596,6 @@ fn observe_me_d2c_flush_event(
}
}
fn rollback_me2c_quota_reservation(
stats: &Stats,
user: &str,
bytes_me2c: &AtomicU64,
reserved_bytes: u64,
) {
stats.sub_user_octets_to(user, reserved_bytes);
bytes_me2c.fetch_sub(reserved_bytes, Ordering::Relaxed);
}
#[cfg(test)]
fn quota_user_lock_test_guard() -> &'static Mutex<()> {
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
TEST_LOCK.get_or_init(|| Mutex::new(()))
}
#[cfg(test)]
fn quota_user_lock_test_scope() -> std::sync::MutexGuard<'static, ()> {
quota_user_lock_test_guard()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
#[cfg(test)]
fn relay_idle_pressure_test_guard() -> &'static Mutex<()> {
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
@ -649,46 +609,6 @@ pub(crate) fn relay_idle_pressure_test_scope() -> std::sync::MutexGuard<'static,
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
fn quota_overflow_user_lock(user: &str) -> Arc<AsyncMutex<()>> {
let stripes = QUOTA_USER_OVERFLOW_LOCKS.get_or_init(|| {
(0..QUOTA_OVERFLOW_LOCK_STRIPES)
.map(|_| Arc::new(AsyncMutex::new(())))
.collect()
});
let hash = crc32fast::hash(user.as_bytes()) as usize;
Arc::clone(&stripes[hash % stripes.len()])
}
fn quota_user_lock(user: &str) -> Arc<AsyncMutex<()>> {
let locks = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
if let Some(existing) = locks.get(user) {
return Arc::clone(existing.value());
}
if locks.len() >= QUOTA_USER_LOCKS_MAX {
locks.retain(|_, value| Arc::strong_count(value) > 1);
}
if locks.len() >= QUOTA_USER_LOCKS_MAX {
return quota_overflow_user_lock(user);
}
let created = Arc::new(AsyncMutex::new(()));
match locks.entry(user.to_string()) {
dashmap::mapref::entry::Entry::Occupied(entry) => Arc::clone(entry.get()),
dashmap::mapref::entry::Entry::Vacant(entry) => {
entry.insert(Arc::clone(&created));
created
}
}
}
#[cfg(test)]
pub(crate) fn cross_mode_quota_user_lock_for_tests(user: &str) -> Arc<AsyncMutex<()>> {
crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(user)
}
async fn enqueue_c2me_command(
tx: &mpsc::Sender<C2MeCommand>,
cmd: C2MeCommand,
@ -744,8 +664,7 @@ where
{
let user = success.user.clone();
let quota_limit = config.access.user_data_quota.get(&user).copied();
let cross_mode_quota_lock =
quota_limit.map(|_| crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user));
let quota_user_stats = quota_limit.map(|_| stats.get_or_create_user_stats_handle(&user));
let peer = success.peer;
let proto_tag = success.proto_tag;
let pool_generation = me_pool.current_generation();
@ -872,7 +791,7 @@ where
let stats_clone = stats.clone();
let rng_clone = rng.clone();
let user_clone = user.clone();
let cross_mode_quota_lock_me_writer = cross_mode_quota_lock.clone();
let quota_user_stats_me_writer = quota_user_stats.clone();
let last_downstream_activity_ms_clone = last_downstream_activity_ms.clone();
let bytes_me2c_clone = bytes_me2c.clone();
let d2c_flush_policy = MeD2cFlushPolicy::from_config(&config);
@ -894,7 +813,7 @@ where
let first_is_downstream_activity =
matches!(&first, MeResponse::Data { .. } | MeResponse::Ack(_));
match process_me_writer_response_with_cross_mode_lock(
match process_me_writer_response(
first,
&mut writer,
proto_tag,
@ -902,9 +821,9 @@ where
&mut frame_buf,
stats_clone.as_ref(),
&user_clone,
quota_user_stats_me_writer.as_deref(),
quota_limit,
d2c_flush_policy.quota_soft_overshoot_bytes,
cross_mode_quota_lock_me_writer.as_ref(),
bytes_me2c_clone.as_ref(),
conn_id,
d2c_flush_policy.ack_flush_immediate,
@ -953,7 +872,7 @@ where
let next_is_downstream_activity =
matches!(&next, MeResponse::Data { .. } | MeResponse::Ack(_));
match process_me_writer_response_with_cross_mode_lock(
match process_me_writer_response(
next,
&mut writer,
proto_tag,
@ -961,9 +880,9 @@ where
&mut frame_buf,
stats_clone.as_ref(),
&user_clone,
quota_user_stats_me_writer.as_deref(),
quota_limit,
d2c_flush_policy.quota_soft_overshoot_bytes,
cross_mode_quota_lock_me_writer.as_ref(),
bytes_me2c_clone.as_ref(),
conn_id,
d2c_flush_policy.ack_flush_immediate,
@ -1015,7 +934,7 @@ where
Ok(Some(next)) => {
let next_is_downstream_activity =
matches!(&next, MeResponse::Data { .. } | MeResponse::Ack(_));
match process_me_writer_response_with_cross_mode_lock(
match process_me_writer_response(
next,
&mut writer,
proto_tag,
@ -1023,9 +942,9 @@ where
&mut frame_buf,
stats_clone.as_ref(),
&user_clone,
quota_user_stats_me_writer.as_deref(),
quota_limit,
d2c_flush_policy.quota_soft_overshoot_bytes,
cross_mode_quota_lock_me_writer.as_ref(),
bytes_me2c_clone.as_ref(),
conn_id,
d2c_flush_policy.ack_flush_immediate,
@ -1079,7 +998,7 @@ where
let extra_is_downstream_activity =
matches!(&extra, MeResponse::Data { .. } | MeResponse::Ack(_));
match process_me_writer_response_with_cross_mode_lock(
match process_me_writer_response(
extra,
&mut writer,
proto_tag,
@ -1087,9 +1006,9 @@ where
&mut frame_buf,
stats_clone.as_ref(),
&user_clone,
quota_user_stats_me_writer.as_deref(),
quota_limit,
d2c_flush_policy.quota_soft_overshoot_bytes,
cross_mode_quota_lock_me_writer.as_ref(),
bytes_me2c_clone.as_ref(),
conn_id,
d2c_flush_policy.ack_flush_immediate,
@ -1259,24 +1178,23 @@ where
forensics.bytes_c2me = forensics
.bytes_c2me
.saturating_add(payload.len() as u64);
if let Some(limit) = quota_limit {
let quota_lock = quota_user_lock(&user);
let _quota_guard = quota_lock.lock().await;
let Some(cross_mode_lock) = cross_mode_quota_lock.as_ref() else {
main_result = Err(ProxyError::Proxy(
"cross-mode quota lock missing for quota-limited session"
.to_string(),
));
break;
};
let _cross_mode_quota_guard = cross_mode_lock.lock().await;
stats.add_user_octets_from(&user, payload.len() as u64);
if quota_exceeded_for_user(stats.as_ref(), &user, Some(limit)) {
if let (Some(limit), Some(user_stats)) =
(quota_limit, quota_user_stats.as_deref())
{
if reserve_user_quota_with_yield(
user_stats,
payload.len() as u64,
limit,
)
.await
.is_err()
{
main_result = Err(ProxyError::DataQuotaExceeded {
user: user.clone(),
});
break;
}
stats.add_user_octets_from_handle(user_stats, payload.len() as u64);
} else {
stats.add_user_octets_from(&user, payload.len() as u64);
}
@ -1755,7 +1673,6 @@ enum MeWriterResponseOutcome {
Close,
}
#[cfg(test)]
async fn process_me_writer_response<W>(
response: MeResponse,
client_writer: &mut CryptoWriter<W>,
@ -1764,6 +1681,7 @@ async fn process_me_writer_response<W>(
frame_buf: &mut Vec<u8>,
stats: &Stats,
user: &str,
quota_user_stats: Option<&UserStats>,
quota_limit: Option<u64>,
quota_soft_overshoot_bytes: u64,
bytes_me2c: &AtomicU64,
@ -1771,44 +1689,6 @@ async fn process_me_writer_response<W>(
ack_flush_immediate: bool,
batched: bool,
) -> Result<MeWriterResponseOutcome>
where
W: AsyncWrite + Unpin + Send + 'static,
{
process_me_writer_response_with_cross_mode_lock(
response,
client_writer,
proto_tag,
rng,
frame_buf,
stats,
user,
quota_limit,
quota_soft_overshoot_bytes,
None,
bytes_me2c,
conn_id,
ack_flush_immediate,
batched,
)
.await
}
async fn process_me_writer_response_with_cross_mode_lock<W>(
response: MeResponse,
client_writer: &mut CryptoWriter<W>,
proto_tag: ProtoTag,
rng: &SecureRandom,
frame_buf: &mut Vec<u8>,
stats: &Stats,
user: &str,
quota_limit: Option<u64>,
quota_soft_overshoot_bytes: u64,
cross_mode_quota_lock: Option<&Arc<AsyncMutex<()>>>,
bytes_me2c: &AtomicU64,
conn_id: u64,
ack_flush_immediate: bool,
batched: bool,
) -> Result<MeWriterResponseOutcome>
where
W: AsyncWrite + Unpin + Send + 'static,
{
@ -1820,78 +1700,43 @@ where
trace!(conn_id, bytes = data.len(), flags, "ME->C data");
}
let data_len = data.len() as u64;
if let Some(limit) = quota_limit {
let owned_cross_mode_lock;
let cross_mode_lock = if let Some(lock) = cross_mode_quota_lock {
lock
} else {
owned_cross_mode_lock =
crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(user);
&owned_cross_mode_lock
};
let cross_mode_quota_guard = cross_mode_lock.lock().await;
if let (Some(limit), Some(user_stats)) = (quota_limit, quota_user_stats) {
let soft_limit = quota_soft_cap(limit, quota_soft_overshoot_bytes);
if quota_would_be_exceeded_for_user_soft(
stats,
user,
Some(limit),
data_len,
quota_soft_overshoot_bytes,
) {
if reserve_user_quota_with_yield(user_stats, data_len, soft_limit)
.await
.is_err()
{
stats.increment_me_d2c_quota_reject_total(MeD2cQuotaRejectStage::PreWrite);
return Err(ProxyError::DataQuotaExceeded {
user: user.to_string(),
});
}
// Reserve quota before awaiting network I/O to avoid same-user HoL stalls.
// If reservation loses a race or write fails, we roll back immediately.
bytes_me2c.fetch_add(data_len, Ordering::Relaxed);
stats.add_user_octets_to(user, data_len);
if stats.get_user_total_octets(user) > soft_limit {
rollback_me2c_quota_reservation(stats, user, bytes_me2c, data_len);
stats.increment_me_d2c_quota_reject_total(MeD2cQuotaRejectStage::PreWrite);
return Err(ProxyError::DataQuotaExceeded {
user: user.to_string(),
});
}
// Keep cross-mode lock scope explicit and minimal: quota reservation is serialized,
// but socket I/O proceeds without holding same-user cross-mode admission lock.
drop(cross_mode_quota_guard);
let write_mode =
match write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf)
.await
{
Ok(mode) => mode,
Err(err) => {
rollback_me2c_quota_reservation(stats, user, bytes_me2c, data_len);
return Err(err);
}
};
stats.increment_me_d2c_data_frames_total();
stats.add_me_d2c_payload_bytes_total(data_len);
stats.increment_me_d2c_write_mode(write_mode);
// Do not fail immediately on exact boundary after a successful write.
// Returning an error here can bypass batch flush in the caller and risk
// dropping buffered ciphertext from CryptoWriter. The next frame is
// rejected by the pre-check at function entry.
} else {
let write_mode =
write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf)
.await?;
bytes_me2c.fetch_add(data_len, Ordering::Relaxed);
stats.add_user_octets_to(user, data_len);
stats.increment_me_d2c_data_frames_total();
stats.add_me_d2c_payload_bytes_total(data_len);
stats.increment_me_d2c_write_mode(write_mode);
}
let write_mode =
match write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf)
.await
{
Ok(mode) => mode,
Err(err) => {
if quota_limit.is_some() {
stats.add_quota_write_fail_bytes_total(data_len);
stats.increment_quota_write_fail_events_total();
}
return Err(err);
}
};
bytes_me2c.fetch_add(data_len, Ordering::Relaxed);
if let Some(user_stats) = quota_user_stats {
stats.add_user_octets_to_handle(user_stats, data_len);
} else {
stats.add_user_octets_to(user, data_len);
}
stats.increment_me_d2c_data_frames_total();
stats.add_me_d2c_payload_bytes_total(data_len);
stats.increment_me_d2c_write_mode(write_mode);
Ok(MeWriterResponseOutcome::Continue {
frames: 1,
bytes: data.len(),
@ -2097,10 +1942,6 @@ where
.map_err(ProxyError::Io)
}
#[cfg(test)]
#[path = "tests/middle_relay_security_tests.rs"]
mod security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_idle_policy_security_tests.rs"]
mod idle_policy_security_tests;
@ -2113,30 +1954,10 @@ mod desync_all_full_dedup_security_tests;
#[path = "tests/middle_relay_stub_completion_security_tests.rs"]
mod stub_completion_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_coverage_high_risk_security_tests.rs"]
mod coverage_high_risk_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_quota_overflow_lock_security_tests.rs"]
mod quota_overflow_lock_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_length_cast_hardening_security_tests.rs"]
mod length_cast_hardening_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_blackhat_campaign_integration_tests.rs"]
mod blackhat_campaign_integration_tests;
#[cfg(test)]
#[path = "tests/middle_relay_hol_quota_security_tests.rs"]
mod hol_quota_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_quota_reservation_adversarial_tests.rs"]
mod quota_reservation_adversarial_tests;
#[cfg(test)]
#[path = "tests/middle_relay_idle_registry_poison_security_tests.rs"]
mod middle_relay_idle_registry_poison_security_tests;
@ -2156,27 +1977,3 @@ mod middle_relay_tiny_frame_debt_concurrency_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_tiny_frame_debt_proto_chunking_security_tests.rs"]
mod middle_relay_tiny_frame_debt_proto_chunking_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_cross_mode_quota_reservation_security_tests.rs"]
mod middle_relay_cross_mode_quota_reservation_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_cross_mode_quota_lock_matrix_security_tests.rs"]
mod middle_relay_cross_mode_quota_lock_matrix_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_cross_mode_lookup_efficiency_security_tests.rs"]
mod middle_relay_cross_mode_lookup_efficiency_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_cross_mode_lock_release_regression_security_tests.rs"]
mod middle_relay_cross_mode_lock_release_regression_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_quota_extended_attack_surface_security_tests.rs"]
mod middle_relay_quota_extended_attack_surface_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_quota_reservation_extreme_security_tests.rs"]
mod middle_relay_quota_reservation_extreme_security_tests;

View File

@ -64,7 +64,6 @@ pub mod direct_relay;
pub mod handshake;
pub mod masking;
pub mod middle_relay;
pub mod quota_lock_registry;
pub mod relay;
pub mod route_mode;
pub mod session_eviction;

View File

@ -1,88 +0,0 @@
use dashmap::DashMap;
use std::sync::{Arc, OnceLock};
use tokio::sync::Mutex;
#[cfg(test)]
use std::sync::atomic::{AtomicUsize, Ordering};
#[cfg(test)]
const CROSS_MODE_QUOTA_USER_LOCKS_MAX: usize = 64;
#[cfg(not(test))]
const CROSS_MODE_QUOTA_USER_LOCKS_MAX: usize = 4_096;
#[cfg(test)]
const CROSS_MODE_QUOTA_OVERFLOW_LOCK_STRIPES: usize = 16;
#[cfg(not(test))]
const CROSS_MODE_QUOTA_OVERFLOW_LOCK_STRIPES: usize = 256;
static CROSS_MODE_QUOTA_USER_LOCKS: OnceLock<DashMap<String, Arc<Mutex<()>>>> = OnceLock::new();
static CROSS_MODE_QUOTA_USER_OVERFLOW_LOCKS: OnceLock<Vec<Arc<Mutex<()>>>> = OnceLock::new();
#[cfg(test)]
static CROSS_MODE_QUOTA_USER_LOCK_LOOKUPS: AtomicUsize = AtomicUsize::new(0);
#[cfg(test)]
static CROSS_MODE_QUOTA_USER_LOOKUPS_BY_USER: OnceLock<DashMap<String, usize>> = OnceLock::new();
fn cross_mode_quota_overflow_user_lock(user: &str) -> Arc<Mutex<()>> {
let stripes = CROSS_MODE_QUOTA_USER_OVERFLOW_LOCKS.get_or_init(|| {
(0..CROSS_MODE_QUOTA_OVERFLOW_LOCK_STRIPES)
.map(|_| Arc::new(Mutex::new(())))
.collect()
});
let hash = crc32fast::hash(user.as_bytes()) as usize;
Arc::clone(&stripes[hash % stripes.len()])
}
pub(crate) fn cross_mode_quota_user_lock(user: &str) -> Arc<Mutex<()>> {
#[cfg(test)]
{
CROSS_MODE_QUOTA_USER_LOCK_LOOKUPS.fetch_add(1, Ordering::Relaxed);
let lookups = CROSS_MODE_QUOTA_USER_LOOKUPS_BY_USER.get_or_init(DashMap::new);
let mut entry = lookups.entry(user.to_string()).or_insert(0);
*entry += 1;
}
let locks = CROSS_MODE_QUOTA_USER_LOCKS.get_or_init(DashMap::new);
if let Some(existing) = locks.get(user) {
return Arc::clone(existing.value());
}
if locks.len() >= CROSS_MODE_QUOTA_USER_LOCKS_MAX {
locks.retain(|_, value| Arc::strong_count(value) > 1);
}
if locks.len() >= CROSS_MODE_QUOTA_USER_LOCKS_MAX {
return cross_mode_quota_overflow_user_lock(user);
}
let created = Arc::new(Mutex::new(()));
match locks.entry(user.to_string()) {
dashmap::mapref::entry::Entry::Occupied(entry) => Arc::clone(entry.get()),
dashmap::mapref::entry::Entry::Vacant(entry) => {
entry.insert(Arc::clone(&created));
created
}
}
}
#[cfg(test)]
pub(crate) fn reset_cross_mode_quota_user_lock_lookup_count_for_tests() {
CROSS_MODE_QUOTA_USER_LOCK_LOOKUPS.store(0, Ordering::Relaxed);
let lookups = CROSS_MODE_QUOTA_USER_LOOKUPS_BY_USER.get_or_init(DashMap::new);
lookups.clear();
}
#[cfg(test)]
pub(crate) fn cross_mode_quota_user_lock_lookup_count_for_tests() -> usize {
CROSS_MODE_QUOTA_USER_LOCK_LOOKUPS.load(Ordering::Relaxed)
}
#[cfg(test)]
pub(crate) fn cross_mode_quota_user_lock_lookup_count_for_user_for_tests(user: &str) -> usize {
let lookups = CROSS_MODE_QUOTA_USER_LOOKUPS_BY_USER.get_or_init(DashMap::new);
lookups.get(user).map(|entry| *entry).unwrap_or(0)
}
#[cfg(test)]
#[path = "tests/quota_lock_registry_cross_mode_adversarial_tests.rs"]
mod quota_lock_registry_cross_mode_adversarial_tests;

View File

@ -52,18 +52,16 @@
//! - `SharedCounters` (atomics) let the watchdog read stats without locking
use crate::error::{ProxyError, Result};
use crate::stats::Stats;
use crate::stats::{Stats, UserStats};
use crate::stream::BufferPool;
use dashmap::DashMap;
use std::io;
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::{Arc, Mutex, OnceLock};
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf, copy_bidirectional_with_sizes};
use tokio::sync::Mutex as AsyncMutex;
use tokio::time::{Instant, Sleep};
use tokio::time::Instant;
use tracing::{debug, trace, warn};
// ============= Constants =============
@ -210,16 +208,10 @@ struct StatsIo<S> {
counters: Arc<SharedCounters>,
stats: Arc<Stats>,
user: String,
quota_lock: Option<Arc<Mutex<()>>>,
cross_mode_quota_lock: Option<Arc<AsyncMutex<()>>>,
user_stats: Arc<UserStats>,
quota_limit: Option<u64>,
quota_exceeded: Arc<AtomicBool>,
quota_read_wake_scheduled: bool,
quota_write_wake_scheduled: bool,
quota_read_retry_sleep: Option<Pin<Box<Sleep>>>,
quota_write_retry_sleep: Option<Pin<Box<Sleep>>>,
quota_read_retry_attempt: u8,
quota_write_retry_attempt: u8,
quota_bytes_since_check: u64,
epoch: Instant,
}
@ -235,24 +227,16 @@ impl<S> StatsIo<S> {
) -> Self {
// Mark initial activity so the watchdog doesn't fire before data flows
counters.touch(Instant::now(), epoch);
let quota_lock = quota_limit.map(|_| quota_user_lock(&user));
let cross_mode_quota_lock = quota_limit
.map(|_| crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user));
let user_stats = stats.get_or_create_user_stats_handle(&user);
Self {
inner,
counters,
stats,
user,
quota_lock,
cross_mode_quota_lock,
user_stats,
quota_limit,
quota_exceeded,
quota_read_wake_scheduled: false,
quota_write_wake_scheduled: false,
quota_read_retry_sleep: None,
quota_write_retry_sleep: None,
quota_read_retry_attempt: 0,
quota_write_retry_attempt: 0,
quota_bytes_since_check: 0,
epoch,
}
}
@ -281,169 +265,24 @@ fn is_quota_io_error(err: &io::Error) -> bool {
.is_some()
}
#[cfg(test)]
const QUOTA_CONTENTION_RETRY_INTERVAL: Duration = Duration::from_millis(1);
#[cfg(not(test))]
const QUOTA_CONTENTION_RETRY_INTERVAL: Duration = Duration::from_millis(2);
#[cfg(test)]
const QUOTA_CONTENTION_RETRY_MAX_INTERVAL: Duration = Duration::from_millis(16);
#[cfg(not(test))]
const QUOTA_CONTENTION_RETRY_MAX_INTERVAL: Duration = Duration::from_millis(64);
const QUOTA_NEAR_LIMIT_BYTES: u64 = 64 * 1024;
const QUOTA_LARGE_CHARGE_BYTES: u64 = 16 * 1024;
const QUOTA_ADAPTIVE_INTERVAL_MIN_BYTES: u64 = 4 * 1024;
const QUOTA_ADAPTIVE_INTERVAL_MAX_BYTES: u64 = 64 * 1024;
#[cfg(test)]
static QUOTA_RETRY_SLEEP_ALLOCS: AtomicU64 = AtomicU64::new(0);
#[cfg(test)]
static QUOTA_USER_LOCK_EVICTOR_SPAWN_COUNT: AtomicU64 = AtomicU64::new(0);
#[cfg(test)]
pub(crate) fn reset_quota_retry_sleep_allocs_for_tests() {
QUOTA_RETRY_SLEEP_ALLOCS.store(0, Ordering::Relaxed);
}
#[cfg(test)]
pub(crate) fn quota_retry_sleep_allocs_for_tests() -> u64 {
QUOTA_RETRY_SLEEP_ALLOCS.load(Ordering::Relaxed)
#[inline]
fn quota_adaptive_interval_bytes(remaining_before: u64) -> u64 {
remaining_before
.saturating_div(2)
.clamp(
QUOTA_ADAPTIVE_INTERVAL_MIN_BYTES,
QUOTA_ADAPTIVE_INTERVAL_MAX_BYTES,
)
}
#[inline]
fn quota_contention_retry_delay(retry_attempt: u8) -> Duration {
let shift = u32::from(retry_attempt.min(5));
let multiplier = 1_u32 << shift;
QUOTA_CONTENTION_RETRY_INTERVAL
.saturating_mul(multiplier)
.min(QUOTA_CONTENTION_RETRY_MAX_INTERVAL)
}
#[inline]
fn reset_quota_retry_scheduler(
sleep_slot: &mut Option<Pin<Box<Sleep>>>,
wake_scheduled: &mut bool,
retry_attempt: &mut u8,
) {
*wake_scheduled = false;
*sleep_slot = None;
*retry_attempt = 0;
}
fn poll_quota_retry_sleep(
sleep_slot: &mut Option<Pin<Box<Sleep>>>,
wake_scheduled: &mut bool,
retry_attempt: &mut u8,
cx: &mut Context<'_>,
) {
if !*wake_scheduled {
*wake_scheduled = true;
#[cfg(test)]
QUOTA_RETRY_SLEEP_ALLOCS.fetch_add(1, Ordering::Relaxed);
*sleep_slot = Some(Box::pin(tokio::time::sleep(quota_contention_retry_delay(
*retry_attempt,
))));
}
if let Some(sleep) = sleep_slot.as_mut()
&& sleep.as_mut().poll(cx).is_ready()
{
*sleep_slot = None;
*wake_scheduled = false;
*retry_attempt = retry_attempt.saturating_add(1);
cx.waker().wake_by_ref();
}
}
static QUOTA_USER_LOCKS: OnceLock<DashMap<String, Arc<Mutex<()>>>> = OnceLock::new();
static QUOTA_USER_OVERFLOW_LOCKS: OnceLock<Vec<Arc<Mutex<()>>>> = OnceLock::new();
#[cfg(test)]
const QUOTA_USER_LOCKS_MAX: usize = 64;
#[cfg(not(test))]
const QUOTA_USER_LOCKS_MAX: usize = 4_096;
#[cfg(test)]
const QUOTA_OVERFLOW_LOCK_STRIPES: usize = 16;
#[cfg(not(test))]
const QUOTA_OVERFLOW_LOCK_STRIPES: usize = 256;
#[cfg(test)]
fn quota_user_lock_test_guard() -> &'static Mutex<()> {
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
TEST_LOCK.get_or_init(|| Mutex::new(()))
}
#[cfg(test)]
fn quota_user_lock_test_scope() -> std::sync::MutexGuard<'static, ()> {
quota_user_lock_test_guard()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
fn quota_overflow_user_lock(user: &str) -> Arc<Mutex<()>> {
let stripes = QUOTA_USER_OVERFLOW_LOCKS.get_or_init(|| {
(0..QUOTA_OVERFLOW_LOCK_STRIPES)
.map(|_| Arc::new(Mutex::new(())))
.collect()
});
let hash = crc32fast::hash(user.as_bytes()) as usize;
Arc::clone(&stripes[hash % stripes.len()])
}
pub(crate) fn quota_user_lock_evict() {
if let Some(locks) = QUOTA_USER_LOCKS.get() {
locks.retain(|_, value| Arc::strong_count(value) > 1);
}
}
pub(crate) fn spawn_quota_user_lock_evictor(interval: Duration) -> tokio::task::JoinHandle<()> {
let interval = interval.max(Duration::from_millis(1));
#[cfg(test)]
QUOTA_USER_LOCK_EVICTOR_SPAWN_COUNT.fetch_add(1, Ordering::Relaxed);
tokio::spawn(async move {
loop {
tokio::time::sleep(interval).await;
quota_user_lock_evict();
}
})
}
#[cfg(test)]
pub(crate) fn spawn_quota_user_lock_evictor_for_tests(
interval: Duration,
) -> tokio::task::JoinHandle<()> {
spawn_quota_user_lock_evictor(interval)
}
#[cfg(test)]
pub(crate) fn reset_quota_user_lock_evictor_spawn_count_for_tests() {
QUOTA_USER_LOCK_EVICTOR_SPAWN_COUNT.store(0, Ordering::Relaxed);
}
#[cfg(test)]
pub(crate) fn quota_user_lock_evictor_spawn_count_for_tests() -> u64 {
QUOTA_USER_LOCK_EVICTOR_SPAWN_COUNT.load(Ordering::Relaxed)
}
fn quota_user_lock(user: &str) -> Arc<Mutex<()>> {
let locks = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
if let Some(existing) = locks.get(user) {
return Arc::clone(existing.value());
}
if locks.len() >= QUOTA_USER_LOCKS_MAX {
return quota_overflow_user_lock(user);
}
let created = Arc::new(Mutex::new(()));
match locks.entry(user.to_string()) {
dashmap::mapref::entry::Entry::Occupied(entry) => Arc::clone(entry.get()),
dashmap::mapref::entry::Entry::Vacant(entry) => {
entry.insert(Arc::clone(&created));
created
}
}
}
#[cfg(test)]
pub(crate) fn cross_mode_quota_user_lock_for_tests(user: &str) -> Arc<AsyncMutex<()>> {
crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(user)
fn should_immediate_quota_check(remaining_before: u64, charge_bytes: u64) -> bool {
remaining_before <= QUOTA_NEAR_LIMIT_BYTES || charge_bytes >= QUOTA_LARGE_CHARGE_BYTES
}
impl<S: AsyncRead + Unpin> AsyncRead for StatsIo<S> {
@ -453,93 +292,60 @@ impl<S: AsyncRead + Unpin> AsyncRead for StatsIo<S> {
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let this = self.get_mut();
if this.quota_exceeded.load(Ordering::Relaxed) {
if this.quota_exceeded.load(Ordering::Acquire) {
return Poll::Ready(Err(quota_io_error()));
}
let _quota_guard = if let Some(lock) = this.quota_lock.as_ref() {
match lock.try_lock() {
Ok(guard) => Some(guard),
Err(_) => {
poll_quota_retry_sleep(
&mut this.quota_read_retry_sleep,
&mut this.quota_read_wake_scheduled,
&mut this.quota_read_retry_attempt,
cx,
);
return Poll::Pending;
}
let mut remaining_before = None;
if let Some(limit) = this.quota_limit {
let used_before = this.user_stats.quota_used();
let remaining = limit.saturating_sub(used_before);
if remaining == 0 {
this.quota_exceeded.store(true, Ordering::Release);
return Poll::Ready(Err(quota_io_error()));
}
} else {
None
};
let _cross_mode_quota_guard = if let Some(lock) = this.cross_mode_quota_lock.as_ref() {
match lock.try_lock() {
Ok(guard) => Some(guard),
Err(_) => {
poll_quota_retry_sleep(
&mut this.quota_read_retry_sleep,
&mut this.quota_read_wake_scheduled,
&mut this.quota_read_retry_attempt,
cx,
);
return Poll::Pending;
}
}
} else {
None
};
reset_quota_retry_scheduler(
&mut this.quota_read_retry_sleep,
&mut this.quota_read_wake_scheduled,
&mut this.quota_read_retry_attempt,
);
if let Some(limit) = this.quota_limit
&& this.stats.get_user_total_octets(&this.user) >= limit
{
this.quota_exceeded.store(true, Ordering::Relaxed);
return Poll::Ready(Err(quota_io_error()));
remaining_before = Some(remaining);
}
let before = buf.filled().len();
match Pin::new(&mut this.inner).poll_read(cx, buf) {
Poll::Ready(Ok(())) => {
let n = buf.filled().len() - before;
if n > 0 {
let mut reached_quota_boundary = false;
if let Some(limit) = this.quota_limit {
let used = this.stats.get_user_total_octets(&this.user);
if used >= limit {
this.quota_exceeded.store(true, Ordering::Relaxed);
return Poll::Ready(Err(quota_io_error()));
}
let remaining = limit - used;
if (n as u64) > remaining {
// Fail closed: when a single read chunk would cross quota,
// stop relay immediately without accounting beyond the cap.
this.quota_exceeded.store(true, Ordering::Relaxed);
return Poll::Ready(Err(quota_io_error()));
}
reached_quota_boundary = (n as u64) == remaining;
}
let n_to_charge = n as u64;
// C→S: client sent data
this.counters
.c2s_bytes
.fetch_add(n as u64, Ordering::Relaxed);
.fetch_add(n_to_charge, Ordering::Relaxed);
this.counters.c2s_ops.fetch_add(1, Ordering::Relaxed);
this.counters.touch(Instant::now(), this.epoch);
this.stats.add_user_octets_from(&this.user, n as u64);
this.stats.increment_user_msgs_from(&this.user);
this.stats
.add_user_octets_from_handle(this.user_stats.as_ref(), n_to_charge);
this.stats
.increment_user_msgs_from_handle(this.user_stats.as_ref());
if reached_quota_boundary {
this.quota_exceeded.store(true, Ordering::Relaxed);
if let (Some(limit), Some(remaining)) = (this.quota_limit, remaining_before) {
this.stats
.quota_charge_post_write(this.user_stats.as_ref(), n_to_charge);
if should_immediate_quota_check(remaining, n_to_charge) {
this.quota_bytes_since_check = 0;
if this.user_stats.quota_used() >= limit {
this.quota_exceeded.store(true, Ordering::Release);
}
} else {
this.quota_bytes_since_check =
this.quota_bytes_since_check.saturating_add(n_to_charge);
let interval = quota_adaptive_interval_bytes(remaining);
if this.quota_bytes_since_check >= interval {
this.quota_bytes_since_check = 0;
if this.user_stats.quota_used() >= limit {
this.quota_exceeded.store(true, Ordering::Release);
}
}
}
}
trace!(user = %this.user, bytes = n, "C->S");
@ -558,87 +364,57 @@ impl<S: AsyncWrite + Unpin> AsyncWrite for StatsIo<S> {
buf: &[u8],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
if this.quota_exceeded.load(Ordering::Relaxed) {
if this.quota_exceeded.load(Ordering::Acquire) {
return Poll::Ready(Err(quota_io_error()));
}
let _quota_guard = if let Some(lock) = this.quota_lock.as_ref() {
match lock.try_lock() {
Ok(guard) => Some(guard),
Err(_) => {
poll_quota_retry_sleep(
&mut this.quota_write_retry_sleep,
&mut this.quota_write_wake_scheduled,
&mut this.quota_write_retry_attempt,
cx,
);
return Poll::Pending;
}
}
} else {
None
};
let _cross_mode_quota_guard = if let Some(lock) = this.cross_mode_quota_lock.as_ref() {
match lock.try_lock() {
Ok(guard) => Some(guard),
Err(_) => {
poll_quota_retry_sleep(
&mut this.quota_write_retry_sleep,
&mut this.quota_write_wake_scheduled,
&mut this.quota_write_retry_attempt,
cx,
);
return Poll::Pending;
}
}
} else {
None
};
reset_quota_retry_scheduler(
&mut this.quota_write_retry_sleep,
&mut this.quota_write_wake_scheduled,
&mut this.quota_write_retry_attempt,
);
let write_buf = if let Some(limit) = this.quota_limit {
let used = this.stats.get_user_total_octets(&this.user);
if used >= limit {
this.quota_exceeded.store(true, Ordering::Relaxed);
let mut remaining_before = None;
if let Some(limit) = this.quota_limit {
let used_before = this.user_stats.quota_used();
let remaining = limit.saturating_sub(used_before);
if remaining == 0 {
this.quota_exceeded.store(true, Ordering::Release);
return Poll::Ready(Err(quota_io_error()));
}
remaining_before = Some(remaining);
}
let remaining = (limit - used) as usize;
if buf.len() > remaining {
// Fail closed: do not emit partial S->C payload when remaining
// quota cannot accommodate the pending write request.
this.quota_exceeded.store(true, Ordering::Relaxed);
return Poll::Ready(Err(quota_io_error()));
}
buf
} else {
buf
};
match Pin::new(&mut this.inner).poll_write(cx, write_buf) {
match Pin::new(&mut this.inner).poll_write(cx, buf) {
Poll::Ready(Ok(n)) => {
if n > 0 {
let n_to_charge = n as u64;
// S→C: data written to client
this.counters
.s2c_bytes
.fetch_add(n as u64, Ordering::Relaxed);
.fetch_add(n_to_charge, Ordering::Relaxed);
this.counters.s2c_ops.fetch_add(1, Ordering::Relaxed);
this.counters.touch(Instant::now(), this.epoch);
this.stats.add_user_octets_to(&this.user, n as u64);
this.stats.increment_user_msgs_to(&this.user);
this.stats
.add_user_octets_to_handle(this.user_stats.as_ref(), n_to_charge);
this.stats
.increment_user_msgs_to_handle(this.user_stats.as_ref());
if let Some(limit) = this.quota_limit
&& this.stats.get_user_total_octets(&this.user) >= limit
{
this.quota_exceeded.store(true, Ordering::Relaxed);
return Poll::Ready(Err(quota_io_error()));
if let (Some(limit), Some(remaining)) = (this.quota_limit, remaining_before) {
this.stats
.quota_charge_post_write(this.user_stats.as_ref(), n_to_charge);
if should_immediate_quota_check(remaining, n_to_charge) {
this.quota_bytes_since_check = 0;
if this.user_stats.quota_used() >= limit {
this.quota_exceeded.store(true, Ordering::Release);
}
} else {
this.quota_bytes_since_check =
this.quota_bytes_since_check.saturating_add(n_to_charge);
let interval = quota_adaptive_interval_bytes(remaining);
if this.quota_bytes_since_check >= interval {
this.quota_bytes_since_check = 0;
if this.user_stats.quota_used() >= limit {
this.quota_exceeded.store(true, Ordering::Release);
}
}
}
}
trace!(user = %this.user, bytes = n, "S->C");
@ -732,7 +508,7 @@ where
let now = Instant::now();
let idle = wd_counters.idle_duration(now, epoch);
if wd_quota_exceeded.load(Ordering::Relaxed) {
if wd_quota_exceeded.load(Ordering::Acquire) {
warn!(user = %wd_user, "User data quota reached, closing relay");
return;
}
@ -870,18 +646,10 @@ where
}
}
#[cfg(test)]
#[path = "tests/relay_security_tests.rs"]
mod security_tests;
#[cfg(test)]
#[path = "tests/relay_adversarial_tests.rs"]
mod adversarial_tests;
#[cfg(test)]
#[path = "tests/relay_quota_lock_pressure_adversarial_tests.rs"]
mod relay_quota_lock_pressure_adversarial_tests;
#[cfg(test)]
#[path = "tests/relay_quota_boundary_blackhat_tests.rs"]
mod relay_quota_boundary_blackhat_tests;
@ -901,71 +669,3 @@ mod relay_quota_extended_attack_surface_security_tests;
#[cfg(test)]
#[path = "tests/relay_watchdog_delta_security_tests.rs"]
mod relay_watchdog_delta_security_tests;
#[cfg(test)]
#[path = "tests/relay_quota_waker_storm_adversarial_tests.rs"]
mod relay_quota_waker_storm_adversarial_tests;
#[cfg(test)]
#[path = "tests/relay_quota_wake_liveness_regression_tests.rs"]
mod relay_quota_wake_liveness_regression_tests;
#[cfg(test)]
#[path = "tests/relay_quota_lock_identity_security_tests.rs"]
mod relay_quota_lock_identity_security_tests;
#[cfg(test)]
#[path = "tests/relay_cross_mode_quota_lock_security_tests.rs"]
mod relay_cross_mode_quota_lock_security_tests;
#[cfg(test)]
#[path = "tests/relay_quota_retry_scheduler_tdd_tests.rs"]
mod relay_quota_retry_scheduler_tdd_tests;
#[cfg(test)]
#[path = "tests/relay_cross_mode_quota_fairness_tdd_tests.rs"]
mod relay_cross_mode_quota_fairness_tdd_tests;
#[cfg(test)]
#[path = "tests/relay_cross_mode_pipeline_hol_integration_security_tests.rs"]
mod relay_cross_mode_pipeline_hol_integration_security_tests;
#[cfg(test)]
#[path = "tests/relay_cross_mode_pipeline_latency_benchmark_security_tests.rs"]
mod relay_cross_mode_pipeline_latency_benchmark_security_tests;
#[cfg(test)]
#[path = "tests/relay_quota_retry_backoff_security_tests.rs"]
mod relay_quota_retry_backoff_security_tests;
#[cfg(test)]
#[path = "tests/relay_quota_retry_backoff_benchmark_security_tests.rs"]
mod relay_quota_retry_backoff_benchmark_security_tests;
#[cfg(test)]
#[path = "tests/relay_dual_lock_backoff_regression_security_tests.rs"]
mod relay_dual_lock_backoff_regression_security_tests;
#[cfg(test)]
#[path = "tests/relay_dual_lock_contention_matrix_security_tests.rs"]
mod relay_dual_lock_contention_matrix_security_tests;
#[cfg(test)]
#[path = "tests/relay_dual_lock_race_harness_security_tests.rs"]
mod relay_dual_lock_race_harness_security_tests;
#[cfg(test)]
#[path = "tests/relay_dual_lock_alternating_contention_security_tests.rs"]
mod relay_dual_lock_alternating_contention_security_tests;
#[cfg(test)]
#[path = "tests/relay_quota_retry_allocation_latency_security_tests.rs"]
mod relay_quota_retry_allocation_latency_security_tests;
#[cfg(test)]
#[path = "tests/relay_quota_lock_eviction_lifecycle_tdd_tests.rs"]
mod relay_quota_lock_eviction_lifecycle_tdd_tests;
#[cfg(test)]
#[path = "tests/relay_quota_lock_eviction_stress_security_tests.rs"]
mod relay_quota_lock_eviction_stress_security_tests;

View File

@ -238,10 +238,12 @@ pub struct Stats {
me_inline_recovery_total: AtomicU64,
ip_reservation_rollback_tcp_limit_total: AtomicU64,
ip_reservation_rollback_quota_limit_total: AtomicU64,
quota_write_fail_bytes_total: AtomicU64,
quota_write_fail_events_total: AtomicU64,
telemetry_core_enabled: AtomicBool,
telemetry_user_enabled: AtomicBool,
telemetry_me_level: AtomicU8,
user_stats: DashMap<String, UserStats>,
user_stats: DashMap<String, Arc<UserStats>>,
user_stats_last_cleanup_epoch_secs: AtomicU64,
start_time: parking_lot::RwLock<Option<Instant>>,
}
@ -254,9 +256,51 @@ pub struct UserStats {
pub octets_to_client: AtomicU64,
pub msgs_from_client: AtomicU64,
pub msgs_to_client: AtomicU64,
/// Total bytes charged against per-user quota admission.
///
/// This counter is the single source of truth for quota enforcement and
/// intentionally tracks attempted traffic, not guaranteed delivery.
pub quota_used: AtomicU64,
pub last_seen_epoch_secs: AtomicU64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum QuotaReserveError {
LimitExceeded,
Contended,
}
impl UserStats {
#[inline]
pub fn quota_used(&self) -> u64 {
self.quota_used.load(Ordering::Relaxed)
}
/// Attempts one CAS reservation step against the quota counter.
///
/// Callers control retry/yield policy. This primitive intentionally does
/// not block or sleep so both sync poll paths and async paths can wrap it
/// with their own contention strategy.
#[inline]
pub fn quota_try_reserve(&self, bytes: u64, limit: u64) -> Result<u64, QuotaReserveError> {
let current = self.quota_used.load(Ordering::Relaxed);
if bytes > limit.saturating_sub(current) {
return Err(QuotaReserveError::LimitExceeded);
}
let next = current.saturating_add(bytes);
match self.quota_used.compare_exchange_weak(
current,
next,
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(_) => Ok(next),
Err(_) => Err(QuotaReserveError::Contended),
}
}
}
impl Stats {
pub fn new() -> Self {
let stats = Self::default();
@ -316,6 +360,70 @@ impl Stats {
.store(Self::now_epoch_secs(), Ordering::Relaxed);
}
pub(crate) fn get_or_create_user_stats_handle(&self, user: &str) -> Arc<UserStats> {
self.maybe_cleanup_user_stats();
if let Some(existing) = self.user_stats.get(user) {
let handle = Arc::clone(existing.value());
Self::touch_user_stats(handle.as_ref());
return handle;
}
let entry = self.user_stats.entry(user.to_string()).or_default();
if entry.last_seen_epoch_secs.load(Ordering::Relaxed) == 0 {
Self::touch_user_stats(entry.value().as_ref());
}
Arc::clone(entry.value())
}
#[inline]
pub(crate) fn add_user_octets_from_handle(&self, user_stats: &UserStats, bytes: u64) {
if !self.telemetry_user_enabled() {
return;
}
Self::touch_user_stats(user_stats);
user_stats.octets_from_client.fetch_add(bytes, Ordering::Relaxed);
}
#[inline]
pub(crate) fn add_user_octets_to_handle(&self, user_stats: &UserStats, bytes: u64) {
if !self.telemetry_user_enabled() {
return;
}
Self::touch_user_stats(user_stats);
user_stats.octets_to_client.fetch_add(bytes, Ordering::Relaxed);
}
#[inline]
pub(crate) fn increment_user_msgs_from_handle(&self, user_stats: &UserStats) {
if !self.telemetry_user_enabled() {
return;
}
Self::touch_user_stats(user_stats);
user_stats.msgs_from_client.fetch_add(1, Ordering::Relaxed);
}
#[inline]
pub(crate) fn increment_user_msgs_to_handle(&self, user_stats: &UserStats) {
if !self.telemetry_user_enabled() {
return;
}
Self::touch_user_stats(user_stats);
user_stats.msgs_to_client.fetch_add(1, Ordering::Relaxed);
}
/// Charges already committed bytes in a post-I/O path.
///
/// This helper is intentionally separate from `quota_try_reserve` to avoid
/// mixing reserve and post-charge on a single I/O event.
#[inline]
pub(crate) fn quota_charge_post_write(&self, user_stats: &UserStats, bytes: u64) -> u64 {
Self::touch_user_stats(user_stats);
user_stats
.quota_used
.fetch_add(bytes, Ordering::Relaxed)
.saturating_add(bytes)
}
fn maybe_cleanup_user_stats(&self) {
const USER_STATS_CLEANUP_INTERVAL_SECS: u64 = 60;
const USER_STATS_IDLE_TTL_SECS: u64 = 24 * 60 * 60;
@ -1114,6 +1222,18 @@ impl Stats {
.fetch_add(1, Ordering::Relaxed);
}
}
pub fn add_quota_write_fail_bytes_total(&self, bytes: u64) {
if self.telemetry_core_enabled() {
self.quota_write_fail_bytes_total
.fetch_add(bytes, Ordering::Relaxed);
}
}
pub fn increment_quota_write_fail_events_total(&self) {
if self.telemetry_core_enabled() {
self.quota_write_fail_events_total
.fetch_add(1, Ordering::Relaxed);
}
}
pub fn increment_me_endpoint_quarantine_total(&self) {
if self.telemetry_me_allows_normal() {
self.me_endpoint_quarantine_total
@ -1764,19 +1884,19 @@ impl Stats {
self.ip_reservation_rollback_quota_limit_total
.load(Ordering::Relaxed)
}
pub fn get_quota_write_fail_bytes_total(&self) -> u64 {
self.quota_write_fail_bytes_total.load(Ordering::Relaxed)
}
pub fn get_quota_write_fail_events_total(&self) -> u64 {
self.quota_write_fail_events_total.load(Ordering::Relaxed)
}
pub fn increment_user_connects(&self, user: &str) {
if !self.telemetry_user_enabled() {
return;
}
self.maybe_cleanup_user_stats();
if let Some(stats) = self.user_stats.get(user) {
Self::touch_user_stats(stats.value());
stats.connects.fetch_add(1, Ordering::Relaxed);
return;
}
let stats = self.user_stats.entry(user.to_string()).or_default();
Self::touch_user_stats(stats.value());
let stats = self.get_or_create_user_stats_handle(user);
Self::touch_user_stats(stats.as_ref());
stats.connects.fetch_add(1, Ordering::Relaxed);
}
@ -1784,14 +1904,8 @@ impl Stats {
if !self.telemetry_user_enabled() {
return;
}
self.maybe_cleanup_user_stats();
if let Some(stats) = self.user_stats.get(user) {
Self::touch_user_stats(stats.value());
stats.curr_connects.fetch_add(1, Ordering::Relaxed);
return;
}
let stats = self.user_stats.entry(user.to_string()).or_default();
Self::touch_user_stats(stats.value());
let stats = self.get_or_create_user_stats_handle(user);
Self::touch_user_stats(stats.as_ref());
stats.curr_connects.fetch_add(1, Ordering::Relaxed);
}
@ -1800,9 +1914,8 @@ impl Stats {
return true;
}
self.maybe_cleanup_user_stats();
let stats = self.user_stats.entry(user.to_string()).or_default();
Self::touch_user_stats(stats.value());
let stats = self.get_or_create_user_stats_handle(user);
Self::touch_user_stats(stats.as_ref());
let counter = &stats.curr_connects;
let mut current = counter.load(Ordering::Relaxed);
@ -1827,7 +1940,7 @@ impl Stats {
pub fn decrement_user_curr_connects(&self, user: &str) {
self.maybe_cleanup_user_stats();
if let Some(stats) = self.user_stats.get(user) {
Self::touch_user_stats(stats.value());
Self::touch_user_stats(stats.value().as_ref());
let counter = &stats.curr_connects;
let mut current = counter.load(Ordering::Relaxed);
loop {
@ -1858,86 +1971,32 @@ impl Stats {
if !self.telemetry_user_enabled() {
return;
}
self.maybe_cleanup_user_stats();
if let Some(stats) = self.user_stats.get(user) {
Self::touch_user_stats(stats.value());
stats.octets_from_client.fetch_add(bytes, Ordering::Relaxed);
return;
}
let stats = self.user_stats.entry(user.to_string()).or_default();
Self::touch_user_stats(stats.value());
stats.octets_from_client.fetch_add(bytes, Ordering::Relaxed);
let stats = self.get_or_create_user_stats_handle(user);
self.add_user_octets_from_handle(stats.as_ref(), bytes);
}
pub fn add_user_octets_to(&self, user: &str, bytes: u64) {
if !self.telemetry_user_enabled() {
return;
}
self.maybe_cleanup_user_stats();
if let Some(stats) = self.user_stats.get(user) {
Self::touch_user_stats(stats.value());
stats.octets_to_client.fetch_add(bytes, Ordering::Relaxed);
return;
}
let stats = self.user_stats.entry(user.to_string()).or_default();
Self::touch_user_stats(stats.value());
stats.octets_to_client.fetch_add(bytes, Ordering::Relaxed);
}
pub fn sub_user_octets_to(&self, user: &str, bytes: u64) {
if !self.telemetry_user_enabled() {
return;
}
self.maybe_cleanup_user_stats();
let Some(stats) = self.user_stats.get(user) else {
return;
};
Self::touch_user_stats(stats.value());
let counter = &stats.octets_to_client;
let mut current = counter.load(Ordering::Relaxed);
loop {
let next = current.saturating_sub(bytes);
match counter.compare_exchange_weak(
current,
next,
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(_) => break,
Err(actual) => current = actual,
}
}
let stats = self.get_or_create_user_stats_handle(user);
self.add_user_octets_to_handle(stats.as_ref(), bytes);
}
pub fn increment_user_msgs_from(&self, user: &str) {
if !self.telemetry_user_enabled() {
return;
}
self.maybe_cleanup_user_stats();
if let Some(stats) = self.user_stats.get(user) {
Self::touch_user_stats(stats.value());
stats.msgs_from_client.fetch_add(1, Ordering::Relaxed);
return;
}
let stats = self.user_stats.entry(user.to_string()).or_default();
Self::touch_user_stats(stats.value());
stats.msgs_from_client.fetch_add(1, Ordering::Relaxed);
let stats = self.get_or_create_user_stats_handle(user);
self.increment_user_msgs_from_handle(stats.as_ref());
}
pub fn increment_user_msgs_to(&self, user: &str) {
if !self.telemetry_user_enabled() {
return;
}
self.maybe_cleanup_user_stats();
if let Some(stats) = self.user_stats.get(user) {
Self::touch_user_stats(stats.value());
stats.msgs_to_client.fetch_add(1, Ordering::Relaxed);
return;
}
let stats = self.user_stats.entry(user.to_string()).or_default();
Self::touch_user_stats(stats.value());
stats.msgs_to_client.fetch_add(1, Ordering::Relaxed);
let stats = self.get_or_create_user_stats_handle(user);
self.increment_user_msgs_to_handle(stats.as_ref());
}
pub fn get_user_total_octets(&self, user: &str) -> u64 {
@ -1950,6 +2009,13 @@ impl Stats {
.unwrap_or(0)
}
pub fn get_user_quota_used(&self, user: &str) -> u64 {
self.user_stats
.get(user)
.map(|s| s.quota_used.load(Ordering::Relaxed))
.unwrap_or(0)
}
pub fn get_handshake_timeouts(&self) -> u64 {
self.handshake_timeouts.load(Ordering::Relaxed)
}
@ -2015,7 +2081,7 @@ impl Stats {
.load(Ordering::Relaxed)
}
pub fn iter_user_stats(&self) -> dashmap::iter::Iter<'_, String, UserStats> {
pub fn iter_user_stats(&self) -> dashmap::iter::Iter<'_, String, Arc<UserStats>> {
self.user_stats.iter()
}
@ -2163,6 +2229,22 @@ impl ReplayChecker {
found
}
fn check_only_internal(
&self,
data: &[u8],
shards: &[Mutex<ReplayShard>],
window: Duration,
) -> bool {
self.checks.fetch_add(1, Ordering::Relaxed);
let idx = self.get_shard_idx(data);
let mut shard = shards[idx].lock();
let found = shard.check(data, Instant::now(), window);
if found {
self.hits.fetch_add(1, Ordering::Relaxed);
}
found
}
fn add_only(&self, data: &[u8], shards: &[Mutex<ReplayShard>], window: Duration) {
self.additions.fetch_add(1, Ordering::Relaxed);
let idx = self.get_shard_idx(data);
@ -2186,7 +2268,7 @@ impl ReplayChecker {
self.add_only(data, &self.handshake_shards, self.window)
}
pub fn check_tls_digest(&self, data: &[u8]) -> bool {
self.check_and_add_tls_digest(data)
self.check_only_internal(data, &self.tls_shards, self.tls_window)
}
pub fn add_tls_digest(&self, data: &[u8]) {
self.add_only(data, &self.tls_shards, self.tls_window)
@ -2289,6 +2371,7 @@ impl ReplayStats {
mod tests {
use super::*;
use crate::config::MeTelemetryLevel;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
#[test]
@ -2457,6 +2540,60 @@ mod tests {
}
assert_eq!(checker.stats().total_entries, 500);
}
#[test]
fn test_quota_reserve_under_contention_hits_limit_exactly() {
let user_stats = Arc::new(UserStats::default());
let successes = Arc::new(AtomicU64::new(0));
let limit = 8_192u64;
let mut workers = Vec::new();
for _ in 0..8 {
let user_stats = user_stats.clone();
let successes = successes.clone();
workers.push(std::thread::spawn(move || {
loop {
match user_stats.quota_try_reserve(1, limit) {
Ok(_) => {
successes.fetch_add(1, Ordering::Relaxed);
}
Err(QuotaReserveError::Contended) => {
std::hint::spin_loop();
}
Err(QuotaReserveError::LimitExceeded) => {
break;
}
}
}
}));
}
for worker in workers {
worker.join().expect("worker thread must finish");
}
assert_eq!(
successes.load(Ordering::Relaxed),
limit,
"successful reservations must stop exactly at limit"
);
assert_eq!(user_stats.quota_used(), limit);
}
#[test]
fn test_quota_used_is_authoritative_and_independent_from_octets_telemetry() {
let stats = Stats::new();
let user = "quota-authoritative-user";
let user_stats = stats.get_or_create_user_stats_handle(user);
stats.add_user_octets_to_handle(&user_stats, 5);
assert_eq!(stats.get_user_total_octets(user), 5);
assert_eq!(stats.get_user_quota_used(user), 0);
stats.quota_charge_post_write(&user_stats, 7);
assert_eq!(stats.get_user_total_octets(user), 5);
assert_eq!(stats.get_user_quota_used(user), 7);
}
}
#[cfg(test)]
@ -2466,7 +2603,3 @@ mod connection_lease_security_tests;
#[cfg(test)]
#[path = "tests/replay_checker_security_tests.rs"]
mod replay_checker_security_tests;
#[cfg(test)]
#[path = "tests/user_octets_sub_security_tests.rs"]
mod user_octets_sub_security_tests;

View File

@ -1,151 +0,0 @@
use super::*;
use std::sync::Arc;
use std::thread;
#[test]
fn sub_user_octets_to_underflow_saturates_at_zero() {
let stats = Stats::new();
let user = "sub-underflow-user";
stats.add_user_octets_to(user, 3);
stats.sub_user_octets_to(user, 100);
assert_eq!(stats.get_user_total_octets(user), 0);
}
#[test]
fn sub_user_octets_to_does_not_affect_octets_from_client() {
let stats = Stats::new();
let user = "sub-isolation-user";
stats.add_user_octets_from(user, 17);
stats.add_user_octets_to(user, 5);
stats.sub_user_octets_to(user, 3);
assert_eq!(stats.get_user_total_octets(user), 19);
}
#[test]
fn light_fuzz_add_sub_model_matches_saturating_reference() {
let stats = Stats::new();
let user = "sub-fuzz-user";
let mut seed = 0x91D2_4CB8_EE77_1101u64;
let mut model_to = 0u64;
for _ in 0..8192 {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let amt = ((seed >> 8) & 0x3f) + 1;
if (seed & 1) == 0 {
stats.add_user_octets_to(user, amt);
model_to = model_to.saturating_add(amt);
} else {
stats.sub_user_octets_to(user, amt);
model_to = model_to.saturating_sub(amt);
}
}
assert_eq!(stats.get_user_total_octets(user), model_to);
}
#[test]
fn stress_parallel_add_sub_never_underflows_or_panics() {
let stats = Arc::new(Stats::new());
let user = "sub-stress-user";
// Pre-fund with a large offset so subtractions never saturate at zero.
// This guarantees commutative updates, making the final state deterministic.
let base_offset = 10_000_000u64;
stats.add_user_octets_to(user, base_offset);
let mut workers = Vec::new();
for tid in 0..16u64 {
let stats_for_thread = Arc::clone(&stats);
workers.push(thread::spawn(move || {
let mut seed = 0xD00D_1000_0000_0000u64 ^ tid;
let mut net_delta = 0i64;
for _ in 0..4096 {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let amt = ((seed >> 8) & 0x1f) + 1;
if (seed & 1) == 0 {
stats_for_thread.add_user_octets_to(user, amt);
net_delta += amt as i64;
} else {
stats_for_thread.sub_user_octets_to(user, amt);
net_delta -= amt as i64;
}
}
net_delta
}));
}
let mut expected_net_delta = 0i64;
for worker in workers {
expected_net_delta += worker
.join()
.expect("sub-user stress worker must not panic");
}
let expected_total = (base_offset as i64 + expected_net_delta) as u64;
let total = stats.get_user_total_octets(user);
assert_eq!(
total, expected_total,
"concurrent add/sub lost updates or suffered ABA races"
);
}
#[test]
fn sub_user_octets_to_missing_user_is_noop() {
let stats = Stats::new();
stats.sub_user_octets_to("missing-user", 1024);
assert_eq!(stats.get_user_total_octets("missing-user"), 0);
}
#[test]
fn stress_parallel_per_user_models_remain_exact() {
let stats = Arc::new(Stats::new());
let mut workers = Vec::new();
for tid in 0..16u64 {
let stats_for_thread = Arc::clone(&stats);
workers.push(thread::spawn(move || {
let user = format!("sub-per-user-{tid}");
let mut seed = 0xFACE_0000_0000_0000u64 ^ tid;
let mut model = 0u64;
for _ in 0..4096 {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let amt = ((seed >> 8) & 0x3f) + 1;
if (seed & 1) == 0 {
stats_for_thread.add_user_octets_to(&user, amt);
model = model.saturating_add(amt);
} else {
stats_for_thread.sub_user_octets_to(&user, amt);
model = model.saturating_sub(amt);
}
}
(user, model)
}));
}
for worker in workers {
let (user, model) = worker
.join()
.expect("per-user subtract stress worker must not panic");
assert_eq!(
stats.get_user_total_octets(&user),
model,
"per-user parallel model diverged"
);
}
}