diff --git a/src/config/hot_reload.rs b/src/config/hot_reload.rs index 6f07a4b..d781f67 100644 --- a/src/config/hot_reload.rs +++ b/src/config/hot_reload.rs @@ -37,7 +37,6 @@ use crate::config::{ }; use super::load::{LoadedConfig, ProxyConfig}; -const HOT_RELOAD_STABLE_SNAPSHOTS: u8 = 2; const HOT_RELOAD_DEBOUNCE: Duration = Duration::from_millis(50); // ── Hot fields ──────────────────────────────────────────────────────────────── @@ -329,41 +328,19 @@ impl WatchManifest { #[derive(Debug, Default)] struct ReloadState { applied_snapshot_hash: Option, - candidate_snapshot_hash: Option, - candidate_hits: u8, } impl ReloadState { fn new(applied_snapshot_hash: Option) -> Self { - Self { - applied_snapshot_hash, - candidate_snapshot_hash: None, - candidate_hits: 0, - } + Self { applied_snapshot_hash } } fn is_applied(&self, hash: u64) -> bool { self.applied_snapshot_hash == Some(hash) } - fn observe_candidate(&mut self, hash: u64) -> u8 { - if self.candidate_snapshot_hash == Some(hash) { - self.candidate_hits = self.candidate_hits.saturating_add(1); - } else { - self.candidate_snapshot_hash = Some(hash); - self.candidate_hits = 1; - } - self.candidate_hits - } - - fn reset_candidate(&mut self) { - self.candidate_snapshot_hash = None; - self.candidate_hits = 0; - } - fn mark_applied(&mut self, hash: u64) { self.applied_snapshot_hash = Some(hash); - self.reset_candidate(); } } @@ -1138,7 +1115,6 @@ fn reload_config( let loaded = match ProxyConfig::load_with_metadata(config_path) { Ok(loaded) => loaded, Err(e) => { - reload_state.reset_candidate(); error!("config reload: failed to parse {:?}: {}", config_path, e); return None; } @@ -1151,7 +1127,6 @@ fn reload_config( let next_manifest = WatchManifest::from_source_files(&source_files); if let Err(e) = new_cfg.validate() { - reload_state.reset_candidate(); error!("config reload: validation failed: {}; keeping old config", e); return Some(next_manifest); } @@ -1160,17 +1135,6 @@ fn reload_config( return Some(next_manifest); } - let candidate_hits = reload_state.observe_candidate(rendered_hash); - if candidate_hits < HOT_RELOAD_STABLE_SNAPSHOTS { - info!( - snapshot_hash = rendered_hash, - candidate_hits, - required_hits = HOT_RELOAD_STABLE_SNAPSHOTS, - "config reload: candidate snapshot observed but not stable yet" - ); - return Some(next_manifest); - } - let old_cfg = config_tx.borrow().clone(); let applied_cfg = overlay_hot_fields(&old_cfg, &new_cfg); let old_hot = HotFields::from_config(&old_cfg); @@ -1190,7 +1154,6 @@ fn reload_config( if old_hot.dns_overrides != applied_hot.dns_overrides && let Err(e) = crate::network::dns_overrides::install_entries(&applied_hot.dns_overrides) { - reload_state.reset_candidate(); error!( "config reload: invalid network.dns_overrides: {}; keeping old config", e @@ -1334,14 +1297,28 @@ pub fn spawn_config_watcher( tokio::time::sleep(HOT_RELOAD_DEBOUNCE).await; while notify_rx.try_recv().is_ok() {} - if let Some(next_manifest) = reload_config( + let mut next_manifest = reload_config( &config_path, &config_tx, &log_tx, detected_ip_v4, detected_ip_v6, &mut reload_state, - ) { + ); + if next_manifest.is_none() { + tokio::time::sleep(HOT_RELOAD_DEBOUNCE).await; + while notify_rx.try_recv().is_ok() {} + next_manifest = reload_config( + &config_path, + &config_tx, + &log_tx, + detected_ip_v4, + detected_ip_v6, + &mut reload_state, + ); + } + + if let Some(next_manifest) = next_manifest { apply_watch_manifest( inotify_watcher.as_mut(), poll_watcher.as_mut(), @@ -1466,7 +1443,7 @@ mod tests { } #[test] - fn reload_requires_stable_snapshot_before_hot_apply() { + fn reload_applies_hot_change_on_first_observed_snapshot() { let initial_tag = "11111111111111111111111111111111"; let final_tag = "22222222222222222222222222222222"; let path = temp_config_path("telemt_hot_reload_stable"); @@ -1478,20 +1455,7 @@ mod tests { let (log_tx, _log_rx) = watch::channel(initial_cfg.general.log_level.clone()); let mut reload_state = ReloadState::new(Some(initial_hash)); - write_reload_config(&path, None, None); - reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap(); - assert_eq!( - config_tx.borrow().general.ad_tag.as_deref(), - Some(initial_tag) - ); - write_reload_config(&path, Some(final_tag), None); - reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap(); - assert_eq!( - config_tx.borrow().general.ad_tag.as_deref(), - Some(initial_tag) - ); - reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap(); assert_eq!(config_tx.borrow().general.ad_tag.as_deref(), Some(final_tag)); @@ -1513,7 +1477,6 @@ mod tests { write_reload_config(&path, Some(final_tag), Some(initial_cfg.server.port + 1)); reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap(); - reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap(); let applied = config_tx.borrow().clone(); assert_eq!(applied.general.ad_tag.as_deref(), Some(final_tag)); @@ -1521,4 +1484,31 @@ mod tests { let _ = std::fs::remove_file(path); } + + #[test] + fn reload_recovers_after_parse_error_on_next_attempt() { + let initial_tag = "cccccccccccccccccccccccccccccccc"; + let final_tag = "dddddddddddddddddddddddddddddddd"; + let path = temp_config_path("telemt_hot_reload_parse_recovery"); + + write_reload_config(&path, Some(initial_tag), None); + let initial_cfg = Arc::new(ProxyConfig::load(&path).unwrap()); + let initial_hash = ProxyConfig::load_with_metadata(&path).unwrap().rendered_hash; + let (config_tx, _config_rx) = watch::channel(initial_cfg.clone()); + let (log_tx, _log_rx) = watch::channel(initial_cfg.general.log_level.clone()); + let mut reload_state = ReloadState::new(Some(initial_hash)); + + std::fs::write(&path, "[access.users\nuser = \"broken\"\n").unwrap(); + assert!(reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).is_none()); + assert_eq!( + config_tx.borrow().general.ad_tag.as_deref(), + Some(initial_tag) + ); + + write_reload_config(&path, Some(final_tag), None); + reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap(); + assert_eq!(config_tx.borrow().general.ad_tag.as_deref(), Some(final_tag)); + + let _ = std::fs::remove_file(path); + } } diff --git a/src/maestro/helpers.rs b/src/maestro/helpers.rs index 78f3ec4..029d0ee 100644 --- a/src/maestro/helpers.rs +++ b/src/maestro/helpers.rs @@ -10,6 +10,16 @@ use crate::transport::middle_proxy::{ ProxyConfigData, fetch_proxy_config_with_raw, load_proxy_config_cache, save_proxy_config_cache, }; +pub(crate) fn resolve_runtime_config_path(config_path_cli: &str, startup_cwd: &std::path::Path) -> PathBuf { + let raw = PathBuf::from(config_path_cli); + let absolute = if raw.is_absolute() { + raw + } else { + startup_cwd.join(raw) + }; + absolute.canonicalize().unwrap_or(absolute) +} + pub(crate) fn parse_cli() -> (String, Option, bool, Option) { let mut config_path = "config.toml".to_string(); let mut data_path: Option = None; @@ -96,6 +106,44 @@ pub(crate) fn parse_cli() -> (String, Option, bool, Option) { (config_path, data_path, silent, log_level) } +#[cfg(test)] +mod tests { + use super::resolve_runtime_config_path; + + #[test] + fn resolve_runtime_config_path_anchors_relative_to_startup_cwd() { + let nonce = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_nanos(); + let startup_cwd = std::env::temp_dir().join(format!("telemt_cfg_path_{nonce}")); + std::fs::create_dir_all(&startup_cwd).unwrap(); + let target = startup_cwd.join("config.toml"); + std::fs::write(&target, " ").unwrap(); + + let resolved = resolve_runtime_config_path("config.toml", &startup_cwd); + assert_eq!(resolved, target.canonicalize().unwrap()); + + let _ = std::fs::remove_file(&target); + let _ = std::fs::remove_dir(&startup_cwd); + } + + #[test] + fn resolve_runtime_config_path_keeps_absolute_for_missing_file() { + let nonce = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_nanos(); + let startup_cwd = std::env::temp_dir().join(format!("telemt_cfg_path_missing_{nonce}")); + std::fs::create_dir_all(&startup_cwd).unwrap(); + + let resolved = resolve_runtime_config_path("missing.toml", &startup_cwd); + assert_eq!(resolved, startup_cwd.join("missing.toml")); + + let _ = std::fs::remove_dir(&startup_cwd); + } +} + pub(crate) fn print_proxy_links(host: &str, port: u16, config: &ProxyConfig) { info!(target: "telemt::links", "--- Proxy Links ({}) ---", host); for user_name in config.general.links.show.resolve_users(&config.access.users) { diff --git a/src/maestro/mod.rs b/src/maestro/mod.rs index da00b40..d4ce2e0 100644 --- a/src/maestro/mod.rs +++ b/src/maestro/mod.rs @@ -45,7 +45,7 @@ use crate::startup::{ use crate::stream::BufferPool; use crate::transport::middle_proxy::MePool; use crate::transport::UpstreamManager; -use helpers::parse_cli; +use helpers::{parse_cli, resolve_runtime_config_path}; /// Runs the full telemt runtime startup pipeline and blocks until shutdown. pub async fn run() -> std::result::Result<(), Box> { @@ -58,18 +58,26 @@ pub async fn run() -> std::result::Result<(), Box> { startup_tracker .start_component(COMPONENT_CONFIG_LOAD, Some("load and validate config".to_string())) .await; - let (config_path, data_path, cli_silent, cli_log_level) = parse_cli(); + let (config_path_cli, data_path, cli_silent, cli_log_level) = parse_cli(); + let startup_cwd = match std::env::current_dir() { + Ok(cwd) => cwd, + Err(e) => { + eprintln!("[telemt] Can't read current_dir: {}", e); + std::process::exit(1); + } + }; + let config_path = resolve_runtime_config_path(&config_path_cli, &startup_cwd); let mut config = match ProxyConfig::load(&config_path) { Ok(c) => c, Err(e) => { - if std::path::Path::new(&config_path).exists() { + if config_path.exists() { eprintln!("[telemt] Error: {}", e); std::process::exit(1); } else { let default = ProxyConfig::default(); std::fs::write(&config_path, toml::to_string_pretty(&default).unwrap()).unwrap(); - eprintln!("[telemt] Created default config at {}", config_path); + eprintln!("[telemt] Created default config at {}", config_path.display()); default } } @@ -258,7 +266,7 @@ pub async fn run() -> std::result::Result<(), Box> { let route_runtime_api = route_runtime.clone(); let config_rx_api = api_config_rx.clone(); let admission_rx_api = admission_rx.clone(); - let config_path_api = std::path::PathBuf::from(&config_path); + let config_path_api = config_path.clone(); let startup_tracker_api = startup_tracker.clone(); let detected_ips_rx_api = detected_ips_rx.clone(); tokio::spawn(async move { diff --git a/src/maestro/runtime_tasks.rs b/src/maestro/runtime_tasks.rs index d9691a8..c2233c7 100644 --- a/src/maestro/runtime_tasks.rs +++ b/src/maestro/runtime_tasks.rs @@ -1,5 +1,5 @@ use std::net::IpAddr; -use std::path::PathBuf; +use std::path::Path; use std::sync::Arc; use tokio::sync::{mpsc, watch}; @@ -32,7 +32,7 @@ pub(crate) struct RuntimeWatches { #[allow(clippy::too_many_arguments)] pub(crate) async fn spawn_runtime_tasks( config: &Arc, - config_path: &str, + config_path: &Path, probe: &NetworkProbe, prefer_ipv6: bool, decision_ipv4_dc: bool, @@ -83,7 +83,7 @@ pub(crate) async fn spawn_runtime_tasks( watch::Receiver>, watch::Receiver, ) = spawn_config_watcher( - PathBuf::from(config_path), + config_path.to_path_buf(), config.clone(), detected_ip_v4, detected_ip_v6, diff --git a/src/transport/middle_proxy/pool.rs b/src/transport/middle_proxy/pool.rs index 2a65160..56f3fbf 100644 --- a/src/transport/middle_proxy/pool.rs +++ b/src/transport/middle_proxy/pool.rs @@ -160,6 +160,7 @@ pub struct MePool { pub(super) refill_inflight: Arc>>, pub(super) refill_inflight_dc: Arc>>, pub(super) conn_count: AtomicUsize, + pub(super) draining_active_runtime: AtomicU64, pub(super) stats: Arc, pub(super) generation: AtomicU64, pub(super) active_generation: AtomicU64, @@ -438,6 +439,7 @@ impl MePool { refill_inflight: Arc::new(Mutex::new(HashSet::new())), refill_inflight_dc: Arc::new(Mutex::new(HashSet::new())), conn_count: AtomicUsize::new(0), + draining_active_runtime: AtomicU64::new(0), generation: AtomicU64::new(1), active_generation: AtomicU64::new(1), warm_generation: AtomicU64::new(0), @@ -690,6 +692,32 @@ impl MePool { } } + pub(super) fn draining_active_runtime(&self) -> u64 { + self.draining_active_runtime.load(Ordering::Relaxed) + } + + pub(super) fn increment_draining_active_runtime(&self) { + self.draining_active_runtime.fetch_add(1, Ordering::Relaxed); + } + + pub(super) fn decrement_draining_active_runtime(&self) { + let mut current = self.draining_active_runtime.load(Ordering::Relaxed); + loop { + if current == 0 { + break; + } + match self.draining_active_runtime.compare_exchange_weak( + current, + current - 1, + Ordering::Relaxed, + Ordering::Relaxed, + ) { + Ok(_) => break, + Err(actual) => current = actual, + } + } + } + pub(super) async fn key_selector(&self) -> u32 { self.proxy_secret.read().await.key_selector } diff --git a/src/transport/middle_proxy/pool_reinit.rs b/src/transport/middle_proxy/pool_reinit.rs index 3d9d679..3cfc834 100644 --- a/src/transport/middle_proxy/pool_reinit.rs +++ b/src/transport/middle_proxy/pool_reinit.rs @@ -141,6 +141,38 @@ impl MePool { out } + pub(super) async fn has_non_draining_writer_per_desired_dc_group(&self) -> bool { + let desired_by_dc = self.desired_dc_endpoints().await; + let required_dcs: HashSet = desired_by_dc + .iter() + .filter_map(|(dc, endpoints)| { + if endpoints.is_empty() { + None + } else { + Some(*dc) + } + }) + .collect(); + if required_dcs.is_empty() { + return true; + } + + let ws = self.writers.read().await; + let mut covered_dcs = HashSet::::with_capacity(required_dcs.len()); + for writer in ws.iter() { + if writer.draining.load(Ordering::Relaxed) { + continue; + } + if required_dcs.contains(&writer.writer_dc) { + covered_dcs.insert(writer.writer_dc); + if covered_dcs.len() == required_dcs.len() { + return true; + } + } + } + false + } + fn hardswap_warmup_connect_delay_ms(&self) -> u64 { let min_ms = self.me_hardswap_warmup_delay_min_ms.load(Ordering::Relaxed); let max_ms = self.me_hardswap_warmup_delay_max_ms.load(Ordering::Relaxed); @@ -475,12 +507,30 @@ impl MePool { coverage_ratio = format_args!("{coverage_ratio:.3}"), min_ratio = format_args!("{min_ratio:.3}"), drain_timeout_secs, - "ME reinit cycle covered; draining stale writers" + "ME reinit cycle covered; processing stale writers" ); self.stats.increment_pool_swap_total(); + let can_drop_with_replacement = self + .has_non_draining_writer_per_desired_dc_group() + .await; + if can_drop_with_replacement { + info!( + stale_writers = stale_writer_ids.len(), + "ME reinit stale writers: replacement coverage ready, force-closing clients for fast rebind" + ); + } else { + warn!( + stale_writers = stale_writer_ids.len(), + "ME reinit stale writers: replacement coverage incomplete, keeping draining fallback" + ); + } for writer_id in stale_writer_ids { self.mark_writer_draining_with_timeout(writer_id, drain_timeout, !hardswap) .await; + if can_drop_with_replacement { + self.stats.increment_pool_force_close_total(); + self.remove_writer_and_close_clients(writer_id).await; + } } if hardswap { self.clear_pending_hardswap_state(); diff --git a/src/transport/middle_proxy/pool_writer.rs b/src/transport/middle_proxy/pool_writer.rs index 8ce3de3..7490a98 100644 --- a/src/transport/middle_proxy/pool_writer.rs +++ b/src/transport/middle_proxy/pool_writer.rs @@ -514,6 +514,7 @@ impl MePool { let was_draining = w.draining.load(Ordering::Relaxed); if was_draining { self.stats.decrement_pool_drain_active(); + self.decrement_draining_active_runtime(); } self.stats.increment_me_writer_removed_total(); w.cancel.cancel(); @@ -572,6 +573,7 @@ impl MePool { .store(drain_deadline_epoch_secs, Ordering::Relaxed); if !already_draining { self.stats.increment_pool_drain_active(); + self.increment_draining_active_runtime(); } w.contour .store(WriterContour::Draining.as_u8(), Ordering::Relaxed); diff --git a/src/transport/middle_proxy/registry.rs b/src/transport/middle_proxy/registry.rs index cc3028b..cbe1d9a 100644 --- a/src/transport/middle_proxy/registry.rs +++ b/src/transport/middle_proxy/registry.rs @@ -436,6 +436,19 @@ impl ConnRegistry { .map(|s| s.is_empty()) .unwrap_or(true) } + + pub(super) async fn non_empty_writer_ids(&self, writer_ids: &[u64]) -> HashSet { + let inner = self.inner.read().await; + let mut out = HashSet::::with_capacity(writer_ids.len()); + for writer_id in writer_ids { + if let Some(conns) = inner.conns_for_writer.get(writer_id) + && !conns.is_empty() + { + out.insert(*writer_id); + } + } + out + } } #[cfg(test)] @@ -634,4 +647,35 @@ mod tests { ); assert!(registry.get_writer(conn_id).await.is_none()); } + + #[tokio::test] + async fn non_empty_writer_ids_returns_only_writers_with_bound_clients() { + let registry = ConnRegistry::new(); + let (conn_id, _rx) = registry.register().await; + let (writer_tx_a, _writer_rx_a) = tokio::sync::mpsc::channel(8); + let (writer_tx_b, _writer_rx_b) = tokio::sync::mpsc::channel(8); + registry.register_writer(10, writer_tx_a).await; + registry.register_writer(20, writer_tx_b).await; + + let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443); + assert!( + registry + .bind_writer( + conn_id, + 10, + ConnMeta { + target_dc: 2, + client_addr: addr, + our_addr: addr, + proto_flags: 0, + }, + ) + .await + ); + + let non_empty = registry.non_empty_writer_ids(&[10, 20, 30]).await; + assert!(non_empty.contains(&10)); + assert!(!non_empty.contains(&20)); + assert!(!non_empty.contains(&30)); + } }