This commit is contained in:
David Osipov 2026-03-17 17:38:23 +00:00 committed by GitHub
commit 59dc95b555
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 5364 additions and 155 deletions

2
Cargo.lock generated
View File

@ -2093,7 +2093,7 @@ dependencies = [
[[package]] [[package]]
name = "telemt" name = "telemt"
version = "3.3.19" version = "3.3.20"
dependencies = [ dependencies = [
"aes", "aes",
"anyhow", "anyhow",

View File

@ -32,6 +32,7 @@ show = "*"
port = 443 port = 443
# proxy_protocol = false # Enable if behind HAProxy/nginx with PROXY protocol # proxy_protocol = false # Enable if behind HAProxy/nginx with PROXY protocol
# metrics_port = 9090 # metrics_port = 9090
# metrics_listen = "0.0.0.0:9090" # Listen address for metrics (overrides metrics_port)
# metrics_whitelist = ["127.0.0.1", "::1", "0.0.0.0/0"] # metrics_whitelist = ["127.0.0.1", "::1", "0.0.0.0/0"]
[server.api] [server.api]

View File

@ -38,8 +38,9 @@ umweltschutz.de -> A-запись 198.18.88.88
В конфигурации Telemt: В конфигурации Telemt:
``` ```toml
tls_domain = umweltschutz.de [censorship]
tls_domain = "umweltschutz.de"
``` ```
Этот домен используется клиентом как SNI в ClientHello Этот домен используется клиентом как SNI в ClientHello
@ -56,8 +57,9 @@ tls_domain = umweltschutz.de
В конфигурации Telemt: В конфигурации Telemt:
``` ```toml
mask_host = 127.0.0.1 [censorship]
mask_host = "127.0.0.1"
mask_port = 8443 mask_port = 8443
``` ```
@ -151,16 +153,18 @@ mask_host:mask_port
Например: Например:
``` ```toml
tls_domain = github.com [censorship]
mask_host = github.com tls_domain = "github.com"
mask_host = "github.com"
mask_port = 443 mask_port = 443
``` ```
или или
``` ```toml
mask_host = 140.82.121.4 [censorship]
mask_host = "140.82.121.4"
``` ```
В этом случае: В этом случае:

View File

@ -1163,9 +1163,17 @@ pub struct ServerConfig {
#[serde(default)] #[serde(default)]
pub proxy_protocol_trusted_cidrs: Vec<IpNetwork>, pub proxy_protocol_trusted_cidrs: Vec<IpNetwork>,
/// Port for the Prometheus-compatible metrics endpoint.
/// Enables metrics when set; binds on all interfaces (dual-stack) by default.
#[serde(default)] #[serde(default)]
pub metrics_port: Option<u16>, pub metrics_port: Option<u16>,
/// Listen address for metrics in `IP:PORT` format (e.g. `"127.0.0.1:9090"`).
/// When set, takes precedence over `metrics_port` and binds on the specified address only.
#[serde(default)]
pub metrics_listen: Option<String>,
/// CIDR whitelist for the metrics endpoint.
#[serde(default = "default_metrics_whitelist")] #[serde(default = "default_metrics_whitelist")]
pub metrics_whitelist: Vec<IpNetwork>, pub metrics_whitelist: Vec<IpNetwork>,
@ -1194,6 +1202,7 @@ impl Default for ServerConfig {
proxy_protocol_header_timeout_ms: default_proxy_protocol_header_timeout_ms(), proxy_protocol_header_timeout_ms: default_proxy_protocol_header_timeout_ms(),
proxy_protocol_trusted_cidrs: Vec::new(), proxy_protocol_trusted_cidrs: Vec::new(),
metrics_port: None, metrics_port: None,
metrics_listen: None,
metrics_whitelist: default_metrics_whitelist(), metrics_whitelist: default_metrics_whitelist(),
api: ApiConfig::default(), api: ApiConfig::default(),
listeners: Vec::new(), listeners: Vec::new(),

View File

@ -0,0 +1,450 @@
use std::collections::HashMap;
use std::net::{IpAddr, Ipv4Addr};
use std::sync::Arc;
use std::time::Duration;
use crate::config::UserMaxUniqueIpsMode;
use crate::ip_tracker::UserIpTracker;
fn ip_from_idx(idx: u32) -> IpAddr {
let a = 10u8;
let b = ((idx / 65_536) % 256) as u8;
let c = ((idx / 256) % 256) as u8;
let d = (idx % 256) as u8;
IpAddr::V4(Ipv4Addr::new(a, b, c, d))
}
#[tokio::test]
async fn active_window_enforces_large_unique_ip_burst() {
let tracker = UserIpTracker::new();
tracker.set_user_limit("burst_user", 64).await;
tracker
.set_limit_policy(UserMaxUniqueIpsMode::ActiveWindow, 30)
.await;
for idx in 0..64 {
assert!(tracker.check_and_add("burst_user", ip_from_idx(idx)).await.is_ok());
}
assert!(tracker.check_and_add("burst_user", ip_from_idx(9_999)).await.is_err());
assert_eq!(tracker.get_active_ip_count("burst_user").await, 64);
}
#[tokio::test]
async fn global_limit_applies_across_many_users() {
let tracker = UserIpTracker::new();
tracker.load_limits(3, &HashMap::new()).await;
for user_idx in 0..150u32 {
let user = format!("u{}", user_idx);
assert!(tracker.check_and_add(&user, ip_from_idx(user_idx * 10)).await.is_ok());
assert!(tracker
.check_and_add(&user, ip_from_idx(user_idx * 10 + 1))
.await
.is_ok());
assert!(tracker
.check_and_add(&user, ip_from_idx(user_idx * 10 + 2))
.await
.is_ok());
assert!(tracker
.check_and_add(&user, ip_from_idx(user_idx * 10 + 3))
.await
.is_err());
}
assert_eq!(tracker.get_stats().await.len(), 150);
}
#[tokio::test]
async fn user_zero_override_falls_back_to_global_limit() {
let tracker = UserIpTracker::new();
let mut limits = HashMap::new();
limits.insert("target".to_string(), 0);
tracker.load_limits(2, &limits).await;
assert!(tracker.check_and_add("target", ip_from_idx(1)).await.is_ok());
assert!(tracker.check_and_add("target", ip_from_idx(2)).await.is_ok());
assert!(tracker.check_and_add("target", ip_from_idx(3)).await.is_err());
assert_eq!(tracker.get_user_limit("target").await, Some(2));
}
#[tokio::test]
async fn remove_ip_is_idempotent_after_counter_reaches_zero() {
let tracker = UserIpTracker::new();
tracker.set_user_limit("u", 2).await;
let ip = ip_from_idx(42);
tracker.check_and_add("u", ip).await.unwrap();
tracker.remove_ip("u", ip).await;
tracker.remove_ip("u", ip).await;
tracker.remove_ip("u", ip).await;
assert_eq!(tracker.get_active_ip_count("u").await, 0);
assert!(!tracker.is_ip_active("u", ip).await);
}
#[tokio::test]
async fn clear_user_ips_resets_active_and_recent() {
let tracker = UserIpTracker::new();
tracker.set_user_limit("u", 10).await;
for idx in 0..6 {
tracker.check_and_add("u", ip_from_idx(idx)).await.unwrap();
}
tracker.clear_user_ips("u").await;
assert_eq!(tracker.get_active_ip_count("u").await, 0);
let counts = tracker
.get_recent_counts_for_users(&["u".to_string()])
.await;
assert_eq!(counts.get("u").copied().unwrap_or(0), 0);
}
#[tokio::test]
async fn clear_all_resets_multi_user_state() {
let tracker = UserIpTracker::new();
for user_idx in 0..80u32 {
let user = format!("u{}", user_idx);
for ip_idx in 0..3 {
tracker
.check_and_add(&user, ip_from_idx(user_idx * 100 + ip_idx))
.await
.unwrap();
}
}
tracker.clear_all().await;
assert!(tracker.get_stats().await.is_empty());
let users = (0..80u32)
.map(|idx| format!("u{}", idx))
.collect::<Vec<_>>();
let recent = tracker.get_recent_counts_for_users(&users).await;
assert!(recent.values().all(|count| *count == 0));
}
#[tokio::test]
async fn get_active_ips_for_users_are_sorted() {
let tracker = UserIpTracker::new();
tracker.set_user_limit("user", 10).await;
tracker
.check_and_add("user", IpAddr::V4(Ipv4Addr::new(10, 0, 0, 9)))
.await
.unwrap();
tracker
.check_and_add("user", IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)))
.await
.unwrap();
tracker
.check_and_add("user", IpAddr::V4(Ipv4Addr::new(10, 0, 0, 5)))
.await
.unwrap();
let map = tracker
.get_active_ips_for_users(&["user".to_string()])
.await;
let ips = map.get("user").cloned().unwrap_or_default();
assert_eq!(
ips,
vec![
IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
IpAddr::V4(Ipv4Addr::new(10, 0, 0, 5)),
IpAddr::V4(Ipv4Addr::new(10, 0, 0, 9)),
]
);
}
#[tokio::test]
async fn get_recent_ips_for_users_are_sorted() {
let tracker = UserIpTracker::new();
tracker.set_user_limit("user", 10).await;
tracker
.check_and_add("user", IpAddr::V4(Ipv4Addr::new(10, 1, 0, 9)))
.await
.unwrap();
tracker
.check_and_add("user", IpAddr::V4(Ipv4Addr::new(10, 1, 0, 1)))
.await
.unwrap();
tracker
.check_and_add("user", IpAddr::V4(Ipv4Addr::new(10, 1, 0, 5)))
.await
.unwrap();
let map = tracker
.get_recent_ips_for_users(&["user".to_string()])
.await;
let ips = map.get("user").cloned().unwrap_or_default();
assert_eq!(
ips,
vec![
IpAddr::V4(Ipv4Addr::new(10, 1, 0, 1)),
IpAddr::V4(Ipv4Addr::new(10, 1, 0, 5)),
IpAddr::V4(Ipv4Addr::new(10, 1, 0, 9)),
]
);
}
#[tokio::test]
async fn time_window_expires_for_large_rotation() {
let tracker = UserIpTracker::new();
tracker.set_user_limit("tw", 1).await;
tracker
.set_limit_policy(UserMaxUniqueIpsMode::TimeWindow, 1)
.await;
tracker.check_and_add("tw", ip_from_idx(1)).await.unwrap();
tracker.remove_ip("tw", ip_from_idx(1)).await;
assert!(tracker.check_and_add("tw", ip_from_idx(2)).await.is_err());
tokio::time::sleep(Duration::from_millis(1_100)).await;
assert!(tracker.check_and_add("tw", ip_from_idx(2)).await.is_ok());
}
#[tokio::test]
async fn combined_mode_blocks_recent_after_disconnect() {
let tracker = UserIpTracker::new();
tracker.set_user_limit("cmb", 1).await;
tracker
.set_limit_policy(UserMaxUniqueIpsMode::Combined, 2)
.await;
tracker.check_and_add("cmb", ip_from_idx(11)).await.unwrap();
tracker.remove_ip("cmb", ip_from_idx(11)).await;
assert!(tracker.check_and_add("cmb", ip_from_idx(12)).await.is_err());
}
#[tokio::test]
async fn load_limits_replaces_large_limit_map() {
let tracker = UserIpTracker::new();
let mut first = HashMap::new();
let mut second = HashMap::new();
for idx in 0..300usize {
first.insert(format!("u{}", idx), 2usize);
}
for idx in 150..450usize {
second.insert(format!("u{}", idx), 4usize);
}
tracker.load_limits(0, &first).await;
tracker.load_limits(0, &second).await;
assert_eq!(tracker.get_user_limit("u20").await, None);
assert_eq!(tracker.get_user_limit("u200").await, Some(4));
assert_eq!(tracker.get_user_limit("u420").await, Some(4));
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn concurrent_same_user_unique_ip_pressure_stays_bounded() {
let tracker = Arc::new(UserIpTracker::new());
tracker.set_user_limit("hot", 32).await;
tracker
.set_limit_policy(UserMaxUniqueIpsMode::ActiveWindow, 30)
.await;
let mut handles = Vec::new();
for worker in 0..16u32 {
let tracker_cloned = tracker.clone();
handles.push(tokio::spawn(async move {
let base = worker * 200;
for step in 0..200u32 {
let _ = tracker_cloned
.check_and_add("hot", ip_from_idx(base + step))
.await;
}
}));
}
for handle in handles {
handle.await.unwrap();
}
assert!(tracker.get_active_ip_count("hot").await <= 32);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn concurrent_many_users_isolate_limits() {
let tracker = Arc::new(UserIpTracker::new());
tracker.load_limits(4, &HashMap::new()).await;
let mut handles = Vec::new();
for user_idx in 0..120u32 {
let tracker_cloned = tracker.clone();
handles.push(tokio::spawn(async move {
let user = format!("u{}", user_idx);
for ip_idx in 0..10u32 {
let _ = tracker_cloned
.check_and_add(&user, ip_from_idx(user_idx * 1_000 + ip_idx))
.await;
}
}));
}
for handle in handles {
handle.await.unwrap();
}
let stats = tracker.get_stats().await;
assert_eq!(stats.len(), 120);
assert!(stats.iter().all(|(_, active, limit)| *active <= 4 && *limit == 4));
}
#[tokio::test]
async fn same_ip_reconnect_high_frequency_keeps_single_unique() {
let tracker = UserIpTracker::new();
tracker.set_user_limit("same", 2).await;
let ip = ip_from_idx(9);
for _ in 0..2_000 {
tracker.check_and_add("same", ip).await.unwrap();
}
assert_eq!(tracker.get_active_ip_count("same").await, 1);
assert!(tracker.is_ip_active("same", ip).await);
}
#[tokio::test]
async fn format_stats_contains_expected_limited_and_unlimited_markers() {
let tracker = UserIpTracker::new();
tracker.set_user_limit("limited", 2).await;
tracker.check_and_add("limited", ip_from_idx(1)).await.unwrap();
tracker.check_and_add("open", ip_from_idx(2)).await.unwrap();
let text = tracker.format_stats().await;
assert!(text.contains("limited"));
assert!(text.contains("open"));
assert!(text.contains("unlimited"));
}
#[tokio::test]
async fn stats_report_global_default_for_users_without_override() {
let tracker = UserIpTracker::new();
tracker.load_limits(5, &HashMap::new()).await;
tracker.check_and_add("a", ip_from_idx(1)).await.unwrap();
tracker.check_and_add("b", ip_from_idx(2)).await.unwrap();
let stats = tracker.get_stats().await;
assert!(stats.iter().any(|(user, _, limit)| user == "a" && *limit == 5));
assert!(stats.iter().any(|(user, _, limit)| user == "b" && *limit == 5));
}
#[tokio::test]
async fn stress_cycle_add_remove_clear_preserves_empty_end_state() {
let tracker = UserIpTracker::new();
for cycle in 0..50u32 {
let user = format!("cycle{}", cycle);
tracker.set_user_limit(&user, 128).await;
for ip_idx in 0..128u32 {
tracker
.check_and_add(&user, ip_from_idx(cycle * 10_000 + ip_idx))
.await
.unwrap();
}
for ip_idx in 0..128u32 {
tracker
.remove_ip(&user, ip_from_idx(cycle * 10_000 + ip_idx))
.await;
}
tracker.clear_user_ips(&user).await;
}
assert!(tracker.get_stats().await.is_empty());
}
#[tokio::test]
async fn remove_unknown_user_or_ip_does_not_corrupt_state() {
let tracker = UserIpTracker::new();
tracker.remove_ip("no_user", ip_from_idx(1)).await;
tracker.check_and_add("x", ip_from_idx(2)).await.unwrap();
tracker.remove_ip("x", ip_from_idx(3)).await;
assert_eq!(tracker.get_active_ip_count("x").await, 1);
assert!(tracker.is_ip_active("x", ip_from_idx(2)).await);
}
#[tokio::test]
async fn active_and_recent_views_match_after_mixed_workload() {
let tracker = UserIpTracker::new();
tracker.set_user_limit("mix", 16).await;
for ip_idx in 0..12u32 {
tracker.check_and_add("mix", ip_from_idx(ip_idx)).await.unwrap();
}
for ip_idx in 0..6u32 {
tracker.remove_ip("mix", ip_from_idx(ip_idx)).await;
}
let active = tracker
.get_active_ips_for_users(&["mix".to_string()])
.await
.get("mix")
.cloned()
.unwrap_or_default();
let recent_count = tracker
.get_recent_counts_for_users(&["mix".to_string()])
.await
.get("mix")
.copied()
.unwrap_or(0);
assert_eq!(active.len(), 6);
assert!(recent_count >= active.len());
assert!(recent_count <= 12);
}
#[tokio::test]
async fn global_limit_switch_updates_enforcement_immediately() {
let tracker = UserIpTracker::new();
tracker.load_limits(2, &HashMap::new()).await;
assert!(tracker.check_and_add("u", ip_from_idx(1)).await.is_ok());
assert!(tracker.check_and_add("u", ip_from_idx(2)).await.is_ok());
assert!(tracker.check_and_add("u", ip_from_idx(3)).await.is_err());
tracker.clear_user_ips("u").await;
tracker.load_limits(4, &HashMap::new()).await;
assert!(tracker.check_and_add("u", ip_from_idx(1)).await.is_ok());
assert!(tracker.check_and_add("u", ip_from_idx(2)).await.is_ok());
assert!(tracker.check_and_add("u", ip_from_idx(3)).await.is_ok());
assert!(tracker.check_and_add("u", ip_from_idx(4)).await.is_ok());
assert!(tracker.check_and_add("u", ip_from_idx(5)).await.is_err());
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn concurrent_reconnect_and_disconnect_preserves_non_negative_counts() {
let tracker = Arc::new(UserIpTracker::new());
tracker.set_user_limit("cc", 8).await;
let mut handles = Vec::new();
for worker in 0..8u32 {
let tracker_cloned = tracker.clone();
handles.push(tokio::spawn(async move {
let ip = ip_from_idx(50 + worker);
for _ in 0..500u32 {
let _ = tracker_cloned.check_and_add("cc", ip).await;
tracker_cloned.remove_ip("cc", ip).await;
}
}));
}
for handle in handles {
handle.await.unwrap();
}
assert!(tracker.get_active_ip_count("cc").await <= 8);
}

View File

@ -279,11 +279,32 @@ pub(crate) async fn spawn_metrics_if_configured(
ip_tracker: Arc<UserIpTracker>, ip_tracker: Arc<UserIpTracker>,
config_rx: watch::Receiver<Arc<ProxyConfig>>, config_rx: watch::Receiver<Arc<ProxyConfig>>,
) { ) {
if let Some(port) = config.server.metrics_port { // metrics_listen takes precedence; fall back to metrics_port for backward compat.
let metrics_target: Option<(u16, Option<String>)> =
if let Some(ref listen) = config.server.metrics_listen {
match listen.parse::<std::net::SocketAddr>() {
Ok(addr) => Some((addr.port(), Some(listen.clone()))),
Err(e) => {
startup_tracker
.skip_component(
COMPONENT_METRICS_START,
Some(format!("invalid metrics_listen \"{}\": {}", listen, e)),
)
.await;
None
}
}
} else {
config.server.metrics_port.map(|p| (p, None))
};
if let Some((port, listen)) = metrics_target {
let fallback_label = format!("port {}", port);
let label = listen.as_deref().unwrap_or(&fallback_label);
startup_tracker startup_tracker
.start_component( .start_component(
COMPONENT_METRICS_START, COMPONENT_METRICS_START,
Some(format!("spawn metrics endpoint on {}", port)), Some(format!("spawn metrics endpoint on {}", label)),
) )
.await; .await;
let stats = stats.clone(); let stats = stats.clone();
@ -294,6 +315,7 @@ pub(crate) async fn spawn_metrics_if_configured(
tokio::spawn(async move { tokio::spawn(async move {
metrics::serve( metrics::serve(
port, port,
listen,
stats, stats,
beobachten, beobachten,
ip_tracker_metrics, ip_tracker_metrics,
@ -308,7 +330,7 @@ pub(crate) async fn spawn_metrics_if_configured(
Some("metrics task spawned".to_string()), Some("metrics task spawned".to_string()),
) )
.await; .await;
} else { } else if config.server.metrics_listen.is_none() {
startup_tracker startup_tracker
.skip_component( .skip_component(
COMPONENT_METRICS_START, COMPONENT_METRICS_START,

View File

@ -6,6 +6,8 @@ mod config;
mod crypto; mod crypto;
mod error; mod error;
mod ip_tracker; mod ip_tracker;
#[cfg(test)]
mod ip_tracker_regression_tests;
mod maestro; mod maestro;
mod metrics; mod metrics;
mod network; mod network;

View File

@ -21,6 +21,7 @@ use crate::transport::{ListenOptions, create_listener};
pub async fn serve( pub async fn serve(
port: u16, port: u16,
listen: Option<String>,
stats: Arc<Stats>, stats: Arc<Stats>,
beobachten: Arc<BeobachtenStore>, beobachten: Arc<BeobachtenStore>,
ip_tracker: Arc<UserIpTracker>, ip_tracker: Arc<UserIpTracker>,
@ -28,6 +29,33 @@ pub async fn serve(
whitelist: Vec<IpNetwork>, whitelist: Vec<IpNetwork>,
) { ) {
let whitelist = Arc::new(whitelist); let whitelist = Arc::new(whitelist);
// If `metrics_listen` is set, bind on that single address only.
if let Some(ref listen_addr) = listen {
let addr: SocketAddr = match listen_addr.parse() {
Ok(a) => a,
Err(e) => {
warn!(error = %e, "Invalid metrics_listen address: {}", listen_addr);
return;
}
};
let is_ipv6 = addr.is_ipv6();
match bind_metrics_listener(addr, is_ipv6) {
Ok(listener) => {
info!("Metrics endpoint: http://{}/metrics and /beobachten", addr);
serve_listener(
listener, stats, beobachten, ip_tracker, config_rx, whitelist,
)
.await;
}
Err(e) => {
warn!(error = %e, "Failed to bind metrics on {}", addr);
}
}
return;
}
// Fallback: bind on 0.0.0.0 and [::] using metrics_port.
let mut listener_v4 = None; let mut listener_v4 = None;
let mut listener_v6 = None; let mut listener_v6 = None;

View File

@ -24,6 +24,72 @@ enum HandshakeOutcome {
Handled, Handled,
} }
#[must_use = "UserConnectionReservation must be kept alive to retain user/IP reservation until release or drop"]
struct UserConnectionReservation {
stats: Arc<Stats>,
ip_tracker: Arc<UserIpTracker>,
user: String,
ip: IpAddr,
active: bool,
runtime_handle: Option<tokio::runtime::Handle>,
}
impl UserConnectionReservation {
fn new(stats: Arc<Stats>, ip_tracker: Arc<UserIpTracker>, user: String, ip: IpAddr) -> Self {
let runtime_handle = tokio::runtime::Handle::try_current().ok();
Self {
stats,
ip_tracker,
user,
ip,
active: true,
runtime_handle,
}
}
async fn release(mut self) {
if !self.active {
return;
}
self.active = false;
self.stats.decrement_user_curr_connects(&self.user);
self.ip_tracker.remove_ip(&self.user, self.ip).await;
}
}
impl Drop for UserConnectionReservation {
fn drop(&mut self) {
if !self.active {
return;
}
self.active = false;
self.stats.decrement_user_curr_connects(&self.user);
if let Some(handle) = &self.runtime_handle {
let ip_tracker = self.ip_tracker.clone();
let user = self.user.clone();
let ip = self.ip;
let handle = handle.clone();
handle.spawn(async move {
ip_tracker.remove_ip(&user, ip).await;
});
} else if let Ok(handle) = tokio::runtime::Handle::try_current() {
let ip_tracker = self.ip_tracker.clone();
let user = self.user.clone();
let ip = self.ip;
handle.spawn(async move {
ip_tracker.remove_ip(&user, ip).await;
});
} else {
warn!(
user = %self.user,
ip = %self.ip,
"UserConnectionReservation dropped without Tokio runtime; IP reservation cleanup skipped"
);
}
}
}
use crate::config::ProxyConfig; use crate::config::ProxyConfig;
use crate::crypto::SecureRandom; use crate::crypto::SecureRandom;
use crate::error::{HandshakeResult, ProxyError, Result, StreamError}; use crate::error::{HandshakeResult, ProxyError, Result, StreamError};
@ -90,6 +156,10 @@ fn is_trusted_proxy_source(peer_ip: IpAddr, trusted: &[IpNetwork]) -> bool {
trusted.iter().any(|cidr| cidr.contains(peer_ip)) trusted.iter().any(|cidr| cidr.contains(peer_ip))
} }
fn synthetic_local_addr(port: u16) -> SocketAddr {
SocketAddr::from(([0, 0, 0, 0], port))
}
pub async fn handle_client_stream<S>( pub async fn handle_client_stream<S>(
mut stream: S, mut stream: S,
peer: SocketAddr, peer: SocketAddr,
@ -113,9 +183,7 @@ where
let mut real_peer = normalize_ip(peer); let mut real_peer = normalize_ip(peer);
// For non-TCP streams, use a synthetic local address; may be overridden by PROXY protocol dst // For non-TCP streams, use a synthetic local address; may be overridden by PROXY protocol dst
let mut local_addr: SocketAddr = format!("0.0.0.0:{}", config.server.port) let mut local_addr = synthetic_local_addr(config.server.port);
.parse()
.unwrap_or_else(|_| "0.0.0.0:443".parse().unwrap());
if proxy_protocol_enabled { if proxy_protocol_enabled {
let proxy_header_timeout = Duration::from_millis( let proxy_header_timeout = Duration::from_millis(
@ -798,10 +866,22 @@ impl RunningClientHandler {
{ {
let user = success.user.clone(); let user = success.user.clone();
if let Err(e) = Self::check_user_limits_static(&user, &config, &stats, peer_addr, &ip_tracker).await { let user_limit_reservation =
warn!(user = %user, error = %e, "User limit exceeded"); match Self::acquire_user_connection_reservation_static(
return Err(e); &user,
} &config,
stats.clone(),
peer_addr,
ip_tracker,
)
.await
{
Ok(reservation) => reservation,
Err(e) => {
warn!(user = %user, error = %e, "User admission check failed");
return Err(e);
}
};
let route_snapshot = route_runtime.snapshot(); let route_snapshot = route_runtime.snapshot();
let session_id = rng.u64(); let session_id = rng.u64();
@ -858,12 +938,65 @@ impl RunningClientHandler {
) )
.await .await
}; };
user_limit_reservation.release().await;
stats.decrement_user_curr_connects(&user);
ip_tracker.remove_ip(&user, peer_addr.ip()).await;
relay_result relay_result
} }
async fn acquire_user_connection_reservation_static(
user: &str,
config: &ProxyConfig,
stats: Arc<Stats>,
peer_addr: SocketAddr,
ip_tracker: Arc<UserIpTracker>,
) -> Result<UserConnectionReservation> {
if let Some(expiration) = config.access.user_expirations.get(user)
&& chrono::Utc::now() > *expiration
{
return Err(ProxyError::UserExpired {
user: user.to_string(),
});
}
if let Some(quota) = config.access.user_data_quota.get(user)
&& stats.get_user_total_octets(user) >= *quota
{
return Err(ProxyError::DataQuotaExceeded {
user: user.to_string(),
});
}
let limit = config.access.user_max_tcp_conns.get(user).map(|v| *v as u64);
if !stats.try_acquire_user_curr_connects(user, limit) {
return Err(ProxyError::ConnectionLimitExceeded {
user: user.to_string(),
});
}
match ip_tracker.check_and_add(user, peer_addr.ip()).await {
Ok(()) => {}
Err(reason) => {
stats.decrement_user_curr_connects(user);
warn!(
user = %user,
ip = %peer_addr.ip(),
reason = %reason,
"IP limit exceeded"
);
return Err(ProxyError::ConnectionLimitExceeded {
user: user.to_string(),
});
}
}
Ok(UserConnectionReservation::new(
stats,
ip_tracker,
user.to_string(),
peer_addr.ip(),
))
}
#[cfg(test)]
async fn check_user_limits_static( async fn check_user_limits_static(
user: &str, user: &str,
config: &ProxyConfig, config: &ProxyConfig,

File diff suppressed because it is too large Load Diff

View File

@ -105,7 +105,7 @@ where
debug!(peer = %success.peer, "TG handshake complete, starting relay"); debug!(peer = %success.peer, "TG handshake complete, starting relay");
stats.increment_user_connects(user); stats.increment_user_connects(user);
stats.increment_current_connections_direct(); let _direct_connection_lease = stats.acquire_direct_connection_lease();
let relay_result = relay_bidirectional( let relay_result = relay_bidirectional(
client_reader, client_reader,
@ -148,8 +148,6 @@ where
} }
}; };
stats.decrement_current_connections_direct();
match &relay_result { match &relay_result {
Ok(()) => debug!(user = %user, "Direct relay completed"), Ok(()) => debug!(user = %user, "Direct relay completed"),
Err(e) => debug!(user = %user, error = %e, "Direct relay ended with error"), Err(e) => debug!(user = %user, error = %e, "Direct relay ended with error"),

View File

@ -1,4 +1,33 @@
use super::*; use super::*;
use crate::config::{UpstreamConfig, UpstreamType};
use crate::crypto::{AesCtr, SecureRandom};
use crate::protocol::constants::ProtoTag;
use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController};
use crate::stats::Stats;
use crate::stream::{BufferPool, CryptoReader, CryptoWriter};
use crate::transport::UpstreamManager;
use std::sync::Arc;
use std::time::Duration;
use tokio::io::duplex;
use tokio::net::TcpListener;
fn make_crypto_reader<R>(reader: R) -> CryptoReader<R>
where
R: tokio::io::AsyncRead + Unpin,
{
let key = [0u8; 32];
let iv = 0u128;
CryptoReader::new(reader, AesCtr::new(&key, iv))
}
fn make_crypto_writer<W>(writer: W) -> CryptoWriter<W>
where
W: tokio::io::AsyncWrite + Unpin,
{
let key = [0u8; 32];
let iv = 0u128;
CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024)
}
#[test] #[test]
fn unknown_dc_log_is_deduplicated_per_dc_idx() { fn unknown_dc_log_is_deduplicated_per_dc_idx() {
@ -49,3 +78,212 @@ fn fallback_dc_never_panics_with_single_dc_list() {
let expected = SocketAddr::new(TG_DATACENTERS_V6[0], TG_DATACENTER_PORT); let expected = SocketAddr::new(TG_DATACENTERS_V6[0], TG_DATACENTER_PORT);
assert_eq!(addr, expected); assert_eq!(addr, expected);
} }
#[tokio::test]
async fn direct_relay_abort_midflight_releases_route_gauge() {
let tg_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let tg_addr = tg_listener.local_addr().unwrap();
let tg_accept_task = tokio::spawn(async move {
let (stream, _) = tg_listener.accept().await.unwrap();
let _hold_stream = stream;
tokio::time::sleep(Duration::from_secs(60)).await;
});
let stats = Arc::new(Stats::new());
let mut config = ProxyConfig::default();
config
.dc_overrides
.insert("2".to_string(), vec![tg_addr.to_string()]);
let config = Arc::new(config);
let upstream_manager = Arc::new(UpstreamManager::new(
vec![UpstreamConfig {
upstream_type: UpstreamType::Direct {
interface: None,
bind_addresses: None,
},
weight: 1,
enabled: true,
scopes: String::new(),
selected_scope: String::new(),
}],
1,
1,
1,
1,
false,
stats.clone(),
));
let rng = Arc::new(SecureRandom::new());
let buffer_pool = Arc::new(BufferPool::new());
let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct));
let route_snapshot = route_runtime.snapshot();
let (server_side, client_side) = duplex(64 * 1024);
let (server_reader, server_writer) = tokio::io::split(server_side);
let client_reader = make_crypto_reader(server_reader);
let client_writer = make_crypto_writer(server_writer);
let success = HandshakeSuccess {
user: "abort-direct-user".to_string(),
dc_idx: 2,
proto_tag: ProtoTag::Intermediate,
dec_key: [0u8; 32],
dec_iv: 0,
enc_key: [0u8; 32],
enc_iv: 0,
peer: "127.0.0.1:50000".parse().unwrap(),
is_tls: false,
};
let relay_task = tokio::spawn(handle_via_direct(
client_reader,
client_writer,
success,
upstream_manager,
stats.clone(),
config,
buffer_pool,
rng,
route_runtime.subscribe(),
route_snapshot,
0xabad1dea,
));
let started = tokio::time::timeout(Duration::from_secs(2), async {
loop {
if stats.get_current_connections_direct() == 1 {
break;
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
})
.await;
assert!(started.is_ok(), "direct relay must increment route gauge before abort");
relay_task.abort();
let joined = relay_task.await;
assert!(joined.is_err(), "aborted direct relay task must return join error");
tokio::time::sleep(Duration::from_millis(20)).await;
assert_eq!(
stats.get_current_connections_direct(),
0,
"route gauge must be released when direct relay task is aborted mid-flight"
);
drop(client_side);
tg_accept_task.abort();
let _ = tg_accept_task.await;
}
#[tokio::test]
async fn direct_relay_cutover_midflight_releases_route_gauge() {
let tg_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let tg_addr = tg_listener.local_addr().unwrap();
let tg_accept_task = tokio::spawn(async move {
let (stream, _) = tg_listener.accept().await.unwrap();
let _hold_stream = stream;
tokio::time::sleep(Duration::from_secs(60)).await;
});
let stats = Arc::new(Stats::new());
let mut config = ProxyConfig::default();
config
.dc_overrides
.insert("2".to_string(), vec![tg_addr.to_string()]);
let config = Arc::new(config);
let upstream_manager = Arc::new(UpstreamManager::new(
vec![UpstreamConfig {
upstream_type: UpstreamType::Direct {
interface: None,
bind_addresses: None,
},
weight: 1,
enabled: true,
scopes: String::new(),
selected_scope: String::new(),
}],
1,
1,
1,
1,
false,
stats.clone(),
));
let rng = Arc::new(SecureRandom::new());
let buffer_pool = Arc::new(BufferPool::new());
let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct));
let route_snapshot = route_runtime.snapshot();
let (server_side, client_side) = duplex(64 * 1024);
let (server_reader, server_writer) = tokio::io::split(server_side);
let client_reader = make_crypto_reader(server_reader);
let client_writer = make_crypto_writer(server_writer);
let success = HandshakeSuccess {
user: "cutover-direct-user".to_string(),
dc_idx: 2,
proto_tag: ProtoTag::Intermediate,
dec_key: [0u8; 32],
dec_iv: 0,
enc_key: [0u8; 32],
enc_iv: 0,
peer: "127.0.0.1:50002".parse().unwrap(),
is_tls: false,
};
let relay_task = tokio::spawn(handle_via_direct(
client_reader,
client_writer,
success,
upstream_manager,
stats.clone(),
config,
buffer_pool,
rng,
route_runtime.subscribe(),
route_snapshot,
0xface_cafe,
));
tokio::time::timeout(Duration::from_secs(2), async {
loop {
if stats.get_current_connections_direct() == 1 {
break;
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
})
.await
.expect("direct relay must increment route gauge before cutover");
assert!(
route_runtime.set_mode(RelayRouteMode::Middle).is_some(),
"cutover must advance route generation"
);
let relay_result = tokio::time::timeout(Duration::from_secs(6), relay_task)
.await
.expect("direct relay must terminate after cutover")
.expect("direct relay task must not panic");
assert!(
relay_result.is_err(),
"cutover should terminate direct relay session"
);
assert_eq!(
stats.get_current_connections_direct(),
0,
"route gauge must be released when direct relay exits on cutover"
);
drop(client_side);
tg_accept_task.abort();
let _ = tg_accept_task.await;
}

View File

@ -317,6 +317,24 @@ fn decode_user_secrets(
secrets secrets
} }
async fn maybe_apply_server_hello_delay(config: &ProxyConfig) {
if config.censorship.server_hello_delay_max_ms == 0 {
return;
}
let min = config.censorship.server_hello_delay_min_ms;
let max = config.censorship.server_hello_delay_max_ms.max(min);
let delay_ms = if max == min {
max
} else {
rand::rng().random_range(min..=max)
};
if delay_ms > 0 {
tokio::time::sleep(Duration::from_millis(delay_ms)).await;
}
}
/// Result of successful handshake /// Result of successful handshake
/// ///
/// Key material (`dec_key`, `dec_iv`, `enc_key`, `enc_iv`) is /// Key material (`dec_key`, `dec_iv`, `enc_key`, `enc_iv`) is
@ -368,11 +386,13 @@ where
debug!(peer = %peer, handshake_len = handshake.len(), "Processing TLS handshake"); debug!(peer = %peer, handshake_len = handshake.len(), "Processing TLS handshake");
if auth_probe_is_throttled(peer.ip(), Instant::now()) { if auth_probe_is_throttled(peer.ip(), Instant::now()) {
maybe_apply_server_hello_delay(config).await;
debug!(peer = %peer, "TLS handshake rejected by pre-auth probe throttle"); debug!(peer = %peer, "TLS handshake rejected by pre-auth probe throttle");
return HandshakeResult::BadClient { reader, writer }; return HandshakeResult::BadClient { reader, writer };
} }
if handshake.len() < tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 { if handshake.len() < tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 {
maybe_apply_server_hello_delay(config).await;
debug!(peer = %peer, "TLS handshake too short"); debug!(peer = %peer, "TLS handshake too short");
return HandshakeResult::BadClient { reader, writer }; return HandshakeResult::BadClient { reader, writer };
} }
@ -388,6 +408,7 @@ where
Some(v) => v, Some(v) => v,
None => { None => {
auth_probe_record_failure(peer.ip(), Instant::now()); auth_probe_record_failure(peer.ip(), Instant::now());
maybe_apply_server_hello_delay(config).await;
debug!( debug!(
peer = %peer, peer = %peer,
ignore_time_skew = config.access.ignore_time_skew, ignore_time_skew = config.access.ignore_time_skew,
@ -402,13 +423,17 @@ where
let digest_half = &validation.digest[..tls::TLS_DIGEST_HALF_LEN]; let digest_half = &validation.digest[..tls::TLS_DIGEST_HALF_LEN];
if replay_checker.check_and_add_tls_digest(digest_half) { if replay_checker.check_and_add_tls_digest(digest_half) {
auth_probe_record_failure(peer.ip(), Instant::now()); auth_probe_record_failure(peer.ip(), Instant::now());
maybe_apply_server_hello_delay(config).await;
warn!(peer = %peer, "TLS replay attack detected (duplicate digest)"); warn!(peer = %peer, "TLS replay attack detected (duplicate digest)");
return HandshakeResult::BadClient { reader, writer }; return HandshakeResult::BadClient { reader, writer };
} }
let secret = match secrets.iter().find(|(name, _)| *name == validation.user) { let secret = match secrets.iter().find(|(name, _)| *name == validation.user) {
Some((_, s)) => s, Some((_, s)) => s,
None => return HandshakeResult::BadClient { reader, writer }, None => {
maybe_apply_server_hello_delay(config).await;
return HandshakeResult::BadClient { reader, writer };
}
}; };
let cached = if config.censorship.tls_emulation { let cached = if config.censorship.tls_emulation {
@ -448,6 +473,7 @@ where
} else if alpn_list.iter().any(|p| p == b"http/1.1") { } else if alpn_list.iter().any(|p| p == b"http/1.1") {
Some(b"http/1.1".to_vec()) Some(b"http/1.1".to_vec())
} else if !alpn_list.is_empty() { } else if !alpn_list.is_empty() {
maybe_apply_server_hello_delay(config).await;
debug!(peer = %peer, "Client ALPN list has no supported protocol; using masking fallback"); debug!(peer = %peer, "Client ALPN list has no supported protocol; using masking fallback");
return HandshakeResult::BadClient { reader, writer }; return HandshakeResult::BadClient { reader, writer };
} else { } else {
@ -480,19 +506,9 @@ where
) )
}; };
// Optional anti-fingerprint delay before sending ServerHello. // Apply the same optional delay budget used by reject paths to reduce
if config.censorship.server_hello_delay_max_ms > 0 { // distinguishability between success and fail-closed handshakes.
let min = config.censorship.server_hello_delay_min_ms; maybe_apply_server_hello_delay(config).await;
let max = config.censorship.server_hello_delay_max_ms.max(min);
let delay_ms = if max == min {
max
} else {
rand::rng().random_range(min..=max)
};
if delay_ms > 0 {
tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
}
}
debug!(peer = %peer, response_len = response.len(), "Sending TLS ServerHello"); debug!(peer = %peer, response_len = response.len(), "Sending TLS ServerHello");
@ -539,6 +555,7 @@ where
trace!(peer = %peer, handshake = ?hex::encode(handshake), "MTProto handshake bytes"); trace!(peer = %peer, handshake = ?hex::encode(handshake), "MTProto handshake bytes");
if auth_probe_is_throttled(peer.ip(), Instant::now()) { if auth_probe_is_throttled(peer.ip(), Instant::now()) {
maybe_apply_server_hello_delay(config).await;
debug!(peer = %peer, "MTProto handshake rejected by pre-auth probe throttle"); debug!(peer = %peer, "MTProto handshake rejected by pre-auth probe throttle");
return HandshakeResult::BadClient { reader, writer }; return HandshakeResult::BadClient { reader, writer };
} }
@ -609,6 +626,7 @@ where
// authentication check first to avoid poisoning the replay cache. // authentication check first to avoid poisoning the replay cache.
if replay_checker.check_and_add_handshake(dec_prekey_iv) { if replay_checker.check_and_add_handshake(dec_prekey_iv) {
auth_probe_record_failure(peer.ip(), Instant::now()); auth_probe_record_failure(peer.ip(), Instant::now());
maybe_apply_server_hello_delay(config).await;
warn!(peer = %peer, user = %user, "MTProto replay attack detected"); warn!(peer = %peer, user = %user, "MTProto replay attack detected");
return HandshakeResult::BadClient { reader, writer }; return HandshakeResult::BadClient { reader, writer };
} }
@ -645,6 +663,7 @@ where
} }
auth_probe_record_failure(peer.ip(), Instant::now()); auth_probe_record_failure(peer.ip(), Instant::now());
maybe_apply_server_hello_delay(config).await;
debug!(peer = %peer, "MTProto handshake: no matching user found"); debug!(peer = %peer, "MTProto handshake: no matching user found");
HandshakeResult::BadClient { reader, writer } HandshakeResult::BadClient { reader, writer }
} }

View File

@ -580,6 +580,72 @@ async fn malformed_tls_classes_complete_within_bounded_time() {
} }
} }
#[tokio::test]
async fn tls_invalid_hmac_respects_configured_anti_fingerprint_delay() {
let secret = [0x5Au8; 16];
let mut config = test_config_with_secret_hex("5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a");
config.censorship.server_hello_delay_min_ms = 20;
config.censorship.server_hello_delay_max_ms = 20;
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let rng = SecureRandom::new();
let peer: SocketAddr = "198.51.100.32:44331".parse().unwrap();
let mut bad_hmac = make_valid_tls_handshake(&secret, 0);
bad_hmac[tls::TLS_DIGEST_POS] ^= 0x01;
let started = Instant::now();
let result = handle_tls_handshake(
&bad_hmac,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
&rng,
None,
)
.await;
assert!(matches!(result, HandshakeResult::BadClient { .. }));
assert!(
started.elapsed() >= Duration::from_millis(18),
"configured anti-fingerprint delay must apply to invalid TLS handshakes"
);
}
#[tokio::test]
async fn tls_alpn_mismatch_respects_configured_anti_fingerprint_delay() {
let secret = [0x6Bu8; 16];
let mut config = test_config_with_secret_hex("6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b");
config.censorship.alpn_enforce = true;
config.censorship.server_hello_delay_min_ms = 20;
config.censorship.server_hello_delay_max_ms = 20;
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let rng = SecureRandom::new();
let peer: SocketAddr = "198.51.100.33:44332".parse().unwrap();
let handshake = make_valid_tls_client_hello_with_alpn(&secret, 0, &[b"h3"]);
let started = Instant::now();
let result = handle_tls_handshake(
&handshake,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
&rng,
None,
)
.await;
assert!(matches!(result, HandshakeResult::BadClient { .. }));
assert!(
started.elapsed() >= Duration::from_millis(18),
"configured anti-fingerprint delay must apply to ALPN-mismatch rejects"
);
}
#[tokio::test] #[tokio::test]
#[ignore = "timing-sensitive; run manually on low-jitter hosts"] #[ignore = "timing-sensitive; run manually on low-jitter hosts"]
async fn malformed_tls_classes_share_close_latency_buckets() { async fn malformed_tls_classes_share_close_latency_buckets() {
@ -643,6 +709,82 @@ async fn malformed_tls_classes_share_close_latency_buckets() {
); );
} }
#[tokio::test]
#[ignore = "timing matrix; run manually with --ignored --nocapture"]
async fn timing_matrix_tls_classes_under_fixed_delay_budget() {
const ITER: usize = 48;
const BUCKET_MS: u128 = 10;
let secret = [0x77u8; 16];
let mut config = test_config_with_secret_hex("77777777777777777777777777777777");
config.censorship.alpn_enforce = true;
config.censorship.server_hello_delay_min_ms = 20;
config.censorship.server_hello_delay_max_ms = 20;
let rng = SecureRandom::new();
let base_ip = std::net::Ipv4Addr::new(198, 51, 100, 34);
let too_short = vec![0x16, 0x03, 0x01];
let mut bad_hmac = make_valid_tls_handshake(&secret, 0);
bad_hmac[tls::TLS_DIGEST_POS + 1] ^= 0x01;
let alpn_mismatch = make_valid_tls_client_hello_with_alpn(&secret, 0, &[b"h3"]);
let valid_h2 = make_valid_tls_client_hello_with_alpn(&secret, 0, &[b"h2"]);
let classes = vec![
("too_short", too_short),
("bad_hmac", bad_hmac),
("alpn_mismatch", alpn_mismatch),
("valid_h2", valid_h2),
];
for (class, probe) in classes {
let mut samples_ms = Vec::with_capacity(ITER);
for idx in 0..ITER {
clear_auth_probe_state_for_testing();
let replay_checker = ReplayChecker::new(4096, Duration::from_secs(60));
let peer: SocketAddr = SocketAddr::from((base_ip, 44_000 + idx as u16));
let started = Instant::now();
let result = handle_tls_handshake(
&probe,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
&rng,
None,
)
.await;
let elapsed = started.elapsed();
samples_ms.push(elapsed.as_millis());
if class == "valid_h2" {
assert!(matches!(result, HandshakeResult::Success(_)));
} else {
assert!(matches!(result, HandshakeResult::BadClient { .. }));
}
}
samples_ms.sort_unstable();
let sum: u128 = samples_ms.iter().copied().sum();
let mean = sum as f64 / samples_ms.len() as f64;
let min = samples_ms[0];
let p95_idx = ((samples_ms.len() as f64) * 0.95).floor() as usize;
let p95 = samples_ms[p95_idx.min(samples_ms.len() - 1)];
let max = samples_ms[samples_ms.len() - 1];
println!(
"TIMING_MATRIX tls class={} mean_ms={:.2} min_ms={} p95_ms={} max_ms={} bucket_mean={}",
class,
mean,
min,
p95,
max,
(mean as u128) / BUCKET_MS
);
}
}
#[test] #[test]
fn secure_tag_requires_tls_mode_on_tls_transport() { fn secure_tag_requires_tls_mode_on_tls_transport() {
let mut config = ProxyConfig::default(); let mut config = ProxyConfig::default();

View File

@ -7,7 +7,7 @@ use tokio::net::TcpStream;
#[cfg(unix)] #[cfg(unix)]
use tokio::net::UnixStream; use tokio::net::UnixStream;
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt};
use tokio::time::timeout; use tokio::time::{Instant, timeout};
use tracing::debug; use tracing::debug;
use crate::config::ProxyConfig; use crate::config::ProxyConfig;
use crate::network::dns_overrides::resolve_socket_addr; use crate::network::dns_overrides::resolve_socket_addr;
@ -49,6 +49,20 @@ where
} }
} }
async fn wait_mask_connect_budget(started: Instant) {
let elapsed = started.elapsed();
if elapsed < MASK_TIMEOUT {
tokio::time::sleep(MASK_TIMEOUT - elapsed).await;
}
}
async fn wait_mask_outcome_budget(started: Instant) {
let elapsed = started.elapsed();
if elapsed < MASK_TIMEOUT {
tokio::time::sleep(MASK_TIMEOUT - elapsed).await;
}
}
/// Detect client type based on initial data /// Detect client type based on initial data
fn detect_client_type(data: &[u8]) -> &'static str { fn detect_client_type(data: &[u8]) -> &'static str {
// Check for HTTP request // Check for HTTP request
@ -107,6 +121,8 @@ where
// Connect via Unix socket or TCP // Connect via Unix socket or TCP
#[cfg(unix)] #[cfg(unix)]
if let Some(ref sock_path) = config.censorship.mask_unix_sock { if let Some(ref sock_path) = config.censorship.mask_unix_sock {
let outcome_started = Instant::now();
let connect_started = Instant::now();
debug!( debug!(
client_type = client_type, client_type = client_type,
sock = %sock_path, sock = %sock_path,
@ -143,14 +159,18 @@ where
if timeout(MASK_RELAY_TIMEOUT, relay_to_mask(reader, writer, mask_read, mask_write, initial_data)).await.is_err() { if timeout(MASK_RELAY_TIMEOUT, relay_to_mask(reader, writer, mask_read, mask_write, initial_data)).await.is_err() {
debug!("Mask relay timed out (unix socket)"); debug!("Mask relay timed out (unix socket)");
} }
wait_mask_outcome_budget(outcome_started).await;
} }
Ok(Err(e)) => { Ok(Err(e)) => {
wait_mask_connect_budget(connect_started).await;
debug!(error = %e, "Failed to connect to mask unix socket"); debug!(error = %e, "Failed to connect to mask unix socket");
consume_client_data_with_timeout(reader).await; consume_client_data_with_timeout(reader).await;
wait_mask_outcome_budget(outcome_started).await;
} }
Err(_) => { Err(_) => {
debug!("Timeout connecting to mask unix socket"); debug!("Timeout connecting to mask unix socket");
consume_client_data_with_timeout(reader).await; consume_client_data_with_timeout(reader).await;
wait_mask_outcome_budget(outcome_started).await;
} }
} }
return; return;
@ -172,6 +192,8 @@ where
let mask_addr = resolve_socket_addr(mask_host, mask_port) let mask_addr = resolve_socket_addr(mask_host, mask_port)
.map(|addr| addr.to_string()) .map(|addr| addr.to_string())
.unwrap_or_else(|| format!("{}:{}", mask_host, mask_port)); .unwrap_or_else(|| format!("{}:{}", mask_host, mask_port));
let outcome_started = Instant::now();
let connect_started = Instant::now();
let connect_result = timeout(MASK_TIMEOUT, TcpStream::connect(&mask_addr)).await; let connect_result = timeout(MASK_TIMEOUT, TcpStream::connect(&mask_addr)).await;
match connect_result { match connect_result {
Ok(Ok(stream)) => { Ok(Ok(stream)) => {
@ -202,14 +224,18 @@ where
if timeout(MASK_RELAY_TIMEOUT, relay_to_mask(reader, writer, mask_read, mask_write, initial_data)).await.is_err() { if timeout(MASK_RELAY_TIMEOUT, relay_to_mask(reader, writer, mask_read, mask_write, initial_data)).await.is_err() {
debug!("Mask relay timed out"); debug!("Mask relay timed out");
} }
wait_mask_outcome_budget(outcome_started).await;
} }
Ok(Err(e)) => { Ok(Err(e)) => {
wait_mask_connect_budget(connect_started).await;
debug!(error = %e, "Failed to connect to mask host"); debug!(error = %e, "Failed to connect to mask host");
consume_client_data_with_timeout(reader).await; consume_client_data_with_timeout(reader).await;
wait_mask_outcome_budget(outcome_started).await;
} }
Err(_) => { Err(_) => {
debug!("Timeout connecting to mask host"); debug!("Timeout connecting to mask host");
consume_client_data_with_timeout(reader).await; consume_client_data_with_timeout(reader).await;
wait_mask_outcome_budget(outcome_started).await;
} }
} }
} }

View File

@ -8,7 +8,7 @@ use tokio::io::{duplex, AsyncBufReadExt, BufReader};
use tokio::net::TcpListener; use tokio::net::TcpListener;
#[cfg(unix)] #[cfg(unix)]
use tokio::net::UnixListener; use tokio::net::UnixListener;
use tokio::time::{sleep, timeout, Duration}; use tokio::time::{Instant, sleep, timeout, Duration};
#[tokio::test] #[tokio::test]
async fn bad_client_probe_is_forwarded_verbatim_to_mask_backend() { async fn bad_client_probe_is_forwarded_verbatim_to_mask_backend() {
@ -216,6 +216,372 @@ async fn backend_unavailable_falls_back_to_silent_consume() {
assert_eq!(n, 0); assert_eq!(n, 0);
} }
#[tokio::test]
async fn backend_connect_refusal_waits_mask_connect_budget_before_fallback() {
let temp_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let unused_port = temp_listener.local_addr().unwrap().port();
drop(temp_listener);
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = unused_port;
config.censorship.mask_unix_sock = None;
config.censorship.mask_proxy_protocol = 0;
let peer: SocketAddr = "203.0.113.12:42426".parse().unwrap();
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let probe = b"GET /probe HTTP/1.1\r\nHost: x\r\n\r\n";
// Keep reader open so fallback path does not terminate immediately on EOF.
let (_client_reader_side, client_reader) = duplex(256);
let (_client_visible_reader, client_visible_writer) = duplex(256);
let beobachten = BeobachtenStore::new();
let started = Instant::now();
let task = tokio::spawn(async move {
handle_bad_client(
client_reader,
client_visible_writer,
probe,
peer,
local_addr,
&config,
&beobachten,
)
.await;
});
timeout(Duration::from_millis(35), task)
.await
.expect_err("masking fallback must not complete before connect budget elapses");
assert!(
started.elapsed() >= Duration::from_millis(35),
"fallback path must absorb immediate connect refusal into connect budget"
);
}
#[tokio::test]
async fn backend_reachable_fast_response_waits_mask_outcome_budget() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let probe = b"GET /ok HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec();
let backend_reply = b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\n\r\n".to_vec();
let accept_task = tokio::spawn({
let probe = probe.clone();
let backend_reply = backend_reply.clone();
async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut received = vec![0u8; probe.len()];
stream.read_exact(&mut received).await.unwrap();
assert_eq!(received, probe);
stream.write_all(&backend_reply).await.unwrap();
}
});
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = backend_addr.port();
config.censorship.mask_unix_sock = None;
config.censorship.mask_proxy_protocol = 0;
let peer: SocketAddr = "203.0.113.13:42427".parse().unwrap();
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let (client_writer_side, client_reader) = duplex(256);
drop(client_writer_side);
let (_client_visible_reader, client_visible_writer) = duplex(512);
let beobachten = BeobachtenStore::new();
let started = Instant::now();
handle_bad_client(
client_reader,
client_visible_writer,
&probe,
peer,
local_addr,
&config,
&beobachten,
)
.await;
assert!(
started.elapsed() >= Duration::from_millis(45),
"reachable mask path must also satisfy coarse outcome budget"
);
accept_task.await.unwrap();
}
#[tokio::test]
async fn mask_disabled_fast_eof_not_shaped_by_mask_budget() {
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = false;
let peer: SocketAddr = "203.0.113.14:42428".parse().unwrap();
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let (client_writer_side, client_reader) = duplex(256);
drop(client_writer_side);
let (_client_visible_reader, client_visible_writer) = duplex(256);
let beobachten = BeobachtenStore::new();
let started = Instant::now();
handle_bad_client(
client_reader,
client_visible_writer,
b"x",
peer,
local_addr,
&config,
&beobachten,
)
.await;
assert!(
started.elapsed() < Duration::from_millis(20),
"mask-disabled fallback should keep immediate EOF behavior"
);
}
#[tokio::test]
async fn backend_reachable_slow_response_not_padded_twice() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let probe = b"GET /slow HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec();
let backend_reply = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK".to_vec();
let accept_task = tokio::spawn({
let probe = probe.clone();
let backend_reply = backend_reply.clone();
async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut received = vec![0u8; probe.len()];
stream.read_exact(&mut received).await.unwrap();
assert_eq!(received, probe);
sleep(Duration::from_millis(90)).await;
stream.write_all(&backend_reply).await.unwrap();
}
});
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = backend_addr.port();
config.censorship.mask_unix_sock = None;
config.censorship.mask_proxy_protocol = 0;
let peer: SocketAddr = "203.0.113.15:42429".parse().unwrap();
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let (client_writer_side, client_reader) = duplex(256);
drop(client_writer_side);
let (_client_visible_reader, client_visible_writer) = duplex(512);
let beobachten = BeobachtenStore::new();
let started = Instant::now();
handle_bad_client(
client_reader,
client_visible_writer,
&probe,
peer,
local_addr,
&config,
&beobachten,
)
.await;
let elapsed = started.elapsed();
assert!(elapsed >= Duration::from_millis(85));
assert!(
elapsed < Duration::from_millis(170),
"slow reachable backend should not incur an extra full budget after already exceeding it"
);
accept_task.await.unwrap();
}
#[tokio::test]
async fn adversarial_enabled_refused_and_reachable_collapse_to_same_bucket() {
const ITER: usize = 20;
const BUCKET_MS: u128 = 10;
let probe = b"GET /collapse HTTP/1.1\r\nHost: x\r\n\r\n";
let peer: SocketAddr = "203.0.113.16:42430".parse().unwrap();
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let mut refused = Vec::with_capacity(ITER);
for _ in 0..ITER {
let temp_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let unused_port = temp_listener.local_addr().unwrap().port();
drop(temp_listener);
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = unused_port;
config.censorship.mask_unix_sock = None;
config.censorship.mask_proxy_protocol = 0;
let (client_writer_side, client_reader) = duplex(256);
drop(client_writer_side);
let (_client_visible_reader, client_visible_writer) = duplex(256);
let beobachten = BeobachtenStore::new();
let started = Instant::now();
handle_bad_client(
client_reader,
client_visible_writer,
probe,
peer,
local_addr,
&config,
&beobachten,
)
.await;
refused.push(started.elapsed().as_millis());
}
let mut reachable = Vec::with_capacity(ITER);
for _ in 0..ITER {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let probe_vec = probe.to_vec();
let backend_reply = b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\n\r\n".to_vec();
let accept_task = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut received = vec![0u8; probe_vec.len()];
stream.read_exact(&mut received).await.unwrap();
stream.write_all(&backend_reply).await.unwrap();
});
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = backend_addr.port();
config.censorship.mask_unix_sock = None;
config.censorship.mask_proxy_protocol = 0;
let (client_writer_side, client_reader) = duplex(256);
drop(client_writer_side);
let (_client_visible_reader, client_visible_writer) = duplex(256);
let beobachten = BeobachtenStore::new();
let started = Instant::now();
handle_bad_client(
client_reader,
client_visible_writer,
probe,
peer,
local_addr,
&config,
&beobachten,
)
.await;
reachable.push(started.elapsed().as_millis());
accept_task.await.unwrap();
}
let refused_mean = refused.iter().copied().sum::<u128>() as f64 / refused.len() as f64;
let reachable_mean = reachable.iter().copied().sum::<u128>() as f64 / reachable.len() as f64;
let refused_bucket = (refused_mean as u128) / BUCKET_MS;
let reachable_bucket = (reachable_mean as u128) / BUCKET_MS;
assert!(
refused_bucket.abs_diff(reachable_bucket) <= 1,
"enabled refused and reachable paths must collapse into the same coarse latency bucket"
);
}
#[tokio::test]
async fn light_fuzz_mask_enabled_outcomes_preserve_coarse_budget() {
let mut seed: u64 = 0xA5A5_5A5A_1337_4242;
let mut next = || {
seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
seed
};
let peer: SocketAddr = "203.0.113.17:42431".parse().unwrap();
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
for _ in 0..40 {
let probe_len = (next() as usize % 96).saturating_add(8);
let mut probe = vec![0u8; probe_len];
for byte in &mut probe {
*byte = next() as u8;
}
let use_reachable = (next() & 1) == 0;
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_unix_sock = None;
config.censorship.mask_proxy_protocol = 0;
let (client_writer_side, client_reader) = duplex(512);
drop(client_writer_side);
let (_client_visible_reader, client_visible_writer) = duplex(512);
let beobachten = BeobachtenStore::new();
let started = Instant::now();
if use_reachable {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = backend_addr.port();
let probe_vec = probe.clone();
let accept_task = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut observed = vec![0u8; probe_vec.len()];
stream.read_exact(&mut observed).await.unwrap();
});
handle_bad_client(
client_reader,
client_visible_writer,
&probe,
peer,
local_addr,
&config,
&beobachten,
)
.await;
accept_task.await.unwrap();
} else {
let temp_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let unused_port = temp_listener.local_addr().unwrap().port();
drop(temp_listener);
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = unused_port;
handle_bad_client(
client_reader,
client_visible_writer,
&probe,
peer,
local_addr,
&config,
&beobachten,
)
.await;
}
assert!(
started.elapsed() >= Duration::from_millis(45),
"mask-enabled fallback must preserve coarse timing budget under varied probe shapes"
);
}
}
#[tokio::test] #[tokio::test]
async fn mask_disabled_consumes_client_data_without_response() { async fn mask_disabled_consumes_client_data_without_response() {
let mut config = ProxyConfig::default(); let mut config = ProxyConfig::default();
@ -729,3 +1095,158 @@ async fn relay_to_mask_timeout_cancels_and_drops_all_io_endpoints() {
assert!(mask_reader_dropped.load(Ordering::SeqCst)); assert!(mask_reader_dropped.load(Ordering::SeqCst));
assert!(mask_writer_dropped.load(Ordering::SeqCst)); assert!(mask_writer_dropped.load(Ordering::SeqCst));
} }
#[tokio::test]
#[ignore = "timing matrix; run manually with --ignored --nocapture"]
async fn timing_matrix_masking_classes_under_controlled_inputs() {
const ITER: usize = 24;
const BUCKET_MS: u128 = 10;
let probe = b"GET /timing HTTP/1.1\r\nHost: x\r\n\r\n";
let peer: SocketAddr = "203.0.113.40:51000".parse().unwrap();
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
// Class 1: masking disabled with immediate EOF (fast fail-closed consume path).
let mut disabled_samples = Vec::with_capacity(ITER);
for _ in 0..ITER {
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = false;
let (client_writer_side, client_reader) = duplex(256);
drop(client_writer_side);
let (_client_visible_reader, client_visible_writer) = duplex(256);
let beobachten = BeobachtenStore::new();
let started = Instant::now();
handle_bad_client(
client_reader,
client_visible_writer,
probe,
peer,
local_addr,
&config,
&beobachten,
)
.await;
disabled_samples.push(started.elapsed().as_millis());
}
// Class 2: masking enabled, backend connect refused.
let mut refused_samples = Vec::with_capacity(ITER);
for _ in 0..ITER {
let temp_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let unused_port = temp_listener.local_addr().unwrap().port();
drop(temp_listener);
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = unused_port;
config.censorship.mask_unix_sock = None;
config.censorship.mask_proxy_protocol = 0;
let (client_writer_side, client_reader) = duplex(256);
drop(client_writer_side);
let (_client_visible_reader, client_visible_writer) = duplex(256);
let beobachten = BeobachtenStore::new();
let started = Instant::now();
handle_bad_client(
client_reader,
client_visible_writer,
probe,
peer,
local_addr,
&config,
&beobachten,
)
.await;
refused_samples.push(started.elapsed().as_millis());
}
// Class 3: masking enabled, backend reachable and immediately responds.
let mut reachable_samples = Vec::with_capacity(ITER);
for _ in 0..ITER {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let backend_reply = b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\n\r\n".to_vec();
let probe_vec = probe.to_vec();
let accept_task = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut received = vec![0u8; probe_vec.len()];
stream.read_exact(&mut received).await.unwrap();
assert_eq!(received, probe_vec);
stream.write_all(&backend_reply).await.unwrap();
});
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = backend_addr.port();
config.censorship.mask_unix_sock = None;
config.censorship.mask_proxy_protocol = 0;
let (client_writer_side, client_reader) = duplex(256);
drop(client_writer_side);
let (_client_visible_reader, client_visible_writer) = duplex(256);
let beobachten = BeobachtenStore::new();
let started = Instant::now();
handle_bad_client(
client_reader,
client_visible_writer,
probe,
peer,
local_addr,
&config,
&beobachten,
)
.await;
reachable_samples.push(started.elapsed().as_millis());
accept_task.await.unwrap();
}
fn summarize(samples_ms: &mut [u128]) -> (f64, u128, u128, u128) {
samples_ms.sort_unstable();
let sum: u128 = samples_ms.iter().copied().sum();
let mean = sum as f64 / samples_ms.len() as f64;
let min = samples_ms[0];
let p95_idx = ((samples_ms.len() as f64) * 0.95).floor() as usize;
let p95 = samples_ms[p95_idx.min(samples_ms.len() - 1)];
let max = samples_ms[samples_ms.len() - 1];
(mean, min, p95, max)
}
let (disabled_mean, disabled_min, disabled_p95, disabled_max) = summarize(&mut disabled_samples);
let (refused_mean, refused_min, refused_p95, refused_max) = summarize(&mut refused_samples);
let (reachable_mean, reachable_min, reachable_p95, reachable_max) = summarize(&mut reachable_samples);
println!(
"TIMING_MATRIX masking class=disabled_eof mean_ms={:.2} min_ms={} p95_ms={} max_ms={} bucket_mean={}",
disabled_mean,
disabled_min,
disabled_p95,
disabled_max,
(disabled_mean as u128) / BUCKET_MS
);
println!(
"TIMING_MATRIX masking class=enabled_refused_eof mean_ms={:.2} min_ms={} p95_ms={} max_ms={} bucket_mean={}",
refused_mean,
refused_min,
refused_p95,
refused_max,
(refused_mean as u128) / BUCKET_MS
);
println!(
"TIMING_MATRIX masking class=enabled_reachable_eof mean_ms={:.2} min_ms={} p95_ms={} max_ms={} bucket_mean={}",
reachable_mean,
reachable_min,
reachable_p95,
reachable_max,
(reachable_mean as u128) / BUCKET_MS
);
}

View File

@ -306,7 +306,7 @@ where
}; };
stats.increment_user_connects(&user); stats.increment_user_connects(&user);
stats.increment_current_connections_me(); let _me_connection_lease = stats.acquire_me_connection_lease();
if let Some(cutover) = affected_cutover_state( if let Some(cutover) = affected_cutover_state(
&route_rx, &route_rx,
@ -324,7 +324,6 @@ where
tokio::time::sleep(delay).await; tokio::time::sleep(delay).await;
let _ = me_pool.send_close(conn_id).await; let _ = me_pool.send_close(conn_id).await;
me_pool.registry().unregister(conn_id).await; me_pool.registry().unregister(conn_id).await;
stats.decrement_current_connections_me();
return Err(ProxyError::Proxy(ROUTE_SWITCH_ERROR_MSG.to_string())); return Err(ProxyError::Proxy(ROUTE_SWITCH_ERROR_MSG.to_string()));
} }
@ -672,7 +671,6 @@ where
"ME relay cleanup" "ME relay cleanup"
); );
me_pool.registry().unregister(conn_id).await; me_pool.registry().unregister(conn_id).await;
stats.decrement_current_connections_me();
result result
} }

