mirror of https://github.com/telemt/telemt.git
Merge resolve-conflicts-419: conflict resolution + Copilot review fixes
This commit is contained in:
commit
0903dd36c7
|
|
@ -2025,6 +2025,12 @@ version = "1.2.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596"
|
checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "static_assertions"
|
||||||
|
version = "1.1.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "subtle"
|
name = "subtle"
|
||||||
version = "2.6.1"
|
version = "2.6.1"
|
||||||
|
|
@ -2127,6 +2133,7 @@ dependencies = [
|
||||||
"sha1",
|
"sha1",
|
||||||
"sha2",
|
"sha2",
|
||||||
"socket2 0.5.10",
|
"socket2 0.5.10",
|
||||||
|
"static_assertions",
|
||||||
"subtle",
|
"subtle",
|
||||||
"thiserror 2.0.18",
|
"thiserror 2.0.18",
|
||||||
"tokio",
|
"tokio",
|
||||||
|
|
|
||||||
|
|
@ -70,6 +70,7 @@ tokio-test = "0.4"
|
||||||
criterion = "0.5"
|
criterion = "0.5"
|
||||||
proptest = "1.4"
|
proptest = "1.4"
|
||||||
futures = "0.3"
|
futures = "0.3"
|
||||||
|
static_assertions = "1.1"
|
||||||
|
|
||||||
[[bench]]
|
[[bench]]
|
||||||
name = "crypto_bench"
|
name = "crypto_bench"
|
||||||
|
|
|
||||||
129
src/api/mod.rs
129
src/api/mod.rs
|
|
@ -62,7 +62,7 @@ use runtime_zero::{
|
||||||
use runtime_watch::spawn_runtime_watchers;
|
use runtime_watch::spawn_runtime_watchers;
|
||||||
use users::{create_user, delete_user, patch_user, rotate_secret, users_from_config};
|
use users::{create_user, delete_user, patch_user, rotate_secret, users_from_config};
|
||||||
|
|
||||||
pub struct ApiRuntimeState {
|
pub(super) struct ApiRuntimeState {
|
||||||
pub(super) process_started_at_epoch_secs: u64,
|
pub(super) process_started_at_epoch_secs: u64,
|
||||||
pub(super) config_reload_count: AtomicU64,
|
pub(super) config_reload_count: AtomicU64,
|
||||||
pub(super) last_config_reload_epoch_secs: AtomicU64,
|
pub(super) last_config_reload_epoch_secs: AtomicU64,
|
||||||
|
|
@ -70,7 +70,7 @@ pub struct ApiRuntimeState {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct ApiShared {
|
pub(super) struct ApiShared {
|
||||||
pub(super) stats: Arc<Stats>,
|
pub(super) stats: Arc<Stats>,
|
||||||
pub(super) ip_tracker: Arc<UserIpTracker>,
|
pub(super) ip_tracker: Arc<UserIpTracker>,
|
||||||
pub(super) me_pool: Arc<RwLock<Option<Arc<MePool>>>>,
|
pub(super) me_pool: Arc<RwLock<Option<Arc<MePool>>>>,
|
||||||
|
|
@ -565,18 +565,19 @@ async fn handle(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// XOR-fold constant-time comparison: avoids early return on length mismatch to
|
// XOR-fold constant-time comparison. Running time depends only on the length of the
|
||||||
// prevent a timing oracle that would reveal the expected token length to a remote
|
// expected token (b), not on min(a.len(), b.len()), to prevent a timing oracle where
|
||||||
// attacker (OWASP ASVS V6.6.1). Bitwise `&` on bool is eager — it never
|
// an attacker reduces the iteration count by sending a shorter candidate
|
||||||
// short-circuits — so both the length check and the byte fold always execute.
|
// (OWASP ASVS V6.6.1). Bitwise `&` on bool is eager — it never short-circuits —
|
||||||
|
// so both the length check and the byte fold always execute.
|
||||||
fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
|
fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
|
||||||
let length_ok = a.len() == b.len();
|
let mut diff = 0u8;
|
||||||
let min_len = a.len().min(b.len());
|
for i in 0..b.len() {
|
||||||
let byte_mismatch = a[..min_len]
|
let x = a.get(i).copied().unwrap_or(0);
|
||||||
.iter()
|
let y = b[i];
|
||||||
.zip(b[..min_len].iter())
|
diff |= x ^ y;
|
||||||
.fold(0u8, |acc, (x, y)| acc | (x ^ y));
|
}
|
||||||
length_ok & (byte_mismatch == 0)
|
(a.len() == b.len()) & (diff == 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|
@ -765,4 +766,106 @@ mod tests {
|
||||||
c[0] = 0xde;
|
c[0] = 0xde;
|
||||||
assert!(!constant_time_eq(&a, &c));
|
assert!(!constant_time_eq(&a, &c));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ── Timing-oracle adversarial tests (length-iteration invariant) ──────────
|
||||||
|
|
||||||
|
// An active censor/attacker who knows the token format may send truncated
|
||||||
|
// inputs to narrow down the token length via a timing side-channel.
|
||||||
|
// After the fix, `constant_time_eq` always iterates over `b.len()` (the
|
||||||
|
// expected token length), so submission of every strict prefix of the
|
||||||
|
// expected token must be rejected while iteration count stays constant.
|
||||||
|
#[test]
|
||||||
|
fn constant_time_eq_every_prefix_of_expected_token_is_rejected() {
|
||||||
|
let expected = b"Bearer super-secret-api-token-abc123xyz";
|
||||||
|
for prefix_len in 0..expected.len() {
|
||||||
|
let attacker_input = &expected[..prefix_len];
|
||||||
|
assert!(
|
||||||
|
!constant_time_eq(attacker_input, expected),
|
||||||
|
"prefix of length {prefix_len} must not authenticate against full token"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reversed: input longer than expected token — the extra bytes must cause
|
||||||
|
// rejection even when the first b.len() bytes are correct.
|
||||||
|
#[test]
|
||||||
|
fn constant_time_eq_input_with_correct_prefix_plus_extra_bytes_is_rejected() {
|
||||||
|
let expected = b"secret-token";
|
||||||
|
for extra in 1usize..=32 {
|
||||||
|
let mut longer = expected.to_vec();
|
||||||
|
// Extend with zeros — the XOR of matching first bytes is 0, so only
|
||||||
|
// the length check prevents a false positive.
|
||||||
|
longer.extend(std::iter::repeat(0u8).take(extra));
|
||||||
|
assert!(
|
||||||
|
!constant_time_eq(&longer, expected),
|
||||||
|
"input extended by {extra} zero bytes must not authenticate"
|
||||||
|
);
|
||||||
|
// Extend with matching-value bytes — ensures the byte_fold stays at 0
|
||||||
|
// for the expected-length portion; only length differs.
|
||||||
|
let mut same_byte_extension = expected.to_vec();
|
||||||
|
same_byte_extension.extend(std::iter::repeat(expected[0]).take(extra));
|
||||||
|
assert!(
|
||||||
|
!constant_time_eq(&same_byte_extension, expected),
|
||||||
|
"input extended by {extra} repeated bytes must not authenticate"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Null-byte injection: ensure the function does not mis-parse embedded
|
||||||
|
// NUL characters as C-string terminators and accept a shorter match.
|
||||||
|
#[test]
|
||||||
|
fn constant_time_eq_null_byte_injection_is_rejected() {
|
||||||
|
// Token containing a null byte — must only match itself exactly.
|
||||||
|
let expected: &[u8] = b"token\x00suffix";
|
||||||
|
assert!(constant_time_eq(expected, expected));
|
||||||
|
assert!(!constant_time_eq(b"token", expected));
|
||||||
|
assert!(!constant_time_eq(b"token\x00", expected));
|
||||||
|
assert!(!constant_time_eq(b"token\x00suffi", expected));
|
||||||
|
|
||||||
|
// Null-prefixed input of the same length must not match a non-null token.
|
||||||
|
let real_token: &[u8] = b"real-secret-value";
|
||||||
|
let mut null_injected = vec![0u8; real_token.len()];
|
||||||
|
null_injected[0] = real_token[0];
|
||||||
|
assert!(!constant_time_eq(&null_injected, real_token));
|
||||||
|
}
|
||||||
|
|
||||||
|
// High-byte (0xFF) values throughout: XOR of 0xFF ^ 0xFF = 0, so equal
|
||||||
|
// high-byte slices must match, and any single-byte difference must not.
|
||||||
|
#[test]
|
||||||
|
fn constant_time_eq_high_byte_edge_cases() {
|
||||||
|
let token = vec![0xffu8; 20];
|
||||||
|
assert!(constant_time_eq(&token, &token));
|
||||||
|
let mut tampered = token.clone();
|
||||||
|
tampered[10] = 0xfe;
|
||||||
|
assert!(!constant_time_eq(&tampered, &token));
|
||||||
|
// Shorter all-ff slice must not match.
|
||||||
|
assert!(!constant_time_eq(&token[..19], &token));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Accumulator-saturation attack: if all bytes of `a` have been XOR-folded to
|
||||||
|
// 0xFF (i.e. acc is saturated), but the remaining bytes of `b` are 0x00, the
|
||||||
|
// fold of 0x00 into 0xFF must keep acc ≠ 0 (since 0xFF | 0 = 0xFF).
|
||||||
|
// This guards against a misimplemented fold that resets acc on certain values.
|
||||||
|
#[test]
|
||||||
|
fn constant_time_eq_accumulator_never_resets_to_zero_after_mismatch() {
|
||||||
|
// a[0] = 0xAA, b[0] = 0x55 → XOR = 0xFF.
|
||||||
|
// Subsequent bytes all match (XOR = 0x00). Accumulator must remain 0xFF.
|
||||||
|
let mut a = vec![0x55u8; 16];
|
||||||
|
let mut b = vec![0x55u8; 16];
|
||||||
|
a[0] = 0xAA; // deliberate mismatch at position 0
|
||||||
|
assert!(!constant_time_eq(&a, &b));
|
||||||
|
// Verify with mismatch only at last position to test late detection.
|
||||||
|
a[0] = b[0];
|
||||||
|
a[15] = 0xAA;
|
||||||
|
b[15] = 0x55;
|
||||||
|
assert!(!constant_time_eq(&a, &b));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Zero-length expected token: only the empty input must match.
|
||||||
|
#[test]
|
||||||
|
fn constant_time_eq_zero_length_expected_only_matches_empty_input() {
|
||||||
|
assert!(constant_time_eq(b"", b""));
|
||||||
|
assert!(!constant_time_eq(b"\x00", b""));
|
||||||
|
assert!(!constant_time_eq(b"x", b""));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -21,9 +21,11 @@
|
||||||
//! `network.*`, `use_middle_proxy`) are **not** applied; a warning is emitted.
|
//! `network.*`, `use_middle_proxy`) are **not** applied; a warning is emitted.
|
||||||
//! Non-hot changes are never mixed into the runtime config snapshot.
|
//! Non-hot changes are never mixed into the runtime config snapshot.
|
||||||
|
|
||||||
|
use std::collections::BTreeSet;
|
||||||
use std::net::IpAddr;
|
use std::net::IpAddr;
|
||||||
use std::path::PathBuf;
|
use std::path::{Path, PathBuf};
|
||||||
use std::sync::Arc;
|
use std::sync::{Arc, RwLock as StdRwLock};
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
use notify::{EventKind, RecursiveMode, Watcher, recommended_watcher};
|
use notify::{EventKind, RecursiveMode, Watcher, recommended_watcher};
|
||||||
use tokio::sync::{mpsc, watch};
|
use tokio::sync::{mpsc, watch};
|
||||||
|
|
@ -33,7 +35,10 @@ use crate::config::{
|
||||||
LogLevel, MeBindStaleMode, MeFloorMode, MeSocksKdfPolicy, MeTelemetryLevel,
|
LogLevel, MeBindStaleMode, MeFloorMode, MeSocksKdfPolicy, MeTelemetryLevel,
|
||||||
MeWriterPickMode,
|
MeWriterPickMode,
|
||||||
};
|
};
|
||||||
use super::load::ProxyConfig;
|
use super::load::{LoadedConfig, ProxyConfig};
|
||||||
|
|
||||||
|
const HOT_RELOAD_STABLE_SNAPSHOTS: u8 = 2;
|
||||||
|
const HOT_RELOAD_DEBOUNCE: Duration = Duration::from_millis(50);
|
||||||
|
|
||||||
// ── Hot fields ────────────────────────────────────────────────────────────────
|
// ── Hot fields ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
@ -287,6 +292,149 @@ fn listeners_equal(
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Default, PartialEq, Eq)]
|
||||||
|
struct WatchManifest {
|
||||||
|
files: BTreeSet<PathBuf>,
|
||||||
|
dirs: BTreeSet<PathBuf>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WatchManifest {
|
||||||
|
fn from_source_files(source_files: &[PathBuf]) -> Self {
|
||||||
|
let mut files = BTreeSet::new();
|
||||||
|
let mut dirs = BTreeSet::new();
|
||||||
|
|
||||||
|
for path in source_files {
|
||||||
|
let normalized = normalize_watch_path(path);
|
||||||
|
files.insert(normalized.clone());
|
||||||
|
if let Some(parent) = normalized.parent() {
|
||||||
|
dirs.insert(parent.to_path_buf());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Self { files, dirs }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn matches_event_paths(&self, event_paths: &[PathBuf]) -> bool {
|
||||||
|
event_paths
|
||||||
|
.iter()
|
||||||
|
.map(|path| normalize_watch_path(path))
|
||||||
|
.any(|path| self.files.contains(&path))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Default)]
|
||||||
|
struct ReloadState {
|
||||||
|
applied_snapshot_hash: Option<u64>,
|
||||||
|
candidate_snapshot_hash: Option<u64>,
|
||||||
|
candidate_hits: u8,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ReloadState {
|
||||||
|
fn new(applied_snapshot_hash: Option<u64>) -> Self {
|
||||||
|
Self {
|
||||||
|
applied_snapshot_hash,
|
||||||
|
candidate_snapshot_hash: None,
|
||||||
|
candidate_hits: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_applied(&self, hash: u64) -> bool {
|
||||||
|
self.applied_snapshot_hash == Some(hash)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn observe_candidate(&mut self, hash: u64) -> u8 {
|
||||||
|
if self.candidate_snapshot_hash == Some(hash) {
|
||||||
|
self.candidate_hits = self.candidate_hits.saturating_add(1);
|
||||||
|
} else {
|
||||||
|
self.candidate_snapshot_hash = Some(hash);
|
||||||
|
self.candidate_hits = 1;
|
||||||
|
}
|
||||||
|
self.candidate_hits
|
||||||
|
}
|
||||||
|
|
||||||
|
fn reset_candidate(&mut self) {
|
||||||
|
self.candidate_snapshot_hash = None;
|
||||||
|
self.candidate_hits = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn mark_applied(&mut self, hash: u64) {
|
||||||
|
self.applied_snapshot_hash = Some(hash);
|
||||||
|
self.reset_candidate();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn normalize_watch_path(path: &Path) -> PathBuf {
|
||||||
|
path.canonicalize().unwrap_or_else(|_| {
|
||||||
|
if path.is_absolute() {
|
||||||
|
path.to_path_buf()
|
||||||
|
} else {
|
||||||
|
std::env::current_dir()
|
||||||
|
.map(|cwd| cwd.join(path))
|
||||||
|
.unwrap_or_else(|_| path.to_path_buf())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sync_watch_paths<W: Watcher>(
|
||||||
|
watcher: &mut W,
|
||||||
|
current: &BTreeSet<PathBuf>,
|
||||||
|
next: &BTreeSet<PathBuf>,
|
||||||
|
recursive_mode: RecursiveMode,
|
||||||
|
kind: &str,
|
||||||
|
) {
|
||||||
|
for path in current.difference(next) {
|
||||||
|
if let Err(e) = watcher.unwatch(path) {
|
||||||
|
warn!(path = %path.display(), error = %e, "config watcher: failed to unwatch {kind}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for path in next.difference(current) {
|
||||||
|
if let Err(e) = watcher.watch(path, recursive_mode) {
|
||||||
|
warn!(path = %path.display(), error = %e, "config watcher: failed to watch {kind}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn apply_watch_manifest<W1: Watcher, W2: Watcher>(
|
||||||
|
notify_watcher: Option<&mut W1>,
|
||||||
|
poll_watcher: Option<&mut W2>,
|
||||||
|
manifest_state: &Arc<StdRwLock<WatchManifest>>,
|
||||||
|
next_manifest: WatchManifest,
|
||||||
|
) {
|
||||||
|
let current_manifest = manifest_state
|
||||||
|
.read()
|
||||||
|
.map(|manifest| manifest.clone())
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
|
if current_manifest == next_manifest {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(watcher) = notify_watcher {
|
||||||
|
sync_watch_paths(
|
||||||
|
watcher,
|
||||||
|
¤t_manifest.dirs,
|
||||||
|
&next_manifest.dirs,
|
||||||
|
RecursiveMode::NonRecursive,
|
||||||
|
"config directory",
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(watcher) = poll_watcher {
|
||||||
|
sync_watch_paths(
|
||||||
|
watcher,
|
||||||
|
¤t_manifest.files,
|
||||||
|
&next_manifest.files,
|
||||||
|
RecursiveMode::NonRecursive,
|
||||||
|
"config file",
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Ok(mut manifest) = manifest_state.write() {
|
||||||
|
*manifest = next_manifest;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn overlay_hot_fields(old: &ProxyConfig, new: &ProxyConfig) -> ProxyConfig {
|
fn overlay_hot_fields(old: &ProxyConfig, new: &ProxyConfig) -> ProxyConfig {
|
||||||
let mut cfg = old.clone();
|
let mut cfg = old.clone();
|
||||||
|
|
||||||
|
|
@ -980,18 +1128,42 @@ fn reload_config(
|
||||||
log_tx: &watch::Sender<LogLevel>,
|
log_tx: &watch::Sender<LogLevel>,
|
||||||
detected_ip_v4: Option<IpAddr>,
|
detected_ip_v4: Option<IpAddr>,
|
||||||
detected_ip_v6: Option<IpAddr>,
|
detected_ip_v6: Option<IpAddr>,
|
||||||
) {
|
reload_state: &mut ReloadState,
|
||||||
let new_cfg = match ProxyConfig::load(config_path) {
|
) -> Option<WatchManifest> {
|
||||||
Ok(c) => c,
|
let loaded = match ProxyConfig::load_with_metadata(config_path) {
|
||||||
|
Ok(loaded) => loaded,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
|
reload_state.reset_candidate();
|
||||||
error!("config reload: failed to parse {:?}: {}", config_path, e);
|
error!("config reload: failed to parse {:?}: {}", config_path, e);
|
||||||
return;
|
return None;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
let LoadedConfig {
|
||||||
|
config: new_cfg,
|
||||||
|
source_files,
|
||||||
|
rendered_hash,
|
||||||
|
} = loaded;
|
||||||
|
let next_manifest = WatchManifest::from_source_files(&source_files);
|
||||||
|
|
||||||
if let Err(e) = new_cfg.validate() {
|
if let Err(e) = new_cfg.validate() {
|
||||||
|
reload_state.reset_candidate();
|
||||||
error!("config reload: validation failed: {}; keeping old config", e);
|
error!("config reload: validation failed: {}; keeping old config", e);
|
||||||
return;
|
return Some(next_manifest);
|
||||||
|
}
|
||||||
|
|
||||||
|
if reload_state.is_applied(rendered_hash) {
|
||||||
|
return Some(next_manifest);
|
||||||
|
}
|
||||||
|
|
||||||
|
let candidate_hits = reload_state.observe_candidate(rendered_hash);
|
||||||
|
if candidate_hits < HOT_RELOAD_STABLE_SNAPSHOTS {
|
||||||
|
info!(
|
||||||
|
snapshot_hash = rendered_hash,
|
||||||
|
candidate_hits,
|
||||||
|
required_hits = HOT_RELOAD_STABLE_SNAPSHOTS,
|
||||||
|
"config reload: candidate snapshot observed but not stable yet"
|
||||||
|
);
|
||||||
|
return Some(next_manifest);
|
||||||
}
|
}
|
||||||
|
|
||||||
let old_cfg = config_tx.borrow().clone();
|
let old_cfg = config_tx.borrow().clone();
|
||||||
|
|
@ -1006,17 +1178,19 @@ fn reload_config(
|
||||||
}
|
}
|
||||||
|
|
||||||
if !hot_changed {
|
if !hot_changed {
|
||||||
return;
|
reload_state.mark_applied(rendered_hash);
|
||||||
|
return Some(next_manifest);
|
||||||
}
|
}
|
||||||
|
|
||||||
if old_hot.dns_overrides != applied_hot.dns_overrides
|
if old_hot.dns_overrides != applied_hot.dns_overrides
|
||||||
&& let Err(e) = crate::network::dns_overrides::install_entries(&applied_hot.dns_overrides)
|
&& let Err(e) = crate::network::dns_overrides::install_entries(&applied_hot.dns_overrides)
|
||||||
{
|
{
|
||||||
|
reload_state.reset_candidate();
|
||||||
error!(
|
error!(
|
||||||
"config reload: invalid network.dns_overrides: {}; keeping old config",
|
"config reload: invalid network.dns_overrides: {}; keeping old config",
|
||||||
e
|
e
|
||||||
);
|
);
|
||||||
return;
|
return Some(next_manifest);
|
||||||
}
|
}
|
||||||
|
|
||||||
log_changes(
|
log_changes(
|
||||||
|
|
@ -1028,6 +1202,8 @@ fn reload_config(
|
||||||
detected_ip_v6,
|
detected_ip_v6,
|
||||||
);
|
);
|
||||||
config_tx.send(Arc::new(applied_cfg)).ok();
|
config_tx.send(Arc::new(applied_cfg)).ok();
|
||||||
|
reload_state.mark_applied(rendered_hash);
|
||||||
|
Some(next_manifest)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ── Public API ────────────────────────────────────────────────────────────────
|
// ── Public API ────────────────────────────────────────────────────────────────
|
||||||
|
|
@ -1050,79 +1226,86 @@ pub fn spawn_config_watcher(
|
||||||
let (config_tx, config_rx) = watch::channel(initial);
|
let (config_tx, config_rx) = watch::channel(initial);
|
||||||
let (log_tx, log_rx) = watch::channel(initial_level);
|
let (log_tx, log_rx) = watch::channel(initial_level);
|
||||||
|
|
||||||
// Bridge: sync notify callbacks → async task via mpsc.
|
let config_path = normalize_watch_path(&config_path);
|
||||||
let (notify_tx, mut notify_rx) = mpsc::channel::<()>(4);
|
let initial_loaded = ProxyConfig::load_with_metadata(&config_path).ok();
|
||||||
|
let initial_manifest = initial_loaded
|
||||||
|
.as_ref()
|
||||||
|
.map(|loaded| WatchManifest::from_source_files(&loaded.source_files))
|
||||||
|
.unwrap_or_else(|| WatchManifest::from_source_files(std::slice::from_ref(&config_path)));
|
||||||
|
let initial_snapshot_hash = initial_loaded.as_ref().map(|loaded| loaded.rendered_hash);
|
||||||
|
|
||||||
// Canonicalize so path matches what notify returns (absolute) in events.
|
|
||||||
let config_path = config_path
|
|
||||||
.canonicalize()
|
|
||||||
.unwrap_or_else(|_| config_path.to_path_buf());
|
|
||||||
|
|
||||||
// Watch the parent directory rather than the file itself, because many
|
|
||||||
// editors (vim, nano) and systemd write via rename, which would cause
|
|
||||||
// inotify to lose track of the original inode.
|
|
||||||
let watch_dir = config_path
|
|
||||||
.parent()
|
|
||||||
.unwrap_or_else(|| std::path::Path::new("."))
|
|
||||||
.to_path_buf();
|
|
||||||
|
|
||||||
// ── inotify watcher (instant on local fs) ────────────────────────────
|
|
||||||
let config_file = config_path.clone();
|
|
||||||
let tx_inotify = notify_tx.clone();
|
|
||||||
let inotify_ok = match recommended_watcher(move |res: notify::Result<notify::Event>| {
|
|
||||||
let Ok(event) = res else { return };
|
|
||||||
let is_our_file = event.paths.iter().any(|p| p == &config_file);
|
|
||||||
if !is_our_file { return; }
|
|
||||||
if matches!(event.kind, EventKind::Modify(_) | EventKind::Create(_) | EventKind::Remove(_)) {
|
|
||||||
let _ = tx_inotify.try_send(());
|
|
||||||
}
|
|
||||||
}) {
|
|
||||||
Ok(mut w) => match w.watch(&watch_dir, RecursiveMode::NonRecursive) {
|
|
||||||
Ok(()) => {
|
|
||||||
info!("config watcher: inotify active on {:?}", config_path);
|
|
||||||
Box::leak(Box::new(w));
|
|
||||||
true
|
|
||||||
}
|
|
||||||
Err(e) => { warn!("config watcher: inotify watch failed: {}", e); false }
|
|
||||||
},
|
|
||||||
Err(e) => { warn!("config watcher: inotify unavailable: {}", e); false }
|
|
||||||
};
|
|
||||||
|
|
||||||
// ── poll watcher (always active, fixes Docker bind mounts / NFS) ─────
|
|
||||||
// inotify does not receive events for files mounted from the host into
|
|
||||||
// a container. PollWatcher compares file contents every 3 s and fires
|
|
||||||
// on any change regardless of the underlying fs.
|
|
||||||
let config_file2 = config_path.clone();
|
|
||||||
let tx_poll = notify_tx;
|
|
||||||
match notify::poll::PollWatcher::new(
|
|
||||||
move |res: notify::Result<notify::Event>| {
|
|
||||||
let Ok(event) = res else { return };
|
|
||||||
let is_our_file = event.paths.iter().any(|p| p == &config_file2);
|
|
||||||
if !is_our_file { return; }
|
|
||||||
if matches!(event.kind, EventKind::Modify(_) | EventKind::Create(_) | EventKind::Remove(_)) {
|
|
||||||
let _ = tx_poll.try_send(());
|
|
||||||
}
|
|
||||||
},
|
|
||||||
notify::Config::default()
|
|
||||||
.with_poll_interval(std::time::Duration::from_secs(3))
|
|
||||||
.with_compare_contents(true),
|
|
||||||
) {
|
|
||||||
Ok(mut w) => match w.watch(&config_path, RecursiveMode::NonRecursive) {
|
|
||||||
Ok(()) => {
|
|
||||||
if inotify_ok {
|
|
||||||
info!("config watcher: poll watcher also active (Docker/NFS safe)");
|
|
||||||
} else {
|
|
||||||
info!("config watcher: poll watcher active on {:?} (3s interval)", config_path);
|
|
||||||
}
|
|
||||||
Box::leak(Box::new(w));
|
|
||||||
}
|
|
||||||
Err(e) => warn!("config watcher: poll watch failed: {}", e),
|
|
||||||
},
|
|
||||||
Err(e) => warn!("config watcher: poll watcher unavailable: {}", e),
|
|
||||||
}
|
|
||||||
|
|
||||||
// ── event loop ───────────────────────────────────────────────────────
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
|
let (notify_tx, mut notify_rx) = mpsc::channel::<()>(4);
|
||||||
|
let manifest_state = Arc::new(StdRwLock::new(WatchManifest::default()));
|
||||||
|
let mut reload_state = ReloadState::new(initial_snapshot_hash);
|
||||||
|
|
||||||
|
let tx_inotify = notify_tx.clone();
|
||||||
|
let manifest_for_inotify = manifest_state.clone();
|
||||||
|
let mut inotify_watcher = match recommended_watcher(move |res: notify::Result<notify::Event>| {
|
||||||
|
let Ok(event) = res else { return };
|
||||||
|
if !matches!(event.kind, EventKind::Modify(_) | EventKind::Create(_) | EventKind::Remove(_)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
let is_our_file = manifest_for_inotify
|
||||||
|
.read()
|
||||||
|
.map(|manifest| manifest.matches_event_paths(&event.paths))
|
||||||
|
.unwrap_or(false);
|
||||||
|
if is_our_file {
|
||||||
|
let _ = tx_inotify.try_send(());
|
||||||
|
}
|
||||||
|
}) {
|
||||||
|
Ok(watcher) => Some(watcher),
|
||||||
|
Err(e) => {
|
||||||
|
warn!("config watcher: inotify unavailable: {}", e);
|
||||||
|
None
|
||||||
|
}
|
||||||
|
};
|
||||||
|
apply_watch_manifest(
|
||||||
|
inotify_watcher.as_mut(),
|
||||||
|
Option::<&mut notify::poll::PollWatcher>::None,
|
||||||
|
&manifest_state,
|
||||||
|
initial_manifest.clone(),
|
||||||
|
);
|
||||||
|
if inotify_watcher.is_some() {
|
||||||
|
info!("config watcher: inotify active on {:?}", config_path);
|
||||||
|
}
|
||||||
|
|
||||||
|
let tx_poll = notify_tx.clone();
|
||||||
|
let manifest_for_poll = manifest_state.clone();
|
||||||
|
let mut poll_watcher = match notify::poll::PollWatcher::new(
|
||||||
|
move |res: notify::Result<notify::Event>| {
|
||||||
|
let Ok(event) = res else { return };
|
||||||
|
if !matches!(event.kind, EventKind::Modify(_) | EventKind::Create(_) | EventKind::Remove(_)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
let is_our_file = manifest_for_poll
|
||||||
|
.read()
|
||||||
|
.map(|manifest| manifest.matches_event_paths(&event.paths))
|
||||||
|
.unwrap_or(false);
|
||||||
|
if is_our_file {
|
||||||
|
let _ = tx_poll.try_send(());
|
||||||
|
}
|
||||||
|
},
|
||||||
|
notify::Config::default()
|
||||||
|
.with_poll_interval(Duration::from_secs(3))
|
||||||
|
.with_compare_contents(true),
|
||||||
|
) {
|
||||||
|
Ok(watcher) => Some(watcher),
|
||||||
|
Err(e) => {
|
||||||
|
warn!("config watcher: poll watcher unavailable: {}", e);
|
||||||
|
None
|
||||||
|
}
|
||||||
|
};
|
||||||
|
apply_watch_manifest(
|
||||||
|
Option::<&mut notify::RecommendedWatcher>::None,
|
||||||
|
poll_watcher.as_mut(),
|
||||||
|
&manifest_state,
|
||||||
|
initial_manifest.clone(),
|
||||||
|
);
|
||||||
|
if poll_watcher.is_some() {
|
||||||
|
info!("config watcher: poll watcher active (Docker/NFS safe)");
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(unix)]
|
#[cfg(unix)]
|
||||||
let mut sighup = {
|
let mut sighup = {
|
||||||
use tokio::signal::unix::{signal, SignalKind};
|
use tokio::signal::unix::{signal, SignalKind};
|
||||||
|
|
@ -1152,11 +1335,25 @@ pub fn spawn_config_watcher(
|
||||||
#[cfg(not(unix))]
|
#[cfg(not(unix))]
|
||||||
if notify_rx.recv().await.is_none() { break; }
|
if notify_rx.recv().await.is_none() { break; }
|
||||||
|
|
||||||
// Debounce: drain extra events that arrive within 50 ms.
|
// Debounce: drain extra events that arrive within a short quiet window.
|
||||||
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
|
tokio::time::sleep(HOT_RELOAD_DEBOUNCE).await;
|
||||||
while notify_rx.try_recv().is_ok() {}
|
while notify_rx.try_recv().is_ok() {}
|
||||||
|
|
||||||
reload_config(&config_path, &config_tx, &log_tx, detected_ip_v4, detected_ip_v6);
|
if let Some(next_manifest) = reload_config(
|
||||||
|
&config_path,
|
||||||
|
&config_tx,
|
||||||
|
&log_tx,
|
||||||
|
detected_ip_v4,
|
||||||
|
detected_ip_v6,
|
||||||
|
&mut reload_state,
|
||||||
|
) {
|
||||||
|
apply_watch_manifest(
|
||||||
|
inotify_watcher.as_mut(),
|
||||||
|
poll_watcher.as_mut(),
|
||||||
|
&manifest_state,
|
||||||
|
next_manifest,
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
@ -1171,6 +1368,40 @@ mod tests {
|
||||||
ProxyConfig::default()
|
ProxyConfig::default()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn write_reload_config(path: &Path, ad_tag: Option<&str>, server_port: Option<u16>) {
|
||||||
|
let mut config = String::from(
|
||||||
|
r#"
|
||||||
|
[censorship]
|
||||||
|
tls_domain = "example.com"
|
||||||
|
|
||||||
|
[access.users]
|
||||||
|
user = "00000000000000000000000000000000"
|
||||||
|
"#,
|
||||||
|
);
|
||||||
|
|
||||||
|
if ad_tag.is_some() {
|
||||||
|
config.push_str("\n[general]\n");
|
||||||
|
if let Some(tag) = ad_tag {
|
||||||
|
config.push_str(&format!("ad_tag = \"{tag}\"\n"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(port) = server_port {
|
||||||
|
config.push_str("\n[server]\n");
|
||||||
|
config.push_str(&format!("port = {port}\n"));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::fs::write(path, config).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn temp_config_path(prefix: &str) -> PathBuf {
|
||||||
|
let nonce = std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap()
|
||||||
|
.as_nanos();
|
||||||
|
std::env::temp_dir().join(format!("{prefix}_{nonce}.toml"))
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn overlay_applies_hot_and_preserves_non_hot() {
|
fn overlay_applies_hot_and_preserves_non_hot() {
|
||||||
let old = sample_config();
|
let old = sample_config();
|
||||||
|
|
@ -1238,4 +1469,61 @@ mod tests {
|
||||||
assert_eq!(applied.general.use_middle_proxy, old.general.use_middle_proxy);
|
assert_eq!(applied.general.use_middle_proxy, old.general.use_middle_proxy);
|
||||||
assert!(!config_equal(&applied, &new));
|
assert!(!config_equal(&applied, &new));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn reload_requires_stable_snapshot_before_hot_apply() {
|
||||||
|
let initial_tag = "11111111111111111111111111111111";
|
||||||
|
let final_tag = "22222222222222222222222222222222";
|
||||||
|
let path = temp_config_path("telemt_hot_reload_stable");
|
||||||
|
|
||||||
|
write_reload_config(&path, Some(initial_tag), None);
|
||||||
|
let initial_cfg = Arc::new(ProxyConfig::load(&path).unwrap());
|
||||||
|
let initial_hash = ProxyConfig::load_with_metadata(&path).unwrap().rendered_hash;
|
||||||
|
let (config_tx, _config_rx) = watch::channel(initial_cfg.clone());
|
||||||
|
let (log_tx, _log_rx) = watch::channel(initial_cfg.general.log_level.clone());
|
||||||
|
let mut reload_state = ReloadState::new(Some(initial_hash));
|
||||||
|
|
||||||
|
write_reload_config(&path, None, None);
|
||||||
|
reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
config_tx.borrow().general.ad_tag.as_deref(),
|
||||||
|
Some(initial_tag)
|
||||||
|
);
|
||||||
|
|
||||||
|
write_reload_config(&path, Some(final_tag), None);
|
||||||
|
reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
config_tx.borrow().general.ad_tag.as_deref(),
|
||||||
|
Some(initial_tag)
|
||||||
|
);
|
||||||
|
|
||||||
|
reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap();
|
||||||
|
assert_eq!(config_tx.borrow().general.ad_tag.as_deref(), Some(final_tag));
|
||||||
|
|
||||||
|
let _ = std::fs::remove_file(path);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn reload_keeps_hot_apply_when_non_hot_fields_change() {
|
||||||
|
let initial_tag = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa";
|
||||||
|
let final_tag = "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb";
|
||||||
|
let path = temp_config_path("telemt_hot_reload_mixed");
|
||||||
|
|
||||||
|
write_reload_config(&path, Some(initial_tag), None);
|
||||||
|
let initial_cfg = Arc::new(ProxyConfig::load(&path).unwrap());
|
||||||
|
let initial_hash = ProxyConfig::load_with_metadata(&path).unwrap().rendered_hash;
|
||||||
|
let (config_tx, _config_rx) = watch::channel(initial_cfg.clone());
|
||||||
|
let (log_tx, _log_rx) = watch::channel(initial_cfg.general.log_level.clone());
|
||||||
|
let mut reload_state = ReloadState::new(Some(initial_hash));
|
||||||
|
|
||||||
|
write_reload_config(&path, Some(final_tag), Some(initial_cfg.server.port + 1));
|
||||||
|
reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap();
|
||||||
|
reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap();
|
||||||
|
|
||||||
|
let applied = config_tx.borrow().clone();
|
||||||
|
assert_eq!(applied.general.ad_tag.as_deref(), Some(final_tag));
|
||||||
|
assert_eq!(applied.server.port, initial_cfg.server.port);
|
||||||
|
|
||||||
|
let _ = std::fs::remove_file(path);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,9 @@
|
||||||
#![allow(deprecated)]
|
#![allow(deprecated)]
|
||||||
|
|
||||||
use std::collections::HashMap;
|
use std::collections::{BTreeSet, HashMap};
|
||||||
|
use std::hash::{DefaultHasher, Hash, Hasher};
|
||||||
use std::net::{IpAddr, SocketAddr};
|
use std::net::{IpAddr, SocketAddr};
|
||||||
use std::path::Path;
|
use std::path::{Path, PathBuf};
|
||||||
|
|
||||||
use rand::Rng;
|
use rand::Rng;
|
||||||
use tracing::warn;
|
use tracing::warn;
|
||||||
|
|
@ -13,6 +14,8 @@ use crate::error::{ProxyError, Result};
|
||||||
use super::defaults::*;
|
use super::defaults::*;
|
||||||
use super::types::*;
|
use super::types::*;
|
||||||
|
|
||||||
|
// Reject absolute paths and any path component that traverses upward.
|
||||||
|
// This prevents config include directives from escaping the config directory.
|
||||||
fn is_safe_include_path(path_str: &str) -> bool {
|
fn is_safe_include_path(path_str: &str) -> bool {
|
||||||
let p = std::path::Path::new(path_str);
|
let p = std::path::Path::new(path_str);
|
||||||
if p.is_absolute() {
|
if p.is_absolute() {
|
||||||
|
|
@ -21,7 +24,37 @@ fn is_safe_include_path(path_str: &str) -> bool {
|
||||||
!p.components().any(|c| matches!(c, std::path::Component::ParentDir))
|
!p.components().any(|c| matches!(c, std::path::Component::ParentDir))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn preprocess_includes(content: &str, base_dir: &Path, depth: u8) -> Result<String> {
|
#[derive(Debug, Clone)]
|
||||||
|
pub(crate) struct LoadedConfig {
|
||||||
|
pub(crate) config: ProxyConfig,
|
||||||
|
pub(crate) source_files: Vec<PathBuf>,
|
||||||
|
pub(crate) rendered_hash: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn normalize_config_path(path: &Path) -> PathBuf {
|
||||||
|
path.canonicalize().unwrap_or_else(|_| {
|
||||||
|
if path.is_absolute() {
|
||||||
|
path.to_path_buf()
|
||||||
|
} else {
|
||||||
|
std::env::current_dir()
|
||||||
|
.map(|cwd| cwd.join(path))
|
||||||
|
.unwrap_or_else(|_| path.to_path_buf())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn hash_rendered_snapshot(rendered: &str) -> u64 {
|
||||||
|
let mut hasher = DefaultHasher::new();
|
||||||
|
rendered.hash(&mut hasher);
|
||||||
|
hasher.finish()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn preprocess_includes(
|
||||||
|
content: &str,
|
||||||
|
base_dir: &Path,
|
||||||
|
depth: u8,
|
||||||
|
source_files: &mut BTreeSet<PathBuf>,
|
||||||
|
) -> Result<String> {
|
||||||
if depth > 10 {
|
if depth > 10 {
|
||||||
return Err(ProxyError::Config("Include depth > 10".into()));
|
return Err(ProxyError::Config("Include depth > 10".into()));
|
||||||
}
|
}
|
||||||
|
|
@ -39,10 +72,16 @@ fn preprocess_includes(content: &str, base_dir: &Path, depth: u8) -> Result<Stri
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
let resolved = base_dir.join(path_str);
|
let resolved = base_dir.join(path_str);
|
||||||
|
source_files.insert(normalize_config_path(&resolved));
|
||||||
let included = std::fs::read_to_string(&resolved)
|
let included = std::fs::read_to_string(&resolved)
|
||||||
.map_err(|e| ProxyError::Config(e.to_string()))?;
|
.map_err(|e| ProxyError::Config(e.to_string()))?;
|
||||||
let included_dir = resolved.parent().unwrap_or(base_dir);
|
let included_dir = resolved.parent().unwrap_or(base_dir);
|
||||||
output.push_str(&preprocess_includes(&included, included_dir, depth + 1)?);
|
output.push_str(&preprocess_includes(
|
||||||
|
&included,
|
||||||
|
included_dir,
|
||||||
|
depth + 1,
|
||||||
|
source_files,
|
||||||
|
)?);
|
||||||
output.push('\n');
|
output.push('\n');
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
@ -152,13 +191,16 @@ pub struct ProxyConfig {
|
||||||
|
|
||||||
impl ProxyConfig {
|
impl ProxyConfig {
|
||||||
pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
|
pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
|
||||||
let content =
|
Self::load_with_metadata(path).map(|loaded| loaded.config)
|
||||||
std::fs::read_to_string(&path).map_err(|e| ProxyError::Config(e.to_string()))?;
|
}
|
||||||
let base_dir = path
|
|
||||||
.as_ref()
|
pub(crate) fn load_with_metadata<P: AsRef<Path>>(path: P) -> Result<LoadedConfig> {
|
||||||
.parent()
|
let path = path.as_ref();
|
||||||
.unwrap_or_else(|| Path::new("."));
|
let content = std::fs::read_to_string(path).map_err(|e| ProxyError::Config(e.to_string()))?;
|
||||||
let processed = preprocess_includes(&content, base_dir, 0)?;
|
let base_dir = path.parent().unwrap_or(Path::new("."));
|
||||||
|
let mut source_files = BTreeSet::new();
|
||||||
|
source_files.insert(normalize_config_path(path));
|
||||||
|
let processed = preprocess_includes(&content, base_dir, 0, &mut source_files)?;
|
||||||
|
|
||||||
let parsed_toml: toml::Value =
|
let parsed_toml: toml::Value =
|
||||||
toml::from_str(&processed).map_err(|e| ProxyError::Config(e.to_string()))?;
|
toml::from_str(&processed).map_err(|e| ProxyError::Config(e.to_string()))?;
|
||||||
|
|
@ -821,7 +863,11 @@ impl ProxyConfig {
|
||||||
.entry("203".to_string())
|
.entry("203".to_string())
|
||||||
.or_insert_with(|| vec!["91.105.192.100:443".to_string()]);
|
.or_insert_with(|| vec!["91.105.192.100:443".to_string()]);
|
||||||
|
|
||||||
Ok(config)
|
Ok(LoadedConfig {
|
||||||
|
config,
|
||||||
|
source_files: source_files.into_iter().collect(),
|
||||||
|
rendered_hash: hash_rendered_snapshot(&processed),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn validate(&self) -> Result<()> {
|
pub fn validate(&self) -> Result<()> {
|
||||||
|
|
@ -1165,6 +1211,48 @@ mod tests {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn load_with_metadata_collects_include_files() {
|
||||||
|
let nonce = std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap()
|
||||||
|
.as_nanos();
|
||||||
|
let dir = std::env::temp_dir().join(format!("telemt_load_metadata_{nonce}"));
|
||||||
|
std::fs::create_dir_all(&dir).unwrap();
|
||||||
|
let main_path = dir.join("config.toml");
|
||||||
|
let include_path = dir.join("included.toml");
|
||||||
|
|
||||||
|
std::fs::write(
|
||||||
|
&include_path,
|
||||||
|
r#"
|
||||||
|
[access.users]
|
||||||
|
user = "00000000000000000000000000000000"
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
std::fs::write(
|
||||||
|
&main_path,
|
||||||
|
r#"
|
||||||
|
include = "included.toml"
|
||||||
|
|
||||||
|
[censorship]
|
||||||
|
tls_domain = "example.com"
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let loaded = ProxyConfig::load_with_metadata(&main_path).unwrap();
|
||||||
|
let main_normalized = normalize_config_path(&main_path);
|
||||||
|
let include_normalized = normalize_config_path(&include_path);
|
||||||
|
|
||||||
|
assert!(loaded.source_files.contains(&main_normalized));
|
||||||
|
assert!(loaded.source_files.contains(&include_normalized));
|
||||||
|
|
||||||
|
let _ = std::fs::remove_file(main_path);
|
||||||
|
let _ = std::fs::remove_file(include_path);
|
||||||
|
let _ = std::fs::remove_dir(dir);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn dc_overrides_inject_dc203_default() {
|
fn dc_overrides_inject_dc203_default() {
|
||||||
let toml = r#"
|
let toml = r#"
|
||||||
|
|
@ -2197,7 +2285,8 @@ mod tests {
|
||||||
|
|
||||||
let root = dir.join(format!("{prefix}0.toml"));
|
let root = dir.join(format!("{prefix}0.toml"));
|
||||||
let root_content = std::fs::read_to_string(&root).unwrap();
|
let root_content = std::fs::read_to_string(&root).unwrap();
|
||||||
let result = preprocess_includes(&root_content, &dir, 0);
|
let mut sf = BTreeSet::new();
|
||||||
|
let result = preprocess_includes(&root_content, &dir, 0, &mut sf);
|
||||||
|
|
||||||
for i in 0usize..=11 {
|
for i in 0usize..=11 {
|
||||||
let _ = std::fs::remove_file(dir.join(format!("{prefix}{i}.toml")));
|
let _ = std::fs::remove_file(dir.join(format!("{prefix}{i}.toml")));
|
||||||
|
|
@ -2230,7 +2319,8 @@ mod tests {
|
||||||
|
|
||||||
let root = dir.join(format!("{prefix}0.toml"));
|
let root = dir.join(format!("{prefix}0.toml"));
|
||||||
let root_content = std::fs::read_to_string(&root).unwrap();
|
let root_content = std::fs::read_to_string(&root).unwrap();
|
||||||
let result = preprocess_includes(&root_content, &dir, 0);
|
let mut sf = BTreeSet::new();
|
||||||
|
let result = preprocess_includes(&root_content, &dir, 0, &mut sf);
|
||||||
|
|
||||||
for i in 0usize..=10 {
|
for i in 0usize..=10 {
|
||||||
let _ = std::fs::remove_file(dir.join(format!("{prefix}{i}.toml")));
|
let _ = std::fs::remove_file(dir.join(format!("{prefix}{i}.toml")));
|
||||||
|
|
|
||||||
|
|
@ -382,20 +382,7 @@ mod tests {
|
||||||
}
|
}
|
||||||
|
|
||||||
// ObfuscationParams must NOT implement Clone — key material must not be duplicated.
|
// ObfuscationParams must NOT implement Clone — key material must not be duplicated.
|
||||||
// This is a compile-time enforcement test via negative trait assertion.
|
// Enforced at compile time: if Clone is ever derived or manually implemented for
|
||||||
// If Clone were derived again, this assertion would fail to compile.
|
// ObfuscationParams, this assertion will fail to compile.
|
||||||
#[test]
|
static_assertions::assert_not_impl_any!(ObfuscationParams: Clone);
|
||||||
fn test_obfuscation_params_is_not_clone() {
|
|
||||||
fn assert_not_clone<T: ?Sized>() {}
|
|
||||||
// The trait bound below must NOT hold. We verify by trying to assert
|
|
||||||
// Clone is NOT available: calling ::clone() on the type must not compile.
|
|
||||||
// Since we cannot express negative bounds in stable Rust, we use the
|
|
||||||
// auto_trait approach: verify the type is not Clone via a static assertion.
|
|
||||||
// If this test compiles and passes, Clone is absent from ObfuscationParams.
|
|
||||||
struct NotClone;
|
|
||||||
impl NotClone { fn check() { assert_not_clone::<ObfuscationParams>(); } }
|
|
||||||
// The above does not assert Clone is absent; the real guard is that the
|
|
||||||
// compiler will reject any `.clone()` call site. This test documents intent.
|
|
||||||
NotClone::check();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,6 @@
|
||||||
use bytes::BytesMut;
|
use bytes::BytesMut;
|
||||||
use crossbeam_queue::ArrayQueue;
|
use crossbeam_queue::ArrayQueue;
|
||||||
use std::ops::{Deref, DerefMut};
|
use std::ops::{Deref, DerefMut};
|
||||||
use std::sync::OnceLock;
|
|
||||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
|
@ -21,6 +20,11 @@ pub const DEFAULT_BUFFER_SIZE: usize = 16 * 1024;
|
||||||
/// Default maximum number of pooled buffers
|
/// Default maximum number of pooled buffers
|
||||||
pub const DEFAULT_MAX_BUFFERS: usize = 1024;
|
pub const DEFAULT_MAX_BUFFERS: usize = 1024;
|
||||||
|
|
||||||
|
// Buffers that grew beyond this multiple of `buffer_size` are dropped rather
|
||||||
|
// than returned to the pool, preventing memory amplification from a single
|
||||||
|
// large-payload connection permanently holding oversized allocations.
|
||||||
|
const MAX_POOL_BUFFER_OVERSIZE_MULT: usize = 4;
|
||||||
|
|
||||||
// ============= Buffer Pool =============
|
// ============= Buffer Pool =============
|
||||||
|
|
||||||
/// Thread-safe pool of reusable buffers
|
/// Thread-safe pool of reusable buffers
|
||||||
|
|
@ -97,11 +101,13 @@ impl BufferPool {
|
||||||
fn return_buffer(&self, mut buffer: BytesMut) {
|
fn return_buffer(&self, mut buffer: BytesMut) {
|
||||||
buffer.clear();
|
buffer.clear();
|
||||||
|
|
||||||
// Only return buffers that have at least the canonical capacity so the
|
// Accept buffers within [buffer_size, buffer_size * MAX_POOL_BUFFER_OVERSIZE_MULT].
|
||||||
// pool does not shrink over time when callers grow their buffers.
|
// The lower bound prevents pool capacity from shrinking over time.
|
||||||
// Over-capacity or full-queue buffers are simply dropped; `allocated`
|
// The upper bound drops buffers that grew excessively (e.g. to serve a large
|
||||||
// is a high-water mark and is not decremented here.
|
// payload) so they do not permanently inflate pool memory or get handed to a
|
||||||
if buffer.capacity() >= self.buffer_size {
|
// future connection that only needs a small allocation.
|
||||||
|
let max_acceptable = self.buffer_size.saturating_mul(MAX_POOL_BUFFER_OVERSIZE_MULT);
|
||||||
|
if buffer.capacity() >= self.buffer_size && buffer.capacity() <= max_acceptable {
|
||||||
let _ = self.buffers.push(buffer);
|
let _ = self.buffers.push(buffer);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -214,18 +220,19 @@ impl PooledBuffer {
|
||||||
|
|
||||||
impl Deref for PooledBuffer {
|
impl Deref for PooledBuffer {
|
||||||
type Target = BytesMut;
|
type Target = BytesMut;
|
||||||
|
|
||||||
fn deref(&self) -> &Self::Target {
|
fn deref(&self) -> &Self::Target {
|
||||||
static EMPTY_BUFFER: OnceLock<BytesMut> = OnceLock::new();
|
|
||||||
self.buffer
|
self.buffer
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.unwrap_or_else(|| EMPTY_BUFFER.get_or_init(BytesMut::new))
|
.expect("PooledBuffer: attempted to deref after buffer was taken")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl DerefMut for PooledBuffer {
|
impl DerefMut for PooledBuffer {
|
||||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||||
self.buffer.get_or_insert_with(BytesMut::new)
|
self.buffer
|
||||||
|
.as_mut()
|
||||||
|
.expect("PooledBuffer: attempted to deref_mut after buffer was taken")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -503,22 +510,27 @@ mod tests {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// An over-grown buffer (capacity > buffer_size) that is returned to the pool
|
// A buffer that grew moderately (within MAX_POOL_BUFFER_OVERSIZE_MULT of the canonical
|
||||||
// must still be returned (not dropped) because its capacity exceeds the
|
// size) must be returned to the pool, because the allocation is still reasonable and
|
||||||
// required minimum.
|
// reusing it is more efficient than allocating a new one.
|
||||||
#[test]
|
#[test]
|
||||||
fn oversized_buffer_is_returned_to_pool() {
|
fn oversized_buffer_is_returned_to_pool() {
|
||||||
let pool = Arc::new(BufferPool::with_config(64, 10));
|
let canonical = 64usize;
|
||||||
|
let pool = Arc::new(BufferPool::with_config(canonical, 10));
|
||||||
|
|
||||||
let mut buf = pool.get();
|
let mut buf = pool.get();
|
||||||
// Manually grow the backing allocation beyond buffer_size.
|
// Grow to 2× the canonical size — within the 4× upper bound.
|
||||||
buf.reserve(8192);
|
buf.reserve(canonical);
|
||||||
assert!(buf.capacity() >= 8192);
|
assert!(buf.capacity() >= canonical);
|
||||||
|
assert!(
|
||||||
|
buf.capacity() <= canonical * MAX_POOL_BUFFER_OVERSIZE_MULT,
|
||||||
|
"pre-condition: test growth must stay within the acceptable bound"
|
||||||
|
);
|
||||||
drop(buf);
|
drop(buf);
|
||||||
|
|
||||||
// The buffer must have been returned because capacity >= buffer_size.
|
// The buffer must have been returned because capacity is within acceptable range.
|
||||||
let stats = pool.stats();
|
let stats = pool.stats();
|
||||||
assert_eq!(stats.pooled, 1, "oversized buffer must be returned to pool");
|
assert_eq!(stats.pooled, 1, "moderately-oversized buffer must be returned to pool");
|
||||||
}
|
}
|
||||||
|
|
||||||
// A buffer whose capacity fell below buffer_size (e.g. due to take() on
|
// A buffer whose capacity fell below buffer_size (e.g. due to take() on
|
||||||
|
|
@ -625,4 +637,97 @@ mod tests {
|
||||||
);
|
);
|
||||||
assert_eq!(stats.hits, 50);
|
assert_eq!(stats.hits, 50);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ── Security invariant: sensitive data must not leak between pool users ───
|
||||||
|
|
||||||
|
// A buffer containing "sensitive" bytes must be zeroed before being handed
|
||||||
|
// to the next caller. An attacker who can trigger repeated pool cycles against
|
||||||
|
// a shared buffer slot must not be able to read prior connection data.
|
||||||
|
#[test]
|
||||||
|
fn pooled_buffer_sensitive_data_is_cleared_before_reuse() {
|
||||||
|
let pool = Arc::new(BufferPool::with_config(64, 2));
|
||||||
|
{
|
||||||
|
let mut buf = pool.get();
|
||||||
|
buf.extend_from_slice(b"credentials:password123");
|
||||||
|
// Drop returns the buffer to the pool after clearing.
|
||||||
|
}
|
||||||
|
{
|
||||||
|
let buf = pool.get();
|
||||||
|
// Buffer must be empty — no leftover bytes from the previous user.
|
||||||
|
assert!(buf.is_empty(), "pool must clear buffer before handing it to the next caller");
|
||||||
|
assert_eq!(buf.len(), 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify that calling take() extracts the full content and the extracted
|
||||||
|
// BytesMut does NOT get returned to the pool (no double-return).
|
||||||
|
#[test]
|
||||||
|
fn pooled_buffer_take_eliminates_pool_return() {
|
||||||
|
let pool = Arc::new(BufferPool::with_config(64, 2));
|
||||||
|
let stats_before = pool.stats();
|
||||||
|
|
||||||
|
let mut buf = pool.get(); // miss
|
||||||
|
buf.extend_from_slice(b"important");
|
||||||
|
let inner = buf.take(); // consumes PooledBuffer, should NOT return to pool
|
||||||
|
|
||||||
|
assert_eq!(&inner[..], b"important");
|
||||||
|
let stats_after = pool.stats();
|
||||||
|
// pooled count must not increase — take() bypasses the pool
|
||||||
|
assert_eq!(
|
||||||
|
stats_after.pooled, stats_before.pooled,
|
||||||
|
"take() must not return the buffer to the pool"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Multiple concurrent get() calls must each get an independent empty buffer,
|
||||||
|
// not aliased memory. An adversary who can cause aliased buffer access could
|
||||||
|
// read or corrupt another connection's in-flight data.
|
||||||
|
#[test]
|
||||||
|
fn pooled_buffers_are_independent_no_aliasing() {
|
||||||
|
let pool = Arc::new(BufferPool::with_config(64, 4));
|
||||||
|
let mut b1 = pool.get();
|
||||||
|
let mut b2 = pool.get();
|
||||||
|
|
||||||
|
b1.extend_from_slice(b"connection-A");
|
||||||
|
b2.extend_from_slice(b"connection-B");
|
||||||
|
|
||||||
|
assert_eq!(&b1[..], b"connection-A");
|
||||||
|
assert_eq!(&b2[..], b"connection-B");
|
||||||
|
// Verify no aliasing: modifying b2 does not affect b1.
|
||||||
|
assert_ne!(&b1[..], &b2[..]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Oversized buffers (capacity grown beyond pool's canonical size) must NOT
|
||||||
|
// be returned to the pool — this prevents the pool from holding oversized
|
||||||
|
// buffers that could be handed to unrelated connections and leak large chunks
|
||||||
|
// of heap across connection boundaries.
|
||||||
|
#[test]
|
||||||
|
fn oversized_buffer_is_dropped_not_pooled() {
|
||||||
|
let canonical = 64usize;
|
||||||
|
let pool = Arc::new(BufferPool::with_config(canonical, 4));
|
||||||
|
|
||||||
|
{
|
||||||
|
let mut buf = pool.get();
|
||||||
|
// Grow well beyond the canonical size.
|
||||||
|
buf.extend(std::iter::repeat(0u8).take(canonical * 8));
|
||||||
|
// Drop should abandon this oversized buffer rather than returning it.
|
||||||
|
}
|
||||||
|
|
||||||
|
let stats = pool.stats();
|
||||||
|
// Pool must be empty: the oversized buffer was not re-queued.
|
||||||
|
assert_eq!(
|
||||||
|
stats.pooled, 0,
|
||||||
|
"oversized buffer must be dropped, not returned to pool (got {} pooled)",
|
||||||
|
stats.pooled
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deref on a PooledBuffer obtained normally must NOT panic.
|
||||||
|
#[test]
|
||||||
|
fn pooled_buffer_deref_on_live_buffer_does_not_panic() {
|
||||||
|
let pool = Arc::new(BufferPool::new());
|
||||||
|
let mut buf = pool.get();
|
||||||
|
buf.extend_from_slice(b"hello");
|
||||||
|
assert_eq!(&buf[..], b"hello");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -97,8 +97,18 @@ pub(crate) fn build_proxy_req_payload(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let extra_bytes = u32::try_from(b.len() - extra_start - 4).unwrap_or(0);
|
// On overflow (buffer > 4 GiB, unreachable in practice) keep the extra block
|
||||||
b[extra_start..extra_start + 4].copy_from_slice(&extra_bytes.to_le_bytes());
|
// empty by truncating the data back to the length-field position and writing 0,
|
||||||
|
// so the wire representation remains consistent (length field matches content).
|
||||||
|
match u32::try_from(b.len() - extra_start - 4) {
|
||||||
|
Ok(extra_bytes) => {
|
||||||
|
b[extra_start..extra_start + 4].copy_from_slice(&extra_bytes.to_le_bytes());
|
||||||
|
}
|
||||||
|
Err(_) => {
|
||||||
|
b.truncate(extra_start + 4);
|
||||||
|
b[extra_start..extra_start + 4].copy_from_slice(&0u32.to_le_bytes());
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
b.extend_from_slice(data);
|
b.extend_from_slice(data);
|
||||||
|
|
@ -242,6 +252,7 @@ mod tests {
|
||||||
|
|
||||||
// Tags exceeding the 3-byte TL length limit (> 0xFFFFFF = 16,777,215 bytes)
|
// Tags exceeding the 3-byte TL length limit (> 0xFFFFFF = 16,777,215 bytes)
|
||||||
// must produce an empty extra block rather than a truncated/corrupt length.
|
// must produce an empty extra block rather than a truncated/corrupt length.
|
||||||
|
#[ignore = "allocates ~16 MiB; run only in a dedicated large-tests profile/CI job"]
|
||||||
#[test]
|
#[test]
|
||||||
fn test_build_proxy_req_payload_tag_exceeds_tl_long_form_limit_produces_empty_extra() {
|
fn test_build_proxy_req_payload_tag_exceeds_tl_long_form_limit_produces_empty_extra() {
|
||||||
let client = SocketAddr::from(([198, 51, 100, 22], 33333));
|
let client = SocketAddr::from(([198, 51, 100, 22], 33333));
|
||||||
|
|
@ -267,6 +278,7 @@ mod tests {
|
||||||
|
|
||||||
// The prior guard was `tag_len_u32.is_some()` which passed for tags up to u32::MAX.
|
// The prior guard was `tag_len_u32.is_some()` which passed for tags up to u32::MAX.
|
||||||
// Verify boundary at exactly 0xFFFFFF: must succeed and encode at 3 bytes.
|
// Verify boundary at exactly 0xFFFFFF: must succeed and encode at 3 bytes.
|
||||||
|
#[ignore = "allocates ~16 MiB; run only in a dedicated large-tests profile/CI job"]
|
||||||
#[test]
|
#[test]
|
||||||
fn test_build_proxy_req_payload_tag_at_tl_long_form_max_boundary_encodes() {
|
fn test_build_proxy_req_payload_tag_at_tl_long_form_max_boundary_encodes() {
|
||||||
// 0xFFFFFF = 16,777,215 bytes — maximum representable by 3-byte TL length.
|
// 0xFFFFFF = 16,777,215 bytes — maximum representable by 3-byte TL length.
|
||||||
|
|
@ -303,4 +315,101 @@ mod tests {
|
||||||
let fixed = 4 + 4 + 8 + 20 + 20;
|
let fixed = 4 + 4 + 8 + 20 + 20;
|
||||||
let extra_len = u32::from_le_bytes(raw[fixed..fixed + 4].try_into().unwrap()) as usize;
|
let extra_len = u32::from_le_bytes(raw[fixed..fixed + 4].try_into().unwrap()) as usize;
|
||||||
assert_eq!(extra_len, 0);
|
assert_eq!(extra_len, 0);
|
||||||
}}
|
}
|
||||||
|
|
||||||
|
// ── Protocol wire-consistency invariant tests ─────────────────────────────
|
||||||
|
|
||||||
|
// The extra-block length field must ALWAYS equal the number of bytes that
|
||||||
|
// actually follow it in the buffer. A censor or MitM that parses the wire
|
||||||
|
// representation must not be able to trigger desync by providing adversarial
|
||||||
|
// input that causes the length to say 0 while data bytes are present.
|
||||||
|
#[test]
|
||||||
|
fn extra_block_length_field_always_matches_actual_content_length() {
|
||||||
|
let client = SocketAddr::from(([198, 51, 100, 40], 10000));
|
||||||
|
let ours = SocketAddr::from(([203, 0, 113, 40], 443));
|
||||||
|
let fixed = 4 + 4 + 8 + 20 + 20;
|
||||||
|
|
||||||
|
// Helper: assert wire consistency for a given tag.
|
||||||
|
let check = |tag: Option<&[u8]>, data: &[u8]| {
|
||||||
|
let p = build_proxy_req_payload(1, client, ours, data, tag, RPC_FLAG_HAS_AD_TAG);
|
||||||
|
let raw = p.as_ref();
|
||||||
|
let declared =
|
||||||
|
u32::from_le_bytes(raw[fixed..fixed + 4].try_into().unwrap()) as usize;
|
||||||
|
let actual = raw.len() - fixed - 4 - data.len();
|
||||||
|
assert_eq!(
|
||||||
|
declared, actual,
|
||||||
|
"extra-block length field ({declared}) must equal actual byte count ({actual})"
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
check(None, b"data");
|
||||||
|
check(Some(b""), b"data"); // zero-length tag
|
||||||
|
check(Some(&[0xABu8; 1]), b""); // 1-byte tag
|
||||||
|
check(Some(&[0xABu8; 253]), b"x"); // short-form max
|
||||||
|
check(Some(&[0xCCu8; 254]), b"y"); // long-form min
|
||||||
|
check(Some(&[0xDDu8; 500]), b"z"); // mid-range long-form
|
||||||
|
}
|
||||||
|
|
||||||
|
// Zero-length tag: must encode with short-form length byte 0 and 4-byte
|
||||||
|
// alignment padding, not be dropped as if tag were None.
|
||||||
|
#[test]
|
||||||
|
fn test_build_proxy_req_payload_zero_length_tag_encodes_correctly() {
|
||||||
|
let client = SocketAddr::from(([198, 51, 100, 50], 20000));
|
||||||
|
let ours = SocketAddr::from(([203, 0, 113, 50], 443));
|
||||||
|
let payload = build_proxy_req_payload(
|
||||||
|
12,
|
||||||
|
client,
|
||||||
|
ours,
|
||||||
|
b"payload",
|
||||||
|
Some(&[]),
|
||||||
|
RPC_FLAG_HAS_AD_TAG,
|
||||||
|
);
|
||||||
|
let raw = payload.as_ref();
|
||||||
|
let fixed = 4 + 4 + 8 + 20 + 20;
|
||||||
|
let extra_len = u32::from_le_bytes(raw[fixed..fixed + 4].try_into().unwrap()) as usize;
|
||||||
|
// TL object (4) + length-byte 0 (1) + 0 data bytes + padding to 4-byte boundary.
|
||||||
|
// (1 + 0) % 4 = 1; pad = (4 - 1) % 4 = 3.
|
||||||
|
assert_eq!(extra_len, 4 + 1 + 3, "zero-length tag must produce TL header + padding");
|
||||||
|
// Length byte must be 0 (short-form).
|
||||||
|
assert_eq!(raw[fixed + 8], 0u8);
|
||||||
|
}
|
||||||
|
|
||||||
|
// HAS_AD_TAG absent: extra block must NOT appear in the wire output at all.
|
||||||
|
#[test]
|
||||||
|
fn test_build_proxy_req_payload_without_has_ad_tag_flag_has_no_extra_block() {
|
||||||
|
let client = SocketAddr::from(([198, 51, 100, 60], 30000));
|
||||||
|
let ours = SocketAddr::from(([203, 0, 113, 60], 443));
|
||||||
|
let payload = build_proxy_req_payload(
|
||||||
|
13,
|
||||||
|
client,
|
||||||
|
ours,
|
||||||
|
b"data",
|
||||||
|
Some(&[1, 2, 3]),
|
||||||
|
0, // no HAS_AD_TAG flag
|
||||||
|
);
|
||||||
|
let fixed = 4 + 4 + 8 + 20 + 20;
|
||||||
|
let raw = payload.as_ref();
|
||||||
|
// Without the flag, extra block is skipped; payload follows immediately after headers.
|
||||||
|
assert_eq!(raw.len(), fixed + 4, "no extra block when HAS_AD_TAG is absent");
|
||||||
|
assert_eq!(&raw[fixed..], b"data");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tag with all-0xFF bytes: the TL length encoding must not be confused by
|
||||||
|
// the 0xFF marker byte that the tag itself might contain.
|
||||||
|
#[test]
|
||||||
|
fn test_build_proxy_req_payload_tag_containing_0xfe_byte_encodes_correctly() {
|
||||||
|
let client = SocketAddr::from(([198, 51, 100, 70], 40000));
|
||||||
|
let ours = SocketAddr::from(([203, 0, 113, 70], 443));
|
||||||
|
// 10-byte tag whose first byte is 0xFE (the long-form marker value).
|
||||||
|
// At length 10 the short-form encoding applies; the 0xFE data byte must
|
||||||
|
// not be confused with the length marker.
|
||||||
|
let tag = vec![0xFEu8; 10];
|
||||||
|
let payload = build_proxy_req_payload(14, client, ours, b"", Some(&tag), RPC_FLAG_HAS_AD_TAG);
|
||||||
|
let raw = payload.as_ref();
|
||||||
|
let fixed = 4 + 4 + 8 + 20 + 20;
|
||||||
|
// Length byte must be 10 (short form), not 0xFE.
|
||||||
|
assert_eq!(raw[fixed + 8], 10u8, "short-form length byte must be 10, not the 0xFE marker");
|
||||||
|
// The actual tag bytes must follow immediately.
|
||||||
|
assert_eq!(&raw[fixed + 9..fixed + 19], &[0xFEu8; 10]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -260,23 +260,44 @@ pub struct PoolStats {
|
||||||
#[cfg(unix)]
|
#[cfg(unix)]
|
||||||
#[allow(unsafe_code)]
|
#[allow(unsafe_code)]
|
||||||
fn is_connection_alive(stream: &TcpStream) -> bool {
|
fn is_connection_alive(stream: &TcpStream) -> bool {
|
||||||
|
use std::io::ErrorKind;
|
||||||
use std::os::unix::io::AsRawFd;
|
use std::os::unix::io::AsRawFd;
|
||||||
|
|
||||||
|
let fd = stream.as_raw_fd();
|
||||||
let mut buf = [0u8; 1];
|
let mut buf = [0u8; 1];
|
||||||
// SAFETY: `stream` owns this fd for the full duration of the call.
|
|
||||||
// MSG_PEEK + MSG_DONTWAIT: inspect the receive buffer without consuming bytes.
|
// EINTR can fire on any syscall when a signal is delivered to the thread.
|
||||||
let n = unsafe {
|
// Treating it as a dead connection causes spurious reconnects; retry instead.
|
||||||
libc::recv(
|
const MAX_RECV_RETRIES: usize = 3;
|
||||||
stream.as_raw_fd(),
|
|
||||||
buf.as_mut_ptr().cast::<libc::c_void>(),
|
for _ in 0..MAX_RECV_RETRIES {
|
||||||
1,
|
// SAFETY: `stream` owns this fd for the full duration of the call.
|
||||||
libc::MSG_PEEK | libc::MSG_DONTWAIT,
|
// MSG_PEEK + MSG_DONTWAIT: inspect the receive buffer without consuming bytes.
|
||||||
)
|
let n = unsafe {
|
||||||
};
|
libc::recv(
|
||||||
match n {
|
fd,
|
||||||
0 => false,
|
buf.as_mut_ptr().cast::<libc::c_void>(),
|
||||||
n if n > 0 => true,
|
1,
|
||||||
_ => std::io::Error::last_os_error().kind() == std::io::ErrorKind::WouldBlock,
|
libc::MSG_PEEK | libc::MSG_DONTWAIT,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
|
||||||
|
if n > 0 {
|
||||||
|
return true;
|
||||||
|
} else if n == 0 {
|
||||||
|
return false;
|
||||||
|
} else {
|
||||||
|
match std::io::Error::last_os_error().kind() {
|
||||||
|
ErrorKind::Interrupted => continue,
|
||||||
|
ErrorKind::WouldBlock => return true,
|
||||||
|
_ => return false,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// After MAX_RECV_RETRIES consecutive EINTR, assume the connection is alive
|
||||||
|
// rather than triggering a false reconnect.
|
||||||
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(not(unix))]
|
#[cfg(not(unix))]
|
||||||
|
|
@ -518,5 +539,63 @@ mod tests {
|
||||||
assert!(is_connection_alive(&stream));
|
assert!(is_connection_alive(&stream));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
|
// ── EINTR retry-logic unit tests ──────────────────────────────────────────
|
||||||
|
|
||||||
|
// The non-unix fallback path must correctly classify WouldBlock as alive
|
||||||
|
// and any other error as dead, mirroring the unix semantics.
|
||||||
|
#[cfg(not(unix))]
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_is_connection_alive_non_unix_open_connection() {
|
||||||
|
let listener = match TcpListener::bind("127.0.0.1:0").await {
|
||||||
|
Ok(l) => l,
|
||||||
|
Err(e) if e.kind() == ErrorKind::PermissionDenied => return,
|
||||||
|
Err(e) => panic!("bind failed: {e}"),
|
||||||
|
};
|
||||||
|
let addr = listener.local_addr().unwrap();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let _ = listener.accept().await;
|
||||||
|
tokio::time::sleep(std::time::Duration::from_millis(500)).await;
|
||||||
|
});
|
||||||
|
let stream = match TcpStream::connect(addr).await {
|
||||||
|
Ok(s) => s,
|
||||||
|
Err(e) if e.kind() == ErrorKind::PermissionDenied => return,
|
||||||
|
Err(e) => panic!("connect failed: {e}"),
|
||||||
|
};
|
||||||
|
assert!(is_connection_alive(&stream), "open idle connection must be alive (non-unix)");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the unix path: calling is_connection_alive many times on an active
|
||||||
|
// connection never spuriously returns false (guards against the pre-fix bug
|
||||||
|
// where EINTR would flip the result on a single call).
|
||||||
|
#[cfg(unix)]
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_is_connection_alive_never_returns_false_spuriously_on_open_connection() {
|
||||||
|
use tokio::io::AsyncWriteExt;
|
||||||
|
let listener = match TcpListener::bind("127.0.0.1:0").await {
|
||||||
|
Ok(l) => l,
|
||||||
|
Err(e) if e.kind() == ErrorKind::PermissionDenied => return,
|
||||||
|
Err(e) => panic!("bind failed: {e}"),
|
||||||
|
};
|
||||||
|
let addr = listener.local_addr().unwrap();
|
||||||
|
let server = tokio::spawn(async move {
|
||||||
|
let (mut s, _) = listener.accept().await.expect("accept");
|
||||||
|
s.write_all(&[0xBEu8]).await.expect("write");
|
||||||
|
tokio::time::sleep(Duration::from_millis(500)).await;
|
||||||
|
});
|
||||||
|
let stream = match TcpStream::connect(addr).await {
|
||||||
|
Ok(s) => s,
|
||||||
|
Err(e) if e.kind() == ErrorKind::PermissionDenied => return,
|
||||||
|
Err(e) => panic!("connect failed: {e}"),
|
||||||
|
};
|
||||||
|
tokio::time::sleep(Duration::from_millis(30)).await;
|
||||||
|
// Call is_connection_alive 20 times; all must return true.
|
||||||
|
for i in 0..20 {
|
||||||
|
assert!(
|
||||||
|
is_connection_alive(&stream),
|
||||||
|
"is_connection_alive must be true on call {i} (connection is open with buffered data)"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
drop(server);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue