Redesign Quotas on Atomics

This commit is contained in:
Alexey
2026-03-23 15:53:44 +03:00
parent 0c3c9009a9
commit 6f4356f72a
10 changed files with 408 additions and 1043 deletions

View File

@@ -10,7 +10,7 @@ use std::time::{Duration, Instant};
use dashmap::DashMap;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::sync::{Mutex as AsyncMutex, mpsc, oneshot, watch};
use tokio::sync::{mpsc, oneshot, watch};
use tokio::time::timeout;
use tracing::{debug, info, trace, warn};
@@ -23,7 +23,7 @@ use crate::proxy::route_mode::{
ROUTE_SWITCH_ERROR_MSG, RelayRouteMode, RouteCutoverState, affected_cutover_state,
cutover_stagger_delay,
};
use crate::stats::{MeD2cFlushReason, MeD2cQuotaRejectStage, MeD2cWriteMode, Stats};
use crate::stats::{MeD2cFlushReason, MeD2cQuotaRejectStage, MeD2cWriteMode, QuotaReserveError, Stats, UserStats};
use crate::stream::{BufferPool, CryptoReader, CryptoWriter, PooledBuffer};
use crate::transport::middle_proxy::{MePool, MeResponse, proto_flags_for_tag};
@@ -53,20 +53,11 @@ const ME_D2C_FLUSH_BATCH_MAX_FRAMES_MIN: usize = 1;
const ME_D2C_FLUSH_BATCH_MAX_BYTES_MIN: usize = 4096;
const ME_D2C_FRAME_BUF_SHRINK_HYSTERESIS_FACTOR: usize = 2;
const ME_D2C_SINGLE_WRITE_COALESCE_MAX_BYTES: usize = 128 * 1024;
#[cfg(test)]
const QUOTA_USER_LOCKS_MAX: usize = 64;
#[cfg(not(test))]
const QUOTA_USER_LOCKS_MAX: usize = 4_096;
#[cfg(test)]
const QUOTA_OVERFLOW_LOCK_STRIPES: usize = 16;
#[cfg(not(test))]
const QUOTA_OVERFLOW_LOCK_STRIPES: usize = 256;
const QUOTA_RESERVE_SPIN_RETRIES: usize = 32;
static DESYNC_DEDUP: OnceLock<DashMap<u64, Instant>> = OnceLock::new();
static DESYNC_HASHER: OnceLock<RandomState> = OnceLock::new();
static DESYNC_FULL_CACHE_LAST_EMIT_AT: OnceLock<Mutex<Option<Instant>>> = OnceLock::new();
static DESYNC_DEDUP_EVER_SATURATED: OnceLock<AtomicBool> = OnceLock::new();
static QUOTA_USER_LOCKS: OnceLock<DashMap<String, Arc<AsyncMutex<()>>>> = OnceLock::new();
static QUOTA_USER_OVERFLOW_LOCKS: OnceLock<Vec<Arc<AsyncMutex<()>>>> = OnceLock::new();
static RELAY_IDLE_CANDIDATE_REGISTRY: OnceLock<Mutex<RelayIdleCandidateRegistry>> = OnceLock::new();
static RELAY_IDLE_MARK_SEQ: AtomicU64 = AtomicU64::new(0);
@@ -538,36 +529,28 @@ fn should_yield_c2me_sender(sent_since_yield: usize, has_backlog: bool) -> bool
has_backlog && sent_since_yield >= C2ME_SENDER_FAIRNESS_BUDGET
}
fn quota_exceeded_for_user(stats: &Stats, user: &str, quota_limit: Option<u64>) -> bool {
quota_limit.is_some_and(|quota| stats.get_user_total_octets(user) >= quota)
}
#[cfg_attr(not(test), allow(dead_code))]
fn quota_would_be_exceeded_for_user(
stats: &Stats,
user: &str,
quota_limit: Option<u64>,
bytes: u64,
) -> bool {
quota_limit.is_some_and(|quota| {
let used = stats.get_user_total_octets(user);
used >= quota || bytes > quota.saturating_sub(used)
})
}
fn quota_soft_cap(limit: u64, overshoot: u64) -> u64 {
limit.saturating_add(overshoot)
}
fn quota_would_be_exceeded_for_user_soft(
stats: &Stats,
user: &str,
quota_limit: Option<u64>,
async fn reserve_user_quota_with_yield(
user_stats: &UserStats,
bytes: u64,
overshoot: u64,
) -> bool {
let capped_limit = quota_limit.map(|quota| quota_soft_cap(quota, overshoot));
quota_would_be_exceeded_for_user(stats, user, capped_limit, bytes)
limit: u64,
) -> std::result::Result<u64, QuotaReserveError> {
loop {
for _ in 0..QUOTA_RESERVE_SPIN_RETRIES {
match user_stats.quota_try_reserve(bytes, limit) {
Ok(total) => return Ok(total),
Err(QuotaReserveError::LimitExceeded) => {
return Err(QuotaReserveError::LimitExceeded);
}
Err(QuotaReserveError::Contended) => std::hint::spin_loop(),
}
}
tokio::task::yield_now().await;
}
}
fn classify_me_d2c_flush_reason(
@@ -613,29 +596,6 @@ fn observe_me_d2c_flush_event(
}
}
fn rollback_me2c_quota_reservation(
stats: &Stats,
user: &str,
bytes_me2c: &AtomicU64,
reserved_bytes: u64,
) {
stats.sub_user_octets_to(user, reserved_bytes);
bytes_me2c.fetch_sub(reserved_bytes, Ordering::Relaxed);
}
#[cfg(test)]
fn quota_user_lock_test_guard() -> &'static Mutex<()> {
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
TEST_LOCK.get_or_init(|| Mutex::new(()))
}
#[cfg(test)]
fn quota_user_lock_test_scope() -> std::sync::MutexGuard<'static, ()> {
quota_user_lock_test_guard()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
#[cfg(test)]
fn relay_idle_pressure_test_guard() -> &'static Mutex<()> {
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
@@ -649,46 +609,6 @@ pub(crate) fn relay_idle_pressure_test_scope() -> std::sync::MutexGuard<'static,
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
fn quota_overflow_user_lock(user: &str) -> Arc<AsyncMutex<()>> {
let stripes = QUOTA_USER_OVERFLOW_LOCKS.get_or_init(|| {
(0..QUOTA_OVERFLOW_LOCK_STRIPES)
.map(|_| Arc::new(AsyncMutex::new(())))
.collect()
});
let hash = crc32fast::hash(user.as_bytes()) as usize;
Arc::clone(&stripes[hash % stripes.len()])
}
fn quota_user_lock(user: &str) -> Arc<AsyncMutex<()>> {
let locks = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
if let Some(existing) = locks.get(user) {
return Arc::clone(existing.value());
}
if locks.len() >= QUOTA_USER_LOCKS_MAX {
locks.retain(|_, value| Arc::strong_count(value) > 1);
}
if locks.len() >= QUOTA_USER_LOCKS_MAX {
return quota_overflow_user_lock(user);
}
let created = Arc::new(AsyncMutex::new(()));
match locks.entry(user.to_string()) {
dashmap::mapref::entry::Entry::Occupied(entry) => Arc::clone(entry.get()),
dashmap::mapref::entry::Entry::Vacant(entry) => {
entry.insert(Arc::clone(&created));
created
}
}
}
#[cfg(test)]
pub(crate) fn cross_mode_quota_user_lock_for_tests(user: &str) -> Arc<AsyncMutex<()>> {
crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(user)
}
async fn enqueue_c2me_command(
tx: &mpsc::Sender<C2MeCommand>,
cmd: C2MeCommand,
@@ -744,8 +664,7 @@ where
{
let user = success.user.clone();
let quota_limit = config.access.user_data_quota.get(&user).copied();
let cross_mode_quota_lock =
quota_limit.map(|_| crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user));
let quota_user_stats = quota_limit.map(|_| stats.get_or_create_user_stats_handle(&user));
let peer = success.peer;
let proto_tag = success.proto_tag;
let pool_generation = me_pool.current_generation();
@@ -872,7 +791,7 @@ where
let stats_clone = stats.clone();
let rng_clone = rng.clone();
let user_clone = user.clone();
let cross_mode_quota_lock_me_writer = cross_mode_quota_lock.clone();
let quota_user_stats_me_writer = quota_user_stats.clone();
let last_downstream_activity_ms_clone = last_downstream_activity_ms.clone();
let bytes_me2c_clone = bytes_me2c.clone();
let d2c_flush_policy = MeD2cFlushPolicy::from_config(&config);
@@ -894,7 +813,7 @@ where
let first_is_downstream_activity =
matches!(&first, MeResponse::Data { .. } | MeResponse::Ack(_));
match process_me_writer_response_with_cross_mode_lock(
match process_me_writer_response(
first,
&mut writer,
proto_tag,
@@ -902,9 +821,9 @@ where
&mut frame_buf,
stats_clone.as_ref(),
&user_clone,
quota_user_stats_me_writer.as_deref(),
quota_limit,
d2c_flush_policy.quota_soft_overshoot_bytes,
cross_mode_quota_lock_me_writer.as_ref(),
bytes_me2c_clone.as_ref(),
conn_id,
d2c_flush_policy.ack_flush_immediate,
@@ -953,7 +872,7 @@ where
let next_is_downstream_activity =
matches!(&next, MeResponse::Data { .. } | MeResponse::Ack(_));
match process_me_writer_response_with_cross_mode_lock(
match process_me_writer_response(
next,
&mut writer,
proto_tag,
@@ -961,9 +880,9 @@ where
&mut frame_buf,
stats_clone.as_ref(),
&user_clone,
quota_user_stats_me_writer.as_deref(),
quota_limit,
d2c_flush_policy.quota_soft_overshoot_bytes,
cross_mode_quota_lock_me_writer.as_ref(),
bytes_me2c_clone.as_ref(),
conn_id,
d2c_flush_policy.ack_flush_immediate,
@@ -1015,7 +934,7 @@ where
Ok(Some(next)) => {
let next_is_downstream_activity =
matches!(&next, MeResponse::Data { .. } | MeResponse::Ack(_));
match process_me_writer_response_with_cross_mode_lock(
match process_me_writer_response(
next,
&mut writer,
proto_tag,
@@ -1023,9 +942,9 @@ where
&mut frame_buf,
stats_clone.as_ref(),
&user_clone,
quota_user_stats_me_writer.as_deref(),
quota_limit,
d2c_flush_policy.quota_soft_overshoot_bytes,
cross_mode_quota_lock_me_writer.as_ref(),
bytes_me2c_clone.as_ref(),
conn_id,
d2c_flush_policy.ack_flush_immediate,
@@ -1079,7 +998,7 @@ where
let extra_is_downstream_activity =
matches!(&extra, MeResponse::Data { .. } | MeResponse::Ack(_));
match process_me_writer_response_with_cross_mode_lock(
match process_me_writer_response(
extra,
&mut writer,
proto_tag,
@@ -1087,9 +1006,9 @@ where
&mut frame_buf,
stats_clone.as_ref(),
&user_clone,
quota_user_stats_me_writer.as_deref(),
quota_limit,
d2c_flush_policy.quota_soft_overshoot_bytes,
cross_mode_quota_lock_me_writer.as_ref(),
bytes_me2c_clone.as_ref(),
conn_id,
d2c_flush_policy.ack_flush_immediate,
@@ -1259,24 +1178,23 @@ where
forensics.bytes_c2me = forensics
.bytes_c2me
.saturating_add(payload.len() as u64);
if let Some(limit) = quota_limit {
let quota_lock = quota_user_lock(&user);
let _quota_guard = quota_lock.lock().await;
let Some(cross_mode_lock) = cross_mode_quota_lock.as_ref() else {
main_result = Err(ProxyError::Proxy(
"cross-mode quota lock missing for quota-limited session"
.to_string(),
));
break;
};
let _cross_mode_quota_guard = cross_mode_lock.lock().await;
stats.add_user_octets_from(&user, payload.len() as u64);
if quota_exceeded_for_user(stats.as_ref(), &user, Some(limit)) {
if let (Some(limit), Some(user_stats)) =
(quota_limit, quota_user_stats.as_deref())
{
if reserve_user_quota_with_yield(
user_stats,
payload.len() as u64,
limit,
)
.await
.is_err()
{
main_result = Err(ProxyError::DataQuotaExceeded {
user: user.clone(),
});
break;
}
stats.add_user_octets_from_handle(user_stats, payload.len() as u64);
} else {
stats.add_user_octets_from(&user, payload.len() as u64);
}
@@ -1755,7 +1673,6 @@ enum MeWriterResponseOutcome {
Close,
}
#[cfg(test)]
async fn process_me_writer_response<W>(
response: MeResponse,
client_writer: &mut CryptoWriter<W>,
@@ -1764,6 +1681,7 @@ async fn process_me_writer_response<W>(
frame_buf: &mut Vec<u8>,
stats: &Stats,
user: &str,
quota_user_stats: Option<&UserStats>,
quota_limit: Option<u64>,
quota_soft_overshoot_bytes: u64,
bytes_me2c: &AtomicU64,
@@ -1771,44 +1689,6 @@ async fn process_me_writer_response<W>(
ack_flush_immediate: bool,
batched: bool,
) -> Result<MeWriterResponseOutcome>
where
W: AsyncWrite + Unpin + Send + 'static,
{
process_me_writer_response_with_cross_mode_lock(
response,
client_writer,
proto_tag,
rng,
frame_buf,
stats,
user,
quota_limit,
quota_soft_overshoot_bytes,
None,
bytes_me2c,
conn_id,
ack_flush_immediate,
batched,
)
.await
}
async fn process_me_writer_response_with_cross_mode_lock<W>(
response: MeResponse,
client_writer: &mut CryptoWriter<W>,
proto_tag: ProtoTag,
rng: &SecureRandom,
frame_buf: &mut Vec<u8>,
stats: &Stats,
user: &str,
quota_limit: Option<u64>,
quota_soft_overshoot_bytes: u64,
cross_mode_quota_lock: Option<&Arc<AsyncMutex<()>>>,
bytes_me2c: &AtomicU64,
conn_id: u64,
ack_flush_immediate: bool,
batched: bool,
) -> Result<MeWriterResponseOutcome>
where
W: AsyncWrite + Unpin + Send + 'static,
{
@@ -1820,78 +1700,43 @@ where
trace!(conn_id, bytes = data.len(), flags, "ME->C data");
}
let data_len = data.len() as u64;
if let Some(limit) = quota_limit {
let owned_cross_mode_lock;
let cross_mode_lock = if let Some(lock) = cross_mode_quota_lock {
lock
} else {
owned_cross_mode_lock =
crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(user);
&owned_cross_mode_lock
};
let cross_mode_quota_guard = cross_mode_lock.lock().await;
if let (Some(limit), Some(user_stats)) = (quota_limit, quota_user_stats) {
let soft_limit = quota_soft_cap(limit, quota_soft_overshoot_bytes);
if quota_would_be_exceeded_for_user_soft(
stats,
user,
Some(limit),
data_len,
quota_soft_overshoot_bytes,
) {
if reserve_user_quota_with_yield(user_stats, data_len, soft_limit)
.await
.is_err()
{
stats.increment_me_d2c_quota_reject_total(MeD2cQuotaRejectStage::PreWrite);
return Err(ProxyError::DataQuotaExceeded {
user: user.to_string(),
});
}
// Reserve quota before awaiting network I/O to avoid same-user HoL stalls.
// If reservation loses a race or write fails, we roll back immediately.
bytes_me2c.fetch_add(data_len, Ordering::Relaxed);
stats.add_user_octets_to(user, data_len);
if stats.get_user_total_octets(user) > soft_limit {
rollback_me2c_quota_reservation(stats, user, bytes_me2c, data_len);
stats.increment_me_d2c_quota_reject_total(MeD2cQuotaRejectStage::PreWrite);
return Err(ProxyError::DataQuotaExceeded {
user: user.to_string(),
});
}
// Keep cross-mode lock scope explicit and minimal: quota reservation is serialized,
// but socket I/O proceeds without holding same-user cross-mode admission lock.
drop(cross_mode_quota_guard);
let write_mode =
match write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf)
.await
{
Ok(mode) => mode,
Err(err) => {
rollback_me2c_quota_reservation(stats, user, bytes_me2c, data_len);
return Err(err);
}
};
stats.increment_me_d2c_data_frames_total();
stats.add_me_d2c_payload_bytes_total(data_len);
stats.increment_me_d2c_write_mode(write_mode);
// Do not fail immediately on exact boundary after a successful write.
// Returning an error here can bypass batch flush in the caller and risk
// dropping buffered ciphertext from CryptoWriter. The next frame is
// rejected by the pre-check at function entry.
} else {
let write_mode =
write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf)
.await?;
bytes_me2c.fetch_add(data_len, Ordering::Relaxed);
stats.add_user_octets_to(user, data_len);
stats.increment_me_d2c_data_frames_total();
stats.add_me_d2c_payload_bytes_total(data_len);
stats.increment_me_d2c_write_mode(write_mode);
}
let write_mode =
match write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf)
.await
{
Ok(mode) => mode,
Err(err) => {
if quota_limit.is_some() {
stats.add_quota_write_fail_bytes_total(data_len);
stats.increment_quota_write_fail_events_total();
}
return Err(err);
}
};
bytes_me2c.fetch_add(data_len, Ordering::Relaxed);
if let Some(user_stats) = quota_user_stats {
stats.add_user_octets_to_handle(user_stats, data_len);
} else {
stats.add_user_octets_to(user, data_len);
}
stats.increment_me_d2c_data_frames_total();
stats.add_me_d2c_payload_bytes_total(data_len);
stats.increment_me_d2c_write_mode(write_mode);
Ok(MeWriterResponseOutcome::Continue {
frames: 1,
bytes: data.len(),
@@ -2097,10 +1942,6 @@ where
.map_err(ProxyError::Io)
}
#[cfg(test)]
#[path = "tests/middle_relay_security_tests.rs"]
mod security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_idle_policy_security_tests.rs"]
mod idle_policy_security_tests;
@@ -2113,30 +1954,10 @@ mod desync_all_full_dedup_security_tests;
#[path = "tests/middle_relay_stub_completion_security_tests.rs"]
mod stub_completion_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_coverage_high_risk_security_tests.rs"]
mod coverage_high_risk_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_quota_overflow_lock_security_tests.rs"]
mod quota_overflow_lock_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_length_cast_hardening_security_tests.rs"]
mod length_cast_hardening_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_blackhat_campaign_integration_tests.rs"]
mod blackhat_campaign_integration_tests;
#[cfg(test)]
#[path = "tests/middle_relay_hol_quota_security_tests.rs"]
mod hol_quota_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_quota_reservation_adversarial_tests.rs"]
mod quota_reservation_adversarial_tests;
#[cfg(test)]
#[path = "tests/middle_relay_idle_registry_poison_security_tests.rs"]
mod middle_relay_idle_registry_poison_security_tests;
@@ -2156,27 +1977,3 @@ mod middle_relay_tiny_frame_debt_concurrency_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_tiny_frame_debt_proto_chunking_security_tests.rs"]
mod middle_relay_tiny_frame_debt_proto_chunking_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_cross_mode_quota_reservation_security_tests.rs"]
mod middle_relay_cross_mode_quota_reservation_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_cross_mode_quota_lock_matrix_security_tests.rs"]
mod middle_relay_cross_mode_quota_lock_matrix_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_cross_mode_lookup_efficiency_security_tests.rs"]
mod middle_relay_cross_mode_lookup_efficiency_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_cross_mode_lock_release_regression_security_tests.rs"]
mod middle_relay_cross_mode_lock_release_regression_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_quota_extended_attack_surface_security_tests.rs"]
mod middle_relay_quota_extended_attack_surface_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_quota_reservation_extreme_security_tests.rs"]
mod middle_relay_quota_reservation_extreme_security_tests;