diff --git a/src/maestro/runtime_tasks.rs b/src/maestro/runtime_tasks.rs index 066c853..d553eb9 100644 --- a/src/maestro/runtime_tasks.rs +++ b/src/maestro/runtime_tasks.rs @@ -32,14 +32,6 @@ pub(crate) struct RuntimeWatches { pub(crate) detected_ip_v6: Option, } -const QUOTA_USER_LOCK_EVICT_INTERVAL_SECS: u64 = 60; - -fn spawn_quota_lock_maintenance_task() -> tokio::task::JoinHandle<()> { - crate::proxy::relay::spawn_quota_user_lock_evictor(std::time::Duration::from_secs( - QUOTA_USER_LOCK_EVICT_INTERVAL_SECS, - )) -} - #[allow(clippy::too_many_arguments)] pub(crate) async fn spawn_runtime_tasks( config: &Arc, @@ -77,8 +69,6 @@ pub(crate) async fn spawn_runtime_tasks( rc_clone.run_periodic_cleanup().await; }); - spawn_quota_lock_maintenance_task(); - let detected_ip_v4: Option = probe.detected_ipv4.map(IpAddr::V4); let detected_ip_v6: Option = probe.detected_ipv6.map(IpAddr::V6); debug!( @@ -370,24 +360,3 @@ pub(crate) async fn mark_runtime_ready(startup_tracker: &Arc) { .await; startup_tracker.mark_ready().await; } - -#[cfg(test)] -mod tests { - use super::*; - - #[tokio::test] - async fn tdd_runtime_quota_lock_maintenance_path_spawns_single_evictor_task() { - crate::proxy::relay::reset_quota_user_lock_evictor_spawn_count_for_tests(); - - let handle = spawn_quota_lock_maintenance_task(); - tokio::time::sleep(std::time::Duration::from_millis(5)).await; - - assert_eq!( - crate::proxy::relay::quota_user_lock_evictor_spawn_count_for_tests(), - 1, - "runtime maintenance path must spawn exactly one quota lock evictor task per call" - ); - - handle.abort(); - } -} diff --git a/src/proxy/client.rs b/src/proxy/client.rs index a804a2c..1567caf 100644 --- a/src/proxy/client.rs +++ b/src/proxy/client.rs @@ -1223,7 +1223,7 @@ impl RunningClientHandler { } if let Some(quota) = config.access.user_data_quota.get(user) - && stats.get_user_total_octets(user) >= *quota + && stats.get_user_quota_used(user) >= *quota { return Err(ProxyError::DataQuotaExceeded { user: user.to_string(), @@ -1282,7 +1282,7 @@ impl RunningClientHandler { } if let Some(quota) = config.access.user_data_quota.get(user) - && stats.get_user_total_octets(user) >= *quota + && stats.get_user_quota_used(user) >= *quota { return Err(ProxyError::DataQuotaExceeded { user: user.to_string(), diff --git a/src/proxy/handshake.rs b/src/proxy/handshake.rs index 96994c7..55a8a21 100644 --- a/src/proxy/handshake.rs +++ b/src/proxy/handshake.rs @@ -614,6 +614,15 @@ where } }; + // Reject known replay digests before expensive cache/domain/ALPN policy work. + let digest_half = &validation.digest[..tls::TLS_DIGEST_HALF_LEN]; + if replay_checker.check_tls_digest(digest_half) { + auth_probe_record_failure(peer.ip(), Instant::now()); + maybe_apply_server_hello_delay(config).await; + warn!(peer = %peer, "TLS replay attack detected (duplicate digest)"); + return HandshakeResult::BadClient { reader, writer }; + } + let secret = match secrets.iter().find(|(name, _)| *name == validation.user) { Some((_, s)) => s, None => { @@ -669,15 +678,8 @@ where None }; - // Replay tracking is applied only after full policy validation (including - // ALPN checks) so rejected handshakes cannot poison replay state. - let digest_half = &validation.digest[..tls::TLS_DIGEST_HALF_LEN]; - if replay_checker.check_and_add_tls_digest(digest_half) { - auth_probe_record_failure(peer.ip(), Instant::now()); - maybe_apply_server_hello_delay(config).await; - warn!(peer = %peer, "TLS replay attack detected (duplicate digest)"); - return HandshakeResult::BadClient { reader, writer }; - } + // Add replay digest only for policy-valid handshakes. + replay_checker.add_tls_digest(digest_half); let response = if let Some((cached_entry, use_full_cert_payload)) = cached { emulator::build_emulated_server_hello( diff --git a/src/proxy/masking.rs b/src/proxy/masking.rs index 841749c..241a48f 100644 --- a/src/proxy/masking.rs +++ b/src/proxy/masking.rs @@ -60,7 +60,7 @@ where R: AsyncRead + Unpin, W: AsyncWrite + Unpin, { - let mut buf = [0u8; MASK_BUFFER_SIZE]; + let mut buf = Box::new([0u8; MASK_BUFFER_SIZE]); let mut total = 0usize; let mut ended_by_eof = false; @@ -262,7 +262,11 @@ fn mask_outcome_target_budget(config: &ProxyConfig) -> Duration { let floor = config.censorship.mask_timing_normalization_floor_ms; let ceiling = config.censorship.mask_timing_normalization_ceiling_ms; if floor == 0 { - return MASK_TIMEOUT; + if ceiling == 0 { + return Duration::from_millis(0); + } + let mut rng = rand::rng(); + return Duration::from_millis(rng.random_range(0..=ceiling)); } if ceiling > floor { let mut rng = rand::rng(); @@ -838,7 +842,7 @@ async fn consume_client_data(mut reader: R, byte_cap: usiz } // Keep drain path fail-closed under slow-loris stalls. - let mut buf = [0u8; MASK_BUFFER_SIZE]; + let mut buf = Box::new([0u8; MASK_BUFFER_SIZE]); let mut total = 0usize; loop { diff --git a/src/proxy/middle_relay.rs b/src/proxy/middle_relay.rs index 14ea001..2a84353 100644 --- a/src/proxy/middle_relay.rs +++ b/src/proxy/middle_relay.rs @@ -10,7 +10,7 @@ use std::time::{Duration, Instant}; use dashmap::DashMap; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; -use tokio::sync::{Mutex as AsyncMutex, mpsc, oneshot, watch}; +use tokio::sync::{mpsc, oneshot, watch}; use tokio::time::timeout; use tracing::{debug, info, trace, warn}; @@ -23,7 +23,7 @@ use crate::proxy::route_mode::{ ROUTE_SWITCH_ERROR_MSG, RelayRouteMode, RouteCutoverState, affected_cutover_state, cutover_stagger_delay, }; -use crate::stats::{MeD2cFlushReason, MeD2cQuotaRejectStage, MeD2cWriteMode, Stats}; +use crate::stats::{MeD2cFlushReason, MeD2cQuotaRejectStage, MeD2cWriteMode, QuotaReserveError, Stats, UserStats}; use crate::stream::{BufferPool, CryptoReader, CryptoWriter, PooledBuffer}; use crate::transport::middle_proxy::{MePool, MeResponse, proto_flags_for_tag}; @@ -53,20 +53,11 @@ const ME_D2C_FLUSH_BATCH_MAX_FRAMES_MIN: usize = 1; const ME_D2C_FLUSH_BATCH_MAX_BYTES_MIN: usize = 4096; const ME_D2C_FRAME_BUF_SHRINK_HYSTERESIS_FACTOR: usize = 2; const ME_D2C_SINGLE_WRITE_COALESCE_MAX_BYTES: usize = 128 * 1024; -#[cfg(test)] -const QUOTA_USER_LOCKS_MAX: usize = 64; -#[cfg(not(test))] -const QUOTA_USER_LOCKS_MAX: usize = 4_096; -#[cfg(test)] -const QUOTA_OVERFLOW_LOCK_STRIPES: usize = 16; -#[cfg(not(test))] -const QUOTA_OVERFLOW_LOCK_STRIPES: usize = 256; +const QUOTA_RESERVE_SPIN_RETRIES: usize = 32; static DESYNC_DEDUP: OnceLock> = OnceLock::new(); static DESYNC_HASHER: OnceLock = OnceLock::new(); static DESYNC_FULL_CACHE_LAST_EMIT_AT: OnceLock>> = OnceLock::new(); static DESYNC_DEDUP_EVER_SATURATED: OnceLock = OnceLock::new(); -static QUOTA_USER_LOCKS: OnceLock>>> = OnceLock::new(); -static QUOTA_USER_OVERFLOW_LOCKS: OnceLock>>> = OnceLock::new(); static RELAY_IDLE_CANDIDATE_REGISTRY: OnceLock> = OnceLock::new(); static RELAY_IDLE_MARK_SEQ: AtomicU64 = AtomicU64::new(0); @@ -538,36 +529,28 @@ fn should_yield_c2me_sender(sent_since_yield: usize, has_backlog: bool) -> bool has_backlog && sent_since_yield >= C2ME_SENDER_FAIRNESS_BUDGET } -fn quota_exceeded_for_user(stats: &Stats, user: &str, quota_limit: Option) -> bool { - quota_limit.is_some_and(|quota| stats.get_user_total_octets(user) >= quota) -} - -#[cfg_attr(not(test), allow(dead_code))] -fn quota_would_be_exceeded_for_user( - stats: &Stats, - user: &str, - quota_limit: Option, - bytes: u64, -) -> bool { - quota_limit.is_some_and(|quota| { - let used = stats.get_user_total_octets(user); - used >= quota || bytes > quota.saturating_sub(used) - }) -} - fn quota_soft_cap(limit: u64, overshoot: u64) -> u64 { limit.saturating_add(overshoot) } -fn quota_would_be_exceeded_for_user_soft( - stats: &Stats, - user: &str, - quota_limit: Option, +async fn reserve_user_quota_with_yield( + user_stats: &UserStats, bytes: u64, - overshoot: u64, -) -> bool { - let capped_limit = quota_limit.map(|quota| quota_soft_cap(quota, overshoot)); - quota_would_be_exceeded_for_user(stats, user, capped_limit, bytes) + limit: u64, +) -> std::result::Result { + loop { + for _ in 0..QUOTA_RESERVE_SPIN_RETRIES { + match user_stats.quota_try_reserve(bytes, limit) { + Ok(total) => return Ok(total), + Err(QuotaReserveError::LimitExceeded) => { + return Err(QuotaReserveError::LimitExceeded); + } + Err(QuotaReserveError::Contended) => std::hint::spin_loop(), + } + } + + tokio::task::yield_now().await; + } } fn classify_me_d2c_flush_reason( @@ -613,29 +596,6 @@ fn observe_me_d2c_flush_event( } } -fn rollback_me2c_quota_reservation( - stats: &Stats, - user: &str, - bytes_me2c: &AtomicU64, - reserved_bytes: u64, -) { - stats.sub_user_octets_to(user, reserved_bytes); - bytes_me2c.fetch_sub(reserved_bytes, Ordering::Relaxed); -} - -#[cfg(test)] -fn quota_user_lock_test_guard() -> &'static Mutex<()> { - static TEST_LOCK: OnceLock> = OnceLock::new(); - TEST_LOCK.get_or_init(|| Mutex::new(())) -} - -#[cfg(test)] -fn quota_user_lock_test_scope() -> std::sync::MutexGuard<'static, ()> { - quota_user_lock_test_guard() - .lock() - .unwrap_or_else(|poisoned| poisoned.into_inner()) -} - #[cfg(test)] fn relay_idle_pressure_test_guard() -> &'static Mutex<()> { static TEST_LOCK: OnceLock> = OnceLock::new(); @@ -649,46 +609,6 @@ pub(crate) fn relay_idle_pressure_test_scope() -> std::sync::MutexGuard<'static, .unwrap_or_else(|poisoned| poisoned.into_inner()) } -fn quota_overflow_user_lock(user: &str) -> Arc> { - let stripes = QUOTA_USER_OVERFLOW_LOCKS.get_or_init(|| { - (0..QUOTA_OVERFLOW_LOCK_STRIPES) - .map(|_| Arc::new(AsyncMutex::new(()))) - .collect() - }); - - let hash = crc32fast::hash(user.as_bytes()) as usize; - Arc::clone(&stripes[hash % stripes.len()]) -} - -fn quota_user_lock(user: &str) -> Arc> { - let locks = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - if let Some(existing) = locks.get(user) { - return Arc::clone(existing.value()); - } - - if locks.len() >= QUOTA_USER_LOCKS_MAX { - locks.retain(|_, value| Arc::strong_count(value) > 1); - } - - if locks.len() >= QUOTA_USER_LOCKS_MAX { - return quota_overflow_user_lock(user); - } - - let created = Arc::new(AsyncMutex::new(())); - match locks.entry(user.to_string()) { - dashmap::mapref::entry::Entry::Occupied(entry) => Arc::clone(entry.get()), - dashmap::mapref::entry::Entry::Vacant(entry) => { - entry.insert(Arc::clone(&created)); - created - } - } -} - -#[cfg(test)] -pub(crate) fn cross_mode_quota_user_lock_for_tests(user: &str) -> Arc> { - crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(user) -} - async fn enqueue_c2me_command( tx: &mpsc::Sender, cmd: C2MeCommand, @@ -744,8 +664,7 @@ where { let user = success.user.clone(); let quota_limit = config.access.user_data_quota.get(&user).copied(); - let cross_mode_quota_lock = - quota_limit.map(|_| crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user)); + let quota_user_stats = quota_limit.map(|_| stats.get_or_create_user_stats_handle(&user)); let peer = success.peer; let proto_tag = success.proto_tag; let pool_generation = me_pool.current_generation(); @@ -872,7 +791,7 @@ where let stats_clone = stats.clone(); let rng_clone = rng.clone(); let user_clone = user.clone(); - let cross_mode_quota_lock_me_writer = cross_mode_quota_lock.clone(); + let quota_user_stats_me_writer = quota_user_stats.clone(); let last_downstream_activity_ms_clone = last_downstream_activity_ms.clone(); let bytes_me2c_clone = bytes_me2c.clone(); let d2c_flush_policy = MeD2cFlushPolicy::from_config(&config); @@ -894,7 +813,7 @@ where let first_is_downstream_activity = matches!(&first, MeResponse::Data { .. } | MeResponse::Ack(_)); - match process_me_writer_response_with_cross_mode_lock( + match process_me_writer_response( first, &mut writer, proto_tag, @@ -902,9 +821,9 @@ where &mut frame_buf, stats_clone.as_ref(), &user_clone, + quota_user_stats_me_writer.as_deref(), quota_limit, d2c_flush_policy.quota_soft_overshoot_bytes, - cross_mode_quota_lock_me_writer.as_ref(), bytes_me2c_clone.as_ref(), conn_id, d2c_flush_policy.ack_flush_immediate, @@ -953,7 +872,7 @@ where let next_is_downstream_activity = matches!(&next, MeResponse::Data { .. } | MeResponse::Ack(_)); - match process_me_writer_response_with_cross_mode_lock( + match process_me_writer_response( next, &mut writer, proto_tag, @@ -961,9 +880,9 @@ where &mut frame_buf, stats_clone.as_ref(), &user_clone, + quota_user_stats_me_writer.as_deref(), quota_limit, d2c_flush_policy.quota_soft_overshoot_bytes, - cross_mode_quota_lock_me_writer.as_ref(), bytes_me2c_clone.as_ref(), conn_id, d2c_flush_policy.ack_flush_immediate, @@ -1015,7 +934,7 @@ where Ok(Some(next)) => { let next_is_downstream_activity = matches!(&next, MeResponse::Data { .. } | MeResponse::Ack(_)); - match process_me_writer_response_with_cross_mode_lock( + match process_me_writer_response( next, &mut writer, proto_tag, @@ -1023,9 +942,9 @@ where &mut frame_buf, stats_clone.as_ref(), &user_clone, + quota_user_stats_me_writer.as_deref(), quota_limit, d2c_flush_policy.quota_soft_overshoot_bytes, - cross_mode_quota_lock_me_writer.as_ref(), bytes_me2c_clone.as_ref(), conn_id, d2c_flush_policy.ack_flush_immediate, @@ -1079,7 +998,7 @@ where let extra_is_downstream_activity = matches!(&extra, MeResponse::Data { .. } | MeResponse::Ack(_)); - match process_me_writer_response_with_cross_mode_lock( + match process_me_writer_response( extra, &mut writer, proto_tag, @@ -1087,9 +1006,9 @@ where &mut frame_buf, stats_clone.as_ref(), &user_clone, + quota_user_stats_me_writer.as_deref(), quota_limit, d2c_flush_policy.quota_soft_overshoot_bytes, - cross_mode_quota_lock_me_writer.as_ref(), bytes_me2c_clone.as_ref(), conn_id, d2c_flush_policy.ack_flush_immediate, @@ -1259,24 +1178,23 @@ where forensics.bytes_c2me = forensics .bytes_c2me .saturating_add(payload.len() as u64); - if let Some(limit) = quota_limit { - let quota_lock = quota_user_lock(&user); - let _quota_guard = quota_lock.lock().await; - let Some(cross_mode_lock) = cross_mode_quota_lock.as_ref() else { - main_result = Err(ProxyError::Proxy( - "cross-mode quota lock missing for quota-limited session" - .to_string(), - )); - break; - }; - let _cross_mode_quota_guard = cross_mode_lock.lock().await; - stats.add_user_octets_from(&user, payload.len() as u64); - if quota_exceeded_for_user(stats.as_ref(), &user, Some(limit)) { + if let (Some(limit), Some(user_stats)) = + (quota_limit, quota_user_stats.as_deref()) + { + if reserve_user_quota_with_yield( + user_stats, + payload.len() as u64, + limit, + ) + .await + .is_err() + { main_result = Err(ProxyError::DataQuotaExceeded { user: user.clone(), }); break; } + stats.add_user_octets_from_handle(user_stats, payload.len() as u64); } else { stats.add_user_octets_from(&user, payload.len() as u64); } @@ -1755,7 +1673,6 @@ enum MeWriterResponseOutcome { Close, } -#[cfg(test)] async fn process_me_writer_response( response: MeResponse, client_writer: &mut CryptoWriter, @@ -1764,6 +1681,7 @@ async fn process_me_writer_response( frame_buf: &mut Vec, stats: &Stats, user: &str, + quota_user_stats: Option<&UserStats>, quota_limit: Option, quota_soft_overshoot_bytes: u64, bytes_me2c: &AtomicU64, @@ -1771,44 +1689,6 @@ async fn process_me_writer_response( ack_flush_immediate: bool, batched: bool, ) -> Result -where - W: AsyncWrite + Unpin + Send + 'static, -{ - process_me_writer_response_with_cross_mode_lock( - response, - client_writer, - proto_tag, - rng, - frame_buf, - stats, - user, - quota_limit, - quota_soft_overshoot_bytes, - None, - bytes_me2c, - conn_id, - ack_flush_immediate, - batched, - ) - .await -} - -async fn process_me_writer_response_with_cross_mode_lock( - response: MeResponse, - client_writer: &mut CryptoWriter, - proto_tag: ProtoTag, - rng: &SecureRandom, - frame_buf: &mut Vec, - stats: &Stats, - user: &str, - quota_limit: Option, - quota_soft_overshoot_bytes: u64, - cross_mode_quota_lock: Option<&Arc>>, - bytes_me2c: &AtomicU64, - conn_id: u64, - ack_flush_immediate: bool, - batched: bool, -) -> Result where W: AsyncWrite + Unpin + Send + 'static, { @@ -1820,78 +1700,43 @@ where trace!(conn_id, bytes = data.len(), flags, "ME->C data"); } let data_len = data.len() as u64; - if let Some(limit) = quota_limit { - let owned_cross_mode_lock; - let cross_mode_lock = if let Some(lock) = cross_mode_quota_lock { - lock - } else { - owned_cross_mode_lock = - crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(user); - &owned_cross_mode_lock - }; - let cross_mode_quota_guard = cross_mode_lock.lock().await; + if let (Some(limit), Some(user_stats)) = (quota_limit, quota_user_stats) { let soft_limit = quota_soft_cap(limit, quota_soft_overshoot_bytes); - if quota_would_be_exceeded_for_user_soft( - stats, - user, - Some(limit), - data_len, - quota_soft_overshoot_bytes, - ) { + if reserve_user_quota_with_yield(user_stats, data_len, soft_limit) + .await + .is_err() + { stats.increment_me_d2c_quota_reject_total(MeD2cQuotaRejectStage::PreWrite); return Err(ProxyError::DataQuotaExceeded { user: user.to_string(), }); } - - // Reserve quota before awaiting network I/O to avoid same-user HoL stalls. - // If reservation loses a race or write fails, we roll back immediately. - bytes_me2c.fetch_add(data_len, Ordering::Relaxed); - stats.add_user_octets_to(user, data_len); - - if stats.get_user_total_octets(user) > soft_limit { - rollback_me2c_quota_reservation(stats, user, bytes_me2c, data_len); - stats.increment_me_d2c_quota_reject_total(MeD2cQuotaRejectStage::PreWrite); - return Err(ProxyError::DataQuotaExceeded { - user: user.to_string(), - }); - } - - // Keep cross-mode lock scope explicit and minimal: quota reservation is serialized, - // but socket I/O proceeds without holding same-user cross-mode admission lock. - drop(cross_mode_quota_guard); - - let write_mode = - match write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf) - .await - { - Ok(mode) => mode, - Err(err) => { - rollback_me2c_quota_reservation(stats, user, bytes_me2c, data_len); - return Err(err); - } - }; - - stats.increment_me_d2c_data_frames_total(); - stats.add_me_d2c_payload_bytes_total(data_len); - stats.increment_me_d2c_write_mode(write_mode); - - // Do not fail immediately on exact boundary after a successful write. - // Returning an error here can bypass batch flush in the caller and risk - // dropping buffered ciphertext from CryptoWriter. The next frame is - // rejected by the pre-check at function entry. - } else { - let write_mode = - write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf) - .await?; - - bytes_me2c.fetch_add(data_len, Ordering::Relaxed); - stats.add_user_octets_to(user, data_len); - stats.increment_me_d2c_data_frames_total(); - stats.add_me_d2c_payload_bytes_total(data_len); - stats.increment_me_d2c_write_mode(write_mode); } + let write_mode = + match write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf) + .await + { + Ok(mode) => mode, + Err(err) => { + if quota_limit.is_some() { + stats.add_quota_write_fail_bytes_total(data_len); + stats.increment_quota_write_fail_events_total(); + } + return Err(err); + } + }; + + bytes_me2c.fetch_add(data_len, Ordering::Relaxed); + if let Some(user_stats) = quota_user_stats { + stats.add_user_octets_to_handle(user_stats, data_len); + } else { + stats.add_user_octets_to(user, data_len); + } + stats.increment_me_d2c_data_frames_total(); + stats.add_me_d2c_payload_bytes_total(data_len); + stats.increment_me_d2c_write_mode(write_mode); + Ok(MeWriterResponseOutcome::Continue { frames: 1, bytes: data.len(), @@ -2097,10 +1942,6 @@ where .map_err(ProxyError::Io) } -#[cfg(test)] -#[path = "tests/middle_relay_security_tests.rs"] -mod security_tests; - #[cfg(test)] #[path = "tests/middle_relay_idle_policy_security_tests.rs"] mod idle_policy_security_tests; @@ -2113,30 +1954,10 @@ mod desync_all_full_dedup_security_tests; #[path = "tests/middle_relay_stub_completion_security_tests.rs"] mod stub_completion_security_tests; -#[cfg(test)] -#[path = "tests/middle_relay_coverage_high_risk_security_tests.rs"] -mod coverage_high_risk_security_tests; - -#[cfg(test)] -#[path = "tests/middle_relay_quota_overflow_lock_security_tests.rs"] -mod quota_overflow_lock_security_tests; - #[cfg(test)] #[path = "tests/middle_relay_length_cast_hardening_security_tests.rs"] mod length_cast_hardening_security_tests; -#[cfg(test)] -#[path = "tests/middle_relay_blackhat_campaign_integration_tests.rs"] -mod blackhat_campaign_integration_tests; - -#[cfg(test)] -#[path = "tests/middle_relay_hol_quota_security_tests.rs"] -mod hol_quota_security_tests; - -#[cfg(test)] -#[path = "tests/middle_relay_quota_reservation_adversarial_tests.rs"] -mod quota_reservation_adversarial_tests; - #[cfg(test)] #[path = "tests/middle_relay_idle_registry_poison_security_tests.rs"] mod middle_relay_idle_registry_poison_security_tests; @@ -2156,27 +1977,3 @@ mod middle_relay_tiny_frame_debt_concurrency_security_tests; #[cfg(test)] #[path = "tests/middle_relay_tiny_frame_debt_proto_chunking_security_tests.rs"] mod middle_relay_tiny_frame_debt_proto_chunking_security_tests; - -#[cfg(test)] -#[path = "tests/middle_relay_cross_mode_quota_reservation_security_tests.rs"] -mod middle_relay_cross_mode_quota_reservation_security_tests; - -#[cfg(test)] -#[path = "tests/middle_relay_cross_mode_quota_lock_matrix_security_tests.rs"] -mod middle_relay_cross_mode_quota_lock_matrix_security_tests; - -#[cfg(test)] -#[path = "tests/middle_relay_cross_mode_lookup_efficiency_security_tests.rs"] -mod middle_relay_cross_mode_lookup_efficiency_security_tests; - -#[cfg(test)] -#[path = "tests/middle_relay_cross_mode_lock_release_regression_security_tests.rs"] -mod middle_relay_cross_mode_lock_release_regression_security_tests; - -#[cfg(test)] -#[path = "tests/middle_relay_quota_extended_attack_surface_security_tests.rs"] -mod middle_relay_quota_extended_attack_surface_security_tests; - -#[cfg(test)] -#[path = "tests/middle_relay_quota_reservation_extreme_security_tests.rs"] -mod middle_relay_quota_reservation_extreme_security_tests; diff --git a/src/proxy/mod.rs b/src/proxy/mod.rs index 519f1b3..eebc188 100644 --- a/src/proxy/mod.rs +++ b/src/proxy/mod.rs @@ -64,7 +64,6 @@ pub mod direct_relay; pub mod handshake; pub mod masking; pub mod middle_relay; -pub mod quota_lock_registry; pub mod relay; pub mod route_mode; pub mod session_eviction; diff --git a/src/proxy/quota_lock_registry.rs b/src/proxy/quota_lock_registry.rs deleted file mode 100644 index 7798b09..0000000 --- a/src/proxy/quota_lock_registry.rs +++ /dev/null @@ -1,88 +0,0 @@ -use dashmap::DashMap; -use std::sync::{Arc, OnceLock}; -use tokio::sync::Mutex; - -#[cfg(test)] -use std::sync::atomic::{AtomicUsize, Ordering}; - -#[cfg(test)] -const CROSS_MODE_QUOTA_USER_LOCKS_MAX: usize = 64; -#[cfg(not(test))] -const CROSS_MODE_QUOTA_USER_LOCKS_MAX: usize = 4_096; -#[cfg(test)] -const CROSS_MODE_QUOTA_OVERFLOW_LOCK_STRIPES: usize = 16; -#[cfg(not(test))] -const CROSS_MODE_QUOTA_OVERFLOW_LOCK_STRIPES: usize = 256; - -static CROSS_MODE_QUOTA_USER_LOCKS: OnceLock>>> = OnceLock::new(); -static CROSS_MODE_QUOTA_USER_OVERFLOW_LOCKS: OnceLock>>> = OnceLock::new(); - -#[cfg(test)] -static CROSS_MODE_QUOTA_USER_LOCK_LOOKUPS: AtomicUsize = AtomicUsize::new(0); -#[cfg(test)] -static CROSS_MODE_QUOTA_USER_LOOKUPS_BY_USER: OnceLock> = OnceLock::new(); - -fn cross_mode_quota_overflow_user_lock(user: &str) -> Arc> { - let stripes = CROSS_MODE_QUOTA_USER_OVERFLOW_LOCKS.get_or_init(|| { - (0..CROSS_MODE_QUOTA_OVERFLOW_LOCK_STRIPES) - .map(|_| Arc::new(Mutex::new(()))) - .collect() - }); - - let hash = crc32fast::hash(user.as_bytes()) as usize; - Arc::clone(&stripes[hash % stripes.len()]) -} - -pub(crate) fn cross_mode_quota_user_lock(user: &str) -> Arc> { - #[cfg(test)] - { - CROSS_MODE_QUOTA_USER_LOCK_LOOKUPS.fetch_add(1, Ordering::Relaxed); - let lookups = CROSS_MODE_QUOTA_USER_LOOKUPS_BY_USER.get_or_init(DashMap::new); - let mut entry = lookups.entry(user.to_string()).or_insert(0); - *entry += 1; - } - - let locks = CROSS_MODE_QUOTA_USER_LOCKS.get_or_init(DashMap::new); - if let Some(existing) = locks.get(user) { - return Arc::clone(existing.value()); - } - - if locks.len() >= CROSS_MODE_QUOTA_USER_LOCKS_MAX { - locks.retain(|_, value| Arc::strong_count(value) > 1); - } - - if locks.len() >= CROSS_MODE_QUOTA_USER_LOCKS_MAX { - return cross_mode_quota_overflow_user_lock(user); - } - - let created = Arc::new(Mutex::new(())); - match locks.entry(user.to_string()) { - dashmap::mapref::entry::Entry::Occupied(entry) => Arc::clone(entry.get()), - dashmap::mapref::entry::Entry::Vacant(entry) => { - entry.insert(Arc::clone(&created)); - created - } - } -} - -#[cfg(test)] -pub(crate) fn reset_cross_mode_quota_user_lock_lookup_count_for_tests() { - CROSS_MODE_QUOTA_USER_LOCK_LOOKUPS.store(0, Ordering::Relaxed); - let lookups = CROSS_MODE_QUOTA_USER_LOOKUPS_BY_USER.get_or_init(DashMap::new); - lookups.clear(); -} - -#[cfg(test)] -pub(crate) fn cross_mode_quota_user_lock_lookup_count_for_tests() -> usize { - CROSS_MODE_QUOTA_USER_LOCK_LOOKUPS.load(Ordering::Relaxed) -} - -#[cfg(test)] -pub(crate) fn cross_mode_quota_user_lock_lookup_count_for_user_for_tests(user: &str) -> usize { - let lookups = CROSS_MODE_QUOTA_USER_LOOKUPS_BY_USER.get_or_init(DashMap::new); - lookups.get(user).map(|entry| *entry).unwrap_or(0) -} - -#[cfg(test)] -#[path = "tests/quota_lock_registry_cross_mode_adversarial_tests.rs"] -mod quota_lock_registry_cross_mode_adversarial_tests; diff --git a/src/proxy/relay.rs b/src/proxy/relay.rs index 55f1385..cc8b088 100644 --- a/src/proxy/relay.rs +++ b/src/proxy/relay.rs @@ -52,18 +52,16 @@ //! - `SharedCounters` (atomics) let the watchdog read stats without locking use crate::error::{ProxyError, Result}; -use crate::stats::Stats; +use crate::stats::{Stats, UserStats}; use crate::stream::BufferPool; -use dashmap::DashMap; use std::io; use std::pin::Pin; use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; -use std::sync::{Arc, Mutex, OnceLock}; +use std::sync::Arc; use std::task::{Context, Poll}; use std::time::Duration; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf, copy_bidirectional_with_sizes}; -use tokio::sync::Mutex as AsyncMutex; -use tokio::time::{Instant, Sleep}; +use tokio::time::Instant; use tracing::{debug, trace, warn}; // ============= Constants ============= @@ -210,16 +208,10 @@ struct StatsIo { counters: Arc, stats: Arc, user: String, - quota_lock: Option>>, - cross_mode_quota_lock: Option>>, + user_stats: Arc, quota_limit: Option, quota_exceeded: Arc, - quota_read_wake_scheduled: bool, - quota_write_wake_scheduled: bool, - quota_read_retry_sleep: Option>>, - quota_write_retry_sleep: Option>>, - quota_read_retry_attempt: u8, - quota_write_retry_attempt: u8, + quota_bytes_since_check: u64, epoch: Instant, } @@ -235,24 +227,16 @@ impl StatsIo { ) -> Self { // Mark initial activity so the watchdog doesn't fire before data flows counters.touch(Instant::now(), epoch); - let quota_lock = quota_limit.map(|_| quota_user_lock(&user)); - let cross_mode_quota_lock = quota_limit - .map(|_| crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user)); + let user_stats = stats.get_or_create_user_stats_handle(&user); Self { inner, counters, stats, user, - quota_lock, - cross_mode_quota_lock, + user_stats, quota_limit, quota_exceeded, - quota_read_wake_scheduled: false, - quota_write_wake_scheduled: false, - quota_read_retry_sleep: None, - quota_write_retry_sleep: None, - quota_read_retry_attempt: 0, - quota_write_retry_attempt: 0, + quota_bytes_since_check: 0, epoch, } } @@ -281,169 +265,24 @@ fn is_quota_io_error(err: &io::Error) -> bool { .is_some() } -#[cfg(test)] -const QUOTA_CONTENTION_RETRY_INTERVAL: Duration = Duration::from_millis(1); -#[cfg(not(test))] -const QUOTA_CONTENTION_RETRY_INTERVAL: Duration = Duration::from_millis(2); -#[cfg(test)] -const QUOTA_CONTENTION_RETRY_MAX_INTERVAL: Duration = Duration::from_millis(16); -#[cfg(not(test))] -const QUOTA_CONTENTION_RETRY_MAX_INTERVAL: Duration = Duration::from_millis(64); +const QUOTA_NEAR_LIMIT_BYTES: u64 = 64 * 1024; +const QUOTA_LARGE_CHARGE_BYTES: u64 = 16 * 1024; +const QUOTA_ADAPTIVE_INTERVAL_MIN_BYTES: u64 = 4 * 1024; +const QUOTA_ADAPTIVE_INTERVAL_MAX_BYTES: u64 = 64 * 1024; -#[cfg(test)] -static QUOTA_RETRY_SLEEP_ALLOCS: AtomicU64 = AtomicU64::new(0); -#[cfg(test)] -static QUOTA_USER_LOCK_EVICTOR_SPAWN_COUNT: AtomicU64 = AtomicU64::new(0); - -#[cfg(test)] -pub(crate) fn reset_quota_retry_sleep_allocs_for_tests() { - QUOTA_RETRY_SLEEP_ALLOCS.store(0, Ordering::Relaxed); -} - -#[cfg(test)] -pub(crate) fn quota_retry_sleep_allocs_for_tests() -> u64 { - QUOTA_RETRY_SLEEP_ALLOCS.load(Ordering::Relaxed) +#[inline] +fn quota_adaptive_interval_bytes(remaining_before: u64) -> u64 { + remaining_before + .saturating_div(2) + .clamp( + QUOTA_ADAPTIVE_INTERVAL_MIN_BYTES, + QUOTA_ADAPTIVE_INTERVAL_MAX_BYTES, + ) } #[inline] -fn quota_contention_retry_delay(retry_attempt: u8) -> Duration { - let shift = u32::from(retry_attempt.min(5)); - let multiplier = 1_u32 << shift; - QUOTA_CONTENTION_RETRY_INTERVAL - .saturating_mul(multiplier) - .min(QUOTA_CONTENTION_RETRY_MAX_INTERVAL) -} - -#[inline] -fn reset_quota_retry_scheduler( - sleep_slot: &mut Option>>, - wake_scheduled: &mut bool, - retry_attempt: &mut u8, -) { - *wake_scheduled = false; - *sleep_slot = None; - *retry_attempt = 0; -} - -fn poll_quota_retry_sleep( - sleep_slot: &mut Option>>, - wake_scheduled: &mut bool, - retry_attempt: &mut u8, - cx: &mut Context<'_>, -) { - if !*wake_scheduled { - *wake_scheduled = true; - #[cfg(test)] - QUOTA_RETRY_SLEEP_ALLOCS.fetch_add(1, Ordering::Relaxed); - *sleep_slot = Some(Box::pin(tokio::time::sleep(quota_contention_retry_delay( - *retry_attempt, - )))); - } - - if let Some(sleep) = sleep_slot.as_mut() - && sleep.as_mut().poll(cx).is_ready() - { - *sleep_slot = None; - *wake_scheduled = false; - *retry_attempt = retry_attempt.saturating_add(1); - cx.waker().wake_by_ref(); - } -} - -static QUOTA_USER_LOCKS: OnceLock>>> = OnceLock::new(); -static QUOTA_USER_OVERFLOW_LOCKS: OnceLock>>> = OnceLock::new(); - -#[cfg(test)] -const QUOTA_USER_LOCKS_MAX: usize = 64; -#[cfg(not(test))] -const QUOTA_USER_LOCKS_MAX: usize = 4_096; -#[cfg(test)] -const QUOTA_OVERFLOW_LOCK_STRIPES: usize = 16; -#[cfg(not(test))] -const QUOTA_OVERFLOW_LOCK_STRIPES: usize = 256; - -#[cfg(test)] -fn quota_user_lock_test_guard() -> &'static Mutex<()> { - static TEST_LOCK: OnceLock> = OnceLock::new(); - TEST_LOCK.get_or_init(|| Mutex::new(())) -} - -#[cfg(test)] -fn quota_user_lock_test_scope() -> std::sync::MutexGuard<'static, ()> { - quota_user_lock_test_guard() - .lock() - .unwrap_or_else(|poisoned| poisoned.into_inner()) -} - -fn quota_overflow_user_lock(user: &str) -> Arc> { - let stripes = QUOTA_USER_OVERFLOW_LOCKS.get_or_init(|| { - (0..QUOTA_OVERFLOW_LOCK_STRIPES) - .map(|_| Arc::new(Mutex::new(()))) - .collect() - }); - - let hash = crc32fast::hash(user.as_bytes()) as usize; - Arc::clone(&stripes[hash % stripes.len()]) -} - -pub(crate) fn quota_user_lock_evict() { - if let Some(locks) = QUOTA_USER_LOCKS.get() { - locks.retain(|_, value| Arc::strong_count(value) > 1); - } -} - -pub(crate) fn spawn_quota_user_lock_evictor(interval: Duration) -> tokio::task::JoinHandle<()> { - let interval = interval.max(Duration::from_millis(1)); - #[cfg(test)] - QUOTA_USER_LOCK_EVICTOR_SPAWN_COUNT.fetch_add(1, Ordering::Relaxed); - tokio::spawn(async move { - loop { - tokio::time::sleep(interval).await; - quota_user_lock_evict(); - } - }) -} - -#[cfg(test)] -pub(crate) fn spawn_quota_user_lock_evictor_for_tests( - interval: Duration, -) -> tokio::task::JoinHandle<()> { - spawn_quota_user_lock_evictor(interval) -} - -#[cfg(test)] -pub(crate) fn reset_quota_user_lock_evictor_spawn_count_for_tests() { - QUOTA_USER_LOCK_EVICTOR_SPAWN_COUNT.store(0, Ordering::Relaxed); -} - -#[cfg(test)] -pub(crate) fn quota_user_lock_evictor_spawn_count_for_tests() -> u64 { - QUOTA_USER_LOCK_EVICTOR_SPAWN_COUNT.load(Ordering::Relaxed) -} - -fn quota_user_lock(user: &str) -> Arc> { - let locks = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - if let Some(existing) = locks.get(user) { - return Arc::clone(existing.value()); - } - - if locks.len() >= QUOTA_USER_LOCKS_MAX { - return quota_overflow_user_lock(user); - } - - let created = Arc::new(Mutex::new(())); - match locks.entry(user.to_string()) { - dashmap::mapref::entry::Entry::Occupied(entry) => Arc::clone(entry.get()), - dashmap::mapref::entry::Entry::Vacant(entry) => { - entry.insert(Arc::clone(&created)); - created - } - } -} - -#[cfg(test)] -pub(crate) fn cross_mode_quota_user_lock_for_tests(user: &str) -> Arc> { - crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(user) +fn should_immediate_quota_check(remaining_before: u64, charge_bytes: u64) -> bool { + remaining_before <= QUOTA_NEAR_LIMIT_BYTES || charge_bytes >= QUOTA_LARGE_CHARGE_BYTES } impl AsyncRead for StatsIo { @@ -453,93 +292,60 @@ impl AsyncRead for StatsIo { buf: &mut ReadBuf<'_>, ) -> Poll> { let this = self.get_mut(); - if this.quota_exceeded.load(Ordering::Relaxed) { + if this.quota_exceeded.load(Ordering::Acquire) { return Poll::Ready(Err(quota_io_error())); } - let _quota_guard = if let Some(lock) = this.quota_lock.as_ref() { - match lock.try_lock() { - Ok(guard) => Some(guard), - Err(_) => { - poll_quota_retry_sleep( - &mut this.quota_read_retry_sleep, - &mut this.quota_read_wake_scheduled, - &mut this.quota_read_retry_attempt, - cx, - ); - return Poll::Pending; - } + let mut remaining_before = None; + if let Some(limit) = this.quota_limit { + let used_before = this.user_stats.quota_used(); + let remaining = limit.saturating_sub(used_before); + if remaining == 0 { + this.quota_exceeded.store(true, Ordering::Release); + return Poll::Ready(Err(quota_io_error())); } - } else { - None - }; - - let _cross_mode_quota_guard = if let Some(lock) = this.cross_mode_quota_lock.as_ref() { - match lock.try_lock() { - Ok(guard) => Some(guard), - Err(_) => { - poll_quota_retry_sleep( - &mut this.quota_read_retry_sleep, - &mut this.quota_read_wake_scheduled, - &mut this.quota_read_retry_attempt, - cx, - ); - return Poll::Pending; - } - } - } else { - None - }; - - reset_quota_retry_scheduler( - &mut this.quota_read_retry_sleep, - &mut this.quota_read_wake_scheduled, - &mut this.quota_read_retry_attempt, - ); - - if let Some(limit) = this.quota_limit - && this.stats.get_user_total_octets(&this.user) >= limit - { - this.quota_exceeded.store(true, Ordering::Relaxed); - return Poll::Ready(Err(quota_io_error())); + remaining_before = Some(remaining); } + let before = buf.filled().len(); match Pin::new(&mut this.inner).poll_read(cx, buf) { Poll::Ready(Ok(())) => { let n = buf.filled().len() - before; if n > 0 { - let mut reached_quota_boundary = false; - if let Some(limit) = this.quota_limit { - let used = this.stats.get_user_total_octets(&this.user); - if used >= limit { - this.quota_exceeded.store(true, Ordering::Relaxed); - return Poll::Ready(Err(quota_io_error())); - } - - let remaining = limit - used; - if (n as u64) > remaining { - // Fail closed: when a single read chunk would cross quota, - // stop relay immediately without accounting beyond the cap. - this.quota_exceeded.store(true, Ordering::Relaxed); - return Poll::Ready(Err(quota_io_error())); - } - - reached_quota_boundary = (n as u64) == remaining; - } + let n_to_charge = n as u64; // C→S: client sent data this.counters .c2s_bytes - .fetch_add(n as u64, Ordering::Relaxed); + .fetch_add(n_to_charge, Ordering::Relaxed); this.counters.c2s_ops.fetch_add(1, Ordering::Relaxed); this.counters.touch(Instant::now(), this.epoch); - this.stats.add_user_octets_from(&this.user, n as u64); - this.stats.increment_user_msgs_from(&this.user); + this.stats + .add_user_octets_from_handle(this.user_stats.as_ref(), n_to_charge); + this.stats + .increment_user_msgs_from_handle(this.user_stats.as_ref()); - if reached_quota_boundary { - this.quota_exceeded.store(true, Ordering::Relaxed); + if let (Some(limit), Some(remaining)) = (this.quota_limit, remaining_before) { + this.stats + .quota_charge_post_write(this.user_stats.as_ref(), n_to_charge); + if should_immediate_quota_check(remaining, n_to_charge) { + this.quota_bytes_since_check = 0; + if this.user_stats.quota_used() >= limit { + this.quota_exceeded.store(true, Ordering::Release); + } + } else { + this.quota_bytes_since_check = + this.quota_bytes_since_check.saturating_add(n_to_charge); + let interval = quota_adaptive_interval_bytes(remaining); + if this.quota_bytes_since_check >= interval { + this.quota_bytes_since_check = 0; + if this.user_stats.quota_used() >= limit { + this.quota_exceeded.store(true, Ordering::Release); + } + } + } } trace!(user = %this.user, bytes = n, "C->S"); @@ -558,87 +364,57 @@ impl AsyncWrite for StatsIo { buf: &[u8], ) -> Poll> { let this = self.get_mut(); - if this.quota_exceeded.load(Ordering::Relaxed) { + if this.quota_exceeded.load(Ordering::Acquire) { return Poll::Ready(Err(quota_io_error())); } - let _quota_guard = if let Some(lock) = this.quota_lock.as_ref() { - match lock.try_lock() { - Ok(guard) => Some(guard), - Err(_) => { - poll_quota_retry_sleep( - &mut this.quota_write_retry_sleep, - &mut this.quota_write_wake_scheduled, - &mut this.quota_write_retry_attempt, - cx, - ); - return Poll::Pending; - } - } - } else { - None - }; - - let _cross_mode_quota_guard = if let Some(lock) = this.cross_mode_quota_lock.as_ref() { - match lock.try_lock() { - Ok(guard) => Some(guard), - Err(_) => { - poll_quota_retry_sleep( - &mut this.quota_write_retry_sleep, - &mut this.quota_write_wake_scheduled, - &mut this.quota_write_retry_attempt, - cx, - ); - return Poll::Pending; - } - } - } else { - None - }; - - reset_quota_retry_scheduler( - &mut this.quota_write_retry_sleep, - &mut this.quota_write_wake_scheduled, - &mut this.quota_write_retry_attempt, - ); - - let write_buf = if let Some(limit) = this.quota_limit { - let used = this.stats.get_user_total_octets(&this.user); - if used >= limit { - this.quota_exceeded.store(true, Ordering::Relaxed); + let mut remaining_before = None; + if let Some(limit) = this.quota_limit { + let used_before = this.user_stats.quota_used(); + let remaining = limit.saturating_sub(used_before); + if remaining == 0 { + this.quota_exceeded.store(true, Ordering::Release); return Poll::Ready(Err(quota_io_error())); } + remaining_before = Some(remaining); + } - let remaining = (limit - used) as usize; - if buf.len() > remaining { - // Fail closed: do not emit partial S->C payload when remaining - // quota cannot accommodate the pending write request. - this.quota_exceeded.store(true, Ordering::Relaxed); - return Poll::Ready(Err(quota_io_error())); - } - buf - } else { - buf - }; - - match Pin::new(&mut this.inner).poll_write(cx, write_buf) { + match Pin::new(&mut this.inner).poll_write(cx, buf) { Poll::Ready(Ok(n)) => { if n > 0 { + let n_to_charge = n as u64; + // S→C: data written to client this.counters .s2c_bytes - .fetch_add(n as u64, Ordering::Relaxed); + .fetch_add(n_to_charge, Ordering::Relaxed); this.counters.s2c_ops.fetch_add(1, Ordering::Relaxed); this.counters.touch(Instant::now(), this.epoch); - this.stats.add_user_octets_to(&this.user, n as u64); - this.stats.increment_user_msgs_to(&this.user); + this.stats + .add_user_octets_to_handle(this.user_stats.as_ref(), n_to_charge); + this.stats + .increment_user_msgs_to_handle(this.user_stats.as_ref()); - if let Some(limit) = this.quota_limit - && this.stats.get_user_total_octets(&this.user) >= limit - { - this.quota_exceeded.store(true, Ordering::Relaxed); - return Poll::Ready(Err(quota_io_error())); + if let (Some(limit), Some(remaining)) = (this.quota_limit, remaining_before) { + this.stats + .quota_charge_post_write(this.user_stats.as_ref(), n_to_charge); + if should_immediate_quota_check(remaining, n_to_charge) { + this.quota_bytes_since_check = 0; + if this.user_stats.quota_used() >= limit { + this.quota_exceeded.store(true, Ordering::Release); + } + } else { + this.quota_bytes_since_check = + this.quota_bytes_since_check.saturating_add(n_to_charge); + let interval = quota_adaptive_interval_bytes(remaining); + if this.quota_bytes_since_check >= interval { + this.quota_bytes_since_check = 0; + if this.user_stats.quota_used() >= limit { + this.quota_exceeded.store(true, Ordering::Release); + } + } + } } trace!(user = %this.user, bytes = n, "S->C"); @@ -732,7 +508,7 @@ where let now = Instant::now(); let idle = wd_counters.idle_duration(now, epoch); - if wd_quota_exceeded.load(Ordering::Relaxed) { + if wd_quota_exceeded.load(Ordering::Acquire) { warn!(user = %wd_user, "User data quota reached, closing relay"); return; } @@ -870,18 +646,10 @@ where } } -#[cfg(test)] -#[path = "tests/relay_security_tests.rs"] -mod security_tests; - #[cfg(test)] #[path = "tests/relay_adversarial_tests.rs"] mod adversarial_tests; -#[cfg(test)] -#[path = "tests/relay_quota_lock_pressure_adversarial_tests.rs"] -mod relay_quota_lock_pressure_adversarial_tests; - #[cfg(test)] #[path = "tests/relay_quota_boundary_blackhat_tests.rs"] mod relay_quota_boundary_blackhat_tests; @@ -901,71 +669,3 @@ mod relay_quota_extended_attack_surface_security_tests; #[cfg(test)] #[path = "tests/relay_watchdog_delta_security_tests.rs"] mod relay_watchdog_delta_security_tests; - -#[cfg(test)] -#[path = "tests/relay_quota_waker_storm_adversarial_tests.rs"] -mod relay_quota_waker_storm_adversarial_tests; - -#[cfg(test)] -#[path = "tests/relay_quota_wake_liveness_regression_tests.rs"] -mod relay_quota_wake_liveness_regression_tests; - -#[cfg(test)] -#[path = "tests/relay_quota_lock_identity_security_tests.rs"] -mod relay_quota_lock_identity_security_tests; - -#[cfg(test)] -#[path = "tests/relay_cross_mode_quota_lock_security_tests.rs"] -mod relay_cross_mode_quota_lock_security_tests; - -#[cfg(test)] -#[path = "tests/relay_quota_retry_scheduler_tdd_tests.rs"] -mod relay_quota_retry_scheduler_tdd_tests; - -#[cfg(test)] -#[path = "tests/relay_cross_mode_quota_fairness_tdd_tests.rs"] -mod relay_cross_mode_quota_fairness_tdd_tests; - -#[cfg(test)] -#[path = "tests/relay_cross_mode_pipeline_hol_integration_security_tests.rs"] -mod relay_cross_mode_pipeline_hol_integration_security_tests; - -#[cfg(test)] -#[path = "tests/relay_cross_mode_pipeline_latency_benchmark_security_tests.rs"] -mod relay_cross_mode_pipeline_latency_benchmark_security_tests; - -#[cfg(test)] -#[path = "tests/relay_quota_retry_backoff_security_tests.rs"] -mod relay_quota_retry_backoff_security_tests; - -#[cfg(test)] -#[path = "tests/relay_quota_retry_backoff_benchmark_security_tests.rs"] -mod relay_quota_retry_backoff_benchmark_security_tests; - -#[cfg(test)] -#[path = "tests/relay_dual_lock_backoff_regression_security_tests.rs"] -mod relay_dual_lock_backoff_regression_security_tests; - -#[cfg(test)] -#[path = "tests/relay_dual_lock_contention_matrix_security_tests.rs"] -mod relay_dual_lock_contention_matrix_security_tests; - -#[cfg(test)] -#[path = "tests/relay_dual_lock_race_harness_security_tests.rs"] -mod relay_dual_lock_race_harness_security_tests; - -#[cfg(test)] -#[path = "tests/relay_dual_lock_alternating_contention_security_tests.rs"] -mod relay_dual_lock_alternating_contention_security_tests; - -#[cfg(test)] -#[path = "tests/relay_quota_retry_allocation_latency_security_tests.rs"] -mod relay_quota_retry_allocation_latency_security_tests; - -#[cfg(test)] -#[path = "tests/relay_quota_lock_eviction_lifecycle_tdd_tests.rs"] -mod relay_quota_lock_eviction_lifecycle_tdd_tests; - -#[cfg(test)] -#[path = "tests/relay_quota_lock_eviction_stress_security_tests.rs"] -mod relay_quota_lock_eviction_stress_security_tests; diff --git a/src/stats/mod.rs b/src/stats/mod.rs index dc455a1..7d8aef3 100644 --- a/src/stats/mod.rs +++ b/src/stats/mod.rs @@ -238,10 +238,12 @@ pub struct Stats { me_inline_recovery_total: AtomicU64, ip_reservation_rollback_tcp_limit_total: AtomicU64, ip_reservation_rollback_quota_limit_total: AtomicU64, + quota_write_fail_bytes_total: AtomicU64, + quota_write_fail_events_total: AtomicU64, telemetry_core_enabled: AtomicBool, telemetry_user_enabled: AtomicBool, telemetry_me_level: AtomicU8, - user_stats: DashMap, + user_stats: DashMap>, user_stats_last_cleanup_epoch_secs: AtomicU64, start_time: parking_lot::RwLock>, } @@ -254,9 +256,51 @@ pub struct UserStats { pub octets_to_client: AtomicU64, pub msgs_from_client: AtomicU64, pub msgs_to_client: AtomicU64, + /// Total bytes charged against per-user quota admission. + /// + /// This counter is the single source of truth for quota enforcement and + /// intentionally tracks attempted traffic, not guaranteed delivery. + pub quota_used: AtomicU64, pub last_seen_epoch_secs: AtomicU64, } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum QuotaReserveError { + LimitExceeded, + Contended, +} + +impl UserStats { + #[inline] + pub fn quota_used(&self) -> u64 { + self.quota_used.load(Ordering::Relaxed) + } + + /// Attempts one CAS reservation step against the quota counter. + /// + /// Callers control retry/yield policy. This primitive intentionally does + /// not block or sleep so both sync poll paths and async paths can wrap it + /// with their own contention strategy. + #[inline] + pub fn quota_try_reserve(&self, bytes: u64, limit: u64) -> Result { + let current = self.quota_used.load(Ordering::Relaxed); + if bytes > limit.saturating_sub(current) { + return Err(QuotaReserveError::LimitExceeded); + } + + let next = current.saturating_add(bytes); + match self.quota_used.compare_exchange_weak( + current, + next, + Ordering::Relaxed, + Ordering::Relaxed, + ) { + Ok(_) => Ok(next), + Err(_) => Err(QuotaReserveError::Contended), + } + } +} + impl Stats { pub fn new() -> Self { let stats = Self::default(); @@ -316,6 +360,70 @@ impl Stats { .store(Self::now_epoch_secs(), Ordering::Relaxed); } + pub(crate) fn get_or_create_user_stats_handle(&self, user: &str) -> Arc { + self.maybe_cleanup_user_stats(); + if let Some(existing) = self.user_stats.get(user) { + let handle = Arc::clone(existing.value()); + Self::touch_user_stats(handle.as_ref()); + return handle; + } + + let entry = self.user_stats.entry(user.to_string()).or_default(); + if entry.last_seen_epoch_secs.load(Ordering::Relaxed) == 0 { + Self::touch_user_stats(entry.value().as_ref()); + } + Arc::clone(entry.value()) + } + + #[inline] + pub(crate) fn add_user_octets_from_handle(&self, user_stats: &UserStats, bytes: u64) { + if !self.telemetry_user_enabled() { + return; + } + Self::touch_user_stats(user_stats); + user_stats.octets_from_client.fetch_add(bytes, Ordering::Relaxed); + } + + #[inline] + pub(crate) fn add_user_octets_to_handle(&self, user_stats: &UserStats, bytes: u64) { + if !self.telemetry_user_enabled() { + return; + } + Self::touch_user_stats(user_stats); + user_stats.octets_to_client.fetch_add(bytes, Ordering::Relaxed); + } + + #[inline] + pub(crate) fn increment_user_msgs_from_handle(&self, user_stats: &UserStats) { + if !self.telemetry_user_enabled() { + return; + } + Self::touch_user_stats(user_stats); + user_stats.msgs_from_client.fetch_add(1, Ordering::Relaxed); + } + + #[inline] + pub(crate) fn increment_user_msgs_to_handle(&self, user_stats: &UserStats) { + if !self.telemetry_user_enabled() { + return; + } + Self::touch_user_stats(user_stats); + user_stats.msgs_to_client.fetch_add(1, Ordering::Relaxed); + } + + /// Charges already committed bytes in a post-I/O path. + /// + /// This helper is intentionally separate from `quota_try_reserve` to avoid + /// mixing reserve and post-charge on a single I/O event. + #[inline] + pub(crate) fn quota_charge_post_write(&self, user_stats: &UserStats, bytes: u64) -> u64 { + Self::touch_user_stats(user_stats); + user_stats + .quota_used + .fetch_add(bytes, Ordering::Relaxed) + .saturating_add(bytes) + } + fn maybe_cleanup_user_stats(&self) { const USER_STATS_CLEANUP_INTERVAL_SECS: u64 = 60; const USER_STATS_IDLE_TTL_SECS: u64 = 24 * 60 * 60; @@ -1114,6 +1222,18 @@ impl Stats { .fetch_add(1, Ordering::Relaxed); } } + pub fn add_quota_write_fail_bytes_total(&self, bytes: u64) { + if self.telemetry_core_enabled() { + self.quota_write_fail_bytes_total + .fetch_add(bytes, Ordering::Relaxed); + } + } + pub fn increment_quota_write_fail_events_total(&self) { + if self.telemetry_core_enabled() { + self.quota_write_fail_events_total + .fetch_add(1, Ordering::Relaxed); + } + } pub fn increment_me_endpoint_quarantine_total(&self) { if self.telemetry_me_allows_normal() { self.me_endpoint_quarantine_total @@ -1764,19 +1884,19 @@ impl Stats { self.ip_reservation_rollback_quota_limit_total .load(Ordering::Relaxed) } + pub fn get_quota_write_fail_bytes_total(&self) -> u64 { + self.quota_write_fail_bytes_total.load(Ordering::Relaxed) + } + pub fn get_quota_write_fail_events_total(&self) -> u64 { + self.quota_write_fail_events_total.load(Ordering::Relaxed) + } pub fn increment_user_connects(&self, user: &str) { if !self.telemetry_user_enabled() { return; } - self.maybe_cleanup_user_stats(); - if let Some(stats) = self.user_stats.get(user) { - Self::touch_user_stats(stats.value()); - stats.connects.fetch_add(1, Ordering::Relaxed); - return; - } - let stats = self.user_stats.entry(user.to_string()).or_default(); - Self::touch_user_stats(stats.value()); + let stats = self.get_or_create_user_stats_handle(user); + Self::touch_user_stats(stats.as_ref()); stats.connects.fetch_add(1, Ordering::Relaxed); } @@ -1784,14 +1904,8 @@ impl Stats { if !self.telemetry_user_enabled() { return; } - self.maybe_cleanup_user_stats(); - if let Some(stats) = self.user_stats.get(user) { - Self::touch_user_stats(stats.value()); - stats.curr_connects.fetch_add(1, Ordering::Relaxed); - return; - } - let stats = self.user_stats.entry(user.to_string()).or_default(); - Self::touch_user_stats(stats.value()); + let stats = self.get_or_create_user_stats_handle(user); + Self::touch_user_stats(stats.as_ref()); stats.curr_connects.fetch_add(1, Ordering::Relaxed); } @@ -1800,9 +1914,8 @@ impl Stats { return true; } - self.maybe_cleanup_user_stats(); - let stats = self.user_stats.entry(user.to_string()).or_default(); - Self::touch_user_stats(stats.value()); + let stats = self.get_or_create_user_stats_handle(user); + Self::touch_user_stats(stats.as_ref()); let counter = &stats.curr_connects; let mut current = counter.load(Ordering::Relaxed); @@ -1827,7 +1940,7 @@ impl Stats { pub fn decrement_user_curr_connects(&self, user: &str) { self.maybe_cleanup_user_stats(); if let Some(stats) = self.user_stats.get(user) { - Self::touch_user_stats(stats.value()); + Self::touch_user_stats(stats.value().as_ref()); let counter = &stats.curr_connects; let mut current = counter.load(Ordering::Relaxed); loop { @@ -1858,86 +1971,32 @@ impl Stats { if !self.telemetry_user_enabled() { return; } - self.maybe_cleanup_user_stats(); - if let Some(stats) = self.user_stats.get(user) { - Self::touch_user_stats(stats.value()); - stats.octets_from_client.fetch_add(bytes, Ordering::Relaxed); - return; - } - let stats = self.user_stats.entry(user.to_string()).or_default(); - Self::touch_user_stats(stats.value()); - stats.octets_from_client.fetch_add(bytes, Ordering::Relaxed); + let stats = self.get_or_create_user_stats_handle(user); + self.add_user_octets_from_handle(stats.as_ref(), bytes); } pub fn add_user_octets_to(&self, user: &str, bytes: u64) { if !self.telemetry_user_enabled() { return; } - self.maybe_cleanup_user_stats(); - if let Some(stats) = self.user_stats.get(user) { - Self::touch_user_stats(stats.value()); - stats.octets_to_client.fetch_add(bytes, Ordering::Relaxed); - return; - } - let stats = self.user_stats.entry(user.to_string()).or_default(); - Self::touch_user_stats(stats.value()); - stats.octets_to_client.fetch_add(bytes, Ordering::Relaxed); - } - - pub fn sub_user_octets_to(&self, user: &str, bytes: u64) { - if !self.telemetry_user_enabled() { - return; - } - self.maybe_cleanup_user_stats(); - let Some(stats) = self.user_stats.get(user) else { - return; - }; - - Self::touch_user_stats(stats.value()); - let counter = &stats.octets_to_client; - let mut current = counter.load(Ordering::Relaxed); - loop { - let next = current.saturating_sub(bytes); - match counter.compare_exchange_weak( - current, - next, - Ordering::Relaxed, - Ordering::Relaxed, - ) { - Ok(_) => break, - Err(actual) => current = actual, - } - } + let stats = self.get_or_create_user_stats_handle(user); + self.add_user_octets_to_handle(stats.as_ref(), bytes); } pub fn increment_user_msgs_from(&self, user: &str) { if !self.telemetry_user_enabled() { return; } - self.maybe_cleanup_user_stats(); - if let Some(stats) = self.user_stats.get(user) { - Self::touch_user_stats(stats.value()); - stats.msgs_from_client.fetch_add(1, Ordering::Relaxed); - return; - } - let stats = self.user_stats.entry(user.to_string()).or_default(); - Self::touch_user_stats(stats.value()); - stats.msgs_from_client.fetch_add(1, Ordering::Relaxed); + let stats = self.get_or_create_user_stats_handle(user); + self.increment_user_msgs_from_handle(stats.as_ref()); } pub fn increment_user_msgs_to(&self, user: &str) { if !self.telemetry_user_enabled() { return; } - self.maybe_cleanup_user_stats(); - if let Some(stats) = self.user_stats.get(user) { - Self::touch_user_stats(stats.value()); - stats.msgs_to_client.fetch_add(1, Ordering::Relaxed); - return; - } - let stats = self.user_stats.entry(user.to_string()).or_default(); - Self::touch_user_stats(stats.value()); - stats.msgs_to_client.fetch_add(1, Ordering::Relaxed); + let stats = self.get_or_create_user_stats_handle(user); + self.increment_user_msgs_to_handle(stats.as_ref()); } pub fn get_user_total_octets(&self, user: &str) -> u64 { @@ -1950,6 +2009,13 @@ impl Stats { .unwrap_or(0) } + pub fn get_user_quota_used(&self, user: &str) -> u64 { + self.user_stats + .get(user) + .map(|s| s.quota_used.load(Ordering::Relaxed)) + .unwrap_or(0) + } + pub fn get_handshake_timeouts(&self) -> u64 { self.handshake_timeouts.load(Ordering::Relaxed) } @@ -2015,7 +2081,7 @@ impl Stats { .load(Ordering::Relaxed) } - pub fn iter_user_stats(&self) -> dashmap::iter::Iter<'_, String, UserStats> { + pub fn iter_user_stats(&self) -> dashmap::iter::Iter<'_, String, Arc> { self.user_stats.iter() } @@ -2163,6 +2229,22 @@ impl ReplayChecker { found } + fn check_only_internal( + &self, + data: &[u8], + shards: &[Mutex], + window: Duration, + ) -> bool { + self.checks.fetch_add(1, Ordering::Relaxed); + let idx = self.get_shard_idx(data); + let mut shard = shards[idx].lock(); + let found = shard.check(data, Instant::now(), window); + if found { + self.hits.fetch_add(1, Ordering::Relaxed); + } + found + } + fn add_only(&self, data: &[u8], shards: &[Mutex], window: Duration) { self.additions.fetch_add(1, Ordering::Relaxed); let idx = self.get_shard_idx(data); @@ -2186,7 +2268,7 @@ impl ReplayChecker { self.add_only(data, &self.handshake_shards, self.window) } pub fn check_tls_digest(&self, data: &[u8]) -> bool { - self.check_and_add_tls_digest(data) + self.check_only_internal(data, &self.tls_shards, self.tls_window) } pub fn add_tls_digest(&self, data: &[u8]) { self.add_only(data, &self.tls_shards, self.tls_window) @@ -2289,6 +2371,7 @@ impl ReplayStats { mod tests { use super::*; use crate::config::MeTelemetryLevel; + use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; #[test] @@ -2457,6 +2540,60 @@ mod tests { } assert_eq!(checker.stats().total_entries, 500); } + + #[test] + fn test_quota_reserve_under_contention_hits_limit_exactly() { + let user_stats = Arc::new(UserStats::default()); + let successes = Arc::new(AtomicU64::new(0)); + let limit = 8_192u64; + let mut workers = Vec::new(); + + for _ in 0..8 { + let user_stats = user_stats.clone(); + let successes = successes.clone(); + workers.push(std::thread::spawn(move || { + loop { + match user_stats.quota_try_reserve(1, limit) { + Ok(_) => { + successes.fetch_add(1, Ordering::Relaxed); + } + Err(QuotaReserveError::Contended) => { + std::hint::spin_loop(); + } + Err(QuotaReserveError::LimitExceeded) => { + break; + } + } + } + })); + } + + for worker in workers { + worker.join().expect("worker thread must finish"); + } + + assert_eq!( + successes.load(Ordering::Relaxed), + limit, + "successful reservations must stop exactly at limit" + ); + assert_eq!(user_stats.quota_used(), limit); + } + + #[test] + fn test_quota_used_is_authoritative_and_independent_from_octets_telemetry() { + let stats = Stats::new(); + let user = "quota-authoritative-user"; + let user_stats = stats.get_or_create_user_stats_handle(user); + + stats.add_user_octets_to_handle(&user_stats, 5); + assert_eq!(stats.get_user_total_octets(user), 5); + assert_eq!(stats.get_user_quota_used(user), 0); + + stats.quota_charge_post_write(&user_stats, 7); + assert_eq!(stats.get_user_total_octets(user), 5); + assert_eq!(stats.get_user_quota_used(user), 7); + } } #[cfg(test)] @@ -2466,7 +2603,3 @@ mod connection_lease_security_tests; #[cfg(test)] #[path = "tests/replay_checker_security_tests.rs"] mod replay_checker_security_tests; - -#[cfg(test)] -#[path = "tests/user_octets_sub_security_tests.rs"] -mod user_octets_sub_security_tests; diff --git a/src/stats/tests/user_octets_sub_security_tests.rs b/src/stats/tests/user_octets_sub_security_tests.rs deleted file mode 100644 index d4e7580..0000000 --- a/src/stats/tests/user_octets_sub_security_tests.rs +++ /dev/null @@ -1,151 +0,0 @@ -use super::*; -use std::sync::Arc; -use std::thread; - -#[test] -fn sub_user_octets_to_underflow_saturates_at_zero() { - let stats = Stats::new(); - let user = "sub-underflow-user"; - - stats.add_user_octets_to(user, 3); - stats.sub_user_octets_to(user, 100); - - assert_eq!(stats.get_user_total_octets(user), 0); -} - -#[test] -fn sub_user_octets_to_does_not_affect_octets_from_client() { - let stats = Stats::new(); - let user = "sub-isolation-user"; - - stats.add_user_octets_from(user, 17); - stats.add_user_octets_to(user, 5); - stats.sub_user_octets_to(user, 3); - - assert_eq!(stats.get_user_total_octets(user), 19); -} - -#[test] -fn light_fuzz_add_sub_model_matches_saturating_reference() { - let stats = Stats::new(); - let user = "sub-fuzz-user"; - let mut seed = 0x91D2_4CB8_EE77_1101u64; - let mut model_to = 0u64; - - for _ in 0..8192 { - seed ^= seed << 7; - seed ^= seed >> 9; - seed ^= seed << 8; - - let amt = ((seed >> 8) & 0x3f) + 1; - if (seed & 1) == 0 { - stats.add_user_octets_to(user, amt); - model_to = model_to.saturating_add(amt); - } else { - stats.sub_user_octets_to(user, amt); - model_to = model_to.saturating_sub(amt); - } - } - - assert_eq!(stats.get_user_total_octets(user), model_to); -} - -#[test] -fn stress_parallel_add_sub_never_underflows_or_panics() { - let stats = Arc::new(Stats::new()); - let user = "sub-stress-user"; - // Pre-fund with a large offset so subtractions never saturate at zero. - // This guarantees commutative updates, making the final state deterministic. - let base_offset = 10_000_000u64; - stats.add_user_octets_to(user, base_offset); - - let mut workers = Vec::new(); - - for tid in 0..16u64 { - let stats_for_thread = Arc::clone(&stats); - workers.push(thread::spawn(move || { - let mut seed = 0xD00D_1000_0000_0000u64 ^ tid; - let mut net_delta = 0i64; - for _ in 0..4096 { - seed ^= seed << 7; - seed ^= seed >> 9; - seed ^= seed << 8; - let amt = ((seed >> 8) & 0x1f) + 1; - - if (seed & 1) == 0 { - stats_for_thread.add_user_octets_to(user, amt); - net_delta += amt as i64; - } else { - stats_for_thread.sub_user_octets_to(user, amt); - net_delta -= amt as i64; - } - } - - net_delta - })); - } - - let mut expected_net_delta = 0i64; - for worker in workers { - expected_net_delta += worker - .join() - .expect("sub-user stress worker must not panic"); - } - - let expected_total = (base_offset as i64 + expected_net_delta) as u64; - let total = stats.get_user_total_octets(user); - assert_eq!( - total, expected_total, - "concurrent add/sub lost updates or suffered ABA races" - ); -} - -#[test] -fn sub_user_octets_to_missing_user_is_noop() { - let stats = Stats::new(); - stats.sub_user_octets_to("missing-user", 1024); - assert_eq!(stats.get_user_total_octets("missing-user"), 0); -} - -#[test] -fn stress_parallel_per_user_models_remain_exact() { - let stats = Arc::new(Stats::new()); - let mut workers = Vec::new(); - - for tid in 0..16u64 { - let stats_for_thread = Arc::clone(&stats); - workers.push(thread::spawn(move || { - let user = format!("sub-per-user-{tid}"); - let mut seed = 0xFACE_0000_0000_0000u64 ^ tid; - let mut model = 0u64; - - for _ in 0..4096 { - seed ^= seed << 7; - seed ^= seed >> 9; - seed ^= seed << 8; - let amt = ((seed >> 8) & 0x3f) + 1; - - if (seed & 1) == 0 { - stats_for_thread.add_user_octets_to(&user, amt); - model = model.saturating_add(amt); - } else { - stats_for_thread.sub_user_octets_to(&user, amt); - model = model.saturating_sub(amt); - } - } - - (user, model) - })); - } - - for worker in workers { - let (user, model) = worker - .join() - .expect("per-user subtract stress worker must not panic"); - assert_eq!( - stats.get_user_total_octets(&user), - model, - "per-user parallel model diverged" - ); - } -} \ No newline at end of file