This commit is contained in:
David Osipov 2026-03-18 17:05:13 +04:00 committed by GitHub
commit c5b8ab68bc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
43 changed files with 13476 additions and 402 deletions

View File

@ -390,6 +390,12 @@ you MUST explain why existing invariants remain valid.
- Do not modify existing tests unless the task explicitly requires it.
- Do not weaken assertions.
- Preserve determinism in testable components.
- Bug-first forces the discipline of proving you understand a bug before you fix it. Tests written after a fix almost always pass trivially and catch nothing new.
- Invariants over scenarios is the core shift. The route_mode table alone would have caught both BUG-1 and BUG-2 before they were written — "snapshot equals watch state after any transition burst" is a two-line property test that fails immediately on the current diverged-atomics code.
- Differential/model catches logic drift over time.
- Scheduler pressure is specifically aimed at the concurrent state bugs that keep reappearing. A single-threaded happy-path test of set_mode will never find subtle bugs; 10,000 concurrent calls will find it on the first run.
- Mutation gate answers your original complaint directly. It measures test power. If you can remove a bounds check and nothing breaks, the suite isn't covering that branch yet — it just says so explicitly.
- Dead parameter is a code smell rule.
### 15. Security Constraints

59
Cargo.lock generated
View File

@ -425,6 +425,32 @@ dependencies = [
"cipher",
]
[[package]]
name = "curve25519-dalek"
version = "4.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "97fb8b7c4503de7d6ae7b42ab72a5a59857b4c937ec27a3d4539dba95b5ab2be"
dependencies = [
"cfg-if",
"cpufeatures",
"curve25519-dalek-derive",
"fiat-crypto",
"rustc_version",
"subtle",
"zeroize",
]
[[package]]
name = "curve25519-dalek-derive"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f46882e17999c6cc590af592290432be3bce0428cb0d5f8b6715e4dc7b383eb3"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.114",
]
[[package]]
name = "dashmap"
version = "5.5.3"
@ -517,6 +543,12 @@ version = "2.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
[[package]]
name = "fiat-crypto"
version = "0.2.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "28dea519a9695b9977216879a3ebfddf92f1c08c05d984f8996aecd6ecdc811d"
[[package]]
name = "filetime"
version = "0.2.27"
@ -1609,7 +1641,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1"
dependencies = [
"rand_chacha",
"rand_core",
"rand_core 0.9.5",
]
[[package]]
@ -1619,9 +1651,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb"
dependencies = [
"ppv-lite86",
"rand_core",
"rand_core 0.9.5",
]
[[package]]
name = "rand_core"
version = "0.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
[[package]]
name = "rand_core"
version = "0.9.5"
@ -1637,7 +1675,7 @@ version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "513962919efc330f829edb2535844d1b912b0fbe2ca165d613e4e8788bb05a5a"
dependencies = [
"rand_core",
"rand_core 0.9.5",
]
[[package]]
@ -2093,7 +2131,7 @@ dependencies = [
[[package]]
name = "telemt"
version = "3.3.19"
version = "3.3.20"
dependencies = [
"aes",
"anyhow",
@ -2145,6 +2183,7 @@ dependencies = [
"tracing-subscriber",
"url",
"webpki-roots 0.26.11",
"x25519-dalek",
"x509-parser",
"zeroize",
]
@ -3144,6 +3183,18 @@ version = "0.6.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9edde0db4769d2dc68579893f2306b26c6ecfbe0ef499b013d731b7b9247e0b9"
[[package]]
name = "x25519-dalek"
version = "2.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c7e468321c81fb07fa7f4c636c3972b9100f0346e5b6a9f2bd0603a52f7ed277"
dependencies = [
"curve25519-dalek",
"rand_core 0.6.4",
"serde",
"zeroize",
]
[[package]]
name = "x509-parser"
version = "0.15.1"

View File

@ -52,6 +52,7 @@ regex = "1.11"
crossbeam-queue = "0.3"
num-bigint = "0.4"
num-traits = "0.2"
x25519-dalek = "2"
anyhow = "1.0"
# HTTP

View File

@ -32,6 +32,7 @@ show = "*"
port = 443
# proxy_protocol = false # Enable if behind HAProxy/nginx with PROXY protocol
# metrics_port = 9090
# metrics_listen = "0.0.0.0:9090" # Listen address for metrics (overrides metrics_port)
# metrics_whitelist = ["127.0.0.1", "::1", "0.0.0.0/0"]
[server.api]

View File

@ -38,8 +38,9 @@ umweltschutz.de -> A-запись 198.18.88.88
В конфигурации Telemt:
```
tls_domain = umweltschutz.de
```toml
[censorship]
tls_domain = "umweltschutz.de"
```
Этот домен используется клиентом как SNI в ClientHello
@ -56,8 +57,9 @@ tls_domain = umweltschutz.de
В конфигурации Telemt:
```
mask_host = 127.0.0.1
```toml
[censorship]
mask_host = "127.0.0.1"
mask_port = 8443
```
@ -151,16 +153,18 @@ mask_host:mask_port
Например:
```
tls_domain = github.com
mask_host = github.com
```toml
[censorship]
tls_domain = "github.com"
mask_host = "github.com"
mask_port = 443
```
или
```
mask_host = 140.82.121.4
```toml
[censorship]
mask_host = "140.82.121.4"
```
В этом случае:

View File

@ -239,7 +239,7 @@ tls_full_cert_ttl_secs = 90
[access]
replay_check_len = 65536
replay_window_secs = 1800
replay_window_secs = 120
ignore_time_skew = false
[access.users]

View File

@ -73,7 +73,9 @@ pub(crate) fn default_replay_check_len() -> usize {
}
pub(crate) fn default_replay_window_secs() -> u64 {
1800
// Keep replay cache TTL tight by default to reduce replay surface.
// Deployments with higher RTT or longer reconnect jitter can override this in config.
120
}
pub(crate) fn default_handshake_timeout() -> u64 {
@ -456,11 +458,11 @@ pub(crate) fn default_tls_full_cert_ttl_secs() -> u64 {
}
pub(crate) fn default_server_hello_delay_min_ms() -> u64 {
0
8
}
pub(crate) fn default_server_hello_delay_max_ms() -> u64 {
0
24
}
pub(crate) fn default_alpn_enforce() -> bool {

View File

@ -1163,9 +1163,17 @@ pub struct ServerConfig {
#[serde(default)]
pub proxy_protocol_trusted_cidrs: Vec<IpNetwork>,
/// Port for the Prometheus-compatible metrics endpoint.
/// Enables metrics when set; binds on all interfaces (dual-stack) by default.
#[serde(default)]
pub metrics_port: Option<u16>,
/// Listen address for metrics in `IP:PORT` format (e.g. `"127.0.0.1:9090"`).
/// When set, takes precedence over `metrics_port` and binds on the specified address only.
#[serde(default)]
pub metrics_listen: Option<String>,
/// CIDR whitelist for the metrics endpoint.
#[serde(default = "default_metrics_whitelist")]
pub metrics_whitelist: Vec<IpNetwork>,
@ -1194,6 +1202,7 @@ impl Default for ServerConfig {
proxy_protocol_header_timeout_ms: default_proxy_protocol_header_timeout_ms(),
proxy_protocol_trusted_cidrs: Vec::new(),
metrics_port: None,
metrics_listen: None,
metrics_whitelist: default_metrics_whitelist(),
api: ApiConfig::default(),
listeners: Vec::new(),

View File

@ -7,8 +7,9 @@ use std::net::IpAddr;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
use std::sync::Mutex;
use tokio::sync::RwLock;
use tokio::sync::{Mutex as AsyncMutex, RwLock};
use crate::config::UserMaxUniqueIpsMode;
@ -21,6 +22,8 @@ pub struct UserIpTracker {
limit_mode: Arc<RwLock<UserMaxUniqueIpsMode>>,
limit_window: Arc<RwLock<Duration>>,
last_compact_epoch_secs: Arc<AtomicU64>,
pub(crate) cleanup_queue: Arc<Mutex<Vec<(String, IpAddr)>>>,
cleanup_drain_lock: Arc<AsyncMutex<()>>,
}
impl UserIpTracker {
@ -33,6 +36,67 @@ impl UserIpTracker {
limit_mode: Arc::new(RwLock::new(UserMaxUniqueIpsMode::ActiveWindow)),
limit_window: Arc::new(RwLock::new(Duration::from_secs(30))),
last_compact_epoch_secs: Arc::new(AtomicU64::new(0)),
cleanup_queue: Arc::new(Mutex::new(Vec::new())),
cleanup_drain_lock: Arc::new(AsyncMutex::new(())),
}
}
pub fn enqueue_cleanup(&self, user: String, ip: IpAddr) {
match self.cleanup_queue.lock() {
Ok(mut queue) => queue.push((user, ip)),
Err(poisoned) => {
let mut queue = poisoned.into_inner();
queue.push((user.clone(), ip));
self.cleanup_queue.clear_poison();
tracing::warn!(
"UserIpTracker cleanup_queue lock poisoned; recovered and enqueued IP cleanup for {} ({})",
user,
ip
);
}
}
}
pub(crate) async fn drain_cleanup_queue(&self) {
// Serialize queue draining and active-IP mutation so check-and-add cannot
// observe stale active entries that are already queued for removal.
let _drain_guard = self.cleanup_drain_lock.lock().await;
let to_remove = {
match self.cleanup_queue.lock() {
Ok(mut queue) => {
if queue.is_empty() {
return;
}
std::mem::take(&mut *queue)
}
Err(poisoned) => {
let mut queue = poisoned.into_inner();
if queue.is_empty() {
self.cleanup_queue.clear_poison();
return;
}
let drained = std::mem::take(&mut *queue);
self.cleanup_queue.clear_poison();
drained
}
}
};
let mut active_ips = self.active_ips.write().await;
for (user, ip) in to_remove {
if let Some(user_ips) = active_ips.get_mut(&user) {
if let Some(count) = user_ips.get_mut(&ip) {
if *count > 1 {
*count -= 1;
} else {
user_ips.remove(&ip);
}
}
if user_ips.is_empty() {
active_ips.remove(&user);
}
}
}
}
@ -118,6 +182,7 @@ impl UserIpTracker {
}
pub async fn check_and_add(&self, username: &str, ip: IpAddr) -> Result<(), String> {
self.drain_cleanup_queue().await;
self.maybe_compact_empty_users().await;
let default_max_ips = *self.default_max_ips.read().await;
let limit = {
@ -194,6 +259,7 @@ impl UserIpTracker {
}
pub async fn get_recent_counts_for_users(&self, users: &[String]) -> HashMap<String, usize> {
self.drain_cleanup_queue().await;
let window = *self.limit_window.read().await;
let now = Instant::now();
let recent_ips = self.recent_ips.read().await;
@ -214,6 +280,7 @@ impl UserIpTracker {
}
pub async fn get_active_ips_for_users(&self, users: &[String]) -> HashMap<String, Vec<IpAddr>> {
self.drain_cleanup_queue().await;
let active_ips = self.active_ips.read().await;
let mut out = HashMap::with_capacity(users.len());
for user in users {
@ -228,6 +295,7 @@ impl UserIpTracker {
}
pub async fn get_recent_ips_for_users(&self, users: &[String]) -> HashMap<String, Vec<IpAddr>> {
self.drain_cleanup_queue().await;
let window = *self.limit_window.read().await;
let now = Instant::now();
let recent_ips = self.recent_ips.read().await;
@ -250,11 +318,13 @@ impl UserIpTracker {
}
pub async fn get_active_ip_count(&self, username: &str) -> usize {
self.drain_cleanup_queue().await;
let active_ips = self.active_ips.read().await;
active_ips.get(username).map(|ips| ips.len()).unwrap_or(0)
}
pub async fn get_active_ips(&self, username: &str) -> Vec<IpAddr> {
self.drain_cleanup_queue().await;
let active_ips = self.active_ips.read().await;
active_ips
.get(username)
@ -263,6 +333,7 @@ impl UserIpTracker {
}
pub async fn get_stats(&self) -> Vec<(String, usize, usize)> {
self.drain_cleanup_queue().await;
let active_ips = self.active_ips.read().await;
let max_ips = self.max_ips.read().await;
let default_max_ips = *self.default_max_ips.read().await;
@ -301,6 +372,7 @@ impl UserIpTracker {
}
pub async fn is_ip_active(&self, username: &str, ip: IpAddr) -> bool {
self.drain_cleanup_queue().await;
let active_ips = self.active_ips.read().await;
active_ips
.get(username)

View File

@ -0,0 +1,619 @@
use std::collections::HashMap;
use std::net::{IpAddr, Ipv4Addr};
use std::sync::Arc;
use std::time::Duration;
use crate::config::UserMaxUniqueIpsMode;
use crate::ip_tracker::UserIpTracker;
fn ip_from_idx(idx: u32) -> IpAddr {
let a = 10u8;
let b = ((idx / 65_536) % 256) as u8;
let c = ((idx / 256) % 256) as u8;
let d = (idx % 256) as u8;
IpAddr::V4(Ipv4Addr::new(a, b, c, d))
}
#[tokio::test]
async fn active_window_enforces_large_unique_ip_burst() {
let tracker = UserIpTracker::new();
tracker.set_user_limit("burst_user", 64).await;
tracker
.set_limit_policy(UserMaxUniqueIpsMode::ActiveWindow, 30)
.await;
for idx in 0..64 {
assert!(tracker.check_and_add("burst_user", ip_from_idx(idx)).await.is_ok());
}
assert!(tracker.check_and_add("burst_user", ip_from_idx(9_999)).await.is_err());
assert_eq!(tracker.get_active_ip_count("burst_user").await, 64);
}
#[tokio::test]
async fn global_limit_applies_across_many_users() {
let tracker = UserIpTracker::new();
tracker.load_limits(3, &HashMap::new()).await;
for user_idx in 0..150u32 {
let user = format!("u{}", user_idx);
assert!(tracker.check_and_add(&user, ip_from_idx(user_idx * 10)).await.is_ok());
assert!(tracker
.check_and_add(&user, ip_from_idx(user_idx * 10 + 1))
.await
.is_ok());
assert!(tracker
.check_and_add(&user, ip_from_idx(user_idx * 10 + 2))
.await
.is_ok());
assert!(tracker
.check_and_add(&user, ip_from_idx(user_idx * 10 + 3))
.await
.is_err());
}
assert_eq!(tracker.get_stats().await.len(), 150);
}
#[tokio::test]
async fn user_zero_override_falls_back_to_global_limit() {
let tracker = UserIpTracker::new();
let mut limits = HashMap::new();
limits.insert("target".to_string(), 0);
tracker.load_limits(2, &limits).await;
assert!(tracker.check_and_add("target", ip_from_idx(1)).await.is_ok());
assert!(tracker.check_and_add("target", ip_from_idx(2)).await.is_ok());
assert!(tracker.check_and_add("target", ip_from_idx(3)).await.is_err());
assert_eq!(tracker.get_user_limit("target").await, Some(2));
}
#[tokio::test]
async fn remove_ip_is_idempotent_after_counter_reaches_zero() {
let tracker = UserIpTracker::new();
tracker.set_user_limit("u", 2).await;
let ip = ip_from_idx(42);
tracker.check_and_add("u", ip).await.unwrap();
tracker.remove_ip("u", ip).await;
tracker.remove_ip("u", ip).await;
tracker.remove_ip("u", ip).await;
assert_eq!(tracker.get_active_ip_count("u").await, 0);
assert!(!tracker.is_ip_active("u", ip).await);
}
#[tokio::test]
async fn clear_user_ips_resets_active_and_recent() {
let tracker = UserIpTracker::new();
tracker.set_user_limit("u", 10).await;
for idx in 0..6 {
tracker.check_and_add("u", ip_from_idx(idx)).await.unwrap();
}
tracker.clear_user_ips("u").await;
assert_eq!(tracker.get_active_ip_count("u").await, 0);
let counts = tracker
.get_recent_counts_for_users(&["u".to_string()])
.await;
assert_eq!(counts.get("u").copied().unwrap_or(0), 0);
}
#[tokio::test]
async fn clear_all_resets_multi_user_state() {
let tracker = UserIpTracker::new();
for user_idx in 0..80u32 {
let user = format!("u{}", user_idx);
for ip_idx in 0..3 {
tracker
.check_and_add(&user, ip_from_idx(user_idx * 100 + ip_idx))
.await
.unwrap();
}
}
tracker.clear_all().await;
assert!(tracker.get_stats().await.is_empty());
let users = (0..80u32)
.map(|idx| format!("u{}", idx))
.collect::<Vec<_>>();
let recent = tracker.get_recent_counts_for_users(&users).await;
assert!(recent.values().all(|count| *count == 0));
}
#[tokio::test]
async fn get_active_ips_for_users_are_sorted() {
let tracker = UserIpTracker::new();
tracker.set_user_limit("user", 10).await;
tracker
.check_and_add("user", IpAddr::V4(Ipv4Addr::new(10, 0, 0, 9)))
.await
.unwrap();
tracker
.check_and_add("user", IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)))
.await
.unwrap();
tracker
.check_and_add("user", IpAddr::V4(Ipv4Addr::new(10, 0, 0, 5)))
.await
.unwrap();
let map = tracker
.get_active_ips_for_users(&["user".to_string()])
.await;
let ips = map.get("user").cloned().unwrap_or_default();
assert_eq!(
ips,
vec![
IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
IpAddr::V4(Ipv4Addr::new(10, 0, 0, 5)),
IpAddr::V4(Ipv4Addr::new(10, 0, 0, 9)),
]
);
}
#[tokio::test]
async fn get_recent_ips_for_users_are_sorted() {
let tracker = UserIpTracker::new();
tracker.set_user_limit("user", 10).await;
tracker
.check_and_add("user", IpAddr::V4(Ipv4Addr::new(10, 1, 0, 9)))
.await
.unwrap();
tracker
.check_and_add("user", IpAddr::V4(Ipv4Addr::new(10, 1, 0, 1)))
.await
.unwrap();
tracker
.check_and_add("user", IpAddr::V4(Ipv4Addr::new(10, 1, 0, 5)))
.await
.unwrap();
let map = tracker
.get_recent_ips_for_users(&["user".to_string()])
.await;
let ips = map.get("user").cloned().unwrap_or_default();
assert_eq!(
ips,
vec![
IpAddr::V4(Ipv4Addr::new(10, 1, 0, 1)),
IpAddr::V4(Ipv4Addr::new(10, 1, 0, 5)),
IpAddr::V4(Ipv4Addr::new(10, 1, 0, 9)),
]
);
}
#[tokio::test]
async fn time_window_expires_for_large_rotation() {
let tracker = UserIpTracker::new();
tracker.set_user_limit("tw", 1).await;
tracker
.set_limit_policy(UserMaxUniqueIpsMode::TimeWindow, 1)
.await;
tracker.check_and_add("tw", ip_from_idx(1)).await.unwrap();
tracker.remove_ip("tw", ip_from_idx(1)).await;
assert!(tracker.check_and_add("tw", ip_from_idx(2)).await.is_err());
tokio::time::sleep(Duration::from_millis(1_100)).await;
assert!(tracker.check_and_add("tw", ip_from_idx(2)).await.is_ok());
}
#[tokio::test]
async fn combined_mode_blocks_recent_after_disconnect() {
let tracker = UserIpTracker::new();
tracker.set_user_limit("cmb", 1).await;
tracker
.set_limit_policy(UserMaxUniqueIpsMode::Combined, 2)
.await;
tracker.check_and_add("cmb", ip_from_idx(11)).await.unwrap();
tracker.remove_ip("cmb", ip_from_idx(11)).await;
assert!(tracker.check_and_add("cmb", ip_from_idx(12)).await.is_err());
}
#[tokio::test]
async fn load_limits_replaces_large_limit_map() {
let tracker = UserIpTracker::new();
let mut first = HashMap::new();
let mut second = HashMap::new();
for idx in 0..300usize {
first.insert(format!("u{}", idx), 2usize);
}
for idx in 150..450usize {
second.insert(format!("u{}", idx), 4usize);
}
tracker.load_limits(0, &first).await;
tracker.load_limits(0, &second).await;
assert_eq!(tracker.get_user_limit("u20").await, None);
assert_eq!(tracker.get_user_limit("u200").await, Some(4));
assert_eq!(tracker.get_user_limit("u420").await, Some(4));
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn concurrent_same_user_unique_ip_pressure_stays_bounded() {
let tracker = Arc::new(UserIpTracker::new());
tracker.set_user_limit("hot", 32).await;
tracker
.set_limit_policy(UserMaxUniqueIpsMode::ActiveWindow, 30)
.await;
let mut handles = Vec::new();
for worker in 0..16u32 {
let tracker_cloned = tracker.clone();
handles.push(tokio::spawn(async move {
let base = worker * 200;
for step in 0..200u32 {
let _ = tracker_cloned
.check_and_add("hot", ip_from_idx(base + step))
.await;
}
}));
}
for handle in handles {
handle.await.unwrap();
}
assert!(tracker.get_active_ip_count("hot").await <= 32);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn concurrent_many_users_isolate_limits() {
let tracker = Arc::new(UserIpTracker::new());
tracker.load_limits(4, &HashMap::new()).await;
let mut handles = Vec::new();
for user_idx in 0..120u32 {
let tracker_cloned = tracker.clone();
handles.push(tokio::spawn(async move {
let user = format!("u{}", user_idx);
for ip_idx in 0..10u32 {
let _ = tracker_cloned
.check_and_add(&user, ip_from_idx(user_idx * 1_000 + ip_idx))
.await;
}
}));
}
for handle in handles {
handle.await.unwrap();
}
let stats = tracker.get_stats().await;
assert_eq!(stats.len(), 120);
assert!(stats.iter().all(|(_, active, limit)| *active <= 4 && *limit == 4));
}
#[tokio::test]
async fn same_ip_reconnect_high_frequency_keeps_single_unique() {
let tracker = UserIpTracker::new();
tracker.set_user_limit("same", 2).await;
let ip = ip_from_idx(9);
for _ in 0..2_000 {
tracker.check_and_add("same", ip).await.unwrap();
}
assert_eq!(tracker.get_active_ip_count("same").await, 1);
assert!(tracker.is_ip_active("same", ip).await);
}
#[tokio::test]
async fn format_stats_contains_expected_limited_and_unlimited_markers() {
let tracker = UserIpTracker::new();
tracker.set_user_limit("limited", 2).await;
tracker.check_and_add("limited", ip_from_idx(1)).await.unwrap();
tracker.check_and_add("open", ip_from_idx(2)).await.unwrap();
let text = tracker.format_stats().await;
assert!(text.contains("limited"));
assert!(text.contains("open"));
assert!(text.contains("unlimited"));
}
#[tokio::test]
async fn stats_report_global_default_for_users_without_override() {
let tracker = UserIpTracker::new();
tracker.load_limits(5, &HashMap::new()).await;
tracker.check_and_add("a", ip_from_idx(1)).await.unwrap();
tracker.check_and_add("b", ip_from_idx(2)).await.unwrap();
let stats = tracker.get_stats().await;
assert!(stats.iter().any(|(user, _, limit)| user == "a" && *limit == 5));
assert!(stats.iter().any(|(user, _, limit)| user == "b" && *limit == 5));
}
#[tokio::test]
async fn stress_cycle_add_remove_clear_preserves_empty_end_state() {
let tracker = UserIpTracker::new();
for cycle in 0..50u32 {
let user = format!("cycle{}", cycle);
tracker.set_user_limit(&user, 128).await;
for ip_idx in 0..128u32 {
tracker
.check_and_add(&user, ip_from_idx(cycle * 10_000 + ip_idx))
.await
.unwrap();
}
for ip_idx in 0..128u32 {
tracker
.remove_ip(&user, ip_from_idx(cycle * 10_000 + ip_idx))
.await;
}
tracker.clear_user_ips(&user).await;
}
assert!(tracker.get_stats().await.is_empty());
}
#[tokio::test]
async fn remove_unknown_user_or_ip_does_not_corrupt_state() {
let tracker = UserIpTracker::new();
tracker.remove_ip("no_user", ip_from_idx(1)).await;
tracker.check_and_add("x", ip_from_idx(2)).await.unwrap();
tracker.remove_ip("x", ip_from_idx(3)).await;
assert_eq!(tracker.get_active_ip_count("x").await, 1);
assert!(tracker.is_ip_active("x", ip_from_idx(2)).await);
}
#[tokio::test]
async fn active_and_recent_views_match_after_mixed_workload() {
let tracker = UserIpTracker::new();
tracker.set_user_limit("mix", 16).await;
for ip_idx in 0..12u32 {
tracker.check_and_add("mix", ip_from_idx(ip_idx)).await.unwrap();
}
for ip_idx in 0..6u32 {
tracker.remove_ip("mix", ip_from_idx(ip_idx)).await;
}
let active = tracker
.get_active_ips_for_users(&["mix".to_string()])
.await
.get("mix")
.cloned()
.unwrap_or_default();
let recent_count = tracker
.get_recent_counts_for_users(&["mix".to_string()])
.await
.get("mix")
.copied()
.unwrap_or(0);
assert_eq!(active.len(), 6);
assert!(recent_count >= active.len());
assert!(recent_count <= 12);
}
#[tokio::test]
async fn global_limit_switch_updates_enforcement_immediately() {
let tracker = UserIpTracker::new();
tracker.load_limits(2, &HashMap::new()).await;
assert!(tracker.check_and_add("u", ip_from_idx(1)).await.is_ok());
assert!(tracker.check_and_add("u", ip_from_idx(2)).await.is_ok());
assert!(tracker.check_and_add("u", ip_from_idx(3)).await.is_err());
tracker.clear_user_ips("u").await;
tracker.load_limits(4, &HashMap::new()).await;
assert!(tracker.check_and_add("u", ip_from_idx(1)).await.is_ok());
assert!(tracker.check_and_add("u", ip_from_idx(2)).await.is_ok());
assert!(tracker.check_and_add("u", ip_from_idx(3)).await.is_ok());
assert!(tracker.check_and_add("u", ip_from_idx(4)).await.is_ok());
assert!(tracker.check_and_add("u", ip_from_idx(5)).await.is_err());
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn concurrent_reconnect_and_disconnect_preserves_non_negative_counts() {
let tracker = Arc::new(UserIpTracker::new());
tracker.set_user_limit("cc", 8).await;
let mut handles = Vec::new();
for worker in 0..8u32 {
let tracker_cloned = tracker.clone();
handles.push(tokio::spawn(async move {
let ip = ip_from_idx(50 + worker);
for _ in 0..500u32 {
let _ = tracker_cloned.check_and_add("cc", ip).await;
tracker_cloned.remove_ip("cc", ip).await;
}
}));
}
for handle in handles {
handle.await.unwrap();
}
assert!(tracker.get_active_ip_count("cc").await <= 8);
}
#[tokio::test]
async fn enqueue_cleanup_recovers_from_poisoned_mutex() {
let tracker = UserIpTracker::new();
let ip = ip_from_idx(99);
// Poison the lock by panicking while holding it
let result = std::panic::catch_unwind(|| {
let _guard = tracker.cleanup_queue.lock().unwrap();
panic!("Intentional poison panic");
});
assert!(result.is_err(), "Expected panic to poison mutex");
// Attempt to enqueue anyway; should hit the poison catch arm and still insert
tracker.enqueue_cleanup("poison-user".to_string(), ip);
tracker.drain_cleanup_queue().await;
assert_eq!(tracker.get_active_ip_count("poison-user").await, 0);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn mass_reconnect_sync_cleanup_prevents_temporary_reservation_bloat() {
// Tests that synchronous M-01 drop mechanism protects against starvation
let tracker = Arc::new(UserIpTracker::new());
tracker.set_user_limit("mass", 5).await;
let ip = ip_from_idx(42);
let mut join_handles = Vec::new();
// 10,000 rapid concurrent requests hitting the same IP limit
for _ in 0..10_000 {
let tracker_clone = tracker.clone();
join_handles.push(tokio::spawn(async move {
if tracker_clone.check_and_add("mass", ip).await.is_ok() {
// Instantly enqueue cleanup, simulating synchronous reservation drop
tracker_clone.enqueue_cleanup("mass".to_string(), ip);
// The next caller will drain it before acquiring again
}
}));
}
for handle in join_handles {
let _ = handle.await;
}
// Force flush
tracker.drain_cleanup_queue().await;
assert_eq!(tracker.get_active_ip_count("mass").await, 0, "No leaked footprints");
}
#[tokio::test]
async fn adversarial_drain_cleanup_queue_race_does_not_cause_false_rejections() {
// Regression guard: concurrent cleanup draining must not produce false
// limit denials for a new IP when the previous IP is already queued.
let tracker = Arc::new(UserIpTracker::new());
tracker.set_user_limit("racer", 1).await;
let ip1 = ip_from_idx(1);
let ip2 = ip_from_idx(2);
// Initial state: add ip1
tracker.check_and_add("racer", ip1).await.unwrap();
// User disconnects from ip1, queuing it
tracker.enqueue_cleanup("racer".to_string(), ip1);
let mut saw_false_rejection = false;
for _ in 0..100 {
// Queue cleanup then race explicit drain and check-and-add on the alternative IP.
tracker.enqueue_cleanup("racer".to_string(), ip1);
let tracker_a = tracker.clone();
let tracker_b = tracker.clone();
let drain_handle = tokio::spawn(async move {
tracker_a.drain_cleanup_queue().await;
});
let handle = tokio::spawn(async move {
tracker_b.check_and_add("racer", ip2).await
});
drain_handle.await.unwrap();
let res = handle.await.unwrap();
if res.is_err() {
saw_false_rejection = true;
break;
}
// Restore baseline for next iteration.
tracker.remove_ip("racer", ip2).await;
tracker.check_and_add("racer", ip1).await.unwrap();
}
assert!(
!saw_false_rejection,
"Concurrent cleanup draining must not cause false-positive IP denials"
);
}
#[tokio::test]
async fn poisoned_cleanup_queue_still_releases_slot_for_next_ip() {
let tracker = UserIpTracker::new();
tracker.set_user_limit("poison-slot", 1).await;
let ip1 = ip_from_idx(7001);
let ip2 = ip_from_idx(7002);
tracker.check_and_add("poison-slot", ip1).await.unwrap();
// Poison the queue lock as an adversarial condition.
let _ = std::panic::catch_unwind(|| {
let _guard = tracker.cleanup_queue.lock().unwrap();
panic!("intentional queue poison");
});
// Disconnect path must still queue cleanup so the next IP can be admitted.
tracker.enqueue_cleanup("poison-slot".to_string(), ip1);
let admitted = tracker.check_and_add("poison-slot", ip2).await;
assert!(
admitted.is_ok(),
"cleanup queue poison must not permanently block slot release for the next IP"
);
}
#[tokio::test]
async fn duplicate_cleanup_entries_do_not_break_future_admission() {
let tracker = UserIpTracker::new();
tracker.set_user_limit("dup-cleanup", 1).await;
let ip1 = ip_from_idx(7101);
let ip2 = ip_from_idx(7102);
tracker.check_and_add("dup-cleanup", ip1).await.unwrap();
tracker.enqueue_cleanup("dup-cleanup".to_string(), ip1);
tracker.enqueue_cleanup("dup-cleanup".to_string(), ip1);
tracker.enqueue_cleanup("dup-cleanup".to_string(), ip1);
tracker.drain_cleanup_queue().await;
assert_eq!(tracker.get_active_ip_count("dup-cleanup").await, 0);
assert!(
tracker.check_and_add("dup-cleanup", ip2).await.is_ok(),
"extra queued cleanup entries must not leave user stuck in denied state"
);
}
#[tokio::test]
async fn stress_repeated_queue_poison_recovery_preserves_admission_progress() {
let tracker = UserIpTracker::new();
tracker.set_user_limit("poison-stress", 1).await;
let ip_primary = ip_from_idx(7201);
let ip_alt = ip_from_idx(7202);
tracker.check_and_add("poison-stress", ip_primary).await.unwrap();
for _ in 0..64 {
let _ = std::panic::catch_unwind(|| {
let _guard = tracker.cleanup_queue.lock().unwrap();
panic!("intentional queue poison in stress loop");
});
tracker.enqueue_cleanup("poison-stress".to_string(), ip_primary);
assert!(
tracker.check_and_add("poison-stress", ip_alt).await.is_ok(),
"poison recovery must preserve admission progress under repeated queue poisoning"
);
tracker.remove_ip("poison-stress", ip_alt).await;
tracker.check_and_add("poison-stress", ip_primary).await.unwrap();
}
}

View File

@ -279,11 +279,32 @@ pub(crate) async fn spawn_metrics_if_configured(
ip_tracker: Arc<UserIpTracker>,
config_rx: watch::Receiver<Arc<ProxyConfig>>,
) {
if let Some(port) = config.server.metrics_port {
// metrics_listen takes precedence; fall back to metrics_port for backward compat.
let metrics_target: Option<(u16, Option<String>)> =
if let Some(ref listen) = config.server.metrics_listen {
match listen.parse::<std::net::SocketAddr>() {
Ok(addr) => Some((addr.port(), Some(listen.clone()))),
Err(e) => {
startup_tracker
.skip_component(
COMPONENT_METRICS_START,
Some(format!("invalid metrics_listen \"{}\": {}", listen, e)),
)
.await;
None
}
}
} else {
config.server.metrics_port.map(|p| (p, None))
};
if let Some((port, listen)) = metrics_target {
let fallback_label = format!("port {}", port);
let label = listen.as_deref().unwrap_or(&fallback_label);
startup_tracker
.start_component(
COMPONENT_METRICS_START,
Some(format!("spawn metrics endpoint on {}", port)),
Some(format!("spawn metrics endpoint on {}", label)),
)
.await;
let stats = stats.clone();
@ -294,6 +315,7 @@ pub(crate) async fn spawn_metrics_if_configured(
tokio::spawn(async move {
metrics::serve(
port,
listen,
stats,
beobachten,
ip_tracker_metrics,
@ -308,7 +330,7 @@ pub(crate) async fn spawn_metrics_if_configured(
Some("metrics task spawned".to_string()),
)
.await;
} else {
} else if config.server.metrics_listen.is_none() {
startup_tracker
.skip_component(
COMPONENT_METRICS_START,

View File

@ -6,6 +6,8 @@ mod config;
mod crypto;
mod error;
mod ip_tracker;
#[cfg(test)]
mod ip_tracker_regression_tests;
mod maestro;
mod metrics;
mod network;

View File

@ -21,6 +21,7 @@ use crate::transport::{ListenOptions, create_listener};
pub async fn serve(
port: u16,
listen: Option<String>,
stats: Arc<Stats>,
beobachten: Arc<BeobachtenStore>,
ip_tracker: Arc<UserIpTracker>,
@ -28,6 +29,33 @@ pub async fn serve(
whitelist: Vec<IpNetwork>,
) {
let whitelist = Arc::new(whitelist);
// If `metrics_listen` is set, bind on that single address only.
if let Some(ref listen_addr) = listen {
let addr: SocketAddr = match listen_addr.parse() {
Ok(a) => a,
Err(e) => {
warn!(error = %e, "Invalid metrics_listen address: {}", listen_addr);
return;
}
};
let is_ipv6 = addr.is_ipv6();
match bind_metrics_listener(addr, is_ipv6) {
Ok(listener) => {
info!("Metrics endpoint: http://{}/metrics and /beobachten", addr);
serve_listener(
listener, stats, beobachten, ip_tracker, config_rx, whitelist,
)
.await;
}
Err(e) => {
warn!(error = %e, "Failed to bind metrics on {}", addr);
}
}
return;
}
// Fallback: bind on 0.0.0.0 and [::] using metrics_port.
let mut listener_v4 = None;
let mut listener_v6 = None;

View File

@ -11,9 +11,8 @@ use crate::crypto::{sha256_hmac, SecureRandom};
use crate::error::ProxyError;
use super::constants::*;
use std::time::{SystemTime, UNIX_EPOCH};
use num_bigint::BigUint;
use num_traits::One;
use subtle::ConstantTimeEq;
use x25519_dalek::{X25519_BASEPOINT_BYTES, x25519};
// ============= Public Constants =============
@ -27,10 +26,17 @@ pub const TLS_DIGEST_POS: usize = 11;
pub const TLS_DIGEST_HALF_LEN: usize = 16;
/// Time skew limits for anti-replay (in seconds)
pub const TIME_SKEW_MIN: i64 = -20 * 60; // 20 minutes before
pub const TIME_SKEW_MAX: i64 = 10 * 60; // 10 minutes after
///
/// The default window is intentionally narrow to reduce replay acceptance.
/// Operators with known clock-drifted clients should tune deployment config
/// (for example replay-window policy) to match their environment.
pub const TIME_SKEW_MIN: i64 = -2 * 60; // 2 minutes before
pub const TIME_SKEW_MAX: i64 = 2 * 60; // 2 minutes after
/// Maximum accepted boot-time timestamp (seconds) before skew checks are enforced.
pub const BOOT_TIME_MAX_SECS: u32 = 7 * 24 * 60 * 60;
/// Hard cap for boot-time compatibility bypass to avoid oversized acceptance
/// windows when replay TTL is configured very large.
pub const BOOT_TIME_COMPAT_MAX_SECS: u32 = 2 * 60;
// ============= Private Constants =============
@ -63,6 +69,7 @@ pub struct TlsValidation {
/// Client digest for response generation
pub digest: [u8; TLS_DIGEST_LEN],
/// Timestamp extracted from digest
pub timestamp: u32,
}
@ -117,28 +124,8 @@ impl TlsExtensionBuilder {
self
}
/// Add ALPN extension with a single selected protocol.
fn add_alpn(&mut self, proto: &[u8]) -> &mut Self {
// Extension type: ALPN (0x0010)
self.extensions.extend_from_slice(&extension_type::ALPN.to_be_bytes());
// ALPN extension format:
// extension_data length (2 bytes)
// protocols length (2 bytes)
// protocol name length (1 byte)
// protocol name bytes
let proto_len = proto.len() as u8;
let list_len: u16 = 1 + u16::from(proto_len);
let ext_len: u16 = 2 + list_len;
self.extensions.extend_from_slice(&ext_len.to_be_bytes());
self.extensions.extend_from_slice(&list_len.to_be_bytes());
self.extensions.push(proto_len);
self.extensions.extend_from_slice(proto);
self
}
/// Build final extensions with length prefix
fn build(self) -> Vec<u8> {
let mut result = Vec::with_capacity(2 + self.extensions.len());
@ -153,7 +140,7 @@ impl TlsExtensionBuilder {
}
/// Get current extensions without length prefix (for calculation)
#[allow(dead_code)]
fn as_bytes(&self) -> &[u8] {
&self.extensions
}
@ -173,8 +160,6 @@ struct ServerHelloBuilder {
compression: u8,
/// Extensions
extensions: TlsExtensionBuilder,
/// Selected ALPN protocol (if any)
alpn: Option<Vec<u8>>,
}
impl ServerHelloBuilder {
@ -185,7 +170,6 @@ impl ServerHelloBuilder {
cipher_suite: cipher_suite::TLS_AES_128_GCM_SHA256,
compression: 0x00,
extensions: TlsExtensionBuilder::new(),
alpn: None,
}
}
@ -200,18 +184,9 @@ impl ServerHelloBuilder {
self
}
fn with_alpn(mut self, proto: Option<Vec<u8>>) -> Self {
self.alpn = proto;
self
}
/// Build ServerHello message (without record header)
fn build_message(&self) -> Vec<u8> {
let mut ext_builder = self.extensions.clone();
if let Some(ref alpn) = self.alpn {
ext_builder.add_alpn(alpn);
}
let extensions = ext_builder.extensions.clone();
let extensions = self.extensions.extensions.clone();
let extensions_len = extensions.len() as u16;
// Calculate total length
@ -281,6 +256,7 @@ impl ServerHelloBuilder {
/// Returns validation result if a matching user is found.
/// The result **must** be used — ignoring it silently bypasses authentication.
#[must_use]
pub fn validate_tls_handshake(
handshake: &[u8],
secrets: &[(String, Vec<u8>)],
@ -296,9 +272,9 @@ pub fn validate_tls_handshake(
/// Validate TLS ClientHello and cap the boot-time bypass by replay-cache TTL.
///
/// A boot-time timestamp is only accepted when it falls below both
/// `BOOT_TIME_MAX_SECS` and the configured replay window, preventing timestamp
/// reuse outside replay cache coverage.
/// A boot-time timestamp is only accepted when it falls below all three
/// bounds: `BOOT_TIME_MAX_SECS`, configured replay window, and
/// `BOOT_TIME_COMPAT_MAX_SECS`, preventing oversized compatibility windows.
#[must_use]
pub fn validate_tls_handshake_with_replay_window(
handshake: &[u8],
@ -316,7 +292,16 @@ pub fn validate_tls_handshake_with_replay_window(
};
let replay_window_u32 = u32::try_from(replay_window_secs).unwrap_or(u32::MAX);
let boot_time_cap_secs = BOOT_TIME_MAX_SECS.min(replay_window_u32);
// Boot-time bypass and ignore_time_skew serve different compatibility paths.
// When skew checks are disabled, force boot-time cap to zero to prevent
// accidental future coupling of boot-time logic into the ignore-skew path.
let boot_time_cap_secs = if ignore_time_skew {
0
} else {
BOOT_TIME_MAX_SECS
.min(replay_window_u32)
.min(BOOT_TIME_COMPAT_MAX_SECS)
};
validate_tls_handshake_at_time_with_boot_cap(
handshake,
@ -335,6 +320,7 @@ fn system_time_to_unix_secs(now: SystemTime) -> Option<i64> {
i64::try_from(d.as_secs()).ok()
}
fn validate_tls_handshake_at_time(
handshake: &[u8],
secrets: &[(String, Vec<u8>)],
@ -369,6 +355,9 @@ fn validate_tls_handshake_at_time_with_boot_cap(
// Extract session ID
let session_id_len_pos = TLS_DIGEST_POS + TLS_DIGEST_LEN;
let session_id_len = handshake.get(session_id_len_pos).copied()? as usize;
if session_id_len > 32 {
return None;
}
let session_id_start = session_id_len_pos + 1;
if handshake.len() < session_id_start + session_id_len {
@ -411,7 +400,7 @@ fn validate_tls_handshake_at_time_with_boot_cap(
if !ignore_time_skew {
// Allow very small timestamps (boot time instead of unix time)
// This is a quirk in some clients that use uptime instead of real time
let is_boot_time = timestamp < boot_time_cap_secs;
let is_boot_time = boot_time_cap_secs > 0 && timestamp < boot_time_cap_secs;
if !is_boot_time {
let time_diff = now - i64::from(timestamp);
if !(TIME_SKEW_MIN..=TIME_SKEW_MAX).contains(&time_diff) {
@ -433,27 +422,14 @@ fn validate_tls_handshake_at_time_with_boot_cap(
})
}
fn curve25519_prime() -> BigUint {
(BigUint::one() << 255) - BigUint::from(19u32)
}
/// Generate a fake X25519 public key for TLS
///
/// Produces a quadratic residue mod p = 2^255 - 19 by computing n² mod p,
/// which matches Python/C behavior and avoids DPI fingerprinting.
/// Uses RFC 7748 X25519 scalar multiplication over the canonical basepoint,
/// yielding distribution-consistent public keys for anti-fingerprinting.
pub fn gen_fake_x25519_key(rng: &SecureRandom) -> [u8; 32] {
let mut n_bytes = [0u8; 32];
n_bytes.copy_from_slice(&rng.bytes(32));
let n = BigUint::from_bytes_le(&n_bytes);
let p = curve25519_prime();
let pk = (&n * &n) % &p;
let mut out = pk.to_bytes_le();
out.resize(32, 0);
let mut result = [0u8; 32];
result.copy_from_slice(&out[..32]);
result
let mut scalar = [0u8; 32];
scalar.copy_from_slice(&rng.bytes(32));
x25519(scalar, X25519_BASEPOINT_BYTES)
}
/// Build TLS ServerHello response
@ -482,7 +458,6 @@ pub fn build_server_hello(
let server_hello = ServerHelloBuilder::new(session_id.to_vec())
.with_x25519_key(&x25519_key)
.with_tls13_version()
.with_alpn(alpn)
.build_record();
// Build Change Cipher Spec record
@ -493,8 +468,27 @@ pub fn build_server_hello(
0x01, // CCS byte
];
// Build fake certificate (Application Data record)
let fake_cert = rng.bytes(fake_cert_len);
// Build first encrypted flight mimic as opaque ApplicationData bytes.
// Embed a compact EncryptedExtensions-like ALPN block when selected.
let mut fake_cert = Vec::with_capacity(fake_cert_len);
if let Some(proto) = alpn.as_ref().filter(|p| !p.is_empty() && p.len() <= u8::MAX as usize) {
let proto_list_len = 1usize + proto.len();
let ext_data_len = 2usize + proto_list_len;
let marker_len = 4usize + ext_data_len;
if marker_len <= fake_cert_len {
fake_cert.extend_from_slice(&0x0010u16.to_be_bytes());
fake_cert.extend_from_slice(&(ext_data_len as u16).to_be_bytes());
fake_cert.extend_from_slice(&(proto_list_len as u16).to_be_bytes());
fake_cert.push(proto.len() as u8);
fake_cert.extend_from_slice(proto);
}
}
if fake_cert.len() < fake_cert_len {
fake_cert.extend_from_slice(&rng.bytes(fake_cert_len - fake_cert.len()));
} else if fake_cert.len() > fake_cert_len {
fake_cert.truncate(fake_cert_len);
}
let mut app_data_record = Vec::with_capacity(5 + fake_cert_len);
app_data_record.push(TLS_RECORD_APPLICATION);
app_data_record.extend_from_slice(&TLS_VERSION);
@ -506,8 +500,9 @@ pub fn build_server_hello(
// Build optional NewSessionTicket records (TLS 1.3 handshake messages are encrypted;
// here we mimic with opaque ApplicationData records of plausible size).
let mut tickets = Vec::new();
if new_session_tickets > 0 {
for _ in 0..new_session_tickets {
let ticket_count = new_session_tickets.min(4);
if ticket_count > 0 {
for _ in 0..ticket_count {
let ticket_len: usize = rng.range(48) + 48; // 48-95 bytes
let mut record = Vec::with_capacity(5 + ticket_len);
record.push(TLS_RECORD_APPLICATION);
@ -705,13 +700,14 @@ pub fn is_tls_handshake(first_bytes: &[u8]) -> bool {
return false;
}
// TLS record header: 0x16 (handshake) 0x03 0x01 (TLS 1.0)
// TLS ClientHello commonly uses legacy record versions 0x0301 or 0x0303.
first_bytes[0] == TLS_RECORD_HANDSHAKE
&& first_bytes[1] == 0x03
&& first_bytes[2] == 0x01
&& (first_bytes[2] == 0x01 || first_bytes[2] == 0x03)
}
/// Parse TLS record header, returns (record_type, length)
pub fn parse_tls_record_header(header: &[u8; 5]) -> Option<(u8, u16)> {
let record_type = header[0];
let version = [header[1], header[2]];

View File

@ -1,5 +1,8 @@
use super::*;
use crate::crypto::sha256_hmac;
use crate::tls_front::emulator::build_emulated_server_hello;
use crate::tls_front::types::{CachedTlsData, ParsedServerHello, TlsBehaviorProfile, TlsProfileSource};
use std::time::SystemTime;
/// Build a TLS-handshake-like buffer that contains a valid HMAC digest
/// for the given `secret` and `timestamp`.
@ -297,8 +300,8 @@ fn boot_time_timestamp_accepted_without_ignore_flag() {
// Timestamps below the boot-time threshold are treated as client uptime,
// not real wall-clock time. The proxy allows them regardless of skew.
let secret = b"boot_time_test";
// Keep this safely below BOOT_TIME_MAX_SECS to assert bypass behavior.
let boot_ts: u32 = BOOT_TIME_MAX_SECS / 2;
// Keep this safely below compatibility cap to assert bypass behavior.
let boot_ts: u32 = BOOT_TIME_COMPAT_MAX_SECS.saturating_sub(1);
let handshake = make_valid_tls_handshake(secret, boot_ts);
let secrets = vec![("u".to_string(), secret.to_vec())];
assert!(
@ -369,16 +372,16 @@ fn one_byte_session_id_validates_and_is_preserved() {
}
#[test]
fn max_session_id_len_255_with_valid_digest_is_accepted() {
fn max_session_id_len_255_with_valid_digest_is_rejected_by_rfc_cap() {
let secret = b"sid_len_255_test";
let session_id = vec![0xCCu8; 255];
let handshake = make_valid_tls_handshake_with_session_id(secret, 0, &session_id);
let secrets = vec![("u".to_string(), secret.to_vec())];
let result = validate_tls_handshake(&handshake, &secrets, true)
.expect("session_id_len=255 with valid digest must validate");
assert_eq!(result.session_id.len(), 255);
assert_eq!(result.session_id, session_id);
assert!(
validate_tls_handshake(&handshake, &secrets, true).is_none(),
"legacy_session_id length > 32 must be rejected even with valid digest"
);
}
// ------------------------------------------------------------------
@ -660,13 +663,14 @@ fn zero_length_session_id_accepted() {
// Boot-time threshold — exact boundary precision
// ------------------------------------------------------------------
/// timestamp = BOOT_TIME_MAX_SECS - 1 is the last value inside the boot-time window.
/// timestamp = BOOT_TIME_COMPAT_MAX_SECS - 1 is the last value inside
/// the runtime boot-time compatibility window.
/// is_boot_time = true → skew check is skipped entirely → accepted even
/// when `now` is far from the timestamp.
#[test]
fn timestamp_one_below_boot_threshold_bypasses_skew_check() {
let secret = b"boot_last_value_test";
let ts: u32 = BOOT_TIME_MAX_SECS - 1;
let ts: u32 = BOOT_TIME_COMPAT_MAX_SECS - 1;
let h = make_valid_tls_handshake(secret, ts);
let secrets = vec![("u".to_string(), secret.to_vec())];
@ -674,32 +678,48 @@ fn timestamp_one_below_boot_threshold_bypasses_skew_check() {
// Boot-time bypass must prevent the skew check from running.
assert!(
validate_tls_handshake_at_time(&h, &secrets, false, 0).is_some(),
"ts=BOOT_TIME_MAX_SECS-1 must bypass skew check regardless of now"
"ts=BOOT_TIME_COMPAT_MAX_SECS-1 must bypass skew check regardless of now"
);
}
/// timestamp = BOOT_TIME_MAX_SECS is the first value outside the boot-time window.
/// timestamp = BOOT_TIME_COMPAT_MAX_SECS is the first value outside the
/// runtime boot-time compatibility window.
/// is_boot_time = false → skew check IS applied. Two sub-cases confirm this:
/// once with now chosen so the skew passes (accepted) and once where it fails.
#[test]
fn timestamp_at_boot_threshold_triggers_skew_check() {
let secret = b"boot_exact_value_test";
let ts: u32 = BOOT_TIME_MAX_SECS;
let ts: u32 = BOOT_TIME_COMPAT_MAX_SECS;
let h = make_valid_tls_handshake(secret, ts);
let secrets = vec![("u".to_string(), secret.to_vec())];
// now = ts + 50 → time_diff = 50, within [-1200, 600] → accepted.
let now_valid: i64 = ts as i64 + 50;
assert!(
validate_tls_handshake_at_time(&h, &secrets, false, now_valid).is_some(),
"ts=BOOT_TIME_MAX_SECS within skew window must be accepted via skew check"
validate_tls_handshake_at_time_with_boot_cap(
&h,
&secrets,
false,
now_valid,
BOOT_TIME_COMPAT_MAX_SECS,
)
.is_some(),
"ts=BOOT_TIME_COMPAT_MAX_SECS within skew window must be accepted via skew check"
);
// now = 0 → time_diff = -86_400_000, outside window → rejected.
// If the boot-time bypass were wrongly applied here this would pass.
// now = -1 → time_diff = -121 at the 120-second threshold, outside window
// for TIME_SKEW_MIN=-120. If boot-time bypass were wrongly applied this
// would pass.
assert!(
validate_tls_handshake_at_time(&h, &secrets, false, 0).is_none(),
"ts=BOOT_TIME_MAX_SECS far from now must be rejected — no boot-time bypass"
validate_tls_handshake_at_time_with_boot_cap(
&h,
&secrets,
false,
-1,
BOOT_TIME_COMPAT_MAX_SECS,
)
.is_none(),
"ts=BOOT_TIME_COMPAT_MAX_SECS far from now must be rejected — no boot-time bypass"
);
}
@ -720,7 +740,7 @@ fn replay_window_cap_disables_boot_bypass_for_old_timestamps() {
#[test]
fn replay_window_cap_still_allows_small_boot_timestamp() {
let secret = b"boot_cap_enabled_test";
let ts: u32 = 120;
let ts: u32 = BOOT_TIME_COMPAT_MAX_SECS.saturating_sub(1);
let h = make_valid_tls_handshake(secret, ts);
let secrets = vec![("u".to_string(), secret.to_vec())];
@ -731,6 +751,260 @@ fn replay_window_cap_still_allows_small_boot_timestamp() {
);
}
#[test]
fn large_replay_window_is_hard_capped_for_boot_compatibility() {
let secret = b"boot_cap_hard_limit_test";
let ts: u32 = BOOT_TIME_COMPAT_MAX_SECS + 1;
let h = make_valid_tls_handshake(secret, ts);
let secrets = vec![("u".to_string(), secret.to_vec())];
let result = validate_tls_handshake_with_replay_window(&h, &secrets, false, u64::MAX);
assert!(
result.is_none(),
"very large replay window must not expand boot-time bypass beyond hard compatibility cap"
);
}
#[test]
fn ignore_time_skew_explicitly_decouples_from_boot_time_cap() {
let secret = b"ignore_skew_boot_cap_decouple_test";
let ts: u32 = 1;
let h = make_valid_tls_handshake(secret, ts);
let secrets = vec![("u".to_string(), secret.to_vec())];
let cap_zero = validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, true, 0, 0);
let cap_nonzero =
validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, true, 0, BOOT_TIME_COMPAT_MAX_SECS);
assert!(cap_zero.is_some(), "ignore_time_skew=true must accept valid HMAC");
assert!(
cap_nonzero.is_some(),
"ignore_time_skew path must not depend on boot-time cap"
);
let a = cap_zero.unwrap();
let b = cap_nonzero.unwrap();
assert_eq!(a.user, b.user);
assert_eq!(a.timestamp, b.timestamp);
}
#[test]
fn adversarial_small_boot_timestamp_matrix_rejected_when_boot_cap_forced_zero() {
let secret = b"boot_cap_zero_matrix_test";
let secrets = vec![("u".to_string(), secret.to_vec())];
let now: i64 = 1_700_000_000;
for ts in 0u32..1024u32 {
let h = make_valid_tls_handshake(secret, ts);
let result = validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, false, now, 0);
assert!(
result.is_none(),
"boot cap=0 must reject timestamp {ts} when skew checks are active"
);
}
}
#[test]
fn light_fuzz_boot_cap_zero_rejects_small_timestamp_space() {
let secret = b"boot_cap_zero_fuzz_test";
let secrets = vec![("u".to_string(), secret.to_vec())];
let now: i64 = 1_700_000_000;
let mut s: u64 = 0x9E37_79B9_7F4A_7C15;
for _ in 0..4096 {
s ^= s << 7;
s ^= s >> 9;
s ^= s << 8;
let ts = (s as u32) % 2048;
let h = make_valid_tls_handshake(secret, ts);
let result = validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, false, now, 0);
assert!(
result.is_none(),
"fuzzed boot-range timestamp {ts} must be rejected when cap=0"
);
}
}
#[test]
fn stress_boot_cap_zero_rejection_is_deterministic_under_high_iteration_count() {
let secret = b"boot_cap_zero_stress_test";
let secrets = vec![("u".to_string(), secret.to_vec())];
let now: i64 = 1_700_000_000;
for i in 0u32..20_000u32 {
let ts = i % 4096;
let h = make_valid_tls_handshake(secret, ts);
let result = validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, false, now, 0);
assert!(
result.is_none(),
"iteration {i}: timestamp {ts} must be rejected with cap=0"
);
}
}
#[test]
fn replay_window_one_allows_only_zero_timestamp_boot_bypass() {
let secret = b"replay_window_one_boot_test";
let secrets = vec![("u".to_string(), secret.to_vec())];
let ts0 = make_valid_tls_handshake(secret, 0);
let ts1 = make_valid_tls_handshake(secret, 1);
assert!(
validate_tls_handshake_with_replay_window(&ts0, &secrets, false, 1).is_some(),
"replay_window=1 must allow timestamp 0 via boot-time compatibility"
);
assert!(
validate_tls_handshake_with_replay_window(&ts1, &secrets, false, 1).is_none(),
"replay_window=1 must reject timestamp 1 on normal wall-clock systems"
);
}
#[test]
fn replay_window_two_allows_ts0_ts1_but_rejects_ts2() {
let secret = b"replay_window_two_boot_test";
let secrets = vec![("u".to_string(), secret.to_vec())];
let ts0 = make_valid_tls_handshake(secret, 0);
let ts1 = make_valid_tls_handshake(secret, 1);
let ts2 = make_valid_tls_handshake(secret, 2);
assert!(validate_tls_handshake_with_replay_window(&ts0, &secrets, false, 2).is_some());
assert!(validate_tls_handshake_with_replay_window(&ts1, &secrets, false, 2).is_some());
assert!(
validate_tls_handshake_with_replay_window(&ts2, &secrets, false, 2).is_none(),
"timestamp equal to replay-window cap must not use boot-time bypass"
);
}
#[test]
fn adversarial_skew_boundary_matrix_accepts_only_inclusive_window_when_boot_disabled() {
let secret = b"skew_boundary_matrix_test";
let secrets = vec![("u".to_string(), secret.to_vec())];
let now: i64 = 1_700_000_000;
for offset in -1500i64..=1500i64 {
let ts_i64 = now - offset;
let ts = u32::try_from(ts_i64).expect("timestamp must fit u32 for test matrix");
let h = make_valid_tls_handshake(secret, ts);
let accepted = validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, false, now, 0)
.is_some();
let expected = (TIME_SKEW_MIN..=TIME_SKEW_MAX).contains(&offset);
assert_eq!(
accepted, expected,
"offset {offset} must match inclusive skew window when boot bypass is disabled"
);
}
}
#[test]
fn light_fuzz_skew_window_rejects_outside_range_when_boot_disabled() {
let secret = b"skew_outside_fuzz_test";
let secrets = vec![("u".to_string(), secret.to_vec())];
let now: i64 = 1_700_000_000;
let mut s: u64 = 0x0123_4567_89AB_CDEF;
for _ in 0..4096 {
s ^= s << 7;
s ^= s >> 9;
s ^= s << 8;
let magnitude = 1300i64 + ((s % 2000u64) as i64);
let sign = if (s & 1) == 0 { 1i64 } else { -1i64 };
let offset = sign * magnitude;
let ts_i64 = now - offset;
let ts = u32::try_from(ts_i64).expect("timestamp must fit u32 for fuzz test");
let h = make_valid_tls_handshake(secret, ts);
let accepted = validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, false, now, 0)
.is_some();
assert!(
!accepted,
"offset {offset} must be rejected outside strict skew window"
);
}
}
#[test]
fn stress_boot_disabled_validation_matches_time_diff_oracle() {
let secret = b"boot_disabled_oracle_stress_test";
let secrets = vec![("u".to_string(), secret.to_vec())];
let now: i64 = 1_700_000_000;
let mut s: u64 = 0xBADC_0FFE_EE11_2233;
for _ in 0..25_000 {
s ^= s << 7;
s ^= s >> 9;
s ^= s << 8;
let ts = s as u32;
let h = make_valid_tls_handshake(secret, ts);
let accepted = validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, false, now, 0)
.is_some();
let time_diff = now - i64::from(ts);
let expected = (TIME_SKEW_MIN..=TIME_SKEW_MAX).contains(&time_diff);
assert_eq!(
accepted, expected,
"boot-disabled validation must match pure time-diff oracle"
);
}
}
#[test]
fn integration_large_user_list_with_boot_disabled_finds_only_matching_user() {
let now: i64 = 1_700_000_000;
let target_secret = b"target_user_secret";
let target_ts = (now - 1) as u32;
let handshake = make_valid_tls_handshake(target_secret, target_ts);
let mut secrets = Vec::new();
for i in 0..512u32 {
secrets.push((format!("noise-{i}"), format!("noise-secret-{i}").into_bytes()));
}
secrets.push(("target-user".to_string(), target_secret.to_vec()));
let result = validate_tls_handshake_at_time_with_boot_cap(&handshake, &secrets, false, now, 0)
.expect("matching user should validate within strict skew window");
assert_eq!(result.user, "target-user");
}
#[test]
fn light_fuzz_ignore_time_skew_accepts_wide_timestamp_range_with_valid_hmac() {
let secret = b"ignore_skew_fuzz_accept_test";
let secrets = vec![("u".to_string(), secret.to_vec())];
let mut s: u64 = 0xC0FF_EE11_2233_4455;
for _ in 0..2048 {
s ^= s << 7;
s ^= s >> 9;
s ^= s << 8;
let ts = s as u32;
let h = make_valid_tls_handshake(secret, ts);
let result = validate_tls_handshake_with_replay_window(&h, &secrets, true, 60);
assert!(
result.is_some(),
"ignore_time_skew=true must accept valid HMAC for arbitrary timestamp"
);
}
}
#[test]
fn light_fuzz_small_replay_window_rejects_far_timestamps_when_skew_enabled() {
let secret = b"replay_window_reject_fuzz_test";
let secrets = vec![("u".to_string(), secret.to_vec())];
for ts in 300u32..=1323u32 {
let h = make_valid_tls_handshake(secret, ts);
let result = validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, false, 0, 300);
assert!(
result.is_none(),
"with skew checks enabled and boot cap=300, timestamp >=300 at now=0 must be rejected"
);
}
}
// ------------------------------------------------------------------
// Extreme timestamp values
// ------------------------------------------------------------------
@ -897,7 +1171,9 @@ fn first_matching_user_wins_over_later_duplicate_secret() {
#[test]
fn test_is_tls_handshake() {
assert!(is_tls_handshake(&[0x16, 0x03, 0x01]));
assert!(is_tls_handshake(&[0x16, 0x03, 0x03]));
assert!(is_tls_handshake(&[0x16, 0x03, 0x01, 0x02, 0x00]));
assert!(is_tls_handshake(&[0x16, 0x03, 0x03, 0x02, 0x00]));
assert!(!is_tls_handshake(&[0x17, 0x03, 0x01]));
assert!(!is_tls_handshake(&[0x16, 0x03, 0x02]));
assert!(!is_tls_handshake(&[0x16, 0x03]));
@ -945,17 +1221,158 @@ fn test_gen_fake_x25519_key() {
}
#[test]
fn test_fake_x25519_key_is_quadratic_residue() {
use num_bigint::BigUint;
use num_traits::One;
fn test_fake_x25519_key_is_nonzero_and_varies() {
let rng = crate::crypto::SecureRandom::new();
let key = gen_fake_x25519_key(&rng);
let p = curve25519_prime();
let k_num = BigUint::from_bytes_le(&key);
let exponent = (&p - BigUint::one()) >> 1;
let legendre = k_num.modpow(&exponent, &p);
assert_eq!(legendre, BigUint::one());
let mut unique = std::collections::HashSet::new();
let mut saw_non_zero = false;
for _ in 0..64 {
let key = gen_fake_x25519_key(&rng);
if key != [0u8; 32] {
saw_non_zero = true;
}
unique.insert(key);
}
assert!(
saw_non_zero,
"generated X25519 public keys must not collapse to all-zero output"
);
assert!(
unique.len() > 1,
"generated X25519 public keys must vary across invocations"
);
}
#[test]
fn validate_tls_handshake_rejects_session_id_longer_than_rfc_cap() {
let secret = b"session_id_cap_secret";
let oversized_sid = vec![0x42u8; 33];
let handshake = make_valid_tls_handshake_with_session_id(secret, 0, &oversized_sid);
let secrets = vec![("u".to_string(), secret.to_vec())];
assert!(
validate_tls_handshake(&handshake, &secrets, true).is_none(),
"legacy_session_id length > 32 must be rejected"
);
}
fn server_hello_extension_types(record: &[u8]) -> Vec<u16> {
if record.len() < 9 || record[0] != TLS_RECORD_HANDSHAKE || record[5] != 0x02 {
return Vec::new();
}
let record_len = u16::from_be_bytes([record[3], record[4]]) as usize;
if record.len() < 5 + record_len {
return Vec::new();
}
let hs_len = u32::from_be_bytes([0, record[6], record[7], record[8]]) as usize;
let hs_start = 5;
let hs_end = hs_start + 4 + hs_len;
if hs_end > record.len() {
return Vec::new();
}
let mut pos = hs_start + 4 + 2 + 32;
if pos >= hs_end {
return Vec::new();
}
let sid_len = record[pos] as usize;
pos += 1 + sid_len;
if pos + 2 + 1 + 2 > hs_end {
return Vec::new();
}
pos += 2 + 1;
let ext_len = u16::from_be_bytes([record[pos], record[pos + 1]]) as usize;
pos += 2;
let ext_end = pos + ext_len;
if ext_end > hs_end {
return Vec::new();
}
let mut out = Vec::new();
while pos + 4 <= ext_end {
let etype = u16::from_be_bytes([record[pos], record[pos + 1]]);
let elen = u16::from_be_bytes([record[pos + 2], record[pos + 3]]) as usize;
pos += 4;
if pos + elen > ext_end {
break;
}
out.push(etype);
pos += elen;
}
out
}
#[test]
fn build_server_hello_never_places_alpn_in_server_hello_extensions() {
let secret = b"alpn_sh_forbidden";
let client_digest = [0x11u8; 32];
let session_id = vec![0xAA; 32];
let rng = crate::crypto::SecureRandom::new();
let response = build_server_hello(
secret,
&client_digest,
&session_id,
1024,
&rng,
Some(b"h2".to_vec()),
0,
);
let exts = server_hello_extension_types(&response);
assert!(
!exts.contains(&0x0010),
"ALPN extension must not appear in ServerHello"
);
}
#[test]
fn emulated_server_hello_never_places_alpn_in_server_hello_extensions() {
let secret = b"alpn_emulated_forbidden";
let client_digest = [0x22u8; 32];
let session_id = vec![0xAB; 32];
let rng = crate::crypto::SecureRandom::new();
let cached = CachedTlsData {
server_hello_template: ParsedServerHello {
version: TLS_VERSION,
random: [0u8; 32],
session_id: Vec::new(),
cipher_suite: [0x13, 0x01],
compression: 0,
extensions: Vec::new(),
},
cert_info: None,
cert_payload: None,
app_data_records_sizes: vec![1024],
total_app_data_len: 1024,
behavior_profile: TlsBehaviorProfile {
change_cipher_spec_count: 1,
app_data_record_sizes: vec![1024],
ticket_record_sizes: Vec::new(),
source: TlsProfileSource::Default,
},
fetched_at: SystemTime::now(),
domain: "example.com".to_string(),
};
let response = build_emulated_server_hello(
secret,
&client_digest,
&session_id,
&cached,
false,
&rng,
Some(b"h2".to_vec()),
0,
);
let exts = server_hello_extension_types(&response);
assert!(
!exts.contains(&0x0010),
"ALPN extension must not appear in emulated ServerHello"
);
}
#[test]
@ -1394,3 +1811,413 @@ fn server_hello_application_data_payload_varies_across_runs() {
"ApplicationData payload should vary across runs to reduce fingerprintability"
);
}
#[test]
fn replay_window_zero_disables_boot_bypass_for_any_nonzero_timestamp() {
let secret = b"window_zero_boot_bypass_test";
let secrets = vec![("u".to_string(), secret.to_vec())];
let ts1 = make_valid_tls_handshake(secret, 1);
assert!(
validate_tls_handshake_with_replay_window(&ts1, &secrets, false, 0).is_none(),
"replay_window_secs=0 must reject nonzero timestamps even in boot-time range"
);
let ts0 = make_valid_tls_handshake(secret, 0);
assert!(
validate_tls_handshake_with_replay_window(&ts0, &secrets, false, 0).is_none(),
"replay_window_secs=0 enforces strict skew check and rejects timestamp=0 on normal wall-clock systems"
);
}
#[test]
fn large_replay_window_does_not_expand_time_skew_acceptance() {
let secret = b"large_replay_window_skew_bound_test";
let secrets = vec![("u".to_string(), secret.to_vec())];
let now: i64 = 1_700_000_000;
let ts_far_past = (now - 600) as u32;
let valid = make_valid_tls_handshake(secret, ts_far_past);
assert!(
validate_tls_handshake_with_replay_window(&valid, &secrets, false, 86_400).is_none(),
"large replay window must not relax strict skew check once boot-time bypass is not in play"
);
}
#[test]
fn parse_tls_record_header_accepts_tls_version_constant() {
let header = [TLS_RECORD_HANDSHAKE, TLS_VERSION[0], TLS_VERSION[1], 0x00, 0x2A];
let parsed = parse_tls_record_header(&header).expect("TLS_VERSION header should be accepted");
assert_eq!(parsed.0, TLS_RECORD_HANDSHAKE);
assert_eq!(parsed.1, 42);
}
#[test]
fn server_hello_clamps_fake_cert_len_lower_bound() {
let secret = b"fake_cert_lower_bound_test";
let client_digest = [0x11u8; TLS_DIGEST_LEN];
let session_id = vec![0x77; 32];
let rng = crate::crypto::SecureRandom::new();
let response = build_server_hello(secret, &client_digest, &session_id, 1, &rng, None, 0);
let sh_len = u16::from_be_bytes([response[3], response[4]]) as usize;
let ccs_pos = 5 + sh_len;
let ccs_len = u16::from_be_bytes([response[ccs_pos + 3], response[ccs_pos + 4]]) as usize;
let app_pos = ccs_pos + 5 + ccs_len;
let app_len = u16::from_be_bytes([response[app_pos + 3], response[app_pos + 4]]) as usize;
assert_eq!(response[app_pos], TLS_RECORD_APPLICATION);
assert_eq!(app_len, 64, "fake cert payload must be clamped to minimum 64 bytes");
}
#[test]
fn server_hello_clamps_fake_cert_len_upper_bound() {
let secret = b"fake_cert_upper_bound_test";
let client_digest = [0x22u8; TLS_DIGEST_LEN];
let session_id = vec![0x66; 32];
let rng = crate::crypto::SecureRandom::new();
let response = build_server_hello(secret, &client_digest, &session_id, 65_535, &rng, None, 0);
let sh_len = u16::from_be_bytes([response[3], response[4]]) as usize;
let ccs_pos = 5 + sh_len;
let ccs_len = u16::from_be_bytes([response[ccs_pos + 3], response[ccs_pos + 4]]) as usize;
let app_pos = ccs_pos + 5 + ccs_len;
let app_len = u16::from_be_bytes([response[app_pos + 3], response[app_pos + 4]]) as usize;
assert_eq!(response[app_pos], TLS_RECORD_APPLICATION);
assert_eq!(app_len, 16_640, "fake cert payload must be clamped to TLS record max bound");
}
#[test]
fn server_hello_new_session_ticket_count_matches_configuration() {
let secret = b"ticket_count_surface_test";
let client_digest = [0x33u8; TLS_DIGEST_LEN];
let session_id = vec![0x55; 32];
let rng = crate::crypto::SecureRandom::new();
let tickets: u8 = 3;
let response = build_server_hello(secret, &client_digest, &session_id, 1024, &rng, None, tickets);
let mut pos = 0usize;
let mut app_records = 0usize;
while pos + 5 <= response.len() {
let rtype = response[pos];
let rlen = u16::from_be_bytes([response[pos + 3], response[pos + 4]]) as usize;
let next = pos + 5 + rlen;
assert!(next <= response.len(), "TLS record must stay inside response bounds");
if rtype == TLS_RECORD_APPLICATION {
app_records += 1;
}
pos = next;
}
assert_eq!(
app_records,
1 + tickets as usize,
"response must contain one main application record plus configured ticket-like tail records"
);
}
#[test]
fn server_hello_new_session_ticket_count_is_safely_capped() {
let secret = b"ticket_count_cap_test";
let client_digest = [0x44u8; TLS_DIGEST_LEN];
let session_id = vec![0x54; 32];
let rng = crate::crypto::SecureRandom::new();
let response = build_server_hello(secret, &client_digest, &session_id, 1024, &rng, None, u8::MAX);
let mut pos = 0usize;
let mut app_records = 0usize;
while pos + 5 <= response.len() {
let rtype = response[pos];
let rlen = u16::from_be_bytes([response[pos + 3], response[pos + 4]]) as usize;
let next = pos + 5 + rlen;
assert!(next <= response.len(), "TLS record must stay inside response bounds");
if rtype == TLS_RECORD_APPLICATION {
app_records += 1;
}
pos = next;
}
assert_eq!(
app_records,
5,
"response must cap ticket-like tail records to four plus one main application record"
);
}
#[test]
fn server_hello_application_data_contains_alpn_marker_when_selected() {
let secret = b"alpn_marker_test";
let client_digest = [0x55u8; TLS_DIGEST_LEN];
let session_id = vec![0xAB; 32];
let rng = crate::crypto::SecureRandom::new();
let response = build_server_hello(
secret,
&client_digest,
&session_id,
512,
&rng,
Some(b"h2".to_vec()),
0,
);
let sh_len = u16::from_be_bytes([response[3], response[4]]) as usize;
let ccs_pos = 5 + sh_len;
let ccs_len = u16::from_be_bytes([response[ccs_pos + 3], response[ccs_pos + 4]]) as usize;
let app_pos = ccs_pos + 5 + ccs_len;
let app_len = u16::from_be_bytes([response[app_pos + 3], response[app_pos + 4]]) as usize;
let app_payload = &response[app_pos + 5..app_pos + 5 + app_len];
let expected = [0x00u8, 0x10, 0x00, 0x05, 0x00, 0x03, 0x02, b'h', b'2'];
assert!(
app_payload.windows(expected.len()).any(|window| window == expected),
"first application payload must carry ALPN marker for selected protocol"
);
}
#[test]
fn server_hello_ignores_oversized_alpn_and_still_caps_ticket_tail() {
let secret = b"alpn_oversize_ignore_test";
let client_digest = [0x56u8; TLS_DIGEST_LEN];
let session_id = vec![0xCD; 32];
let rng = crate::crypto::SecureRandom::new();
let oversized_alpn = vec![b'x'; u8::MAX as usize + 1];
let response = build_server_hello(
secret,
&client_digest,
&session_id,
512,
&rng,
Some(oversized_alpn),
u8::MAX,
);
let mut pos = 0usize;
let mut app_records = 0usize;
let mut first_app_payload: Option<&[u8]> = None;
while pos + 5 <= response.len() {
let rtype = response[pos];
let rlen = u16::from_be_bytes([response[pos + 3], response[pos + 4]]) as usize;
let next = pos + 5 + rlen;
assert!(next <= response.len(), "TLS record must stay inside response bounds");
if rtype == TLS_RECORD_APPLICATION {
app_records += 1;
if first_app_payload.is_none() {
first_app_payload = Some(&response[pos + 5..next]);
}
}
pos = next;
}
let marker = [0x00u8, 0x10, 0x00, 0x06, 0x00, 0x04, 0x03, b'x', b'x', b'x', b'x'];
assert_eq!(
app_records, 5,
"oversized ALPN must not change the four-ticket cap on tail records"
);
assert!(
!first_app_payload
.expect("response must contain an application record")
.windows(marker.len())
.any(|window| window == marker),
"oversized ALPN must be ignored rather than embedded into the first application payload"
);
}
#[test]
fn server_hello_ignores_oversized_alpn_when_marker_would_not_fit() {
let secret = b"alpn_too_large_to_fit_test";
let client_digest = [0x57u8; TLS_DIGEST_LEN];
let session_id = vec![0xEF; 32];
let rng = crate::crypto::SecureRandom::new();
let oversized_alpn = vec![0xAB; u8::MAX as usize];
let response = build_server_hello(
secret,
&client_digest,
&session_id,
64,
&rng,
Some(oversized_alpn),
0,
);
let sh_len = u16::from_be_bytes([response[3], response[4]]) as usize;
let ccs_pos = 5 + sh_len;
let ccs_len = u16::from_be_bytes([response[ccs_pos + 3], response[ccs_pos + 4]]) as usize;
let app_pos = ccs_pos + 5 + ccs_len;
let app_len = u16::from_be_bytes([response[app_pos + 3], response[app_pos + 4]]) as usize;
let app_payload = &response[app_pos + 5..app_pos + 5 + app_len];
let mut marker_prefix = Vec::new();
marker_prefix.extend_from_slice(&0x0010u16.to_be_bytes());
marker_prefix.extend_from_slice(&0x0102u16.to_be_bytes());
marker_prefix.extend_from_slice(&0x0100u16.to_be_bytes());
marker_prefix.push(0xff);
marker_prefix.extend_from_slice(&[0xab; 8]);
assert!(
!app_payload.starts_with(&marker_prefix),
"oversized ALPN must not be partially embedded into the ServerHello application record"
);
}
#[test]
fn server_hello_embeds_full_alpn_marker_when_it_exactly_fits_fake_cert_len() {
let secret = b"alpn_exact_fit_test";
let client_digest = [0x58u8; TLS_DIGEST_LEN];
let session_id = vec![0xA5; 32];
let rng = crate::crypto::SecureRandom::new();
let proto = vec![b'z'; 57];
// marker_len = 4 + (2 + (1 + proto_len)) = 7 + proto_len = 64
let response = build_server_hello(
secret,
&client_digest,
&session_id,
64,
&rng,
Some(proto.clone()),
0,
);
let sh_len = u16::from_be_bytes([response[3], response[4]]) as usize;
let ccs_pos = 5 + sh_len;
let ccs_len = u16::from_be_bytes([response[ccs_pos + 3], response[ccs_pos + 4]]) as usize;
let app_pos = ccs_pos + 5 + ccs_len;
let app_len = u16::from_be_bytes([response[app_pos + 3], response[app_pos + 4]]) as usize;
let app_payload = &response[app_pos + 5..app_pos + 5 + app_len];
let mut expected_marker = Vec::new();
expected_marker.extend_from_slice(&0x0010u16.to_be_bytes());
expected_marker.extend_from_slice(&0x003Cu16.to_be_bytes());
expected_marker.extend_from_slice(&0x003Au16.to_be_bytes());
expected_marker.push(57u8);
expected_marker.extend_from_slice(&proto);
assert_eq!(app_payload.len(), expected_marker.len());
assert_eq!(app_payload, expected_marker.as_slice());
}
#[test]
fn server_hello_does_not_embed_partial_alpn_marker_when_one_byte_short() {
let secret = b"alpn_one_byte_short_test";
let client_digest = [0x59u8; TLS_DIGEST_LEN];
let session_id = vec![0xA6; 32];
let rng = crate::crypto::SecureRandom::new();
let proto = vec![0xAB; 58];
// marker_len = 65, fake_cert_len = 64 => marker must be fully skipped.
let response = build_server_hello(
secret,
&client_digest,
&session_id,
64,
&rng,
Some(proto),
0,
);
let sh_len = u16::from_be_bytes([response[3], response[4]]) as usize;
let ccs_pos = 5 + sh_len;
let ccs_len = u16::from_be_bytes([response[ccs_pos + 3], response[ccs_pos + 4]]) as usize;
let app_pos = ccs_pos + 5 + ccs_len;
let app_len = u16::from_be_bytes([response[app_pos + 3], response[app_pos + 4]]) as usize;
let app_payload = &response[app_pos + 5..app_pos + 5 + app_len];
let mut marker_prefix = Vec::new();
marker_prefix.extend_from_slice(&0x0010u16.to_be_bytes());
marker_prefix.extend_from_slice(&0x003Du16.to_be_bytes());
marker_prefix.extend_from_slice(&0x003Bu16.to_be_bytes());
marker_prefix.push(58u8);
marker_prefix.extend_from_slice(&[0xAB; 8]);
assert!(
!app_payload.starts_with(&marker_prefix),
"one-byte-short ALPN marker must be skipped entirely, not partially embedded"
);
}
#[test]
fn exhaustive_tls_minor_version_classification_matches_policy() {
for minor in 0u8..=u8::MAX {
let first = [TLS_RECORD_HANDSHAKE, 0x03, minor];
let expected = minor == 0x01 || minor == 0x03;
assert_eq!(
is_tls_handshake(&first),
expected,
"minor version {minor:#04x} classification mismatch"
);
}
}
#[test]
fn light_fuzz_tls_header_classifier_and_parser_policy_consistency() {
// Deterministic xorshift state keeps this fuzz test reproducible.
let mut s: u64 = 0x9E37_79B9_AA95_5A5D;
for _ in 0..10_000 {
s ^= s << 7;
s ^= s >> 9;
s ^= s << 8;
let header = [
(s & 0xff) as u8,
((s >> 8) & 0xff) as u8,
((s >> 16) & 0xff) as u8,
((s >> 24) & 0xff) as u8,
((s >> 32) & 0xff) as u8,
];
let classified = is_tls_handshake(&header[..3]);
let expected_classified = header[0] == TLS_RECORD_HANDSHAKE
&& header[1] == 0x03
&& (header[2] == 0x01 || header[2] == 0x03);
assert_eq!(
classified,
expected_classified,
"classifier policy mismatch for header {header:02x?}"
);
let parsed = parse_tls_record_header(&header);
let expected_parsed = header[1] == 0x03 && (header[2] == 0x01 || header[2] == TLS_VERSION[1]);
assert_eq!(
parsed.is_some(),
expected_parsed,
"parser policy mismatch for header {header:02x?}"
);
}
}
#[test]
fn stress_random_noise_handshakes_never_authenticate() {
let secret = b"stress_noise_secret";
let secrets = vec![("noise-user".to_string(), secret.to_vec())];
// Deterministic xorshift state keeps this stress test reproducible.
let mut s: u64 = 0xD1B5_4A32_9C6E_77F1;
for _ in 0..5_000 {
s ^= s << 7;
s ^= s >> 9;
s ^= s << 8;
let len = 1 + ((s as usize) % 196);
let mut buf = vec![0u8; len];
for b in &mut buf {
s ^= s << 7;
s ^= s >> 9;
s ^= s << 8;
*b = (s & 0xff) as u8;
}
assert!(
validate_tls_handshake(&buf, &secrets, true).is_none(),
"random noise must never authenticate"
);
}
}

View File

@ -24,6 +24,47 @@ enum HandshakeOutcome {
Handled,
}
#[must_use = "UserConnectionReservation must be kept alive to retain user/IP reservation until release or drop"]
struct UserConnectionReservation {
stats: Arc<Stats>,
ip_tracker: Arc<UserIpTracker>,
user: String,
ip: IpAddr,
active: bool,
}
impl UserConnectionReservation {
fn new(stats: Arc<Stats>, ip_tracker: Arc<UserIpTracker>, user: String, ip: IpAddr) -> Self {
Self {
stats,
ip_tracker,
user,
ip,
active: true,
}
}
async fn release(mut self) {
if !self.active {
return;
}
self.ip_tracker.remove_ip(&self.user, self.ip).await;
self.active = false;
self.stats.decrement_user_curr_connects(&self.user);
}
}
impl Drop for UserConnectionReservation {
fn drop(&mut self) {
if !self.active {
return;
}
self.active = false;
self.stats.decrement_user_curr_connects(&self.user);
self.ip_tracker.enqueue_cleanup(self.user.clone(), self.ip);
}
}
use crate::config::ProxyConfig;
use crate::crypto::SecureRandom;
use crate::error::{HandshakeResult, ProxyError, Result, StreamError};
@ -45,7 +86,19 @@ use crate::proxy::middle_relay::handle_via_middle_proxy;
use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController};
fn beobachten_ttl(config: &ProxyConfig) -> Duration {
Duration::from_secs(config.general.beobachten_minutes.saturating_mul(60))
let minutes = config.general.beobachten_minutes;
if minutes == 0 {
static BEOBACHTEN_ZERO_MINUTES_WARNED: OnceLock<AtomicBool> = OnceLock::new();
let warned = BEOBACHTEN_ZERO_MINUTES_WARNED.get_or_init(|| AtomicBool::new(false));
if !warned.swap(true, Ordering::Relaxed) {
warn!(
"general.beobachten_minutes=0 is insecure because entries expire immediately; forcing minimum TTL to 1 minute"
);
}
return Duration::from_secs(60);
}
Duration::from_secs(minutes.saturating_mul(60))
}
fn record_beobachten_class(
@ -90,6 +143,10 @@ fn is_trusted_proxy_source(peer_ip: IpAddr, trusted: &[IpNetwork]) -> bool {
trusted.iter().any(|cidr| cidr.contains(peer_ip))
}
fn synthetic_local_addr(port: u16) -> SocketAddr {
SocketAddr::from(([0, 0, 0, 0], port))
}
pub async fn handle_client_stream<S>(
mut stream: S,
peer: SocketAddr,
@ -113,9 +170,7 @@ where
let mut real_peer = normalize_ip(peer);
// For non-TCP streams, use a synthetic local address; may be overridden by PROXY protocol dst
let mut local_addr: SocketAddr = format!("0.0.0.0:{}", config.server.port)
.parse()
.unwrap_or_else(|_| "0.0.0.0:443".parse().unwrap());
let mut local_addr = synthetic_local_addr(config.server.port);
if proxy_protocol_enabled {
let proxy_header_timeout = Duration::from_millis(
@ -426,7 +481,6 @@ impl RunningClientHandler {
pub async fn run(self) -> Result<()> {
self.stats.increment_connects_all();
let peer = self.peer;
let _ip_tracker = self.ip_tracker.clone();
debug!(peer = %peer, "New connection");
if let Err(e) = configure_client_socket(
@ -557,7 +611,6 @@ impl RunningClientHandler {
let is_tls = tls::is_tls_handshake(&first_bytes[..3]);
let peer = self.peer;
let _ip_tracker = self.ip_tracker.clone();
debug!(peer = %peer, is_tls = is_tls, "Handshake type detected");
@ -570,7 +623,6 @@ impl RunningClientHandler {
async fn handle_tls_client(mut self, first_bytes: [u8; 5], local_addr: SocketAddr) -> Result<HandshakeOutcome> {
let peer = self.peer;
let _ip_tracker = self.ip_tracker.clone();
let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize;
@ -694,7 +746,6 @@ impl RunningClientHandler {
async fn handle_direct_client(mut self, first_bytes: [u8; 5], local_addr: SocketAddr) -> Result<HandshakeOutcome> {
let peer = self.peer;
let _ip_tracker = self.ip_tracker.clone();
if !self.config.general.modes.classic && !self.config.general.modes.secure {
debug!(peer = %peer, "Non-TLS modes disabled");
@ -798,10 +849,22 @@ impl RunningClientHandler {
{
let user = success.user.clone();
if let Err(e) = Self::check_user_limits_static(&user, &config, &stats, peer_addr, &ip_tracker).await {
warn!(user = %user, error = %e, "User limit exceeded");
return Err(e);
}
let user_limit_reservation =
match Self::acquire_user_connection_reservation_static(
&user,
&config,
stats.clone(),
peer_addr,
ip_tracker,
)
.await
{
Ok(reservation) => reservation,
Err(e) => {
warn!(user = %user, error = %e, "User admission check failed");
return Err(e);
}
};
let route_snapshot = route_runtime.snapshot();
let session_id = rng.u64();
@ -858,15 +921,68 @@ impl RunningClientHandler {
)
.await
};
stats.decrement_user_curr_connects(&user);
ip_tracker.remove_ip(&user, peer_addr.ip()).await;
user_limit_reservation.release().await;
relay_result
}
async fn acquire_user_connection_reservation_static(
user: &str,
config: &ProxyConfig,
stats: Arc<Stats>,
peer_addr: SocketAddr,
ip_tracker: Arc<UserIpTracker>,
) -> Result<UserConnectionReservation> {
if let Some(expiration) = config.access.user_expirations.get(user)
&& chrono::Utc::now() > *expiration
{
return Err(ProxyError::UserExpired {
user: user.to_string(),
});
}
if let Some(quota) = config.access.user_data_quota.get(user)
&& stats.get_user_total_octets(user) >= *quota
{
return Err(ProxyError::DataQuotaExceeded {
user: user.to_string(),
});
}
let limit = config.access.user_max_tcp_conns.get(user).map(|v| *v as u64);
if !stats.try_acquire_user_curr_connects(user, limit) {
return Err(ProxyError::ConnectionLimitExceeded {
user: user.to_string(),
});
}
match ip_tracker.check_and_add(user, peer_addr.ip()).await {
Ok(()) => {}
Err(reason) => {
stats.decrement_user_curr_connects(user);
warn!(
user = %user,
ip = %peer_addr.ip(),
reason = %reason,
"IP limit exceeded"
);
return Err(ProxyError::ConnectionLimitExceeded {
user: user.to_string(),
});
}
}
Ok(UserConnectionReservation::new(
stats,
ip_tracker,
user.to_string(),
peer_addr.ip(),
))
}
#[cfg(test)]
async fn check_user_limits_static(
user: &str,
config: &ProxyConfig,
user: &str,
config: &ProxyConfig,
stats: &Stats,
peer_addr: SocketAddr,
ip_tracker: &UserIpTracker,
@ -899,7 +1015,10 @@ impl RunningClientHandler {
}
match ip_tracker.check_and_add(user, peer_addr.ip()).await {
Ok(()) => {}
Ok(()) => {
ip_tracker.remove_ip(user, peer_addr.ip()).await;
stats.decrement_user_curr_connects(user);
}
Err(reason) => {
stats.decrement_user_curr_connects(user);
warn!(

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,8 @@
use std::ffi::OsString;
use std::fs::OpenOptions;
use std::io::Write;
use std::net::SocketAddr;
use std::path::{Component, Path, PathBuf};
use std::sync::Arc;
use std::collections::HashSet;
use std::sync::{Mutex, OnceLock};
@ -24,14 +26,28 @@ use crate::stats::Stats;
use crate::stream::{BufferPool, CryptoReader, CryptoWriter};
use crate::transport::UpstreamManager;
#[cfg(unix)]
use std::os::unix::fs::OpenOptionsExt;
const UNKNOWN_DC_LOG_DISTINCT_LIMIT: usize = 1024;
static LOGGED_UNKNOWN_DCS: OnceLock<Mutex<HashSet<i16>>> = OnceLock::new();
#[derive(Clone)]
struct SanitizedUnknownDcLogPath {
resolved_path: PathBuf,
allowed_parent: PathBuf,
file_name: OsString,
}
// In tests, this function shares global mutable state. Callers that also use
// cache-reset helpers must hold `unknown_dc_test_lock()` to keep assertions
// deterministic under parallel execution.
fn should_log_unknown_dc(dc_idx: i16) -> bool {
let set = LOGGED_UNKNOWN_DCS.get_or_init(|| Mutex::new(HashSet::new()));
should_log_unknown_dc_with_set(set, dc_idx)
}
fn should_log_unknown_dc_with_set(set: &Mutex<HashSet<i16>>, dc_idx: i16) -> bool {
match set.lock() {
Ok(mut guard) => {
if guard.contains(&dc_idx) {
@ -42,9 +58,85 @@ fn should_log_unknown_dc(dc_idx: i16) -> bool {
}
guard.insert(dc_idx)
}
// If the lock is poisoned, keep logging rather than silently dropping
// operator-visible diagnostics.
Err(_) => true,
// Fail closed on poisoned state to avoid unbounded blocking log writes.
Err(_) => false,
}
}
fn sanitize_unknown_dc_log_path(path: &str) -> Option<SanitizedUnknownDcLogPath> {
let candidate = Path::new(path);
if candidate.as_os_str().is_empty() {
return None;
}
if candidate
.components()
.any(|component| matches!(component, Component::ParentDir))
{
return None;
}
let cwd = std::env::current_dir().ok()?;
let file_name = candidate.file_name()?;
let parent = candidate.parent().unwrap_or_else(|| Path::new("."));
let parent_path = if parent.is_absolute() {
parent.to_path_buf()
} else {
cwd.join(parent)
};
let canonical_parent = parent_path.canonicalize().ok()?;
if !canonical_parent.is_dir() {
return None;
}
Some(SanitizedUnknownDcLogPath {
resolved_path: canonical_parent.join(file_name),
allowed_parent: canonical_parent,
file_name: file_name.to_os_string(),
})
}
fn unknown_dc_log_path_is_still_safe(path: &SanitizedUnknownDcLogPath) -> bool {
let Some(parent) = path.resolved_path.parent() else {
return false;
};
let Ok(current_parent) = parent.canonicalize() else {
return false;
};
if current_parent != path.allowed_parent {
return false;
}
if let Ok(canonical_target) = path.resolved_path.canonicalize() {
let Some(target_parent) = canonical_target.parent() else {
return false;
};
let Some(target_name) = canonical_target.file_name() else {
return false;
};
if target_parent != path.allowed_parent || target_name != path.file_name {
return false;
}
}
true
}
fn open_unknown_dc_log_append(path: &Path) -> std::io::Result<std::fs::File> {
#[cfg(unix)]
{
OpenOptions::new()
.create(true)
.append(true)
.custom_flags(libc::O_NOFOLLOW)
.open(path)
}
#[cfg(not(unix))]
{
let _ = path;
Err(std::io::Error::new(
std::io::ErrorKind::PermissionDenied,
"unknown_dc_file_log_enabled requires unix O_NOFOLLOW support",
))
}
}
@ -105,7 +197,7 @@ where
debug!(peer = %success.peer, "TG handshake complete, starting relay");
stats.increment_user_connects(user);
stats.increment_current_connections_direct();
let _direct_connection_lease = stats.acquire_direct_connection_lease();
let relay_result = relay_bidirectional(
client_reader,
@ -116,6 +208,7 @@ where
config.general.direct_relay_copy_buf_s2c_bytes,
user,
Arc::clone(&stats),
config.access.user_data_quota.get(user).copied(),
buffer_pool,
);
tokio::pin!(relay_result);
@ -148,8 +241,6 @@ where
}
};
stats.decrement_current_connections_direct();
match &relay_result {
Ok(()) => debug!(user = %user, "Direct relay completed"),
Err(e) => debug!(user = %user, error = %e, "Direct relay ended with error"),
@ -202,12 +293,17 @@ fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result<SocketAddr> {
&& should_log_unknown_dc(dc_idx)
&& let Ok(handle) = tokio::runtime::Handle::try_current()
{
let path = path.clone();
handle.spawn_blocking(move || {
if let Ok(mut file) = OpenOptions::new().create(true).append(true).open(path) {
let _ = writeln!(file, "dc_idx={dc_idx}");
}
});
if let Some(path) = sanitize_unknown_dc_log_path(path) {
handle.spawn_blocking(move || {
if unknown_dc_log_path_is_still_safe(&path)
&& let Ok(mut file) = open_unknown_dc_log_append(&path.resolved_path)
{
let _ = writeln!(file, "dc_idx={dc_idx}");
}
});
} else {
warn!(dc_idx = dc_idx, raw_path = %path, "Rejected unsafe unknown DC log path");
}
}
}

File diff suppressed because it is too large Load Diff

View File

@ -4,11 +4,11 @@
use std::net::SocketAddr;
use std::collections::HashSet;
use std::collections::hash_map::RandomState;
use std::net::{IpAddr, Ipv6Addr};
use std::sync::Arc;
use std::sync::{Mutex, OnceLock};
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::hash::{BuildHasher, Hash, Hasher};
use std::time::{Duration, Instant};
use dashmap::DashMap;
use dashmap::mapref::entry::Entry;
@ -36,6 +36,7 @@ const AUTH_PROBE_TRACK_MAX_ENTRIES: usize = 256;
const AUTH_PROBE_TRACK_MAX_ENTRIES: usize = 65_536;
const AUTH_PROBE_PRUNE_SCAN_LIMIT: usize = 1_024;
const AUTH_PROBE_BACKOFF_START_FAILS: u32 = 4;
const AUTH_PROBE_SATURATION_GRACE_FAILS: u32 = 2;
#[cfg(test)]
const AUTH_PROBE_BACKOFF_BASE_MS: u64 = 1;
@ -54,12 +55,25 @@ struct AuthProbeState {
last_seen: Instant,
}
#[derive(Clone, Copy)]
struct AuthProbeSaturationState {
fail_streak: u32,
blocked_until: Instant,
last_seen: Instant,
}
static AUTH_PROBE_STATE: OnceLock<DashMap<IpAddr, AuthProbeState>> = OnceLock::new();
static AUTH_PROBE_SATURATION_STATE: OnceLock<Mutex<Option<AuthProbeSaturationState>>> = OnceLock::new();
static AUTH_PROBE_EVICTION_HASHER: OnceLock<RandomState> = OnceLock::new();
fn auth_probe_state_map() -> &'static DashMap<IpAddr, AuthProbeState> {
AUTH_PROBE_STATE.get_or_init(DashMap::new)
}
fn auth_probe_saturation_state() -> &'static Mutex<Option<AuthProbeSaturationState>> {
AUTH_PROBE_SATURATION_STATE.get_or_init(|| Mutex::new(None))
}
fn normalize_auth_probe_ip(peer_ip: IpAddr) -> IpAddr {
match peer_ip {
IpAddr::V4(ip) => IpAddr::V4(ip),
@ -88,7 +102,8 @@ fn auth_probe_state_expired(state: &AuthProbeState, now: Instant) -> bool {
}
fn auth_probe_eviction_offset(peer_ip: IpAddr, now: Instant) -> usize {
let mut hasher = DefaultHasher::new();
let hasher_state = AUTH_PROBE_EVICTION_HASHER.get_or_init(RandomState::new);
let mut hasher = hasher_state.build_hasher();
peer_ip.hash(&mut hasher);
now.hash(&mut hasher);
hasher.finish() as usize
@ -108,6 +123,83 @@ fn auth_probe_is_throttled(peer_ip: IpAddr, now: Instant) -> bool {
now < entry.blocked_until
}
fn auth_probe_saturation_grace_exhausted(peer_ip: IpAddr, now: Instant) -> bool {
let peer_ip = normalize_auth_probe_ip(peer_ip);
let state = auth_probe_state_map();
let Some(entry) = state.get(&peer_ip) else {
return false;
};
if auth_probe_state_expired(&entry, now) {
drop(entry);
state.remove(&peer_ip);
return false;
}
entry.fail_streak >= AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS
}
fn auth_probe_should_apply_preauth_throttle(peer_ip: IpAddr, now: Instant) -> bool {
if !auth_probe_is_throttled(peer_ip, now) {
return false;
}
if !auth_probe_saturation_is_throttled(now) {
return true;
}
auth_probe_saturation_grace_exhausted(peer_ip, now)
}
fn auth_probe_saturation_is_throttled(now: Instant) -> bool {
let saturation = auth_probe_saturation_state();
let mut guard = match saturation.lock() {
Ok(guard) => guard,
Err(_) => return false,
};
let Some(state) = guard.as_mut() else {
return false;
};
if now.duration_since(state.last_seen) > Duration::from_secs(AUTH_PROBE_TRACK_RETENTION_SECS) {
*guard = None;
return false;
}
if now < state.blocked_until {
return true;
}
false
}
fn auth_probe_note_saturation(now: Instant) {
let saturation = auth_probe_saturation_state();
let mut guard = match saturation.lock() {
Ok(guard) => guard,
Err(_) => return,
};
match guard.as_mut() {
Some(state)
if now.duration_since(state.last_seen)
<= Duration::from_secs(AUTH_PROBE_TRACK_RETENTION_SECS) =>
{
state.fail_streak = state.fail_streak.saturating_add(1);
state.last_seen = now;
state.blocked_until = now + auth_probe_backoff(state.fail_streak);
}
_ => {
let fail_streak = AUTH_PROBE_BACKOFF_START_FAILS;
*guard = Some(AuthProbeSaturationState {
fail_streak,
blocked_until: now + auth_probe_backoff(fail_streak),
last_seen: now,
});
}
}
}
fn auth_probe_record_failure(peer_ip: IpAddr, now: Instant) {
let peer_ip = normalize_auth_probe_ip(peer_ip);
let state = auth_probe_state_map();
@ -144,24 +236,98 @@ fn auth_probe_record_failure_with_state(
}
if state.len() >= AUTH_PROBE_TRACK_MAX_ENTRIES {
let mut stale_keys = Vec::new();
let mut eviction_candidates = Vec::new();
for entry in state.iter().take(AUTH_PROBE_PRUNE_SCAN_LIMIT) {
eviction_candidates.push(*entry.key());
if auth_probe_state_expired(entry.value(), now) {
stale_keys.push(*entry.key());
let mut rounds = 0usize;
while state.len() >= AUTH_PROBE_TRACK_MAX_ENTRIES {
rounds += 1;
if rounds > 8 {
auth_probe_note_saturation(now);
let mut eviction_candidate: Option<(IpAddr, u32, Instant)> = None;
for entry in state.iter().take(AUTH_PROBE_PRUNE_SCAN_LIMIT) {
let key = *entry.key();
let fail_streak = entry.value().fail_streak;
let last_seen = entry.value().last_seen;
match eviction_candidate {
Some((_, current_fail, current_seen))
if fail_streak > current_fail
|| (fail_streak == current_fail && last_seen >= current_seen) =>
{
}
_ => eviction_candidate = Some((key, fail_streak, last_seen)),
}
}
let Some((evict_key, _, _)) = eviction_candidate else {
return;
};
state.remove(&evict_key);
break;
}
}
for stale_key in stale_keys {
state.remove(&stale_key);
}
if state.len() >= AUTH_PROBE_TRACK_MAX_ENTRIES {
if eviction_candidates.is_empty() {
let mut stale_keys = Vec::new();
let mut eviction_candidate: Option<(IpAddr, u32, Instant)> = None;
let state_len = state.len();
let scan_limit = state_len.min(AUTH_PROBE_PRUNE_SCAN_LIMIT);
let start_offset = if state_len == 0 {
0
} else {
auth_probe_eviction_offset(peer_ip, now) % state_len
};
let mut scanned = 0usize;
for entry in state.iter().skip(start_offset) {
let key = *entry.key();
let fail_streak = entry.value().fail_streak;
let last_seen = entry.value().last_seen;
match eviction_candidate {
Some((_, current_fail, current_seen))
if fail_streak > current_fail
|| (fail_streak == current_fail && last_seen >= current_seen) =>
{
}
_ => eviction_candidate = Some((key, fail_streak, last_seen)),
}
if auth_probe_state_expired(entry.value(), now) {
stale_keys.push(key);
}
scanned += 1;
if scanned >= scan_limit {
break;
}
}
if scanned < scan_limit {
for entry in state.iter().take(scan_limit - scanned) {
let key = *entry.key();
let fail_streak = entry.value().fail_streak;
let last_seen = entry.value().last_seen;
match eviction_candidate {
Some((_, current_fail, current_seen))
if fail_streak > current_fail
|| (fail_streak == current_fail && last_seen >= current_seen) =>
{
}
_ => eviction_candidate = Some((key, fail_streak, last_seen)),
}
if auth_probe_state_expired(entry.value(), now) {
stale_keys.push(key);
}
}
}
for stale_key in stale_keys {
state.remove(&stale_key);
}
if state.len() < AUTH_PROBE_TRACK_MAX_ENTRIES {
break;
}
let Some((evict_key, _, _)) = eviction_candidate else {
auth_probe_note_saturation(now);
return;
}
let idx = auth_probe_eviction_offset(peer_ip, now) % eviction_candidates.len();
let evict_key = eviction_candidates[idx];
};
state.remove(&evict_key);
auth_probe_note_saturation(now);
}
}
@ -186,6 +352,11 @@ fn clear_auth_probe_state_for_testing() {
if let Some(state) = AUTH_PROBE_STATE.get() {
state.clear();
}
if let Some(saturation) = AUTH_PROBE_SATURATION_STATE.get()
&& let Ok(mut guard) = saturation.lock()
{
*guard = None;
}
}
#[cfg(test)]
@ -200,6 +371,16 @@ fn auth_probe_is_throttled_for_testing(peer_ip: IpAddr) -> bool {
auth_probe_is_throttled(peer_ip, Instant::now())
}
#[cfg(test)]
fn auth_probe_saturation_is_throttled_for_testing() -> bool {
auth_probe_saturation_is_throttled(Instant::now())
}
#[cfg(test)]
fn auth_probe_saturation_is_throttled_at_for_testing(now: Instant) -> bool {
auth_probe_saturation_is_throttled(now)
}
#[cfg(test)]
fn auth_probe_test_lock() -> &'static Mutex<()> {
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
@ -317,6 +498,24 @@ fn decode_user_secrets(
secrets
}
async fn maybe_apply_server_hello_delay(config: &ProxyConfig) {
if config.censorship.server_hello_delay_max_ms == 0 {
return;
}
let min = config.censorship.server_hello_delay_min_ms;
let max = config.censorship.server_hello_delay_max_ms.max(min);
let delay_ms = if max == min {
max
} else {
rand::rng().random_range(min..=max)
};
if delay_ms > 0 {
tokio::time::sleep(Duration::from_millis(delay_ms)).await;
}
}
/// Result of successful handshake
///
/// Key material (`dec_key`, `dec_iv`, `enc_key`, `enc_iv`) is
@ -338,6 +537,7 @@ pub struct HandshakeSuccess {
/// Client address
pub peer: SocketAddr,
/// Whether TLS was used
pub is_tls: bool,
}
@ -367,17 +567,21 @@ where
{
debug!(peer = %peer, handshake_len = handshake.len(), "Processing TLS handshake");
if auth_probe_is_throttled(peer.ip(), Instant::now()) {
let throttle_now = Instant::now();
if auth_probe_should_apply_preauth_throttle(peer.ip(), throttle_now) {
maybe_apply_server_hello_delay(config).await;
debug!(peer = %peer, "TLS handshake rejected by pre-auth probe throttle");
return HandshakeResult::BadClient { reader, writer };
}
if handshake.len() < tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 {
maybe_apply_server_hello_delay(config).await;
debug!(peer = %peer, "TLS handshake too short");
return HandshakeResult::BadClient { reader, writer };
}
let secrets = decode_user_secrets(config, None);
let client_sni = tls::extract_sni_from_client_hello(handshake);
let secrets = decode_user_secrets(config, client_sni.as_deref());
let validation = match tls::validate_tls_handshake_with_replay_window(
handshake,
@ -388,6 +592,7 @@ where
Some(v) => v,
None => {
auth_probe_record_failure(peer.ip(), Instant::now());
maybe_apply_server_hello_delay(config).await;
debug!(
peer = %peer,
ignore_time_skew = config.access.ignore_time_skew,
@ -402,20 +607,24 @@ where
let digest_half = &validation.digest[..tls::TLS_DIGEST_HALF_LEN];
if replay_checker.check_and_add_tls_digest(digest_half) {
auth_probe_record_failure(peer.ip(), Instant::now());
maybe_apply_server_hello_delay(config).await;
warn!(peer = %peer, "TLS replay attack detected (duplicate digest)");
return HandshakeResult::BadClient { reader, writer };
}
let secret = match secrets.iter().find(|(name, _)| *name == validation.user) {
Some((_, s)) => s,
None => return HandshakeResult::BadClient { reader, writer },
None => {
maybe_apply_server_hello_delay(config).await;
return HandshakeResult::BadClient { reader, writer };
}
};
let cached = if config.censorship.tls_emulation {
if let Some(cache) = tls_cache.as_ref() {
let selected_domain = if let Some(sni) = tls::extract_sni_from_client_hello(handshake) {
let selected_domain = if let Some(sni) = client_sni.as_ref() {
if cache.contains_domain(&sni).await {
sni
sni.clone()
} else {
config.censorship.tls_domain.clone()
}
@ -448,6 +657,7 @@ where
} else if alpn_list.iter().any(|p| p == b"http/1.1") {
Some(b"http/1.1".to_vec())
} else if !alpn_list.is_empty() {
maybe_apply_server_hello_delay(config).await;
debug!(peer = %peer, "Client ALPN list has no supported protocol; using masking fallback");
return HandshakeResult::BadClient { reader, writer };
} else {
@ -480,19 +690,9 @@ where
)
};
// Optional anti-fingerprint delay before sending ServerHello.
if config.censorship.server_hello_delay_max_ms > 0 {
let min = config.censorship.server_hello_delay_min_ms;
let max = config.censorship.server_hello_delay_max_ms.max(min);
let delay_ms = if max == min {
max
} else {
rand::rng().random_range(min..=max)
};
if delay_ms > 0 {
tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
}
}
// Apply the same optional delay budget used by reject paths to reduce
// distinguishability between success and fail-closed handshakes.
maybe_apply_server_hello_delay(config).await;
debug!(peer = %peer, response_len = response.len(), "Sending TLS ServerHello");
@ -536,9 +736,15 @@ where
R: AsyncRead + Unpin + Send,
W: AsyncWrite + Unpin + Send,
{
trace!(peer = %peer, handshake = ?hex::encode(handshake), "MTProto handshake bytes");
trace!(
peer = %peer,
handshake_head = %hex::encode(&handshake[..8]),
"MTProto handshake prefix"
);
if auth_probe_is_throttled(peer.ip(), Instant::now()) {
let throttle_now = Instant::now();
if auth_probe_should_apply_preauth_throttle(peer.ip(), throttle_now) {
maybe_apply_server_hello_delay(config).await;
debug!(peer = %peer, "MTProto handshake rejected by pre-auth probe throttle");
return HandshakeResult::BadClient { reader, writer };
}
@ -609,6 +815,7 @@ where
// authentication check first to avoid poisoning the replay cache.
if replay_checker.check_and_add_handshake(dec_prekey_iv) {
auth_probe_record_failure(peer.ip(), Instant::now());
maybe_apply_server_hello_delay(config).await;
warn!(peer = %peer, user = %user, "MTProto replay attack detected");
return HandshakeResult::BadClient { reader, writer };
}
@ -645,6 +852,7 @@ where
}
auth_probe_record_failure(peer.ip(), Instant::now());
maybe_apply_server_hello_delay(config).await;
debug!(peer = %peer, "MTProto handshake: no matching user found");
HandshakeResult::BadClient { reader, writer }
}
@ -732,6 +940,7 @@ pub fn encrypt_tg_nonce_with_ciphers(nonce: &[u8; HANDSHAKE_LEN]) -> (Vec<u8>, A
}
/// Encrypt nonce for sending to Telegram (legacy function for compatibility)
pub fn encrypt_tg_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec<u8> {
let (encrypted, _, _) = encrypt_tg_nonce_with_ciphers(nonce);
encrypted

File diff suppressed because it is too large Load Diff

View File

@ -7,7 +7,7 @@ use tokio::net::TcpStream;
#[cfg(unix)]
use tokio::net::UnixStream;
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt};
use tokio::time::timeout;
use tokio::time::{Instant, timeout};
use tracing::debug;
use crate::config::ProxyConfig;
use crate::network::dns_overrides::resolve_socket_addr;
@ -24,8 +24,36 @@ const MASK_TIMEOUT: Duration = Duration::from_millis(50);
const MASK_RELAY_TIMEOUT: Duration = Duration::from_secs(60);
#[cfg(test)]
const MASK_RELAY_TIMEOUT: Duration = Duration::from_millis(200);
#[cfg(not(test))]
const MASK_RELAY_IDLE_TIMEOUT: Duration = Duration::from_secs(5);
#[cfg(test)]
const MASK_RELAY_IDLE_TIMEOUT: Duration = Duration::from_millis(100);
const MASK_BUFFER_SIZE: usize = 8192;
async fn copy_with_idle_timeout<R, W>(reader: &mut R, writer: &mut W)
where
R: AsyncRead + Unpin,
W: AsyncWrite + Unpin,
{
let mut buf = [0u8; MASK_BUFFER_SIZE];
loop {
let read_res = timeout(MASK_RELAY_IDLE_TIMEOUT, reader.read(&mut buf)).await;
let n = match read_res {
Ok(Ok(n)) => n,
Ok(Err(_)) | Err(_) => break,
};
if n == 0 {
break;
}
let write_res = timeout(MASK_RELAY_IDLE_TIMEOUT, writer.write_all(&buf[..n])).await;
match write_res {
Ok(Ok(())) => {}
Ok(Err(_)) | Err(_) => break,
}
}
}
async fn write_proxy_header_with_timeout<W>(mask_write: &mut W, header: &[u8]) -> bool
where
W: AsyncWrite + Unpin,
@ -49,6 +77,20 @@ where
}
}
async fn wait_mask_connect_budget(started: Instant) {
let elapsed = started.elapsed();
if elapsed < MASK_TIMEOUT {
tokio::time::sleep(MASK_TIMEOUT - elapsed).await;
}
}
async fn wait_mask_outcome_budget(started: Instant) {
let elapsed = started.elapsed();
if elapsed < MASK_TIMEOUT {
tokio::time::sleep(MASK_TIMEOUT - elapsed).await;
}
}
/// Detect client type based on initial data
fn detect_client_type(data: &[u8]) -> &'static str {
// Check for HTTP request
@ -107,6 +149,8 @@ where
// Connect via Unix socket or TCP
#[cfg(unix)]
if let Some(ref sock_path) = config.censorship.mask_unix_sock {
let outcome_started = Instant::now();
let connect_started = Instant::now();
debug!(
client_type = client_type,
sock = %sock_path,
@ -143,14 +187,18 @@ where
if timeout(MASK_RELAY_TIMEOUT, relay_to_mask(reader, writer, mask_read, mask_write, initial_data)).await.is_err() {
debug!("Mask relay timed out (unix socket)");
}
wait_mask_outcome_budget(outcome_started).await;
}
Ok(Err(e)) => {
wait_mask_connect_budget(connect_started).await;
debug!(error = %e, "Failed to connect to mask unix socket");
consume_client_data_with_timeout(reader).await;
wait_mask_outcome_budget(outcome_started).await;
}
Err(_) => {
debug!("Timeout connecting to mask unix socket");
consume_client_data_with_timeout(reader).await;
wait_mask_outcome_budget(outcome_started).await;
}
}
return;
@ -172,6 +220,8 @@ where
let mask_addr = resolve_socket_addr(mask_host, mask_port)
.map(|addr| addr.to_string())
.unwrap_or_else(|| format!("{}:{}", mask_host, mask_port));
let outcome_started = Instant::now();
let connect_started = Instant::now();
let connect_result = timeout(MASK_TIMEOUT, TcpStream::connect(&mask_addr)).await;
match connect_result {
Ok(Ok(stream)) => {
@ -202,14 +252,18 @@ where
if timeout(MASK_RELAY_TIMEOUT, relay_to_mask(reader, writer, mask_read, mask_write, initial_data)).await.is_err() {
debug!("Mask relay timed out");
}
wait_mask_outcome_budget(outcome_started).await;
}
Ok(Err(e)) => {
wait_mask_connect_budget(connect_started).await;
debug!(error = %e, "Failed to connect to mask host");
consume_client_data_with_timeout(reader).await;
wait_mask_outcome_budget(outcome_started).await;
}
Err(_) => {
debug!("Timeout connecting to mask host");
consume_client_data_with_timeout(reader).await;
wait_mask_outcome_budget(outcome_started).await;
}
}
}
@ -238,11 +292,11 @@ where
let _ = tokio::join!(
async {
let _ = tokio::io::copy(&mut reader, &mut mask_write).await;
copy_with_idle_timeout(&mut reader, &mut mask_write).await;
let _ = mask_write.shutdown().await;
},
async {
let _ = tokio::io::copy(&mut mask_read, &mut writer).await;
copy_with_idle_timeout(&mut mask_read, &mut writer).await;
let _ = writer.shutdown().await;
}
);

View File

@ -8,7 +8,7 @@ use tokio::io::{duplex, AsyncBufReadExt, BufReader};
use tokio::net::TcpListener;
#[cfg(unix)]
use tokio::net::UnixListener;
use tokio::time::{sleep, timeout, Duration};
use tokio::time::{Instant, sleep, timeout, Duration};
#[tokio::test]
async fn bad_client_probe_is_forwarded_verbatim_to_mask_backend() {
@ -216,6 +216,373 @@ async fn backend_unavailable_falls_back_to_silent_consume() {
assert_eq!(n, 0);
}
#[tokio::test]
async fn backend_connect_refusal_waits_mask_connect_budget_before_fallback() {
let temp_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let unused_port = temp_listener.local_addr().unwrap().port();
drop(temp_listener);
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = unused_port;
config.censorship.mask_unix_sock = None;
config.censorship.mask_proxy_protocol = 0;
let peer: SocketAddr = "203.0.113.12:42426".parse().unwrap();
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let probe = b"GET /probe HTTP/1.1\r\nHost: x\r\n\r\n";
// Close client reader immediately to force the refusal path to rely on masking budget timing.
let (client_reader_side, client_reader) = duplex(256);
drop(client_reader_side);
let (_client_visible_reader, client_visible_writer) = duplex(256);
let beobachten = BeobachtenStore::new();
let started = Instant::now();
let task = tokio::spawn(async move {
handle_bad_client(
client_reader,
client_visible_writer,
probe,
peer,
local_addr,
&config,
&beobachten,
)
.await;
});
timeout(Duration::from_millis(35), task)
.await
.expect_err("masking fallback must not complete before connect budget elapses");
assert!(
started.elapsed() >= Duration::from_millis(35),
"fallback path must absorb immediate connect refusal into connect budget"
);
}
#[tokio::test]
async fn backend_reachable_fast_response_waits_mask_outcome_budget() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let probe = b"GET /ok HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec();
let backend_reply = b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\n\r\n".to_vec();
let accept_task = tokio::spawn({
let probe = probe.clone();
let backend_reply = backend_reply.clone();
async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut received = vec![0u8; probe.len()];
stream.read_exact(&mut received).await.unwrap();
assert_eq!(received, probe);
stream.write_all(&backend_reply).await.unwrap();
}
});
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = backend_addr.port();
config.censorship.mask_unix_sock = None;
config.censorship.mask_proxy_protocol = 0;
let peer: SocketAddr = "203.0.113.13:42427".parse().unwrap();
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let (client_writer_side, client_reader) = duplex(256);
drop(client_writer_side);
let (_client_visible_reader, client_visible_writer) = duplex(512);
let beobachten = BeobachtenStore::new();
let started = Instant::now();
handle_bad_client(
client_reader,
client_visible_writer,
&probe,
peer,
local_addr,
&config,
&beobachten,
)
.await;
assert!(
started.elapsed() >= Duration::from_millis(45),
"reachable mask path must also satisfy coarse outcome budget"
);
accept_task.await.unwrap();
}
#[tokio::test]
async fn mask_disabled_fast_eof_not_shaped_by_mask_budget() {
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = false;
let peer: SocketAddr = "203.0.113.14:42428".parse().unwrap();
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let (client_writer_side, client_reader) = duplex(256);
drop(client_writer_side);
let (_client_visible_reader, client_visible_writer) = duplex(256);
let beobachten = BeobachtenStore::new();
let started = Instant::now();
handle_bad_client(
client_reader,
client_visible_writer,
b"x",
peer,
local_addr,
&config,
&beobachten,
)
.await;
assert!(
started.elapsed() < Duration::from_millis(20),
"mask-disabled fallback should keep immediate EOF behavior"
);
}
#[tokio::test]
async fn backend_reachable_slow_response_not_padded_twice() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let probe = b"GET /slow HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec();
let backend_reply = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK".to_vec();
let accept_task = tokio::spawn({
let probe = probe.clone();
let backend_reply = backend_reply.clone();
async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut received = vec![0u8; probe.len()];
stream.read_exact(&mut received).await.unwrap();
assert_eq!(received, probe);
sleep(Duration::from_millis(90)).await;
stream.write_all(&backend_reply).await.unwrap();
}
});
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = backend_addr.port();
config.censorship.mask_unix_sock = None;
config.censorship.mask_proxy_protocol = 0;
let peer: SocketAddr = "203.0.113.15:42429".parse().unwrap();
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let (client_writer_side, client_reader) = duplex(256);
drop(client_writer_side);
let (_client_visible_reader, client_visible_writer) = duplex(512);
let beobachten = BeobachtenStore::new();
let started = Instant::now();
handle_bad_client(
client_reader,
client_visible_writer,
&probe,
peer,
local_addr,
&config,
&beobachten,
)
.await;
let elapsed = started.elapsed();
assert!(elapsed >= Duration::from_millis(85));
assert!(
elapsed < Duration::from_millis(170),
"slow reachable backend should not incur an extra full budget after already exceeding it"
);
accept_task.await.unwrap();
}
#[tokio::test]
async fn adversarial_enabled_refused_and_reachable_collapse_to_same_bucket() {
const ITER: usize = 20;
const BUCKET_MS: u128 = 10;
let probe = b"GET /collapse HTTP/1.1\r\nHost: x\r\n\r\n";
let peer: SocketAddr = "203.0.113.16:42430".parse().unwrap();
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let mut refused = Vec::with_capacity(ITER);
for _ in 0..ITER {
let temp_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let unused_port = temp_listener.local_addr().unwrap().port();
drop(temp_listener);
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = unused_port;
config.censorship.mask_unix_sock = None;
config.censorship.mask_proxy_protocol = 0;
let (client_writer_side, client_reader) = duplex(256);
drop(client_writer_side);
let (_client_visible_reader, client_visible_writer) = duplex(256);
let beobachten = BeobachtenStore::new();
let started = Instant::now();
handle_bad_client(
client_reader,
client_visible_writer,
probe,
peer,
local_addr,
&config,
&beobachten,
)
.await;
refused.push(started.elapsed().as_millis());
}
let mut reachable = Vec::with_capacity(ITER);
for _ in 0..ITER {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let probe_vec = probe.to_vec();
let backend_reply = b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\n\r\n".to_vec();
let accept_task = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut received = vec![0u8; probe_vec.len()];
stream.read_exact(&mut received).await.unwrap();
stream.write_all(&backend_reply).await.unwrap();
});
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = backend_addr.port();
config.censorship.mask_unix_sock = None;
config.censorship.mask_proxy_protocol = 0;
let (client_writer_side, client_reader) = duplex(256);
drop(client_writer_side);
let (_client_visible_reader, client_visible_writer) = duplex(256);
let beobachten = BeobachtenStore::new();
let started = Instant::now();
handle_bad_client(
client_reader,
client_visible_writer,
probe,
peer,
local_addr,
&config,
&beobachten,
)
.await;
reachable.push(started.elapsed().as_millis());
accept_task.await.unwrap();
}
let refused_mean = refused.iter().copied().sum::<u128>() as f64 / refused.len() as f64;
let reachable_mean = reachable.iter().copied().sum::<u128>() as f64 / reachable.len() as f64;
let refused_bucket = (refused_mean as u128) / BUCKET_MS;
let reachable_bucket = (reachable_mean as u128) / BUCKET_MS;
assert!(
refused_bucket.abs_diff(reachable_bucket) <= 1,
"enabled refused and reachable paths must collapse into the same coarse latency bucket"
);
}
#[tokio::test]
async fn light_fuzz_mask_enabled_outcomes_preserve_coarse_budget() {
let mut seed: u64 = 0xA5A5_5A5A_1337_4242;
let mut next = || {
seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
seed
};
let peer: SocketAddr = "203.0.113.17:42431".parse().unwrap();
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
for _ in 0..40 {
let probe_len = (next() as usize % 96).saturating_add(8);
let mut probe = vec![0u8; probe_len];
for byte in &mut probe {
*byte = next() as u8;
}
let use_reachable = (next() & 1) == 0;
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_unix_sock = None;
config.censorship.mask_proxy_protocol = 0;
let (client_writer_side, client_reader) = duplex(512);
drop(client_writer_side);
let (_client_visible_reader, client_visible_writer) = duplex(512);
let beobachten = BeobachtenStore::new();
let started = Instant::now();
if use_reachable {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = backend_addr.port();
let probe_vec = probe.clone();
let accept_task = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut observed = vec![0u8; probe_vec.len()];
stream.read_exact(&mut observed).await.unwrap();
});
handle_bad_client(
client_reader,
client_visible_writer,
&probe,
peer,
local_addr,
&config,
&beobachten,
)
.await;
accept_task.await.unwrap();
} else {
let temp_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let unused_port = temp_listener.local_addr().unwrap().port();
drop(temp_listener);
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = unused_port;
handle_bad_client(
client_reader,
client_visible_writer,
&probe,
peer,
local_addr,
&config,
&beobachten,
)
.await;
}
assert!(
started.elapsed() >= Duration::from_millis(45),
"mask-enabled fallback must preserve coarse timing budget under varied probe shapes"
);
}
}
#[tokio::test]
async fn mask_disabled_consumes_client_data_without_response() {
let mut config = ProxyConfig::default();
@ -524,6 +891,59 @@ async fn mask_disabled_slowloris_connection_is_closed_by_consume_timeout() {
timeout(Duration::from_secs(1), task).await.unwrap().unwrap();
}
#[tokio::test]
async fn mask_enabled_idle_relay_is_closed_by_idle_timeout_before_global_relay_timeout() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let probe = b"GET /idle HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec();
let accept_task = tokio::spawn({
let probe = probe.clone();
async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut received = vec![0u8; probe.len()];
stream.read_exact(&mut received).await.unwrap();
assert_eq!(received, probe);
sleep(Duration::from_millis(300)).await;
}
});
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = backend_addr.port();
config.censorship.mask_unix_sock = None;
config.censorship.mask_proxy_protocol = 0;
let peer: SocketAddr = "198.51.100.34:45456".parse().unwrap();
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let (_client_reader_side, client_reader) = duplex(512);
let (_client_visible_reader, client_visible_writer) = duplex(512);
let beobachten = BeobachtenStore::new();
let started = Instant::now();
handle_bad_client(
client_reader,
client_visible_writer,
&probe,
peer,
local_addr,
&config,
&beobachten,
)
.await;
let elapsed = started.elapsed();
assert!(
elapsed < Duration::from_millis(150),
"idle unauth relay must terminate on idle timeout instead of waiting for full relay timeout"
);
accept_task.await.unwrap();
}
struct PendingWriter;
impl tokio::io::AsyncWrite for PendingWriter {
@ -729,3 +1149,321 @@ async fn relay_to_mask_timeout_cancels_and_drops_all_io_endpoints() {
assert!(mask_reader_dropped.load(Ordering::SeqCst));
assert!(mask_writer_dropped.load(Ordering::SeqCst));
}
#[tokio::test]
#[ignore = "timing matrix; run manually with --ignored --nocapture"]
async fn timing_matrix_masking_classes_under_controlled_inputs() {
const ITER: usize = 24;
const BUCKET_MS: u128 = 10;
let probe = b"GET /timing HTTP/1.1\r\nHost: x\r\n\r\n";
let peer: SocketAddr = "203.0.113.40:51000".parse().unwrap();
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
// Class 1: masking disabled with immediate EOF (fast fail-closed consume path).
let mut disabled_samples = Vec::with_capacity(ITER);
for _ in 0..ITER {
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = false;
let (client_writer_side, client_reader) = duplex(256);
drop(client_writer_side);
let (_client_visible_reader, client_visible_writer) = duplex(256);
let beobachten = BeobachtenStore::new();
let started = Instant::now();
handle_bad_client(
client_reader,
client_visible_writer,
probe,
peer,
local_addr,
&config,
&beobachten,
)
.await;
disabled_samples.push(started.elapsed().as_millis());
}
// Class 2: masking enabled, backend connect refused.
let mut refused_samples = Vec::with_capacity(ITER);
for _ in 0..ITER {
let temp_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let unused_port = temp_listener.local_addr().unwrap().port();
drop(temp_listener);
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = unused_port;
config.censorship.mask_unix_sock = None;
config.censorship.mask_proxy_protocol = 0;
let (client_writer_side, client_reader) = duplex(256);
drop(client_writer_side);
let (_client_visible_reader, client_visible_writer) = duplex(256);
let beobachten = BeobachtenStore::new();
let started = Instant::now();
handle_bad_client(
client_reader,
client_visible_writer,
probe,
peer,
local_addr,
&config,
&beobachten,
)
.await;
refused_samples.push(started.elapsed().as_millis());
}
// Class 3: masking enabled, backend reachable and immediately responds.
let mut reachable_samples = Vec::with_capacity(ITER);
for _ in 0..ITER {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let backend_reply = b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\n\r\n".to_vec();
let probe_vec = probe.to_vec();
let accept_task = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut received = vec![0u8; probe_vec.len()];
stream.read_exact(&mut received).await.unwrap();
assert_eq!(received, probe_vec);
stream.write_all(&backend_reply).await.unwrap();
});
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = backend_addr.port();
config.censorship.mask_unix_sock = None;
config.censorship.mask_proxy_protocol = 0;
let (client_writer_side, client_reader) = duplex(256);
drop(client_writer_side);
let (_client_visible_reader, client_visible_writer) = duplex(256);
let beobachten = BeobachtenStore::new();
let started = Instant::now();
handle_bad_client(
client_reader,
client_visible_writer,
probe,
peer,
local_addr,
&config,
&beobachten,
)
.await;
reachable_samples.push(started.elapsed().as_millis());
accept_task.await.unwrap();
}
fn summarize(samples_ms: &mut [u128]) -> (f64, u128, u128, u128) {
samples_ms.sort_unstable();
let sum: u128 = samples_ms.iter().copied().sum();
let mean = sum as f64 / samples_ms.len() as f64;
let min = samples_ms[0];
let p95_idx = ((samples_ms.len() as f64) * 0.95).floor() as usize;
let p95 = samples_ms[p95_idx.min(samples_ms.len() - 1)];
let max = samples_ms[samples_ms.len() - 1];
(mean, min, p95, max)
}
let (disabled_mean, disabled_min, disabled_p95, disabled_max) = summarize(&mut disabled_samples);
let (refused_mean, refused_min, refused_p95, refused_max) = summarize(&mut refused_samples);
let (reachable_mean, reachable_min, reachable_p95, reachable_max) = summarize(&mut reachable_samples);
println!(
"TIMING_MATRIX masking class=disabled_eof mean_ms={:.2} min_ms={} p95_ms={} max_ms={} bucket_mean={}",
disabled_mean,
disabled_min,
disabled_p95,
disabled_max,
(disabled_mean as u128) / BUCKET_MS
);
println!(
"TIMING_MATRIX masking class=enabled_refused_eof mean_ms={:.2} min_ms={} p95_ms={} max_ms={} bucket_mean={}",
refused_mean,
refused_min,
refused_p95,
refused_max,
(refused_mean as u128) / BUCKET_MS
);
println!(
"TIMING_MATRIX masking class=enabled_reachable_eof mean_ms={:.2} min_ms={} p95_ms={} max_ms={} bucket_mean={}",
reachable_mean,
reachable_min,
reachable_p95,
reachable_max,
(reachable_mean as u128) / BUCKET_MS
);
}
#[tokio::test]
async fn backend_connect_refusal_completes_within_bounded_mask_budget() {
let temp_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let unused_port = temp_listener.local_addr().unwrap().port();
drop(temp_listener);
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = unused_port;
config.censorship.mask_unix_sock = None;
config.censorship.mask_proxy_protocol = 0;
let peer: SocketAddr = "203.0.113.41:51001".parse().unwrap();
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let probe = b"GET /bounded HTTP/1.1\r\nHost: x\r\n\r\n";
let (_client_reader_side, client_reader) = duplex(256);
let (_client_visible_reader, client_visible_writer) = duplex(256);
let beobachten = BeobachtenStore::new();
let started = Instant::now();
handle_bad_client(
client_reader,
client_visible_writer,
probe,
peer,
local_addr,
&config,
&beobachten,
)
.await;
let elapsed = started.elapsed();
assert!(
elapsed >= Duration::from_millis(45),
"connect refusal path must respect minimum masking budget"
);
assert!(
elapsed < Duration::from_millis(500),
"connect refusal path must stay bounded and avoid unbounded stall"
);
}
#[tokio::test]
async fn reachable_backend_one_response_then_silence_is_cut_by_idle_timeout() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let probe = b"GET /oneshot HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec();
let response = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK".to_vec();
let accept_task = tokio::spawn({
let probe = probe.clone();
let response = response.clone();
async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut received = vec![0u8; probe.len()];
stream.read_exact(&mut received).await.unwrap();
assert_eq!(received, probe);
stream.write_all(&response).await.unwrap();
sleep(Duration::from_millis(300)).await;
}
});
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = backend_addr.port();
config.censorship.mask_unix_sock = None;
config.censorship.mask_proxy_protocol = 0;
let peer: SocketAddr = "203.0.113.42:51002".parse().unwrap();
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let (_client_reader_side, client_reader) = duplex(256);
let (mut client_visible_reader, client_visible_writer) = duplex(512);
let beobachten = BeobachtenStore::new();
let started = Instant::now();
handle_bad_client(
client_reader,
client_visible_writer,
&probe,
peer,
local_addr,
&config,
&beobachten,
)
.await;
let elapsed = started.elapsed();
let mut observed = vec![0u8; response.len()];
client_visible_reader.read_exact(&mut observed).await.unwrap();
assert_eq!(observed, response);
assert!(
elapsed < Duration::from_millis(190),
"idle backend silence after first response must be cut by relay idle timeout"
);
accept_task.await.unwrap();
}
#[tokio::test]
async fn adversarial_client_drip_feed_longer_than_idle_timeout_is_cut_off() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let initial = b"GET /drip HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec();
let accept_task = tokio::spawn({
let initial = initial.clone();
async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut observed = vec![0u8; initial.len()];
stream.read_exact(&mut observed).await.unwrap();
assert_eq!(observed, initial);
let mut extra = [0u8; 1];
let read_res = timeout(Duration::from_millis(220), stream.read_exact(&mut extra)).await;
assert!(
read_res.is_err() || read_res.unwrap().is_err(),
"drip-fed post-probe byte arriving after idle timeout should not be forwarded"
);
}
});
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = backend_addr.port();
config.censorship.mask_unix_sock = None;
config.censorship.mask_proxy_protocol = 0;
let peer: SocketAddr = "203.0.113.43:51003".parse().unwrap();
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let (mut client_writer_side, client_reader) = duplex(256);
let (_client_visible_reader, client_visible_writer) = duplex(256);
let beobachten = BeobachtenStore::new();
let relay_task = tokio::spawn(async move {
handle_bad_client(
client_reader,
client_visible_writer,
&initial,
peer,
local_addr,
&config,
&beobachten,
)
.await;
});
sleep(Duration::from_millis(160)).await;
let _ = client_writer_side.write_all(b"X").await;
drop(client_writer_side);
timeout(Duration::from_secs(1), relay_task).await.unwrap().unwrap();
accept_task.await.unwrap();
}

View File

@ -1,15 +1,14 @@
use std::collections::hash_map::DefaultHasher;
use std::collections::hash_map::RandomState;
use std::hash::BuildHasher;
use std::hash::{Hash, Hasher};
use std::net::{IpAddr, SocketAddr};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, OnceLock};
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::{Arc, Mutex, OnceLock};
use std::time::{Duration, Instant};
#[cfg(test)]
use std::sync::Mutex;
use dashmap::DashMap;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::sync::{mpsc, oneshot, watch};
use tokio::sync::{mpsc, oneshot, watch, Mutex as AsyncMutex};
use tokio::time::timeout;
use tracing::{debug, trace, warn};
@ -34,13 +33,22 @@ enum C2MeCommand {
const DESYNC_DEDUP_WINDOW: Duration = Duration::from_secs(60);
const DESYNC_DEDUP_MAX_ENTRIES: usize = 65_536;
const DESYNC_DEDUP_PRUNE_SCAN_LIMIT: usize = 1024;
const DESYNC_FULL_CACHE_EMIT_MIN_INTERVAL: Duration = Duration::from_millis(1000);
const DESYNC_ERROR_CLASS: &str = "frame_too_large_crypto_desync";
const C2ME_CHANNEL_CAPACITY_FALLBACK: usize = 128;
const C2ME_SOFT_PRESSURE_MIN_FREE_SLOTS: usize = 64;
const C2ME_SENDER_FAIRNESS_BUDGET: usize = 32;
#[cfg(test)]
const C2ME_SEND_TIMEOUT: Duration = Duration::from_millis(50);
#[cfg(not(test))]
const C2ME_SEND_TIMEOUT: Duration = Duration::from_secs(5);
const ME_D2C_FLUSH_BATCH_MAX_FRAMES_MIN: usize = 1;
const ME_D2C_FLUSH_BATCH_MAX_BYTES_MIN: usize = 4096;
static DESYNC_DEDUP: OnceLock<DashMap<u64, Instant>> = OnceLock::new();
static DESYNC_HASHER: OnceLock<RandomState> = OnceLock::new();
static DESYNC_FULL_CACHE_LAST_EMIT_AT: OnceLock<Mutex<Option<Instant>>> = OnceLock::new();
static DESYNC_DEDUP_EVER_SATURATED: OnceLock<AtomicBool> = OnceLock::new();
static QUOTA_USER_LOCKS: OnceLock<DashMap<String, Arc<AsyncMutex<()>>>> = OnceLock::new();
struct RelayForensicsState {
trace_id: u64,
@ -80,7 +88,8 @@ impl MeD2cFlushPolicy {
}
fn hash_value<T: Hash>(value: &T) -> u64 {
let mut hasher = DefaultHasher::new();
let state = DESYNC_HASHER.get_or_init(RandomState::new);
let mut hasher = state.build_hasher();
value.hash(&mut hasher);
hasher.finish()
}
@ -95,6 +104,11 @@ fn should_emit_full_desync(key: u64, all_full: bool, now: Instant) -> bool {
}
let dedup = DESYNC_DEDUP.get_or_init(DashMap::new);
let saturated_before = dedup.len() >= DESYNC_DEDUP_MAX_ENTRIES;
let ever_saturated = DESYNC_DEDUP_EVER_SATURATED.get_or_init(|| AtomicBool::new(false));
if saturated_before {
ever_saturated.store(true, Ordering::Relaxed);
}
if let Some(mut seen_at) = dedup.get_mut(&key) {
if now.duration_since(*seen_at) >= DESYNC_DEDUP_WINDOW {
@ -106,12 +120,17 @@ fn should_emit_full_desync(key: u64, all_full: bool, now: Instant) -> bool {
if dedup.len() >= DESYNC_DEDUP_MAX_ENTRIES {
let mut stale_keys = Vec::new();
let mut eviction_candidate = None;
let mut oldest_candidate: Option<(u64, Instant)> = None;
for entry in dedup.iter().take(DESYNC_DEDUP_PRUNE_SCAN_LIMIT) {
if eviction_candidate.is_none() {
eviction_candidate = Some(*entry.key());
let key = *entry.key();
let seen_at = *entry.value();
match oldest_candidate {
Some((_, oldest_seen)) if seen_at >= oldest_seen => {}
_ => oldest_candidate = Some((key, seen_at)),
}
if now.duration_since(*entry.value()) >= DESYNC_DEDUP_WINDOW {
if now.duration_since(seen_at) >= DESYNC_DEDUP_WINDOW {
stale_keys.push(*entry.key());
}
}
@ -119,17 +138,57 @@ fn should_emit_full_desync(key: u64, all_full: bool, now: Instant) -> bool {
dedup.remove(&stale_key);
}
if dedup.len() >= DESYNC_DEDUP_MAX_ENTRIES {
let Some(evict_key) = eviction_candidate else {
let Some((evict_key, _)) = oldest_candidate else {
return false;
};
dedup.remove(&evict_key);
dedup.insert(key, now);
return false;
return should_emit_full_desync_full_cache(now);
}
}
dedup.insert(key, now);
true
let saturated_after = dedup.len() >= DESYNC_DEDUP_MAX_ENTRIES;
// Preserve the first sequential insert that reaches capacity as a normal
// emit, while still gating concurrent newcomer churn after the cache has
// ever been observed at saturation.
let was_ever_saturated = if saturated_after {
ever_saturated.swap(true, Ordering::Relaxed)
} else {
ever_saturated.load(Ordering::Relaxed)
};
if saturated_before || (saturated_after && was_ever_saturated) {
should_emit_full_desync_full_cache(now)
} else {
true
}
}
fn should_emit_full_desync_full_cache(now: Instant) -> bool {
let gate = DESYNC_FULL_CACHE_LAST_EMIT_AT.get_or_init(|| Mutex::new(None));
let Ok(mut last_emit_at) = gate.lock() else {
return false;
};
match *last_emit_at {
None => {
*last_emit_at = Some(now);
true
}
Some(last) => {
let Some(elapsed) = now.checked_duration_since(last) else {
*last_emit_at = Some(now);
return true;
};
if elapsed >= DESYNC_FULL_CACHE_EMIT_MIN_INTERVAL {
*last_emit_at = Some(now);
true
} else {
false
}
}
}
}
#[cfg(test)]
@ -137,6 +196,21 @@ fn clear_desync_dedup_for_testing() {
if let Some(dedup) = DESYNC_DEDUP.get() {
dedup.clear();
}
if let Some(ever_saturated) = DESYNC_DEDUP_EVER_SATURATED.get() {
ever_saturated.store(false, Ordering::Relaxed);
}
if let Some(last_emit_at) = DESYNC_FULL_CACHE_LAST_EMIT_AT.get() {
match last_emit_at.lock() {
Ok(mut guard) => {
*guard = None;
}
Err(poisoned) => {
let mut guard = poisoned.into_inner();
*guard = None;
last_emit_at.clear_poison();
}
}
}
}
#[cfg(test)]
@ -240,6 +314,38 @@ fn should_yield_c2me_sender(sent_since_yield: usize, has_backlog: bool) -> bool
has_backlog && sent_since_yield >= C2ME_SENDER_FAIRNESS_BUDGET
}
fn quota_exceeded_for_user(stats: &Stats, user: &str, quota_limit: Option<u64>) -> bool {
quota_limit.is_some_and(|quota| stats.get_user_total_octets(user) >= quota)
}
fn quota_would_be_exceeded_for_user(
stats: &Stats,
user: &str,
quota_limit: Option<u64>,
bytes: u64,
) -> bool {
quota_limit.is_some_and(|quota| {
let used = stats.get_user_total_octets(user);
used >= quota || bytes > quota.saturating_sub(used)
})
}
fn quota_user_lock(user: &str) -> Arc<AsyncMutex<()>> {
let locks = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
if let Some(existing) = locks.get(user) {
return Arc::clone(existing.value());
}
let created = Arc::new(AsyncMutex::new(()));
match locks.entry(user.to_string()) {
dashmap::mapref::entry::Entry::Occupied(entry) => Arc::clone(entry.get()),
dashmap::mapref::entry::Entry::Vacant(entry) => {
entry.insert(Arc::clone(&created));
created
}
}
}
async fn enqueue_c2me_command(
tx: &mpsc::Sender<C2MeCommand>,
cmd: C2MeCommand,
@ -252,7 +358,14 @@ async fn enqueue_c2me_command(
if tx.capacity() <= C2ME_SOFT_PRESSURE_MIN_FREE_SLOTS {
tokio::task::yield_now().await;
}
tx.send(cmd).await
match timeout(C2ME_SEND_TIMEOUT, tx.reserve()).await {
Ok(Ok(permit)) => {
permit.send(cmd);
Ok(())
}
Ok(Err(_)) => Err(mpsc::error::SendError(cmd)),
Err(_) => Err(mpsc::error::SendError(cmd)),
}
}
}
}
@ -276,6 +389,7 @@ where
W: AsyncWrite + Unpin + Send + 'static,
{
let user = success.user.clone();
let quota_limit = config.access.user_data_quota.get(&user).copied();
let peer = success.peer;
let proto_tag = success.proto_tag;
let pool_generation = me_pool.current_generation();
@ -306,7 +420,7 @@ where
};
stats.increment_user_connects(&user);
stats.increment_current_connections_me();
let _me_connection_lease = stats.acquire_me_connection_lease();
if let Some(cutover) = affected_cutover_state(
&route_rx,
@ -324,7 +438,6 @@ where
tokio::time::sleep(delay).await;
let _ = me_pool.send_close(conn_id).await;
me_pool.registry().unregister(conn_id).await;
stats.decrement_current_connections_me();
return Err(ProxyError::Proxy(ROUTE_SWITCH_ERROR_MSG.to_string()));
}
@ -425,6 +538,7 @@ where
&mut frame_buf,
stats_clone.as_ref(),
&user_clone,
quota_limit,
bytes_me2c_clone.as_ref(),
conn_id,
d2c_flush_policy.ack_flush_immediate,
@ -457,6 +571,7 @@ where
&mut frame_buf,
stats_clone.as_ref(),
&user_clone,
quota_limit,
bytes_me2c_clone.as_ref(),
conn_id,
d2c_flush_policy.ack_flush_immediate,
@ -489,6 +604,7 @@ where
&mut frame_buf,
stats_clone.as_ref(),
&user_clone,
quota_limit,
bytes_me2c_clone.as_ref(),
conn_id,
d2c_flush_policy.ack_flush_immediate,
@ -521,6 +637,7 @@ where
&mut frame_buf,
stats_clone.as_ref(),
&user_clone,
quota_limit,
bytes_me2c_clone.as_ref(),
conn_id,
d2c_flush_policy.ack_flush_immediate,
@ -602,7 +719,19 @@ where
forensics.bytes_c2me = forensics
.bytes_c2me
.saturating_add(payload.len() as u64);
stats.add_user_octets_from(&user, payload.len() as u64);
if let Some(limit) = quota_limit {
let quota_lock = quota_user_lock(&user);
let _quota_guard = quota_lock.lock().await;
stats.add_user_octets_from(&user, payload.len() as u64);
if quota_exceeded_for_user(stats.as_ref(), &user, Some(limit)) {
main_result = Err(ProxyError::DataQuotaExceeded {
user: user.clone(),
});
break;
}
} else {
stats.add_user_octets_from(&user, payload.len() as u64);
}
let mut flags = proto_flags;
if quickack {
flags |= RPC_FLAG_QUICKACK;
@ -672,7 +801,6 @@ where
"ME relay cleanup"
);
me_pool.registry().unregister(conn_id).await;
stats.decrement_current_connections_me();
result
}
@ -827,6 +955,7 @@ async fn process_me_writer_response<W>(
frame_buf: &mut Vec<u8>,
stats: &Stats,
user: &str,
quota_limit: Option<u64>,
bytes_me2c: &AtomicU64,
conn_id: u64,
ack_flush_immediate: bool,
@ -842,17 +971,47 @@ where
} else {
trace!(conn_id, bytes = data.len(), flags, "ME->C data");
}
bytes_me2c.fetch_add(data.len() as u64, Ordering::Relaxed);
stats.add_user_octets_to(user, data.len() as u64);
write_client_payload(
client_writer,
proto_tag,
flags,
&data,
rng,
frame_buf,
)
.await?;
let data_len = data.len() as u64;
if let Some(limit) = quota_limit {
let quota_lock = quota_user_lock(user);
let _quota_guard = quota_lock.lock().await;
if quota_would_be_exceeded_for_user(stats, user, Some(limit), data_len) {
return Err(ProxyError::DataQuotaExceeded {
user: user.to_string(),
});
}
write_client_payload(
client_writer,
proto_tag,
flags,
&data,
rng,
frame_buf,
)
.await?;
bytes_me2c.fetch_add(data.len() as u64, Ordering::Relaxed);
stats.add_user_octets_to(user, data.len() as u64);
if quota_exceeded_for_user(stats, user, Some(limit)) {
return Err(ProxyError::DataQuotaExceeded {
user: user.to_string(),
});
}
} else {
write_client_payload(
client_writer,
proto_tag,
flags,
&data,
rng,
frame_buf,
)
.await?;
bytes_me2c.fetch_add(data.len() as u64, Ordering::Relaxed);
stats.add_user_octets_to(user, data.len() as u64);
}
Ok(MeWriterResponseOutcome::Continue {
frames: 1,

File diff suppressed because it is too large Load Diff

View File

@ -53,16 +53,17 @@
use std::io;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex, OnceLock};
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::task::{Context, Poll};
use std::time::Duration;
use dashmap::DashMap;
use tokio::io::{
AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf, copy_bidirectional_with_sizes,
};
use tokio::time::Instant;
use tracing::{debug, trace, warn};
use crate::error::Result;
use crate::error::{ProxyError, Result};
use crate::stats::Stats;
use crate::stream::BufferPool;
@ -205,6 +206,8 @@ struct StatsIo<S> {
counters: Arc<SharedCounters>,
stats: Arc<Stats>,
user: String,
quota_limit: Option<u64>,
quota_exceeded: Arc<AtomicBool>,
epoch: Instant,
}
@ -214,11 +217,62 @@ impl<S> StatsIo<S> {
counters: Arc<SharedCounters>,
stats: Arc<Stats>,
user: String,
quota_limit: Option<u64>,
quota_exceeded: Arc<AtomicBool>,
epoch: Instant,
) -> Self {
// Mark initial activity so the watchdog doesn't fire before data flows
counters.touch(Instant::now(), epoch);
Self { inner, counters, stats, user, epoch }
Self {
inner,
counters,
stats,
user,
quota_limit,
quota_exceeded,
epoch,
}
}
}
#[derive(Debug)]
struct QuotaIoSentinel;
impl std::fmt::Display for QuotaIoSentinel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("user data quota exceeded")
}
}
impl std::error::Error for QuotaIoSentinel {}
fn quota_io_error() -> io::Error {
io::Error::new(io::ErrorKind::PermissionDenied, QuotaIoSentinel)
}
fn is_quota_io_error(err: &io::Error) -> bool {
err.kind() == io::ErrorKind::PermissionDenied
&& err
.get_ref()
.and_then(|source| source.downcast_ref::<QuotaIoSentinel>())
.is_some()
}
static QUOTA_USER_LOCKS: OnceLock<DashMap<String, Arc<Mutex<()>>>> = OnceLock::new();
fn quota_user_lock(user: &str) -> Arc<Mutex<()>> {
let locks = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
if let Some(existing) = locks.get(user) {
return Arc::clone(existing.value());
}
let created = Arc::new(Mutex::new(()));
match locks.entry(user.to_string()) {
dashmap::mapref::entry::Entry::Occupied(entry) => Arc::clone(entry.get()),
dashmap::mapref::entry::Entry::Vacant(entry) => {
entry.insert(Arc::clone(&created));
created
}
}
}
@ -229,6 +283,32 @@ impl<S: AsyncRead + Unpin> AsyncRead for StatsIo<S> {
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let this = self.get_mut();
if this.quota_exceeded.load(Ordering::Relaxed) {
return Poll::Ready(Err(quota_io_error()));
}
let quota_lock = this
.quota_limit
.is_some()
.then(|| quota_user_lock(&this.user));
let _quota_guard = if let Some(lock) = quota_lock.as_ref() {
match lock.try_lock() {
Ok(guard) => Some(guard),
Err(_) => {
cx.waker().wake_by_ref();
return Poll::Pending;
}
}
} else {
None
};
if let Some(limit) = this.quota_limit
&& this.stats.get_user_total_octets(&this.user) >= limit
{
this.quota_exceeded.store(true, Ordering::Relaxed);
return Poll::Ready(Err(quota_io_error()));
}
let before = buf.filled().len();
match Pin::new(&mut this.inner).poll_read(cx, buf) {
@ -243,6 +323,13 @@ impl<S: AsyncRead + Unpin> AsyncRead for StatsIo<S> {
this.stats.add_user_octets_from(&this.user, n as u64);
this.stats.increment_user_msgs_from(&this.user);
if let Some(limit) = this.quota_limit
&& this.stats.get_user_total_octets(&this.user) >= limit
{
this.quota_exceeded.store(true, Ordering::Relaxed);
return Poll::Ready(Err(quota_io_error()));
}
trace!(user = %this.user, bytes = n, "C->S");
}
Poll::Ready(Ok(()))
@ -259,8 +346,46 @@ impl<S: AsyncWrite + Unpin> AsyncWrite for StatsIo<S> {
buf: &[u8],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
if this.quota_exceeded.load(Ordering::Relaxed) {
return Poll::Ready(Err(quota_io_error()));
}
match Pin::new(&mut this.inner).poll_write(cx, buf) {
let quota_lock = this
.quota_limit
.is_some()
.then(|| quota_user_lock(&this.user));
let _quota_guard = if let Some(lock) = quota_lock.as_ref() {
match lock.try_lock() {
Ok(guard) => Some(guard),
Err(_) => {
cx.waker().wake_by_ref();
return Poll::Pending;
}
}
} else {
None
};
let write_buf = if let Some(limit) = this.quota_limit {
let used = this.stats.get_user_total_octets(&this.user);
if used >= limit {
this.quota_exceeded.store(true, Ordering::Relaxed);
return Poll::Ready(Err(quota_io_error()));
}
let remaining = (limit - used) as usize;
if buf.len() > remaining {
// Fail closed: do not emit partial S->C payload when remaining
// quota cannot accommodate the pending write request.
this.quota_exceeded.store(true, Ordering::Relaxed);
return Poll::Ready(Err(quota_io_error()));
}
buf
} else {
buf
};
match Pin::new(&mut this.inner).poll_write(cx, write_buf) {
Poll::Ready(Ok(n)) => {
if n > 0 {
// S→C: data written to client
@ -271,6 +396,13 @@ impl<S: AsyncWrite + Unpin> AsyncWrite for StatsIo<S> {
this.stats.add_user_octets_to(&this.user, n as u64);
this.stats.increment_user_msgs_to(&this.user);
if let Some(limit) = this.quota_limit
&& this.stats.get_user_total_octets(&this.user) >= limit
{
this.quota_exceeded.store(true, Ordering::Relaxed);
return Poll::Ready(Err(quota_io_error()));
}
trace!(user = %this.user, bytes = n, "S->C");
}
Poll::Ready(Ok(n))
@ -307,7 +439,8 @@ impl<S: AsyncWrite + Unpin> AsyncWrite for StatsIo<S> {
/// - Per-user stats: bytes and ops counted per direction
/// - Periodic rate logging: every 10 seconds when active
/// - Clean shutdown: both write sides are shut down on exit
/// - Error propagation: I/O errors are returned as `ProxyError::Io`
/// - Error propagation: quota exits return `ProxyError::DataQuotaExceeded`,
/// other I/O failures are returned as `ProxyError::Io`
pub async fn relay_bidirectional<CR, CW, SR, SW>(
client_reader: CR,
client_writer: CW,
@ -317,6 +450,7 @@ pub async fn relay_bidirectional<CR, CW, SR, SW>(
s2c_buf_size: usize,
user: &str,
stats: Arc<Stats>,
quota_limit: Option<u64>,
_buffer_pool: Arc<BufferPool>,
) -> Result<()>
where
@ -327,6 +461,7 @@ where
{
let epoch = Instant::now();
let counters = Arc::new(SharedCounters::new());
let quota_exceeded = Arc::new(AtomicBool::new(false));
let user_owned = user.to_string();
// ── Combine split halves into bidirectional streams ──────────────
@ -339,12 +474,15 @@ where
Arc::clone(&counters),
Arc::clone(&stats),
user_owned.clone(),
quota_limit,
Arc::clone(&quota_exceeded),
epoch,
);
// ── Watchdog: activity timeout + periodic rate logging ──────────
let wd_counters = Arc::clone(&counters);
let wd_user = user_owned.clone();
let wd_quota_exceeded = Arc::clone(&quota_exceeded);
let watchdog = async {
let mut prev_c2s: u64 = 0;
@ -356,6 +494,11 @@ where
let now = Instant::now();
let idle = wd_counters.idle_duration(now, epoch);
if wd_quota_exceeded.load(Ordering::Relaxed) {
warn!(user = %wd_user, "User data quota reached, closing relay");
return;
}
// ── Activity timeout ────────────────────────────────────
if idle >= ACTIVITY_TIMEOUT {
let c2s = wd_counters.c2s_bytes.load(Ordering::Relaxed);
@ -439,6 +582,22 @@ where
);
Ok(())
}
Some(Err(e)) if is_quota_io_error(&e) => {
let c2s = counters.c2s_bytes.load(Ordering::Relaxed);
let s2c = counters.s2c_bytes.load(Ordering::Relaxed);
warn!(
user = %user_owned,
c2s_bytes = c2s,
s2c_bytes = s2c,
c2s_msgs = c2s_ops,
s2c_msgs = s2c_ops,
duration_secs = duration.as_secs(),
"Data quota reached, closing relay"
);
Err(ProxyError::DataQuotaExceeded {
user: user_owned.clone(),
})
}
Some(Err(e)) => {
// I/O error in one of the directions
let c2s = counters.c2s_bytes.load(Ordering::Relaxed);
@ -472,3 +631,7 @@ where
}
}
}
#[cfg(test)]
#[path = "relay_security_tests.rs"]
mod security_tests;

View File

@ -0,0 +1,972 @@
use super::relay_bidirectional;
use crate::error::ProxyError;
use crate::stats::Stats;
use crate::stream::BufferPool;
use std::future::poll_fn;
use std::io;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Mutex;
use std::task::{Context, Poll};
use std::task::Waker;
use tokio::io::{AsyncRead, ReadBuf};
use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt, duplex};
use tokio::time::{Duration, timeout};
#[tokio::test]
async fn relay_bidirectional_enforces_live_user_quota() {
let stats = Arc::new(Stats::new());
let user = "quota-user";
stats.add_user_octets_from(user, 6);
let (mut client_peer, relay_client) = duplex(4096);
let (relay_server, mut server_peer) = duplex(4096);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let relay_task = tokio::spawn(relay_bidirectional(
client_reader,
client_writer,
server_reader,
server_writer,
1024,
1024,
user,
Arc::clone(&stats),
Some(8),
Arc::new(BufferPool::new()),
));
client_peer
.write_all(&[0x10, 0x20, 0x30, 0x40])
.await
.expect("client write must succeed");
let mut forwarded = [0u8; 4];
let _ = timeout(
Duration::from_millis(200),
server_peer.read_exact(&mut forwarded),
)
.await;
let relay_result = timeout(Duration::from_secs(2), relay_task)
.await
.expect("relay task must finish under quota cutoff")
.expect("relay task must not panic");
assert!(
matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == "quota-user"),
"relay must surface a typed quota error once live quota is exceeded"
);
}
#[tokio::test]
async fn relay_bidirectional_does_not_forward_server_bytes_after_quota_is_exhausted() {
let stats = Arc::new(Stats::new());
let quota_user = "quota-exhausted-user";
stats.add_user_octets_from(quota_user, 1);
let (mut client_peer, relay_client) = duplex(4096);
let (relay_server, mut server_peer) = duplex(4096);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let relay_task = tokio::spawn(relay_bidirectional(
client_reader,
client_writer,
server_reader,
server_writer,
1024,
1024,
quota_user,
Arc::clone(&stats),
Some(1),
Arc::new(BufferPool::new()),
));
server_peer
.write_all(&[0xde, 0xad, 0xbe, 0xef])
.await
.expect("server write must succeed");
let mut observed = [0u8; 4];
let forwarded = timeout(
Duration::from_millis(200),
client_peer.read_exact(&mut observed),
)
.await;
let relay_result = timeout(Duration::from_secs(2), relay_task)
.await
.expect("relay task must finish under quota cutoff")
.expect("relay task must not panic");
assert!(
!matches!(forwarded, Ok(Ok(n)) if n == observed.len()),
"no full server payload should be forwarded once quota is already exhausted"
);
assert!(
matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user),
"relay must still terminate with a typed quota error"
);
}
#[tokio::test]
async fn relay_bidirectional_does_not_leak_partial_server_payload_when_remaining_quota_is_smaller_than_write() {
let stats = Arc::new(Stats::new());
let quota_user = "partial-leak-user";
stats.add_user_octets_from(quota_user, 3);
let (mut client_peer, relay_client) = duplex(4096);
let (relay_server, mut server_peer) = duplex(4096);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let relay_task = tokio::spawn(relay_bidirectional(
client_reader,
client_writer,
server_reader,
server_writer,
1024,
1024,
quota_user,
Arc::clone(&stats),
Some(4),
Arc::new(BufferPool::new()),
));
server_peer
.write_all(&[0x11, 0x22, 0x33, 0x44])
.await
.expect("server write must succeed");
let mut observed = [0u8; 8];
let forwarded = timeout(Duration::from_millis(200), client_peer.read(&mut observed)).await;
let relay_result = timeout(Duration::from_secs(2), relay_task)
.await
.expect("relay task must finish under quota cutoff")
.expect("relay task must not panic");
assert!(
!matches!(forwarded, Ok(Ok(n)) if n > 0),
"quota exhaustion must not leak any partial server payload when remaining quota is smaller than the write"
);
assert!(
matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user),
"relay must still terminate with a typed quota error"
);
}
#[tokio::test]
async fn relay_bidirectional_zero_quota_remains_fail_closed_for_server_payloads_under_stress() {
let stats = Arc::new(Stats::new());
let quota_user = "zero-quota-user";
for payload_len in [1usize, 16, 512, 4096] {
let (mut client_peer, relay_client) = duplex(4096);
let (relay_server, mut server_peer) = duplex(4096);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let relay_task = tokio::spawn(relay_bidirectional(
client_reader,
client_writer,
server_reader,
server_writer,
1024,
1024,
quota_user,
Arc::clone(&stats),
Some(0),
Arc::new(BufferPool::new()),
));
let payload = vec![0x7f; payload_len];
let _ = server_peer.write_all(&payload).await;
let mut observed = vec![0u8; payload_len];
let forwarded = timeout(Duration::from_millis(200), client_peer.read(&mut observed)).await;
let relay_result = timeout(Duration::from_secs(2), relay_task)
.await
.expect("relay task must finish under zero-quota cutoff")
.expect("relay task must not panic");
assert!(
!matches!(forwarded, Ok(Ok(n)) if n > 0),
"zero quota must not forward any server bytes for payload_len={payload_len}"
);
assert!(
matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user),
"zero quota must terminate with the typed quota error for payload_len={payload_len}"
);
}
}
#[tokio::test]
async fn relay_bidirectional_allows_exact_server_payload_at_quota_boundary() {
let stats = Arc::new(Stats::new());
let quota_user = "exact-boundary-user";
let (mut client_peer, relay_client) = duplex(4096);
let (relay_server, mut server_peer) = duplex(4096);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let relay_task = tokio::spawn(relay_bidirectional(
client_reader,
client_writer,
server_reader,
server_writer,
1024,
1024,
quota_user,
Arc::clone(&stats),
Some(4),
Arc::new(BufferPool::new()),
));
server_peer
.write_all(&[0x91, 0x92, 0x93, 0x94])
.await
.expect("server write must succeed at exact quota boundary");
let mut observed = [0u8; 4];
client_peer
.read_exact(&mut observed)
.await
.expect("client must receive the full payload at the exact quota boundary");
assert_eq!(observed, [0x91, 0x92, 0x93, 0x94]);
let relay_result = timeout(Duration::from_secs(2), relay_task)
.await
.expect("relay task must finish after exact boundary delivery")
.expect("relay task must not panic");
assert!(
matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user),
"relay must close with a typed quota error after reaching the exact boundary"
);
}
#[tokio::test]
async fn relay_bidirectional_does_not_forward_client_bytes_after_quota_is_exhausted() {
let stats = Arc::new(Stats::new());
let quota_user = "client-exhausted-user";
stats.add_user_octets_from(quota_user, 1);
let (mut client_peer, relay_client) = duplex(4096);
let (relay_server, mut server_peer) = duplex(4096);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let relay_task = tokio::spawn(relay_bidirectional(
client_reader,
client_writer,
server_reader,
server_writer,
1024,
1024,
quota_user,
Arc::clone(&stats),
Some(1),
Arc::new(BufferPool::new()),
));
client_peer
.write_all(&[0x51, 0x52, 0x53, 0x54])
.await
.expect("client write must succeed even when quota is already exhausted");
let mut observed = [0u8; 4];
let forwarded = timeout(
Duration::from_millis(200),
server_peer.read_exact(&mut observed),
)
.await;
let relay_result = timeout(Duration::from_secs(2), relay_task)
.await
.expect("relay task must finish under quota cutoff")
.expect("relay task must not panic");
assert!(
!matches!(forwarded, Ok(Ok(n)) if n == observed.len()),
"client payload must not be fully forwarded once quota is already exhausted"
);
assert!(
matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user),
"relay must still terminate with a typed quota error"
);
}
#[tokio::test]
async fn relay_bidirectional_server_bytes_remain_blocked_even_under_multiple_payload_sizes() {
let stats = Arc::new(Stats::new());
let quota_user = "quota-fuzz-user";
stats.add_user_octets_from(quota_user, 2);
for payload_len in [1usize, 32, 1024, 8192] {
let (mut client_peer, relay_client) = duplex(4096);
let (relay_server, mut server_peer) = duplex(4096);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let relay_task = tokio::spawn(relay_bidirectional(
client_reader,
client_writer,
server_reader,
server_writer,
1024,
1024,
quota_user,
Arc::clone(&stats),
Some(2),
Arc::new(BufferPool::new()),
));
let payload = vec![0xaa; payload_len];
let _ = server_peer.write_all(&payload).await;
let mut observed = vec![0u8; payload_len];
let forwarded = timeout(
Duration::from_millis(200),
client_peer.read_exact(&mut observed),
)
.await;
let relay_result = timeout(Duration::from_secs(2), relay_task)
.await
.expect("relay task must finish under quota cutoff")
.expect("relay task must not panic");
assert!(
!matches!(forwarded, Ok(Ok(n)) if n == payload_len),
"quota exhaustion must block full server-to-client forwarding for payload_len={payload_len}"
);
assert!(
matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user),
"relay must keep returning the typed quota error for payload_len={payload_len}"
);
}
}
#[tokio::test]
async fn relay_bidirectional_terminates_on_activity_timeout() {
tokio::time::pause();
let stats = Arc::new(Stats::new());
let user = "timeout-user";
let (client_peer, relay_client) = duplex(4096);
let (relay_server, server_peer) = duplex(4096);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let relay_task = tokio::spawn(relay_bidirectional(
client_reader,
client_writer,
server_reader,
server_writer,
1024,
1024,
user,
Arc::clone(&stats),
None, // No quota
Arc::new(BufferPool::new()),
));
// Wait past the activity timeout threshold (1800 seconds) + buffer
tokio::time::sleep(Duration::from_secs(1805)).await;
// Resume time to process timeouts
tokio::time::resume();
let relay_result = timeout(Duration::from_secs(1), relay_task)
.await
.expect("relay task must finish inside bounded timeout due to inactivity cutoff")
.expect("relay task must not panic");
assert!(
relay_result.is_ok(),
"relay should complete successfully on scheduled inactivity timeout"
);
// Verify client/server sockets are closed
drop(client_peer);
drop(server_peer);
}
#[tokio::test]
async fn relay_bidirectional_watchdog_resists_premature_execution() {
tokio::time::pause();
let stats = Arc::new(Stats::new());
let user = "activity-user";
let (mut client_peer, relay_client) = duplex(4096);
let (relay_server, server_peer) = duplex(4096);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let mut relay_task = tokio::spawn(relay_bidirectional(
client_reader,
client_writer,
server_reader,
server_writer,
1024,
1024,
user,
Arc::clone(&stats),
None,
Arc::new(BufferPool::new()),
));
// Advance by half the timeout
tokio::time::sleep(Duration::from_secs(900)).await;
// Provide activity
client_peer
.write_all(&[0xaa, 0xbb])
.await
.expect("client write must succeed");
client_peer.flush().await.unwrap();
// Advance by another half (total time since start is 1800, but since last activity is 900)
tokio::time::sleep(Duration::from_secs(900)).await;
tokio::time::resume();
// Re-evaluating the task, it should NOT have timed out and still be pending
let relay_result = timeout(Duration::from_millis(100), &mut relay_task).await;
assert!(
relay_result.is_err(),
"Relay must not exit prematurely as long as activity was received before timeout"
);
// Explicitly drop sockets to cleanly shut down relay loop
drop(client_peer);
drop(server_peer);
let completion = timeout(Duration::from_secs(1), relay_task).await
.expect("relay task must complete securely after client disconnection")
.expect("relay task must not panic");
assert!(completion.is_ok(), "relay exits clean");
}
#[tokio::test]
async fn relay_bidirectional_half_closure_terminates_cleanly() {
let stats = Arc::new(Stats::new());
let (client_peer, relay_client) = duplex(4096);
let (relay_server, server_peer) = duplex(4096);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let relay_task = tokio::spawn(relay_bidirectional(
client_reader, client_writer, server_reader, server_writer, 1024, 1024, "half-close", stats, None, Arc::new(BufferPool::new()),
));
// Half closure: drop the client completely but leave the server active.
drop(client_peer);
// Check that we don't immediately crash. Bidirectional relay stays open for the server -> client flush.
// Eventually dropping the server cleanly closes the task.
drop(server_peer);
timeout(Duration::from_secs(1), relay_task).await.unwrap().unwrap().unwrap();
}
#[tokio::test]
async fn relay_bidirectional_zero_length_noise_fuzzing() {
let stats = Arc::new(Stats::new());
let (mut client_peer, relay_client) = duplex(4096);
let (relay_server, mut server_peer) = duplex(4096);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let relay_task = tokio::spawn(relay_bidirectional(
client_reader, client_writer, server_reader, server_writer, 1024, 1024, "fuzz", stats, None, Arc::new(BufferPool::new()),
));
// Flood with zero-length payloads (edge cases in stream framing logic sometimes loop)
for _ in 0..100 {
client_peer.write_all(&[]).await.unwrap();
}
client_peer.write_all(&[1, 2, 3]).await.unwrap();
client_peer.flush().await.unwrap();
let mut buf = [0u8; 3];
server_peer.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, &[1, 2, 3]);
drop(client_peer);
drop(server_peer);
timeout(Duration::from_secs(1), relay_task).await.unwrap().unwrap().unwrap();
}
#[tokio::test]
async fn relay_bidirectional_asymmetric_backpressure() {
let stats = Arc::new(Stats::new());
// Give the client stream an extremely narrow throughput limit explicitly
let (client_peer, relay_client) = duplex(1024);
let (relay_server, mut server_peer) = duplex(4096);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let relay_task = tokio::spawn(relay_bidirectional(
client_reader, client_writer, server_reader, server_writer, 1024, 1024, "slowloris", stats, None, Arc::new(BufferPool::new()),
));
let payload = vec![0xba; 65536]; // 64k payload
// Server attempts to shove 64KB into a relay whose client pipe only holds 1KB!
let write_res = tokio::time::timeout(Duration::from_millis(50), server_peer.write_all(&payload)).await;
assert!(
write_res.is_err(),
"Relay backpressure MUST halt the server writer from unbounded buffering when client stream is full!"
);
drop(client_peer);
drop(server_peer);
let completion = timeout(Duration::from_secs(1), relay_task).await.unwrap().unwrap();
assert!(
completion.is_ok() || completion.is_err(),
"Task must unwind reliably (either Ok or BrokenPipe Err) when dropped despite active backpressure locks"
);
}
use rand::{Rng, SeedableRng, rngs::StdRng};
#[tokio::test]
async fn relay_bidirectional_light_fuzzing_temporal_jitter() {
tokio::time::pause();
let stats = Arc::new(Stats::new());
let (mut client_peer, relay_client) = duplex(4096);
let (relay_server, server_peer) = duplex(4096);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let mut relay_task = tokio::spawn(relay_bidirectional(
client_reader, client_writer, server_reader, server_writer, 1024, 1024, "fuzz-user", stats, None, Arc::new(BufferPool::new()),
));
let mut rng = StdRng::seed_from_u64(0xDEADBEEF);
for _ in 0..10 {
// Vary timing significantly up to 1600 seconds (limit is 1800s)
let jitter = rng.random_range(100..1600);
tokio::time::sleep(Duration::from_secs(jitter)).await;
client_peer.write_all(&[0x11]).await.unwrap();
client_peer.flush().await.unwrap();
// Ensure task has not died
let res = timeout(Duration::from_millis(10), &mut relay_task).await;
assert!(res.is_err(), "Relay must remain open indefinitely under light temporal fuzzing with active jitter pulses");
}
drop(client_peer);
drop(server_peer);
timeout(Duration::from_secs(1), relay_task).await.unwrap().unwrap().unwrap();
}
struct FaultyReader {
error_once: Option<io::Error>,
}
struct TwoPartyGate {
arrivals: AtomicUsize,
total_bytes: AtomicUsize,
wakers: Mutex<Vec<Waker>>,
}
impl TwoPartyGate {
fn new() -> Self {
Self {
arrivals: AtomicUsize::new(0),
total_bytes: AtomicUsize::new(0),
wakers: Mutex::new(Vec::new()),
}
}
fn arrive_or_park(&self, cx: &mut Context<'_>) -> bool {
if self.arrivals.load(Ordering::Relaxed) >= 2 {
return true;
}
let prev = self.arrivals.fetch_add(1, Ordering::AcqRel);
if prev + 1 >= 2 {
let mut wakers = self.wakers.lock().unwrap_or_else(|p| p.into_inner());
for waker in wakers.drain(..) {
waker.wake();
}
true
} else {
let mut wakers = self.wakers.lock().unwrap_or_else(|p| p.into_inner());
wakers.push(cx.waker().clone());
false
}
}
fn total_bytes(&self) -> usize {
self.total_bytes.load(Ordering::Relaxed)
}
}
struct GateWriter {
gate: Arc<TwoPartyGate>,
entered: bool,
}
impl GateWriter {
fn new(gate: Arc<TwoPartyGate>) -> Self {
Self {
gate,
entered: false,
}
}
}
impl AsyncWrite for GateWriter {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
if !self.entered {
self.entered = true;
}
if !self.gate.arrive_or_park(cx) {
return Poll::Pending;
}
self.gate
.total_bytes
.fetch_add(buf.len(), Ordering::Relaxed);
Poll::Ready(Ok(buf.len()))
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
}
struct GateReader {
gate: Arc<TwoPartyGate>,
entered: bool,
emitted: bool,
}
impl GateReader {
fn new(gate: Arc<TwoPartyGate>) -> Self {
Self {
gate,
entered: false,
emitted: false,
}
}
}
impl AsyncRead for GateReader {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
if self.emitted {
return Poll::Ready(Ok(()));
}
if !self.entered {
self.entered = true;
}
if !self.gate.arrive_or_park(cx) {
return Poll::Pending;
}
buf.put_slice(&[0x42]);
self.gate.total_bytes.fetch_add(1, Ordering::Relaxed);
self.emitted = true;
Poll::Ready(Ok(()))
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn adversarial_concurrent_quota_write_race_does_not_overshoot_limit() {
let stats = Arc::new(Stats::new());
let gate = Arc::new(TwoPartyGate::new());
let user = "concurrent-quota-write".to_string();
let writer_a = super::StatsIo::new(
GateWriter::new(Arc::clone(&gate)),
Arc::new(super::SharedCounters::new()),
Arc::clone(&stats),
user.clone(),
Some(1),
Arc::new(std::sync::atomic::AtomicBool::new(false)),
tokio::time::Instant::now(),
);
let writer_b = super::StatsIo::new(
GateWriter::new(Arc::clone(&gate)),
Arc::new(super::SharedCounters::new()),
Arc::clone(&stats),
user.clone(),
Some(1),
Arc::new(std::sync::atomic::AtomicBool::new(false)),
tokio::time::Instant::now(),
);
let task_a = tokio::spawn(async move {
let mut w = writer_a;
AsyncWriteExt::write_all(&mut w, &[0x01]).await
});
let task_b = tokio::spawn(async move {
let mut w = writer_b;
AsyncWriteExt::write_all(&mut w, &[0x02]).await
});
let (res_a, res_b) = tokio::join!(task_a, task_b);
let _ = res_a.expect("task a must join");
let _ = res_b.expect("task b must join");
assert!(
gate.total_bytes() <= 1,
"concurrent same-user writes must not forward more than one byte under quota=1"
);
assert!(
stats.get_user_total_octets(&user) <= 1,
"concurrent same-user writes must not account over limit"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn adversarial_concurrent_quota_read_race_does_not_overshoot_limit() {
let stats = Arc::new(Stats::new());
let gate = Arc::new(TwoPartyGate::new());
let user = "concurrent-quota-read".to_string();
let reader_a = super::StatsIo::new(
GateReader::new(Arc::clone(&gate)),
Arc::new(super::SharedCounters::new()),
Arc::clone(&stats),
user.clone(),
Some(1),
Arc::new(std::sync::atomic::AtomicBool::new(false)),
tokio::time::Instant::now(),
);
let reader_b = super::StatsIo::new(
GateReader::new(Arc::clone(&gate)),
Arc::new(super::SharedCounters::new()),
Arc::clone(&stats),
user.clone(),
Some(1),
Arc::new(std::sync::atomic::AtomicBool::new(false)),
tokio::time::Instant::now(),
);
let task_a = tokio::spawn(async move {
let mut r = reader_a;
let mut one = [0u8; 1];
AsyncReadExt::read_exact(&mut r, &mut one).await
});
let task_b = tokio::spawn(async move {
let mut r = reader_b;
let mut one = [0u8; 1];
AsyncReadExt::read_exact(&mut r, &mut one).await
});
let (res_a, res_b) = tokio::join!(task_a, task_b);
let _ = res_a.expect("task a must join");
let _ = res_b.expect("task b must join");
assert!(
gate.total_bytes() <= 1,
"concurrent same-user reads must not consume more than one byte under quota=1"
);
assert!(
stats.get_user_total_octets(&user) <= 1,
"concurrent same-user reads must not account over limit"
);
}
#[tokio::test]
async fn stress_same_user_quota_parallel_relays_never_exceed_limit() {
let stats = Arc::new(Stats::new());
let user = "parallel-quota-user";
for _ in 0..128 {
let (mut client_peer_a, relay_client_a) = duplex(256);
let (relay_server_a, mut server_peer_a) = duplex(256);
let (mut client_peer_b, relay_client_b) = duplex(256);
let (relay_server_b, mut server_peer_b) = duplex(256);
let (client_reader_a, client_writer_a) = tokio::io::split(relay_client_a);
let (server_reader_a, server_writer_a) = tokio::io::split(relay_server_a);
let (client_reader_b, client_writer_b) = tokio::io::split(relay_client_b);
let (server_reader_b, server_writer_b) = tokio::io::split(relay_server_b);
let relay_a = tokio::spawn(relay_bidirectional(
client_reader_a,
client_writer_a,
server_reader_a,
server_writer_a,
64,
64,
user,
Arc::clone(&stats),
Some(1),
Arc::new(BufferPool::new()),
));
let relay_b = tokio::spawn(relay_bidirectional(
client_reader_b,
client_writer_b,
server_reader_b,
server_writer_b,
64,
64,
user,
Arc::clone(&stats),
Some(1),
Arc::new(BufferPool::new()),
));
let _ = tokio::join!(
client_peer_a.write_all(&[0x01]),
server_peer_a.write_all(&[0x02]),
client_peer_b.write_all(&[0x03]),
server_peer_b.write_all(&[0x04]),
);
let _ = timeout(Duration::from_millis(50), poll_fn(|cx| {
let mut one = [0u8; 1];
let _ = Pin::new(&mut client_peer_a).poll_read(cx, &mut ReadBuf::new(&mut one));
Poll::Ready(())
}))
.await;
drop(client_peer_a);
drop(server_peer_a);
drop(client_peer_b);
drop(server_peer_b);
let _ = timeout(Duration::from_secs(1), relay_a).await;
let _ = timeout(Duration::from_secs(1), relay_b).await;
assert!(
stats.get_user_total_octets(user) <= 1,
"parallel relays must not exceed configured quota"
);
}
}
impl FaultyReader {
fn permission_denied_with_message(message: impl Into<String>) -> Self {
Self {
error_once: Some(io::Error::new(io::ErrorKind::PermissionDenied, message.into())),
}
}
}
impl AsyncRead for FaultyReader {
fn poll_read(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
if let Some(err) = self.error_once.take() {
return Poll::Ready(Err(err));
}
Poll::Ready(Ok(()))
}
}
#[tokio::test]
async fn relay_bidirectional_does_not_misclassify_transport_permission_denied_as_quota() {
let stats = Arc::new(Stats::new());
let (client_peer, relay_client) = duplex(4096);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let relay_result = relay_bidirectional(
client_reader,
client_writer,
FaultyReader::permission_denied_with_message("user data quota exceeded"),
tokio::io::sink(),
1024,
1024,
"non-quota-permission-denied",
Arc::clone(&stats),
None,
Arc::new(BufferPool::new()),
)
.await;
drop(client_peer);
assert!(
matches!(relay_result, Err(ProxyError::Io(ref err)) if err.kind() == io::ErrorKind::PermissionDenied),
"non-quota transport PermissionDenied errors must remain IO errors"
);
}
#[tokio::test]
async fn relay_bidirectional_light_fuzz_permission_denied_messages_remain_io_errors() {
let mut rng = StdRng::seed_from_u64(0xA11CE0B5);
for i in 0..128u64 {
let stats = Arc::new(Stats::new());
let (client_peer, relay_client) = duplex(1024);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let random_len = rng.random_range(1..=48);
let mut msg = String::with_capacity(random_len);
for _ in 0..random_len {
let ch = (b'a' + (rng.random::<u8>() % 26)) as char;
msg.push(ch);
}
// Include the legacy quota string in a subset of fuzz cases to validate
// collision resistance against message-based classification.
if i % 7 == 0 {
msg = "user data quota exceeded".to_string();
}
let relay_result = relay_bidirectional(
client_reader,
client_writer,
FaultyReader::permission_denied_with_message(msg),
tokio::io::sink(),
1024,
1024,
"fuzz-perm-denied",
Arc::clone(&stats),
None,
Arc::new(BufferPool::new()),
)
.await;
drop(client_peer);
assert!(
matches!(relay_result, Err(ProxyError::Io(ref err)) if err.kind() == io::ErrorKind::PermissionDenied),
"transport PermissionDenied case must stay typed as IO regardless of message content"
);
}
}

View File

@ -1,10 +1,10 @@
use std::sync::Arc;
use std::sync::atomic::{AtomicU8, AtomicU64, Ordering};
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use tokio::sync::watch;
pub(crate) const ROUTE_SWITCH_ERROR_MSG: &str = "Route mode switched by cutover";
pub(crate) const ROUTE_SWITCH_ERROR_MSG: &str = "Session terminated";
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[repr(u8)]
@ -14,17 +14,6 @@ pub(crate) enum RelayRouteMode {
}
impl RelayRouteMode {
pub(crate) fn as_u8(self) -> u8 {
self as u8
}
pub(crate) fn from_u8(value: u8) -> Self {
match value {
1 => Self::Middle,
_ => Self::Direct,
}
}
pub(crate) fn as_str(self) -> &'static str {
match self {
Self::Direct => "direct",
@ -41,8 +30,6 @@ pub(crate) struct RouteCutoverState {
#[derive(Clone)]
pub(crate) struct RouteRuntimeController {
mode: Arc<AtomicU8>,
generation: Arc<AtomicU64>,
direct_since_epoch_secs: Arc<AtomicU64>,
tx: watch::Sender<RouteCutoverState>,
}
@ -60,18 +47,13 @@ impl RouteRuntimeController {
0
};
Self {
mode: Arc::new(AtomicU8::new(initial_mode.as_u8())),
generation: Arc::new(AtomicU64::new(0)),
direct_since_epoch_secs: Arc::new(AtomicU64::new(direct_since_epoch_secs)),
tx,
}
}
pub(crate) fn snapshot(&self) -> RouteCutoverState {
RouteCutoverState {
mode: RelayRouteMode::from_u8(self.mode.load(Ordering::Relaxed)),
generation: self.generation.load(Ordering::Relaxed),
}
*self.tx.borrow()
}
pub(crate) fn subscribe(&self) -> watch::Receiver<RouteCutoverState> {
@ -84,20 +66,29 @@ impl RouteRuntimeController {
}
pub(crate) fn set_mode(&self, mode: RelayRouteMode) -> Option<RouteCutoverState> {
let previous = self.mode.swap(mode.as_u8(), Ordering::Relaxed);
if previous == mode.as_u8() {
let mut next = None;
let changed = self.tx.send_if_modified(|state| {
if state.mode == mode {
return false;
}
state.mode = mode;
state.generation = state.generation.saturating_add(1);
next = Some(*state);
true
});
if !changed {
return None;
}
if matches!(mode, RelayRouteMode::Direct) {
self.direct_since_epoch_secs
.store(now_epoch_secs(), Ordering::Relaxed);
} else {
self.direct_since_epoch_secs.store(0, Ordering::Relaxed);
}
let generation = self.generation.fetch_add(1, Ordering::Relaxed) + 1;
let next = RouteCutoverState { mode, generation };
self.tx.send_replace(next);
Some(next)
next
}
}
@ -110,10 +101,10 @@ fn now_epoch_secs() -> u64 {
pub(crate) fn is_session_affected_by_cutover(
current: RouteCutoverState,
_session_mode: RelayRouteMode,
session_mode: RelayRouteMode,
session_generation: u64,
) -> bool {
current.generation > session_generation
current.generation > session_generation && current.mode != session_mode
}
pub(crate) fn affected_cutover_state(
@ -140,3 +131,7 @@ pub(crate) fn cutover_stagger_delay(session_id: u64, generation: u64) -> Duratio
let ms = 1000 + (value % 1000);
Duration::from_millis(ms)
}
#[cfg(test)]
#[path = "route_mode_security_tests.rs"]
mod security_tests;

View File

@ -0,0 +1,340 @@
use super::*;
use rand::{Rng, SeedableRng};
use rand::rngs::StdRng;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
#[test]
fn cutover_stagger_delay_is_deterministic_for_same_inputs() {
let d1 = cutover_stagger_delay(0x0123_4567_89ab_cdef, 42);
let d2 = cutover_stagger_delay(0x0123_4567_89ab_cdef, 42);
assert_eq!(
d1, d2,
"stagger delay must be deterministic for identical session/generation inputs"
);
}
#[test]
fn cutover_stagger_delay_stays_within_budget_bounds() {
// Black-hat model: censors trigger many cutovers and correlate disconnect timing.
// Keep delay inside a narrow coarse window to avoid long-tail spikes.
for generation in [0u64, 1, 2, 3, 16, 128, u32::MAX as u64, u64::MAX] {
for session_id in [
0u64,
1,
2,
0xdead_beef,
0xfeed_face_cafe_babe,
u64::MAX,
] {
let delay = cutover_stagger_delay(session_id, generation);
assert!(
(1000..=1999).contains(&delay.as_millis()),
"stagger delay must remain in fixed 1000..=1999ms budget"
);
}
}
}
#[test]
fn cutover_stagger_delay_changes_with_generation_for_same_session() {
let session_id = 0x0123_4567_89ab_cdef;
let first = cutover_stagger_delay(session_id, 100);
let second = cutover_stagger_delay(session_id, 101);
assert_ne!(
first, second,
"adjacent cutover generations should decorrelate disconnect delays"
);
}
#[test]
fn route_runtime_set_mode_is_idempotent_for_same_mode() {
let runtime = RouteRuntimeController::new(RelayRouteMode::Direct);
let first = runtime.snapshot();
let changed = runtime.set_mode(RelayRouteMode::Direct);
let second = runtime.snapshot();
assert!(
changed.is_none(),
"setting already-active mode must not produce a cutover event"
);
assert_eq!(
first.generation, second.generation,
"idempotent mode set must not bump generation"
);
}
#[test]
fn affected_cutover_state_triggers_only_for_newer_generation() {
let runtime = RouteRuntimeController::new(RelayRouteMode::Direct);
let rx = runtime.subscribe();
let initial = runtime.snapshot();
assert!(
affected_cutover_state(&rx, RelayRouteMode::Direct, initial.generation).is_none(),
"current generation must not be considered a cutover for existing session"
);
let next = runtime
.set_mode(RelayRouteMode::Middle)
.expect("mode change must produce cutover state");
let seen = affected_cutover_state(&rx, RelayRouteMode::Direct, initial.generation)
.expect("newer generation must be observed as cutover");
assert_eq!(seen.generation, next.generation);
assert_eq!(seen.mode, RelayRouteMode::Middle);
}
#[test]
fn integration_watch_and_snapshot_follow_same_transition_sequence() {
let runtime = RouteRuntimeController::new(RelayRouteMode::Direct);
let rx = runtime.subscribe();
let sequence = [
RelayRouteMode::Middle,
RelayRouteMode::Middle,
RelayRouteMode::Direct,
RelayRouteMode::Direct,
RelayRouteMode::Middle,
];
let mut expected_generation = 0u64;
let mut expected_mode = RelayRouteMode::Direct;
for target in sequence {
let changed = runtime.set_mode(target);
if target == expected_mode {
assert!(changed.is_none(), "idempotent transition must return none");
} else {
expected_mode = target;
expected_generation = expected_generation.saturating_add(1);
let emitted = changed.expect("real transition must emit cutover state");
assert_eq!(emitted.mode, expected_mode);
assert_eq!(emitted.generation, expected_generation);
}
let snap = runtime.snapshot();
let watched = *rx.borrow();
assert_eq!(snap, watched, "snapshot and watch state must stay aligned");
assert_eq!(snap.mode, expected_mode);
assert_eq!(snap.generation, expected_generation);
}
}
#[test]
fn session_is_not_affected_when_mode_matches_even_if_generation_advanced() {
let session_mode = RelayRouteMode::Direct;
let current = RouteCutoverState {
mode: RelayRouteMode::Direct,
generation: 2,
};
let session_generation = 0;
assert!(
!is_session_affected_by_cutover(current, session_mode, session_generation),
"session on matching final route mode should not be force-cut over on intermediate generation bumps"
);
}
#[test]
fn cutover_predicate_rejects_equal_generation_even_if_mode_differs() {
let current = RouteCutoverState {
mode: RelayRouteMode::Middle,
generation: 77,
};
assert!(
!is_session_affected_by_cutover(current, RelayRouteMode::Direct, 77),
"equal generation must never trigger cutover regardless of mode mismatch"
);
}
#[test]
fn adversarial_route_oscillation_only_cuts_over_sessions_with_different_final_mode() {
let runtime = RouteRuntimeController::new(RelayRouteMode::Direct);
let rx = runtime.subscribe();
let session_generation = runtime.snapshot().generation;
runtime
.set_mode(RelayRouteMode::Middle)
.expect("direct->middle must transition");
runtime
.set_mode(RelayRouteMode::Direct)
.expect("middle->direct must transition");
assert!(
affected_cutover_state(&rx, RelayRouteMode::Direct, session_generation).is_none(),
"direct session should survive when final mode returns to direct"
);
assert!(
affected_cutover_state(&rx, RelayRouteMode::Middle, session_generation).is_some(),
"middle session should be cut over when final mode is direct"
);
}
#[test]
fn light_fuzz_cutover_predicate_matches_reference_oracle() {
let mut rng = StdRng::seed_from_u64(0xC0DEC0DE5EED);
for _ in 0..20_000 {
let current = RouteCutoverState {
mode: if rng.random::<bool>() {
RelayRouteMode::Direct
} else {
RelayRouteMode::Middle
},
generation: rng.random_range(0u64..1_000_000),
};
let session_mode = if rng.random::<bool>() {
RelayRouteMode::Direct
} else {
RelayRouteMode::Middle
};
let session_generation = rng.random_range(0u64..1_000_000);
let expected = current.generation > session_generation && current.mode != session_mode;
let actual = is_session_affected_by_cutover(current, session_mode, session_generation);
assert_eq!(
actual, expected,
"cutover predicate must match mode-aware generation oracle"
);
}
}
#[test]
fn light_fuzz_set_mode_generation_tracks_only_real_transitions() {
let runtime = RouteRuntimeController::new(RelayRouteMode::Direct);
let mut rng = StdRng::seed_from_u64(0x0DDC0FFE);
let mut expected_mode = RelayRouteMode::Direct;
let mut expected_generation = 0u64;
for _ in 0..10_000 {
let candidate = if rng.random::<bool>() {
RelayRouteMode::Direct
} else {
RelayRouteMode::Middle
};
let changed = runtime.set_mode(candidate);
if candidate == expected_mode {
assert!(changed.is_none(), "idempotent set_mode must not emit cutover state");
} else {
expected_mode = candidate;
expected_generation = expected_generation.saturating_add(1);
let next = changed.expect("mode transition must emit cutover state");
assert_eq!(next.mode, expected_mode);
assert_eq!(next.generation, expected_generation);
}
}
let final_state = runtime.snapshot();
assert_eq!(final_state.mode, expected_mode);
assert_eq!(final_state.generation, expected_generation);
}
#[test]
fn stress_snapshot_and_watch_state_remain_consistent_under_concurrent_switch_storm() {
let runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct));
std::thread::scope(|scope| {
let mut writers = Vec::new();
for worker in 0..4usize {
let runtime = Arc::clone(&runtime);
writers.push(scope.spawn(move || {
for step in 0..20_000usize {
let mode = if (worker + step) % 2 == 0 {
RelayRouteMode::Direct
} else {
RelayRouteMode::Middle
};
let _ = runtime.set_mode(mode);
}
}));
}
for writer in writers {
writer
.join()
.expect("route mode writer thread must not panic");
}
let rx = runtime.subscribe();
for _ in 0..128 {
assert_eq!(
runtime.snapshot(),
*rx.borrow(),
"snapshot and watch state must converge after concurrent set_mode churn"
);
std::thread::yield_now();
}
});
}
#[test]
fn stress_concurrent_transition_count_matches_final_generation() {
let runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct));
let successful_transitions = Arc::new(AtomicU64::new(0));
std::thread::scope(|scope| {
let mut workers = Vec::new();
for worker in 0..6usize {
let runtime = Arc::clone(&runtime);
let successful_transitions = Arc::clone(&successful_transitions);
workers.push(scope.spawn(move || {
let mut state = (worker as u64 + 1).wrapping_mul(0x9E37_79B9_7F4A_7C15);
for _ in 0..25_000usize {
state ^= state << 7;
state ^= state >> 9;
state ^= state << 8;
let mode = if (state & 1) == 0 {
RelayRouteMode::Direct
} else {
RelayRouteMode::Middle
};
if runtime.set_mode(mode).is_some() {
successful_transitions.fetch_add(1, Ordering::Relaxed);
}
}
}));
}
for worker in workers {
worker.join().expect("route mode transition worker must not panic");
}
});
let final_state = runtime.snapshot();
assert_eq!(
final_state.generation,
successful_transitions.load(Ordering::Relaxed),
"final generation must equal number of accepted mode transitions"
);
assert_eq!(
final_state,
*runtime.subscribe().borrow(),
"watch and snapshot state must match after concurrent transition accounting"
);
}
#[test]
fn light_fuzz_cutover_stagger_delay_distribution_stays_in_fixed_window() {
// Deterministic xorshift fuzzing keeps this test stable across runs.
let mut s: u64 = 0x9E37_79B9_7F4A_7C15;
for _ in 0..20_000 {
s ^= s << 7;
s ^= s >> 9;
s ^= s << 8;
let session_id = s;
s ^= s << 7;
s ^= s >> 9;
s ^= s << 8;
let generation = s;
let delay = cutover_stagger_delay(session_id, generation);
assert!(
(1000..=1999).contains(&delay.as_millis()),
"fuzzed inputs must always map into fixed stagger window"
);
}
}

View File

@ -0,0 +1,265 @@
use super::*;
use std::panic::{self, AssertUnwindSafe};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Barrier;
#[test]
fn direct_connection_lease_balances_on_drop() {
let stats = Arc::new(Stats::new());
assert_eq!(stats.get_current_connections_direct(), 0);
{
let _lease = stats.acquire_direct_connection_lease();
assert_eq!(stats.get_current_connections_direct(), 1);
}
assert_eq!(stats.get_current_connections_direct(), 0);
}
#[test]
fn middle_connection_lease_balances_on_drop() {
let stats = Arc::new(Stats::new());
assert_eq!(stats.get_current_connections_me(), 0);
{
let _lease = stats.acquire_me_connection_lease();
assert_eq!(stats.get_current_connections_me(), 1);
}
assert_eq!(stats.get_current_connections_me(), 0);
}
#[test]
fn connection_lease_disarm_prevents_double_release() {
let stats = Arc::new(Stats::new());
let mut lease = stats.acquire_direct_connection_lease();
assert_eq!(stats.get_current_connections_direct(), 1);
stats.decrement_current_connections_direct();
assert_eq!(stats.get_current_connections_direct(), 0);
lease.disarm();
drop(lease);
assert_eq!(stats.get_current_connections_direct(), 0);
}
#[test]
fn direct_connection_lease_balances_on_panic_unwind() {
let stats = Arc::new(Stats::new());
let stats_for_panic = stats.clone();
let panic_result = panic::catch_unwind(AssertUnwindSafe(move || {
let _lease = stats_for_panic.acquire_direct_connection_lease();
panic!("intentional panic to verify lease drop path");
}));
assert!(panic_result.is_err(), "panic must propagate from test closure");
assert_eq!(
stats.get_current_connections_direct(),
0,
"panic unwind must release direct route gauge"
);
}
#[test]
fn middle_connection_lease_balances_on_panic_unwind() {
let stats = Arc::new(Stats::new());
let stats_for_panic = stats.clone();
let panic_result = panic::catch_unwind(AssertUnwindSafe(move || {
let _lease = stats_for_panic.acquire_me_connection_lease();
panic!("intentional panic to verify middle lease drop path");
}));
assert!(panic_result.is_err(), "panic must propagate from test closure");
assert_eq!(
stats.get_current_connections_me(),
0,
"panic unwind must release middle route gauge"
);
}
#[tokio::test]
async fn concurrent_mixed_route_lease_churn_balances_to_zero() {
const TASKS: usize = 48;
const ITERATIONS_PER_TASK: usize = 256;
let stats = Arc::new(Stats::new());
let barrier = Arc::new(Barrier::new(TASKS));
let mut workers = Vec::with_capacity(TASKS);
for task_idx in 0..TASKS {
let stats_for_task = stats.clone();
let barrier_for_task = barrier.clone();
workers.push(tokio::spawn(async move {
barrier_for_task.wait().await;
for iter in 0..ITERATIONS_PER_TASK {
if (task_idx + iter) % 2 == 0 {
let _lease = stats_for_task.acquire_direct_connection_lease();
tokio::task::yield_now().await;
} else {
let _lease = stats_for_task.acquire_me_connection_lease();
tokio::task::yield_now().await;
}
}
}));
}
for worker in workers {
worker
.await
.expect("lease churn worker must not panic");
}
assert_eq!(
stats.get_current_connections_direct(),
0,
"direct route gauge must return to zero after concurrent lease churn"
);
assert_eq!(
stats.get_current_connections_me(),
0,
"middle route gauge must return to zero after concurrent lease churn"
);
}
#[tokio::test]
async fn abort_storm_mixed_route_leases_returns_all_gauges_to_zero() {
const TASKS: usize = 64;
let stats = Arc::new(Stats::new());
let mut workers = Vec::with_capacity(TASKS);
for task_idx in 0..TASKS {
let stats_for_task = stats.clone();
workers.push(tokio::spawn(async move {
if task_idx % 2 == 0 {
let _lease = stats_for_task.acquire_direct_connection_lease();
tokio::time::sleep(Duration::from_secs(60)).await;
} else {
let _lease = stats_for_task.acquire_me_connection_lease();
tokio::time::sleep(Duration::from_secs(60)).await;
}
}));
}
tokio::time::timeout(Duration::from_secs(2), async {
loop {
let total = stats.get_current_connections_direct() + stats.get_current_connections_me();
if total == TASKS as u64 {
break;
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
})
.await
.expect("all storm tasks must acquire route leases before abort");
for worker in &workers {
worker.abort();
}
for worker in workers {
let joined = worker.await;
assert!(joined.is_err(), "aborted worker must return join error");
}
tokio::time::timeout(Duration::from_secs(2), async {
loop {
if stats.get_current_connections_direct() == 0 && stats.get_current_connections_me() == 0 {
break;
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
})
.await
.expect("all route gauges must drain to zero after abort storm");
}
#[test]
fn saturating_route_decrements_do_not_underflow_under_race() {
const THREADS: usize = 16;
const DECREMENTS_PER_THREAD: usize = 4096;
let stats = Arc::new(Stats::new());
let mut workers = Vec::with_capacity(THREADS);
for _ in 0..THREADS {
let stats_for_thread = stats.clone();
workers.push(std::thread::spawn(move || {
for _ in 0..DECREMENTS_PER_THREAD {
stats_for_thread.decrement_current_connections_direct();
stats_for_thread.decrement_current_connections_me();
}
}));
}
for worker in workers {
worker
.join()
.expect("decrement race worker must not panic");
}
assert_eq!(
stats.get_current_connections_direct(),
0,
"direct route decrement races must never underflow"
);
assert_eq!(
stats.get_current_connections_me(),
0,
"middle route decrement races must never underflow"
);
}
#[tokio::test]
async fn direct_connection_lease_balances_on_task_abort() {
let stats = Arc::new(Stats::new());
let stats_for_task = stats.clone();
let task = tokio::spawn(async move {
let _lease = stats_for_task.acquire_direct_connection_lease();
tokio::time::sleep(Duration::from_secs(60)).await;
});
tokio::time::sleep(Duration::from_millis(20)).await;
assert_eq!(stats.get_current_connections_direct(), 1);
task.abort();
let joined = task.await;
assert!(joined.is_err(), "aborted task must return a join error");
tokio::time::sleep(Duration::from_millis(20)).await;
assert_eq!(
stats.get_current_connections_direct(),
0,
"aborted task must release direct route gauge"
);
}
#[tokio::test]
async fn middle_connection_lease_balances_on_task_abort() {
let stats = Arc::new(Stats::new());
let stats_for_task = stats.clone();
let task = tokio::spawn(async move {
let _lease = stats_for_task.acquire_me_connection_lease();
tokio::time::sleep(Duration::from_secs(60)).await;
});
tokio::time::sleep(Duration::from_millis(20)).await;
assert_eq!(stats.get_current_connections_me(), 1);
task.abort();
let joined = task.await;
assert!(joined.is_err(), "aborted task must return a join error");
tokio::time::sleep(Duration::from_millis(20)).await;
assert_eq!(
stats.get_current_connections_me(),
0,
"aborted task must release middle route gauge"
);
}

View File

@ -6,6 +6,7 @@ pub mod beobachten;
pub mod telemetry;
use std::sync::atomic::{AtomicBool, AtomicU8, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
use dashmap::DashMap;
use parking_lot::Mutex;
@ -19,6 +20,46 @@ use tracing::debug;
use crate::config::{MeTelemetryLevel, MeWriterPickMode};
use self::telemetry::TelemetryPolicy;
#[derive(Clone, Copy)]
enum RouteConnectionGauge {
Direct,
Middle,
}
#[must_use = "RouteConnectionLease must be kept alive to hold the connection gauge increment"]
pub struct RouteConnectionLease {
stats: Arc<Stats>,
gauge: RouteConnectionGauge,
active: bool,
}
impl RouteConnectionLease {
fn new(stats: Arc<Stats>, gauge: RouteConnectionGauge) -> Self {
Self {
stats,
gauge,
active: true,
}
}
#[cfg(test)]
fn disarm(&mut self) {
self.active = false;
}
}
impl Drop for RouteConnectionLease {
fn drop(&mut self) {
if !self.active {
return;
}
match self.gauge {
RouteConnectionGauge::Direct => self.stats.decrement_current_connections_direct(),
RouteConnectionGauge::Middle => self.stats.decrement_current_connections_me(),
}
}
}
// ============= Stats =============
#[derive(Default)]
@ -285,6 +326,16 @@ impl Stats {
pub fn decrement_current_connections_me(&self) {
Self::decrement_atomic_saturating(&self.current_connections_me);
}
pub fn acquire_direct_connection_lease(self: &Arc<Self>) -> RouteConnectionLease {
self.increment_current_connections_direct();
RouteConnectionLease::new(self.clone(), RouteConnectionGauge::Direct)
}
pub fn acquire_me_connection_lease(self: &Arc<Self>) -> RouteConnectionLease {
self.increment_current_connections_me();
RouteConnectionLease::new(self.clone(), RouteConnectionGauge::Middle)
}
pub fn increment_handshake_timeouts(&self) {
if self.telemetry_core_enabled() {
self.handshake_timeouts.fetch_add(1, Ordering::Relaxed);
@ -1772,3 +1823,7 @@ mod tests {
assert_eq!(checker.stats().total_entries, 500);
}
}
#[cfg(test)]
#[path = "connection_lease_security_tests.rs"]
mod connection_lease_security_tests;

View File

@ -117,15 +117,6 @@ pub fn build_emulated_server_hello(
extensions.extend_from_slice(&0x002bu16.to_be_bytes());
extensions.extend_from_slice(&(2u16).to_be_bytes());
extensions.extend_from_slice(&0x0304u16.to_be_bytes());
if let Some(alpn_proto) = &alpn {
extensions.extend_from_slice(&0x0010u16.to_be_bytes());
let list_len: u16 = 1 + alpn_proto.len() as u16;
let ext_len: u16 = 2 + list_len;
extensions.extend_from_slice(&ext_len.to_be_bytes());
extensions.extend_from_slice(&list_len.to_be_bytes());
extensions.push(alpn_proto.len() as u8);
extensions.extend_from_slice(alpn_proto);
}
let extensions_len = extensions.len() as u16;
let body_len = 2 + // version
@ -207,8 +198,22 @@ pub fn build_emulated_server_hello(
}
let mut app_data = Vec::new();
let alpn_marker = alpn
.as_ref()
.filter(|p| !p.is_empty() && p.len() <= u8::MAX as usize)
.map(|proto| {
let proto_list_len = 1usize + proto.len();
let ext_data_len = 2usize + proto_list_len;
let mut marker = Vec::with_capacity(4 + ext_data_len);
marker.extend_from_slice(&0x0010u16.to_be_bytes());
marker.extend_from_slice(&(ext_data_len as u16).to_be_bytes());
marker.extend_from_slice(&(proto_list_len as u16).to_be_bytes());
marker.push(proto.len() as u8);
marker.extend_from_slice(proto);
marker
});
let mut payload_offset = 0usize;
for size in sizes {
for (idx, size) in sizes.into_iter().enumerate() {
let mut rec = Vec::with_capacity(5 + size);
rec.push(TLS_RECORD_APPLICATION);
rec.extend_from_slice(&TLS_VERSION);
@ -233,7 +238,20 @@ pub fn build_emulated_server_hello(
}
} else if size > 17 {
let body_len = size - 17;
rec.extend_from_slice(&rng.bytes(body_len));
let mut body = Vec::with_capacity(body_len);
if idx == 0 && let Some(marker) = &alpn_marker {
if marker.len() <= body_len {
body.extend_from_slice(marker);
if body_len > marker.len() {
body.extend_from_slice(&rng.bytes(body_len - marker.len()));
}
} else {
body.extend_from_slice(&rng.bytes(body_len));
}
} else {
body.extend_from_slice(&rng.bytes(body_len));
}
rec.extend_from_slice(&body);
rec.push(0x16); // inner content type marker (handshake)
rec.extend_from_slice(&rng.bytes(16)); // AEAD-like tag
} else {
@ -245,8 +263,9 @@ pub fn build_emulated_server_hello(
// --- Combine ---
// Optional NewSessionTicket mimic records (opaque ApplicationData for fingerprint).
let mut tickets = Vec::new();
if new_session_tickets > 0 {
for _ in 0..new_session_tickets {
let ticket_count = new_session_tickets.min(4);
if ticket_count > 0 {
for _ in 0..ticket_count {
let ticket_len: usize = rng.range(48) + 48;
let mut rec = Vec::with_capacity(5 + ticket_len);
rec.push(TLS_RECORD_APPLICATION);
@ -273,6 +292,10 @@ pub fn build_emulated_server_hello(
response
}
#[cfg(test)]
#[path = "emulator_security_tests.rs"]
mod security_tests;
#[cfg(test)]
mod tests {
use std::time::SystemTime;

View File

@ -0,0 +1,136 @@
use std::time::SystemTime;
use crate::crypto::SecureRandom;
use crate::protocol::constants::{TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER, TLS_RECORD_HANDSHAKE};
use crate::tls_front::emulator::build_emulated_server_hello;
use crate::tls_front::types::{
CachedTlsData, ParsedServerHello, TlsBehaviorProfile, TlsCertPayload, TlsProfileSource,
};
fn make_cached(cert_payload: Option<crate::tls_front::types::TlsCertPayload>) -> CachedTlsData {
CachedTlsData {
server_hello_template: ParsedServerHello {
version: [0x03, 0x03],
random: [0u8; 32],
session_id: Vec::new(),
cipher_suite: [0x13, 0x01],
compression: 0,
extensions: Vec::new(),
},
cert_info: None,
cert_payload,
app_data_records_sizes: vec![64],
total_app_data_len: 64,
behavior_profile: TlsBehaviorProfile {
change_cipher_spec_count: 1,
app_data_record_sizes: vec![64],
ticket_record_sizes: Vec::new(),
source: TlsProfileSource::Default,
},
fetched_at: SystemTime::now(),
domain: "example.com".to_string(),
}
}
fn first_app_data_payload(response: &[u8]) -> &[u8] {
let hello_len = u16::from_be_bytes([response[3], response[4]]) as usize;
let ccs_start = 5 + hello_len;
let ccs_len = u16::from_be_bytes([response[ccs_start + 3], response[ccs_start + 4]]) as usize;
let app_start = ccs_start + 5 + ccs_len;
let app_len = u16::from_be_bytes([response[app_start + 3], response[app_start + 4]]) as usize;
&response[app_start + 5..app_start + 5 + app_len]
}
#[test]
fn emulated_server_hello_ignores_oversized_alpn_when_marker_would_not_fit() {
let cached = make_cached(None);
let rng = SecureRandom::new();
let oversized_alpn = vec![0xAB; u8::MAX as usize + 1];
let response = build_emulated_server_hello(
b"secret",
&[0x11; 32],
&[0x22; 16],
&cached,
true,
&rng,
Some(oversized_alpn),
0,
);
assert_eq!(response[0], TLS_RECORD_HANDSHAKE);
let hello_len = u16::from_be_bytes([response[3], response[4]]) as usize;
let ccs_start = 5 + hello_len;
assert_eq!(response[ccs_start], TLS_RECORD_CHANGE_CIPHER);
let app_start = ccs_start + 6;
assert_eq!(response[app_start], TLS_RECORD_APPLICATION);
let payload = first_app_data_payload(&response);
let mut marker_prefix = Vec::new();
marker_prefix.extend_from_slice(&0x0010u16.to_be_bytes());
marker_prefix.extend_from_slice(&0x0102u16.to_be_bytes());
marker_prefix.extend_from_slice(&0x0100u16.to_be_bytes());
marker_prefix.push(0xff);
marker_prefix.extend_from_slice(&[0xab; 8]);
assert!(
!payload.starts_with(&marker_prefix),
"oversized ALPN must not be partially embedded into the emulated first application record"
);
}
#[test]
fn emulated_server_hello_embeds_full_alpn_marker_when_body_can_fit() {
let cached = make_cached(None);
let rng = SecureRandom::new();
let response = build_emulated_server_hello(
b"secret",
&[0x31; 32],
&[0x41; 16],
&cached,
true,
&rng,
Some(b"h2".to_vec()),
0,
);
let payload = first_app_data_payload(&response);
let expected = [0x00u8, 0x10, 0x00, 0x05, 0x00, 0x03, 0x02, b'h', b'2'];
assert!(
payload.starts_with(&expected),
"when body has enough capacity, emulated first application record must include full ALPN marker"
);
}
#[test]
fn emulated_server_hello_prefers_cert_payload_over_alpn_marker() {
let cert_msg = vec![0x0b, 0x00, 0x00, 0x05, 0x00, 0xaa, 0xbb, 0xcc, 0xdd];
let cached = make_cached(Some(TlsCertPayload {
cert_chain_der: vec![vec![0x30, 0x01, 0x00]],
certificate_message: cert_msg.clone(),
}));
let rng = SecureRandom::new();
let response = build_emulated_server_hello(
b"secret",
&[0x32; 32],
&[0x42; 16],
&cached,
true,
&rng,
Some(b"h2".to_vec()),
0,
);
let payload = first_app_data_payload(&response);
let alpn_marker = [0x00u8, 0x10, 0x00, 0x05, 0x00, 0x03, 0x02, b'h', b'2'];
assert!(
payload.starts_with(&cert_msg),
"when certificate payload is available, first record must start with cert payload bytes"
);
assert!(
!payload.starts_with(&alpn_marker),
"ALPN marker must not displace selected certificate payload"
);
}

View File

@ -25,6 +25,9 @@ const HEALTH_RECONNECT_BUDGET_PER_CORE: usize = 2;
const HEALTH_RECONNECT_BUDGET_PER_DC: usize = 1;
const HEALTH_RECONNECT_BUDGET_MIN: usize = 4;
const HEALTH_RECONNECT_BUDGET_MAX: usize = 128;
const HEALTH_DRAIN_CLOSE_BUDGET_PER_CORE: usize = 16;
const HEALTH_DRAIN_CLOSE_BUDGET_MIN: usize = 16;
const HEALTH_DRAIN_CLOSE_BUDGET_MAX: usize = 256;
#[derive(Debug, Clone)]
struct DcFloorPlanEntry {
@ -111,106 +114,75 @@ pub async fn me_health_monitor(pool: Arc<MePool>, rng: Arc<SecureRandom>, _min_c
}
}
async fn reap_draining_writers(
pub(super) async fn reap_draining_writers(
pool: &Arc<MePool>,
warn_next_allowed: &mut HashMap<u64, Instant>,
) {
if pool.draining_active_runtime() == 0 {
return;
}
let now_epoch_secs = MePool::now_epoch_secs();
let now = Instant::now();
let drain_ttl_secs = pool.me_pool_drain_ttl_secs.load(std::sync::atomic::Ordering::Relaxed);
let drain_threshold = pool
.me_pool_drain_threshold
.load(std::sync::atomic::Ordering::Relaxed);
let mut draining_writers = {
let writers = pool.writers.read().await;
let mut draining_writers = Vec::<DrainingWriterSnapshot>::new();
for writer in writers.iter() {
if !writer.draining.load(std::sync::atomic::Ordering::Relaxed) {
continue;
}
draining_writers.push(DrainingWriterSnapshot {
id: writer.id,
writer_dc: writer.writer_dc,
addr: writer.addr,
generation: writer.generation,
created_at: writer.created_at,
draining_started_at_epoch_secs: writer
.draining_started_at_epoch_secs
.load(std::sync::atomic::Ordering::Relaxed),
drain_deadline_epoch_secs: writer
.drain_deadline_epoch_secs
.load(std::sync::atomic::Ordering::Relaxed),
allow_drain_fallback: writer
.allow_drain_fallback
.load(std::sync::atomic::Ordering::Relaxed),
});
let activity = pool.registry.writer_activity_snapshot().await;
let mut draining_writers = Vec::<DrainingWriterSnapshot>::new();
let mut empty_writer_ids = Vec::<u64>::new();
let mut force_close_writer_ids = Vec::<u64>::new();
let writers = pool.writers.read().await;
for writer in writers.iter() {
if !writer.draining.load(std::sync::atomic::Ordering::Relaxed) {
continue;
}
draining_writers
};
if draining_writers.is_empty() {
return;
}
let draining_ids: Vec<u64> = draining_writers.iter().map(|writer| writer.id).collect();
let non_empty_writer_ids = pool.registry.non_empty_writer_ids(&draining_ids).await;
let mut non_empty_draining_writers =
Vec::<DrainingWriterSnapshot>::with_capacity(draining_writers.len());
for writer in draining_writers.drain(..) {
if non_empty_writer_ids.contains(&writer.id) {
non_empty_draining_writers.push(writer);
} else {
pool.remove_writer_and_close_clients(writer.id).await;
if activity
.bound_clients_by_writer
.get(&writer.id)
.copied()
.unwrap_or(0)
== 0
{
empty_writer_ids.push(writer.id);
continue;
}
draining_writers.push(DrainingWriterSnapshot {
id: writer.id,
writer_dc: writer.writer_dc,
addr: writer.addr,
generation: writer.generation,
created_at: writer.created_at,
draining_started_at_epoch_secs: writer
.draining_started_at_epoch_secs
.load(std::sync::atomic::Ordering::Relaxed),
drain_deadline_epoch_secs: writer
.drain_deadline_epoch_secs
.load(std::sync::atomic::Ordering::Relaxed),
allow_drain_fallback: writer
.allow_drain_fallback
.load(std::sync::atomic::Ordering::Relaxed),
});
}
draining_writers = non_empty_draining_writers;
if draining_writers.is_empty() {
return;
}
drop(writers);
let overflow = if drain_threshold > 0 && draining_writers.len() > drain_threshold as usize {
draining_writers.len().saturating_sub(drain_threshold as usize)
} else {
0
};
let has_deadline_expired = draining_writers.iter().any(|writer| {
writer.drain_deadline_epoch_secs != 0 && now_epoch_secs >= writer.drain_deadline_epoch_secs
});
let can_drop_with_replacement = if overflow > 0 || has_deadline_expired {
pool.has_non_draining_writer_per_desired_dc_group().await
} else {
false
};
if overflow > 0 {
if can_drop_with_replacement {
draining_writers.sort_by(|left, right| {
left.draining_started_at_epoch_secs
.cmp(&right.draining_started_at_epoch_secs)
.then_with(|| left.created_at.cmp(&right.created_at))
.then_with(|| left.id.cmp(&right.id))
});
warn!(
draining_writers = draining_writers.len(),
me_pool_drain_threshold = drain_threshold,
removing_writers = overflow,
"ME draining writer threshold exceeded, force-closing oldest draining writers"
);
for writer in draining_writers.drain(..overflow) {
pool.stats.increment_pool_force_close_total();
pool.remove_writer_and_close_clients(writer.id).await;
}
} else {
warn!(
draining_writers = draining_writers.len(),
me_pool_drain_threshold = drain_threshold,
overflow,
"ME draining threshold exceeded, but replacement coverage is incomplete; keeping draining writers"
);
draining_writers.sort_by(|left, right| {
left.draining_started_at_epoch_secs
.cmp(&right.draining_started_at_epoch_secs)
.then_with(|| left.created_at.cmp(&right.created_at))
.then_with(|| left.id.cmp(&right.id))
});
warn!(
draining_writers = draining_writers.len(),
me_pool_drain_threshold = drain_threshold,
removing_writers = overflow,
"ME draining writer threshold exceeded, force-closing oldest draining writers"
);
for writer in draining_writers.drain(..overflow) {
force_close_writer_ids.push(writer.id);
}
}
@ -238,25 +210,71 @@ async fn reap_draining_writers(
}
if writer.drain_deadline_epoch_secs != 0 && now_epoch_secs >= writer.drain_deadline_epoch_secs
{
if can_drop_with_replacement {
warn!(writer_id = writer.id, "Drain timeout, force-closing");
pool.stats.increment_pool_force_close_total();
pool.remove_writer_and_close_clients(writer.id).await;
} else if should_emit_writer_warn(
warn_next_allowed,
writer.id,
now,
pool.warn_rate_limit_duration(),
) {
warn!(
writer_id = writer.id,
writer_dc = writer.writer_dc,
endpoint = %writer.addr,
"Drain timeout reached, but replacement coverage is incomplete; keeping draining writer"
);
}
warn!(writer_id = writer.id, "Drain timeout, force-closing");
force_close_writer_ids.push(writer.id);
}
}
let close_budget = health_drain_close_budget();
let requested_force_close = force_close_writer_ids.len();
let requested_empty_close = empty_writer_ids.len();
let requested_close_total = requested_force_close.saturating_add(requested_empty_close);
let mut closed_writer_ids = HashSet::<u64>::new();
let mut closed_total = 0usize;
for writer_id in force_close_writer_ids {
if closed_total >= close_budget {
break;
}
if !closed_writer_ids.insert(writer_id) {
continue;
}
pool.stats.increment_pool_force_close_total();
pool.remove_writer_and_close_clients(writer_id).await;
closed_total = closed_total.saturating_add(1);
}
for writer_id in empty_writer_ids {
if closed_total >= close_budget {
break;
}
if !closed_writer_ids.insert(writer_id) {
continue;
}
if !pool.remove_writer_if_empty(writer_id).await {
continue;
}
closed_total = closed_total.saturating_add(1);
}
let pending_close_total = requested_close_total.saturating_sub(closed_total);
if pending_close_total > 0 {
warn!(
close_budget,
closed_total,
pending_close_total,
"ME draining close backlog deferred to next health cycle"
);
}
// Keep warn cooldown state for draining writers still present in the pool;
// drop state only once a writer is actually removed.
let active_draining_writer_ids = {
let writers = pool.writers.read().await;
writers
.iter()
.filter(|writer| writer.draining.load(std::sync::atomic::Ordering::Relaxed))
.map(|writer| writer.id)
.collect::<HashSet<u64>>()
};
warn_next_allowed.retain(|writer_id, _| active_draining_writer_ids.contains(writer_id));
}
pub(super) fn health_drain_close_budget() -> usize {
let cpu_cores = std::thread::available_parallelism()
.map(std::num::NonZeroUsize::get)
.unwrap_or(1);
cpu_cores
.saturating_mul(HEALTH_DRAIN_CLOSE_BUDGET_PER_CORE)
.clamp(HEALTH_DRAIN_CLOSE_BUDGET_MIN, HEALTH_DRAIN_CLOSE_BUDGET_MAX)
}
#[derive(Debug, Clone)]
@ -1521,7 +1539,6 @@ mod tests {
pool.writers.write().await.push(writer);
pool.registry.register_writer(writer_id, tx).await;
pool.conn_count.fetch_add(1, Ordering::Relaxed);
pool.increment_draining_active_runtime();
assert!(
pool.registry
.bind_writer(
@ -1570,7 +1587,6 @@ mod tests {
async fn reap_draining_writers_force_closes_oldest_over_threshold() {
let pool = make_pool(2).await;
insert_live_writer(&pool, 1, 2).await;
assert!(pool.has_non_draining_writer_per_desired_dc_group().await);
let now_epoch_secs = MePool::now_epoch_secs();
let conn_a = insert_draining_writer(&pool, 10, now_epoch_secs.saturating_sub(30)).await;
let conn_b = insert_draining_writer(&pool, 20, now_epoch_secs.saturating_sub(20)).await;
@ -1588,7 +1604,7 @@ mod tests {
}
#[tokio::test]
async fn reap_draining_writers_does_not_force_close_overflow_without_replacement() {
async fn reap_draining_writers_force_closes_overflow_without_replacement() {
let pool = make_pool(2).await;
let now_epoch_secs = MePool::now_epoch_secs();
let conn_a = insert_draining_writer(&pool, 10, now_epoch_secs.saturating_sub(30)).await;
@ -1600,8 +1616,8 @@ mod tests {
let mut writer_ids: Vec<u64> = pool.writers.read().await.iter().map(|writer| writer.id).collect();
writer_ids.sort_unstable();
assert_eq!(writer_ids, vec![10, 20, 30]);
assert_eq!(pool.registry.get_writer(conn_a).await.unwrap().writer_id, 10);
assert_eq!(writer_ids, vec![20, 30]);
assert!(pool.registry.get_writer(conn_a).await.is_none());
assert_eq!(pool.registry.get_writer(conn_b).await.unwrap().writer_id, 20);
assert_eq!(pool.registry.get_writer(conn_c).await.unwrap().writer_id, 30);
}

View File

@ -0,0 +1,615 @@
use std::collections::HashMap;
use std::collections::HashSet;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU8, AtomicU32, AtomicU64, Ordering};
use std::time::{Duration, Instant};
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use super::codec::WriterCommand;
use super::health::{health_drain_close_budget, reap_draining_writers};
use super::pool::{MePool, MeWriter, WriterContour};
use super::registry::ConnMeta;
use super::me_health_monitor;
use crate::config::{GeneralConfig, MeRouteNoWriterMode, MeSocksKdfPolicy, MeWriterPickMode};
use crate::crypto::SecureRandom;
use crate::network::probe::NetworkDecision;
use crate::stats::Stats;
async fn make_pool(
me_pool_drain_threshold: u64,
me_health_interval_ms_unhealthy: u64,
me_health_interval_ms_healthy: u64,
) -> (Arc<MePool>, Arc<SecureRandom>) {
let general = GeneralConfig {
me_pool_drain_threshold,
me_health_interval_ms_unhealthy,
me_health_interval_ms_healthy,
..GeneralConfig::default()
};
let rng = Arc::new(SecureRandom::new());
let pool = MePool::new(
None,
vec![1u8; 32],
None,
false,
None,
Vec::new(),
1,
None,
12,
1200,
HashMap::new(),
HashMap::new(),
None,
NetworkDecision::default(),
None,
rng.clone(),
Arc::new(Stats::default()),
general.me_keepalive_enabled,
general.me_keepalive_interval_secs,
general.me_keepalive_jitter_secs,
general.me_keepalive_payload_random,
general.rpc_proxy_req_every,
general.me_warmup_stagger_enabled,
general.me_warmup_step_delay_ms,
general.me_warmup_step_jitter_ms,
general.me_reconnect_max_concurrent_per_dc,
general.me_reconnect_backoff_base_ms,
general.me_reconnect_backoff_cap_ms,
general.me_reconnect_fast_retry_count,
general.me_single_endpoint_shadow_writers,
general.me_single_endpoint_outage_mode_enabled,
general.me_single_endpoint_outage_disable_quarantine,
general.me_single_endpoint_outage_backoff_min_ms,
general.me_single_endpoint_outage_backoff_max_ms,
general.me_single_endpoint_shadow_rotate_every_secs,
general.me_floor_mode,
general.me_adaptive_floor_idle_secs,
general.me_adaptive_floor_min_writers_single_endpoint,
general.me_adaptive_floor_min_writers_multi_endpoint,
general.me_adaptive_floor_recover_grace_secs,
general.me_adaptive_floor_writers_per_core_total,
general.me_adaptive_floor_cpu_cores_override,
general.me_adaptive_floor_max_extra_writers_single_per_core,
general.me_adaptive_floor_max_extra_writers_multi_per_core,
general.me_adaptive_floor_max_active_writers_per_core,
general.me_adaptive_floor_max_warm_writers_per_core,
general.me_adaptive_floor_max_active_writers_global,
general.me_adaptive_floor_max_warm_writers_global,
general.hardswap,
general.me_pool_drain_ttl_secs,
general.me_pool_drain_threshold,
general.effective_me_pool_force_close_secs(),
general.me_pool_min_fresh_ratio,
general.me_hardswap_warmup_delay_min_ms,
general.me_hardswap_warmup_delay_max_ms,
general.me_hardswap_warmup_extra_passes,
general.me_hardswap_warmup_pass_backoff_base_ms,
general.me_bind_stale_mode,
general.me_bind_stale_ttl_secs,
general.me_secret_atomic_snapshot,
general.me_deterministic_writer_sort,
MeWriterPickMode::default(),
general.me_writer_pick_sample_size,
MeSocksKdfPolicy::default(),
general.me_writer_cmd_channel_capacity,
general.me_route_channel_capacity,
general.me_route_backpressure_base_timeout_ms,
general.me_route_backpressure_high_timeout_ms,
general.me_route_backpressure_high_watermark_pct,
general.me_reader_route_data_wait_ms,
general.me_health_interval_ms_unhealthy,
general.me_health_interval_ms_healthy,
general.me_warn_rate_limit_ms,
MeRouteNoWriterMode::default(),
general.me_route_no_writer_wait_ms,
general.me_route_inline_recovery_attempts,
general.me_route_inline_recovery_wait_ms,
);
(pool, rng)
}
async fn insert_draining_writer(
pool: &Arc<MePool>,
writer_id: u64,
drain_started_at_epoch_secs: u64,
bound_clients: usize,
drain_deadline_epoch_secs: u64,
) {
let (tx, _writer_rx) = mpsc::channel::<WriterCommand>(8);
let writer = MeWriter {
id: writer_id,
addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 6000 + writer_id as u16),
source_ip: IpAddr::V4(Ipv4Addr::LOCALHOST),
writer_dc: 2,
generation: 1,
contour: Arc::new(AtomicU8::new(WriterContour::Draining.as_u8())),
created_at: Instant::now() - Duration::from_secs(writer_id),
tx: tx.clone(),
cancel: CancellationToken::new(),
degraded: Arc::new(AtomicBool::new(false)),
rtt_ema_ms_x10: Arc::new(AtomicU32::new(0)),
draining: Arc::new(AtomicBool::new(true)),
draining_started_at_epoch_secs: Arc::new(AtomicU64::new(drain_started_at_epoch_secs)),
drain_deadline_epoch_secs: Arc::new(AtomicU64::new(drain_deadline_epoch_secs)),
allow_drain_fallback: Arc::new(AtomicBool::new(false)),
};
pool.writers.write().await.push(writer);
pool.registry.register_writer(writer_id, tx).await;
pool.conn_count.fetch_add(1, Ordering::Relaxed);
for idx in 0..bound_clients {
let (conn_id, _rx) = pool.registry.register().await;
assert!(
pool.registry
.bind_writer(
conn_id,
writer_id,
ConnMeta {
target_dc: 2,
client_addr: SocketAddr::new(
IpAddr::V4(Ipv4Addr::LOCALHOST),
8000 + idx as u16,
),
our_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443),
proto_flags: 0,
},
)
.await
);
}
}
async fn writer_count(pool: &Arc<MePool>) -> usize {
pool.writers.read().await.len()
}
async fn sorted_writer_ids(pool: &Arc<MePool>) -> Vec<u64> {
let mut ids = pool
.writers
.read()
.await
.iter()
.map(|writer| writer.id)
.collect::<Vec<_>>();
ids.sort_unstable();
ids
}
fn lcg_next(state: &mut u64) -> u64 {
*state = state.wrapping_mul(6364136223846793005).wrapping_add(1);
*state
}
async fn draining_writer_ids(pool: &Arc<MePool>) -> HashSet<u64> {
pool.writers
.read()
.await
.iter()
.filter(|writer| writer.draining.load(Ordering::Relaxed))
.map(|writer| writer.id)
.collect::<HashSet<u64>>()
}
async fn set_writer_runtime_state(
pool: &Arc<MePool>,
writer_id: u64,
draining: bool,
drain_started_at_epoch_secs: u64,
drain_deadline_epoch_secs: u64,
) {
let writers = pool.writers.read().await;
if let Some(writer) = writers.iter().find(|writer| writer.id == writer_id) {
writer.draining.store(draining, Ordering::Relaxed);
writer
.draining_started_at_epoch_secs
.store(drain_started_at_epoch_secs, Ordering::Relaxed);
writer
.drain_deadline_epoch_secs
.store(drain_deadline_epoch_secs, Ordering::Relaxed);
}
}
#[tokio::test]
async fn reap_draining_writers_clears_warn_state_when_pool_empty() {
let (pool, _rng) = make_pool(128, 1, 1).await;
let mut warn_next_allowed = HashMap::new();
warn_next_allowed.insert(11, Instant::now() + Duration::from_secs(5));
warn_next_allowed.insert(22, Instant::now() + Duration::from_secs(5));
reap_draining_writers(&pool, &mut warn_next_allowed).await;
assert!(warn_next_allowed.is_empty());
}
#[tokio::test]
async fn reap_draining_writers_respects_threshold_across_multiple_overflow_cycles() {
let threshold = 3u64;
let (pool, _rng) = make_pool(threshold, 1, 1).await;
let now_epoch_secs = MePool::now_epoch_secs();
for writer_id in 1..=60u64 {
insert_draining_writer(
&pool,
writer_id,
now_epoch_secs.saturating_sub(600).saturating_add(writer_id),
1,
0,
)
.await;
}
let mut warn_next_allowed = HashMap::new();
for _ in 0..64 {
reap_draining_writers(&pool, &mut warn_next_allowed).await;
if writer_count(&pool).await <= threshold as usize {
break;
}
}
assert_eq!(writer_count(&pool).await, threshold as usize);
assert_eq!(sorted_writer_ids(&pool).await, vec![58, 59, 60]);
}
#[tokio::test]
async fn reap_draining_writers_handles_large_empty_writer_population() {
let (pool, _rng) = make_pool(128, 1, 1).await;
let now_epoch_secs = MePool::now_epoch_secs();
let total = health_drain_close_budget().saturating_mul(3).saturating_add(27);
for writer_id in 1..=total as u64 {
insert_draining_writer(
&pool,
writer_id,
now_epoch_secs.saturating_sub(120),
0,
0,
)
.await;
}
let mut warn_next_allowed = HashMap::new();
for _ in 0..24 {
if writer_count(&pool).await == 0 {
break;
}
reap_draining_writers(&pool, &mut warn_next_allowed).await;
}
assert_eq!(writer_count(&pool).await, 0);
}
#[tokio::test]
async fn reap_draining_writers_processes_mass_deadline_expiry_without_unbounded_growth() {
let (pool, _rng) = make_pool(128, 1, 1).await;
let now_epoch_secs = MePool::now_epoch_secs();
let total = health_drain_close_budget().saturating_mul(4).saturating_add(31);
for writer_id in 1..=total as u64 {
insert_draining_writer(
&pool,
writer_id,
now_epoch_secs.saturating_sub(180),
1,
now_epoch_secs.saturating_sub(1),
)
.await;
}
let mut warn_next_allowed = HashMap::new();
for _ in 0..40 {
if writer_count(&pool).await == 0 {
break;
}
reap_draining_writers(&pool, &mut warn_next_allowed).await;
}
assert_eq!(writer_count(&pool).await, 0);
}
#[tokio::test]
async fn reap_draining_writers_maintains_warn_state_subset_property_under_bulk_churn() {
let (pool, _rng) = make_pool(128, 1, 1).await;
let now_epoch_secs = MePool::now_epoch_secs();
let mut warn_next_allowed = HashMap::new();
for wave in 0..40u64 {
for offset in 0..8u64 {
insert_draining_writer(
&pool,
wave * 100 + offset,
now_epoch_secs.saturating_sub(400 + offset),
1,
0,
)
.await;
}
reap_draining_writers(&pool, &mut warn_next_allowed).await;
assert!(warn_next_allowed.len() <= writer_count(&pool).await);
let ids = sorted_writer_ids(&pool).await;
for writer_id in ids.into_iter().take(3) {
let _ = pool.remove_writer_and_close_clients(writer_id).await;
}
reap_draining_writers(&pool, &mut warn_next_allowed).await;
assert!(warn_next_allowed.len() <= writer_count(&pool).await);
}
}
#[tokio::test]
async fn reap_draining_writers_budgeted_cleanup_never_increases_pool_size() {
let (pool, _rng) = make_pool(5, 1, 1).await;
let now_epoch_secs = MePool::now_epoch_secs();
for writer_id in 1..=200u64 {
insert_draining_writer(
&pool,
writer_id,
now_epoch_secs.saturating_sub(240).saturating_add(writer_id),
1,
0,
)
.await;
}
let mut warn_next_allowed = HashMap::new();
let mut previous = writer_count(&pool).await;
for _ in 0..32 {
reap_draining_writers(&pool, &mut warn_next_allowed).await;
let current = writer_count(&pool).await;
assert!(current <= previous);
previous = current;
}
}
#[tokio::test]
async fn me_health_monitor_converges_to_threshold_under_live_injection_churn() {
let threshold = 7u64;
let (pool, rng) = make_pool(threshold, 1, 1).await;
let now_epoch_secs = MePool::now_epoch_secs();
for writer_id in 1..=40u64 {
insert_draining_writer(
&pool,
writer_id,
now_epoch_secs.saturating_sub(300).saturating_add(writer_id),
1,
0,
)
.await;
}
let monitor = tokio::spawn(me_health_monitor(pool.clone(), rng, 0));
for wave in 0..8u64 {
for offset in 0..10u64 {
insert_draining_writer(
&pool,
1000 + wave * 100 + offset,
now_epoch_secs.saturating_sub(120).saturating_add(offset),
1,
0,
)
.await;
}
tokio::time::sleep(Duration::from_millis(5)).await;
}
tokio::time::sleep(Duration::from_millis(120)).await;
monitor.abort();
let _ = monitor.await;
assert!(writer_count(&pool).await <= threshold as usize);
}
#[tokio::test]
async fn me_health_monitor_drains_deadline_storm_with_budgeted_progress() {
let (pool, rng) = make_pool(128, 1, 1).await;
let now_epoch_secs = MePool::now_epoch_secs();
for writer_id in 1..=220u64 {
insert_draining_writer(
&pool,
writer_id,
now_epoch_secs.saturating_sub(120),
1,
now_epoch_secs.saturating_sub(1),
)
.await;
}
let monitor = tokio::spawn(me_health_monitor(pool.clone(), rng, 0));
tokio::time::sleep(Duration::from_millis(120)).await;
monitor.abort();
let _ = monitor.await;
assert_eq!(writer_count(&pool).await, 0);
}
#[tokio::test]
async fn me_health_monitor_eliminates_mixed_empty_and_deadline_backlog() {
let threshold = 12u64;
let (pool, rng) = make_pool(threshold, 1, 1).await;
let now_epoch_secs = MePool::now_epoch_secs();
for writer_id in 1..=180u64 {
let bound_clients = if writer_id % 3 == 0 { 0 } else { 1 };
let deadline = if writer_id % 2 == 0 {
now_epoch_secs.saturating_sub(1)
} else {
0
};
insert_draining_writer(
&pool,
writer_id,
now_epoch_secs.saturating_sub(250).saturating_add(writer_id),
bound_clients,
deadline,
)
.await;
}
let monitor = tokio::spawn(me_health_monitor(pool.clone(), rng, 0));
tokio::time::sleep(Duration::from_millis(140)).await;
monitor.abort();
let _ = monitor.await;
assert!(writer_count(&pool).await <= threshold as usize);
}
#[tokio::test]
async fn reap_draining_writers_deterministic_mixed_state_churn_preserves_invariants() {
let threshold = 9u64;
let (pool, _rng) = make_pool(threshold, 1, 1).await;
let mut warn_next_allowed = HashMap::new();
let mut seed = 0x9E37_79B9_7F4A_7C15u64;
let mut next_writer_id = 20_000u64;
let now_epoch_secs = MePool::now_epoch_secs();
for writer_id in 1..=72u64 {
let bound_clients = if writer_id % 4 == 0 { 0 } else { 1 };
let deadline = if writer_id % 5 == 0 {
now_epoch_secs.saturating_sub(1)
} else {
0
};
insert_draining_writer(
&pool,
writer_id,
now_epoch_secs.saturating_sub(500).saturating_add(writer_id),
bound_clients,
deadline,
)
.await;
}
for _round in 0..90 {
reap_draining_writers(&pool, &mut warn_next_allowed).await;
let draining_ids = draining_writer_ids(&pool).await;
assert!(
warn_next_allowed.keys().all(|id| draining_ids.contains(id)),
"warn-state keys must always be a subset of live draining writers"
);
let writer_ids = sorted_writer_ids(&pool).await;
if writer_ids.is_empty() {
continue;
}
let remove_n = (lcg_next(&mut seed) % 3) as usize;
for writer_id in writer_ids.iter().copied().take(remove_n) {
let _ = pool.remove_writer_and_close_clients(writer_id).await;
}
let survivors = sorted_writer_ids(&pool).await;
if !survivors.is_empty() {
let idx = (lcg_next(&mut seed) as usize) % survivors.len();
let target = survivors[idx];
set_writer_runtime_state(&pool, target, false, 0, 0).await;
}
let survivors = sorted_writer_ids(&pool).await;
if survivors.len() > 1 {
let idx = (lcg_next(&mut seed) as usize) % survivors.len();
let target = survivors[idx];
let expired_deadline = if lcg_next(&mut seed) & 1 == 0 {
now_epoch_secs.saturating_sub(1)
} else {
0
};
set_writer_runtime_state(
&pool,
target,
true,
now_epoch_secs.saturating_sub(120),
expired_deadline,
)
.await;
}
let inject_n = (lcg_next(&mut seed) % 4) as usize;
for _ in 0..inject_n {
let bound_clients = if lcg_next(&mut seed) & 1 == 0 { 0 } else { 1 };
let deadline = if lcg_next(&mut seed) & 1 == 0 {
now_epoch_secs.saturating_sub(1)
} else {
0
};
insert_draining_writer(
&pool,
next_writer_id,
now_epoch_secs.saturating_sub(240),
bound_clients,
deadline,
)
.await;
next_writer_id = next_writer_id.saturating_add(1);
}
}
for _ in 0..64 {
reap_draining_writers(&pool, &mut warn_next_allowed).await;
if writer_count(&pool).await <= threshold as usize {
break;
}
}
assert!(writer_count(&pool).await <= threshold as usize);
let draining_ids = draining_writer_ids(&pool).await;
assert!(warn_next_allowed.keys().all(|id| draining_ids.contains(id)));
}
#[tokio::test]
async fn reap_draining_writers_repeated_draining_flips_never_leave_stale_warn_state() {
let (pool, _rng) = make_pool(64, 1, 1).await;
let now_epoch_secs = MePool::now_epoch_secs();
for writer_id in 1..=24u64 {
insert_draining_writer(
&pool,
writer_id,
now_epoch_secs.saturating_sub(240),
1,
0,
)
.await;
}
let mut warn_next_allowed = HashMap::new();
for _round in 0..48u64 {
for writer_id in 1..=24u64 {
let draining = (writer_id + _round) % 3 != 0;
set_writer_runtime_state(
&pool,
writer_id,
draining,
now_epoch_secs.saturating_sub(120),
0,
)
.await;
}
reap_draining_writers(&pool, &mut warn_next_allowed).await;
let draining_ids = draining_writer_ids(&pool).await;
assert!(
warn_next_allowed.keys().all(|id| draining_ids.contains(id)),
"warn-state map must not retain entries for writers outside draining set"
);
}
}
#[test]
fn health_drain_close_budget_is_within_expected_bounds() {
let budget = health_drain_close_budget();
assert!((16..=256).contains(&budget));
}

View File

@ -0,0 +1,241 @@
use std::collections::HashMap;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU8, AtomicU32, AtomicU64, Ordering};
use std::time::{Duration, Instant};
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use super::codec::WriterCommand;
use super::health::health_drain_close_budget;
use super::pool::{MePool, MeWriter, WriterContour};
use super::registry::ConnMeta;
use super::me_health_monitor;
use crate::config::{GeneralConfig, MeRouteNoWriterMode, MeSocksKdfPolicy, MeWriterPickMode};
use crate::crypto::SecureRandom;
use crate::network::probe::NetworkDecision;
use crate::stats::Stats;
async fn make_pool(
me_pool_drain_threshold: u64,
me_health_interval_ms_unhealthy: u64,
me_health_interval_ms_healthy: u64,
) -> (Arc<MePool>, Arc<SecureRandom>) {
let general = GeneralConfig {
me_pool_drain_threshold,
me_health_interval_ms_unhealthy,
me_health_interval_ms_healthy,
..GeneralConfig::default()
};
let rng = Arc::new(SecureRandom::new());
let pool = MePool::new(
None,
vec![1u8; 32],
None,
false,
None,
Vec::new(),
1,
None,
12,
1200,
HashMap::new(),
HashMap::new(),
None,
NetworkDecision::default(),
None,
rng.clone(),
Arc::new(Stats::default()),
general.me_keepalive_enabled,
general.me_keepalive_interval_secs,
general.me_keepalive_jitter_secs,
general.me_keepalive_payload_random,
general.rpc_proxy_req_every,
general.me_warmup_stagger_enabled,
general.me_warmup_step_delay_ms,
general.me_warmup_step_jitter_ms,
general.me_reconnect_max_concurrent_per_dc,
general.me_reconnect_backoff_base_ms,
general.me_reconnect_backoff_cap_ms,
general.me_reconnect_fast_retry_count,
general.me_single_endpoint_shadow_writers,
general.me_single_endpoint_outage_mode_enabled,
general.me_single_endpoint_outage_disable_quarantine,
general.me_single_endpoint_outage_backoff_min_ms,
general.me_single_endpoint_outage_backoff_max_ms,
general.me_single_endpoint_shadow_rotate_every_secs,
general.me_floor_mode,
general.me_adaptive_floor_idle_secs,
general.me_adaptive_floor_min_writers_single_endpoint,
general.me_adaptive_floor_min_writers_multi_endpoint,
general.me_adaptive_floor_recover_grace_secs,
general.me_adaptive_floor_writers_per_core_total,
general.me_adaptive_floor_cpu_cores_override,
general.me_adaptive_floor_max_extra_writers_single_per_core,
general.me_adaptive_floor_max_extra_writers_multi_per_core,
general.me_adaptive_floor_max_active_writers_per_core,
general.me_adaptive_floor_max_warm_writers_per_core,
general.me_adaptive_floor_max_active_writers_global,
general.me_adaptive_floor_max_warm_writers_global,
general.hardswap,
general.me_pool_drain_ttl_secs,
general.me_pool_drain_threshold,
general.effective_me_pool_force_close_secs(),
general.me_pool_min_fresh_ratio,
general.me_hardswap_warmup_delay_min_ms,
general.me_hardswap_warmup_delay_max_ms,
general.me_hardswap_warmup_extra_passes,
general.me_hardswap_warmup_pass_backoff_base_ms,
general.me_bind_stale_mode,
general.me_bind_stale_ttl_secs,
general.me_secret_atomic_snapshot,
general.me_deterministic_writer_sort,
MeWriterPickMode::default(),
general.me_writer_pick_sample_size,
MeSocksKdfPolicy::default(),
general.me_writer_cmd_channel_capacity,
general.me_route_channel_capacity,
general.me_route_backpressure_base_timeout_ms,
general.me_route_backpressure_high_timeout_ms,
general.me_route_backpressure_high_watermark_pct,
general.me_reader_route_data_wait_ms,
general.me_health_interval_ms_unhealthy,
general.me_health_interval_ms_healthy,
general.me_warn_rate_limit_ms,
MeRouteNoWriterMode::default(),
general.me_route_no_writer_wait_ms,
general.me_route_inline_recovery_attempts,
general.me_route_inline_recovery_wait_ms,
);
(pool, rng)
}
async fn insert_draining_writer(
pool: &Arc<MePool>,
writer_id: u64,
drain_started_at_epoch_secs: u64,
bound_clients: usize,
drain_deadline_epoch_secs: u64,
) {
let (tx, _writer_rx) = mpsc::channel::<WriterCommand>(8);
let writer = MeWriter {
id: writer_id,
addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 5500 + writer_id as u16),
source_ip: IpAddr::V4(Ipv4Addr::LOCALHOST),
writer_dc: 2,
generation: 1,
contour: Arc::new(AtomicU8::new(WriterContour::Draining.as_u8())),
created_at: Instant::now() - Duration::from_secs(writer_id),
tx: tx.clone(),
cancel: CancellationToken::new(),
degraded: Arc::new(AtomicBool::new(false)),
rtt_ema_ms_x10: Arc::new(AtomicU32::new(0)),
draining: Arc::new(AtomicBool::new(true)),
draining_started_at_epoch_secs: Arc::new(AtomicU64::new(drain_started_at_epoch_secs)),
drain_deadline_epoch_secs: Arc::new(AtomicU64::new(drain_deadline_epoch_secs)),
allow_drain_fallback: Arc::new(AtomicBool::new(false)),
};
pool.writers.write().await.push(writer);
pool.registry.register_writer(writer_id, tx).await;
pool.conn_count.fetch_add(1, Ordering::Relaxed);
for idx in 0..bound_clients {
let (conn_id, _rx) = pool.registry.register().await;
assert!(
pool.registry
.bind_writer(
conn_id,
writer_id,
ConnMeta {
target_dc: 2,
client_addr: SocketAddr::new(
IpAddr::V4(Ipv4Addr::LOCALHOST),
7200 + idx as u16,
),
our_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443),
proto_flags: 0,
},
)
.await
);
}
}
async fn wait_for_pool_empty(pool: &Arc<MePool>, timeout: Duration) {
let start = Instant::now();
loop {
if pool.writers.read().await.is_empty() {
return;
}
assert!(
start.elapsed() < timeout,
"timed out waiting for pool.writers to become empty"
);
tokio::time::sleep(Duration::from_millis(5)).await;
}
}
#[tokio::test]
async fn me_health_monitor_drains_expired_backlog_over_multiple_cycles() {
let (pool, rng) = make_pool(128, 1, 1).await;
let now_epoch_secs = MePool::now_epoch_secs();
let writer_total = health_drain_close_budget().saturating_mul(2).saturating_add(9);
for writer_id in 1..=writer_total as u64 {
insert_draining_writer(
&pool,
writer_id,
now_epoch_secs.saturating_sub(120),
1,
now_epoch_secs.saturating_sub(1),
)
.await;
}
let monitor = tokio::spawn(me_health_monitor(pool.clone(), rng, 0));
wait_for_pool_empty(&pool, Duration::from_secs(1)).await;
monitor.abort();
let _ = monitor.await;
assert!(pool.writers.read().await.is_empty());
}
#[tokio::test]
async fn me_health_monitor_cleans_empty_draining_writers_without_force_close() {
let (pool, rng) = make_pool(128, 1, 1).await;
let now_epoch_secs = MePool::now_epoch_secs();
for writer_id in 1..=24u64 {
insert_draining_writer(&pool, writer_id, now_epoch_secs.saturating_sub(60), 0, 0).await;
}
let monitor = tokio::spawn(me_health_monitor(pool.clone(), rng, 0));
wait_for_pool_empty(&pool, Duration::from_secs(1)).await;
monitor.abort();
let _ = monitor.await;
assert!(pool.writers.read().await.is_empty());
}
#[tokio::test]
async fn me_health_monitor_converges_retry_like_threshold_backlog_to_empty() {
let threshold = 4u64;
let (pool, rng) = make_pool(threshold, 1, 1).await;
let now_epoch_secs = MePool::now_epoch_secs();
let writer_total = threshold as usize + health_drain_close_budget().saturating_add(11);
for writer_id in 1..=writer_total as u64 {
insert_draining_writer(
&pool,
writer_id,
now_epoch_secs.saturating_sub(300).saturating_add(writer_id),
1,
0,
)
.await;
}
let monitor = tokio::spawn(me_health_monitor(pool.clone(), rng, 0));
wait_for_pool_empty(&pool, Duration::from_secs(1)).await;
monitor.abort();
let _ = monitor.await;
assert!(pool.writers.read().await.is_empty());
}

View File

@ -0,0 +1,658 @@
use std::collections::HashMap;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU8, AtomicU32, AtomicU64, Ordering};
use std::time::{Duration, Instant};
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use super::codec::WriterCommand;
use super::health::{health_drain_close_budget, reap_draining_writers};
use super::pool::{MePool, MeWriter, WriterContour};
use super::registry::ConnMeta;
use crate::config::{GeneralConfig, MeRouteNoWriterMode, MeSocksKdfPolicy, MeWriterPickMode};
use crate::crypto::SecureRandom;
use crate::network::probe::NetworkDecision;
use crate::stats::Stats;
async fn make_pool(me_pool_drain_threshold: u64) -> Arc<MePool> {
let general = GeneralConfig {
me_pool_drain_threshold,
..GeneralConfig::default()
};
MePool::new(
None,
vec![1u8; 32],
None,
false,
None,
Vec::new(),
1,
None,
12,
1200,
HashMap::new(),
HashMap::new(),
None,
NetworkDecision::default(),
None,
Arc::new(SecureRandom::new()),
Arc::new(Stats::default()),
general.me_keepalive_enabled,
general.me_keepalive_interval_secs,
general.me_keepalive_jitter_secs,
general.me_keepalive_payload_random,
general.rpc_proxy_req_every,
general.me_warmup_stagger_enabled,
general.me_warmup_step_delay_ms,
general.me_warmup_step_jitter_ms,
general.me_reconnect_max_concurrent_per_dc,
general.me_reconnect_backoff_base_ms,
general.me_reconnect_backoff_cap_ms,
general.me_reconnect_fast_retry_count,
general.me_single_endpoint_shadow_writers,
general.me_single_endpoint_outage_mode_enabled,
general.me_single_endpoint_outage_disable_quarantine,
general.me_single_endpoint_outage_backoff_min_ms,
general.me_single_endpoint_outage_backoff_max_ms,
general.me_single_endpoint_shadow_rotate_every_secs,
general.me_floor_mode,
general.me_adaptive_floor_idle_secs,
general.me_adaptive_floor_min_writers_single_endpoint,
general.me_adaptive_floor_min_writers_multi_endpoint,
general.me_adaptive_floor_recover_grace_secs,
general.me_adaptive_floor_writers_per_core_total,
general.me_adaptive_floor_cpu_cores_override,
general.me_adaptive_floor_max_extra_writers_single_per_core,
general.me_adaptive_floor_max_extra_writers_multi_per_core,
general.me_adaptive_floor_max_active_writers_per_core,
general.me_adaptive_floor_max_warm_writers_per_core,
general.me_adaptive_floor_max_active_writers_global,
general.me_adaptive_floor_max_warm_writers_global,
general.hardswap,
general.me_pool_drain_ttl_secs,
general.me_pool_drain_threshold,
general.effective_me_pool_force_close_secs(),
general.me_pool_min_fresh_ratio,
general.me_hardswap_warmup_delay_min_ms,
general.me_hardswap_warmup_delay_max_ms,
general.me_hardswap_warmup_extra_passes,
general.me_hardswap_warmup_pass_backoff_base_ms,
general.me_bind_stale_mode,
general.me_bind_stale_ttl_secs,
general.me_secret_atomic_snapshot,
general.me_deterministic_writer_sort,
MeWriterPickMode::default(),
general.me_writer_pick_sample_size,
MeSocksKdfPolicy::default(),
general.me_writer_cmd_channel_capacity,
general.me_route_channel_capacity,
general.me_route_backpressure_base_timeout_ms,
general.me_route_backpressure_high_timeout_ms,
general.me_route_backpressure_high_watermark_pct,
general.me_reader_route_data_wait_ms,
general.me_health_interval_ms_unhealthy,
general.me_health_interval_ms_healthy,
general.me_warn_rate_limit_ms,
MeRouteNoWriterMode::default(),
general.me_route_no_writer_wait_ms,
general.me_route_inline_recovery_attempts,
general.me_route_inline_recovery_wait_ms,
)
}
async fn insert_draining_writer(
pool: &Arc<MePool>,
writer_id: u64,
drain_started_at_epoch_secs: u64,
bound_clients: usize,
drain_deadline_epoch_secs: u64,
) -> Vec<u64> {
let mut conn_ids = Vec::with_capacity(bound_clients);
let (tx, _writer_rx) = mpsc::channel::<WriterCommand>(8);
let writer = MeWriter {
id: writer_id,
addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 4500 + writer_id as u16),
source_ip: IpAddr::V4(Ipv4Addr::LOCALHOST),
writer_dc: 2,
generation: 1,
contour: Arc::new(AtomicU8::new(WriterContour::Draining.as_u8())),
created_at: Instant::now() - Duration::from_secs(writer_id),
tx: tx.clone(),
cancel: CancellationToken::new(),
degraded: Arc::new(AtomicBool::new(false)),
rtt_ema_ms_x10: Arc::new(AtomicU32::new(0)),
draining: Arc::new(AtomicBool::new(true)),
draining_started_at_epoch_secs: Arc::new(AtomicU64::new(drain_started_at_epoch_secs)),
drain_deadline_epoch_secs: Arc::new(AtomicU64::new(drain_deadline_epoch_secs)),
allow_drain_fallback: Arc::new(AtomicBool::new(false)),
};
pool.writers.write().await.push(writer);
pool.registry.register_writer(writer_id, tx).await;
pool.conn_count.fetch_add(1, Ordering::Relaxed);
for idx in 0..bound_clients {
let (conn_id, _rx) = pool.registry.register().await;
assert!(
pool.registry
.bind_writer(
conn_id,
writer_id,
ConnMeta {
target_dc: 2,
client_addr: SocketAddr::new(
IpAddr::V4(Ipv4Addr::LOCALHOST),
6200 + idx as u16,
),
our_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443),
proto_flags: 0,
},
)
.await
);
conn_ids.push(conn_id);
}
conn_ids
}
async fn current_writer_ids(pool: &Arc<MePool>) -> Vec<u64> {
let mut writer_ids = pool
.writers
.read()
.await
.iter()
.map(|writer| writer.id)
.collect::<Vec<_>>();
writer_ids.sort_unstable();
writer_ids
}
async fn writer_exists(pool: &Arc<MePool>, writer_id: u64) -> bool {
pool.writers
.read()
.await
.iter()
.any(|writer| writer.id == writer_id)
}
async fn set_writer_draining(pool: &Arc<MePool>, writer_id: u64, draining: bool) {
let writers = pool.writers.read().await;
if let Some(writer) = writers.iter().find(|writer| writer.id == writer_id) {
writer.draining.store(draining, Ordering::Relaxed);
}
}
#[tokio::test]
async fn reap_draining_writers_drops_warn_state_for_removed_writer() {
let pool = make_pool(128).await;
let now_epoch_secs = MePool::now_epoch_secs();
let conn_ids =
insert_draining_writer(&pool, 7, now_epoch_secs.saturating_sub(180), 1, 0).await;
let mut warn_next_allowed = HashMap::new();
reap_draining_writers(&pool, &mut warn_next_allowed).await;
assert!(warn_next_allowed.contains_key(&7));
let _ = pool.remove_writer_and_close_clients(7).await;
assert!(pool.registry.get_writer(conn_ids[0]).await.is_none());
reap_draining_writers(&pool, &mut warn_next_allowed).await;
assert!(!warn_next_allowed.contains_key(&7));
}
#[tokio::test]
async fn reap_draining_writers_removes_empty_draining_writers() {
let pool = make_pool(128).await;
let now_epoch_secs = MePool::now_epoch_secs();
insert_draining_writer(&pool, 1, now_epoch_secs.saturating_sub(40), 0, 0).await;
insert_draining_writer(&pool, 2, now_epoch_secs.saturating_sub(30), 0, 0).await;
insert_draining_writer(&pool, 3, now_epoch_secs.saturating_sub(20), 1, 0).await;
let mut warn_next_allowed = HashMap::new();
reap_draining_writers(&pool, &mut warn_next_allowed).await;
assert_eq!(current_writer_ids(&pool).await, vec![3]);
}
#[tokio::test]
async fn reap_draining_writers_overflow_closes_oldest_non_empty_writers() {
let pool = make_pool(2).await;
let now_epoch_secs = MePool::now_epoch_secs();
insert_draining_writer(&pool, 11, now_epoch_secs.saturating_sub(40), 1, 0).await;
insert_draining_writer(&pool, 22, now_epoch_secs.saturating_sub(30), 1, 0).await;
insert_draining_writer(&pool, 33, now_epoch_secs.saturating_sub(20), 1, 0).await;
insert_draining_writer(&pool, 44, now_epoch_secs.saturating_sub(10), 1, 0).await;
let mut warn_next_allowed = HashMap::new();
reap_draining_writers(&pool, &mut warn_next_allowed).await;
assert_eq!(current_writer_ids(&pool).await, vec![33, 44]);
}
#[tokio::test]
async fn reap_draining_writers_deadline_force_close_applies_under_threshold() {
let pool = make_pool(128).await;
let now_epoch_secs = MePool::now_epoch_secs();
insert_draining_writer(
&pool,
50,
now_epoch_secs.saturating_sub(15),
1,
now_epoch_secs.saturating_sub(1),
)
.await;
let mut warn_next_allowed = HashMap::new();
reap_draining_writers(&pool, &mut warn_next_allowed).await;
assert!(current_writer_ids(&pool).await.is_empty());
}
#[tokio::test]
async fn reap_draining_writers_limits_closes_per_health_tick() {
let pool = make_pool(128).await;
let now_epoch_secs = MePool::now_epoch_secs();
let close_budget = health_drain_close_budget();
let writer_total = close_budget.saturating_add(19);
for writer_id in 1..=writer_total as u64 {
insert_draining_writer(
&pool,
writer_id,
now_epoch_secs.saturating_sub(20),
1,
now_epoch_secs.saturating_sub(1),
)
.await;
}
let mut warn_next_allowed = HashMap::new();
reap_draining_writers(&pool, &mut warn_next_allowed).await;
assert_eq!(pool.writers.read().await.len(), writer_total - close_budget);
}
#[tokio::test]
async fn reap_draining_writers_keeps_warn_state_for_deadline_backlog_writers() {
let pool = make_pool(0).await;
let now_epoch_secs = MePool::now_epoch_secs();
let close_budget = health_drain_close_budget();
let writer_total = close_budget.saturating_add(5);
for writer_id in 1..=writer_total as u64 {
insert_draining_writer(
&pool,
writer_id,
now_epoch_secs.saturating_sub(60),
1,
now_epoch_secs.saturating_sub(1),
)
.await;
}
let target_writer_id = writer_total as u64;
let mut warn_next_allowed = HashMap::new();
warn_next_allowed.insert(
target_writer_id,
Instant::now() + Duration::from_secs(300),
);
reap_draining_writers(&pool, &mut warn_next_allowed).await;
assert!(writer_exists(&pool, target_writer_id).await);
assert!(warn_next_allowed.contains_key(&target_writer_id));
}
#[tokio::test]
async fn reap_draining_writers_keeps_warn_state_for_overflow_backlog_writers() {
let pool = make_pool(1).await;
let now_epoch_secs = MePool::now_epoch_secs();
let close_budget = health_drain_close_budget();
let writer_total = close_budget.saturating_add(6);
for writer_id in 1..=writer_total as u64 {
insert_draining_writer(
&pool,
writer_id,
now_epoch_secs.saturating_sub(300).saturating_add(writer_id),
1,
0,
)
.await;
}
let target_writer_id = writer_total.saturating_sub(1) as u64;
let mut warn_next_allowed = HashMap::new();
warn_next_allowed.insert(
target_writer_id,
Instant::now() + Duration::from_secs(300),
);
reap_draining_writers(&pool, &mut warn_next_allowed).await;
assert!(writer_exists(&pool, target_writer_id).await);
assert!(warn_next_allowed.contains_key(&target_writer_id));
}
#[tokio::test]
async fn reap_draining_writers_drops_warn_state_when_writer_exits_draining_state() {
let pool = make_pool(128).await;
let now_epoch_secs = MePool::now_epoch_secs();
insert_draining_writer(&pool, 71, now_epoch_secs.saturating_sub(60), 1, 0).await;
let mut warn_next_allowed = HashMap::new();
warn_next_allowed.insert(71, Instant::now() + Duration::from_secs(300));
set_writer_draining(&pool, 71, false).await;
reap_draining_writers(&pool, &mut warn_next_allowed).await;
assert!(writer_exists(&pool, 71).await);
assert!(
!warn_next_allowed.contains_key(&71),
"warn cooldown state must be dropped after writer leaves draining state"
);
}
#[tokio::test]
async fn reap_draining_writers_preserves_warn_state_across_multiple_budget_deferrals() {
let pool = make_pool(0).await;
let now_epoch_secs = MePool::now_epoch_secs();
let close_budget = health_drain_close_budget();
let writer_total = close_budget.saturating_mul(2).saturating_add(1);
for writer_id in 1..=writer_total as u64 {
insert_draining_writer(
&pool,
writer_id,
now_epoch_secs.saturating_sub(120),
1,
now_epoch_secs.saturating_sub(1),
)
.await;
}
let tail_writer_id = writer_total as u64;
let mut warn_next_allowed = HashMap::new();
warn_next_allowed.insert(
tail_writer_id,
Instant::now() + Duration::from_secs(300),
);
reap_draining_writers(&pool, &mut warn_next_allowed).await;
assert!(writer_exists(&pool, tail_writer_id).await);
assert!(warn_next_allowed.contains_key(&tail_writer_id));
reap_draining_writers(&pool, &mut warn_next_allowed).await;
assert!(writer_exists(&pool, tail_writer_id).await);
assert!(warn_next_allowed.contains_key(&tail_writer_id));
reap_draining_writers(&pool, &mut warn_next_allowed).await;
assert!(!writer_exists(&pool, tail_writer_id).await);
assert!(
!warn_next_allowed.contains_key(&tail_writer_id),
"warn cooldown state must clear once writer is actually removed"
);
}
#[tokio::test]
async fn reap_draining_writers_backlog_drains_across_ticks() {
let pool = make_pool(128).await;
let now_epoch_secs = MePool::now_epoch_secs();
let close_budget = health_drain_close_budget();
let writer_total = close_budget.saturating_mul(2).saturating_add(7);
for writer_id in 1..=writer_total as u64 {
insert_draining_writer(
&pool,
writer_id,
now_epoch_secs.saturating_sub(20),
1,
now_epoch_secs.saturating_sub(1),
)
.await;
}
let mut warn_next_allowed = HashMap::new();
for _ in 0..8 {
if pool.writers.read().await.is_empty() {
break;
}
reap_draining_writers(&pool, &mut warn_next_allowed).await;
}
assert!(pool.writers.read().await.is_empty());
}
#[tokio::test]
async fn reap_draining_writers_threshold_backlog_converges_to_threshold() {
let threshold = 5u64;
let pool = make_pool(threshold).await;
let now_epoch_secs = MePool::now_epoch_secs();
let close_budget = health_drain_close_budget();
let writer_total = threshold as usize + close_budget.saturating_add(12);
for writer_id in 1..=writer_total as u64 {
insert_draining_writer(
&pool,
writer_id,
now_epoch_secs.saturating_sub(200).saturating_add(writer_id),
1,
0,
)
.await;
}
let mut warn_next_allowed = HashMap::new();
for _ in 0..16 {
reap_draining_writers(&pool, &mut warn_next_allowed).await;
if pool.writers.read().await.len() <= threshold as usize {
break;
}
}
assert_eq!(pool.writers.read().await.len(), threshold as usize);
}
#[tokio::test]
async fn reap_draining_writers_threshold_zero_preserves_non_expired_non_empty_writers() {
let pool = make_pool(0).await;
let now_epoch_secs = MePool::now_epoch_secs();
insert_draining_writer(&pool, 10, now_epoch_secs.saturating_sub(40), 1, 0).await;
insert_draining_writer(&pool, 20, now_epoch_secs.saturating_sub(30), 1, 0).await;
insert_draining_writer(&pool, 30, now_epoch_secs.saturating_sub(20), 1, 0).await;
let mut warn_next_allowed = HashMap::new();
reap_draining_writers(&pool, &mut warn_next_allowed).await;
assert_eq!(current_writer_ids(&pool).await, vec![10, 20, 30]);
}
#[tokio::test]
async fn reap_draining_writers_prioritizes_force_close_before_empty_cleanup() {
let pool = make_pool(128).await;
let now_epoch_secs = MePool::now_epoch_secs();
let close_budget = health_drain_close_budget();
for writer_id in 1..=close_budget as u64 {
insert_draining_writer(
&pool,
writer_id,
now_epoch_secs.saturating_sub(20),
1,
now_epoch_secs.saturating_sub(1),
)
.await;
}
let empty_writer_id = close_budget as u64 + 1;
insert_draining_writer(&pool, empty_writer_id, now_epoch_secs.saturating_sub(20), 0, 0).await;
let mut warn_next_allowed = HashMap::new();
reap_draining_writers(&pool, &mut warn_next_allowed).await;
assert_eq!(current_writer_ids(&pool).await, vec![empty_writer_id]);
}
#[tokio::test]
async fn reap_draining_writers_empty_cleanup_does_not_increment_force_close_metric() {
let pool = make_pool(128).await;
let now_epoch_secs = MePool::now_epoch_secs();
insert_draining_writer(&pool, 1, now_epoch_secs.saturating_sub(60), 0, 0).await;
insert_draining_writer(&pool, 2, now_epoch_secs.saturating_sub(50), 0, 0).await;
let mut warn_next_allowed = HashMap::new();
reap_draining_writers(&pool, &mut warn_next_allowed).await;
assert!(current_writer_ids(&pool).await.is_empty());
assert_eq!(pool.stats.get_pool_force_close_total(), 0);
}
#[tokio::test]
async fn reap_draining_writers_handles_duplicate_force_close_requests_for_same_writer() {
let pool = make_pool(1).await;
let now_epoch_secs = MePool::now_epoch_secs();
insert_draining_writer(
&pool,
10,
now_epoch_secs.saturating_sub(30),
1,
now_epoch_secs.saturating_sub(1),
)
.await;
insert_draining_writer(
&pool,
20,
now_epoch_secs.saturating_sub(20),
1,
now_epoch_secs.saturating_sub(1),
)
.await;
let mut warn_next_allowed = HashMap::new();
reap_draining_writers(&pool, &mut warn_next_allowed).await;
assert!(current_writer_ids(&pool).await.is_empty());
}
#[tokio::test]
async fn reap_draining_writers_warn_state_never_exceeds_live_draining_population_under_churn() {
let pool = make_pool(128).await;
let now_epoch_secs = MePool::now_epoch_secs();
let mut warn_next_allowed = HashMap::new();
for wave in 0..12u64 {
for offset in 0..9u64 {
insert_draining_writer(
&pool,
wave * 100 + offset,
now_epoch_secs.saturating_sub(120 + offset),
1,
0,
)
.await;
}
reap_draining_writers(&pool, &mut warn_next_allowed).await;
assert!(warn_next_allowed.len() <= pool.writers.read().await.len());
let existing_writer_ids = current_writer_ids(&pool).await;
for writer_id in existing_writer_ids.into_iter().take(4) {
let _ = pool.remove_writer_and_close_clients(writer_id).await;
}
reap_draining_writers(&pool, &mut warn_next_allowed).await;
assert!(warn_next_allowed.len() <= pool.writers.read().await.len());
}
}
#[tokio::test]
async fn reap_draining_writers_mixed_backlog_converges_without_leaking_warn_state() {
let pool = make_pool(6).await;
let now_epoch_secs = MePool::now_epoch_secs();
let mut warn_next_allowed = HashMap::new();
for writer_id in 1..=18u64 {
let bound_clients = if writer_id % 3 == 0 { 0 } else { 1 };
let deadline = if writer_id % 2 == 0 {
now_epoch_secs.saturating_sub(1)
} else {
0
};
insert_draining_writer(
&pool,
writer_id,
now_epoch_secs.saturating_sub(300).saturating_add(writer_id),
bound_clients,
deadline,
)
.await;
}
for _ in 0..16 {
reap_draining_writers(&pool, &mut warn_next_allowed).await;
if pool.writers.read().await.len() <= 6 {
break;
}
}
assert!(pool.writers.read().await.len() <= 6);
assert!(warn_next_allowed.len() <= pool.writers.read().await.len());
}
#[test]
fn general_config_default_drain_threshold_remains_enabled() {
assert_eq!(GeneralConfig::default().me_pool_drain_threshold, 128);
}
#[tokio::test]
async fn reap_draining_writers_does_not_close_writer_that_became_non_empty_after_snapshot() {
let pool = make_pool(128).await;
let now_epoch_secs = MePool::now_epoch_secs();
let empty_writer_id = 700u64;
insert_draining_writer(
&pool,
empty_writer_id,
now_epoch_secs.saturating_sub(60),
0,
0,
)
.await;
let stale_empty_snapshot = vec![empty_writer_id];
let (rebound_conn_id, _rx) = pool.registry.register().await;
assert!(
pool.registry
.bind_writer(
rebound_conn_id,
empty_writer_id,
ConnMeta {
target_dc: 2,
client_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 9050),
our_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443),
proto_flags: 0,
},
)
.await,
"writer should accept a new bind after stale empty snapshot"
);
for writer_id in stale_empty_snapshot {
assert!(
!pool.remove_writer_if_empty(writer_id).await,
"atomic empty cleanup must reject writers that gained bound clients"
);
}
assert!(
writer_exists(&pool, empty_writer_id).await,
"empty-path cleanup must not remove a writer that gained a bound client"
);
assert_eq!(
pool.registry.get_writer(rebound_conn_id).await.map(|w| w.writer_id),
Some(empty_writer_id)
);
let _ = pool.registry.unregister(rebound_conn_id).await;
}
#[tokio::test]
async fn prune_closed_writers_closes_bound_clients_when_writer_is_non_empty() {
let pool = make_pool(128).await;
let now_epoch_secs = MePool::now_epoch_secs();
let conn_ids = insert_draining_writer(&pool, 910, now_epoch_secs.saturating_sub(60), 1, 0).await;
pool.prune_closed_writers().await;
assert!(!writer_exists(&pool, 910).await);
assert!(pool.registry.get_writer(conn_ids[0]).await.is_none());
}

View File

@ -21,6 +21,14 @@ mod secret;
mod selftest;
mod wire;
mod pool_status;
#[cfg(test)]
mod health_regression_tests;
#[cfg(test)]
mod health_integration_tests;
#[cfg(test)]
mod health_adversarial_tests;
#[cfg(test)]
mod send_adversarial_tests;
use bytes::Bytes;

View File

@ -692,6 +692,7 @@ impl MePool {
}
}
#[allow(dead_code)]
pub(super) fn draining_active_runtime(&self) -> u64 {
self.draining_active_runtime.load(Ordering::Relaxed)
}

View File

@ -42,11 +42,10 @@ impl MePool {
}
for writer_id in closed_writer_ids {
if self.registry.is_writer_empty(writer_id).await {
let _ = self.remove_writer_only(writer_id).await;
} else {
let _ = self.remove_writer_and_close_clients(writer_id).await;
if self.remove_writer_if_empty(writer_id).await {
continue;
}
let _ = self.remove_writer_and_close_clients(writer_id).await;
}
}
@ -501,6 +500,17 @@ impl MePool {
}
}
pub(crate) async fn remove_writer_if_empty(self: &Arc<Self>, writer_id: u64) -> bool {
if !self.registry.unregister_writer_if_empty(writer_id).await {
return false;
}
// The registry empty-check and unregister are atomic with respect to binds,
// so remove_writer_only cannot return active bound sessions here.
let _ = self.remove_writer_only(writer_id).await;
true
}
async fn remove_writer_only(self: &Arc<Self>, writer_id: u64) -> Vec<BoundConn> {
let mut close_tx: Option<mpsc::Sender<WriterCommand>> = None;
let mut removed_addr: Option<SocketAddr> = None;

View File

@ -437,6 +437,24 @@ impl ConnRegistry {
.unwrap_or(true)
}
pub async fn unregister_writer_if_empty(&self, writer_id: u64) -> bool {
let mut inner = self.inner.write().await;
let Some(conn_ids) = inner.conns_for_writer.get(&writer_id) else {
// Writer is already absent from the registry.
return true;
};
if !conn_ids.is_empty() {
return false;
}
inner.writers.remove(&writer_id);
inner.last_meta_for_writer.remove(&writer_id);
inner.writer_idle_since_epoch_secs.remove(&writer_id);
inner.conns_for_writer.remove(&writer_id);
true
}
#[allow(dead_code)]
pub(super) async fn non_empty_writer_ids(&self, writer_ids: &[u64]) -> HashSet<u64> {
let inner = self.inner.read().await;
let mut out = HashSet::<u64>::with_capacity(writer_ids.len());

View File

@ -372,17 +372,20 @@ impl MePool {
}
let effective_our_addr = SocketAddr::new(w.source_ip, our_addr.port());
let (payload, meta) = build_routed_payload(effective_our_addr);
match w.tx.try_send(WriterCommand::Data(payload.clone())) {
Ok(()) => {
self.stats.increment_me_writer_pick_success_try_total(pick_mode);
match w.tx.clone().try_reserve_owned() {
Ok(permit) => {
if !self.registry.bind_writer(conn_id, w.id, meta).await {
debug!(
conn_id,
writer_id = w.id,
"ME writer disappeared before bind commit, retrying"
"ME writer disappeared before bind commit, pruning stale writer"
);
drop(permit);
self.remove_writer_and_close_clients(w.id).await;
continue;
}
permit.send(WriterCommand::Data(payload.clone()));
self.stats.increment_me_writer_pick_success_try_total(pick_mode);
if w.generation < self.current_generation() {
self.stats.increment_pool_stale_pick_total();
debug!(
@ -422,18 +425,21 @@ impl MePool {
self.stats.increment_me_writer_pick_blocking_fallback_total();
let effective_our_addr = SocketAddr::new(w.source_ip, our_addr.port());
let (payload, meta) = build_routed_payload(effective_our_addr);
match w.tx.send(WriterCommand::Data(payload.clone())).await {
Ok(()) => {
self.stats
.increment_me_writer_pick_success_fallback_total(pick_mode);
match w.tx.clone().reserve_owned().await {
Ok(permit) => {
if !self.registry.bind_writer(conn_id, w.id, meta).await {
debug!(
conn_id,
writer_id = w.id,
"ME writer disappeared before fallback bind commit, retrying"
"ME writer disappeared before fallback bind commit, pruning stale writer"
);
drop(permit);
self.remove_writer_and_close_clients(w.id).await;
continue;
}
permit.send(WriterCommand::Data(payload.clone()));
self.stats
.increment_me_writer_pick_success_fallback_total(pick_mode);
if w.generation < self.current_generation() {
self.stats.increment_pool_stale_pick_total();
}

View File

@ -0,0 +1,263 @@
use std::collections::HashMap;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU8, AtomicU32, AtomicU64, Ordering};
use std::time::{Duration, Instant};
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use super::codec::WriterCommand;
use super::pool::{MePool, MeWriter, WriterContour};
use crate::config::{GeneralConfig, MeRouteNoWriterMode, MeSocksKdfPolicy, MeWriterPickMode};
use crate::crypto::SecureRandom;
use crate::network::probe::NetworkDecision;
use crate::stats::Stats;
async fn make_pool() -> (Arc<MePool>, Arc<SecureRandom>) {
let general = GeneralConfig {
me_route_no_writer_mode: MeRouteNoWriterMode::AsyncRecoveryFailfast,
me_route_no_writer_wait_ms: 50,
me_writer_pick_mode: MeWriterPickMode::SortedRr,
me_deterministic_writer_sort: true,
..GeneralConfig::default()
};
let rng = Arc::new(SecureRandom::new());
let pool = MePool::new(
None,
vec![1u8; 32],
None,
false,
None,
Vec::new(),
1,
None,
12,
1200,
HashMap::new(),
HashMap::new(),
None,
NetworkDecision::default(),
None,
rng.clone(),
Arc::new(Stats::default()),
general.me_keepalive_enabled,
general.me_keepalive_interval_secs,
general.me_keepalive_jitter_secs,
general.me_keepalive_payload_random,
general.rpc_proxy_req_every,
general.me_warmup_stagger_enabled,
general.me_warmup_step_delay_ms,
general.me_warmup_step_jitter_ms,
general.me_reconnect_max_concurrent_per_dc,
general.me_reconnect_backoff_base_ms,
general.me_reconnect_backoff_cap_ms,
general.me_reconnect_fast_retry_count,
general.me_single_endpoint_shadow_writers,
general.me_single_endpoint_outage_mode_enabled,
general.me_single_endpoint_outage_disable_quarantine,
general.me_single_endpoint_outage_backoff_min_ms,
general.me_single_endpoint_outage_backoff_max_ms,
general.me_single_endpoint_shadow_rotate_every_secs,
general.me_floor_mode,
general.me_adaptive_floor_idle_secs,
general.me_adaptive_floor_min_writers_single_endpoint,
general.me_adaptive_floor_min_writers_multi_endpoint,
general.me_adaptive_floor_recover_grace_secs,
general.me_adaptive_floor_writers_per_core_total,
general.me_adaptive_floor_cpu_cores_override,
general.me_adaptive_floor_max_extra_writers_single_per_core,
general.me_adaptive_floor_max_extra_writers_multi_per_core,
general.me_adaptive_floor_max_active_writers_per_core,
general.me_adaptive_floor_max_warm_writers_per_core,
general.me_adaptive_floor_max_active_writers_global,
general.me_adaptive_floor_max_warm_writers_global,
general.hardswap,
general.me_pool_drain_ttl_secs,
general.me_pool_drain_threshold,
general.effective_me_pool_force_close_secs(),
general.me_pool_min_fresh_ratio,
general.me_hardswap_warmup_delay_min_ms,
general.me_hardswap_warmup_delay_max_ms,
general.me_hardswap_warmup_extra_passes,
general.me_hardswap_warmup_pass_backoff_base_ms,
general.me_bind_stale_mode,
general.me_bind_stale_ttl_secs,
general.me_secret_atomic_snapshot,
general.me_deterministic_writer_sort,
general.me_writer_pick_mode,
general.me_writer_pick_sample_size,
MeSocksKdfPolicy::default(),
general.me_writer_cmd_channel_capacity,
general.me_route_channel_capacity,
general.me_route_backpressure_base_timeout_ms,
general.me_route_backpressure_high_timeout_ms,
general.me_route_backpressure_high_watermark_pct,
general.me_reader_route_data_wait_ms,
general.me_health_interval_ms_unhealthy,
general.me_health_interval_ms_healthy,
general.me_warn_rate_limit_ms,
general.me_route_no_writer_mode,
general.me_route_no_writer_wait_ms,
general.me_route_inline_recovery_attempts,
general.me_route_inline_recovery_wait_ms,
);
(pool, rng)
}
async fn insert_writer(
pool: &Arc<MePool>,
writer_id: u64,
writer_dc: i32,
addr: SocketAddr,
register_in_registry: bool,
) -> mpsc::Receiver<WriterCommand> {
let (tx, rx) = mpsc::channel::<WriterCommand>(8);
let writer = MeWriter {
id: writer_id,
addr,
source_ip: addr.ip(),
writer_dc,
generation: pool.current_generation(),
contour: Arc::new(AtomicU8::new(WriterContour::Active.as_u8())),
created_at: Instant::now(),
tx: tx.clone(),
cancel: CancellationToken::new(),
degraded: Arc::new(AtomicBool::new(false)),
rtt_ema_ms_x10: Arc::new(AtomicU32::new(0)),
draining: Arc::new(AtomicBool::new(false)),
draining_started_at_epoch_secs: Arc::new(AtomicU64::new(0)),
drain_deadline_epoch_secs: Arc::new(AtomicU64::new(0)),
allow_drain_fallback: Arc::new(AtomicBool::new(false)),
};
pool.writers.write().await.push(writer);
{
let mut map = pool.proxy_map_v4.write().await;
map.entry(writer_dc)
.or_insert_with(Vec::new)
.push((addr.ip(), addr.port()));
}
pool.rebuild_endpoint_dc_map().await;
if register_in_registry {
pool.registry.register_writer(writer_id, tx).await;
}
rx
}
async fn recv_data_count(rx: &mut mpsc::Receiver<WriterCommand>, budget: Duration) -> usize {
let start = Instant::now();
let mut data_count = 0usize;
while Instant::now().duration_since(start) < budget {
let remaining = budget.saturating_sub(Instant::now().duration_since(start));
match tokio::time::timeout(remaining.min(Duration::from_millis(10)), rx.recv()).await {
Ok(Some(WriterCommand::Data(_))) => data_count += 1,
Ok(Some(WriterCommand::DataAndFlush(_))) => data_count += 1,
Ok(Some(WriterCommand::Close)) => {}
Ok(None) => break,
Err(_) => break,
}
}
data_count
}
#[tokio::test]
async fn send_proxy_req_does_not_replay_when_first_bind_commit_fails() {
let (pool, _rng) = make_pool().await;
pool.rr.store(0, Ordering::Relaxed);
let (conn_id, _rx) = pool.registry.register().await;
let mut stale_rx = insert_writer(
&pool,
10,
2,
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 10)), 443),
false,
)
.await;
let mut live_rx = insert_writer(
&pool,
11,
2,
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 11)), 443),
true,
)
.await;
let result = pool
.send_proxy_req(
conn_id,
2,
SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 30000),
SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443),
b"hello",
0,
None,
)
.await;
assert!(result.is_ok());
assert_eq!(recv_data_count(&mut stale_rx, Duration::from_millis(50)).await, 0);
assert_eq!(recv_data_count(&mut live_rx, Duration::from_millis(50)).await, 1);
let bound = pool.registry.get_writer(conn_id).await;
assert!(bound.is_some());
assert_eq!(bound.expect("writer should be bound").writer_id, 11);
}
#[tokio::test]
async fn send_proxy_req_prunes_iterative_stale_bind_failures_without_data_replay() {
let (pool, _rng) = make_pool().await;
pool.rr.store(0, Ordering::Relaxed);
let (conn_id, _rx) = pool.registry.register().await;
let mut stale_rx_1 = insert_writer(
&pool,
21,
2,
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 1, 21)), 443),
false,
)
.await;
let mut stale_rx_2 = insert_writer(
&pool,
22,
2,
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 1, 22)), 443),
false,
)
.await;
let mut live_rx = insert_writer(
&pool,
23,
2,
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 1, 23)), 443),
true,
)
.await;
let result = pool
.send_proxy_req(
conn_id,
2,
SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 30001),
SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443),
b"storm",
0,
None,
)
.await;
assert!(result.is_ok());
assert_eq!(recv_data_count(&mut stale_rx_1, Duration::from_millis(50)).await, 0);
assert_eq!(recv_data_count(&mut stale_rx_2, Duration::from_millis(50)).await, 0);
assert_eq!(recv_data_count(&mut live_rx, Duration::from_millis(50)).await, 1);
let writers = pool.writers.read().await;
let writer_ids = writers.iter().map(|w| w.id).collect::<Vec<_>>();
drop(writers);
assert_eq!(writer_ids, vec![23]);
}