From 635bea4de4eed088806e559deb3dbd62801a3410 Mon Sep 17 00:00:00 2001 From: Mirotin Artem Date: Sat, 25 Apr 2026 00:02:32 +0300 Subject: [PATCH 1/8] feat(api): add Patch enum for JSON merge-patch semantics Introduce a three-state Patch (Unchanged / Remove / Set) and a serde helper patch_field that distinguishes an omitted JSON field from an explicit null. Wired up next as the field type for the removable settings on PATCH /v1/users/{user}. --- src/api/mod.rs | 1 + src/api/patch.rs | 79 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 80 insertions(+) create mode 100644 src/api/patch.rs diff --git a/src/api/mod.rs b/src/api/mod.rs index f33a89b..1778d7d 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -28,6 +28,7 @@ mod config_store; mod events; mod http_utils; mod model; +mod patch; mod runtime_edge; mod runtime_init; mod runtime_min; diff --git a/src/api/patch.rs b/src/api/patch.rs new file mode 100644 index 0000000..6425af1 --- /dev/null +++ b/src/api/patch.rs @@ -0,0 +1,79 @@ +use serde::Deserialize; + +/// Three-state field for JSON Merge Patch semantics on the `PATCH /v1/users/{user}` +/// endpoint. +/// +/// `Unchanged` is produced when the JSON body omits the field entirely and tells the +/// handler to leave the corresponding configuration entry untouched. `Remove` is +/// produced when the JSON body sets the field to `null` and instructs the handler to +/// drop the entry from the corresponding access HashMap. `Set` carries an explicit +/// new value, including zero, which is preserved verbatim in the configuration. +#[derive(Debug)] +pub(super) enum Patch { + Unchanged, + Remove, + Set(T), +} + +impl Default for Patch { + fn default() -> Self { + Self::Unchanged + } +} + +/// Serde deserializer adapter for fields that follow JSON Merge Patch semantics. +/// +/// Pair this with `#[serde(default, deserialize_with = "patch_field")]` on a +/// `Patch` field. An omitted field falls back to `Patch::Unchanged` via +/// `Default`; an explicit JSON `null` becomes `Patch::Remove`; any other value +/// becomes `Patch::Set(v)`. +pub(super) fn patch_field<'de, D, T>(deserializer: D) -> Result, D::Error> +where + D: serde::Deserializer<'de>, + T: serde::Deserialize<'de>, +{ + Option::::deserialize(deserializer).map(|opt| match opt { + Some(value) => Patch::Set(value), + None => Patch::Remove, + }) +} + +#[cfg(test)] +mod tests { + use super::*; + use serde::Deserialize; + + #[derive(Deserialize)] + struct Holder { + #[serde(default, deserialize_with = "patch_field")] + value: Patch, + } + + fn parse(json: &str) -> Holder { + serde_json::from_str(json).expect("valid json") + } + + #[test] + fn omitted_field_yields_unchanged() { + let h = parse("{}"); + assert!(matches!(h.value, Patch::Unchanged)); + } + + #[test] + fn explicit_null_yields_remove() { + let h = parse(r#"{"value": null}"#); + assert!(matches!(h.value, Patch::Remove)); + } + + #[test] + fn explicit_value_yields_set() { + let h = parse(r#"{"value": 42}"#); + assert!(matches!(h.value, Patch::Set(42))); + } + + #[test] + fn explicit_zero_yields_set_zero() { + let h = parse(r#"{"value": 0}"#); + assert!(matches!(h.value, Patch::Set(0))); + } +} From 4ed87d194688e84f0202cec4f5246b0cd985d17c Mon Sep 17 00:00:00 2001 From: Mirotin Artem Date: Sat, 25 Apr 2026 00:22:09 +0300 Subject: [PATCH 2/8] feat(api): support null-removal in PATCH /v1/users/{user} PatchUserRequest now uses Patch for the five removable fields (user_ad_tag, max_tcp_conns, expiration_rfc3339, data_quota_bytes, max_unique_ips). Sending JSON null drops the entry from the corresponding access HashMap; sending 0 is preserved as a literal limit; omitted fields stay untouched. The handler synchronises the in-memory ip_tracker on both set and remove of max_unique_ips. A helper parse_patch_expiration mirrors parse_optional_expiration for the new three-state field. Runtime semantics are unchanged. --- src/api/model.rs | 30 ++++++++++++++--- src/api/patch.rs | 51 +++++++++++++++++++++++++++++ src/api/users.rs | 84 ++++++++++++++++++++++++++++++++++-------------- 3 files changed, 135 insertions(+), 30 deletions(-) diff --git a/src/api/model.rs b/src/api/model.rs index fa1f063..1ca9f33 100644 --- a/src/api/model.rs +++ b/src/api/model.rs @@ -5,6 +5,7 @@ use chrono::{DateTime, Utc}; use hyper::StatusCode; use serde::{Deserialize, Serialize}; +use super::patch::{Patch, patch_field}; use crate::crypto::SecureRandom; const MAX_USERNAME_LEN: usize = 64; @@ -507,11 +508,16 @@ pub(super) struct CreateUserRequest { #[derive(Deserialize)] pub(super) struct PatchUserRequest { pub(super) secret: Option, - pub(super) user_ad_tag: Option, - pub(super) max_tcp_conns: Option, - pub(super) expiration_rfc3339: Option, - pub(super) data_quota_bytes: Option, - pub(super) max_unique_ips: Option, + #[serde(default, deserialize_with = "patch_field")] + pub(super) user_ad_tag: Patch, + #[serde(default, deserialize_with = "patch_field")] + pub(super) max_tcp_conns: Patch, + #[serde(default, deserialize_with = "patch_field")] + pub(super) expiration_rfc3339: Patch, + #[serde(default, deserialize_with = "patch_field")] + pub(super) data_quota_bytes: Patch, + #[serde(default, deserialize_with = "patch_field")] + pub(super) max_unique_ips: Patch, } #[derive(Default, Deserialize)] @@ -530,6 +536,20 @@ pub(super) fn parse_optional_expiration( Ok(Some(parsed.with_timezone(&Utc))) } +pub(super) fn parse_patch_expiration( + value: &Patch, +) -> Result>, ApiFailure> { + match value { + Patch::Unchanged => Ok(Patch::Unchanged), + Patch::Remove => Ok(Patch::Remove), + Patch::Set(raw) => { + let parsed = DateTime::parse_from_rfc3339(raw) + .map_err(|_| ApiFailure::bad_request("expiration_rfc3339 must be valid RFC3339"))?; + Ok(Patch::Set(parsed.with_timezone(&Utc))) + } + } +} + pub(super) fn is_valid_user_secret(secret: &str) -> bool { secret.len() == 32 && secret.chars().all(|c| c.is_ascii_hexdigit()) } diff --git a/src/api/patch.rs b/src/api/patch.rs index 6425af1..65cd191 100644 --- a/src/api/patch.rs +++ b/src/api/patch.rs @@ -41,6 +41,8 @@ where #[cfg(test)] mod tests { use super::*; + use crate::api::model::{PatchUserRequest, parse_patch_expiration}; + use chrono::{TimeZone, Utc}; use serde::Deserialize; #[derive(Deserialize)] @@ -76,4 +78,53 @@ mod tests { let h = parse(r#"{"value": 0}"#); assert!(matches!(h.value, Patch::Set(0))); } + + #[test] + fn parse_patch_expiration_passes_unchanged_and_remove_through() { + assert!(matches!( + parse_patch_expiration(&Patch::Unchanged), + Ok(Patch::Unchanged) + )); + assert!(matches!( + parse_patch_expiration(&Patch::Remove), + Ok(Patch::Remove) + )); + } + + #[test] + fn parse_patch_expiration_parses_set_value() { + let parsed = + parse_patch_expiration(&Patch::Set("2030-01-02T03:04:05Z".into())).expect("valid"); + match parsed { + Patch::Set(dt) => { + assert_eq!(dt, Utc.with_ymd_and_hms(2030, 1, 2, 3, 4, 5).unwrap()); + } + other => panic!("expected Patch::Set, got {:?}", other), + } + } + + #[test] + fn parse_patch_expiration_rejects_invalid_set_value() { + assert!(parse_patch_expiration(&Patch::Set("not-a-date".into())).is_err()); + } + + #[test] + fn patch_user_request_deserializes_mixed_states() { + let raw = r#"{ + "secret": "00112233445566778899aabbccddeeff", + "max_tcp_conns": 0, + "max_unique_ips": null, + "data_quota_bytes": 1024 + }"#; + let req: PatchUserRequest = serde_json::from_str(raw).expect("valid json"); + assert_eq!( + req.secret.as_deref(), + Some("00112233445566778899aabbccddeeff") + ); + assert!(matches!(req.max_tcp_conns, Patch::Set(0))); + assert!(matches!(req.max_unique_ips, Patch::Remove)); + assert!(matches!(req.data_quota_bytes, Patch::Set(1024))); + assert!(matches!(req.expiration_rfc3339, Patch::Unchanged)); + assert!(matches!(req.user_ad_tag, Patch::Unchanged)); + } } diff --git a/src/api/users.rs b/src/api/users.rs index 6b20b85..ef8f10a 100644 --- a/src/api/users.rs +++ b/src/api/users.rs @@ -14,8 +14,9 @@ use super::config_store::{ use super::model::{ ApiFailure, CreateUserRequest, CreateUserResponse, PatchUserRequest, RotateSecretRequest, UserInfo, UserLinks, is_valid_ad_tag, is_valid_user_secret, is_valid_username, - parse_optional_expiration, random_user_secret, + parse_optional_expiration, parse_patch_expiration, random_user_secret, }; +use super::patch::Patch; pub(super) async fn create_user( body: CreateUserRequest, @@ -182,14 +183,14 @@ pub(super) async fn patch_user( "secret must be exactly 32 hex characters", )); } - if let Some(ad_tag) = body.user_ad_tag.as_ref() + if let Patch::Set(ad_tag) = &body.user_ad_tag && !is_valid_ad_tag(ad_tag) { return Err(ApiFailure::bad_request( "user_ad_tag must be exactly 32 hex characters", )); } - let expiration = parse_optional_expiration(body.expiration_rfc3339.as_deref())?; + let expiration = parse_patch_expiration(&body.expiration_rfc3339)?; let _guard = shared.mutation_lock.lock().await; let mut cfg = load_config_from_disk(&shared.config_path).await?; ensure_expected_revision(&shared.config_path, expected_revision.as_deref()).await?; @@ -205,38 +206,71 @@ pub(super) async fn patch_user( if let Some(secret) = body.secret { cfg.access.users.insert(user.to_string(), secret); } - if let Some(ad_tag) = body.user_ad_tag { - cfg.access.user_ad_tags.insert(user.to_string(), ad_tag); + match body.user_ad_tag { + Patch::Unchanged => {} + Patch::Remove => { + cfg.access.user_ad_tags.remove(user); + } + Patch::Set(ad_tag) => { + cfg.access.user_ad_tags.insert(user.to_string(), ad_tag); + } } - if let Some(limit) = body.max_tcp_conns { - cfg.access - .user_max_tcp_conns - .insert(user.to_string(), limit); + match body.max_tcp_conns { + Patch::Unchanged => {} + Patch::Remove => { + cfg.access.user_max_tcp_conns.remove(user); + } + Patch::Set(limit) => { + cfg.access + .user_max_tcp_conns + .insert(user.to_string(), limit); + } } - if let Some(expiration) = expiration { - cfg.access - .user_expirations - .insert(user.to_string(), expiration); + match expiration { + Patch::Unchanged => {} + Patch::Remove => { + cfg.access.user_expirations.remove(user); + } + Patch::Set(expiration) => { + cfg.access + .user_expirations + .insert(user.to_string(), expiration); + } } - if let Some(quota) = body.data_quota_bytes { - cfg.access.user_data_quota.insert(user.to_string(), quota); - } - - let mut updated_limit = None; - if let Some(limit) = body.max_unique_ips { - cfg.access - .user_max_unique_ips - .insert(user.to_string(), limit); - updated_limit = Some(limit); + match body.data_quota_bytes { + Patch::Unchanged => {} + Patch::Remove => { + cfg.access.user_data_quota.remove(user); + } + Patch::Set(quota) => { + cfg.access.user_data_quota.insert(user.to_string(), quota); + } } + // Capture how the per-user IP limit changed, so the in-memory ip_tracker + // can be synced (set or removed) after the config is persisted. + let max_unique_ips_change = match body.max_unique_ips { + Patch::Unchanged => None, + Patch::Remove => { + cfg.access.user_max_unique_ips.remove(user); + Some(None) + } + Patch::Set(limit) => { + cfg.access + .user_max_unique_ips + .insert(user.to_string(), limit); + Some(Some(limit)) + } + }; cfg.validate() .map_err(|e| ApiFailure::bad_request(format!("config validation failed: {}", e)))?; let revision = save_config_to_disk(&shared.config_path, &cfg).await?; drop(_guard); - if let Some(limit) = updated_limit { - shared.ip_tracker.set_user_limit(user, limit).await; + match max_unique_ips_change { + Some(Some(limit)) => shared.ip_tracker.set_user_limit(user, limit).await, + Some(None) => shared.ip_tracker.remove_user_limit(user).await, + None => {} } let (detected_ip_v4, detected_ip_v6) = shared.detected_link_ips(); let users = users_from_config( From e78592ef9b2a5f56c577775275465c3bb7601f5b Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Sat, 25 Apr 2026 12:00:46 +0300 Subject: [PATCH 3/8] Avoid IP tracking when unique-IP limits are disabled and cap beobachten memory Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com> Signed-off-by: Alexey <247128645+axkurcom@users.noreply.github.com> --- src/proxy/client.rs | 82 +++++++++++++++--------- src/proxy/tests/client_security_tests.rs | 25 +++++++- src/stats/beobachten.rs | 24 +++++-- 3 files changed, 91 insertions(+), 40 deletions(-) diff --git a/src/proxy/client.rs b/src/proxy/client.rs index 2d4dd42..2ab02ce 100644 --- a/src/proxy/client.rs +++ b/src/proxy/client.rs @@ -31,16 +31,24 @@ struct UserConnectionReservation { ip_tracker: Arc, user: String, ip: IpAddr, + tracks_ip: bool, active: bool, } impl UserConnectionReservation { - fn new(stats: Arc, ip_tracker: Arc, user: String, ip: IpAddr) -> Self { + fn new( + stats: Arc, + ip_tracker: Arc, + user: String, + ip: IpAddr, + tracks_ip: bool, + ) -> Self { Self { stats, ip_tracker, user, ip, + tracks_ip, active: true, } } @@ -49,7 +57,9 @@ impl UserConnectionReservation { if !self.active { return; } - self.ip_tracker.remove_ip(&self.user, self.ip).await; + if self.tracks_ip { + self.ip_tracker.remove_ip(&self.user, self.ip).await; + } self.active = false; self.stats.decrement_user_curr_connects(&self.user); } @@ -62,7 +72,9 @@ impl Drop for UserConnectionReservation { } self.active = false; self.stats.decrement_user_curr_connects(&self.user); - self.ip_tracker.enqueue_cleanup(self.user.clone(), self.ip); + if self.tracks_ip { + self.ip_tracker.enqueue_cleanup(self.user.clone(), self.ip); + } } } @@ -1600,19 +1612,22 @@ impl RunningClientHandler { }); } - match ip_tracker.check_and_add(user, peer_addr.ip()).await { - Ok(()) => {} - Err(reason) => { - stats.decrement_user_curr_connects(user); - warn!( - user = %user, - ip = %peer_addr.ip(), - reason = %reason, - "IP limit exceeded" - ); - return Err(ProxyError::ConnectionLimitExceeded { - user: user.to_string(), - }); + let tracks_ip = ip_tracker.get_user_limit(user).await.is_some(); + if tracks_ip { + match ip_tracker.check_and_add(user, peer_addr.ip()).await { + Ok(()) => {} + Err(reason) => { + stats.decrement_user_curr_connects(user); + warn!( + user = %user, + ip = %peer_addr.ip(), + reason = %reason, + "IP limit exceeded" + ); + return Err(ProxyError::ConnectionLimitExceeded { + user: user.to_string(), + }); + } } } @@ -1621,6 +1636,7 @@ impl RunningClientHandler { ip_tracker, user.to_string(), peer_addr.ip(), + tracks_ip, )) } @@ -1663,25 +1679,27 @@ impl RunningClientHandler { }); } - match ip_tracker.check_and_add(user, peer_addr.ip()).await { - Ok(()) => { - ip_tracker.remove_ip(user, peer_addr.ip()).await; - stats.decrement_user_curr_connects(user); - } - Err(reason) => { - stats.decrement_user_curr_connects(user); - warn!( - user = %user, - ip = %peer_addr.ip(), - reason = %reason, - "IP limit exceeded" - ); - return Err(ProxyError::ConnectionLimitExceeded { - user: user.to_string(), - }); + if ip_tracker.get_user_limit(user).await.is_some() { + match ip_tracker.check_and_add(user, peer_addr.ip()).await { + Ok(()) => { + ip_tracker.remove_ip(user, peer_addr.ip()).await; + } + Err(reason) => { + stats.decrement_user_curr_connects(user); + warn!( + user = %user, + ip = %peer_addr.ip(), + reason = %reason, + "IP limit exceeded" + ); + return Err(ProxyError::ConnectionLimitExceeded { + user: user.to_string(), + }); + } } } + stats.decrement_user_curr_connects(user); Ok(()) } } diff --git a/src/proxy/tests/client_security_tests.rs b/src/proxy/tests/client_security_tests.rs index 480b33d..4505e17 100644 --- a/src/proxy/tests/client_security_tests.rs +++ b/src/proxy/tests/client_security_tests.rs @@ -281,8 +281,13 @@ async fn user_connection_reservation_drop_enqueues_cleanup_synchronously() { assert_eq!(ip_tracker.get_active_ip_count(&user).await, 1); assert_eq!(stats.get_user_curr_connects(&user), 1); - let reservation = - UserConnectionReservation::new(stats.clone(), ip_tracker.clone(), user.clone(), ip); + let reservation = UserConnectionReservation::new( + stats.clone(), + ip_tracker.clone(), + user.clone(), + ip, + true, + ); // Drop the reservation synchronously without any tokio::spawn/await yielding! drop(reservation); @@ -320,6 +325,7 @@ async fn relay_task_abort_releases_user_gate_and_ip_reservation() { let stats = Arc::new(Stats::new()); let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 8).await; let mut cfg = ProxyConfig::default(); cfg.access.user_max_tcp_conns.insert(user.to_string(), 8); @@ -437,6 +443,7 @@ async fn relay_cutover_releases_user_gate_and_ip_reservation() { let stats = Arc::new(Stats::new()); let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 8).await; let mut cfg = ProxyConfig::default(); cfg.access.user_max_tcp_conns.insert(user.to_string(), 8); @@ -2879,6 +2886,7 @@ async fn explicit_reservation_release_cleans_user_and_ip_immediately() { let stats = Arc::new(Stats::new()); let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 4).await; let reservation = RunningClientHandler::acquire_user_connection_reservation_static( user, @@ -2917,6 +2925,7 @@ async fn explicit_reservation_release_does_not_double_decrement_on_drop() { let stats = Arc::new(Stats::new()); let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 4).await; let reservation = RunningClientHandler::acquire_user_connection_reservation_static( user, @@ -2947,6 +2956,7 @@ async fn drop_fallback_eventually_cleans_user_and_ip_reservation() { let stats = Arc::new(Stats::new()); let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 1).await; let reservation = RunningClientHandler::acquire_user_connection_reservation_static( user, @@ -3029,6 +3039,7 @@ async fn release_abort_storm_does_not_leak_user_or_ip_reservations() { let stats = Arc::new(Stats::new()); let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, ATTEMPTS + 16).await; for idx in 0..ATTEMPTS { let peer = SocketAddr::new( @@ -3079,6 +3090,7 @@ async fn release_abort_loop_preserves_immediate_same_ip_reacquire() { let stats = Arc::new(Stats::new()); let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 1).await; for _ in 0..ITERATIONS { let reservation = RunningClientHandler::acquire_user_connection_reservation_static( @@ -3137,6 +3149,7 @@ async fn adversarial_mixed_release_drop_abort_wave_converges_to_zero() { let stats = Arc::new(Stats::new()); let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, RESERVATIONS + 8).await; let mut reservations = Vec::with_capacity(RESERVATIONS); for idx in 0..RESERVATIONS { @@ -3217,6 +3230,8 @@ async fn parallel_users_abort_release_isolation_preserves_independent_cleanup() let stats = Arc::new(Stats::new()); let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user_a, 64).await; + ip_tracker.set_user_limit(user_b, 64).await; let mut tasks = tokio::task::JoinSet::new(); for idx in 0..64usize { @@ -3278,6 +3293,7 @@ async fn concurrent_release_storm_leaves_zero_user_and_ip_footprint() { let stats = Arc::new(Stats::new()); let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, RESERVATIONS + 8).await; let mut reservations = Vec::with_capacity(RESERVATIONS); for idx in 0..RESERVATIONS { @@ -3332,6 +3348,7 @@ async fn relay_connect_error_releases_user_and_ip_before_return() { let stats = Arc::new(Stats::new()); let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 8).await; let mut config = ProxyConfig::default(); config.access.user_max_tcp_conns.insert(user.to_string(), 1); @@ -3427,6 +3444,7 @@ async fn mixed_release_and_drop_same_ip_preserves_counter_correctness() { let stats = Arc::new(Stats::new()); let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 1).await; let reservation_a = RunningClientHandler::acquire_user_connection_reservation_static( user, @@ -3487,6 +3505,7 @@ async fn drop_one_of_two_same_ip_reservations_keeps_ip_active() { let stats = Arc::new(Stats::new()); let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 1).await; let reservation_a = RunningClientHandler::acquire_user_connection_reservation_static( user, @@ -3696,6 +3715,7 @@ async fn cross_thread_drop_uses_captured_runtime_for_ip_cleanup() { let stats = Arc::new(Stats::new()); let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 8).await; let reservation = RunningClientHandler::acquire_user_connection_reservation_static( user, @@ -3740,6 +3760,7 @@ async fn immediate_reacquire_after_cross_thread_drop_succeeds() { let stats = Arc::new(Stats::new()); let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 1).await; let reservation = RunningClientHandler::acquire_user_connection_reservation_static( user, diff --git a/src/stats/beobachten.rs b/src/stats/beobachten.rs index 3d3a2da..79b2bcd 100644 --- a/src/stats/beobachten.rs +++ b/src/stats/beobachten.rs @@ -7,6 +7,7 @@ use std::time::{Duration, Instant}; use parking_lot::Mutex; const CLEANUP_INTERVAL: Duration = Duration::from_secs(30); +const MAX_BEOBACHTEN_ENTRIES: usize = 65_536; #[derive(Default)] struct BeobachtenInner { @@ -48,12 +49,23 @@ impl BeobachtenStore { Self::cleanup_if_needed(&mut guard, now, ttl); let key = (class.to_string(), ip); - let entry = guard.entries.entry(key).or_insert(BeobachtenEntry { - tries: 0, - last_seen: now, - }); - entry.tries = entry.tries.saturating_add(1); - entry.last_seen = now; + if let Some(entry) = guard.entries.get_mut(&key) { + entry.tries = entry.tries.saturating_add(1); + entry.last_seen = now; + return; + } + + if guard.entries.len() >= MAX_BEOBACHTEN_ENTRIES { + return; + } + + guard.entries.insert( + key, + BeobachtenEntry { + tries: 1, + last_seen: now, + }, + ); } pub fn snapshot_text(&self, ttl: Duration) -> String { From 27b5d576c0fbf713a8f5db02b8337d902e1cc958 Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Sat, 25 Apr 2026 12:16:26 +0300 Subject: [PATCH 4/8] Bound hot-path pressure in ME Relay + Handshake Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com> Signed-off-by: Alexey <247128645+axkurcom@users.noreply.github.com> --- src/proxy/handshake.rs | 32 ++++++- src/proxy/middle_relay.rs | 75 ++++++++++++++- src/proxy/tests/handshake_security_tests.rs | 91 +++++++++++++++++++ ...le_relay_stub_completion_security_tests.rs | 12 ++- 4 files changed, 202 insertions(+), 8 deletions(-) diff --git a/src/proxy/handshake.rs b/src/proxy/handshake.rs index cdfd844..f719349 100644 --- a/src/proxy/handshake.rs +++ b/src/proxy/handshake.rs @@ -55,6 +55,7 @@ const STICKY_HINT_MAX_ENTRIES: usize = 65_536; const CANDIDATE_HINT_TRACK_CAP: usize = 64; const OVERLOAD_CANDIDATE_BUDGET_HINTED: usize = 16; const OVERLOAD_CANDIDATE_BUDGET_UNHINTED: usize = 8; +const EXPENSIVE_INVALID_SCAN_SATURATION_THRESHOLD: usize = 64; const RECENT_USER_RING_SCAN_LIMIT: usize = 32; type HmacSha256 = Hmac; @@ -551,6 +552,19 @@ fn auth_probe_note_saturation_in(shared: &ProxySharedState, now: Instant) { } } +fn auth_probe_note_expensive_invalid_scan_in( + shared: &ProxySharedState, + now: Instant, + validation_checks: usize, + overload: bool, +) { + if overload || validation_checks < EXPENSIVE_INVALID_SCAN_SATURATION_THRESHOLD { + return; + } + + auth_probe_note_saturation_in(shared, now); +} + fn auth_probe_record_failure_in(shared: &ProxySharedState, peer_ip: IpAddr, now: Instant) { let peer_ip = normalize_auth_probe_ip(peer_ip); let state = &shared.handshake.auth_probe; @@ -1378,7 +1392,14 @@ where } if !matched { - auth_probe_record_failure_in(shared, peer.ip(), Instant::now()); + let failure_now = Instant::now(); + auth_probe_note_expensive_invalid_scan_in( + shared, + failure_now, + validation_checks, + overload, + ); + auth_probe_record_failure_in(shared, peer.ip(), failure_now); maybe_apply_server_hello_delay(config).await; debug!( peer = %peer, @@ -1753,7 +1774,14 @@ where } if !matched { - auth_probe_record_failure_in(shared, peer.ip(), Instant::now()); + let failure_now = Instant::now(); + auth_probe_note_expensive_invalid_scan_in( + shared, + failure_now, + validation_checks, + overload, + ); + auth_probe_record_failure_in(shared, peer.ip(), failure_now); maybe_apply_server_hello_delay(config).await; debug!( peer = %peer, diff --git a/src/proxy/middle_relay.rs b/src/proxy/middle_relay.rs index f1f6584..b0ddb8f 100644 --- a/src/proxy/middle_relay.rs +++ b/src/proxy/middle_relay.rs @@ -12,7 +12,7 @@ use std::sync::atomic::{AtomicU64, Ordering}; use std::time::{Duration, Instant}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; -use tokio::sync::{mpsc, oneshot, watch}; +use tokio::sync::{OwnedSemaphorePermit, Semaphore, mpsc, oneshot, watch}; use tokio::time::timeout; use tracing::{debug, info, trace, warn}; @@ -36,7 +36,11 @@ use crate::stream::{BufferPool, CryptoReader, CryptoWriter, PooledBuffer}; use crate::transport::middle_proxy::{MePool, MeResponse, proto_flags_for_tag}; enum C2MeCommand { - Data { payload: PooledBuffer, flags: u32 }, + Data { + payload: PooledBuffer, + flags: u32, + _permit: OwnedSemaphorePermit, + }, Close, } @@ -47,6 +51,8 @@ const DESYNC_ERROR_CLASS: &str = "frame_too_large_crypto_desync"; const C2ME_CHANNEL_CAPACITY_FALLBACK: usize = 128; const C2ME_SOFT_PRESSURE_MIN_FREE_SLOTS: usize = 64; const C2ME_SENDER_FAIRNESS_BUDGET: usize = 32; +const C2ME_QUEUED_BYTE_PERMIT_UNIT: usize = 16 * 1024; +const C2ME_QUEUED_PERMITS_PER_SLOT: usize = 4; const RELAY_IDLE_IO_POLL_MAX: Duration = Duration::from_secs(1); const TINY_FRAME_DEBT_PER_TINY: u32 = 8; const TINY_FRAME_DEBT_LIMIT: u32 = 512; @@ -571,6 +577,43 @@ fn should_yield_c2me_sender(sent_since_yield: usize, has_backlog: bool) -> bool has_backlog && sent_since_yield >= C2ME_SENDER_FAIRNESS_BUDGET } +fn c2me_payload_permits(payload_len: usize) -> u32 { + payload_len + .max(1) + .div_ceil(C2ME_QUEUED_BYTE_PERMIT_UNIT) + .min(u32::MAX as usize) as u32 +} + +fn c2me_queued_permit_budget(channel_capacity: usize, frame_limit: usize) -> usize { + channel_capacity + .saturating_mul(C2ME_QUEUED_PERMITS_PER_SLOT) + .max(c2me_payload_permits(frame_limit) as usize) + .max(1) +} + +async fn acquire_c2me_payload_permit( + semaphore: &Arc, + payload_len: usize, + send_timeout: Option, + stats: &Stats, +) -> Result { + let permits = c2me_payload_permits(payload_len); + let acquire = semaphore.clone().acquire_many_owned(permits); + match send_timeout { + Some(send_timeout) => match timeout(send_timeout, acquire).await { + Ok(Ok(permit)) => Ok(permit), + Ok(Err(_)) => Err(ProxyError::Proxy("ME sender byte budget closed".into())), + Err(_) => { + stats.increment_me_c2me_send_timeout_total(); + Err(ProxyError::Proxy("ME sender byte budget timeout".into())) + } + }, + None => acquire + .await + .map_err(|_| ProxyError::Proxy("ME sender byte budget closed".into())), + } +} + fn quota_soft_cap(limit: u64, overshoot: u64) -> u64 { limit.saturating_add(overshoot) } @@ -1122,13 +1165,19 @@ where 0 => None, timeout_ms => Some(Duration::from_millis(timeout_ms)), }; + let c2me_byte_budget = c2me_queued_permit_budget(c2me_channel_capacity, frame_limit); + let c2me_byte_semaphore = Arc::new(Semaphore::new(c2me_byte_budget)); let (c2me_tx, mut c2me_rx) = mpsc::channel::(c2me_channel_capacity); let me_pool_c2me = me_pool.clone(); let c2me_sender = tokio::spawn(async move { let mut sent_since_yield = 0usize; while let Some(cmd) = c2me_rx.recv().await { match cmd { - C2MeCommand::Data { payload, flags } => { + C2MeCommand::Data { + payload, + flags, + _permit, + } => { me_pool_c2me .send_proxy_req( conn_id, @@ -1624,11 +1673,29 @@ where if payload.len() >= 8 && payload[..8].iter().all(|b| *b == 0) { flags |= RPC_FLAG_NOT_ENCRYPTED; } + let payload_permit = match acquire_c2me_payload_permit( + &c2me_byte_semaphore, + payload.len(), + c2me_send_timeout, + stats.as_ref(), + ) + .await + { + Ok(permit) => permit, + Err(e) => { + main_result = Err(e); + break; + } + }; // Keep client read loop lightweight: route heavy ME send path via a dedicated task. if enqueue_c2me_command_in( shared.as_ref(), &c2me_tx, - C2MeCommand::Data { payload, flags }, + C2MeCommand::Data { + payload, + flags, + _permit: payload_permit, + }, c2me_send_timeout, stats.as_ref(), ) diff --git a/src/proxy/tests/handshake_security_tests.rs b/src/proxy/tests/handshake_security_tests.rs index df91cac..dd7ad08 100644 --- a/src/proxy/tests/handshake_security_tests.rs +++ b/src/proxy/tests/handshake_security_tests.rs @@ -1252,6 +1252,97 @@ async fn tls_overload_budget_limits_candidate_scan_depth() { ); } +#[tokio::test] +async fn tls_expensive_invalid_scan_activates_saturation_budget() { + let mut config = ProxyConfig::default(); + config.access.users.clear(); + config.access.ignore_time_skew = true; + for idx in 0..80u8 { + config.access.users.insert( + format!("user-{idx}"), + format!("{:032x}", u128::from(idx) + 1), + ); + } + config.rebuild_runtime_user_auth().unwrap(); + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let shared = ProxySharedState::new(); + let attacker_secret = [0xEFu8; 16]; + let handshake = make_valid_tls_handshake(&attacker_secret, 0); + + let first_peer: SocketAddr = "198.51.100.214:44326".parse().unwrap(); + let first = handle_tls_handshake_with_shared( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + first_peer, + &config, + &replay_checker, + &rng, + None, + shared.as_ref(), + ) + .await; + + assert!(matches!(first, HandshakeResult::BadClient { .. })); + assert!( + auth_probe_saturation_state_for_testing_in_shared(shared.as_ref()) + .lock() + .unwrap() + .is_some(), + "expensive invalid scan must activate global saturation" + ); + assert_eq!( + shared + .handshake + .auth_expensive_checks_total + .load(Ordering::Relaxed), + 80, + "first invalid probe preserves full first-hit compatibility before enabling saturation" + ); + + { + let mut saturation = auth_probe_saturation_state_for_testing_in_shared(shared.as_ref()) + .lock() + .unwrap(); + let state = saturation.as_mut().expect("saturation must be present"); + state.blocked_until = Instant::now() + Duration::from_millis(200); + } + + let second_peer: SocketAddr = "198.51.100.215:44326".parse().unwrap(); + let second = handle_tls_handshake_with_shared( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + second_peer, + &config, + &replay_checker, + &rng, + None, + shared.as_ref(), + ) + .await; + + assert!(matches!(second, HandshakeResult::BadClient { .. })); + assert_eq!( + shared + .handshake + .auth_budget_exhausted_total + .load(Ordering::Relaxed), + 1, + "second invalid probe must be capped by overload budget" + ); + assert_eq!( + shared + .handshake + .auth_expensive_checks_total + .load(Ordering::Relaxed), + 80 + OVERLOAD_CANDIDATE_BUDGET_UNHINTED as u64, + "saturation budget must bound follow-up invalid scans" + ); +} + #[tokio::test] async fn mtproto_runtime_snapshot_prefers_preferred_user_hint() { let mut config = ProxyConfig::default(); diff --git a/src/proxy/tests/middle_relay_stub_completion_security_tests.rs b/src/proxy/tests/middle_relay_stub_completion_security_tests.rs index 54eb784..6d398c8 100644 --- a/src/proxy/tests/middle_relay_stub_completion_security_tests.rs +++ b/src/proxy/tests/middle_relay_stub_completion_security_tests.rs @@ -12,6 +12,12 @@ fn make_pooled_payload(data: &[u8]) -> PooledBuffer { payload } +fn make_c2me_permit() -> tokio::sync::OwnedSemaphorePermit { + Arc::new(tokio::sync::Semaphore::new(1)) + .try_acquire_many_owned(1) + .expect("test permit must be available") +} + #[test] #[ignore = "Tracking for M-04: Verify should_emit_full_desync returns true on first occurrence and false on duplicate within window"] fn should_emit_full_desync_filters_duplicates() { @@ -107,6 +113,7 @@ async fn c2me_channel_full_path_yields_then_sends() { tx.send(C2MeCommand::Data { payload: make_pooled_payload(&[0xAA]), flags: 1, + _permit: make_c2me_permit(), }) .await .expect("priming queue with one frame must succeed"); @@ -119,6 +126,7 @@ async fn c2me_channel_full_path_yields_then_sends() { C2MeCommand::Data { payload: make_pooled_payload(&[0xBB, 0xCC]), flags: 2, + _permit: make_c2me_permit(), }, None, &stats, @@ -138,7 +146,7 @@ async fn c2me_channel_full_path_yields_then_sends() { .expect("receiver should observe primed frame") .expect("first queued command must exist"); match first { - C2MeCommand::Data { payload, flags } => { + C2MeCommand::Data { payload, flags, .. } => { assert_eq!(payload.as_ref(), &[0xAA]); assert_eq!(flags, 1); } @@ -155,7 +163,7 @@ async fn c2me_channel_full_path_yields_then_sends() { .expect("receiver should observe backpressure-resumed frame") .expect("second queued command must exist"); match second { - C2MeCommand::Data { payload, flags } => { + C2MeCommand::Data { payload, flags, .. } => { assert_eq!(payload.as_ref(), &[0xBB, 0xCC]); assert_eq!(flags, 2); } From 1df668144c392052c737c665675c9b9f1612a8ba Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Sat, 25 Apr 2026 13:09:10 +0300 Subject: [PATCH 5/8] Bounded ME Route fairness and IP-Cleanup-Backlog Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com> Signed-off-by: Alexey <247128645+axkurcom@users.noreply.github.com> --- .../tests/load_memory_envelope_tests.rs | 14 +++++++++++ src/ip_tracker.rs | 25 +++++++++++++------ src/tests/ip_tracker_regression_tests.rs | 19 ++++++++++++++ 3 files changed, 50 insertions(+), 8 deletions(-) diff --git a/src/config/tests/load_memory_envelope_tests.rs b/src/config/tests/load_memory_envelope_tests.rs index ea78498..1c201cc 100644 --- a/src/config/tests/load_memory_envelope_tests.rs +++ b/src/config/tests/load_memory_envelope_tests.rs @@ -17,6 +17,20 @@ fn remove_temp_config(path: &PathBuf) { let _ = fs::remove_file(path); } +#[test] +fn defaults_enable_byte_bounded_route_fairness() { + let cfg = ProxyConfig::default(); + + assert!( + cfg.general.me_route_fairshare_enabled, + "D2C route fairness must be enabled by default to bound queued bytes" + ); + assert!( + cfg.general.me_route_backpressure_enabled, + "D2C route backpressure must be enabled by default to shed under sustained pressure" + ); +} + #[test] fn load_rejects_writer_cmd_capacity_above_upper_bound() { let path = write_temp_config( diff --git a/src/ip_tracker.rs b/src/ip_tracker.rs index b4d934f..e3993f1 100644 --- a/src/ip_tracker.rs +++ b/src/ip_tracker.rs @@ -22,7 +22,7 @@ pub struct UserIpTracker { limit_mode: Arc>, limit_window: Arc>, last_compact_epoch_secs: Arc, - cleanup_queue: Arc>>, + cleanup_queue: Arc>>, cleanup_drain_lock: Arc>, } @@ -45,17 +45,21 @@ impl UserIpTracker { limit_mode: Arc::new(RwLock::new(UserMaxUniqueIpsMode::ActiveWindow)), limit_window: Arc::new(RwLock::new(Duration::from_secs(30))), last_compact_epoch_secs: Arc::new(AtomicU64::new(0)), - cleanup_queue: Arc::new(Mutex::new(Vec::new())), + cleanup_queue: Arc::new(Mutex::new(HashMap::new())), cleanup_drain_lock: Arc::new(AsyncMutex::new(())), } } pub fn enqueue_cleanup(&self, user: String, ip: IpAddr) { match self.cleanup_queue.lock() { - Ok(mut queue) => queue.push((user, ip)), + Ok(mut queue) => { + let count = queue.entry((user, ip)).or_insert(0); + *count = count.saturating_add(1); + } Err(poisoned) => { let mut queue = poisoned.into_inner(); - queue.push((user.clone(), ip)); + let count = queue.entry((user.clone(), ip)).or_insert(0); + *count = count.saturating_add(1); self.cleanup_queue.clear_poison(); tracing::warn!( "UserIpTracker cleanup_queue lock poisoned; recovered and enqueued IP cleanup for {} ({})", @@ -75,7 +79,9 @@ impl UserIpTracker { } #[cfg(test)] - pub(crate) fn cleanup_queue_mutex_for_tests(&self) -> Arc>> { + pub(crate) fn cleanup_queue_mutex_for_tests( + &self, + ) -> Arc>> { Arc::clone(&self.cleanup_queue) } @@ -105,11 +111,14 @@ impl UserIpTracker { }; let mut active_ips = self.active_ips.write().await; - for (user, ip) in to_remove { + for ((user, ip), pending_count) in to_remove { + if pending_count == 0 { + continue; + } if let Some(user_ips) = active_ips.get_mut(&user) { if let Some(count) = user_ips.get_mut(&ip) { - if *count > 1 { - *count -= 1; + if *count > pending_count { + *count -= pending_count; } else { user_ips.remove(&ip); } diff --git a/src/tests/ip_tracker_regression_tests.rs b/src/tests/ip_tracker_regression_tests.rs index 0e6656e..193c9c3 100644 --- a/src/tests/ip_tracker_regression_tests.rs +++ b/src/tests/ip_tracker_regression_tests.rs @@ -649,6 +649,25 @@ async fn duplicate_cleanup_entries_do_not_break_future_admission() { ); } +#[tokio::test] +async fn duplicate_cleanup_entries_are_coalesced_until_drain() { + let tracker = UserIpTracker::new(); + let ip = ip_from_idx(7150); + + tracker.enqueue_cleanup("coalesced-cleanup".to_string(), ip); + tracker.enqueue_cleanup("coalesced-cleanup".to_string(), ip); + tracker.enqueue_cleanup("coalesced-cleanup".to_string(), ip); + + assert_eq!( + tracker.cleanup_queue_len_for_tests(), + 1, + "duplicate queued cleanup entries must retain one allocation slot" + ); + + tracker.drain_cleanup_queue().await; + assert_eq!(tracker.cleanup_queue_len_for_tests(), 0); +} + #[tokio::test] async fn stress_repeated_queue_poison_recovery_preserves_admission_progress() { let tracker = UserIpTracker::new(); From 2f2fe9d5d3da7f07a7c6c41e362b47e1b70168e4 Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Sat, 25 Apr 2026 13:54:20 +0300 Subject: [PATCH 6/8] Bound relay queues by bytes Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com> Signed-off-by: Alexey <247128645+axkurcom@users.noreply.github.com> --- .../tests/load_memory_envelope_tests.rs | 14 -- src/proxy/middle_relay.rs | 2 +- ...ddle_relay_atomic_quota_invariant_tests.rs | 2 + src/transport/middle_proxy/mod.rs | 26 ++- src/transport/middle_proxy/reader.rs | 3 +- src/transport/middle_proxy/registry.rs | 190 +++++++++++++++++- 6 files changed, 211 insertions(+), 26 deletions(-) diff --git a/src/config/tests/load_memory_envelope_tests.rs b/src/config/tests/load_memory_envelope_tests.rs index 1c201cc..ea78498 100644 --- a/src/config/tests/load_memory_envelope_tests.rs +++ b/src/config/tests/load_memory_envelope_tests.rs @@ -17,20 +17,6 @@ fn remove_temp_config(path: &PathBuf) { let _ = fs::remove_file(path); } -#[test] -fn defaults_enable_byte_bounded_route_fairness() { - let cfg = ProxyConfig::default(); - - assert!( - cfg.general.me_route_fairshare_enabled, - "D2C route fairness must be enabled by default to bound queued bytes" - ); - assert!( - cfg.general.me_route_backpressure_enabled, - "D2C route backpressure must be enabled by default to shed under sustained pressure" - ); -} - #[test] fn load_rejects_writer_cmd_capacity_above_upper_bound() { let path = write_temp_config( diff --git a/src/proxy/middle_relay.rs b/src/proxy/middle_relay.rs index b0ddb8f..e4b4fe6 100644 --- a/src/proxy/middle_relay.rs +++ b/src/proxy/middle_relay.rs @@ -2329,7 +2329,7 @@ where W: AsyncWrite + Unpin + Send + 'static, { match response { - MeResponse::Data { flags, data } => { + MeResponse::Data { flags, data, .. } => { if batched { trace!(conn_id, bytes = data.len(), flags, "ME->C data (batched)"); } else { diff --git a/src/proxy/tests/middle_relay_atomic_quota_invariant_tests.rs b/src/proxy/tests/middle_relay_atomic_quota_invariant_tests.rs index 7c176bc..18bd583 100644 --- a/src/proxy/tests/middle_relay_atomic_quota_invariant_tests.rs +++ b/src/proxy/tests/middle_relay_atomic_quota_invariant_tests.rs @@ -70,6 +70,7 @@ async fn me_writer_write_fail_keeps_reserved_quota_and_tracks_fail_metrics() { MeResponse::Data { flags: 0, data: payload.clone(), + route_permit: None, }, &mut writer, ProtoTag::Intermediate, @@ -139,6 +140,7 @@ async fn me_writer_pre_write_quota_reject_happens_before_writer_poll() { MeResponse::Data { flags: 0, data: Bytes::from_static(&[0xAA, 0xBB, 0xCC]), + route_permit: None, }, &mut writer, ProtoTag::Intermediate, diff --git a/src/transport/middle_proxy/mod.rs b/src/transport/middle_proxy/mod.rs index 992fec3..3f46a80 100644 --- a/src/transport/middle_proxy/mod.rs +++ b/src/transport/middle_proxy/mod.rs @@ -46,6 +46,7 @@ mod send_adversarial_tests; mod wire; use bytes::Bytes; +use tokio::sync::OwnedSemaphorePermit; #[allow(unused_imports)] pub use config_updater::{ @@ -68,9 +69,32 @@ pub use secret::{fetch_proxy_secret, fetch_proxy_secret_with_upstream}; pub(crate) use selftest::{bnd_snapshot, timeskew_snapshot, upstream_bnd_snapshots}; pub use wire::proto_flags_for_tag; +/// Holds D2C queued-byte capacity until a routed payload is consumed or dropped. +pub struct RouteBytePermit { + _permit: OwnedSemaphorePermit, +} + +impl std::fmt::Debug for RouteBytePermit { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("RouteBytePermit").finish_non_exhaustive() + } +} + +impl RouteBytePermit { + pub(crate) fn new(permit: OwnedSemaphorePermit) -> Self { + Self { _permit: permit } + } +} + +/// Response routed from middle proxy readers to client relay tasks. #[derive(Debug)] pub enum MeResponse { - Data { flags: u32, data: Bytes }, + /// Downstream payload with its queued-byte reservation. + Data { + flags: u32, + data: Bytes, + route_permit: Option, + }, Ack(u32), Close, } diff --git a/src/transport/middle_proxy/reader.rs b/src/transport/middle_proxy/reader.rs index 97fa329..2dae1f1 100644 --- a/src/transport/middle_proxy/reader.rs +++ b/src/transport/middle_proxy/reader.rs @@ -84,6 +84,7 @@ async fn route_data_with_retry( MeResponse::Data { flags, data: data.clone(), + route_permit: None, }, timeout_ms, ) @@ -639,7 +640,7 @@ mod tests { let routed = route_data_with_retry(®, conn_id, 0, Bytes::from_static(b"a"), 20).await; assert!(matches!(routed, RouteResult::Routed)); match rx.recv().await { - Some(MeResponse::Data { flags, data }) => { + Some(MeResponse::Data { flags, data, .. }) => { assert_eq!(flags, 0); assert_eq!(data, Bytes::from_static(b"a")); } diff --git a/src/transport/middle_proxy/registry.rs b/src/transport/middle_proxy/registry.rs index 0c7a0a9..ee2598d 100644 --- a/src/transport/middle_proxy/registry.rs +++ b/src/transport/middle_proxy/registry.rs @@ -1,18 +1,22 @@ use std::collections::{HashMap, HashSet}; use std::net::SocketAddr; +use std::sync::Arc; use std::sync::atomic::{AtomicU8, AtomicU64, Ordering}; use std::time::{Duration, SystemTime, UNIX_EPOCH}; use dashmap::DashMap; use tokio::sync::mpsc::error::TrySendError; -use tokio::sync::{Mutex, mpsc}; +use tokio::sync::{Mutex, Semaphore, mpsc}; -use super::MeResponse; +use super::{MeResponse, RouteBytePermit}; use super::codec::WriterCommand; const ROUTE_BACKPRESSURE_BASE_TIMEOUT_MS: u64 = 25; const ROUTE_BACKPRESSURE_HIGH_TIMEOUT_MS: u64 = 120; const ROUTE_BACKPRESSURE_HIGH_WATERMARK_PCT: u8 = 80; +const ROUTE_QUEUED_BYTE_PERMIT_UNIT: usize = 16 * 1024; +const ROUTE_QUEUED_PERMITS_PER_SLOT: usize = 4; +const ROUTE_QUEUED_MAX_FRAME_PERMITS: usize = 1024; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum RouteResult { @@ -53,6 +57,7 @@ pub(super) struct WriterActivitySnapshot { struct RoutingTable { map: DashMap>, + byte_budget: DashMap>, } struct WriterTable { @@ -105,6 +110,7 @@ pub struct ConnRegistry { route_backpressure_base_timeout_ms: AtomicU64, route_backpressure_high_timeout_ms: AtomicU64, route_backpressure_high_watermark_pct: AtomicU8, + route_byte_permits_per_conn: usize, } impl ConnRegistry { @@ -116,10 +122,20 @@ impl ConnRegistry { } pub fn with_route_channel_capacity(route_channel_capacity: usize) -> Self { + let route_channel_capacity = route_channel_capacity.max(1); + Self::with_route_limits( + route_channel_capacity, + Self::route_byte_permit_budget(route_channel_capacity), + ) + } + + fn with_route_limits(route_channel_capacity: usize, route_byte_permits_per_conn: usize) -> Self { let start = rand::random::() | 1; + let route_channel_capacity = route_channel_capacity.max(1); Self { routing: RoutingTable { map: DashMap::new(), + byte_budget: DashMap::new(), }, writers: WriterTable { map: DashMap::new(), @@ -131,15 +147,30 @@ impl ConnRegistry { inner: Mutex::new(BindingInner::new()), }, next_id: AtomicU64::new(start), - route_channel_capacity: route_channel_capacity.max(1), + route_channel_capacity, route_backpressure_base_timeout_ms: AtomicU64::new(ROUTE_BACKPRESSURE_BASE_TIMEOUT_MS), route_backpressure_high_timeout_ms: AtomicU64::new(ROUTE_BACKPRESSURE_HIGH_TIMEOUT_MS), route_backpressure_high_watermark_pct: AtomicU8::new( ROUTE_BACKPRESSURE_HIGH_WATERMARK_PCT, ), + route_byte_permits_per_conn: route_byte_permits_per_conn.max(1), } } + fn route_data_permits(data_len: usize) -> u32 { + data_len + .max(1) + .div_ceil(ROUTE_QUEUED_BYTE_PERMIT_UNIT) + .min(u32::MAX as usize) as u32 + } + + fn route_byte_permit_budget(route_channel_capacity: usize) -> usize { + route_channel_capacity + .saturating_mul(ROUTE_QUEUED_PERMITS_PER_SLOT) + .max(ROUTE_QUEUED_MAX_FRAME_PERMITS) + .max(1) + } + pub fn route_channel_capacity(&self) -> usize { self.route_channel_capacity } @@ -149,6 +180,14 @@ impl ConnRegistry { Self::with_route_channel_capacity(4096) } + #[cfg(test)] + fn with_route_byte_permits_for_tests( + route_channel_capacity: usize, + route_byte_permits_per_conn: usize, + ) -> Self { + Self::with_route_limits(route_channel_capacity, route_byte_permits_per_conn) + } + pub fn update_route_backpressure_policy( &self, base_timeout_ms: u64, @@ -170,6 +209,9 @@ impl ConnRegistry { let id = self.next_id.fetch_add(1, Ordering::Relaxed); let (tx, rx) = mpsc::channel(self.route_channel_capacity); self.routing.map.insert(id, tx); + self.routing + .byte_budget + .insert(id, Arc::new(Semaphore::new(self.route_byte_permits_per_conn))); (id, rx) } @@ -186,6 +228,7 @@ impl ConnRegistry { /// Unregister connection, returning associated writer_id if any. pub async fn unregister(&self, id: u64) -> Option { self.routing.map.remove(&id); + self.routing.byte_budget.remove(&id); self.hot_binding.map.remove(&id); let mut binding = self.binding.inner.lock().await; binding.meta.remove(&id); @@ -206,6 +249,65 @@ impl ConnRegistry { None } + async fn attach_route_byte_permit( + &self, + id: u64, + resp: MeResponse, + timeout_ms: Option, + ) -> std::result::Result { + let MeResponse::Data { + flags, + data, + route_permit, + } = resp + else { + return Ok(resp); + }; + + if route_permit.is_some() { + return Ok(MeResponse::Data { + flags, + data, + route_permit, + }); + } + + let Some(semaphore) = self + .routing + .byte_budget + .get(&id) + .map(|entry| entry.value().clone()) + else { + return Err(RouteResult::NoConn); + }; + let permits = Self::route_data_permits(data.len()); + let permit = match timeout_ms { + Some(0) => semaphore + .try_acquire_many_owned(permits) + .map_err(|_| RouteResult::QueueFullHigh)?, + Some(timeout_ms) => { + let acquire = semaphore.acquire_many_owned(permits); + match tokio::time::timeout(Duration::from_millis(timeout_ms.max(1)), acquire) + .await + { + Ok(Ok(permit)) => permit, + Ok(Err(_)) => return Err(RouteResult::ChannelClosed), + Err(_) => return Err(RouteResult::QueueFullHigh), + } + } + None => semaphore + .acquire_many_owned(permits) + .await + .map_err(|_| RouteResult::ChannelClosed)?, + }; + + Ok(MeResponse::Data { + flags, + data, + route_permit: Some(RouteBytePermit::new(permit)), + }) + } + #[allow(dead_code)] pub async fn route(&self, id: u64, resp: MeResponse) -> RouteResult { let tx = self.routing.map.get(&id).map(|entry| entry.value().clone()); @@ -214,15 +316,23 @@ impl ConnRegistry { return RouteResult::NoConn; }; + let base_timeout_ms = self + .route_backpressure_base_timeout_ms + .load(Ordering::Relaxed) + .max(1); + let resp = match self + .attach_route_byte_permit(id, resp, Some(base_timeout_ms)) + .await + { + Ok(resp) => resp, + Err(result) => return result, + }; + match tx.try_send(resp) { Ok(()) => RouteResult::Routed, Err(TrySendError::Closed(_)) => RouteResult::ChannelClosed, Err(TrySendError::Full(resp)) => { // Absorb short bursts without dropping/closing the session immediately. - let base_timeout_ms = self - .route_backpressure_base_timeout_ms - .load(Ordering::Relaxed) - .max(1); let high_timeout_ms = self .route_backpressure_high_timeout_ms .load(Ordering::Relaxed) @@ -266,6 +376,10 @@ impl ConnRegistry { let Some(tx) = tx else { return RouteResult::NoConn; }; + let resp = match self.attach_route_byte_permit(id, resp, Some(0)).await { + Ok(resp) => resp, + Err(result) => return result, + }; match tx.try_send(resp) { Ok(()) => RouteResult::Routed, @@ -289,6 +403,13 @@ impl ConnRegistry { let Some(tx) = tx else { return RouteResult::NoConn; }; + let resp = match self + .attach_route_byte_permit(id, resp, Some(timeout_ms)) + .await + { + Ok(resp) => resp, + Err(result) => return result, + }; match tx.try_send(resp) { Ok(()) => RouteResult::Routed, @@ -541,8 +662,10 @@ impl ConnRegistry { mod tests { use std::net::{IpAddr, Ipv4Addr, SocketAddr}; - use super::ConnMeta; - use super::ConnRegistry; + use bytes::Bytes; + + use super::{ConnMeta, ConnRegistry, RouteResult}; + use crate::transport::middle_proxy::MeResponse; #[tokio::test] async fn writer_activity_snapshot_tracks_writer_and_dc_load() { @@ -608,6 +731,55 @@ mod tests { assert_eq!(snapshot.active_sessions_by_target_dc.get(&4), Some(&1)); } + #[tokio::test] + async fn route_data_is_bounded_by_byte_permits_before_channel_capacity() { + let registry = ConnRegistry::with_route_byte_permits_for_tests(4, 1); + let (conn_id, mut rx) = registry.register().await; + let routed = registry + .route_nowait( + conn_id, + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0xAA]), + route_permit: None, + }, + ) + .await; + assert!(matches!(routed, RouteResult::Routed)); + + let blocked = registry + .route_nowait( + conn_id, + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0xBB]), + route_permit: None, + }, + ) + .await; + assert!( + matches!(blocked, RouteResult::QueueFullHigh), + "byte budget must reject data before count capacity is exhausted" + ); + + drop(rx.recv().await); + + let routed_after_drain = registry + .route_nowait( + conn_id, + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0xCC]), + route_permit: None, + }, + ) + .await; + assert!( + matches!(routed_after_drain, RouteResult::Routed), + "receiving queued data must release byte permits" + ); + } + #[tokio::test] async fn bind_writer_rebinds_conn_atomically() { let registry = ConnRegistry::new(); From 37c916056a46a50abed2cb005efbc3a0c1e30ad1 Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Sat, 25 Apr 2026 14:35:35 +0300 Subject: [PATCH 7/8] Rustfmt --- src/proxy/tests/client_security_tests.rs | 9 ++------- src/transport/middle_proxy/registry.rs | 17 ++++++++++------- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/src/proxy/tests/client_security_tests.rs b/src/proxy/tests/client_security_tests.rs index 4505e17..abd4213 100644 --- a/src/proxy/tests/client_security_tests.rs +++ b/src/proxy/tests/client_security_tests.rs @@ -281,13 +281,8 @@ async fn user_connection_reservation_drop_enqueues_cleanup_synchronously() { assert_eq!(ip_tracker.get_active_ip_count(&user).await, 1); assert_eq!(stats.get_user_curr_connects(&user), 1); - let reservation = UserConnectionReservation::new( - stats.clone(), - ip_tracker.clone(), - user.clone(), - ip, - true, - ); + let reservation = + UserConnectionReservation::new(stats.clone(), ip_tracker.clone(), user.clone(), ip, true); // Drop the reservation synchronously without any tokio::spawn/await yielding! drop(reservation); diff --git a/src/transport/middle_proxy/registry.rs b/src/transport/middle_proxy/registry.rs index ee2598d..aca6f9c 100644 --- a/src/transport/middle_proxy/registry.rs +++ b/src/transport/middle_proxy/registry.rs @@ -8,8 +8,8 @@ use dashmap::DashMap; use tokio::sync::mpsc::error::TrySendError; use tokio::sync::{Mutex, Semaphore, mpsc}; -use super::{MeResponse, RouteBytePermit}; use super::codec::WriterCommand; +use super::{MeResponse, RouteBytePermit}; const ROUTE_BACKPRESSURE_BASE_TIMEOUT_MS: u64 = 25; const ROUTE_BACKPRESSURE_HIGH_TIMEOUT_MS: u64 = 120; @@ -129,7 +129,10 @@ impl ConnRegistry { ) } - fn with_route_limits(route_channel_capacity: usize, route_byte_permits_per_conn: usize) -> Self { + fn with_route_limits( + route_channel_capacity: usize, + route_byte_permits_per_conn: usize, + ) -> Self { let start = rand::random::() | 1; let route_channel_capacity = route_channel_capacity.max(1); Self { @@ -209,9 +212,10 @@ impl ConnRegistry { let id = self.next_id.fetch_add(1, Ordering::Relaxed); let (tx, rx) = mpsc::channel(self.route_channel_capacity); self.routing.map.insert(id, tx); - self.routing - .byte_budget - .insert(id, Arc::new(Semaphore::new(self.route_byte_permits_per_conn))); + self.routing.byte_budget.insert( + id, + Arc::new(Semaphore::new(self.route_byte_permits_per_conn)), + ); (id, rx) } @@ -287,8 +291,7 @@ impl ConnRegistry { .map_err(|_| RouteResult::QueueFullHigh)?, Some(timeout_ms) => { let acquire = semaphore.acquire_many_owned(permits); - match tokio::time::timeout(Duration::from_millis(timeout_ms.max(1)), acquire) - .await + match tokio::time::timeout(Duration::from_millis(timeout_ms.max(1)), acquire).await { Ok(Ok(permit)) => permit, Ok(Err(_)) => return Err(RouteResult::ChannelClosed), From e217371dc8983df1d70a39a2d3c830a90948e774 Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Sat, 25 Apr 2026 14:36:51 +0300 Subject: [PATCH 8/8] Bump --- Cargo.lock | 2 +- Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8ede4c6..02bfcda 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2791,7 +2791,7 @@ checksum = "7b2093cf4c8eb1e67749a6762251bc9cd836b6fc171623bd0a9d324d37af2417" [[package]] name = "telemt" -version = "3.4.6" +version = "3.4.7" dependencies = [ "aes", "anyhow", diff --git a/Cargo.toml b/Cargo.toml index 8983d48..a40ce33 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "telemt" -version = "3.4.6" +version = "3.4.7" edition = "2024" [features]