Compare commits

..

No commits in common. "f8e1e2f2ea0228e1e478ffee7dd07717806758fc" and "44376b5652e35e12cae55e92dd6d8adad5f70b5d" have entirely different histories.

18 changed files with 81 additions and 2219 deletions

View File

@ -544,11 +544,6 @@ pub fn extract_sni_from_client_hello(handshake: &[u8]) -> Option<String> {
return None; return None;
} }
let record_len = u16::from_be_bytes([handshake[3], handshake[4]]) as usize;
if handshake.len() < 5 + record_len {
return None;
}
let mut pos = 5; // after record header let mut pos = 5; // after record header
if handshake.get(pos).copied()? != 0x01 { if handshake.get(pos).copied()? != 0x01 {
return None; // not ClientHello return None; // not ClientHello
@ -654,15 +649,6 @@ fn is_valid_sni_hostname(host: &str) -> bool {
/// Extract ALPN protocol list from ClientHello, return in offered order. /// Extract ALPN protocol list from ClientHello, return in offered order.
pub fn extract_alpn_from_client_hello(handshake: &[u8]) -> Vec<Vec<u8>> { pub fn extract_alpn_from_client_hello(handshake: &[u8]) -> Vec<Vec<u8>> {
if handshake.len() < 5 || handshake[0] != TLS_RECORD_HANDSHAKE {
return Vec::new();
}
let record_len = u16::from_be_bytes([handshake[3], handshake[4]]) as usize;
if handshake.len() < 5 + record_len {
return Vec::new();
}
let mut pos = 5; // after record header let mut pos = 5; // after record header
if handshake.get(pos) != Some(&0x01) { if handshake.get(pos) != Some(&0x01) {
return Vec::new(); return Vec::new();
@ -816,11 +802,3 @@ mod compile_time_security_checks {
#[cfg(test)] #[cfg(test)]
#[path = "tls_security_tests.rs"] #[path = "tls_security_tests.rs"]
mod security_tests; mod security_tests;
#[cfg(test)]
#[path = "tls_adversarial_tests.rs"]
mod adversarial_tests;
#[cfg(test)]
#[path = "tls_fuzz_security_tests.rs"]
mod fuzz_security_tests;

View File

@ -1,352 +0,0 @@
use super::*;
use std::time::Instant;
use crate::crypto::sha256_hmac;
/// Helper to create a byte vector of specific length.
fn make_garbage(len: usize) -> Vec<u8> {
vec![0x42u8; len]
}
/// Helper to create a valid-looking HMAC digest for test.
fn make_digest(secret: &[u8], msg: &[u8], ts: u32) -> [u8; 32] {
let mut hmac = sha256_hmac(secret, msg);
let ts_bytes = ts.to_le_bytes();
for i in 0..4 {
hmac[28 + i] ^= ts_bytes[i];
}
hmac
}
fn make_valid_tls_handshake_with_session_id(
secret: &[u8],
timestamp: u32,
session_id: &[u8],
) -> Vec<u8> {
let session_id_len = session_id.len();
let len = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1 + session_id_len;
let mut handshake = vec![0x42u8; len];
handshake[TLS_DIGEST_POS + TLS_DIGEST_LEN] = session_id_len as u8;
let sid_start = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1;
handshake[sid_start..sid_start + session_id_len].copy_from_slice(session_id);
handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].fill(0);
let digest = make_digest(secret, &handshake, timestamp);
handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN]
.copy_from_slice(&digest);
handshake
}
fn make_valid_tls_handshake(secret: &[u8], timestamp: u32) -> Vec<u8> {
make_valid_tls_handshake_with_session_id(secret, timestamp, &[0x42; 32])
}
// ------------------------------------------------------------------
// Truncated Packet Tests (OWASP ASVS 5.1.4, 5.1.5)
// ------------------------------------------------------------------
#[test]
fn validate_tls_handshake_truncated_10_bytes_rejected() {
let secrets = vec![("user".to_string(), b"secret".to_vec())];
let truncated = make_garbage(10);
assert!(validate_tls_handshake(&truncated, &secrets, true).is_none());
}
#[test]
fn validate_tls_handshake_truncated_at_digest_start_rejected() {
let secrets = vec![("user".to_string(), b"secret".to_vec())];
// TLS_DIGEST_POS = 11. 11 bytes should be rejected.
let truncated = make_garbage(TLS_DIGEST_POS);
assert!(validate_tls_handshake(&truncated, &secrets, true).is_none());
}
#[test]
fn validate_tls_handshake_truncated_inside_digest_rejected() {
let secrets = vec![("user".to_string(), b"secret".to_vec())];
// TLS_DIGEST_POS + 16 (half digest)
let truncated = make_garbage(TLS_DIGEST_POS + 16);
assert!(validate_tls_handshake(&truncated, &secrets, true).is_none());
}
#[test]
fn extract_sni_truncated_at_record_header_rejected() {
let truncated = make_garbage(3);
assert!(extract_sni_from_client_hello(&truncated).is_none());
}
#[test]
fn extract_sni_truncated_at_handshake_header_rejected() {
let mut truncated = vec![TLS_RECORD_HANDSHAKE, 0x03, 0x03, 0x00, 0x05];
truncated.extend_from_slice(&[0x01, 0x00]); // ClientHello type but truncated length
assert!(extract_sni_from_client_hello(&truncated).is_none());
}
// ------------------------------------------------------------------
// Malformed Extension Parsing Tests
// ------------------------------------------------------------------
#[test]
fn extract_sni_with_overlapping_extension_lengths_rejected() {
let mut h = vec![0x16, 0x03, 0x03, 0x00, 0x60]; // Record header
h.push(0x01); // Handshake type: ClientHello
h.extend_from_slice(&[0x00, 0x00, 0x5C]); // Length: 92
h.extend_from_slice(&[0x03, 0x03]); // Version
h.extend_from_slice(&[0u8; 32]); // Random
h.push(0); // Session ID length: 0
h.extend_from_slice(&[0x00, 0x02, 0x13, 0x01]); // Cipher suites
h.extend_from_slice(&[0x01, 0x00]); // Compression
// Extensions start
h.extend_from_slice(&[0x00, 0x20]); // Total Extensions length: 32
// Extension 1: SNI (type 0)
h.extend_from_slice(&[0x00, 0x00]);
h.extend_from_slice(&[0x00, 0x40]); // Claimed len: 64 (OVERFLOWS total extensions len 32)
h.extend_from_slice(&[0u8; 64]);
assert!(extract_sni_from_client_hello(&h).is_none());
}
#[test]
fn extract_sni_with_infinite_loop_potential_extension_rejected() {
let mut h = vec![0x16, 0x03, 0x03, 0x00, 0x60]; // Record header
h.push(0x01); // Handshake type: ClientHello
h.extend_from_slice(&[0x00, 0x00, 0x5C]); // Length: 92
h.extend_from_slice(&[0x03, 0x03]); // Version
h.extend_from_slice(&[0u8; 32]); // Random
h.push(0); // Session ID length: 0
h.extend_from_slice(&[0x00, 0x02, 0x13, 0x01]); // Cipher suites
h.extend_from_slice(&[0x01, 0x00]); // Compression
// Extensions start
h.extend_from_slice(&[0x00, 0x10]); // Total Extensions length: 16
// Extension: zero length but claims more?
// If our parser didn't advance, it might loop.
// Telemt uses `pos += 4 + elen;` so it always advances.
h.extend_from_slice(&[0x12, 0x34]); // Unknown type
h.extend_from_slice(&[0x00, 0x00]); // Length 0
// Fill the rest with garbage
h.extend_from_slice(&[0x42; 12]);
// We expect it to finish without SNI found
assert!(extract_sni_from_client_hello(&h).is_none());
}
#[test]
fn extract_sni_with_invalid_hostname_rejected() {
let host = b"invalid_host!%^";
let mut sni = Vec::new();
sni.extend_from_slice(&((host.len() + 3) as u16).to_be_bytes());
sni.push(0);
sni.extend_from_slice(&(host.len() as u16).to_be_bytes());
sni.extend_from_slice(host);
let mut h = vec![0x16, 0x03, 0x03, 0x00, 0x60]; // Record header
h.push(0x01); // ClientHello
h.extend_from_slice(&[0x00, 0x00, 0x5C]);
h.extend_from_slice(&[0x03, 0x03]);
h.extend_from_slice(&[0u8; 32]);
h.push(0);
h.extend_from_slice(&[0x00, 0x02, 0x13, 0x01]);
h.extend_from_slice(&[0x01, 0x00]);
let mut ext = Vec::new();
ext.extend_from_slice(&0x0000u16.to_be_bytes());
ext.extend_from_slice(&(sni.len() as u16).to_be_bytes());
ext.extend_from_slice(&sni);
h.extend_from_slice(&(ext.len() as u16).to_be_bytes());
h.extend_from_slice(&ext);
assert!(extract_sni_from_client_hello(&h).is_none(), "Invalid SNI hostname must be rejected");
}
// ------------------------------------------------------------------
// Timing Neutrality Tests (OWASP ASVS 5.1.7)
// ------------------------------------------------------------------
#[test]
fn validate_tls_handshake_timing_neutrality() {
let secret = b"timing_test_secret_32_bytes_long_";
let secrets = vec![("u".to_string(), secret.to_vec())];
let mut base = vec![0x42u8; 100];
base[TLS_DIGEST_POS + TLS_DIGEST_LEN] = 32;
const ITER: usize = 600;
const ROUNDS: usize = 7;
let mut per_round_avg_diff_ns = Vec::with_capacity(ROUNDS);
for round in 0..ROUNDS {
let mut success_h = base.clone();
let mut fail_h = base.clone();
let start_success = Instant::now();
for _ in 0..ITER {
let digest = make_digest(secret, &success_h, 0);
success_h[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].copy_from_slice(&digest);
let _ = validate_tls_handshake_at_time(&success_h, &secrets, true, 0);
}
let success_elapsed = start_success.elapsed();
let start_fail = Instant::now();
for i in 0..ITER {
let mut digest = make_digest(secret, &fail_h, 0);
let flip_idx = (i + round) % (TLS_DIGEST_LEN - 4);
digest[flip_idx] ^= 0xFF;
fail_h[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].copy_from_slice(&digest);
let _ = validate_tls_handshake_at_time(&fail_h, &secrets, true, 0);
}
let fail_elapsed = start_fail.elapsed();
let diff = if success_elapsed > fail_elapsed {
success_elapsed - fail_elapsed
} else {
fail_elapsed - success_elapsed
};
per_round_avg_diff_ns.push(diff.as_nanos() as f64 / ITER as f64);
}
per_round_avg_diff_ns.sort_by(|a, b| a.partial_cmp(b).unwrap());
let median_avg_diff_ns = per_round_avg_diff_ns[ROUNDS / 2];
// Keep this as a coarse side-channel guard only; noisy shared CI hosts can
// introduce microsecond-level jitter that should not fail deterministic suites.
assert!(
median_avg_diff_ns < 50_000.0,
"Median timing delta too large: {} ns/iter",
median_avg_diff_ns
);
}
// ------------------------------------------------------------------
// Adversarial Fingerprinting / Active Probing Tests
// ------------------------------------------------------------------
#[test]
fn is_tls_handshake_robustness_against_probing() {
// Valid TLS 1.0 ClientHello
assert!(is_tls_handshake(&[0x16, 0x03, 0x01]));
// Valid TLS 1.2/1.3 ClientHello (Legacy Record Layer)
assert!(is_tls_handshake(&[0x16, 0x03, 0x03]));
// Invalid record type but matching version
assert!(!is_tls_handshake(&[0x17, 0x03, 0x03]));
// Plaintext HTTP request
assert!(!is_tls_handshake(b"GET / HTTP/1.1"));
// Short garbage
assert!(!is_tls_handshake(&[0x16, 0x03]));
}
#[test]
fn validate_tls_handshake_at_time_strict_boundary() {
let secret = b"strict_boundary_secret_32_bytes_";
let secrets = vec![("u".to_string(), secret.to_vec())];
let now: i64 = 1_000_000_000;
// Boundary: exactly TIME_SKEW_MAX (120s past)
let ts_past = (now - TIME_SKEW_MAX) as u32;
let h = make_valid_tls_handshake_with_session_id(secret, ts_past, &[0x42; 32]);
assert!(validate_tls_handshake_at_time(&h, &secrets, false, now).is_some());
// Boundary + 1s: should be rejected
let ts_too_past = (now - TIME_SKEW_MAX - 1) as u32;
let h2 = make_valid_tls_handshake_with_session_id(secret, ts_too_past, &[0x42; 32]);
assert!(validate_tls_handshake_at_time(&h2, &secrets, false, now).is_none());
}
#[test]
fn extract_sni_with_duplicate_extensions_rejected() {
// Construct a ClientHello with TWO SNI extensions
let host1 = b"first.com";
let mut sni1 = Vec::new();
sni1.extend_from_slice(&((host1.len() + 3) as u16).to_be_bytes());
sni1.push(0);
sni1.extend_from_slice(&(host1.len() as u16).to_be_bytes());
sni1.extend_from_slice(host1);
let host2 = b"second.com";
let mut sni2 = Vec::new();
sni2.extend_from_slice(&((host2.len() + 3) as u16).to_be_bytes());
sni2.push(0);
sni2.extend_from_slice(&(host2.len() as u16).to_be_bytes());
sni2.extend_from_slice(host2);
let mut ext = Vec::new();
// Ext 1: SNI
ext.extend_from_slice(&0x0000u16.to_be_bytes());
ext.extend_from_slice(&(sni1.len() as u16).to_be_bytes());
ext.extend_from_slice(&sni1);
// Ext 2: SNI again
ext.extend_from_slice(&0x0000u16.to_be_bytes());
ext.extend_from_slice(&(sni2.len() as u16).to_be_bytes());
ext.extend_from_slice(&sni2);
let mut body = Vec::new();
body.extend_from_slice(&[0x03, 0x03]);
body.extend_from_slice(&[0u8; 32]);
body.push(0);
body.extend_from_slice(&[0x00, 0x02, 0x13, 0x01]);
body.extend_from_slice(&[0x01, 0x00]);
body.extend_from_slice(&(ext.len() as u16).to_be_bytes());
body.extend_from_slice(&ext);
let mut handshake = Vec::new();
handshake.push(0x01);
let body_len = (body.len() as u32).to_be_bytes();
handshake.extend_from_slice(&body_len[1..4]);
handshake.extend_from_slice(&body);
let mut h = Vec::new();
h.push(0x16);
h.extend_from_slice(&[0x03, 0x03]);
h.extend_from_slice(&(handshake.len() as u16).to_be_bytes());
h.extend_from_slice(&handshake);
// Parser might return first, see second, or fail. OWASP ASVS prefers rejection of unexpected dups.
// Telemt's `extract_sni` returns the first one found.
assert!(extract_sni_from_client_hello(&h).is_some());
}
#[test]
fn extract_alpn_with_malformed_list_rejected() {
let mut alpn_payload = Vec::new();
alpn_payload.extend_from_slice(&0x0005u16.to_be_bytes()); // Total len 5
alpn_payload.push(10); // Labeled len 10 (OVERFLOWS total 5)
alpn_payload.extend_from_slice(b"h2");
let mut ext = Vec::new();
ext.extend_from_slice(&0x0010u16.to_be_bytes()); // Type: ALPN (16)
ext.extend_from_slice(&(alpn_payload.len() as u16).to_be_bytes());
ext.extend_from_slice(&alpn_payload);
let mut h = vec![0x16, 0x03, 0x03, 0x00, 0x40, 0x01, 0x00, 0x00, 0x3C, 0x03, 0x03];
h.extend_from_slice(&[0u8; 32]);
h.push(0);
h.extend_from_slice(&[0x00, 0x02, 0x13, 0x01, 0x01, 0x00]);
h.extend_from_slice(&(ext.len() as u16).to_be_bytes());
h.extend_from_slice(&ext);
let res = extract_alpn_from_client_hello(&h);
assert!(res.is_empty(), "Malformed ALPN list must return empty or fail");
}
#[test]
fn extract_sni_with_huge_extension_header_rejected() {
let mut h = vec![0x16, 0x03, 0x03, 0x00, 0x00]; // Record header
h.push(0x01); // ClientHello
h.extend_from_slice(&[0x00, 0xFF, 0xFF]); // Huge length (65535) - overflows record
h.extend_from_slice(&[0x03, 0x03]);
h.extend_from_slice(&[0u8; 32]);
h.push(0);
h.extend_from_slice(&[0x00, 0x02, 0x13, 0x01, 0x01, 0x00]);
// Extensions start
h.extend_from_slice(&[0xFF, 0xFF]); // Total extensions: 65535 (OVERFLOWS everything)
assert!(extract_sni_from_client_hello(&h).is_none());
}

View File

@ -1,195 +0,0 @@
use super::*;
use crate::crypto::sha256_hmac;
use std::panic::catch_unwind;
fn make_valid_tls_handshake_with_session_id(
secret: &[u8],
timestamp: u32,
session_id: &[u8],
) -> Vec<u8> {
let session_id_len = session_id.len();
assert!(session_id_len <= u8::MAX as usize);
let len = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1 + session_id_len;
let mut handshake = vec![0x42u8; len];
handshake[TLS_DIGEST_POS + TLS_DIGEST_LEN] = session_id_len as u8;
let sid_start = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1;
handshake[sid_start..sid_start + session_id_len].copy_from_slice(session_id);
handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].fill(0);
let mut digest = sha256_hmac(secret, &handshake);
let ts = timestamp.to_le_bytes();
for idx in 0..4 {
digest[28 + idx] ^= ts[idx];
}
handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].copy_from_slice(&digest);
handshake
}
fn make_valid_client_hello_record(host: &str, alpn_protocols: &[&[u8]]) -> Vec<u8> {
let mut body = Vec::new();
body.extend_from_slice(&TLS_VERSION);
body.extend_from_slice(&[0u8; 32]);
body.push(0);
body.extend_from_slice(&2u16.to_be_bytes());
body.extend_from_slice(&[0x13, 0x01]);
body.push(1);
body.push(0);
let mut ext_blob = Vec::new();
let host_bytes = host.as_bytes();
let mut sni_payload = Vec::new();
sni_payload.extend_from_slice(&((host_bytes.len() + 3) as u16).to_be_bytes());
sni_payload.push(0);
sni_payload.extend_from_slice(&(host_bytes.len() as u16).to_be_bytes());
sni_payload.extend_from_slice(host_bytes);
ext_blob.extend_from_slice(&0x0000u16.to_be_bytes());
ext_blob.extend_from_slice(&(sni_payload.len() as u16).to_be_bytes());
ext_blob.extend_from_slice(&sni_payload);
if !alpn_protocols.is_empty() {
let mut alpn_list = Vec::new();
for proto in alpn_protocols {
alpn_list.push(proto.len() as u8);
alpn_list.extend_from_slice(proto);
}
let mut alpn_data = Vec::new();
alpn_data.extend_from_slice(&(alpn_list.len() as u16).to_be_bytes());
alpn_data.extend_from_slice(&alpn_list);
ext_blob.extend_from_slice(&0x0010u16.to_be_bytes());
ext_blob.extend_from_slice(&(alpn_data.len() as u16).to_be_bytes());
ext_blob.extend_from_slice(&alpn_data);
}
body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes());
body.extend_from_slice(&ext_blob);
let mut handshake = Vec::new();
handshake.push(0x01);
let body_len = (body.len() as u32).to_be_bytes();
handshake.extend_from_slice(&body_len[1..4]);
handshake.extend_from_slice(&body);
let mut record = Vec::new();
record.push(TLS_RECORD_HANDSHAKE);
record.extend_from_slice(&[0x03, 0x01]);
record.extend_from_slice(&(handshake.len() as u16).to_be_bytes());
record.extend_from_slice(&handshake);
record
}
#[test]
fn client_hello_fuzz_corpus_never_panics_or_accepts_corruption() {
let valid = make_valid_client_hello_record("example.com", &[b"h2", b"http/1.1"]);
assert_eq!(extract_sni_from_client_hello(&valid).as_deref(), Some("example.com"));
assert_eq!(
extract_alpn_from_client_hello(&valid),
vec![b"h2".to_vec(), b"http/1.1".to_vec()]
);
assert!(
extract_sni_from_client_hello(&make_valid_client_hello_record("127.0.0.1", &[])).is_none(),
"literal IP hostnames must be rejected"
);
let mut corpus = vec![
Vec::new(),
vec![0x16, 0x03, 0x03],
valid[..9].to_vec(),
valid[..valid.len() - 1].to_vec(),
];
let mut wrong_type = valid.clone();
wrong_type[0] = 0x15;
corpus.push(wrong_type);
let mut wrong_handshake = valid.clone();
wrong_handshake[5] = 0x02;
corpus.push(wrong_handshake);
let mut wrong_length = valid.clone();
wrong_length[3] ^= 0x7f;
corpus.push(wrong_length);
for (idx, input) in corpus.iter().enumerate() {
assert!(catch_unwind(|| extract_sni_from_client_hello(input)).is_ok());
assert!(catch_unwind(|| extract_alpn_from_client_hello(input)).is_ok());
if idx == 0 {
continue;
}
assert!(extract_sni_from_client_hello(input).is_none(), "corpus item {idx} must fail closed for SNI");
assert!(extract_alpn_from_client_hello(input).is_empty(), "corpus item {idx} must fail closed for ALPN");
}
}
#[test]
fn tls_handshake_fuzz_corpus_never_panics_and_rejects_digest_mutations() {
let secret = b"tls_fuzz_security_secret";
let now: i64 = 1_700_000_000;
let base = make_valid_tls_handshake_with_session_id(secret, now as u32, &[0x42; 32]);
let secrets = vec![("fuzz-user".to_string(), secret.to_vec())];
assert!(validate_tls_handshake_at_time(&base, &secrets, false, now).is_some());
let mut corpus = Vec::new();
let mut truncated = base.clone();
truncated.truncate(TLS_DIGEST_POS + 16);
corpus.push(truncated);
let mut digest_flip = base.clone();
digest_flip[TLS_DIGEST_POS + 7] ^= 0x80;
corpus.push(digest_flip);
let mut session_id_len_overflow = base.clone();
session_id_len_overflow[TLS_DIGEST_POS + TLS_DIGEST_LEN] = 33;
corpus.push(session_id_len_overflow);
let mut timestamp_far_past = base.clone();
timestamp_far_past[TLS_DIGEST_POS + 28..TLS_DIGEST_POS + 32]
.copy_from_slice(&((now - i64::from(TIME_SKEW_MAX) - 1) as u32).to_le_bytes());
corpus.push(timestamp_far_past);
let mut timestamp_far_future = base.clone();
timestamp_far_future[TLS_DIGEST_POS + 28..TLS_DIGEST_POS + 32]
.copy_from_slice(&((now - TIME_SKEW_MIN + 1) as u32).to_le_bytes());
corpus.push(timestamp_far_future);
let mut seed = 0xA5A5_5A5A_F00D_BAAD_u64;
for _ in 0..32 {
let mut mutated = base.clone();
for _ in 0..2 {
seed = seed.wrapping_mul(2862933555777941757).wrapping_add(3037000493);
let idx = TLS_DIGEST_POS + (seed as usize % TLS_DIGEST_LEN);
mutated[idx] ^= ((seed >> 17) as u8).wrapping_add(1);
}
corpus.push(mutated);
}
for (idx, handshake) in corpus.iter().enumerate() {
let result = catch_unwind(|| validate_tls_handshake_at_time(handshake, &secrets, false, now));
assert!(result.is_ok(), "corpus item {idx} must not panic");
assert!(result.unwrap().is_none(), "corpus item {idx} must fail closed");
}
}
#[test]
fn tls_boot_time_acceptance_is_capped_by_replay_window() {
let secret = b"tls_boot_time_cap_secret";
let secrets = vec![("boot-user".to_string(), secret.to_vec())];
let boot_ts = 1u32;
let handshake = make_valid_tls_handshake_with_session_id(secret, boot_ts, &[0x42; 32]);
assert!(
validate_tls_handshake_with_replay_window(&handshake, &secrets, false, 300).is_some(),
"boot-time timestamp should be accepted while replay window permits it"
);
assert!(
validate_tls_handshake_with_replay_window(&handshake, &secrets, false, 0).is_none(),
"boot-time timestamp must be rejected when replay window disables the bypass"
);
}

View File

@ -1040,6 +1040,3 @@ impl RunningClientHandler {
#[cfg(test)] #[cfg(test)]
#[path = "client_security_tests.rs"] #[path = "client_security_tests.rs"]
mod security_tests; mod security_tests;
#[cfg(test)]
#[path = "client_adversarial_tests.rs"]
mod adversarial_tests;

View File

@ -1,109 +0,0 @@
use super::*;
use crate::config::ProxyConfig;
use crate::stats::Stats;
use crate::ip_tracker::UserIpTracker;
use crate::error::ProxyError;
use std::sync::Arc;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
// ------------------------------------------------------------------
// Priority 3: Massive Concurrency Stress (OWASP ASVS 5.1.6)
// ------------------------------------------------------------------
#[tokio::test]
async fn client_stress_10k_connections_limit_strict() {
let user = "stress-user";
let limit = 512;
let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new());
let mut config = ProxyConfig::default();
config.access.user_max_tcp_conns.insert(user.to_string(), limit);
let iterations = 1000;
let mut tasks = Vec::new();
for i in 0..iterations {
let stats = Arc::clone(&stats);
let ip_tracker = Arc::clone(&ip_tracker);
let config = config.clone();
let user_str = user.to_string();
tasks.push(tokio::spawn(async move {
let peer = SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(127, 0, 0, (i % 254 + 1) as u8)),
10000 + (i % 1000) as u16,
);
match RunningClientHandler::acquire_user_connection_reservation_static(
&user_str,
&config,
stats,
peer,
ip_tracker,
).await {
Ok(res) => Ok(res),
Err(ProxyError::ConnectionLimitExceeded { .. }) => Err(()),
Err(e) => panic!("Unexpected error: {:?}", e),
}
}));
}
let results = futures::future::join_all(tasks).await;
let mut successes = 0;
let mut failures = 0;
let mut reservations = Vec::new();
for res in results {
match res.unwrap() {
Ok(r) => {
successes += 1;
reservations.push(r);
}
Err(_) => failures += 1,
}
}
assert_eq!(successes, limit, "Should allow exactly 'limit' connections");
assert_eq!(failures, iterations - limit, "Should fail the rest with LimitExceeded");
assert_eq!(stats.get_user_curr_connects(user), limit as u64);
drop(reservations);
ip_tracker.drain_cleanup_queue().await;
assert_eq!(stats.get_user_curr_connects(user), 0, "Stats must converge to 0 after all drops");
assert_eq!(ip_tracker.get_active_ip_count(user).await, 0, "IP tracker must converge to 0");
}
// ------------------------------------------------------------------
// Priority 3: IP Tracker Race Stress
// ------------------------------------------------------------------
#[tokio::test]
async fn client_ip_tracker_race_condition_stress() {
let user = "race-user";
let ip_tracker = Arc::new(UserIpTracker::new());
ip_tracker.set_user_limit(user, 100).await;
let iterations = 1000;
let mut tasks = Vec::new();
for i in 0..iterations {
let ip_tracker = Arc::clone(&ip_tracker);
let ip = IpAddr::V4(Ipv4Addr::new(10, 0, 0, (i % 254 + 1) as u8));
tasks.push(tokio::spawn(async move {
for _ in 0..10 {
if let Ok(()) = ip_tracker.check_and_add("race-user", ip).await {
ip_tracker.remove_ip("race-user", ip).await;
}
}
}));
}
futures::future::join_all(tasks).await;
assert_eq!(ip_tracker.get_active_ip_count(user).await, 0, "IP count must be zero after balanced add/remove burst");
}

View File

@ -7,7 +7,6 @@ use crate::protocol::tls;
use crate::proxy::handshake::HandshakeSuccess; use crate::proxy::handshake::HandshakeSuccess;
use crate::stream::{CryptoReader, CryptoWriter}; use crate::stream::{CryptoReader, CryptoWriter};
use crate::transport::proxy_protocol::ProxyProtocolV1Builder; use crate::transport::proxy_protocol::ProxyProtocolV1Builder;
use std::net::Ipv4Addr;
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream}; use tokio::net::{TcpListener, TcpStream};
@ -631,54 +630,6 @@ async fn proxy_protocol_header_from_untrusted_peer_range_is_rejected_under_load(
} }
} }
#[tokio::test]
async fn reservation_limit_failure_does_not_leak_curr_connects_counter() {
let user = "leak-check-user";
let mut config = crate::config::ProxyConfig::default();
config.access.user_max_tcp_conns.insert(user.to_string(), 1);
let stats = Arc::new(crate::stats::Stats::new());
let ip_tracker = Arc::new(crate::ip_tracker::UserIpTracker::new());
ip_tracker.set_user_limit(user, 8).await;
let first_peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(198, 51, 200, 1)), 50001);
let first = RunningClientHandler::acquire_user_connection_reservation_static(
user,
&config,
stats.clone(),
first_peer,
ip_tracker.clone(),
)
.await
.expect("first reservation must succeed");
assert_eq!(stats.get_user_curr_connects(user), 1);
assert_eq!(ip_tracker.get_active_ip_count(user).await, 1);
let second_peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(198, 51, 200, 2)), 50002);
let second = RunningClientHandler::acquire_user_connection_reservation_static(
user,
&config,
stats.clone(),
second_peer,
ip_tracker.clone(),
)
.await;
assert!(
matches!(second, Err(crate::error::ProxyError::ConnectionLimitExceeded { user: denied }) if denied == user),
"second reservation must be rejected at the configured tcp-conns limit"
);
assert_eq!(stats.get_user_curr_connects(user), 1, "failed acquisition must not leak a counter increment");
assert_eq!(ip_tracker.get_active_ip_count(user).await, 1, "failed acquisition must not mutate IP tracker state");
first.release().await;
ip_tracker.drain_cleanup_queue().await;
assert_eq!(stats.get_user_curr_connects(user), 0);
assert_eq!(ip_tracker.get_active_ip_count(user).await, 0);
}
#[tokio::test] #[tokio::test]
async fn short_tls_probe_is_masked_through_client_pipeline() { async fn short_tls_probe_is_masked_through_client_pipeline() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();

View File

@ -12,10 +12,8 @@ use std::path::Path;
use std::sync::Arc; use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration; use std::time::Duration;
use tokio::io::AsyncReadExt;
use tokio::io::duplex; use tokio::io::duplex;
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio::time::{timeout, Duration as TokioDuration};
fn make_crypto_reader<R>(reader: R) -> CryptoReader<R> fn make_crypto_reader<R>(reader: R) -> CryptoReader<R>
where where
@ -1324,194 +1322,3 @@ fn stress_prefer_v6_override_matrix_is_deterministic_under_mixed_inputs() {
assert!(first.is_ipv6(), "dc {idx}: v6 override should be preferred"); assert!(first.is_ipv6(), "dc {idx}: v6 override should be preferred");
} }
} }
#[tokio::test]
async fn negative_direct_relay_dc_connection_refused_fails_fast() {
let (client_reader_side, _client_writer_side) = duplex(1024);
let (_client_reader_relay, client_writer_side) = duplex(1024);
let key = [0u8; 32];
let iv = 0u128;
let client_reader = CryptoReader::new(client_reader_side, AesCtr::new(&key, iv));
let client_writer = CryptoWriter::new(client_writer_side, AesCtr::new(&key, iv), 1024);
let stats = Arc::new(Stats::new());
let buffer_pool = Arc::new(BufferPool::with_config(1024, 1));
let rng = Arc::new(SecureRandom::new());
let route_runtime = RouteRuntimeController::new(RelayRouteMode::Direct);
// Reserve an ephemeral port and immediately release it to deterministically
// exercise the direct-connect failure path without long-lived hangs.
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let dc_addr = listener.local_addr().unwrap();
drop(listener);
let mut config_with_override = ProxyConfig::default();
config_with_override.dc_overrides.insert("1".to_string(), vec![dc_addr.to_string()]);
let config = Arc::new(config_with_override);
let upstream_manager = Arc::new(UpstreamManager::new(
vec![UpstreamConfig {
enabled: true,
weight: 1,
scopes: String::new(),
upstream_type: UpstreamType::Direct {
interface: None,
bind_addresses: None,
},
selected_scope: String::new(),
}],
1,
100,
5000,
3,
false,
stats.clone(),
));
let success = HandshakeSuccess {
user: "test-user".to_string(),
peer: "127.0.0.1:12345".parse().unwrap(),
dc_idx: 1,
proto_tag: ProtoTag::Intermediate,
enc_key: key,
enc_iv: iv,
dec_key: key,
dec_iv: iv,
is_tls: false,
};
let result = timeout(
TokioDuration::from_secs(2),
handle_via_direct(
client_reader,
client_writer,
success,
upstream_manager,
stats,
config,
buffer_pool,
rng,
route_runtime.subscribe(),
route_runtime.snapshot(),
0xABCD_1234,
),
)
.await
.expect("direct relay must fail fast on connection-refused upstream");
assert!(
result.is_err(),
"connection-refused upstream must fail closed"
);
}
#[tokio::test]
async fn adversarial_direct_relay_cutover_integrity() {
let (client_reader_side, _client_writer_side) = duplex(1024);
let (_client_reader_relay, client_writer_side) = duplex(1024);
let key = [0u8; 32];
let iv = 0u128;
let client_reader = CryptoReader::new(client_reader_side, AesCtr::new(&key, iv));
let client_writer = CryptoWriter::new(client_writer_side, AesCtr::new(&key, iv), 1024);
let stats = Arc::new(Stats::new());
let buffer_pool = Arc::new(BufferPool::with_config(1024, 1));
let rng = Arc::new(SecureRandom::new());
let route_runtime = RouteRuntimeController::new(RelayRouteMode::Direct);
// Mock upstream server.
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let dc_addr = listener.local_addr().unwrap();
tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
// Read handshake nonce.
let mut nonce = [0u8; 64];
let _ = stream.read_exact(&mut nonce).await;
// Keep connection open.
tokio::time::sleep(TokioDuration::from_secs(5)).await;
});
let mut config_with_override = ProxyConfig::default();
config_with_override.dc_overrides.insert("1".to_string(), vec![dc_addr.to_string()]);
let config = Arc::new(config_with_override);
let upstream_manager = Arc::new(UpstreamManager::new(
vec![UpstreamConfig {
enabled: true,
weight: 1,
scopes: String::new(),
upstream_type: UpstreamType::Direct {
interface: None,
bind_addresses: None,
},
selected_scope: String::new(),
}],
1,
100,
5000,
3,
false,
stats.clone(),
));
let success = HandshakeSuccess {
user: "test-user".to_string(),
peer: "127.0.0.1:12345".parse().unwrap(),
dc_idx: 1,
proto_tag: ProtoTag::Intermediate,
enc_key: key,
enc_iv: iv,
dec_key: key,
dec_iv: iv,
is_tls: false,
};
let stats_for_task = stats.clone();
let runtime_clone = route_runtime.clone();
let session_task = tokio::spawn(async move {
handle_via_direct(
client_reader,
client_writer,
success,
upstream_manager,
stats_for_task,
config,
buffer_pool,
rng,
runtime_clone.subscribe(),
runtime_clone.snapshot(),
0xABCD_1234,
).await
});
timeout(TokioDuration::from_secs(2), async {
loop {
if stats.get_current_connections_direct() == 1 {
break;
}
tokio::time::sleep(TokioDuration::from_millis(10)).await;
}
})
.await
.expect("direct relay session must start before cutover");
// Trigger cutover.
route_runtime.set_mode(RelayRouteMode::Middle).unwrap();
// The session should terminate after the staggered delay (1000-2000ms).
let result = timeout(TokioDuration::from_secs(5), session_task)
.await
.expect("Session must terminate after cutover")
.expect("Session must not panic");
assert!(
matches!(
result,
Err(ProxyError::Proxy(ref msg)) if msg == ROUTE_SWITCH_ERROR_MSG
),
"Session must terminate with route switch error on cutover"
);
}

