Merge pull request #463 from DavidOsipov/pr-sec-1

[WIP] Enhance metrics configuration, add health monitoring tests, security hardening, perf optimizations & loads of tests
This commit is contained in:
Alexey 2026-03-18 23:02:58 +03:00 committed by GitHub
commit 44376b5652
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
46 changed files with 15357 additions and 436 deletions

1
.gitignore vendored
View File

@ -21,3 +21,4 @@ target
#.idea/
proxy-secret
coverage-html/

View File

@ -390,6 +390,12 @@ you MUST explain why existing invariants remain valid.
- Do not modify existing tests unless the task explicitly requires it.
- Do not weaken assertions.
- Preserve determinism in testable components.
- Bug-first forces the discipline of proving you understand a bug before you fix it. Tests written after a fix almost always pass trivially and catch nothing new.
- Invariants over scenarios is the core shift. The route_mode table alone would have caught both BUG-1 and BUG-2 before they were written — "snapshot equals watch state after any transition burst" is a two-line property test that fails immediately on the current diverged-atomics code.
- Differential/model catches logic drift over time.
- Scheduler pressure is specifically aimed at the concurrent state bugs that keep reappearing. A single-threaded happy-path test of set_mode will never find subtle bugs; 10,000 concurrent calls will find it on the first run.
- Mutation gate answers your original complaint directly. It measures test power. If you can remove a bounds check and nothing breaks, the suite isn't covering that branch yet — it just says so explicitly.
- Dead parameter is a code smell rule.
### 15. Security Constraints

59
Cargo.lock generated
View File