View File

@ -2,8 +2,13 @@ use super::*;
use bytes::Bytes; use bytes::Bytes;
use crate::crypto::AesCtr; use crate::crypto::AesCtr;
use crate::crypto::SecureRandom; use crate::crypto::SecureRandom;
use crate::config::{GeneralConfig, MeRouteNoWriterMode, MeSocksKdfPolicy, MeWriterPickMode};
use crate::network::probe::NetworkDecision;
use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController};
use crate::stats::Stats; use crate::stats::Stats;
use crate::stream::{BufferPool, CryptoReader, CryptoWriter, PooledBuffer}; use crate::stream::{BufferPool, CryptoReader, CryptoWriter, PooledBuffer};
use crate::transport::middle_proxy::MePool;
use std::collections::HashMap;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
use std::sync::atomic::AtomicU64; use std::sync::atomic::AtomicU64;
@ -229,18 +234,108 @@ fn make_forensics_state() -> RelayForensicsState {
} }
} }
fn make_crypto_reader(reader: tokio::io::DuplexStream) -> CryptoReader<tokio::io::DuplexStream> { fn make_crypto_reader<R>(reader: R) -> CryptoReader<R>
where
R: tokio::io::AsyncRead + Unpin,
{
let key = [0u8; 32]; let key = [0u8; 32];
let iv = 0u128; let iv = 0u128;
CryptoReader::new(reader, AesCtr::new(&key, iv)) CryptoReader::new(reader, AesCtr::new(&key, iv))
} }
fn make_crypto_writer(writer: tokio::io::DuplexStream) -> CryptoWriter<tokio::io::DuplexStream> { fn make_crypto_writer<W>(writer: W) -> CryptoWriter<W>
where
W: tokio::io::AsyncWrite + Unpin,
{
let key = [0u8; 32]; let key = [0u8; 32];
let iv = 0u128; let iv = 0u128;
CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024)
} }
async fn make_me_pool_for_abort_test(stats: Arc<Stats>) -> Arc<MePool> {
let general = GeneralConfig::default();
MePool::new(
None,
vec![1u8; 32],
None,
false,
None,
Vec::new(),
1,
None,
12,
1200,
HashMap::new(),
HashMap::new(),
None,
NetworkDecision::default(),
None,
Arc::new(SecureRandom::new()),
stats,
general.me_keepalive_enabled,
general.me_keepalive_interval_secs,
general.me_keepalive_jitter_secs,
general.me_keepalive_payload_random,
general.rpc_proxy_req_every,
general.me_warmup_stagger_enabled,
general.me_warmup_step_delay_ms,
general.me_warmup_step_jitter_ms,
general.me_reconnect_max_concurrent_per_dc,
general.me_reconnect_backoff_base_ms,
general.me_reconnect_backoff_cap_ms,
general.me_reconnect_fast_retry_count,
general.me_single_endpoint_shadow_writers,
general.me_single_endpoint_outage_mode_enabled,
general.me_single_endpoint_outage_disable_quarantine,
general.me_single_endpoint_outage_backoff_min_ms,
general.me_single_endpoint_outage_backoff_max_ms,
general.me_single_endpoint_shadow_rotate_every_secs,
general.me_floor_mode,
general.me_adaptive_floor_idle_secs,
general.me_adaptive_floor_min_writers_single_endpoint,
general.me_adaptive_floor_min_writers_multi_endpoint,
general.me_adaptive_floor_recover_grace_secs,
general.me_adaptive_floor_writers_per_core_total,
general.me_adaptive_floor_cpu_cores_override,
general.me_adaptive_floor_max_extra_writers_single_per_core,
general.me_adaptive_floor_max_extra_writers_multi_per_core,
general.me_adaptive_floor_max_active_writers_per_core,
general.me_adaptive_floor_max_warm_writers_per_core,
general.me_adaptive_floor_max_active_writers_global,
general.me_adaptive_floor_max_warm_writers_global,
general.hardswap,
general.me_pool_drain_ttl_secs,
general.me_pool_drain_threshold,
general.effective_me_pool_force_close_secs(),
general.me_pool_min_fresh_ratio,
general.me_hardswap_warmup_delay_min_ms,
general.me_hardswap_warmup_delay_max_ms,
general.me_hardswap_warmup_extra_passes,
general.me_hardswap_warmup_pass_backoff_base_ms,
general.me_bind_stale_mode,
general.me_bind_stale_ttl_secs,
general.me_secret_atomic_snapshot,
general.me_deterministic_writer_sort,
MeWriterPickMode::default(),
general.me_writer_pick_sample_size,
MeSocksKdfPolicy::default(),
general.me_writer_cmd_channel_capacity,
general.me_route_channel_capacity,
general.me_route_backpressure_base_timeout_ms,
general.me_route_backpressure_high_timeout_ms,
general.me_route_backpressure_high_watermark_pct,
general.me_reader_route_data_wait_ms,
general.me_health_interval_ms_unhealthy,
general.me_health_interval_ms_healthy,
general.me_warn_rate_limit_ms,
MeRouteNoWriterMode::default(),
general.me_route_no_writer_wait_ms,
general.me_route_inline_recovery_attempts,
general.me_route_inline_recovery_wait_ms,
)
}
fn encrypt_for_reader(plaintext: &[u8]) -> Vec<u8> { fn encrypt_for_reader(plaintext: &[u8]) -> Vec<u8> {
let key = [0u8; 32]; let key = [0u8; 32];
let iv = 0u128; let iv = 0u128;
@ -779,3 +874,148 @@ async fn process_me_writer_response_data_updates_byte_accounting() {
"ME->C byte accounting must increase by emitted payload size" "ME->C byte accounting must increase by emitted payload size"
); );
} }
#[tokio::test]
async fn middle_relay_abort_midflight_releases_route_gauge() {
let stats = Arc::new(Stats::new());
let me_pool = make_me_pool_for_abort_test(stats.clone()).await;
let config = Arc::new(ProxyConfig::default());
let buffer_pool = Arc::new(BufferPool::new());
let rng = Arc::new(SecureRandom::new());
let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Middle));
let route_snapshot = route_runtime.snapshot();
let (server_side, client_side) = duplex(64 * 1024);
let (server_reader, server_writer) = tokio::io::split(server_side);
let crypto_reader = make_crypto_reader(server_reader);
let crypto_writer = make_crypto_writer(server_writer);
let success = HandshakeSuccess {
user: "abort-middle-user".to_string(),
dc_idx: 2,
proto_tag: ProtoTag::Intermediate,
dec_key: [0u8; 32],
dec_iv: 0,
enc_key: [0u8; 32],
enc_iv: 0,
peer: "127.0.0.1:50001".parse().unwrap(),
is_tls: false,
};
let relay_task = tokio::spawn(handle_via_middle_proxy(
crypto_reader,
crypto_writer,
success,
me_pool,
stats.clone(),
config,
buffer_pool,
"127.0.0.1:443".parse().unwrap(),
rng,
route_runtime.subscribe(),
route_snapshot,
0xdecafbad,
));
let started = tokio::time::timeout(TokioDuration::from_secs(2), async {
loop {
if stats.get_current_connections_me() == 1 {
break;
}
tokio::time::sleep(TokioDuration::from_millis(10)).await;
}
})
.await;
assert!(started.is_ok(), "middle relay must increment route gauge before abort");
relay_task.abort();
let joined = relay_task.await;
assert!(joined.is_err(), "aborted middle relay task must return join error");
tokio::time::sleep(TokioDuration::from_millis(20)).await;
assert_eq!(
stats.get_current_connections_me(),
0,
"route gauge must be released when middle relay task is aborted mid-flight"
);
drop(client_side);
}
#[tokio::test]
async fn middle_relay_cutover_midflight_releases_route_gauge() {
let stats = Arc::new(Stats::new());
let me_pool = make_me_pool_for_abort_test(stats.clone()).await;
let config = Arc::new(ProxyConfig::default());
let buffer_pool = Arc::new(BufferPool::new());
let rng = Arc::new(SecureRandom::new());
let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Middle));
let route_snapshot = route_runtime.snapshot();
let (server_side, client_side) = duplex(64 * 1024);
let (server_reader, server_writer) = tokio::io::split(server_side);
let crypto_reader = make_crypto_reader(server_reader);
let crypto_writer = make_crypto_writer(server_writer);
let success = HandshakeSuccess {
user: "cutover-middle-user".to_string(),
dc_idx: 2,
proto_tag: ProtoTag::Intermediate,
dec_key: [0u8; 32],
dec_iv: 0,
enc_key: [0u8; 32],
enc_iv: 0,
peer: "127.0.0.1:50003".parse().unwrap(),
is_tls: false,
};
let relay_task = tokio::spawn(handle_via_middle_proxy(
crypto_reader,
crypto_writer,
success,
me_pool,
stats.clone(),
config,
buffer_pool,
"127.0.0.1:443".parse().unwrap(),
rng,
route_runtime.subscribe(),
route_snapshot,
0xfeed_beef,
));
tokio::time::timeout(TokioDuration::from_secs(2), async {
loop {
if stats.get_current_connections_me() == 1 {
break;
}
tokio::time::sleep(TokioDuration::from_millis(10)).await;
}
})
.await
.expect("middle relay must increment route gauge before cutover");
assert!(
route_runtime.set_mode(RelayRouteMode::Direct).is_some(),
"cutover must advance route generation"
);
let relay_result = tokio::time::timeout(TokioDuration::from_secs(6), relay_task)
.await
.expect("middle relay must terminate after cutover")
.expect("middle relay task must not panic");
assert!(
relay_result.is_err(),
"cutover should terminate middle relay session"
);
assert_eq!(
stats.get_current_connections_me(),
0,
"route gauge must be released when middle relay exits on cutover"
);
drop(client_side);
}

