diff --git a/.cargo/deny.toml b/.cargo/deny.toml new file mode 100644 index 0000000..09a5dd9 --- /dev/null +++ b/.cargo/deny.toml @@ -0,0 +1,15 @@ +[bans] +multiple-versions = "deny" +wildcards = "allow" +highlight = "all" + +# Explicitly flag the weak cryptography so the agent is forced to justify its existence +[[bans.skip]] +name = "md-5" +version = "*" +reason = "MUST VERIFY: Only allowed for legacy checksums, never for security." + +[[bans.skip]] +name = "sha1" +version = "*" +reason = "MUST VERIFY: Only allowed for backwards compatibility." diff --git a/AGENTS.md b/AGENTS.md index e6c5f2e..e7f94a5 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -5,6 +5,22 @@ Your responses are precise, minimal, and architecturally sound. You are working --- +### Context: The Telemt Project + +You are working on **Telemt**, a high-performance, production-grade Telegram MTProxy implementation written in Rust. It is explicitly designed to operate in highly hostile network environments and evade advanced network censorship. + +**Adversarial Threat Model:** +The proxy operates under constant surveillance by DPI (Deep Packet Inspection) systems and active scanners (state firewalls, mobile operator fraud controls). These entities actively probe IPs, analyze protocol handshakes, and look for known proxy signatures to block or throttle traffic. + +**Core Architectural Pillars:** +1. **TLS-Fronting (TLS-F) & TCP-Splitting (TCP-S):** To the outside world, Telemt looks like a standard TLS server. If a client presents a valid MTProxy key, the connection is handled internally. If a censor's scanner, web browser, or unauthorized crawler connects, Telemt seamlessly splices the TCP connection (L4) to a real, legitimate HTTPS fallback server (e.g., Nginx) without modifying the `ClientHello` or terminating the TLS handshake. +2. **Middle-End (ME) Orchestration:** A highly concurrent, generation-based pool managing upstream connections to Telegram Datacenters (DCs). It utilizes an **Adaptive Floor** (dynamically scaling writer connections based on traffic), **Hardswaps** (zero-downtime pool reconfiguration), and **STUN/NAT** reflection mechanisms. +3. **Strict KDF Routing:** Cryptographic Key Derivation Functions (KDF) in this protocol strictly rely on the exact pairing of Source IP/Port and Destination IP/Port. Deviations or missing port logic will silently break the MTProto handshake. +4. **Data Plane vs. Control Plane Isolation:** The Data Plane (readers, writers, payload relay, TCP splicing) must remain strictly non-blocking, zero-allocation in hot paths, and highly resilient to network backpressure. The Control Plane (API, metrics, pool generation swaps, config reloads) orchestrates the state asynchronously without stalling the Data Plane. + +Any modification you make must preserve Telemt's invisibility to censors, its strict memory-safety invariants, and its hot-path throughput. + + ### 0. Priority Resolution — Scope Control This section resolves conflicts between code quality enforcement and scope limitation. diff --git a/Cargo.lock b/Cargo.lock index 06ea5c6..89eefd6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2025,6 +2025,12 @@ version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596" +[[package]] +name = "static_assertions" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" + [[package]] name = "subtle" version = "2.6.1" @@ -2087,7 +2093,7 @@ dependencies = [ [[package]] name = "telemt" -version = "3.3.15" +version = "3.3.19" dependencies = [ "aes", "anyhow", @@ -2127,6 +2133,8 @@ dependencies = [ "sha1", "sha2", "socket2 0.5.10", + "static_assertions", + "subtle", "thiserror 2.0.18", "tokio", "tokio-rustls", diff --git a/Cargo.toml b/Cargo.toml index 9374924..4e12cad 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "telemt" -version = "3.3.19" +version = "3.3.20" edition = "2024" [dependencies] @@ -22,6 +22,8 @@ hmac = "0.12" crc32fast = "1.4" crc32c = "0.6" zeroize = { version = "1.8", features = ["derive"] } +subtle = "2.6" +static_assertions = "1.1" # Network socket2 = { version = "0.5", features = ["all"] } diff --git a/src/config/types.rs b/src/config/types.rs index f676f54..808698d 100644 --- a/src/config/types.rs +++ b/src/config/types.rs @@ -1156,6 +1156,13 @@ pub struct ServerConfig { #[serde(default = "default_proxy_protocol_header_timeout_ms")] pub proxy_protocol_header_timeout_ms: u64, + /// Trusted source CIDRs allowed to send incoming PROXY protocol headers. + /// + /// When non-empty, connections from addresses outside this allowlist are + /// rejected before `src_addr` is applied. + #[serde(default)] + pub proxy_protocol_trusted_cidrs: Vec, + #[serde(default)] pub metrics_port: Option, @@ -1185,6 +1192,7 @@ impl Default for ServerConfig { listen_tcp: None, proxy_protocol: false, proxy_protocol_header_timeout_ms: default_proxy_protocol_header_timeout_ms(), + proxy_protocol_trusted_cidrs: Vec::new(), metrics_port: None, metrics_whitelist: default_metrics_whitelist(), api: ApiConfig::default(), diff --git a/src/protocol/tls.rs b/src/protocol/tls.rs index fbe7ad5..c82c9fe 100644 --- a/src/protocol/tls.rs +++ b/src/protocol/tls.rs @@ -13,6 +13,7 @@ use super::constants::*; use std::time::{SystemTime, UNIX_EPOCH}; use num_bigint::BigUint; use num_traits::One; +use subtle::ConstantTimeEq; // ============= Public Constants ============= @@ -28,6 +29,8 @@ pub const TLS_DIGEST_HALF_LEN: usize = 16; /// Time skew limits for anti-replay (in seconds) pub const TIME_SKEW_MIN: i64 = -20 * 60; // 20 minutes before pub const TIME_SKEW_MAX: i64 = 10 * 60; // 10 minutes after +/// Maximum accepted boot-time timestamp (seconds) before skew checks are enforced. +pub const BOOT_TIME_MAX_SECS: u32 = 7 * 24 * 60 * 60; // ============= Private Constants ============= @@ -125,7 +128,7 @@ impl TlsExtensionBuilder { // protocol name length (1 byte) // protocol name bytes let proto_len = proto.len() as u8; - let list_len: u16 = 1 + proto_len as u16; + let list_len: u16 = 1 + u16::from(proto_len); let ext_len: u16 = 2 + list_len; self.extensions.extend_from_slice(&ext_len.to_be_bytes()); @@ -273,13 +276,86 @@ impl ServerHelloBuilder { // ============= Public Functions ============= -/// Validate TLS ClientHello against user secrets +/// Validate TLS ClientHello against user secrets. /// /// Returns validation result if a matching user is found. +/// The result **must** be used — ignoring it silently bypasses authentication. +#[must_use] pub fn validate_tls_handshake( handshake: &[u8], secrets: &[(String, Vec)], ignore_time_skew: bool, +) -> Option { + validate_tls_handshake_with_replay_window( + handshake, + secrets, + ignore_time_skew, + u64::from(BOOT_TIME_MAX_SECS), + ) +} + +/// Validate TLS ClientHello and cap the boot-time bypass by replay-cache TTL. +/// +/// A boot-time timestamp is only accepted when it falls below both +/// `BOOT_TIME_MAX_SECS` and the configured replay window, preventing timestamp +/// reuse outside replay cache coverage. +#[must_use] +pub fn validate_tls_handshake_with_replay_window( + handshake: &[u8], + secrets: &[(String, Vec)], + ignore_time_skew: bool, + replay_window_secs: u64, +) -> Option { + // Only pay the clock syscall when we will actually compare against it. + // If `ignore_time_skew` is set, a broken or unavailable system clock + // must not block legitimate clients — that would be a DoS via clock failure. + let now = if !ignore_time_skew { + system_time_to_unix_secs(SystemTime::now())? + } else { + 0_i64 + }; + + let replay_window_u32 = u32::try_from(replay_window_secs).unwrap_or(u32::MAX); + let boot_time_cap_secs = BOOT_TIME_MAX_SECS.min(replay_window_u32); + + validate_tls_handshake_at_time_with_boot_cap( + handshake, + secrets, + ignore_time_skew, + now, + boot_time_cap_secs, + ) +} + +fn system_time_to_unix_secs(now: SystemTime) -> Option { + // `try_from` rejects values that overflow i64 (> ~292 billion years CE), + // whereas `as i64` would silently wrap to a negative timestamp and corrupt + // every subsequent time-skew comparison. + let d = now.duration_since(UNIX_EPOCH).ok()?; + i64::try_from(d.as_secs()).ok() +} + +fn validate_tls_handshake_at_time( + handshake: &[u8], + secrets: &[(String, Vec)], + ignore_time_skew: bool, + now: i64, +) -> Option { + validate_tls_handshake_at_time_with_boot_cap( + handshake, + secrets, + ignore_time_skew, + now, + BOOT_TIME_MAX_SECS, + ) +} + +fn validate_tls_handshake_at_time_with_boot_cap( + handshake: &[u8], + secrets: &[(String, Vec)], + ignore_time_skew: bool, + now: i64, + boot_time_cap_secs: u32, ) -> Option { if handshake.len() < TLS_DIGEST_POS + TLS_DIGEST_LEN + 1 { return None; @@ -305,50 +381,56 @@ pub fn validate_tls_handshake( let mut msg = handshake.to_vec(); msg[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].fill(0); - // Get current time - let now = SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_secs() as i64; - + let mut first_match: Option = None; + for (user, secret) in secrets { let computed = sha256_hmac(secret, &msg); - - // XOR digests - let xored: Vec = digest.iter() - .zip(computed.iter()) - .map(|(a, b)| a ^ b) - .collect(); - - // Check that first 28 bytes are zeros (timestamp in last 4) - if !xored[..28].iter().all(|&b| b == 0) { + + // Constant-time equality check on the 28-byte HMAC window. + // A variable-time short-circuit here lets an active censor measure how many + // bytes matched, enabling secret brute-force via timing side-channels. + // Direct comparison on the original arrays avoids a heap allocation and + // removes the `try_into().unwrap()` that the intermediate Vec would require. + if !bool::from(digest[..28].ct_eq(&computed[..28])) { continue; } - - // Extract timestamp - let timestamp = u32::from_le_bytes(xored[28..32].try_into().unwrap()); - let time_diff = now - timestamp as i64; - - // Check time skew + + // The last 4 bytes encode the timestamp as XOR(digest[28..32], computed[28..32]). + // Inline array construction is infallible: both slices are [u8; 32] by construction. + let timestamp = u32::from_le_bytes([ + digest[28] ^ computed[28], + digest[29] ^ computed[29], + digest[30] ^ computed[30], + digest[31] ^ computed[31], + ]); + + // time_diff is only meaningful (and `now` is only valid) when we are + // actually checking the window. Keep both inside the guard to make + // the dead-code path explicit and prevent accidental future use of + // a sentinel `now` value outside its intended scope. if !ignore_time_skew { // Allow very small timestamps (boot time instead of unix time) // This is a quirk in some clients that use uptime instead of real time - let is_boot_time = timestamp < 60 * 60 * 24 * 1000; // < ~2.7 years in seconds - - if !is_boot_time && !(TIME_SKEW_MIN..=TIME_SKEW_MAX).contains(&time_diff) { - continue; + let is_boot_time = timestamp < boot_time_cap_secs; + if !is_boot_time { + let time_diff = now - i64::from(timestamp); + if !(TIME_SKEW_MIN..=TIME_SKEW_MAX).contains(&time_diff) { + continue; + } } } - return Some(TlsValidation { - user: user.clone(), - session_id, - digest, - timestamp, - }); + if first_match.is_none() { + first_match = Some(TlsValidation { + user: user.clone(), + session_id: session_id.clone(), + digest, + timestamp, + }); + } } - - None + + first_match } fn curve25519_prime() -> BigUint { @@ -528,7 +610,9 @@ pub fn extract_sni_from_client_hello(handshake: &[u8]) -> Option { if name_type == 0 && name_len > 0 && let Ok(host) = std::str::from_utf8(&handshake[sn_pos..sn_pos + name_len]) { - return Some(host.to_string()); + if is_valid_sni_hostname(host) { + return Some(host.to_string()); + } } sn_pos += name_len; } @@ -539,6 +623,35 @@ pub fn extract_sni_from_client_hello(handshake: &[u8]) -> Option { None } +fn is_valid_sni_hostname(host: &str) -> bool { + if host.is_empty() || host.len() > 253 { + return false; + } + if host.starts_with('.') || host.ends_with('.') { + return false; + } + if host.parse::().is_ok() { + return false; + } + + for label in host.split('.') { + if label.is_empty() || label.len() > 63 { + return false; + } + if label.starts_with('-') || label.ends_with('-') { + return false; + } + if !label + .bytes() + .all(|b| b.is_ascii_alphanumeric() || b == b'-') + { + return false; + } + } + + true +} + /// Extract ALPN protocol list from ClientHello, return in offered order. pub fn extract_alpn_from_client_hello(handshake: &[u8]) -> Vec> { let mut pos = 5; // after record header @@ -667,291 +780,29 @@ fn validate_server_hello_structure(data: &[u8]) -> Result<(), ProxyError> { Ok(()) } -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_is_tls_handshake() { - assert!(is_tls_handshake(&[0x16, 0x03, 0x01])); - assert!(is_tls_handshake(&[0x16, 0x03, 0x01, 0x02, 0x00])); - assert!(!is_tls_handshake(&[0x17, 0x03, 0x01])); // Application data - assert!(!is_tls_handshake(&[0x16, 0x03, 0x02])); // Wrong version - assert!(!is_tls_handshake(&[0x16, 0x03])); // Too short - } - - #[test] - fn test_parse_tls_record_header() { - let header = [0x16, 0x03, 0x01, 0x02, 0x00]; - let result = parse_tls_record_header(&header).unwrap(); - assert_eq!(result.0, TLS_RECORD_HANDSHAKE); - assert_eq!(result.1, 512); - - let header = [0x17, 0x03, 0x03, 0x40, 0x00]; - let result = parse_tls_record_header(&header).unwrap(); - assert_eq!(result.0, TLS_RECORD_APPLICATION); - assert_eq!(result.1, 16384); - } - - #[test] - fn test_gen_fake_x25519_key() { - let rng = SecureRandom::new(); - let key1 = gen_fake_x25519_key(&rng); - let key2 = gen_fake_x25519_key(&rng); - - assert_eq!(key1.len(), 32); - assert_eq!(key2.len(), 32); - assert_ne!(key1, key2); // Should be random - } +// ============= Compile-time Security Invariants ============= - #[test] - fn test_fake_x25519_key_is_quadratic_residue() { - let rng = SecureRandom::new(); - let key = gen_fake_x25519_key(&rng); - let p = curve25519_prime(); - let k_num = BigUint::from_bytes_le(&key); - let exponent = (&p - BigUint::one()) >> 1; - let legendre = k_num.modpow(&exponent, &p); - assert_eq!(legendre, BigUint::one()); - } - - #[test] - fn test_tls_extension_builder() { - let key = [0x42u8; 32]; - - let mut builder = TlsExtensionBuilder::new(); - builder.add_key_share(&key); - builder.add_supported_versions(0x0304); - - let result = builder.build(); - - // Check length prefix - let len = u16::from_be_bytes([result[0], result[1]]) as usize; - assert_eq!(len, result.len() - 2); - - // Check key_share extension is present - assert!(result.len() > 40); // At least key share - } - - #[test] - fn test_server_hello_builder() { - let session_id = vec![0x01, 0x02, 0x03, 0x04]; - let key = [0x55u8; 32]; - - let builder = ServerHelloBuilder::new(session_id.clone()) - .with_x25519_key(&key) - .with_tls13_version(); - - let record = builder.build_record(); - - // Validate structure - validate_server_hello_structure(&record).expect("Invalid ServerHello structure"); - - // Check record type - assert_eq!(record[0], TLS_RECORD_HANDSHAKE); - - // Check version - assert_eq!(&record[1..3], &TLS_VERSION); - - // Check message type (ServerHello = 0x02) - assert_eq!(record[5], 0x02); - } - - #[test] - fn test_build_server_hello_structure() { - let secret = b"test secret"; - let client_digest = [0x42u8; 32]; - let session_id = vec![0xAA; 32]; - - let rng = SecureRandom::new(); - let response = build_server_hello(secret, &client_digest, &session_id, 2048, &rng, None, 0); - - // Should have at least 3 records - assert!(response.len() > 100); - - // First record should be ServerHello - assert_eq!(response[0], TLS_RECORD_HANDSHAKE); - - // Validate ServerHello structure - validate_server_hello_structure(&response).expect("Invalid ServerHello"); - - // Find Change Cipher Spec - let server_hello_len = 5 + u16::from_be_bytes([response[3], response[4]]) as usize; - let ccs_start = server_hello_len; - - assert!(response.len() > ccs_start + 6); - assert_eq!(response[ccs_start], TLS_RECORD_CHANGE_CIPHER); - - // Find Application Data - let ccs_len = 5 + u16::from_be_bytes([response[ccs_start + 3], response[ccs_start + 4]]) as usize; - let app_start = ccs_start + ccs_len; - - assert!(response.len() > app_start + 5); - assert_eq!(response[app_start], TLS_RECORD_APPLICATION); - } - - #[test] - fn test_build_server_hello_digest() { - let secret = b"test secret key here"; - let client_digest = [0x42u8; 32]; - let session_id = vec![0xAA; 32]; - - let rng = SecureRandom::new(); - let response1 = build_server_hello(secret, &client_digest, &session_id, 1024, &rng, None, 0); - let response2 = build_server_hello(secret, &client_digest, &session_id, 1024, &rng, None, 0); - - // Digest position should have non-zero data - let digest1 = &response1[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN]; - assert!(!digest1.iter().all(|&b| b == 0)); - - // Different calls should have different digests (due to random cert) - let digest2 = &response2[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN]; - assert_ne!(digest1, digest2); - } - - #[test] - fn test_server_hello_extensions_length() { - let session_id = vec![0x01; 32]; - let key = [0x55u8; 32]; - - let builder = ServerHelloBuilder::new(session_id) - .with_x25519_key(&key) - .with_tls13_version(); - - let record = builder.build_record(); - - // Parse to find extensions - let msg_start = 5; // After record header - let msg_len = u32::from_be_bytes([0, record[6], record[7], record[8]]) as usize; - - // Skip to session ID - let session_id_pos = msg_start + 4 + 2 + 32; // header(4) + version(2) + random(32) - let session_id_len = record[session_id_pos] as usize; - - // Skip to extensions - let ext_len_pos = session_id_pos + 1 + session_id_len + 2 + 1; // session_id + cipher(2) + compression(1) - let ext_len = u16::from_be_bytes([record[ext_len_pos], record[ext_len_pos + 1]]) as usize; - - // Verify extensions length matches actual data - let extensions_data = &record[ext_len_pos + 2..msg_start + 4 + msg_len]; - assert_eq!(ext_len, extensions_data.len(), - "Extension length mismatch: declared {}, actual {}", ext_len, extensions_data.len()); - } - - #[test] - fn test_validate_tls_handshake_format() { - // Build a minimal ClientHello-like structure - let mut handshake = vec![0u8; 100]; - - // Put a valid-looking digest at position 11 - handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN] - .copy_from_slice(&[0x42; 32]); - - // Session ID length - handshake[TLS_DIGEST_POS + TLS_DIGEST_LEN] = 32; - - // This won't validate (wrong HMAC) but shouldn't panic - let secrets = vec![("test".to_string(), b"secret".to_vec())]; - let result = validate_tls_handshake(&handshake, &secrets, true); - - // Should return None (no match) but not panic - assert!(result.is_none()); - } +/// Compile-time checks that enforce invariants the rest of the code relies on. +/// Using `static_assertions` ensures these can never silently break across +/// refactors without a compile error. +mod compile_time_security_checks { + use super::{TLS_DIGEST_LEN, TLS_DIGEST_HALF_LEN}; + use static_assertions::const_assert; - fn build_client_hello_with_exts(exts: Vec<(u16, Vec)>, host: &str) -> Vec { - let mut body = Vec::new(); - body.extend_from_slice(&TLS_VERSION); // legacy version - body.extend_from_slice(&[0u8; 32]); // random - body.push(0); // session id len - body.extend_from_slice(&2u16.to_be_bytes()); // cipher suites len - body.extend_from_slice(&[0x13, 0x01]); // TLS_AES_128_GCM_SHA256 - body.push(1); // compression len - body.push(0); // null compression + // The digest must be exactly one SHA-256 output. + const_assert!(TLS_DIGEST_LEN == 32); - // Build SNI extension - let host_bytes = host.as_bytes(); - let mut sni_ext = Vec::new(); - sni_ext.extend_from_slice(&(host_bytes.len() as u16 + 3).to_be_bytes()); - sni_ext.push(0); - sni_ext.extend_from_slice(&(host_bytes.len() as u16).to_be_bytes()); - sni_ext.extend_from_slice(host_bytes); + // Replay-dedup stores the first half; verify it is literally half. + const_assert!(TLS_DIGEST_HALF_LEN * 2 == TLS_DIGEST_LEN); - let mut ext_blob = Vec::new(); - for (typ, data) in exts { - ext_blob.extend_from_slice(&typ.to_be_bytes()); - ext_blob.extend_from_slice(&(data.len() as u16).to_be_bytes()); - ext_blob.extend_from_slice(&data); - } - // SNI last - ext_blob.extend_from_slice(&0x0000u16.to_be_bytes()); - ext_blob.extend_from_slice(&(sni_ext.len() as u16).to_be_bytes()); - ext_blob.extend_from_slice(&sni_ext); - - body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes()); - body.extend_from_slice(&ext_blob); - - let mut handshake = Vec::new(); - handshake.push(0x01); // ClientHello - let len_bytes = (body.len() as u32).to_be_bytes(); - handshake.extend_from_slice(&len_bytes[1..4]); - handshake.extend_from_slice(&body); - - let mut record = Vec::new(); - record.push(TLS_RECORD_HANDSHAKE); - record.extend_from_slice(&[0x03, 0x01]); - record.extend_from_slice(&(handshake.len() as u16).to_be_bytes()); - record.extend_from_slice(&handshake); - record - } - - #[test] - fn test_extract_sni_with_grease_extension() { - // GREASE type 0x0a0a with zero length before SNI - let ch = build_client_hello_with_exts(vec![(0x0a0a, Vec::new())], "example.com"); - let sni = extract_sni_from_client_hello(&ch); - assert_eq!(sni.as_deref(), Some("example.com")); - } - - #[test] - fn test_extract_sni_tolerates_empty_unknown_extension() { - let ch = build_client_hello_with_exts(vec![(0x1234, Vec::new())], "test.local"); - let sni = extract_sni_from_client_hello(&ch); - assert_eq!(sni.as_deref(), Some("test.local")); - } - - #[test] - fn test_extract_alpn_single() { - let mut alpn_data = Vec::new(); - // list length = 3 (1 length byte + "h2") - alpn_data.extend_from_slice(&3u16.to_be_bytes()); - alpn_data.push(2); - alpn_data.extend_from_slice(b"h2"); - let ch = build_client_hello_with_exts(vec![(0x0010, alpn_data)], "alpn.test"); - let alpn = extract_alpn_from_client_hello(&ch); - let alpn_str: Vec = alpn - .iter() - .map(|p| std::str::from_utf8(p).unwrap().to_string()) - .collect(); - assert_eq!(alpn_str, vec!["h2"]); - } - - #[test] - fn test_extract_alpn_multiple() { - let mut alpn_data = Vec::new(); - // list length = 11 (sum of per-proto lengths including length bytes) - alpn_data.extend_from_slice(&11u16.to_be_bytes()); - alpn_data.push(2); - alpn_data.extend_from_slice(b"h2"); - alpn_data.push(4); - alpn_data.extend_from_slice(b"spdy"); - alpn_data.push(2); - alpn_data.extend_from_slice(b"h3"); - let ch = build_client_hello_with_exts(vec![(0x0010, alpn_data)], "alpn.test"); - let alpn = extract_alpn_from_client_hello(&ch); - let alpn_str: Vec = alpn - .iter() - .map(|p| std::str::from_utf8(p).unwrap().to_string()) - .collect(); - assert_eq!(alpn_str, vec!["h2", "spdy", "h3"]); - } + // The HMAC check window (28 bytes) plus the embedded timestamp (4 bytes) + // must exactly fill the digest. If TLS_DIGEST_LEN ever changes, these + // assertions will catch the mismatch before any timing-oracle fix is broke. + const_assert!(28 + 4 == TLS_DIGEST_LEN); } + +// ============= Security-focused regression tests ============= + +#[cfg(test)] +#[path = "tls_security_tests.rs"] +mod security_tests; diff --git a/src/protocol/tls_security_tests.rs b/src/protocol/tls_security_tests.rs new file mode 100644 index 0000000..c25a517 --- /dev/null +++ b/src/protocol/tls_security_tests.rs @@ -0,0 +1,1289 @@ +use super::*; +use crate::crypto::sha256_hmac; + +/// Build a TLS-handshake-like buffer that contains a valid HMAC digest +/// for the given `secret` and `timestamp`. +/// +/// Layout (bytes): +/// [0..TLS_DIGEST_POS] : fixed filler (0x42) +/// [TLS_DIGEST_POS..+32] : digest = HMAC XOR [0..0 || timestamp_le] +/// [TLS_DIGEST_POS+32] : session_id_len = 32 +/// [TLS_DIGEST_POS+33..+65] : session_id filler (0x42) +fn make_valid_tls_handshake(secret: &[u8], timestamp: u32) -> Vec { + let session_id_len: usize = 32; + let len = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1 + session_id_len; + let mut handshake = vec![0x42u8; len]; + + handshake[TLS_DIGEST_POS + TLS_DIGEST_LEN] = session_id_len as u8; + // Zero the digest slot before computing HMAC (mirrors what validate does). + handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].fill(0); + + let computed = sha256_hmac(secret, &handshake); + + // digest = HMAC such that XOR with stored digest yields [0..0, timestamp_le]. + // bytes 0-27 of digest == computed[0..28] -> xored[..28] == 0 + // bytes 28-31 of digest == computed[28..32] XOR timestamp_le + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + + handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN] + .copy_from_slice(&digest); + handshake +} + +// ------------------------------------------------------------------ +// Happy-path sanity +// ------------------------------------------------------------------ + +#[test] +fn valid_handshake_with_correct_secret_accepted() { + let secret = b"correct_horse_battery_staple_32b"; + // timestamp = 0 triggers is_boot_time path, accepted without wall-clock check. + let handshake = make_valid_tls_handshake(secret, 0); + let secrets = vec![("alice".to_string(), secret.to_vec())]; + let result = validate_tls_handshake(&handshake, &secrets, true); + assert!(result.is_some(), "Valid handshake must be accepted"); + assert_eq!(result.unwrap().user, "alice"); +} + +#[test] +fn deterministic_external_vector_validates_without_helper() { + // Deterministic vector generated by an external Python stdlib HMAC script, + // not by this test module helper. This catches mirrored helper mistakes. + let secret = hex::decode("00112233445566778899aabbccddeeff").unwrap(); + let handshake = hex::decode( + "4242424242424242424242a93225d1d6b46260bc9ce0cc48c7487d2b1ca5afa7ae9fc6609d9e60a3ca842b204242424242424242424242424242424242424242424242424242424242424242", + ) + .unwrap(); + + let secrets = vec![("vector_user".to_string(), secret)]; + let result = validate_tls_handshake(&handshake, &secrets, true).unwrap(); + + assert_eq!(result.user, "vector_user"); + assert_eq!(result.timestamp, 0x01020304); +} + +#[test] +fn valid_handshake_timestamp_extracted_correctly() { + let secret = b"ts_extraction_test"; + let ts: u32 = 0xDEAD_BEEF; + let handshake = make_valid_tls_handshake(secret, ts); + let secrets = vec![("u".to_string(), secret.to_vec())]; + let result = validate_tls_handshake(&handshake, &secrets, true); + assert!(result.is_some()); + assert_eq!(result.unwrap().timestamp, ts); +} + +// ------------------------------------------------------------------ +// HMAC bit-flip rejection - adversarial HMAC forgery attempts +// ------------------------------------------------------------------ + +/// Flip every single bit across the 28-byte HMAC check window one at a +/// time. Each flip must cause rejection. This is the primary guard +/// against a censor gradually narrowing down a valid HMAC via partial +/// matches (which would be exploitable with a variable-time comparison). +#[test] +fn hmac_single_bit_flip_anywhere_in_check_window_rejected() { + let secret = b"flip_test_secret"; + let base = make_valid_tls_handshake(secret, 0); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + // First ensure the unmodified handshake is accepted. + assert!( + validate_tls_handshake(&base, &secrets, true).is_some(), + "Baseline handshake must be accepted before flip tests" + ); + + for byte_pos in 0..28usize { + for bit in 0u8..8 { + let mut h = base.clone(); + h[TLS_DIGEST_POS + byte_pos] ^= 1 << bit; + assert!( + validate_tls_handshake(&h, &secrets, true).is_none(), + "Flip of bit {bit} in HMAC byte {byte_pos} must be rejected" + ); + } + } +} + +/// XOR entire check window (bytes 0-27) with 0xFF - must still fail. +#[test] +fn hmac_full_window_corruption_rejected() { + let secret = b"full_window_test"; + let mut h = make_valid_tls_handshake(secret, 0); + let secrets = vec![("u".to_string(), secret.to_vec())]; + for i in 0..28 { + h[TLS_DIGEST_POS + i] ^= 0xFF; + } + assert!(validate_tls_handshake(&h, &secrets, true).is_none()); +} + +/// Byte 27 is the last byte in the checked window. A non-constant-time +/// `all(|b| b == 0)` that short-circuits on byte 0 would never even reach +/// byte 27, making this an effective "did the fix actually run to the end" +/// sentinel: if this passes but the earlier byte-0 test fails, the check +/// window is not being evaluated end-to-end. +#[test] +fn hmac_last_byte_of_check_window_enforced() { + let secret = b"last_byte_sentinel"; + let mut h = make_valid_tls_handshake(secret, 0); + let secrets = vec![("u".to_string(), secret.to_vec())]; + // Corrupt only byte 27. + h[TLS_DIGEST_POS + 27] ^= 0x01; + assert!( + validate_tls_handshake(&h, &secrets, true).is_none(), + "Corruption at byte 27 (end of HMAC window) must cause rejection" + ); +} + +// ------------------------------------------------------------------ +// User enumeration / multi-user ordering +// ------------------------------------------------------------------ + +#[test] +fn wrong_user_secret_rejected_even_with_valid_structure() { + let secret_a = b"secret_alpha"; + let secret_b = b"secret_beta"; + let handshake = make_valid_tls_handshake(secret_b, 0); + // Only user_a is configured. + let secrets = vec![("user_a".to_string(), secret_a.to_vec())]; + assert!( + validate_tls_handshake(&handshake, &secrets, true).is_none(), + "Handshake for user_b must fail when only user_a is configured" + ); +} + +#[test] +fn second_user_in_list_found_when_first_does_not_match() { + let secret_a = b"secret_alpha"; + let secret_b = b"secret_beta"; + let handshake = make_valid_tls_handshake(secret_b, 0); + let secrets = vec![ + ("user_a".to_string(), secret_a.to_vec()), + ("user_b".to_string(), secret_b.to_vec()), + ]; + let result = validate_tls_handshake(&handshake, &secrets, true); + assert!(result.is_some(), "user_b must be found even though user_a comes first"); + assert_eq!(result.unwrap().user, "user_b"); +} + +#[test] +fn duplicate_secret_keeps_first_user_identity() { + // If multiple entries share the same secret, the selected identity must + // stay stable and deterministic (first entry wins). + let shared = b"same_secret_for_two_users"; + let handshake = make_valid_tls_handshake(shared, 0); + let secrets = vec![ + ("first_user".to_string(), shared.to_vec()), + ("second_user".to_string(), shared.to_vec()), + ]; + + let result = validate_tls_handshake(&handshake, &secrets, true); + assert!(result.is_some()); + assert_eq!(result.unwrap().user, "first_user"); +} + +#[test] +fn no_user_matches_returns_none() { + let secret_a = b"aaa"; + let secret_b = b"bbb"; + let secret_c = b"ccc"; + let handshake = make_valid_tls_handshake(b"unknown_secret", 0); + let secrets = vec![ + ("a".to_string(), secret_a.to_vec()), + ("b".to_string(), secret_b.to_vec()), + ("c".to_string(), secret_c.to_vec()), + ]; + assert!(validate_tls_handshake(&handshake, &secrets, true).is_none()); +} + +#[test] +fn empty_secrets_list_rejects_everything() { + let secret = b"test"; + let handshake = make_valid_tls_handshake(secret, 0); + let secrets: Vec<(String, Vec)> = Vec::new(); + assert!(validate_tls_handshake(&handshake, &secrets, true).is_none()); +} + +// ------------------------------------------------------------------ +// Timestamp / time-skew boundary attacks +// ------------------------------------------------------------------ + +#[test] +fn timestamp_at_time_skew_boundaries_accepted() { + let secret = b"skew_boundary_test_secret"; + let now: i64 = 1_700_000_000; + let secrets = vec![("u".to_string(), secret.to_vec())]; + + // time_diff = now - ts = TIME_SKEW_MIN = -1200 + // -> ts = now - TIME_SKEW_MIN = now + 1200 (20 min in the future). + let ts_at_future_limit = (now - TIME_SKEW_MIN) as u32; + let h = make_valid_tls_handshake(secret, ts_at_future_limit); + assert!( + validate_tls_handshake_at_time(&h, &secrets, false, now).is_some(), + "Timestamp at max-allowed future (time_diff = TIME_SKEW_MIN) must be accepted" + ); + + // time_diff = now - ts = TIME_SKEW_MAX = 600 + // -> ts = now - TIME_SKEW_MAX = now - 600 (10 min in the past). + let ts_at_past_limit = (now - TIME_SKEW_MAX) as u32; + let h = make_valid_tls_handshake(secret, ts_at_past_limit); + assert!( + validate_tls_handshake_at_time(&h, &secrets, false, now).is_some(), + "Timestamp at max-allowed past (time_diff = TIME_SKEW_MAX) must be accepted" + ); +} + +#[test] +fn timestamp_one_second_outside_skew_window_rejected() { + let secret = b"skew_outside_test_secret"; + let now: i64 = 1_700_000_000; + let secrets = vec![("u".to_string(), secret.to_vec())]; + + // time_diff = TIME_SKEW_MAX + 1 = 601 (one second too far in the past) + // -> ts = now - (TIME_SKEW_MAX + 1) = now - 601 + let ts_too_past = (now - TIME_SKEW_MAX - 1) as u32; + let h = make_valid_tls_handshake(secret, ts_too_past); + assert!( + validate_tls_handshake_at_time(&h, &secrets, false, now).is_none(), + "Timestamp one second too far in the past must be rejected" + ); + + // time_diff = TIME_SKEW_MIN - 1 = -1201 (one second too far in the future) + // -> ts = now - (TIME_SKEW_MIN - 1) = now + 1201 + let ts_too_future = (now - TIME_SKEW_MIN + 1) as u32; + let h = make_valid_tls_handshake(secret, ts_too_future); + assert!( + validate_tls_handshake_at_time(&h, &secrets, false, now).is_none(), + "Timestamp one second too far in the future must be rejected" + ); +} + +#[test] +fn ignore_time_skew_accepts_far_future_timestamp() { + let secret = b"ignore_skew_test"; + let now: i64 = 1_700_000_000; + let secrets = vec![("u".to_string(), secret.to_vec())]; + + // 1 hour in the future - outside TIME_SKEW_MAX but should pass with flag. + let future_ts = (now + 3600) as u32; + let h = make_valid_tls_handshake(secret, future_ts); + assert!( + validate_tls_handshake_at_time(&h, &secrets, true, now).is_some(), + "ignore_time_skew=true must override window rejection" + ); + assert!( + validate_tls_handshake_at_time(&h, &secrets, false, now).is_none(), + "ignore_time_skew=false must still reject far-future timestamp" + ); +} + +#[test] +fn boot_time_timestamp_accepted_without_ignore_flag() { + // Timestamps below the boot-time threshold are treated as client uptime, + // not real wall-clock time. The proxy allows them regardless of skew. + let secret = b"boot_time_test"; + // Keep this safely below BOOT_TIME_MAX_SECS to assert bypass behavior. + let boot_ts: u32 = BOOT_TIME_MAX_SECS / 2; + let handshake = make_valid_tls_handshake(secret, boot_ts); + let secrets = vec![("u".to_string(), secret.to_vec())]; + assert!( + validate_tls_handshake(&handshake, &secrets, false).is_some(), + "Boot-time timestamp must be accepted even with ignore_time_skew=false" + ); +} + +// ------------------------------------------------------------------ +// Structural / length boundary attacks +// ------------------------------------------------------------------ + +#[test] +fn too_short_handshake_rejected_without_panic() { + let secrets = vec![("u".to_string(), b"s".to_vec())]; + // Exactly one byte short of the minimum required length. + let h = vec![0u8; TLS_DIGEST_POS + TLS_DIGEST_LEN]; + assert!(validate_tls_handshake(&h, &secrets, true).is_none()); + + // Empty buffer. + assert!(validate_tls_handshake(&[], &secrets, true).is_none()); +} + +#[test] +fn claimed_session_id_overflows_buffer_rejected() { + let session_id_len: usize = 32; + let min_len = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1 + session_id_len; + let mut h = vec![0u8; min_len]; + // Claim session_id is 33 bytes - one more than the buffer holds. + h[TLS_DIGEST_POS + TLS_DIGEST_LEN] = (session_id_len + 1) as u8; + let secrets = vec![("u".to_string(), b"s".to_vec())]; + assert!(validate_tls_handshake(&h, &secrets, true).is_none()); +} + +#[test] +fn max_session_id_len_255_does_not_panic() { + // session_id_len = 255 with a buffer that is far too small for it. + let min_len = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1 + 32; + let mut h = vec![0u8; min_len]; + h[TLS_DIGEST_POS + TLS_DIGEST_LEN] = 255; + let secrets = vec![("u".to_string(), b"s".to_vec())]; + assert!(validate_tls_handshake(&h, &secrets, true).is_none()); +} + +// ------------------------------------------------------------------ +// Adversarial digest values +// ------------------------------------------------------------------ + +#[test] +fn all_zeros_digest_rejected() { + // An all-zeros digest would only pass if HMAC(secret, msg) happens to + // have its first 28 bytes all zero, which is computationally infeasible. + let session_id_len: usize = 32; + let min_len = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1 + session_id_len; + let mut h = vec![0x42u8; min_len]; + h[TLS_DIGEST_POS + TLS_DIGEST_LEN] = session_id_len as u8; + h[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].fill(0); + let secrets = vec![("u".to_string(), b"test_secret".to_vec())]; + assert!(validate_tls_handshake(&h, &secrets, true).is_none()); +} + +#[test] +fn all_ones_digest_rejected() { + let session_id_len: usize = 32; + let min_len = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1 + session_id_len; + let mut h = vec![0x42u8; min_len]; + h[TLS_DIGEST_POS + TLS_DIGEST_LEN] = session_id_len as u8; + h[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].fill(0xFF); + let secrets = vec![("u".to_string(), b"test_secret".to_vec())]; + assert!(validate_tls_handshake(&h, &secrets, true).is_none()); +} + +/// Simulate a censor that sends 200 crafted packets with random digests. +/// Every single one must be rejected; no random digest should accidentally +/// pass (probability 2^{-224} per attempt; negligible for 200 trials). +#[test] +fn censor_probe_random_digests_all_rejected() { + use crate::crypto::SecureRandom; + let secret = b"production_like_secret_value_xyz"; + let session_id_len: usize = 32; + let min_len = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1 + session_id_len; + let secrets = vec![("u".to_string(), secret.to_vec())]; + let rng = SecureRandom::new(); + + for attempt in 0..200 { + let mut h = vec![0x42u8; min_len]; + h[TLS_DIGEST_POS + TLS_DIGEST_LEN] = session_id_len as u8; + let rand_digest = rng.bytes(TLS_DIGEST_LEN); + h[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN] + .copy_from_slice(&rand_digest); + assert!( + validate_tls_handshake(&h, &secrets, true).is_none(), + "Random digest at attempt {attempt} must not match" + ); + } +} + +/// The check window is bytes 0-27 of the XOR result. Bytes 28-31 encode +/// the timestamp and must NOT affect whether the HMAC portion validates - +/// only the timestamp range check uses them. Build a valid handshake with +/// timestamp = 0 (boot-time), flip each of bytes 28-31 with ignore_time_skew +/// enabled, and verify the HMAC portion still passes (the timestamp changes +/// but the proxy still accepts the connection under ignore_time_skew). +#[test] +fn timestamp_bytes_28_31_do_not_affect_hmac_window() { + let secret = b"window_boundary_test"; + let base = make_valid_tls_handshake(secret, 0); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + // Baseline must pass. + assert!(validate_tls_handshake(&base, &secrets, true).is_some()); + + // Flip each of the timestamp bytes; with ignore_time_skew the + // modified timestamps (small absolute values) still pass boot-time check. + for i in 28..32usize { + let mut h = base.clone(); + h[TLS_DIGEST_POS + i] ^= 0xFF; + // The new timestamp is non-zero but potentially still < boot threshold; + // use ignore_time_skew=true so wallet test is HMAC-only. + assert!( + validate_tls_handshake(&h, &secrets, true).is_some(), + "Flipping byte {i} (timestamp region) must not invalidate HMAC window" + ); + } +} + +// ------------------------------------------------------------------ +// session_id preservation +// ------------------------------------------------------------------ + +#[test] +fn session_id_is_preserved_verbatim_in_validation_result() { + // If session_id extraction is ever broken (wrong offset, wrong length, + // off-by-one), this test will catch it before it silently corrupts the + // ServerHello that echoes the session_id back to the client. + let secret = b"session_id_preservation_test"; + let handshake = make_valid_tls_handshake(secret, 0); + let secrets = vec![("u".to_string(), secret.to_vec())]; + let result = validate_tls_handshake(&handshake, &secrets, true).unwrap(); + + let sid_len_pos = TLS_DIGEST_POS + TLS_DIGEST_LEN; + let sid_len = handshake[sid_len_pos] as usize; + let expected = &handshake[sid_len_pos + 1..sid_len_pos + 1 + sid_len]; + + assert_eq!( + result.session_id, expected, + "session_id in TlsValidation must be the verbatim bytes from the handshake" + ); +} + +// ------------------------------------------------------------------ +// Clock decoupling - ignore_time_skew must not consult the system clock +// ------------------------------------------------------------------ + +/// When `ignore_time_skew = true`, a valid HMAC must be accepted even if +/// `now = 0` (the sentinel used when the clock is not needed). A broken +/// system clock cannot silently deny service when the admin has explicitly +/// disabled timestamp checking. +#[test] +fn ignore_time_skew_accepts_valid_hmac_with_now_zero() { + let secret = b"clock_decoupling_test"; + // Use a realistic Unix timestamp that would be far outside the window + // if compared against now=0 (time_diff would be ~-1_700_000_000). + let realistic_ts: u32 = 1_700_000_000; + let h = make_valid_tls_handshake(secret, realistic_ts); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + assert!( + validate_tls_handshake_at_time(&h, &secrets, true, 0).is_some(), + "ignore_time_skew=true must accept a valid HMAC regardless of `now`" + ); + + // Confirm that the same handshake IS rejected when the window is enforced + // and now=0 (time_diff very negative -> outside window). This distinguishes + // "clock decoupling" from "always accept". + assert!( + validate_tls_handshake_at_time(&h, &secrets, false, 0).is_none(), + "ignore_time_skew=false with now=0 must still reject out-of-window timestamps" + ); +} + +/// An HMAC-invalid handshake must be rejected even when ignore_time_skew=true +/// and now=0. Verifies that the clock-decoupling fix did not weaken HMAC +/// enforcement in the ignore_time_skew path. +#[test] +fn ignore_time_skew_with_now_zero_still_rejects_bad_hmac() { + let secret = b"clock_no_backdoor_test"; + let mut h = make_valid_tls_handshake(secret, 1_700_000_000); + let secrets = vec![("u".to_string(), secret.to_vec())]; + // Corrupt the HMAC check window. + h[TLS_DIGEST_POS] ^= 0xFF; + assert!( + validate_tls_handshake_at_time(&h, &secrets, true, 0).is_none(), + "Broken HMAC must be rejected even with ignore_time_skew=true and now=0" + ); +} + +#[test] +fn system_time_before_unix_epoch_is_rejected_without_panic() { + let before_epoch = UNIX_EPOCH + .checked_sub(std::time::Duration::from_secs(1)) + .expect("UNIX_EPOCH minus one second must be representable"); + assert!(system_time_to_unix_secs(before_epoch).is_none()); +} + +/// `i64::MAX` is 9_223_372_036_854_775_807 seconds (~292 billion years CE). +/// Any `SystemTime` whose duration since epoch exceeds `i64::MAX` seconds +/// must return `None` rather than silently wrapping to a large negative +/// timestamp that would corrupt every subsequent time-skew comparison. +#[test] +fn system_time_far_future_overflowing_i64_returns_none() { + // i64::MAX + 1 seconds past epoch overflows i64 when cast naively with `as`. + let overflow_secs = u64::try_from(i64::MAX).unwrap() + 1; + if let Some(far_future) = + UNIX_EPOCH.checked_add(std::time::Duration::from_secs(overflow_secs)) + { + assert!( + system_time_to_unix_secs(far_future).is_none(), + "Seconds > i64::MAX must return None, not a wrapped negative timestamp" + ); + } + // If the platform cannot represent this SystemTime, the test is vacuously + // satisfied: `checked_add` returning None means the platform already rejects it. +} + +// ------------------------------------------------------------------ +// Message canonicalization — HMAC covers every byte of the handshake +// ------------------------------------------------------------------ + +/// Every byte before TLS_DIGEST_POS is part of the HMAC input (because msg +/// = full handshake with only the digest slot zeroed). An attacker cannot +/// replay a valid handshake with a modified ClientHello header while keeping +/// the stored digest; each such modification produces a different HMAC. +#[test] +fn pre_digest_bytes_are_hmac_covered() { + // TLS_DIGEST_POS = 11, so 11 bytes precede the digest. + let secret = b"pre_digest_coverage_test"; + let base = make_valid_tls_handshake(secret, 0); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + for byte_pos in 0..TLS_DIGEST_POS { + let mut h = base.clone(); + h[byte_pos] ^= 0x01; + assert!( + validate_tls_handshake(&h, &secrets, true).is_none(), + "Flip in pre-digest byte {byte_pos} must cause HMAC check failure" + ); + } +} + +/// session_id bytes follow the digest in the buffer and are also part of the +/// HMAC input. Flipping any of them invalidates the stored digest, preventing +/// a censor from capturing a valid session_id and replaying it with a different +/// one while keeping the rest of the packet intact. +#[test] +fn session_id_bytes_are_hmac_covered() { + let secret = b"session_id_coverage_test"; + let base = make_valid_tls_handshake(secret, 0); // session_id_len = 32 + let secrets = vec![("u".to_string(), secret.to_vec())]; + + let sid_start = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1; + for byte_pos in sid_start..base.len() { + let mut h = base.clone(); + h[byte_pos] ^= 0x01; + assert!( + validate_tls_handshake(&h, &secrets, true).is_none(), + "Flip in session_id byte at offset {byte_pos} must cause HMAC check failure" + ); + } +} + +/// Appending even one byte to a valid handshake changes the HMAC input (msg +/// includes all bytes) and therefore invalidates the stored digest. This +/// prevents a length-extension-style modification of the payload. +#[test] +fn appended_trailing_byte_causes_rejection() { + let secret = b"trailing_byte_test"; + let mut h = make_valid_tls_handshake(secret, 0); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + assert!(validate_tls_handshake(&h, &secrets, true).is_some(), "baseline"); + + h.push(0x00); + assert!( + validate_tls_handshake(&h, &secrets, true).is_none(), + "Appending a trailing byte to a valid handshake must invalidate the HMAC" + ); +} + +// ------------------------------------------------------------------ +// Zero-length session_id (structural edge case) +// ------------------------------------------------------------------ + +/// session_id_len = 0 is legal in the TLS spec. The validator must accept a +/// valid handshake with an empty session_id and return an empty session_id +/// slice without panicking or accessing out-of-bounds memory. +#[test] +fn zero_length_session_id_accepted() { + let secret = b"zero_sid_test"; + // Buffer: pre-digest | digest | session_id_len=0 (no session_id bytes follow) + let len = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1; + let mut handshake = vec![0x42u8; len]; + handshake[TLS_DIGEST_POS + TLS_DIGEST_LEN] = 0; // session_id_len = 0 + handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].fill(0); + + let computed = sha256_hmac(secret, &handshake); + // timestamp = 0 → ts XOR bytes are all zero → digest = computed unchanged. + handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN] + .copy_from_slice(&computed); + + let secrets = vec![("u".to_string(), secret.to_vec())]; + let result = validate_tls_handshake(&handshake, &secrets, true); + assert!(result.is_some(), "zero-length session_id must be accepted"); + assert!( + result.unwrap().session_id.is_empty(), + "session_id field must be empty when session_id_len = 0" + ); +} + +// ------------------------------------------------------------------ +// Boot-time threshold — exact boundary precision +// ------------------------------------------------------------------ + +/// timestamp = BOOT_TIME_MAX_SECS - 1 is the last value inside the boot-time window. +/// is_boot_time = true → skew check is skipped entirely → accepted even +/// when `now` is far from the timestamp. +#[test] +fn timestamp_one_below_boot_threshold_bypasses_skew_check() { + let secret = b"boot_last_value_test"; + let ts: u32 = BOOT_TIME_MAX_SECS - 1; + let h = make_valid_tls_handshake(secret, ts); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + // now = 0 → time_diff would be -86_399_999, way outside [-1200, 600]. + // Boot-time bypass must prevent the skew check from running. + assert!( + validate_tls_handshake_at_time(&h, &secrets, false, 0).is_some(), + "ts=BOOT_TIME_MAX_SECS-1 must bypass skew check regardless of now" + ); +} + +/// timestamp = BOOT_TIME_MAX_SECS is the first value outside the boot-time window. +/// is_boot_time = false → skew check IS applied. Two sub-cases confirm this: +/// once with now chosen so the skew passes (accepted) and once where it fails. +#[test] +fn timestamp_at_boot_threshold_triggers_skew_check() { + let secret = b"boot_exact_value_test"; + let ts: u32 = BOOT_TIME_MAX_SECS; + let h = make_valid_tls_handshake(secret, ts); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + // now = ts + 50 → time_diff = 50, within [-1200, 600] → accepted. + let now_valid: i64 = ts as i64 + 50; + assert!( + validate_tls_handshake_at_time(&h, &secrets, false, now_valid).is_some(), + "ts=BOOT_TIME_MAX_SECS within skew window must be accepted via skew check" + ); + + // now = 0 → time_diff = -86_400_000, outside window → rejected. + // If the boot-time bypass were wrongly applied here this would pass. + assert!( + validate_tls_handshake_at_time(&h, &secrets, false, 0).is_none(), + "ts=BOOT_TIME_MAX_SECS far from now must be rejected — no boot-time bypass" + ); +} + +#[test] +fn replay_window_cap_disables_boot_bypass_for_old_timestamps() { + let secret = b"boot_cap_disabled_test"; + let ts: u32 = 900; + let h = make_valid_tls_handshake(secret, ts); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + let result = validate_tls_handshake_with_replay_window(&h, &secrets, false, 300); + assert!( + result.is_none(), + "timestamp above replay-window cap must not use boot-time bypass" + ); +} + +#[test] +fn replay_window_cap_still_allows_small_boot_timestamp() { + let secret = b"boot_cap_enabled_test"; + let ts: u32 = 120; + let h = make_valid_tls_handshake(secret, ts); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + let result = validate_tls_handshake_with_replay_window(&h, &secrets, false, 300); + assert!( + result.is_some(), + "timestamp below replay-window cap must retain boot-time compatibility" + ); +} + +// ------------------------------------------------------------------ +// Extreme timestamp values +// ------------------------------------------------------------------ + +/// u32::MAX is a valid timestamp value. When ignore_time_skew=true the HMAC +/// is the only gate, and a correctly constructed handshake must be accepted. +#[test] +fn u32_max_timestamp_accepted_with_ignore_time_skew() { + let secret = b"u32_max_ts_accept_test"; + let h = make_valid_tls_handshake(secret, u32::MAX); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + let result = validate_tls_handshake(&h, &secrets, true); + assert!(result.is_some(), "u32::MAX timestamp must be accepted with ignore_time_skew=true"); + assert_eq!( + result.unwrap().timestamp, + u32::MAX, + "timestamp field must equal u32::MAX verbatim" + ); +} + +/// u32::MAX > BOOT_TIME_MAX_SECS so the skew check runs. With any realistic `now` +/// (~1.7 billion), time_diff = now - u32::MAX is deeply negative — far outside +/// [-1200, 600] — so the handshake must be rejected without overflow. +#[test] +fn u32_max_timestamp_rejected_by_skew_enforcement() { + let secret = b"u32_max_ts_reject_test"; + let h = make_valid_tls_handshake(secret, u32::MAX); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + let now: i64 = 1_700_000_000; + assert!( + validate_tls_handshake_at_time(&h, &secrets, false, now).is_none(), + "u32::MAX timestamp must be rejected by skew check with realistic now" + ); +} + +// ------------------------------------------------------------------ +// Validation result field correctness +// ------------------------------------------------------------------ + +/// result.digest must be the verbatim bytes stored in the handshake buffer, +/// not the freshly recomputed HMAC. Callers use this field directly when +/// constructing the ServerHello response digest. +#[test] +fn result_digest_field_is_verbatim_stored_digest() { + let secret = b"digest_field_verbatim_test"; + let ts: u32 = 0xCAFE_BABE; + let h = make_valid_tls_handshake(secret, ts); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + let result = validate_tls_handshake(&h, &secrets, true).unwrap(); + + let stored: [u8; TLS_DIGEST_LEN] = h[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN] + .try_into() + .unwrap(); + assert_eq!( + result.digest, stored, + "result.digest must equal the stored bytes, not the computed HMAC" + ); +} + +// ------------------------------------------------------------------ +// Secret length edge cases +// ------------------------------------------------------------------ + +/// HMAC-SHA256 pads or hashes keys of any length; a single-byte key must work. +#[test] +fn single_byte_secret_works() { + let secret = b"x"; + let h = make_valid_tls_handshake(secret, 0); + let secrets = vec![("u".to_string(), secret.to_vec())]; + assert!( + validate_tls_handshake(&h, &secrets, true).is_some(), + "single-byte secret must produce a valid and verifiable HMAC" + ); +} + +/// Keys longer than the HMAC block size (64 bytes for SHA-256) are hashed +/// before use. A 256-byte key must work without truncation or panic. +#[test] +fn very_long_secret_256_bytes_works() { + let secret = vec![0xABu8; 256]; + let h = make_valid_tls_handshake(&secret, 0); + let secrets = vec![("u".to_string(), secret.clone())]; + assert!( + validate_tls_handshake(&h, &secrets, true).is_some(), + "256-byte secret must be accepted without truncation" + ); +} + +// ------------------------------------------------------------------ +// Determinism — same input must always produce same result +// ------------------------------------------------------------------ + +/// Calling validate twice on the same input must return identical results. +/// Non-determinism (e.g. from an accidentally global mutable state or a +/// shared nonce) would be a critical security defect in a proxy that rejects +/// censors by relying on stable authentication outcomes. +#[test] +fn validation_is_deterministic() { + let secret = b"determinism_test_key"; + let h = make_valid_tls_handshake(secret, 42); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + let r1 = validate_tls_handshake(&h, &secrets, true).unwrap(); + let r2 = validate_tls_handshake(&h, &secrets, true).unwrap(); + + assert_eq!(r1.user, r2.user); + assert_eq!(r1.session_id, r2.session_id); + assert_eq!(r1.digest, r2.digest); + assert_eq!(r1.timestamp, r2.timestamp); +} + +// ------------------------------------------------------------------ +// Multi-user: scan-all correctness guarantees +// ------------------------------------------------------------------ + +/// The matching logic must scan through the entire secrets list. A user +/// at position 99 of 100 must be found; an implementation that stops early +/// on the first non-match would fail this test. +#[test] +fn last_user_in_large_list_is_found() { + let target_secret = b"needle_in_haystack"; + let h = make_valid_tls_handshake(target_secret, 0); + + let mut secrets: Vec<(String, Vec)> = (0..99) + .map(|i| (format!("decoy_{i}"), format!("wrong_{i}").into_bytes())) + .collect(); + secrets.push(("needle".to_string(), target_secret.to_vec())); + + let result = validate_tls_handshake(&h, &secrets, true); + assert!(result.is_some(), "100th user must be found"); + assert_eq!(result.unwrap().user, "needle"); +} + +/// When multiple users share the same secret the first occurrence must always +/// win. The scan-all loop must not replace first_match with a later one. +#[test] +fn first_matching_user_wins_over_later_duplicate_secret() { + let shared = b"duplicated_secret_key"; + let h = make_valid_tls_handshake(shared, 0); + + let secrets = vec![ + ("decoy_1".to_string(), b"wrong_1".to_vec()), + ("winner".to_string(), shared.to_vec()), // first match + ("decoy_2".to_string(), b"wrong_2".to_vec()), + ("loser".to_string(), shared.to_vec()), // second match — must not win + ("decoy_3".to_string(), b"wrong_3".to_vec()), + ]; + + let result = validate_tls_handshake(&h, &secrets, true); + assert!(result.is_some()); + assert_eq!( + result.unwrap().user, "winner", + "first matching user must be returned even when a later entry also matches" + ); +} + +// ------------------------------------------------------------------ +// Legacy tls.rs tests moved here +// ------------------------------------------------------------------ + +#[test] +fn test_is_tls_handshake() { + assert!(is_tls_handshake(&[0x16, 0x03, 0x01])); + assert!(is_tls_handshake(&[0x16, 0x03, 0x01, 0x02, 0x00])); + assert!(!is_tls_handshake(&[0x17, 0x03, 0x01])); + assert!(!is_tls_handshake(&[0x16, 0x03, 0x02])); + assert!(!is_tls_handshake(&[0x16, 0x03])); +} + +#[test] +fn test_parse_tls_record_header() { + let header = [0x16, 0x03, 0x01, 0x02, 0x00]; + let result = parse_tls_record_header(&header).unwrap(); + assert_eq!(result.0, TLS_RECORD_HANDSHAKE); + assert_eq!(result.1, 512); + + let header = [0x17, 0x03, 0x03, 0x40, 0x00]; + let result = parse_tls_record_header(&header).unwrap(); + assert_eq!(result.0, TLS_RECORD_APPLICATION); + assert_eq!(result.1, 16384); +} + +#[test] +fn test_gen_fake_x25519_key() { + let rng = crate::crypto::SecureRandom::new(); + let key1 = gen_fake_x25519_key(&rng); + let key2 = gen_fake_x25519_key(&rng); + + assert_eq!(key1.len(), 32); + assert_eq!(key2.len(), 32); + assert_ne!(key1, key2); +} + +#[test] +fn test_fake_x25519_key_is_quadratic_residue() { + use num_bigint::BigUint; + use num_traits::One; + + let rng = crate::crypto::SecureRandom::new(); + let key = gen_fake_x25519_key(&rng); + let p = curve25519_prime(); + let k_num = BigUint::from_bytes_le(&key); + let exponent = (&p - BigUint::one()) >> 1; + let legendre = k_num.modpow(&exponent, &p); + assert_eq!(legendre, BigUint::one()); +} + +#[test] +fn test_tls_extension_builder() { + let key = [0x42u8; 32]; + + let mut builder = TlsExtensionBuilder::new(); + builder.add_key_share(&key); + builder.add_supported_versions(0x0304); + + let result = builder.build(); + let len = u16::from_be_bytes([result[0], result[1]]) as usize; + + assert_eq!(len, result.len() - 2); + assert!(result.len() > 40); +} + +#[test] +fn test_server_hello_builder() { + let session_id = vec![0x01, 0x02, 0x03, 0x04]; + let key = [0x55u8; 32]; + + let builder = ServerHelloBuilder::new(session_id.clone()) + .with_x25519_key(&key) + .with_tls13_version(); + + let record = builder.build_record(); + validate_server_hello_structure(&record).expect("Invalid ServerHello structure"); + + assert_eq!(record[0], TLS_RECORD_HANDSHAKE); + assert_eq!(&record[1..3], &TLS_VERSION); + assert_eq!(record[5], 0x02); +} + +#[test] +fn test_build_server_hello_structure() { + let secret = b"test secret"; + let client_digest = [0x42u8; 32]; + let session_id = vec![0xAA; 32]; + + let rng = crate::crypto::SecureRandom::new(); + let response = build_server_hello(secret, &client_digest, &session_id, 2048, &rng, None, 0); + + assert!(response.len() > 100); + assert_eq!(response[0], TLS_RECORD_HANDSHAKE); + validate_server_hello_structure(&response).expect("Invalid ServerHello"); + + let server_hello_len = 5 + u16::from_be_bytes([response[3], response[4]]) as usize; + let ccs_start = server_hello_len; + assert!(response.len() > ccs_start + 6); + assert_eq!(response[ccs_start], TLS_RECORD_CHANGE_CIPHER); + + let ccs_len = 5 + u16::from_be_bytes([response[ccs_start + 3], response[ccs_start + 4]]) as usize; + let app_start = ccs_start + ccs_len; + assert!(response.len() > app_start + 5); + assert_eq!(response[app_start], TLS_RECORD_APPLICATION); +} + +#[test] +fn test_build_server_hello_digest() { + let secret = b"test secret key here"; + let client_digest = [0x42u8; 32]; + let session_id = vec![0xAA; 32]; + + let rng = crate::crypto::SecureRandom::new(); + let response1 = build_server_hello(secret, &client_digest, &session_id, 1024, &rng, None, 0); + let response2 = build_server_hello(secret, &client_digest, &session_id, 1024, &rng, None, 0); + + let digest1 = &response1[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN]; + assert!(!digest1.iter().all(|&b| b == 0)); + + let digest2 = &response2[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN]; + assert_ne!(digest1, digest2); +} + +#[test] +fn test_server_hello_extensions_length() { + let session_id = vec![0x01; 32]; + let key = [0x55u8; 32]; + + let builder = ServerHelloBuilder::new(session_id) + .with_x25519_key(&key) + .with_tls13_version(); + + let record = builder.build_record(); + let msg_start = 5; + let msg_len = u32::from_be_bytes([0, record[6], record[7], record[8]]) as usize; + let session_id_pos = msg_start + 4 + 2 + 32; + let session_id_len = record[session_id_pos] as usize; + let ext_len_pos = session_id_pos + 1 + session_id_len + 2 + 1; + let ext_len = u16::from_be_bytes([record[ext_len_pos], record[ext_len_pos + 1]]) as usize; + let extensions_data = &record[ext_len_pos + 2..msg_start + 4 + msg_len]; + + assert_eq!( + ext_len, + extensions_data.len(), + "Extension length mismatch: declared {}, actual {}", + ext_len, + extensions_data.len() + ); +} + +#[test] +fn test_validate_tls_handshake_format() { + let mut handshake = vec![0u8; 100]; + handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].copy_from_slice(&[0x42; 32]); + handshake[TLS_DIGEST_POS + TLS_DIGEST_LEN] = 32; + + let secrets = vec![("test".to_string(), b"secret".to_vec())]; + let result = validate_tls_handshake(&handshake, &secrets, true); + assert!(result.is_none()); +} + +fn build_client_hello_with_exts(exts: Vec<(u16, Vec)>, host: &str) -> Vec { + let mut body = Vec::new(); + body.extend_from_slice(&TLS_VERSION); + body.extend_from_slice(&[0u8; 32]); + body.push(0); + body.extend_from_slice(&2u16.to_be_bytes()); + body.extend_from_slice(&[0x13, 0x01]); + body.push(1); + body.push(0); + + let host_bytes = host.as_bytes(); + let mut sni_ext = Vec::new(); + sni_ext.extend_from_slice(&(host_bytes.len() as u16 + 3).to_be_bytes()); + sni_ext.push(0); + sni_ext.extend_from_slice(&(host_bytes.len() as u16).to_be_bytes()); + sni_ext.extend_from_slice(host_bytes); + + let mut ext_blob = Vec::new(); + for (typ, data) in exts { + ext_blob.extend_from_slice(&typ.to_be_bytes()); + ext_blob.extend_from_slice(&(data.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&data); + } + ext_blob.extend_from_slice(&0x0000u16.to_be_bytes()); + ext_blob.extend_from_slice(&(sni_ext.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&sni_ext); + + body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes()); + body.extend_from_slice(&ext_blob); + + let mut handshake = Vec::new(); + handshake.push(0x01); + let len_bytes = (body.len() as u32).to_be_bytes(); + handshake.extend_from_slice(&len_bytes[1..4]); + handshake.extend_from_slice(&body); + + let mut record = Vec::new(); + record.push(TLS_RECORD_HANDSHAKE); + record.extend_from_slice(&[0x03, 0x01]); + record.extend_from_slice(&(handshake.len() as u16).to_be_bytes()); + record.extend_from_slice(&handshake); + record +} + +fn build_client_hello_with_raw_extensions(ext_blob: &[u8]) -> Vec { + let mut body = Vec::new(); + body.extend_from_slice(&TLS_VERSION); + body.extend_from_slice(&[0u8; 32]); + body.push(0); + body.extend_from_slice(&2u16.to_be_bytes()); + body.extend_from_slice(&[0x13, 0x01]); + body.push(1); + body.push(0); + body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes()); + body.extend_from_slice(ext_blob); + + let mut handshake = Vec::new(); + handshake.push(0x01); + let len_bytes = (body.len() as u32).to_be_bytes(); + handshake.extend_from_slice(&len_bytes[1..4]); + handshake.extend_from_slice(&body); + + let mut record = Vec::new(); + record.push(TLS_RECORD_HANDSHAKE); + record.extend_from_slice(&[0x03, 0x01]); + record.extend_from_slice(&(handshake.len() as u16).to_be_bytes()); + record.extend_from_slice(&handshake); + record +} + +#[test] +fn test_extract_sni_with_grease_extension() { + let ch = build_client_hello_with_exts(vec![(0x0a0a, Vec::new())], "example.com"); + let sni = extract_sni_from_client_hello(&ch); + assert_eq!(sni.as_deref(), Some("example.com")); +} + +#[test] +fn test_extract_sni_tolerates_empty_unknown_extension() { + let ch = build_client_hello_with_exts(vec![(0x1234, Vec::new())], "test.local"); + let sni = extract_sni_from_client_hello(&ch); + assert_eq!(sni.as_deref(), Some("test.local")); +} + +#[test] +fn test_extract_alpn_single() { + let mut alpn_data = Vec::new(); + alpn_data.extend_from_slice(&3u16.to_be_bytes()); + alpn_data.push(2); + alpn_data.extend_from_slice(b"h2"); + let ch = build_client_hello_with_exts(vec![(0x0010, alpn_data)], "alpn.test"); + let alpn = extract_alpn_from_client_hello(&ch); + let alpn_str: Vec = alpn + .iter() + .map(|p| std::str::from_utf8(p).unwrap().to_string()) + .collect(); + assert_eq!(alpn_str, vec!["h2"]); +} + +#[test] +fn test_extract_alpn_multiple() { + let mut alpn_data = Vec::new(); + alpn_data.extend_from_slice(&11u16.to_be_bytes()); + alpn_data.push(2); + alpn_data.extend_from_slice(b"h2"); + alpn_data.push(4); + alpn_data.extend_from_slice(b"spdy"); + alpn_data.push(2); + alpn_data.extend_from_slice(b"h3"); + let ch = build_client_hello_with_exts(vec![(0x0010, alpn_data)], "alpn.test"); + let alpn = extract_alpn_from_client_hello(&ch); + let alpn_str: Vec = alpn + .iter() + .map(|p| std::str::from_utf8(p).unwrap().to_string()) + .collect(); + assert_eq!(alpn_str, vec!["h2", "spdy", "h3"]); +} + +#[test] +fn extract_sni_rejects_zero_length_host_name() { + let mut sni_ext = Vec::new(); + sni_ext.extend_from_slice(&3u16.to_be_bytes()); + sni_ext.push(0); + sni_ext.extend_from_slice(&0u16.to_be_bytes()); + + let mut ext_blob = Vec::new(); + ext_blob.extend_from_slice(&0x0000u16.to_be_bytes()); + ext_blob.extend_from_slice(&(sni_ext.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&sni_ext); + + let ch = build_client_hello_with_raw_extensions(&ext_blob); + assert!(extract_sni_from_client_hello(&ch).is_none()); +} + +#[test] +fn extract_sni_rejects_raw_ipv4_literals() { + let ch = build_client_hello_with_exts(Vec::new(), "203.0.113.10"); + assert!(extract_sni_from_client_hello(&ch).is_none()); +} + +#[test] +fn extract_sni_rejects_invalid_label_characters() { + let ch = build_client_hello_with_exts(Vec::new(), "exa_mple.com"); + assert!(extract_sni_from_client_hello(&ch).is_none()); +} + +#[test] +fn extract_sni_rejects_oversized_label() { + let oversized = format!("{}.example.com", "a".repeat(64)); + let ch = build_client_hello_with_exts(Vec::new(), &oversized); + assert!(extract_sni_from_client_hello(&ch).is_none()); +} + +#[test] +fn extract_sni_rejects_when_extension_block_is_truncated() { + let mut ext_blob = Vec::new(); + ext_blob.extend_from_slice(&0x0000u16.to_be_bytes()); + ext_blob.extend_from_slice(&5u16.to_be_bytes()); + ext_blob.extend_from_slice(&[0, 3, 0]); + + let mut ch = build_client_hello_with_raw_extensions(&ext_blob); + ch.pop(); + assert!(extract_sni_from_client_hello(&ch).is_none()); +} + +#[test] +fn extract_alpn_rejects_when_extension_block_is_truncated() { + let mut ext_blob = Vec::new(); + ext_blob.extend_from_slice(&0x0010u16.to_be_bytes()); + ext_blob.extend_from_slice(&5u16.to_be_bytes()); + ext_blob.extend_from_slice(&[0, 3, 2, b'h']); + + let ch = build_client_hello_with_raw_extensions(&ext_blob); + assert!(extract_alpn_from_client_hello(&ch).is_empty()); +} + +#[test] +fn extract_alpn_rejects_nested_length_overflow() { + let mut alpn_data = Vec::new(); + alpn_data.extend_from_slice(&10u16.to_be_bytes()); + alpn_data.push(8); + alpn_data.extend_from_slice(b"h2"); + + let mut ext_blob = Vec::new(); + ext_blob.extend_from_slice(&0x0010u16.to_be_bytes()); + ext_blob.extend_from_slice(&(alpn_data.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&alpn_data); + + let ch = build_client_hello_with_raw_extensions(&ext_blob); + assert!(extract_alpn_from_client_hello(&ch).is_empty()); +} + +// ------------------------------------------------------------------ +// Additional adversarial checks +// ------------------------------------------------------------------ + +#[test] +fn empty_secret_hmac_is_supported() { + let secret: &[u8] = b""; + let handshake = make_valid_tls_handshake(secret, 0); + let secrets = vec![("empty".to_string(), secret.to_vec())]; + let result = validate_tls_handshake(&handshake, &secrets, true); + assert!(result.is_some(), "Empty HMAC key must not panic and must validate when correct"); +} + +#[test] +fn server_hello_digest_verifies_against_full_response() { + let secret = b"fronting_digest_verify_key"; + let client_digest = [0x42u8; TLS_DIGEST_LEN]; + let session_id = vec![0xAA; 32]; + let rng = crate::crypto::SecureRandom::new(); + + let response = build_server_hello(secret, &client_digest, &session_id, 1024, &rng, None, 1); + let mut zeroed = response.clone(); + zeroed[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].fill(0); + + let mut hmac_input = Vec::with_capacity(TLS_DIGEST_LEN + zeroed.len()); + hmac_input.extend_from_slice(&client_digest); + hmac_input.extend_from_slice(&zeroed); + let expected = sha256_hmac(secret, &hmac_input); + + assert_eq!( + &response[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN], + &expected, + "ServerHello digest must be verifiable by a client that recomputes HMAC over full response" + ); +} + +#[test] +fn server_hello_digest_fails_after_single_byte_tamper() { + let secret = b"fronting_tamper_detect_key"; + let client_digest = [0x24u8; TLS_DIGEST_LEN]; + let session_id = vec![0xBB; 32]; + let rng = crate::crypto::SecureRandom::new(); + + let mut response = build_server_hello(secret, &client_digest, &session_id, 1024, &rng, None, 0); + response[TLS_DIGEST_POS + TLS_DIGEST_LEN + 1] ^= 0x01; + + let mut zeroed = response.clone(); + zeroed[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].fill(0); + + let mut hmac_input = Vec::with_capacity(TLS_DIGEST_LEN + zeroed.len()); + hmac_input.extend_from_slice(&client_digest); + hmac_input.extend_from_slice(&zeroed); + let expected = sha256_hmac(secret, &hmac_input); + + assert_ne!( + &response[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN], + &expected, + "Tampering any response byte must invalidate the embedded digest" + ); +} + +#[test] +fn server_hello_application_data_payload_varies_across_runs() { + use std::collections::HashSet; + + let secret = b"fronting_payload_variability_key"; + let client_digest = [0x13u8; TLS_DIGEST_LEN]; + let session_id = vec![0x44; 32]; + let rng = crate::crypto::SecureRandom::new(); + + let mut unique_payloads: HashSet> = HashSet::new(); + for _ in 0..16 { + let response = build_server_hello(secret, &client_digest, &session_id, 1024, &rng, None, 0); + + let sh_len = u16::from_be_bytes([response[3], response[4]]) as usize; + let ccs_pos = 5 + sh_len; + let ccs_len = u16::from_be_bytes([response[ccs_pos + 3], response[ccs_pos + 4]]) as usize; + let app_pos = ccs_pos + 5 + ccs_len; + + assert_eq!(response[app_pos], TLS_RECORD_APPLICATION); + let app_len = u16::from_be_bytes([response[app_pos + 3], response[app_pos + 4]]) as usize; + let payload = response[app_pos + 5..app_pos + 5 + app_len].to_vec(); + + assert!(payload.iter().any(|&b| b != 0), "Payload must not be all-zero deterministic filler"); + unique_payloads.insert(payload); + } + + assert!( + unique_payloads.len() >= 4, + "ApplicationData payload should vary across runs to reduce fingerprintability" + ); +} diff --git a/src/proxy/client.rs b/src/proxy/client.rs index 99e6837..5ccbd40 100644 --- a/src/proxy/client.rs +++ b/src/proxy/client.rs @@ -4,7 +4,10 @@ use std::future::Future; use std::net::{IpAddr, SocketAddr}; use std::pin::Pin; use std::sync::Arc; +use std::sync::OnceLock; +use std::sync::atomic::{AtomicBool, Ordering}; use std::time::Duration; +use ipnetwork::IpNetwork; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite}; use tokio::net::TcpStream; use tokio::time::timeout; @@ -23,7 +26,7 @@ enum HandshakeOutcome { use crate::config::ProxyConfig; use crate::crypto::SecureRandom; -use crate::error::{HandshakeResult, ProxyError, Result}; +use crate::error::{HandshakeResult, ProxyError, Result, StreamError}; use crate::ip_tracker::UserIpTracker; use crate::protocol::constants::*; use crate::protocol::tls; @@ -63,14 +66,30 @@ fn record_handshake_failure_class( peer_ip: IpAddr, error: &ProxyError, ) { - let class = if error.to_string().contains("expected 64 bytes, got 0") { - "expected_64_got_0" - } else { - "other" + let class = match error { + ProxyError::Io(err) if err.kind() == std::io::ErrorKind::UnexpectedEof => { + "expected_64_got_0" + } + ProxyError::Stream(StreamError::UnexpectedEof) => "expected_64_got_0", + _ => "other", }; record_beobachten_class(beobachten, config, peer_ip, class); } +fn is_trusted_proxy_source(peer_ip: IpAddr, trusted: &[IpNetwork]) -> bool { + if trusted.is_empty() { + static EMPTY_PROXY_TRUST_WARNED: OnceLock = OnceLock::new(); + let warned = EMPTY_PROXY_TRUST_WARNED.get_or_init(|| AtomicBool::new(false)); + if !warned.swap(true, Ordering::Relaxed) { + warn!( + "PROXY protocol enabled but server.proxy_protocol_trusted_cidrs is empty; rejecting all PROXY headers by default" + ); + } + return false; + } + trusted.iter().any(|cidr| cidr.contains(peer_ip)) +} + pub async fn handle_client_stream( mut stream: S, peer: SocketAddr, @@ -104,6 +123,17 @@ where ); match timeout(proxy_header_timeout, parse_proxy_protocol(&mut stream, peer)).await { Ok(Ok(info)) => { + if !is_trusted_proxy_source(peer.ip(), &config.server.proxy_protocol_trusted_cidrs) + { + stats.increment_connects_bad(); + warn!( + peer = %peer, + trusted = ?config.server.proxy_protocol_trusted_cidrs, + "Rejecting PROXY protocol header from untrusted source" + ); + record_beobachten_class(&beobachten, &config, peer.ip(), "other"); + return Err(ProxyError::InvalidProxyProtocol); + } debug!( peer = %peer, client = %info.src_addr, @@ -149,8 +179,13 @@ where if is_tls { let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize; - if tls_len < 512 { - debug!(peer = %real_peer, tls_len = tls_len, "TLS handshake too short"); +// RFC 8446 §5.1 mandates that TLSPlaintext records must not exceed 2^14 + // bytes (16_384). A client claiming a larger record is non-compliant and + // may be an active probe attempting to force large allocations. + // + // Also enforce a minimum record size to avoid trivial/garbage probes. + if !(512..=MAX_TLS_RECORD_SIZE).contains(&tls_len) { + debug!(peer = %real_peer, tls_len = tls_len, max_tls_len = MAX_TLS_RECORD_SIZE, "TLS handshake length out of bounds"); stats.increment_connects_bad(); let (reader, writer) = tokio::io::split(stream); handle_bad_client( @@ -204,9 +239,19 @@ where &config, &replay_checker, true, Some(tls_user.as_str()), ).await { HandshakeResult::Success(result) => result, - HandshakeResult::BadClient { reader: _, writer: _ } => { + HandshakeResult::BadClient { reader, writer } => { stats.increment_connects_bad(); debug!(peer = %peer, "Valid TLS but invalid MTProto handshake"); + handle_bad_client( + reader, + writer, + &mtproto_handshake, + real_peer, + local_addr, + &config, + &beobachten, + ) + .await; return Ok(HandshakeOutcome::Handled); } HandshakeResult::Error(e) => return Err(e), @@ -445,6 +490,24 @@ impl RunningClientHandler { .await { Ok(Ok(info)) => { + if !is_trusted_proxy_source( + self.peer.ip(), + &self.config.server.proxy_protocol_trusted_cidrs, + ) { + self.stats.increment_connects_bad(); + warn!( + peer = %self.peer, + trusted = ?self.config.server.proxy_protocol_trusted_cidrs, + "Rejecting PROXY protocol header from untrusted source" + ); + record_beobachten_class( + &self.beobachten, + &self.config, + self.peer.ip(), + "other", + ); + return Err(ProxyError::InvalidProxyProtocol); + } debug!( peer = %self.peer, client = %info.src_addr, @@ -513,8 +576,10 @@ impl RunningClientHandler { debug!(peer = %peer, tls_len = tls_len, "Reading TLS handshake"); - if tls_len < 512 { - debug!(peer = %peer, tls_len = tls_len, "TLS handshake too short"); + // See RFC 8446 §5.1: TLSPlaintext records must not exceed 16_384 bytes. + // Treat too-small or too-large lengths as active probes and mask them. + if !(512..=MAX_TLS_RECORD_SIZE).contains(&tls_len) { + debug!(peer = %peer, tls_len = tls_len, max_tls_len = MAX_TLS_RECORD_SIZE, "TLS handshake length out of bounds"); self.stats.increment_connects_bad(); let (reader, writer) = self.stream.into_split(); handle_bad_client( @@ -590,12 +655,19 @@ impl RunningClientHandler { .await { HandshakeResult::Success(result) => result, - HandshakeResult::BadClient { - reader: _, - writer: _, - } => { + HandshakeResult::BadClient { reader, writer } => { stats.increment_connects_bad(); debug!(peer = %peer, "Valid TLS but invalid MTProto handshake"); + handle_bad_client( + reader, + writer, + &mtproto_handshake, + peer, + local_addr, + &config, + &self.beobachten, + ) + .await; return Ok(HandshakeOutcome::Handled); } HandshakeResult::Error(e) => return Err(e), @@ -742,7 +814,7 @@ impl RunningClientHandler { client_writer, success, pool.clone(), - stats, + stats.clone(), config, buffer_pool, local_addr, @@ -759,7 +831,7 @@ impl RunningClientHandler { client_writer, success, upstream_manager, - stats, + stats.clone(), config, buffer_pool, rng, @@ -776,7 +848,7 @@ impl RunningClientHandler { client_writer, success, upstream_manager, - stats, + stats.clone(), config, buffer_pool, rng, @@ -787,6 +859,7 @@ impl RunningClientHandler { .await }; + stats.decrement_user_curr_connects(&user); ip_tracker.remove_ip(&user, peer_addr.ip()).await; relay_result } @@ -806,9 +879,29 @@ impl RunningClientHandler { }); } - let ip_reserved = match ip_tracker.check_and_add(user, peer_addr.ip()).await { - Ok(()) => true, + if let Some(quota) = config.access.user_data_quota.get(user) + && stats.get_user_total_octets(user) >= *quota + { + return Err(ProxyError::DataQuotaExceeded { + user: user.to_string(), + }); + } + + let limit = config + .access + .user_max_tcp_conns + .get(user) + .map(|v| *v as u64); + if !stats.try_acquire_user_curr_connects(user, limit) { + return Err(ProxyError::ConnectionLimitExceeded { + user: user.to_string(), + }); + } + + match ip_tracker.check_and_add(user, peer_addr.ip()).await { + Ok(()) => {} Err(reason) => { + stats.decrement_user_curr_connects(user); warn!( user = %user, ip = %peer_addr.ip(), @@ -819,33 +912,12 @@ impl RunningClientHandler { user: user.to_string(), }); } - }; - // IP limit check - - if let Some(limit) = config.access.user_max_tcp_conns.get(user) - && stats.get_user_curr_connects(user) >= *limit as u64 - { - if ip_reserved { - ip_tracker.remove_ip(user, peer_addr.ip()).await; - stats.increment_ip_reservation_rollback_tcp_limit_total(); - } - return Err(ProxyError::ConnectionLimitExceeded { - user: user.to_string(), - }); - } - - if let Some(quota) = config.access.user_data_quota.get(user) - && stats.get_user_total_octets(user) >= *quota - { - if ip_reserved { - ip_tracker.remove_ip(user, peer_addr.ip()).await; - stats.increment_ip_reservation_rollback_quota_limit_total(); - } - return Err(ProxyError::DataQuotaExceeded { - user: user.to_string(), - }); } Ok(()) } } + +#[cfg(test)] +#[path = "client_security_tests.rs"] +mod security_tests; diff --git a/src/proxy/client_security_tests.rs b/src/proxy/client_security_tests.rs new file mode 100644 index 0000000..415cafd --- /dev/null +++ b/src/proxy/client_security_tests.rs @@ -0,0 +1,2071 @@ +use super::*; +use crate::config::{UpstreamConfig, UpstreamType}; +use crate::crypto::sha256_hmac; +use crate::protocol::tls; +use crate::transport::proxy_protocol::ProxyProtocolV1Builder; +use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; +use tokio::net::{TcpListener, TcpStream}; + +#[tokio::test] +async fn short_tls_probe_is_masked_through_client_pipeline() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = vec![0x16, 0x03, 0x01, 0x00, 0x10]; + let backend_reply = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK".to_vec(); + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = vec![0u8; probe.len()]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, probe); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(4096); + let peer: SocketAddr = "203.0.113.77:55001".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + )); + + client_side.write_all(&probe).await.unwrap(); + let mut observed = vec![0u8; backend_reply.len()]; + client_side.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, backend_reply); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + accept_task.await.unwrap(); +} + +#[tokio::test] +async fn handle_client_stream_increments_connects_all_exactly_once() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = vec![0x16, 0x03, 0x01, 0x00, 0x10]; + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = vec![0u8; probe.len()]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, probe); + } + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let before = stats.get_connects_all(); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(4096); + let peer: SocketAddr = "203.0.113.177:55001".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats.clone(), + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + )); + + client_side.write_all(&probe).await.unwrap(); + drop(client_side); + + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + assert_eq!( + stats.get_connects_all(), + before + 1, + "handle_client_stream must increment connects_all exactly once" + ); +} + +#[tokio::test] +async fn running_client_handler_increments_connects_all_exactly_once() { + let mask_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = mask_listener.local_addr().unwrap(); + + let front_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let front_addr = front_listener.local_addr().unwrap(); + + let probe = [0x16, 0x03, 0x01, 0x00, 0x10]; + + let mask_accept_task = tokio::spawn(async move { + let (mut stream, _) = mask_listener.accept().await.unwrap(); + let mut got = [0u8; 5]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, probe); + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let before = stats.get_connects_all(); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let server_task = { + let config = config.clone(); + let stats = stats.clone(); + let upstream_manager = upstream_manager.clone(); + let replay_checker = replay_checker.clone(); + let buffer_pool = buffer_pool.clone(); + let rng = rng.clone(); + let route_runtime = route_runtime.clone(); + let ip_tracker = ip_tracker.clone(); + let beobachten = beobachten.clone(); + + tokio::spawn(async move { + let (stream, peer) = front_listener.accept().await.unwrap(); + let real_peer_report = Arc::new(std::sync::Mutex::new(None)); + ClientHandler::new( + stream, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + real_peer_report, + ) + .run() + .await + }) + }; + + let mut client = TcpStream::connect(front_addr).await.unwrap(); + client.write_all(&probe).await.unwrap(); + drop(client); + + let _ = tokio::time::timeout(Duration::from_secs(3), server_task) + .await + .unwrap() + .unwrap(); + tokio::time::timeout(Duration::from_secs(3), mask_accept_task) + .await + .unwrap() + .unwrap(); + + assert_eq!( + stats.get_connects_all(), + before + 1, + "ClientHandler::run must increment connects_all exactly once" + ); +} + +#[tokio::test] +async fn partial_tls_header_stall_triggers_handshake_timeout() { + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.timeouts.client_handshake = 1; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(4096); + let peer: SocketAddr = "198.51.100.170:55201".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + )); + + client_side + .write_all(&[0x16, 0x03, 0x01, 0x02, 0x00]) + .await + .unwrap(); + + let result = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + assert!(matches!(result, Err(ProxyError::TgHandshakeTimeout))); +} + +fn make_valid_tls_client_hello_with_len(secret: &[u8], timestamp: u32, tls_len: usize) -> Vec { + assert!(tls_len <= u16::MAX as usize, "TLS length must fit into record header"); + + let total_len = 5 + tls_len; + let mut handshake = vec![0x42u8; total_len]; + + handshake[0] = 0x16; + handshake[1] = 0x03; + handshake[2] = 0x01; + handshake[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes()); + + let session_id_len: usize = 32; + handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8; + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + let computed = sha256_hmac(secret, &handshake); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].copy_from_slice(&digest); + handshake +} + +fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32) -> Vec { + make_valid_tls_client_hello_with_len(secret, timestamp, 600) +} + +fn make_valid_tls_client_hello_with_alpn( + secret: &[u8], + timestamp: u32, + alpn_protocols: &[&[u8]], +) -> Vec { + let mut body = Vec::new(); + body.extend_from_slice(&TLS_VERSION); + body.extend_from_slice(&[0u8; 32]); + body.push(32); + body.extend_from_slice(&[0x42u8; 32]); + body.extend_from_slice(&2u16.to_be_bytes()); + body.extend_from_slice(&[0x13, 0x01]); + body.push(1); + body.push(0); + + let mut ext_blob = Vec::new(); + if !alpn_protocols.is_empty() { + let mut alpn_list = Vec::new(); + for proto in alpn_protocols { + alpn_list.push(proto.len() as u8); + alpn_list.extend_from_slice(proto); + } + + let mut alpn_data = Vec::new(); + alpn_data.extend_from_slice(&(alpn_list.len() as u16).to_be_bytes()); + alpn_data.extend_from_slice(&alpn_list); + + ext_blob.extend_from_slice(&0x0010u16.to_be_bytes()); + ext_blob.extend_from_slice(&(alpn_data.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&alpn_data); + } + + body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes()); + body.extend_from_slice(&ext_blob); + + let mut handshake = Vec::new(); + handshake.push(0x01); + let body_len = (body.len() as u32).to_be_bytes(); + handshake.extend_from_slice(&body_len[1..4]); + handshake.extend_from_slice(&body); + + let mut record = Vec::new(); + record.push(0x16); + record.extend_from_slice(&[0x03, 0x01]); + record.extend_from_slice(&(handshake.len() as u16).to_be_bytes()); + record.extend_from_slice(&handshake); + + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + let computed = sha256_hmac(secret, &record); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .copy_from_slice(&digest); + record +} + +fn wrap_tls_application_data(payload: &[u8]) -> Vec { + let mut record = Vec::with_capacity(5 + payload.len()); + record.push(0x17); + record.extend_from_slice(&[0x03, 0x03]); + record.extend_from_slice(&(payload.len() as u16).to_be_bytes()); + record.extend_from_slice(payload); + record +} + +#[tokio::test] +async fn valid_tls_path_does_not_fall_back_to_mask_backend() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x11u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 0); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + cfg.access.ignore_time_skew = true; + cfg.access + .users + .insert("user".to_string(), "11111111111111111111111111111111".to_string()); + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(8192); + let peer: SocketAddr = "198.51.100.80:55002".parse().unwrap(); + let stats_for_assert = stats.clone(); + let bad_before = stats_for_assert.get_connects_bad(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + + let mut record_header = [0u8; 5]; + client_side.read_exact(&mut record_header).await.unwrap(); + assert_eq!(record_header[0], 0x16); + + drop(client_side); + let handler_result = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + assert!(handler_result.is_err()); + + let no_mask_connect = tokio::time::timeout(Duration::from_millis(250), listener.accept()).await; + assert!( + no_mask_connect.is_err(), + "Mask backend must not be contacted on authenticated TLS path" + ); + + let bad_after = stats_for_assert.get_connects_bad(); + assert_eq!( + bad_before, + bad_after, + "Authenticated TLS path must not increment connects_bad" + ); +} + +#[tokio::test] +async fn valid_tls_with_invalid_mtproto_falls_back_to_mask_backend() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x33u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 0); + let invalid_mtproto = vec![0u8; crate::protocol::constants::HANDSHAKE_LEN]; + let tls_app_record = wrap_tls_application_data(&invalid_mtproto); + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = vec![0u8; invalid_mtproto.len()]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, invalid_mtproto); + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + cfg.access.ignore_time_skew = true; + cfg.access + .users + .insert("user".to_string(), "33333333333333333333333333333333".to_string()); + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(32768); + let peer: SocketAddr = "198.51.100.90:55111".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut tls_response_head = [0u8; 5]; + client_side.read_exact(&mut tls_response_head).await.unwrap(); + assert_eq!(tls_response_head[0], 0x16); + + client_side.write_all(&tls_app_record).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn client_handler_tls_bad_mtproto_is_forwarded_to_mask_backend() { + let mask_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = mask_listener.local_addr().unwrap(); + + let front_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let front_addr = front_listener.local_addr().unwrap(); + + let secret = [0x44u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 0); + let invalid_mtproto = vec![0u8; crate::protocol::constants::HANDSHAKE_LEN]; + let tls_app_record = wrap_tls_application_data(&invalid_mtproto); + + let mask_accept_task = tokio::spawn(async move { + let (mut stream, _) = mask_listener.accept().await.unwrap(); + let mut got = vec![0u8; invalid_mtproto.len()]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, invalid_mtproto); + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + cfg.access.ignore_time_skew = true; + cfg.access + .users + .insert("user".to_string(), "44444444444444444444444444444444".to_string()); + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let server_task = { + let config = config.clone(); + let stats = stats.clone(); + let upstream_manager = upstream_manager.clone(); + let replay_checker = replay_checker.clone(); + let buffer_pool = buffer_pool.clone(); + let rng = rng.clone(); + let route_runtime = route_runtime.clone(); + let ip_tracker = ip_tracker.clone(); + let beobachten = beobachten.clone(); + + tokio::spawn(async move { + let (stream, peer) = front_listener.accept().await.unwrap(); + let real_peer_report = Arc::new(std::sync::Mutex::new(None)); + ClientHandler::new( + stream, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + real_peer_report, + ) + .run() + .await + }) + }; + + let mut client = TcpStream::connect(front_addr).await.unwrap(); + client.write_all(&client_hello).await.unwrap(); + + let mut tls_response_head = [0u8; 5]; + client.read_exact(&mut tls_response_head).await.unwrap(); + assert_eq!(tls_response_head[0], 0x16); + + client.write_all(&tls_app_record).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), mask_accept_task) + .await + .unwrap() + .unwrap(); + + drop(client); + + let _ = tokio::time::timeout(Duration::from_secs(3), server_task) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn alpn_mismatch_tls_probe_is_masked_through_client_pipeline() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x66u8; 16]; + let probe = make_valid_tls_client_hello_with_alpn(&secret, 0, &[b"h3"]); + let backend_reply = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK".to_vec(); + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = vec![0u8; probe.len()]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, probe); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + cfg.censorship.alpn_enforce = true; + cfg.access.ignore_time_skew = true; + cfg.access + .users + .insert("user".to_string(), "66666666666666666666666666666666".to_string()); + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(8192); + let peer: SocketAddr = "198.51.100.66:55211".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + )); + + client_side.write_all(&probe).await.unwrap(); + let mut observed = vec![0u8; backend_reply.len()]; + client_side.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, backend_reply); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + accept_task.await.unwrap(); +} + +#[tokio::test] +async fn invalid_hmac_tls_probe_is_masked_through_client_pipeline() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x77u8; 16]; + let mut probe = make_valid_tls_client_hello(&secret, 0); + probe[tls::TLS_DIGEST_POS] ^= 0x01; + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = vec![0u8; probe.len()]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, probe); + } + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + cfg.access.ignore_time_skew = true; + cfg.access + .users + .insert("user".to_string(), "77777777777777777777777777777777".to_string()); + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(8192); + let peer: SocketAddr = "198.51.100.77:55212".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + )); + + client_side.write_all(&probe).await.unwrap(); + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn burst_invalid_tls_probes_are_masked_verbatim() { + const N: usize = 12; + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x88u8; 16]; + let mut probe = make_valid_tls_client_hello(&secret, 0); + probe[tls::TLS_DIGEST_POS + 1] ^= 0x01; + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + async move { + for _ in 0..N { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = vec![0u8; probe.len()]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, probe); + } + } + }); + + let mut handlers = Vec::with_capacity(N); + for i in 0..N { + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + cfg.access.ignore_time_skew = true; + cfg.access + .users + .insert("user".to_string(), "88888888888888888888888888888888".to_string()); + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(8192); + let peer: SocketAddr = format!("198.51.100.{}:{}", 100 + i, 56000 + i) + .parse() + .unwrap(); + let probe_bytes = probe.clone(); + + let h = tokio::spawn(async move { + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + )); + + client_side.write_all(&probe_bytes).await.unwrap(); + drop(client_side); + + tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap() + .unwrap(); + }); + handlers.push(h); + } + + for h in handlers { + tokio::time::timeout(Duration::from_secs(5), h) + .await + .unwrap() + .unwrap(); + } + + tokio::time::timeout(Duration::from_secs(5), accept_task) + .await + .unwrap() + .unwrap(); +} + +#[test] +fn unexpected_eof_is_classified_without_string_matching() { + let beobachten = BeobachtenStore::new(); + let mut config = ProxyConfig::default(); + config.general.beobachten = true; + config.general.beobachten_minutes = 1; + + let eof = ProxyError::Io(std::io::Error::from(std::io::ErrorKind::UnexpectedEof)); + let peer_ip: IpAddr = "198.51.100.200".parse().unwrap(); + + record_handshake_failure_class(&beobachten, &config, peer_ip, &eof); + + let snapshot = beobachten.snapshot_text(Duration::from_secs(60)); + assert!( + snapshot.contains("[expected_64_got_0]"), + "UnexpectedEof must be classified as expected_64_got_0" + ); + assert!( + snapshot.contains("198.51.100.200-1"), + "Classified record must include source IP" + ); +} + +#[test] +fn non_eof_error_is_classified_as_other() { + let beobachten = BeobachtenStore::new(); + let mut config = ProxyConfig::default(); + config.general.beobachten = true; + config.general.beobachten_minutes = 1; + + let non_eof = ProxyError::Io(std::io::Error::other("different error")); + let peer_ip: IpAddr = "203.0.113.201".parse().unwrap(); + + record_handshake_failure_class(&beobachten, &config, peer_ip, &non_eof); + + let snapshot = beobachten.snapshot_text(Duration::from_secs(60)); + assert!( + snapshot.contains("[other]"), + "Non-EOF errors must map to other" + ); + assert!( + snapshot.contains("203.0.113.201-1"), + "Classified record must include source IP" + ); + assert!( + !snapshot.contains("[expected_64_got_0]"), + "Non-EOF errors must not be misclassified as expected_64_got_0" + ); +} + +#[tokio::test] +async fn tcp_limit_rejection_does_not_reserve_ip_or_trigger_rollback() { + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert("user".to_string(), 1); + + let stats = Stats::new(); + stats.increment_user_curr_connects("user"); + + let ip_tracker = UserIpTracker::new(); + let peer_addr: SocketAddr = "198.51.100.210:50000".parse().unwrap(); + + let result = RunningClientHandler::check_user_limits_static( + "user", + &config, + &stats, + peer_addr, + &ip_tracker, + ) + .await; + + assert!(matches!( + result, + Err(ProxyError::ConnectionLimitExceeded { user }) if user == "user" + )); + assert_eq!( + ip_tracker.get_active_ip_count("user").await, + 0, + "Rejected client must not reserve IP slot" + ); + assert_eq!( + stats.get_ip_reservation_rollback_tcp_limit_total(), + 0, + "No rollback should occur when reservation is not taken" + ); +} + +#[tokio::test] +async fn quota_rejection_does_not_reserve_ip_or_trigger_rollback() { + let mut config = ProxyConfig::default(); + config.access.user_data_quota.insert("user".to_string(), 1024); + + let stats = Stats::new(); + stats.add_user_octets_from("user", 1024); + + let ip_tracker = UserIpTracker::new(); + let peer_addr: SocketAddr = "203.0.113.211:50001".parse().unwrap(); + + let result = RunningClientHandler::check_user_limits_static( + "user", + &config, + &stats, + peer_addr, + &ip_tracker, + ) + .await; + + assert!(matches!( + result, + Err(ProxyError::DataQuotaExceeded { user }) if user == "user" + )); + assert_eq!( + ip_tracker.get_active_ip_count("user").await, + 0, + "Quota-rejected client must not reserve IP slot" + ); + assert_eq!( + stats.get_ip_reservation_rollback_quota_limit_total(), + 0, + "No rollback should occur when reservation is not taken" + ); +} + +#[tokio::test] +async fn concurrent_limit_rejections_from_mixed_ips_leave_no_ip_footprint() { + const PARALLEL_IPS: usize = 64; + const ATTEMPTS_PER_IP: usize = 8; + + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert("user".to_string(), 1); + + let config = Arc::new(config); + let stats = Arc::new(Stats::new()); + stats.increment_user_curr_connects("user"); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let mut tasks = tokio::task::JoinSet::new(); + for i in 0..PARALLEL_IPS { + let config = config.clone(); + let stats = stats.clone(); + let ip_tracker = ip_tracker.clone(); + + tasks.spawn(async move { + let ip = IpAddr::V4(std::net::Ipv4Addr::new(198, 51, 100, (i + 1) as u8)); + for _ in 0..ATTEMPTS_PER_IP { + let peer_addr = SocketAddr::new(ip, 40000 + i as u16); + let result = RunningClientHandler::check_user_limits_static( + "user", + &config, + &stats, + peer_addr, + &ip_tracker, + ) + .await; + + assert!(matches!( + result, + Err(ProxyError::ConnectionLimitExceeded { user }) if user == "user" + )); + } + }); + } + + while let Some(joined) = tasks.join_next().await { + joined.unwrap(); + } + + assert_eq!( + ip_tracker.get_active_ip_count("user").await, + 0, + "Concurrent rejected attempts must not leave active IP reservations" + ); + + let recent = ip_tracker + .get_recent_ips_for_users(&["user".to_string()]) + .await; + assert!( + recent + .get("user") + .map(|ips| ips.is_empty()) + .unwrap_or(true), + "Concurrent rejected attempts must not leave recent IP footprint" + ); + + assert_eq!( + stats.get_ip_reservation_rollback_tcp_limit_total(), + 0, + "No rollback should occur under concurrent rejection storms" + ); +} + +#[tokio::test] +async fn atomic_limit_gate_allows_only_one_concurrent_acquire() { + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert("user".to_string(), 1); + + let config = Arc::new(config); + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let mut tasks = tokio::task::JoinSet::new(); + for i in 0..64u16 { + let config = config.clone(); + let stats = stats.clone(); + let ip_tracker = ip_tracker.clone(); + tasks.spawn(async move { + let peer = SocketAddr::new( + IpAddr::V4(std::net::Ipv4Addr::new(203, 0, 113, (i + 1) as u8)), + 30000 + i, + ); + RunningClientHandler::check_user_limits_static("user", &config, &stats, peer, &ip_tracker) + .await + .is_ok() + }); + } + + let mut successes = 0u64; + while let Some(joined) = tasks.join_next().await { + if joined.unwrap() { + successes += 1; + } + } + + assert_eq!( + successes, 1, + "exactly one concurrent acquire must pass for a limit=1 user" + ); + assert_eq!(stats.get_user_curr_connects("user"), 1); +} + +#[tokio::test] +async fn untrusted_proxy_header_source_is_rejected() { + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.server.proxy_protocol_trusted_cidrs = vec!["10.10.0.0/16".parse().unwrap()]; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(2048); + let peer: SocketAddr = "198.51.100.44:55000".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + true, + )); + + let proxy_header = ProxyProtocolV1Builder::new() + .tcp4( + "203.0.113.9:32000".parse().unwrap(), + "192.0.2.8:443".parse().unwrap(), + ) + .build(); + client_side.write_all(&proxy_header).await.unwrap(); + drop(client_side); + + let result = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + assert!(matches!(result, Err(ProxyError::InvalidProxyProtocol))); +} + +#[tokio::test] +async fn empty_proxy_trusted_cidrs_rejects_proxy_header_by_default() { + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.server.proxy_protocol_trusted_cidrs.clear(); + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(2048); + let peer: SocketAddr = "198.51.100.45:55000".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + true, + )); + + let proxy_header = ProxyProtocolV1Builder::new() + .tcp4( + "203.0.113.9:32000".parse().unwrap(), + "192.0.2.8:443".parse().unwrap(), + ) + .build(); + client_side.write_all(&proxy_header).await.unwrap(); + drop(client_side); + + let result = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + assert!(matches!(result, Err(ProxyError::InvalidProxyProtocol))); +} + +#[tokio::test] +async fn oversized_tls_record_is_masked_in_generic_stream_pipeline() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = [ + 0x16, + 0x03, + 0x01, + (((MAX_TLS_RECORD_SIZE + 1) >> 8) & 0xff) as u8, + ((MAX_TLS_RECORD_SIZE + 1) & 0xff) as u8, + ]; + let backend_reply = b"HTTP/1.1 400 Bad Request\r\nContent-Length: 0\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn({ + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = [0u8; 5]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, probe); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let bad_before = stats.get_connects_bad(); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(4096); + let peer: SocketAddr = "203.0.113.123:55123".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats.clone(), + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + )); + + client_side.write_all(&probe).await.unwrap(); + let mut observed = vec![0u8; backend_reply.len()]; + client_side.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, backend_reply); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + accept_task.await.unwrap(); + + assert_eq!( + stats.get_connects_bad(), + bad_before + 1, + "Oversized TLS probe must be classified as bad" + ); +} + +#[tokio::test] +async fn oversized_tls_record_is_masked_in_client_handler_pipeline() { + let mask_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = mask_listener.local_addr().unwrap(); + + let front_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let front_addr = front_listener.local_addr().unwrap(); + + let probe = [ + 0x16, + 0x03, + 0x01, + (((MAX_TLS_RECORD_SIZE + 1) >> 8) & 0xff) as u8, + ((MAX_TLS_RECORD_SIZE + 1) & 0xff) as u8, + ]; + let backend_reply = b"HTTP/1.1 403 Forbidden\r\nContent-Length: 0\r\n\r\n".to_vec(); + + let mask_accept_task = tokio::spawn({ + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = mask_listener.accept().await.unwrap(); + let mut got = [0u8; 5]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, probe); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let server_task = { + let config = config.clone(); + let stats = stats.clone(); + let upstream_manager = upstream_manager.clone(); + let replay_checker = replay_checker.clone(); + let buffer_pool = buffer_pool.clone(); + let rng = rng.clone(); + let route_runtime = route_runtime.clone(); + let ip_tracker = ip_tracker.clone(); + let beobachten = beobachten.clone(); + + tokio::spawn(async move { + let (stream, peer) = front_listener.accept().await.unwrap(); + let real_peer_report = Arc::new(std::sync::Mutex::new(None)); + ClientHandler::new( + stream, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + real_peer_report, + ) + .run() + .await + }) + }; + + let mut client = TcpStream::connect(front_addr).await.unwrap(); + client.write_all(&probe).await.unwrap(); + + let mut observed = vec![0u8; backend_reply.len()]; + client.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, backend_reply); + + tokio::time::timeout(Duration::from_secs(3), mask_accept_task) + .await + .unwrap() + .unwrap(); + + drop(client); + + let _ = tokio::time::timeout(Duration::from_secs(3), server_task) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn tls_record_len_511_is_rejected_in_generic_stream_pipeline() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = [0x16, 0x03, 0x01, 0x01, 0xff]; + let backend_reply = b"HTTP/1.1 400 Bad Request\r\nContent-Length: 0\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn({ + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = [0u8; 5]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, probe); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let bad_before = stats.get_connects_bad(); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(4096); + let peer: SocketAddr = "203.0.113.130:55130".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats.clone(), + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + )); + + client_side.write_all(&probe).await.unwrap(); + let mut observed = vec![0u8; backend_reply.len()]; + client_side.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, backend_reply); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + accept_task.await.unwrap(); + + assert_eq!( + stats.get_connects_bad(), + bad_before + 1, + "TLS record length 511 must be rejected" + ); +} + +#[tokio::test] +async fn tls_record_len_511_is_rejected_in_client_handler_pipeline() { + let mask_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = mask_listener.local_addr().unwrap(); + + let front_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let front_addr = front_listener.local_addr().unwrap(); + + let probe = [0x16, 0x03, 0x01, 0x01, 0xff]; + let backend_reply = b"HTTP/1.1 403 Forbidden\r\nContent-Length: 0\r\n\r\n".to_vec(); + + let mask_accept_task = tokio::spawn({ + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = mask_listener.accept().await.unwrap(); + let mut got = [0u8; 5]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, probe); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let server_task = { + let config = config.clone(); + let stats = stats.clone(); + let upstream_manager = upstream_manager.clone(); + let replay_checker = replay_checker.clone(); + let buffer_pool = buffer_pool.clone(); + let rng = rng.clone(); + let route_runtime = route_runtime.clone(); + let ip_tracker = ip_tracker.clone(); + let beobachten = beobachten.clone(); + + tokio::spawn(async move { + let (stream, peer) = front_listener.accept().await.unwrap(); + let real_peer_report = Arc::new(std::sync::Mutex::new(None)); + ClientHandler::new( + stream, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + real_peer_report, + ) + .run() + .await + }) + }; + + let mut client = TcpStream::connect(front_addr).await.unwrap(); + client.write_all(&probe).await.unwrap(); + + let mut observed = vec![0u8; backend_reply.len()]; + client.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, backend_reply); + + tokio::time::timeout(Duration::from_secs(3), mask_accept_task) + .await + .unwrap() + .unwrap(); + + drop(client); + + let _ = tokio::time::timeout(Duration::from_secs(3), server_task) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn tls_record_len_16384_is_accepted_in_generic_stream_pipeline() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x55u8; 16]; + let client_hello = make_valid_tls_client_hello_with_len(&secret, 0, MAX_TLS_RECORD_SIZE); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + cfg.access.ignore_time_skew = true; + cfg.access + .users + .insert("user".to_string(), "55555555555555555555555555555555".to_string()); + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let bad_before = stats.get_connects_bad(); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(131072); + let peer: SocketAddr = "198.51.100.55:56055".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats.clone(), + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut record_header = [0u8; 5]; + client_side.read_exact(&mut record_header).await.unwrap(); + assert_eq!(record_header[0], 0x16, "Valid max-length ClientHello must be accepted"); + + drop(client_side); + let handler_result = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + assert!(handler_result.is_err()); + + let no_mask_connect = tokio::time::timeout(Duration::from_millis(250), listener.accept()).await; + assert!( + no_mask_connect.is_err(), + "Valid max-length ClientHello must not trigger mask fallback" + ); + + assert_eq!( + bad_before, + stats.get_connects_bad(), + "Valid max-length ClientHello must not increment bad counter" + ); +} + +#[tokio::test] +async fn tls_record_len_16384_is_accepted_in_client_handler_pipeline() { + let mask_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = mask_listener.local_addr().unwrap(); + + let front_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let front_addr = front_listener.local_addr().unwrap(); + + let secret = [0x66u8; 16]; + let client_hello = make_valid_tls_client_hello_with_len(&secret, 0, MAX_TLS_RECORD_SIZE); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + cfg.access.ignore_time_skew = true; + cfg.access + .users + .insert("user".to_string(), "66666666666666666666666666666666".to_string()); + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let bad_before = stats.get_connects_bad(); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let server_task = { + let config = config.clone(); + let stats = stats.clone(); + let upstream_manager = upstream_manager.clone(); + let replay_checker = replay_checker.clone(); + let buffer_pool = buffer_pool.clone(); + let rng = rng.clone(); + let route_runtime = route_runtime.clone(); + let ip_tracker = ip_tracker.clone(); + let beobachten = beobachten.clone(); + + tokio::spawn(async move { + let (stream, peer) = front_listener.accept().await.unwrap(); + let real_peer_report = Arc::new(std::sync::Mutex::new(None)); + ClientHandler::new( + stream, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + real_peer_report, + ) + .run() + .await + }) + }; + + let mut client = TcpStream::connect(front_addr).await.unwrap(); + client.write_all(&client_hello).await.unwrap(); + + let mut record_header = [0u8; 5]; + client.read_exact(&mut record_header).await.unwrap(); + assert_eq!(record_header[0], 0x16, "Valid max-length ClientHello must be accepted"); + + drop(client); + + let _ = tokio::time::timeout(Duration::from_secs(3), server_task) + .await + .unwrap() + .unwrap(); + + let no_mask_connect = tokio::time::timeout(Duration::from_millis(250), mask_listener.accept()).await; + assert!( + no_mask_connect.is_err(), + "Valid max-length ClientHello must not trigger mask fallback in ClientHandler path" + ); + + assert_eq!( + bad_before, + stats.get_connects_bad(), + "Valid max-length ClientHello must not increment bad counter" + ); +} diff --git a/src/proxy/direct_relay.rs b/src/proxy/direct_relay.rs index 7a7810a..9c6116c 100644 --- a/src/proxy/direct_relay.rs +++ b/src/proxy/direct_relay.rs @@ -2,6 +2,8 @@ use std::fs::OpenOptions; use std::io::Write; use std::net::SocketAddr; use std::sync::Arc; +use std::collections::HashSet; +use std::sync::{Mutex, OnceLock}; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio::net::TcpStream; @@ -22,6 +24,45 @@ use crate::stats::Stats; use crate::stream::{BufferPool, CryptoReader, CryptoWriter}; use crate::transport::UpstreamManager; +const UNKNOWN_DC_LOG_DISTINCT_LIMIT: usize = 1024; +static LOGGED_UNKNOWN_DCS: OnceLock>> = OnceLock::new(); + +// In tests, this function shares global mutable state. Callers that also use +// cache-reset helpers must hold `unknown_dc_test_lock()` to keep assertions +// deterministic under parallel execution. +fn should_log_unknown_dc(dc_idx: i16) -> bool { + let set = LOGGED_UNKNOWN_DCS.get_or_init(|| Mutex::new(HashSet::new())); + match set.lock() { + Ok(mut guard) => { + if guard.contains(&dc_idx) { + return false; + } + if guard.len() >= UNKNOWN_DC_LOG_DISTINCT_LIMIT { + return false; + } + guard.insert(dc_idx) + } + // If the lock is poisoned, keep logging rather than silently dropping + // operator-visible diagnostics. + Err(_) => true, + } +} + +#[cfg(test)] +fn clear_unknown_dc_log_cache_for_testing() { + if let Some(set) = LOGGED_UNKNOWN_DCS.get() + && let Ok(mut guard) = set.lock() + { + guard.clear(); + } +} + +#[cfg(test)] +fn unknown_dc_test_lock() -> &'static Mutex<()> { + static TEST_LOCK: OnceLock> = OnceLock::new(); + TEST_LOCK.get_or_init(|| Mutex::new(())) +} + pub(crate) async fn handle_via_direct( client_reader: CryptoReader, client_writer: CryptoWriter, @@ -64,7 +105,6 @@ where debug!(peer = %success.peer, "TG handshake complete, starting relay"); stats.increment_user_connects(user); - stats.increment_user_curr_connects(user); stats.increment_current_connections_direct(); let relay_result = relay_bidirectional( @@ -109,7 +149,6 @@ where }; stats.decrement_current_connections_direct(); - stats.decrement_user_curr_connects(user); match &relay_result { Ok(()) => debug!(user = %user, "Direct relay completed"), @@ -160,6 +199,7 @@ fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result { warn!(dc_idx = dc_idx, "Requested non-standard DC with no override; falling back to default cluster"); if config.general.unknown_dc_file_log_enabled && let Some(path) = &config.general.unknown_dc_log_path + && should_log_unknown_dc(dc_idx) && let Ok(handle) = tokio::runtime::Handle::try_current() { let path = path.clone(); @@ -175,7 +215,7 @@ fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result { let fallback_idx = if default_dc >= 1 && default_dc <= num_dcs { default_dc - 1 } else { - 1 + 0 }; info!( @@ -203,8 +243,6 @@ async fn do_tg_handshake_static( let (nonce, _tg_enc_key, _tg_enc_iv, _tg_dec_key, _tg_dec_iv) = generate_tg_nonce( success.proto_tag, success.dc_idx, - &success.dec_key, - success.dec_iv, &success.enc_key, success.enc_iv, rng, @@ -230,3 +268,7 @@ async fn do_tg_handshake_static( CryptoWriter::new(write_half, tg_encryptor, max_pending), )) } + +#[cfg(test)] +#[path = "direct_relay_security_tests.rs"] +mod security_tests; diff --git a/src/proxy/direct_relay_security_tests.rs b/src/proxy/direct_relay_security_tests.rs new file mode 100644 index 0000000..3b3185a --- /dev/null +++ b/src/proxy/direct_relay_security_tests.rs @@ -0,0 +1,51 @@ +use super::*; + +#[test] +fn unknown_dc_log_is_deduplicated_per_dc_idx() { + let _guard = unknown_dc_test_lock() + .lock() + .expect("unknown dc test lock must be available"); + clear_unknown_dc_log_cache_for_testing(); + + assert!(should_log_unknown_dc(777)); + assert!( + !should_log_unknown_dc(777), + "same unknown dc_idx must not be logged repeatedly" + ); + assert!( + should_log_unknown_dc(778), + "different unknown dc_idx must still be loggable" + ); +} + +#[test] +fn unknown_dc_log_respects_distinct_limit() { + let _guard = unknown_dc_test_lock() + .lock() + .expect("unknown dc test lock must be available"); + clear_unknown_dc_log_cache_for_testing(); + + for dc in 1..=UNKNOWN_DC_LOG_DISTINCT_LIMIT { + assert!( + should_log_unknown_dc(dc as i16), + "expected first-time unknown dc_idx to be loggable" + ); + } + + assert!( + !should_log_unknown_dc(i16::MAX), + "distinct unknown dc_idx entries above limit must not be logged" + ); +} + +#[test] +fn fallback_dc_never_panics_with_single_dc_list() { + let mut cfg = ProxyConfig::default(); + cfg.network.prefer = 6; + cfg.network.ipv6 = Some(true); + cfg.default_dc = Some(42); + + let addr = get_dc_addr_static(999, &cfg).expect("fallback dc must resolve safely"); + let expected = SocketAddr::new(TG_DATACENTERS_V6[0], TG_DATACENTER_PORT); + assert_eq!(addr, expected); +} diff --git a/src/proxy/handshake.rs b/src/proxy/handshake.rs index 296432f..ef98144 100644 --- a/src/proxy/handshake.rs +++ b/src/proxy/handshake.rs @@ -3,8 +3,12 @@ #![allow(dead_code)] use std::net::SocketAddr; +use std::collections::HashSet; +use std::net::IpAddr; use std::sync::Arc; -use std::time::Duration; +use std::sync::{Mutex, OnceLock}; +use std::time::{Duration, Instant}; +use dashmap::DashMap; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tracing::{debug, warn, trace}; use zeroize::Zeroize; @@ -19,6 +23,231 @@ use crate::stats::ReplayChecker; use crate::config::ProxyConfig; use crate::tls_front::{TlsFrontCache, emulator}; +const ACCESS_SECRET_BYTES: usize = 16; +static INVALID_SECRET_WARNED: OnceLock>> = OnceLock::new(); + +const AUTH_PROBE_TRACK_RETENTION_SECS: u64 = 10 * 60; +#[cfg(test)] +const AUTH_PROBE_TRACK_MAX_ENTRIES: usize = 256; +#[cfg(not(test))] +const AUTH_PROBE_TRACK_MAX_ENTRIES: usize = 65_536; +const AUTH_PROBE_PRUNE_SCAN_LIMIT: usize = 1_024; +const AUTH_PROBE_BACKOFF_START_FAILS: u32 = 4; + +#[cfg(test)] +const AUTH_PROBE_BACKOFF_BASE_MS: u64 = 1; +#[cfg(not(test))] +const AUTH_PROBE_BACKOFF_BASE_MS: u64 = 25; + +#[cfg(test)] +const AUTH_PROBE_BACKOFF_MAX_MS: u64 = 16; +#[cfg(not(test))] +const AUTH_PROBE_BACKOFF_MAX_MS: u64 = 1_000; + +#[derive(Clone, Copy)] +struct AuthProbeState { + fail_streak: u32, + blocked_until: Instant, + last_seen: Instant, +} + +static AUTH_PROBE_STATE: OnceLock> = OnceLock::new(); + +fn auth_probe_state_map() -> &'static DashMap { + AUTH_PROBE_STATE.get_or_init(DashMap::new) +} + +fn auth_probe_backoff(fail_streak: u32) -> Duration { + if fail_streak < AUTH_PROBE_BACKOFF_START_FAILS { + return Duration::ZERO; + } + let shift = (fail_streak - AUTH_PROBE_BACKOFF_START_FAILS).min(10); + let multiplier = 1u64.checked_shl(shift).unwrap_or(u64::MAX); + let ms = AUTH_PROBE_BACKOFF_BASE_MS + .saturating_mul(multiplier) + .min(AUTH_PROBE_BACKOFF_MAX_MS); + Duration::from_millis(ms) +} + +fn auth_probe_state_expired(state: &AuthProbeState, now: Instant) -> bool { + let retention = Duration::from_secs(AUTH_PROBE_TRACK_RETENTION_SECS); + now.duration_since(state.last_seen) > retention +} + +fn auth_probe_is_throttled(peer_ip: IpAddr, now: Instant) -> bool { + let state = auth_probe_state_map(); + let Some(entry) = state.get(&peer_ip) else { + return false; + }; + if auth_probe_state_expired(&entry, now) { + drop(entry); + state.remove(&peer_ip); + return false; + } + now < entry.blocked_until +} + +fn auth_probe_record_failure(peer_ip: IpAddr, now: Instant) { + let state = auth_probe_state_map(); + auth_probe_record_failure_with_state(state, peer_ip, now); +} + +fn auth_probe_record_failure_with_state( + state: &DashMap, + peer_ip: IpAddr, + now: Instant, +) { + if let Some(mut entry) = state.get_mut(&peer_ip) { + if auth_probe_state_expired(&entry, now) { + *entry = AuthProbeState { + fail_streak: 1, + blocked_until: now + auth_probe_backoff(1), + last_seen: now, + }; + return; + } + entry.fail_streak = entry.fail_streak.saturating_add(1); + entry.last_seen = now; + entry.blocked_until = now + auth_probe_backoff(entry.fail_streak); + return; + }; + + if state.len() >= AUTH_PROBE_TRACK_MAX_ENTRIES { + let mut stale_keys = Vec::new(); + for entry in state.iter().take(AUTH_PROBE_PRUNE_SCAN_LIMIT) { + if auth_probe_state_expired(entry.value(), now) { + stale_keys.push(*entry.key()); + } + } + for stale_key in stale_keys { + state.remove(&stale_key); + } + if state.len() >= AUTH_PROBE_TRACK_MAX_ENTRIES { + return; + } + } + + state.insert(peer_ip, AuthProbeState { + fail_streak: 0, + blocked_until: now, + last_seen: now, + }); + + if let Some(mut entry) = state.get_mut(&peer_ip) { + entry.fail_streak = 1; + entry.blocked_until = now + auth_probe_backoff(1); + } +} + +fn auth_probe_record_success(peer_ip: IpAddr) { + let state = auth_probe_state_map(); + state.remove(&peer_ip); +} + +#[cfg(test)] +fn clear_auth_probe_state_for_testing() { + if let Some(state) = AUTH_PROBE_STATE.get() { + state.clear(); + } +} + +#[cfg(test)] +fn auth_probe_fail_streak_for_testing(peer_ip: IpAddr) -> Option { + let state = AUTH_PROBE_STATE.get()?; + state.get(&peer_ip).map(|entry| entry.fail_streak) +} + +#[cfg(test)] +fn auth_probe_is_throttled_for_testing(peer_ip: IpAddr) -> bool { + auth_probe_is_throttled(peer_ip, Instant::now()) +} + +#[cfg(test)] +fn auth_probe_test_lock() -> &'static Mutex<()> { + static TEST_LOCK: OnceLock> = OnceLock::new(); + TEST_LOCK.get_or_init(|| Mutex::new(())) +} + +#[cfg(test)] +fn clear_warned_secrets_for_testing() { + if let Some(warned) = INVALID_SECRET_WARNED.get() + && let Ok(mut guard) = warned.lock() + { + guard.clear(); + } +} + +fn warn_invalid_secret_once(name: &str, reason: &str, expected: usize, got: Option) { + let key = (name.to_string(), reason.to_string()); + let warned = INVALID_SECRET_WARNED.get_or_init(|| Mutex::new(HashSet::new())); + let should_warn = match warned.lock() { + Ok(mut guard) => guard.insert(key), + Err(_) => true, + }; + + if !should_warn { + return; + } + + match got { + Some(actual) => { + warn!( + user = %name, + expected = expected, + got = actual, + "Skipping user: access secret has unexpected length" + ); + } + None => { + warn!( + user = %name, + "Skipping user: access secret is not valid hex" + ); + } + } +} + +fn decode_user_secret(name: &str, secret_hex: &str) -> Option> { + match hex::decode(secret_hex) { + Ok(bytes) if bytes.len() == ACCESS_SECRET_BYTES => Some(bytes), + Ok(bytes) => { + warn_invalid_secret_once( + name, + "invalid_length", + ACCESS_SECRET_BYTES, + Some(bytes.len()), + ); + None + } + Err(_) => { + warn_invalid_secret_once(name, "invalid_hex", ACCESS_SECRET_BYTES, None); + None + } + } +} + +// Decide whether a client-supplied proto tag is allowed given the configured +// proxy modes and the transport that carried the handshake. +// +// A common mistake is to treat `modes.tls` and `modes.secure` as interchangeable +// even though they correspond to different transport profiles: `modes.tls` is +// for the TLS-fronted (EE-TLS) path, while `modes.secure` is for direct MTProto +// over TCP (DD). Enforcing this separation prevents an attacker from using a +// TLS-capable client to bypass the operator intent for the direct MTProto mode, +// and vice versa. +fn mode_enabled_for_proto(config: &ProxyConfig, proto_tag: ProtoTag, is_tls: bool) -> bool { + match proto_tag { + ProtoTag::Secure => { + if is_tls { + config.general.modes.tls + } else { + config.general.modes.secure + } + } + ProtoTag::Intermediate | ProtoTag::Abridged => config.general.modes.classic, + } +} + fn decode_user_secrets( config: &ProxyConfig, preferred_user: Option<&str>, @@ -27,7 +256,7 @@ fn decode_user_secrets( if let Some(preferred) = preferred_user && let Some(secret_hex) = config.access.users.get(preferred) - && let Ok(bytes) = hex::decode(secret_hex) + && let Some(bytes) = decode_user_secret(preferred, secret_hex) { secrets.push((preferred.to_string(), bytes)); } @@ -36,7 +265,7 @@ fn decode_user_secrets( if preferred_user.is_some_and(|preferred| preferred == name.as_str()) { continue; } - if let Ok(bytes) = hex::decode(secret_hex) { + if let Some(bytes) = decode_user_secret(name, secret_hex) { secrets.push((name.clone(), bytes)); } } @@ -48,7 +277,7 @@ fn decode_user_secrets( /// /// Key material (`dec_key`, `dec_iv`, `enc_key`, `enc_iv`) is /// zeroized on drop. -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct HandshakeSuccess { /// Authenticated user name pub user: String, @@ -94,28 +323,27 @@ where { debug!(peer = %peer, handshake_len = handshake.len(), "Processing TLS handshake"); + if auth_probe_is_throttled(peer.ip(), Instant::now()) { + debug!(peer = %peer, "TLS handshake rejected by pre-auth probe throttle"); + return HandshakeResult::BadClient { reader, writer }; + } + if handshake.len() < tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 { debug!(peer = %peer, "TLS handshake too short"); return HandshakeResult::BadClient { reader, writer }; } - let digest = &handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN]; - let digest_half = &digest[..tls::TLS_DIGEST_HALF_LEN]; - - if replay_checker.check_and_add_tls_digest(digest_half) { - warn!(peer = %peer, "TLS replay attack detected (duplicate digest)"); - return HandshakeResult::BadClient { reader, writer }; - } - let secrets = decode_user_secrets(config, None); - let validation = match tls::validate_tls_handshake( + let validation = match tls::validate_tls_handshake_with_replay_window( handshake, &secrets, config.access.ignore_time_skew, + config.access.replay_window_secs, ) { Some(v) => v, None => { + auth_probe_record_failure(peer.ip(), Instant::now()); debug!( peer = %peer, ignore_time_skew = config.access.ignore_time_skew, @@ -125,6 +353,15 @@ where } }; + // Replay tracking is applied only after successful authentication to avoid + // letting unauthenticated probes evict valid entries from the replay cache. + let digest_half = &validation.digest[..tls::TLS_DIGEST_HALF_LEN]; + if replay_checker.check_and_add_tls_digest(digest_half) { + auth_probe_record_failure(peer.ip(), Instant::now()); + warn!(peer = %peer, "TLS replay attack detected (duplicate digest)"); + return HandshakeResult::BadClient { reader, writer }; + } + let secret = match secrets.iter().find(|(name, _)| *name == validation.user) { Some((_, s)) => s, None => return HandshakeResult::BadClient { reader, writer }, @@ -166,6 +403,9 @@ where Some(b"h2".to_vec()) } else if alpn_list.iter().any(|p| p == b"http/1.1") { Some(b"http/1.1".to_vec()) + } else if !alpn_list.is_empty() { + debug!(peer = %peer, "Client ALPN list has no supported protocol; using masking fallback"); + return HandshakeResult::BadClient { reader, writer }; } else { None } @@ -228,6 +468,8 @@ where "TLS handshake successful" ); + auth_probe_record_success(peer.ip()); + HandshakeResult::Success(( FakeTlsReader::new(reader), FakeTlsWriter::new(writer), @@ -252,13 +494,13 @@ where { trace!(peer = %peer, handshake = ?hex::encode(handshake), "MTProto handshake bytes"); - let dec_prekey_iv = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]; - - if replay_checker.check_and_add_handshake(dec_prekey_iv) { - warn!(peer = %peer, "MTProto replay attack detected"); + if auth_probe_is_throttled(peer.ip(), Instant::now()) { + debug!(peer = %peer, "MTProto handshake rejected by pre-auth probe throttle"); return HandshakeResult::BadClient { reader, writer }; } + let dec_prekey_iv = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]; + let enc_prekey_iv: Vec = dec_prekey_iv.iter().rev().copied().collect(); let decoded_users = decode_user_secrets(config, preferred_user); @@ -273,39 +515,33 @@ where dec_key_input.extend_from_slice(&secret); let dec_key = sha256(&dec_key_input); - let dec_iv = u128::from_be_bytes(dec_iv_bytes.try_into().unwrap()); + let mut dec_iv_arr = [0u8; IV_LEN]; + dec_iv_arr.copy_from_slice(dec_iv_bytes); + let dec_iv = u128::from_be_bytes(dec_iv_arr); let mut decryptor = AesCtr::new(&dec_key, dec_iv); let decrypted = decryptor.decrypt(handshake); - let tag_bytes: [u8; 4] = decrypted[PROTO_TAG_POS..PROTO_TAG_POS + 4] - .try_into() - .unwrap(); + let tag_bytes: [u8; 4] = [ + decrypted[PROTO_TAG_POS], + decrypted[PROTO_TAG_POS + 1], + decrypted[PROTO_TAG_POS + 2], + decrypted[PROTO_TAG_POS + 3], + ]; let proto_tag = match ProtoTag::from_bytes(tag_bytes) { Some(tag) => tag, None => continue, }; - let mode_ok = match proto_tag { - ProtoTag::Secure => { - if is_tls { - config.general.modes.tls || config.general.modes.secure - } else { - config.general.modes.secure || config.general.modes.tls - } - } - ProtoTag::Intermediate | ProtoTag::Abridged => config.general.modes.classic, - }; + let mode_ok = mode_enabled_for_proto(config, proto_tag, is_tls); if !mode_ok { debug!(peer = %peer, user = %user, proto = ?proto_tag, "Mode not enabled"); continue; } - let dc_idx = i16::from_le_bytes( - decrypted[DC_IDX_POS..DC_IDX_POS + 2].try_into().unwrap() - ); + let dc_idx = i16::from_le_bytes([decrypted[DC_IDX_POS], decrypted[DC_IDX_POS + 1]]); let enc_prekey = &enc_prekey_iv[..PREKEY_LEN]; let enc_iv_bytes = &enc_prekey_iv[PREKEY_LEN..]; @@ -315,10 +551,24 @@ where enc_key_input.extend_from_slice(&secret); let enc_key = sha256(&enc_key_input); - let enc_iv = u128::from_be_bytes(enc_iv_bytes.try_into().unwrap()); + let mut enc_iv_arr = [0u8; IV_LEN]; + enc_iv_arr.copy_from_slice(enc_iv_bytes); + let enc_iv = u128::from_be_bytes(enc_iv_arr); let encryptor = AesCtr::new(&enc_key, enc_iv); +// Apply replay tracking only after successful authentication. + // + // This ordering prevents an attacker from producing invalid handshakes that + // still collide with a valid handshake's replay slot and thus evict a valid + // entry from the cache. We accept the cost of performing the full + // authentication check first to avoid poisoning the replay cache. + if replay_checker.check_and_add_handshake(dec_prekey_iv) { + auth_probe_record_failure(peer.ip(), Instant::now()); + warn!(peer = %peer, user = %user, "MTProto replay attack detected"); + return HandshakeResult::BadClient { reader, writer }; + } + let success = HandshakeSuccess { user: user.clone(), dc_idx, @@ -340,6 +590,8 @@ where "MTProto handshake successful" ); + auth_probe_record_success(peer.ip()); + let max_pending = config.general.crypto_pending_buffer; return HandshakeResult::Success(( CryptoReader::new(reader, decryptor), @@ -348,6 +600,7 @@ where )); } + auth_probe_record_failure(peer.ip(), Instant::now()); debug!(peer = %peer, "MTProto handshake: no matching user found"); HandshakeResult::BadClient { reader, writer } } @@ -356,8 +609,6 @@ where pub fn generate_tg_nonce( proto_tag: ProtoTag, dc_idx: i16, - _client_dec_key: &[u8; 32], - _client_dec_iv: u128, client_enc_key: &[u8; 32], client_enc_iv: u128, rng: &SecureRandom, @@ -365,14 +616,16 @@ pub fn generate_tg_nonce( ) -> ([u8; HANDSHAKE_LEN], [u8; 32], u128, [u8; 32], u128) { loop { let bytes = rng.bytes(HANDSHAKE_LEN); - let mut nonce: [u8; HANDSHAKE_LEN] = bytes.try_into().unwrap(); + let Ok(mut nonce): Result<[u8; HANDSHAKE_LEN], _> = bytes.try_into() else { + continue; + }; if RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]) { continue; } - let first_four: [u8; 4] = nonce[..4].try_into().unwrap(); + let first_four: [u8; 4] = [nonce[0], nonce[1], nonce[2], nonce[3]]; if RESERVED_NONCE_BEGINNINGS.contains(&first_four) { continue; } - let continue_four: [u8; 4] = nonce[4..8].try_into().unwrap(); + let continue_four: [u8; 4] = [nonce[4], nonce[5], nonce[6], nonce[7]]; if RESERVED_NONCE_CONTINUES.contains(&continue_four) { continue; } nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes()); @@ -390,11 +643,17 @@ pub fn generate_tg_nonce( let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN]; let dec_key_iv: Vec = enc_key_iv.iter().rev().copied().collect(); - let tg_enc_key: [u8; 32] = enc_key_iv[..KEY_LEN].try_into().unwrap(); - let tg_enc_iv = u128::from_be_bytes(enc_key_iv[KEY_LEN..].try_into().unwrap()); + let mut tg_enc_key = [0u8; 32]; + tg_enc_key.copy_from_slice(&enc_key_iv[..KEY_LEN]); + let mut tg_enc_iv_arr = [0u8; IV_LEN]; + tg_enc_iv_arr.copy_from_slice(&enc_key_iv[KEY_LEN..]); + let tg_enc_iv = u128::from_be_bytes(tg_enc_iv_arr); - let tg_dec_key: [u8; 32] = dec_key_iv[..KEY_LEN].try_into().unwrap(); - let tg_dec_iv = u128::from_be_bytes(dec_key_iv[KEY_LEN..].try_into().unwrap()); + let mut tg_dec_key = [0u8; 32]; + tg_dec_key.copy_from_slice(&dec_key_iv[..KEY_LEN]); + let mut tg_dec_iv_arr = [0u8; IV_LEN]; + tg_dec_iv_arr.copy_from_slice(&dec_key_iv[KEY_LEN..]); + let tg_dec_iv = u128::from_be_bytes(tg_dec_iv_arr); return (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv); } @@ -405,11 +664,17 @@ pub fn encrypt_tg_nonce_with_ciphers(nonce: &[u8; HANDSHAKE_LEN]) -> (Vec, A let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN]; let dec_key_iv: Vec = enc_key_iv.iter().rev().copied().collect(); - let enc_key: [u8; 32] = enc_key_iv[..KEY_LEN].try_into().unwrap(); - let enc_iv = u128::from_be_bytes(enc_key_iv[KEY_LEN..].try_into().unwrap()); + let mut enc_key = [0u8; 32]; + enc_key.copy_from_slice(&enc_key_iv[..KEY_LEN]); + let mut enc_iv_arr = [0u8; IV_LEN]; + enc_iv_arr.copy_from_slice(&enc_key_iv[KEY_LEN..]); + let enc_iv = u128::from_be_bytes(enc_iv_arr); - let dec_key: [u8; 32] = dec_key_iv[..KEY_LEN].try_into().unwrap(); - let dec_iv = u128::from_be_bytes(dec_key_iv[KEY_LEN..].try_into().unwrap()); + let mut dec_key = [0u8; 32]; + dec_key.copy_from_slice(&dec_key_iv[..KEY_LEN]); + let mut dec_iv_arr = [0u8; IV_LEN]; + dec_iv_arr.copy_from_slice(&dec_key_iv[KEY_LEN..]); + let dec_iv = u128::from_be_bytes(dec_iv_arr); let mut encryptor = AesCtr::new(&enc_key, enc_iv); let encrypted_full = encryptor.encrypt(nonce); // counter: 0 → 4 @@ -429,80 +694,15 @@ pub fn encrypt_tg_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec { } #[cfg(test)] -mod tests { - use super::*; +#[path = "handshake_security_tests.rs"] +mod security_tests; - #[test] - fn test_generate_tg_nonce() { - let client_dec_key = [0x42u8; 32]; - let client_dec_iv = 12345u128; - let client_enc_key = [0x24u8; 32]; - let client_enc_iv = 54321u128; +/// Compile-time guard: HandshakeSuccess holds cryptographic key material and +/// must never be Copy. A Copy impl would allow silent key duplication, +/// undermining the zeroize-on-drop guarantee. +mod compile_time_security_checks { + use super::HandshakeSuccess; + use static_assertions::assert_not_impl_all; - let rng = SecureRandom::new(); - let (nonce, _tg_enc_key, _tg_enc_iv, _tg_dec_key, _tg_dec_iv) = - generate_tg_nonce( - ProtoTag::Secure, - 2, - &client_dec_key, - client_dec_iv, - &client_enc_key, - client_enc_iv, - &rng, - false, - ); - - assert_eq!(nonce.len(), HANDSHAKE_LEN); - - let tag_bytes: [u8; 4] = nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].try_into().unwrap(); - assert_eq!(ProtoTag::from_bytes(tag_bytes), Some(ProtoTag::Secure)); - } - - #[test] - fn test_encrypt_tg_nonce() { - let client_dec_key = [0x42u8; 32]; - let client_dec_iv = 12345u128; - let client_enc_key = [0x24u8; 32]; - let client_enc_iv = 54321u128; - - let rng = SecureRandom::new(); - let (nonce, _, _, _, _) = - generate_tg_nonce( - ProtoTag::Secure, - 2, - &client_dec_key, - client_dec_iv, - &client_enc_key, - client_enc_iv, - &rng, - false, - ); - - let encrypted = encrypt_tg_nonce(&nonce); - - assert_eq!(encrypted.len(), HANDSHAKE_LEN); - assert_eq!(&encrypted[..PROTO_TAG_POS], &nonce[..PROTO_TAG_POS]); - assert_ne!(&encrypted[PROTO_TAG_POS..], &nonce[PROTO_TAG_POS..]); - } - - #[test] - fn test_handshake_success_zeroize_on_drop() { - let success = HandshakeSuccess { - user: "test".to_string(), - dc_idx: 2, - proto_tag: ProtoTag::Secure, - dec_key: [0xAA; 32], - dec_iv: 0xBBBBBBBB, - enc_key: [0xCC; 32], - enc_iv: 0xDDDDDDDD, - peer: "127.0.0.1:1234".parse().unwrap(), - is_tls: true, - }; - - assert_eq!(success.dec_key, [0xAA; 32]); - assert_eq!(success.enc_key, [0xCC; 32]); - - drop(success); - // Drop impl zeroizes key material without panic - } + assert_not_impl_all!(HandshakeSuccess: Copy, Clone); } diff --git a/src/proxy/handshake_security_tests.rs b/src/proxy/handshake_security_tests.rs new file mode 100644 index 0000000..f2d7d03 --- /dev/null +++ b/src/proxy/handshake_security_tests.rs @@ -0,0 +1,891 @@ +use super::*; +use crate::crypto::sha256_hmac; +use dashmap::DashMap; +use std::net::{IpAddr, Ipv4Addr}; +use std::sync::Arc; +use std::time::{Duration, Instant}; + +fn make_valid_tls_handshake(secret: &[u8], timestamp: u32) -> Vec { + let session_id_len: usize = 32; + let len = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + session_id_len; + let mut handshake = vec![0x42u8; len]; + + handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8; + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + + let computed = sha256_hmac(secret, &handshake); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .copy_from_slice(&digest); + handshake +} + +fn make_valid_tls_client_hello_with_alpn( + secret: &[u8], + timestamp: u32, + alpn_protocols: &[&[u8]], +) -> Vec { + let mut body = Vec::new(); + body.extend_from_slice(&TLS_VERSION); + body.extend_from_slice(&[0u8; 32]); + body.push(32); + body.extend_from_slice(&[0x42u8; 32]); + body.extend_from_slice(&2u16.to_be_bytes()); + body.extend_from_slice(&[0x13, 0x01]); + body.push(1); + body.push(0); + + let mut ext_blob = Vec::new(); + if !alpn_protocols.is_empty() { + let mut alpn_list = Vec::new(); + for proto in alpn_protocols { + alpn_list.push(proto.len() as u8); + alpn_list.extend_from_slice(proto); + } + let mut alpn_data = Vec::new(); + alpn_data.extend_from_slice(&(alpn_list.len() as u16).to_be_bytes()); + alpn_data.extend_from_slice(&alpn_list); + + ext_blob.extend_from_slice(&0x0010u16.to_be_bytes()); + ext_blob.extend_from_slice(&(alpn_data.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&alpn_data); + } + body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes()); + body.extend_from_slice(&ext_blob); + + let mut handshake = Vec::new(); + handshake.push(0x01); + let body_len = (body.len() as u32).to_be_bytes(); + handshake.extend_from_slice(&body_len[1..4]); + handshake.extend_from_slice(&body); + + let mut record = Vec::new(); + record.push(TLS_RECORD_HANDSHAKE); + record.extend_from_slice(&[0x03, 0x01]); + record.extend_from_slice(&(handshake.len() as u16).to_be_bytes()); + record.extend_from_slice(&handshake); + + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + let computed = sha256_hmac(secret, &record); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .copy_from_slice(&digest); + + record +} + +fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig { + clear_auth_probe_state_for_testing(); + let mut cfg = ProxyConfig::default(); + cfg.access.users.clear(); + cfg.access + .users + .insert("user".to_string(), secret_hex.to_string()); + cfg.access.ignore_time_skew = true; + cfg +} + +#[test] +fn test_generate_tg_nonce() { + let client_enc_key = [0x24u8; 32]; + let client_enc_iv = 54321u128; + + let rng = SecureRandom::new(); + let (nonce, _tg_enc_key, _tg_enc_iv, _tg_dec_key, _tg_dec_iv) = generate_tg_nonce( + ProtoTag::Secure, + 2, + &client_enc_key, + client_enc_iv, + &rng, + false, + ); + + assert_eq!(nonce.len(), HANDSHAKE_LEN); + + let tag_bytes: [u8; 4] = nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].try_into().unwrap(); + assert_eq!(ProtoTag::from_bytes(tag_bytes), Some(ProtoTag::Secure)); +} + +#[test] +fn test_encrypt_tg_nonce() { + let client_enc_key = [0x24u8; 32]; + let client_enc_iv = 54321u128; + + let rng = SecureRandom::new(); + let (nonce, _, _, _, _) = generate_tg_nonce( + ProtoTag::Secure, + 2, + &client_enc_key, + client_enc_iv, + &rng, + false, + ); + + let encrypted = encrypt_tg_nonce(&nonce); + + assert_eq!(encrypted.len(), HANDSHAKE_LEN); + assert_eq!(&encrypted[..PROTO_TAG_POS], &nonce[..PROTO_TAG_POS]); + assert_ne!(&encrypted[PROTO_TAG_POS..], &nonce[PROTO_TAG_POS..]); +} + +#[test] +fn test_handshake_success_drop_does_not_panic() { + let success = HandshakeSuccess { + user: "test".to_string(), + dc_idx: 2, + proto_tag: ProtoTag::Secure, + dec_key: [0xAA; 32], + dec_iv: 0xBBBBBBBB, + enc_key: [0xCC; 32], + enc_iv: 0xDDDDDDDD, + peer: "198.51.100.10:1234".parse().unwrap(), + is_tls: true, + }; + + assert_eq!(success.dec_key, [0xAA; 32]); + assert_eq!(success.enc_key, [0xCC; 32]); + + drop(success); +} + +#[test] +fn test_generate_tg_nonce_enc_dec_material_is_consistent() { + let client_enc_key = [0x34u8; 32]; + let client_enc_iv = 0xffeeddccbbaa00998877665544332211u128; + let rng = SecureRandom::new(); + + let (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv) = generate_tg_nonce( + ProtoTag::Secure, + 7, + &client_enc_key, + client_enc_iv, + &rng, + false, + ); + + let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN]; + let dec_key_iv: Vec = enc_key_iv.iter().rev().copied().collect(); + + let mut expected_tg_enc_key = [0u8; 32]; + expected_tg_enc_key.copy_from_slice(&enc_key_iv[..KEY_LEN]); + let mut expected_tg_enc_iv_arr = [0u8; IV_LEN]; + expected_tg_enc_iv_arr.copy_from_slice(&enc_key_iv[KEY_LEN..]); + let expected_tg_enc_iv = u128::from_be_bytes(expected_tg_enc_iv_arr); + + let mut expected_tg_dec_key = [0u8; 32]; + expected_tg_dec_key.copy_from_slice(&dec_key_iv[..KEY_LEN]); + let mut expected_tg_dec_iv_arr = [0u8; IV_LEN]; + expected_tg_dec_iv_arr.copy_from_slice(&dec_key_iv[KEY_LEN..]); + let expected_tg_dec_iv = u128::from_be_bytes(expected_tg_dec_iv_arr); + + assert_eq!(tg_enc_key, expected_tg_enc_key); + assert_eq!(tg_enc_iv, expected_tg_enc_iv); + assert_eq!(tg_dec_key, expected_tg_dec_key); + assert_eq!(tg_dec_iv, expected_tg_dec_iv); + assert_eq!( + i16::from_le_bytes([nonce[DC_IDX_POS], nonce[DC_IDX_POS + 1]]), + 7, + "Generated nonce must keep target dc index in protocol slot" + ); +} + +#[test] +fn test_generate_tg_nonce_fast_mode_embeds_reversed_client_enc_material() { + let client_enc_key = [0xABu8; 32]; + let client_enc_iv = 0x11223344556677889900aabbccddeeffu128; + let rng = SecureRandom::new(); + + let (nonce, _, _, _, _) = generate_tg_nonce( + ProtoTag::Secure, + 9, + &client_enc_key, + client_enc_iv, + &rng, + true, + ); + + let mut expected = Vec::with_capacity(KEY_LEN + IV_LEN); + expected.extend_from_slice(&client_enc_key); + expected.extend_from_slice(&client_enc_iv.to_be_bytes()); + expected.reverse(); + + assert_eq!(&nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN], expected.as_slice()); +} + +#[test] +fn test_encrypt_tg_nonce_with_ciphers_matches_manual_suffix_encryption() { + let client_enc_key = [0x24u8; 32]; + let client_enc_iv = 54321u128; + + let rng = SecureRandom::new(); + let (nonce, _, _, _, _) = generate_tg_nonce( + ProtoTag::Secure, + 2, + &client_enc_key, + client_enc_iv, + &rng, + false, + ); + + let (encrypted, _, _) = encrypt_tg_nonce_with_ciphers(&nonce); + + let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN]; + let mut expected_enc_key = [0u8; 32]; + expected_enc_key.copy_from_slice(&enc_key_iv[..KEY_LEN]); + let mut expected_enc_iv_arr = [0u8; IV_LEN]; + expected_enc_iv_arr.copy_from_slice(&enc_key_iv[KEY_LEN..]); + let expected_enc_iv = u128::from_be_bytes(expected_enc_iv_arr); + + let mut manual_encryptor = AesCtr::new(&expected_enc_key, expected_enc_iv); + let manual = manual_encryptor.encrypt(&nonce); + + assert_eq!(encrypted.len(), HANDSHAKE_LEN); + assert_eq!(&encrypted[..PROTO_TAG_POS], &nonce[..PROTO_TAG_POS]); + assert_eq!( + &encrypted[PROTO_TAG_POS..], + &manual[PROTO_TAG_POS..], + "Encrypted nonce suffix must match AES-CTR output with derived enc key/iv" + ); +} + +#[tokio::test] +async fn tls_replay_second_identical_handshake_is_rejected() { + let secret = [0x11u8; 16]; + let config = test_config_with_secret_hex("11111111111111111111111111111111"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.21:44321".parse().unwrap(); + let handshake = make_valid_tls_handshake(&secret, 0); + + let first = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(first, HandshakeResult::Success(_))); + + let second = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(second, HandshakeResult::BadClient { .. })); +} + +#[tokio::test] +async fn tls_replay_concurrent_identical_handshake_allows_exactly_one_success() { + let secret = [0x77u8; 16]; + let config = Arc::new(test_config_with_secret_hex("77777777777777777777777777777777")); + let replay_checker = Arc::new(ReplayChecker::new(4096, Duration::from_secs(60))); + let rng = Arc::new(SecureRandom::new()); + let handshake = Arc::new(make_valid_tls_handshake(&secret, 0)); + + let mut tasks = Vec::new(); + for _ in 0..50 { + let config = config.clone(); + let replay_checker = replay_checker.clone(); + let rng = rng.clone(); + let handshake = handshake.clone(); + tasks.push(tokio::spawn(async move { + handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + "198.51.100.22:45000".parse().unwrap(), + &config, + &replay_checker, + &rng, + None, + ) + .await + })); + } + + let mut success_count = 0usize; + for task in tasks { + let result = task.await.unwrap(); + if matches!(result, HandshakeResult::Success(_)) { + success_count += 1; + } else { + assert!(matches!(result, HandshakeResult::BadClient { .. })); + } + } + + assert_eq!( + success_count, 1, + "Concurrent replay attempts must allow exactly one successful handshake" + ); +} + +#[tokio::test] +async fn invalid_tls_probe_does_not_pollute_replay_cache() { + let config = test_config_with_secret_hex("11111111111111111111111111111111"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.23:44322".parse().unwrap(); + + let mut invalid = vec![0x42u8; tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + 32]; + invalid[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = 32; + + let before = replay_checker.stats(); + let result = handle_tls_handshake( + &invalid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + let after = replay_checker.stats(); + + assert!(matches!(result, HandshakeResult::BadClient { .. })); + assert_eq!(before.total_additions, after.total_additions); + assert_eq!(before.total_hits, after.total_hits); +} + +#[tokio::test] +async fn empty_decoded_secret_is_rejected() { + clear_warned_secrets_for_testing(); + let config = test_config_with_secret_hex(""); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.24:44323".parse().unwrap(); + let handshake = make_valid_tls_handshake(&[], 0); + + let result = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(result, HandshakeResult::BadClient { .. })); +} + +#[tokio::test] +async fn wrong_length_decoded_secret_is_rejected() { + clear_warned_secrets_for_testing(); + let config = test_config_with_secret_hex("aa"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.25:44324".parse().unwrap(); + let handshake = make_valid_tls_handshake(&[0xaau8], 0); + + let result = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(result, HandshakeResult::BadClient { .. })); +} + +#[tokio::test] +async fn invalid_mtproto_probe_does_not_pollute_replay_cache() { + let config = test_config_with_secret_hex("11111111111111111111111111111111"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let peer: SocketAddr = "198.51.100.26:44325".parse().unwrap(); + let handshake = [0u8; HANDSHAKE_LEN]; + + let before = replay_checker.stats(); + let result = handle_mtproto_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + let after = replay_checker.stats(); + + assert!(matches!(result, HandshakeResult::BadClient { .. })); + assert_eq!(before.total_additions, after.total_additions); + assert_eq!(before.total_hits, after.total_hits); +} + +#[tokio::test] +async fn mixed_secret_lengths_keep_valid_user_authenticating() { + clear_warned_secrets_for_testing(); + clear_auth_probe_state_for_testing(); + let good_secret = [0x22u8; 16]; + let mut config = ProxyConfig::default(); + config.access.users.clear(); + config + .access + .users + .insert("broken_user".to_string(), "aa".to_string()); + config + .access + .users + .insert("valid_user".to_string(), "22222222222222222222222222222222".to_string()); + config.access.ignore_time_skew = true; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.27:44326".parse().unwrap(); + let handshake = make_valid_tls_handshake(&good_secret, 0); + + let result = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(result, HandshakeResult::Success(_))); +} + +#[tokio::test] +async fn alpn_enforce_rejects_unsupported_client_alpn() { + let secret = [0x33u8; 16]; + let mut config = test_config_with_secret_hex("33333333333333333333333333333333"); + config.censorship.alpn_enforce = true; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.28:44327".parse().unwrap(); + let handshake = make_valid_tls_client_hello_with_alpn(&secret, 0, &[b"h3"]); + + let result = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(result, HandshakeResult::BadClient { .. })); +} + +#[tokio::test] +async fn alpn_enforce_accepts_h2() { + let secret = [0x44u8; 16]; + let mut config = test_config_with_secret_hex("44444444444444444444444444444444"); + config.censorship.alpn_enforce = true; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.29:44328".parse().unwrap(); + let handshake = make_valid_tls_client_hello_with_alpn(&secret, 0, &[b"h2", b"h3"]); + + let result = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(result, HandshakeResult::Success(_))); +} + +#[tokio::test] +async fn malformed_tls_classes_complete_within_bounded_time() { + let secret = [0x55u8; 16]; + let mut config = test_config_with_secret_hex("55555555555555555555555555555555"); + config.censorship.alpn_enforce = true; + + let replay_checker = ReplayChecker::new(512, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.30:44329".parse().unwrap(); + + let too_short = vec![0x16, 0x03, 0x01]; + + let mut bad_hmac = make_valid_tls_handshake(&secret, 0); + bad_hmac[tls::TLS_DIGEST_POS] ^= 0x01; + + let alpn_mismatch = make_valid_tls_client_hello_with_alpn(&secret, 0, &[b"h3"]); + + for probe in [too_short, bad_hmac, alpn_mismatch] { + let result = tokio::time::timeout( + Duration::from_millis(200), + handle_tls_handshake( + &probe, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ), + ) + .await + .expect("Malformed TLS classes must be rejected within bounded time"); + + assert!(matches!(result, HandshakeResult::BadClient { .. })); + } +} + +#[tokio::test] +#[ignore = "timing-sensitive; run manually on low-jitter hosts"] +async fn malformed_tls_classes_share_close_latency_buckets() { + const ITER: usize = 24; + const BUCKET_MS: u128 = 10; + + let secret = [0x99u8; 16]; + let mut config = test_config_with_secret_hex("99999999999999999999999999999999"); + config.censorship.alpn_enforce = true; + + let replay_checker = ReplayChecker::new(4096, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.31:44330".parse().unwrap(); + + let too_short = vec![0x16, 0x03, 0x01]; + + let mut bad_hmac = make_valid_tls_handshake(&secret, 0); + bad_hmac[tls::TLS_DIGEST_POS + 1] ^= 0x01; + + let alpn_mismatch = make_valid_tls_client_hello_with_alpn(&secret, 0, &[b"h3"]); + + let mut class_means_ms = Vec::new(); + for probe in [too_short, bad_hmac, alpn_mismatch] { + let mut sum_micros: u128 = 0; + for _ in 0..ITER { + let started = Instant::now(); + let result = handle_tls_handshake( + &probe, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + let elapsed = started.elapsed(); + assert!(matches!(result, HandshakeResult::BadClient { .. })); + sum_micros += elapsed.as_micros(); + } + + class_means_ms.push(sum_micros / ITER as u128 / 1_000); + } + + let min_bucket = class_means_ms + .iter() + .map(|ms| ms / BUCKET_MS) + .min() + .unwrap(); + let max_bucket = class_means_ms + .iter() + .map(|ms| ms / BUCKET_MS) + .max() + .unwrap(); + + assert!( + max_bucket <= min_bucket + 1, + "Malformed TLS classes diverged across latency buckets: means_ms={:?}", + class_means_ms + ); +} + +#[test] +fn secure_tag_requires_tls_mode_on_tls_transport() { + let mut config = ProxyConfig::default(); + config.general.modes.classic = false; + config.general.modes.secure = true; + config.general.modes.tls = false; + + assert!( + !mode_enabled_for_proto(&config, ProtoTag::Secure, true), + "Secure tag over TLS must be rejected when tls mode is disabled" + ); + + config.general.modes.tls = true; + assert!( + mode_enabled_for_proto(&config, ProtoTag::Secure, true), + "Secure tag over TLS must be accepted when tls mode is enabled" + ); +} + +#[test] +fn secure_tag_requires_secure_mode_on_direct_transport() { + let mut config = ProxyConfig::default(); + config.general.modes.classic = false; + config.general.modes.secure = false; + config.general.modes.tls = true; + + assert!( + !mode_enabled_for_proto(&config, ProtoTag::Secure, false), + "Secure tag without TLS must be rejected when secure mode is disabled" + ); + + config.general.modes.secure = true; + assert!( + mode_enabled_for_proto(&config, ProtoTag::Secure, false), + "Secure tag without TLS must be accepted when secure mode is enabled" + ); +} + +#[test] +fn mode_policy_matrix_is_stable_for_all_tag_transport_mode_combinations() { + let tags = [ProtoTag::Secure, ProtoTag::Intermediate, ProtoTag::Abridged]; + + for classic in [false, true] { + for secure in [false, true] { + for tls in [false, true] { + let mut config = ProxyConfig::default(); + config.general.modes.classic = classic; + config.general.modes.secure = secure; + config.general.modes.tls = tls; + + for is_tls in [false, true] { + for tag in tags { + let expected = match (tag, is_tls) { + (ProtoTag::Secure, true) => tls, + (ProtoTag::Secure, false) => secure, + (ProtoTag::Intermediate | ProtoTag::Abridged, _) => classic, + }; + + assert_eq!( + mode_enabled_for_proto(&config, tag, is_tls), + expected, + "mode policy drifted for tag={:?}, transport_tls={}, modes=(classic={}, secure={}, tls={})", + tag, + is_tls, + classic, + secure, + tls + ); + } + } + } + } + } +} + +#[test] +fn invalid_secret_warning_keys_do_not_collide_on_colon_boundaries() { + clear_warned_secrets_for_testing(); + + warn_invalid_secret_once("a:b", "c", ACCESS_SECRET_BYTES, Some(1)); + warn_invalid_secret_once("a", "b:c", ACCESS_SECRET_BYTES, Some(2)); + + let warned = INVALID_SECRET_WARNED + .get() + .expect("warned set must be initialized"); + let guard = warned.lock().expect("warned set lock must be available"); + assert_eq!( + guard.len(), + 2, + "(name, reason) pairs that stringify to the same colon-joined key must remain distinct" + ); +} + +#[tokio::test] +async fn repeated_invalid_tls_probes_trigger_pre_auth_throttle() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let config = test_config_with_secret_hex("11111111111111111111111111111111"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.61:44361".parse().unwrap(); + + let mut invalid = vec![0x42u8; tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + 32]; + invalid[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = 32; + + for _ in 0..AUTH_PROBE_BACKOFF_START_FAILS { + let result = handle_tls_handshake( + &invalid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(result, HandshakeResult::BadClient { .. })); + } + + assert!( + auth_probe_is_throttled_for_testing(peer.ip()), + "invalid probe burst must activate per-IP pre-auth throttle" + ); +} + +#[tokio::test] +async fn successful_tls_handshake_clears_pre_auth_failure_streak() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let secret = [0x23u8; 16]; + let config = test_config_with_secret_hex("23232323232323232323232323232323"); + let replay_checker = ReplayChecker::new(256, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.62:44362".parse().unwrap(); + + let mut invalid = vec![0x42u8; tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + 32]; + invalid[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = 32; + + for expected in 1..AUTH_PROBE_BACKOFF_START_FAILS { + let result = handle_tls_handshake( + &invalid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(result, HandshakeResult::BadClient { .. })); + assert_eq!( + auth_probe_fail_streak_for_testing(peer.ip()), + Some(expected), + "failure streak must grow before a successful authentication" + ); + } + + let valid = make_valid_tls_handshake(&secret, 0); + let success = handle_tls_handshake( + &valid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(success, HandshakeResult::Success(_))); + assert_eq!( + auth_probe_fail_streak_for_testing(peer.ip()), + None, + "successful authentication must clear accumulated pre-auth failures" + ); +} + +#[test] +fn auth_probe_capacity_prunes_stale_entries_for_new_ips() { + let state = DashMap::new(); + let now = Instant::now(); + let stale_seen = now - Duration::from_secs(AUTH_PROBE_TRACK_RETENTION_SECS + 1); + + for idx in 0..AUTH_PROBE_TRACK_MAX_ENTRIES { + let ip = IpAddr::V4(Ipv4Addr::new( + 10, + 1, + ((idx >> 8) & 0xff) as u8, + (idx & 0xff) as u8, + )); + state.insert( + ip, + AuthProbeState { + fail_streak: 1, + blocked_until: now, + last_seen: stale_seen, + }, + ); + } + + let newcomer = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 200)); + auth_probe_record_failure_with_state(&state, newcomer, now); + + assert_eq!( + state.get(&newcomer).map(|entry| entry.fail_streak), + Some(1), + "stale-entry pruning must admit and track a new probe source" + ); + assert!( + state.len() <= AUTH_PROBE_TRACK_MAX_ENTRIES, + "auth probe map must remain bounded after stale pruning" + ); +} + +#[test] +fn auth_probe_capacity_stays_fail_closed_when_map_is_fresh_and_full() { + let state = DashMap::new(); + let now = Instant::now(); + + for idx in 0..AUTH_PROBE_TRACK_MAX_ENTRIES { + let ip = IpAddr::V4(Ipv4Addr::new( + 172, + 16, + ((idx >> 8) & 0xff) as u8, + (idx & 0xff) as u8, + )); + state.insert( + ip, + AuthProbeState { + fail_streak: 1, + blocked_until: now, + last_seen: now, + }, + ); + } + + let newcomer = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 55)); + auth_probe_record_failure_with_state(&state, newcomer, now); + + assert!( + state.get(&newcomer).is_none(), + "when all entries are fresh and full, new probes must not be admitted" + ); + assert_eq!( + state.len(), + AUTH_PROBE_TRACK_MAX_ENTRIES, + "auth probe map must stay at the configured cap" + ); +} diff --git a/src/proxy/masking.rs b/src/proxy/masking.rs index 318071b..e347d73 100644 --- a/src/proxy/masking.rs +++ b/src/proxy/masking.rs @@ -14,12 +14,41 @@ use crate::network::dns_overrides::resolve_socket_addr; use crate::stats::beobachten::BeobachtenStore; use crate::transport::proxy_protocol::{ProxyProtocolV1Builder, ProxyProtocolV2Builder}; +#[cfg(not(test))] const MASK_TIMEOUT: Duration = Duration::from_secs(5); +#[cfg(test)] +const MASK_TIMEOUT: Duration = Duration::from_millis(50); /// Maximum duration for the entire masking relay. /// Limits resource consumption from slow-loris attacks and port scanners. +#[cfg(not(test))] const MASK_RELAY_TIMEOUT: Duration = Duration::from_secs(60); +#[cfg(test)] +const MASK_RELAY_TIMEOUT: Duration = Duration::from_millis(200); const MASK_BUFFER_SIZE: usize = 8192; +async fn write_proxy_header_with_timeout(mask_write: &mut W, header: &[u8]) -> bool +where + W: AsyncWrite + Unpin, +{ + match timeout(MASK_TIMEOUT, mask_write.write_all(header)).await { + Ok(Ok(())) => true, + Ok(Err(_)) => false, + Err(_) => { + debug!("Timeout writing proxy protocol header to mask backend"); + false + } + } +} + +async fn consume_client_data_with_timeout(reader: R) +where + R: AsyncRead + Unpin, +{ + if timeout(MASK_RELAY_TIMEOUT, consume_client_data(reader)).await.is_err() { + debug!("Timed out while consuming client data on masking fallback path"); + } +} + /// Detect client type based on initial data fn detect_client_type(data: &[u8]) -> &'static str { // Check for HTTP request @@ -71,7 +100,7 @@ where if !config.censorship.mask { // Masking disabled, just consume data - consume_client_data(reader).await; + consume_client_data_with_timeout(reader).await; return; } @@ -107,7 +136,7 @@ where } }; if let Some(header) = proxy_header { - if mask_write.write_all(&header).await.is_err() { + if !write_proxy_header_with_timeout(&mut mask_write, &header).await { return; } } @@ -117,11 +146,11 @@ where } Ok(Err(e)) => { debug!(error = %e, "Failed to connect to mask unix socket"); - consume_client_data(reader).await; + consume_client_data_with_timeout(reader).await; } Err(_) => { debug!("Timeout connecting to mask unix socket"); - consume_client_data(reader).await; + consume_client_data_with_timeout(reader).await; } } return; @@ -166,7 +195,7 @@ where let (mask_read, mut mask_write) = stream.into_split(); if let Some(header) = proxy_header { - if mask_write.write_all(&header).await.is_err() { + if !write_proxy_header_with_timeout(&mut mask_write, &header).await { return; } } @@ -176,11 +205,11 @@ where } Ok(Err(e)) => { debug!(error = %e, "Failed to connect to mask host"); - consume_client_data(reader).await; + consume_client_data_with_timeout(reader).await; } Err(_) => { debug!("Timeout connecting to mask host"); - consume_client_data(reader).await; + consume_client_data_with_timeout(reader).await; } } } @@ -194,55 +223,51 @@ async fn relay_to_mask( initial_data: &[u8], ) where - R: AsyncRead + Unpin + Send + 'static, - W: AsyncWrite + Unpin + Send + 'static, - MR: AsyncRead + Unpin + Send + 'static, - MW: AsyncWrite + Unpin + Send + 'static, + R: AsyncRead + Unpin + Send, + W: AsyncWrite + Unpin + Send, + MR: AsyncRead + Unpin + Send, + MW: AsyncWrite + Unpin + Send, { // Send initial data to mask host if mask_write.write_all(initial_data).await.is_err() { return; } + if mask_write.flush().await.is_err() { + return; + } - // Relay traffic - let c2m = tokio::spawn(async move { - let mut buf = vec![0u8; MASK_BUFFER_SIZE]; - loop { - match reader.read(&mut buf).await { - Ok(0) | Err(_) => { - let _ = mask_write.shutdown().await; - break; - } - Ok(n) => { - if mask_write.write_all(&buf[..n]).await.is_err() { + let mut client_buf = vec![0u8; MASK_BUFFER_SIZE]; + let mut mask_buf = vec![0u8; MASK_BUFFER_SIZE]; + + loop { + tokio::select! { + client_read = reader.read(&mut client_buf) => { + match client_read { + Ok(0) | Err(_) => { + let _ = mask_write.shutdown().await; break; } + Ok(n) => { + if mask_write.write_all(&client_buf[..n]).await.is_err() { + break; + } + } + } + } + mask_read_res = mask_read.read(&mut mask_buf) => { + match mask_read_res { + Ok(0) | Err(_) => { + let _ = writer.shutdown().await; + break; + } + Ok(n) => { + if writer.write_all(&mask_buf[..n]).await.is_err() { + break; + } + } } } } - }); - - let m2c = tokio::spawn(async move { - let mut buf = vec![0u8; MASK_BUFFER_SIZE]; - loop { - match mask_read.read(&mut buf).await { - Ok(0) | Err(_) => { - let _ = writer.shutdown().await; - break; - } - Ok(n) => { - if writer.write_all(&buf[..n]).await.is_err() { - break; - } - } - } - } - }); - - // Wait for either to complete - tokio::select! { - _ = c2m => {} - _ = m2c => {} } } @@ -255,3 +280,7 @@ async fn consume_client_data(mut reader: R) { } } } + +#[cfg(test)] +#[path = "masking_security_tests.rs"] +mod security_tests; diff --git a/src/proxy/masking_security_tests.rs b/src/proxy/masking_security_tests.rs new file mode 100644 index 0000000..52e9f69 --- /dev/null +++ b/src/proxy/masking_security_tests.rs @@ -0,0 +1,550 @@ +use super::*; +use crate::config::ProxyConfig; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::{duplex, AsyncBufReadExt, BufReader}; +use tokio::net::TcpListener; +#[cfg(unix)] +use tokio::net::UnixListener; +use tokio::time::{timeout, Duration}; + +#[tokio::test] +async fn bad_client_probe_is_forwarded_verbatim_to_mask_backend() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = b"GET / HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + let backend_reply = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK".to_vec(); + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut received = vec![0u8; probe.len()]; + stream.read_exact(&mut received).await.unwrap(); + assert_eq!(received, probe); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let peer: SocketAddr = "203.0.113.10:42424".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (client_reader, _client_writer) = duplex(256); + let (mut client_visible_reader, client_visible_writer) = duplex(2048); + + let beobachten = BeobachtenStore::new(); + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + let mut observed = vec![0u8; backend_reply.len()]; + client_visible_reader.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, backend_reply); + accept_task.await.unwrap(); +} + +#[tokio::test] +async fn tls_scanner_probe_keeps_http_like_fallback_surface() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = vec![0x16, 0x03, 0x01, 0x00, 0x10, 0x01, 0x02, 0x03, 0x04]; + let backend_reply = b"HTTP/1.1 400 Bad Request\r\nContent-Length: 0\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut received = vec![0u8; probe.len()]; + stream.read_exact(&mut received).await.unwrap(); + assert_eq!(received, probe); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = true; + config.general.beobachten_minutes = 1; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let peer: SocketAddr = "198.51.100.44:55221".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (client_reader, _client_writer) = duplex(256); + let (mut client_visible_reader, client_visible_writer) = duplex(2048); + + let beobachten = BeobachtenStore::new(); + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + let mut observed = vec![0u8; backend_reply.len()]; + client_visible_reader.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, backend_reply); + + let snapshot = beobachten.snapshot_text(Duration::from_secs(60)); + assert!(snapshot.contains("[TLS-scanner]")); + assert!(snapshot.contains("198.51.100.44-1")); + accept_task.await.unwrap(); +} + +#[test] +fn detect_client_type_covers_ssh_port_scanner_and_unknown() { + assert_eq!(detect_client_type(b"SSH-2.0-OpenSSH_9.7"), "SSH"); + assert_eq!(detect_client_type(b"\x01\x02\x03"), "port-scanner"); + assert_eq!(detect_client_type(b"random-binary-payload"), "unknown"); +} + +#[test] +fn detect_client_type_len_boundary_9_vs_10_bytes() { + assert_eq!(detect_client_type(b"123456789"), "port-scanner"); + assert_eq!(detect_client_type(b"1234567890"), "unknown"); +} + +#[tokio::test] +async fn beobachten_records_scanner_class_when_mask_is_disabled() { + let mut config = ProxyConfig::default(); + config.general.beobachten = true; + config.general.beobachten_minutes = 1; + config.censorship.mask = false; + + let peer: SocketAddr = "203.0.113.99:41234".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let initial = b"SSH-2.0-probe"; + + let (mut client_reader_side, client_reader) = duplex(256); + let (_client_visible_reader, client_visible_writer) = duplex(256); + let beobachten = BeobachtenStore::new(); + + let task = tokio::spawn(async move { + handle_bad_client( + client_reader, + client_visible_writer, + initial, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + beobachten + }); + + client_reader_side.write_all(b"noise").await.unwrap(); + drop(client_reader_side); + + let beobachten = timeout(Duration::from_secs(3), task).await.unwrap().unwrap(); + let snapshot = beobachten.snapshot_text(Duration::from_secs(60)); + assert!(snapshot.contains("[SSH]")); + assert!(snapshot.contains("203.0.113.99-1")); +} + +#[tokio::test] +async fn backend_unavailable_falls_back_to_silent_consume() { + let temp_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let unused_port = temp_listener.local_addr().unwrap().port(); + drop(temp_listener); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = unused_port; + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let peer: SocketAddr = "203.0.113.11:42425".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let probe = b"GET /probe HTTP/1.1\r\nHost: x\r\n\r\n"; + + let (mut client_reader_side, client_reader) = duplex(256); + let (mut client_visible_reader, client_visible_writer) = duplex(256); + let beobachten = BeobachtenStore::new(); + + let task = tokio::spawn(async move { + handle_bad_client( + client_reader, + client_visible_writer, + probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + }); + + client_reader_side.write_all(b"noise").await.unwrap(); + drop(client_reader_side); + + timeout(Duration::from_secs(3), task).await.unwrap().unwrap(); + + let mut buf = [0u8; 1]; + let n = timeout(Duration::from_secs(1), client_visible_reader.read(&mut buf)) + .await + .unwrap() + .unwrap(); + assert_eq!(n, 0); +} + +#[tokio::test] +async fn mask_disabled_consumes_client_data_without_response() { + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = false; + + let peer: SocketAddr = "198.51.100.12:45454".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let initial = b"scanner"; + + let (mut client_reader_side, client_reader) = duplex(256); + let (mut client_visible_reader, client_visible_writer) = duplex(256); + let beobachten = BeobachtenStore::new(); + + let task = tokio::spawn(async move { + handle_bad_client( + client_reader, + client_visible_writer, + initial, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + }); + + client_reader_side.write_all(b"untrusted payload").await.unwrap(); + drop(client_reader_side); + + timeout(Duration::from_secs(3), task).await.unwrap().unwrap(); + + let mut buf = [0u8; 1]; + let n = timeout(Duration::from_secs(1), client_visible_reader.read(&mut buf)) + .await + .unwrap() + .unwrap(); + assert_eq!(n, 0); +} + +#[tokio::test] +async fn proxy_protocol_v1_header_is_sent_before_probe() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = b"GET / HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + let backend_reply = b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + let backend_reply = backend_reply.clone(); + async move { + let (stream, _) = listener.accept().await.unwrap(); + let mut reader = BufReader::new(stream); + + let mut header_line = Vec::new(); + reader.read_until(b'\n', &mut header_line).await.unwrap(); + let header_text = String::from_utf8(header_line.clone()).unwrap(); + assert!(header_text.starts_with("PROXY TCP4 ")); + assert!(header_text.ends_with("\r\n")); + + let mut received_probe = vec![0u8; probe.len()]; + reader.read_exact(&mut received_probe).await.unwrap(); + assert_eq!(received_probe, probe); + + let mut stream = reader.into_inner(); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 1; + + let peer: SocketAddr = "203.0.113.15:50001".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (client_reader, _client_writer) = duplex(256); + let (mut client_visible_reader, client_visible_writer) = duplex(2048); + + let beobachten = BeobachtenStore::new(); + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + let mut observed = vec![0u8; backend_reply.len()]; + client_visible_reader.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, backend_reply); + accept_task.await.unwrap(); +} + +#[tokio::test] +async fn proxy_protocol_v2_header_is_sent_before_probe() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = b"GET / HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + let backend_reply = b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut sig = [0u8; 12]; + stream.read_exact(&mut sig).await.unwrap(); + assert_eq!(&sig, b"\r\n\r\n\0\r\nQUIT\n"); + + let mut fixed = [0u8; 4]; + stream.read_exact(&mut fixed).await.unwrap(); + let addr_len = u16::from_be_bytes([fixed[2], fixed[3]]) as usize; + + let mut addr_block = vec![0u8; addr_len]; + stream.read_exact(&mut addr_block).await.unwrap(); + + let mut received_probe = vec![0u8; probe.len()]; + stream.read_exact(&mut received_probe).await.unwrap(); + assert_eq!(received_probe, probe); + + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 2; + + let peer: SocketAddr = "203.0.113.18:50004".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (client_reader, _client_writer) = duplex(256); + let (mut client_visible_reader, client_visible_writer) = duplex(2048); + + let beobachten = BeobachtenStore::new(); + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + let mut observed = vec![0u8; backend_reply.len()]; + client_visible_reader.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, backend_reply); + accept_task.await.unwrap(); +} + +#[tokio::test] +async fn proxy_protocol_v1_mixed_family_falls_back_to_unknown_header() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = b"GET /mix HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + let backend_reply = b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + let backend_reply = backend_reply.clone(); + async move { + let (stream, _) = listener.accept().await.unwrap(); + let mut reader = BufReader::new(stream); + + let mut header_line = Vec::new(); + reader.read_until(b'\n', &mut header_line).await.unwrap(); + let header_text = String::from_utf8(header_line).unwrap(); + assert_eq!(header_text, "PROXY UNKNOWN\r\n"); + + let mut received_probe = vec![0u8; probe.len()]; + reader.read_exact(&mut received_probe).await.unwrap(); + assert_eq!(received_probe, probe); + + let mut stream = reader.into_inner(); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 1; + + let peer: SocketAddr = "203.0.113.20:50006".parse().unwrap(); + let local_addr: SocketAddr = "[::1]:443".parse().unwrap(); + + let (client_reader, _client_writer) = duplex(256); + let (mut client_visible_reader, client_visible_writer) = duplex(2048); + + let beobachten = BeobachtenStore::new(); + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + let mut observed = vec![0u8; backend_reply.len()]; + client_visible_reader.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, backend_reply); + accept_task.await.unwrap(); +} + +#[cfg(unix)] +#[tokio::test] +async fn unix_socket_mask_path_forwards_probe_and_response() { + let sock_path = format!("/tmp/telemt-mask-test-{}-{}.sock", std::process::id(), rand::random::()); + let _ = std::fs::remove_file(&sock_path); + + let listener = UnixListener::bind(&sock_path).unwrap(); + let probe = b"GET /unix HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + let backend_reply = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK".to_vec(); + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut received = vec![0u8; probe.len()]; + stream.read_exact(&mut received).await.unwrap(); + assert_eq!(received, probe); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = Some(sock_path.clone()); + config.censorship.mask_proxy_protocol = 0; + + let peer: SocketAddr = "203.0.113.30:50010".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (client_reader, _client_writer) = duplex(256); + let (mut client_visible_reader, client_visible_writer) = duplex(2048); + + let beobachten = BeobachtenStore::new(); + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + let mut observed = vec![0u8; backend_reply.len()]; + client_visible_reader.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, backend_reply); + + accept_task.await.unwrap(); + let _ = std::fs::remove_file(sock_path); +} + +#[tokio::test] +async fn mask_disabled_slowloris_connection_is_closed_by_consume_timeout() { + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = false; + + let peer: SocketAddr = "198.51.100.33:45455".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (_client_reader_side, client_reader) = duplex(256); + let (_client_visible_reader, client_visible_writer) = duplex(256); + let beobachten = BeobachtenStore::new(); + + let task = tokio::spawn(async move { + handle_bad_client( + client_reader, + client_visible_writer, + b"slowloris", + peer, + local_addr, + &config, + &beobachten, + ) + .await; + }); + + timeout(Duration::from_secs(1), task).await.unwrap().unwrap(); +} + +struct PendingWriter; + +impl tokio::io::AsyncWrite for PendingWriter { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _buf: &[u8], + ) -> Poll> { + Poll::Pending + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} + +#[tokio::test] +async fn proxy_header_write_timeout_returns_false() { + let mut writer = PendingWriter; + let ok = write_proxy_header_with_timeout(&mut writer, b"PROXY UNKNOWN\r\n").await; + assert!(!ok, "Proxy header writes that never complete must time out"); +} diff --git a/src/proxy/middle_relay.rs b/src/proxy/middle_relay.rs index aaae1b3..ba01c74 100644 --- a/src/proxy/middle_relay.rs +++ b/src/proxy/middle_relay.rs @@ -1,14 +1,17 @@ -use std::collections::HashMap; use std::collections::hash_map::DefaultHasher; use std::hash::{Hash, Hasher}; use std::net::{IpAddr, SocketAddr}; use std::sync::atomic::{AtomicU64, Ordering}; -use std::sync::{Arc, Mutex, OnceLock}; +use std::sync::{Arc, OnceLock}; use std::time::{Duration, Instant}; +#[cfg(test)] +use std::sync::Mutex; -use bytes::Bytes; +use bytes::{Bytes, BytesMut}; +use dashmap::DashMap; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::sync::{mpsc, oneshot, watch}; +use tokio::time::timeout; use tracing::{debug, trace, warn}; use crate::config::ProxyConfig; @@ -30,13 +33,15 @@ enum C2MeCommand { } const DESYNC_DEDUP_WINDOW: Duration = Duration::from_secs(60); +const DESYNC_DEDUP_MAX_ENTRIES: usize = 65_536; +const DESYNC_DEDUP_PRUNE_SCAN_LIMIT: usize = 1024; const DESYNC_ERROR_CLASS: &str = "frame_too_large_crypto_desync"; const C2ME_CHANNEL_CAPACITY_FALLBACK: usize = 128; const C2ME_SOFT_PRESSURE_MIN_FREE_SLOTS: usize = 64; const C2ME_SENDER_FAIRNESS_BUDGET: usize = 32; const ME_D2C_FLUSH_BATCH_MAX_FRAMES_MIN: usize = 1; const ME_D2C_FLUSH_BATCH_MAX_BYTES_MIN: usize = 4096; -static DESYNC_DEDUP: OnceLock>> = OnceLock::new(); +static DESYNC_DEDUP: OnceLock> = OnceLock::new(); struct RelayForensicsState { trace_id: u64, @@ -90,24 +95,46 @@ fn should_emit_full_desync(key: u64, all_full: bool, now: Instant) -> bool { return true; } - let dedup = DESYNC_DEDUP.get_or_init(|| Mutex::new(HashMap::new())); - let mut guard = dedup.lock().expect("desync dedup mutex poisoned"); - guard.retain(|_, seen_at| now.duration_since(*seen_at) < DESYNC_DEDUP_WINDOW); + let dedup = DESYNC_DEDUP.get_or_init(DashMap::new); - match guard.get_mut(&key) { - Some(seen_at) => { - if now.duration_since(*seen_at) >= DESYNC_DEDUP_WINDOW { - *seen_at = now; - true - } else { - false + if let Some(mut seen_at) = dedup.get_mut(&key) { + if now.duration_since(*seen_at) >= DESYNC_DEDUP_WINDOW { + *seen_at = now; + return true; + } + return false; + } + + if dedup.len() >= DESYNC_DEDUP_MAX_ENTRIES { + let mut stale_keys = Vec::new(); + for entry in dedup.iter().take(DESYNC_DEDUP_PRUNE_SCAN_LIMIT) { + if now.duration_since(*entry.value()) >= DESYNC_DEDUP_WINDOW { + stale_keys.push(*entry.key()); } } - None => { - guard.insert(key, now); - true + for stale_key in stale_keys { + dedup.remove(&stale_key); + } + if dedup.len() >= DESYNC_DEDUP_MAX_ENTRIES { + return false; } } + + dedup.insert(key, now); + true +} + +#[cfg(test)] +fn clear_desync_dedup_for_testing() { + if let Some(dedup) = DESYNC_DEDUP.get() { + dedup.clear(); + } +} + +#[cfg(test)] +fn desync_dedup_test_lock() -> &'static Mutex<()> { + static TEST_LOCK: OnceLock> = OnceLock::new(); + TEST_LOCK.get_or_init(|| Mutex::new(())) } fn report_desync_frame_too_large( @@ -229,7 +256,7 @@ pub(crate) async fn handle_via_middle_proxy( me_pool: Arc, stats: Arc, config: Arc, - _buffer_pool: Arc, + buffer_pool: Arc, local_addr: SocketAddr, rng: Arc, mut route_rx: watch::Receiver, @@ -271,7 +298,6 @@ where }; stats.increment_user_connects(&user); - stats.increment_user_curr_connects(&user); stats.increment_current_connections_me(); if let Some(cutover) = affected_cutover_state( @@ -291,7 +317,6 @@ where let _ = me_pool.send_close(conn_id).await; me_pool.registry().unregister(conn_id).await; stats.decrement_current_connections_me(); - stats.decrement_user_curr_connects(&user); return Err(ProxyError::Proxy(ROUTE_SWITCH_ERROR_MSG.to_string())); } @@ -557,6 +582,8 @@ where &mut crypto_reader, proto_tag, frame_limit, + Duration::from_secs(config.timeouts.client_handshake.max(1)), + &buffer_pool, &forensics, &mut frame_counter, &stats, @@ -638,7 +665,6 @@ where ); me_pool.registry().unregister(conn_id).await; stats.decrement_current_connections_me(); - stats.decrement_user_curr_connects(&user); result } @@ -646,6 +672,8 @@ async fn read_client_payload( client_reader: &mut CryptoReader, proto_tag: ProtoTag, max_frame: usize, + frame_read_timeout: Duration, + buffer_pool: &Arc, forensics: &RelayForensicsState, frame_counter: &mut u64, stats: &Stats, @@ -653,23 +681,40 @@ async fn read_client_payload( where R: AsyncRead + Unpin + Send + 'static, { + async fn read_exact_with_timeout( + client_reader: &mut CryptoReader, + buf: &mut [u8], + frame_read_timeout: Duration, + ) -> Result<()> + where + R: AsyncRead + Unpin + Send + 'static, + { + match timeout(frame_read_timeout, client_reader.read_exact(buf)).await { + Ok(Ok(_)) => Ok(()), + Ok(Err(e)) => Err(ProxyError::Io(e)), + Err(_) => Err(ProxyError::Io(std::io::Error::new( + std::io::ErrorKind::TimedOut, + "middle-relay client frame read timeout", + ))), + } + } + loop { let (len, quickack, raw_len_bytes) = match proto_tag { ProtoTag::Abridged => { let mut first = [0u8; 1]; - match client_reader.read_exact(&mut first).await { - Ok(_) => {} - Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None), - Err(e) => return Err(ProxyError::Io(e)), + match read_exact_with_timeout(client_reader, &mut first, frame_read_timeout).await { + Ok(()) => {} + Err(ProxyError::Io(e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => { + return Ok(None); + } + Err(e) => return Err(e), } let quickack = (first[0] & 0x80) != 0; let len_words = if (first[0] & 0x7f) == 0x7f { let mut ext = [0u8; 3]; - client_reader - .read_exact(&mut ext) - .await - .map_err(ProxyError::Io)?; + read_exact_with_timeout(client_reader, &mut ext, frame_read_timeout).await?; u32::from_le_bytes([ext[0], ext[1], ext[2], 0]) as usize } else { (first[0] & 0x7f) as usize @@ -682,10 +727,12 @@ where } ProtoTag::Intermediate | ProtoTag::Secure => { let mut len_buf = [0u8; 4]; - match client_reader.read_exact(&mut len_buf).await { - Ok(_) => {} - Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None), - Err(e) => return Err(ProxyError::Io(e)), + match read_exact_with_timeout(client_reader, &mut len_buf, frame_read_timeout).await { + Ok(()) => {} + Err(ProxyError::Io(e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => { + return Ok(None); + } + Err(e) => return Err(e), } let quickack = (len_buf[3] & 0x80) != 0; ( @@ -737,18 +784,25 @@ where len }; - let mut payload = vec![0u8; len]; - client_reader - .read_exact(&mut payload) - .await - .map_err(ProxyError::Io)?; + let chunk_cap = buffer_pool.buffer_size().max(1024); + let mut payload = BytesMut::with_capacity(len.min(chunk_cap)); + let mut remaining = len; + while remaining > 0 { + let chunk_len = remaining.min(chunk_cap); + let mut chunk = buffer_pool.get(); + chunk.resize(chunk_len, 0); + read_exact_with_timeout(client_reader, &mut chunk[..chunk_len], frame_read_timeout) + .await?; + payload.extend_from_slice(&chunk[..chunk_len]); + remaining -= chunk_len; + } // Secure Intermediate: strip validated trailing padding bytes. if proto_tag == ProtoTag::Secure { payload.truncate(secure_payload_len); } *frame_counter += 1; - return Ok(Some((Bytes::from(payload), quickack))); + return Ok(Some((payload.freeze(), quickack))); } } @@ -940,82 +994,5 @@ where } #[cfg(test)] -mod tests { - use super::*; - use tokio::time::{Duration as TokioDuration, timeout}; - - #[test] - fn should_yield_sender_only_on_budget_with_backlog() { - assert!(!should_yield_c2me_sender(0, true)); - assert!(!should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET - 1, true)); - assert!(!should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET, false)); - assert!(should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET, true)); - } - - #[tokio::test] - async fn enqueue_c2me_command_uses_try_send_fast_path() { - let (tx, mut rx) = mpsc::channel::(2); - enqueue_c2me_command( - &tx, - C2MeCommand::Data { - payload: Bytes::from_static(&[1, 2, 3]), - flags: 0, - }, - ) - .await - .unwrap(); - - let recv = timeout(TokioDuration::from_millis(50), rx.recv()) - .await - .unwrap() - .unwrap(); - match recv { - C2MeCommand::Data { payload, flags } => { - assert_eq!(payload.as_ref(), &[1, 2, 3]); - assert_eq!(flags, 0); - } - C2MeCommand::Close => panic!("unexpected close command"), - } - } - - #[tokio::test] - async fn enqueue_c2me_command_falls_back_to_send_when_queue_is_full() { - let (tx, mut rx) = mpsc::channel::(1); - tx.send(C2MeCommand::Data { - payload: Bytes::from_static(&[9]), - flags: 9, - }) - .await - .unwrap(); - - let tx2 = tx.clone(); - let producer = tokio::spawn(async move { - enqueue_c2me_command( - &tx2, - C2MeCommand::Data { - payload: Bytes::from_static(&[7, 7]), - flags: 7, - }, - ) - .await - .unwrap(); - }); - - let _ = timeout(TokioDuration::from_millis(100), rx.recv()) - .await - .unwrap(); - producer.await.unwrap(); - - let recv = timeout(TokioDuration::from_millis(100), rx.recv()) - .await - .unwrap() - .unwrap(); - match recv { - C2MeCommand::Data { payload, flags } => { - assert_eq!(payload.as_ref(), &[7, 7]); - assert_eq!(flags, 7); - } - C2MeCommand::Close => panic!("unexpected close command"), - } - } -} +#[path = "middle_relay_security_tests.rs"] +mod security_tests; diff --git a/src/proxy/middle_relay_security_tests.rs b/src/proxy/middle_relay_security_tests.rs new file mode 100644 index 0000000..a2f89f8 --- /dev/null +++ b/src/proxy/middle_relay_security_tests.rs @@ -0,0 +1,201 @@ +use super::*; +use crate::crypto::AesCtr; +use crate::stats::Stats; +use crate::stream::{BufferPool, CryptoReader}; +use std::net::SocketAddr; +use std::sync::Arc; +use std::sync::atomic::AtomicU64; +use tokio::io::AsyncWriteExt; +use tokio::io::duplex; +use tokio::time::{Duration as TokioDuration, timeout}; + +#[test] +fn should_yield_sender_only_on_budget_with_backlog() { + assert!(!should_yield_c2me_sender(0, true)); + assert!(!should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET - 1, true)); + assert!(!should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET, false)); + assert!(should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET, true)); +} + +#[tokio::test] +async fn enqueue_c2me_command_uses_try_send_fast_path() { + let (tx, mut rx) = mpsc::channel::(2); + enqueue_c2me_command( + &tx, + C2MeCommand::Data { + payload: Bytes::from_static(&[1, 2, 3]), + flags: 0, + }, + ) + .await + .unwrap(); + + let recv = timeout(TokioDuration::from_millis(50), rx.recv()) + .await + .unwrap() + .unwrap(); + match recv { + C2MeCommand::Data { payload, flags } => { + assert_eq!(payload.as_ref(), &[1, 2, 3]); + assert_eq!(flags, 0); + } + C2MeCommand::Close => panic!("unexpected close command"), + } +} + +#[tokio::test] +async fn enqueue_c2me_command_falls_back_to_send_when_queue_is_full() { + let (tx, mut rx) = mpsc::channel::(1); + tx.send(C2MeCommand::Data { + payload: Bytes::from_static(&[9]), + flags: 9, + }) + .await + .unwrap(); + + let tx2 = tx.clone(); + let producer = tokio::spawn(async move { + enqueue_c2me_command( + &tx2, + C2MeCommand::Data { + payload: Bytes::from_static(&[7, 7]), + flags: 7, + }, + ) + .await + .unwrap(); + }); + + let _ = timeout(TokioDuration::from_millis(100), rx.recv()) + .await + .unwrap(); + producer.await.unwrap(); + + let recv = timeout(TokioDuration::from_millis(100), rx.recv()) + .await + .unwrap() + .unwrap(); + match recv { + C2MeCommand::Data { payload, flags } => { + assert_eq!(payload.as_ref(), &[7, 7]); + assert_eq!(flags, 7); + } + C2MeCommand::Close => panic!("unexpected close command"), + } +} + +#[test] +fn desync_dedup_cache_is_bounded() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("desync dedup test lock must be available"); + clear_desync_dedup_for_testing(); + + let now = Instant::now(); + for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { + assert!( + should_emit_full_desync(key, false, now), + "unique keys up to cap must be tracked" + ); + } + + assert!( + !should_emit_full_desync(u64::MAX, false, now), + "new key above cap must be suppressed to bound memory" + ); + + assert!( + !should_emit_full_desync(7, false, now), + "already tracked key inside dedup window must stay suppressed" + ); +} + +fn make_forensics_state() -> RelayForensicsState { + RelayForensicsState { + trace_id: 1, + conn_id: 2, + user: "test-user".to_string(), + peer: "127.0.0.1:50000".parse::().unwrap(), + peer_hash: 3, + started_at: Instant::now(), + bytes_c2me: 0, + bytes_me2c: Arc::new(AtomicU64::new(0)), + desync_all_full: false, + } +} + +fn make_crypto_reader(reader: tokio::io::DuplexStream) -> CryptoReader { + let key = [0u8; 32]; + let iv = 0u128; + CryptoReader::new(reader, AesCtr::new(&key, iv)) +} + +fn encrypt_for_reader(plaintext: &[u8]) -> Vec { + let key = [0u8; 32]; + let iv = 0u128; + let mut cipher = AesCtr::new(&key, iv); + cipher.encrypt(plaintext) +} + +#[tokio::test] +async fn read_client_payload_times_out_on_header_stall() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("middle relay test lock must be available"); + let (reader, _writer) = duplex(1024); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let forensics = make_forensics_state(); + let mut frame_counter = 0; + + let result = read_client_payload( + &mut crypto_reader, + ProtoTag::Intermediate, + 1024, + TokioDuration::from_millis(25), + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + ) + .await; + + assert!( + matches!(result, Err(ProxyError::Io(ref e)) if e.kind() == std::io::ErrorKind::TimedOut), + "stalled header read must time out" + ); +} + +#[tokio::test] +async fn read_client_payload_times_out_on_payload_stall() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("middle relay test lock must be available"); + let (reader, mut writer) = duplex(1024); + let encrypted_len = encrypt_for_reader(&[8, 0, 0, 0]); + writer.write_all(&encrypted_len).await.unwrap(); + + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let forensics = make_forensics_state(); + let mut frame_counter = 0; + + let result = read_client_payload( + &mut crypto_reader, + ProtoTag::Intermediate, + 1024, + TokioDuration::from_millis(25), + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + ) + .await; + + assert!( + matches!(result, Err(ProxyError::Io(ref e)) if e.kind() == std::io::ErrorKind::TimedOut), + "stalled payload body read must time out" + ); +} diff --git a/src/stats/mod.rs b/src/stats/mod.rs index 25905b2..603552d 100644 --- a/src/stats/mod.rs +++ b/src/stats/mod.rs @@ -1256,6 +1256,33 @@ impl Stats { Self::touch_user_stats(stats.value()); stats.curr_connects.fetch_add(1, Ordering::Relaxed); } + + pub fn try_acquire_user_curr_connects(&self, user: &str, limit: Option) -> bool { + if !self.telemetry_user_enabled() { + return true; + } + + self.maybe_cleanup_user_stats(); + let stats = self.user_stats.entry(user.to_string()).or_default(); + Self::touch_user_stats(stats.value()); + + let counter = &stats.curr_connects; + let mut current = counter.load(Ordering::Relaxed); + loop { + if let Some(max) = limit && current >= max { + return false; + } + match counter.compare_exchange_weak( + current, + current.saturating_add(1), + Ordering::Relaxed, + Ordering::Relaxed, + ) { + Ok(_) => return true, + Err(actual) => current = actual, + } + } + } pub fn decrement_user_curr_connects(&self, user: &str) { self.maybe_cleanup_user_stats(); diff --git a/src/stream/frame_codec.rs b/src/stream/frame_codec.rs index 2ff7de7..403f695 100644 --- a/src/stream/frame_codec.rs +++ b/src/stream/frame_codec.rs @@ -513,6 +513,7 @@ impl FrameCodecTrait for SecureCodec { #[cfg(test)] mod tests { use super::*; + use std::collections::HashSet; use tokio_util::codec::{FramedRead, FramedWrite}; use tokio::io::duplex; use futures::{SinkExt, StreamExt}; @@ -630,4 +631,31 @@ mod tests { let result = codec.decode(&mut buf); assert!(result.is_err()); } + + #[test] + fn secure_codec_always_adds_padding_and_jitters_wire_length() { + let codec = SecureCodec::new(Arc::new(SecureRandom::new())); + let payload = Bytes::from_static(&[1, 2, 3, 4, 5, 6, 7, 8]); + let mut wire_lens = HashSet::new(); + + for _ in 0..64 { + let frame = Frame::new(payload.clone()); + let mut out = BytesMut::new(); + codec.encode(&frame, &mut out).unwrap(); + + assert!(out.len() >= 4 + payload.len() + 1); + let wire_len = u32::from_le_bytes([out[0], out[1], out[2], out[3]]) as usize; + assert!( + (payload.len() + 1..=payload.len() + 3).contains(&wire_len), + "Secure wire length must be payload+1..3, got {wire_len}" + ); + assert_ne!(wire_len % 4, 0, "Secure wire length must be non-4-aligned"); + wire_lens.insert(wire_len); + } + + assert!( + wire_lens.len() >= 2, + "Secure padding should create observable wire-length jitter" + ); + } }