From 873618ce53b3b0f3800d1f53d37dfeb28ecd2492 Mon Sep 17 00:00:00 2001 From: mammuthus Date: Thu, 2 Apr 2026 18:02:07 +0000 Subject: [PATCH 1/6] metrics: export telemt_build_info version metric --- src/metrics.rs | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/metrics.rs b/src/metrics.rs index 3a88a5b..eba5f35 100644 --- a/src/metrics.rs +++ b/src/metrics.rs @@ -234,6 +234,14 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp let me_allows_normal = telemetry.me_level.allows_normal(); let me_allows_debug = telemetry.me_level.allows_debug(); + let _ = writeln!(out, "# HELP telemt_build_info Build information for the running telemt binary"); + let _ = writeln!(out, "# TYPE telemt_build_info gauge"); + let _ = writeln!( + out, + "telemt_build_info{{version=\"{}\"}} 1", + env!("CARGO_PKG_VERSION") + ); + let _ = writeln!(out, "# HELP telemt_uptime_seconds Proxy uptime"); let _ = writeln!(out, "# TYPE telemt_uptime_seconds gauge"); let _ = writeln!(out, "telemt_uptime_seconds {:.1}", stats.uptime_secs()); @@ -2679,6 +2687,10 @@ mod tests { let output = render_metrics(&stats, &config, &tracker).await; + assert!(output.contains(&format!( + "telemt_build_info{{version=\"{}\"}} 1", + env!("CARGO_PKG_VERSION") + ))); assert!(output.contains("telemt_connections_total 2")); assert!(output.contains("telemt_connections_bad_total 1")); assert!(output.contains("telemt_handshake_timeouts_total 1")); @@ -2768,6 +2780,7 @@ mod tests { let tracker = UserIpTracker::new(); let config = ProxyConfig::default(); let output = render_metrics(&stats, &config, &tracker).await; + assert!(output.contains("# TYPE telemt_build_info gauge")); assert!(output.contains("# TYPE telemt_uptime_seconds gauge")); assert!(output.contains("# TYPE telemt_connections_total counter")); assert!(output.contains("# TYPE telemt_connections_bad_total counter")); @@ -2822,6 +2835,10 @@ mod tests { .unwrap() .contains("telemt_connections_total 3") ); + assert!(std::str::from_utf8(body.as_ref()).unwrap().contains(&format!( + "telemt_build_info{{version=\"{}\"}} 1", + env!("CARGO_PKG_VERSION") + ))); config.general.beobachten = true; config.general.beobachten_minutes = 10; From 9b64d2ee177fa206579ae6204ae9e3c47006c86f Mon Sep 17 00:00:00 2001 From: mammuthus Date: Fri, 3 Apr 2026 07:49:21 +0000 Subject: [PATCH 2/6] style(metrics): apply rustfmt for build_info additions --- src/metrics.rs | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/src/metrics.rs b/src/metrics.rs index eba5f35..56b6558 100644 --- a/src/metrics.rs +++ b/src/metrics.rs @@ -234,7 +234,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp let me_allows_normal = telemetry.me_level.allows_normal(); let me_allows_debug = telemetry.me_level.allows_debug(); - let _ = writeln!(out, "# HELP telemt_build_info Build information for the running telemt binary"); + let _ = writeln!( + out, + "# HELP telemt_build_info Build information for the running telemt binary" + ); let _ = writeln!(out, "# TYPE telemt_build_info gauge"); let _ = writeln!( out, @@ -2835,10 +2838,14 @@ mod tests { .unwrap() .contains("telemt_connections_total 3") ); - assert!(std::str::from_utf8(body.as_ref()).unwrap().contains(&format!( - "telemt_build_info{{version=\"{}\"}} 1", - env!("CARGO_PKG_VERSION") - ))); + assert!( + std::str::from_utf8(body.as_ref()) + .unwrap() + .contains(&format!( + "telemt_build_info{{version=\"{}\"}} 1", + env!("CARGO_PKG_VERSION") + )) + ); config.general.beobachten = true; config.general.beobachten_minutes = 10; From bc3ad02a20d2c6130bf4bfdc82fd9e2c89f75a76 Mon Sep 17 00:00:00 2001 From: Ivan <84094482+JetJava@users.noreply.github.com> Date: Tue, 7 Apr 2026 02:10:08 +0400 Subject: [PATCH 3/6] tls_front/emulator: hash compact cert info payload before TLS emulation --- src/tls_front/emulator.rs | 55 ++++++++++++++++++++++++++++++++++----- 1 file changed, 48 insertions(+), 7 deletions(-) diff --git a/src/tls_front/emulator.rs b/src/tls_front/emulator.rs index 80f2b1b..d6845a2 100644 --- a/src/tls_front/emulator.rs +++ b/src/tls_front/emulator.rs @@ -1,5 +1,6 @@ #![allow(clippy::too_many_arguments)] +use crc32fast::Hasher; use crate::crypto::{SecureRandom, sha256_hmac}; use crate::protocol::constants::{ MAX_TLS_CIPHERTEXT_SIZE, TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER, @@ -98,6 +99,31 @@ fn build_compact_cert_info_payload(cert_info: &ParsedCertificateInfo) -> Option< Some(payload) } +fn hash_compact_cert_info_payload(cert_payload: Vec) -> Option> { + if cert_payload.is_empty() { + return None; + } + + let mut hashed = Vec::with_capacity(cert_payload.len()); + let mut seed_hasher = Hasher::new(); + seed_hasher.update(&cert_payload); + let mut state = seed_hasher.finalize(); + + while hashed.len() < cert_payload.len() { + let mut hasher = Hasher::new(); + hasher.update(&state.to_le_bytes()); + hasher.update(&cert_payload); + state = hasher.finalize(); + + let block = state.to_le_bytes(); + let remaining = cert_payload.len() - hashed.len(); + let copy_len = remaining.min(block.len()); + hashed.extend_from_slice(&block[..copy_len]); + } + + Some(hashed) +} + /// Build a ServerHello + CCS + ApplicationData sequence using cached TLS metadata. pub fn build_emulated_server_hello( secret: &[u8], @@ -190,7 +216,8 @@ pub fn build_emulated_server_hello( let compact_payload = cached .cert_info .as_ref() - .and_then(build_compact_cert_info_payload); + .and_then(build_compact_cert_info_payload) + .and_then(hash_compact_cert_info_payload); let selected_payload: Option<&[u8]> = if use_full_cert_payload { cached .cert_payload @@ -221,7 +248,6 @@ pub fn build_emulated_server_hello( marker.extend_from_slice(proto); marker }); - let mut payload_offset = 0usize; for (idx, size) in sizes.into_iter().enumerate() { let mut rec = Vec::with_capacity(5 + size); rec.push(TLS_RECORD_APPLICATION); @@ -231,11 +257,10 @@ pub fn build_emulated_server_hello( if let Some(payload) = selected_payload { if size > 17 { let body_len = size - 17; - let remaining = payload.len().saturating_sub(payload_offset); + let remaining = payload.len(); let copy_len = remaining.min(body_len); if copy_len > 0 { - rec.extend_from_slice(&payload[payload_offset..payload_offset + copy_len]); - payload_offset += copy_len; + rec.extend_from_slice(&payload[..copy_len]); } if body_len > copy_len { rec.extend_from_slice(&rng.bytes(body_len - copy_len)); @@ -317,7 +342,9 @@ mod tests { CachedTlsData, ParsedServerHello, TlsBehaviorProfile, TlsCertPayload, TlsProfileSource, }; - use super::build_emulated_server_hello; + use super::{ + build_compact_cert_info_payload, build_emulated_server_hello, hash_compact_cert_info_payload, + }; use crate::crypto::SecureRandom; use crate::protocol::constants::{ TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER, TLS_RECORD_HANDSHAKE, @@ -432,7 +459,21 @@ mod tests { ); let payload = first_app_data_payload(&response); - assert!(payload.starts_with(b"CN=example.com")); + let expected_hashed_payload = build_compact_cert_info_payload( + cached + .cert_info + .as_ref() + .expect("test fixture must provide certificate info"), + ) + .and_then(hash_compact_cert_info_payload) + .expect("compact certificate info payload must be present for this test"); + let copied_prefix_len = expected_hashed_payload + .len() + .min(payload.len().saturating_sub(17)); + assert_eq!( + &payload[..copied_prefix_len], + &expected_hashed_payload[..copied_prefix_len] + ); } #[test] From 3b717c75dae2a6e009260c422272102c83e39c8f Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Tue, 7 Apr 2026 18:18:47 +0300 Subject: [PATCH 4/6] Memory Hard-bounds + Handshake Budget in Metrics + No mutable in hotpath ConnRegistry Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com> --- src/config/load.rs | 32 +++++ .../tests/load_memory_envelope_tests.rs | 115 ++++++++++++++++ src/maestro/mod.rs | 1 + src/maestro/runtime_tasks.rs | 4 + src/metrics.rs | 128 ++++++++++++++++-- src/transport/middle_proxy/registry.rs | 83 +++++++----- 6 files changed, 317 insertions(+), 46 deletions(-) create mode 100644 src/config/tests/load_memory_envelope_tests.rs diff --git a/src/config/load.rs b/src/config/load.rs index 32f877d..f9e230c 100644 --- a/src/config/load.rs +++ b/src/config/load.rs @@ -17,6 +17,11 @@ use super::defaults::*; use super::types::*; const ACCESS_SECRET_BYTES: usize = 16; +const MAX_ME_WRITER_CMD_CHANNEL_CAPACITY: usize = 16_384; +const MAX_ME_ROUTE_CHANNEL_CAPACITY: usize = 8_192; +const MAX_ME_C2ME_CHANNEL_CAPACITY: usize = 8_192; +const MIN_MAX_CLIENT_FRAME_BYTES: usize = 4 * 1024; +const MAX_MAX_CLIENT_FRAME_BYTES: usize = 16 * 1024 * 1024; #[derive(Debug, Clone)] pub(crate) struct LoadedConfig { @@ -626,18 +631,41 @@ impl ProxyConfig { "general.me_writer_cmd_channel_capacity must be > 0".to_string(), )); } + if config.general.me_writer_cmd_channel_capacity > MAX_ME_WRITER_CMD_CHANNEL_CAPACITY { + return Err(ProxyError::Config(format!( + "general.me_writer_cmd_channel_capacity must be within [1, {MAX_ME_WRITER_CMD_CHANNEL_CAPACITY}]" + ))); + } if config.general.me_route_channel_capacity == 0 { return Err(ProxyError::Config( "general.me_route_channel_capacity must be > 0".to_string(), )); } + if config.general.me_route_channel_capacity > MAX_ME_ROUTE_CHANNEL_CAPACITY { + return Err(ProxyError::Config(format!( + "general.me_route_channel_capacity must be within [1, {MAX_ME_ROUTE_CHANNEL_CAPACITY}]" + ))); + } if config.general.me_c2me_channel_capacity == 0 { return Err(ProxyError::Config( "general.me_c2me_channel_capacity must be > 0".to_string(), )); } + if config.general.me_c2me_channel_capacity > MAX_ME_C2ME_CHANNEL_CAPACITY { + return Err(ProxyError::Config(format!( + "general.me_c2me_channel_capacity must be within [1, {MAX_ME_C2ME_CHANNEL_CAPACITY}]" + ))); + } + + if !(MIN_MAX_CLIENT_FRAME_BYTES..=MAX_MAX_CLIENT_FRAME_BYTES) + .contains(&config.general.max_client_frame) + { + return Err(ProxyError::Config(format!( + "general.max_client_frame must be within [{MIN_MAX_CLIENT_FRAME_BYTES}, {MAX_MAX_CLIENT_FRAME_BYTES}]" + ))); + } if config.general.me_c2me_send_timeout_ms > 60_000 { return Err(ProxyError::Config( @@ -1346,6 +1374,10 @@ mod load_mask_shape_security_tests; #[path = "tests/load_mask_classifier_prefetch_timeout_security_tests.rs"] mod load_mask_classifier_prefetch_timeout_security_tests; +#[cfg(test)] +#[path = "tests/load_memory_envelope_tests.rs"] +mod load_memory_envelope_tests; + #[cfg(test)] mod tests { use super::*; diff --git a/src/config/tests/load_memory_envelope_tests.rs b/src/config/tests/load_memory_envelope_tests.rs new file mode 100644 index 0000000..b2d14fb --- /dev/null +++ b/src/config/tests/load_memory_envelope_tests.rs @@ -0,0 +1,115 @@ +use super::*; +use std::fs; +use std::path::PathBuf; +use std::time::{SystemTime, UNIX_EPOCH}; + +fn write_temp_config(contents: &str) -> PathBuf { + let nonce = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("system time must be after unix epoch") + .as_nanos(); + let path = std::env::temp_dir().join(format!("telemt-load-memory-envelope-{nonce}.toml")); + fs::write(&path, contents).expect("temp config write must succeed"); + path +} + +fn remove_temp_config(path: &PathBuf) { + let _ = fs::remove_file(path); +} + +#[test] +fn load_rejects_writer_cmd_capacity_above_upper_bound() { + let path = write_temp_config( + r#" +[general] +me_writer_cmd_channel_capacity = 16385 +"#, + ); + + let err = ProxyConfig::load(&path).expect_err("writer command capacity above hard cap must fail"); + let msg = err.to_string(); + assert!( + msg.contains("general.me_writer_cmd_channel_capacity must be within [1, 16384]"), + "error must explain writer command capacity hard cap, got: {msg}" + ); + + remove_temp_config(&path); +} + +#[test] +fn load_rejects_route_channel_capacity_above_upper_bound() { + let path = write_temp_config( + r#" +[general] +me_route_channel_capacity = 8193 +"#, + ); + + let err = ProxyConfig::load(&path).expect_err("route channel capacity above hard cap must fail"); + let msg = err.to_string(); + assert!( + msg.contains("general.me_route_channel_capacity must be within [1, 8192]"), + "error must explain route channel hard cap, got: {msg}" + ); + + remove_temp_config(&path); +} + +#[test] +fn load_rejects_c2me_channel_capacity_above_upper_bound() { + let path = write_temp_config( + r#" +[general] +me_c2me_channel_capacity = 8193 +"#, + ); + + let err = ProxyConfig::load(&path).expect_err("c2me channel capacity above hard cap must fail"); + let msg = err.to_string(); + assert!( + msg.contains("general.me_c2me_channel_capacity must be within [1, 8192]"), + "error must explain c2me channel hard cap, got: {msg}" + ); + + remove_temp_config(&path); +} + +#[test] +fn load_rejects_max_client_frame_above_upper_bound() { + let path = write_temp_config( + r#" +[general] +max_client_frame = 16777217 +"#, + ); + + let err = ProxyConfig::load(&path).expect_err("max_client_frame above hard cap must fail"); + let msg = err.to_string(); + assert!( + msg.contains("general.max_client_frame must be within [4096, 16777216]"), + "error must explain max_client_frame hard cap, got: {msg}" + ); + + remove_temp_config(&path); +} + +#[test] +fn load_accepts_memory_limits_at_hard_upper_bounds() { + let path = write_temp_config( + r#" +[general] +me_writer_cmd_channel_capacity = 16384 +me_route_channel_capacity = 8192 +me_c2me_channel_capacity = 8192 +max_client_frame = 16777216 +"#, + ); + + let cfg = ProxyConfig::load(&path).expect("hard upper bound values must be accepted"); + assert_eq!(cfg.general.me_writer_cmd_channel_capacity, 16384); + assert_eq!(cfg.general.me_route_channel_capacity, 8192); + assert_eq!(cfg.general.me_c2me_channel_capacity, 8192); + assert_eq!(cfg.general.max_client_frame, 16 * 1024 * 1024); + + remove_temp_config(&path); +} diff --git a/src/maestro/mod.rs b/src/maestro/mod.rs index eed8d2e..00b3b2d 100644 --- a/src/maestro/mod.rs +++ b/src/maestro/mod.rs @@ -786,6 +786,7 @@ async fn run_inner( &startup_tracker, stats.clone(), beobachten.clone(), + shared_state.clone(), ip_tracker.clone(), config_rx.clone(), ) diff --git a/src/maestro/runtime_tasks.rs b/src/maestro/runtime_tasks.rs index b8b10da..5b3f2e0 100644 --- a/src/maestro/runtime_tasks.rs +++ b/src/maestro/runtime_tasks.rs @@ -13,6 +13,7 @@ use crate::crypto::SecureRandom; use crate::ip_tracker::UserIpTracker; use crate::metrics; use crate::network::probe::NetworkProbe; +use crate::proxy::shared_state::ProxySharedState; use crate::startup::{ COMPONENT_CONFIG_WATCHER_START, COMPONENT_METRICS_START, COMPONENT_RUNTIME_READY, StartupTracker, @@ -287,6 +288,7 @@ pub(crate) async fn spawn_metrics_if_configured( startup_tracker: &Arc, stats: Arc, beobachten: Arc, + shared_state: Arc, ip_tracker: Arc, config_rx: watch::Receiver>, ) { @@ -320,6 +322,7 @@ pub(crate) async fn spawn_metrics_if_configured( .await; let stats = stats.clone(); let beobachten = beobachten.clone(); + let shared_state = shared_state.clone(); let config_rx_metrics = config_rx.clone(); let ip_tracker_metrics = ip_tracker.clone(); let whitelist = config.server.metrics_whitelist.clone(); @@ -331,6 +334,7 @@ pub(crate) async fn spawn_metrics_if_configured( listen_backlog, stats, beobachten, + shared_state, ip_tracker_metrics, config_rx_metrics, whitelist, diff --git a/src/metrics.rs b/src/metrics.rs index 5cb1e77..7130e28 100644 --- a/src/metrics.rs +++ b/src/metrics.rs @@ -15,6 +15,7 @@ use tracing::{debug, info, warn}; use crate::config::ProxyConfig; use crate::ip_tracker::UserIpTracker; +use crate::proxy::shared_state::ProxySharedState; use crate::stats::Stats; use crate::stats::beobachten::BeobachtenStore; use crate::transport::{ListenOptions, create_listener}; @@ -25,6 +26,7 @@ pub async fn serve( listen_backlog: u32, stats: Arc, beobachten: Arc, + shared_state: Arc, ip_tracker: Arc, config_rx: tokio::sync::watch::Receiver>, whitelist: Vec, @@ -45,7 +47,13 @@ pub async fn serve( Ok(listener) => { info!("Metrics endpoint: http://{}/metrics and /beobachten", addr); serve_listener( - listener, stats, beobachten, ip_tracker, config_rx, whitelist, + listener, + stats, + beobachten, + shared_state, + ip_tracker, + config_rx, + whitelist, ) .await; } @@ -94,13 +102,20 @@ pub async fn serve( } (Some(listener), None) | (None, Some(listener)) => { serve_listener( - listener, stats, beobachten, ip_tracker, config_rx, whitelist, + listener, + stats, + beobachten, + shared_state, + ip_tracker, + config_rx, + whitelist, ) .await; } (Some(listener4), Some(listener6)) => { let stats_v6 = stats.clone(); let beobachten_v6 = beobachten.clone(); + let shared_state_v6 = shared_state.clone(); let ip_tracker_v6 = ip_tracker.clone(); let config_rx_v6 = config_rx.clone(); let whitelist_v6 = whitelist.clone(); @@ -109,6 +124,7 @@ pub async fn serve( listener6, stats_v6, beobachten_v6, + shared_state_v6, ip_tracker_v6, config_rx_v6, whitelist_v6, @@ -116,7 +132,13 @@ pub async fn serve( .await; }); serve_listener( - listener4, stats, beobachten, ip_tracker, config_rx, whitelist, + listener4, + stats, + beobachten, + shared_state, + ip_tracker, + config_rx, + whitelist, ) .await; } @@ -142,6 +164,7 @@ async fn serve_listener( listener: TcpListener, stats: Arc, beobachten: Arc, + shared_state: Arc, ip_tracker: Arc, config_rx: tokio::sync::watch::Receiver>, whitelist: Arc>, @@ -162,15 +185,19 @@ async fn serve_listener( let stats = stats.clone(); let beobachten = beobachten.clone(); + let shared_state = shared_state.clone(); let ip_tracker = ip_tracker.clone(); let config_rx_conn = config_rx.clone(); tokio::spawn(async move { let svc = service_fn(move |req| { let stats = stats.clone(); let beobachten = beobachten.clone(); + let shared_state = shared_state.clone(); let ip_tracker = ip_tracker.clone(); let config = config_rx_conn.borrow().clone(); - async move { handle(req, &stats, &beobachten, &ip_tracker, &config).await } + async move { + handle(req, &stats, &beobachten, &shared_state, &ip_tracker, &config).await + } }); if let Err(e) = http1::Builder::new() .serve_connection(hyper_util::rt::TokioIo::new(stream), svc) @@ -186,11 +213,12 @@ async fn handle( req: Request, stats: &Stats, beobachten: &BeobachtenStore, + shared_state: &ProxySharedState, ip_tracker: &UserIpTracker, config: &ProxyConfig, ) -> Result>, Infallible> { if req.uri().path() == "/metrics" { - let body = render_metrics(stats, config, ip_tracker).await; + let body = render_metrics(stats, shared_state, config, ip_tracker).await; let resp = Response::builder() .status(StatusCode::OK) .header("content-type", "text/plain; version=0.0.4; charset=utf-8") @@ -225,7 +253,12 @@ fn render_beobachten(beobachten: &BeobachtenStore, config: &ProxyConfig) -> Stri beobachten.snapshot_text(ttl) } -async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIpTracker) -> String { +async fn render_metrics( + stats: &Stats, + shared_state: &ProxySharedState, + config: &ProxyConfig, + ip_tracker: &UserIpTracker, +) -> String { use std::fmt::Write; let mut out = String::with_capacity(4096); let telemetry = stats.telemetry_policy(); @@ -359,6 +392,42 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp } ); + let _ = writeln!( + out, + "# HELP telemt_auth_expensive_checks_total Expensive authentication candidate checks executed during handshake validation" + ); + let _ = writeln!(out, "# TYPE telemt_auth_expensive_checks_total counter"); + let _ = writeln!( + out, + "telemt_auth_expensive_checks_total {}", + if core_enabled { + shared_state + .handshake + .auth_expensive_checks_total + .load(std::sync::atomic::Ordering::Relaxed) + } else { + 0 + } + ); + + let _ = writeln!( + out, + "# HELP telemt_auth_budget_exhausted_total Handshake validations that hit authentication candidate budget limits" + ); + let _ = writeln!(out, "# TYPE telemt_auth_budget_exhausted_total counter"); + let _ = writeln!( + out, + "telemt_auth_budget_exhausted_total {}", + if core_enabled { + shared_state + .handshake + .auth_budget_exhausted_total + .load(std::sync::atomic::Ordering::Relaxed) + } else { + 0 + } + ); + let _ = writeln!( out, "# HELP telemt_accept_permit_timeout_total Accepted connections dropped due to permit wait timeout" @@ -2847,6 +2916,7 @@ mod tests { #[tokio::test] async fn test_render_metrics_format() { let stats = Arc::new(Stats::new()); + let shared_state = ProxySharedState::new(); let tracker = UserIpTracker::new(); let mut config = ProxyConfig::default(); config @@ -2858,6 +2928,14 @@ mod tests { stats.increment_connects_all(); stats.increment_connects_bad(); stats.increment_handshake_timeouts(); + shared_state + .handshake + .auth_expensive_checks_total + .fetch_add(9, std::sync::atomic::Ordering::Relaxed); + shared_state + .handshake + .auth_budget_exhausted_total + .fetch_add(2, std::sync::atomic::Ordering::Relaxed); stats.increment_upstream_connect_attempt_total(); stats.increment_upstream_connect_attempt_total(); stats.increment_upstream_connect_success_total(); @@ -2901,11 +2979,13 @@ mod tests { .await .unwrap(); - let output = render_metrics(&stats, &config, &tracker).await; + let output = render_metrics(&stats, shared_state.as_ref(), &config, &tracker).await; assert!(output.contains("telemt_connections_total 2")); assert!(output.contains("telemt_connections_bad_total 1")); assert!(output.contains("telemt_handshake_timeouts_total 1")); + assert!(output.contains("telemt_auth_expensive_checks_total 9")); + assert!(output.contains("telemt_auth_budget_exhausted_total 2")); assert!(output.contains("telemt_upstream_connect_attempt_total 2")); assert!(output.contains("telemt_upstream_connect_success_total 1")); assert!(output.contains("telemt_upstream_connect_fail_total 1")); @@ -2960,12 +3040,15 @@ mod tests { #[tokio::test] async fn test_render_empty_stats() { let stats = Stats::new(); + let shared_state = ProxySharedState::new(); let tracker = UserIpTracker::new(); let config = ProxyConfig::default(); - let output = render_metrics(&stats, &config, &tracker).await; + let output = render_metrics(&stats, &shared_state, &config, &tracker).await; assert!(output.contains("telemt_connections_total 0")); assert!(output.contains("telemt_connections_bad_total 0")); assert!(output.contains("telemt_handshake_timeouts_total 0")); + assert!(output.contains("telemt_auth_expensive_checks_total 0")); + assert!(output.contains("telemt_auth_budget_exhausted_total 0")); assert!(output.contains("telemt_user_unique_ips_current{user=")); assert!(output.contains("telemt_user_unique_ips_recent_window{user=")); } @@ -2973,6 +3056,7 @@ mod tests { #[tokio::test] async fn test_render_uses_global_each_unique_ip_limit() { let stats = Stats::new(); + let shared_state = ProxySharedState::new(); stats.increment_user_connects("alice"); stats.increment_user_curr_connects("alice"); let tracker = UserIpTracker::new(); @@ -2983,7 +3067,7 @@ mod tests { let mut config = ProxyConfig::default(); config.access.user_max_unique_ips_global_each = 2; - let output = render_metrics(&stats, &config, &tracker).await; + let output = render_metrics(&stats, &shared_state, &config, &tracker).await; assert!(output.contains("telemt_user_unique_ips_limit{user=\"alice\"} 2")); assert!(output.contains("telemt_user_unique_ips_utilization{user=\"alice\"} 0.500000")); @@ -2992,13 +3076,16 @@ mod tests { #[tokio::test] async fn test_render_has_type_annotations() { let stats = Stats::new(); + let shared_state = ProxySharedState::new(); let tracker = UserIpTracker::new(); let config = ProxyConfig::default(); - let output = render_metrics(&stats, &config, &tracker).await; + let output = render_metrics(&stats, &shared_state, &config, &tracker).await; assert!(output.contains("# TYPE telemt_uptime_seconds gauge")); assert!(output.contains("# TYPE telemt_connections_total counter")); assert!(output.contains("# TYPE telemt_connections_bad_total counter")); assert!(output.contains("# TYPE telemt_handshake_timeouts_total counter")); + assert!(output.contains("# TYPE telemt_auth_expensive_checks_total counter")); + assert!(output.contains("# TYPE telemt_auth_budget_exhausted_total counter")); assert!(output.contains("# TYPE telemt_upstream_connect_attempt_total counter")); assert!(output.contains("# TYPE telemt_me_rpc_proxy_req_signal_sent_total counter")); assert!(output.contains("# TYPE telemt_me_idle_close_by_peer_total counter")); @@ -3035,6 +3122,7 @@ mod tests { async fn test_endpoint_integration() { let stats = Arc::new(Stats::new()); let beobachten = Arc::new(BeobachtenStore::new()); + let shared_state = ProxySharedState::new(); let tracker = UserIpTracker::new(); let mut config = ProxyConfig::default(); stats.increment_connects_all(); @@ -3042,7 +3130,7 @@ mod tests { stats.increment_connects_all(); let req = Request::builder().uri("/metrics").body(()).unwrap(); - let resp = handle(req, &stats, &beobachten, &tracker, &config) + let resp = handle(req, &stats, &beobachten, shared_state.as_ref(), &tracker, &config) .await .unwrap(); assert_eq!(resp.status(), StatusCode::OK); @@ -3061,7 +3149,14 @@ mod tests { Duration::from_secs(600), ); let req_beob = Request::builder().uri("/beobachten").body(()).unwrap(); - let resp_beob = handle(req_beob, &stats, &beobachten, &tracker, &config) + let resp_beob = handle( + req_beob, + &stats, + &beobachten, + shared_state.as_ref(), + &tracker, + &config, + ) .await .unwrap(); assert_eq!(resp_beob.status(), StatusCode::OK); @@ -3071,7 +3166,14 @@ mod tests { assert!(beob_text.contains("203.0.113.10-1")); let req404 = Request::builder().uri("/other").body(()).unwrap(); - let resp404 = handle(req404, &stats, &beobachten, &tracker, &config) + let resp404 = handle( + req404, + &stats, + &beobachten, + shared_state.as_ref(), + &tracker, + &config, + ) .await .unwrap(); assert_eq!(resp404.status(), StatusCode::NOT_FOUND); diff --git a/src/transport/middle_proxy/registry.rs b/src/transport/middle_proxy/registry.rs index ff4a68b..17fce47 100644 --- a/src/transport/middle_proxy/registry.rs +++ b/src/transport/middle_proxy/registry.rs @@ -55,6 +55,20 @@ struct RoutingTable { map: DashMap>, } +struct WriterTable { + map: DashMap>, +} + +#[derive(Clone)] +struct HotConnBinding { + writer_id: u64, + meta: ConnMeta, +} + +struct HotBindingTable { + map: DashMap, +} + struct BindingState { inner: Mutex, } @@ -83,6 +97,8 @@ impl BindingInner { pub struct ConnRegistry { routing: RoutingTable, + writers: WriterTable, + hot_binding: HotBindingTable, binding: BindingState, next_id: AtomicU64, route_channel_capacity: usize, @@ -105,6 +121,12 @@ impl ConnRegistry { routing: RoutingTable { map: DashMap::new(), }, + writers: WriterTable { + map: DashMap::new(), + }, + hot_binding: HotBindingTable { + map: DashMap::new(), + }, binding: BindingState { inner: Mutex::new(BindingInner::new()), }, @@ -149,16 +171,18 @@ impl ConnRegistry { pub async fn register_writer(&self, writer_id: u64, tx: mpsc::Sender) { let mut binding = self.binding.inner.lock().await; - binding.writers.insert(writer_id, tx); + binding.writers.insert(writer_id, tx.clone()); binding .conns_for_writer .entry(writer_id) .or_insert_with(HashSet::new); + self.writers.map.insert(writer_id, tx); } /// Unregister connection, returning associated writer_id if any. pub async fn unregister(&self, id: u64) -> Option { self.routing.map.remove(&id); + self.hot_binding.map.remove(&id); let mut binding = self.binding.inner.lock().await; binding.meta.remove(&id); if let Some(writer_id) = binding.writer_for_conn.remove(&id) { @@ -325,13 +349,20 @@ impl ConnRegistry { } binding.meta.insert(conn_id, meta.clone()); - binding.last_meta_for_writer.insert(writer_id, meta); + binding.last_meta_for_writer.insert(writer_id, meta.clone()); binding.writer_idle_since_epoch_secs.remove(&writer_id); binding .conns_for_writer .entry(writer_id) .or_insert_with(HashSet::new) .insert(conn_id); + self.hot_binding.map.insert( + conn_id, + HotConnBinding { + writer_id, + meta, + }, + ); true } @@ -392,39 +423,12 @@ impl ConnRegistry { } pub async fn get_writer(&self, conn_id: u64) -> Option { - let mut binding = self.binding.inner.lock().await; - // ROUTING IS THE SOURCE OF TRUTH: - // stale bindings are ignored and lazily cleaned when routing no longer - // contains the connection. if !self.routing.map.contains_key(&conn_id) { - binding.meta.remove(&conn_id); - if let Some(stale_writer_id) = binding.writer_for_conn.remove(&conn_id) - && let Some(conns) = binding.conns_for_writer.get_mut(&stale_writer_id) - { - conns.remove(&conn_id); - if conns.is_empty() { - binding - .writer_idle_since_epoch_secs - .insert(stale_writer_id, Self::now_epoch_secs()); - } - } return None; } - let writer_id = binding.writer_for_conn.get(&conn_id).copied()?; - let Some(writer) = binding.writers.get(&writer_id).cloned() else { - binding.writer_for_conn.remove(&conn_id); - binding.meta.remove(&conn_id); - if let Some(conns) = binding.conns_for_writer.get_mut(&writer_id) { - conns.remove(&conn_id); - if conns.is_empty() { - binding - .writer_idle_since_epoch_secs - .insert(writer_id, Self::now_epoch_secs()); - } - } - return None; - }; + let writer_id = self.hot_binding.map.get(&conn_id).map(|entry| entry.writer_id)?; + let writer = self.writers.map.get(&writer_id).map(|entry| entry.value().clone())?; Some(ConnWriter { writer_id, tx: writer, @@ -439,6 +443,7 @@ impl ConnRegistry { pub async fn writer_lost(&self, writer_id: u64) -> Vec { let mut binding = self.binding.inner.lock().await; binding.writers.remove(&writer_id); + self.writers.map.remove(&writer_id); binding.last_meta_for_writer.remove(&writer_id); binding.writer_idle_since_epoch_secs.remove(&writer_id); let conns = binding @@ -454,6 +459,15 @@ impl ConnRegistry { continue; } binding.writer_for_conn.remove(&conn_id); + let remove_hot = self + .hot_binding + .map + .get(&conn_id) + .map(|hot| hot.writer_id == writer_id) + .unwrap_or(false); + if remove_hot { + self.hot_binding.map.remove(&conn_id); + } if let Some(m) = binding.meta.get(&conn_id) { out.push(BoundConn { conn_id, @@ -466,8 +480,10 @@ impl ConnRegistry { #[allow(dead_code)] pub async fn get_meta(&self, conn_id: u64) -> Option { - let binding = self.binding.inner.lock().await; - binding.meta.get(&conn_id).cloned() + self.hot_binding + .map + .get(&conn_id) + .map(|entry| entry.meta.clone()) } pub async fn is_writer_empty(&self, writer_id: u64) -> bool { @@ -491,6 +507,7 @@ impl ConnRegistry { } binding.writers.remove(&writer_id); + self.writers.map.remove(&writer_id); binding.last_meta_for_writer.remove(&writer_id); binding.writer_idle_since_epoch_secs.remove(&writer_id); binding.conns_for_writer.remove(&writer_id); From e8cf97095fc83704ed73f0f37834c14de3062619 Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Tue, 7 Apr 2026 18:41:59 +0300 Subject: [PATCH 5/6] QueueFall Bounded Retry on Data-route Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com> --- src/transport/middle_proxy/reader.rs | 119 +++++++++++++++++++++++++-- 1 file changed, 113 insertions(+), 6 deletions(-) diff --git a/src/transport/middle_proxy/reader.rs b/src/transport/middle_proxy/reader.rs index aec55cd..9eaaa3f 100644 --- a/src/transport/middle_proxy/reader.rs +++ b/src/transport/middle_proxy/reader.rs @@ -23,6 +23,48 @@ use super::codec::{RpcChecksumMode, WriterCommand, rpc_crc}; use super::registry::RouteResult; use super::{ConnRegistry, MeResponse}; +const DATA_ROUTE_MAX_ATTEMPTS: usize = 3; + +fn should_close_on_route_result_for_data(result: RouteResult) -> bool { + !matches!(result, RouteResult::Routed) +} + +fn should_close_on_route_result_for_ack(result: RouteResult) -> bool { + matches!(result, RouteResult::NoConn | RouteResult::ChannelClosed) +} + +async fn route_data_with_retry( + reg: &ConnRegistry, + conn_id: u64, + flags: u32, + data: Bytes, + timeout_ms: u64, +) -> RouteResult { + let mut attempt = 0usize; + loop { + let routed = reg + .route_with_timeout( + conn_id, + MeResponse::Data { + flags, + data: data.clone(), + }, + timeout_ms, + ) + .await; + match routed { + RouteResult::QueueFullBase | RouteResult::QueueFullHigh => { + attempt = attempt.saturating_add(1); + if attempt >= DATA_ROUTE_MAX_ATTEMPTS { + return routed; + } + tokio::task::yield_now().await; + } + _ => return routed, + } + } +} + pub(crate) async fn reader_loop( mut rd: tokio::io::ReadHalf, dk: [u8; 32], @@ -127,10 +169,8 @@ pub(crate) async fn reader_loop( trace!(cid, flags, len = data.len(), "RPC_PROXY_ANS"); let route_wait_ms = reader_route_data_wait_ms.load(Ordering::Relaxed); - let routed = reg - .route_with_timeout(cid, MeResponse::Data { flags, data }, route_wait_ms) - .await; - if !matches!(routed, RouteResult::Routed) { + let routed = route_data_with_retry(reg.as_ref(), cid, flags, data, route_wait_ms).await; + if should_close_on_route_result_for_data(routed) { match routed { RouteResult::NoConn => stats.increment_me_route_drop_no_conn(), RouteResult::ChannelClosed => { @@ -171,8 +211,10 @@ pub(crate) async fn reader_loop( } RouteResult::Routed => {} } - reg.unregister(cid).await; - send_close_conn(&tx, cid).await; + if should_close_on_route_result_for_ack(routed) { + reg.unregister(cid).await; + send_close_conn(&tx, cid).await; + } } } else if pt == RPC_CLOSE_EXT_U32 && body.len() >= 8 { let cid = u64::from_le_bytes(body[0..8].try_into().unwrap()); @@ -243,6 +285,71 @@ pub(crate) async fn reader_loop( } } +#[cfg(test)] +mod tests { + use bytes::Bytes; + + use crate::transport::middle_proxy::ConnRegistry; + + use super::{ + MeResponse, RouteResult, route_data_with_retry, should_close_on_route_result_for_ack, + should_close_on_route_result_for_data, + }; + + #[test] + fn data_route_failure_always_closes_session() { + assert!(!should_close_on_route_result_for_data(RouteResult::Routed)); + assert!(should_close_on_route_result_for_data(RouteResult::NoConn)); + assert!(should_close_on_route_result_for_data(RouteResult::ChannelClosed)); + assert!(should_close_on_route_result_for_data(RouteResult::QueueFullBase)); + assert!(should_close_on_route_result_for_data(RouteResult::QueueFullHigh)); + } + + #[test] + fn ack_queue_full_is_soft_dropped_without_forced_close() { + assert!(!should_close_on_route_result_for_ack(RouteResult::Routed)); + assert!(!should_close_on_route_result_for_ack(RouteResult::QueueFullBase)); + assert!(!should_close_on_route_result_for_ack(RouteResult::QueueFullHigh)); + assert!(should_close_on_route_result_for_ack(RouteResult::NoConn)); + assert!(should_close_on_route_result_for_ack(RouteResult::ChannelClosed)); + } + + #[tokio::test] + async fn route_data_with_retry_returns_routed_when_channel_has_capacity() { + let reg = ConnRegistry::with_route_channel_capacity(1); + let (conn_id, mut rx) = reg.register().await; + + let routed = + route_data_with_retry(®, conn_id, 0, Bytes::from_static(b"a"), 20).await; + assert!(matches!(routed, RouteResult::Routed)); + match rx.recv().await { + Some(MeResponse::Data { flags, data }) => { + assert_eq!(flags, 0); + assert_eq!(data, Bytes::from_static(b"a")); + } + other => panic!("expected routed data response, got {other:?}"), + } + } + + #[tokio::test] + async fn route_data_with_retry_stops_after_bounded_attempts() { + let reg = ConnRegistry::with_route_channel_capacity(1); + let (conn_id, _rx) = reg.register().await; + + assert!(matches!( + reg.route_nowait(conn_id, MeResponse::Ack(1)).await, + RouteResult::Routed + )); + + let routed = + route_data_with_retry(®, conn_id, 0, Bytes::from_static(b"a"), 0).await; + assert!(matches!( + routed, + RouteResult::QueueFullBase | RouteResult::QueueFullHigh + )); + } +} + async fn send_close_conn(tx: &mpsc::Sender, conn_id: u64) { let mut p = Vec::with_capacity(12); p.extend_from_slice(&RPC_CLOSE_CONN_U32.to_le_bytes()); From 4a77335ba94fa824ca9806cb09478902ecfc6a13 Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Tue, 7 Apr 2026 19:19:40 +0300 Subject: [PATCH 6/6] Round-bounded Retries + Bounded Retry-Round Constant Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com> --- .../tests/load_memory_envelope_tests.rs | 6 +- src/metrics.rs | 31 +++-- src/proxy/middle_relay.rs | 7 ++ src/proxy/relay.rs | 35 ++++-- src/tls_front/emulator.rs | 5 +- src/transport/middle_proxy/reader.rs | 109 +++++++++++++----- src/transport/middle_proxy/registry.rs | 22 ++-- 7 files changed, 154 insertions(+), 61 deletions(-) diff --git a/src/config/tests/load_memory_envelope_tests.rs b/src/config/tests/load_memory_envelope_tests.rs index b2d14fb..ea78498 100644 --- a/src/config/tests/load_memory_envelope_tests.rs +++ b/src/config/tests/load_memory_envelope_tests.rs @@ -26,7 +26,8 @@ me_writer_cmd_channel_capacity = 16385 "#, ); - let err = ProxyConfig::load(&path).expect_err("writer command capacity above hard cap must fail"); + let err = + ProxyConfig::load(&path).expect_err("writer command capacity above hard cap must fail"); let msg = err.to_string(); assert!( msg.contains("general.me_writer_cmd_channel_capacity must be within [1, 16384]"), @@ -45,7 +46,8 @@ me_route_channel_capacity = 8193 "#, ); - let err = ProxyConfig::load(&path).expect_err("route channel capacity above hard cap must fail"); + let err = + ProxyConfig::load(&path).expect_err("route channel capacity above hard cap must fail"); let msg = err.to_string(); assert!( msg.contains("general.me_route_channel_capacity must be within [1, 8192]"), diff --git a/src/metrics.rs b/src/metrics.rs index 685d2ef..1b920a8 100644 --- a/src/metrics.rs +++ b/src/metrics.rs @@ -196,7 +196,15 @@ async fn serve_listener( let ip_tracker = ip_tracker.clone(); let config = config_rx_conn.borrow().clone(); async move { - handle(req, &stats, &beobachten, &shared_state, &ip_tracker, &config).await + handle( + req, + &stats, + &beobachten, + &shared_state, + &ip_tracker, + &config, + ) + .await } }); if let Err(e) = http1::Builder::new() @@ -3145,9 +3153,16 @@ mod tests { stats.increment_connects_all(); let req = Request::builder().uri("/metrics").body(()).unwrap(); - let resp = handle(req, &stats, &beobachten, shared_state.as_ref(), &tracker, &config) - .await - .unwrap(); + let resp = handle( + req, + &stats, + &beobachten, + shared_state.as_ref(), + &tracker, + &config, + ) + .await + .unwrap(); assert_eq!(resp.status(), StatusCode::OK); let body = resp.into_body().collect().await.unwrap().to_bytes(); assert!( @@ -3180,8 +3195,8 @@ mod tests { &tracker, &config, ) - .await - .unwrap(); + .await + .unwrap(); assert_eq!(resp_beob.status(), StatusCode::OK); let body_beob = resp_beob.into_body().collect().await.unwrap().to_bytes(); let beob_text = std::str::from_utf8(body_beob.as_ref()).unwrap(); @@ -3197,8 +3212,8 @@ mod tests { &tracker, &config, ) - .await - .unwrap(); + .await + .unwrap(); assert_eq!(resp404.status(), StatusCode::NOT_FOUND); } } diff --git a/src/proxy/middle_relay.rs b/src/proxy/middle_relay.rs index 665e90e..eb68f83 100644 --- a/src/proxy/middle_relay.rs +++ b/src/proxy/middle_relay.rs @@ -56,6 +56,8 @@ const ME_D2C_FLUSH_BATCH_MAX_BYTES_MIN: usize = 4096; const ME_D2C_FRAME_BUF_SHRINK_HYSTERESIS_FACTOR: usize = 2; const ME_D2C_SINGLE_WRITE_COALESCE_MAX_BYTES: usize = 128 * 1024; const QUOTA_RESERVE_SPIN_RETRIES: usize = 32; +const QUOTA_RESERVE_BACKOFF_MIN_MS: u64 = 1; +const QUOTA_RESERVE_BACKOFF_MAX_MS: u64 = 16; #[derive(Default)] pub(crate) struct DesyncDedupRotationState { @@ -573,6 +575,7 @@ async fn reserve_user_quota_with_yield( bytes: u64, limit: u64, ) -> std::result::Result { + let mut backoff_ms = QUOTA_RESERVE_BACKOFF_MIN_MS; loop { for _ in 0..QUOTA_RESERVE_SPIN_RETRIES { match user_stats.quota_try_reserve(bytes, limit) { @@ -585,6 +588,10 @@ async fn reserve_user_quota_with_yield( } tokio::task::yield_now().await; + tokio::time::sleep(Duration::from_millis(backoff_ms)).await; + backoff_ms = backoff_ms + .saturating_mul(2) + .min(QUOTA_RESERVE_BACKOFF_MAX_MS); } } diff --git a/src/proxy/relay.rs b/src/proxy/relay.rs index 9fd5f3d..f612cb1 100644 --- a/src/proxy/relay.rs +++ b/src/proxy/relay.rs @@ -271,6 +271,7 @@ const QUOTA_LARGE_CHARGE_BYTES: u64 = 16 * 1024; const QUOTA_ADAPTIVE_INTERVAL_MIN_BYTES: u64 = 4 * 1024; const QUOTA_ADAPTIVE_INTERVAL_MAX_BYTES: u64 = 64 * 1024; const QUOTA_RESERVE_SPIN_RETRIES: usize = 64; +const QUOTA_RESERVE_MAX_ROUNDS: usize = 8; #[inline] fn quota_adaptive_interval_bytes(remaining_before: u64) -> u64 { @@ -319,6 +320,7 @@ impl AsyncRead for StatsIo { let mut reserved_total = None; let mut reserve_rounds = 0usize; while reserved_total.is_none() { + let mut saw_contention = false; for _ in 0..QUOTA_RESERVE_SPIN_RETRIES { match this.user_stats.quota_try_reserve(n_to_charge, limit) { Ok(total) => { @@ -331,15 +333,20 @@ impl AsyncRead for StatsIo { return Poll::Ready(Err(quota_io_error())); } Err(crate::stats::QuotaReserveError::Contended) => { - std::hint::spin_loop(); + saw_contention = true; } } } - reserve_rounds = reserve_rounds.saturating_add(1); - if reserved_total.is_none() && reserve_rounds >= 8 { - this.quota_exceeded.store(true, Ordering::Release); - buf.set_filled(before); - return Poll::Ready(Err(quota_io_error())); + if reserved_total.is_none() { + reserve_rounds = reserve_rounds.saturating_add(1); + if reserve_rounds >= QUOTA_RESERVE_MAX_ROUNDS { + this.quota_exceeded.store(true, Ordering::Release); + buf.set_filled(before); + return Poll::Ready(Err(quota_io_error())); + } + if saw_contention { + std::thread::yield_now(); + } } } @@ -407,6 +414,7 @@ impl AsyncWrite for StatsIo { remaining_before = Some(remaining); let desired = remaining.min(buf.len() as u64); + let mut saw_contention = false; for _ in 0..QUOTA_RESERVE_SPIN_RETRIES { match this.user_stats.quota_try_reserve(desired, limit) { Ok(_) => { @@ -418,15 +426,20 @@ impl AsyncWrite for StatsIo { break; } Err(crate::stats::QuotaReserveError::Contended) => { - std::hint::spin_loop(); + saw_contention = true; } } } - reserve_rounds = reserve_rounds.saturating_add(1); - if reserved_bytes == 0 && reserve_rounds >= 8 { - this.quota_exceeded.store(true, Ordering::Release); - return Poll::Ready(Err(quota_io_error())); + if reserved_bytes == 0 { + reserve_rounds = reserve_rounds.saturating_add(1); + if reserve_rounds >= QUOTA_RESERVE_MAX_ROUNDS { + this.quota_exceeded.store(true, Ordering::Release); + return Poll::Ready(Err(quota_io_error())); + } + if saw_contention { + std::thread::yield_now(); + } } } } else { diff --git a/src/tls_front/emulator.rs b/src/tls_front/emulator.rs index d6845a2..290e203 100644 --- a/src/tls_front/emulator.rs +++ b/src/tls_front/emulator.rs @@ -1,6 +1,5 @@ #![allow(clippy::too_many_arguments)] -use crc32fast::Hasher; use crate::crypto::{SecureRandom, sha256_hmac}; use crate::protocol::constants::{ MAX_TLS_CIPHERTEXT_SIZE, TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER, @@ -8,6 +7,7 @@ use crate::protocol::constants::{ }; use crate::protocol::tls::{TLS_DIGEST_LEN, TLS_DIGEST_POS, gen_fake_x25519_key}; use crate::tls_front::types::{CachedTlsData, ParsedCertificateInfo, TlsProfileSource}; +use crc32fast::Hasher; const MIN_APP_DATA: usize = 64; const MAX_APP_DATA: usize = MAX_TLS_CIPHERTEXT_SIZE; @@ -343,7 +343,8 @@ mod tests { }; use super::{ - build_compact_cert_info_payload, build_emulated_server_hello, hash_compact_cert_info_payload, + build_compact_cert_info_payload, build_emulated_server_hello, + hash_compact_cert_info_payload, }; use crate::crypto::SecureRandom; use crate::protocol::constants::{ diff --git a/src/transport/middle_proxy/reader.rs b/src/transport/middle_proxy/reader.rs index 9eaaa3f..dbfd9d7 100644 --- a/src/transport/middle_proxy/reader.rs +++ b/src/transport/middle_proxy/reader.rs @@ -24,15 +24,27 @@ use super::registry::RouteResult; use super::{ConnRegistry, MeResponse}; const DATA_ROUTE_MAX_ATTEMPTS: usize = 3; +const DATA_ROUTE_QUEUE_FULL_STARVATION_THRESHOLD: u8 = 3; fn should_close_on_route_result_for_data(result: RouteResult) -> bool { - !matches!(result, RouteResult::Routed) + matches!(result, RouteResult::NoConn | RouteResult::ChannelClosed) } fn should_close_on_route_result_for_ack(result: RouteResult) -> bool { matches!(result, RouteResult::NoConn | RouteResult::ChannelClosed) } +fn is_data_route_queue_full(result: RouteResult) -> bool { + matches!( + result, + RouteResult::QueueFullBase | RouteResult::QueueFullHigh + ) +} + +fn should_close_on_queue_full_streak(streak: u8) -> bool { + streak >= DATA_ROUTE_QUEUE_FULL_STARVATION_THRESHOLD +} + async fn route_data_with_retry( reg: &ConnRegistry, conn_id: u64, @@ -85,6 +97,7 @@ pub(crate) async fn reader_loop( ) -> Result<()> { let mut raw = enc_leftover; let mut expected_seq: i32 = 0; + let mut data_route_queue_full_streak = HashMap::::new(); loop { let mut tmp = [0u8; 65_536]; @@ -169,25 +182,39 @@ pub(crate) async fn reader_loop( trace!(cid, flags, len = data.len(), "RPC_PROXY_ANS"); let route_wait_ms = reader_route_data_wait_ms.load(Ordering::Relaxed); - let routed = route_data_with_retry(reg.as_ref(), cid, flags, data, route_wait_ms).await; - if should_close_on_route_result_for_data(routed) { - match routed { - RouteResult::NoConn => stats.increment_me_route_drop_no_conn(), - RouteResult::ChannelClosed => { - stats.increment_me_route_drop_channel_closed() - } - RouteResult::QueueFullBase => { - stats.increment_me_route_drop_queue_full(); - stats.increment_me_route_drop_queue_full_base(); - } - RouteResult::QueueFullHigh => { - stats.increment_me_route_drop_queue_full(); - stats.increment_me_route_drop_queue_full_high(); - } - RouteResult::Routed => {} + let routed = + route_data_with_retry(reg.as_ref(), cid, flags, data, route_wait_ms).await; + if matches!(routed, RouteResult::Routed) { + data_route_queue_full_streak.remove(&cid); + continue; + } + match routed { + RouteResult::NoConn => stats.increment_me_route_drop_no_conn(), + RouteResult::ChannelClosed => stats.increment_me_route_drop_channel_closed(), + RouteResult::QueueFullBase => { + stats.increment_me_route_drop_queue_full(); + stats.increment_me_route_drop_queue_full_base(); } + RouteResult::QueueFullHigh => { + stats.increment_me_route_drop_queue_full(); + stats.increment_me_route_drop_queue_full_high(); + } + RouteResult::Routed => {} + } + if should_close_on_route_result_for_data(routed) { + data_route_queue_full_streak.remove(&cid); reg.unregister(cid).await; send_close_conn(&tx, cid).await; + continue; + } + if is_data_route_queue_full(routed) { + let streak = data_route_queue_full_streak.entry(cid).or_insert(0); + *streak = streak.saturating_add(1); + if should_close_on_queue_full_streak(*streak) { + data_route_queue_full_streak.remove(&cid); + reg.unregister(cid).await; + send_close_conn(&tx, cid).await; + } } } else if pt == RPC_SIMPLE_ACK_U32 && body.len() >= 12 { let cid = u64::from_le_bytes(body[0..8].try_into().unwrap()); @@ -221,11 +248,13 @@ pub(crate) async fn reader_loop( debug!(cid, "RPC_CLOSE_EXT from ME"); let _ = reg.route_nowait(cid, MeResponse::Close).await; reg.unregister(cid).await; + data_route_queue_full_streak.remove(&cid); } else if pt == RPC_CLOSE_CONN_U32 && body.len() >= 8 { let cid = u64::from_le_bytes(body[0..8].try_into().unwrap()); debug!(cid, "RPC_CLOSE_CONN from ME"); let _ = reg.route_nowait(cid, MeResponse::Close).await; reg.unregister(cid).await; + data_route_queue_full_streak.remove(&cid); } else if pt == RPC_PING_U32 && body.len() >= 8 { let ping_id = i64::from_le_bytes(body[0..8].try_into().unwrap()); trace!(ping_id, "RPC_PING -> RPC_PONG"); @@ -292,26 +321,50 @@ mod tests { use crate::transport::middle_proxy::ConnRegistry; use super::{ - MeResponse, RouteResult, route_data_with_retry, should_close_on_route_result_for_ack, + MeResponse, RouteResult, is_data_route_queue_full, route_data_with_retry, + should_close_on_queue_full_streak, should_close_on_route_result_for_ack, should_close_on_route_result_for_data, }; #[test] - fn data_route_failure_always_closes_session() { + fn data_route_only_fatal_results_close_immediately() { assert!(!should_close_on_route_result_for_data(RouteResult::Routed)); + assert!(!should_close_on_route_result_for_data( + RouteResult::QueueFullBase + )); + assert!(!should_close_on_route_result_for_data( + RouteResult::QueueFullHigh + )); assert!(should_close_on_route_result_for_data(RouteResult::NoConn)); - assert!(should_close_on_route_result_for_data(RouteResult::ChannelClosed)); - assert!(should_close_on_route_result_for_data(RouteResult::QueueFullBase)); - assert!(should_close_on_route_result_for_data(RouteResult::QueueFullHigh)); + assert!(should_close_on_route_result_for_data( + RouteResult::ChannelClosed + )); + } + + #[test] + fn data_route_queue_full_uses_starvation_threshold() { + assert!(is_data_route_queue_full(RouteResult::QueueFullBase)); + assert!(is_data_route_queue_full(RouteResult::QueueFullHigh)); + assert!(!is_data_route_queue_full(RouteResult::NoConn)); + assert!(!should_close_on_queue_full_streak(1)); + assert!(!should_close_on_queue_full_streak(2)); + assert!(should_close_on_queue_full_streak(3)); + assert!(should_close_on_queue_full_streak(u8::MAX)); } #[test] fn ack_queue_full_is_soft_dropped_without_forced_close() { assert!(!should_close_on_route_result_for_ack(RouteResult::Routed)); - assert!(!should_close_on_route_result_for_ack(RouteResult::QueueFullBase)); - assert!(!should_close_on_route_result_for_ack(RouteResult::QueueFullHigh)); + assert!(!should_close_on_route_result_for_ack( + RouteResult::QueueFullBase + )); + assert!(!should_close_on_route_result_for_ack( + RouteResult::QueueFullHigh + )); assert!(should_close_on_route_result_for_ack(RouteResult::NoConn)); - assert!(should_close_on_route_result_for_ack(RouteResult::ChannelClosed)); + assert!(should_close_on_route_result_for_ack( + RouteResult::ChannelClosed + )); } #[tokio::test] @@ -319,8 +372,7 @@ mod tests { let reg = ConnRegistry::with_route_channel_capacity(1); let (conn_id, mut rx) = reg.register().await; - let routed = - route_data_with_retry(®, conn_id, 0, Bytes::from_static(b"a"), 20).await; + let routed = route_data_with_retry(®, conn_id, 0, Bytes::from_static(b"a"), 20).await; assert!(matches!(routed, RouteResult::Routed)); match rx.recv().await { Some(MeResponse::Data { flags, data }) => { @@ -341,8 +393,7 @@ mod tests { RouteResult::Routed )); - let routed = - route_data_with_retry(®, conn_id, 0, Bytes::from_static(b"a"), 0).await; + let routed = route_data_with_retry(®, conn_id, 0, Bytes::from_static(b"a"), 0).await; assert!(matches!( routed, RouteResult::QueueFullBase | RouteResult::QueueFullHigh diff --git a/src/transport/middle_proxy/registry.rs b/src/transport/middle_proxy/registry.rs index 17fce47..d8625f2 100644 --- a/src/transport/middle_proxy/registry.rs +++ b/src/transport/middle_proxy/registry.rs @@ -356,13 +356,9 @@ impl ConnRegistry { .entry(writer_id) .or_insert_with(HashSet::new) .insert(conn_id); - self.hot_binding.map.insert( - conn_id, - HotConnBinding { - writer_id, - meta, - }, - ); + self.hot_binding + .map + .insert(conn_id, HotConnBinding { writer_id, meta }); true } @@ -427,8 +423,16 @@ impl ConnRegistry { return None; } - let writer_id = self.hot_binding.map.get(&conn_id).map(|entry| entry.writer_id)?; - let writer = self.writers.map.get(&writer_id).map(|entry| entry.value().clone())?; + let writer_id = self + .hot_binding + .map + .get(&conn_id) + .map(|entry| entry.writer_id)?; + let writer = self + .writers + .map + .get(&writer_id) + .map(|entry| entry.value().clone())?; Some(ConnWriter { writer_id, tx: writer,