Add security tests for connection limits and handshake integrity

- Implement a test to ensure that exceeding the user connection limit does not leak the current connections counter.
- Add tests for direct relay connection refusal and adversarial scenarios to verify proper error handling.
- Introduce fuzz testing for MTProto handshake to ensure robustness against malformed inputs and replay attacks.
- Remove obsolete short TLS probe throttle tests and integrate their functionality into existing security tests.
- Enhance middle relay tests to validate behavior during connection drops and cutovers, ensuring graceful error handling.
- Add a test for half-close scenarios in relay to confirm bidirectional data flow continues as expected.
This commit is contained in:
David Osipov 2026-03-19 14:56:28 +04:00
parent 2a01ca2d6f
commit e6ad9e4c7f
No known key found for this signature in database
GPG Key ID: 0E55C4A47454E82E
11 changed files with 1198 additions and 91 deletions

View File

@ -544,6 +544,11 @@ pub fn extract_sni_from_client_hello(handshake: &[u8]) -> Option<String> {
return None;
}
let record_len = u16::from_be_bytes([handshake[3], handshake[4]]) as usize;
if handshake.len() < 5 + record_len {
return None;
}
let mut pos = 5; // after record header
if handshake.get(pos).copied()? != 0x01 {
return None; // not ClientHello
@ -649,6 +654,15 @@ fn is_valid_sni_hostname(host: &str) -> bool {
/// Extract ALPN protocol list from ClientHello, return in offered order.
pub fn extract_alpn_from_client_hello(handshake: &[u8]) -> Vec<Vec<u8>> {
if handshake.len() < 5 || handshake[0] != TLS_RECORD_HANDSHAKE {
return Vec::new();
}
let record_len = u16::from_be_bytes([handshake[3], handshake[4]]) as usize;
if handshake.len() < 5 + record_len {
return Vec::new();
}
let mut pos = 5; // after record header
if handshake.get(pos) != Some(&0x01) {
return Vec::new();
@ -806,3 +820,7 @@ mod security_tests;
#[cfg(test)]
#[path = "tls_adversarial_tests.rs"]
mod adversarial_tests;
#[cfg(test)]
#[path = "tls_fuzz_security_tests.rs"]
mod fuzz_security_tests;

View File

@ -286,13 +286,26 @@ fn extract_sni_with_duplicate_extensions_rejected() {
ext.extend_from_slice(&(sni2.len() as u16).to_be_bytes());
ext.extend_from_slice(&sni2);
let mut h = vec![0x16, 0x03, 0x03, 0x00, 0x80, 0x01, 0x00, 0x00, 0x7C, 0x03, 0x03];
h.extend_from_slice(&[0u8; 32]);
h.push(0);
h.extend_from_slice(&[0x00, 0x02, 0x13, 0x01]);
h.extend_from_slice(&[0x01, 0x00]);
h.extend_from_slice(&(ext.len() as u16).to_be_bytes());
h.extend_from_slice(&ext);
let mut body = Vec::new();
body.extend_from_slice(&[0x03, 0x03]);
body.extend_from_slice(&[0u8; 32]);
body.push(0);
body.extend_from_slice(&[0x00, 0x02, 0x13, 0x01]);
body.extend_from_slice(&[0x01, 0x00]);
body.extend_from_slice(&(ext.len() as u16).to_be_bytes());
body.extend_from_slice(&ext);
let mut handshake = Vec::new();
handshake.push(0x01);
let body_len = (body.len() as u32).to_be_bytes();
handshake.extend_from_slice(&body_len[1..4]);
handshake.extend_from_slice(&body);
let mut h = Vec::new();
h.push(0x16);
h.extend_from_slice(&[0x03, 0x03]);
h.extend_from_slice(&(handshake.len() as u16).to_be_bytes());
h.extend_from_slice(&handshake);
// Parser might return first, see second, or fail. OWASP ASVS prefers rejection of unexpected dups.
// Telemt's `extract_sni` returns the first one found.

View File

@ -0,0 +1,195 @@
use super::*;
use crate::crypto::sha256_hmac;
use std::panic::catch_unwind;
fn make_valid_tls_handshake_with_session_id(
secret: &[u8],
timestamp: u32,
session_id: &[u8],
) -> Vec<u8> {
let session_id_len = session_id.len();
assert!(session_id_len <= u8::MAX as usize);
let len = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1 + session_id_len;
let mut handshake = vec![0x42u8; len];
handshake[TLS_DIGEST_POS + TLS_DIGEST_LEN] = session_id_len as u8;
let sid_start = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1;
handshake[sid_start..sid_start + session_id_len].copy_from_slice(session_id);
handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].fill(0);
let mut digest = sha256_hmac(secret, &handshake);
let ts = timestamp.to_le_bytes();
for idx in 0..4 {
digest[28 + idx] ^= ts[idx];
}
handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].copy_from_slice(&digest);
handshake
}
fn make_valid_client_hello_record(host: &str, alpn_protocols: &[&[u8]]) -> Vec<u8> {
let mut body = Vec::new();
body.extend_from_slice(&TLS_VERSION);
body.extend_from_slice(&[0u8; 32]);
body.push(0);
body.extend_from_slice(&2u16.to_be_bytes());
body.extend_from_slice(&[0x13, 0x01]);
body.push(1);
body.push(0);
let mut ext_blob = Vec::new();
let host_bytes = host.as_bytes();
let mut sni_payload = Vec::new();
sni_payload.extend_from_slice(&((host_bytes.len() + 3) as u16).to_be_bytes());
sni_payload.push(0);
sni_payload.extend_from_slice(&(host_bytes.len() as u16).to_be_bytes());
sni_payload.extend_from_slice(host_bytes);
ext_blob.extend_from_slice(&0x0000u16.to_be_bytes());
ext_blob.extend_from_slice(&(sni_payload.len() as u16).to_be_bytes());
ext_blob.extend_from_slice(&sni_payload);
if !alpn_protocols.is_empty() {
let mut alpn_list = Vec::new();
for proto in alpn_protocols {
alpn_list.push(proto.len() as u8);
alpn_list.extend_from_slice(proto);
}
let mut alpn_data = Vec::new();
alpn_data.extend_from_slice(&(alpn_list.len() as u16).to_be_bytes());
alpn_data.extend_from_slice(&alpn_list);
ext_blob.extend_from_slice(&0x0010u16.to_be_bytes());
ext_blob.extend_from_slice(&(alpn_data.len() as u16).to_be_bytes());
ext_blob.extend_from_slice(&alpn_data);
}
body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes());
body.extend_from_slice(&ext_blob);
let mut handshake = Vec::new();
handshake.push(0x01);
let body_len = (body.len() as u32).to_be_bytes();
handshake.extend_from_slice(&body_len[1..4]);
handshake.extend_from_slice(&body);
let mut record = Vec::new();
record.push(TLS_RECORD_HANDSHAKE);
record.extend_from_slice(&[0x03, 0x01]);
record.extend_from_slice(&(handshake.len() as u16).to_be_bytes());
record.extend_from_slice(&handshake);
record
}
#[test]
fn client_hello_fuzz_corpus_never_panics_or_accepts_corruption() {
let valid = make_valid_client_hello_record("example.com", &[b"h2", b"http/1.1"]);
assert_eq!(extract_sni_from_client_hello(&valid).as_deref(), Some("example.com"));
assert_eq!(
extract_alpn_from_client_hello(&valid),
vec![b"h2".to_vec(), b"http/1.1".to_vec()]
);
assert!(
extract_sni_from_client_hello(&make_valid_client_hello_record("127.0.0.1", &[])).is_none(),
"literal IP hostnames must be rejected"
);
let mut corpus = vec![
Vec::new(),
vec![0x16, 0x03, 0x03],
valid[..9].to_vec(),
valid[..valid.len() - 1].to_vec(),
];
let mut wrong_type = valid.clone();
wrong_type[0] = 0x15;
corpus.push(wrong_type);
let mut wrong_handshake = valid.clone();
wrong_handshake[5] = 0x02;
corpus.push(wrong_handshake);
let mut wrong_length = valid.clone();
wrong_length[3] ^= 0x7f;
corpus.push(wrong_length);
for (idx, input) in corpus.iter().enumerate() {
assert!(catch_unwind(|| extract_sni_from_client_hello(input)).is_ok());
assert!(catch_unwind(|| extract_alpn_from_client_hello(input)).is_ok());
if idx == 0 {
continue;
}
assert!(extract_sni_from_client_hello(input).is_none(), "corpus item {idx} must fail closed for SNI");
assert!(extract_alpn_from_client_hello(input).is_empty(), "corpus item {idx} must fail closed for ALPN");
}
}
#[test]
fn tls_handshake_fuzz_corpus_never_panics_and_rejects_digest_mutations() {
let secret = b"tls_fuzz_security_secret";
let now: i64 = 1_700_000_000;
let base = make_valid_tls_handshake_with_session_id(secret, now as u32, &[0x42; 32]);
let secrets = vec![("fuzz-user".to_string(), secret.to_vec())];
assert!(validate_tls_handshake_at_time(&base, &secrets, false, now).is_some());
let mut corpus = Vec::new();
let mut truncated = base.clone();
truncated.truncate(TLS_DIGEST_POS + 16);
corpus.push(truncated);
let mut digest_flip = base.clone();
digest_flip[TLS_DIGEST_POS + 7] ^= 0x80;
corpus.push(digest_flip);
let mut session_id_len_overflow = base.clone();
session_id_len_overflow[TLS_DIGEST_POS + TLS_DIGEST_LEN] = 33;
corpus.push(session_id_len_overflow);
let mut timestamp_far_past = base.clone();
timestamp_far_past[TLS_DIGEST_POS + 28..TLS_DIGEST_POS + 32]
.copy_from_slice(&((now - i64::from(TIME_SKEW_MAX) - 1) as u32).to_le_bytes());
corpus.push(timestamp_far_past);
let mut timestamp_far_future = base.clone();
timestamp_far_future[TLS_DIGEST_POS + 28..TLS_DIGEST_POS + 32]
.copy_from_slice(&((now - TIME_SKEW_MIN + 1) as u32).to_le_bytes());
corpus.push(timestamp_far_future);
let mut seed = 0xA5A5_5A5A_F00D_BAAD_u64;
for _ in 0..32 {
let mut mutated = base.clone();
for _ in 0..2 {
seed = seed.wrapping_mul(2862933555777941757).wrapping_add(3037000493);
let idx = TLS_DIGEST_POS + (seed as usize % TLS_DIGEST_LEN);
mutated[idx] ^= ((seed >> 17) as u8).wrapping_add(1);
}
corpus.push(mutated);
}
for (idx, handshake) in corpus.iter().enumerate() {
let result = catch_unwind(|| validate_tls_handshake_at_time(handshake, &secrets, false, now));
assert!(result.is_ok(), "corpus item {idx} must not panic");
assert!(result.unwrap().is_none(), "corpus item {idx} must fail closed");
}
}
#[test]
fn tls_boot_time_acceptance_is_capped_by_replay_window() {
let secret = b"tls_boot_time_cap_secret";
let secrets = vec![("boot-user".to_string(), secret.to_vec())];
let boot_ts = 1u32;
let handshake = make_valid_tls_handshake_with_session_id(secret, boot_ts, &[0x42; 32]);
assert!(
validate_tls_handshake_with_replay_window(&handshake, &secrets, false, 300).is_some(),
"boot-time timestamp should be accepted while replay window permits it"
);
assert!(
validate_tls_handshake_with_replay_window(&handshake, &secrets, false, 0).is_none(),
"boot-time timestamp must be rejected when replay window disables the bypass"
);
}

View File

@ -7,6 +7,7 @@ use crate::protocol::tls;
use crate::proxy::handshake::HandshakeSuccess;
use crate::stream::{CryptoReader, CryptoWriter};
use crate::transport::proxy_protocol::ProxyProtocolV1Builder;
use std::net::Ipv4Addr;
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
@ -630,6 +631,54 @@ async fn proxy_protocol_header_from_untrusted_peer_range_is_rejected_under_load(
}
}
#[tokio::test]
async fn reservation_limit_failure_does_not_leak_curr_connects_counter() {
let user = "leak-check-user";
let mut config = crate::config::ProxyConfig::default();
config.access.user_max_tcp_conns.insert(user.to_string(), 1);
let stats = Arc::new(crate::stats::Stats::new());
let ip_tracker = Arc::new(crate::ip_tracker::UserIpTracker::new());
ip_tracker.set_user_limit(user, 8).await;
let first_peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(198, 51, 200, 1)), 50001);
let first = RunningClientHandler::acquire_user_connection_reservation_static(
user,
&config,
stats.clone(),
first_peer,
ip_tracker.clone(),
)
.await
.expect("first reservation must succeed");
assert_eq!(stats.get_user_curr_connects(user), 1);
assert_eq!(ip_tracker.get_active_ip_count(user).await, 1);
let second_peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(198, 51, 200, 2)), 50002);
let second = RunningClientHandler::acquire_user_connection_reservation_static(
user,
&config,
stats.clone(),
second_peer,
ip_tracker.clone(),
)
.await;
assert!(
matches!(second, Err(crate::error::ProxyError::ConnectionLimitExceeded { user: denied }) if denied == user),
"second reservation must be rejected at the configured tcp-conns limit"
);
assert_eq!(stats.get_user_curr_connects(user), 1, "failed acquisition must not leak a counter increment");
assert_eq!(ip_tracker.get_active_ip_count(user).await, 1, "failed acquisition must not mutate IP tracker state");
first.release().await;
ip_tracker.drain_cleanup_queue().await;
assert_eq!(stats.get_user_curr_connects(user), 0);
assert_eq!(ip_tracker.get_active_ip_count(user).await, 0);
}
#[tokio::test]
async fn short_tls_probe_is_masked_through_client_pipeline() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();

View File

@ -12,8 +12,10 @@ use std::path::Path;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
use tokio::io::AsyncReadExt;
use tokio::io::duplex;
use tokio::net::TcpListener;
use tokio::time::{timeout, Duration as TokioDuration};
fn make_crypto_reader<R>(reader: R) -> CryptoReader<R>
where
@ -1322,3 +1324,194 @@ fn stress_prefer_v6_override_matrix_is_deterministic_under_mixed_inputs() {
assert!(first.is_ipv6(), "dc {idx}: v6 override should be preferred");
}
}
#[tokio::test]
async fn negative_direct_relay_dc_connection_refused_fails_fast() {
let (client_reader_side, _client_writer_side) = duplex(1024);
let (_client_reader_relay, client_writer_side) = duplex(1024);
let key = [0u8; 32];
let iv = 0u128;
let client_reader = CryptoReader::new(client_reader_side, AesCtr::new(&key, iv));
let client_writer = CryptoWriter::new(client_writer_side, AesCtr::new(&key, iv), 1024);
let stats = Arc::new(Stats::new());
let buffer_pool = Arc::new(BufferPool::with_config(1024, 1));
let rng = Arc::new(SecureRandom::new());
let route_runtime = RouteRuntimeController::new(RelayRouteMode::Direct);
// Reserve an ephemeral port and immediately release it to deterministically
// exercise the direct-connect failure path without long-lived hangs.
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let dc_addr = listener.local_addr().unwrap();
drop(listener);
let mut config_with_override = ProxyConfig::default();
config_with_override.dc_overrides.insert("1".to_string(), vec![dc_addr.to_string()]);
let config = Arc::new(config_with_override);
let upstream_manager = Arc::new(UpstreamManager::new(
vec![UpstreamConfig {
enabled: true,
weight: 1,
scopes: String::new(),
upstream_type: UpstreamType::Direct {
interface: None,
bind_addresses: None,
},
selected_scope: String::new(),
}],
1,
100,
5000,
3,
false,
stats.clone(),
));
let success = HandshakeSuccess {
user: "test-user".to_string(),
peer: "127.0.0.1:12345".parse().unwrap(),
dc_idx: 1,
proto_tag: ProtoTag::Intermediate,
enc_key: key,
enc_iv: iv,
dec_key: key,
dec_iv: iv,
is_tls: false,
};
let result = timeout(
TokioDuration::from_secs(2),
handle_via_direct(
client_reader,
client_writer,
success,
upstream_manager,
stats,
config,
buffer_pool,
rng,
route_runtime.subscribe(),
route_runtime.snapshot(),
0xABCD_1234,
),
)
.await
.expect("direct relay must fail fast on connection-refused upstream");
assert!(
result.is_err(),
"connection-refused upstream must fail closed"
);
}
#[tokio::test]
async fn adversarial_direct_relay_cutover_integrity() {
let (client_reader_side, _client_writer_side) = duplex(1024);
let (_client_reader_relay, client_writer_side) = duplex(1024);
let key = [0u8; 32];
let iv = 0u128;
let client_reader = CryptoReader::new(client_reader_side, AesCtr::new(&key, iv));
let client_writer = CryptoWriter::new(client_writer_side, AesCtr::new(&key, iv), 1024);
let stats = Arc::new(Stats::new());
let buffer_pool = Arc::new(BufferPool::with_config(1024, 1));
let rng = Arc::new(SecureRandom::new());
let route_runtime = RouteRuntimeController::new(RelayRouteMode::Direct);
// Mock upstream server.
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let dc_addr = listener.local_addr().unwrap();
tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
// Read handshake nonce.
let mut nonce = [0u8; 64];
let _ = stream.read_exact(&mut nonce).await;
// Keep connection open.
tokio::time::sleep(TokioDuration::from_secs(5)).await;
});
let mut config_with_override = ProxyConfig::default();
config_with_override.dc_overrides.insert("1".to_string(), vec![dc_addr.to_string()]);
let config = Arc::new(config_with_override);
let upstream_manager = Arc::new(UpstreamManager::new(
vec![UpstreamConfig {
enabled: true,
weight: 1,
scopes: String::new(),
upstream_type: UpstreamType::Direct {
interface: None,
bind_addresses: None,
},
selected_scope: String::new(),
}],
1,
100,
5000,
3,
false,
stats.clone(),
));
let success = HandshakeSuccess {
user: "test-user".to_string(),
peer: "127.0.0.1:12345".parse().unwrap(),
dc_idx: 1,
proto_tag: ProtoTag::Intermediate,
enc_key: key,
enc_iv: iv,
dec_key: key,
dec_iv: iv,
is_tls: false,
};
let stats_for_task = stats.clone();
let runtime_clone = route_runtime.clone();
let session_task = tokio::spawn(async move {
handle_via_direct(
client_reader,
client_writer,
success,
upstream_manager,
stats_for_task,
config,
buffer_pool,
rng,
runtime_clone.subscribe(),
runtime_clone.snapshot(),
0xABCD_1234,
).await
});
timeout(TokioDuration::from_secs(2), async {
loop {
if stats.get_current_connections_direct() == 1 {
break;
}
tokio::time::sleep(TokioDuration::from_millis(10)).await;
}
})
.await
.expect("direct relay session must start before cutover");
// Trigger cutover.
route_runtime.set_mode(RelayRouteMode::Middle).unwrap();
// The session should terminate after the staggered delay (1000-2000ms).
let result = timeout(TokioDuration::from_secs(5), session_task)
.await
.expect("Session must terminate after cutover")
.expect("Session must not panic");
assert!(
matches!(
result,
Err(ProxyError::Proxy(ref msg)) if msg == ROUTE_SWITCH_ERROR_MSG
),
"Session must terminate with route switch error on cutover"
);
}

View File

@ -967,14 +967,14 @@ pub fn encrypt_tg_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec<u8> {
#[path = "handshake_security_tests.rs"]
mod security_tests;
#[cfg(test)]
#[path = "handshake_gap_short_tls_probe_throttle_security_tests.rs"]
mod gap_short_tls_probe_throttle_security_tests;
#[cfg(test)]
#[path = "handshake_adversarial_tests.rs"]
mod adversarial_tests;
#[cfg(test)]
#[path = "handshake_fuzz_security_tests.rs"]
mod fuzz_security_tests;
/// Compile-time guard: HandshakeSuccess holds cryptographic key material and
/// must never be Copy. A Copy impl would allow silent key duplication,
/// undermining the zeroize-on-drop guarantee.

View File

@ -0,0 +1,270 @@
use super::*;
use crate::config::ProxyConfig;
use crate::crypto::AesCtr;
use crate::crypto::sha256;
use crate::protocol::constants::ProtoTag;
use crate::stats::ReplayChecker;
use std::net::SocketAddr;
use std::sync::MutexGuard;
use tokio::time::{timeout, Duration as TokioDuration};
fn make_mtproto_handshake_with_proto_bytes(
secret_hex: &str,
proto_bytes: [u8; 4],
dc_idx: i16,
) -> [u8; HANDSHAKE_LEN] {
let secret = hex::decode(secret_hex).expect("secret hex must decode");
let mut handshake = [0x5Au8; HANDSHAKE_LEN];
for (idx, b) in handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]
.iter_mut()
.enumerate()
{
*b = (idx as u8).wrapping_add(1);
}
let dec_prekey = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN];
let dec_iv_bytes = &handshake[SKIP_LEN + PREKEY_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN];
let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len());
dec_key_input.extend_from_slice(dec_prekey);
dec_key_input.extend_from_slice(&secret);
let dec_key = sha256(&dec_key_input);
let mut dec_iv_arr = [0u8; IV_LEN];
dec_iv_arr.copy_from_slice(dec_iv_bytes);
let dec_iv = u128::from_be_bytes(dec_iv_arr);
let mut stream = AesCtr::new(&dec_key, dec_iv);
let keystream = stream.encrypt(&[0u8; HANDSHAKE_LEN]);
let mut target_plain = [0u8; HANDSHAKE_LEN];
target_plain[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_bytes);
target_plain[DC_IDX_POS..DC_IDX_POS + 2].copy_from_slice(&dc_idx.to_le_bytes());
for idx in PROTO_TAG_POS..HANDSHAKE_LEN {
handshake[idx] = target_plain[idx] ^ keystream[idx];
}
handshake
}
fn make_valid_mtproto_handshake(secret_hex: &str, proto_tag: ProtoTag, dc_idx: i16) -> [u8; HANDSHAKE_LEN] {
make_mtproto_handshake_with_proto_bytes(secret_hex, proto_tag.to_bytes(), dc_idx)
}
fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig {
let mut cfg = ProxyConfig::default();
cfg.access.users.clear();
cfg.access.users.insert("user".to_string(), secret_hex.to_string());
cfg.access.ignore_time_skew = true;
cfg.general.modes.secure = true;
cfg
}
fn auth_probe_test_guard() -> MutexGuard<'static, ()> {
auth_probe_test_lock()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
#[tokio::test]
async fn mtproto_handshake_duplicate_digest_is_replayed_on_second_attempt() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let secret_hex = "11223344556677889900aabbccddeeff";
let base = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 2);
let config = test_config_with_secret_hex(secret_hex);
let replay_checker = ReplayChecker::new(128, TokioDuration::from_secs(60));
let peer: SocketAddr = "192.0.2.1:12345".parse().unwrap();
let first = handle_mtproto_handshake(
&base,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
false,
None,
)
.await;
assert!(matches!(first, HandshakeResult::Success(_)));
let second = handle_mtproto_handshake(
&base,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
false,
None,
)
.await;
assert!(matches!(second, HandshakeResult::BadClient { .. }));
clear_auth_probe_state_for_testing();
}
#[tokio::test]
async fn mtproto_handshake_fuzz_corpus_never_panics_and_stays_fail_closed() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let secret_hex = "00112233445566778899aabbccddeeff";
let base = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 1);
let config = test_config_with_secret_hex(secret_hex);
let replay_checker = ReplayChecker::new(128, TokioDuration::from_secs(60));
let peer: SocketAddr = "192.0.2.2:54321".parse().unwrap();
let mut corpus = Vec::<[u8; HANDSHAKE_LEN]>::new();
corpus.push(make_mtproto_handshake_with_proto_bytes(
secret_hex,
[0x00, 0x00, 0x00, 0x00],
1,
));
corpus.push(make_mtproto_handshake_with_proto_bytes(
secret_hex,
[0xff, 0xff, 0xff, 0xff],
1,
));
corpus.push(make_valid_mtproto_handshake(
"ffeeddccbbaa99887766554433221100",
ProtoTag::Secure,
1,
));
let mut seed = 0xF0F0_F00D_BAAD_CAFEu64;
for _ in 0..32 {
let mut mutated = base;
for _ in 0..4 {
seed = seed.wrapping_mul(2862933555777941757).wrapping_add(3037000493);
let idx = SKIP_LEN + (seed as usize % (PREKEY_LEN + IV_LEN));
mutated[idx] ^= ((seed >> 19) as u8).wrapping_add(1);
}
corpus.push(mutated);
}
for (idx, input) in corpus.into_iter().enumerate() {
let result = timeout(
TokioDuration::from_secs(1),
handle_mtproto_handshake(
&input,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
false,
None,
),
)
.await
.expect("fuzzed handshake must complete in time");
assert!(
matches!(result, HandshakeResult::BadClient { .. }),
"corpus item {idx} must fail closed"
);
}
clear_auth_probe_state_for_testing();
}
#[tokio::test]
async fn mtproto_handshake_mixed_corpus_never_panics_and_exact_duplicates_are_rejected() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let secret_hex = "99887766554433221100ffeeddccbbaa";
let base = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 4);
let config = test_config_with_secret_hex(secret_hex);
let replay_checker = ReplayChecker::new(256, TokioDuration::from_secs(60));
let peer: SocketAddr = "192.0.2.44:45444".parse().unwrap();
let first = timeout(
TokioDuration::from_secs(1),
handle_mtproto_handshake(
&base,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
false,
None,
),
)
.await
.expect("base handshake must not hang");
assert!(matches!(first, HandshakeResult::Success(_)));
let replay = timeout(
TokioDuration::from_secs(1),
handle_mtproto_handshake(
&base,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
false,
None,
),
)
.await
.expect("duplicate handshake must not hang");
assert!(matches!(replay, HandshakeResult::BadClient { .. }));
let mut corpus = Vec::<[u8; HANDSHAKE_LEN]>::new();
let mut prekey_flip = base;
prekey_flip[SKIP_LEN] ^= 0x80;
corpus.push(prekey_flip);
let mut iv_flip = base;
iv_flip[SKIP_LEN + PREKEY_LEN] ^= 0x01;
corpus.push(iv_flip);
let mut tail_flip = base;
tail_flip[SKIP_LEN + PREKEY_LEN + IV_LEN - 1] ^= 0x40;
corpus.push(tail_flip);
let mut seed = 0xBADC_0FFE_EE11_4242u64;
for _ in 0..24 {
let mut mutated = base;
for _ in 0..3 {
seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
let idx = SKIP_LEN + (seed as usize % (PREKEY_LEN + IV_LEN));
mutated[idx] ^= ((seed >> 16) as u8).wrapping_add(1);
}
corpus.push(mutated);
}
for (idx, input) in corpus.iter().enumerate() {
let result = timeout(
TokioDuration::from_secs(1),
handle_mtproto_handshake(
input,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
false,
None,
),
)
.await
.expect("fuzzed handshake must complete in time");
assert!(
matches!(result, HandshakeResult::BadClient { .. }),
"mixed corpus item {idx} must fail closed"
);
}
clear_auth_probe_state_for_testing();
}

View File

@ -1,50 +0,0 @@
use super::*;
use crate::stats::ReplayChecker;
use std::net::SocketAddr;
use std::time::Duration;
fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig {
let mut cfg = ProxyConfig::default();
cfg.access.users.clear();
cfg.access
.users
.insert("user".to_string(), secret_hex.to_string());
cfg.access.ignore_time_skew = true;
cfg
}
#[tokio::test]
async fn gap_t01_short_tls_probe_burst_is_throttled() {
let _guard = auth_probe_test_lock()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
clear_auth_probe_state_for_testing();
let config = test_config_with_secret_hex("11111111111111111111111111111111");
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let rng = SecureRandom::new();
let peer: SocketAddr = "198.51.100.171:44361".parse().unwrap();
let too_short = vec![0x16, 0x03, 0x01];
for _ in 0..AUTH_PROBE_BACKOFF_START_FAILS {
let result = handle_tls_handshake(
&too_short,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
&rng,
None,
)
.await;
assert!(matches!(result, HandshakeResult::BadClient { .. }));
}
assert!(
auth_probe_fail_streak_for_testing(peer.ip())
.is_some_and(|streak| streak >= AUTH_PROBE_BACKOFF_START_FAILS),
"short TLS probe bursts must increase auth-probe fail streak"
);
}

View File

@ -1997,6 +1997,42 @@ fn auth_probe_round_limited_overcap_eviction_marks_saturation_and_keeps_newcomer
);
}
#[tokio::test]
async fn gap_t01_short_tls_probe_burst_is_throttled() {
let _guard = auth_probe_test_lock()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
clear_auth_probe_state_for_testing();
let config = test_config_with_secret_hex("11111111111111111111111111111111");
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let rng = SecureRandom::new();
let peer: SocketAddr = "198.51.100.171:44361".parse().unwrap();
let too_short = vec![0x16, 0x03, 0x01];
for _ in 0..AUTH_PROBE_BACKOFF_START_FAILS {
let result = handle_tls_handshake(
&too_short,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
&rng,
None,
)
.await;
assert!(matches!(result, HandshakeResult::BadClient { .. }));
}
assert!(
auth_probe_fail_streak_for_testing(peer.ip())
.is_some_and(|streak| streak >= AUTH_PROBE_BACKOFF_START_FAILS),
"short TLS probe bursts must increase auth-probe fail streak"
);
}
#[test]
fn stress_auth_probe_overcap_churn_does_not_starve_high_threat_sentinel_bucket() {
let _guard = auth_probe_test_lock()

View File

@ -1,10 +1,11 @@
use super::*;
use crate::proxy::handshake::HandshakeSuccess;
use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController};
use bytes::Bytes;
use crate::crypto::AesCtr;
use crate::crypto::SecureRandom;
use crate::config::{GeneralConfig, MeRouteNoWriterMode, MeSocksKdfPolicy, MeWriterPickMode};
use crate::network::probe::NetworkDecision;
use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController};
use crate::stats::Stats;
use crate::stream::{BufferPool, CryptoReader, CryptoWriter, PooledBuffer};
use crate::transport::middle_proxy::MePool;
@ -20,6 +21,7 @@ use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt;
use tokio::io::duplex;
use tokio::time::{Duration as TokioDuration, timeout};
use std::sync::{Mutex, OnceLock};
fn make_pooled_payload(data: &[u8]) -> PooledBuffer {
let pool = Arc::new(BufferPool::with_config(data.len().max(1), 4));
@ -36,6 +38,11 @@ fn make_pooled_payload_from(pool: &Arc<BufferPool>, data: &[u8]) -> PooledBuffer
payload
}
fn quota_user_lock_test_lock() -> &'static Mutex<()> {
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
TEST_LOCK.get_or_init(|| Mutex::new(()))
}
#[test]
fn should_yield_sender_only_on_budget_with_backlog() {
assert!(!should_yield_c2me_sender(0, true));
@ -244,6 +251,10 @@ fn quota_user_lock_cache_reuses_entry_for_same_user() {
#[test]
fn quota_user_lock_cache_is_bounded_under_unique_churn() {
let _guard = quota_user_lock_test_lock()
.lock()
.expect("quota user lock test lock must be available");
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
@ -261,39 +272,51 @@ fn quota_user_lock_cache_is_bounded_under_unique_churn() {
#[test]
fn quota_user_lock_cache_saturation_returns_ephemeral_lock_without_growth() {
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let _guard = quota_user_lock_test_lock()
.lock()
.expect("quota user lock test lock must be available");
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
for idx in 0..QUOTA_USER_LOCKS_MAX {
let user = format!("quota-held-user-{idx}");
retained.push(quota_user_lock(&user));
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
for attempt in 0..8u32 {
map.clear();
let prefix = format!("quota-held-user-{}-{attempt}", std::process::id());
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
for idx in 0..QUOTA_USER_LOCKS_MAX {
let user = format!("{prefix}-{idx}");
retained.push(quota_user_lock(&user));
}
if map.len() != QUOTA_USER_LOCKS_MAX {
drop(retained);
continue;
}
let overflow_user = format!("quota-overflow-user-{}-{attempt}", std::process::id());
let overflow_a = quota_user_lock(&overflow_user);
let overflow_b = quota_user_lock(&overflow_user);
assert_eq!(
map.len(),
QUOTA_USER_LOCKS_MAX,
"overflow acquisition must not grow cache past hard limit"
);
assert!(
map.get(&overflow_user).is_none(),
"overflow path should not cache new user lock when map is saturated and all entries are retained"
);
assert!(
!Arc::ptr_eq(&overflow_a, &overflow_b),
"overflow user lock should be ephemeral under saturation to preserve bounded cache size"
);
drop(retained);
return;
}
assert_eq!(
map.len(),
QUOTA_USER_LOCKS_MAX,
"precondition: cache should be full before overflow acquisition"
panic!(
"unable to observe stable saturated lock-cache precondition after bounded retries"
);
let overflow_a = quota_user_lock("quota-overflow-user");
let overflow_b = quota_user_lock("quota-overflow-user");
assert_eq!(
map.len(),
QUOTA_USER_LOCKS_MAX,
"overflow acquisition must not grow cache past hard limit"
);
assert!(
map.get("quota-overflow-user").is_none(),
"overflow path should not cache new user lock when map is saturated and all entries are retained"
);
assert!(
!Arc::ptr_eq(&overflow_a, &overflow_b),
"overflow user lock should be ephemeral under saturation to preserve bounded cache size"
);
drop(retained);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
@ -2169,3 +2192,320 @@ async fn middle_relay_cutover_storm_multi_session_keeps_generic_errors_and_relea
drop(client_sides);
}
#[tokio::test]
async fn secure_padding_distribution_in_relay_writer() {
timeout(TokioDuration::from_secs(10), async {
let (mut client_side, relay_side) = duplex(512 * 1024);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(relay_side, AesCtr::new(&key, iv), 8 * 1024);
let rng = Arc::new(SecureRandom::new());
let mut frame_buf = Vec::new();
let mut decryptor = AesCtr::new(&key, iv);
let mut padding_counts = [0usize; 4];
let iterations = 180usize;
let payload = vec![0xAAu8; 100]; // 4-byte aligned
for _ in 0..iterations {
write_client_payload(
&mut writer,
ProtoTag::Secure,
0,
&payload,
&rng,
&mut frame_buf,
)
.await
.expect("payload write must succeed");
writer
.flush()
.await
.expect("writer flush must complete so encrypted frame becomes readable");
let mut len_buf = [0u8; 4];
client_side
.read_exact(&mut len_buf)
.await
.expect("must read encrypted secure length");
let decrypted_len_bytes = decryptor.decrypt(&len_buf);
let decrypted_len_bytes: [u8; 4] = decrypted_len_bytes
.try_into()
.expect("decrypted length must be 4 bytes");
let wire_len = (u32::from_le_bytes(decrypted_len_bytes) & 0x7fff_ffff) as usize;
assert!(
wire_len >= payload.len(),
"wire length must include at least payload bytes"
);
let padding_len = wire_len - payload.len();
assert!(padding_len >= 1 && padding_len <= 3);
padding_counts[padding_len] += 1;
// Drain and decrypt frame bytes so CTR state stays aligned across writes.
let mut trash = vec![0u8; wire_len];
client_side
.read_exact(&mut trash)
.await
.expect("must read encrypted secure frame body");
let _ = decryptor.decrypt(&trash);
}
for p in 1..=3 {
let count = padding_counts[p];
assert!(
count > iterations / 8,
"padding length {p} is under-represented ({count}/{iterations})"
);
}
})
.await
.expect("secure padding distribution test exceeded runtime budget");
}
#[tokio::test]
async fn negative_middle_end_connection_lost_during_relay_exits_on_client_eof() {
let (client_reader_side, client_writer_side) = duplex(1024);
let (_relay_reader_side, relay_writer_side) = duplex(1024);
let key = [0u8; 32];
let iv = 0u128;
let crypto_reader = CryptoReader::new(client_reader_side, AesCtr::new(&key, iv));
let crypto_writer = CryptoWriter::new(relay_writer_side, AesCtr::new(&key, iv), 1024);
let stats = Arc::new(Stats::new());
let config = Arc::new(ProxyConfig::default());
let buffer_pool = Arc::new(BufferPool::with_config(1024, 1));
let rng = Arc::new(SecureRandom::new());
let route_runtime = RouteRuntimeController::new(RelayRouteMode::Middle);
// Create an ME pool.
let me_pool = make_me_pool_for_abort_test(stats.clone()).await;
// ConnRegistry ids are monotonic; reserve one id so we can predict the
// next session conn_id and close it deterministically without relying on
// writer-bound views such as active_conn_ids().
let (probe_conn_id, probe_rx) = me_pool.registry().register().await;
drop(probe_rx);
me_pool.registry().unregister(probe_conn_id).await;
let target_conn_id = probe_conn_id.wrapping_add(1);
let success = HandshakeSuccess {
user: "test-user".to_string(),
peer: "127.0.0.1:12345".parse().unwrap(),
dc_idx: 1,
proto_tag: ProtoTag::Intermediate,
enc_key: key,
enc_iv: iv,
dec_key: key,
dec_iv: iv,
is_tls: false,
};
let session_task = tokio::spawn(handle_via_middle_proxy(
crypto_reader,
crypto_writer,
success,
me_pool.clone(),
stats.clone(),
config.clone(),
buffer_pool.clone(),
"127.0.0.1:443".parse().unwrap(),
rng.clone(),
route_runtime.subscribe(),
route_runtime.snapshot(),
0x1234_5678,
));
// Wait until session startup is visible, then unregister the predicted
// conn_id to close the per-session ME response channel.
timeout(TokioDuration::from_millis(500), async {
loop {
if stats.get_current_connections_me() >= 1 {
break;
}
tokio::time::sleep(TokioDuration::from_millis(10)).await;
}
})
.await
.expect("ME session must start before channel close simulation");
me_pool.registry().unregister(target_conn_id).await;
drop(client_writer_side);
let result = timeout(TokioDuration::from_secs(2), session_task)
.await
.expect("Session task must terminate after ME drop and client EOF")
.expect("Session task must not panic");
assert!(
result.is_ok(),
"Session should complete cleanly after ME drop when client closes, got: {:?}",
result
);
}
#[tokio::test]
async fn adversarial_middle_end_drop_plus_cutover_returns_generic_route_switch() {
let (client_reader_side, _client_writer_side) = duplex(1024);
let (_relay_reader_side, relay_writer_side) = duplex(1024);
let key = [0u8; 32];
let iv = 0u128;
let crypto_reader = CryptoReader::new(client_reader_side, AesCtr::new(&key, iv));
let crypto_writer = CryptoWriter::new(relay_writer_side, AesCtr::new(&key, iv), 1024);
let stats = Arc::new(Stats::new());
let config = Arc::new(ProxyConfig::default());
let buffer_pool = Arc::new(BufferPool::with_config(1024, 1));
let rng = Arc::new(SecureRandom::new());
let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Middle));
let me_pool = make_me_pool_for_abort_test(stats.clone()).await;
// Predict the next conn_id so we can force-drop its ME channel deterministically.
let (probe_conn_id, probe_rx) = me_pool.registry().register().await;
drop(probe_rx);
me_pool.registry().unregister(probe_conn_id).await;
let target_conn_id = probe_conn_id.wrapping_add(1);
let success = HandshakeSuccess {
user: "test-user-cutover".to_string(),
peer: "127.0.0.1:12345".parse().unwrap(),
dc_idx: 1,
proto_tag: ProtoTag::Intermediate,
enc_key: key,
enc_iv: iv,
dec_key: key,
dec_iv: iv,
is_tls: false,
};
let runtime_clone = route_runtime.clone();
let session_task = tokio::spawn(handle_via_middle_proxy(
crypto_reader,
crypto_writer,
success,
me_pool.clone(),
stats.clone(),
config,
buffer_pool,
"127.0.0.1:443".parse().unwrap(),
rng,
runtime_clone.subscribe(),
runtime_clone.snapshot(),
0xC001_CAFE,
));
timeout(TokioDuration::from_millis(500), async {
loop {
if stats.get_current_connections_me() >= 1 {
break;
}
tokio::time::sleep(TokioDuration::from_millis(10)).await;
}
})
.await
.expect("ME session must start before race trigger");
// Race ME channel drop with route cutover and assert generic client-visible outcome.
me_pool.registry().unregister(target_conn_id).await;
assert!(
route_runtime.set_mode(RelayRouteMode::Direct).is_some(),
"cutover must advance generation"
);
let relay_result = timeout(TokioDuration::from_secs(6), session_task)
.await
.expect("session must terminate under ME-drop + cutover race")
.expect("session task must not panic");
assert!(
matches!(
relay_result,
Err(ProxyError::Proxy(ref msg)) if msg == ROUTE_SWITCH_ERROR_MSG
),
"race outcome must remain generic and not leak ME internals, got: {:?}",
relay_result
);
}
#[tokio::test]
async fn stress_middle_end_drop_with_client_eof_never_hangs_across_burst() {
let stats = Arc::new(Stats::new());
let me_pool = make_me_pool_for_abort_test(stats.clone()).await;
for round in 0..32u64 {
let (client_reader_side, client_writer_side) = duplex(1024);
let (_relay_reader_side, relay_writer_side) = duplex(1024);
let key = [0u8; 32];
let iv = 0u128;
let crypto_reader = CryptoReader::new(client_reader_side, AesCtr::new(&key, iv));
let crypto_writer = CryptoWriter::new(relay_writer_side, AesCtr::new(&key, iv), 1024);
let config = Arc::new(ProxyConfig::default());
let buffer_pool = Arc::new(BufferPool::with_config(1024, 1));
let rng = Arc::new(SecureRandom::new());
let route_runtime = RouteRuntimeController::new(RelayRouteMode::Middle);
let (probe_conn_id, probe_rx) = me_pool.registry().register().await;
drop(probe_rx);
me_pool.registry().unregister(probe_conn_id).await;
let target_conn_id = probe_conn_id.wrapping_add(1);
let success = HandshakeSuccess {
user: format!("stress-me-drop-eof-{round}"),
peer: "127.0.0.1:12345".parse().unwrap(),
dc_idx: 1,
proto_tag: ProtoTag::Intermediate,
enc_key: key,
enc_iv: iv,
dec_key: key,
dec_iv: iv,
is_tls: false,
};
let session_task = tokio::spawn(handle_via_middle_proxy(
crypto_reader,
crypto_writer,
success,
me_pool.clone(),
stats.clone(),
config,
buffer_pool,
"127.0.0.1:443".parse().unwrap(),
rng,
route_runtime.subscribe(),
route_runtime.snapshot(),
0xD00D_0000 + round,
));
timeout(TokioDuration::from_millis(500), async {
loop {
if stats.get_current_connections_me() >= 1 {
break;
}
tokio::time::sleep(TokioDuration::from_millis(10)).await;
}
})
.await
.expect("session must start before forced drop in burst round");
me_pool.registry().unregister(target_conn_id).await;
drop(client_writer_side);
let result = timeout(TokioDuration::from_secs(2), session_task)
.await
.expect("burst round session must terminate quickly")
.expect("burst round session must not panic");
assert!(
result.is_ok(),
"burst round {round}: expected clean shutdown after ME drop + EOF, got: {:?}",
result
);
}
}

View File

@ -1140,3 +1140,46 @@ async fn relay_bidirectional_light_fuzz_permission_denied_messages_remain_io_err
);
}
}
#[tokio::test]
async fn relay_half_close_keeps_reverse_direction_progressing() {
let stats = Arc::new(Stats::new());
let user = "half-close-user";
let (client_peer, relay_client) = duplex(1024);
let (relay_server, server_peer) = duplex(1024);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let (mut cp_reader, mut cp_writer) = tokio::io::split(client_peer);
let (mut sp_reader, mut sp_writer) = tokio::io::split(server_peer);
let relay_task = tokio::spawn(relay_bidirectional(
client_reader,
client_writer,
server_reader,
server_writer,
8192,
8192,
user,
Arc::clone(&stats),
None,
Arc::new(BufferPool::new()),
));
sp_writer.write_all(&[0x10, 0x20, 0x30, 0x40]).await.unwrap();
sp_writer.shutdown().await.unwrap();
let mut inbound = [0u8; 4];
cp_reader.read_exact(&mut inbound).await.unwrap();
assert_eq!(inbound, [0x10, 0x20, 0x30, 0x40]);
cp_writer.write_all(&[0xaa, 0xbb, 0xcc, 0xdd]).await.unwrap();
let mut outbound = [0u8; 4];
sp_reader.read_exact(&mut outbound).await.unwrap();
assert_eq!(outbound, [0xaa, 0xbb, 0xcc, 0xdd]);
relay_task.abort();
let joined = relay_task.await;
assert!(joined.is_err(), "aborted relay task must return join error");
}