Refactor health management: implement remove_writer_if_empty method for cleaner writer removal logic and update related functions to enhance efficiency in handling closed writers.

This commit is contained in:
David Osipov 2026-03-17 21:38:15 +04:00
parent 60953bcc2c
commit f0c37f233e
No known key found for this signature in database
GPG Key ID: 0E55C4A47454E82E
8 changed files with 822 additions and 21 deletions

View File

@ -317,6 +317,24 @@ fn decode_user_secrets(
secrets
}
async fn maybe_apply_server_hello_delay(config: &ProxyConfig) {
if config.censorship.server_hello_delay_max_ms == 0 {
return;
}
let min = config.censorship.server_hello_delay_min_ms;
let max = config.censorship.server_hello_delay_max_ms.max(min);
let delay_ms = if max == min {
max
} else {
rand::rng().random_range(min..=max)
};
if delay_ms > 0 {
tokio::time::sleep(Duration::from_millis(delay_ms)).await;
}
}
/// Result of successful handshake
///
/// Key material (`dec_key`, `dec_iv`, `enc_key`, `enc_iv`) is
@ -368,11 +386,13 @@ where
debug!(peer = %peer, handshake_len = handshake.len(), "Processing TLS handshake");
if auth_probe_is_throttled(peer.ip(), Instant::now()) {
maybe_apply_server_hello_delay(config).await;
debug!(peer = %peer, "TLS handshake rejected by pre-auth probe throttle");
return HandshakeResult::BadClient { reader, writer };
}
if handshake.len() < tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 {
maybe_apply_server_hello_delay(config).await;
debug!(peer = %peer, "TLS handshake too short");
return HandshakeResult::BadClient { reader, writer };
}
@ -388,6 +408,7 @@ where
Some(v) => v,
None => {
auth_probe_record_failure(peer.ip(), Instant::now());
maybe_apply_server_hello_delay(config).await;
debug!(
peer = %peer,
ignore_time_skew = config.access.ignore_time_skew,
@ -402,13 +423,17 @@ where
let digest_half = &validation.digest[..tls::TLS_DIGEST_HALF_LEN];
if replay_checker.check_and_add_tls_digest(digest_half) {
auth_probe_record_failure(peer.ip(), Instant::now());
maybe_apply_server_hello_delay(config).await;
warn!(peer = %peer, "TLS replay attack detected (duplicate digest)");
return HandshakeResult::BadClient { reader, writer };
}
let secret = match secrets.iter().find(|(name, _)| *name == validation.user) {
Some((_, s)) => s,
None => return HandshakeResult::BadClient { reader, writer },
None => {
maybe_apply_server_hello_delay(config).await;
return HandshakeResult::BadClient { reader, writer };
}
};
let cached = if config.censorship.tls_emulation {
@ -448,6 +473,7 @@ where
} else if alpn_list.iter().any(|p| p == b"http/1.1") {
Some(b"http/1.1".to_vec())
} else if !alpn_list.is_empty() {
maybe_apply_server_hello_delay(config).await;
debug!(peer = %peer, "Client ALPN list has no supported protocol; using masking fallback");
return HandshakeResult::BadClient { reader, writer };
} else {
@ -480,19 +506,9 @@ where
)
};
// Optional anti-fingerprint delay before sending ServerHello.
if config.censorship.server_hello_delay_max_ms > 0 {
let min = config.censorship.server_hello_delay_min_ms;
let max = config.censorship.server_hello_delay_max_ms.max(min);
let delay_ms = if max == min {
max
} else {
rand::rng().random_range(min..=max)
};
if delay_ms > 0 {
tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
}
}
// Apply the same optional delay budget used by reject paths to reduce
// distinguishability between success and fail-closed handshakes.
maybe_apply_server_hello_delay(config).await;
debug!(peer = %peer, response_len = response.len(), "Sending TLS ServerHello");
@ -539,6 +555,7 @@ where
trace!(peer = %peer, handshake = ?hex::encode(handshake), "MTProto handshake bytes");
if auth_probe_is_throttled(peer.ip(), Instant::now()) {
maybe_apply_server_hello_delay(config).await;
debug!(peer = %peer, "MTProto handshake rejected by pre-auth probe throttle");
return HandshakeResult::BadClient { reader, writer };
}
@ -609,6 +626,7 @@ where
// authentication check first to avoid poisoning the replay cache.
if replay_checker.check_and_add_handshake(dec_prekey_iv) {
auth_probe_record_failure(peer.ip(), Instant::now());
maybe_apply_server_hello_delay(config).await;
warn!(peer = %peer, user = %user, "MTProto replay attack detected");
return HandshakeResult::BadClient { reader, writer };
}
@ -645,6 +663,7 @@ where
}
auth_probe_record_failure(peer.ip(), Instant::now());
maybe_apply_server_hello_delay(config).await;
debug!(peer = %peer, "MTProto handshake: no matching user found");
HandshakeResult::BadClient { reader, writer }
}

View File

@ -580,6 +580,72 @@ async fn malformed_tls_classes_complete_within_bounded_time() {
}
}
#[tokio::test]
async fn tls_invalid_hmac_respects_configured_anti_fingerprint_delay() {
let secret = [0x5Au8; 16];
let mut config = test_config_with_secret_hex("5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a");
config.censorship.server_hello_delay_min_ms = 20;
config.censorship.server_hello_delay_max_ms = 20;
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let rng = SecureRandom::new();
let peer: SocketAddr = "198.51.100.32:44331".parse().unwrap();
let mut bad_hmac = make_valid_tls_handshake(&secret, 0);
bad_hmac[tls::TLS_DIGEST_POS] ^= 0x01;
let started = Instant::now();
let result = handle_tls_handshake(
&bad_hmac,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
&rng,
None,
)
.await;
assert!(matches!(result, HandshakeResult::BadClient { .. }));
assert!(
started.elapsed() >= Duration::from_millis(18),
"configured anti-fingerprint delay must apply to invalid TLS handshakes"
);
}
#[tokio::test]
async fn tls_alpn_mismatch_respects_configured_anti_fingerprint_delay() {
let secret = [0x6Bu8; 16];
let mut config = test_config_with_secret_hex("6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b");
config.censorship.alpn_enforce = true;
config.censorship.server_hello_delay_min_ms = 20;
config.censorship.server_hello_delay_max_ms = 20;
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let rng = SecureRandom::new();
let peer: SocketAddr = "198.51.100.33:44332".parse().unwrap();
let handshake = make_valid_tls_client_hello_with_alpn(&secret, 0, &[b"h3"]);
let started = Instant::now();
let result = handle_tls_handshake(
&handshake,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
&rng,
None,
)
.await;
assert!(matches!(result, HandshakeResult::BadClient { .. }));
assert!(
started.elapsed() >= Duration::from_millis(18),
"configured anti-fingerprint delay must apply to ALPN-mismatch rejects"
);
}
#[tokio::test]
#[ignore = "timing-sensitive; run manually on low-jitter hosts"]
async fn malformed_tls_classes_share_close_latency_buckets() {
@ -643,6 +709,82 @@ async fn malformed_tls_classes_share_close_latency_buckets() {
);
}
#[tokio::test]
#[ignore = "timing matrix; run manually with --ignored --nocapture"]
async fn timing_matrix_tls_classes_under_fixed_delay_budget() {
const ITER: usize = 48;
const BUCKET_MS: u128 = 10;
let secret = [0x77u8; 16];
let mut config = test_config_with_secret_hex("77777777777777777777777777777777");
config.censorship.alpn_enforce = true;
config.censorship.server_hello_delay_min_ms = 20;
config.censorship.server_hello_delay_max_ms = 20;
let rng = SecureRandom::new();
let base_ip = std::net::Ipv4Addr::new(198, 51, 100, 34);
let too_short = vec![0x16, 0x03, 0x01];
let mut bad_hmac = make_valid_tls_handshake(&secret, 0);
bad_hmac[tls::TLS_DIGEST_POS + 1] ^= 0x01;
let alpn_mismatch = make_valid_tls_client_hello_with_alpn(&secret, 0, &[b"h3"]);
let valid_h2 = make_valid_tls_client_hello_with_alpn(&secret, 0, &[b"h2"]);
let classes = vec![
("too_short", too_short),
("bad_hmac", bad_hmac),
("alpn_mismatch", alpn_mismatch),
("valid_h2", valid_h2),
];
for (class, probe) in classes {
let mut samples_ms = Vec::with_capacity(ITER);
for idx in 0..ITER {
clear_auth_probe_state_for_testing();
let replay_checker = ReplayChecker::new(4096, Duration::from_secs(60));
let peer: SocketAddr = SocketAddr::from((base_ip, 44_000 + idx as u16));
let started = Instant::now();
let result = handle_tls_handshake(
&probe,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
&rng,
None,
)
.await;
let elapsed = started.elapsed();
samples_ms.push(elapsed.as_millis());
if class == "valid_h2" {
assert!(matches!(result, HandshakeResult::Success(_)));
} else {
assert!(matches!(result, HandshakeResult::BadClient { .. }));
}
}
samples_ms.sort_unstable();
let sum: u128 = samples_ms.iter().copied().sum();
let mean = sum as f64 / samples_ms.len() as f64;
let min = samples_ms[0];
let p95_idx = ((samples_ms.len() as f64) * 0.95).floor() as usize;
let p95 = samples_ms[p95_idx.min(samples_ms.len() - 1)];
let max = samples_ms[samples_ms.len() - 1];
println!(
"TIMING_MATRIX tls class={} mean_ms={:.2} min_ms={} p95_ms={} max_ms={} bucket_mean={}",
class,
mean,
min,
p95,
max,
(mean as u128) / BUCKET_MS
);
}
}
#[test]
fn secure_tag_requires_tls_mode_on_tls_transport() {
let mut config = ProxyConfig::default();

View File

@ -7,7 +7,7 @@ use tokio::net::TcpStream;
#[cfg(unix)]
use tokio::net::UnixStream;
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt};
use tokio::time::timeout;
use tokio::time::{Instant, timeout};
use tracing::debug;
use crate::config::ProxyConfig;
use crate::network::dns_overrides::resolve_socket_addr;
@ -49,6 +49,20 @@ where
}
}
async fn wait_mask_connect_budget(started: Instant) {
let elapsed = started.elapsed();
if elapsed < MASK_TIMEOUT {
tokio::time::sleep(MASK_TIMEOUT - elapsed).await;
}
}
async fn wait_mask_outcome_budget(started: Instant) {
let elapsed = started.elapsed();
if elapsed < MASK_TIMEOUT {
tokio::time::sleep(MASK_TIMEOUT - elapsed).await;
}
}
/// Detect client type based on initial data
fn detect_client_type(data: &[u8]) -> &'static str {
// Check for HTTP request
@ -107,6 +121,8 @@ where
// Connect via Unix socket or TCP
#[cfg(unix)]
if let Some(ref sock_path) = config.censorship.mask_unix_sock {
let outcome_started = Instant::now();
let connect_started = Instant::now();
debug!(
client_type = client_type,
sock = %sock_path,
@ -143,14 +159,18 @@ where
if timeout(MASK_RELAY_TIMEOUT, relay_to_mask(reader, writer, mask_read, mask_write, initial_data)).await.is_err() {
debug!("Mask relay timed out (unix socket)");
}
wait_mask_outcome_budget(outcome_started).await;
}
Ok(Err(e)) => {
wait_mask_connect_budget(connect_started).await;
debug!(error = %e, "Failed to connect to mask unix socket");
consume_client_data_with_timeout(reader).await;
wait_mask_outcome_budget(outcome_started).await;
}
Err(_) => {
debug!("Timeout connecting to mask unix socket");
consume_client_data_with_timeout(reader).await;
wait_mask_outcome_budget(outcome_started).await;
}
}
return;
@ -172,6 +192,8 @@ where
let mask_addr = resolve_socket_addr(mask_host, mask_port)
.map(|addr| addr.to_string())
.unwrap_or_else(|| format!("{}:{}", mask_host, mask_port));
let outcome_started = Instant::now();
let connect_started = Instant::now();
let connect_result = timeout(MASK_TIMEOUT, TcpStream::connect(&mask_addr)).await;
match connect_result {
Ok(Ok(stream)) => {
@ -202,14 +224,18 @@ where
if timeout(MASK_RELAY_TIMEOUT, relay_to_mask(reader, writer, mask_read, mask_write, initial_data)).await.is_err() {
debug!("Mask relay timed out");
}
wait_mask_outcome_budget(outcome_started).await;
}
Ok(Err(e)) => {
wait_mask_connect_budget(connect_started).await;
debug!(error = %e, "Failed to connect to mask host");
consume_client_data_with_timeout(reader).await;
wait_mask_outcome_budget(outcome_started).await;
}
Err(_) => {
debug!("Timeout connecting to mask host");
consume_client_data_with_timeout(reader).await;
wait_mask_outcome_budget(outcome_started).await;
}
}
}

