From 822bcbf7a50c5bc826e313699cf83a0b12113e42 Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Tue, 17 Mar 2026 11:21:35 +0300 Subject: [PATCH 1/3] Update Cargo.toml --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index a482ca4..4e12cad 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "telemt" -version = "3.3.19" +version = "3.3.20" edition = "2024" [dependencies] From d78360982cff2962b1cccc7393c4d23716abb7b7 Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Tue, 17 Mar 2026 13:02:12 +0300 Subject: [PATCH 2/3] Hot-Reload fixes --- src/config/hot_reload.rs | 100 ++++++++++++++++------------------- src/maestro/helpers.rs | 48 +++++++++++++++++ src/maestro/mod.rs | 18 +++++-- src/maestro/runtime_tasks.rs | 6 +-- 4 files changed, 109 insertions(+), 63 deletions(-) diff --git a/src/config/hot_reload.rs b/src/config/hot_reload.rs index 6f07a4b..d781f67 100644 --- a/src/config/hot_reload.rs +++ b/src/config/hot_reload.rs @@ -37,7 +37,6 @@ use crate::config::{ }; use super::load::{LoadedConfig, ProxyConfig}; -const HOT_RELOAD_STABLE_SNAPSHOTS: u8 = 2; const HOT_RELOAD_DEBOUNCE: Duration = Duration::from_millis(50); // ── Hot fields ──────────────────────────────────────────────────────────────── @@ -329,41 +328,19 @@ impl WatchManifest { #[derive(Debug, Default)] struct ReloadState { applied_snapshot_hash: Option, - candidate_snapshot_hash: Option, - candidate_hits: u8, } impl ReloadState { fn new(applied_snapshot_hash: Option) -> Self { - Self { - applied_snapshot_hash, - candidate_snapshot_hash: None, - candidate_hits: 0, - } + Self { applied_snapshot_hash } } fn is_applied(&self, hash: u64) -> bool { self.applied_snapshot_hash == Some(hash) } - fn observe_candidate(&mut self, hash: u64) -> u8 { - if self.candidate_snapshot_hash == Some(hash) { - self.candidate_hits = self.candidate_hits.saturating_add(1); - } else { - self.candidate_snapshot_hash = Some(hash); - self.candidate_hits = 1; - } - self.candidate_hits - } - - fn reset_candidate(&mut self) { - self.candidate_snapshot_hash = None; - self.candidate_hits = 0; - } - fn mark_applied(&mut self, hash: u64) { self.applied_snapshot_hash = Some(hash); - self.reset_candidate(); } } @@ -1138,7 +1115,6 @@ fn reload_config( let loaded = match ProxyConfig::load_with_metadata(config_path) { Ok(loaded) => loaded, Err(e) => { - reload_state.reset_candidate(); error!("config reload: failed to parse {:?}: {}", config_path, e); return None; } @@ -1151,7 +1127,6 @@ fn reload_config( let next_manifest = WatchManifest::from_source_files(&source_files); if let Err(e) = new_cfg.validate() { - reload_state.reset_candidate(); error!("config reload: validation failed: {}; keeping old config", e); return Some(next_manifest); } @@ -1160,17 +1135,6 @@ fn reload_config( return Some(next_manifest); } - let candidate_hits = reload_state.observe_candidate(rendered_hash); - if candidate_hits < HOT_RELOAD_STABLE_SNAPSHOTS { - info!( - snapshot_hash = rendered_hash, - candidate_hits, - required_hits = HOT_RELOAD_STABLE_SNAPSHOTS, - "config reload: candidate snapshot observed but not stable yet" - ); - return Some(next_manifest); - } - let old_cfg = config_tx.borrow().clone(); let applied_cfg = overlay_hot_fields(&old_cfg, &new_cfg); let old_hot = HotFields::from_config(&old_cfg); @@ -1190,7 +1154,6 @@ fn reload_config( if old_hot.dns_overrides != applied_hot.dns_overrides && let Err(e) = crate::network::dns_overrides::install_entries(&applied_hot.dns_overrides) { - reload_state.reset_candidate(); error!( "config reload: invalid network.dns_overrides: {}; keeping old config", e @@ -1334,14 +1297,28 @@ pub fn spawn_config_watcher( tokio::time::sleep(HOT_RELOAD_DEBOUNCE).await; while notify_rx.try_recv().is_ok() {} - if let Some(next_manifest) = reload_config( + let mut next_manifest = reload_config( &config_path, &config_tx, &log_tx, detected_ip_v4, detected_ip_v6, &mut reload_state, - ) { + ); + if next_manifest.is_none() { + tokio::time::sleep(HOT_RELOAD_DEBOUNCE).await; + while notify_rx.try_recv().is_ok() {} + next_manifest = reload_config( + &config_path, + &config_tx, + &log_tx, + detected_ip_v4, + detected_ip_v6, + &mut reload_state, + ); + } + + if let Some(next_manifest) = next_manifest { apply_watch_manifest( inotify_watcher.as_mut(), poll_watcher.as_mut(), @@ -1466,7 +1443,7 @@ mod tests { } #[test] - fn reload_requires_stable_snapshot_before_hot_apply() { + fn reload_applies_hot_change_on_first_observed_snapshot() { let initial_tag = "11111111111111111111111111111111"; let final_tag = "22222222222222222222222222222222"; let path = temp_config_path("telemt_hot_reload_stable"); @@ -1478,20 +1455,7 @@ mod tests { let (log_tx, _log_rx) = watch::channel(initial_cfg.general.log_level.clone()); let mut reload_state = ReloadState::new(Some(initial_hash)); - write_reload_config(&path, None, None); - reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap(); - assert_eq!( - config_tx.borrow().general.ad_tag.as_deref(), - Some(initial_tag) - ); - write_reload_config(&path, Some(final_tag), None); - reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap(); - assert_eq!( - config_tx.borrow().general.ad_tag.as_deref(), - Some(initial_tag) - ); - reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap(); assert_eq!(config_tx.borrow().general.ad_tag.as_deref(), Some(final_tag)); @@ -1513,7 +1477,6 @@ mod tests { write_reload_config(&path, Some(final_tag), Some(initial_cfg.server.port + 1)); reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap(); - reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap(); let applied = config_tx.borrow().clone(); assert_eq!(applied.general.ad_tag.as_deref(), Some(final_tag)); @@ -1521,4 +1484,31 @@ mod tests { let _ = std::fs::remove_file(path); } + + #[test] + fn reload_recovers_after_parse_error_on_next_attempt() { + let initial_tag = "cccccccccccccccccccccccccccccccc"; + let final_tag = "dddddddddddddddddddddddddddddddd"; + let path = temp_config_path("telemt_hot_reload_parse_recovery"); + + write_reload_config(&path, Some(initial_tag), None); + let initial_cfg = Arc::new(ProxyConfig::load(&path).unwrap()); + let initial_hash = ProxyConfig::load_with_metadata(&path).unwrap().rendered_hash; + let (config_tx, _config_rx) = watch::channel(initial_cfg.clone()); + let (log_tx, _log_rx) = watch::channel(initial_cfg.general.log_level.clone()); + let mut reload_state = ReloadState::new(Some(initial_hash)); + + std::fs::write(&path, "[access.users\nuser = \"broken\"\n").unwrap(); + assert!(reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).is_none()); + assert_eq!( + config_tx.borrow().general.ad_tag.as_deref(), + Some(initial_tag) + ); + + write_reload_config(&path, Some(final_tag), None); + reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap(); + assert_eq!(config_tx.borrow().general.ad_tag.as_deref(), Some(final_tag)); + + let _ = std::fs::remove_file(path); + } } diff --git a/src/maestro/helpers.rs b/src/maestro/helpers.rs index 78f3ec4..029d0ee 100644 --- a/src/maestro/helpers.rs +++ b/src/maestro/helpers.rs @@ -10,6 +10,16 @@ use crate::transport::middle_proxy::{ ProxyConfigData, fetch_proxy_config_with_raw, load_proxy_config_cache, save_proxy_config_cache, }; +pub(crate) fn resolve_runtime_config_path(config_path_cli: &str, startup_cwd: &std::path::Path) -> PathBuf { + let raw = PathBuf::from(config_path_cli); + let absolute = if raw.is_absolute() { + raw + } else { + startup_cwd.join(raw) + }; + absolute.canonicalize().unwrap_or(absolute) +} + pub(crate) fn parse_cli() -> (String, Option, bool, Option) { let mut config_path = "config.toml".to_string(); let mut data_path: Option = None; @@ -96,6 +106,44 @@ pub(crate) fn parse_cli() -> (String, Option, bool, Option) { (config_path, data_path, silent, log_level) } +#[cfg(test)] +mod tests { + use super::resolve_runtime_config_path; + + #[test] + fn resolve_runtime_config_path_anchors_relative_to_startup_cwd() { + let nonce = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_nanos(); + let startup_cwd = std::env::temp_dir().join(format!("telemt_cfg_path_{nonce}")); + std::fs::create_dir_all(&startup_cwd).unwrap(); + let target = startup_cwd.join("config.toml"); + std::fs::write(&target, " ").unwrap(); + + let resolved = resolve_runtime_config_path("config.toml", &startup_cwd); + assert_eq!(resolved, target.canonicalize().unwrap()); + + let _ = std::fs::remove_file(&target); + let _ = std::fs::remove_dir(&startup_cwd); + } + + #[test] + fn resolve_runtime_config_path_keeps_absolute_for_missing_file() { + let nonce = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_nanos(); + let startup_cwd = std::env::temp_dir().join(format!("telemt_cfg_path_missing_{nonce}")); + std::fs::create_dir_all(&startup_cwd).unwrap(); + + let resolved = resolve_runtime_config_path("missing.toml", &startup_cwd); + assert_eq!(resolved, startup_cwd.join("missing.toml")); + + let _ = std::fs::remove_dir(&startup_cwd); + } +} + pub(crate) fn print_proxy_links(host: &str, port: u16, config: &ProxyConfig) { info!(target: "telemt::links", "--- Proxy Links ({}) ---", host); for user_name in config.general.links.show.resolve_users(&config.access.users) { diff --git a/src/maestro/mod.rs b/src/maestro/mod.rs index da00b40..d4ce2e0 100644 --- a/src/maestro/mod.rs +++ b/src/maestro/mod.rs @@ -45,7 +45,7 @@ use crate::startup::{ use crate::stream::BufferPool; use crate::transport::middle_proxy::MePool; use crate::transport::UpstreamManager; -use helpers::parse_cli; +use helpers::{parse_cli, resolve_runtime_config_path}; /// Runs the full telemt runtime startup pipeline and blocks until shutdown. pub async fn run() -> std::result::Result<(), Box> { @@ -58,18 +58,26 @@ pub async fn run() -> std::result::Result<(), Box> { startup_tracker .start_component(COMPONENT_CONFIG_LOAD, Some("load and validate config".to_string())) .await; - let (config_path, data_path, cli_silent, cli_log_level) = parse_cli(); + let (config_path_cli, data_path, cli_silent, cli_log_level) = parse_cli(); + let startup_cwd = match std::env::current_dir() { + Ok(cwd) => cwd, + Err(e) => { + eprintln!("[telemt] Can't read current_dir: {}", e); + std::process::exit(1); + } + }; + let config_path = resolve_runtime_config_path(&config_path_cli, &startup_cwd); let mut config = match ProxyConfig::load(&config_path) { Ok(c) => c, Err(e) => { - if std::path::Path::new(&config_path).exists() { + if config_path.exists() { eprintln!("[telemt] Error: {}", e); std::process::exit(1); } else { let default = ProxyConfig::default(); std::fs::write(&config_path, toml::to_string_pretty(&default).unwrap()).unwrap(); - eprintln!("[telemt] Created default config at {}", config_path); + eprintln!("[telemt] Created default config at {}", config_path.display()); default } } @@ -258,7 +266,7 @@ pub async fn run() -> std::result::Result<(), Box> { let route_runtime_api = route_runtime.clone(); let config_rx_api = api_config_rx.clone(); let admission_rx_api = admission_rx.clone(); - let config_path_api = std::path::PathBuf::from(&config_path); + let config_path_api = config_path.clone(); let startup_tracker_api = startup_tracker.clone(); let detected_ips_rx_api = detected_ips_rx.clone(); tokio::spawn(async move { diff --git a/src/maestro/runtime_tasks.rs b/src/maestro/runtime_tasks.rs index 329e267..d4bda4d 100644 --- a/src/maestro/runtime_tasks.rs +++ b/src/maestro/runtime_tasks.rs @@ -1,5 +1,5 @@ use std::net::IpAddr; -use std::path::PathBuf; +use std::path::Path; use std::sync::Arc; use tokio::sync::{mpsc, watch}; @@ -32,7 +32,7 @@ pub(crate) struct RuntimeWatches { #[allow(clippy::too_many_arguments)] pub(crate) async fn spawn_runtime_tasks( config: &Arc, - config_path: &str, + config_path: &Path, probe: &NetworkProbe, prefer_ipv6: bool, decision_ipv4_dc: bool, @@ -83,7 +83,7 @@ pub(crate) async fn spawn_runtime_tasks( watch::Receiver>, watch::Receiver, ) = spawn_config_watcher( - PathBuf::from(config_path), + config_path.to_path_buf(), config.clone(), detected_ip_v4, detected_ip_v6, From 2e8be87ccfd2287d01858e7e863131e52812b5b8 Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Tue, 17 Mar 2026 13:58:01 +0300 Subject: [PATCH 3/3] ME Writer Draining-state fixes --- src/transport/middle_proxy/health.rs | 235 +++++++++++++++++----- src/transport/middle_proxy/pool.rs | 28 +++ src/transport/middle_proxy/pool_reinit.rs | 52 ++++- src/transport/middle_proxy/pool_writer.rs | 2 + src/transport/middle_proxy/registry.rs | 44 ++++ 5 files changed, 309 insertions(+), 52 deletions(-) diff --git a/src/transport/middle_proxy/health.rs b/src/transport/middle_proxy/health.rs index e5f4260..a2e107d 100644 --- a/src/transport/middle_proxy/health.rs +++ b/src/transport/middle_proxy/health.rs @@ -115,59 +115,109 @@ async fn reap_draining_writers( pool: &Arc, warn_next_allowed: &mut HashMap, ) { + if pool.draining_active_runtime() == 0 { + return; + } + let now_epoch_secs = MePool::now_epoch_secs(); let now = Instant::now(); let drain_ttl_secs = pool.me_pool_drain_ttl_secs.load(std::sync::atomic::Ordering::Relaxed); let drain_threshold = pool .me_pool_drain_threshold .load(std::sync::atomic::Ordering::Relaxed); - let writers = pool.writers.read().await.clone(); - let mut draining_writers = Vec::new(); - for writer in writers { - if !writer.draining.load(std::sync::atomic::Ordering::Relaxed) { - continue; + let mut draining_writers = { + let writers = pool.writers.read().await; + let mut draining_writers = Vec::::new(); + for writer in writers.iter() { + if !writer.draining.load(std::sync::atomic::Ordering::Relaxed) { + continue; + } + draining_writers.push(DrainingWriterSnapshot { + id: writer.id, + writer_dc: writer.writer_dc, + addr: writer.addr, + generation: writer.generation, + created_at: writer.created_at, + draining_started_at_epoch_secs: writer + .draining_started_at_epoch_secs + .load(std::sync::atomic::Ordering::Relaxed), + drain_deadline_epoch_secs: writer + .drain_deadline_epoch_secs + .load(std::sync::atomic::Ordering::Relaxed), + allow_drain_fallback: writer + .allow_drain_fallback + .load(std::sync::atomic::Ordering::Relaxed), + }); } - let is_empty = pool.registry.is_writer_empty(writer.id).await; - if is_empty { - pool.remove_writer_and_close_clients(writer.id).await; - continue; - } - draining_writers.push(writer); + draining_writers + }; + + if draining_writers.is_empty() { + return; } - if drain_threshold > 0 && draining_writers.len() > drain_threshold as usize { - draining_writers.sort_by(|left, right| { - let left_started = left - .draining_started_at_epoch_secs - .load(std::sync::atomic::Ordering::Relaxed); - let right_started = right - .draining_started_at_epoch_secs - .load(std::sync::atomic::Ordering::Relaxed); - left_started - .cmp(&right_started) - .then_with(|| left.created_at.cmp(&right.created_at)) - .then_with(|| left.id.cmp(&right.id)) - }); - let overflow = draining_writers.len().saturating_sub(drain_threshold as usize); - warn!( - draining_writers = draining_writers.len(), - me_pool_drain_threshold = drain_threshold, - removing_writers = overflow, - "ME draining writer threshold exceeded, force-closing oldest draining writers" - ); - for writer in draining_writers.drain(..overflow) { - pool.stats.increment_pool_force_close_total(); + let draining_ids: Vec = draining_writers.iter().map(|writer| writer.id).collect(); + let non_empty_writer_ids = pool.registry.non_empty_writer_ids(&draining_ids).await; + let mut non_empty_draining_writers = + Vec::::with_capacity(draining_writers.len()); + for writer in draining_writers.drain(..) { + if non_empty_writer_ids.contains(&writer.id) { + non_empty_draining_writers.push(writer); + } else { pool.remove_writer_and_close_clients(writer.id).await; } } + draining_writers = non_empty_draining_writers; + if draining_writers.is_empty() { + return; + } + + let overflow = if drain_threshold > 0 && draining_writers.len() > drain_threshold as usize { + draining_writers.len().saturating_sub(drain_threshold as usize) + } else { + 0 + }; + let has_deadline_expired = draining_writers.iter().any(|writer| { + writer.drain_deadline_epoch_secs != 0 && now_epoch_secs >= writer.drain_deadline_epoch_secs + }); + let can_drop_with_replacement = if overflow > 0 || has_deadline_expired { + pool.has_non_draining_writer_per_desired_dc_group().await + } else { + false + }; + + if overflow > 0 { + if can_drop_with_replacement { + draining_writers.sort_by(|left, right| { + left.draining_started_at_epoch_secs + .cmp(&right.draining_started_at_epoch_secs) + .then_with(|| left.created_at.cmp(&right.created_at)) + .then_with(|| left.id.cmp(&right.id)) + }); + warn!( + draining_writers = draining_writers.len(), + me_pool_drain_threshold = drain_threshold, + removing_writers = overflow, + "ME draining writer threshold exceeded, force-closing oldest draining writers" + ); + for writer in draining_writers.drain(..overflow) { + pool.stats.increment_pool_force_close_total(); + pool.remove_writer_and_close_clients(writer.id).await; + } + } else { + warn!( + draining_writers = draining_writers.len(), + me_pool_drain_threshold = drain_threshold, + overflow, + "ME draining threshold exceeded, but replacement coverage is incomplete; keeping draining writers" + ); + } + } for writer in draining_writers { - let drain_started_at_epoch_secs = writer - .draining_started_at_epoch_secs - .load(std::sync::atomic::Ordering::Relaxed); if drain_ttl_secs > 0 - && drain_started_at_epoch_secs != 0 - && now_epoch_secs.saturating_sub(drain_started_at_epoch_secs) > drain_ttl_secs + && writer.draining_started_at_epoch_secs != 0 + && now_epoch_secs.saturating_sub(writer.draining_started_at_epoch_secs) > drain_ttl_secs && should_emit_writer_warn( warn_next_allowed, writer.id, @@ -182,21 +232,45 @@ async fn reap_draining_writers( generation = writer.generation, drain_ttl_secs, force_close_secs = pool.me_pool_force_close_secs.load(std::sync::atomic::Ordering::Relaxed), - allow_drain_fallback = writer.allow_drain_fallback.load(std::sync::atomic::Ordering::Relaxed), + allow_drain_fallback = writer.allow_drain_fallback, "ME draining writer remains non-empty past drain TTL" ); } - let deadline_epoch_secs = writer - .drain_deadline_epoch_secs - .load(std::sync::atomic::Ordering::Relaxed); - if deadline_epoch_secs != 0 && now_epoch_secs >= deadline_epoch_secs { - warn!(writer_id = writer.id, "Drain timeout, force-closing"); - pool.stats.increment_pool_force_close_total(); - pool.remove_writer_and_close_clients(writer.id).await; + if writer.drain_deadline_epoch_secs != 0 && now_epoch_secs >= writer.drain_deadline_epoch_secs + { + if can_drop_with_replacement { + warn!(writer_id = writer.id, "Drain timeout, force-closing"); + pool.stats.increment_pool_force_close_total(); + pool.remove_writer_and_close_clients(writer.id).await; + } else if should_emit_writer_warn( + warn_next_allowed, + writer.id, + now, + pool.warn_rate_limit_duration(), + ) { + warn!( + writer_id = writer.id, + writer_dc = writer.writer_dc, + endpoint = %writer.addr, + "Drain timeout reached, but replacement coverage is incomplete; keeping draining writer" + ); + } } } } +#[derive(Debug, Clone)] +struct DrainingWriterSnapshot { + id: u64, + writer_dc: i32, + addr: SocketAddr, + generation: u64, + created_at: Instant, + draining_started_at_epoch_secs: u64, + drain_deadline_epoch_secs: u64, + allow_drain_fallback: bool, +} + fn should_emit_writer_warn( next_allowed: &mut HashMap, writer_id: u64, @@ -1330,6 +1404,15 @@ mod tests { me_pool_drain_threshold, ..GeneralConfig::default() }; + let mut proxy_map_v4 = HashMap::new(); + proxy_map_v4.insert( + 2, + vec![(IpAddr::V4(Ipv4Addr::new(203, 0, 113, 10)), 443)], + ); + let decision = NetworkDecision { + ipv4_me: true, + ..NetworkDecision::default() + }; MePool::new( None, vec![1u8; 32], @@ -1341,10 +1424,10 @@ mod tests { None, 12, 1200, - HashMap::new(), + proxy_map_v4, HashMap::new(), None, - NetworkDecision::default(), + decision, None, Arc::new(SecureRandom::new()), Arc::new(Stats::default()), @@ -1438,6 +1521,7 @@ mod tests { pool.writers.write().await.push(writer); pool.registry.register_writer(writer_id, tx).await; pool.conn_count.fetch_add(1, Ordering::Relaxed); + pool.increment_draining_active_runtime(); assert!( pool.registry .bind_writer( @@ -1455,8 +1539,56 @@ mod tests { conn_id } + async fn insert_live_writer(pool: &Arc, writer_id: u64, writer_dc: i32) { + let (tx, _writer_rx) = mpsc::channel::(8); + let writer = MeWriter { + id: writer_id, + addr: SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(203, 0, 113, (writer_id as u8).saturating_add(1))), + 4000 + writer_id as u16, + ), + source_ip: IpAddr::V4(Ipv4Addr::LOCALHOST), + writer_dc, + generation: 2, + contour: Arc::new(AtomicU8::new(WriterContour::Active.as_u8())), + created_at: Instant::now(), + tx: tx.clone(), + cancel: CancellationToken::new(), + degraded: Arc::new(AtomicBool::new(false)), + rtt_ema_ms_x10: Arc::new(AtomicU32::new(0)), + draining: Arc::new(AtomicBool::new(false)), + draining_started_at_epoch_secs: Arc::new(AtomicU64::new(0)), + drain_deadline_epoch_secs: Arc::new(AtomicU64::new(0)), + allow_drain_fallback: Arc::new(AtomicBool::new(false)), + }; + pool.writers.write().await.push(writer); + pool.registry.register_writer(writer_id, tx).await; + pool.conn_count.fetch_add(1, Ordering::Relaxed); + } + #[tokio::test] async fn reap_draining_writers_force_closes_oldest_over_threshold() { + let pool = make_pool(2).await; + insert_live_writer(&pool, 1, 2).await; + assert!(pool.has_non_draining_writer_per_desired_dc_group().await); + let now_epoch_secs = MePool::now_epoch_secs(); + let conn_a = insert_draining_writer(&pool, 10, now_epoch_secs.saturating_sub(30)).await; + let conn_b = insert_draining_writer(&pool, 20, now_epoch_secs.saturating_sub(20)).await; + let conn_c = insert_draining_writer(&pool, 30, now_epoch_secs.saturating_sub(10)).await; + let mut warn_next_allowed = HashMap::new(); + + reap_draining_writers(&pool, &mut warn_next_allowed).await; + + let mut writer_ids: Vec = pool.writers.read().await.iter().map(|writer| writer.id).collect(); + writer_ids.sort_unstable(); + assert_eq!(writer_ids, vec![1, 20, 30]); + assert!(pool.registry.get_writer(conn_a).await.is_none()); + assert_eq!(pool.registry.get_writer(conn_b).await.unwrap().writer_id, 20); + assert_eq!(pool.registry.get_writer(conn_c).await.unwrap().writer_id, 30); + } + + #[tokio::test] + async fn reap_draining_writers_does_not_force_close_overflow_without_replacement() { let pool = make_pool(2).await; let now_epoch_secs = MePool::now_epoch_secs(); let conn_a = insert_draining_writer(&pool, 10, now_epoch_secs.saturating_sub(30)).await; @@ -1466,9 +1598,10 @@ mod tests { reap_draining_writers(&pool, &mut warn_next_allowed).await; - let writer_ids: Vec = pool.writers.read().await.iter().map(|writer| writer.id).collect(); - assert_eq!(writer_ids, vec![20, 30]); - assert!(pool.registry.get_writer(conn_a).await.is_none()); + let mut writer_ids: Vec = pool.writers.read().await.iter().map(|writer| writer.id).collect(); + writer_ids.sort_unstable(); + assert_eq!(writer_ids, vec![10, 20, 30]); + assert_eq!(pool.registry.get_writer(conn_a).await.unwrap().writer_id, 10); assert_eq!(pool.registry.get_writer(conn_b).await.unwrap().writer_id, 20); assert_eq!(pool.registry.get_writer(conn_c).await.unwrap().writer_id, 30); } diff --git a/src/transport/middle_proxy/pool.rs b/src/transport/middle_proxy/pool.rs index 2a65160..56f3fbf 100644 --- a/src/transport/middle_proxy/pool.rs +++ b/src/transport/middle_proxy/pool.rs @@ -160,6 +160,7 @@ pub struct MePool { pub(super) refill_inflight: Arc>>, pub(super) refill_inflight_dc: Arc>>, pub(super) conn_count: AtomicUsize, + pub(super) draining_active_runtime: AtomicU64, pub(super) stats: Arc, pub(super) generation: AtomicU64, pub(super) active_generation: AtomicU64, @@ -438,6 +439,7 @@ impl MePool { refill_inflight: Arc::new(Mutex::new(HashSet::new())), refill_inflight_dc: Arc::new(Mutex::new(HashSet::new())), conn_count: AtomicUsize::new(0), + draining_active_runtime: AtomicU64::new(0), generation: AtomicU64::new(1), active_generation: AtomicU64::new(1), warm_generation: AtomicU64::new(0), @@ -690,6 +692,32 @@ impl MePool { } } + pub(super) fn draining_active_runtime(&self) -> u64 { + self.draining_active_runtime.load(Ordering::Relaxed) + } + + pub(super) fn increment_draining_active_runtime(&self) { + self.draining_active_runtime.fetch_add(1, Ordering::Relaxed); + } + + pub(super) fn decrement_draining_active_runtime(&self) { + let mut current = self.draining_active_runtime.load(Ordering::Relaxed); + loop { + if current == 0 { + break; + } + match self.draining_active_runtime.compare_exchange_weak( + current, + current - 1, + Ordering::Relaxed, + Ordering::Relaxed, + ) { + Ok(_) => break, + Err(actual) => current = actual, + } + } + } + pub(super) async fn key_selector(&self) -> u32 { self.proxy_secret.read().await.key_selector } diff --git a/src/transport/middle_proxy/pool_reinit.rs b/src/transport/middle_proxy/pool_reinit.rs index 3d9d679..3cfc834 100644 --- a/src/transport/middle_proxy/pool_reinit.rs +++ b/src/transport/middle_proxy/pool_reinit.rs @@ -141,6 +141,38 @@ impl MePool { out } + pub(super) async fn has_non_draining_writer_per_desired_dc_group(&self) -> bool { + let desired_by_dc = self.desired_dc_endpoints().await; + let required_dcs: HashSet = desired_by_dc + .iter() + .filter_map(|(dc, endpoints)| { + if endpoints.is_empty() { + None + } else { + Some(*dc) + } + }) + .collect(); + if required_dcs.is_empty() { + return true; + } + + let ws = self.writers.read().await; + let mut covered_dcs = HashSet::::with_capacity(required_dcs.len()); + for writer in ws.iter() { + if writer.draining.load(Ordering::Relaxed) { + continue; + } + if required_dcs.contains(&writer.writer_dc) { + covered_dcs.insert(writer.writer_dc); + if covered_dcs.len() == required_dcs.len() { + return true; + } + } + } + false + } + fn hardswap_warmup_connect_delay_ms(&self) -> u64 { let min_ms = self.me_hardswap_warmup_delay_min_ms.load(Ordering::Relaxed); let max_ms = self.me_hardswap_warmup_delay_max_ms.load(Ordering::Relaxed); @@ -475,12 +507,30 @@ impl MePool { coverage_ratio = format_args!("{coverage_ratio:.3}"), min_ratio = format_args!("{min_ratio:.3}"), drain_timeout_secs, - "ME reinit cycle covered; draining stale writers" + "ME reinit cycle covered; processing stale writers" ); self.stats.increment_pool_swap_total(); + let can_drop_with_replacement = self + .has_non_draining_writer_per_desired_dc_group() + .await; + if can_drop_with_replacement { + info!( + stale_writers = stale_writer_ids.len(), + "ME reinit stale writers: replacement coverage ready, force-closing clients for fast rebind" + ); + } else { + warn!( + stale_writers = stale_writer_ids.len(), + "ME reinit stale writers: replacement coverage incomplete, keeping draining fallback" + ); + } for writer_id in stale_writer_ids { self.mark_writer_draining_with_timeout(writer_id, drain_timeout, !hardswap) .await; + if can_drop_with_replacement { + self.stats.increment_pool_force_close_total(); + self.remove_writer_and_close_clients(writer_id).await; + } } if hardswap { self.clear_pending_hardswap_state(); diff --git a/src/transport/middle_proxy/pool_writer.rs b/src/transport/middle_proxy/pool_writer.rs index 8ce3de3..7490a98 100644 --- a/src/transport/middle_proxy/pool_writer.rs +++ b/src/transport/middle_proxy/pool_writer.rs @@ -514,6 +514,7 @@ impl MePool { let was_draining = w.draining.load(Ordering::Relaxed); if was_draining { self.stats.decrement_pool_drain_active(); + self.decrement_draining_active_runtime(); } self.stats.increment_me_writer_removed_total(); w.cancel.cancel(); @@ -572,6 +573,7 @@ impl MePool { .store(drain_deadline_epoch_secs, Ordering::Relaxed); if !already_draining { self.stats.increment_pool_drain_active(); + self.increment_draining_active_runtime(); } w.contour .store(WriterContour::Draining.as_u8(), Ordering::Relaxed); diff --git a/src/transport/middle_proxy/registry.rs b/src/transport/middle_proxy/registry.rs index cc3028b..cbe1d9a 100644 --- a/src/transport/middle_proxy/registry.rs +++ b/src/transport/middle_proxy/registry.rs @@ -436,6 +436,19 @@ impl ConnRegistry { .map(|s| s.is_empty()) .unwrap_or(true) } + + pub(super) async fn non_empty_writer_ids(&self, writer_ids: &[u64]) -> HashSet { + let inner = self.inner.read().await; + let mut out = HashSet::::with_capacity(writer_ids.len()); + for writer_id in writer_ids { + if let Some(conns) = inner.conns_for_writer.get(writer_id) + && !conns.is_empty() + { + out.insert(*writer_id); + } + } + out + } } #[cfg(test)] @@ -634,4 +647,35 @@ mod tests { ); assert!(registry.get_writer(conn_id).await.is_none()); } + + #[tokio::test] + async fn non_empty_writer_ids_returns_only_writers_with_bound_clients() { + let registry = ConnRegistry::new(); + let (conn_id, _rx) = registry.register().await; + let (writer_tx_a, _writer_rx_a) = tokio::sync::mpsc::channel(8); + let (writer_tx_b, _writer_rx_b) = tokio::sync::mpsc::channel(8); + registry.register_writer(10, writer_tx_a).await; + registry.register_writer(20, writer_tx_b).await; + + let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443); + assert!( + registry + .bind_writer( + conn_id, + 10, + ConnMeta { + target_dc: 2, + client_addr: addr, + our_addr: addr, + proto_flags: 0, + }, + ) + .await + ); + + let non_empty = registry.non_empty_writer_ids(&[10, 20, 30]).await; + assert!(non_empty.contains(&10)); + assert!(!non_empty.contains(&20)); + assert!(!non_empty.contains(&30)); + } }