mirror of https://github.com/telemt/telemt.git
Compare commits
33 Commits
c0c7f28f32
...
ad9c6dba7b
| Author | SHA1 | Date |
|---|---|---|
|
|
ad9c6dba7b | |
|
|
44376b5652 | |
|
|
c7cf37898b | |
|
|
20e205189c | |
|
|
97d4a1c5c8 | |
|
|
c2443e6f1a | |
|
|
a7cffb547e | |
|
|
f0c37f233e | |
|
|
60953bcc2c | |
|
|
2c06288b40 | |
|
|
0284b9f9e3 | |
|
|
4e3f42dce3 | |
|
|
50a827e7fd | |
|
|
d81140ccec | |
|
|
c540a6657f | |
|
|
4808a30185 | |
|
|
1357f3cc4c | |
|
|
d9aa6f4956 | |
|
|
37a31c13cb | |
|
|
35bca7d4cc | |
|
|
f39d317d93 | |
|
|
d4d93aabf5 | |
|
|
c9271d9083 | |
|
|
4f55d08c51 | |
|
|
9c9ba4becd | |
|
|
93caab1aec | |
|
|
0c6bb3a641 | |
|
|
b2e15327fe | |
|
|
2e8be87ccf | |
|
|
d78360982c | |
|
|
bd0cefdb12 | |
|
|
e2ed1eb286 | |
|
|
a74def9561 |
|
|
@ -21,3 +21,4 @@ target
|
|||
#.idea/
|
||||
|
||||
proxy-secret
|
||||
coverage-html/
|
||||
|
|
@ -390,6 +390,12 @@ you MUST explain why existing invariants remain valid.
|
|||
- Do not modify existing tests unless the task explicitly requires it.
|
||||
- Do not weaken assertions.
|
||||
- Preserve determinism in testable components.
|
||||
- Bug-first forces the discipline of proving you understand a bug before you fix it. Tests written after a fix almost always pass trivially and catch nothing new.
|
||||
- Invariants over scenarios is the core shift. The route_mode table alone would have caught both BUG-1 and BUG-2 before they were written — "snapshot equals watch state after any transition burst" is a two-line property test that fails immediately on the current diverged-atomics code.
|
||||
- Differential/model catches logic drift over time.
|
||||
- Scheduler pressure is specifically aimed at the concurrent state bugs that keep reappearing. A single-threaded happy-path test of set_mode will never find subtle bugs; 10,000 concurrent calls will find it on the first run.
|
||||
- Mutation gate answers your original complaint directly. It measures test power. If you can remove a bounds check and nothing breaks, the suite isn't covering that branch yet — it just says so explicitly.
|
||||
- Dead parameter is a code smell rule.
|
||||
|
||||
### 15. Security Constraints
|
||||
|
||||
|
|
|
|||
|
|
@ -425,6 +425,32 @@ dependencies = [
|
|||
"cipher",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "curve25519-dalek"
|
||||
version = "4.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "97fb8b7c4503de7d6ae7b42ab72a5a59857b4c937ec27a3d4539dba95b5ab2be"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"cpufeatures",
|
||||
"curve25519-dalek-derive",
|
||||
"fiat-crypto",
|
||||
"rustc_version",
|
||||
"subtle",
|
||||
"zeroize",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "curve25519-dalek-derive"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f46882e17999c6cc590af592290432be3bce0428cb0d5f8b6715e4dc7b383eb3"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.114",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dashmap"
|
||||
version = "5.5.3"
|
||||
|
|
@ -517,6 +543,12 @@ version = "2.3.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
|
||||
|
||||
[[package]]
|
||||
name = "fiat-crypto"
|
||||
version = "0.2.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "28dea519a9695b9977216879a3ebfddf92f1c08c05d984f8996aecd6ecdc811d"
|
||||
|
||||
[[package]]
|
||||
name = "filetime"
|
||||
version = "0.2.27"
|
||||
|
|
@ -1609,7 +1641,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1"
|
||||
dependencies = [
|
||||
"rand_chacha",
|
||||
"rand_core",
|
||||
"rand_core 0.9.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -1619,9 +1651,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb"
|
||||
dependencies = [
|
||||
"ppv-lite86",
|
||||
"rand_core",
|
||||
"rand_core 0.9.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_core"
|
||||
version = "0.6.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
|
||||
|
||||
[[package]]
|
||||
name = "rand_core"
|
||||
version = "0.9.5"
|
||||
|
|
@ -1637,7 +1675,7 @@ version = "0.4.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "513962919efc330f829edb2535844d1b912b0fbe2ca165d613e4e8788bb05a5a"
|
||||
dependencies = [
|
||||
"rand_core",
|
||||
"rand_core 0.9.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -2093,7 +2131,7 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "telemt"
|
||||
version = "3.3.19"
|
||||
version = "3.3.20"
|
||||
dependencies = [
|
||||
"aes",
|
||||
"anyhow",
|
||||
|
|
@ -2145,6 +2183,7 @@ dependencies = [
|
|||
"tracing-subscriber",
|
||||
"url",
|
||||
"webpki-roots 0.26.11",
|
||||
"x25519-dalek",
|
||||
"x509-parser",
|
||||
"zeroize",
|
||||
]
|
||||
|
|
@ -3144,6 +3183,18 @@ version = "0.6.2"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9edde0db4769d2dc68579893f2306b26c6ecfbe0ef499b013d731b7b9247e0b9"
|
||||
|
||||
[[package]]
|
||||
name = "x25519-dalek"
|
||||
version = "2.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c7e468321c81fb07fa7f4c636c3972b9100f0346e5b6a9f2bd0603a52f7ed277"
|
||||
dependencies = [
|
||||
"curve25519-dalek",
|
||||
"rand_core 0.6.4",
|
||||
"serde",
|
||||
"zeroize",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "x509-parser"
|
||||
version = "0.15.1"
|
||||
|
|
|
|||
|
|
@ -52,6 +52,7 @@ regex = "1.11"
|
|||
crossbeam-queue = "0.3"
|
||||
num-bigint = "0.4"
|
||||
num-traits = "0.2"
|
||||
x25519-dalek = "2"
|
||||
anyhow = "1.0"
|
||||
|
||||
# HTTP
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ show = "*"
|
|||
port = 443
|
||||
# proxy_protocol = false # Enable if behind HAProxy/nginx with PROXY protocol
|
||||
# metrics_port = 9090
|
||||
# metrics_listen = "0.0.0.0:9090" # Listen address for metrics (overrides metrics_port)
|
||||
# metrics_whitelist = ["127.0.0.1", "::1", "0.0.0.0/0"]
|
||||
|
||||
[server.api]
|
||||
|
|
|
|||
|
|
@ -38,8 +38,9 @@ umweltschutz.de -> A-запись 198.18.88.88
|
|||
|
||||
В конфигурации Telemt:
|
||||
|
||||
```
|
||||
tls_domain = umweltschutz.de
|
||||
```toml
|
||||
[censorship]
|
||||
tls_domain = "umweltschutz.de"
|
||||
```
|
||||
|
||||
Этот домен используется клиентом как SNI в ClientHello
|
||||
|
|
@ -56,8 +57,9 @@ tls_domain = umweltschutz.de
|
|||
|
||||
В конфигурации Telemt:
|
||||
|
||||
```
|
||||
mask_host = 127.0.0.1
|
||||
```toml
|
||||
[censorship]
|
||||
mask_host = "127.0.0.1"
|
||||
mask_port = 8443
|
||||
```
|
||||
|
||||
|
|
@ -151,16 +153,18 @@ mask_host:mask_port
|
|||
|
||||
Например:
|
||||
|
||||
```
|
||||
tls_domain = github.com
|
||||
mask_host = github.com
|
||||
```toml
|
||||
[censorship]
|
||||
tls_domain = "github.com"
|
||||
mask_host = "github.com"
|
||||
mask_port = 443
|
||||
```
|
||||
|
||||
или
|
||||
|
||||
```
|
||||
mask_host = 140.82.121.4
|
||||
```toml
|
||||
[censorship]
|
||||
mask_host = "140.82.121.4"
|
||||
```
|
||||
|
||||
В этом случае:
|
||||
|
|
|
|||
|
|
@ -239,7 +239,7 @@ tls_full_cert_ttl_secs = 90
|
|||
|
||||
[access]
|
||||
replay_check_len = 65536
|
||||
replay_window_secs = 1800
|
||||
replay_window_secs = 120
|
||||
ignore_time_skew = false
|
||||
|
||||
[access.users]
|
||||
|
|
|
|||
|
|
@ -73,7 +73,9 @@ pub(crate) fn default_replay_check_len() -> usize {
|
|||
}
|
||||
|
||||
pub(crate) fn default_replay_window_secs() -> u64 {
|
||||
1800
|
||||
// Keep replay cache TTL tight by default to reduce replay surface.
|
||||
// Deployments with higher RTT or longer reconnect jitter can override this in config.
|
||||
120
|
||||
}
|
||||
|
||||
pub(crate) fn default_handshake_timeout() -> u64 {
|
||||
|
|
@ -456,11 +458,11 @@ pub(crate) fn default_tls_full_cert_ttl_secs() -> u64 {
|
|||
}
|
||||
|
||||
pub(crate) fn default_server_hello_delay_min_ms() -> u64 {
|
||||
0
|
||||
8
|
||||
}
|
||||
|
||||
pub(crate) fn default_server_hello_delay_max_ms() -> u64 {
|
||||
0
|
||||
24
|
||||
}
|
||||
|
||||
pub(crate) fn default_alpn_enforce() -> bool {
|
||||
|
|
|
|||
|
|
@ -37,7 +37,6 @@ use crate::config::{
|
|||
};
|
||||
use super::load::{LoadedConfig, ProxyConfig};
|
||||
|
||||
const HOT_RELOAD_STABLE_SNAPSHOTS: u8 = 2;
|
||||
const HOT_RELOAD_DEBOUNCE: Duration = Duration::from_millis(50);
|
||||
|
||||
// ── Hot fields ────────────────────────────────────────────────────────────────
|
||||
|
|
@ -329,41 +328,19 @@ impl WatchManifest {
|
|||
#[derive(Debug, Default)]
|
||||
struct ReloadState {
|
||||
applied_snapshot_hash: Option<u64>,
|
||||
candidate_snapshot_hash: Option<u64>,
|
||||
candidate_hits: u8,
|
||||
}
|
||||
|
||||
impl ReloadState {
|
||||
fn new(applied_snapshot_hash: Option<u64>) -> Self {
|
||||
Self {
|
||||
applied_snapshot_hash,
|
||||
candidate_snapshot_hash: None,
|
||||
candidate_hits: 0,
|
||||
}
|
||||
Self { applied_snapshot_hash }
|
||||
}
|
||||
|
||||
fn is_applied(&self, hash: u64) -> bool {
|
||||
self.applied_snapshot_hash == Some(hash)
|
||||
}
|
||||
|
||||
fn observe_candidate(&mut self, hash: u64) -> u8 {
|
||||
if self.candidate_snapshot_hash == Some(hash) {
|
||||
self.candidate_hits = self.candidate_hits.saturating_add(1);
|
||||
} else {
|
||||
self.candidate_snapshot_hash = Some(hash);
|
||||
self.candidate_hits = 1;
|
||||
}
|
||||
self.candidate_hits
|
||||
}
|
||||
|
||||
fn reset_candidate(&mut self) {
|
||||
self.candidate_snapshot_hash = None;
|
||||
self.candidate_hits = 0;
|
||||
}
|
||||
|
||||
fn mark_applied(&mut self, hash: u64) {
|
||||
self.applied_snapshot_hash = Some(hash);
|
||||
self.reset_candidate();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1138,7 +1115,6 @@ fn reload_config(
|
|||
let loaded = match ProxyConfig::load_with_metadata(config_path) {
|
||||
Ok(loaded) => loaded,
|
||||
Err(e) => {
|
||||
reload_state.reset_candidate();
|
||||
error!("config reload: failed to parse {:?}: {}", config_path, e);
|
||||
return None;
|
||||
}
|
||||
|
|
@ -1151,7 +1127,6 @@ fn reload_config(
|
|||
let next_manifest = WatchManifest::from_source_files(&source_files);
|
||||
|
||||
if let Err(e) = new_cfg.validate() {
|
||||
reload_state.reset_candidate();
|
||||
error!("config reload: validation failed: {}; keeping old config", e);
|
||||
return Some(next_manifest);
|
||||
}
|
||||
|
|
@ -1160,17 +1135,6 @@ fn reload_config(
|
|||
return Some(next_manifest);
|
||||
}
|
||||
|
||||
let candidate_hits = reload_state.observe_candidate(rendered_hash);
|
||||
if candidate_hits < HOT_RELOAD_STABLE_SNAPSHOTS {
|
||||
info!(
|
||||
snapshot_hash = rendered_hash,
|
||||
candidate_hits,
|
||||
required_hits = HOT_RELOAD_STABLE_SNAPSHOTS,
|
||||
"config reload: candidate snapshot observed but not stable yet"
|
||||
);
|
||||
return Some(next_manifest);
|
||||
}
|
||||
|
||||
let old_cfg = config_tx.borrow().clone();
|
||||
let applied_cfg = overlay_hot_fields(&old_cfg, &new_cfg);
|
||||
let old_hot = HotFields::from_config(&old_cfg);
|
||||
|
|
@ -1190,7 +1154,6 @@ fn reload_config(
|
|||
if old_hot.dns_overrides != applied_hot.dns_overrides
|
||||
&& let Err(e) = crate::network::dns_overrides::install_entries(&applied_hot.dns_overrides)
|
||||
{
|
||||
reload_state.reset_candidate();
|
||||
error!(
|
||||
"config reload: invalid network.dns_overrides: {}; keeping old config",
|
||||
e
|
||||
|
|
@ -1334,14 +1297,28 @@ pub fn spawn_config_watcher(
|
|||
tokio::time::sleep(HOT_RELOAD_DEBOUNCE).await;
|
||||
while notify_rx.try_recv().is_ok() {}
|
||||
|
||||
if let Some(next_manifest) = reload_config(
|
||||
let mut next_manifest = reload_config(
|
||||
&config_path,
|
||||
&config_tx,
|
||||
&log_tx,
|
||||
detected_ip_v4,
|
||||
detected_ip_v6,
|
||||
&mut reload_state,
|
||||
) {
|
||||
);
|
||||
if next_manifest.is_none() {
|
||||
tokio::time::sleep(HOT_RELOAD_DEBOUNCE).await;
|
||||
while notify_rx.try_recv().is_ok() {}
|
||||
next_manifest = reload_config(
|
||||
&config_path,
|
||||
&config_tx,
|
||||
&log_tx,
|
||||
detected_ip_v4,
|
||||
detected_ip_v6,
|
||||
&mut reload_state,
|
||||
);
|
||||
}
|
||||
|
||||
if let Some(next_manifest) = next_manifest {
|
||||
apply_watch_manifest(
|
||||
inotify_watcher.as_mut(),
|
||||
poll_watcher.as_mut(),
|
||||
|
|
@ -1466,7 +1443,7 @@ mod tests {
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn reload_requires_stable_snapshot_before_hot_apply() {
|
||||
fn reload_applies_hot_change_on_first_observed_snapshot() {
|
||||
let initial_tag = "11111111111111111111111111111111";
|
||||
let final_tag = "22222222222222222222222222222222";
|
||||
let path = temp_config_path("telemt_hot_reload_stable");
|
||||
|
|
@ -1478,20 +1455,7 @@ mod tests {
|
|||
let (log_tx, _log_rx) = watch::channel(initial_cfg.general.log_level.clone());
|
||||
let mut reload_state = ReloadState::new(Some(initial_hash));
|
||||
|
||||
write_reload_config(&path, None, None);
|
||||
reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap();
|
||||
assert_eq!(
|
||||
config_tx.borrow().general.ad_tag.as_deref(),
|
||||
Some(initial_tag)
|
||||
);
|
||||
|
||||
write_reload_config(&path, Some(final_tag), None);
|
||||
reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap();
|
||||
assert_eq!(
|
||||
config_tx.borrow().general.ad_tag.as_deref(),
|
||||
Some(initial_tag)
|
||||
);
|
||||
|
||||
reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap();
|
||||
assert_eq!(config_tx.borrow().general.ad_tag.as_deref(), Some(final_tag));
|
||||
|
||||
|
|
@ -1513,7 +1477,6 @@ mod tests {
|
|||
|
||||
write_reload_config(&path, Some(final_tag), Some(initial_cfg.server.port + 1));
|
||||
reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap();
|
||||
reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap();
|
||||
|
||||
let applied = config_tx.borrow().clone();
|
||||
assert_eq!(applied.general.ad_tag.as_deref(), Some(final_tag));
|
||||
|
|
@ -1521,4 +1484,31 @@ mod tests {
|
|||
|
||||
let _ = std::fs::remove_file(path);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reload_recovers_after_parse_error_on_next_attempt() {
|
||||
let initial_tag = "cccccccccccccccccccccccccccccccc";
|
||||
let final_tag = "dddddddddddddddddddddddddddddddd";
|
||||
let path = temp_config_path("telemt_hot_reload_parse_recovery");
|
||||
|
||||
write_reload_config(&path, Some(initial_tag), None);
|
||||
let initial_cfg = Arc::new(ProxyConfig::load(&path).unwrap());
|
||||
let initial_hash = ProxyConfig::load_with_metadata(&path).unwrap().rendered_hash;
|
||||
let (config_tx, _config_rx) = watch::channel(initial_cfg.clone());
|
||||
let (log_tx, _log_rx) = watch::channel(initial_cfg.general.log_level.clone());
|
||||
let mut reload_state = ReloadState::new(Some(initial_hash));
|
||||
|
||||
std::fs::write(&path, "[access.users\nuser = \"broken\"\n").unwrap();
|
||||
assert!(reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).is_none());
|
||||
assert_eq!(
|
||||
config_tx.borrow().general.ad_tag.as_deref(),
|
||||
Some(initial_tag)
|
||||
);
|
||||
|
||||
write_reload_config(&path, Some(final_tag), None);
|
||||
reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap();
|
||||
assert_eq!(config_tx.borrow().general.ad_tag.as_deref(), Some(final_tag));
|
||||
|
||||
let _ = std::fs::remove_file(path);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1163,9 +1163,17 @@ pub struct ServerConfig {
|
|||
#[serde(default)]
|
||||
pub proxy_protocol_trusted_cidrs: Vec<IpNetwork>,
|
||||
|
||||
/// Port for the Prometheus-compatible metrics endpoint.
|
||||
/// Enables metrics when set; binds on all interfaces (dual-stack) by default.
|
||||
#[serde(default)]
|
||||
pub metrics_port: Option<u16>,
|
||||
|
||||
/// Listen address for metrics in `IP:PORT` format (e.g. `"127.0.0.1:9090"`).
|
||||
/// When set, takes precedence over `metrics_port` and binds on the specified address only.
|
||||
#[serde(default)]
|
||||
pub metrics_listen: Option<String>,
|
||||
|
||||
/// CIDR whitelist for the metrics endpoint.
|
||||
#[serde(default = "default_metrics_whitelist")]
|
||||
pub metrics_whitelist: Vec<IpNetwork>,
|
||||
|
||||
|
|
@ -1194,6 +1202,7 @@ impl Default for ServerConfig {
|
|||
proxy_protocol_header_timeout_ms: default_proxy_protocol_header_timeout_ms(),
|
||||
proxy_protocol_trusted_cidrs: Vec::new(),
|
||||
metrics_port: None,
|
||||
metrics_listen: None,
|
||||
metrics_whitelist: default_metrics_whitelist(),
|
||||
api: ApiConfig::default(),
|
||||
listeners: Vec::new(),
|
||||
|
|
|
|||
|
|
@ -7,8 +7,9 @@ use std::net::IpAddr;
|
|||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::time::{Duration, Instant};
|
||||
use std::sync::Mutex;
|
||||
|
||||
use tokio::sync::RwLock;
|
||||
use tokio::sync::{Mutex as AsyncMutex, RwLock};
|
||||
|
||||
use crate::config::UserMaxUniqueIpsMode;
|
||||
|
||||
|
|
@ -21,6 +22,8 @@ pub struct UserIpTracker {
|
|||
limit_mode: Arc<RwLock<UserMaxUniqueIpsMode>>,
|
||||
limit_window: Arc<RwLock<Duration>>,
|
||||
last_compact_epoch_secs: Arc<AtomicU64>,
|
||||
pub(crate) cleanup_queue: Arc<Mutex<Vec<(String, IpAddr)>>>,
|
||||
cleanup_drain_lock: Arc<AsyncMutex<()>>,
|
||||
}
|
||||
|
||||
impl UserIpTracker {
|
||||
|
|
@ -33,6 +36,67 @@ impl UserIpTracker {
|
|||
limit_mode: Arc::new(RwLock::new(UserMaxUniqueIpsMode::ActiveWindow)),
|
||||
limit_window: Arc::new(RwLock::new(Duration::from_secs(30))),
|
||||
last_compact_epoch_secs: Arc::new(AtomicU64::new(0)),
|
||||
cleanup_queue: Arc::new(Mutex::new(Vec::new())),
|
||||
cleanup_drain_lock: Arc::new(AsyncMutex::new(())),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
pub fn enqueue_cleanup(&self, user: String, ip: IpAddr) {
|
||||
match self.cleanup_queue.lock() {
|
||||
Ok(mut queue) => queue.push((user, ip)),
|
||||
Err(poisoned) => {
|
||||
let mut queue = poisoned.into_inner();
|
||||
queue.push((user.clone(), ip));
|
||||
self.cleanup_queue.clear_poison();
|
||||
tracing::warn!(
|
||||
"UserIpTracker cleanup_queue lock poisoned; recovered and enqueued IP cleanup for {} ({})",
|
||||
user,
|
||||
ip
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn drain_cleanup_queue(&self) {
|
||||
// Serialize queue draining and active-IP mutation so check-and-add cannot
|
||||
// observe stale active entries that are already queued for removal.
|
||||
let _drain_guard = self.cleanup_drain_lock.lock().await;
|
||||
let to_remove = {
|
||||
match self.cleanup_queue.lock() {
|
||||
Ok(mut queue) => {
|
||||
if queue.is_empty() {
|
||||
return;
|
||||
}
|
||||
std::mem::take(&mut *queue)
|
||||
}
|
||||
Err(poisoned) => {
|
||||
let mut queue = poisoned.into_inner();
|
||||
if queue.is_empty() {
|
||||
self.cleanup_queue.clear_poison();
|
||||
return;
|
||||
}
|
||||
let drained = std::mem::take(&mut *queue);
|
||||
self.cleanup_queue.clear_poison();
|
||||
drained
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let mut active_ips = self.active_ips.write().await;
|
||||
for (user, ip) in to_remove {
|
||||
if let Some(user_ips) = active_ips.get_mut(&user) {
|
||||
if let Some(count) = user_ips.get_mut(&ip) {
|
||||
if *count > 1 {
|
||||
*count -= 1;
|
||||
} else {
|
||||
user_ips.remove(&ip);
|
||||
}
|
||||
}
|
||||
if user_ips.is_empty() {
|
||||
active_ips.remove(&user);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -118,6 +182,7 @@ impl UserIpTracker {
|
|||
}
|
||||
|
||||
pub async fn check_and_add(&self, username: &str, ip: IpAddr) -> Result<(), String> {
|
||||
self.drain_cleanup_queue().await;
|
||||
self.maybe_compact_empty_users().await;
|
||||
let default_max_ips = *self.default_max_ips.read().await;
|
||||
let limit = {
|
||||
|
|
@ -194,6 +259,7 @@ impl UserIpTracker {
|
|||
}
|
||||
|
||||
pub async fn get_recent_counts_for_users(&self, users: &[String]) -> HashMap<String, usize> {
|
||||
self.drain_cleanup_queue().await;
|
||||
let window = *self.limit_window.read().await;
|
||||
let now = Instant::now();
|
||||
let recent_ips = self.recent_ips.read().await;
|
||||
|
|
@ -214,6 +280,7 @@ impl UserIpTracker {
|
|||
}
|
||||
|
||||
pub async fn get_active_ips_for_users(&self, users: &[String]) -> HashMap<String, Vec<IpAddr>> {
|
||||
self.drain_cleanup_queue().await;
|
||||
let active_ips = self.active_ips.read().await;
|
||||
let mut out = HashMap::with_capacity(users.len());
|
||||
for user in users {
|
||||
|
|
@ -228,6 +295,7 @@ impl UserIpTracker {
|
|||
}
|
||||
|
||||
pub async fn get_recent_ips_for_users(&self, users: &[String]) -> HashMap<String, Vec<IpAddr>> {
|
||||
self.drain_cleanup_queue().await;
|
||||
let window = *self.limit_window.read().await;
|
||||
let now = Instant::now();
|
||||
let recent_ips = self.recent_ips.read().await;
|
||||
|
|
@ -250,11 +318,13 @@ impl UserIpTracker {
|
|||
}
|
||||
|
||||
pub async fn get_active_ip_count(&self, username: &str) -> usize {
|
||||
self.drain_cleanup_queue().await;
|
||||
let active_ips = self.active_ips.read().await;
|
||||
active_ips.get(username).map(|ips| ips.len()).unwrap_or(0)
|
||||
}
|
||||
|
||||
pub async fn get_active_ips(&self, username: &str) -> Vec<IpAddr> {
|
||||
self.drain_cleanup_queue().await;
|
||||
let active_ips = self.active_ips.read().await;
|
||||
active_ips
|
||||
.get(username)
|
||||
|
|
@ -263,6 +333,7 @@ impl UserIpTracker {
|
|||
}
|
||||
|
||||
pub async fn get_stats(&self) -> Vec<(String, usize, usize)> {
|
||||
self.drain_cleanup_queue().await;
|
||||
let active_ips = self.active_ips.read().await;
|
||||
let max_ips = self.max_ips.read().await;
|
||||
let default_max_ips = *self.default_max_ips.read().await;
|
||||
|
|
@ -301,6 +372,7 @@ impl UserIpTracker {
|
|||
}
|
||||
|
||||
pub async fn is_ip_active(&self, username: &str, ip: IpAddr) -> bool {
|
||||
self.drain_cleanup_queue().await;
|
||||
let active_ips = self.active_ips.read().await;
|
||||
active_ips
|
||||
.get(username)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,619 @@
|
|||
use std::collections::HashMap;
|
||||
use std::net::{IpAddr, Ipv4Addr};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::config::UserMaxUniqueIpsMode;
|
||||
use crate::ip_tracker::UserIpTracker;
|
||||
|
||||
fn ip_from_idx(idx: u32) -> IpAddr {
|
||||
let a = 10u8;
|
||||
let b = ((idx / 65_536) % 256) as u8;
|
||||
let c = ((idx / 256) % 256) as u8;
|
||||
let d = (idx % 256) as u8;
|
||||
IpAddr::V4(Ipv4Addr::new(a, b, c, d))
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn active_window_enforces_large_unique_ip_burst() {
|
||||
let tracker = UserIpTracker::new();
|
||||
tracker.set_user_limit("burst_user", 64).await;
|
||||
tracker
|
||||
.set_limit_policy(UserMaxUniqueIpsMode::ActiveWindow, 30)
|
||||
.await;
|
||||
|
||||
for idx in 0..64 {
|
||||
assert!(tracker.check_and_add("burst_user", ip_from_idx(idx)).await.is_ok());
|
||||
}
|
||||
assert!(tracker.check_and_add("burst_user", ip_from_idx(9_999)).await.is_err());
|
||||
assert_eq!(tracker.get_active_ip_count("burst_user").await, 64);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn global_limit_applies_across_many_users() {
|
||||
let tracker = UserIpTracker::new();
|
||||
tracker.load_limits(3, &HashMap::new()).await;
|
||||
|
||||
for user_idx in 0..150u32 {
|
||||
let user = format!("u{}", user_idx);
|
||||
assert!(tracker.check_and_add(&user, ip_from_idx(user_idx * 10)).await.is_ok());
|
||||
assert!(tracker
|
||||
.check_and_add(&user, ip_from_idx(user_idx * 10 + 1))
|
||||
.await
|
||||
.is_ok());
|
||||
assert!(tracker
|
||||
.check_and_add(&user, ip_from_idx(user_idx * 10 + 2))
|
||||
.await
|
||||
.is_ok());
|
||||
assert!(tracker
|
||||
.check_and_add(&user, ip_from_idx(user_idx * 10 + 3))
|
||||
.await
|
||||
.is_err());
|
||||
}
|
||||
|
||||
assert_eq!(tracker.get_stats().await.len(), 150);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn user_zero_override_falls_back_to_global_limit() {
|
||||
let tracker = UserIpTracker::new();
|
||||
let mut limits = HashMap::new();
|
||||
limits.insert("target".to_string(), 0);
|
||||
tracker.load_limits(2, &limits).await;
|
||||
|
||||
assert!(tracker.check_and_add("target", ip_from_idx(1)).await.is_ok());
|
||||
assert!(tracker.check_and_add("target", ip_from_idx(2)).await.is_ok());
|
||||
assert!(tracker.check_and_add("target", ip_from_idx(3)).await.is_err());
|
||||
assert_eq!(tracker.get_user_limit("target").await, Some(2));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn remove_ip_is_idempotent_after_counter_reaches_zero() {
|
||||
let tracker = UserIpTracker::new();
|
||||
tracker.set_user_limit("u", 2).await;
|
||||
let ip = ip_from_idx(42);
|
||||
|
||||
tracker.check_and_add("u", ip).await.unwrap();
|
||||
tracker.remove_ip("u", ip).await;
|
||||
tracker.remove_ip("u", ip).await;
|
||||
tracker.remove_ip("u", ip).await;
|
||||
|
||||
assert_eq!(tracker.get_active_ip_count("u").await, 0);
|
||||
assert!(!tracker.is_ip_active("u", ip).await);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn clear_user_ips_resets_active_and_recent() {
|
||||
let tracker = UserIpTracker::new();
|
||||
tracker.set_user_limit("u", 10).await;
|
||||
|
||||
for idx in 0..6 {
|
||||
tracker.check_and_add("u", ip_from_idx(idx)).await.unwrap();
|
||||
}
|
||||
|
||||
tracker.clear_user_ips("u").await;
|
||||
|
||||
assert_eq!(tracker.get_active_ip_count("u").await, 0);
|
||||
let counts = tracker
|
||||
.get_recent_counts_for_users(&["u".to_string()])
|
||||
.await;
|
||||
assert_eq!(counts.get("u").copied().unwrap_or(0), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn clear_all_resets_multi_user_state() {
|
||||
let tracker = UserIpTracker::new();
|
||||
|
||||
for user_idx in 0..80u32 {
|
||||
let user = format!("u{}", user_idx);
|
||||
for ip_idx in 0..3 {
|
||||
tracker
|
||||
.check_and_add(&user, ip_from_idx(user_idx * 100 + ip_idx))
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
tracker.clear_all().await;
|
||||
|
||||
assert!(tracker.get_stats().await.is_empty());
|
||||
let users = (0..80u32)
|
||||
.map(|idx| format!("u{}", idx))
|
||||
.collect::<Vec<_>>();
|
||||
let recent = tracker.get_recent_counts_for_users(&users).await;
|
||||
assert!(recent.values().all(|count| *count == 0));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn get_active_ips_for_users_are_sorted() {
|
||||
let tracker = UserIpTracker::new();
|
||||
tracker.set_user_limit("user", 10).await;
|
||||
|
||||
tracker
|
||||
.check_and_add("user", IpAddr::V4(Ipv4Addr::new(10, 0, 0, 9)))
|
||||
.await
|
||||
.unwrap();
|
||||
tracker
|
||||
.check_and_add("user", IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)))
|
||||
.await
|
||||
.unwrap();
|
||||
tracker
|
||||
.check_and_add("user", IpAddr::V4(Ipv4Addr::new(10, 0, 0, 5)))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let map = tracker
|
||||
.get_active_ips_for_users(&["user".to_string()])
|
||||
.await;
|
||||
let ips = map.get("user").cloned().unwrap_or_default();
|
||||
|
||||
assert_eq!(
|
||||
ips,
|
||||
vec![
|
||||
IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
|
||||
IpAddr::V4(Ipv4Addr::new(10, 0, 0, 5)),
|
||||
IpAddr::V4(Ipv4Addr::new(10, 0, 0, 9)),
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn get_recent_ips_for_users_are_sorted() {
|
||||
let tracker = UserIpTracker::new();
|
||||
tracker.set_user_limit("user", 10).await;
|
||||
|
||||
tracker
|
||||
.check_and_add("user", IpAddr::V4(Ipv4Addr::new(10, 1, 0, 9)))
|
||||
.await
|
||||
.unwrap();
|
||||
tracker
|
||||
.check_and_add("user", IpAddr::V4(Ipv4Addr::new(10, 1, 0, 1)))
|
||||
.await
|
||||
.unwrap();
|
||||
tracker
|
||||
.check_and_add("user", IpAddr::V4(Ipv4Addr::new(10, 1, 0, 5)))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let map = tracker
|
||||
.get_recent_ips_for_users(&["user".to_string()])
|
||||
.await;
|
||||
let ips = map.get("user").cloned().unwrap_or_default();
|
||||
|
||||
assert_eq!(
|
||||
ips,
|
||||
vec![
|
||||
IpAddr::V4(Ipv4Addr::new(10, 1, 0, 1)),
|
||||
IpAddr::V4(Ipv4Addr::new(10, 1, 0, 5)),
|
||||
IpAddr::V4(Ipv4Addr::new(10, 1, 0, 9)),
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn time_window_expires_for_large_rotation() {
|
||||
let tracker = UserIpTracker::new();
|
||||
tracker.set_user_limit("tw", 1).await;
|
||||
tracker
|
||||
.set_limit_policy(UserMaxUniqueIpsMode::TimeWindow, 1)
|
||||
.await;
|
||||
|
||||
tracker.check_and_add("tw", ip_from_idx(1)).await.unwrap();
|
||||
tracker.remove_ip("tw", ip_from_idx(1)).await;
|
||||
assert!(tracker.check_and_add("tw", ip_from_idx(2)).await.is_err());
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(1_100)).await;
|
||||
assert!(tracker.check_and_add("tw", ip_from_idx(2)).await.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn combined_mode_blocks_recent_after_disconnect() {
|
||||
let tracker = UserIpTracker::new();
|
||||
tracker.set_user_limit("cmb", 1).await;
|
||||
tracker
|
||||
.set_limit_policy(UserMaxUniqueIpsMode::Combined, 2)
|
||||
.await;
|
||||
|
||||
tracker.check_and_add("cmb", ip_from_idx(11)).await.unwrap();
|
||||
tracker.remove_ip("cmb", ip_from_idx(11)).await;
|
||||
|
||||
assert!(tracker.check_and_add("cmb", ip_from_idx(12)).await.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn load_limits_replaces_large_limit_map() {
|
||||
let tracker = UserIpTracker::new();
|
||||
let mut first = HashMap::new();
|
||||
let mut second = HashMap::new();
|
||||
|
||||
for idx in 0..300usize {
|
||||
first.insert(format!("u{}", idx), 2usize);
|
||||
}
|
||||
for idx in 150..450usize {
|
||||
second.insert(format!("u{}", idx), 4usize);
|
||||
}
|
||||
|
||||
tracker.load_limits(0, &first).await;
|
||||
tracker.load_limits(0, &second).await;
|
||||
|
||||
assert_eq!(tracker.get_user_limit("u20").await, None);
|
||||
assert_eq!(tracker.get_user_limit("u200").await, Some(4));
|
||||
assert_eq!(tracker.get_user_limit("u420").await, Some(4));
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn concurrent_same_user_unique_ip_pressure_stays_bounded() {
|
||||
let tracker = Arc::new(UserIpTracker::new());
|
||||
tracker.set_user_limit("hot", 32).await;
|
||||
tracker
|
||||
.set_limit_policy(UserMaxUniqueIpsMode::ActiveWindow, 30)
|
||||
.await;
|
||||
|
||||
let mut handles = Vec::new();
|
||||
for worker in 0..16u32 {
|
||||
let tracker_cloned = tracker.clone();
|
||||
handles.push(tokio::spawn(async move {
|
||||
let base = worker * 200;
|
||||
for step in 0..200u32 {
|
||||
let _ = tracker_cloned
|
||||
.check_and_add("hot", ip_from_idx(base + step))
|
||||
.await;
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
for handle in handles {
|
||||
handle.await.unwrap();
|
||||
}
|
||||
|
||||
assert!(tracker.get_active_ip_count("hot").await <= 32);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn concurrent_many_users_isolate_limits() {
|
||||
let tracker = Arc::new(UserIpTracker::new());
|
||||
tracker.load_limits(4, &HashMap::new()).await;
|
||||
|
||||
let mut handles = Vec::new();
|
||||
for user_idx in 0..120u32 {
|
||||
let tracker_cloned = tracker.clone();
|
||||
handles.push(tokio::spawn(async move {
|
||||
let user = format!("u{}", user_idx);
|
||||
for ip_idx in 0..10u32 {
|
||||
let _ = tracker_cloned
|
||||
.check_and_add(&user, ip_from_idx(user_idx * 1_000 + ip_idx))
|
||||
.await;
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
for handle in handles {
|
||||
handle.await.unwrap();
|
||||
}
|
||||
|
||||
let stats = tracker.get_stats().await;
|
||||
assert_eq!(stats.len(), 120);
|
||||
assert!(stats.iter().all(|(_, active, limit)| *active <= 4 && *limit == 4));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn same_ip_reconnect_high_frequency_keeps_single_unique() {
|
||||
let tracker = UserIpTracker::new();
|
||||
tracker.set_user_limit("same", 2).await;
|
||||
let ip = ip_from_idx(9);
|
||||
|
||||
for _ in 0..2_000 {
|
||||
tracker.check_and_add("same", ip).await.unwrap();
|
||||
}
|
||||
|
||||
assert_eq!(tracker.get_active_ip_count("same").await, 1);
|
||||
assert!(tracker.is_ip_active("same", ip).await);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn format_stats_contains_expected_limited_and_unlimited_markers() {
|
||||
let tracker = UserIpTracker::new();
|
||||
tracker.set_user_limit("limited", 2).await;
|
||||
tracker.check_and_add("limited", ip_from_idx(1)).await.unwrap();
|
||||
tracker.check_and_add("open", ip_from_idx(2)).await.unwrap();
|
||||
|
||||
let text = tracker.format_stats().await;
|
||||
|
||||
assert!(text.contains("limited"));
|
||||
assert!(text.contains("open"));
|
||||
assert!(text.contains("unlimited"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn stats_report_global_default_for_users_without_override() {
|
||||
let tracker = UserIpTracker::new();
|
||||
tracker.load_limits(5, &HashMap::new()).await;
|
||||
|
||||
tracker.check_and_add("a", ip_from_idx(1)).await.unwrap();
|
||||
tracker.check_and_add("b", ip_from_idx(2)).await.unwrap();
|
||||
|
||||
let stats = tracker.get_stats().await;
|
||||
assert!(stats.iter().any(|(user, _, limit)| user == "a" && *limit == 5));
|
||||
assert!(stats.iter().any(|(user, _, limit)| user == "b" && *limit == 5));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn stress_cycle_add_remove_clear_preserves_empty_end_state() {
|
||||
let tracker = UserIpTracker::new();
|
||||
|
||||
for cycle in 0..50u32 {
|
||||
let user = format!("cycle{}", cycle);
|
||||
tracker.set_user_limit(&user, 128).await;
|
||||
|
||||
for ip_idx in 0..128u32 {
|
||||
tracker
|
||||
.check_and_add(&user, ip_from_idx(cycle * 10_000 + ip_idx))
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
for ip_idx in 0..128u32 {
|
||||
tracker
|
||||
.remove_ip(&user, ip_from_idx(cycle * 10_000 + ip_idx))
|
||||
.await;
|
||||
}
|
||||
|
||||
tracker.clear_user_ips(&user).await;
|
||||
}
|
||||
|
||||
assert!(tracker.get_stats().await.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn remove_unknown_user_or_ip_does_not_corrupt_state() {
|
||||
let tracker = UserIpTracker::new();
|
||||
|
||||
tracker.remove_ip("no_user", ip_from_idx(1)).await;
|
||||
tracker.check_and_add("x", ip_from_idx(2)).await.unwrap();
|
||||
tracker.remove_ip("x", ip_from_idx(3)).await;
|
||||
|
||||
assert_eq!(tracker.get_active_ip_count("x").await, 1);
|
||||
assert!(tracker.is_ip_active("x", ip_from_idx(2)).await);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn active_and_recent_views_match_after_mixed_workload() {
|
||||
let tracker = UserIpTracker::new();
|
||||
tracker.set_user_limit("mix", 16).await;
|
||||
|
||||
for ip_idx in 0..12u32 {
|
||||
tracker.check_and_add("mix", ip_from_idx(ip_idx)).await.unwrap();
|
||||
}
|
||||
for ip_idx in 0..6u32 {
|
||||
tracker.remove_ip("mix", ip_from_idx(ip_idx)).await;
|
||||
}
|
||||
|
||||
let active = tracker
|
||||
.get_active_ips_for_users(&["mix".to_string()])
|
||||
.await
|
||||
.get("mix")
|
||||
.cloned()
|
||||
.unwrap_or_default();
|
||||
let recent_count = tracker
|
||||
.get_recent_counts_for_users(&["mix".to_string()])
|
||||
.await
|
||||
.get("mix")
|
||||
.copied()
|
||||
.unwrap_or(0);
|
||||
|
||||
assert_eq!(active.len(), 6);
|
||||
assert!(recent_count >= active.len());
|
||||
assert!(recent_count <= 12);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn global_limit_switch_updates_enforcement_immediately() {
|
||||
let tracker = UserIpTracker::new();
|
||||
tracker.load_limits(2, &HashMap::new()).await;
|
||||
|
||||
assert!(tracker.check_and_add("u", ip_from_idx(1)).await.is_ok());
|
||||
assert!(tracker.check_and_add("u", ip_from_idx(2)).await.is_ok());
|
||||
assert!(tracker.check_and_add("u", ip_from_idx(3)).await.is_err());
|
||||
|
||||
tracker.clear_user_ips("u").await;
|
||||
tracker.load_limits(4, &HashMap::new()).await;
|
||||
|
||||
assert!(tracker.check_and_add("u", ip_from_idx(1)).await.is_ok());
|
||||
assert!(tracker.check_and_add("u", ip_from_idx(2)).await.is_ok());
|
||||
assert!(tracker.check_and_add("u", ip_from_idx(3)).await.is_ok());
|
||||
assert!(tracker.check_and_add("u", ip_from_idx(4)).await.is_ok());
|
||||
assert!(tracker.check_and_add("u", ip_from_idx(5)).await.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn concurrent_reconnect_and_disconnect_preserves_non_negative_counts() {
|
||||
let tracker = Arc::new(UserIpTracker::new());
|
||||
tracker.set_user_limit("cc", 8).await;
|
||||
|
||||
let mut handles = Vec::new();
|
||||
for worker in 0..8u32 {
|
||||
let tracker_cloned = tracker.clone();
|
||||
handles.push(tokio::spawn(async move {
|
||||
let ip = ip_from_idx(50 + worker);
|
||||
for _ in 0..500u32 {
|
||||
let _ = tracker_cloned.check_and_add("cc", ip).await;
|
||||
tracker_cloned.remove_ip("cc", ip).await;
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
for handle in handles {
|
||||
handle.await.unwrap();
|
||||
}
|
||||
|
||||
assert!(tracker.get_active_ip_count("cc").await <= 8);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn enqueue_cleanup_recovers_from_poisoned_mutex() {
|
||||
let tracker = UserIpTracker::new();
|
||||
let ip = ip_from_idx(99);
|
||||
|
||||
// Poison the lock by panicking while holding it
|
||||
let result = std::panic::catch_unwind(|| {
|
||||
let _guard = tracker.cleanup_queue.lock().unwrap();
|
||||
panic!("Intentional poison panic");
|
||||
});
|
||||
assert!(result.is_err(), "Expected panic to poison mutex");
|
||||
|
||||
// Attempt to enqueue anyway; should hit the poison catch arm and still insert
|
||||
tracker.enqueue_cleanup("poison-user".to_string(), ip);
|
||||
|
||||
tracker.drain_cleanup_queue().await;
|
||||
|
||||
assert_eq!(tracker.get_active_ip_count("poison-user").await, 0);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn mass_reconnect_sync_cleanup_prevents_temporary_reservation_bloat() {
|
||||
// Tests that synchronous M-01 drop mechanism protects against starvation
|
||||
let tracker = Arc::new(UserIpTracker::new());
|
||||
tracker.set_user_limit("mass", 5).await;
|
||||
|
||||
let ip = ip_from_idx(42);
|
||||
let mut join_handles = Vec::new();
|
||||
|
||||
// 10,000 rapid concurrent requests hitting the same IP limit
|
||||
for _ in 0..10_000 {
|
||||
let tracker_clone = tracker.clone();
|
||||
join_handles.push(tokio::spawn(async move {
|
||||
if tracker_clone.check_and_add("mass", ip).await.is_ok() {
|
||||
// Instantly enqueue cleanup, simulating synchronous reservation drop
|
||||
tracker_clone.enqueue_cleanup("mass".to_string(), ip);
|
||||
// The next caller will drain it before acquiring again
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
for handle in join_handles {
|
||||
let _ = handle.await;
|
||||
}
|
||||
|
||||
// Force flush
|
||||
tracker.drain_cleanup_queue().await;
|
||||
assert_eq!(tracker.get_active_ip_count("mass").await, 0, "No leaked footprints");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn adversarial_drain_cleanup_queue_race_does_not_cause_false_rejections() {
|
||||
// Regression guard: concurrent cleanup draining must not produce false
|
||||
// limit denials for a new IP when the previous IP is already queued.
|
||||
let tracker = Arc::new(UserIpTracker::new());
|
||||
tracker.set_user_limit("racer", 1).await;
|
||||
let ip1 = ip_from_idx(1);
|
||||
let ip2 = ip_from_idx(2);
|
||||
|
||||
// Initial state: add ip1
|
||||
tracker.check_and_add("racer", ip1).await.unwrap();
|
||||
|
||||
// User disconnects from ip1, queuing it
|
||||
tracker.enqueue_cleanup("racer".to_string(), ip1);
|
||||
|
||||
let mut saw_false_rejection = false;
|
||||
for _ in 0..100 {
|
||||
// Queue cleanup then race explicit drain and check-and-add on the alternative IP.
|
||||
tracker.enqueue_cleanup("racer".to_string(), ip1);
|
||||
let tracker_a = tracker.clone();
|
||||
let tracker_b = tracker.clone();
|
||||
|
||||
let drain_handle = tokio::spawn(async move {
|
||||
tracker_a.drain_cleanup_queue().await;
|
||||
});
|
||||
let handle = tokio::spawn(async move {
|
||||
tracker_b.check_and_add("racer", ip2).await
|
||||
});
|
||||
|
||||
drain_handle.await.unwrap();
|
||||
let res = handle.await.unwrap();
|
||||
if res.is_err() {
|
||||
saw_false_rejection = true;
|
||||
break;
|
||||
}
|
||||
|
||||
// Restore baseline for next iteration.
|
||||
tracker.remove_ip("racer", ip2).await;
|
||||
tracker.check_and_add("racer", ip1).await.unwrap();
|
||||
}
|
||||
|
||||
assert!(
|
||||
!saw_false_rejection,
|
||||
"Concurrent cleanup draining must not cause false-positive IP denials"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn poisoned_cleanup_queue_still_releases_slot_for_next_ip() {
|
||||
let tracker = UserIpTracker::new();
|
||||
tracker.set_user_limit("poison-slot", 1).await;
|
||||
let ip1 = ip_from_idx(7001);
|
||||
let ip2 = ip_from_idx(7002);
|
||||
|
||||
tracker.check_and_add("poison-slot", ip1).await.unwrap();
|
||||
|
||||
// Poison the queue lock as an adversarial condition.
|
||||
let _ = std::panic::catch_unwind(|| {
|
||||
let _guard = tracker.cleanup_queue.lock().unwrap();
|
||||
panic!("intentional queue poison");
|
||||
});
|
||||
|
||||
// Disconnect path must still queue cleanup so the next IP can be admitted.
|
||||
tracker.enqueue_cleanup("poison-slot".to_string(), ip1);
|
||||
let admitted = tracker.check_and_add("poison-slot", ip2).await;
|
||||
assert!(
|
||||
admitted.is_ok(),
|
||||
"cleanup queue poison must not permanently block slot release for the next IP"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn duplicate_cleanup_entries_do_not_break_future_admission() {
|
||||
let tracker = UserIpTracker::new();
|
||||
tracker.set_user_limit("dup-cleanup", 1).await;
|
||||
let ip1 = ip_from_idx(7101);
|
||||
let ip2 = ip_from_idx(7102);
|
||||
|
||||
tracker.check_and_add("dup-cleanup", ip1).await.unwrap();
|
||||
tracker.enqueue_cleanup("dup-cleanup".to_string(), ip1);
|
||||
tracker.enqueue_cleanup("dup-cleanup".to_string(), ip1);
|
||||
tracker.enqueue_cleanup("dup-cleanup".to_string(), ip1);
|
||||
|
||||
tracker.drain_cleanup_queue().await;
|
||||
|
||||
assert_eq!(tracker.get_active_ip_count("dup-cleanup").await, 0);
|
||||
assert!(
|
||||
tracker.check_and_add("dup-cleanup", ip2).await.is_ok(),
|
||||
"extra queued cleanup entries must not leave user stuck in denied state"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn stress_repeated_queue_poison_recovery_preserves_admission_progress() {
|
||||
let tracker = UserIpTracker::new();
|
||||
tracker.set_user_limit("poison-stress", 1).await;
|
||||
let ip_primary = ip_from_idx(7201);
|
||||
let ip_alt = ip_from_idx(7202);
|
||||
|
||||
tracker.check_and_add("poison-stress", ip_primary).await.unwrap();
|
||||
|
||||
for _ in 0..64 {
|
||||
let _ = std::panic::catch_unwind(|| {
|
||||
let _guard = tracker.cleanup_queue.lock().unwrap();
|
||||
panic!("intentional queue poison in stress loop");
|
||||
});
|
||||
|
||||
tracker.enqueue_cleanup("poison-stress".to_string(), ip_primary);
|
||||
|
||||
assert!(
|
||||
tracker.check_and_add("poison-stress", ip_alt).await.is_ok(),
|
||||
"poison recovery must preserve admission progress under repeated queue poisoning"
|
||||
);
|
||||
|
||||
tracker.remove_ip("poison-stress", ip_alt).await;
|
||||
tracker.check_and_add("poison-stress", ip_primary).await.unwrap();
|
||||
}
|
||||
}
|
||||
|
|
@ -10,6 +10,16 @@ use crate::transport::middle_proxy::{
|
|||
ProxyConfigData, fetch_proxy_config_with_raw, load_proxy_config_cache, save_proxy_config_cache,
|
||||
};
|
||||
|
||||
pub(crate) fn resolve_runtime_config_path(config_path_cli: &str, startup_cwd: &std::path::Path) -> PathBuf {
|
||||
let raw = PathBuf::from(config_path_cli);
|
||||
let absolute = if raw.is_absolute() {
|
||||
raw
|
||||
} else {
|
||||
startup_cwd.join(raw)
|
||||
};
|
||||
absolute.canonicalize().unwrap_or(absolute)
|
||||
}
|
||||
|
||||
pub(crate) fn parse_cli() -> (String, Option<PathBuf>, bool, Option<String>) {
|
||||
let mut config_path = "config.toml".to_string();
|
||||
let mut data_path: Option<PathBuf> = None;
|
||||
|
|
@ -96,6 +106,44 @@ pub(crate) fn parse_cli() -> (String, Option<PathBuf>, bool, Option<String>) {
|
|||
(config_path, data_path, silent, log_level)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::resolve_runtime_config_path;
|
||||
|
||||
#[test]
|
||||
fn resolve_runtime_config_path_anchors_relative_to_startup_cwd() {
|
||||
let nonce = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_nanos();
|
||||
let startup_cwd = std::env::temp_dir().join(format!("telemt_cfg_path_{nonce}"));
|
||||
std::fs::create_dir_all(&startup_cwd).unwrap();
|
||||
let target = startup_cwd.join("config.toml");
|
||||
std::fs::write(&target, " ").unwrap();
|
||||
|
||||
let resolved = resolve_runtime_config_path("config.toml", &startup_cwd);
|
||||
assert_eq!(resolved, target.canonicalize().unwrap());
|
||||
|
||||
let _ = std::fs::remove_file(&target);
|
||||
let _ = std::fs::remove_dir(&startup_cwd);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_runtime_config_path_keeps_absolute_for_missing_file() {
|
||||
let nonce = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_nanos();
|
||||
let startup_cwd = std::env::temp_dir().join(format!("telemt_cfg_path_missing_{nonce}"));
|
||||
std::fs::create_dir_all(&startup_cwd).unwrap();
|
||||
|
||||
let resolved = resolve_runtime_config_path("missing.toml", &startup_cwd);
|
||||
assert_eq!(resolved, startup_cwd.join("missing.toml"));
|
||||
|
||||
let _ = std::fs::remove_dir(&startup_cwd);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn print_proxy_links(host: &str, port: u16, config: &ProxyConfig) {
|
||||
info!(target: "telemt::links", "--- Proxy Links ({}) ---", host);
|
||||
for user_name in config.general.links.show.resolve_users(&config.access.users) {
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ use crate::startup::{
|
|||
use crate::stream::BufferPool;
|
||||
use crate::transport::middle_proxy::MePool;
|
||||
use crate::transport::UpstreamManager;
|
||||
use helpers::parse_cli;
|
||||
use helpers::{parse_cli, resolve_runtime_config_path};
|
||||
|
||||
/// Runs the full telemt runtime startup pipeline and blocks until shutdown.
|
||||
pub async fn run() -> std::result::Result<(), Box<dyn std::error::Error>> {
|
||||
|
|
@ -58,18 +58,26 @@ pub async fn run() -> std::result::Result<(), Box<dyn std::error::Error>> {
|
|||
startup_tracker
|
||||
.start_component(COMPONENT_CONFIG_LOAD, Some("load and validate config".to_string()))
|
||||
.await;
|
||||
let (config_path, data_path, cli_silent, cli_log_level) = parse_cli();
|
||||
let (config_path_cli, data_path, cli_silent, cli_log_level) = parse_cli();
|
||||
let startup_cwd = match std::env::current_dir() {
|
||||
Ok(cwd) => cwd,
|
||||
Err(e) => {
|
||||
eprintln!("[telemt] Can't read current_dir: {}", e);
|
||||
std::process::exit(1);
|
||||
}
|
||||
};
|
||||
let config_path = resolve_runtime_config_path(&config_path_cli, &startup_cwd);
|
||||
|
||||
let mut config = match ProxyConfig::load(&config_path) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
if std::path::Path::new(&config_path).exists() {
|
||||
if config_path.exists() {
|
||||
eprintln!("[telemt] Error: {}", e);
|
||||
std::process::exit(1);
|
||||
} else {
|
||||
let default = ProxyConfig::default();
|
||||
std::fs::write(&config_path, toml::to_string_pretty(&default).unwrap()).unwrap();
|
||||
eprintln!("[telemt] Created default config at {}", config_path);
|
||||
eprintln!("[telemt] Created default config at {}", config_path.display());
|
||||
default
|
||||
}
|
||||
}
|
||||
|
|
@ -258,7 +266,7 @@ pub async fn run() -> std::result::Result<(), Box<dyn std::error::Error>> {
|
|||
let route_runtime_api = route_runtime.clone();
|
||||
let config_rx_api = api_config_rx.clone();
|
||||
let admission_rx_api = admission_rx.clone();
|
||||
let config_path_api = std::path::PathBuf::from(&config_path);
|
||||
let config_path_api = config_path.clone();
|
||||
let startup_tracker_api = startup_tracker.clone();
|
||||
let detected_ips_rx_api = detected_ips_rx.clone();
|
||||
tokio::spawn(async move {
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
use std::net::IpAddr;
|
||||
use std::path::PathBuf;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
|
||||
use tokio::sync::{mpsc, watch};
|
||||
|
|
@ -32,7 +32,7 @@ pub(crate) struct RuntimeWatches {
|
|||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn spawn_runtime_tasks(
|
||||
config: &Arc<ProxyConfig>,
|
||||
config_path: &str,
|
||||
config_path: &Path,
|
||||
probe: &NetworkProbe,
|
||||
prefer_ipv6: bool,
|
||||
decision_ipv4_dc: bool,
|
||||
|
|
@ -83,7 +83,7 @@ pub(crate) async fn spawn_runtime_tasks(
|
|||
watch::Receiver<Arc<ProxyConfig>>,
|
||||
watch::Receiver<LogLevel>,
|
||||
) = spawn_config_watcher(
|
||||
PathBuf::from(config_path),
|
||||
config_path.to_path_buf(),
|
||||
config.clone(),
|
||||
detected_ip_v4,
|
||||
detected_ip_v6,
|
||||
|
|
@ -279,11 +279,32 @@ pub(crate) async fn spawn_metrics_if_configured(
|
|||
ip_tracker: Arc<UserIpTracker>,
|
||||
config_rx: watch::Receiver<Arc<ProxyConfig>>,
|
||||
) {
|
||||
if let Some(port) = config.server.metrics_port {
|
||||
// metrics_listen takes precedence; fall back to metrics_port for backward compat.
|
||||
let metrics_target: Option<(u16, Option<String>)> =
|
||||
if let Some(ref listen) = config.server.metrics_listen {
|
||||
match listen.parse::<std::net::SocketAddr>() {
|
||||
Ok(addr) => Some((addr.port(), Some(listen.clone()))),
|
||||
Err(e) => {
|
||||
startup_tracker
|
||||
.skip_component(
|
||||
COMPONENT_METRICS_START,
|
||||
Some(format!("invalid metrics_listen \"{}\": {}", listen, e)),
|
||||
)
|
||||
.await;
|
||||
None
|
||||
}
|
||||
}
|
||||
} else {
|
||||
config.server.metrics_port.map(|p| (p, None))
|
||||
};
|
||||
|
||||
if let Some((port, listen)) = metrics_target {
|
||||
let fallback_label = format!("port {}", port);
|
||||
let label = listen.as_deref().unwrap_or(&fallback_label);
|
||||
startup_tracker
|
||||
.start_component(
|
||||
COMPONENT_METRICS_START,
|
||||
Some(format!("spawn metrics endpoint on {}", port)),
|
||||
Some(format!("spawn metrics endpoint on {}", label)),
|
||||
)
|
||||
.await;
|
||||
let stats = stats.clone();
|
||||
|
|
@ -294,6 +315,7 @@ pub(crate) async fn spawn_metrics_if_configured(
|
|||
tokio::spawn(async move {
|
||||
metrics::serve(
|
||||
port,
|
||||
listen,
|
||||
stats,
|
||||
beobachten,
|
||||
ip_tracker_metrics,
|
||||
|
|
@ -308,7 +330,7 @@ pub(crate) async fn spawn_metrics_if_configured(
|
|||
Some("metrics task spawned".to_string()),
|
||||
)
|
||||
.await;
|
||||
} else {
|
||||
} else if config.server.metrics_listen.is_none() {
|
||||
startup_tracker
|
||||
.skip_component(
|
||||
COMPONENT_METRICS_START,
|
||||
|
|
|
|||
|
|
@ -6,6 +6,8 @@ mod config;
|
|||
mod crypto;
|
||||
mod error;
|
||||
mod ip_tracker;
|
||||
#[cfg(test)]
|
||||
mod ip_tracker_regression_tests;
|
||||
mod maestro;
|
||||
mod metrics;
|
||||
mod network;
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ use crate::transport::{ListenOptions, create_listener};
|
|||
|
||||
pub async fn serve(
|
||||
port: u16,
|
||||
listen: Option<String>,
|
||||
stats: Arc<Stats>,
|
||||
beobachten: Arc<BeobachtenStore>,
|
||||
ip_tracker: Arc<UserIpTracker>,
|
||||
|
|
@ -28,6 +29,33 @@ pub async fn serve(
|
|||
whitelist: Vec<IpNetwork>,
|
||||
) {
|
||||
let whitelist = Arc::new(whitelist);
|
||||
|
||||
// If `metrics_listen` is set, bind on that single address only.
|
||||
if let Some(ref listen_addr) = listen {
|
||||
let addr: SocketAddr = match listen_addr.parse() {
|
||||
Ok(a) => a,
|
||||
Err(e) => {
|
||||
warn!(error = %e, "Invalid metrics_listen address: {}", listen_addr);
|
||||
return;
|
||||
}
|
||||
};
|
||||
let is_ipv6 = addr.is_ipv6();
|
||||
match bind_metrics_listener(addr, is_ipv6) {
|
||||
Ok(listener) => {
|
||||
info!("Metrics endpoint: http://{}/metrics and /beobachten", addr);
|
||||
serve_listener(
|
||||
listener, stats, beobachten, ip_tracker, config_rx, whitelist,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(error = %e, "Failed to bind metrics on {}", addr);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Fallback: bind on 0.0.0.0 and [::] using metrics_port.
|
||||
let mut listener_v4 = None;
|
||||
let mut listener_v6 = None;
|
||||
|
||||
|
|
|
|||
|
|
@ -11,9 +11,8 @@ use crate::crypto::{sha256_hmac, SecureRandom};
|
|||
use crate::error::ProxyError;
|
||||
use super::constants::*;
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
use num_bigint::BigUint;
|
||||
use num_traits::One;
|
||||
use subtle::ConstantTimeEq;
|
||||
use x25519_dalek::{X25519_BASEPOINT_BYTES, x25519};
|
||||
|
||||
// ============= Public Constants =============
|
||||
|
||||
|
|
@ -27,10 +26,17 @@ pub const TLS_DIGEST_POS: usize = 11;
|
|||
pub const TLS_DIGEST_HALF_LEN: usize = 16;
|
||||
|
||||
/// Time skew limits for anti-replay (in seconds)
|
||||
pub const TIME_SKEW_MIN: i64 = -20 * 60; // 20 minutes before
|
||||
pub const TIME_SKEW_MAX: i64 = 10 * 60; // 10 minutes after
|
||||
///
|
||||
/// The default window is intentionally narrow to reduce replay acceptance.
|
||||
/// Operators with known clock-drifted clients should tune deployment config
|
||||
/// (for example replay-window policy) to match their environment.
|
||||
pub const TIME_SKEW_MIN: i64 = -2 * 60; // 2 minutes before
|
||||
pub const TIME_SKEW_MAX: i64 = 2 * 60; // 2 minutes after
|
||||
/// Maximum accepted boot-time timestamp (seconds) before skew checks are enforced.
|
||||
pub const BOOT_TIME_MAX_SECS: u32 = 7 * 24 * 60 * 60;
|
||||
/// Hard cap for boot-time compatibility bypass to avoid oversized acceptance
|
||||
/// windows when replay TTL is configured very large.
|
||||
pub const BOOT_TIME_COMPAT_MAX_SECS: u32 = 2 * 60;
|
||||
|
||||
// ============= Private Constants =============
|
||||
|
||||
|
|
@ -63,6 +69,7 @@ pub struct TlsValidation {
|
|||
/// Client digest for response generation
|
||||
pub digest: [u8; TLS_DIGEST_LEN],
|
||||
/// Timestamp extracted from digest
|
||||
|
||||
pub timestamp: u32,
|
||||
}
|
||||
|
||||
|
|
@ -117,28 +124,8 @@ impl TlsExtensionBuilder {
|
|||
self
|
||||
}
|
||||
|
||||
/// Add ALPN extension with a single selected protocol.
|
||||
fn add_alpn(&mut self, proto: &[u8]) -> &mut Self {
|
||||
// Extension type: ALPN (0x0010)
|
||||
self.extensions.extend_from_slice(&extension_type::ALPN.to_be_bytes());
|
||||
|
||||
// ALPN extension format:
|
||||
// extension_data length (2 bytes)
|
||||
// protocols length (2 bytes)
|
||||
// protocol name length (1 byte)
|
||||
// protocol name bytes
|
||||
let proto_len = proto.len() as u8;
|
||||
let list_len: u16 = 1 + u16::from(proto_len);
|
||||
let ext_len: u16 = 2 + list_len;
|
||||
|
||||
self.extensions.extend_from_slice(&ext_len.to_be_bytes());
|
||||
self.extensions.extend_from_slice(&list_len.to_be_bytes());
|
||||
self.extensions.push(proto_len);
|
||||
self.extensions.extend_from_slice(proto);
|
||||
self
|
||||
}
|
||||
|
||||
/// Build final extensions with length prefix
|
||||
|
||||
fn build(self) -> Vec<u8> {
|
||||
let mut result = Vec::with_capacity(2 + self.extensions.len());
|
||||
|
||||
|
|
@ -153,7 +140,7 @@ impl TlsExtensionBuilder {
|
|||
}
|
||||
|
||||
/// Get current extensions without length prefix (for calculation)
|
||||
#[allow(dead_code)]
|
||||
|
||||
fn as_bytes(&self) -> &[u8] {
|
||||
&self.extensions
|
||||
}
|
||||
|
|
@ -173,8 +160,6 @@ struct ServerHelloBuilder {
|
|||
compression: u8,
|
||||
/// Extensions
|
||||
extensions: TlsExtensionBuilder,
|
||||
/// Selected ALPN protocol (if any)
|
||||
alpn: Option<Vec<u8>>,
|
||||
}
|
||||
|
||||
impl ServerHelloBuilder {
|
||||
|
|
@ -185,7 +170,6 @@ impl ServerHelloBuilder {
|
|||
cipher_suite: cipher_suite::TLS_AES_128_GCM_SHA256,
|
||||
compression: 0x00,
|
||||
extensions: TlsExtensionBuilder::new(),
|
||||
alpn: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -200,18 +184,9 @@ impl ServerHelloBuilder {
|
|||
self
|
||||
}
|
||||
|
||||
fn with_alpn(mut self, proto: Option<Vec<u8>>) -> Self {
|
||||
self.alpn = proto;
|
||||
self
|
||||
}
|
||||
|
||||
/// Build ServerHello message (without record header)
|
||||
fn build_message(&self) -> Vec<u8> {
|
||||
let mut ext_builder = self.extensions.clone();
|
||||
if let Some(ref alpn) = self.alpn {
|
||||
ext_builder.add_alpn(alpn);
|
||||
}
|
||||
let extensions = ext_builder.extensions.clone();
|
||||
let extensions = self.extensions.extensions.clone();
|
||||
let extensions_len = extensions.len() as u16;
|
||||
|
||||
// Calculate total length
|
||||
|
|
@ -281,6 +256,7 @@ impl ServerHelloBuilder {
|
|||
/// Returns validation result if a matching user is found.
|
||||
/// The result **must** be used — ignoring it silently bypasses authentication.
|
||||
#[must_use]
|
||||
|
||||
pub fn validate_tls_handshake(
|
||||
handshake: &[u8],
|
||||
secrets: &[(String, Vec<u8>)],
|
||||
|
|
@ -296,9 +272,9 @@ pub fn validate_tls_handshake(
|
|||
|
||||
/// Validate TLS ClientHello and cap the boot-time bypass by replay-cache TTL.
|
||||
///
|
||||
/// A boot-time timestamp is only accepted when it falls below both
|
||||
/// `BOOT_TIME_MAX_SECS` and the configured replay window, preventing timestamp
|
||||
/// reuse outside replay cache coverage.
|
||||
/// A boot-time timestamp is only accepted when it falls below all three
|
||||
/// bounds: `BOOT_TIME_MAX_SECS`, configured replay window, and
|
||||
/// `BOOT_TIME_COMPAT_MAX_SECS`, preventing oversized compatibility windows.
|
||||
#[must_use]
|
||||
pub fn validate_tls_handshake_with_replay_window(
|
||||
handshake: &[u8],
|
||||
|
|
@ -316,7 +292,16 @@ pub fn validate_tls_handshake_with_replay_window(
|
|||
};
|
||||
|
||||
let replay_window_u32 = u32::try_from(replay_window_secs).unwrap_or(u32::MAX);
|
||||
let boot_time_cap_secs = BOOT_TIME_MAX_SECS.min(replay_window_u32);
|
||||
// Boot-time bypass and ignore_time_skew serve different compatibility paths.
|
||||
// When skew checks are disabled, force boot-time cap to zero to prevent
|
||||
// accidental future coupling of boot-time logic into the ignore-skew path.
|
||||
let boot_time_cap_secs = if ignore_time_skew {
|
||||
0
|
||||
} else {
|
||||
BOOT_TIME_MAX_SECS
|
||||
.min(replay_window_u32)
|
||||
.min(BOOT_TIME_COMPAT_MAX_SECS)
|
||||
};
|
||||
|
||||
validate_tls_handshake_at_time_with_boot_cap(
|
||||
handshake,
|
||||
|
|
@ -335,6 +320,7 @@ fn system_time_to_unix_secs(now: SystemTime) -> Option<i64> {
|
|||
i64::try_from(d.as_secs()).ok()
|
||||
}
|
||||
|
||||
|
||||
fn validate_tls_handshake_at_time(
|
||||
handshake: &[u8],
|
||||
secrets: &[(String, Vec<u8>)],
|
||||
|
|
@ -369,6 +355,9 @@ fn validate_tls_handshake_at_time_with_boot_cap(
|
|||
// Extract session ID
|
||||
let session_id_len_pos = TLS_DIGEST_POS + TLS_DIGEST_LEN;
|
||||
let session_id_len = handshake.get(session_id_len_pos).copied()? as usize;
|
||||
if session_id_len > 32 {
|
||||
return None;
|
||||
}
|
||||
let session_id_start = session_id_len_pos + 1;
|
||||
|
||||
if handshake.len() < session_id_start + session_id_len {
|
||||
|
|
@ -381,7 +370,7 @@ fn validate_tls_handshake_at_time_with_boot_cap(
|
|||
let mut msg = handshake.to_vec();
|
||||
msg[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].fill(0);
|
||||
|
||||
let mut first_match: Option<TlsValidation> = None;
|
||||
let mut first_match: Option<(&String, u32)> = None;
|
||||
|
||||
for (user, secret) in secrets {
|
||||
let computed = sha256_hmac(secret, &msg);
|
||||
|
|
@ -411,7 +400,7 @@ fn validate_tls_handshake_at_time_with_boot_cap(
|
|||
if !ignore_time_skew {
|
||||
// Allow very small timestamps (boot time instead of unix time)
|
||||
// This is a quirk in some clients that use uptime instead of real time
|
||||
let is_boot_time = timestamp < boot_time_cap_secs;
|
||||
let is_boot_time = boot_time_cap_secs > 0 && timestamp < boot_time_cap_secs;
|
||||
if !is_boot_time {
|
||||
let time_diff = now - i64::from(timestamp);
|
||||
if !(TIME_SKEW_MIN..=TIME_SKEW_MAX).contains(&time_diff) {
|
||||
|
|
@ -421,39 +410,26 @@ fn validate_tls_handshake_at_time_with_boot_cap(
|
|||
}
|
||||
|
||||
if first_match.is_none() {
|
||||
first_match = Some(TlsValidation {
|
||||
user: user.clone(),
|
||||
session_id: session_id.clone(),
|
||||
digest,
|
||||
timestamp,
|
||||
});
|
||||
first_match = Some((user, timestamp));
|
||||
}
|
||||
}
|
||||
|
||||
first_match
|
||||
}
|
||||
|
||||
fn curve25519_prime() -> BigUint {
|
||||
(BigUint::one() << 255) - BigUint::from(19u32)
|
||||
first_match.map(|(user, timestamp)| TlsValidation {
|
||||
user: user.clone(),
|
||||
session_id,
|
||||
digest,
|
||||
timestamp,
|
||||
})
|
||||
}
|
||||
|
||||
/// Generate a fake X25519 public key for TLS
|
||||
///
|
||||
/// Produces a quadratic residue mod p = 2^255 - 19 by computing n² mod p,
|
||||
/// which matches Python/C behavior and avoids DPI fingerprinting.
|
||||
/// Uses RFC 7748 X25519 scalar multiplication over the canonical basepoint,
|
||||
/// yielding distribution-consistent public keys for anti-fingerprinting.
|
||||
pub fn gen_fake_x25519_key(rng: &SecureRandom) -> [u8; 32] {
|
||||
let mut n_bytes = [0u8; 32];
|
||||
n_bytes.copy_from_slice(&rng.bytes(32));
|
||||
|
||||
let n = BigUint::from_bytes_le(&n_bytes);
|
||||
let p = curve25519_prime();
|
||||
let pk = (&n * &n) % &p;
|
||||
|
||||
let mut out = pk.to_bytes_le();
|
||||
out.resize(32, 0);
|
||||
let mut result = [0u8; 32];
|
||||
result.copy_from_slice(&out[..32]);
|
||||
result
|
||||
let mut scalar = [0u8; 32];
|
||||
scalar.copy_from_slice(&rng.bytes(32));
|
||||
x25519(scalar, X25519_BASEPOINT_BYTES)
|
||||
}
|
||||
|
||||
/// Build TLS ServerHello response
|
||||
|
|
@ -482,7 +458,6 @@ pub fn build_server_hello(
|
|||
let server_hello = ServerHelloBuilder::new(session_id.to_vec())
|
||||
.with_x25519_key(&x25519_key)
|
||||
.with_tls13_version()
|
||||
.with_alpn(alpn)
|
||||
.build_record();
|
||||
|
||||
// Build Change Cipher Spec record
|
||||
|
|
@ -493,8 +468,27 @@ pub fn build_server_hello(
|
|||
0x01, // CCS byte
|
||||
];
|
||||
|
||||
// Build fake certificate (Application Data record)
|
||||
let fake_cert = rng.bytes(fake_cert_len);
|
||||
// Build first encrypted flight mimic as opaque ApplicationData bytes.
|
||||
// Embed a compact EncryptedExtensions-like ALPN block when selected.
|
||||
let mut fake_cert = Vec::with_capacity(fake_cert_len);
|
||||
if let Some(proto) = alpn.as_ref().filter(|p| !p.is_empty() && p.len() <= u8::MAX as usize) {
|
||||
let proto_list_len = 1usize + proto.len();
|
||||
let ext_data_len = 2usize + proto_list_len;
|
||||
let marker_len = 4usize + ext_data_len;
|
||||
if marker_len <= fake_cert_len {
|
||||
fake_cert.extend_from_slice(&0x0010u16.to_be_bytes());
|
||||
fake_cert.extend_from_slice(&(ext_data_len as u16).to_be_bytes());
|
||||
fake_cert.extend_from_slice(&(proto_list_len as u16).to_be_bytes());
|
||||
fake_cert.push(proto.len() as u8);
|
||||
fake_cert.extend_from_slice(proto);
|
||||
}
|
||||
}
|
||||
if fake_cert.len() < fake_cert_len {
|
||||
fake_cert.extend_from_slice(&rng.bytes(fake_cert_len - fake_cert.len()));
|
||||
} else if fake_cert.len() > fake_cert_len {
|
||||
fake_cert.truncate(fake_cert_len);
|
||||
}
|
||||
|
||||
let mut app_data_record = Vec::with_capacity(5 + fake_cert_len);
|
||||
app_data_record.push(TLS_RECORD_APPLICATION);
|
||||
app_data_record.extend_from_slice(&TLS_VERSION);
|
||||
|
|
@ -506,8 +500,9 @@ pub fn build_server_hello(
|
|||
// Build optional NewSessionTicket records (TLS 1.3 handshake messages are encrypted;
|
||||
// here we mimic with opaque ApplicationData records of plausible size).
|
||||
let mut tickets = Vec::new();
|
||||
if new_session_tickets > 0 {
|
||||
for _ in 0..new_session_tickets {
|
||||
let ticket_count = new_session_tickets.min(4);
|
||||
if ticket_count > 0 {
|
||||
for _ in 0..ticket_count {
|
||||
let ticket_len: usize = rng.range(48) + 48; // 48-95 bytes
|
||||
let mut record = Vec::with_capacity(5 + ticket_len);
|
||||
record.push(TLS_RECORD_APPLICATION);
|
||||
|
|
@ -705,13 +700,14 @@ pub fn is_tls_handshake(first_bytes: &[u8]) -> bool {
|
|||
return false;
|
||||
}
|
||||
|
||||
// TLS record header: 0x16 (handshake) 0x03 0x01 (TLS 1.0)
|
||||
// TLS ClientHello commonly uses legacy record versions 0x0301 or 0x0303.
|
||||
first_bytes[0] == TLS_RECORD_HANDSHAKE
|
||||
&& first_bytes[1] == 0x03
|
||||
&& first_bytes[2] == 0x01
|
||||
&& (first_bytes[2] == 0x01 || first_bytes[2] == 0x03)
|
||||
}
|
||||
|
||||
/// Parse TLS record header, returns (record_type, length)
|
||||
|
||||
pub fn parse_tls_record_header(header: &[u8; 5]) -> Option<(u8, u16)> {
|
||||
let record_type = header[0];
|
||||
let version = [header[1], header[2]];
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -24,6 +24,47 @@ enum HandshakeOutcome {
|
|||
Handled,
|
||||
}
|
||||
|
||||
#[must_use = "UserConnectionReservation must be kept alive to retain user/IP reservation until release or drop"]
|
||||
struct UserConnectionReservation {
|
||||
stats: Arc<Stats>,
|
||||
ip_tracker: Arc<UserIpTracker>,
|
||||
user: String,
|
||||
ip: IpAddr,
|
||||
active: bool,
|
||||
}
|
||||
|
||||
impl UserConnectionReservation {
|
||||
fn new(stats: Arc<Stats>, ip_tracker: Arc<UserIpTracker>, user: String, ip: IpAddr) -> Self {
|
||||
Self {
|
||||
stats,
|
||||
ip_tracker,
|
||||
user,
|
||||
ip,
|
||||
active: true,
|
||||
}
|
||||
}
|
||||
|
||||
async fn release(mut self) {
|
||||
if !self.active {
|
||||
return;
|
||||
}
|
||||
self.ip_tracker.remove_ip(&self.user, self.ip).await;
|
||||
self.active = false;
|
||||
self.stats.decrement_user_curr_connects(&self.user);
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for UserConnectionReservation {
|
||||
fn drop(&mut self) {
|
||||
if !self.active {
|
||||
return;
|
||||
}
|
||||
self.active = false;
|
||||
self.stats.decrement_user_curr_connects(&self.user);
|
||||
self.ip_tracker.enqueue_cleanup(self.user.clone(), self.ip);
|
||||
}
|
||||
}
|
||||
|
||||
use crate::config::ProxyConfig;
|
||||
use crate::crypto::SecureRandom;
|
||||
use crate::error::{HandshakeResult, ProxyError, Result, StreamError};
|
||||
|
|
@ -45,7 +86,19 @@ use crate::proxy::middle_relay::handle_via_middle_proxy;
|
|||
use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController};
|
||||
|
||||
fn beobachten_ttl(config: &ProxyConfig) -> Duration {
|
||||
Duration::from_secs(config.general.beobachten_minutes.saturating_mul(60))
|
||||
let minutes = config.general.beobachten_minutes;
|
||||
if minutes == 0 {
|
||||
static BEOBACHTEN_ZERO_MINUTES_WARNED: OnceLock<AtomicBool> = OnceLock::new();
|
||||
let warned = BEOBACHTEN_ZERO_MINUTES_WARNED.get_or_init(|| AtomicBool::new(false));
|
||||
if !warned.swap(true, Ordering::Relaxed) {
|
||||
warn!(
|
||||
"general.beobachten_minutes=0 is insecure because entries expire immediately; forcing minimum TTL to 1 minute"
|
||||
);
|
||||
}
|
||||
return Duration::from_secs(60);
|
||||
}
|
||||
|
||||
Duration::from_secs(minutes.saturating_mul(60))
|
||||
}
|
||||
|
||||
fn record_beobachten_class(
|
||||
|
|
@ -90,6 +143,10 @@ fn is_trusted_proxy_source(peer_ip: IpAddr, trusted: &[IpNetwork]) -> bool {
|
|||
trusted.iter().any(|cidr| cidr.contains(peer_ip))
|
||||
}
|
||||
|
||||
fn synthetic_local_addr(port: u16) -> SocketAddr {
|
||||
SocketAddr::from(([0, 0, 0, 0], port))
|
||||
}
|
||||
|
||||
pub async fn handle_client_stream<S>(
|
||||
mut stream: S,
|
||||
peer: SocketAddr,
|
||||
|
|
@ -113,9 +170,7 @@ where
|
|||
let mut real_peer = normalize_ip(peer);
|
||||
|
||||
// For non-TCP streams, use a synthetic local address; may be overridden by PROXY protocol dst
|
||||
let mut local_addr: SocketAddr = format!("0.0.0.0:{}", config.server.port)
|
||||
.parse()
|
||||
.unwrap_or_else(|_| "0.0.0.0:443".parse().unwrap());
|
||||
let mut local_addr = synthetic_local_addr(config.server.port);
|
||||
|
||||
if proxy_protocol_enabled {
|
||||
let proxy_header_timeout = Duration::from_millis(
|
||||
|
|
@ -245,7 +300,7 @@ where
|
|||
handle_bad_client(
|
||||
reader,
|
||||
writer,
|
||||
&mtproto_handshake,
|
||||
&handshake,
|
||||
real_peer,
|
||||
local_addr,
|
||||
&config,
|
||||
|
|
@ -426,7 +481,6 @@ impl RunningClientHandler {
|
|||
pub async fn run(self) -> Result<()> {
|
||||
self.stats.increment_connects_all();
|
||||
let peer = self.peer;
|
||||
let _ip_tracker = self.ip_tracker.clone();
|
||||
debug!(peer = %peer, "New connection");
|
||||
|
||||
if let Err(e) = configure_client_socket(
|
||||
|
|
@ -557,7 +611,6 @@ impl RunningClientHandler {
|
|||
|
||||
let is_tls = tls::is_tls_handshake(&first_bytes[..3]);
|
||||
let peer = self.peer;
|
||||
let _ip_tracker = self.ip_tracker.clone();
|
||||
|
||||
debug!(peer = %peer, is_tls = is_tls, "Handshake type detected");
|
||||
|
||||
|
|
@ -570,7 +623,6 @@ impl RunningClientHandler {
|
|||
|
||||
async fn handle_tls_client(mut self, first_bytes: [u8; 5], local_addr: SocketAddr) -> Result<HandshakeOutcome> {
|
||||
let peer = self.peer;
|
||||
let _ip_tracker = self.ip_tracker.clone();
|
||||
|
||||
let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize;
|
||||
|
||||
|
|
@ -661,7 +713,7 @@ impl RunningClientHandler {
|
|||
handle_bad_client(
|
||||
reader,
|
||||
writer,
|
||||
&mtproto_handshake,
|
||||
&handshake,
|
||||
peer,
|
||||
local_addr,
|
||||
&config,
|
||||
|
|
@ -694,7 +746,6 @@ impl RunningClientHandler {
|
|||
|
||||
async fn handle_direct_client(mut self, first_bytes: [u8; 5], local_addr: SocketAddr) -> Result<HandshakeOutcome> {
|
||||
let peer = self.peer;
|
||||
let _ip_tracker = self.ip_tracker.clone();
|
||||
|
||||
if !self.config.general.modes.classic && !self.config.general.modes.secure {
|
||||
debug!(peer = %peer, "Non-TLS modes disabled");
|
||||
|
|
@ -798,10 +849,22 @@ impl RunningClientHandler {
|
|||
{
|
||||
let user = success.user.clone();
|
||||
|
||||
if let Err(e) = Self::check_user_limits_static(&user, &config, &stats, peer_addr, &ip_tracker).await {
|
||||
warn!(user = %user, error = %e, "User limit exceeded");
|
||||
return Err(e);
|
||||
}
|
||||
let user_limit_reservation =
|
||||
match Self::acquire_user_connection_reservation_static(
|
||||
&user,
|
||||
&config,
|
||||
stats.clone(),
|
||||
peer_addr,
|
||||
ip_tracker,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(reservation) => reservation,
|
||||
Err(e) => {
|
||||
warn!(user = %user, error = %e, "User admission check failed");
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
|
||||
let route_snapshot = route_runtime.snapshot();
|
||||
let session_id = rng.u64();
|
||||
|
|
@ -858,15 +921,68 @@ impl RunningClientHandler {
|
|||
)
|
||||
.await
|
||||
};
|
||||
|
||||
stats.decrement_user_curr_connects(&user);
|
||||
ip_tracker.remove_ip(&user, peer_addr.ip()).await;
|
||||
user_limit_reservation.release().await;
|
||||
relay_result
|
||||
}
|
||||
|
||||
async fn acquire_user_connection_reservation_static(
|
||||
user: &str,
|
||||
config: &ProxyConfig,
|
||||
stats: Arc<Stats>,
|
||||
peer_addr: SocketAddr,
|
||||
ip_tracker: Arc<UserIpTracker>,
|
||||
) -> Result<UserConnectionReservation> {
|
||||
if let Some(expiration) = config.access.user_expirations.get(user)
|
||||
&& chrono::Utc::now() > *expiration
|
||||
{
|
||||
return Err(ProxyError::UserExpired {
|
||||
user: user.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
if let Some(quota) = config.access.user_data_quota.get(user)
|
||||
&& stats.get_user_total_octets(user) >= *quota
|
||||
{
|
||||
return Err(ProxyError::DataQuotaExceeded {
|
||||
user: user.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
let limit = config.access.user_max_tcp_conns.get(user).map(|v| *v as u64);
|
||||
if !stats.try_acquire_user_curr_connects(user, limit) {
|
||||
return Err(ProxyError::ConnectionLimitExceeded {
|
||||
user: user.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
match ip_tracker.check_and_add(user, peer_addr.ip()).await {
|
||||
Ok(()) => {}
|
||||
Err(reason) => {
|
||||
stats.decrement_user_curr_connects(user);
|
||||
warn!(
|
||||
user = %user,
|
||||
ip = %peer_addr.ip(),
|
||||
reason = %reason,
|
||||
"IP limit exceeded"
|
||||
);
|
||||
return Err(ProxyError::ConnectionLimitExceeded {
|
||||
user: user.to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Ok(UserConnectionReservation::new(
|
||||
stats,
|
||||
ip_tracker,
|
||||
user.to_string(),
|
||||
peer_addr.ip(),
|
||||
))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
async fn check_user_limits_static(
|
||||
user: &str,
|
||||
config: &ProxyConfig,
|
||||
user: &str,
|
||||
config: &ProxyConfig,
|
||||
stats: &Stats,
|
||||
peer_addr: SocketAddr,
|
||||
ip_tracker: &UserIpTracker,
|
||||
|
|
@ -899,7 +1015,10 @@ impl RunningClientHandler {
|
|||
}
|
||||
|
||||
match ip_tracker.check_and_add(user, peer_addr.ip()).await {
|
||||
Ok(()) => {}
|
||||
Ok(()) => {
|
||||
ip_tracker.remove_ip(user, peer_addr.ip()).await;
|
||||
stats.decrement_user_curr_connects(user);
|
||||
}
|
||||
Err(reason) => {
|
||||
stats.decrement_user_curr_connects(user);
|
||||
warn!(
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,6 +1,8 @@
|
|||
use std::ffi::OsString;
|
||||
use std::fs::OpenOptions;
|
||||
use std::io::Write;
|
||||
use std::net::SocketAddr;
|
||||
use std::path::{Component, Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
use std::collections::HashSet;
|
||||
use std::sync::{Mutex, OnceLock};
|
||||
|
|
@ -24,14 +26,44 @@ use crate::stats::Stats;
|
|||
use crate::stream::{BufferPool, CryptoReader, CryptoWriter};
|
||||
use crate::transport::UpstreamManager;
|
||||
|
||||
#[cfg(unix)]
|
||||
use std::os::unix::fs::OpenOptionsExt;
|
||||
|
||||
const UNKNOWN_DC_LOG_DISTINCT_LIMIT: usize = 1024;
|
||||
static LOGGED_UNKNOWN_DCS: OnceLock<Mutex<HashSet<i16>>> = OnceLock::new();
|
||||
const MAX_SCOPE_HINT_LEN: usize = 64;
|
||||
|
||||
fn validated_scope_hint(user: &str) -> Option<&str> {
|
||||
let scope = user.strip_prefix("scope_")?;
|
||||
if scope.is_empty() || scope.len() > MAX_SCOPE_HINT_LEN {
|
||||
return None;
|
||||
}
|
||||
if scope
|
||||
.bytes()
|
||||
.all(|b| b.is_ascii_alphanumeric() || b == b'-')
|
||||
{
|
||||
Some(scope)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct SanitizedUnknownDcLogPath {
|
||||
resolved_path: PathBuf,
|
||||
allowed_parent: PathBuf,
|
||||
file_name: OsString,
|
||||
}
|
||||
|
||||
// In tests, this function shares global mutable state. Callers that also use
|
||||
// cache-reset helpers must hold `unknown_dc_test_lock()` to keep assertions
|
||||
// deterministic under parallel execution.
|
||||
fn should_log_unknown_dc(dc_idx: i16) -> bool {
|
||||
let set = LOGGED_UNKNOWN_DCS.get_or_init(|| Mutex::new(HashSet::new()));
|
||||
should_log_unknown_dc_with_set(set, dc_idx)
|
||||
}
|
||||
|
||||
fn should_log_unknown_dc_with_set(set: &Mutex<HashSet<i16>>, dc_idx: i16) -> bool {
|
||||
match set.lock() {
|
||||
Ok(mut guard) => {
|
||||
if guard.contains(&dc_idx) {
|
||||
|
|
@ -42,9 +74,85 @@ fn should_log_unknown_dc(dc_idx: i16) -> bool {
|
|||
}
|
||||
guard.insert(dc_idx)
|
||||
}
|
||||
// If the lock is poisoned, keep logging rather than silently dropping
|
||||
// operator-visible diagnostics.
|
||||
Err(_) => true,
|
||||
// Fail closed on poisoned state to avoid unbounded blocking log writes.
|
||||
Err(_) => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn sanitize_unknown_dc_log_path(path: &str) -> Option<SanitizedUnknownDcLogPath> {
|
||||
let candidate = Path::new(path);
|
||||
if candidate.as_os_str().is_empty() {
|
||||
return None;
|
||||
}
|
||||
if candidate
|
||||
.components()
|
||||
.any(|component| matches!(component, Component::ParentDir))
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
let cwd = std::env::current_dir().ok()?;
|
||||
let file_name = candidate.file_name()?;
|
||||
let parent = candidate.parent().unwrap_or_else(|| Path::new("."));
|
||||
let parent_path = if parent.is_absolute() {
|
||||
parent.to_path_buf()
|
||||
} else {
|
||||
cwd.join(parent)
|
||||
};
|
||||
let canonical_parent = parent_path.canonicalize().ok()?;
|
||||
if !canonical_parent.is_dir() {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(SanitizedUnknownDcLogPath {
|
||||
resolved_path: canonical_parent.join(file_name),
|
||||
allowed_parent: canonical_parent,
|
||||
file_name: file_name.to_os_string(),
|
||||
})
|
||||
}
|
||||
|
||||
fn unknown_dc_log_path_is_still_safe(path: &SanitizedUnknownDcLogPath) -> bool {
|
||||
let Some(parent) = path.resolved_path.parent() else {
|
||||
return false;
|
||||
};
|
||||
let Ok(current_parent) = parent.canonicalize() else {
|
||||
return false;
|
||||
};
|
||||
if current_parent != path.allowed_parent {
|
||||
return false;
|
||||
}
|
||||
|
||||
if let Ok(canonical_target) = path.resolved_path.canonicalize() {
|
||||
let Some(target_parent) = canonical_target.parent() else {
|
||||
return false;
|
||||
};
|
||||
let Some(target_name) = canonical_target.file_name() else {
|
||||
return false;
|
||||
};
|
||||
if target_parent != path.allowed_parent || target_name != path.file_name {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
fn open_unknown_dc_log_append(path: &Path) -> std::io::Result<std::fs::File> {
|
||||
#[cfg(unix)]
|
||||
{
|
||||
OpenOptions::new()
|
||||
.create(true)
|
||||
.append(true)
|
||||
.custom_flags(libc::O_NOFOLLOW)
|
||||
.open(path)
|
||||
}
|
||||
#[cfg(not(unix))]
|
||||
{
|
||||
let _ = path;
|
||||
Err(std::io::Error::new(
|
||||
std::io::ErrorKind::PermissionDenied,
|
||||
"unknown_dc_file_log_enabled requires unix O_NOFOLLOW support",
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -93,8 +201,15 @@ where
|
|||
"Connecting to Telegram DC"
|
||||
);
|
||||
|
||||
let scope_hint = validated_scope_hint(user);
|
||||
if user.starts_with("scope_") && scope_hint.is_none() {
|
||||
warn!(
|
||||
user = %user,
|
||||
"Ignoring invalid scope hint and falling back to default upstream selection"
|
||||
);
|
||||
}
|
||||
let tg_stream = upstream_manager
|
||||
.connect(dc_addr, Some(success.dc_idx), user.strip_prefix("scope_").filter(|s| !s.is_empty()))
|
||||
.connect(dc_addr, Some(success.dc_idx), scope_hint)
|
||||
.await?;
|
||||
|
||||
debug!(peer = %success.peer, dc_addr = %dc_addr, "Connected, performing TG handshake");
|
||||
|
|
@ -105,7 +220,7 @@ where
|
|||
debug!(peer = %success.peer, "TG handshake complete, starting relay");
|
||||
|
||||
stats.increment_user_connects(user);
|
||||
stats.increment_current_connections_direct();
|
||||
let _direct_connection_lease = stats.acquire_direct_connection_lease();
|
||||
|
||||
let relay_result = relay_bidirectional(
|
||||
client_reader,
|
||||
|
|
@ -116,6 +231,7 @@ where
|
|||
config.general.direct_relay_copy_buf_s2c_bytes,
|
||||
user,
|
||||
Arc::clone(&stats),
|
||||
config.access.user_data_quota.get(user).copied(),
|
||||
buffer_pool,
|
||||
);
|
||||
tokio::pin!(relay_result);
|
||||
|
|
@ -148,8 +264,6 @@ where
|
|||
}
|
||||
};
|
||||
|
||||
stats.decrement_current_connections_direct();
|
||||
|
||||
match &relay_result {
|
||||
Ok(()) => debug!(user = %user, "Direct relay completed"),
|
||||
Err(e) => debug!(user = %user, error = %e, "Direct relay ended with error"),
|
||||
|
|
@ -199,15 +313,21 @@ fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result<SocketAddr> {
|
|||
warn!(dc_idx = dc_idx, "Requested non-standard DC with no override; falling back to default cluster");
|
||||
if config.general.unknown_dc_file_log_enabled
|
||||
&& let Some(path) = &config.general.unknown_dc_log_path
|
||||
&& should_log_unknown_dc(dc_idx)
|
||||
&& let Ok(handle) = tokio::runtime::Handle::try_current()
|
||||
{
|
||||
let path = path.clone();
|
||||
handle.spawn_blocking(move || {
|
||||
if let Ok(mut file) = OpenOptions::new().create(true).append(true).open(path) {
|
||||
let _ = writeln!(file, "dc_idx={dc_idx}");
|
||||
if let Some(path) = sanitize_unknown_dc_log_path(path) {
|
||||
if should_log_unknown_dc(dc_idx) {
|
||||
handle.spawn_blocking(move || {
|
||||
if unknown_dc_log_path_is_still_safe(&path)
|
||||
&& let Ok(mut file) = open_unknown_dc_log_append(&path.resolved_path)
|
||||
{
|
||||
let _ = writeln!(file, "dc_idx={dc_idx}");
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
} else {
|
||||
warn!(dc_idx = dc_idx, raw_path = %path, "Rejected unsafe unknown DC log path");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -4,14 +4,17 @@
|
|||
|
||||
use std::net::SocketAddr;
|
||||
use std::collections::HashSet;
|
||||
use std::net::IpAddr;
|
||||
use std::collections::hash_map::RandomState;
|
||||
use std::net::{IpAddr, Ipv6Addr};
|
||||
use std::sync::Arc;
|
||||
use std::sync::{Mutex, OnceLock};
|
||||
use std::hash::{BuildHasher, Hash, Hasher};
|
||||
use std::time::{Duration, Instant};
|
||||
use dashmap::DashMap;
|
||||
use dashmap::mapref::entry::Entry;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||
use tracing::{debug, warn, trace};
|
||||
use zeroize::Zeroize;
|
||||
use zeroize::{Zeroize, Zeroizing};
|
||||
|
||||
use crate::crypto::{sha256, AesCtr, SecureRandom};
|
||||
use rand::Rng;
|
||||
|
|
@ -25,6 +28,10 @@ use crate::tls_front::{TlsFrontCache, emulator};
|
|||
|
||||
const ACCESS_SECRET_BYTES: usize = 16;
|
||||
static INVALID_SECRET_WARNED: OnceLock<Mutex<HashSet<(String, String)>>> = OnceLock::new();
|
||||
#[cfg(test)]
|
||||
const WARNED_SECRET_MAX_ENTRIES: usize = 64;
|
||||
#[cfg(not(test))]
|
||||
const WARNED_SECRET_MAX_ENTRIES: usize = 1_024;
|
||||
|
||||
const AUTH_PROBE_TRACK_RETENTION_SECS: u64 = 10 * 60;
|
||||
#[cfg(test)]
|
||||
|
|
@ -33,6 +40,7 @@ const AUTH_PROBE_TRACK_MAX_ENTRIES: usize = 256;
|
|||
const AUTH_PROBE_TRACK_MAX_ENTRIES: usize = 65_536;
|
||||
const AUTH_PROBE_PRUNE_SCAN_LIMIT: usize = 1_024;
|
||||
const AUTH_PROBE_BACKOFF_START_FAILS: u32 = 4;
|
||||
const AUTH_PROBE_SATURATION_GRACE_FAILS: u32 = 2;
|
||||
|
||||
#[cfg(test)]
|
||||
const AUTH_PROBE_BACKOFF_BASE_MS: u64 = 1;
|
||||
|
|
@ -51,12 +59,35 @@ struct AuthProbeState {
|
|||
last_seen: Instant,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
struct AuthProbeSaturationState {
|
||||
fail_streak: u32,
|
||||
blocked_until: Instant,
|
||||
last_seen: Instant,
|
||||
}
|
||||
|
||||
static AUTH_PROBE_STATE: OnceLock<DashMap<IpAddr, AuthProbeState>> = OnceLock::new();
|
||||
static AUTH_PROBE_SATURATION_STATE: OnceLock<Mutex<Option<AuthProbeSaturationState>>> = OnceLock::new();
|
||||
static AUTH_PROBE_EVICTION_HASHER: OnceLock<RandomState> = OnceLock::new();
|
||||
|
||||
fn auth_probe_state_map() -> &'static DashMap<IpAddr, AuthProbeState> {
|
||||
AUTH_PROBE_STATE.get_or_init(DashMap::new)
|
||||
}
|
||||
|
||||
fn auth_probe_saturation_state() -> &'static Mutex<Option<AuthProbeSaturationState>> {
|
||||
AUTH_PROBE_SATURATION_STATE.get_or_init(|| Mutex::new(None))
|
||||
}
|
||||
|
||||
fn normalize_auth_probe_ip(peer_ip: IpAddr) -> IpAddr {
|
||||
match peer_ip {
|
||||
IpAddr::V4(ip) => IpAddr::V4(ip),
|
||||
IpAddr::V6(ip) => {
|
||||
let [a, b, c, d, _, _, _, _] = ip.segments();
|
||||
IpAddr::V6(Ipv6Addr::new(a, b, c, d, 0, 0, 0, 0))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn auth_probe_backoff(fail_streak: u32) -> Duration {
|
||||
if fail_streak < AUTH_PROBE_BACKOFF_START_FAILS {
|
||||
return Duration::ZERO;
|
||||
|
|
@ -74,7 +105,16 @@ fn auth_probe_state_expired(state: &AuthProbeState, now: Instant) -> bool {
|
|||
now.duration_since(state.last_seen) > retention
|
||||
}
|
||||
|
||||
fn auth_probe_eviction_offset(peer_ip: IpAddr, now: Instant) -> usize {
|
||||
let hasher_state = AUTH_PROBE_EVICTION_HASHER.get_or_init(RandomState::new);
|
||||
let mut hasher = hasher_state.build_hasher();
|
||||
peer_ip.hash(&mut hasher);
|
||||
now.hash(&mut hasher);
|
||||
hasher.finish() as usize
|
||||
}
|
||||
|
||||
fn auth_probe_is_throttled(peer_ip: IpAddr, now: Instant) -> bool {
|
||||
let peer_ip = normalize_auth_probe_ip(peer_ip);
|
||||
let state = auth_probe_state_map();
|
||||
let Some(entry) = state.get(&peer_ip) else {
|
||||
return false;
|
||||
|
|
@ -87,7 +127,85 @@ fn auth_probe_is_throttled(peer_ip: IpAddr, now: Instant) -> bool {
|
|||
now < entry.blocked_until
|
||||
}
|
||||
|
||||
fn auth_probe_saturation_grace_exhausted(peer_ip: IpAddr, now: Instant) -> bool {
|
||||
let peer_ip = normalize_auth_probe_ip(peer_ip);
|
||||
let state = auth_probe_state_map();
|
||||
let Some(entry) = state.get(&peer_ip) else {
|
||||
return false;
|
||||
};
|
||||
if auth_probe_state_expired(&entry, now) {
|
||||
drop(entry);
|
||||
state.remove(&peer_ip);
|
||||
return false;
|
||||
}
|
||||
|
||||
entry.fail_streak >= AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS
|
||||
}
|
||||
|
||||
fn auth_probe_should_apply_preauth_throttle(peer_ip: IpAddr, now: Instant) -> bool {
|
||||
if !auth_probe_is_throttled(peer_ip, now) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if !auth_probe_saturation_is_throttled(now) {
|
||||
return true;
|
||||
}
|
||||
|
||||
auth_probe_saturation_grace_exhausted(peer_ip, now)
|
||||
}
|
||||
|
||||
fn auth_probe_saturation_is_throttled(now: Instant) -> bool {
|
||||
let saturation = auth_probe_saturation_state();
|
||||
let mut guard = match saturation.lock() {
|
||||
Ok(guard) => guard,
|
||||
Err(_) => return false,
|
||||
};
|
||||
|
||||
let Some(state) = guard.as_mut() else {
|
||||
return false;
|
||||
};
|
||||
|
||||
if now.duration_since(state.last_seen) > Duration::from_secs(AUTH_PROBE_TRACK_RETENTION_SECS) {
|
||||
*guard = None;
|
||||
return false;
|
||||
}
|
||||
|
||||
if now < state.blocked_until {
|
||||
return true;
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
fn auth_probe_note_saturation(now: Instant) {
|
||||
let saturation = auth_probe_saturation_state();
|
||||
let mut guard = match saturation.lock() {
|
||||
Ok(guard) => guard,
|
||||
Err(_) => return,
|
||||
};
|
||||
|
||||
match guard.as_mut() {
|
||||
Some(state)
|
||||
if now.duration_since(state.last_seen)
|
||||
<= Duration::from_secs(AUTH_PROBE_TRACK_RETENTION_SECS) =>
|
||||
{
|
||||
state.fail_streak = state.fail_streak.saturating_add(1);
|
||||
state.last_seen = now;
|
||||
state.blocked_until = now + auth_probe_backoff(state.fail_streak);
|
||||
}
|
||||
_ => {
|
||||
let fail_streak = AUTH_PROBE_BACKOFF_START_FAILS;
|
||||
*guard = Some(AuthProbeSaturationState {
|
||||
fail_streak,
|
||||
blocked_until: now + auth_probe_backoff(fail_streak),
|
||||
last_seen: now,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn auth_probe_record_failure(peer_ip: IpAddr, now: Instant) {
|
||||
let peer_ip = normalize_auth_probe_ip(peer_ip);
|
||||
let state = auth_probe_state_map();
|
||||
auth_probe_record_failure_with_state(state, peer_ip, now);
|
||||
}
|
||||
|
|
@ -97,49 +215,138 @@ fn auth_probe_record_failure_with_state(
|
|||
peer_ip: IpAddr,
|
||||
now: Instant,
|
||||
) {
|
||||
if let Some(mut entry) = state.get_mut(&peer_ip) {
|
||||
if auth_probe_state_expired(&entry, now) {
|
||||
*entry = AuthProbeState {
|
||||
fail_streak: 1,
|
||||
blocked_until: now + auth_probe_backoff(1),
|
||||
last_seen: now,
|
||||
};
|
||||
return;
|
||||
}
|
||||
entry.fail_streak = entry.fail_streak.saturating_add(1);
|
||||
entry.last_seen = now;
|
||||
entry.blocked_until = now + auth_probe_backoff(entry.fail_streak);
|
||||
return;
|
||||
let make_new_state = || AuthProbeState {
|
||||
fail_streak: 1,
|
||||
blocked_until: now + auth_probe_backoff(1),
|
||||
last_seen: now,
|
||||
};
|
||||
|
||||
if state.len() >= AUTH_PROBE_TRACK_MAX_ENTRIES {
|
||||
let mut stale_keys = Vec::new();
|
||||
for entry in state.iter().take(AUTH_PROBE_PRUNE_SCAN_LIMIT) {
|
||||
if auth_probe_state_expired(entry.value(), now) {
|
||||
stale_keys.push(*entry.key());
|
||||
}
|
||||
let update_existing = |entry: &mut AuthProbeState| {
|
||||
if auth_probe_state_expired(entry, now) {
|
||||
*entry = make_new_state();
|
||||
} else {
|
||||
entry.fail_streak = entry.fail_streak.saturating_add(1);
|
||||
entry.last_seen = now;
|
||||
entry.blocked_until = now + auth_probe_backoff(entry.fail_streak);
|
||||
}
|
||||
for stale_key in stale_keys {
|
||||
state.remove(&stale_key);
|
||||
}
|
||||
if state.len() >= AUTH_PROBE_TRACK_MAX_ENTRIES {
|
||||
};
|
||||
|
||||
match state.entry(peer_ip) {
|
||||
Entry::Occupied(mut entry) => {
|
||||
update_existing(entry.get_mut());
|
||||
return;
|
||||
}
|
||||
Entry::Vacant(_) => {}
|
||||
}
|
||||
|
||||
if state.len() >= AUTH_PROBE_TRACK_MAX_ENTRIES {
|
||||
let mut rounds = 0usize;
|
||||
while state.len() >= AUTH_PROBE_TRACK_MAX_ENTRIES {
|
||||
rounds += 1;
|
||||
if rounds > 8 {
|
||||
auth_probe_note_saturation(now);
|
||||
let mut eviction_candidate: Option<(IpAddr, u32, Instant)> = None;
|
||||
for entry in state.iter().take(AUTH_PROBE_PRUNE_SCAN_LIMIT) {
|
||||
let key = *entry.key();
|
||||
let fail_streak = entry.value().fail_streak;
|
||||
let last_seen = entry.value().last_seen;
|
||||
match eviction_candidate {
|
||||
Some((_, current_fail, current_seen))
|
||||
if fail_streak > current_fail
|
||||
|| (fail_streak == current_fail && last_seen >= current_seen) =>
|
||||
{
|
||||
}
|
||||
_ => eviction_candidate = Some((key, fail_streak, last_seen)),
|
||||
}
|
||||
}
|
||||
|
||||
let Some((evict_key, _, _)) = eviction_candidate else {
|
||||
return;
|
||||
};
|
||||
state.remove(&evict_key);
|
||||
break;
|
||||
}
|
||||
|
||||
let mut stale_keys = Vec::new();
|
||||
let mut eviction_candidate: Option<(IpAddr, u32, Instant)> = None;
|
||||
let state_len = state.len();
|
||||
let scan_limit = state_len.min(AUTH_PROBE_PRUNE_SCAN_LIMIT);
|
||||
let start_offset = if state_len == 0 {
|
||||
0
|
||||
} else {
|
||||
auth_probe_eviction_offset(peer_ip, now) % state_len
|
||||
};
|
||||
|
||||
let mut scanned = 0usize;
|
||||
for entry in state.iter().skip(start_offset) {
|
||||
let key = *entry.key();
|
||||
let fail_streak = entry.value().fail_streak;
|
||||
let last_seen = entry.value().last_seen;
|
||||
match eviction_candidate {
|
||||
Some((_, current_fail, current_seen))
|
||||
if fail_streak > current_fail
|
||||
|| (fail_streak == current_fail && last_seen >= current_seen) =>
|
||||
{
|
||||
}
|
||||
_ => eviction_candidate = Some((key, fail_streak, last_seen)),
|
||||
}
|
||||
if auth_probe_state_expired(entry.value(), now) {
|
||||
stale_keys.push(key);
|
||||
}
|
||||
scanned += 1;
|
||||
if scanned >= scan_limit {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if scanned < scan_limit {
|
||||
for entry in state.iter().take(scan_limit - scanned) {
|
||||
let key = *entry.key();
|
||||
let fail_streak = entry.value().fail_streak;
|
||||
let last_seen = entry.value().last_seen;
|
||||
match eviction_candidate {
|
||||
Some((_, current_fail, current_seen))
|
||||
if fail_streak > current_fail
|
||||
|| (fail_streak == current_fail && last_seen >= current_seen) =>
|
||||
{
|
||||
}
|
||||
_ => eviction_candidate = Some((key, fail_streak, last_seen)),
|
||||
}
|
||||
if auth_probe_state_expired(entry.value(), now) {
|
||||
stale_keys.push(key);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for stale_key in stale_keys {
|
||||
state.remove(&stale_key);
|
||||
}
|
||||
|
||||
if state.len() < AUTH_PROBE_TRACK_MAX_ENTRIES {
|
||||
break;
|
||||
}
|
||||
|
||||
let Some((evict_key, _, _)) = eviction_candidate else {
|
||||
auth_probe_note_saturation(now);
|
||||
return;
|
||||
};
|
||||
state.remove(&evict_key);
|
||||
auth_probe_note_saturation(now);
|
||||
}
|
||||
}
|
||||
|
||||
state.insert(peer_ip, AuthProbeState {
|
||||
fail_streak: 0,
|
||||
blocked_until: now,
|
||||
last_seen: now,
|
||||
});
|
||||
|
||||
if let Some(mut entry) = state.get_mut(&peer_ip) {
|
||||
entry.fail_streak = 1;
|
||||
entry.blocked_until = now + auth_probe_backoff(1);
|
||||
match state.entry(peer_ip) {
|
||||
Entry::Occupied(mut entry) => {
|
||||
update_existing(entry.get_mut());
|
||||
}
|
||||
Entry::Vacant(entry) => {
|
||||
entry.insert(make_new_state());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn auth_probe_record_success(peer_ip: IpAddr) {
|
||||
let peer_ip = normalize_auth_probe_ip(peer_ip);
|
||||
let state = auth_probe_state_map();
|
||||
state.remove(&peer_ip);
|
||||
}
|
||||
|
|
@ -149,10 +356,16 @@ fn clear_auth_probe_state_for_testing() {
|
|||
if let Some(state) = AUTH_PROBE_STATE.get() {
|
||||
state.clear();
|
||||
}
|
||||
if let Some(saturation) = AUTH_PROBE_SATURATION_STATE.get()
|
||||
&& let Ok(mut guard) = saturation.lock()
|
||||
{
|
||||
*guard = None;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn auth_probe_fail_streak_for_testing(peer_ip: IpAddr) -> Option<u32> {
|
||||
let peer_ip = normalize_auth_probe_ip(peer_ip);
|
||||
let state = AUTH_PROBE_STATE.get()?;
|
||||
state.get(&peer_ip).map(|entry| entry.fail_streak)
|
||||
}
|
||||
|
|
@ -162,6 +375,16 @@ fn auth_probe_is_throttled_for_testing(peer_ip: IpAddr) -> bool {
|
|||
auth_probe_is_throttled(peer_ip, Instant::now())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn auth_probe_saturation_is_throttled_for_testing() -> bool {
|
||||
auth_probe_saturation_is_throttled(Instant::now())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn auth_probe_saturation_is_throttled_at_for_testing(now: Instant) -> bool {
|
||||
auth_probe_saturation_is_throttled(now)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn auth_probe_test_lock() -> &'static Mutex<()> {
|
||||
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
|
|
@ -177,11 +400,23 @@ fn clear_warned_secrets_for_testing() {
|
|||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn warned_secrets_test_lock() -> &'static Mutex<()> {
|
||||
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
TEST_LOCK.get_or_init(|| Mutex::new(()))
|
||||
}
|
||||
|
||||
fn warn_invalid_secret_once(name: &str, reason: &str, expected: usize, got: Option<usize>) {
|
||||
let key = (name.to_string(), reason.to_string());
|
||||
let warned = INVALID_SECRET_WARNED.get_or_init(|| Mutex::new(HashSet::new()));
|
||||
let should_warn = match warned.lock() {
|
||||
Ok(mut guard) => guard.insert(key),
|
||||
Ok(mut guard) => {
|
||||
if !guard.contains(&key) && guard.len() >= WARNED_SECRET_MAX_ENTRIES {
|
||||
false
|
||||
} else {
|
||||
guard.insert(key)
|
||||
}
|
||||
}
|
||||
Err(_) => true,
|
||||
};
|
||||
|
||||
|
|
@ -273,6 +508,24 @@ fn decode_user_secrets(
|
|||
secrets
|
||||
}
|
||||
|
||||
async fn maybe_apply_server_hello_delay(config: &ProxyConfig) {
|
||||
if config.censorship.server_hello_delay_max_ms == 0 {
|
||||
return;
|
||||
}
|
||||
|
||||
let min = config.censorship.server_hello_delay_min_ms;
|
||||
let max = config.censorship.server_hello_delay_max_ms.max(min);
|
||||
let delay_ms = if max == min {
|
||||
max
|
||||
} else {
|
||||
rand::rng().random_range(min..=max)
|
||||
};
|
||||
|
||||
if delay_ms > 0 {
|
||||
tokio::time::sleep(Duration::from_millis(delay_ms)).await;
|
||||
}
|
||||
}
|
||||
|
||||
/// Result of successful handshake
|
||||
///
|
||||
/// Key material (`dec_key`, `dec_iv`, `enc_key`, `enc_iv`) is
|
||||
|
|
@ -294,6 +547,7 @@ pub struct HandshakeSuccess {
|
|||
/// Client address
|
||||
pub peer: SocketAddr,
|
||||
/// Whether TLS was used
|
||||
|
||||
pub is_tls: bool,
|
||||
}
|
||||
|
||||
|
|
@ -323,17 +577,22 @@ where
|
|||
{
|
||||
debug!(peer = %peer, handshake_len = handshake.len(), "Processing TLS handshake");
|
||||
|
||||
if auth_probe_is_throttled(peer.ip(), Instant::now()) {
|
||||
let throttle_now = Instant::now();
|
||||
if auth_probe_should_apply_preauth_throttle(peer.ip(), throttle_now) {
|
||||
maybe_apply_server_hello_delay(config).await;
|
||||
debug!(peer = %peer, "TLS handshake rejected by pre-auth probe throttle");
|
||||
return HandshakeResult::BadClient { reader, writer };
|
||||
}
|
||||
|
||||
if handshake.len() < tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 {
|
||||
auth_probe_record_failure(peer.ip(), Instant::now());
|
||||
maybe_apply_server_hello_delay(config).await;
|
||||
debug!(peer = %peer, "TLS handshake too short");
|
||||
return HandshakeResult::BadClient { reader, writer };
|
||||
}
|
||||
|
||||
let secrets = decode_user_secrets(config, None);
|
||||
let client_sni = tls::extract_sni_from_client_hello(handshake);
|
||||
let secrets = decode_user_secrets(config, client_sni.as_deref());
|
||||
|
||||
let validation = match tls::validate_tls_handshake_with_replay_window(
|
||||
handshake,
|
||||
|
|
@ -344,6 +603,7 @@ where
|
|||
Some(v) => v,
|
||||
None => {
|
||||
auth_probe_record_failure(peer.ip(), Instant::now());
|
||||
maybe_apply_server_hello_delay(config).await;
|
||||
debug!(
|
||||
peer = %peer,
|
||||
ignore_time_skew = config.access.ignore_time_skew,
|
||||
|
|
@ -358,20 +618,24 @@ where
|
|||
let digest_half = &validation.digest[..tls::TLS_DIGEST_HALF_LEN];
|
||||
if replay_checker.check_and_add_tls_digest(digest_half) {
|
||||
auth_probe_record_failure(peer.ip(), Instant::now());
|
||||
maybe_apply_server_hello_delay(config).await;
|
||||
warn!(peer = %peer, "TLS replay attack detected (duplicate digest)");
|
||||
return HandshakeResult::BadClient { reader, writer };
|
||||
}
|
||||
|
||||
let secret = match secrets.iter().find(|(name, _)| *name == validation.user) {
|
||||
Some((_, s)) => s,
|
||||
None => return HandshakeResult::BadClient { reader, writer },
|
||||
None => {
|
||||
maybe_apply_server_hello_delay(config).await;
|
||||
return HandshakeResult::BadClient { reader, writer };
|
||||
}
|
||||
};
|
||||
|
||||
let cached = if config.censorship.tls_emulation {
|
||||
if let Some(cache) = tls_cache.as_ref() {
|
||||
let selected_domain = if let Some(sni) = tls::extract_sni_from_client_hello(handshake) {
|
||||
let selected_domain = if let Some(sni) = client_sni.as_ref() {
|
||||
if cache.contains_domain(&sni).await {
|
||||
sni
|
||||
sni.clone()
|
||||
} else {
|
||||
config.censorship.tls_domain.clone()
|
||||
}
|
||||
|
|
@ -404,6 +668,7 @@ where
|
|||
} else if alpn_list.iter().any(|p| p == b"http/1.1") {
|
||||
Some(b"http/1.1".to_vec())
|
||||
} else if !alpn_list.is_empty() {
|
||||
maybe_apply_server_hello_delay(config).await;
|
||||
debug!(peer = %peer, "Client ALPN list has no supported protocol; using masking fallback");
|
||||
return HandshakeResult::BadClient { reader, writer };
|
||||
} else {
|
||||
|
|
@ -436,19 +701,9 @@ where
|
|||
)
|
||||
};
|
||||
|
||||
// Optional anti-fingerprint delay before sending ServerHello.
|
||||
if config.censorship.server_hello_delay_max_ms > 0 {
|
||||
let min = config.censorship.server_hello_delay_min_ms;
|
||||
let max = config.censorship.server_hello_delay_max_ms.max(min);
|
||||
let delay_ms = if max == min {
|
||||
max
|
||||
} else {
|
||||
rand::rng().random_range(min..=max)
|
||||
};
|
||||
if delay_ms > 0 {
|
||||
tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
|
||||
}
|
||||
}
|
||||
// Apply the same optional delay budget used by reject paths to reduce
|
||||
// distinguishability between success and fail-closed handshakes.
|
||||
maybe_apply_server_hello_delay(config).await;
|
||||
|
||||
debug!(peer = %peer, response_len = response.len(), "Sending TLS ServerHello");
|
||||
|
||||
|
|
@ -492,9 +747,19 @@ where
|
|||
R: AsyncRead + Unpin + Send,
|
||||
W: AsyncWrite + Unpin + Send,
|
||||
{
|
||||
trace!(peer = %peer, handshake = ?hex::encode(handshake), "MTProto handshake bytes");
|
||||
let handshake_fingerprint = {
|
||||
let digest = sha256(&handshake[..8]);
|
||||
hex::encode(&digest[..4])
|
||||
};
|
||||
trace!(
|
||||
peer = %peer,
|
||||
handshake_fingerprint = %handshake_fingerprint,
|
||||
"MTProto handshake prefix"
|
||||
);
|
||||
|
||||
if auth_probe_is_throttled(peer.ip(), Instant::now()) {
|
||||
let throttle_now = Instant::now();
|
||||
if auth_probe_should_apply_preauth_throttle(peer.ip(), throttle_now) {
|
||||
maybe_apply_server_hello_delay(config).await;
|
||||
debug!(peer = %peer, "MTProto handshake rejected by pre-auth probe throttle");
|
||||
return HandshakeResult::BadClient { reader, writer };
|
||||
}
|
||||
|
|
@ -510,7 +775,7 @@ where
|
|||
let dec_prekey = &dec_prekey_iv[..PREKEY_LEN];
|
||||
let dec_iv_bytes = &dec_prekey_iv[PREKEY_LEN..];
|
||||
|
||||
let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len());
|
||||
let mut dec_key_input = Zeroizing::new(Vec::with_capacity(PREKEY_LEN + secret.len()));
|
||||
dec_key_input.extend_from_slice(dec_prekey);
|
||||
dec_key_input.extend_from_slice(&secret);
|
||||
let dec_key = sha256(&dec_key_input);
|
||||
|
|
@ -546,7 +811,7 @@ where
|
|||
let enc_prekey = &enc_prekey_iv[..PREKEY_LEN];
|
||||
let enc_iv_bytes = &enc_prekey_iv[PREKEY_LEN..];
|
||||
|
||||
let mut enc_key_input = Vec::with_capacity(PREKEY_LEN + secret.len());
|
||||
let mut enc_key_input = Zeroizing::new(Vec::with_capacity(PREKEY_LEN + secret.len()));
|
||||
enc_key_input.extend_from_slice(enc_prekey);
|
||||
enc_key_input.extend_from_slice(&secret);
|
||||
let enc_key = sha256(&enc_key_input);
|
||||
|
|
@ -565,6 +830,7 @@ where
|
|||
// authentication check first to avoid poisoning the replay cache.
|
||||
if replay_checker.check_and_add_handshake(dec_prekey_iv) {
|
||||
auth_probe_record_failure(peer.ip(), Instant::now());
|
||||
maybe_apply_server_hello_delay(config).await;
|
||||
warn!(peer = %peer, user = %user, "MTProto replay attack detected");
|
||||
return HandshakeResult::BadClient { reader, writer };
|
||||
}
|
||||
|
|
@ -601,6 +867,7 @@ where
|
|||
}
|
||||
|
||||
auth_probe_record_failure(peer.ip(), Instant::now());
|
||||
maybe_apply_server_hello_delay(config).await;
|
||||
debug!(peer = %peer, "MTProto handshake: no matching user found");
|
||||
HandshakeResult::BadClient { reader, writer }
|
||||
}
|
||||
|
|
@ -633,7 +900,7 @@ pub fn generate_tg_nonce(
|
|||
nonce[DC_IDX_POS..DC_IDX_POS + 2].copy_from_slice(&dc_idx.to_le_bytes());
|
||||
|
||||
if fast_mode {
|
||||
let mut key_iv = Vec::with_capacity(KEY_LEN + IV_LEN);
|
||||
let mut key_iv = Zeroizing::new(Vec::with_capacity(KEY_LEN + IV_LEN));
|
||||
key_iv.extend_from_slice(client_enc_key);
|
||||
key_iv.extend_from_slice(&client_enc_iv.to_be_bytes());
|
||||
key_iv.reverse(); // Python/C behavior: reversed enc_key+enc_iv in nonce
|
||||
|
|
@ -641,7 +908,7 @@ pub fn generate_tg_nonce(
|
|||
}
|
||||
|
||||
let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN];
|
||||
let dec_key_iv: Vec<u8> = enc_key_iv.iter().rev().copied().collect();
|
||||
let dec_key_iv = Zeroizing::new(enc_key_iv.iter().rev().copied().collect::<Vec<u8>>());
|
||||
|
||||
let mut tg_enc_key = [0u8; 32];
|
||||
tg_enc_key.copy_from_slice(&enc_key_iv[..KEY_LEN]);
|
||||
|
|
@ -662,7 +929,7 @@ pub fn generate_tg_nonce(
|
|||
/// Encrypt nonce for sending to Telegram and return cipher objects with correct counter state
|
||||
pub fn encrypt_tg_nonce_with_ciphers(nonce: &[u8; HANDSHAKE_LEN]) -> (Vec<u8>, AesCtr, AesCtr) {
|
||||
let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN];
|
||||
let dec_key_iv: Vec<u8> = enc_key_iv.iter().rev().copied().collect();
|
||||
let dec_key_iv = Zeroizing::new(enc_key_iv.iter().rev().copied().collect::<Vec<u8>>());
|
||||
|
||||
let mut enc_key = [0u8; 32];
|
||||
enc_key.copy_from_slice(&enc_key_iv[..KEY_LEN]);
|
||||
|
|
@ -683,11 +950,14 @@ pub fn encrypt_tg_nonce_with_ciphers(nonce: &[u8; HANDSHAKE_LEN]) -> (Vec<u8>, A
|
|||
result.extend_from_slice(&encrypted_full[PROTO_TAG_POS..]);
|
||||
|
||||
let decryptor = AesCtr::new(&dec_key, dec_iv);
|
||||
enc_key.zeroize();
|
||||
dec_key.zeroize();
|
||||
|
||||
(result, encryptor, decryptor)
|
||||
}
|
||||
|
||||
/// Encrypt nonce for sending to Telegram (legacy function for compatibility)
|
||||
|
||||
pub fn encrypt_tg_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec<u8> {
|
||||
let (encrypted, _, _) = encrypt_tg_nonce_with_ciphers(nonce);
|
||||
encrypted
|
||||
|
|
@ -697,6 +967,10 @@ pub fn encrypt_tg_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec<u8> {
|
|||
#[path = "handshake_security_tests.rs"]
|
||||
mod security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "handshake_gap_short_tls_probe_throttle_security_tests.rs"]
|
||||
mod gap_short_tls_probe_throttle_security_tests;
|
||||
|
||||
/// Compile-time guard: HandshakeSuccess holds cryptographic key material and
|
||||
/// must never be Copy. A Copy impl would allow silent key duplication,
|
||||
/// undermining the zeroize-on-drop guarantee.
|
||||
|
|
|
|||
|
|
@ -0,0 +1,50 @@
|
|||
use super::*;
|
||||
use crate::stats::ReplayChecker;
|
||||
use std::net::SocketAddr;
|
||||
use std::time::Duration;
|
||||
|
||||
fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig {
|
||||
let mut cfg = ProxyConfig::default();
|
||||
cfg.access.users.clear();
|
||||
cfg.access
|
||||
.users
|
||||
.insert("user".to_string(), secret_hex.to_string());
|
||||
cfg.access.ignore_time_skew = true;
|
||||
cfg
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn gap_t01_short_tls_probe_burst_is_throttled() {
|
||||
let _guard = auth_probe_test_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(|poisoned| poisoned.into_inner());
|
||||
clear_auth_probe_state_for_testing();
|
||||
|
||||
let config = test_config_with_secret_hex("11111111111111111111111111111111");
|
||||
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
|
||||
let rng = SecureRandom::new();
|
||||
let peer: SocketAddr = "198.51.100.171:44361".parse().unwrap();
|
||||
|
||||
let too_short = vec![0x16, 0x03, 0x01];
|
||||
|
||||
for _ in 0..AUTH_PROBE_BACKOFF_START_FAILS {
|
||||
let result = handle_tls_handshake(
|
||||
&too_short,
|
||||
tokio::io::empty(),
|
||||
tokio::io::sink(),
|
||||
peer,
|
||||
&config,
|
||||
&replay_checker,
|
||||
&rng,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
assert!(matches!(result, HandshakeResult::BadClient { .. }));
|
||||
}
|
||||
|
||||
assert!(
|
||||
auth_probe_fail_streak_for_testing(peer.ip())
|
||||
.is_some_and(|streak| streak >= AUTH_PROBE_BACKOFF_START_FAILS),
|
||||
"short TLS probe bursts must increase auth-probe fail streak"
|
||||
);
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -7,7 +7,7 @@ use tokio::net::TcpStream;
|
|||
#[cfg(unix)]
|
||||
use tokio::net::UnixStream;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::time::timeout;
|
||||
use tokio::time::{Instant, timeout};
|
||||
use tracing::debug;
|
||||
use crate::config::ProxyConfig;
|
||||
use crate::network::dns_overrides::resolve_socket_addr;
|
||||
|
|
@ -24,8 +24,36 @@ const MASK_TIMEOUT: Duration = Duration::from_millis(50);
|
|||
const MASK_RELAY_TIMEOUT: Duration = Duration::from_secs(60);
|
||||
#[cfg(test)]
|
||||
const MASK_RELAY_TIMEOUT: Duration = Duration::from_millis(200);
|
||||
#[cfg(not(test))]
|
||||
const MASK_RELAY_IDLE_TIMEOUT: Duration = Duration::from_secs(5);
|
||||
#[cfg(test)]
|
||||
const MASK_RELAY_IDLE_TIMEOUT: Duration = Duration::from_millis(100);
|
||||
const MASK_BUFFER_SIZE: usize = 8192;
|
||||
|
||||
async fn copy_with_idle_timeout<R, W>(reader: &mut R, writer: &mut W)
|
||||
where
|
||||
R: AsyncRead + Unpin,
|
||||
W: AsyncWrite + Unpin,
|
||||
{
|
||||
let mut buf = [0u8; MASK_BUFFER_SIZE];
|
||||
loop {
|
||||
let read_res = timeout(MASK_RELAY_IDLE_TIMEOUT, reader.read(&mut buf)).await;
|
||||
let n = match read_res {
|
||||
Ok(Ok(n)) => n,
|
||||
Ok(Err(_)) | Err(_) => break,
|
||||
};
|
||||
if n == 0 {
|
||||
break;
|
||||
}
|
||||
|
||||
let write_res = timeout(MASK_RELAY_IDLE_TIMEOUT, writer.write_all(&buf[..n])).await;
|
||||
match write_res {
|
||||
Ok(Ok(())) => {}
|
||||
Ok(Err(_)) | Err(_) => break,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn write_proxy_header_with_timeout<W>(mask_write: &mut W, header: &[u8]) -> bool
|
||||
where
|
||||
W: AsyncWrite + Unpin,
|
||||
|
|
@ -49,6 +77,20 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
async fn wait_mask_connect_budget(started: Instant) {
|
||||
let elapsed = started.elapsed();
|
||||
if elapsed < MASK_TIMEOUT {
|
||||
tokio::time::sleep(MASK_TIMEOUT - elapsed).await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn wait_mask_outcome_budget(started: Instant) {
|
||||
let elapsed = started.elapsed();
|
||||
if elapsed < MASK_TIMEOUT {
|
||||
tokio::time::sleep(MASK_TIMEOUT - elapsed).await;
|
||||
}
|
||||
}
|
||||
|
||||
/// Detect client type based on initial data
|
||||
fn detect_client_type(data: &[u8]) -> &'static str {
|
||||
// Check for HTTP request
|
||||
|
|
@ -107,6 +149,8 @@ where
|
|||
// Connect via Unix socket or TCP
|
||||
#[cfg(unix)]
|
||||
if let Some(ref sock_path) = config.censorship.mask_unix_sock {
|
||||
let outcome_started = Instant::now();
|
||||
let connect_started = Instant::now();
|
||||
debug!(
|
||||
client_type = client_type,
|
||||
sock = %sock_path,
|
||||
|
|
@ -137,20 +181,25 @@ where
|
|||
};
|
||||
if let Some(header) = proxy_header {
|
||||
if !write_proxy_header_with_timeout(&mut mask_write, &header).await {
|
||||
wait_mask_outcome_budget(outcome_started).await;
|
||||
return;
|
||||
}
|
||||
}
|
||||
if timeout(MASK_RELAY_TIMEOUT, relay_to_mask(reader, writer, mask_read, mask_write, initial_data)).await.is_err() {
|
||||
debug!("Mask relay timed out (unix socket)");
|
||||
}
|
||||
wait_mask_outcome_budget(outcome_started).await;
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
wait_mask_connect_budget(connect_started).await;
|
||||
debug!(error = %e, "Failed to connect to mask unix socket");
|
||||
consume_client_data_with_timeout(reader).await;
|
||||
wait_mask_outcome_budget(outcome_started).await;
|
||||
}
|
||||
Err(_) => {
|
||||
debug!("Timeout connecting to mask unix socket");
|
||||
consume_client_data_with_timeout(reader).await;
|
||||
wait_mask_outcome_budget(outcome_started).await;
|
||||
}
|
||||
}
|
||||
return;
|
||||
|
|
@ -172,6 +221,8 @@ where
|
|||
let mask_addr = resolve_socket_addr(mask_host, mask_port)
|
||||
.map(|addr| addr.to_string())
|
||||
.unwrap_or_else(|| format!("{}:{}", mask_host, mask_port));
|
||||
let outcome_started = Instant::now();
|
||||
let connect_started = Instant::now();
|
||||
let connect_result = timeout(MASK_TIMEOUT, TcpStream::connect(&mask_addr)).await;
|
||||
match connect_result {
|
||||
Ok(Ok(stream)) => {
|
||||
|
|
@ -196,20 +247,25 @@ where
|
|||
let (mask_read, mut mask_write) = stream.into_split();
|
||||
if let Some(header) = proxy_header {
|
||||
if !write_proxy_header_with_timeout(&mut mask_write, &header).await {
|
||||
wait_mask_outcome_budget(outcome_started).await;
|
||||
return;
|
||||
}
|
||||
}
|
||||
if timeout(MASK_RELAY_TIMEOUT, relay_to_mask(reader, writer, mask_read, mask_write, initial_data)).await.is_err() {
|
||||
debug!("Mask relay timed out");
|
||||
}
|
||||
wait_mask_outcome_budget(outcome_started).await;
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
wait_mask_connect_budget(connect_started).await;
|
||||
debug!(error = %e, "Failed to connect to mask host");
|
||||
consume_client_data_with_timeout(reader).await;
|
||||
wait_mask_outcome_budget(outcome_started).await;
|
||||
}
|
||||
Err(_) => {
|
||||
debug!("Timeout connecting to mask host");
|
||||
consume_client_data_with_timeout(reader).await;
|
||||
wait_mask_outcome_budget(outcome_started).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -223,10 +279,10 @@ async fn relay_to_mask<R, W, MR, MW>(
|
|||
initial_data: &[u8],
|
||||
)
|
||||
where
|
||||
R: AsyncRead + Unpin + Send,
|
||||
W: AsyncWrite + Unpin + Send,
|
||||
MR: AsyncRead + Unpin + Send,
|
||||
MW: AsyncWrite + Unpin + Send,
|
||||
R: AsyncRead + Unpin + Send + 'static,
|
||||
W: AsyncWrite + Unpin + Send + 'static,
|
||||
MR: AsyncRead + Unpin + Send + 'static,
|
||||
MW: AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
// Send initial data to mask host
|
||||
if mask_write.write_all(initial_data).await.is_err() {
|
||||
|
|
@ -236,39 +292,16 @@ where
|
|||
return;
|
||||
}
|
||||
|
||||
let mut client_buf = vec![0u8; MASK_BUFFER_SIZE];
|
||||
let mut mask_buf = vec![0u8; MASK_BUFFER_SIZE];
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
client_read = reader.read(&mut client_buf) => {
|
||||
match client_read {
|
||||
Ok(0) | Err(_) => {
|
||||
let _ = mask_write.shutdown().await;
|
||||
break;
|
||||
}
|
||||
Ok(n) => {
|
||||
if mask_write.write_all(&client_buf[..n]).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
mask_read_res = mask_read.read(&mut mask_buf) => {
|
||||
match mask_read_res {
|
||||
Ok(0) | Err(_) => {
|
||||
let _ = writer.shutdown().await;
|
||||
break;
|
||||
}
|
||||
Ok(n) => {
|
||||
if writer.write_all(&mask_buf[..n]).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
let _ = tokio::join!(
|
||||
async {
|
||||
copy_with_idle_timeout(&mut reader, &mut mask_write).await;
|
||||
let _ = mask_write.shutdown().await;
|
||||
},
|
||||
async {
|
||||
copy_with_idle_timeout(&mut mask_read, &mut writer).await;
|
||||
let _ = writer.shutdown().await;
|
||||
}
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
/// Just consume all data from client without responding
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,16 +1,14 @@
|
|||
use std::collections::hash_map::DefaultHasher;
|
||||
use std::collections::hash_map::RandomState;
|
||||
use std::hash::BuildHasher;
|
||||
use std::hash::{Hash, Hasher};
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::{Arc, OnceLock};
|
||||
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
|
||||
use std::sync::{Arc, Mutex, OnceLock};
|
||||
use std::time::{Duration, Instant};
|
||||
#[cfg(test)]
|
||||
use std::sync::Mutex;
|
||||
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use dashmap::DashMap;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
use tokio::sync::{mpsc, oneshot, watch};
|
||||
use tokio::sync::{mpsc, oneshot, watch, Mutex as AsyncMutex};
|
||||
use tokio::time::timeout;
|
||||
use tracing::{debug, trace, warn};
|
||||
|
||||
|
|
@ -24,24 +22,37 @@ use crate::proxy::route_mode::{
|
|||
cutover_stagger_delay,
|
||||
};
|
||||
use crate::stats::Stats;
|
||||
use crate::stream::{BufferPool, CryptoReader, CryptoWriter};
|
||||
use crate::stream::{BufferPool, CryptoReader, CryptoWriter, PooledBuffer};
|
||||
use crate::transport::middle_proxy::{MePool, MeResponse, proto_flags_for_tag};
|
||||
|
||||
enum C2MeCommand {
|
||||
Data { payload: Bytes, flags: u32 },
|
||||
Data { payload: PooledBuffer, flags: u32 },
|
||||
Close,
|
||||
}
|
||||
|
||||
const DESYNC_DEDUP_WINDOW: Duration = Duration::from_secs(60);
|
||||
const DESYNC_DEDUP_MAX_ENTRIES: usize = 65_536;
|
||||
const DESYNC_DEDUP_PRUNE_SCAN_LIMIT: usize = 1024;
|
||||
const DESYNC_FULL_CACHE_EMIT_MIN_INTERVAL: Duration = Duration::from_millis(1000);
|
||||
const DESYNC_ERROR_CLASS: &str = "frame_too_large_crypto_desync";
|
||||
const C2ME_CHANNEL_CAPACITY_FALLBACK: usize = 128;
|
||||
const C2ME_SOFT_PRESSURE_MIN_FREE_SLOTS: usize = 64;
|
||||
const C2ME_SENDER_FAIRNESS_BUDGET: usize = 32;
|
||||
#[cfg(test)]
|
||||
const C2ME_SEND_TIMEOUT: Duration = Duration::from_millis(50);
|
||||
#[cfg(not(test))]
|
||||
const C2ME_SEND_TIMEOUT: Duration = Duration::from_secs(5);
|
||||
const ME_D2C_FLUSH_BATCH_MAX_FRAMES_MIN: usize = 1;
|
||||
const ME_D2C_FLUSH_BATCH_MAX_BYTES_MIN: usize = 4096;
|
||||
#[cfg(test)]
|
||||
const QUOTA_USER_LOCKS_MAX: usize = 64;
|
||||
#[cfg(not(test))]
|
||||
const QUOTA_USER_LOCKS_MAX: usize = 4_096;
|
||||
static DESYNC_DEDUP: OnceLock<DashMap<u64, Instant>> = OnceLock::new();
|
||||
static DESYNC_HASHER: OnceLock<RandomState> = OnceLock::new();
|
||||
static DESYNC_FULL_CACHE_LAST_EMIT_AT: OnceLock<Mutex<Option<Instant>>> = OnceLock::new();
|
||||
static DESYNC_DEDUP_EVER_SATURATED: OnceLock<AtomicBool> = OnceLock::new();
|
||||
static QUOTA_USER_LOCKS: OnceLock<DashMap<String, Arc<AsyncMutex<()>>>> = OnceLock::new();
|
||||
|
||||
struct RelayForensicsState {
|
||||
trace_id: u64,
|
||||
|
|
@ -81,7 +92,8 @@ impl MeD2cFlushPolicy {
|
|||
}
|
||||
|
||||
fn hash_value<T: Hash>(value: &T) -> u64 {
|
||||
let mut hasher = DefaultHasher::new();
|
||||
let state = DESYNC_HASHER.get_or_init(RandomState::new);
|
||||
let mut hasher = state.build_hasher();
|
||||
value.hash(&mut hasher);
|
||||
hasher.finish()
|
||||
}
|
||||
|
|
@ -96,6 +108,11 @@ fn should_emit_full_desync(key: u64, all_full: bool, now: Instant) -> bool {
|
|||
}
|
||||
|
||||
let dedup = DESYNC_DEDUP.get_or_init(DashMap::new);
|
||||
let saturated_before = dedup.len() >= DESYNC_DEDUP_MAX_ENTRIES;
|
||||
let ever_saturated = DESYNC_DEDUP_EVER_SATURATED.get_or_init(|| AtomicBool::new(false));
|
||||
if saturated_before {
|
||||
ever_saturated.store(true, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
if let Some(mut seen_at) = dedup.get_mut(&key) {
|
||||
if now.duration_since(*seen_at) >= DESYNC_DEDUP_WINDOW {
|
||||
|
|
@ -107,8 +124,17 @@ fn should_emit_full_desync(key: u64, all_full: bool, now: Instant) -> bool {
|
|||
|
||||
if dedup.len() >= DESYNC_DEDUP_MAX_ENTRIES {
|
||||
let mut stale_keys = Vec::new();
|
||||
let mut oldest_candidate: Option<(u64, Instant)> = None;
|
||||
for entry in dedup.iter().take(DESYNC_DEDUP_PRUNE_SCAN_LIMIT) {
|
||||
if now.duration_since(*entry.value()) >= DESYNC_DEDUP_WINDOW {
|
||||
let key = *entry.key();
|
||||
let seen_at = *entry.value();
|
||||
|
||||
match oldest_candidate {
|
||||
Some((_, oldest_seen)) if seen_at >= oldest_seen => {}
|
||||
_ => oldest_candidate = Some((key, seen_at)),
|
||||
}
|
||||
|
||||
if now.duration_since(seen_at) >= DESYNC_DEDUP_WINDOW {
|
||||
stale_keys.push(*entry.key());
|
||||
}
|
||||
}
|
||||
|
|
@ -116,12 +142,57 @@ fn should_emit_full_desync(key: u64, all_full: bool, now: Instant) -> bool {
|
|||
dedup.remove(&stale_key);
|
||||
}
|
||||
if dedup.len() >= DESYNC_DEDUP_MAX_ENTRIES {
|
||||
return false;
|
||||
let Some((evict_key, _)) = oldest_candidate else {
|
||||
return false;
|
||||
};
|
||||
dedup.remove(&evict_key);
|
||||
dedup.insert(key, now);
|
||||
return should_emit_full_desync_full_cache(now);
|
||||
}
|
||||
}
|
||||
|
||||
dedup.insert(key, now);
|
||||
true
|
||||
let saturated_after = dedup.len() >= DESYNC_DEDUP_MAX_ENTRIES;
|
||||
// Preserve the first sequential insert that reaches capacity as a normal
|
||||
// emit, while still gating concurrent newcomer churn after the cache has
|
||||
// ever been observed at saturation.
|
||||
let was_ever_saturated = if saturated_after {
|
||||
ever_saturated.swap(true, Ordering::Relaxed)
|
||||
} else {
|
||||
ever_saturated.load(Ordering::Relaxed)
|
||||
};
|
||||
|
||||
if saturated_before || (saturated_after && was_ever_saturated) {
|
||||
should_emit_full_desync_full_cache(now)
|
||||
} else {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
fn should_emit_full_desync_full_cache(now: Instant) -> bool {
|
||||
let gate = DESYNC_FULL_CACHE_LAST_EMIT_AT.get_or_init(|| Mutex::new(None));
|
||||
let Ok(mut last_emit_at) = gate.lock() else {
|
||||
return false;
|
||||
};
|
||||
|
||||
match *last_emit_at {
|
||||
None => {
|
||||
*last_emit_at = Some(now);
|
||||
true
|
||||
}
|
||||
Some(last) => {
|
||||
let Some(elapsed) = now.checked_duration_since(last) else {
|
||||
*last_emit_at = Some(now);
|
||||
return true;
|
||||
};
|
||||
if elapsed >= DESYNC_FULL_CACHE_EMIT_MIN_INTERVAL {
|
||||
*last_emit_at = Some(now);
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
@ -129,6 +200,21 @@ fn clear_desync_dedup_for_testing() {
|
|||
if let Some(dedup) = DESYNC_DEDUP.get() {
|
||||
dedup.clear();
|
||||
}
|
||||
if let Some(ever_saturated) = DESYNC_DEDUP_EVER_SATURATED.get() {
|
||||
ever_saturated.store(false, Ordering::Relaxed);
|
||||
}
|
||||
if let Some(last_emit_at) = DESYNC_FULL_CACHE_LAST_EMIT_AT.get() {
|
||||
match last_emit_at.lock() {
|
||||
Ok(mut guard) => {
|
||||
*guard = None;
|
||||
}
|
||||
Err(poisoned) => {
|
||||
let mut guard = poisoned.into_inner();
|
||||
*guard = None;
|
||||
last_emit_at.clear_poison();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
@ -232,6 +318,46 @@ fn should_yield_c2me_sender(sent_since_yield: usize, has_backlog: bool) -> bool
|
|||
has_backlog && sent_since_yield >= C2ME_SENDER_FAIRNESS_BUDGET
|
||||
}
|
||||
|
||||
fn quota_exceeded_for_user(stats: &Stats, user: &str, quota_limit: Option<u64>) -> bool {
|
||||
quota_limit.is_some_and(|quota| stats.get_user_total_octets(user) >= quota)
|
||||
}
|
||||
|
||||
fn quota_would_be_exceeded_for_user(
|
||||
stats: &Stats,
|
||||
user: &str,
|
||||
quota_limit: Option<u64>,
|
||||
bytes: u64,
|
||||
) -> bool {
|
||||
quota_limit.is_some_and(|quota| {
|
||||
let used = stats.get_user_total_octets(user);
|
||||
used >= quota || bytes > quota.saturating_sub(used)
|
||||
})
|
||||
}
|
||||
|
||||
fn quota_user_lock(user: &str) -> Arc<AsyncMutex<()>> {
|
||||
let locks = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
if let Some(existing) = locks.get(user) {
|
||||
return Arc::clone(existing.value());
|
||||
}
|
||||
|
||||
if locks.len() >= QUOTA_USER_LOCKS_MAX {
|
||||
locks.retain(|_, value| Arc::strong_count(value) > 1);
|
||||
}
|
||||
|
||||
if locks.len() >= QUOTA_USER_LOCKS_MAX {
|
||||
return Arc::new(AsyncMutex::new(()));
|
||||
}
|
||||
|
||||
let created = Arc::new(AsyncMutex::new(()));
|
||||
match locks.entry(user.to_string()) {
|
||||
dashmap::mapref::entry::Entry::Occupied(entry) => Arc::clone(entry.get()),
|
||||
dashmap::mapref::entry::Entry::Vacant(entry) => {
|
||||
entry.insert(Arc::clone(&created));
|
||||
created
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn enqueue_c2me_command(
|
||||
tx: &mpsc::Sender<C2MeCommand>,
|
||||
cmd: C2MeCommand,
|
||||
|
|
@ -244,7 +370,14 @@ async fn enqueue_c2me_command(
|
|||
if tx.capacity() <= C2ME_SOFT_PRESSURE_MIN_FREE_SLOTS {
|
||||
tokio::task::yield_now().await;
|
||||
}
|
||||
tx.send(cmd).await
|
||||
match timeout(C2ME_SEND_TIMEOUT, tx.reserve()).await {
|
||||
Ok(Ok(permit)) => {
|
||||
permit.send(cmd);
|
||||
Ok(())
|
||||
}
|
||||
Ok(Err(_)) => Err(mpsc::error::SendError(cmd)),
|
||||
Err(_) => Err(mpsc::error::SendError(cmd)),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -268,6 +401,7 @@ where
|
|||
W: AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
let user = success.user.clone();
|
||||
let quota_limit = config.access.user_data_quota.get(&user).copied();
|
||||
let peer = success.peer;
|
||||
let proto_tag = success.proto_tag;
|
||||
let pool_generation = me_pool.current_generation();
|
||||
|
|
@ -283,7 +417,7 @@ where
|
|||
);
|
||||
|
||||
let (conn_id, me_rx) = me_pool.registry().register().await;
|
||||
let trace_id = conn_id;
|
||||
let trace_id = session_id;
|
||||
let bytes_me2c = Arc::new(AtomicU64::new(0));
|
||||
let mut forensics = RelayForensicsState {
|
||||
trace_id,
|
||||
|
|
@ -298,7 +432,7 @@ where
|
|||
};
|
||||
|
||||
stats.increment_user_connects(&user);
|
||||
stats.increment_current_connections_me();
|
||||
let _me_connection_lease = stats.acquire_me_connection_lease();
|
||||
|
||||
if let Some(cutover) = affected_cutover_state(
|
||||
&route_rx,
|
||||
|
|
@ -316,7 +450,6 @@ where
|
|||
tokio::time::sleep(delay).await;
|
||||
let _ = me_pool.send_close(conn_id).await;
|
||||
me_pool.registry().unregister(conn_id).await;
|
||||
stats.decrement_current_connections_me();
|
||||
return Err(ProxyError::Proxy(ROUTE_SWITCH_ERROR_MSG.to_string()));
|
||||
}
|
||||
|
||||
|
|
@ -417,6 +550,7 @@ where
|
|||
&mut frame_buf,
|
||||
stats_clone.as_ref(),
|
||||
&user_clone,
|
||||
quota_limit,
|
||||
bytes_me2c_clone.as_ref(),
|
||||
conn_id,
|
||||
d2c_flush_policy.ack_flush_immediate,
|
||||
|
|
@ -449,6 +583,7 @@ where
|
|||
&mut frame_buf,
|
||||
stats_clone.as_ref(),
|
||||
&user_clone,
|
||||
quota_limit,
|
||||
bytes_me2c_clone.as_ref(),
|
||||
conn_id,
|
||||
d2c_flush_policy.ack_flush_immediate,
|
||||
|
|
@ -481,6 +616,7 @@ where
|
|||
&mut frame_buf,
|
||||
stats_clone.as_ref(),
|
||||
&user_clone,
|
||||
quota_limit,
|
||||
bytes_me2c_clone.as_ref(),
|
||||
conn_id,
|
||||
d2c_flush_policy.ack_flush_immediate,
|
||||
|
|
@ -513,6 +649,7 @@ where
|
|||
&mut frame_buf,
|
||||
stats_clone.as_ref(),
|
||||
&user_clone,
|
||||
quota_limit,
|
||||
bytes_me2c_clone.as_ref(),
|
||||
conn_id,
|
||||
d2c_flush_policy.ack_flush_immediate,
|
||||
|
|
@ -594,7 +731,19 @@ where
|
|||
forensics.bytes_c2me = forensics
|
||||
.bytes_c2me
|
||||
.saturating_add(payload.len() as u64);
|
||||
stats.add_user_octets_from(&user, payload.len() as u64);
|
||||
if let Some(limit) = quota_limit {
|
||||
let quota_lock = quota_user_lock(&user);
|
||||
let _quota_guard = quota_lock.lock().await;
|
||||
stats.add_user_octets_from(&user, payload.len() as u64);
|
||||
if quota_exceeded_for_user(stats.as_ref(), &user, Some(limit)) {
|
||||
main_result = Err(ProxyError::DataQuotaExceeded {
|
||||
user: user.clone(),
|
||||
});
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
stats.add_user_octets_from(&user, payload.len() as u64);
|
||||
}
|
||||
let mut flags = proto_flags;
|
||||
if quickack {
|
||||
flags |= RPC_FLAG_QUICKACK;
|
||||
|
|
@ -664,7 +813,6 @@ where
|
|||
"ME relay cleanup"
|
||||
);
|
||||
me_pool.registry().unregister(conn_id).await;
|
||||
stats.decrement_current_connections_me();
|
||||
result
|
||||
}
|
||||
|
||||
|
|
@ -677,7 +825,7 @@ async fn read_client_payload<R>(
|
|||
forensics: &RelayForensicsState,
|
||||
frame_counter: &mut u64,
|
||||
stats: &Stats,
|
||||
) -> Result<Option<(Bytes, bool)>>
|
||||
) -> Result<Option<(PooledBuffer, bool)>>
|
||||
where
|
||||
R: AsyncRead + Unpin + Send + 'static,
|
||||
{
|
||||
|
|
@ -784,25 +932,21 @@ where
|
|||
len
|
||||
};
|
||||
|
||||
let chunk_cap = buffer_pool.buffer_size().max(1024);
|
||||
let mut payload = BytesMut::with_capacity(len.min(chunk_cap));
|
||||
let mut remaining = len;
|
||||
while remaining > 0 {
|
||||
let chunk_len = remaining.min(chunk_cap);
|
||||
let mut chunk = buffer_pool.get();
|
||||
chunk.resize(chunk_len, 0);
|
||||
read_exact_with_timeout(client_reader, &mut chunk[..chunk_len], frame_read_timeout)
|
||||
.await?;
|
||||
payload.extend_from_slice(&chunk[..chunk_len]);
|
||||
remaining -= chunk_len;
|
||||
let mut payload = buffer_pool.get();
|
||||
payload.clear();
|
||||
let current_cap = payload.capacity();
|
||||
if current_cap < len {
|
||||
payload.reserve(len - current_cap);
|
||||
}
|
||||
payload.resize(len, 0);
|
||||
read_exact_with_timeout(client_reader, &mut payload[..len], frame_read_timeout).await?;
|
||||
|
||||
// Secure Intermediate: strip validated trailing padding bytes.
|
||||
if proto_tag == ProtoTag::Secure {
|
||||
payload.truncate(secure_payload_len);
|
||||
}
|
||||
*frame_counter += 1;
|
||||
return Ok(Some((payload.freeze(), quickack)));
|
||||
return Ok(Some((payload, quickack)));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -823,6 +967,7 @@ async fn process_me_writer_response<W>(
|
|||
frame_buf: &mut Vec<u8>,
|
||||
stats: &Stats,
|
||||
user: &str,
|
||||
quota_limit: Option<u64>,
|
||||
bytes_me2c: &AtomicU64,
|
||||
conn_id: u64,
|
||||
ack_flush_immediate: bool,
|
||||
|
|
@ -838,17 +983,47 @@ where
|
|||
} else {
|
||||
trace!(conn_id, bytes = data.len(), flags, "ME->C data");
|
||||
}
|
||||
bytes_me2c.fetch_add(data.len() as u64, Ordering::Relaxed);
|
||||
stats.add_user_octets_to(user, data.len() as u64);
|
||||
write_client_payload(
|
||||
client_writer,
|
||||
proto_tag,
|
||||
flags,
|
||||
&data,
|
||||
rng,
|
||||
frame_buf,
|
||||
)
|
||||
.await?;
|
||||
let data_len = data.len() as u64;
|
||||
if let Some(limit) = quota_limit {
|
||||
let quota_lock = quota_user_lock(user);
|
||||
let _quota_guard = quota_lock.lock().await;
|
||||
if quota_would_be_exceeded_for_user(stats, user, Some(limit), data_len) {
|
||||
return Err(ProxyError::DataQuotaExceeded {
|
||||
user: user.to_string(),
|
||||
});
|
||||
}
|
||||
write_client_payload(
|
||||
client_writer,
|
||||
proto_tag,
|
||||
flags,
|
||||
&data,
|
||||
rng,
|
||||
frame_buf,
|
||||
)
|
||||
.await?;
|
||||
|
||||
bytes_me2c.fetch_add(data.len() as u64, Ordering::Relaxed);
|
||||
stats.add_user_octets_to(user, data.len() as u64);
|
||||
|
||||
if quota_exceeded_for_user(stats, user, Some(limit)) {
|
||||
return Err(ProxyError::DataQuotaExceeded {
|
||||
user: user.to_string(),
|
||||
});
|
||||
}
|
||||
} else {
|
||||
write_client_payload(
|
||||
client_writer,
|
||||
proto_tag,
|
||||
flags,
|
||||
&data,
|
||||
rng,
|
||||
frame_buf,
|
||||
)
|
||||
.await?;
|
||||
|
||||
bytes_me2c.fetch_add(data.len() as u64, Ordering::Relaxed);
|
||||
stats.add_user_octets_to(user, data.len() as u64);
|
||||
}
|
||||
|
||||
Ok(MeWriterResponseOutcome::Continue {
|
||||
frames: 1,
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -53,16 +53,17 @@
|
|||
|
||||
use std::io;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::{Arc, Mutex, OnceLock};
|
||||
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
|
||||
use std::task::{Context, Poll};
|
||||
use std::time::Duration;
|
||||
use dashmap::DashMap;
|
||||
use tokio::io::{
|
||||
AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf, copy_bidirectional_with_sizes,
|
||||
};
|
||||
use tokio::time::Instant;
|
||||
use tracing::{debug, trace, warn};
|
||||
use crate::error::Result;
|
||||
use crate::error::{ProxyError, Result};
|
||||
use crate::stats::Stats;
|
||||
use crate::stream::BufferPool;
|
||||
|
||||
|
|
@ -205,6 +206,10 @@ struct StatsIo<S> {
|
|||
counters: Arc<SharedCounters>,
|
||||
stats: Arc<Stats>,
|
||||
user: String,
|
||||
quota_limit: Option<u64>,
|
||||
quota_exceeded: Arc<AtomicBool>,
|
||||
quota_read_wake_scheduled: bool,
|
||||
quota_write_wake_scheduled: bool,
|
||||
epoch: Instant,
|
||||
}
|
||||
|
||||
|
|
@ -214,11 +219,64 @@ impl<S> StatsIo<S> {
|
|||
counters: Arc<SharedCounters>,
|
||||
stats: Arc<Stats>,
|
||||
user: String,
|
||||
quota_limit: Option<u64>,
|
||||
quota_exceeded: Arc<AtomicBool>,
|
||||
epoch: Instant,
|
||||
) -> Self {
|
||||
// Mark initial activity so the watchdog doesn't fire before data flows
|
||||
counters.touch(Instant::now(), epoch);
|
||||
Self { inner, counters, stats, user, epoch }
|
||||
Self {
|
||||
inner,
|
||||
counters,
|
||||
stats,
|
||||
user,
|
||||
quota_limit,
|
||||
quota_exceeded,
|
||||
quota_read_wake_scheduled: false,
|
||||
quota_write_wake_scheduled: false,
|
||||
epoch,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct QuotaIoSentinel;
|
||||
|
||||
impl std::fmt::Display for QuotaIoSentinel {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_str("user data quota exceeded")
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for QuotaIoSentinel {}
|
||||
|
||||
fn quota_io_error() -> io::Error {
|
||||
io::Error::new(io::ErrorKind::PermissionDenied, QuotaIoSentinel)
|
||||
}
|
||||
|
||||
fn is_quota_io_error(err: &io::Error) -> bool {
|
||||
err.kind() == io::ErrorKind::PermissionDenied
|
||||
&& err
|
||||
.get_ref()
|
||||
.and_then(|source| source.downcast_ref::<QuotaIoSentinel>())
|
||||
.is_some()
|
||||
}
|
||||
|
||||
static QUOTA_USER_LOCKS: OnceLock<DashMap<String, Arc<Mutex<()>>>> = OnceLock::new();
|
||||
|
||||
fn quota_user_lock(user: &str) -> Arc<Mutex<()>> {
|
||||
let locks = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
if let Some(existing) = locks.get(user) {
|
||||
return Arc::clone(existing.value());
|
||||
}
|
||||
|
||||
let created = Arc::new(Mutex::new(()));
|
||||
match locks.entry(user.to_string()) {
|
||||
dashmap::mapref::entry::Entry::Occupied(entry) => Arc::clone(entry.get()),
|
||||
dashmap::mapref::entry::Entry::Vacant(entry) => {
|
||||
entry.insert(Arc::clone(&created));
|
||||
created
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -229,6 +287,42 @@ impl<S: AsyncRead + Unpin> AsyncRead for StatsIo<S> {
|
|||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
let this = self.get_mut();
|
||||
if this.quota_exceeded.load(Ordering::Relaxed) {
|
||||
return Poll::Ready(Err(quota_io_error()));
|
||||
}
|
||||
|
||||
let quota_lock = this
|
||||
.quota_limit
|
||||
.is_some()
|
||||
.then(|| quota_user_lock(&this.user));
|
||||
let _quota_guard = if let Some(lock) = quota_lock.as_ref() {
|
||||
match lock.try_lock() {
|
||||
Ok(guard) => {
|
||||
this.quota_read_wake_scheduled = false;
|
||||
Some(guard)
|
||||
}
|
||||
Err(_) => {
|
||||
if !this.quota_read_wake_scheduled {
|
||||
this.quota_read_wake_scheduled = true;
|
||||
let waker = cx.waker().clone();
|
||||
tokio::task::spawn(async move {
|
||||
tokio::task::yield_now().await;
|
||||
waker.wake();
|
||||
});
|
||||
}
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
if let Some(limit) = this.quota_limit
|
||||
&& this.stats.get_user_total_octets(&this.user) >= limit
|
||||
{
|
||||
this.quota_exceeded.store(true, Ordering::Relaxed);
|
||||
return Poll::Ready(Err(quota_io_error()));
|
||||
}
|
||||
let before = buf.filled().len();
|
||||
|
||||
match Pin::new(&mut this.inner).poll_read(cx, buf) {
|
||||
|
|
@ -243,6 +337,13 @@ impl<S: AsyncRead + Unpin> AsyncRead for StatsIo<S> {
|
|||
this.stats.add_user_octets_from(&this.user, n as u64);
|
||||
this.stats.increment_user_msgs_from(&this.user);
|
||||
|
||||
if let Some(limit) = this.quota_limit
|
||||
&& this.stats.get_user_total_octets(&this.user) >= limit
|
||||
{
|
||||
this.quota_exceeded.store(true, Ordering::Relaxed);
|
||||
return Poll::Ready(Err(quota_io_error()));
|
||||
}
|
||||
|
||||
trace!(user = %this.user, bytes = n, "C->S");
|
||||
}
|
||||
Poll::Ready(Ok(()))
|
||||
|
|
@ -259,8 +360,56 @@ impl<S: AsyncWrite + Unpin> AsyncWrite for StatsIo<S> {
|
|||
buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
let this = self.get_mut();
|
||||
if this.quota_exceeded.load(Ordering::Relaxed) {
|
||||
return Poll::Ready(Err(quota_io_error()));
|
||||
}
|
||||
|
||||
match Pin::new(&mut this.inner).poll_write(cx, buf) {
|
||||
let quota_lock = this
|
||||
.quota_limit
|
||||
.is_some()
|
||||
.then(|| quota_user_lock(&this.user));
|
||||
let _quota_guard = if let Some(lock) = quota_lock.as_ref() {
|
||||
match lock.try_lock() {
|
||||
Ok(guard) => {
|
||||
this.quota_write_wake_scheduled = false;
|
||||
Some(guard)
|
||||
}
|
||||
Err(_) => {
|
||||
if !this.quota_write_wake_scheduled {
|
||||
this.quota_write_wake_scheduled = true;
|
||||
let waker = cx.waker().clone();
|
||||
tokio::task::spawn(async move {
|
||||
tokio::task::yield_now().await;
|
||||
waker.wake();
|
||||
});
|
||||
}
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let write_buf = if let Some(limit) = this.quota_limit {
|
||||
let used = this.stats.get_user_total_octets(&this.user);
|
||||
if used >= limit {
|
||||
this.quota_exceeded.store(true, Ordering::Relaxed);
|
||||
return Poll::Ready(Err(quota_io_error()));
|
||||
}
|
||||
|
||||
let remaining = (limit - used) as usize;
|
||||
if buf.len() > remaining {
|
||||
// Fail closed: do not emit partial S->C payload when remaining
|
||||
// quota cannot accommodate the pending write request.
|
||||
this.quota_exceeded.store(true, Ordering::Relaxed);
|
||||
return Poll::Ready(Err(quota_io_error()));
|
||||
}
|
||||
buf
|
||||
} else {
|
||||
buf
|
||||
};
|
||||
|
||||
match Pin::new(&mut this.inner).poll_write(cx, write_buf) {
|
||||
Poll::Ready(Ok(n)) => {
|
||||
if n > 0 {
|
||||
// S→C: data written to client
|
||||
|
|
@ -271,6 +420,13 @@ impl<S: AsyncWrite + Unpin> AsyncWrite for StatsIo<S> {
|
|||
this.stats.add_user_octets_to(&this.user, n as u64);
|
||||
this.stats.increment_user_msgs_to(&this.user);
|
||||
|
||||
if let Some(limit) = this.quota_limit
|
||||
&& this.stats.get_user_total_octets(&this.user) >= limit
|
||||
{
|
||||
this.quota_exceeded.store(true, Ordering::Relaxed);
|
||||
return Poll::Ready(Err(quota_io_error()));
|
||||
}
|
||||
|
||||
trace!(user = %this.user, bytes = n, "S->C");
|
||||
}
|
||||
Poll::Ready(Ok(n))
|
||||
|
|
@ -307,7 +463,8 @@ impl<S: AsyncWrite + Unpin> AsyncWrite for StatsIo<S> {
|
|||
/// - Per-user stats: bytes and ops counted per direction
|
||||
/// - Periodic rate logging: every 10 seconds when active
|
||||
/// - Clean shutdown: both write sides are shut down on exit
|
||||
/// - Error propagation: I/O errors are returned as `ProxyError::Io`
|
||||
/// - Error propagation: quota exits return `ProxyError::DataQuotaExceeded`,
|
||||
/// other I/O failures are returned as `ProxyError::Io`
|
||||
pub async fn relay_bidirectional<CR, CW, SR, SW>(
|
||||
client_reader: CR,
|
||||
client_writer: CW,
|
||||
|
|
@ -317,6 +474,7 @@ pub async fn relay_bidirectional<CR, CW, SR, SW>(
|
|||
s2c_buf_size: usize,
|
||||
user: &str,
|
||||
stats: Arc<Stats>,
|
||||
quota_limit: Option<u64>,
|
||||
_buffer_pool: Arc<BufferPool>,
|
||||
) -> Result<()>
|
||||
where
|
||||
|
|
@ -327,6 +485,7 @@ where
|
|||
{
|
||||
let epoch = Instant::now();
|
||||
let counters = Arc::new(SharedCounters::new());
|
||||
let quota_exceeded = Arc::new(AtomicBool::new(false));
|
||||
let user_owned = user.to_string();
|
||||
|
||||
// ── Combine split halves into bidirectional streams ──────────────
|
||||
|
|
@ -339,12 +498,15 @@ where
|
|||
Arc::clone(&counters),
|
||||
Arc::clone(&stats),
|
||||
user_owned.clone(),
|
||||
quota_limit,
|
||||
Arc::clone("a_exceeded),
|
||||
epoch,
|
||||
);
|
||||
|
||||
// ── Watchdog: activity timeout + periodic rate logging ──────────
|
||||
let wd_counters = Arc::clone(&counters);
|
||||
let wd_user = user_owned.clone();
|
||||
let wd_quota_exceeded = Arc::clone("a_exceeded);
|
||||
|
||||
let watchdog = async {
|
||||
let mut prev_c2s: u64 = 0;
|
||||
|
|
@ -356,6 +518,11 @@ where
|
|||
let now = Instant::now();
|
||||
let idle = wd_counters.idle_duration(now, epoch);
|
||||
|
||||
if wd_quota_exceeded.load(Ordering::Relaxed) {
|
||||
warn!(user = %wd_user, "User data quota reached, closing relay");
|
||||
return;
|
||||
}
|
||||
|
||||
// ── Activity timeout ────────────────────────────────────
|
||||
if idle >= ACTIVITY_TIMEOUT {
|
||||
let c2s = wd_counters.c2s_bytes.load(Ordering::Relaxed);
|
||||
|
|
@ -439,6 +606,22 @@ where
|
|||
);
|
||||
Ok(())
|
||||
}
|
||||
Some(Err(e)) if is_quota_io_error(&e) => {
|
||||
let c2s = counters.c2s_bytes.load(Ordering::Relaxed);
|
||||
let s2c = counters.s2c_bytes.load(Ordering::Relaxed);
|
||||
warn!(
|
||||
user = %user_owned,
|
||||
c2s_bytes = c2s,
|
||||
s2c_bytes = s2c,
|
||||
c2s_msgs = c2s_ops,
|
||||
s2c_msgs = s2c_ops,
|
||||
duration_secs = duration.as_secs(),
|
||||
"Data quota reached, closing relay"
|
||||
);
|
||||
Err(ProxyError::DataQuotaExceeded {
|
||||
user: user_owned.clone(),
|
||||
})
|
||||
}
|
||||
Some(Err(e)) => {
|
||||
// I/O error in one of the directions
|
||||
let c2s = counters.c2s_bytes.load(Ordering::Relaxed);
|
||||
|
|
@ -472,3 +655,7 @@ where
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "relay_security_tests.rs"]
|
||||
mod security_tests;
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,10 +1,10 @@
|
|||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicU8, AtomicU64, Ordering};
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
||||
|
||||
use tokio::sync::watch;
|
||||
|
||||
pub(crate) const ROUTE_SWITCH_ERROR_MSG: &str = "Route mode switched by cutover";
|
||||
pub(crate) const ROUTE_SWITCH_ERROR_MSG: &str = "Session terminated";
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
#[repr(u8)]
|
||||
|
|
@ -14,17 +14,6 @@ pub(crate) enum RelayRouteMode {
|
|||
}
|
||||
|
||||
impl RelayRouteMode {
|
||||
pub(crate) fn as_u8(self) -> u8 {
|
||||
self as u8
|
||||
}
|
||||
|
||||
pub(crate) fn from_u8(value: u8) -> Self {
|
||||
match value {
|
||||
1 => Self::Middle,
|
||||
_ => Self::Direct,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn as_str(self) -> &'static str {
|
||||
match self {
|
||||
Self::Direct => "direct",
|
||||
|
|
@ -41,8 +30,6 @@ pub(crate) struct RouteCutoverState {
|
|||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct RouteRuntimeController {
|
||||
mode: Arc<AtomicU8>,
|
||||
generation: Arc<AtomicU64>,
|
||||
direct_since_epoch_secs: Arc<AtomicU64>,
|
||||
tx: watch::Sender<RouteCutoverState>,
|
||||
}
|
||||
|
|
@ -60,18 +47,13 @@ impl RouteRuntimeController {
|
|||
0
|
||||
};
|
||||
Self {
|
||||
mode: Arc::new(AtomicU8::new(initial_mode.as_u8())),
|
||||
generation: Arc::new(AtomicU64::new(0)),
|
||||
direct_since_epoch_secs: Arc::new(AtomicU64::new(direct_since_epoch_secs)),
|
||||
tx,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn snapshot(&self) -> RouteCutoverState {
|
||||
RouteCutoverState {
|
||||
mode: RelayRouteMode::from_u8(self.mode.load(Ordering::Relaxed)),
|
||||
generation: self.generation.load(Ordering::Relaxed),
|
||||
}
|
||||
*self.tx.borrow()
|
||||
}
|
||||
|
||||
pub(crate) fn subscribe(&self) -> watch::Receiver<RouteCutoverState> {
|
||||
|
|
@ -84,20 +66,29 @@ impl RouteRuntimeController {
|
|||
}
|
||||
|
||||
pub(crate) fn set_mode(&self, mode: RelayRouteMode) -> Option<RouteCutoverState> {
|
||||
let previous = self.mode.swap(mode.as_u8(), Ordering::Relaxed);
|
||||
if previous == mode.as_u8() {
|
||||
let mut next = None;
|
||||
let changed = self.tx.send_if_modified(|state| {
|
||||
if state.mode == mode {
|
||||
return false;
|
||||
}
|
||||
state.mode = mode;
|
||||
state.generation = state.generation.saturating_add(1);
|
||||
next = Some(*state);
|
||||
true
|
||||
});
|
||||
|
||||
if !changed {
|
||||
return None;
|
||||
}
|
||||
|
||||
if matches!(mode, RelayRouteMode::Direct) {
|
||||
self.direct_since_epoch_secs
|
||||
.store(now_epoch_secs(), Ordering::Relaxed);
|
||||
} else {
|
||||
self.direct_since_epoch_secs.store(0, Ordering::Relaxed);
|
||||
}
|
||||
let generation = self.generation.fetch_add(1, Ordering::Relaxed) + 1;
|
||||
let next = RouteCutoverState { mode, generation };
|
||||
self.tx.send_replace(next);
|
||||
Some(next)
|
||||
|
||||
next
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -110,10 +101,10 @@ fn now_epoch_secs() -> u64 {
|
|||
|
||||
pub(crate) fn is_session_affected_by_cutover(
|
||||
current: RouteCutoverState,
|
||||
_session_mode: RelayRouteMode,
|
||||
session_mode: RelayRouteMode,
|
||||
session_generation: u64,
|
||||
) -> bool {
|
||||
current.generation > session_generation
|
||||
current.generation > session_generation && current.mode != session_mode
|
||||
}
|
||||
|
||||
pub(crate) fn affected_cutover_state(
|
||||
|
|
@ -140,3 +131,7 @@ pub(crate) fn cutover_stagger_delay(session_id: u64, generation: u64) -> Duratio
|
|||
let ms = 1000 + (value % 1000);
|
||||
Duration::from_millis(ms)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "route_mode_security_tests.rs"]
|
||||
mod security_tests;
|
||||
|
|
|
|||
|
|
@ -0,0 +1,406 @@
|
|||
use super::*;
|
||||
use rand::{Rng, SeedableRng};
|
||||
use rand::rngs::StdRng;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
|
||||
#[test]
|
||||
fn cutover_stagger_delay_is_deterministic_for_same_inputs() {
|
||||
let d1 = cutover_stagger_delay(0x0123_4567_89ab_cdef, 42);
|
||||
let d2 = cutover_stagger_delay(0x0123_4567_89ab_cdef, 42);
|
||||
assert_eq!(
|
||||
d1, d2,
|
||||
"stagger delay must be deterministic for identical session/generation inputs"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cutover_stagger_delay_stays_within_budget_bounds() {
|
||||
// Black-hat model: censors trigger many cutovers and correlate disconnect timing.
|
||||
// Keep delay inside a narrow coarse window to avoid long-tail spikes.
|
||||
for generation in [0u64, 1, 2, 3, 16, 128, u32::MAX as u64, u64::MAX] {
|
||||
for session_id in [
|
||||
0u64,
|
||||
1,
|
||||
2,
|
||||
0xdead_beef,
|
||||
0xfeed_face_cafe_babe,
|
||||
u64::MAX,
|
||||
] {
|
||||
let delay = cutover_stagger_delay(session_id, generation);
|
||||
assert!(
|
||||
(1000..=1999).contains(&delay.as_millis()),
|
||||
"stagger delay must remain in fixed 1000..=1999ms budget"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cutover_stagger_delay_changes_with_generation_for_same_session() {
|
||||
let session_id = 0x0123_4567_89ab_cdef;
|
||||
let first = cutover_stagger_delay(session_id, 100);
|
||||
let second = cutover_stagger_delay(session_id, 101);
|
||||
assert_ne!(
|
||||
first, second,
|
||||
"adjacent cutover generations should decorrelate disconnect delays"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn route_runtime_set_mode_is_idempotent_for_same_mode() {
|
||||
let runtime = RouteRuntimeController::new(RelayRouteMode::Direct);
|
||||
let first = runtime.snapshot();
|
||||
let changed = runtime.set_mode(RelayRouteMode::Direct);
|
||||
let second = runtime.snapshot();
|
||||
|
||||
assert!(
|
||||
changed.is_none(),
|
||||
"setting already-active mode must not produce a cutover event"
|
||||
);
|
||||
assert_eq!(
|
||||
first.generation, second.generation,
|
||||
"idempotent mode set must not bump generation"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn affected_cutover_state_triggers_only_for_newer_generation() {
|
||||
let runtime = RouteRuntimeController::new(RelayRouteMode::Direct);
|
||||
let rx = runtime.subscribe();
|
||||
let initial = runtime.snapshot();
|
||||
|
||||
assert!(
|
||||
affected_cutover_state(&rx, RelayRouteMode::Direct, initial.generation).is_none(),
|
||||
"current generation must not be considered a cutover for existing session"
|
||||
);
|
||||
|
||||
let next = runtime
|
||||
.set_mode(RelayRouteMode::Middle)
|
||||
.expect("mode change must produce cutover state");
|
||||
let seen = affected_cutover_state(&rx, RelayRouteMode::Direct, initial.generation)
|
||||
.expect("newer generation must be observed as cutover");
|
||||
|
||||
assert_eq!(seen.generation, next.generation);
|
||||
assert_eq!(seen.mode, RelayRouteMode::Middle);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn integration_watch_and_snapshot_follow_same_transition_sequence() {
|
||||
let runtime = RouteRuntimeController::new(RelayRouteMode::Direct);
|
||||
let rx = runtime.subscribe();
|
||||
|
||||
let sequence = [
|
||||
RelayRouteMode::Middle,
|
||||
RelayRouteMode::Middle,
|
||||
RelayRouteMode::Direct,
|
||||
RelayRouteMode::Direct,
|
||||
RelayRouteMode::Middle,
|
||||
];
|
||||
|
||||
let mut expected_generation = 0u64;
|
||||
let mut expected_mode = RelayRouteMode::Direct;
|
||||
|
||||
for target in sequence {
|
||||
let changed = runtime.set_mode(target);
|
||||
if target == expected_mode {
|
||||
assert!(changed.is_none(), "idempotent transition must return none");
|
||||
} else {
|
||||
expected_mode = target;
|
||||
expected_generation = expected_generation.saturating_add(1);
|
||||
let emitted = changed.expect("real transition must emit cutover state");
|
||||
assert_eq!(emitted.mode, expected_mode);
|
||||
assert_eq!(emitted.generation, expected_generation);
|
||||
}
|
||||
|
||||
let snap = runtime.snapshot();
|
||||
let watched = *rx.borrow();
|
||||
assert_eq!(snap, watched, "snapshot and watch state must stay aligned");
|
||||
assert_eq!(snap.mode, expected_mode);
|
||||
assert_eq!(snap.generation, expected_generation);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn session_is_not_affected_when_mode_matches_even_if_generation_advanced() {
|
||||
let session_mode = RelayRouteMode::Direct;
|
||||
let current = RouteCutoverState {
|
||||
mode: RelayRouteMode::Direct,
|
||||
generation: 2,
|
||||
};
|
||||
let session_generation = 0;
|
||||
|
||||
assert!(
|
||||
!is_session_affected_by_cutover(current, session_mode, session_generation),
|
||||
"session on matching final route mode should not be force-cut over on intermediate generation bumps"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cutover_predicate_rejects_equal_generation_even_if_mode_differs() {
|
||||
let current = RouteCutoverState {
|
||||
mode: RelayRouteMode::Middle,
|
||||
generation: 77,
|
||||
};
|
||||
assert!(
|
||||
!is_session_affected_by_cutover(current, RelayRouteMode::Direct, 77),
|
||||
"equal generation must never trigger cutover regardless of mode mismatch"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn adversarial_route_oscillation_only_cuts_over_sessions_with_different_final_mode() {
|
||||
let runtime = RouteRuntimeController::new(RelayRouteMode::Direct);
|
||||
let rx = runtime.subscribe();
|
||||
let session_generation = runtime.snapshot().generation;
|
||||
|
||||
runtime
|
||||
.set_mode(RelayRouteMode::Middle)
|
||||
.expect("direct->middle must transition");
|
||||
runtime
|
||||
.set_mode(RelayRouteMode::Direct)
|
||||
.expect("middle->direct must transition");
|
||||
|
||||
assert!(
|
||||
affected_cutover_state(&rx, RelayRouteMode::Direct, session_generation).is_none(),
|
||||
"direct session should survive when final mode returns to direct"
|
||||
);
|
||||
assert!(
|
||||
affected_cutover_state(&rx, RelayRouteMode::Middle, session_generation).is_some(),
|
||||
"middle session should be cut over when final mode is direct"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn light_fuzz_cutover_predicate_matches_reference_oracle() {
|
||||
let mut rng = StdRng::seed_from_u64(0xC0DEC0DE5EED);
|
||||
for _ in 0..20_000 {
|
||||
let current = RouteCutoverState {
|
||||
mode: if rng.random::<bool>() {
|
||||
RelayRouteMode::Direct
|
||||
} else {
|
||||
RelayRouteMode::Middle
|
||||
},
|
||||
generation: rng.random_range(0u64..1_000_000),
|
||||
};
|
||||
let session_mode = if rng.random::<bool>() {
|
||||
RelayRouteMode::Direct
|
||||
} else {
|
||||
RelayRouteMode::Middle
|
||||
};
|
||||
let session_generation = rng.random_range(0u64..1_000_000);
|
||||
|
||||
let expected = current.generation > session_generation && current.mode != session_mode;
|
||||
let actual = is_session_affected_by_cutover(current, session_mode, session_generation);
|
||||
assert_eq!(
|
||||
actual, expected,
|
||||
"cutover predicate must match mode-aware generation oracle"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn light_fuzz_set_mode_generation_tracks_only_real_transitions() {
|
||||
let runtime = RouteRuntimeController::new(RelayRouteMode::Direct);
|
||||
let mut rng = StdRng::seed_from_u64(0x0DDC0FFE);
|
||||
|
||||
let mut expected_mode = RelayRouteMode::Direct;
|
||||
let mut expected_generation = 0u64;
|
||||
|
||||
for _ in 0..10_000 {
|
||||
let candidate = if rng.random::<bool>() {
|
||||
RelayRouteMode::Direct
|
||||
} else {
|
||||
RelayRouteMode::Middle
|
||||
};
|
||||
let changed = runtime.set_mode(candidate);
|
||||
|
||||
if candidate == expected_mode {
|
||||
assert!(changed.is_none(), "idempotent set_mode must not emit cutover state");
|
||||
} else {
|
||||
expected_mode = candidate;
|
||||
expected_generation = expected_generation.saturating_add(1);
|
||||
let next = changed.expect("mode transition must emit cutover state");
|
||||
assert_eq!(next.mode, expected_mode);
|
||||
assert_eq!(next.generation, expected_generation);
|
||||
}
|
||||
}
|
||||
|
||||
let final_state = runtime.snapshot();
|
||||
assert_eq!(final_state.mode, expected_mode);
|
||||
assert_eq!(final_state.generation, expected_generation);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stress_snapshot_and_watch_state_remain_consistent_under_concurrent_switch_storm() {
|
||||
let runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct));
|
||||
|
||||
std::thread::scope(|scope| {
|
||||
let mut writers = Vec::new();
|
||||
for worker in 0..4usize {
|
||||
let runtime = Arc::clone(&runtime);
|
||||
writers.push(scope.spawn(move || {
|
||||
for step in 0..20_000usize {
|
||||
let mode = if (worker + step) % 2 == 0 {
|
||||
RelayRouteMode::Direct
|
||||
} else {
|
||||
RelayRouteMode::Middle
|
||||
};
|
||||
let _ = runtime.set_mode(mode);
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
for writer in writers {
|
||||
writer
|
||||
.join()
|
||||
.expect("route mode writer thread must not panic");
|
||||
}
|
||||
|
||||
let rx = runtime.subscribe();
|
||||
for _ in 0..128 {
|
||||
assert_eq!(
|
||||
runtime.snapshot(),
|
||||
*rx.borrow(),
|
||||
"snapshot and watch state must converge after concurrent set_mode churn"
|
||||
);
|
||||
std::thread::yield_now();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stress_concurrent_transition_count_matches_final_generation() {
|
||||
let runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct));
|
||||
let successful_transitions = Arc::new(AtomicU64::new(0));
|
||||
|
||||
std::thread::scope(|scope| {
|
||||
let mut workers = Vec::new();
|
||||
for worker in 0..6usize {
|
||||
let runtime = Arc::clone(&runtime);
|
||||
let successful_transitions = Arc::clone(&successful_transitions);
|
||||
workers.push(scope.spawn(move || {
|
||||
let mut state = (worker as u64 + 1).wrapping_mul(0x9E37_79B9_7F4A_7C15);
|
||||
for _ in 0..25_000usize {
|
||||
state ^= state << 7;
|
||||
state ^= state >> 9;
|
||||
state ^= state << 8;
|
||||
let mode = if (state & 1) == 0 {
|
||||
RelayRouteMode::Direct
|
||||
} else {
|
||||
RelayRouteMode::Middle
|
||||
};
|
||||
if runtime.set_mode(mode).is_some() {
|
||||
successful_transitions.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
for worker in workers {
|
||||
worker.join().expect("route mode transition worker must not panic");
|
||||
}
|
||||
});
|
||||
|
||||
let final_state = runtime.snapshot();
|
||||
assert_eq!(
|
||||
final_state.generation,
|
||||
successful_transitions.load(Ordering::Relaxed),
|
||||
"final generation must equal number of accepted mode transitions"
|
||||
);
|
||||
assert_eq!(
|
||||
final_state,
|
||||
*runtime.subscribe().borrow(),
|
||||
"watch and snapshot state must match after concurrent transition accounting"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn light_fuzz_cutover_stagger_delay_distribution_stays_in_fixed_window() {
|
||||
// Deterministic xorshift fuzzing keeps this test stable across runs.
|
||||
let mut s: u64 = 0x9E37_79B9_7F4A_7C15;
|
||||
|
||||
for _ in 0..20_000 {
|
||||
s ^= s << 7;
|
||||
s ^= s >> 9;
|
||||
s ^= s << 8;
|
||||
let session_id = s;
|
||||
|
||||
s ^= s << 7;
|
||||
s ^= s >> 9;
|
||||
s ^= s << 8;
|
||||
let generation = s;
|
||||
|
||||
let delay = cutover_stagger_delay(session_id, generation);
|
||||
assert!(
|
||||
(1000..=1999).contains(&delay.as_millis()),
|
||||
"fuzzed inputs must always map into fixed stagger window"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cutover_stagger_delay_distribution_has_no_empty_buckets_under_sequential_sessions() {
|
||||
let mut buckets = [0usize; 1000];
|
||||
let generation = 4242u64;
|
||||
|
||||
for session_id in 0..250_000u64 {
|
||||
let delay_ms = cutover_stagger_delay(session_id, generation).as_millis() as usize;
|
||||
let idx = delay_ms - 1000;
|
||||
buckets[idx] += 1;
|
||||
}
|
||||
|
||||
let empty = buckets.iter().filter(|&&count| count == 0).count();
|
||||
assert_eq!(
|
||||
empty, 0,
|
||||
"all 1000 delay buckets must be exercised to avoid cutover herd clustering"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn light_fuzz_cutover_stagger_delay_distribution_stays_reasonably_uniform() {
|
||||
let mut buckets = [0usize; 1000];
|
||||
let mut s: u64 = 0x1BAD_B002_CAFE_F00D;
|
||||
|
||||
for _ in 0..300_000usize {
|
||||
s ^= s << 7;
|
||||
s ^= s >> 9;
|
||||
s ^= s << 8;
|
||||
let session_id = s;
|
||||
|
||||
s ^= s << 7;
|
||||
s ^= s >> 9;
|
||||
s ^= s << 8;
|
||||
let generation = s;
|
||||
|
||||
let delay_ms = cutover_stagger_delay(session_id, generation).as_millis() as usize;
|
||||
buckets[delay_ms - 1000] += 1;
|
||||
}
|
||||
|
||||
let min = *buckets.iter().min().unwrap_or(&0);
|
||||
let max = *buckets.iter().max().unwrap_or(&0);
|
||||
assert!(min > 0, "fuzzed distribution must not leave empty buckets");
|
||||
assert!(
|
||||
max <= min.saturating_mul(3),
|
||||
"bucket skew is too high for anti-herd staggering (max={max}, min={min})"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stress_cutover_stagger_delay_distribution_remains_stable_across_generations() {
|
||||
for generation in [0u64, 1, 7, 31, 255, 1024, u32::MAX as u64, u64::MAX - 1] {
|
||||
let mut buckets = [0usize; 1000];
|
||||
for session_id in 0..100_000u64 {
|
||||
let delay_ms = cutover_stagger_delay(session_id ^ 0x9E37_79B9, generation)
|
||||
.as_millis() as usize;
|
||||
buckets[delay_ms - 1000] += 1;
|
||||
}
|
||||
|
||||
let min = *buckets.iter().min().unwrap_or(&0);
|
||||
let max = *buckets.iter().max().unwrap_or(&0);
|
||||
assert!(
|
||||
max <= min.saturating_mul(4).max(1),
|
||||
"generation={generation}: distribution collapsed (max={max}, min={min})"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,265 @@
|
|||
use super::*;
|
||||
use std::panic::{self, AssertUnwindSafe};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::Barrier;
|
||||
|
||||
#[test]
|
||||
fn direct_connection_lease_balances_on_drop() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
assert_eq!(stats.get_current_connections_direct(), 0);
|
||||
|
||||
{
|
||||
let _lease = stats.acquire_direct_connection_lease();
|
||||
assert_eq!(stats.get_current_connections_direct(), 1);
|
||||
}
|
||||
|
||||
assert_eq!(stats.get_current_connections_direct(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn middle_connection_lease_balances_on_drop() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
assert_eq!(stats.get_current_connections_me(), 0);
|
||||
|
||||
{
|
||||
let _lease = stats.acquire_me_connection_lease();
|
||||
assert_eq!(stats.get_current_connections_me(), 1);
|
||||
}
|
||||
|
||||
assert_eq!(stats.get_current_connections_me(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connection_lease_disarm_prevents_double_release() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
|
||||
let mut lease = stats.acquire_direct_connection_lease();
|
||||
assert_eq!(stats.get_current_connections_direct(), 1);
|
||||
|
||||
stats.decrement_current_connections_direct();
|
||||
assert_eq!(stats.get_current_connections_direct(), 0);
|
||||
|
||||
lease.disarm();
|
||||
drop(lease);
|
||||
|
||||
assert_eq!(stats.get_current_connections_direct(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn direct_connection_lease_balances_on_panic_unwind() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let stats_for_panic = stats.clone();
|
||||
|
||||
let panic_result = panic::catch_unwind(AssertUnwindSafe(move || {
|
||||
let _lease = stats_for_panic.acquire_direct_connection_lease();
|
||||
panic!("intentional panic to verify lease drop path");
|
||||
}));
|
||||
|
||||
assert!(panic_result.is_err(), "panic must propagate from test closure");
|
||||
assert_eq!(
|
||||
stats.get_current_connections_direct(),
|
||||
0,
|
||||
"panic unwind must release direct route gauge"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn middle_connection_lease_balances_on_panic_unwind() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let stats_for_panic = stats.clone();
|
||||
|
||||
let panic_result = panic::catch_unwind(AssertUnwindSafe(move || {
|
||||
let _lease = stats_for_panic.acquire_me_connection_lease();
|
||||
panic!("intentional panic to verify middle lease drop path");
|
||||
}));
|
||||
|
||||
assert!(panic_result.is_err(), "panic must propagate from test closure");
|
||||
assert_eq!(
|
||||
stats.get_current_connections_me(),
|
||||
0,
|
||||
"panic unwind must release middle route gauge"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn concurrent_mixed_route_lease_churn_balances_to_zero() {
|
||||
const TASKS: usize = 48;
|
||||
const ITERATIONS_PER_TASK: usize = 256;
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let barrier = Arc::new(Barrier::new(TASKS));
|
||||
let mut workers = Vec::with_capacity(TASKS);
|
||||
|
||||
for task_idx in 0..TASKS {
|
||||
let stats_for_task = stats.clone();
|
||||
let barrier_for_task = barrier.clone();
|
||||
workers.push(tokio::spawn(async move {
|
||||
barrier_for_task.wait().await;
|
||||
for iter in 0..ITERATIONS_PER_TASK {
|
||||
if (task_idx + iter) % 2 == 0 {
|
||||
let _lease = stats_for_task.acquire_direct_connection_lease();
|
||||
tokio::task::yield_now().await;
|
||||
} else {
|
||||
let _lease = stats_for_task.acquire_me_connection_lease();
|
||||
tokio::task::yield_now().await;
|
||||
}
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
for worker in workers {
|
||||
worker
|
||||
.await
|
||||
.expect("lease churn worker must not panic");
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
stats.get_current_connections_direct(),
|
||||
0,
|
||||
"direct route gauge must return to zero after concurrent lease churn"
|
||||
);
|
||||
assert_eq!(
|
||||
stats.get_current_connections_me(),
|
||||
0,
|
||||
"middle route gauge must return to zero after concurrent lease churn"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn abort_storm_mixed_route_leases_returns_all_gauges_to_zero() {
|
||||
const TASKS: usize = 64;
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let mut workers = Vec::with_capacity(TASKS);
|
||||
|
||||
for task_idx in 0..TASKS {
|
||||
let stats_for_task = stats.clone();
|
||||
workers.push(tokio::spawn(async move {
|
||||
if task_idx % 2 == 0 {
|
||||
let _lease = stats_for_task.acquire_direct_connection_lease();
|
||||
tokio::time::sleep(Duration::from_secs(60)).await;
|
||||
} else {
|
||||
let _lease = stats_for_task.acquire_me_connection_lease();
|
||||
tokio::time::sleep(Duration::from_secs(60)).await;
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
tokio::time::timeout(Duration::from_secs(2), async {
|
||||
loop {
|
||||
let total = stats.get_current_connections_direct() + stats.get_current_connections_me();
|
||||
if total == TASKS as u64 {
|
||||
break;
|
||||
}
|
||||
tokio::time::sleep(Duration::from_millis(10)).await;
|
||||
}
|
||||
})
|
||||
.await
|
||||
.expect("all storm tasks must acquire route leases before abort");
|
||||
|
||||
for worker in &workers {
|
||||
worker.abort();
|
||||
}
|
||||
for worker in workers {
|
||||
let joined = worker.await;
|
||||
assert!(joined.is_err(), "aborted worker must return join error");
|
||||
}
|
||||
|
||||
tokio::time::timeout(Duration::from_secs(2), async {
|
||||
loop {
|
||||
if stats.get_current_connections_direct() == 0 && stats.get_current_connections_me() == 0 {
|
||||
break;
|
||||
}
|
||||
tokio::time::sleep(Duration::from_millis(10)).await;
|
||||
}
|
||||
})
|
||||
.await
|
||||
.expect("all route gauges must drain to zero after abort storm");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn saturating_route_decrements_do_not_underflow_under_race() {
|
||||
const THREADS: usize = 16;
|
||||
const DECREMENTS_PER_THREAD: usize = 4096;
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let mut workers = Vec::with_capacity(THREADS);
|
||||
|
||||
for _ in 0..THREADS {
|
||||
let stats_for_thread = stats.clone();
|
||||
workers.push(std::thread::spawn(move || {
|
||||
for _ in 0..DECREMENTS_PER_THREAD {
|
||||
stats_for_thread.decrement_current_connections_direct();
|
||||
stats_for_thread.decrement_current_connections_me();
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
for worker in workers {
|
||||
worker
|
||||
.join()
|
||||
.expect("decrement race worker must not panic");
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
stats.get_current_connections_direct(),
|
||||
0,
|
||||
"direct route decrement races must never underflow"
|
||||
);
|
||||
assert_eq!(
|
||||
stats.get_current_connections_me(),
|
||||
0,
|
||||
"middle route decrement races must never underflow"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn direct_connection_lease_balances_on_task_abort() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let stats_for_task = stats.clone();
|
||||
|
||||
let task = tokio::spawn(async move {
|
||||
let _lease = stats_for_task.acquire_direct_connection_lease();
|
||||
tokio::time::sleep(Duration::from_secs(60)).await;
|
||||
});
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(20)).await;
|
||||
assert_eq!(stats.get_current_connections_direct(), 1);
|
||||
|
||||
task.abort();
|
||||
let joined = task.await;
|
||||
assert!(joined.is_err(), "aborted task must return a join error");
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(20)).await;
|
||||
assert_eq!(
|
||||
stats.get_current_connections_direct(),
|
||||
0,
|
||||
"aborted task must release direct route gauge"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn middle_connection_lease_balances_on_task_abort() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let stats_for_task = stats.clone();
|
||||
|
||||
let task = tokio::spawn(async move {
|
||||
let _lease = stats_for_task.acquire_me_connection_lease();
|
||||
tokio::time::sleep(Duration::from_secs(60)).await;
|
||||
});
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(20)).await;
|
||||
assert_eq!(stats.get_current_connections_me(), 1);
|
||||
|
||||
task.abort();
|
||||
let joined = task.await;
|
||||
assert!(joined.is_err(), "aborted task must return a join error");
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(20)).await;
|
||||
assert_eq!(
|
||||
stats.get_current_connections_me(),
|
||||
0,
|
||||
"aborted task must release middle route gauge"
|
||||
);
|
||||
}
|
||||
125
src/stats/mod.rs
125
src/stats/mod.rs
|
|
@ -6,6 +6,7 @@ pub mod beobachten;
|
|||
pub mod telemetry;
|
||||
|
||||
use std::sync::atomic::{AtomicBool, AtomicU8, AtomicU64, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
|
||||
use dashmap::DashMap;
|
||||
use parking_lot::Mutex;
|
||||
|
|
@ -19,6 +20,46 @@ use tracing::debug;
|
|||
use crate::config::{MeTelemetryLevel, MeWriterPickMode};
|
||||
use self::telemetry::TelemetryPolicy;
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
enum RouteConnectionGauge {
|
||||
Direct,
|
||||
Middle,
|
||||
}
|
||||
|
||||
#[must_use = "RouteConnectionLease must be kept alive to hold the connection gauge increment"]
|
||||
pub struct RouteConnectionLease {
|
||||
stats: Arc<Stats>,
|
||||
gauge: RouteConnectionGauge,
|
||||
active: bool,
|
||||
}
|
||||
|
||||
impl RouteConnectionLease {
|
||||
fn new(stats: Arc<Stats>, gauge: RouteConnectionGauge) -> Self {
|
||||
Self {
|
||||
stats,
|
||||
gauge,
|
||||
active: true,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn disarm(&mut self) {
|
||||
self.active = false;
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for RouteConnectionLease {
|
||||
fn drop(&mut self) {
|
||||
if !self.active {
|
||||
return;
|
||||
}
|
||||
match self.gauge {
|
||||
RouteConnectionGauge::Direct => self.stats.decrement_current_connections_direct(),
|
||||
RouteConnectionGauge::Middle => self.stats.decrement_current_connections_me(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============= Stats =============
|
||||
|
||||
#[derive(Default)]
|
||||
|
|
@ -285,6 +326,16 @@ impl Stats {
|
|||
pub fn decrement_current_connections_me(&self) {
|
||||
Self::decrement_atomic_saturating(&self.current_connections_me);
|
||||
}
|
||||
|
||||
pub fn acquire_direct_connection_lease(self: &Arc<Self>) -> RouteConnectionLease {
|
||||
self.increment_current_connections_direct();
|
||||
RouteConnectionLease::new(self.clone(), RouteConnectionGauge::Direct)
|
||||
}
|
||||
|
||||
pub fn acquire_me_connection_lease(self: &Arc<Self>) -> RouteConnectionLease {
|
||||
self.increment_current_connections_me();
|
||||
RouteConnectionLease::new(self.clone(), RouteConnectionGauge::Middle)
|
||||
}
|
||||
pub fn increment_handshake_timeouts(&self) {
|
||||
if self.telemetry_core_enabled() {
|
||||
self.handshake_timeouts.fetch_add(1, Ordering::Relaxed);
|
||||
|
|
@ -1457,9 +1508,11 @@ impl Stats {
|
|||
// ============= Replay Checker =============
|
||||
|
||||
pub struct ReplayChecker {
|
||||
shards: Vec<Mutex<ReplayShard>>,
|
||||
handshake_shards: Vec<Mutex<ReplayShard>>,
|
||||
tls_shards: Vec<Mutex<ReplayShard>>,
|
||||
shard_mask: usize,
|
||||
window: Duration,
|
||||
tls_window: Duration,
|
||||
checks: AtomicU64,
|
||||
hits: AtomicU64,
|
||||
additions: AtomicU64,
|
||||
|
|
@ -1536,19 +1589,24 @@ impl ReplayShard {
|
|||
|
||||
impl ReplayChecker {
|
||||
pub fn new(total_capacity: usize, window: Duration) -> Self {
|
||||
const MIN_TLS_REPLAY_WINDOW: Duration = Duration::from_secs(120);
|
||||
let num_shards = 64;
|
||||
let shard_capacity = (total_capacity / num_shards).max(1);
|
||||
let cap = NonZeroUsize::new(shard_capacity).unwrap();
|
||||
|
||||
let mut shards = Vec::with_capacity(num_shards);
|
||||
let mut handshake_shards = Vec::with_capacity(num_shards);
|
||||
let mut tls_shards = Vec::with_capacity(num_shards);
|
||||
for _ in 0..num_shards {
|
||||
shards.push(Mutex::new(ReplayShard::new(cap)));
|
||||
handshake_shards.push(Mutex::new(ReplayShard::new(cap)));
|
||||
tls_shards.push(Mutex::new(ReplayShard::new(cap)));
|
||||
}
|
||||
|
||||
Self {
|
||||
shards,
|
||||
handshake_shards,
|
||||
tls_shards,
|
||||
shard_mask: num_shards - 1,
|
||||
window,
|
||||
tls_window: window.max(MIN_TLS_REPLAY_WINDOW),
|
||||
checks: AtomicU64::new(0),
|
||||
hits: AtomicU64::new(0),
|
||||
additions: AtomicU64::new(0),
|
||||
|
|
@ -1562,46 +1620,60 @@ impl ReplayChecker {
|
|||
(hasher.finish() as usize) & self.shard_mask
|
||||
}
|
||||
|
||||
fn check_and_add_internal(&self, data: &[u8]) -> bool {
|
||||
fn check_and_add_internal(
|
||||
&self,
|
||||
data: &[u8],
|
||||
shards: &[Mutex<ReplayShard>],
|
||||
window: Duration,
|
||||
) -> bool {
|
||||
self.checks.fetch_add(1, Ordering::Relaxed);
|
||||
let idx = self.get_shard_idx(data);
|
||||
let mut shard = self.shards[idx].lock();
|
||||
let mut shard = shards[idx].lock();
|
||||
let now = Instant::now();
|
||||
let found = shard.check(data, now, self.window);
|
||||
let found = shard.check(data, now, window);
|
||||
if found {
|
||||
self.hits.fetch_add(1, Ordering::Relaxed);
|
||||
} else {
|
||||
shard.add(data, now, self.window);
|
||||
shard.add(data, now, window);
|
||||
self.additions.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
found
|
||||
}
|
||||
|
||||
fn add_only(&self, data: &[u8]) {
|
||||
fn add_only(&self, data: &[u8], shards: &[Mutex<ReplayShard>], window: Duration) {
|
||||
self.additions.fetch_add(1, Ordering::Relaxed);
|
||||
let idx = self.get_shard_idx(data);
|
||||
let mut shard = self.shards[idx].lock();
|
||||
shard.add(data, Instant::now(), self.window);
|
||||
let mut shard = shards[idx].lock();
|
||||
shard.add(data, Instant::now(), window);
|
||||
}
|
||||
|
||||
pub fn check_and_add_handshake(&self, data: &[u8]) -> bool {
|
||||
self.check_and_add_internal(data)
|
||||
self.check_and_add_internal(data, &self.handshake_shards, self.window)
|
||||
}
|
||||
|
||||
pub fn check_and_add_tls_digest(&self, data: &[u8]) -> bool {
|
||||
self.check_and_add_internal(data)
|
||||
self.check_and_add_internal(data, &self.tls_shards, self.tls_window)
|
||||
}
|
||||
|
||||
// Compatibility helpers (non-atomic split operations) — prefer check_and_add_*.
|
||||
pub fn check_handshake(&self, data: &[u8]) -> bool { self.check_and_add_handshake(data) }
|
||||
pub fn add_handshake(&self, data: &[u8]) { self.add_only(data) }
|
||||
pub fn add_handshake(&self, data: &[u8]) {
|
||||
self.add_only(data, &self.handshake_shards, self.window)
|
||||
}
|
||||
pub fn check_tls_digest(&self, data: &[u8]) -> bool { self.check_and_add_tls_digest(data) }
|
||||
pub fn add_tls_digest(&self, data: &[u8]) { self.add_only(data) }
|
||||
pub fn add_tls_digest(&self, data: &[u8]) {
|
||||
self.add_only(data, &self.tls_shards, self.tls_window)
|
||||
}
|
||||
|
||||
pub fn stats(&self) -> ReplayStats {
|
||||
let mut total_entries = 0;
|
||||
let mut total_queue_len = 0;
|
||||
for shard in &self.shards {
|
||||
for shard in &self.handshake_shards {
|
||||
let s = shard.lock();
|
||||
total_entries += s.cache.len();
|
||||
total_queue_len += s.queue.len();
|
||||
}
|
||||
for shard in &self.tls_shards {
|
||||
let s = shard.lock();
|
||||
total_entries += s.cache.len();
|
||||
total_queue_len += s.queue.len();
|
||||
|
|
@ -1614,7 +1686,7 @@ impl ReplayChecker {
|
|||
total_hits: self.hits.load(Ordering::Relaxed),
|
||||
total_additions: self.additions.load(Ordering::Relaxed),
|
||||
total_cleanups: self.cleanups.load(Ordering::Relaxed),
|
||||
num_shards: self.shards.len(),
|
||||
num_shards: self.handshake_shards.len() + self.tls_shards.len(),
|
||||
window_secs: self.window.as_secs(),
|
||||
}
|
||||
}
|
||||
|
|
@ -1632,13 +1704,20 @@ impl ReplayChecker {
|
|||
let now = Instant::now();
|
||||
let mut cleaned = 0usize;
|
||||
|
||||
for shard_mutex in &self.shards {
|
||||
for shard_mutex in &self.handshake_shards {
|
||||
let mut shard = shard_mutex.lock();
|
||||
let before = shard.len();
|
||||
shard.cleanup(now, self.window);
|
||||
let after = shard.len();
|
||||
cleaned += before.saturating_sub(after);
|
||||
}
|
||||
for shard_mutex in &self.tls_shards {
|
||||
let mut shard = shard_mutex.lock();
|
||||
let before = shard.len();
|
||||
shard.cleanup(now, self.tls_window);
|
||||
let after = shard.len();
|
||||
cleaned += before.saturating_sub(after);
|
||||
}
|
||||
|
||||
self.cleanups.fetch_add(1, Ordering::Relaxed);
|
||||
|
||||
|
|
@ -1764,7 +1843,7 @@ mod tests {
|
|||
fn test_replay_checker_many_keys() {
|
||||
let checker = ReplayChecker::new(10_000, Duration::from_secs(60));
|
||||
for i in 0..500u32 {
|
||||
checker.add_only(&i.to_le_bytes());
|
||||
checker.add_handshake(&i.to_le_bytes());
|
||||
}
|
||||
for i in 0..500u32 {
|
||||
assert!(checker.check_handshake(&i.to_le_bytes()));
|
||||
|
|
@ -1772,3 +1851,11 @@ mod tests {
|
|||
assert_eq!(checker.stats().total_entries, 500);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "connection_lease_security_tests.rs"]
|
||||
mod connection_lease_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "replay_checker_security_tests.rs"]
|
||||
mod replay_checker_security_tests;
|
||||
|
|
|
|||
|
|
@ -0,0 +1,80 @@
|
|||
use super::*;
|
||||
use std::time::Duration;
|
||||
|
||||
#[test]
|
||||
fn replay_checker_keeps_tls_and_handshake_domains_isolated_for_same_key() {
|
||||
let checker = ReplayChecker::new(128, Duration::from_millis(20));
|
||||
let key = b"same-key-domain-separation";
|
||||
|
||||
assert!(
|
||||
!checker.check_and_add_handshake(key),
|
||||
"first handshake use should be fresh"
|
||||
);
|
||||
assert!(
|
||||
!checker.check_and_add_tls_digest(key),
|
||||
"same bytes in TLS domain should still be fresh"
|
||||
);
|
||||
|
||||
assert!(
|
||||
checker.check_and_add_handshake(key),
|
||||
"second handshake use should be replay-hit"
|
||||
);
|
||||
assert!(
|
||||
checker.check_and_add_tls_digest(key),
|
||||
"second TLS use should be replay-hit independently"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn replay_checker_tls_window_is_clamped_beyond_small_handshake_window() {
|
||||
let checker = ReplayChecker::new(128, Duration::from_millis(20));
|
||||
let handshake_key = b"short-window-handshake";
|
||||
let tls_key = b"short-window-tls";
|
||||
|
||||
assert!(!checker.check_and_add_handshake(handshake_key));
|
||||
assert!(!checker.check_and_add_tls_digest(tls_key));
|
||||
|
||||
std::thread::sleep(Duration::from_millis(80));
|
||||
|
||||
assert!(
|
||||
!checker.check_and_add_handshake(handshake_key),
|
||||
"handshake key should expire under short configured window"
|
||||
);
|
||||
assert!(
|
||||
checker.check_and_add_tls_digest(tls_key),
|
||||
"TLS key should still be replay-hit because TLS window is clamped to a secure minimum"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn replay_checker_compat_add_paths_do_not_cross_pollute_domains() {
|
||||
let checker = ReplayChecker::new(128, Duration::from_secs(1));
|
||||
let key = b"compat-domain-separation";
|
||||
|
||||
checker.add_handshake(key);
|
||||
assert!(
|
||||
checker.check_and_add_handshake(key),
|
||||
"handshake add helper must populate handshake domain"
|
||||
);
|
||||
assert!(
|
||||
!checker.check_and_add_tls_digest(key),
|
||||
"handshake add helper must not pollute TLS domain"
|
||||
);
|
||||
|
||||
checker.add_tls_digest(key);
|
||||
assert!(
|
||||
checker.check_and_add_tls_digest(key),
|
||||
"TLS add helper must populate TLS domain"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn replay_checker_stats_reflect_dual_shard_domains() {
|
||||
let checker = ReplayChecker::new(128, Duration::from_secs(1));
|
||||
let stats = checker.stats();
|
||||
|
||||
assert_eq!(
|
||||
stats.num_shards, 128,
|
||||
"stats should expose both shard domains (handshake + TLS)"
|
||||
);
|
||||
}
|
||||
|
|
@ -117,15 +117,6 @@ pub fn build_emulated_server_hello(
|
|||
extensions.extend_from_slice(&0x002bu16.to_be_bytes());
|
||||
extensions.extend_from_slice(&(2u16).to_be_bytes());
|
||||
extensions.extend_from_slice(&0x0304u16.to_be_bytes());
|
||||
if let Some(alpn_proto) = &alpn {
|
||||
extensions.extend_from_slice(&0x0010u16.to_be_bytes());
|
||||
let list_len: u16 = 1 + alpn_proto.len() as u16;
|
||||
let ext_len: u16 = 2 + list_len;
|
||||
extensions.extend_from_slice(&ext_len.to_be_bytes());
|
||||
extensions.extend_from_slice(&list_len.to_be_bytes());
|
||||
extensions.push(alpn_proto.len() as u8);
|
||||
extensions.extend_from_slice(alpn_proto);
|
||||
}
|
||||
let extensions_len = extensions.len() as u16;
|
||||
|
||||
let body_len = 2 + // version
|
||||
|
|
@ -207,8 +198,22 @@ pub fn build_emulated_server_hello(
|
|||
}
|
||||
|
||||
let mut app_data = Vec::new();
|
||||
let alpn_marker = alpn
|
||||
.as_ref()
|
||||
.filter(|p| !p.is_empty() && p.len() <= u8::MAX as usize)
|
||||
.map(|proto| {
|
||||
let proto_list_len = 1usize + proto.len();
|
||||
let ext_data_len = 2usize + proto_list_len;
|
||||
let mut marker = Vec::with_capacity(4 + ext_data_len);
|
||||
marker.extend_from_slice(&0x0010u16.to_be_bytes());
|
||||
marker.extend_from_slice(&(ext_data_len as u16).to_be_bytes());
|
||||
marker.extend_from_slice(&(proto_list_len as u16).to_be_bytes());
|
||||
marker.push(proto.len() as u8);
|
||||
marker.extend_from_slice(proto);
|
||||
marker
|
||||
});
|
||||
let mut payload_offset = 0usize;
|
||||
for size in sizes {
|
||||
for (idx, size) in sizes.into_iter().enumerate() {
|
||||
let mut rec = Vec::with_capacity(5 + size);
|
||||
rec.push(TLS_RECORD_APPLICATION);
|
||||
rec.extend_from_slice(&TLS_VERSION);
|
||||
|
|
@ -233,7 +238,20 @@ pub fn build_emulated_server_hello(
|
|||
}
|
||||
} else if size > 17 {
|
||||
let body_len = size - 17;
|
||||
rec.extend_from_slice(&rng.bytes(body_len));
|
||||
let mut body = Vec::with_capacity(body_len);
|
||||
if idx == 0 && let Some(marker) = &alpn_marker {
|
||||
if marker.len() <= body_len {
|
||||
body.extend_from_slice(marker);
|
||||
if body_len > marker.len() {
|
||||
body.extend_from_slice(&rng.bytes(body_len - marker.len()));
|
||||
}
|
||||
} else {
|
||||
body.extend_from_slice(&rng.bytes(body_len));
|
||||
}
|
||||
} else {
|
||||
body.extend_from_slice(&rng.bytes(body_len));
|
||||
}
|
||||
rec.extend_from_slice(&body);
|
||||
rec.push(0x16); // inner content type marker (handshake)
|
||||
rec.extend_from_slice(&rng.bytes(16)); // AEAD-like tag
|
||||
} else {
|
||||
|
|
@ -245,8 +263,9 @@ pub fn build_emulated_server_hello(
|
|||
// --- Combine ---
|
||||
// Optional NewSessionTicket mimic records (opaque ApplicationData for fingerprint).
|
||||
let mut tickets = Vec::new();
|
||||
if new_session_tickets > 0 {
|
||||
for _ in 0..new_session_tickets {
|
||||
let ticket_count = new_session_tickets.min(4);
|
||||
if ticket_count > 0 {
|
||||
for _ in 0..ticket_count {
|
||||
let ticket_len: usize = rng.range(48) + 48;
|
||||
let mut rec = Vec::with_capacity(5 + ticket_len);
|
||||
rec.push(TLS_RECORD_APPLICATION);
|
||||
|
|
@ -273,6 +292,10 @@ pub fn build_emulated_server_hello(
|
|||
response
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "emulator_security_tests.rs"]
|
||||
mod security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::time::SystemTime;
|
||||
|
|
|
|||
|
|
@ -0,0 +1,136 @@
|
|||
use std::time::SystemTime;
|
||||
|
||||
use crate::crypto::SecureRandom;
|
||||
use crate::protocol::constants::{TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER, TLS_RECORD_HANDSHAKE};
|
||||
use crate::tls_front::emulator::build_emulated_server_hello;
|
||||
use crate::tls_front::types::{
|
||||
CachedTlsData, ParsedServerHello, TlsBehaviorProfile, TlsCertPayload, TlsProfileSource,
|
||||
};
|
||||
|
||||
fn make_cached(cert_payload: Option<crate::tls_front::types::TlsCertPayload>) -> CachedTlsData {
|
||||
CachedTlsData {
|
||||
server_hello_template: ParsedServerHello {
|
||||
version: [0x03, 0x03],
|
||||
random: [0u8; 32],
|
||||
session_id: Vec::new(),
|
||||
cipher_suite: [0x13, 0x01],
|
||||
compression: 0,
|
||||
extensions: Vec::new(),
|
||||
},
|
||||
cert_info: None,
|
||||
cert_payload,
|
||||
app_data_records_sizes: vec![64],
|
||||
total_app_data_len: 64,
|
||||
behavior_profile: TlsBehaviorProfile {
|
||||
change_cipher_spec_count: 1,
|
||||
app_data_record_sizes: vec![64],
|
||||
ticket_record_sizes: Vec::new(),
|
||||
source: TlsProfileSource::Default,
|
||||
},
|
||||
fetched_at: SystemTime::now(),
|
||||
domain: "example.com".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn first_app_data_payload(response: &[u8]) -> &[u8] {
|
||||
let hello_len = u16::from_be_bytes([response[3], response[4]]) as usize;
|
||||
let ccs_start = 5 + hello_len;
|
||||
let ccs_len = u16::from_be_bytes([response[ccs_start + 3], response[ccs_start + 4]]) as usize;
|
||||
let app_start = ccs_start + 5 + ccs_len;
|
||||
let app_len = u16::from_be_bytes([response[app_start + 3], response[app_start + 4]]) as usize;
|
||||
&response[app_start + 5..app_start + 5 + app_len]
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn emulated_server_hello_ignores_oversized_alpn_when_marker_would_not_fit() {
|
||||
let cached = make_cached(None);
|
||||
let rng = SecureRandom::new();
|
||||
let oversized_alpn = vec![0xAB; u8::MAX as usize + 1];
|
||||
|
||||
let response = build_emulated_server_hello(
|
||||
b"secret",
|
||||
&[0x11; 32],
|
||||
&[0x22; 16],
|
||||
&cached,
|
||||
true,
|
||||
&rng,
|
||||
Some(oversized_alpn),
|
||||
0,
|
||||
);
|
||||
|
||||
assert_eq!(response[0], TLS_RECORD_HANDSHAKE);
|
||||
let hello_len = u16::from_be_bytes([response[3], response[4]]) as usize;
|
||||
let ccs_start = 5 + hello_len;
|
||||
assert_eq!(response[ccs_start], TLS_RECORD_CHANGE_CIPHER);
|
||||
let app_start = ccs_start + 6;
|
||||
assert_eq!(response[app_start], TLS_RECORD_APPLICATION);
|
||||
|
||||
let payload = first_app_data_payload(&response);
|
||||
let mut marker_prefix = Vec::new();
|
||||
marker_prefix.extend_from_slice(&0x0010u16.to_be_bytes());
|
||||
marker_prefix.extend_from_slice(&0x0102u16.to_be_bytes());
|
||||
marker_prefix.extend_from_slice(&0x0100u16.to_be_bytes());
|
||||
marker_prefix.push(0xff);
|
||||
marker_prefix.extend_from_slice(&[0xab; 8]);
|
||||
assert!(
|
||||
!payload.starts_with(&marker_prefix),
|
||||
"oversized ALPN must not be partially embedded into the emulated first application record"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn emulated_server_hello_embeds_full_alpn_marker_when_body_can_fit() {
|
||||
let cached = make_cached(None);
|
||||
let rng = SecureRandom::new();
|
||||
|
||||
let response = build_emulated_server_hello(
|
||||
b"secret",
|
||||
&[0x31; 32],
|
||||
&[0x41; 16],
|
||||
&cached,
|
||||
true,
|
||||
&rng,
|
||||
Some(b"h2".to_vec()),
|
||||
0,
|
||||
);
|
||||
|
||||
let payload = first_app_data_payload(&response);
|
||||
let expected = [0x00u8, 0x10, 0x00, 0x05, 0x00, 0x03, 0x02, b'h', b'2'];
|
||||
assert!(
|
||||
payload.starts_with(&expected),
|
||||
"when body has enough capacity, emulated first application record must include full ALPN marker"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn emulated_server_hello_prefers_cert_payload_over_alpn_marker() {
|
||||
let cert_msg = vec![0x0b, 0x00, 0x00, 0x05, 0x00, 0xaa, 0xbb, 0xcc, 0xdd];
|
||||
let cached = make_cached(Some(TlsCertPayload {
|
||||
cert_chain_der: vec![vec![0x30, 0x01, 0x00]],
|
||||
certificate_message: cert_msg.clone(),
|
||||
}));
|
||||
let rng = SecureRandom::new();
|
||||
|
||||
let response = build_emulated_server_hello(
|
||||
b"secret",
|
||||
&[0x32; 32],
|
||||
&[0x42; 16],
|
||||
&cached,
|
||||
true,
|
||||
&rng,
|
||||
Some(b"h2".to_vec()),
|
||||
0,
|
||||
);
|
||||
|
||||
let payload = first_app_data_payload(&response);
|
||||
let alpn_marker = [0x00u8, 0x10, 0x00, 0x05, 0x00, 0x03, 0x02, b'h', b'2'];
|
||||
|
||||
assert!(
|
||||
payload.starts_with(&cert_msg),
|
||||
"when certificate payload is available, first record must start with cert payload bytes"
|
||||
);
|
||||
assert!(
|
||||
!payload.starts_with(&alpn_marker),
|
||||
"ALPN marker must not displace selected certificate payload"
|
||||
);
|
||||
}
|
||||
|
|
@ -25,6 +25,9 @@ const HEALTH_RECONNECT_BUDGET_PER_CORE: usize = 2;
|
|||
const HEALTH_RECONNECT_BUDGET_PER_DC: usize = 1;
|
||||
const HEALTH_RECONNECT_BUDGET_MIN: usize = 4;
|
||||
const HEALTH_RECONNECT_BUDGET_MAX: usize = 128;
|
||||
const HEALTH_DRAIN_CLOSE_BUDGET_PER_CORE: usize = 16;
|
||||
const HEALTH_DRAIN_CLOSE_BUDGET_MIN: usize = 16;
|
||||
const HEALTH_DRAIN_CLOSE_BUDGET_MAX: usize = 256;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct DcFloorPlanEntry {
|
||||
|
|
@ -111,7 +114,7 @@ pub async fn me_health_monitor(pool: Arc<MePool>, rng: Arc<SecureRandom>, _min_c
|
|||
}
|
||||
}
|
||||
|
||||
async fn reap_draining_writers(
|
||||
pub(super) async fn reap_draining_writers(
|
||||
pool: &Arc<MePool>,
|
||||
warn_next_allowed: &mut HashMap<u64, Instant>,
|
||||
) {
|
||||
|
|
@ -121,34 +124,57 @@ async fn reap_draining_writers(
|
|||
let drain_threshold = pool
|
||||
.me_pool_drain_threshold
|
||||
.load(std::sync::atomic::Ordering::Relaxed);
|
||||
let writers = pool.writers.read().await.clone();
|
||||
let mut draining_writers = Vec::new();
|
||||
for writer in writers {
|
||||
let activity = pool.registry.writer_activity_snapshot().await;
|
||||
let mut draining_writers = Vec::<DrainingWriterSnapshot>::new();
|
||||
let mut empty_writer_ids = Vec::<u64>::new();
|
||||
let mut force_close_writer_ids = Vec::<u64>::new();
|
||||
let writers = pool.writers.read().await;
|
||||
for writer in writers.iter() {
|
||||
if !writer.draining.load(std::sync::atomic::Ordering::Relaxed) {
|
||||
continue;
|
||||
}
|
||||
let is_empty = pool.registry.is_writer_empty(writer.id).await;
|
||||
if is_empty {
|
||||
pool.remove_writer_and_close_clients(writer.id).await;
|
||||
if activity
|
||||
.bound_clients_by_writer
|
||||
.get(&writer.id)
|
||||
.copied()
|
||||
.unwrap_or(0)
|
||||
== 0
|
||||
{
|
||||
empty_writer_ids.push(writer.id);
|
||||
continue;
|
||||
}
|
||||
draining_writers.push(writer);
|
||||
draining_writers.push(DrainingWriterSnapshot {
|
||||
id: writer.id,
|
||||
writer_dc: writer.writer_dc,
|
||||
addr: writer.addr,
|
||||
generation: writer.generation,
|
||||
created_at: writer.created_at,
|
||||
draining_started_at_epoch_secs: writer
|
||||
.draining_started_at_epoch_secs
|
||||
.load(std::sync::atomic::Ordering::Relaxed),
|
||||
drain_deadline_epoch_secs: writer
|
||||
.drain_deadline_epoch_secs
|
||||
.load(std::sync::atomic::Ordering::Relaxed),
|
||||
allow_drain_fallback: writer
|
||||
.allow_drain_fallback
|
||||
.load(std::sync::atomic::Ordering::Relaxed),
|
||||
});
|
||||
}
|
||||
drop(writers);
|
||||
|
||||
if drain_threshold > 0 && draining_writers.len() > drain_threshold as usize {
|
||||
let overflow = if drain_threshold > 0 && draining_writers.len() > drain_threshold as usize {
|
||||
draining_writers.len().saturating_sub(drain_threshold as usize)
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
if overflow > 0 {
|
||||
draining_writers.sort_by(|left, right| {
|
||||
let left_started = left
|
||||
.draining_started_at_epoch_secs
|
||||
.load(std::sync::atomic::Ordering::Relaxed);
|
||||
let right_started = right
|
||||
.draining_started_at_epoch_secs
|
||||
.load(std::sync::atomic::Ordering::Relaxed);
|
||||
left_started
|
||||
.cmp(&right_started)
|
||||
left.draining_started_at_epoch_secs
|
||||
.cmp(&right.draining_started_at_epoch_secs)
|
||||
.then_with(|| left.created_at.cmp(&right.created_at))
|
||||
.then_with(|| left.id.cmp(&right.id))
|
||||
});
|
||||
let overflow = draining_writers.len().saturating_sub(drain_threshold as usize);
|
||||
warn!(
|
||||
draining_writers = draining_writers.len(),
|
||||
me_pool_drain_threshold = drain_threshold,
|
||||
|
|
@ -156,18 +182,14 @@ async fn reap_draining_writers(
|
|||
"ME draining writer threshold exceeded, force-closing oldest draining writers"
|
||||
);
|
||||
for writer in draining_writers.drain(..overflow) {
|
||||
pool.stats.increment_pool_force_close_total();
|
||||
pool.remove_writer_and_close_clients(writer.id).await;
|
||||
force_close_writer_ids.push(writer.id);
|
||||
}
|
||||
}
|
||||
|
||||
for writer in draining_writers {
|
||||
let drain_started_at_epoch_secs = writer
|
||||
.draining_started_at_epoch_secs
|
||||
.load(std::sync::atomic::Ordering::Relaxed);
|
||||
if drain_ttl_secs > 0
|
||||
&& drain_started_at_epoch_secs != 0
|
||||
&& now_epoch_secs.saturating_sub(drain_started_at_epoch_secs) > drain_ttl_secs
|
||||
&& writer.draining_started_at_epoch_secs != 0
|
||||
&& now_epoch_secs.saturating_sub(writer.draining_started_at_epoch_secs) > drain_ttl_secs
|
||||
&& should_emit_writer_warn(
|
||||
warn_next_allowed,
|
||||
writer.id,
|
||||
|
|
@ -182,19 +204,89 @@ async fn reap_draining_writers(
|
|||
generation = writer.generation,
|
||||
drain_ttl_secs,
|
||||
force_close_secs = pool.me_pool_force_close_secs.load(std::sync::atomic::Ordering::Relaxed),
|
||||
allow_drain_fallback = writer.allow_drain_fallback.load(std::sync::atomic::Ordering::Relaxed),
|
||||
allow_drain_fallback = writer.allow_drain_fallback,
|
||||
"ME draining writer remains non-empty past drain TTL"
|
||||
);
|
||||
}
|
||||
let deadline_epoch_secs = writer
|
||||
.drain_deadline_epoch_secs
|
||||
.load(std::sync::atomic::Ordering::Relaxed);
|
||||
if deadline_epoch_secs != 0 && now_epoch_secs >= deadline_epoch_secs {
|
||||
if writer.drain_deadline_epoch_secs != 0 && now_epoch_secs >= writer.drain_deadline_epoch_secs
|
||||
{
|
||||
warn!(writer_id = writer.id, "Drain timeout, force-closing");
|
||||
pool.stats.increment_pool_force_close_total();
|
||||
pool.remove_writer_and_close_clients(writer.id).await;
|
||||
force_close_writer_ids.push(writer.id);
|
||||
}
|
||||
}
|
||||
|
||||
let close_budget = health_drain_close_budget();
|
||||
let requested_force_close = force_close_writer_ids.len();
|
||||
let requested_empty_close = empty_writer_ids.len();
|
||||
let requested_close_total = requested_force_close.saturating_add(requested_empty_close);
|
||||
let mut closed_writer_ids = HashSet::<u64>::new();
|
||||
let mut closed_total = 0usize;
|
||||
for writer_id in force_close_writer_ids {
|
||||
if closed_total >= close_budget {
|
||||
break;
|
||||
}
|
||||
if !closed_writer_ids.insert(writer_id) {
|
||||
continue;
|
||||
}
|
||||
pool.stats.increment_pool_force_close_total();
|
||||
pool.remove_writer_and_close_clients(writer_id).await;
|
||||
closed_total = closed_total.saturating_add(1);
|
||||
}
|
||||
for writer_id in empty_writer_ids {
|
||||
if closed_total >= close_budget {
|
||||
break;
|
||||
}
|
||||
if !closed_writer_ids.insert(writer_id) {
|
||||
continue;
|
||||
}
|
||||
if !pool.remove_writer_if_empty(writer_id).await {
|
||||
continue;
|
||||
}
|
||||
closed_total = closed_total.saturating_add(1);
|
||||
}
|
||||
|
||||
let pending_close_total = requested_close_total.saturating_sub(closed_total);
|
||||
if pending_close_total > 0 {
|
||||
warn!(
|
||||
close_budget,
|
||||
closed_total,
|
||||
pending_close_total,
|
||||
"ME draining close backlog deferred to next health cycle"
|
||||
);
|
||||
}
|
||||
|
||||
// Keep warn cooldown state for draining writers still present in the pool;
|
||||
// drop state only once a writer is actually removed.
|
||||
let active_draining_writer_ids = {
|
||||
let writers = pool.writers.read().await;
|
||||
writers
|
||||
.iter()
|
||||
.filter(|writer| writer.draining.load(std::sync::atomic::Ordering::Relaxed))
|
||||
.map(|writer| writer.id)
|
||||
.collect::<HashSet<u64>>()
|
||||
};
|
||||
warn_next_allowed.retain(|writer_id, _| active_draining_writer_ids.contains(writer_id));
|
||||
}
|
||||
|
||||
pub(super) fn health_drain_close_budget() -> usize {
|
||||
let cpu_cores = std::thread::available_parallelism()
|
||||
.map(std::num::NonZeroUsize::get)
|
||||
.unwrap_or(1);
|
||||
cpu_cores
|
||||
.saturating_mul(HEALTH_DRAIN_CLOSE_BUDGET_PER_CORE)
|
||||
.clamp(HEALTH_DRAIN_CLOSE_BUDGET_MIN, HEALTH_DRAIN_CLOSE_BUDGET_MAX)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct DrainingWriterSnapshot {
|
||||
id: u64,
|
||||
writer_dc: i32,
|
||||
addr: SocketAddr,
|
||||
generation: u64,
|
||||
created_at: Instant,
|
||||
draining_started_at_epoch_secs: u64,
|
||||
drain_deadline_epoch_secs: u64,
|
||||
allow_drain_fallback: bool,
|
||||
}
|
||||
|
||||
fn should_emit_writer_warn(
|
||||
|
|
@ -1330,6 +1422,15 @@ mod tests {
|
|||
me_pool_drain_threshold,
|
||||
..GeneralConfig::default()
|
||||
};
|
||||
let mut proxy_map_v4 = HashMap::new();
|
||||
proxy_map_v4.insert(
|
||||
2,
|
||||
vec![(IpAddr::V4(Ipv4Addr::new(203, 0, 113, 10)), 443)],
|
||||
);
|
||||
let decision = NetworkDecision {
|
||||
ipv4_me: true,
|
||||
..NetworkDecision::default()
|
||||
};
|
||||
MePool::new(
|
||||
None,
|
||||
vec![1u8; 32],
|
||||
|
|
@ -1341,10 +1442,10 @@ mod tests {
|
|||
None,
|
||||
12,
|
||||
1200,
|
||||
HashMap::new(),
|
||||
proxy_map_v4,
|
||||
HashMap::new(),
|
||||
None,
|
||||
NetworkDecision::default(),
|
||||
decision,
|
||||
None,
|
||||
Arc::new(SecureRandom::new()),
|
||||
Arc::new(Stats::default()),
|
||||
|
|
@ -1455,8 +1556,55 @@ mod tests {
|
|||
conn_id
|
||||
}
|
||||
|
||||
async fn insert_live_writer(pool: &Arc<MePool>, writer_id: u64, writer_dc: i32) {
|
||||
let (tx, _writer_rx) = mpsc::channel::<WriterCommand>(8);
|
||||
let writer = MeWriter {
|
||||
id: writer_id,
|
||||
addr: SocketAddr::new(
|
||||
IpAddr::V4(Ipv4Addr::new(203, 0, 113, (writer_id as u8).saturating_add(1))),
|
||||
4000 + writer_id as u16,
|
||||
),
|
||||
source_ip: IpAddr::V4(Ipv4Addr::LOCALHOST),
|
||||
writer_dc,
|
||||
generation: 2,
|
||||
contour: Arc::new(AtomicU8::new(WriterContour::Active.as_u8())),
|
||||
created_at: Instant::now(),
|
||||
tx: tx.clone(),
|
||||
cancel: CancellationToken::new(),
|
||||
degraded: Arc::new(AtomicBool::new(false)),
|
||||
rtt_ema_ms_x10: Arc::new(AtomicU32::new(0)),
|
||||
draining: Arc::new(AtomicBool::new(false)),
|
||||
draining_started_at_epoch_secs: Arc::new(AtomicU64::new(0)),
|
||||
drain_deadline_epoch_secs: Arc::new(AtomicU64::new(0)),
|
||||
allow_drain_fallback: Arc::new(AtomicBool::new(false)),
|
||||
};
|
||||
pool.writers.write().await.push(writer);
|
||||
pool.registry.register_writer(writer_id, tx).await;
|
||||
pool.conn_count.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn reap_draining_writers_force_closes_oldest_over_threshold() {
|
||||
let pool = make_pool(2).await;
|
||||
insert_live_writer(&pool, 1, 2).await;
|
||||
let now_epoch_secs = MePool::now_epoch_secs();
|
||||
let conn_a = insert_draining_writer(&pool, 10, now_epoch_secs.saturating_sub(30)).await;
|
||||
let conn_b = insert_draining_writer(&pool, 20, now_epoch_secs.saturating_sub(20)).await;
|
||||
let conn_c = insert_draining_writer(&pool, 30, now_epoch_secs.saturating_sub(10)).await;
|
||||
let mut warn_next_allowed = HashMap::new();
|
||||
|
||||
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||
|
||||
let mut writer_ids: Vec<u64> = pool.writers.read().await.iter().map(|writer| writer.id).collect();
|
||||
writer_ids.sort_unstable();
|
||||
assert_eq!(writer_ids, vec![1, 20, 30]);
|
||||
assert!(pool.registry.get_writer(conn_a).await.is_none());
|
||||
assert_eq!(pool.registry.get_writer(conn_b).await.unwrap().writer_id, 20);
|
||||
assert_eq!(pool.registry.get_writer(conn_c).await.unwrap().writer_id, 30);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn reap_draining_writers_force_closes_overflow_without_replacement() {
|
||||
let pool = make_pool(2).await;
|
||||
let now_epoch_secs = MePool::now_epoch_secs();
|
||||
let conn_a = insert_draining_writer(&pool, 10, now_epoch_secs.saturating_sub(30)).await;
|
||||
|
|
@ -1466,7 +1614,8 @@ mod tests {
|
|||
|
||||
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||
|
||||
let writer_ids: Vec<u64> = pool.writers.read().await.iter().map(|writer| writer.id).collect();
|
||||
let mut writer_ids: Vec<u64> = pool.writers.read().await.iter().map(|writer| writer.id).collect();
|
||||
writer_ids.sort_unstable();
|
||||
assert_eq!(writer_ids, vec![20, 30]);
|
||||
assert!(pool.registry.get_writer(conn_a).await.is_none());
|
||||
assert_eq!(pool.registry.get_writer(conn_b).await.unwrap().writer_id, 20);
|
||||
|
|
|
|||
|
|
@ -0,0 +1,615 @@
|
|||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicBool, AtomicU8, AtomicU32, AtomicU64, Ordering};
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
use super::codec::WriterCommand;
|
||||
use super::health::{health_drain_close_budget, reap_draining_writers};
|
||||
use super::pool::{MePool, MeWriter, WriterContour};
|
||||
use super::registry::ConnMeta;
|
||||
use super::me_health_monitor;
|
||||
use crate::config::{GeneralConfig, MeRouteNoWriterMode, MeSocksKdfPolicy, MeWriterPickMode};
|
||||
use crate::crypto::SecureRandom;
|
||||
use crate::network::probe::NetworkDecision;
|
||||
use crate::stats::Stats;
|
||||
|
||||
async fn make_pool(
|
||||
me_pool_drain_threshold: u64,
|
||||
me_health_interval_ms_unhealthy: u64,
|
||||
me_health_interval_ms_healthy: u64,
|
||||
) -> (Arc<MePool>, Arc<SecureRandom>) {
|
||||
let general = GeneralConfig {
|
||||
me_pool_drain_threshold,
|
||||
me_health_interval_ms_unhealthy,
|
||||
me_health_interval_ms_healthy,
|
||||
..GeneralConfig::default()
|
||||
};
|
||||
|
||||
let rng = Arc::new(SecureRandom::new());
|
||||
let pool = MePool::new(
|
||||
None,
|
||||
vec![1u8; 32],
|
||||
None,
|
||||
false,
|
||||
None,
|
||||
Vec::new(),
|
||||
1,
|
||||
None,
|
||||
12,
|
||||
1200,
|
||||
HashMap::new(),
|
||||
HashMap::new(),
|
||||
None,
|
||||
NetworkDecision::default(),
|
||||
None,
|
||||
rng.clone(),
|
||||
Arc::new(Stats::default()),
|
||||
general.me_keepalive_enabled,
|
||||
general.me_keepalive_interval_secs,
|
||||
general.me_keepalive_jitter_secs,
|
||||
general.me_keepalive_payload_random,
|
||||
general.rpc_proxy_req_every,
|
||||
general.me_warmup_stagger_enabled,
|
||||
general.me_warmup_step_delay_ms,
|
||||
general.me_warmup_step_jitter_ms,
|
||||
general.me_reconnect_max_concurrent_per_dc,
|
||||
general.me_reconnect_backoff_base_ms,
|
||||
general.me_reconnect_backoff_cap_ms,
|
||||
general.me_reconnect_fast_retry_count,
|
||||
general.me_single_endpoint_shadow_writers,
|
||||
general.me_single_endpoint_outage_mode_enabled,
|
||||
general.me_single_endpoint_outage_disable_quarantine,
|
||||
general.me_single_endpoint_outage_backoff_min_ms,
|
||||
general.me_single_endpoint_outage_backoff_max_ms,
|
||||
general.me_single_endpoint_shadow_rotate_every_secs,
|
||||
general.me_floor_mode,
|
||||
general.me_adaptive_floor_idle_secs,
|
||||
general.me_adaptive_floor_min_writers_single_endpoint,
|
||||
general.me_adaptive_floor_min_writers_multi_endpoint,
|
||||
general.me_adaptive_floor_recover_grace_secs,
|
||||
general.me_adaptive_floor_writers_per_core_total,
|
||||
general.me_adaptive_floor_cpu_cores_override,
|
||||
general.me_adaptive_floor_max_extra_writers_single_per_core,
|
||||
general.me_adaptive_floor_max_extra_writers_multi_per_core,
|
||||
general.me_adaptive_floor_max_active_writers_per_core,
|
||||
general.me_adaptive_floor_max_warm_writers_per_core,
|
||||
general.me_adaptive_floor_max_active_writers_global,
|
||||
general.me_adaptive_floor_max_warm_writers_global,
|
||||
general.hardswap,
|
||||
general.me_pool_drain_ttl_secs,
|
||||
general.me_pool_drain_threshold,
|
||||
general.effective_me_pool_force_close_secs(),
|
||||
general.me_pool_min_fresh_ratio,
|
||||
general.me_hardswap_warmup_delay_min_ms,
|
||||
general.me_hardswap_warmup_delay_max_ms,
|
||||
general.me_hardswap_warmup_extra_passes,
|
||||
general.me_hardswap_warmup_pass_backoff_base_ms,
|
||||
general.me_bind_stale_mode,
|
||||
general.me_bind_stale_ttl_secs,
|
||||
general.me_secret_atomic_snapshot,
|
||||
general.me_deterministic_writer_sort,
|
||||
MeWriterPickMode::default(),
|
||||
general.me_writer_pick_sample_size,
|
||||
MeSocksKdfPolicy::default(),
|
||||
general.me_writer_cmd_channel_capacity,
|
||||
general.me_route_channel_capacity,
|
||||
general.me_route_backpressure_base_timeout_ms,
|
||||
general.me_route_backpressure_high_timeout_ms,
|
||||
general.me_route_backpressure_high_watermark_pct,
|
||||
general.me_reader_route_data_wait_ms,
|
||||
general.me_health_interval_ms_unhealthy,
|
||||
general.me_health_interval_ms_healthy,
|
||||
general.me_warn_rate_limit_ms,
|
||||
MeRouteNoWriterMode::default(),
|
||||
general.me_route_no_writer_wait_ms,
|
||||
general.me_route_inline_recovery_attempts,
|
||||
general.me_route_inline_recovery_wait_ms,
|
||||
);
|
||||
|
||||
(pool, rng)
|
||||
}
|
||||
|
||||
async fn insert_draining_writer(
|
||||
pool: &Arc<MePool>,
|
||||
writer_id: u64,
|
||||
drain_started_at_epoch_secs: u64,
|
||||
bound_clients: usize,
|
||||
drain_deadline_epoch_secs: u64,
|
||||
) {
|
||||
let (tx, _writer_rx) = mpsc::channel::<WriterCommand>(8);
|
||||
let writer = MeWriter {
|
||||
id: writer_id,
|
||||
addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 6000 + writer_id as u16),
|
||||
source_ip: IpAddr::V4(Ipv4Addr::LOCALHOST),
|
||||
writer_dc: 2,
|
||||
generation: 1,
|
||||
contour: Arc::new(AtomicU8::new(WriterContour::Draining.as_u8())),
|
||||
created_at: Instant::now() - Duration::from_secs(writer_id),
|
||||
tx: tx.clone(),
|
||||
cancel: CancellationToken::new(),
|
||||
degraded: Arc::new(AtomicBool::new(false)),
|
||||
rtt_ema_ms_x10: Arc::new(AtomicU32::new(0)),
|
||||
draining: Arc::new(AtomicBool::new(true)),
|
||||
draining_started_at_epoch_secs: Arc::new(AtomicU64::new(drain_started_at_epoch_secs)),
|
||||
drain_deadline_epoch_secs: Arc::new(AtomicU64::new(drain_deadline_epoch_secs)),
|
||||
allow_drain_fallback: Arc::new(AtomicBool::new(false)),
|
||||
};
|
||||
|
||||
pool.writers.write().await.push(writer);
|
||||
pool.registry.register_writer(writer_id, tx).await;
|
||||
pool.conn_count.fetch_add(1, Ordering::Relaxed);
|
||||
|
||||
for idx in 0..bound_clients {
|
||||
let (conn_id, _rx) = pool.registry.register().await;
|
||||
assert!(
|
||||
pool.registry
|
||||
.bind_writer(
|
||||
conn_id,
|
||||
writer_id,
|
||||
ConnMeta {
|
||||
target_dc: 2,
|
||||
client_addr: SocketAddr::new(
|
||||
IpAddr::V4(Ipv4Addr::LOCALHOST),
|
||||
8000 + idx as u16,
|
||||
),
|
||||
our_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443),
|
||||
proto_flags: 0,
|
||||
},
|
||||
)
|
||||
.await
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
async fn writer_count(pool: &Arc<MePool>) -> usize {
|
||||
pool.writers.read().await.len()
|
||||
}
|
||||
|
||||
async fn sorted_writer_ids(pool: &Arc<MePool>) -> Vec<u64> {
|
||||
let mut ids = pool
|
||||
.writers
|
||||
.read()
|
||||
.await
|
||||
.iter()
|
||||
.map(|writer| writer.id)
|
||||
.collect::<Vec<_>>();
|
||||
ids.sort_unstable();
|
||||
ids
|
||||
}
|
||||
|
||||
fn lcg_next(state: &mut u64) -> u64 {
|
||||
*state = state.wrapping_mul(6364136223846793005).wrapping_add(1);
|
||||
*state
|
||||
}
|
||||
|
||||
async fn draining_writer_ids(pool: &Arc<MePool>) -> HashSet<u64> {
|
||||
pool.writers
|
||||
.read()
|
||||
.await
|
||||
.iter()
|
||||
.filter(|writer| writer.draining.load(Ordering::Relaxed))
|
||||
.map(|writer| writer.id)
|
||||
.collect::<HashSet<u64>>()
|
||||
}
|
||||
|
||||
async fn set_writer_runtime_state(
|
||||
pool: &Arc<MePool>,
|
||||
writer_id: u64,
|
||||
draining: bool,
|
||||
drain_started_at_epoch_secs: u64,
|
||||
drain_deadline_epoch_secs: u64,
|
||||
) {
|
||||
let writers = pool.writers.read().await;
|
||||
if let Some(writer) = writers.iter().find(|writer| writer.id == writer_id) {
|
||||
writer.draining.store(draining, Ordering::Relaxed);
|
||||
writer
|
||||
.draining_started_at_epoch_secs
|
||||
.store(drain_started_at_epoch_secs, Ordering::Relaxed);
|
||||
writer
|
||||
.drain_deadline_epoch_secs
|
||||
.store(drain_deadline_epoch_secs, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn reap_draining_writers_clears_warn_state_when_pool_empty() {
|
||||
let (pool, _rng) = make_pool(128, 1, 1).await;
|
||||
let mut warn_next_allowed = HashMap::new();
|
||||
warn_next_allowed.insert(11, Instant::now() + Duration::from_secs(5));
|
||||
warn_next_allowed.insert(22, Instant::now() + Duration::from_secs(5));
|
||||
|
||||
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||
|
||||
assert!(warn_next_allowed.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn reap_draining_writers_respects_threshold_across_multiple_overflow_cycles() {
|
||||
let threshold = 3u64;
|
||||
let (pool, _rng) = make_pool(threshold, 1, 1).await;
|
||||
let now_epoch_secs = MePool::now_epoch_secs();
|
||||
|
||||
for writer_id in 1..=60u64 {
|
||||
insert_draining_writer(
|
||||
&pool,
|
||||
writer_id,
|
||||
now_epoch_secs.saturating_sub(600).saturating_add(writer_id),
|
||||
1,
|
||||
0,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
let mut warn_next_allowed = HashMap::new();
|
||||
for _ in 0..64 {
|
||||
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||
if writer_count(&pool).await <= threshold as usize {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(writer_count(&pool).await, threshold as usize);
|
||||
assert_eq!(sorted_writer_ids(&pool).await, vec![58, 59, 60]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn reap_draining_writers_handles_large_empty_writer_population() {
|
||||
let (pool, _rng) = make_pool(128, 1, 1).await;
|
||||
let now_epoch_secs = MePool::now_epoch_secs();
|
||||
let total = health_drain_close_budget().saturating_mul(3).saturating_add(27);
|
||||
|
||||
for writer_id in 1..=total as u64 {
|
||||
insert_draining_writer(
|
||||
&pool,
|
||||
writer_id,
|
||||
now_epoch_secs.saturating_sub(120),
|
||||
0,
|
||||
0,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
let mut warn_next_allowed = HashMap::new();
|
||||
for _ in 0..24 {
|
||||
if writer_count(&pool).await == 0 {
|
||||
break;
|
||||
}
|
||||
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||
}
|
||||
|
||||
assert_eq!(writer_count(&pool).await, 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn reap_draining_writers_processes_mass_deadline_expiry_without_unbounded_growth() {
|
||||
let (pool, _rng) = make_pool(128, 1, 1).await;
|
||||
let now_epoch_secs = MePool::now_epoch_secs();
|
||||
let total = health_drain_close_budget().saturating_mul(4).saturating_add(31);
|
||||
|
||||
for writer_id in 1..=total as u64 {
|
||||
insert_draining_writer(
|
||||
&pool,
|
||||
writer_id,
|
||||
now_epoch_secs.saturating_sub(180),
|
||||
1,
|
||||
now_epoch_secs.saturating_sub(1),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
let mut warn_next_allowed = HashMap::new();
|
||||
for _ in 0..40 {
|
||||
if writer_count(&pool).await == 0 {
|
||||
break;
|
||||
}
|
||||
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||
}
|
||||
|
||||
assert_eq!(writer_count(&pool).await, 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn reap_draining_writers_maintains_warn_state_subset_property_under_bulk_churn() {
|
||||
let (pool, _rng) = make_pool(128, 1, 1).await;
|
||||
let now_epoch_secs = MePool::now_epoch_secs();
|
||||
let mut warn_next_allowed = HashMap::new();
|
||||
|
||||
for wave in 0..40u64 {
|
||||
for offset in 0..8u64 {
|
||||
insert_draining_writer(
|
||||
&pool,
|
||||
wave * 100 + offset,
|
||||
now_epoch_secs.saturating_sub(400 + offset),
|
||||
1,
|
||||
0,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||
assert!(warn_next_allowed.len() <= writer_count(&pool).await);
|
||||
|
||||
let ids = sorted_writer_ids(&pool).await;
|
||||
for writer_id in ids.into_iter().take(3) {
|
||||
let _ = pool.remove_writer_and_close_clients(writer_id).await;
|
||||
}
|
||||
|
||||
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||
assert!(warn_next_allowed.len() <= writer_count(&pool).await);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn reap_draining_writers_budgeted_cleanup_never_increases_pool_size() {
|
||||
let (pool, _rng) = make_pool(5, 1, 1).await;
|
||||
let now_epoch_secs = MePool::now_epoch_secs();
|
||||
|
||||
for writer_id in 1..=200u64 {
|
||||
insert_draining_writer(
|
||||
&pool,
|
||||
writer_id,
|
||||
now_epoch_secs.saturating_sub(240).saturating_add(writer_id),
|
||||
1,
|
||||
0,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
let mut warn_next_allowed = HashMap::new();
|
||||
let mut previous = writer_count(&pool).await;
|
||||
for _ in 0..32 {
|
||||
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||
let current = writer_count(&pool).await;
|
||||
assert!(current <= previous);
|
||||
previous = current;
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn me_health_monitor_converges_to_threshold_under_live_injection_churn() {
|
||||
let threshold = 7u64;
|
||||
let (pool, rng) = make_pool(threshold, 1, 1).await;
|
||||
let now_epoch_secs = MePool::now_epoch_secs();
|
||||
|
||||
for writer_id in 1..=40u64 {
|
||||
insert_draining_writer(
|
||||
&pool,
|
||||
writer_id,
|
||||
now_epoch_secs.saturating_sub(300).saturating_add(writer_id),
|
||||
1,
|
||||
0,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
let monitor = tokio::spawn(me_health_monitor(pool.clone(), rng, 0));
|
||||
|
||||
for wave in 0..8u64 {
|
||||
for offset in 0..10u64 {
|
||||
insert_draining_writer(
|
||||
&pool,
|
||||
1000 + wave * 100 + offset,
|
||||
now_epoch_secs.saturating_sub(120).saturating_add(offset),
|
||||
1,
|
||||
0,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
tokio::time::sleep(Duration::from_millis(5)).await;
|
||||
}
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(120)).await;
|
||||
monitor.abort();
|
||||
let _ = monitor.await;
|
||||
|
||||
assert!(writer_count(&pool).await <= threshold as usize);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn me_health_monitor_drains_deadline_storm_with_budgeted_progress() {
|
||||
let (pool, rng) = make_pool(128, 1, 1).await;
|
||||
let now_epoch_secs = MePool::now_epoch_secs();
|
||||
|
||||
for writer_id in 1..=220u64 {
|
||||
insert_draining_writer(
|
||||
&pool,
|
||||
writer_id,
|
||||
now_epoch_secs.saturating_sub(120),
|
||||
1,
|
||||
now_epoch_secs.saturating_sub(1),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
let monitor = tokio::spawn(me_health_monitor(pool.clone(), rng, 0));
|
||||
tokio::time::sleep(Duration::from_millis(120)).await;
|
||||
monitor.abort();
|
||||
let _ = monitor.await;
|
||||
|
||||
assert_eq!(writer_count(&pool).await, 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn me_health_monitor_eliminates_mixed_empty_and_deadline_backlog() {
|
||||
let threshold = 12u64;
|
||||
let (pool, rng) = make_pool(threshold, 1, 1).await;
|
||||
let now_epoch_secs = MePool::now_epoch_secs();
|
||||
|
||||
for writer_id in 1..=180u64 {
|
||||
let bound_clients = if writer_id % 3 == 0 { 0 } else { 1 };
|
||||
let deadline = if writer_id % 2 == 0 {
|
||||
now_epoch_secs.saturating_sub(1)
|
||||
} else {
|
||||
0
|
||||
};
|
||||
insert_draining_writer(
|
||||
&pool,
|
||||
writer_id,
|
||||
now_epoch_secs.saturating_sub(250).saturating_add(writer_id),
|
||||
bound_clients,
|
||||
deadline,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
let monitor = tokio::spawn(me_health_monitor(pool.clone(), rng, 0));
|
||||
tokio::time::sleep(Duration::from_millis(140)).await;
|
||||
monitor.abort();
|
||||
let _ = monitor.await;
|
||||
|
||||
assert!(writer_count(&pool).await <= threshold as usize);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn reap_draining_writers_deterministic_mixed_state_churn_preserves_invariants() {
|
||||
let threshold = 9u64;
|
||||
let (pool, _rng) = make_pool(threshold, 1, 1).await;
|
||||
let mut warn_next_allowed = HashMap::new();
|
||||
let mut seed = 0x9E37_79B9_7F4A_7C15u64;
|
||||
let mut next_writer_id = 20_000u64;
|
||||
let now_epoch_secs = MePool::now_epoch_secs();
|
||||
|
||||
for writer_id in 1..=72u64 {
|
||||
let bound_clients = if writer_id % 4 == 0 { 0 } else { 1 };
|
||||
let deadline = if writer_id % 5 == 0 {
|
||||
now_epoch_secs.saturating_sub(1)
|
||||
} else {
|
||||
0
|
||||
};
|
||||
insert_draining_writer(
|
||||
&pool,
|
||||
writer_id,
|
||||
now_epoch_secs.saturating_sub(500).saturating_add(writer_id),
|
||||
bound_clients,
|
||||
deadline,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
for _round in 0..90 {
|
||||
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||
|
||||
let draining_ids = draining_writer_ids(&pool).await;
|
||||
assert!(
|
||||
warn_next_allowed.keys().all(|id| draining_ids.contains(id)),
|
||||
"warn-state keys must always be a subset of live draining writers"
|
||||
);
|
||||
|
||||
let writer_ids = sorted_writer_ids(&pool).await;
|
||||
if writer_ids.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let remove_n = (lcg_next(&mut seed) % 3) as usize;
|
||||
for writer_id in writer_ids.iter().copied().take(remove_n) {
|
||||
let _ = pool.remove_writer_and_close_clients(writer_id).await;
|
||||
}
|
||||
|
||||
let survivors = sorted_writer_ids(&pool).await;
|
||||
if !survivors.is_empty() {
|
||||
let idx = (lcg_next(&mut seed) as usize) % survivors.len();
|
||||
let target = survivors[idx];
|
||||
set_writer_runtime_state(&pool, target, false, 0, 0).await;
|
||||
}
|
||||
|
||||
let survivors = sorted_writer_ids(&pool).await;
|
||||
if survivors.len() > 1 {
|
||||
let idx = (lcg_next(&mut seed) as usize) % survivors.len();
|
||||
let target = survivors[idx];
|
||||
let expired_deadline = if lcg_next(&mut seed) & 1 == 0 {
|
||||
now_epoch_secs.saturating_sub(1)
|
||||
} else {
|
||||
0
|
||||
};
|
||||
set_writer_runtime_state(
|
||||
&pool,
|
||||
target,
|
||||
true,
|
||||
now_epoch_secs.saturating_sub(120),
|
||||
expired_deadline,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
let inject_n = (lcg_next(&mut seed) % 4) as usize;
|
||||
for _ in 0..inject_n {
|
||||
let bound_clients = if lcg_next(&mut seed) & 1 == 0 { 0 } else { 1 };
|
||||
let deadline = if lcg_next(&mut seed) & 1 == 0 {
|
||||
now_epoch_secs.saturating_sub(1)
|
||||
} else {
|
||||
0
|
||||
};
|
||||
insert_draining_writer(
|
||||
&pool,
|
||||
next_writer_id,
|
||||
now_epoch_secs.saturating_sub(240),
|
||||
bound_clients,
|
||||
deadline,
|
||||
)
|
||||
.await;
|
||||
next_writer_id = next_writer_id.saturating_add(1);
|
||||
}
|
||||
}
|
||||
|
||||
for _ in 0..64 {
|
||||
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||
if writer_count(&pool).await <= threshold as usize {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
assert!(writer_count(&pool).await <= threshold as usize);
|
||||
let draining_ids = draining_writer_ids(&pool).await;
|
||||
assert!(warn_next_allowed.keys().all(|id| draining_ids.contains(id)));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn reap_draining_writers_repeated_draining_flips_never_leave_stale_warn_state() {
|
||||
let (pool, _rng) = make_pool(64, 1, 1).await;
|
||||
let now_epoch_secs = MePool::now_epoch_secs();
|
||||
|
||||
for writer_id in 1..=24u64 {
|
||||
insert_draining_writer(
|
||||
&pool,
|
||||
writer_id,
|
||||
now_epoch_secs.saturating_sub(240),
|
||||
1,
|
||||
0,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
let mut warn_next_allowed = HashMap::new();
|
||||
for _round in 0..48u64 {
|
||||
for writer_id in 1..=24u64 {
|
||||
let draining = (writer_id + _round) % 3 != 0;
|
||||
set_writer_runtime_state(
|
||||
&pool,
|
||||
writer_id,
|
||||
draining,
|
||||
now_epoch_secs.saturating_sub(120),
|
||||
0,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||
|
||||
let draining_ids = draining_writer_ids(&pool).await;
|
||||
assert!(
|
||||
warn_next_allowed.keys().all(|id| draining_ids.contains(id)),
|
||||
"warn-state map must not retain entries for writers outside draining set"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn health_drain_close_budget_is_within_expected_bounds() {
|
||||
let budget = health_drain_close_budget();
|
||||
assert!((16..=256).contains(&budget));
|
||||
}
|
||||
|
|
@ -0,0 +1,241 @@
|
|||
use std::collections::HashMap;
|
||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicBool, AtomicU8, AtomicU32, AtomicU64, Ordering};
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
use super::codec::WriterCommand;
|
||||
use super::health::health_drain_close_budget;
|
||||
use super::pool::{MePool, MeWriter, WriterContour};
|
||||
use super::registry::ConnMeta;
|
||||
use super::me_health_monitor;
|
||||
use crate::config::{GeneralConfig, MeRouteNoWriterMode, MeSocksKdfPolicy, MeWriterPickMode};
|
||||
use crate::crypto::SecureRandom;
|
||||
use crate::network::probe::NetworkDecision;
|
||||
use crate::stats::Stats;
|
||||
|
||||
async fn make_pool(
|
||||
me_pool_drain_threshold: u64,
|
||||
me_health_interval_ms_unhealthy: u64,
|
||||
me_health_interval_ms_healthy: u64,
|
||||
) -> (Arc<MePool>, Arc<SecureRandom>) {
|
||||
let general = GeneralConfig {
|
||||
me_pool_drain_threshold,
|
||||
me_health_interval_ms_unhealthy,
|
||||
me_health_interval_ms_healthy,
|
||||
..GeneralConfig::default()
|
||||
};
|
||||
let rng = Arc::new(SecureRandom::new());
|
||||
let pool = MePool::new(
|
||||
None,
|
||||
vec![1u8; 32],
|
||||
None,
|
||||
false,
|
||||
None,
|
||||
Vec::new(),
|
||||
1,
|
||||
None,
|
||||
12,
|
||||
1200,
|
||||
HashMap::new(),
|
||||
HashMap::new(),
|
||||
None,
|
||||
NetworkDecision::default(),
|
||||
None,
|
||||
rng.clone(),
|
||||
Arc::new(Stats::default()),
|
||||
general.me_keepalive_enabled,
|
||||
general.me_keepalive_interval_secs,
|
||||
general.me_keepalive_jitter_secs,
|
||||
general.me_keepalive_payload_random,
|
||||
general.rpc_proxy_req_every,
|
||||
general.me_warmup_stagger_enabled,
|
||||
general.me_warmup_step_delay_ms,
|
||||
general.me_warmup_step_jitter_ms,
|
||||
general.me_reconnect_max_concurrent_per_dc,
|
||||
general.me_reconnect_backoff_base_ms,
|
||||
general.me_reconnect_backoff_cap_ms,
|
||||
general.me_reconnect_fast_retry_count,
|
||||
general.me_single_endpoint_shadow_writers,
|
||||
general.me_single_endpoint_outage_mode_enabled,
|
||||
general.me_single_endpoint_outage_disable_quarantine,
|
||||
general.me_single_endpoint_outage_backoff_min_ms,
|
||||
general.me_single_endpoint_outage_backoff_max_ms,
|
||||
general.me_single_endpoint_shadow_rotate_every_secs,
|
||||
general.me_floor_mode,
|
||||
general.me_adaptive_floor_idle_secs,
|
||||
general.me_adaptive_floor_min_writers_single_endpoint,
|
||||
general.me_adaptive_floor_min_writers_multi_endpoint,
|
||||
general.me_adaptive_floor_recover_grace_secs,
|
||||
general.me_adaptive_floor_writers_per_core_total,
|
||||
general.me_adaptive_floor_cpu_cores_override,
|
||||
general.me_adaptive_floor_max_extra_writers_single_per_core,
|
||||
general.me_adaptive_floor_max_extra_writers_multi_per_core,
|
||||
general.me_adaptive_floor_max_active_writers_per_core,
|
||||
general.me_adaptive_floor_max_warm_writers_per_core,
|
||||
general.me_adaptive_floor_max_active_writers_global,
|
||||
general.me_adaptive_floor_max_warm_writers_global,
|
||||
general.hardswap,
|
||||
general.me_pool_drain_ttl_secs,
|
||||
general.me_pool_drain_threshold,
|
||||
general.effective_me_pool_force_close_secs(),
|
||||
general.me_pool_min_fresh_ratio,
|
||||
general.me_hardswap_warmup_delay_min_ms,
|
||||
general.me_hardswap_warmup_delay_max_ms,
|
||||
general.me_hardswap_warmup_extra_passes,
|
||||
general.me_hardswap_warmup_pass_backoff_base_ms,
|
||||
general.me_bind_stale_mode,
|
||||
general.me_bind_stale_ttl_secs,
|
||||
general.me_secret_atomic_snapshot,
|
||||
general.me_deterministic_writer_sort,
|
||||
MeWriterPickMode::default(),
|
||||
general.me_writer_pick_sample_size,
|
||||
MeSocksKdfPolicy::default(),
|
||||
general.me_writer_cmd_channel_capacity,
|
||||
general.me_route_channel_capacity,
|
||||
general.me_route_backpressure_base_timeout_ms,
|
||||
general.me_route_backpressure_high_timeout_ms,
|
||||
general.me_route_backpressure_high_watermark_pct,
|
||||
general.me_reader_route_data_wait_ms,
|
||||
general.me_health_interval_ms_unhealthy,
|
||||
general.me_health_interval_ms_healthy,
|
||||
general.me_warn_rate_limit_ms,
|
||||
MeRouteNoWriterMode::default(),
|
||||
general.me_route_no_writer_wait_ms,
|
||||
general.me_route_inline_recovery_attempts,
|
||||
general.me_route_inline_recovery_wait_ms,
|
||||
);
|
||||
(pool, rng)
|
||||
}
|
||||
|
||||
async fn insert_draining_writer(
|
||||
pool: &Arc<MePool>,
|
||||
writer_id: u64,
|
||||
drain_started_at_epoch_secs: u64,
|
||||
bound_clients: usize,
|
||||
drain_deadline_epoch_secs: u64,
|
||||
) {
|
||||
let (tx, _writer_rx) = mpsc::channel::<WriterCommand>(8);
|
||||
let writer = MeWriter {
|
||||
id: writer_id,
|
||||
addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 5500 + writer_id as u16),
|
||||
source_ip: IpAddr::V4(Ipv4Addr::LOCALHOST),
|
||||
writer_dc: 2,
|
||||
generation: 1,
|
||||
contour: Arc::new(AtomicU8::new(WriterContour::Draining.as_u8())),
|
||||
created_at: Instant::now() - Duration::from_secs(writer_id),
|
||||
tx: tx.clone(),
|
||||
cancel: CancellationToken::new(),
|
||||
degraded: Arc::new(AtomicBool::new(false)),
|
||||
rtt_ema_ms_x10: Arc::new(AtomicU32::new(0)),
|
||||
draining: Arc::new(AtomicBool::new(true)),
|
||||
draining_started_at_epoch_secs: Arc::new(AtomicU64::new(drain_started_at_epoch_secs)),
|
||||
drain_deadline_epoch_secs: Arc::new(AtomicU64::new(drain_deadline_epoch_secs)),
|
||||
allow_drain_fallback: Arc::new(AtomicBool::new(false)),
|
||||
};
|
||||
pool.writers.write().await.push(writer);
|
||||
pool.registry.register_writer(writer_id, tx).await;
|
||||
pool.conn_count.fetch_add(1, Ordering::Relaxed);
|
||||
for idx in 0..bound_clients {
|
||||
let (conn_id, _rx) = pool.registry.register().await;
|
||||
assert!(
|
||||
pool.registry
|
||||
.bind_writer(
|
||||
conn_id,
|
||||
writer_id,
|
||||
ConnMeta {
|
||||
target_dc: 2,
|
||||
client_addr: SocketAddr::new(
|
||||
IpAddr::V4(Ipv4Addr::LOCALHOST),
|
||||
7200 + idx as u16,
|
||||
),
|
||||
our_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443),
|
||||
proto_flags: 0,
|
||||
},
|
||||
)
|
||||
.await
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
async fn wait_for_pool_empty(pool: &Arc<MePool>, timeout: Duration) {
|
||||
let start = Instant::now();
|
||||
loop {
|
||||
if pool.writers.read().await.is_empty() {
|
||||
return;
|
||||
}
|
||||
assert!(
|
||||
start.elapsed() < timeout,
|
||||
"timed out waiting for pool.writers to become empty"
|
||||
);
|
||||
tokio::time::sleep(Duration::from_millis(5)).await;
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn me_health_monitor_drains_expired_backlog_over_multiple_cycles() {
|
||||
let (pool, rng) = make_pool(128, 1, 1).await;
|
||||
let now_epoch_secs = MePool::now_epoch_secs();
|
||||
let writer_total = health_drain_close_budget().saturating_mul(2).saturating_add(9);
|
||||
for writer_id in 1..=writer_total as u64 {
|
||||
insert_draining_writer(
|
||||
&pool,
|
||||
writer_id,
|
||||
now_epoch_secs.saturating_sub(120),
|
||||
1,
|
||||
now_epoch_secs.saturating_sub(1),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
let monitor = tokio::spawn(me_health_monitor(pool.clone(), rng, 0));
|
||||
wait_for_pool_empty(&pool, Duration::from_secs(1)).await;
|
||||
monitor.abort();
|
||||
let _ = monitor.await;
|
||||
|
||||
assert!(pool.writers.read().await.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn me_health_monitor_cleans_empty_draining_writers_without_force_close() {
|
||||
let (pool, rng) = make_pool(128, 1, 1).await;
|
||||
let now_epoch_secs = MePool::now_epoch_secs();
|
||||
for writer_id in 1..=24u64 {
|
||||
insert_draining_writer(&pool, writer_id, now_epoch_secs.saturating_sub(60), 0, 0).await;
|
||||
}
|
||||
|
||||
let monitor = tokio::spawn(me_health_monitor(pool.clone(), rng, 0));
|
||||
wait_for_pool_empty(&pool, Duration::from_secs(1)).await;
|
||||
monitor.abort();
|
||||
let _ = monitor.await;
|
||||
|
||||
assert!(pool.writers.read().await.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn me_health_monitor_converges_retry_like_threshold_backlog_to_empty() {
|
||||
let threshold = 4u64;
|
||||
let (pool, rng) = make_pool(threshold, 1, 1).await;
|
||||
let now_epoch_secs = MePool::now_epoch_secs();
|
||||
let writer_total = threshold as usize + health_drain_close_budget().saturating_add(11);
|
||||
for writer_id in 1..=writer_total as u64 {
|
||||
insert_draining_writer(
|
||||
&pool,
|
||||
writer_id,
|
||||
now_epoch_secs.saturating_sub(300).saturating_add(writer_id),
|
||||
1,
|
||||
0,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
let monitor = tokio::spawn(me_health_monitor(pool.clone(), rng, 0));
|
||||
wait_for_pool_empty(&pool, Duration::from_secs(1)).await;
|
||||
monitor.abort();
|
||||
let _ = monitor.await;
|
||||
|
||||
assert!(pool.writers.read().await.is_empty());
|
||||
}
|
||||
|
|
@ -0,0 +1,658 @@
|
|||
use std::collections::HashMap;
|
||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicBool, AtomicU8, AtomicU32, AtomicU64, Ordering};
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
use super::codec::WriterCommand;
|
||||
use super::health::{health_drain_close_budget, reap_draining_writers};
|
||||
use super::pool::{MePool, MeWriter, WriterContour};
|
||||
use super::registry::ConnMeta;
|
||||
use crate::config::{GeneralConfig, MeRouteNoWriterMode, MeSocksKdfPolicy, MeWriterPickMode};
|
||||
use crate::crypto::SecureRandom;
|
||||
use crate::network::probe::NetworkDecision;
|
||||
use crate::stats::Stats;
|
||||
|
||||
async fn make_pool(me_pool_drain_threshold: u64) -> Arc<MePool> {
|
||||
let general = GeneralConfig {
|
||||
me_pool_drain_threshold,
|
||||
..GeneralConfig::default()
|
||||
};
|
||||
|
||||
MePool::new(
|
||||
None,
|
||||
vec![1u8; 32],
|
||||
None,
|
||||
false,
|
||||
None,
|
||||
Vec::new(),
|
||||
1,
|
||||
None,
|
||||
12,
|
||||
1200,
|
||||
HashMap::new(),
|
||||
HashMap::new(),
|
||||
None,
|
||||
NetworkDecision::default(),
|
||||
None,
|
||||
Arc::new(SecureRandom::new()),
|
||||
Arc::new(Stats::default()),
|
||||
general.me_keepalive_enabled,
|
||||
general.me_keepalive_interval_secs,
|
||||
general.me_keepalive_jitter_secs,
|
||||
general.me_keepalive_payload_random,
|
||||
general.rpc_proxy_req_every,
|
||||
general.me_warmup_stagger_enabled,
|
||||
general.me_warmup_step_delay_ms,
|
||||
general.me_warmup_step_jitter_ms,
|
||||
general.me_reconnect_max_concurrent_per_dc,
|
||||
general.me_reconnect_backoff_base_ms,
|
||||
general.me_reconnect_backoff_cap_ms,
|
||||
general.me_reconnect_fast_retry_count,
|
||||
general.me_single_endpoint_shadow_writers,
|
||||
general.me_single_endpoint_outage_mode_enabled,
|
||||
general.me_single_endpoint_outage_disable_quarantine,
|
||||
general.me_single_endpoint_outage_backoff_min_ms,
|
||||
general.me_single_endpoint_outage_backoff_max_ms,
|
||||
general.me_single_endpoint_shadow_rotate_every_secs,
|
||||
general.me_floor_mode,
|
||||
general.me_adaptive_floor_idle_secs,
|
||||
general.me_adaptive_floor_min_writers_single_endpoint,
|
||||
general.me_adaptive_floor_min_writers_multi_endpoint,
|
||||
general.me_adaptive_floor_recover_grace_secs,
|
||||
general.me_adaptive_floor_writers_per_core_total,
|
||||
general.me_adaptive_floor_cpu_cores_override,
|
||||
general.me_adaptive_floor_max_extra_writers_single_per_core,
|
||||
general.me_adaptive_floor_max_extra_writers_multi_per_core,
|
||||
general.me_adaptive_floor_max_active_writers_per_core,
|
||||
general.me_adaptive_floor_max_warm_writers_per_core,
|
||||
general.me_adaptive_floor_max_active_writers_global,
|
||||
general.me_adaptive_floor_max_warm_writers_global,
|
||||
general.hardswap,
|
||||
general.me_pool_drain_ttl_secs,
|
||||
general.me_pool_drain_threshold,
|
||||
general.effective_me_pool_force_close_secs(),
|
||||
general.me_pool_min_fresh_ratio,
|
||||
general.me_hardswap_warmup_delay_min_ms,
|
||||
general.me_hardswap_warmup_delay_max_ms,
|
||||
general.me_hardswap_warmup_extra_passes,
|
||||
general.me_hardswap_warmup_pass_backoff_base_ms,
|
||||
general.me_bind_stale_mode,
|
||||
general.me_bind_stale_ttl_secs,
|
||||
general.me_secret_atomic_snapshot,
|
||||
general.me_deterministic_writer_sort,
|
||||
MeWriterPickMode::default(),
|
||||
general.me_writer_pick_sample_size,
|
||||
MeSocksKdfPolicy::default(),
|
||||
general.me_writer_cmd_channel_capacity,
|
||||
general.me_route_channel_capacity,
|
||||
general.me_route_backpressure_base_timeout_ms,
|
||||
general.me_route_backpressure_high_timeout_ms,
|
||||
general.me_route_backpressure_high_watermark_pct,
|
||||
general.me_reader_route_data_wait_ms,
|
||||
general.me_health_interval_ms_unhealthy,
|
||||
general.me_health_interval_ms_healthy,
|
||||
general.me_warn_rate_limit_ms,
|
||||
MeRouteNoWriterMode::default(),
|
||||
general.me_route_no_writer_wait_ms,
|
||||
general.me_route_inline_recovery_attempts,
|
||||
general.me_route_inline_recovery_wait_ms,
|
||||
)
|
||||
}
|
||||
|
||||
async fn insert_draining_writer(
|
||||
pool: &Arc<MePool>,
|
||||
writer_id: u64,
|
||||
drain_started_at_epoch_secs: u64,
|
||||
bound_clients: usize,
|
||||
drain_deadline_epoch_secs: u64,
|
||||
) -> Vec<u64> {
|
||||
let mut conn_ids = Vec::with_capacity(bound_clients);
|
||||
let (tx, _writer_rx) = mpsc::channel::<WriterCommand>(8);
|
||||
let writer = MeWriter {
|
||||
id: writer_id,
|
||||
addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 4500 + writer_id as u16),
|
||||
source_ip: IpAddr::V4(Ipv4Addr::LOCALHOST),
|
||||
writer_dc: 2,
|
||||
generation: 1,
|
||||
contour: Arc::new(AtomicU8::new(WriterContour::Draining.as_u8())),
|
||||
created_at: Instant::now() - Duration::from_secs(writer_id),
|
||||
tx: tx.clone(),
|
||||
cancel: CancellationToken::new(),
|
||||
degraded: Arc::new(AtomicBool::new(false)),
|
||||
rtt_ema_ms_x10: Arc::new(AtomicU32::new(0)),
|
||||
draining: Arc::new(AtomicBool::new(true)),
|
||||
draining_started_at_epoch_secs: Arc::new(AtomicU64::new(drain_started_at_epoch_secs)),
|
||||
drain_deadline_epoch_secs: Arc::new(AtomicU64::new(drain_deadline_epoch_secs)),
|
||||
allow_drain_fallback: Arc::new(AtomicBool::new(false)),
|
||||
};
|
||||
pool.writers.write().await.push(writer);
|
||||
pool.registry.register_writer(writer_id, tx).await;
|
||||
pool.conn_count.fetch_add(1, Ordering::Relaxed);
|
||||
for idx in 0..bound_clients {
|
||||
let (conn_id, _rx) = pool.registry.register().await;
|
||||
assert!(
|
||||
pool.registry
|
||||
.bind_writer(
|
||||
conn_id,
|
||||
writer_id,
|
||||
ConnMeta {
|
||||
target_dc: 2,
|
||||
client_addr: SocketAddr::new(
|
||||
IpAddr::V4(Ipv4Addr::LOCALHOST),
|
||||
6200 + idx as u16,
|
||||
),
|
||||
our_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443),
|
||||
proto_flags: 0,
|
||||
},
|
||||
)
|
||||
.await
|
||||
);
|
||||
conn_ids.push(conn_id);
|
||||
}
|
||||
conn_ids
|
||||
}
|
||||
|
||||
async fn current_writer_ids(pool: &Arc<MePool>) -> Vec<u64> {
|
||||
let mut writer_ids = pool
|
||||
.writers
|
||||
.read()
|
||||
.await
|
||||
.iter()
|
||||
.map(|writer| writer.id)
|
||||
.collect::<Vec<_>>();
|
||||
writer_ids.sort_unstable();
|
||||
writer_ids
|
||||
}
|
||||
|
||||
async fn writer_exists(pool: &Arc<MePool>, writer_id: u64) -> bool {
|
||||
pool.writers
|
||||
.read()
|
||||
.await
|
||||
.iter()
|
||||
.any(|writer| writer.id == writer_id)
|
||||
}
|
||||
|
||||
async fn set_writer_draining(pool: &Arc<MePool>, writer_id: u64, draining: bool) {
|
||||
let writers = pool.writers.read().await;
|
||||
if let Some(writer) = writers.iter().find(|writer| writer.id == writer_id) {
|
||||
writer.draining.store(draining, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn reap_draining_writers_drops_warn_state_for_removed_writer() {
|
||||
let pool = make_pool(128).await;
|
||||
let now_epoch_secs = MePool::now_epoch_secs();
|
||||
let conn_ids =
|
||||
insert_draining_writer(&pool, 7, now_epoch_secs.saturating_sub(180), 1, 0).await;
|
||||
let mut warn_next_allowed = HashMap::new();
|
||||
|
||||
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||
assert!(warn_next_allowed.contains_key(&7));
|
||||
|
||||
let _ = pool.remove_writer_and_close_clients(7).await;
|
||||
assert!(pool.registry.get_writer(conn_ids[0]).await.is_none());
|
||||
|
||||
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||
assert!(!warn_next_allowed.contains_key(&7));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn reap_draining_writers_removes_empty_draining_writers() {
|
||||
let pool = make_pool(128).await;
|
||||
let now_epoch_secs = MePool::now_epoch_secs();
|
||||
insert_draining_writer(&pool, 1, now_epoch_secs.saturating_sub(40), 0, 0).await;
|
||||
insert_draining_writer(&pool, 2, now_epoch_secs.saturating_sub(30), 0, 0).await;
|
||||
insert_draining_writer(&pool, 3, now_epoch_secs.saturating_sub(20), 1, 0).await;
|
||||
let mut warn_next_allowed = HashMap::new();
|
||||
|
||||
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||
|
||||
assert_eq!(current_writer_ids(&pool).await, vec![3]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn reap_draining_writers_overflow_closes_oldest_non_empty_writers() {
|
||||
let pool = make_pool(2).await;
|
||||
let now_epoch_secs = MePool::now_epoch_secs();
|
||||
insert_draining_writer(&pool, 11, now_epoch_secs.saturating_sub(40), 1, 0).await;
|
||||
insert_draining_writer(&pool, 22, now_epoch_secs.saturating_sub(30), 1, 0).await;
|
||||
insert_draining_writer(&pool, 33, now_epoch_secs.saturating_sub(20), 1, 0).await;
|
||||
insert_draining_writer(&pool, 44, now_epoch_secs.saturating_sub(10), 1, 0).await;
|
||||
let mut warn_next_allowed = HashMap::new();
|
||||
|
||||
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||
|
||||
assert_eq!(current_writer_ids(&pool).await, vec![33, 44]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn reap_draining_writers_deadline_force_close_applies_under_threshold() {
|
||||
let pool = make_pool(128).await;
|
||||
let now_epoch_secs = MePool::now_epoch_secs();
|
||||
insert_draining_writer(
|
||||
&pool,
|
||||
50,
|
||||
now_epoch_secs.saturating_sub(15),
|
||||
1,
|
||||
now_epoch_secs.saturating_sub(1),
|
||||
)
|
||||
.await;
|
||||
let mut warn_next_allowed = HashMap::new();
|
||||
|
||||
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||
|
||||
assert!(current_writer_ids(&pool).await.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn reap_draining_writers_limits_closes_per_health_tick() {
|
||||
let pool = make_pool(128).await;
|
||||
let now_epoch_secs = MePool::now_epoch_secs();
|
||||
let close_budget = health_drain_close_budget();
|
||||
let writer_total = close_budget.saturating_add(19);
|
||||
for writer_id in 1..=writer_total as u64 {
|
||||
insert_draining_writer(
|
||||
&pool,
|
||||
writer_id,
|
||||
now_epoch_secs.saturating_sub(20),
|
||||
1,
|
||||
now_epoch_secs.saturating_sub(1),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
let mut warn_next_allowed = HashMap::new();
|
||||
|
||||
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||
|
||||
assert_eq!(pool.writers.read().await.len(), writer_total - close_budget);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn reap_draining_writers_keeps_warn_state_for_deadline_backlog_writers() {
|
||||
let pool = make_pool(0).await;
|
||||
let now_epoch_secs = MePool::now_epoch_secs();
|
||||
let close_budget = health_drain_close_budget();
|
||||
let writer_total = close_budget.saturating_add(5);
|
||||
for writer_id in 1..=writer_total as u64 {
|
||||
insert_draining_writer(
|
||||
&pool,
|
||||
writer_id,
|
||||
now_epoch_secs.saturating_sub(60),
|
||||
1,
|
||||
now_epoch_secs.saturating_sub(1),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
let target_writer_id = writer_total as u64;
|
||||
let mut warn_next_allowed = HashMap::new();
|
||||
warn_next_allowed.insert(
|
||||
target_writer_id,
|
||||
Instant::now() + Duration::from_secs(300),
|
||||
);
|
||||
|
||||
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||
|
||||
assert!(writer_exists(&pool, target_writer_id).await);
|
||||
assert!(warn_next_allowed.contains_key(&target_writer_id));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn reap_draining_writers_keeps_warn_state_for_overflow_backlog_writers() {
|
||||
let pool = make_pool(1).await;
|
||||
let now_epoch_secs = MePool::now_epoch_secs();
|
||||
let close_budget = health_drain_close_budget();
|
||||
let writer_total = close_budget.saturating_add(6);
|
||||
for writer_id in 1..=writer_total as u64 {
|
||||
insert_draining_writer(
|
||||
&pool,
|
||||
writer_id,
|
||||
now_epoch_secs.saturating_sub(300).saturating_add(writer_id),
|
||||
1,
|
||||
0,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
let target_writer_id = writer_total.saturating_sub(1) as u64;
|
||||
let mut warn_next_allowed = HashMap::new();
|
||||
warn_next_allowed.insert(
|
||||
target_writer_id,
|
||||
Instant::now() + Duration::from_secs(300),
|
||||
);
|
||||
|
||||
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||
|
||||
assert!(writer_exists(&pool, target_writer_id).await);
|
||||
assert!(warn_next_allowed.contains_key(&target_writer_id));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn reap_draining_writers_drops_warn_state_when_writer_exits_draining_state() {
|
||||
let pool = make_pool(128).await;
|
||||
let now_epoch_secs = MePool::now_epoch_secs();
|
||||
insert_draining_writer(&pool, 71, now_epoch_secs.saturating_sub(60), 1, 0).await;
|
||||
|
||||
let mut warn_next_allowed = HashMap::new();
|
||||
warn_next_allowed.insert(71, Instant::now() + Duration::from_secs(300));
|
||||
|
||||
set_writer_draining(&pool, 71, false).await;
|
||||
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||
|
||||
assert!(writer_exists(&pool, 71).await);
|
||||
assert!(
|
||||
!warn_next_allowed.contains_key(&71),
|
||||
"warn cooldown state must be dropped after writer leaves draining state"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn reap_draining_writers_preserves_warn_state_across_multiple_budget_deferrals() {
|
||||
let pool = make_pool(0).await;
|
||||
let now_epoch_secs = MePool::now_epoch_secs();
|
||||
let close_budget = health_drain_close_budget();
|
||||
let writer_total = close_budget.saturating_mul(2).saturating_add(1);
|
||||
for writer_id in 1..=writer_total as u64 {
|
||||
insert_draining_writer(
|
||||
&pool,
|
||||
writer_id,
|
||||
now_epoch_secs.saturating_sub(120),
|
||||
1,
|
||||
now_epoch_secs.saturating_sub(1),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
let tail_writer_id = writer_total as u64;
|
||||
let mut warn_next_allowed = HashMap::new();
|
||||
warn_next_allowed.insert(
|
||||
tail_writer_id,
|
||||
Instant::now() + Duration::from_secs(300),
|
||||
);
|
||||
|
||||
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||
assert!(writer_exists(&pool, tail_writer_id).await);
|
||||
assert!(warn_next_allowed.contains_key(&tail_writer_id));
|
||||
|
||||
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||
assert!(writer_exists(&pool, tail_writer_id).await);
|
||||
assert!(warn_next_allowed.contains_key(&tail_writer_id));
|
||||
|
||||
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||
assert!(!writer_exists(&pool, tail_writer_id).await);
|
||||
assert!(
|
||||
!warn_next_allowed.contains_key(&tail_writer_id),
|
||||
"warn cooldown state must clear once writer is actually removed"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn reap_draining_writers_backlog_drains_across_ticks() {
|
||||
let pool = make_pool(128).await;
|
||||
let now_epoch_secs = MePool::now_epoch_secs();
|
||||
let close_budget = health_drain_close_budget();
|
||||
let writer_total = close_budget.saturating_mul(2).saturating_add(7);
|
||||
for writer_id in 1..=writer_total as u64 {
|
||||
insert_draining_writer(
|
||||
&pool,
|
||||
writer_id,
|
||||
now_epoch_secs.saturating_sub(20),
|
||||
1,
|
||||
now_epoch_secs.saturating_sub(1),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
let mut warn_next_allowed = HashMap::new();
|
||||
|
||||
for _ in 0..8 {
|
||||
if pool.writers.read().await.is_empty() {
|
||||
break;
|
||||
}
|
||||
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||
}
|
||||
|
||||
assert!(pool.writers.read().await.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn reap_draining_writers_threshold_backlog_converges_to_threshold() {
|
||||
let threshold = 5u64;
|
||||
let pool = make_pool(threshold).await;
|
||||
let now_epoch_secs = MePool::now_epoch_secs();
|
||||
let close_budget = health_drain_close_budget();
|
||||
let writer_total = threshold as usize + close_budget.saturating_add(12);
|
||||
for writer_id in 1..=writer_total as u64 {
|
||||
insert_draining_writer(
|
||||
&pool,
|
||||
writer_id,
|
||||
now_epoch_secs.saturating_sub(200).saturating_add(writer_id),
|
||||
1,
|
||||
0,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
let mut warn_next_allowed = HashMap::new();
|
||||
|
||||
for _ in 0..16 {
|
||||
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||
if pool.writers.read().await.len() <= threshold as usize {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(pool.writers.read().await.len(), threshold as usize);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn reap_draining_writers_threshold_zero_preserves_non_expired_non_empty_writers() {
|
||||
let pool = make_pool(0).await;
|
||||
let now_epoch_secs = MePool::now_epoch_secs();
|
||||
insert_draining_writer(&pool, 10, now_epoch_secs.saturating_sub(40), 1, 0).await;
|
||||
insert_draining_writer(&pool, 20, now_epoch_secs.saturating_sub(30), 1, 0).await;
|
||||
insert_draining_writer(&pool, 30, now_epoch_secs.saturating_sub(20), 1, 0).await;
|
||||
let mut warn_next_allowed = HashMap::new();
|
||||
|
||||
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||
|
||||
assert_eq!(current_writer_ids(&pool).await, vec![10, 20, 30]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn reap_draining_writers_prioritizes_force_close_before_empty_cleanup() {
|
||||
let pool = make_pool(128).await;
|
||||
let now_epoch_secs = MePool::now_epoch_secs();
|
||||
let close_budget = health_drain_close_budget();
|
||||
for writer_id in 1..=close_budget as u64 {
|
||||
insert_draining_writer(
|
||||
&pool,
|
||||
writer_id,
|
||||
now_epoch_secs.saturating_sub(20),
|
||||
1,
|
||||
now_epoch_secs.saturating_sub(1),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
let empty_writer_id = close_budget as u64 + 1;
|
||||
insert_draining_writer(&pool, empty_writer_id, now_epoch_secs.saturating_sub(20), 0, 0).await;
|
||||
let mut warn_next_allowed = HashMap::new();
|
||||
|
||||
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||
|
||||
assert_eq!(current_writer_ids(&pool).await, vec![empty_writer_id]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn reap_draining_writers_empty_cleanup_does_not_increment_force_close_metric() {
|
||||
let pool = make_pool(128).await;
|
||||
let now_epoch_secs = MePool::now_epoch_secs();
|
||||
insert_draining_writer(&pool, 1, now_epoch_secs.saturating_sub(60), 0, 0).await;
|
||||
insert_draining_writer(&pool, 2, now_epoch_secs.saturating_sub(50), 0, 0).await;
|
||||
let mut warn_next_allowed = HashMap::new();
|
||||
|
||||
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||
|
||||
assert!(current_writer_ids(&pool).await.is_empty());
|
||||
assert_eq!(pool.stats.get_pool_force_close_total(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn reap_draining_writers_handles_duplicate_force_close_requests_for_same_writer() {
|
||||
let pool = make_pool(1).await;
|
||||
let now_epoch_secs = MePool::now_epoch_secs();
|
||||
insert_draining_writer(
|
||||
&pool,
|
||||
10,
|
||||
now_epoch_secs.saturating_sub(30),
|
||||
1,
|
||||
now_epoch_secs.saturating_sub(1),
|
||||
)
|
||||
.await;
|
||||
insert_draining_writer(
|
||||
&pool,
|
||||
20,
|
||||
now_epoch_secs.saturating_sub(20),
|
||||
1,
|
||||
now_epoch_secs.saturating_sub(1),
|
||||
)
|
||||
.await;
|
||||
let mut warn_next_allowed = HashMap::new();
|
||||
|
||||
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||
|
||||
assert!(current_writer_ids(&pool).await.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn reap_draining_writers_warn_state_never_exceeds_live_draining_population_under_churn() {
|
||||
let pool = make_pool(128).await;
|
||||
let now_epoch_secs = MePool::now_epoch_secs();
|
||||
let mut warn_next_allowed = HashMap::new();
|
||||
|
||||
for wave in 0..12u64 {
|
||||
for offset in 0..9u64 {
|
||||
insert_draining_writer(
|
||||
&pool,
|
||||
wave * 100 + offset,
|
||||
now_epoch_secs.saturating_sub(120 + offset),
|
||||
1,
|
||||
0,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||
assert!(warn_next_allowed.len() <= pool.writers.read().await.len());
|
||||
|
||||
let existing_writer_ids = current_writer_ids(&pool).await;
|
||||
for writer_id in existing_writer_ids.into_iter().take(4) {
|
||||
let _ = pool.remove_writer_and_close_clients(writer_id).await;
|
||||
}
|
||||
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||
assert!(warn_next_allowed.len() <= pool.writers.read().await.len());
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn reap_draining_writers_mixed_backlog_converges_without_leaking_warn_state() {
|
||||
let pool = make_pool(6).await;
|
||||
let now_epoch_secs = MePool::now_epoch_secs();
|
||||
let mut warn_next_allowed = HashMap::new();
|
||||
|
||||
for writer_id in 1..=18u64 {
|
||||
let bound_clients = if writer_id % 3 == 0 { 0 } else { 1 };
|
||||
let deadline = if writer_id % 2 == 0 {
|
||||
now_epoch_secs.saturating_sub(1)
|
||||
} else {
|
||||
0
|
||||
};
|
||||
insert_draining_writer(
|
||||
&pool,
|
||||
writer_id,
|
||||
now_epoch_secs.saturating_sub(300).saturating_add(writer_id),
|
||||
bound_clients,
|
||||
deadline,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
for _ in 0..16 {
|
||||
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||
if pool.writers.read().await.len() <= 6 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
assert!(pool.writers.read().await.len() <= 6);
|
||||
assert!(warn_next_allowed.len() <= pool.writers.read().await.len());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn general_config_default_drain_threshold_remains_enabled() {
|
||||
assert_eq!(GeneralConfig::default().me_pool_drain_threshold, 128);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn reap_draining_writers_does_not_close_writer_that_became_non_empty_after_snapshot() {
|
||||
let pool = make_pool(128).await;
|
||||
let now_epoch_secs = MePool::now_epoch_secs();
|
||||
|
||||
let empty_writer_id = 700u64;
|
||||
insert_draining_writer(
|
||||
&pool,
|
||||
empty_writer_id,
|
||||
now_epoch_secs.saturating_sub(60),
|
||||
0,
|
||||
0,
|
||||
)
|
||||
.await;
|
||||
|
||||
let stale_empty_snapshot = vec![empty_writer_id];
|
||||
let (rebound_conn_id, _rx) = pool.registry.register().await;
|
||||
assert!(
|
||||
pool.registry
|
||||
.bind_writer(
|
||||
rebound_conn_id,
|
||||
empty_writer_id,
|
||||
ConnMeta {
|
||||
target_dc: 2,
|
||||
client_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 9050),
|
||||
our_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443),
|
||||
proto_flags: 0,
|
||||
},
|
||||
)
|
||||
.await,
|
||||
"writer should accept a new bind after stale empty snapshot"
|
||||
);
|
||||
|
||||
for writer_id in stale_empty_snapshot {
|
||||
assert!(
|
||||
!pool.remove_writer_if_empty(writer_id).await,
|
||||
"atomic empty cleanup must reject writers that gained bound clients"
|
||||
);
|
||||
}
|
||||
|
||||
assert!(
|
||||
writer_exists(&pool, empty_writer_id).await,
|
||||
"empty-path cleanup must not remove a writer that gained a bound client"
|
||||
);
|
||||
assert_eq!(
|
||||
pool.registry.get_writer(rebound_conn_id).await.map(|w| w.writer_id),
|
||||
Some(empty_writer_id)
|
||||
);
|
||||
|
||||
let _ = pool.registry.unregister(rebound_conn_id).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn prune_closed_writers_closes_bound_clients_when_writer_is_non_empty() {
|
||||
let pool = make_pool(128).await;
|
||||
let now_epoch_secs = MePool::now_epoch_secs();
|
||||
let conn_ids = insert_draining_writer(&pool, 910, now_epoch_secs.saturating_sub(60), 1, 0).await;
|
||||
|
||||
pool.prune_closed_writers().await;
|
||||
|
||||
assert!(!writer_exists(&pool, 910).await);
|
||||
assert!(pool.registry.get_writer(conn_ids[0]).await.is_none());
|
||||
}
|
||||
|
|
@ -21,6 +21,14 @@ mod secret;
|
|||
mod selftest;
|
||||
mod wire;
|
||||
mod pool_status;
|
||||
#[cfg(test)]
|
||||
mod health_regression_tests;
|
||||
#[cfg(test)]
|
||||
mod health_integration_tests;
|
||||
#[cfg(test)]
|
||||
mod health_adversarial_tests;
|
||||
#[cfg(test)]
|
||||
mod send_adversarial_tests;
|
||||
|
||||
use bytes::Bytes;
|
||||
|
||||
|
|
|
|||
|
|
@ -160,6 +160,7 @@ pub struct MePool {
|
|||
pub(super) refill_inflight: Arc<Mutex<HashSet<RefillEndpointKey>>>,
|
||||
pub(super) refill_inflight_dc: Arc<Mutex<HashSet<RefillDcKey>>>,
|
||||
pub(super) conn_count: AtomicUsize,
|
||||
pub(super) draining_active_runtime: AtomicU64,
|
||||
pub(super) stats: Arc<crate::stats::Stats>,
|
||||
pub(super) generation: AtomicU64,
|
||||
pub(super) active_generation: AtomicU64,
|
||||
|
|
@ -438,6 +439,7 @@ impl MePool {
|
|||
refill_inflight: Arc::new(Mutex::new(HashSet::new())),
|
||||
refill_inflight_dc: Arc::new(Mutex::new(HashSet::new())),
|
||||
conn_count: AtomicUsize::new(0),
|
||||
draining_active_runtime: AtomicU64::new(0),
|
||||
generation: AtomicU64::new(1),
|
||||
active_generation: AtomicU64::new(1),
|
||||
warm_generation: AtomicU64::new(0),
|
||||
|
|
@ -690,6 +692,33 @@ impl MePool {
|
|||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub(super) fn draining_active_runtime(&self) -> u64 {
|
||||
self.draining_active_runtime.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
pub(super) fn increment_draining_active_runtime(&self) {
|
||||
self.draining_active_runtime.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub(super) fn decrement_draining_active_runtime(&self) {
|
||||
let mut current = self.draining_active_runtime.load(Ordering::Relaxed);
|
||||
loop {
|
||||
if current == 0 {
|
||||
break;
|
||||
}
|
||||
match self.draining_active_runtime.compare_exchange_weak(
|
||||
current,
|
||||
current - 1,
|
||||
Ordering::Relaxed,
|
||||
Ordering::Relaxed,
|
||||
) {
|
||||
Ok(_) => break,
|
||||
Err(actual) => current = actual,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn key_selector(&self) -> u32 {
|
||||
self.proxy_secret.read().await.key_selector
|
||||
}
|
||||
|
|
|
|||
|
|
@ -141,6 +141,38 @@ impl MePool {
|
|||
out
|
||||
}
|
||||
|
||||
pub(super) async fn has_non_draining_writer_per_desired_dc_group(&self) -> bool {
|
||||
let desired_by_dc = self.desired_dc_endpoints().await;
|
||||
let required_dcs: HashSet<i32> = desired_by_dc
|
||||
.iter()
|
||||
.filter_map(|(dc, endpoints)| {
|
||||
if endpoints.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(*dc)
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
if required_dcs.is_empty() {
|
||||
return true;
|
||||
}
|
||||
|
||||
let ws = self.writers.read().await;
|
||||
let mut covered_dcs = HashSet::<i32>::with_capacity(required_dcs.len());
|
||||
for writer in ws.iter() {
|
||||
if writer.draining.load(Ordering::Relaxed) {
|
||||
continue;
|
||||
}
|
||||
if required_dcs.contains(&writer.writer_dc) {
|
||||
covered_dcs.insert(writer.writer_dc);
|
||||
if covered_dcs.len() == required_dcs.len() {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
fn hardswap_warmup_connect_delay_ms(&self) -> u64 {
|
||||
let min_ms = self.me_hardswap_warmup_delay_min_ms.load(Ordering::Relaxed);
|
||||
let max_ms = self.me_hardswap_warmup_delay_max_ms.load(Ordering::Relaxed);
|
||||
|
|
@ -475,12 +507,30 @@ impl MePool {
|
|||
coverage_ratio = format_args!("{coverage_ratio:.3}"),
|
||||
min_ratio = format_args!("{min_ratio:.3}"),
|
||||
drain_timeout_secs,
|
||||
"ME reinit cycle covered; draining stale writers"
|
||||
"ME reinit cycle covered; processing stale writers"
|
||||
);
|
||||
self.stats.increment_pool_swap_total();
|
||||
let can_drop_with_replacement = self
|
||||
.has_non_draining_writer_per_desired_dc_group()
|
||||
.await;
|
||||
if can_drop_with_replacement {
|
||||
info!(
|
||||
stale_writers = stale_writer_ids.len(),
|
||||
"ME reinit stale writers: replacement coverage ready, force-closing clients for fast rebind"
|
||||
);
|
||||
} else {
|
||||
warn!(
|
||||
stale_writers = stale_writer_ids.len(),
|
||||
"ME reinit stale writers: replacement coverage incomplete, keeping draining fallback"
|
||||
);
|
||||
}
|
||||
for writer_id in stale_writer_ids {
|
||||
self.mark_writer_draining_with_timeout(writer_id, drain_timeout, !hardswap)
|
||||
.await;
|
||||
if can_drop_with_replacement {
|
||||
self.stats.increment_pool_force_close_total();
|
||||
self.remove_writer_and_close_clients(writer_id).await;
|
||||
}
|
||||
}
|
||||
if hardswap {
|
||||
self.clear_pending_hardswap_state();
|
||||
|
|
|
|||
|
|
@ -42,11 +42,10 @@ impl MePool {
|
|||
}
|
||||
|
||||
for writer_id in closed_writer_ids {
|
||||
if self.registry.is_writer_empty(writer_id).await {
|
||||
let _ = self.remove_writer_only(writer_id).await;
|
||||
} else {
|
||||
let _ = self.remove_writer_and_close_clients(writer_id).await;
|
||||
if self.remove_writer_if_empty(writer_id).await {
|
||||
continue;
|
||||
}
|
||||
let _ = self.remove_writer_and_close_clients(writer_id).await;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -501,6 +500,17 @@ impl MePool {
|
|||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn remove_writer_if_empty(self: &Arc<Self>, writer_id: u64) -> bool {
|
||||
if !self.registry.unregister_writer_if_empty(writer_id).await {
|
||||
return false;
|
||||
}
|
||||
|
||||
// The registry empty-check and unregister are atomic with respect to binds,
|
||||
// so remove_writer_only cannot return active bound sessions here.
|
||||
let _ = self.remove_writer_only(writer_id).await;
|
||||
true
|
||||
}
|
||||
|
||||
async fn remove_writer_only(self: &Arc<Self>, writer_id: u64) -> Vec<BoundConn> {
|
||||
let mut close_tx: Option<mpsc::Sender<WriterCommand>> = None;
|
||||
let mut removed_addr: Option<SocketAddr> = None;
|
||||
|
|
@ -514,6 +524,7 @@ impl MePool {
|
|||
let was_draining = w.draining.load(Ordering::Relaxed);
|
||||
if was_draining {
|
||||
self.stats.decrement_pool_drain_active();
|
||||
self.decrement_draining_active_runtime();
|
||||
}
|
||||
self.stats.increment_me_writer_removed_total();
|
||||
w.cancel.cancel();
|
||||
|
|
@ -572,6 +583,7 @@ impl MePool {
|
|||
.store(drain_deadline_epoch_secs, Ordering::Relaxed);
|
||||
if !already_draining {
|
||||
self.stats.increment_pool_drain_active();
|
||||
self.increment_draining_active_runtime();
|
||||
}
|
||||
w.contour
|
||||
.store(WriterContour::Draining.as_u8(), Ordering::Relaxed);
|
||||
|
|
|
|||
|
|
@ -436,6 +436,37 @@ impl ConnRegistry {
|
|||
.map(|s| s.is_empty())
|
||||
.unwrap_or(true)
|
||||
}
|
||||
|
||||
pub async fn unregister_writer_if_empty(&self, writer_id: u64) -> bool {
|
||||
let mut inner = self.inner.write().await;
|
||||
let Some(conn_ids) = inner.conns_for_writer.get(&writer_id) else {
|
||||
// Writer is already absent from the registry.
|
||||
return true;
|
||||
};
|
||||
if !conn_ids.is_empty() {
|
||||
return false;
|
||||
}
|
||||
|
||||
inner.writers.remove(&writer_id);
|
||||
inner.last_meta_for_writer.remove(&writer_id);
|
||||
inner.writer_idle_since_epoch_secs.remove(&writer_id);
|
||||
inner.conns_for_writer.remove(&writer_id);
|
||||
true
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub(super) async fn non_empty_writer_ids(&self, writer_ids: &[u64]) -> HashSet<u64> {
|
||||
let inner = self.inner.read().await;
|
||||
let mut out = HashSet::<u64>::with_capacity(writer_ids.len());
|
||||
for writer_id in writer_ids {
|
||||
if let Some(conns) = inner.conns_for_writer.get(writer_id)
|
||||
&& !conns.is_empty()
|
||||
{
|
||||
out.insert(*writer_id);
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
@ -634,4 +665,35 @@ mod tests {
|
|||
);
|
||||
assert!(registry.get_writer(conn_id).await.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn non_empty_writer_ids_returns_only_writers_with_bound_clients() {
|
||||
let registry = ConnRegistry::new();
|
||||
let (conn_id, _rx) = registry.register().await;
|
||||
let (writer_tx_a, _writer_rx_a) = tokio::sync::mpsc::channel(8);
|
||||
let (writer_tx_b, _writer_rx_b) = tokio::sync::mpsc::channel(8);
|
||||
registry.register_writer(10, writer_tx_a).await;
|
||||
registry.register_writer(20, writer_tx_b).await;
|
||||
|
||||
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443);
|
||||
assert!(
|
||||
registry
|
||||
.bind_writer(
|
||||
conn_id,
|
||||
10,
|
||||
ConnMeta {
|
||||
target_dc: 2,
|
||||
client_addr: addr,
|
||||
our_addr: addr,
|
||||
proto_flags: 0,
|
||||
},
|
||||
)
|
||||
.await
|
||||
);
|
||||
|
||||
let non_empty = registry.non_empty_writer_ids(&[10, 20, 30]).await;
|
||||
assert!(non_empty.contains(&10));
|
||||
assert!(!non_empty.contains(&20));
|
||||
assert!(!non_empty.contains(&30));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -372,17 +372,20 @@ impl MePool {
|
|||
}
|
||||
let effective_our_addr = SocketAddr::new(w.source_ip, our_addr.port());
|
||||
let (payload, meta) = build_routed_payload(effective_our_addr);
|
||||
match w.tx.try_send(WriterCommand::Data(payload.clone())) {
|
||||
Ok(()) => {
|
||||
self.stats.increment_me_writer_pick_success_try_total(pick_mode);
|
||||
match w.tx.clone().try_reserve_owned() {
|
||||
Ok(permit) => {
|
||||
if !self.registry.bind_writer(conn_id, w.id, meta).await {
|
||||
debug!(
|
||||
conn_id,
|
||||
writer_id = w.id,
|
||||
"ME writer disappeared before bind commit, retrying"
|
||||
"ME writer disappeared before bind commit, pruning stale writer"
|
||||
);
|
||||
drop(permit);
|
||||
self.remove_writer_and_close_clients(w.id).await;
|
||||
continue;
|
||||
}
|
||||
permit.send(WriterCommand::Data(payload.clone()));
|
||||
self.stats.increment_me_writer_pick_success_try_total(pick_mode);
|
||||
if w.generation < self.current_generation() {
|
||||
self.stats.increment_pool_stale_pick_total();
|
||||
debug!(
|
||||
|
|
@ -422,18 +425,21 @@ impl MePool {
|
|||
self.stats.increment_me_writer_pick_blocking_fallback_total();
|
||||
let effective_our_addr = SocketAddr::new(w.source_ip, our_addr.port());
|
||||
let (payload, meta) = build_routed_payload(effective_our_addr);
|
||||
match w.tx.send(WriterCommand::Data(payload.clone())).await {
|
||||
Ok(()) => {
|
||||
self.stats
|
||||
.increment_me_writer_pick_success_fallback_total(pick_mode);
|
||||
match w.tx.clone().reserve_owned().await {
|
||||
Ok(permit) => {
|
||||
if !self.registry.bind_writer(conn_id, w.id, meta).await {
|
||||
debug!(
|
||||
conn_id,
|
||||
writer_id = w.id,
|
||||
"ME writer disappeared before fallback bind commit, retrying"
|
||||
"ME writer disappeared before fallback bind commit, pruning stale writer"
|
||||
);
|
||||
drop(permit);
|
||||
self.remove_writer_and_close_clients(w.id).await;
|
||||
continue;
|
||||
}
|
||||
permit.send(WriterCommand::Data(payload.clone()));
|
||||
self.stats
|
||||
.increment_me_writer_pick_success_fallback_total(pick_mode);
|
||||
if w.generation < self.current_generation() {
|
||||
self.stats.increment_pool_stale_pick_total();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,263 @@
|
|||
use std::collections::HashMap;
|
||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicBool, AtomicU8, AtomicU32, AtomicU64, Ordering};
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
use super::codec::WriterCommand;
|
||||
use super::pool::{MePool, MeWriter, WriterContour};
|
||||
use crate::config::{GeneralConfig, MeRouteNoWriterMode, MeSocksKdfPolicy, MeWriterPickMode};
|
||||
use crate::crypto::SecureRandom;
|
||||
use crate::network::probe::NetworkDecision;
|
||||
use crate::stats::Stats;
|
||||
|
||||
async fn make_pool() -> (Arc<MePool>, Arc<SecureRandom>) {
|
||||
let general = GeneralConfig {
|
||||
me_route_no_writer_mode: MeRouteNoWriterMode::AsyncRecoveryFailfast,
|
||||
me_route_no_writer_wait_ms: 50,
|
||||
me_writer_pick_mode: MeWriterPickMode::SortedRr,
|
||||
me_deterministic_writer_sort: true,
|
||||
..GeneralConfig::default()
|
||||
};
|
||||
|
||||
let rng = Arc::new(SecureRandom::new());
|
||||
let pool = MePool::new(
|
||||
None,
|
||||
vec![1u8; 32],
|
||||
None,
|
||||
false,
|
||||
None,
|
||||
Vec::new(),
|
||||
1,
|
||||
None,
|
||||
12,
|
||||
1200,
|
||||
HashMap::new(),
|
||||
HashMap::new(),
|
||||
None,
|
||||
NetworkDecision::default(),
|
||||
None,
|
||||
rng.clone(),
|
||||
Arc::new(Stats::default()),
|
||||
general.me_keepalive_enabled,
|
||||
general.me_keepalive_interval_secs,
|
||||
general.me_keepalive_jitter_secs,
|
||||
general.me_keepalive_payload_random,
|
||||
general.rpc_proxy_req_every,
|
||||
general.me_warmup_stagger_enabled,
|
||||
general.me_warmup_step_delay_ms,
|
||||
general.me_warmup_step_jitter_ms,
|
||||
general.me_reconnect_max_concurrent_per_dc,
|
||||
general.me_reconnect_backoff_base_ms,
|
||||
general.me_reconnect_backoff_cap_ms,
|
||||
general.me_reconnect_fast_retry_count,
|
||||
general.me_single_endpoint_shadow_writers,
|
||||
general.me_single_endpoint_outage_mode_enabled,
|
||||
general.me_single_endpoint_outage_disable_quarantine,
|
||||
general.me_single_endpoint_outage_backoff_min_ms,
|
||||
general.me_single_endpoint_outage_backoff_max_ms,
|
||||
general.me_single_endpoint_shadow_rotate_every_secs,
|
||||
general.me_floor_mode,
|
||||
general.me_adaptive_floor_idle_secs,
|
||||
general.me_adaptive_floor_min_writers_single_endpoint,
|
||||
general.me_adaptive_floor_min_writers_multi_endpoint,
|
||||
general.me_adaptive_floor_recover_grace_secs,
|
||||
general.me_adaptive_floor_writers_per_core_total,
|
||||
general.me_adaptive_floor_cpu_cores_override,
|
||||
general.me_adaptive_floor_max_extra_writers_single_per_core,
|
||||
general.me_adaptive_floor_max_extra_writers_multi_per_core,
|
||||
general.me_adaptive_floor_max_active_writers_per_core,
|
||||
general.me_adaptive_floor_max_warm_writers_per_core,
|
||||
general.me_adaptive_floor_max_active_writers_global,
|
||||
general.me_adaptive_floor_max_warm_writers_global,
|
||||
general.hardswap,
|
||||
general.me_pool_drain_ttl_secs,
|
||||
general.me_pool_drain_threshold,
|
||||
general.effective_me_pool_force_close_secs(),
|
||||
general.me_pool_min_fresh_ratio,
|
||||
general.me_hardswap_warmup_delay_min_ms,
|
||||
general.me_hardswap_warmup_delay_max_ms,
|
||||
general.me_hardswap_warmup_extra_passes,
|
||||
general.me_hardswap_warmup_pass_backoff_base_ms,
|
||||
general.me_bind_stale_mode,
|
||||
general.me_bind_stale_ttl_secs,
|
||||
general.me_secret_atomic_snapshot,
|
||||
general.me_deterministic_writer_sort,
|
||||
general.me_writer_pick_mode,
|
||||
general.me_writer_pick_sample_size,
|
||||
MeSocksKdfPolicy::default(),
|
||||
general.me_writer_cmd_channel_capacity,
|
||||
general.me_route_channel_capacity,
|
||||
general.me_route_backpressure_base_timeout_ms,
|
||||
general.me_route_backpressure_high_timeout_ms,
|
||||
general.me_route_backpressure_high_watermark_pct,
|
||||
general.me_reader_route_data_wait_ms,
|
||||
general.me_health_interval_ms_unhealthy,
|
||||
general.me_health_interval_ms_healthy,
|
||||
general.me_warn_rate_limit_ms,
|
||||
general.me_route_no_writer_mode,
|
||||
general.me_route_no_writer_wait_ms,
|
||||
general.me_route_inline_recovery_attempts,
|
||||
general.me_route_inline_recovery_wait_ms,
|
||||
);
|
||||
|
||||
(pool, rng)
|
||||
}
|
||||
|
||||
async fn insert_writer(
|
||||
pool: &Arc<MePool>,
|
||||
writer_id: u64,
|
||||
writer_dc: i32,
|
||||
addr: SocketAddr,
|
||||
register_in_registry: bool,
|
||||
) -> mpsc::Receiver<WriterCommand> {
|
||||
let (tx, rx) = mpsc::channel::<WriterCommand>(8);
|
||||
let writer = MeWriter {
|
||||
id: writer_id,
|
||||
addr,
|
||||
source_ip: addr.ip(),
|
||||
writer_dc,
|
||||
generation: pool.current_generation(),
|
||||
contour: Arc::new(AtomicU8::new(WriterContour::Active.as_u8())),
|
||||
created_at: Instant::now(),
|
||||
tx: tx.clone(),
|
||||
cancel: CancellationToken::new(),
|
||||
degraded: Arc::new(AtomicBool::new(false)),
|
||||
rtt_ema_ms_x10: Arc::new(AtomicU32::new(0)),
|
||||
draining: Arc::new(AtomicBool::new(false)),
|
||||
draining_started_at_epoch_secs: Arc::new(AtomicU64::new(0)),
|
||||
drain_deadline_epoch_secs: Arc::new(AtomicU64::new(0)),
|
||||
allow_drain_fallback: Arc::new(AtomicBool::new(false)),
|
||||
};
|
||||
|
||||
pool.writers.write().await.push(writer);
|
||||
{
|
||||
let mut map = pool.proxy_map_v4.write().await;
|
||||
map.entry(writer_dc)
|
||||
.or_insert_with(Vec::new)
|
||||
.push((addr.ip(), addr.port()));
|
||||
}
|
||||
pool.rebuild_endpoint_dc_map().await;
|
||||
if register_in_registry {
|
||||
pool.registry.register_writer(writer_id, tx).await;
|
||||
}
|
||||
rx
|
||||
}
|
||||
|
||||
async fn recv_data_count(rx: &mut mpsc::Receiver<WriterCommand>, budget: Duration) -> usize {
|
||||
let start = Instant::now();
|
||||
let mut data_count = 0usize;
|
||||
while Instant::now().duration_since(start) < budget {
|
||||
let remaining = budget.saturating_sub(Instant::now().duration_since(start));
|
||||
match tokio::time::timeout(remaining.min(Duration::from_millis(10)), rx.recv()).await {
|
||||
Ok(Some(WriterCommand::Data(_))) => data_count += 1,
|
||||
Ok(Some(WriterCommand::DataAndFlush(_))) => data_count += 1,
|
||||
Ok(Some(WriterCommand::Close)) => {}
|
||||
Ok(None) => break,
|
||||
Err(_) => break,
|
||||
}
|
||||
}
|
||||
data_count
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn send_proxy_req_does_not_replay_when_first_bind_commit_fails() {
|
||||
let (pool, _rng) = make_pool().await;
|
||||
pool.rr.store(0, Ordering::Relaxed);
|
||||
|
||||
let (conn_id, _rx) = pool.registry.register().await;
|
||||
let mut stale_rx = insert_writer(
|
||||
&pool,
|
||||
10,
|
||||
2,
|
||||
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 10)), 443),
|
||||
false,
|
||||
)
|
||||
.await;
|
||||
let mut live_rx = insert_writer(
|
||||
&pool,
|
||||
11,
|
||||
2,
|
||||
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 11)), 443),
|
||||
true,
|
||||
)
|
||||
.await;
|
||||
|
||||
let result = pool
|
||||
.send_proxy_req(
|
||||
conn_id,
|
||||
2,
|
||||
SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 30000),
|
||||
SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443),
|
||||
b"hello",
|
||||
0,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(recv_data_count(&mut stale_rx, Duration::from_millis(50)).await, 0);
|
||||
assert_eq!(recv_data_count(&mut live_rx, Duration::from_millis(50)).await, 1);
|
||||
|
||||
let bound = pool.registry.get_writer(conn_id).await;
|
||||
assert!(bound.is_some());
|
||||
assert_eq!(bound.expect("writer should be bound").writer_id, 11);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn send_proxy_req_prunes_iterative_stale_bind_failures_without_data_replay() {
|
||||
let (pool, _rng) = make_pool().await;
|
||||
pool.rr.store(0, Ordering::Relaxed);
|
||||
|
||||
let (conn_id, _rx) = pool.registry.register().await;
|
||||
|
||||
let mut stale_rx_1 = insert_writer(
|
||||
&pool,
|
||||
21,
|
||||
2,
|
||||
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 1, 21)), 443),
|
||||
false,
|
||||
)
|
||||
.await;
|
||||
let mut stale_rx_2 = insert_writer(
|
||||
&pool,
|
||||
22,
|
||||
2,
|
||||
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 1, 22)), 443),
|
||||
false,
|
||||
)
|
||||
.await;
|
||||
let mut live_rx = insert_writer(
|
||||
&pool,
|
||||
23,
|
||||
2,
|
||||
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 1, 23)), 443),
|
||||
true,
|
||||
)
|
||||
.await;
|
||||
|
||||
let result = pool
|
||||
.send_proxy_req(
|
||||
conn_id,
|
||||
2,
|
||||
SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 30001),
|
||||
SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443),
|
||||
b"storm",
|
||||
0,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(recv_data_count(&mut stale_rx_1, Duration::from_millis(50)).await, 0);
|
||||
assert_eq!(recv_data_count(&mut stale_rx_2, Duration::from_millis(50)).await, 0);
|
||||
assert_eq!(recv_data_count(&mut live_rx, Duration::from_millis(50)).await, 1);
|
||||
|
||||
let writers = pool.writers.read().await;
|
||||
let writer_ids = writers.iter().map(|w| w.id).collect::<Vec<_>>();
|
||||
drop(writers);
|
||||
assert_eq!(writer_ids, vec![23]);
|
||||
}
|
||||
Loading…
Reference in New Issue