diff --git a/src/config/tests/load_memory_envelope_tests.rs b/src/config/tests/load_memory_envelope_tests.rs index 1c201cc..ea78498 100644 --- a/src/config/tests/load_memory_envelope_tests.rs +++ b/src/config/tests/load_memory_envelope_tests.rs @@ -17,20 +17,6 @@ fn remove_temp_config(path: &PathBuf) { let _ = fs::remove_file(path); } -#[test] -fn defaults_enable_byte_bounded_route_fairness() { - let cfg = ProxyConfig::default(); - - assert!( - cfg.general.me_route_fairshare_enabled, - "D2C route fairness must be enabled by default to bound queued bytes" - ); - assert!( - cfg.general.me_route_backpressure_enabled, - "D2C route backpressure must be enabled by default to shed under sustained pressure" - ); -} - #[test] fn load_rejects_writer_cmd_capacity_above_upper_bound() { let path = write_temp_config( diff --git a/src/proxy/middle_relay.rs b/src/proxy/middle_relay.rs index b0ddb8f..e4b4fe6 100644 --- a/src/proxy/middle_relay.rs +++ b/src/proxy/middle_relay.rs @@ -2329,7 +2329,7 @@ where W: AsyncWrite + Unpin + Send + 'static, { match response { - MeResponse::Data { flags, data } => { + MeResponse::Data { flags, data, .. } => { if batched { trace!(conn_id, bytes = data.len(), flags, "ME->C data (batched)"); } else { diff --git a/src/proxy/tests/middle_relay_atomic_quota_invariant_tests.rs b/src/proxy/tests/middle_relay_atomic_quota_invariant_tests.rs index 7c176bc..18bd583 100644 --- a/src/proxy/tests/middle_relay_atomic_quota_invariant_tests.rs +++ b/src/proxy/tests/middle_relay_atomic_quota_invariant_tests.rs @@ -70,6 +70,7 @@ async fn me_writer_write_fail_keeps_reserved_quota_and_tracks_fail_metrics() { MeResponse::Data { flags: 0, data: payload.clone(), + route_permit: None, }, &mut writer, ProtoTag::Intermediate, @@ -139,6 +140,7 @@ async fn me_writer_pre_write_quota_reject_happens_before_writer_poll() { MeResponse::Data { flags: 0, data: Bytes::from_static(&[0xAA, 0xBB, 0xCC]), + route_permit: None, }, &mut writer, ProtoTag::Intermediate, diff --git a/src/transport/middle_proxy/mod.rs b/src/transport/middle_proxy/mod.rs index 992fec3..3f46a80 100644 --- a/src/transport/middle_proxy/mod.rs +++ b/src/transport/middle_proxy/mod.rs @@ -46,6 +46,7 @@ mod send_adversarial_tests; mod wire; use bytes::Bytes; +use tokio::sync::OwnedSemaphorePermit; #[allow(unused_imports)] pub use config_updater::{ @@ -68,9 +69,32 @@ pub use secret::{fetch_proxy_secret, fetch_proxy_secret_with_upstream}; pub(crate) use selftest::{bnd_snapshot, timeskew_snapshot, upstream_bnd_snapshots}; pub use wire::proto_flags_for_tag; +/// Holds D2C queued-byte capacity until a routed payload is consumed or dropped. +pub struct RouteBytePermit { + _permit: OwnedSemaphorePermit, +} + +impl std::fmt::Debug for RouteBytePermit { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("RouteBytePermit").finish_non_exhaustive() + } +} + +impl RouteBytePermit { + pub(crate) fn new(permit: OwnedSemaphorePermit) -> Self { + Self { _permit: permit } + } +} + +/// Response routed from middle proxy readers to client relay tasks. #[derive(Debug)] pub enum MeResponse { - Data { flags: u32, data: Bytes }, + /// Downstream payload with its queued-byte reservation. + Data { + flags: u32, + data: Bytes, + route_permit: Option, + }, Ack(u32), Close, } diff --git a/src/transport/middle_proxy/reader.rs b/src/transport/middle_proxy/reader.rs index 97fa329..2dae1f1 100644 --- a/src/transport/middle_proxy/reader.rs +++ b/src/transport/middle_proxy/reader.rs @@ -84,6 +84,7 @@ async fn route_data_with_retry( MeResponse::Data { flags, data: data.clone(), + route_permit: None, }, timeout_ms, ) @@ -639,7 +640,7 @@ mod tests { let routed = route_data_with_retry(®, conn_id, 0, Bytes::from_static(b"a"), 20).await; assert!(matches!(routed, RouteResult::Routed)); match rx.recv().await { - Some(MeResponse::Data { flags, data }) => { + Some(MeResponse::Data { flags, data, .. }) => { assert_eq!(flags, 0); assert_eq!(data, Bytes::from_static(b"a")); } diff --git a/src/transport/middle_proxy/registry.rs b/src/transport/middle_proxy/registry.rs index 0c7a0a9..ee2598d 100644 --- a/src/transport/middle_proxy/registry.rs +++ b/src/transport/middle_proxy/registry.rs @@ -1,18 +1,22 @@ use std::collections::{HashMap, HashSet}; use std::net::SocketAddr; +use std::sync::Arc; use std::sync::atomic::{AtomicU8, AtomicU64, Ordering}; use std::time::{Duration, SystemTime, UNIX_EPOCH}; use dashmap::DashMap; use tokio::sync::mpsc::error::TrySendError; -use tokio::sync::{Mutex, mpsc}; +use tokio::sync::{Mutex, Semaphore, mpsc}; -use super::MeResponse; +use super::{MeResponse, RouteBytePermit}; use super::codec::WriterCommand; const ROUTE_BACKPRESSURE_BASE_TIMEOUT_MS: u64 = 25; const ROUTE_BACKPRESSURE_HIGH_TIMEOUT_MS: u64 = 120; const ROUTE_BACKPRESSURE_HIGH_WATERMARK_PCT: u8 = 80; +const ROUTE_QUEUED_BYTE_PERMIT_UNIT: usize = 16 * 1024; +const ROUTE_QUEUED_PERMITS_PER_SLOT: usize = 4; +const ROUTE_QUEUED_MAX_FRAME_PERMITS: usize = 1024; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum RouteResult { @@ -53,6 +57,7 @@ pub(super) struct WriterActivitySnapshot { struct RoutingTable { map: DashMap>, + byte_budget: DashMap>, } struct WriterTable { @@ -105,6 +110,7 @@ pub struct ConnRegistry { route_backpressure_base_timeout_ms: AtomicU64, route_backpressure_high_timeout_ms: AtomicU64, route_backpressure_high_watermark_pct: AtomicU8, + route_byte_permits_per_conn: usize, } impl ConnRegistry { @@ -116,10 +122,20 @@ impl ConnRegistry { } pub fn with_route_channel_capacity(route_channel_capacity: usize) -> Self { + let route_channel_capacity = route_channel_capacity.max(1); + Self::with_route_limits( + route_channel_capacity, + Self::route_byte_permit_budget(route_channel_capacity), + ) + } + + fn with_route_limits(route_channel_capacity: usize, route_byte_permits_per_conn: usize) -> Self { let start = rand::random::() | 1; + let route_channel_capacity = route_channel_capacity.max(1); Self { routing: RoutingTable { map: DashMap::new(), + byte_budget: DashMap::new(), }, writers: WriterTable { map: DashMap::new(), @@ -131,15 +147,30 @@ impl ConnRegistry { inner: Mutex::new(BindingInner::new()), }, next_id: AtomicU64::new(start), - route_channel_capacity: route_channel_capacity.max(1), + route_channel_capacity, route_backpressure_base_timeout_ms: AtomicU64::new(ROUTE_BACKPRESSURE_BASE_TIMEOUT_MS), route_backpressure_high_timeout_ms: AtomicU64::new(ROUTE_BACKPRESSURE_HIGH_TIMEOUT_MS), route_backpressure_high_watermark_pct: AtomicU8::new( ROUTE_BACKPRESSURE_HIGH_WATERMARK_PCT, ), + route_byte_permits_per_conn: route_byte_permits_per_conn.max(1), } } + fn route_data_permits(data_len: usize) -> u32 { + data_len + .max(1) + .div_ceil(ROUTE_QUEUED_BYTE_PERMIT_UNIT) + .min(u32::MAX as usize) as u32 + } + + fn route_byte_permit_budget(route_channel_capacity: usize) -> usize { + route_channel_capacity + .saturating_mul(ROUTE_QUEUED_PERMITS_PER_SLOT) + .max(ROUTE_QUEUED_MAX_FRAME_PERMITS) + .max(1) + } + pub fn route_channel_capacity(&self) -> usize { self.route_channel_capacity } @@ -149,6 +180,14 @@ impl ConnRegistry { Self::with_route_channel_capacity(4096) } + #[cfg(test)] + fn with_route_byte_permits_for_tests( + route_channel_capacity: usize, + route_byte_permits_per_conn: usize, + ) -> Self { + Self::with_route_limits(route_channel_capacity, route_byte_permits_per_conn) + } + pub fn update_route_backpressure_policy( &self, base_timeout_ms: u64, @@ -170,6 +209,9 @@ impl ConnRegistry { let id = self.next_id.fetch_add(1, Ordering::Relaxed); let (tx, rx) = mpsc::channel(self.route_channel_capacity); self.routing.map.insert(id, tx); + self.routing + .byte_budget + .insert(id, Arc::new(Semaphore::new(self.route_byte_permits_per_conn))); (id, rx) } @@ -186,6 +228,7 @@ impl ConnRegistry { /// Unregister connection, returning associated writer_id if any. pub async fn unregister(&self, id: u64) -> Option { self.routing.map.remove(&id); + self.routing.byte_budget.remove(&id); self.hot_binding.map.remove(&id); let mut binding = self.binding.inner.lock().await; binding.meta.remove(&id); @@ -206,6 +249,65 @@ impl ConnRegistry { None } + async fn attach_route_byte_permit( + &self, + id: u64, + resp: MeResponse, + timeout_ms: Option, + ) -> std::result::Result { + let MeResponse::Data { + flags, + data, + route_permit, + } = resp + else { + return Ok(resp); + }; + + if route_permit.is_some() { + return Ok(MeResponse::Data { + flags, + data, + route_permit, + }); + } + + let Some(semaphore) = self + .routing + .byte_budget + .get(&id) + .map(|entry| entry.value().clone()) + else { + return Err(RouteResult::NoConn); + }; + let permits = Self::route_data_permits(data.len()); + let permit = match timeout_ms { + Some(0) => semaphore + .try_acquire_many_owned(permits) + .map_err(|_| RouteResult::QueueFullHigh)?, + Some(timeout_ms) => { + let acquire = semaphore.acquire_many_owned(permits); + match tokio::time::timeout(Duration::from_millis(timeout_ms.max(1)), acquire) + .await + { + Ok(Ok(permit)) => permit, + Ok(Err(_)) => return Err(RouteResult::ChannelClosed), + Err(_) => return Err(RouteResult::QueueFullHigh), + } + } + None => semaphore + .acquire_many_owned(permits) + .await + .map_err(|_| RouteResult::ChannelClosed)?, + }; + + Ok(MeResponse::Data { + flags, + data, + route_permit: Some(RouteBytePermit::new(permit)), + }) + } + #[allow(dead_code)] pub async fn route(&self, id: u64, resp: MeResponse) -> RouteResult { let tx = self.routing.map.get(&id).map(|entry| entry.value().clone()); @@ -214,15 +316,23 @@ impl ConnRegistry { return RouteResult::NoConn; }; + let base_timeout_ms = self + .route_backpressure_base_timeout_ms + .load(Ordering::Relaxed) + .max(1); + let resp = match self + .attach_route_byte_permit(id, resp, Some(base_timeout_ms)) + .await + { + Ok(resp) => resp, + Err(result) => return result, + }; + match tx.try_send(resp) { Ok(()) => RouteResult::Routed, Err(TrySendError::Closed(_)) => RouteResult::ChannelClosed, Err(TrySendError::Full(resp)) => { // Absorb short bursts without dropping/closing the session immediately. - let base_timeout_ms = self - .route_backpressure_base_timeout_ms - .load(Ordering::Relaxed) - .max(1); let high_timeout_ms = self .route_backpressure_high_timeout_ms .load(Ordering::Relaxed) @@ -266,6 +376,10 @@ impl ConnRegistry { let Some(tx) = tx else { return RouteResult::NoConn; }; + let resp = match self.attach_route_byte_permit(id, resp, Some(0)).await { + Ok(resp) => resp, + Err(result) => return result, + }; match tx.try_send(resp) { Ok(()) => RouteResult::Routed, @@ -289,6 +403,13 @@ impl ConnRegistry { let Some(tx) = tx else { return RouteResult::NoConn; }; + let resp = match self + .attach_route_byte_permit(id, resp, Some(timeout_ms)) + .await + { + Ok(resp) => resp, + Err(result) => return result, + }; match tx.try_send(resp) { Ok(()) => RouteResult::Routed, @@ -541,8 +662,10 @@ impl ConnRegistry { mod tests { use std::net::{IpAddr, Ipv4Addr, SocketAddr}; - use super::ConnMeta; - use super::ConnRegistry; + use bytes::Bytes; + + use super::{ConnMeta, ConnRegistry, RouteResult}; + use crate::transport::middle_proxy::MeResponse; #[tokio::test] async fn writer_activity_snapshot_tracks_writer_and_dc_load() { @@ -608,6 +731,55 @@ mod tests { assert_eq!(snapshot.active_sessions_by_target_dc.get(&4), Some(&1)); } + #[tokio::test] + async fn route_data_is_bounded_by_byte_permits_before_channel_capacity() { + let registry = ConnRegistry::with_route_byte_permits_for_tests(4, 1); + let (conn_id, mut rx) = registry.register().await; + let routed = registry + .route_nowait( + conn_id, + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0xAA]), + route_permit: None, + }, + ) + .await; + assert!(matches!(routed, RouteResult::Routed)); + + let blocked = registry + .route_nowait( + conn_id, + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0xBB]), + route_permit: None, + }, + ) + .await; + assert!( + matches!(blocked, RouteResult::QueueFullHigh), + "byte budget must reject data before count capacity is exhausted" + ); + + drop(rx.recv().await); + + let routed_after_drain = registry + .route_nowait( + conn_id, + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0xCC]), + route_permit: None, + }, + ) + .await; + assert!( + matches!(routed_after_drain, RouteResult::Routed), + "receiving queued data must release byte permits" + ); + } + #[tokio::test] async fn bind_writer_rebinds_conn_atomically() { let registry = ConnRegistry::new();