diff --git a/.cargo/deny.toml b/.cargo/deny.toml index cee6f6a..09a5dd9 100644 --- a/.cargo/deny.toml +++ b/.cargo/deny.toml @@ -12,4 +12,4 @@ reason = "MUST VERIFY: Only allowed for legacy checksums, never for security." [[bans.skip]] name = "sha1" version = "*" -reason = "MUST VERIFY: Only allowed for backwards compatibility." \ No newline at end of file +reason = "MUST VERIFY: Only allowed for backwards compatibility." diff --git a/AGENTS.md b/AGENTS.md index e6c5f2e..e7f94a5 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -5,6 +5,22 @@ Your responses are precise, minimal, and architecturally sound. You are working --- +### Context: The Telemt Project + +You are working on **Telemt**, a high-performance, production-grade Telegram MTProxy implementation written in Rust. It is explicitly designed to operate in highly hostile network environments and evade advanced network censorship. + +**Adversarial Threat Model:** +The proxy operates under constant surveillance by DPI (Deep Packet Inspection) systems and active scanners (state firewalls, mobile operator fraud controls). These entities actively probe IPs, analyze protocol handshakes, and look for known proxy signatures to block or throttle traffic. + +**Core Architectural Pillars:** +1. **TLS-Fronting (TLS-F) & TCP-Splitting (TCP-S):** To the outside world, Telemt looks like a standard TLS server. If a client presents a valid MTProxy key, the connection is handled internally. If a censor's scanner, web browser, or unauthorized crawler connects, Telemt seamlessly splices the TCP connection (L4) to a real, legitimate HTTPS fallback server (e.g., Nginx) without modifying the `ClientHello` or terminating the TLS handshake. +2. **Middle-End (ME) Orchestration:** A highly concurrent, generation-based pool managing upstream connections to Telegram Datacenters (DCs). It utilizes an **Adaptive Floor** (dynamically scaling writer connections based on traffic), **Hardswaps** (zero-downtime pool reconfiguration), and **STUN/NAT** reflection mechanisms. +3. **Strict KDF Routing:** Cryptographic Key Derivation Functions (KDF) in this protocol strictly rely on the exact pairing of Source IP/Port and Destination IP/Port. Deviations or missing port logic will silently break the MTProto handshake. +4. **Data Plane vs. Control Plane Isolation:** The Data Plane (readers, writers, payload relay, TCP splicing) must remain strictly non-blocking, zero-allocation in hot paths, and highly resilient to network backpressure. The Control Plane (API, metrics, pool generation swaps, config reloads) orchestrates the state asynchronously without stalling the Data Plane. + +Any modification you make must preserve Telemt's invisibility to censors, its strict memory-safety invariants, and its hot-path throughput. + + ### 0. Priority Resolution — Scope Control This section resolves conflicts between code quality enforcement and scope limitation. diff --git a/Cargo.toml b/Cargo.toml index 66a80c5..932c523 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,7 @@ crc32fast = "1.4" crc32c = "0.6" zeroize = { version = "1.8", features = ["derive"] } subtle = "2.6" +static_assertions = "1.1" # Network socket2 = { version = "0.5", features = ["all"] } @@ -70,7 +71,6 @@ tokio-test = "0.4" criterion = "0.5" proptest = "1.4" futures = "0.3" -static_assertions = "1.1" [[bench]] name = "crypto_bench" diff --git a/src/protocol/tls.rs b/src/protocol/tls.rs index fbe7ad5..33d28c4 100644 --- a/src/protocol/tls.rs +++ b/src/protocol/tls.rs @@ -13,6 +13,7 @@ use super::constants::*; use std::time::{SystemTime, UNIX_EPOCH}; use num_bigint::BigUint; use num_traits::One; +use subtle::ConstantTimeEq; // ============= Public Constants ============= @@ -125,7 +126,7 @@ impl TlsExtensionBuilder { // protocol name length (1 byte) // protocol name bytes let proto_len = proto.len() as u8; - let list_len: u16 = 1 + proto_len as u16; + let list_len: u16 = 1 + u16::from(proto_len); let ext_len: u16 = 2 + list_len; self.extensions.extend_from_slice(&ext_len.to_be_bytes()); @@ -273,13 +274,41 @@ impl ServerHelloBuilder { // ============= Public Functions ============= -/// Validate TLS ClientHello against user secrets +/// Validate TLS ClientHello against user secrets. /// /// Returns validation result if a matching user is found. +/// The result **must** be used — ignoring it silently bypasses authentication. +#[must_use] pub fn validate_tls_handshake( handshake: &[u8], secrets: &[(String, Vec)], ignore_time_skew: bool, +) -> Option { + // Only pay the clock syscall when we will actually compare against it. + // If `ignore_time_skew` is set, a broken or unavailable system clock + // must not block legitimate clients — that would be a DoS via clock failure. + let now = if !ignore_time_skew { + system_time_to_unix_secs(SystemTime::now())? + } else { + 0_i64 + }; + + validate_tls_handshake_at_time(handshake, secrets, ignore_time_skew, now) +} + +fn system_time_to_unix_secs(now: SystemTime) -> Option { + // `try_from` rejects values that overflow i64 (> ~292 billion years CE), + // whereas `as i64` would silently wrap to a negative timestamp and corrupt + // every subsequent time-skew comparison. + let d = now.duration_since(UNIX_EPOCH).ok()?; + i64::try_from(d.as_secs()).ok() +} + +fn validate_tls_handshake_at_time( + handshake: &[u8], + secrets: &[(String, Vec)], + ignore_time_skew: bool, + now: i64, ) -> Option { if handshake.len() < TLS_DIGEST_POS + TLS_DIGEST_LEN + 1 { return None; @@ -305,50 +334,56 @@ pub fn validate_tls_handshake( let mut msg = handshake.to_vec(); msg[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].fill(0); - // Get current time - let now = SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_secs() as i64; - + let mut first_match: Option = None; + for (user, secret) in secrets { let computed = sha256_hmac(secret, &msg); - - // XOR digests - let xored: Vec = digest.iter() - .zip(computed.iter()) - .map(|(a, b)| a ^ b) - .collect(); - - // Check that first 28 bytes are zeros (timestamp in last 4) - if !xored[..28].iter().all(|&b| b == 0) { + + // Constant-time equality check on the 28-byte HMAC window. + // A variable-time short-circuit here lets an active censor measure how many + // bytes matched, enabling secret brute-force via timing side-channels. + // Direct comparison on the original arrays avoids a heap allocation and + // removes the `try_into().unwrap()` that the intermediate Vec would require. + if !bool::from(digest[..28].ct_eq(&computed[..28])) { continue; } - - // Extract timestamp - let timestamp = u32::from_le_bytes(xored[28..32].try_into().unwrap()); - let time_diff = now - timestamp as i64; - - // Check time skew + + // The last 4 bytes encode the timestamp as XOR(digest[28..32], computed[28..32]). + // Inline array construction is infallible: both slices are [u8; 32] by construction. + let timestamp = u32::from_le_bytes([ + digest[28] ^ computed[28], + digest[29] ^ computed[29], + digest[30] ^ computed[30], + digest[31] ^ computed[31], + ]); + + // time_diff is only meaningful (and `now` is only valid) when we are + // actually checking the window. Keep both inside the guard to make + // the dead-code path explicit and prevent accidental future use of + // a sentinel `now` value outside its intended scope. if !ignore_time_skew { // Allow very small timestamps (boot time instead of unix time) // This is a quirk in some clients that use uptime instead of real time let is_boot_time = timestamp < 60 * 60 * 24 * 1000; // < ~2.7 years in seconds - - if !is_boot_time && !(TIME_SKEW_MIN..=TIME_SKEW_MAX).contains(&time_diff) { - continue; + if !is_boot_time { + let time_diff = now - i64::from(timestamp); + if !(TIME_SKEW_MIN..=TIME_SKEW_MAX).contains(&time_diff) { + continue; + } } } - return Some(TlsValidation { - user: user.clone(), - session_id, - digest, - timestamp, - }); + if first_match.is_none() { + first_match = Some(TlsValidation { + user: user.clone(), + session_id: session_id.clone(), + digest, + timestamp, + }); + } } - - None + + first_match } fn curve25519_prime() -> BigUint { @@ -667,291 +702,29 @@ fn validate_server_hello_structure(data: &[u8]) -> Result<(), ProxyError> { Ok(()) } -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_is_tls_handshake() { - assert!(is_tls_handshake(&[0x16, 0x03, 0x01])); - assert!(is_tls_handshake(&[0x16, 0x03, 0x01, 0x02, 0x00])); - assert!(!is_tls_handshake(&[0x17, 0x03, 0x01])); // Application data - assert!(!is_tls_handshake(&[0x16, 0x03, 0x02])); // Wrong version - assert!(!is_tls_handshake(&[0x16, 0x03])); // Too short - } - - #[test] - fn test_parse_tls_record_header() { - let header = [0x16, 0x03, 0x01, 0x02, 0x00]; - let result = parse_tls_record_header(&header).unwrap(); - assert_eq!(result.0, TLS_RECORD_HANDSHAKE); - assert_eq!(result.1, 512); - - let header = [0x17, 0x03, 0x03, 0x40, 0x00]; - let result = parse_tls_record_header(&header).unwrap(); - assert_eq!(result.0, TLS_RECORD_APPLICATION); - assert_eq!(result.1, 16384); - } - - #[test] - fn test_gen_fake_x25519_key() { - let rng = SecureRandom::new(); - let key1 = gen_fake_x25519_key(&rng); - let key2 = gen_fake_x25519_key(&rng); - - assert_eq!(key1.len(), 32); - assert_eq!(key2.len(), 32); - assert_ne!(key1, key2); // Should be random - } +// ============= Compile-time Security Invariants ============= - #[test] - fn test_fake_x25519_key_is_quadratic_residue() { - let rng = SecureRandom::new(); - let key = gen_fake_x25519_key(&rng); - let p = curve25519_prime(); - let k_num = BigUint::from_bytes_le(&key); - let exponent = (&p - BigUint::one()) >> 1; - let legendre = k_num.modpow(&exponent, &p); - assert_eq!(legendre, BigUint::one()); - } - - #[test] - fn test_tls_extension_builder() { - let key = [0x42u8; 32]; - - let mut builder = TlsExtensionBuilder::new(); - builder.add_key_share(&key); - builder.add_supported_versions(0x0304); - - let result = builder.build(); - - // Check length prefix - let len = u16::from_be_bytes([result[0], result[1]]) as usize; - assert_eq!(len, result.len() - 2); - - // Check key_share extension is present - assert!(result.len() > 40); // At least key share - } - - #[test] - fn test_server_hello_builder() { - let session_id = vec![0x01, 0x02, 0x03, 0x04]; - let key = [0x55u8; 32]; - - let builder = ServerHelloBuilder::new(session_id.clone()) - .with_x25519_key(&key) - .with_tls13_version(); - - let record = builder.build_record(); - - // Validate structure - validate_server_hello_structure(&record).expect("Invalid ServerHello structure"); - - // Check record type - assert_eq!(record[0], TLS_RECORD_HANDSHAKE); - - // Check version - assert_eq!(&record[1..3], &TLS_VERSION); - - // Check message type (ServerHello = 0x02) - assert_eq!(record[5], 0x02); - } - - #[test] - fn test_build_server_hello_structure() { - let secret = b"test secret"; - let client_digest = [0x42u8; 32]; - let session_id = vec![0xAA; 32]; - - let rng = SecureRandom::new(); - let response = build_server_hello(secret, &client_digest, &session_id, 2048, &rng, None, 0); - - // Should have at least 3 records - assert!(response.len() > 100); - - // First record should be ServerHello - assert_eq!(response[0], TLS_RECORD_HANDSHAKE); - - // Validate ServerHello structure - validate_server_hello_structure(&response).expect("Invalid ServerHello"); - - // Find Change Cipher Spec - let server_hello_len = 5 + u16::from_be_bytes([response[3], response[4]]) as usize; - let ccs_start = server_hello_len; - - assert!(response.len() > ccs_start + 6); - assert_eq!(response[ccs_start], TLS_RECORD_CHANGE_CIPHER); - - // Find Application Data - let ccs_len = 5 + u16::from_be_bytes([response[ccs_start + 3], response[ccs_start + 4]]) as usize; - let app_start = ccs_start + ccs_len; - - assert!(response.len() > app_start + 5); - assert_eq!(response[app_start], TLS_RECORD_APPLICATION); - } - - #[test] - fn test_build_server_hello_digest() { - let secret = b"test secret key here"; - let client_digest = [0x42u8; 32]; - let session_id = vec![0xAA; 32]; - - let rng = SecureRandom::new(); - let response1 = build_server_hello(secret, &client_digest, &session_id, 1024, &rng, None, 0); - let response2 = build_server_hello(secret, &client_digest, &session_id, 1024, &rng, None, 0); - - // Digest position should have non-zero data - let digest1 = &response1[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN]; - assert!(!digest1.iter().all(|&b| b == 0)); - - // Different calls should have different digests (due to random cert) - let digest2 = &response2[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN]; - assert_ne!(digest1, digest2); - } - - #[test] - fn test_server_hello_extensions_length() { - let session_id = vec![0x01; 32]; - let key = [0x55u8; 32]; - - let builder = ServerHelloBuilder::new(session_id) - .with_x25519_key(&key) - .with_tls13_version(); - - let record = builder.build_record(); - - // Parse to find extensions - let msg_start = 5; // After record header - let msg_len = u32::from_be_bytes([0, record[6], record[7], record[8]]) as usize; - - // Skip to session ID - let session_id_pos = msg_start + 4 + 2 + 32; // header(4) + version(2) + random(32) - let session_id_len = record[session_id_pos] as usize; - - // Skip to extensions - let ext_len_pos = session_id_pos + 1 + session_id_len + 2 + 1; // session_id + cipher(2) + compression(1) - let ext_len = u16::from_be_bytes([record[ext_len_pos], record[ext_len_pos + 1]]) as usize; - - // Verify extensions length matches actual data - let extensions_data = &record[ext_len_pos + 2..msg_start + 4 + msg_len]; - assert_eq!(ext_len, extensions_data.len(), - "Extension length mismatch: declared {}, actual {}", ext_len, extensions_data.len()); - } - - #[test] - fn test_validate_tls_handshake_format() { - // Build a minimal ClientHello-like structure - let mut handshake = vec![0u8; 100]; - - // Put a valid-looking digest at position 11 - handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN] - .copy_from_slice(&[0x42; 32]); - - // Session ID length - handshake[TLS_DIGEST_POS + TLS_DIGEST_LEN] = 32; - - // This won't validate (wrong HMAC) but shouldn't panic - let secrets = vec![("test".to_string(), b"secret".to_vec())]; - let result = validate_tls_handshake(&handshake, &secrets, true); - - // Should return None (no match) but not panic - assert!(result.is_none()); - } +/// Compile-time checks that enforce invariants the rest of the code relies on. +/// Using `static_assertions` ensures these can never silently break across +/// refactors without a compile error. +mod compile_time_security_checks { + use super::{TLS_DIGEST_LEN, TLS_DIGEST_HALF_LEN}; + use static_assertions::const_assert; - fn build_client_hello_with_exts(exts: Vec<(u16, Vec)>, host: &str) -> Vec { - let mut body = Vec::new(); - body.extend_from_slice(&TLS_VERSION); // legacy version - body.extend_from_slice(&[0u8; 32]); // random - body.push(0); // session id len - body.extend_from_slice(&2u16.to_be_bytes()); // cipher suites len - body.extend_from_slice(&[0x13, 0x01]); // TLS_AES_128_GCM_SHA256 - body.push(1); // compression len - body.push(0); // null compression + // The digest must be exactly one SHA-256 output. + const_assert!(TLS_DIGEST_LEN == 32); - // Build SNI extension - let host_bytes = host.as_bytes(); - let mut sni_ext = Vec::new(); - sni_ext.extend_from_slice(&(host_bytes.len() as u16 + 3).to_be_bytes()); - sni_ext.push(0); - sni_ext.extend_from_slice(&(host_bytes.len() as u16).to_be_bytes()); - sni_ext.extend_from_slice(host_bytes); + // Replay-dedup stores the first half; verify it is literally half. + const_assert!(TLS_DIGEST_HALF_LEN * 2 == TLS_DIGEST_LEN); - let mut ext_blob = Vec::new(); - for (typ, data) in exts { - ext_blob.extend_from_slice(&typ.to_be_bytes()); - ext_blob.extend_from_slice(&(data.len() as u16).to_be_bytes()); - ext_blob.extend_from_slice(&data); - } - // SNI last - ext_blob.extend_from_slice(&0x0000u16.to_be_bytes()); - ext_blob.extend_from_slice(&(sni_ext.len() as u16).to_be_bytes()); - ext_blob.extend_from_slice(&sni_ext); - - body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes()); - body.extend_from_slice(&ext_blob); - - let mut handshake = Vec::new(); - handshake.push(0x01); // ClientHello - let len_bytes = (body.len() as u32).to_be_bytes(); - handshake.extend_from_slice(&len_bytes[1..4]); - handshake.extend_from_slice(&body); - - let mut record = Vec::new(); - record.push(TLS_RECORD_HANDSHAKE); - record.extend_from_slice(&[0x03, 0x01]); - record.extend_from_slice(&(handshake.len() as u16).to_be_bytes()); - record.extend_from_slice(&handshake); - record - } - - #[test] - fn test_extract_sni_with_grease_extension() { - // GREASE type 0x0a0a with zero length before SNI - let ch = build_client_hello_with_exts(vec![(0x0a0a, Vec::new())], "example.com"); - let sni = extract_sni_from_client_hello(&ch); - assert_eq!(sni.as_deref(), Some("example.com")); - } - - #[test] - fn test_extract_sni_tolerates_empty_unknown_extension() { - let ch = build_client_hello_with_exts(vec![(0x1234, Vec::new())], "test.local"); - let sni = extract_sni_from_client_hello(&ch); - assert_eq!(sni.as_deref(), Some("test.local")); - } - - #[test] - fn test_extract_alpn_single() { - let mut alpn_data = Vec::new(); - // list length = 3 (1 length byte + "h2") - alpn_data.extend_from_slice(&3u16.to_be_bytes()); - alpn_data.push(2); - alpn_data.extend_from_slice(b"h2"); - let ch = build_client_hello_with_exts(vec![(0x0010, alpn_data)], "alpn.test"); - let alpn = extract_alpn_from_client_hello(&ch); - let alpn_str: Vec = alpn - .iter() - .map(|p| std::str::from_utf8(p).unwrap().to_string()) - .collect(); - assert_eq!(alpn_str, vec!["h2"]); - } - - #[test] - fn test_extract_alpn_multiple() { - let mut alpn_data = Vec::new(); - // list length = 11 (sum of per-proto lengths including length bytes) - alpn_data.extend_from_slice(&11u16.to_be_bytes()); - alpn_data.push(2); - alpn_data.extend_from_slice(b"h2"); - alpn_data.push(4); - alpn_data.extend_from_slice(b"spdy"); - alpn_data.push(2); - alpn_data.extend_from_slice(b"h3"); - let ch = build_client_hello_with_exts(vec![(0x0010, alpn_data)], "alpn.test"); - let alpn = extract_alpn_from_client_hello(&ch); - let alpn_str: Vec = alpn - .iter() - .map(|p| std::str::from_utf8(p).unwrap().to_string()) - .collect(); - assert_eq!(alpn_str, vec!["h2", "spdy", "h3"]); - } + // The HMAC check window (28 bytes) plus the embedded timestamp (4 bytes) + // must exactly fill the digest. If TLS_DIGEST_LEN ever changes, these + // assertions will catch the mismatch before any timing-oracle fix is broke. + const_assert!(28 + 4 == TLS_DIGEST_LEN); } + +// ============= Security-focused regression tests ============= + +#[cfg(test)] +#[path = "tls_security_tests.rs"] +mod security_tests; diff --git a/src/protocol/tls_security_tests.rs b/src/protocol/tls_security_tests.rs new file mode 100644 index 0000000..476f24a --- /dev/null +++ b/src/protocol/tls_security_tests.rs @@ -0,0 +1,1242 @@ +use super::*; +use crate::crypto::sha256_hmac; + +/// Build a TLS-handshake-like buffer that contains a valid HMAC digest +/// for the given `secret` and `timestamp`. +/// +/// Layout (bytes): +/// [0..TLS_DIGEST_POS] : fixed filler (0x42) +/// [TLS_DIGEST_POS..+32] : digest = HMAC XOR [0..0 || timestamp_le] +/// [TLS_DIGEST_POS+32] : session_id_len = 32 +/// [TLS_DIGEST_POS+33..+65] : session_id filler (0x42) +fn make_valid_tls_handshake(secret: &[u8], timestamp: u32) -> Vec { + let session_id_len: usize = 32; + let len = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1 + session_id_len; + let mut handshake = vec![0x42u8; len]; + + handshake[TLS_DIGEST_POS + TLS_DIGEST_LEN] = session_id_len as u8; + // Zero the digest slot before computing HMAC (mirrors what validate does). + handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].fill(0); + + let computed = sha256_hmac(secret, &handshake); + + // digest = HMAC such that XOR with stored digest yields [0..0, timestamp_le]. + // bytes 0-27 of digest == computed[0..28] -> xored[..28] == 0 + // bytes 28-31 of digest == computed[28..32] XOR timestamp_le + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + + handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN] + .copy_from_slice(&digest); + handshake +} + +// ------------------------------------------------------------------ +// Happy-path sanity +// ------------------------------------------------------------------ + +#[test] +fn valid_handshake_with_correct_secret_accepted() { + let secret = b"correct_horse_battery_staple_32b"; + // timestamp = 0 triggers is_boot_time path, accepted without wall-clock check. + let handshake = make_valid_tls_handshake(secret, 0); + let secrets = vec![("alice".to_string(), secret.to_vec())]; + let result = validate_tls_handshake(&handshake, &secrets, true); + assert!(result.is_some(), "Valid handshake must be accepted"); + assert_eq!(result.unwrap().user, "alice"); +} + +#[test] +fn deterministic_external_vector_validates_without_helper() { + // Deterministic vector generated by an external Python stdlib HMAC script, + // not by this test module helper. This catches mirrored helper mistakes. + let secret = hex::decode("00112233445566778899aabbccddeeff").unwrap(); + let handshake = hex::decode( + "4242424242424242424242a93225d1d6b46260bc9ce0cc48c7487d2b1ca5afa7ae9fc6609d9e60a3ca842b204242424242424242424242424242424242424242424242424242424242424242", + ) + .unwrap(); + + let secrets = vec![("vector_user".to_string(), secret)]; + let result = validate_tls_handshake(&handshake, &secrets, true).unwrap(); + + assert_eq!(result.user, "vector_user"); + assert_eq!(result.timestamp, 0x01020304); +} + +#[test] +fn valid_handshake_timestamp_extracted_correctly() { + let secret = b"ts_extraction_test"; + let ts: u32 = 0xDEAD_BEEF; + let handshake = make_valid_tls_handshake(secret, ts); + let secrets = vec![("u".to_string(), secret.to_vec())]; + let result = validate_tls_handshake(&handshake, &secrets, true); + assert!(result.is_some()); + assert_eq!(result.unwrap().timestamp, ts); +} + +// ------------------------------------------------------------------ +// HMAC bit-flip rejection - adversarial HMAC forgery attempts +// ------------------------------------------------------------------ + +/// Flip every single bit across the 28-byte HMAC check window one at a +/// time. Each flip must cause rejection. This is the primary guard +/// against a censor gradually narrowing down a valid HMAC via partial +/// matches (which would be exploitable with a variable-time comparison). +#[test] +fn hmac_single_bit_flip_anywhere_in_check_window_rejected() { + let secret = b"flip_test_secret"; + let base = make_valid_tls_handshake(secret, 0); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + // First ensure the unmodified handshake is accepted. + assert!( + validate_tls_handshake(&base, &secrets, true).is_some(), + "Baseline handshake must be accepted before flip tests" + ); + + for byte_pos in 0..28usize { + for bit in 0u8..8 { + let mut h = base.clone(); + h[TLS_DIGEST_POS + byte_pos] ^= 1 << bit; + assert!( + validate_tls_handshake(&h, &secrets, true).is_none(), + "Flip of bit {bit} in HMAC byte {byte_pos} must be rejected" + ); + } + } +} + +/// XOR entire check window (bytes 0-27) with 0xFF - must still fail. +#[test] +fn hmac_full_window_corruption_rejected() { + let secret = b"full_window_test"; + let mut h = make_valid_tls_handshake(secret, 0); + let secrets = vec![("u".to_string(), secret.to_vec())]; + for i in 0..28 { + h[TLS_DIGEST_POS + i] ^= 0xFF; + } + assert!(validate_tls_handshake(&h, &secrets, true).is_none()); +} + +/// Byte 27 is the last byte in the checked window. A non-constant-time +/// `all(|b| b == 0)` that short-circuits on byte 0 would never even reach +/// byte 27, making this an effective "did the fix actually run to the end" +/// sentinel: if this passes but the earlier byte-0 test fails, the check +/// window is not being evaluated end-to-end. +#[test] +fn hmac_last_byte_of_check_window_enforced() { + let secret = b"last_byte_sentinel"; + let mut h = make_valid_tls_handshake(secret, 0); + let secrets = vec![("u".to_string(), secret.to_vec())]; + // Corrupt only byte 27. + h[TLS_DIGEST_POS + 27] ^= 0x01; + assert!( + validate_tls_handshake(&h, &secrets, true).is_none(), + "Corruption at byte 27 (end of HMAC window) must cause rejection" + ); +} + +// ------------------------------------------------------------------ +// User enumeration / multi-user ordering +// ------------------------------------------------------------------ + +#[test] +fn wrong_user_secret_rejected_even_with_valid_structure() { + let secret_a = b"secret_alpha"; + let secret_b = b"secret_beta"; + let handshake = make_valid_tls_handshake(secret_b, 0); + // Only user_a is configured. + let secrets = vec![("user_a".to_string(), secret_a.to_vec())]; + assert!( + validate_tls_handshake(&handshake, &secrets, true).is_none(), + "Handshake for user_b must fail when only user_a is configured" + ); +} + +#[test] +fn second_user_in_list_found_when_first_does_not_match() { + let secret_a = b"secret_alpha"; + let secret_b = b"secret_beta"; + let handshake = make_valid_tls_handshake(secret_b, 0); + let secrets = vec![ + ("user_a".to_string(), secret_a.to_vec()), + ("user_b".to_string(), secret_b.to_vec()), + ]; + let result = validate_tls_handshake(&handshake, &secrets, true); + assert!(result.is_some(), "user_b must be found even though user_a comes first"); + assert_eq!(result.unwrap().user, "user_b"); +} + +#[test] +fn duplicate_secret_keeps_first_user_identity() { + // If multiple entries share the same secret, the selected identity must + // stay stable and deterministic (first entry wins). + let shared = b"same_secret_for_two_users"; + let handshake = make_valid_tls_handshake(shared, 0); + let secrets = vec![ + ("first_user".to_string(), shared.to_vec()), + ("second_user".to_string(), shared.to_vec()), + ]; + + let result = validate_tls_handshake(&handshake, &secrets, true); + assert!(result.is_some()); + assert_eq!(result.unwrap().user, "first_user"); +} + +#[test] +fn no_user_matches_returns_none() { + let secret_a = b"aaa"; + let secret_b = b"bbb"; + let secret_c = b"ccc"; + let handshake = make_valid_tls_handshake(b"unknown_secret", 0); + let secrets = vec![ + ("a".to_string(), secret_a.to_vec()), + ("b".to_string(), secret_b.to_vec()), + ("c".to_string(), secret_c.to_vec()), + ]; + assert!(validate_tls_handshake(&handshake, &secrets, true).is_none()); +} + +#[test] +fn empty_secrets_list_rejects_everything() { + let secret = b"test"; + let handshake = make_valid_tls_handshake(secret, 0); + let secrets: Vec<(String, Vec)> = Vec::new(); + assert!(validate_tls_handshake(&handshake, &secrets, true).is_none()); +} + +// ------------------------------------------------------------------ +// Timestamp / time-skew boundary attacks +// ------------------------------------------------------------------ + +#[test] +fn timestamp_at_time_skew_boundaries_accepted() { + let secret = b"skew_boundary_test_secret"; + let now: i64 = 1_700_000_000; + let secrets = vec![("u".to_string(), secret.to_vec())]; + + // time_diff = now - ts = TIME_SKEW_MIN = -1200 + // -> ts = now - TIME_SKEW_MIN = now + 1200 (20 min in the future). + let ts_at_future_limit = (now - TIME_SKEW_MIN) as u32; + let h = make_valid_tls_handshake(secret, ts_at_future_limit); + assert!( + validate_tls_handshake_at_time(&h, &secrets, false, now).is_some(), + "Timestamp at max-allowed future (time_diff = TIME_SKEW_MIN) must be accepted" + ); + + // time_diff = now - ts = TIME_SKEW_MAX = 600 + // -> ts = now - TIME_SKEW_MAX = now - 600 (10 min in the past). + let ts_at_past_limit = (now - TIME_SKEW_MAX) as u32; + let h = make_valid_tls_handshake(secret, ts_at_past_limit); + assert!( + validate_tls_handshake_at_time(&h, &secrets, false, now).is_some(), + "Timestamp at max-allowed past (time_diff = TIME_SKEW_MAX) must be accepted" + ); +} + +#[test] +fn timestamp_one_second_outside_skew_window_rejected() { + let secret = b"skew_outside_test_secret"; + let now: i64 = 1_700_000_000; + let secrets = vec![("u".to_string(), secret.to_vec())]; + + // time_diff = TIME_SKEW_MAX + 1 = 601 (one second too far in the past) + // -> ts = now - (TIME_SKEW_MAX + 1) = now - 601 + let ts_too_past = (now - TIME_SKEW_MAX - 1) as u32; + let h = make_valid_tls_handshake(secret, ts_too_past); + assert!( + validate_tls_handshake_at_time(&h, &secrets, false, now).is_none(), + "Timestamp one second too far in the past must be rejected" + ); + + // time_diff = TIME_SKEW_MIN - 1 = -1201 (one second too far in the future) + // -> ts = now - (TIME_SKEW_MIN - 1) = now + 1201 + let ts_too_future = (now - TIME_SKEW_MIN + 1) as u32; + let h = make_valid_tls_handshake(secret, ts_too_future); + assert!( + validate_tls_handshake_at_time(&h, &secrets, false, now).is_none(), + "Timestamp one second too far in the future must be rejected" + ); +} + +#[test] +fn ignore_time_skew_accepts_far_future_timestamp() { + let secret = b"ignore_skew_test"; + let now: i64 = 1_700_000_000; + let secrets = vec![("u".to_string(), secret.to_vec())]; + + // 1 hour in the future - outside TIME_SKEW_MAX but should pass with flag. + let future_ts = (now + 3600) as u32; + let h = make_valid_tls_handshake(secret, future_ts); + assert!( + validate_tls_handshake_at_time(&h, &secrets, true, now).is_some(), + "ignore_time_skew=true must override window rejection" + ); + assert!( + validate_tls_handshake_at_time(&h, &secrets, false, now).is_none(), + "ignore_time_skew=false must still reject far-future timestamp" + ); +} + +#[test] +fn boot_time_timestamp_accepted_without_ignore_flag() { + // Timestamps below the boot-time threshold are treated as client uptime, + // not real wall-clock time. The proxy allows them regardless of skew. + let secret = b"boot_time_test"; + // 86_400_000 / 2 is well below the boot-time threshold (~2.74 years worth of seconds). + let boot_ts: u32 = 86_400_000 / 2; + let handshake = make_valid_tls_handshake(secret, boot_ts); + let secrets = vec![("u".to_string(), secret.to_vec())]; + assert!( + validate_tls_handshake(&handshake, &secrets, false).is_some(), + "Boot-time timestamp must be accepted even with ignore_time_skew=false" + ); +} + +// ------------------------------------------------------------------ +// Structural / length boundary attacks +// ------------------------------------------------------------------ + +#[test] +fn too_short_handshake_rejected_without_panic() { + let secrets = vec![("u".to_string(), b"s".to_vec())]; + // Exactly one byte short of the minimum required length. + let h = vec![0u8; TLS_DIGEST_POS + TLS_DIGEST_LEN]; + assert!(validate_tls_handshake(&h, &secrets, true).is_none()); + + // Empty buffer. + assert!(validate_tls_handshake(&[], &secrets, true).is_none()); +} + +#[test] +fn claimed_session_id_overflows_buffer_rejected() { + let session_id_len: usize = 32; + let min_len = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1 + session_id_len; + let mut h = vec![0u8; min_len]; + // Claim session_id is 33 bytes - one more than the buffer holds. + h[TLS_DIGEST_POS + TLS_DIGEST_LEN] = (session_id_len + 1) as u8; + let secrets = vec![("u".to_string(), b"s".to_vec())]; + assert!(validate_tls_handshake(&h, &secrets, true).is_none()); +} + +#[test] +fn max_session_id_len_255_does_not_panic() { + // session_id_len = 255 with a buffer that is far too small for it. + let min_len = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1 + 32; + let mut h = vec![0u8; min_len]; + h[TLS_DIGEST_POS + TLS_DIGEST_LEN] = 255; + let secrets = vec![("u".to_string(), b"s".to_vec())]; + assert!(validate_tls_handshake(&h, &secrets, true).is_none()); +} + +// ------------------------------------------------------------------ +// Adversarial digest values +// ------------------------------------------------------------------ + +#[test] +fn all_zeros_digest_rejected() { + // An all-zeros digest would only pass if HMAC(secret, msg) happens to + // have its first 28 bytes all zero, which is computationally infeasible. + let session_id_len: usize = 32; + let min_len = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1 + session_id_len; + let mut h = vec![0x42u8; min_len]; + h[TLS_DIGEST_POS + TLS_DIGEST_LEN] = session_id_len as u8; + h[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].fill(0); + let secrets = vec![("u".to_string(), b"test_secret".to_vec())]; + assert!(validate_tls_handshake(&h, &secrets, true).is_none()); +} + +#[test] +fn all_ones_digest_rejected() { + let session_id_len: usize = 32; + let min_len = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1 + session_id_len; + let mut h = vec![0x42u8; min_len]; + h[TLS_DIGEST_POS + TLS_DIGEST_LEN] = session_id_len as u8; + h[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].fill(0xFF); + let secrets = vec![("u".to_string(), b"test_secret".to_vec())]; + assert!(validate_tls_handshake(&h, &secrets, true).is_none()); +} + +/// Simulate a censor that sends 200 crafted packets with random digests. +/// Every single one must be rejected; no random digest should accidentally +/// pass (probability 2^{-224} per attempt; negligible for 200 trials). +#[test] +fn censor_probe_random_digests_all_rejected() { + use crate::crypto::SecureRandom; + let secret = b"production_like_secret_value_xyz"; + let session_id_len: usize = 32; + let min_len = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1 + session_id_len; + let secrets = vec![("u".to_string(), secret.to_vec())]; + let rng = SecureRandom::new(); + + for attempt in 0..200 { + let mut h = vec![0x42u8; min_len]; + h[TLS_DIGEST_POS + TLS_DIGEST_LEN] = session_id_len as u8; + let rand_digest = rng.bytes(TLS_DIGEST_LEN); + h[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN] + .copy_from_slice(&rand_digest); + assert!( + validate_tls_handshake(&h, &secrets, true).is_none(), + "Random digest at attempt {attempt} must not match" + ); + } +} + +/// The check window is bytes 0-27 of the XOR result. Bytes 28-31 encode +/// the timestamp and must NOT affect whether the HMAC portion validates - +/// only the timestamp range check uses them. Build a valid handshake with +/// timestamp = 0 (boot-time), flip each of bytes 28-31 with ignore_time_skew +/// enabled, and verify the HMAC portion still passes (the timestamp changes +/// but the proxy still accepts the connection under ignore_time_skew). +#[test] +fn timestamp_bytes_28_31_do_not_affect_hmac_window() { + let secret = b"window_boundary_test"; + let base = make_valid_tls_handshake(secret, 0); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + // Baseline must pass. + assert!(validate_tls_handshake(&base, &secrets, true).is_some()); + + // Flip each of the timestamp bytes; with ignore_time_skew the + // modified timestamps (small absolute values) still pass boot-time check. + for i in 28..32usize { + let mut h = base.clone(); + h[TLS_DIGEST_POS + i] ^= 0xFF; + // The new timestamp is non-zero but potentially still < boot threshold; + // use ignore_time_skew=true so wallet test is HMAC-only. + assert!( + validate_tls_handshake(&h, &secrets, true).is_some(), + "Flipping byte {i} (timestamp region) must not invalidate HMAC window" + ); + } +} + +// ------------------------------------------------------------------ +// session_id preservation +// ------------------------------------------------------------------ + +#[test] +fn session_id_is_preserved_verbatim_in_validation_result() { + // If session_id extraction is ever broken (wrong offset, wrong length, + // off-by-one), this test will catch it before it silently corrupts the + // ServerHello that echoes the session_id back to the client. + let secret = b"session_id_preservation_test"; + let handshake = make_valid_tls_handshake(secret, 0); + let secrets = vec![("u".to_string(), secret.to_vec())]; + let result = validate_tls_handshake(&handshake, &secrets, true).unwrap(); + + let sid_len_pos = TLS_DIGEST_POS + TLS_DIGEST_LEN; + let sid_len = handshake[sid_len_pos] as usize; + let expected = &handshake[sid_len_pos + 1..sid_len_pos + 1 + sid_len]; + + assert_eq!( + result.session_id, expected, + "session_id in TlsValidation must be the verbatim bytes from the handshake" + ); +} + +// ------------------------------------------------------------------ +// Clock decoupling - ignore_time_skew must not consult the system clock +// ------------------------------------------------------------------ + +/// When `ignore_time_skew = true`, a valid HMAC must be accepted even if +/// `now = 0` (the sentinel used when the clock is not needed). A broken +/// system clock cannot silently deny service when the admin has explicitly +/// disabled timestamp checking. +#[test] +fn ignore_time_skew_accepts_valid_hmac_with_now_zero() { + let secret = b"clock_decoupling_test"; + // Use a realistic Unix timestamp that would be far outside the window + // if compared against now=0 (time_diff would be ~-1_700_000_000). + let realistic_ts: u32 = 1_700_000_000; + let h = make_valid_tls_handshake(secret, realistic_ts); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + assert!( + validate_tls_handshake_at_time(&h, &secrets, true, 0).is_some(), + "ignore_time_skew=true must accept a valid HMAC regardless of `now`" + ); + + // Confirm that the same handshake IS rejected when the window is enforced + // and now=0 (time_diff very negative -> outside window). This distinguishes + // "clock decoupling" from "always accept". + assert!( + validate_tls_handshake_at_time(&h, &secrets, false, 0).is_none(), + "ignore_time_skew=false with now=0 must still reject out-of-window timestamps" + ); +} + +/// An HMAC-invalid handshake must be rejected even when ignore_time_skew=true +/// and now=0. Verifies that the clock-decoupling fix did not weaken HMAC +/// enforcement in the ignore_time_skew path. +#[test] +fn ignore_time_skew_with_now_zero_still_rejects_bad_hmac() { + let secret = b"clock_no_backdoor_test"; + let mut h = make_valid_tls_handshake(secret, 1_700_000_000); + let secrets = vec![("u".to_string(), secret.to_vec())]; + // Corrupt the HMAC check window. + h[TLS_DIGEST_POS] ^= 0xFF; + assert!( + validate_tls_handshake_at_time(&h, &secrets, true, 0).is_none(), + "Broken HMAC must be rejected even with ignore_time_skew=true and now=0" + ); +} + +#[test] +fn system_time_before_unix_epoch_is_rejected_without_panic() { + let before_epoch = UNIX_EPOCH + .checked_sub(std::time::Duration::from_secs(1)) + .expect("UNIX_EPOCH minus one second must be representable"); + assert!(system_time_to_unix_secs(before_epoch).is_none()); +} + +/// `i64::MAX` is 9_223_372_036_854_775_807 seconds (~292 billion years CE). +/// Any `SystemTime` whose duration since epoch exceeds `i64::MAX` seconds +/// must return `None` rather than silently wrapping to a large negative +/// timestamp that would corrupt every subsequent time-skew comparison. +#[test] +fn system_time_far_future_overflowing_i64_returns_none() { + // i64::MAX + 1 seconds past epoch overflows i64 when cast naively with `as`. + let overflow_secs = u64::try_from(i64::MAX).unwrap() + 1; + if let Some(far_future) = + UNIX_EPOCH.checked_add(std::time::Duration::from_secs(overflow_secs)) + { + assert!( + system_time_to_unix_secs(far_future).is_none(), + "Seconds > i64::MAX must return None, not a wrapped negative timestamp" + ); + } + // If the platform cannot represent this SystemTime, the test is vacuously + // satisfied: `checked_add` returning None means the platform already rejects it. +} + +// ------------------------------------------------------------------ +// Message canonicalization — HMAC covers every byte of the handshake +// ------------------------------------------------------------------ + +/// Every byte before TLS_DIGEST_POS is part of the HMAC input (because msg +/// = full handshake with only the digest slot zeroed). An attacker cannot +/// replay a valid handshake with a modified ClientHello header while keeping +/// the stored digest; each such modification produces a different HMAC. +#[test] +fn pre_digest_bytes_are_hmac_covered() { + // TLS_DIGEST_POS = 11, so 11 bytes precede the digest. + let secret = b"pre_digest_coverage_test"; + let base = make_valid_tls_handshake(secret, 0); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + for byte_pos in 0..TLS_DIGEST_POS { + let mut h = base.clone(); + h[byte_pos] ^= 0x01; + assert!( + validate_tls_handshake(&h, &secrets, true).is_none(), + "Flip in pre-digest byte {byte_pos} must cause HMAC check failure" + ); + } +} + +/// session_id bytes follow the digest in the buffer and are also part of the +/// HMAC input. Flipping any of them invalidates the stored digest, preventing +/// a censor from capturing a valid session_id and replaying it with a different +/// one while keeping the rest of the packet intact. +#[test] +fn session_id_bytes_are_hmac_covered() { + let secret = b"session_id_coverage_test"; + let base = make_valid_tls_handshake(secret, 0); // session_id_len = 32 + let secrets = vec![("u".to_string(), secret.to_vec())]; + + let sid_start = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1; + for byte_pos in sid_start..base.len() { + let mut h = base.clone(); + h[byte_pos] ^= 0x01; + assert!( + validate_tls_handshake(&h, &secrets, true).is_none(), + "Flip in session_id byte at offset {byte_pos} must cause HMAC check failure" + ); + } +} + +/// Appending even one byte to a valid handshake changes the HMAC input (msg +/// includes all bytes) and therefore invalidates the stored digest. This +/// prevents a length-extension-style modification of the payload. +#[test] +fn appended_trailing_byte_causes_rejection() { + let secret = b"trailing_byte_test"; + let mut h = make_valid_tls_handshake(secret, 0); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + assert!(validate_tls_handshake(&h, &secrets, true).is_some(), "baseline"); + + h.push(0x00); + assert!( + validate_tls_handshake(&h, &secrets, true).is_none(), + "Appending a trailing byte to a valid handshake must invalidate the HMAC" + ); +} + +// ------------------------------------------------------------------ +// Zero-length session_id (structural edge case) +// ------------------------------------------------------------------ + +/// session_id_len = 0 is legal in the TLS spec. The validator must accept a +/// valid handshake with an empty session_id and return an empty session_id +/// slice without panicking or accessing out-of-bounds memory. +#[test] +fn zero_length_session_id_accepted() { + let secret = b"zero_sid_test"; + // Buffer: pre-digest | digest | session_id_len=0 (no session_id bytes follow) + let len = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1; + let mut handshake = vec![0x42u8; len]; + handshake[TLS_DIGEST_POS + TLS_DIGEST_LEN] = 0; // session_id_len = 0 + handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].fill(0); + + let computed = sha256_hmac(secret, &handshake); + // timestamp = 0 → ts XOR bytes are all zero → digest = computed unchanged. + handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN] + .copy_from_slice(&computed); + + let secrets = vec![("u".to_string(), secret.to_vec())]; + let result = validate_tls_handshake(&handshake, &secrets, true); + assert!(result.is_some(), "zero-length session_id must be accepted"); + assert!( + result.unwrap().session_id.is_empty(), + "session_id field must be empty when session_id_len = 0" + ); +} + +// ------------------------------------------------------------------ +// Boot-time threshold — exact boundary precision +// ------------------------------------------------------------------ + +/// timestamp = 86_399_999 is the last value inside the boot-time window. +/// is_boot_time = true → skew check is skipped entirely → accepted even +/// when `now` is far from the timestamp. +#[test] +fn timestamp_one_below_boot_threshold_bypasses_skew_check() { + let secret = b"boot_last_value_test"; + let ts: u32 = 86_400_000 - 1; + let h = make_valid_tls_handshake(secret, ts); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + // now = 0 → time_diff would be -86_399_999, way outside [-1200, 600]. + // Boot-time bypass must prevent the skew check from running. + assert!( + validate_tls_handshake_at_time(&h, &secrets, false, 0).is_some(), + "ts=86_399_999 must bypass skew check regardless of now" + ); +} + +/// timestamp = 86_400_000 is the first value outside the boot-time window. +/// is_boot_time = false → skew check IS applied. Two sub-cases confirm this: +/// once with now chosen so the skew passes (accepted) and once where it fails. +#[test] +fn timestamp_at_boot_threshold_triggers_skew_check() { + let secret = b"boot_exact_value_test"; + let ts: u32 = 86_400_000; + let h = make_valid_tls_handshake(secret, ts); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + // now = ts + 50 → time_diff = 50, within [-1200, 600] → accepted. + let now_valid: i64 = ts as i64 + 50; + assert!( + validate_tls_handshake_at_time(&h, &secrets, false, now_valid).is_some(), + "ts=86_400_000 within skew window must be accepted via skew check" + ); + + // now = 0 → time_diff = -86_400_000, outside window → rejected. + // If the boot-time bypass were wrongly applied here this would pass. + assert!( + validate_tls_handshake_at_time(&h, &secrets, false, 0).is_none(), + "ts=86_400_000 far from now must be rejected — no boot-time bypass" + ); +} + +// ------------------------------------------------------------------ +// Extreme timestamp values +// ------------------------------------------------------------------ + +/// u32::MAX is a valid timestamp value. When ignore_time_skew=true the HMAC +/// is the only gate, and a correctly constructed handshake must be accepted. +#[test] +fn u32_max_timestamp_accepted_with_ignore_time_skew() { + let secret = b"u32_max_ts_accept_test"; + let h = make_valid_tls_handshake(secret, u32::MAX); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + let result = validate_tls_handshake(&h, &secrets, true); + assert!(result.is_some(), "u32::MAX timestamp must be accepted with ignore_time_skew=true"); + assert_eq!( + result.unwrap().timestamp, + u32::MAX, + "timestamp field must equal u32::MAX verbatim" + ); +} + +/// u32::MAX > 86_400_000 so the skew check runs. With any realistic `now` +/// (~1.7 billion), time_diff = now - u32::MAX is deeply negative — far outside +/// [-1200, 600] — so the handshake must be rejected without overflow. +#[test] +fn u32_max_timestamp_rejected_by_skew_enforcement() { + let secret = b"u32_max_ts_reject_test"; + let h = make_valid_tls_handshake(secret, u32::MAX); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + let now: i64 = 1_700_000_000; + assert!( + validate_tls_handshake_at_time(&h, &secrets, false, now).is_none(), + "u32::MAX timestamp must be rejected by skew check with realistic now" + ); +} + +// ------------------------------------------------------------------ +// Validation result field correctness +// ------------------------------------------------------------------ + +/// result.digest must be the verbatim bytes stored in the handshake buffer, +/// not the freshly recomputed HMAC. Callers use this field directly when +/// constructing the ServerHello response digest. +#[test] +fn result_digest_field_is_verbatim_stored_digest() { + let secret = b"digest_field_verbatim_test"; + let ts: u32 = 0xCAFE_BABE; + let h = make_valid_tls_handshake(secret, ts); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + let result = validate_tls_handshake(&h, &secrets, true).unwrap(); + + let stored: [u8; TLS_DIGEST_LEN] = h[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN] + .try_into() + .unwrap(); + assert_eq!( + result.digest, stored, + "result.digest must equal the stored bytes, not the computed HMAC" + ); +} + +// ------------------------------------------------------------------ +// Secret length edge cases +// ------------------------------------------------------------------ + +/// HMAC-SHA256 pads or hashes keys of any length; a single-byte key must work. +#[test] +fn single_byte_secret_works() { + let secret = b"x"; + let h = make_valid_tls_handshake(secret, 0); + let secrets = vec![("u".to_string(), secret.to_vec())]; + assert!( + validate_tls_handshake(&h, &secrets, true).is_some(), + "single-byte secret must produce a valid and verifiable HMAC" + ); +} + +/// Keys longer than the HMAC block size (64 bytes for SHA-256) are hashed +/// before use. A 256-byte key must work without truncation or panic. +#[test] +fn very_long_secret_256_bytes_works() { + let secret = vec![0xABu8; 256]; + let h = make_valid_tls_handshake(&secret, 0); + let secrets = vec![("u".to_string(), secret.clone())]; + assert!( + validate_tls_handshake(&h, &secrets, true).is_some(), + "256-byte secret must be accepted without truncation" + ); +} + +// ------------------------------------------------------------------ +// Determinism — same input must always produce same result +// ------------------------------------------------------------------ + +/// Calling validate twice on the same input must return identical results. +/// Non-determinism (e.g. from an accidentally global mutable state or a +/// shared nonce) would be a critical security defect in a proxy that rejects +/// censors by relying on stable authentication outcomes. +#[test] +fn validation_is_deterministic() { + let secret = b"determinism_test_key"; + let h = make_valid_tls_handshake(secret, 42); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + let r1 = validate_tls_handshake(&h, &secrets, true).unwrap(); + let r2 = validate_tls_handshake(&h, &secrets, true).unwrap(); + + assert_eq!(r1.user, r2.user); + assert_eq!(r1.session_id, r2.session_id); + assert_eq!(r1.digest, r2.digest); + assert_eq!(r1.timestamp, r2.timestamp); +} + +// ------------------------------------------------------------------ +// Multi-user: scan-all correctness guarantees +// ------------------------------------------------------------------ + +/// The matching logic must scan through the entire secrets list. A user +/// at position 99 of 100 must be found; an implementation that stops early +/// on the first non-match would fail this test. +#[test] +fn last_user_in_large_list_is_found() { + let target_secret = b"needle_in_haystack"; + let h = make_valid_tls_handshake(target_secret, 0); + + let mut secrets: Vec<(String, Vec)> = (0..99) + .map(|i| (format!("decoy_{i}"), format!("wrong_{i}").into_bytes())) + .collect(); + secrets.push(("needle".to_string(), target_secret.to_vec())); + + let result = validate_tls_handshake(&h, &secrets, true); + assert!(result.is_some(), "100th user must be found"); + assert_eq!(result.unwrap().user, "needle"); +} + +/// When multiple users share the same secret the first occurrence must always +/// win. The scan-all loop must not replace first_match with a later one. +#[test] +fn first_matching_user_wins_over_later_duplicate_secret() { + let shared = b"duplicated_secret_key"; + let h = make_valid_tls_handshake(shared, 0); + + let secrets = vec![ + ("decoy_1".to_string(), b"wrong_1".to_vec()), + ("winner".to_string(), shared.to_vec()), // first match + ("decoy_2".to_string(), b"wrong_2".to_vec()), + ("loser".to_string(), shared.to_vec()), // second match — must not win + ("decoy_3".to_string(), b"wrong_3".to_vec()), + ]; + + let result = validate_tls_handshake(&h, &secrets, true); + assert!(result.is_some()); + assert_eq!( + result.unwrap().user, "winner", + "first matching user must be returned even when a later entry also matches" + ); +} + +// ------------------------------------------------------------------ +// Legacy tls.rs tests moved here +// ------------------------------------------------------------------ + +#[test] +fn test_is_tls_handshake() { + assert!(is_tls_handshake(&[0x16, 0x03, 0x01])); + assert!(is_tls_handshake(&[0x16, 0x03, 0x01, 0x02, 0x00])); + assert!(!is_tls_handshake(&[0x17, 0x03, 0x01])); + assert!(!is_tls_handshake(&[0x16, 0x03, 0x02])); + assert!(!is_tls_handshake(&[0x16, 0x03])); +} + +#[test] +fn test_parse_tls_record_header() { + let header = [0x16, 0x03, 0x01, 0x02, 0x00]; + let result = parse_tls_record_header(&header).unwrap(); + assert_eq!(result.0, TLS_RECORD_HANDSHAKE); + assert_eq!(result.1, 512); + + let header = [0x17, 0x03, 0x03, 0x40, 0x00]; + let result = parse_tls_record_header(&header).unwrap(); + assert_eq!(result.0, TLS_RECORD_APPLICATION); + assert_eq!(result.1, 16384); +} + +#[test] +fn test_gen_fake_x25519_key() { + let rng = crate::crypto::SecureRandom::new(); + let key1 = gen_fake_x25519_key(&rng); + let key2 = gen_fake_x25519_key(&rng); + + assert_eq!(key1.len(), 32); + assert_eq!(key2.len(), 32); + assert_ne!(key1, key2); +} + +#[test] +fn test_fake_x25519_key_is_quadratic_residue() { + use num_bigint::BigUint; + use num_traits::One; + + let rng = crate::crypto::SecureRandom::new(); + let key = gen_fake_x25519_key(&rng); + let p = curve25519_prime(); + let k_num = BigUint::from_bytes_le(&key); + let exponent = (&p - BigUint::one()) >> 1; + let legendre = k_num.modpow(&exponent, &p); + assert_eq!(legendre, BigUint::one()); +} + +#[test] +fn test_tls_extension_builder() { + let key = [0x42u8; 32]; + + let mut builder = TlsExtensionBuilder::new(); + builder.add_key_share(&key); + builder.add_supported_versions(0x0304); + + let result = builder.build(); + let len = u16::from_be_bytes([result[0], result[1]]) as usize; + + assert_eq!(len, result.len() - 2); + assert!(result.len() > 40); +} + +#[test] +fn test_server_hello_builder() { + let session_id = vec![0x01, 0x02, 0x03, 0x04]; + let key = [0x55u8; 32]; + + let builder = ServerHelloBuilder::new(session_id.clone()) + .with_x25519_key(&key) + .with_tls13_version(); + + let record = builder.build_record(); + validate_server_hello_structure(&record).expect("Invalid ServerHello structure"); + + assert_eq!(record[0], TLS_RECORD_HANDSHAKE); + assert_eq!(&record[1..3], &TLS_VERSION); + assert_eq!(record[5], 0x02); +} + +#[test] +fn test_build_server_hello_structure() { + let secret = b"test secret"; + let client_digest = [0x42u8; 32]; + let session_id = vec![0xAA; 32]; + + let rng = crate::crypto::SecureRandom::new(); + let response = build_server_hello(secret, &client_digest, &session_id, 2048, &rng, None, 0); + + assert!(response.len() > 100); + assert_eq!(response[0], TLS_RECORD_HANDSHAKE); + validate_server_hello_structure(&response).expect("Invalid ServerHello"); + + let server_hello_len = 5 + u16::from_be_bytes([response[3], response[4]]) as usize; + let ccs_start = server_hello_len; + assert!(response.len() > ccs_start + 6); + assert_eq!(response[ccs_start], TLS_RECORD_CHANGE_CIPHER); + + let ccs_len = 5 + u16::from_be_bytes([response[ccs_start + 3], response[ccs_start + 4]]) as usize; + let app_start = ccs_start + ccs_len; + assert!(response.len() > app_start + 5); + assert_eq!(response[app_start], TLS_RECORD_APPLICATION); +} + +#[test] +fn test_build_server_hello_digest() { + let secret = b"test secret key here"; + let client_digest = [0x42u8; 32]; + let session_id = vec![0xAA; 32]; + + let rng = crate::crypto::SecureRandom::new(); + let response1 = build_server_hello(secret, &client_digest, &session_id, 1024, &rng, None, 0); + let response2 = build_server_hello(secret, &client_digest, &session_id, 1024, &rng, None, 0); + + let digest1 = &response1[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN]; + assert!(!digest1.iter().all(|&b| b == 0)); + + let digest2 = &response2[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN]; + assert_ne!(digest1, digest2); +} + +#[test] +fn test_server_hello_extensions_length() { + let session_id = vec![0x01; 32]; + let key = [0x55u8; 32]; + + let builder = ServerHelloBuilder::new(session_id) + .with_x25519_key(&key) + .with_tls13_version(); + + let record = builder.build_record(); + let msg_start = 5; + let msg_len = u32::from_be_bytes([0, record[6], record[7], record[8]]) as usize; + let session_id_pos = msg_start + 4 + 2 + 32; + let session_id_len = record[session_id_pos] as usize; + let ext_len_pos = session_id_pos + 1 + session_id_len + 2 + 1; + let ext_len = u16::from_be_bytes([record[ext_len_pos], record[ext_len_pos + 1]]) as usize; + let extensions_data = &record[ext_len_pos + 2..msg_start + 4 + msg_len]; + + assert_eq!( + ext_len, + extensions_data.len(), + "Extension length mismatch: declared {}, actual {}", + ext_len, + extensions_data.len() + ); +} + +#[test] +fn test_validate_tls_handshake_format() { + let mut handshake = vec![0u8; 100]; + handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].copy_from_slice(&[0x42; 32]); + handshake[TLS_DIGEST_POS + TLS_DIGEST_LEN] = 32; + + let secrets = vec![("test".to_string(), b"secret".to_vec())]; + let result = validate_tls_handshake(&handshake, &secrets, true); + assert!(result.is_none()); +} + +fn build_client_hello_with_exts(exts: Vec<(u16, Vec)>, host: &str) -> Vec { + let mut body = Vec::new(); + body.extend_from_slice(&TLS_VERSION); + body.extend_from_slice(&[0u8; 32]); + body.push(0); + body.extend_from_slice(&2u16.to_be_bytes()); + body.extend_from_slice(&[0x13, 0x01]); + body.push(1); + body.push(0); + + let host_bytes = host.as_bytes(); + let mut sni_ext = Vec::new(); + sni_ext.extend_from_slice(&(host_bytes.len() as u16 + 3).to_be_bytes()); + sni_ext.push(0); + sni_ext.extend_from_slice(&(host_bytes.len() as u16).to_be_bytes()); + sni_ext.extend_from_slice(host_bytes); + + let mut ext_blob = Vec::new(); + for (typ, data) in exts { + ext_blob.extend_from_slice(&typ.to_be_bytes()); + ext_blob.extend_from_slice(&(data.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&data); + } + ext_blob.extend_from_slice(&0x0000u16.to_be_bytes()); + ext_blob.extend_from_slice(&(sni_ext.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&sni_ext); + + body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes()); + body.extend_from_slice(&ext_blob); + + let mut handshake = Vec::new(); + handshake.push(0x01); + let len_bytes = (body.len() as u32).to_be_bytes(); + handshake.extend_from_slice(&len_bytes[1..4]); + handshake.extend_from_slice(&body); + + let mut record = Vec::new(); + record.push(TLS_RECORD_HANDSHAKE); + record.extend_from_slice(&[0x03, 0x01]); + record.extend_from_slice(&(handshake.len() as u16).to_be_bytes()); + record.extend_from_slice(&handshake); + record +} + +fn build_client_hello_with_raw_extensions(ext_blob: &[u8]) -> Vec { + let mut body = Vec::new(); + body.extend_from_slice(&TLS_VERSION); + body.extend_from_slice(&[0u8; 32]); + body.push(0); + body.extend_from_slice(&2u16.to_be_bytes()); + body.extend_from_slice(&[0x13, 0x01]); + body.push(1); + body.push(0); + body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes()); + body.extend_from_slice(ext_blob); + + let mut handshake = Vec::new(); + handshake.push(0x01); + let len_bytes = (body.len() as u32).to_be_bytes(); + handshake.extend_from_slice(&len_bytes[1..4]); + handshake.extend_from_slice(&body); + + let mut record = Vec::new(); + record.push(TLS_RECORD_HANDSHAKE); + record.extend_from_slice(&[0x03, 0x01]); + record.extend_from_slice(&(handshake.len() as u16).to_be_bytes()); + record.extend_from_slice(&handshake); + record +} + +#[test] +fn test_extract_sni_with_grease_extension() { + let ch = build_client_hello_with_exts(vec![(0x0a0a, Vec::new())], "example.com"); + let sni = extract_sni_from_client_hello(&ch); + assert_eq!(sni.as_deref(), Some("example.com")); +} + +#[test] +fn test_extract_sni_tolerates_empty_unknown_extension() { + let ch = build_client_hello_with_exts(vec![(0x1234, Vec::new())], "test.local"); + let sni = extract_sni_from_client_hello(&ch); + assert_eq!(sni.as_deref(), Some("test.local")); +} + +#[test] +fn test_extract_alpn_single() { + let mut alpn_data = Vec::new(); + alpn_data.extend_from_slice(&3u16.to_be_bytes()); + alpn_data.push(2); + alpn_data.extend_from_slice(b"h2"); + let ch = build_client_hello_with_exts(vec![(0x0010, alpn_data)], "alpn.test"); + let alpn = extract_alpn_from_client_hello(&ch); + let alpn_str: Vec = alpn + .iter() + .map(|p| std::str::from_utf8(p).unwrap().to_string()) + .collect(); + assert_eq!(alpn_str, vec!["h2"]); +} + +#[test] +fn test_extract_alpn_multiple() { + let mut alpn_data = Vec::new(); + alpn_data.extend_from_slice(&11u16.to_be_bytes()); + alpn_data.push(2); + alpn_data.extend_from_slice(b"h2"); + alpn_data.push(4); + alpn_data.extend_from_slice(b"spdy"); + alpn_data.push(2); + alpn_data.extend_from_slice(b"h3"); + let ch = build_client_hello_with_exts(vec![(0x0010, alpn_data)], "alpn.test"); + let alpn = extract_alpn_from_client_hello(&ch); + let alpn_str: Vec = alpn + .iter() + .map(|p| std::str::from_utf8(p).unwrap().to_string()) + .collect(); + assert_eq!(alpn_str, vec!["h2", "spdy", "h3"]); +} + +#[test] +fn extract_sni_rejects_zero_length_host_name() { + let mut sni_ext = Vec::new(); + sni_ext.extend_from_slice(&3u16.to_be_bytes()); + sni_ext.push(0); + sni_ext.extend_from_slice(&0u16.to_be_bytes()); + + let mut ext_blob = Vec::new(); + ext_blob.extend_from_slice(&0x0000u16.to_be_bytes()); + ext_blob.extend_from_slice(&(sni_ext.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&sni_ext); + + let ch = build_client_hello_with_raw_extensions(&ext_blob); + assert!(extract_sni_from_client_hello(&ch).is_none()); +} + +#[test] +fn extract_sni_rejects_when_extension_block_is_truncated() { + let mut ext_blob = Vec::new(); + ext_blob.extend_from_slice(&0x0000u16.to_be_bytes()); + ext_blob.extend_from_slice(&5u16.to_be_bytes()); + ext_blob.extend_from_slice(&[0, 3, 0]); + + let mut ch = build_client_hello_with_raw_extensions(&ext_blob); + ch.pop(); + assert!(extract_sni_from_client_hello(&ch).is_none()); +} + +#[test] +fn extract_alpn_rejects_when_extension_block_is_truncated() { + let mut ext_blob = Vec::new(); + ext_blob.extend_from_slice(&0x0010u16.to_be_bytes()); + ext_blob.extend_from_slice(&5u16.to_be_bytes()); + ext_blob.extend_from_slice(&[0, 3, 2, b'h']); + + let ch = build_client_hello_with_raw_extensions(&ext_blob); + assert!(extract_alpn_from_client_hello(&ch).is_empty()); +} + +#[test] +fn extract_alpn_rejects_nested_length_overflow() { + let mut alpn_data = Vec::new(); + alpn_data.extend_from_slice(&10u16.to_be_bytes()); + alpn_data.push(8); + alpn_data.extend_from_slice(b"h2"); + + let mut ext_blob = Vec::new(); + ext_blob.extend_from_slice(&0x0010u16.to_be_bytes()); + ext_blob.extend_from_slice(&(alpn_data.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&alpn_data); + + let ch = build_client_hello_with_raw_extensions(&ext_blob); + assert!(extract_alpn_from_client_hello(&ch).is_empty()); +} + +// ------------------------------------------------------------------ +// Additional adversarial checks +// ------------------------------------------------------------------ + +#[test] +fn empty_secret_hmac_is_supported() { + let secret: &[u8] = b""; + let handshake = make_valid_tls_handshake(secret, 0); + let secrets = vec![("empty".to_string(), secret.to_vec())]; + let result = validate_tls_handshake(&handshake, &secrets, true); + assert!(result.is_some(), "Empty HMAC key must not panic and must validate when correct"); +} + +#[test] +fn server_hello_digest_verifies_against_full_response() { + let secret = b"fronting_digest_verify_key"; + let client_digest = [0x42u8; TLS_DIGEST_LEN]; + let session_id = vec![0xAA; 32]; + let rng = crate::crypto::SecureRandom::new(); + + let response = build_server_hello(secret, &client_digest, &session_id, 1024, &rng, None, 1); + let mut zeroed = response.clone(); + zeroed[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].fill(0); + + let mut hmac_input = Vec::with_capacity(TLS_DIGEST_LEN + zeroed.len()); + hmac_input.extend_from_slice(&client_digest); + hmac_input.extend_from_slice(&zeroed); + let expected = sha256_hmac(secret, &hmac_input); + + assert_eq!( + &response[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN], + &expected, + "ServerHello digest must be verifiable by a client that recomputes HMAC over full response" + ); +} + +#[test] +fn server_hello_digest_fails_after_single_byte_tamper() { + let secret = b"fronting_tamper_detect_key"; + let client_digest = [0x24u8; TLS_DIGEST_LEN]; + let session_id = vec![0xBB; 32]; + let rng = crate::crypto::SecureRandom::new(); + + let mut response = build_server_hello(secret, &client_digest, &session_id, 1024, &rng, None, 0); + response[TLS_DIGEST_POS + TLS_DIGEST_LEN + 1] ^= 0x01; + + let mut zeroed = response.clone(); + zeroed[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].fill(0); + + let mut hmac_input = Vec::with_capacity(TLS_DIGEST_LEN + zeroed.len()); + hmac_input.extend_from_slice(&client_digest); + hmac_input.extend_from_slice(&zeroed); + let expected = sha256_hmac(secret, &hmac_input); + + assert_ne!( + &response[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN], + &expected, + "Tampering any response byte must invalidate the embedded digest" + ); +} + +#[test] +fn server_hello_application_data_payload_varies_across_runs() { + use std::collections::HashSet; + + let secret = b"fronting_payload_variability_key"; + let client_digest = [0x13u8; TLS_DIGEST_LEN]; + let session_id = vec![0x44; 32]; + let rng = crate::crypto::SecureRandom::new(); + + let mut unique_payloads: HashSet> = HashSet::new(); + for _ in 0..16 { + let response = build_server_hello(secret, &client_digest, &session_id, 1024, &rng, None, 0); + + let sh_len = u16::from_be_bytes([response[3], response[4]]) as usize; + let ccs_pos = 5 + sh_len; + let ccs_len = u16::from_be_bytes([response[ccs_pos + 3], response[ccs_pos + 4]]) as usize; + let app_pos = ccs_pos + 5 + ccs_len; + + assert_eq!(response[app_pos], TLS_RECORD_APPLICATION); + let app_len = u16::from_be_bytes([response[app_pos + 3], response[app_pos + 4]]) as usize; + let payload = response[app_pos + 5..app_pos + 5 + app_len].to_vec(); + + assert!(payload.iter().any(|&b| b != 0), "Payload must not be all-zero deterministic filler"); + unique_payloads.insert(payload); + } + + assert!( + unique_payloads.len() >= 4, + "ApplicationData payload should vary across runs to reduce fingerprintability" + ); +} diff --git a/src/proxy/client.rs b/src/proxy/client.rs index 99e6837..0ef2cc6 100644 --- a/src/proxy/client.rs +++ b/src/proxy/client.rs @@ -23,7 +23,7 @@ enum HandshakeOutcome { use crate::config::ProxyConfig; use crate::crypto::SecureRandom; -use crate::error::{HandshakeResult, ProxyError, Result}; +use crate::error::{HandshakeResult, ProxyError, Result, StreamError}; use crate::ip_tracker::UserIpTracker; use crate::protocol::constants::*; use crate::protocol::tls; @@ -63,10 +63,12 @@ fn record_handshake_failure_class( peer_ip: IpAddr, error: &ProxyError, ) { - let class = if error.to_string().contains("expected 64 bytes, got 0") { - "expected_64_got_0" - } else { - "other" + let class = match error { + ProxyError::Io(err) if err.kind() == std::io::ErrorKind::UnexpectedEof => { + "expected_64_got_0" + } + ProxyError::Stream(StreamError::UnexpectedEof) => "expected_64_got_0", + _ => "other", }; record_beobachten_class(beobachten, config, peer_ip, class); } @@ -204,9 +206,19 @@ where &config, &replay_checker, true, Some(tls_user.as_str()), ).await { HandshakeResult::Success(result) => result, - HandshakeResult::BadClient { reader: _, writer: _ } => { + HandshakeResult::BadClient { reader, writer } => { stats.increment_connects_bad(); debug!(peer = %peer, "Valid TLS but invalid MTProto handshake"); + handle_bad_client( + reader, + writer, + &mtproto_handshake, + real_peer, + local_addr, + &config, + &beobachten, + ) + .await; return Ok(HandshakeOutcome::Handled); } HandshakeResult::Error(e) => return Err(e), @@ -590,12 +602,19 @@ impl RunningClientHandler { .await { HandshakeResult::Success(result) => result, - HandshakeResult::BadClient { - reader: _, - writer: _, - } => { + HandshakeResult::BadClient { reader, writer } => { stats.increment_connects_bad(); debug!(peer = %peer, "Valid TLS but invalid MTProto handshake"); + handle_bad_client( + reader, + writer, + &mtproto_handshake, + peer, + local_addr, + &config, + &self.beobachten, + ) + .await; return Ok(HandshakeOutcome::Handled); } HandshakeResult::Error(e) => return Err(e), @@ -806,8 +825,24 @@ impl RunningClientHandler { }); } - let ip_reserved = match ip_tracker.check_and_add(user, peer_addr.ip()).await { - Ok(()) => true, + if let Some(limit) = config.access.user_max_tcp_conns.get(user) + && stats.get_user_curr_connects(user) >= *limit as u64 + { + return Err(ProxyError::ConnectionLimitExceeded { + user: user.to_string(), + }); + } + + if let Some(quota) = config.access.user_data_quota.get(user) + && stats.get_user_total_octets(user) >= *quota + { + return Err(ProxyError::DataQuotaExceeded { + user: user.to_string(), + }); + } + + match ip_tracker.check_and_add(user, peer_addr.ip()).await { + Ok(()) => {} Err(reason) => { warn!( user = %user, @@ -819,33 +854,12 @@ impl RunningClientHandler { user: user.to_string(), }); } - }; - // IP limit check - - if let Some(limit) = config.access.user_max_tcp_conns.get(user) - && stats.get_user_curr_connects(user) >= *limit as u64 - { - if ip_reserved { - ip_tracker.remove_ip(user, peer_addr.ip()).await; - stats.increment_ip_reservation_rollback_tcp_limit_total(); - } - return Err(ProxyError::ConnectionLimitExceeded { - user: user.to_string(), - }); - } - - if let Some(quota) = config.access.user_data_quota.get(user) - && stats.get_user_total_octets(user) >= *quota - { - if ip_reserved { - ip_tracker.remove_ip(user, peer_addr.ip()).await; - stats.increment_ip_reservation_rollback_quota_limit_total(); - } - return Err(ProxyError::DataQuotaExceeded { - user: user.to_string(), - }); } Ok(()) } } + +#[cfg(test)] +#[path = "client_security_tests.rs"] +mod security_tests; diff --git a/src/proxy/client_security_tests.rs b/src/proxy/client_security_tests.rs new file mode 100644 index 0000000..70930ea --- /dev/null +++ b/src/proxy/client_security_tests.rs @@ -0,0 +1,631 @@ +use super::*; +use crate::config::{UpstreamConfig, UpstreamType}; +use crate::crypto::sha256_hmac; +use crate::protocol::tls; +use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; +use tokio::net::{TcpListener, TcpStream}; + +#[tokio::test] +async fn short_tls_probe_is_masked_through_client_pipeline() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = vec![0x16, 0x03, 0x01, 0x00, 0x10]; + let backend_reply = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK".to_vec(); + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = vec![0u8; probe.len()]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, probe); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(4096); + let peer: SocketAddr = "203.0.113.77:55001".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + )); + + client_side.write_all(&probe).await.unwrap(); + let mut observed = vec![0u8; backend_reply.len()]; + client_side.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, backend_reply); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + accept_task.await.unwrap(); +} + +fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32) -> Vec { + let tls_len: usize = 600; + let total_len = 5 + tls_len; + let mut handshake = vec![0x42u8; total_len]; + + handshake[0] = 0x16; + handshake[1] = 0x03; + handshake[2] = 0x01; + handshake[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes()); + + let session_id_len: usize = 32; + handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8; + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + let computed = sha256_hmac(secret, &handshake); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].copy_from_slice(&digest); + handshake +} + +fn wrap_tls_application_data(payload: &[u8]) -> Vec { + let mut record = Vec::with_capacity(5 + payload.len()); + record.push(0x17); + record.extend_from_slice(&[0x03, 0x03]); + record.extend_from_slice(&(payload.len() as u16).to_be_bytes()); + record.extend_from_slice(payload); + record +} + +#[tokio::test] +async fn valid_tls_path_does_not_fall_back_to_mask_backend() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x11u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 0); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + cfg.access.ignore_time_skew = true; + cfg.access + .users + .insert("user".to_string(), "11111111111111111111111111111111".to_string()); + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(8192); + let peer: SocketAddr = "198.51.100.80:55002".parse().unwrap(); + let stats_for_assert = stats.clone(); + let bad_before = stats_for_assert.get_connects_bad(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + + let mut record_header = [0u8; 5]; + client_side.read_exact(&mut record_header).await.unwrap(); + assert_eq!(record_header[0], 0x16); + + drop(client_side); + let handler_result = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + assert!(handler_result.is_err()); + + let no_mask_connect = tokio::time::timeout(Duration::from_millis(250), listener.accept()).await; + assert!( + no_mask_connect.is_err(), + "Mask backend must not be contacted on authenticated TLS path" + ); + + let bad_after = stats_for_assert.get_connects_bad(); + assert_eq!( + bad_before, + bad_after, + "Authenticated TLS path must not increment connects_bad" + ); +} + +#[tokio::test] +async fn valid_tls_with_invalid_mtproto_falls_back_to_mask_backend() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x33u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 0); + let invalid_mtproto = vec![0u8; crate::protocol::constants::HANDSHAKE_LEN]; + let tls_app_record = wrap_tls_application_data(&invalid_mtproto); + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = vec![0u8; invalid_mtproto.len()]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, invalid_mtproto); + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + cfg.access.ignore_time_skew = true; + cfg.access + .users + .insert("user".to_string(), "33333333333333333333333333333333".to_string()); + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(32768); + let peer: SocketAddr = "198.51.100.90:55111".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut tls_response_head = [0u8; 5]; + client_side.read_exact(&mut tls_response_head).await.unwrap(); + assert_eq!(tls_response_head[0], 0x16); + + client_side.write_all(&tls_app_record).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn client_handler_tls_bad_mtproto_is_forwarded_to_mask_backend() { + let mask_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = mask_listener.local_addr().unwrap(); + + let front_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let front_addr = front_listener.local_addr().unwrap(); + + let secret = [0x44u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 0); + let invalid_mtproto = vec![0u8; crate::protocol::constants::HANDSHAKE_LEN]; + let tls_app_record = wrap_tls_application_data(&invalid_mtproto); + + let mask_accept_task = tokio::spawn(async move { + let (mut stream, _) = mask_listener.accept().await.unwrap(); + let mut got = vec![0u8; invalid_mtproto.len()]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, invalid_mtproto); + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + cfg.access.ignore_time_skew = true; + cfg.access + .users + .insert("user".to_string(), "44444444444444444444444444444444".to_string()); + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let server_task = { + let config = config.clone(); + let stats = stats.clone(); + let upstream_manager = upstream_manager.clone(); + let replay_checker = replay_checker.clone(); + let buffer_pool = buffer_pool.clone(); + let rng = rng.clone(); + let route_runtime = route_runtime.clone(); + let ip_tracker = ip_tracker.clone(); + let beobachten = beobachten.clone(); + + tokio::spawn(async move { + let (stream, peer) = front_listener.accept().await.unwrap(); + let real_peer_report = Arc::new(std::sync::Mutex::new(None)); + ClientHandler::new( + stream, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + real_peer_report, + ) + .run() + .await + }) + }; + + let mut client = TcpStream::connect(front_addr).await.unwrap(); + client.write_all(&client_hello).await.unwrap(); + + let mut tls_response_head = [0u8; 5]; + client.read_exact(&mut tls_response_head).await.unwrap(); + assert_eq!(tls_response_head[0], 0x16); + + client.write_all(&tls_app_record).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), mask_accept_task) + .await + .unwrap() + .unwrap(); + + drop(client); + + let _ = tokio::time::timeout(Duration::from_secs(3), server_task) + .await + .unwrap() + .unwrap(); +} + +#[test] +fn unexpected_eof_is_classified_without_string_matching() { + let beobachten = BeobachtenStore::new(); + let mut config = ProxyConfig::default(); + config.general.beobachten = true; + config.general.beobachten_minutes = 1; + + let eof = ProxyError::Io(std::io::Error::from(std::io::ErrorKind::UnexpectedEof)); + let peer_ip: IpAddr = "198.51.100.200".parse().unwrap(); + + record_handshake_failure_class(&beobachten, &config, peer_ip, &eof); + + let snapshot = beobachten.snapshot_text(Duration::from_secs(60)); + assert!( + snapshot.contains("[expected_64_got_0]"), + "UnexpectedEof must be classified as expected_64_got_0" + ); + assert!( + snapshot.contains("198.51.100.200-1"), + "Classified record must include source IP" + ); +} + +#[test] +fn non_eof_error_is_classified_as_other() { + let beobachten = BeobachtenStore::new(); + let mut config = ProxyConfig::default(); + config.general.beobachten = true; + config.general.beobachten_minutes = 1; + + let non_eof = ProxyError::Io(std::io::Error::other("different error")); + let peer_ip: IpAddr = "203.0.113.201".parse().unwrap(); + + record_handshake_failure_class(&beobachten, &config, peer_ip, &non_eof); + + let snapshot = beobachten.snapshot_text(Duration::from_secs(60)); + assert!( + snapshot.contains("[other]"), + "Non-EOF errors must map to other" + ); + assert!( + snapshot.contains("203.0.113.201-1"), + "Classified record must include source IP" + ); + assert!( + !snapshot.contains("[expected_64_got_0]"), + "Non-EOF errors must not be misclassified as expected_64_got_0" + ); +} + +#[tokio::test] +async fn tcp_limit_rejection_does_not_reserve_ip_or_trigger_rollback() { + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert("user".to_string(), 1); + + let stats = Stats::new(); + stats.increment_user_curr_connects("user"); + + let ip_tracker = UserIpTracker::new(); + let peer_addr: SocketAddr = "198.51.100.210:50000".parse().unwrap(); + + let result = RunningClientHandler::check_user_limits_static( + "user", + &config, + &stats, + peer_addr, + &ip_tracker, + ) + .await; + + assert!(matches!( + result, + Err(ProxyError::ConnectionLimitExceeded { user }) if user == "user" + )); + assert_eq!( + ip_tracker.get_active_ip_count("user").await, + 0, + "Rejected client must not reserve IP slot" + ); + assert_eq!( + stats.get_ip_reservation_rollback_tcp_limit_total(), + 0, + "No rollback should occur when reservation is not taken" + ); +} + +#[tokio::test] +async fn quota_rejection_does_not_reserve_ip_or_trigger_rollback() { + let mut config = ProxyConfig::default(); + config.access.user_data_quota.insert("user".to_string(), 1024); + + let stats = Stats::new(); + stats.add_user_octets_from("user", 1024); + + let ip_tracker = UserIpTracker::new(); + let peer_addr: SocketAddr = "203.0.113.211:50001".parse().unwrap(); + + let result = RunningClientHandler::check_user_limits_static( + "user", + &config, + &stats, + peer_addr, + &ip_tracker, + ) + .await; + + assert!(matches!( + result, + Err(ProxyError::DataQuotaExceeded { user }) if user == "user" + )); + assert_eq!( + ip_tracker.get_active_ip_count("user").await, + 0, + "Quota-rejected client must not reserve IP slot" + ); + assert_eq!( + stats.get_ip_reservation_rollback_quota_limit_total(), + 0, + "No rollback should occur when reservation is not taken" + ); +} + +#[tokio::test] +async fn concurrent_limit_rejections_from_mixed_ips_leave_no_ip_footprint() { + const PARALLEL_IPS: usize = 64; + const ATTEMPTS_PER_IP: usize = 8; + + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert("user".to_string(), 1); + + let config = Arc::new(config); + let stats = Arc::new(Stats::new()); + stats.increment_user_curr_connects("user"); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let mut tasks = tokio::task::JoinSet::new(); + for i in 0..PARALLEL_IPS { + let config = config.clone(); + let stats = stats.clone(); + let ip_tracker = ip_tracker.clone(); + + tasks.spawn(async move { + let ip = IpAddr::V4(std::net::Ipv4Addr::new(198, 51, 100, (i + 1) as u8)); + for _ in 0..ATTEMPTS_PER_IP { + let peer_addr = SocketAddr::new(ip, 40000 + i as u16); + let result = RunningClientHandler::check_user_limits_static( + "user", + &config, + &stats, + peer_addr, + &ip_tracker, + ) + .await; + + assert!(matches!( + result, + Err(ProxyError::ConnectionLimitExceeded { user }) if user == "user" + )); + } + }); + } + + while let Some(joined) = tasks.join_next().await { + joined.unwrap(); + } + + assert_eq!( + ip_tracker.get_active_ip_count("user").await, + 0, + "Concurrent rejected attempts must not leave active IP reservations" + ); + + let recent = ip_tracker + .get_recent_ips_for_users(&["user".to_string()]) + .await; + assert!( + recent + .get("user") + .map(|ips| ips.is_empty()) + .unwrap_or(true), + "Concurrent rejected attempts must not leave recent IP footprint" + ); + + assert_eq!( + stats.get_ip_reservation_rollback_tcp_limit_total(), + 0, + "No rollback should occur under concurrent rejection storms" + ); +} diff --git a/src/proxy/handshake.rs b/src/proxy/handshake.rs index 296432f..4e7b371 100644 --- a/src/proxy/handshake.rs +++ b/src/proxy/handshake.rs @@ -19,6 +19,8 @@ use crate::stats::ReplayChecker; use crate::config::ProxyConfig; use crate::tls_front::{TlsFrontCache, emulator}; +const ACCESS_SECRET_BYTES: usize = 16; + fn decode_user_secrets( config: &ProxyConfig, preferred_user: Option<&str>, @@ -28,6 +30,7 @@ fn decode_user_secrets( if let Some(preferred) = preferred_user && let Some(secret_hex) = config.access.users.get(preferred) && let Ok(bytes) = hex::decode(secret_hex) + && bytes.len() == ACCESS_SECRET_BYTES { secrets.push((preferred.to_string(), bytes)); } @@ -36,7 +39,9 @@ fn decode_user_secrets( if preferred_user.is_some_and(|preferred| preferred == name.as_str()) { continue; } - if let Ok(bytes) = hex::decode(secret_hex) { + if let Ok(bytes) = hex::decode(secret_hex) + && bytes.len() == ACCESS_SECRET_BYTES + { secrets.push((name.clone(), bytes)); } } @@ -48,7 +53,7 @@ fn decode_user_secrets( /// /// Key material (`dec_key`, `dec_iv`, `enc_key`, `enc_iv`) is /// zeroized on drop. -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct HandshakeSuccess { /// Authenticated user name pub user: String, @@ -99,14 +104,6 @@ where return HandshakeResult::BadClient { reader, writer }; } - let digest = &handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN]; - let digest_half = &digest[..tls::TLS_DIGEST_HALF_LEN]; - - if replay_checker.check_and_add_tls_digest(digest_half) { - warn!(peer = %peer, "TLS replay attack detected (duplicate digest)"); - return HandshakeResult::BadClient { reader, writer }; - } - let secrets = decode_user_secrets(config, None); let validation = match tls::validate_tls_handshake( @@ -125,6 +122,14 @@ where } }; + // Replay tracking is applied only after successful authentication to avoid + // letting unauthenticated probes evict valid entries from the replay cache. + let digest_half = &validation.digest[..tls::TLS_DIGEST_HALF_LEN]; + if replay_checker.check_and_add_tls_digest(digest_half) { + warn!(peer = %peer, "TLS replay attack detected (duplicate digest)"); + return HandshakeResult::BadClient { reader, writer }; + } + let secret = match secrets.iter().find(|(name, _)| *name == validation.user) { Some((_, s)) => s, None => return HandshakeResult::BadClient { reader, writer }, @@ -254,11 +259,6 @@ where let dec_prekey_iv = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]; - if replay_checker.check_and_add_handshake(dec_prekey_iv) { - warn!(peer = %peer, "MTProto replay attack detected"); - return HandshakeResult::BadClient { reader, writer }; - } - let enc_prekey_iv: Vec = dec_prekey_iv.iter().rev().copied().collect(); let decoded_users = decode_user_secrets(config, preferred_user); @@ -273,14 +273,19 @@ where dec_key_input.extend_from_slice(&secret); let dec_key = sha256(&dec_key_input); - let dec_iv = u128::from_be_bytes(dec_iv_bytes.try_into().unwrap()); + let mut dec_iv_arr = [0u8; IV_LEN]; + dec_iv_arr.copy_from_slice(dec_iv_bytes); + let dec_iv = u128::from_be_bytes(dec_iv_arr); let mut decryptor = AesCtr::new(&dec_key, dec_iv); let decrypted = decryptor.decrypt(handshake); - let tag_bytes: [u8; 4] = decrypted[PROTO_TAG_POS..PROTO_TAG_POS + 4] - .try_into() - .unwrap(); + let tag_bytes: [u8; 4] = [ + decrypted[PROTO_TAG_POS], + decrypted[PROTO_TAG_POS + 1], + decrypted[PROTO_TAG_POS + 2], + decrypted[PROTO_TAG_POS + 3], + ]; let proto_tag = match ProtoTag::from_bytes(tag_bytes) { Some(tag) => tag, @@ -303,9 +308,7 @@ where continue; } - let dc_idx = i16::from_le_bytes( - decrypted[DC_IDX_POS..DC_IDX_POS + 2].try_into().unwrap() - ); + let dc_idx = i16::from_le_bytes([decrypted[DC_IDX_POS], decrypted[DC_IDX_POS + 1]]); let enc_prekey = &enc_prekey_iv[..PREKEY_LEN]; let enc_iv_bytes = &enc_prekey_iv[PREKEY_LEN..]; @@ -315,10 +318,19 @@ where enc_key_input.extend_from_slice(&secret); let enc_key = sha256(&enc_key_input); - let enc_iv = u128::from_be_bytes(enc_iv_bytes.try_into().unwrap()); + let mut enc_iv_arr = [0u8; IV_LEN]; + enc_iv_arr.copy_from_slice(enc_iv_bytes); + let enc_iv = u128::from_be_bytes(enc_iv_arr); let encryptor = AesCtr::new(&enc_key, enc_iv); + // Apply replay tracking only after successful authentication to prevent + // unauthenticated probes from evicting legitimate replay-cache entries. + if replay_checker.check_and_add_handshake(dec_prekey_iv) { + warn!(peer = %peer, user = %user, "MTProto replay attack detected"); + return HandshakeResult::BadClient { reader, writer }; + } + let success = HandshakeSuccess { user: user.clone(), dc_idx, @@ -365,14 +377,16 @@ pub fn generate_tg_nonce( ) -> ([u8; HANDSHAKE_LEN], [u8; 32], u128, [u8; 32], u128) { loop { let bytes = rng.bytes(HANDSHAKE_LEN); - let mut nonce: [u8; HANDSHAKE_LEN] = bytes.try_into().unwrap(); + let Ok(mut nonce): Result<[u8; HANDSHAKE_LEN], _> = bytes.try_into() else { + continue; + }; if RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]) { continue; } - let first_four: [u8; 4] = nonce[..4].try_into().unwrap(); + let first_four: [u8; 4] = [nonce[0], nonce[1], nonce[2], nonce[3]]; if RESERVED_NONCE_BEGINNINGS.contains(&first_four) { continue; } - let continue_four: [u8; 4] = nonce[4..8].try_into().unwrap(); + let continue_four: [u8; 4] = [nonce[4], nonce[5], nonce[6], nonce[7]]; if RESERVED_NONCE_CONTINUES.contains(&continue_four) { continue; } nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes()); @@ -390,11 +404,17 @@ pub fn generate_tg_nonce( let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN]; let dec_key_iv: Vec = enc_key_iv.iter().rev().copied().collect(); - let tg_enc_key: [u8; 32] = enc_key_iv[..KEY_LEN].try_into().unwrap(); - let tg_enc_iv = u128::from_be_bytes(enc_key_iv[KEY_LEN..].try_into().unwrap()); + let mut tg_enc_key = [0u8; 32]; + tg_enc_key.copy_from_slice(&enc_key_iv[..KEY_LEN]); + let mut tg_enc_iv_arr = [0u8; IV_LEN]; + tg_enc_iv_arr.copy_from_slice(&enc_key_iv[KEY_LEN..]); + let tg_enc_iv = u128::from_be_bytes(tg_enc_iv_arr); - let tg_dec_key: [u8; 32] = dec_key_iv[..KEY_LEN].try_into().unwrap(); - let tg_dec_iv = u128::from_be_bytes(dec_key_iv[KEY_LEN..].try_into().unwrap()); + let mut tg_dec_key = [0u8; 32]; + tg_dec_key.copy_from_slice(&dec_key_iv[..KEY_LEN]); + let mut tg_dec_iv_arr = [0u8; IV_LEN]; + tg_dec_iv_arr.copy_from_slice(&dec_key_iv[KEY_LEN..]); + let tg_dec_iv = u128::from_be_bytes(tg_dec_iv_arr); return (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv); } @@ -405,11 +425,17 @@ pub fn encrypt_tg_nonce_with_ciphers(nonce: &[u8; HANDSHAKE_LEN]) -> (Vec, A let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN]; let dec_key_iv: Vec = enc_key_iv.iter().rev().copied().collect(); - let enc_key: [u8; 32] = enc_key_iv[..KEY_LEN].try_into().unwrap(); - let enc_iv = u128::from_be_bytes(enc_key_iv[KEY_LEN..].try_into().unwrap()); + let mut enc_key = [0u8; 32]; + enc_key.copy_from_slice(&enc_key_iv[..KEY_LEN]); + let mut enc_iv_arr = [0u8; IV_LEN]; + enc_iv_arr.copy_from_slice(&enc_key_iv[KEY_LEN..]); + let enc_iv = u128::from_be_bytes(enc_iv_arr); - let dec_key: [u8; 32] = dec_key_iv[..KEY_LEN].try_into().unwrap(); - let dec_iv = u128::from_be_bytes(dec_key_iv[KEY_LEN..].try_into().unwrap()); + let mut dec_key = [0u8; 32]; + dec_key.copy_from_slice(&dec_key_iv[..KEY_LEN]); + let mut dec_iv_arr = [0u8; IV_LEN]; + dec_iv_arr.copy_from_slice(&dec_key_iv[KEY_LEN..]); + let dec_iv = u128::from_be_bytes(dec_iv_arr); let mut encryptor = AesCtr::new(&enc_key, enc_iv); let encrypted_full = encryptor.encrypt(nonce); // counter: 0 → 4 @@ -429,80 +455,15 @@ pub fn encrypt_tg_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec { } #[cfg(test)] -mod tests { - use super::*; +#[path = "handshake_security_tests.rs"] +mod security_tests; - #[test] - fn test_generate_tg_nonce() { - let client_dec_key = [0x42u8; 32]; - let client_dec_iv = 12345u128; - let client_enc_key = [0x24u8; 32]; - let client_enc_iv = 54321u128; +/// Compile-time guard: HandshakeSuccess holds cryptographic key material and +/// must never be Copy. A Copy impl would allow silent key duplication, +/// undermining the zeroize-on-drop guarantee. +mod compile_time_security_checks { + use super::HandshakeSuccess; + use static_assertions::assert_not_impl_all; - let rng = SecureRandom::new(); - let (nonce, _tg_enc_key, _tg_enc_iv, _tg_dec_key, _tg_dec_iv) = - generate_tg_nonce( - ProtoTag::Secure, - 2, - &client_dec_key, - client_dec_iv, - &client_enc_key, - client_enc_iv, - &rng, - false, - ); - - assert_eq!(nonce.len(), HANDSHAKE_LEN); - - let tag_bytes: [u8; 4] = nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].try_into().unwrap(); - assert_eq!(ProtoTag::from_bytes(tag_bytes), Some(ProtoTag::Secure)); - } - - #[test] - fn test_encrypt_tg_nonce() { - let client_dec_key = [0x42u8; 32]; - let client_dec_iv = 12345u128; - let client_enc_key = [0x24u8; 32]; - let client_enc_iv = 54321u128; - - let rng = SecureRandom::new(); - let (nonce, _, _, _, _) = - generate_tg_nonce( - ProtoTag::Secure, - 2, - &client_dec_key, - client_dec_iv, - &client_enc_key, - client_enc_iv, - &rng, - false, - ); - - let encrypted = encrypt_tg_nonce(&nonce); - - assert_eq!(encrypted.len(), HANDSHAKE_LEN); - assert_eq!(&encrypted[..PROTO_TAG_POS], &nonce[..PROTO_TAG_POS]); - assert_ne!(&encrypted[PROTO_TAG_POS..], &nonce[PROTO_TAG_POS..]); - } - - #[test] - fn test_handshake_success_zeroize_on_drop() { - let success = HandshakeSuccess { - user: "test".to_string(), - dc_idx: 2, - proto_tag: ProtoTag::Secure, - dec_key: [0xAA; 32], - dec_iv: 0xBBBBBBBB, - enc_key: [0xCC; 32], - enc_iv: 0xDDDDDDDD, - peer: "127.0.0.1:1234".parse().unwrap(), - is_tls: true, - }; - - assert_eq!(success.dec_key, [0xAA; 32]); - assert_eq!(success.enc_key, [0xCC; 32]); - - drop(success); - // Drop impl zeroizes key material without panic - } + assert_not_impl_all!(HandshakeSuccess: Copy, Clone); } diff --git a/src/proxy/handshake_security_tests.rs b/src/proxy/handshake_security_tests.rs new file mode 100644 index 0000000..58178d9 --- /dev/null +++ b/src/proxy/handshake_security_tests.rs @@ -0,0 +1,276 @@ +use super::*; +use crate::crypto::sha256_hmac; +use std::time::Duration; + +fn make_valid_tls_handshake(secret: &[u8], timestamp: u32) -> Vec { + let session_id_len: usize = 32; + let len = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + session_id_len; + let mut handshake = vec![0x42u8; len]; + + handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8; + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + + let computed = sha256_hmac(secret, &handshake); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .copy_from_slice(&digest); + handshake +} + +fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig { + let mut cfg = ProxyConfig::default(); + cfg.access.users.clear(); + cfg.access + .users + .insert("user".to_string(), secret_hex.to_string()); + cfg.access.ignore_time_skew = true; + cfg +} + +#[test] +fn test_generate_tg_nonce() { + let client_dec_key = [0x42u8; 32]; + let client_dec_iv = 12345u128; + let client_enc_key = [0x24u8; 32]; + let client_enc_iv = 54321u128; + + let rng = SecureRandom::new(); + let (nonce, _tg_enc_key, _tg_enc_iv, _tg_dec_key, _tg_dec_iv) = generate_tg_nonce( + ProtoTag::Secure, + 2, + &client_dec_key, + client_dec_iv, + &client_enc_key, + client_enc_iv, + &rng, + false, + ); + + assert_eq!(nonce.len(), HANDSHAKE_LEN); + + let tag_bytes: [u8; 4] = nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].try_into().unwrap(); + assert_eq!(ProtoTag::from_bytes(tag_bytes), Some(ProtoTag::Secure)); +} + +#[test] +fn test_encrypt_tg_nonce() { + let client_dec_key = [0x42u8; 32]; + let client_dec_iv = 12345u128; + let client_enc_key = [0x24u8; 32]; + let client_enc_iv = 54321u128; + + let rng = SecureRandom::new(); + let (nonce, _, _, _, _) = generate_tg_nonce( + ProtoTag::Secure, + 2, + &client_dec_key, + client_dec_iv, + &client_enc_key, + client_enc_iv, + &rng, + false, + ); + + let encrypted = encrypt_tg_nonce(&nonce); + + assert_eq!(encrypted.len(), HANDSHAKE_LEN); + assert_eq!(&encrypted[..PROTO_TAG_POS], &nonce[..PROTO_TAG_POS]); + assert_ne!(&encrypted[PROTO_TAG_POS..], &nonce[PROTO_TAG_POS..]); +} + +#[test] +fn test_handshake_success_zeroize_on_drop() { + let success = HandshakeSuccess { + user: "test".to_string(), + dc_idx: 2, + proto_tag: ProtoTag::Secure, + dec_key: [0xAA; 32], + dec_iv: 0xBBBBBBBB, + enc_key: [0xCC; 32], + enc_iv: 0xDDDDDDDD, + peer: "127.0.0.1:1234".parse().unwrap(), + is_tls: true, + }; + + assert_eq!(success.dec_key, [0xAA; 32]); + assert_eq!(success.enc_key, [0xCC; 32]); + + drop(success); +} + +#[tokio::test] +async fn tls_replay_second_identical_handshake_is_rejected() { + let secret = [0x11u8; 16]; + let config = test_config_with_secret_hex("11111111111111111111111111111111"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "127.0.0.1:44321".parse().unwrap(); + let handshake = make_valid_tls_handshake(&secret, 0); + + let first = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(first, HandshakeResult::Success(_))); + + let second = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(second, HandshakeResult::BadClient { .. })); +} + +#[tokio::test] +async fn invalid_tls_probe_does_not_pollute_replay_cache() { + let config = test_config_with_secret_hex("11111111111111111111111111111111"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "127.0.0.1:44322".parse().unwrap(); + + let mut invalid = vec![0x42u8; tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + 32]; + invalid[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = 32; + + let before = replay_checker.stats(); + let result = handle_tls_handshake( + &invalid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + let after = replay_checker.stats(); + + assert!(matches!(result, HandshakeResult::BadClient { .. })); + assert_eq!(before.total_additions, after.total_additions); + assert_eq!(before.total_hits, after.total_hits); +} + +#[tokio::test] +async fn empty_decoded_secret_is_rejected() { + let config = test_config_with_secret_hex(""); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "127.0.0.1:44323".parse().unwrap(); + let handshake = make_valid_tls_handshake(&[], 0); + + let result = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(result, HandshakeResult::BadClient { .. })); +} + +#[tokio::test] +async fn wrong_length_decoded_secret_is_rejected() { + let config = test_config_with_secret_hex("aa"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "127.0.0.1:44324".parse().unwrap(); + let handshake = make_valid_tls_handshake(&[0xaau8], 0); + + let result = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(result, HandshakeResult::BadClient { .. })); +} + +#[tokio::test] +async fn invalid_mtproto_probe_does_not_pollute_replay_cache() { + let config = test_config_with_secret_hex("11111111111111111111111111111111"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let peer: SocketAddr = "127.0.0.1:44325".parse().unwrap(); + let handshake = [0u8; HANDSHAKE_LEN]; + + let before = replay_checker.stats(); + let result = handle_mtproto_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + let after = replay_checker.stats(); + + assert!(matches!(result, HandshakeResult::BadClient { .. })); + assert_eq!(before.total_additions, after.total_additions); + assert_eq!(before.total_hits, after.total_hits); +} + +#[tokio::test] +async fn mixed_secret_lengths_keep_valid_user_authenticating() { + let good_secret = [0x22u8; 16]; + let mut config = ProxyConfig::default(); + config.access.users.clear(); + config + .access + .users + .insert("broken_user".to_string(), "aa".to_string()); + config + .access + .users + .insert("valid_user".to_string(), "22222222222222222222222222222222".to_string()); + config.access.ignore_time_skew = true; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "127.0.0.1:44326".parse().unwrap(); + let handshake = make_valid_tls_handshake(&good_secret, 0); + + let result = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(result, HandshakeResult::Success(_))); +} diff --git a/src/proxy/masking.rs b/src/proxy/masking.rs index 318071b..fd0b404 100644 --- a/src/proxy/masking.rs +++ b/src/proxy/masking.rs @@ -194,55 +194,48 @@ async fn relay_to_mask( initial_data: &[u8], ) where - R: AsyncRead + Unpin + Send + 'static, - W: AsyncWrite + Unpin + Send + 'static, - MR: AsyncRead + Unpin + Send + 'static, - MW: AsyncWrite + Unpin + Send + 'static, + R: AsyncRead + Unpin + Send, + W: AsyncWrite + Unpin + Send, + MR: AsyncRead + Unpin + Send, + MW: AsyncWrite + Unpin + Send, { // Send initial data to mask host if mask_write.write_all(initial_data).await.is_err() { return; } - // Relay traffic - let c2m = tokio::spawn(async move { - let mut buf = vec![0u8; MASK_BUFFER_SIZE]; - loop { - match reader.read(&mut buf).await { - Ok(0) | Err(_) => { - let _ = mask_write.shutdown().await; - break; - } - Ok(n) => { - if mask_write.write_all(&buf[..n]).await.is_err() { + let mut client_buf = vec![0u8; MASK_BUFFER_SIZE]; + let mut mask_buf = vec![0u8; MASK_BUFFER_SIZE]; + + loop { + tokio::select! { + client_read = reader.read(&mut client_buf) => { + match client_read { + Ok(0) | Err(_) => { + let _ = mask_write.shutdown().await; break; } + Ok(n) => { + if mask_write.write_all(&client_buf[..n]).await.is_err() { + break; + } + } + } + } + mask_read_res = mask_read.read(&mut mask_buf) => { + match mask_read_res { + Ok(0) | Err(_) => { + let _ = writer.shutdown().await; + break; + } + Ok(n) => { + if writer.write_all(&mask_buf[..n]).await.is_err() { + break; + } + } } } } - }); - - let m2c = tokio::spawn(async move { - let mut buf = vec![0u8; MASK_BUFFER_SIZE]; - loop { - match mask_read.read(&mut buf).await { - Ok(0) | Err(_) => { - let _ = writer.shutdown().await; - break; - } - Ok(n) => { - if writer.write_all(&buf[..n]).await.is_err() { - break; - } - } - } - } - }); - - // Wait for either to complete - tokio::select! { - _ = c2m => {} - _ = m2c => {} } } @@ -255,3 +248,7 @@ async fn consume_client_data(mut reader: R) { } } } + +#[cfg(test)] +#[path = "masking_security_tests.rs"] +mod security_tests; diff --git a/src/proxy/masking_security_tests.rs b/src/proxy/masking_security_tests.rs new file mode 100644 index 0000000..50ea8ed --- /dev/null +++ b/src/proxy/masking_security_tests.rs @@ -0,0 +1,257 @@ +use super::*; +use crate::config::ProxyConfig; +use tokio::io::{duplex, AsyncBufReadExt, BufReader}; +use tokio::net::TcpListener; +use tokio::time::{timeout, Duration}; + +#[tokio::test] +async fn bad_client_probe_is_forwarded_verbatim_to_mask_backend() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = b"GET / HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + let backend_reply = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK".to_vec(); + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut received = vec![0u8; probe.len()]; + stream.read_exact(&mut received).await.unwrap(); + assert_eq!(received, probe); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let peer: SocketAddr = "203.0.113.10:42424".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (client_reader, _client_writer) = duplex(256); + let (mut client_visible_reader, client_visible_writer) = duplex(2048); + + let beobachten = BeobachtenStore::new(); + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + let mut observed = vec![0u8; backend_reply.len()]; + client_visible_reader.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, backend_reply); + accept_task.await.unwrap(); +} + +#[tokio::test] +async fn tls_scanner_probe_keeps_http_like_fallback_surface() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = vec![0x16, 0x03, 0x01, 0x00, 0x10, 0x01, 0x02, 0x03, 0x04]; + let backend_reply = b"HTTP/1.1 400 Bad Request\r\nContent-Length: 0\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut received = vec![0u8; probe.len()]; + stream.read_exact(&mut received).await.unwrap(); + assert_eq!(received, probe); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let peer: SocketAddr = "198.51.100.44:55221".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (client_reader, _client_writer) = duplex(256); + let (mut client_visible_reader, client_visible_writer) = duplex(2048); + + let beobachten = BeobachtenStore::new(); + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + let mut observed = vec![0u8; backend_reply.len()]; + client_visible_reader.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, backend_reply); + assert!(observed.starts_with(b"HTTP/")); + accept_task.await.unwrap(); +} + +#[tokio::test] +async fn backend_unavailable_falls_back_to_silent_consume() { + let temp_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let unused_port = temp_listener.local_addr().unwrap().port(); + drop(temp_listener); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = unused_port; + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let peer: SocketAddr = "203.0.113.11:42425".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let probe = b"GET /probe HTTP/1.1\r\nHost: x\r\n\r\n"; + + let (mut client_reader_side, client_reader) = duplex(256); + let (mut client_visible_reader, client_visible_writer) = duplex(256); + let beobachten = BeobachtenStore::new(); + + let task = tokio::spawn(async move { + handle_bad_client( + client_reader, + client_visible_writer, + probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + }); + + client_reader_side.write_all(b"noise").await.unwrap(); + drop(client_reader_side); + + timeout(Duration::from_secs(3), task).await.unwrap().unwrap(); + + let mut buf = [0u8; 1]; + let n = timeout(Duration::from_secs(1), client_visible_reader.read(&mut buf)) + .await + .unwrap() + .unwrap(); + assert_eq!(n, 0); +} + +#[tokio::test] +async fn mask_disabled_consumes_client_data_without_response() { + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = false; + + let peer: SocketAddr = "198.51.100.12:45454".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let initial = b"scanner"; + + let (mut client_reader_side, client_reader) = duplex(256); + let (mut client_visible_reader, client_visible_writer) = duplex(256); + let beobachten = BeobachtenStore::new(); + + let task = tokio::spawn(async move { + handle_bad_client( + client_reader, + client_visible_writer, + initial, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + }); + + client_reader_side.write_all(b"untrusted payload").await.unwrap(); + drop(client_reader_side); + + timeout(Duration::from_secs(3), task).await.unwrap().unwrap(); + + let mut buf = [0u8; 1]; + let n = timeout(Duration::from_secs(1), client_visible_reader.read(&mut buf)) + .await + .unwrap() + .unwrap(); + assert_eq!(n, 0); +} + +#[tokio::test] +async fn proxy_protocol_v1_header_is_sent_before_probe() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = b"GET / HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + let backend_reply = b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + let backend_reply = backend_reply.clone(); + async move { + let (stream, _) = listener.accept().await.unwrap(); + let mut reader = BufReader::new(stream); + + let mut header_line = Vec::new(); + reader.read_until(b'\n', &mut header_line).await.unwrap(); + let header_text = String::from_utf8(header_line.clone()).unwrap(); + assert!(header_text.starts_with("PROXY TCP4 ")); + assert!(header_text.ends_with("\r\n")); + + let mut received_probe = vec![0u8; probe.len()]; + reader.read_exact(&mut received_probe).await.unwrap(); + assert_eq!(received_probe, probe); + + let mut stream = reader.into_inner(); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 1; + + let peer: SocketAddr = "203.0.113.15:50001".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (client_reader, _client_writer) = duplex(256); + let (mut client_visible_reader, client_visible_writer) = duplex(2048); + + let beobachten = BeobachtenStore::new(); + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + let mut observed = vec![0u8; backend_reply.len()]; + client_visible_reader.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, backend_reply); + accept_task.await.unwrap(); +}