mirror of https://github.com/telemt/telemt.git
Resolve merge conflicts with upstream/flow and apply Copilot review fixes
Conflict resolution: - src/config/load.rs: Merge HEAD's is_safe_include_path (path traversal guard) with upstream/flow's LoadedConfig, normalize_config_path, hash_rendered_snapshot, and the new 4-parameter preprocess_includes signature. Update two test call sites that still used the old 3-arg signature. - src/config/hot_reload.rs: Take upstream/flow's tokio::spawn-contained watcher setup (inotify + poll via manifest_state). HEAD's pre-spawn block was broken: it referenced notify_tx before the channel was created. Copilot review fixes (already applied in working tree, now committed): - src/transport/pool.rs: Handle EINTR in is_connection_alive with retry loop (treating it as alive) instead of a false dead-connection verdict. - src/transport/middle_proxy/wire.rs: On u32 overflow in extra-block length encoding, truncate buffer back to the length-field position and write 0 so wire representation stays self-consistent. Annotate 16 MiB boundary tests with #[ignore] to avoid OOM on low-memory CI runners. - src/stream/buffer_pool.rs: Restore fail-fast expect() in Deref/DerefMut (was silently returning empty buffer after take(), masking use-after-take bugs). Add MAX_POOL_BUFFER_OVERSIZE_MULT upper bound in return_buffer to prevent memory amplification from excessively-grown buffers staying in the pool. Fix contradictory test: oversized_buffer_is_returned_to_pool now grows within the 4x bound; oversized_buffer_is_dropped_not_pooled (8x growth) now passes. - src/api/mod.rs: Fix constant_time_eq to iterate over b.len() (expected token length) rather than min(a.len(), b.len()), closing the timing oracle where an attacker could influence iteration count by sending a shorter candidate (OWASP ASVS V6.6.1). Revert ApiRuntimeState and ApiShared to pub(super). - src/protocol/obfuscation.rs: Replace no-op test_obfuscation_params_is_not_clone with static_assertions::assert_not_impl_any!(ObfuscationParams: Clone) which is an actual compile-time enforcement.
This commit is contained in:
commit
5f3a2e7055
129
src/api/mod.rs
129
src/api/mod.rs
|
|
@ -62,7 +62,7 @@ use runtime_zero::{
|
|||
use runtime_watch::spawn_runtime_watchers;
|
||||
use users::{create_user, delete_user, patch_user, rotate_secret, users_from_config};
|
||||
|
||||
pub struct ApiRuntimeState {
|
||||
pub(super) struct ApiRuntimeState {
|
||||
pub(super) process_started_at_epoch_secs: u64,
|
||||
pub(super) config_reload_count: AtomicU64,
|
||||
pub(super) last_config_reload_epoch_secs: AtomicU64,
|
||||
|
|
@ -70,7 +70,7 @@ pub struct ApiRuntimeState {
|
|||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ApiShared {
|
||||
pub(super) struct ApiShared {
|
||||
pub(super) stats: Arc<Stats>,
|
||||
pub(super) ip_tracker: Arc<UserIpTracker>,
|
||||
pub(super) me_pool: Arc<RwLock<Option<Arc<MePool>>>>,
|
||||
|
|
@ -565,18 +565,19 @@ async fn handle(
|
|||
}
|
||||
}
|
||||
|
||||
// XOR-fold constant-time comparison: avoids early return on length mismatch to
|
||||
// prevent a timing oracle that would reveal the expected token length to a remote
|
||||
// attacker (OWASP ASVS V6.6.1). Bitwise `&` on bool is eager — it never
|
||||
// short-circuits — so both the length check and the byte fold always execute.
|
||||
// XOR-fold constant-time comparison. Running time depends only on the length of the
|
||||
// expected token (b), not on min(a.len(), b.len()), to prevent a timing oracle where
|
||||
// an attacker reduces the iteration count by sending a shorter candidate
|
||||
// (OWASP ASVS V6.6.1). Bitwise `&` on bool is eager — it never short-circuits —
|
||||
// so both the length check and the byte fold always execute.
|
||||
fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
|
||||
let length_ok = a.len() == b.len();
|
||||
let min_len = a.len().min(b.len());
|
||||
let byte_mismatch = a[..min_len]
|
||||
.iter()
|
||||
.zip(b[..min_len].iter())
|
||||
.fold(0u8, |acc, (x, y)| acc | (x ^ y));
|
||||
length_ok & (byte_mismatch == 0)
|
||||
let mut diff = 0u8;
|
||||
for i in 0..b.len() {
|
||||
let x = a.get(i).copied().unwrap_or(0);
|
||||
let y = b[i];
|
||||
diff |= x ^ y;
|
||||
}
|
||||
(a.len() == b.len()) & (diff == 0)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
@ -765,4 +766,106 @@ mod tests {
|
|||
c[0] = 0xde;
|
||||
assert!(!constant_time_eq(&a, &c));
|
||||
}
|
||||
|
||||
// ── Timing-oracle adversarial tests (length-iteration invariant) ──────────
|
||||
|
||||
// An active censor/attacker who knows the token format may send truncated
|
||||
// inputs to narrow down the token length via a timing side-channel.
|
||||
// After the fix, `constant_time_eq` always iterates over `b.len()` (the
|
||||
// expected token length), so submission of every strict prefix of the
|
||||
// expected token must be rejected while iteration count stays constant.
|
||||
#[test]
|
||||
fn constant_time_eq_every_prefix_of_expected_token_is_rejected() {
|
||||
let expected = b"Bearer super-secret-api-token-abc123xyz";
|
||||
for prefix_len in 0..expected.len() {
|
||||
let attacker_input = &expected[..prefix_len];
|
||||
assert!(
|
||||
!constant_time_eq(attacker_input, expected),
|
||||
"prefix of length {prefix_len} must not authenticate against full token"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Reversed: input longer than expected token — the extra bytes must cause
|
||||
// rejection even when the first b.len() bytes are correct.
|
||||
#[test]
|
||||
fn constant_time_eq_input_with_correct_prefix_plus_extra_bytes_is_rejected() {
|
||||
let expected = b"secret-token";
|
||||
for extra in 1usize..=32 {
|
||||
let mut longer = expected.to_vec();
|
||||
// Extend with zeros — the XOR of matching first bytes is 0, so only
|
||||
// the length check prevents a false positive.
|
||||
longer.extend(std::iter::repeat(0u8).take(extra));
|
||||
assert!(
|
||||
!constant_time_eq(&longer, expected),
|
||||
"input extended by {extra} zero bytes must not authenticate"
|
||||
);
|
||||
// Extend with matching-value bytes — ensures the byte_fold stays at 0
|
||||
// for the expected-length portion; only length differs.
|
||||
let mut same_byte_extension = expected.to_vec();
|
||||
same_byte_extension.extend(std::iter::repeat(expected[0]).take(extra));
|
||||
assert!(
|
||||
!constant_time_eq(&same_byte_extension, expected),
|
||||
"input extended by {extra} repeated bytes must not authenticate"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Null-byte injection: ensure the function does not mis-parse embedded
|
||||
// NUL characters as C-string terminators and accept a shorter match.
|
||||
#[test]
|
||||
fn constant_time_eq_null_byte_injection_is_rejected() {
|
||||
// Token containing a null byte — must only match itself exactly.
|
||||
let expected: &[u8] = b"token\x00suffix";
|
||||
assert!(constant_time_eq(expected, expected));
|
||||
assert!(!constant_time_eq(b"token", expected));
|
||||
assert!(!constant_time_eq(b"token\x00", expected));
|
||||
assert!(!constant_time_eq(b"token\x00suffi", expected));
|
||||
|
||||
// Null-prefixed input of the same length must not match a non-null token.
|
||||
let real_token: &[u8] = b"real-secret-value";
|
||||
let mut null_injected = vec![0u8; real_token.len()];
|
||||
null_injected[0] = real_token[0];
|
||||
assert!(!constant_time_eq(&null_injected, real_token));
|
||||
}
|
||||
|
||||
// High-byte (0xFF) values throughout: XOR of 0xFF ^ 0xFF = 0, so equal
|
||||
// high-byte slices must match, and any single-byte difference must not.
|
||||
#[test]
|
||||
fn constant_time_eq_high_byte_edge_cases() {
|
||||
let token = vec![0xffu8; 20];
|
||||
assert!(constant_time_eq(&token, &token));
|
||||
let mut tampered = token.clone();
|
||||
tampered[10] = 0xfe;
|
||||
assert!(!constant_time_eq(&tampered, &token));
|
||||
// Shorter all-ff slice must not match.
|
||||
assert!(!constant_time_eq(&token[..19], &token));
|
||||
}
|
||||
|
||||
// Accumulator-saturation attack: if all bytes of `a` have been XOR-folded to
|
||||
// 0xFF (i.e. acc is saturated), but the remaining bytes of `b` are 0x00, the
|
||||
// fold of 0x00 into 0xFF must keep acc ≠ 0 (since 0xFF | 0 = 0xFF).
|
||||
// This guards against a misimplemented fold that resets acc on certain values.
|
||||
#[test]
|
||||
fn constant_time_eq_accumulator_never_resets_to_zero_after_mismatch() {
|
||||
// a[0] = 0xAA, b[0] = 0x55 → XOR = 0xFF.
|
||||
// Subsequent bytes all match (XOR = 0x00). Accumulator must remain 0xFF.
|
||||
let mut a = vec![0x55u8; 16];
|
||||
let mut b = vec![0x55u8; 16];
|
||||
a[0] = 0xAA; // deliberate mismatch at position 0
|
||||
assert!(!constant_time_eq(&a, &b));
|
||||
// Verify with mismatch only at last position to test late detection.
|
||||
a[0] = b[0];
|
||||
a[15] = 0xAA;
|
||||
b[15] = 0x55;
|
||||
assert!(!constant_time_eq(&a, &b));
|
||||
}
|
||||
|
||||
// Zero-length expected token: only the empty input must match.
|
||||
#[test]
|
||||
fn constant_time_eq_zero_length_expected_only_matches_empty_input() {
|
||||
assert!(constant_time_eq(b"", b""));
|
||||
assert!(!constant_time_eq(b"\x00", b""));
|
||||
assert!(!constant_time_eq(b"x", b""));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -21,9 +21,11 @@
|
|||
//! `network.*`, `use_middle_proxy`) are **not** applied; a warning is emitted.
|
||||
//! Non-hot changes are never mixed into the runtime config snapshot.
|
||||
|
||||
use std::collections::BTreeSet;
|
||||
use std::net::IpAddr;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::{Arc, RwLock as StdRwLock};
|
||||
use std::time::Duration;
|
||||
|
||||
use notify::{EventKind, RecursiveMode, Watcher, recommended_watcher};
|
||||
use tokio::sync::{mpsc, watch};
|
||||
|
|
@ -33,7 +35,10 @@ use crate::config::{
|
|||
LogLevel, MeBindStaleMode, MeFloorMode, MeSocksKdfPolicy, MeTelemetryLevel,
|
||||
MeWriterPickMode,
|
||||
};
|
||||
use super::load::ProxyConfig;
|
||||
use super::load::{LoadedConfig, ProxyConfig};
|
||||
|
||||
const HOT_RELOAD_STABLE_SNAPSHOTS: u8 = 2;
|
||||
const HOT_RELOAD_DEBOUNCE: Duration = Duration::from_millis(50);
|
||||
|
||||
// ── Hot fields ────────────────────────────────────────────────────────────────
|
||||
|
||||
|
|
@ -287,6 +292,149 @@ fn listeners_equal(
|
|||
})
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, PartialEq, Eq)]
|
||||
struct WatchManifest {
|
||||
files: BTreeSet<PathBuf>,
|
||||
dirs: BTreeSet<PathBuf>,
|
||||
}
|
||||
|
||||
impl WatchManifest {
|
||||
fn from_source_files(source_files: &[PathBuf]) -> Self {
|
||||
let mut files = BTreeSet::new();
|
||||
let mut dirs = BTreeSet::new();
|
||||
|
||||
for path in source_files {
|
||||
let normalized = normalize_watch_path(path);
|
||||
files.insert(normalized.clone());
|
||||
if let Some(parent) = normalized.parent() {
|
||||
dirs.insert(parent.to_path_buf());
|
||||
}
|
||||
}
|
||||
|
||||
Self { files, dirs }
|
||||
}
|
||||
|
||||
fn matches_event_paths(&self, event_paths: &[PathBuf]) -> bool {
|
||||
event_paths
|
||||
.iter()
|
||||
.map(|path| normalize_watch_path(path))
|
||||
.any(|path| self.files.contains(&path))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
struct ReloadState {
|
||||
applied_snapshot_hash: Option<u64>,
|
||||
candidate_snapshot_hash: Option<u64>,
|
||||
candidate_hits: u8,
|
||||
}
|
||||
|
||||
impl ReloadState {
|
||||
fn new(applied_snapshot_hash: Option<u64>) -> Self {
|
||||
Self {
|
||||
applied_snapshot_hash,
|
||||
candidate_snapshot_hash: None,
|
||||
candidate_hits: 0,
|
||||
}
|
||||
}
|
||||
|
||||
fn is_applied(&self, hash: u64) -> bool {
|
||||
self.applied_snapshot_hash == Some(hash)
|
||||
}
|
||||
|
||||
fn observe_candidate(&mut self, hash: u64) -> u8 {
|
||||
if self.candidate_snapshot_hash == Some(hash) {
|
||||
self.candidate_hits = self.candidate_hits.saturating_add(1);
|
||||
} else {
|
||||
self.candidate_snapshot_hash = Some(hash);
|
||||
self.candidate_hits = 1;
|
||||
}
|
||||
self.candidate_hits
|
||||
}
|
||||
|
||||
fn reset_candidate(&mut self) {
|
||||
self.candidate_snapshot_hash = None;
|
||||
self.candidate_hits = 0;
|
||||
}
|
||||
|
||||
fn mark_applied(&mut self, hash: u64) {
|
||||
self.applied_snapshot_hash = Some(hash);
|
||||
self.reset_candidate();
|
||||
}
|
||||
}
|
||||
|
||||
fn normalize_watch_path(path: &Path) -> PathBuf {
|
||||
path.canonicalize().unwrap_or_else(|_| {
|
||||
if path.is_absolute() {
|
||||
path.to_path_buf()
|
||||
} else {
|
||||
std::env::current_dir()
|
||||
.map(|cwd| cwd.join(path))
|
||||
.unwrap_or_else(|_| path.to_path_buf())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn sync_watch_paths<W: Watcher>(
|
||||
watcher: &mut W,
|
||||
current: &BTreeSet<PathBuf>,
|
||||
next: &BTreeSet<PathBuf>,
|
||||
recursive_mode: RecursiveMode,
|
||||
kind: &str,
|
||||
) {
|
||||
for path in current.difference(next) {
|
||||
if let Err(e) = watcher.unwatch(path) {
|
||||
warn!(path = %path.display(), error = %e, "config watcher: failed to unwatch {kind}");
|
||||
}
|
||||
}
|
||||
|
||||
for path in next.difference(current) {
|
||||
if let Err(e) = watcher.watch(path, recursive_mode) {
|
||||
warn!(path = %path.display(), error = %e, "config watcher: failed to watch {kind}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn apply_watch_manifest<W1: Watcher, W2: Watcher>(
|
||||
notify_watcher: Option<&mut W1>,
|
||||
poll_watcher: Option<&mut W2>,
|
||||
manifest_state: &Arc<StdRwLock<WatchManifest>>,
|
||||
next_manifest: WatchManifest,
|
||||
) {
|
||||
let current_manifest = manifest_state
|
||||
.read()
|
||||
.map(|manifest| manifest.clone())
|
||||
.unwrap_or_default();
|
||||
|
||||
if current_manifest == next_manifest {
|
||||
return;
|
||||
}
|
||||
|
||||
if let Some(watcher) = notify_watcher {
|
||||
sync_watch_paths(
|
||||
watcher,
|
||||
¤t_manifest.dirs,
|
||||
&next_manifest.dirs,
|
||||
RecursiveMode::NonRecursive,
|
||||
"config directory",
|
||||
);
|
||||
}
|
||||
|
||||
if let Some(watcher) = poll_watcher {
|
||||
sync_watch_paths(
|
||||
watcher,
|
||||
¤t_manifest.files,
|
||||
&next_manifest.files,
|
||||
RecursiveMode::NonRecursive,
|
||||
"config file",
|
||||
);
|
||||
}
|
||||
|
||||
if let Ok(mut manifest) = manifest_state.write() {
|
||||
*manifest = next_manifest;
|
||||
}
|
||||
}
|
||||
|
||||
fn overlay_hot_fields(old: &ProxyConfig, new: &ProxyConfig) -> ProxyConfig {
|
||||
let mut cfg = old.clone();
|
||||
|
||||
|
|
@ -980,18 +1128,42 @@ fn reload_config(
|
|||
log_tx: &watch::Sender<LogLevel>,
|
||||
detected_ip_v4: Option<IpAddr>,
|
||||
detected_ip_v6: Option<IpAddr>,
|
||||
) {
|
||||
let new_cfg = match ProxyConfig::load(config_path) {
|
||||
Ok(c) => c,
|
||||
reload_state: &mut ReloadState,
|
||||
) -> Option<WatchManifest> {
|
||||
let loaded = match ProxyConfig::load_with_metadata(config_path) {
|
||||
Ok(loaded) => loaded,
|
||||
Err(e) => {
|
||||
reload_state.reset_candidate();
|
||||
error!("config reload: failed to parse {:?}: {}", config_path, e);
|
||||
return;
|
||||
return None;
|
||||
}
|
||||
};
|
||||
let LoadedConfig {
|
||||
config: new_cfg,
|
||||
source_files,
|
||||
rendered_hash,
|
||||
} = loaded;
|
||||
let next_manifest = WatchManifest::from_source_files(&source_files);
|
||||
|
||||
if let Err(e) = new_cfg.validate() {
|
||||
reload_state.reset_candidate();
|
||||
error!("config reload: validation failed: {}; keeping old config", e);
|
||||
return;
|
||||
return Some(next_manifest);
|
||||
}
|
||||
|
||||
if reload_state.is_applied(rendered_hash) {
|
||||
return Some(next_manifest);
|
||||
}
|
||||
|
||||
let candidate_hits = reload_state.observe_candidate(rendered_hash);
|
||||
if candidate_hits < HOT_RELOAD_STABLE_SNAPSHOTS {
|
||||
info!(
|
||||
snapshot_hash = rendered_hash,
|
||||
candidate_hits,
|
||||
required_hits = HOT_RELOAD_STABLE_SNAPSHOTS,
|
||||
"config reload: candidate snapshot observed but not stable yet"
|
||||
);
|
||||
return Some(next_manifest);
|
||||
}
|
||||
|
||||
let old_cfg = config_tx.borrow().clone();
|
||||
|
|
@ -1006,17 +1178,19 @@ fn reload_config(
|
|||
}
|
||||
|
||||
if !hot_changed {
|
||||
return;
|
||||
reload_state.mark_applied(rendered_hash);
|
||||
return Some(next_manifest);
|
||||
}
|
||||
|
||||
if old_hot.dns_overrides != applied_hot.dns_overrides
|
||||
&& let Err(e) = crate::network::dns_overrides::install_entries(&applied_hot.dns_overrides)
|
||||
{
|
||||
reload_state.reset_candidate();
|
||||
error!(
|
||||
"config reload: invalid network.dns_overrides: {}; keeping old config",
|
||||
e
|
||||
);
|
||||
return;
|
||||
return Some(next_manifest);
|
||||
}
|
||||
|
||||
log_changes(
|
||||
|
|
@ -1028,6 +1202,8 @@ fn reload_config(
|
|||
detected_ip_v6,
|
||||
);
|
||||
config_tx.send(Arc::new(applied_cfg)).ok();
|
||||
reload_state.mark_applied(rendered_hash);
|
||||
Some(next_manifest)
|
||||
}
|
||||
|
||||
// ── Public API ────────────────────────────────────────────────────────────────
|
||||
|
|
@ -1050,79 +1226,86 @@ pub fn spawn_config_watcher(
|
|||
let (config_tx, config_rx) = watch::channel(initial);
|
||||
let (log_tx, log_rx) = watch::channel(initial_level);
|
||||
|
||||
// Bridge: sync notify callbacks → async task via mpsc.
|
||||
let (notify_tx, mut notify_rx) = mpsc::channel::<()>(4);
|
||||
let config_path = normalize_watch_path(&config_path);
|
||||
let initial_loaded = ProxyConfig::load_with_metadata(&config_path).ok();
|
||||
let initial_manifest = initial_loaded
|
||||
.as_ref()
|
||||
.map(|loaded| WatchManifest::from_source_files(&loaded.source_files))
|
||||
.unwrap_or_else(|| WatchManifest::from_source_files(std::slice::from_ref(&config_path)));
|
||||
let initial_snapshot_hash = initial_loaded.as_ref().map(|loaded| loaded.rendered_hash);
|
||||
|
||||
// Canonicalize so path matches what notify returns (absolute) in events.
|
||||
let config_path = config_path
|
||||
.canonicalize()
|
||||
.unwrap_or_else(|_| config_path.to_path_buf());
|
||||
|
||||
// Watch the parent directory rather than the file itself, because many
|
||||
// editors (vim, nano) and systemd write via rename, which would cause
|
||||
// inotify to lose track of the original inode.
|
||||
let watch_dir = config_path
|
||||
.parent()
|
||||
.unwrap_or_else(|| std::path::Path::new("."))
|
||||
.to_path_buf();
|
||||
|
||||
// ── inotify watcher (instant on local fs) ────────────────────────────
|
||||
let config_file = config_path.clone();
|
||||
let tx_inotify = notify_tx.clone();
|
||||
let inotify_ok = match recommended_watcher(move |res: notify::Result<notify::Event>| {
|
||||
let Ok(event) = res else { return };
|
||||
let is_our_file = event.paths.iter().any(|p| p == &config_file);
|
||||
if !is_our_file { return; }
|
||||
if matches!(event.kind, EventKind::Modify(_) | EventKind::Create(_) | EventKind::Remove(_)) {
|
||||
let _ = tx_inotify.try_send(());
|
||||
}
|
||||
}) {
|
||||
Ok(mut w) => match w.watch(&watch_dir, RecursiveMode::NonRecursive) {
|
||||
Ok(()) => {
|
||||
info!("config watcher: inotify active on {:?}", config_path);
|
||||
Box::leak(Box::new(w));
|
||||
true
|
||||
}
|
||||
Err(e) => { warn!("config watcher: inotify watch failed: {}", e); false }
|
||||
},
|
||||
Err(e) => { warn!("config watcher: inotify unavailable: {}", e); false }
|
||||
};
|
||||
|
||||
// ── poll watcher (always active, fixes Docker bind mounts / NFS) ─────
|
||||
// inotify does not receive events for files mounted from the host into
|
||||
// a container. PollWatcher compares file contents every 3 s and fires
|
||||
// on any change regardless of the underlying fs.
|
||||
let config_file2 = config_path.clone();
|
||||
let tx_poll = notify_tx;
|
||||
match notify::poll::PollWatcher::new(
|
||||
move |res: notify::Result<notify::Event>| {
|
||||
let Ok(event) = res else { return };
|
||||
let is_our_file = event.paths.iter().any(|p| p == &config_file2);
|
||||
if !is_our_file { return; }
|
||||
if matches!(event.kind, EventKind::Modify(_) | EventKind::Create(_) | EventKind::Remove(_)) {
|
||||
let _ = tx_poll.try_send(());
|
||||
}
|
||||
},
|
||||
notify::Config::default()
|
||||
.with_poll_interval(std::time::Duration::from_secs(3))
|
||||
.with_compare_contents(true),
|
||||
) {
|
||||
Ok(mut w) => match w.watch(&config_path, RecursiveMode::NonRecursive) {
|
||||
Ok(()) => {
|
||||
if inotify_ok {
|
||||
info!("config watcher: poll watcher also active (Docker/NFS safe)");
|
||||
} else {
|
||||
info!("config watcher: poll watcher active on {:?} (3s interval)", config_path);
|
||||
}
|
||||
Box::leak(Box::new(w));
|
||||
}
|
||||
Err(e) => warn!("config watcher: poll watch failed: {}", e),
|
||||
},
|
||||
Err(e) => warn!("config watcher: poll watcher unavailable: {}", e),
|
||||
}
|
||||
|
||||
// ── event loop ───────────────────────────────────────────────────────
|
||||
tokio::spawn(async move {
|
||||
let (notify_tx, mut notify_rx) = mpsc::channel::<()>(4);
|
||||
let manifest_state = Arc::new(StdRwLock::new(WatchManifest::default()));
|
||||
let mut reload_state = ReloadState::new(initial_snapshot_hash);
|
||||
|
||||
let tx_inotify = notify_tx.clone();
|
||||
let manifest_for_inotify = manifest_state.clone();
|
||||
let mut inotify_watcher = match recommended_watcher(move |res: notify::Result<notify::Event>| {
|
||||
let Ok(event) = res else { return };
|
||||
if !matches!(event.kind, EventKind::Modify(_) | EventKind::Create(_) | EventKind::Remove(_)) {
|
||||
return;
|
||||
}
|
||||
let is_our_file = manifest_for_inotify
|
||||
.read()
|
||||
.map(|manifest| manifest.matches_event_paths(&event.paths))
|
||||
.unwrap_or(false);
|
||||
if is_our_file {
|
||||
let _ = tx_inotify.try_send(());
|
||||
}
|
||||
}) {
|
||||
Ok(watcher) => Some(watcher),
|
||||
Err(e) => {
|
||||
warn!("config watcher: inotify unavailable: {}", e);
|
||||
None
|
||||
}
|
||||
};
|
||||
apply_watch_manifest(
|
||||
inotify_watcher.as_mut(),
|
||||
Option::<&mut notify::poll::PollWatcher>::None,
|
||||
&manifest_state,
|
||||
initial_manifest.clone(),
|
||||
);
|
||||
if inotify_watcher.is_some() {
|
||||
info!("config watcher: inotify active on {:?}", config_path);
|
||||
}
|
||||
|
||||
let tx_poll = notify_tx.clone();
|
||||
let manifest_for_poll = manifest_state.clone();
|
||||
let mut poll_watcher = match notify::poll::PollWatcher::new(
|
||||
move |res: notify::Result<notify::Event>| {
|
||||
let Ok(event) = res else { return };
|
||||
if !matches!(event.kind, EventKind::Modify(_) | EventKind::Create(_) | EventKind::Remove(_)) {
|
||||
return;
|
||||
}
|
||||
let is_our_file = manifest_for_poll
|
||||
.read()
|
||||
.map(|manifest| manifest.matches_event_paths(&event.paths))
|
||||
.unwrap_or(false);
|
||||
if is_our_file {
|
||||
let _ = tx_poll.try_send(());
|
||||
}
|
||||
},
|
||||
notify::Config::default()
|
||||
.with_poll_interval(Duration::from_secs(3))
|
||||
.with_compare_contents(true),
|
||||
) {
|
||||
Ok(watcher) => Some(watcher),
|
||||
Err(e) => {
|
||||
warn!("config watcher: poll watcher unavailable: {}", e);
|
||||
None
|
||||
}
|
||||
};
|
||||
apply_watch_manifest(
|
||||
Option::<&mut notify::RecommendedWatcher>::None,
|
||||
poll_watcher.as_mut(),
|
||||
&manifest_state,
|
||||
initial_manifest.clone(),
|
||||
);
|
||||
if poll_watcher.is_some() {
|
||||
info!("config watcher: poll watcher active (Docker/NFS safe)");
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
let mut sighup = {
|
||||
use tokio::signal::unix::{signal, SignalKind};
|
||||
|
|
@ -1152,11 +1335,25 @@ pub fn spawn_config_watcher(
|
|||
#[cfg(not(unix))]
|
||||
if notify_rx.recv().await.is_none() { break; }
|
||||
|
||||
// Debounce: drain extra events that arrive within 50 ms.
|
||||
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
|
||||
// Debounce: drain extra events that arrive within a short quiet window.
|
||||
tokio::time::sleep(HOT_RELOAD_DEBOUNCE).await;
|
||||
while notify_rx.try_recv().is_ok() {}
|
||||
|
||||
reload_config(&config_path, &config_tx, &log_tx, detected_ip_v4, detected_ip_v6);
|
||||
if let Some(next_manifest) = reload_config(
|
||||
&config_path,
|
||||
&config_tx,
|
||||
&log_tx,
|
||||
detected_ip_v4,
|
||||
detected_ip_v6,
|
||||
&mut reload_state,
|
||||
) {
|
||||
apply_watch_manifest(
|
||||
inotify_watcher.as_mut(),
|
||||
poll_watcher.as_mut(),
|
||||
&manifest_state,
|
||||
next_manifest,
|
||||
);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
|
|
@ -1171,6 +1368,40 @@ mod tests {
|
|||
ProxyConfig::default()
|
||||
}
|
||||
|
||||
fn write_reload_config(path: &Path, ad_tag: Option<&str>, server_port: Option<u16>) {
|
||||
let mut config = String::from(
|
||||
r#"
|
||||
[censorship]
|
||||
tls_domain = "example.com"
|
||||
|
||||
[access.users]
|
||||
user = "00000000000000000000000000000000"
|
||||
"#,
|
||||
);
|
||||
|
||||
if ad_tag.is_some() {
|
||||
config.push_str("\n[general]\n");
|
||||
if let Some(tag) = ad_tag {
|
||||
config.push_str(&format!("ad_tag = \"{tag}\"\n"));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(port) = server_port {
|
||||
config.push_str("\n[server]\n");
|
||||
config.push_str(&format!("port = {port}\n"));
|
||||
}
|
||||
|
||||
std::fs::write(path, config).unwrap();
|
||||
}
|
||||
|
||||
fn temp_config_path(prefix: &str) -> PathBuf {
|
||||
let nonce = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_nanos();
|
||||
std::env::temp_dir().join(format!("{prefix}_{nonce}.toml"))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn overlay_applies_hot_and_preserves_non_hot() {
|
||||
let old = sample_config();
|
||||
|
|
@ -1238,4 +1469,61 @@ mod tests {
|
|||
assert_eq!(applied.general.use_middle_proxy, old.general.use_middle_proxy);
|
||||
assert!(!config_equal(&applied, &new));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reload_requires_stable_snapshot_before_hot_apply() {
|
||||
let initial_tag = "11111111111111111111111111111111";
|
||||
let final_tag = "22222222222222222222222222222222";
|
||||
let path = temp_config_path("telemt_hot_reload_stable");
|
||||
|
||||
write_reload_config(&path, Some(initial_tag), None);
|
||||
let initial_cfg = Arc::new(ProxyConfig::load(&path).unwrap());
|
||||
let initial_hash = ProxyConfig::load_with_metadata(&path).unwrap().rendered_hash;
|
||||
let (config_tx, _config_rx) = watch::channel(initial_cfg.clone());
|
||||
let (log_tx, _log_rx) = watch::channel(initial_cfg.general.log_level.clone());
|
||||
let mut reload_state = ReloadState::new(Some(initial_hash));
|
||||
|
||||
write_reload_config(&path, None, None);
|
||||
reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap();
|
||||
assert_eq!(
|
||||
config_tx.borrow().general.ad_tag.as_deref(),
|
||||
Some(initial_tag)
|
||||
);
|
||||
|
||||
write_reload_config(&path, Some(final_tag), None);
|
||||
reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap();
|
||||
assert_eq!(
|
||||
config_tx.borrow().general.ad_tag.as_deref(),
|
||||
Some(initial_tag)
|
||||
);
|
||||
|
||||
reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap();
|
||||
assert_eq!(config_tx.borrow().general.ad_tag.as_deref(), Some(final_tag));
|
||||
|
||||
let _ = std::fs::remove_file(path);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reload_keeps_hot_apply_when_non_hot_fields_change() {
|
||||
let initial_tag = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa";
|
||||
let final_tag = "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb";
|
||||
let path = temp_config_path("telemt_hot_reload_mixed");
|
||||
|
||||
write_reload_config(&path, Some(initial_tag), None);
|
||||
let initial_cfg = Arc::new(ProxyConfig::load(&path).unwrap());
|
||||
let initial_hash = ProxyConfig::load_with_metadata(&path).unwrap().rendered_hash;
|
||||
let (config_tx, _config_rx) = watch::channel(initial_cfg.clone());
|
||||
let (log_tx, _log_rx) = watch::channel(initial_cfg.general.log_level.clone());
|
||||
let mut reload_state = ReloadState::new(Some(initial_hash));
|
||||
|
||||
write_reload_config(&path, Some(final_tag), Some(initial_cfg.server.port + 1));
|
||||
reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap();
|
||||
reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap();
|
||||
|
||||
let applied = config_tx.borrow().clone();
|
||||
assert_eq!(applied.general.ad_tag.as_deref(), Some(final_tag));
|
||||
assert_eq!(applied.server.port, initial_cfg.server.port);
|
||||
|
||||
let _ = std::fs::remove_file(path);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
#![allow(deprecated)]
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::collections::{BTreeSet, HashMap};
|
||||
use std::hash::{DefaultHasher, Hash, Hasher};
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::path::Path;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
use rand::Rng;
|
||||
use tracing::warn;
|
||||
|
|
@ -13,6 +14,8 @@ use crate::error::{ProxyError, Result};
|
|||
use super::defaults::*;
|
||||
use super::types::*;
|
||||
|
||||
// Reject absolute paths and any path component that traverses upward.
|
||||
// This prevents config include directives from escaping the config directory.
|
||||
fn is_safe_include_path(path_str: &str) -> bool {
|
||||
let p = std::path::Path::new(path_str);
|
||||
if p.is_absolute() {
|
||||
|
|
@ -21,7 +24,37 @@ fn is_safe_include_path(path_str: &str) -> bool {
|
|||
!p.components().any(|c| matches!(c, std::path::Component::ParentDir))
|
||||
}
|
||||
|
||||
fn preprocess_includes(content: &str, base_dir: &Path, depth: u8) -> Result<String> {
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct LoadedConfig {
|
||||
pub(crate) config: ProxyConfig,
|
||||
pub(crate) source_files: Vec<PathBuf>,
|
||||
pub(crate) rendered_hash: u64,
|
||||
}
|
||||
|
||||
fn normalize_config_path(path: &Path) -> PathBuf {
|
||||
path.canonicalize().unwrap_or_else(|_| {
|
||||
if path.is_absolute() {
|
||||
path.to_path_buf()
|
||||
} else {
|
||||
std::env::current_dir()
|
||||
.map(|cwd| cwd.join(path))
|
||||
.unwrap_or_else(|_| path.to_path_buf())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn hash_rendered_snapshot(rendered: &str) -> u64 {
|
||||
let mut hasher = DefaultHasher::new();
|
||||
rendered.hash(&mut hasher);
|
||||
hasher.finish()
|
||||
}
|
||||
|
||||
fn preprocess_includes(
|
||||
content: &str,
|
||||
base_dir: &Path,
|
||||
depth: u8,
|
||||
source_files: &mut BTreeSet<PathBuf>,
|
||||
) -> Result<String> {
|
||||
if depth > 10 {
|
||||
return Err(ProxyError::Config("Include depth > 10".into()));
|
||||
}
|
||||
|
|
@ -39,10 +72,16 @@ fn preprocess_includes(content: &str, base_dir: &Path, depth: u8) -> Result<Stri
|
|||
)));
|
||||
}
|
||||
let resolved = base_dir.join(path_str);
|
||||
source_files.insert(normalize_config_path(&resolved));
|
||||
let included = std::fs::read_to_string(&resolved)
|
||||
.map_err(|e| ProxyError::Config(e.to_string()))?;
|
||||
let included_dir = resolved.parent().unwrap_or(base_dir);
|
||||
output.push_str(&preprocess_includes(&included, included_dir, depth + 1)?);
|
||||
output.push_str(&preprocess_includes(
|
||||
&included,
|
||||
included_dir,
|
||||
depth + 1,
|
||||
source_files,
|
||||
)?);
|
||||
output.push('\n');
|
||||
continue;
|
||||
}
|
||||
|
|
@ -152,13 +191,16 @@ pub struct ProxyConfig {
|
|||
|
||||
impl ProxyConfig {
|
||||
pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
|
||||
let content =
|
||||
std::fs::read_to_string(&path).map_err(|e| ProxyError::Config(e.to_string()))?;
|
||||
let base_dir = path
|
||||
.as_ref()
|
||||
.parent()
|
||||
.unwrap_or_else(|| Path::new("."));
|
||||
let processed = preprocess_includes(&content, base_dir, 0)?;
|
||||
Self::load_with_metadata(path).map(|loaded| loaded.config)
|
||||
}
|
||||
|
||||
pub(crate) fn load_with_metadata<P: AsRef<Path>>(path: P) -> Result<LoadedConfig> {
|
||||
let path = path.as_ref();
|
||||
let content = std::fs::read_to_string(path).map_err(|e| ProxyError::Config(e.to_string()))?;
|
||||
let base_dir = path.parent().unwrap_or(Path::new("."));
|
||||
let mut source_files = BTreeSet::new();
|
||||
source_files.insert(normalize_config_path(path));
|
||||
let processed = preprocess_includes(&content, base_dir, 0, &mut source_files)?;
|
||||
|
||||
let parsed_toml: toml::Value =
|
||||
toml::from_str(&processed).map_err(|e| ProxyError::Config(e.to_string()))?;
|
||||
|
|
@ -821,7 +863,11 @@ impl ProxyConfig {
|
|||
.entry("203".to_string())
|
||||
.or_insert_with(|| vec!["91.105.192.100:443".to_string()]);
|
||||
|
||||
Ok(config)
|
||||
Ok(LoadedConfig {
|
||||
config,
|
||||
source_files: source_files.into_iter().collect(),
|
||||
rendered_hash: hash_rendered_snapshot(&processed),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn validate(&self) -> Result<()> {
|
||||
|
|
@ -1165,6 +1211,48 @@ mod tests {
|
|||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_with_metadata_collects_include_files() {
|
||||
let nonce = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_nanos();
|
||||
let dir = std::env::temp_dir().join(format!("telemt_load_metadata_{nonce}"));
|
||||
std::fs::create_dir_all(&dir).unwrap();
|
||||
let main_path = dir.join("config.toml");
|
||||
let include_path = dir.join("included.toml");
|
||||
|
||||
std::fs::write(
|
||||
&include_path,
|
||||
r#"
|
||||
[access.users]
|
||||
user = "00000000000000000000000000000000"
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
std::fs::write(
|
||||
&main_path,
|
||||
r#"
|
||||
include = "included.toml"
|
||||
|
||||
[censorship]
|
||||
tls_domain = "example.com"
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let loaded = ProxyConfig::load_with_metadata(&main_path).unwrap();
|
||||
let main_normalized = normalize_config_path(&main_path);
|
||||
let include_normalized = normalize_config_path(&include_path);
|
||||
|
||||
assert!(loaded.source_files.contains(&main_normalized));
|
||||
assert!(loaded.source_files.contains(&include_normalized));
|
||||
|
||||
let _ = std::fs::remove_file(main_path);
|
||||
let _ = std::fs::remove_file(include_path);
|
||||
let _ = std::fs::remove_dir(dir);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dc_overrides_inject_dc203_default() {
|
||||
let toml = r#"
|
||||
|
|
@ -2197,7 +2285,8 @@ mod tests {
|
|||
|
||||
let root = dir.join(format!("{prefix}0.toml"));
|
||||
let root_content = std::fs::read_to_string(&root).unwrap();
|
||||
let result = preprocess_includes(&root_content, &dir, 0);
|
||||
let mut sf = BTreeSet::new();
|
||||
let result = preprocess_includes(&root_content, &dir, 0, &mut sf);
|
||||
|
||||
for i in 0usize..=11 {
|
||||
let _ = std::fs::remove_file(dir.join(format!("{prefix}{i}.toml")));
|
||||
|
|
@ -2230,7 +2319,8 @@ mod tests {
|
|||
|
||||
let root = dir.join(format!("{prefix}0.toml"));
|
||||
let root_content = std::fs::read_to_string(&root).unwrap();
|
||||
let result = preprocess_includes(&root_content, &dir, 0);
|
||||
let mut sf = BTreeSet::new();
|
||||
let result = preprocess_includes(&root_content, &dir, 0, &mut sf);
|
||||
|
||||
for i in 0usize..=10 {
|
||||
let _ = std::fs::remove_file(dir.join(format!("{prefix}{i}.toml")));
|
||||
|
|
|
|||
|
|
@ -382,20 +382,7 @@ mod tests {
|
|||
}
|
||||
|
||||
// ObfuscationParams must NOT implement Clone — key material must not be duplicated.
|
||||
// This is a compile-time enforcement test via negative trait assertion.
|
||||
// If Clone were derived again, this assertion would fail to compile.
|
||||
#[test]
|
||||
fn test_obfuscation_params_is_not_clone() {
|
||||
fn assert_not_clone<T: ?Sized>() {}
|
||||
// The trait bound below must NOT hold. We verify by trying to assert
|
||||
// Clone is NOT available: calling ::clone() on the type must not compile.
|
||||
// Since we cannot express negative bounds in stable Rust, we use the
|
||||
// auto_trait approach: verify the type is not Clone via a static assertion.
|
||||
// If this test compiles and passes, Clone is absent from ObfuscationParams.
|
||||
struct NotClone;
|
||||
impl NotClone { fn check() { assert_not_clone::<ObfuscationParams>(); } }
|
||||
// The above does not assert Clone is absent; the real guard is that the
|
||||
// compiler will reject any `.clone()` call site. This test documents intent.
|
||||
NotClone::check();
|
||||
}
|
||||
// Enforced at compile time: if Clone is ever derived or manually implemented for
|
||||
// ObfuscationParams, this assertion will fail to compile.
|
||||
static_assertions::assert_not_impl_any!(ObfuscationParams: Clone);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@
|
|||
use bytes::BytesMut;
|
||||
use crossbeam_queue::ArrayQueue;
|
||||
use std::ops::{Deref, DerefMut};
|
||||
use std::sync::OnceLock;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
||||
|
|
@ -21,6 +20,11 @@ pub const DEFAULT_BUFFER_SIZE: usize = 16 * 1024;
|
|||
/// Default maximum number of pooled buffers
|
||||
pub const DEFAULT_MAX_BUFFERS: usize = 1024;
|
||||
|
||||
// Buffers that grew beyond this multiple of `buffer_size` are dropped rather
|
||||
// than returned to the pool, preventing memory amplification from a single
|
||||
// large-payload connection permanently holding oversized allocations.
|
||||
const MAX_POOL_BUFFER_OVERSIZE_MULT: usize = 4;
|
||||
|
||||
// ============= Buffer Pool =============
|
||||
|
||||
/// Thread-safe pool of reusable buffers
|
||||
|
|
@ -97,11 +101,13 @@ impl BufferPool {
|
|||
fn return_buffer(&self, mut buffer: BytesMut) {
|
||||
buffer.clear();
|
||||
|
||||
// Only return buffers that have at least the canonical capacity so the
|
||||
// pool does not shrink over time when callers grow their buffers.
|
||||
// Over-capacity or full-queue buffers are simply dropped; `allocated`
|
||||
// is a high-water mark and is not decremented here.
|
||||
if buffer.capacity() >= self.buffer_size {
|
||||
// Accept buffers within [buffer_size, buffer_size * MAX_POOL_BUFFER_OVERSIZE_MULT].
|
||||
// The lower bound prevents pool capacity from shrinking over time.
|
||||
// The upper bound drops buffers that grew excessively (e.g. to serve a large
|
||||
// payload) so they do not permanently inflate pool memory or get handed to a
|
||||
// future connection that only needs a small allocation.
|
||||
let max_acceptable = self.buffer_size.saturating_mul(MAX_POOL_BUFFER_OVERSIZE_MULT);
|
||||
if buffer.capacity() >= self.buffer_size && buffer.capacity() <= max_acceptable {
|
||||
let _ = self.buffers.push(buffer);
|
||||
}
|
||||
}
|
||||
|
|
@ -216,16 +222,17 @@ impl Deref for PooledBuffer {
|
|||
type Target = BytesMut;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
static EMPTY_BUFFER: OnceLock<BytesMut> = OnceLock::new();
|
||||
self.buffer
|
||||
.as_ref()
|
||||
.unwrap_or_else(|| EMPTY_BUFFER.get_or_init(BytesMut::new))
|
||||
.expect("PooledBuffer: attempted to deref after buffer was taken")
|
||||
}
|
||||
}
|
||||
|
||||
impl DerefMut for PooledBuffer {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
self.buffer.get_or_insert_with(BytesMut::new)
|
||||
self.buffer
|
||||
.as_mut()
|
||||
.expect("PooledBuffer: attempted to deref_mut after buffer was taken")
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -503,22 +510,27 @@ mod tests {
|
|||
);
|
||||
}
|
||||
|
||||
// An over-grown buffer (capacity > buffer_size) that is returned to the pool
|
||||
// must still be returned (not dropped) because its capacity exceeds the
|
||||
// required minimum.
|
||||
// A buffer that grew moderately (within MAX_POOL_BUFFER_OVERSIZE_MULT of the canonical
|
||||
// size) must be returned to the pool, because the allocation is still reasonable and
|
||||
// reusing it is more efficient than allocating a new one.
|
||||
#[test]
|
||||
fn oversized_buffer_is_returned_to_pool() {
|
||||
let pool = Arc::new(BufferPool::with_config(64, 10));
|
||||
let canonical = 64usize;
|
||||
let pool = Arc::new(BufferPool::with_config(canonical, 10));
|
||||
|
||||
let mut buf = pool.get();
|
||||
// Manually grow the backing allocation beyond buffer_size.
|
||||
buf.reserve(8192);
|
||||
assert!(buf.capacity() >= 8192);
|
||||
// Grow to 2× the canonical size — within the 4× upper bound.
|
||||
buf.reserve(canonical);
|
||||
assert!(buf.capacity() >= canonical);
|
||||
assert!(
|
||||
buf.capacity() <= canonical * MAX_POOL_BUFFER_OVERSIZE_MULT,
|
||||
"pre-condition: test growth must stay within the acceptable bound"
|
||||
);
|
||||
drop(buf);
|
||||
|
||||
// The buffer must have been returned because capacity >= buffer_size.
|
||||
// The buffer must have been returned because capacity is within acceptable range.
|
||||
let stats = pool.stats();
|
||||
assert_eq!(stats.pooled, 1, "oversized buffer must be returned to pool");
|
||||
assert_eq!(stats.pooled, 1, "moderately-oversized buffer must be returned to pool");
|
||||
}
|
||||
|
||||
// A buffer whose capacity fell below buffer_size (e.g. due to take() on
|
||||
|
|
@ -625,4 +637,97 @@ mod tests {
|
|||
);
|
||||
assert_eq!(stats.hits, 50);
|
||||
}
|
||||
|
||||
// ── Security invariant: sensitive data must not leak between pool users ───
|
||||
|
||||
// A buffer containing "sensitive" bytes must be zeroed before being handed
|
||||
// to the next caller. An attacker who can trigger repeated pool cycles against
|
||||
// a shared buffer slot must not be able to read prior connection data.
|
||||
#[test]
|
||||
fn pooled_buffer_sensitive_data_is_cleared_before_reuse() {
|
||||
let pool = Arc::new(BufferPool::with_config(64, 2));
|
||||
{
|
||||
let mut buf = pool.get();
|
||||
buf.extend_from_slice(b"credentials:password123");
|
||||
// Drop returns the buffer to the pool after clearing.
|
||||
}
|
||||
{
|
||||
let buf = pool.get();
|
||||
// Buffer must be empty — no leftover bytes from the previous user.
|
||||
assert!(buf.is_empty(), "pool must clear buffer before handing it to the next caller");
|
||||
assert_eq!(buf.len(), 0);
|
||||
}
|
||||
}
|
||||
|
||||
// Verify that calling take() extracts the full content and the extracted
|
||||
// BytesMut does NOT get returned to the pool (no double-return).
|
||||
#[test]
|
||||
fn pooled_buffer_take_eliminates_pool_return() {
|
||||
let pool = Arc::new(BufferPool::with_config(64, 2));
|
||||
let stats_before = pool.stats();
|
||||
|
||||
let mut buf = pool.get(); // miss
|
||||
buf.extend_from_slice(b"important");
|
||||
let inner = buf.take(); // consumes PooledBuffer, should NOT return to pool
|
||||
|
||||
assert_eq!(&inner[..], b"important");
|
||||
let stats_after = pool.stats();
|
||||
// pooled count must not increase — take() bypasses the pool
|
||||
assert_eq!(
|
||||
stats_after.pooled, stats_before.pooled,
|
||||
"take() must not return the buffer to the pool"
|
||||
);
|
||||
}
|
||||
|
||||
// Multiple concurrent get() calls must each get an independent empty buffer,
|
||||
// not aliased memory. An adversary who can cause aliased buffer access could
|
||||
// read or corrupt another connection's in-flight data.
|
||||
#[test]
|
||||
fn pooled_buffers_are_independent_no_aliasing() {
|
||||
let pool = Arc::new(BufferPool::with_config(64, 4));
|
||||
let mut b1 = pool.get();
|
||||
let mut b2 = pool.get();
|
||||
|
||||
b1.extend_from_slice(b"connection-A");
|
||||
b2.extend_from_slice(b"connection-B");
|
||||
|
||||
assert_eq!(&b1[..], b"connection-A");
|
||||
assert_eq!(&b2[..], b"connection-B");
|
||||
// Verify no aliasing: modifying b2 does not affect b1.
|
||||
assert_ne!(&b1[..], &b2[..]);
|
||||
}
|
||||
|
||||
// Oversized buffers (capacity grown beyond pool's canonical size) must NOT
|
||||
// be returned to the pool — this prevents the pool from holding oversized
|
||||
// buffers that could be handed to unrelated connections and leak large chunks
|
||||
// of heap across connection boundaries.
|
||||
#[test]
|
||||
fn oversized_buffer_is_dropped_not_pooled() {
|
||||
let canonical = 64usize;
|
||||
let pool = Arc::new(BufferPool::with_config(canonical, 4));
|
||||
|
||||
{
|
||||
let mut buf = pool.get();
|
||||
// Grow well beyond the canonical size.
|
||||
buf.extend(std::iter::repeat(0u8).take(canonical * 8));
|
||||
// Drop should abandon this oversized buffer rather than returning it.
|
||||
}
|
||||
|
||||
let stats = pool.stats();
|
||||
// Pool must be empty: the oversized buffer was not re-queued.
|
||||
assert_eq!(
|
||||
stats.pooled, 0,
|
||||
"oversized buffer must be dropped, not returned to pool (got {} pooled)",
|
||||
stats.pooled
|
||||
);
|
||||
}
|
||||
|
||||
// Deref on a PooledBuffer obtained normally must NOT panic.
|
||||
#[test]
|
||||
fn pooled_buffer_deref_on_live_buffer_does_not_panic() {
|
||||
let pool = Arc::new(BufferPool::new());
|
||||
let mut buf = pool.get();
|
||||
buf.extend_from_slice(b"hello");
|
||||
assert_eq!(&buf[..], b"hello");
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -97,8 +97,18 @@ pub(crate) fn build_proxy_req_payload(
|
|||
}
|
||||
}
|
||||
|
||||
let extra_bytes = u32::try_from(b.len() - extra_start - 4).unwrap_or(0);
|
||||
b[extra_start..extra_start + 4].copy_from_slice(&extra_bytes.to_le_bytes());
|
||||
// On overflow (buffer > 4 GiB, unreachable in practice) keep the extra block
|
||||
// empty by truncating the data back to the length-field position and writing 0,
|
||||
// so the wire representation remains consistent (length field matches content).
|
||||
match u32::try_from(b.len() - extra_start - 4) {
|
||||
Ok(extra_bytes) => {
|
||||
b[extra_start..extra_start + 4].copy_from_slice(&extra_bytes.to_le_bytes());
|
||||
}
|
||||
Err(_) => {
|
||||
b.truncate(extra_start + 4);
|
||||
b[extra_start..extra_start + 4].copy_from_slice(&0u32.to_le_bytes());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
b.extend_from_slice(data);
|
||||
|
|
@ -242,6 +252,7 @@ mod tests {
|
|||
|
||||
// Tags exceeding the 3-byte TL length limit (> 0xFFFFFF = 16,777,215 bytes)
|
||||
// must produce an empty extra block rather than a truncated/corrupt length.
|
||||
#[ignore = "allocates ~16 MiB; run only in a dedicated large-tests profile/CI job"]
|
||||
#[test]
|
||||
fn test_build_proxy_req_payload_tag_exceeds_tl_long_form_limit_produces_empty_extra() {
|
||||
let client = SocketAddr::from(([198, 51, 100, 22], 33333));
|
||||
|
|
@ -267,6 +278,7 @@ mod tests {
|
|||
|
||||
// The prior guard was `tag_len_u32.is_some()` which passed for tags up to u32::MAX.
|
||||
// Verify boundary at exactly 0xFFFFFF: must succeed and encode at 3 bytes.
|
||||
#[ignore = "allocates ~16 MiB; run only in a dedicated large-tests profile/CI job"]
|
||||
#[test]
|
||||
fn test_build_proxy_req_payload_tag_at_tl_long_form_max_boundary_encodes() {
|
||||
// 0xFFFFFF = 16,777,215 bytes — maximum representable by 3-byte TL length.
|
||||
|
|
@ -303,4 +315,101 @@ mod tests {
|
|||
let fixed = 4 + 4 + 8 + 20 + 20;
|
||||
let extra_len = u32::from_le_bytes(raw[fixed..fixed + 4].try_into().unwrap()) as usize;
|
||||
assert_eq!(extra_len, 0);
|
||||
}}
|
||||
}
|
||||
|
||||
// ── Protocol wire-consistency invariant tests ─────────────────────────────
|
||||
|
||||
// The extra-block length field must ALWAYS equal the number of bytes that
|
||||
// actually follow it in the buffer. A censor or MitM that parses the wire
|
||||
// representation must not be able to trigger desync by providing adversarial
|
||||
// input that causes the length to say 0 while data bytes are present.
|
||||
#[test]
|
||||
fn extra_block_length_field_always_matches_actual_content_length() {
|
||||
let client = SocketAddr::from(([198, 51, 100, 40], 10000));
|
||||
let ours = SocketAddr::from(([203, 0, 113, 40], 443));
|
||||
let fixed = 4 + 4 + 8 + 20 + 20;
|
||||
|
||||
// Helper: assert wire consistency for a given tag.
|
||||
let check = |tag: Option<&[u8]>, data: &[u8]| {
|
||||
let p = build_proxy_req_payload(1, client, ours, data, tag, RPC_FLAG_HAS_AD_TAG);
|
||||
let raw = p.as_ref();
|
||||
let declared =
|
||||
u32::from_le_bytes(raw[fixed..fixed + 4].try_into().unwrap()) as usize;
|
||||
let actual = raw.len() - fixed - 4 - data.len();
|
||||
assert_eq!(
|
||||
declared, actual,
|
||||
"extra-block length field ({declared}) must equal actual byte count ({actual})"
|
||||
);
|
||||
};
|
||||
|
||||
check(None, b"data");
|
||||
check(Some(b""), b"data"); // zero-length tag
|
||||
check(Some(&[0xABu8; 1]), b""); // 1-byte tag
|
||||
check(Some(&[0xABu8; 253]), b"x"); // short-form max
|
||||
check(Some(&[0xCCu8; 254]), b"y"); // long-form min
|
||||
check(Some(&[0xDDu8; 500]), b"z"); // mid-range long-form
|
||||
}
|
||||
|
||||
// Zero-length tag: must encode with short-form length byte 0 and 4-byte
|
||||
// alignment padding, not be dropped as if tag were None.
|
||||
#[test]
|
||||
fn test_build_proxy_req_payload_zero_length_tag_encodes_correctly() {
|
||||
let client = SocketAddr::from(([198, 51, 100, 50], 20000));
|
||||
let ours = SocketAddr::from(([203, 0, 113, 50], 443));
|
||||
let payload = build_proxy_req_payload(
|
||||
12,
|
||||
client,
|
||||
ours,
|
||||
b"payload",
|
||||
Some(&[]),
|
||||
RPC_FLAG_HAS_AD_TAG,
|
||||
);
|
||||
let raw = payload.as_ref();
|
||||
let fixed = 4 + 4 + 8 + 20 + 20;
|
||||
let extra_len = u32::from_le_bytes(raw[fixed..fixed + 4].try_into().unwrap()) as usize;
|
||||
// TL object (4) + length-byte 0 (1) + 0 data bytes + padding to 4-byte boundary.
|
||||
// (1 + 0) % 4 = 1; pad = (4 - 1) % 4 = 3.
|
||||
assert_eq!(extra_len, 4 + 1 + 3, "zero-length tag must produce TL header + padding");
|
||||
// Length byte must be 0 (short-form).
|
||||
assert_eq!(raw[fixed + 8], 0u8);
|
||||
}
|
||||
|
||||
// HAS_AD_TAG absent: extra block must NOT appear in the wire output at all.
|
||||
#[test]
|
||||
fn test_build_proxy_req_payload_without_has_ad_tag_flag_has_no_extra_block() {
|
||||
let client = SocketAddr::from(([198, 51, 100, 60], 30000));
|
||||
let ours = SocketAddr::from(([203, 0, 113, 60], 443));
|
||||
let payload = build_proxy_req_payload(
|
||||
13,
|
||||
client,
|
||||
ours,
|
||||
b"data",
|
||||
Some(&[1, 2, 3]),
|
||||
0, // no HAS_AD_TAG flag
|
||||
);
|
||||
let fixed = 4 + 4 + 8 + 20 + 20;
|
||||
let raw = payload.as_ref();
|
||||
// Without the flag, extra block is skipped; payload follows immediately after headers.
|
||||
assert_eq!(raw.len(), fixed + 4, "no extra block when HAS_AD_TAG is absent");
|
||||
assert_eq!(&raw[fixed..], b"data");
|
||||
}
|
||||
|
||||
// Tag with all-0xFF bytes: the TL length encoding must not be confused by
|
||||
// the 0xFF marker byte that the tag itself might contain.
|
||||
#[test]
|
||||
fn test_build_proxy_req_payload_tag_containing_0xfe_byte_encodes_correctly() {
|
||||
let client = SocketAddr::from(([198, 51, 100, 70], 40000));
|
||||
let ours = SocketAddr::from(([203, 0, 113, 70], 443));
|
||||
// 10-byte tag whose first byte is 0xFE (the long-form marker value).
|
||||
// At length 10 the short-form encoding applies; the 0xFE data byte must
|
||||
// not be confused with the length marker.
|
||||
let tag = vec![0xFEu8; 10];
|
||||
let payload = build_proxy_req_payload(14, client, ours, b"", Some(&tag), RPC_FLAG_HAS_AD_TAG);
|
||||
let raw = payload.as_ref();
|
||||
let fixed = 4 + 4 + 8 + 20 + 20;
|
||||
// Length byte must be 10 (short form), not 0xFE.
|
||||
assert_eq!(raw[fixed + 8], 10u8, "short-form length byte must be 10, not the 0xFE marker");
|
||||
// The actual tag bytes must follow immediately.
|
||||
assert_eq!(&raw[fixed + 9..fixed + 19], &[0xFEu8; 10]);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -260,23 +260,44 @@ pub struct PoolStats {
|
|||
#[cfg(unix)]
|
||||
#[allow(unsafe_code)]
|
||||
fn is_connection_alive(stream: &TcpStream) -> bool {
|
||||
use std::io::ErrorKind;
|
||||
use std::os::unix::io::AsRawFd;
|
||||
|
||||
let fd = stream.as_raw_fd();
|
||||
let mut buf = [0u8; 1];
|
||||
// SAFETY: `stream` owns this fd for the full duration of the call.
|
||||
// MSG_PEEK + MSG_DONTWAIT: inspect the receive buffer without consuming bytes.
|
||||
let n = unsafe {
|
||||
libc::recv(
|
||||
stream.as_raw_fd(),
|
||||
buf.as_mut_ptr().cast::<libc::c_void>(),
|
||||
1,
|
||||
libc::MSG_PEEK | libc::MSG_DONTWAIT,
|
||||
)
|
||||
};
|
||||
match n {
|
||||
0 => false,
|
||||
n if n > 0 => true,
|
||||
_ => std::io::Error::last_os_error().kind() == std::io::ErrorKind::WouldBlock,
|
||||
|
||||
// EINTR can fire on any syscall when a signal is delivered to the thread.
|
||||
// Treating it as a dead connection causes spurious reconnects; retry instead.
|
||||
const MAX_RECV_RETRIES: usize = 3;
|
||||
|
||||
for _ in 0..MAX_RECV_RETRIES {
|
||||
// SAFETY: `stream` owns this fd for the full duration of the call.
|
||||
// MSG_PEEK + MSG_DONTWAIT: inspect the receive buffer without consuming bytes.
|
||||
let n = unsafe {
|
||||
libc::recv(
|
||||
fd,
|
||||
buf.as_mut_ptr().cast::<libc::c_void>(),
|
||||
1,
|
||||
libc::MSG_PEEK | libc::MSG_DONTWAIT,
|
||||
)
|
||||
};
|
||||
|
||||
if n > 0 {
|
||||
return true;
|
||||
} else if n == 0 {
|
||||
return false;
|
||||
} else {
|
||||
match std::io::Error::last_os_error().kind() {
|
||||
ErrorKind::Interrupted => continue,
|
||||
ErrorKind::WouldBlock => return true,
|
||||
_ => return false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// After MAX_RECV_RETRIES consecutive EINTR, assume the connection is alive
|
||||
// rather than triggering a false reconnect.
|
||||
true
|
||||
}
|
||||
|
||||
#[cfg(not(unix))]
|
||||
|
|
@ -518,5 +539,63 @@ mod tests {
|
|||
assert!(is_connection_alive(&stream));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── EINTR retry-logic unit tests ──────────────────────────────────────────
|
||||
|
||||
// The non-unix fallback path must correctly classify WouldBlock as alive
|
||||
// and any other error as dead, mirroring the unix semantics.
|
||||
#[cfg(not(unix))]
|
||||
#[tokio::test]
|
||||
async fn test_is_connection_alive_non_unix_open_connection() {
|
||||
let listener = match TcpListener::bind("127.0.0.1:0").await {
|
||||
Ok(l) => l,
|
||||
Err(e) if e.kind() == ErrorKind::PermissionDenied => return,
|
||||
Err(e) => panic!("bind failed: {e}"),
|
||||
};
|
||||
let addr = listener.local_addr().unwrap();
|
||||
tokio::spawn(async move {
|
||||
let _ = listener.accept().await;
|
||||
tokio::time::sleep(std::time::Duration::from_millis(500)).await;
|
||||
});
|
||||
let stream = match TcpStream::connect(addr).await {
|
||||
Ok(s) => s,
|
||||
Err(e) if e.kind() == ErrorKind::PermissionDenied => return,
|
||||
Err(e) => panic!("connect failed: {e}"),
|
||||
};
|
||||
assert!(is_connection_alive(&stream), "open idle connection must be alive (non-unix)");
|
||||
}
|
||||
|
||||
// Verify the unix path: calling is_connection_alive many times on an active
|
||||
// connection never spuriously returns false (guards against the pre-fix bug
|
||||
// where EINTR would flip the result on a single call).
|
||||
#[cfg(unix)]
|
||||
#[tokio::test]
|
||||
async fn test_is_connection_alive_never_returns_false_spuriously_on_open_connection() {
|
||||
use tokio::io::AsyncWriteExt;
|
||||
let listener = match TcpListener::bind("127.0.0.1:0").await {
|
||||
Ok(l) => l,
|
||||
Err(e) if e.kind() == ErrorKind::PermissionDenied => return,
|
||||
Err(e) => panic!("bind failed: {e}"),
|
||||
};
|
||||
let addr = listener.local_addr().unwrap();
|
||||
let server = tokio::spawn(async move {
|
||||
let (mut s, _) = listener.accept().await.expect("accept");
|
||||
s.write_all(&[0xBEu8]).await.expect("write");
|
||||
tokio::time::sleep(Duration::from_millis(500)).await;
|
||||
});
|
||||
let stream = match TcpStream::connect(addr).await {
|
||||
Ok(s) => s,
|
||||
Err(e) if e.kind() == ErrorKind::PermissionDenied => return,
|
||||
Err(e) => panic!("connect failed: {e}"),
|
||||
};
|
||||
tokio::time::sleep(Duration::from_millis(30)).await;
|
||||
// Call is_connection_alive 20 times; all must return true.
|
||||
for i in 0..20 {
|
||||
assert!(
|
||||
is_connection_alive(&stream),
|
||||
"is_connection_alive must be true on call {i} (connection is open with buffered data)"
|
||||
);
|
||||
}
|
||||
drop(server);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue