From e6ad9e4c7f89baaa6721df4aa987858771acce81 Mon Sep 17 00:00:00 2001 From: David Osipov Date: Thu, 19 Mar 2026 14:56:28 +0400 Subject: [PATCH] Add security tests for connection limits and handshake integrity - Implement a test to ensure that exceeding the user connection limit does not leak the current connections counter. - Add tests for direct relay connection refusal and adversarial scenarios to verify proper error handling. - Introduce fuzz testing for MTProto handshake to ensure robustness against malformed inputs and replay attacks. - Remove obsolete short TLS probe throttle tests and integrate their functionality into existing security tests. - Enhance middle relay tests to validate behavior during connection drops and cutovers, ensuring graceful error handling. - Add a test for half-close scenarios in relay to confirm bidirectional data flow continues as expected. --- src/protocol/tls.rs | 18 + src/protocol/tls_adversarial_tests.rs | 27 +- src/protocol/tls_fuzz_security_tests.rs | 195 +++++++++ src/proxy/client_security_tests.rs | 49 +++ src/proxy/direct_relay_security_tests.rs | 193 +++++++++ src/proxy/handshake.rs | 8 +- src/proxy/handshake_fuzz_security_tests.rs | 270 ++++++++++++ ...short_tls_probe_throttle_security_tests.rs | 50 --- src/proxy/handshake_security_tests.rs | 36 ++ src/proxy/middle_relay_security_tests.rs | 400 ++++++++++++++++-- src/proxy/relay_security_tests.rs | 43 ++ 11 files changed, 1198 insertions(+), 91 deletions(-) create mode 100644 src/protocol/tls_fuzz_security_tests.rs create mode 100644 src/proxy/handshake_fuzz_security_tests.rs delete mode 100644 src/proxy/handshake_gap_short_tls_probe_throttle_security_tests.rs diff --git a/src/protocol/tls.rs b/src/protocol/tls.rs index 77af648..ac49ae3 100644 --- a/src/protocol/tls.rs +++ b/src/protocol/tls.rs @@ -544,6 +544,11 @@ pub fn extract_sni_from_client_hello(handshake: &[u8]) -> Option { return None; } + let record_len = u16::from_be_bytes([handshake[3], handshake[4]]) as usize; + if handshake.len() < 5 + record_len { + return None; + } + let mut pos = 5; // after record header if handshake.get(pos).copied()? != 0x01 { return None; // not ClientHello @@ -649,6 +654,15 @@ fn is_valid_sni_hostname(host: &str) -> bool { /// Extract ALPN protocol list from ClientHello, return in offered order. pub fn extract_alpn_from_client_hello(handshake: &[u8]) -> Vec> { + if handshake.len() < 5 || handshake[0] != TLS_RECORD_HANDSHAKE { + return Vec::new(); + } + + let record_len = u16::from_be_bytes([handshake[3], handshake[4]]) as usize; + if handshake.len() < 5 + record_len { + return Vec::new(); + } + let mut pos = 5; // after record header if handshake.get(pos) != Some(&0x01) { return Vec::new(); @@ -806,3 +820,7 @@ mod security_tests; #[cfg(test)] #[path = "tls_adversarial_tests.rs"] mod adversarial_tests; + +#[cfg(test)] +#[path = "tls_fuzz_security_tests.rs"] +mod fuzz_security_tests; diff --git a/src/protocol/tls_adversarial_tests.rs b/src/protocol/tls_adversarial_tests.rs index e17c9f8..4c8aa72 100644 --- a/src/protocol/tls_adversarial_tests.rs +++ b/src/protocol/tls_adversarial_tests.rs @@ -286,13 +286,26 @@ fn extract_sni_with_duplicate_extensions_rejected() { ext.extend_from_slice(&(sni2.len() as u16).to_be_bytes()); ext.extend_from_slice(&sni2); - let mut h = vec![0x16, 0x03, 0x03, 0x00, 0x80, 0x01, 0x00, 0x00, 0x7C, 0x03, 0x03]; - h.extend_from_slice(&[0u8; 32]); - h.push(0); - h.extend_from_slice(&[0x00, 0x02, 0x13, 0x01]); - h.extend_from_slice(&[0x01, 0x00]); - h.extend_from_slice(&(ext.len() as u16).to_be_bytes()); - h.extend_from_slice(&ext); + let mut body = Vec::new(); + body.extend_from_slice(&[0x03, 0x03]); + body.extend_from_slice(&[0u8; 32]); + body.push(0); + body.extend_from_slice(&[0x00, 0x02, 0x13, 0x01]); + body.extend_from_slice(&[0x01, 0x00]); + body.extend_from_slice(&(ext.len() as u16).to_be_bytes()); + body.extend_from_slice(&ext); + + let mut handshake = Vec::new(); + handshake.push(0x01); + let body_len = (body.len() as u32).to_be_bytes(); + handshake.extend_from_slice(&body_len[1..4]); + handshake.extend_from_slice(&body); + + let mut h = Vec::new(); + h.push(0x16); + h.extend_from_slice(&[0x03, 0x03]); + h.extend_from_slice(&(handshake.len() as u16).to_be_bytes()); + h.extend_from_slice(&handshake); // Parser might return first, see second, or fail. OWASP ASVS prefers rejection of unexpected dups. // Telemt's `extract_sni` returns the first one found. diff --git a/src/protocol/tls_fuzz_security_tests.rs b/src/protocol/tls_fuzz_security_tests.rs new file mode 100644 index 0000000..32d8efe --- /dev/null +++ b/src/protocol/tls_fuzz_security_tests.rs @@ -0,0 +1,195 @@ +use super::*; +use crate::crypto::sha256_hmac; +use std::panic::catch_unwind; + +fn make_valid_tls_handshake_with_session_id( + secret: &[u8], + timestamp: u32, + session_id: &[u8], +) -> Vec { + let session_id_len = session_id.len(); + assert!(session_id_len <= u8::MAX as usize); + + let len = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1 + session_id_len; + let mut handshake = vec![0x42u8; len]; + handshake[TLS_DIGEST_POS + TLS_DIGEST_LEN] = session_id_len as u8; + let sid_start = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1; + handshake[sid_start..sid_start + session_id_len].copy_from_slice(session_id); + handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].fill(0); + + let mut digest = sha256_hmac(secret, &handshake); + let ts = timestamp.to_le_bytes(); + for idx in 0..4 { + digest[28 + idx] ^= ts[idx]; + } + + handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].copy_from_slice(&digest); + handshake +} + +fn make_valid_client_hello_record(host: &str, alpn_protocols: &[&[u8]]) -> Vec { + let mut body = Vec::new(); + body.extend_from_slice(&TLS_VERSION); + body.extend_from_slice(&[0u8; 32]); + body.push(0); + body.extend_from_slice(&2u16.to_be_bytes()); + body.extend_from_slice(&[0x13, 0x01]); + body.push(1); + body.push(0); + + let mut ext_blob = Vec::new(); + + let host_bytes = host.as_bytes(); + let mut sni_payload = Vec::new(); + sni_payload.extend_from_slice(&((host_bytes.len() + 3) as u16).to_be_bytes()); + sni_payload.push(0); + sni_payload.extend_from_slice(&(host_bytes.len() as u16).to_be_bytes()); + sni_payload.extend_from_slice(host_bytes); + ext_blob.extend_from_slice(&0x0000u16.to_be_bytes()); + ext_blob.extend_from_slice(&(sni_payload.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&sni_payload); + + if !alpn_protocols.is_empty() { + let mut alpn_list = Vec::new(); + for proto in alpn_protocols { + alpn_list.push(proto.len() as u8); + alpn_list.extend_from_slice(proto); + } + let mut alpn_data = Vec::new(); + alpn_data.extend_from_slice(&(alpn_list.len() as u16).to_be_bytes()); + alpn_data.extend_from_slice(&alpn_list); + + ext_blob.extend_from_slice(&0x0010u16.to_be_bytes()); + ext_blob.extend_from_slice(&(alpn_data.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&alpn_data); + } + + body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes()); + body.extend_from_slice(&ext_blob); + + let mut handshake = Vec::new(); + handshake.push(0x01); + let body_len = (body.len() as u32).to_be_bytes(); + handshake.extend_from_slice(&body_len[1..4]); + handshake.extend_from_slice(&body); + + let mut record = Vec::new(); + record.push(TLS_RECORD_HANDSHAKE); + record.extend_from_slice(&[0x03, 0x01]); + record.extend_from_slice(&(handshake.len() as u16).to_be_bytes()); + record.extend_from_slice(&handshake); + record +} + +#[test] +fn client_hello_fuzz_corpus_never_panics_or_accepts_corruption() { + let valid = make_valid_client_hello_record("example.com", &[b"h2", b"http/1.1"]); + assert_eq!(extract_sni_from_client_hello(&valid).as_deref(), Some("example.com")); + assert_eq!( + extract_alpn_from_client_hello(&valid), + vec![b"h2".to_vec(), b"http/1.1".to_vec()] + ); + assert!( + extract_sni_from_client_hello(&make_valid_client_hello_record("127.0.0.1", &[])).is_none(), + "literal IP hostnames must be rejected" + ); + + let mut corpus = vec![ + Vec::new(), + vec![0x16, 0x03, 0x03], + valid[..9].to_vec(), + valid[..valid.len() - 1].to_vec(), + ]; + + let mut wrong_type = valid.clone(); + wrong_type[0] = 0x15; + corpus.push(wrong_type); + + let mut wrong_handshake = valid.clone(); + wrong_handshake[5] = 0x02; + corpus.push(wrong_handshake); + + let mut wrong_length = valid.clone(); + wrong_length[3] ^= 0x7f; + corpus.push(wrong_length); + + for (idx, input) in corpus.iter().enumerate() { + assert!(catch_unwind(|| extract_sni_from_client_hello(input)).is_ok()); + assert!(catch_unwind(|| extract_alpn_from_client_hello(input)).is_ok()); + + if idx == 0 { + continue; + } + + assert!(extract_sni_from_client_hello(input).is_none(), "corpus item {idx} must fail closed for SNI"); + assert!(extract_alpn_from_client_hello(input).is_empty(), "corpus item {idx} must fail closed for ALPN"); + } +} + +#[test] +fn tls_handshake_fuzz_corpus_never_panics_and_rejects_digest_mutations() { + let secret = b"tls_fuzz_security_secret"; + let now: i64 = 1_700_000_000; + let base = make_valid_tls_handshake_with_session_id(secret, now as u32, &[0x42; 32]); + let secrets = vec![("fuzz-user".to_string(), secret.to_vec())]; + + assert!(validate_tls_handshake_at_time(&base, &secrets, false, now).is_some()); + + let mut corpus = Vec::new(); + + let mut truncated = base.clone(); + truncated.truncate(TLS_DIGEST_POS + 16); + corpus.push(truncated); + + let mut digest_flip = base.clone(); + digest_flip[TLS_DIGEST_POS + 7] ^= 0x80; + corpus.push(digest_flip); + + let mut session_id_len_overflow = base.clone(); + session_id_len_overflow[TLS_DIGEST_POS + TLS_DIGEST_LEN] = 33; + corpus.push(session_id_len_overflow); + + let mut timestamp_far_past = base.clone(); + timestamp_far_past[TLS_DIGEST_POS + 28..TLS_DIGEST_POS + 32] + .copy_from_slice(&((now - i64::from(TIME_SKEW_MAX) - 1) as u32).to_le_bytes()); + corpus.push(timestamp_far_past); + + let mut timestamp_far_future = base.clone(); + timestamp_far_future[TLS_DIGEST_POS + 28..TLS_DIGEST_POS + 32] + .copy_from_slice(&((now - TIME_SKEW_MIN + 1) as u32).to_le_bytes()); + corpus.push(timestamp_far_future); + + let mut seed = 0xA5A5_5A5A_F00D_BAAD_u64; + for _ in 0..32 { + let mut mutated = base.clone(); + for _ in 0..2 { + seed = seed.wrapping_mul(2862933555777941757).wrapping_add(3037000493); + let idx = TLS_DIGEST_POS + (seed as usize % TLS_DIGEST_LEN); + mutated[idx] ^= ((seed >> 17) as u8).wrapping_add(1); + } + corpus.push(mutated); + } + + for (idx, handshake) in corpus.iter().enumerate() { + let result = catch_unwind(|| validate_tls_handshake_at_time(handshake, &secrets, false, now)); + assert!(result.is_ok(), "corpus item {idx} must not panic"); + assert!(result.unwrap().is_none(), "corpus item {idx} must fail closed"); + } +} + +#[test] +fn tls_boot_time_acceptance_is_capped_by_replay_window() { + let secret = b"tls_boot_time_cap_secret"; + let secrets = vec![("boot-user".to_string(), secret.to_vec())]; + let boot_ts = 1u32; + let handshake = make_valid_tls_handshake_with_session_id(secret, boot_ts, &[0x42; 32]); + + assert!( + validate_tls_handshake_with_replay_window(&handshake, &secrets, false, 300).is_some(), + "boot-time timestamp should be accepted while replay window permits it" + ); + assert!( + validate_tls_handshake_with_replay_window(&handshake, &secrets, false, 0).is_none(), + "boot-time timestamp must be rejected when replay window disables the bypass" + ); +} diff --git a/src/proxy/client_security_tests.rs b/src/proxy/client_security_tests.rs index abd6266..d3c411e 100644 --- a/src/proxy/client_security_tests.rs +++ b/src/proxy/client_security_tests.rs @@ -7,6 +7,7 @@ use crate::protocol::tls; use crate::proxy::handshake::HandshakeSuccess; use crate::stream::{CryptoReader, CryptoWriter}; use crate::transport::proxy_protocol::ProxyProtocolV1Builder; +use std::net::Ipv4Addr; use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream}; @@ -630,6 +631,54 @@ async fn proxy_protocol_header_from_untrusted_peer_range_is_rejected_under_load( } } +#[tokio::test] +async fn reservation_limit_failure_does_not_leak_curr_connects_counter() { + let user = "leak-check-user"; + let mut config = crate::config::ProxyConfig::default(); + config.access.user_max_tcp_conns.insert(user.to_string(), 1); + + let stats = Arc::new(crate::stats::Stats::new()); + let ip_tracker = Arc::new(crate::ip_tracker::UserIpTracker::new()); + ip_tracker.set_user_limit(user, 8).await; + + let first_peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(198, 51, 200, 1)), 50001); + let first = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + first_peer, + ip_tracker.clone(), + ) + .await + .expect("first reservation must succeed"); + + assert_eq!(stats.get_user_curr_connects(user), 1); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 1); + + let second_peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(198, 51, 200, 2)), 50002); + let second = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + second_peer, + ip_tracker.clone(), + ) + .await; + + assert!( + matches!(second, Err(crate::error::ProxyError::ConnectionLimitExceeded { user: denied }) if denied == user), + "second reservation must be rejected at the configured tcp-conns limit" + ); + assert_eq!(stats.get_user_curr_connects(user), 1, "failed acquisition must not leak a counter increment"); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 1, "failed acquisition must not mutate IP tracker state"); + + first.release().await; + ip_tracker.drain_cleanup_queue().await; + + assert_eq!(stats.get_user_curr_connects(user), 0); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 0); +} + #[tokio::test] async fn short_tls_probe_is_masked_through_client_pipeline() { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); diff --git a/src/proxy/direct_relay_security_tests.rs b/src/proxy/direct_relay_security_tests.rs index 6c25068..e8016a5 100644 --- a/src/proxy/direct_relay_security_tests.rs +++ b/src/proxy/direct_relay_security_tests.rs @@ -12,8 +12,10 @@ use std::path::Path; use std::sync::Arc; use std::sync::atomic::{AtomicUsize, Ordering}; use std::time::Duration; +use tokio::io::AsyncReadExt; use tokio::io::duplex; use tokio::net::TcpListener; +use tokio::time::{timeout, Duration as TokioDuration}; fn make_crypto_reader(reader: R) -> CryptoReader where @@ -1322,3 +1324,194 @@ fn stress_prefer_v6_override_matrix_is_deterministic_under_mixed_inputs() { assert!(first.is_ipv6(), "dc {idx}: v6 override should be preferred"); } } + +#[tokio::test] +async fn negative_direct_relay_dc_connection_refused_fails_fast() { + let (client_reader_side, _client_writer_side) = duplex(1024); + let (_client_reader_relay, client_writer_side) = duplex(1024); + + let key = [0u8; 32]; + let iv = 0u128; + let client_reader = CryptoReader::new(client_reader_side, AesCtr::new(&key, iv)); + let client_writer = CryptoWriter::new(client_writer_side, AesCtr::new(&key, iv), 1024); + + let stats = Arc::new(Stats::new()); + let buffer_pool = Arc::new(BufferPool::with_config(1024, 1)); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = RouteRuntimeController::new(RelayRouteMode::Direct); + + // Reserve an ephemeral port and immediately release it to deterministically + // exercise the direct-connect failure path without long-lived hangs. + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let dc_addr = listener.local_addr().unwrap(); + drop(listener); + + let mut config_with_override = ProxyConfig::default(); + config_with_override.dc_overrides.insert("1".to_string(), vec![dc_addr.to_string()]); + let config = Arc::new(config_with_override); + + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + enabled: true, + weight: 1, + scopes: String::new(), + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + selected_scope: String::new(), + }], + 1, + 100, + 5000, + 3, + false, + stats.clone(), + )); + + let success = HandshakeSuccess { + user: "test-user".to_string(), + peer: "127.0.0.1:12345".parse().unwrap(), + dc_idx: 1, + proto_tag: ProtoTag::Intermediate, + enc_key: key, + enc_iv: iv, + dec_key: key, + dec_iv: iv, + is_tls: false, + }; + + let result = timeout( + TokioDuration::from_secs(2), + handle_via_direct( + client_reader, + client_writer, + success, + upstream_manager, + stats, + config, + buffer_pool, + rng, + route_runtime.subscribe(), + route_runtime.snapshot(), + 0xABCD_1234, + ), + ) + .await + .expect("direct relay must fail fast on connection-refused upstream"); + + assert!( + result.is_err(), + "connection-refused upstream must fail closed" + ); +} + +#[tokio::test] +async fn adversarial_direct_relay_cutover_integrity() { + let (client_reader_side, _client_writer_side) = duplex(1024); + let (_client_reader_relay, client_writer_side) = duplex(1024); + + let key = [0u8; 32]; + let iv = 0u128; + let client_reader = CryptoReader::new(client_reader_side, AesCtr::new(&key, iv)); + let client_writer = CryptoWriter::new(client_writer_side, AesCtr::new(&key, iv), 1024); + + let stats = Arc::new(Stats::new()); + let buffer_pool = Arc::new(BufferPool::with_config(1024, 1)); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = RouteRuntimeController::new(RelayRouteMode::Direct); + + // Mock upstream server. + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let dc_addr = listener.local_addr().unwrap(); + + tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + // Read handshake nonce. + let mut nonce = [0u8; 64]; + let _ = stream.read_exact(&mut nonce).await; + // Keep connection open. + tokio::time::sleep(TokioDuration::from_secs(5)).await; + }); + + let mut config_with_override = ProxyConfig::default(); + config_with_override.dc_overrides.insert("1".to_string(), vec![dc_addr.to_string()]); + let config = Arc::new(config_with_override); + + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + enabled: true, + weight: 1, + scopes: String::new(), + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + selected_scope: String::new(), + }], + 1, + 100, + 5000, + 3, + false, + stats.clone(), + )); + + let success = HandshakeSuccess { + user: "test-user".to_string(), + peer: "127.0.0.1:12345".parse().unwrap(), + dc_idx: 1, + proto_tag: ProtoTag::Intermediate, + enc_key: key, + enc_iv: iv, + dec_key: key, + dec_iv: iv, + is_tls: false, + }; + + let stats_for_task = stats.clone(); + let runtime_clone = route_runtime.clone(); + let session_task = tokio::spawn(async move { + handle_via_direct( + client_reader, + client_writer, + success, + upstream_manager, + stats_for_task, + config, + buffer_pool, + rng, + runtime_clone.subscribe(), + runtime_clone.snapshot(), + 0xABCD_1234, + ).await + }); + + timeout(TokioDuration::from_secs(2), async { + loop { + if stats.get_current_connections_direct() == 1 { + break; + } + tokio::time::sleep(TokioDuration::from_millis(10)).await; + } + }) + .await + .expect("direct relay session must start before cutover"); + + // Trigger cutover. + route_runtime.set_mode(RelayRouteMode::Middle).unwrap(); + + // The session should terminate after the staggered delay (1000-2000ms). + let result = timeout(TokioDuration::from_secs(5), session_task) + .await + .expect("Session must terminate after cutover") + .expect("Session must not panic"); + + assert!( + matches!( + result, + Err(ProxyError::Proxy(ref msg)) if msg == ROUTE_SWITCH_ERROR_MSG + ), + "Session must terminate with route switch error on cutover" + ); +} diff --git a/src/proxy/handshake.rs b/src/proxy/handshake.rs index a23d514..b930caf 100644 --- a/src/proxy/handshake.rs +++ b/src/proxy/handshake.rs @@ -967,14 +967,14 @@ pub fn encrypt_tg_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec { #[path = "handshake_security_tests.rs"] mod security_tests; -#[cfg(test)] -#[path = "handshake_gap_short_tls_probe_throttle_security_tests.rs"] -mod gap_short_tls_probe_throttle_security_tests; - #[cfg(test)] #[path = "handshake_adversarial_tests.rs"] mod adversarial_tests; +#[cfg(test)] +#[path = "handshake_fuzz_security_tests.rs"] +mod fuzz_security_tests; + /// Compile-time guard: HandshakeSuccess holds cryptographic key material and /// must never be Copy. A Copy impl would allow silent key duplication, /// undermining the zeroize-on-drop guarantee. diff --git a/src/proxy/handshake_fuzz_security_tests.rs b/src/proxy/handshake_fuzz_security_tests.rs new file mode 100644 index 0000000..d72c9cd --- /dev/null +++ b/src/proxy/handshake_fuzz_security_tests.rs @@ -0,0 +1,270 @@ +use super::*; +use crate::config::ProxyConfig; +use crate::crypto::AesCtr; +use crate::crypto::sha256; +use crate::protocol::constants::ProtoTag; +use crate::stats::ReplayChecker; +use std::net::SocketAddr; +use std::sync::MutexGuard; +use tokio::time::{timeout, Duration as TokioDuration}; + +fn make_mtproto_handshake_with_proto_bytes( + secret_hex: &str, + proto_bytes: [u8; 4], + dc_idx: i16, +) -> [u8; HANDSHAKE_LEN] { + let secret = hex::decode(secret_hex).expect("secret hex must decode"); + let mut handshake = [0x5Au8; HANDSHAKE_LEN]; + for (idx, b) in handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN] + .iter_mut() + .enumerate() + { + *b = (idx as u8).wrapping_add(1); + } + + let dec_prekey = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN]; + let dec_iv_bytes = &handshake[SKIP_LEN + PREKEY_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]; + + let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len()); + dec_key_input.extend_from_slice(dec_prekey); + dec_key_input.extend_from_slice(&secret); + let dec_key = sha256(&dec_key_input); + + let mut dec_iv_arr = [0u8; IV_LEN]; + dec_iv_arr.copy_from_slice(dec_iv_bytes); + let dec_iv = u128::from_be_bytes(dec_iv_arr); + + let mut stream = AesCtr::new(&dec_key, dec_iv); + let keystream = stream.encrypt(&[0u8; HANDSHAKE_LEN]); + + let mut target_plain = [0u8; HANDSHAKE_LEN]; + target_plain[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_bytes); + target_plain[DC_IDX_POS..DC_IDX_POS + 2].copy_from_slice(&dc_idx.to_le_bytes()); + + for idx in PROTO_TAG_POS..HANDSHAKE_LEN { + handshake[idx] = target_plain[idx] ^ keystream[idx]; + } + + handshake +} + +fn make_valid_mtproto_handshake(secret_hex: &str, proto_tag: ProtoTag, dc_idx: i16) -> [u8; HANDSHAKE_LEN] { + make_mtproto_handshake_with_proto_bytes(secret_hex, proto_tag.to_bytes(), dc_idx) +} + +fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig { + let mut cfg = ProxyConfig::default(); + cfg.access.users.clear(); + cfg.access.users.insert("user".to_string(), secret_hex.to_string()); + cfg.access.ignore_time_skew = true; + cfg.general.modes.secure = true; + cfg +} + +fn auth_probe_test_guard() -> MutexGuard<'static, ()> { + auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +#[tokio::test] +async fn mtproto_handshake_duplicate_digest_is_replayed_on_second_attempt() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret_hex = "11223344556677889900aabbccddeeff"; + let base = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 2); + let config = test_config_with_secret_hex(secret_hex); + let replay_checker = ReplayChecker::new(128, TokioDuration::from_secs(60)); + let peer: SocketAddr = "192.0.2.1:12345".parse().unwrap(); + + let first = handle_mtproto_handshake( + &base, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + assert!(matches!(first, HandshakeResult::Success(_))); + + let second = handle_mtproto_handshake( + &base, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + assert!(matches!(second, HandshakeResult::BadClient { .. })); + + clear_auth_probe_state_for_testing(); +} + +#[tokio::test] +async fn mtproto_handshake_fuzz_corpus_never_panics_and_stays_fail_closed() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret_hex = "00112233445566778899aabbccddeeff"; + let base = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 1); + let config = test_config_with_secret_hex(secret_hex); + let replay_checker = ReplayChecker::new(128, TokioDuration::from_secs(60)); + let peer: SocketAddr = "192.0.2.2:54321".parse().unwrap(); + + let mut corpus = Vec::<[u8; HANDSHAKE_LEN]>::new(); + + corpus.push(make_mtproto_handshake_with_proto_bytes( + secret_hex, + [0x00, 0x00, 0x00, 0x00], + 1, + )); + corpus.push(make_mtproto_handshake_with_proto_bytes( + secret_hex, + [0xff, 0xff, 0xff, 0xff], + 1, + )); + corpus.push(make_valid_mtproto_handshake( + "ffeeddccbbaa99887766554433221100", + ProtoTag::Secure, + 1, + )); + + let mut seed = 0xF0F0_F00D_BAAD_CAFEu64; + for _ in 0..32 { + let mut mutated = base; + for _ in 0..4 { + seed = seed.wrapping_mul(2862933555777941757).wrapping_add(3037000493); + let idx = SKIP_LEN + (seed as usize % (PREKEY_LEN + IV_LEN)); + mutated[idx] ^= ((seed >> 19) as u8).wrapping_add(1); + } + corpus.push(mutated); + } + + for (idx, input) in corpus.into_iter().enumerate() { + let result = timeout( + TokioDuration::from_secs(1), + handle_mtproto_handshake( + &input, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ), + ) + .await + .expect("fuzzed handshake must complete in time"); + + assert!( + matches!(result, HandshakeResult::BadClient { .. }), + "corpus item {idx} must fail closed" + ); + } + + clear_auth_probe_state_for_testing(); +} + +#[tokio::test] +async fn mtproto_handshake_mixed_corpus_never_panics_and_exact_duplicates_are_rejected() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret_hex = "99887766554433221100ffeeddccbbaa"; + let base = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 4); + let config = test_config_with_secret_hex(secret_hex); + let replay_checker = ReplayChecker::new(256, TokioDuration::from_secs(60)); + let peer: SocketAddr = "192.0.2.44:45444".parse().unwrap(); + + let first = timeout( + TokioDuration::from_secs(1), + handle_mtproto_handshake( + &base, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ), + ) + .await + .expect("base handshake must not hang"); + assert!(matches!(first, HandshakeResult::Success(_))); + + let replay = timeout( + TokioDuration::from_secs(1), + handle_mtproto_handshake( + &base, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ), + ) + .await + .expect("duplicate handshake must not hang"); + assert!(matches!(replay, HandshakeResult::BadClient { .. })); + + let mut corpus = Vec::<[u8; HANDSHAKE_LEN]>::new(); + + let mut prekey_flip = base; + prekey_flip[SKIP_LEN] ^= 0x80; + corpus.push(prekey_flip); + + let mut iv_flip = base; + iv_flip[SKIP_LEN + PREKEY_LEN] ^= 0x01; + corpus.push(iv_flip); + + let mut tail_flip = base; + tail_flip[SKIP_LEN + PREKEY_LEN + IV_LEN - 1] ^= 0x40; + corpus.push(tail_flip); + + let mut seed = 0xBADC_0FFE_EE11_4242u64; + for _ in 0..24 { + let mut mutated = base; + for _ in 0..3 { + seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1); + let idx = SKIP_LEN + (seed as usize % (PREKEY_LEN + IV_LEN)); + mutated[idx] ^= ((seed >> 16) as u8).wrapping_add(1); + } + corpus.push(mutated); + } + + for (idx, input) in corpus.iter().enumerate() { + let result = timeout( + TokioDuration::from_secs(1), + handle_mtproto_handshake( + input, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ), + ) + .await + .expect("fuzzed handshake must complete in time"); + + assert!( + matches!(result, HandshakeResult::BadClient { .. }), + "mixed corpus item {idx} must fail closed" + ); + } + + clear_auth_probe_state_for_testing(); +} \ No newline at end of file diff --git a/src/proxy/handshake_gap_short_tls_probe_throttle_security_tests.rs b/src/proxy/handshake_gap_short_tls_probe_throttle_security_tests.rs deleted file mode 100644 index 2ea32bc..0000000 --- a/src/proxy/handshake_gap_short_tls_probe_throttle_security_tests.rs +++ /dev/null @@ -1,50 +0,0 @@ -use super::*; -use crate::stats::ReplayChecker; -use std::net::SocketAddr; -use std::time::Duration; - -fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig { - let mut cfg = ProxyConfig::default(); - cfg.access.users.clear(); - cfg.access - .users - .insert("user".to_string(), secret_hex.to_string()); - cfg.access.ignore_time_skew = true; - cfg -} - -#[tokio::test] -async fn gap_t01_short_tls_probe_burst_is_throttled() { - let _guard = auth_probe_test_lock() - .lock() - .unwrap_or_else(|poisoned| poisoned.into_inner()); - clear_auth_probe_state_for_testing(); - - let config = test_config_with_secret_hex("11111111111111111111111111111111"); - let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); - let rng = SecureRandom::new(); - let peer: SocketAddr = "198.51.100.171:44361".parse().unwrap(); - - let too_short = vec![0x16, 0x03, 0x01]; - - for _ in 0..AUTH_PROBE_BACKOFF_START_FAILS { - let result = handle_tls_handshake( - &too_short, - tokio::io::empty(), - tokio::io::sink(), - peer, - &config, - &replay_checker, - &rng, - None, - ) - .await; - assert!(matches!(result, HandshakeResult::BadClient { .. })); - } - - assert!( - auth_probe_fail_streak_for_testing(peer.ip()) - .is_some_and(|streak| streak >= AUTH_PROBE_BACKOFF_START_FAILS), - "short TLS probe bursts must increase auth-probe fail streak" - ); -} diff --git a/src/proxy/handshake_security_tests.rs b/src/proxy/handshake_security_tests.rs index c93d18e..5263413 100644 --- a/src/proxy/handshake_security_tests.rs +++ b/src/proxy/handshake_security_tests.rs @@ -1997,6 +1997,42 @@ fn auth_probe_round_limited_overcap_eviction_marks_saturation_and_keeps_newcomer ); } +#[tokio::test] +async fn gap_t01_short_tls_probe_burst_is_throttled() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let config = test_config_with_secret_hex("11111111111111111111111111111111"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.171:44361".parse().unwrap(); + + let too_short = vec![0x16, 0x03, 0x01]; + + for _ in 0..AUTH_PROBE_BACKOFF_START_FAILS { + let result = handle_tls_handshake( + &too_short, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(result, HandshakeResult::BadClient { .. })); + } + + assert!( + auth_probe_fail_streak_for_testing(peer.ip()) + .is_some_and(|streak| streak >= AUTH_PROBE_BACKOFF_START_FAILS), + "short TLS probe bursts must increase auth-probe fail streak" + ); +} + #[test] fn stress_auth_probe_overcap_churn_does_not_starve_high_threat_sentinel_bucket() { let _guard = auth_probe_test_lock() diff --git a/src/proxy/middle_relay_security_tests.rs b/src/proxy/middle_relay_security_tests.rs index 896e465..b8ed52a 100644 --- a/src/proxy/middle_relay_security_tests.rs +++ b/src/proxy/middle_relay_security_tests.rs @@ -1,10 +1,11 @@ use super::*; +use crate::proxy::handshake::HandshakeSuccess; +use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController}; use bytes::Bytes; use crate::crypto::AesCtr; use crate::crypto::SecureRandom; use crate::config::{GeneralConfig, MeRouteNoWriterMode, MeSocksKdfPolicy, MeWriterPickMode}; use crate::network::probe::NetworkDecision; -use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController}; use crate::stats::Stats; use crate::stream::{BufferPool, CryptoReader, CryptoWriter, PooledBuffer}; use crate::transport::middle_proxy::MePool; @@ -20,6 +21,7 @@ use tokio::io::AsyncReadExt; use tokio::io::AsyncWriteExt; use tokio::io::duplex; use tokio::time::{Duration as TokioDuration, timeout}; +use std::sync::{Mutex, OnceLock}; fn make_pooled_payload(data: &[u8]) -> PooledBuffer { let pool = Arc::new(BufferPool::with_config(data.len().max(1), 4)); @@ -36,6 +38,11 @@ fn make_pooled_payload_from(pool: &Arc, data: &[u8]) -> PooledBuffer payload } +fn quota_user_lock_test_lock() -> &'static Mutex<()> { + static TEST_LOCK: OnceLock> = OnceLock::new(); + TEST_LOCK.get_or_init(|| Mutex::new(())) +} + #[test] fn should_yield_sender_only_on_budget_with_backlog() { assert!(!should_yield_c2me_sender(0, true)); @@ -244,6 +251,10 @@ fn quota_user_lock_cache_reuses_entry_for_same_user() { #[test] fn quota_user_lock_cache_is_bounded_under_unique_churn() { + let _guard = quota_user_lock_test_lock() + .lock() + .expect("quota user lock test lock must be available"); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); map.clear(); @@ -261,39 +272,51 @@ fn quota_user_lock_cache_is_bounded_under_unique_churn() { #[test] fn quota_user_lock_cache_saturation_returns_ephemeral_lock_without_growth() { - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); + let _guard = quota_user_lock_test_lock() + .lock() + .expect("quota user lock test lock must be available"); - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - for idx in 0..QUOTA_USER_LOCKS_MAX { - let user = format!("quota-held-user-{idx}"); - retained.push(quota_user_lock(&user)); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + for attempt in 0..8u32 { + map.clear(); + + let prefix = format!("quota-held-user-{}-{attempt}", std::process::id()); + let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); + for idx in 0..QUOTA_USER_LOCKS_MAX { + let user = format!("{prefix}-{idx}"); + retained.push(quota_user_lock(&user)); + } + + if map.len() != QUOTA_USER_LOCKS_MAX { + drop(retained); + continue; + } + + let overflow_user = format!("quota-overflow-user-{}-{attempt}", std::process::id()); + let overflow_a = quota_user_lock(&overflow_user); + let overflow_b = quota_user_lock(&overflow_user); + + assert_eq!( + map.len(), + QUOTA_USER_LOCKS_MAX, + "overflow acquisition must not grow cache past hard limit" + ); + assert!( + map.get(&overflow_user).is_none(), + "overflow path should not cache new user lock when map is saturated and all entries are retained" + ); + assert!( + !Arc::ptr_eq(&overflow_a, &overflow_b), + "overflow user lock should be ephemeral under saturation to preserve bounded cache size" + ); + + drop(retained); + return; } - assert_eq!( - map.len(), - QUOTA_USER_LOCKS_MAX, - "precondition: cache should be full before overflow acquisition" + panic!( + "unable to observe stable saturated lock-cache precondition after bounded retries" ); - - let overflow_a = quota_user_lock("quota-overflow-user"); - let overflow_b = quota_user_lock("quota-overflow-user"); - - assert_eq!( - map.len(), - QUOTA_USER_LOCKS_MAX, - "overflow acquisition must not grow cache past hard limit" - ); - assert!( - map.get("quota-overflow-user").is_none(), - "overflow path should not cache new user lock when map is saturated and all entries are retained" - ); - assert!( - !Arc::ptr_eq(&overflow_a, &overflow_b), - "overflow user lock should be ephemeral under saturation to preserve bounded cache size" - ); - - drop(retained); } #[tokio::test(flavor = "multi_thread", worker_threads = 4)] @@ -2169,3 +2192,320 @@ async fn middle_relay_cutover_storm_multi_session_keeps_generic_errors_and_relea drop(client_sides); } + +#[tokio::test] +async fn secure_padding_distribution_in_relay_writer() { + timeout(TokioDuration::from_secs(10), async { + let (mut client_side, relay_side) = duplex(512 * 1024); + let key = [0u8; 32]; + let iv = 0u128; + let mut writer = CryptoWriter::new(relay_side, AesCtr::new(&key, iv), 8 * 1024); + let rng = Arc::new(SecureRandom::new()); + let mut frame_buf = Vec::new(); + let mut decryptor = AesCtr::new(&key, iv); + + let mut padding_counts = [0usize; 4]; + let iterations = 180usize; + let payload = vec![0xAAu8; 100]; // 4-byte aligned + + for _ in 0..iterations { + write_client_payload( + &mut writer, + ProtoTag::Secure, + 0, + &payload, + &rng, + &mut frame_buf, + ) + .await + .expect("payload write must succeed"); + writer + .flush() + .await + .expect("writer flush must complete so encrypted frame becomes readable"); + + let mut len_buf = [0u8; 4]; + client_side + .read_exact(&mut len_buf) + .await + .expect("must read encrypted secure length"); + let decrypted_len_bytes = decryptor.decrypt(&len_buf); + let decrypted_len_bytes: [u8; 4] = decrypted_len_bytes + .try_into() + .expect("decrypted length must be 4 bytes"); + let wire_len = (u32::from_le_bytes(decrypted_len_bytes) & 0x7fff_ffff) as usize; + + assert!( + wire_len >= payload.len(), + "wire length must include at least payload bytes" + ); + let padding_len = wire_len - payload.len(); + assert!(padding_len >= 1 && padding_len <= 3); + padding_counts[padding_len] += 1; + + // Drain and decrypt frame bytes so CTR state stays aligned across writes. + let mut trash = vec![0u8; wire_len]; + client_side + .read_exact(&mut trash) + .await + .expect("must read encrypted secure frame body"); + let _ = decryptor.decrypt(&trash); + } + + for p in 1..=3 { + let count = padding_counts[p]; + assert!( + count > iterations / 8, + "padding length {p} is under-represented ({count}/{iterations})" + ); + } + }) + .await + .expect("secure padding distribution test exceeded runtime budget"); +} + +#[tokio::test] +async fn negative_middle_end_connection_lost_during_relay_exits_on_client_eof() { + let (client_reader_side, client_writer_side) = duplex(1024); + let (_relay_reader_side, relay_writer_side) = duplex(1024); + + let key = [0u8; 32]; + let iv = 0u128; + let crypto_reader = CryptoReader::new(client_reader_side, AesCtr::new(&key, iv)); + let crypto_writer = CryptoWriter::new(relay_writer_side, AesCtr::new(&key, iv), 1024); + + let stats = Arc::new(Stats::new()); + let config = Arc::new(ProxyConfig::default()); + let buffer_pool = Arc::new(BufferPool::with_config(1024, 1)); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = RouteRuntimeController::new(RelayRouteMode::Middle); + + // Create an ME pool. + let me_pool = make_me_pool_for_abort_test(stats.clone()).await; + + // ConnRegistry ids are monotonic; reserve one id so we can predict the + // next session conn_id and close it deterministically without relying on + // writer-bound views such as active_conn_ids(). + let (probe_conn_id, probe_rx) = me_pool.registry().register().await; + drop(probe_rx); + me_pool.registry().unregister(probe_conn_id).await; + let target_conn_id = probe_conn_id.wrapping_add(1); + + let success = HandshakeSuccess { + user: "test-user".to_string(), + peer: "127.0.0.1:12345".parse().unwrap(), + dc_idx: 1, + proto_tag: ProtoTag::Intermediate, + enc_key: key, + enc_iv: iv, + dec_key: key, + dec_iv: iv, + is_tls: false, + }; + + let session_task = tokio::spawn(handle_via_middle_proxy( + crypto_reader, + crypto_writer, + success, + me_pool.clone(), + stats.clone(), + config.clone(), + buffer_pool.clone(), + "127.0.0.1:443".parse().unwrap(), + rng.clone(), + route_runtime.subscribe(), + route_runtime.snapshot(), + 0x1234_5678, + )); + + // Wait until session startup is visible, then unregister the predicted + // conn_id to close the per-session ME response channel. + timeout(TokioDuration::from_millis(500), async { + loop { + if stats.get_current_connections_me() >= 1 { + break; + } + tokio::time::sleep(TokioDuration::from_millis(10)).await; + } + }) + .await + .expect("ME session must start before channel close simulation"); + + me_pool.registry().unregister(target_conn_id).await; + + drop(client_writer_side); + + let result = timeout(TokioDuration::from_secs(2), session_task) + .await + .expect("Session task must terminate after ME drop and client EOF") + .expect("Session task must not panic"); + + assert!( + result.is_ok(), + "Session should complete cleanly after ME drop when client closes, got: {:?}", + result + ); +} + +#[tokio::test] +async fn adversarial_middle_end_drop_plus_cutover_returns_generic_route_switch() { + let (client_reader_side, _client_writer_side) = duplex(1024); + let (_relay_reader_side, relay_writer_side) = duplex(1024); + + let key = [0u8; 32]; + let iv = 0u128; + let crypto_reader = CryptoReader::new(client_reader_side, AesCtr::new(&key, iv)); + let crypto_writer = CryptoWriter::new(relay_writer_side, AesCtr::new(&key, iv), 1024); + + let stats = Arc::new(Stats::new()); + let config = Arc::new(ProxyConfig::default()); + let buffer_pool = Arc::new(BufferPool::with_config(1024, 1)); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Middle)); + + let me_pool = make_me_pool_for_abort_test(stats.clone()).await; + + // Predict the next conn_id so we can force-drop its ME channel deterministically. + let (probe_conn_id, probe_rx) = me_pool.registry().register().await; + drop(probe_rx); + me_pool.registry().unregister(probe_conn_id).await; + let target_conn_id = probe_conn_id.wrapping_add(1); + + let success = HandshakeSuccess { + user: "test-user-cutover".to_string(), + peer: "127.0.0.1:12345".parse().unwrap(), + dc_idx: 1, + proto_tag: ProtoTag::Intermediate, + enc_key: key, + enc_iv: iv, + dec_key: key, + dec_iv: iv, + is_tls: false, + }; + + let runtime_clone = route_runtime.clone(); + let session_task = tokio::spawn(handle_via_middle_proxy( + crypto_reader, + crypto_writer, + success, + me_pool.clone(), + stats.clone(), + config, + buffer_pool, + "127.0.0.1:443".parse().unwrap(), + rng, + runtime_clone.subscribe(), + runtime_clone.snapshot(), + 0xC001_CAFE, + )); + + timeout(TokioDuration::from_millis(500), async { + loop { + if stats.get_current_connections_me() >= 1 { + break; + } + tokio::time::sleep(TokioDuration::from_millis(10)).await; + } + }) + .await + .expect("ME session must start before race trigger"); + + // Race ME channel drop with route cutover and assert generic client-visible outcome. + me_pool.registry().unregister(target_conn_id).await; + assert!( + route_runtime.set_mode(RelayRouteMode::Direct).is_some(), + "cutover must advance generation" + ); + + let relay_result = timeout(TokioDuration::from_secs(6), session_task) + .await + .expect("session must terminate under ME-drop + cutover race") + .expect("session task must not panic"); + + assert!( + matches!( + relay_result, + Err(ProxyError::Proxy(ref msg)) if msg == ROUTE_SWITCH_ERROR_MSG + ), + "race outcome must remain generic and not leak ME internals, got: {:?}", + relay_result + ); +} + +#[tokio::test] +async fn stress_middle_end_drop_with_client_eof_never_hangs_across_burst() { + let stats = Arc::new(Stats::new()); + let me_pool = make_me_pool_for_abort_test(stats.clone()).await; + + for round in 0..32u64 { + let (client_reader_side, client_writer_side) = duplex(1024); + let (_relay_reader_side, relay_writer_side) = duplex(1024); + + let key = [0u8; 32]; + let iv = 0u128; + let crypto_reader = CryptoReader::new(client_reader_side, AesCtr::new(&key, iv)); + let crypto_writer = CryptoWriter::new(relay_writer_side, AesCtr::new(&key, iv), 1024); + + let config = Arc::new(ProxyConfig::default()); + let buffer_pool = Arc::new(BufferPool::with_config(1024, 1)); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = RouteRuntimeController::new(RelayRouteMode::Middle); + + let (probe_conn_id, probe_rx) = me_pool.registry().register().await; + drop(probe_rx); + me_pool.registry().unregister(probe_conn_id).await; + let target_conn_id = probe_conn_id.wrapping_add(1); + + let success = HandshakeSuccess { + user: format!("stress-me-drop-eof-{round}"), + peer: "127.0.0.1:12345".parse().unwrap(), + dc_idx: 1, + proto_tag: ProtoTag::Intermediate, + enc_key: key, + enc_iv: iv, + dec_key: key, + dec_iv: iv, + is_tls: false, + }; + + let session_task = tokio::spawn(handle_via_middle_proxy( + crypto_reader, + crypto_writer, + success, + me_pool.clone(), + stats.clone(), + config, + buffer_pool, + "127.0.0.1:443".parse().unwrap(), + rng, + route_runtime.subscribe(), + route_runtime.snapshot(), + 0xD00D_0000 + round, + )); + + timeout(TokioDuration::from_millis(500), async { + loop { + if stats.get_current_connections_me() >= 1 { + break; + } + tokio::time::sleep(TokioDuration::from_millis(10)).await; + } + }) + .await + .expect("session must start before forced drop in burst round"); + + me_pool.registry().unregister(target_conn_id).await; + drop(client_writer_side); + + let result = timeout(TokioDuration::from_secs(2), session_task) + .await + .expect("burst round session must terminate quickly") + .expect("burst round session must not panic"); + + assert!( + result.is_ok(), + "burst round {round}: expected clean shutdown after ME drop + EOF, got: {:?}", + result + ); + } +} diff --git a/src/proxy/relay_security_tests.rs b/src/proxy/relay_security_tests.rs index 9ba8295..4b002a4 100644 --- a/src/proxy/relay_security_tests.rs +++ b/src/proxy/relay_security_tests.rs @@ -1140,3 +1140,46 @@ async fn relay_bidirectional_light_fuzz_permission_denied_messages_remain_io_err ); } } + +#[tokio::test] +async fn relay_half_close_keeps_reverse_direction_progressing() { + let stats = Arc::new(Stats::new()); + let user = "half-close-user"; + + let (client_peer, relay_client) = duplex(1024); + let (relay_server, server_peer) = duplex(1024); + + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + let (mut cp_reader, mut cp_writer) = tokio::io::split(client_peer); + let (mut sp_reader, mut sp_writer) = tokio::io::split(server_peer); + + let relay_task = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 8192, + 8192, + user, + Arc::clone(&stats), + None, + Arc::new(BufferPool::new()), + )); + + sp_writer.write_all(&[0x10, 0x20, 0x30, 0x40]).await.unwrap(); + sp_writer.shutdown().await.unwrap(); + + let mut inbound = [0u8; 4]; + cp_reader.read_exact(&mut inbound).await.unwrap(); + assert_eq!(inbound, [0x10, 0x20, 0x30, 0x40]); + + cp_writer.write_all(&[0xaa, 0xbb, 0xcc, 0xdd]).await.unwrap(); + let mut outbound = [0u8; 4]; + sp_reader.read_exact(&mut outbound).await.unwrap(); + assert_eq!(outbound, [0xaa, 0xbb, 0xcc, 0xdd]); + + relay_task.abort(); + let joined = relay_task.await; + assert!(joined.is_err(), "aborted relay task must return join error"); +}