mirror of https://github.com/telemt/telemt.git
Redesign Quotas on Atomics
This commit is contained in:
parent
0c3c9009a9
commit
6f4356f72a
|
|
@ -32,14 +32,6 @@ pub(crate) struct RuntimeWatches {
|
|||
pub(crate) detected_ip_v6: Option<IpAddr>,
|
||||
}
|
||||
|
||||
const QUOTA_USER_LOCK_EVICT_INTERVAL_SECS: u64 = 60;
|
||||
|
||||
fn spawn_quota_lock_maintenance_task() -> tokio::task::JoinHandle<()> {
|
||||
crate::proxy::relay::spawn_quota_user_lock_evictor(std::time::Duration::from_secs(
|
||||
QUOTA_USER_LOCK_EVICT_INTERVAL_SECS,
|
||||
))
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn spawn_runtime_tasks(
|
||||
config: &Arc<ProxyConfig>,
|
||||
|
|
@ -77,8 +69,6 @@ pub(crate) async fn spawn_runtime_tasks(
|
|||
rc_clone.run_periodic_cleanup().await;
|
||||
});
|
||||
|
||||
spawn_quota_lock_maintenance_task();
|
||||
|
||||
let detected_ip_v4: Option<IpAddr> = probe.detected_ipv4.map(IpAddr::V4);
|
||||
let detected_ip_v6: Option<IpAddr> = probe.detected_ipv6.map(IpAddr::V6);
|
||||
debug!(
|
||||
|
|
@ -370,24 +360,3 @@ pub(crate) async fn mark_runtime_ready(startup_tracker: &Arc<StartupTracker>) {
|
|||
.await;
|
||||
startup_tracker.mark_ready().await;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn tdd_runtime_quota_lock_maintenance_path_spawns_single_evictor_task() {
|
||||
crate::proxy::relay::reset_quota_user_lock_evictor_spawn_count_for_tests();
|
||||
|
||||
let handle = spawn_quota_lock_maintenance_task();
|
||||
tokio::time::sleep(std::time::Duration::from_millis(5)).await;
|
||||
|
||||
assert_eq!(
|
||||
crate::proxy::relay::quota_user_lock_evictor_spawn_count_for_tests(),
|
||||
1,
|
||||
"runtime maintenance path must spawn exactly one quota lock evictor task per call"
|
||||
);
|
||||
|
||||
handle.abort();
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1223,7 +1223,7 @@ impl RunningClientHandler {
|
|||
}
|
||||
|
||||
if let Some(quota) = config.access.user_data_quota.get(user)
|
||||
&& stats.get_user_total_octets(user) >= *quota
|
||||
&& stats.get_user_quota_used(user) >= *quota
|
||||
{
|
||||
return Err(ProxyError::DataQuotaExceeded {
|
||||
user: user.to_string(),
|
||||
|
|
@ -1282,7 +1282,7 @@ impl RunningClientHandler {
|
|||
}
|
||||
|
||||
if let Some(quota) = config.access.user_data_quota.get(user)
|
||||
&& stats.get_user_total_octets(user) >= *quota
|
||||
&& stats.get_user_quota_used(user) >= *quota
|
||||
{
|
||||
return Err(ProxyError::DataQuotaExceeded {
|
||||
user: user.to_string(),
|
||||
|
|
|
|||
|
|
@ -614,6 +614,15 @@ where
|
|||
}
|
||||
};
|
||||
|
||||
// Reject known replay digests before expensive cache/domain/ALPN policy work.
|
||||
let digest_half = &validation.digest[..tls::TLS_DIGEST_HALF_LEN];
|
||||
if replay_checker.check_tls_digest(digest_half) {
|
||||
auth_probe_record_failure(peer.ip(), Instant::now());
|
||||
maybe_apply_server_hello_delay(config).await;
|
||||
warn!(peer = %peer, "TLS replay attack detected (duplicate digest)");
|
||||
return HandshakeResult::BadClient { reader, writer };
|
||||
}
|
||||
|
||||
let secret = match secrets.iter().find(|(name, _)| *name == validation.user) {
|
||||
Some((_, s)) => s,
|
||||
None => {
|
||||
|
|
@ -669,15 +678,8 @@ where
|
|||
None
|
||||
};
|
||||
|
||||
// Replay tracking is applied only after full policy validation (including
|
||||
// ALPN checks) so rejected handshakes cannot poison replay state.
|
||||
let digest_half = &validation.digest[..tls::TLS_DIGEST_HALF_LEN];
|
||||
if replay_checker.check_and_add_tls_digest(digest_half) {
|
||||
auth_probe_record_failure(peer.ip(), Instant::now());
|
||||
maybe_apply_server_hello_delay(config).await;
|
||||
warn!(peer = %peer, "TLS replay attack detected (duplicate digest)");
|
||||
return HandshakeResult::BadClient { reader, writer };
|
||||
}
|
||||
// Add replay digest only for policy-valid handshakes.
|
||||
replay_checker.add_tls_digest(digest_half);
|
||||
|
||||
let response = if let Some((cached_entry, use_full_cert_payload)) = cached {
|
||||
emulator::build_emulated_server_hello(
|
||||
|
|
|
|||
|
|
@ -60,7 +60,7 @@ where
|
|||
R: AsyncRead + Unpin,
|
||||
W: AsyncWrite + Unpin,
|
||||
{
|
||||
let mut buf = [0u8; MASK_BUFFER_SIZE];
|
||||
let mut buf = Box::new([0u8; MASK_BUFFER_SIZE]);
|
||||
let mut total = 0usize;
|
||||
let mut ended_by_eof = false;
|
||||
|
||||
|
|
@ -262,7 +262,11 @@ fn mask_outcome_target_budget(config: &ProxyConfig) -> Duration {
|
|||
let floor = config.censorship.mask_timing_normalization_floor_ms;
|
||||
let ceiling = config.censorship.mask_timing_normalization_ceiling_ms;
|
||||
if floor == 0 {
|
||||
return MASK_TIMEOUT;
|
||||
if ceiling == 0 {
|
||||
return Duration::from_millis(0);
|
||||
}
|
||||
let mut rng = rand::rng();
|
||||
return Duration::from_millis(rng.random_range(0..=ceiling));
|
||||
}
|
||||
if ceiling > floor {
|
||||
let mut rng = rand::rng();
|
||||
|
|
@ -838,7 +842,7 @@ async fn consume_client_data<R: AsyncRead + Unpin>(mut reader: R, byte_cap: usiz
|
|||
}
|
||||
|
||||
// Keep drain path fail-closed under slow-loris stalls.
|
||||
let mut buf = [0u8; MASK_BUFFER_SIZE];
|
||||
let mut buf = Box::new([0u8; MASK_BUFFER_SIZE]);
|
||||
let mut total = 0usize;
|
||||
|
||||
loop {
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ use std::time::{Duration, Instant};
|
|||
|
||||
use dashmap::DashMap;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
use tokio::sync::{Mutex as AsyncMutex, mpsc, oneshot, watch};
|
||||
use tokio::sync::{mpsc, oneshot, watch};
|
||||
use tokio::time::timeout;
|
||||
use tracing::{debug, info, trace, warn};
|
||||
|
||||
|
|
@ -23,7 +23,7 @@ use crate::proxy::route_mode::{
|
|||
ROUTE_SWITCH_ERROR_MSG, RelayRouteMode, RouteCutoverState, affected_cutover_state,
|
||||
cutover_stagger_delay,
|
||||
};
|
||||
use crate::stats::{MeD2cFlushReason, MeD2cQuotaRejectStage, MeD2cWriteMode, Stats};
|
||||
use crate::stats::{MeD2cFlushReason, MeD2cQuotaRejectStage, MeD2cWriteMode, QuotaReserveError, Stats, UserStats};
|
||||
use crate::stream::{BufferPool, CryptoReader, CryptoWriter, PooledBuffer};
|
||||
use crate::transport::middle_proxy::{MePool, MeResponse, proto_flags_for_tag};
|
||||
|
||||
|
|
@ -53,20 +53,11 @@ const ME_D2C_FLUSH_BATCH_MAX_FRAMES_MIN: usize = 1;
|
|||
const ME_D2C_FLUSH_BATCH_MAX_BYTES_MIN: usize = 4096;
|
||||
const ME_D2C_FRAME_BUF_SHRINK_HYSTERESIS_FACTOR: usize = 2;
|
||||
const ME_D2C_SINGLE_WRITE_COALESCE_MAX_BYTES: usize = 128 * 1024;
|
||||
#[cfg(test)]
|
||||
const QUOTA_USER_LOCKS_MAX: usize = 64;
|
||||
#[cfg(not(test))]
|
||||
const QUOTA_USER_LOCKS_MAX: usize = 4_096;
|
||||
#[cfg(test)]
|
||||
const QUOTA_OVERFLOW_LOCK_STRIPES: usize = 16;
|
||||
#[cfg(not(test))]
|
||||
const QUOTA_OVERFLOW_LOCK_STRIPES: usize = 256;
|
||||
const QUOTA_RESERVE_SPIN_RETRIES: usize = 32;
|
||||
static DESYNC_DEDUP: OnceLock<DashMap<u64, Instant>> = OnceLock::new();
|
||||
static DESYNC_HASHER: OnceLock<RandomState> = OnceLock::new();
|
||||
static DESYNC_FULL_CACHE_LAST_EMIT_AT: OnceLock<Mutex<Option<Instant>>> = OnceLock::new();
|
||||
static DESYNC_DEDUP_EVER_SATURATED: OnceLock<AtomicBool> = OnceLock::new();
|
||||
static QUOTA_USER_LOCKS: OnceLock<DashMap<String, Arc<AsyncMutex<()>>>> = OnceLock::new();
|
||||
static QUOTA_USER_OVERFLOW_LOCKS: OnceLock<Vec<Arc<AsyncMutex<()>>>> = OnceLock::new();
|
||||
static RELAY_IDLE_CANDIDATE_REGISTRY: OnceLock<Mutex<RelayIdleCandidateRegistry>> = OnceLock::new();
|
||||
static RELAY_IDLE_MARK_SEQ: AtomicU64 = AtomicU64::new(0);
|
||||
|
||||
|
|
@ -538,36 +529,28 @@ fn should_yield_c2me_sender(sent_since_yield: usize, has_backlog: bool) -> bool
|
|||
has_backlog && sent_since_yield >= C2ME_SENDER_FAIRNESS_BUDGET
|
||||
}
|
||||
|
||||
fn quota_exceeded_for_user(stats: &Stats, user: &str, quota_limit: Option<u64>) -> bool {
|
||||
quota_limit.is_some_and(|quota| stats.get_user_total_octets(user) >= quota)
|
||||
}
|
||||
|
||||
#[cfg_attr(not(test), allow(dead_code))]
|
||||
fn quota_would_be_exceeded_for_user(
|
||||
stats: &Stats,
|
||||
user: &str,
|
||||
quota_limit: Option<u64>,
|
||||
bytes: u64,
|
||||
) -> bool {
|
||||
quota_limit.is_some_and(|quota| {
|
||||
let used = stats.get_user_total_octets(user);
|
||||
used >= quota || bytes > quota.saturating_sub(used)
|
||||
})
|
||||
}
|
||||
|
||||
fn quota_soft_cap(limit: u64, overshoot: u64) -> u64 {
|
||||
limit.saturating_add(overshoot)
|
||||
}
|
||||
|
||||
fn quota_would_be_exceeded_for_user_soft(
|
||||
stats: &Stats,
|
||||
user: &str,
|
||||
quota_limit: Option<u64>,
|
||||
async fn reserve_user_quota_with_yield(
|
||||
user_stats: &UserStats,
|
||||
bytes: u64,
|
||||
overshoot: u64,
|
||||
) -> bool {
|
||||
let capped_limit = quota_limit.map(|quota| quota_soft_cap(quota, overshoot));
|
||||
quota_would_be_exceeded_for_user(stats, user, capped_limit, bytes)
|
||||
limit: u64,
|
||||
) -> std::result::Result<u64, QuotaReserveError> {
|
||||
loop {
|
||||
for _ in 0..QUOTA_RESERVE_SPIN_RETRIES {
|
||||
match user_stats.quota_try_reserve(bytes, limit) {
|
||||
Ok(total) => return Ok(total),
|
||||
Err(QuotaReserveError::LimitExceeded) => {
|
||||
return Err(QuotaReserveError::LimitExceeded);
|
||||
}
|
||||
Err(QuotaReserveError::Contended) => std::hint::spin_loop(),
|
||||
}
|
||||
}
|
||||
|
||||
tokio::task::yield_now().await;
|
||||
}
|
||||
}
|
||||
|
||||
fn classify_me_d2c_flush_reason(
|
||||
|
|
@ -613,29 +596,6 @@ fn observe_me_d2c_flush_event(
|
|||
}
|
||||
}
|
||||
|
||||
fn rollback_me2c_quota_reservation(
|
||||
stats: &Stats,
|
||||
user: &str,
|
||||
bytes_me2c: &AtomicU64,
|
||||
reserved_bytes: u64,
|
||||
) {
|
||||
stats.sub_user_octets_to(user, reserved_bytes);
|
||||
bytes_me2c.fetch_sub(reserved_bytes, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn quota_user_lock_test_guard() -> &'static Mutex<()> {
|
||||
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
TEST_LOCK.get_or_init(|| Mutex::new(()))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn quota_user_lock_test_scope() -> std::sync::MutexGuard<'static, ()> {
|
||||
quota_user_lock_test_guard()
|
||||
.lock()
|
||||
.unwrap_or_else(|poisoned| poisoned.into_inner())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn relay_idle_pressure_test_guard() -> &'static Mutex<()> {
|
||||
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
|
|
@ -649,46 +609,6 @@ pub(crate) fn relay_idle_pressure_test_scope() -> std::sync::MutexGuard<'static,
|
|||
.unwrap_or_else(|poisoned| poisoned.into_inner())
|
||||
}
|
||||
|
||||
fn quota_overflow_user_lock(user: &str) -> Arc<AsyncMutex<()>> {
|
||||
let stripes = QUOTA_USER_OVERFLOW_LOCKS.get_or_init(|| {
|
||||
(0..QUOTA_OVERFLOW_LOCK_STRIPES)
|
||||
.map(|_| Arc::new(AsyncMutex::new(())))
|
||||
.collect()
|
||||
});
|
||||
|
||||
let hash = crc32fast::hash(user.as_bytes()) as usize;
|
||||
Arc::clone(&stripes[hash % stripes.len()])
|
||||
}
|
||||
|
||||
fn quota_user_lock(user: &str) -> Arc<AsyncMutex<()>> {
|
||||
let locks = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
if let Some(existing) = locks.get(user) {
|
||||
return Arc::clone(existing.value());
|
||||
}
|
||||
|
||||
if locks.len() >= QUOTA_USER_LOCKS_MAX {
|
||||
locks.retain(|_, value| Arc::strong_count(value) > 1);
|
||||
}
|
||||
|
||||
if locks.len() >= QUOTA_USER_LOCKS_MAX {
|
||||
return quota_overflow_user_lock(user);
|
||||
}
|
||||
|
||||
let created = Arc::new(AsyncMutex::new(()));
|
||||
match locks.entry(user.to_string()) {
|
||||
dashmap::mapref::entry::Entry::Occupied(entry) => Arc::clone(entry.get()),
|
||||
dashmap::mapref::entry::Entry::Vacant(entry) => {
|
||||
entry.insert(Arc::clone(&created));
|
||||
created
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn cross_mode_quota_user_lock_for_tests(user: &str) -> Arc<AsyncMutex<()>> {
|
||||
crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(user)
|
||||
}
|
||||
|
||||
async fn enqueue_c2me_command(
|
||||
tx: &mpsc::Sender<C2MeCommand>,
|
||||
cmd: C2MeCommand,
|
||||
|
|
@ -744,8 +664,7 @@ where
|
|||
{
|
||||
let user = success.user.clone();
|
||||
let quota_limit = config.access.user_data_quota.get(&user).copied();
|
||||
let cross_mode_quota_lock =
|
||||
quota_limit.map(|_| crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user));
|
||||
let quota_user_stats = quota_limit.map(|_| stats.get_or_create_user_stats_handle(&user));
|
||||
let peer = success.peer;
|
||||
let proto_tag = success.proto_tag;
|
||||
let pool_generation = me_pool.current_generation();
|
||||
|
|
@ -872,7 +791,7 @@ where
|
|||
let stats_clone = stats.clone();
|
||||
let rng_clone = rng.clone();
|
||||
let user_clone = user.clone();
|
||||
let cross_mode_quota_lock_me_writer = cross_mode_quota_lock.clone();
|
||||
let quota_user_stats_me_writer = quota_user_stats.clone();
|
||||
let last_downstream_activity_ms_clone = last_downstream_activity_ms.clone();
|
||||
let bytes_me2c_clone = bytes_me2c.clone();
|
||||
let d2c_flush_policy = MeD2cFlushPolicy::from_config(&config);
|
||||
|
|
@ -894,7 +813,7 @@ where
|
|||
|
||||
let first_is_downstream_activity =
|
||||
matches!(&first, MeResponse::Data { .. } | MeResponse::Ack(_));
|
||||
match process_me_writer_response_with_cross_mode_lock(
|
||||
match process_me_writer_response(
|
||||
first,
|
||||
&mut writer,
|
||||
proto_tag,
|
||||
|
|
@ -902,9 +821,9 @@ where
|
|||
&mut frame_buf,
|
||||
stats_clone.as_ref(),
|
||||
&user_clone,
|
||||
quota_user_stats_me_writer.as_deref(),
|
||||
quota_limit,
|
||||
d2c_flush_policy.quota_soft_overshoot_bytes,
|
||||
cross_mode_quota_lock_me_writer.as_ref(),
|
||||
bytes_me2c_clone.as_ref(),
|
||||
conn_id,
|
||||
d2c_flush_policy.ack_flush_immediate,
|
||||
|
|
@ -953,7 +872,7 @@ where
|
|||
|
||||
let next_is_downstream_activity =
|
||||
matches!(&next, MeResponse::Data { .. } | MeResponse::Ack(_));
|
||||
match process_me_writer_response_with_cross_mode_lock(
|
||||
match process_me_writer_response(
|
||||
next,
|
||||
&mut writer,
|
||||
proto_tag,
|
||||
|
|
@ -961,9 +880,9 @@ where
|
|||
&mut frame_buf,
|
||||
stats_clone.as_ref(),
|
||||
&user_clone,
|
||||
quota_user_stats_me_writer.as_deref(),
|
||||
quota_limit,
|
||||
d2c_flush_policy.quota_soft_overshoot_bytes,
|
||||
cross_mode_quota_lock_me_writer.as_ref(),
|
||||
bytes_me2c_clone.as_ref(),
|
||||
conn_id,
|
||||
d2c_flush_policy.ack_flush_immediate,
|
||||
|
|
@ -1015,7 +934,7 @@ where
|
|||
Ok(Some(next)) => {
|
||||
let next_is_downstream_activity =
|
||||
matches!(&next, MeResponse::Data { .. } | MeResponse::Ack(_));
|
||||
match process_me_writer_response_with_cross_mode_lock(
|
||||
match process_me_writer_response(
|
||||
next,
|
||||
&mut writer,
|
||||
proto_tag,
|
||||
|
|
@ -1023,9 +942,9 @@ where
|
|||
&mut frame_buf,
|
||||
stats_clone.as_ref(),
|
||||
&user_clone,
|
||||
quota_user_stats_me_writer.as_deref(),
|
||||
quota_limit,
|
||||
d2c_flush_policy.quota_soft_overshoot_bytes,
|
||||
cross_mode_quota_lock_me_writer.as_ref(),
|
||||
bytes_me2c_clone.as_ref(),
|
||||
conn_id,
|
||||
d2c_flush_policy.ack_flush_immediate,
|
||||
|
|
@ -1079,7 +998,7 @@ where
|
|||
|
||||
let extra_is_downstream_activity =
|
||||
matches!(&extra, MeResponse::Data { .. } | MeResponse::Ack(_));
|
||||
match process_me_writer_response_with_cross_mode_lock(
|
||||
match process_me_writer_response(
|
||||
extra,
|
||||
&mut writer,
|
||||
proto_tag,
|
||||
|
|
@ -1087,9 +1006,9 @@ where
|
|||
&mut frame_buf,
|
||||
stats_clone.as_ref(),
|
||||
&user_clone,
|
||||
quota_user_stats_me_writer.as_deref(),
|
||||
quota_limit,
|
||||
d2c_flush_policy.quota_soft_overshoot_bytes,
|
||||
cross_mode_quota_lock_me_writer.as_ref(),
|
||||
bytes_me2c_clone.as_ref(),
|
||||
conn_id,
|
||||
d2c_flush_policy.ack_flush_immediate,
|
||||
|
|
@ -1259,24 +1178,23 @@ where
|
|||
forensics.bytes_c2me = forensics
|
||||
.bytes_c2me
|
||||
.saturating_add(payload.len() as u64);
|
||||
if let Some(limit) = quota_limit {
|
||||
let quota_lock = quota_user_lock(&user);
|
||||
let _quota_guard = quota_lock.lock().await;
|
||||
let Some(cross_mode_lock) = cross_mode_quota_lock.as_ref() else {
|
||||
main_result = Err(ProxyError::Proxy(
|
||||
"cross-mode quota lock missing for quota-limited session"
|
||||
.to_string(),
|
||||
));
|
||||
break;
|
||||
};
|
||||
let _cross_mode_quota_guard = cross_mode_lock.lock().await;
|
||||
stats.add_user_octets_from(&user, payload.len() as u64);
|
||||
if quota_exceeded_for_user(stats.as_ref(), &user, Some(limit)) {
|
||||
if let (Some(limit), Some(user_stats)) =
|
||||
(quota_limit, quota_user_stats.as_deref())
|
||||
{
|
||||
if reserve_user_quota_with_yield(
|
||||
user_stats,
|
||||
payload.len() as u64,
|
||||
limit,
|
||||
)
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
main_result = Err(ProxyError::DataQuotaExceeded {
|
||||
user: user.clone(),
|
||||
});
|
||||
break;
|
||||
}
|
||||
stats.add_user_octets_from_handle(user_stats, payload.len() as u64);
|
||||
} else {
|
||||
stats.add_user_octets_from(&user, payload.len() as u64);
|
||||
}
|
||||
|
|
@ -1755,7 +1673,6 @@ enum MeWriterResponseOutcome {
|
|||
Close,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
async fn process_me_writer_response<W>(
|
||||
response: MeResponse,
|
||||
client_writer: &mut CryptoWriter<W>,
|
||||
|
|
@ -1764,6 +1681,7 @@ async fn process_me_writer_response<W>(
|
|||
frame_buf: &mut Vec<u8>,
|
||||
stats: &Stats,
|
||||
user: &str,
|
||||
quota_user_stats: Option<&UserStats>,
|
||||
quota_limit: Option<u64>,
|
||||
quota_soft_overshoot_bytes: u64,
|
||||
bytes_me2c: &AtomicU64,
|
||||
|
|
@ -1771,44 +1689,6 @@ async fn process_me_writer_response<W>(
|
|||
ack_flush_immediate: bool,
|
||||
batched: bool,
|
||||
) -> Result<MeWriterResponseOutcome>
|
||||
where
|
||||
W: AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
process_me_writer_response_with_cross_mode_lock(
|
||||
response,
|
||||
client_writer,
|
||||
proto_tag,
|
||||
rng,
|
||||
frame_buf,
|
||||
stats,
|
||||
user,
|
||||
quota_limit,
|
||||
quota_soft_overshoot_bytes,
|
||||
None,
|
||||
bytes_me2c,
|
||||
conn_id,
|
||||
ack_flush_immediate,
|
||||
batched,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn process_me_writer_response_with_cross_mode_lock<W>(
|
||||
response: MeResponse,
|
||||
client_writer: &mut CryptoWriter<W>,
|
||||
proto_tag: ProtoTag,
|
||||
rng: &SecureRandom,
|
||||
frame_buf: &mut Vec<u8>,
|
||||
stats: &Stats,
|
||||
user: &str,
|
||||
quota_limit: Option<u64>,
|
||||
quota_soft_overshoot_bytes: u64,
|
||||
cross_mode_quota_lock: Option<&Arc<AsyncMutex<()>>>,
|
||||
bytes_me2c: &AtomicU64,
|
||||
conn_id: u64,
|
||||
ack_flush_immediate: bool,
|
||||
batched: bool,
|
||||
) -> Result<MeWriterResponseOutcome>
|
||||
where
|
||||
W: AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
|
|
@ -1820,77 +1700,42 @@ where
|
|||
trace!(conn_id, bytes = data.len(), flags, "ME->C data");
|
||||
}
|
||||
let data_len = data.len() as u64;
|
||||
if let Some(limit) = quota_limit {
|
||||
let owned_cross_mode_lock;
|
||||
let cross_mode_lock = if let Some(lock) = cross_mode_quota_lock {
|
||||
lock
|
||||
} else {
|
||||
owned_cross_mode_lock =
|
||||
crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(user);
|
||||
&owned_cross_mode_lock
|
||||
};
|
||||
let cross_mode_quota_guard = cross_mode_lock.lock().await;
|
||||
if let (Some(limit), Some(user_stats)) = (quota_limit, quota_user_stats) {
|
||||
let soft_limit = quota_soft_cap(limit, quota_soft_overshoot_bytes);
|
||||
if quota_would_be_exceeded_for_user_soft(
|
||||
stats,
|
||||
user,
|
||||
Some(limit),
|
||||
data_len,
|
||||
quota_soft_overshoot_bytes,
|
||||
) {
|
||||
if reserve_user_quota_with_yield(user_stats, data_len, soft_limit)
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
stats.increment_me_d2c_quota_reject_total(MeD2cQuotaRejectStage::PreWrite);
|
||||
return Err(ProxyError::DataQuotaExceeded {
|
||||
user: user.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
// Reserve quota before awaiting network I/O to avoid same-user HoL stalls.
|
||||
// If reservation loses a race or write fails, we roll back immediately.
|
||||
bytes_me2c.fetch_add(data_len, Ordering::Relaxed);
|
||||
stats.add_user_octets_to(user, data_len);
|
||||
|
||||
if stats.get_user_total_octets(user) > soft_limit {
|
||||
rollback_me2c_quota_reservation(stats, user, bytes_me2c, data_len);
|
||||
stats.increment_me_d2c_quota_reject_total(MeD2cQuotaRejectStage::PreWrite);
|
||||
return Err(ProxyError::DataQuotaExceeded {
|
||||
user: user.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
// Keep cross-mode lock scope explicit and minimal: quota reservation is serialized,
|
||||
// but socket I/O proceeds without holding same-user cross-mode admission lock.
|
||||
drop(cross_mode_quota_guard);
|
||||
|
||||
let write_mode =
|
||||
match write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf)
|
||||
.await
|
||||
{
|
||||
Ok(mode) => mode,
|
||||
Err(err) => {
|
||||
rollback_me2c_quota_reservation(stats, user, bytes_me2c, data_len);
|
||||
if quota_limit.is_some() {
|
||||
stats.add_quota_write_fail_bytes_total(data_len);
|
||||
stats.increment_quota_write_fail_events_total();
|
||||
}
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
|
||||
stats.increment_me_d2c_data_frames_total();
|
||||
stats.add_me_d2c_payload_bytes_total(data_len);
|
||||
stats.increment_me_d2c_write_mode(write_mode);
|
||||
|
||||
// Do not fail immediately on exact boundary after a successful write.
|
||||
// Returning an error here can bypass batch flush in the caller and risk
|
||||
// dropping buffered ciphertext from CryptoWriter. The next frame is
|
||||
// rejected by the pre-check at function entry.
|
||||
} else {
|
||||
let write_mode =
|
||||
write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf)
|
||||
.await?;
|
||||
|
||||
bytes_me2c.fetch_add(data_len, Ordering::Relaxed);
|
||||
if let Some(user_stats) = quota_user_stats {
|
||||
stats.add_user_octets_to_handle(user_stats, data_len);
|
||||
} else {
|
||||
stats.add_user_octets_to(user, data_len);
|
||||
}
|
||||
stats.increment_me_d2c_data_frames_total();
|
||||
stats.add_me_d2c_payload_bytes_total(data_len);
|
||||
stats.increment_me_d2c_write_mode(write_mode);
|
||||
}
|
||||
|
||||
Ok(MeWriterResponseOutcome::Continue {
|
||||
frames: 1,
|
||||
|
|
@ -2097,10 +1942,6 @@ where
|
|||
.map_err(ProxyError::Io)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/middle_relay_security_tests.rs"]
|
||||
mod security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/middle_relay_idle_policy_security_tests.rs"]
|
||||
mod idle_policy_security_tests;
|
||||
|
|
@ -2113,30 +1954,10 @@ mod desync_all_full_dedup_security_tests;
|
|||
#[path = "tests/middle_relay_stub_completion_security_tests.rs"]
|
||||
mod stub_completion_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/middle_relay_coverage_high_risk_security_tests.rs"]
|
||||
mod coverage_high_risk_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/middle_relay_quota_overflow_lock_security_tests.rs"]
|
||||
mod quota_overflow_lock_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/middle_relay_length_cast_hardening_security_tests.rs"]
|
||||
mod length_cast_hardening_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/middle_relay_blackhat_campaign_integration_tests.rs"]
|
||||
mod blackhat_campaign_integration_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/middle_relay_hol_quota_security_tests.rs"]
|
||||
mod hol_quota_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/middle_relay_quota_reservation_adversarial_tests.rs"]
|
||||
mod quota_reservation_adversarial_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/middle_relay_idle_registry_poison_security_tests.rs"]
|
||||
mod middle_relay_idle_registry_poison_security_tests;
|
||||
|
|
@ -2156,27 +1977,3 @@ mod middle_relay_tiny_frame_debt_concurrency_security_tests;
|
|||
#[cfg(test)]
|
||||
#[path = "tests/middle_relay_tiny_frame_debt_proto_chunking_security_tests.rs"]
|
||||
mod middle_relay_tiny_frame_debt_proto_chunking_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/middle_relay_cross_mode_quota_reservation_security_tests.rs"]
|
||||
mod middle_relay_cross_mode_quota_reservation_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/middle_relay_cross_mode_quota_lock_matrix_security_tests.rs"]
|
||||
mod middle_relay_cross_mode_quota_lock_matrix_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/middle_relay_cross_mode_lookup_efficiency_security_tests.rs"]
|
||||
mod middle_relay_cross_mode_lookup_efficiency_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/middle_relay_cross_mode_lock_release_regression_security_tests.rs"]
|
||||
mod middle_relay_cross_mode_lock_release_regression_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/middle_relay_quota_extended_attack_surface_security_tests.rs"]
|
||||
mod middle_relay_quota_extended_attack_surface_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/middle_relay_quota_reservation_extreme_security_tests.rs"]
|
||||
mod middle_relay_quota_reservation_extreme_security_tests;
|
||||
|
|
|
|||
|
|
@ -64,7 +64,6 @@ pub mod direct_relay;
|
|||
pub mod handshake;
|
||||
pub mod masking;
|
||||
pub mod middle_relay;
|
||||
pub mod quota_lock_registry;
|
||||
pub mod relay;
|
||||
pub mod route_mode;
|
||||
pub mod session_eviction;
|
||||
|
|
|
|||
|
|
@ -1,88 +0,0 @@
|
|||
use dashmap::DashMap;
|
||||
use std::sync::{Arc, OnceLock};
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
#[cfg(test)]
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
|
||||
#[cfg(test)]
|
||||
const CROSS_MODE_QUOTA_USER_LOCKS_MAX: usize = 64;
|
||||
#[cfg(not(test))]
|
||||
const CROSS_MODE_QUOTA_USER_LOCKS_MAX: usize = 4_096;
|
||||
#[cfg(test)]
|
||||
const CROSS_MODE_QUOTA_OVERFLOW_LOCK_STRIPES: usize = 16;
|
||||
#[cfg(not(test))]
|
||||
const CROSS_MODE_QUOTA_OVERFLOW_LOCK_STRIPES: usize = 256;
|
||||
|
||||
static CROSS_MODE_QUOTA_USER_LOCKS: OnceLock<DashMap<String, Arc<Mutex<()>>>> = OnceLock::new();
|
||||
static CROSS_MODE_QUOTA_USER_OVERFLOW_LOCKS: OnceLock<Vec<Arc<Mutex<()>>>> = OnceLock::new();
|
||||
|
||||
#[cfg(test)]
|
||||
static CROSS_MODE_QUOTA_USER_LOCK_LOOKUPS: AtomicUsize = AtomicUsize::new(0);
|
||||
#[cfg(test)]
|
||||
static CROSS_MODE_QUOTA_USER_LOOKUPS_BY_USER: OnceLock<DashMap<String, usize>> = OnceLock::new();
|
||||
|
||||
fn cross_mode_quota_overflow_user_lock(user: &str) -> Arc<Mutex<()>> {
|
||||
let stripes = CROSS_MODE_QUOTA_USER_OVERFLOW_LOCKS.get_or_init(|| {
|
||||
(0..CROSS_MODE_QUOTA_OVERFLOW_LOCK_STRIPES)
|
||||
.map(|_| Arc::new(Mutex::new(())))
|
||||
.collect()
|
||||
});
|
||||
|
||||
let hash = crc32fast::hash(user.as_bytes()) as usize;
|
||||
Arc::clone(&stripes[hash % stripes.len()])
|
||||
}
|
||||
|
||||
pub(crate) fn cross_mode_quota_user_lock(user: &str) -> Arc<Mutex<()>> {
|
||||
#[cfg(test)]
|
||||
{
|
||||
CROSS_MODE_QUOTA_USER_LOCK_LOOKUPS.fetch_add(1, Ordering::Relaxed);
|
||||
let lookups = CROSS_MODE_QUOTA_USER_LOOKUPS_BY_USER.get_or_init(DashMap::new);
|
||||
let mut entry = lookups.entry(user.to_string()).or_insert(0);
|
||||
*entry += 1;
|
||||
}
|
||||
|
||||
let locks = CROSS_MODE_QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
if let Some(existing) = locks.get(user) {
|
||||
return Arc::clone(existing.value());
|
||||
}
|
||||
|
||||
if locks.len() >= CROSS_MODE_QUOTA_USER_LOCKS_MAX {
|
||||
locks.retain(|_, value| Arc::strong_count(value) > 1);
|
||||
}
|
||||
|
||||
if locks.len() >= CROSS_MODE_QUOTA_USER_LOCKS_MAX {
|
||||
return cross_mode_quota_overflow_user_lock(user);
|
||||
}
|
||||
|
||||
let created = Arc::new(Mutex::new(()));
|
||||
match locks.entry(user.to_string()) {
|
||||
dashmap::mapref::entry::Entry::Occupied(entry) => Arc::clone(entry.get()),
|
||||
dashmap::mapref::entry::Entry::Vacant(entry) => {
|
||||
entry.insert(Arc::clone(&created));
|
||||
created
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn reset_cross_mode_quota_user_lock_lookup_count_for_tests() {
|
||||
CROSS_MODE_QUOTA_USER_LOCK_LOOKUPS.store(0, Ordering::Relaxed);
|
||||
let lookups = CROSS_MODE_QUOTA_USER_LOOKUPS_BY_USER.get_or_init(DashMap::new);
|
||||
lookups.clear();
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn cross_mode_quota_user_lock_lookup_count_for_tests() -> usize {
|
||||
CROSS_MODE_QUOTA_USER_LOCK_LOOKUPS.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn cross_mode_quota_user_lock_lookup_count_for_user_for_tests(user: &str) -> usize {
|
||||
let lookups = CROSS_MODE_QUOTA_USER_LOOKUPS_BY_USER.get_or_init(DashMap::new);
|
||||
lookups.get(user).map(|entry| *entry).unwrap_or(0)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/quota_lock_registry_cross_mode_adversarial_tests.rs"]
|
||||
mod quota_lock_registry_cross_mode_adversarial_tests;
|
||||
|
|
@ -52,18 +52,16 @@
|
|||
//! - `SharedCounters` (atomics) let the watchdog read stats without locking
|
||||
|
||||
use crate::error::{ProxyError, Result};
|
||||
use crate::stats::Stats;
|
||||
use crate::stats::{Stats, UserStats};
|
||||
use crate::stream::BufferPool;
|
||||
use dashmap::DashMap;
|
||||
use std::io;
|
||||
use std::pin::Pin;
|
||||
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
|
||||
use std::sync::{Arc, Mutex, OnceLock};
|
||||
use std::sync::Arc;
|
||||
use std::task::{Context, Poll};
|
||||
use std::time::Duration;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf, copy_bidirectional_with_sizes};
|
||||
use tokio::sync::Mutex as AsyncMutex;
|
||||
use tokio::time::{Instant, Sleep};
|
||||
use tokio::time::Instant;
|
||||
use tracing::{debug, trace, warn};
|
||||
|
||||
// ============= Constants =============
|
||||
|
|
@ -210,16 +208,10 @@ struct StatsIo<S> {
|
|||
counters: Arc<SharedCounters>,
|
||||
stats: Arc<Stats>,
|
||||
user: String,
|
||||
quota_lock: Option<Arc<Mutex<()>>>,
|
||||
cross_mode_quota_lock: Option<Arc<AsyncMutex<()>>>,
|
||||
user_stats: Arc<UserStats>,
|
||||
quota_limit: Option<u64>,
|
||||
quota_exceeded: Arc<AtomicBool>,
|
||||
quota_read_wake_scheduled: bool,
|
||||
quota_write_wake_scheduled: bool,
|
||||
quota_read_retry_sleep: Option<Pin<Box<Sleep>>>,
|
||||
quota_write_retry_sleep: Option<Pin<Box<Sleep>>>,
|
||||
quota_read_retry_attempt: u8,
|
||||
quota_write_retry_attempt: u8,
|
||||
quota_bytes_since_check: u64,
|
||||
epoch: Instant,
|
||||
}
|
||||
|
||||
|
|
@ -235,24 +227,16 @@ impl<S> StatsIo<S> {
|
|||
) -> Self {
|
||||
// Mark initial activity so the watchdog doesn't fire before data flows
|
||||
counters.touch(Instant::now(), epoch);
|
||||
let quota_lock = quota_limit.map(|_| quota_user_lock(&user));
|
||||
let cross_mode_quota_lock = quota_limit
|
||||
.map(|_| crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user));
|
||||
let user_stats = stats.get_or_create_user_stats_handle(&user);
|
||||
Self {
|
||||
inner,
|
||||
counters,
|
||||
stats,
|
||||
user,
|
||||
quota_lock,
|
||||
cross_mode_quota_lock,
|
||||
user_stats,
|
||||
quota_limit,
|
||||
quota_exceeded,
|
||||
quota_read_wake_scheduled: false,
|
||||
quota_write_wake_scheduled: false,
|
||||
quota_read_retry_sleep: None,
|
||||
quota_write_retry_sleep: None,
|
||||
quota_read_retry_attempt: 0,
|
||||
quota_write_retry_attempt: 0,
|
||||
quota_bytes_since_check: 0,
|
||||
epoch,
|
||||
}
|
||||
}
|
||||
|
|
@ -281,169 +265,24 @@ fn is_quota_io_error(err: &io::Error) -> bool {
|
|||
.is_some()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
const QUOTA_CONTENTION_RETRY_INTERVAL: Duration = Duration::from_millis(1);
|
||||
#[cfg(not(test))]
|
||||
const QUOTA_CONTENTION_RETRY_INTERVAL: Duration = Duration::from_millis(2);
|
||||
#[cfg(test)]
|
||||
const QUOTA_CONTENTION_RETRY_MAX_INTERVAL: Duration = Duration::from_millis(16);
|
||||
#[cfg(not(test))]
|
||||
const QUOTA_CONTENTION_RETRY_MAX_INTERVAL: Duration = Duration::from_millis(64);
|
||||
const QUOTA_NEAR_LIMIT_BYTES: u64 = 64 * 1024;
|
||||
const QUOTA_LARGE_CHARGE_BYTES: u64 = 16 * 1024;
|
||||
const QUOTA_ADAPTIVE_INTERVAL_MIN_BYTES: u64 = 4 * 1024;
|
||||
const QUOTA_ADAPTIVE_INTERVAL_MAX_BYTES: u64 = 64 * 1024;
|
||||
|
||||
#[cfg(test)]
|
||||
static QUOTA_RETRY_SLEEP_ALLOCS: AtomicU64 = AtomicU64::new(0);
|
||||
#[cfg(test)]
|
||||
static QUOTA_USER_LOCK_EVICTOR_SPAWN_COUNT: AtomicU64 = AtomicU64::new(0);
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn reset_quota_retry_sleep_allocs_for_tests() {
|
||||
QUOTA_RETRY_SLEEP_ALLOCS.store(0, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn quota_retry_sleep_allocs_for_tests() -> u64 {
|
||||
QUOTA_RETRY_SLEEP_ALLOCS.load(Ordering::Relaxed)
|
||||
#[inline]
|
||||
fn quota_adaptive_interval_bytes(remaining_before: u64) -> u64 {
|
||||
remaining_before
|
||||
.saturating_div(2)
|
||||
.clamp(
|
||||
QUOTA_ADAPTIVE_INTERVAL_MIN_BYTES,
|
||||
QUOTA_ADAPTIVE_INTERVAL_MAX_BYTES,
|
||||
)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn quota_contention_retry_delay(retry_attempt: u8) -> Duration {
|
||||
let shift = u32::from(retry_attempt.min(5));
|
||||
let multiplier = 1_u32 << shift;
|
||||
QUOTA_CONTENTION_RETRY_INTERVAL
|
||||
.saturating_mul(multiplier)
|
||||
.min(QUOTA_CONTENTION_RETRY_MAX_INTERVAL)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn reset_quota_retry_scheduler(
|
||||
sleep_slot: &mut Option<Pin<Box<Sleep>>>,
|
||||
wake_scheduled: &mut bool,
|
||||
retry_attempt: &mut u8,
|
||||
) {
|
||||
*wake_scheduled = false;
|
||||
*sleep_slot = None;
|
||||
*retry_attempt = 0;
|
||||
}
|
||||
|
||||
fn poll_quota_retry_sleep(
|
||||
sleep_slot: &mut Option<Pin<Box<Sleep>>>,
|
||||
wake_scheduled: &mut bool,
|
||||
retry_attempt: &mut u8,
|
||||
cx: &mut Context<'_>,
|
||||
) {
|
||||
if !*wake_scheduled {
|
||||
*wake_scheduled = true;
|
||||
#[cfg(test)]
|
||||
QUOTA_RETRY_SLEEP_ALLOCS.fetch_add(1, Ordering::Relaxed);
|
||||
*sleep_slot = Some(Box::pin(tokio::time::sleep(quota_contention_retry_delay(
|
||||
*retry_attempt,
|
||||
))));
|
||||
}
|
||||
|
||||
if let Some(sleep) = sleep_slot.as_mut()
|
||||
&& sleep.as_mut().poll(cx).is_ready()
|
||||
{
|
||||
*sleep_slot = None;
|
||||
*wake_scheduled = false;
|
||||
*retry_attempt = retry_attempt.saturating_add(1);
|
||||
cx.waker().wake_by_ref();
|
||||
}
|
||||
}
|
||||
|
||||
static QUOTA_USER_LOCKS: OnceLock<DashMap<String, Arc<Mutex<()>>>> = OnceLock::new();
|
||||
static QUOTA_USER_OVERFLOW_LOCKS: OnceLock<Vec<Arc<Mutex<()>>>> = OnceLock::new();
|
||||
|
||||
#[cfg(test)]
|
||||
const QUOTA_USER_LOCKS_MAX: usize = 64;
|
||||
#[cfg(not(test))]
|
||||
const QUOTA_USER_LOCKS_MAX: usize = 4_096;
|
||||
#[cfg(test)]
|
||||
const QUOTA_OVERFLOW_LOCK_STRIPES: usize = 16;
|
||||
#[cfg(not(test))]
|
||||
const QUOTA_OVERFLOW_LOCK_STRIPES: usize = 256;
|
||||
|
||||
#[cfg(test)]
|
||||
fn quota_user_lock_test_guard() -> &'static Mutex<()> {
|
||||
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
TEST_LOCK.get_or_init(|| Mutex::new(()))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn quota_user_lock_test_scope() -> std::sync::MutexGuard<'static, ()> {
|
||||
quota_user_lock_test_guard()
|
||||
.lock()
|
||||
.unwrap_or_else(|poisoned| poisoned.into_inner())
|
||||
}
|
||||
|
||||
fn quota_overflow_user_lock(user: &str) -> Arc<Mutex<()>> {
|
||||
let stripes = QUOTA_USER_OVERFLOW_LOCKS.get_or_init(|| {
|
||||
(0..QUOTA_OVERFLOW_LOCK_STRIPES)
|
||||
.map(|_| Arc::new(Mutex::new(())))
|
||||
.collect()
|
||||
});
|
||||
|
||||
let hash = crc32fast::hash(user.as_bytes()) as usize;
|
||||
Arc::clone(&stripes[hash % stripes.len()])
|
||||
}
|
||||
|
||||
pub(crate) fn quota_user_lock_evict() {
|
||||
if let Some(locks) = QUOTA_USER_LOCKS.get() {
|
||||
locks.retain(|_, value| Arc::strong_count(value) > 1);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn spawn_quota_user_lock_evictor(interval: Duration) -> tokio::task::JoinHandle<()> {
|
||||
let interval = interval.max(Duration::from_millis(1));
|
||||
#[cfg(test)]
|
||||
QUOTA_USER_LOCK_EVICTOR_SPAWN_COUNT.fetch_add(1, Ordering::Relaxed);
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
tokio::time::sleep(interval).await;
|
||||
quota_user_lock_evict();
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn spawn_quota_user_lock_evictor_for_tests(
|
||||
interval: Duration,
|
||||
) -> tokio::task::JoinHandle<()> {
|
||||
spawn_quota_user_lock_evictor(interval)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn reset_quota_user_lock_evictor_spawn_count_for_tests() {
|
||||
QUOTA_USER_LOCK_EVICTOR_SPAWN_COUNT.store(0, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn quota_user_lock_evictor_spawn_count_for_tests() -> u64 {
|
||||
QUOTA_USER_LOCK_EVICTOR_SPAWN_COUNT.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
fn quota_user_lock(user: &str) -> Arc<Mutex<()>> {
|
||||
let locks = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
if let Some(existing) = locks.get(user) {
|
||||
return Arc::clone(existing.value());
|
||||
}
|
||||
|
||||
if locks.len() >= QUOTA_USER_LOCKS_MAX {
|
||||
return quota_overflow_user_lock(user);
|
||||
}
|
||||
|
||||
let created = Arc::new(Mutex::new(()));
|
||||
match locks.entry(user.to_string()) {
|
||||
dashmap::mapref::entry::Entry::Occupied(entry) => Arc::clone(entry.get()),
|
||||
dashmap::mapref::entry::Entry::Vacant(entry) => {
|
||||
entry.insert(Arc::clone(&created));
|
||||
created
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn cross_mode_quota_user_lock_for_tests(user: &str) -> Arc<AsyncMutex<()>> {
|
||||
crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(user)
|
||||
fn should_immediate_quota_check(remaining_before: u64, charge_bytes: u64) -> bool {
|
||||
remaining_before <= QUOTA_NEAR_LIMIT_BYTES || charge_bytes >= QUOTA_LARGE_CHARGE_BYTES
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + Unpin> AsyncRead for StatsIo<S> {
|
||||
|
|
@ -453,93 +292,60 @@ impl<S: AsyncRead + Unpin> AsyncRead for StatsIo<S> {
|
|||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
let this = self.get_mut();
|
||||
if this.quota_exceeded.load(Ordering::Relaxed) {
|
||||
if this.quota_exceeded.load(Ordering::Acquire) {
|
||||
return Poll::Ready(Err(quota_io_error()));
|
||||
}
|
||||
|
||||
let _quota_guard = if let Some(lock) = this.quota_lock.as_ref() {
|
||||
match lock.try_lock() {
|
||||
Ok(guard) => Some(guard),
|
||||
Err(_) => {
|
||||
poll_quota_retry_sleep(
|
||||
&mut this.quota_read_retry_sleep,
|
||||
&mut this.quota_read_wake_scheduled,
|
||||
&mut this.quota_read_retry_attempt,
|
||||
cx,
|
||||
);
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let _cross_mode_quota_guard = if let Some(lock) = this.cross_mode_quota_lock.as_ref() {
|
||||
match lock.try_lock() {
|
||||
Ok(guard) => Some(guard),
|
||||
Err(_) => {
|
||||
poll_quota_retry_sleep(
|
||||
&mut this.quota_read_retry_sleep,
|
||||
&mut this.quota_read_wake_scheduled,
|
||||
&mut this.quota_read_retry_attempt,
|
||||
cx,
|
||||
);
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
reset_quota_retry_scheduler(
|
||||
&mut this.quota_read_retry_sleep,
|
||||
&mut this.quota_read_wake_scheduled,
|
||||
&mut this.quota_read_retry_attempt,
|
||||
);
|
||||
|
||||
if let Some(limit) = this.quota_limit
|
||||
&& this.stats.get_user_total_octets(&this.user) >= limit
|
||||
{
|
||||
this.quota_exceeded.store(true, Ordering::Relaxed);
|
||||
let mut remaining_before = None;
|
||||
if let Some(limit) = this.quota_limit {
|
||||
let used_before = this.user_stats.quota_used();
|
||||
let remaining = limit.saturating_sub(used_before);
|
||||
if remaining == 0 {
|
||||
this.quota_exceeded.store(true, Ordering::Release);
|
||||
return Poll::Ready(Err(quota_io_error()));
|
||||
}
|
||||
remaining_before = Some(remaining);
|
||||
}
|
||||
|
||||
let before = buf.filled().len();
|
||||
|
||||
match Pin::new(&mut this.inner).poll_read(cx, buf) {
|
||||
Poll::Ready(Ok(())) => {
|
||||
let n = buf.filled().len() - before;
|
||||
if n > 0 {
|
||||
let mut reached_quota_boundary = false;
|
||||
if let Some(limit) = this.quota_limit {
|
||||
let used = this.stats.get_user_total_octets(&this.user);
|
||||
if used >= limit {
|
||||
this.quota_exceeded.store(true, Ordering::Relaxed);
|
||||
return Poll::Ready(Err(quota_io_error()));
|
||||
}
|
||||
|
||||
let remaining = limit - used;
|
||||
if (n as u64) > remaining {
|
||||
// Fail closed: when a single read chunk would cross quota,
|
||||
// stop relay immediately without accounting beyond the cap.
|
||||
this.quota_exceeded.store(true, Ordering::Relaxed);
|
||||
return Poll::Ready(Err(quota_io_error()));
|
||||
}
|
||||
|
||||
reached_quota_boundary = (n as u64) == remaining;
|
||||
}
|
||||
let n_to_charge = n as u64;
|
||||
|
||||
// C→S: client sent data
|
||||
this.counters
|
||||
.c2s_bytes
|
||||
.fetch_add(n as u64, Ordering::Relaxed);
|
||||
.fetch_add(n_to_charge, Ordering::Relaxed);
|
||||
this.counters.c2s_ops.fetch_add(1, Ordering::Relaxed);
|
||||
this.counters.touch(Instant::now(), this.epoch);
|
||||
|
||||
this.stats.add_user_octets_from(&this.user, n as u64);
|
||||
this.stats.increment_user_msgs_from(&this.user);
|
||||
this.stats
|
||||
.add_user_octets_from_handle(this.user_stats.as_ref(), n_to_charge);
|
||||
this.stats
|
||||
.increment_user_msgs_from_handle(this.user_stats.as_ref());
|
||||
|
||||
if reached_quota_boundary {
|
||||
this.quota_exceeded.store(true, Ordering::Relaxed);
|
||||
if let (Some(limit), Some(remaining)) = (this.quota_limit, remaining_before) {
|
||||
this.stats
|
||||
.quota_charge_post_write(this.user_stats.as_ref(), n_to_charge);
|
||||
if should_immediate_quota_check(remaining, n_to_charge) {
|
||||
this.quota_bytes_since_check = 0;
|
||||
if this.user_stats.quota_used() >= limit {
|
||||
this.quota_exceeded.store(true, Ordering::Release);
|
||||
}
|
||||
} else {
|
||||
this.quota_bytes_since_check =
|
||||
this.quota_bytes_since_check.saturating_add(n_to_charge);
|
||||
let interval = quota_adaptive_interval_bytes(remaining);
|
||||
if this.quota_bytes_since_check >= interval {
|
||||
this.quota_bytes_since_check = 0;
|
||||
if this.user_stats.quota_used() >= limit {
|
||||
this.quota_exceeded.store(true, Ordering::Release);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
trace!(user = %this.user, bytes = n, "C->S");
|
||||
|
|
@ -558,87 +364,57 @@ impl<S: AsyncWrite + Unpin> AsyncWrite for StatsIo<S> {
|
|||
buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
let this = self.get_mut();
|
||||
if this.quota_exceeded.load(Ordering::Relaxed) {
|
||||
if this.quota_exceeded.load(Ordering::Acquire) {
|
||||
return Poll::Ready(Err(quota_io_error()));
|
||||
}
|
||||
|
||||
let _quota_guard = if let Some(lock) = this.quota_lock.as_ref() {
|
||||
match lock.try_lock() {
|
||||
Ok(guard) => Some(guard),
|
||||
Err(_) => {
|
||||
poll_quota_retry_sleep(
|
||||
&mut this.quota_write_retry_sleep,
|
||||
&mut this.quota_write_wake_scheduled,
|
||||
&mut this.quota_write_retry_attempt,
|
||||
cx,
|
||||
);
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let _cross_mode_quota_guard = if let Some(lock) = this.cross_mode_quota_lock.as_ref() {
|
||||
match lock.try_lock() {
|
||||
Ok(guard) => Some(guard),
|
||||
Err(_) => {
|
||||
poll_quota_retry_sleep(
|
||||
&mut this.quota_write_retry_sleep,
|
||||
&mut this.quota_write_wake_scheduled,
|
||||
&mut this.quota_write_retry_attempt,
|
||||
cx,
|
||||
);
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
reset_quota_retry_scheduler(
|
||||
&mut this.quota_write_retry_sleep,
|
||||
&mut this.quota_write_wake_scheduled,
|
||||
&mut this.quota_write_retry_attempt,
|
||||
);
|
||||
|
||||
let write_buf = if let Some(limit) = this.quota_limit {
|
||||
let used = this.stats.get_user_total_octets(&this.user);
|
||||
if used >= limit {
|
||||
this.quota_exceeded.store(true, Ordering::Relaxed);
|
||||
let mut remaining_before = None;
|
||||
if let Some(limit) = this.quota_limit {
|
||||
let used_before = this.user_stats.quota_used();
|
||||
let remaining = limit.saturating_sub(used_before);
|
||||
if remaining == 0 {
|
||||
this.quota_exceeded.store(true, Ordering::Release);
|
||||
return Poll::Ready(Err(quota_io_error()));
|
||||
}
|
||||
|
||||
let remaining = (limit - used) as usize;
|
||||
if buf.len() > remaining {
|
||||
// Fail closed: do not emit partial S->C payload when remaining
|
||||
// quota cannot accommodate the pending write request.
|
||||
this.quota_exceeded.store(true, Ordering::Relaxed);
|
||||
return Poll::Ready(Err(quota_io_error()));
|
||||
remaining_before = Some(remaining);
|
||||
}
|
||||
buf
|
||||
} else {
|
||||
buf
|
||||
};
|
||||
|
||||
match Pin::new(&mut this.inner).poll_write(cx, write_buf) {
|
||||
match Pin::new(&mut this.inner).poll_write(cx, buf) {
|
||||
Poll::Ready(Ok(n)) => {
|
||||
if n > 0 {
|
||||
let n_to_charge = n as u64;
|
||||
|
||||
// S→C: data written to client
|
||||
this.counters
|
||||
.s2c_bytes
|
||||
.fetch_add(n as u64, Ordering::Relaxed);
|
||||
.fetch_add(n_to_charge, Ordering::Relaxed);
|
||||
this.counters.s2c_ops.fetch_add(1, Ordering::Relaxed);
|
||||
this.counters.touch(Instant::now(), this.epoch);
|
||||
|
||||
this.stats.add_user_octets_to(&this.user, n as u64);
|
||||
this.stats.increment_user_msgs_to(&this.user);
|
||||
this.stats
|
||||
.add_user_octets_to_handle(this.user_stats.as_ref(), n_to_charge);
|
||||
this.stats
|
||||
.increment_user_msgs_to_handle(this.user_stats.as_ref());
|
||||
|
||||
if let Some(limit) = this.quota_limit
|
||||
&& this.stats.get_user_total_octets(&this.user) >= limit
|
||||
{
|
||||
this.quota_exceeded.store(true, Ordering::Relaxed);
|
||||
return Poll::Ready(Err(quota_io_error()));
|
||||
if let (Some(limit), Some(remaining)) = (this.quota_limit, remaining_before) {
|
||||
this.stats
|
||||
.quota_charge_post_write(this.user_stats.as_ref(), n_to_charge);
|
||||
if should_immediate_quota_check(remaining, n_to_charge) {
|
||||
this.quota_bytes_since_check = 0;
|
||||
if this.user_stats.quota_used() >= limit {
|
||||
this.quota_exceeded.store(true, Ordering::Release);
|
||||
}
|
||||
} else {
|
||||
this.quota_bytes_since_check =
|
||||
this.quota_bytes_since_check.saturating_add(n_to_charge);
|
||||
let interval = quota_adaptive_interval_bytes(remaining);
|
||||
if this.quota_bytes_since_check >= interval {
|
||||
this.quota_bytes_since_check = 0;
|
||||
if this.user_stats.quota_used() >= limit {
|
||||
this.quota_exceeded.store(true, Ordering::Release);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
trace!(user = %this.user, bytes = n, "S->C");
|
||||
|
|
@ -732,7 +508,7 @@ where
|
|||
let now = Instant::now();
|
||||
let idle = wd_counters.idle_duration(now, epoch);
|
||||
|
||||
if wd_quota_exceeded.load(Ordering::Relaxed) {
|
||||
if wd_quota_exceeded.load(Ordering::Acquire) {
|
||||
warn!(user = %wd_user, "User data quota reached, closing relay");
|
||||
return;
|
||||
}
|
||||
|
|
@ -870,18 +646,10 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_security_tests.rs"]
|
||||
mod security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_adversarial_tests.rs"]
|
||||
mod adversarial_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_quota_lock_pressure_adversarial_tests.rs"]
|
||||
mod relay_quota_lock_pressure_adversarial_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_quota_boundary_blackhat_tests.rs"]
|
||||
mod relay_quota_boundary_blackhat_tests;
|
||||
|
|
@ -901,71 +669,3 @@ mod relay_quota_extended_attack_surface_security_tests;
|
|||
#[cfg(test)]
|
||||
#[path = "tests/relay_watchdog_delta_security_tests.rs"]
|
||||
mod relay_watchdog_delta_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_quota_waker_storm_adversarial_tests.rs"]
|
||||
mod relay_quota_waker_storm_adversarial_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_quota_wake_liveness_regression_tests.rs"]
|
||||
mod relay_quota_wake_liveness_regression_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_quota_lock_identity_security_tests.rs"]
|
||||
mod relay_quota_lock_identity_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_cross_mode_quota_lock_security_tests.rs"]
|
||||
mod relay_cross_mode_quota_lock_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_quota_retry_scheduler_tdd_tests.rs"]
|
||||
mod relay_quota_retry_scheduler_tdd_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_cross_mode_quota_fairness_tdd_tests.rs"]
|
||||
mod relay_cross_mode_quota_fairness_tdd_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_cross_mode_pipeline_hol_integration_security_tests.rs"]
|
||||
mod relay_cross_mode_pipeline_hol_integration_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_cross_mode_pipeline_latency_benchmark_security_tests.rs"]
|
||||
mod relay_cross_mode_pipeline_latency_benchmark_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_quota_retry_backoff_security_tests.rs"]
|
||||
mod relay_quota_retry_backoff_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_quota_retry_backoff_benchmark_security_tests.rs"]
|
||||
mod relay_quota_retry_backoff_benchmark_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_dual_lock_backoff_regression_security_tests.rs"]
|
||||
mod relay_dual_lock_backoff_regression_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_dual_lock_contention_matrix_security_tests.rs"]
|
||||
mod relay_dual_lock_contention_matrix_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_dual_lock_race_harness_security_tests.rs"]
|
||||
mod relay_dual_lock_race_harness_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_dual_lock_alternating_contention_security_tests.rs"]
|
||||
mod relay_dual_lock_alternating_contention_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_quota_retry_allocation_latency_security_tests.rs"]
|
||||
mod relay_quota_retry_allocation_latency_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_quota_lock_eviction_lifecycle_tdd_tests.rs"]
|
||||
mod relay_quota_lock_eviction_lifecycle_tdd_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_quota_lock_eviction_stress_security_tests.rs"]
|
||||
mod relay_quota_lock_eviction_stress_security_tests;
|
||||
|
|
|
|||
311
src/stats/mod.rs
311
src/stats/mod.rs
|
|
@ -238,10 +238,12 @@ pub struct Stats {
|
|||
me_inline_recovery_total: AtomicU64,
|
||||
ip_reservation_rollback_tcp_limit_total: AtomicU64,
|
||||
ip_reservation_rollback_quota_limit_total: AtomicU64,
|
||||
quota_write_fail_bytes_total: AtomicU64,
|
||||
quota_write_fail_events_total: AtomicU64,
|
||||
telemetry_core_enabled: AtomicBool,
|
||||
telemetry_user_enabled: AtomicBool,
|
||||
telemetry_me_level: AtomicU8,
|
||||
user_stats: DashMap<String, UserStats>,
|
||||
user_stats: DashMap<String, Arc<UserStats>>,
|
||||
user_stats_last_cleanup_epoch_secs: AtomicU64,
|
||||
start_time: parking_lot::RwLock<Option<Instant>>,
|
||||
}
|
||||
|
|
@ -254,9 +256,51 @@ pub struct UserStats {
|
|||
pub octets_to_client: AtomicU64,
|
||||
pub msgs_from_client: AtomicU64,
|
||||
pub msgs_to_client: AtomicU64,
|
||||
/// Total bytes charged against per-user quota admission.
|
||||
///
|
||||
/// This counter is the single source of truth for quota enforcement and
|
||||
/// intentionally tracks attempted traffic, not guaranteed delivery.
|
||||
pub quota_used: AtomicU64,
|
||||
pub last_seen_epoch_secs: AtomicU64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum QuotaReserveError {
|
||||
LimitExceeded,
|
||||
Contended,
|
||||
}
|
||||
|
||||
impl UserStats {
|
||||
#[inline]
|
||||
pub fn quota_used(&self) -> u64 {
|
||||
self.quota_used.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// Attempts one CAS reservation step against the quota counter.
|
||||
///
|
||||
/// Callers control retry/yield policy. This primitive intentionally does
|
||||
/// not block or sleep so both sync poll paths and async paths can wrap it
|
||||
/// with their own contention strategy.
|
||||
#[inline]
|
||||
pub fn quota_try_reserve(&self, bytes: u64, limit: u64) -> Result<u64, QuotaReserveError> {
|
||||
let current = self.quota_used.load(Ordering::Relaxed);
|
||||
if bytes > limit.saturating_sub(current) {
|
||||
return Err(QuotaReserveError::LimitExceeded);
|
||||
}
|
||||
|
||||
let next = current.saturating_add(bytes);
|
||||
match self.quota_used.compare_exchange_weak(
|
||||
current,
|
||||
next,
|
||||
Ordering::Relaxed,
|
||||
Ordering::Relaxed,
|
||||
) {
|
||||
Ok(_) => Ok(next),
|
||||
Err(_) => Err(QuotaReserveError::Contended),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Stats {
|
||||
pub fn new() -> Self {
|
||||
let stats = Self::default();
|
||||
|
|
@ -316,6 +360,70 @@ impl Stats {
|
|||
.store(Self::now_epoch_secs(), Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub(crate) fn get_or_create_user_stats_handle(&self, user: &str) -> Arc<UserStats> {
|
||||
self.maybe_cleanup_user_stats();
|
||||
if let Some(existing) = self.user_stats.get(user) {
|
||||
let handle = Arc::clone(existing.value());
|
||||
Self::touch_user_stats(handle.as_ref());
|
||||
return handle;
|
||||
}
|
||||
|
||||
let entry = self.user_stats.entry(user.to_string()).or_default();
|
||||
if entry.last_seen_epoch_secs.load(Ordering::Relaxed) == 0 {
|
||||
Self::touch_user_stats(entry.value().as_ref());
|
||||
}
|
||||
Arc::clone(entry.value())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub(crate) fn add_user_octets_from_handle(&self, user_stats: &UserStats, bytes: u64) {
|
||||
if !self.telemetry_user_enabled() {
|
||||
return;
|
||||
}
|
||||
Self::touch_user_stats(user_stats);
|
||||
user_stats.octets_from_client.fetch_add(bytes, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub(crate) fn add_user_octets_to_handle(&self, user_stats: &UserStats, bytes: u64) {
|
||||
if !self.telemetry_user_enabled() {
|
||||
return;
|
||||
}
|
||||
Self::touch_user_stats(user_stats);
|
||||
user_stats.octets_to_client.fetch_add(bytes, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub(crate) fn increment_user_msgs_from_handle(&self, user_stats: &UserStats) {
|
||||
if !self.telemetry_user_enabled() {
|
||||
return;
|
||||
}
|
||||
Self::touch_user_stats(user_stats);
|
||||
user_stats.msgs_from_client.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub(crate) fn increment_user_msgs_to_handle(&self, user_stats: &UserStats) {
|
||||
if !self.telemetry_user_enabled() {
|
||||
return;
|
||||
}
|
||||
Self::touch_user_stats(user_stats);
|
||||
user_stats.msgs_to_client.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Charges already committed bytes in a post-I/O path.
|
||||
///
|
||||
/// This helper is intentionally separate from `quota_try_reserve` to avoid
|
||||
/// mixing reserve and post-charge on a single I/O event.
|
||||
#[inline]
|
||||
pub(crate) fn quota_charge_post_write(&self, user_stats: &UserStats, bytes: u64) -> u64 {
|
||||
Self::touch_user_stats(user_stats);
|
||||
user_stats
|
||||
.quota_used
|
||||
.fetch_add(bytes, Ordering::Relaxed)
|
||||
.saturating_add(bytes)
|
||||
}
|
||||
|
||||
fn maybe_cleanup_user_stats(&self) {
|
||||
const USER_STATS_CLEANUP_INTERVAL_SECS: u64 = 60;
|
||||
const USER_STATS_IDLE_TTL_SECS: u64 = 24 * 60 * 60;
|
||||
|
|
@ -1114,6 +1222,18 @@ impl Stats {
|
|||
.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
pub fn add_quota_write_fail_bytes_total(&self, bytes: u64) {
|
||||
if self.telemetry_core_enabled() {
|
||||
self.quota_write_fail_bytes_total
|
||||
.fetch_add(bytes, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
pub fn increment_quota_write_fail_events_total(&self) {
|
||||
if self.telemetry_core_enabled() {
|
||||
self.quota_write_fail_events_total
|
||||
.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
pub fn increment_me_endpoint_quarantine_total(&self) {
|
||||
if self.telemetry_me_allows_normal() {
|
||||
self.me_endpoint_quarantine_total
|
||||
|
|
@ -1764,19 +1884,19 @@ impl Stats {
|
|||
self.ip_reservation_rollback_quota_limit_total
|
||||
.load(Ordering::Relaxed)
|
||||
}
|
||||
pub fn get_quota_write_fail_bytes_total(&self) -> u64 {
|
||||
self.quota_write_fail_bytes_total.load(Ordering::Relaxed)
|
||||
}
|
||||
pub fn get_quota_write_fail_events_total(&self) -> u64 {
|
||||
self.quota_write_fail_events_total.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
pub fn increment_user_connects(&self, user: &str) {
|
||||
if !self.telemetry_user_enabled() {
|
||||
return;
|
||||
}
|
||||
self.maybe_cleanup_user_stats();
|
||||
if let Some(stats) = self.user_stats.get(user) {
|
||||
Self::touch_user_stats(stats.value());
|
||||
stats.connects.fetch_add(1, Ordering::Relaxed);
|
||||
return;
|
||||
}
|
||||
let stats = self.user_stats.entry(user.to_string()).or_default();
|
||||
Self::touch_user_stats(stats.value());
|
||||
let stats = self.get_or_create_user_stats_handle(user);
|
||||
Self::touch_user_stats(stats.as_ref());
|
||||
stats.connects.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
|
|
@ -1784,14 +1904,8 @@ impl Stats {
|
|||
if !self.telemetry_user_enabled() {
|
||||
return;
|
||||
}
|
||||
self.maybe_cleanup_user_stats();
|
||||
if let Some(stats) = self.user_stats.get(user) {
|
||||
Self::touch_user_stats(stats.value());
|
||||
stats.curr_connects.fetch_add(1, Ordering::Relaxed);
|
||||
return;
|
||||
}
|
||||
let stats = self.user_stats.entry(user.to_string()).or_default();
|
||||
Self::touch_user_stats(stats.value());
|
||||
let stats = self.get_or_create_user_stats_handle(user);
|
||||
Self::touch_user_stats(stats.as_ref());
|
||||
stats.curr_connects.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
|
|
@ -1800,9 +1914,8 @@ impl Stats {
|
|||
return true;
|
||||
}
|
||||
|
||||
self.maybe_cleanup_user_stats();
|
||||
let stats = self.user_stats.entry(user.to_string()).or_default();
|
||||
Self::touch_user_stats(stats.value());
|
||||
let stats = self.get_or_create_user_stats_handle(user);
|
||||
Self::touch_user_stats(stats.as_ref());
|
||||
|
||||
let counter = &stats.curr_connects;
|
||||
let mut current = counter.load(Ordering::Relaxed);
|
||||
|
|
@ -1827,7 +1940,7 @@ impl Stats {
|
|||
pub fn decrement_user_curr_connects(&self, user: &str) {
|
||||
self.maybe_cleanup_user_stats();
|
||||
if let Some(stats) = self.user_stats.get(user) {
|
||||
Self::touch_user_stats(stats.value());
|
||||
Self::touch_user_stats(stats.value().as_ref());
|
||||
let counter = &stats.curr_connects;
|
||||
let mut current = counter.load(Ordering::Relaxed);
|
||||
loop {
|
||||
|
|
@ -1858,86 +1971,32 @@ impl Stats {
|
|||
if !self.telemetry_user_enabled() {
|
||||
return;
|
||||
}
|
||||
self.maybe_cleanup_user_stats();
|
||||
if let Some(stats) = self.user_stats.get(user) {
|
||||
Self::touch_user_stats(stats.value());
|
||||
stats.octets_from_client.fetch_add(bytes, Ordering::Relaxed);
|
||||
return;
|
||||
}
|
||||
let stats = self.user_stats.entry(user.to_string()).or_default();
|
||||
Self::touch_user_stats(stats.value());
|
||||
stats.octets_from_client.fetch_add(bytes, Ordering::Relaxed);
|
||||
let stats = self.get_or_create_user_stats_handle(user);
|
||||
self.add_user_octets_from_handle(stats.as_ref(), bytes);
|
||||
}
|
||||
|
||||
pub fn add_user_octets_to(&self, user: &str, bytes: u64) {
|
||||
if !self.telemetry_user_enabled() {
|
||||
return;
|
||||
}
|
||||
self.maybe_cleanup_user_stats();
|
||||
if let Some(stats) = self.user_stats.get(user) {
|
||||
Self::touch_user_stats(stats.value());
|
||||
stats.octets_to_client.fetch_add(bytes, Ordering::Relaxed);
|
||||
return;
|
||||
}
|
||||
let stats = self.user_stats.entry(user.to_string()).or_default();
|
||||
Self::touch_user_stats(stats.value());
|
||||
stats.octets_to_client.fetch_add(bytes, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn sub_user_octets_to(&self, user: &str, bytes: u64) {
|
||||
if !self.telemetry_user_enabled() {
|
||||
return;
|
||||
}
|
||||
self.maybe_cleanup_user_stats();
|
||||
let Some(stats) = self.user_stats.get(user) else {
|
||||
return;
|
||||
};
|
||||
|
||||
Self::touch_user_stats(stats.value());
|
||||
let counter = &stats.octets_to_client;
|
||||
let mut current = counter.load(Ordering::Relaxed);
|
||||
loop {
|
||||
let next = current.saturating_sub(bytes);
|
||||
match counter.compare_exchange_weak(
|
||||
current,
|
||||
next,
|
||||
Ordering::Relaxed,
|
||||
Ordering::Relaxed,
|
||||
) {
|
||||
Ok(_) => break,
|
||||
Err(actual) => current = actual,
|
||||
}
|
||||
}
|
||||
let stats = self.get_or_create_user_stats_handle(user);
|
||||
self.add_user_octets_to_handle(stats.as_ref(), bytes);
|
||||
}
|
||||
|
||||
pub fn increment_user_msgs_from(&self, user: &str) {
|
||||
if !self.telemetry_user_enabled() {
|
||||
return;
|
||||
}
|
||||
self.maybe_cleanup_user_stats();
|
||||
if let Some(stats) = self.user_stats.get(user) {
|
||||
Self::touch_user_stats(stats.value());
|
||||
stats.msgs_from_client.fetch_add(1, Ordering::Relaxed);
|
||||
return;
|
||||
}
|
||||
let stats = self.user_stats.entry(user.to_string()).or_default();
|
||||
Self::touch_user_stats(stats.value());
|
||||
stats.msgs_from_client.fetch_add(1, Ordering::Relaxed);
|
||||
let stats = self.get_or_create_user_stats_handle(user);
|
||||
self.increment_user_msgs_from_handle(stats.as_ref());
|
||||
}
|
||||
|
||||
pub fn increment_user_msgs_to(&self, user: &str) {
|
||||
if !self.telemetry_user_enabled() {
|
||||
return;
|
||||
}
|
||||
self.maybe_cleanup_user_stats();
|
||||
if let Some(stats) = self.user_stats.get(user) {
|
||||
Self::touch_user_stats(stats.value());
|
||||
stats.msgs_to_client.fetch_add(1, Ordering::Relaxed);
|
||||
return;
|
||||
}
|
||||
let stats = self.user_stats.entry(user.to_string()).or_default();
|
||||
Self::touch_user_stats(stats.value());
|
||||
stats.msgs_to_client.fetch_add(1, Ordering::Relaxed);
|
||||
let stats = self.get_or_create_user_stats_handle(user);
|
||||
self.increment_user_msgs_to_handle(stats.as_ref());
|
||||
}
|
||||
|
||||
pub fn get_user_total_octets(&self, user: &str) -> u64 {
|
||||
|
|
@ -1950,6 +2009,13 @@ impl Stats {
|
|||
.unwrap_or(0)
|
||||
}
|
||||
|
||||
pub fn get_user_quota_used(&self, user: &str) -> u64 {
|
||||
self.user_stats
|
||||
.get(user)
|
||||
.map(|s| s.quota_used.load(Ordering::Relaxed))
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
||||
pub fn get_handshake_timeouts(&self) -> u64 {
|
||||
self.handshake_timeouts.load(Ordering::Relaxed)
|
||||
}
|
||||
|
|
@ -2015,7 +2081,7 @@ impl Stats {
|
|||
.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
pub fn iter_user_stats(&self) -> dashmap::iter::Iter<'_, String, UserStats> {
|
||||
pub fn iter_user_stats(&self) -> dashmap::iter::Iter<'_, String, Arc<UserStats>> {
|
||||
self.user_stats.iter()
|
||||
}
|
||||
|
||||
|
|
@ -2163,6 +2229,22 @@ impl ReplayChecker {
|
|||
found
|
||||
}
|
||||
|
||||
fn check_only_internal(
|
||||
&self,
|
||||
data: &[u8],
|
||||
shards: &[Mutex<ReplayShard>],
|
||||
window: Duration,
|
||||
) -> bool {
|
||||
self.checks.fetch_add(1, Ordering::Relaxed);
|
||||
let idx = self.get_shard_idx(data);
|
||||
let mut shard = shards[idx].lock();
|
||||
let found = shard.check(data, Instant::now(), window);
|
||||
if found {
|
||||
self.hits.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
found
|
||||
}
|
||||
|
||||
fn add_only(&self, data: &[u8], shards: &[Mutex<ReplayShard>], window: Duration) {
|
||||
self.additions.fetch_add(1, Ordering::Relaxed);
|
||||
let idx = self.get_shard_idx(data);
|
||||
|
|
@ -2186,7 +2268,7 @@ impl ReplayChecker {
|
|||
self.add_only(data, &self.handshake_shards, self.window)
|
||||
}
|
||||
pub fn check_tls_digest(&self, data: &[u8]) -> bool {
|
||||
self.check_and_add_tls_digest(data)
|
||||
self.check_only_internal(data, &self.tls_shards, self.tls_window)
|
||||
}
|
||||
pub fn add_tls_digest(&self, data: &[u8]) {
|
||||
self.add_only(data, &self.tls_shards, self.tls_window)
|
||||
|
|
@ -2289,6 +2371,7 @@ impl ReplayStats {
|
|||
mod tests {
|
||||
use super::*;
|
||||
use crate::config::MeTelemetryLevel;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[test]
|
||||
|
|
@ -2457,6 +2540,60 @@ mod tests {
|
|||
}
|
||||
assert_eq!(checker.stats().total_entries, 500);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_quota_reserve_under_contention_hits_limit_exactly() {
|
||||
let user_stats = Arc::new(UserStats::default());
|
||||
let successes = Arc::new(AtomicU64::new(0));
|
||||
let limit = 8_192u64;
|
||||
let mut workers = Vec::new();
|
||||
|
||||
for _ in 0..8 {
|
||||
let user_stats = user_stats.clone();
|
||||
let successes = successes.clone();
|
||||
workers.push(std::thread::spawn(move || {
|
||||
loop {
|
||||
match user_stats.quota_try_reserve(1, limit) {
|
||||
Ok(_) => {
|
||||
successes.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
Err(QuotaReserveError::Contended) => {
|
||||
std::hint::spin_loop();
|
||||
}
|
||||
Err(QuotaReserveError::LimitExceeded) => {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
for worker in workers {
|
||||
worker.join().expect("worker thread must finish");
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
successes.load(Ordering::Relaxed),
|
||||
limit,
|
||||
"successful reservations must stop exactly at limit"
|
||||
);
|
||||
assert_eq!(user_stats.quota_used(), limit);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_quota_used_is_authoritative_and_independent_from_octets_telemetry() {
|
||||
let stats = Stats::new();
|
||||
let user = "quota-authoritative-user";
|
||||
let user_stats = stats.get_or_create_user_stats_handle(user);
|
||||
|
||||
stats.add_user_octets_to_handle(&user_stats, 5);
|
||||
assert_eq!(stats.get_user_total_octets(user), 5);
|
||||
assert_eq!(stats.get_user_quota_used(user), 0);
|
||||
|
||||
stats.quota_charge_post_write(&user_stats, 7);
|
||||
assert_eq!(stats.get_user_total_octets(user), 5);
|
||||
assert_eq!(stats.get_user_quota_used(user), 7);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
@ -2466,7 +2603,3 @@ mod connection_lease_security_tests;
|
|||
#[cfg(test)]
|
||||
#[path = "tests/replay_checker_security_tests.rs"]
|
||||
mod replay_checker_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/user_octets_sub_security_tests.rs"]
|
||||
mod user_octets_sub_security_tests;
|
||||
|
|
|
|||
|
|
@ -1,151 +0,0 @@
|
|||
use super::*;
|
||||
use std::sync::Arc;
|
||||
use std::thread;
|
||||
|
||||
#[test]
|
||||
fn sub_user_octets_to_underflow_saturates_at_zero() {
|
||||
let stats = Stats::new();
|
||||
let user = "sub-underflow-user";
|
||||
|
||||
stats.add_user_octets_to(user, 3);
|
||||
stats.sub_user_octets_to(user, 100);
|
||||
|
||||
assert_eq!(stats.get_user_total_octets(user), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sub_user_octets_to_does_not_affect_octets_from_client() {
|
||||
let stats = Stats::new();
|
||||
let user = "sub-isolation-user";
|
||||
|
||||
stats.add_user_octets_from(user, 17);
|
||||
stats.add_user_octets_to(user, 5);
|
||||
stats.sub_user_octets_to(user, 3);
|
||||
|
||||
assert_eq!(stats.get_user_total_octets(user), 19);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn light_fuzz_add_sub_model_matches_saturating_reference() {
|
||||
let stats = Stats::new();
|
||||
let user = "sub-fuzz-user";
|
||||
let mut seed = 0x91D2_4CB8_EE77_1101u64;
|
||||
let mut model_to = 0u64;
|
||||
|
||||
for _ in 0..8192 {
|
||||
seed ^= seed << 7;
|
||||
seed ^= seed >> 9;
|
||||
seed ^= seed << 8;
|
||||
|
||||
let amt = ((seed >> 8) & 0x3f) + 1;
|
||||
if (seed & 1) == 0 {
|
||||
stats.add_user_octets_to(user, amt);
|
||||
model_to = model_to.saturating_add(amt);
|
||||
} else {
|
||||
stats.sub_user_octets_to(user, amt);
|
||||
model_to = model_to.saturating_sub(amt);
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(stats.get_user_total_octets(user), model_to);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stress_parallel_add_sub_never_underflows_or_panics() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user = "sub-stress-user";
|
||||
// Pre-fund with a large offset so subtractions never saturate at zero.
|
||||
// This guarantees commutative updates, making the final state deterministic.
|
||||
let base_offset = 10_000_000u64;
|
||||
stats.add_user_octets_to(user, base_offset);
|
||||
|
||||
let mut workers = Vec::new();
|
||||
|
||||
for tid in 0..16u64 {
|
||||
let stats_for_thread = Arc::clone(&stats);
|
||||
workers.push(thread::spawn(move || {
|
||||
let mut seed = 0xD00D_1000_0000_0000u64 ^ tid;
|
||||
let mut net_delta = 0i64;
|
||||
for _ in 0..4096 {
|
||||
seed ^= seed << 7;
|
||||
seed ^= seed >> 9;
|
||||
seed ^= seed << 8;
|
||||
let amt = ((seed >> 8) & 0x1f) + 1;
|
||||
|
||||
if (seed & 1) == 0 {
|
||||
stats_for_thread.add_user_octets_to(user, amt);
|
||||
net_delta += amt as i64;
|
||||
} else {
|
||||
stats_for_thread.sub_user_octets_to(user, amt);
|
||||
net_delta -= amt as i64;
|
||||
}
|
||||
}
|
||||
|
||||
net_delta
|
||||
}));
|
||||
}
|
||||
|
||||
let mut expected_net_delta = 0i64;
|
||||
for worker in workers {
|
||||
expected_net_delta += worker
|
||||
.join()
|
||||
.expect("sub-user stress worker must not panic");
|
||||
}
|
||||
|
||||
let expected_total = (base_offset as i64 + expected_net_delta) as u64;
|
||||
let total = stats.get_user_total_octets(user);
|
||||
assert_eq!(
|
||||
total, expected_total,
|
||||
"concurrent add/sub lost updates or suffered ABA races"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sub_user_octets_to_missing_user_is_noop() {
|
||||
let stats = Stats::new();
|
||||
stats.sub_user_octets_to("missing-user", 1024);
|
||||
assert_eq!(stats.get_user_total_octets("missing-user"), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stress_parallel_per_user_models_remain_exact() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let mut workers = Vec::new();
|
||||
|
||||
for tid in 0..16u64 {
|
||||
let stats_for_thread = Arc::clone(&stats);
|
||||
workers.push(thread::spawn(move || {
|
||||
let user = format!("sub-per-user-{tid}");
|
||||
let mut seed = 0xFACE_0000_0000_0000u64 ^ tid;
|
||||
let mut model = 0u64;
|
||||
|
||||
for _ in 0..4096 {
|
||||
seed ^= seed << 7;
|
||||
seed ^= seed >> 9;
|
||||
seed ^= seed << 8;
|
||||
let amt = ((seed >> 8) & 0x3f) + 1;
|
||||
|
||||
if (seed & 1) == 0 {
|
||||
stats_for_thread.add_user_octets_to(&user, amt);
|
||||
model = model.saturating_add(amt);
|
||||
} else {
|
||||
stats_for_thread.sub_user_octets_to(&user, amt);
|
||||
model = model.saturating_sub(amt);
|
||||
}
|
||||
}
|
||||
|
||||
(user, model)
|
||||
}));
|
||||
}
|
||||
|
||||
for worker in workers {
|
||||
let (user, model) = worker
|
||||
.join()
|
||||
.expect("per-user subtract stress worker must not panic");
|
||||
assert_eq!(
|
||||
stats.get_user_total_octets(&user),
|
||||
model,
|
||||
"per-user parallel model diverged"
|
||||
);
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue