ME Pool improvements

Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
This commit is contained in:
Alexey 2026-03-01 03:36:00 +03:00
parent 25ffcf6081
commit 44cdfd4b23
No known key found for this signature in database
3 changed files with 249 additions and 37 deletions

View File

@ -26,6 +26,9 @@ enum C2MeCommand {
const DESYNC_DEDUP_WINDOW: Duration = Duration::from_secs(60); const DESYNC_DEDUP_WINDOW: Duration = Duration::from_secs(60);
const DESYNC_ERROR_CLASS: &str = "frame_too_large_crypto_desync"; const DESYNC_ERROR_CLASS: &str = "frame_too_large_crypto_desync";
const C2ME_CHANNEL_CAPACITY: usize = 1024;
const C2ME_SOFT_PRESSURE_MIN_FREE_SLOTS: usize = 64;
const C2ME_SENDER_FAIRNESS_BUDGET: usize = 32;
static DESYNC_DEDUP: OnceLock<Mutex<HashMap<u64, Instant>>> = OnceLock::new(); static DESYNC_DEDUP: OnceLock<Mutex<HashMap<u64, Instant>>> = OnceLock::new();
struct RelayForensicsState { struct RelayForensicsState {
@ -166,6 +169,27 @@ fn report_desync_frame_too_large(
)) ))
} }
fn should_yield_c2me_sender(sent_since_yield: usize, has_backlog: bool) -> bool {
has_backlog && sent_since_yield >= C2ME_SENDER_FAIRNESS_BUDGET
}
async fn enqueue_c2me_command(
tx: &mpsc::Sender<C2MeCommand>,
cmd: C2MeCommand,
) -> std::result::Result<(), mpsc::error::SendError<C2MeCommand>> {
match tx.try_send(cmd) {
Ok(()) => Ok(()),
Err(mpsc::error::TrySendError::Closed(cmd)) => Err(mpsc::error::SendError(cmd)),
Err(mpsc::error::TrySendError::Full(cmd)) => {
// Cooperative yield reduces burst catch-up when the per-conn queue is near saturation.
if tx.capacity() <= C2ME_SOFT_PRESSURE_MIN_FREE_SLOTS {
tokio::task::yield_now().await;
}
tx.send(cmd).await
}
}
}
pub(crate) async fn handle_via_middle_proxy<R, W>( pub(crate) async fn handle_via_middle_proxy<R, W>(
mut crypto_reader: CryptoReader<R>, mut crypto_reader: CryptoReader<R>,
crypto_writer: CryptoWriter<W>, crypto_writer: CryptoWriter<W>,
@ -230,9 +254,10 @@ where
let frame_limit = config.general.max_client_frame; let frame_limit = config.general.max_client_frame;
let (c2me_tx, mut c2me_rx) = mpsc::channel::<C2MeCommand>(1024); let (c2me_tx, mut c2me_rx) = mpsc::channel::<C2MeCommand>(C2ME_CHANNEL_CAPACITY);
let me_pool_c2me = me_pool.clone(); let me_pool_c2me = me_pool.clone();
let c2me_sender = tokio::spawn(async move { let c2me_sender = tokio::spawn(async move {
let mut sent_since_yield = 0usize;
while let Some(cmd) = c2me_rx.recv().await { while let Some(cmd) = c2me_rx.recv().await {
match cmd { match cmd {
C2MeCommand::Data { payload, flags } => { C2MeCommand::Data { payload, flags } => {
@ -244,6 +269,11 @@ where
&payload, &payload,
flags, flags,
).await?; ).await?;
sent_since_yield = sent_since_yield.saturating_add(1);
if should_yield_c2me_sender(sent_since_yield, !c2me_rx.is_empty()) {
sent_since_yield = 0;
tokio::task::yield_now().await;
}
} }
C2MeCommand::Close => { C2MeCommand::Close => {
let _ = me_pool_c2me.send_close(conn_id).await; let _ = me_pool_c2me.send_close(conn_id).await;
@ -360,8 +390,7 @@ where
flags |= RPC_FLAG_NOT_ENCRYPTED; flags |= RPC_FLAG_NOT_ENCRYPTED;
} }
// Keep client read loop lightweight: route heavy ME send path via a dedicated task. // Keep client read loop lightweight: route heavy ME send path via a dedicated task.
if c2me_tx if enqueue_c2me_command(&c2me_tx, C2MeCommand::Data { payload, flags })
.send(C2MeCommand::Data { payload, flags })
.await .await
.is_err() .is_err()
{ {
@ -372,7 +401,7 @@ where
Ok(None) => { Ok(None) => {
debug!(conn_id, "Client EOF"); debug!(conn_id, "Client EOF");
client_closed = true; client_closed = true;
let _ = c2me_tx.send(C2MeCommand::Close).await; let _ = enqueue_c2me_command(&c2me_tx, C2MeCommand::Close).await;
break; break;
} }
Err(e) => { Err(e) => {
@ -647,3 +676,84 @@ where
// ACK should remain low-latency. // ACK should remain low-latency.
client_writer.flush().await.map_err(ProxyError::Io) client_writer.flush().await.map_err(ProxyError::Io)
} }
#[cfg(test)]
mod tests {
use super::*;
use tokio::time::{Duration as TokioDuration, timeout};
#[test]
fn should_yield_sender_only_on_budget_with_backlog() {
assert!(!should_yield_c2me_sender(0, true));
assert!(!should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET - 1, true));
assert!(!should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET, false));
assert!(should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET, true));
}
#[tokio::test]
async fn enqueue_c2me_command_uses_try_send_fast_path() {
let (tx, mut rx) = mpsc::channel::<C2MeCommand>(2);
enqueue_c2me_command(
&tx,
C2MeCommand::Data {
payload: vec![1, 2, 3],
flags: 0,
},
)
.await
.unwrap();
let recv = timeout(TokioDuration::from_millis(50), rx.recv())
.await
.unwrap()
.unwrap();
match recv {
C2MeCommand::Data { payload, flags } => {
assert_eq!(payload, vec![1, 2, 3]);
assert_eq!(flags, 0);
}
C2MeCommand::Close => panic!("unexpected close command"),
}
}
#[tokio::test]
async fn enqueue_c2me_command_falls_back_to_send_when_queue_is_full() {
let (tx, mut rx) = mpsc::channel::<C2MeCommand>(1);
tx.send(C2MeCommand::Data {
payload: vec![9],
flags: 9,
})
.await
.unwrap();
let tx2 = tx.clone();
let producer = tokio::spawn(async move {
enqueue_c2me_command(
&tx2,
C2MeCommand::Data {
payload: vec![7, 7],
flags: 7,
},
)
.await
.unwrap();
});
let _ = timeout(TokioDuration::from_millis(100), rx.recv())
.await
.unwrap();
producer.await.unwrap();
let recv = timeout(TokioDuration::from_millis(100), rx.recv())
.await
.unwrap()
.unwrap();
match recv {
C2MeCommand::Data { payload, flags } => {
assert_eq!(payload, vec![7, 7]);
assert_eq!(flags, 7);
}
C2MeCommand::Close => panic!("unexpected close command"),
}
}
}

View File

@ -336,22 +336,35 @@ impl PendingCiphertext {
} }
fn remaining_capacity(&self) -> usize { fn remaining_capacity(&self) -> usize {
self.max_len.saturating_sub(self.buf.len()) self.max_len.saturating_sub(self.pending_len())
}
fn compact_consumed_prefix(&mut self) {
if self.pos == 0 {
return;
}
if self.pos >= self.buf.len() {
self.buf.clear();
self.pos = 0;
return;
}
let _ = self.buf.split_to(self.pos);
self.pos = 0;
} }
fn advance(&mut self, n: usize) { fn advance(&mut self, n: usize) {
self.pos = (self.pos + n).min(self.buf.len()); self.pos = (self.pos + n).min(self.buf.len());
if self.pos == self.buf.len() { if self.pos == self.buf.len() {
self.buf.clear(); self.compact_consumed_prefix();
self.pos = 0;
return; return;
} }
// Compact when a large prefix was consumed. // Compact when a large prefix was consumed.
if self.pos >= 16 * 1024 { if self.pos >= 16 * 1024 {
let _ = self.buf.split_to(self.pos); self.compact_consumed_prefix();
self.pos = 0;
} }
} }
@ -379,6 +392,11 @@ impl PendingCiphertext {
)); ));
} }
// Reclaim consumed prefix when physical storage is the only limiter.
if self.pos > 0 && self.buf.len() + plaintext.len() > self.max_len {
self.compact_consumed_prefix();
}
let start = self.buf.len(); let start = self.buf.len();
self.buf.reserve(plaintext.len()); self.buf.reserve(plaintext.len());
self.buf.extend_from_slice(plaintext); self.buf.extend_from_slice(plaintext);
@ -777,3 +795,70 @@ impl<S: AsyncWrite + Unpin> AsyncWrite for PassthroughStream<S> {
Pin::new(&mut self.inner).poll_shutdown(cx) Pin::new(&mut self.inner).poll_shutdown(cx)
} }
} }
#[cfg(test)]
mod tests {
use super::*;
fn test_ctr() -> AesCtr {
AesCtr::new(&[0x11; 32], 0x0102_0304_0506_0708_1112_1314_1516_1718)
}
#[test]
fn pending_capacity_reclaims_after_partial_advance_without_compaction_threshold() {
let mut pending = PendingCiphertext::new(1024);
let mut ctr = test_ctr();
let payload = vec![0x41; 900];
pending.push_encrypted(&mut ctr, &payload).unwrap();
// Keep position below compaction threshold to validate logical-capacity accounting.
pending.advance(800);
assert_eq!(pending.pending_len(), 100);
assert_eq!(pending.remaining_capacity(), 924);
}
#[test]
fn push_encrypted_respects_pending_limit() {
let mut pending = PendingCiphertext::new(64);
let mut ctr = test_ctr();
pending.push_encrypted(&mut ctr, &[0x10; 64]).unwrap();
let err = pending.push_encrypted(&mut ctr, &[0x20]).unwrap_err();
assert_eq!(err.kind(), ErrorKind::WouldBlock);
}
#[test]
fn push_encrypted_compacts_prefix_when_physical_buffer_would_overflow() {
let mut pending = PendingCiphertext::new(64);
let mut ctr = test_ctr();
pending.push_encrypted(&mut ctr, &[0x22; 60]).unwrap();
pending.advance(30);
pending.push_encrypted(&mut ctr, &[0x33; 30]).unwrap();
assert_eq!(pending.pending_len(), 60);
assert!(pending.buf.len() <= 64);
}
#[test]
fn pending_ciphertext_preserves_stream_order_across_drain_and_append() {
let mut pending = PendingCiphertext::new(128);
let mut ctr = test_ctr();
let first = vec![0xA1; 80];
let second = vec![0xB2; 40];
pending.push_encrypted(&mut ctr, &first).unwrap();
pending.advance(50);
pending.push_encrypted(&mut ctr, &second).unwrap();
let mut baseline_ctr = test_ctr();
let mut baseline_plain = Vec::with_capacity(first.len() + second.len());
baseline_plain.extend_from_slice(&first);
baseline_plain.extend_from_slice(&second);
baseline_ctr.apply(&mut baseline_plain);
let expected = &baseline_plain[50..];
assert_eq!(pending.pending_slice(), expected);
}
}

