From 4a77335ba94fa824ca9806cb09478902ecfc6a13 Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Tue, 7 Apr 2026 19:19:40 +0300 Subject: [PATCH] Round-bounded Retries + Bounded Retry-Round Constant Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com> --- .../tests/load_memory_envelope_tests.rs | 6 +- src/metrics.rs | 31 +++-- src/proxy/middle_relay.rs | 7 ++ src/proxy/relay.rs | 35 ++++-- src/tls_front/emulator.rs | 5 +- src/transport/middle_proxy/reader.rs | 109 +++++++++++++----- src/transport/middle_proxy/registry.rs | 22 ++-- 7 files changed, 154 insertions(+), 61 deletions(-) diff --git a/src/config/tests/load_memory_envelope_tests.rs b/src/config/tests/load_memory_envelope_tests.rs index b2d14fb..ea78498 100644 --- a/src/config/tests/load_memory_envelope_tests.rs +++ b/src/config/tests/load_memory_envelope_tests.rs @@ -26,7 +26,8 @@ me_writer_cmd_channel_capacity = 16385 "#, ); - let err = ProxyConfig::load(&path).expect_err("writer command capacity above hard cap must fail"); + let err = + ProxyConfig::load(&path).expect_err("writer command capacity above hard cap must fail"); let msg = err.to_string(); assert!( msg.contains("general.me_writer_cmd_channel_capacity must be within [1, 16384]"), @@ -45,7 +46,8 @@ me_route_channel_capacity = 8193 "#, ); - let err = ProxyConfig::load(&path).expect_err("route channel capacity above hard cap must fail"); + let err = + ProxyConfig::load(&path).expect_err("route channel capacity above hard cap must fail"); let msg = err.to_string(); assert!( msg.contains("general.me_route_channel_capacity must be within [1, 8192]"), diff --git a/src/metrics.rs b/src/metrics.rs index 685d2ef..1b920a8 100644 --- a/src/metrics.rs +++ b/src/metrics.rs @@ -196,7 +196,15 @@ async fn serve_listener( let ip_tracker = ip_tracker.clone(); let config = config_rx_conn.borrow().clone(); async move { - handle(req, &stats, &beobachten, &shared_state, &ip_tracker, &config).await + handle( + req, + &stats, + &beobachten, + &shared_state, + &ip_tracker, + &config, + ) + .await } }); if let Err(e) = http1::Builder::new() @@ -3145,9 +3153,16 @@ mod tests { stats.increment_connects_all(); let req = Request::builder().uri("/metrics").body(()).unwrap(); - let resp = handle(req, &stats, &beobachten, shared_state.as_ref(), &tracker, &config) - .await - .unwrap(); + let resp = handle( + req, + &stats, + &beobachten, + shared_state.as_ref(), + &tracker, + &config, + ) + .await + .unwrap(); assert_eq!(resp.status(), StatusCode::OK); let body = resp.into_body().collect().await.unwrap().to_bytes(); assert!( @@ -3180,8 +3195,8 @@ mod tests { &tracker, &config, ) - .await - .unwrap(); + .await + .unwrap(); assert_eq!(resp_beob.status(), StatusCode::OK); let body_beob = resp_beob.into_body().collect().await.unwrap().to_bytes(); let beob_text = std::str::from_utf8(body_beob.as_ref()).unwrap(); @@ -3197,8 +3212,8 @@ mod tests { &tracker, &config, ) - .await - .unwrap(); + .await + .unwrap(); assert_eq!(resp404.status(), StatusCode::NOT_FOUND); } } diff --git a/src/proxy/middle_relay.rs b/src/proxy/middle_relay.rs index 665e90e..eb68f83 100644 --- a/src/proxy/middle_relay.rs +++ b/src/proxy/middle_relay.rs @@ -56,6 +56,8 @@ const ME_D2C_FLUSH_BATCH_MAX_BYTES_MIN: usize = 4096; const ME_D2C_FRAME_BUF_SHRINK_HYSTERESIS_FACTOR: usize = 2; const ME_D2C_SINGLE_WRITE_COALESCE_MAX_BYTES: usize = 128 * 1024; const QUOTA_RESERVE_SPIN_RETRIES: usize = 32; +const QUOTA_RESERVE_BACKOFF_MIN_MS: u64 = 1; +const QUOTA_RESERVE_BACKOFF_MAX_MS: u64 = 16; #[derive(Default)] pub(crate) struct DesyncDedupRotationState { @@ -573,6 +575,7 @@ async fn reserve_user_quota_with_yield( bytes: u64, limit: u64, ) -> std::result::Result { + let mut backoff_ms = QUOTA_RESERVE_BACKOFF_MIN_MS; loop { for _ in 0..QUOTA_RESERVE_SPIN_RETRIES { match user_stats.quota_try_reserve(bytes, limit) { @@ -585,6 +588,10 @@ async fn reserve_user_quota_with_yield( } tokio::task::yield_now().await; + tokio::time::sleep(Duration::from_millis(backoff_ms)).await; + backoff_ms = backoff_ms + .saturating_mul(2) + .min(QUOTA_RESERVE_BACKOFF_MAX_MS); } } diff --git a/src/proxy/relay.rs b/src/proxy/relay.rs index 9fd5f3d..f612cb1 100644 --- a/src/proxy/relay.rs +++ b/src/proxy/relay.rs @@ -271,6 +271,7 @@ const QUOTA_LARGE_CHARGE_BYTES: u64 = 16 * 1024; const QUOTA_ADAPTIVE_INTERVAL_MIN_BYTES: u64 = 4 * 1024; const QUOTA_ADAPTIVE_INTERVAL_MAX_BYTES: u64 = 64 * 1024; const QUOTA_RESERVE_SPIN_RETRIES: usize = 64; +const QUOTA_RESERVE_MAX_ROUNDS: usize = 8; #[inline] fn quota_adaptive_interval_bytes(remaining_before: u64) -> u64 { @@ -319,6 +320,7 @@ impl AsyncRead for StatsIo { let mut reserved_total = None; let mut reserve_rounds = 0usize; while reserved_total.is_none() { + let mut saw_contention = false; for _ in 0..QUOTA_RESERVE_SPIN_RETRIES { match this.user_stats.quota_try_reserve(n_to_charge, limit) { Ok(total) => { @@ -331,15 +333,20 @@ impl AsyncRead for StatsIo { return Poll::Ready(Err(quota_io_error())); } Err(crate::stats::QuotaReserveError::Contended) => { - std::hint::spin_loop(); + saw_contention = true; } } } - reserve_rounds = reserve_rounds.saturating_add(1); - if reserved_total.is_none() && reserve_rounds >= 8 { - this.quota_exceeded.store(true, Ordering::Release); - buf.set_filled(before); - return Poll::Ready(Err(quota_io_error())); + if reserved_total.is_none() { + reserve_rounds = reserve_rounds.saturating_add(1); + if reserve_rounds >= QUOTA_RESERVE_MAX_ROUNDS { + this.quota_exceeded.store(true, Ordering::Release); + buf.set_filled(before); + return Poll::Ready(Err(quota_io_error())); + } + if saw_contention { + std::thread::yield_now(); + } } } @@ -407,6 +414,7 @@ impl AsyncWrite for StatsIo { remaining_before = Some(remaining); let desired = remaining.min(buf.len() as u64); + let mut saw_contention = false; for _ in 0..QUOTA_RESERVE_SPIN_RETRIES { match this.user_stats.quota_try_reserve(desired, limit) { Ok(_) => { @@ -418,15 +426,20 @@ impl AsyncWrite for StatsIo { break; } Err(crate::stats::QuotaReserveError::Contended) => { - std::hint::spin_loop(); + saw_contention = true; } } } - reserve_rounds = reserve_rounds.saturating_add(1); - if reserved_bytes == 0 && reserve_rounds >= 8 { - this.quota_exceeded.store(true, Ordering::Release); - return Poll::Ready(Err(quota_io_error())); + if reserved_bytes == 0 { + reserve_rounds = reserve_rounds.saturating_add(1); + if reserve_rounds >= QUOTA_RESERVE_MAX_ROUNDS { + this.quota_exceeded.store(true, Ordering::Release); + return Poll::Ready(Err(quota_io_error())); + } + if saw_contention { + std::thread::yield_now(); + } } } } else { diff --git a/src/tls_front/emulator.rs b/src/tls_front/emulator.rs index d6845a2..290e203 100644 --- a/src/tls_front/emulator.rs +++ b/src/tls_front/emulator.rs @@ -1,6 +1,5 @@ #![allow(clippy::too_many_arguments)] -use crc32fast::Hasher; use crate::crypto::{SecureRandom, sha256_hmac}; use crate::protocol::constants::{ MAX_TLS_CIPHERTEXT_SIZE, TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER, @@ -8,6 +7,7 @@ use crate::protocol::constants::{ }; use crate::protocol::tls::{TLS_DIGEST_LEN, TLS_DIGEST_POS, gen_fake_x25519_key}; use crate::tls_front::types::{CachedTlsData, ParsedCertificateInfo, TlsProfileSource}; +use crc32fast::Hasher; const MIN_APP_DATA: usize = 64; const MAX_APP_DATA: usize = MAX_TLS_CIPHERTEXT_SIZE; @@ -343,7 +343,8 @@ mod tests { }; use super::{ - build_compact_cert_info_payload, build_emulated_server_hello, hash_compact_cert_info_payload, + build_compact_cert_info_payload, build_emulated_server_hello, + hash_compact_cert_info_payload, }; use crate::crypto::SecureRandom; use crate::protocol::constants::{ diff --git a/src/transport/middle_proxy/reader.rs b/src/transport/middle_proxy/reader.rs index 9eaaa3f..dbfd9d7 100644 --- a/src/transport/middle_proxy/reader.rs +++ b/src/transport/middle_proxy/reader.rs @@ -24,15 +24,27 @@ use super::registry::RouteResult; use super::{ConnRegistry, MeResponse}; const DATA_ROUTE_MAX_ATTEMPTS: usize = 3; +const DATA_ROUTE_QUEUE_FULL_STARVATION_THRESHOLD: u8 = 3; fn should_close_on_route_result_for_data(result: RouteResult) -> bool { - !matches!(result, RouteResult::Routed) + matches!(result, RouteResult::NoConn | RouteResult::ChannelClosed) } fn should_close_on_route_result_for_ack(result: RouteResult) -> bool { matches!(result, RouteResult::NoConn | RouteResult::ChannelClosed) } +fn is_data_route_queue_full(result: RouteResult) -> bool { + matches!( + result, + RouteResult::QueueFullBase | RouteResult::QueueFullHigh + ) +} + +fn should_close_on_queue_full_streak(streak: u8) -> bool { + streak >= DATA_ROUTE_QUEUE_FULL_STARVATION_THRESHOLD +} + async fn route_data_with_retry( reg: &ConnRegistry, conn_id: u64, @@ -85,6 +97,7 @@ pub(crate) async fn reader_loop( ) -> Result<()> { let mut raw = enc_leftover; let mut expected_seq: i32 = 0; + let mut data_route_queue_full_streak = HashMap::::new(); loop { let mut tmp = [0u8; 65_536]; @@ -169,25 +182,39 @@ pub(crate) async fn reader_loop( trace!(cid, flags, len = data.len(), "RPC_PROXY_ANS"); let route_wait_ms = reader_route_data_wait_ms.load(Ordering::Relaxed); - let routed = route_data_with_retry(reg.as_ref(), cid, flags, data, route_wait_ms).await; - if should_close_on_route_result_for_data(routed) { - match routed { - RouteResult::NoConn => stats.increment_me_route_drop_no_conn(), - RouteResult::ChannelClosed => { - stats.increment_me_route_drop_channel_closed() - } - RouteResult::QueueFullBase => { - stats.increment_me_route_drop_queue_full(); - stats.increment_me_route_drop_queue_full_base(); - } - RouteResult::QueueFullHigh => { - stats.increment_me_route_drop_queue_full(); - stats.increment_me_route_drop_queue_full_high(); - } - RouteResult::Routed => {} + let routed = + route_data_with_retry(reg.as_ref(), cid, flags, data, route_wait_ms).await; + if matches!(routed, RouteResult::Routed) { + data_route_queue_full_streak.remove(&cid); + continue; + } + match routed { + RouteResult::NoConn => stats.increment_me_route_drop_no_conn(), + RouteResult::ChannelClosed => stats.increment_me_route_drop_channel_closed(), + RouteResult::QueueFullBase => { + stats.increment_me_route_drop_queue_full(); + stats.increment_me_route_drop_queue_full_base(); } + RouteResult::QueueFullHigh => { + stats.increment_me_route_drop_queue_full(); + stats.increment_me_route_drop_queue_full_high(); + } + RouteResult::Routed => {} + } + if should_close_on_route_result_for_data(routed) { + data_route_queue_full_streak.remove(&cid); reg.unregister(cid).await; send_close_conn(&tx, cid).await; + continue; + } + if is_data_route_queue_full(routed) { + let streak = data_route_queue_full_streak.entry(cid).or_insert(0); + *streak = streak.saturating_add(1); + if should_close_on_queue_full_streak(*streak) { + data_route_queue_full_streak.remove(&cid); + reg.unregister(cid).await; + send_close_conn(&tx, cid).await; + } } } else if pt == RPC_SIMPLE_ACK_U32 && body.len() >= 12 { let cid = u64::from_le_bytes(body[0..8].try_into().unwrap()); @@ -221,11 +248,13 @@ pub(crate) async fn reader_loop( debug!(cid, "RPC_CLOSE_EXT from ME"); let _ = reg.route_nowait(cid, MeResponse::Close).await; reg.unregister(cid).await; + data_route_queue_full_streak.remove(&cid); } else if pt == RPC_CLOSE_CONN_U32 && body.len() >= 8 { let cid = u64::from_le_bytes(body[0..8].try_into().unwrap()); debug!(cid, "RPC_CLOSE_CONN from ME"); let _ = reg.route_nowait(cid, MeResponse::Close).await; reg.unregister(cid).await; + data_route_queue_full_streak.remove(&cid); } else if pt == RPC_PING_U32 && body.len() >= 8 { let ping_id = i64::from_le_bytes(body[0..8].try_into().unwrap()); trace!(ping_id, "RPC_PING -> RPC_PONG"); @@ -292,26 +321,50 @@ mod tests { use crate::transport::middle_proxy::ConnRegistry; use super::{ - MeResponse, RouteResult, route_data_with_retry, should_close_on_route_result_for_ack, + MeResponse, RouteResult, is_data_route_queue_full, route_data_with_retry, + should_close_on_queue_full_streak, should_close_on_route_result_for_ack, should_close_on_route_result_for_data, }; #[test] - fn data_route_failure_always_closes_session() { + fn data_route_only_fatal_results_close_immediately() { assert!(!should_close_on_route_result_for_data(RouteResult::Routed)); + assert!(!should_close_on_route_result_for_data( + RouteResult::QueueFullBase + )); + assert!(!should_close_on_route_result_for_data( + RouteResult::QueueFullHigh + )); assert!(should_close_on_route_result_for_data(RouteResult::NoConn)); - assert!(should_close_on_route_result_for_data(RouteResult::ChannelClosed)); - assert!(should_close_on_route_result_for_data(RouteResult::QueueFullBase)); - assert!(should_close_on_route_result_for_data(RouteResult::QueueFullHigh)); + assert!(should_close_on_route_result_for_data( + RouteResult::ChannelClosed + )); + } + + #[test] + fn data_route_queue_full_uses_starvation_threshold() { + assert!(is_data_route_queue_full(RouteResult::QueueFullBase)); + assert!(is_data_route_queue_full(RouteResult::QueueFullHigh)); + assert!(!is_data_route_queue_full(RouteResult::NoConn)); + assert!(!should_close_on_queue_full_streak(1)); + assert!(!should_close_on_queue_full_streak(2)); + assert!(should_close_on_queue_full_streak(3)); + assert!(should_close_on_queue_full_streak(u8::MAX)); } #[test] fn ack_queue_full_is_soft_dropped_without_forced_close() { assert!(!should_close_on_route_result_for_ack(RouteResult::Routed)); - assert!(!should_close_on_route_result_for_ack(RouteResult::QueueFullBase)); - assert!(!should_close_on_route_result_for_ack(RouteResult::QueueFullHigh)); + assert!(!should_close_on_route_result_for_ack( + RouteResult::QueueFullBase + )); + assert!(!should_close_on_route_result_for_ack( + RouteResult::QueueFullHigh + )); assert!(should_close_on_route_result_for_ack(RouteResult::NoConn)); - assert!(should_close_on_route_result_for_ack(RouteResult::ChannelClosed)); + assert!(should_close_on_route_result_for_ack( + RouteResult::ChannelClosed + )); } #[tokio::test] @@ -319,8 +372,7 @@ mod tests { let reg = ConnRegistry::with_route_channel_capacity(1); let (conn_id, mut rx) = reg.register().await; - let routed = - route_data_with_retry(®, conn_id, 0, Bytes::from_static(b"a"), 20).await; + let routed = route_data_with_retry(®, conn_id, 0, Bytes::from_static(b"a"), 20).await; assert!(matches!(routed, RouteResult::Routed)); match rx.recv().await { Some(MeResponse::Data { flags, data }) => { @@ -341,8 +393,7 @@ mod tests { RouteResult::Routed )); - let routed = - route_data_with_retry(®, conn_id, 0, Bytes::from_static(b"a"), 0).await; + let routed = route_data_with_retry(®, conn_id, 0, Bytes::from_static(b"a"), 0).await; assert!(matches!( routed, RouteResult::QueueFullBase | RouteResult::QueueFullHigh diff --git a/src/transport/middle_proxy/registry.rs b/src/transport/middle_proxy/registry.rs index 17fce47..d8625f2 100644 --- a/src/transport/middle_proxy/registry.rs +++ b/src/transport/middle_proxy/registry.rs @@ -356,13 +356,9 @@ impl ConnRegistry { .entry(writer_id) .or_insert_with(HashSet::new) .insert(conn_id); - self.hot_binding.map.insert( - conn_id, - HotConnBinding { - writer_id, - meta, - }, - ); + self.hot_binding + .map + .insert(conn_id, HotConnBinding { writer_id, meta }); true } @@ -427,8 +423,16 @@ impl ConnRegistry { return None; } - let writer_id = self.hot_binding.map.get(&conn_id).map(|entry| entry.writer_id)?; - let writer = self.writers.map.get(&writer_id).map(|entry| entry.value().clone())?; + let writer_id = self + .hot_binding + .map + .get(&conn_id) + .map(|entry| entry.writer_id)?; + let writer = self + .writers + .map + .get(&writer_id) + .map(|entry| entry.value().clone())?; Some(ConnWriter { writer_id, tx: writer,