From 0e2d42624f1bb12145b8f12363ab922bb63560c7 Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Tue, 24 Feb 2026 00:04:12 +0300 Subject: [PATCH] ME Pool Hardswap --- src/cli.rs | 5 +- src/config/defaults.rs | 17 +- src/config/hot_reload.rs | 30 +++ src/config/load.rs | 59 +++++ src/config/types.rs | 24 ++ src/main.rs | 47 ++-- src/metrics.rs | 24 ++ src/proxy/middle_relay.rs | 3 + src/stats/mod.rs | 45 ++++ src/transport/middle_proxy/config_updater.rs | 20 +- src/transport/middle_proxy/health.rs | 1 + src/transport/middle_proxy/pool.rs | 253 +++++++++++++++++-- src/transport/middle_proxy/send.rs | 27 +- 13 files changed, 491 insertions(+), 64 deletions(-) diff --git a/src/cli.rs b/src/cli.rs index 3525a22..a1182a7 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -196,7 +196,10 @@ use_middle_proxy = false log_level = "normal" desync_all_full = false update_every = 43200 -me_reinit_drain_timeout_secs = 300 +hardswap = false +me_pool_drain_ttl_secs = 90 +me_pool_min_fresh_ratio = 0.8 +me_reinit_drain_timeout_secs = 120 [network] ipv4 = true diff --git a/src/config/defaults.rs b/src/config/defaults.rs index 01cdcb0..775692e 100644 --- a/src/config/defaults.rs +++ b/src/config/defaults.rs @@ -1,4 +1,3 @@ -use std::net::IpAddr; use std::collections::HashMap; use ipnetwork::IpNetwork; use serde::Deserialize; @@ -172,11 +171,23 @@ pub(crate) fn default_proxy_config_reload_secs() -> u64 { } pub(crate) fn default_update_every_secs() -> u64 { - 2 * 60 * 60 + 12 * 60 * 60 } pub(crate) fn default_me_reinit_drain_timeout_secs() -> u64 { - 300 + 120 +} + +pub(crate) fn default_me_pool_drain_ttl_secs() -> u64 { + 90 +} + +pub(crate) fn default_me_pool_min_fresh_ratio() -> f32 { + 0.8 +} + +pub(crate) fn default_hardswap() -> bool { + false } pub(crate) fn default_ntp_check() -> bool { diff --git a/src/config/hot_reload.rs b/src/config/hot_reload.rs index 5c7263f..7f121f6 100644 --- a/src/config/hot_reload.rs +++ b/src/config/hot_reload.rs @@ -12,6 +12,9 @@ //! | `general` | `me_keepalive_*` | Passed on next connection | //! | `general` | `desync_all_full` | Applied immediately | //! | `general` | `update_every` | Applied to ME updater immediately | +//! | `general` | `hardswap` | Applied on next ME map update | +//! | `general` | `me_pool_drain_ttl_secs` | Applied on next ME map update | +//! | `general` | `me_pool_min_fresh_ratio` | Applied on next ME map update | //! | `general` | `me_reinit_drain_timeout_secs`| Applied on next ME map update | //! | `access` | All user/quota fields | Effective immediately | //! @@ -39,6 +42,9 @@ pub struct HotFields { pub middle_proxy_pool_size: usize, pub desync_all_full: bool, pub update_every_secs: u64, + pub hardswap: bool, + pub me_pool_drain_ttl_secs: u64, + pub me_pool_min_fresh_ratio: f32, pub me_reinit_drain_timeout_secs: u64, pub me_keepalive_enabled: bool, pub me_keepalive_interval_secs: u64, @@ -55,6 +61,9 @@ impl HotFields { middle_proxy_pool_size: cfg.general.middle_proxy_pool_size, desync_all_full: cfg.general.desync_all_full, update_every_secs: cfg.general.effective_update_every_secs(), + hardswap: cfg.general.hardswap, + me_pool_drain_ttl_secs: cfg.general.me_pool_drain_ttl_secs, + me_pool_min_fresh_ratio: cfg.general.me_pool_min_fresh_ratio, me_reinit_drain_timeout_secs: cfg.general.me_reinit_drain_timeout_secs, me_keepalive_enabled: cfg.general.me_keepalive_enabled, me_keepalive_interval_secs: cfg.general.me_keepalive_interval_secs, @@ -198,6 +207,27 @@ fn log_changes( ); } + if old_hot.hardswap != new_hot.hardswap { + info!( + "config reload: hardswap: {} → {}", + old_hot.hardswap, new_hot.hardswap, + ); + } + + if old_hot.me_pool_drain_ttl_secs != new_hot.me_pool_drain_ttl_secs { + info!( + "config reload: me_pool_drain_ttl_secs: {}s → {}s", + old_hot.me_pool_drain_ttl_secs, new_hot.me_pool_drain_ttl_secs, + ); + } + + if (old_hot.me_pool_min_fresh_ratio - new_hot.me_pool_min_fresh_ratio).abs() > f32::EPSILON { + info!( + "config reload: me_pool_min_fresh_ratio: {:.3} → {:.3}", + old_hot.me_pool_min_fresh_ratio, new_hot.me_pool_min_fresh_ratio, + ); + } + if old_hot.me_reinit_drain_timeout_secs != new_hot.me_reinit_drain_timeout_secs { info!( "config reload: me_reinit_drain_timeout_secs: {}s → {}s", diff --git a/src/config/load.rs b/src/config/load.rs index fa61539..4d59e60 100644 --- a/src/config/load.rs +++ b/src/config/load.rs @@ -145,6 +145,24 @@ impl ProxyConfig { } } + if !(0.0..=1.0).contains(&config.general.me_pool_min_fresh_ratio) { + return Err(ProxyError::Config( + "general.me_pool_min_fresh_ratio must be within [0.0, 1.0]".to_string(), + )); + } + + if config.general.effective_me_pool_force_close_secs() > 0 + && config.general.effective_me_pool_force_close_secs() + < config.general.me_pool_drain_ttl_secs + { + warn!( + me_pool_drain_ttl_secs = config.general.me_pool_drain_ttl_secs, + me_reinit_drain_timeout_secs = config.general.effective_me_pool_force_close_secs(), + "force-close timeout is lower than drain TTL; bumping force-close timeout to TTL" + ); + config.general.me_reinit_drain_timeout_secs = config.general.me_pool_drain_ttl_secs; + } + // Validate secrets. for (user, secret) in &config.access.users { if !secret.chars().all(|c| c.is_ascii_hexdigit()) || secret.len() != 32 { @@ -439,4 +457,45 @@ mod tests { assert!(err.contains("general.update_every must be > 0")); let _ = std::fs::remove_file(path); } + + #[test] + fn me_pool_min_fresh_ratio_out_of_range_is_rejected() { + let toml = r#" + [general] + me_pool_min_fresh_ratio = 1.5 + + [censorship] + tls_domain = "example.com" + + [access.users] + user = "00000000000000000000000000000000" + "#; + let dir = std::env::temp_dir(); + let path = dir.join("telemt_me_pool_min_ratio_invalid_test.toml"); + std::fs::write(&path, toml).unwrap(); + let err = ProxyConfig::load(&path).unwrap_err().to_string(); + assert!(err.contains("general.me_pool_min_fresh_ratio must be within [0.0, 1.0]")); + let _ = std::fs::remove_file(path); + } + + #[test] + fn force_close_bumped_when_below_drain_ttl() { + let toml = r#" + [general] + me_pool_drain_ttl_secs = 90 + me_reinit_drain_timeout_secs = 30 + + [censorship] + tls_domain = "example.com" + + [access.users] + user = "00000000000000000000000000000000" + "#; + let dir = std::env::temp_dir(); + let path = dir.join("telemt_force_close_bump_test.toml"); + std::fs::write(&path, toml).unwrap(); + let cfg = ProxyConfig::load(&path).unwrap(); + assert_eq!(cfg.general.me_reinit_drain_timeout_secs, 90); + let _ = std::fs::remove_file(path); + } } diff --git a/src/config/types.rs b/src/config/types.rs index eb16885..b33b5fd 100644 --- a/src/config/types.rs +++ b/src/config/types.rs @@ -206,6 +206,11 @@ pub struct GeneralConfig { #[serde(default = "default_desync_all_full")] pub desync_all_full: bool, + /// Enable C-like hard-swap for ME pool generations. + /// When true, Telemt prewarms a new generation and switches once full coverage is reached. + #[serde(default = "default_hardswap")] + pub hardswap: bool, + /// Enable staggered warmup of extra ME writers. #[serde(default = "default_true")] pub me_warmup_stagger_enabled: bool, @@ -262,6 +267,16 @@ pub struct GeneralConfig { #[serde(default)] pub update_every: Option, + /// Drain-TTL in seconds for stale ME writers after endpoint map changes. + /// During TTL, stale writers may be used only as fallback for new bindings. + #[serde(default = "default_me_pool_drain_ttl_secs")] + pub me_pool_drain_ttl_secs: u64, + + /// Minimum desired-DC coverage ratio required before draining stale writers. + /// Range: 0.0..=1.0. + #[serde(default = "default_me_pool_min_fresh_ratio")] + pub me_pool_min_fresh_ratio: f32, + /// Drain timeout in seconds for stale ME writers after endpoint map changes. /// Set to 0 to keep stale writers draining indefinitely (no force-close). #[serde(default = "default_me_reinit_drain_timeout_secs")] @@ -328,8 +343,11 @@ impl Default for GeneralConfig { crypto_pending_buffer: default_crypto_pending_buffer(), max_client_frame: default_max_client_frame(), desync_all_full: default_desync_all_full(), + hardswap: default_hardswap(), fast_mode_min_tls_record: default_fast_mode_min_tls_record(), update_every: Some(default_update_every_secs()), + me_pool_drain_ttl_secs: default_me_pool_drain_ttl_secs(), + me_pool_min_fresh_ratio: default_me_pool_min_fresh_ratio(), me_reinit_drain_timeout_secs: default_me_reinit_drain_timeout_secs(), proxy_secret_auto_reload_secs: default_proxy_secret_reload_secs(), proxy_config_auto_reload_secs: default_proxy_config_reload_secs(), @@ -348,6 +366,12 @@ impl GeneralConfig { self.update_every .unwrap_or_else(|| self.proxy_secret_auto_reload_secs.min(self.proxy_config_auto_reload_secs)) } + + /// Resolve force-close timeout for stale writers. + /// `me_reinit_drain_timeout_secs` remains backward-compatible alias. + pub fn effective_me_pool_force_close_secs(&self) -> u64 { + self.me_reinit_drain_timeout_secs + } } /// `[general.links]` — proxy link generation settings. diff --git a/src/main.rs b/src/main.rs index 3a6ad1a..0601215 100644 --- a/src/main.rs +++ b/src/main.rs @@ -73,36 +73,27 @@ fn parse_cli() -> (String, bool, Option) { log_level = Some(s.trim_start_matches("--log-level=").to_string()); } "--help" | "-h" => { - eprintln!("telemt - Telegram MTProto Proxy v{}", env!("CARGO_PKG_VERSION")); + eprintln!("Usage: telemt [config.toml] [OPTIONS]"); eprintln!(); - eprintln!("USAGE:"); - eprintln!(" telemt [CONFIG] [OPTIONS]"); - eprintln!(" telemt --init [INIT_OPTIONS]"); + eprintln!("Options:"); + eprintln!(" --silent, -s Suppress info logs"); + eprintln!(" --log-level debug|verbose|normal|silent"); + eprintln!(" --help, -h Show this help"); eprintln!(); - eprintln!("ARGS:"); - eprintln!(" Path to config file (default: config.toml)"); - eprintln!(); - eprintln!("OPTIONS:"); - eprintln!(" -s, --silent Suppress info logs (equivalent to --log-level silent)"); - eprintln!(" --log-level Set log level [possible values: debug, verbose, normal, silent]"); - eprintln!(" -h, --help Show this help message"); - eprintln!(" -V, --version Print version number"); - eprintln!(); - eprintln!("INIT OPTIONS (fire-and-forget setup):"); - eprintln!(" --init Generate config, install systemd service, and start"); + eprintln!("Setup (fire-and-forget):"); + eprintln!( + " --init Generate config, install systemd service, start" + ); eprintln!(" --port Listen port (default: 443)"); - eprintln!(" --domain TLS domain for masking (default: www.google.com)"); - eprintln!(" --secret 32-char hex secret (auto-generated if omitted)"); - eprintln!(" --user Username for proxy access (default: user)"); + eprintln!( + " --domain TLS domain for masking (default: www.google.com)" + ); + eprintln!( + " --secret 32-char hex secret (auto-generated if omitted)" + ); + eprintln!(" --user Username (default: user)"); eprintln!(" --config-dir Config directory (default: /etc/telemt)"); - eprintln!(" --no-start Create config and service but don't start"); - eprintln!(); - eprintln!("EXAMPLES:"); - eprintln!(" telemt # Run with default config"); - eprintln!(" telemt /etc/telemt/config.toml # Run with specific config"); - eprintln!(" telemt --log-level debug # Run with debug logging"); - eprintln!(" telemt --init # Quick setup with defaults"); - eprintln!(" telemt --init --port 8443 --user admin # Custom setup"); + eprintln!(" --no-start Don't start the service after install"); std::process::exit(0); } "--version" | "-V" => { @@ -371,6 +362,10 @@ match crate::transport::middle_proxy::fetch_proxy_secret(proxy_secret_path).awai config.general.me_reconnect_backoff_base_ms, config.general.me_reconnect_backoff_cap_ms, config.general.me_reconnect_fast_retry_count, + config.general.hardswap, + config.general.me_pool_drain_ttl_secs, + config.general.effective_me_pool_force_close_secs(), + config.general.me_pool_min_fresh_ratio, ); let pool_size = config.general.middle_proxy_pool_size.max(1); diff --git a/src/metrics.rs b/src/metrics.rs index 326d333..d11c302 100644 --- a/src/metrics.rs +++ b/src/metrics.rs @@ -175,6 +175,30 @@ fn render_metrics(stats: &Stats) -> String { stats.get_desync_frames_bucket_gt_10() ); + let _ = writeln!(out, "# HELP telemt_pool_swap_total Successful ME pool swaps"); + let _ = writeln!(out, "# TYPE telemt_pool_swap_total counter"); + let _ = writeln!(out, "telemt_pool_swap_total {}", stats.get_pool_swap_total()); + + let _ = writeln!(out, "# HELP telemt_pool_drain_active Active draining ME writers"); + let _ = writeln!(out, "# TYPE telemt_pool_drain_active gauge"); + let _ = writeln!(out, "telemt_pool_drain_active {}", stats.get_pool_drain_active()); + + let _ = writeln!(out, "# HELP telemt_pool_force_close_total Forced close events for draining writers"); + let _ = writeln!(out, "# TYPE telemt_pool_force_close_total counter"); + let _ = writeln!( + out, + "telemt_pool_force_close_total {}", + stats.get_pool_force_close_total() + ); + + let _ = writeln!(out, "# HELP telemt_pool_stale_pick_total Stale writer fallback picks for new binds"); + let _ = writeln!(out, "# TYPE telemt_pool_stale_pick_total counter"); + let _ = writeln!( + out, + "telemt_pool_stale_pick_total {}", + stats.get_pool_stale_pick_total() + ); + let _ = writeln!(out, "# HELP telemt_user_connections_total Per-user total connections"); let _ = writeln!(out, "# TYPE telemt_user_connections_total counter"); let _ = writeln!(out, "# HELP telemt_user_connections_current Per-user active connections"); diff --git a/src/proxy/middle_relay.rs b/src/proxy/middle_relay.rs index d55e5a2..a6a11e1 100644 --- a/src/proxy/middle_relay.rs +++ b/src/proxy/middle_relay.rs @@ -184,6 +184,7 @@ where let user = success.user.clone(); let peer = success.peer; let proto_tag = success.proto_tag; + let pool_generation = me_pool.current_generation(); info!( user = %user, @@ -191,6 +192,7 @@ where dc = success.dc_idx, proto = ?proto_tag, mode = "middle_proxy", + pool_generation, "Routing via Middle-End" ); @@ -220,6 +222,7 @@ where peer_hash = format_args!("0x{:016x}", forensics.peer_hash), desync_all_full = forensics.desync_all_full, proto_flags = format_args!("0x{:08x}", proto_flags), + pool_generation, "ME relay started" ); diff --git a/src/stats/mod.rs b/src/stats/mod.rs index 4c16d25..1994b36 100644 --- a/src/stats/mod.rs +++ b/src/stats/mod.rs @@ -38,6 +38,10 @@ pub struct Stats { desync_frames_bucket_1_2: AtomicU64, desync_frames_bucket_3_10: AtomicU64, desync_frames_bucket_gt_10: AtomicU64, + pool_swap_total: AtomicU64, + pool_drain_active: AtomicU64, + pool_force_close_total: AtomicU64, + pool_stale_pick_total: AtomicU64, user_stats: DashMap, start_time: parking_lot::RwLock>, } @@ -108,6 +112,35 @@ impl Stats { } } } + pub fn increment_pool_swap_total(&self) { + self.pool_swap_total.fetch_add(1, Ordering::Relaxed); + } + pub fn increment_pool_drain_active(&self) { + self.pool_drain_active.fetch_add(1, Ordering::Relaxed); + } + pub fn decrement_pool_drain_active(&self) { + let mut current = self.pool_drain_active.load(Ordering::Relaxed); + loop { + if current == 0 { + break; + } + match self.pool_drain_active.compare_exchange_weak( + current, + current - 1, + Ordering::Relaxed, + Ordering::Relaxed, + ) { + Ok(_) => break, + Err(actual) => current = actual, + } + } + } + pub fn increment_pool_force_close_total(&self) { + self.pool_force_close_total.fetch_add(1, Ordering::Relaxed); + } + pub fn increment_pool_stale_pick_total(&self) { + self.pool_stale_pick_total.fetch_add(1, Ordering::Relaxed); + } pub fn get_connects_all(&self) -> u64 { self.connects_all.load(Ordering::Relaxed) } pub fn get_connects_bad(&self) -> u64 { self.connects_bad.load(Ordering::Relaxed) } pub fn get_me_keepalive_sent(&self) -> u64 { self.me_keepalive_sent.load(Ordering::Relaxed) } @@ -149,6 +182,18 @@ impl Stats { pub fn get_desync_frames_bucket_gt_10(&self) -> u64 { self.desync_frames_bucket_gt_10.load(Ordering::Relaxed) } + pub fn get_pool_swap_total(&self) -> u64 { + self.pool_swap_total.load(Ordering::Relaxed) + } + pub fn get_pool_drain_active(&self) -> u64 { + self.pool_drain_active.load(Ordering::Relaxed) + } + pub fn get_pool_force_close_total(&self) -> u64 { + self.pool_force_close_total.load(Ordering::Relaxed) + } + pub fn get_pool_stale_pick_total(&self) -> u64 { + self.pool_stale_pick_total.load(Ordering::Relaxed) + } pub fn increment_user_connects(&self, user: &str) { self.user_stats.entry(user.to_string()).or_default() diff --git a/src/transport/middle_proxy/config_updater.rs b/src/transport/middle_proxy/config_updater.rs index 479a880..96d5f91 100644 --- a/src/transport/middle_proxy/config_updater.rs +++ b/src/transport/middle_proxy/config_updater.rs @@ -131,6 +131,13 @@ pub async fn fetch_proxy_config(url: &str) -> Result { } async fn run_update_cycle(pool: &Arc, rng: &Arc, cfg: &ProxyConfig) { + pool.update_runtime_reinit_policy( + cfg.general.hardswap, + cfg.general.me_pool_drain_ttl_secs, + cfg.general.effective_me_pool_force_close_secs(), + cfg.general.me_pool_min_fresh_ratio, + ); + let mut maps_changed = false; // Update proxy config v4 @@ -162,12 +169,7 @@ async fn run_update_cycle(pool: &Arc, rng: &Arc, cfg: &Pro } if maps_changed { - let drain_timeout = if cfg.general.me_reinit_drain_timeout_secs == 0 { - None - } else { - Some(Duration::from_secs(cfg.general.me_reinit_drain_timeout_secs)) - }; - pool.zero_downtime_reinit_after_map_change(rng.as_ref(), drain_timeout) + pool.zero_downtime_reinit_after_map_change(rng.as_ref()) .await; } @@ -224,6 +226,12 @@ pub async fn me_config_updater( break; } let cfg = config_rx.borrow().clone(); + pool.update_runtime_reinit_policy( + cfg.general.hardswap, + cfg.general.me_pool_drain_ttl_secs, + cfg.general.effective_me_pool_force_close_secs(), + cfg.general.me_pool_min_fresh_ratio, + ); let new_secs = cfg.general.effective_update_every_secs().max(1); if new_secs == update_every_secs { continue; diff --git a/src/transport/middle_proxy/health.rs b/src/transport/middle_proxy/health.rs index d4d4a70..18814cd 100644 --- a/src/transport/middle_proxy/health.rs +++ b/src/transport/middle_proxy/health.rs @@ -68,6 +68,7 @@ async fn check_family( .read() .await .iter() + .filter(|w| !w.draining.load(std::sync::atomic::Ordering::Relaxed)) .map(|w| w.addr) .collect(); diff --git a/src/transport/middle_proxy/pool.rs b/src/transport/middle_proxy/pool.rs index bd7c9cc..8e159db 100644 --- a/src/transport/middle_proxy/pool.rs +++ b/src/transport/middle_proxy/pool.rs @@ -1,14 +1,14 @@ use std::collections::{HashMap, HashSet}; use std::net::{IpAddr, Ipv6Addr, SocketAddr}; use std::sync::Arc; -use std::sync::atomic::{AtomicBool, AtomicI32, AtomicU64, AtomicUsize, Ordering}; +use std::sync::atomic::{AtomicBool, AtomicI32, AtomicU32, AtomicU64, AtomicUsize, Ordering}; use bytes::BytesMut; use rand::Rng; use rand::seq::SliceRandom; use tokio::sync::{Mutex, RwLock, mpsc, Notify}; use tokio_util::sync::CancellationToken; use tracing::{debug, info, warn}; -use std::time::{Duration, Instant}; +use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; use crate::crypto::SecureRandom; use crate::error::{ProxyError, Result}; @@ -27,10 +27,13 @@ const ME_ACTIVE_PING_JITTER_SECS: i64 = 5; pub struct MeWriter { pub id: u64, pub addr: SocketAddr, + pub generation: u64, pub tx: mpsc::Sender, pub cancel: CancellationToken, pub degraded: Arc, pub draining: Arc, + pub draining_started_at_epoch_secs: Arc, + pub allow_drain_fallback: Arc, } pub struct MePool { @@ -73,6 +76,11 @@ pub struct MePool { pub(super) writer_available: Arc, pub(super) conn_count: AtomicUsize, pub(super) stats: Arc, + pub(super) generation: AtomicU64, + pub(super) hardswap: AtomicBool, + pub(super) me_pool_drain_ttl_secs: AtomicU64, + pub(super) me_pool_force_close_secs: AtomicU64, + pub(super) me_pool_min_fresh_ratio_permille: AtomicU32, pool_size: usize, } @@ -83,6 +91,22 @@ pub struct NatReflectionCache { } impl MePool { + fn ratio_to_permille(ratio: f32) -> u32 { + let clamped = ratio.clamp(0.0, 1.0); + (clamped * 1000.0).round() as u32 + } + + fn permille_to_ratio(permille: u32) -> f32 { + (permille.min(1000) as f32) / 1000.0 + } + + fn now_epoch_secs() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs() + } + pub fn new( proxy_tag: Option>, proxy_secret: Vec, @@ -110,6 +134,10 @@ impl MePool { me_reconnect_backoff_base_ms: u64, me_reconnect_backoff_cap_ms: u64, me_reconnect_fast_retry_count: u32, + hardswap: bool, + me_pool_drain_ttl_secs: u64, + me_pool_force_close_secs: u64, + me_pool_min_fresh_ratio: f32, ) -> Arc { Arc::new(Self { registry: Arc::new(ConnRegistry::new()), @@ -152,6 +180,11 @@ impl MePool { nat_reflection_cache: Arc::new(Mutex::new(NatReflectionCache::default())), writer_available: Arc::new(Notify::new()), conn_count: AtomicUsize::new(0), + generation: AtomicU64::new(1), + hardswap: AtomicBool::new(hardswap), + me_pool_drain_ttl_secs: AtomicU64::new(me_pool_drain_ttl_secs), + me_pool_force_close_secs: AtomicU64::new(me_pool_force_close_secs), + me_pool_min_fresh_ratio_permille: AtomicU32::new(Self::ratio_to_permille(me_pool_min_fresh_ratio)), }) } @@ -159,6 +192,25 @@ impl MePool { self.proxy_tag.is_some() } + pub fn current_generation(&self) -> u64 { + self.generation.load(Ordering::Relaxed) + } + + pub fn update_runtime_reinit_policy( + &self, + hardswap: bool, + drain_ttl_secs: u64, + force_close_secs: u64, + min_fresh_ratio: f32, + ) { + self.hardswap.store(hardswap, Ordering::Relaxed); + self.me_pool_drain_ttl_secs.store(drain_ttl_secs, Ordering::Relaxed); + self.me_pool_force_close_secs + .store(force_close_secs, Ordering::Relaxed); + self.me_pool_min_fresh_ratio_permille + .store(Self::ratio_to_permille(min_fresh_ratio), Ordering::Relaxed); + } + pub fn reset_stun_state(&self) { self.nat_probe_attempts.store(0, Ordering::Relaxed); self.nat_probe_disabled.store(false, Ordering::Relaxed); @@ -177,6 +229,42 @@ impl MePool { self.writers.clone() } + fn force_close_timeout(&self) -> Option { + let secs = self.me_pool_force_close_secs.load(Ordering::Relaxed); + if secs == 0 { + None + } else { + Some(Duration::from_secs(secs)) + } + } + + fn coverage_ratio( + desired_by_dc: &HashMap>, + active_writer_addrs: &HashSet, + ) -> (f32, Vec) { + if desired_by_dc.is_empty() { + return (1.0, Vec::new()); + } + + let mut missing_dc = Vec::::new(); + let mut covered = 0usize; + for (dc, endpoints) in desired_by_dc { + if endpoints.is_empty() { + continue; + } + if endpoints.iter().any(|addr| active_writer_addrs.contains(addr)) { + covered += 1; + } else { + missing_dc.push(*dc); + } + } + + missing_dc.sort_unstable(); + let total = desired_by_dc.len().max(1); + let ratio = (covered as f32) / (total as f32); + (ratio, missing_dc) + } + pub async fn reconcile_connections(self: &Arc, rng: &SecureRandom) { let writers = self.writers.read().await; let current: HashSet = writers @@ -235,39 +323,104 @@ impl MePool { out } + async fn warmup_generation_for_all_dcs( + self: &Arc, + rng: &SecureRandom, + generation: u64, + desired_by_dc: &HashMap>, + ) { + for endpoints in desired_by_dc.values() { + if endpoints.is_empty() { + continue; + } + + let has_fresh = { + let ws = self.writers.read().await; + ws.iter().any(|w| { + !w.draining.load(Ordering::Relaxed) + && w.generation == generation + && endpoints.contains(&w.addr) + }) + }; + + if has_fresh { + continue; + } + + let mut shuffled: Vec = endpoints.iter().copied().collect(); + shuffled.shuffle(&mut rand::rng()); + for addr in shuffled { + if self.connect_one(addr, rng).await.is_ok() { + break; + } + } + } + } + pub async fn zero_downtime_reinit_after_map_change( self: &Arc, rng: &SecureRandom, - drain_timeout: Option, ) { - // Stage 1: prewarm writers for new endpoint maps before draining old ones. - self.reconcile_connections(rng).await; - let desired_by_dc = self.desired_dc_endpoints().await; if desired_by_dc.is_empty() { warn!("ME endpoint map is empty after update; skipping stale writer drain"); return; } + let previous_generation = self.current_generation(); + let generation = self.generation.fetch_add(1, Ordering::Relaxed) + 1; + let hardswap = self.hardswap.load(Ordering::Relaxed); + + if hardswap { + self.warmup_generation_for_all_dcs(rng, generation, &desired_by_dc) + .await; + } else { + self.reconcile_connections(rng).await; + } + let writers = self.writers.read().await; let active_writer_addrs: HashSet = writers .iter() .filter(|w| !w.draining.load(Ordering::Relaxed)) .map(|w| w.addr) .collect(); - - let mut missing_dc = Vec::::new(); - for (dc, endpoints) in &desired_by_dc { - if endpoints.is_empty() { - continue; - } - if !endpoints.iter().any(|addr| active_writer_addrs.contains(addr)) { - missing_dc.push(*dc); - } + let min_ratio = Self::permille_to_ratio( + self.me_pool_min_fresh_ratio_permille + .load(Ordering::Relaxed), + ); + let (coverage_ratio, missing_dc) = Self::coverage_ratio(&desired_by_dc, &active_writer_addrs); + if !hardswap && coverage_ratio < min_ratio { + warn!( + previous_generation, + generation, + coverage_ratio = format_args!("{coverage_ratio:.3}"), + min_ratio = format_args!("{min_ratio:.3}"), + missing_dc = ?missing_dc, + "ME reinit coverage below threshold; keeping stale writers" + ); + return; } - if !missing_dc.is_empty() { - missing_dc.sort_unstable(); + if hardswap { + let fresh_writer_addrs: HashSet = writers + .iter() + .filter(|w| !w.draining.load(Ordering::Relaxed)) + .filter(|w| w.generation == generation) + .map(|w| w.addr) + .collect(); + let (fresh_ratio, fresh_missing_dc) = + Self::coverage_ratio(&desired_by_dc, &fresh_writer_addrs); + if !fresh_missing_dc.is_empty() { + warn!( + previous_generation, + generation, + fresh_ratio = format_args!("{fresh_ratio:.3}"), + missing_dc = ?fresh_missing_dc, + "ME hardswap pending: fresh generation coverage incomplete" + ); + return; + } + } else if !missing_dc.is_empty() { warn!( missing_dc = ?missing_dc, // Keep stale writers alive when fresh coverage is incomplete. @@ -284,7 +437,13 @@ impl MePool { let stale_writer_ids: Vec = writers .iter() .filter(|w| !w.draining.load(Ordering::Relaxed)) - .filter(|w| !desired_addrs.contains(&w.addr)) + .filter(|w| { + if hardswap { + w.generation < generation + } else { + !desired_addrs.contains(&w.addr) + } + }) .map(|w| w.id) .collect(); drop(writers); @@ -294,14 +453,21 @@ impl MePool { return; } + let drain_timeout = self.force_close_timeout(); let drain_timeout_secs = drain_timeout.map(|d| d.as_secs()).unwrap_or(0); info!( stale_writers = stale_writer_ids.len(), + previous_generation, + generation, + hardswap, + coverage_ratio = format_args!("{coverage_ratio:.3}"), + min_ratio = format_args!("{min_ratio:.3}"), drain_timeout_secs, "ME map update covered; draining stale writers" ); + self.stats.increment_pool_swap_total(); for writer_id in stale_writer_ids { - self.mark_writer_draining_with_timeout(writer_id, drain_timeout) + self.mark_writer_draining_with_timeout(writer_id, drain_timeout, !hardswap) .await; } } @@ -507,9 +673,12 @@ impl MePool { let hs = self.handshake_only(stream, addr, rng).await?; let writer_id = self.next_writer_id.fetch_add(1, Ordering::Relaxed); + let generation = self.current_generation(); let cancel = CancellationToken::new(); let degraded = Arc::new(AtomicBool::new(false)); let draining = Arc::new(AtomicBool::new(false)); + let draining_started_at_epoch_secs = Arc::new(AtomicU64::new(0)); + let allow_drain_fallback = Arc::new(AtomicBool::new(false)); let (tx, mut rx) = mpsc::channel::(4096); let mut rpc_writer = RpcWriter { writer: hs.wr, @@ -540,10 +709,13 @@ impl MePool { let writer = MeWriter { id: writer_id, addr, + generation, tx: tx.clone(), cancel: cancel.clone(), degraded: degraded.clone(), draining: draining.clone(), + draining_started_at_epoch_secs: draining_started_at_epoch_secs.clone(), + allow_drain_fallback: allow_drain_fallback.clone(), }; self.writers.write().await.push(writer.clone()); self.conn_count.fetch_add(1, Ordering::Relaxed); @@ -715,6 +887,9 @@ impl MePool { let mut ws = self.writers.write().await; if let Some(pos) = ws.iter().position(|w| w.id == writer_id) { let w = ws.remove(pos); + if w.draining.load(Ordering::Relaxed) { + self.stats.decrement_pool_drain_active(); + } w.cancel.cancel(); close_tx = Some(w.tx.clone()); self.conn_count.fetch_sub(1, Ordering::Relaxed); @@ -731,11 +906,20 @@ impl MePool { self: &Arc, writer_id: u64, timeout: Option, + allow_drain_fallback: bool, ) { let timeout = timeout.filter(|d| !d.is_zero()); let found = { let mut ws = self.writers.write().await; if let Some(w) = ws.iter_mut().find(|w| w.id == writer_id) { + let already_draining = w.draining.swap(true, Ordering::Relaxed); + w.allow_drain_fallback + .store(allow_drain_fallback, Ordering::Relaxed); + w.draining_started_at_epoch_secs + .store(Self::now_epoch_secs(), Ordering::Relaxed); + if !already_draining { + self.stats.increment_pool_drain_active(); + } w.draining.store(true, Ordering::Relaxed); true } else { @@ -748,7 +932,12 @@ impl MePool { } let timeout_secs = timeout.map(|d| d.as_secs()).unwrap_or(0); - debug!(writer_id, timeout_secs, "ME writer marked draining"); + debug!( + writer_id, + timeout_secs, + allow_drain_fallback, + "ME writer marked draining" + ); let pool = Arc::downgrade(self); tokio::spawn(async move { @@ -758,6 +947,7 @@ impl MePool { if let Some(deadline_at) = deadline { if Instant::now() >= deadline_at { warn!(writer_id, "Drain timeout, force-closing"); + p.stats.increment_pool_force_close_total(); let _ = p.remove_writer_and_close_clients(writer_id).await; break; } @@ -775,10 +965,31 @@ impl MePool { } pub(crate) async fn mark_writer_draining(self: &Arc, writer_id: u64) { - self.mark_writer_draining_with_timeout(writer_id, Some(Duration::from_secs(300))) + self.mark_writer_draining_with_timeout(writer_id, Some(Duration::from_secs(300)), false) .await; } + pub(super) fn writer_accepts_new_binding(&self, writer: &MeWriter) -> bool { + if !writer.draining.load(Ordering::Relaxed) { + return true; + } + if !writer.allow_drain_fallback.load(Ordering::Relaxed) { + return false; + } + + let ttl_secs = self.me_pool_drain_ttl_secs.load(Ordering::Relaxed); + if ttl_secs == 0 { + return true; + } + + let started = writer.draining_started_at_epoch_secs.load(Ordering::Relaxed); + if started == 0 { + return false; + } + + Self::now_epoch_secs().saturating_sub(started) <= ttl_secs + } + } fn hex_dump(data: &[u8]) -> String { diff --git a/src/transport/middle_proxy/send.rs b/src/transport/middle_proxy/send.rs index 2ebafea..56bd17a 100644 --- a/src/transport/middle_proxy/send.rs +++ b/src/transport/middle_proxy/send.rs @@ -134,8 +134,8 @@ impl MePool { candidate_indices.sort_by_key(|idx| { let w = &writers_snapshot[*idx]; let degraded = w.degraded.load(Ordering::Relaxed); - let draining = w.draining.load(Ordering::Relaxed); - (draining as usize, degraded as usize) + let stale = (w.generation < self.current_generation()) as usize; + (stale, degraded as usize) }); let start = self.rr.fetch_add(1, Ordering::Relaxed) as usize % candidate_indices.len(); @@ -143,13 +143,23 @@ impl MePool { for offset in 0..candidate_indices.len() { let idx = candidate_indices[(start + offset) % candidate_indices.len()]; let w = &writers_snapshot[idx]; - if w.draining.load(Ordering::Relaxed) { + if !self.writer_accepts_new_binding(w) { continue; } if w.tx.send(WriterCommand::Data(payload.clone())).await.is_ok() { self.registry .bind_writer(conn_id, w.id, w.tx.clone(), meta.clone()) .await; + if w.generation < self.current_generation() { + self.stats.increment_pool_stale_pick_total(); + debug!( + conn_id, + writer_id = w.id, + writer_generation = w.generation, + current_generation = self.current_generation(), + "Selected stale ME writer for fallback bind" + ); + } return Ok(()); } else { warn!(writer_id = w.id, "ME writer channel closed"); @@ -159,7 +169,7 @@ impl MePool { } let w = writers_snapshot[candidate_indices[start]].clone(); - if w.draining.load(Ordering::Relaxed) { + if !self.writer_accepts_new_binding(&w) { continue; } match w.tx.send(WriterCommand::Data(payload.clone())).await { @@ -167,6 +177,9 @@ impl MePool { self.registry .bind_writer(conn_id, w.id, w.tx.clone(), meta.clone()) .await; + if w.generation < self.current_generation() { + self.stats.increment_pool_stale_pick_total(); + } return Ok(()); } Err(_) => { @@ -245,13 +258,13 @@ impl MePool { if preferred.is_empty() { return (0..writers.len()) - .filter(|i| !writers[*i].draining.load(Ordering::Relaxed)) + .filter(|i| self.writer_accepts_new_binding(&writers[*i])) .collect(); } let mut out = Vec::new(); for (idx, w) in writers.iter().enumerate() { - if w.draining.load(Ordering::Relaxed) { + if !self.writer_accepts_new_binding(w) { continue; } if preferred.iter().any(|p| *p == w.addr) { @@ -260,7 +273,7 @@ impl MePool { } if out.is_empty() { return (0..writers.len()) - .filter(|i| !writers[*i].draining.load(Ordering::Relaxed)) + .filter(|i| self.writer_accepts_new_binding(&writers[*i])) .collect(); } out