View File

@ -8,7 +8,7 @@ use tokio::io::{duplex, AsyncBufReadExt, BufReader};
use tokio::net::TcpListener;
#[cfg(unix)]
use tokio::net::UnixListener;
use tokio::time::{sleep, timeout, Duration};
use tokio::time::{Instant, sleep, timeout, Duration};
#[tokio::test]
async fn bad_client_probe_is_forwarded_verbatim_to_mask_backend() {
@ -216,6 +216,372 @@ async fn backend_unavailable_falls_back_to_silent_consume() {
assert_eq!(n, 0);
}
#[tokio::test]
async fn backend_connect_refusal_waits_mask_connect_budget_before_fallback() {
let temp_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let unused_port = temp_listener.local_addr().unwrap().port();
drop(temp_listener);
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = unused_port;
config.censorship.mask_unix_sock = None;
config.censorship.mask_proxy_protocol = 0;
let peer: SocketAddr = "203.0.113.12:42426".parse().unwrap();
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let probe = b"GET /probe HTTP/1.1\r\nHost: x\r\n\r\n";
// Keep reader open so fallback path does not terminate immediately on EOF.
let (_client_reader_side, client_reader) = duplex(256);
let (_client_visible_reader, client_visible_writer) = duplex(256);
let beobachten = BeobachtenStore::new();
let started = Instant::now();
let task = tokio::spawn(async move {
handle_bad_client(
client_reader,
client_visible_writer,
probe,
peer,
local_addr,
&config,
&beobachten,
)
.await;
});
timeout(Duration::from_millis(35), task)
.await
.expect_err("masking fallback must not complete before connect budget elapses");
assert!(
started.elapsed() >= Duration::from_millis(35),
"fallback path must absorb immediate connect refusal into connect budget"
);
}
#[tokio::test]
async fn backend_reachable_fast_response_waits_mask_outcome_budget() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let probe = b"GET /ok HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec();
let backend_reply = b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\n\r\n".to_vec();
let accept_task = tokio::spawn({
let probe = probe.clone();
let backend_reply = backend_reply.clone();
async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut received = vec![0u8; probe.len()];
stream.read_exact(&mut received).await.unwrap();
assert_eq!(received, probe);
stream.write_all(&backend_reply).await.unwrap();
}
});
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = backend_addr.port();
config.censorship.mask_unix_sock = None;
config.censorship.mask_proxy_protocol = 0;
let peer: SocketAddr = "203.0.113.13:42427".parse().unwrap();
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let (client_writer_side, client_reader) = duplex(256);
drop(client_writer_side);
let (_client_visible_reader, client_visible_writer) = duplex(512);
let beobachten = BeobachtenStore::new();
let started = Instant::now();
handle_bad_client(
client_reader,
client_visible_writer,
&probe,
peer,
local_addr,
&config,
&beobachten,
)
.await;
assert!(
started.elapsed() >= Duration::from_millis(45),
"reachable mask path must also satisfy coarse outcome budget"
);
accept_task.await.unwrap();
}
#[tokio::test]
async fn mask_disabled_fast_eof_not_shaped_by_mask_budget() {
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = false;
let peer: SocketAddr = "203.0.113.14:42428".parse().unwrap();
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let (client_writer_side, client_reader) = duplex(256);
drop(client_writer_side);
let (_client_visible_reader, client_visible_writer) = duplex(256);
let beobachten = BeobachtenStore::new();
let started = Instant::now();
handle_bad_client(
client_reader,
client_visible_writer,
b"x",
peer,
local_addr,
&config,
&beobachten,
)
.await;
assert!(
started.elapsed() < Duration::from_millis(20),
"mask-disabled fallback should keep immediate EOF behavior"
);
}
#[tokio::test]
async fn backend_reachable_slow_response_not_padded_twice() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let probe = b"GET /slow HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec();
let backend_reply = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK".to_vec();
let accept_task = tokio::spawn({
let probe = probe.clone();
let backend_reply = backend_reply.clone();
async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut received = vec![0u8; probe.len()];
stream.read_exact(&mut received).await.unwrap();
assert_eq!(received, probe);
sleep(Duration::from_millis(90)).await;
stream.write_all(&backend_reply).await.unwrap();
}
});
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = backend_addr.port();
config.censorship.mask_unix_sock = None;
config.censorship.mask_proxy_protocol = 0;
let peer: SocketAddr = "203.0.113.15:42429".parse().unwrap();
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let (client_writer_side, client_reader) = duplex(256);
drop(client_writer_side);
let (_client_visible_reader, client_visible_writer) = duplex(512);
let beobachten = BeobachtenStore::new();
let started = Instant::now();
handle_bad_client(
client_reader,
client_visible_writer,
&probe,
peer,
local_addr,
&config,
&beobachten,
)
.await;
let elapsed = started.elapsed();
assert!(elapsed >= Duration::from_millis(85));
assert!(
elapsed < Duration::from_millis(170),
"slow reachable backend should not incur an extra full budget after already exceeding it"
);
accept_task.await.unwrap();
}
#[tokio::test]
async fn adversarial_enabled_refused_and_reachable_collapse_to_same_bucket() {
const ITER: usize = 20;
const BUCKET_MS: u128 = 10;
let probe = b"GET /collapse HTTP/1.1\r\nHost: x\r\n\r\n";
let peer: SocketAddr = "203.0.113.16:42430".parse().unwrap();
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let mut refused = Vec::with_capacity(ITER);
for _ in 0..ITER {
let temp_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let unused_port = temp_listener.local_addr().unwrap().port();
drop(temp_listener);
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = unused_port;
config.censorship.mask_unix_sock = None;
config.censorship.mask_proxy_protocol = 0;
let (client_writer_side, client_reader) = duplex(256);
drop(client_writer_side);
let (_client_visible_reader, client_visible_writer) = duplex(256);
let beobachten = BeobachtenStore::new();
let started = Instant::now();
handle_bad_client(
client_reader,
client_visible_writer,
probe,
peer,
local_addr,
&config,
&beobachten,
)
.await;
refused.push(started.elapsed().as_millis());
}
let mut reachable = Vec::with_capacity(ITER);
for _ in 0..ITER {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let probe_vec = probe.to_vec();
let backend_reply = b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\n\r\n".to_vec();
let accept_task = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut received = vec![0u8; probe_vec.len()];
stream.read_exact(&mut received).await.unwrap();
stream.write_all(&backend_reply).await.unwrap();
});
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = backend_addr.port();
config.censorship.mask_unix_sock = None;
config.censorship.mask_proxy_protocol = 0;
let (client_writer_side, client_reader) = duplex(256);
drop(client_writer_side);
let (_client_visible_reader, client_visible_writer) = duplex(256);
let beobachten = BeobachtenStore::new();
let started = Instant::now();
handle_bad_client(
client_reader,
client_visible_writer,
probe,
peer,
local_addr,
&config,
&beobachten,
)
.await;
reachable.push(started.elapsed().as_millis());
accept_task.await.unwrap();
}
let refused_mean = refused.iter().copied().sum::<u128>() as f64 / refused.len() as f64;
let reachable_mean = reachable.iter().copied().sum::<u128>() as f64 / reachable.len() as f64;
let refused_bucket = (refused_mean as u128) / BUCKET_MS;
let reachable_bucket = (reachable_mean as u128) / BUCKET_MS;
assert!(
refused_bucket.abs_diff(reachable_bucket) <= 1,
"enabled refused and reachable paths must collapse into the same coarse latency bucket"
);
}
#[tokio::test]
async fn light_fuzz_mask_enabled_outcomes_preserve_coarse_budget() {
let mut seed: u64 = 0xA5A5_5A5A_1337_4242;
let mut next = || {
seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
seed
};
let peer: SocketAddr = "203.0.113.17:42431".parse().unwrap();
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
for _ in 0..40 {
let probe_len = (next() as usize % 96).saturating_add(8);
let mut probe = vec![0u8; probe_len];
for byte in &mut probe {
*byte = next() as u8;
}
let use_reachable = (next() & 1) == 0;
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_unix_sock = None;
config.censorship.mask_proxy_protocol = 0;
let (client_writer_side, client_reader) = duplex(512);
drop(client_writer_side);
let (_client_visible_reader, client_visible_writer) = duplex(512);
let beobachten = BeobachtenStore::new();
let started = Instant::now();
if use_reachable {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = backend_addr.port();
let probe_vec = probe.clone();
let accept_task = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut observed = vec![0u8; probe_vec.len()];
stream.read_exact(&mut observed).await.unwrap();
});
handle_bad_client(
client_reader,
client_visible_writer,
&probe,
peer,
local_addr,
&config,
&beobachten,
)
.await;
accept_task.await.unwrap();
} else {
let temp_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let unused_port = temp_listener.local_addr().unwrap().port();
drop(temp_listener);
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = unused_port;
handle_bad_client(
client_reader,
client_visible_writer,
&probe,
peer,
local_addr,
&config,
&beobachten,
)
.await;
}
assert!(
started.elapsed() >= Duration::from_millis(45),
"mask-enabled fallback must preserve coarse timing budget under varied probe shapes"
);
}
}
#[tokio::test]
async fn mask_disabled_consumes_client_data_without_response() {
let mut config = ProxyConfig::default();
@ -729,3 +1095,158 @@ async fn relay_to_mask_timeout_cancels_and_drops_all_io_endpoints() {
assert!(mask_reader_dropped.load(Ordering::SeqCst));
assert!(mask_writer_dropped.load(Ordering::SeqCst));
}
#[tokio::test]
#[ignore = "timing matrix; run manually with --ignored --nocapture"]
async fn timing_matrix_masking_classes_under_controlled_inputs() {
const ITER: usize = 24;
const BUCKET_MS: u128 = 10;
let probe = b"GET /timing HTTP/1.1\r\nHost: x\r\n\r\n";
let peer: SocketAddr = "203.0.113.40:51000".parse().unwrap();
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
// Class 1: masking disabled with immediate EOF (fast fail-closed consume path).
let mut disabled_samples = Vec::with_capacity(ITER);
for _ in 0..ITER {
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = false;
let (client_writer_side, client_reader) = duplex(256);
drop(client_writer_side);
let (_client_visible_reader, client_visible_writer) = duplex(256);
let beobachten = BeobachtenStore::new();
let started = Instant::now();
handle_bad_client(
client_reader,
client_visible_writer,
probe,
peer,
local_addr,
&config,
&beobachten,
)
.await;
disabled_samples.push(started.elapsed().as_millis());
}
// Class 2: masking enabled, backend connect refused.
let mut refused_samples = Vec::with_capacity(ITER);
for _ in 0..ITER {
let temp_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let unused_port = temp_listener.local_addr().unwrap().port();
drop(temp_listener);
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = unused_port;
config.censorship.mask_unix_sock = None;
config.censorship.mask_proxy_protocol = 0;
let (client_writer_side, client_reader) = duplex(256);
drop(client_writer_side);
let (_client_visible_reader, client_visible_writer) = duplex(256);
let beobachten = BeobachtenStore::new();
let started = Instant::now();
handle_bad_client(
client_reader,
client_visible_writer,
probe,
peer,
local_addr,
&config,
&beobachten,
)
.await;
refused_samples.push(started.elapsed().as_millis());
}
// Class 3: masking enabled, backend reachable and immediately responds.
let mut reachable_samples = Vec::with_capacity(ITER);
for _ in 0..ITER {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let backend_reply = b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\n\r\n".to_vec();
let probe_vec = probe.to_vec();
let accept_task = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut received = vec![0u8; probe_vec.len()];
stream.read_exact(&mut received).await.unwrap();
assert_eq!(received, probe_vec);
stream.write_all(&backend_reply).await.unwrap();
});
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = backend_addr.port();
config.censorship.mask_unix_sock = None;
config.censorship.mask_proxy_protocol = 0;
let (client_writer_side, client_reader) = duplex(256);
drop(client_writer_side);
let (_client_visible_reader, client_visible_writer) = duplex(256);
let beobachten = BeobachtenStore::new();
let started = Instant::now();
handle_bad_client(
client_reader,
client_visible_writer,
probe,
peer,
local_addr,
&config,
&beobachten,
)
.await;
reachable_samples.push(started.elapsed().as_millis());
accept_task.await.unwrap();
}
fn summarize(samples_ms: &mut [u128]) -> (f64, u128, u128, u128) {
samples_ms.sort_unstable();
let sum: u128 = samples_ms.iter().copied().sum();
let mean = sum as f64 / samples_ms.len() as f64;
let min = samples_ms[0];
let p95_idx = ((samples_ms.len() as f64) * 0.95).floor() as usize;
let p95 = samples_ms[p95_idx.min(samples_ms.len() - 1)];
let max = samples_ms[samples_ms.len() - 1];
(mean, min, p95, max)
}
let (disabled_mean, disabled_min, disabled_p95, disabled_max) = summarize(&mut disabled_samples);
let (refused_mean, refused_min, refused_p95, refused_max) = summarize(&mut refused_samples);
let (reachable_mean, reachable_min, reachable_p95, reachable_max) = summarize(&mut reachable_samples);
println!(
"TIMING_MATRIX masking class=disabled_eof mean_ms={:.2} min_ms={} p95_ms={} max_ms={} bucket_mean={}",
disabled_mean,
disabled_min,
disabled_p95,
disabled_max,
(disabled_mean as u128) / BUCKET_MS
);
println!(
"TIMING_MATRIX masking class=enabled_refused_eof mean_ms={:.2} min_ms={} p95_ms={} max_ms={} bucket_mean={}",
refused_mean,
refused_min,
refused_p95,
refused_max,
(refused_mean as u128) / BUCKET_MS
);
println!(
"TIMING_MATRIX masking class=enabled_reachable_eof mean_ms={:.2} min_ms={} p95_ms={} max_ms={} bucket_mean={}",
reachable_mean,
reachable_min,
reachable_p95,
reachable_max,
(reachable_mean as u128) / BUCKET_MS
);
}