View File

@ -968,12 +968,8 @@ pub fn encrypt_tg_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec<u8> {
mod security_tests; mod security_tests;
#[cfg(test)] #[cfg(test)]
#[path = "handshake_adversarial_tests.rs"] #[path = "handshake_gap_short_tls_probe_throttle_security_tests.rs"]
mod adversarial_tests; mod gap_short_tls_probe_throttle_security_tests;
#[cfg(test)]
#[path = "handshake_fuzz_security_tests.rs"]
mod fuzz_security_tests;
/// Compile-time guard: HandshakeSuccess holds cryptographic key material and /// Compile-time guard: HandshakeSuccess holds cryptographic key material and
/// must never be Copy. A Copy impl would allow silent key duplication, /// must never be Copy. A Copy impl would allow silent key duplication,

View File

@ -1,231 +0,0 @@
use super::*;
use std::sync::Arc;
use std::net::{IpAddr, Ipv4Addr};
use std::time::{Duration, Instant};
use crate::crypto::sha256;
fn make_valid_mtproto_handshake(secret_hex: &str, proto_tag: ProtoTag, dc_idx: i16) -> [u8; HANDSHAKE_LEN] {
let secret = hex::decode(secret_hex).expect("secret hex must decode");
let mut handshake = [0x5Au8; HANDSHAKE_LEN];
for (idx, b) in handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]
.iter_mut()
.enumerate()
{
*b = (idx as u8).wrapping_add(1);
}
let dec_prekey = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN];
let dec_iv_bytes = &handshake[SKIP_LEN + PREKEY_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN];
let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len());
dec_key_input.extend_from_slice(dec_prekey);
dec_key_input.extend_from_slice(&secret);
let dec_key = sha256(&dec_key_input);
let mut dec_iv_arr = [0u8; IV_LEN];
dec_iv_arr.copy_from_slice(dec_iv_bytes);
let dec_iv = u128::from_be_bytes(dec_iv_arr);
let mut stream = AesCtr::new(&dec_key, dec_iv);
let keystream = stream.encrypt(&[0u8; HANDSHAKE_LEN]);
let mut target_plain = [0u8; HANDSHAKE_LEN];
target_plain[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes());
target_plain[DC_IDX_POS..DC_IDX_POS + 2].copy_from_slice(&dc_idx.to_le_bytes());
for idx in PROTO_TAG_POS..HANDSHAKE_LEN {
handshake[idx] = target_plain[idx] ^ keystream[idx];
}
handshake
}
fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> {
auth_probe_test_lock()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig {
let mut cfg = ProxyConfig::default();
cfg.access.users.clear();
cfg.access.users.insert("user".to_string(), secret_hex.to_string());
cfg.access.ignore_time_skew = true;
cfg.general.modes.secure = true;
cfg
}
// ------------------------------------------------------------------
// Mutational Bit-Flipping Tests (OWASP ASVS 5.1.4)
// ------------------------------------------------------------------
#[tokio::test]
async fn mtproto_handshake_bit_flip_anywhere_rejected() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let secret_hex = "11223344556677889900aabbccddeeff";
let base = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 2);
let config = test_config_with_secret_hex(secret_hex);
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let peer: SocketAddr = "192.0.2.1:12345".parse().unwrap();
// Baseline check
let res = handle_mtproto_handshake(&base, tokio::io::empty(), tokio::io::sink(), peer, &config, &replay_checker, false, None).await;
match res {
HandshakeResult::Success(_) => {},
_ => panic!("Baseline failed: expected Success"),
}
// Flip bits in the encrypted part (beyond the key material)
for byte_pos in SKIP_LEN..HANDSHAKE_LEN {
let mut h = base;
h[byte_pos] ^= 0x01; // Flip 1 bit
let res = handle_mtproto_handshake(&h, tokio::io::empty(), tokio::io::sink(), peer, &config, &replay_checker, false, None).await;
assert!(matches!(res, HandshakeResult::BadClient { .. }), "Flip at byte {byte_pos} bit 0 must be rejected");
}
}
// ------------------------------------------------------------------
// Adversarial Probing / Timing Neutrality (OWASP ASVS 5.1.7)
// ------------------------------------------------------------------
#[tokio::test]
async fn mtproto_handshake_timing_neutrality_mocked() {
let secret_hex = "00112233445566778899aabbccddeeff";
let base = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 1);
let config = test_config_with_secret_hex(secret_hex);
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let peer: SocketAddr = "192.0.2.2:54321".parse().unwrap();
const ITER: usize = 50;
let mut start = Instant::now();
for _ in 0..ITER {
let _ = handle_mtproto_handshake(&base, tokio::io::empty(), tokio::io::sink(), peer, &config, &replay_checker, false, None).await;
}
let duration_success = start.elapsed();
start = Instant::now();
for i in 0..ITER {
let mut h = base;
h[SKIP_LEN + (i % 48)] ^= 0xFF;
let _ = handle_mtproto_handshake(&h, tokio::io::empty(), tokio::io::sink(), peer, &config, &replay_checker, false, None).await;
}
let duration_fail = start.elapsed();
let avg_diff_ms = (duration_success.as_millis() as f64 - duration_fail.as_millis() as f64).abs() / ITER as f64;
// Threshold (loose for CI)
assert!(avg_diff_ms < 100.0, "Timing difference too large: {} ms/iter", avg_diff_ms);
}
// ------------------------------------------------------------------
// Stress Tests (OWASP ASVS 5.1.6)
// ------------------------------------------------------------------
#[tokio::test]
async fn auth_probe_throttle_saturation_stress() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let now = Instant::now();
// Record enough failures for one IP to trigger backoff
let target_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
for _ in 0..AUTH_PROBE_BACKOFF_START_FAILS {
auth_probe_record_failure(target_ip, now);
}
assert!(auth_probe_is_throttled(target_ip, now));
// Stress test with many unique IPs
for i in 0..500u32 {
let ip = IpAddr::V4(Ipv4Addr::new(203, 0, 113, (i % 256) as u8));
auth_probe_record_failure(ip, now);
}
let tracked = AUTH_PROBE_STATE
.get()
.map(|state| state.len())
.unwrap_or(0);
assert!(
tracked <= AUTH_PROBE_TRACK_MAX_ENTRIES,
"auth probe state grew past hard cap: {tracked} > {AUTH_PROBE_TRACK_MAX_ENTRIES}"
);
}
#[tokio::test]
async fn mtproto_handshake_abridged_prefix_rejected() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let mut handshake = [0x5Au8; HANDSHAKE_LEN];
handshake[0] = 0xef; // Abridged prefix
let config = ProxyConfig::default();
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let peer: SocketAddr = "192.0.2.3:12345".parse().unwrap();
let res = handle_mtproto_handshake(&handshake, tokio::io::empty(), tokio::io::sink(), peer, &config, &replay_checker, false, None).await;
// MTProxy stops immediately on 0xef
assert!(matches!(res, HandshakeResult::BadClient { .. }));
}
#[tokio::test]
async fn mtproto_handshake_preferred_user_mismatch_continues() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let secret1_hex = "11111111111111111111111111111111";
let secret2_hex = "22222222222222222222222222222222";
let base = make_valid_mtproto_handshake(secret2_hex, ProtoTag::Secure, 1);
let mut config = ProxyConfig::default();
config.access.users.insert("user1".to_string(), secret1_hex.to_string());
config.access.users.insert("user2".to_string(), secret2_hex.to_string());
config.access.ignore_time_skew = true;
config.general.modes.secure = true;
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let peer: SocketAddr = "192.0.2.4:12345".parse().unwrap();
// Even if we prefer user1, if user2 matches, it should succeed.
let res = handle_mtproto_handshake(&base, tokio::io::empty(), tokio::io::sink(), peer, &config, &replay_checker, false, Some("user1")).await;
if let HandshakeResult::Success((_, _, success)) = res {
assert_eq!(success.user, "user2");
} else {
panic!("Handshake failed even though user2 matched");
}
}
#[tokio::test]
async fn mtproto_handshake_concurrent_flood_stability() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let secret_hex = "00112233445566778899aabbccddeeff";
let base = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 1);
let mut config = test_config_with_secret_hex(secret_hex);
config.access.ignore_time_skew = true;
let replay_checker = Arc::new(ReplayChecker::new(1024, Duration::from_secs(60)));
let config = Arc::new(config);
let mut tasks = Vec::new();
for i in 0..50 {
let base = base;
let config = Arc::clone(&config);
let replay_checker = Arc::clone(&replay_checker);
let peer: SocketAddr = format!("192.0.2.{}:12345", (i % 254) + 1).parse().unwrap();
tasks.push(tokio::spawn(async move {
let res = handle_mtproto_handshake(&base, tokio::io::empty(), tokio::io::sink(), peer, &config, &replay_checker, false, None).await;
matches!(res, HandshakeResult::Success(_))
}));
}
// We don't necessarily care if they all succeed (some might fail due to replay if they hit the same chunk),
// but the system must not panic or hang.
for task in tasks {
let _ = task.await.unwrap();
}
}

View File

@ -1,270 +0,0 @@
use super::*;
use crate::config::ProxyConfig;
use crate::crypto::AesCtr;
use crate::crypto::sha256;
use crate::protocol::constants::ProtoTag;
use crate::stats::ReplayChecker;
use std::net::SocketAddr;
use std::sync::MutexGuard;
use tokio::time::{timeout, Duration as TokioDuration};
fn make_mtproto_handshake_with_proto_bytes(
secret_hex: &str,
proto_bytes: [u8; 4],
dc_idx: i16,
) -> [u8; HANDSHAKE_LEN] {
let secret = hex::decode(secret_hex).expect("secret hex must decode");
let mut handshake = [0x5Au8; HANDSHAKE_LEN];
for (idx, b) in handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]
.iter_mut()
.enumerate()
{
*b = (idx as u8).wrapping_add(1);
}
let dec_prekey = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN];
let dec_iv_bytes = &handshake[SKIP_LEN + PREKEY_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN];
let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len());
dec_key_input.extend_from_slice(dec_prekey);
dec_key_input.extend_from_slice(&secret);
let dec_key = sha256(&dec_key_input);
let mut dec_iv_arr = [0u8; IV_LEN];
dec_iv_arr.copy_from_slice(dec_iv_bytes);
let dec_iv = u128::from_be_bytes(dec_iv_arr);
let mut stream = AesCtr::new(&dec_key, dec_iv);
let keystream = stream.encrypt(&[0u8; HANDSHAKE_LEN]);
let mut target_plain = [0u8; HANDSHAKE_LEN];
target_plain[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_bytes);
target_plain[DC_IDX_POS..DC_IDX_POS + 2].copy_from_slice(&dc_idx.to_le_bytes());
for idx in PROTO_TAG_POS..HANDSHAKE_LEN {
handshake[idx] = target_plain[idx] ^ keystream[idx];
}
handshake
}
fn make_valid_mtproto_handshake(secret_hex: &str, proto_tag: ProtoTag, dc_idx: i16) -> [u8; HANDSHAKE_LEN] {
make_mtproto_handshake_with_proto_bytes(secret_hex, proto_tag.to_bytes(), dc_idx)
}
fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig {
let mut cfg = ProxyConfig::default();
cfg.access.users.clear();
cfg.access.users.insert("user".to_string(), secret_hex.to_string());
cfg.access.ignore_time_skew = true;
cfg.general.modes.secure = true;
cfg
}
fn auth_probe_test_guard() -> MutexGuard<'static, ()> {
auth_probe_test_lock()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
#[tokio::test]
async fn mtproto_handshake_duplicate_digest_is_replayed_on_second_attempt() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let secret_hex = "11223344556677889900aabbccddeeff";
let base = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 2);
let config = test_config_with_secret_hex(secret_hex);
let replay_checker = ReplayChecker::new(128, TokioDuration::from_secs(60));
let peer: SocketAddr = "192.0.2.1:12345".parse().unwrap();
let first = handle_mtproto_handshake(
&base,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
false,
None,
)
.await;
assert!(matches!(first, HandshakeResult::Success(_)));
let second = handle_mtproto_handshake(
&base,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
false,
None,
)
.await;
assert!(matches!(second, HandshakeResult::BadClient { .. }));
clear_auth_probe_state_for_testing();
}
#[tokio::test]
async fn mtproto_handshake_fuzz_corpus_never_panics_and_stays_fail_closed() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let secret_hex = "00112233445566778899aabbccddeeff";
let base = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 1);
let config = test_config_with_secret_hex(secret_hex);
let replay_checker = ReplayChecker::new(128, TokioDuration::from_secs(60));
let peer: SocketAddr = "192.0.2.2:54321".parse().unwrap();
let mut corpus = Vec::<[u8; HANDSHAKE_LEN]>::new();
corpus.push(make_mtproto_handshake_with_proto_bytes(
secret_hex,
[0x00, 0x00, 0x00, 0x00],
1,
));
corpus.push(make_mtproto_handshake_with_proto_bytes(
secret_hex,
[0xff, 0xff, 0xff, 0xff],
1,
));
corpus.push(make_valid_mtproto_handshake(
"ffeeddccbbaa99887766554433221100",
ProtoTag::Secure,
1,
));
let mut seed = 0xF0F0_F00D_BAAD_CAFEu64;
for _ in 0..32 {
let mut mutated = base;
for _ in 0..4 {
seed = seed.wrapping_mul(2862933555777941757).wrapping_add(3037000493);
let idx = SKIP_LEN + (seed as usize % (PREKEY_LEN + IV_LEN));
mutated[idx] ^= ((seed >> 19) as u8).wrapping_add(1);
}
corpus.push(mutated);
}
for (idx, input) in corpus.into_iter().enumerate() {
let result = timeout(
TokioDuration::from_secs(1),
handle_mtproto_handshake(
&input,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
false,
None,
),
)
.await
.expect("fuzzed handshake must complete in time");
assert!(
matches!(result, HandshakeResult::BadClient { .. }),
"corpus item {idx} must fail closed"
);
}
clear_auth_probe_state_for_testing();
}
#[tokio::test]
async fn mtproto_handshake_mixed_corpus_never_panics_and_exact_duplicates_are_rejected() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let secret_hex = "99887766554433221100ffeeddccbbaa";
let base = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 4);
let config = test_config_with_secret_hex(secret_hex);
let replay_checker = ReplayChecker::new(256, TokioDuration::from_secs(60));
let peer: SocketAddr = "192.0.2.44:45444".parse().unwrap();
let first = timeout(
TokioDuration::from_secs(1),
handle_mtproto_handshake(
&base,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
false,
None,
),
)
.await
.expect("base handshake must not hang");
assert!(matches!(first, HandshakeResult::Success(_)));
let replay = timeout(
TokioDuration::from_secs(1),
handle_mtproto_handshake(
&base,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
false,
None,
),
)
.await
.expect("duplicate handshake must not hang");
assert!(matches!(replay, HandshakeResult::BadClient { .. }));
let mut corpus = Vec::<[u8; HANDSHAKE_LEN]>::new();
let mut prekey_flip = base;
prekey_flip[SKIP_LEN] ^= 0x80;
corpus.push(prekey_flip);
let mut iv_flip = base;
iv_flip[SKIP_LEN + PREKEY_LEN] ^= 0x01;
corpus.push(iv_flip);
let mut tail_flip = base;
tail_flip[SKIP_LEN + PREKEY_LEN + IV_LEN - 1] ^= 0x40;
corpus.push(tail_flip);
let mut seed = 0xBADC_0FFE_EE11_4242u64;
for _ in 0..24 {
let mut mutated = base;
for _ in 0..3 {
seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
let idx = SKIP_LEN + (seed as usize % (PREKEY_LEN + IV_LEN));
mutated[idx] ^= ((seed >> 16) as u8).wrapping_add(1);
}
corpus.push(mutated);
}
for (idx, input) in corpus.iter().enumerate() {
let result = timeout(
TokioDuration::from_secs(1),
handle_mtproto_handshake(
input,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
false,
None,
),
)
.await
.expect("fuzzed handshake must complete in time");
assert!(
matches!(result, HandshakeResult::BadClient { .. }),
"mixed corpus item {idx} must fail closed"
);
}
clear_auth_probe_state_for_testing();
}

View File

@ -0,0 +1,50 @@
use super::*;
use crate::stats::ReplayChecker;
use std::net::SocketAddr;
use std::time::Duration;
fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig {
let mut cfg = ProxyConfig::default();
cfg.access.users.clear();
cfg.access
.users
.insert("user".to_string(), secret_hex.to_string());
cfg.access.ignore_time_skew = true;
cfg
}
#[tokio::test]
async fn gap_t01_short_tls_probe_burst_is_throttled() {
let _guard = auth_probe_test_lock()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
clear_auth_probe_state_for_testing();
let config = test_config_with_secret_hex("11111111111111111111111111111111");
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let rng = SecureRandom::new();
let peer: SocketAddr = "198.51.100.171:44361".parse().unwrap();
let too_short = vec![0x16, 0x03, 0x01];
for _ in 0..AUTH_PROBE_BACKOFF_START_FAILS {
let result = handle_tls_handshake(
&too_short,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
&rng,
None,
)
.await;
assert!(matches!(result, HandshakeResult::BadClient { .. }));
}
assert!(
auth_probe_fail_streak_for_testing(peer.ip())
.is_some_and(|streak| streak >= AUTH_PROBE_BACKOFF_START_FAILS),
"short TLS probe bursts must increase auth-probe fail streak"
);
}

