Merge upstream/flow-sec into pr-sec-1

This commit is contained in:
David Osipov 2026-03-17 19:48:53 +04:00
commit 50a827e7fd
No known key found for this signature in database
GPG Key ID: 0E55C4A47454E82E
8 changed files with 234 additions and 64 deletions

View File

@ -37,7 +37,6 @@ use crate::config::{
};
use super::load::{LoadedConfig, ProxyConfig};
const HOT_RELOAD_STABLE_SNAPSHOTS: u8 = 2;
const HOT_RELOAD_DEBOUNCE: Duration = Duration::from_millis(50);
// ── Hot fields ────────────────────────────────────────────────────────────────
@ -329,41 +328,19 @@ impl WatchManifest {
#[derive(Debug, Default)]
struct ReloadState {
applied_snapshot_hash: Option<u64>,
candidate_snapshot_hash: Option<u64>,
candidate_hits: u8,
}
impl ReloadState {
fn new(applied_snapshot_hash: Option<u64>) -> Self {
Self {
applied_snapshot_hash,
candidate_snapshot_hash: None,
candidate_hits: 0,
}
Self { applied_snapshot_hash }
}
fn is_applied(&self, hash: u64) -> bool {
self.applied_snapshot_hash == Some(hash)
}
fn observe_candidate(&mut self, hash: u64) -> u8 {
if self.candidate_snapshot_hash == Some(hash) {
self.candidate_hits = self.candidate_hits.saturating_add(1);
} else {
self.candidate_snapshot_hash = Some(hash);
self.candidate_hits = 1;
}
self.candidate_hits
}
fn reset_candidate(&mut self) {
self.candidate_snapshot_hash = None;
self.candidate_hits = 0;
}
fn mark_applied(&mut self, hash: u64) {
self.applied_snapshot_hash = Some(hash);
self.reset_candidate();
}
}
@ -1138,7 +1115,6 @@ fn reload_config(
let loaded = match ProxyConfig::load_with_metadata(config_path) {
Ok(loaded) => loaded,
Err(e) => {
reload_state.reset_candidate();
error!("config reload: failed to parse {:?}: {}", config_path, e);
return None;
}
@ -1151,7 +1127,6 @@ fn reload_config(
let next_manifest = WatchManifest::from_source_files(&source_files);
if let Err(e) = new_cfg.validate() {
reload_state.reset_candidate();
error!("config reload: validation failed: {}; keeping old config", e);
return Some(next_manifest);
}
@ -1160,17 +1135,6 @@ fn reload_config(
return Some(next_manifest);
}
let candidate_hits = reload_state.observe_candidate(rendered_hash);
if candidate_hits < HOT_RELOAD_STABLE_SNAPSHOTS {
info!(
snapshot_hash = rendered_hash,
candidate_hits,
required_hits = HOT_RELOAD_STABLE_SNAPSHOTS,
"config reload: candidate snapshot observed but not stable yet"
);
return Some(next_manifest);
}
let old_cfg = config_tx.borrow().clone();
let applied_cfg = overlay_hot_fields(&old_cfg, &new_cfg);
let old_hot = HotFields::from_config(&old_cfg);
@ -1190,7 +1154,6 @@ fn reload_config(
if old_hot.dns_overrides != applied_hot.dns_overrides
&& let Err(e) = crate::network::dns_overrides::install_entries(&applied_hot.dns_overrides)
{
reload_state.reset_candidate();
error!(
"config reload: invalid network.dns_overrides: {}; keeping old config",
e
@ -1334,14 +1297,28 @@ pub fn spawn_config_watcher(
tokio::time::sleep(HOT_RELOAD_DEBOUNCE).await;
while notify_rx.try_recv().is_ok() {}
if let Some(next_manifest) = reload_config(
let mut next_manifest = reload_config(
&config_path,
&config_tx,
&log_tx,
detected_ip_v4,
detected_ip_v6,
&mut reload_state,
) {
);
if next_manifest.is_none() {
tokio::time::sleep(HOT_RELOAD_DEBOUNCE).await;
while notify_rx.try_recv().is_ok() {}
next_manifest = reload_config(
&config_path,
&config_tx,
&log_tx,
detected_ip_v4,
detected_ip_v6,
&mut reload_state,
);
}
if let Some(next_manifest) = next_manifest {
apply_watch_manifest(
inotify_watcher.as_mut(),
poll_watcher.as_mut(),
@ -1466,7 +1443,7 @@ mod tests {
}
#[test]
fn reload_requires_stable_snapshot_before_hot_apply() {
fn reload_applies_hot_change_on_first_observed_snapshot() {
let initial_tag = "11111111111111111111111111111111";
let final_tag = "22222222222222222222222222222222";
let path = temp_config_path("telemt_hot_reload_stable");
@ -1478,20 +1455,7 @@ mod tests {
let (log_tx, _log_rx) = watch::channel(initial_cfg.general.log_level.clone());
let mut reload_state = ReloadState::new(Some(initial_hash));
write_reload_config(&path, None, None);
reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap();
assert_eq!(
config_tx.borrow().general.ad_tag.as_deref(),
Some(initial_tag)
);
write_reload_config(&path, Some(final_tag), None);
reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap();
assert_eq!(
config_tx.borrow().general.ad_tag.as_deref(),
Some(initial_tag)
);
reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap();
assert_eq!(config_tx.borrow().general.ad_tag.as_deref(), Some(final_tag));
@ -1513,7 +1477,6 @@ mod tests {
write_reload_config(&path, Some(final_tag), Some(initial_cfg.server.port + 1));
reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap();
reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap();
let applied = config_tx.borrow().clone();
assert_eq!(applied.general.ad_tag.as_deref(), Some(final_tag));
@ -1521,4 +1484,31 @@ mod tests {
let _ = std::fs::remove_file(path);
}
#[test]
fn reload_recovers_after_parse_error_on_next_attempt() {
let initial_tag = "cccccccccccccccccccccccccccccccc";
let final_tag = "dddddddddddddddddddddddddddddddd";
let path = temp_config_path("telemt_hot_reload_parse_recovery");
write_reload_config(&path, Some(initial_tag), None);
let initial_cfg = Arc::new(ProxyConfig::load(&path).unwrap());
let initial_hash = ProxyConfig::load_with_metadata(&path).unwrap().rendered_hash;
let (config_tx, _config_rx) = watch::channel(initial_cfg.clone());
let (log_tx, _log_rx) = watch::channel(initial_cfg.general.log_level.clone());
let mut reload_state = ReloadState::new(Some(initial_hash));
std::fs::write(&path, "[access.users\nuser = \"broken\"\n").unwrap();
assert!(reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).is_none());
assert_eq!(
config_tx.borrow().general.ad_tag.as_deref(),
Some(initial_tag)
);
write_reload_config(&path, Some(final_tag), None);
reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap();
assert_eq!(config_tx.borrow().general.ad_tag.as_deref(), Some(final_tag));
let _ = std::fs::remove_file(path);
}
}

View File

@ -10,6 +10,16 @@ use crate::transport::middle_proxy::{
ProxyConfigData, fetch_proxy_config_with_raw, load_proxy_config_cache, save_proxy_config_cache,
};
pub(crate) fn resolve_runtime_config_path(config_path_cli: &str, startup_cwd: &std::path::Path) -> PathBuf {
let raw = PathBuf::from(config_path_cli);
let absolute = if raw.is_absolute() {
raw
} else {
startup_cwd.join(raw)
};
absolute.canonicalize().unwrap_or(absolute)
}
pub(crate) fn parse_cli() -> (String, Option<PathBuf>, bool, Option<String>) {
let mut config_path = "config.toml".to_string();
let mut data_path: Option<PathBuf> = None;
@ -96,6 +106,44 @@ pub(crate) fn parse_cli() -> (String, Option<PathBuf>, bool, Option<String>) {
(config_path, data_path, silent, log_level)
}
#[cfg(test)]
mod tests {
use super::resolve_runtime_config_path;
#[test]
fn resolve_runtime_config_path_anchors_relative_to_startup_cwd() {
let nonce = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_nanos();
let startup_cwd = std::env::temp_dir().join(format!("telemt_cfg_path_{nonce}"));
std::fs::create_dir_all(&startup_cwd).unwrap();
let target = startup_cwd.join("config.toml");
std::fs::write(&target, " ").unwrap();
let resolved = resolve_runtime_config_path("config.toml", &startup_cwd);
assert_eq!(resolved, target.canonicalize().unwrap());
let _ = std::fs::remove_file(&target);
let _ = std::fs::remove_dir(&startup_cwd);
}
#[test]
fn resolve_runtime_config_path_keeps_absolute_for_missing_file() {
let nonce = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_nanos();
let startup_cwd = std::env::temp_dir().join(format!("telemt_cfg_path_missing_{nonce}"));
std::fs::create_dir_all(&startup_cwd).unwrap();
let resolved = resolve_runtime_config_path("missing.toml", &startup_cwd);
assert_eq!(resolved, startup_cwd.join("missing.toml"));
let _ = std::fs::remove_dir(&startup_cwd);
}
}
pub(crate) fn print_proxy_links(host: &str, port: u16, config: &ProxyConfig) {
info!(target: "telemt::links", "--- Proxy Links ({}) ---", host);
for user_name in config.general.links.show.resolve_users(&config.access.users) {

View File

@ -45,7 +45,7 @@ use crate::startup::{
use crate::stream::BufferPool;
use crate::transport::middle_proxy::MePool;
use crate::transport::UpstreamManager;
use helpers::parse_cli;
use helpers::{parse_cli, resolve_runtime_config_path};
/// Runs the full telemt runtime startup pipeline and blocks until shutdown.
pub async fn run() -> std::result::Result<(), Box<dyn std::error::Error>> {
@ -58,18 +58,26 @@ pub async fn run() -> std::result::Result<(), Box<dyn std::error::Error>> {
startup_tracker
.start_component(COMPONENT_CONFIG_LOAD, Some("load and validate config".to_string()))
.await;
let (config_path, data_path, cli_silent, cli_log_level) = parse_cli();
let (config_path_cli, data_path, cli_silent, cli_log_level) = parse_cli();
let startup_cwd = match std::env::current_dir() {
Ok(cwd) => cwd,
Err(e) => {
eprintln!("[telemt] Can't read current_dir: {}", e);
std::process::exit(1);
}
};
let config_path = resolve_runtime_config_path(&config_path_cli, &startup_cwd);
let mut config = match ProxyConfig::load(&config_path) {
Ok(c) => c,
Err(e) => {
if std::path::Path::new(&config_path).exists() {
if config_path.exists() {
eprintln!("[telemt] Error: {}", e);
std::process::exit(1);
} else {
let default = ProxyConfig::default();
std::fs::write(&config_path, toml::to_string_pretty(&default).unwrap()).unwrap();
eprintln!("[telemt] Created default config at {}", config_path);
eprintln!("[telemt] Created default config at {}", config_path.display());
default
}
}
@ -258,7 +266,7 @@ pub async fn run() -> std::result::Result<(), Box<dyn std::error::Error>> {
let route_runtime_api = route_runtime.clone();
let config_rx_api = api_config_rx.clone();
let admission_rx_api = admission_rx.clone();
let config_path_api = std::path::PathBuf::from(&config_path);
let config_path_api = config_path.clone();
let startup_tracker_api = startup_tracker.clone();
let detected_ips_rx_api = detected_ips_rx.clone();
tokio::spawn(async move {

View File

@ -1,5 +1,5 @@
use std::net::IpAddr;
use std::path::PathBuf;
use std::path::Path;
use std::sync::Arc;
use tokio::sync::{mpsc, watch};
@ -32,7 +32,7 @@ pub(crate) struct RuntimeWatches {
#[allow(clippy::too_many_arguments)]
pub(crate) async fn spawn_runtime_tasks(
config: &Arc<ProxyConfig>,
config_path: &str,
config_path: &Path,
probe: &NetworkProbe,
prefer_ipv6: bool,
decision_ipv4_dc: bool,
@ -83,7 +83,7 @@ pub(crate) async fn spawn_runtime_tasks(
watch::Receiver<Arc<ProxyConfig>>,
watch::Receiver<LogLevel>,
) = spawn_config_watcher(
PathBuf::from(config_path),
config_path.to_path_buf(),
config.clone(),
detected_ip_v4,
detected_ip_v6,

View File

@ -160,6 +160,7 @@ pub struct MePool {
pub(super) refill_inflight: Arc<Mutex<HashSet<RefillEndpointKey>>>,
pub(super) refill_inflight_dc: Arc<Mutex<HashSet<RefillDcKey>>>,
pub(super) conn_count: AtomicUsize,
pub(super) draining_active_runtime: AtomicU64,
pub(super) stats: Arc<crate::stats::Stats>,
pub(super) generation: AtomicU64,
pub(super) active_generation: AtomicU64,
@ -438,6 +439,7 @@ impl MePool {
refill_inflight: Arc::new(Mutex::new(HashSet::new())),
refill_inflight_dc: Arc::new(Mutex::new(HashSet::new())),
conn_count: AtomicUsize::new(0),
draining_active_runtime: AtomicU64::new(0),
generation: AtomicU64::new(1),
active_generation: AtomicU64::new(1),
warm_generation: AtomicU64::new(0),
@ -690,6 +692,32 @@ impl MePool {
}
}
pub(super) fn draining_active_runtime(&self) -> u64 {
self.draining_active_runtime.load(Ordering::Relaxed)
}
pub(super) fn increment_draining_active_runtime(&self) {
self.draining_active_runtime.fetch_add(1, Ordering::Relaxed);
}
pub(super) fn decrement_draining_active_runtime(&self) {
let mut current = self.draining_active_runtime.load(Ordering::Relaxed);
loop {
if current == 0 {
break;
}
match self.draining_active_runtime.compare_exchange_weak(
current,
current - 1,
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(_) => break,
Err(actual) => current = actual,
}
}
}
pub(super) async fn key_selector(&self) -> u32 {
self.proxy_secret.read().await.key_selector
}

View File

@ -141,6 +141,38 @@ impl MePool {
out
}
pub(super) async fn has_non_draining_writer_per_desired_dc_group(&self) -> bool {
let desired_by_dc = self.desired_dc_endpoints().await;
let required_dcs: HashSet<i32> = desired_by_dc
.iter()
.filter_map(|(dc, endpoints)| {
if endpoints.is_empty() {
None
} else {
Some(*dc)
}
})
.collect();
if required_dcs.is_empty() {
return true;
}
let ws = self.writers.read().await;
let mut covered_dcs = HashSet::<i32>::with_capacity(required_dcs.len());
for writer in ws.iter() {
if writer.draining.load(Ordering::Relaxed) {
continue;
}
if required_dcs.contains(&writer.writer_dc) {
covered_dcs.insert(writer.writer_dc);
if covered_dcs.len() == required_dcs.len() {
return true;
}
}
}
false
}
fn hardswap_warmup_connect_delay_ms(&self) -> u64 {
let min_ms = self.me_hardswap_warmup_delay_min_ms.load(Ordering::Relaxed);
let max_ms = self.me_hardswap_warmup_delay_max_ms.load(Ordering::Relaxed);
@ -475,12 +507,30 @@ impl MePool {
coverage_ratio = format_args!("{coverage_ratio:.3}"),
min_ratio = format_args!("{min_ratio:.3}"),
drain_timeout_secs,
"ME reinit cycle covered; draining stale writers"
"ME reinit cycle covered; processing stale writers"
);
self.stats.increment_pool_swap_total();
let can_drop_with_replacement = self
.has_non_draining_writer_per_desired_dc_group()
.await;
if can_drop_with_replacement {
info!(
stale_writers = stale_writer_ids.len(),
"ME reinit stale writers: replacement coverage ready, force-closing clients for fast rebind"
);
} else {
warn!(
stale_writers = stale_writer_ids.len(),
"ME reinit stale writers: replacement coverage incomplete, keeping draining fallback"
);
}
for writer_id in stale_writer_ids {
self.mark_writer_draining_with_timeout(writer_id, drain_timeout, !hardswap)
.await;
if can_drop_with_replacement {
self.stats.increment_pool_force_close_total();
self.remove_writer_and_close_clients(writer_id).await;
}
}
if hardswap {
self.clear_pending_hardswap_state();

View File

@ -514,6 +514,7 @@ impl MePool {
let was_draining = w.draining.load(Ordering::Relaxed);
if was_draining {
self.stats.decrement_pool_drain_active();
self.decrement_draining_active_runtime();
}
self.stats.increment_me_writer_removed_total();
w.cancel.cancel();
@ -572,6 +573,7 @@ impl MePool {
.store(drain_deadline_epoch_secs, Ordering::Relaxed);
if !already_draining {
self.stats.increment_pool_drain_active();
self.increment_draining_active_runtime();
}
w.contour
.store(WriterContour::Draining.as_u8(), Ordering::Relaxed);

View File

@ -436,6 +436,19 @@ impl ConnRegistry {
.map(|s| s.is_empty())
.unwrap_or(true)
}
pub(super) async fn non_empty_writer_ids(&self, writer_ids: &[u64]) -> HashSet<u64> {
let inner = self.inner.read().await;
let mut out = HashSet::<u64>::with_capacity(writer_ids.len());
for writer_id in writer_ids {
if let Some(conns) = inner.conns_for_writer.get(writer_id)
&& !conns.is_empty()
{
out.insert(*writer_id);
}
}
out
}
}
#[cfg(test)]
@ -634,4 +647,35 @@ mod tests {
);
assert!(registry.get_writer(conn_id).await.is_none());
}
#[tokio::test]
async fn non_empty_writer_ids_returns_only_writers_with_bound_clients() {
let registry = ConnRegistry::new();
let (conn_id, _rx) = registry.register().await;
let (writer_tx_a, _writer_rx_a) = tokio::sync::mpsc::channel(8);
let (writer_tx_b, _writer_rx_b) = tokio::sync::mpsc::channel(8);
registry.register_writer(10, writer_tx_a).await;
registry.register_writer(20, writer_tx_b).await;
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443);
assert!(
registry
.bind_writer(
conn_id,
10,
ConnMeta {
target_dc: 2,
client_addr: addr,
our_addr: addr,
proto_flags: 0,
},
)
.await
);
let non_empty = registry.non_empty_writer_ids(&[10, 20, 30]).await;
assert!(non_empty.contains(&10));
assert!(!non_empty.contains(&20));
assert!(!non_empty.contains(&30));
}
}