@ -425,6 +425,32 @@ dependencies = [
"cipher",
]
[[package]]
name = "curve25519-dalek"
version = "4.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "97fb8b7c4503de7d6ae7b42ab72a5a59857b4c937ec27a3d4539dba95b5ab2be"
dependencies = [
"cfg-if",
"cpufeatures",
"curve25519-dalek-derive",
"fiat-crypto",
"rustc_version",
"subtle",
"zeroize",
]
[[package]]
name = "curve25519-dalek-derive"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f46882e17999c6cc590af592290432be3bce0428cb0d5f8b6715e4dc7b383eb3"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.114",
]
[[package]]
name = "dashmap"
version = "5.5.3"
@ -517,6 +543,12 @@ version = "2.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
[[package]]
name = "fiat-crypto"
version = "0.2.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "28dea519a9695b9977216879a3ebfddf92f1c08c05d984f8996aecd6ecdc811d"
[[package]]
name = "filetime"
version = "0.2.27"
@ -1609,7 +1641,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1"
dependencies = [
"rand_chacha",
"rand_core",
"rand_core 0.9.5",
]
[[package]]
@ -1619,9 +1651,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb"
dependencies = [
"ppv-lite86",
"rand_core",
"rand_core 0.9.5",
]
[[package]]
name = "rand_core"
version = "0.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
[[package]]
name = "rand_core"
version = "0.9.5"
@ -1637,7 +1675,7 @@ version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "513962919efc330f829edb2535844d1b912b0fbe2ca165d613e4e8788bb05a5a"
dependencies = [
"rand_core",
"rand_core 0.9.5",
]
[[package]]
@ -2093,7 +2131,7 @@ dependencies = [
[[package]]
name = "telemt"
version = "3.3.19"
version = "3.3.20"
dependencies = [
"aes",
"anyhow",
@ -2145,6 +2183,7 @@ dependencies = [
"tracing-subscriber",
"url",
"webpki-roots 0.26.11",
"x25519-dalek",
"x509-parser",
"zeroize",
]
@ -3144,6 +3183,18 @@ version = "0.6.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9edde0db4769d2dc68579893f2306b26c6ecfbe0ef499b013d731b7b9247e0b9"
[[package]]
name = "x25519-dalek"
version = "2.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c7e468321c81fb07fa7f4c636c3972b9100f0346e5b6a9f2bd0603a52f7ed277"
dependencies = [
"curve25519-dalek",
"rand_core 0.6.4",
"serde",
"zeroize",
]
[[package]]
name = "x509-parser"
version = "0.15.1"

View File

@ -52,6 +52,7 @@ regex = "1.11"
crossbeam-queue = "0.3"
num-bigint = "0.4"
num-traits = "0.2"
x25519-dalek = "2"
anyhow = "1.0"
# HTTP

View File

@ -32,6 +32,7 @@ show = "*"
port = 443
# proxy_protocol = false # Enable if behind HAProxy/nginx with PROXY protocol
# metrics_port = 9090
# metrics_listen = "0.0.0.0:9090" # Listen address for metrics (overrides metrics_port)
# metrics_whitelist = ["127.0.0.1", "::1", "0.0.0.0/0"]
[server.api]

View File

@ -38,8 +38,9 @@ umweltschutz.de -> A-запись 198.18.88.88
В конфигурации Telemt:
```
tls_domain = umweltschutz.de
```toml
[censorship]
tls_domain = "umweltschutz.de"
```
Этот домен используется клиентом как SNI в ClientHello
@ -56,8 +57,9 @@ tls_domain = umweltschutz.de
В конфигурации Telemt:
```
mask_host = 127.0.0.1
```toml
[censorship]
mask_host = "127.0.0.1"
mask_port = 8443
```
@ -151,16 +153,18 @@ mask_host:mask_port
Например:
```
tls_domain = github.com
mask_host = github.com
```toml
[censorship]
tls_domain = "github.com"
mask_host = "github.com"
mask_port = 443
```
или
```
mask_host = 140.82.121.4
```toml
[censorship]
mask_host = "140.82.121.4"
```
В этом случае:

View File

@ -239,7 +239,7 @@ tls_full_cert_ttl_secs = 90
[access]
replay_check_len = 65536
replay_window_secs = 1800
replay_window_secs = 120
ignore_time_skew = false
[access.users]

View File

@ -73,7 +73,9 @@ pub(crate) fn default_replay_check_len() -> usize {
}
pub(crate) fn default_replay_window_secs() -> u64 {
1800
// Keep replay cache TTL tight by default to reduce replay surface.
// Deployments with higher RTT or longer reconnect jitter can override this in config.
120
}
pub(crate) fn default_handshake_timeout() -> u64 {
@ -456,11 +458,11 @@ pub(crate) fn default_tls_full_cert_ttl_secs() -> u64 {
}
pub(crate) fn default_server_hello_delay_min_ms() -> u64 {
0
8
}
pub(crate) fn default_server_hello_delay_max_ms() -> u64 {
0
24
}
pub(crate) fn default_alpn_enforce() -> bool {

View File

@ -1163,9 +1163,17 @@ pub struct ServerConfig {
#[serde(default)]
pub proxy_protocol_trusted_cidrs: Vec<IpNetwork>,
/// Port for the Prometheus-compatible metrics endpoint.
/// Enables metrics when set; binds on all interfaces (dual-stack) by default.
#[serde(default)]
pub metrics_port: Option<u16>,
/// Listen address for metrics in `IP:PORT` format (e.g. `"127.0.0.1:9090"`).
/// When set, takes precedence over `metrics_port` and binds on the specified address only.
#[serde(default)]
pub metrics_listen: Option<String>,
/// CIDR whitelist for the metrics endpoint.
#[serde(default = "default_metrics_whitelist")]
pub metrics_whitelist: Vec<IpNetwork>,
@ -1194,6 +1202,7 @@ impl Default for ServerConfig {
proxy_protocol_header_timeout_ms: default_proxy_protocol_header_timeout_ms(),
proxy_protocol_trusted_cidrs: Vec::new(),
metrics_port: None,
metrics_listen: None,
metrics_whitelist: default_metrics_whitelist(),
api: ApiConfig::default(),
listeners: Vec::new(),

View File

@ -7,8 +7,9 @@ use std::net::IpAddr;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
use std::sync::Mutex;
use tokio::sync::RwLock;
use tokio::sync::{Mutex as AsyncMutex, RwLock};
use crate::config::UserMaxUniqueIpsMode;
@ -21,6 +22,8 @@ pub struct UserIpTracker {
limit_mode: Arc<RwLock<UserMaxUniqueIpsMode>>,
limit_window: Arc<RwLock<Duration>>,
last_compact_epoch_secs: Arc<AtomicU64>,
pub(crate) cleanup_queue: Arc<Mutex<Vec<(String, IpAddr)>>>,
cleanup_drain_lock: Arc<AsyncMutex<()>>,
}
impl UserIpTracker {
@ -33,6 +36,67 @@ impl UserIpTracker {
limit_mode: Arc::new(RwLock::new(UserMaxUniqueIpsMode::ActiveWindow)),
limit_window: Arc::new(RwLock::new(Duration::from_secs(30))),
last_compact_epoch_secs: Arc::new(AtomicU64::new(0)),
cleanup_queue: Arc::new(Mutex::new(Vec::new())),
cleanup_drain_lock: Arc::new(AsyncMutex::new(())),
}
}
pub fn enqueue_cleanup(&self, user: String, ip: IpAddr) {
match self.cleanup_queue.lock() {
Ok(mut queue) => queue.push((user, ip)),
Err(poisoned) => {
let mut queue = poisoned.into_inner();
queue.push((user.clone(), ip));
self.cleanup_queue.clear_poison();
tracing::warn!(
"UserIpTracker cleanup_queue lock poisoned; recovered and enqueued IP cleanup for {} ({})",
user,
ip
);
}
}
}
pub(crate) async fn drain_cleanup_queue(&self) {
// Serialize queue draining and active-IP mutation so check-and-add cannot
// observe stale active entries that are already queued for removal.
let _drain_guard = self.cleanup_drain_lock.lock().await;
let to_remove = {
match self.cleanup_queue.lock() {
Ok(mut queue) => {
if queue.is_empty() {
return;
}
std::mem::take(&mut *queue)
}
Err(poisoned) => {
let mut queue = poisoned.into_inner();
if queue.is_empty() {
self.cleanup_queue.clear_poison();
return;
}
let drained = std::mem::take(&mut *queue);
self.cleanup_queue.clear_poison();
drained
}
}
};
let mut active_ips = self.active_ips.write().await;
for (user, ip) in to_remove {
if let Some(user_ips) = active_ips.get_mut(&user) {
if let Some(count) = user_ips.get_mut(&ip) {
if *count > 1 {
*count -= 1;
} else {
user_ips.remove(&ip);
}
}
if user_ips.is_empty() {
active_ips.remove(&user);
}
}
}
}
@ -118,6 +182,7 @@ impl UserIpTracker {
}
pub async fn check_and_add(&self, username: &str, ip: IpAddr) -> Result<(), String> {
self.drain_cleanup_queue().await;
self.maybe_compact_empty_users().await;
let default_max_ips = *self.default_max_ips.read().await;
let limit = {
@ -194,6 +259,7 @@ impl UserIpTracker {
}
pub async fn get_recent_counts_for_users(&self, users: &[String]) -> HashMap<String, usize> {
self.drain_cleanup_queue().await;
let window = *self.limit_window.read().await;
let now = Instant::now();
let recent_ips = self.recent_ips.read().await;
@ -214,6 +280,7 @@ impl UserIpTracker {
}
pub async fn get_active_ips_for_users(&self, users: &[String]) -> HashMap<String, Vec<IpAddr>> {
self.drain_cleanup_queue().await;
let active_ips = self.active_ips.read().await;
let mut out = HashMap::with_capacity(users.len());
for user in users {
@ -228,6 +295,7 @@ impl UserIpTracker {
}
pub async fn get_recent_ips_for_users(&self, users: &[String]) -> HashMap<String, Vec<IpAddr>> {
self.drain_cleanup_queue().await;
let window = *self.limit_window.read().await;
let now = Instant::now();
let recent_ips = self.recent_ips.read().await;
@ -250,11 +318,13 @@ impl UserIpTracker {
}
pub async fn get_active_ip_count(&self, username: &str) -> usize {
self.drain_cleanup_queue().await;
let active_ips = self.active_ips.read().await;
active_ips.get(username).map(|ips| ips.len()).unwrap_or(0)
}
pub async fn get_active_ips(&self, username: &str) -> Vec<IpAddr> {
self.drain_cleanup_queue().await;
let active_ips = self.active_ips.read().await;
active_ips
.get(username)
@ -263,6 +333,7 @@ impl UserIpTracker {
}
pub async fn get_stats(&self) -> Vec<(String, usize, usize)> {
self.drain_cleanup_queue().await;
let active_ips = self.active_ips.read().await;
let max_ips = self.max_ips.read().await;
let default_max_ips = *self.default_max_ips.read().await;
@ -301,6 +372,7 @@ impl UserIpTracker {
}
pub async fn is_ip_active(&self, username: &str, ip: IpAddr) -> bool {
self.drain_cleanup_queue().await;
let active_ips = self.active_ips.read().await;
active_ips
.get(username)

View File

@ -0,0 +1,619 @@
use std::collections::HashMap;
use std::net::{IpAddr, Ipv4Addr};
use std::sync::Arc;
use std::time::Duration;
use crate::config::UserMaxUniqueIpsMode;
use crate::ip_tracker::UserIpTracker;
fn ip_from_idx(idx: u32) -> IpAddr {
let a = 10u8;
let b = ((idx / 65_536) % 256) as u8;
let c = ((idx / 256) % 256) as u8;
let d = (idx % 256) as u8;
IpAddr::V4(Ipv4Addr::new(a, b, c, d))
}
#[tokio::test]
async fn active_window_enforces_large_unique_ip_burst() {
let tracker = UserIpTracker::new();
tracker.set_user_limit("burst_user", 64).await;
tracker
.set_limit_policy(UserMaxUniqueIpsMode::ActiveWindow, 30)
.await;
for idx in 0..64 {
assert!(tracker.check_and_add("burst_user", ip_from_idx(idx)).await.is_ok());
}
assert!(tracker.check_and_add("burst_user", ip_from_idx(9_999)).await.is_err());
assert_eq!(tracker.get_active_ip_count("burst_user").await, 64);
}
#[tokio::test]
async fn global_limit_applies_across_many_users() {
let tracker = UserIpTracker::new();
tracker.load_limits(3, &HashMap::new()).await;
for user_idx in 0..150u32 {
let user = format!("u{}", user_idx);
assert!(tracker.check_and_add(&user, ip_from_idx(user_idx * 10)).await.is_ok());
assert!(tracker
.check_and_add(&user, ip_from_idx(user_idx * 10 + 1))
.await
.is_ok());
assert!(tracker
.check_and_add(&user, ip_from_idx(user_idx * 10 + 2))
.await
.is_ok());
assert!(tracker
.check_and_add(&user, ip_from_idx(user_idx * 10 + 3))
.await
.is_err());
}
assert_eq!(tracker.get_stats().await.len(), 150);
}
#[tokio::test]
async fn user_zero_override_falls_back_to_global_limit() {
let tracker = UserIpTracker::new();
let mut limits = HashMap::new();
limits.insert("target".to_string(), 0);
tracker.load_limits(2, &limits).await;
assert!(tracker.check_and_add("target", ip_from_idx(1)).await.is_ok());
assert!(tracker.check_and_add("target", ip_from_idx(2)).await.is_ok());
assert!(tracker.check_and_add("target", ip_from_idx(3)).await.is_err());
assert_eq!(tracker.get_user_limit("target").await, Some(2));
}
#[tokio::test]
async fn remove_ip_is_idempotent_after_counter_reaches_zero() {
let tracker = UserIpTracker::new();
tracker.set_user_limit("u", 2).await;
let ip = ip_from_idx(42);
tracker.check_and_add("u", ip).await.unwrap();
tracker.remove_ip("u", ip).await;
tracker.remove_ip("u", ip).await;
tracker.remove_ip("u", ip).await;
assert_eq!(tracker.get_active_ip_count("u").await, 0);
assert!(!tracker.is_ip_active("u", ip).await);
}
#[tokio::test]
async fn clear_user_ips_resets_active_and_recent() {
let tracker = UserIpTracker::new();
tracker.set_user_limit("u", 10).await;
for idx in 0..6 {
tracker.check_and_add("u", ip_from_idx(idx)).await.unwrap();
}
tracker.clear_user_ips("u").await;
assert_eq!(tracker.get_active_ip_count("u").await, 0);
let counts = tracker
.get_recent_counts_for_users(&["u".to_string()])
.await;
assert_eq!(counts.get("u").copied().unwrap_or(0), 0);
}
#[tokio::test]
async fn clear_all_resets_multi_user_state() {
let tracker = UserIpTracker::new();
for user_idx in 0..80u32 {
let user = format!("u{}", user_idx);
for ip_idx in 0..3 {
tracker
.check_and_add(&user, ip_from_idx(user_idx * 100 + ip_idx))
.await
.unwrap();
}
}
tracker.clear_all().await;
assert!(tracker.get_stats().await.is_empty());
let users = (0..80u32)
.map(|idx| format!("u{}", idx))
.collect::<Vec<_>>();
let recent = tracker.get_recent_counts_for_users(&users).await;
assert!(recent.values().all(|count| *count == 0));
}
#[tokio::test]
async fn get_active_ips_for_users_are_sorted() {
let tracker = UserIpTracker::new();
tracker.set_user_limit("user", 10).await;
tracker
.check_and_add("user", IpAddr::V4(Ipv4Addr::new(10, 0, 0, 9)))
.await
.unwrap();
tracker
.check_and_add("user", IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)))
.await
.unwrap();
tracker
.check_and_add("user", IpAddr::V4(Ipv4Addr::new(10, 0, 0, 5)))
.await
.unwrap();
let map = tracker
.get_active_ips_for_users(&["user".to_string()])
.await;
let ips = map.get("user").cloned().unwrap_or_default();
assert_eq!(
ips,
vec![
IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
IpAddr::V4(Ipv4Addr::new(10, 0, 0, 5)),
IpAddr::V4(Ipv4Addr::new(10, 0, 0, 9)),
]
);
}
#[tokio::test]
async fn get_recent_ips_for_users_are_sorted() {
let tracker = UserIpTracker::new();
tracker.set_user_limit("user", 10).await;
tracker
.check_and_add("user", IpAddr::V4(Ipv4Addr::new(10, 1, 0, 9)))
.await
.unwrap();
tracker
.check_and_add("user", IpAddr::V4(Ipv4Addr::new(10, 1, 0, 1)))
.await
.unwrap();
tracker
.check_and_add("user", IpAddr::V4(Ipv4Addr::new(10, 1, 0, 5)))
.await
.unwrap();
let map = tracker
.get_recent_ips_for_users(&["user".to_string()])
.await;
let ips = map.get("user").cloned().unwrap_or_default();
assert_eq!(
ips,
vec![
IpAddr::V4(Ipv4Addr::new(10, 1, 0, 1)),
IpAddr::V4(Ipv4Addr::new(10, 1, 0, 5)),
IpAddr::V4(Ipv4Addr::new(10, 1, 0, 9)),
]
);
}
#[tokio::test]
async fn time_window_expires_for_large_rotation() {
let tracker = UserIpTracker::new();
tracker.set_user_limit("tw", 1).await;
tracker
.set_limit_policy(UserMaxUniqueIpsMode::TimeWindow, 1)
.await;
tracker.check_and_add("tw", ip_from_idx(1)).await.unwrap();
tracker.remove_ip("tw", ip_from_idx(1)).await;
assert!(tracker.check_and_add("tw", ip_from_idx(2)).await.is_err());
tokio::time::sleep(Duration::from_millis(1_100)).await;
assert!(tracker.check_and_add("tw", ip_from_idx(2)).await.is_ok());
}
#[tokio::test]
async fn combined_mode_blocks_recent_after_disconnect() {
let tracker = UserIpTracker::new();
tracker.set_user_limit("cmb", 1).await;
tracker
.set_limit_policy(UserMaxUniqueIpsMode::Combined, 2)
.await;
tracker.check_and_add("cmb", ip_from_idx(11)).await.unwrap();
tracker.remove_ip("cmb", ip_from_idx(11)).await;
assert!(tracker.check_and_add("cmb", ip_from_idx(12)).await.is_err());
}
#[tokio::test]
async fn load_limits_replaces_large_limit_map() {
let tracker = UserIpTracker::new();
let mut first = HashMap::new();
let mut second = HashMap::new();
for idx in 0..300usize {
first.insert(format!("u{}", idx), 2usize);
}
for idx in 150..450usize {
second.insert(format!("u{}", idx), 4usize);
}
tracker.load_limits(0, &first).await;
tracker.load_limits(0, &second).await;
assert_eq!(tracker.get_user_limit("u20").await, None);
assert_eq!(tracker.get_user_limit("u200").await, Some(4));
assert_eq!(tracker.get_user_limit("u420").await, Some(4));
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn concurrent_same_user_unique_ip_pressure_stays_bounded() {
let tracker = Arc::new(UserIpTracker::new());
tracker.set_user_limit("hot", 32).await;
tracker
.set_limit_policy(UserMaxUniqueIpsMode::ActiveWindow, 30)
.await;
let mut handles = Vec::new();
for worker in 0..16u32 {
let tracker_cloned = tracker.clone();
handles.push(tokio::spawn(async move {
let base = worker * 200;
for step in 0..200u32 {
let _ = tracker_cloned
.check_and_add("hot", ip_from_idx(base + step))
.await;
}
}));
}
for handle in handles {
handle.await.unwrap();
}
assert!(tracker.get_active_ip_count("hot").await <= 32);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn concurrent_many_users_isolate_limits() {
let tracker = Arc::new(UserIpTracker::new());
tracker.load_limits(4, &HashMap::new()).await;
let mut handles = Vec::new();
for user_idx in 0..120u32 {
let tracker_cloned = tracker.clone();
handles.push(tokio::spawn(async move {
let user = format!("u{}", user_idx);
for ip_idx in 0..10u32 {
let _ = tracker_cloned
.check_and_add(&user, ip_from_idx(user_idx * 1_000 + ip_idx))
.await;
}
}));
}
for handle in handles {
handle.await.unwrap();
}
let stats = tracker.get_stats().await;
assert_eq!(stats.len(), 120);
assert!(stats.iter().all(|(_, active, limit)| *active <= 4 && *limit == 4));
}
#[tokio::test]
async fn same_ip_reconnect_high_frequency_keeps_single_unique() {
let tracker = UserIpTracker::new();
tracker.set_user_limit("same", 2).await;
let ip = ip_from_idx(9);
for _ in 0..2_000 {
tracker.check_and_add("same", ip).await.unwrap();
}
assert_eq!(tracker.get_active_ip_count("same").await, 1);
assert!(tracker.is_ip_active("same", ip).await);
}
#[tokio::test]
async fn format_stats_contains_expected_limited_and_unlimited_markers() {
let tracker = UserIpTracker::new();
tracker.set_user_limit("limited", 2).await;
tracker.check_and_add("limited", ip_from_idx(1)).await.unwrap();
tracker.check_and_add("open", ip_from_idx(2)).await.unwrap();
let text = tracker.format_stats().await;
assert!(text.contains("limited"));
assert!(text.contains("open"));
assert!(text.contains("unlimited"));
}
#[tokio::test]
async fn stats_report_global_default_for_users_without_override() {
let tracker = UserIpTracker::new();
tracker.load_limits(5, &HashMap::new()).await;
tracker.check_and_add("a", ip_from_idx(1)).await.unwrap();
tracker.check_and_add("b", ip_from_idx(2)).await.unwrap();
let stats = tracker.get_stats().await;
assert!(stats.iter().any(|(user, _, limit)| user == "a" && *limit == 5));
assert!(stats.iter().any(|(user, _, limit)| user == "b" && *limit == 5));
}
#[tokio::test]
async fn stress_cycle_add_remove_clear_preserves_empty_end_state() {
let tracker = UserIpTracker::new();
for cycle in 0..50u32 {
let user = format!("cycle{}", cycle);
tracker.set_user_limit(&user, 128).await;
for ip_idx in 0..128u32 {
tracker
.check_and_add(&user, ip_from_idx(cycle * 10_000 + ip_idx))
.await
.unwrap();
}
for ip_idx in 0..128u32 {
tracker
.remove_ip(&user, ip_from_idx(cycle * 10_000 + ip_idx))
.await;
}
tracker.clear_user_ips(&user).await;
}
assert!(tracker.get_stats().await.is_empty());
}
#[tokio::test]
async fn remove_unknown_user_or_ip_does_not_corrupt_state() {
let tracker = UserIpTracker::new();
tracker.remove_ip("no_user", ip_from_idx(1)).await;
tracker.check_and_add("x", ip_from_idx(2)).await.unwrap();
tracker.remove_ip("x", ip_from_idx(3)).await;
assert_eq!(tracker.get_active_ip_count("x").await, 1);
assert!(tracker.is_ip_active("x", ip_from_idx(2)).await);
}
#[tokio::test]
async fn active_and_recent_views_match_after_mixed_workload() {
let tracker = UserIpTracker::new();
tracker.set_user_limit("mix", 16).await;
for ip_idx in 0..12u32 {
tracker.check_and_add("mix", ip_from_idx(ip_idx)).await.unwrap();
}
for ip_idx in 0..6u32 {
tracker.remove_ip("mix", ip_from_idx(ip_idx)).await;
}
let active = tracker
.get_active_ips_for_users(&["mix".to_string()])
.await
.get("mix")
.cloned()
.unwrap_or_default();
let recent_count = tracker
.get_recent_counts_for_users(&["mix".to_string()])
.await
.get("mix")
.copied()
.unwrap_or(0);
assert_eq!(active.len(), 6);
assert!(recent_count >= active.len());
assert!(recent_count <= 12);
}
#[tokio::test]
async fn global_limit_switch_updates_enforcement_immediately() {
let tracker = UserIpTracker::new();
tracker.load_limits(2, &HashMap::new()).await;
assert!(tracker.check_and_add("u", ip_from_idx(1)).await.is_ok());
assert!(tracker.check_and_add("u", ip_from_idx(2)).await.is_ok());
assert!(tracker.check_and_add("u", ip_from_idx(3)).await.is_err());
tracker.clear_user_ips("u").await;
tracker.load_limits(4, &HashMap::new()).await;
assert!(tracker.check_and_add("u", ip_from_idx(1)).await.is_ok());
assert!(tracker.check_and_add("u", ip_from_idx(2)).await.is_ok());
assert!(tracker.check_and_add("u", ip_from_idx(3)).await.is_ok());
assert!(tracker.check_and_add("u", ip_from_idx(4)).await.is_ok());
assert!(tracker.check_and_add("u", ip_from_idx(5)).await.is_err());
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn concurrent_reconnect_and_disconnect_preserves_non_negative_counts() {
let tracker = Arc::new(UserIpTracker::new());
tracker.set_user_limit("cc", 8).await;
let mut handles = Vec::new();
for worker in 0..8u32 {
let tracker_cloned = tracker.clone();
handles.push(tokio::spawn(async move {
let ip = ip_from_idx(50 + worker);
for _ in 0..500u32 {
let _ = tracker_cloned.check_and_add("cc", ip).await;
tracker_cloned.remove_ip("cc", ip).await;
}
}));
}
for handle in handles {
handle.await.unwrap();
}
assert!(tracker.get_active_ip_count("cc").await <= 8);
}
#[tokio::test]
async fn enqueue_cleanup_recovers_from_poisoned_mutex() {
let tracker = UserIpTracker::new();
let ip = ip_from_idx(99);
// Poison the lock by panicking while holding it
let result = std::panic::catch_unwind(|| {
let _guard = tracker.cleanup_queue.lock().unwrap();
panic!("Intentional poison panic");
});
assert!(result.is_err(), "Expected panic to poison mutex");
// Attempt to enqueue anyway; should hit the poison catch arm and still insert
tracker.enqueue_cleanup("poison-user".to_string(), ip);
tracker.drain_cleanup_queue().await;
assert_eq!(tracker.get_active_ip_count("poison-user").await, 0);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn mass_reconnect_sync_cleanup_prevents_temporary_reservation_bloat() {
// Tests that synchronous M-01 drop mechanism protects against starvation
let tracker = Arc::new(UserIpTracker::new());
tracker.set_user_limit("mass", 5).await;
let ip = ip_from_idx(42);
let mut join_handles = Vec::new();
// 10,000 rapid concurrent requests hitting the same IP limit
for _ in 0..10_000 {
let tracker_clone = tracker.clone();
join_handles.push(tokio::spawn(async move {
if tracker_clone.check_and_add("mass", ip).await.is_ok() {
// Instantly enqueue cleanup, simulating synchronous reservation drop
tracker_clone.enqueue_cleanup("mass".to_string(), ip);
// The next caller will drain it before acquiring again
}
}));
}
for handle in join_handles {
let _ = handle.await;
}
// Force flush
tracker.drain_cleanup_queue().await;
assert_eq!(tracker.get_active_ip_count("mass").await, 0, "No leaked footprints");
}
#[tokio::test]
async fn adversarial_drain_cleanup_queue_race_does_not_cause_false_rejections() {
// Regression guard: concurrent cleanup draining must not produce false
// limit denials for a new IP when the previous IP is already queued.
let tracker = Arc::new(UserIpTracker::new());
tracker.set_user_limit("racer", 1).await;
let ip1 = ip_from_idx(1);
let ip2 = ip_from_idx(2);
// Initial state: add ip1
tracker.check_and_add("racer", ip1).await.unwrap();
// User disconnects from ip1, queuing it
tracker.enqueue_cleanup("racer".to_string(), ip1);
let mut saw_false_rejection = false;
for _ in 0..100 {
// Queue cleanup then race explicit drain and check-and-add on the alternative IP.
tracker.enqueue_cleanup("racer".to_string(), ip1);
let tracker_a = tracker.clone();
let tracker_b = tracker.clone();
let drain_handle = tokio::spawn(async move {
tracker_a.drain_cleanup_queue().await;
});
let handle = tokio::spawn(async move {
tracker_b.check_and_add("racer", ip2).await
});
drain_handle.await.unwrap();
let res = handle.await.unwrap();
if res.is_err() {
saw_false_rejection = true;
break;
}
// Restore baseline for next iteration.
tracker.remove_ip("racer", ip2).await;
tracker.check_and_add("racer", ip1).await.unwrap();
}
assert!(
!saw_false_rejection,
"Concurrent cleanup draining must not cause false-positive IP denials"
);
}
#[tokio::test]
async fn poisoned_cleanup_queue_still_releases_slot_for_next_ip() {
let tracker = UserIpTracker::new();
tracker.set_user_limit("poison-slot", 1).await;
let ip1 = ip_from_idx(7001);
let ip2 = ip_from_idx(7002);
tracker.check_and_add("poison-slot", ip1).await.unwrap();
// Poison the queue lock as an adversarial condition.
let _ = std::panic::catch_unwind(|| {
let _guard = tracker.cleanup_queue.lock().unwrap();
panic!("intentional queue poison");
});
// Disconnect path must still queue cleanup so the next IP can be admitted.
tracker.enqueue_cleanup("poison-slot".to_string(), ip1);
let admitted = tracker.check_and_add("poison-slot", ip2).await;
assert!(
admitted.is_ok(),
"cleanup queue poison must not permanently block slot release for the next IP"
);
}
#[tokio::test]
async fn duplicate_cleanup_entries_do_not_break_future_admission() {
let tracker = UserIpTracker::new();
tracker.set_user_limit("dup-cleanup", 1).await;
let ip1 = ip_from_idx(7101);
let ip2 = ip_from_idx(7102);
tracker.check_and_add("dup-cleanup", ip1).await.unwrap();
tracker.enqueue_cleanup("dup-cleanup".to_string(), ip1);
tracker.enqueue_cleanup("dup-cleanup".to_string(), ip1);
tracker.enqueue_cleanup("dup-cleanup".to_string(), ip1);
tracker.drain_cleanup_queue().await;
assert_eq!(tracker.get_active_ip_count("dup-cleanup").await, 0);
assert!(
tracker.check_and_add("dup-cleanup", ip2).await.is_ok(),
"extra queued cleanup entries must not leave user stuck in denied state"
);
}
#[tokio::test]
async fn stress_repeated_queue_poison_recovery_preserves_admission_progress() {
let tracker = UserIpTracker::new();
tracker.set_user_limit("poison-stress", 1).await;
let ip_primary = ip_from_idx(7201);
let ip_alt = ip_from_idx(7202);
tracker.check_and_add("poison-stress", ip_primary).await.unwrap();
for _ in 0..64 {
let _ = std::panic::catch_unwind(|| {
let _guard = tracker.cleanup_queue.lock().unwrap();
panic!("intentional queue poison in stress loop");
});
tracker.enqueue_cleanup("poison-stress".to_string(), ip_primary);
assert!(
tracker.check_and_add("poison-stress", ip_alt).await.is_ok(),
"poison recovery must preserve admission progress under repeated queue poisoning"
);
tracker.remove_ip("poison-stress", ip_alt).await;
tracker.check_and_add("poison-stress", ip_primary).await.unwrap();
}
}

View File

@ -279,11 +279,32 @@ pub(crate) async fn spawn_metrics_if_configured(
ip_tracker: Arc<UserIpTracker>,
config_rx: watch::Receiver<Arc<ProxyConfig>>,
) {
if let Some(port) = config.server.metrics_port {
// metrics_listen takes precedence; fall back to metrics_port for backward compat.
let metrics_target: Option<(u16, Option<String>)> =
if let Some(ref listen) = config.server.metrics_listen {
match listen.parse::<std::net::SocketAddr>() {
Ok(addr) => Some((addr.port(), Some(listen.clone()))),
Err(e) => {
startup_tracker
.skip_component(
COMPONENT_METRICS_START,
Some(format!("invalid metrics_listen \"{}\": {}", listen, e)),
)
.await;
None
}
}
} else {
config.server.metrics_port.map(|p| (p, None))
};
if let Some((port, listen)) = metrics_target {
let fallback_label = format!("port {}", port);
let label = listen.as_deref().unwrap_or(&fallback_label);
startup_tracker
.start_component(
COMPONENT_METRICS_START,
Some(format!("spawn metrics endpoint on {}", port)),
Some(format!("spawn metrics endpoint on {}", label)),
)
.await;
let stats = stats.clone();
@ -294,6 +315,7 @@ pub(crate) async fn spawn_metrics_if_configured(
tokio::spawn(async move {
metrics::serve(
port,
listen,
stats,
beobachten,
ip_tracker_metrics,
@ -308,7 +330,7 @@ pub(crate) async fn spawn_metrics_if_configured(
Some("metrics task spawned".to_string()),
)
.await;
} else {
} else if config.server.metrics_listen.is_none() {
startup_tracker
.skip_component(
COMPONENT_METRICS_START,

View File

@ -6,6 +6,8 @@ mod config;
mod crypto;
mod error;
mod ip_tracker;
#[cfg(test)]
mod ip_tracker_regression_tests;
mod maestro;
mod metrics;
mod network;

View File

@ -21,6 +21,7 @@ use crate::transport::{ListenOptions, create_listener};
pub async fn serve(
port: u16,
listen: Option<String>,
stats: Arc<Stats>,
beobachten: Arc<BeobachtenStore>,
ip_tracker: Arc<UserIpTracker>,
@ -28,6 +29,33 @@ pub async fn serve(
whitelist: Vec<IpNetwork>,
) {
let whitelist = Arc::new(whitelist);
// If `metrics_listen` is set, bind on that single address only.
if let Some(ref listen_addr) = listen {
let addr: SocketAddr = match listen_addr.parse() {
Ok(a) => a,
Err(e) => {
warn!(error = %e, "Invalid metrics_listen address: {}", listen_addr);
return;
}
};
let is_ipv6 = addr.is_ipv6();
match bind_metrics_listener(addr, is_ipv6) {
Ok(listener) => {
info!("Metrics endpoint: http://{}/metrics and /beobachten", addr);
serve_listener(
listener, stats, beobachten, ip_tracker, config_rx, whitelist,
)
.await;
}
Err(e) => {
warn!(error = %e, "Failed to bind metrics on {}", addr);
}
}
return;
}
// Fallback: bind on 0.0.0.0 and [::] using metrics_port.
let mut listener_v4 = None;
let mut listener_v6 = None;

View File

@ -11,9 +11,8 @@ use crate::crypto::{sha256_hmac, SecureRandom};
use crate::error::ProxyError;
use super::constants::*;
use std::time::{SystemTime, UNIX_EPOCH};
use num_bigint::BigUint;
use num_traits::One;
use subtle::ConstantTimeEq;
use x25519_dalek::{X25519_BASEPOINT_BYTES, x25519};
// ============= Public Constants =============
@ -27,10 +26,17 @@ pub const TLS_DIGEST_POS: usize = 11;
pub const TLS_DIGEST_HALF_LEN: usize = 16;
/// Time skew limits for anti-replay (in seconds)
pub const TIME_SKEW_MIN: i64 = -20 * 60; // 20 minutes before
pub const TIME_SKEW_MAX: i64 = 10 * 60; // 10 minutes after
///
/// The default window is intentionally narrow to reduce replay acceptance.
/// Operators with known clock-drifted clients should tune deployment config
/// (for example replay-window policy) to match their environment.
pub const TIME_SKEW_MIN: i64 = -2 * 60; // 2 minutes before
pub const TIME_SKEW_MAX: i64 = 2 * 60; // 2 minutes after
/// Maximum accepted boot-time timestamp (seconds) before skew checks are enforced.
pub const BOOT_TIME_MAX_SECS: u32 = 7 * 24 * 60 * 60;
/// Hard cap for boot-time compatibility bypass to avoid oversized acceptance
/// windows when replay TTL is configured very large.
pub const BOOT_TIME_COMPAT_MAX_SECS: u32 = 2 * 60;
// ============= Private Constants =============
@ -63,6 +69,7 @@ pub struct TlsValidation {
/// Client digest for response generation
pub digest: [u8; TLS_DIGEST_LEN],
/// Timestamp extracted from digest
pub timestamp: u32,
}
@ -117,28 +124,8 @@ impl TlsExtensionBuilder {
self
}
/// Add ALPN extension with a single selected protocol.
fn add_alpn(&mut self, proto: &[u8]) -> &mut Self {
// Extension type: ALPN (0x0010)
self.extensions.extend_from_slice(&extension_type::ALPN.to_be_bytes());
// ALPN extension format:
// extension_data length (2 bytes)
// protocols length (2 bytes)
// protocol name length (1 byte)
// protocol name bytes
let proto_len = proto.len() as u8;
let list_len: u16 = 1 + u16::from(proto_len);
let ext_len: u16 = 2 + list_len;
self.extensions.extend_from_slice(&ext_len.to_be_bytes());
self.extensions.extend_from_slice(&list_len.to_be_bytes());
self.extensions.push(proto_len);
self.extensions.extend_from_slice(proto);
self
}
/// Build final extensions with length prefix
fn build(self) -> Vec<u8> {
let mut result = Vec::with_capacity(2 + self.extensions.len());
@ -153,7 +140,7 @@ impl TlsExtensionBuilder {
}
/// Get current extensions without length prefix (for calculation)
#[allow(dead_code)]
fn as_bytes(&self) -> &[u8] {
&self.extensions
}
@ -173,8 +160,6 @@ struct ServerHelloBuilder {
compression: u8,
/// Extensions
extensions: TlsExtensionBuilder,
/// Selected ALPN protocol (if any)
alpn: Option<Vec<u8>>,
}
impl ServerHelloBuilder {
@ -185,7 +170,6 @@ impl ServerHelloBuilder {
cipher_suite: cipher_suite::TLS_AES_128_GCM_SHA256,
compression: 0x00,
extensions: TlsExtensionBuilder::new(),
alpn: None,
}
}
@ -200,18 +184,9 @@ impl ServerHelloBuilder {
self
}
fn with_alpn(mut self, proto: Option<Vec<u8>>) -> Self {
self.alpn = proto;
self
}
/// Build ServerHello message (without record header)
fn build_message(&self) -> Vec<u8> {
let mut ext_builder = self.extensions.clone();
if let Some(ref alpn) = self.alpn {
ext_builder.add_alpn(alpn);
}
let extensions = ext_builder.extensions.clone();
let extensions = self.extensions.extensions.clone();
let extensions_len = extensions.len() as u16;
// Calculate total length
@ -281,6 +256,7 @@ impl ServerHelloBuilder {
/// Returns validation result if a matching user is found.
/// The result **must** be used — ignoring it silently bypasses authentication.
#[must_use]
pub fn validate_tls_handshake(
handshake: &[u8],
secrets: &[(String, Vec<u8>)],
@ -296,9 +272,9 @@ pub fn validate_tls_handshake(
/// Validate TLS ClientHello and cap the boot-time bypass by replay-cache TTL.
///
/// A boot-time timestamp is only accepted when it falls below both
/// `BOOT_TIME_MAX_SECS` and the configured replay window, preventing timestamp
/// reuse outside replay cache coverage.
/// A boot-time timestamp is only accepted when it falls below all three
/// bounds: `BOOT_TIME_MAX_SECS`, configured replay window, and
/// `BOOT_TIME_COMPAT_MAX_SECS`, preventing oversized compatibility windows.
#[must_use]
pub fn validate_tls_handshake_with_replay_window(
handshake: &[u8],
@ -316,7 +292,16 @@ pub fn validate_tls_handshake_with_replay_window(
};
let replay_window_u32 = u32::try_from(replay_window_secs).unwrap_or(u32::MAX);
let boot_time_cap_secs = BOOT_TIME_MAX_SECS.min(replay_window_u32);
// Boot-time bypass and ignore_time_skew serve different compatibility paths.
// When skew checks are disabled, force boot-time cap to zero to prevent
// accidental future coupling of boot-time logic into the ignore-skew path.
let boot_time_cap_secs = if ignore_time_skew {
0
} else {
BOOT_TIME_MAX_SECS
.min(replay_window_u32)
.min(BOOT_TIME_COMPAT_MAX_SECS)
};
validate_tls_handshake_at_time_with_boot_cap(
handshake,
@ -335,6 +320,7 @@ fn system_time_to_unix_secs(now: SystemTime) -> Option<i64> {
i64::try_from(d.as_secs()).ok()
}
fn validate_tls_handshake_at_time(
handshake: &[u8],
secrets: &[(String, Vec<u8>)],
@ -369,6 +355,9 @@ fn validate_tls_handshake_at_time_with_boot_cap(
// Extract session ID
let session_id_len_pos = TLS_DIGEST_POS + TLS_DIGEST_LEN;
let session_id_len = handshake.get(session_id_len_pos).copied()? as usize;
if session_id_len > 32 {
return None;
}
let session_id_start = session_id_len_pos + 1;
if handshake.len() < session_id_start + session_id_len {
@ -411,7 +400,7 @@ fn validate_tls_handshake_at_time_with_boot_cap(
if !ignore_time_skew {
// Allow very small timestamps (boot time instead of unix time)
// This is a quirk in some clients that use uptime instead of real time
let is_boot_time = timestamp < boot_time_cap_secs;
let is_boot_time = boot_time_cap_secs > 0 && timestamp < boot_time_cap_secs;
if !is_boot_time {
let time_diff = now - i64::from(timestamp);
if !(TIME_SKEW_MIN..=TIME_SKEW_MAX).contains(&time_diff) {
@ -433,27 +422,14 @@ fn validate_tls_handshake_at_time_with_boot_cap(
})
}
fn curve25519_prime() -> BigUint {
(BigUint::one() << 255) - BigUint::from(19u32)
}
/// Generate a fake X25519 public key for TLS
///
/// Produces a quadratic residue mod p = 2^255 - 19 by computing n² mod p,
/// which matches Python/C behavior and avoids DPI fingerprinting.
/// Uses RFC 7748 X25519 scalar multiplication over the canonical basepoint,
/// yielding distribution-consistent public keys for anti-fingerprinting.
pub fn gen_fake_x25519_key(rng: &SecureRandom) -> [u8; 32] {
let mut n_bytes = [0u8; 32];
n_bytes.copy_from_slice(&rng.bytes(32));
let n = BigUint::from_bytes_le(&n_bytes);
let p = curve25519_prime();
let pk = (&n * &n) % &p;
let mut out = pk.to_bytes_le();
out.resize(32, 0);
let mut result = [0u8; 32];
result.copy_from_slice(&out[..32]);
result
let mut scalar = [0u8; 32];
scalar.copy_from_slice(&rng.bytes(32));
x25519(scalar, X25519_BASEPOINT_BYTES)
}
/// Build TLS ServerHello response
@ -482,7 +458,6 @@ pub fn build_server_hello(
let server_hello = ServerHelloBuilder::new(session_id.to_vec())
.with_x25519_key(&x25519_key)
.with_tls13_version()
.with_alpn(alpn)
.build_record();
// Build Change Cipher Spec record
@ -493,8 +468,27 @@ pub fn build_server_hello(
0x01, // CCS byte
];
// Build fake certificate (Application Data record)
let fake_cert = rng.bytes(fake_cert_len);
// Build first encrypted flight mimic as opaque ApplicationData bytes.
// Embed a compact EncryptedExtensions-like ALPN block when selected.
let mut fake_cert = Vec::with_capacity(fake_cert_len);
if let Some(proto) = alpn.as_ref().filter(|p| !p.is_empty() && p.len() <= u8::MAX as usize) {
let proto_list_len = 1usize + proto.len();
let ext_data_len = 2usize + proto_list_len;
let marker_len = 4usize + ext_data_len;
if marker_len <= fake_cert_len {
fake_cert.extend_from_slice(&0x0010u16.to_be_bytes());
fake_cert.extend_from_slice(&(ext_data_len as u16).to_be_bytes());
fake_cert.extend_from_slice(&(proto_list_len as u16).to_be_bytes());
fake_cert.push(proto.len() as u8);
fake_cert.extend_from_slice(proto);
}
}
if fake_cert.len() < fake_cert_len {
fake_cert.extend_from_slice(&rng.bytes(fake_cert_len - fake_cert.len()));
} else if fake_cert.len() > fake_cert_len {
fake_cert.truncate(fake_cert_len);
}
let mut app_data_record = Vec::with_capacity(5 + fake_cert_len);
app_data_record.push(TLS_RECORD_APPLICATION);
app_data_record.extend_from_slice(&TLS_VERSION);
@ -506,8 +500,9 @@ pub fn build_server_hello(
// Build optional NewSessionTicket records (TLS 1.3 handshake messages are encrypted;
// here we mimic with opaque ApplicationData records of plausible size).
let mut tickets = Vec::new();
if new_session_tickets > 0 {
for _ in 0..new_session_tickets {
let ticket_count = new_session_tickets.min(4);
if ticket_count > 0 {
for _ in 0..ticket_count {
let ticket_len: usize = rng.range(48) + 48; // 48-95 bytes
let mut record = Vec::with_capacity(5 + ticket_len);
record.push(TLS_RECORD_APPLICATION);
@ -705,13 +700,14 @@ pub fn is_tls_handshake(first_bytes: &[u8]) -> bool {
return false;
}
// TLS record header: 0x16 (handshake) 0x03 0x01 (TLS 1.0)
// TLS ClientHello commonly uses legacy record versions 0x0301 or 0x0303.
first_bytes[0] == TLS_RECORD_HANDSHAKE
&& first_bytes[1] == 0x03
&& first_bytes[2] == 0x01
&& (first_bytes[2] == 0x01 || first_bytes[2] == 0x03)
}
/// Parse TLS record header, returns (record_type, length)
pub fn parse_tls_record_header(header: &[u8; 5]) -> Option<(u8, u16)> {
let record_type = header[0];
let version = [header[1], header[2]];

File diff suppressed because it is too large Load Diff

View File

@ -24,6 +24,47 @@ enum HandshakeOutcome {
Handled,
}
#[must_use = "UserConnectionReservation must be kept alive to retain user/IP reservation until release or drop"]
struct UserConnectionReservation {
stats: Arc<Stats>,
ip_tracker: Arc<UserIpTracker>,
user: String,
ip: IpAddr,
active: bool,
}
impl UserConnectionReservation {
fn new(stats: Arc<Stats>, ip_tracker: Arc<UserIpTracker>, user: String, ip: IpAddr) -> Self {
Self {
stats,
ip_tracker,
user,
ip,
active: true,
}
}
async fn release(mut self) {
if !self.active {
return;
}
self.ip_tracker.remove_ip(&self.user, self.ip).await;
self.active = false;
self.stats.decrement_user_curr_connects(&self.user);
}
}
impl Drop for UserConnectionReservation {
fn drop(&mut self) {
if !self.active {
return;
}
self.active = false;
self.stats.decrement_user_curr_connects(&self.user);
self.ip_tracker.enqueue_cleanup(self.user.clone(), self.ip);
}
}
use crate::config::ProxyConfig;
use crate::crypto::SecureRandom;
use crate::error::{HandshakeResult, ProxyError, Result, StreamError};
@ -45,7 +86,19 @@ use crate::proxy::middle_relay::handle_via_middle_proxy;
use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController};
fn beobachten_ttl(config: &ProxyConfig) -> Duration {
Duration::from_secs(config.general.beobachten_minutes.saturating_mul(60))
let minutes = config.general.beobachten_minutes;
if minutes == 0 {
static BEOBACHTEN_ZERO_MINUTES_WARNED: OnceLock<AtomicBool> = OnceLock::new();
let warned = BEOBACHTEN_ZERO_MINUTES_WARNED.get_or_init(|| AtomicBool::new(false));
if !warned.swap(true, Ordering::Relaxed) {
warn!(
"general.beobachten_minutes=0 is insecure because entries expire immediately; forcing minimum TTL to 1 minute"
);
}
return Duration::from_secs(60);
}
Duration::from_secs(minutes.saturating_mul(60))
}
fn record_beobachten_class(
@ -90,6 +143,10 @@ fn is_trusted_proxy_source(peer_ip: IpAddr, trusted: &[IpNetwork]) -> bool {
trusted.iter().any(|cidr| cidr.contains(peer_ip))
}
fn synthetic_local_addr(port: u16) -> SocketAddr {
SocketAddr::from(([0, 0, 0, 0], port))
}
pub async fn handle_client_stream<S>(
mut stream: S,
peer: SocketAddr,
@ -113,9 +170,7 @@ where
let mut real_peer = normalize_ip(peer);
// For non-TCP streams, use a synthetic local address; may be overridden by PROXY protocol dst
let mut local_addr: SocketAddr = format!("0.0.0.0:{}", config.server.port)
.parse()
.unwrap_or_else(|_| "0.0.0.0:443".parse().unwrap());
let mut local_addr = synthetic_local_addr(config.server.port);
if proxy_protocol_enabled {
let proxy_header_timeout = Duration::from_millis(
@ -245,7 +300,7 @@ where
handle_bad_client(
reader,
writer,
&mtproto_handshake,
&handshake,
real_peer,
local_addr,
&config,
@ -426,7 +481,6 @@ impl RunningClientHandler {
pub async fn run(self) -> Result<()> {
self.stats.increment_connects_all();
let peer = self.peer;
let _ip_tracker = self.ip_tracker.clone();
debug!(peer = %peer, "New connection");
if let Err(e) = configure_client_socket(
@ -557,7 +611,6 @@ impl RunningClientHandler {
let is_tls = tls::is_tls_handshake(&first_bytes[..3]);
let peer = self.peer;
let _ip_tracker = self.ip_tracker.clone();
debug!(peer = %peer, is_tls = is_tls, "Handshake type detected");
@ -570,7 +623,6 @@ impl RunningClientHandler {
async fn handle_tls_client(mut self, first_bytes: [u8; 5], local_addr: SocketAddr) -> Result<HandshakeOutcome> {
let peer = self.peer;
let _ip_tracker = self.ip_tracker.clone();
let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize;
@ -661,7 +713,7 @@ impl RunningClientHandler {
handle_bad_client(
reader,
writer,
&mtproto_handshake,
&handshake,
peer,
local_addr,
&config,
@ -694,7 +746,6 @@ impl RunningClientHandler {
async fn handle_direct_client(mut self, first_bytes: [u8; 5], local_addr: SocketAddr) -> Result<HandshakeOutcome> {
let peer = self.peer;
let _ip_tracker = self.ip_tracker.clone();
if !self.config.general.modes.classic && !self.config.general.modes.secure {
debug!(peer = %peer, "Non-TLS modes disabled");
@ -798,10 +849,22 @@ impl RunningClientHandler {
{
let user = success.user.clone();
if let Err(e) = Self::check_user_limits_static(&user, &config, &stats, peer_addr, &ip_tracker).await {
warn!(user = %user, error = %e, "User limit exceeded");
return Err(e);
}
let user_limit_reservation =
match Self::acquire_user_connection_reservation_static(
&user,
&config,
stats.clone(),
peer_addr,
ip_tracker,
)
.await
{
Ok(reservation) => reservation,
Err(e) => {
warn!(user = %user, error = %e, "User admission check failed");
return Err(e);
}
};
let route_snapshot = route_runtime.snapshot();
let session_id = rng.u64();
@ -858,15 +921,68 @@ impl RunningClientHandler {
)
.await
};
stats.decrement_user_curr_connects(&user);
ip_tracker.remove_ip(&user, peer_addr.ip()).await;
user_limit_reservation.release().await;
relay_result
}
async fn acquire_user_connection_reservation_static(
user: &str,
config: &ProxyConfig,
stats: Arc<Stats>,
peer_addr: SocketAddr,
ip_tracker: Arc<UserIpTracker>,
) -> Result<UserConnectionReservation> {
if let Some(expiration) = config.access.user_expirations.get(user)
&& chrono::Utc::now() > *expiration
{
return Err(ProxyError::UserExpired {
user: user.to_string(),
});
}
if let Some(quota) = config.access.user_data_quota.get(user)
&& stats.get_user_total_octets(user) >= *quota
{
return Err(ProxyError::DataQuotaExceeded {
user: user.to_string(),
});
}
let limit = config.access.user_max_tcp_conns.get(user).map(|v| *v as u64);
if !stats.try_acquire_user_curr_connects(user, limit) {
return Err(ProxyError::ConnectionLimitExceeded {
user: user.to_string(),
});
}
match ip_tracker.check_and_add(user, peer_addr.ip()).await {
Ok(()) => {}
Err(reason) => {
stats.decrement_user_curr_connects(user);
warn!(
user = %user,
ip = %peer_addr.ip(),
reason = %reason,
"IP limit exceeded"
);
return Err(ProxyError::ConnectionLimitExceeded {
user: user.to_string(),
});
}
}
Ok(UserConnectionReservation::new(
stats,
ip_tracker,
user.to_string(),
peer_addr.ip(),
))
}
#[cfg(test)]
async fn check_user_limits_static(
user: &str,
config: &ProxyConfig,
user: &str,
config: &ProxyConfig,
stats: &Stats,
peer_addr: SocketAddr,
ip_tracker: &UserIpTracker,
@ -899,7 +1015,10 @@ impl RunningClientHandler {
}
match ip_tracker.check_and_add(user, peer_addr.ip()).await {
Ok(()) => {}
Ok(()) => {
ip_tracker.remove_ip(user, peer_addr.ip()).await;
stats.decrement_user_curr_connects(user);
}
Err(reason) => {
stats.decrement_user_curr_connects(user);
warn!(

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,8 @@
use std::ffi::OsString;
use std::fs::OpenOptions;
use std::io::Write;
use std::net::SocketAddr;
use std::path::{Component, Path, PathBuf};
use std::sync::Arc;
use std::collections::HashSet;
use std::sync::{Mutex, OnceLock};
@ -24,14 +26,44 @@ use crate::stats::Stats;
use crate::stream::{BufferPool, CryptoReader, CryptoWriter};
use crate::transport::UpstreamManager;
#[cfg(unix)]
use std::os::unix::fs::OpenOptionsExt;
const UNKNOWN_DC_LOG_DISTINCT_LIMIT: usize = 1024;
static LOGGED_UNKNOWN_DCS: OnceLock<Mutex<HashSet<i16>>> = OnceLock::new();
const MAX_SCOPE_HINT_LEN: usize = 64;
fn validated_scope_hint(user: &str) -> Option<&str> {
let scope = user.strip_prefix("scope_")?;
if scope.is_empty() || scope.len() > MAX_SCOPE_HINT_LEN {
return None;
}
if scope
.bytes()
.all(|b| b.is_ascii_alphanumeric() || b == b'-')
{
Some(scope)
} else {
None
}
}
#[derive(Clone)]
struct SanitizedUnknownDcLogPath {
resolved_path: PathBuf,
allowed_parent: PathBuf,
file_name: OsString,
}
// In tests, this function shares global mutable state. Callers that also use
// cache-reset helpers must hold `unknown_dc_test_lock()` to keep assertions
// deterministic under parallel execution.
fn should_log_unknown_dc(dc_idx: i16) -> bool {
let set = LOGGED_UNKNOWN_DCS.get_or_init(|| Mutex::new(HashSet::new()));
should_log_unknown_dc_with_set(set, dc_idx)
}
fn should_log_unknown_dc_with_set(set: &Mutex<HashSet<i16>>, dc_idx: i16) -> bool {
match set.lock() {
Ok(mut guard) => {
if guard.contains(&dc_idx) {
@ -42,9 +74,85 @@ fn should_log_unknown_dc(dc_idx: i16) -> bool {
}
guard.insert(dc_idx)
}
// If the lock is poisoned, keep logging rather than silently dropping
// operator-visible diagnostics.
Err(_) => true,
// Fail closed on poisoned state to avoid unbounded blocking log writes.
Err(_) => false,
}
}
fn sanitize_unknown_dc_log_path(path: &str) -> Option<SanitizedUnknownDcLogPath> {
let candidate = Path::new(path);
if candidate.as_os_str().is_empty() {
return None;
}
if candidate
.components()
.any(|component| matches!(component, Component::ParentDir))
{
return None;
}
let cwd = std::env::current_dir().ok()?;
let file_name = candidate.file_name()?;
let parent = candidate.parent().unwrap_or_else(|| Path::new("."));
let parent_path = if parent.is_absolute() {
parent.to_path_buf()
} else {
cwd.join(parent)
};
let canonical_parent = parent_path.canonicalize().ok()?;
if !canonical_parent.is_dir() {
return None;
}
Some(SanitizedUnknownDcLogPath {
resolved_path: canonical_parent.join(file_name),
allowed_parent: canonical_parent,
file_name: file_name.to_os_string(),
})
}
fn unknown_dc_log_path_is_still_safe(path: &SanitizedUnknownDcLogPath) -> bool {
let Some(parent) = path.resolved_path.parent() else {
return false;
};
let Ok(current_parent) = parent.canonicalize() else {
return false;
};
if current_parent != path.allowed_parent {
return false;
}
if let Ok(canonical_target) = path.resolved_path.canonicalize() {
let Some(target_parent) = canonical_target.parent() else {
return false;
};
let Some(target_name) = canonical_target.file_name() else {
return false;
};
if target_parent != path.allowed_parent || target_name != path.file_name {
return false;
}
}
true
}
fn open_unknown_dc_log_append(path: &Path) -> std::io::Result<std::fs::File> {
#[cfg(unix)]
{
OpenOptions::new()
.create(true)
.append(true)
.custom_flags(libc::O_NOFOLLOW)
.open(path)
}
#[cfg(not(unix))]
{
let _ = path;
Err(std::io::Error::new(
std::io::ErrorKind::PermissionDenied,
"unknown_dc_file_log_enabled requires unix O_NOFOLLOW support",
))
}
}
@ -93,8 +201,15 @@ where
"Connecting to Telegram DC"
);
let scope_hint = validated_scope_hint(user);
if user.starts_with("scope_") && scope_hint.is_none() {
warn!(
user = %user,
"Ignoring invalid scope hint and falling back to default upstream selection"
);
}
let tg_stream = upstream_manager
.connect(dc_addr, Some(success.dc_idx), user.strip_prefix("scope_").filter(|s| !s.is_empty()))
.connect(dc_addr, Some(success.dc_idx), scope_hint)
.await?;
debug!(peer = %success.peer, dc_addr = %dc_addr, "Connected, performing TG handshake");
@ -105,7 +220,7 @@ where
debug!(peer = %success.peer, "TG handshake complete, starting relay");
stats.increment_user_connects(user);
stats.increment_current_connections_direct();
let _direct_connection_lease = stats.acquire_direct_connection_lease();
let relay_result = relay_bidirectional(
client_reader,
@ -116,6 +231,7 @@ where
config.general.direct_relay_copy_buf_s2c_bytes,
user,
Arc::clone(&stats),
config.access.user_data_quota.get(user).copied(),
buffer_pool,
);
tokio::pin!(relay_result);
@ -148,8 +264,6 @@ where
}
};
stats.decrement_current_connections_direct();
match &relay_result {
Ok(()) => debug!(user = %user, "Direct relay completed"),
Err(e) => debug!(user = %user, error = %e, "Direct relay ended with error"),
@ -199,15 +313,21 @@ fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result<SocketAddr> {
warn!(dc_idx = dc_idx, "Requested non-standard DC with no override; falling back to default cluster");
if config.general.unknown_dc_file_log_enabled
&& let Some(path) = &config.general.unknown_dc_log_path
&& should_log_unknown_dc(dc_idx)
&& let Ok(handle) = tokio::runtime::Handle::try_current()
{
let path = path.clone();
handle.spawn_blocking(move || {
if let Ok(mut file) = OpenOptions::new().create(true).append(true).open(path) {
let _ = writeln!(file, "dc_idx={dc_idx}");
if let Some(path) = sanitize_unknown_dc_log_path(path) {
if should_log_unknown_dc(dc_idx) {
handle.spawn_blocking(move || {
if unknown_dc_log_path_is_still_safe(&path)
&& let Ok(mut file) = open_unknown_dc_log_append(&path.resolved_path)
{
let _ = writeln!(file, "dc_idx={dc_idx}");
}
});
}
});
} else {
warn!(dc_idx = dc_idx, raw_path = %path, "Rejected unsafe unknown DC log path");
}
}
}

File diff suppressed because it is too large Load Diff

View File

@ -4,17 +4,17 @@
use std::net::SocketAddr;
use std::collections::HashSet;
use std::collections::hash_map::RandomState;
use std::net::{IpAddr, Ipv6Addr};
use std::sync::Arc;
use std::sync::{Mutex, OnceLock};
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::hash::{BuildHasher, Hash, Hasher};
use std::time::{Duration, Instant};
use dashmap::DashMap;
use dashmap::mapref::entry::Entry;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tracing::{debug, warn, trace};
use zeroize::Zeroize;
use zeroize::{Zeroize, Zeroizing};
use crate::crypto::{sha256, AesCtr, SecureRandom};
use rand::Rng;
@ -28,6 +28,10 @@ use crate::tls_front::{TlsFrontCache, emulator};
const ACCESS_SECRET_BYTES: usize = 16;
static INVALID_SECRET_WARNED: OnceLock<Mutex<HashSet<(String, String)>>> = OnceLock::new();
#[cfg(test)]
const WARNED_SECRET_MAX_ENTRIES: usize = 64;
#[cfg(not(test))]
const WARNED_SECRET_MAX_ENTRIES: usize = 1_024;
const AUTH_PROBE_TRACK_RETENTION_SECS: u64 = 10 * 60;
#[cfg(test)]
@ -36,6 +40,7 @@ const AUTH_PROBE_TRACK_MAX_ENTRIES: usize = 256;
const AUTH_PROBE_TRACK_MAX_ENTRIES: usize = 65_536;
const AUTH_PROBE_PRUNE_SCAN_LIMIT: usize = 1_024;
const AUTH_PROBE_BACKOFF_START_FAILS: u32 = 4;
const AUTH_PROBE_SATURATION_GRACE_FAILS: u32 = 2;
#[cfg(test)]
const AUTH_PROBE_BACKOFF_BASE_MS: u64 = 1;
@ -54,12 +59,25 @@ struct AuthProbeState {
last_seen: Instant,
}
#[derive(Clone, Copy)]
struct AuthProbeSaturationState {
fail_streak: u32,
blocked_until: Instant,
last_seen: Instant,
}
static AUTH_PROBE_STATE: OnceLock<DashMap<IpAddr, AuthProbeState>> = OnceLock::new();
static AUTH_PROBE_SATURATION_STATE: OnceLock<Mutex<Option<AuthProbeSaturationState>>> = OnceLock::new();
static AUTH_PROBE_EVICTION_HASHER: OnceLock<RandomState> = OnceLock::new();
fn auth_probe_state_map() -> &'static DashMap<IpAddr, AuthProbeState> {
AUTH_PROBE_STATE.get_or_init(DashMap::new)
}
fn auth_probe_saturation_state() -> &'static Mutex<Option<AuthProbeSaturationState>> {
AUTH_PROBE_SATURATION_STATE.get_or_init(|| Mutex::new(None))
}
fn normalize_auth_probe_ip(peer_ip: IpAddr) -> IpAddr {
match peer_ip {
IpAddr::V4(ip) => IpAddr::V4(ip),
@ -88,7 +106,8 @@ fn auth_probe_state_expired(state: &AuthProbeState, now: Instant) -> bool {
}
fn auth_probe_eviction_offset(peer_ip: IpAddr, now: Instant) -> usize {
let mut hasher = DefaultHasher::new();
let hasher_state = AUTH_PROBE_EVICTION_HASHER.get_or_init(RandomState::new);
let mut hasher = hasher_state.build_hasher();
peer_ip.hash(&mut hasher);
now.hash(&mut hasher);
hasher.finish() as usize
@ -108,6 +127,83 @@ fn auth_probe_is_throttled(peer_ip: IpAddr, now: Instant) -> bool {
now < entry.blocked_until
}
fn auth_probe_saturation_grace_exhausted(peer_ip: IpAddr, now: Instant) -> bool {
let peer_ip = normalize_auth_probe_ip(peer_ip);
let state = auth_probe_state_map();
let Some(entry) = state.get(&peer_ip) else {
return false;
};
if auth_probe_state_expired(&entry, now) {
drop(entry);
state.remove(&peer_ip);
return false;
}
entry.fail_streak >= AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS
}
fn auth_probe_should_apply_preauth_throttle(peer_ip: IpAddr, now: Instant) -> bool {
if !auth_probe_is_throttled(peer_ip, now) {
return false;
}
if !auth_probe_saturation_is_throttled(now) {
return true;
}
auth_probe_saturation_grace_exhausted(peer_ip, now)
}
fn auth_probe_saturation_is_throttled(now: Instant) -> bool {
let saturation = auth_probe_saturation_state();
let mut guard = match saturation.lock() {
Ok(guard) => guard,
Err(_) => return false,
};
let Some(state) = guard.as_mut() else {
return false;
};
if now.duration_since(state.last_seen) > Duration::from_secs(AUTH_PROBE_TRACK_RETENTION_SECS) {
*guard = None;
return false;
}
if now < state.blocked_until {
return true;
}
false
}
fn auth_probe_note_saturation(now: Instant) {
let saturation = auth_probe_saturation_state();
let mut guard = match saturation.lock() {
Ok(guard) => guard,
Err(_) => return,
};
match guard.as_mut() {
Some(state)
if now.duration_since(state.last_seen)
<= Duration::from_secs(AUTH_PROBE_TRACK_RETENTION_SECS) =>
{
state.fail_streak = state.fail_streak.saturating_add(1);
state.last_seen = now;
state.blocked_until = now + auth_probe_backoff(state.fail_streak);
}
_ => {
let fail_streak = AUTH_PROBE_BACKOFF_START_FAILS;
*guard = Some(AuthProbeSaturationState {
fail_streak,
blocked_until: now + auth_probe_backoff(fail_streak),
last_seen: now,
});
}
}
}
fn auth_probe_record_failure(peer_ip: IpAddr, now: Instant) {
let peer_ip = normalize_auth_probe_ip(peer_ip);
let state = auth_probe_state_map();
@ -144,24 +240,98 @@ fn auth_probe_record_failure_with_state(
}
if state.len() >= AUTH_PROBE_TRACK_MAX_ENTRIES {
let mut stale_keys = Vec::new();
let mut eviction_candidates = Vec::new();
for entry in state.iter().take(AUTH_PROBE_PRUNE_SCAN_LIMIT) {
eviction_candidates.push(*entry.key());
if auth_probe_state_expired(entry.value(), now) {
stale_keys.push(*entry.key());
let mut rounds = 0usize;
while state.len() >= AUTH_PROBE_TRACK_MAX_ENTRIES {
rounds += 1;
if rounds > 8 {
auth_probe_note_saturation(now);
let mut eviction_candidate: Option<(IpAddr, u32, Instant)> = None;
for entry in state.iter().take(AUTH_PROBE_PRUNE_SCAN_LIMIT) {
let key = *entry.key();
let fail_streak = entry.value().fail_streak;
let last_seen = entry.value().last_seen;
match eviction_candidate {
Some((_, current_fail, current_seen))
if fail_streak > current_fail
|| (fail_streak == current_fail && last_seen >= current_seen) =>
{
}
_ => eviction_candidate = Some((key, fail_streak, last_seen)),
}
}
let Some((evict_key, _, _)) = eviction_candidate else {
return;
};
state.remove(&evict_key);
break;
}
}
for stale_key in stale_keys {
state.remove(&stale_key);
}
if state.len() >= AUTH_PROBE_TRACK_MAX_ENTRIES {
if eviction_candidates.is_empty() {
let mut stale_keys = Vec::new();
let mut eviction_candidate: Option<(IpAddr, u32, Instant)> = None;
let state_len = state.len();
let scan_limit = state_len.min(AUTH_PROBE_PRUNE_SCAN_LIMIT);
let start_offset = if state_len == 0 {
0
} else {
auth_probe_eviction_offset(peer_ip, now) % state_len
};
let mut scanned = 0usize;
for entry in state.iter().skip(start_offset) {
let key = *entry.key();
let fail_streak = entry.value().fail_streak;
let last_seen = entry.value().last_seen;
match eviction_candidate {
Some((_, current_fail, current_seen))
if fail_streak > current_fail
|| (fail_streak == current_fail && last_seen >= current_seen) =>
{
}
_ => eviction_candidate = Some((key, fail_streak, last_seen)),
}
if auth_probe_state_expired(entry.value(), now) {
stale_keys.push(key);
}
scanned += 1;
if scanned >= scan_limit {
break;
}
}
if scanned < scan_limit {
for entry in state.iter().take(scan_limit - scanned) {
let key = *entry.key();
let fail_streak = entry.value().fail_streak;
let last_seen = entry.value().last_seen;
match eviction_candidate {
Some((_, current_fail, current_seen))
if fail_streak > current_fail
|| (fail_streak == current_fail && last_seen >= current_seen) =>
{
}
_ => eviction_candidate = Some((key, fail_streak, last_seen)),
}
if auth_probe_state_expired(entry.value(), now) {
stale_keys.push(key);
}
}
}
for stale_key in stale_keys {
state.remove(&stale_key);
}
if state.len() < AUTH_PROBE_TRACK_MAX_ENTRIES {
break;
}
let Some((evict_key, _, _)) = eviction_candidate else {
auth_probe_note_saturation(now);
return;
}
let idx = auth_probe_eviction_offset(peer_ip, now) % eviction_candidates.len();
let evict_key = eviction_candidates[idx];
};
state.remove(&evict_key);
auth_probe_note_saturation(now);
}
}
@ -186,6 +356,11 @@ fn clear_auth_probe_state_for_testing() {
if let Some(state) = AUTH_PROBE_STATE.get() {
state.clear();
}
if let Some(saturation) = AUTH_PROBE_SATURATION_STATE.get()
&& let Ok(mut guard) = saturation.lock()
{
*guard = None;
}
}
#[cfg(test)]
@ -200,6 +375,16 @@ fn auth_probe_is_throttled_for_testing(peer_ip: IpAddr) -> bool {
auth_probe_is_throttled(peer_ip, Instant::now())
}
#[cfg(test)]
fn auth_probe_saturation_is_throttled_for_testing() -> bool {
auth_probe_saturation_is_throttled(Instant::now())
}
#[cfg(test)]
fn auth_probe_saturation_is_throttled_at_for_testing(now: Instant) -> bool {
auth_probe_saturation_is_throttled(now)
}
#[cfg(test)]
fn auth_probe_test_lock() -> &'static Mutex<()> {
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
@ -225,7 +410,13 @@ fn warn_invalid_secret_once(name: &str, reason: &str, expected: usize, got: Opti
let key = (name.to_string(), reason.to_string());
let warned = INVALID_SECRET_WARNED.get_or_init(|| Mutex::new(HashSet::new()));
let should_warn = match warned.lock() {
Ok(mut guard) => guard.insert(key),
Ok(mut guard) => {
if !guard.contains(&key) && guard.len() >= WARNED_SECRET_MAX_ENTRIES {
false
} else {
guard.insert(key)
}
}
Err(_) => true,
};
@ -317,6 +508,24 @@ fn decode_user_secrets(
secrets
}
async fn maybe_apply_server_hello_delay(config: &ProxyConfig) {
if config.censorship.server_hello_delay_max_ms == 0 {
return;
}
let min = config.censorship.server_hello_delay_min_ms;
let max = config.censorship.server_hello_delay_max_ms.max(min);
let delay_ms = if max == min {
max
} else {
rand::rng().random_range(min..=max)
};
if delay_ms > 0 {
tokio::time::sleep(Duration::from_millis(delay_ms)).await;
}
}
/// Result of successful handshake
///
/// Key material (`dec_key`, `dec_iv`, `enc_key`, `enc_iv`) is
@ -338,6 +547,7 @@ pub struct HandshakeSuccess {
/// Client address
pub peer: SocketAddr,
/// Whether TLS was used
pub is_tls: bool,
}
@ -367,17 +577,22 @@ where
{
debug!(peer = %peer, handshake_len = handshake.len(), "Processing TLS handshake");
if auth_probe_is_throttled(peer.ip(), Instant::now()) {
let throttle_now = Instant::now();
if auth_probe_should_apply_preauth_throttle(peer.ip(), throttle_now) {
maybe_apply_server_hello_delay(config).await;
debug!(peer = %peer, "TLS handshake rejected by pre-auth probe throttle");
return HandshakeResult::BadClient { reader, writer };
}
if handshake.len() < tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 {
auth_probe_record_failure(peer.ip(), Instant::now());
maybe_apply_server_hello_delay(config).await;
debug!(peer = %peer, "TLS handshake too short");
return HandshakeResult::BadClient { reader, writer };
}
let secrets = decode_user_secrets(config, None);
let client_sni = tls::extract_sni_from_client_hello(handshake);
let secrets = decode_user_secrets(config, client_sni.as_deref());
let validation = match tls::validate_tls_handshake_with_replay_window(
handshake,
@ -388,6 +603,7 @@ where
Some(v) => v,
None => {
auth_probe_record_failure(peer.ip(), Instant::now());
maybe_apply_server_hello_delay(config).await;
debug!(
peer = %peer,
ignore_time_skew = config.access.ignore_time_skew,
@ -402,20 +618,24 @@ where
let digest_half = &validation.digest[..tls::TLS_DIGEST_HALF_LEN];
if replay_checker.check_and_add_tls_digest(digest_half) {
auth_probe_record_failure(peer.ip(), Instant::now());
maybe_apply_server_hello_delay(config).await;
warn!(peer = %peer, "TLS replay attack detected (duplicate digest)");
return HandshakeResult::BadClient { reader, writer };
}
let secret = match secrets.iter().find(|(name, _)| *name == validation.user) {
Some((_, s)) => s,
None => return HandshakeResult::BadClient { reader, writer },
None => {
maybe_apply_server_hello_delay(config).await;
return HandshakeResult::BadClient { reader, writer };
}
};
let cached = if config.censorship.tls_emulation {
if let Some(cache) = tls_cache.as_ref() {
let selected_domain = if let Some(sni) = tls::extract_sni_from_client_hello(handshake) {
let selected_domain = if let Some(sni) = client_sni.as_ref() {
if cache.contains_domain(&sni).await {
sni
sni.clone()
} else {
config.censorship.tls_domain.clone()
}
@ -448,6 +668,7 @@ where
} else if alpn_list.iter().any(|p| p == b"http/1.1") {
Some(b"http/1.1".to_vec())
} else if !alpn_list.is_empty() {
maybe_apply_server_hello_delay(config).await;
debug!(peer = %peer, "Client ALPN list has no supported protocol; using masking fallback");
return HandshakeResult::BadClient { reader, writer };
} else {
@ -480,19 +701,9 @@ where
)
};
// Optional anti-fingerprint delay before sending ServerHello.
if config.censorship.server_hello_delay_max_ms > 0 {
let min = config.censorship.server_hello_delay_min_ms;
let max = config.censorship.server_hello_delay_max_ms.max(min);
let delay_ms = if max == min {
max
} else {
rand::rng().random_range(min..=max)
};
if delay_ms > 0 {
tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
}
}
// Apply the same optional delay budget used by reject paths to reduce
// distinguishability between success and fail-closed handshakes.
maybe_apply_server_hello_delay(config).await;
debug!(peer = %peer, response_len = response.len(), "Sending TLS ServerHello");
@ -536,9 +747,19 @@ where
R: AsyncRead + Unpin + Send,
W: AsyncWrite + Unpin + Send,
{
trace!(peer = %peer, handshake = ?hex::encode(handshake), "MTProto handshake bytes");
let handshake_fingerprint = {
let digest = sha256(&handshake[..8]);
hex::encode(&digest[..4])
};
trace!(
peer = %peer,
handshake_fingerprint = %handshake_fingerprint,
"MTProto handshake prefix"
);
if auth_probe_is_throttled(peer.ip(), Instant::now()) {
let throttle_now = Instant::now();
if auth_probe_should_apply_preauth_throttle(peer.ip(), throttle_now) {
maybe_apply_server_hello_delay(config).await;
debug!(peer = %peer, "MTProto handshake rejected by pre-auth probe throttle");
return HandshakeResult::BadClient { reader, writer };
}
@ -554,7 +775,7 @@ where
let dec_prekey = &dec_prekey_iv[..PREKEY_LEN];
let dec_iv_bytes = &dec_prekey_iv[PREKEY_LEN..];
let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len());
let mut dec_key_input = Zeroizing::new(Vec::with_capacity(PREKEY_LEN + secret.len()));
dec_key_input.extend_from_slice(dec_prekey);
dec_key_input.extend_from_slice(&secret);
let dec_key = sha256(&dec_key_input);
@ -590,7 +811,7 @@ where
let enc_prekey = &enc_prekey_iv[..PREKEY_LEN];
let enc_iv_bytes = &enc_prekey_iv[PREKEY_LEN..];
let mut enc_key_input = Vec::with_capacity(PREKEY_LEN + secret.len());
let mut enc_key_input = Zeroizing::new(Vec::with_capacity(PREKEY_LEN + secret.len()));
enc_key_input.extend_from_slice(enc_prekey);
enc_key_input.extend_from_slice(&secret);
let enc_key = sha256(&enc_key_input);
@ -609,6 +830,7 @@ where
// authentication check first to avoid poisoning the replay cache.
if replay_checker.check_and_add_handshake(dec_prekey_iv) {
auth_probe_record_failure(peer.ip(), Instant::now());
maybe_apply_server_hello_delay(config).await;
warn!(peer = %peer, user = %user, "MTProto replay attack detected");
return HandshakeResult::BadClient { reader, writer };
}
@ -645,6 +867,7 @@ where
}
auth_probe_record_failure(peer.ip(), Instant::now());
maybe_apply_server_hello_delay(config).await;
debug!(peer = %peer, "MTProto handshake: no matching user found");
HandshakeResult::BadClient { reader, writer }
}
@ -677,7 +900,7 @@ pub fn generate_tg_nonce(
nonce[DC_IDX_POS..DC_IDX_POS + 2].copy_from_slice(&dc_idx.to_le_bytes());
if fast_mode {
let mut key_iv = Vec::with_capacity(KEY_LEN + IV_LEN);
let mut key_iv = Zeroizing::new(Vec::with_capacity(KEY_LEN + IV_LEN));
key_iv.extend_from_slice(client_enc_key);
key_iv.extend_from_slice(&client_enc_iv.to_be_bytes());
key_iv.reverse(); // Python/C behavior: reversed enc_key+enc_iv in nonce
@ -685,7 +908,7 @@ pub fn generate_tg_nonce(
}
let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN];
let dec_key_iv: Vec<u8> = enc_key_iv.iter().rev().copied().collect();
let dec_key_iv = Zeroizing::new(enc_key_iv.iter().rev().copied().collect::<Vec<u8>>());
let mut tg_enc_key = [0u8; 32];
tg_enc_key.copy_from_slice(&enc_key_iv[..KEY_LEN]);
@ -706,7 +929,7 @@ pub fn generate_tg_nonce(
/// Encrypt nonce for sending to Telegram and return cipher objects with correct counter state
pub fn encrypt_tg_nonce_with_ciphers(nonce: &[u8; HANDSHAKE_LEN]) -> (Vec<u8>, AesCtr, AesCtr) {
let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN];
let dec_key_iv: Vec<u8> = enc_key_iv.iter().rev().copied().collect();
let dec_key_iv = Zeroizing::new(enc_key_iv.iter().rev().copied().collect::<Vec<u8>>());
let mut enc_key = [0u8; 32];
enc_key.copy_from_slice(&enc_key_iv[..KEY_LEN]);
@ -727,11 +950,14 @@ pub fn encrypt_tg_nonce_with_ciphers(nonce: &[u8; HANDSHAKE_LEN]) -> (Vec<u8>, A
result.extend_from_slice(&encrypted_full[PROTO_TAG_POS..]);
let decryptor = AesCtr::new(&dec_key, dec_iv);
enc_key.zeroize();
dec_key.zeroize();
(result, encryptor, decryptor)
}
/// Encrypt nonce for sending to Telegram (legacy function for compatibility)
pub fn encrypt_tg_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec<u8> {
let (encrypted, _, _) = encrypt_tg_nonce_with_ciphers(nonce);
encrypted
@ -741,6 +967,10 @@ pub fn encrypt_tg_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec<u8> {
#[path = "handshake_security_tests.rs"]
mod security_tests;
#[cfg(test)]
#[path = "handshake_gap_short_tls_probe_throttle_security_tests.rs"]
mod gap_short_tls_probe_throttle_security_tests;
/// Compile-time guard: HandshakeSuccess holds cryptographic key material and
/// must never be Copy. A Copy impl would allow silent key duplication,
/// undermining the zeroize-on-drop guarantee.

View File

@ -0,0 +1,50 @@
use super::*;
use crate::stats::ReplayChecker;
use std::net::SocketAddr;
use std::time::Duration;
fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig {
let mut cfg = ProxyConfig::default();
cfg.access.users.clear();
cfg.access
.users
.insert("user".to_string(), secret_hex.to_string());
cfg.access.ignore_time_skew = true;
cfg
}
#[tokio::test]
async fn gap_t01_short_tls_probe_burst_is_throttled() {
let _guard = auth_probe_test_lock()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
clear_auth_probe_state_for_testing();
let config = test_config_with_secret_hex("11111111111111111111111111111111");
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let rng = SecureRandom::new();
let peer: SocketAddr = "198.51.100.171:44361".parse().unwrap();
let too_short = vec![0x16, 0x03, 0x01];
for _ in 0..AUTH_PROBE_BACKOFF_START_FAILS {
let result = handle_tls_handshake(
&too_short,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
&rng,
None,
)
.await;
assert!(matches!(result, HandshakeResult::BadClient { .. }));
}
assert!(
auth_probe_fail_streak_for_testing(peer.ip())
.is_some_and(|streak| streak >= AUTH_PROBE_BACKOFF_START_FAILS),
"short TLS probe bursts must increase auth-probe fail streak"
);
}

File diff suppressed because it is too large Load Diff

View File

@ -7,7 +7,7 @@ use tokio::net::TcpStream;
#[cfg(unix)]
use tokio::net::UnixStream;
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt};
use tokio::time::timeout;
use tokio::time::{Instant, timeout};
use tracing::debug;
use crate::config::ProxyConfig;
use crate::network::dns_overrides::resolve_socket_addr;
@ -24,8 +24,36 @@ const MASK_TIMEOUT: Duration = Duration::from_millis(50);
const MASK_RELAY_TIMEOUT: Duration = Duration::from_secs(60);
#[cfg(test)]
const MASK_RELAY_TIMEOUT: Duration = Duration::from_millis(200);
#[cfg(not(test))]
const MASK_RELAY_IDLE_TIMEOUT: Duration = Duration::from_secs(5);
#[cfg(test)]
const MASK_RELAY_IDLE_TIMEOUT: Duration = Duration::from_millis(100);
const MASK_BUFFER_SIZE: usize = 8192;
async fn copy_with_idle_timeout<R, W>(reader: &mut R, writer: &mut W)
where
R: AsyncRead + Unpin,
W: AsyncWrite + Unpin,
{
let mut buf = [0u8; MASK_BUFFER_SIZE];
loop {
let read_res = timeout(MASK_RELAY_IDLE_TIMEOUT, reader.read(&mut buf)).await;
let n = match read_res {
Ok(Ok(n)) => n,
Ok(Err(_)) | Err(_) => break,
};
if n == 0 {
break;
}
let write_res = timeout(MASK_RELAY_IDLE_TIMEOUT, writer.write_all(&buf[..n])).await;
match write_res {
Ok(Ok(())) => {}
Ok(Err(_)) | Err(_) => break,
}
}
}
async fn write_proxy_header_with_timeout<W>(mask_write: &mut W, header: &[u8]) -> bool
where
W: AsyncWrite + Unpin,
@ -49,6 +77,20 @@ where
}
}
async fn wait_mask_connect_budget(started: Instant) {
let elapsed = started.elapsed();
if elapsed < MASK_TIMEOUT {
tokio::time::sleep(MASK_TIMEOUT - elapsed).await;
}
}
async fn wait_mask_outcome_budget(started: Instant) {
let elapsed = started.elapsed();
if elapsed < MASK_TIMEOUT {
tokio::time::sleep(MASK_TIMEOUT - elapsed).await;
}
}
/// Detect client type based on initial data
fn detect_client_type(data: &[u8]) -> &'static str {
// Check for HTTP request
@ -107,6 +149,8 @@ where
// Connect via Unix socket or TCP
#[cfg(unix)]
if let Some(ref sock_path) = config.censorship.mask_unix_sock {
let outcome_started = Instant::now();
let connect_started = Instant::now();
debug!(
client_type = client_type,
sock = %sock_path,
@ -137,20 +181,25 @@ where
};
if let Some(header) = proxy_header {
if !write_proxy_header_with_timeout(&mut mask_write, &header).await {
wait_mask_outcome_budget(outcome_started).await;
return;
}
}
if timeout(MASK_RELAY_TIMEOUT, relay_to_mask(reader, writer, mask_read, mask_write, initial_data)).await.is_err() {
debug!("Mask relay timed out (unix socket)");
}
wait_mask_outcome_budget(outcome_started).await;
}
Ok(Err(e)) => {
wait_mask_connect_budget(connect_started).await;
debug!(error = %e, "Failed to connect to mask unix socket");
consume_client_data_with_timeout(reader).await;
wait_mask_outcome_budget(outcome_started).await;
}
Err(_) => {
debug!("Timeout connecting to mask unix socket");
consume_client_data_with_timeout(reader).await;
wait_mask_outcome_budget(outcome_started).await;
}
}
return;
@ -172,6 +221,8 @@ where
let mask_addr = resolve_socket_addr(mask_host, mask_port)
.map(|addr| addr.to_string())
.unwrap_or_else(|| format!("{}:{}", mask_host, mask_port));
let outcome_started = Instant::now();
let connect_started = Instant::now();
let connect_result = timeout(MASK_TIMEOUT, TcpStream::connect(&mask_addr)).await;
match connect_result {
Ok(Ok(stream)) => {
@ -196,20 +247,25 @@ where
let (mask_read, mut mask_write) = stream.into_split();
if let Some(header) = proxy_header {
if !write_proxy_header_with_timeout(&mut mask_write, &header).await {
wait_mask_outcome_budget(outcome_started).await;
return;
}
}
if timeout(MASK_RELAY_TIMEOUT, relay_to_mask(reader, writer, mask_read, mask_write, initial_data)).await.is_err() {
debug!("Mask relay timed out");
}
wait_mask_outcome_budget(outcome_started).await;
}
Ok(Err(e)) => {
wait_mask_connect_budget(connect_started).await;
debug!(error = %e, "Failed to connect to mask host");
consume_client_data_with_timeout(reader).await;
wait_mask_outcome_budget(outcome_started).await;
}
Err(_) => {
debug!("Timeout connecting to mask host");
consume_client_data_with_timeout(reader).await;
wait_mask_outcome_budget(outcome_started).await;
}
}
}
@ -238,11 +294,11 @@ where
let _ = tokio::join!(
async {
let _ = tokio::io::copy(&mut reader, &mut mask_write).await;
copy_with_idle_timeout(&mut reader, &mut mask_write).await;
let _ = mask_write.shutdown().await;
},
async {
let _ = tokio::io::copy(&mut mask_read, &mut writer).await;
copy_with_idle_timeout(&mut mask_read, &mut writer).await;
let _ = writer.shutdown().await;
}
);

File diff suppressed because it is too large Load Diff

View File

@ -1,15 +1,14 @@
use std::collections::hash_map::DefaultHasher;
use std::collections::hash_map::RandomState;
use std::hash::BuildHasher;
use std::hash::{Hash, Hasher};
use std::net::{IpAddr, SocketAddr};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, OnceLock};
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::{Arc, Mutex, OnceLock};
use std::time::{Duration, Instant};
#[cfg(test)]
use std::sync::Mutex;
use dashmap::DashMap;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::sync::{mpsc, oneshot, watch};
use tokio::sync::{mpsc, oneshot, watch, Mutex as AsyncMutex};
use tokio::time::timeout;
use tracing::{debug, trace, warn};
@ -34,13 +33,26 @@ enum C2MeCommand {
const DESYNC_DEDUP_WINDOW: Duration = Duration::from_secs(60);
const DESYNC_DEDUP_MAX_ENTRIES: usize = 65_536;
const DESYNC_DEDUP_PRUNE_SCAN_LIMIT: usize = 1024;
const DESYNC_FULL_CACHE_EMIT_MIN_INTERVAL: Duration = Duration::from_millis(1000);
const DESYNC_ERROR_CLASS: &str = "frame_too_large_crypto_desync";
const C2ME_CHANNEL_CAPACITY_FALLBACK: usize = 128;
const C2ME_SOFT_PRESSURE_MIN_FREE_SLOTS: usize = 64;
const C2ME_SENDER_FAIRNESS_BUDGET: usize = 32;
#[cfg(test)]
const C2ME_SEND_TIMEOUT: Duration = Duration::from_millis(50);
#[cfg(not(test))]
const C2ME_SEND_TIMEOUT: Duration = Duration::from_secs(5);
const ME_D2C_FLUSH_BATCH_MAX_FRAMES_MIN: usize = 1;
const ME_D2C_FLUSH_BATCH_MAX_BYTES_MIN: usize = 4096;
#[cfg(test)]
const QUOTA_USER_LOCKS_MAX: usize = 64;
#[cfg(not(test))]
const QUOTA_USER_LOCKS_MAX: usize = 4_096;
static DESYNC_DEDUP: OnceLock<DashMap<u64, Instant>> = OnceLock::new();
static DESYNC_HASHER: OnceLock<RandomState> = OnceLock::new();
static DESYNC_FULL_CACHE_LAST_EMIT_AT: OnceLock<Mutex<Option<Instant>>> = OnceLock::new();
static DESYNC_DEDUP_EVER_SATURATED: OnceLock<AtomicBool> = OnceLock::new();
static QUOTA_USER_LOCKS: OnceLock<DashMap<String, Arc<AsyncMutex<()>>>> = OnceLock::new();
struct RelayForensicsState {
trace_id: u64,
@ -80,7 +92,8 @@ impl MeD2cFlushPolicy {
}
fn hash_value<T: Hash>(value: &T) -> u64 {
let mut hasher = DefaultHasher::new();
let state = DESYNC_HASHER.get_or_init(RandomState::new);
let mut hasher = state.build_hasher();
value.hash(&mut hasher);
hasher.finish()
}
@ -95,6 +108,11 @@ fn should_emit_full_desync(key: u64, all_full: bool, now: Instant) -> bool {
}
let dedup = DESYNC_DEDUP.get_or_init(DashMap::new);
let saturated_before = dedup.len() >= DESYNC_DEDUP_MAX_ENTRIES;
let ever_saturated = DESYNC_DEDUP_EVER_SATURATED.get_or_init(|| AtomicBool::new(false));
if saturated_before {
ever_saturated.store(true, Ordering::Relaxed);
}
if let Some(mut seen_at) = dedup.get_mut(&key) {
if now.duration_since(*seen_at) >= DESYNC_DEDUP_WINDOW {
@ -106,12 +124,17 @@ fn should_emit_full_desync(key: u64, all_full: bool, now: Instant) -> bool {
if dedup.len() >= DESYNC_DEDUP_MAX_ENTRIES {
let mut stale_keys = Vec::new();
let mut eviction_candidate = None;
let mut oldest_candidate: Option<(u64, Instant)> = None;
for entry in dedup.iter().take(DESYNC_DEDUP_PRUNE_SCAN_LIMIT) {
if eviction_candidate.is_none() {
eviction_candidate = Some(*entry.key());
let key = *entry.key();
let seen_at = *entry.value();
match oldest_candidate {
Some((_, oldest_seen)) if seen_at >= oldest_seen => {}
_ => oldest_candidate = Some((key, seen_at)),
}
if now.duration_since(*entry.value()) >= DESYNC_DEDUP_WINDOW {
if now.duration_since(seen_at) >= DESYNC_DEDUP_WINDOW {
stale_keys.push(*entry.key());
}
}
@ -119,17 +142,57 @@ fn should_emit_full_desync(key: u64, all_full: bool, now: Instant) -> bool {
dedup.remove(&stale_key);
}
if dedup.len() >= DESYNC_DEDUP_MAX_ENTRIES {
let Some(evict_key) = eviction_candidate else {
let Some((evict_key, _)) = oldest_candidate else {
return false;
};
dedup.remove(&evict_key);
dedup.insert(key, now);
return false;
return should_emit_full_desync_full_cache(now);
}
}
dedup.insert(key, now);
true
let saturated_after = dedup.len() >= DESYNC_DEDUP_MAX_ENTRIES;
// Preserve the first sequential insert that reaches capacity as a normal
// emit, while still gating concurrent newcomer churn after the cache has
// ever been observed at saturation.
let was_ever_saturated = if saturated_after {
ever_saturated.swap(true, Ordering::Relaxed)
} else {
ever_saturated.load(Ordering::Relaxed)
};
if saturated_before || (saturated_after && was_ever_saturated) {
should_emit_full_desync_full_cache(now)
} else {
true
}
}
fn should_emit_full_desync_full_cache(now: Instant) -> bool {
let gate = DESYNC_FULL_CACHE_LAST_EMIT_AT.get_or_init(|| Mutex::new(None));
let Ok(mut last_emit_at) = gate.lock() else {
return false;
};
match *last_emit_at {
None => {
*last_emit_at = Some(now);
true
}
Some(last) => {
let Some(elapsed) = now.checked_duration_since(last) else {
*last_emit_at = Some(now);
return true;
};
if elapsed >= DESYNC_FULL_CACHE_EMIT_MIN_INTERVAL {
*last_emit_at = Some(now);
true
} else {
false
}
}
}
}
#[cfg(test)]
@ -137,6 +200,21 @@ fn clear_desync_dedup_for_testing() {
if let Some(dedup) = DESYNC_DEDUP.get() {
dedup.clear();
}
if let Some(ever_saturated) = DESYNC_DEDUP_EVER_SATURATED.get() {
ever_saturated.store(false, Ordering::Relaxed);
}
if let Some(last_emit_at) = DESYNC_FULL_CACHE_LAST_EMIT_AT.get() {
match last_emit_at.lock() {
Ok(mut guard) => {
*guard = None;
}
Err(poisoned) => {
let mut guard = poisoned.into_inner();
*guard = None;
last_emit_at.clear_poison();
}
}
}
}
#[cfg(test)]
@ -240,6 +318,46 @@ fn should_yield_c2me_sender(sent_since_yield: usize, has_backlog: bool) -> bool
has_backlog && sent_since_yield >= C2ME_SENDER_FAIRNESS_BUDGET
}
fn quota_exceeded_for_user(stats: &Stats, user: &str, quota_limit: Option<u64>) -> bool {
quota_limit.is_some_and(|quota| stats.get_user_total_octets(user) >= quota)
}
fn quota_would_be_exceeded_for_user(
stats: &Stats,
user: &str,
quota_limit: Option<u64>,
bytes: u64,
) -> bool {
quota_limit.is_some_and(|quota| {
let used = stats.get_user_total_octets(user);
used >= quota || bytes > quota.saturating_sub(used)
})
}
fn quota_user_lock(user: &str) -> Arc<AsyncMutex<()>> {
let locks = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
if let Some(existing) = locks.get(user) {
return Arc::clone(existing.value());
}
if locks.len() >= QUOTA_USER_LOCKS_MAX {
locks.retain(|_, value| Arc::strong_count(value) > 1);
}
if locks.len() >= QUOTA_USER_LOCKS_MAX {
return Arc::new(AsyncMutex::new(()));
}
let created = Arc::new(AsyncMutex::new(()));
match locks.entry(user.to_string()) {
dashmap::mapref::entry::Entry::Occupied(entry) => Arc::clone(entry.get()),
dashmap::mapref::entry::Entry::Vacant(entry) => {
entry.insert(Arc::clone(&created));
created
}
}
}
async fn enqueue_c2me_command(
tx: &mpsc::Sender<C2MeCommand>,
cmd: C2MeCommand,
@ -252,7 +370,14 @@ async fn enqueue_c2me_command(
if tx.capacity() <= C2ME_SOFT_PRESSURE_MIN_FREE_SLOTS {
tokio::task::yield_now().await;
}
tx.send(cmd).await
match timeout(C2ME_SEND_TIMEOUT, tx.reserve()).await {
Ok(Ok(permit)) => {
permit.send(cmd);
Ok(())
}
Ok(Err(_)) => Err(mpsc::error::SendError(cmd)),
Err(_) => Err(mpsc::error::SendError(cmd)),
}
}
}
}
@ -276,6 +401,7 @@ where
W: AsyncWrite + Unpin + Send + 'static,
{
let user = success.user.clone();
let quota_limit = config.access.user_data_quota.get(&user).copied();
let peer = success.peer;
let proto_tag = success.proto_tag;
let pool_generation = me_pool.current_generation();
@ -291,7 +417,7 @@ where
);
let (conn_id, me_rx) = me_pool.registry().register().await;
let trace_id = conn_id;
let trace_id = session_id;
let bytes_me2c = Arc::new(AtomicU64::new(0));
let mut forensics = RelayForensicsState {
trace_id,
@ -306,7 +432,7 @@ where
};
stats.increment_user_connects(&user);
stats.increment_current_connections_me();
let _me_connection_lease = stats.acquire_me_connection_lease();
if let Some(cutover) = affected_cutover_state(
&route_rx,
@ -324,7 +450,6 @@ where
tokio::time::sleep(delay).await;
let _ = me_pool.send_close(conn_id).await;
me_pool.registry().unregister(conn_id).await;
stats.decrement_current_connections_me();
return Err(ProxyError::Proxy(ROUTE_SWITCH_ERROR_MSG.to_string()));
}
@ -425,6 +550,7 @@ where
&mut frame_buf,
stats_clone.as_ref(),
&user_clone,
quota_limit,
bytes_me2c_clone.as_ref(),
conn_id,
d2c_flush_policy.ack_flush_immediate,
@ -457,6 +583,7 @@ where
&mut frame_buf,
stats_clone.as_ref(),
&user_clone,
quota_limit,
bytes_me2c_clone.as_ref(),
conn_id,
d2c_flush_policy.ack_flush_immediate,
@ -489,6 +616,7 @@ where
&mut frame_buf,
stats_clone.as_ref(),
&user_clone,
quota_limit,
bytes_me2c_clone.as_ref(),
conn_id,
d2c_flush_policy.ack_flush_immediate,
@ -521,6 +649,7 @@ where
&mut frame_buf,
stats_clone.as_ref(),
&user_clone,
quota_limit,
bytes_me2c_clone.as_ref(),
conn_id,
d2c_flush_policy.ack_flush_immediate,
@ -602,7 +731,19 @@ where
forensics.bytes_c2me = forensics
.bytes_c2me
.saturating_add(payload.len() as u64);
stats.add_user_octets_from(&user, payload.len() as u64);
if let Some(limit) = quota_limit {
let quota_lock = quota_user_lock(&user);
let _quota_guard = quota_lock.lock().await;
stats.add_user_octets_from(&user, payload.len() as u64);
if quota_exceeded_for_user(stats.as_ref(), &user, Some(limit)) {
main_result = Err(ProxyError::DataQuotaExceeded {
user: user.clone(),
});
break;
}
} else {
stats.add_user_octets_from(&user, payload.len() as u64);
}
let mut flags = proto_flags;
if quickack {
flags |= RPC_FLAG_QUICKACK;
@ -672,7 +813,6 @@ where
"ME relay cleanup"
);
me_pool.registry().unregister(conn_id).await;
stats.decrement_current_connections_me();
result
}
@ -827,6 +967,7 @@ async fn process_me_writer_response<W>(
frame_buf: &mut Vec<u8>,
stats: &Stats,
user: &str,
quota_limit: Option<u64>,
bytes_me2c: &AtomicU64,
conn_id: u64,
ack_flush_immediate: bool,
@ -842,17 +983,47 @@ where
} else {
trace!(conn_id, bytes = data.len(), flags, "ME->C data");
}
bytes_me2c.fetch_add(data.len() as u64, Ordering::Relaxed);
stats.add_user_octets_to(user, data.len() as u64);
write_client_payload(
client_writer,
proto_tag,
flags,
&data,
rng,
frame_buf,
)
.await?;
let data_len = data.len() as u64;
if let Some(limit) = quota_limit {
let quota_lock = quota_user_lock(user);
let _quota_guard = quota_lock.lock().await;
if quota_would_be_exceeded_for_user(stats, user, Some(limit), data_len) {
return Err(ProxyError::DataQuotaExceeded {
user: user.to_string(),
});
}
write_client_payload(
client_writer,
proto_tag,
flags,
&data,
rng,
frame_buf,
)
.await?;
bytes_me2c.fetch_add(data.len() as u64, Ordering::Relaxed);
stats.add_user_octets_to(user, data.len() as u64);
if quota_exceeded_for_user(stats, user, Some(limit)) {
return Err(ProxyError::DataQuotaExceeded {
user: user.to_string(),
});
}
} else {
write_client_payload(
client_writer,
proto_tag,
flags,
&data,
rng,
frame_buf,
)
.await?;
bytes_me2c.fetch_add(data.len() as u64, Ordering::Relaxed);
stats.add_user_octets_to(user, data.len() as u64);
}
Ok(MeWriterResponseOutcome::Continue {
frames: 1,

File diff suppressed because it is too large Load Diff

View File

@ -53,16 +53,17 @@
use std::io;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex, OnceLock};
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::task::{Context, Poll};
use std::time::Duration;
use dashmap::DashMap;
use tokio::io::{
AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf, copy_bidirectional_with_sizes,
};
use tokio::time::Instant;
use tracing::{debug, trace, warn};
use crate::error::Result;
use crate::error::{ProxyError, Result};
use crate::stats::Stats;
use crate::stream::BufferPool;
@ -205,6 +206,10 @@ struct StatsIo<S> {
counters: Arc<SharedCounters>,
stats: Arc<Stats>,
user: String,
quota_limit: Option<u64>,
quota_exceeded: Arc<AtomicBool>,
quota_read_wake_scheduled: bool,
quota_write_wake_scheduled: bool,
epoch: Instant,
}
@ -214,11 +219,64 @@ impl<S> StatsIo<S> {
counters: Arc<SharedCounters>,
stats: Arc<Stats>,
user: String,
quota_limit: Option<u64>,
quota_exceeded: Arc<AtomicBool>,
epoch: Instant,
) -> Self {
// Mark initial activity so the watchdog doesn't fire before data flows
counters.touch(Instant::now(), epoch);
Self { inner, counters, stats, user, epoch }
Self {
inner,
counters,
stats,
user,
quota_limit,
quota_exceeded,
quota_read_wake_scheduled: false,
quota_write_wake_scheduled: false,
epoch,
}
}
}
#[derive(Debug)]
struct QuotaIoSentinel;
impl std::fmt::Display for QuotaIoSentinel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("user data quota exceeded")
}
}
impl std::error::Error for QuotaIoSentinel {}
fn quota_io_error() -> io::Error {
io::Error::new(io::ErrorKind::PermissionDenied, QuotaIoSentinel)
}
fn is_quota_io_error(err: &io::Error) -> bool {
err.kind() == io::ErrorKind::PermissionDenied
&& err
.get_ref()
.and_then(|source| source.downcast_ref::<QuotaIoSentinel>())
.is_some()
}
static QUOTA_USER_LOCKS: OnceLock<DashMap<String, Arc<Mutex<()>>>> = OnceLock::new();
fn quota_user_lock(user: &str) -> Arc<Mutex<()>> {
let locks = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
if let Some(existing) = locks.get(user) {
return Arc::clone(existing.value());
}
let created = Arc::new(Mutex::new(()));
match locks.entry(user.to_string()) {
dashmap::mapref::entry::Entry::Occupied(entry) => Arc::clone(entry.get()),
dashmap::mapref::entry::Entry::Vacant(entry) => {
entry.insert(Arc::clone(&created));
created
}
}
}
@ -229,6 +287,42 @@ impl<S: AsyncRead + Unpin> AsyncRead for StatsIo<S> {
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let this = self.get_mut();
if this.quota_exceeded.load(Ordering::Relaxed) {
return Poll::Ready(Err(quota_io_error()));
}
let quota_lock = this
.quota_limit
.is_some()
.then(|| quota_user_lock(&this.user));
let _quota_guard = if let Some(lock) = quota_lock.as_ref() {
match lock.try_lock() {
Ok(guard) => {
this.quota_read_wake_scheduled = false;
Some(guard)
}
Err(_) => {
if !this.quota_read_wake_scheduled {
this.quota_read_wake_scheduled = true;
let waker = cx.waker().clone();
tokio::task::spawn(async move {
tokio::task::yield_now().await;
waker.wake();
});
}
return Poll::Pending;
}
}
} else {
None
};
if let Some(limit) = this.quota_limit
&& this.stats.get_user_total_octets(&this.user) >= limit
{
this.quota_exceeded.store(true, Ordering::Relaxed);
return Poll::Ready(Err(quota_io_error()));
}
let before = buf.filled().len();
match Pin::new(&mut this.inner).poll_read(cx, buf) {
@ -243,6 +337,13 @@ impl<S: AsyncRead + Unpin> AsyncRead for StatsIo<S> {
this.stats.add_user_octets_from(&this.user, n as u64);
this.stats.increment_user_msgs_from(&this.user);
if let Some(limit) = this.quota_limit
&& this.stats.get_user_total_octets(&this.user) >= limit
{
this.quota_exceeded.store(true, Ordering::Relaxed);
return Poll::Ready(Err(quota_io_error()));
}
trace!(user = %this.user, bytes = n, "C->S");
}
Poll::Ready(Ok(()))
@ -259,8 +360,56 @@ impl<S: AsyncWrite + Unpin> AsyncWrite for StatsIo<S> {
buf: &[u8],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
if this.quota_exceeded.load(Ordering::Relaxed) {
return Poll::Ready(Err(quota_io_error()));
}
match Pin::new(&mut this.inner).poll_write(cx, buf) {
let quota_lock = this
.quota_limit
.is_some()
.then(|| quota_user_lock(&this.user));
let _quota_guard = if let Some(lock) = quota_lock.as_ref() {
match lock.try_lock() {
Ok(guard) => {
this.quota_write_wake_scheduled = false;
Some(guard)
}
Err(_) => {
if !this.quota_write_wake_scheduled {
this.quota_write_wake_scheduled = true;
let waker = cx.waker().clone();
tokio::task::spawn(async move {
tokio::task::yield_now().await;
waker.wake();
});
}
return Poll::Pending;
}
}
} else {
None
};
let write_buf = if let Some(limit) = this.quota_limit {
let used = this.stats.get_user_total_octets(&this.user);
if used >= limit {
this.quota_exceeded.store(true, Ordering::Relaxed);
return Poll::Ready(Err(quota_io_error()));
}
let remaining = (limit - used) as usize;
if buf.len() > remaining {
// Fail closed: do not emit partial S->C payload when remaining
// quota cannot accommodate the pending write request.
this.quota_exceeded.store(true, Ordering::Relaxed);
return Poll::Ready(Err(quota_io_error()));
}
buf
} else {
buf
};
match Pin::new(&mut this.inner).poll_write(cx, write_buf) {
Poll::Ready(Ok(n)) => {
if n > 0 {
// S→C: data written to client
@ -271,6 +420,13 @@ impl<S: AsyncWrite + Unpin> AsyncWrite for StatsIo<S> {
this.stats.add_user_octets_to(&this.user, n as u64);
this.stats.increment_user_msgs_to(&this.user);
if let Some(limit) = this.quota_limit
&& this.stats.get_user_total_octets(&this.user) >= limit
{
this.quota_exceeded.store(true, Ordering::Relaxed);
return Poll::Ready(Err(quota_io_error()));
}
trace!(user = %this.user, bytes = n, "S->C");
}
Poll::Ready(Ok(n))
@ -307,7 +463,8 @@ impl<S: AsyncWrite + Unpin> AsyncWrite for StatsIo<S> {
/// - Per-user stats: bytes and ops counted per direction
/// - Periodic rate logging: every 10 seconds when active
/// - Clean shutdown: both write sides are shut down on exit
/// - Error propagation: I/O errors are returned as `ProxyError::Io`
/// - Error propagation: quota exits return `ProxyError::DataQuotaExceeded`,
/// other I/O failures are returned as `ProxyError::Io`
pub async fn relay_bidirectional<CR, CW, SR, SW>(
client_reader: CR,
client_writer: CW,
@ -317,6 +474,7 @@ pub async fn relay_bidirectional<CR, CW, SR, SW>(
s2c_buf_size: usize,
user: &str,
stats: Arc<Stats>,
quota_limit: Option<u64>,
_buffer_pool: Arc<BufferPool>,
) -> Result<()>
where
@ -327,6 +485,7 @@ where
{
let epoch = Instant::now();
let counters = Arc::new(SharedCounters::new());
let quota_exceeded = Arc::new(AtomicBool::new(false));
let user_owned = user.to_string();
// ── Combine split halves into bidirectional streams ──────────────
@ -339,12 +498,15 @@ where
Arc::clone(&counters),
Arc::clone(&stats),
user_owned.clone(),
quota_limit,
Arc::clone(&quota_exceeded),
epoch,
);
// ── Watchdog: activity timeout + periodic rate logging ──────────
let wd_counters = Arc::clone(&counters);
let wd_user = user_owned.clone();
let wd_quota_exceeded = Arc::clone(&quota_exceeded);
let watchdog = async {
let mut prev_c2s: u64 = 0;
@ -356,6 +518,11 @@ where
let now = Instant::now();
let idle = wd_counters.idle_duration(now, epoch);
if wd_quota_exceeded.load(Ordering::Relaxed) {
warn!(user = %wd_user, "User data quota reached, closing relay");
return;
}
// ── Activity timeout ────────────────────────────────────
if idle >= ACTIVITY_TIMEOUT {
let c2s = wd_counters.c2s_bytes.load(Ordering::Relaxed);
@ -439,6 +606,22 @@ where
);
Ok(())
}
Some(Err(e)) if is_quota_io_error(&e) => {
let c2s = counters.c2s_bytes.load(Ordering::Relaxed);
let s2c = counters.s2c_bytes.load(Ordering::Relaxed);
warn!(
user = %user_owned,
c2s_bytes = c2s,
s2c_bytes = s2c,
c2s_msgs = c2s_ops,
s2c_msgs = s2c_ops,
duration_secs = duration.as_secs(),
"Data quota reached, closing relay"
);
Err(ProxyError::DataQuotaExceeded {
user: user_owned.clone(),
})
}
Some(Err(e)) => {
// I/O error in one of the directions
let c2s = counters.c2s_bytes.load(Ordering::Relaxed);
@ -472,3 +655,7 @@ where
}
}
}
#[cfg(test)]
#[path = "relay_security_tests.rs"]
mod security_tests;

File diff suppressed because it is too large Load Diff

View File

@ -1,10 +1,10 @@
use std::sync::Arc;
use std::sync::atomic::{AtomicU8, AtomicU64, Ordering};
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use tokio::sync::watch;
pub(crate) const ROUTE_SWITCH_ERROR_MSG: &str = "Route mode switched by cutover";
pub(crate) const ROUTE_SWITCH_ERROR_MSG: &str = "Session terminated";
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[repr(u8)]
@ -14,17 +14,6 @@ pub(crate) enum RelayRouteMode {
}
impl RelayRouteMode {
pub(crate) fn as_u8(self) -> u8 {
self as u8
}
pub(crate) fn from_u8(value: u8) -> Self {
match value {
1 => Self::Middle,
_ => Self::Direct,
}
}
pub(crate) fn as_str(self) -> &'static str {
match self {
Self::Direct => "direct",
@ -41,8 +30,6 @@ pub(crate) struct RouteCutoverState {
#[derive(Clone)]
pub(crate) struct RouteRuntimeController {
mode: Arc<AtomicU8>,
generation: Arc<AtomicU64>,
direct_since_epoch_secs: Arc<AtomicU64>,
tx: watch::Sender<RouteCutoverState>,
}
@ -60,18 +47,13 @@ impl RouteRuntimeController {
0
};
Self {
mode: Arc::new(AtomicU8::new(initial_mode.as_u8())),
generation: Arc::new(AtomicU64::new(0)),
direct_since_epoch_secs: Arc::new(AtomicU64::new(direct_since_epoch_secs)),
tx,
}
}
pub(crate) fn snapshot(&self) -> RouteCutoverState {
RouteCutoverState {
mode: RelayRouteMode::from_u8(self.mode.load(Ordering::Relaxed)),
generation: self.generation.load(Ordering::Relaxed),
}
*self.tx.borrow()
}
pub(crate) fn subscribe(&self) -> watch::Receiver<RouteCutoverState> {
@ -84,20 +66,29 @@ impl RouteRuntimeController {
}
pub(crate) fn set_mode(&self, mode: RelayRouteMode) -> Option<RouteCutoverState> {
let previous = self.mode.swap(mode.as_u8(), Ordering::Relaxed);
if previous == mode.as_u8() {
let mut next = None;
let changed = self.tx.send_if_modified(|state| {
if state.mode == mode {
return false;
}
state.mode = mode;
state.generation = state.generation.saturating_add(1);
next = Some(*state);
true
});
if !changed {
return None;
}
if matches!(mode, RelayRouteMode::Direct) {
self.direct_since_epoch_secs
.store(now_epoch_secs(), Ordering::Relaxed);
} else {
self.direct_since_epoch_secs.store(0, Ordering::Relaxed);
}
let generation = self.generation.fetch_add(1, Ordering::Relaxed) + 1;
let next = RouteCutoverState { mode, generation };
self.tx.send_replace(next);
Some(next)
next
}
}
@ -110,10 +101,10 @@ fn now_epoch_secs() -> u64 {
pub(crate) fn is_session_affected_by_cutover(
current: RouteCutoverState,
_session_mode: RelayRouteMode,
session_mode: RelayRouteMode,
session_generation: u64,
) -> bool {
current.generation > session_generation
current.generation > session_generation && current.mode != session_mode
}
pub(crate) fn affected_cutover_state(
@ -140,3 +131,7 @@ pub(crate) fn cutover_stagger_delay(session_id: u64, generation: u64) -> Duratio
let ms = 1000 + (value % 1000);
Duration::from_millis(ms)
}
#[cfg(test)]
#[path = "route_mode_security_tests.rs"]
mod security_tests;

View File

@ -0,0 +1,406 @@
use super::*;
use rand::{Rng, SeedableRng};
use rand::rngs::StdRng;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
#[test]
fn cutover_stagger_delay_is_deterministic_for_same_inputs() {
let d1 = cutover_stagger_delay(0x0123_4567_89ab_cdef, 42);
let d2 = cutover_stagger_delay(0x0123_4567_89ab_cdef, 42);
assert_eq!(
d1, d2,
"stagger delay must be deterministic for identical session/generation inputs"
);
}
#[test]
fn cutover_stagger_delay_stays_within_budget_bounds() {
// Black-hat model: censors trigger many cutovers and correlate disconnect timing.
// Keep delay inside a narrow coarse window to avoid long-tail spikes.
for generation in [0u64, 1, 2, 3, 16, 128, u32::MAX as u64, u64::MAX] {
for session_id in [
0u64,
1,
2,
0xdead_beef,
0xfeed_face_cafe_babe,
u64::MAX,
] {
let delay = cutover_stagger_delay(session_id, generation);
assert!(
(1000..=1999).contains(&delay.as_millis()),
"stagger delay must remain in fixed 1000..=1999ms budget"
);
}
}
}
#[test]
fn cutover_stagger_delay_changes_with_generation_for_same_session() {
let session_id = 0x0123_4567_89ab_cdef;
let first = cutover_stagger_delay(session_id, 100);
let second = cutover_stagger_delay(session_id, 101);
assert_ne!(
first, second,
"adjacent cutover generations should decorrelate disconnect delays"
);
}
#[test]
fn route_runtime_set_mode_is_idempotent_for_same_mode() {
let runtime = RouteRuntimeController::new(RelayRouteMode::Direct);
let first = runtime.snapshot();
let changed = runtime.set_mode(RelayRouteMode::Direct);
let second = runtime.snapshot();
assert!(
changed.is_none(),
"setting already-active mode must not produce a cutover event"
);
assert_eq!(
first.generation, second.generation,
"idempotent mode set must not bump generation"
);
}
#[test]
fn affected_cutover_state_triggers_only_for_newer_generation() {
let runtime = RouteRuntimeController::new(RelayRouteMode::Direct);
let rx = runtime.subscribe();
let initial = runtime.snapshot();
assert!(
affected_cutover_state(&rx, RelayRouteMode::Direct, initial.generation).is_none(),
"current generation must not be considered a cutover for existing session"
);
let next = runtime
.set_mode(RelayRouteMode::Middle)
.expect("mode change must produce cutover state");
let seen = affected_cutover_state(&rx, RelayRouteMode::Direct, initial.generation)
.expect("newer generation must be observed as cutover");
assert_eq!(seen.generation, next.generation);
assert_eq!(seen.mode, RelayRouteMode::Middle);
}
#[test]
fn integration_watch_and_snapshot_follow_same_transition_sequence() {
let runtime = RouteRuntimeController::new(RelayRouteMode::Direct);
let rx = runtime.subscribe();
let sequence = [
RelayRouteMode::Middle,
RelayRouteMode::Middle,
RelayRouteMode::Direct,
RelayRouteMode::Direct,
RelayRouteMode::Middle,
];
let mut expected_generation = 0u64;
let mut expected_mode = RelayRouteMode::Direct;
for target in sequence {
let changed = runtime.set_mode(target);
if target == expected_mode {
assert!(changed.is_none(), "idempotent transition must return none");
} else {
expected_mode = target;
expected_generation = expected_generation.saturating_add(1);
let emitted = changed.expect("real transition must emit cutover state");
assert_eq!(emitted.mode, expected_mode);
assert_eq!(emitted.generation, expected_generation);
}
let snap = runtime.snapshot();
let watched = *rx.borrow();
assert_eq!(snap, watched, "snapshot and watch state must stay aligned");
assert_eq!(snap.mode, expected_mode);
assert_eq!(snap.generation, expected_generation);
}
}
#[test]
fn session_is_not_affected_when_mode_matches_even_if_generation_advanced() {
let session_mode = RelayRouteMode::Direct;
let current = RouteCutoverState {
mode: RelayRouteMode::Direct,
generation: 2,
};
let session_generation = 0;
assert!(
!is_session_affected_by_cutover(current, session_mode, session_generation),
"session on matching final route mode should not be force-cut over on intermediate generation bumps"
);
}
#[test]
fn cutover_predicate_rejects_equal_generation_even_if_mode_differs() {
let current = RouteCutoverState {
mode: RelayRouteMode::Middle,
generation: 77,
};
assert!(
!is_session_affected_by_cutover(current, RelayRouteMode::Direct, 77),
"equal generation must never trigger cutover regardless of mode mismatch"
);
}
#[test]
fn adversarial_route_oscillation_only_cuts_over_sessions_with_different_final_mode() {
let runtime = RouteRuntimeController::new(RelayRouteMode::Direct);
let rx = runtime.subscribe();
let session_generation = runtime.snapshot().generation;
runtime
.set_mode(RelayRouteMode::Middle)
.expect("direct->middle must transition");
runtime
.set_mode(RelayRouteMode::Direct)
.expect("middle->direct must transition");
assert!(
affected_cutover_state(&rx, RelayRouteMode::Direct, session_generation).is_none(),
"direct session should survive when final mode returns to direct"
);
assert!(
affected_cutover_state(&rx, RelayRouteMode::Middle, session_generation).is_some(),
"middle session should be cut over when final mode is direct"
);
}
#[test]
fn light_fuzz_cutover_predicate_matches_reference_oracle() {
let mut rng = StdRng::seed_from_u64(0xC0DEC0DE5EED);
for _ in 0..20_000 {
let current = RouteCutoverState {
mode: if rng.random::<bool>() {
RelayRouteMode::Direct
} else {
RelayRouteMode::Middle
},
generation: rng.random_range(0u64..1_000_000),
};
let session_mode = if rng.random::<bool>() {
RelayRouteMode::Direct
} else {
RelayRouteMode::Middle
};
let session_generation = rng.random_range(0u64..1_000_000);
let expected = current.generation > session_generation && current.mode != session_mode;
let actual = is_session_affected_by_cutover(current, session_mode, session_generation);
assert_eq!(
actual, expected,
"cutover predicate must match mode-aware generation oracle"
);
}
}
#[test]
fn light_fuzz_set_mode_generation_tracks_only_real_transitions() {
let runtime = RouteRuntimeController::new(RelayRouteMode::Direct);
let mut rng = StdRng::seed_from_u64(0x0DDC0FFE);
let mut expected_mode = RelayRouteMode::Direct;
let mut expected_generation = 0u64;
for _ in 0..10_000 {
let candidate = if rng.random::<bool>() {
RelayRouteMode::Direct
} else {
RelayRouteMode::Middle
};
let changed = runtime.set_mode(candidate);
if candidate == expected_mode {
assert!(changed.is_none(), "idempotent set_mode must not emit cutover state");
} else {
expected_mode = candidate;
expected_generation = expected_generation.saturating_add(1);
let next = changed.expect("mode transition must emit cutover state");
assert_eq!(next.mode, expected_mode);
assert_eq!(next.generation, expected_generation);
}
}
let final_state = runtime.snapshot();
assert_eq!(final_state.mode, expected_mode);
assert_eq!(final_state.generation, expected_generation);
}
#[test]
fn stress_snapshot_and_watch_state_remain_consistent_under_concurrent_switch_storm() {
let runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct));
std::thread::scope(|scope| {
let mut writers = Vec::new();
for worker in 0..4usize {
let runtime = Arc::clone(&runtime);
writers.push(scope.spawn(move || {
for step in 0..20_000usize {
let mode = if (worker + step) % 2 == 0 {
RelayRouteMode::Direct
} else {
RelayRouteMode::Middle
};
let _ = runtime.set_mode(mode);
}
}));
}
for writer in writers {
writer
.join()
.expect("route mode writer thread must not panic");
}
let rx = runtime.subscribe();
for _ in 0..128 {
assert_eq!(
runtime.snapshot(),
*rx.borrow(),
"snapshot and watch state must converge after concurrent set_mode churn"
);
std::thread::yield_now();
}
});
}
#[test]
fn stress_concurrent_transition_count_matches_final_generation() {
let runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct));
let successful_transitions = Arc::new(AtomicU64::new(0));
std::thread::scope(|scope| {
let mut workers = Vec::new();
for worker in 0..6usize {
let runtime = Arc::clone(&runtime);
let successful_transitions = Arc::clone(&successful_transitions);
workers.push(scope.spawn(move || {
let mut state = (worker as u64 + 1).wrapping_mul(0x9E37_79B9_7F4A_7C15);
for _ in 0..25_000usize {
state ^= state << 7;
state ^= state >> 9;
state ^= state << 8;
let mode = if (state & 1) == 0 {
RelayRouteMode::Direct
} else {
RelayRouteMode::Middle
};
if runtime.set_mode(mode).is_some() {
successful_transitions.fetch_add(1, Ordering::Relaxed);
}
}
}));
}
for worker in workers {
worker.join().expect("route mode transition worker must not panic");
}
});
let final_state = runtime.snapshot();
assert_eq!(
final_state.generation,
successful_transitions.load(Ordering::Relaxed),
"final generation must equal number of accepted mode transitions"
);
assert_eq!(
final_state,
*runtime.subscribe().borrow(),
"watch and snapshot state must match after concurrent transition accounting"
);
}
#[test]
fn light_fuzz_cutover_stagger_delay_distribution_stays_in_fixed_window() {
// Deterministic xorshift fuzzing keeps this test stable across runs.
let mut s: u64 = 0x9E37_79B9_7F4A_7C15;
for _ in 0..20_000 {
s ^= s << 7;
s ^= s >> 9;
s ^= s << 8;
let session_id = s;
s ^= s << 7;
s ^= s >> 9;
s ^= s << 8;
let generation = s;
let delay = cutover_stagger_delay(session_id, generation);
assert!(
(1000..=1999).contains(&delay.as_millis()),
"fuzzed inputs must always map into fixed stagger window"
);
}
}
#[test]
fn cutover_stagger_delay_distribution_has_no_empty_buckets_under_sequential_sessions() {
let mut buckets = [0usize; 1000];
let generation = 4242u64;
for session_id in 0..250_000u64 {
let delay_ms = cutover_stagger_delay(session_id, generation).as_millis() as usize;
let idx = delay_ms - 1000;
buckets[idx] += 1;
}
let empty = buckets.iter().filter(|&&count| count == 0).count();
assert_eq!(
empty, 0,
"all 1000 delay buckets must be exercised to avoid cutover herd clustering"
);
}
#[test]
fn light_fuzz_cutover_stagger_delay_distribution_stays_reasonably_uniform() {
let mut buckets = [0usize; 1000];
let mut s: u64 = 0x1BAD_B002_CAFE_F00D;
for _ in 0..300_000usize {
s ^= s << 7;
s ^= s >> 9;
s ^= s << 8;
let session_id = s;
s ^= s << 7;
s ^= s >> 9;
s ^= s << 8;
let generation = s;
let delay_ms = cutover_stagger_delay(session_id, generation).as_millis() as usize;
buckets[delay_ms - 1000] += 1;
}
let min = *buckets.iter().min().unwrap_or(&0);
let max = *buckets.iter().max().unwrap_or(&0);
assert!(min > 0, "fuzzed distribution must not leave empty buckets");
assert!(
max <= min.saturating_mul(3),
"bucket skew is too high for anti-herd staggering (max={max}, min={min})"
);
}
#[test]
fn stress_cutover_stagger_delay_distribution_remains_stable_across_generations() {
for generation in [0u64, 1, 7, 31, 255, 1024, u32::MAX as u64, u64::MAX - 1] {
let mut buckets = [0usize; 1000];
for session_id in 0..100_000u64 {
let delay_ms = cutover_stagger_delay(session_id ^ 0x9E37_79B9, generation)
.as_millis() as usize;
buckets[delay_ms - 1000] += 1;
}
let min = *buckets.iter().min().unwrap_or(&0);
let max = *buckets.iter().max().unwrap_or(&0);
assert!(
max <= min.saturating_mul(4).max(1),
"generation={generation}: distribution collapsed (max={max}, min={min})"
);
}
}

View File

@ -0,0 +1,265 @@
use super::*;
use std::panic::{self, AssertUnwindSafe};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Barrier;
#[test]
fn direct_connection_lease_balances_on_drop() {
let stats = Arc::new(Stats::new());
assert_eq!(stats.get_current_connections_direct(), 0);
{
let _lease = stats.acquire_direct_connection_lease();
assert_eq!(stats.get_current_connections_direct(), 1);
}
assert_eq!(stats.get_current_connections_direct(), 0);
}
#[test]
fn middle_connection_lease_balances_on_drop() {
let stats = Arc::new(Stats::new());
assert_eq!(stats.get_current_connections_me(), 0);
{
let _lease = stats.acquire_me_connection_lease();
assert_eq!(stats.get_current_connections_me(), 1);
}
assert_eq!(stats.get_current_connections_me(), 0);
}
#[test]
fn connection_lease_disarm_prevents_double_release() {
let stats = Arc::new(Stats::new());
let mut lease = stats.acquire_direct_connection_lease();
assert_eq!(stats.get_current_connections_direct(), 1);
stats.decrement_current_connections_direct();
assert_eq!(stats.get_current_connections_direct(), 0);
lease.disarm();
drop(lease);
assert_eq!(stats.get_current_connections_direct(), 0);
}
#[test]
fn direct_connection_lease_balances_on_panic_unwind() {
let stats = Arc::new(Stats::new());
let stats_for_panic = stats.clone();
let panic_result = panic::catch_unwind(AssertUnwindSafe(move || {
let _lease = stats_for_panic.acquire_direct_connection_lease();
panic!("intentional panic to verify lease drop path");
}));
assert!(panic_result.is_err(), "panic must propagate from test closure");
assert_eq!(
stats.get_current_connections_direct(),
0,
"panic unwind must release direct route gauge"
);
}
#[test]
fn middle_connection_lease_balances_on_panic_unwind() {
let stats = Arc::new(Stats::new());
let stats_for_panic = stats.clone();
let panic_result = panic::catch_unwind(AssertUnwindSafe(move || {
let _lease = stats_for_panic.acquire_me_connection_lease();
panic!("intentional panic to verify middle lease drop path");
}));
assert!(panic_result.is_err(), "panic must propagate from test closure");
assert_eq!(
stats.get_current_connections_me(),
0,
"panic unwind must release middle route gauge"
);
}
#[tokio::test]
async fn concurrent_mixed_route_lease_churn_balances_to_zero() {
const TASKS: usize = 48;
const ITERATIONS_PER_TASK: usize = 256;
let stats = Arc::new(Stats::new());
let barrier = Arc::new(Barrier::new(TASKS));
let mut workers = Vec::with_capacity(TASKS);
for task_idx in 0..TASKS {
let stats_for_task = stats.clone();
let barrier_for_task = barrier.clone();
workers.push(tokio::spawn(async move {
barrier_for_task.wait().await;
for iter in 0..ITERATIONS_PER_TASK {
if (task_idx + iter) % 2 == 0 {
let _lease = stats_for_task.acquire_direct_connection_lease();
tokio::task::yield_now().await;
} else {
let _lease = stats_for_task.acquire_me_connection_lease();
tokio::task::yield_now().await;
}
}
}));
}
for worker in workers {
worker
.await
.expect("lease churn worker must not panic");
}
assert_eq!(
stats.get_current_connections_direct(),
0,
"direct route gauge must return to zero after concurrent lease churn"
);
assert_eq!(
stats.get_current_connections_me(),
0,
"middle route gauge must return to zero after concurrent lease churn"
);
}
#[tokio::test]
async fn abort_storm_mixed_route_leases_returns_all_gauges_to_zero() {
const TASKS: usize = 64;
let stats = Arc::new(Stats::new());
let mut workers = Vec::with_capacity(TASKS);
for task_idx in 0..TASKS {
let stats_for_task = stats.clone();
workers.push(tokio::spawn(async move {
if task_idx % 2 == 0 {
let _lease = stats_for_task.acquire_direct_connection_lease();
tokio::time::sleep(Duration::from_secs(60)).await;
} else {
let _lease = stats_for_task.acquire_me_connection_lease();
tokio::time::sleep(Duration::from_secs(60)).await;
}
}));
}
tokio::time::timeout(Duration::from_secs(2), async {
loop {
let total = stats.get_current_connections_direct() + stats.get_current_connections_me();
if total == TASKS as u64 {
break;
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
})
.await
.expect("all storm tasks must acquire route leases before abort");
for worker in &workers {
worker.abort();
}
for worker in workers {
let joined = worker.await;
assert!(joined.is_err(), "aborted worker must return join error");
}
tokio::time::timeout(Duration::from_secs(2), async {
loop {
if stats.get_current_connections_direct() == 0 && stats.get_current_connections_me() == 0 {
break;
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
})
.await
.expect("all route gauges must drain to zero after abort storm");
}
#[test]
fn saturating_route_decrements_do_not_underflow_under_race() {
const THREADS: usize = 16;
const DECREMENTS_PER_THREAD: usize = 4096;
let stats = Arc::new(Stats::new());
let mut workers = Vec::with_capacity(THREADS);
for _ in 0..THREADS {
let stats_for_thread = stats.clone();
workers.push(std::thread::spawn(move || {
for _ in 0..DECREMENTS_PER_THREAD {
stats_for_thread.decrement_current_connections_direct();
stats_for_thread.decrement_current_connections_me();
}
}));
}
for worker in workers {
worker
.join()
.expect("decrement race worker must not panic");
}
assert_eq!(
stats.get_current_connections_direct(),
0,
"direct route decrement races must never underflow"
);
assert_eq!(
stats.get_current_connections_me(),
0,
"middle route decrement races must never underflow"
);
}
#[tokio::test]
async fn direct_connection_lease_balances_on_task_abort() {
let stats = Arc::new(Stats::new());
let stats_for_task = stats.clone();
let task = tokio::spawn(async move {
let _lease = stats_for_task.acquire_direct_connection_lease();
tokio::time::sleep(Duration::from_secs(60)).await;
});
tokio::time::sleep(Duration::from_millis(20)).await;
assert_eq!(stats.get_current_connections_direct(), 1);
task.abort();
let joined = task.await;
assert!(joined.is_err(), "aborted task must return a join error");
tokio::time::sleep(Duration::from_millis(20)).await;
assert_eq!(
stats.get_current_connections_direct(),
0,
"aborted task must release direct route gauge"
);
}
#[tokio::test]
async fn middle_connection_lease_balances_on_task_abort() {
let stats = Arc::new(Stats::new());
let stats_for_task = stats.clone();
let task = tokio::spawn(async move {
let _lease = stats_for_task.acquire_me_connection_lease();
tokio::time::sleep(Duration::from_secs(60)).await;
});
tokio::time::sleep(Duration::from_millis(20)).await;
assert_eq!(stats.get_current_connections_me(), 1);
task.abort();
let joined = task.await;
assert!(joined.is_err(), "aborted task must return a join error");
tokio::time::sleep(Duration::from_millis(20)).await;
assert_eq!(
stats.get_current_connections_me(),
0,
"aborted task must release middle route gauge"
);
}

View File

@ -6,6 +6,7 @@ pub mod beobachten;
pub mod telemetry;
use std::sync::atomic::{AtomicBool, AtomicU8, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
use dashmap::DashMap;
use parking_lot::Mutex;
@ -19,6 +20,46 @@ use tracing::debug;
use crate::config::{MeTelemetryLevel, MeWriterPickMode};
use self::telemetry::TelemetryPolicy;
#[derive(Clone, Copy)]
enum RouteConnectionGauge {
Direct,
Middle,
}
#[must_use = "RouteConnectionLease must be kept alive to hold the connection gauge increment"]
pub struct RouteConnectionLease {
stats: Arc<Stats>,
gauge: RouteConnectionGauge,
active: bool,
}
impl RouteConnectionLease {
fn new(stats: Arc<Stats>, gauge: RouteConnectionGauge) -> Self {
Self {
stats,
gauge,
active: true,
}
}
#[cfg(test)]
fn disarm(&mut self) {
self.active = false;
}
}
impl Drop for RouteConnectionLease {
fn drop(&mut self) {
if !self.active {
return;
}
match self.gauge {
RouteConnectionGauge::Direct => self.stats.decrement_current_connections_direct(),
RouteConnectionGauge::Middle => self.stats.decrement_current_connections_me(),
}
}
}
// ============= Stats =============
#[derive(Default)]
@ -285,6 +326,16 @@ impl Stats {
pub fn decrement_current_connections_me(&self) {
Self::decrement_atomic_saturating(&self.current_connections_me);
}
pub fn acquire_direct_connection_lease(self: &Arc<Self>) -> RouteConnectionLease {
self.increment_current_connections_direct();
RouteConnectionLease::new(self.clone(), RouteConnectionGauge::Direct)
}
pub fn acquire_me_connection_lease(self: &Arc<Self>) -> RouteConnectionLease {
self.increment_current_connections_me();
RouteConnectionLease::new(self.clone(), RouteConnectionGauge::Middle)
}
pub fn increment_handshake_timeouts(&self) {
if self.telemetry_core_enabled() {
self.handshake_timeouts.fetch_add(1, Ordering::Relaxed);
@ -1457,9 +1508,11 @@ impl Stats {
// ============= Replay Checker =============
pub struct ReplayChecker {
shards: Vec<Mutex<ReplayShard>>,
handshake_shards: Vec<Mutex<ReplayShard>>,
tls_shards: Vec<Mutex<ReplayShard>>,
shard_mask: usize,
window: Duration,
tls_window: Duration,
checks: AtomicU64,
hits: AtomicU64,
additions: AtomicU64,
@ -1536,19 +1589,24 @@ impl ReplayShard {
impl ReplayChecker {
pub fn new(total_capacity: usize, window: Duration) -> Self {
const MIN_TLS_REPLAY_WINDOW: Duration = Duration::from_secs(120);
let num_shards = 64;
let shard_capacity = (total_capacity / num_shards).max(1);
let cap = NonZeroUsize::new(shard_capacity).unwrap();
let mut shards = Vec::with_capacity(num_shards);
let mut handshake_shards = Vec::with_capacity(num_shards);
let mut tls_shards = Vec::with_capacity(num_shards);
for _ in 0..num_shards {
shards.push(Mutex::new(ReplayShard::new(cap)));
handshake_shards.push(Mutex::new(ReplayShard::new(cap)));
tls_shards.push(Mutex::new(ReplayShard::new(cap)));
}
Self {
shards,
handshake_shards,
tls_shards,
shard_mask: num_shards - 1,
window,
tls_window: window.max(MIN_TLS_REPLAY_WINDOW),
checks: AtomicU64::new(0),
hits: AtomicU64::new(0),
additions: AtomicU64::new(0),
@ -1562,46 +1620,60 @@ impl ReplayChecker {
(hasher.finish() as usize) & self.shard_mask
}
fn check_and_add_internal(&self, data: &[u8]) -> bool {
fn check_and_add_internal(
&self,
data: &[u8],
shards: &[Mutex<ReplayShard>],
window: Duration,
) -> bool {
self.checks.fetch_add(1, Ordering::Relaxed);
let idx = self.get_shard_idx(data);
let mut shard = self.shards[idx].lock();
let mut shard = shards[idx].lock();
let now = Instant::now();
let found = shard.check(data, now, self.window);
let found = shard.check(data, now, window);
if found {
self.hits.fetch_add(1, Ordering::Relaxed);
} else {
shard.add(data, now, self.window);
shard.add(data, now, window);
self.additions.fetch_add(1, Ordering::Relaxed);
}
found
}
fn add_only(&self, data: &[u8]) {
fn add_only(&self, data: &[u8], shards: &[Mutex<ReplayShard>], window: Duration) {
self.additions.fetch_add(1, Ordering::Relaxed);
let idx = self.get_shard_idx(data);
let mut shard = self.shards[idx].lock();
shard.add(data, Instant::now(), self.window);
let mut shard = shards[idx].lock();
shard.add(data, Instant::now(), window);
}
pub fn check_and_add_handshake(&self, data: &[u8]) -> bool {
self.check_and_add_internal(data)
self.check_and_add_internal(data, &self.handshake_shards, self.window)
}
pub fn check_and_add_tls_digest(&self, data: &[u8]) -> bool {
self.check_and_add_internal(data)
self.check_and_add_internal(data, &self.tls_shards, self.tls_window)
}
// Compatibility helpers (non-atomic split operations) — prefer check_and_add_*.
pub fn check_handshake(&self, data: &[u8]) -> bool { self.check_and_add_handshake(data) }
pub fn add_handshake(&self, data: &[u8]) { self.add_only(data) }
pub fn add_handshake(&self, data: &[u8]) {
self.add_only(data, &self.handshake_shards, self.window)
}
pub fn check_tls_digest(&self, data: &[u8]) -> bool { self.check_and_add_tls_digest(data) }
pub fn add_tls_digest(&self, data: &[u8]) { self.add_only(data) }
pub fn add_tls_digest(&self, data: &[u8]) {
self.add_only(data, &self.tls_shards, self.tls_window)
}
pub fn stats(&self) -> ReplayStats {
let mut total_entries = 0;
let mut total_queue_len = 0;
for shard in &self.shards {
for shard in &self.handshake_shards {
let s = shard.lock();
total_entries += s.cache.len();
total_queue_len += s.queue.len();
}
for shard in &self.tls_shards {
let s = shard.lock();
total_entries += s.cache.len();
total_queue_len += s.queue.len();
@ -1614,7 +1686,7 @@ impl ReplayChecker {
total_hits: self.hits.load(Ordering::Relaxed),
total_additions: self.additions.load(Ordering::Relaxed),
total_cleanups: self.cleanups.load(Ordering::Relaxed),
num_shards: self.shards.len(),
num_shards: self.handshake_shards.len() + self.tls_shards.len(),
window_secs: self.window.as_secs(),
}
}
@ -1632,13 +1704,20 @@ impl ReplayChecker {
let now = Instant::now();
let mut cleaned = 0usize;
for shard_mutex in &self.shards {
for shard_mutex in &self.handshake_shards {
let mut shard = shard_mutex.lock();
let before = shard.len();
shard.cleanup(now, self.window);
let after = shard.len();
cleaned += before.saturating_sub(after);
}
for shard_mutex in &self.tls_shards {
let mut shard = shard_mutex.lock();
let before = shard.len();
shard.cleanup(now, self.tls_window);
let after = shard.len();
cleaned += before.saturating_sub(after);
}
self.cleanups.fetch_add(1, Ordering::Relaxed);
@ -1764,7 +1843,7 @@ mod tests {
fn test_replay_checker_many_keys() {
let checker = ReplayChecker::new(10_000, Duration::from_secs(60));
for i in 0..500u32 {
checker.add_only(&i.to_le_bytes());
checker.add_handshake(&i.to_le_bytes());
}
for i in 0..500u32 {
assert!(checker.check_handshake(&i.to_le_bytes()));
@ -1772,3 +1851,11 @@ mod tests {
assert_eq!(checker.stats().total_entries, 500);
}
}
#[cfg(test)]
#[path = "connection_lease_security_tests.rs"]
mod connection_lease_security_tests;
#[cfg(test)]
#[path = "replay_checker_security_tests.rs"]
mod replay_checker_security_tests;

View File

@ -0,0 +1,80 @@
use super::*;
use std::time::Duration;
#[test]
fn replay_checker_keeps_tls_and_handshake_domains_isolated_for_same_key() {
let checker = ReplayChecker::new(128, Duration::from_millis(20));
let key = b"same-key-domain-separation";
assert!(
!checker.check_and_add_handshake(key),
"first handshake use should be fresh"
);
assert!(
!checker.check_and_add_tls_digest(key),
"same bytes in TLS domain should still be fresh"
);
assert!(
checker.check_and_add_handshake(key),
"second handshake use should be replay-hit"
);
assert!(
checker.check_and_add_tls_digest(key),
"second TLS use should be replay-hit independently"
);
}
#[test]
fn replay_checker_tls_window_is_clamped_beyond_small_handshake_window() {
let checker = ReplayChecker::new(128, Duration::from_millis(20));
let handshake_key = b"short-window-handshake";
let tls_key = b"short-window-tls";
assert!(!checker.check_and_add_handshake(handshake_key));
assert!(!checker.check_and_add_tls_digest(tls_key));
std::thread::sleep(Duration::from_millis(80));
assert!(
!checker.check_and_add_handshake(handshake_key),
"handshake key should expire under short configured window"
);
assert!(
checker.check_and_add_tls_digest(tls_key),
"TLS key should still be replay-hit because TLS window is clamped to a secure minimum"
);
}
#[test]
fn replay_checker_compat_add_paths_do_not_cross_pollute_domains() {
let checker = ReplayChecker::new(128, Duration::from_secs(1));
let key = b"compat-domain-separation";
checker.add_handshake(key);
assert!(
checker.check_and_add_handshake(key),
"handshake add helper must populate handshake domain"
);
assert!(
!checker.check_and_add_tls_digest(key),
"handshake add helper must not pollute TLS domain"
);
checker.add_tls_digest(key);
assert!(
checker.check_and_add_tls_digest(key),
"TLS add helper must populate TLS domain"
);
}
#[test]
fn replay_checker_stats_reflect_dual_shard_domains() {
let checker = ReplayChecker::new(128, Duration::from_secs(1));
let stats = checker.stats();
assert_eq!(
stats.num_shards, 128,
"stats should expose both shard domains (handshake + TLS)"
);
}

View File

@ -117,15 +117,6 @@ pub fn build_emulated_server_hello(
extensions.extend_from_slice(&0x002bu16.to_be_bytes());
extensions.extend_from_slice(&(2u16).to_be_bytes());
extensions.extend_from_slice(&0x0304u16.to_be_bytes());
if let Some(alpn_proto) = &alpn {
extensions.extend_from_slice(&0x0010u16.to_be_bytes());
let list_len: u16 = 1 + alpn_proto.len() as u16;
let ext_len: u16 = 2 + list_len;
extensions.extend_from_slice(&ext_len.to_be_bytes());
extensions.extend_from_slice(&list_len.to_be_bytes());
extensions.push(alpn_proto.len() as u8);
extensions.extend_from_slice(alpn_proto);
}
let extensions_len = extensions.len() as u16;
let body_len = 2 + // version
@ -207,8 +198,22 @@ pub fn build_emulated_server_hello(
}
let mut app_data = Vec::new();
let alpn_marker = alpn
.as_ref()
.filter(|p| !p.is_empty() && p.len() <= u8::MAX as usize)
.map(|proto| {
let proto_list_len = 1usize + proto.len();
let ext_data_len = 2usize + proto_list_len;
let mut marker = Vec::with_capacity(4 + ext_data_len);
marker.extend_from_slice(&0x0010u16.to_be_bytes());
marker.extend_from_slice(&(ext_data_len as u16).to_be_bytes());
marker.extend_from_slice(&(proto_list_len as u16).to_be_bytes());
marker.push(proto.len() as u8);
marker.extend_from_slice(proto);
marker
});
let mut payload_offset = 0usize;
for size in sizes {
for (idx, size) in sizes.into_iter().enumerate() {
let mut rec = Vec::with_capacity(5 + size);
rec.push(TLS_RECORD_APPLICATION);
rec.extend_from_slice(&TLS_VERSION);
@ -233,7 +238,20 @@ pub fn build_emulated_server_hello(
}
} else if size > 17 {
let body_len = size - 17;
rec.extend_from_slice(&rng.bytes(body_len));
let mut body = Vec::with_capacity(body_len);
if idx == 0 && let Some(marker) = &alpn_marker {
if marker.len() <= body_len {
body.extend_from_slice(marker);
if body_len > marker.len() {
body.extend_from_slice(&rng.bytes(body_len - marker.len()));
}
} else {
body.extend_from_slice(&rng.bytes(body_len));
}
} else {
body.extend_from_slice(&rng.bytes(body_len));
}
rec.extend_from_slice(&body);
rec.push(0x16); // inner content type marker (handshake)
rec.extend_from_slice(&rng.bytes(16)); // AEAD-like tag
} else {
@ -245,8 +263,9 @@ pub fn build_emulated_server_hello(
// --- Combine ---
// Optional NewSessionTicket mimic records (opaque ApplicationData for fingerprint).
let mut tickets = Vec::new();
if new_session_tickets > 0 {
for _ in 0..new_session_tickets {
let ticket_count = new_session_tickets.min(4);
if ticket_count > 0 {
for _ in 0..ticket_count {
let ticket_len: usize = rng.range(48) + 48;
let mut rec = Vec::with_capacity(5 + ticket_len);
rec.push(TLS_RECORD_APPLICATION);
@ -273,6 +292,10 @@ pub fn build_emulated_server_hello(
response
}
#[cfg(test)]
#[path = "emulator_security_tests.rs"]
mod security_tests;
#[cfg(test)]
mod tests {
use std::time::SystemTime;

View File

@ -0,0 +1,136 @@
use std::time::SystemTime;
use crate::crypto::SecureRandom;
use crate::protocol::constants::{TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER, TLS_RECORD_HANDSHAKE};
use crate::tls_front::emulator::build_emulated_server_hello;
use crate::tls_front::types::{
CachedTlsData, ParsedServerHello, TlsBehaviorProfile, TlsCertPayload, TlsProfileSource,
};
fn make_cached(cert_payload: Option<crate::tls_front::types::TlsCertPayload>) -> CachedTlsData {
CachedTlsData {
server_hello_template: ParsedServerHello {
version: [0x03, 0x03],
random: [0u8; 32],
session_id: Vec::new(),
cipher_suite: [0x13, 0x01],
compression: 0,
extensions: Vec::new(),
},
cert_info: None,
cert_payload,
app_data_records_sizes: vec![64],
total_app_data_len: 64,
behavior_profile: TlsBehaviorProfile {
change_cipher_spec_count: 1,
app_data_record_sizes: vec![64],
ticket_record_sizes: Vec::new(),
source: TlsProfileSource::Default,
},
fetched_at: SystemTime::now(),
domain: "example.com".to_string(),
}
}
fn first_app_data_payload(response: &[u8]) -> &[u8] {
let hello_len = u16::from_be_bytes([response[3], response[4]]) as usize;
let ccs_start = 5 + hello_len;
let ccs_len = u16::from_be_bytes([response[ccs_start + 3], response[ccs_start + 4]]) as usize;
let app_start = ccs_start + 5 + ccs_len;
let app_len = u16::from_be_bytes([response[app_start + 3], response[app_start + 4]]) as usize;
&response[app_start + 5..app_start + 5 + app_len]
}
#[test]
fn emulated_server_hello_ignores_oversized_alpn_when_marker_would_not_fit() {
let cached = make_cached(None);
let rng = SecureRandom::new();
let oversized_alpn = vec![0xAB; u8::MAX as usize + 1];
let response = build_emulated_server_hello(
b"secret",
&[0x11; 32],
&[0x22; 16],
&cached,
true,
&rng,
Some(oversized_alpn),
0,
);
assert_eq!(response[0], TLS_RECORD_HANDSHAKE);
let hello_len = u16::from_be_bytes([response[3], response[4]]) as usize;
let ccs_start = 5 + hello_len;
assert_eq!(response[ccs_start], TLS_RECORD_CHANGE_CIPHER);
let app_start = ccs_start + 6;
assert_eq!(response[app_start], TLS_RECORD_APPLICATION);
let payload = first_app_data_payload(&response);
let mut marker_prefix = Vec::new();
marker_prefix.extend_from_slice(&0x0010u16.to_be_bytes());
marker_prefix.extend_from_slice(&0x0102u16.to_be_bytes());
marker_prefix.extend_from_slice(&0x0100u16.to_be_bytes());
marker_prefix.push(0xff);
marker_prefix.extend_from_slice(&[0xab; 8]);
assert!(
!payload.starts_with(&marker_prefix),
"oversized ALPN must not be partially embedded into the emulated first application record"
);
}
#[test]
fn emulated_server_hello_embeds_full_alpn_marker_when_body_can_fit() {
let cached = make_cached(None);
let rng = SecureRandom::new();
let response = build_emulated_server_hello(
b"secret",
&[0x31; 32],
&[0x41; 16],
&cached,
true,
&rng,
Some(b"h2".to_vec()),
0,
);
let payload = first_app_data_payload(&response);
let expected = [0x00u8, 0x10, 0x00, 0x05, 0x00, 0x03, 0x02, b'h', b'2'];
assert!(
payload.starts_with(&expected),
"when body has enough capacity, emulated first application record must include full ALPN marker"
);
}
#[test]
fn emulated_server_hello_prefers_cert_payload_over_alpn_marker() {
let cert_msg = vec![0x0b, 0x00, 0x00, 0x05, 0x00, 0xaa, 0xbb, 0xcc, 0xdd];
let cached = make_cached(Some(TlsCertPayload {
cert_chain_der: vec![vec![0x30, 0x01, 0x00]],
certificate_message: cert_msg.clone(),
}));
let rng = SecureRandom::new();
let response = build_emulated_server_hello(
b"secret",
&[0x32; 32],
&[0x42; 16],
&cached,
true,
&rng,
Some(b"h2".to_vec()),
0,
);
let payload = first_app_data_payload(&response);
let alpn_marker = [0x00u8, 0x10, 0x00, 0x05, 0x00, 0x03, 0x02, b'h', b'2'];
assert!(
payload.starts_with(&cert_msg),
"when certificate payload is available, first record must start with cert payload bytes"
);
assert!(
!payload.starts_with(&alpn_marker),
"ALPN marker must not displace selected certificate payload"
);
}

View File

@ -25,6 +25,9 @@ const HEALTH_RECONNECT_BUDGET_PER_CORE: usize = 2;
const HEALTH_RECONNECT_BUDGET_PER_DC: usize = 1;
const HEALTH_RECONNECT_BUDGET_MIN: usize = 4;
const HEALTH_RECONNECT_BUDGET_MAX: usize = 128;
const HEALTH_DRAIN_CLOSE_BUDGET_PER_CORE: usize = 16;
const HEALTH_DRAIN_CLOSE_BUDGET_MIN: usize = 16;
const HEALTH_DRAIN_CLOSE_BUDGET_MAX: usize = 256;
#[derive(Debug, Clone)]
struct DcFloorPlanEntry {
@ -111,106 +114,75 @@ pub async fn me_health_monitor(pool: Arc<MePool>, rng: Arc<SecureRandom>, _min_c
}
}
async fn reap_draining_writers(
pub(super) async fn reap_draining_writers(
pool: &Arc<MePool>,
warn_next_allowed: &mut HashMap<u64, Instant>,
) {
if pool.draining_active_runtime() == 0 {
return;
}
let now_epoch_secs = MePool::now_epoch_secs();
let now = Instant::now();
let drain_ttl_secs = pool.me_pool_drain_ttl_secs.load(std::sync::atomic::Ordering::Relaxed);
let drain_threshold = pool
.me_pool_drain_threshold
.load(std::sync::atomic::Ordering::Relaxed);
let mut draining_writers = {
let writers = pool.writers.read().await;
let mut draining_writers = Vec::<DrainingWriterSnapshot>::new();
for writer in writers.iter() {
if !writer.draining.load(std::sync::atomic::Ordering::Relaxed) {
continue;
}
draining_writers.push(DrainingWriterSnapshot {
id: writer.id,
writer_dc: writer.writer_dc,
addr: writer.addr,
generation: writer.generation,
created_at: writer.created_at,
draining_started_at_epoch_secs: writer
.draining_started_at_epoch_secs
.load(std::sync::atomic::Ordering::Relaxed),
drain_deadline_epoch_secs: writer
.drain_deadline_epoch_secs
.load(std::sync::atomic::Ordering::Relaxed),
allow_drain_fallback: writer
.allow_drain_fallback
.load(std::sync::atomic::Ordering::Relaxed),
});
let activity = pool.registry.writer_activity_snapshot().await;
let mut draining_writers = Vec::<DrainingWriterSnapshot>::new();
let mut empty_writer_ids = Vec::<u64>::new();
let mut force_close_writer_ids = Vec::<u64>::new();
let writers = pool.writers.read().await;
for writer in writers.iter() {
if !writer.draining.load(std::sync::atomic::Ordering::Relaxed) {
continue;
}
draining_writers
};
if draining_writers.is_empty() {
return;
}
let draining_ids: Vec<u64> = draining_writers.iter().map(|writer| writer.id).collect();
let non_empty_writer_ids = pool.registry.non_empty_writer_ids(&draining_ids).await;
let mut non_empty_draining_writers =
Vec::<DrainingWriterSnapshot>::with_capacity(draining_writers.len());
for writer in draining_writers.drain(..) {
if non_empty_writer_ids.contains(&writer.id) {
non_empty_draining_writers.push(writer);
} else {
pool.remove_writer_and_close_clients(writer.id).await;
if activity
.bound_clients_by_writer
.get(&writer.id)
.copied()
.unwrap_or(0)
== 0
{
empty_writer_ids.push(writer.id);
continue;
}
draining_writers.push(DrainingWriterSnapshot {
id: writer.id,
writer_dc: writer.writer_dc,
addr: writer.addr,
generation: writer.generation,
created_at: writer.created_at,
draining_started_at_epoch_secs: writer
.draining_started_at_epoch_secs
.load(std::sync::atomic::Ordering::Relaxed),
drain_deadline_epoch_secs: writer
.drain_deadline_epoch_secs
.load(std::sync::atomic::Ordering::Relaxed),
allow_drain_fallback: writer
.allow_drain_fallback
.load(std::sync::atomic::Ordering::Relaxed),
});
}
draining_writers = non_empty_draining_writers;
if draining_writers.is_empty() {
return;
}
drop(writers);
let overflow = if drain_threshold > 0 && draining_writers.len() > drain_threshold as usize {
draining_writers.len().saturating_sub(drain_threshold as usize)
} else {
0
};
let has_deadline_expired = draining_writers.iter().any(|writer| {
writer.drain_deadline_epoch_secs != 0 && now_epoch_secs >= writer.drain_deadline_epoch_secs
});
let can_drop_with_replacement = if overflow > 0 || has_deadline_expired {
pool.has_non_draining_writer_per_desired_dc_group().await
} else {
false
};
if overflow > 0 {
if can_drop_with_replacement {
draining_writers.sort_by(|left, right| {
left.draining_started_at_epoch_secs
.cmp(&right.draining_started_at_epoch_secs)
.then_with(|| left.created_at.cmp(&right.created_at))
.then_with(|| left.id.cmp(&right.id))
});
warn!(
draining_writers = draining_writers.len(),
me_pool_drain_threshold = drain_threshold,
removing_writers = overflow,
"ME draining writer threshold exceeded, force-closing oldest draining writers"
);
for writer in draining_writers.drain(..overflow) {
pool.stats.increment_pool_force_close_total();
pool.remove_writer_and_close_clients(writer.id).await;
}
} else {
warn!(
draining_writers = draining_writers.len(),
me_pool_drain_threshold = drain_threshold,
overflow,
"ME draining threshold exceeded, but replacement coverage is incomplete; keeping draining writers"
);
draining_writers.sort_by(|left, right| {
left.draining_started_at_epoch_secs
.cmp(&right.draining_started_at_epoch_secs)
.then_with(|| left.created_at.cmp(&right.created_at))
.then_with(|| left.id.cmp(&right.id))
});
warn!(
draining_writers = draining_writers.len(),
me_pool_drain_threshold = drain_threshold,
removing_writers = overflow,
"ME draining writer threshold exceeded, force-closing oldest draining writers"
);
for writer in draining_writers.drain(..overflow) {
force_close_writer_ids.push(writer.id);
}
}
@ -238,25 +210,71 @@ async fn reap_draining_writers(
}
if writer.drain_deadline_epoch_secs != 0 && now_epoch_secs >= writer.drain_deadline_epoch_secs
{
if can_drop_with_replacement {
warn!(writer_id = writer.id, "Drain timeout, force-closing");
pool.stats.increment_pool_force_close_total();
pool.remove_writer_and_close_clients(writer.id).await;
} else if should_emit_writer_warn(
warn_next_allowed,
writer.id,
now,
pool.warn_rate_limit_duration(),
) {
warn!(
writer_id = writer.id,
writer_dc = writer.writer_dc,
endpoint = %writer.addr,
"Drain timeout reached, but replacement coverage is incomplete; keeping draining writer"
);
}
warn!(writer_id = writer.id, "Drain timeout, force-closing");
force_close_writer_ids.push(writer.id);
}
}
let close_budget = health_drain_close_budget();
let requested_force_close = force_close_writer_ids.len();
let requested_empty_close = empty_writer_ids.len();
let requested_close_total = requested_force_close.saturating_add(requested_empty_close);
let mut closed_writer_ids = HashSet::<u64>::new();
let mut closed_total = 0usize;
for writer_id in force_close_writer_ids {
if closed_total >= close_budget {
break;
}
if !closed_writer_ids.insert(writer_id) {
continue;
}
pool.stats.increment_pool_force_close_total();
pool.remove_writer_and_close_clients(writer_id).await;
closed_total = closed_total.saturating_add(1);
}
for writer_id in empty_writer_ids {
if closed_total >= close_budget {
break;
}
if !closed_writer_ids.insert(writer_id) {
continue;
}
if !pool.remove_writer_if_empty(writer_id).await {
continue;
}
closed_total = closed_total.saturating_add(1);
}
let pending_close_total = requested_close_total.saturating_sub(closed_total);
if pending_close_total > 0 {
warn!(
close_budget,
closed_total,
pending_close_total,
"ME draining close backlog deferred to next health cycle"
);
}
// Keep warn cooldown state for draining writers still present in the pool;
// drop state only once a writer is actually removed.
let active_draining_writer_ids = {
let writers = pool.writers.read().await;
writers
.iter()
.filter(|writer| writer.draining.load(std::sync::atomic::Ordering::Relaxed))
.map(|writer| writer.id)
.collect::<HashSet<u64>>()
};
warn_next_allowed.retain(|writer_id, _| active_draining_writer_ids.contains(writer_id));
}
pub(super) fn health_drain_close_budget() -> usize {
let cpu_cores = std::thread::available_parallelism()
.map(std::num::NonZeroUsize::get)
.unwrap_or(1);
cpu_cores
.saturating_mul(HEALTH_DRAIN_CLOSE_BUDGET_PER_CORE)
.clamp(HEALTH_DRAIN_CLOSE_BUDGET_MIN, HEALTH_DRAIN_CLOSE_BUDGET_MAX)
}
#[derive(Debug, Clone)]
@ -1521,7 +1539,6 @@ mod tests {
pool.writers.write().await.push(writer);
pool.registry.register_writer(writer_id, tx).await;
pool.conn_count.fetch_add(1, Ordering::Relaxed);
pool.increment_draining_active_runtime();
assert!(
pool.registry
.bind_writer(
@ -1570,7 +1587,6 @@ mod tests {
async fn reap_draining_writers_force_closes_oldest_over_threshold() {
let pool = make_pool(2).await;
insert_live_writer(&pool, 1, 2).await;
assert!(pool.has_non_draining_writer_per_desired_dc_group().await);
let now_epoch_secs = MePool::now_epoch_secs();
let conn_a = insert_draining_writer(&pool, 10, now_epoch_secs.saturating_sub(30)).await;
let conn_b = insert_draining_writer(&pool, 20, now_epoch_secs.saturating_sub(20)).await;
@ -1588,7 +1604,7 @@ mod tests {
}
#[tokio::test]
async fn reap_draining_writers_does_not_force_close_overflow_without_replacement() {
async fn reap_draining_writers_force_closes_overflow_without_replacement() {
let pool = make_pool(2).await;
let now_epoch_secs = MePool::now_epoch_secs();
let conn_a = insert_draining_writer(&pool, 10, now_epoch_secs.saturating_sub(30)).await;
@ -1600,8 +1616,8 @@ mod tests {
let mut writer_ids: Vec<u64> = pool.writers.read().await.iter().map(|writer| writer.id).collect();
writer_ids.sort_unstable();
assert_eq!(writer_ids, vec![10, 20, 30]);
assert_eq!(pool.registry.get_writer(conn_a).await.unwrap().writer_id, 10);
assert_eq!(writer_ids, vec![20, 30]);
assert!(pool.registry.get_writer(conn_a).await.is_none());
assert_eq!(pool.registry.get_writer(conn_b).await.unwrap().writer_id, 20);
assert_eq!(pool.registry.get_writer(conn_c).await.unwrap().writer_id, 30);
}

View File

@ -0,0 +1,615 @@
use std::collections::HashMap;
use std::collections::HashSet;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU8, AtomicU32, AtomicU64, Ordering};
use std::time::{Duration, Instant};
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use super::codec::WriterCommand;
use super::health::{health_drain_close_budget, reap_draining_writers};
use super::pool::{MePool, MeWriter, WriterContour};
use super::registry::ConnMeta;
use super::me_health_monitor;
use crate::config::{GeneralConfig, MeRouteNoWriterMode, MeSocksKdfPolicy, MeWriterPickMode};
use crate::crypto::SecureRandom;
use crate::network::probe::NetworkDecision;
use crate::stats::Stats;
async fn make_pool(
me_pool_drain_threshold: u64,
me_health_interval_ms_unhealthy: u64,
me_health_interval_ms_healthy: u64,
) -> (Arc<MePool>, Arc<SecureRandom>) {
let general = GeneralConfig {
me_pool_drain_threshold,
me_health_interval_ms_unhealthy,
me_health_interval_ms_healthy,
..GeneralConfig::default()
};
let rng = Arc::new(SecureRandom::new());
let pool = MePool::new(
None,
vec![1u8; 32],
None,
false,
None,
Vec::new(),
1,
None,
12,
1200,
HashMap::new(),
HashMap::new(),
None,
NetworkDecision::default(),
None,
rng.clone(),
Arc::new(Stats::default()),
general.me_keepalive_enabled,
general.me_keepalive_interval_secs,
general.me_keepalive_jitter_secs,
general.me_keepalive_payload_random,
general.rpc_proxy_req_every,
general.me_warmup_stagger_enabled,
general.me_warmup_step_delay_ms,
general.me_warmup_step_jitter_ms,
general.me_reconnect_max_concurrent_per_dc,
general.me_reconnect_backoff_base_ms,
general.me_reconnect_backoff_cap_ms,
general.me_reconnect_fast_retry_count,
general.me_single_endpoint_shadow_writers,
general.me_single_endpoint_outage_mode_enabled,
general.me_single_endpoint_outage_disable_quarantine,
general.me_single_endpoint_outage_backoff_min_ms,
general.me_single_endpoint_outage_backoff_max_ms,
general.me_single_endpoint_shadow_rotate_every_secs,
general.me_floor_mode,
general.me_adaptive_floor_idle_secs,
general.me_adaptive_floor_min_writers_single_endpoint,
general.me_adaptive_floor_min_writers_multi_endpoint,
general.me_adaptive_floor_recover_grace_secs,
general.me_adaptive_floor_writers_per_core_total,
general.me_adaptive_floor_cpu_cores_override,
general.me_adaptive_floor_max_extra_writers_single_per_core,
general.me_adaptive_floor_max_extra_writers_multi_per_core,
general.me_adaptive_floor_max_active_writers_per_core,
general.me_adaptive_floor_max_warm_writers_per_core,
general.me_adaptive_floor_max_active_writers_global,
general.me_adaptive_floor_max_warm_writers_global,
general.hardswap,
general.me_pool_drain_ttl_secs,
general.me_pool_drain_threshold,
general.effective_me_pool_force_close_secs(),
general.me_pool_min_fresh_ratio,
general.me_hardswap_warmup_delay_min_ms,
general.me_hardswap_warmup_delay_max_ms,
general.me_hardswap_warmup_extra_passes,
general.me_hardswap_warmup_pass_backoff_base_ms,
general.me_bind_stale_mode,
general.me_bind_stale_ttl_secs,
general.me_secret_atomic_snapshot,
general.me_deterministic_writer_sort,
MeWriterPickMode::default(),
general.me_writer_pick_sample_size,
MeSocksKdfPolicy::default(),
general.me_writer_cmd_channel_capacity,
general.me_route_channel_capacity,
general.me_route_backpressure_base_timeout_ms,
general.me_route_backpressure_high_timeout_ms,
general.me_route_backpressure_high_watermark_pct,
general.me_reader_route_data_wait_ms,
general.me_health_interval_ms_unhealthy,
general.me_health_interval_ms_healthy,
general.me_warn_rate_limit_ms,
MeRouteNoWriterMode::default(),
general.me_route_no_writer_wait_ms,
general.me_route_inline_recovery_attempts,
general.me_route_inline_recovery_wait_ms,
);
(pool, rng)
}
async fn insert_draining_writer(
pool: &Arc<MePool>,
writer_id: u64,
drain_started_at_epoch_secs: u64,
bound_clients: usize,
drain_deadline_epoch_secs: u64,
) {
let (tx, _writer_rx) = mpsc::channel::<WriterCommand>(8);
let writer = MeWriter {
id: writer_id,
addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 6000 + writer_id as u16),
source_ip: IpAddr::V4(Ipv4Addr::LOCALHOST),
writer_dc: 2,
generation: 1,
contour: Arc::new(AtomicU8::new(WriterContour::Draining.as_u8())),
created_at: Instant::now() - Duration::from_secs(writer_id),
tx: tx.clone(),
cancel: CancellationToken::new(),
degraded: Arc::new(AtomicBool::new(false)),
rtt_ema_ms_x10: Arc::new(AtomicU32::new(0)),
draining: Arc::new(AtomicBool::new(true)),
draining_started_at_epoch_secs: Arc::new(AtomicU64::new(drain_started_at_epoch_secs)),
drain_deadline_epoch_secs: Arc::new(AtomicU64::new(drain_deadline_epoch_secs)),
allow_drain_fallback: Arc::new(AtomicBool::new(false)),
};
pool.writers.write().await.push(writer);
pool.registry.register_writer(writer_id, tx).await;
pool.conn_count.fetch_add(1, Ordering::Relaxed);
for idx in 0..bound_clients {
let (conn_id, _rx) = pool.registry.register().await;
assert!(
pool.registry
.bind_writer(
conn_id,
writer_id,
ConnMeta {
target_dc: 2,
client_addr: SocketAddr::new(
IpAddr::V4(Ipv4Addr::LOCALHOST),
8000 + idx as u16,
),
our_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443),
proto_flags: 0,
},
)
.await
);
}
}
async fn writer_count(pool: &Arc<MePool>) -> usize {
pool.writers.read().await.len()
}
async fn sorted_writer_ids(pool: &Arc<MePool>) -> Vec<u64> {
let mut ids = pool
.writers
.read()
.await
.iter()
.map(|writer| writer.id)
.collect::<Vec<_>>();
ids.sort_unstable();
ids
}
fn lcg_next(state: &mut u64) -> u64 {
*state = state.wrapping_mul(6364136223846793005).wrapping_add(1);
*state
}
async fn draining_writer_ids(pool: &Arc<MePool>) -> HashSet<u64> {
pool.writers
.read()
.await
.iter()
.filter(|writer| writer.draining.load(Ordering::Relaxed))
.map(|writer| writer.id)
.collect::<HashSet<u64>>()
}
async fn set_writer_runtime_state(
pool: &Arc<MePool>,
writer_id: u64,
draining: bool,
drain_started_at_epoch_secs: u64,
drain_deadline_epoch_secs: u64,
) {
let writers = pool.writers.read().await;
if let Some(writer) = writers.iter().find(|writer| writer.id == writer_id) {
writer.draining.store(draining, Ordering::Relaxed);
writer
.draining_started_at_epoch_secs
.store(drain_started_at_epoch_secs, Ordering::Relaxed);
writer
.drain_deadline_epoch_secs
.store(drain_deadline_epoch_secs, Ordering::Relaxed);
}
}
#[tokio::test]
async fn reap_draining_writers_clears_warn_state_when_pool_empty() {
let (pool, _rng) = make_pool(128, 1, 1).await;
let mut warn_next_allowed = HashMap::new();
warn_next_allowed.insert(11, Instant::now() + Duration::from_secs(5));
warn_next_allowed.insert(22, Instant::now() + Duration::from_secs(5));
reap_draining_writers(&pool, &mut warn_next_allowed).await;
assert!(warn_next_allowed.is_empty());
}
#[tokio::test]
async fn reap_draining_writers_respects_threshold_across_multiple_overflow_cycles() {
let threshold = 3u64;
let (pool, _rng) = make_pool(threshold, 1, 1).await;
let now_epoch_secs = MePool::now_epoch_secs();
for writer_id in 1..=60u64 {
insert_draining_writer(
&pool,
writer_id,
now_epoch_secs.saturating_sub(600).saturating_add(writer_id),
1,
0,
)
.await;
}
let mut warn_next_allowed = HashMap::new();
for _ in 0..64 {
reap_draining_writers(&pool, &mut warn_next_allowed).await;
if writer_count(&pool).await <= threshold as usize {
break;
}
}
assert_eq!(writer_count(&pool).await, threshold as usize);
assert_eq!(sorted_writer_ids(&pool).await, vec![58, 59, 60]);
}
#[tokio::test]
async fn reap_draining_writers_handles_large_empty_writer_population() {
let (pool, _rng) = make_pool(128, 1, 1).await;
let now_epoch_secs = MePool::now_epoch_secs();
let total = health_drain_close_budget().saturating_mul(3).saturating_add(27);
for writer_id in 1..=total as u64 {
insert_draining_writer(
&pool,
writer_id,
now_epoch_secs.saturating_sub(120),
0,
0,
)
.await;
}
let mut warn_next_allowed = HashMap::new();
for _ in 0..24 {
if writer_count(&pool).await == 0 {
break;
}
reap_draining_writers(&pool, &mut warn_next_allowed).await;
}
assert_eq!(writer_count(&pool).await, 0);
}
#[tokio::test]
async fn reap_draining_writers_processes_mass_deadline_expiry_without_unbounded_growth() {
let (pool, _rng) = make_pool(128, 1, 1).await;
let now_epoch_secs = MePool::now_epoch_secs();
let total = health_drain_close_budget().saturating_mul(4).saturating_add(31);
for writer_id in 1..=total as u64 {
insert_draining_writer(
&pool,
writer_id,
now_epoch_secs.saturating_sub(180),
1,
now_epoch_secs.saturating_sub(1),
)
.await;
}
let mut warn_next_allowed = HashMap::new();
for _ in 0..40 {
if writer_count(&pool).await == 0 {
break;
}
reap_draining_writers(&pool, &mut warn_next_allowed).await;
}
assert_eq!(writer_count(&pool).await, 0);
}
#[tokio::test]
async fn reap_draining_writers_maintains_warn_state_subset_property_under_bulk_churn() {
let (pool, _rng) = make_pool(128, 1, 1).await;
let now_epoch_secs = MePool::now_epoch_secs();
let mut warn_next_allowed = HashMap::new();
for wave in 0..40u64 {
for offset in 0..8u64 {
insert_draining_writer(
&pool,
wave * 100 + offset,
now_epoch_secs.saturating_sub(400 + offset),
1,
0,
)
.await;
}
reap_draining_writers(&pool, &mut warn_next_allowed).await;
assert!(warn_next_allowed.len() <= writer_count(&pool).await);
let ids = sorted_writer_ids(&pool).await;
for writer_id in ids.into_iter().take(3) {
let _ = pool.remove_writer_and_close_clients(writer_id).await;
}
reap_draining_writers(&pool, &mut warn_next_allowed).await;
assert!(warn_next_allowed.len() <= writer_count(&pool).await);
}
}
#[tokio::test]
async fn reap_draining_writers_budgeted_cleanup_never_increases_pool_size() {
let (pool, _rng) = make_pool(5, 1, 1).await;
let now_epoch_secs = MePool::now_epoch_secs();
for writer_id in 1..=200u64 {
insert_draining_writer(
&pool,
writer_id,
now_epoch_secs.saturating_sub(240).saturating_add(writer_id),
1,
0,
)
.await;
}
let mut warn_next_allowed = HashMap::new();
let mut previous = writer_count(&pool).await;
for _ in 0..32 {
reap_draining_writers(&pool, &mut warn_next_allowed).await;
let current = writer_count(&pool).await;
assert!(current <= previous);
previous = current;
}
}
#[tokio::test]
async fn me_health_monitor_converges_to_threshold_under_live_injection_churn() {
let threshold = 7u64;
let (pool, rng) = make_pool(threshold, 1, 1).await;
let now_epoch_secs = MePool::now_epoch_secs();
for writer_id in 1..=40u64 {
insert_draining_writer(
&pool,
writer_id,
now_epoch_secs.saturating_sub(300).saturating_add(writer_id),
1,
0,
)
.await;
}
let monitor = tokio::spawn(me_health_monitor(pool.clone(), rng, 0));
for wave in 0..8u64 {
for offset in 0..10u64 {
insert_draining_writer(
&pool,
1000 + wave * 100 + offset,
now_epoch_secs.saturating_sub(120).saturating_add(offset),
1,
0,
)
.await;
}
tokio::time::sleep(Duration::from_millis(5)).await;
}
tokio::time::sleep(Duration::from_millis(120)).await;
monitor.abort();
let _ = monitor.await;
assert!(writer_count(&pool).await <= threshold as usize);
}
#[tokio::test]
async fn me_health_monitor_drains_deadline_storm_with_budgeted_progress() {
let (pool, rng) = make_pool(128, 1, 1).await;
let now_epoch_secs = MePool::now_epoch_secs();
for writer_id in 1..=220u64 {
insert_draining_writer(
&pool,
writer_id,
now_epoch_secs.saturating_sub(120),
1,
now_epoch_secs.saturating_sub(1),
)
.await;
}
let monitor = tokio::spawn(me_health_monitor(pool.clone(), rng, 0));
tokio::time::sleep(Duration::from_millis(120)).await;
monitor.abort();
let _ = monitor.await;
assert_eq!(writer_count(&pool).await, 0);
}
#[tokio::test]
async fn me_health_monitor_eliminates_mixed_empty_and_deadline_backlog() {
let threshold = 12u64;
let (pool, rng) = make_pool(threshold, 1, 1).await;
let now_epoch_secs = MePool::now_epoch_secs();
for writer_id in 1..=180u64 {
let bound_clients = if writer_id % 3 == 0 { 0 } else { 1 };
let deadline = if writer_id % 2 == 0 {
now_epoch_secs.saturating_sub(1)
} else {
0
};
insert_draining_writer(
&pool,
writer_id,
now_epoch_secs.saturating_sub(250).saturating_add(writer_id),
bound_clients,
deadline,
)
.await;
}
let monitor = tokio::spawn(me_health_monitor(pool.clone(), rng, 0));
tokio::time::sleep(Duration::from_millis(140)).await;
monitor.abort();
let _ = monitor.await;
assert!(writer_count(&pool).await <= threshold as usize);
}
#[tokio::test]
async fn reap_draining_writers_deterministic_mixed_state_churn_preserves_invariants() {
let threshold = 9u64;
let (pool, _rng) = make_pool(threshold, 1, 1).await;
let mut warn_next_allowed = HashMap::new();
let mut seed = 0x9E37_79B9_7F4A_7C15u64;
let mut next_writer_id = 20_000u64;
let now_epoch_secs = MePool::now_epoch_secs();
for writer_id in 1..=72u64 {
let bound_clients = if writer_id % 4 == 0 { 0 } else { 1 };
let deadline = if writer_id % 5 == 0 {
now_epoch_secs.saturating_sub(1)
} else {
0
};
insert_draining_writer(
&pool,
writer_id,
now_epoch_secs.saturating_sub(500).saturating_add(writer_id),
bound_clients,
deadline,
)
.await;
}
for _round in 0..90 {
reap_draining_writers(&pool, &mut warn_next_allowed).await;
let draining_ids = draining_writer_ids(&pool).await;
assert!(
warn_next_allowed.keys().all(|id| draining_ids.contains(id)),
"warn-state keys must always be a subset of live draining writers"
);
let writer_ids = sorted_writer_ids(&pool).await;
if writer_ids.is_empty() {
continue;
}
let remove_n = (lcg_next(&mut seed) % 3) as usize;
for writer_id in writer_ids.iter().copied().take(remove_n) {
let _ = pool.remove_writer_and_close_clients(writer_id).await;
}
let survivors = sorted_writer_ids(&pool).await;
if !survivors.is_empty() {
let idx = (lcg_next(&mut seed) as usize) % survivors.len();
let target = survivors[idx];
set_writer_runtime_state(&pool, target, false, 0, 0).await;
}
let survivors = sorted_writer_ids(&pool).await;
if survivors.len() > 1 {
let idx = (lcg_next(&mut seed) as usize) % survivors.len();
let target = survivors[idx];
let expired_deadline = if lcg_next(&mut seed) & 1 == 0 {
now_epoch_secs.saturating_sub(1)
} else {
0
};
set_writer_runtime_state(
&pool,
target,
true,
now_epoch_secs.saturating_sub(120),
expired_deadline,
)
.await;
}
let inject_n = (lcg_next(&mut seed) % 4) as usize;
for _ in 0..inject_n {
let bound_clients = if lcg_next(&mut seed) & 1 == 0 { 0 } else { 1 };
let deadline = if lcg_next(&mut seed) & 1 == 0 {
now_epoch_secs.saturating_sub(1)
} else {
0
};
insert_draining_writer(
&pool,
next_writer_id,
now_epoch_secs.saturating_sub(240),
bound_clients,
deadline,
)
.await;
next_writer_id = next_writer_id.saturating_add(1);
}
}
for _ in 0..64 {
reap_draining_writers(&pool, &mut warn_next_allowed).await;
if writer_count(&pool).await <= threshold as usize {
break;
}
}
assert!(writer_count(&pool).await <= threshold as usize);
let draining_ids = draining_writer_ids(&pool).await;
assert!(warn_next_allowed.keys().all(|id| draining_ids.contains(id)));
}
#[tokio::test]
async fn reap_draining_writers_repeated_draining_flips_never_leave_stale_warn_state() {
let (pool, _rng) = make_pool(64, 1, 1).await;
let now_epoch_secs = MePool::now_epoch_secs();
for writer_id in 1..=24u64 {
insert_draining_writer(
&pool,
writer_id,
now_epoch_secs.saturating_sub(240),
1,
0,
)
.await;
}
let mut warn_next_allowed = HashMap::new();
for _round in 0..48u64 {
for writer_id in 1..=24u64 {
let draining = (writer_id + _round) % 3 != 0;
set_writer_runtime_state(
&pool,
writer_id,
draining,
now_epoch_secs.saturating_sub(120),
0,
)
.await;
}
reap_draining_writers(&pool, &mut warn_next_allowed).await;
let draining_ids = draining_writer_ids(&pool).await;
assert!(
warn_next_allowed.keys().all(|id| draining_ids.contains(id)),
"warn-state map must not retain entries for writers outside draining set"
);
}
}
#[test]
fn health_drain_close_budget_is_within_expected_bounds() {
let budget = health_drain_close_budget();
assert!((16..=256).contains(&budget));
}

View File

@ -0,0 +1,241 @@
use std::collections::HashMap;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU8, AtomicU32, AtomicU64, Ordering};
use std::time::{Duration, Instant};
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use super::codec::WriterCommand;
use super::health::health_drain_close_budget;
use super::pool::{MePool, MeWriter, WriterContour};
use super::registry::ConnMeta;
use super::me_health_monitor;
use crate::config::{GeneralConfig, MeRouteNoWriterMode, MeSocksKdfPolicy, MeWriterPickMode};
use crate::crypto::SecureRandom;
use crate::network::probe::NetworkDecision;
use crate::stats::Stats;
async fn make_pool(
me_pool_drain_threshold: u64,
me_health_interval_ms_unhealthy: u64,
me_health_interval_ms_healthy: u64,
) -> (Arc<MePool>, Arc<SecureRandom>) {
let general = GeneralConfig {
me_pool_drain_threshold,
me_health_interval_ms_unhealthy,
me_health_interval_ms_healthy,
..GeneralConfig::default()
};
let rng = Arc::new(SecureRandom::new());
let pool = MePool::new(
None,
vec![1u8; 32],
None,
false,
None,
Vec::new(),
1,
None,
12,
1200,
HashMap::new(),
HashMap::new(),
None,
NetworkDecision::default(),
None,
rng.clone(),
Arc::new(Stats::default()),
general.me_keepalive_enabled,
general.me_keepalive_interval_secs,
general.me_keepalive_jitter_secs,
general.me_keepalive_payload_random,
general.rpc_proxy_req_every,
general.me_warmup_stagger_enabled,
general.me_warmup_step_delay_ms,
general.me_warmup_step_jitter_ms,
general.me_reconnect_max_concurrent_per_dc,
general.me_reconnect_backoff_base_ms,
general.me_reconnect_backoff_cap_ms,
general.me_reconnect_fast_retry_count,
general.me_single_endpoint_shadow_writers,
general.me_single_endpoint_outage_mode_enabled,
general.me_single_endpoint_outage_disable_quarantine,
general.me_single_endpoint_outage_backoff_min_ms,
general.me_single_endpoint_outage_backoff_max_ms,
general.me_single_endpoint_shadow_rotate_every_secs,
general.me_floor_mode,
general.me_adaptive_floor_idle_secs,
general.me_adaptive_floor_min_writers_single_endpoint,
general.me_adaptive_floor_min_writers_multi_endpoint,
general.me_adaptive_floor_recover_grace_secs,
general.me_adaptive_floor_writers_per_core_total,
general.me_adaptive_floor_cpu_cores_override,
general.me_adaptive_floor_max_extra_writers_single_per_core,
general.me_adaptive_floor_max_extra_writers_multi_per_core,
general.me_adaptive_floor_max_active_writers_per_core,
general.me_adaptive_floor_max_warm_writers_per_core,
general.me_adaptive_floor_max_active_writers_global,
general.me_adaptive_floor_max_warm_writers_global,
general.hardswap,
general.me_pool_drain_ttl_secs,
general.me_pool_drain_threshold,
general.effective_me_pool_force_close_secs(),
general.me_pool_min_fresh_ratio,
general.me_hardswap_warmup_delay_min_ms,
general.me_hardswap_warmup_delay_max_ms,
general.me_hardswap_warmup_extra_passes,
general.me_hardswap_warmup_pass_backoff_base_ms,
general.me_bind_stale_mode,
general.me_bind_stale_ttl_secs,
general.me_secret_atomic_snapshot,
general.me_deterministic_writer_sort,
MeWriterPickMode::default(),
general.me_writer_pick_sample_size,
MeSocksKdfPolicy::default(),
general.me_writer_cmd_channel_capacity,
general.me_route_channel_capacity,
general.me_route_backpressure_base_timeout_ms,
general.me_route_backpressure_high_timeout_ms,
general.me_route_backpressure_high_watermark_pct,
general.me_reader_route_data_wait_ms,
general.me_health_interval_ms_unhealthy,
general.me_health_interval_ms_healthy,
general.me_warn_rate_limit_ms,
MeRouteNoWriterMode::default(),
general.me_route_no_writer_wait_ms,
general.me_route_inline_recovery_attempts,
general.me_route_inline_recovery_wait_ms,
);
(pool, rng)
}
async fn insert_draining_writer(
pool: &Arc<MePool>,
writer_id: u64,
drain_started_at_epoch_secs: u64,
bound_clients: usize,
drain_deadline_epoch_secs: u64,
) {
let (tx, _writer_rx) = mpsc::channel::<WriterCommand>(8);
let writer = MeWriter {
id: writer_id,
addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 5500 + writer_id as u16),
source_ip: IpAddr::V4(Ipv4Addr::LOCALHOST),
writer_dc: 2,
generation: 1,
contour: Arc::new(AtomicU8::new(WriterContour::Draining.as_u8())),
created_at: Instant::now() - Duration::from_secs(writer_id),
tx: tx.clone(),
cancel: CancellationToken::new(),
degraded: Arc::new(AtomicBool::new(false)),
rtt_ema_ms_x10: Arc::new(AtomicU32::new(0)),
draining: Arc::new(AtomicBool::new(true)),
draining_started_at_epoch_secs: Arc::new(AtomicU64::new(drain_started_at_epoch_secs)),
drain_deadline_epoch_secs: Arc::new(AtomicU64::new(drain_deadline_epoch_secs)),
allow_drain_fallback: Arc::new(AtomicBool::new(false)),
};
pool.writers.write().await.push(writer);
pool.registry.register_writer(writer_id, tx).await;
pool.conn_count.fetch_add(1, Ordering::Relaxed);
for idx in 0..bound_clients {
let (conn_id, _rx) = pool.registry.register().await;
assert!(
pool.registry
.bind_writer(
conn_id,
writer_id,
ConnMeta {
target_dc: 2,
client_addr: SocketAddr::new(
IpAddr::V4(Ipv4Addr::LOCALHOST),
7200 + idx as u16,
),
our_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443),
proto_flags: 0,
},
)
.await
);
}
}
async fn wait_for_pool_empty(pool: &Arc<MePool>, timeout: Duration) {
let start = Instant::now();
loop {
if pool.writers.read().await.is_empty() {
return;
}
assert!(
start.elapsed() < timeout,
"timed out waiting for pool.writers to become empty"
);
tokio::time::sleep(Duration::from_millis(5)).await;
}
}
#[tokio::test]
async fn me_health_monitor_drains_expired_backlog_over_multiple_cycles() {
let (pool, rng) = make_pool(128, 1, 1).await;
let now_epoch_secs = MePool::now_epoch_secs();
let writer_total = health_drain_close_budget().saturating_mul(2).saturating_add(9);
for writer_id in 1..=writer_total as u64 {
insert_draining_writer(
&pool,
writer_id,
now_epoch_secs.saturating_sub(120),
1,
now_epoch_secs.saturating_sub(1),
)
.await;
}
let monitor = tokio::spawn(me_health_monitor(pool.clone(), rng, 0));
wait_for_pool_empty(&pool, Duration::from_secs(1)).await;
monitor.abort();
let _ = monitor.await;
assert!(pool.writers.read().await.is_empty());
}
#[tokio::test]
async fn me_health_monitor_cleans_empty_draining_writers_without_force_close() {
let (pool, rng) = make_pool(128, 1, 1).await;
let now_epoch_secs = MePool::now_epoch_secs();
for writer_id in 1..=24u64 {
insert_draining_writer(&pool, writer_id, now_epoch_secs.saturating_sub(60), 0, 0).await;
}
let monitor = tokio::spawn(me_health_monitor(pool.clone(), rng, 0));
wait_for_pool_empty(&pool, Duration::from_secs(1)).await;
monitor.abort();
let _ = monitor.await;
assert!(pool.writers.read().await.is_empty());
}
#[tokio::test]
async fn me_health_monitor_converges_retry_like_threshold_backlog_to_empty() {
let threshold = 4u64;
let (pool, rng) = make_pool(threshold, 1, 1).await;
let now_epoch_secs = MePool::now_epoch_secs();
let writer_total = threshold as usize + health_drain_close_budget().saturating_add(11);
for writer_id in 1..=writer_total as u64 {
insert_draining_writer(
&pool,
writer_id,
now_epoch_secs.saturating_sub(300).saturating_add(writer_id),
1,
0,
)
.await;
}
let monitor = tokio::spawn(me_health_monitor(pool.clone(), rng, 0));
wait_for_pool_empty(&pool, Duration::from_secs(1)).await;
monitor.abort();
let _ = monitor.await;
assert!(pool.writers.read().await.is_empty());
}

View File

@ -0,0 +1,658 @@
use std::collections::HashMap;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU8, AtomicU32, AtomicU64, Ordering};
use std::time::{Duration, Instant};
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use super::codec::WriterCommand;
use super::health::{health_drain_close_budget, reap_draining_writers};
use super::pool::{MePool, MeWriter, WriterContour};
use super::registry::ConnMeta;
use crate::config::{GeneralConfig, MeRouteNoWriterMode, MeSocksKdfPolicy, MeWriterPickMode};
use crate::crypto::SecureRandom;
use crate::network::probe::NetworkDecision;
use crate::stats::Stats;
async fn make_pool(me_pool_drain_threshold: u64) -> Arc<MePool> {
let general = GeneralConfig {
me_pool_drain_threshold,
..GeneralConfig::default()
};
MePool::new(
None,
vec![1u8; 32],
None,
false,
None,
Vec::new(),
1,
None,
12,
1200,
HashMap::new(),
HashMap::new(),
None,
NetworkDecision::default(),
None,
Arc::new(SecureRandom::new()),
Arc::new(Stats::default()),
general.me_keepalive_enabled,
general.me_keepalive_interval_secs,
general.me_keepalive_jitter_secs,
general.me_keepalive_payload_random,
general.rpc_proxy_req_every,
general.me_warmup_stagger_enabled,
general.me_warmup_step_delay_ms,
general.me_warmup_step_jitter_ms,
general.me_reconnect_max_concurrent_per_dc,
general.me_reconnect_backoff_base_ms,
general.me_reconnect_backoff_cap_ms,
general.me_reconnect_fast_retry_count,
general.me_single_endpoint_shadow_writers,
general.me_single_endpoint_outage_mode_enabled,
general.me_single_endpoint_outage_disable_quarantine,
general.me_single_endpoint_outage_backoff_min_ms,
general.me_single_endpoint_outage_backoff_max_ms,
general.me_single_endpoint_shadow_rotate_every_secs,
general.me_floor_mode,
general.me_adaptive_floor_idle_secs,
general.me_adaptive_floor_min_writers_single_endpoint,
general.me_adaptive_floor_min_writers_multi_endpoint,
general.me_adaptive_floor_recover_grace_secs,
general.me_adaptive_floor_writers_per_core_total,
general.me_adaptive_floor_cpu_cores_override,
general.me_adaptive_floor_max_extra_writers_single_per_core,
general.me_adaptive_floor_max_extra_writers_multi_per_core,
general.me_adaptive_floor_max_active_writers_per_core,
general.me_adaptive_floor_max_warm_writers_per_core,
general.me_adaptive_floor_max_active_writers_global,
general.me_adaptive_floor_max_warm_writers_global,
general.hardswap,
general.me_pool_drain_ttl_secs,
general.me_pool_drain_threshold,
general.effective_me_pool_force_close_secs(),
general.me_pool_min_fresh_ratio,
general.me_hardswap_warmup_delay_min_ms,
general.me_hardswap_warmup_delay_max_ms,
general.me_hardswap_warmup_extra_passes,
general.me_hardswap_warmup_pass_backoff_base_ms,
general.me_bind_stale_mode,
general.me_bind_stale_ttl_secs,
general.me_secret_atomic_snapshot,
general.me_deterministic_writer_sort,
MeWriterPickMode::default(),
general.me_writer_pick_sample_size,
MeSocksKdfPolicy::default(),
general.me_writer_cmd_channel_capacity,
general.me_route_channel_capacity,
general.me_route_backpressure_base_timeout_ms,
general.me_route_backpressure_high_timeout_ms,
general.me_route_backpressure_high_watermark_pct,
general.me_reader_route_data_wait_ms,
general.me_health_interval_ms_unhealthy,
general.me_health_interval_ms_healthy,
general.me_warn_rate_limit_ms,
MeRouteNoWriterMode::default(),
general.me_route_no_writer_wait_ms,
general.me_route_inline_recovery_attempts,
general.me_route_inline_recovery_wait_ms,
)
}
async fn insert_draining_writer(
pool: &Arc<MePool>,
writer_id: u64,
drain_started_at_epoch_secs: u64,
bound_clients: usize,
drain_deadline_epoch_secs: u64,
) -> Vec<u64> {
let mut conn_ids = Vec::with_capacity(bound_clients);
let (tx, _writer_rx) = mpsc::channel::<WriterCommand>(8);
let writer = MeWriter {
id: writer_id,
addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 4500 + writer_id as u16),
source_ip: IpAddr::V4(Ipv4Addr::LOCALHOST),
writer_dc: 2,
generation: 1,
contour: Arc::new(AtomicU8::new(WriterContour::Draining.as_u8())),
created_at: Instant::now() - Duration::from_secs(writer_id),
tx: tx.clone(),
cancel: CancellationToken::new(),
degraded: Arc::new(AtomicBool::new(false)),
rtt_ema_ms_x10: Arc::new(AtomicU32::new(0)),
draining: Arc::new(AtomicBool::new(true)),
draining_started_at_epoch_secs: Arc::new(AtomicU64::new(drain_started_at_epoch_secs)),
drain_deadline_epoch_secs: Arc::new(AtomicU64::new(drain_deadline_epoch_secs)),
allow_drain_fallback: Arc::new(AtomicBool::new(false)),
};
pool.writers.write().await.push(writer);
pool.registry.register_writer(writer_id, tx).await;
pool.conn_count.fetch_add(1, Ordering::Relaxed);
for idx in 0..bound_clients {
let (conn_id, _rx) = pool.registry.register().await;
assert!(
pool.registry
.bind_writer(
conn_id,
writer_id,
ConnMeta {
target_dc: 2,
client_addr: SocketAddr::new(
IpAddr::V4(Ipv4Addr::LOCALHOST),
6200 + idx as u16,
),
our_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443),
proto_flags: 0,
},
)
.await
);
conn_ids.push(conn_id);
}
conn_ids
}
async fn current_writer_ids(pool: &Arc<MePool>) -> Vec<u64> {
let mut writer_ids = pool
.writers
.read()
.await
.iter()
.map(|writer| writer.id)
.collect::<Vec<_>>();
writer_ids.sort_unstable();
writer_ids
}
async fn writer_exists(pool: &Arc<MePool>, writer_id: u64) -> bool {
pool.writers
.read()
.await
.iter()
.any(|writer| writer.id == writer_id)
}
async fn set_writer_draining(pool: &Arc<MePool>, writer_id: u64, draining: bool) {
let writers = pool.writers.read().await;
if let Some(writer) = writers.iter().find(|writer| writer.id == writer_id) {
writer.draining.store(draining, Ordering::Relaxed);
}
}
#[tokio::test]
async fn reap_draining_writers_drops_warn_state_for_removed_writer() {
let pool = make_pool(128).await;
let now_epoch_secs = MePool::now_epoch_secs();
let conn_ids =
insert_draining_writer(&pool, 7, now_epoch_secs.saturating_sub(180), 1, 0).await;
let mut warn_next_allowed = HashMap::new();
reap_draining_writers(&pool, &mut warn_next_allowed).await;
assert!(warn_next_allowed.contains_key(&7));
let _ = pool.remove_writer_and_close_clients(7).await;
assert!(pool.registry.get_writer(conn_ids[0]).await.is_none());
reap_draining_writers(&pool, &mut warn_next_allowed).await;
assert!(!warn_next_allowed.contains_key(&7));
}
#[tokio::test]
async fn reap_draining_writers_removes_empty_draining_writers() {
let pool = make_pool(128).await;
let now_epoch_secs = MePool::now_epoch_secs();
insert_draining_writer(&pool, 1, now_epoch_secs.saturating_sub(40), 0, 0).await;
insert_draining_writer(&pool, 2, now_epoch_secs.saturating_sub(30), 0, 0).await;
insert_draining_writer(&pool, 3, now_epoch_secs.saturating_sub(20), 1, 0).await;
let mut warn_next_allowed = HashMap::new();
reap_draining_writers(&pool, &mut warn_next_allowed).await;
assert_eq!(current_writer_ids(&pool).await, vec![3]);
}
#[tokio::test]
async fn reap_draining_writers_overflow_closes_oldest_non_empty_writers() {
let pool = make_pool(2).await;
let now_epoch_secs = MePool::now_epoch_secs();
insert_draining_writer(&pool, 11, now_epoch_secs.saturating_sub(40), 1, 0).await;
insert_draining_writer(&pool, 22, now_epoch_secs.saturating_sub(30), 1, 0).await;
insert_draining_writer(&pool, 33, now_epoch_secs.saturating_sub(20), 1, 0).await;
insert_draining_writer(&pool, 44, now_epoch_secs.saturating_sub(10), 1, 0).await;
let mut warn_next_allowed = HashMap::new();
reap_draining_writers(&pool, &mut warn_next_allowed).await;
assert_eq!(current_writer_ids(&pool).await, vec![33, 44]);
}
#[tokio::test]
async fn reap_draining_writers_deadline_force_close_applies_under_threshold() {
let pool = make_pool(128).await;
let now_epoch_secs = MePool::now_epoch_secs();
insert_draining_writer(
&pool,
50,
now_epoch_secs.saturating_sub(15),
1,
now_epoch_secs.saturating_sub(1),
)
.await;
let mut warn_next_allowed = HashMap::new();
reap_draining_writers(&pool, &mut warn_next_allowed).await;
assert!(current_writer_ids(&pool).await.is_empty());
}
#[tokio::test]
async fn reap_draining_writers_limits_closes_per_health_tick() {
let pool = make_pool(128).await;
let now_epoch_secs = MePool::now_epoch_secs();
let close_budget = health_drain_close_budget();
let writer_total = close_budget.saturating_add(19);
for writer_id in 1..=writer_total as u64 {
insert_draining_writer(
&pool,
writer_id,
now_epoch_secs.saturating_sub(20),
1,
now_epoch_secs.saturating_sub(1),
)
.await;
}
let mut warn_next_allowed = HashMap::new();
reap_draining_writers(&pool, &mut warn_next_allowed).await;
assert_eq!(pool.writers.read().await.len(), writer_total - close_budget);
}
#[tokio::test]
async fn reap_draining_writers_keeps_warn_state_for_deadline_backlog_writers() {
let pool = make_pool(0).await;
let now_epoch_secs = MePool::now_epoch_secs();
let close_budget = health_drain_close_budget();
let writer_total = close_budget.saturating_add(5);
for writer_id in 1..=writer_total as u64 {
insert_draining_writer(
&pool,
writer_id,
now_epoch_secs.saturating_sub(60),
1,
now_epoch_secs.saturating_sub(1),
)
.await;
}
let target_writer_id = writer_total as u64;
let mut warn_next_allowed = HashMap::new();
warn_next_allowed.insert(
target_writer_id,
Instant::now() + Duration::from_secs(300),
);
reap_draining_writers(&pool, &mut warn_next_allowed).await;
assert!(writer_exists(&pool, target_writer_id).await);
assert!(warn_next_allowed.contains_key(&target_writer_id));
}
#[tokio::test]
async fn reap_draining_writers_keeps_warn_state_for_overflow_backlog_writers() {
let pool = make_pool(1).await;
let now_epoch_secs = MePool::now_epoch_secs();
let close_budget = health_drain_close_budget();
let writer_total = close_budget.saturating_add(6);
for writer_id in 1..=writer_total as u64 {
insert_draining_writer(
&pool,
writer_id,
now_epoch_secs.saturating_sub(300).saturating_add(writer_id),
1,
0,
)
.await;
}
let target_writer_id = writer_total.saturating_sub(1) as u64;
let mut warn_next_allowed = HashMap::new();
warn_next_allowed.insert(
target_writer_id,
Instant::now() + Duration::from_secs(300),
);
reap_draining_writers(&pool, &mut warn_next_allowed).await;
assert!(writer_exists(&pool, target_writer_id).await);
assert!(warn_next_allowed.contains_key(&target_writer_id));
}
#[tokio::test]
async fn reap_draining_writers_drops_warn_state_when_writer_exits_draining_state() {
let pool = make_pool(128).await;
let now_epoch_secs = MePool::now_epoch_secs();
insert_draining_writer(&pool, 71, now_epoch_secs.saturating_sub(60), 1, 0).await;
let mut warn_next_allowed = HashMap::new();
warn_next_allowed.insert(71, Instant::now() + Duration::from_secs(300));
set_writer_draining(&pool, 71, false).await;
reap_draining_writers(&pool, &mut warn_next_allowed).await;
assert!(writer_exists(&pool, 71).await);
assert!(
!warn_next_allowed.contains_key(&71),
"warn cooldown state must be dropped after writer leaves draining state"
);
}
#[tokio::test]
async fn reap_draining_writers_preserves_warn_state_across_multiple_budget_deferrals() {
let pool = make_pool(0).await;
let now_epoch_secs = MePool::now_epoch_secs();
let close_budget = health_drain_close_budget();
let writer_total = close_budget.saturating_mul(2).saturating_add(1);
for writer_id in 1..=writer_total as u64 {
insert_draining_writer(
&pool,
writer_id,
now_epoch_secs.saturating_sub(120),
1,
now_epoch_secs.saturating_sub(1),
)
.await;
}
let tail_writer_id = writer_total as u64;
let mut warn_next_allowed = HashMap::new();
warn_next_allowed.insert(
tail_writer_id,
Instant::now() + Duration::from_secs(300),
);
reap_draining_writers(&pool, &mut warn_next_allowed).await;
assert!(writer_exists(&pool, tail_writer_id).await);
assert!(warn_next_allowed.contains_key(&tail_writer_id));
reap_draining_writers(&pool, &mut warn_next_allowed).await;
assert!(writer_exists(&pool, tail_writer_id).await);
assert!(warn_next_allowed.contains_key(&tail_writer_id));
reap_draining_writers(&pool, &mut warn_next_allowed).await;
assert!(!writer_exists(&pool, tail_writer_id).await);
assert!(
!warn_next_allowed.contains_key(&tail_writer_id),
"warn cooldown state must clear once writer is actually removed"
);
}
#[tokio::test]
async fn reap_draining_writers_backlog_drains_across_ticks() {
let pool = make_pool(128).await;
let now_epoch_secs = MePool::now_epoch_secs();
let close_budget = health_drain_close_budget();
let writer_total = close_budget.saturating_mul(2).saturating_add(7);
for writer_id in 1..=writer_total as u64 {
insert_draining_writer(
&pool,
writer_id,
now_epoch_secs.saturating_sub(20),
1,
now_epoch_secs.saturating_sub(1),
)
.await;
}
let mut warn_next_allowed = HashMap::new();
for _ in 0..8 {
if pool.writers.read().await.is_empty() {
break;
}
reap_draining_writers(&pool, &mut warn_next_allowed).await;
}
assert!(pool.writers.read().await.is_empty());
}
#[tokio::test]
async fn reap_draining_writers_threshold_backlog_converges_to_threshold() {
let threshold = 5u64;
let pool = make_pool(threshold).await;
let now_epoch_secs = MePool::now_epoch_secs();
let close_budget = health_drain_close_budget();
let writer_total = threshold as usize + close_budget.saturating_add(12);
for writer_id in 1..=writer_total as u64 {
insert_draining_writer(
&pool,
writer_id,
now_epoch_secs.saturating_sub(200).saturating_add(writer_id),
1,
0,
)
.await;
}
let mut warn_next_allowed = HashMap::new();
for _ in 0..16 {
reap_draining_writers(&pool, &mut warn_next_allowed).await;
if pool.writers.read().await.len() <= threshold as usize {
break;
}
}
assert_eq!(pool.writers.read().await.len(), threshold as usize);
}
#[tokio::test]
async fn reap_draining_writers_threshold_zero_preserves_non_expired_non_empty_writers() {
let pool = make_pool(0).await;
let now_epoch_secs = MePool::now_epoch_secs();
insert_draining_writer(&pool, 10, now_epoch_secs.saturating_sub(40), 1, 0).await;
insert_draining_writer(&pool, 20, now_epoch_secs.saturating_sub(30), 1, 0).await;
insert_draining_writer(&pool, 30, now_epoch_secs.saturating_sub(20), 1, 0).await;
let mut warn_next_allowed = HashMap::new();
reap_draining_writers(&pool, &mut warn_next_allowed).await;
assert_eq!(current_writer_ids(&pool).await, vec![10, 20, 30]);
}
#[tokio::test]
async fn reap_draining_writers_prioritizes_force_close_before_empty_cleanup() {
let pool = make_pool(128).await;
let now_epoch_secs = MePool::now_epoch_secs();
let close_budget = health_drain_close_budget();
for writer_id in 1..=close_budget as u64 {
insert_draining_writer(
&pool,
writer_id,
now_epoch_secs.saturating_sub(20),
1,
now_epoch_secs.saturating_sub(1),
)
.await;
}
let empty_writer_id = close_budget as u64 + 1;
insert_draining_writer(&pool, empty_writer_id, now_epoch_secs.saturating_sub(20), 0, 0).await;
let mut warn_next_allowed = HashMap::new();
reap_draining_writers(&pool, &mut warn_next_allowed).await;
assert_eq!(current_writer_ids(&pool).await, vec![empty_writer_id]);
}
#[tokio::test]
async fn reap_draining_writers_empty_cleanup_does_not_increment_force_close_metric() {
let pool = make_pool(128).await;
let now_epoch_secs = MePool::now_epoch_secs();
insert_draining_writer(&pool, 1, now_epoch_secs.saturating_sub(60), 0, 0).await;
insert_draining_writer(&pool, 2, now_epoch_secs.saturating_sub(50), 0, 0).await;
let mut warn_next_allowed = HashMap::new();
reap_draining_writers(&pool, &mut warn_next_allowed).await;
assert!(current_writer_ids(&pool).await.is_empty());
assert_eq!(pool.stats.get_pool_force_close_total(), 0);
}
#[tokio::test]
async fn reap_draining_writers_handles_duplicate_force_close_requests_for_same_writer() {
let pool = make_pool(1).await;
let now_epoch_secs = MePool::now_epoch_secs();
insert_draining_writer(
&pool,
10,
now_epoch_secs.saturating_sub(30),
1,
now_epoch_secs.saturating_sub(1),
)
.await;
insert_draining_writer(
&pool,
20,
now_epoch_secs.saturating_sub(20),
1,
now_epoch_secs.saturating_sub(1),
)
.await;
let mut warn_next_allowed = HashMap::new();
reap_draining_writers(&pool, &mut warn_next_allowed).await;
assert!(current_writer_ids(&pool).await.is_empty());
}
#[tokio::test]
async fn reap_draining_writers_warn_state_never_exceeds_live_draining_population_under_churn() {
let pool = make_pool(128).await;
let now_epoch_secs = MePool::now_epoch_secs();
let mut warn_next_allowed = HashMap::new();
for wave in 0..12u64 {
for offset in 0..9u64 {
insert_draining_writer(
&pool,
wave * 100 + offset,
now_epoch_secs.saturating_sub(120 + offset),
1,
0,
)
.await;
}
reap_draining_writers(&pool, &mut warn_next_allowed).await;
assert!(warn_next_allowed.len() <= pool.writers.read().await.len());
let existing_writer_ids = current_writer_ids(&pool).await;
for writer_id in existing_writer_ids.into_iter().take(4) {
let _ = pool.remove_writer_and_close_clients(writer_id).await;
}
reap_draining_writers(&pool, &mut warn_next_allowed).await;
assert!(warn_next_allowed.len() <= pool.writers.read().await.len());
}
}
#[tokio::test]
async fn reap_draining_writers_mixed_backlog_converges_without_leaking_warn_state() {
let pool = make_pool(6).await;
let now_epoch_secs = MePool::now_epoch_secs();
let mut warn_next_allowed = HashMap::new();
for writer_id in 1..=18u64 {
let bound_clients = if writer_id % 3 == 0 { 0 } else { 1 };
let deadline = if writer_id % 2 == 0 {
now_epoch_secs.saturating_sub(1)
} else {
0
};
insert_draining_writer(
&pool,
writer_id,
now_epoch_secs.saturating_sub(300).saturating_add(writer_id),
bound_clients,
deadline,
)
.await;
}
for _ in 0..16 {
reap_draining_writers(&pool, &mut warn_next_allowed).await;
if pool.writers.read().await.len() <= 6 {
break;
}
}
assert!(pool.writers.read().await.len() <= 6);
assert!(warn_next_allowed.len() <= pool.writers.read().await.len());
}
#[test]
fn general_config_default_drain_threshold_remains_enabled() {
assert_eq!(GeneralConfig::default().me_pool_drain_threshold, 128);
}
#[tokio::test]
async fn reap_draining_writers_does_not_close_writer_that_became_non_empty_after_snapshot() {
let pool = make_pool(128).await;
let now_epoch_secs = MePool::now_epoch_secs();
let empty_writer_id = 700u64;
insert_draining_writer(
&pool,
empty_writer_id,
now_epoch_secs.saturating_sub(60),
0,
0,
)
.await;
let stale_empty_snapshot = vec![empty_writer_id];
let (rebound_conn_id, _rx) = pool.registry.register().await;
assert!(
pool.registry
.bind_writer(
rebound_conn_id,
empty_writer_id,
ConnMeta {
target_dc: 2,
client_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 9050),
our_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443),
proto_flags: 0,
},
)
.await,
"writer should accept a new bind after stale empty snapshot"
);
for writer_id in stale_empty_snapshot {
assert!(
!pool.remove_writer_if_empty(writer_id).await,
"atomic empty cleanup must reject writers that gained bound clients"
);
}
assert!(
writer_exists(&pool, empty_writer_id).await,
"empty-path cleanup must not remove a writer that gained a bound client"
);
assert_eq!(
pool.registry.get_writer(rebound_conn_id).await.map(|w| w.writer_id),
Some(empty_writer_id)
);
let _ = pool.registry.unregister(rebound_conn_id).await;
}
#[tokio::test]
async fn prune_closed_writers_closes_bound_clients_when_writer_is_non_empty() {
let pool = make_pool(128).await;
let now_epoch_secs = MePool::now_epoch_secs();
let conn_ids = insert_draining_writer(&pool, 910, now_epoch_secs.saturating_sub(60), 1, 0).await;
pool.prune_closed_writers().await;
assert!(!writer_exists(&pool, 910).await);
assert!(pool.registry.get_writer(conn_ids[0]).await.is_none());
}

View File

@ -21,6 +21,14 @@ mod secret;
mod selftest;
mod wire;
mod pool_status;
#[cfg(test)]
mod health_regression_tests;
#[cfg(test)]
mod health_integration_tests;
#[cfg(test)]
mod health_adversarial_tests;
#[cfg(test)]
mod send_adversarial_tests;
use bytes::Bytes;

View File

@ -692,6 +692,7 @@ impl MePool {
}
}
#[allow(dead_code)]
pub(super) fn draining_active_runtime(&self) -> u64 {
self.draining_active_runtime.load(Ordering::Relaxed)
}

View File

@ -42,11 +42,10 @@ impl MePool {
}
for writer_id in closed_writer_ids {
if self.registry.is_writer_empty(writer_id).await {
let _ = self.remove_writer_only(writer_id).await;
} else {
let _ = self.remove_writer_and_close_clients(writer_id).await;
if self.remove_writer_if_empty(writer_id).await {
continue;
}
let _ = self.remove_writer_and_close_clients(writer_id).await;
}
}
@ -501,6 +500,17 @@ impl MePool {
}
}
pub(crate) async fn remove_writer_if_empty(self: &Arc<Self>, writer_id: u64) -> bool {
if !self.registry.unregister_writer_if_empty(writer_id).await {
return false;
}
// The registry empty-check and unregister are atomic with respect to binds,
// so remove_writer_only cannot return active bound sessions here.
let _ = self.remove_writer_only(writer_id).await;
true
}
async fn remove_writer_only(self: &Arc<Self>, writer_id: u64) -> Vec<BoundConn> {
let mut close_tx: Option<mpsc::Sender<WriterCommand>> = None;
let mut removed_addr: Option<SocketAddr> = None;

View File

@ -437,6 +437,24 @@ impl ConnRegistry {
.unwrap_or(true)
}
pub async fn unregister_writer_if_empty(&self, writer_id: u64) -> bool {
let mut inner = self.inner.write().await;
let Some(conn_ids) = inner.conns_for_writer.get(&writer_id) else {
// Writer is already absent from the registry.
return true;
};
if !conn_ids.is_empty() {
return false;
}
inner.writers.remove(&writer_id);
inner.last_meta_for_writer.remove(&writer_id);
inner.writer_idle_since_epoch_secs.remove(&writer_id);
inner.conns_for_writer.remove(&writer_id);
true
}
#[allow(dead_code)]
pub(super) async fn non_empty_writer_ids(&self, writer_ids: &[u64]) -> HashSet<u64> {
let inner = self.inner.read().await;
let mut out = HashSet::<u64>::with_capacity(writer_ids.len());

View File

@ -372,17 +372,20 @@ impl MePool {
}
let effective_our_addr = SocketAddr::new(w.source_ip, our_addr.port());
let (payload, meta) = build_routed_payload(effective_our_addr);
match w.tx.try_send(WriterCommand::Data(payload.clone())) {
Ok(()) => {
self.stats.increment_me_writer_pick_success_try_total(pick_mode);
match w.tx.clone().try_reserve_owned() {
Ok(permit) => {
if !self.registry.bind_writer(conn_id, w.id, meta).await {
debug!(
conn_id,
writer_id = w.id,
"ME writer disappeared before bind commit, retrying"
"ME writer disappeared before bind commit, pruning stale writer"
);
drop(permit);
self.remove_writer_and_close_clients(w.id).await;
continue;
}
permit.send(WriterCommand::Data(payload.clone()));
self.stats.increment_me_writer_pick_success_try_total(pick_mode);
if w.generation < self.current_generation() {
self.stats.increment_pool_stale_pick_total();
debug!(
@ -422,18 +425,21 @@ impl MePool {
self.stats.increment_me_writer_pick_blocking_fallback_total();
let effective_our_addr = SocketAddr::new(w.source_ip, our_addr.port());
let (payload, meta) = build_routed_payload(effective_our_addr);
match w.tx.send(WriterCommand::Data(payload.clone())).await {
Ok(()) => {
self.stats
.increment_me_writer_pick_success_fallback_total(pick_mode);
match w.tx.clone().reserve_owned().await {
Ok(permit) => {
if !self.registry.bind_writer(conn_id, w.id, meta).await {
debug!(
conn_id,
writer_id = w.id,
"ME writer disappeared before fallback bind commit, retrying"
"ME writer disappeared before fallback bind commit, pruning stale writer"
);
drop(permit);
self.remove_writer_and_close_clients(w.id).await;
continue;
}
permit.send(WriterCommand::Data(payload.clone()));
self.stats
.increment_me_writer_pick_success_fallback_total(pick_mode);
if w.generation < self.current_generation() {
self.stats.increment_pool_stale_pick_total();
}

View File

@ -0,0 +1,263 @@
use std::collections::HashMap;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU8, AtomicU32, AtomicU64, Ordering};
use std::time::{Duration, Instant};
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use super::codec::WriterCommand;
use super::pool::{MePool, MeWriter, WriterContour};
use crate::config::{GeneralConfig, MeRouteNoWriterMode, MeSocksKdfPolicy, MeWriterPickMode};
use crate::crypto::SecureRandom;
use crate::network::probe::NetworkDecision;
use crate::stats::Stats;
async fn make_pool() -> (Arc<MePool>, Arc<SecureRandom>) {
let general = GeneralConfig {
me_route_no_writer_mode: MeRouteNoWriterMode::AsyncRecoveryFailfast,
me_route_no_writer_wait_ms: 50,
me_writer_pick_mode: MeWriterPickMode::SortedRr,
me_deterministic_writer_sort: true,
..GeneralConfig::default()
};
let rng = Arc::new(SecureRandom::new());
let pool = MePool::new(
None,
vec![1u8; 32],
None,
false,
None,
Vec::new(),
1,
None,
12,
1200,
HashMap::new(),
HashMap::new(),
None,
NetworkDecision::default(),
None,
rng.clone(),
Arc::new(Stats::default()),
general.me_keepalive_enabled,
general.me_keepalive_interval_secs,
general.me_keepalive_jitter_secs,
general.me_keepalive_payload_random,
general.rpc_proxy_req_every,
general.me_warmup_stagger_enabled,
general.me_warmup_step_delay_ms,
general.me_warmup_step_jitter_ms,
general.me_reconnect_max_concurrent_per_dc,
general.me_reconnect_backoff_base_ms,
general.me_reconnect_backoff_cap_ms,
general.me_reconnect_fast_retry_count,
general.me_single_endpoint_shadow_writers,
general.me_single_endpoint_outage_mode_enabled,
general.me_single_endpoint_outage_disable_quarantine,
general.me_single_endpoint_outage_backoff_min_ms,
general.me_single_endpoint_outage_backoff_max_ms,
general.me_single_endpoint_shadow_rotate_every_secs,
general.me_floor_mode,
general.me_adaptive_floor_idle_secs,
general.me_adaptive_floor_min_writers_single_endpoint,
general.me_adaptive_floor_min_writers_multi_endpoint,
general.me_adaptive_floor_recover_grace_secs,
general.me_adaptive_floor_writers_per_core_total,
general.me_adaptive_floor_cpu_cores_override,
general.me_adaptive_floor_max_extra_writers_single_per_core,
general.me_adaptive_floor_max_extra_writers_multi_per_core,
general.me_adaptive_floor_max_active_writers_per_core,
general.me_adaptive_floor_max_warm_writers_per_core,
general.me_adaptive_floor_max_active_writers_global,
general.me_adaptive_floor_max_warm_writers_global,
general.hardswap,
general.me_pool_drain_ttl_secs,
general.me_pool_drain_threshold,
general.effective_me_pool_force_close_secs(),
general.me_pool_min_fresh_ratio,
general.me_hardswap_warmup_delay_min_ms,
general.me_hardswap_warmup_delay_max_ms,
general.me_hardswap_warmup_extra_passes,
general.me_hardswap_warmup_pass_backoff_base_ms,
general.me_bind_stale_mode,
general.me_bind_stale_ttl_secs,
general.me_secret_atomic_snapshot,
general.me_deterministic_writer_sort,
general.me_writer_pick_mode,
general.me_writer_pick_sample_size,
MeSocksKdfPolicy::default(),
general.me_writer_cmd_channel_capacity,
general.me_route_channel_capacity,
general.me_route_backpressure_base_timeout_ms,
general.me_route_backpressure_high_timeout_ms,
general.me_route_backpressure_high_watermark_pct,
general.me_reader_route_data_wait_ms,
general.me_health_interval_ms_unhealthy,
general.me_health_interval_ms_healthy,
general.me_warn_rate_limit_ms,
general.me_route_no_writer_mode,
general.me_route_no_writer_wait_ms,
general.me_route_inline_recovery_attempts,
general.me_route_inline_recovery_wait_ms,
);
(pool, rng)
}
async fn insert_writer(
pool: &Arc<MePool>,
writer_id: u64,
writer_dc: i32,
addr: SocketAddr,
register_in_registry: bool,
) -> mpsc::Receiver<WriterCommand> {
let (tx, rx) = mpsc::channel::<WriterCommand>(8);
let writer = MeWriter {
id: writer_id,
addr,
source_ip: addr.ip(),
writer_dc,
generation: pool.current_generation(),
contour: Arc::new(AtomicU8::new(WriterContour::Active.as_u8())),
created_at: Instant::now(),
tx: tx.clone(),
cancel: CancellationToken::new(),
degraded: Arc::new(AtomicBool::new(false)),
rtt_ema_ms_x10: Arc::new(AtomicU32::new(0)),
draining: Arc::new(AtomicBool::new(false)),
draining_started_at_epoch_secs: Arc::new(AtomicU64::new(0)),
drain_deadline_epoch_secs: Arc::new(AtomicU64::new(0)),
allow_drain_fallback: Arc::new(AtomicBool::new(false)),
};
pool.writers.write().await.push(writer);
{
let mut map = pool.proxy_map_v4.write().await;
map.entry(writer_dc)
.or_insert_with(Vec::new)
.push((addr.ip(), addr.port()));
}
pool.rebuild_endpoint_dc_map().await;
if register_in_registry {
pool.registry.register_writer(writer_id, tx).await;
}
rx
}
async fn recv_data_count(rx: &mut mpsc::Receiver<WriterCommand>, budget: Duration) -> usize {
let start = Instant::now();
let mut data_count = 0usize;
while Instant::now().duration_since(start) < budget {
let remaining = budget.saturating_sub(Instant::now().duration_since(start));
match tokio::time::timeout(remaining.min(Duration::from_millis(10)), rx.recv()).await {
Ok(Some(WriterCommand::Data(_))) => data_count += 1,
Ok(Some(WriterCommand::DataAndFlush(_))) => data_count += 1,
Ok(Some(WriterCommand::Close)) => {}
Ok(None) => break,
Err(_) => break,
}
}
data_count
}
#[tokio::test]
async fn send_proxy_req_does_not_replay_when_first_bind_commit_fails() {
let (pool, _rng) = make_pool().await;
pool.rr.store(0, Ordering::Relaxed);
let (conn_id, _rx) = pool.registry.register().await;
let mut stale_rx = insert_writer(
&pool,
10,
2,
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 10)), 443),
false,
)
.await;
let mut live_rx = insert_writer(
&pool,
11,
2,
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 11)), 443),
true,
)
.await;
let result = pool
.send_proxy_req(
conn_id,
2,
SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 30000),
SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443),
b"hello",
0,
None,
)
.await;
assert!(result.is_ok());
assert_eq!(recv_data_count(&mut stale_rx, Duration::from_millis(50)).await, 0);
assert_eq!(recv_data_count(&mut live_rx, Duration::from_millis(50)).await, 1);
let bound = pool.registry.get_writer(conn_id).await;
assert!(bound.is_some());
assert_eq!(bound.expect("writer should be bound").writer_id, 11);
}
#[tokio::test]
async fn send_proxy_req_prunes_iterative_stale_bind_failures_without_data_replay() {
let (pool, _rng) = make_pool().await;
pool.rr.store(0, Ordering::Relaxed);
let (conn_id, _rx) = pool.registry.register().await;
let mut stale_rx_1 = insert_writer(
&pool,
21,
2,
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 1, 21)), 443),
false,
)
.await;
let mut stale_rx_2 = insert_writer(
&pool,
22,
2,
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 1, 22)), 443),
false,
)
.await;
let mut live_rx = insert_writer(
&pool,
23,
2,
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 1, 23)), 443),
true,
)
.await;
let result = pool
.send_proxy_req(
conn_id,
2,
SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 30001),
SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443),
b"storm",
0,
None,
)
.await;
assert!(result.is_ok());
assert_eq!(recv_data_count(&mut stale_rx_1, Duration::from_millis(50)).await, 0);
assert_eq!(recv_data_count(&mut stale_rx_2, Duration::from_millis(50)).await, 0);
assert_eq!(recv_data_count(&mut live_rx, Duration::from_millis(50)).await, 1);
let writers = pool.writers.read().await;
let writer_ids = writers.iter().map(|w| w.id).collect::<Vec<_>>();
drop(writers);
assert_eq!(writer_ids, vec![23]);
}