From ba1d9be5d4795ea660e643179788fb76cc14b9d0 Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Sun, 10 May 2026 13:22:54 +0300 Subject: [PATCH] Hardened Relays and API Security paths --- src/proxy/middle_relay.rs | 35 ++++++--- src/proxy/relay.rs | 145 ++++++++++++++++++++++++++------------ 2 files changed, 125 insertions(+), 55 deletions(-) diff --git a/src/proxy/middle_relay.rs b/src/proxy/middle_relay.rs index e4b4fe6..1865c84 100644 --- a/src/proxy/middle_relay.rs +++ b/src/proxy/middle_relay.rs @@ -65,6 +65,8 @@ const ME_D2C_SINGLE_WRITE_COALESCE_MAX_BYTES: usize = 128 * 1024; const QUOTA_RESERVE_SPIN_RETRIES: usize = 32; const QUOTA_RESERVE_BACKOFF_MIN_MS: u64 = 1; const QUOTA_RESERVE_BACKOFF_MAX_MS: u64 = 16; +const QUOTA_RESERVE_MAX_BACKOFF_ROUNDS: usize = 16; +const ME_CHILD_JOIN_TIMEOUT: Duration = Duration::from_secs(2); #[derive(Default)] pub(crate) struct DesyncDedupRotationState { @@ -624,6 +626,7 @@ async fn reserve_user_quota_with_yield( limit: u64, ) -> std::result::Result { let mut backoff_ms = QUOTA_RESERVE_BACKOFF_MIN_MS; + let mut backoff_rounds = 0usize; loop { for _ in 0..QUOTA_RESERVE_SPIN_RETRIES { match user_stats.quota_try_reserve(bytes, limit) { @@ -637,6 +640,10 @@ async fn reserve_user_quota_with_yield( tokio::task::yield_now().await; tokio::time::sleep(Duration::from_millis(backoff_ms)).await; + backoff_rounds = backoff_rounds.saturating_add(1); + if backoff_rounds >= QUOTA_RESERVE_MAX_BACKOFF_ROUNDS { + return Err(QuotaReserveError::Contended); + } backoff_ms = backoff_ms .saturating_mul(2) .min(QUOTA_RESERVE_BACKOFF_MAX_MS); @@ -1169,7 +1176,7 @@ where let c2me_byte_semaphore = Arc::new(Semaphore::new(c2me_byte_budget)); let (c2me_tx, mut c2me_rx) = mpsc::channel::(c2me_channel_capacity); let me_pool_c2me = me_pool.clone(); - let c2me_sender = tokio::spawn(async move { + let mut c2me_sender = tokio::spawn(async move { let mut sent_since_yield = 0usize; while let Some(cmd) = c2me_rx.recv().await { match cmd { @@ -1214,7 +1221,7 @@ where let last_downstream_activity_ms_clone = last_downstream_activity_ms.clone(); let bytes_me2c_clone = bytes_me2c.clone(); let d2c_flush_policy = MeD2cFlushPolicy::from_config(&config); - let me_writer = tokio::spawn(async move { + let mut me_writer = tokio::spawn(async move { let mut writer = crypto_writer; let mut frame_buf = Vec::with_capacity(16 * 1024); let shrink_threshold = d2c_flush_policy.frame_buf_shrink_threshold_bytes; @@ -1729,14 +1736,26 @@ where } drop(c2me_tx); - let c2me_result = c2me_sender - .await - .unwrap_or_else(|e| Err(ProxyError::Proxy(format!("ME sender join error: {e}")))); + let c2me_result = match timeout(ME_CHILD_JOIN_TIMEOUT, &mut c2me_sender).await { + Ok(joined) => { + joined.unwrap_or_else(|e| Err(ProxyError::Proxy(format!("ME sender join error: {e}")))) + } + Err(_) => { + c2me_sender.abort(); + Err(ProxyError::Proxy("ME sender join timeout".into())) + } + }; let _ = stop_tx.send(()); - let mut writer_result = me_writer - .await - .unwrap_or_else(|e| Err(ProxyError::Proxy(format!("ME writer join error: {e}")))); + let mut writer_result = match timeout(ME_CHILD_JOIN_TIMEOUT, &mut me_writer).await { + Ok(joined) => { + joined.unwrap_or_else(|e| Err(ProxyError::Proxy(format!("ME writer join error: {e}")))) + } + Err(_) => { + me_writer.abort(); + Err(ProxyError::Proxy("ME writer join timeout".into())) + } + }; // When client closes, but ME channel stopped as unregistered - it isnt error if client_closed diff --git a/src/proxy/relay.rs b/src/proxy/relay.rs index 117d158..4d8b827 100644 --- a/src/proxy/relay.rs +++ b/src/proxy/relay.rs @@ -215,6 +215,7 @@ struct StatsIo { c2s_rate_debt_bytes: u64, c2s_wait: RateWaitState, s2c_wait: RateWaitState, + quota_wait: RateWaitState, quota_limit: Option, quota_exceeded: Arc, quota_bytes_since_check: u64, @@ -275,6 +276,7 @@ impl StatsIo { c2s_rate_debt_bytes: 0, c2s_wait: RateWaitState::default(), s2c_wait: RateWaitState::default(), + quota_wait: RateWaitState::default(), quota_limit, quota_exceeded, quota_bytes_since_check: 0, @@ -353,6 +355,11 @@ impl StatsIo { Poll::Ready(()) } + + fn arm_quota_wait(&mut self, cx: &mut Context<'_>) -> Poll<()> { + Self::arm_wait(&mut self.quota_wait, false, false); + Self::poll_wait(&mut self.quota_wait, cx, None, RateDirection::Up) + } } #[derive(Debug)] @@ -430,8 +437,13 @@ impl AsyncRead for StatsIo { if this.settle_c2s_rate_debt(cx).is_pending() { return Poll::Pending; } + if buf.remaining() == 0 { + return Pin::new(&mut this.inner).poll_read(cx, buf); + } let mut remaining_before = None; + let mut reserved_read_bytes = 0u64; + let mut read_limit = buf.remaining(); if let Some(limit) = this.quota_limit { let used_before = this.user_stats.quota_used(); let remaining = limit.saturating_sub(used_before); @@ -440,50 +452,77 @@ impl AsyncRead for StatsIo { return Poll::Ready(Err(quota_io_error())); } remaining_before = Some(remaining); + read_limit = read_limit.min(remaining as usize); + if read_limit == 0 { + this.quota_exceeded.store(true, Ordering::Release); + return Poll::Ready(Err(quota_io_error())); + } + + let desired = read_limit as u64; + let mut reserve_rounds = 0usize; + while reserved_read_bytes == 0 { + for _ in 0..QUOTA_RESERVE_SPIN_RETRIES { + match this.user_stats.quota_try_reserve(desired, limit) { + Ok(_) => { + reserved_read_bytes = desired; + break; + } + Err(crate::stats::QuotaReserveError::LimitExceeded) => { + this.quota_exceeded.store(true, Ordering::Release); + return Poll::Ready(Err(quota_io_error())); + } + Err(crate::stats::QuotaReserveError::Contended) => {} + } + } + + if reserved_read_bytes == 0 { + reserve_rounds = reserve_rounds.saturating_add(1); + if reserve_rounds >= QUOTA_RESERVE_MAX_ROUNDS { + if this.arm_quota_wait(cx).is_pending() { + return Poll::Pending; + } + reserve_rounds = 0; + } + } + } } - let before = buf.filled().len(); + let limited_read = read_limit < buf.remaining(); + let read_result = if limited_read { + let mut limited_buf = ReadBuf::new(buf.initialize_unfilled_to(read_limit)); + match Pin::new(&mut this.inner).poll_read(cx, &mut limited_buf) { + Poll::Ready(Ok(())) => { + let n = limited_buf.filled().len(); + buf.advance(n); + Poll::Ready(Ok(n)) + } + Poll::Ready(Err(err)) => Poll::Ready(Err(err)), + Poll::Pending => Poll::Pending, + } + } else { + let before = buf.filled().len(); + match Pin::new(&mut this.inner).poll_read(cx, buf) { + Poll::Ready(Ok(())) => { + let n = buf.filled().len() - before; + Poll::Ready(Ok(n)) + } + Poll::Ready(Err(err)) => Poll::Ready(Err(err)), + Poll::Pending => Poll::Pending, + } + }; - match Pin::new(&mut this.inner).poll_read(cx, buf) { - Poll::Ready(Ok(())) => { - let n = buf.filled().len() - before; + match read_result { + Poll::Ready(Ok(n)) => { + if reserved_read_bytes > n as u64 { + refund_reserved_quota_bytes( + this.user_stats.as_ref(), + reserved_read_bytes - n as u64, + ); + } if n > 0 { let n_to_charge = n as u64; - if let (Some(limit), Some(remaining)) = (this.quota_limit, remaining_before) { - let mut reserved_total = None; - let mut reserve_rounds = 0usize; - while reserved_total.is_none() { - let mut saw_contention = false; - for _ in 0..QUOTA_RESERVE_SPIN_RETRIES { - match this.user_stats.quota_try_reserve(n_to_charge, limit) { - Ok(total) => { - reserved_total = Some(total); - break; - } - Err(crate::stats::QuotaReserveError::LimitExceeded) => { - this.quota_exceeded.store(true, Ordering::Release); - buf.set_filled(before); - return Poll::Ready(Err(quota_io_error())); - } - Err(crate::stats::QuotaReserveError::Contended) => { - saw_contention = true; - } - } - } - if reserved_total.is_none() { - reserve_rounds = reserve_rounds.saturating_add(1); - if reserve_rounds >= QUOTA_RESERVE_MAX_ROUNDS { - this.quota_exceeded.store(true, Ordering::Release); - buf.set_filled(before); - return Poll::Ready(Err(quota_io_error())); - } - if saw_contention { - std::thread::yield_now(); - } - } - } - + if let Some(remaining) = remaining_before { if should_immediate_quota_check(remaining, n_to_charge) { this.quota_bytes_since_check = 0; } else { @@ -495,9 +534,11 @@ impl AsyncRead for StatsIo { } } - if reserved_total.unwrap_or(0) >= limit { - this.quota_exceeded.store(true, Ordering::Release); - } + } + if let Some(limit) = this.quota_limit + && this.user_stats.quota_used() >= limit + { + this.quota_exceeded.store(true, Ordering::Release); } // C→S: client sent data @@ -521,7 +562,18 @@ impl AsyncRead for StatsIo { } Poll::Ready(Ok(())) } - other => other, + Poll::Pending => { + if reserved_read_bytes > 0 { + refund_reserved_quota_bytes(this.user_stats.as_ref(), reserved_read_bytes); + } + Poll::Pending + } + Poll::Ready(Err(err)) => { + if reserved_read_bytes > 0 { + refund_reserved_quota_bytes(this.user_stats.as_ref(), reserved_read_bytes); + } + Poll::Ready(Err(err)) + } } } } @@ -614,11 +666,10 @@ impl AsyncWrite for StatsIo { if let Some(lease) = this.traffic_lease.as_ref() { lease.refund(RateDirection::Down, shaper_reserved_bytes); } - this.quota_exceeded.store(true, Ordering::Release); - return Poll::Ready(Err(quota_io_error())); - } - if saw_contention { - std::thread::yield_now(); + let _ = this.arm_quota_wait(cx); + return Poll::Pending; + } else if saw_contention { + std::hint::spin_loop(); } } }