View File

@ -239,7 +239,9 @@ pub(super) async fn reap_draining_writers(
if !closed_writer_ids.insert(writer_id) {
continue;
}
pool.remove_writer_and_close_clients(writer_id).await;
if !pool.remove_writer_if_empty(writer_id).await {
continue;
}
closed_total = closed_total.saturating_add(1);
}

View File

@ -592,3 +592,67 @@ async fn reap_draining_writers_mixed_backlog_converges_without_leaking_warn_stat
fn general_config_default_drain_threshold_remains_enabled() {
assert_eq!(GeneralConfig::default().me_pool_drain_threshold, 128);
}
#[tokio::test]
async fn reap_draining_writers_does_not_close_writer_that_became_non_empty_after_snapshot() {
let pool = make_pool(128).await;
let now_epoch_secs = MePool::now_epoch_secs();
let empty_writer_id = 700u64;
insert_draining_writer(
&pool,
empty_writer_id,
now_epoch_secs.saturating_sub(60),
0,
0,
)
.await;
let stale_empty_snapshot = vec![empty_writer_id];
let (rebound_conn_id, _rx) = pool.registry.register().await;
assert!(
pool.registry
.bind_writer(
rebound_conn_id,
empty_writer_id,
ConnMeta {
target_dc: 2,
client_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 9050),
our_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443),
proto_flags: 0,
},
)
.await,
"writer should accept a new bind after stale empty snapshot"
);
for writer_id in stale_empty_snapshot {
assert!(
!pool.remove_writer_if_empty(writer_id).await,
"atomic empty cleanup must reject writers that gained bound clients"
);
}
assert!(
writer_exists(&pool, empty_writer_id).await,
"empty-path cleanup must not remove a writer that gained a bound client"
);
assert_eq!(
pool.registry.get_writer(rebound_conn_id).await.map(|w| w.writer_id),
Some(empty_writer_id)
);
let _ = pool.registry.unregister(rebound_conn_id).await;
}
#[tokio::test]
async fn prune_closed_writers_closes_bound_clients_when_writer_is_non_empty() {
let pool = make_pool(128).await;
let now_epoch_secs = MePool::now_epoch_secs();
let conn_ids = insert_draining_writer(&pool, 910, now_epoch_secs.saturating_sub(60), 1, 0).await;
pool.prune_closed_writers().await;
assert!(!writer_exists(&pool, 910).await);
assert!(pool.registry.get_writer(conn_ids[0]).await.is_none());
}

