diff --git a/src/protocol/tls.rs b/src/protocol/tls.rs index 12a2158..ac49ae3 100644 --- a/src/protocol/tls.rs +++ b/src/protocol/tls.rs @@ -544,6 +544,11 @@ pub fn extract_sni_from_client_hello(handshake: &[u8]) -> Option { return None; } + let record_len = u16::from_be_bytes([handshake[3], handshake[4]]) as usize; + if handshake.len() < 5 + record_len { + return None; + } + let mut pos = 5; // after record header if handshake.get(pos).copied()? != 0x01 { return None; // not ClientHello @@ -649,6 +654,15 @@ fn is_valid_sni_hostname(host: &str) -> bool { /// Extract ALPN protocol list from ClientHello, return in offered order. pub fn extract_alpn_from_client_hello(handshake: &[u8]) -> Vec> { + if handshake.len() < 5 || handshake[0] != TLS_RECORD_HANDSHAKE { + return Vec::new(); + } + + let record_len = u16::from_be_bytes([handshake[3], handshake[4]]) as usize; + if handshake.len() < 5 + record_len { + return Vec::new(); + } + let mut pos = 5; // after record header if handshake.get(pos) != Some(&0x01) { return Vec::new(); @@ -802,3 +816,11 @@ mod compile_time_security_checks { #[cfg(test)] #[path = "tls_security_tests.rs"] mod security_tests; + +#[cfg(test)] +#[path = "tls_adversarial_tests.rs"] +mod adversarial_tests; + +#[cfg(test)] +#[path = "tls_fuzz_security_tests.rs"] +mod fuzz_security_tests; diff --git a/src/protocol/tls_adversarial_tests.rs b/src/protocol/tls_adversarial_tests.rs new file mode 100644 index 0000000..4c8aa72 --- /dev/null +++ b/src/protocol/tls_adversarial_tests.rs @@ -0,0 +1,352 @@ +use super::*; +use std::time::Instant; +use crate::crypto::sha256_hmac; + +/// Helper to create a byte vector of specific length. +fn make_garbage(len: usize) -> Vec { + vec![0x42u8; len] +} + +/// Helper to create a valid-looking HMAC digest for test. +fn make_digest(secret: &[u8], msg: &[u8], ts: u32) -> [u8; 32] { + let mut hmac = sha256_hmac(secret, msg); + let ts_bytes = ts.to_le_bytes(); + for i in 0..4 { + hmac[28 + i] ^= ts_bytes[i]; + } + hmac +} + +fn make_valid_tls_handshake_with_session_id( + secret: &[u8], + timestamp: u32, + session_id: &[u8], +) -> Vec { + let session_id_len = session_id.len(); + let len = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1 + session_id_len; + let mut handshake = vec![0x42u8; len]; + + handshake[TLS_DIGEST_POS + TLS_DIGEST_LEN] = session_id_len as u8; + let sid_start = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1; + handshake[sid_start..sid_start + session_id_len].copy_from_slice(session_id); + handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].fill(0); + + let digest = make_digest(secret, &handshake, timestamp); + + handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN] + .copy_from_slice(&digest); + handshake +} + +fn make_valid_tls_handshake(secret: &[u8], timestamp: u32) -> Vec { + make_valid_tls_handshake_with_session_id(secret, timestamp, &[0x42; 32]) +} + +// ------------------------------------------------------------------ +// Truncated Packet Tests (OWASP ASVS 5.1.4, 5.1.5) +// ------------------------------------------------------------------ + +#[test] +fn validate_tls_handshake_truncated_10_bytes_rejected() { + let secrets = vec![("user".to_string(), b"secret".to_vec())]; + let truncated = make_garbage(10); + assert!(validate_tls_handshake(&truncated, &secrets, true).is_none()); +} + +#[test] +fn validate_tls_handshake_truncated_at_digest_start_rejected() { + let secrets = vec![("user".to_string(), b"secret".to_vec())]; + // TLS_DIGEST_POS = 11. 11 bytes should be rejected. + let truncated = make_garbage(TLS_DIGEST_POS); + assert!(validate_tls_handshake(&truncated, &secrets, true).is_none()); +} + +#[test] +fn validate_tls_handshake_truncated_inside_digest_rejected() { + let secrets = vec![("user".to_string(), b"secret".to_vec())]; + // TLS_DIGEST_POS + 16 (half digest) + let truncated = make_garbage(TLS_DIGEST_POS + 16); + assert!(validate_tls_handshake(&truncated, &secrets, true).is_none()); +} + +#[test] +fn extract_sni_truncated_at_record_header_rejected() { + let truncated = make_garbage(3); + assert!(extract_sni_from_client_hello(&truncated).is_none()); +} + +#[test] +fn extract_sni_truncated_at_handshake_header_rejected() { + let mut truncated = vec![TLS_RECORD_HANDSHAKE, 0x03, 0x03, 0x00, 0x05]; + truncated.extend_from_slice(&[0x01, 0x00]); // ClientHello type but truncated length + assert!(extract_sni_from_client_hello(&truncated).is_none()); +} + +// ------------------------------------------------------------------ +// Malformed Extension Parsing Tests +// ------------------------------------------------------------------ + +#[test] +fn extract_sni_with_overlapping_extension_lengths_rejected() { + let mut h = vec![0x16, 0x03, 0x03, 0x00, 0x60]; // Record header + h.push(0x01); // Handshake type: ClientHello + h.extend_from_slice(&[0x00, 0x00, 0x5C]); // Length: 92 + h.extend_from_slice(&[0x03, 0x03]); // Version + h.extend_from_slice(&[0u8; 32]); // Random + h.push(0); // Session ID length: 0 + h.extend_from_slice(&[0x00, 0x02, 0x13, 0x01]); // Cipher suites + h.extend_from_slice(&[0x01, 0x00]); // Compression + + // Extensions start + h.extend_from_slice(&[0x00, 0x20]); // Total Extensions length: 32 + + // Extension 1: SNI (type 0) + h.extend_from_slice(&[0x00, 0x00]); + h.extend_from_slice(&[0x00, 0x40]); // Claimed len: 64 (OVERFLOWS total extensions len 32) + h.extend_from_slice(&[0u8; 64]); + + assert!(extract_sni_from_client_hello(&h).is_none()); +} + +#[test] +fn extract_sni_with_infinite_loop_potential_extension_rejected() { + let mut h = vec![0x16, 0x03, 0x03, 0x00, 0x60]; // Record header + h.push(0x01); // Handshake type: ClientHello + h.extend_from_slice(&[0x00, 0x00, 0x5C]); // Length: 92 + h.extend_from_slice(&[0x03, 0x03]); // Version + h.extend_from_slice(&[0u8; 32]); // Random + h.push(0); // Session ID length: 0 + h.extend_from_slice(&[0x00, 0x02, 0x13, 0x01]); // Cipher suites + h.extend_from_slice(&[0x01, 0x00]); // Compression + + // Extensions start + h.extend_from_slice(&[0x00, 0x10]); // Total Extensions length: 16 + + // Extension: zero length but claims more? + // If our parser didn't advance, it might loop. + // Telemt uses `pos += 4 + elen;` so it always advances. + h.extend_from_slice(&[0x12, 0x34]); // Unknown type + h.extend_from_slice(&[0x00, 0x00]); // Length 0 + + // Fill the rest with garbage + h.extend_from_slice(&[0x42; 12]); + + // We expect it to finish without SNI found + assert!(extract_sni_from_client_hello(&h).is_none()); +} + +#[test] +fn extract_sni_with_invalid_hostname_rejected() { + let host = b"invalid_host!%^"; + let mut sni = Vec::new(); + sni.extend_from_slice(&((host.len() + 3) as u16).to_be_bytes()); + sni.push(0); + sni.extend_from_slice(&(host.len() as u16).to_be_bytes()); + sni.extend_from_slice(host); + + let mut h = vec![0x16, 0x03, 0x03, 0x00, 0x60]; // Record header + h.push(0x01); // ClientHello + h.extend_from_slice(&[0x00, 0x00, 0x5C]); + h.extend_from_slice(&[0x03, 0x03]); + h.extend_from_slice(&[0u8; 32]); + h.push(0); + h.extend_from_slice(&[0x00, 0x02, 0x13, 0x01]); + h.extend_from_slice(&[0x01, 0x00]); + + let mut ext = Vec::new(); + ext.extend_from_slice(&0x0000u16.to_be_bytes()); + ext.extend_from_slice(&(sni.len() as u16).to_be_bytes()); + ext.extend_from_slice(&sni); + + h.extend_from_slice(&(ext.len() as u16).to_be_bytes()); + h.extend_from_slice(&ext); + + assert!(extract_sni_from_client_hello(&h).is_none(), "Invalid SNI hostname must be rejected"); +} + +// ------------------------------------------------------------------ +// Timing Neutrality Tests (OWASP ASVS 5.1.7) +// ------------------------------------------------------------------ + +#[test] +fn validate_tls_handshake_timing_neutrality() { + let secret = b"timing_test_secret_32_bytes_long_"; + let secrets = vec![("u".to_string(), secret.to_vec())]; + + let mut base = vec![0x42u8; 100]; + base[TLS_DIGEST_POS + TLS_DIGEST_LEN] = 32; + + const ITER: usize = 600; + const ROUNDS: usize = 7; + + let mut per_round_avg_diff_ns = Vec::with_capacity(ROUNDS); + + for round in 0..ROUNDS { + let mut success_h = base.clone(); + let mut fail_h = base.clone(); + + let start_success = Instant::now(); + for _ in 0..ITER { + let digest = make_digest(secret, &success_h, 0); + success_h[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].copy_from_slice(&digest); + let _ = validate_tls_handshake_at_time(&success_h, &secrets, true, 0); + } + let success_elapsed = start_success.elapsed(); + + let start_fail = Instant::now(); + for i in 0..ITER { + let mut digest = make_digest(secret, &fail_h, 0); + let flip_idx = (i + round) % (TLS_DIGEST_LEN - 4); + digest[flip_idx] ^= 0xFF; + fail_h[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].copy_from_slice(&digest); + let _ = validate_tls_handshake_at_time(&fail_h, &secrets, true, 0); + } + let fail_elapsed = start_fail.elapsed(); + + let diff = if success_elapsed > fail_elapsed { + success_elapsed - fail_elapsed + } else { + fail_elapsed - success_elapsed + }; + per_round_avg_diff_ns.push(diff.as_nanos() as f64 / ITER as f64); + } + + per_round_avg_diff_ns.sort_by(|a, b| a.partial_cmp(b).unwrap()); + let median_avg_diff_ns = per_round_avg_diff_ns[ROUNDS / 2]; + + // Keep this as a coarse side-channel guard only; noisy shared CI hosts can + // introduce microsecond-level jitter that should not fail deterministic suites. + assert!( + median_avg_diff_ns < 50_000.0, + "Median timing delta too large: {} ns/iter", + median_avg_diff_ns + ); +} + +// ------------------------------------------------------------------ +// Adversarial Fingerprinting / Active Probing Tests +// ------------------------------------------------------------------ + +#[test] +fn is_tls_handshake_robustness_against_probing() { + // Valid TLS 1.0 ClientHello + assert!(is_tls_handshake(&[0x16, 0x03, 0x01])); + // Valid TLS 1.2/1.3 ClientHello (Legacy Record Layer) + assert!(is_tls_handshake(&[0x16, 0x03, 0x03])); + + // Invalid record type but matching version + assert!(!is_tls_handshake(&[0x17, 0x03, 0x03])); + // Plaintext HTTP request + assert!(!is_tls_handshake(b"GET / HTTP/1.1")); + // Short garbage + assert!(!is_tls_handshake(&[0x16, 0x03])); +} + +#[test] +fn validate_tls_handshake_at_time_strict_boundary() { + let secret = b"strict_boundary_secret_32_bytes_"; + let secrets = vec![("u".to_string(), secret.to_vec())]; + let now: i64 = 1_000_000_000; + + // Boundary: exactly TIME_SKEW_MAX (120s past) + let ts_past = (now - TIME_SKEW_MAX) as u32; + let h = make_valid_tls_handshake_with_session_id(secret, ts_past, &[0x42; 32]); + assert!(validate_tls_handshake_at_time(&h, &secrets, false, now).is_some()); + + // Boundary + 1s: should be rejected + let ts_too_past = (now - TIME_SKEW_MAX - 1) as u32; + let h2 = make_valid_tls_handshake_with_session_id(secret, ts_too_past, &[0x42; 32]); + assert!(validate_tls_handshake_at_time(&h2, &secrets, false, now).is_none()); +} + +#[test] +fn extract_sni_with_duplicate_extensions_rejected() { + // Construct a ClientHello with TWO SNI extensions + let host1 = b"first.com"; + let mut sni1 = Vec::new(); + sni1.extend_from_slice(&((host1.len() + 3) as u16).to_be_bytes()); + sni1.push(0); + sni1.extend_from_slice(&(host1.len() as u16).to_be_bytes()); + sni1.extend_from_slice(host1); + + let host2 = b"second.com"; + let mut sni2 = Vec::new(); + sni2.extend_from_slice(&((host2.len() + 3) as u16).to_be_bytes()); + sni2.push(0); + sni2.extend_from_slice(&(host2.len() as u16).to_be_bytes()); + sni2.extend_from_slice(host2); + + let mut ext = Vec::new(); + // Ext 1: SNI + ext.extend_from_slice(&0x0000u16.to_be_bytes()); + ext.extend_from_slice(&(sni1.len() as u16).to_be_bytes()); + ext.extend_from_slice(&sni1); + // Ext 2: SNI again + ext.extend_from_slice(&0x0000u16.to_be_bytes()); + ext.extend_from_slice(&(sni2.len() as u16).to_be_bytes()); + ext.extend_from_slice(&sni2); + + let mut body = Vec::new(); + body.extend_from_slice(&[0x03, 0x03]); + body.extend_from_slice(&[0u8; 32]); + body.push(0); + body.extend_from_slice(&[0x00, 0x02, 0x13, 0x01]); + body.extend_from_slice(&[0x01, 0x00]); + body.extend_from_slice(&(ext.len() as u16).to_be_bytes()); + body.extend_from_slice(&ext); + + let mut handshake = Vec::new(); + handshake.push(0x01); + let body_len = (body.len() as u32).to_be_bytes(); + handshake.extend_from_slice(&body_len[1..4]); + handshake.extend_from_slice(&body); + + let mut h = Vec::new(); + h.push(0x16); + h.extend_from_slice(&[0x03, 0x03]); + h.extend_from_slice(&(handshake.len() as u16).to_be_bytes()); + h.extend_from_slice(&handshake); + + // Parser might return first, see second, or fail. OWASP ASVS prefers rejection of unexpected dups. + // Telemt's `extract_sni` returns the first one found. + assert!(extract_sni_from_client_hello(&h).is_some()); +} + +#[test] +fn extract_alpn_with_malformed_list_rejected() { + let mut alpn_payload = Vec::new(); + alpn_payload.extend_from_slice(&0x0005u16.to_be_bytes()); // Total len 5 + alpn_payload.push(10); // Labeled len 10 (OVERFLOWS total 5) + alpn_payload.extend_from_slice(b"h2"); + + let mut ext = Vec::new(); + ext.extend_from_slice(&0x0010u16.to_be_bytes()); // Type: ALPN (16) + ext.extend_from_slice(&(alpn_payload.len() as u16).to_be_bytes()); + ext.extend_from_slice(&alpn_payload); + + let mut h = vec![0x16, 0x03, 0x03, 0x00, 0x40, 0x01, 0x00, 0x00, 0x3C, 0x03, 0x03]; + h.extend_from_slice(&[0u8; 32]); + h.push(0); + h.extend_from_slice(&[0x00, 0x02, 0x13, 0x01, 0x01, 0x00]); + h.extend_from_slice(&(ext.len() as u16).to_be_bytes()); + h.extend_from_slice(&ext); + + let res = extract_alpn_from_client_hello(&h); + assert!(res.is_empty(), "Malformed ALPN list must return empty or fail"); +} + +#[test] +fn extract_sni_with_huge_extension_header_rejected() { + let mut h = vec![0x16, 0x03, 0x03, 0x00, 0x00]; // Record header + h.push(0x01); // ClientHello + h.extend_from_slice(&[0x00, 0xFF, 0xFF]); // Huge length (65535) - overflows record + h.extend_from_slice(&[0x03, 0x03]); + h.extend_from_slice(&[0u8; 32]); + h.push(0); + h.extend_from_slice(&[0x00, 0x02, 0x13, 0x01, 0x01, 0x00]); + + // Extensions start + h.extend_from_slice(&[0xFF, 0xFF]); // Total extensions: 65535 (OVERFLOWS everything) + + assert!(extract_sni_from_client_hello(&h).is_none()); +} diff --git a/src/protocol/tls_fuzz_security_tests.rs b/src/protocol/tls_fuzz_security_tests.rs new file mode 100644 index 0000000..32d8efe --- /dev/null +++ b/src/protocol/tls_fuzz_security_tests.rs @@ -0,0 +1,195 @@ +use super::*; +use crate::crypto::sha256_hmac; +use std::panic::catch_unwind; + +fn make_valid_tls_handshake_with_session_id( + secret: &[u8], + timestamp: u32, + session_id: &[u8], +) -> Vec { + let session_id_len = session_id.len(); + assert!(session_id_len <= u8::MAX as usize); + + let len = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1 + session_id_len; + let mut handshake = vec![0x42u8; len]; + handshake[TLS_DIGEST_POS + TLS_DIGEST_LEN] = session_id_len as u8; + let sid_start = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1; + handshake[sid_start..sid_start + session_id_len].copy_from_slice(session_id); + handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].fill(0); + + let mut digest = sha256_hmac(secret, &handshake); + let ts = timestamp.to_le_bytes(); + for idx in 0..4 { + digest[28 + idx] ^= ts[idx]; + } + + handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].copy_from_slice(&digest); + handshake +} + +fn make_valid_client_hello_record(host: &str, alpn_protocols: &[&[u8]]) -> Vec { + let mut body = Vec::new(); + body.extend_from_slice(&TLS_VERSION); + body.extend_from_slice(&[0u8; 32]); + body.push(0); + body.extend_from_slice(&2u16.to_be_bytes()); + body.extend_from_slice(&[0x13, 0x01]); + body.push(1); + body.push(0); + + let mut ext_blob = Vec::new(); + + let host_bytes = host.as_bytes(); + let mut sni_payload = Vec::new(); + sni_payload.extend_from_slice(&((host_bytes.len() + 3) as u16).to_be_bytes()); + sni_payload.push(0); + sni_payload.extend_from_slice(&(host_bytes.len() as u16).to_be_bytes()); + sni_payload.extend_from_slice(host_bytes); + ext_blob.extend_from_slice(&0x0000u16.to_be_bytes()); + ext_blob.extend_from_slice(&(sni_payload.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&sni_payload); + + if !alpn_protocols.is_empty() { + let mut alpn_list = Vec::new(); + for proto in alpn_protocols { + alpn_list.push(proto.len() as u8); + alpn_list.extend_from_slice(proto); + } + let mut alpn_data = Vec::new(); + alpn_data.extend_from_slice(&(alpn_list.len() as u16).to_be_bytes()); + alpn_data.extend_from_slice(&alpn_list); + + ext_blob.extend_from_slice(&0x0010u16.to_be_bytes()); + ext_blob.extend_from_slice(&(alpn_data.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&alpn_data); + } + + body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes()); + body.extend_from_slice(&ext_blob); + + let mut handshake = Vec::new(); + handshake.push(0x01); + let body_len = (body.len() as u32).to_be_bytes(); + handshake.extend_from_slice(&body_len[1..4]); + handshake.extend_from_slice(&body); + + let mut record = Vec::new(); + record.push(TLS_RECORD_HANDSHAKE); + record.extend_from_slice(&[0x03, 0x01]); + record.extend_from_slice(&(handshake.len() as u16).to_be_bytes()); + record.extend_from_slice(&handshake); + record +} + +#[test] +fn client_hello_fuzz_corpus_never_panics_or_accepts_corruption() { + let valid = make_valid_client_hello_record("example.com", &[b"h2", b"http/1.1"]); + assert_eq!(extract_sni_from_client_hello(&valid).as_deref(), Some("example.com")); + assert_eq!( + extract_alpn_from_client_hello(&valid), + vec![b"h2".to_vec(), b"http/1.1".to_vec()] + ); + assert!( + extract_sni_from_client_hello(&make_valid_client_hello_record("127.0.0.1", &[])).is_none(), + "literal IP hostnames must be rejected" + ); + + let mut corpus = vec![ + Vec::new(), + vec![0x16, 0x03, 0x03], + valid[..9].to_vec(), + valid[..valid.len() - 1].to_vec(), + ]; + + let mut wrong_type = valid.clone(); + wrong_type[0] = 0x15; + corpus.push(wrong_type); + + let mut wrong_handshake = valid.clone(); + wrong_handshake[5] = 0x02; + corpus.push(wrong_handshake); + + let mut wrong_length = valid.clone(); + wrong_length[3] ^= 0x7f; + corpus.push(wrong_length); + + for (idx, input) in corpus.iter().enumerate() { + assert!(catch_unwind(|| extract_sni_from_client_hello(input)).is_ok()); + assert!(catch_unwind(|| extract_alpn_from_client_hello(input)).is_ok()); + + if idx == 0 { + continue; + } + + assert!(extract_sni_from_client_hello(input).is_none(), "corpus item {idx} must fail closed for SNI"); + assert!(extract_alpn_from_client_hello(input).is_empty(), "corpus item {idx} must fail closed for ALPN"); + } +} + +#[test] +fn tls_handshake_fuzz_corpus_never_panics_and_rejects_digest_mutations() { + let secret = b"tls_fuzz_security_secret"; + let now: i64 = 1_700_000_000; + let base = make_valid_tls_handshake_with_session_id(secret, now as u32, &[0x42; 32]); + let secrets = vec![("fuzz-user".to_string(), secret.to_vec())]; + + assert!(validate_tls_handshake_at_time(&base, &secrets, false, now).is_some()); + + let mut corpus = Vec::new(); + + let mut truncated = base.clone(); + truncated.truncate(TLS_DIGEST_POS + 16); + corpus.push(truncated); + + let mut digest_flip = base.clone(); + digest_flip[TLS_DIGEST_POS + 7] ^= 0x80; + corpus.push(digest_flip); + + let mut session_id_len_overflow = base.clone(); + session_id_len_overflow[TLS_DIGEST_POS + TLS_DIGEST_LEN] = 33; + corpus.push(session_id_len_overflow); + + let mut timestamp_far_past = base.clone(); + timestamp_far_past[TLS_DIGEST_POS + 28..TLS_DIGEST_POS + 32] + .copy_from_slice(&((now - i64::from(TIME_SKEW_MAX) - 1) as u32).to_le_bytes()); + corpus.push(timestamp_far_past); + + let mut timestamp_far_future = base.clone(); + timestamp_far_future[TLS_DIGEST_POS + 28..TLS_DIGEST_POS + 32] + .copy_from_slice(&((now - TIME_SKEW_MIN + 1) as u32).to_le_bytes()); + corpus.push(timestamp_far_future); + + let mut seed = 0xA5A5_5A5A_F00D_BAAD_u64; + for _ in 0..32 { + let mut mutated = base.clone(); + for _ in 0..2 { + seed = seed.wrapping_mul(2862933555777941757).wrapping_add(3037000493); + let idx = TLS_DIGEST_POS + (seed as usize % TLS_DIGEST_LEN); + mutated[idx] ^= ((seed >> 17) as u8).wrapping_add(1); + } + corpus.push(mutated); + } + + for (idx, handshake) in corpus.iter().enumerate() { + let result = catch_unwind(|| validate_tls_handshake_at_time(handshake, &secrets, false, now)); + assert!(result.is_ok(), "corpus item {idx} must not panic"); + assert!(result.unwrap().is_none(), "corpus item {idx} must fail closed"); + } +} + +#[test] +fn tls_boot_time_acceptance_is_capped_by_replay_window() { + let secret = b"tls_boot_time_cap_secret"; + let secrets = vec![("boot-user".to_string(), secret.to_vec())]; + let boot_ts = 1u32; + let handshake = make_valid_tls_handshake_with_session_id(secret, boot_ts, &[0x42; 32]); + + assert!( + validate_tls_handshake_with_replay_window(&handshake, &secrets, false, 300).is_some(), + "boot-time timestamp should be accepted while replay window permits it" + ); + assert!( + validate_tls_handshake_with_replay_window(&handshake, &secrets, false, 0).is_none(), + "boot-time timestamp must be rejected when replay window disables the bypass" + ); +} diff --git a/src/proxy/client.rs b/src/proxy/client.rs index 6c64a94..cbf68a7 100644 --- a/src/proxy/client.rs +++ b/src/proxy/client.rs @@ -1040,3 +1040,6 @@ impl RunningClientHandler { #[cfg(test)] #[path = "client_security_tests.rs"] mod security_tests; +#[cfg(test)] +#[path = "client_adversarial_tests.rs"] +mod adversarial_tests; diff --git a/src/proxy/client_adversarial_tests.rs b/src/proxy/client_adversarial_tests.rs new file mode 100644 index 0000000..80d65f2 --- /dev/null +++ b/src/proxy/client_adversarial_tests.rs @@ -0,0 +1,109 @@ +use super::*; +use crate::config::ProxyConfig; +use crate::stats::Stats; +use crate::ip_tracker::UserIpTracker; +use crate::error::ProxyError; +use std::sync::Arc; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; + +// ------------------------------------------------------------------ +// Priority 3: Massive Concurrency Stress (OWASP ASVS 5.1.6) +// ------------------------------------------------------------------ + +#[tokio::test] +async fn client_stress_10k_connections_limit_strict() { + let user = "stress-user"; + let limit = 512; + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let mut config = ProxyConfig::default(); + config.access.user_max_tcp_conns.insert(user.to_string(), limit); + + let iterations = 1000; + let mut tasks = Vec::new(); + + for i in 0..iterations { + let stats = Arc::clone(&stats); + let ip_tracker = Arc::clone(&ip_tracker); + let config = config.clone(); + let user_str = user.to_string(); + + tasks.push(tokio::spawn(async move { + let peer = SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(127, 0, 0, (i % 254 + 1) as u8)), + 10000 + (i % 1000) as u16, + ); + + match RunningClientHandler::acquire_user_connection_reservation_static( + &user_str, + &config, + stats, + peer, + ip_tracker, + ).await { + Ok(res) => Ok(res), + Err(ProxyError::ConnectionLimitExceeded { .. }) => Err(()), + Err(e) => panic!("Unexpected error: {:?}", e), + } + })); + } + + let results = futures::future::join_all(tasks).await; + let mut successes = 0; + let mut failures = 0; + let mut reservations = Vec::new(); + + for res in results { + match res.unwrap() { + Ok(r) => { + successes += 1; + reservations.push(r); + } + Err(_) => failures += 1, + } + } + + assert_eq!(successes, limit, "Should allow exactly 'limit' connections"); + assert_eq!(failures, iterations - limit, "Should fail the rest with LimitExceeded"); + assert_eq!(stats.get_user_curr_connects(user), limit as u64); + + drop(reservations); + + ip_tracker.drain_cleanup_queue().await; + + assert_eq!(stats.get_user_curr_connects(user), 0, "Stats must converge to 0 after all drops"); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 0, "IP tracker must converge to 0"); +} + +// ------------------------------------------------------------------ +// Priority 3: IP Tracker Race Stress +// ------------------------------------------------------------------ + +#[tokio::test] +async fn client_ip_tracker_race_condition_stress() { + let user = "race-user"; + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 100).await; + + let iterations = 1000; + let mut tasks = Vec::new(); + + for i in 0..iterations { + let ip_tracker = Arc::clone(&ip_tracker); + let ip = IpAddr::V4(Ipv4Addr::new(10, 0, 0, (i % 254 + 1) as u8)); + + tasks.push(tokio::spawn(async move { + for _ in 0..10 { + if let Ok(()) = ip_tracker.check_and_add("race-user", ip).await { + ip_tracker.remove_ip("race-user", ip).await; + } + } + })); + } + + futures::future::join_all(tasks).await; + + assert_eq!(ip_tracker.get_active_ip_count(user).await, 0, "IP count must be zero after balanced add/remove burst"); +} diff --git a/src/proxy/client_security_tests.rs b/src/proxy/client_security_tests.rs index abd6266..d3c411e 100644 --- a/src/proxy/client_security_tests.rs +++ b/src/proxy/client_security_tests.rs @@ -7,6 +7,7 @@ use crate::protocol::tls; use crate::proxy::handshake::HandshakeSuccess; use crate::stream::{CryptoReader, CryptoWriter}; use crate::transport::proxy_protocol::ProxyProtocolV1Builder; +use std::net::Ipv4Addr; use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream}; @@ -630,6 +631,54 @@ async fn proxy_protocol_header_from_untrusted_peer_range_is_rejected_under_load( } } +#[tokio::test] +async fn reservation_limit_failure_does_not_leak_curr_connects_counter() { + let user = "leak-check-user"; + let mut config = crate::config::ProxyConfig::default(); + config.access.user_max_tcp_conns.insert(user.to_string(), 1); + + let stats = Arc::new(crate::stats::Stats::new()); + let ip_tracker = Arc::new(crate::ip_tracker::UserIpTracker::new()); + ip_tracker.set_user_limit(user, 8).await; + + let first_peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(198, 51, 200, 1)), 50001); + let first = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + first_peer, + ip_tracker.clone(), + ) + .await + .expect("first reservation must succeed"); + + assert_eq!(stats.get_user_curr_connects(user), 1); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 1); + + let second_peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(198, 51, 200, 2)), 50002); + let second = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + second_peer, + ip_tracker.clone(), + ) + .await; + + assert!( + matches!(second, Err(crate::error::ProxyError::ConnectionLimitExceeded { user: denied }) if denied == user), + "second reservation must be rejected at the configured tcp-conns limit" + ); + assert_eq!(stats.get_user_curr_connects(user), 1, "failed acquisition must not leak a counter increment"); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 1, "failed acquisition must not mutate IP tracker state"); + + first.release().await; + ip_tracker.drain_cleanup_queue().await; + + assert_eq!(stats.get_user_curr_connects(user), 0); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 0); +} + #[tokio::test] async fn short_tls_probe_is_masked_through_client_pipeline() { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); diff --git a/src/proxy/direct_relay_security_tests.rs b/src/proxy/direct_relay_security_tests.rs index 6c25068..e8016a5 100644 --- a/src/proxy/direct_relay_security_tests.rs +++ b/src/proxy/direct_relay_security_tests.rs @@ -12,8 +12,10 @@ use std::path::Path; use std::sync::Arc; use std::sync::atomic::{AtomicUsize, Ordering}; use std::time::Duration; +use tokio::io::AsyncReadExt; use tokio::io::duplex; use tokio::net::TcpListener; +use tokio::time::{timeout, Duration as TokioDuration}; fn make_crypto_reader(reader: R) -> CryptoReader where @@ -1322,3 +1324,194 @@ fn stress_prefer_v6_override_matrix_is_deterministic_under_mixed_inputs() { assert!(first.is_ipv6(), "dc {idx}: v6 override should be preferred"); } } + +#[tokio::test] +async fn negative_direct_relay_dc_connection_refused_fails_fast() { + let (client_reader_side, _client_writer_side) = duplex(1024); + let (_client_reader_relay, client_writer_side) = duplex(1024); + + let key = [0u8; 32]; + let iv = 0u128; + let client_reader = CryptoReader::new(client_reader_side, AesCtr::new(&key, iv)); + let client_writer = CryptoWriter::new(client_writer_side, AesCtr::new(&key, iv), 1024); + + let stats = Arc::new(Stats::new()); + let buffer_pool = Arc::new(BufferPool::with_config(1024, 1)); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = RouteRuntimeController::new(RelayRouteMode::Direct); + + // Reserve an ephemeral port and immediately release it to deterministically + // exercise the direct-connect failure path without long-lived hangs. + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let dc_addr = listener.local_addr().unwrap(); + drop(listener); + + let mut config_with_override = ProxyConfig::default(); + config_with_override.dc_overrides.insert("1".to_string(), vec![dc_addr.to_string()]); + let config = Arc::new(config_with_override); + + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + enabled: true, + weight: 1, + scopes: String::new(), + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + selected_scope: String::new(), + }], + 1, + 100, + 5000, + 3, + false, + stats.clone(), + )); + + let success = HandshakeSuccess { + user: "test-user".to_string(), + peer: "127.0.0.1:12345".parse().unwrap(), + dc_idx: 1, + proto_tag: ProtoTag::Intermediate, + enc_key: key, + enc_iv: iv, + dec_key: key, + dec_iv: iv, + is_tls: false, + }; + + let result = timeout( + TokioDuration::from_secs(2), + handle_via_direct( + client_reader, + client_writer, + success, + upstream_manager, + stats, + config, + buffer_pool, + rng, + route_runtime.subscribe(), + route_runtime.snapshot(), + 0xABCD_1234, + ), + ) + .await + .expect("direct relay must fail fast on connection-refused upstream"); + + assert!( + result.is_err(), + "connection-refused upstream must fail closed" + ); +} + +#[tokio::test] +async fn adversarial_direct_relay_cutover_integrity() { + let (client_reader_side, _client_writer_side) = duplex(1024); + let (_client_reader_relay, client_writer_side) = duplex(1024); + + let key = [0u8; 32]; + let iv = 0u128; + let client_reader = CryptoReader::new(client_reader_side, AesCtr::new(&key, iv)); + let client_writer = CryptoWriter::new(client_writer_side, AesCtr::new(&key, iv), 1024); + + let stats = Arc::new(Stats::new()); + let buffer_pool = Arc::new(BufferPool::with_config(1024, 1)); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = RouteRuntimeController::new(RelayRouteMode::Direct); + + // Mock upstream server. + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let dc_addr = listener.local_addr().unwrap(); + + tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + // Read handshake nonce. + let mut nonce = [0u8; 64]; + let _ = stream.read_exact(&mut nonce).await; + // Keep connection open. + tokio::time::sleep(TokioDuration::from_secs(5)).await; + }); + + let mut config_with_override = ProxyConfig::default(); + config_with_override.dc_overrides.insert("1".to_string(), vec![dc_addr.to_string()]); + let config = Arc::new(config_with_override); + + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + enabled: true, + weight: 1, + scopes: String::new(), + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + selected_scope: String::new(), + }], + 1, + 100, + 5000, + 3, + false, + stats.clone(), + )); + + let success = HandshakeSuccess { + user: "test-user".to_string(), + peer: "127.0.0.1:12345".parse().unwrap(), + dc_idx: 1, + proto_tag: ProtoTag::Intermediate, + enc_key: key, + enc_iv: iv, + dec_key: key, + dec_iv: iv, + is_tls: false, + }; + + let stats_for_task = stats.clone(); + let runtime_clone = route_runtime.clone(); + let session_task = tokio::spawn(async move { + handle_via_direct( + client_reader, + client_writer, + success, + upstream_manager, + stats_for_task, + config, + buffer_pool, + rng, + runtime_clone.subscribe(), + runtime_clone.snapshot(), + 0xABCD_1234, + ).await + }); + + timeout(TokioDuration::from_secs(2), async { + loop { + if stats.get_current_connections_direct() == 1 { + break; + } + tokio::time::sleep(TokioDuration::from_millis(10)).await; + } + }) + .await + .expect("direct relay session must start before cutover"); + + // Trigger cutover. + route_runtime.set_mode(RelayRouteMode::Middle).unwrap(); + + // The session should terminate after the staggered delay (1000-2000ms). + let result = timeout(TokioDuration::from_secs(5), session_task) + .await + .expect("Session must terminate after cutover") + .expect("Session must not panic"); + + assert!( + matches!( + result, + Err(ProxyError::Proxy(ref msg)) if msg == ROUTE_SWITCH_ERROR_MSG + ), + "Session must terminate with route switch error on cutover" + ); +} diff --git a/src/proxy/handshake.rs b/src/proxy/handshake.rs index 6886e65..b930caf 100644 --- a/src/proxy/handshake.rs +++ b/src/proxy/handshake.rs @@ -968,8 +968,12 @@ pub fn encrypt_tg_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec { mod security_tests; #[cfg(test)] -#[path = "handshake_gap_short_tls_probe_throttle_security_tests.rs"] -mod gap_short_tls_probe_throttle_security_tests; +#[path = "handshake_adversarial_tests.rs"] +mod adversarial_tests; + +#[cfg(test)] +#[path = "handshake_fuzz_security_tests.rs"] +mod fuzz_security_tests; /// Compile-time guard: HandshakeSuccess holds cryptographic key material and /// must never be Copy. A Copy impl would allow silent key duplication, diff --git a/src/proxy/handshake_adversarial_tests.rs b/src/proxy/handshake_adversarial_tests.rs new file mode 100644 index 0000000..f93d8ce --- /dev/null +++ b/src/proxy/handshake_adversarial_tests.rs @@ -0,0 +1,231 @@ +use super::*; +use std::sync::Arc; +use std::net::{IpAddr, Ipv4Addr}; +use std::time::{Duration, Instant}; +use crate::crypto::sha256; + +fn make_valid_mtproto_handshake(secret_hex: &str, proto_tag: ProtoTag, dc_idx: i16) -> [u8; HANDSHAKE_LEN] { + let secret = hex::decode(secret_hex).expect("secret hex must decode"); + let mut handshake = [0x5Au8; HANDSHAKE_LEN]; + for (idx, b) in handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN] + .iter_mut() + .enumerate() + { + *b = (idx as u8).wrapping_add(1); + } + + let dec_prekey = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN]; + let dec_iv_bytes = &handshake[SKIP_LEN + PREKEY_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]; + + let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len()); + dec_key_input.extend_from_slice(dec_prekey); + dec_key_input.extend_from_slice(&secret); + let dec_key = sha256(&dec_key_input); + + let mut dec_iv_arr = [0u8; IV_LEN]; + dec_iv_arr.copy_from_slice(dec_iv_bytes); + let dec_iv = u128::from_be_bytes(dec_iv_arr); + + let mut stream = AesCtr::new(&dec_key, dec_iv); + let keystream = stream.encrypt(&[0u8; HANDSHAKE_LEN]); + + let mut target_plain = [0u8; HANDSHAKE_LEN]; + target_plain[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes()); + target_plain[DC_IDX_POS..DC_IDX_POS + 2].copy_from_slice(&dc_idx.to_le_bytes()); + + for idx in PROTO_TAG_POS..HANDSHAKE_LEN { + handshake[idx] = target_plain[idx] ^ keystream[idx]; + } + + handshake +} + +fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> { + auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig { + let mut cfg = ProxyConfig::default(); + cfg.access.users.clear(); + cfg.access.users.insert("user".to_string(), secret_hex.to_string()); + cfg.access.ignore_time_skew = true; + cfg.general.modes.secure = true; + cfg +} + +// ------------------------------------------------------------------ +// Mutational Bit-Flipping Tests (OWASP ASVS 5.1.4) +// ------------------------------------------------------------------ + +#[tokio::test] +async fn mtproto_handshake_bit_flip_anywhere_rejected() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret_hex = "11223344556677889900aabbccddeeff"; + let base = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 2); + let config = test_config_with_secret_hex(secret_hex); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let peer: SocketAddr = "192.0.2.1:12345".parse().unwrap(); + + // Baseline check + let res = handle_mtproto_handshake(&base, tokio::io::empty(), tokio::io::sink(), peer, &config, &replay_checker, false, None).await; + match res { + HandshakeResult::Success(_) => {}, + _ => panic!("Baseline failed: expected Success"), + } + + // Flip bits in the encrypted part (beyond the key material) + for byte_pos in SKIP_LEN..HANDSHAKE_LEN { + let mut h = base; + h[byte_pos] ^= 0x01; // Flip 1 bit + let res = handle_mtproto_handshake(&h, tokio::io::empty(), tokio::io::sink(), peer, &config, &replay_checker, false, None).await; + assert!(matches!(res, HandshakeResult::BadClient { .. }), "Flip at byte {byte_pos} bit 0 must be rejected"); + } +} + +// ------------------------------------------------------------------ +// Adversarial Probing / Timing Neutrality (OWASP ASVS 5.1.7) +// ------------------------------------------------------------------ + +#[tokio::test] +async fn mtproto_handshake_timing_neutrality_mocked() { + let secret_hex = "00112233445566778899aabbccddeeff"; + let base = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 1); + let config = test_config_with_secret_hex(secret_hex); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let peer: SocketAddr = "192.0.2.2:54321".parse().unwrap(); + + const ITER: usize = 50; + + let mut start = Instant::now(); + for _ in 0..ITER { + let _ = handle_mtproto_handshake(&base, tokio::io::empty(), tokio::io::sink(), peer, &config, &replay_checker, false, None).await; + } + let duration_success = start.elapsed(); + + start = Instant::now(); + for i in 0..ITER { + let mut h = base; + h[SKIP_LEN + (i % 48)] ^= 0xFF; + let _ = handle_mtproto_handshake(&h, tokio::io::empty(), tokio::io::sink(), peer, &config, &replay_checker, false, None).await; + } + let duration_fail = start.elapsed(); + + let avg_diff_ms = (duration_success.as_millis() as f64 - duration_fail.as_millis() as f64).abs() / ITER as f64; + + // Threshold (loose for CI) + assert!(avg_diff_ms < 100.0, "Timing difference too large: {} ms/iter", avg_diff_ms); +} + +// ------------------------------------------------------------------ +// Stress Tests (OWASP ASVS 5.1.6) +// ------------------------------------------------------------------ + +#[tokio::test] +async fn auth_probe_throttle_saturation_stress() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let now = Instant::now(); + + // Record enough failures for one IP to trigger backoff + let target_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)); + for _ in 0..AUTH_PROBE_BACKOFF_START_FAILS { + auth_probe_record_failure(target_ip, now); + } + + assert!(auth_probe_is_throttled(target_ip, now)); + + // Stress test with many unique IPs + for i in 0..500u32 { + let ip = IpAddr::V4(Ipv4Addr::new(203, 0, 113, (i % 256) as u8)); + auth_probe_record_failure(ip, now); + } + + let tracked = AUTH_PROBE_STATE + .get() + .map(|state| state.len()) + .unwrap_or(0); + assert!( + tracked <= AUTH_PROBE_TRACK_MAX_ENTRIES, + "auth probe state grew past hard cap: {tracked} > {AUTH_PROBE_TRACK_MAX_ENTRIES}" + ); +} + +#[tokio::test] +async fn mtproto_handshake_abridged_prefix_rejected() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let mut handshake = [0x5Au8; HANDSHAKE_LEN]; + handshake[0] = 0xef; // Abridged prefix + let config = ProxyConfig::default(); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let peer: SocketAddr = "192.0.2.3:12345".parse().unwrap(); + + let res = handle_mtproto_handshake(&handshake, tokio::io::empty(), tokio::io::sink(), peer, &config, &replay_checker, false, None).await; + // MTProxy stops immediately on 0xef + assert!(matches!(res, HandshakeResult::BadClient { .. })); +} + +#[tokio::test] +async fn mtproto_handshake_preferred_user_mismatch_continues() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret1_hex = "11111111111111111111111111111111"; + let secret2_hex = "22222222222222222222222222222222"; + + let base = make_valid_mtproto_handshake(secret2_hex, ProtoTag::Secure, 1); + let mut config = ProxyConfig::default(); + config.access.users.insert("user1".to_string(), secret1_hex.to_string()); + config.access.users.insert("user2".to_string(), secret2_hex.to_string()); + config.access.ignore_time_skew = true; + config.general.modes.secure = true; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let peer: SocketAddr = "192.0.2.4:12345".parse().unwrap(); + + // Even if we prefer user1, if user2 matches, it should succeed. + let res = handle_mtproto_handshake(&base, tokio::io::empty(), tokio::io::sink(), peer, &config, &replay_checker, false, Some("user1")).await; + if let HandshakeResult::Success((_, _, success)) = res { + assert_eq!(success.user, "user2"); + } else { + panic!("Handshake failed even though user2 matched"); + } +} + +#[tokio::test] +async fn mtproto_handshake_concurrent_flood_stability() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret_hex = "00112233445566778899aabbccddeeff"; + let base = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 1); + let mut config = test_config_with_secret_hex(secret_hex); + config.access.ignore_time_skew = true; + let replay_checker = Arc::new(ReplayChecker::new(1024, Duration::from_secs(60))); + let config = Arc::new(config); + + let mut tasks = Vec::new(); + for i in 0..50 { + let base = base; + let config = Arc::clone(&config); + let replay_checker = Arc::clone(&replay_checker); + let peer: SocketAddr = format!("192.0.2.{}:12345", (i % 254) + 1).parse().unwrap(); + + tasks.push(tokio::spawn(async move { + let res = handle_mtproto_handshake(&base, tokio::io::empty(), tokio::io::sink(), peer, &config, &replay_checker, false, None).await; + matches!(res, HandshakeResult::Success(_)) + })); + } + + // We don't necessarily care if they all succeed (some might fail due to replay if they hit the same chunk), + // but the system must not panic or hang. + for task in tasks { + let _ = task.await.unwrap(); + } +} diff --git a/src/proxy/handshake_fuzz_security_tests.rs b/src/proxy/handshake_fuzz_security_tests.rs new file mode 100644 index 0000000..d72c9cd --- /dev/null +++ b/src/proxy/handshake_fuzz_security_tests.rs @@ -0,0 +1,270 @@ +use super::*; +use crate::config::ProxyConfig; +use crate::crypto::AesCtr; +use crate::crypto::sha256; +use crate::protocol::constants::ProtoTag; +use crate::stats::ReplayChecker; +use std::net::SocketAddr; +use std::sync::MutexGuard; +use tokio::time::{timeout, Duration as TokioDuration}; + +fn make_mtproto_handshake_with_proto_bytes( + secret_hex: &str, + proto_bytes: [u8; 4], + dc_idx: i16, +) -> [u8; HANDSHAKE_LEN] { + let secret = hex::decode(secret_hex).expect("secret hex must decode"); + let mut handshake = [0x5Au8; HANDSHAKE_LEN]; + for (idx, b) in handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN] + .iter_mut() + .enumerate() + { + *b = (idx as u8).wrapping_add(1); + } + + let dec_prekey = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN]; + let dec_iv_bytes = &handshake[SKIP_LEN + PREKEY_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]; + + let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len()); + dec_key_input.extend_from_slice(dec_prekey); + dec_key_input.extend_from_slice(&secret); + let dec_key = sha256(&dec_key_input); + + let mut dec_iv_arr = [0u8; IV_LEN]; + dec_iv_arr.copy_from_slice(dec_iv_bytes); + let dec_iv = u128::from_be_bytes(dec_iv_arr); + + let mut stream = AesCtr::new(&dec_key, dec_iv); + let keystream = stream.encrypt(&[0u8; HANDSHAKE_LEN]); + + let mut target_plain = [0u8; HANDSHAKE_LEN]; + target_plain[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_bytes); + target_plain[DC_IDX_POS..DC_IDX_POS + 2].copy_from_slice(&dc_idx.to_le_bytes()); + + for idx in PROTO_TAG_POS..HANDSHAKE_LEN { + handshake[idx] = target_plain[idx] ^ keystream[idx]; + } + + handshake +} + +fn make_valid_mtproto_handshake(secret_hex: &str, proto_tag: ProtoTag, dc_idx: i16) -> [u8; HANDSHAKE_LEN] { + make_mtproto_handshake_with_proto_bytes(secret_hex, proto_tag.to_bytes(), dc_idx) +} + +fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig { + let mut cfg = ProxyConfig::default(); + cfg.access.users.clear(); + cfg.access.users.insert("user".to_string(), secret_hex.to_string()); + cfg.access.ignore_time_skew = true; + cfg.general.modes.secure = true; + cfg +} + +fn auth_probe_test_guard() -> MutexGuard<'static, ()> { + auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +#[tokio::test] +async fn mtproto_handshake_duplicate_digest_is_replayed_on_second_attempt() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret_hex = "11223344556677889900aabbccddeeff"; + let base = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 2); + let config = test_config_with_secret_hex(secret_hex); + let replay_checker = ReplayChecker::new(128, TokioDuration::from_secs(60)); + let peer: SocketAddr = "192.0.2.1:12345".parse().unwrap(); + + let first = handle_mtproto_handshake( + &base, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + assert!(matches!(first, HandshakeResult::Success(_))); + + let second = handle_mtproto_handshake( + &base, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + assert!(matches!(second, HandshakeResult::BadClient { .. })); + + clear_auth_probe_state_for_testing(); +} + +#[tokio::test] +async fn mtproto_handshake_fuzz_corpus_never_panics_and_stays_fail_closed() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret_hex = "00112233445566778899aabbccddeeff"; + let base = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 1); + let config = test_config_with_secret_hex(secret_hex); + let replay_checker = ReplayChecker::new(128, TokioDuration::from_secs(60)); + let peer: SocketAddr = "192.0.2.2:54321".parse().unwrap(); + + let mut corpus = Vec::<[u8; HANDSHAKE_LEN]>::new(); + + corpus.push(make_mtproto_handshake_with_proto_bytes( + secret_hex, + [0x00, 0x00, 0x00, 0x00], + 1, + )); + corpus.push(make_mtproto_handshake_with_proto_bytes( + secret_hex, + [0xff, 0xff, 0xff, 0xff], + 1, + )); + corpus.push(make_valid_mtproto_handshake( + "ffeeddccbbaa99887766554433221100", + ProtoTag::Secure, + 1, + )); + + let mut seed = 0xF0F0_F00D_BAAD_CAFEu64; + for _ in 0..32 { + let mut mutated = base; + for _ in 0..4 { + seed = seed.wrapping_mul(2862933555777941757).wrapping_add(3037000493); + let idx = SKIP_LEN + (seed as usize % (PREKEY_LEN + IV_LEN)); + mutated[idx] ^= ((seed >> 19) as u8).wrapping_add(1); + } + corpus.push(mutated); + } + + for (idx, input) in corpus.into_iter().enumerate() { + let result = timeout( + TokioDuration::from_secs(1), + handle_mtproto_handshake( + &input, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ), + ) + .await + .expect("fuzzed handshake must complete in time"); + + assert!( + matches!(result, HandshakeResult::BadClient { .. }), + "corpus item {idx} must fail closed" + ); + } + + clear_auth_probe_state_for_testing(); +} + +#[tokio::test] +async fn mtproto_handshake_mixed_corpus_never_panics_and_exact_duplicates_are_rejected() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret_hex = "99887766554433221100ffeeddccbbaa"; + let base = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 4); + let config = test_config_with_secret_hex(secret_hex); + let replay_checker = ReplayChecker::new(256, TokioDuration::from_secs(60)); + let peer: SocketAddr = "192.0.2.44:45444".parse().unwrap(); + + let first = timeout( + TokioDuration::from_secs(1), + handle_mtproto_handshake( + &base, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ), + ) + .await + .expect("base handshake must not hang"); + assert!(matches!(first, HandshakeResult::Success(_))); + + let replay = timeout( + TokioDuration::from_secs(1), + handle_mtproto_handshake( + &base, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ), + ) + .await + .expect("duplicate handshake must not hang"); + assert!(matches!(replay, HandshakeResult::BadClient { .. })); + + let mut corpus = Vec::<[u8; HANDSHAKE_LEN]>::new(); + + let mut prekey_flip = base; + prekey_flip[SKIP_LEN] ^= 0x80; + corpus.push(prekey_flip); + + let mut iv_flip = base; + iv_flip[SKIP_LEN + PREKEY_LEN] ^= 0x01; + corpus.push(iv_flip); + + let mut tail_flip = base; + tail_flip[SKIP_LEN + PREKEY_LEN + IV_LEN - 1] ^= 0x40; + corpus.push(tail_flip); + + let mut seed = 0xBADC_0FFE_EE11_4242u64; + for _ in 0..24 { + let mut mutated = base; + for _ in 0..3 { + seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1); + let idx = SKIP_LEN + (seed as usize % (PREKEY_LEN + IV_LEN)); + mutated[idx] ^= ((seed >> 16) as u8).wrapping_add(1); + } + corpus.push(mutated); + } + + for (idx, input) in corpus.iter().enumerate() { + let result = timeout( + TokioDuration::from_secs(1), + handle_mtproto_handshake( + input, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ), + ) + .await + .expect("fuzzed handshake must complete in time"); + + assert!( + matches!(result, HandshakeResult::BadClient { .. }), + "mixed corpus item {idx} must fail closed" + ); + } + + clear_auth_probe_state_for_testing(); +} \ No newline at end of file diff --git a/src/proxy/handshake_gap_short_tls_probe_throttle_security_tests.rs b/src/proxy/handshake_gap_short_tls_probe_throttle_security_tests.rs deleted file mode 100644 index 2ea32bc..0000000 --- a/src/proxy/handshake_gap_short_tls_probe_throttle_security_tests.rs +++ /dev/null @@ -1,50 +0,0 @@ -use super::*; -use crate::stats::ReplayChecker; -use std::net::SocketAddr; -use std::time::Duration; - -fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig { - let mut cfg = ProxyConfig::default(); - cfg.access.users.clear(); - cfg.access - .users - .insert("user".to_string(), secret_hex.to_string()); - cfg.access.ignore_time_skew = true; - cfg -} - -#[tokio::test] -async fn gap_t01_short_tls_probe_burst_is_throttled() { - let _guard = auth_probe_test_lock() - .lock() - .unwrap_or_else(|poisoned| poisoned.into_inner()); - clear_auth_probe_state_for_testing(); - - let config = test_config_with_secret_hex("11111111111111111111111111111111"); - let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); - let rng = SecureRandom::new(); - let peer: SocketAddr = "198.51.100.171:44361".parse().unwrap(); - - let too_short = vec![0x16, 0x03, 0x01]; - - for _ in 0..AUTH_PROBE_BACKOFF_START_FAILS { - let result = handle_tls_handshake( - &too_short, - tokio::io::empty(), - tokio::io::sink(), - peer, - &config, - &replay_checker, - &rng, - None, - ) - .await; - assert!(matches!(result, HandshakeResult::BadClient { .. })); - } - - assert!( - auth_probe_fail_streak_for_testing(peer.ip()) - .is_some_and(|streak| streak >= AUTH_PROBE_BACKOFF_START_FAILS), - "short TLS probe bursts must increase auth-probe fail streak" - ); -} diff --git a/src/proxy/handshake_security_tests.rs b/src/proxy/handshake_security_tests.rs index c93d18e..5263413 100644 --- a/src/proxy/handshake_security_tests.rs +++ b/src/proxy/handshake_security_tests.rs @@ -1997,6 +1997,42 @@ fn auth_probe_round_limited_overcap_eviction_marks_saturation_and_keeps_newcomer ); } +#[tokio::test] +async fn gap_t01_short_tls_probe_burst_is_throttled() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let config = test_config_with_secret_hex("11111111111111111111111111111111"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.171:44361".parse().unwrap(); + + let too_short = vec![0x16, 0x03, 0x01]; + + for _ in 0..AUTH_PROBE_BACKOFF_START_FAILS { + let result = handle_tls_handshake( + &too_short, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(result, HandshakeResult::BadClient { .. })); + } + + assert!( + auth_probe_fail_streak_for_testing(peer.ip()) + .is_some_and(|streak| streak >= AUTH_PROBE_BACKOFF_START_FAILS), + "short TLS probe bursts must increase auth-probe fail streak" + ); +} + #[test] fn stress_auth_probe_overcap_churn_does_not_starve_high_threat_sentinel_bucket() { let _guard = auth_probe_test_lock() diff --git a/src/proxy/masking.rs b/src/proxy/masking.rs index 030fb2f..a7da35a 100644 --- a/src/proxy/masking.rs +++ b/src/proxy/masking.rs @@ -317,3 +317,7 @@ async fn consume_client_data(mut reader: R) { #[cfg(test)] #[path = "masking_security_tests.rs"] mod security_tests; + +#[cfg(test)] +#[path = "masking_adversarial_tests.rs"] +mod adversarial_tests; diff --git a/src/proxy/masking_adversarial_tests.rs b/src/proxy/masking_adversarial_tests.rs new file mode 100644 index 0000000..16b0047 --- /dev/null +++ b/src/proxy/masking_adversarial_tests.rs @@ -0,0 +1,213 @@ +use super::*; +use std::sync::Arc; +use tokio::io::duplex; +use tokio::net::TcpListener; +use tokio::time::{Instant, Duration}; +use crate::config::ProxyConfig; +use crate::stats::beobachten::BeobachtenStore; + +// ------------------------------------------------------------------ +// Probing Indistinguishability (OWASP ASVS 5.1.7) +// ------------------------------------------------------------------ + +#[tokio::test] +async fn masking_probes_indistinguishable_timing() { + let mut config = ProxyConfig::default(); + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = 80; // Should timeout/refuse + + let peer: SocketAddr = "192.0.2.10:443".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + // Test different probe types + let probes = vec![ + (b"GET / HTTP/1.1\r\nHost: x\r\n\r\n".to_vec(), "HTTP"), + (b"SSH-2.0-probe".to_vec(), "SSH"), + (vec![0x16, 0x03, 0x03, 0x00, 0x05, 0x01, 0x00, 0x00, 0x01, 0x00], "TLS-scanner"), + (vec![0x42; 5], "port-scanner"), + ]; + + for (probe, type_name) in probes { + let (client_reader, _client_writer) = duplex(256); + let (_client_visible_reader, client_visible_writer) = duplex(256); + + let start = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ).await; + + let elapsed = start.elapsed(); + + // We expect any outcome to take roughly MASK_TIMEOUT (50ms in tests) + // to mask whether the backend was reachable or refused. + assert!(elapsed >= Duration::from_millis(30), "Probe {type_name} finished too fast: {elapsed:?}"); + } +} + +// ------------------------------------------------------------------ +// Masking Budget Stress Tests (OWASP ASVS 5.1.6) +// ------------------------------------------------------------------ + +#[tokio::test] +async fn masking_budget_stress_under_load() { + let mut config = ProxyConfig::default(); + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = 1; // Unlikely port + + let peer: SocketAddr = "192.0.2.20:443".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let beobachten = Arc::new(BeobachtenStore::new()); + + let mut tasks = Vec::new(); + for _ in 0..50 { + let (client_reader, _client_writer) = duplex(256); + let (_client_visible_reader, client_visible_writer) = duplex(256); + let config = config.clone(); + let beobachten = Arc::clone(&beobachten); + + tasks.push(tokio::spawn(async move { + let start = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + b"probe", + peer, + local_addr, + &config, + &beobachten, + ).await; + start.elapsed() + })); + } + + for task in tasks { + let elapsed = task.await.unwrap(); + assert!(elapsed >= Duration::from_millis(30), "Stress probe finished too fast: {elapsed:?}"); + } +} + +// ------------------------------------------------------------------ +// detect_client_type Fingerprint Check +// ------------------------------------------------------------------ + +#[test] +fn test_detect_client_type_boundary_cases() { + // 9 bytes = port-scanner + assert_eq!(detect_client_type(&[0x42; 9]), "port-scanner"); + // 10 bytes = unknown + assert_eq!(detect_client_type(&[0x42; 10]), "unknown"); + + // HTTP verbs without trailing space + assert_eq!(detect_client_type(b"GET/"), "port-scanner"); // because len < 10 + assert_eq!(detect_client_type(b"GET /path"), "HTTP"); +} + +// ------------------------------------------------------------------ +// Priority 2: Slowloris and Slow Read Attacks (OWASP ASVS 5.1.5) +// ------------------------------------------------------------------ + +#[tokio::test] +async fn masking_slowloris_client_idle_timeout_rejected() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let initial = b"GET / HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn({ + let initial = initial.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut observed = vec![0u8; initial.len()]; + stream.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, initial); + + let mut drip = [0u8; 1]; + let drip_read = tokio::time::timeout(Duration::from_millis(220), stream.read_exact(&mut drip)).await; + assert!( + drip_read.is_err() || drip_read.unwrap().is_err(), + "backend must not receive post-timeout slowloris drip bytes" + ); + } + }); + + let mut config = ProxyConfig::default(); + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + + let beobachten = BeobachtenStore::new(); + let peer: SocketAddr = "192.0.2.10:12345".parse().unwrap(); + let local: SocketAddr = "192.0.2.1:443".parse().unwrap(); + + let (mut client_writer, client_reader) = duplex(1024); + let (_client_visible_reader, client_visible_writer) = duplex(1024); + + let handle = tokio::spawn(async move { + handle_bad_client( + client_reader, + client_visible_writer, + &initial, + peer, + local, + &config, + &beobachten, + ) + .await; + }); + + tokio::time::sleep(Duration::from_millis(160)).await; + let _ = client_writer.write_all(b"X").await; + + handle.await.unwrap(); + accept_task.await.unwrap(); +} + +// ------------------------------------------------------------------ +// Priority 2: Fallback Server Down / Fingerprinting (OWASP ASVS 5.1.7) +// ------------------------------------------------------------------ + +#[tokio::test] +async fn masking_fallback_down_mimics_timeout() { + let mut config = ProxyConfig::default(); + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = 1; // Unlikely port + + let (server_reader, server_writer) = duplex(1024); + let beobachten = BeobachtenStore::new(); + let peer: SocketAddr = "192.0.2.12:12345".parse().unwrap(); + let local: SocketAddr = "192.0.2.1:443".parse().unwrap(); + + let start = Instant::now(); + handle_bad_client(server_reader, server_writer, b"GET / HTTP/1.1\r\n", peer, local, &config, &beobachten).await; + + let elapsed = start.elapsed(); + // It should wait for MASK_TIMEOUT (50ms in tests) even if connection was refused immediately + assert!(elapsed >= Duration::from_millis(40), "Must respect connect budget even on failure: {:?}", elapsed); +} + +// ------------------------------------------------------------------ +// Priority 2: SSRF Prevention (OWASP ASVS 5.1.2) +// ------------------------------------------------------------------ + +#[tokio::test] +async fn masking_ssrf_resolve_internal_ranges_blocked() { + use crate::network::dns_overrides::resolve_socket_addr; + + let blocked_ips = ["127.0.0.1", "169.254.169.254", "10.0.0.1", "192.168.1.1", "0.0.0.0"]; + + for ip in blocked_ips { + assert!( + resolve_socket_addr(ip, 80).is_none(), + "runtime DNS overrides must not resolve unconfigured literal host targets" + ); + } +} diff --git a/src/proxy/middle_relay_security_tests.rs b/src/proxy/middle_relay_security_tests.rs index 896e465..b8ed52a 100644 --- a/src/proxy/middle_relay_security_tests.rs +++ b/src/proxy/middle_relay_security_tests.rs @@ -1,10 +1,11 @@ use super::*; +use crate::proxy::handshake::HandshakeSuccess; +use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController}; use bytes::Bytes; use crate::crypto::AesCtr; use crate::crypto::SecureRandom; use crate::config::{GeneralConfig, MeRouteNoWriterMode, MeSocksKdfPolicy, MeWriterPickMode}; use crate::network::probe::NetworkDecision; -use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController}; use crate::stats::Stats; use crate::stream::{BufferPool, CryptoReader, CryptoWriter, PooledBuffer}; use crate::transport::middle_proxy::MePool; @@ -20,6 +21,7 @@ use tokio::io::AsyncReadExt; use tokio::io::AsyncWriteExt; use tokio::io::duplex; use tokio::time::{Duration as TokioDuration, timeout}; +use std::sync::{Mutex, OnceLock}; fn make_pooled_payload(data: &[u8]) -> PooledBuffer { let pool = Arc::new(BufferPool::with_config(data.len().max(1), 4)); @@ -36,6 +38,11 @@ fn make_pooled_payload_from(pool: &Arc, data: &[u8]) -> PooledBuffer payload } +fn quota_user_lock_test_lock() -> &'static Mutex<()> { + static TEST_LOCK: OnceLock> = OnceLock::new(); + TEST_LOCK.get_or_init(|| Mutex::new(())) +} + #[test] fn should_yield_sender_only_on_budget_with_backlog() { assert!(!should_yield_c2me_sender(0, true)); @@ -244,6 +251,10 @@ fn quota_user_lock_cache_reuses_entry_for_same_user() { #[test] fn quota_user_lock_cache_is_bounded_under_unique_churn() { + let _guard = quota_user_lock_test_lock() + .lock() + .expect("quota user lock test lock must be available"); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); map.clear(); @@ -261,39 +272,51 @@ fn quota_user_lock_cache_is_bounded_under_unique_churn() { #[test] fn quota_user_lock_cache_saturation_returns_ephemeral_lock_without_growth() { - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); + let _guard = quota_user_lock_test_lock() + .lock() + .expect("quota user lock test lock must be available"); - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - for idx in 0..QUOTA_USER_LOCKS_MAX { - let user = format!("quota-held-user-{idx}"); - retained.push(quota_user_lock(&user)); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + for attempt in 0..8u32 { + map.clear(); + + let prefix = format!("quota-held-user-{}-{attempt}", std::process::id()); + let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); + for idx in 0..QUOTA_USER_LOCKS_MAX { + let user = format!("{prefix}-{idx}"); + retained.push(quota_user_lock(&user)); + } + + if map.len() != QUOTA_USER_LOCKS_MAX { + drop(retained); + continue; + } + + let overflow_user = format!("quota-overflow-user-{}-{attempt}", std::process::id()); + let overflow_a = quota_user_lock(&overflow_user); + let overflow_b = quota_user_lock(&overflow_user); + + assert_eq!( + map.len(), + QUOTA_USER_LOCKS_MAX, + "overflow acquisition must not grow cache past hard limit" + ); + assert!( + map.get(&overflow_user).is_none(), + "overflow path should not cache new user lock when map is saturated and all entries are retained" + ); + assert!( + !Arc::ptr_eq(&overflow_a, &overflow_b), + "overflow user lock should be ephemeral under saturation to preserve bounded cache size" + ); + + drop(retained); + return; } - assert_eq!( - map.len(), - QUOTA_USER_LOCKS_MAX, - "precondition: cache should be full before overflow acquisition" + panic!( + "unable to observe stable saturated lock-cache precondition after bounded retries" ); - - let overflow_a = quota_user_lock("quota-overflow-user"); - let overflow_b = quota_user_lock("quota-overflow-user"); - - assert_eq!( - map.len(), - QUOTA_USER_LOCKS_MAX, - "overflow acquisition must not grow cache past hard limit" - ); - assert!( - map.get("quota-overflow-user").is_none(), - "overflow path should not cache new user lock when map is saturated and all entries are retained" - ); - assert!( - !Arc::ptr_eq(&overflow_a, &overflow_b), - "overflow user lock should be ephemeral under saturation to preserve bounded cache size" - ); - - drop(retained); } #[tokio::test(flavor = "multi_thread", worker_threads = 4)] @@ -2169,3 +2192,320 @@ async fn middle_relay_cutover_storm_multi_session_keeps_generic_errors_and_relea drop(client_sides); } + +#[tokio::test] +async fn secure_padding_distribution_in_relay_writer() { + timeout(TokioDuration::from_secs(10), async { + let (mut client_side, relay_side) = duplex(512 * 1024); + let key = [0u8; 32]; + let iv = 0u128; + let mut writer = CryptoWriter::new(relay_side, AesCtr::new(&key, iv), 8 * 1024); + let rng = Arc::new(SecureRandom::new()); + let mut frame_buf = Vec::new(); + let mut decryptor = AesCtr::new(&key, iv); + + let mut padding_counts = [0usize; 4]; + let iterations = 180usize; + let payload = vec![0xAAu8; 100]; // 4-byte aligned + + for _ in 0..iterations { + write_client_payload( + &mut writer, + ProtoTag::Secure, + 0, + &payload, + &rng, + &mut frame_buf, + ) + .await + .expect("payload write must succeed"); + writer + .flush() + .await + .expect("writer flush must complete so encrypted frame becomes readable"); + + let mut len_buf = [0u8; 4]; + client_side + .read_exact(&mut len_buf) + .await + .expect("must read encrypted secure length"); + let decrypted_len_bytes = decryptor.decrypt(&len_buf); + let decrypted_len_bytes: [u8; 4] = decrypted_len_bytes + .try_into() + .expect("decrypted length must be 4 bytes"); + let wire_len = (u32::from_le_bytes(decrypted_len_bytes) & 0x7fff_ffff) as usize; + + assert!( + wire_len >= payload.len(), + "wire length must include at least payload bytes" + ); + let padding_len = wire_len - payload.len(); + assert!(padding_len >= 1 && padding_len <= 3); + padding_counts[padding_len] += 1; + + // Drain and decrypt frame bytes so CTR state stays aligned across writes. + let mut trash = vec![0u8; wire_len]; + client_side + .read_exact(&mut trash) + .await + .expect("must read encrypted secure frame body"); + let _ = decryptor.decrypt(&trash); + } + + for p in 1..=3 { + let count = padding_counts[p]; + assert!( + count > iterations / 8, + "padding length {p} is under-represented ({count}/{iterations})" + ); + } + }) + .await + .expect("secure padding distribution test exceeded runtime budget"); +} + +#[tokio::test] +async fn negative_middle_end_connection_lost_during_relay_exits_on_client_eof() { + let (client_reader_side, client_writer_side) = duplex(1024); + let (_relay_reader_side, relay_writer_side) = duplex(1024); + + let key = [0u8; 32]; + let iv = 0u128; + let crypto_reader = CryptoReader::new(client_reader_side, AesCtr::new(&key, iv)); + let crypto_writer = CryptoWriter::new(relay_writer_side, AesCtr::new(&key, iv), 1024); + + let stats = Arc::new(Stats::new()); + let config = Arc::new(ProxyConfig::default()); + let buffer_pool = Arc::new(BufferPool::with_config(1024, 1)); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = RouteRuntimeController::new(RelayRouteMode::Middle); + + // Create an ME pool. + let me_pool = make_me_pool_for_abort_test(stats.clone()).await; + + // ConnRegistry ids are monotonic; reserve one id so we can predict the + // next session conn_id and close it deterministically without relying on + // writer-bound views such as active_conn_ids(). + let (probe_conn_id, probe_rx) = me_pool.registry().register().await; + drop(probe_rx); + me_pool.registry().unregister(probe_conn_id).await; + let target_conn_id = probe_conn_id.wrapping_add(1); + + let success = HandshakeSuccess { + user: "test-user".to_string(), + peer: "127.0.0.1:12345".parse().unwrap(), + dc_idx: 1, + proto_tag: ProtoTag::Intermediate, + enc_key: key, + enc_iv: iv, + dec_key: key, + dec_iv: iv, + is_tls: false, + }; + + let session_task = tokio::spawn(handle_via_middle_proxy( + crypto_reader, + crypto_writer, + success, + me_pool.clone(), + stats.clone(), + config.clone(), + buffer_pool.clone(), + "127.0.0.1:443".parse().unwrap(), + rng.clone(), + route_runtime.subscribe(), + route_runtime.snapshot(), + 0x1234_5678, + )); + + // Wait until session startup is visible, then unregister the predicted + // conn_id to close the per-session ME response channel. + timeout(TokioDuration::from_millis(500), async { + loop { + if stats.get_current_connections_me() >= 1 { + break; + } + tokio::time::sleep(TokioDuration::from_millis(10)).await; + } + }) + .await + .expect("ME session must start before channel close simulation"); + + me_pool.registry().unregister(target_conn_id).await; + + drop(client_writer_side); + + let result = timeout(TokioDuration::from_secs(2), session_task) + .await + .expect("Session task must terminate after ME drop and client EOF") + .expect("Session task must not panic"); + + assert!( + result.is_ok(), + "Session should complete cleanly after ME drop when client closes, got: {:?}", + result + ); +} + +#[tokio::test] +async fn adversarial_middle_end_drop_plus_cutover_returns_generic_route_switch() { + let (client_reader_side, _client_writer_side) = duplex(1024); + let (_relay_reader_side, relay_writer_side) = duplex(1024); + + let key = [0u8; 32]; + let iv = 0u128; + let crypto_reader = CryptoReader::new(client_reader_side, AesCtr::new(&key, iv)); + let crypto_writer = CryptoWriter::new(relay_writer_side, AesCtr::new(&key, iv), 1024); + + let stats = Arc::new(Stats::new()); + let config = Arc::new(ProxyConfig::default()); + let buffer_pool = Arc::new(BufferPool::with_config(1024, 1)); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Middle)); + + let me_pool = make_me_pool_for_abort_test(stats.clone()).await; + + // Predict the next conn_id so we can force-drop its ME channel deterministically. + let (probe_conn_id, probe_rx) = me_pool.registry().register().await; + drop(probe_rx); + me_pool.registry().unregister(probe_conn_id).await; + let target_conn_id = probe_conn_id.wrapping_add(1); + + let success = HandshakeSuccess { + user: "test-user-cutover".to_string(), + peer: "127.0.0.1:12345".parse().unwrap(), + dc_idx: 1, + proto_tag: ProtoTag::Intermediate, + enc_key: key, + enc_iv: iv, + dec_key: key, + dec_iv: iv, + is_tls: false, + }; + + let runtime_clone = route_runtime.clone(); + let session_task = tokio::spawn(handle_via_middle_proxy( + crypto_reader, + crypto_writer, + success, + me_pool.clone(), + stats.clone(), + config, + buffer_pool, + "127.0.0.1:443".parse().unwrap(), + rng, + runtime_clone.subscribe(), + runtime_clone.snapshot(), + 0xC001_CAFE, + )); + + timeout(TokioDuration::from_millis(500), async { + loop { + if stats.get_current_connections_me() >= 1 { + break; + } + tokio::time::sleep(TokioDuration::from_millis(10)).await; + } + }) + .await + .expect("ME session must start before race trigger"); + + // Race ME channel drop with route cutover and assert generic client-visible outcome. + me_pool.registry().unregister(target_conn_id).await; + assert!( + route_runtime.set_mode(RelayRouteMode::Direct).is_some(), + "cutover must advance generation" + ); + + let relay_result = timeout(TokioDuration::from_secs(6), session_task) + .await + .expect("session must terminate under ME-drop + cutover race") + .expect("session task must not panic"); + + assert!( + matches!( + relay_result, + Err(ProxyError::Proxy(ref msg)) if msg == ROUTE_SWITCH_ERROR_MSG + ), + "race outcome must remain generic and not leak ME internals, got: {:?}", + relay_result + ); +} + +#[tokio::test] +async fn stress_middle_end_drop_with_client_eof_never_hangs_across_burst() { + let stats = Arc::new(Stats::new()); + let me_pool = make_me_pool_for_abort_test(stats.clone()).await; + + for round in 0..32u64 { + let (client_reader_side, client_writer_side) = duplex(1024); + let (_relay_reader_side, relay_writer_side) = duplex(1024); + + let key = [0u8; 32]; + let iv = 0u128; + let crypto_reader = CryptoReader::new(client_reader_side, AesCtr::new(&key, iv)); + let crypto_writer = CryptoWriter::new(relay_writer_side, AesCtr::new(&key, iv), 1024); + + let config = Arc::new(ProxyConfig::default()); + let buffer_pool = Arc::new(BufferPool::with_config(1024, 1)); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = RouteRuntimeController::new(RelayRouteMode::Middle); + + let (probe_conn_id, probe_rx) = me_pool.registry().register().await; + drop(probe_rx); + me_pool.registry().unregister(probe_conn_id).await; + let target_conn_id = probe_conn_id.wrapping_add(1); + + let success = HandshakeSuccess { + user: format!("stress-me-drop-eof-{round}"), + peer: "127.0.0.1:12345".parse().unwrap(), + dc_idx: 1, + proto_tag: ProtoTag::Intermediate, + enc_key: key, + enc_iv: iv, + dec_key: key, + dec_iv: iv, + is_tls: false, + }; + + let session_task = tokio::spawn(handle_via_middle_proxy( + crypto_reader, + crypto_writer, + success, + me_pool.clone(), + stats.clone(), + config, + buffer_pool, + "127.0.0.1:443".parse().unwrap(), + rng, + route_runtime.subscribe(), + route_runtime.snapshot(), + 0xD00D_0000 + round, + )); + + timeout(TokioDuration::from_millis(500), async { + loop { + if stats.get_current_connections_me() >= 1 { + break; + } + tokio::time::sleep(TokioDuration::from_millis(10)).await; + } + }) + .await + .expect("session must start before forced drop in burst round"); + + me_pool.registry().unregister(target_conn_id).await; + drop(client_writer_side); + + let result = timeout(TokioDuration::from_secs(2), session_task) + .await + .expect("burst round session must terminate quickly") + .expect("burst round session must not panic"); + + assert!( + result.is_ok(), + "burst round {round}: expected clean shutdown after ME drop + EOF, got: {:?}", + result + ); + } +} diff --git a/src/proxy/relay.rs b/src/proxy/relay.rs index 8b4c87f..a742e33 100644 --- a/src/proxy/relay.rs +++ b/src/proxy/relay.rs @@ -659,3 +659,5 @@ where #[cfg(test)] #[path = "relay_security_tests.rs"] mod security_tests; +#[path = "relay_adversarial_tests.rs"] +mod adversarial_tests; \ No newline at end of file diff --git a/src/proxy/relay_adversarial_tests.rs b/src/proxy/relay_adversarial_tests.rs new file mode 100644 index 0000000..08de0b8 --- /dev/null +++ b/src/proxy/relay_adversarial_tests.rs @@ -0,0 +1,122 @@ +use super::*; +use crate::error::ProxyError; +use crate::stats::Stats; +use crate::stream::BufferPool; +use std::sync::Arc; +use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; +use tokio::time::{Duration, Instant, timeout}; + +// ------------------------------------------------------------------ +// Priority 3: Async Relay HOL Blocking Prevention (OWASP ASVS 5.1.5) +// ------------------------------------------------------------------ + +#[tokio::test] +async fn relay_hol_blocking_prevention_regression() { + let stats = Arc::new(Stats::new()); + let user = "hol-user"; + + let (client_peer, relay_client) = duplex(65536); + let (relay_server, server_peer) = duplex(65536); + + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + let (mut cp_reader, mut cp_writer) = tokio::io::split(client_peer); + let (mut sp_reader, mut sp_writer) = tokio::io::split(server_peer); + + let relay_task = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 8192, + 8192, + user, + Arc::clone(&stats), + None, + Arc::new(BufferPool::new()), + )); + + let payload_size = 1024 * 10; + let s2c_payload = vec![0x41; payload_size]; + let c2s_payload = vec![0x42; payload_size]; + + let s2c_handle = tokio::spawn(async move { + sp_writer.write_all(&s2c_payload).await.unwrap(); + + let mut total_read = 0; + let mut buf = [0u8; 10]; + while total_read < payload_size { + let n = cp_reader.read(&mut buf).await.unwrap(); + total_read += n; + tokio::time::sleep(Duration::from_millis(100)).await; + } + }); + + let start = Instant::now(); + cp_writer.write_all(&c2s_payload).await.unwrap(); + + let mut server_buf = vec![0u8; payload_size]; + sp_reader.read_exact(&mut server_buf).await.unwrap(); + let elapsed = start.elapsed(); + + assert!(elapsed < Duration::from_millis(1000), "C->S must not be blocked by slow S->C (HOL blocking): {:?}", elapsed); + assert_eq!(server_buf, c2s_payload); + + s2c_handle.abort(); + relay_task.abort(); +} + +// ------------------------------------------------------------------ +// Priority 3: Data Quota Mid-Session Cutoff (OWASP ASVS 5.1.6) +// ------------------------------------------------------------------ + +#[tokio::test] +async fn relay_quota_mid_session_cutoff() { + let stats = Arc::new(Stats::new()); + let user = "quota-mid-user"; + let quota = 5000; + + let (client_peer, relay_client) = duplex(8192); + let (relay_server, server_peer) = duplex(8192); + + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + let (mut _cp_reader, mut cp_writer) = tokio::io::split(client_peer); + let (mut sp_reader, _sp_writer) = tokio::io::split(server_peer); + + let relay_task = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 1024, + 1024, + user, + Arc::clone(&stats), + Some(quota), + Arc::new(BufferPool::new()), + )); + + // Send 4000 bytes (Ok) + let buf1 = vec![0x42; 4000]; + cp_writer.write_all(&buf1).await.unwrap(); + let mut server_recv = vec![0u8; 4000]; + sp_reader.read_exact(&mut server_recv).await.unwrap(); + + // Send another 2000 bytes (Total 6000 > 5000) + let buf2 = vec![0x42; 2000]; + let _ = cp_writer.write_all(&buf2).await; + + let relay_res = timeout(Duration::from_secs(1), relay_task).await.unwrap(); + + match relay_res { + Ok(Err(ProxyError::DataQuotaExceeded { .. })) => { + // Expected + } + other => panic!("Expected DataQuotaExceeded error, got: {:?}", other), + } + + let mut small_buf = [0u8; 1]; + let n = sp_reader.read(&mut small_buf).await.unwrap(); + assert_eq!(n, 0, "Server must see EOF after quota reached"); +} diff --git a/src/proxy/relay_security_tests.rs b/src/proxy/relay_security_tests.rs index 9ba8295..4b002a4 100644 --- a/src/proxy/relay_security_tests.rs +++ b/src/proxy/relay_security_tests.rs @@ -1140,3 +1140,46 @@ async fn relay_bidirectional_light_fuzz_permission_denied_messages_remain_io_err ); } } + +#[tokio::test] +async fn relay_half_close_keeps_reverse_direction_progressing() { + let stats = Arc::new(Stats::new()); + let user = "half-close-user"; + + let (client_peer, relay_client) = duplex(1024); + let (relay_server, server_peer) = duplex(1024); + + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + let (mut cp_reader, mut cp_writer) = tokio::io::split(client_peer); + let (mut sp_reader, mut sp_writer) = tokio::io::split(server_peer); + + let relay_task = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 8192, + 8192, + user, + Arc::clone(&stats), + None, + Arc::new(BufferPool::new()), + )); + + sp_writer.write_all(&[0x10, 0x20, 0x30, 0x40]).await.unwrap(); + sp_writer.shutdown().await.unwrap(); + + let mut inbound = [0u8; 4]; + cp_reader.read_exact(&mut inbound).await.unwrap(); + assert_eq!(inbound, [0x10, 0x20, 0x30, 0x40]); + + cp_writer.write_all(&[0xaa, 0xbb, 0xcc, 0xdd]).await.unwrap(); + let mut outbound = [0u8; 4]; + sp_reader.read_exact(&mut outbound).await.unwrap(); + assert_eq!(outbound, [0xaa, 0xbb, 0xcc, 0xdd]); + + relay_task.abort(); + let joined = relay_task.await; + assert!(joined.is_err(), "aborted relay task must return join error"); +}