View File

@ -0,0 +1,265 @@
use super::*;
use std::panic::{self, AssertUnwindSafe};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Barrier;
#[test]
fn direct_connection_lease_balances_on_drop() {
let stats = Arc::new(Stats::new());
assert_eq!(stats.get_current_connections_direct(), 0);
{
let _lease = stats.acquire_direct_connection_lease();
assert_eq!(stats.get_current_connections_direct(), 1);
}
assert_eq!(stats.get_current_connections_direct(), 0);
}
#[test]
fn middle_connection_lease_balances_on_drop() {
let stats = Arc::new(Stats::new());
assert_eq!(stats.get_current_connections_me(), 0);
{
let _lease = stats.acquire_me_connection_lease();
assert_eq!(stats.get_current_connections_me(), 1);
}
assert_eq!(stats.get_current_connections_me(), 0);
}
#[test]
fn connection_lease_disarm_prevents_double_release() {
let stats = Arc::new(Stats::new());
let mut lease = stats.acquire_direct_connection_lease();
assert_eq!(stats.get_current_connections_direct(), 1);
stats.decrement_current_connections_direct();
assert_eq!(stats.get_current_connections_direct(), 0);
lease.disarm();
drop(lease);
assert_eq!(stats.get_current_connections_direct(), 0);
}
#[test]
fn direct_connection_lease_balances_on_panic_unwind() {
let stats = Arc::new(Stats::new());
let stats_for_panic = stats.clone();
let panic_result = panic::catch_unwind(AssertUnwindSafe(move || {
let _lease = stats_for_panic.acquire_direct_connection_lease();
panic!("intentional panic to verify lease drop path");
}));
assert!(panic_result.is_err(), "panic must propagate from test closure");
assert_eq!(
stats.get_current_connections_direct(),
0,
"panic unwind must release direct route gauge"
);
}
#[test]
fn middle_connection_lease_balances_on_panic_unwind() {
let stats = Arc::new(Stats::new());
let stats_for_panic = stats.clone();
let panic_result = panic::catch_unwind(AssertUnwindSafe(move || {
let _lease = stats_for_panic.acquire_me_connection_lease();
panic!("intentional panic to verify middle lease drop path");
}));
assert!(panic_result.is_err(), "panic must propagate from test closure");
assert_eq!(
stats.get_current_connections_me(),
0,
"panic unwind must release middle route gauge"
);
}
#[tokio::test]
async fn concurrent_mixed_route_lease_churn_balances_to_zero() {
const TASKS: usize = 48;
const ITERATIONS_PER_TASK: usize = 256;
let stats = Arc::new(Stats::new());
let barrier = Arc::new(Barrier::new(TASKS));
let mut workers = Vec::with_capacity(TASKS);
for task_idx in 0..TASKS {
let stats_for_task = stats.clone();
let barrier_for_task = barrier.clone();
workers.push(tokio::spawn(async move {
barrier_for_task.wait().await;
for iter in 0..ITERATIONS_PER_TASK {
if (task_idx + iter) % 2 == 0 {
let _lease = stats_for_task.acquire_direct_connection_lease();
tokio::task::yield_now().await;
} else {
let _lease = stats_for_task.acquire_me_connection_lease();
tokio::task::yield_now().await;
}
}
}));
}
for worker in workers {
worker
.await
.expect("lease churn worker must not panic");
}
assert_eq!(
stats.get_current_connections_direct(),
0,
"direct route gauge must return to zero after concurrent lease churn"
);
assert_eq!(
stats.get_current_connections_me(),
0,
"middle route gauge must return to zero after concurrent lease churn"
);
}
#[tokio::test]
async fn abort_storm_mixed_route_leases_returns_all_gauges_to_zero() {
const TASKS: usize = 64;
let stats = Arc::new(Stats::new());
let mut workers = Vec::with_capacity(TASKS);
for task_idx in 0..TASKS {
let stats_for_task = stats.clone();
workers.push(tokio::spawn(async move {
if task_idx % 2 == 0 {
let _lease = stats_for_task.acquire_direct_connection_lease();
tokio::time::sleep(Duration::from_secs(60)).await;
} else {
let _lease = stats_for_task.acquire_me_connection_lease();
tokio::time::sleep(Duration::from_secs(60)).await;
}
}));
}
tokio::time::timeout(Duration::from_secs(2), async {
loop {
let total = stats.get_current_connections_direct() + stats.get_current_connections_me();
if total == TASKS as u64 {
break;
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
})
.await
.expect("all storm tasks must acquire route leases before abort");
for worker in &workers {
worker.abort();
}
for worker in workers {
let joined = worker.await;
assert!(joined.is_err(), "aborted worker must return join error");
}
tokio::time::timeout(Duration::from_secs(2), async {
loop {
if stats.get_current_connections_direct() == 0 && stats.get_current_connections_me() == 0 {
break;
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
})
.await
.expect("all route gauges must drain to zero after abort storm");
}
#[test]
fn saturating_route_decrements_do_not_underflow_under_race() {
const THREADS: usize = 16;
const DECREMENTS_PER_THREAD: usize = 4096;
let stats = Arc::new(Stats::new());
let mut workers = Vec::with_capacity(THREADS);
for _ in 0..THREADS {
let stats_for_thread = stats.clone();
workers.push(std::thread::spawn(move || {
for _ in 0..DECREMENTS_PER_THREAD {
stats_for_thread.decrement_current_connections_direct();
stats_for_thread.decrement_current_connections_me();
}
}));
}
for worker in workers {
worker
.join()
.expect("decrement race worker must not panic");
}
assert_eq!(
stats.get_current_connections_direct(),
0,
"direct route decrement races must never underflow"
);
assert_eq!(
stats.get_current_connections_me(),
0,
"middle route decrement races must never underflow"
);
}
#[tokio::test]
async fn direct_connection_lease_balances_on_task_abort() {
let stats = Arc::new(Stats::new());
let stats_for_task = stats.clone();
let task = tokio::spawn(async move {
let _lease = stats_for_task.acquire_direct_connection_lease();
tokio::time::sleep(Duration::from_secs(60)).await;
});
tokio::time::sleep(Duration::from_millis(20)).await;
assert_eq!(stats.get_current_connections_direct(), 1);
task.abort();
let joined = task.await;
assert!(joined.is_err(), "aborted task must return a join error");
tokio::time::sleep(Duration::from_millis(20)).await;
assert_eq!(
stats.get_current_connections_direct(),
0,
"aborted task must release direct route gauge"
);
}
#[tokio::test]
async fn middle_connection_lease_balances_on_task_abort() {
let stats = Arc::new(Stats::new());
let stats_for_task = stats.clone();
let task = tokio::spawn(async move {
let _lease = stats_for_task.acquire_me_connection_lease();
tokio::time::sleep(Duration::from_secs(60)).await;
});
tokio::time::sleep(Duration::from_millis(20)).await;
assert_eq!(stats.get_current_connections_me(), 1);
task.abort();
let joined = task.await;
assert!(joined.is_err(), "aborted task must return a join error");
tokio::time::sleep(Duration::from_millis(20)).await;
assert_eq!(
stats.get_current_connections_me(),
0,
"aborted task must release middle route gauge"
);
}

View File

@ -6,6 +6,7 @@ pub mod beobachten;
pub mod telemetry; pub mod telemetry;
use std::sync::atomic::{AtomicBool, AtomicU8, AtomicU64, Ordering}; use std::sync::atomic::{AtomicBool, AtomicU8, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
use dashmap::DashMap; use dashmap::DashMap;
use parking_lot::Mutex; use parking_lot::Mutex;
@ -19,6 +20,46 @@ use tracing::debug;
use crate::config::{MeTelemetryLevel, MeWriterPickMode}; use crate::config::{MeTelemetryLevel, MeWriterPickMode};
use self::telemetry::TelemetryPolicy; use self::telemetry::TelemetryPolicy;
#[derive(Clone, Copy)]
enum RouteConnectionGauge {
Direct,
Middle,
}
#[must_use = "RouteConnectionLease must be kept alive to hold the connection gauge increment"]
pub struct RouteConnectionLease {
stats: Arc<Stats>,
gauge: RouteConnectionGauge,
active: bool,
}
impl RouteConnectionLease {
fn new(stats: Arc<Stats>, gauge: RouteConnectionGauge) -> Self {
Self {
stats,
gauge,
active: true,
}
}
#[cfg(test)]
fn disarm(&mut self) {
self.active = false;
}
}
impl Drop for RouteConnectionLease {
fn drop(&mut self) {
if !self.active {
return;
}
match self.gauge {
RouteConnectionGauge::Direct => self.stats.decrement_current_connections_direct(),
RouteConnectionGauge::Middle => self.stats.decrement_current_connections_me(),
}
}
}
// ============= Stats ============= // ============= Stats =============
#[derive(Default)] #[derive(Default)]
@ -285,6 +326,16 @@ impl Stats {
pub fn decrement_current_connections_me(&self) { pub fn decrement_current_connections_me(&self) {
Self::decrement_atomic_saturating(&self.current_connections_me); Self::decrement_atomic_saturating(&self.current_connections_me);
} }
pub fn acquire_direct_connection_lease(self: &Arc<Self>) -> RouteConnectionLease {
self.increment_current_connections_direct();
RouteConnectionLease::new(self.clone(), RouteConnectionGauge::Direct)
}
pub fn acquire_me_connection_lease(self: &Arc<Self>) -> RouteConnectionLease {
self.increment_current_connections_me();
RouteConnectionLease::new(self.clone(), RouteConnectionGauge::Middle)
}
pub fn increment_handshake_timeouts(&self) { pub fn increment_handshake_timeouts(&self) {
if self.telemetry_core_enabled() { if self.telemetry_core_enabled() {
self.handshake_timeouts.fetch_add(1, Ordering::Relaxed); self.handshake_timeouts.fetch_add(1, Ordering::Relaxed);
@ -1772,3 +1823,7 @@ mod tests {
assert_eq!(checker.stats().total_entries, 500); assert_eq!(checker.stats().total_entries, 500);
} }
} }
#[cfg(test)]
#[path = "connection_lease_security_tests.rs"]
mod connection_lease_security_tests;

View File

@ -25,6 +25,9 @@ const HEALTH_RECONNECT_BUDGET_PER_CORE: usize = 2;
const HEALTH_RECONNECT_BUDGET_PER_DC: usize = 1; const HEALTH_RECONNECT_BUDGET_PER_DC: usize = 1;
const HEALTH_RECONNECT_BUDGET_MIN: usize = 4; const HEALTH_RECONNECT_BUDGET_MIN: usize = 4;
const HEALTH_RECONNECT_BUDGET_MAX: usize = 128; const HEALTH_RECONNECT_BUDGET_MAX: usize = 128;
const HEALTH_DRAIN_CLOSE_BUDGET_PER_CORE: usize = 16;
const HEALTH_DRAIN_CLOSE_BUDGET_MIN: usize = 16;
const HEALTH_DRAIN_CLOSE_BUDGET_MAX: usize = 256;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
struct DcFloorPlanEntry { struct DcFloorPlanEntry {
@ -111,106 +114,75 @@ pub async fn me_health_monitor(pool: Arc<MePool>, rng: Arc<SecureRandom>, _min_c
} }
} }
async fn reap_draining_writers( pub(super) async fn reap_draining_writers(
pool: &Arc<MePool>, pool: &Arc<MePool>,
warn_next_allowed: &mut HashMap<u64, Instant>, warn_next_allowed: &mut HashMap<u64, Instant>,
) { ) {
if pool.draining_active_runtime() == 0 {
return;
}
let now_epoch_secs = MePool::now_epoch_secs(); let now_epoch_secs = MePool::now_epoch_secs();
let now = Instant::now(); let now = Instant::now();
let drain_ttl_secs = pool.me_pool_drain_ttl_secs.load(std::sync::atomic::Ordering::Relaxed); let drain_ttl_secs = pool.me_pool_drain_ttl_secs.load(std::sync::atomic::Ordering::Relaxed);
let drain_threshold = pool let drain_threshold = pool
.me_pool_drain_threshold .me_pool_drain_threshold
.load(std::sync::atomic::Ordering::Relaxed); .load(std::sync::atomic::Ordering::Relaxed);
let mut draining_writers = { let activity = pool.registry.writer_activity_snapshot().await;
let writers = pool.writers.read().await; let mut draining_writers = Vec::<DrainingWriterSnapshot>::new();
let mut draining_writers = Vec::<DrainingWriterSnapshot>::new(); let mut empty_writer_ids = Vec::<u64>::new();
for writer in writers.iter() { let mut force_close_writer_ids = Vec::<u64>::new();
if !writer.draining.load(std::sync::atomic::Ordering::Relaxed) { let writers = pool.writers.read().await;
continue; for writer in writers.iter() {
} if !writer.draining.load(std::sync::atomic::Ordering::Relaxed) {
draining_writers.push(DrainingWriterSnapshot { continue;
id: writer.id,
writer_dc: writer.writer_dc,
addr: writer.addr,
generation: writer.generation,
created_at: writer.created_at,
draining_started_at_epoch_secs: writer
.draining_started_at_epoch_secs
.load(std::sync::atomic::Ordering::Relaxed),
drain_deadline_epoch_secs: writer
.drain_deadline_epoch_secs
.load(std::sync::atomic::Ordering::Relaxed),
allow_drain_fallback: writer
.allow_drain_fallback
.load(std::sync::atomic::Ordering::Relaxed),
});
} }
draining_writers if activity
}; .bound_clients_by_writer
.get(&writer.id)
if draining_writers.is_empty() { .copied()
return; .unwrap_or(0)
} == 0
{
let draining_ids: Vec<u64> = draining_writers.iter().map(|writer| writer.id).collect(); empty_writer_ids.push(writer.id);
let non_empty_writer_ids = pool.registry.non_empty_writer_ids(&draining_ids).await; continue;
let mut non_empty_draining_writers =
Vec::<DrainingWriterSnapshot>::with_capacity(draining_writers.len());
for writer in draining_writers.drain(..) {
if non_empty_writer_ids.contains(&writer.id) {
non_empty_draining_writers.push(writer);
} else {
pool.remove_writer_and_close_clients(writer.id).await;
} }
draining_writers.push(DrainingWriterSnapshot {
id: writer.id,
writer_dc: writer.writer_dc,
addr: writer.addr,
generation: writer.generation,
created_at: writer.created_at,
draining_started_at_epoch_secs: writer
.draining_started_at_epoch_secs
.load(std::sync::atomic::Ordering::Relaxed),
drain_deadline_epoch_secs: writer
.drain_deadline_epoch_secs
.load(std::sync::atomic::Ordering::Relaxed),
allow_drain_fallback: writer
.allow_drain_fallback
.load(std::sync::atomic::Ordering::Relaxed),
});
} }
draining_writers = non_empty_draining_writers; drop(writers);
if draining_writers.is_empty() {
return;
}
let overflow = if drain_threshold > 0 && draining_writers.len() > drain_threshold as usize { let overflow = if drain_threshold > 0 && draining_writers.len() > drain_threshold as usize {
draining_writers.len().saturating_sub(drain_threshold as usize) draining_writers.len().saturating_sub(drain_threshold as usize)
} else { } else {
0 0
}; };
let has_deadline_expired = draining_writers.iter().any(|writer| {
writer.drain_deadline_epoch_secs != 0 && now_epoch_secs >= writer.drain_deadline_epoch_secs
});
let can_drop_with_replacement = if overflow > 0 || has_deadline_expired {
pool.has_non_draining_writer_per_desired_dc_group().await
} else {
false
};
if overflow > 0 { if overflow > 0 {
if can_drop_with_replacement { draining_writers.sort_by(|left, right| {
draining_writers.sort_by(|left, right| { left.draining_started_at_epoch_secs
left.draining_started_at_epoch_secs .cmp(&right.draining_started_at_epoch_secs)
.cmp(&right.draining_started_at_epoch_secs) .then_with(|| left.created_at.cmp(&right.created_at))
.then_with(|| left.created_at.cmp(&right.created_at)) .then_with(|| left.id.cmp(&right.id))
.then_with(|| left.id.cmp(&right.id)) });
}); warn!(
warn!( draining_writers = draining_writers.len(),
draining_writers = draining_writers.len(), me_pool_drain_threshold = drain_threshold,
me_pool_drain_threshold = drain_threshold, removing_writers = overflow,
removing_writers = overflow, "ME draining writer threshold exceeded, force-closing oldest draining writers"
"ME draining writer threshold exceeded, force-closing oldest draining writers" );
); for writer in draining_writers.drain(..overflow) {
for writer in draining_writers.drain(..overflow) { force_close_writer_ids.push(writer.id);
pool.stats.increment_pool_force_close_total();
pool.remove_writer_and_close_clients(writer.id).await;
}
} else {
warn!(
draining_writers = draining_writers.len(),
me_pool_drain_threshold = drain_threshold,
overflow,
"ME draining threshold exceeded, but replacement coverage is incomplete; keeping draining writers"
);
} }
} }
@ -238,25 +210,71 @@ async fn reap_draining_writers(
} }
if writer.drain_deadline_epoch_secs != 0 && now_epoch_secs >= writer.drain_deadline_epoch_secs if writer.drain_deadline_epoch_secs != 0 && now_epoch_secs >= writer.drain_deadline_epoch_secs
{ {
if can_drop_with_replacement { warn!(writer_id = writer.id, "Drain timeout, force-closing");
warn!(writer_id = writer.id, "Drain timeout, force-closing"); force_close_writer_ids.push(writer.id);
pool.stats.increment_pool_force_close_total();
pool.remove_writer_and_close_clients(writer.id).await;
} else if should_emit_writer_warn(
warn_next_allowed,
writer.id,
now,
pool.warn_rate_limit_duration(),
) {
warn!(
writer_id = writer.id,
writer_dc = writer.writer_dc,
endpoint = %writer.addr,
"Drain timeout reached, but replacement coverage is incomplete; keeping draining writer"
);
}
} }
} }
let close_budget = health_drain_close_budget();
let requested_force_close = force_close_writer_ids.len();
let requested_empty_close = empty_writer_ids.len();
let requested_close_total = requested_force_close.saturating_add(requested_empty_close);
let mut closed_writer_ids = HashSet::<u64>::new();
let mut closed_total = 0usize;
for writer_id in force_close_writer_ids {
if closed_total >= close_budget {
break;
}
if !closed_writer_ids.insert(writer_id) {
continue;
}
pool.stats.increment_pool_force_close_total();
pool.remove_writer_and_close_clients(writer_id).await;
closed_total = closed_total.saturating_add(1);
}
for writer_id in empty_writer_ids {
if closed_total >= close_budget {
break;
}
if !closed_writer_ids.insert(writer_id) {
continue;
}
if !pool.remove_writer_if_empty(writer_id).await {
continue;
}
closed_total = closed_total.saturating_add(1);
}
let pending_close_total = requested_close_total.saturating_sub(closed_total);
if pending_close_total > 0 {
warn!(
close_budget,
closed_total,
pending_close_total,
"ME draining close backlog deferred to next health cycle"
);
}
// Keep warn cooldown state for draining writers still present in the pool;
// drop state only once a writer is actually removed.
let active_draining_writer_ids = {
let writers = pool.writers.read().await;
writers
.iter()
.filter(|writer| writer.draining.load(std::sync::atomic::Ordering::Relaxed))
.map(|writer| writer.id)
.collect::<HashSet<u64>>()
};
warn_next_allowed.retain(|writer_id, _| active_draining_writer_ids.contains(writer_id));
}
pub(super) fn health_drain_close_budget() -> usize {
let cpu_cores = std::thread::available_parallelism()
.map(std::num::NonZeroUsize::get)
.unwrap_or(1);
cpu_cores
.saturating_mul(HEALTH_DRAIN_CLOSE_BUDGET_PER_CORE)
.clamp(HEALTH_DRAIN_CLOSE_BUDGET_MIN, HEALTH_DRAIN_CLOSE_BUDGET_MAX)
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -1521,7 +1539,6 @@ mod tests {
pool.writers.write().await.push(writer); pool.writers.write().await.push(writer);
pool.registry.register_writer(writer_id, tx).await; pool.registry.register_writer(writer_id, tx).await;
pool.conn_count.fetch_add(1, Ordering::Relaxed); pool.conn_count.fetch_add(1, Ordering::Relaxed);
pool.increment_draining_active_runtime();
assert!( assert!(
pool.registry pool.registry
.bind_writer( .bind_writer(
@ -1570,7 +1587,6 @@ mod tests {
async fn reap_draining_writers_force_closes_oldest_over_threshold() { async fn reap_draining_writers_force_closes_oldest_over_threshold() {
let pool = make_pool(2).await; let pool = make_pool(2).await;
insert_live_writer(&pool, 1, 2).await; insert_live_writer(&pool, 1, 2).await;
assert!(pool.has_non_draining_writer_per_desired_dc_group().await);
let now_epoch_secs = MePool::now_epoch_secs(); let now_epoch_secs = MePool::now_epoch_secs();
let conn_a = insert_draining_writer(&pool, 10, now_epoch_secs.saturating_sub(30)).await; let conn_a = insert_draining_writer(&pool, 10, now_epoch_secs.saturating_sub(30)).await;
let conn_b = insert_draining_writer(&pool, 20, now_epoch_secs.saturating_sub(20)).await; let conn_b = insert_draining_writer(&pool, 20, now_epoch_secs.saturating_sub(20)).await;
@ -1588,7 +1604,7 @@ mod tests {
} }
#[tokio::test] #[tokio::test]
async fn reap_draining_writers_does_not_force_close_overflow_without_replacement() { async fn reap_draining_writers_force_closes_overflow_without_replacement() {
let pool = make_pool(2).await; let pool = make_pool(2).await;
let now_epoch_secs = MePool::now_epoch_secs(); let now_epoch_secs = MePool::now_epoch_secs();
let conn_a = insert_draining_writer(&pool, 10, now_epoch_secs.saturating_sub(30)).await; let conn_a = insert_draining_writer(&pool, 10, now_epoch_secs.saturating_sub(30)).await;
@ -1600,8 +1616,8 @@ mod tests {
let mut writer_ids: Vec<u64> = pool.writers.read().await.iter().map(|writer| writer.id).collect(); let mut writer_ids: Vec<u64> = pool.writers.read().await.iter().map(|writer| writer.id).collect();
writer_ids.sort_unstable(); writer_ids.sort_unstable();
assert_eq!(writer_ids, vec![10, 20, 30]); assert_eq!(writer_ids, vec![20, 30]);
assert_eq!(pool.registry.get_writer(conn_a).await.unwrap().writer_id, 10); assert!(pool.registry.get_writer(conn_a).await.is_none());
assert_eq!(pool.registry.get_writer(conn_b).await.unwrap().writer_id, 20); assert_eq!(pool.registry.get_writer(conn_b).await.unwrap().writer_id, 20);
assert_eq!(pool.registry.get_writer(conn_c).await.unwrap().writer_id, 30); assert_eq!(pool.registry.get_writer(conn_c).await.unwrap().writer_id, 30);
} }

View File

@ -0,0 +1,615 @@
use std::collections::HashMap;
use std::collections::HashSet;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU8, AtomicU32, AtomicU64, Ordering};
use std::time::{Duration, Instant};
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use super::codec::WriterCommand;
use super::health::{health_drain_close_budget, reap_draining_writers};
use super::pool::{MePool, MeWriter, WriterContour};
use super::registry::ConnMeta;
use super::me_health_monitor;
use crate::config::{GeneralConfig, MeRouteNoWriterMode, MeSocksKdfPolicy, MeWriterPickMode};
use crate::crypto::SecureRandom;
use crate::network::probe::NetworkDecision;
use crate::stats::Stats;
async fn make_pool(
me_pool_drain_threshold: u64,
me_health_interval_ms_unhealthy: u64,
me_health_interval_ms_healthy: u64,
) -> (Arc<MePool>, Arc<SecureRandom>) {
let general = GeneralConfig {
me_pool_drain_threshold,
me_health_interval_ms_unhealthy,
me_health_interval_ms_healthy,
..GeneralConfig::default()
};
let rng = Arc::new(SecureRandom::new());
let pool = MePool::new(
None,
vec![1u8; 32],
None,
false,
None,
Vec::new(),
1,
None,
12,
1200,
HashMap::new(),
HashMap::new(),
None,
NetworkDecision::default(),
None,
rng.clone(),
Arc::new(Stats::default()),
general.me_keepalive_enabled,
general.me_keepalive_interval_secs,
general.me_keepalive_jitter_secs,
general.me_keepalive_payload_random,
general.rpc_proxy_req_every,
general.me_warmup_stagger_enabled,
general.me_warmup_step_delay_ms,
general.me_warmup_step_jitter_ms,
general.me_reconnect_max_concurrent_per_dc,
general.me_reconnect_backoff_base_ms,
general.me_reconnect_backoff_cap_ms,
general.me_reconnect_fast_retry_count,
general.me_single_endpoint_shadow_writers,
general.me_single_endpoint_outage_mode_enabled,
general.me_single_endpoint_outage_disable_quarantine,
general.me_single_endpoint_outage_backoff_min_ms,
general.me_single_endpoint_outage_backoff_max_ms,
general.me_single_endpoint_shadow_rotate_every_secs,
general.me_floor_mode,
general.me_adaptive_floor_idle_secs,
general.me_adaptive_floor_min_writers_single_endpoint,
general.me_adaptive_floor_min_writers_multi_endpoint,
general.me_adaptive_floor_recover_grace_secs,
general.me_adaptive_floor_writers_per_core_total,
general.me_adaptive_floor_cpu_cores_override,
general.me_adaptive_floor_max_extra_writers_single_per_core,
general.me_adaptive_floor_max_extra_writers_multi_per_core,
general.me_adaptive_floor_max_active_writers_per_core,
general.me_adaptive_floor_max_warm_writers_per_core,
general.me_adaptive_floor_max_active_writers_global,
general.me_adaptive_floor_max_warm_writers_global,
general.hardswap,
general.me_pool_drain_ttl_secs,
general.me_pool_drain_threshold,
general.effective_me_pool_force_close_secs(),
general.me_pool_min_fresh_ratio,
general.me_hardswap_warmup_delay_min_ms,
general.me_hardswap_warmup_delay_max_ms,
general.me_hardswap_warmup_extra_passes,
general.me_hardswap_warmup_pass_backoff_base_ms,
general.me_bind_stale_mode,
general.me_bind_stale_ttl_secs,
general.me_secret_atomic_snapshot,
general.me_deterministic_writer_sort,
MeWriterPickMode::default(),
general.me_writer_pick_sample_size,
MeSocksKdfPolicy::default(),
general.me_writer_cmd_channel_capacity,
general.me_route_channel_capacity,
general.me_route_backpressure_base_timeout_ms,
general.me_route_backpressure_high_timeout_ms,
general.me_route_backpressure_high_watermark_pct,
general.me_reader_route_data_wait_ms,
general.me_health_interval_ms_unhealthy,
general.me_health_interval_ms_healthy,
general.me_warn_rate_limit_ms,
MeRouteNoWriterMode::default(),
general.me_route_no_writer_wait_ms,
general.me_route_inline_recovery_attempts,
general.me_route_inline_recovery_wait_ms,
);
(pool, rng)
}
async fn insert_draining_writer(
pool: &Arc<MePool>,
writer_id: u64,
drain_started_at_epoch_secs: u64,
bound_clients: usize,
drain_deadline_epoch_secs: u64,
) {
let (tx, _writer_rx) = mpsc::channel::<WriterCommand>(8);
let writer = MeWriter {
id: writer_id,
addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 6000 + writer_id as u16),
source_ip: IpAddr::V4(Ipv4Addr::LOCALHOST),
writer_dc: 2,
generation: 1,
contour: Arc::new(AtomicU8::new(WriterContour::Draining.as_u8())),
created_at: Instant::now() - Duration::from_secs(writer_id),
tx: tx.clone(),
cancel: CancellationToken::new(),
degraded: Arc::new(AtomicBool::new(false)),
rtt_ema_ms_x10: Arc::new(AtomicU32::new(0)),
draining: Arc::new(AtomicBool::new(true)),
draining_started_at_epoch_secs: Arc::new(AtomicU64::new(drain_started_at_epoch_secs)),
drain_deadline_epoch_secs: Arc::new(AtomicU64::new(drain_deadline_epoch_secs)),
allow_drain_fallback: Arc::new(AtomicBool::new(false)),
};
pool.writers.write().await.push(writer);
pool.registry.register_writer(writer_id, tx).await;
pool.conn_count.fetch_add(1, Ordering::Relaxed);
for idx in 0..bound_clients {
let (conn_id, _rx) = pool.registry.register().await;
assert!(
pool.registry
.bind_writer(
conn_id,
writer_id,
ConnMeta {
target_dc: 2,
client_addr: SocketAddr::new(
IpAddr::V4(Ipv4Addr::LOCALHOST),
8000 + idx as u16,
),
our_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443),
proto_flags: 0,
},
)
.await
);
}
}
async fn writer_count(pool: &Arc<MePool>) -> usize {
pool.writers.read().await.len()
}
async fn sorted_writer_ids(pool: &Arc<MePool>) -> Vec<u64> {
let mut ids = pool
.writers
.read()
.await
.iter()
.map(|writer| writer.id)
.collect::<Vec<_>>();
ids.sort_unstable();
ids
}
fn lcg_next(state: &mut u64) -> u64 {
*state = state.wrapping_mul(6364136223846793005).wrapping_add(1);
*state
}
async fn draining_writer_ids(pool: &Arc<MePool>) -> HashSet<u64> {
pool.writers
.read()
.await
.iter()
.filter(|writer| writer.draining.load(Ordering::Relaxed))
.map(|writer| writer.id)
.collect::<HashSet<u64>>()
}
async fn set_writer_runtime_state(
pool: &Arc<MePool>,
writer_id: u64,
draining: bool,
drain_started_at_epoch_secs: u64,
drain_deadline_epoch_secs: u64,
) {
let writers = pool.writers.read().await;
if let Some(writer) = writers.iter().find(|writer| writer.id == writer_id) {
writer.draining.store(draining, Ordering::Relaxed);
writer
.draining_started_at_epoch_secs
.store(drain_started_at_epoch_secs, Ordering::Relaxed);
writer
.drain_deadline_epoch_secs
.store(drain_deadline_epoch_secs, Ordering::Relaxed);
}
}
#[tokio::test]
async fn reap_draining_writers_clears_warn_state_when_pool_empty() {
let (pool, _rng) = make_pool(128, 1, 1).await;
let mut warn_next_allowed = HashMap::new();
warn_next_allowed.insert(11, Instant::now() + Duration::from_secs(5));
warn_next_allowed.insert(22, Instant::now() + Duration::from_secs(5));
reap_draining_writers(&pool, &mut warn_next_allowed).await;
assert!(warn_next_allowed.is_empty());
}
#[tokio::test]
async fn reap_draining_writers_respects_threshold_across_multiple_overflow_cycles() {
let threshold = 3u64;
let (pool, _rng) = make_pool(threshold, 1, 1).await;
let now_epoch_secs = MePool::now_epoch_secs();
for writer_id in 1..=60u64 {
insert_draining_writer(
&pool,
writer_id,
now_epoch_secs.saturating_sub(600).saturating_add(writer_id),
1,
0,
)
.await;
}
let mut warn_next_allowed = HashMap::new();
for _ in 0..64 {
reap_draining_writers(&pool, &mut warn_next_allowed).await;
if writer_count(&pool).await <= threshold as usize {
break;
}
}
assert_eq!(writer_count(&pool).await, threshold as usize);
assert_eq!(sorted_writer_ids(&pool).await, vec![58, 59, 60]);
}
#[tokio::test]
async fn reap_draining_writers_handles_large_empty_writer_population() {
let (pool, _rng) = make_pool(128, 1, 1).await;
let now_epoch_secs = MePool::now_epoch_secs();
let total = health_drain_close_budget().saturating_mul(3).saturating_add(27);
for writer_id in 1..=total as u64 {
insert_draining_writer(
&pool,
writer_id,
now_epoch_secs.saturating_sub(120),
0,
0,
)
.await;
}
let mut warn_next_allowed = HashMap::new();
for _ in 0..24 {
if writer_count(&pool).await == 0 {
break;
}
reap_draining_writers(&pool, &mut warn_next_allowed).await;
}
assert_eq!(writer_count(&pool).await, 0);
}
#[tokio::test]
async fn reap_draining_writers_processes_mass_deadline_expiry_without_unbounded_growth() {
let (pool, _rng) = make_pool(128, 1, 1).await;
let now_epoch_secs = MePool::now_epoch_secs();
let total = health_drain_close_budget().saturating_mul(4).saturating_add(31);
for writer_id in 1..=total as u64 {
insert_draining_writer(
&pool,
writer_id,
now_epoch_secs.saturating_sub(180),
1,
now_epoch_secs.saturating_sub(1),
)
.await;
}
let mut warn_next_allowed = HashMap::new();
for _ in 0..40 {
if writer_count(&pool).await == 0 {
break;
}
reap_draining_writers(&pool, &mut warn_next_allowed).await;
}
assert_eq!(writer_count(&pool).await, 0);
}
#[tokio::test]
async fn reap_draining_writers_maintains_warn_state_subset_property_under_bulk_churn() {
let (pool, _rng) = make_pool(128, 1, 1).await;
let now_epoch_secs = MePool::now_epoch_secs();
let mut warn_next_allowed = HashMap::new();
for wave in 0..40u64 {
for offset in 0..8u64 {
insert_draining_writer(
&pool,
wave * 100 + offset,
now_epoch_secs.saturating_sub(400 + offset),
1,
0,
)
.await;
}
reap_draining_writers(&pool, &mut warn_next_allowed).await;
assert!(warn_next_allowed.len() <= writer_count(&pool).await);
let ids = sorted_writer_ids(&pool).await;
for writer_id in ids.into_iter().take(3) {
let _ = pool.remove_writer_and_close_clients(writer_id).await;
}
reap_draining_writers(&pool, &mut warn_next_allowed).await;
assert!(warn_next_allowed.len() <= writer_count(&pool).await);
}
}
#[tokio::test]
async fn reap_draining_writers_budgeted_cleanup_never_increases_pool_size() {
let (pool, _rng) = make_pool(5, 1, 1).await;
let now_epoch_secs = MePool::now_epoch_secs();
for writer_id in 1..=200u64 {
insert_draining_writer(
&pool,
writer_id,
now_epoch_secs.saturating_sub(240).saturating_add(writer_id),
1,
0,
)
.await;
}
let mut warn_next_allowed = HashMap::new();
let mut previous = writer_count(&pool).await;
for _ in 0..32 {
reap_draining_writers(&pool, &mut warn_next_allowed).await;
let current = writer_count(&pool).await;
assert!(current <= previous);
previous = current;
}
}
#[tokio::test]
async fn me_health_monitor_converges_to_threshold_under_live_injection_churn() {
let threshold = 7u64;
let (pool, rng) = make_pool(threshold, 1, 1).await;
let now_epoch_secs = MePool::now_epoch_secs();
for writer_id in 1..=40u64 {
insert_draining_writer(
&pool,
writer_id,
now_epoch_secs.saturating_sub(300).saturating_add(writer_id),
1,
0,
)
.await;
}
let monitor = tokio::spawn(me_health_monitor(pool.clone(), rng, 0));
for wave in 0..8u64 {
for offset in 0..10u64 {
insert_draining_writer(
&pool,
1000 + wave * 100 + offset,
now_epoch_secs.saturating_sub(120).saturating_add(offset),
1,
0,
)
.await;
}
tokio::time::sleep(Duration::from_millis(5)).await;
}
tokio::time::sleep(Duration::from_millis(120)).await;
monitor.abort();
let _ = monitor.await;
assert!(writer_count(&pool).await <= threshold as usize);
}
#[tokio::test]
async fn me_health_monitor_drains_deadline_storm_with_budgeted_progress() {
let (pool, rng) = make_pool(128, 1, 1).await;
let now_epoch_secs = MePool::now_epoch_secs();
for writer_id in 1..=220u64 {
insert_draining_writer(
&pool,
writer_id,
now_epoch_secs.saturating_sub(120),
1,
now_epoch_secs.saturating_sub(1),
)
.await;
}
let monitor = tokio::spawn(me_health_monitor(pool.clone(), rng, 0));
tokio::time::sleep(Duration::from_millis(120)).await;
monitor.abort();
let _ = monitor.await;
assert_eq!(writer_count(&pool).await, 0);
}
#[tokio::test]
async fn me_health_monitor_eliminates_mixed_empty_and_deadline_backlog() {
let threshold = 12u64;
let (pool, rng) = make_pool(threshold, 1, 1).await;
let now_epoch_secs = MePool::now_epoch_secs();
for writer_id in 1..=180u64 {
let bound_clients = if writer_id % 3 == 0 { 0 } else { 1 };
let deadline = if writer_id % 2 == 0 {
now_epoch_secs.saturating_sub(1)
} else {
0
};
insert_draining_writer(
&pool,
writer_id,
now_epoch_secs.saturating_sub(250).saturating_add(writer_id),
bound_clients,
deadline,
)
.await;
}
let monitor = tokio::spawn(me_health_monitor(pool.clone(), rng, 0));
tokio::time::sleep(Duration::from_millis(140)).await;
monitor.abort();
let _ = monitor.await;
assert!(writer_count(&pool).await <= threshold as usize);
}
#[tokio::test]
async fn reap_draining_writers_deterministic_mixed_state_churn_preserves_invariants() {
let threshold = 9u64;
let (pool, _rng) = make_pool(threshold, 1, 1).await;
let mut warn_next_allowed = HashMap::new();
let mut seed = 0x9E37_79B9_7F4A_7C15u64;
let mut next_writer_id = 20_000u64;
let now_epoch_secs = MePool::now_epoch_secs();
for writer_id in 1..=72u64 {
let bound_clients = if writer_id % 4 == 0 { 0 } else { 1 };
let deadline = if writer_id % 5 == 0 {
now_epoch_secs.saturating_sub(1)
} else {
0
};
insert_draining_writer(
&pool,
writer_id,
now_epoch_secs.saturating_sub(500).saturating_add(writer_id),
bound_clients,
deadline,
)
.await;
}
for _round in 0..90 {
reap_draining_writers(&pool, &mut warn_next_allowed).await;
let draining_ids = draining_writer_ids(&pool).await;
assert!(
warn_next_allowed.keys().all(|id| draining_ids.contains(id)),
"warn-state keys must always be a subset of live draining writers"
);
let writer_ids = sorted_writer_ids(&pool).await;
if writer_ids.is_empty() {
continue;
}
let remove_n = (lcg_next(&mut seed) % 3) as usize;
for writer_id in writer_ids.iter().copied().take(remove_n) {
let _ = pool.remove_writer_and_close_clients(writer_id).await;
}
let survivors = sorted_writer_ids(&pool).await;
if !survivors.is_empty() {
let idx = (lcg_next(&mut seed) as usize) % survivors.len();
let target = survivors[idx];
set_writer_runtime_state(&pool, target, false, 0, 0).await;
}
let survivors = sorted_writer_ids(&pool).await;
if survivors.len() > 1 {
let idx = (lcg_next(&mut seed) as usize) % survivors.len();
let target = survivors[idx];
let expired_deadline = if lcg_next(&mut seed) & 1 == 0 {
now_epoch_secs.saturating_sub(1)
} else {
0
};
set_writer_runtime_state(
&pool,
target,
true,
now_epoch_secs.saturating_sub(120),
expired_deadline,
)
.await;
}
let inject_n = (lcg_next(&mut seed) % 4) as usize;
for _ in 0..inject_n {
let bound_clients = if lcg_next(&mut seed) & 1 == 0 { 0 } else { 1 };
let deadline = if lcg_next(&mut seed) & 1 == 0 {
now_epoch_secs.saturating_sub(1)
} else {
0
};
insert_draining_writer(
&pool,
next_writer_id,
now_epoch_secs.saturating_sub(240),
bound_clients,
deadline,
)
.await;
next_writer_id = next_writer_id.saturating_add(1);
}
}
for _ in 0..64 {
reap_draining_writers(&pool, &mut warn_next_allowed).await;
if writer_count(&pool).await <= threshold as usize {
break;
}
}
assert!(writer_count(&pool).await <= threshold as usize);
let draining_ids = draining_writer_ids(&pool).await;
assert!(warn_next_allowed.keys().all(|id| draining_ids.contains(id)));
}
#[tokio::test]
async fn reap_draining_writers_repeated_draining_flips_never_leave_stale_warn_state() {
let (pool, _rng) = make_pool(64, 1, 1).await;
let now_epoch_secs = MePool::now_epoch_secs();
for writer_id in 1..=24u64 {
insert_draining_writer(
&pool,
writer_id,
now_epoch_secs.saturating_sub(240),
1,
0,
)
.await;
}
let mut warn_next_allowed = HashMap::new();
for _round in 0..48u64 {
for writer_id in 1..=24u64 {
let draining = (writer_id + _round) % 3 != 0;
set_writer_runtime_state(
&pool,
writer_id,
draining,
now_epoch_secs.saturating_sub(120),
0,
)
.await;
}
reap_draining_writers(&pool, &mut warn_next_allowed).await;
let draining_ids = draining_writer_ids(&pool).await;
assert!(
warn_next_allowed.keys().all(|id| draining_ids.contains(id)),
"warn-state map must not retain entries for writers outside draining set"
);
}
}
#[test]
fn health_drain_close_budget_is_within_expected_bounds() {
let budget = health_drain_close_budget();
assert!((16..=256).contains(&budget));
}

View File

@ -0,0 +1,241 @@
use std::collections::HashMap;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU8, AtomicU32, AtomicU64, Ordering};
use std::time::{Duration, Instant};
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use super::codec::WriterCommand;
use super::health::health_drain_close_budget;
use super::pool::{MePool, MeWriter, WriterContour};
use super::registry::ConnMeta;
use super::me_health_monitor;
use crate::config::{GeneralConfig, MeRouteNoWriterMode, MeSocksKdfPolicy, MeWriterPickMode};
use crate::crypto::SecureRandom;
use crate::network::probe::NetworkDecision;
use crate::stats::Stats;
async fn make_pool(
me_pool_drain_threshold: u64,
me_health_interval_ms_unhealthy: u64,
me_health_interval_ms_healthy: u64,
) -> (Arc<MePool>, Arc<SecureRandom>) {
let general = GeneralConfig {
me_pool_drain_threshold,
me_health_interval_ms_unhealthy,
me_health_interval_ms_healthy,
..GeneralConfig::default()
};
let rng = Arc::new(SecureRandom::new());
let pool = MePool::new(
None,
vec![1u8; 32],
None,
false,
None,
Vec::new(),
1,
None,
12,
1200,
HashMap::new(),
HashMap::new(),
None,
NetworkDecision::default(),
None,
rng.clone(),
Arc::new(Stats::default()),
general.me_keepalive_enabled,
general.me_keepalive_interval_secs,
general.me_keepalive_jitter_secs,
general.me_keepalive_payload_random,
general.rpc_proxy_req_every,
general.me_warmup_stagger_enabled,
general.me_warmup_step_delay_ms,
general.me_warmup_step_jitter_ms,
general.me_reconnect_max_concurrent_per_dc,
general.me_reconnect_backoff_base_ms,
general.me_reconnect_backoff_cap_ms,
general.me_reconnect_fast_retry_count,
general.me_single_endpoint_shadow_writers,
general.me_single_endpoint_outage_mode_enabled,
general.me_single_endpoint_outage_disable_quarantine,
general.me_single_endpoint_outage_backoff_min_ms,
general.me_single_endpoint_outage_backoff_max_ms,
general.me_single_endpoint_shadow_rotate_every_secs,
general.me_floor_mode,
general.me_adaptive_floor_idle_secs,
general.me_adaptive_floor_min_writers_single_endpoint,
general.me_adaptive_floor_min_writers_multi_endpoint,
general.me_adaptive_floor_recover_grace_secs,
general.me_adaptive_floor_writers_per_core_total,
general.me_adaptive_floor_cpu_cores_override,
general.me_adaptive_floor_max_extra_writers_single_per_core,
general.me_adaptive_floor_max_extra_writers_multi_per_core,
general.me_adaptive_floor_max_active_writers_per_core,
general.me_adaptive_floor_max_warm_writers_per_core,
general.me_adaptive_floor_max_active_writers_global,
general.me_adaptive_floor_max_warm_writers_global,
general.hardswap,
general.me_pool_drain_ttl_secs,
general.me_pool_drain_threshold,
general.effective_me_pool_force_close_secs(),
general.me_pool_min_fresh_ratio,
general.me_hardswap_warmup_delay_min_ms,
general.me_hardswap_warmup_delay_max_ms,
general.me_hardswap_warmup_extra_passes,
general.me_hardswap_warmup_pass_backoff_base_ms,
general.me_bind_stale_mode,
general.me_bind_stale_ttl_secs,
general.me_secret_atomic_snapshot,
general.me_deterministic_writer_sort,
MeWriterPickMode::default(),
general.me_writer_pick_sample_size,
MeSocksKdfPolicy::default(),
general.me_writer_cmd_channel_capacity,
general.me_route_channel_capacity,
general.me_route_backpressure_base_timeout_ms,
general.me_route_backpressure_high_timeout_ms,
general.me_route_backpressure_high_watermark_pct,
general.me_reader_route_data_wait_ms,
general.me_health_interval_ms_unhealthy,
general.me_health_interval_ms_healthy,
general.me_warn_rate_limit_ms,
MeRouteNoWriterMode::default(),
general.me_route_no_writer_wait_ms,
general.me_route_inline_recovery_attempts,
general.me_route_inline_recovery_wait_ms,
);
(pool, rng)
}
async fn insert_draining_writer(
pool: &Arc<MePool>,
writer_id: u64,
drain_started_at_epoch_secs: u64,
bound_clients: usize,
drain_deadline_epoch_secs: u64,
) {
let (tx, _writer_rx) = mpsc::channel::<WriterCommand>(8);
let writer = MeWriter {
id: writer_id,
addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 5500 + writer_id as u16),
source_ip: IpAddr::V4(Ipv4Addr::LOCALHOST),
writer_dc: 2,
generation: 1,
contour: Arc::new(AtomicU8::new(WriterContour::Draining.as_u8())),
created_at: Instant::now() - Duration::from_secs(writer_id),
tx: tx.clone(),
cancel: CancellationToken::new(),
degraded: Arc::new(AtomicBool::new(false)),
rtt_ema_ms_x10: Arc::new(AtomicU32::new(0)),
draining: Arc::new(AtomicBool::new(true)),
draining_started_at_epoch_secs: Arc::new(AtomicU64::new(drain_started_at_epoch_secs)),
drain_deadline_epoch_secs: Arc::new(AtomicU64::new(drain_deadline_epoch_secs)),
allow_drain_fallback: Arc::new(AtomicBool::new(false)),
};
pool.writers.write().await.push(writer);
pool.registry.register_writer(writer_id, tx).await;
pool.conn_count.fetch_add(1, Ordering::Relaxed);
for idx in 0..bound_clients {
let (conn_id, _rx) = pool.registry.register().await;
assert!(
pool.registry
.bind_writer(
conn_id,
writer_id,
ConnMeta {
target_dc: 2,
client_addr: SocketAddr::new(
IpAddr::V4(Ipv4Addr::LOCALHOST),
7200 + idx as u16,
),
our_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443),
proto_flags: 0,
},
)
.await
);
}
}
async fn wait_for_pool_empty(pool: &Arc<MePool>, timeout: Duration) {
let start = Instant::now();
loop {
if pool.writers.read().await.is_empty() {
return;
}
assert!(
start.elapsed() < timeout,
"timed out waiting for pool.writers to become empty"
);
tokio::time::sleep(Duration::from_millis(5)).await;
}
}
#[tokio::test]
async fn me_health_monitor_drains_expired_backlog_over_multiple_cycles() {
let (pool, rng) = make_pool(128, 1, 1).await;
let now_epoch_secs = MePool::now_epoch_secs();
let writer_total = health_drain_close_budget().saturating_mul(2).saturating_add(9);
for writer_id in 1..=writer_total as u64 {
insert_draining_writer(
&pool,
writer_id,
now_epoch_secs.saturating_sub(120),
1,
now_epoch_secs.saturating_sub(1),
)
.await;
}
let monitor = tokio::spawn(me_health_monitor(pool.clone(), rng, 0));
wait_for_pool_empty(&pool, Duration::from_secs(1)).await;
monitor.abort();
let _ = monitor.await;
assert!(pool.writers.read().await.is_empty());
}
#[tokio::test]
async fn me_health_monitor_cleans_empty_draining_writers_without_force_close() {
let (pool, rng) = make_pool(128, 1, 1).await;
let now_epoch_secs = MePool::now_epoch_secs();
for writer_id in 1..=24u64 {
insert_draining_writer(&pool, writer_id, now_epoch_secs.saturating_sub(60), 0, 0).await;
}
let monitor = tokio::spawn(me_health_monitor(pool.clone(), rng, 0));
wait_for_pool_empty(&pool, Duration::from_secs(1)).await;
monitor.abort();
let _ = monitor.await;
assert!(pool.writers.read().await.is_empty());
}
#[tokio::test]
async fn me_health_monitor_converges_retry_like_threshold_backlog_to_empty() {
let threshold = 4u64;
let (pool, rng) = make_pool(threshold, 1, 1).await;
let now_epoch_secs = MePool::now_epoch_secs();
let writer_total = threshold as usize + health_drain_close_budget().saturating_add(11);
for writer_id in 1..=writer_total as u64 {
insert_draining_writer(
&pool,
writer_id,
now_epoch_secs.saturating_sub(300).saturating_add(writer_id),
1,
0,
)
.await;
}
let monitor = tokio::spawn(me_health_monitor(pool.clone(), rng, 0));
wait_for_pool_empty(&pool, Duration::from_secs(1)).await;
monitor.abort();
let _ = monitor.await;
assert!(pool.writers.read().await.is_empty());
}

View File

@ -0,0 +1,658 @@
use std::collections::HashMap;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU8, AtomicU32, AtomicU64, Ordering};
use std::time::{Duration, Instant};
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use super::codec::WriterCommand;
use super::health::{health_drain_close_budget, reap_draining_writers};
use super::pool::{MePool, MeWriter, WriterContour};
use super::registry::ConnMeta;
use crate::config::{GeneralConfig, MeRouteNoWriterMode, MeSocksKdfPolicy, MeWriterPickMode};
use crate::crypto::SecureRandom;
use crate::network::probe::NetworkDecision;
use crate::stats::Stats;
async fn make_pool(me_pool_drain_threshold: u64) -> Arc<MePool> {
let general = GeneralConfig {
me_pool_drain_threshold,
..GeneralConfig::default()
};
MePool::new(
None,
vec![1u8; 32],
None,
false,
None,
Vec::new(),
1,
None,
12,
1200,
HashMap::new(),
HashMap::new(),
None,
NetworkDecision::default(),
None,
Arc::new(SecureRandom::new()),
Arc::new(Stats::default()),
general.me_keepalive_enabled,
general.me_keepalive_interval_secs,
general.me_keepalive_jitter_secs,
general.me_keepalive_payload_random,
general.rpc_proxy_req_every,
general.me_warmup_stagger_enabled,
general.me_warmup_step_delay_ms,
general.me_warmup_step_jitter_ms,
general.me_reconnect_max_concurrent_per_dc,
general.me_reconnect_backoff_base_ms,
general.me_reconnect_backoff_cap_ms,
general.me_reconnect_fast_retry_count,
general.me_single_endpoint_shadow_writers,
general.me_single_endpoint_outage_mode_enabled,
general.me_single_endpoint_outage_disable_quarantine,
general.me_single_endpoint_outage_backoff_min_ms,
general.me_single_endpoint_outage_backoff_max_ms,
general.me_single_endpoint_shadow_rotate_every_secs,
general.me_floor_mode,
general.me_adaptive_floor_idle_secs,
general.me_adaptive_floor_min_writers_single_endpoint,
general.me_adaptive_floor_min_writers_multi_endpoint,
general.me_adaptive_floor_recover_grace_secs,
general.me_adaptive_floor_writers_per_core_total,
general.me_adaptive_floor_cpu_cores_override,
general.me_adaptive_floor_max_extra_writers_single_per_core,
general.me_adaptive_floor_max_extra_writers_multi_per_core,
general.me_adaptive_floor_max_active_writers_per_core,
general.me_adaptive_floor_max_warm_writers_per_core,
general.me_adaptive_floor_max_active_writers_global,
general.me_adaptive_floor_max_warm_writers_global,
general.hardswap,
general.me_pool_drain_ttl_secs,
general.me_pool_drain_threshold,
general.effective_me_pool_force_close_secs(),
general.me_pool_min_fresh_ratio,
general.me_hardswap_warmup_delay_min_ms,
general.me_hardswap_warmup_delay_max_ms,
general.me_hardswap_warmup_extra_passes,
general.me_hardswap_warmup_pass_backoff_base_ms,
general.me_bind_stale_mode,
general.me_bind_stale_ttl_secs,
general.me_secret_atomic_snapshot,
general.me_deterministic_writer_sort,
MeWriterPickMode::default(),
general.me_writer_pick_sample_size,
MeSocksKdfPolicy::default(),
general.me_writer_cmd_channel_capacity,
general.me_route_channel_capacity,
general.me_route_backpressure_base_timeout_ms,
general.me_route_backpressure_high_timeout_ms,
general.me_route_backpressure_high_watermark_pct,
general.me_reader_route_data_wait_ms,
general.me_health_interval_ms_unhealthy,
general.me_health_interval_ms_healthy,
general.me_warn_rate_limit_ms,
MeRouteNoWriterMode::default(),
general.me_route_no_writer_wait_ms,
general.me_route_inline_recovery_attempts,
general.me_route_inline_recovery_wait_ms,
)
}
async fn insert_draining_writer(
pool: &Arc<MePool>,
writer_id: u64,
drain_started_at_epoch_secs: u64,
bound_clients: usize,
drain_deadline_epoch_secs: u64,
) -> Vec<u64> {
let mut conn_ids = Vec::with_capacity(bound_clients);
let (tx, _writer_rx) = mpsc::channel::<WriterCommand>(8);
let writer = MeWriter {
id: writer_id,
addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 4500 + writer_id as u16),
source_ip: IpAddr::V4(Ipv4Addr::LOCALHOST),
writer_dc: 2,
generation: 1,
contour: Arc::new(AtomicU8::new(WriterContour::Draining.as_u8())),
created_at: Instant::now() - Duration::from_secs(writer_id),
tx: tx.clone(),
cancel: CancellationToken::new(),
degraded: Arc::new(AtomicBool::new(false)),
rtt_ema_ms_x10: Arc::new(AtomicU32::new(0)),
draining: Arc::new(AtomicBool::new(true)),
draining_started_at_epoch_secs: Arc::new(AtomicU64::new(drain_started_at_epoch_secs)),
drain_deadline_epoch_secs: Arc::new(AtomicU64::new(drain_deadline_epoch_secs)),
allow_drain_fallback: Arc::new(AtomicBool::new(false)),
};
pool.writers.write().await.push(writer);
pool.registry.register_writer(writer_id, tx).await;
pool.conn_count.fetch_add(1, Ordering::Relaxed);
for idx in 0..bound_clients {
let (conn_id, _rx) = pool.registry.register().await;
assert!(
pool.registry
.bind_writer(
conn_id,
writer_id,
ConnMeta {
target_dc: 2,
client_addr: SocketAddr::new(
IpAddr::V4(Ipv4Addr::LOCALHOST),
6200 + idx as u16,
),
our_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443),
proto_flags: 0,
},
)
.await
);
conn_ids.push(conn_id);
}
conn_ids
}
async fn current_writer_ids(pool: &Arc<MePool>) -> Vec<u64> {
let mut writer_ids = pool
.writers
.read()
.await
.iter()
.map(|writer| writer.id)
.collect::<Vec<_>>();
writer_ids.sort_unstable();
writer_ids
}
async fn writer_exists(pool: &Arc<MePool>, writer_id: u64) -> bool {
pool.writers
.read()
.await
.iter()
.any(|writer| writer.id == writer_id)
}
async fn set_writer_draining(pool: &Arc<MePool>, writer_id: u64, draining: bool) {
let writers = pool.writers.read().await;
if let Some(writer) = writers.iter().find(|writer| writer.id == writer_id) {
writer.draining.store(draining, Ordering::Relaxed);
}
}
#[tokio::test]
async fn reap_draining_writers_drops_warn_state_for_removed_writer() {
let pool = make_pool(128).await;
let now_epoch_secs = MePool::now_epoch_secs();
let conn_ids =
insert_draining_writer(&pool, 7, now_epoch_secs.saturating_sub(180), 1, 0).await;
let mut warn_next_allowed = HashMap::new();
reap_draining_writers(&pool, &mut warn_next_allowed).await;
assert!(warn_next_allowed.contains_key(&7));
let _ = pool.remove_writer_and_close_clients(7).await;
assert!(pool.registry.get_writer(conn_ids[0]).await.is_none());
reap_draining_writers(&pool, &mut warn_next_allowed).await;
assert!(!warn_next_allowed.contains_key(&7));
}
#[tokio::test]
async fn reap_draining_writers_removes_empty_draining_writers() {
let pool = make_pool(128).await;
let now_epoch_secs = MePool::now_epoch_secs();
insert_draining_writer(&pool, 1, now_epoch_secs.saturating_sub(40), 0, 0).await;
insert_draining_writer(&pool, 2, now_epoch_secs.saturating_sub(30), 0, 0).await;
insert_draining_writer(&pool, 3, now_epoch_secs.saturating_sub(20), 1, 0).await;
let mut warn_next_allowed = HashMap::new();
reap_draining_writers(&pool, &mut warn_next_allowed).await;
assert_eq!(current_writer_ids(&pool).await, vec![3]);
}
#[tokio::test]
async fn reap_draining_writers_overflow_closes_oldest_non_empty_writers() {
let pool = make_pool(2).await;
let now_epoch_secs = MePool::now_epoch_secs();
insert_draining_writer(&pool, 11, now_epoch_secs.saturating_sub(40), 1, 0).await;
insert_draining_writer(&pool, 22, now_epoch_secs.saturating_sub(30), 1, 0).await;
insert_draining_writer(&pool, 33, now_epoch_secs.saturating_sub(20), 1, 0).await;
insert_draining_writer(&pool, 44, now_epoch_secs.saturating_sub(10), 1, 0).await;
let mut warn_next_allowed = HashMap::new();
reap_draining_writers(&pool, &mut warn_next_allowed).await;
assert_eq!(current_writer_ids(&pool).await, vec![33, 44]);
}
#[tokio::test]
async fn reap_draining_writers_deadline_force_close_applies_under_threshold() {
let pool = make_pool(128).await;
let now_epoch_secs = MePool::now_epoch_secs();
insert_draining_writer(
&pool,
50,
now_epoch_secs.saturating_sub(15),
1,
now_epoch_secs.saturating_sub(1),
)
.await;
let mut warn_next_allowed = HashMap::new();
reap_draining_writers(&pool, &mut warn_next_allowed).await;
assert!(current_writer_ids(&pool).await.is_empty());
}
#[tokio::test]
async fn reap_draining_writers_limits_closes_per_health_tick() {
let pool = make_pool(128).await;
let now_epoch_secs = MePool::now_epoch_secs();
let close_budget = health_drain_close_budget();
let writer_total = close_budget.saturating_add(19);
for writer_id in 1..=writer_total as u64 {
insert_draining_writer(
&pool,
writer_id,
now_epoch_secs.saturating_sub(20),
1,
now_epoch_secs.saturating_sub(1),
)
.await;
}
let mut warn_next_allowed = HashMap::new();
reap_draining_writers(&pool, &mut warn_next_allowed).await;
assert_eq!(pool.writers.read().await.len(), writer_total - close_budget);
}
#[tokio::test]
async fn reap_draining_writers_keeps_warn_state_for_deadline_backlog_writers() {
let pool = make_pool(0).await;
let now_epoch_secs = MePool::now_epoch_secs();
let close_budget = health_drain_close_budget();
let writer_total = close_budget.saturating_add(5);
for writer_id in 1..=writer_total as u64 {
insert_draining_writer(
&pool,
writer_id,
now_epoch_secs.saturating_sub(60),
1,
now_epoch_secs.saturating_sub(1),
)
.await;
}
let target_writer_id = writer_total as u64;
let mut warn_next_allowed = HashMap::new();
warn_next_allowed.insert(
target_writer_id,
Instant::now() + Duration::from_secs(300),
);
reap_draining_writers(&pool, &mut warn_next_allowed).await;
assert!(writer_exists(&pool, target_writer_id).await);
assert!(warn_next_allowed.contains_key(&target_writer_id));
}
#[tokio::test]
async fn reap_draining_writers_keeps_warn_state_for_overflow_backlog_writers() {
let pool = make_pool(1).await;
let now_epoch_secs = MePool::now_epoch_secs();
let close_budget = health_drain_close_budget();
let writer_total = close_budget.saturating_add(6);
for writer_id in 1..=writer_total as u64 {
insert_draining_writer(
&pool,
writer_id,
now_epoch_secs.saturating_sub(300).saturating_add(writer_id),
1,
0,
)
.await;
}
let target_writer_id = writer_total.saturating_sub(1) as u64;
let mut warn_next_allowed = HashMap::new();
warn_next_allowed.insert(
target_writer_id,
Instant::now() + Duration::from_secs(300),
);
reap_draining_writers(&pool, &mut warn_next_allowed).await;
assert!(writer_exists(&pool, target_writer_id).await);
assert!(warn_next_allowed.contains_key(&target_writer_id));
}
#[tokio::test]
async fn reap_draining_writers_drops_warn_state_when_writer_exits_draining_state() {
let pool = make_pool(128).await;
let now_epoch_secs = MePool::now_epoch_secs();
insert_draining_writer(&pool, 71, now_epoch_secs.saturating_sub(60), 1, 0).await;
let mut warn_next_allowed = HashMap::new();
warn_next_allowed.insert(71, Instant::now() + Duration::from_secs(300));
set_writer_draining(&pool, 71, false).await;
reap_draining_writers(&pool, &mut warn_next_allowed).await;
assert!(writer_exists(&pool, 71).await);
assert!(
!warn_next_allowed.contains_key(&71),
"warn cooldown state must be dropped after writer leaves draining state"
);
}
#[tokio::test]
async fn reap_draining_writers_preserves_warn_state_across_multiple_budget_deferrals() {
let pool = make_pool(0).await;
let now_epoch_secs = MePool::now_epoch_secs();
let close_budget = health_drain_close_budget();
let writer_total = close_budget.saturating_mul(2).saturating_add(1);
for writer_id in 1..=writer_total as u64 {
insert_draining_writer(
&pool,
writer_id,
now_epoch_secs.saturating_sub(120),
1,
now_epoch_secs.saturating_sub(1),
)
.await;
}
let tail_writer_id = writer_total as u64;
let mut warn_next_allowed = HashMap::new();
warn_next_allowed.insert(
tail_writer_id,
Instant::now() + Duration::from_secs(300),
);
reap_draining_writers(&pool, &mut warn_next_allowed).await;
assert!(writer_exists(&pool, tail_writer_id).await);
assert!(warn_next_allowed.contains_key(&tail_writer_id));
reap_draining_writers(&pool, &mut warn_next_allowed).await;
assert!(writer_exists(&pool, tail_writer_id).await);
assert!(warn_next_allowed.contains_key(&tail_writer_id));
reap_draining_writers(&pool, &mut warn_next_allowed).await;
assert!(!writer_exists(&pool, tail_writer_id).await);
assert!(
!warn_next_allowed.contains_key(&tail_writer_id),
"warn cooldown state must clear once writer is actually removed"
);
}
#[tokio::test]
async fn reap_draining_writers_backlog_drains_across_ticks() {
let pool = make_pool(128).await;
let now_epoch_secs = MePool::now_epoch_secs();
let close_budget = health_drain_close_budget();
let writer_total = close_budget.saturating_mul(2).saturating_add(7);
for writer_id in 1..=writer_total as u64 {
insert_draining_writer(
&pool,
writer_id,
now_epoch_secs.saturating_sub(20),
1,
now_epoch_secs.saturating_sub(1),
)
.await;
}
let mut warn_next_allowed = HashMap::new();
for _ in 0..8 {
if pool.writers.read().await.is_empty() {
break;
}
reap_draining_writers(&pool, &mut warn_next_allowed).await;
}
assert!(pool.writers.read().await.is_empty());
}
#[tokio::test]
async fn reap_draining_writers_threshold_backlog_converges_to_threshold() {
let threshold = 5u64;
let pool = make_pool(threshold).await;
let now_epoch_secs = MePool::now_epoch_secs();
let close_budget = health_drain_close_budget();
let writer_total = threshold as usize + close_budget.saturating_add(12);
for writer_id in 1..=writer_total as u64 {
insert_draining_writer(
&pool,
writer_id,
now_epoch_secs.saturating_sub(200).saturating_add(writer_id),
1,
0,
)
.await;
}
let mut warn_next_allowed = HashMap::new();
for _ in 0..16 {
reap_draining_writers(&pool, &mut warn_next_allowed).await;
if pool.writers.read().await.len() <= threshold as usize {
break;
}
}
assert_eq!(pool.writers.read().await.len(), threshold as usize);
}
#[tokio::test]
async fn reap_draining_writers_threshold_zero_preserves_non_expired_non_empty_writers() {
let pool = make_pool(0).await;
let now_epoch_secs = MePool::now_epoch_secs();
insert_draining_writer(&pool, 10, now_epoch_secs.saturating_sub(40), 1, 0).await;
insert_draining_writer(&pool, 20, now_epoch_secs.saturating_sub(30), 1, 0).await;
insert_draining_writer(&pool, 30, now_epoch_secs.saturating_sub(20), 1, 0).await;
let mut warn_next_allowed = HashMap::new();
reap_draining_writers(&pool, &mut warn_next_allowed).await;
assert_eq!(current_writer_ids(&pool).await, vec![10, 20, 30]);
}
#[tokio::test]
async fn reap_draining_writers_prioritizes_force_close_before_empty_cleanup() {
let pool = make_pool(128).await;
let now_epoch_secs = MePool::now_epoch_secs();
let close_budget = health_drain_close_budget();
for writer_id in 1..=close_budget as u64 {
insert_draining_writer(
&pool,
writer_id,
now_epoch_secs.saturating_sub(20),
1,
now_epoch_secs.saturating_sub(1),
)
.await;
}
let empty_writer_id = close_budget as u64 + 1;
insert_draining_writer(&pool, empty_writer_id, now_epoch_secs.saturating_sub(20), 0, 0).await;
let mut warn_next_allowed = HashMap::new();
reap_draining_writers(&pool, &mut warn_next_allowed).await;
assert_eq!(current_writer_ids(&pool).await, vec![empty_writer_id]);
}
#[tokio::test]
async fn reap_draining_writers_empty_cleanup_does_not_increment_force_close_metric() {
let pool = make_pool(128).await;
let now_epoch_secs = MePool::now_epoch_secs();
insert_draining_writer(&pool, 1, now_epoch_secs.saturating_sub(60), 0, 0).await;
insert_draining_writer(&pool, 2, now_epoch_secs.saturating_sub(50), 0, 0).await;
let mut warn_next_allowed = HashMap::new();
reap_draining_writers(&pool, &mut warn_next_allowed).await;
assert!(current_writer_ids(&pool).await.is_empty());
assert_eq!(pool.stats.get_pool_force_close_total(), 0);
}
#[tokio::test]
async fn reap_draining_writers_handles_duplicate_force_close_requests_for_same_writer() {
let pool = make_pool(1).await;
let now_epoch_secs = MePool::now_epoch_secs();
insert_draining_writer(
&pool,
10,
now_epoch_secs.saturating_sub(30),
1,
now_epoch_secs.saturating_sub(1),
)
.await;
insert_draining_writer(
&pool,
20,
now_epoch_secs.saturating_sub(20),
1,
now_epoch_secs.saturating_sub(1),
)
.await;
let mut warn_next_allowed = HashMap::new();
reap_draining_writers(&pool, &mut warn_next_allowed).await;
assert!(current_writer_ids(&pool).await.is_empty());
}
#[tokio::test]
async fn reap_draining_writers_warn_state_never_exceeds_live_draining_population_under_churn() {
let pool = make_pool(128).await;
let now_epoch_secs = MePool::now_epoch_secs();
let mut warn_next_allowed = HashMap::new();
for wave in 0..12u64 {
for offset in 0..9u64 {
insert_draining_writer(
&pool,
wave * 100 + offset,
now_epoch_secs.saturating_sub(120 + offset),
1,
0,
)
.await;
}
reap_draining_writers(&pool, &mut warn_next_allowed).await;
assert!(warn_next_allowed.len() <= pool.writers.read().await.len());
let existing_writer_ids = current_writer_ids(&pool).await;
for writer_id in existing_writer_ids.into_iter().take(4) {
let _ = pool.remove_writer_and_close_clients(writer_id).await;
}
reap_draining_writers(&pool, &mut warn_next_allowed).await;
assert!(warn_next_allowed.len() <= pool.writers.read().await.len());
}
}
#[tokio::test]
async fn reap_draining_writers_mixed_backlog_converges_without_leaking_warn_state() {
let pool = make_pool(6).await;
let now_epoch_secs = MePool::now_epoch_secs();
let mut warn_next_allowed = HashMap::new();
for writer_id in 1..=18u64 {
let bound_clients = if writer_id % 3 == 0 { 0 } else { 1 };
let deadline = if writer_id % 2 == 0 {
now_epoch_secs.saturating_sub(1)
} else {
0
};
insert_draining_writer(
&pool,
writer_id,
now_epoch_secs.saturating_sub(300).saturating_add(writer_id),
bound_clients,
deadline,
)
.await;
}
for _ in 0..16 {
reap_draining_writers(&pool, &mut warn_next_allowed).await;
if pool.writers.read().await.len() <= 6 {
break;
}
}
assert!(pool.writers.read().await.len() <= 6);
assert!(warn_next_allowed.len() <= pool.writers.read().await.len());
}
#[test]
fn general_config_default_drain_threshold_remains_enabled() {
assert_eq!(GeneralConfig::default().me_pool_drain_threshold, 128);
}
#[tokio::test]
async fn reap_draining_writers_does_not_close_writer_that_became_non_empty_after_snapshot() {
let pool = make_pool(128).await;
let now_epoch_secs = MePool::now_epoch_secs();
let empty_writer_id = 700u64;
insert_draining_writer(
&pool,
empty_writer_id,
now_epoch_secs.saturating_sub(60),
0,
0,
)
.await;
let stale_empty_snapshot = vec![empty_writer_id];
let (rebound_conn_id, _rx) = pool.registry.register().await;
assert!(
pool.registry
.bind_writer(
rebound_conn_id,
empty_writer_id,
ConnMeta {
target_dc: 2,
client_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 9050),
our_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443),
proto_flags: 0,
},
)
.await,
"writer should accept a new bind after stale empty snapshot"
);
for writer_id in stale_empty_snapshot {
assert!(
!pool.remove_writer_if_empty(writer_id).await,
"atomic empty cleanup must reject writers that gained bound clients"
);
}
assert!(
writer_exists(&pool, empty_writer_id).await,
"empty-path cleanup must not remove a writer that gained a bound client"
);
assert_eq!(
pool.registry.get_writer(rebound_conn_id).await.map(|w| w.writer_id),
Some(empty_writer_id)
);
let _ = pool.registry.unregister(rebound_conn_id).await;
}
#[tokio::test]
async fn prune_closed_writers_closes_bound_clients_when_writer_is_non_empty() {
let pool = make_pool(128).await;
let now_epoch_secs = MePool::now_epoch_secs();
let conn_ids = insert_draining_writer(&pool, 910, now_epoch_secs.saturating_sub(60), 1, 0).await;
pool.prune_closed_writers().await;
assert!(!writer_exists(&pool, 910).await);
assert!(pool.registry.get_writer(conn_ids[0]).await.is_none());
}

View File

@ -21,6 +21,12 @@ mod secret;
mod selftest; mod selftest;
mod wire; mod wire;
mod pool_status; mod pool_status;
#[cfg(test)]
mod health_regression_tests;
#[cfg(test)]
mod health_integration_tests;
#[cfg(test)]
mod health_adversarial_tests;
use bytes::Bytes; use bytes::Bytes;

View File

@ -42,11 +42,10 @@ impl MePool {
} }
for writer_id in closed_writer_ids { for writer_id in closed_writer_ids {
if self.registry.is_writer_empty(writer_id).await { if self.remove_writer_if_empty(writer_id).await {
let _ = self.remove_writer_only(writer_id).await; continue;
} else {
let _ = self.remove_writer_and_close_clients(writer_id).await;
} }
let _ = self.remove_writer_and_close_clients(writer_id).await;
} }
} }
@ -501,6 +500,17 @@ impl MePool {
} }
} }
pub(crate) async fn remove_writer_if_empty(self: &Arc<Self>, writer_id: u64) -> bool {
if !self.registry.unregister_writer_if_empty(writer_id).await {
return false;
}
// The registry empty-check and unregister are atomic with respect to binds,
// so remove_writer_only cannot return active bound sessions here.
let _ = self.remove_writer_only(writer_id).await;
true
}
async fn remove_writer_only(self: &Arc<Self>, writer_id: u64) -> Vec<BoundConn> { async fn remove_writer_only(self: &Arc<Self>, writer_id: u64) -> Vec<BoundConn> {
let mut close_tx: Option<mpsc::Sender<WriterCommand>> = None; let mut close_tx: Option<mpsc::Sender<WriterCommand>> = None;
let mut removed_addr: Option<SocketAddr> = None; let mut removed_addr: Option<SocketAddr> = None;

View File

@ -437,6 +437,23 @@ impl ConnRegistry {
.unwrap_or(true) .unwrap_or(true)
} }
pub async fn unregister_writer_if_empty(&self, writer_id: u64) -> bool {
let mut inner = self.inner.write().await;
let Some(conn_ids) = inner.conns_for_writer.get(&writer_id) else {
// Writer is already absent from the registry.
return true;
};
if !conn_ids.is_empty() {
return false;
}
inner.writers.remove(&writer_id);
inner.last_meta_for_writer.remove(&writer_id);
inner.writer_idle_since_epoch_secs.remove(&writer_id);
inner.conns_for_writer.remove(&writer_id);
true
}
pub(super) async fn non_empty_writer_ids(&self, writer_ids: &[u64]) -> HashSet<u64> { pub(super) async fn non_empty_writer_ids(&self, writer_ids: &[u64]) -> HashSet<u64> {
let inner = self.inner.read().await; let inner = self.inner.read().await;
let mut out = HashSet::<u64>::with_capacity(writer_ids.len()); let mut out = HashSet::<u64>::with_capacity(writer_ids.len());