diff --git a/src/api/mod.rs b/src/api/mod.rs index 61f88fd..2d90157 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -62,7 +62,7 @@ use runtime_zero::{ use runtime_watch::spawn_runtime_watchers; use users::{create_user, delete_user, patch_user, rotate_secret, users_from_config}; -pub struct ApiRuntimeState { +pub(super) struct ApiRuntimeState { pub(super) process_started_at_epoch_secs: u64, pub(super) config_reload_count: AtomicU64, pub(super) last_config_reload_epoch_secs: AtomicU64, @@ -70,7 +70,7 @@ pub struct ApiRuntimeState { } #[derive(Clone)] -pub struct ApiShared { +pub(super) struct ApiShared { pub(super) stats: Arc, pub(super) ip_tracker: Arc, pub(super) me_pool: Arc>>>, @@ -565,18 +565,19 @@ async fn handle( } } -// XOR-fold constant-time comparison: avoids early return on length mismatch to -// prevent a timing oracle that would reveal the expected token length to a remote -// attacker (OWASP ASVS V6.6.1). Bitwise `&` on bool is eager — it never -// short-circuits — so both the length check and the byte fold always execute. +// XOR-fold constant-time comparison. Running time depends only on the length of the +// expected token (b), not on min(a.len(), b.len()), to prevent a timing oracle where +// an attacker reduces the iteration count by sending a shorter candidate +// (OWASP ASVS V6.6.1). Bitwise `&` on bool is eager — it never short-circuits — +// so both the length check and the byte fold always execute. fn constant_time_eq(a: &[u8], b: &[u8]) -> bool { - let length_ok = a.len() == b.len(); - let min_len = a.len().min(b.len()); - let byte_mismatch = a[..min_len] - .iter() - .zip(b[..min_len].iter()) - .fold(0u8, |acc, (x, y)| acc | (x ^ y)); - length_ok & (byte_mismatch == 0) + let mut diff = 0u8; + for i in 0..b.len() { + let x = a.get(i).copied().unwrap_or(0); + let y = b[i]; + diff |= x ^ y; + } + (a.len() == b.len()) & (diff == 0) } #[cfg(test)] @@ -765,4 +766,106 @@ mod tests { c[0] = 0xde; assert!(!constant_time_eq(&a, &c)); } + + // ── Timing-oracle adversarial tests (length-iteration invariant) ────────── + + // An active censor/attacker who knows the token format may send truncated + // inputs to narrow down the token length via a timing side-channel. + // After the fix, `constant_time_eq` always iterates over `b.len()` (the + // expected token length), so submission of every strict prefix of the + // expected token must be rejected while iteration count stays constant. + #[test] + fn constant_time_eq_every_prefix_of_expected_token_is_rejected() { + let expected = b"Bearer super-secret-api-token-abc123xyz"; + for prefix_len in 0..expected.len() { + let attacker_input = &expected[..prefix_len]; + assert!( + !constant_time_eq(attacker_input, expected), + "prefix of length {prefix_len} must not authenticate against full token" + ); + } + } + + // Reversed: input longer than expected token — the extra bytes must cause + // rejection even when the first b.len() bytes are correct. + #[test] + fn constant_time_eq_input_with_correct_prefix_plus_extra_bytes_is_rejected() { + let expected = b"secret-token"; + for extra in 1usize..=32 { + let mut longer = expected.to_vec(); + // Extend with zeros — the XOR of matching first bytes is 0, so only + // the length check prevents a false positive. + longer.extend(std::iter::repeat(0u8).take(extra)); + assert!( + !constant_time_eq(&longer, expected), + "input extended by {extra} zero bytes must not authenticate" + ); + // Extend with matching-value bytes — ensures the byte_fold stays at 0 + // for the expected-length portion; only length differs. + let mut same_byte_extension = expected.to_vec(); + same_byte_extension.extend(std::iter::repeat(expected[0]).take(extra)); + assert!( + !constant_time_eq(&same_byte_extension, expected), + "input extended by {extra} repeated bytes must not authenticate" + ); + } + } + + // Null-byte injection: ensure the function does not mis-parse embedded + // NUL characters as C-string terminators and accept a shorter match. + #[test] + fn constant_time_eq_null_byte_injection_is_rejected() { + // Token containing a null byte — must only match itself exactly. + let expected: &[u8] = b"token\x00suffix"; + assert!(constant_time_eq(expected, expected)); + assert!(!constant_time_eq(b"token", expected)); + assert!(!constant_time_eq(b"token\x00", expected)); + assert!(!constant_time_eq(b"token\x00suffi", expected)); + + // Null-prefixed input of the same length must not match a non-null token. + let real_token: &[u8] = b"real-secret-value"; + let mut null_injected = vec![0u8; real_token.len()]; + null_injected[0] = real_token[0]; + assert!(!constant_time_eq(&null_injected, real_token)); + } + + // High-byte (0xFF) values throughout: XOR of 0xFF ^ 0xFF = 0, so equal + // high-byte slices must match, and any single-byte difference must not. + #[test] + fn constant_time_eq_high_byte_edge_cases() { + let token = vec![0xffu8; 20]; + assert!(constant_time_eq(&token, &token)); + let mut tampered = token.clone(); + tampered[10] = 0xfe; + assert!(!constant_time_eq(&tampered, &token)); + // Shorter all-ff slice must not match. + assert!(!constant_time_eq(&token[..19], &token)); + } + + // Accumulator-saturation attack: if all bytes of `a` have been XOR-folded to + // 0xFF (i.e. acc is saturated), but the remaining bytes of `b` are 0x00, the + // fold of 0x00 into 0xFF must keep acc ≠ 0 (since 0xFF | 0 = 0xFF). + // This guards against a misimplemented fold that resets acc on certain values. + #[test] + fn constant_time_eq_accumulator_never_resets_to_zero_after_mismatch() { + // a[0] = 0xAA, b[0] = 0x55 → XOR = 0xFF. + // Subsequent bytes all match (XOR = 0x00). Accumulator must remain 0xFF. + let mut a = vec![0x55u8; 16]; + let mut b = vec![0x55u8; 16]; + a[0] = 0xAA; // deliberate mismatch at position 0 + assert!(!constant_time_eq(&a, &b)); + // Verify with mismatch only at last position to test late detection. + a[0] = b[0]; + a[15] = 0xAA; + b[15] = 0x55; + assert!(!constant_time_eq(&a, &b)); + } + + // Zero-length expected token: only the empty input must match. + #[test] + fn constant_time_eq_zero_length_expected_only_matches_empty_input() { + assert!(constant_time_eq(b"", b"")); + assert!(!constant_time_eq(b"\x00", b"")); + assert!(!constant_time_eq(b"x", b"")); + } } diff --git a/src/config/hot_reload.rs b/src/config/hot_reload.rs index ac8acab..103bfb3 100644 --- a/src/config/hot_reload.rs +++ b/src/config/hot_reload.rs @@ -21,9 +21,11 @@ //! `network.*`, `use_middle_proxy`) are **not** applied; a warning is emitted. //! Non-hot changes are never mixed into the runtime config snapshot. +use std::collections::BTreeSet; use std::net::IpAddr; -use std::path::PathBuf; -use std::sync::Arc; +use std::path::{Path, PathBuf}; +use std::sync::{Arc, RwLock as StdRwLock}; +use std::time::Duration; use notify::{EventKind, RecursiveMode, Watcher, recommended_watcher}; use tokio::sync::{mpsc, watch}; @@ -33,7 +35,10 @@ use crate::config::{ LogLevel, MeBindStaleMode, MeFloorMode, MeSocksKdfPolicy, MeTelemetryLevel, MeWriterPickMode, }; -use super::load::ProxyConfig; +use super::load::{LoadedConfig, ProxyConfig}; + +const HOT_RELOAD_STABLE_SNAPSHOTS: u8 = 2; +const HOT_RELOAD_DEBOUNCE: Duration = Duration::from_millis(50); // ── Hot fields ──────────────────────────────────────────────────────────────── @@ -287,6 +292,149 @@ fn listeners_equal( }) } +#[derive(Debug, Clone, Default, PartialEq, Eq)] +struct WatchManifest { + files: BTreeSet, + dirs: BTreeSet, +} + +impl WatchManifest { + fn from_source_files(source_files: &[PathBuf]) -> Self { + let mut files = BTreeSet::new(); + let mut dirs = BTreeSet::new(); + + for path in source_files { + let normalized = normalize_watch_path(path); + files.insert(normalized.clone()); + if let Some(parent) = normalized.parent() { + dirs.insert(parent.to_path_buf()); + } + } + + Self { files, dirs } + } + + fn matches_event_paths(&self, event_paths: &[PathBuf]) -> bool { + event_paths + .iter() + .map(|path| normalize_watch_path(path)) + .any(|path| self.files.contains(&path)) + } +} + +#[derive(Debug, Default)] +struct ReloadState { + applied_snapshot_hash: Option, + candidate_snapshot_hash: Option, + candidate_hits: u8, +} + +impl ReloadState { + fn new(applied_snapshot_hash: Option) -> Self { + Self { + applied_snapshot_hash, + candidate_snapshot_hash: None, + candidate_hits: 0, + } + } + + fn is_applied(&self, hash: u64) -> bool { + self.applied_snapshot_hash == Some(hash) + } + + fn observe_candidate(&mut self, hash: u64) -> u8 { + if self.candidate_snapshot_hash == Some(hash) { + self.candidate_hits = self.candidate_hits.saturating_add(1); + } else { + self.candidate_snapshot_hash = Some(hash); + self.candidate_hits = 1; + } + self.candidate_hits + } + + fn reset_candidate(&mut self) { + self.candidate_snapshot_hash = None; + self.candidate_hits = 0; + } + + fn mark_applied(&mut self, hash: u64) { + self.applied_snapshot_hash = Some(hash); + self.reset_candidate(); + } +} + +fn normalize_watch_path(path: &Path) -> PathBuf { + path.canonicalize().unwrap_or_else(|_| { + if path.is_absolute() { + path.to_path_buf() + } else { + std::env::current_dir() + .map(|cwd| cwd.join(path)) + .unwrap_or_else(|_| path.to_path_buf()) + } + }) +} + +fn sync_watch_paths( + watcher: &mut W, + current: &BTreeSet, + next: &BTreeSet, + recursive_mode: RecursiveMode, + kind: &str, +) { + for path in current.difference(next) { + if let Err(e) = watcher.unwatch(path) { + warn!(path = %path.display(), error = %e, "config watcher: failed to unwatch {kind}"); + } + } + + for path in next.difference(current) { + if let Err(e) = watcher.watch(path, recursive_mode) { + warn!(path = %path.display(), error = %e, "config watcher: failed to watch {kind}"); + } + } +} + +fn apply_watch_manifest( + notify_watcher: Option<&mut W1>, + poll_watcher: Option<&mut W2>, + manifest_state: &Arc>, + next_manifest: WatchManifest, +) { + let current_manifest = manifest_state + .read() + .map(|manifest| manifest.clone()) + .unwrap_or_default(); + + if current_manifest == next_manifest { + return; + } + + if let Some(watcher) = notify_watcher { + sync_watch_paths( + watcher, + ¤t_manifest.dirs, + &next_manifest.dirs, + RecursiveMode::NonRecursive, + "config directory", + ); + } + + if let Some(watcher) = poll_watcher { + sync_watch_paths( + watcher, + ¤t_manifest.files, + &next_manifest.files, + RecursiveMode::NonRecursive, + "config file", + ); + } + + if let Ok(mut manifest) = manifest_state.write() { + *manifest = next_manifest; + } +} + fn overlay_hot_fields(old: &ProxyConfig, new: &ProxyConfig) -> ProxyConfig { let mut cfg = old.clone(); @@ -980,18 +1128,42 @@ fn reload_config( log_tx: &watch::Sender, detected_ip_v4: Option, detected_ip_v6: Option, -) { - let new_cfg = match ProxyConfig::load(config_path) { - Ok(c) => c, + reload_state: &mut ReloadState, +) -> Option { + let loaded = match ProxyConfig::load_with_metadata(config_path) { + Ok(loaded) => loaded, Err(e) => { + reload_state.reset_candidate(); error!("config reload: failed to parse {:?}: {}", config_path, e); - return; + return None; } }; + let LoadedConfig { + config: new_cfg, + source_files, + rendered_hash, + } = loaded; + let next_manifest = WatchManifest::from_source_files(&source_files); if let Err(e) = new_cfg.validate() { + reload_state.reset_candidate(); error!("config reload: validation failed: {}; keeping old config", e); - return; + return Some(next_manifest); + } + + if reload_state.is_applied(rendered_hash) { + return Some(next_manifest); + } + + let candidate_hits = reload_state.observe_candidate(rendered_hash); + if candidate_hits < HOT_RELOAD_STABLE_SNAPSHOTS { + info!( + snapshot_hash = rendered_hash, + candidate_hits, + required_hits = HOT_RELOAD_STABLE_SNAPSHOTS, + "config reload: candidate snapshot observed but not stable yet" + ); + return Some(next_manifest); } let old_cfg = config_tx.borrow().clone(); @@ -1006,17 +1178,19 @@ fn reload_config( } if !hot_changed { - return; + reload_state.mark_applied(rendered_hash); + return Some(next_manifest); } if old_hot.dns_overrides != applied_hot.dns_overrides && let Err(e) = crate::network::dns_overrides::install_entries(&applied_hot.dns_overrides) { + reload_state.reset_candidate(); error!( "config reload: invalid network.dns_overrides: {}; keeping old config", e ); - return; + return Some(next_manifest); } log_changes( @@ -1028,6 +1202,8 @@ fn reload_config( detected_ip_v6, ); config_tx.send(Arc::new(applied_cfg)).ok(); + reload_state.mark_applied(rendered_hash); + Some(next_manifest) } // ── Public API ──────────────────────────────────────────────────────────────── @@ -1050,79 +1226,86 @@ pub fn spawn_config_watcher( let (config_tx, config_rx) = watch::channel(initial); let (log_tx, log_rx) = watch::channel(initial_level); - // Bridge: sync notify callbacks → async task via mpsc. - let (notify_tx, mut notify_rx) = mpsc::channel::<()>(4); + let config_path = normalize_watch_path(&config_path); + let initial_loaded = ProxyConfig::load_with_metadata(&config_path).ok(); + let initial_manifest = initial_loaded + .as_ref() + .map(|loaded| WatchManifest::from_source_files(&loaded.source_files)) + .unwrap_or_else(|| WatchManifest::from_source_files(std::slice::from_ref(&config_path))); + let initial_snapshot_hash = initial_loaded.as_ref().map(|loaded| loaded.rendered_hash); - // Canonicalize so path matches what notify returns (absolute) in events. - let config_path = config_path - .canonicalize() - .unwrap_or_else(|_| config_path.to_path_buf()); - - // Watch the parent directory rather than the file itself, because many - // editors (vim, nano) and systemd write via rename, which would cause - // inotify to lose track of the original inode. - let watch_dir = config_path - .parent() - .unwrap_or_else(|| std::path::Path::new(".")) - .to_path_buf(); - - // ── inotify watcher (instant on local fs) ──────────────────────────── - let config_file = config_path.clone(); - let tx_inotify = notify_tx.clone(); - let inotify_ok = match recommended_watcher(move |res: notify::Result| { - let Ok(event) = res else { return }; - let is_our_file = event.paths.iter().any(|p| p == &config_file); - if !is_our_file { return; } - if matches!(event.kind, EventKind::Modify(_) | EventKind::Create(_) | EventKind::Remove(_)) { - let _ = tx_inotify.try_send(()); - } - }) { - Ok(mut w) => match w.watch(&watch_dir, RecursiveMode::NonRecursive) { - Ok(()) => { - info!("config watcher: inotify active on {:?}", config_path); - Box::leak(Box::new(w)); - true - } - Err(e) => { warn!("config watcher: inotify watch failed: {}", e); false } - }, - Err(e) => { warn!("config watcher: inotify unavailable: {}", e); false } - }; - - // ── poll watcher (always active, fixes Docker bind mounts / NFS) ───── - // inotify does not receive events for files mounted from the host into - // a container. PollWatcher compares file contents every 3 s and fires - // on any change regardless of the underlying fs. - let config_file2 = config_path.clone(); - let tx_poll = notify_tx; - match notify::poll::PollWatcher::new( - move |res: notify::Result| { - let Ok(event) = res else { return }; - let is_our_file = event.paths.iter().any(|p| p == &config_file2); - if !is_our_file { return; } - if matches!(event.kind, EventKind::Modify(_) | EventKind::Create(_) | EventKind::Remove(_)) { - let _ = tx_poll.try_send(()); - } - }, - notify::Config::default() - .with_poll_interval(std::time::Duration::from_secs(3)) - .with_compare_contents(true), - ) { - Ok(mut w) => match w.watch(&config_path, RecursiveMode::NonRecursive) { - Ok(()) => { - if inotify_ok { - info!("config watcher: poll watcher also active (Docker/NFS safe)"); - } else { - info!("config watcher: poll watcher active on {:?} (3s interval)", config_path); - } - Box::leak(Box::new(w)); - } - Err(e) => warn!("config watcher: poll watch failed: {}", e), - }, - Err(e) => warn!("config watcher: poll watcher unavailable: {}", e), - } - - // ── event loop ─────────────────────────────────────────────────────── tokio::spawn(async move { + let (notify_tx, mut notify_rx) = mpsc::channel::<()>(4); + let manifest_state = Arc::new(StdRwLock::new(WatchManifest::default())); + let mut reload_state = ReloadState::new(initial_snapshot_hash); + + let tx_inotify = notify_tx.clone(); + let manifest_for_inotify = manifest_state.clone(); + let mut inotify_watcher = match recommended_watcher(move |res: notify::Result| { + let Ok(event) = res else { return }; + if !matches!(event.kind, EventKind::Modify(_) | EventKind::Create(_) | EventKind::Remove(_)) { + return; + } + let is_our_file = manifest_for_inotify + .read() + .map(|manifest| manifest.matches_event_paths(&event.paths)) + .unwrap_or(false); + if is_our_file { + let _ = tx_inotify.try_send(()); + } + }) { + Ok(watcher) => Some(watcher), + Err(e) => { + warn!("config watcher: inotify unavailable: {}", e); + None + } + }; + apply_watch_manifest( + inotify_watcher.as_mut(), + Option::<&mut notify::poll::PollWatcher>::None, + &manifest_state, + initial_manifest.clone(), + ); + if inotify_watcher.is_some() { + info!("config watcher: inotify active on {:?}", config_path); + } + + let tx_poll = notify_tx.clone(); + let manifest_for_poll = manifest_state.clone(); + let mut poll_watcher = match notify::poll::PollWatcher::new( + move |res: notify::Result| { + let Ok(event) = res else { return }; + if !matches!(event.kind, EventKind::Modify(_) | EventKind::Create(_) | EventKind::Remove(_)) { + return; + } + let is_our_file = manifest_for_poll + .read() + .map(|manifest| manifest.matches_event_paths(&event.paths)) + .unwrap_or(false); + if is_our_file { + let _ = tx_poll.try_send(()); + } + }, + notify::Config::default() + .with_poll_interval(Duration::from_secs(3)) + .with_compare_contents(true), + ) { + Ok(watcher) => Some(watcher), + Err(e) => { + warn!("config watcher: poll watcher unavailable: {}", e); + None + } + }; + apply_watch_manifest( + Option::<&mut notify::RecommendedWatcher>::None, + poll_watcher.as_mut(), + &manifest_state, + initial_manifest.clone(), + ); + if poll_watcher.is_some() { + info!("config watcher: poll watcher active (Docker/NFS safe)"); + } + #[cfg(unix)] let mut sighup = { use tokio::signal::unix::{signal, SignalKind}; @@ -1152,11 +1335,25 @@ pub fn spawn_config_watcher( #[cfg(not(unix))] if notify_rx.recv().await.is_none() { break; } - // Debounce: drain extra events that arrive within 50 ms. - tokio::time::sleep(std::time::Duration::from_millis(50)).await; + // Debounce: drain extra events that arrive within a short quiet window. + tokio::time::sleep(HOT_RELOAD_DEBOUNCE).await; while notify_rx.try_recv().is_ok() {} - reload_config(&config_path, &config_tx, &log_tx, detected_ip_v4, detected_ip_v6); + if let Some(next_manifest) = reload_config( + &config_path, + &config_tx, + &log_tx, + detected_ip_v4, + detected_ip_v6, + &mut reload_state, + ) { + apply_watch_manifest( + inotify_watcher.as_mut(), + poll_watcher.as_mut(), + &manifest_state, + next_manifest, + ); + } } }); @@ -1171,6 +1368,40 @@ mod tests { ProxyConfig::default() } + fn write_reload_config(path: &Path, ad_tag: Option<&str>, server_port: Option) { + let mut config = String::from( + r#" + [censorship] + tls_domain = "example.com" + + [access.users] + user = "00000000000000000000000000000000" + "#, + ); + + if ad_tag.is_some() { + config.push_str("\n[general]\n"); + if let Some(tag) = ad_tag { + config.push_str(&format!("ad_tag = \"{tag}\"\n")); + } + } + + if let Some(port) = server_port { + config.push_str("\n[server]\n"); + config.push_str(&format!("port = {port}\n")); + } + + std::fs::write(path, config).unwrap(); + } + + fn temp_config_path(prefix: &str) -> PathBuf { + let nonce = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_nanos(); + std::env::temp_dir().join(format!("{prefix}_{nonce}.toml")) + } + #[test] fn overlay_applies_hot_and_preserves_non_hot() { let old = sample_config(); @@ -1238,4 +1469,61 @@ mod tests { assert_eq!(applied.general.use_middle_proxy, old.general.use_middle_proxy); assert!(!config_equal(&applied, &new)); } + + #[test] + fn reload_requires_stable_snapshot_before_hot_apply() { + let initial_tag = "11111111111111111111111111111111"; + let final_tag = "22222222222222222222222222222222"; + let path = temp_config_path("telemt_hot_reload_stable"); + + write_reload_config(&path, Some(initial_tag), None); + let initial_cfg = Arc::new(ProxyConfig::load(&path).unwrap()); + let initial_hash = ProxyConfig::load_with_metadata(&path).unwrap().rendered_hash; + let (config_tx, _config_rx) = watch::channel(initial_cfg.clone()); + let (log_tx, _log_rx) = watch::channel(initial_cfg.general.log_level.clone()); + let mut reload_state = ReloadState::new(Some(initial_hash)); + + write_reload_config(&path, None, None); + reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap(); + assert_eq!( + config_tx.borrow().general.ad_tag.as_deref(), + Some(initial_tag) + ); + + write_reload_config(&path, Some(final_tag), None); + reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap(); + assert_eq!( + config_tx.borrow().general.ad_tag.as_deref(), + Some(initial_tag) + ); + + reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap(); + assert_eq!(config_tx.borrow().general.ad_tag.as_deref(), Some(final_tag)); + + let _ = std::fs::remove_file(path); + } + + #[test] + fn reload_keeps_hot_apply_when_non_hot_fields_change() { + let initial_tag = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"; + let final_tag = "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb"; + let path = temp_config_path("telemt_hot_reload_mixed"); + + write_reload_config(&path, Some(initial_tag), None); + let initial_cfg = Arc::new(ProxyConfig::load(&path).unwrap()); + let initial_hash = ProxyConfig::load_with_metadata(&path).unwrap().rendered_hash; + let (config_tx, _config_rx) = watch::channel(initial_cfg.clone()); + let (log_tx, _log_rx) = watch::channel(initial_cfg.general.log_level.clone()); + let mut reload_state = ReloadState::new(Some(initial_hash)); + + write_reload_config(&path, Some(final_tag), Some(initial_cfg.server.port + 1)); + reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap(); + reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap(); + + let applied = config_tx.borrow().clone(); + assert_eq!(applied.general.ad_tag.as_deref(), Some(final_tag)); + assert_eq!(applied.server.port, initial_cfg.server.port); + + let _ = std::fs::remove_file(path); + } } diff --git a/src/config/load.rs b/src/config/load.rs index a61b6e1..b2e4b22 100644 --- a/src/config/load.rs +++ b/src/config/load.rs @@ -1,8 +1,9 @@ #![allow(deprecated)] -use std::collections::HashMap; +use std::collections::{BTreeSet, HashMap}; +use std::hash::{DefaultHasher, Hash, Hasher}; use std::net::{IpAddr, SocketAddr}; -use std::path::Path; +use std::path::{Path, PathBuf}; use rand::Rng; use tracing::warn; @@ -13,6 +14,8 @@ use crate::error::{ProxyError, Result}; use super::defaults::*; use super::types::*; +// Reject absolute paths and any path component that traverses upward. +// This prevents config include directives from escaping the config directory. fn is_safe_include_path(path_str: &str) -> bool { let p = std::path::Path::new(path_str); if p.is_absolute() { @@ -21,7 +24,37 @@ fn is_safe_include_path(path_str: &str) -> bool { !p.components().any(|c| matches!(c, std::path::Component::ParentDir)) } -fn preprocess_includes(content: &str, base_dir: &Path, depth: u8) -> Result { +#[derive(Debug, Clone)] +pub(crate) struct LoadedConfig { + pub(crate) config: ProxyConfig, + pub(crate) source_files: Vec, + pub(crate) rendered_hash: u64, +} + +fn normalize_config_path(path: &Path) -> PathBuf { + path.canonicalize().unwrap_or_else(|_| { + if path.is_absolute() { + path.to_path_buf() + } else { + std::env::current_dir() + .map(|cwd| cwd.join(path)) + .unwrap_or_else(|_| path.to_path_buf()) + } + }) +} + +fn hash_rendered_snapshot(rendered: &str) -> u64 { + let mut hasher = DefaultHasher::new(); + rendered.hash(&mut hasher); + hasher.finish() +} + +fn preprocess_includes( + content: &str, + base_dir: &Path, + depth: u8, + source_files: &mut BTreeSet, +) -> Result { if depth > 10 { return Err(ProxyError::Config("Include depth > 10".into())); } @@ -39,10 +72,16 @@ fn preprocess_includes(content: &str, base_dir: &Path, depth: u8) -> Result>(path: P) -> Result { - let content = - std::fs::read_to_string(&path).map_err(|e| ProxyError::Config(e.to_string()))?; - let base_dir = path - .as_ref() - .parent() - .unwrap_or_else(|| Path::new(".")); - let processed = preprocess_includes(&content, base_dir, 0)?; + Self::load_with_metadata(path).map(|loaded| loaded.config) + } + + pub(crate) fn load_with_metadata>(path: P) -> Result { + let path = path.as_ref(); + let content = std::fs::read_to_string(path).map_err(|e| ProxyError::Config(e.to_string()))?; + let base_dir = path.parent().unwrap_or(Path::new(".")); + let mut source_files = BTreeSet::new(); + source_files.insert(normalize_config_path(path)); + let processed = preprocess_includes(&content, base_dir, 0, &mut source_files)?; let parsed_toml: toml::Value = toml::from_str(&processed).map_err(|e| ProxyError::Config(e.to_string()))?; @@ -821,7 +863,11 @@ impl ProxyConfig { .entry("203".to_string()) .or_insert_with(|| vec!["91.105.192.100:443".to_string()]); - Ok(config) + Ok(LoadedConfig { + config, + source_files: source_files.into_iter().collect(), + rendered_hash: hash_rendered_snapshot(&processed), + }) } pub fn validate(&self) -> Result<()> { @@ -1165,6 +1211,48 @@ mod tests { ); } + #[test] + fn load_with_metadata_collects_include_files() { + let nonce = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_nanos(); + let dir = std::env::temp_dir().join(format!("telemt_load_metadata_{nonce}")); + std::fs::create_dir_all(&dir).unwrap(); + let main_path = dir.join("config.toml"); + let include_path = dir.join("included.toml"); + + std::fs::write( + &include_path, + r#" + [access.users] + user = "00000000000000000000000000000000" + "#, + ) + .unwrap(); + std::fs::write( + &main_path, + r#" + include = "included.toml" + + [censorship] + tls_domain = "example.com" + "#, + ) + .unwrap(); + + let loaded = ProxyConfig::load_with_metadata(&main_path).unwrap(); + let main_normalized = normalize_config_path(&main_path); + let include_normalized = normalize_config_path(&include_path); + + assert!(loaded.source_files.contains(&main_normalized)); + assert!(loaded.source_files.contains(&include_normalized)); + + let _ = std::fs::remove_file(main_path); + let _ = std::fs::remove_file(include_path); + let _ = std::fs::remove_dir(dir); + } + #[test] fn dc_overrides_inject_dc203_default() { let toml = r#" @@ -2197,7 +2285,8 @@ mod tests { let root = dir.join(format!("{prefix}0.toml")); let root_content = std::fs::read_to_string(&root).unwrap(); - let result = preprocess_includes(&root_content, &dir, 0); + let mut sf = BTreeSet::new(); + let result = preprocess_includes(&root_content, &dir, 0, &mut sf); for i in 0usize..=11 { let _ = std::fs::remove_file(dir.join(format!("{prefix}{i}.toml"))); @@ -2230,7 +2319,8 @@ mod tests { let root = dir.join(format!("{prefix}0.toml")); let root_content = std::fs::read_to_string(&root).unwrap(); - let result = preprocess_includes(&root_content, &dir, 0); + let mut sf = BTreeSet::new(); + let result = preprocess_includes(&root_content, &dir, 0, &mut sf); for i in 0usize..=10 { let _ = std::fs::remove_file(dir.join(format!("{prefix}{i}.toml"))); diff --git a/src/protocol/obfuscation.rs b/src/protocol/obfuscation.rs index 47eae2a..394aabc 100644 --- a/src/protocol/obfuscation.rs +++ b/src/protocol/obfuscation.rs @@ -382,20 +382,7 @@ mod tests { } // ObfuscationParams must NOT implement Clone — key material must not be duplicated. - // This is a compile-time enforcement test via negative trait assertion. - // If Clone were derived again, this assertion would fail to compile. - #[test] - fn test_obfuscation_params_is_not_clone() { - fn assert_not_clone() {} - // The trait bound below must NOT hold. We verify by trying to assert - // Clone is NOT available: calling ::clone() on the type must not compile. - // Since we cannot express negative bounds in stable Rust, we use the - // auto_trait approach: verify the type is not Clone via a static assertion. - // If this test compiles and passes, Clone is absent from ObfuscationParams. - struct NotClone; - impl NotClone { fn check() { assert_not_clone::(); } } - // The above does not assert Clone is absent; the real guard is that the - // compiler will reject any `.clone()` call site. This test documents intent. - NotClone::check(); - } + // Enforced at compile time: if Clone is ever derived or manually implemented for + // ObfuscationParams, this assertion will fail to compile. + static_assertions::assert_not_impl_any!(ObfuscationParams: Clone); } diff --git a/src/stream/buffer_pool.rs b/src/stream/buffer_pool.rs index ed6afb3..b682ca3 100644 --- a/src/stream/buffer_pool.rs +++ b/src/stream/buffer_pool.rs @@ -8,7 +8,6 @@ use bytes::BytesMut; use crossbeam_queue::ArrayQueue; use std::ops::{Deref, DerefMut}; -use std::sync::OnceLock; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; @@ -21,6 +20,11 @@ pub const DEFAULT_BUFFER_SIZE: usize = 16 * 1024; /// Default maximum number of pooled buffers pub const DEFAULT_MAX_BUFFERS: usize = 1024; +// Buffers that grew beyond this multiple of `buffer_size` are dropped rather +// than returned to the pool, preventing memory amplification from a single +// large-payload connection permanently holding oversized allocations. +const MAX_POOL_BUFFER_OVERSIZE_MULT: usize = 4; + // ============= Buffer Pool ============= /// Thread-safe pool of reusable buffers @@ -97,11 +101,13 @@ impl BufferPool { fn return_buffer(&self, mut buffer: BytesMut) { buffer.clear(); - // Only return buffers that have at least the canonical capacity so the - // pool does not shrink over time when callers grow their buffers. - // Over-capacity or full-queue buffers are simply dropped; `allocated` - // is a high-water mark and is not decremented here. - if buffer.capacity() >= self.buffer_size { + // Accept buffers within [buffer_size, buffer_size * MAX_POOL_BUFFER_OVERSIZE_MULT]. + // The lower bound prevents pool capacity from shrinking over time. + // The upper bound drops buffers that grew excessively (e.g. to serve a large + // payload) so they do not permanently inflate pool memory or get handed to a + // future connection that only needs a small allocation. + let max_acceptable = self.buffer_size.saturating_mul(MAX_POOL_BUFFER_OVERSIZE_MULT); + if buffer.capacity() >= self.buffer_size && buffer.capacity() <= max_acceptable { let _ = self.buffers.push(buffer); } } @@ -214,18 +220,19 @@ impl PooledBuffer { impl Deref for PooledBuffer { type Target = BytesMut; - + fn deref(&self) -> &Self::Target { - static EMPTY_BUFFER: OnceLock = OnceLock::new(); self.buffer .as_ref() - .unwrap_or_else(|| EMPTY_BUFFER.get_or_init(BytesMut::new)) + .expect("PooledBuffer: attempted to deref after buffer was taken") } } impl DerefMut for PooledBuffer { fn deref_mut(&mut self) -> &mut Self::Target { - self.buffer.get_or_insert_with(BytesMut::new) + self.buffer + .as_mut() + .expect("PooledBuffer: attempted to deref_mut after buffer was taken") } } @@ -503,22 +510,27 @@ mod tests { ); } - // An over-grown buffer (capacity > buffer_size) that is returned to the pool - // must still be returned (not dropped) because its capacity exceeds the - // required minimum. + // A buffer that grew moderately (within MAX_POOL_BUFFER_OVERSIZE_MULT of the canonical + // size) must be returned to the pool, because the allocation is still reasonable and + // reusing it is more efficient than allocating a new one. #[test] fn oversized_buffer_is_returned_to_pool() { - let pool = Arc::new(BufferPool::with_config(64, 10)); + let canonical = 64usize; + let pool = Arc::new(BufferPool::with_config(canonical, 10)); let mut buf = pool.get(); - // Manually grow the backing allocation beyond buffer_size. - buf.reserve(8192); - assert!(buf.capacity() >= 8192); + // Grow to 2× the canonical size — within the 4× upper bound. + buf.reserve(canonical); + assert!(buf.capacity() >= canonical); + assert!( + buf.capacity() <= canonical * MAX_POOL_BUFFER_OVERSIZE_MULT, + "pre-condition: test growth must stay within the acceptable bound" + ); drop(buf); - // The buffer must have been returned because capacity >= buffer_size. + // The buffer must have been returned because capacity is within acceptable range. let stats = pool.stats(); - assert_eq!(stats.pooled, 1, "oversized buffer must be returned to pool"); + assert_eq!(stats.pooled, 1, "moderately-oversized buffer must be returned to pool"); } // A buffer whose capacity fell below buffer_size (e.g. due to take() on @@ -625,4 +637,97 @@ mod tests { ); assert_eq!(stats.hits, 50); } + + // ── Security invariant: sensitive data must not leak between pool users ─── + + // A buffer containing "sensitive" bytes must be zeroed before being handed + // to the next caller. An attacker who can trigger repeated pool cycles against + // a shared buffer slot must not be able to read prior connection data. + #[test] + fn pooled_buffer_sensitive_data_is_cleared_before_reuse() { + let pool = Arc::new(BufferPool::with_config(64, 2)); + { + let mut buf = pool.get(); + buf.extend_from_slice(b"credentials:password123"); + // Drop returns the buffer to the pool after clearing. + } + { + let buf = pool.get(); + // Buffer must be empty — no leftover bytes from the previous user. + assert!(buf.is_empty(), "pool must clear buffer before handing it to the next caller"); + assert_eq!(buf.len(), 0); + } + } + + // Verify that calling take() extracts the full content and the extracted + // BytesMut does NOT get returned to the pool (no double-return). + #[test] + fn pooled_buffer_take_eliminates_pool_return() { + let pool = Arc::new(BufferPool::with_config(64, 2)); + let stats_before = pool.stats(); + + let mut buf = pool.get(); // miss + buf.extend_from_slice(b"important"); + let inner = buf.take(); // consumes PooledBuffer, should NOT return to pool + + assert_eq!(&inner[..], b"important"); + let stats_after = pool.stats(); + // pooled count must not increase — take() bypasses the pool + assert_eq!( + stats_after.pooled, stats_before.pooled, + "take() must not return the buffer to the pool" + ); + } + + // Multiple concurrent get() calls must each get an independent empty buffer, + // not aliased memory. An adversary who can cause aliased buffer access could + // read or corrupt another connection's in-flight data. + #[test] + fn pooled_buffers_are_independent_no_aliasing() { + let pool = Arc::new(BufferPool::with_config(64, 4)); + let mut b1 = pool.get(); + let mut b2 = pool.get(); + + b1.extend_from_slice(b"connection-A"); + b2.extend_from_slice(b"connection-B"); + + assert_eq!(&b1[..], b"connection-A"); + assert_eq!(&b2[..], b"connection-B"); + // Verify no aliasing: modifying b2 does not affect b1. + assert_ne!(&b1[..], &b2[..]); + } + + // Oversized buffers (capacity grown beyond pool's canonical size) must NOT + // be returned to the pool — this prevents the pool from holding oversized + // buffers that could be handed to unrelated connections and leak large chunks + // of heap across connection boundaries. + #[test] + fn oversized_buffer_is_dropped_not_pooled() { + let canonical = 64usize; + let pool = Arc::new(BufferPool::with_config(canonical, 4)); + + { + let mut buf = pool.get(); + // Grow well beyond the canonical size. + buf.extend(std::iter::repeat(0u8).take(canonical * 8)); + // Drop should abandon this oversized buffer rather than returning it. + } + + let stats = pool.stats(); + // Pool must be empty: the oversized buffer was not re-queued. + assert_eq!( + stats.pooled, 0, + "oversized buffer must be dropped, not returned to pool (got {} pooled)", + stats.pooled + ); + } + + // Deref on a PooledBuffer obtained normally must NOT panic. + #[test] + fn pooled_buffer_deref_on_live_buffer_does_not_panic() { + let pool = Arc::new(BufferPool::new()); + let mut buf = pool.get(); + buf.extend_from_slice(b"hello"); + assert_eq!(&buf[..], b"hello"); + } } diff --git a/src/transport/middle_proxy/wire.rs b/src/transport/middle_proxy/wire.rs index 05656a2..8dde55c 100644 --- a/src/transport/middle_proxy/wire.rs +++ b/src/transport/middle_proxy/wire.rs @@ -97,8 +97,18 @@ pub(crate) fn build_proxy_req_payload( } } - let extra_bytes = u32::try_from(b.len() - extra_start - 4).unwrap_or(0); - b[extra_start..extra_start + 4].copy_from_slice(&extra_bytes.to_le_bytes()); + // On overflow (buffer > 4 GiB, unreachable in practice) keep the extra block + // empty by truncating the data back to the length-field position and writing 0, + // so the wire representation remains consistent (length field matches content). + match u32::try_from(b.len() - extra_start - 4) { + Ok(extra_bytes) => { + b[extra_start..extra_start + 4].copy_from_slice(&extra_bytes.to_le_bytes()); + } + Err(_) => { + b.truncate(extra_start + 4); + b[extra_start..extra_start + 4].copy_from_slice(&0u32.to_le_bytes()); + } + } } b.extend_from_slice(data); @@ -242,6 +252,7 @@ mod tests { // Tags exceeding the 3-byte TL length limit (> 0xFFFFFF = 16,777,215 bytes) // must produce an empty extra block rather than a truncated/corrupt length. + #[ignore = "allocates ~16 MiB; run only in a dedicated large-tests profile/CI job"] #[test] fn test_build_proxy_req_payload_tag_exceeds_tl_long_form_limit_produces_empty_extra() { let client = SocketAddr::from(([198, 51, 100, 22], 33333)); @@ -267,6 +278,7 @@ mod tests { // The prior guard was `tag_len_u32.is_some()` which passed for tags up to u32::MAX. // Verify boundary at exactly 0xFFFFFF: must succeed and encode at 3 bytes. + #[ignore = "allocates ~16 MiB; run only in a dedicated large-tests profile/CI job"] #[test] fn test_build_proxy_req_payload_tag_at_tl_long_form_max_boundary_encodes() { // 0xFFFFFF = 16,777,215 bytes — maximum representable by 3-byte TL length. @@ -303,4 +315,101 @@ mod tests { let fixed = 4 + 4 + 8 + 20 + 20; let extra_len = u32::from_le_bytes(raw[fixed..fixed + 4].try_into().unwrap()) as usize; assert_eq!(extra_len, 0); - }} + } + + // ── Protocol wire-consistency invariant tests ───────────────────────────── + + // The extra-block length field must ALWAYS equal the number of bytes that + // actually follow it in the buffer. A censor or MitM that parses the wire + // representation must not be able to trigger desync by providing adversarial + // input that causes the length to say 0 while data bytes are present. + #[test] + fn extra_block_length_field_always_matches_actual_content_length() { + let client = SocketAddr::from(([198, 51, 100, 40], 10000)); + let ours = SocketAddr::from(([203, 0, 113, 40], 443)); + let fixed = 4 + 4 + 8 + 20 + 20; + + // Helper: assert wire consistency for a given tag. + let check = |tag: Option<&[u8]>, data: &[u8]| { + let p = build_proxy_req_payload(1, client, ours, data, tag, RPC_FLAG_HAS_AD_TAG); + let raw = p.as_ref(); + let declared = + u32::from_le_bytes(raw[fixed..fixed + 4].try_into().unwrap()) as usize; + let actual = raw.len() - fixed - 4 - data.len(); + assert_eq!( + declared, actual, + "extra-block length field ({declared}) must equal actual byte count ({actual})" + ); + }; + + check(None, b"data"); + check(Some(b""), b"data"); // zero-length tag + check(Some(&[0xABu8; 1]), b""); // 1-byte tag + check(Some(&[0xABu8; 253]), b"x"); // short-form max + check(Some(&[0xCCu8; 254]), b"y"); // long-form min + check(Some(&[0xDDu8; 500]), b"z"); // mid-range long-form + } + + // Zero-length tag: must encode with short-form length byte 0 and 4-byte + // alignment padding, not be dropped as if tag were None. + #[test] + fn test_build_proxy_req_payload_zero_length_tag_encodes_correctly() { + let client = SocketAddr::from(([198, 51, 100, 50], 20000)); + let ours = SocketAddr::from(([203, 0, 113, 50], 443)); + let payload = build_proxy_req_payload( + 12, + client, + ours, + b"payload", + Some(&[]), + RPC_FLAG_HAS_AD_TAG, + ); + let raw = payload.as_ref(); + let fixed = 4 + 4 + 8 + 20 + 20; + let extra_len = u32::from_le_bytes(raw[fixed..fixed + 4].try_into().unwrap()) as usize; + // TL object (4) + length-byte 0 (1) + 0 data bytes + padding to 4-byte boundary. + // (1 + 0) % 4 = 1; pad = (4 - 1) % 4 = 3. + assert_eq!(extra_len, 4 + 1 + 3, "zero-length tag must produce TL header + padding"); + // Length byte must be 0 (short-form). + assert_eq!(raw[fixed + 8], 0u8); + } + + // HAS_AD_TAG absent: extra block must NOT appear in the wire output at all. + #[test] + fn test_build_proxy_req_payload_without_has_ad_tag_flag_has_no_extra_block() { + let client = SocketAddr::from(([198, 51, 100, 60], 30000)); + let ours = SocketAddr::from(([203, 0, 113, 60], 443)); + let payload = build_proxy_req_payload( + 13, + client, + ours, + b"data", + Some(&[1, 2, 3]), + 0, // no HAS_AD_TAG flag + ); + let fixed = 4 + 4 + 8 + 20 + 20; + let raw = payload.as_ref(); + // Without the flag, extra block is skipped; payload follows immediately after headers. + assert_eq!(raw.len(), fixed + 4, "no extra block when HAS_AD_TAG is absent"); + assert_eq!(&raw[fixed..], b"data"); + } + + // Tag with all-0xFF bytes: the TL length encoding must not be confused by + // the 0xFF marker byte that the tag itself might contain. + #[test] + fn test_build_proxy_req_payload_tag_containing_0xfe_byte_encodes_correctly() { + let client = SocketAddr::from(([198, 51, 100, 70], 40000)); + let ours = SocketAddr::from(([203, 0, 113, 70], 443)); + // 10-byte tag whose first byte is 0xFE (the long-form marker value). + // At length 10 the short-form encoding applies; the 0xFE data byte must + // not be confused with the length marker. + let tag = vec![0xFEu8; 10]; + let payload = build_proxy_req_payload(14, client, ours, b"", Some(&tag), RPC_FLAG_HAS_AD_TAG); + let raw = payload.as_ref(); + let fixed = 4 + 4 + 8 + 20 + 20; + // Length byte must be 10 (short form), not 0xFE. + assert_eq!(raw[fixed + 8], 10u8, "short-form length byte must be 10, not the 0xFE marker"); + // The actual tag bytes must follow immediately. + assert_eq!(&raw[fixed + 9..fixed + 19], &[0xFEu8; 10]); + } +} diff --git a/src/transport/pool.rs b/src/transport/pool.rs index d71559a..30e9f24 100644 --- a/src/transport/pool.rs +++ b/src/transport/pool.rs @@ -260,23 +260,44 @@ pub struct PoolStats { #[cfg(unix)] #[allow(unsafe_code)] fn is_connection_alive(stream: &TcpStream) -> bool { + use std::io::ErrorKind; use std::os::unix::io::AsRawFd; + + let fd = stream.as_raw_fd(); let mut buf = [0u8; 1]; - // SAFETY: `stream` owns this fd for the full duration of the call. - // MSG_PEEK + MSG_DONTWAIT: inspect the receive buffer without consuming bytes. - let n = unsafe { - libc::recv( - stream.as_raw_fd(), - buf.as_mut_ptr().cast::(), - 1, - libc::MSG_PEEK | libc::MSG_DONTWAIT, - ) - }; - match n { - 0 => false, - n if n > 0 => true, - _ => std::io::Error::last_os_error().kind() == std::io::ErrorKind::WouldBlock, + + // EINTR can fire on any syscall when a signal is delivered to the thread. + // Treating it as a dead connection causes spurious reconnects; retry instead. + const MAX_RECV_RETRIES: usize = 3; + + for _ in 0..MAX_RECV_RETRIES { + // SAFETY: `stream` owns this fd for the full duration of the call. + // MSG_PEEK + MSG_DONTWAIT: inspect the receive buffer without consuming bytes. + let n = unsafe { + libc::recv( + fd, + buf.as_mut_ptr().cast::(), + 1, + libc::MSG_PEEK | libc::MSG_DONTWAIT, + ) + }; + + if n > 0 { + return true; + } else if n == 0 { + return false; + } else { + match std::io::Error::last_os_error().kind() { + ErrorKind::Interrupted => continue, + ErrorKind::WouldBlock => return true, + _ => return false, + } + } } + + // After MAX_RECV_RETRIES consecutive EINTR, assume the connection is alive + // rather than triggering a false reconnect. + true } #[cfg(not(unix))] @@ -518,5 +539,63 @@ mod tests { assert!(is_connection_alive(&stream)); } } -} + // ── EINTR retry-logic unit tests ────────────────────────────────────────── + + // The non-unix fallback path must correctly classify WouldBlock as alive + // and any other error as dead, mirroring the unix semantics. + #[cfg(not(unix))] + #[tokio::test] + async fn test_is_connection_alive_non_unix_open_connection() { + let listener = match TcpListener::bind("127.0.0.1:0").await { + Ok(l) => l, + Err(e) if e.kind() == ErrorKind::PermissionDenied => return, + Err(e) => panic!("bind failed: {e}"), + }; + let addr = listener.local_addr().unwrap(); + tokio::spawn(async move { + let _ = listener.accept().await; + tokio::time::sleep(std::time::Duration::from_millis(500)).await; + }); + let stream = match TcpStream::connect(addr).await { + Ok(s) => s, + Err(e) if e.kind() == ErrorKind::PermissionDenied => return, + Err(e) => panic!("connect failed: {e}"), + }; + assert!(is_connection_alive(&stream), "open idle connection must be alive (non-unix)"); + } + + // Verify the unix path: calling is_connection_alive many times on an active + // connection never spuriously returns false (guards against the pre-fix bug + // where EINTR would flip the result on a single call). + #[cfg(unix)] + #[tokio::test] + async fn test_is_connection_alive_never_returns_false_spuriously_on_open_connection() { + use tokio::io::AsyncWriteExt; + let listener = match TcpListener::bind("127.0.0.1:0").await { + Ok(l) => l, + Err(e) if e.kind() == ErrorKind::PermissionDenied => return, + Err(e) => panic!("bind failed: {e}"), + }; + let addr = listener.local_addr().unwrap(); + let server = tokio::spawn(async move { + let (mut s, _) = listener.accept().await.expect("accept"); + s.write_all(&[0xBEu8]).await.expect("write"); + tokio::time::sleep(Duration::from_millis(500)).await; + }); + let stream = match TcpStream::connect(addr).await { + Ok(s) => s, + Err(e) if e.kind() == ErrorKind::PermissionDenied => return, + Err(e) => panic!("connect failed: {e}"), + }; + tokio::time::sleep(Duration::from_millis(30)).await; + // Call is_connection_alive 20 times; all must return true. + for i in 0..20 { + assert!( + is_connection_alive(&stream), + "is_connection_alive must be true on call {i} (connection is open with buffered data)" + ); + } + drop(server); + } +}