View File

@ -1997,42 +1997,6 @@ fn auth_probe_round_limited_overcap_eviction_marks_saturation_and_keeps_newcomer
); );
} }
#[tokio::test]
async fn gap_t01_short_tls_probe_burst_is_throttled() {
let _guard = auth_probe_test_lock()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
clear_auth_probe_state_for_testing();
let config = test_config_with_secret_hex("11111111111111111111111111111111");
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let rng = SecureRandom::new();
let peer: SocketAddr = "198.51.100.171:44361".parse().unwrap();
let too_short = vec![0x16, 0x03, 0x01];
for _ in 0..AUTH_PROBE_BACKOFF_START_FAILS {
let result = handle_tls_handshake(
&too_short,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
&rng,
None,
)
.await;
assert!(matches!(result, HandshakeResult::BadClient { .. }));
}
assert!(
auth_probe_fail_streak_for_testing(peer.ip())
.is_some_and(|streak| streak >= AUTH_PROBE_BACKOFF_START_FAILS),
"short TLS probe bursts must increase auth-probe fail streak"
);
}
#[test] #[test]
fn stress_auth_probe_overcap_churn_does_not_starve_high_threat_sentinel_bucket() { fn stress_auth_probe_overcap_churn_does_not_starve_high_threat_sentinel_bucket() {
let _guard = auth_probe_test_lock() let _guard = auth_probe_test_lock()

View File

@ -317,7 +317,3 @@ async fn consume_client_data<R: AsyncRead + Unpin>(mut reader: R) {
#[cfg(test)] #[cfg(test)]
#[path = "masking_security_tests.rs"] #[path = "masking_security_tests.rs"]
mod security_tests; mod security_tests;
#[cfg(test)]
#[path = "masking_adversarial_tests.rs"]
mod adversarial_tests;

View File

@ -1,213 +0,0 @@
use super::*;
use std::sync::Arc;
use tokio::io::duplex;
use tokio::net::TcpListener;
use tokio::time::{Instant, Duration};
use crate::config::ProxyConfig;
use crate::stats::beobachten::BeobachtenStore;
// ------------------------------------------------------------------
// Probing Indistinguishability (OWASP ASVS 5.1.7)
// ------------------------------------------------------------------
#[tokio::test]
async fn masking_probes_indistinguishable_timing() {
let mut config = ProxyConfig::default();
config.censorship.mask = true;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = 80; // Should timeout/refuse
let peer: SocketAddr = "192.0.2.10:443".parse().unwrap();
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let beobachten = BeobachtenStore::new();
// Test different probe types
let probes = vec![
(b"GET / HTTP/1.1\r\nHost: x\r\n\r\n".to_vec(), "HTTP"),
(b"SSH-2.0-probe".to_vec(), "SSH"),
(vec![0x16, 0x03, 0x03, 0x00, 0x05, 0x01, 0x00, 0x00, 0x01, 0x00], "TLS-scanner"),
(vec![0x42; 5], "port-scanner"),
];
for (probe, type_name) in probes {
let (client_reader, _client_writer) = duplex(256);
let (_client_visible_reader, client_visible_writer) = duplex(256);
let start = Instant::now();
handle_bad_client(
client_reader,
client_visible_writer,
&probe,
peer,
local_addr,
&config,
&beobachten,
).await;
let elapsed = start.elapsed();
// We expect any outcome to take roughly MASK_TIMEOUT (50ms in tests)
// to mask whether the backend was reachable or refused.
assert!(elapsed >= Duration::from_millis(30), "Probe {type_name} finished too fast: {elapsed:?}");
}
}
// ------------------------------------------------------------------
// Masking Budget Stress Tests (OWASP ASVS 5.1.6)
// ------------------------------------------------------------------
#[tokio::test]
async fn masking_budget_stress_under_load() {
let mut config = ProxyConfig::default();
config.censorship.mask = true;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = 1; // Unlikely port
let peer: SocketAddr = "192.0.2.20:443".parse().unwrap();
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let beobachten = Arc::new(BeobachtenStore::new());
let mut tasks = Vec::new();
for _ in 0..50 {
let (client_reader, _client_writer) = duplex(256);
let (_client_visible_reader, client_visible_writer) = duplex(256);
let config = config.clone();
let beobachten = Arc::clone(&beobachten);
tasks.push(tokio::spawn(async move {
let start = Instant::now();
handle_bad_client(
client_reader,
client_visible_writer,
b"probe",
peer,
local_addr,
&config,
&beobachten,
).await;
start.elapsed()
}));
}
for task in tasks {
let elapsed = task.await.unwrap();
assert!(elapsed >= Duration::from_millis(30), "Stress probe finished too fast: {elapsed:?}");
}
}
// ------------------------------------------------------------------
// detect_client_type Fingerprint Check
// ------------------------------------------------------------------
#[test]
fn test_detect_client_type_boundary_cases() {
// 9 bytes = port-scanner
assert_eq!(detect_client_type(&[0x42; 9]), "port-scanner");
// 10 bytes = unknown
assert_eq!(detect_client_type(&[0x42; 10]), "unknown");
// HTTP verbs without trailing space
assert_eq!(detect_client_type(b"GET/"), "port-scanner"); // because len < 10
assert_eq!(detect_client_type(b"GET /path"), "HTTP");
}
// ------------------------------------------------------------------
// Priority 2: Slowloris and Slow Read Attacks (OWASP ASVS 5.1.5)
// ------------------------------------------------------------------
#[tokio::test]
async fn masking_slowloris_client_idle_timeout_rejected() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let initial = b"GET / HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec();
let accept_task = tokio::spawn({
let initial = initial.clone();
async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut observed = vec![0u8; initial.len()];
stream.read_exact(&mut observed).await.unwrap();
assert_eq!(observed, initial);
let mut drip = [0u8; 1];
let drip_read = tokio::time::timeout(Duration::from_millis(220), stream.read_exact(&mut drip)).await;
assert!(
drip_read.is_err() || drip_read.unwrap().is_err(),
"backend must not receive post-timeout slowloris drip bytes"
);
}
});
let mut config = ProxyConfig::default();
config.censorship.mask = true;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = backend_addr.port();
let beobachten = BeobachtenStore::new();
let peer: SocketAddr = "192.0.2.10:12345".parse().unwrap();
let local: SocketAddr = "192.0.2.1:443".parse().unwrap();
let (mut client_writer, client_reader) = duplex(1024);
let (_client_visible_reader, client_visible_writer) = duplex(1024);
let handle = tokio::spawn(async move {
handle_bad_client(
client_reader,
client_visible_writer,
&initial,
peer,
local,
&config,
&beobachten,
)
.await;
});
tokio::time::sleep(Duration::from_millis(160)).await;
let _ = client_writer.write_all(b"X").await;
handle.await.unwrap();
accept_task.await.unwrap();
}
// ------------------------------------------------------------------
// Priority 2: Fallback Server Down / Fingerprinting (OWASP ASVS 5.1.7)
// ------------------------------------------------------------------
#[tokio::test]
async fn masking_fallback_down_mimics_timeout() {
let mut config = ProxyConfig::default();
config.censorship.mask = true;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = 1; // Unlikely port
let (server_reader, server_writer) = duplex(1024);
let beobachten = BeobachtenStore::new();
let peer: SocketAddr = "192.0.2.12:12345".parse().unwrap();
let local: SocketAddr = "192.0.2.1:443".parse().unwrap();
let start = Instant::now();
handle_bad_client(server_reader, server_writer, b"GET / HTTP/1.1\r\n", peer, local, &config, &beobachten).await;
let elapsed = start.elapsed();
// It should wait for MASK_TIMEOUT (50ms in tests) even if connection was refused immediately
assert!(elapsed >= Duration::from_millis(40), "Must respect connect budget even on failure: {:?}", elapsed);
}
// ------------------------------------------------------------------
// Priority 2: SSRF Prevention (OWASP ASVS 5.1.2)
// ------------------------------------------------------------------
#[tokio::test]
async fn masking_ssrf_resolve_internal_ranges_blocked() {
use crate::network::dns_overrides::resolve_socket_addr;
let blocked_ips = ["127.0.0.1", "169.254.169.254", "10.0.0.1", "192.168.1.1", "0.0.0.0"];
for ip in blocked_ips {
assert!(
resolve_socket_addr(ip, 80).is_none(),
"runtime DNS overrides must not resolve unconfigured literal host targets"
);
}
}

View File

@ -1,11 +1,10 @@
use super::*; use super::*;
use crate::proxy::handshake::HandshakeSuccess;
use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController};
use bytes::Bytes; use bytes::Bytes;
use crate::crypto::AesCtr; use crate::crypto::AesCtr;
use crate::crypto::SecureRandom; use crate::crypto::SecureRandom;
use crate::config::{GeneralConfig, MeRouteNoWriterMode, MeSocksKdfPolicy, MeWriterPickMode}; use crate::config::{GeneralConfig, MeRouteNoWriterMode, MeSocksKdfPolicy, MeWriterPickMode};
use crate::network::probe::NetworkDecision; use crate::network::probe::NetworkDecision;
use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController};
use crate::stats::Stats; use crate::stats::Stats;
use crate::stream::{BufferPool, CryptoReader, CryptoWriter, PooledBuffer}; use crate::stream::{BufferPool, CryptoReader, CryptoWriter, PooledBuffer};
use crate::transport::middle_proxy::MePool; use crate::transport::middle_proxy::MePool;
@ -21,7 +20,6 @@ use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt; use tokio::io::AsyncWriteExt;
use tokio::io::duplex; use tokio::io::duplex;
use tokio::time::{Duration as TokioDuration, timeout}; use tokio::time::{Duration as TokioDuration, timeout};
use std::sync::{Mutex, OnceLock};
fn make_pooled_payload(data: &[u8]) -> PooledBuffer { fn make_pooled_payload(data: &[u8]) -> PooledBuffer {
let pool = Arc::new(BufferPool::with_config(data.len().max(1), 4)); let pool = Arc::new(BufferPool::with_config(data.len().max(1), 4));
@ -38,11 +36,6 @@ fn make_pooled_payload_from(pool: &Arc<BufferPool>, data: &[u8]) -> PooledBuffer
payload payload
} }
fn quota_user_lock_test_lock() -> &'static Mutex<()> {
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
TEST_LOCK.get_or_init(|| Mutex::new(()))
}
#[test] #[test]
fn should_yield_sender_only_on_budget_with_backlog() { fn should_yield_sender_only_on_budget_with_backlog() {
assert!(!should_yield_c2me_sender(0, true)); assert!(!should_yield_c2me_sender(0, true));
@ -251,10 +244,6 @@ fn quota_user_lock_cache_reuses_entry_for_same_user() {
#[test] #[test]
fn quota_user_lock_cache_is_bounded_under_unique_churn() { fn quota_user_lock_cache_is_bounded_under_unique_churn() {
let _guard = quota_user_lock_test_lock()
.lock()
.expect("quota user lock test lock must be available");
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear(); map.clear();
@ -272,29 +261,23 @@ fn quota_user_lock_cache_is_bounded_under_unique_churn() {
#[test] #[test]
fn quota_user_lock_cache_saturation_returns_ephemeral_lock_without_growth() { fn quota_user_lock_cache_saturation_returns_ephemeral_lock_without_growth() {
let _guard = quota_user_lock_test_lock()
.lock()
.expect("quota user lock test lock must be available");
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
for attempt in 0..8u32 {
map.clear(); map.clear();
let prefix = format!("quota-held-user-{}-{attempt}", std::process::id());
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
for idx in 0..QUOTA_USER_LOCKS_MAX { for idx in 0..QUOTA_USER_LOCKS_MAX {
let user = format!("{prefix}-{idx}"); let user = format!("quota-held-user-{idx}");
retained.push(quota_user_lock(&user)); retained.push(quota_user_lock(&user));
} }
if map.len() != QUOTA_USER_LOCKS_MAX { assert_eq!(
drop(retained); map.len(),
continue; QUOTA_USER_LOCKS_MAX,
} "precondition: cache should be full before overflow acquisition"
);
let overflow_user = format!("quota-overflow-user-{}-{attempt}", std::process::id()); let overflow_a = quota_user_lock("quota-overflow-user");
let overflow_a = quota_user_lock(&overflow_user); let overflow_b = quota_user_lock("quota-overflow-user");
let overflow_b = quota_user_lock(&overflow_user);
assert_eq!( assert_eq!(
map.len(), map.len(),
@ -302,7 +285,7 @@ fn quota_user_lock_cache_saturation_returns_ephemeral_lock_without_growth() {
"overflow acquisition must not grow cache past hard limit" "overflow acquisition must not grow cache past hard limit"
); );
assert!( assert!(
map.get(&overflow_user).is_none(), map.get("quota-overflow-user").is_none(),
"overflow path should not cache new user lock when map is saturated and all entries are retained" "overflow path should not cache new user lock when map is saturated and all entries are retained"
); );
assert!( assert!(
@ -311,12 +294,6 @@ fn quota_user_lock_cache_saturation_returns_ephemeral_lock_without_growth() {
); );
drop(retained); drop(retained);
return;
}
panic!(
"unable to observe stable saturated lock-cache precondition after bounded retries"
);
} }
#[tokio::test(flavor = "multi_thread", worker_threads = 4)] #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
@ -2192,320 +2169,3 @@ async fn middle_relay_cutover_storm_multi_session_keeps_generic_errors_and_relea
drop(client_sides); drop(client_sides);
} }
#[tokio::test]
async fn secure_padding_distribution_in_relay_writer() {
timeout(TokioDuration::from_secs(10), async {
let (mut client_side, relay_side) = duplex(512 * 1024);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(relay_side, AesCtr::new(&key, iv), 8 * 1024);
let rng = Arc::new(SecureRandom::new());
let mut frame_buf = Vec::new();
let mut decryptor = AesCtr::new(&key, iv);
let mut padding_counts = [0usize; 4];
let iterations = 180usize;
let payload = vec![0xAAu8; 100]; // 4-byte aligned
for _ in 0..iterations {
write_client_payload(
&mut writer,
ProtoTag::Secure,
0,
&payload,
&rng,
&mut frame_buf,
)
.await
.expect("payload write must succeed");
writer
.flush()
.await
.expect("writer flush must complete so encrypted frame becomes readable");
let mut len_buf = [0u8; 4];
client_side
.read_exact(&mut len_buf)
.await
.expect("must read encrypted secure length");
let decrypted_len_bytes = decryptor.decrypt(&len_buf);
let decrypted_len_bytes: [u8; 4] = decrypted_len_bytes
.try_into()
.expect("decrypted length must be 4 bytes");
let wire_len = (u32::from_le_bytes(decrypted_len_bytes) & 0x7fff_ffff) as usize;
assert!(
wire_len >= payload.len(),
"wire length must include at least payload bytes"
);
let padding_len = wire_len - payload.len();
assert!(padding_len >= 1 && padding_len <= 3);
padding_counts[padding_len] += 1;
// Drain and decrypt frame bytes so CTR state stays aligned across writes.
let mut trash = vec![0u8; wire_len];
client_side
.read_exact(&mut trash)
.await
.expect("must read encrypted secure frame body");
let _ = decryptor.decrypt(&trash);
}
for p in 1..=3 {
let count = padding_counts[p];
assert!(
count > iterations / 8,
"padding length {p} is under-represented ({count}/{iterations})"
);
}
})
.await
.expect("secure padding distribution test exceeded runtime budget");
}
#[tokio::test]
async fn negative_middle_end_connection_lost_during_relay_exits_on_client_eof() {
let (client_reader_side, client_writer_side) = duplex(1024);
let (_relay_reader_side, relay_writer_side) = duplex(1024);
let key = [0u8; 32];
let iv = 0u128;
let crypto_reader = CryptoReader::new(client_reader_side, AesCtr::new(&key, iv));
let crypto_writer = CryptoWriter::new(relay_writer_side, AesCtr::new(&key, iv), 1024);
let stats = Arc::new(Stats::new());
let config = Arc::new(ProxyConfig::default());
let buffer_pool = Arc::new(BufferPool::with_config(1024, 1));
let rng = Arc::new(SecureRandom::new());
let route_runtime = RouteRuntimeController::new(RelayRouteMode::Middle);
// Create an ME pool.
let me_pool = make_me_pool_for_abort_test(stats.clone()).await;
// ConnRegistry ids are monotonic; reserve one id so we can predict the
// next session conn_id and close it deterministically without relying on
// writer-bound views such as active_conn_ids().
let (probe_conn_id, probe_rx) = me_pool.registry().register().await;
drop(probe_rx);
me_pool.registry().unregister(probe_conn_id).await;
let target_conn_id = probe_conn_id.wrapping_add(1);
let success = HandshakeSuccess {
user: "test-user".to_string(),
peer: "127.0.0.1:12345".parse().unwrap(),
dc_idx: 1,
proto_tag: ProtoTag::Intermediate,
enc_key: key,
enc_iv: iv,
dec_key: key,
dec_iv: iv,
is_tls: false,
};
let session_task = tokio::spawn(handle_via_middle_proxy(
crypto_reader,
crypto_writer,
success,
me_pool.clone(),
stats.clone(),
config.clone(),
buffer_pool.clone(),
"127.0.0.1:443".parse().unwrap(),
rng.clone(),
route_runtime.subscribe(),
route_runtime.snapshot(),
0x1234_5678,
));
// Wait until session startup is visible, then unregister the predicted
// conn_id to close the per-session ME response channel.
timeout(TokioDuration::from_millis(500), async {
loop {
if stats.get_current_connections_me() >= 1 {
break;
}
tokio::time::sleep(TokioDuration::from_millis(10)).await;
}
})
.await
.expect("ME session must start before channel close simulation");
me_pool.registry().unregister(target_conn_id).await;
drop(client_writer_side);
let result = timeout(TokioDuration::from_secs(2), session_task)
.await
.expect("Session task must terminate after ME drop and client EOF")
.expect("Session task must not panic");
assert!(
result.is_ok(),
"Session should complete cleanly after ME drop when client closes, got: {:?}",
result
);
}
#[tokio::test]
async fn adversarial_middle_end_drop_plus_cutover_returns_generic_route_switch() {
let (client_reader_side, _client_writer_side) = duplex(1024);
let (_relay_reader_side, relay_writer_side) = duplex(1024);
let key = [0u8; 32];
let iv = 0u128;
let crypto_reader = CryptoReader::new(client_reader_side, AesCtr::new(&key, iv));
let crypto_writer = CryptoWriter::new(relay_writer_side, AesCtr::new(&key, iv), 1024);
let stats = Arc::new(Stats::new());
let config = Arc::new(ProxyConfig::default());
let buffer_pool = Arc::new(BufferPool::with_config(1024, 1));
let rng = Arc::new(SecureRandom::new());
let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Middle));
let me_pool = make_me_pool_for_abort_test(stats.clone()).await;
// Predict the next conn_id so we can force-drop its ME channel deterministically.
let (probe_conn_id, probe_rx) = me_pool.registry().register().await;
drop(probe_rx);
me_pool.registry().unregister(probe_conn_id).await;
let target_conn_id = probe_conn_id.wrapping_add(1);
let success = HandshakeSuccess {
user: "test-user-cutover".to_string(),
peer: "127.0.0.1:12345".parse().unwrap(),
dc_idx: 1,
proto_tag: ProtoTag::Intermediate,
enc_key: key,
enc_iv: iv,
dec_key: key,
dec_iv: iv,
is_tls: false,
};
let runtime_clone = route_runtime.clone();
let session_task = tokio::spawn(handle_via_middle_proxy(
crypto_reader,
crypto_writer,
success,
me_pool.clone(),
stats.clone(),
config,
buffer_pool,
"127.0.0.1:443".parse().unwrap(),
rng,
runtime_clone.subscribe(),
runtime_clone.snapshot(),
0xC001_CAFE,
));
timeout(TokioDuration::from_millis(500), async {
loop {
if stats.get_current_connections_me() >= 1 {
break;
}
tokio::time::sleep(TokioDuration::from_millis(10)).await;
}
})
.await
.expect("ME session must start before race trigger");
// Race ME channel drop with route cutover and assert generic client-visible outcome.
me_pool.registry().unregister(target_conn_id).await;
assert!(
route_runtime.set_mode(RelayRouteMode::Direct).is_some(),
"cutover must advance generation"
);
let relay_result = timeout(TokioDuration::from_secs(6), session_task)
.await
.expect("session must terminate under ME-drop + cutover race")
.expect("session task must not panic");
assert!(
matches!(
relay_result,
Err(ProxyError::Proxy(ref msg)) if msg == ROUTE_SWITCH_ERROR_MSG
),
"race outcome must remain generic and not leak ME internals, got: {:?}",
relay_result
);
}
#[tokio::test]
async fn stress_middle_end_drop_with_client_eof_never_hangs_across_burst() {
let stats = Arc::new(Stats::new());
let me_pool = make_me_pool_for_abort_test(stats.clone()).await;
for round in 0..32u64 {
let (client_reader_side, client_writer_side) = duplex(1024);
let (_relay_reader_side, relay_writer_side) = duplex(1024);
let key = [0u8; 32];
let iv = 0u128;
let crypto_reader = CryptoReader::new(client_reader_side, AesCtr::new(&key, iv));
let crypto_writer = CryptoWriter::new(relay_writer_side, AesCtr::new(&key, iv), 1024);
let config = Arc::new(ProxyConfig::default());
let buffer_pool = Arc::new(BufferPool::with_config(1024, 1));
let rng = Arc::new(SecureRandom::new());
let route_runtime = RouteRuntimeController::new(RelayRouteMode::Middle);
let (probe_conn_id, probe_rx) = me_pool.registry().register().await;
drop(probe_rx);
me_pool.registry().unregister(probe_conn_id).await;
let target_conn_id = probe_conn_id.wrapping_add(1);
let success = HandshakeSuccess {
user: format!("stress-me-drop-eof-{round}"),
peer: "127.0.0.1:12345".parse().unwrap(),
dc_idx: 1,
proto_tag: ProtoTag::Intermediate,
enc_key: key,
enc_iv: iv,
dec_key: key,
dec_iv: iv,
is_tls: false,
};
let session_task = tokio::spawn(handle_via_middle_proxy(
crypto_reader,
crypto_writer,
success,
me_pool.clone(),
stats.clone(),
config,
buffer_pool,
"127.0.0.1:443".parse().unwrap(),
rng,
route_runtime.subscribe(),
route_runtime.snapshot(),
0xD00D_0000 + round,
));
timeout(TokioDuration::from_millis(500), async {
loop {
if stats.get_current_connections_me() >= 1 {
break;
}
tokio::time::sleep(TokioDuration::from_millis(10)).await;
}
})
.await
.expect("session must start before forced drop in burst round");
me_pool.registry().unregister(target_conn_id).await;
drop(client_writer_side);
let result = timeout(TokioDuration::from_secs(2), session_task)
.await
.expect("burst round session must terminate quickly")
.expect("burst round session must not panic");
assert!(
result.is_ok(),
"burst round {round}: expected clean shutdown after ME drop + EOF, got: {:?}",
result
);
}
}

View File

@ -659,5 +659,3 @@ where
#[cfg(test)] #[cfg(test)]
#[path = "relay_security_tests.rs"] #[path = "relay_security_tests.rs"]
mod security_tests; mod security_tests;
#[path = "relay_adversarial_tests.rs"]
mod adversarial_tests;

View File

@ -1,122 +0,0 @@
use super::*;
use crate::error::ProxyError;
use crate::stats::Stats;
use crate::stream::BufferPool;
use std::sync::Arc;
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};
use tokio::time::{Duration, Instant, timeout};
// ------------------------------------------------------------------
// Priority 3: Async Relay HOL Blocking Prevention (OWASP ASVS 5.1.5)
// ------------------------------------------------------------------
#[tokio::test]
async fn relay_hol_blocking_prevention_regression() {
let stats = Arc::new(Stats::new());
let user = "hol-user";
let (client_peer, relay_client) = duplex(65536);
let (relay_server, server_peer) = duplex(65536);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let (mut cp_reader, mut cp_writer) = tokio::io::split(client_peer);
let (mut sp_reader, mut sp_writer) = tokio::io::split(server_peer);
let relay_task = tokio::spawn(relay_bidirectional(
client_reader,
client_writer,
server_reader,
server_writer,
8192,
8192,
user,
Arc::clone(&stats),
None,
Arc::new(BufferPool::new()),
));
let payload_size = 1024 * 10;
let s2c_payload = vec![0x41; payload_size];
let c2s_payload = vec![0x42; payload_size];
let s2c_handle = tokio::spawn(async move {
sp_writer.write_all(&s2c_payload).await.unwrap();
let mut total_read = 0;
let mut buf = [0u8; 10];
while total_read < payload_size {
let n = cp_reader.read(&mut buf).await.unwrap();
total_read += n;
tokio::time::sleep(Duration::from_millis(100)).await;
}
});
let start = Instant::now();
cp_writer.write_all(&c2s_payload).await.unwrap();
let mut server_buf = vec![0u8; payload_size];
sp_reader.read_exact(&mut server_buf).await.unwrap();
let elapsed = start.elapsed();
assert!(elapsed < Duration::from_millis(1000), "C->S must not be blocked by slow S->C (HOL blocking): {:?}", elapsed);
assert_eq!(server_buf, c2s_payload);
s2c_handle.abort();
relay_task.abort();
}
// ------------------------------------------------------------------
// Priority 3: Data Quota Mid-Session Cutoff (OWASP ASVS 5.1.6)
// ------------------------------------------------------------------
#[tokio::test]
async fn relay_quota_mid_session_cutoff() {
let stats = Arc::new(Stats::new());
let user = "quota-mid-user";
let quota = 5000;
let (client_peer, relay_client) = duplex(8192);
let (relay_server, server_peer) = duplex(8192);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let (mut _cp_reader, mut cp_writer) = tokio::io::split(client_peer);
let (mut sp_reader, _sp_writer) = tokio::io::split(server_peer);
let relay_task = tokio::spawn(relay_bidirectional(
client_reader,
client_writer,
server_reader,
server_writer,
1024,
1024,
user,
Arc::clone(&stats),
Some(quota),
Arc::new(BufferPool::new()),
));
// Send 4000 bytes (Ok)
let buf1 = vec![0x42; 4000];
cp_writer.write_all(&buf1).await.unwrap();
let mut server_recv = vec![0u8; 4000];
sp_reader.read_exact(&mut server_recv).await.unwrap();
// Send another 2000 bytes (Total 6000 > 5000)
let buf2 = vec![0x42; 2000];
let _ = cp_writer.write_all(&buf2).await;
let relay_res = timeout(Duration::from_secs(1), relay_task).await.unwrap();
match relay_res {
Ok(Err(ProxyError::DataQuotaExceeded { .. })) => {
// Expected
}
other => panic!("Expected DataQuotaExceeded error, got: {:?}", other),
}
let mut small_buf = [0u8; 1];
let n = sp_reader.read(&mut small_buf).await.unwrap();
assert_eq!(n, 0, "Server must see EOF after quota reached");
}

View File

@ -1140,46 +1140,3 @@ async fn relay_bidirectional_light_fuzz_permission_denied_messages_remain_io_err
); );
} }
} }
#[tokio::test]
async fn relay_half_close_keeps_reverse_direction_progressing() {
let stats = Arc::new(Stats::new());
let user = "half-close-user";
let (client_peer, relay_client) = duplex(1024);
let (relay_server, server_peer) = duplex(1024);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let (mut cp_reader, mut cp_writer) = tokio::io::split(client_peer);
let (mut sp_reader, mut sp_writer) = tokio::io::split(server_peer);
let relay_task = tokio::spawn(relay_bidirectional(
client_reader,
client_writer,
server_reader,
server_writer,
8192,
8192,
user,
Arc::clone(&stats),
None,
Arc::new(BufferPool::new()),
));
sp_writer.write_all(&[0x10, 0x20, 0x30, 0x40]).await.unwrap();
sp_writer.shutdown().await.unwrap();
let mut inbound = [0u8; 4];
cp_reader.read_exact(&mut inbound).await.unwrap();
assert_eq!(inbound, [0x10, 0x20, 0x30, 0x40]);
cp_writer.write_all(&[0xaa, 0xbb, 0xcc, 0xdd]).await.unwrap();
let mut outbound = [0u8; 4];
sp_reader.read_exact(&mut outbound).await.unwrap();
assert_eq!(outbound, [0xaa, 0xbb, 0xcc, 0xdd]);
relay_task.abort();
let joined = relay_task.await;
assert!(joined.is_err(), "aborted relay task must return join error");
}