From 3b717c75dae2a6e009260c422272102c83e39c8f Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Tue, 7 Apr 2026 18:18:47 +0300 Subject: [PATCH] Memory Hard-bounds + Handshake Budget in Metrics + No mutable in hotpath ConnRegistry Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com> --- src/config/load.rs | 32 +++++ .../tests/load_memory_envelope_tests.rs | 115 ++++++++++++++++ src/maestro/mod.rs | 1 + src/maestro/runtime_tasks.rs | 4 + src/metrics.rs | 128 ++++++++++++++++-- src/transport/middle_proxy/registry.rs | 83 +++++++----- 6 files changed, 317 insertions(+), 46 deletions(-) create mode 100644 src/config/tests/load_memory_envelope_tests.rs diff --git a/src/config/load.rs b/src/config/load.rs index 32f877d..f9e230c 100644 --- a/src/config/load.rs +++ b/src/config/load.rs @@ -17,6 +17,11 @@ use super::defaults::*; use super::types::*; const ACCESS_SECRET_BYTES: usize = 16; +const MAX_ME_WRITER_CMD_CHANNEL_CAPACITY: usize = 16_384; +const MAX_ME_ROUTE_CHANNEL_CAPACITY: usize = 8_192; +const MAX_ME_C2ME_CHANNEL_CAPACITY: usize = 8_192; +const MIN_MAX_CLIENT_FRAME_BYTES: usize = 4 * 1024; +const MAX_MAX_CLIENT_FRAME_BYTES: usize = 16 * 1024 * 1024; #[derive(Debug, Clone)] pub(crate) struct LoadedConfig { @@ -626,18 +631,41 @@ impl ProxyConfig { "general.me_writer_cmd_channel_capacity must be > 0".to_string(), )); } + if config.general.me_writer_cmd_channel_capacity > MAX_ME_WRITER_CMD_CHANNEL_CAPACITY { + return Err(ProxyError::Config(format!( + "general.me_writer_cmd_channel_capacity must be within [1, {MAX_ME_WRITER_CMD_CHANNEL_CAPACITY}]" + ))); + } if config.general.me_route_channel_capacity == 0 { return Err(ProxyError::Config( "general.me_route_channel_capacity must be > 0".to_string(), )); } + if config.general.me_route_channel_capacity > MAX_ME_ROUTE_CHANNEL_CAPACITY { + return Err(ProxyError::Config(format!( + "general.me_route_channel_capacity must be within [1, {MAX_ME_ROUTE_CHANNEL_CAPACITY}]" + ))); + } if config.general.me_c2me_channel_capacity == 0 { return Err(ProxyError::Config( "general.me_c2me_channel_capacity must be > 0".to_string(), )); } + if config.general.me_c2me_channel_capacity > MAX_ME_C2ME_CHANNEL_CAPACITY { + return Err(ProxyError::Config(format!( + "general.me_c2me_channel_capacity must be within [1, {MAX_ME_C2ME_CHANNEL_CAPACITY}]" + ))); + } + + if !(MIN_MAX_CLIENT_FRAME_BYTES..=MAX_MAX_CLIENT_FRAME_BYTES) + .contains(&config.general.max_client_frame) + { + return Err(ProxyError::Config(format!( + "general.max_client_frame must be within [{MIN_MAX_CLIENT_FRAME_BYTES}, {MAX_MAX_CLIENT_FRAME_BYTES}]" + ))); + } if config.general.me_c2me_send_timeout_ms > 60_000 { return Err(ProxyError::Config( @@ -1346,6 +1374,10 @@ mod load_mask_shape_security_tests; #[path = "tests/load_mask_classifier_prefetch_timeout_security_tests.rs"] mod load_mask_classifier_prefetch_timeout_security_tests; +#[cfg(test)] +#[path = "tests/load_memory_envelope_tests.rs"] +mod load_memory_envelope_tests; + #[cfg(test)] mod tests { use super::*; diff --git a/src/config/tests/load_memory_envelope_tests.rs b/src/config/tests/load_memory_envelope_tests.rs new file mode 100644 index 0000000..b2d14fb --- /dev/null +++ b/src/config/tests/load_memory_envelope_tests.rs @@ -0,0 +1,115 @@ +use super::*; +use std::fs; +use std::path::PathBuf; +use std::time::{SystemTime, UNIX_EPOCH}; + +fn write_temp_config(contents: &str) -> PathBuf { + let nonce = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("system time must be after unix epoch") + .as_nanos(); + let path = std::env::temp_dir().join(format!("telemt-load-memory-envelope-{nonce}.toml")); + fs::write(&path, contents).expect("temp config write must succeed"); + path +} + +fn remove_temp_config(path: &PathBuf) { + let _ = fs::remove_file(path); +} + +#[test] +fn load_rejects_writer_cmd_capacity_above_upper_bound() { + let path = write_temp_config( + r#" +[general] +me_writer_cmd_channel_capacity = 16385 +"#, + ); + + let err = ProxyConfig::load(&path).expect_err("writer command capacity above hard cap must fail"); + let msg = err.to_string(); + assert!( + msg.contains("general.me_writer_cmd_channel_capacity must be within [1, 16384]"), + "error must explain writer command capacity hard cap, got: {msg}" + ); + + remove_temp_config(&path); +} + +#[test] +fn load_rejects_route_channel_capacity_above_upper_bound() { + let path = write_temp_config( + r#" +[general] +me_route_channel_capacity = 8193 +"#, + ); + + let err = ProxyConfig::load(&path).expect_err("route channel capacity above hard cap must fail"); + let msg = err.to_string(); + assert!( + msg.contains("general.me_route_channel_capacity must be within [1, 8192]"), + "error must explain route channel hard cap, got: {msg}" + ); + + remove_temp_config(&path); +} + +#[test] +fn load_rejects_c2me_channel_capacity_above_upper_bound() { + let path = write_temp_config( + r#" +[general] +me_c2me_channel_capacity = 8193 +"#, + ); + + let err = ProxyConfig::load(&path).expect_err("c2me channel capacity above hard cap must fail"); + let msg = err.to_string(); + assert!( + msg.contains("general.me_c2me_channel_capacity must be within [1, 8192]"), + "error must explain c2me channel hard cap, got: {msg}" + ); + + remove_temp_config(&path); +} + +#[test] +fn load_rejects_max_client_frame_above_upper_bound() { + let path = write_temp_config( + r#" +[general] +max_client_frame = 16777217 +"#, + ); + + let err = ProxyConfig::load(&path).expect_err("max_client_frame above hard cap must fail"); + let msg = err.to_string(); + assert!( + msg.contains("general.max_client_frame must be within [4096, 16777216]"), + "error must explain max_client_frame hard cap, got: {msg}" + ); + + remove_temp_config(&path); +} + +#[test] +fn load_accepts_memory_limits_at_hard_upper_bounds() { + let path = write_temp_config( + r#" +[general] +me_writer_cmd_channel_capacity = 16384 +me_route_channel_capacity = 8192 +me_c2me_channel_capacity = 8192 +max_client_frame = 16777216 +"#, + ); + + let cfg = ProxyConfig::load(&path).expect("hard upper bound values must be accepted"); + assert_eq!(cfg.general.me_writer_cmd_channel_capacity, 16384); + assert_eq!(cfg.general.me_route_channel_capacity, 8192); + assert_eq!(cfg.general.me_c2me_channel_capacity, 8192); + assert_eq!(cfg.general.max_client_frame, 16 * 1024 * 1024); + + remove_temp_config(&path); +} diff --git a/src/maestro/mod.rs b/src/maestro/mod.rs index eed8d2e..00b3b2d 100644 --- a/src/maestro/mod.rs +++ b/src/maestro/mod.rs @@ -786,6 +786,7 @@ async fn run_inner( &startup_tracker, stats.clone(), beobachten.clone(), + shared_state.clone(), ip_tracker.clone(), config_rx.clone(), ) diff --git a/src/maestro/runtime_tasks.rs b/src/maestro/runtime_tasks.rs index b8b10da..5b3f2e0 100644 --- a/src/maestro/runtime_tasks.rs +++ b/src/maestro/runtime_tasks.rs @@ -13,6 +13,7 @@ use crate::crypto::SecureRandom; use crate::ip_tracker::UserIpTracker; use crate::metrics; use crate::network::probe::NetworkProbe; +use crate::proxy::shared_state::ProxySharedState; use crate::startup::{ COMPONENT_CONFIG_WATCHER_START, COMPONENT_METRICS_START, COMPONENT_RUNTIME_READY, StartupTracker, @@ -287,6 +288,7 @@ pub(crate) async fn spawn_metrics_if_configured( startup_tracker: &Arc, stats: Arc, beobachten: Arc, + shared_state: Arc, ip_tracker: Arc, config_rx: watch::Receiver>, ) { @@ -320,6 +322,7 @@ pub(crate) async fn spawn_metrics_if_configured( .await; let stats = stats.clone(); let beobachten = beobachten.clone(); + let shared_state = shared_state.clone(); let config_rx_metrics = config_rx.clone(); let ip_tracker_metrics = ip_tracker.clone(); let whitelist = config.server.metrics_whitelist.clone(); @@ -331,6 +334,7 @@ pub(crate) async fn spawn_metrics_if_configured( listen_backlog, stats, beobachten, + shared_state, ip_tracker_metrics, config_rx_metrics, whitelist, diff --git a/src/metrics.rs b/src/metrics.rs index 5cb1e77..7130e28 100644 --- a/src/metrics.rs +++ b/src/metrics.rs @@ -15,6 +15,7 @@ use tracing::{debug, info, warn}; use crate::config::ProxyConfig; use crate::ip_tracker::UserIpTracker; +use crate::proxy::shared_state::ProxySharedState; use crate::stats::Stats; use crate::stats::beobachten::BeobachtenStore; use crate::transport::{ListenOptions, create_listener}; @@ -25,6 +26,7 @@ pub async fn serve( listen_backlog: u32, stats: Arc, beobachten: Arc, + shared_state: Arc, ip_tracker: Arc, config_rx: tokio::sync::watch::Receiver>, whitelist: Vec, @@ -45,7 +47,13 @@ pub async fn serve( Ok(listener) => { info!("Metrics endpoint: http://{}/metrics and /beobachten", addr); serve_listener( - listener, stats, beobachten, ip_tracker, config_rx, whitelist, + listener, + stats, + beobachten, + shared_state, + ip_tracker, + config_rx, + whitelist, ) .await; } @@ -94,13 +102,20 @@ pub async fn serve( } (Some(listener), None) | (None, Some(listener)) => { serve_listener( - listener, stats, beobachten, ip_tracker, config_rx, whitelist, + listener, + stats, + beobachten, + shared_state, + ip_tracker, + config_rx, + whitelist, ) .await; } (Some(listener4), Some(listener6)) => { let stats_v6 = stats.clone(); let beobachten_v6 = beobachten.clone(); + let shared_state_v6 = shared_state.clone(); let ip_tracker_v6 = ip_tracker.clone(); let config_rx_v6 = config_rx.clone(); let whitelist_v6 = whitelist.clone(); @@ -109,6 +124,7 @@ pub async fn serve( listener6, stats_v6, beobachten_v6, + shared_state_v6, ip_tracker_v6, config_rx_v6, whitelist_v6, @@ -116,7 +132,13 @@ pub async fn serve( .await; }); serve_listener( - listener4, stats, beobachten, ip_tracker, config_rx, whitelist, + listener4, + stats, + beobachten, + shared_state, + ip_tracker, + config_rx, + whitelist, ) .await; } @@ -142,6 +164,7 @@ async fn serve_listener( listener: TcpListener, stats: Arc, beobachten: Arc, + shared_state: Arc, ip_tracker: Arc, config_rx: tokio::sync::watch::Receiver>, whitelist: Arc>, @@ -162,15 +185,19 @@ async fn serve_listener( let stats = stats.clone(); let beobachten = beobachten.clone(); + let shared_state = shared_state.clone(); let ip_tracker = ip_tracker.clone(); let config_rx_conn = config_rx.clone(); tokio::spawn(async move { let svc = service_fn(move |req| { let stats = stats.clone(); let beobachten = beobachten.clone(); + let shared_state = shared_state.clone(); let ip_tracker = ip_tracker.clone(); let config = config_rx_conn.borrow().clone(); - async move { handle(req, &stats, &beobachten, &ip_tracker, &config).await } + async move { + handle(req, &stats, &beobachten, &shared_state, &ip_tracker, &config).await + } }); if let Err(e) = http1::Builder::new() .serve_connection(hyper_util::rt::TokioIo::new(stream), svc) @@ -186,11 +213,12 @@ async fn handle( req: Request, stats: &Stats, beobachten: &BeobachtenStore, + shared_state: &ProxySharedState, ip_tracker: &UserIpTracker, config: &ProxyConfig, ) -> Result>, Infallible> { if req.uri().path() == "/metrics" { - let body = render_metrics(stats, config, ip_tracker).await; + let body = render_metrics(stats, shared_state, config, ip_tracker).await; let resp = Response::builder() .status(StatusCode::OK) .header("content-type", "text/plain; version=0.0.4; charset=utf-8") @@ -225,7 +253,12 @@ fn render_beobachten(beobachten: &BeobachtenStore, config: &ProxyConfig) -> Stri beobachten.snapshot_text(ttl) } -async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIpTracker) -> String { +async fn render_metrics( + stats: &Stats, + shared_state: &ProxySharedState, + config: &ProxyConfig, + ip_tracker: &UserIpTracker, +) -> String { use std::fmt::Write; let mut out = String::with_capacity(4096); let telemetry = stats.telemetry_policy(); @@ -359,6 +392,42 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp } ); + let _ = writeln!( + out, + "# HELP telemt_auth_expensive_checks_total Expensive authentication candidate checks executed during handshake validation" + ); + let _ = writeln!(out, "# TYPE telemt_auth_expensive_checks_total counter"); + let _ = writeln!( + out, + "telemt_auth_expensive_checks_total {}", + if core_enabled { + shared_state + .handshake + .auth_expensive_checks_total + .load(std::sync::atomic::Ordering::Relaxed) + } else { + 0 + } + ); + + let _ = writeln!( + out, + "# HELP telemt_auth_budget_exhausted_total Handshake validations that hit authentication candidate budget limits" + ); + let _ = writeln!(out, "# TYPE telemt_auth_budget_exhausted_total counter"); + let _ = writeln!( + out, + "telemt_auth_budget_exhausted_total {}", + if core_enabled { + shared_state + .handshake + .auth_budget_exhausted_total + .load(std::sync::atomic::Ordering::Relaxed) + } else { + 0 + } + ); + let _ = writeln!( out, "# HELP telemt_accept_permit_timeout_total Accepted connections dropped due to permit wait timeout" @@ -2847,6 +2916,7 @@ mod tests { #[tokio::test] async fn test_render_metrics_format() { let stats = Arc::new(Stats::new()); + let shared_state = ProxySharedState::new(); let tracker = UserIpTracker::new(); let mut config = ProxyConfig::default(); config @@ -2858,6 +2928,14 @@ mod tests { stats.increment_connects_all(); stats.increment_connects_bad(); stats.increment_handshake_timeouts(); + shared_state + .handshake + .auth_expensive_checks_total + .fetch_add(9, std::sync::atomic::Ordering::Relaxed); + shared_state + .handshake + .auth_budget_exhausted_total + .fetch_add(2, std::sync::atomic::Ordering::Relaxed); stats.increment_upstream_connect_attempt_total(); stats.increment_upstream_connect_attempt_total(); stats.increment_upstream_connect_success_total(); @@ -2901,11 +2979,13 @@ mod tests { .await .unwrap(); - let output = render_metrics(&stats, &config, &tracker).await; + let output = render_metrics(&stats, shared_state.as_ref(), &config, &tracker).await; assert!(output.contains("telemt_connections_total 2")); assert!(output.contains("telemt_connections_bad_total 1")); assert!(output.contains("telemt_handshake_timeouts_total 1")); + assert!(output.contains("telemt_auth_expensive_checks_total 9")); + assert!(output.contains("telemt_auth_budget_exhausted_total 2")); assert!(output.contains("telemt_upstream_connect_attempt_total 2")); assert!(output.contains("telemt_upstream_connect_success_total 1")); assert!(output.contains("telemt_upstream_connect_fail_total 1")); @@ -2960,12 +3040,15 @@ mod tests { #[tokio::test] async fn test_render_empty_stats() { let stats = Stats::new(); + let shared_state = ProxySharedState::new(); let tracker = UserIpTracker::new(); let config = ProxyConfig::default(); - let output = render_metrics(&stats, &config, &tracker).await; + let output = render_metrics(&stats, &shared_state, &config, &tracker).await; assert!(output.contains("telemt_connections_total 0")); assert!(output.contains("telemt_connections_bad_total 0")); assert!(output.contains("telemt_handshake_timeouts_total 0")); + assert!(output.contains("telemt_auth_expensive_checks_total 0")); + assert!(output.contains("telemt_auth_budget_exhausted_total 0")); assert!(output.contains("telemt_user_unique_ips_current{user=")); assert!(output.contains("telemt_user_unique_ips_recent_window{user=")); } @@ -2973,6 +3056,7 @@ mod tests { #[tokio::test] async fn test_render_uses_global_each_unique_ip_limit() { let stats = Stats::new(); + let shared_state = ProxySharedState::new(); stats.increment_user_connects("alice"); stats.increment_user_curr_connects("alice"); let tracker = UserIpTracker::new(); @@ -2983,7 +3067,7 @@ mod tests { let mut config = ProxyConfig::default(); config.access.user_max_unique_ips_global_each = 2; - let output = render_metrics(&stats, &config, &tracker).await; + let output = render_metrics(&stats, &shared_state, &config, &tracker).await; assert!(output.contains("telemt_user_unique_ips_limit{user=\"alice\"} 2")); assert!(output.contains("telemt_user_unique_ips_utilization{user=\"alice\"} 0.500000")); @@ -2992,13 +3076,16 @@ mod tests { #[tokio::test] async fn test_render_has_type_annotations() { let stats = Stats::new(); + let shared_state = ProxySharedState::new(); let tracker = UserIpTracker::new(); let config = ProxyConfig::default(); - let output = render_metrics(&stats, &config, &tracker).await; + let output = render_metrics(&stats, &shared_state, &config, &tracker).await; assert!(output.contains("# TYPE telemt_uptime_seconds gauge")); assert!(output.contains("# TYPE telemt_connections_total counter")); assert!(output.contains("# TYPE telemt_connections_bad_total counter")); assert!(output.contains("# TYPE telemt_handshake_timeouts_total counter")); + assert!(output.contains("# TYPE telemt_auth_expensive_checks_total counter")); + assert!(output.contains("# TYPE telemt_auth_budget_exhausted_total counter")); assert!(output.contains("# TYPE telemt_upstream_connect_attempt_total counter")); assert!(output.contains("# TYPE telemt_me_rpc_proxy_req_signal_sent_total counter")); assert!(output.contains("# TYPE telemt_me_idle_close_by_peer_total counter")); @@ -3035,6 +3122,7 @@ mod tests { async fn test_endpoint_integration() { let stats = Arc::new(Stats::new()); let beobachten = Arc::new(BeobachtenStore::new()); + let shared_state = ProxySharedState::new(); let tracker = UserIpTracker::new(); let mut config = ProxyConfig::default(); stats.increment_connects_all(); @@ -3042,7 +3130,7 @@ mod tests { stats.increment_connects_all(); let req = Request::builder().uri("/metrics").body(()).unwrap(); - let resp = handle(req, &stats, &beobachten, &tracker, &config) + let resp = handle(req, &stats, &beobachten, shared_state.as_ref(), &tracker, &config) .await .unwrap(); assert_eq!(resp.status(), StatusCode::OK); @@ -3061,7 +3149,14 @@ mod tests { Duration::from_secs(600), ); let req_beob = Request::builder().uri("/beobachten").body(()).unwrap(); - let resp_beob = handle(req_beob, &stats, &beobachten, &tracker, &config) + let resp_beob = handle( + req_beob, + &stats, + &beobachten, + shared_state.as_ref(), + &tracker, + &config, + ) .await .unwrap(); assert_eq!(resp_beob.status(), StatusCode::OK); @@ -3071,7 +3166,14 @@ mod tests { assert!(beob_text.contains("203.0.113.10-1")); let req404 = Request::builder().uri("/other").body(()).unwrap(); - let resp404 = handle(req404, &stats, &beobachten, &tracker, &config) + let resp404 = handle( + req404, + &stats, + &beobachten, + shared_state.as_ref(), + &tracker, + &config, + ) .await .unwrap(); assert_eq!(resp404.status(), StatusCode::NOT_FOUND); diff --git a/src/transport/middle_proxy/registry.rs b/src/transport/middle_proxy/registry.rs index ff4a68b..17fce47 100644 --- a/src/transport/middle_proxy/registry.rs +++ b/src/transport/middle_proxy/registry.rs @@ -55,6 +55,20 @@ struct RoutingTable { map: DashMap>, } +struct WriterTable { + map: DashMap>, +} + +#[derive(Clone)] +struct HotConnBinding { + writer_id: u64, + meta: ConnMeta, +} + +struct HotBindingTable { + map: DashMap, +} + struct BindingState { inner: Mutex, } @@ -83,6 +97,8 @@ impl BindingInner { pub struct ConnRegistry { routing: RoutingTable, + writers: WriterTable, + hot_binding: HotBindingTable, binding: BindingState, next_id: AtomicU64, route_channel_capacity: usize, @@ -105,6 +121,12 @@ impl ConnRegistry { routing: RoutingTable { map: DashMap::new(), }, + writers: WriterTable { + map: DashMap::new(), + }, + hot_binding: HotBindingTable { + map: DashMap::new(), + }, binding: BindingState { inner: Mutex::new(BindingInner::new()), }, @@ -149,16 +171,18 @@ impl ConnRegistry { pub async fn register_writer(&self, writer_id: u64, tx: mpsc::Sender) { let mut binding = self.binding.inner.lock().await; - binding.writers.insert(writer_id, tx); + binding.writers.insert(writer_id, tx.clone()); binding .conns_for_writer .entry(writer_id) .or_insert_with(HashSet::new); + self.writers.map.insert(writer_id, tx); } /// Unregister connection, returning associated writer_id if any. pub async fn unregister(&self, id: u64) -> Option { self.routing.map.remove(&id); + self.hot_binding.map.remove(&id); let mut binding = self.binding.inner.lock().await; binding.meta.remove(&id); if let Some(writer_id) = binding.writer_for_conn.remove(&id) { @@ -325,13 +349,20 @@ impl ConnRegistry { } binding.meta.insert(conn_id, meta.clone()); - binding.last_meta_for_writer.insert(writer_id, meta); + binding.last_meta_for_writer.insert(writer_id, meta.clone()); binding.writer_idle_since_epoch_secs.remove(&writer_id); binding .conns_for_writer .entry(writer_id) .or_insert_with(HashSet::new) .insert(conn_id); + self.hot_binding.map.insert( + conn_id, + HotConnBinding { + writer_id, + meta, + }, + ); true } @@ -392,39 +423,12 @@ impl ConnRegistry { } pub async fn get_writer(&self, conn_id: u64) -> Option { - let mut binding = self.binding.inner.lock().await; - // ROUTING IS THE SOURCE OF TRUTH: - // stale bindings are ignored and lazily cleaned when routing no longer - // contains the connection. if !self.routing.map.contains_key(&conn_id) { - binding.meta.remove(&conn_id); - if let Some(stale_writer_id) = binding.writer_for_conn.remove(&conn_id) - && let Some(conns) = binding.conns_for_writer.get_mut(&stale_writer_id) - { - conns.remove(&conn_id); - if conns.is_empty() { - binding - .writer_idle_since_epoch_secs - .insert(stale_writer_id, Self::now_epoch_secs()); - } - } return None; } - let writer_id = binding.writer_for_conn.get(&conn_id).copied()?; - let Some(writer) = binding.writers.get(&writer_id).cloned() else { - binding.writer_for_conn.remove(&conn_id); - binding.meta.remove(&conn_id); - if let Some(conns) = binding.conns_for_writer.get_mut(&writer_id) { - conns.remove(&conn_id); - if conns.is_empty() { - binding - .writer_idle_since_epoch_secs - .insert(writer_id, Self::now_epoch_secs()); - } - } - return None; - }; + let writer_id = self.hot_binding.map.get(&conn_id).map(|entry| entry.writer_id)?; + let writer = self.writers.map.get(&writer_id).map(|entry| entry.value().clone())?; Some(ConnWriter { writer_id, tx: writer, @@ -439,6 +443,7 @@ impl ConnRegistry { pub async fn writer_lost(&self, writer_id: u64) -> Vec { let mut binding = self.binding.inner.lock().await; binding.writers.remove(&writer_id); + self.writers.map.remove(&writer_id); binding.last_meta_for_writer.remove(&writer_id); binding.writer_idle_since_epoch_secs.remove(&writer_id); let conns = binding @@ -454,6 +459,15 @@ impl ConnRegistry { continue; } binding.writer_for_conn.remove(&conn_id); + let remove_hot = self + .hot_binding + .map + .get(&conn_id) + .map(|hot| hot.writer_id == writer_id) + .unwrap_or(false); + if remove_hot { + self.hot_binding.map.remove(&conn_id); + } if let Some(m) = binding.meta.get(&conn_id) { out.push(BoundConn { conn_id, @@ -466,8 +480,10 @@ impl ConnRegistry { #[allow(dead_code)] pub async fn get_meta(&self, conn_id: u64) -> Option { - let binding = self.binding.inner.lock().await; - binding.meta.get(&conn_id).cloned() + self.hot_binding + .map + .get(&conn_id) + .map(|entry| entry.meta.clone()) } pub async fn is_writer_empty(&self, writer_id: u64) -> bool { @@ -491,6 +507,7 @@ impl ConnRegistry { } binding.writers.remove(&writer_id); + self.writers.map.remove(&writer_id); binding.last_meta_for_writer.remove(&writer_id); binding.writer_idle_since_epoch_secs.remove(&writer_id); binding.conns_for_writer.remove(&writer_id);