Merge resolve-conflicts-419: conflict resolution + Copilot review fixes

This commit is contained in:
David Osipov 2026-03-14 22:01:06 +04:00
commit 0903dd36c7
No known key found for this signature in database
GPG Key ID: 0E55C4A47454E82E
9 changed files with 933 additions and 164 deletions

7
Cargo.lock generated
View File

@ -2025,6 +2025,12 @@ version = "1.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596"
[[package]]
name = "static_assertions"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
[[package]]
name = "subtle"
version = "2.6.1"
@ -2127,6 +2133,7 @@ dependencies = [
"sha1",
"sha2",
"socket2 0.5.10",
"static_assertions",
"subtle",
"thiserror 2.0.18",
"tokio",

View File

@ -70,6 +70,7 @@ tokio-test = "0.4"
criterion = "0.5"
proptest = "1.4"
futures = "0.3"
static_assertions = "1.1"
[[bench]]
name = "crypto_bench"

View File

@ -62,7 +62,7 @@ use runtime_zero::{
use runtime_watch::spawn_runtime_watchers;
use users::{create_user, delete_user, patch_user, rotate_secret, users_from_config};
pub struct ApiRuntimeState {
pub(super) struct ApiRuntimeState {
pub(super) process_started_at_epoch_secs: u64,
pub(super) config_reload_count: AtomicU64,
pub(super) last_config_reload_epoch_secs: AtomicU64,
@ -70,7 +70,7 @@ pub struct ApiRuntimeState {
}
#[derive(Clone)]
pub struct ApiShared {
pub(super) struct ApiShared {
pub(super) stats: Arc<Stats>,
pub(super) ip_tracker: Arc<UserIpTracker>,
pub(super) me_pool: Arc<RwLock<Option<Arc<MePool>>>>,
@ -565,18 +565,19 @@ async fn handle(
}
}
// XOR-fold constant-time comparison: avoids early return on length mismatch to
// prevent a timing oracle that would reveal the expected token length to a remote
// attacker (OWASP ASVS V6.6.1). Bitwise `&` on bool is eager — it never
// short-circuits — so both the length check and the byte fold always execute.
// XOR-fold constant-time comparison. Running time depends only on the length of the
// expected token (b), not on min(a.len(), b.len()), to prevent a timing oracle where
// an attacker reduces the iteration count by sending a shorter candidate
// (OWASP ASVS V6.6.1). Bitwise `&` on bool is eager — it never short-circuits —
// so both the length check and the byte fold always execute.
fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
let length_ok = a.len() == b.len();
let min_len = a.len().min(b.len());
let byte_mismatch = a[..min_len]
.iter()
.zip(b[..min_len].iter())
.fold(0u8, |acc, (x, y)| acc | (x ^ y));
length_ok & (byte_mismatch == 0)
let mut diff = 0u8;
for i in 0..b.len() {
let x = a.get(i).copied().unwrap_or(0);
let y = b[i];
diff |= x ^ y;
}
(a.len() == b.len()) & (diff == 0)
}
#[cfg(test)]
@ -765,4 +766,106 @@ mod tests {
c[0] = 0xde;
assert!(!constant_time_eq(&a, &c));
}
// ── Timing-oracle adversarial tests (length-iteration invariant) ──────────
// An active censor/attacker who knows the token format may send truncated
// inputs to narrow down the token length via a timing side-channel.
// After the fix, `constant_time_eq` always iterates over `b.len()` (the
// expected token length), so submission of every strict prefix of the
// expected token must be rejected while iteration count stays constant.
#[test]
fn constant_time_eq_every_prefix_of_expected_token_is_rejected() {
let expected = b"Bearer super-secret-api-token-abc123xyz";
for prefix_len in 0..expected.len() {
let attacker_input = &expected[..prefix_len];
assert!(
!constant_time_eq(attacker_input, expected),
"prefix of length {prefix_len} must not authenticate against full token"
);
}
}
// Reversed: input longer than expected token — the extra bytes must cause
// rejection even when the first b.len() bytes are correct.
#[test]
fn constant_time_eq_input_with_correct_prefix_plus_extra_bytes_is_rejected() {
let expected = b"secret-token";
for extra in 1usize..=32 {
let mut longer = expected.to_vec();
// Extend with zeros — the XOR of matching first bytes is 0, so only
// the length check prevents a false positive.
longer.extend(std::iter::repeat(0u8).take(extra));
assert!(
!constant_time_eq(&longer, expected),
"input extended by {extra} zero bytes must not authenticate"
);
// Extend with matching-value bytes — ensures the byte_fold stays at 0
// for the expected-length portion; only length differs.
let mut same_byte_extension = expected.to_vec();
same_byte_extension.extend(std::iter::repeat(expected[0]).take(extra));
assert!(
!constant_time_eq(&same_byte_extension, expected),
"input extended by {extra} repeated bytes must not authenticate"
);
}
}
// Null-byte injection: ensure the function does not mis-parse embedded
// NUL characters as C-string terminators and accept a shorter match.
#[test]
fn constant_time_eq_null_byte_injection_is_rejected() {
// Token containing a null byte — must only match itself exactly.
let expected: &[u8] = b"token\x00suffix";
assert!(constant_time_eq(expected, expected));
assert!(!constant_time_eq(b"token", expected));
assert!(!constant_time_eq(b"token\x00", expected));
assert!(!constant_time_eq(b"token\x00suffi", expected));
// Null-prefixed input of the same length must not match a non-null token.
let real_token: &[u8] = b"real-secret-value";
let mut null_injected = vec![0u8; real_token.len()];
null_injected[0] = real_token[0];
assert!(!constant_time_eq(&null_injected, real_token));
}
// High-byte (0xFF) values throughout: XOR of 0xFF ^ 0xFF = 0, so equal
// high-byte slices must match, and any single-byte difference must not.
#[test]
fn constant_time_eq_high_byte_edge_cases() {
let token = vec![0xffu8; 20];
assert!(constant_time_eq(&token, &token));
let mut tampered = token.clone();
tampered[10] = 0xfe;
assert!(!constant_time_eq(&tampered, &token));
// Shorter all-ff slice must not match.
assert!(!constant_time_eq(&token[..19], &token));
}
// Accumulator-saturation attack: if all bytes of `a` have been XOR-folded to
// 0xFF (i.e. acc is saturated), but the remaining bytes of `b` are 0x00, the
// fold of 0x00 into 0xFF must keep acc ≠ 0 (since 0xFF | 0 = 0xFF).
// This guards against a misimplemented fold that resets acc on certain values.
#[test]
fn constant_time_eq_accumulator_never_resets_to_zero_after_mismatch() {
// a[0] = 0xAA, b[0] = 0x55 → XOR = 0xFF.
// Subsequent bytes all match (XOR = 0x00). Accumulator must remain 0xFF.
let mut a = vec![0x55u8; 16];
let mut b = vec![0x55u8; 16];
a[0] = 0xAA; // deliberate mismatch at position 0
assert!(!constant_time_eq(&a, &b));
// Verify with mismatch only at last position to test late detection.
a[0] = b[0];
a[15] = 0xAA;
b[15] = 0x55;
assert!(!constant_time_eq(&a, &b));
}
// Zero-length expected token: only the empty input must match.
#[test]
fn constant_time_eq_zero_length_expected_only_matches_empty_input() {
assert!(constant_time_eq(b"", b""));
assert!(!constant_time_eq(b"\x00", b""));
assert!(!constant_time_eq(b"x", b""));
}
}

View File

@ -21,9 +21,11 @@
//! `network.*`, `use_middle_proxy`) are **not** applied; a warning is emitted.
//! Non-hot changes are never mixed into the runtime config snapshot.
use std::collections::BTreeSet;
use std::net::IpAddr;
use std::path::PathBuf;
use std::sync::Arc;
use std::path::{Path, PathBuf};
use std::sync::{Arc, RwLock as StdRwLock};
use std::time::Duration;
use notify::{EventKind, RecursiveMode, Watcher, recommended_watcher};
use tokio::sync::{mpsc, watch};
@ -33,7 +35,10 @@ use crate::config::{
LogLevel, MeBindStaleMode, MeFloorMode, MeSocksKdfPolicy, MeTelemetryLevel,
MeWriterPickMode,
};
use super::load::ProxyConfig;
use super::load::{LoadedConfig, ProxyConfig};
const HOT_RELOAD_STABLE_SNAPSHOTS: u8 = 2;
const HOT_RELOAD_DEBOUNCE: Duration = Duration::from_millis(50);
// ── Hot fields ────────────────────────────────────────────────────────────────
@ -287,6 +292,149 @@ fn listeners_equal(
})
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
struct WatchManifest {
files: BTreeSet<PathBuf>,
dirs: BTreeSet<PathBuf>,
}
impl WatchManifest {
fn from_source_files(source_files: &[PathBuf]) -> Self {
let mut files = BTreeSet::new();
let mut dirs = BTreeSet::new();
for path in source_files {
let normalized = normalize_watch_path(path);
files.insert(normalized.clone());
if let Some(parent) = normalized.parent() {
dirs.insert(parent.to_path_buf());
}
}
Self { files, dirs }
}
fn matches_event_paths(&self, event_paths: &[PathBuf]) -> bool {
event_paths
.iter()
.map(|path| normalize_watch_path(path))
.any(|path| self.files.contains(&path))
}
}
#[derive(Debug, Default)]
struct ReloadState {
applied_snapshot_hash: Option<u64>,
candidate_snapshot_hash: Option<u64>,
candidate_hits: u8,
}
impl ReloadState {
fn new(applied_snapshot_hash: Option<u64>) -> Self {
Self {
applied_snapshot_hash,
candidate_snapshot_hash: None,
candidate_hits: 0,
}
}
fn is_applied(&self, hash: u64) -> bool {
self.applied_snapshot_hash == Some(hash)
}
fn observe_candidate(&mut self, hash: u64) -> u8 {
if self.candidate_snapshot_hash == Some(hash) {
self.candidate_hits = self.candidate_hits.saturating_add(1);
} else {
self.candidate_snapshot_hash = Some(hash);
self.candidate_hits = 1;
}
self.candidate_hits
}
fn reset_candidate(&mut self) {
self.candidate_snapshot_hash = None;
self.candidate_hits = 0;
}
fn mark_applied(&mut self, hash: u64) {
self.applied_snapshot_hash = Some(hash);
self.reset_candidate();
}
}
fn normalize_watch_path(path: &Path) -> PathBuf {
path.canonicalize().unwrap_or_else(|_| {
if path.is_absolute() {
path.to_path_buf()
} else {
std::env::current_dir()
.map(|cwd| cwd.join(path))
.unwrap_or_else(|_| path.to_path_buf())
}
})
}
fn sync_watch_paths<W: Watcher>(
watcher: &mut W,
current: &BTreeSet<PathBuf>,
next: &BTreeSet<PathBuf>,
recursive_mode: RecursiveMode,
kind: &str,
) {
for path in current.difference(next) {
if let Err(e) = watcher.unwatch(path) {
warn!(path = %path.display(), error = %e, "config watcher: failed to unwatch {kind}");
}
}
for path in next.difference(current) {
if let Err(e) = watcher.watch(path, recursive_mode) {
warn!(path = %path.display(), error = %e, "config watcher: failed to watch {kind}");
}
}
}
fn apply_watch_manifest<W1: Watcher, W2: Watcher>(
notify_watcher: Option<&mut W1>,
poll_watcher: Option<&mut W2>,
manifest_state: &Arc<StdRwLock<WatchManifest>>,
next_manifest: WatchManifest,
) {
let current_manifest = manifest_state
.read()
.map(|manifest| manifest.clone())
.unwrap_or_default();
if current_manifest == next_manifest {
return;
}
if let Some(watcher) = notify_watcher {
sync_watch_paths(
watcher,
&current_manifest.dirs,
&next_manifest.dirs,
RecursiveMode::NonRecursive,
"config directory",
);
}
if let Some(watcher) = poll_watcher {
sync_watch_paths(
watcher,
&current_manifest.files,
&next_manifest.files,
RecursiveMode::NonRecursive,
"config file",
);
}
if let Ok(mut manifest) = manifest_state.write() {
*manifest = next_manifest;
}
}
fn overlay_hot_fields(old: &ProxyConfig, new: &ProxyConfig) -> ProxyConfig {
let mut cfg = old.clone();
@ -980,18 +1128,42 @@ fn reload_config(
log_tx: &watch::Sender<LogLevel>,
detected_ip_v4: Option<IpAddr>,
detected_ip_v6: Option<IpAddr>,
) {
let new_cfg = match ProxyConfig::load(config_path) {
Ok(c) => c,
reload_state: &mut ReloadState,
) -> Option<WatchManifest> {
let loaded = match ProxyConfig::load_with_metadata(config_path) {
Ok(loaded) => loaded,
Err(e) => {
reload_state.reset_candidate();
error!("config reload: failed to parse {:?}: {}", config_path, e);
return;
return None;
}
};
let LoadedConfig {
config: new_cfg,
source_files,
rendered_hash,
} = loaded;
let next_manifest = WatchManifest::from_source_files(&source_files);
if let Err(e) = new_cfg.validate() {
reload_state.reset_candidate();
error!("config reload: validation failed: {}; keeping old config", e);
return;
return Some(next_manifest);
}
if reload_state.is_applied(rendered_hash) {
return Some(next_manifest);
}
let candidate_hits = reload_state.observe_candidate(rendered_hash);
if candidate_hits < HOT_RELOAD_STABLE_SNAPSHOTS {
info!(
snapshot_hash = rendered_hash,
candidate_hits,
required_hits = HOT_RELOAD_STABLE_SNAPSHOTS,
"config reload: candidate snapshot observed but not stable yet"
);
return Some(next_manifest);
}
let old_cfg = config_tx.borrow().clone();
@ -1006,17 +1178,19 @@ fn reload_config(
}
if !hot_changed {
return;
reload_state.mark_applied(rendered_hash);
return Some(next_manifest);
}
if old_hot.dns_overrides != applied_hot.dns_overrides
&& let Err(e) = crate::network::dns_overrides::install_entries(&applied_hot.dns_overrides)
{
reload_state.reset_candidate();
error!(
"config reload: invalid network.dns_overrides: {}; keeping old config",
e
);
return;
return Some(next_manifest);
}
log_changes(
@ -1028,6 +1202,8 @@ fn reload_config(
detected_ip_v6,
);
config_tx.send(Arc::new(applied_cfg)).ok();
reload_state.mark_applied(rendered_hash);
Some(next_manifest)
}
// ── Public API ────────────────────────────────────────────────────────────────
@ -1050,79 +1226,86 @@ pub fn spawn_config_watcher(
let (config_tx, config_rx) = watch::channel(initial);
let (log_tx, log_rx) = watch::channel(initial_level);
// Bridge: sync notify callbacks → async task via mpsc.
let config_path = normalize_watch_path(&config_path);
let initial_loaded = ProxyConfig::load_with_metadata(&config_path).ok();
let initial_manifest = initial_loaded
.as_ref()
.map(|loaded| WatchManifest::from_source_files(&loaded.source_files))
.unwrap_or_else(|| WatchManifest::from_source_files(std::slice::from_ref(&config_path)));
let initial_snapshot_hash = initial_loaded.as_ref().map(|loaded| loaded.rendered_hash);
tokio::spawn(async move {
let (notify_tx, mut notify_rx) = mpsc::channel::<()>(4);
let manifest_state = Arc::new(StdRwLock::new(WatchManifest::default()));
let mut reload_state = ReloadState::new(initial_snapshot_hash);
// Canonicalize so path matches what notify returns (absolute) in events.
let config_path = config_path
.canonicalize()
.unwrap_or_else(|_| config_path.to_path_buf());
// Watch the parent directory rather than the file itself, because many
// editors (vim, nano) and systemd write via rename, which would cause
// inotify to lose track of the original inode.
let watch_dir = config_path
.parent()
.unwrap_or_else(|| std::path::Path::new("."))
.to_path_buf();
// ── inotify watcher (instant on local fs) ────────────────────────────
let config_file = config_path.clone();
let tx_inotify = notify_tx.clone();
let inotify_ok = match recommended_watcher(move |res: notify::Result<notify::Event>| {
let manifest_for_inotify = manifest_state.clone();
let mut inotify_watcher = match recommended_watcher(move |res: notify::Result<notify::Event>| {
let Ok(event) = res else { return };
let is_our_file = event.paths.iter().any(|p| p == &config_file);
if !is_our_file { return; }
if matches!(event.kind, EventKind::Modify(_) | EventKind::Create(_) | EventKind::Remove(_)) {
if !matches!(event.kind, EventKind::Modify(_) | EventKind::Create(_) | EventKind::Remove(_)) {
return;
}
let is_our_file = manifest_for_inotify
.read()
.map(|manifest| manifest.matches_event_paths(&event.paths))
.unwrap_or(false);
if is_our_file {
let _ = tx_inotify.try_send(());
}
}) {
Ok(mut w) => match w.watch(&watch_dir, RecursiveMode::NonRecursive) {
Ok(()) => {
info!("config watcher: inotify active on {:?}", config_path);
Box::leak(Box::new(w));
true
Ok(watcher) => Some(watcher),
Err(e) => {
warn!("config watcher: inotify unavailable: {}", e);
None
}
Err(e) => { warn!("config watcher: inotify watch failed: {}", e); false }
},
Err(e) => { warn!("config watcher: inotify unavailable: {}", e); false }
};
apply_watch_manifest(
inotify_watcher.as_mut(),
Option::<&mut notify::poll::PollWatcher>::None,
&manifest_state,
initial_manifest.clone(),
);
if inotify_watcher.is_some() {
info!("config watcher: inotify active on {:?}", config_path);
}
// ── poll watcher (always active, fixes Docker bind mounts / NFS) ─────
// inotify does not receive events for files mounted from the host into
// a container. PollWatcher compares file contents every 3 s and fires
// on any change regardless of the underlying fs.
let config_file2 = config_path.clone();
let tx_poll = notify_tx;
match notify::poll::PollWatcher::new(
let tx_poll = notify_tx.clone();
let manifest_for_poll = manifest_state.clone();
let mut poll_watcher = match notify::poll::PollWatcher::new(
move |res: notify::Result<notify::Event>| {
let Ok(event) = res else { return };
let is_our_file = event.paths.iter().any(|p| p == &config_file2);
if !is_our_file { return; }
if matches!(event.kind, EventKind::Modify(_) | EventKind::Create(_) | EventKind::Remove(_)) {
if !matches!(event.kind, EventKind::Modify(_) | EventKind::Create(_) | EventKind::Remove(_)) {
return;
}
let is_our_file = manifest_for_poll
.read()
.map(|manifest| manifest.matches_event_paths(&event.paths))
.unwrap_or(false);
if is_our_file {
let _ = tx_poll.try_send(());
}
},
notify::Config::default()
.with_poll_interval(std::time::Duration::from_secs(3))
.with_poll_interval(Duration::from_secs(3))
.with_compare_contents(true),
) {
Ok(mut w) => match w.watch(&config_path, RecursiveMode::NonRecursive) {
Ok(()) => {
if inotify_ok {
info!("config watcher: poll watcher also active (Docker/NFS safe)");
} else {
info!("config watcher: poll watcher active on {:?} (3s interval)", config_path);
Ok(watcher) => Some(watcher),
Err(e) => {
warn!("config watcher: poll watcher unavailable: {}", e);
None
}
Box::leak(Box::new(w));
}
Err(e) => warn!("config watcher: poll watch failed: {}", e),
},
Err(e) => warn!("config watcher: poll watcher unavailable: {}", e),
};
apply_watch_manifest(
Option::<&mut notify::RecommendedWatcher>::None,
poll_watcher.as_mut(),
&manifest_state,
initial_manifest.clone(),
);
if poll_watcher.is_some() {
info!("config watcher: poll watcher active (Docker/NFS safe)");
}
// ── event loop ───────────────────────────────────────────────────────
tokio::spawn(async move {
#[cfg(unix)]
let mut sighup = {
use tokio::signal::unix::{signal, SignalKind};
@ -1152,11 +1335,25 @@ pub fn spawn_config_watcher(
#[cfg(not(unix))]
if notify_rx.recv().await.is_none() { break; }
// Debounce: drain extra events that arrive within 50 ms.
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
// Debounce: drain extra events that arrive within a short quiet window.
tokio::time::sleep(HOT_RELOAD_DEBOUNCE).await;
while notify_rx.try_recv().is_ok() {}
reload_config(&config_path, &config_tx, &log_tx, detected_ip_v4, detected_ip_v6);
if let Some(next_manifest) = reload_config(
&config_path,
&config_tx,
&log_tx,
detected_ip_v4,
detected_ip_v6,
&mut reload_state,
) {
apply_watch_manifest(
inotify_watcher.as_mut(),
poll_watcher.as_mut(),
&manifest_state,
next_manifest,
);
}
}
});
@ -1171,6 +1368,40 @@ mod tests {
ProxyConfig::default()
}
fn write_reload_config(path: &Path, ad_tag: Option<&str>, server_port: Option<u16>) {
let mut config = String::from(
r#"
[censorship]
tls_domain = "example.com"
[access.users]
user = "00000000000000000000000000000000"
"#,
);
if ad_tag.is_some() {
config.push_str("\n[general]\n");
if let Some(tag) = ad_tag {
config.push_str(&format!("ad_tag = \"{tag}\"\n"));
}
}
if let Some(port) = server_port {
config.push_str("\n[server]\n");
config.push_str(&format!("port = {port}\n"));
}
std::fs::write(path, config).unwrap();
}
fn temp_config_path(prefix: &str) -> PathBuf {
let nonce = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_nanos();
std::env::temp_dir().join(format!("{prefix}_{nonce}.toml"))
}
#[test]
fn overlay_applies_hot_and_preserves_non_hot() {
let old = sample_config();
@ -1238,4 +1469,61 @@ mod tests {
assert_eq!(applied.general.use_middle_proxy, old.general.use_middle_proxy);
assert!(!config_equal(&applied, &new));
}
#[test]
fn reload_requires_stable_snapshot_before_hot_apply() {
let initial_tag = "11111111111111111111111111111111";
let final_tag = "22222222222222222222222222222222";
let path = temp_config_path("telemt_hot_reload_stable");
write_reload_config(&path, Some(initial_tag), None);
let initial_cfg = Arc::new(ProxyConfig::load(&path).unwrap());
let initial_hash = ProxyConfig::load_with_metadata(&path).unwrap().rendered_hash;
let (config_tx, _config_rx) = watch::channel(initial_cfg.clone());
let (log_tx, _log_rx) = watch::channel(initial_cfg.general.log_level.clone());
let mut reload_state = ReloadState::new(Some(initial_hash));
write_reload_config(&path, None, None);
reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap();
assert_eq!(
config_tx.borrow().general.ad_tag.as_deref(),
Some(initial_tag)
);
write_reload_config(&path, Some(final_tag), None);
reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap();
assert_eq!(
config_tx.borrow().general.ad_tag.as_deref(),
Some(initial_tag)
);
reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap();
assert_eq!(config_tx.borrow().general.ad_tag.as_deref(), Some(final_tag));
let _ = std::fs::remove_file(path);
}
#[test]
fn reload_keeps_hot_apply_when_non_hot_fields_change() {
let initial_tag = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa";
let final_tag = "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb";
let path = temp_config_path("telemt_hot_reload_mixed");
write_reload_config(&path, Some(initial_tag), None);
let initial_cfg = Arc::new(ProxyConfig::load(&path).unwrap());
let initial_hash = ProxyConfig::load_with_metadata(&path).unwrap().rendered_hash;
let (config_tx, _config_rx) = watch::channel(initial_cfg.clone());
let (log_tx, _log_rx) = watch::channel(initial_cfg.general.log_level.clone());
let mut reload_state = ReloadState::new(Some(initial_hash));
write_reload_config(&path, Some(final_tag), Some(initial_cfg.server.port + 1));
reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap();
reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap();
let applied = config_tx.borrow().clone();
assert_eq!(applied.general.ad_tag.as_deref(), Some(final_tag));
assert_eq!(applied.server.port, initial_cfg.server.port);
let _ = std::fs::remove_file(path);
}
}

View File

@ -1,8 +1,9 @@
#![allow(deprecated)]
use std::collections::HashMap;
use std::collections::{BTreeSet, HashMap};
use std::hash::{DefaultHasher, Hash, Hasher};
use std::net::{IpAddr, SocketAddr};
use std::path::Path;
use std::path::{Path, PathBuf};
use rand::Rng;
use tracing::warn;
@ -13,6 +14,8 @@ use crate::error::{ProxyError, Result};
use super::defaults::*;
use super::types::*;
// Reject absolute paths and any path component that traverses upward.
// This prevents config include directives from escaping the config directory.
fn is_safe_include_path(path_str: &str) -> bool {
let p = std::path::Path::new(path_str);
if p.is_absolute() {
@ -21,7 +24,37 @@ fn is_safe_include_path(path_str: &str) -> bool {
!p.components().any(|c| matches!(c, std::path::Component::ParentDir))
}
fn preprocess_includes(content: &str, base_dir: &Path, depth: u8) -> Result<String> {
#[derive(Debug, Clone)]
pub(crate) struct LoadedConfig {
pub(crate) config: ProxyConfig,
pub(crate) source_files: Vec<PathBuf>,
pub(crate) rendered_hash: u64,
}
fn normalize_config_path(path: &Path) -> PathBuf {
path.canonicalize().unwrap_or_else(|_| {
if path.is_absolute() {
path.to_path_buf()
} else {
std::env::current_dir()
.map(|cwd| cwd.join(path))
.unwrap_or_else(|_| path.to_path_buf())
}
})
}
fn hash_rendered_snapshot(rendered: &str) -> u64 {
let mut hasher = DefaultHasher::new();
rendered.hash(&mut hasher);
hasher.finish()
}
fn preprocess_includes(
content: &str,
base_dir: &Path,
depth: u8,
source_files: &mut BTreeSet<PathBuf>,
) -> Result<String> {
if depth > 10 {
return Err(ProxyError::Config("Include depth > 10".into()));
}
@ -39,10 +72,16 @@ fn preprocess_includes(content: &str, base_dir: &Path, depth: u8) -> Result<Stri
)));
}
let resolved = base_dir.join(path_str);
source_files.insert(normalize_config_path(&resolved));
let included = std::fs::read_to_string(&resolved)
.map_err(|e| ProxyError::Config(e.to_string()))?;
let included_dir = resolved.parent().unwrap_or(base_dir);
output.push_str(&preprocess_includes(&included, included_dir, depth + 1)?);
output.push_str(&preprocess_includes(
&included,
included_dir,
depth + 1,
source_files,
)?);
output.push('\n');
continue;
}
@ -152,13 +191,16 @@ pub struct ProxyConfig {
impl ProxyConfig {
pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
let content =
std::fs::read_to_string(&path).map_err(|e| ProxyError::Config(e.to_string()))?;
let base_dir = path
.as_ref()
.parent()
.unwrap_or_else(|| Path::new("."));
let processed = preprocess_includes(&content, base_dir, 0)?;
Self::load_with_metadata(path).map(|loaded| loaded.config)
}
pub(crate) fn load_with_metadata<P: AsRef<Path>>(path: P) -> Result<LoadedConfig> {
let path = path.as_ref();
let content = std::fs::read_to_string(path).map_err(|e| ProxyError::Config(e.to_string()))?;
let base_dir = path.parent().unwrap_or(Path::new("."));
let mut source_files = BTreeSet::new();
source_files.insert(normalize_config_path(path));
let processed = preprocess_includes(&content, base_dir, 0, &mut source_files)?;
let parsed_toml: toml::Value =
toml::from_str(&processed).map_err(|e| ProxyError::Config(e.to_string()))?;
@ -821,7 +863,11 @@ impl ProxyConfig {
.entry("203".to_string())
.or_insert_with(|| vec!["91.105.192.100:443".to_string()]);
Ok(config)
Ok(LoadedConfig {
config,
source_files: source_files.into_iter().collect(),
rendered_hash: hash_rendered_snapshot(&processed),
})
}
pub fn validate(&self) -> Result<()> {
@ -1165,6 +1211,48 @@ mod tests {
);
}
#[test]
fn load_with_metadata_collects_include_files() {
let nonce = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_nanos();
let dir = std::env::temp_dir().join(format!("telemt_load_metadata_{nonce}"));
std::fs::create_dir_all(&dir).unwrap();
let main_path = dir.join("config.toml");
let include_path = dir.join("included.toml");
std::fs::write(
&include_path,
r#"
[access.users]
user = "00000000000000000000000000000000"
"#,
)
.unwrap();
std::fs::write(
&main_path,
r#"
include = "included.toml"
[censorship]
tls_domain = "example.com"
"#,
)
.unwrap();
let loaded = ProxyConfig::load_with_metadata(&main_path).unwrap();
let main_normalized = normalize_config_path(&main_path);
let include_normalized = normalize_config_path(&include_path);
assert!(loaded.source_files.contains(&main_normalized));
assert!(loaded.source_files.contains(&include_normalized));
let _ = std::fs::remove_file(main_path);
let _ = std::fs::remove_file(include_path);
let _ = std::fs::remove_dir(dir);
}
#[test]
fn dc_overrides_inject_dc203_default() {
let toml = r#"
@ -2197,7 +2285,8 @@ mod tests {
let root = dir.join(format!("{prefix}0.toml"));
let root_content = std::fs::read_to_string(&root).unwrap();
let result = preprocess_includes(&root_content, &dir, 0);
let mut sf = BTreeSet::new();
let result = preprocess_includes(&root_content, &dir, 0, &mut sf);
for i in 0usize..=11 {
let _ = std::fs::remove_file(dir.join(format!("{prefix}{i}.toml")));
@ -2230,7 +2319,8 @@ mod tests {
let root = dir.join(format!("{prefix}0.toml"));
let root_content = std::fs::read_to_string(&root).unwrap();
let result = preprocess_includes(&root_content, &dir, 0);
let mut sf = BTreeSet::new();
let result = preprocess_includes(&root_content, &dir, 0, &mut sf);
for i in 0usize..=10 {
let _ = std::fs::remove_file(dir.join(format!("{prefix}{i}.toml")));

View File

@ -382,20 +382,7 @@ mod tests {
}
// ObfuscationParams must NOT implement Clone — key material must not be duplicated.
// This is a compile-time enforcement test via negative trait assertion.
// If Clone were derived again, this assertion would fail to compile.
#[test]
fn test_obfuscation_params_is_not_clone() {
fn assert_not_clone<T: ?Sized>() {}
// The trait bound below must NOT hold. We verify by trying to assert
// Clone is NOT available: calling ::clone() on the type must not compile.
// Since we cannot express negative bounds in stable Rust, we use the
// auto_trait approach: verify the type is not Clone via a static assertion.
// If this test compiles and passes, Clone is absent from ObfuscationParams.
struct NotClone;
impl NotClone { fn check() { assert_not_clone::<ObfuscationParams>(); } }
// The above does not assert Clone is absent; the real guard is that the
// compiler will reject any `.clone()` call site. This test documents intent.
NotClone::check();
}
// Enforced at compile time: if Clone is ever derived or manually implemented for
// ObfuscationParams, this assertion will fail to compile.
static_assertions::assert_not_impl_any!(ObfuscationParams: Clone);
}

View File

@ -8,7 +8,6 @@
use bytes::BytesMut;
use crossbeam_queue::ArrayQueue;
use std::ops::{Deref, DerefMut};
use std::sync::OnceLock;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
@ -21,6 +20,11 @@ pub const DEFAULT_BUFFER_SIZE: usize = 16 * 1024;
/// Default maximum number of pooled buffers
pub const DEFAULT_MAX_BUFFERS: usize = 1024;
// Buffers that grew beyond this multiple of `buffer_size` are dropped rather
// than returned to the pool, preventing memory amplification from a single
// large-payload connection permanently holding oversized allocations.
const MAX_POOL_BUFFER_OVERSIZE_MULT: usize = 4;
// ============= Buffer Pool =============
/// Thread-safe pool of reusable buffers
@ -97,11 +101,13 @@ impl BufferPool {
fn return_buffer(&self, mut buffer: BytesMut) {
buffer.clear();
// Only return buffers that have at least the canonical capacity so the
// pool does not shrink over time when callers grow their buffers.
// Over-capacity or full-queue buffers are simply dropped; `allocated`
// is a high-water mark and is not decremented here.
if buffer.capacity() >= self.buffer_size {
// Accept buffers within [buffer_size, buffer_size * MAX_POOL_BUFFER_OVERSIZE_MULT].
// The lower bound prevents pool capacity from shrinking over time.
// The upper bound drops buffers that grew excessively (e.g. to serve a large
// payload) so they do not permanently inflate pool memory or get handed to a
// future connection that only needs a small allocation.
let max_acceptable = self.buffer_size.saturating_mul(MAX_POOL_BUFFER_OVERSIZE_MULT);
if buffer.capacity() >= self.buffer_size && buffer.capacity() <= max_acceptable {
let _ = self.buffers.push(buffer);
}
}
@ -216,16 +222,17 @@ impl Deref for PooledBuffer {
type Target = BytesMut;
fn deref(&self) -> &Self::Target {
static EMPTY_BUFFER: OnceLock<BytesMut> = OnceLock::new();
self.buffer
.as_ref()
.unwrap_or_else(|| EMPTY_BUFFER.get_or_init(BytesMut::new))
.expect("PooledBuffer: attempted to deref after buffer was taken")
}
}
impl DerefMut for PooledBuffer {
fn deref_mut(&mut self) -> &mut Self::Target {
self.buffer.get_or_insert_with(BytesMut::new)
self.buffer
.as_mut()
.expect("PooledBuffer: attempted to deref_mut after buffer was taken")
}
}
@ -503,22 +510,27 @@ mod tests {
);
}
// An over-grown buffer (capacity > buffer_size) that is returned to the pool
// must still be returned (not dropped) because its capacity exceeds the
// required minimum.
// A buffer that grew moderately (within MAX_POOL_BUFFER_OVERSIZE_MULT of the canonical
// size) must be returned to the pool, because the allocation is still reasonable and
// reusing it is more efficient than allocating a new one.
#[test]
fn oversized_buffer_is_returned_to_pool() {
let pool = Arc::new(BufferPool::with_config(64, 10));
let canonical = 64usize;
let pool = Arc::new(BufferPool::with_config(canonical, 10));
let mut buf = pool.get();
// Manually grow the backing allocation beyond buffer_size.
buf.reserve(8192);
assert!(buf.capacity() >= 8192);
// Grow to 2× the canonical size — within the 4× upper bound.
buf.reserve(canonical);
assert!(buf.capacity() >= canonical);
assert!(
buf.capacity() <= canonical * MAX_POOL_BUFFER_OVERSIZE_MULT,
"pre-condition: test growth must stay within the acceptable bound"
);
drop(buf);
// The buffer must have been returned because capacity >= buffer_size.
// The buffer must have been returned because capacity is within acceptable range.
let stats = pool.stats();
assert_eq!(stats.pooled, 1, "oversized buffer must be returned to pool");
assert_eq!(stats.pooled, 1, "moderately-oversized buffer must be returned to pool");
}
// A buffer whose capacity fell below buffer_size (e.g. due to take() on
@ -625,4 +637,97 @@ mod tests {
);
assert_eq!(stats.hits, 50);
}
// ── Security invariant: sensitive data must not leak between pool users ───
// A buffer containing "sensitive" bytes must be zeroed before being handed
// to the next caller. An attacker who can trigger repeated pool cycles against
// a shared buffer slot must not be able to read prior connection data.
#[test]
fn pooled_buffer_sensitive_data_is_cleared_before_reuse() {
let pool = Arc::new(BufferPool::with_config(64, 2));
{
let mut buf = pool.get();
buf.extend_from_slice(b"credentials:password123");
// Drop returns the buffer to the pool after clearing.
}
{
let buf = pool.get();
// Buffer must be empty — no leftover bytes from the previous user.
assert!(buf.is_empty(), "pool must clear buffer before handing it to the next caller");
assert_eq!(buf.len(), 0);
}
}
// Verify that calling take() extracts the full content and the extracted
// BytesMut does NOT get returned to the pool (no double-return).
#[test]
fn pooled_buffer_take_eliminates_pool_return() {
let pool = Arc::new(BufferPool::with_config(64, 2));
let stats_before = pool.stats();
let mut buf = pool.get(); // miss
buf.extend_from_slice(b"important");
let inner = buf.take(); // consumes PooledBuffer, should NOT return to pool
assert_eq!(&inner[..], b"important");
let stats_after = pool.stats();
// pooled count must not increase — take() bypasses the pool
assert_eq!(
stats_after.pooled, stats_before.pooled,
"take() must not return the buffer to the pool"
);
}
// Multiple concurrent get() calls must each get an independent empty buffer,
// not aliased memory. An adversary who can cause aliased buffer access could
// read or corrupt another connection's in-flight data.
#[test]
fn pooled_buffers_are_independent_no_aliasing() {
let pool = Arc::new(BufferPool::with_config(64, 4));
let mut b1 = pool.get();
let mut b2 = pool.get();
b1.extend_from_slice(b"connection-A");
b2.extend_from_slice(b"connection-B");
assert_eq!(&b1[..], b"connection-A");
assert_eq!(&b2[..], b"connection-B");
// Verify no aliasing: modifying b2 does not affect b1.
assert_ne!(&b1[..], &b2[..]);
}
// Oversized buffers (capacity grown beyond pool's canonical size) must NOT
// be returned to the pool — this prevents the pool from holding oversized
// buffers that could be handed to unrelated connections and leak large chunks
// of heap across connection boundaries.
#[test]
fn oversized_buffer_is_dropped_not_pooled() {
let canonical = 64usize;
let pool = Arc::new(BufferPool::with_config(canonical, 4));
{
let mut buf = pool.get();
// Grow well beyond the canonical size.
buf.extend(std::iter::repeat(0u8).take(canonical * 8));
// Drop should abandon this oversized buffer rather than returning it.
}
let stats = pool.stats();
// Pool must be empty: the oversized buffer was not re-queued.
assert_eq!(
stats.pooled, 0,
"oversized buffer must be dropped, not returned to pool (got {} pooled)",
stats.pooled
);
}
// Deref on a PooledBuffer obtained normally must NOT panic.
#[test]
fn pooled_buffer_deref_on_live_buffer_does_not_panic() {
let pool = Arc::new(BufferPool::new());
let mut buf = pool.get();
buf.extend_from_slice(b"hello");
assert_eq!(&buf[..], b"hello");
}
}

View File

@ -97,9 +97,19 @@ pub(crate) fn build_proxy_req_payload(
}
}
let extra_bytes = u32::try_from(b.len() - extra_start - 4).unwrap_or(0);
// On overflow (buffer > 4 GiB, unreachable in practice) keep the extra block
// empty by truncating the data back to the length-field position and writing 0,
// so the wire representation remains consistent (length field matches content).
match u32::try_from(b.len() - extra_start - 4) {
Ok(extra_bytes) => {
b[extra_start..extra_start + 4].copy_from_slice(&extra_bytes.to_le_bytes());
}
Err(_) => {
b.truncate(extra_start + 4);
b[extra_start..extra_start + 4].copy_from_slice(&0u32.to_le_bytes());
}
}
}
b.extend_from_slice(data);
Bytes::from(b)
@ -242,6 +252,7 @@ mod tests {
// Tags exceeding the 3-byte TL length limit (> 0xFFFFFF = 16,777,215 bytes)
// must produce an empty extra block rather than a truncated/corrupt length.
#[ignore = "allocates ~16 MiB; run only in a dedicated large-tests profile/CI job"]
#[test]
fn test_build_proxy_req_payload_tag_exceeds_tl_long_form_limit_produces_empty_extra() {
let client = SocketAddr::from(([198, 51, 100, 22], 33333));
@ -267,6 +278,7 @@ mod tests {
// The prior guard was `tag_len_u32.is_some()` which passed for tags up to u32::MAX.
// Verify boundary at exactly 0xFFFFFF: must succeed and encode at 3 bytes.
#[ignore = "allocates ~16 MiB; run only in a dedicated large-tests profile/CI job"]
#[test]
fn test_build_proxy_req_payload_tag_at_tl_long_form_max_boundary_encodes() {
// 0xFFFFFF = 16,777,215 bytes — maximum representable by 3-byte TL length.
@ -303,4 +315,101 @@ mod tests {
let fixed = 4 + 4 + 8 + 20 + 20;
let extra_len = u32::from_le_bytes(raw[fixed..fixed + 4].try_into().unwrap()) as usize;
assert_eq!(extra_len, 0);
}}
}
// ── Protocol wire-consistency invariant tests ─────────────────────────────
// The extra-block length field must ALWAYS equal the number of bytes that
// actually follow it in the buffer. A censor or MitM that parses the wire
// representation must not be able to trigger desync by providing adversarial
// input that causes the length to say 0 while data bytes are present.
#[test]
fn extra_block_length_field_always_matches_actual_content_length() {
let client = SocketAddr::from(([198, 51, 100, 40], 10000));
let ours = SocketAddr::from(([203, 0, 113, 40], 443));
let fixed = 4 + 4 + 8 + 20 + 20;
// Helper: assert wire consistency for a given tag.
let check = |tag: Option<&[u8]>, data: &[u8]| {
let p = build_proxy_req_payload(1, client, ours, data, tag, RPC_FLAG_HAS_AD_TAG);
let raw = p.as_ref();
let declared =
u32::from_le_bytes(raw[fixed..fixed + 4].try_into().unwrap()) as usize;
let actual = raw.len() - fixed - 4 - data.len();
assert_eq!(
declared, actual,
"extra-block length field ({declared}) must equal actual byte count ({actual})"
);
};
check(None, b"data");
check(Some(b""), b"data"); // zero-length tag
check(Some(&[0xABu8; 1]), b""); // 1-byte tag
check(Some(&[0xABu8; 253]), b"x"); // short-form max
check(Some(&[0xCCu8; 254]), b"y"); // long-form min
check(Some(&[0xDDu8; 500]), b"z"); // mid-range long-form
}
// Zero-length tag: must encode with short-form length byte 0 and 4-byte
// alignment padding, not be dropped as if tag were None.
#[test]
fn test_build_proxy_req_payload_zero_length_tag_encodes_correctly() {
let client = SocketAddr::from(([198, 51, 100, 50], 20000));
let ours = SocketAddr::from(([203, 0, 113, 50], 443));
let payload = build_proxy_req_payload(
12,
client,
ours,
b"payload",
Some(&[]),
RPC_FLAG_HAS_AD_TAG,
);
let raw = payload.as_ref();
let fixed = 4 + 4 + 8 + 20 + 20;
let extra_len = u32::from_le_bytes(raw[fixed..fixed + 4].try_into().unwrap()) as usize;
// TL object (4) + length-byte 0 (1) + 0 data bytes + padding to 4-byte boundary.
// (1 + 0) % 4 = 1; pad = (4 - 1) % 4 = 3.
assert_eq!(extra_len, 4 + 1 + 3, "zero-length tag must produce TL header + padding");
// Length byte must be 0 (short-form).
assert_eq!(raw[fixed + 8], 0u8);
}
// HAS_AD_TAG absent: extra block must NOT appear in the wire output at all.
#[test]
fn test_build_proxy_req_payload_without_has_ad_tag_flag_has_no_extra_block() {
let client = SocketAddr::from(([198, 51, 100, 60], 30000));
let ours = SocketAddr::from(([203, 0, 113, 60], 443));
let payload = build_proxy_req_payload(
13,
client,
ours,
b"data",
Some(&[1, 2, 3]),
0, // no HAS_AD_TAG flag
);
let fixed = 4 + 4 + 8 + 20 + 20;
let raw = payload.as_ref();
// Without the flag, extra block is skipped; payload follows immediately after headers.
assert_eq!(raw.len(), fixed + 4, "no extra block when HAS_AD_TAG is absent");
assert_eq!(&raw[fixed..], b"data");
}
// Tag with all-0xFF bytes: the TL length encoding must not be confused by
// the 0xFF marker byte that the tag itself might contain.
#[test]
fn test_build_proxy_req_payload_tag_containing_0xfe_byte_encodes_correctly() {
let client = SocketAddr::from(([198, 51, 100, 70], 40000));
let ours = SocketAddr::from(([203, 0, 113, 70], 443));
// 10-byte tag whose first byte is 0xFE (the long-form marker value).
// At length 10 the short-form encoding applies; the 0xFE data byte must
// not be confused with the length marker.
let tag = vec![0xFEu8; 10];
let payload = build_proxy_req_payload(14, client, ours, b"", Some(&tag), RPC_FLAG_HAS_AD_TAG);
let raw = payload.as_ref();
let fixed = 4 + 4 + 8 + 20 + 20;
// Length byte must be 10 (short form), not 0xFE.
assert_eq!(raw[fixed + 8], 10u8, "short-form length byte must be 10, not the 0xFE marker");
// The actual tag bytes must follow immediately.
assert_eq!(&raw[fixed + 9..fixed + 19], &[0xFEu8; 10]);
}
}

View File

@ -260,24 +260,45 @@ pub struct PoolStats {
#[cfg(unix)]
#[allow(unsafe_code)]
fn is_connection_alive(stream: &TcpStream) -> bool {
use std::io::ErrorKind;
use std::os::unix::io::AsRawFd;
let fd = stream.as_raw_fd();
let mut buf = [0u8; 1];
// EINTR can fire on any syscall when a signal is delivered to the thread.
// Treating it as a dead connection causes spurious reconnects; retry instead.
const MAX_RECV_RETRIES: usize = 3;
for _ in 0..MAX_RECV_RETRIES {
// SAFETY: `stream` owns this fd for the full duration of the call.
// MSG_PEEK + MSG_DONTWAIT: inspect the receive buffer without consuming bytes.
let n = unsafe {
libc::recv(
stream.as_raw_fd(),
fd,
buf.as_mut_ptr().cast::<libc::c_void>(),
1,
libc::MSG_PEEK | libc::MSG_DONTWAIT,
)
};
match n {
0 => false,
n if n > 0 => true,
_ => std::io::Error::last_os_error().kind() == std::io::ErrorKind::WouldBlock,
if n > 0 {
return true;
} else if n == 0 {
return false;
} else {
match std::io::Error::last_os_error().kind() {
ErrorKind::Interrupted => continue,
ErrorKind::WouldBlock => return true,
_ => return false,
}
}
}
// After MAX_RECV_RETRIES consecutive EINTR, assume the connection is alive
// rather than triggering a false reconnect.
true
}
#[cfg(not(unix))]
fn is_connection_alive(stream: &TcpStream) -> bool {
@ -518,5 +539,63 @@ mod tests {
assert!(is_connection_alive(&stream));
}
}
// ── EINTR retry-logic unit tests ──────────────────────────────────────────
// The non-unix fallback path must correctly classify WouldBlock as alive
// and any other error as dead, mirroring the unix semantics.
#[cfg(not(unix))]
#[tokio::test]
async fn test_is_connection_alive_non_unix_open_connection() {
let listener = match TcpListener::bind("127.0.0.1:0").await {
Ok(l) => l,
Err(e) if e.kind() == ErrorKind::PermissionDenied => return,
Err(e) => panic!("bind failed: {e}"),
};
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
let _ = listener.accept().await;
tokio::time::sleep(std::time::Duration::from_millis(500)).await;
});
let stream = match TcpStream::connect(addr).await {
Ok(s) => s,
Err(e) if e.kind() == ErrorKind::PermissionDenied => return,
Err(e) => panic!("connect failed: {e}"),
};
assert!(is_connection_alive(&stream), "open idle connection must be alive (non-unix)");
}
// Verify the unix path: calling is_connection_alive many times on an active
// connection never spuriously returns false (guards against the pre-fix bug
// where EINTR would flip the result on a single call).
#[cfg(unix)]
#[tokio::test]
async fn test_is_connection_alive_never_returns_false_spuriously_on_open_connection() {
use tokio::io::AsyncWriteExt;
let listener = match TcpListener::bind("127.0.0.1:0").await {
Ok(l) => l,
Err(e) if e.kind() == ErrorKind::PermissionDenied => return,
Err(e) => panic!("bind failed: {e}"),
};
let addr = listener.local_addr().unwrap();
let server = tokio::spawn(async move {
let (mut s, _) = listener.accept().await.expect("accept");
s.write_all(&[0xBEu8]).await.expect("write");
tokio::time::sleep(Duration::from_millis(500)).await;
});
let stream = match TcpStream::connect(addr).await {
Ok(s) => s,
Err(e) if e.kind() == ErrorKind::PermissionDenied => return,
Err(e) => panic!("connect failed: {e}"),
};
tokio::time::sleep(Duration::from_millis(30)).await;
// Call is_connection_alive 20 times; all must return true.
for i in 0..20 {
assert!(
is_connection_alive(&stream),
"is_connection_alive must be true on call {i} (connection is open with buffered data)"
);
}
drop(server);
}
}