mirror of https://github.com/telemt/telemt.git
Merge 97d4a1c5c8 into 4f55d08c51
This commit is contained in:
commit
3db62aaea8
|
|
@ -390,6 +390,12 @@ you MUST explain why existing invariants remain valid.
|
||||||
- Do not modify existing tests unless the task explicitly requires it.
|
- Do not modify existing tests unless the task explicitly requires it.
|
||||||
- Do not weaken assertions.
|
- Do not weaken assertions.
|
||||||
- Preserve determinism in testable components.
|
- Preserve determinism in testable components.
|
||||||
|
- Bug-first forces the discipline of proving you understand a bug before you fix it. Tests written after a fix almost always pass trivially and catch nothing new.
|
||||||
|
- Invariants over scenarios is the core shift. The route_mode table alone would have caught both BUG-1 and BUG-2 before they were written — "snapshot equals watch state after any transition burst" is a two-line property test that fails immediately on the current diverged-atomics code.
|
||||||
|
- Differential/model catches logic drift over time.
|
||||||
|
- Scheduler pressure is specifically aimed at the concurrent state bugs that keep reappearing. A single-threaded happy-path test of set_mode will never find subtle bugs; 10,000 concurrent calls will find it on the first run.
|
||||||
|
- Mutation gate answers your original complaint directly. It measures test power. If you can remove a bounds check and nothing breaks, the suite isn't covering that branch yet — it just says so explicitly.
|
||||||
|
- Dead parameter is a code smell rule.
|
||||||
|
|
||||||
### 15. Security Constraints
|
### 15. Security Constraints
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -425,6 +425,32 @@ dependencies = [
|
||||||
"cipher",
|
"cipher",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "curve25519-dalek"
|
||||||
|
version = "4.1.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "97fb8b7c4503de7d6ae7b42ab72a5a59857b4c937ec27a3d4539dba95b5ab2be"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if",
|
||||||
|
"cpufeatures",
|
||||||
|
"curve25519-dalek-derive",
|
||||||
|
"fiat-crypto",
|
||||||
|
"rustc_version",
|
||||||
|
"subtle",
|
||||||
|
"zeroize",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "curve25519-dalek-derive"
|
||||||
|
version = "0.1.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "f46882e17999c6cc590af592290432be3bce0428cb0d5f8b6715e4dc7b383eb3"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn 2.0.114",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "dashmap"
|
name = "dashmap"
|
||||||
version = "5.5.3"
|
version = "5.5.3"
|
||||||
|
|
@ -517,6 +543,12 @@ version = "2.3.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
|
checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "fiat-crypto"
|
||||||
|
version = "0.2.9"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "28dea519a9695b9977216879a3ebfddf92f1c08c05d984f8996aecd6ecdc811d"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "filetime"
|
name = "filetime"
|
||||||
version = "0.2.27"
|
version = "0.2.27"
|
||||||
|
|
@ -1609,7 +1641,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1"
|
checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"rand_chacha",
|
"rand_chacha",
|
||||||
"rand_core",
|
"rand_core 0.9.5",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|
@ -1619,9 +1651,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb"
|
checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"ppv-lite86",
|
"ppv-lite86",
|
||||||
"rand_core",
|
"rand_core 0.9.5",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rand_core"
|
||||||
|
version = "0.6.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rand_core"
|
name = "rand_core"
|
||||||
version = "0.9.5"
|
version = "0.9.5"
|
||||||
|
|
@ -1637,7 +1675,7 @@ version = "0.4.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "513962919efc330f829edb2535844d1b912b0fbe2ca165d613e4e8788bb05a5a"
|
checksum = "513962919efc330f829edb2535844d1b912b0fbe2ca165d613e4e8788bb05a5a"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"rand_core",
|
"rand_core 0.9.5",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|
@ -2093,7 +2131,7 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "telemt"
|
name = "telemt"
|
||||||
version = "3.3.19"
|
version = "3.3.20"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"aes",
|
"aes",
|
||||||
"anyhow",
|
"anyhow",
|
||||||
|
|
@ -2145,6 +2183,7 @@ dependencies = [
|
||||||
"tracing-subscriber",
|
"tracing-subscriber",
|
||||||
"url",
|
"url",
|
||||||
"webpki-roots 0.26.11",
|
"webpki-roots 0.26.11",
|
||||||
|
"x25519-dalek",
|
||||||
"x509-parser",
|
"x509-parser",
|
||||||
"zeroize",
|
"zeroize",
|
||||||
]
|
]
|
||||||
|
|
@ -3144,6 +3183,18 @@ version = "0.6.2"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "9edde0db4769d2dc68579893f2306b26c6ecfbe0ef499b013d731b7b9247e0b9"
|
checksum = "9edde0db4769d2dc68579893f2306b26c6ecfbe0ef499b013d731b7b9247e0b9"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "x25519-dalek"
|
||||||
|
version = "2.0.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "c7e468321c81fb07fa7f4c636c3972b9100f0346e5b6a9f2bd0603a52f7ed277"
|
||||||
|
dependencies = [
|
||||||
|
"curve25519-dalek",
|
||||||
|
"rand_core 0.6.4",
|
||||||
|
"serde",
|
||||||
|
"zeroize",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "x509-parser"
|
name = "x509-parser"
|
||||||
version = "0.15.1"
|
version = "0.15.1"
|
||||||
|
|
|
||||||
|
|
@ -52,6 +52,7 @@ regex = "1.11"
|
||||||
crossbeam-queue = "0.3"
|
crossbeam-queue = "0.3"
|
||||||
num-bigint = "0.4"
|
num-bigint = "0.4"
|
||||||
num-traits = "0.2"
|
num-traits = "0.2"
|
||||||
|
x25519-dalek = "2"
|
||||||
anyhow = "1.0"
|
anyhow = "1.0"
|
||||||
|
|
||||||
# HTTP
|
# HTTP
|
||||||
|
|
|
||||||
|
|
@ -32,6 +32,7 @@ show = "*"
|
||||||
port = 443
|
port = 443
|
||||||
# proxy_protocol = false # Enable if behind HAProxy/nginx with PROXY protocol
|
# proxy_protocol = false # Enable if behind HAProxy/nginx with PROXY protocol
|
||||||
# metrics_port = 9090
|
# metrics_port = 9090
|
||||||
|
# metrics_listen = "0.0.0.0:9090" # Listen address for metrics (overrides metrics_port)
|
||||||
# metrics_whitelist = ["127.0.0.1", "::1", "0.0.0.0/0"]
|
# metrics_whitelist = ["127.0.0.1", "::1", "0.0.0.0/0"]
|
||||||
|
|
||||||
[server.api]
|
[server.api]
|
||||||
|
|
|
||||||
|
|
@ -38,8 +38,9 @@ umweltschutz.de -> A-запись 198.18.88.88
|
||||||
|
|
||||||
В конфигурации Telemt:
|
В конфигурации Telemt:
|
||||||
|
|
||||||
```
|
```toml
|
||||||
tls_domain = umweltschutz.de
|
[censorship]
|
||||||
|
tls_domain = "umweltschutz.de"
|
||||||
```
|
```
|
||||||
|
|
||||||
Этот домен используется клиентом как SNI в ClientHello
|
Этот домен используется клиентом как SNI в ClientHello
|
||||||
|
|
@ -56,8 +57,9 @@ tls_domain = umweltschutz.de
|
||||||
|
|
||||||
В конфигурации Telemt:
|
В конфигурации Telemt:
|
||||||
|
|
||||||
```
|
```toml
|
||||||
mask_host = 127.0.0.1
|
[censorship]
|
||||||
|
mask_host = "127.0.0.1"
|
||||||
mask_port = 8443
|
mask_port = 8443
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
@ -151,16 +153,18 @@ mask_host:mask_port
|
||||||
|
|
||||||
Например:
|
Например:
|
||||||
|
|
||||||
```
|
```toml
|
||||||
tls_domain = github.com
|
[censorship]
|
||||||
mask_host = github.com
|
tls_domain = "github.com"
|
||||||
|
mask_host = "github.com"
|
||||||
mask_port = 443
|
mask_port = 443
|
||||||
```
|
```
|
||||||
|
|
||||||
или
|
или
|
||||||
|
|
||||||
```
|
```toml
|
||||||
mask_host = 140.82.121.4
|
[censorship]
|
||||||
|
mask_host = "140.82.121.4"
|
||||||
```
|
```
|
||||||
|
|
||||||
В этом случае:
|
В этом случае:
|
||||||
|
|
|
||||||
|
|
@ -239,7 +239,7 @@ tls_full_cert_ttl_secs = 90
|
||||||
|
|
||||||
[access]
|
[access]
|
||||||
replay_check_len = 65536
|
replay_check_len = 65536
|
||||||
replay_window_secs = 1800
|
replay_window_secs = 120
|
||||||
ignore_time_skew = false
|
ignore_time_skew = false
|
||||||
|
|
||||||
[access.users]
|
[access.users]
|
||||||
|
|
|
||||||
|
|
@ -73,7 +73,9 @@ pub(crate) fn default_replay_check_len() -> usize {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn default_replay_window_secs() -> u64 {
|
pub(crate) fn default_replay_window_secs() -> u64 {
|
||||||
1800
|
// Keep replay cache TTL tight by default to reduce replay surface.
|
||||||
|
// Deployments with higher RTT or longer reconnect jitter can override this in config.
|
||||||
|
120
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn default_handshake_timeout() -> u64 {
|
pub(crate) fn default_handshake_timeout() -> u64 {
|
||||||
|
|
@ -456,11 +458,11 @@ pub(crate) fn default_tls_full_cert_ttl_secs() -> u64 {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn default_server_hello_delay_min_ms() -> u64 {
|
pub(crate) fn default_server_hello_delay_min_ms() -> u64 {
|
||||||
0
|
8
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn default_server_hello_delay_max_ms() -> u64 {
|
pub(crate) fn default_server_hello_delay_max_ms() -> u64 {
|
||||||
0
|
24
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn default_alpn_enforce() -> bool {
|
pub(crate) fn default_alpn_enforce() -> bool {
|
||||||
|
|
|
||||||
|
|
@ -1163,9 +1163,17 @@ pub struct ServerConfig {
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub proxy_protocol_trusted_cidrs: Vec<IpNetwork>,
|
pub proxy_protocol_trusted_cidrs: Vec<IpNetwork>,
|
||||||
|
|
||||||
|
/// Port for the Prometheus-compatible metrics endpoint.
|
||||||
|
/// Enables metrics when set; binds on all interfaces (dual-stack) by default.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub metrics_port: Option<u16>,
|
pub metrics_port: Option<u16>,
|
||||||
|
|
||||||
|
/// Listen address for metrics in `IP:PORT` format (e.g. `"127.0.0.1:9090"`).
|
||||||
|
/// When set, takes precedence over `metrics_port` and binds on the specified address only.
|
||||||
|
#[serde(default)]
|
||||||
|
pub metrics_listen: Option<String>,
|
||||||
|
|
||||||
|
/// CIDR whitelist for the metrics endpoint.
|
||||||
#[serde(default = "default_metrics_whitelist")]
|
#[serde(default = "default_metrics_whitelist")]
|
||||||
pub metrics_whitelist: Vec<IpNetwork>,
|
pub metrics_whitelist: Vec<IpNetwork>,
|
||||||
|
|
||||||
|
|
@ -1194,6 +1202,7 @@ impl Default for ServerConfig {
|
||||||
proxy_protocol_header_timeout_ms: default_proxy_protocol_header_timeout_ms(),
|
proxy_protocol_header_timeout_ms: default_proxy_protocol_header_timeout_ms(),
|
||||||
proxy_protocol_trusted_cidrs: Vec::new(),
|
proxy_protocol_trusted_cidrs: Vec::new(),
|
||||||
metrics_port: None,
|
metrics_port: None,
|
||||||
|
metrics_listen: None,
|
||||||
metrics_whitelist: default_metrics_whitelist(),
|
metrics_whitelist: default_metrics_whitelist(),
|
||||||
api: ApiConfig::default(),
|
api: ApiConfig::default(),
|
||||||
listeners: Vec::new(),
|
listeners: Vec::new(),
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,450 @@
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::net::{IpAddr, Ipv4Addr};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use crate::config::UserMaxUniqueIpsMode;
|
||||||
|
use crate::ip_tracker::UserIpTracker;
|
||||||
|
|
||||||
|
fn ip_from_idx(idx: u32) -> IpAddr {
|
||||||
|
let a = 10u8;
|
||||||
|
let b = ((idx / 65_536) % 256) as u8;
|
||||||
|
let c = ((idx / 256) % 256) as u8;
|
||||||
|
let d = (idx % 256) as u8;
|
||||||
|
IpAddr::V4(Ipv4Addr::new(a, b, c, d))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn active_window_enforces_large_unique_ip_burst() {
|
||||||
|
let tracker = UserIpTracker::new();
|
||||||
|
tracker.set_user_limit("burst_user", 64).await;
|
||||||
|
tracker
|
||||||
|
.set_limit_policy(UserMaxUniqueIpsMode::ActiveWindow, 30)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
for idx in 0..64 {
|
||||||
|
assert!(tracker.check_and_add("burst_user", ip_from_idx(idx)).await.is_ok());
|
||||||
|
}
|
||||||
|
assert!(tracker.check_and_add("burst_user", ip_from_idx(9_999)).await.is_err());
|
||||||
|
assert_eq!(tracker.get_active_ip_count("burst_user").await, 64);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn global_limit_applies_across_many_users() {
|
||||||
|
let tracker = UserIpTracker::new();
|
||||||
|
tracker.load_limits(3, &HashMap::new()).await;
|
||||||
|
|
||||||
|
for user_idx in 0..150u32 {
|
||||||
|
let user = format!("u{}", user_idx);
|
||||||
|
assert!(tracker.check_and_add(&user, ip_from_idx(user_idx * 10)).await.is_ok());
|
||||||
|
assert!(tracker
|
||||||
|
.check_and_add(&user, ip_from_idx(user_idx * 10 + 1))
|
||||||
|
.await
|
||||||
|
.is_ok());
|
||||||
|
assert!(tracker
|
||||||
|
.check_and_add(&user, ip_from_idx(user_idx * 10 + 2))
|
||||||
|
.await
|
||||||
|
.is_ok());
|
||||||
|
assert!(tracker
|
||||||
|
.check_and_add(&user, ip_from_idx(user_idx * 10 + 3))
|
||||||
|
.await
|
||||||
|
.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
assert_eq!(tracker.get_stats().await.len(), 150);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn user_zero_override_falls_back_to_global_limit() {
|
||||||
|
let tracker = UserIpTracker::new();
|
||||||
|
let mut limits = HashMap::new();
|
||||||
|
limits.insert("target".to_string(), 0);
|
||||||
|
tracker.load_limits(2, &limits).await;
|
||||||
|
|
||||||
|
assert!(tracker.check_and_add("target", ip_from_idx(1)).await.is_ok());
|
||||||
|
assert!(tracker.check_and_add("target", ip_from_idx(2)).await.is_ok());
|
||||||
|
assert!(tracker.check_and_add("target", ip_from_idx(3)).await.is_err());
|
||||||
|
assert_eq!(tracker.get_user_limit("target").await, Some(2));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn remove_ip_is_idempotent_after_counter_reaches_zero() {
|
||||||
|
let tracker = UserIpTracker::new();
|
||||||
|
tracker.set_user_limit("u", 2).await;
|
||||||
|
let ip = ip_from_idx(42);
|
||||||
|
|
||||||
|
tracker.check_and_add("u", ip).await.unwrap();
|
||||||
|
tracker.remove_ip("u", ip).await;
|
||||||
|
tracker.remove_ip("u", ip).await;
|
||||||
|
tracker.remove_ip("u", ip).await;
|
||||||
|
|
||||||
|
assert_eq!(tracker.get_active_ip_count("u").await, 0);
|
||||||
|
assert!(!tracker.is_ip_active("u", ip).await);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn clear_user_ips_resets_active_and_recent() {
|
||||||
|
let tracker = UserIpTracker::new();
|
||||||
|
tracker.set_user_limit("u", 10).await;
|
||||||
|
|
||||||
|
for idx in 0..6 {
|
||||||
|
tracker.check_and_add("u", ip_from_idx(idx)).await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
tracker.clear_user_ips("u").await;
|
||||||
|
|
||||||
|
assert_eq!(tracker.get_active_ip_count("u").await, 0);
|
||||||
|
let counts = tracker
|
||||||
|
.get_recent_counts_for_users(&["u".to_string()])
|
||||||
|
.await;
|
||||||
|
assert_eq!(counts.get("u").copied().unwrap_or(0), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn clear_all_resets_multi_user_state() {
|
||||||
|
let tracker = UserIpTracker::new();
|
||||||
|
|
||||||
|
for user_idx in 0..80u32 {
|
||||||
|
let user = format!("u{}", user_idx);
|
||||||
|
for ip_idx in 0..3 {
|
||||||
|
tracker
|
||||||
|
.check_and_add(&user, ip_from_idx(user_idx * 100 + ip_idx))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tracker.clear_all().await;
|
||||||
|
|
||||||
|
assert!(tracker.get_stats().await.is_empty());
|
||||||
|
let users = (0..80u32)
|
||||||
|
.map(|idx| format!("u{}", idx))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
let recent = tracker.get_recent_counts_for_users(&users).await;
|
||||||
|
assert!(recent.values().all(|count| *count == 0));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn get_active_ips_for_users_are_sorted() {
|
||||||
|
let tracker = UserIpTracker::new();
|
||||||
|
tracker.set_user_limit("user", 10).await;
|
||||||
|
|
||||||
|
tracker
|
||||||
|
.check_and_add("user", IpAddr::V4(Ipv4Addr::new(10, 0, 0, 9)))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
tracker
|
||||||
|
.check_and_add("user", IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
tracker
|
||||||
|
.check_and_add("user", IpAddr::V4(Ipv4Addr::new(10, 0, 0, 5)))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let map = tracker
|
||||||
|
.get_active_ips_for_users(&["user".to_string()])
|
||||||
|
.await;
|
||||||
|
let ips = map.get("user").cloned().unwrap_or_default();
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
ips,
|
||||||
|
vec![
|
||||||
|
IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
|
||||||
|
IpAddr::V4(Ipv4Addr::new(10, 0, 0, 5)),
|
||||||
|
IpAddr::V4(Ipv4Addr::new(10, 0, 0, 9)),
|
||||||
|
]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn get_recent_ips_for_users_are_sorted() {
|
||||||
|
let tracker = UserIpTracker::new();
|
||||||
|
tracker.set_user_limit("user", 10).await;
|
||||||
|
|
||||||
|
tracker
|
||||||
|
.check_and_add("user", IpAddr::V4(Ipv4Addr::new(10, 1, 0, 9)))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
tracker
|
||||||
|
.check_and_add("user", IpAddr::V4(Ipv4Addr::new(10, 1, 0, 1)))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
tracker
|
||||||
|
.check_and_add("user", IpAddr::V4(Ipv4Addr::new(10, 1, 0, 5)))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let map = tracker
|
||||||
|
.get_recent_ips_for_users(&["user".to_string()])
|
||||||
|
.await;
|
||||||
|
let ips = map.get("user").cloned().unwrap_or_default();
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
ips,
|
||||||
|
vec![
|
||||||
|
IpAddr::V4(Ipv4Addr::new(10, 1, 0, 1)),
|
||||||
|
IpAddr::V4(Ipv4Addr::new(10, 1, 0, 5)),
|
||||||
|
IpAddr::V4(Ipv4Addr::new(10, 1, 0, 9)),
|
||||||
|
]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn time_window_expires_for_large_rotation() {
|
||||||
|
let tracker = UserIpTracker::new();
|
||||||
|
tracker.set_user_limit("tw", 1).await;
|
||||||
|
tracker
|
||||||
|
.set_limit_policy(UserMaxUniqueIpsMode::TimeWindow, 1)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
tracker.check_and_add("tw", ip_from_idx(1)).await.unwrap();
|
||||||
|
tracker.remove_ip("tw", ip_from_idx(1)).await;
|
||||||
|
assert!(tracker.check_and_add("tw", ip_from_idx(2)).await.is_err());
|
||||||
|
|
||||||
|
tokio::time::sleep(Duration::from_millis(1_100)).await;
|
||||||
|
assert!(tracker.check_and_add("tw", ip_from_idx(2)).await.is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn combined_mode_blocks_recent_after_disconnect() {
|
||||||
|
let tracker = UserIpTracker::new();
|
||||||
|
tracker.set_user_limit("cmb", 1).await;
|
||||||
|
tracker
|
||||||
|
.set_limit_policy(UserMaxUniqueIpsMode::Combined, 2)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
tracker.check_and_add("cmb", ip_from_idx(11)).await.unwrap();
|
||||||
|
tracker.remove_ip("cmb", ip_from_idx(11)).await;
|
||||||
|
|
||||||
|
assert!(tracker.check_and_add("cmb", ip_from_idx(12)).await.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn load_limits_replaces_large_limit_map() {
|
||||||
|
let tracker = UserIpTracker::new();
|
||||||
|
let mut first = HashMap::new();
|
||||||
|
let mut second = HashMap::new();
|
||||||
|
|
||||||
|
for idx in 0..300usize {
|
||||||
|
first.insert(format!("u{}", idx), 2usize);
|
||||||
|
}
|
||||||
|
for idx in 150..450usize {
|
||||||
|
second.insert(format!("u{}", idx), 4usize);
|
||||||
|
}
|
||||||
|
|
||||||
|
tracker.load_limits(0, &first).await;
|
||||||
|
tracker.load_limits(0, &second).await;
|
||||||
|
|
||||||
|
assert_eq!(tracker.get_user_limit("u20").await, None);
|
||||||
|
assert_eq!(tracker.get_user_limit("u200").await, Some(4));
|
||||||
|
assert_eq!(tracker.get_user_limit("u420").await, Some(4));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||||
|
async fn concurrent_same_user_unique_ip_pressure_stays_bounded() {
|
||||||
|
let tracker = Arc::new(UserIpTracker::new());
|
||||||
|
tracker.set_user_limit("hot", 32).await;
|
||||||
|
tracker
|
||||||
|
.set_limit_policy(UserMaxUniqueIpsMode::ActiveWindow, 30)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let mut handles = Vec::new();
|
||||||
|
for worker in 0..16u32 {
|
||||||
|
let tracker_cloned = tracker.clone();
|
||||||
|
handles.push(tokio::spawn(async move {
|
||||||
|
let base = worker * 200;
|
||||||
|
for step in 0..200u32 {
|
||||||
|
let _ = tracker_cloned
|
||||||
|
.check_and_add("hot", ip_from_idx(base + step))
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
for handle in handles {
|
||||||
|
handle.await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
assert!(tracker.get_active_ip_count("hot").await <= 32);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||||
|
async fn concurrent_many_users_isolate_limits() {
|
||||||
|
let tracker = Arc::new(UserIpTracker::new());
|
||||||
|
tracker.load_limits(4, &HashMap::new()).await;
|
||||||
|
|
||||||
|
let mut handles = Vec::new();
|
||||||
|
for user_idx in 0..120u32 {
|
||||||
|
let tracker_cloned = tracker.clone();
|
||||||
|
handles.push(tokio::spawn(async move {
|
||||||
|
let user = format!("u{}", user_idx);
|
||||||
|
for ip_idx in 0..10u32 {
|
||||||
|
let _ = tracker_cloned
|
||||||
|
.check_and_add(&user, ip_from_idx(user_idx * 1_000 + ip_idx))
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
for handle in handles {
|
||||||
|
handle.await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
let stats = tracker.get_stats().await;
|
||||||
|
assert_eq!(stats.len(), 120);
|
||||||
|
assert!(stats.iter().all(|(_, active, limit)| *active <= 4 && *limit == 4));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn same_ip_reconnect_high_frequency_keeps_single_unique() {
|
||||||
|
let tracker = UserIpTracker::new();
|
||||||
|
tracker.set_user_limit("same", 2).await;
|
||||||
|
let ip = ip_from_idx(9);
|
||||||
|
|
||||||
|
for _ in 0..2_000 {
|
||||||
|
tracker.check_and_add("same", ip).await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
assert_eq!(tracker.get_active_ip_count("same").await, 1);
|
||||||
|
assert!(tracker.is_ip_active("same", ip).await);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn format_stats_contains_expected_limited_and_unlimited_markers() {
|
||||||
|
let tracker = UserIpTracker::new();
|
||||||
|
tracker.set_user_limit("limited", 2).await;
|
||||||
|
tracker.check_and_add("limited", ip_from_idx(1)).await.unwrap();
|
||||||
|
tracker.check_and_add("open", ip_from_idx(2)).await.unwrap();
|
||||||
|
|
||||||
|
let text = tracker.format_stats().await;
|
||||||
|
|
||||||
|
assert!(text.contains("limited"));
|
||||||
|
assert!(text.contains("open"));
|
||||||
|
assert!(text.contains("unlimited"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn stats_report_global_default_for_users_without_override() {
|
||||||
|
let tracker = UserIpTracker::new();
|
||||||
|
tracker.load_limits(5, &HashMap::new()).await;
|
||||||
|
|
||||||
|
tracker.check_and_add("a", ip_from_idx(1)).await.unwrap();
|
||||||
|
tracker.check_and_add("b", ip_from_idx(2)).await.unwrap();
|
||||||
|
|
||||||
|
let stats = tracker.get_stats().await;
|
||||||
|
assert!(stats.iter().any(|(user, _, limit)| user == "a" && *limit == 5));
|
||||||
|
assert!(stats.iter().any(|(user, _, limit)| user == "b" && *limit == 5));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn stress_cycle_add_remove_clear_preserves_empty_end_state() {
|
||||||
|
let tracker = UserIpTracker::new();
|
||||||
|
|
||||||
|
for cycle in 0..50u32 {
|
||||||
|
let user = format!("cycle{}", cycle);
|
||||||
|
tracker.set_user_limit(&user, 128).await;
|
||||||
|
|
||||||
|
for ip_idx in 0..128u32 {
|
||||||
|
tracker
|
||||||
|
.check_and_add(&user, ip_from_idx(cycle * 10_000 + ip_idx))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
for ip_idx in 0..128u32 {
|
||||||
|
tracker
|
||||||
|
.remove_ip(&user, ip_from_idx(cycle * 10_000 + ip_idx))
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
tracker.clear_user_ips(&user).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
assert!(tracker.get_stats().await.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn remove_unknown_user_or_ip_does_not_corrupt_state() {
|
||||||
|
let tracker = UserIpTracker::new();
|
||||||
|
|
||||||
|
tracker.remove_ip("no_user", ip_from_idx(1)).await;
|
||||||
|
tracker.check_and_add("x", ip_from_idx(2)).await.unwrap();
|
||||||
|
tracker.remove_ip("x", ip_from_idx(3)).await;
|
||||||
|
|
||||||
|
assert_eq!(tracker.get_active_ip_count("x").await, 1);
|
||||||
|
assert!(tracker.is_ip_active("x", ip_from_idx(2)).await);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn active_and_recent_views_match_after_mixed_workload() {
|
||||||
|
let tracker = UserIpTracker::new();
|
||||||
|
tracker.set_user_limit("mix", 16).await;
|
||||||
|
|
||||||
|
for ip_idx in 0..12u32 {
|
||||||
|
tracker.check_and_add("mix", ip_from_idx(ip_idx)).await.unwrap();
|
||||||
|
}
|
||||||
|
for ip_idx in 0..6u32 {
|
||||||
|
tracker.remove_ip("mix", ip_from_idx(ip_idx)).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
let active = tracker
|
||||||
|
.get_active_ips_for_users(&["mix".to_string()])
|
||||||
|
.await
|
||||||
|
.get("mix")
|
||||||
|
.cloned()
|
||||||
|
.unwrap_or_default();
|
||||||
|
let recent_count = tracker
|
||||||
|
.get_recent_counts_for_users(&["mix".to_string()])
|
||||||
|
.await
|
||||||
|
.get("mix")
|
||||||
|
.copied()
|
||||||
|
.unwrap_or(0);
|
||||||
|
|
||||||
|
assert_eq!(active.len(), 6);
|
||||||
|
assert!(recent_count >= active.len());
|
||||||
|
assert!(recent_count <= 12);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn global_limit_switch_updates_enforcement_immediately() {
|
||||||
|
let tracker = UserIpTracker::new();
|
||||||
|
tracker.load_limits(2, &HashMap::new()).await;
|
||||||
|
|
||||||
|
assert!(tracker.check_and_add("u", ip_from_idx(1)).await.is_ok());
|
||||||
|
assert!(tracker.check_and_add("u", ip_from_idx(2)).await.is_ok());
|
||||||
|
assert!(tracker.check_and_add("u", ip_from_idx(3)).await.is_err());
|
||||||
|
|
||||||
|
tracker.clear_user_ips("u").await;
|
||||||
|
tracker.load_limits(4, &HashMap::new()).await;
|
||||||
|
|
||||||
|
assert!(tracker.check_and_add("u", ip_from_idx(1)).await.is_ok());
|
||||||
|
assert!(tracker.check_and_add("u", ip_from_idx(2)).await.is_ok());
|
||||||
|
assert!(tracker.check_and_add("u", ip_from_idx(3)).await.is_ok());
|
||||||
|
assert!(tracker.check_and_add("u", ip_from_idx(4)).await.is_ok());
|
||||||
|
assert!(tracker.check_and_add("u", ip_from_idx(5)).await.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||||
|
async fn concurrent_reconnect_and_disconnect_preserves_non_negative_counts() {
|
||||||
|
let tracker = Arc::new(UserIpTracker::new());
|
||||||
|
tracker.set_user_limit("cc", 8).await;
|
||||||
|
|
||||||
|
let mut handles = Vec::new();
|
||||||
|
for worker in 0..8u32 {
|
||||||
|
let tracker_cloned = tracker.clone();
|
||||||
|
handles.push(tokio::spawn(async move {
|
||||||
|
let ip = ip_from_idx(50 + worker);
|
||||||
|
for _ in 0..500u32 {
|
||||||
|
let _ = tracker_cloned.check_and_add("cc", ip).await;
|
||||||
|
tracker_cloned.remove_ip("cc", ip).await;
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
for handle in handles {
|
||||||
|
handle.await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
assert!(tracker.get_active_ip_count("cc").await <= 8);
|
||||||
|
}
|
||||||
|
|
@ -279,11 +279,32 @@ pub(crate) async fn spawn_metrics_if_configured(
|
||||||
ip_tracker: Arc<UserIpTracker>,
|
ip_tracker: Arc<UserIpTracker>,
|
||||||
config_rx: watch::Receiver<Arc<ProxyConfig>>,
|
config_rx: watch::Receiver<Arc<ProxyConfig>>,
|
||||||
) {
|
) {
|
||||||
if let Some(port) = config.server.metrics_port {
|
// metrics_listen takes precedence; fall back to metrics_port for backward compat.
|
||||||
|
let metrics_target: Option<(u16, Option<String>)> =
|
||||||
|
if let Some(ref listen) = config.server.metrics_listen {
|
||||||
|
match listen.parse::<std::net::SocketAddr>() {
|
||||||
|
Ok(addr) => Some((addr.port(), Some(listen.clone()))),
|
||||||
|
Err(e) => {
|
||||||
|
startup_tracker
|
||||||
|
.skip_component(
|
||||||
|
COMPONENT_METRICS_START,
|
||||||
|
Some(format!("invalid metrics_listen \"{}\": {}", listen, e)),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
config.server.metrics_port.map(|p| (p, None))
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Some((port, listen)) = metrics_target {
|
||||||
|
let fallback_label = format!("port {}", port);
|
||||||
|
let label = listen.as_deref().unwrap_or(&fallback_label);
|
||||||
startup_tracker
|
startup_tracker
|
||||||
.start_component(
|
.start_component(
|
||||||
COMPONENT_METRICS_START,
|
COMPONENT_METRICS_START,
|
||||||
Some(format!("spawn metrics endpoint on {}", port)),
|
Some(format!("spawn metrics endpoint on {}", label)),
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
let stats = stats.clone();
|
let stats = stats.clone();
|
||||||
|
|
@ -294,6 +315,7 @@ pub(crate) async fn spawn_metrics_if_configured(
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
metrics::serve(
|
metrics::serve(
|
||||||
port,
|
port,
|
||||||
|
listen,
|
||||||
stats,
|
stats,
|
||||||
beobachten,
|
beobachten,
|
||||||
ip_tracker_metrics,
|
ip_tracker_metrics,
|
||||||
|
|
@ -308,7 +330,7 @@ pub(crate) async fn spawn_metrics_if_configured(
|
||||||
Some("metrics task spawned".to_string()),
|
Some("metrics task spawned".to_string()),
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
} else {
|
} else if config.server.metrics_listen.is_none() {
|
||||||
startup_tracker
|
startup_tracker
|
||||||
.skip_component(
|
.skip_component(
|
||||||
COMPONENT_METRICS_START,
|
COMPONENT_METRICS_START,
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,8 @@ mod config;
|
||||||
mod crypto;
|
mod crypto;
|
||||||
mod error;
|
mod error;
|
||||||
mod ip_tracker;
|
mod ip_tracker;
|
||||||
|
#[cfg(test)]
|
||||||
|
mod ip_tracker_regression_tests;
|
||||||
mod maestro;
|
mod maestro;
|
||||||
mod metrics;
|
mod metrics;
|
||||||
mod network;
|
mod network;
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,7 @@ use crate::transport::{ListenOptions, create_listener};
|
||||||
|
|
||||||
pub async fn serve(
|
pub async fn serve(
|
||||||
port: u16,
|
port: u16,
|
||||||
|
listen: Option<String>,
|
||||||
stats: Arc<Stats>,
|
stats: Arc<Stats>,
|
||||||
beobachten: Arc<BeobachtenStore>,
|
beobachten: Arc<BeobachtenStore>,
|
||||||
ip_tracker: Arc<UserIpTracker>,
|
ip_tracker: Arc<UserIpTracker>,
|
||||||
|
|
@ -28,6 +29,33 @@ pub async fn serve(
|
||||||
whitelist: Vec<IpNetwork>,
|
whitelist: Vec<IpNetwork>,
|
||||||
) {
|
) {
|
||||||
let whitelist = Arc::new(whitelist);
|
let whitelist = Arc::new(whitelist);
|
||||||
|
|
||||||
|
// If `metrics_listen` is set, bind on that single address only.
|
||||||
|
if let Some(ref listen_addr) = listen {
|
||||||
|
let addr: SocketAddr = match listen_addr.parse() {
|
||||||
|
Ok(a) => a,
|
||||||
|
Err(e) => {
|
||||||
|
warn!(error = %e, "Invalid metrics_listen address: {}", listen_addr);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let is_ipv6 = addr.is_ipv6();
|
||||||
|
match bind_metrics_listener(addr, is_ipv6) {
|
||||||
|
Ok(listener) => {
|
||||||
|
info!("Metrics endpoint: http://{}/metrics and /beobachten", addr);
|
||||||
|
serve_listener(
|
||||||
|
listener, stats, beobachten, ip_tracker, config_rx, whitelist,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
warn!(error = %e, "Failed to bind metrics on {}", addr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback: bind on 0.0.0.0 and [::] using metrics_port.
|
||||||
let mut listener_v4 = None;
|
let mut listener_v4 = None;
|
||||||
let mut listener_v6 = None;
|
let mut listener_v6 = None;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -11,9 +11,8 @@ use crate::crypto::{sha256_hmac, SecureRandom};
|
||||||
use crate::error::ProxyError;
|
use crate::error::ProxyError;
|
||||||
use super::constants::*;
|
use super::constants::*;
|
||||||
use std::time::{SystemTime, UNIX_EPOCH};
|
use std::time::{SystemTime, UNIX_EPOCH};
|
||||||
use num_bigint::BigUint;
|
|
||||||
use num_traits::One;
|
|
||||||
use subtle::ConstantTimeEq;
|
use subtle::ConstantTimeEq;
|
||||||
|
use x25519_dalek::{X25519_BASEPOINT_BYTES, x25519};
|
||||||
|
|
||||||
// ============= Public Constants =============
|
// ============= Public Constants =============
|
||||||
|
|
||||||
|
|
@ -27,8 +26,12 @@ pub const TLS_DIGEST_POS: usize = 11;
|
||||||
pub const TLS_DIGEST_HALF_LEN: usize = 16;
|
pub const TLS_DIGEST_HALF_LEN: usize = 16;
|
||||||
|
|
||||||
/// Time skew limits for anti-replay (in seconds)
|
/// Time skew limits for anti-replay (in seconds)
|
||||||
pub const TIME_SKEW_MIN: i64 = -20 * 60; // 20 minutes before
|
///
|
||||||
pub const TIME_SKEW_MAX: i64 = 10 * 60; // 10 minutes after
|
/// The default window is intentionally narrow to reduce replay acceptance.
|
||||||
|
/// Operators with known clock-drifted clients should tune deployment config
|
||||||
|
/// (for example replay-window policy) to match their environment.
|
||||||
|
pub const TIME_SKEW_MIN: i64 = -2 * 60; // 2 minutes before
|
||||||
|
pub const TIME_SKEW_MAX: i64 = 2 * 60; // 2 minutes after
|
||||||
/// Maximum accepted boot-time timestamp (seconds) before skew checks are enforced.
|
/// Maximum accepted boot-time timestamp (seconds) before skew checks are enforced.
|
||||||
pub const BOOT_TIME_MAX_SECS: u32 = 7 * 24 * 60 * 60;
|
pub const BOOT_TIME_MAX_SECS: u32 = 7 * 24 * 60 * 60;
|
||||||
|
|
||||||
|
|
@ -117,27 +120,6 @@ impl TlsExtensionBuilder {
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Add ALPN extension with a single selected protocol.
|
|
||||||
fn add_alpn(&mut self, proto: &[u8]) -> &mut Self {
|
|
||||||
// Extension type: ALPN (0x0010)
|
|
||||||
self.extensions.extend_from_slice(&extension_type::ALPN.to_be_bytes());
|
|
||||||
|
|
||||||
// ALPN extension format:
|
|
||||||
// extension_data length (2 bytes)
|
|
||||||
// protocols length (2 bytes)
|
|
||||||
// protocol name length (1 byte)
|
|
||||||
// protocol name bytes
|
|
||||||
let proto_len = proto.len() as u8;
|
|
||||||
let list_len: u16 = 1 + u16::from(proto_len);
|
|
||||||
let ext_len: u16 = 2 + list_len;
|
|
||||||
|
|
||||||
self.extensions.extend_from_slice(&ext_len.to_be_bytes());
|
|
||||||
self.extensions.extend_from_slice(&list_len.to_be_bytes());
|
|
||||||
self.extensions.push(proto_len);
|
|
||||||
self.extensions.extend_from_slice(proto);
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Build final extensions with length prefix
|
/// Build final extensions with length prefix
|
||||||
fn build(self) -> Vec<u8> {
|
fn build(self) -> Vec<u8> {
|
||||||
let mut result = Vec::with_capacity(2 + self.extensions.len());
|
let mut result = Vec::with_capacity(2 + self.extensions.len());
|
||||||
|
|
@ -173,8 +155,6 @@ struct ServerHelloBuilder {
|
||||||
compression: u8,
|
compression: u8,
|
||||||
/// Extensions
|
/// Extensions
|
||||||
extensions: TlsExtensionBuilder,
|
extensions: TlsExtensionBuilder,
|
||||||
/// Selected ALPN protocol (if any)
|
|
||||||
alpn: Option<Vec<u8>>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ServerHelloBuilder {
|
impl ServerHelloBuilder {
|
||||||
|
|
@ -185,7 +165,6 @@ impl ServerHelloBuilder {
|
||||||
cipher_suite: cipher_suite::TLS_AES_128_GCM_SHA256,
|
cipher_suite: cipher_suite::TLS_AES_128_GCM_SHA256,
|
||||||
compression: 0x00,
|
compression: 0x00,
|
||||||
extensions: TlsExtensionBuilder::new(),
|
extensions: TlsExtensionBuilder::new(),
|
||||||
alpn: None,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -200,18 +179,9 @@ impl ServerHelloBuilder {
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
fn with_alpn(mut self, proto: Option<Vec<u8>>) -> Self {
|
|
||||||
self.alpn = proto;
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Build ServerHello message (without record header)
|
/// Build ServerHello message (without record header)
|
||||||
fn build_message(&self) -> Vec<u8> {
|
fn build_message(&self) -> Vec<u8> {
|
||||||
let mut ext_builder = self.extensions.clone();
|
let extensions = self.extensions.extensions.clone();
|
||||||
if let Some(ref alpn) = self.alpn {
|
|
||||||
ext_builder.add_alpn(alpn);
|
|
||||||
}
|
|
||||||
let extensions = ext_builder.extensions.clone();
|
|
||||||
let extensions_len = extensions.len() as u16;
|
let extensions_len = extensions.len() as u16;
|
||||||
|
|
||||||
// Calculate total length
|
// Calculate total length
|
||||||
|
|
@ -316,7 +286,14 @@ pub fn validate_tls_handshake_with_replay_window(
|
||||||
};
|
};
|
||||||
|
|
||||||
let replay_window_u32 = u32::try_from(replay_window_secs).unwrap_or(u32::MAX);
|
let replay_window_u32 = u32::try_from(replay_window_secs).unwrap_or(u32::MAX);
|
||||||
let boot_time_cap_secs = BOOT_TIME_MAX_SECS.min(replay_window_u32);
|
// Boot-time bypass and ignore_time_skew serve different compatibility paths.
|
||||||
|
// When skew checks are disabled, force boot-time cap to zero to prevent
|
||||||
|
// accidental future coupling of boot-time logic into the ignore-skew path.
|
||||||
|
let boot_time_cap_secs = if ignore_time_skew {
|
||||||
|
0
|
||||||
|
} else {
|
||||||
|
BOOT_TIME_MAX_SECS.min(replay_window_u32)
|
||||||
|
};
|
||||||
|
|
||||||
validate_tls_handshake_at_time_with_boot_cap(
|
validate_tls_handshake_at_time_with_boot_cap(
|
||||||
handshake,
|
handshake,
|
||||||
|
|
@ -369,6 +346,9 @@ fn validate_tls_handshake_at_time_with_boot_cap(
|
||||||
// Extract session ID
|
// Extract session ID
|
||||||
let session_id_len_pos = TLS_DIGEST_POS + TLS_DIGEST_LEN;
|
let session_id_len_pos = TLS_DIGEST_POS + TLS_DIGEST_LEN;
|
||||||
let session_id_len = handshake.get(session_id_len_pos).copied()? as usize;
|
let session_id_len = handshake.get(session_id_len_pos).copied()? as usize;
|
||||||
|
if session_id_len > 32 {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
let session_id_start = session_id_len_pos + 1;
|
let session_id_start = session_id_len_pos + 1;
|
||||||
|
|
||||||
if handshake.len() < session_id_start + session_id_len {
|
if handshake.len() < session_id_start + session_id_len {
|
||||||
|
|
@ -411,7 +391,7 @@ fn validate_tls_handshake_at_time_with_boot_cap(
|
||||||
if !ignore_time_skew {
|
if !ignore_time_skew {
|
||||||
// Allow very small timestamps (boot time instead of unix time)
|
// Allow very small timestamps (boot time instead of unix time)
|
||||||
// This is a quirk in some clients that use uptime instead of real time
|
// This is a quirk in some clients that use uptime instead of real time
|
||||||
let is_boot_time = timestamp < boot_time_cap_secs;
|
let is_boot_time = boot_time_cap_secs > 0 && timestamp < boot_time_cap_secs;
|
||||||
if !is_boot_time {
|
if !is_boot_time {
|
||||||
let time_diff = now - i64::from(timestamp);
|
let time_diff = now - i64::from(timestamp);
|
||||||
if !(TIME_SKEW_MIN..=TIME_SKEW_MAX).contains(&time_diff) {
|
if !(TIME_SKEW_MIN..=TIME_SKEW_MAX).contains(&time_diff) {
|
||||||
|
|
@ -433,27 +413,14 @@ fn validate_tls_handshake_at_time_with_boot_cap(
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn curve25519_prime() -> BigUint {
|
|
||||||
(BigUint::one() << 255) - BigUint::from(19u32)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Generate a fake X25519 public key for TLS
|
/// Generate a fake X25519 public key for TLS
|
||||||
///
|
///
|
||||||
/// Produces a quadratic residue mod p = 2^255 - 19 by computing n² mod p,
|
/// Uses RFC 7748 X25519 scalar multiplication over the canonical basepoint,
|
||||||
/// which matches Python/C behavior and avoids DPI fingerprinting.
|
/// yielding distribution-consistent public keys for anti-fingerprinting.
|
||||||
pub fn gen_fake_x25519_key(rng: &SecureRandom) -> [u8; 32] {
|
pub fn gen_fake_x25519_key(rng: &SecureRandom) -> [u8; 32] {
|
||||||
let mut n_bytes = [0u8; 32];
|
let mut scalar = [0u8; 32];
|
||||||
n_bytes.copy_from_slice(&rng.bytes(32));
|
scalar.copy_from_slice(&rng.bytes(32));
|
||||||
|
x25519(scalar, X25519_BASEPOINT_BYTES)
|
||||||
let n = BigUint::from_bytes_le(&n_bytes);
|
|
||||||
let p = curve25519_prime();
|
|
||||||
let pk = (&n * &n) % &p;
|
|
||||||
|
|
||||||
let mut out = pk.to_bytes_le();
|
|
||||||
out.resize(32, 0);
|
|
||||||
let mut result = [0u8; 32];
|
|
||||||
result.copy_from_slice(&out[..32]);
|
|
||||||
result
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Build TLS ServerHello response
|
/// Build TLS ServerHello response
|
||||||
|
|
@ -470,7 +437,7 @@ pub fn build_server_hello(
|
||||||
session_id: &[u8],
|
session_id: &[u8],
|
||||||
fake_cert_len: usize,
|
fake_cert_len: usize,
|
||||||
rng: &SecureRandom,
|
rng: &SecureRandom,
|
||||||
alpn: Option<Vec<u8>>,
|
_alpn: Option<Vec<u8>>,
|
||||||
new_session_tickets: u8,
|
new_session_tickets: u8,
|
||||||
) -> Vec<u8> {
|
) -> Vec<u8> {
|
||||||
const MIN_APP_DATA: usize = 64;
|
const MIN_APP_DATA: usize = 64;
|
||||||
|
|
@ -482,7 +449,6 @@ pub fn build_server_hello(
|
||||||
let server_hello = ServerHelloBuilder::new(session_id.to_vec())
|
let server_hello = ServerHelloBuilder::new(session_id.to_vec())
|
||||||
.with_x25519_key(&x25519_key)
|
.with_x25519_key(&x25519_key)
|
||||||
.with_tls13_version()
|
.with_tls13_version()
|
||||||
.with_alpn(alpn)
|
|
||||||
.build_record();
|
.build_record();
|
||||||
|
|
||||||
// Build Change Cipher Spec record
|
// Build Change Cipher Spec record
|
||||||
|
|
@ -705,10 +671,10 @@ pub fn is_tls_handshake(first_bytes: &[u8]) -> bool {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// TLS record header: 0x16 (handshake) 0x03 0x01 (TLS 1.0)
|
// TLS ClientHello commonly uses legacy record versions 0x0301 or 0x0303.
|
||||||
first_bytes[0] == TLS_RECORD_HANDSHAKE
|
first_bytes[0] == TLS_RECORD_HANDSHAKE
|
||||||
&& first_bytes[1] == 0x03
|
&& first_bytes[1] == 0x03
|
||||||
&& first_bytes[2] == 0x01
|
&& (first_bytes[2] == 0x01 || first_bytes[2] == 0x03)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Parse TLS record header, returns (record_type, length)
|
/// Parse TLS record header, returns (record_type, length)
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,8 @@
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::crypto::sha256_hmac;
|
use crate::crypto::sha256_hmac;
|
||||||
|
use crate::tls_front::emulator::build_emulated_server_hello;
|
||||||
|
use crate::tls_front::types::{CachedTlsData, ParsedServerHello, TlsBehaviorProfile, TlsProfileSource};
|
||||||
|
use std::time::SystemTime;
|
||||||
|
|
||||||
/// Build a TLS-handshake-like buffer that contains a valid HMAC digest
|
/// Build a TLS-handshake-like buffer that contains a valid HMAC digest
|
||||||
/// for the given `secret` and `timestamp`.
|
/// for the given `secret` and `timestamp`.
|
||||||
|
|
@ -369,16 +372,16 @@ fn one_byte_session_id_validates_and_is_preserved() {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn max_session_id_len_255_with_valid_digest_is_accepted() {
|
fn max_session_id_len_255_with_valid_digest_is_rejected_by_rfc_cap() {
|
||||||
let secret = b"sid_len_255_test";
|
let secret = b"sid_len_255_test";
|
||||||
let session_id = vec![0xCCu8; 255];
|
let session_id = vec![0xCCu8; 255];
|
||||||
let handshake = make_valid_tls_handshake_with_session_id(secret, 0, &session_id);
|
let handshake = make_valid_tls_handshake_with_session_id(secret, 0, &session_id);
|
||||||
let secrets = vec![("u".to_string(), secret.to_vec())];
|
let secrets = vec![("u".to_string(), secret.to_vec())];
|
||||||
|
|
||||||
let result = validate_tls_handshake(&handshake, &secrets, true)
|
assert!(
|
||||||
.expect("session_id_len=255 with valid digest must validate");
|
validate_tls_handshake(&handshake, &secrets, true).is_none(),
|
||||||
assert_eq!(result.session_id.len(), 255);
|
"legacy_session_id length > 32 must be rejected even with valid digest"
|
||||||
assert_eq!(result.session_id, session_id);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// ------------------------------------------------------------------
|
// ------------------------------------------------------------------
|
||||||
|
|
@ -731,6 +734,246 @@ fn replay_window_cap_still_allows_small_boot_timestamp() {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn ignore_time_skew_explicitly_decouples_from_boot_time_cap() {
|
||||||
|
let secret = b"ignore_skew_boot_cap_decouple_test";
|
||||||
|
let ts: u32 = 1;
|
||||||
|
let h = make_valid_tls_handshake(secret, ts);
|
||||||
|
let secrets = vec![("u".to_string(), secret.to_vec())];
|
||||||
|
|
||||||
|
let cap_zero = validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, true, 0, 0);
|
||||||
|
let cap_nonzero =
|
||||||
|
validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, true, 0, BOOT_TIME_MAX_SECS);
|
||||||
|
|
||||||
|
assert!(cap_zero.is_some(), "ignore_time_skew=true must accept valid HMAC");
|
||||||
|
assert!(
|
||||||
|
cap_nonzero.is_some(),
|
||||||
|
"ignore_time_skew path must not depend on boot-time cap"
|
||||||
|
);
|
||||||
|
|
||||||
|
let a = cap_zero.unwrap();
|
||||||
|
let b = cap_nonzero.unwrap();
|
||||||
|
assert_eq!(a.user, b.user);
|
||||||
|
assert_eq!(a.timestamp, b.timestamp);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn adversarial_small_boot_timestamp_matrix_rejected_when_boot_cap_forced_zero() {
|
||||||
|
let secret = b"boot_cap_zero_matrix_test";
|
||||||
|
let secrets = vec![("u".to_string(), secret.to_vec())];
|
||||||
|
let now: i64 = 1_700_000_000;
|
||||||
|
|
||||||
|
for ts in 0u32..1024u32 {
|
||||||
|
let h = make_valid_tls_handshake(secret, ts);
|
||||||
|
let result = validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, false, now, 0);
|
||||||
|
assert!(
|
||||||
|
result.is_none(),
|
||||||
|
"boot cap=0 must reject timestamp {ts} when skew checks are active"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn light_fuzz_boot_cap_zero_rejects_small_timestamp_space() {
|
||||||
|
let secret = b"boot_cap_zero_fuzz_test";
|
||||||
|
let secrets = vec![("u".to_string(), secret.to_vec())];
|
||||||
|
let now: i64 = 1_700_000_000;
|
||||||
|
let mut s: u64 = 0x9E37_79B9_7F4A_7C15;
|
||||||
|
|
||||||
|
for _ in 0..4096 {
|
||||||
|
s ^= s << 7;
|
||||||
|
s ^= s >> 9;
|
||||||
|
s ^= s << 8;
|
||||||
|
let ts = (s as u32) % 2048;
|
||||||
|
|
||||||
|
let h = make_valid_tls_handshake(secret, ts);
|
||||||
|
let result = validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, false, now, 0);
|
||||||
|
assert!(
|
||||||
|
result.is_none(),
|
||||||
|
"fuzzed boot-range timestamp {ts} must be rejected when cap=0"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn stress_boot_cap_zero_rejection_is_deterministic_under_high_iteration_count() {
|
||||||
|
let secret = b"boot_cap_zero_stress_test";
|
||||||
|
let secrets = vec![("u".to_string(), secret.to_vec())];
|
||||||
|
let now: i64 = 1_700_000_000;
|
||||||
|
|
||||||
|
for i in 0u32..20_000u32 {
|
||||||
|
let ts = i % 4096;
|
||||||
|
let h = make_valid_tls_handshake(secret, ts);
|
||||||
|
let result = validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, false, now, 0);
|
||||||
|
assert!(
|
||||||
|
result.is_none(),
|
||||||
|
"iteration {i}: timestamp {ts} must be rejected with cap=0"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn replay_window_one_allows_only_zero_timestamp_boot_bypass() {
|
||||||
|
let secret = b"replay_window_one_boot_test";
|
||||||
|
let secrets = vec![("u".to_string(), secret.to_vec())];
|
||||||
|
|
||||||
|
let ts0 = make_valid_tls_handshake(secret, 0);
|
||||||
|
let ts1 = make_valid_tls_handshake(secret, 1);
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
validate_tls_handshake_with_replay_window(&ts0, &secrets, false, 1).is_some(),
|
||||||
|
"replay_window=1 must allow timestamp 0 via boot-time compatibility"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
validate_tls_handshake_with_replay_window(&ts1, &secrets, false, 1).is_none(),
|
||||||
|
"replay_window=1 must reject timestamp 1 on normal wall-clock systems"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn replay_window_two_allows_ts0_ts1_but_rejects_ts2() {
|
||||||
|
let secret = b"replay_window_two_boot_test";
|
||||||
|
let secrets = vec![("u".to_string(), secret.to_vec())];
|
||||||
|
|
||||||
|
let ts0 = make_valid_tls_handshake(secret, 0);
|
||||||
|
let ts1 = make_valid_tls_handshake(secret, 1);
|
||||||
|
let ts2 = make_valid_tls_handshake(secret, 2);
|
||||||
|
|
||||||
|
assert!(validate_tls_handshake_with_replay_window(&ts0, &secrets, false, 2).is_some());
|
||||||
|
assert!(validate_tls_handshake_with_replay_window(&ts1, &secrets, false, 2).is_some());
|
||||||
|
assert!(
|
||||||
|
validate_tls_handshake_with_replay_window(&ts2, &secrets, false, 2).is_none(),
|
||||||
|
"timestamp equal to replay-window cap must not use boot-time bypass"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn adversarial_skew_boundary_matrix_accepts_only_inclusive_window_when_boot_disabled() {
|
||||||
|
let secret = b"skew_boundary_matrix_test";
|
||||||
|
let secrets = vec![("u".to_string(), secret.to_vec())];
|
||||||
|
let now: i64 = 1_700_000_000;
|
||||||
|
|
||||||
|
for offset in -1500i64..=1500i64 {
|
||||||
|
let ts_i64 = now - offset;
|
||||||
|
let ts = u32::try_from(ts_i64).expect("timestamp must fit u32 for test matrix");
|
||||||
|
let h = make_valid_tls_handshake(secret, ts);
|
||||||
|
let accepted = validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, false, now, 0)
|
||||||
|
.is_some();
|
||||||
|
let expected = (TIME_SKEW_MIN..=TIME_SKEW_MAX).contains(&offset);
|
||||||
|
assert_eq!(
|
||||||
|
accepted, expected,
|
||||||
|
"offset {offset} must match inclusive skew window when boot bypass is disabled"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn light_fuzz_skew_window_rejects_outside_range_when_boot_disabled() {
|
||||||
|
let secret = b"skew_outside_fuzz_test";
|
||||||
|
let secrets = vec![("u".to_string(), secret.to_vec())];
|
||||||
|
let now: i64 = 1_700_000_000;
|
||||||
|
let mut s: u64 = 0x0123_4567_89AB_CDEF;
|
||||||
|
|
||||||
|
for _ in 0..4096 {
|
||||||
|
s ^= s << 7;
|
||||||
|
s ^= s >> 9;
|
||||||
|
s ^= s << 8;
|
||||||
|
|
||||||
|
let magnitude = 1300i64 + ((s % 2000u64) as i64);
|
||||||
|
let sign = if (s & 1) == 0 { 1i64 } else { -1i64 };
|
||||||
|
let offset = sign * magnitude;
|
||||||
|
let ts_i64 = now - offset;
|
||||||
|
let ts = u32::try_from(ts_i64).expect("timestamp must fit u32 for fuzz test");
|
||||||
|
|
||||||
|
let h = make_valid_tls_handshake(secret, ts);
|
||||||
|
let accepted = validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, false, now, 0)
|
||||||
|
.is_some();
|
||||||
|
assert!(
|
||||||
|
!accepted,
|
||||||
|
"offset {offset} must be rejected outside strict skew window"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn stress_boot_disabled_validation_matches_time_diff_oracle() {
|
||||||
|
let secret = b"boot_disabled_oracle_stress_test";
|
||||||
|
let secrets = vec![("u".to_string(), secret.to_vec())];
|
||||||
|
let now: i64 = 1_700_000_000;
|
||||||
|
let mut s: u64 = 0xBADC_0FFE_EE11_2233;
|
||||||
|
|
||||||
|
for _ in 0..25_000 {
|
||||||
|
s ^= s << 7;
|
||||||
|
s ^= s >> 9;
|
||||||
|
s ^= s << 8;
|
||||||
|
let ts = s as u32;
|
||||||
|
let h = make_valid_tls_handshake(secret, ts);
|
||||||
|
|
||||||
|
let accepted = validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, false, now, 0)
|
||||||
|
.is_some();
|
||||||
|
let time_diff = now - i64::from(ts);
|
||||||
|
let expected = (TIME_SKEW_MIN..=TIME_SKEW_MAX).contains(&time_diff);
|
||||||
|
assert_eq!(
|
||||||
|
accepted, expected,
|
||||||
|
"boot-disabled validation must match pure time-diff oracle"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn integration_large_user_list_with_boot_disabled_finds_only_matching_user() {
|
||||||
|
let now: i64 = 1_700_000_000;
|
||||||
|
let target_secret = b"target_user_secret";
|
||||||
|
let target_ts = (now - 1) as u32;
|
||||||
|
let handshake = make_valid_tls_handshake(target_secret, target_ts);
|
||||||
|
|
||||||
|
let mut secrets = Vec::new();
|
||||||
|
for i in 0..512u32 {
|
||||||
|
secrets.push((format!("noise-{i}"), format!("noise-secret-{i}").into_bytes()));
|
||||||
|
}
|
||||||
|
secrets.push(("target-user".to_string(), target_secret.to_vec()));
|
||||||
|
|
||||||
|
let result = validate_tls_handshake_at_time_with_boot_cap(&handshake, &secrets, false, now, 0)
|
||||||
|
.expect("matching user should validate within strict skew window");
|
||||||
|
assert_eq!(result.user, "target-user");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn light_fuzz_ignore_time_skew_accepts_wide_timestamp_range_with_valid_hmac() {
|
||||||
|
let secret = b"ignore_skew_fuzz_accept_test";
|
||||||
|
let secrets = vec![("u".to_string(), secret.to_vec())];
|
||||||
|
let mut s: u64 = 0xC0FF_EE11_2233_4455;
|
||||||
|
|
||||||
|
for _ in 0..2048 {
|
||||||
|
s ^= s << 7;
|
||||||
|
s ^= s >> 9;
|
||||||
|
s ^= s << 8;
|
||||||
|
let ts = s as u32;
|
||||||
|
|
||||||
|
let h = make_valid_tls_handshake(secret, ts);
|
||||||
|
let result = validate_tls_handshake_with_replay_window(&h, &secrets, true, 60);
|
||||||
|
assert!(
|
||||||
|
result.is_some(),
|
||||||
|
"ignore_time_skew=true must accept valid HMAC for arbitrary timestamp"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn light_fuzz_small_replay_window_rejects_far_timestamps_when_skew_enabled() {
|
||||||
|
let secret = b"replay_window_reject_fuzz_test";
|
||||||
|
let secrets = vec![("u".to_string(), secret.to_vec())];
|
||||||
|
|
||||||
|
for ts in 300u32..=1323u32 {
|
||||||
|
let h = make_valid_tls_handshake(secret, ts);
|
||||||
|
let result = validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, false, 0, 300);
|
||||||
|
assert!(
|
||||||
|
result.is_none(),
|
||||||
|
"with skew checks enabled and boot cap=300, timestamp >=300 at now=0 must be rejected"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ------------------------------------------------------------------
|
// ------------------------------------------------------------------
|
||||||
// Extreme timestamp values
|
// Extreme timestamp values
|
||||||
// ------------------------------------------------------------------
|
// ------------------------------------------------------------------
|
||||||
|
|
@ -897,7 +1140,9 @@ fn first_matching_user_wins_over_later_duplicate_secret() {
|
||||||
#[test]
|
#[test]
|
||||||
fn test_is_tls_handshake() {
|
fn test_is_tls_handshake() {
|
||||||
assert!(is_tls_handshake(&[0x16, 0x03, 0x01]));
|
assert!(is_tls_handshake(&[0x16, 0x03, 0x01]));
|
||||||
|
assert!(is_tls_handshake(&[0x16, 0x03, 0x03]));
|
||||||
assert!(is_tls_handshake(&[0x16, 0x03, 0x01, 0x02, 0x00]));
|
assert!(is_tls_handshake(&[0x16, 0x03, 0x01, 0x02, 0x00]));
|
||||||
|
assert!(is_tls_handshake(&[0x16, 0x03, 0x03, 0x02, 0x00]));
|
||||||
assert!(!is_tls_handshake(&[0x17, 0x03, 0x01]));
|
assert!(!is_tls_handshake(&[0x17, 0x03, 0x01]));
|
||||||
assert!(!is_tls_handshake(&[0x16, 0x03, 0x02]));
|
assert!(!is_tls_handshake(&[0x16, 0x03, 0x02]));
|
||||||
assert!(!is_tls_handshake(&[0x16, 0x03]));
|
assert!(!is_tls_handshake(&[0x16, 0x03]));
|
||||||
|
|
@ -945,17 +1190,158 @@ fn test_gen_fake_x25519_key() {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_fake_x25519_key_is_quadratic_residue() {
|
fn test_fake_x25519_key_is_nonzero_and_varies() {
|
||||||
use num_bigint::BigUint;
|
|
||||||
use num_traits::One;
|
|
||||||
|
|
||||||
let rng = crate::crypto::SecureRandom::new();
|
let rng = crate::crypto::SecureRandom::new();
|
||||||
let key = gen_fake_x25519_key(&rng);
|
let mut unique = std::collections::HashSet::new();
|
||||||
let p = curve25519_prime();
|
let mut saw_non_zero = false;
|
||||||
let k_num = BigUint::from_bytes_le(&key);
|
|
||||||
let exponent = (&p - BigUint::one()) >> 1;
|
for _ in 0..64 {
|
||||||
let legendre = k_num.modpow(&exponent, &p);
|
let key = gen_fake_x25519_key(&rng);
|
||||||
assert_eq!(legendre, BigUint::one());
|
if key != [0u8; 32] {
|
||||||
|
saw_non_zero = true;
|
||||||
|
}
|
||||||
|
unique.insert(key);
|
||||||
|
}
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
saw_non_zero,
|
||||||
|
"generated X25519 public keys must not collapse to all-zero output"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
unique.len() > 1,
|
||||||
|
"generated X25519 public keys must vary across invocations"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn validate_tls_handshake_rejects_session_id_longer_than_rfc_cap() {
|
||||||
|
let secret = b"session_id_cap_secret";
|
||||||
|
let oversized_sid = vec![0x42u8; 33];
|
||||||
|
let handshake = make_valid_tls_handshake_with_session_id(secret, 0, &oversized_sid);
|
||||||
|
let secrets = vec![("u".to_string(), secret.to_vec())];
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
validate_tls_handshake(&handshake, &secrets, true).is_none(),
|
||||||
|
"legacy_session_id length > 32 must be rejected"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn server_hello_extension_types(record: &[u8]) -> Vec<u16> {
|
||||||
|
if record.len() < 9 || record[0] != TLS_RECORD_HANDSHAKE || record[5] != 0x02 {
|
||||||
|
return Vec::new();
|
||||||
|
}
|
||||||
|
|
||||||
|
let record_len = u16::from_be_bytes([record[3], record[4]]) as usize;
|
||||||
|
if record.len() < 5 + record_len {
|
||||||
|
return Vec::new();
|
||||||
|
}
|
||||||
|
|
||||||
|
let hs_len = u32::from_be_bytes([0, record[6], record[7], record[8]]) as usize;
|
||||||
|
let hs_start = 5;
|
||||||
|
let hs_end = hs_start + 4 + hs_len;
|
||||||
|
if hs_end > record.len() {
|
||||||
|
return Vec::new();
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut pos = hs_start + 4 + 2 + 32;
|
||||||
|
if pos >= hs_end {
|
||||||
|
return Vec::new();
|
||||||
|
}
|
||||||
|
let sid_len = record[pos] as usize;
|
||||||
|
pos += 1 + sid_len;
|
||||||
|
if pos + 2 + 1 + 2 > hs_end {
|
||||||
|
return Vec::new();
|
||||||
|
}
|
||||||
|
|
||||||
|
pos += 2 + 1;
|
||||||
|
let ext_len = u16::from_be_bytes([record[pos], record[pos + 1]]) as usize;
|
||||||
|
pos += 2;
|
||||||
|
let ext_end = pos + ext_len;
|
||||||
|
if ext_end > hs_end {
|
||||||
|
return Vec::new();
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut out = Vec::new();
|
||||||
|
while pos + 4 <= ext_end {
|
||||||
|
let etype = u16::from_be_bytes([record[pos], record[pos + 1]]);
|
||||||
|
let elen = u16::from_be_bytes([record[pos + 2], record[pos + 3]]) as usize;
|
||||||
|
pos += 4;
|
||||||
|
if pos + elen > ext_end {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
out.push(etype);
|
||||||
|
pos += elen;
|
||||||
|
}
|
||||||
|
out
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn build_server_hello_never_places_alpn_in_server_hello_extensions() {
|
||||||
|
let secret = b"alpn_sh_forbidden";
|
||||||
|
let client_digest = [0x11u8; 32];
|
||||||
|
let session_id = vec![0xAA; 32];
|
||||||
|
let rng = crate::crypto::SecureRandom::new();
|
||||||
|
|
||||||
|
let response = build_server_hello(
|
||||||
|
secret,
|
||||||
|
&client_digest,
|
||||||
|
&session_id,
|
||||||
|
1024,
|
||||||
|
&rng,
|
||||||
|
Some(b"h2".to_vec()),
|
||||||
|
0,
|
||||||
|
);
|
||||||
|
let exts = server_hello_extension_types(&response);
|
||||||
|
assert!(
|
||||||
|
!exts.contains(&0x0010),
|
||||||
|
"ALPN extension must not appear in ServerHello"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn emulated_server_hello_never_places_alpn_in_server_hello_extensions() {
|
||||||
|
let secret = b"alpn_emulated_forbidden";
|
||||||
|
let client_digest = [0x22u8; 32];
|
||||||
|
let session_id = vec![0xAB; 32];
|
||||||
|
let rng = crate::crypto::SecureRandom::new();
|
||||||
|
let cached = CachedTlsData {
|
||||||
|
server_hello_template: ParsedServerHello {
|
||||||
|
version: TLS_VERSION,
|
||||||
|
random: [0u8; 32],
|
||||||
|
session_id: Vec::new(),
|
||||||
|
cipher_suite: [0x13, 0x01],
|
||||||
|
compression: 0,
|
||||||
|
extensions: Vec::new(),
|
||||||
|
},
|
||||||
|
cert_info: None,
|
||||||
|
cert_payload: None,
|
||||||
|
app_data_records_sizes: vec![1024],
|
||||||
|
total_app_data_len: 1024,
|
||||||
|
behavior_profile: TlsBehaviorProfile {
|
||||||
|
change_cipher_spec_count: 1,
|
||||||
|
app_data_record_sizes: vec![1024],
|
||||||
|
ticket_record_sizes: Vec::new(),
|
||||||
|
source: TlsProfileSource::Default,
|
||||||
|
},
|
||||||
|
fetched_at: SystemTime::now(),
|
||||||
|
domain: "example.com".to_string(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let response = build_emulated_server_hello(
|
||||||
|
secret,
|
||||||
|
&client_digest,
|
||||||
|
&session_id,
|
||||||
|
&cached,
|
||||||
|
false,
|
||||||
|
&rng,
|
||||||
|
Some(b"h2".to_vec()),
|
||||||
|
0,
|
||||||
|
);
|
||||||
|
let exts = server_hello_extension_types(&response);
|
||||||
|
assert!(
|
||||||
|
!exts.contains(&0x0010),
|
||||||
|
"ALPN extension must not appear in emulated ServerHello"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
@ -1394,3 +1780,191 @@ fn server_hello_application_data_payload_varies_across_runs() {
|
||||||
"ApplicationData payload should vary across runs to reduce fingerprintability"
|
"ApplicationData payload should vary across runs to reduce fingerprintability"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn replay_window_zero_disables_boot_bypass_for_any_nonzero_timestamp() {
|
||||||
|
let secret = b"window_zero_boot_bypass_test";
|
||||||
|
let secrets = vec![("u".to_string(), secret.to_vec())];
|
||||||
|
|
||||||
|
let ts1 = make_valid_tls_handshake(secret, 1);
|
||||||
|
assert!(
|
||||||
|
validate_tls_handshake_with_replay_window(&ts1, &secrets, false, 0).is_none(),
|
||||||
|
"replay_window_secs=0 must reject nonzero timestamps even in boot-time range"
|
||||||
|
);
|
||||||
|
|
||||||
|
let ts0 = make_valid_tls_handshake(secret, 0);
|
||||||
|
assert!(
|
||||||
|
validate_tls_handshake_with_replay_window(&ts0, &secrets, false, 0).is_none(),
|
||||||
|
"replay_window_secs=0 enforces strict skew check and rejects timestamp=0 on normal wall-clock systems"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn large_replay_window_does_not_expand_time_skew_acceptance() {
|
||||||
|
let secret = b"large_replay_window_skew_bound_test";
|
||||||
|
let secrets = vec![("u".to_string(), secret.to_vec())];
|
||||||
|
let now: i64 = 1_700_000_000;
|
||||||
|
|
||||||
|
let ts_far_past = (now - 600) as u32;
|
||||||
|
let valid = make_valid_tls_handshake(secret, ts_far_past);
|
||||||
|
assert!(
|
||||||
|
validate_tls_handshake_with_replay_window(&valid, &secrets, false, 86_400).is_none(),
|
||||||
|
"large replay window must not relax strict skew check once boot-time bypass is not in play"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_tls_record_header_accepts_tls_version_constant() {
|
||||||
|
let header = [TLS_RECORD_HANDSHAKE, TLS_VERSION[0], TLS_VERSION[1], 0x00, 0x2A];
|
||||||
|
let parsed = parse_tls_record_header(&header).expect("TLS_VERSION header should be accepted");
|
||||||
|
assert_eq!(parsed.0, TLS_RECORD_HANDSHAKE);
|
||||||
|
assert_eq!(parsed.1, 42);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn server_hello_clamps_fake_cert_len_lower_bound() {
|
||||||
|
let secret = b"fake_cert_lower_bound_test";
|
||||||
|
let client_digest = [0x11u8; TLS_DIGEST_LEN];
|
||||||
|
let session_id = vec![0x77; 32];
|
||||||
|
let rng = crate::crypto::SecureRandom::new();
|
||||||
|
|
||||||
|
let response = build_server_hello(secret, &client_digest, &session_id, 1, &rng, None, 0);
|
||||||
|
|
||||||
|
let sh_len = u16::from_be_bytes([response[3], response[4]]) as usize;
|
||||||
|
let ccs_pos = 5 + sh_len;
|
||||||
|
let ccs_len = u16::from_be_bytes([response[ccs_pos + 3], response[ccs_pos + 4]]) as usize;
|
||||||
|
let app_pos = ccs_pos + 5 + ccs_len;
|
||||||
|
let app_len = u16::from_be_bytes([response[app_pos + 3], response[app_pos + 4]]) as usize;
|
||||||
|
|
||||||
|
assert_eq!(response[app_pos], TLS_RECORD_APPLICATION);
|
||||||
|
assert_eq!(app_len, 64, "fake cert payload must be clamped to minimum 64 bytes");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn server_hello_clamps_fake_cert_len_upper_bound() {
|
||||||
|
let secret = b"fake_cert_upper_bound_test";
|
||||||
|
let client_digest = [0x22u8; TLS_DIGEST_LEN];
|
||||||
|
let session_id = vec![0x66; 32];
|
||||||
|
let rng = crate::crypto::SecureRandom::new();
|
||||||
|
|
||||||
|
let response = build_server_hello(secret, &client_digest, &session_id, 65_535, &rng, None, 0);
|
||||||
|
|
||||||
|
let sh_len = u16::from_be_bytes([response[3], response[4]]) as usize;
|
||||||
|
let ccs_pos = 5 + sh_len;
|
||||||
|
let ccs_len = u16::from_be_bytes([response[ccs_pos + 3], response[ccs_pos + 4]]) as usize;
|
||||||
|
let app_pos = ccs_pos + 5 + ccs_len;
|
||||||
|
let app_len = u16::from_be_bytes([response[app_pos + 3], response[app_pos + 4]]) as usize;
|
||||||
|
|
||||||
|
assert_eq!(response[app_pos], TLS_RECORD_APPLICATION);
|
||||||
|
assert_eq!(app_len, 16_640, "fake cert payload must be clamped to TLS record max bound");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn server_hello_new_session_ticket_count_matches_configuration() {
|
||||||
|
let secret = b"ticket_count_surface_test";
|
||||||
|
let client_digest = [0x33u8; TLS_DIGEST_LEN];
|
||||||
|
let session_id = vec![0x55; 32];
|
||||||
|
let rng = crate::crypto::SecureRandom::new();
|
||||||
|
|
||||||
|
let tickets: u8 = 3;
|
||||||
|
let response = build_server_hello(secret, &client_digest, &session_id, 1024, &rng, None, tickets);
|
||||||
|
|
||||||
|
let mut pos = 0usize;
|
||||||
|
let mut app_records = 0usize;
|
||||||
|
while pos + 5 <= response.len() {
|
||||||
|
let rtype = response[pos];
|
||||||
|
let rlen = u16::from_be_bytes([response[pos + 3], response[pos + 4]]) as usize;
|
||||||
|
let next = pos + 5 + rlen;
|
||||||
|
assert!(next <= response.len(), "TLS record must stay inside response bounds");
|
||||||
|
if rtype == TLS_RECORD_APPLICATION {
|
||||||
|
app_records += 1;
|
||||||
|
}
|
||||||
|
pos = next;
|
||||||
|
}
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
app_records,
|
||||||
|
1 + tickets as usize,
|
||||||
|
"response must contain one main application record plus configured ticket-like tail records"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn exhaustive_tls_minor_version_classification_matches_policy() {
|
||||||
|
for minor in 0u8..=u8::MAX {
|
||||||
|
let first = [TLS_RECORD_HANDSHAKE, 0x03, minor];
|
||||||
|
let expected = minor == 0x01 || minor == 0x03;
|
||||||
|
assert_eq!(
|
||||||
|
is_tls_handshake(&first),
|
||||||
|
expected,
|
||||||
|
"minor version {minor:#04x} classification mismatch"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn light_fuzz_tls_header_classifier_and_parser_policy_consistency() {
|
||||||
|
// Deterministic xorshift state keeps this fuzz test reproducible.
|
||||||
|
let mut s: u64 = 0x9E37_79B9_AA95_5A5D;
|
||||||
|
|
||||||
|
for _ in 0..10_000 {
|
||||||
|
s ^= s << 7;
|
||||||
|
s ^= s >> 9;
|
||||||
|
s ^= s << 8;
|
||||||
|
|
||||||
|
let header = [
|
||||||
|
(s & 0xff) as u8,
|
||||||
|
((s >> 8) & 0xff) as u8,
|
||||||
|
((s >> 16) & 0xff) as u8,
|
||||||
|
((s >> 24) & 0xff) as u8,
|
||||||
|
((s >> 32) & 0xff) as u8,
|
||||||
|
];
|
||||||
|
|
||||||
|
let classified = is_tls_handshake(&header[..3]);
|
||||||
|
let expected_classified = header[0] == TLS_RECORD_HANDSHAKE
|
||||||
|
&& header[1] == 0x03
|
||||||
|
&& (header[2] == 0x01 || header[2] == 0x03);
|
||||||
|
assert_eq!(
|
||||||
|
classified,
|
||||||
|
expected_classified,
|
||||||
|
"classifier policy mismatch for header {header:02x?}"
|
||||||
|
);
|
||||||
|
|
||||||
|
let parsed = parse_tls_record_header(&header);
|
||||||
|
let expected_parsed = header[1] == 0x03 && (header[2] == 0x01 || header[2] == TLS_VERSION[1]);
|
||||||
|
assert_eq!(
|
||||||
|
parsed.is_some(),
|
||||||
|
expected_parsed,
|
||||||
|
"parser policy mismatch for header {header:02x?}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn stress_random_noise_handshakes_never_authenticate() {
|
||||||
|
let secret = b"stress_noise_secret";
|
||||||
|
let secrets = vec![("noise-user".to_string(), secret.to_vec())];
|
||||||
|
|
||||||
|
// Deterministic xorshift state keeps this stress test reproducible.
|
||||||
|
let mut s: u64 = 0xD1B5_4A32_9C6E_77F1;
|
||||||
|
|
||||||
|
for _ in 0..5_000 {
|
||||||
|
s ^= s << 7;
|
||||||
|
s ^= s >> 9;
|
||||||
|
s ^= s << 8;
|
||||||
|
|
||||||
|
let len = 1 + ((s as usize) % 196);
|
||||||
|
let mut buf = vec![0u8; len];
|
||||||
|
for b in &mut buf {
|
||||||
|
s ^= s << 7;
|
||||||
|
s ^= s >> 9;
|
||||||
|
s ^= s << 8;
|
||||||
|
*b = (s & 0xff) as u8;
|
||||||
|
}
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
validate_tls_handshake(&buf, &secrets, true).is_none(),
|
||||||
|
"random noise must never authenticate"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,72 @@ enum HandshakeOutcome {
|
||||||
Handled,
|
Handled,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[must_use = "UserConnectionReservation must be kept alive to retain user/IP reservation until release or drop"]
|
||||||
|
struct UserConnectionReservation {
|
||||||
|
stats: Arc<Stats>,
|
||||||
|
ip_tracker: Arc<UserIpTracker>,
|
||||||
|
user: String,
|
||||||
|
ip: IpAddr,
|
||||||
|
active: bool,
|
||||||
|
runtime_handle: Option<tokio::runtime::Handle>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl UserConnectionReservation {
|
||||||
|
fn new(stats: Arc<Stats>, ip_tracker: Arc<UserIpTracker>, user: String, ip: IpAddr) -> Self {
|
||||||
|
let runtime_handle = tokio::runtime::Handle::try_current().ok();
|
||||||
|
Self {
|
||||||
|
stats,
|
||||||
|
ip_tracker,
|
||||||
|
user,
|
||||||
|
ip,
|
||||||
|
active: true,
|
||||||
|
runtime_handle,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn release(mut self) {
|
||||||
|
if !self.active {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
self.ip_tracker.remove_ip(&self.user, self.ip).await;
|
||||||
|
self.active = false;
|
||||||
|
self.stats.decrement_user_curr_connects(&self.user);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for UserConnectionReservation {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
if !self.active {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
self.active = false;
|
||||||
|
self.stats.decrement_user_curr_connects(&self.user);
|
||||||
|
|
||||||
|
if let Some(handle) = &self.runtime_handle {
|
||||||
|
let ip_tracker = self.ip_tracker.clone();
|
||||||
|
let user = self.user.clone();
|
||||||
|
let ip = self.ip;
|
||||||
|
let handle = handle.clone();
|
||||||
|
handle.spawn(async move {
|
||||||
|
ip_tracker.remove_ip(&user, ip).await;
|
||||||
|
});
|
||||||
|
} else if let Ok(handle) = tokio::runtime::Handle::try_current() {
|
||||||
|
let ip_tracker = self.ip_tracker.clone();
|
||||||
|
let user = self.user.clone();
|
||||||
|
let ip = self.ip;
|
||||||
|
handle.spawn(async move {
|
||||||
|
ip_tracker.remove_ip(&user, ip).await;
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
warn!(
|
||||||
|
user = %self.user,
|
||||||
|
ip = %self.ip,
|
||||||
|
"UserConnectionReservation dropped without Tokio runtime; IP reservation cleanup skipped"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
use crate::config::ProxyConfig;
|
use crate::config::ProxyConfig;
|
||||||
use crate::crypto::SecureRandom;
|
use crate::crypto::SecureRandom;
|
||||||
use crate::error::{HandshakeResult, ProxyError, Result, StreamError};
|
use crate::error::{HandshakeResult, ProxyError, Result, StreamError};
|
||||||
|
|
@ -45,7 +111,19 @@ use crate::proxy::middle_relay::handle_via_middle_proxy;
|
||||||
use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController};
|
use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController};
|
||||||
|
|
||||||
fn beobachten_ttl(config: &ProxyConfig) -> Duration {
|
fn beobachten_ttl(config: &ProxyConfig) -> Duration {
|
||||||
Duration::from_secs(config.general.beobachten_minutes.saturating_mul(60))
|
let minutes = config.general.beobachten_minutes;
|
||||||
|
if minutes == 0 {
|
||||||
|
static BEOBACHTEN_ZERO_MINUTES_WARNED: OnceLock<AtomicBool> = OnceLock::new();
|
||||||
|
let warned = BEOBACHTEN_ZERO_MINUTES_WARNED.get_or_init(|| AtomicBool::new(false));
|
||||||
|
if !warned.swap(true, Ordering::Relaxed) {
|
||||||
|
warn!(
|
||||||
|
"general.beobachten_minutes=0 is insecure because entries expire immediately; forcing minimum TTL to 1 minute"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return Duration::from_secs(60);
|
||||||
|
}
|
||||||
|
|
||||||
|
Duration::from_secs(minutes.saturating_mul(60))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn record_beobachten_class(
|
fn record_beobachten_class(
|
||||||
|
|
@ -90,6 +168,10 @@ fn is_trusted_proxy_source(peer_ip: IpAddr, trusted: &[IpNetwork]) -> bool {
|
||||||
trusted.iter().any(|cidr| cidr.contains(peer_ip))
|
trusted.iter().any(|cidr| cidr.contains(peer_ip))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn synthetic_local_addr(port: u16) -> SocketAddr {
|
||||||
|
SocketAddr::from(([0, 0, 0, 0], port))
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn handle_client_stream<S>(
|
pub async fn handle_client_stream<S>(
|
||||||
mut stream: S,
|
mut stream: S,
|
||||||
peer: SocketAddr,
|
peer: SocketAddr,
|
||||||
|
|
@ -113,9 +195,7 @@ where
|
||||||
let mut real_peer = normalize_ip(peer);
|
let mut real_peer = normalize_ip(peer);
|
||||||
|
|
||||||
// For non-TCP streams, use a synthetic local address; may be overridden by PROXY protocol dst
|
// For non-TCP streams, use a synthetic local address; may be overridden by PROXY protocol dst
|
||||||
let mut local_addr: SocketAddr = format!("0.0.0.0:{}", config.server.port)
|
let mut local_addr = synthetic_local_addr(config.server.port);
|
||||||
.parse()
|
|
||||||
.unwrap_or_else(|_| "0.0.0.0:443".parse().unwrap());
|
|
||||||
|
|
||||||
if proxy_protocol_enabled {
|
if proxy_protocol_enabled {
|
||||||
let proxy_header_timeout = Duration::from_millis(
|
let proxy_header_timeout = Duration::from_millis(
|
||||||
|
|
@ -426,7 +506,6 @@ impl RunningClientHandler {
|
||||||
pub async fn run(self) -> Result<()> {
|
pub async fn run(self) -> Result<()> {
|
||||||
self.stats.increment_connects_all();
|
self.stats.increment_connects_all();
|
||||||
let peer = self.peer;
|
let peer = self.peer;
|
||||||
let _ip_tracker = self.ip_tracker.clone();
|
|
||||||
debug!(peer = %peer, "New connection");
|
debug!(peer = %peer, "New connection");
|
||||||
|
|
||||||
if let Err(e) = configure_client_socket(
|
if let Err(e) = configure_client_socket(
|
||||||
|
|
@ -557,7 +636,6 @@ impl RunningClientHandler {
|
||||||
|
|
||||||
let is_tls = tls::is_tls_handshake(&first_bytes[..3]);
|
let is_tls = tls::is_tls_handshake(&first_bytes[..3]);
|
||||||
let peer = self.peer;
|
let peer = self.peer;
|
||||||
let _ip_tracker = self.ip_tracker.clone();
|
|
||||||
|
|
||||||
debug!(peer = %peer, is_tls = is_tls, "Handshake type detected");
|
debug!(peer = %peer, is_tls = is_tls, "Handshake type detected");
|
||||||
|
|
||||||
|
|
@ -570,7 +648,6 @@ impl RunningClientHandler {
|
||||||
|
|
||||||
async fn handle_tls_client(mut self, first_bytes: [u8; 5], local_addr: SocketAddr) -> Result<HandshakeOutcome> {
|
async fn handle_tls_client(mut self, first_bytes: [u8; 5], local_addr: SocketAddr) -> Result<HandshakeOutcome> {
|
||||||
let peer = self.peer;
|
let peer = self.peer;
|
||||||
let _ip_tracker = self.ip_tracker.clone();
|
|
||||||
|
|
||||||
let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize;
|
let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize;
|
||||||
|
|
||||||
|
|
@ -694,7 +771,6 @@ impl RunningClientHandler {
|
||||||
|
|
||||||
async fn handle_direct_client(mut self, first_bytes: [u8; 5], local_addr: SocketAddr) -> Result<HandshakeOutcome> {
|
async fn handle_direct_client(mut self, first_bytes: [u8; 5], local_addr: SocketAddr) -> Result<HandshakeOutcome> {
|
||||||
let peer = self.peer;
|
let peer = self.peer;
|
||||||
let _ip_tracker = self.ip_tracker.clone();
|
|
||||||
|
|
||||||
if !self.config.general.modes.classic && !self.config.general.modes.secure {
|
if !self.config.general.modes.classic && !self.config.general.modes.secure {
|
||||||
debug!(peer = %peer, "Non-TLS modes disabled");
|
debug!(peer = %peer, "Non-TLS modes disabled");
|
||||||
|
|
@ -798,10 +874,22 @@ impl RunningClientHandler {
|
||||||
{
|
{
|
||||||
let user = success.user.clone();
|
let user = success.user.clone();
|
||||||
|
|
||||||
if let Err(e) = Self::check_user_limits_static(&user, &config, &stats, peer_addr, &ip_tracker).await {
|
let user_limit_reservation =
|
||||||
warn!(user = %user, error = %e, "User limit exceeded");
|
match Self::acquire_user_connection_reservation_static(
|
||||||
return Err(e);
|
&user,
|
||||||
}
|
&config,
|
||||||
|
stats.clone(),
|
||||||
|
peer_addr,
|
||||||
|
ip_tracker,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(reservation) => reservation,
|
||||||
|
Err(e) => {
|
||||||
|
warn!(user = %user, error = %e, "User admission check failed");
|
||||||
|
return Err(e);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
let route_snapshot = route_runtime.snapshot();
|
let route_snapshot = route_runtime.snapshot();
|
||||||
let session_id = rng.u64();
|
let session_id = rng.u64();
|
||||||
|
|
@ -858,15 +946,68 @@ impl RunningClientHandler {
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
};
|
};
|
||||||
|
user_limit_reservation.release().await;
|
||||||
stats.decrement_user_curr_connects(&user);
|
|
||||||
ip_tracker.remove_ip(&user, peer_addr.ip()).await;
|
|
||||||
relay_result
|
relay_result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn acquire_user_connection_reservation_static(
|
||||||
|
user: &str,
|
||||||
|
config: &ProxyConfig,
|
||||||
|
stats: Arc<Stats>,
|
||||||
|
peer_addr: SocketAddr,
|
||||||
|
ip_tracker: Arc<UserIpTracker>,
|
||||||
|
) -> Result<UserConnectionReservation> {
|
||||||
|
if let Some(expiration) = config.access.user_expirations.get(user)
|
||||||
|
&& chrono::Utc::now() > *expiration
|
||||||
|
{
|
||||||
|
return Err(ProxyError::UserExpired {
|
||||||
|
user: user.to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(quota) = config.access.user_data_quota.get(user)
|
||||||
|
&& stats.get_user_total_octets(user) >= *quota
|
||||||
|
{
|
||||||
|
return Err(ProxyError::DataQuotaExceeded {
|
||||||
|
user: user.to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let limit = config.access.user_max_tcp_conns.get(user).map(|v| *v as u64);
|
||||||
|
if !stats.try_acquire_user_curr_connects(user, limit) {
|
||||||
|
return Err(ProxyError::ConnectionLimitExceeded {
|
||||||
|
user: user.to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
match ip_tracker.check_and_add(user, peer_addr.ip()).await {
|
||||||
|
Ok(()) => {}
|
||||||
|
Err(reason) => {
|
||||||
|
stats.decrement_user_curr_connects(user);
|
||||||
|
warn!(
|
||||||
|
user = %user,
|
||||||
|
ip = %peer_addr.ip(),
|
||||||
|
reason = %reason,
|
||||||
|
"IP limit exceeded"
|
||||||
|
);
|
||||||
|
return Err(ProxyError::ConnectionLimitExceeded {
|
||||||
|
user: user.to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(UserConnectionReservation::new(
|
||||||
|
stats,
|
||||||
|
ip_tracker,
|
||||||
|
user.to_string(),
|
||||||
|
peer_addr.ip(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
async fn check_user_limits_static(
|
async fn check_user_limits_static(
|
||||||
user: &str,
|
user: &str,
|
||||||
config: &ProxyConfig,
|
config: &ProxyConfig,
|
||||||
stats: &Stats,
|
stats: &Stats,
|
||||||
peer_addr: SocketAddr,
|
peer_addr: SocketAddr,
|
||||||
ip_tracker: &UserIpTracker,
|
ip_tracker: &UserIpTracker,
|
||||||
|
|
@ -899,7 +1040,10 @@ impl RunningClientHandler {
|
||||||
}
|
}
|
||||||
|
|
||||||
match ip_tracker.check_and_add(user, peer_addr.ip()).await {
|
match ip_tracker.check_and_add(user, peer_addr.ip()).await {
|
||||||
Ok(()) => {}
|
Ok(()) => {
|
||||||
|
ip_tracker.remove_ip(user, peer_addr.ip()).await;
|
||||||
|
stats.decrement_user_curr_connects(user);
|
||||||
|
}
|
||||||
Err(reason) => {
|
Err(reason) => {
|
||||||
stats.decrement_user_curr_connects(user);
|
stats.decrement_user_curr_connects(user);
|
||||||
warn!(
|
warn!(
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,6 +1,8 @@
|
||||||
|
use std::ffi::OsString;
|
||||||
use std::fs::OpenOptions;
|
use std::fs::OpenOptions;
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
|
use std::path::{Component, Path, PathBuf};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::collections::HashSet;
|
use std::collections::HashSet;
|
||||||
use std::sync::{Mutex, OnceLock};
|
use std::sync::{Mutex, OnceLock};
|
||||||
|
|
@ -24,14 +26,28 @@ use crate::stats::Stats;
|
||||||
use crate::stream::{BufferPool, CryptoReader, CryptoWriter};
|
use crate::stream::{BufferPool, CryptoReader, CryptoWriter};
|
||||||
use crate::transport::UpstreamManager;
|
use crate::transport::UpstreamManager;
|
||||||
|
|
||||||
|
#[cfg(unix)]
|
||||||
|
use std::os::unix::fs::OpenOptionsExt;
|
||||||
|
|
||||||
const UNKNOWN_DC_LOG_DISTINCT_LIMIT: usize = 1024;
|
const UNKNOWN_DC_LOG_DISTINCT_LIMIT: usize = 1024;
|
||||||
static LOGGED_UNKNOWN_DCS: OnceLock<Mutex<HashSet<i16>>> = OnceLock::new();
|
static LOGGED_UNKNOWN_DCS: OnceLock<Mutex<HashSet<i16>>> = OnceLock::new();
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct SanitizedUnknownDcLogPath {
|
||||||
|
resolved_path: PathBuf,
|
||||||
|
allowed_parent: PathBuf,
|
||||||
|
file_name: OsString,
|
||||||
|
}
|
||||||
|
|
||||||
// In tests, this function shares global mutable state. Callers that also use
|
// In tests, this function shares global mutable state. Callers that also use
|
||||||
// cache-reset helpers must hold `unknown_dc_test_lock()` to keep assertions
|
// cache-reset helpers must hold `unknown_dc_test_lock()` to keep assertions
|
||||||
// deterministic under parallel execution.
|
// deterministic under parallel execution.
|
||||||
fn should_log_unknown_dc(dc_idx: i16) -> bool {
|
fn should_log_unknown_dc(dc_idx: i16) -> bool {
|
||||||
let set = LOGGED_UNKNOWN_DCS.get_or_init(|| Mutex::new(HashSet::new()));
|
let set = LOGGED_UNKNOWN_DCS.get_or_init(|| Mutex::new(HashSet::new()));
|
||||||
|
should_log_unknown_dc_with_set(set, dc_idx)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn should_log_unknown_dc_with_set(set: &Mutex<HashSet<i16>>, dc_idx: i16) -> bool {
|
||||||
match set.lock() {
|
match set.lock() {
|
||||||
Ok(mut guard) => {
|
Ok(mut guard) => {
|
||||||
if guard.contains(&dc_idx) {
|
if guard.contains(&dc_idx) {
|
||||||
|
|
@ -42,9 +58,81 @@ fn should_log_unknown_dc(dc_idx: i16) -> bool {
|
||||||
}
|
}
|
||||||
guard.insert(dc_idx)
|
guard.insert(dc_idx)
|
||||||
}
|
}
|
||||||
// If the lock is poisoned, keep logging rather than silently dropping
|
// Fail closed on poisoned state to avoid unbounded blocking log writes.
|
||||||
// operator-visible diagnostics.
|
Err(_) => false,
|
||||||
Err(_) => true,
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sanitize_unknown_dc_log_path(path: &str) -> Option<SanitizedUnknownDcLogPath> {
|
||||||
|
let candidate = Path::new(path);
|
||||||
|
if candidate.as_os_str().is_empty() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
if candidate
|
||||||
|
.components()
|
||||||
|
.any(|component| matches!(component, Component::ParentDir))
|
||||||
|
{
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let cwd = std::env::current_dir().ok()?;
|
||||||
|
let file_name = candidate.file_name()?;
|
||||||
|
let parent = candidate.parent().unwrap_or_else(|| Path::new("."));
|
||||||
|
let parent_path = if parent.is_absolute() {
|
||||||
|
parent.to_path_buf()
|
||||||
|
} else {
|
||||||
|
cwd.join(parent)
|
||||||
|
};
|
||||||
|
let canonical_parent = parent_path.canonicalize().ok()?;
|
||||||
|
if !canonical_parent.is_dir() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
Some(SanitizedUnknownDcLogPath {
|
||||||
|
resolved_path: canonical_parent.join(file_name),
|
||||||
|
allowed_parent: canonical_parent,
|
||||||
|
file_name: file_name.to_os_string(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn unknown_dc_log_path_is_still_safe(path: &SanitizedUnknownDcLogPath) -> bool {
|
||||||
|
let Some(parent) = path.resolved_path.parent() else {
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
let Ok(current_parent) = parent.canonicalize() else {
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
if current_parent != path.allowed_parent {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Ok(canonical_target) = path.resolved_path.canonicalize() {
|
||||||
|
let Some(target_parent) = canonical_target.parent() else {
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
let Some(target_name) = canonical_target.file_name() else {
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
if target_parent != path.allowed_parent || target_name != path.file_name {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
fn open_unknown_dc_log_append(path: &Path) -> std::io::Result<std::fs::File> {
|
||||||
|
#[cfg(unix)]
|
||||||
|
{
|
||||||
|
OpenOptions::new()
|
||||||
|
.create(true)
|
||||||
|
.append(true)
|
||||||
|
.custom_flags(libc::O_NOFOLLOW)
|
||||||
|
.open(path)
|
||||||
|
}
|
||||||
|
#[cfg(not(unix))]
|
||||||
|
{
|
||||||
|
OpenOptions::new().create(true).append(true).open(path)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -105,7 +193,7 @@ where
|
||||||
debug!(peer = %success.peer, "TG handshake complete, starting relay");
|
debug!(peer = %success.peer, "TG handshake complete, starting relay");
|
||||||
|
|
||||||
stats.increment_user_connects(user);
|
stats.increment_user_connects(user);
|
||||||
stats.increment_current_connections_direct();
|
let _direct_connection_lease = stats.acquire_direct_connection_lease();
|
||||||
|
|
||||||
let relay_result = relay_bidirectional(
|
let relay_result = relay_bidirectional(
|
||||||
client_reader,
|
client_reader,
|
||||||
|
|
@ -148,8 +236,6 @@ where
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
stats.decrement_current_connections_direct();
|
|
||||||
|
|
||||||
match &relay_result {
|
match &relay_result {
|
||||||
Ok(()) => debug!(user = %user, "Direct relay completed"),
|
Ok(()) => debug!(user = %user, "Direct relay completed"),
|
||||||
Err(e) => debug!(user = %user, error = %e, "Direct relay ended with error"),
|
Err(e) => debug!(user = %user, error = %e, "Direct relay ended with error"),
|
||||||
|
|
@ -202,12 +288,17 @@ fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result<SocketAddr> {
|
||||||
&& should_log_unknown_dc(dc_idx)
|
&& should_log_unknown_dc(dc_idx)
|
||||||
&& let Ok(handle) = tokio::runtime::Handle::try_current()
|
&& let Ok(handle) = tokio::runtime::Handle::try_current()
|
||||||
{
|
{
|
||||||
let path = path.clone();
|
if let Some(path) = sanitize_unknown_dc_log_path(path) {
|
||||||
handle.spawn_blocking(move || {
|
handle.spawn_blocking(move || {
|
||||||
if let Ok(mut file) = OpenOptions::new().create(true).append(true).open(path) {
|
if unknown_dc_log_path_is_still_safe(&path)
|
||||||
let _ = writeln!(file, "dc_idx={dc_idx}");
|
&& let Ok(mut file) = open_unknown_dc_log_append(&path.resolved_path)
|
||||||
}
|
{
|
||||||
});
|
let _ = writeln!(file, "dc_idx={dc_idx}");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
warn!(dc_idx = dc_idx, raw_path = %path, "Rejected unsafe unknown DC log path");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -4,11 +4,11 @@
|
||||||
|
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use std::collections::HashSet;
|
use std::collections::HashSet;
|
||||||
|
use std::collections::hash_map::RandomState;
|
||||||
use std::net::{IpAddr, Ipv6Addr};
|
use std::net::{IpAddr, Ipv6Addr};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::sync::{Mutex, OnceLock};
|
use std::sync::{Mutex, OnceLock};
|
||||||
use std::collections::hash_map::DefaultHasher;
|
use std::hash::{BuildHasher, Hash, Hasher};
|
||||||
use std::hash::{Hash, Hasher};
|
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
use dashmap::DashMap;
|
use dashmap::DashMap;
|
||||||
use dashmap::mapref::entry::Entry;
|
use dashmap::mapref::entry::Entry;
|
||||||
|
|
@ -36,6 +36,7 @@ const AUTH_PROBE_TRACK_MAX_ENTRIES: usize = 256;
|
||||||
const AUTH_PROBE_TRACK_MAX_ENTRIES: usize = 65_536;
|
const AUTH_PROBE_TRACK_MAX_ENTRIES: usize = 65_536;
|
||||||
const AUTH_PROBE_PRUNE_SCAN_LIMIT: usize = 1_024;
|
const AUTH_PROBE_PRUNE_SCAN_LIMIT: usize = 1_024;
|
||||||
const AUTH_PROBE_BACKOFF_START_FAILS: u32 = 4;
|
const AUTH_PROBE_BACKOFF_START_FAILS: u32 = 4;
|
||||||
|
const AUTH_PROBE_SATURATION_GRACE_FAILS: u32 = 2;
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
const AUTH_PROBE_BACKOFF_BASE_MS: u64 = 1;
|
const AUTH_PROBE_BACKOFF_BASE_MS: u64 = 1;
|
||||||
|
|
@ -54,12 +55,25 @@ struct AuthProbeState {
|
||||||
last_seen: Instant,
|
last_seen: Instant,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy)]
|
||||||
|
struct AuthProbeSaturationState {
|
||||||
|
fail_streak: u32,
|
||||||
|
blocked_until: Instant,
|
||||||
|
last_seen: Instant,
|
||||||
|
}
|
||||||
|
|
||||||
static AUTH_PROBE_STATE: OnceLock<DashMap<IpAddr, AuthProbeState>> = OnceLock::new();
|
static AUTH_PROBE_STATE: OnceLock<DashMap<IpAddr, AuthProbeState>> = OnceLock::new();
|
||||||
|
static AUTH_PROBE_SATURATION_STATE: OnceLock<Mutex<Option<AuthProbeSaturationState>>> = OnceLock::new();
|
||||||
|
static AUTH_PROBE_EVICTION_HASHER: OnceLock<RandomState> = OnceLock::new();
|
||||||
|
|
||||||
fn auth_probe_state_map() -> &'static DashMap<IpAddr, AuthProbeState> {
|
fn auth_probe_state_map() -> &'static DashMap<IpAddr, AuthProbeState> {
|
||||||
AUTH_PROBE_STATE.get_or_init(DashMap::new)
|
AUTH_PROBE_STATE.get_or_init(DashMap::new)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn auth_probe_saturation_state() -> &'static Mutex<Option<AuthProbeSaturationState>> {
|
||||||
|
AUTH_PROBE_SATURATION_STATE.get_or_init(|| Mutex::new(None))
|
||||||
|
}
|
||||||
|
|
||||||
fn normalize_auth_probe_ip(peer_ip: IpAddr) -> IpAddr {
|
fn normalize_auth_probe_ip(peer_ip: IpAddr) -> IpAddr {
|
||||||
match peer_ip {
|
match peer_ip {
|
||||||
IpAddr::V4(ip) => IpAddr::V4(ip),
|
IpAddr::V4(ip) => IpAddr::V4(ip),
|
||||||
|
|
@ -88,7 +102,8 @@ fn auth_probe_state_expired(state: &AuthProbeState, now: Instant) -> bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn auth_probe_eviction_offset(peer_ip: IpAddr, now: Instant) -> usize {
|
fn auth_probe_eviction_offset(peer_ip: IpAddr, now: Instant) -> usize {
|
||||||
let mut hasher = DefaultHasher::new();
|
let hasher_state = AUTH_PROBE_EVICTION_HASHER.get_or_init(RandomState::new);
|
||||||
|
let mut hasher = hasher_state.build_hasher();
|
||||||
peer_ip.hash(&mut hasher);
|
peer_ip.hash(&mut hasher);
|
||||||
now.hash(&mut hasher);
|
now.hash(&mut hasher);
|
||||||
hasher.finish() as usize
|
hasher.finish() as usize
|
||||||
|
|
@ -108,6 +123,83 @@ fn auth_probe_is_throttled(peer_ip: IpAddr, now: Instant) -> bool {
|
||||||
now < entry.blocked_until
|
now < entry.blocked_until
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn auth_probe_saturation_grace_exhausted(peer_ip: IpAddr, now: Instant) -> bool {
|
||||||
|
let peer_ip = normalize_auth_probe_ip(peer_ip);
|
||||||
|
let state = auth_probe_state_map();
|
||||||
|
let Some(entry) = state.get(&peer_ip) else {
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
if auth_probe_state_expired(&entry, now) {
|
||||||
|
drop(entry);
|
||||||
|
state.remove(&peer_ip);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
entry.fail_streak >= AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS
|
||||||
|
}
|
||||||
|
|
||||||
|
fn auth_probe_should_apply_preauth_throttle(peer_ip: IpAddr, now: Instant) -> bool {
|
||||||
|
if !auth_probe_is_throttled(peer_ip, now) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if !auth_probe_saturation_is_throttled(now) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
auth_probe_saturation_grace_exhausted(peer_ip, now)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn auth_probe_saturation_is_throttled(now: Instant) -> bool {
|
||||||
|
let saturation = auth_probe_saturation_state();
|
||||||
|
let mut guard = match saturation.lock() {
|
||||||
|
Ok(guard) => guard,
|
||||||
|
Err(_) => return false,
|
||||||
|
};
|
||||||
|
|
||||||
|
let Some(state) = guard.as_mut() else {
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
|
||||||
|
if now.duration_since(state.last_seen) > Duration::from_secs(AUTH_PROBE_TRACK_RETENTION_SECS) {
|
||||||
|
*guard = None;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if now < state.blocked_until {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
|
fn auth_probe_note_saturation(now: Instant) {
|
||||||
|
let saturation = auth_probe_saturation_state();
|
||||||
|
let mut guard = match saturation.lock() {
|
||||||
|
Ok(guard) => guard,
|
||||||
|
Err(_) => return,
|
||||||
|
};
|
||||||
|
|
||||||
|
match guard.as_mut() {
|
||||||
|
Some(state)
|
||||||
|
if now.duration_since(state.last_seen)
|
||||||
|
<= Duration::from_secs(AUTH_PROBE_TRACK_RETENTION_SECS) =>
|
||||||
|
{
|
||||||
|
state.fail_streak = state.fail_streak.saturating_add(1);
|
||||||
|
state.last_seen = now;
|
||||||
|
state.blocked_until = now + auth_probe_backoff(state.fail_streak);
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
let fail_streak = AUTH_PROBE_BACKOFF_START_FAILS;
|
||||||
|
*guard = Some(AuthProbeSaturationState {
|
||||||
|
fail_streak,
|
||||||
|
blocked_until: now + auth_probe_backoff(fail_streak),
|
||||||
|
last_seen: now,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn auth_probe_record_failure(peer_ip: IpAddr, now: Instant) {
|
fn auth_probe_record_failure(peer_ip: IpAddr, now: Instant) {
|
||||||
let peer_ip = normalize_auth_probe_ip(peer_ip);
|
let peer_ip = normalize_auth_probe_ip(peer_ip);
|
||||||
let state = auth_probe_state_map();
|
let state = auth_probe_state_map();
|
||||||
|
|
@ -144,24 +236,79 @@ fn auth_probe_record_failure_with_state(
|
||||||
}
|
}
|
||||||
|
|
||||||
if state.len() >= AUTH_PROBE_TRACK_MAX_ENTRIES {
|
if state.len() >= AUTH_PROBE_TRACK_MAX_ENTRIES {
|
||||||
let mut stale_keys = Vec::new();
|
let mut rounds = 0usize;
|
||||||
let mut eviction_candidates = Vec::new();
|
while state.len() >= AUTH_PROBE_TRACK_MAX_ENTRIES {
|
||||||
for entry in state.iter().take(AUTH_PROBE_PRUNE_SCAN_LIMIT) {
|
rounds += 1;
|
||||||
eviction_candidates.push(*entry.key());
|
if rounds > 8 {
|
||||||
if auth_probe_state_expired(entry.value(), now) {
|
auth_probe_note_saturation(now);
|
||||||
stale_keys.push(*entry.key());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for stale_key in stale_keys {
|
|
||||||
state.remove(&stale_key);
|
|
||||||
}
|
|
||||||
if state.len() >= AUTH_PROBE_TRACK_MAX_ENTRIES {
|
|
||||||
if eviction_candidates.is_empty() {
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
let idx = auth_probe_eviction_offset(peer_ip, now) % eviction_candidates.len();
|
|
||||||
let evict_key = eviction_candidates[idx];
|
let mut stale_keys = Vec::new();
|
||||||
|
let mut eviction_candidate: Option<(IpAddr, u32, Instant)> = None;
|
||||||
|
let state_len = state.len();
|
||||||
|
let scan_limit = state_len.min(AUTH_PROBE_PRUNE_SCAN_LIMIT);
|
||||||
|
let start_offset = if state_len == 0 {
|
||||||
|
0
|
||||||
|
} else {
|
||||||
|
auth_probe_eviction_offset(peer_ip, now) % state_len
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut scanned = 0usize;
|
||||||
|
for entry in state.iter().skip(start_offset) {
|
||||||
|
let key = *entry.key();
|
||||||
|
let fail_streak = entry.value().fail_streak;
|
||||||
|
let last_seen = entry.value().last_seen;
|
||||||
|
match eviction_candidate {
|
||||||
|
Some((_, current_fail, current_seen))
|
||||||
|
if fail_streak > current_fail
|
||||||
|
|| (fail_streak == current_fail && last_seen >= current_seen) =>
|
||||||
|
{
|
||||||
|
}
|
||||||
|
_ => eviction_candidate = Some((key, fail_streak, last_seen)),
|
||||||
|
}
|
||||||
|
if auth_probe_state_expired(entry.value(), now) {
|
||||||
|
stale_keys.push(key);
|
||||||
|
}
|
||||||
|
scanned += 1;
|
||||||
|
if scanned >= scan_limit {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if scanned < scan_limit {
|
||||||
|
for entry in state.iter().take(scan_limit - scanned) {
|
||||||
|
let key = *entry.key();
|
||||||
|
let fail_streak = entry.value().fail_streak;
|
||||||
|
let last_seen = entry.value().last_seen;
|
||||||
|
match eviction_candidate {
|
||||||
|
Some((_, current_fail, current_seen))
|
||||||
|
if fail_streak > current_fail
|
||||||
|
|| (fail_streak == current_fail && last_seen >= current_seen) =>
|
||||||
|
{
|
||||||
|
}
|
||||||
|
_ => eviction_candidate = Some((key, fail_streak, last_seen)),
|
||||||
|
}
|
||||||
|
if auth_probe_state_expired(entry.value(), now) {
|
||||||
|
stale_keys.push(key);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for stale_key in stale_keys {
|
||||||
|
state.remove(&stale_key);
|
||||||
|
}
|
||||||
|
|
||||||
|
if state.len() < AUTH_PROBE_TRACK_MAX_ENTRIES {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
let Some((evict_key, _, _)) = eviction_candidate else {
|
||||||
|
auth_probe_note_saturation(now);
|
||||||
|
return;
|
||||||
|
};
|
||||||
state.remove(&evict_key);
|
state.remove(&evict_key);
|
||||||
|
auth_probe_note_saturation(now);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -186,6 +333,11 @@ fn clear_auth_probe_state_for_testing() {
|
||||||
if let Some(state) = AUTH_PROBE_STATE.get() {
|
if let Some(state) = AUTH_PROBE_STATE.get() {
|
||||||
state.clear();
|
state.clear();
|
||||||
}
|
}
|
||||||
|
if let Some(saturation) = AUTH_PROBE_SATURATION_STATE.get()
|
||||||
|
&& let Ok(mut guard) = saturation.lock()
|
||||||
|
{
|
||||||
|
*guard = None;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|
@ -200,6 +352,16 @@ fn auth_probe_is_throttled_for_testing(peer_ip: IpAddr) -> bool {
|
||||||
auth_probe_is_throttled(peer_ip, Instant::now())
|
auth_probe_is_throttled(peer_ip, Instant::now())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
fn auth_probe_saturation_is_throttled_for_testing() -> bool {
|
||||||
|
auth_probe_saturation_is_throttled(Instant::now())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
fn auth_probe_saturation_is_throttled_at_for_testing(now: Instant) -> bool {
|
||||||
|
auth_probe_saturation_is_throttled(now)
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
fn auth_probe_test_lock() -> &'static Mutex<()> {
|
fn auth_probe_test_lock() -> &'static Mutex<()> {
|
||||||
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||||
|
|
@ -317,6 +479,24 @@ fn decode_user_secrets(
|
||||||
secrets
|
secrets
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn maybe_apply_server_hello_delay(config: &ProxyConfig) {
|
||||||
|
if config.censorship.server_hello_delay_max_ms == 0 {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let min = config.censorship.server_hello_delay_min_ms;
|
||||||
|
let max = config.censorship.server_hello_delay_max_ms.max(min);
|
||||||
|
let delay_ms = if max == min {
|
||||||
|
max
|
||||||
|
} else {
|
||||||
|
rand::rng().random_range(min..=max)
|
||||||
|
};
|
||||||
|
|
||||||
|
if delay_ms > 0 {
|
||||||
|
tokio::time::sleep(Duration::from_millis(delay_ms)).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Result of successful handshake
|
/// Result of successful handshake
|
||||||
///
|
///
|
||||||
/// Key material (`dec_key`, `dec_iv`, `enc_key`, `enc_iv`) is
|
/// Key material (`dec_key`, `dec_iv`, `enc_key`, `enc_iv`) is
|
||||||
|
|
@ -367,17 +547,21 @@ where
|
||||||
{
|
{
|
||||||
debug!(peer = %peer, handshake_len = handshake.len(), "Processing TLS handshake");
|
debug!(peer = %peer, handshake_len = handshake.len(), "Processing TLS handshake");
|
||||||
|
|
||||||
if auth_probe_is_throttled(peer.ip(), Instant::now()) {
|
let throttle_now = Instant::now();
|
||||||
|
if auth_probe_should_apply_preauth_throttle(peer.ip(), throttle_now) {
|
||||||
|
maybe_apply_server_hello_delay(config).await;
|
||||||
debug!(peer = %peer, "TLS handshake rejected by pre-auth probe throttle");
|
debug!(peer = %peer, "TLS handshake rejected by pre-auth probe throttle");
|
||||||
return HandshakeResult::BadClient { reader, writer };
|
return HandshakeResult::BadClient { reader, writer };
|
||||||
}
|
}
|
||||||
|
|
||||||
if handshake.len() < tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 {
|
if handshake.len() < tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 {
|
||||||
|
maybe_apply_server_hello_delay(config).await;
|
||||||
debug!(peer = %peer, "TLS handshake too short");
|
debug!(peer = %peer, "TLS handshake too short");
|
||||||
return HandshakeResult::BadClient { reader, writer };
|
return HandshakeResult::BadClient { reader, writer };
|
||||||
}
|
}
|
||||||
|
|
||||||
let secrets = decode_user_secrets(config, None);
|
let client_sni = tls::extract_sni_from_client_hello(handshake);
|
||||||
|
let secrets = decode_user_secrets(config, client_sni.as_deref());
|
||||||
|
|
||||||
let validation = match tls::validate_tls_handshake_with_replay_window(
|
let validation = match tls::validate_tls_handshake_with_replay_window(
|
||||||
handshake,
|
handshake,
|
||||||
|
|
@ -388,6 +572,7 @@ where
|
||||||
Some(v) => v,
|
Some(v) => v,
|
||||||
None => {
|
None => {
|
||||||
auth_probe_record_failure(peer.ip(), Instant::now());
|
auth_probe_record_failure(peer.ip(), Instant::now());
|
||||||
|
maybe_apply_server_hello_delay(config).await;
|
||||||
debug!(
|
debug!(
|
||||||
peer = %peer,
|
peer = %peer,
|
||||||
ignore_time_skew = config.access.ignore_time_skew,
|
ignore_time_skew = config.access.ignore_time_skew,
|
||||||
|
|
@ -402,20 +587,24 @@ where
|
||||||
let digest_half = &validation.digest[..tls::TLS_DIGEST_HALF_LEN];
|
let digest_half = &validation.digest[..tls::TLS_DIGEST_HALF_LEN];
|
||||||
if replay_checker.check_and_add_tls_digest(digest_half) {
|
if replay_checker.check_and_add_tls_digest(digest_half) {
|
||||||
auth_probe_record_failure(peer.ip(), Instant::now());
|
auth_probe_record_failure(peer.ip(), Instant::now());
|
||||||
|
maybe_apply_server_hello_delay(config).await;
|
||||||
warn!(peer = %peer, "TLS replay attack detected (duplicate digest)");
|
warn!(peer = %peer, "TLS replay attack detected (duplicate digest)");
|
||||||
return HandshakeResult::BadClient { reader, writer };
|
return HandshakeResult::BadClient { reader, writer };
|
||||||
}
|
}
|
||||||
|
|
||||||
let secret = match secrets.iter().find(|(name, _)| *name == validation.user) {
|
let secret = match secrets.iter().find(|(name, _)| *name == validation.user) {
|
||||||
Some((_, s)) => s,
|
Some((_, s)) => s,
|
||||||
None => return HandshakeResult::BadClient { reader, writer },
|
None => {
|
||||||
|
maybe_apply_server_hello_delay(config).await;
|
||||||
|
return HandshakeResult::BadClient { reader, writer };
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let cached = if config.censorship.tls_emulation {
|
let cached = if config.censorship.tls_emulation {
|
||||||
if let Some(cache) = tls_cache.as_ref() {
|
if let Some(cache) = tls_cache.as_ref() {
|
||||||
let selected_domain = if let Some(sni) = tls::extract_sni_from_client_hello(handshake) {
|
let selected_domain = if let Some(sni) = client_sni.as_ref() {
|
||||||
if cache.contains_domain(&sni).await {
|
if cache.contains_domain(&sni).await {
|
||||||
sni
|
sni.clone()
|
||||||
} else {
|
} else {
|
||||||
config.censorship.tls_domain.clone()
|
config.censorship.tls_domain.clone()
|
||||||
}
|
}
|
||||||
|
|
@ -448,6 +637,7 @@ where
|
||||||
} else if alpn_list.iter().any(|p| p == b"http/1.1") {
|
} else if alpn_list.iter().any(|p| p == b"http/1.1") {
|
||||||
Some(b"http/1.1".to_vec())
|
Some(b"http/1.1".to_vec())
|
||||||
} else if !alpn_list.is_empty() {
|
} else if !alpn_list.is_empty() {
|
||||||
|
maybe_apply_server_hello_delay(config).await;
|
||||||
debug!(peer = %peer, "Client ALPN list has no supported protocol; using masking fallback");
|
debug!(peer = %peer, "Client ALPN list has no supported protocol; using masking fallback");
|
||||||
return HandshakeResult::BadClient { reader, writer };
|
return HandshakeResult::BadClient { reader, writer };
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -480,19 +670,9 @@ where
|
||||||
)
|
)
|
||||||
};
|
};
|
||||||
|
|
||||||
// Optional anti-fingerprint delay before sending ServerHello.
|
// Apply the same optional delay budget used by reject paths to reduce
|
||||||
if config.censorship.server_hello_delay_max_ms > 0 {
|
// distinguishability between success and fail-closed handshakes.
|
||||||
let min = config.censorship.server_hello_delay_min_ms;
|
maybe_apply_server_hello_delay(config).await;
|
||||||
let max = config.censorship.server_hello_delay_max_ms.max(min);
|
|
||||||
let delay_ms = if max == min {
|
|
||||||
max
|
|
||||||
} else {
|
|
||||||
rand::rng().random_range(min..=max)
|
|
||||||
};
|
|
||||||
if delay_ms > 0 {
|
|
||||||
tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
debug!(peer = %peer, response_len = response.len(), "Sending TLS ServerHello");
|
debug!(peer = %peer, response_len = response.len(), "Sending TLS ServerHello");
|
||||||
|
|
||||||
|
|
@ -538,7 +718,9 @@ where
|
||||||
{
|
{
|
||||||
trace!(peer = %peer, handshake = ?hex::encode(handshake), "MTProto handshake bytes");
|
trace!(peer = %peer, handshake = ?hex::encode(handshake), "MTProto handshake bytes");
|
||||||
|
|
||||||
if auth_probe_is_throttled(peer.ip(), Instant::now()) {
|
let throttle_now = Instant::now();
|
||||||
|
if auth_probe_should_apply_preauth_throttle(peer.ip(), throttle_now) {
|
||||||
|
maybe_apply_server_hello_delay(config).await;
|
||||||
debug!(peer = %peer, "MTProto handshake rejected by pre-auth probe throttle");
|
debug!(peer = %peer, "MTProto handshake rejected by pre-auth probe throttle");
|
||||||
return HandshakeResult::BadClient { reader, writer };
|
return HandshakeResult::BadClient { reader, writer };
|
||||||
}
|
}
|
||||||
|
|
@ -609,6 +791,7 @@ where
|
||||||
// authentication check first to avoid poisoning the replay cache.
|
// authentication check first to avoid poisoning the replay cache.
|
||||||
if replay_checker.check_and_add_handshake(dec_prekey_iv) {
|
if replay_checker.check_and_add_handshake(dec_prekey_iv) {
|
||||||
auth_probe_record_failure(peer.ip(), Instant::now());
|
auth_probe_record_failure(peer.ip(), Instant::now());
|
||||||
|
maybe_apply_server_hello_delay(config).await;
|
||||||
warn!(peer = %peer, user = %user, "MTProto replay attack detected");
|
warn!(peer = %peer, user = %user, "MTProto replay attack detected");
|
||||||
return HandshakeResult::BadClient { reader, writer };
|
return HandshakeResult::BadClient { reader, writer };
|
||||||
}
|
}
|
||||||
|
|
@ -645,6 +828,7 @@ where
|
||||||
}
|
}
|
||||||
|
|
||||||
auth_probe_record_failure(peer.ip(), Instant::now());
|
auth_probe_record_failure(peer.ip(), Instant::now());
|
||||||
|
maybe_apply_server_hello_delay(config).await;
|
||||||
debug!(peer = %peer, "MTProto handshake: no matching user found");
|
debug!(peer = %peer, "MTProto handshake: no matching user found");
|
||||||
HandshakeResult::BadClient { reader, writer }
|
HandshakeResult::BadClient { reader, writer }
|
||||||
}
|
}
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -7,7 +7,7 @@ use tokio::net::TcpStream;
|
||||||
#[cfg(unix)]
|
#[cfg(unix)]
|
||||||
use tokio::net::UnixStream;
|
use tokio::net::UnixStream;
|
||||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt};
|
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt};
|
||||||
use tokio::time::timeout;
|
use tokio::time::{Instant, timeout};
|
||||||
use tracing::debug;
|
use tracing::debug;
|
||||||
use crate::config::ProxyConfig;
|
use crate::config::ProxyConfig;
|
||||||
use crate::network::dns_overrides::resolve_socket_addr;
|
use crate::network::dns_overrides::resolve_socket_addr;
|
||||||
|
|
@ -24,8 +24,36 @@ const MASK_TIMEOUT: Duration = Duration::from_millis(50);
|
||||||
const MASK_RELAY_TIMEOUT: Duration = Duration::from_secs(60);
|
const MASK_RELAY_TIMEOUT: Duration = Duration::from_secs(60);
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
const MASK_RELAY_TIMEOUT: Duration = Duration::from_millis(200);
|
const MASK_RELAY_TIMEOUT: Duration = Duration::from_millis(200);
|
||||||
|
#[cfg(not(test))]
|
||||||
|
const MASK_RELAY_IDLE_TIMEOUT: Duration = Duration::from_secs(5);
|
||||||
|
#[cfg(test)]
|
||||||
|
const MASK_RELAY_IDLE_TIMEOUT: Duration = Duration::from_millis(100);
|
||||||
const MASK_BUFFER_SIZE: usize = 8192;
|
const MASK_BUFFER_SIZE: usize = 8192;
|
||||||
|
|
||||||
|
async fn copy_with_idle_timeout<R, W>(reader: &mut R, writer: &mut W)
|
||||||
|
where
|
||||||
|
R: AsyncRead + Unpin,
|
||||||
|
W: AsyncWrite + Unpin,
|
||||||
|
{
|
||||||
|
let mut buf = [0u8; MASK_BUFFER_SIZE];
|
||||||
|
loop {
|
||||||
|
let read_res = timeout(MASK_RELAY_IDLE_TIMEOUT, reader.read(&mut buf)).await;
|
||||||
|
let n = match read_res {
|
||||||
|
Ok(Ok(n)) => n,
|
||||||
|
Ok(Err(_)) | Err(_) => break,
|
||||||
|
};
|
||||||
|
if n == 0 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
let write_res = timeout(MASK_RELAY_IDLE_TIMEOUT, writer.write_all(&buf[..n])).await;
|
||||||
|
match write_res {
|
||||||
|
Ok(Ok(())) => {}
|
||||||
|
Ok(Err(_)) | Err(_) => break,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async fn write_proxy_header_with_timeout<W>(mask_write: &mut W, header: &[u8]) -> bool
|
async fn write_proxy_header_with_timeout<W>(mask_write: &mut W, header: &[u8]) -> bool
|
||||||
where
|
where
|
||||||
W: AsyncWrite + Unpin,
|
W: AsyncWrite + Unpin,
|
||||||
|
|
@ -49,6 +77,20 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn wait_mask_connect_budget(started: Instant) {
|
||||||
|
let elapsed = started.elapsed();
|
||||||
|
if elapsed < MASK_TIMEOUT {
|
||||||
|
tokio::time::sleep(MASK_TIMEOUT - elapsed).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn wait_mask_outcome_budget(started: Instant) {
|
||||||
|
let elapsed = started.elapsed();
|
||||||
|
if elapsed < MASK_TIMEOUT {
|
||||||
|
tokio::time::sleep(MASK_TIMEOUT - elapsed).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Detect client type based on initial data
|
/// Detect client type based on initial data
|
||||||
fn detect_client_type(data: &[u8]) -> &'static str {
|
fn detect_client_type(data: &[u8]) -> &'static str {
|
||||||
// Check for HTTP request
|
// Check for HTTP request
|
||||||
|
|
@ -107,6 +149,8 @@ where
|
||||||
// Connect via Unix socket or TCP
|
// Connect via Unix socket or TCP
|
||||||
#[cfg(unix)]
|
#[cfg(unix)]
|
||||||
if let Some(ref sock_path) = config.censorship.mask_unix_sock {
|
if let Some(ref sock_path) = config.censorship.mask_unix_sock {
|
||||||
|
let outcome_started = Instant::now();
|
||||||
|
let connect_started = Instant::now();
|
||||||
debug!(
|
debug!(
|
||||||
client_type = client_type,
|
client_type = client_type,
|
||||||
sock = %sock_path,
|
sock = %sock_path,
|
||||||
|
|
@ -143,14 +187,18 @@ where
|
||||||
if timeout(MASK_RELAY_TIMEOUT, relay_to_mask(reader, writer, mask_read, mask_write, initial_data)).await.is_err() {
|
if timeout(MASK_RELAY_TIMEOUT, relay_to_mask(reader, writer, mask_read, mask_write, initial_data)).await.is_err() {
|
||||||
debug!("Mask relay timed out (unix socket)");
|
debug!("Mask relay timed out (unix socket)");
|
||||||
}
|
}
|
||||||
|
wait_mask_outcome_budget(outcome_started).await;
|
||||||
}
|
}
|
||||||
Ok(Err(e)) => {
|
Ok(Err(e)) => {
|
||||||
|
wait_mask_connect_budget(connect_started).await;
|
||||||
debug!(error = %e, "Failed to connect to mask unix socket");
|
debug!(error = %e, "Failed to connect to mask unix socket");
|
||||||
consume_client_data_with_timeout(reader).await;
|
consume_client_data_with_timeout(reader).await;
|
||||||
|
wait_mask_outcome_budget(outcome_started).await;
|
||||||
}
|
}
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
debug!("Timeout connecting to mask unix socket");
|
debug!("Timeout connecting to mask unix socket");
|
||||||
consume_client_data_with_timeout(reader).await;
|
consume_client_data_with_timeout(reader).await;
|
||||||
|
wait_mask_outcome_budget(outcome_started).await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
|
|
@ -172,6 +220,8 @@ where
|
||||||
let mask_addr = resolve_socket_addr(mask_host, mask_port)
|
let mask_addr = resolve_socket_addr(mask_host, mask_port)
|
||||||
.map(|addr| addr.to_string())
|
.map(|addr| addr.to_string())
|
||||||
.unwrap_or_else(|| format!("{}:{}", mask_host, mask_port));
|
.unwrap_or_else(|| format!("{}:{}", mask_host, mask_port));
|
||||||
|
let outcome_started = Instant::now();
|
||||||
|
let connect_started = Instant::now();
|
||||||
let connect_result = timeout(MASK_TIMEOUT, TcpStream::connect(&mask_addr)).await;
|
let connect_result = timeout(MASK_TIMEOUT, TcpStream::connect(&mask_addr)).await;
|
||||||
match connect_result {
|
match connect_result {
|
||||||
Ok(Ok(stream)) => {
|
Ok(Ok(stream)) => {
|
||||||
|
|
@ -202,14 +252,18 @@ where
|
||||||
if timeout(MASK_RELAY_TIMEOUT, relay_to_mask(reader, writer, mask_read, mask_write, initial_data)).await.is_err() {
|
if timeout(MASK_RELAY_TIMEOUT, relay_to_mask(reader, writer, mask_read, mask_write, initial_data)).await.is_err() {
|
||||||
debug!("Mask relay timed out");
|
debug!("Mask relay timed out");
|
||||||
}
|
}
|
||||||
|
wait_mask_outcome_budget(outcome_started).await;
|
||||||
}
|
}
|
||||||
Ok(Err(e)) => {
|
Ok(Err(e)) => {
|
||||||
|
wait_mask_connect_budget(connect_started).await;
|
||||||
debug!(error = %e, "Failed to connect to mask host");
|
debug!(error = %e, "Failed to connect to mask host");
|
||||||
consume_client_data_with_timeout(reader).await;
|
consume_client_data_with_timeout(reader).await;
|
||||||
|
wait_mask_outcome_budget(outcome_started).await;
|
||||||
}
|
}
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
debug!("Timeout connecting to mask host");
|
debug!("Timeout connecting to mask host");
|
||||||
consume_client_data_with_timeout(reader).await;
|
consume_client_data_with_timeout(reader).await;
|
||||||
|
wait_mask_outcome_budget(outcome_started).await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -238,11 +292,11 @@ where
|
||||||
|
|
||||||
let _ = tokio::join!(
|
let _ = tokio::join!(
|
||||||
async {
|
async {
|
||||||
let _ = tokio::io::copy(&mut reader, &mut mask_write).await;
|
copy_with_idle_timeout(&mut reader, &mut mask_write).await;
|
||||||
let _ = mask_write.shutdown().await;
|
let _ = mask_write.shutdown().await;
|
||||||
},
|
},
|
||||||
async {
|
async {
|
||||||
let _ = tokio::io::copy(&mut mask_read, &mut writer).await;
|
copy_with_idle_timeout(&mut mask_read, &mut writer).await;
|
||||||
let _ = writer.shutdown().await;
|
let _ = writer.shutdown().await;
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ use tokio::io::{duplex, AsyncBufReadExt, BufReader};
|
||||||
use tokio::net::TcpListener;
|
use tokio::net::TcpListener;
|
||||||
#[cfg(unix)]
|
#[cfg(unix)]
|
||||||
use tokio::net::UnixListener;
|
use tokio::net::UnixListener;
|
||||||
use tokio::time::{sleep, timeout, Duration};
|
use tokio::time::{Instant, sleep, timeout, Duration};
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn bad_client_probe_is_forwarded_verbatim_to_mask_backend() {
|
async fn bad_client_probe_is_forwarded_verbatim_to_mask_backend() {
|
||||||
|
|
@ -216,6 +216,373 @@ async fn backend_unavailable_falls_back_to_silent_consume() {
|
||||||
assert_eq!(n, 0);
|
assert_eq!(n, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn backend_connect_refusal_waits_mask_connect_budget_before_fallback() {
|
||||||
|
let temp_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let unused_port = temp_listener.local_addr().unwrap().port();
|
||||||
|
drop(temp_listener);
|
||||||
|
|
||||||
|
let mut config = ProxyConfig::default();
|
||||||
|
config.general.beobachten = false;
|
||||||
|
config.censorship.mask = true;
|
||||||
|
config.censorship.mask_host = Some("127.0.0.1".to_string());
|
||||||
|
config.censorship.mask_port = unused_port;
|
||||||
|
config.censorship.mask_unix_sock = None;
|
||||||
|
config.censorship.mask_proxy_protocol = 0;
|
||||||
|
|
||||||
|
let peer: SocketAddr = "203.0.113.12:42426".parse().unwrap();
|
||||||
|
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
|
||||||
|
let probe = b"GET /probe HTTP/1.1\r\nHost: x\r\n\r\n";
|
||||||
|
|
||||||
|
// Close client reader immediately to force the refusal path to rely on masking budget timing.
|
||||||
|
let (client_reader_side, client_reader) = duplex(256);
|
||||||
|
drop(client_reader_side);
|
||||||
|
let (_client_visible_reader, client_visible_writer) = duplex(256);
|
||||||
|
let beobachten = BeobachtenStore::new();
|
||||||
|
|
||||||
|
let started = Instant::now();
|
||||||
|
let task = tokio::spawn(async move {
|
||||||
|
handle_bad_client(
|
||||||
|
client_reader,
|
||||||
|
client_visible_writer,
|
||||||
|
probe,
|
||||||
|
peer,
|
||||||
|
local_addr,
|
||||||
|
&config,
|
||||||
|
&beobachten,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
});
|
||||||
|
|
||||||
|
timeout(Duration::from_millis(35), task)
|
||||||
|
.await
|
||||||
|
.expect_err("masking fallback must not complete before connect budget elapses");
|
||||||
|
assert!(
|
||||||
|
started.elapsed() >= Duration::from_millis(35),
|
||||||
|
"fallback path must absorb immediate connect refusal into connect budget"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn backend_reachable_fast_response_waits_mask_outcome_budget() {
|
||||||
|
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let backend_addr = listener.local_addr().unwrap();
|
||||||
|
let probe = b"GET /ok HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec();
|
||||||
|
let backend_reply = b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\n\r\n".to_vec();
|
||||||
|
|
||||||
|
let accept_task = tokio::spawn({
|
||||||
|
let probe = probe.clone();
|
||||||
|
let backend_reply = backend_reply.clone();
|
||||||
|
async move {
|
||||||
|
let (mut stream, _) = listener.accept().await.unwrap();
|
||||||
|
let mut received = vec![0u8; probe.len()];
|
||||||
|
stream.read_exact(&mut received).await.unwrap();
|
||||||
|
assert_eq!(received, probe);
|
||||||
|
stream.write_all(&backend_reply).await.unwrap();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut config = ProxyConfig::default();
|
||||||
|
config.general.beobachten = false;
|
||||||
|
config.censorship.mask = true;
|
||||||
|
config.censorship.mask_host = Some("127.0.0.1".to_string());
|
||||||
|
config.censorship.mask_port = backend_addr.port();
|
||||||
|
config.censorship.mask_unix_sock = None;
|
||||||
|
config.censorship.mask_proxy_protocol = 0;
|
||||||
|
|
||||||
|
let peer: SocketAddr = "203.0.113.13:42427".parse().unwrap();
|
||||||
|
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
|
||||||
|
|
||||||
|
let (client_writer_side, client_reader) = duplex(256);
|
||||||
|
drop(client_writer_side);
|
||||||
|
let (_client_visible_reader, client_visible_writer) = duplex(512);
|
||||||
|
let beobachten = BeobachtenStore::new();
|
||||||
|
|
||||||
|
let started = Instant::now();
|
||||||
|
handle_bad_client(
|
||||||
|
client_reader,
|
||||||
|
client_visible_writer,
|
||||||
|
&probe,
|
||||||
|
peer,
|
||||||
|
local_addr,
|
||||||
|
&config,
|
||||||
|
&beobachten,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
started.elapsed() >= Duration::from_millis(45),
|
||||||
|
"reachable mask path must also satisfy coarse outcome budget"
|
||||||
|
);
|
||||||
|
accept_task.await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn mask_disabled_fast_eof_not_shaped_by_mask_budget() {
|
||||||
|
let mut config = ProxyConfig::default();
|
||||||
|
config.general.beobachten = false;
|
||||||
|
config.censorship.mask = false;
|
||||||
|
|
||||||
|
let peer: SocketAddr = "203.0.113.14:42428".parse().unwrap();
|
||||||
|
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
|
||||||
|
|
||||||
|
let (client_writer_side, client_reader) = duplex(256);
|
||||||
|
drop(client_writer_side);
|
||||||
|
let (_client_visible_reader, client_visible_writer) = duplex(256);
|
||||||
|
let beobachten = BeobachtenStore::new();
|
||||||
|
|
||||||
|
let started = Instant::now();
|
||||||
|
handle_bad_client(
|
||||||
|
client_reader,
|
||||||
|
client_visible_writer,
|
||||||
|
b"x",
|
||||||
|
peer,
|
||||||
|
local_addr,
|
||||||
|
&config,
|
||||||
|
&beobachten,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
started.elapsed() < Duration::from_millis(20),
|
||||||
|
"mask-disabled fallback should keep immediate EOF behavior"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn backend_reachable_slow_response_not_padded_twice() {
|
||||||
|
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let backend_addr = listener.local_addr().unwrap();
|
||||||
|
let probe = b"GET /slow HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec();
|
||||||
|
let backend_reply = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK".to_vec();
|
||||||
|
|
||||||
|
let accept_task = tokio::spawn({
|
||||||
|
let probe = probe.clone();
|
||||||
|
let backend_reply = backend_reply.clone();
|
||||||
|
async move {
|
||||||
|
let (mut stream, _) = listener.accept().await.unwrap();
|
||||||
|
let mut received = vec![0u8; probe.len()];
|
||||||
|
stream.read_exact(&mut received).await.unwrap();
|
||||||
|
assert_eq!(received, probe);
|
||||||
|
sleep(Duration::from_millis(90)).await;
|
||||||
|
stream.write_all(&backend_reply).await.unwrap();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut config = ProxyConfig::default();
|
||||||
|
config.general.beobachten = false;
|
||||||
|
config.censorship.mask = true;
|
||||||
|
config.censorship.mask_host = Some("127.0.0.1".to_string());
|
||||||
|
config.censorship.mask_port = backend_addr.port();
|
||||||
|
config.censorship.mask_unix_sock = None;
|
||||||
|
config.censorship.mask_proxy_protocol = 0;
|
||||||
|
|
||||||
|
let peer: SocketAddr = "203.0.113.15:42429".parse().unwrap();
|
||||||
|
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
|
||||||
|
|
||||||
|
let (client_writer_side, client_reader) = duplex(256);
|
||||||
|
drop(client_writer_side);
|
||||||
|
let (_client_visible_reader, client_visible_writer) = duplex(512);
|
||||||
|
let beobachten = BeobachtenStore::new();
|
||||||
|
|
||||||
|
let started = Instant::now();
|
||||||
|
handle_bad_client(
|
||||||
|
client_reader,
|
||||||
|
client_visible_writer,
|
||||||
|
&probe,
|
||||||
|
peer,
|
||||||
|
local_addr,
|
||||||
|
&config,
|
||||||
|
&beobachten,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
let elapsed = started.elapsed();
|
||||||
|
|
||||||
|
assert!(elapsed >= Duration::from_millis(85));
|
||||||
|
assert!(
|
||||||
|
elapsed < Duration::from_millis(170),
|
||||||
|
"slow reachable backend should not incur an extra full budget after already exceeding it"
|
||||||
|
);
|
||||||
|
accept_task.await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn adversarial_enabled_refused_and_reachable_collapse_to_same_bucket() {
|
||||||
|
const ITER: usize = 20;
|
||||||
|
const BUCKET_MS: u128 = 10;
|
||||||
|
|
||||||
|
let probe = b"GET /collapse HTTP/1.1\r\nHost: x\r\n\r\n";
|
||||||
|
let peer: SocketAddr = "203.0.113.16:42430".parse().unwrap();
|
||||||
|
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
|
||||||
|
|
||||||
|
let mut refused = Vec::with_capacity(ITER);
|
||||||
|
for _ in 0..ITER {
|
||||||
|
let temp_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let unused_port = temp_listener.local_addr().unwrap().port();
|
||||||
|
drop(temp_listener);
|
||||||
|
|
||||||
|
let mut config = ProxyConfig::default();
|
||||||
|
config.general.beobachten = false;
|
||||||
|
config.censorship.mask = true;
|
||||||
|
config.censorship.mask_host = Some("127.0.0.1".to_string());
|
||||||
|
config.censorship.mask_port = unused_port;
|
||||||
|
config.censorship.mask_unix_sock = None;
|
||||||
|
config.censorship.mask_proxy_protocol = 0;
|
||||||
|
|
||||||
|
let (client_writer_side, client_reader) = duplex(256);
|
||||||
|
drop(client_writer_side);
|
||||||
|
let (_client_visible_reader, client_visible_writer) = duplex(256);
|
||||||
|
let beobachten = BeobachtenStore::new();
|
||||||
|
|
||||||
|
let started = Instant::now();
|
||||||
|
handle_bad_client(
|
||||||
|
client_reader,
|
||||||
|
client_visible_writer,
|
||||||
|
probe,
|
||||||
|
peer,
|
||||||
|
local_addr,
|
||||||
|
&config,
|
||||||
|
&beobachten,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
refused.push(started.elapsed().as_millis());
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut reachable = Vec::with_capacity(ITER);
|
||||||
|
for _ in 0..ITER {
|
||||||
|
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let backend_addr = listener.local_addr().unwrap();
|
||||||
|
let probe_vec = probe.to_vec();
|
||||||
|
let backend_reply = b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\n\r\n".to_vec();
|
||||||
|
|
||||||
|
let accept_task = tokio::spawn(async move {
|
||||||
|
let (mut stream, _) = listener.accept().await.unwrap();
|
||||||
|
let mut received = vec![0u8; probe_vec.len()];
|
||||||
|
stream.read_exact(&mut received).await.unwrap();
|
||||||
|
stream.write_all(&backend_reply).await.unwrap();
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut config = ProxyConfig::default();
|
||||||
|
config.general.beobachten = false;
|
||||||
|
config.censorship.mask = true;
|
||||||
|
config.censorship.mask_host = Some("127.0.0.1".to_string());
|
||||||
|
config.censorship.mask_port = backend_addr.port();
|
||||||
|
config.censorship.mask_unix_sock = None;
|
||||||
|
config.censorship.mask_proxy_protocol = 0;
|
||||||
|
|
||||||
|
let (client_writer_side, client_reader) = duplex(256);
|
||||||
|
drop(client_writer_side);
|
||||||
|
let (_client_visible_reader, client_visible_writer) = duplex(256);
|
||||||
|
let beobachten = BeobachtenStore::new();
|
||||||
|
|
||||||
|
let started = Instant::now();
|
||||||
|
handle_bad_client(
|
||||||
|
client_reader,
|
||||||
|
client_visible_writer,
|
||||||
|
probe,
|
||||||
|
peer,
|
||||||
|
local_addr,
|
||||||
|
&config,
|
||||||
|
&beobachten,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
reachable.push(started.elapsed().as_millis());
|
||||||
|
accept_task.await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
let refused_mean = refused.iter().copied().sum::<u128>() as f64 / refused.len() as f64;
|
||||||
|
let reachable_mean = reachable.iter().copied().sum::<u128>() as f64 / reachable.len() as f64;
|
||||||
|
let refused_bucket = (refused_mean as u128) / BUCKET_MS;
|
||||||
|
let reachable_bucket = (reachable_mean as u128) / BUCKET_MS;
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
refused_bucket.abs_diff(reachable_bucket) <= 1,
|
||||||
|
"enabled refused and reachable paths must collapse into the same coarse latency bucket"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn light_fuzz_mask_enabled_outcomes_preserve_coarse_budget() {
|
||||||
|
let mut seed: u64 = 0xA5A5_5A5A_1337_4242;
|
||||||
|
let mut next = || {
|
||||||
|
seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
|
||||||
|
seed
|
||||||
|
};
|
||||||
|
|
||||||
|
let peer: SocketAddr = "203.0.113.17:42431".parse().unwrap();
|
||||||
|
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
|
||||||
|
|
||||||
|
for _ in 0..40 {
|
||||||
|
let probe_len = (next() as usize % 96).saturating_add(8);
|
||||||
|
let mut probe = vec![0u8; probe_len];
|
||||||
|
for byte in &mut probe {
|
||||||
|
*byte = next() as u8;
|
||||||
|
}
|
||||||
|
|
||||||
|
let use_reachable = (next() & 1) == 0;
|
||||||
|
let mut config = ProxyConfig::default();
|
||||||
|
config.general.beobachten = false;
|
||||||
|
config.censorship.mask = true;
|
||||||
|
config.censorship.mask_unix_sock = None;
|
||||||
|
config.censorship.mask_proxy_protocol = 0;
|
||||||
|
|
||||||
|
let (client_writer_side, client_reader) = duplex(512);
|
||||||
|
drop(client_writer_side);
|
||||||
|
let (_client_visible_reader, client_visible_writer) = duplex(512);
|
||||||
|
let beobachten = BeobachtenStore::new();
|
||||||
|
|
||||||
|
let started = Instant::now();
|
||||||
|
if use_reachable {
|
||||||
|
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let backend_addr = listener.local_addr().unwrap();
|
||||||
|
config.censorship.mask_host = Some("127.0.0.1".to_string());
|
||||||
|
config.censorship.mask_port = backend_addr.port();
|
||||||
|
|
||||||
|
let probe_vec = probe.clone();
|
||||||
|
let accept_task = tokio::spawn(async move {
|
||||||
|
let (mut stream, _) = listener.accept().await.unwrap();
|
||||||
|
let mut observed = vec![0u8; probe_vec.len()];
|
||||||
|
stream.read_exact(&mut observed).await.unwrap();
|
||||||
|
});
|
||||||
|
|
||||||
|
handle_bad_client(
|
||||||
|
client_reader,
|
||||||
|
client_visible_writer,
|
||||||
|
&probe,
|
||||||
|
peer,
|
||||||
|
local_addr,
|
||||||
|
&config,
|
||||||
|
&beobachten,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
accept_task.await.unwrap();
|
||||||
|
} else {
|
||||||
|
let temp_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let unused_port = temp_listener.local_addr().unwrap().port();
|
||||||
|
drop(temp_listener);
|
||||||
|
|
||||||
|
config.censorship.mask_host = Some("127.0.0.1".to_string());
|
||||||
|
config.censorship.mask_port = unused_port;
|
||||||
|
|
||||||
|
handle_bad_client(
|
||||||
|
client_reader,
|
||||||
|
client_visible_writer,
|
||||||
|
&probe,
|
||||||
|
peer,
|
||||||
|
local_addr,
|
||||||
|
&config,
|
||||||
|
&beobachten,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
started.elapsed() >= Duration::from_millis(45),
|
||||||
|
"mask-enabled fallback must preserve coarse timing budget under varied probe shapes"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn mask_disabled_consumes_client_data_without_response() {
|
async fn mask_disabled_consumes_client_data_without_response() {
|
||||||
let mut config = ProxyConfig::default();
|
let mut config = ProxyConfig::default();
|
||||||
|
|
@ -524,6 +891,59 @@ async fn mask_disabled_slowloris_connection_is_closed_by_consume_timeout() {
|
||||||
timeout(Duration::from_secs(1), task).await.unwrap().unwrap();
|
timeout(Duration::from_secs(1), task).await.unwrap().unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn mask_enabled_idle_relay_is_closed_by_idle_timeout_before_global_relay_timeout() {
|
||||||
|
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let backend_addr = listener.local_addr().unwrap();
|
||||||
|
let probe = b"GET /idle HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec();
|
||||||
|
|
||||||
|
let accept_task = tokio::spawn({
|
||||||
|
let probe = probe.clone();
|
||||||
|
async move {
|
||||||
|
let (mut stream, _) = listener.accept().await.unwrap();
|
||||||
|
let mut received = vec![0u8; probe.len()];
|
||||||
|
stream.read_exact(&mut received).await.unwrap();
|
||||||
|
assert_eq!(received, probe);
|
||||||
|
sleep(Duration::from_millis(300)).await;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut config = ProxyConfig::default();
|
||||||
|
config.general.beobachten = false;
|
||||||
|
config.censorship.mask = true;
|
||||||
|
config.censorship.mask_host = Some("127.0.0.1".to_string());
|
||||||
|
config.censorship.mask_port = backend_addr.port();
|
||||||
|
config.censorship.mask_unix_sock = None;
|
||||||
|
config.censorship.mask_proxy_protocol = 0;
|
||||||
|
|
||||||
|
let peer: SocketAddr = "198.51.100.34:45456".parse().unwrap();
|
||||||
|
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
|
||||||
|
|
||||||
|
let (_client_reader_side, client_reader) = duplex(512);
|
||||||
|
let (_client_visible_reader, client_visible_writer) = duplex(512);
|
||||||
|
let beobachten = BeobachtenStore::new();
|
||||||
|
|
||||||
|
let started = Instant::now();
|
||||||
|
handle_bad_client(
|
||||||
|
client_reader,
|
||||||
|
client_visible_writer,
|
||||||
|
&probe,
|
||||||
|
peer,
|
||||||
|
local_addr,
|
||||||
|
&config,
|
||||||
|
&beobachten,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let elapsed = started.elapsed();
|
||||||
|
assert!(
|
||||||
|
elapsed < Duration::from_millis(150),
|
||||||
|
"idle unauth relay must terminate on idle timeout instead of waiting for full relay timeout"
|
||||||
|
);
|
||||||
|
|
||||||
|
accept_task.await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
struct PendingWriter;
|
struct PendingWriter;
|
||||||
|
|
||||||
impl tokio::io::AsyncWrite for PendingWriter {
|
impl tokio::io::AsyncWrite for PendingWriter {
|
||||||
|
|
@ -729,3 +1149,321 @@ async fn relay_to_mask_timeout_cancels_and_drops_all_io_endpoints() {
|
||||||
assert!(mask_reader_dropped.load(Ordering::SeqCst));
|
assert!(mask_reader_dropped.load(Ordering::SeqCst));
|
||||||
assert!(mask_writer_dropped.load(Ordering::SeqCst));
|
assert!(mask_writer_dropped.load(Ordering::SeqCst));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
#[ignore = "timing matrix; run manually with --ignored --nocapture"]
|
||||||
|
async fn timing_matrix_masking_classes_under_controlled_inputs() {
|
||||||
|
const ITER: usize = 24;
|
||||||
|
const BUCKET_MS: u128 = 10;
|
||||||
|
|
||||||
|
let probe = b"GET /timing HTTP/1.1\r\nHost: x\r\n\r\n";
|
||||||
|
let peer: SocketAddr = "203.0.113.40:51000".parse().unwrap();
|
||||||
|
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
|
||||||
|
|
||||||
|
// Class 1: masking disabled with immediate EOF (fast fail-closed consume path).
|
||||||
|
let mut disabled_samples = Vec::with_capacity(ITER);
|
||||||
|
for _ in 0..ITER {
|
||||||
|
let mut config = ProxyConfig::default();
|
||||||
|
config.general.beobachten = false;
|
||||||
|
config.censorship.mask = false;
|
||||||
|
|
||||||
|
let (client_writer_side, client_reader) = duplex(256);
|
||||||
|
drop(client_writer_side);
|
||||||
|
let (_client_visible_reader, client_visible_writer) = duplex(256);
|
||||||
|
let beobachten = BeobachtenStore::new();
|
||||||
|
|
||||||
|
let started = Instant::now();
|
||||||
|
handle_bad_client(
|
||||||
|
client_reader,
|
||||||
|
client_visible_writer,
|
||||||
|
probe,
|
||||||
|
peer,
|
||||||
|
local_addr,
|
||||||
|
&config,
|
||||||
|
&beobachten,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
disabled_samples.push(started.elapsed().as_millis());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Class 2: masking enabled, backend connect refused.
|
||||||
|
let mut refused_samples = Vec::with_capacity(ITER);
|
||||||
|
for _ in 0..ITER {
|
||||||
|
let temp_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let unused_port = temp_listener.local_addr().unwrap().port();
|
||||||
|
drop(temp_listener);
|
||||||
|
|
||||||
|
let mut config = ProxyConfig::default();
|
||||||
|
config.general.beobachten = false;
|
||||||
|
config.censorship.mask = true;
|
||||||
|
config.censorship.mask_host = Some("127.0.0.1".to_string());
|
||||||
|
config.censorship.mask_port = unused_port;
|
||||||
|
config.censorship.mask_unix_sock = None;
|
||||||
|
config.censorship.mask_proxy_protocol = 0;
|
||||||
|
|
||||||
|
let (client_writer_side, client_reader) = duplex(256);
|
||||||
|
drop(client_writer_side);
|
||||||
|
let (_client_visible_reader, client_visible_writer) = duplex(256);
|
||||||
|
let beobachten = BeobachtenStore::new();
|
||||||
|
|
||||||
|
let started = Instant::now();
|
||||||
|
handle_bad_client(
|
||||||
|
client_reader,
|
||||||
|
client_visible_writer,
|
||||||
|
probe,
|
||||||
|
peer,
|
||||||
|
local_addr,
|
||||||
|
&config,
|
||||||
|
&beobachten,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
refused_samples.push(started.elapsed().as_millis());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Class 3: masking enabled, backend reachable and immediately responds.
|
||||||
|
let mut reachable_samples = Vec::with_capacity(ITER);
|
||||||
|
for _ in 0..ITER {
|
||||||
|
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let backend_addr = listener.local_addr().unwrap();
|
||||||
|
let backend_reply = b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\n\r\n".to_vec();
|
||||||
|
let probe_vec = probe.to_vec();
|
||||||
|
|
||||||
|
let accept_task = tokio::spawn(async move {
|
||||||
|
let (mut stream, _) = listener.accept().await.unwrap();
|
||||||
|
let mut received = vec![0u8; probe_vec.len()];
|
||||||
|
stream.read_exact(&mut received).await.unwrap();
|
||||||
|
assert_eq!(received, probe_vec);
|
||||||
|
stream.write_all(&backend_reply).await.unwrap();
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut config = ProxyConfig::default();
|
||||||
|
config.general.beobachten = false;
|
||||||
|
config.censorship.mask = true;
|
||||||
|
config.censorship.mask_host = Some("127.0.0.1".to_string());
|
||||||
|
config.censorship.mask_port = backend_addr.port();
|
||||||
|
config.censorship.mask_unix_sock = None;
|
||||||
|
config.censorship.mask_proxy_protocol = 0;
|
||||||
|
|
||||||
|
let (client_writer_side, client_reader) = duplex(256);
|
||||||
|
drop(client_writer_side);
|
||||||
|
let (_client_visible_reader, client_visible_writer) = duplex(256);
|
||||||
|
let beobachten = BeobachtenStore::new();
|
||||||
|
|
||||||
|
let started = Instant::now();
|
||||||
|
handle_bad_client(
|
||||||
|
client_reader,
|
||||||
|
client_visible_writer,
|
||||||
|
probe,
|
||||||
|
peer,
|
||||||
|
local_addr,
|
||||||
|
&config,
|
||||||
|
&beobachten,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
reachable_samples.push(started.elapsed().as_millis());
|
||||||
|
accept_task.await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn summarize(samples_ms: &mut [u128]) -> (f64, u128, u128, u128) {
|
||||||
|
samples_ms.sort_unstable();
|
||||||
|
let sum: u128 = samples_ms.iter().copied().sum();
|
||||||
|
let mean = sum as f64 / samples_ms.len() as f64;
|
||||||
|
let min = samples_ms[0];
|
||||||
|
let p95_idx = ((samples_ms.len() as f64) * 0.95).floor() as usize;
|
||||||
|
let p95 = samples_ms[p95_idx.min(samples_ms.len() - 1)];
|
||||||
|
let max = samples_ms[samples_ms.len() - 1];
|
||||||
|
(mean, min, p95, max)
|
||||||
|
}
|
||||||
|
|
||||||
|
let (disabled_mean, disabled_min, disabled_p95, disabled_max) = summarize(&mut disabled_samples);
|
||||||
|
let (refused_mean, refused_min, refused_p95, refused_max) = summarize(&mut refused_samples);
|
||||||
|
let (reachable_mean, reachable_min, reachable_p95, reachable_max) = summarize(&mut reachable_samples);
|
||||||
|
|
||||||
|
println!(
|
||||||
|
"TIMING_MATRIX masking class=disabled_eof mean_ms={:.2} min_ms={} p95_ms={} max_ms={} bucket_mean={}",
|
||||||
|
disabled_mean,
|
||||||
|
disabled_min,
|
||||||
|
disabled_p95,
|
||||||
|
disabled_max,
|
||||||
|
(disabled_mean as u128) / BUCKET_MS
|
||||||
|
);
|
||||||
|
println!(
|
||||||
|
"TIMING_MATRIX masking class=enabled_refused_eof mean_ms={:.2} min_ms={} p95_ms={} max_ms={} bucket_mean={}",
|
||||||
|
refused_mean,
|
||||||
|
refused_min,
|
||||||
|
refused_p95,
|
||||||
|
refused_max,
|
||||||
|
(refused_mean as u128) / BUCKET_MS
|
||||||
|
);
|
||||||
|
println!(
|
||||||
|
"TIMING_MATRIX masking class=enabled_reachable_eof mean_ms={:.2} min_ms={} p95_ms={} max_ms={} bucket_mean={}",
|
||||||
|
reachable_mean,
|
||||||
|
reachable_min,
|
||||||
|
reachable_p95,
|
||||||
|
reachable_max,
|
||||||
|
(reachable_mean as u128) / BUCKET_MS
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn backend_connect_refusal_completes_within_bounded_mask_budget() {
|
||||||
|
let temp_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let unused_port = temp_listener.local_addr().unwrap().port();
|
||||||
|
drop(temp_listener);
|
||||||
|
|
||||||
|
let mut config = ProxyConfig::default();
|
||||||
|
config.general.beobachten = false;
|
||||||
|
config.censorship.mask = true;
|
||||||
|
config.censorship.mask_host = Some("127.0.0.1".to_string());
|
||||||
|
config.censorship.mask_port = unused_port;
|
||||||
|
config.censorship.mask_unix_sock = None;
|
||||||
|
config.censorship.mask_proxy_protocol = 0;
|
||||||
|
|
||||||
|
let peer: SocketAddr = "203.0.113.41:51001".parse().unwrap();
|
||||||
|
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
|
||||||
|
let probe = b"GET /bounded HTTP/1.1\r\nHost: x\r\n\r\n";
|
||||||
|
|
||||||
|
let (_client_reader_side, client_reader) = duplex(256);
|
||||||
|
let (_client_visible_reader, client_visible_writer) = duplex(256);
|
||||||
|
let beobachten = BeobachtenStore::new();
|
||||||
|
|
||||||
|
let started = Instant::now();
|
||||||
|
handle_bad_client(
|
||||||
|
client_reader,
|
||||||
|
client_visible_writer,
|
||||||
|
probe,
|
||||||
|
peer,
|
||||||
|
local_addr,
|
||||||
|
&config,
|
||||||
|
&beobachten,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let elapsed = started.elapsed();
|
||||||
|
assert!(
|
||||||
|
elapsed >= Duration::from_millis(45),
|
||||||
|
"connect refusal path must respect minimum masking budget"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
elapsed < Duration::from_millis(500),
|
||||||
|
"connect refusal path must stay bounded and avoid unbounded stall"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn reachable_backend_one_response_then_silence_is_cut_by_idle_timeout() {
|
||||||
|
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let backend_addr = listener.local_addr().unwrap();
|
||||||
|
let probe = b"GET /oneshot HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec();
|
||||||
|
let response = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK".to_vec();
|
||||||
|
|
||||||
|
let accept_task = tokio::spawn({
|
||||||
|
let probe = probe.clone();
|
||||||
|
let response = response.clone();
|
||||||
|
async move {
|
||||||
|
let (mut stream, _) = listener.accept().await.unwrap();
|
||||||
|
let mut received = vec![0u8; probe.len()];
|
||||||
|
stream.read_exact(&mut received).await.unwrap();
|
||||||
|
assert_eq!(received, probe);
|
||||||
|
stream.write_all(&response).await.unwrap();
|
||||||
|
sleep(Duration::from_millis(300)).await;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut config = ProxyConfig::default();
|
||||||
|
config.general.beobachten = false;
|
||||||
|
config.censorship.mask = true;
|
||||||
|
config.censorship.mask_host = Some("127.0.0.1".to_string());
|
||||||
|
config.censorship.mask_port = backend_addr.port();
|
||||||
|
config.censorship.mask_unix_sock = None;
|
||||||
|
config.censorship.mask_proxy_protocol = 0;
|
||||||
|
|
||||||
|
let peer: SocketAddr = "203.0.113.42:51002".parse().unwrap();
|
||||||
|
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
|
||||||
|
|
||||||
|
let (_client_reader_side, client_reader) = duplex(256);
|
||||||
|
let (mut client_visible_reader, client_visible_writer) = duplex(512);
|
||||||
|
let beobachten = BeobachtenStore::new();
|
||||||
|
|
||||||
|
let started = Instant::now();
|
||||||
|
handle_bad_client(
|
||||||
|
client_reader,
|
||||||
|
client_visible_writer,
|
||||||
|
&probe,
|
||||||
|
peer,
|
||||||
|
local_addr,
|
||||||
|
&config,
|
||||||
|
&beobachten,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
let elapsed = started.elapsed();
|
||||||
|
|
||||||
|
let mut observed = vec![0u8; response.len()];
|
||||||
|
client_visible_reader.read_exact(&mut observed).await.unwrap();
|
||||||
|
assert_eq!(observed, response);
|
||||||
|
assert!(
|
||||||
|
elapsed < Duration::from_millis(190),
|
||||||
|
"idle backend silence after first response must be cut by relay idle timeout"
|
||||||
|
);
|
||||||
|
|
||||||
|
accept_task.await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn adversarial_client_drip_feed_longer_than_idle_timeout_is_cut_off() {
|
||||||
|
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let backend_addr = listener.local_addr().unwrap();
|
||||||
|
let initial = b"GET /drip HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec();
|
||||||
|
|
||||||
|
let accept_task = tokio::spawn({
|
||||||
|
let initial = initial.clone();
|
||||||
|
async move {
|
||||||
|
let (mut stream, _) = listener.accept().await.unwrap();
|
||||||
|
let mut observed = vec![0u8; initial.len()];
|
||||||
|
stream.read_exact(&mut observed).await.unwrap();
|
||||||
|
assert_eq!(observed, initial);
|
||||||
|
|
||||||
|
let mut extra = [0u8; 1];
|
||||||
|
let read_res = timeout(Duration::from_millis(220), stream.read_exact(&mut extra)).await;
|
||||||
|
assert!(
|
||||||
|
read_res.is_err() || read_res.unwrap().is_err(),
|
||||||
|
"drip-fed post-probe byte arriving after idle timeout should not be forwarded"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut config = ProxyConfig::default();
|
||||||
|
config.general.beobachten = false;
|
||||||
|
config.censorship.mask = true;
|
||||||
|
config.censorship.mask_host = Some("127.0.0.1".to_string());
|
||||||
|
config.censorship.mask_port = backend_addr.port();
|
||||||
|
config.censorship.mask_unix_sock = None;
|
||||||
|
config.censorship.mask_proxy_protocol = 0;
|
||||||
|
|
||||||
|
let peer: SocketAddr = "203.0.113.43:51003".parse().unwrap();
|
||||||
|
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
|
||||||
|
|
||||||
|
let (mut client_writer_side, client_reader) = duplex(256);
|
||||||
|
let (_client_visible_reader, client_visible_writer) = duplex(256);
|
||||||
|
let beobachten = BeobachtenStore::new();
|
||||||
|
|
||||||
|
let relay_task = tokio::spawn(async move {
|
||||||
|
handle_bad_client(
|
||||||
|
client_reader,
|
||||||
|
client_visible_writer,
|
||||||
|
&initial,
|
||||||
|
peer,
|
||||||
|
local_addr,
|
||||||
|
&config,
|
||||||
|
&beobachten,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
});
|
||||||
|
|
||||||
|
sleep(Duration::from_millis(160)).await;
|
||||||
|
let _ = client_writer_side.write_all(b"X").await;
|
||||||
|
drop(client_writer_side);
|
||||||
|
|
||||||
|
timeout(Duration::from_secs(1), relay_task).await.unwrap().unwrap();
|
||||||
|
accept_task.await.unwrap();
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
use std::collections::hash_map::DefaultHasher;
|
use std::collections::hash_map::RandomState;
|
||||||
|
use std::hash::BuildHasher;
|
||||||
use std::hash::{Hash, Hasher};
|
use std::hash::{Hash, Hasher};
|
||||||
use std::net::{IpAddr, SocketAddr};
|
use std::net::{IpAddr, SocketAddr};
|
||||||
use std::sync::atomic::{AtomicU64, Ordering};
|
use std::sync::atomic::{AtomicU64, Ordering};
|
||||||
|
|
@ -41,6 +42,7 @@ const C2ME_SENDER_FAIRNESS_BUDGET: usize = 32;
|
||||||
const ME_D2C_FLUSH_BATCH_MAX_FRAMES_MIN: usize = 1;
|
const ME_D2C_FLUSH_BATCH_MAX_FRAMES_MIN: usize = 1;
|
||||||
const ME_D2C_FLUSH_BATCH_MAX_BYTES_MIN: usize = 4096;
|
const ME_D2C_FLUSH_BATCH_MAX_BYTES_MIN: usize = 4096;
|
||||||
static DESYNC_DEDUP: OnceLock<DashMap<u64, Instant>> = OnceLock::new();
|
static DESYNC_DEDUP: OnceLock<DashMap<u64, Instant>> = OnceLock::new();
|
||||||
|
static DESYNC_HASHER: OnceLock<RandomState> = OnceLock::new();
|
||||||
|
|
||||||
struct RelayForensicsState {
|
struct RelayForensicsState {
|
||||||
trace_id: u64,
|
trace_id: u64,
|
||||||
|
|
@ -80,7 +82,8 @@ impl MeD2cFlushPolicy {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn hash_value<T: Hash>(value: &T) -> u64 {
|
fn hash_value<T: Hash>(value: &T) -> u64 {
|
||||||
let mut hasher = DefaultHasher::new();
|
let state = DESYNC_HASHER.get_or_init(RandomState::new);
|
||||||
|
let mut hasher = state.build_hasher();
|
||||||
value.hash(&mut hasher);
|
value.hash(&mut hasher);
|
||||||
hasher.finish()
|
hasher.finish()
|
||||||
}
|
}
|
||||||
|
|
@ -106,12 +109,17 @@ fn should_emit_full_desync(key: u64, all_full: bool, now: Instant) -> bool {
|
||||||
|
|
||||||
if dedup.len() >= DESYNC_DEDUP_MAX_ENTRIES {
|
if dedup.len() >= DESYNC_DEDUP_MAX_ENTRIES {
|
||||||
let mut stale_keys = Vec::new();
|
let mut stale_keys = Vec::new();
|
||||||
let mut eviction_candidate = None;
|
let mut oldest_candidate: Option<(u64, Instant)> = None;
|
||||||
for entry in dedup.iter().take(DESYNC_DEDUP_PRUNE_SCAN_LIMIT) {
|
for entry in dedup.iter().take(DESYNC_DEDUP_PRUNE_SCAN_LIMIT) {
|
||||||
if eviction_candidate.is_none() {
|
let key = *entry.key();
|
||||||
eviction_candidate = Some(*entry.key());
|
let seen_at = *entry.value();
|
||||||
|
|
||||||
|
match oldest_candidate {
|
||||||
|
Some((_, oldest_seen)) if seen_at >= oldest_seen => {}
|
||||||
|
_ => oldest_candidate = Some((key, seen_at)),
|
||||||
}
|
}
|
||||||
if now.duration_since(*entry.value()) >= DESYNC_DEDUP_WINDOW {
|
|
||||||
|
if now.duration_since(seen_at) >= DESYNC_DEDUP_WINDOW {
|
||||||
stale_keys.push(*entry.key());
|
stale_keys.push(*entry.key());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -119,7 +127,7 @@ fn should_emit_full_desync(key: u64, all_full: bool, now: Instant) -> bool {
|
||||||
dedup.remove(&stale_key);
|
dedup.remove(&stale_key);
|
||||||
}
|
}
|
||||||
if dedup.len() >= DESYNC_DEDUP_MAX_ENTRIES {
|
if dedup.len() >= DESYNC_DEDUP_MAX_ENTRIES {
|
||||||
let Some(evict_key) = eviction_candidate else {
|
let Some((evict_key, _)) = oldest_candidate else {
|
||||||
return false;
|
return false;
|
||||||
};
|
};
|
||||||
dedup.remove(&evict_key);
|
dedup.remove(&evict_key);
|
||||||
|
|
@ -306,7 +314,7 @@ where
|
||||||
};
|
};
|
||||||
|
|
||||||
stats.increment_user_connects(&user);
|
stats.increment_user_connects(&user);
|
||||||
stats.increment_current_connections_me();
|
let _me_connection_lease = stats.acquire_me_connection_lease();
|
||||||
|
|
||||||
if let Some(cutover) = affected_cutover_state(
|
if let Some(cutover) = affected_cutover_state(
|
||||||
&route_rx,
|
&route_rx,
|
||||||
|
|
@ -324,7 +332,6 @@ where
|
||||||
tokio::time::sleep(delay).await;
|
tokio::time::sleep(delay).await;
|
||||||
let _ = me_pool.send_close(conn_id).await;
|
let _ = me_pool.send_close(conn_id).await;
|
||||||
me_pool.registry().unregister(conn_id).await;
|
me_pool.registry().unregister(conn_id).await;
|
||||||
stats.decrement_current_connections_me();
|
|
||||||
return Err(ProxyError::Proxy(ROUTE_SWITCH_ERROR_MSG.to_string()));
|
return Err(ProxyError::Proxy(ROUTE_SWITCH_ERROR_MSG.to_string()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -672,7 +679,6 @@ where
|
||||||
"ME relay cleanup"
|
"ME relay cleanup"
|
||||||
);
|
);
|
||||||
me_pool.registry().unregister(conn_id).await;
|
me_pool.registry().unregister(conn_id).await;
|
||||||
stats.decrement_current_connections_me();
|
|
||||||
result
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,8 +2,15 @@ use super::*;
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
use crate::crypto::AesCtr;
|
use crate::crypto::AesCtr;
|
||||||
use crate::crypto::SecureRandom;
|
use crate::crypto::SecureRandom;
|
||||||
|
use crate::config::{GeneralConfig, MeRouteNoWriterMode, MeSocksKdfPolicy, MeWriterPickMode};
|
||||||
|
use crate::network::probe::NetworkDecision;
|
||||||
|
use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController};
|
||||||
use crate::stats::Stats;
|
use crate::stats::Stats;
|
||||||
use crate::stream::{BufferPool, CryptoReader, CryptoWriter, PooledBuffer};
|
use crate::stream::{BufferPool, CryptoReader, CryptoWriter, PooledBuffer};
|
||||||
|
use crate::transport::middle_proxy::MePool;
|
||||||
|
use rand::rngs::StdRng;
|
||||||
|
use rand::{Rng, SeedableRng};
|
||||||
|
use std::collections::{HashMap, HashSet};
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::sync::atomic::AtomicU64;
|
use std::sync::atomic::AtomicU64;
|
||||||
|
|
@ -215,6 +222,190 @@ fn desync_dedup_full_cache_churn_stays_suppressed() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn dedup_hash_is_stable_for_same_input_within_process() {
|
||||||
|
let sample = (
|
||||||
|
"scope_user",
|
||||||
|
hash_ip("198.51.100.7".parse().unwrap()),
|
||||||
|
ProtoTag::Secure,
|
||||||
|
);
|
||||||
|
let first = hash_value(&sample);
|
||||||
|
let second = hash_value(&sample);
|
||||||
|
assert_eq!(
|
||||||
|
first, second,
|
||||||
|
"dedup hash must be stable within a process for cache lookups"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn dedup_hash_resists_simple_collision_bursts_for_peer_ip_space() {
|
||||||
|
let mut seen = HashSet::new();
|
||||||
|
|
||||||
|
for octet in 1u16..=2048 {
|
||||||
|
let third = ((octet / 256) & 0xff) as u8;
|
||||||
|
let fourth = (octet & 0xff) as u8;
|
||||||
|
let ip = IpAddr::V4(std::net::Ipv4Addr::new(198, 51, third, fourth));
|
||||||
|
let key = hash_value(&(
|
||||||
|
"scope_user",
|
||||||
|
hash_ip(ip),
|
||||||
|
ProtoTag::Secure,
|
||||||
|
DESYNC_ERROR_CLASS,
|
||||||
|
));
|
||||||
|
seen.insert(key);
|
||||||
|
}
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
seen.len(),
|
||||||
|
2048,
|
||||||
|
"adversarial peer-IP burst should not collapse dedup keys via trivial collisions"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn light_fuzz_dedup_hash_collision_rate_stays_negligible() {
|
||||||
|
let mut rng = StdRng::seed_from_u64(0x9E37_79B9_A1B2_C3D4);
|
||||||
|
let mut seen = HashSet::new();
|
||||||
|
let samples = 8192usize;
|
||||||
|
|
||||||
|
for _ in 0..samples {
|
||||||
|
let user_seed: u64 = rng.random();
|
||||||
|
let peer_seed: u64 = rng.random();
|
||||||
|
let proto = if (peer_seed & 1) == 0 {
|
||||||
|
ProtoTag::Secure
|
||||||
|
} else {
|
||||||
|
ProtoTag::Intermediate
|
||||||
|
};
|
||||||
|
let key = hash_value(&(user_seed, peer_seed, proto, DESYNC_ERROR_CLASS));
|
||||||
|
seen.insert(key);
|
||||||
|
}
|
||||||
|
|
||||||
|
let collisions = samples - seen.len();
|
||||||
|
assert!(
|
||||||
|
collisions <= 1,
|
||||||
|
"light fuzz collision count should remain negligible for 64-bit dedup keys"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn stress_desync_dedup_churn_keeps_cache_hard_bounded() {
|
||||||
|
let _guard = desync_dedup_test_lock()
|
||||||
|
.lock()
|
||||||
|
.expect("desync dedup test lock must be available");
|
||||||
|
clear_desync_dedup_for_testing();
|
||||||
|
|
||||||
|
let now = Instant::now();
|
||||||
|
let total = DESYNC_DEDUP_MAX_ENTRIES + 8192;
|
||||||
|
|
||||||
|
for key in 0..total as u64 {
|
||||||
|
let emitted = should_emit_full_desync(key, false, now);
|
||||||
|
if key < DESYNC_DEDUP_MAX_ENTRIES as u64 {
|
||||||
|
assert!(emitted, "keys below cap must be admitted initially");
|
||||||
|
} else {
|
||||||
|
assert!(
|
||||||
|
!emitted,
|
||||||
|
"new keys above cap must stay suppressed under sustained churn"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let len = DESYNC_DEDUP
|
||||||
|
.get()
|
||||||
|
.expect("dedup cache must be initialized by stress run")
|
||||||
|
.len();
|
||||||
|
assert!(
|
||||||
|
len <= DESYNC_DEDUP_MAX_ENTRIES,
|
||||||
|
"dedup cache must stay bounded under stress churn"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn desync_dedup_full_cache_inserts_new_key_with_bounded_single_key_churn() {
|
||||||
|
let _guard = desync_dedup_test_lock()
|
||||||
|
.lock()
|
||||||
|
.expect("desync dedup test lock must be available");
|
||||||
|
clear_desync_dedup_for_testing();
|
||||||
|
|
||||||
|
let dedup = DESYNC_DEDUP.get_or_init(DashMap::new);
|
||||||
|
let base_now = Instant::now();
|
||||||
|
|
||||||
|
// Fill with fresh entries so stale-pruning does not apply.
|
||||||
|
for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 {
|
||||||
|
dedup.insert(key, base_now - TokioDuration::from_millis(10));
|
||||||
|
}
|
||||||
|
|
||||||
|
let before_keys: std::collections::HashSet<u64> = dedup.iter().map(|e| *e.key()).collect();
|
||||||
|
|
||||||
|
let newcomer_key = u64::MAX;
|
||||||
|
let emitted = should_emit_full_desync(newcomer_key, false, base_now);
|
||||||
|
assert!(
|
||||||
|
!emitted,
|
||||||
|
"new entry under full fresh cache must stay suppressed"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
dedup.get(&newcomer_key).is_some(),
|
||||||
|
"new key must be inserted after bounded eviction"
|
||||||
|
);
|
||||||
|
|
||||||
|
let after_keys: std::collections::HashSet<u64> = dedup.iter().map(|e| *e.key()).collect();
|
||||||
|
let removed_count = before_keys.difference(&after_keys).count();
|
||||||
|
let added_count = after_keys.difference(&before_keys).count();
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
removed_count, 1,
|
||||||
|
"full-cache insertion must evict exactly one prior key"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
added_count, 1,
|
||||||
|
"full-cache insertion must add exactly one newcomer key"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
dedup.len() <= DESYNC_DEDUP_MAX_ENTRIES,
|
||||||
|
"dedup cache must remain hard-bounded after full-cache churn"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn light_fuzz_desync_dedup_temporal_gate_behavior_is_stable() {
|
||||||
|
let _guard = desync_dedup_test_lock()
|
||||||
|
.lock()
|
||||||
|
.expect("desync dedup test lock must be available");
|
||||||
|
clear_desync_dedup_for_testing();
|
||||||
|
|
||||||
|
let key = 0xC0DE_CAFE_u64;
|
||||||
|
let start = Instant::now();
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
should_emit_full_desync(key, false, start),
|
||||||
|
"first event for key must emit full forensic record"
|
||||||
|
);
|
||||||
|
|
||||||
|
// Deterministic pseudo-random time deltas around dedup window edge.
|
||||||
|
let mut s: u64 = 0x1234_5678_9ABC_DEF0;
|
||||||
|
for _ in 0..2048 {
|
||||||
|
s ^= s << 7;
|
||||||
|
s ^= s >> 9;
|
||||||
|
s ^= s << 8;
|
||||||
|
|
||||||
|
let delta_ms = s % (DESYNC_DEDUP_WINDOW.as_millis() as u64 * 2 + 1);
|
||||||
|
let now = start + TokioDuration::from_millis(delta_ms);
|
||||||
|
let emitted = should_emit_full_desync(key, false, now);
|
||||||
|
|
||||||
|
if delta_ms < DESYNC_DEDUP_WINDOW.as_millis() as u64 {
|
||||||
|
assert!(
|
||||||
|
!emitted,
|
||||||
|
"events inside dedup window must remain suppressed"
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
// Once window elapsed for this key, at least one sample should re-emit and refresh.
|
||||||
|
if emitted {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
panic!("expected at least one post-window sample to re-emit forensic record");
|
||||||
|
}
|
||||||
|
|
||||||
fn make_forensics_state() -> RelayForensicsState {
|
fn make_forensics_state() -> RelayForensicsState {
|
||||||
RelayForensicsState {
|
RelayForensicsState {
|
||||||
trace_id: 1,
|
trace_id: 1,
|
||||||
|
|
@ -229,18 +420,108 @@ fn make_forensics_state() -> RelayForensicsState {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn make_crypto_reader(reader: tokio::io::DuplexStream) -> CryptoReader<tokio::io::DuplexStream> {
|
fn make_crypto_reader<R>(reader: R) -> CryptoReader<R>
|
||||||
|
where
|
||||||
|
R: tokio::io::AsyncRead + Unpin,
|
||||||
|
{
|
||||||
let key = [0u8; 32];
|
let key = [0u8; 32];
|
||||||
let iv = 0u128;
|
let iv = 0u128;
|
||||||
CryptoReader::new(reader, AesCtr::new(&key, iv))
|
CryptoReader::new(reader, AesCtr::new(&key, iv))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn make_crypto_writer(writer: tokio::io::DuplexStream) -> CryptoWriter<tokio::io::DuplexStream> {
|
fn make_crypto_writer<W>(writer: W) -> CryptoWriter<W>
|
||||||
|
where
|
||||||
|
W: tokio::io::AsyncWrite + Unpin,
|
||||||
|
{
|
||||||
let key = [0u8; 32];
|
let key = [0u8; 32];
|
||||||
let iv = 0u128;
|
let iv = 0u128;
|
||||||
CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024)
|
CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn make_me_pool_for_abort_test(stats: Arc<Stats>) -> Arc<MePool> {
|
||||||
|
let general = GeneralConfig::default();
|
||||||
|
|
||||||
|
MePool::new(
|
||||||
|
None,
|
||||||
|
vec![1u8; 32],
|
||||||
|
None,
|
||||||
|
false,
|
||||||
|
None,
|
||||||
|
Vec::new(),
|
||||||
|
1,
|
||||||
|
None,
|
||||||
|
12,
|
||||||
|
1200,
|
||||||
|
HashMap::new(),
|
||||||
|
HashMap::new(),
|
||||||
|
None,
|
||||||
|
NetworkDecision::default(),
|
||||||
|
None,
|
||||||
|
Arc::new(SecureRandom::new()),
|
||||||
|
stats,
|
||||||
|
general.me_keepalive_enabled,
|
||||||
|
general.me_keepalive_interval_secs,
|
||||||
|
general.me_keepalive_jitter_secs,
|
||||||
|
general.me_keepalive_payload_random,
|
||||||
|
general.rpc_proxy_req_every,
|
||||||
|
general.me_warmup_stagger_enabled,
|
||||||
|
general.me_warmup_step_delay_ms,
|
||||||
|
general.me_warmup_step_jitter_ms,
|
||||||
|
general.me_reconnect_max_concurrent_per_dc,
|
||||||
|
general.me_reconnect_backoff_base_ms,
|
||||||
|
general.me_reconnect_backoff_cap_ms,
|
||||||
|
general.me_reconnect_fast_retry_count,
|
||||||
|
general.me_single_endpoint_shadow_writers,
|
||||||
|
general.me_single_endpoint_outage_mode_enabled,
|
||||||
|
general.me_single_endpoint_outage_disable_quarantine,
|
||||||
|
general.me_single_endpoint_outage_backoff_min_ms,
|
||||||
|
general.me_single_endpoint_outage_backoff_max_ms,
|
||||||
|
general.me_single_endpoint_shadow_rotate_every_secs,
|
||||||
|
general.me_floor_mode,
|
||||||
|
general.me_adaptive_floor_idle_secs,
|
||||||
|
general.me_adaptive_floor_min_writers_single_endpoint,
|
||||||
|
general.me_adaptive_floor_min_writers_multi_endpoint,
|
||||||
|
general.me_adaptive_floor_recover_grace_secs,
|
||||||
|
general.me_adaptive_floor_writers_per_core_total,
|
||||||
|
general.me_adaptive_floor_cpu_cores_override,
|
||||||
|
general.me_adaptive_floor_max_extra_writers_single_per_core,
|
||||||
|
general.me_adaptive_floor_max_extra_writers_multi_per_core,
|
||||||
|
general.me_adaptive_floor_max_active_writers_per_core,
|
||||||
|
general.me_adaptive_floor_max_warm_writers_per_core,
|
||||||
|
general.me_adaptive_floor_max_active_writers_global,
|
||||||
|
general.me_adaptive_floor_max_warm_writers_global,
|
||||||
|
general.hardswap,
|
||||||
|
general.me_pool_drain_ttl_secs,
|
||||||
|
general.me_pool_drain_threshold,
|
||||||
|
general.effective_me_pool_force_close_secs(),
|
||||||
|
general.me_pool_min_fresh_ratio,
|
||||||
|
general.me_hardswap_warmup_delay_min_ms,
|
||||||
|
general.me_hardswap_warmup_delay_max_ms,
|
||||||
|
general.me_hardswap_warmup_extra_passes,
|
||||||
|
general.me_hardswap_warmup_pass_backoff_base_ms,
|
||||||
|
general.me_bind_stale_mode,
|
||||||
|
general.me_bind_stale_ttl_secs,
|
||||||
|
general.me_secret_atomic_snapshot,
|
||||||
|
general.me_deterministic_writer_sort,
|
||||||
|
MeWriterPickMode::default(),
|
||||||
|
general.me_writer_pick_sample_size,
|
||||||
|
MeSocksKdfPolicy::default(),
|
||||||
|
general.me_writer_cmd_channel_capacity,
|
||||||
|
general.me_route_channel_capacity,
|
||||||
|
general.me_route_backpressure_base_timeout_ms,
|
||||||
|
general.me_route_backpressure_high_timeout_ms,
|
||||||
|
general.me_route_backpressure_high_watermark_pct,
|
||||||
|
general.me_reader_route_data_wait_ms,
|
||||||
|
general.me_health_interval_ms_unhealthy,
|
||||||
|
general.me_health_interval_ms_healthy,
|
||||||
|
general.me_warn_rate_limit_ms,
|
||||||
|
MeRouteNoWriterMode::default(),
|
||||||
|
general.me_route_no_writer_wait_ms,
|
||||||
|
general.me_route_inline_recovery_attempts,
|
||||||
|
general.me_route_inline_recovery_wait_ms,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
fn encrypt_for_reader(plaintext: &[u8]) -> Vec<u8> {
|
fn encrypt_for_reader(plaintext: &[u8]) -> Vec<u8> {
|
||||||
let key = [0u8; 32];
|
let key = [0u8; 32];
|
||||||
let iv = 0u128;
|
let iv = 0u128;
|
||||||
|
|
@ -779,3 +1060,259 @@ async fn process_me_writer_response_data_updates_byte_accounting() {
|
||||||
"ME->C byte accounting must increase by emitted payload size"
|
"ME->C byte accounting must increase by emitted payload size"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn middle_relay_abort_midflight_releases_route_gauge() {
|
||||||
|
let stats = Arc::new(Stats::new());
|
||||||
|
let me_pool = make_me_pool_for_abort_test(stats.clone()).await;
|
||||||
|
let config = Arc::new(ProxyConfig::default());
|
||||||
|
let buffer_pool = Arc::new(BufferPool::new());
|
||||||
|
let rng = Arc::new(SecureRandom::new());
|
||||||
|
|
||||||
|
let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Middle));
|
||||||
|
let route_snapshot = route_runtime.snapshot();
|
||||||
|
|
||||||
|
let (server_side, client_side) = duplex(64 * 1024);
|
||||||
|
let (server_reader, server_writer) = tokio::io::split(server_side);
|
||||||
|
let crypto_reader = make_crypto_reader(server_reader);
|
||||||
|
let crypto_writer = make_crypto_writer(server_writer);
|
||||||
|
|
||||||
|
let success = HandshakeSuccess {
|
||||||
|
user: "abort-middle-user".to_string(),
|
||||||
|
dc_idx: 2,
|
||||||
|
proto_tag: ProtoTag::Intermediate,
|
||||||
|
dec_key: [0u8; 32],
|
||||||
|
dec_iv: 0,
|
||||||
|
enc_key: [0u8; 32],
|
||||||
|
enc_iv: 0,
|
||||||
|
peer: "127.0.0.1:50001".parse().unwrap(),
|
||||||
|
is_tls: false,
|
||||||
|
};
|
||||||
|
|
||||||
|
let relay_task = tokio::spawn(handle_via_middle_proxy(
|
||||||
|
crypto_reader,
|
||||||
|
crypto_writer,
|
||||||
|
success,
|
||||||
|
me_pool,
|
||||||
|
stats.clone(),
|
||||||
|
config,
|
||||||
|
buffer_pool,
|
||||||
|
"127.0.0.1:443".parse().unwrap(),
|
||||||
|
rng,
|
||||||
|
route_runtime.subscribe(),
|
||||||
|
route_snapshot,
|
||||||
|
0xdecafbad,
|
||||||
|
));
|
||||||
|
|
||||||
|
let started = tokio::time::timeout(TokioDuration::from_secs(2), async {
|
||||||
|
loop {
|
||||||
|
if stats.get_current_connections_me() == 1 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
tokio::time::sleep(TokioDuration::from_millis(10)).await;
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
assert!(started.is_ok(), "middle relay must increment route gauge before abort");
|
||||||
|
|
||||||
|
relay_task.abort();
|
||||||
|
let joined = relay_task.await;
|
||||||
|
assert!(joined.is_err(), "aborted middle relay task must return join error");
|
||||||
|
|
||||||
|
tokio::time::sleep(TokioDuration::from_millis(20)).await;
|
||||||
|
assert_eq!(
|
||||||
|
stats.get_current_connections_me(),
|
||||||
|
0,
|
||||||
|
"route gauge must be released when middle relay task is aborted mid-flight"
|
||||||
|
);
|
||||||
|
|
||||||
|
drop(client_side);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn middle_relay_cutover_midflight_releases_route_gauge() {
|
||||||
|
let stats = Arc::new(Stats::new());
|
||||||
|
let me_pool = make_me_pool_for_abort_test(stats.clone()).await;
|
||||||
|
let config = Arc::new(ProxyConfig::default());
|
||||||
|
let buffer_pool = Arc::new(BufferPool::new());
|
||||||
|
let rng = Arc::new(SecureRandom::new());
|
||||||
|
|
||||||
|
let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Middle));
|
||||||
|
let route_snapshot = route_runtime.snapshot();
|
||||||
|
|
||||||
|
let (server_side, client_side) = duplex(64 * 1024);
|
||||||
|
let (server_reader, server_writer) = tokio::io::split(server_side);
|
||||||
|
let crypto_reader = make_crypto_reader(server_reader);
|
||||||
|
let crypto_writer = make_crypto_writer(server_writer);
|
||||||
|
|
||||||
|
let success = HandshakeSuccess {
|
||||||
|
user: "cutover-middle-user".to_string(),
|
||||||
|
dc_idx: 2,
|
||||||
|
proto_tag: ProtoTag::Intermediate,
|
||||||
|
dec_key: [0u8; 32],
|
||||||
|
dec_iv: 0,
|
||||||
|
enc_key: [0u8; 32],
|
||||||
|
enc_iv: 0,
|
||||||
|
peer: "127.0.0.1:50003".parse().unwrap(),
|
||||||
|
is_tls: false,
|
||||||
|
};
|
||||||
|
|
||||||
|
let relay_task = tokio::spawn(handle_via_middle_proxy(
|
||||||
|
crypto_reader,
|
||||||
|
crypto_writer,
|
||||||
|
success,
|
||||||
|
me_pool,
|
||||||
|
stats.clone(),
|
||||||
|
config,
|
||||||
|
buffer_pool,
|
||||||
|
"127.0.0.1:443".parse().unwrap(),
|
||||||
|
rng,
|
||||||
|
route_runtime.subscribe(),
|
||||||
|
route_snapshot,
|
||||||
|
0xfeed_beef,
|
||||||
|
));
|
||||||
|
|
||||||
|
tokio::time::timeout(TokioDuration::from_secs(2), async {
|
||||||
|
loop {
|
||||||
|
if stats.get_current_connections_me() == 1 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
tokio::time::sleep(TokioDuration::from_millis(10)).await;
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.expect("middle relay must increment route gauge before cutover");
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
route_runtime.set_mode(RelayRouteMode::Direct).is_some(),
|
||||||
|
"cutover must advance route generation"
|
||||||
|
);
|
||||||
|
|
||||||
|
let relay_result = tokio::time::timeout(TokioDuration::from_secs(6), relay_task)
|
||||||
|
.await
|
||||||
|
.expect("middle relay must terminate after cutover")
|
||||||
|
.expect("middle relay task must not panic");
|
||||||
|
assert!(
|
||||||
|
relay_result.is_err(),
|
||||||
|
"cutover should terminate middle relay session"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
matches!(
|
||||||
|
relay_result,
|
||||||
|
Err(ProxyError::Proxy(ref msg)) if msg == ROUTE_SWITCH_ERROR_MSG
|
||||||
|
),
|
||||||
|
"client-visible cutover error must stay generic and avoid route-internal metadata"
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
stats.get_current_connections_me(),
|
||||||
|
0,
|
||||||
|
"route gauge must be released when middle relay exits on cutover"
|
||||||
|
);
|
||||||
|
|
||||||
|
drop(client_side);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn middle_relay_cutover_storm_multi_session_keeps_generic_errors_and_releases_gauge() {
|
||||||
|
let session_count = 6usize;
|
||||||
|
let stats = Arc::new(Stats::new());
|
||||||
|
let me_pool = make_me_pool_for_abort_test(stats.clone()).await;
|
||||||
|
let config = Arc::new(ProxyConfig::default());
|
||||||
|
let buffer_pool = Arc::new(BufferPool::new());
|
||||||
|
let rng = Arc::new(SecureRandom::new());
|
||||||
|
|
||||||
|
let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Middle));
|
||||||
|
let route_snapshot = route_runtime.snapshot();
|
||||||
|
|
||||||
|
let mut relay_tasks = Vec::with_capacity(session_count);
|
||||||
|
let mut client_sides = Vec::with_capacity(session_count);
|
||||||
|
|
||||||
|
for idx in 0..session_count {
|
||||||
|
let (server_side, client_side) = duplex(64 * 1024);
|
||||||
|
client_sides.push(client_side);
|
||||||
|
let (server_reader, server_writer) = tokio::io::split(server_side);
|
||||||
|
let crypto_reader = make_crypto_reader(server_reader);
|
||||||
|
let crypto_writer = make_crypto_writer(server_writer);
|
||||||
|
|
||||||
|
let success = HandshakeSuccess {
|
||||||
|
user: format!("cutover-storm-middle-user-{idx}"),
|
||||||
|
dc_idx: 2,
|
||||||
|
proto_tag: ProtoTag::Intermediate,
|
||||||
|
dec_key: [0u8; 32],
|
||||||
|
dec_iv: 0,
|
||||||
|
enc_key: [0u8; 32],
|
||||||
|
enc_iv: 0,
|
||||||
|
peer: SocketAddr::new(
|
||||||
|
std::net::IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 1)),
|
||||||
|
52000 + idx as u16,
|
||||||
|
),
|
||||||
|
is_tls: false,
|
||||||
|
};
|
||||||
|
|
||||||
|
relay_tasks.push(tokio::spawn(handle_via_middle_proxy(
|
||||||
|
crypto_reader,
|
||||||
|
crypto_writer,
|
||||||
|
success,
|
||||||
|
me_pool.clone(),
|
||||||
|
stats.clone(),
|
||||||
|
config.clone(),
|
||||||
|
buffer_pool.clone(),
|
||||||
|
"127.0.0.1:443".parse().unwrap(),
|
||||||
|
rng.clone(),
|
||||||
|
route_runtime.subscribe(),
|
||||||
|
route_snapshot,
|
||||||
|
0xB000_0000 + idx as u64,
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
tokio::time::timeout(TokioDuration::from_secs(4), async {
|
||||||
|
loop {
|
||||||
|
if stats.get_current_connections_me() == session_count as u64 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
tokio::time::sleep(TokioDuration::from_millis(10)).await;
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.expect("all middle sessions must become active before cutover storm");
|
||||||
|
|
||||||
|
let route_runtime_flipper = route_runtime.clone();
|
||||||
|
let flipper = tokio::spawn(async move {
|
||||||
|
for step in 0..64u32 {
|
||||||
|
let mode = if (step & 1) == 0 {
|
||||||
|
RelayRouteMode::Direct
|
||||||
|
} else {
|
||||||
|
RelayRouteMode::Middle
|
||||||
|
};
|
||||||
|
let _ = route_runtime_flipper.set_mode(mode);
|
||||||
|
tokio::time::sleep(TokioDuration::from_millis(15)).await;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
for relay_task in relay_tasks {
|
||||||
|
let relay_result = tokio::time::timeout(TokioDuration::from_secs(10), relay_task)
|
||||||
|
.await
|
||||||
|
.expect("middle relay task must finish under cutover storm")
|
||||||
|
.expect("middle relay task must not panic");
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
matches!(
|
||||||
|
relay_result,
|
||||||
|
Err(ProxyError::Proxy(ref msg)) if msg == ROUTE_SWITCH_ERROR_MSG
|
||||||
|
),
|
||||||
|
"storm-cutover termination must remain generic for all middle sessions"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
flipper.abort();
|
||||||
|
let _ = flipper.await;
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
stats.get_current_connections_me(),
|
||||||
|
0,
|
||||||
|
"middle route gauge must return to zero after cutover storm"
|
||||||
|
);
|
||||||
|
|
||||||
|
drop(client_sides);
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,10 @@
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::sync::atomic::{AtomicU8, AtomicU64, Ordering};
|
use std::sync::atomic::{AtomicU64, Ordering};
|
||||||
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
||||||
|
|
||||||
use tokio::sync::watch;
|
use tokio::sync::watch;
|
||||||
|
|
||||||
pub(crate) const ROUTE_SWITCH_ERROR_MSG: &str = "Route mode switched by cutover";
|
pub(crate) const ROUTE_SWITCH_ERROR_MSG: &str = "Session terminated";
|
||||||
|
|
||||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||||
#[repr(u8)]
|
#[repr(u8)]
|
||||||
|
|
@ -14,17 +14,6 @@ pub(crate) enum RelayRouteMode {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RelayRouteMode {
|
impl RelayRouteMode {
|
||||||
pub(crate) fn as_u8(self) -> u8 {
|
|
||||||
self as u8
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn from_u8(value: u8) -> Self {
|
|
||||||
match value {
|
|
||||||
1 => Self::Middle,
|
|
||||||
_ => Self::Direct,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn as_str(self) -> &'static str {
|
pub(crate) fn as_str(self) -> &'static str {
|
||||||
match self {
|
match self {
|
||||||
Self::Direct => "direct",
|
Self::Direct => "direct",
|
||||||
|
|
@ -41,8 +30,6 @@ pub(crate) struct RouteCutoverState {
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub(crate) struct RouteRuntimeController {
|
pub(crate) struct RouteRuntimeController {
|
||||||
mode: Arc<AtomicU8>,
|
|
||||||
generation: Arc<AtomicU64>,
|
|
||||||
direct_since_epoch_secs: Arc<AtomicU64>,
|
direct_since_epoch_secs: Arc<AtomicU64>,
|
||||||
tx: watch::Sender<RouteCutoverState>,
|
tx: watch::Sender<RouteCutoverState>,
|
||||||
}
|
}
|
||||||
|
|
@ -60,18 +47,13 @@ impl RouteRuntimeController {
|
||||||
0
|
0
|
||||||
};
|
};
|
||||||
Self {
|
Self {
|
||||||
mode: Arc::new(AtomicU8::new(initial_mode.as_u8())),
|
|
||||||
generation: Arc::new(AtomicU64::new(0)),
|
|
||||||
direct_since_epoch_secs: Arc::new(AtomicU64::new(direct_since_epoch_secs)),
|
direct_since_epoch_secs: Arc::new(AtomicU64::new(direct_since_epoch_secs)),
|
||||||
tx,
|
tx,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn snapshot(&self) -> RouteCutoverState {
|
pub(crate) fn snapshot(&self) -> RouteCutoverState {
|
||||||
RouteCutoverState {
|
*self.tx.borrow()
|
||||||
mode: RelayRouteMode::from_u8(self.mode.load(Ordering::Relaxed)),
|
|
||||||
generation: self.generation.load(Ordering::Relaxed),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn subscribe(&self) -> watch::Receiver<RouteCutoverState> {
|
pub(crate) fn subscribe(&self) -> watch::Receiver<RouteCutoverState> {
|
||||||
|
|
@ -84,20 +66,29 @@ impl RouteRuntimeController {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn set_mode(&self, mode: RelayRouteMode) -> Option<RouteCutoverState> {
|
pub(crate) fn set_mode(&self, mode: RelayRouteMode) -> Option<RouteCutoverState> {
|
||||||
let previous = self.mode.swap(mode.as_u8(), Ordering::Relaxed);
|
let mut next = None;
|
||||||
if previous == mode.as_u8() {
|
let changed = self.tx.send_if_modified(|state| {
|
||||||
|
if state.mode == mode {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
state.mode = mode;
|
||||||
|
state.generation = state.generation.saturating_add(1);
|
||||||
|
next = Some(*state);
|
||||||
|
true
|
||||||
|
});
|
||||||
|
|
||||||
|
if !changed {
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
|
|
||||||
if matches!(mode, RelayRouteMode::Direct) {
|
if matches!(mode, RelayRouteMode::Direct) {
|
||||||
self.direct_since_epoch_secs
|
self.direct_since_epoch_secs
|
||||||
.store(now_epoch_secs(), Ordering::Relaxed);
|
.store(now_epoch_secs(), Ordering::Relaxed);
|
||||||
} else {
|
} else {
|
||||||
self.direct_since_epoch_secs.store(0, Ordering::Relaxed);
|
self.direct_since_epoch_secs.store(0, Ordering::Relaxed);
|
||||||
}
|
}
|
||||||
let generation = self.generation.fetch_add(1, Ordering::Relaxed) + 1;
|
|
||||||
let next = RouteCutoverState { mode, generation };
|
next
|
||||||
self.tx.send_replace(next);
|
|
||||||
Some(next)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -110,10 +101,10 @@ fn now_epoch_secs() -> u64 {
|
||||||
|
|
||||||
pub(crate) fn is_session_affected_by_cutover(
|
pub(crate) fn is_session_affected_by_cutover(
|
||||||
current: RouteCutoverState,
|
current: RouteCutoverState,
|
||||||
_session_mode: RelayRouteMode,
|
session_mode: RelayRouteMode,
|
||||||
session_generation: u64,
|
session_generation: u64,
|
||||||
) -> bool {
|
) -> bool {
|
||||||
current.generation > session_generation
|
current.generation > session_generation && current.mode != session_mode
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn affected_cutover_state(
|
pub(crate) fn affected_cutover_state(
|
||||||
|
|
@ -140,3 +131,7 @@ pub(crate) fn cutover_stagger_delay(session_id: u64, generation: u64) -> Duratio
|
||||||
let ms = 1000 + (value % 1000);
|
let ms = 1000 + (value % 1000);
|
||||||
Duration::from_millis(ms)
|
Duration::from_millis(ms)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
#[path = "route_mode_security_tests.rs"]
|
||||||
|
mod security_tests;
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,340 @@
|
||||||
|
use super::*;
|
||||||
|
use rand::{Rng, SeedableRng};
|
||||||
|
use rand::rngs::StdRng;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::sync::atomic::{AtomicU64, Ordering};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cutover_stagger_delay_is_deterministic_for_same_inputs() {
|
||||||
|
let d1 = cutover_stagger_delay(0x0123_4567_89ab_cdef, 42);
|
||||||
|
let d2 = cutover_stagger_delay(0x0123_4567_89ab_cdef, 42);
|
||||||
|
assert_eq!(
|
||||||
|
d1, d2,
|
||||||
|
"stagger delay must be deterministic for identical session/generation inputs"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cutover_stagger_delay_stays_within_budget_bounds() {
|
||||||
|
// Black-hat model: censors trigger many cutovers and correlate disconnect timing.
|
||||||
|
// Keep delay inside a narrow coarse window to avoid long-tail spikes.
|
||||||
|
for generation in [0u64, 1, 2, 3, 16, 128, u32::MAX as u64, u64::MAX] {
|
||||||
|
for session_id in [
|
||||||
|
0u64,
|
||||||
|
1,
|
||||||
|
2,
|
||||||
|
0xdead_beef,
|
||||||
|
0xfeed_face_cafe_babe,
|
||||||
|
u64::MAX,
|
||||||
|
] {
|
||||||
|
let delay = cutover_stagger_delay(session_id, generation);
|
||||||
|
assert!(
|
||||||
|
(1000..=1999).contains(&delay.as_millis()),
|
||||||
|
"stagger delay must remain in fixed 1000..=1999ms budget"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cutover_stagger_delay_changes_with_generation_for_same_session() {
|
||||||
|
let session_id = 0x0123_4567_89ab_cdef;
|
||||||
|
let first = cutover_stagger_delay(session_id, 100);
|
||||||
|
let second = cutover_stagger_delay(session_id, 101);
|
||||||
|
assert_ne!(
|
||||||
|
first, second,
|
||||||
|
"adjacent cutover generations should decorrelate disconnect delays"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn route_runtime_set_mode_is_idempotent_for_same_mode() {
|
||||||
|
let runtime = RouteRuntimeController::new(RelayRouteMode::Direct);
|
||||||
|
let first = runtime.snapshot();
|
||||||
|
let changed = runtime.set_mode(RelayRouteMode::Direct);
|
||||||
|
let second = runtime.snapshot();
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
changed.is_none(),
|
||||||
|
"setting already-active mode must not produce a cutover event"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
first.generation, second.generation,
|
||||||
|
"idempotent mode set must not bump generation"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn affected_cutover_state_triggers_only_for_newer_generation() {
|
||||||
|
let runtime = RouteRuntimeController::new(RelayRouteMode::Direct);
|
||||||
|
let rx = runtime.subscribe();
|
||||||
|
let initial = runtime.snapshot();
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
affected_cutover_state(&rx, RelayRouteMode::Direct, initial.generation).is_none(),
|
||||||
|
"current generation must not be considered a cutover for existing session"
|
||||||
|
);
|
||||||
|
|
||||||
|
let next = runtime
|
||||||
|
.set_mode(RelayRouteMode::Middle)
|
||||||
|
.expect("mode change must produce cutover state");
|
||||||
|
let seen = affected_cutover_state(&rx, RelayRouteMode::Direct, initial.generation)
|
||||||
|
.expect("newer generation must be observed as cutover");
|
||||||
|
|
||||||
|
assert_eq!(seen.generation, next.generation);
|
||||||
|
assert_eq!(seen.mode, RelayRouteMode::Middle);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn integration_watch_and_snapshot_follow_same_transition_sequence() {
|
||||||
|
let runtime = RouteRuntimeController::new(RelayRouteMode::Direct);
|
||||||
|
let rx = runtime.subscribe();
|
||||||
|
|
||||||
|
let sequence = [
|
||||||
|
RelayRouteMode::Middle,
|
||||||
|
RelayRouteMode::Middle,
|
||||||
|
RelayRouteMode::Direct,
|
||||||
|
RelayRouteMode::Direct,
|
||||||
|
RelayRouteMode::Middle,
|
||||||
|
];
|
||||||
|
|
||||||
|
let mut expected_generation = 0u64;
|
||||||
|
let mut expected_mode = RelayRouteMode::Direct;
|
||||||
|
|
||||||
|
for target in sequence {
|
||||||
|
let changed = runtime.set_mode(target);
|
||||||
|
if target == expected_mode {
|
||||||
|
assert!(changed.is_none(), "idempotent transition must return none");
|
||||||
|
} else {
|
||||||
|
expected_mode = target;
|
||||||
|
expected_generation = expected_generation.saturating_add(1);
|
||||||
|
let emitted = changed.expect("real transition must emit cutover state");
|
||||||
|
assert_eq!(emitted.mode, expected_mode);
|
||||||
|
assert_eq!(emitted.generation, expected_generation);
|
||||||
|
}
|
||||||
|
|
||||||
|
let snap = runtime.snapshot();
|
||||||
|
let watched = *rx.borrow();
|
||||||
|
assert_eq!(snap, watched, "snapshot and watch state must stay aligned");
|
||||||
|
assert_eq!(snap.mode, expected_mode);
|
||||||
|
assert_eq!(snap.generation, expected_generation);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn session_is_not_affected_when_mode_matches_even_if_generation_advanced() {
|
||||||
|
let session_mode = RelayRouteMode::Direct;
|
||||||
|
let current = RouteCutoverState {
|
||||||
|
mode: RelayRouteMode::Direct,
|
||||||
|
generation: 2,
|
||||||
|
};
|
||||||
|
let session_generation = 0;
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
!is_session_affected_by_cutover(current, session_mode, session_generation),
|
||||||
|
"session on matching final route mode should not be force-cut over on intermediate generation bumps"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cutover_predicate_rejects_equal_generation_even_if_mode_differs() {
|
||||||
|
let current = RouteCutoverState {
|
||||||
|
mode: RelayRouteMode::Middle,
|
||||||
|
generation: 77,
|
||||||
|
};
|
||||||
|
assert!(
|
||||||
|
!is_session_affected_by_cutover(current, RelayRouteMode::Direct, 77),
|
||||||
|
"equal generation must never trigger cutover regardless of mode mismatch"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn adversarial_route_oscillation_only_cuts_over_sessions_with_different_final_mode() {
|
||||||
|
let runtime = RouteRuntimeController::new(RelayRouteMode::Direct);
|
||||||
|
let rx = runtime.subscribe();
|
||||||
|
let session_generation = runtime.snapshot().generation;
|
||||||
|
|
||||||
|
runtime
|
||||||
|
.set_mode(RelayRouteMode::Middle)
|
||||||
|
.expect("direct->middle must transition");
|
||||||
|
runtime
|
||||||
|
.set_mode(RelayRouteMode::Direct)
|
||||||
|
.expect("middle->direct must transition");
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
affected_cutover_state(&rx, RelayRouteMode::Direct, session_generation).is_none(),
|
||||||
|
"direct session should survive when final mode returns to direct"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
affected_cutover_state(&rx, RelayRouteMode::Middle, session_generation).is_some(),
|
||||||
|
"middle session should be cut over when final mode is direct"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn light_fuzz_cutover_predicate_matches_reference_oracle() {
|
||||||
|
let mut rng = StdRng::seed_from_u64(0xC0DEC0DE5EED);
|
||||||
|
for _ in 0..20_000 {
|
||||||
|
let current = RouteCutoverState {
|
||||||
|
mode: if rng.random::<bool>() {
|
||||||
|
RelayRouteMode::Direct
|
||||||
|
} else {
|
||||||
|
RelayRouteMode::Middle
|
||||||
|
},
|
||||||
|
generation: rng.random_range(0u64..1_000_000),
|
||||||
|
};
|
||||||
|
let session_mode = if rng.random::<bool>() {
|
||||||
|
RelayRouteMode::Direct
|
||||||
|
} else {
|
||||||
|
RelayRouteMode::Middle
|
||||||
|
};
|
||||||
|
let session_generation = rng.random_range(0u64..1_000_000);
|
||||||
|
|
||||||
|
let expected = current.generation > session_generation && current.mode != session_mode;
|
||||||
|
let actual = is_session_affected_by_cutover(current, session_mode, session_generation);
|
||||||
|
assert_eq!(
|
||||||
|
actual, expected,
|
||||||
|
"cutover predicate must match mode-aware generation oracle"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn light_fuzz_set_mode_generation_tracks_only_real_transitions() {
|
||||||
|
let runtime = RouteRuntimeController::new(RelayRouteMode::Direct);
|
||||||
|
let mut rng = StdRng::seed_from_u64(0x0DDC0FFE);
|
||||||
|
|
||||||
|
let mut expected_mode = RelayRouteMode::Direct;
|
||||||
|
let mut expected_generation = 0u64;
|
||||||
|
|
||||||
|
for _ in 0..10_000 {
|
||||||
|
let candidate = if rng.random::<bool>() {
|
||||||
|
RelayRouteMode::Direct
|
||||||
|
} else {
|
||||||
|
RelayRouteMode::Middle
|
||||||
|
};
|
||||||
|
let changed = runtime.set_mode(candidate);
|
||||||
|
|
||||||
|
if candidate == expected_mode {
|
||||||
|
assert!(changed.is_none(), "idempotent set_mode must not emit cutover state");
|
||||||
|
} else {
|
||||||
|
expected_mode = candidate;
|
||||||
|
expected_generation = expected_generation.saturating_add(1);
|
||||||
|
let next = changed.expect("mode transition must emit cutover state");
|
||||||
|
assert_eq!(next.mode, expected_mode);
|
||||||
|
assert_eq!(next.generation, expected_generation);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let final_state = runtime.snapshot();
|
||||||
|
assert_eq!(final_state.mode, expected_mode);
|
||||||
|
assert_eq!(final_state.generation, expected_generation);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn stress_snapshot_and_watch_state_remain_consistent_under_concurrent_switch_storm() {
|
||||||
|
let runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct));
|
||||||
|
|
||||||
|
std::thread::scope(|scope| {
|
||||||
|
let mut writers = Vec::new();
|
||||||
|
for worker in 0..4usize {
|
||||||
|
let runtime = Arc::clone(&runtime);
|
||||||
|
writers.push(scope.spawn(move || {
|
||||||
|
for step in 0..20_000usize {
|
||||||
|
let mode = if (worker + step) % 2 == 0 {
|
||||||
|
RelayRouteMode::Direct
|
||||||
|
} else {
|
||||||
|
RelayRouteMode::Middle
|
||||||
|
};
|
||||||
|
let _ = runtime.set_mode(mode);
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
for writer in writers {
|
||||||
|
writer
|
||||||
|
.join()
|
||||||
|
.expect("route mode writer thread must not panic");
|
||||||
|
}
|
||||||
|
|
||||||
|
let rx = runtime.subscribe();
|
||||||
|
for _ in 0..128 {
|
||||||
|
assert_eq!(
|
||||||
|
runtime.snapshot(),
|
||||||
|
*rx.borrow(),
|
||||||
|
"snapshot and watch state must converge after concurrent set_mode churn"
|
||||||
|
);
|
||||||
|
std::thread::yield_now();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn stress_concurrent_transition_count_matches_final_generation() {
|
||||||
|
let runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct));
|
||||||
|
let successful_transitions = Arc::new(AtomicU64::new(0));
|
||||||
|
|
||||||
|
std::thread::scope(|scope| {
|
||||||
|
let mut workers = Vec::new();
|
||||||
|
for worker in 0..6usize {
|
||||||
|
let runtime = Arc::clone(&runtime);
|
||||||
|
let successful_transitions = Arc::clone(&successful_transitions);
|
||||||
|
workers.push(scope.spawn(move || {
|
||||||
|
let mut state = (worker as u64 + 1).wrapping_mul(0x9E37_79B9_7F4A_7C15);
|
||||||
|
for _ in 0..25_000usize {
|
||||||
|
state ^= state << 7;
|
||||||
|
state ^= state >> 9;
|
||||||
|
state ^= state << 8;
|
||||||
|
let mode = if (state & 1) == 0 {
|
||||||
|
RelayRouteMode::Direct
|
||||||
|
} else {
|
||||||
|
RelayRouteMode::Middle
|
||||||
|
};
|
||||||
|
if runtime.set_mode(mode).is_some() {
|
||||||
|
successful_transitions.fetch_add(1, Ordering::Relaxed);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
for worker in workers {
|
||||||
|
worker.join().expect("route mode transition worker must not panic");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let final_state = runtime.snapshot();
|
||||||
|
assert_eq!(
|
||||||
|
final_state.generation,
|
||||||
|
successful_transitions.load(Ordering::Relaxed),
|
||||||
|
"final generation must equal number of accepted mode transitions"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
final_state,
|
||||||
|
*runtime.subscribe().borrow(),
|
||||||
|
"watch and snapshot state must match after concurrent transition accounting"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn light_fuzz_cutover_stagger_delay_distribution_stays_in_fixed_window() {
|
||||||
|
// Deterministic xorshift fuzzing keeps this test stable across runs.
|
||||||
|
let mut s: u64 = 0x9E37_79B9_7F4A_7C15;
|
||||||
|
|
||||||
|
for _ in 0..20_000 {
|
||||||
|
s ^= s << 7;
|
||||||
|
s ^= s >> 9;
|
||||||
|
s ^= s << 8;
|
||||||
|
let session_id = s;
|
||||||
|
|
||||||
|
s ^= s << 7;
|
||||||
|
s ^= s >> 9;
|
||||||
|
s ^= s << 8;
|
||||||
|
let generation = s;
|
||||||
|
|
||||||
|
let delay = cutover_stagger_delay(session_id, generation);
|
||||||
|
assert!(
|
||||||
|
(1000..=1999).contains(&delay.as_millis()),
|
||||||
|
"fuzzed inputs must always map into fixed stagger window"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,265 @@
|
||||||
|
use super::*;
|
||||||
|
use std::panic::{self, AssertUnwindSafe};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration;
|
||||||
|
use tokio::sync::Barrier;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn direct_connection_lease_balances_on_drop() {
|
||||||
|
let stats = Arc::new(Stats::new());
|
||||||
|
assert_eq!(stats.get_current_connections_direct(), 0);
|
||||||
|
|
||||||
|
{
|
||||||
|
let _lease = stats.acquire_direct_connection_lease();
|
||||||
|
assert_eq!(stats.get_current_connections_direct(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
assert_eq!(stats.get_current_connections_direct(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn middle_connection_lease_balances_on_drop() {
|
||||||
|
let stats = Arc::new(Stats::new());
|
||||||
|
assert_eq!(stats.get_current_connections_me(), 0);
|
||||||
|
|
||||||
|
{
|
||||||
|
let _lease = stats.acquire_me_connection_lease();
|
||||||
|
assert_eq!(stats.get_current_connections_me(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
assert_eq!(stats.get_current_connections_me(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn connection_lease_disarm_prevents_double_release() {
|
||||||
|
let stats = Arc::new(Stats::new());
|
||||||
|
|
||||||
|
let mut lease = stats.acquire_direct_connection_lease();
|
||||||
|
assert_eq!(stats.get_current_connections_direct(), 1);
|
||||||
|
|
||||||
|
stats.decrement_current_connections_direct();
|
||||||
|
assert_eq!(stats.get_current_connections_direct(), 0);
|
||||||
|
|
||||||
|
lease.disarm();
|
||||||
|
drop(lease);
|
||||||
|
|
||||||
|
assert_eq!(stats.get_current_connections_direct(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn direct_connection_lease_balances_on_panic_unwind() {
|
||||||
|
let stats = Arc::new(Stats::new());
|
||||||
|
let stats_for_panic = stats.clone();
|
||||||
|
|
||||||
|
let panic_result = panic::catch_unwind(AssertUnwindSafe(move || {
|
||||||
|
let _lease = stats_for_panic.acquire_direct_connection_lease();
|
||||||
|
panic!("intentional panic to verify lease drop path");
|
||||||
|
}));
|
||||||
|
|
||||||
|
assert!(panic_result.is_err(), "panic must propagate from test closure");
|
||||||
|
assert_eq!(
|
||||||
|
stats.get_current_connections_direct(),
|
||||||
|
0,
|
||||||
|
"panic unwind must release direct route gauge"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn middle_connection_lease_balances_on_panic_unwind() {
|
||||||
|
let stats = Arc::new(Stats::new());
|
||||||
|
let stats_for_panic = stats.clone();
|
||||||
|
|
||||||
|
let panic_result = panic::catch_unwind(AssertUnwindSafe(move || {
|
||||||
|
let _lease = stats_for_panic.acquire_me_connection_lease();
|
||||||
|
panic!("intentional panic to verify middle lease drop path");
|
||||||
|
}));
|
||||||
|
|
||||||
|
assert!(panic_result.is_err(), "panic must propagate from test closure");
|
||||||
|
assert_eq!(
|
||||||
|
stats.get_current_connections_me(),
|
||||||
|
0,
|
||||||
|
"panic unwind must release middle route gauge"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn concurrent_mixed_route_lease_churn_balances_to_zero() {
|
||||||
|
const TASKS: usize = 48;
|
||||||
|
const ITERATIONS_PER_TASK: usize = 256;
|
||||||
|
|
||||||
|
let stats = Arc::new(Stats::new());
|
||||||
|
let barrier = Arc::new(Barrier::new(TASKS));
|
||||||
|
let mut workers = Vec::with_capacity(TASKS);
|
||||||
|
|
||||||
|
for task_idx in 0..TASKS {
|
||||||
|
let stats_for_task = stats.clone();
|
||||||
|
let barrier_for_task = barrier.clone();
|
||||||
|
workers.push(tokio::spawn(async move {
|
||||||
|
barrier_for_task.wait().await;
|
||||||
|
for iter in 0..ITERATIONS_PER_TASK {
|
||||||
|
if (task_idx + iter) % 2 == 0 {
|
||||||
|
let _lease = stats_for_task.acquire_direct_connection_lease();
|
||||||
|
tokio::task::yield_now().await;
|
||||||
|
} else {
|
||||||
|
let _lease = stats_for_task.acquire_me_connection_lease();
|
||||||
|
tokio::task::yield_now().await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
for worker in workers {
|
||||||
|
worker
|
||||||
|
.await
|
||||||
|
.expect("lease churn worker must not panic");
|
||||||
|
}
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
stats.get_current_connections_direct(),
|
||||||
|
0,
|
||||||
|
"direct route gauge must return to zero after concurrent lease churn"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
stats.get_current_connections_me(),
|
||||||
|
0,
|
||||||
|
"middle route gauge must return to zero after concurrent lease churn"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn abort_storm_mixed_route_leases_returns_all_gauges_to_zero() {
|
||||||
|
const TASKS: usize = 64;
|
||||||
|
|
||||||
|
let stats = Arc::new(Stats::new());
|
||||||
|
let mut workers = Vec::with_capacity(TASKS);
|
||||||
|
|
||||||
|
for task_idx in 0..TASKS {
|
||||||
|
let stats_for_task = stats.clone();
|
||||||
|
workers.push(tokio::spawn(async move {
|
||||||
|
if task_idx % 2 == 0 {
|
||||||
|
let _lease = stats_for_task.acquire_direct_connection_lease();
|
||||||
|
tokio::time::sleep(Duration::from_secs(60)).await;
|
||||||
|
} else {
|
||||||
|
let _lease = stats_for_task.acquire_me_connection_lease();
|
||||||
|
tokio::time::sleep(Duration::from_secs(60)).await;
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
tokio::time::timeout(Duration::from_secs(2), async {
|
||||||
|
loop {
|
||||||
|
let total = stats.get_current_connections_direct() + stats.get_current_connections_me();
|
||||||
|
if total == TASKS as u64 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
tokio::time::sleep(Duration::from_millis(10)).await;
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.expect("all storm tasks must acquire route leases before abort");
|
||||||
|
|
||||||
|
for worker in &workers {
|
||||||
|
worker.abort();
|
||||||
|
}
|
||||||
|
for worker in workers {
|
||||||
|
let joined = worker.await;
|
||||||
|
assert!(joined.is_err(), "aborted worker must return join error");
|
||||||
|
}
|
||||||
|
|
||||||
|
tokio::time::timeout(Duration::from_secs(2), async {
|
||||||
|
loop {
|
||||||
|
if stats.get_current_connections_direct() == 0 && stats.get_current_connections_me() == 0 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
tokio::time::sleep(Duration::from_millis(10)).await;
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.expect("all route gauges must drain to zero after abort storm");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn saturating_route_decrements_do_not_underflow_under_race() {
|
||||||
|
const THREADS: usize = 16;
|
||||||
|
const DECREMENTS_PER_THREAD: usize = 4096;
|
||||||
|
|
||||||
|
let stats = Arc::new(Stats::new());
|
||||||
|
let mut workers = Vec::with_capacity(THREADS);
|
||||||
|
|
||||||
|
for _ in 0..THREADS {
|
||||||
|
let stats_for_thread = stats.clone();
|
||||||
|
workers.push(std::thread::spawn(move || {
|
||||||
|
for _ in 0..DECREMENTS_PER_THREAD {
|
||||||
|
stats_for_thread.decrement_current_connections_direct();
|
||||||
|
stats_for_thread.decrement_current_connections_me();
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
for worker in workers {
|
||||||
|
worker
|
||||||
|
.join()
|
||||||
|
.expect("decrement race worker must not panic");
|
||||||
|
}
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
stats.get_current_connections_direct(),
|
||||||
|
0,
|
||||||
|
"direct route decrement races must never underflow"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
stats.get_current_connections_me(),
|
||||||
|
0,
|
||||||
|
"middle route decrement races must never underflow"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn direct_connection_lease_balances_on_task_abort() {
|
||||||
|
let stats = Arc::new(Stats::new());
|
||||||
|
let stats_for_task = stats.clone();
|
||||||
|
|
||||||
|
let task = tokio::spawn(async move {
|
||||||
|
let _lease = stats_for_task.acquire_direct_connection_lease();
|
||||||
|
tokio::time::sleep(Duration::from_secs(60)).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
tokio::time::sleep(Duration::from_millis(20)).await;
|
||||||
|
assert_eq!(stats.get_current_connections_direct(), 1);
|
||||||
|
|
||||||
|
task.abort();
|
||||||
|
let joined = task.await;
|
||||||
|
assert!(joined.is_err(), "aborted task must return a join error");
|
||||||
|
|
||||||
|
tokio::time::sleep(Duration::from_millis(20)).await;
|
||||||
|
assert_eq!(
|
||||||
|
stats.get_current_connections_direct(),
|
||||||
|
0,
|
||||||
|
"aborted task must release direct route gauge"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn middle_connection_lease_balances_on_task_abort() {
|
||||||
|
let stats = Arc::new(Stats::new());
|
||||||
|
let stats_for_task = stats.clone();
|
||||||
|
|
||||||
|
let task = tokio::spawn(async move {
|
||||||
|
let _lease = stats_for_task.acquire_me_connection_lease();
|
||||||
|
tokio::time::sleep(Duration::from_secs(60)).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
tokio::time::sleep(Duration::from_millis(20)).await;
|
||||||
|
assert_eq!(stats.get_current_connections_me(), 1);
|
||||||
|
|
||||||
|
task.abort();
|
||||||
|
let joined = task.await;
|
||||||
|
assert!(joined.is_err(), "aborted task must return a join error");
|
||||||
|
|
||||||
|
tokio::time::sleep(Duration::from_millis(20)).await;
|
||||||
|
assert_eq!(
|
||||||
|
stats.get_current_connections_me(),
|
||||||
|
0,
|
||||||
|
"aborted task must release middle route gauge"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
@ -6,6 +6,7 @@ pub mod beobachten;
|
||||||
pub mod telemetry;
|
pub mod telemetry;
|
||||||
|
|
||||||
use std::sync::atomic::{AtomicBool, AtomicU8, AtomicU64, Ordering};
|
use std::sync::atomic::{AtomicBool, AtomicU8, AtomicU64, Ordering};
|
||||||
|
use std::sync::Arc;
|
||||||
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
|
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
|
||||||
use dashmap::DashMap;
|
use dashmap::DashMap;
|
||||||
use parking_lot::Mutex;
|
use parking_lot::Mutex;
|
||||||
|
|
@ -19,6 +20,46 @@ use tracing::debug;
|
||||||
use crate::config::{MeTelemetryLevel, MeWriterPickMode};
|
use crate::config::{MeTelemetryLevel, MeWriterPickMode};
|
||||||
use self::telemetry::TelemetryPolicy;
|
use self::telemetry::TelemetryPolicy;
|
||||||
|
|
||||||
|
#[derive(Clone, Copy)]
|
||||||
|
enum RouteConnectionGauge {
|
||||||
|
Direct,
|
||||||
|
Middle,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use = "RouteConnectionLease must be kept alive to hold the connection gauge increment"]
|
||||||
|
pub struct RouteConnectionLease {
|
||||||
|
stats: Arc<Stats>,
|
||||||
|
gauge: RouteConnectionGauge,
|
||||||
|
active: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RouteConnectionLease {
|
||||||
|
fn new(stats: Arc<Stats>, gauge: RouteConnectionGauge) -> Self {
|
||||||
|
Self {
|
||||||
|
stats,
|
||||||
|
gauge,
|
||||||
|
active: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
fn disarm(&mut self) {
|
||||||
|
self.active = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for RouteConnectionLease {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
if !self.active {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
match self.gauge {
|
||||||
|
RouteConnectionGauge::Direct => self.stats.decrement_current_connections_direct(),
|
||||||
|
RouteConnectionGauge::Middle => self.stats.decrement_current_connections_me(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ============= Stats =============
|
// ============= Stats =============
|
||||||
|
|
||||||
#[derive(Default)]
|
#[derive(Default)]
|
||||||
|
|
@ -285,6 +326,16 @@ impl Stats {
|
||||||
pub fn decrement_current_connections_me(&self) {
|
pub fn decrement_current_connections_me(&self) {
|
||||||
Self::decrement_atomic_saturating(&self.current_connections_me);
|
Self::decrement_atomic_saturating(&self.current_connections_me);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn acquire_direct_connection_lease(self: &Arc<Self>) -> RouteConnectionLease {
|
||||||
|
self.increment_current_connections_direct();
|
||||||
|
RouteConnectionLease::new(self.clone(), RouteConnectionGauge::Direct)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn acquire_me_connection_lease(self: &Arc<Self>) -> RouteConnectionLease {
|
||||||
|
self.increment_current_connections_me();
|
||||||
|
RouteConnectionLease::new(self.clone(), RouteConnectionGauge::Middle)
|
||||||
|
}
|
||||||
pub fn increment_handshake_timeouts(&self) {
|
pub fn increment_handshake_timeouts(&self) {
|
||||||
if self.telemetry_core_enabled() {
|
if self.telemetry_core_enabled() {
|
||||||
self.handshake_timeouts.fetch_add(1, Ordering::Relaxed);
|
self.handshake_timeouts.fetch_add(1, Ordering::Relaxed);
|
||||||
|
|
@ -1772,3 +1823,7 @@ mod tests {
|
||||||
assert_eq!(checker.stats().total_entries, 500);
|
assert_eq!(checker.stats().total_entries, 500);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
#[path = "connection_lease_security_tests.rs"]
|
||||||
|
mod connection_lease_security_tests;
|
||||||
|
|
|
||||||
|
|
@ -103,7 +103,7 @@ pub fn build_emulated_server_hello(
|
||||||
cached: &CachedTlsData,
|
cached: &CachedTlsData,
|
||||||
use_full_cert_payload: bool,
|
use_full_cert_payload: bool,
|
||||||
rng: &SecureRandom,
|
rng: &SecureRandom,
|
||||||
alpn: Option<Vec<u8>>,
|
_alpn: Option<Vec<u8>>,
|
||||||
new_session_tickets: u8,
|
new_session_tickets: u8,
|
||||||
) -> Vec<u8> {
|
) -> Vec<u8> {
|
||||||
// --- ServerHello ---
|
// --- ServerHello ---
|
||||||
|
|
@ -117,15 +117,6 @@ pub fn build_emulated_server_hello(
|
||||||
extensions.extend_from_slice(&0x002bu16.to_be_bytes());
|
extensions.extend_from_slice(&0x002bu16.to_be_bytes());
|
||||||
extensions.extend_from_slice(&(2u16).to_be_bytes());
|
extensions.extend_from_slice(&(2u16).to_be_bytes());
|
||||||
extensions.extend_from_slice(&0x0304u16.to_be_bytes());
|
extensions.extend_from_slice(&0x0304u16.to_be_bytes());
|
||||||
if let Some(alpn_proto) = &alpn {
|
|
||||||
extensions.extend_from_slice(&0x0010u16.to_be_bytes());
|
|
||||||
let list_len: u16 = 1 + alpn_proto.len() as u16;
|
|
||||||
let ext_len: u16 = 2 + list_len;
|
|
||||||
extensions.extend_from_slice(&ext_len.to_be_bytes());
|
|
||||||
extensions.extend_from_slice(&list_len.to_be_bytes());
|
|
||||||
extensions.push(alpn_proto.len() as u8);
|
|
||||||
extensions.extend_from_slice(alpn_proto);
|
|
||||||
}
|
|
||||||
let extensions_len = extensions.len() as u16;
|
let extensions_len = extensions.len() as u16;
|
||||||
|
|
||||||
let body_len = 2 + // version
|
let body_len = 2 + // version
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,9 @@ const HEALTH_RECONNECT_BUDGET_PER_CORE: usize = 2;
|
||||||
const HEALTH_RECONNECT_BUDGET_PER_DC: usize = 1;
|
const HEALTH_RECONNECT_BUDGET_PER_DC: usize = 1;
|
||||||
const HEALTH_RECONNECT_BUDGET_MIN: usize = 4;
|
const HEALTH_RECONNECT_BUDGET_MIN: usize = 4;
|
||||||
const HEALTH_RECONNECT_BUDGET_MAX: usize = 128;
|
const HEALTH_RECONNECT_BUDGET_MAX: usize = 128;
|
||||||
|
const HEALTH_DRAIN_CLOSE_BUDGET_PER_CORE: usize = 16;
|
||||||
|
const HEALTH_DRAIN_CLOSE_BUDGET_MIN: usize = 16;
|
||||||
|
const HEALTH_DRAIN_CLOSE_BUDGET_MAX: usize = 256;
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
struct DcFloorPlanEntry {
|
struct DcFloorPlanEntry {
|
||||||
|
|
@ -111,106 +114,75 @@ pub async fn me_health_monitor(pool: Arc<MePool>, rng: Arc<SecureRandom>, _min_c
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn reap_draining_writers(
|
pub(super) async fn reap_draining_writers(
|
||||||
pool: &Arc<MePool>,
|
pool: &Arc<MePool>,
|
||||||
warn_next_allowed: &mut HashMap<u64, Instant>,
|
warn_next_allowed: &mut HashMap<u64, Instant>,
|
||||||
) {
|
) {
|
||||||
if pool.draining_active_runtime() == 0 {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
let now_epoch_secs = MePool::now_epoch_secs();
|
let now_epoch_secs = MePool::now_epoch_secs();
|
||||||
let now = Instant::now();
|
let now = Instant::now();
|
||||||
let drain_ttl_secs = pool.me_pool_drain_ttl_secs.load(std::sync::atomic::Ordering::Relaxed);
|
let drain_ttl_secs = pool.me_pool_drain_ttl_secs.load(std::sync::atomic::Ordering::Relaxed);
|
||||||
let drain_threshold = pool
|
let drain_threshold = pool
|
||||||
.me_pool_drain_threshold
|
.me_pool_drain_threshold
|
||||||
.load(std::sync::atomic::Ordering::Relaxed);
|
.load(std::sync::atomic::Ordering::Relaxed);
|
||||||
let mut draining_writers = {
|
let activity = pool.registry.writer_activity_snapshot().await;
|
||||||
let writers = pool.writers.read().await;
|
let mut draining_writers = Vec::<DrainingWriterSnapshot>::new();
|
||||||
let mut draining_writers = Vec::<DrainingWriterSnapshot>::new();
|
let mut empty_writer_ids = Vec::<u64>::new();
|
||||||
for writer in writers.iter() {
|
let mut force_close_writer_ids = Vec::<u64>::new();
|
||||||
if !writer.draining.load(std::sync::atomic::Ordering::Relaxed) {
|
let writers = pool.writers.read().await;
|
||||||
continue;
|
for writer in writers.iter() {
|
||||||
}
|
if !writer.draining.load(std::sync::atomic::Ordering::Relaxed) {
|
||||||
draining_writers.push(DrainingWriterSnapshot {
|
continue;
|
||||||
id: writer.id,
|
|
||||||
writer_dc: writer.writer_dc,
|
|
||||||
addr: writer.addr,
|
|
||||||
generation: writer.generation,
|
|
||||||
created_at: writer.created_at,
|
|
||||||
draining_started_at_epoch_secs: writer
|
|
||||||
.draining_started_at_epoch_secs
|
|
||||||
.load(std::sync::atomic::Ordering::Relaxed),
|
|
||||||
drain_deadline_epoch_secs: writer
|
|
||||||
.drain_deadline_epoch_secs
|
|
||||||
.load(std::sync::atomic::Ordering::Relaxed),
|
|
||||||
allow_drain_fallback: writer
|
|
||||||
.allow_drain_fallback
|
|
||||||
.load(std::sync::atomic::Ordering::Relaxed),
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
draining_writers
|
if activity
|
||||||
};
|
.bound_clients_by_writer
|
||||||
|
.get(&writer.id)
|
||||||
if draining_writers.is_empty() {
|
.copied()
|
||||||
return;
|
.unwrap_or(0)
|
||||||
}
|
== 0
|
||||||
|
{
|
||||||
let draining_ids: Vec<u64> = draining_writers.iter().map(|writer| writer.id).collect();
|
empty_writer_ids.push(writer.id);
|
||||||
let non_empty_writer_ids = pool.registry.non_empty_writer_ids(&draining_ids).await;
|
continue;
|
||||||
let mut non_empty_draining_writers =
|
|
||||||
Vec::<DrainingWriterSnapshot>::with_capacity(draining_writers.len());
|
|
||||||
for writer in draining_writers.drain(..) {
|
|
||||||
if non_empty_writer_ids.contains(&writer.id) {
|
|
||||||
non_empty_draining_writers.push(writer);
|
|
||||||
} else {
|
|
||||||
pool.remove_writer_and_close_clients(writer.id).await;
|
|
||||||
}
|
}
|
||||||
|
draining_writers.push(DrainingWriterSnapshot {
|
||||||
|
id: writer.id,
|
||||||
|
writer_dc: writer.writer_dc,
|
||||||
|
addr: writer.addr,
|
||||||
|
generation: writer.generation,
|
||||||
|
created_at: writer.created_at,
|
||||||
|
draining_started_at_epoch_secs: writer
|
||||||
|
.draining_started_at_epoch_secs
|
||||||
|
.load(std::sync::atomic::Ordering::Relaxed),
|
||||||
|
drain_deadline_epoch_secs: writer
|
||||||
|
.drain_deadline_epoch_secs
|
||||||
|
.load(std::sync::atomic::Ordering::Relaxed),
|
||||||
|
allow_drain_fallback: writer
|
||||||
|
.allow_drain_fallback
|
||||||
|
.load(std::sync::atomic::Ordering::Relaxed),
|
||||||
|
});
|
||||||
}
|
}
|
||||||
draining_writers = non_empty_draining_writers;
|
drop(writers);
|
||||||
if draining_writers.is_empty() {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
let overflow = if drain_threshold > 0 && draining_writers.len() > drain_threshold as usize {
|
let overflow = if drain_threshold > 0 && draining_writers.len() > drain_threshold as usize {
|
||||||
draining_writers.len().saturating_sub(drain_threshold as usize)
|
draining_writers.len().saturating_sub(drain_threshold as usize)
|
||||||
} else {
|
} else {
|
||||||
0
|
0
|
||||||
};
|
};
|
||||||
let has_deadline_expired = draining_writers.iter().any(|writer| {
|
|
||||||
writer.drain_deadline_epoch_secs != 0 && now_epoch_secs >= writer.drain_deadline_epoch_secs
|
|
||||||
});
|
|
||||||
let can_drop_with_replacement = if overflow > 0 || has_deadline_expired {
|
|
||||||
pool.has_non_draining_writer_per_desired_dc_group().await
|
|
||||||
} else {
|
|
||||||
false
|
|
||||||
};
|
|
||||||
|
|
||||||
if overflow > 0 {
|
if overflow > 0 {
|
||||||
if can_drop_with_replacement {
|
draining_writers.sort_by(|left, right| {
|
||||||
draining_writers.sort_by(|left, right| {
|
left.draining_started_at_epoch_secs
|
||||||
left.draining_started_at_epoch_secs
|
.cmp(&right.draining_started_at_epoch_secs)
|
||||||
.cmp(&right.draining_started_at_epoch_secs)
|
.then_with(|| left.created_at.cmp(&right.created_at))
|
||||||
.then_with(|| left.created_at.cmp(&right.created_at))
|
.then_with(|| left.id.cmp(&right.id))
|
||||||
.then_with(|| left.id.cmp(&right.id))
|
});
|
||||||
});
|
warn!(
|
||||||
warn!(
|
draining_writers = draining_writers.len(),
|
||||||
draining_writers = draining_writers.len(),
|
me_pool_drain_threshold = drain_threshold,
|
||||||
me_pool_drain_threshold = drain_threshold,
|
removing_writers = overflow,
|
||||||
removing_writers = overflow,
|
"ME draining writer threshold exceeded, force-closing oldest draining writers"
|
||||||
"ME draining writer threshold exceeded, force-closing oldest draining writers"
|
);
|
||||||
);
|
for writer in draining_writers.drain(..overflow) {
|
||||||
for writer in draining_writers.drain(..overflow) {
|
force_close_writer_ids.push(writer.id);
|
||||||
pool.stats.increment_pool_force_close_total();
|
|
||||||
pool.remove_writer_and_close_clients(writer.id).await;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
warn!(
|
|
||||||
draining_writers = draining_writers.len(),
|
|
||||||
me_pool_drain_threshold = drain_threshold,
|
|
||||||
overflow,
|
|
||||||
"ME draining threshold exceeded, but replacement coverage is incomplete; keeping draining writers"
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -238,25 +210,71 @@ async fn reap_draining_writers(
|
||||||
}
|
}
|
||||||
if writer.drain_deadline_epoch_secs != 0 && now_epoch_secs >= writer.drain_deadline_epoch_secs
|
if writer.drain_deadline_epoch_secs != 0 && now_epoch_secs >= writer.drain_deadline_epoch_secs
|
||||||
{
|
{
|
||||||
if can_drop_with_replacement {
|
warn!(writer_id = writer.id, "Drain timeout, force-closing");
|
||||||
warn!(writer_id = writer.id, "Drain timeout, force-closing");
|
force_close_writer_ids.push(writer.id);
|
||||||
pool.stats.increment_pool_force_close_total();
|
|
||||||
pool.remove_writer_and_close_clients(writer.id).await;
|
|
||||||
} else if should_emit_writer_warn(
|
|
||||||
warn_next_allowed,
|
|
||||||
writer.id,
|
|
||||||
now,
|
|
||||||
pool.warn_rate_limit_duration(),
|
|
||||||
) {
|
|
||||||
warn!(
|
|
||||||
writer_id = writer.id,
|
|
||||||
writer_dc = writer.writer_dc,
|
|
||||||
endpoint = %writer.addr,
|
|
||||||
"Drain timeout reached, but replacement coverage is incomplete; keeping draining writer"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let close_budget = health_drain_close_budget();
|
||||||
|
let requested_force_close = force_close_writer_ids.len();
|
||||||
|
let requested_empty_close = empty_writer_ids.len();
|
||||||
|
let requested_close_total = requested_force_close.saturating_add(requested_empty_close);
|
||||||
|
let mut closed_writer_ids = HashSet::<u64>::new();
|
||||||
|
let mut closed_total = 0usize;
|
||||||
|
for writer_id in force_close_writer_ids {
|
||||||
|
if closed_total >= close_budget {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if !closed_writer_ids.insert(writer_id) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
pool.stats.increment_pool_force_close_total();
|
||||||
|
pool.remove_writer_and_close_clients(writer_id).await;
|
||||||
|
closed_total = closed_total.saturating_add(1);
|
||||||
|
}
|
||||||
|
for writer_id in empty_writer_ids {
|
||||||
|
if closed_total >= close_budget {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if !closed_writer_ids.insert(writer_id) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if !pool.remove_writer_if_empty(writer_id).await {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
closed_total = closed_total.saturating_add(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
let pending_close_total = requested_close_total.saturating_sub(closed_total);
|
||||||
|
if pending_close_total > 0 {
|
||||||
|
warn!(
|
||||||
|
close_budget,
|
||||||
|
closed_total,
|
||||||
|
pending_close_total,
|
||||||
|
"ME draining close backlog deferred to next health cycle"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Keep warn cooldown state for draining writers still present in the pool;
|
||||||
|
// drop state only once a writer is actually removed.
|
||||||
|
let active_draining_writer_ids = {
|
||||||
|
let writers = pool.writers.read().await;
|
||||||
|
writers
|
||||||
|
.iter()
|
||||||
|
.filter(|writer| writer.draining.load(std::sync::atomic::Ordering::Relaxed))
|
||||||
|
.map(|writer| writer.id)
|
||||||
|
.collect::<HashSet<u64>>()
|
||||||
|
};
|
||||||
|
warn_next_allowed.retain(|writer_id, _| active_draining_writer_ids.contains(writer_id));
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(super) fn health_drain_close_budget() -> usize {
|
||||||
|
let cpu_cores = std::thread::available_parallelism()
|
||||||
|
.map(std::num::NonZeroUsize::get)
|
||||||
|
.unwrap_or(1);
|
||||||
|
cpu_cores
|
||||||
|
.saturating_mul(HEALTH_DRAIN_CLOSE_BUDGET_PER_CORE)
|
||||||
|
.clamp(HEALTH_DRAIN_CLOSE_BUDGET_MIN, HEALTH_DRAIN_CLOSE_BUDGET_MAX)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
|
|
@ -1521,7 +1539,6 @@ mod tests {
|
||||||
pool.writers.write().await.push(writer);
|
pool.writers.write().await.push(writer);
|
||||||
pool.registry.register_writer(writer_id, tx).await;
|
pool.registry.register_writer(writer_id, tx).await;
|
||||||
pool.conn_count.fetch_add(1, Ordering::Relaxed);
|
pool.conn_count.fetch_add(1, Ordering::Relaxed);
|
||||||
pool.increment_draining_active_runtime();
|
|
||||||
assert!(
|
assert!(
|
||||||
pool.registry
|
pool.registry
|
||||||
.bind_writer(
|
.bind_writer(
|
||||||
|
|
@ -1570,7 +1587,6 @@ mod tests {
|
||||||
async fn reap_draining_writers_force_closes_oldest_over_threshold() {
|
async fn reap_draining_writers_force_closes_oldest_over_threshold() {
|
||||||
let pool = make_pool(2).await;
|
let pool = make_pool(2).await;
|
||||||
insert_live_writer(&pool, 1, 2).await;
|
insert_live_writer(&pool, 1, 2).await;
|
||||||
assert!(pool.has_non_draining_writer_per_desired_dc_group().await);
|
|
||||||
let now_epoch_secs = MePool::now_epoch_secs();
|
let now_epoch_secs = MePool::now_epoch_secs();
|
||||||
let conn_a = insert_draining_writer(&pool, 10, now_epoch_secs.saturating_sub(30)).await;
|
let conn_a = insert_draining_writer(&pool, 10, now_epoch_secs.saturating_sub(30)).await;
|
||||||
let conn_b = insert_draining_writer(&pool, 20, now_epoch_secs.saturating_sub(20)).await;
|
let conn_b = insert_draining_writer(&pool, 20, now_epoch_secs.saturating_sub(20)).await;
|
||||||
|
|
@ -1588,7 +1604,7 @@ mod tests {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn reap_draining_writers_does_not_force_close_overflow_without_replacement() {
|
async fn reap_draining_writers_force_closes_overflow_without_replacement() {
|
||||||
let pool = make_pool(2).await;
|
let pool = make_pool(2).await;
|
||||||
let now_epoch_secs = MePool::now_epoch_secs();
|
let now_epoch_secs = MePool::now_epoch_secs();
|
||||||
let conn_a = insert_draining_writer(&pool, 10, now_epoch_secs.saturating_sub(30)).await;
|
let conn_a = insert_draining_writer(&pool, 10, now_epoch_secs.saturating_sub(30)).await;
|
||||||
|
|
@ -1600,8 +1616,8 @@ mod tests {
|
||||||
|
|
||||||
let mut writer_ids: Vec<u64> = pool.writers.read().await.iter().map(|writer| writer.id).collect();
|
let mut writer_ids: Vec<u64> = pool.writers.read().await.iter().map(|writer| writer.id).collect();
|
||||||
writer_ids.sort_unstable();
|
writer_ids.sort_unstable();
|
||||||
assert_eq!(writer_ids, vec![10, 20, 30]);
|
assert_eq!(writer_ids, vec![20, 30]);
|
||||||
assert_eq!(pool.registry.get_writer(conn_a).await.unwrap().writer_id, 10);
|
assert!(pool.registry.get_writer(conn_a).await.is_none());
|
||||||
assert_eq!(pool.registry.get_writer(conn_b).await.unwrap().writer_id, 20);
|
assert_eq!(pool.registry.get_writer(conn_b).await.unwrap().writer_id, 20);
|
||||||
assert_eq!(pool.registry.get_writer(conn_c).await.unwrap().writer_id, 30);
|
assert_eq!(pool.registry.get_writer(conn_c).await.unwrap().writer_id, 30);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,615 @@
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::collections::HashSet;
|
||||||
|
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::sync::atomic::{AtomicBool, AtomicU8, AtomicU32, AtomicU64, Ordering};
|
||||||
|
use std::time::{Duration, Instant};
|
||||||
|
|
||||||
|
use tokio::sync::mpsc;
|
||||||
|
use tokio_util::sync::CancellationToken;
|
||||||
|
|
||||||
|
use super::codec::WriterCommand;
|
||||||
|
use super::health::{health_drain_close_budget, reap_draining_writers};
|
||||||
|
use super::pool::{MePool, MeWriter, WriterContour};
|
||||||
|
use super::registry::ConnMeta;
|
||||||
|
use super::me_health_monitor;
|
||||||
|
use crate::config::{GeneralConfig, MeRouteNoWriterMode, MeSocksKdfPolicy, MeWriterPickMode};
|
||||||
|
use crate::crypto::SecureRandom;
|
||||||
|
use crate::network::probe::NetworkDecision;
|
||||||
|
use crate::stats::Stats;
|
||||||
|
|
||||||
|
async fn make_pool(
|
||||||
|
me_pool_drain_threshold: u64,
|
||||||
|
me_health_interval_ms_unhealthy: u64,
|
||||||
|
me_health_interval_ms_healthy: u64,
|
||||||
|
) -> (Arc<MePool>, Arc<SecureRandom>) {
|
||||||
|
let general = GeneralConfig {
|
||||||
|
me_pool_drain_threshold,
|
||||||
|
me_health_interval_ms_unhealthy,
|
||||||
|
me_health_interval_ms_healthy,
|
||||||
|
..GeneralConfig::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
let rng = Arc::new(SecureRandom::new());
|
||||||
|
let pool = MePool::new(
|
||||||
|
None,
|
||||||
|
vec![1u8; 32],
|
||||||
|
None,
|
||||||
|
false,
|
||||||
|
None,
|
||||||
|
Vec::new(),
|
||||||
|
1,
|
||||||
|
None,
|
||||||
|
12,
|
||||||
|
1200,
|
||||||
|
HashMap::new(),
|
||||||
|
HashMap::new(),
|
||||||
|
None,
|
||||||
|
NetworkDecision::default(),
|
||||||
|
None,
|
||||||
|
rng.clone(),
|
||||||
|
Arc::new(Stats::default()),
|
||||||
|
general.me_keepalive_enabled,
|
||||||
|
general.me_keepalive_interval_secs,
|
||||||
|
general.me_keepalive_jitter_secs,
|
||||||
|
general.me_keepalive_payload_random,
|
||||||
|
general.rpc_proxy_req_every,
|
||||||
|
general.me_warmup_stagger_enabled,
|
||||||
|
general.me_warmup_step_delay_ms,
|
||||||
|
general.me_warmup_step_jitter_ms,
|
||||||
|
general.me_reconnect_max_concurrent_per_dc,
|
||||||
|
general.me_reconnect_backoff_base_ms,
|
||||||
|
general.me_reconnect_backoff_cap_ms,
|
||||||
|
general.me_reconnect_fast_retry_count,
|
||||||
|
general.me_single_endpoint_shadow_writers,
|
||||||
|
general.me_single_endpoint_outage_mode_enabled,
|
||||||
|
general.me_single_endpoint_outage_disable_quarantine,
|
||||||
|
general.me_single_endpoint_outage_backoff_min_ms,
|
||||||
|
general.me_single_endpoint_outage_backoff_max_ms,
|
||||||
|
general.me_single_endpoint_shadow_rotate_every_secs,
|
||||||
|
general.me_floor_mode,
|
||||||
|
general.me_adaptive_floor_idle_secs,
|
||||||
|
general.me_adaptive_floor_min_writers_single_endpoint,
|
||||||
|
general.me_adaptive_floor_min_writers_multi_endpoint,
|
||||||
|
general.me_adaptive_floor_recover_grace_secs,
|
||||||
|
general.me_adaptive_floor_writers_per_core_total,
|
||||||
|
general.me_adaptive_floor_cpu_cores_override,
|
||||||
|
general.me_adaptive_floor_max_extra_writers_single_per_core,
|
||||||
|
general.me_adaptive_floor_max_extra_writers_multi_per_core,
|
||||||
|
general.me_adaptive_floor_max_active_writers_per_core,
|
||||||
|
general.me_adaptive_floor_max_warm_writers_per_core,
|
||||||
|
general.me_adaptive_floor_max_active_writers_global,
|
||||||
|
general.me_adaptive_floor_max_warm_writers_global,
|
||||||
|
general.hardswap,
|
||||||
|
general.me_pool_drain_ttl_secs,
|
||||||
|
general.me_pool_drain_threshold,
|
||||||
|
general.effective_me_pool_force_close_secs(),
|
||||||
|
general.me_pool_min_fresh_ratio,
|
||||||
|
general.me_hardswap_warmup_delay_min_ms,
|
||||||
|
general.me_hardswap_warmup_delay_max_ms,
|
||||||
|
general.me_hardswap_warmup_extra_passes,
|
||||||
|
general.me_hardswap_warmup_pass_backoff_base_ms,
|
||||||
|
general.me_bind_stale_mode,
|
||||||
|
general.me_bind_stale_ttl_secs,
|
||||||
|
general.me_secret_atomic_snapshot,
|
||||||
|
general.me_deterministic_writer_sort,
|
||||||
|
MeWriterPickMode::default(),
|
||||||
|
general.me_writer_pick_sample_size,
|
||||||
|
MeSocksKdfPolicy::default(),
|
||||||
|
general.me_writer_cmd_channel_capacity,
|
||||||
|
general.me_route_channel_capacity,
|
||||||
|
general.me_route_backpressure_base_timeout_ms,
|
||||||
|
general.me_route_backpressure_high_timeout_ms,
|
||||||
|
general.me_route_backpressure_high_watermark_pct,
|
||||||
|
general.me_reader_route_data_wait_ms,
|
||||||
|
general.me_health_interval_ms_unhealthy,
|
||||||
|
general.me_health_interval_ms_healthy,
|
||||||
|
general.me_warn_rate_limit_ms,
|
||||||
|
MeRouteNoWriterMode::default(),
|
||||||
|
general.me_route_no_writer_wait_ms,
|
||||||
|
general.me_route_inline_recovery_attempts,
|
||||||
|
general.me_route_inline_recovery_wait_ms,
|
||||||
|
);
|
||||||
|
|
||||||
|
(pool, rng)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn insert_draining_writer(
|
||||||
|
pool: &Arc<MePool>,
|
||||||
|
writer_id: u64,
|
||||||
|
drain_started_at_epoch_secs: u64,
|
||||||
|
bound_clients: usize,
|
||||||
|
drain_deadline_epoch_secs: u64,
|
||||||
|
) {
|
||||||
|
let (tx, _writer_rx) = mpsc::channel::<WriterCommand>(8);
|
||||||
|
let writer = MeWriter {
|
||||||
|
id: writer_id,
|
||||||
|
addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 6000 + writer_id as u16),
|
||||||
|
source_ip: IpAddr::V4(Ipv4Addr::LOCALHOST),
|
||||||
|
writer_dc: 2,
|
||||||
|
generation: 1,
|
||||||
|
contour: Arc::new(AtomicU8::new(WriterContour::Draining.as_u8())),
|
||||||
|
created_at: Instant::now() - Duration::from_secs(writer_id),
|
||||||
|
tx: tx.clone(),
|
||||||
|
cancel: CancellationToken::new(),
|
||||||
|
degraded: Arc::new(AtomicBool::new(false)),
|
||||||
|
rtt_ema_ms_x10: Arc::new(AtomicU32::new(0)),
|
||||||
|
draining: Arc::new(AtomicBool::new(true)),
|
||||||
|
draining_started_at_epoch_secs: Arc::new(AtomicU64::new(drain_started_at_epoch_secs)),
|
||||||
|
drain_deadline_epoch_secs: Arc::new(AtomicU64::new(drain_deadline_epoch_secs)),
|
||||||
|
allow_drain_fallback: Arc::new(AtomicBool::new(false)),
|
||||||
|
};
|
||||||
|
|
||||||
|
pool.writers.write().await.push(writer);
|
||||||
|
pool.registry.register_writer(writer_id, tx).await;
|
||||||
|
pool.conn_count.fetch_add(1, Ordering::Relaxed);
|
||||||
|
|
||||||
|
for idx in 0..bound_clients {
|
||||||
|
let (conn_id, _rx) = pool.registry.register().await;
|
||||||
|
assert!(
|
||||||
|
pool.registry
|
||||||
|
.bind_writer(
|
||||||
|
conn_id,
|
||||||
|
writer_id,
|
||||||
|
ConnMeta {
|
||||||
|
target_dc: 2,
|
||||||
|
client_addr: SocketAddr::new(
|
||||||
|
IpAddr::V4(Ipv4Addr::LOCALHOST),
|
||||||
|
8000 + idx as u16,
|
||||||
|
),
|
||||||
|
our_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443),
|
||||||
|
proto_flags: 0,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn writer_count(pool: &Arc<MePool>) -> usize {
|
||||||
|
pool.writers.read().await.len()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn sorted_writer_ids(pool: &Arc<MePool>) -> Vec<u64> {
|
||||||
|
let mut ids = pool
|
||||||
|
.writers
|
||||||
|
.read()
|
||||||
|
.await
|
||||||
|
.iter()
|
||||||
|
.map(|writer| writer.id)
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
ids.sort_unstable();
|
||||||
|
ids
|
||||||
|
}
|
||||||
|
|
||||||
|
fn lcg_next(state: &mut u64) -> u64 {
|
||||||
|
*state = state.wrapping_mul(6364136223846793005).wrapping_add(1);
|
||||||
|
*state
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn draining_writer_ids(pool: &Arc<MePool>) -> HashSet<u64> {
|
||||||
|
pool.writers
|
||||||
|
.read()
|
||||||
|
.await
|
||||||
|
.iter()
|
||||||
|
.filter(|writer| writer.draining.load(Ordering::Relaxed))
|
||||||
|
.map(|writer| writer.id)
|
||||||
|
.collect::<HashSet<u64>>()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn set_writer_runtime_state(
|
||||||
|
pool: &Arc<MePool>,
|
||||||
|
writer_id: u64,
|
||||||
|
draining: bool,
|
||||||
|
drain_started_at_epoch_secs: u64,
|
||||||
|
drain_deadline_epoch_secs: u64,
|
||||||
|
) {
|
||||||
|
let writers = pool.writers.read().await;
|
||||||
|
if let Some(writer) = writers.iter().find(|writer| writer.id == writer_id) {
|
||||||
|
writer.draining.store(draining, Ordering::Relaxed);
|
||||||
|
writer
|
||||||
|
.draining_started_at_epoch_secs
|
||||||
|
.store(drain_started_at_epoch_secs, Ordering::Relaxed);
|
||||||
|
writer
|
||||||
|
.drain_deadline_epoch_secs
|
||||||
|
.store(drain_deadline_epoch_secs, Ordering::Relaxed);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn reap_draining_writers_clears_warn_state_when_pool_empty() {
|
||||||
|
let (pool, _rng) = make_pool(128, 1, 1).await;
|
||||||
|
let mut warn_next_allowed = HashMap::new();
|
||||||
|
warn_next_allowed.insert(11, Instant::now() + Duration::from_secs(5));
|
||||||
|
warn_next_allowed.insert(22, Instant::now() + Duration::from_secs(5));
|
||||||
|
|
||||||
|
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||||
|
|
||||||
|
assert!(warn_next_allowed.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn reap_draining_writers_respects_threshold_across_multiple_overflow_cycles() {
|
||||||
|
let threshold = 3u64;
|
||||||
|
let (pool, _rng) = make_pool(threshold, 1, 1).await;
|
||||||
|
let now_epoch_secs = MePool::now_epoch_secs();
|
||||||
|
|
||||||
|
for writer_id in 1..=60u64 {
|
||||||
|
insert_draining_writer(
|
||||||
|
&pool,
|
||||||
|
writer_id,
|
||||||
|
now_epoch_secs.saturating_sub(600).saturating_add(writer_id),
|
||||||
|
1,
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut warn_next_allowed = HashMap::new();
|
||||||
|
for _ in 0..64 {
|
||||||
|
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||||
|
if writer_count(&pool).await <= threshold as usize {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assert_eq!(writer_count(&pool).await, threshold as usize);
|
||||||
|
assert_eq!(sorted_writer_ids(&pool).await, vec![58, 59, 60]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn reap_draining_writers_handles_large_empty_writer_population() {
|
||||||
|
let (pool, _rng) = make_pool(128, 1, 1).await;
|
||||||
|
let now_epoch_secs = MePool::now_epoch_secs();
|
||||||
|
let total = health_drain_close_budget().saturating_mul(3).saturating_add(27);
|
||||||
|
|
||||||
|
for writer_id in 1..=total as u64 {
|
||||||
|
insert_draining_writer(
|
||||||
|
&pool,
|
||||||
|
writer_id,
|
||||||
|
now_epoch_secs.saturating_sub(120),
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut warn_next_allowed = HashMap::new();
|
||||||
|
for _ in 0..24 {
|
||||||
|
if writer_count(&pool).await == 0 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
assert_eq!(writer_count(&pool).await, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn reap_draining_writers_processes_mass_deadline_expiry_without_unbounded_growth() {
|
||||||
|
let (pool, _rng) = make_pool(128, 1, 1).await;
|
||||||
|
let now_epoch_secs = MePool::now_epoch_secs();
|
||||||
|
let total = health_drain_close_budget().saturating_mul(4).saturating_add(31);
|
||||||
|
|
||||||
|
for writer_id in 1..=total as u64 {
|
||||||
|
insert_draining_writer(
|
||||||
|
&pool,
|
||||||
|
writer_id,
|
||||||
|
now_epoch_secs.saturating_sub(180),
|
||||||
|
1,
|
||||||
|
now_epoch_secs.saturating_sub(1),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut warn_next_allowed = HashMap::new();
|
||||||
|
for _ in 0..40 {
|
||||||
|
if writer_count(&pool).await == 0 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
assert_eq!(writer_count(&pool).await, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn reap_draining_writers_maintains_warn_state_subset_property_under_bulk_churn() {
|
||||||
|
let (pool, _rng) = make_pool(128, 1, 1).await;
|
||||||
|
let now_epoch_secs = MePool::now_epoch_secs();
|
||||||
|
let mut warn_next_allowed = HashMap::new();
|
||||||
|
|
||||||
|
for wave in 0..40u64 {
|
||||||
|
for offset in 0..8u64 {
|
||||||
|
insert_draining_writer(
|
||||||
|
&pool,
|
||||||
|
wave * 100 + offset,
|
||||||
|
now_epoch_secs.saturating_sub(400 + offset),
|
||||||
|
1,
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||||
|
assert!(warn_next_allowed.len() <= writer_count(&pool).await);
|
||||||
|
|
||||||
|
let ids = sorted_writer_ids(&pool).await;
|
||||||
|
for writer_id in ids.into_iter().take(3) {
|
||||||
|
let _ = pool.remove_writer_and_close_clients(writer_id).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||||
|
assert!(warn_next_allowed.len() <= writer_count(&pool).await);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn reap_draining_writers_budgeted_cleanup_never_increases_pool_size() {
|
||||||
|
let (pool, _rng) = make_pool(5, 1, 1).await;
|
||||||
|
let now_epoch_secs = MePool::now_epoch_secs();
|
||||||
|
|
||||||
|
for writer_id in 1..=200u64 {
|
||||||
|
insert_draining_writer(
|
||||||
|
&pool,
|
||||||
|
writer_id,
|
||||||
|
now_epoch_secs.saturating_sub(240).saturating_add(writer_id),
|
||||||
|
1,
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut warn_next_allowed = HashMap::new();
|
||||||
|
let mut previous = writer_count(&pool).await;
|
||||||
|
for _ in 0..32 {
|
||||||
|
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||||
|
let current = writer_count(&pool).await;
|
||||||
|
assert!(current <= previous);
|
||||||
|
previous = current;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn me_health_monitor_converges_to_threshold_under_live_injection_churn() {
|
||||||
|
let threshold = 7u64;
|
||||||
|
let (pool, rng) = make_pool(threshold, 1, 1).await;
|
||||||
|
let now_epoch_secs = MePool::now_epoch_secs();
|
||||||
|
|
||||||
|
for writer_id in 1..=40u64 {
|
||||||
|
insert_draining_writer(
|
||||||
|
&pool,
|
||||||
|
writer_id,
|
||||||
|
now_epoch_secs.saturating_sub(300).saturating_add(writer_id),
|
||||||
|
1,
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
let monitor = tokio::spawn(me_health_monitor(pool.clone(), rng, 0));
|
||||||
|
|
||||||
|
for wave in 0..8u64 {
|
||||||
|
for offset in 0..10u64 {
|
||||||
|
insert_draining_writer(
|
||||||
|
&pool,
|
||||||
|
1000 + wave * 100 + offset,
|
||||||
|
now_epoch_secs.saturating_sub(120).saturating_add(offset),
|
||||||
|
1,
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
tokio::time::sleep(Duration::from_millis(5)).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
tokio::time::sleep(Duration::from_millis(120)).await;
|
||||||
|
monitor.abort();
|
||||||
|
let _ = monitor.await;
|
||||||
|
|
||||||
|
assert!(writer_count(&pool).await <= threshold as usize);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn me_health_monitor_drains_deadline_storm_with_budgeted_progress() {
|
||||||
|
let (pool, rng) = make_pool(128, 1, 1).await;
|
||||||
|
let now_epoch_secs = MePool::now_epoch_secs();
|
||||||
|
|
||||||
|
for writer_id in 1..=220u64 {
|
||||||
|
insert_draining_writer(
|
||||||
|
&pool,
|
||||||
|
writer_id,
|
||||||
|
now_epoch_secs.saturating_sub(120),
|
||||||
|
1,
|
||||||
|
now_epoch_secs.saturating_sub(1),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
let monitor = tokio::spawn(me_health_monitor(pool.clone(), rng, 0));
|
||||||
|
tokio::time::sleep(Duration::from_millis(120)).await;
|
||||||
|
monitor.abort();
|
||||||
|
let _ = monitor.await;
|
||||||
|
|
||||||
|
assert_eq!(writer_count(&pool).await, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn me_health_monitor_eliminates_mixed_empty_and_deadline_backlog() {
|
||||||
|
let threshold = 12u64;
|
||||||
|
let (pool, rng) = make_pool(threshold, 1, 1).await;
|
||||||
|
let now_epoch_secs = MePool::now_epoch_secs();
|
||||||
|
|
||||||
|
for writer_id in 1..=180u64 {
|
||||||
|
let bound_clients = if writer_id % 3 == 0 { 0 } else { 1 };
|
||||||
|
let deadline = if writer_id % 2 == 0 {
|
||||||
|
now_epoch_secs.saturating_sub(1)
|
||||||
|
} else {
|
||||||
|
0
|
||||||
|
};
|
||||||
|
insert_draining_writer(
|
||||||
|
&pool,
|
||||||
|
writer_id,
|
||||||
|
now_epoch_secs.saturating_sub(250).saturating_add(writer_id),
|
||||||
|
bound_clients,
|
||||||
|
deadline,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
let monitor = tokio::spawn(me_health_monitor(pool.clone(), rng, 0));
|
||||||
|
tokio::time::sleep(Duration::from_millis(140)).await;
|
||||||
|
monitor.abort();
|
||||||
|
let _ = monitor.await;
|
||||||
|
|
||||||
|
assert!(writer_count(&pool).await <= threshold as usize);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn reap_draining_writers_deterministic_mixed_state_churn_preserves_invariants() {
|
||||||
|
let threshold = 9u64;
|
||||||
|
let (pool, _rng) = make_pool(threshold, 1, 1).await;
|
||||||
|
let mut warn_next_allowed = HashMap::new();
|
||||||
|
let mut seed = 0x9E37_79B9_7F4A_7C15u64;
|
||||||
|
let mut next_writer_id = 20_000u64;
|
||||||
|
let now_epoch_secs = MePool::now_epoch_secs();
|
||||||
|
|
||||||
|
for writer_id in 1..=72u64 {
|
||||||
|
let bound_clients = if writer_id % 4 == 0 { 0 } else { 1 };
|
||||||
|
let deadline = if writer_id % 5 == 0 {
|
||||||
|
now_epoch_secs.saturating_sub(1)
|
||||||
|
} else {
|
||||||
|
0
|
||||||
|
};
|
||||||
|
insert_draining_writer(
|
||||||
|
&pool,
|
||||||
|
writer_id,
|
||||||
|
now_epoch_secs.saturating_sub(500).saturating_add(writer_id),
|
||||||
|
bound_clients,
|
||||||
|
deadline,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
for _round in 0..90 {
|
||||||
|
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||||
|
|
||||||
|
let draining_ids = draining_writer_ids(&pool).await;
|
||||||
|
assert!(
|
||||||
|
warn_next_allowed.keys().all(|id| draining_ids.contains(id)),
|
||||||
|
"warn-state keys must always be a subset of live draining writers"
|
||||||
|
);
|
||||||
|
|
||||||
|
let writer_ids = sorted_writer_ids(&pool).await;
|
||||||
|
if writer_ids.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let remove_n = (lcg_next(&mut seed) % 3) as usize;
|
||||||
|
for writer_id in writer_ids.iter().copied().take(remove_n) {
|
||||||
|
let _ = pool.remove_writer_and_close_clients(writer_id).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
let survivors = sorted_writer_ids(&pool).await;
|
||||||
|
if !survivors.is_empty() {
|
||||||
|
let idx = (lcg_next(&mut seed) as usize) % survivors.len();
|
||||||
|
let target = survivors[idx];
|
||||||
|
set_writer_runtime_state(&pool, target, false, 0, 0).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
let survivors = sorted_writer_ids(&pool).await;
|
||||||
|
if survivors.len() > 1 {
|
||||||
|
let idx = (lcg_next(&mut seed) as usize) % survivors.len();
|
||||||
|
let target = survivors[idx];
|
||||||
|
let expired_deadline = if lcg_next(&mut seed) & 1 == 0 {
|
||||||
|
now_epoch_secs.saturating_sub(1)
|
||||||
|
} else {
|
||||||
|
0
|
||||||
|
};
|
||||||
|
set_writer_runtime_state(
|
||||||
|
&pool,
|
||||||
|
target,
|
||||||
|
true,
|
||||||
|
now_epoch_secs.saturating_sub(120),
|
||||||
|
expired_deadline,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
let inject_n = (lcg_next(&mut seed) % 4) as usize;
|
||||||
|
for _ in 0..inject_n {
|
||||||
|
let bound_clients = if lcg_next(&mut seed) & 1 == 0 { 0 } else { 1 };
|
||||||
|
let deadline = if lcg_next(&mut seed) & 1 == 0 {
|
||||||
|
now_epoch_secs.saturating_sub(1)
|
||||||
|
} else {
|
||||||
|
0
|
||||||
|
};
|
||||||
|
insert_draining_writer(
|
||||||
|
&pool,
|
||||||
|
next_writer_id,
|
||||||
|
now_epoch_secs.saturating_sub(240),
|
||||||
|
bound_clients,
|
||||||
|
deadline,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
next_writer_id = next_writer_id.saturating_add(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _ in 0..64 {
|
||||||
|
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||||
|
if writer_count(&pool).await <= threshold as usize {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assert!(writer_count(&pool).await <= threshold as usize);
|
||||||
|
let draining_ids = draining_writer_ids(&pool).await;
|
||||||
|
assert!(warn_next_allowed.keys().all(|id| draining_ids.contains(id)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn reap_draining_writers_repeated_draining_flips_never_leave_stale_warn_state() {
|
||||||
|
let (pool, _rng) = make_pool(64, 1, 1).await;
|
||||||
|
let now_epoch_secs = MePool::now_epoch_secs();
|
||||||
|
|
||||||
|
for writer_id in 1..=24u64 {
|
||||||
|
insert_draining_writer(
|
||||||
|
&pool,
|
||||||
|
writer_id,
|
||||||
|
now_epoch_secs.saturating_sub(240),
|
||||||
|
1,
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut warn_next_allowed = HashMap::new();
|
||||||
|
for _round in 0..48u64 {
|
||||||
|
for writer_id in 1..=24u64 {
|
||||||
|
let draining = (writer_id + _round) % 3 != 0;
|
||||||
|
set_writer_runtime_state(
|
||||||
|
&pool,
|
||||||
|
writer_id,
|
||||||
|
draining,
|
||||||
|
now_epoch_secs.saturating_sub(120),
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||||
|
|
||||||
|
let draining_ids = draining_writer_ids(&pool).await;
|
||||||
|
assert!(
|
||||||
|
warn_next_allowed.keys().all(|id| draining_ids.contains(id)),
|
||||||
|
"warn-state map must not retain entries for writers outside draining set"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn health_drain_close_budget_is_within_expected_bounds() {
|
||||||
|
let budget = health_drain_close_budget();
|
||||||
|
assert!((16..=256).contains(&budget));
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,241 @@
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::sync::atomic::{AtomicBool, AtomicU8, AtomicU32, AtomicU64, Ordering};
|
||||||
|
use std::time::{Duration, Instant};
|
||||||
|
|
||||||
|
use tokio::sync::mpsc;
|
||||||
|
use tokio_util::sync::CancellationToken;
|
||||||
|
|
||||||
|
use super::codec::WriterCommand;
|
||||||
|
use super::health::health_drain_close_budget;
|
||||||
|
use super::pool::{MePool, MeWriter, WriterContour};
|
||||||
|
use super::registry::ConnMeta;
|
||||||
|
use super::me_health_monitor;
|
||||||
|
use crate::config::{GeneralConfig, MeRouteNoWriterMode, MeSocksKdfPolicy, MeWriterPickMode};
|
||||||
|
use crate::crypto::SecureRandom;
|
||||||
|
use crate::network::probe::NetworkDecision;
|
||||||
|
use crate::stats::Stats;
|
||||||
|
|
||||||
|
async fn make_pool(
|
||||||
|
me_pool_drain_threshold: u64,
|
||||||
|
me_health_interval_ms_unhealthy: u64,
|
||||||
|
me_health_interval_ms_healthy: u64,
|
||||||
|
) -> (Arc<MePool>, Arc<SecureRandom>) {
|
||||||
|
let general = GeneralConfig {
|
||||||
|
me_pool_drain_threshold,
|
||||||
|
me_health_interval_ms_unhealthy,
|
||||||
|
me_health_interval_ms_healthy,
|
||||||
|
..GeneralConfig::default()
|
||||||
|
};
|
||||||
|
let rng = Arc::new(SecureRandom::new());
|
||||||
|
let pool = MePool::new(
|
||||||
|
None,
|
||||||
|
vec![1u8; 32],
|
||||||
|
None,
|
||||||
|
false,
|
||||||
|
None,
|
||||||
|
Vec::new(),
|
||||||
|
1,
|
||||||
|
None,
|
||||||
|
12,
|
||||||
|
1200,
|
||||||
|
HashMap::new(),
|
||||||
|
HashMap::new(),
|
||||||
|
None,
|
||||||
|
NetworkDecision::default(),
|
||||||
|
None,
|
||||||
|
rng.clone(),
|
||||||
|
Arc::new(Stats::default()),
|
||||||
|
general.me_keepalive_enabled,
|
||||||
|
general.me_keepalive_interval_secs,
|
||||||
|
general.me_keepalive_jitter_secs,
|
||||||
|
general.me_keepalive_payload_random,
|
||||||
|
general.rpc_proxy_req_every,
|
||||||
|
general.me_warmup_stagger_enabled,
|
||||||
|
general.me_warmup_step_delay_ms,
|
||||||
|
general.me_warmup_step_jitter_ms,
|
||||||
|
general.me_reconnect_max_concurrent_per_dc,
|
||||||
|
general.me_reconnect_backoff_base_ms,
|
||||||
|
general.me_reconnect_backoff_cap_ms,
|
||||||
|
general.me_reconnect_fast_retry_count,
|
||||||
|
general.me_single_endpoint_shadow_writers,
|
||||||
|
general.me_single_endpoint_outage_mode_enabled,
|
||||||
|
general.me_single_endpoint_outage_disable_quarantine,
|
||||||
|
general.me_single_endpoint_outage_backoff_min_ms,
|
||||||
|
general.me_single_endpoint_outage_backoff_max_ms,
|
||||||
|
general.me_single_endpoint_shadow_rotate_every_secs,
|
||||||
|
general.me_floor_mode,
|
||||||
|
general.me_adaptive_floor_idle_secs,
|
||||||
|
general.me_adaptive_floor_min_writers_single_endpoint,
|
||||||
|
general.me_adaptive_floor_min_writers_multi_endpoint,
|
||||||
|
general.me_adaptive_floor_recover_grace_secs,
|
||||||
|
general.me_adaptive_floor_writers_per_core_total,
|
||||||
|
general.me_adaptive_floor_cpu_cores_override,
|
||||||
|
general.me_adaptive_floor_max_extra_writers_single_per_core,
|
||||||
|
general.me_adaptive_floor_max_extra_writers_multi_per_core,
|
||||||
|
general.me_adaptive_floor_max_active_writers_per_core,
|
||||||
|
general.me_adaptive_floor_max_warm_writers_per_core,
|
||||||
|
general.me_adaptive_floor_max_active_writers_global,
|
||||||
|
general.me_adaptive_floor_max_warm_writers_global,
|
||||||
|
general.hardswap,
|
||||||
|
general.me_pool_drain_ttl_secs,
|
||||||
|
general.me_pool_drain_threshold,
|
||||||
|
general.effective_me_pool_force_close_secs(),
|
||||||
|
general.me_pool_min_fresh_ratio,
|
||||||
|
general.me_hardswap_warmup_delay_min_ms,
|
||||||
|
general.me_hardswap_warmup_delay_max_ms,
|
||||||
|
general.me_hardswap_warmup_extra_passes,
|
||||||
|
general.me_hardswap_warmup_pass_backoff_base_ms,
|
||||||
|
general.me_bind_stale_mode,
|
||||||
|
general.me_bind_stale_ttl_secs,
|
||||||
|
general.me_secret_atomic_snapshot,
|
||||||
|
general.me_deterministic_writer_sort,
|
||||||
|
MeWriterPickMode::default(),
|
||||||
|
general.me_writer_pick_sample_size,
|
||||||
|
MeSocksKdfPolicy::default(),
|
||||||
|
general.me_writer_cmd_channel_capacity,
|
||||||
|
general.me_route_channel_capacity,
|
||||||
|
general.me_route_backpressure_base_timeout_ms,
|
||||||
|
general.me_route_backpressure_high_timeout_ms,
|
||||||
|
general.me_route_backpressure_high_watermark_pct,
|
||||||
|
general.me_reader_route_data_wait_ms,
|
||||||
|
general.me_health_interval_ms_unhealthy,
|
||||||
|
general.me_health_interval_ms_healthy,
|
||||||
|
general.me_warn_rate_limit_ms,
|
||||||
|
MeRouteNoWriterMode::default(),
|
||||||
|
general.me_route_no_writer_wait_ms,
|
||||||
|
general.me_route_inline_recovery_attempts,
|
||||||
|
general.me_route_inline_recovery_wait_ms,
|
||||||
|
);
|
||||||
|
(pool, rng)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn insert_draining_writer(
|
||||||
|
pool: &Arc<MePool>,
|
||||||
|
writer_id: u64,
|
||||||
|
drain_started_at_epoch_secs: u64,
|
||||||
|
bound_clients: usize,
|
||||||
|
drain_deadline_epoch_secs: u64,
|
||||||
|
) {
|
||||||
|
let (tx, _writer_rx) = mpsc::channel::<WriterCommand>(8);
|
||||||
|
let writer = MeWriter {
|
||||||
|
id: writer_id,
|
||||||
|
addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 5500 + writer_id as u16),
|
||||||
|
source_ip: IpAddr::V4(Ipv4Addr::LOCALHOST),
|
||||||
|
writer_dc: 2,
|
||||||
|
generation: 1,
|
||||||
|
contour: Arc::new(AtomicU8::new(WriterContour::Draining.as_u8())),
|
||||||
|
created_at: Instant::now() - Duration::from_secs(writer_id),
|
||||||
|
tx: tx.clone(),
|
||||||
|
cancel: CancellationToken::new(),
|
||||||
|
degraded: Arc::new(AtomicBool::new(false)),
|
||||||
|
rtt_ema_ms_x10: Arc::new(AtomicU32::new(0)),
|
||||||
|
draining: Arc::new(AtomicBool::new(true)),
|
||||||
|
draining_started_at_epoch_secs: Arc::new(AtomicU64::new(drain_started_at_epoch_secs)),
|
||||||
|
drain_deadline_epoch_secs: Arc::new(AtomicU64::new(drain_deadline_epoch_secs)),
|
||||||
|
allow_drain_fallback: Arc::new(AtomicBool::new(false)),
|
||||||
|
};
|
||||||
|
pool.writers.write().await.push(writer);
|
||||||
|
pool.registry.register_writer(writer_id, tx).await;
|
||||||
|
pool.conn_count.fetch_add(1, Ordering::Relaxed);
|
||||||
|
for idx in 0..bound_clients {
|
||||||
|
let (conn_id, _rx) = pool.registry.register().await;
|
||||||
|
assert!(
|
||||||
|
pool.registry
|
||||||
|
.bind_writer(
|
||||||
|
conn_id,
|
||||||
|
writer_id,
|
||||||
|
ConnMeta {
|
||||||
|
target_dc: 2,
|
||||||
|
client_addr: SocketAddr::new(
|
||||||
|
IpAddr::V4(Ipv4Addr::LOCALHOST),
|
||||||
|
7200 + idx as u16,
|
||||||
|
),
|
||||||
|
our_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443),
|
||||||
|
proto_flags: 0,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn wait_for_pool_empty(pool: &Arc<MePool>, timeout: Duration) {
|
||||||
|
let start = Instant::now();
|
||||||
|
loop {
|
||||||
|
if pool.writers.read().await.is_empty() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
assert!(
|
||||||
|
start.elapsed() < timeout,
|
||||||
|
"timed out waiting for pool.writers to become empty"
|
||||||
|
);
|
||||||
|
tokio::time::sleep(Duration::from_millis(5)).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn me_health_monitor_drains_expired_backlog_over_multiple_cycles() {
|
||||||
|
let (pool, rng) = make_pool(128, 1, 1).await;
|
||||||
|
let now_epoch_secs = MePool::now_epoch_secs();
|
||||||
|
let writer_total = health_drain_close_budget().saturating_mul(2).saturating_add(9);
|
||||||
|
for writer_id in 1..=writer_total as u64 {
|
||||||
|
insert_draining_writer(
|
||||||
|
&pool,
|
||||||
|
writer_id,
|
||||||
|
now_epoch_secs.saturating_sub(120),
|
||||||
|
1,
|
||||||
|
now_epoch_secs.saturating_sub(1),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
let monitor = tokio::spawn(me_health_monitor(pool.clone(), rng, 0));
|
||||||
|
wait_for_pool_empty(&pool, Duration::from_secs(1)).await;
|
||||||
|
monitor.abort();
|
||||||
|
let _ = monitor.await;
|
||||||
|
|
||||||
|
assert!(pool.writers.read().await.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn me_health_monitor_cleans_empty_draining_writers_without_force_close() {
|
||||||
|
let (pool, rng) = make_pool(128, 1, 1).await;
|
||||||
|
let now_epoch_secs = MePool::now_epoch_secs();
|
||||||
|
for writer_id in 1..=24u64 {
|
||||||
|
insert_draining_writer(&pool, writer_id, now_epoch_secs.saturating_sub(60), 0, 0).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
let monitor = tokio::spawn(me_health_monitor(pool.clone(), rng, 0));
|
||||||
|
wait_for_pool_empty(&pool, Duration::from_secs(1)).await;
|
||||||
|
monitor.abort();
|
||||||
|
let _ = monitor.await;
|
||||||
|
|
||||||
|
assert!(pool.writers.read().await.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn me_health_monitor_converges_retry_like_threshold_backlog_to_empty() {
|
||||||
|
let threshold = 4u64;
|
||||||
|
let (pool, rng) = make_pool(threshold, 1, 1).await;
|
||||||
|
let now_epoch_secs = MePool::now_epoch_secs();
|
||||||
|
let writer_total = threshold as usize + health_drain_close_budget().saturating_add(11);
|
||||||
|
for writer_id in 1..=writer_total as u64 {
|
||||||
|
insert_draining_writer(
|
||||||
|
&pool,
|
||||||
|
writer_id,
|
||||||
|
now_epoch_secs.saturating_sub(300).saturating_add(writer_id),
|
||||||
|
1,
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
let monitor = tokio::spawn(me_health_monitor(pool.clone(), rng, 0));
|
||||||
|
wait_for_pool_empty(&pool, Duration::from_secs(1)).await;
|
||||||
|
monitor.abort();
|
||||||
|
let _ = monitor.await;
|
||||||
|
|
||||||
|
assert!(pool.writers.read().await.is_empty());
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,658 @@
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::sync::atomic::{AtomicBool, AtomicU8, AtomicU32, AtomicU64, Ordering};
|
||||||
|
use std::time::{Duration, Instant};
|
||||||
|
|
||||||
|
use tokio::sync::mpsc;
|
||||||
|
use tokio_util::sync::CancellationToken;
|
||||||
|
|
||||||
|
use super::codec::WriterCommand;
|
||||||
|
use super::health::{health_drain_close_budget, reap_draining_writers};
|
||||||
|
use super::pool::{MePool, MeWriter, WriterContour};
|
||||||
|
use super::registry::ConnMeta;
|
||||||
|
use crate::config::{GeneralConfig, MeRouteNoWriterMode, MeSocksKdfPolicy, MeWriterPickMode};
|
||||||
|
use crate::crypto::SecureRandom;
|
||||||
|
use crate::network::probe::NetworkDecision;
|
||||||
|
use crate::stats::Stats;
|
||||||
|
|
||||||
|
async fn make_pool(me_pool_drain_threshold: u64) -> Arc<MePool> {
|
||||||
|
let general = GeneralConfig {
|
||||||
|
me_pool_drain_threshold,
|
||||||
|
..GeneralConfig::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
MePool::new(
|
||||||
|
None,
|
||||||
|
vec![1u8; 32],
|
||||||
|
None,
|
||||||
|
false,
|
||||||
|
None,
|
||||||
|
Vec::new(),
|
||||||
|
1,
|
||||||
|
None,
|
||||||
|
12,
|
||||||
|
1200,
|
||||||
|
HashMap::new(),
|
||||||
|
HashMap::new(),
|
||||||
|
None,
|
||||||
|
NetworkDecision::default(),
|
||||||
|
None,
|
||||||
|
Arc::new(SecureRandom::new()),
|
||||||
|
Arc::new(Stats::default()),
|
||||||
|
general.me_keepalive_enabled,
|
||||||
|
general.me_keepalive_interval_secs,
|
||||||
|
general.me_keepalive_jitter_secs,
|
||||||
|
general.me_keepalive_payload_random,
|
||||||
|
general.rpc_proxy_req_every,
|
||||||
|
general.me_warmup_stagger_enabled,
|
||||||
|
general.me_warmup_step_delay_ms,
|
||||||
|
general.me_warmup_step_jitter_ms,
|
||||||
|
general.me_reconnect_max_concurrent_per_dc,
|
||||||
|
general.me_reconnect_backoff_base_ms,
|
||||||
|
general.me_reconnect_backoff_cap_ms,
|
||||||
|
general.me_reconnect_fast_retry_count,
|
||||||
|
general.me_single_endpoint_shadow_writers,
|
||||||
|
general.me_single_endpoint_outage_mode_enabled,
|
||||||
|
general.me_single_endpoint_outage_disable_quarantine,
|
||||||
|
general.me_single_endpoint_outage_backoff_min_ms,
|
||||||
|
general.me_single_endpoint_outage_backoff_max_ms,
|
||||||
|
general.me_single_endpoint_shadow_rotate_every_secs,
|
||||||
|
general.me_floor_mode,
|
||||||
|
general.me_adaptive_floor_idle_secs,
|
||||||
|
general.me_adaptive_floor_min_writers_single_endpoint,
|
||||||
|
general.me_adaptive_floor_min_writers_multi_endpoint,
|
||||||
|
general.me_adaptive_floor_recover_grace_secs,
|
||||||
|
general.me_adaptive_floor_writers_per_core_total,
|
||||||
|
general.me_adaptive_floor_cpu_cores_override,
|
||||||
|
general.me_adaptive_floor_max_extra_writers_single_per_core,
|
||||||
|
general.me_adaptive_floor_max_extra_writers_multi_per_core,
|
||||||
|
general.me_adaptive_floor_max_active_writers_per_core,
|
||||||
|
general.me_adaptive_floor_max_warm_writers_per_core,
|
||||||
|
general.me_adaptive_floor_max_active_writers_global,
|
||||||
|
general.me_adaptive_floor_max_warm_writers_global,
|
||||||
|
general.hardswap,
|
||||||
|
general.me_pool_drain_ttl_secs,
|
||||||
|
general.me_pool_drain_threshold,
|
||||||
|
general.effective_me_pool_force_close_secs(),
|
||||||
|
general.me_pool_min_fresh_ratio,
|
||||||
|
general.me_hardswap_warmup_delay_min_ms,
|
||||||
|
general.me_hardswap_warmup_delay_max_ms,
|
||||||
|
general.me_hardswap_warmup_extra_passes,
|
||||||
|
general.me_hardswap_warmup_pass_backoff_base_ms,
|
||||||
|
general.me_bind_stale_mode,
|
||||||
|
general.me_bind_stale_ttl_secs,
|
||||||
|
general.me_secret_atomic_snapshot,
|
||||||
|
general.me_deterministic_writer_sort,
|
||||||
|
MeWriterPickMode::default(),
|
||||||
|
general.me_writer_pick_sample_size,
|
||||||
|
MeSocksKdfPolicy::default(),
|
||||||
|
general.me_writer_cmd_channel_capacity,
|
||||||
|
general.me_route_channel_capacity,
|
||||||
|
general.me_route_backpressure_base_timeout_ms,
|
||||||
|
general.me_route_backpressure_high_timeout_ms,
|
||||||
|
general.me_route_backpressure_high_watermark_pct,
|
||||||
|
general.me_reader_route_data_wait_ms,
|
||||||
|
general.me_health_interval_ms_unhealthy,
|
||||||
|
general.me_health_interval_ms_healthy,
|
||||||
|
general.me_warn_rate_limit_ms,
|
||||||
|
MeRouteNoWriterMode::default(),
|
||||||
|
general.me_route_no_writer_wait_ms,
|
||||||
|
general.me_route_inline_recovery_attempts,
|
||||||
|
general.me_route_inline_recovery_wait_ms,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn insert_draining_writer(
|
||||||
|
pool: &Arc<MePool>,
|
||||||
|
writer_id: u64,
|
||||||
|
drain_started_at_epoch_secs: u64,
|
||||||
|
bound_clients: usize,
|
||||||
|
drain_deadline_epoch_secs: u64,
|
||||||
|
) -> Vec<u64> {
|
||||||
|
let mut conn_ids = Vec::with_capacity(bound_clients);
|
||||||
|
let (tx, _writer_rx) = mpsc::channel::<WriterCommand>(8);
|
||||||
|
let writer = MeWriter {
|
||||||
|
id: writer_id,
|
||||||
|
addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 4500 + writer_id as u16),
|
||||||
|
source_ip: IpAddr::V4(Ipv4Addr::LOCALHOST),
|
||||||
|
writer_dc: 2,
|
||||||
|
generation: 1,
|
||||||
|
contour: Arc::new(AtomicU8::new(WriterContour::Draining.as_u8())),
|
||||||
|
created_at: Instant::now() - Duration::from_secs(writer_id),
|
||||||
|
tx: tx.clone(),
|
||||||
|
cancel: CancellationToken::new(),
|
||||||
|
degraded: Arc::new(AtomicBool::new(false)),
|
||||||
|
rtt_ema_ms_x10: Arc::new(AtomicU32::new(0)),
|
||||||
|
draining: Arc::new(AtomicBool::new(true)),
|
||||||
|
draining_started_at_epoch_secs: Arc::new(AtomicU64::new(drain_started_at_epoch_secs)),
|
||||||
|
drain_deadline_epoch_secs: Arc::new(AtomicU64::new(drain_deadline_epoch_secs)),
|
||||||
|
allow_drain_fallback: Arc::new(AtomicBool::new(false)),
|
||||||
|
};
|
||||||
|
pool.writers.write().await.push(writer);
|
||||||
|
pool.registry.register_writer(writer_id, tx).await;
|
||||||
|
pool.conn_count.fetch_add(1, Ordering::Relaxed);
|
||||||
|
for idx in 0..bound_clients {
|
||||||
|
let (conn_id, _rx) = pool.registry.register().await;
|
||||||
|
assert!(
|
||||||
|
pool.registry
|
||||||
|
.bind_writer(
|
||||||
|
conn_id,
|
||||||
|
writer_id,
|
||||||
|
ConnMeta {
|
||||||
|
target_dc: 2,
|
||||||
|
client_addr: SocketAddr::new(
|
||||||
|
IpAddr::V4(Ipv4Addr::LOCALHOST),
|
||||||
|
6200 + idx as u16,
|
||||||
|
),
|
||||||
|
our_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443),
|
||||||
|
proto_flags: 0,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
);
|
||||||
|
conn_ids.push(conn_id);
|
||||||
|
}
|
||||||
|
conn_ids
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn current_writer_ids(pool: &Arc<MePool>) -> Vec<u64> {
|
||||||
|
let mut writer_ids = pool
|
||||||
|
.writers
|
||||||
|
.read()
|
||||||
|
.await
|
||||||
|
.iter()
|
||||||
|
.map(|writer| writer.id)
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
writer_ids.sort_unstable();
|
||||||
|
writer_ids
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn writer_exists(pool: &Arc<MePool>, writer_id: u64) -> bool {
|
||||||
|
pool.writers
|
||||||
|
.read()
|
||||||
|
.await
|
||||||
|
.iter()
|
||||||
|
.any(|writer| writer.id == writer_id)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn set_writer_draining(pool: &Arc<MePool>, writer_id: u64, draining: bool) {
|
||||||
|
let writers = pool.writers.read().await;
|
||||||
|
if let Some(writer) = writers.iter().find(|writer| writer.id == writer_id) {
|
||||||
|
writer.draining.store(draining, Ordering::Relaxed);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn reap_draining_writers_drops_warn_state_for_removed_writer() {
|
||||||
|
let pool = make_pool(128).await;
|
||||||
|
let now_epoch_secs = MePool::now_epoch_secs();
|
||||||
|
let conn_ids =
|
||||||
|
insert_draining_writer(&pool, 7, now_epoch_secs.saturating_sub(180), 1, 0).await;
|
||||||
|
let mut warn_next_allowed = HashMap::new();
|
||||||
|
|
||||||
|
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||||
|
assert!(warn_next_allowed.contains_key(&7));
|
||||||
|
|
||||||
|
let _ = pool.remove_writer_and_close_clients(7).await;
|
||||||
|
assert!(pool.registry.get_writer(conn_ids[0]).await.is_none());
|
||||||
|
|
||||||
|
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||||
|
assert!(!warn_next_allowed.contains_key(&7));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn reap_draining_writers_removes_empty_draining_writers() {
|
||||||
|
let pool = make_pool(128).await;
|
||||||
|
let now_epoch_secs = MePool::now_epoch_secs();
|
||||||
|
insert_draining_writer(&pool, 1, now_epoch_secs.saturating_sub(40), 0, 0).await;
|
||||||
|
insert_draining_writer(&pool, 2, now_epoch_secs.saturating_sub(30), 0, 0).await;
|
||||||
|
insert_draining_writer(&pool, 3, now_epoch_secs.saturating_sub(20), 1, 0).await;
|
||||||
|
let mut warn_next_allowed = HashMap::new();
|
||||||
|
|
||||||
|
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||||
|
|
||||||
|
assert_eq!(current_writer_ids(&pool).await, vec![3]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn reap_draining_writers_overflow_closes_oldest_non_empty_writers() {
|
||||||
|
let pool = make_pool(2).await;
|
||||||
|
let now_epoch_secs = MePool::now_epoch_secs();
|
||||||
|
insert_draining_writer(&pool, 11, now_epoch_secs.saturating_sub(40), 1, 0).await;
|
||||||
|
insert_draining_writer(&pool, 22, now_epoch_secs.saturating_sub(30), 1, 0).await;
|
||||||
|
insert_draining_writer(&pool, 33, now_epoch_secs.saturating_sub(20), 1, 0).await;
|
||||||
|
insert_draining_writer(&pool, 44, now_epoch_secs.saturating_sub(10), 1, 0).await;
|
||||||
|
let mut warn_next_allowed = HashMap::new();
|
||||||
|
|
||||||
|
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||||
|
|
||||||
|
assert_eq!(current_writer_ids(&pool).await, vec![33, 44]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn reap_draining_writers_deadline_force_close_applies_under_threshold() {
|
||||||
|
let pool = make_pool(128).await;
|
||||||
|
let now_epoch_secs = MePool::now_epoch_secs();
|
||||||
|
insert_draining_writer(
|
||||||
|
&pool,
|
||||||
|
50,
|
||||||
|
now_epoch_secs.saturating_sub(15),
|
||||||
|
1,
|
||||||
|
now_epoch_secs.saturating_sub(1),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
let mut warn_next_allowed = HashMap::new();
|
||||||
|
|
||||||
|
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||||
|
|
||||||
|
assert!(current_writer_ids(&pool).await.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn reap_draining_writers_limits_closes_per_health_tick() {
|
||||||
|
let pool = make_pool(128).await;
|
||||||
|
let now_epoch_secs = MePool::now_epoch_secs();
|
||||||
|
let close_budget = health_drain_close_budget();
|
||||||
|
let writer_total = close_budget.saturating_add(19);
|
||||||
|
for writer_id in 1..=writer_total as u64 {
|
||||||
|
insert_draining_writer(
|
||||||
|
&pool,
|
||||||
|
writer_id,
|
||||||
|
now_epoch_secs.saturating_sub(20),
|
||||||
|
1,
|
||||||
|
now_epoch_secs.saturating_sub(1),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
let mut warn_next_allowed = HashMap::new();
|
||||||
|
|
||||||
|
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||||
|
|
||||||
|
assert_eq!(pool.writers.read().await.len(), writer_total - close_budget);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn reap_draining_writers_keeps_warn_state_for_deadline_backlog_writers() {
|
||||||
|
let pool = make_pool(0).await;
|
||||||
|
let now_epoch_secs = MePool::now_epoch_secs();
|
||||||
|
let close_budget = health_drain_close_budget();
|
||||||
|
let writer_total = close_budget.saturating_add(5);
|
||||||
|
for writer_id in 1..=writer_total as u64 {
|
||||||
|
insert_draining_writer(
|
||||||
|
&pool,
|
||||||
|
writer_id,
|
||||||
|
now_epoch_secs.saturating_sub(60),
|
||||||
|
1,
|
||||||
|
now_epoch_secs.saturating_sub(1),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
let target_writer_id = writer_total as u64;
|
||||||
|
let mut warn_next_allowed = HashMap::new();
|
||||||
|
warn_next_allowed.insert(
|
||||||
|
target_writer_id,
|
||||||
|
Instant::now() + Duration::from_secs(300),
|
||||||
|
);
|
||||||
|
|
||||||
|
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||||
|
|
||||||
|
assert!(writer_exists(&pool, target_writer_id).await);
|
||||||
|
assert!(warn_next_allowed.contains_key(&target_writer_id));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn reap_draining_writers_keeps_warn_state_for_overflow_backlog_writers() {
|
||||||
|
let pool = make_pool(1).await;
|
||||||
|
let now_epoch_secs = MePool::now_epoch_secs();
|
||||||
|
let close_budget = health_drain_close_budget();
|
||||||
|
let writer_total = close_budget.saturating_add(6);
|
||||||
|
for writer_id in 1..=writer_total as u64 {
|
||||||
|
insert_draining_writer(
|
||||||
|
&pool,
|
||||||
|
writer_id,
|
||||||
|
now_epoch_secs.saturating_sub(300).saturating_add(writer_id),
|
||||||
|
1,
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
let target_writer_id = writer_total.saturating_sub(1) as u64;
|
||||||
|
let mut warn_next_allowed = HashMap::new();
|
||||||
|
warn_next_allowed.insert(
|
||||||
|
target_writer_id,
|
||||||
|
Instant::now() + Duration::from_secs(300),
|
||||||
|
);
|
||||||
|
|
||||||
|
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||||
|
|
||||||
|
assert!(writer_exists(&pool, target_writer_id).await);
|
||||||
|
assert!(warn_next_allowed.contains_key(&target_writer_id));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn reap_draining_writers_drops_warn_state_when_writer_exits_draining_state() {
|
||||||
|
let pool = make_pool(128).await;
|
||||||
|
let now_epoch_secs = MePool::now_epoch_secs();
|
||||||
|
insert_draining_writer(&pool, 71, now_epoch_secs.saturating_sub(60), 1, 0).await;
|
||||||
|
|
||||||
|
let mut warn_next_allowed = HashMap::new();
|
||||||
|
warn_next_allowed.insert(71, Instant::now() + Duration::from_secs(300));
|
||||||
|
|
||||||
|
set_writer_draining(&pool, 71, false).await;
|
||||||
|
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||||
|
|
||||||
|
assert!(writer_exists(&pool, 71).await);
|
||||||
|
assert!(
|
||||||
|
!warn_next_allowed.contains_key(&71),
|
||||||
|
"warn cooldown state must be dropped after writer leaves draining state"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn reap_draining_writers_preserves_warn_state_across_multiple_budget_deferrals() {
|
||||||
|
let pool = make_pool(0).await;
|
||||||
|
let now_epoch_secs = MePool::now_epoch_secs();
|
||||||
|
let close_budget = health_drain_close_budget();
|
||||||
|
let writer_total = close_budget.saturating_mul(2).saturating_add(1);
|
||||||
|
for writer_id in 1..=writer_total as u64 {
|
||||||
|
insert_draining_writer(
|
||||||
|
&pool,
|
||||||
|
writer_id,
|
||||||
|
now_epoch_secs.saturating_sub(120),
|
||||||
|
1,
|
||||||
|
now_epoch_secs.saturating_sub(1),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
let tail_writer_id = writer_total as u64;
|
||||||
|
let mut warn_next_allowed = HashMap::new();
|
||||||
|
warn_next_allowed.insert(
|
||||||
|
tail_writer_id,
|
||||||
|
Instant::now() + Duration::from_secs(300),
|
||||||
|
);
|
||||||
|
|
||||||
|
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||||
|
assert!(writer_exists(&pool, tail_writer_id).await);
|
||||||
|
assert!(warn_next_allowed.contains_key(&tail_writer_id));
|
||||||
|
|
||||||
|
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||||
|
assert!(writer_exists(&pool, tail_writer_id).await);
|
||||||
|
assert!(warn_next_allowed.contains_key(&tail_writer_id));
|
||||||
|
|
||||||
|
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||||
|
assert!(!writer_exists(&pool, tail_writer_id).await);
|
||||||
|
assert!(
|
||||||
|
!warn_next_allowed.contains_key(&tail_writer_id),
|
||||||
|
"warn cooldown state must clear once writer is actually removed"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn reap_draining_writers_backlog_drains_across_ticks() {
|
||||||
|
let pool = make_pool(128).await;
|
||||||
|
let now_epoch_secs = MePool::now_epoch_secs();
|
||||||
|
let close_budget = health_drain_close_budget();
|
||||||
|
let writer_total = close_budget.saturating_mul(2).saturating_add(7);
|
||||||
|
for writer_id in 1..=writer_total as u64 {
|
||||||
|
insert_draining_writer(
|
||||||
|
&pool,
|
||||||
|
writer_id,
|
||||||
|
now_epoch_secs.saturating_sub(20),
|
||||||
|
1,
|
||||||
|
now_epoch_secs.saturating_sub(1),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
let mut warn_next_allowed = HashMap::new();
|
||||||
|
|
||||||
|
for _ in 0..8 {
|
||||||
|
if pool.writers.read().await.is_empty() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
assert!(pool.writers.read().await.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn reap_draining_writers_threshold_backlog_converges_to_threshold() {
|
||||||
|
let threshold = 5u64;
|
||||||
|
let pool = make_pool(threshold).await;
|
||||||
|
let now_epoch_secs = MePool::now_epoch_secs();
|
||||||
|
let close_budget = health_drain_close_budget();
|
||||||
|
let writer_total = threshold as usize + close_budget.saturating_add(12);
|
||||||
|
for writer_id in 1..=writer_total as u64 {
|
||||||
|
insert_draining_writer(
|
||||||
|
&pool,
|
||||||
|
writer_id,
|
||||||
|
now_epoch_secs.saturating_sub(200).saturating_add(writer_id),
|
||||||
|
1,
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
let mut warn_next_allowed = HashMap::new();
|
||||||
|
|
||||||
|
for _ in 0..16 {
|
||||||
|
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||||
|
if pool.writers.read().await.len() <= threshold as usize {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assert_eq!(pool.writers.read().await.len(), threshold as usize);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn reap_draining_writers_threshold_zero_preserves_non_expired_non_empty_writers() {
|
||||||
|
let pool = make_pool(0).await;
|
||||||
|
let now_epoch_secs = MePool::now_epoch_secs();
|
||||||
|
insert_draining_writer(&pool, 10, now_epoch_secs.saturating_sub(40), 1, 0).await;
|
||||||
|
insert_draining_writer(&pool, 20, now_epoch_secs.saturating_sub(30), 1, 0).await;
|
||||||
|
insert_draining_writer(&pool, 30, now_epoch_secs.saturating_sub(20), 1, 0).await;
|
||||||
|
let mut warn_next_allowed = HashMap::new();
|
||||||
|
|
||||||
|
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||||
|
|
||||||
|
assert_eq!(current_writer_ids(&pool).await, vec![10, 20, 30]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn reap_draining_writers_prioritizes_force_close_before_empty_cleanup() {
|
||||||
|
let pool = make_pool(128).await;
|
||||||
|
let now_epoch_secs = MePool::now_epoch_secs();
|
||||||
|
let close_budget = health_drain_close_budget();
|
||||||
|
for writer_id in 1..=close_budget as u64 {
|
||||||
|
insert_draining_writer(
|
||||||
|
&pool,
|
||||||
|
writer_id,
|
||||||
|
now_epoch_secs.saturating_sub(20),
|
||||||
|
1,
|
||||||
|
now_epoch_secs.saturating_sub(1),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
let empty_writer_id = close_budget as u64 + 1;
|
||||||
|
insert_draining_writer(&pool, empty_writer_id, now_epoch_secs.saturating_sub(20), 0, 0).await;
|
||||||
|
let mut warn_next_allowed = HashMap::new();
|
||||||
|
|
||||||
|
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||||
|
|
||||||
|
assert_eq!(current_writer_ids(&pool).await, vec![empty_writer_id]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn reap_draining_writers_empty_cleanup_does_not_increment_force_close_metric() {
|
||||||
|
let pool = make_pool(128).await;
|
||||||
|
let now_epoch_secs = MePool::now_epoch_secs();
|
||||||
|
insert_draining_writer(&pool, 1, now_epoch_secs.saturating_sub(60), 0, 0).await;
|
||||||
|
insert_draining_writer(&pool, 2, now_epoch_secs.saturating_sub(50), 0, 0).await;
|
||||||
|
let mut warn_next_allowed = HashMap::new();
|
||||||
|
|
||||||
|
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||||
|
|
||||||
|
assert!(current_writer_ids(&pool).await.is_empty());
|
||||||
|
assert_eq!(pool.stats.get_pool_force_close_total(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn reap_draining_writers_handles_duplicate_force_close_requests_for_same_writer() {
|
||||||
|
let pool = make_pool(1).await;
|
||||||
|
let now_epoch_secs = MePool::now_epoch_secs();
|
||||||
|
insert_draining_writer(
|
||||||
|
&pool,
|
||||||
|
10,
|
||||||
|
now_epoch_secs.saturating_sub(30),
|
||||||
|
1,
|
||||||
|
now_epoch_secs.saturating_sub(1),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
insert_draining_writer(
|
||||||
|
&pool,
|
||||||
|
20,
|
||||||
|
now_epoch_secs.saturating_sub(20),
|
||||||
|
1,
|
||||||
|
now_epoch_secs.saturating_sub(1),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
let mut warn_next_allowed = HashMap::new();
|
||||||
|
|
||||||
|
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||||
|
|
||||||
|
assert!(current_writer_ids(&pool).await.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn reap_draining_writers_warn_state_never_exceeds_live_draining_population_under_churn() {
|
||||||
|
let pool = make_pool(128).await;
|
||||||
|
let now_epoch_secs = MePool::now_epoch_secs();
|
||||||
|
let mut warn_next_allowed = HashMap::new();
|
||||||
|
|
||||||
|
for wave in 0..12u64 {
|
||||||
|
for offset in 0..9u64 {
|
||||||
|
insert_draining_writer(
|
||||||
|
&pool,
|
||||||
|
wave * 100 + offset,
|
||||||
|
now_epoch_secs.saturating_sub(120 + offset),
|
||||||
|
1,
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||||
|
assert!(warn_next_allowed.len() <= pool.writers.read().await.len());
|
||||||
|
|
||||||
|
let existing_writer_ids = current_writer_ids(&pool).await;
|
||||||
|
for writer_id in existing_writer_ids.into_iter().take(4) {
|
||||||
|
let _ = pool.remove_writer_and_close_clients(writer_id).await;
|
||||||
|
}
|
||||||
|
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||||
|
assert!(warn_next_allowed.len() <= pool.writers.read().await.len());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn reap_draining_writers_mixed_backlog_converges_without_leaking_warn_state() {
|
||||||
|
let pool = make_pool(6).await;
|
||||||
|
let now_epoch_secs = MePool::now_epoch_secs();
|
||||||
|
let mut warn_next_allowed = HashMap::new();
|
||||||
|
|
||||||
|
for writer_id in 1..=18u64 {
|
||||||
|
let bound_clients = if writer_id % 3 == 0 { 0 } else { 1 };
|
||||||
|
let deadline = if writer_id % 2 == 0 {
|
||||||
|
now_epoch_secs.saturating_sub(1)
|
||||||
|
} else {
|
||||||
|
0
|
||||||
|
};
|
||||||
|
insert_draining_writer(
|
||||||
|
&pool,
|
||||||
|
writer_id,
|
||||||
|
now_epoch_secs.saturating_sub(300).saturating_add(writer_id),
|
||||||
|
bound_clients,
|
||||||
|
deadline,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
for _ in 0..16 {
|
||||||
|
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||||
|
if pool.writers.read().await.len() <= 6 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assert!(pool.writers.read().await.len() <= 6);
|
||||||
|
assert!(warn_next_allowed.len() <= pool.writers.read().await.len());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn general_config_default_drain_threshold_remains_enabled() {
|
||||||
|
assert_eq!(GeneralConfig::default().me_pool_drain_threshold, 128);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn reap_draining_writers_does_not_close_writer_that_became_non_empty_after_snapshot() {
|
||||||
|
let pool = make_pool(128).await;
|
||||||
|
let now_epoch_secs = MePool::now_epoch_secs();
|
||||||
|
|
||||||
|
let empty_writer_id = 700u64;
|
||||||
|
insert_draining_writer(
|
||||||
|
&pool,
|
||||||
|
empty_writer_id,
|
||||||
|
now_epoch_secs.saturating_sub(60),
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let stale_empty_snapshot = vec![empty_writer_id];
|
||||||
|
let (rebound_conn_id, _rx) = pool.registry.register().await;
|
||||||
|
assert!(
|
||||||
|
pool.registry
|
||||||
|
.bind_writer(
|
||||||
|
rebound_conn_id,
|
||||||
|
empty_writer_id,
|
||||||
|
ConnMeta {
|
||||||
|
target_dc: 2,
|
||||||
|
client_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 9050),
|
||||||
|
our_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443),
|
||||||
|
proto_flags: 0,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await,
|
||||||
|
"writer should accept a new bind after stale empty snapshot"
|
||||||
|
);
|
||||||
|
|
||||||
|
for writer_id in stale_empty_snapshot {
|
||||||
|
assert!(
|
||||||
|
!pool.remove_writer_if_empty(writer_id).await,
|
||||||
|
"atomic empty cleanup must reject writers that gained bound clients"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
writer_exists(&pool, empty_writer_id).await,
|
||||||
|
"empty-path cleanup must not remove a writer that gained a bound client"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
pool.registry.get_writer(rebound_conn_id).await.map(|w| w.writer_id),
|
||||||
|
Some(empty_writer_id)
|
||||||
|
);
|
||||||
|
|
||||||
|
let _ = pool.registry.unregister(rebound_conn_id).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn prune_closed_writers_closes_bound_clients_when_writer_is_non_empty() {
|
||||||
|
let pool = make_pool(128).await;
|
||||||
|
let now_epoch_secs = MePool::now_epoch_secs();
|
||||||
|
let conn_ids = insert_draining_writer(&pool, 910, now_epoch_secs.saturating_sub(60), 1, 0).await;
|
||||||
|
|
||||||
|
pool.prune_closed_writers().await;
|
||||||
|
|
||||||
|
assert!(!writer_exists(&pool, 910).await);
|
||||||
|
assert!(pool.registry.get_writer(conn_ids[0]).await.is_none());
|
||||||
|
}
|
||||||
|
|
@ -21,6 +21,12 @@ mod secret;
|
||||||
mod selftest;
|
mod selftest;
|
||||||
mod wire;
|
mod wire;
|
||||||
mod pool_status;
|
mod pool_status;
|
||||||
|
#[cfg(test)]
|
||||||
|
mod health_regression_tests;
|
||||||
|
#[cfg(test)]
|
||||||
|
mod health_integration_tests;
|
||||||
|
#[cfg(test)]
|
||||||
|
mod health_adversarial_tests;
|
||||||
|
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -42,11 +42,10 @@ impl MePool {
|
||||||
}
|
}
|
||||||
|
|
||||||
for writer_id in closed_writer_ids {
|
for writer_id in closed_writer_ids {
|
||||||
if self.registry.is_writer_empty(writer_id).await {
|
if self.remove_writer_if_empty(writer_id).await {
|
||||||
let _ = self.remove_writer_only(writer_id).await;
|
continue;
|
||||||
} else {
|
|
||||||
let _ = self.remove_writer_and_close_clients(writer_id).await;
|
|
||||||
}
|
}
|
||||||
|
let _ = self.remove_writer_and_close_clients(writer_id).await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -501,6 +500,17 @@ impl MePool {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) async fn remove_writer_if_empty(self: &Arc<Self>, writer_id: u64) -> bool {
|
||||||
|
if !self.registry.unregister_writer_if_empty(writer_id).await {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// The registry empty-check and unregister are atomic with respect to binds,
|
||||||
|
// so remove_writer_only cannot return active bound sessions here.
|
||||||
|
let _ = self.remove_writer_only(writer_id).await;
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
async fn remove_writer_only(self: &Arc<Self>, writer_id: u64) -> Vec<BoundConn> {
|
async fn remove_writer_only(self: &Arc<Self>, writer_id: u64) -> Vec<BoundConn> {
|
||||||
let mut close_tx: Option<mpsc::Sender<WriterCommand>> = None;
|
let mut close_tx: Option<mpsc::Sender<WriterCommand>> = None;
|
||||||
let mut removed_addr: Option<SocketAddr> = None;
|
let mut removed_addr: Option<SocketAddr> = None;
|
||||||
|
|
|
||||||
|
|
@ -437,6 +437,23 @@ impl ConnRegistry {
|
||||||
.unwrap_or(true)
|
.unwrap_or(true)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn unregister_writer_if_empty(&self, writer_id: u64) -> bool {
|
||||||
|
let mut inner = self.inner.write().await;
|
||||||
|
let Some(conn_ids) = inner.conns_for_writer.get(&writer_id) else {
|
||||||
|
// Writer is already absent from the registry.
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
if !conn_ids.is_empty() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
inner.writers.remove(&writer_id);
|
||||||
|
inner.last_meta_for_writer.remove(&writer_id);
|
||||||
|
inner.writer_idle_since_epoch_secs.remove(&writer_id);
|
||||||
|
inner.conns_for_writer.remove(&writer_id);
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
pub(super) async fn non_empty_writer_ids(&self, writer_ids: &[u64]) -> HashSet<u64> {
|
pub(super) async fn non_empty_writer_ids(&self, writer_ids: &[u64]) -> HashSet<u64> {
|
||||||
let inner = self.inner.read().await;
|
let inner = self.inner.read().await;
|
||||||
let mut out = HashSet::<u64>::with_capacity(writer_ids.len());
|
let mut out = HashSet::<u64>::with_capacity(writer_ids.len());
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue