Add security tests for connection limits and handshake integrity

- Implement a test to ensure that exceeding the user connection limit does not leak the current connections counter.
- Add tests for direct relay connection refusal and adversarial scenarios to verify proper error handling.
- Introduce fuzz testing for MTProto handshake to ensure robustness against malformed inputs and replay attacks.
- Remove obsolete short TLS probe throttle tests and integrate their functionality into existing security tests.
- Enhance middle relay tests to validate behavior during connection drops and cutovers, ensuring graceful error handling.
- Add a test for half-close scenarios in relay to confirm bidirectional data flow continues as expected.
This commit is contained in:
David Osipov 2026-03-19 14:56:28 +04:00
parent 2a01ca2d6f
commit e6ad9e4c7f
No known key found for this signature in database
GPG Key ID: 0E55C4A47454E82E
11 changed files with 1198 additions and 91 deletions

View File

@ -544,6 +544,11 @@ pub fn extract_sni_from_client_hello(handshake: &[u8]) -> Option<String> {
return None; return None;
} }
let record_len = u16::from_be_bytes([handshake[3], handshake[4]]) as usize;
if handshake.len() < 5 + record_len {
return None;
}
let mut pos = 5; // after record header let mut pos = 5; // after record header
if handshake.get(pos).copied()? != 0x01 { if handshake.get(pos).copied()? != 0x01 {
return None; // not ClientHello return None; // not ClientHello
@ -649,6 +654,15 @@ fn is_valid_sni_hostname(host: &str) -> bool {
/// Extract ALPN protocol list from ClientHello, return in offered order. /// Extract ALPN protocol list from ClientHello, return in offered order.
pub fn extract_alpn_from_client_hello(handshake: &[u8]) -> Vec<Vec<u8>> { pub fn extract_alpn_from_client_hello(handshake: &[u8]) -> Vec<Vec<u8>> {
if handshake.len() < 5 || handshake[0] != TLS_RECORD_HANDSHAKE {
return Vec::new();
}
let record_len = u16::from_be_bytes([handshake[3], handshake[4]]) as usize;
if handshake.len() < 5 + record_len {
return Vec::new();
}
let mut pos = 5; // after record header let mut pos = 5; // after record header
if handshake.get(pos) != Some(&0x01) { if handshake.get(pos) != Some(&0x01) {
return Vec::new(); return Vec::new();
@ -806,3 +820,7 @@ mod security_tests;
#[cfg(test)] #[cfg(test)]
#[path = "tls_adversarial_tests.rs"] #[path = "tls_adversarial_tests.rs"]
mod adversarial_tests; mod adversarial_tests;
#[cfg(test)]
#[path = "tls_fuzz_security_tests.rs"]
mod fuzz_security_tests;

View File

@ -286,13 +286,26 @@ fn extract_sni_with_duplicate_extensions_rejected() {
ext.extend_from_slice(&(sni2.len() as u16).to_be_bytes()); ext.extend_from_slice(&(sni2.len() as u16).to_be_bytes());
ext.extend_from_slice(&sni2); ext.extend_from_slice(&sni2);
let mut h = vec![0x16, 0x03, 0x03, 0x00, 0x80, 0x01, 0x00, 0x00, 0x7C, 0x03, 0x03]; let mut body = Vec::new();
h.extend_from_slice(&[0u8; 32]); body.extend_from_slice(&[0x03, 0x03]);
h.push(0); body.extend_from_slice(&[0u8; 32]);
h.extend_from_slice(&[0x00, 0x02, 0x13, 0x01]); body.push(0);
h.extend_from_slice(&[0x01, 0x00]); body.extend_from_slice(&[0x00, 0x02, 0x13, 0x01]);
h.extend_from_slice(&(ext.len() as u16).to_be_bytes()); body.extend_from_slice(&[0x01, 0x00]);
h.extend_from_slice(&ext); body.extend_from_slice(&(ext.len() as u16).to_be_bytes());
body.extend_from_slice(&ext);
let mut handshake = Vec::new();
handshake.push(0x01);
let body_len = (body.len() as u32).to_be_bytes();
handshake.extend_from_slice(&body_len[1..4]);
handshake.extend_from_slice(&body);
let mut h = Vec::new();
h.push(0x16);
h.extend_from_slice(&[0x03, 0x03]);
h.extend_from_slice(&(handshake.len() as u16).to_be_bytes());
h.extend_from_slice(&handshake);
// Parser might return first, see second, or fail. OWASP ASVS prefers rejection of unexpected dups. // Parser might return first, see second, or fail. OWASP ASVS prefers rejection of unexpected dups.
// Telemt's `extract_sni` returns the first one found. // Telemt's `extract_sni` returns the first one found.

View File

@ -0,0 +1,195 @@
use super::*;
use crate::crypto::sha256_hmac;
use std::panic::catch_unwind;
fn make_valid_tls_handshake_with_session_id(
secret: &[u8],
timestamp: u32,
session_id: &[u8],
) -> Vec<u8> {
let session_id_len = session_id.len();
assert!(session_id_len <= u8::MAX as usize);
let len = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1 + session_id_len;
let mut handshake = vec![0x42u8; len];
handshake[TLS_DIGEST_POS + TLS_DIGEST_LEN] = session_id_len as u8;
let sid_start = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1;
handshake[sid_start..sid_start + session_id_len].copy_from_slice(session_id);
handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].fill(0);
let mut digest = sha256_hmac(secret, &handshake);
let ts = timestamp.to_le_bytes();
for idx in 0..4 {
digest[28 + idx] ^= ts[idx];
}
handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].copy_from_slice(&digest);
handshake
}
fn make_valid_client_hello_record(host: &str, alpn_protocols: &[&[u8]]) -> Vec<u8> {
let mut body = Vec::new();
body.extend_from_slice(&TLS_VERSION);
body.extend_from_slice(&[0u8; 32]);
body.push(0);
body.extend_from_slice(&2u16.to_be_bytes());
body.extend_from_slice(&[0x13, 0x01]);
body.push(1);
body.push(0);
let mut ext_blob = Vec::new();
let host_bytes = host.as_bytes();
let mut sni_payload = Vec::new();
sni_payload.extend_from_slice(&((host_bytes.len() + 3) as u16).to_be_bytes());
sni_payload.push(0);
sni_payload.extend_from_slice(&(host_bytes.len() as u16).to_be_bytes());
sni_payload.extend_from_slice(host_bytes);
ext_blob.extend_from_slice(&0x0000u16.to_be_bytes());
ext_blob.extend_from_slice(&(sni_payload.len() as u16).to_be_bytes());
ext_blob.extend_from_slice(&sni_payload);
if !alpn_protocols.is_empty() {
let mut alpn_list = Vec::new();
for proto in alpn_protocols {
alpn_list.push(proto.len() as u8);
alpn_list.extend_from_slice(proto);
}
let mut alpn_data = Vec::new();
alpn_data.extend_from_slice(&(alpn_list.len() as u16).to_be_bytes());
alpn_data.extend_from_slice(&alpn_list);
ext_blob.extend_from_slice(&0x0010u16.to_be_bytes());
ext_blob.extend_from_slice(&(alpn_data.len() as u16).to_be_bytes());
ext_blob.extend_from_slice(&alpn_data);
}
body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes());
body.extend_from_slice(&ext_blob);
let mut handshake = Vec::new();
handshake.push(0x01);
let body_len = (body.len() as u32).to_be_bytes();
handshake.extend_from_slice(&body_len[1..4]);
handshake.extend_from_slice(&body);
let mut record = Vec::new();
record.push(TLS_RECORD_HANDSHAKE);
record.extend_from_slice(&[0x03, 0x01]);
record.extend_from_slice(&(handshake.len() as u16).to_be_bytes());
record.extend_from_slice(&handshake);
record
}
#[test]
fn client_hello_fuzz_corpus_never_panics_or_accepts_corruption() {
let valid = make_valid_client_hello_record("example.com", &[b"h2", b"http/1.1"]);
assert_eq!(extract_sni_from_client_hello(&valid).as_deref(), Some("example.com"));
assert_eq!(
extract_alpn_from_client_hello(&valid),
vec![b"h2".to_vec(), b"http/1.1".to_vec()]
);
assert!(
extract_sni_from_client_hello(&make_valid_client_hello_record("127.0.0.1", &[])).is_none(),
"literal IP hostnames must be rejected"
);
let mut corpus = vec![
Vec::new(),
vec![0x16, 0x03, 0x03],
valid[..9].to_vec(),
valid[..valid.len() - 1].to_vec(),
];
let mut wrong_type = valid.clone();
wrong_type[0] = 0x15;
corpus.push(wrong_type);
let mut wrong_handshake = valid.clone();
wrong_handshake[5] = 0x02;
corpus.push(wrong_handshake);
let mut wrong_length = valid.clone();
wrong_length[3] ^= 0x7f;
corpus.push(wrong_length);
for (idx, input) in corpus.iter().enumerate() {
assert!(catch_unwind(|| extract_sni_from_client_hello(input)).is_ok());
assert!(catch_unwind(|| extract_alpn_from_client_hello(input)).is_ok());
if idx == 0 {
continue;
}
assert!(extract_sni_from_client_hello(input).is_none(), "corpus item {idx} must fail closed for SNI");
assert!(extract_alpn_from_client_hello(input).is_empty(), "corpus item {idx} must fail closed for ALPN");
}
}
#[test]
fn tls_handshake_fuzz_corpus_never_panics_and_rejects_digest_mutations() {
let secret = b"tls_fuzz_security_secret";
let now: i64 = 1_700_000_000;
let base = make_valid_tls_handshake_with_session_id(secret, now as u32, &[0x42; 32]);
let secrets = vec![("fuzz-user".to_string(), secret.to_vec())];
assert!(validate_tls_handshake_at_time(&base, &secrets, false, now).is_some());
let mut corpus = Vec::new();
let mut truncated = base.clone();
truncated.truncate(TLS_DIGEST_POS + 16);
corpus.push(truncated);
let mut digest_flip = base.clone();
digest_flip[TLS_DIGEST_POS + 7] ^= 0x80;
corpus.push(digest_flip);
let mut session_id_len_overflow = base.clone();
session_id_len_overflow[TLS_DIGEST_POS + TLS_DIGEST_LEN] = 33;
corpus.push(session_id_len_overflow);
let mut timestamp_far_past = base.clone();
timestamp_far_past[TLS_DIGEST_POS + 28..TLS_DIGEST_POS + 32]
.copy_from_slice(&((now - i64::from(TIME_SKEW_MAX) - 1) as u32).to_le_bytes());
corpus.push(timestamp_far_past);
let mut timestamp_far_future = base.clone();
timestamp_far_future[TLS_DIGEST_POS + 28..TLS_DIGEST_POS + 32]
.copy_from_slice(&((now - TIME_SKEW_MIN + 1) as u32).to_le_bytes());
corpus.push(timestamp_far_future);
let mut seed = 0xA5A5_5A5A_F00D_BAAD_u64;
for _ in 0..32 {
let mut mutated = base.clone();
for _ in 0..2 {
seed = seed.wrapping_mul(2862933555777941757).wrapping_add(3037000493);
let idx = TLS_DIGEST_POS + (seed as usize % TLS_DIGEST_LEN);
mutated[idx] ^= ((seed >> 17) as u8).wrapping_add(1);
}
corpus.push(mutated);
}
for (idx, handshake) in corpus.iter().enumerate() {
let result = catch_unwind(|| validate_tls_handshake_at_time(handshake, &secrets, false, now));
assert!(result.is_ok(), "corpus item {idx} must not panic");
assert!(result.unwrap().is_none(), "corpus item {idx} must fail closed");
}
}
#[test]
fn tls_boot_time_acceptance_is_capped_by_replay_window() {
let secret = b"tls_boot_time_cap_secret";
let secrets = vec![("boot-user".to_string(), secret.to_vec())];
let boot_ts = 1u32;
let handshake = make_valid_tls_handshake_with_session_id(secret, boot_ts, &[0x42; 32]);
assert!(
validate_tls_handshake_with_replay_window(&handshake, &secrets, false, 300).is_some(),
"boot-time timestamp should be accepted while replay window permits it"
);
assert!(
validate_tls_handshake_with_replay_window(&handshake, &secrets, false, 0).is_none(),
"boot-time timestamp must be rejected when replay window disables the bypass"
);
}

View File

@ -7,6 +7,7 @@ use crate::protocol::tls;
use crate::proxy::handshake::HandshakeSuccess; use crate::proxy::handshake::HandshakeSuccess;
use crate::stream::{CryptoReader, CryptoWriter}; use crate::stream::{CryptoReader, CryptoWriter};
use crate::transport::proxy_protocol::ProxyProtocolV1Builder; use crate::transport::proxy_protocol::ProxyProtocolV1Builder;
use std::net::Ipv4Addr;
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream}; use tokio::net::{TcpListener, TcpStream};
@ -630,6 +631,54 @@ async fn proxy_protocol_header_from_untrusted_peer_range_is_rejected_under_load(
} }
} }
#[tokio::test]
async fn reservation_limit_failure_does_not_leak_curr_connects_counter() {
let user = "leak-check-user";
let mut config = crate::config::ProxyConfig::default();
config.access.user_max_tcp_conns.insert(user.to_string(), 1);
let stats = Arc::new(crate::stats::Stats::new());
let ip_tracker = Arc::new(crate::ip_tracker::UserIpTracker::new());
ip_tracker.set_user_limit(user, 8).await;
let first_peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(198, 51, 200, 1)), 50001);
let first = RunningClientHandler::acquire_user_connection_reservation_static(
user,
&config,
stats.clone(),
first_peer,
ip_tracker.clone(),
)
.await
.expect("first reservation must succeed");
assert_eq!(stats.get_user_curr_connects(user), 1);
assert_eq!(ip_tracker.get_active_ip_count(user).await, 1);
let second_peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(198, 51, 200, 2)), 50002);
let second = RunningClientHandler::acquire_user_connection_reservation_static(
user,
&config,
stats.clone(),
second_peer,
ip_tracker.clone(),
)
.await;
assert!(
matches!(second, Err(crate::error::ProxyError::ConnectionLimitExceeded { user: denied }) if denied == user),
"second reservation must be rejected at the configured tcp-conns limit"
);
assert_eq!(stats.get_user_curr_connects(user), 1, "failed acquisition must not leak a counter increment");
assert_eq!(ip_tracker.get_active_ip_count(user).await, 1, "failed acquisition must not mutate IP tracker state");
first.release().await;
ip_tracker.drain_cleanup_queue().await;
assert_eq!(stats.get_user_curr_connects(user), 0);
assert_eq!(ip_tracker.get_active_ip_count(user).await, 0);
}
#[tokio::test] #[tokio::test]
async fn short_tls_probe_is_masked_through_client_pipeline() { async fn short_tls_probe_is_masked_through_client_pipeline() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();

View File

@ -12,8 +12,10 @@ use std::path::Path;
use std::sync::Arc; use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration; use std::time::Duration;
use tokio::io::AsyncReadExt;
use tokio::io::duplex; use tokio::io::duplex;
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio::time::{timeout, Duration as TokioDuration};
fn make_crypto_reader<R>(reader: R) -> CryptoReader<R> fn make_crypto_reader<R>(reader: R) -> CryptoReader<R>
where where
@ -1322,3 +1324,194 @@ fn stress_prefer_v6_override_matrix_is_deterministic_under_mixed_inputs() {
assert!(first.is_ipv6(), "dc {idx}: v6 override should be preferred"); assert!(first.is_ipv6(), "dc {idx}: v6 override should be preferred");
} }
} }
#[tokio::test]
async fn negative_direct_relay_dc_connection_refused_fails_fast() {
let (client_reader_side, _client_writer_side) = duplex(1024);
let (_client_reader_relay, client_writer_side) = duplex(1024);
let key = [0u8; 32];
let iv = 0u128;
let client_reader = CryptoReader::new(client_reader_side, AesCtr::new(&key, iv));
let client_writer = CryptoWriter::new(client_writer_side, AesCtr::new(&key, iv), 1024);
let stats = Arc::new(Stats::new());
let buffer_pool = Arc::new(BufferPool::with_config(1024, 1));
let rng = Arc::new(SecureRandom::new());
let route_runtime = RouteRuntimeController::new(RelayRouteMode::Direct);
// Reserve an ephemeral port and immediately release it to deterministically
// exercise the direct-connect failure path without long-lived hangs.
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let dc_addr = listener.local_addr().unwrap();
drop(listener);
let mut config_with_override = ProxyConfig::default();
config_with_override.dc_overrides.insert("1".to_string(), vec![dc_addr.to_string()]);
let config = Arc::new(config_with_override);
let upstream_manager = Arc::new(UpstreamManager::new(
vec![UpstreamConfig {
enabled: true,
weight: 1,
scopes: String::new(),
upstream_type: UpstreamType::Direct {
interface: None,
bind_addresses: None,
},
selected_scope: String::new(),
}],
1,
100,
5000,
3,
false,
stats.clone(),
));
let success = HandshakeSuccess {
user: "test-user".to_string(),
peer: "127.0.0.1:12345".parse().unwrap(),
dc_idx: 1,
proto_tag: ProtoTag::Intermediate,
enc_key: key,
enc_iv: iv,
dec_key: key,
dec_iv: iv,
is_tls: false,
};
let result = timeout(
TokioDuration::from_secs(2),
handle_via_direct(
client_reader,
client_writer,
success,
upstream_manager,
stats,
config,
buffer_pool,
rng,
route_runtime.subscribe(),
route_runtime.snapshot(),
0xABCD_1234,
),
)
.await
.expect("direct relay must fail fast on connection-refused upstream");
assert!(
result.is_err(),
"connection-refused upstream must fail closed"
);
}
#[tokio::test]
async fn adversarial_direct_relay_cutover_integrity() {
let (client_reader_side, _client_writer_side) = duplex(1024);
let (_client_reader_relay, client_writer_side) = duplex(1024);
let key = [0u8; 32];
let iv = 0u128;
let client_reader = CryptoReader::new(client_reader_side, AesCtr::new(&key, iv));
let client_writer = CryptoWriter::new(client_writer_side, AesCtr::new(&key, iv), 1024);
let stats = Arc::new(Stats::new());
let buffer_pool = Arc::new(BufferPool::with_config(1024, 1));
let rng = Arc::new(SecureRandom::new());
let route_runtime = RouteRuntimeController::new(RelayRouteMode::Direct);
// Mock upstream server.
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let dc_addr = listener.local_addr().unwrap();
tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
// Read handshake nonce.
let mut nonce = [0u8; 64];
let _ = stream.read_exact(&mut nonce).await;
// Keep connection open.
tokio::time::sleep(TokioDuration::from_secs(5)).await;
});
let mut config_with_override = ProxyConfig::default();
config_with_override.dc_overrides.insert("1".to_string(), vec![dc_addr.to_string()]);
let config = Arc::new(config_with_override);
let upstream_manager = Arc::new(UpstreamManager::new(
vec![UpstreamConfig {
enabled: true,
weight: 1,
scopes: String::new(),
upstream_type: UpstreamType::Direct {
interface: None,
bind_addresses: None,
},
selected_scope: String::new(),
}],
1,
100,
5000,
3,
false,
stats.clone(),
));
let success = HandshakeSuccess {
user: "test-user".to_string(),
peer: "127.0.0.1:12345".parse().unwrap(),
dc_idx: 1,
proto_tag: ProtoTag::Intermediate,
enc_key: key,
enc_iv: iv,
dec_key: key,
dec_iv: iv,
is_tls: false,
};
let stats_for_task = stats.clone();
let runtime_clone = route_runtime.clone();
let session_task = tokio::spawn(async move {
handle_via_direct(
client_reader,
client_writer,
success,
upstream_manager,
stats_for_task,
config,
buffer_pool,
rng,
runtime_clone.subscribe(),
runtime_clone.snapshot(),
0xABCD_1234,
).await
});
timeout(TokioDuration::from_secs(2), async {
loop {
if stats.get_current_connections_direct() == 1 {
break;
}
tokio::time::sleep(TokioDuration::from_millis(10)).await;
}
})
.await
.expect("direct relay session must start before cutover");
// Trigger cutover.
route_runtime.set_mode(RelayRouteMode::Middle).unwrap();
// The session should terminate after the staggered delay (1000-2000ms).
let result = timeout(TokioDuration::from_secs(5), session_task)
.await
.expect("Session must terminate after cutover")
.expect("Session must not panic");
assert!(
matches!(
result,
Err(ProxyError::Proxy(ref msg)) if msg == ROUTE_SWITCH_ERROR_MSG
),
"Session must terminate with route switch error on cutover"
);
}

View File

@ -967,14 +967,14 @@ pub fn encrypt_tg_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec<u8> {
#[path = "handshake_security_tests.rs"] #[path = "handshake_security_tests.rs"]
mod security_tests; mod security_tests;
#[cfg(test)]
#[path = "handshake_gap_short_tls_probe_throttle_security_tests.rs"]
mod gap_short_tls_probe_throttle_security_tests;
#[cfg(test)] #[cfg(test)]
#[path = "handshake_adversarial_tests.rs"] #[path = "handshake_adversarial_tests.rs"]
mod adversarial_tests; mod adversarial_tests;
#[cfg(test)]
#[path = "handshake_fuzz_security_tests.rs"]
mod fuzz_security_tests;
/// Compile-time guard: HandshakeSuccess holds cryptographic key material and /// Compile-time guard: HandshakeSuccess holds cryptographic key material and
/// must never be Copy. A Copy impl would allow silent key duplication, /// must never be Copy. A Copy impl would allow silent key duplication,
/// undermining the zeroize-on-drop guarantee. /// undermining the zeroize-on-drop guarantee.

View File

@ -0,0 +1,270 @@
use super::*;
use crate::config::ProxyConfig;
use crate::crypto::AesCtr;
use crate::crypto::sha256;
use crate::protocol::constants::ProtoTag;
use crate::stats::ReplayChecker;
use std::net::SocketAddr;
use std::sync::MutexGuard;
use tokio::time::{timeout, Duration as TokioDuration};
fn make_mtproto_handshake_with_proto_bytes(
secret_hex: &str,
proto_bytes: [u8; 4],
dc_idx: i16,
) -> [u8; HANDSHAKE_LEN] {
let secret = hex::decode(secret_hex).expect("secret hex must decode");
let mut handshake = [0x5Au8; HANDSHAKE_LEN];
for (idx, b) in handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]
.iter_mut()
.enumerate()
{
*b = (idx as u8).wrapping_add(1);
}
let dec_prekey = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN];
let dec_iv_bytes = &handshake[SKIP_LEN + PREKEY_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN];
let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len());
dec_key_input.extend_from_slice(dec_prekey);
dec_key_input.extend_from_slice(&secret);
let dec_key = sha256(&dec_key_input);
let mut dec_iv_arr = [0u8; IV_LEN];
dec_iv_arr.copy_from_slice(dec_iv_bytes);
let dec_iv = u128::from_be_bytes(dec_iv_arr);
let mut stream = AesCtr::new(&dec_key, dec_iv);
let keystream = stream.encrypt(&[0u8; HANDSHAKE_LEN]);
let mut target_plain = [0u8; HANDSHAKE_LEN];
target_plain[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_bytes);
target_plain[DC_IDX_POS..DC_IDX_POS + 2].copy_from_slice(&dc_idx.to_le_bytes());
for idx in PROTO_TAG_POS..HANDSHAKE_LEN {
handshake[idx] = target_plain[idx] ^ keystream[idx];
}
handshake
}
fn make_valid_mtproto_handshake(secret_hex: &str, proto_tag: ProtoTag, dc_idx: i16) -> [u8; HANDSHAKE_LEN] {
make_mtproto_handshake_with_proto_bytes(secret_hex, proto_tag.to_bytes(), dc_idx)
}
fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig {
let mut cfg = ProxyConfig::default();
cfg.access.users.clear();
cfg.access.users.insert("user".to_string(), secret_hex.to_string());
cfg.access.ignore_time_skew = true;
cfg.general.modes.secure = true;
cfg
}
fn auth_probe_test_guard() -> MutexGuard<'static, ()> {
auth_probe_test_lock()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
#[tokio::test]
async fn mtproto_handshake_duplicate_digest_is_replayed_on_second_attempt() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let secret_hex = "11223344556677889900aabbccddeeff";
let base = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 2);
let config = test_config_with_secret_hex(secret_hex);
let replay_checker = ReplayChecker::new(128, TokioDuration::from_secs(60));
let peer: SocketAddr = "192.0.2.1:12345".parse().unwrap();
let first = handle_mtproto_handshake(
&base,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
false,
None,
)
.await;
assert!(matches!(first, HandshakeResult::Success(_)));
let second = handle_mtproto_handshake(
&base,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
false,
None,
)
.await;
assert!(matches!(second, HandshakeResult::BadClient { .. }));
clear_auth_probe_state_for_testing();
}
#[tokio::test]
async fn mtproto_handshake_fuzz_corpus_never_panics_and_stays_fail_closed() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let secret_hex = "00112233445566778899aabbccddeeff";
let base = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 1);
let config = test_config_with_secret_hex(secret_hex);
let replay_checker = ReplayChecker::new(128, TokioDuration::from_secs(60));
let peer: SocketAddr = "192.0.2.2:54321".parse().unwrap();
let mut corpus = Vec::<[u8; HANDSHAKE_LEN]>::new();
corpus.push(make_mtproto_handshake_with_proto_bytes(
secret_hex,
[0x00, 0x00, 0x00, 0x00],
1,
));
corpus.push(make_mtproto_handshake_with_proto_bytes(
secret_hex,
[0xff, 0xff, 0xff, 0xff],
1,
));
corpus.push(make_valid_mtproto_handshake(
"ffeeddccbbaa99887766554433221100",
ProtoTag::Secure,
1,
));
let mut seed = 0xF0F0_F00D_BAAD_CAFEu64;
for _ in 0..32 {
let mut mutated = base;
for _ in 0..4 {
seed = seed.wrapping_mul(2862933555777941757).wrapping_add(3037000493);
let idx = SKIP_LEN + (seed as usize % (PREKEY_LEN + IV_LEN));
mutated[idx] ^= ((seed >> 19) as u8).wrapping_add(1);
}
corpus.push(mutated);
}
for (idx, input) in corpus.into_iter().enumerate() {
let result = timeout(
TokioDuration::from_secs(1),
handle_mtproto_handshake(
&input,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
false,
None,
),
)
.await
.expect("fuzzed handshake must complete in time");
assert!(
matches!(result, HandshakeResult::BadClient { .. }),
"corpus item {idx} must fail closed"
);
}
clear_auth_probe_state_for_testing();
}
#[tokio::test]
async fn mtproto_handshake_mixed_corpus_never_panics_and_exact_duplicates_are_rejected() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let secret_hex = "99887766554433221100ffeeddccbbaa";
let base = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 4);
let config = test_config_with_secret_hex(secret_hex);
let replay_checker = ReplayChecker::new(256, TokioDuration::from_secs(60));
let peer: SocketAddr = "192.0.2.44:45444".parse().unwrap();
let first = timeout(
TokioDuration::from_secs(1),
handle_mtproto_handshake(
&base,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
false,
None,
),
)
.await
.expect("base handshake must not hang");
assert!(matches!(first, HandshakeResult::Success(_)));
let replay = timeout(
TokioDuration::from_secs(1),
handle_mtproto_handshake(
&base,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
false,
None,
),
)
.await
.expect("duplicate handshake must not hang");
assert!(matches!(replay, HandshakeResult::BadClient { .. }));
let mut corpus = Vec::<[u8; HANDSHAKE_LEN]>::new();
let mut prekey_flip = base;
prekey_flip[SKIP_LEN] ^= 0x80;
corpus.push(prekey_flip);
let mut iv_flip = base;
iv_flip[SKIP_LEN + PREKEY_LEN] ^= 0x01;
corpus.push(iv_flip);
let mut tail_flip = base;
tail_flip[SKIP_LEN + PREKEY_LEN + IV_LEN - 1] ^= 0x40;
corpus.push(tail_flip);
let mut seed = 0xBADC_0FFE_EE11_4242u64;
for _ in 0..24 {
let mut mutated = base;
for _ in 0..3 {
seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
let idx = SKIP_LEN + (seed as usize % (PREKEY_LEN + IV_LEN));
mutated[idx] ^= ((seed >> 16) as u8).wrapping_add(1);
}
corpus.push(mutated);
}
for (idx, input) in corpus.iter().enumerate() {
let result = timeout(
TokioDuration::from_secs(1),
handle_mtproto_handshake(
input,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
false,
None,
),
)
.await
.expect("fuzzed handshake must complete in time");
assert!(
matches!(result, HandshakeResult::BadClient { .. }),
"mixed corpus item {idx} must fail closed"
);
}
clear_auth_probe_state_for_testing();
}

View File

@ -1,50 +0,0 @@
use super::*;
use crate::stats::ReplayChecker;
use std::net::SocketAddr;
use std::time::Duration;
fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig {
let mut cfg = ProxyConfig::default();
cfg.access.users.clear();
cfg.access
.users
.insert("user".to_string(), secret_hex.to_string());
cfg.access.ignore_time_skew = true;
cfg
}
#[tokio::test]
async fn gap_t01_short_tls_probe_burst_is_throttled() {
let _guard = auth_probe_test_lock()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
clear_auth_probe_state_for_testing();
let config = test_config_with_secret_hex("11111111111111111111111111111111");
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let rng = SecureRandom::new();
let peer: SocketAddr = "198.51.100.171:44361".parse().unwrap();
let too_short = vec![0x16, 0x03, 0x01];
for _ in 0..AUTH_PROBE_BACKOFF_START_FAILS {
let result = handle_tls_handshake(
&too_short,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
&rng,
None,
)
.await;
assert!(matches!(result, HandshakeResult::BadClient { .. }));
}
assert!(
auth_probe_fail_streak_for_testing(peer.ip())
.is_some_and(|streak| streak >= AUTH_PROBE_BACKOFF_START_FAILS),
"short TLS probe bursts must increase auth-probe fail streak"
);
}

View File

@ -1997,6 +1997,42 @@ fn auth_probe_round_limited_overcap_eviction_marks_saturation_and_keeps_newcomer
); );
} }
#[tokio::test]
async fn gap_t01_short_tls_probe_burst_is_throttled() {
let _guard = auth_probe_test_lock()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
clear_auth_probe_state_for_testing();
let config = test_config_with_secret_hex("11111111111111111111111111111111");
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let rng = SecureRandom::new();
let peer: SocketAddr = "198.51.100.171:44361".parse().unwrap();
let too_short = vec![0x16, 0x03, 0x01];
for _ in 0..AUTH_PROBE_BACKOFF_START_FAILS {
let result = handle_tls_handshake(
&too_short,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
&rng,
None,
)
.await;
assert!(matches!(result, HandshakeResult::BadClient { .. }));
}
assert!(
auth_probe_fail_streak_for_testing(peer.ip())
.is_some_and(|streak| streak >= AUTH_PROBE_BACKOFF_START_FAILS),
"short TLS probe bursts must increase auth-probe fail streak"
);
}
#[test] #[test]
fn stress_auth_probe_overcap_churn_does_not_starve_high_threat_sentinel_bucket() { fn stress_auth_probe_overcap_churn_does_not_starve_high_threat_sentinel_bucket() {
let _guard = auth_probe_test_lock() let _guard = auth_probe_test_lock()

View File

@ -1,10 +1,11 @@
use super::*; use super::*;
use crate::proxy::handshake::HandshakeSuccess;
use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController};
use bytes::Bytes; use bytes::Bytes;
use crate::crypto::AesCtr; use crate::crypto::AesCtr;
use crate::crypto::SecureRandom; use crate::crypto::SecureRandom;
use crate::config::{GeneralConfig, MeRouteNoWriterMode, MeSocksKdfPolicy, MeWriterPickMode}; use crate::config::{GeneralConfig, MeRouteNoWriterMode, MeSocksKdfPolicy, MeWriterPickMode};
use crate::network::probe::NetworkDecision; use crate::network::probe::NetworkDecision;
use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController};
use crate::stats::Stats; use crate::stats::Stats;
use crate::stream::{BufferPool, CryptoReader, CryptoWriter, PooledBuffer}; use crate::stream::{BufferPool, CryptoReader, CryptoWriter, PooledBuffer};
use crate::transport::middle_proxy::MePool; use crate::transport::middle_proxy::MePool;
@ -20,6 +21,7 @@ use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt; use tokio::io::AsyncWriteExt;
use tokio::io::duplex; use tokio::io::duplex;
use tokio::time::{Duration as TokioDuration, timeout}; use tokio::time::{Duration as TokioDuration, timeout};
use std::sync::{Mutex, OnceLock};
fn make_pooled_payload(data: &[u8]) -> PooledBuffer { fn make_pooled_payload(data: &[u8]) -> PooledBuffer {
let pool = Arc::new(BufferPool::with_config(data.len().max(1), 4)); let pool = Arc::new(BufferPool::with_config(data.len().max(1), 4));
@ -36,6 +38,11 @@ fn make_pooled_payload_from(pool: &Arc<BufferPool>, data: &[u8]) -> PooledBuffer
payload payload
} }
fn quota_user_lock_test_lock() -> &'static Mutex<()> {
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
TEST_LOCK.get_or_init(|| Mutex::new(()))
}
#[test] #[test]
fn should_yield_sender_only_on_budget_with_backlog() { fn should_yield_sender_only_on_budget_with_backlog() {
assert!(!should_yield_c2me_sender(0, true)); assert!(!should_yield_c2me_sender(0, true));
@ -244,6 +251,10 @@ fn quota_user_lock_cache_reuses_entry_for_same_user() {
#[test] #[test]
fn quota_user_lock_cache_is_bounded_under_unique_churn() { fn quota_user_lock_cache_is_bounded_under_unique_churn() {
let _guard = quota_user_lock_test_lock()
.lock()
.expect("quota user lock test lock must be available");
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear(); map.clear();
@ -261,23 +272,29 @@ fn quota_user_lock_cache_is_bounded_under_unique_churn() {
#[test] #[test]
fn quota_user_lock_cache_saturation_returns_ephemeral_lock_without_growth() { fn quota_user_lock_cache_saturation_returns_ephemeral_lock_without_growth() {
let _guard = quota_user_lock_test_lock()
.lock()
.expect("quota user lock test lock must be available");
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
for attempt in 0..8u32 {
map.clear(); map.clear();
let prefix = format!("quota-held-user-{}-{attempt}", std::process::id());
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
for idx in 0..QUOTA_USER_LOCKS_MAX { for idx in 0..QUOTA_USER_LOCKS_MAX {
let user = format!("quota-held-user-{idx}"); let user = format!("{prefix}-{idx}");
retained.push(quota_user_lock(&user)); retained.push(quota_user_lock(&user));
} }
assert_eq!( if map.len() != QUOTA_USER_LOCKS_MAX {
map.len(), drop(retained);
QUOTA_USER_LOCKS_MAX, continue;
"precondition: cache should be full before overflow acquisition" }
);
let overflow_a = quota_user_lock("quota-overflow-user"); let overflow_user = format!("quota-overflow-user-{}-{attempt}", std::process::id());
let overflow_b = quota_user_lock("quota-overflow-user"); let overflow_a = quota_user_lock(&overflow_user);
let overflow_b = quota_user_lock(&overflow_user);
assert_eq!( assert_eq!(
map.len(), map.len(),
@ -285,7 +302,7 @@ fn quota_user_lock_cache_saturation_returns_ephemeral_lock_without_growth() {
"overflow acquisition must not grow cache past hard limit" "overflow acquisition must not grow cache past hard limit"
); );
assert!( assert!(
map.get("quota-overflow-user").is_none(), map.get(&overflow_user).is_none(),
"overflow path should not cache new user lock when map is saturated and all entries are retained" "overflow path should not cache new user lock when map is saturated and all entries are retained"
); );
assert!( assert!(
@ -294,6 +311,12 @@ fn quota_user_lock_cache_saturation_returns_ephemeral_lock_without_growth() {
); );
drop(retained); drop(retained);
return;
}
panic!(
"unable to observe stable saturated lock-cache precondition after bounded retries"
);
} }
#[tokio::test(flavor = "multi_thread", worker_threads = 4)] #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
@ -2169,3 +2192,320 @@ async fn middle_relay_cutover_storm_multi_session_keeps_generic_errors_and_relea
drop(client_sides); drop(client_sides);
} }
#[tokio::test]
async fn secure_padding_distribution_in_relay_writer() {
timeout(TokioDuration::from_secs(10), async {
let (mut client_side, relay_side) = duplex(512 * 1024);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(relay_side, AesCtr::new(&key, iv), 8 * 1024);
let rng = Arc::new(SecureRandom::new());
let mut frame_buf = Vec::new();
let mut decryptor = AesCtr::new(&key, iv);
let mut padding_counts = [0usize; 4];
let iterations = 180usize;
let payload = vec![0xAAu8; 100]; // 4-byte aligned
for _ in 0..iterations {
write_client_payload(
&mut writer,
ProtoTag::Secure,
0,
&payload,
&rng,
&mut frame_buf,
)
.await
.expect("payload write must succeed");
writer
.flush()
.await
.expect("writer flush must complete so encrypted frame becomes readable");
let mut len_buf = [0u8; 4];
client_side
.read_exact(&mut len_buf)
.await
.expect("must read encrypted secure length");
let decrypted_len_bytes = decryptor.decrypt(&len_buf);
let decrypted_len_bytes: [u8; 4] = decrypted_len_bytes
.try_into()
.expect("decrypted length must be 4 bytes");
let wire_len = (u32::from_le_bytes(decrypted_len_bytes) & 0x7fff_ffff) as usize;
assert!(
wire_len >= payload.len(),
"wire length must include at least payload bytes"
);
let padding_len = wire_len - payload.len();
assert!(padding_len >= 1 && padding_len <= 3);
padding_counts[padding_len] += 1;
// Drain and decrypt frame bytes so CTR state stays aligned across writes.
let mut trash = vec![0u8; wire_len];
client_side
.read_exact(&mut trash)
.await
.expect("must read encrypted secure frame body");
let _ = decryptor.decrypt(&trash);
}
for p in 1..=3 {
let count = padding_counts[p];
assert!(
count > iterations / 8,
"padding length {p} is under-represented ({count}/{iterations})"
);
}
})
.await
.expect("secure padding distribution test exceeded runtime budget");
}
#[tokio::test]
async fn negative_middle_end_connection_lost_during_relay_exits_on_client_eof() {
let (client_reader_side, client_writer_side) = duplex(1024);
let (_relay_reader_side, relay_writer_side) = duplex(1024);
let key = [0u8; 32];
let iv = 0u128;
let crypto_reader = CryptoReader::new(client_reader_side, AesCtr::new(&key, iv));
let crypto_writer = CryptoWriter::new(relay_writer_side, AesCtr::new(&key, iv), 1024);
let stats = Arc::new(Stats::new());
let config = Arc::new(ProxyConfig::default());
let buffer_pool = Arc::new(BufferPool::with_config(1024, 1));
let rng = Arc::new(SecureRandom::new());
let route_runtime = RouteRuntimeController::new(RelayRouteMode::Middle);
// Create an ME pool.
let me_pool = make_me_pool_for_abort_test(stats.clone()).await;
// ConnRegistry ids are monotonic; reserve one id so we can predict the
// next session conn_id and close it deterministically without relying on
// writer-bound views such as active_conn_ids().
let (probe_conn_id, probe_rx) = me_pool.registry().register().await;
drop(probe_rx);
me_pool.registry().unregister(probe_conn_id).await;
let target_conn_id = probe_conn_id.wrapping_add(1);
let success = HandshakeSuccess {
user: "test-user".to_string(),
peer: "127.0.0.1:12345".parse().unwrap(),
dc_idx: 1,
proto_tag: ProtoTag::Intermediate,
enc_key: key,
enc_iv: iv,
dec_key: key,
dec_iv: iv,
is_tls: false,
};
let session_task = tokio::spawn(handle_via_middle_proxy(
crypto_reader,
crypto_writer,
success,
me_pool.clone(),
stats.clone(),
config.clone(),
buffer_pool.clone(),
"127.0.0.1:443".parse().unwrap(),
rng.clone(),
route_runtime.subscribe(),
route_runtime.snapshot(),
0x1234_5678,
));
// Wait until session startup is visible, then unregister the predicted
// conn_id to close the per-session ME response channel.
timeout(TokioDuration::from_millis(500), async {
loop {
if stats.get_current_connections_me() >= 1 {
break;
}
tokio::time::sleep(TokioDuration::from_millis(10)).await;
}
})
.await
.expect("ME session must start before channel close simulation");
me_pool.registry().unregister(target_conn_id).await;
drop(client_writer_side);
let result = timeout(TokioDuration::from_secs(2), session_task)
.await
.expect("Session task must terminate after ME drop and client EOF")
.expect("Session task must not panic");
assert!(
result.is_ok(),
"Session should complete cleanly after ME drop when client closes, got: {:?}",
result
);
}
#[tokio::test]
async fn adversarial_middle_end_drop_plus_cutover_returns_generic_route_switch() {
let (client_reader_side, _client_writer_side) = duplex(1024);
let (_relay_reader_side, relay_writer_side) = duplex(1024);
let key = [0u8; 32];
let iv = 0u128;
let crypto_reader = CryptoReader::new(client_reader_side, AesCtr::new(&key, iv));
let crypto_writer = CryptoWriter::new(relay_writer_side, AesCtr::new(&key, iv), 1024);
let stats = Arc::new(Stats::new());
let config = Arc::new(ProxyConfig::default());
let buffer_pool = Arc::new(BufferPool::with_config(1024, 1));
let rng = Arc::new(SecureRandom::new());
let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Middle));
let me_pool = make_me_pool_for_abort_test(stats.clone()).await;
// Predict the next conn_id so we can force-drop its ME channel deterministically.
let (probe_conn_id, probe_rx) = me_pool.registry().register().await;
drop(probe_rx);
me_pool.registry().unregister(probe_conn_id).await;
let target_conn_id = probe_conn_id.wrapping_add(1);
let success = HandshakeSuccess {
user: "test-user-cutover".to_string(),
peer: "127.0.0.1:12345".parse().unwrap(),
dc_idx: 1,
proto_tag: ProtoTag::Intermediate,
enc_key: key,
enc_iv: iv,
dec_key: key,
dec_iv: iv,
is_tls: false,
};
let runtime_clone = route_runtime.clone();
let session_task = tokio::spawn(handle_via_middle_proxy(
crypto_reader,
crypto_writer,
success,
me_pool.clone(),
stats.clone(),
config,
buffer_pool,
"127.0.0.1:443".parse().unwrap(),
rng,
runtime_clone.subscribe(),
runtime_clone.snapshot(),
0xC001_CAFE,
));
timeout(TokioDuration::from_millis(500), async {
loop {
if stats.get_current_connections_me() >= 1 {
break;
}
tokio::time::sleep(TokioDuration::from_millis(10)).await;
}
})
.await
.expect("ME session must start before race trigger");
// Race ME channel drop with route cutover and assert generic client-visible outcome.
me_pool.registry().unregister(target_conn_id).await;
assert!(
route_runtime.set_mode(RelayRouteMode::Direct).is_some(),
"cutover must advance generation"
);
let relay_result = timeout(TokioDuration::from_secs(6), session_task)
.await
.expect("session must terminate under ME-drop + cutover race")
.expect("session task must not panic");
assert!(
matches!(
relay_result,
Err(ProxyError::Proxy(ref msg)) if msg == ROUTE_SWITCH_ERROR_MSG
),
"race outcome must remain generic and not leak ME internals, got: {:?}",
relay_result
);
}
#[tokio::test]
async fn stress_middle_end_drop_with_client_eof_never_hangs_across_burst() {
let stats = Arc::new(Stats::new());
let me_pool = make_me_pool_for_abort_test(stats.clone()).await;
for round in 0..32u64 {
let (client_reader_side, client_writer_side) = duplex(1024);
let (_relay_reader_side, relay_writer_side) = duplex(1024);
let key = [0u8; 32];
let iv = 0u128;
let crypto_reader = CryptoReader::new(client_reader_side, AesCtr::new(&key, iv));
let crypto_writer = CryptoWriter::new(relay_writer_side, AesCtr::new(&key, iv), 1024);
let config = Arc::new(ProxyConfig::default());
let buffer_pool = Arc::new(BufferPool::with_config(1024, 1));
let rng = Arc::new(SecureRandom::new());
let route_runtime = RouteRuntimeController::new(RelayRouteMode::Middle);
let (probe_conn_id, probe_rx) = me_pool.registry().register().await;
drop(probe_rx);
me_pool.registry().unregister(probe_conn_id).await;
let target_conn_id = probe_conn_id.wrapping_add(1);
let success = HandshakeSuccess {
user: format!("stress-me-drop-eof-{round}"),
peer: "127.0.0.1:12345".parse().unwrap(),
dc_idx: 1,
proto_tag: ProtoTag::Intermediate,
enc_key: key,
enc_iv: iv,
dec_key: key,
dec_iv: iv,
is_tls: false,
};
let session_task = tokio::spawn(handle_via_middle_proxy(
crypto_reader,
crypto_writer,
success,
me_pool.clone(),
stats.clone(),
config,
buffer_pool,
"127.0.0.1:443".parse().unwrap(),
rng,
route_runtime.subscribe(),
route_runtime.snapshot(),
0xD00D_0000 + round,
));
timeout(TokioDuration::from_millis(500), async {
loop {
if stats.get_current_connections_me() >= 1 {
break;
}
tokio::time::sleep(TokioDuration::from_millis(10)).await;
}
})
.await
.expect("session must start before forced drop in burst round");
me_pool.registry().unregister(target_conn_id).await;
drop(client_writer_side);
let result = timeout(TokioDuration::from_secs(2), session_task)
.await
.expect("burst round session must terminate quickly")
.expect("burst round session must not panic");
assert!(
result.is_ok(),
"burst round {round}: expected clean shutdown after ME drop + EOF, got: {:?}",
result
);
}
}

View File

@ -1140,3 +1140,46 @@ async fn relay_bidirectional_light_fuzz_permission_denied_messages_remain_io_err
); );
} }
} }
#[tokio::test]
async fn relay_half_close_keeps_reverse_direction_progressing() {
let stats = Arc::new(Stats::new());
let user = "half-close-user";
let (client_peer, relay_client) = duplex(1024);
let (relay_server, server_peer) = duplex(1024);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let (mut cp_reader, mut cp_writer) = tokio::io::split(client_peer);
let (mut sp_reader, mut sp_writer) = tokio::io::split(server_peer);
let relay_task = tokio::spawn(relay_bidirectional(
client_reader,
client_writer,
server_reader,
server_writer,
8192,
8192,
user,
Arc::clone(&stats),
None,
Arc::new(BufferPool::new()),
));
sp_writer.write_all(&[0x10, 0x20, 0x30, 0x40]).await.unwrap();
sp_writer.shutdown().await.unwrap();
let mut inbound = [0u8; 4];
cp_reader.read_exact(&mut inbound).await.unwrap();
assert_eq!(inbound, [0x10, 0x20, 0x30, 0x40]);
cp_writer.write_all(&[0xaa, 0xbb, 0xcc, 0xdd]).await.unwrap();
let mut outbound = [0u8; 4];
sp_reader.read_exact(&mut outbound).await.unwrap();
assert_eq!(outbound, [0xaa, 0xbb, 0xcc, 0xdd]);
relay_task.abort();
let joined = relay_task.await;
assert!(joined.is_err(), "aborted relay task must return join error");
}