View File

@ -1,8 +1,10 @@
use std::cmp::Reverse;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
use std::sync::atomic::Ordering; use std::sync::atomic::Ordering;
use std::time::Duration; use std::time::Duration;
use tokio::sync::mpsc::error::TrySendError;
use tracing::{debug, warn}; use tracing::{debug, warn};
use crate::error::{ProxyError, Result}; use crate::error::{ProxyError, Result};
@ -43,15 +45,17 @@ impl MePool {
loop { loop {
if let Some(current) = self.registry.get_writer(conn_id).await { if let Some(current) = self.registry.get_writer(conn_id).await {
let send_res = { match current.tx.try_send(WriterCommand::Data(payload.clone())) {
current
.tx
.send(WriterCommand::Data(payload.clone()))
.await
};
match send_res {
Ok(()) => return Ok(()), Ok(()) => return Ok(()),
Err(_) => { Err(TrySendError::Full(cmd)) => {
if current.tx.send(cmd).await.is_ok() {
return Ok(());
}
warn!(writer_id = current.writer_id, "ME writer channel closed");
self.remove_writer_and_close_clients(current.writer_id).await;
continue;
}
Err(TrySendError::Closed(_)) => {
warn!(writer_id = current.writer_id, "ME writer channel closed"); warn!(writer_id = current.writer_id, "ME writer channel closed");
self.remove_writer_and_close_clients(current.writer_id).await; self.remove_writer_and_close_clients(current.writer_id).await;
continue; continue;
@ -135,10 +139,11 @@ impl MePool {
let w = &writers_snapshot[*idx]; let w = &writers_snapshot[*idx];
let degraded = w.degraded.load(Ordering::Relaxed); let degraded = w.degraded.load(Ordering::Relaxed);
let stale = (w.generation < self.current_generation()) as usize; let stale = (w.generation < self.current_generation()) as usize;
(stale, degraded as usize) (stale, degraded as usize, Reverse(w.tx.capacity()))
}); });
let start = self.rr.fetch_add(1, Ordering::Relaxed) as usize % candidate_indices.len(); let start = self.rr.fetch_add(1, Ordering::Relaxed) as usize % candidate_indices.len();
let mut fallback_blocking_idx: Option<usize> = None;
for offset in 0..candidate_indices.len() { for offset in 0..candidate_indices.len() {
let idx = candidate_indices[(start + offset) % candidate_indices.len()]; let idx = candidate_indices[(start + offset) % candidate_indices.len()];
@ -146,7 +151,8 @@ impl MePool {
if !self.writer_accepts_new_binding(w) { if !self.writer_accepts_new_binding(w) {
continue; continue;
} }
if w.tx.send(WriterCommand::Data(payload.clone())).await.is_ok() { match w.tx.try_send(WriterCommand::Data(payload.clone())) {
Ok(()) => {
self.registry self.registry
.bind_writer(conn_id, w.id, w.tx.clone(), meta.clone()) .bind_writer(conn_id, w.id, w.tx.clone(), meta.clone())
.await; .await;
@ -161,14 +167,25 @@ impl MePool {
); );
} }
return Ok(()); return Ok(());
} else { }
Err(TrySendError::Full(_)) => {
if fallback_blocking_idx.is_none() {
fallback_blocking_idx = Some(idx);
}
}
Err(TrySendError::Closed(_)) => {
warn!(writer_id = w.id, "ME writer channel closed"); warn!(writer_id = w.id, "ME writer channel closed");
self.remove_writer_and_close_clients(w.id).await; self.remove_writer_and_close_clients(w.id).await;
continue; continue;
} }
} }
}
let w = writers_snapshot[candidate_indices[start]].clone(); let Some(blocking_idx) = fallback_blocking_idx else {
continue;
};
let w = writers_snapshot[blocking_idx].clone();
if !self.writer_accepts_new_binding(&w) { if !self.writer_accepts_new_binding(&w) {
continue; continue;
} }