View File

@ -42,11 +42,10 @@ impl MePool {
}
for writer_id in closed_writer_ids {
if self.registry.is_writer_empty(writer_id).await {
let _ = self.remove_writer_only(writer_id).await;
} else {
let _ = self.remove_writer_and_close_clients(writer_id).await;
if self.remove_writer_if_empty(writer_id).await {
continue;
}
let _ = self.remove_writer_and_close_clients(writer_id).await;
}
}
@ -501,6 +500,17 @@ impl MePool {
}
}
pub(crate) async fn remove_writer_if_empty(self: &Arc<Self>, writer_id: u64) -> bool {
if !self.registry.unregister_writer_if_empty(writer_id).await {
return false;
}
// The registry empty-check and unregister are atomic with respect to binds,
// so remove_writer_only cannot return active bound sessions here.
let _ = self.remove_writer_only(writer_id).await;
true
}
async fn remove_writer_only(self: &Arc<Self>, writer_id: u64) -> Vec<BoundConn> {
let mut close_tx: Option<mpsc::Sender<WriterCommand>> = None;
let mut removed_addr: Option<SocketAddr> = None;

View File

@ -437,6 +437,23 @@ impl ConnRegistry {
.unwrap_or(true)
}
pub async fn unregister_writer_if_empty(&self, writer_id: u64) -> bool {
let mut inner = self.inner.write().await;
let Some(conn_ids) = inner.conns_for_writer.get(&writer_id) else {
// Writer is already absent from the registry.
return true;
};
if !conn_ids.is_empty() {
return false;
}
inner.writers.remove(&writer_id);
inner.last_meta_for_writer.remove(&writer_id);
inner.writer_idle_since_epoch_secs.remove(&writer_id);
inner.conns_for_writer.remove(&writer_id);
true
}
pub(super) async fn non_empty_writer_ids(&self, writer_ids: &[u64]) -> HashSet<u64> {
let inner = self.inner.read().await;
let mut out = HashSet::<u64>::with_capacity(writer_ids.len());