diff --git a/src/transport/middle_proxy/reader.rs b/src/transport/middle_proxy/reader.rs index aec55cd..9eaaa3f 100644 --- a/src/transport/middle_proxy/reader.rs +++ b/src/transport/middle_proxy/reader.rs @@ -23,6 +23,48 @@ use super::codec::{RpcChecksumMode, WriterCommand, rpc_crc}; use super::registry::RouteResult; use super::{ConnRegistry, MeResponse}; +const DATA_ROUTE_MAX_ATTEMPTS: usize = 3; + +fn should_close_on_route_result_for_data(result: RouteResult) -> bool { + !matches!(result, RouteResult::Routed) +} + +fn should_close_on_route_result_for_ack(result: RouteResult) -> bool { + matches!(result, RouteResult::NoConn | RouteResult::ChannelClosed) +} + +async fn route_data_with_retry( + reg: &ConnRegistry, + conn_id: u64, + flags: u32, + data: Bytes, + timeout_ms: u64, +) -> RouteResult { + let mut attempt = 0usize; + loop { + let routed = reg + .route_with_timeout( + conn_id, + MeResponse::Data { + flags, + data: data.clone(), + }, + timeout_ms, + ) + .await; + match routed { + RouteResult::QueueFullBase | RouteResult::QueueFullHigh => { + attempt = attempt.saturating_add(1); + if attempt >= DATA_ROUTE_MAX_ATTEMPTS { + return routed; + } + tokio::task::yield_now().await; + } + _ => return routed, + } + } +} + pub(crate) async fn reader_loop( mut rd: tokio::io::ReadHalf, dk: [u8; 32], @@ -127,10 +169,8 @@ pub(crate) async fn reader_loop( trace!(cid, flags, len = data.len(), "RPC_PROXY_ANS"); let route_wait_ms = reader_route_data_wait_ms.load(Ordering::Relaxed); - let routed = reg - .route_with_timeout(cid, MeResponse::Data { flags, data }, route_wait_ms) - .await; - if !matches!(routed, RouteResult::Routed) { + let routed = route_data_with_retry(reg.as_ref(), cid, flags, data, route_wait_ms).await; + if should_close_on_route_result_for_data(routed) { match routed { RouteResult::NoConn => stats.increment_me_route_drop_no_conn(), RouteResult::ChannelClosed => { @@ -171,8 +211,10 @@ pub(crate) async fn reader_loop( } RouteResult::Routed => {} } - reg.unregister(cid).await; - send_close_conn(&tx, cid).await; + if should_close_on_route_result_for_ack(routed) { + reg.unregister(cid).await; + send_close_conn(&tx, cid).await; + } } } else if pt == RPC_CLOSE_EXT_U32 && body.len() >= 8 { let cid = u64::from_le_bytes(body[0..8].try_into().unwrap()); @@ -243,6 +285,71 @@ pub(crate) async fn reader_loop( } } +#[cfg(test)] +mod tests { + use bytes::Bytes; + + use crate::transport::middle_proxy::ConnRegistry; + + use super::{ + MeResponse, RouteResult, route_data_with_retry, should_close_on_route_result_for_ack, + should_close_on_route_result_for_data, + }; + + #[test] + fn data_route_failure_always_closes_session() { + assert!(!should_close_on_route_result_for_data(RouteResult::Routed)); + assert!(should_close_on_route_result_for_data(RouteResult::NoConn)); + assert!(should_close_on_route_result_for_data(RouteResult::ChannelClosed)); + assert!(should_close_on_route_result_for_data(RouteResult::QueueFullBase)); + assert!(should_close_on_route_result_for_data(RouteResult::QueueFullHigh)); + } + + #[test] + fn ack_queue_full_is_soft_dropped_without_forced_close() { + assert!(!should_close_on_route_result_for_ack(RouteResult::Routed)); + assert!(!should_close_on_route_result_for_ack(RouteResult::QueueFullBase)); + assert!(!should_close_on_route_result_for_ack(RouteResult::QueueFullHigh)); + assert!(should_close_on_route_result_for_ack(RouteResult::NoConn)); + assert!(should_close_on_route_result_for_ack(RouteResult::ChannelClosed)); + } + + #[tokio::test] + async fn route_data_with_retry_returns_routed_when_channel_has_capacity() { + let reg = ConnRegistry::with_route_channel_capacity(1); + let (conn_id, mut rx) = reg.register().await; + + let routed = + route_data_with_retry(®, conn_id, 0, Bytes::from_static(b"a"), 20).await; + assert!(matches!(routed, RouteResult::Routed)); + match rx.recv().await { + Some(MeResponse::Data { flags, data }) => { + assert_eq!(flags, 0); + assert_eq!(data, Bytes::from_static(b"a")); + } + other => panic!("expected routed data response, got {other:?}"), + } + } + + #[tokio::test] + async fn route_data_with_retry_stops_after_bounded_attempts() { + let reg = ConnRegistry::with_route_channel_capacity(1); + let (conn_id, _rx) = reg.register().await; + + assert!(matches!( + reg.route_nowait(conn_id, MeResponse::Ack(1)).await, + RouteResult::Routed + )); + + let routed = + route_data_with_retry(®, conn_id, 0, Bytes::from_static(b"a"), 0).await; + assert!(matches!( + routed, + RouteResult::QueueFullBase | RouteResult::QueueFullHigh + )); + } +} + async fn send_close_conn(tx: &mpsc::Sender, conn_id: u64) { let mut p = Vec::with_capacity(12); p.extend_from_slice(&RPC_CLOSE_CONN_U32.to_le_bytes());