diff --git a/src/proxy/middle_relay.rs b/src/proxy/middle_relay.rs index f089442..a4942ba 100644 --- a/src/proxy/middle_relay.rs +++ b/src/proxy/middle_relay.rs @@ -26,6 +26,9 @@ enum C2MeCommand { const DESYNC_DEDUP_WINDOW: Duration = Duration::from_secs(60); const DESYNC_ERROR_CLASS: &str = "frame_too_large_crypto_desync"; +const C2ME_CHANNEL_CAPACITY: usize = 1024; +const C2ME_SOFT_PRESSURE_MIN_FREE_SLOTS: usize = 64; +const C2ME_SENDER_FAIRNESS_BUDGET: usize = 32; static DESYNC_DEDUP: OnceLock>> = OnceLock::new(); struct RelayForensicsState { @@ -166,6 +169,27 @@ fn report_desync_frame_too_large( )) } +fn should_yield_c2me_sender(sent_since_yield: usize, has_backlog: bool) -> bool { + has_backlog && sent_since_yield >= C2ME_SENDER_FAIRNESS_BUDGET +} + +async fn enqueue_c2me_command( + tx: &mpsc::Sender, + cmd: C2MeCommand, +) -> std::result::Result<(), mpsc::error::SendError> { + match tx.try_send(cmd) { + Ok(()) => Ok(()), + Err(mpsc::error::TrySendError::Closed(cmd)) => Err(mpsc::error::SendError(cmd)), + Err(mpsc::error::TrySendError::Full(cmd)) => { + // Cooperative yield reduces burst catch-up when the per-conn queue is near saturation. + if tx.capacity() <= C2ME_SOFT_PRESSURE_MIN_FREE_SLOTS { + tokio::task::yield_now().await; + } + tx.send(cmd).await + } + } +} + pub(crate) async fn handle_via_middle_proxy( mut crypto_reader: CryptoReader, crypto_writer: CryptoWriter, @@ -230,9 +254,10 @@ where let frame_limit = config.general.max_client_frame; - let (c2me_tx, mut c2me_rx) = mpsc::channel::(1024); + let (c2me_tx, mut c2me_rx) = mpsc::channel::(C2ME_CHANNEL_CAPACITY); let me_pool_c2me = me_pool.clone(); let c2me_sender = tokio::spawn(async move { + let mut sent_since_yield = 0usize; while let Some(cmd) = c2me_rx.recv().await { match cmd { C2MeCommand::Data { payload, flags } => { @@ -244,6 +269,11 @@ where &payload, flags, ).await?; + sent_since_yield = sent_since_yield.saturating_add(1); + if should_yield_c2me_sender(sent_since_yield, !c2me_rx.is_empty()) { + sent_since_yield = 0; + tokio::task::yield_now().await; + } } C2MeCommand::Close => { let _ = me_pool_c2me.send_close(conn_id).await; @@ -360,8 +390,7 @@ where flags |= RPC_FLAG_NOT_ENCRYPTED; } // Keep client read loop lightweight: route heavy ME send path via a dedicated task. - if c2me_tx - .send(C2MeCommand::Data { payload, flags }) + if enqueue_c2me_command(&c2me_tx, C2MeCommand::Data { payload, flags }) .await .is_err() { @@ -372,7 +401,7 @@ where Ok(None) => { debug!(conn_id, "Client EOF"); client_closed = true; - let _ = c2me_tx.send(C2MeCommand::Close).await; + let _ = enqueue_c2me_command(&c2me_tx, C2MeCommand::Close).await; break; } Err(e) => { @@ -647,3 +676,84 @@ where // ACK should remain low-latency. client_writer.flush().await.map_err(ProxyError::Io) } + +#[cfg(test)] +mod tests { + use super::*; + use tokio::time::{Duration as TokioDuration, timeout}; + + #[test] + fn should_yield_sender_only_on_budget_with_backlog() { + assert!(!should_yield_c2me_sender(0, true)); + assert!(!should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET - 1, true)); + assert!(!should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET, false)); + assert!(should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET, true)); + } + + #[tokio::test] + async fn enqueue_c2me_command_uses_try_send_fast_path() { + let (tx, mut rx) = mpsc::channel::(2); + enqueue_c2me_command( + &tx, + C2MeCommand::Data { + payload: vec![1, 2, 3], + flags: 0, + }, + ) + .await + .unwrap(); + + let recv = timeout(TokioDuration::from_millis(50), rx.recv()) + .await + .unwrap() + .unwrap(); + match recv { + C2MeCommand::Data { payload, flags } => { + assert_eq!(payload, vec![1, 2, 3]); + assert_eq!(flags, 0); + } + C2MeCommand::Close => panic!("unexpected close command"), + } + } + + #[tokio::test] + async fn enqueue_c2me_command_falls_back_to_send_when_queue_is_full() { + let (tx, mut rx) = mpsc::channel::(1); + tx.send(C2MeCommand::Data { + payload: vec![9], + flags: 9, + }) + .await + .unwrap(); + + let tx2 = tx.clone(); + let producer = tokio::spawn(async move { + enqueue_c2me_command( + &tx2, + C2MeCommand::Data { + payload: vec![7, 7], + flags: 7, + }, + ) + .await + .unwrap(); + }); + + let _ = timeout(TokioDuration::from_millis(100), rx.recv()) + .await + .unwrap(); + producer.await.unwrap(); + + let recv = timeout(TokioDuration::from_millis(100), rx.recv()) + .await + .unwrap() + .unwrap(); + match recv { + C2MeCommand::Data { payload, flags } => { + assert_eq!(payload, vec![7, 7]); + assert_eq!(flags, 7); + } + C2MeCommand::Close => panic!("unexpected close command"), + } + } +} diff --git a/src/stream/crypto_stream.rs b/src/stream/crypto_stream.rs index 5303fe5..744b186 100644 --- a/src/stream/crypto_stream.rs +++ b/src/stream/crypto_stream.rs @@ -336,22 +336,35 @@ impl PendingCiphertext { } fn remaining_capacity(&self) -> usize { - self.max_len.saturating_sub(self.buf.len()) + self.max_len.saturating_sub(self.pending_len()) + } + + fn compact_consumed_prefix(&mut self) { + if self.pos == 0 { + return; + } + + if self.pos >= self.buf.len() { + self.buf.clear(); + self.pos = 0; + return; + } + + let _ = self.buf.split_to(self.pos); + self.pos = 0; } fn advance(&mut self, n: usize) { self.pos = (self.pos + n).min(self.buf.len()); if self.pos == self.buf.len() { - self.buf.clear(); - self.pos = 0; + self.compact_consumed_prefix(); return; } // Compact when a large prefix was consumed. if self.pos >= 16 * 1024 { - let _ = self.buf.split_to(self.pos); - self.pos = 0; + self.compact_consumed_prefix(); } } @@ -379,6 +392,11 @@ impl PendingCiphertext { )); } + // Reclaim consumed prefix when physical storage is the only limiter. + if self.pos > 0 && self.buf.len() + plaintext.len() > self.max_len { + self.compact_consumed_prefix(); + } + let start = self.buf.len(); self.buf.reserve(plaintext.len()); self.buf.extend_from_slice(plaintext); @@ -777,3 +795,70 @@ impl AsyncWrite for PassthroughStream { Pin::new(&mut self.inner).poll_shutdown(cx) } } + +#[cfg(test)] +mod tests { + use super::*; + + fn test_ctr() -> AesCtr { + AesCtr::new(&[0x11; 32], 0x0102_0304_0506_0708_1112_1314_1516_1718) + } + + #[test] + fn pending_capacity_reclaims_after_partial_advance_without_compaction_threshold() { + let mut pending = PendingCiphertext::new(1024); + let mut ctr = test_ctr(); + let payload = vec![0x41; 900]; + pending.push_encrypted(&mut ctr, &payload).unwrap(); + + // Keep position below compaction threshold to validate logical-capacity accounting. + pending.advance(800); + assert_eq!(pending.pending_len(), 100); + assert_eq!(pending.remaining_capacity(), 924); + } + + #[test] + fn push_encrypted_respects_pending_limit() { + let mut pending = PendingCiphertext::new(64); + let mut ctr = test_ctr(); + + pending.push_encrypted(&mut ctr, &[0x10; 64]).unwrap(); + let err = pending.push_encrypted(&mut ctr, &[0x20]).unwrap_err(); + assert_eq!(err.kind(), ErrorKind::WouldBlock); + } + + #[test] + fn push_encrypted_compacts_prefix_when_physical_buffer_would_overflow() { + let mut pending = PendingCiphertext::new(64); + let mut ctr = test_ctr(); + + pending.push_encrypted(&mut ctr, &[0x22; 60]).unwrap(); + pending.advance(30); + pending.push_encrypted(&mut ctr, &[0x33; 30]).unwrap(); + + assert_eq!(pending.pending_len(), 60); + assert!(pending.buf.len() <= 64); + } + + #[test] + fn pending_ciphertext_preserves_stream_order_across_drain_and_append() { + let mut pending = PendingCiphertext::new(128); + let mut ctr = test_ctr(); + + let first = vec![0xA1; 80]; + let second = vec![0xB2; 40]; + + pending.push_encrypted(&mut ctr, &first).unwrap(); + pending.advance(50); + pending.push_encrypted(&mut ctr, &second).unwrap(); + + let mut baseline_ctr = test_ctr(); + let mut baseline_plain = Vec::with_capacity(first.len() + second.len()); + baseline_plain.extend_from_slice(&first); + baseline_plain.extend_from_slice(&second); + baseline_ctr.apply(&mut baseline_plain); + + let expected = &baseline_plain[50..]; + assert_eq!(pending.pending_slice(), expected); + } +} diff --git a/src/transport/middle_proxy/send.rs b/src/transport/middle_proxy/send.rs index 8867212..f68b1b9 100644 --- a/src/transport/middle_proxy/send.rs +++ b/src/transport/middle_proxy/send.rs @@ -1,8 +1,10 @@ +use std::cmp::Reverse; use std::net::SocketAddr; use std::sync::Arc; use std::sync::atomic::Ordering; use std::time::Duration; +use tokio::sync::mpsc::error::TrySendError; use tracing::{debug, warn}; use crate::error::{ProxyError, Result}; @@ -43,15 +45,17 @@ impl MePool { loop { if let Some(current) = self.registry.get_writer(conn_id).await { - let send_res = { - current - .tx - .send(WriterCommand::Data(payload.clone())) - .await - }; - match send_res { + match current.tx.try_send(WriterCommand::Data(payload.clone())) { Ok(()) => return Ok(()), - Err(_) => { + Err(TrySendError::Full(cmd)) => { + if current.tx.send(cmd).await.is_ok() { + return Ok(()); + } + warn!(writer_id = current.writer_id, "ME writer channel closed"); + self.remove_writer_and_close_clients(current.writer_id).await; + continue; + } + Err(TrySendError::Closed(_)) => { warn!(writer_id = current.writer_id, "ME writer channel closed"); self.remove_writer_and_close_clients(current.writer_id).await; continue; @@ -135,10 +139,11 @@ impl MePool { let w = &writers_snapshot[*idx]; let degraded = w.degraded.load(Ordering::Relaxed); let stale = (w.generation < self.current_generation()) as usize; - (stale, degraded as usize) + (stale, degraded as usize, Reverse(w.tx.capacity())) }); let start = self.rr.fetch_add(1, Ordering::Relaxed) as usize % candidate_indices.len(); + let mut fallback_blocking_idx: Option = None; for offset in 0..candidate_indices.len() { let idx = candidate_indices[(start + offset) % candidate_indices.len()]; @@ -146,29 +151,41 @@ impl MePool { if !self.writer_accepts_new_binding(w) { continue; } - if w.tx.send(WriterCommand::Data(payload.clone())).await.is_ok() { - self.registry - .bind_writer(conn_id, w.id, w.tx.clone(), meta.clone()) - .await; - if w.generation < self.current_generation() { - self.stats.increment_pool_stale_pick_total(); - debug!( - conn_id, - writer_id = w.id, - writer_generation = w.generation, - current_generation = self.current_generation(), - "Selected stale ME writer for fallback bind" - ); + match w.tx.try_send(WriterCommand::Data(payload.clone())) { + Ok(()) => { + self.registry + .bind_writer(conn_id, w.id, w.tx.clone(), meta.clone()) + .await; + if w.generation < self.current_generation() { + self.stats.increment_pool_stale_pick_total(); + debug!( + conn_id, + writer_id = w.id, + writer_generation = w.generation, + current_generation = self.current_generation(), + "Selected stale ME writer for fallback bind" + ); + } + return Ok(()); + } + Err(TrySendError::Full(_)) => { + if fallback_blocking_idx.is_none() { + fallback_blocking_idx = Some(idx); + } + } + Err(TrySendError::Closed(_)) => { + warn!(writer_id = w.id, "ME writer channel closed"); + self.remove_writer_and_close_clients(w.id).await; + continue; } - return Ok(()); - } else { - warn!(writer_id = w.id, "ME writer channel closed"); - self.remove_writer_and_close_clients(w.id).await; - continue; } } - let w = writers_snapshot[candidate_indices[start]].clone(); + let Some(blocking_idx) = fallback_blocking_idx else { + continue; + }; + + let w = writers_snapshot[blocking_idx].clone(); if !self.writer_accepts_new_binding(&w) { continue; }