diff --git a/src/proxy/handshake.rs b/src/proxy/handshake.rs index dbd50d5..a1b3eb7 100644 --- a/src/proxy/handshake.rs +++ b/src/proxy/handshake.rs @@ -317,6 +317,24 @@ fn decode_user_secrets( secrets } +async fn maybe_apply_server_hello_delay(config: &ProxyConfig) { + if config.censorship.server_hello_delay_max_ms == 0 { + return; + } + + let min = config.censorship.server_hello_delay_min_ms; + let max = config.censorship.server_hello_delay_max_ms.max(min); + let delay_ms = if max == min { + max + } else { + rand::rng().random_range(min..=max) + }; + + if delay_ms > 0 { + tokio::time::sleep(Duration::from_millis(delay_ms)).await; + } +} + /// Result of successful handshake /// /// Key material (`dec_key`, `dec_iv`, `enc_key`, `enc_iv`) is @@ -368,11 +386,13 @@ where debug!(peer = %peer, handshake_len = handshake.len(), "Processing TLS handshake"); if auth_probe_is_throttled(peer.ip(), Instant::now()) { + maybe_apply_server_hello_delay(config).await; debug!(peer = %peer, "TLS handshake rejected by pre-auth probe throttle"); return HandshakeResult::BadClient { reader, writer }; } if handshake.len() < tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 { + maybe_apply_server_hello_delay(config).await; debug!(peer = %peer, "TLS handshake too short"); return HandshakeResult::BadClient { reader, writer }; } @@ -388,6 +408,7 @@ where Some(v) => v, None => { auth_probe_record_failure(peer.ip(), Instant::now()); + maybe_apply_server_hello_delay(config).await; debug!( peer = %peer, ignore_time_skew = config.access.ignore_time_skew, @@ -402,13 +423,17 @@ where let digest_half = &validation.digest[..tls::TLS_DIGEST_HALF_LEN]; if replay_checker.check_and_add_tls_digest(digest_half) { auth_probe_record_failure(peer.ip(), Instant::now()); + maybe_apply_server_hello_delay(config).await; warn!(peer = %peer, "TLS replay attack detected (duplicate digest)"); return HandshakeResult::BadClient { reader, writer }; } let secret = match secrets.iter().find(|(name, _)| *name == validation.user) { Some((_, s)) => s, - None => return HandshakeResult::BadClient { reader, writer }, + None => { + maybe_apply_server_hello_delay(config).await; + return HandshakeResult::BadClient { reader, writer }; + } }; let cached = if config.censorship.tls_emulation { @@ -448,6 +473,7 @@ where } else if alpn_list.iter().any(|p| p == b"http/1.1") { Some(b"http/1.1".to_vec()) } else if !alpn_list.is_empty() { + maybe_apply_server_hello_delay(config).await; debug!(peer = %peer, "Client ALPN list has no supported protocol; using masking fallback"); return HandshakeResult::BadClient { reader, writer }; } else { @@ -480,19 +506,9 @@ where ) }; - // Optional anti-fingerprint delay before sending ServerHello. - if config.censorship.server_hello_delay_max_ms > 0 { - let min = config.censorship.server_hello_delay_min_ms; - let max = config.censorship.server_hello_delay_max_ms.max(min); - let delay_ms = if max == min { - max - } else { - rand::rng().random_range(min..=max) - }; - if delay_ms > 0 { - tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await; - } - } + // Apply the same optional delay budget used by reject paths to reduce + // distinguishability between success and fail-closed handshakes. + maybe_apply_server_hello_delay(config).await; debug!(peer = %peer, response_len = response.len(), "Sending TLS ServerHello"); @@ -539,6 +555,7 @@ where trace!(peer = %peer, handshake = ?hex::encode(handshake), "MTProto handshake bytes"); if auth_probe_is_throttled(peer.ip(), Instant::now()) { + maybe_apply_server_hello_delay(config).await; debug!(peer = %peer, "MTProto handshake rejected by pre-auth probe throttle"); return HandshakeResult::BadClient { reader, writer }; } @@ -609,6 +626,7 @@ where // authentication check first to avoid poisoning the replay cache. if replay_checker.check_and_add_handshake(dec_prekey_iv) { auth_probe_record_failure(peer.ip(), Instant::now()); + maybe_apply_server_hello_delay(config).await; warn!(peer = %peer, user = %user, "MTProto replay attack detected"); return HandshakeResult::BadClient { reader, writer }; } @@ -645,6 +663,7 @@ where } auth_probe_record_failure(peer.ip(), Instant::now()); + maybe_apply_server_hello_delay(config).await; debug!(peer = %peer, "MTProto handshake: no matching user found"); HandshakeResult::BadClient { reader, writer } } diff --git a/src/proxy/handshake_security_tests.rs b/src/proxy/handshake_security_tests.rs index 6bdc345..7040025 100644 --- a/src/proxy/handshake_security_tests.rs +++ b/src/proxy/handshake_security_tests.rs @@ -580,6 +580,72 @@ async fn malformed_tls_classes_complete_within_bounded_time() { } } +#[tokio::test] +async fn tls_invalid_hmac_respects_configured_anti_fingerprint_delay() { + let secret = [0x5Au8; 16]; + let mut config = test_config_with_secret_hex("5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a"); + config.censorship.server_hello_delay_min_ms = 20; + config.censorship.server_hello_delay_max_ms = 20; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.32:44331".parse().unwrap(); + let mut bad_hmac = make_valid_tls_handshake(&secret, 0); + bad_hmac[tls::TLS_DIGEST_POS] ^= 0x01; + + let started = Instant::now(); + let result = handle_tls_handshake( + &bad_hmac, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(result, HandshakeResult::BadClient { .. })); + assert!( + started.elapsed() >= Duration::from_millis(18), + "configured anti-fingerprint delay must apply to invalid TLS handshakes" + ); +} + +#[tokio::test] +async fn tls_alpn_mismatch_respects_configured_anti_fingerprint_delay() { + let secret = [0x6Bu8; 16]; + let mut config = test_config_with_secret_hex("6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b"); + config.censorship.alpn_enforce = true; + config.censorship.server_hello_delay_min_ms = 20; + config.censorship.server_hello_delay_max_ms = 20; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.33:44332".parse().unwrap(); + let handshake = make_valid_tls_client_hello_with_alpn(&secret, 0, &[b"h3"]); + + let started = Instant::now(); + let result = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(result, HandshakeResult::BadClient { .. })); + assert!( + started.elapsed() >= Duration::from_millis(18), + "configured anti-fingerprint delay must apply to ALPN-mismatch rejects" + ); +} + #[tokio::test] #[ignore = "timing-sensitive; run manually on low-jitter hosts"] async fn malformed_tls_classes_share_close_latency_buckets() { @@ -643,6 +709,82 @@ async fn malformed_tls_classes_share_close_latency_buckets() { ); } +#[tokio::test] +#[ignore = "timing matrix; run manually with --ignored --nocapture"] +async fn timing_matrix_tls_classes_under_fixed_delay_budget() { + const ITER: usize = 48; + const BUCKET_MS: u128 = 10; + + let secret = [0x77u8; 16]; + let mut config = test_config_with_secret_hex("77777777777777777777777777777777"); + config.censorship.alpn_enforce = true; + config.censorship.server_hello_delay_min_ms = 20; + config.censorship.server_hello_delay_max_ms = 20; + + let rng = SecureRandom::new(); + let base_ip = std::net::Ipv4Addr::new(198, 51, 100, 34); + + let too_short = vec![0x16, 0x03, 0x01]; + let mut bad_hmac = make_valid_tls_handshake(&secret, 0); + bad_hmac[tls::TLS_DIGEST_POS + 1] ^= 0x01; + let alpn_mismatch = make_valid_tls_client_hello_with_alpn(&secret, 0, &[b"h3"]); + let valid_h2 = make_valid_tls_client_hello_with_alpn(&secret, 0, &[b"h2"]); + + let classes = vec![ + ("too_short", too_short), + ("bad_hmac", bad_hmac), + ("alpn_mismatch", alpn_mismatch), + ("valid_h2", valid_h2), + ]; + + for (class, probe) in classes { + let mut samples_ms = Vec::with_capacity(ITER); + for idx in 0..ITER { + clear_auth_probe_state_for_testing(); + let replay_checker = ReplayChecker::new(4096, Duration::from_secs(60)); + let peer: SocketAddr = SocketAddr::from((base_ip, 44_000 + idx as u16)); + let started = Instant::now(); + let result = handle_tls_handshake( + &probe, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + let elapsed = started.elapsed(); + samples_ms.push(elapsed.as_millis()); + + if class == "valid_h2" { + assert!(matches!(result, HandshakeResult::Success(_))); + } else { + assert!(matches!(result, HandshakeResult::BadClient { .. })); + } + } + + samples_ms.sort_unstable(); + let sum: u128 = samples_ms.iter().copied().sum(); + let mean = sum as f64 / samples_ms.len() as f64; + let min = samples_ms[0]; + let p95_idx = ((samples_ms.len() as f64) * 0.95).floor() as usize; + let p95 = samples_ms[p95_idx.min(samples_ms.len() - 1)]; + let max = samples_ms[samples_ms.len() - 1]; + + println!( + "TIMING_MATRIX tls class={} mean_ms={:.2} min_ms={} p95_ms={} max_ms={} bucket_mean={}", + class, + mean, + min, + p95, + max, + (mean as u128) / BUCKET_MS + ); + } +} + #[test] fn secure_tag_requires_tls_mode_on_tls_transport() { let mut config = ProxyConfig::default(); diff --git a/src/proxy/masking.rs b/src/proxy/masking.rs index 9a23c5b..eb6f6da 100644 --- a/src/proxy/masking.rs +++ b/src/proxy/masking.rs @@ -7,7 +7,7 @@ use tokio::net::TcpStream; #[cfg(unix)] use tokio::net::UnixStream; use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt}; -use tokio::time::timeout; +use tokio::time::{Instant, timeout}; use tracing::debug; use crate::config::ProxyConfig; use crate::network::dns_overrides::resolve_socket_addr; @@ -49,6 +49,20 @@ where } } +async fn wait_mask_connect_budget(started: Instant) { + let elapsed = started.elapsed(); + if elapsed < MASK_TIMEOUT { + tokio::time::sleep(MASK_TIMEOUT - elapsed).await; + } +} + +async fn wait_mask_outcome_budget(started: Instant) { + let elapsed = started.elapsed(); + if elapsed < MASK_TIMEOUT { + tokio::time::sleep(MASK_TIMEOUT - elapsed).await; + } +} + /// Detect client type based on initial data fn detect_client_type(data: &[u8]) -> &'static str { // Check for HTTP request @@ -107,6 +121,8 @@ where // Connect via Unix socket or TCP #[cfg(unix)] if let Some(ref sock_path) = config.censorship.mask_unix_sock { + let outcome_started = Instant::now(); + let connect_started = Instant::now(); debug!( client_type = client_type, sock = %sock_path, @@ -143,14 +159,18 @@ where if timeout(MASK_RELAY_TIMEOUT, relay_to_mask(reader, writer, mask_read, mask_write, initial_data)).await.is_err() { debug!("Mask relay timed out (unix socket)"); } + wait_mask_outcome_budget(outcome_started).await; } Ok(Err(e)) => { + wait_mask_connect_budget(connect_started).await; debug!(error = %e, "Failed to connect to mask unix socket"); consume_client_data_with_timeout(reader).await; + wait_mask_outcome_budget(outcome_started).await; } Err(_) => { debug!("Timeout connecting to mask unix socket"); consume_client_data_with_timeout(reader).await; + wait_mask_outcome_budget(outcome_started).await; } } return; @@ -172,6 +192,8 @@ where let mask_addr = resolve_socket_addr(mask_host, mask_port) .map(|addr| addr.to_string()) .unwrap_or_else(|| format!("{}:{}", mask_host, mask_port)); + let outcome_started = Instant::now(); + let connect_started = Instant::now(); let connect_result = timeout(MASK_TIMEOUT, TcpStream::connect(&mask_addr)).await; match connect_result { Ok(Ok(stream)) => { @@ -202,14 +224,18 @@ where if timeout(MASK_RELAY_TIMEOUT, relay_to_mask(reader, writer, mask_read, mask_write, initial_data)).await.is_err() { debug!("Mask relay timed out"); } + wait_mask_outcome_budget(outcome_started).await; } Ok(Err(e)) => { + wait_mask_connect_budget(connect_started).await; debug!(error = %e, "Failed to connect to mask host"); consume_client_data_with_timeout(reader).await; + wait_mask_outcome_budget(outcome_started).await; } Err(_) => { debug!("Timeout connecting to mask host"); consume_client_data_with_timeout(reader).await; + wait_mask_outcome_budget(outcome_started).await; } } } diff --git a/src/proxy/masking_security_tests.rs b/src/proxy/masking_security_tests.rs index 25b6a76..2310846 100644 --- a/src/proxy/masking_security_tests.rs +++ b/src/proxy/masking_security_tests.rs @@ -8,7 +8,7 @@ use tokio::io::{duplex, AsyncBufReadExt, BufReader}; use tokio::net::TcpListener; #[cfg(unix)] use tokio::net::UnixListener; -use tokio::time::{sleep, timeout, Duration}; +use tokio::time::{Instant, sleep, timeout, Duration}; #[tokio::test] async fn bad_client_probe_is_forwarded_verbatim_to_mask_backend() { @@ -216,6 +216,372 @@ async fn backend_unavailable_falls_back_to_silent_consume() { assert_eq!(n, 0); } +#[tokio::test] +async fn backend_connect_refusal_waits_mask_connect_budget_before_fallback() { + let temp_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let unused_port = temp_listener.local_addr().unwrap().port(); + drop(temp_listener); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = unused_port; + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let peer: SocketAddr = "203.0.113.12:42426".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let probe = b"GET /probe HTTP/1.1\r\nHost: x\r\n\r\n"; + + // Keep reader open so fallback path does not terminate immediately on EOF. + let (_client_reader_side, client_reader) = duplex(256); + let (_client_visible_reader, client_visible_writer) = duplex(256); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + let task = tokio::spawn(async move { + handle_bad_client( + client_reader, + client_visible_writer, + probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + }); + + timeout(Duration::from_millis(35), task) + .await + .expect_err("masking fallback must not complete before connect budget elapses"); + assert!( + started.elapsed() >= Duration::from_millis(35), + "fallback path must absorb immediate connect refusal into connect budget" + ); +} + +#[tokio::test] +async fn backend_reachable_fast_response_waits_mask_outcome_budget() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = b"GET /ok HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + let backend_reply = b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut received = vec![0u8; probe.len()]; + stream.read_exact(&mut received).await.unwrap(); + assert_eq!(received, probe); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let peer: SocketAddr = "203.0.113.13:42427".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (client_writer_side, client_reader) = duplex(256); + drop(client_writer_side); + let (_client_visible_reader, client_visible_writer) = duplex(512); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + assert!( + started.elapsed() >= Duration::from_millis(45), + "reachable mask path must also satisfy coarse outcome budget" + ); + accept_task.await.unwrap(); +} + +#[tokio::test] +async fn mask_disabled_fast_eof_not_shaped_by_mask_budget() { + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = false; + + let peer: SocketAddr = "203.0.113.14:42428".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (client_writer_side, client_reader) = duplex(256); + drop(client_writer_side); + let (_client_visible_reader, client_visible_writer) = duplex(256); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + b"x", + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + assert!( + started.elapsed() < Duration::from_millis(20), + "mask-disabled fallback should keep immediate EOF behavior" + ); +} + +#[tokio::test] +async fn backend_reachable_slow_response_not_padded_twice() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = b"GET /slow HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + let backend_reply = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK".to_vec(); + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut received = vec![0u8; probe.len()]; + stream.read_exact(&mut received).await.unwrap(); + assert_eq!(received, probe); + sleep(Duration::from_millis(90)).await; + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let peer: SocketAddr = "203.0.113.15:42429".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (client_writer_side, client_reader) = duplex(256); + drop(client_writer_side); + let (_client_visible_reader, client_visible_writer) = duplex(512); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + let elapsed = started.elapsed(); + + assert!(elapsed >= Duration::from_millis(85)); + assert!( + elapsed < Duration::from_millis(170), + "slow reachable backend should not incur an extra full budget after already exceeding it" + ); + accept_task.await.unwrap(); +} + +#[tokio::test] +async fn adversarial_enabled_refused_and_reachable_collapse_to_same_bucket() { + const ITER: usize = 20; + const BUCKET_MS: u128 = 10; + + let probe = b"GET /collapse HTTP/1.1\r\nHost: x\r\n\r\n"; + let peer: SocketAddr = "203.0.113.16:42430".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let mut refused = Vec::with_capacity(ITER); + for _ in 0..ITER { + let temp_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let unused_port = temp_listener.local_addr().unwrap().port(); + drop(temp_listener); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = unused_port; + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let (client_writer_side, client_reader) = duplex(256); + drop(client_writer_side); + let (_client_visible_reader, client_visible_writer) = duplex(256); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + refused.push(started.elapsed().as_millis()); + } + + let mut reachable = Vec::with_capacity(ITER); + for _ in 0..ITER { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe_vec = probe.to_vec(); + let backend_reply = b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut received = vec![0u8; probe_vec.len()]; + stream.read_exact(&mut received).await.unwrap(); + stream.write_all(&backend_reply).await.unwrap(); + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let (client_writer_side, client_reader) = duplex(256); + drop(client_writer_side); + let (_client_visible_reader, client_visible_writer) = duplex(256); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + reachable.push(started.elapsed().as_millis()); + accept_task.await.unwrap(); + } + + let refused_mean = refused.iter().copied().sum::() as f64 / refused.len() as f64; + let reachable_mean = reachable.iter().copied().sum::() as f64 / reachable.len() as f64; + let refused_bucket = (refused_mean as u128) / BUCKET_MS; + let reachable_bucket = (reachable_mean as u128) / BUCKET_MS; + + assert!( + refused_bucket.abs_diff(reachable_bucket) <= 1, + "enabled refused and reachable paths must collapse into the same coarse latency bucket" + ); +} + +#[tokio::test] +async fn light_fuzz_mask_enabled_outcomes_preserve_coarse_budget() { + let mut seed: u64 = 0xA5A5_5A5A_1337_4242; + let mut next = || { + seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1); + seed + }; + + let peer: SocketAddr = "203.0.113.17:42431".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + for _ in 0..40 { + let probe_len = (next() as usize % 96).saturating_add(8); + let mut probe = vec![0u8; probe_len]; + for byte in &mut probe { + *byte = next() as u8; + } + + let use_reachable = (next() & 1) == 0; + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let (client_writer_side, client_reader) = duplex(512); + drop(client_writer_side); + let (_client_visible_reader, client_visible_writer) = duplex(512); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + if use_reachable { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + + let probe_vec = probe.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut observed = vec![0u8; probe_vec.len()]; + stream.read_exact(&mut observed).await.unwrap(); + }); + + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + accept_task.await.unwrap(); + } else { + let temp_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let unused_port = temp_listener.local_addr().unwrap().port(); + drop(temp_listener); + + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = unused_port; + + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + } + + assert!( + started.elapsed() >= Duration::from_millis(45), + "mask-enabled fallback must preserve coarse timing budget under varied probe shapes" + ); + } +} + #[tokio::test] async fn mask_disabled_consumes_client_data_without_response() { let mut config = ProxyConfig::default(); @@ -729,3 +1095,158 @@ async fn relay_to_mask_timeout_cancels_and_drops_all_io_endpoints() { assert!(mask_reader_dropped.load(Ordering::SeqCst)); assert!(mask_writer_dropped.load(Ordering::SeqCst)); } + +#[tokio::test] +#[ignore = "timing matrix; run manually with --ignored --nocapture"] +async fn timing_matrix_masking_classes_under_controlled_inputs() { + const ITER: usize = 24; + const BUCKET_MS: u128 = 10; + + let probe = b"GET /timing HTTP/1.1\r\nHost: x\r\n\r\n"; + let peer: SocketAddr = "203.0.113.40:51000".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + // Class 1: masking disabled with immediate EOF (fast fail-closed consume path). + let mut disabled_samples = Vec::with_capacity(ITER); + for _ in 0..ITER { + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = false; + + let (client_writer_side, client_reader) = duplex(256); + drop(client_writer_side); + let (_client_visible_reader, client_visible_writer) = duplex(256); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + disabled_samples.push(started.elapsed().as_millis()); + } + + // Class 2: masking enabled, backend connect refused. + let mut refused_samples = Vec::with_capacity(ITER); + for _ in 0..ITER { + let temp_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let unused_port = temp_listener.local_addr().unwrap().port(); + drop(temp_listener); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = unused_port; + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let (client_writer_side, client_reader) = duplex(256); + drop(client_writer_side); + let (_client_visible_reader, client_visible_writer) = duplex(256); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + refused_samples.push(started.elapsed().as_millis()); + } + + // Class 3: masking enabled, backend reachable and immediately responds. + let mut reachable_samples = Vec::with_capacity(ITER); + for _ in 0..ITER { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let backend_reply = b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\n\r\n".to_vec(); + let probe_vec = probe.to_vec(); + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut received = vec![0u8; probe_vec.len()]; + stream.read_exact(&mut received).await.unwrap(); + assert_eq!(received, probe_vec); + stream.write_all(&backend_reply).await.unwrap(); + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let (client_writer_side, client_reader) = duplex(256); + drop(client_writer_side); + let (_client_visible_reader, client_visible_writer) = duplex(256); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + reachable_samples.push(started.elapsed().as_millis()); + accept_task.await.unwrap(); + } + + fn summarize(samples_ms: &mut [u128]) -> (f64, u128, u128, u128) { + samples_ms.sort_unstable(); + let sum: u128 = samples_ms.iter().copied().sum(); + let mean = sum as f64 / samples_ms.len() as f64; + let min = samples_ms[0]; + let p95_idx = ((samples_ms.len() as f64) * 0.95).floor() as usize; + let p95 = samples_ms[p95_idx.min(samples_ms.len() - 1)]; + let max = samples_ms[samples_ms.len() - 1]; + (mean, min, p95, max) + } + + let (disabled_mean, disabled_min, disabled_p95, disabled_max) = summarize(&mut disabled_samples); + let (refused_mean, refused_min, refused_p95, refused_max) = summarize(&mut refused_samples); + let (reachable_mean, reachable_min, reachable_p95, reachable_max) = summarize(&mut reachable_samples); + + println!( + "TIMING_MATRIX masking class=disabled_eof mean_ms={:.2} min_ms={} p95_ms={} max_ms={} bucket_mean={}", + disabled_mean, + disabled_min, + disabled_p95, + disabled_max, + (disabled_mean as u128) / BUCKET_MS + ); + println!( + "TIMING_MATRIX masking class=enabled_refused_eof mean_ms={:.2} min_ms={} p95_ms={} max_ms={} bucket_mean={}", + refused_mean, + refused_min, + refused_p95, + refused_max, + (refused_mean as u128) / BUCKET_MS + ); + println!( + "TIMING_MATRIX masking class=enabled_reachable_eof mean_ms={:.2} min_ms={} p95_ms={} max_ms={} bucket_mean={}", + reachable_mean, + reachable_min, + reachable_p95, + reachable_max, + (reachable_mean as u128) / BUCKET_MS + ); +} diff --git a/src/transport/middle_proxy/health.rs b/src/transport/middle_proxy/health.rs index 1c2c648..a6b1031 100644 --- a/src/transport/middle_proxy/health.rs +++ b/src/transport/middle_proxy/health.rs @@ -239,7 +239,9 @@ pub(super) async fn reap_draining_writers( if !closed_writer_ids.insert(writer_id) { continue; } - pool.remove_writer_and_close_clients(writer_id).await; + if !pool.remove_writer_if_empty(writer_id).await { + continue; + } closed_total = closed_total.saturating_add(1); } diff --git a/src/transport/middle_proxy/health_regression_tests.rs b/src/transport/middle_proxy/health_regression_tests.rs index fe73670..6b6b12a 100644 --- a/src/transport/middle_proxy/health_regression_tests.rs +++ b/src/transport/middle_proxy/health_regression_tests.rs @@ -592,3 +592,67 @@ async fn reap_draining_writers_mixed_backlog_converges_without_leaking_warn_stat fn general_config_default_drain_threshold_remains_enabled() { assert_eq!(GeneralConfig::default().me_pool_drain_threshold, 128); } + +#[tokio::test] +async fn reap_draining_writers_does_not_close_writer_that_became_non_empty_after_snapshot() { + let pool = make_pool(128).await; + let now_epoch_secs = MePool::now_epoch_secs(); + + let empty_writer_id = 700u64; + insert_draining_writer( + &pool, + empty_writer_id, + now_epoch_secs.saturating_sub(60), + 0, + 0, + ) + .await; + + let stale_empty_snapshot = vec![empty_writer_id]; + let (rebound_conn_id, _rx) = pool.registry.register().await; + assert!( + pool.registry + .bind_writer( + rebound_conn_id, + empty_writer_id, + ConnMeta { + target_dc: 2, + client_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 9050), + our_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443), + proto_flags: 0, + }, + ) + .await, + "writer should accept a new bind after stale empty snapshot" + ); + + for writer_id in stale_empty_snapshot { + assert!( + !pool.remove_writer_if_empty(writer_id).await, + "atomic empty cleanup must reject writers that gained bound clients" + ); + } + + assert!( + writer_exists(&pool, empty_writer_id).await, + "empty-path cleanup must not remove a writer that gained a bound client" + ); + assert_eq!( + pool.registry.get_writer(rebound_conn_id).await.map(|w| w.writer_id), + Some(empty_writer_id) + ); + + let _ = pool.registry.unregister(rebound_conn_id).await; +} + +#[tokio::test] +async fn prune_closed_writers_closes_bound_clients_when_writer_is_non_empty() { + let pool = make_pool(128).await; + let now_epoch_secs = MePool::now_epoch_secs(); + let conn_ids = insert_draining_writer(&pool, 910, now_epoch_secs.saturating_sub(60), 1, 0).await; + + pool.prune_closed_writers().await; + + assert!(!writer_exists(&pool, 910).await); + assert!(pool.registry.get_writer(conn_ids[0]).await.is_none()); +} diff --git a/src/transport/middle_proxy/pool_writer.rs b/src/transport/middle_proxy/pool_writer.rs index 7490a98..5b23d7f 100644 --- a/src/transport/middle_proxy/pool_writer.rs +++ b/src/transport/middle_proxy/pool_writer.rs @@ -42,11 +42,10 @@ impl MePool { } for writer_id in closed_writer_ids { - if self.registry.is_writer_empty(writer_id).await { - let _ = self.remove_writer_only(writer_id).await; - } else { - let _ = self.remove_writer_and_close_clients(writer_id).await; + if self.remove_writer_if_empty(writer_id).await { + continue; } + let _ = self.remove_writer_and_close_clients(writer_id).await; } } @@ -501,6 +500,17 @@ impl MePool { } } + pub(crate) async fn remove_writer_if_empty(self: &Arc, writer_id: u64) -> bool { + if !self.registry.unregister_writer_if_empty(writer_id).await { + return false; + } + + // The registry empty-check and unregister are atomic with respect to binds, + // so remove_writer_only cannot return active bound sessions here. + let _ = self.remove_writer_only(writer_id).await; + true + } + async fn remove_writer_only(self: &Arc, writer_id: u64) -> Vec { let mut close_tx: Option> = None; let mut removed_addr: Option = None; diff --git a/src/transport/middle_proxy/registry.rs b/src/transport/middle_proxy/registry.rs index cbe1d9a..ea968b5 100644 --- a/src/transport/middle_proxy/registry.rs +++ b/src/transport/middle_proxy/registry.rs @@ -437,6 +437,23 @@ impl ConnRegistry { .unwrap_or(true) } + pub async fn unregister_writer_if_empty(&self, writer_id: u64) -> bool { + let mut inner = self.inner.write().await; + let Some(conn_ids) = inner.conns_for_writer.get(&writer_id) else { + // Writer is already absent from the registry. + return true; + }; + if !conn_ids.is_empty() { + return false; + } + + inner.writers.remove(&writer_id); + inner.last_meta_for_writer.remove(&writer_id); + inner.writer_idle_since_epoch_secs.remove(&writer_id); + inner.conns_for_writer.remove(&writer_id); + true + } + pub(super) async fn non_empty_writer_ids(&self, writer_ids: &[u64]) -> HashSet { let inner = self.inner.read().await; let mut out = HashSet::::with_capacity(writer_ids.len());