From 97d4a1c5c8a1dfa388132bff85af861099d4b68d Mon Sep 17 00:00:00 2001 From: David Osipov Date: Wed, 18 Mar 2026 01:40:38 +0400 Subject: [PATCH] Refactor and enhance security in proxy and handshake modules - Updated `direct_relay_security_tests.rs` to ensure sanitized paths are correctly validated against resolved paths. - Added tests for symlink handling in `unknown_dc_log_path_revalidation` to prevent symlink target escape vulnerabilities. - Modified `handshake.rs` to use a more robust hashing strategy for eviction offsets, improving the eviction logic in `auth_probe_record_failure_with_state`. - Introduced new tests in `handshake_security_tests.rs` to validate eviction logic under various conditions, ensuring low fail streak entries are prioritized for eviction. - Simplified `route_mode.rs` by removing unnecessary atomic mode tracking, streamlining the transition logic in `RouteRuntimeController`. - Enhanced `route_mode_security_tests.rs` with comprehensive tests for mode transitions and their effects on session states, ensuring consistency under concurrent modifications. - Cleaned up `emulator.rs` by removing unused ALPN extension handling, improving code clarity and maintainability. --- AGENTS.md | 6 + Cargo.lock | 57 +++- Cargo.toml | 1 + src/protocol/tls.rs | 67 +---- src/protocol/tls_security_tests.rs | 174 ++++++++++- src/proxy/direct_relay.rs | 64 +++- src/proxy/direct_relay_security_tests.rs | 353 ++++++++++++++++++++++- src/proxy/handshake.rs | 93 ++++-- src/proxy/handshake_security_tests.rs | 284 +++++++++++++++++- src/proxy/route_mode.rs | 47 ++- src/proxy/route_mode_security_tests.rs | 234 +++++++++++++++ src/tls_front/emulator.rs | 11 +- 12 files changed, 1247 insertions(+), 144 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index e7f94a5..c17cc76 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -390,6 +390,12 @@ you MUST explain why existing invariants remain valid. - Do not modify existing tests unless the task explicitly requires it. - Do not weaken assertions. - Preserve determinism in testable components. +- Bug-first forces the discipline of proving you understand a bug before you fix it. Tests written after a fix almost always pass trivially and catch nothing new. +- Invariants over scenarios is the core shift. The route_mode table alone would have caught both BUG-1 and BUG-2 before they were written — "snapshot equals watch state after any transition burst" is a two-line property test that fails immediately on the current diverged-atomics code. +- Differential/model catches logic drift over time. +- Scheduler pressure is specifically aimed at the concurrent state bugs that keep reappearing. A single-threaded happy-path test of set_mode will never find subtle bugs; 10,000 concurrent calls will find it on the first run. +- Mutation gate answers your original complaint directly. It measures test power. If you can remove a bounds check and nothing breaks, the suite isn't covering that branch yet — it just says so explicitly. +- Dead parameter is a code smell rule. ### 15. Security Constraints diff --git a/Cargo.lock b/Cargo.lock index 677ab84..7749ef5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -425,6 +425,32 @@ dependencies = [ "cipher", ] +[[package]] +name = "curve25519-dalek" +version = "4.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97fb8b7c4503de7d6ae7b42ab72a5a59857b4c937ec27a3d4539dba95b5ab2be" +dependencies = [ + "cfg-if", + "cpufeatures", + "curve25519-dalek-derive", + "fiat-crypto", + "rustc_version", + "subtle", + "zeroize", +] + +[[package]] +name = "curve25519-dalek-derive" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f46882e17999c6cc590af592290432be3bce0428cb0d5f8b6715e4dc7b383eb3" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.114", +] + [[package]] name = "dashmap" version = "5.5.3" @@ -517,6 +543,12 @@ version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" +[[package]] +name = "fiat-crypto" +version = "0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28dea519a9695b9977216879a3ebfddf92f1c08c05d984f8996aecd6ecdc811d" + [[package]] name = "filetime" version = "0.2.27" @@ -1609,7 +1641,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" dependencies = [ "rand_chacha", - "rand_core", + "rand_core 0.9.5", ] [[package]] @@ -1619,9 +1651,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" dependencies = [ "ppv-lite86", - "rand_core", + "rand_core 0.9.5", ] +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" + [[package]] name = "rand_core" version = "0.9.5" @@ -1637,7 +1675,7 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "513962919efc330f829edb2535844d1b912b0fbe2ca165d613e4e8788bb05a5a" dependencies = [ - "rand_core", + "rand_core 0.9.5", ] [[package]] @@ -2145,6 +2183,7 @@ dependencies = [ "tracing-subscriber", "url", "webpki-roots 0.26.11", + "x25519-dalek", "x509-parser", "zeroize", ] @@ -3144,6 +3183,18 @@ version = "0.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9edde0db4769d2dc68579893f2306b26c6ecfbe0ef499b013d731b7b9247e0b9" +[[package]] +name = "x25519-dalek" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7e468321c81fb07fa7f4c636c3972b9100f0346e5b6a9f2bd0603a52f7ed277" +dependencies = [ + "curve25519-dalek", + "rand_core 0.6.4", + "serde", + "zeroize", +] + [[package]] name = "x509-parser" version = "0.15.1" diff --git a/Cargo.toml b/Cargo.toml index 4e12cad..a47a4e5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -52,6 +52,7 @@ regex = "1.11" crossbeam-queue = "0.3" num-bigint = "0.4" num-traits = "0.2" +x25519-dalek = "2" anyhow = "1.0" # HTTP diff --git a/src/protocol/tls.rs b/src/protocol/tls.rs index 5ff38ae..3f9f981 100644 --- a/src/protocol/tls.rs +++ b/src/protocol/tls.rs @@ -11,9 +11,8 @@ use crate::crypto::{sha256_hmac, SecureRandom}; use crate::error::ProxyError; use super::constants::*; use std::time::{SystemTime, UNIX_EPOCH}; -use num_bigint::BigUint; -use num_traits::One; use subtle::ConstantTimeEq; +use x25519_dalek::{X25519_BASEPOINT_BYTES, x25519}; // ============= Public Constants ============= @@ -121,27 +120,6 @@ impl TlsExtensionBuilder { self } - /// Add ALPN extension with a single selected protocol. - fn add_alpn(&mut self, proto: &[u8]) -> &mut Self { - // Extension type: ALPN (0x0010) - self.extensions.extend_from_slice(&extension_type::ALPN.to_be_bytes()); - - // ALPN extension format: - // extension_data length (2 bytes) - // protocols length (2 bytes) - // protocol name length (1 byte) - // protocol name bytes - let proto_len = proto.len() as u8; - let list_len: u16 = 1 + u16::from(proto_len); - let ext_len: u16 = 2 + list_len; - - self.extensions.extend_from_slice(&ext_len.to_be_bytes()); - self.extensions.extend_from_slice(&list_len.to_be_bytes()); - self.extensions.push(proto_len); - self.extensions.extend_from_slice(proto); - self - } - /// Build final extensions with length prefix fn build(self) -> Vec { let mut result = Vec::with_capacity(2 + self.extensions.len()); @@ -177,8 +155,6 @@ struct ServerHelloBuilder { compression: u8, /// Extensions extensions: TlsExtensionBuilder, - /// Selected ALPN protocol (if any) - alpn: Option>, } impl ServerHelloBuilder { @@ -189,7 +165,6 @@ impl ServerHelloBuilder { cipher_suite: cipher_suite::TLS_AES_128_GCM_SHA256, compression: 0x00, extensions: TlsExtensionBuilder::new(), - alpn: None, } } @@ -204,18 +179,9 @@ impl ServerHelloBuilder { self } - fn with_alpn(mut self, proto: Option>) -> Self { - self.alpn = proto; - self - } - /// Build ServerHello message (without record header) fn build_message(&self) -> Vec { - let mut ext_builder = self.extensions.clone(); - if let Some(ref alpn) = self.alpn { - ext_builder.add_alpn(alpn); - } - let extensions = ext_builder.extensions.clone(); + let extensions = self.extensions.extensions.clone(); let extensions_len = extensions.len() as u16; // Calculate total length @@ -380,6 +346,9 @@ fn validate_tls_handshake_at_time_with_boot_cap( // Extract session ID let session_id_len_pos = TLS_DIGEST_POS + TLS_DIGEST_LEN; let session_id_len = handshake.get(session_id_len_pos).copied()? as usize; + if session_id_len > 32 { + return None; + } let session_id_start = session_id_len_pos + 1; if handshake.len() < session_id_start + session_id_len { @@ -444,27 +413,14 @@ fn validate_tls_handshake_at_time_with_boot_cap( }) } -fn curve25519_prime() -> BigUint { - (BigUint::one() << 255) - BigUint::from(19u32) -} - /// Generate a fake X25519 public key for TLS /// -/// Produces a quadratic residue mod p = 2^255 - 19 by computing n² mod p, -/// which matches Python/C behavior and avoids DPI fingerprinting. +/// Uses RFC 7748 X25519 scalar multiplication over the canonical basepoint, +/// yielding distribution-consistent public keys for anti-fingerprinting. pub fn gen_fake_x25519_key(rng: &SecureRandom) -> [u8; 32] { - let mut n_bytes = [0u8; 32]; - n_bytes.copy_from_slice(&rng.bytes(32)); - - let n = BigUint::from_bytes_le(&n_bytes); - let p = curve25519_prime(); - let pk = (&n * &n) % &p; - - let mut out = pk.to_bytes_le(); - out.resize(32, 0); - let mut result = [0u8; 32]; - result.copy_from_slice(&out[..32]); - result + let mut scalar = [0u8; 32]; + scalar.copy_from_slice(&rng.bytes(32)); + x25519(scalar, X25519_BASEPOINT_BYTES) } /// Build TLS ServerHello response @@ -481,7 +437,7 @@ pub fn build_server_hello( session_id: &[u8], fake_cert_len: usize, rng: &SecureRandom, - alpn: Option>, + _alpn: Option>, new_session_tickets: u8, ) -> Vec { const MIN_APP_DATA: usize = 64; @@ -493,7 +449,6 @@ pub fn build_server_hello( let server_hello = ServerHelloBuilder::new(session_id.to_vec()) .with_x25519_key(&x25519_key) .with_tls13_version() - .with_alpn(alpn) .build_record(); // Build Change Cipher Spec record diff --git a/src/protocol/tls_security_tests.rs b/src/protocol/tls_security_tests.rs index 74baa2f..9f568b5 100644 --- a/src/protocol/tls_security_tests.rs +++ b/src/protocol/tls_security_tests.rs @@ -1,5 +1,8 @@ use super::*; use crate::crypto::sha256_hmac; +use crate::tls_front::emulator::build_emulated_server_hello; +use crate::tls_front::types::{CachedTlsData, ParsedServerHello, TlsBehaviorProfile, TlsProfileSource}; +use std::time::SystemTime; /// Build a TLS-handshake-like buffer that contains a valid HMAC digest /// for the given `secret` and `timestamp`. @@ -369,16 +372,16 @@ fn one_byte_session_id_validates_and_is_preserved() { } #[test] -fn max_session_id_len_255_with_valid_digest_is_accepted() { +fn max_session_id_len_255_with_valid_digest_is_rejected_by_rfc_cap() { let secret = b"sid_len_255_test"; let session_id = vec![0xCCu8; 255]; let handshake = make_valid_tls_handshake_with_session_id(secret, 0, &session_id); let secrets = vec![("u".to_string(), secret.to_vec())]; - let result = validate_tls_handshake(&handshake, &secrets, true) - .expect("session_id_len=255 with valid digest must validate"); - assert_eq!(result.session_id.len(), 255); - assert_eq!(result.session_id, session_id); + assert!( + validate_tls_handshake(&handshake, &secrets, true).is_none(), + "legacy_session_id length > 32 must be rejected even with valid digest" + ); } // ------------------------------------------------------------------ @@ -1187,17 +1190,158 @@ fn test_gen_fake_x25519_key() { } #[test] -fn test_fake_x25519_key_is_quadratic_residue() { - use num_bigint::BigUint; - use num_traits::One; - +fn test_fake_x25519_key_is_nonzero_and_varies() { let rng = crate::crypto::SecureRandom::new(); - let key = gen_fake_x25519_key(&rng); - let p = curve25519_prime(); - let k_num = BigUint::from_bytes_le(&key); - let exponent = (&p - BigUint::one()) >> 1; - let legendre = k_num.modpow(&exponent, &p); - assert_eq!(legendre, BigUint::one()); + let mut unique = std::collections::HashSet::new(); + let mut saw_non_zero = false; + + for _ in 0..64 { + let key = gen_fake_x25519_key(&rng); + if key != [0u8; 32] { + saw_non_zero = true; + } + unique.insert(key); + } + + assert!( + saw_non_zero, + "generated X25519 public keys must not collapse to all-zero output" + ); + assert!( + unique.len() > 1, + "generated X25519 public keys must vary across invocations" + ); +} + +#[test] +fn validate_tls_handshake_rejects_session_id_longer_than_rfc_cap() { + let secret = b"session_id_cap_secret"; + let oversized_sid = vec![0x42u8; 33]; + let handshake = make_valid_tls_handshake_with_session_id(secret, 0, &oversized_sid); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + assert!( + validate_tls_handshake(&handshake, &secrets, true).is_none(), + "legacy_session_id length > 32 must be rejected" + ); +} + +fn server_hello_extension_types(record: &[u8]) -> Vec { + if record.len() < 9 || record[0] != TLS_RECORD_HANDSHAKE || record[5] != 0x02 { + return Vec::new(); + } + + let record_len = u16::from_be_bytes([record[3], record[4]]) as usize; + if record.len() < 5 + record_len { + return Vec::new(); + } + + let hs_len = u32::from_be_bytes([0, record[6], record[7], record[8]]) as usize; + let hs_start = 5; + let hs_end = hs_start + 4 + hs_len; + if hs_end > record.len() { + return Vec::new(); + } + + let mut pos = hs_start + 4 + 2 + 32; + if pos >= hs_end { + return Vec::new(); + } + let sid_len = record[pos] as usize; + pos += 1 + sid_len; + if pos + 2 + 1 + 2 > hs_end { + return Vec::new(); + } + + pos += 2 + 1; + let ext_len = u16::from_be_bytes([record[pos], record[pos + 1]]) as usize; + pos += 2; + let ext_end = pos + ext_len; + if ext_end > hs_end { + return Vec::new(); + } + + let mut out = Vec::new(); + while pos + 4 <= ext_end { + let etype = u16::from_be_bytes([record[pos], record[pos + 1]]); + let elen = u16::from_be_bytes([record[pos + 2], record[pos + 3]]) as usize; + pos += 4; + if pos + elen > ext_end { + break; + } + out.push(etype); + pos += elen; + } + out +} + +#[test] +fn build_server_hello_never_places_alpn_in_server_hello_extensions() { + let secret = b"alpn_sh_forbidden"; + let client_digest = [0x11u8; 32]; + let session_id = vec![0xAA; 32]; + let rng = crate::crypto::SecureRandom::new(); + + let response = build_server_hello( + secret, + &client_digest, + &session_id, + 1024, + &rng, + Some(b"h2".to_vec()), + 0, + ); + let exts = server_hello_extension_types(&response); + assert!( + !exts.contains(&0x0010), + "ALPN extension must not appear in ServerHello" + ); +} + +#[test] +fn emulated_server_hello_never_places_alpn_in_server_hello_extensions() { + let secret = b"alpn_emulated_forbidden"; + let client_digest = [0x22u8; 32]; + let session_id = vec![0xAB; 32]; + let rng = crate::crypto::SecureRandom::new(); + let cached = CachedTlsData { + server_hello_template: ParsedServerHello { + version: TLS_VERSION, + random: [0u8; 32], + session_id: Vec::new(), + cipher_suite: [0x13, 0x01], + compression: 0, + extensions: Vec::new(), + }, + cert_info: None, + cert_payload: None, + app_data_records_sizes: vec![1024], + total_app_data_len: 1024, + behavior_profile: TlsBehaviorProfile { + change_cipher_spec_count: 1, + app_data_record_sizes: vec![1024], + ticket_record_sizes: Vec::new(), + source: TlsProfileSource::Default, + }, + fetched_at: SystemTime::now(), + domain: "example.com".to_string(), + }; + + let response = build_emulated_server_hello( + secret, + &client_digest, + &session_id, + &cached, + false, + &rng, + Some(b"h2".to_vec()), + 0, + ); + let exts = server_hello_extension_types(&response); + assert!( + !exts.contains(&0x0010), + "ALPN extension must not appear in emulated ServerHello" + ); } #[test] diff --git a/src/proxy/direct_relay.rs b/src/proxy/direct_relay.rs index 72a5c91..4a7b9a9 100644 --- a/src/proxy/direct_relay.rs +++ b/src/proxy/direct_relay.rs @@ -1,3 +1,4 @@ +use std::ffi::OsString; use std::fs::OpenOptions; use std::io::Write; use std::net::SocketAddr; @@ -25,9 +26,19 @@ use crate::stats::Stats; use crate::stream::{BufferPool, CryptoReader, CryptoWriter}; use crate::transport::UpstreamManager; +#[cfg(unix)] +use std::os::unix::fs::OpenOptionsExt; + const UNKNOWN_DC_LOG_DISTINCT_LIMIT: usize = 1024; static LOGGED_UNKNOWN_DCS: OnceLock>> = OnceLock::new(); +#[derive(Clone)] +struct SanitizedUnknownDcLogPath { + resolved_path: PathBuf, + allowed_parent: PathBuf, + file_name: OsString, +} + // In tests, this function shares global mutable state. Callers that also use // cache-reset helpers must hold `unknown_dc_test_lock()` to keep assertions // deterministic under parallel execution. @@ -52,7 +63,7 @@ fn should_log_unknown_dc_with_set(set: &Mutex>, dc_idx: i16) -> boo } } -fn sanitize_unknown_dc_log_path(path: &str) -> Option { +fn sanitize_unknown_dc_log_path(path: &str) -> Option { let candidate = Path::new(path); if candidate.as_os_str().is_empty() { return None; @@ -77,7 +88,52 @@ fn sanitize_unknown_dc_log_path(path: &str) -> Option { return None; } - Some(canonical_parent.join(file_name)) + Some(SanitizedUnknownDcLogPath { + resolved_path: canonical_parent.join(file_name), + allowed_parent: canonical_parent, + file_name: file_name.to_os_string(), + }) +} + +fn unknown_dc_log_path_is_still_safe(path: &SanitizedUnknownDcLogPath) -> bool { + let Some(parent) = path.resolved_path.parent() else { + return false; + }; + let Ok(current_parent) = parent.canonicalize() else { + return false; + }; + if current_parent != path.allowed_parent { + return false; + } + + if let Ok(canonical_target) = path.resolved_path.canonicalize() { + let Some(target_parent) = canonical_target.parent() else { + return false; + }; + let Some(target_name) = canonical_target.file_name() else { + return false; + }; + if target_parent != path.allowed_parent || target_name != path.file_name { + return false; + } + } + + true +} + +fn open_unknown_dc_log_append(path: &Path) -> std::io::Result { + #[cfg(unix)] + { + OpenOptions::new() + .create(true) + .append(true) + .custom_flags(libc::O_NOFOLLOW) + .open(path) + } + #[cfg(not(unix))] + { + OpenOptions::new().create(true).append(true).open(path) + } } #[cfg(test)] @@ -234,7 +290,9 @@ fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result { { if let Some(path) = sanitize_unknown_dc_log_path(path) { handle.spawn_blocking(move || { - if let Ok(mut file) = OpenOptions::new().create(true).append(true).open(path) { + if unknown_dc_log_path_is_still_safe(&path) + && let Ok(mut file) = open_unknown_dc_log_append(&path.resolved_path) + { let _ = writeln!(file, "dc_idx={dc_idx}"); } }); diff --git a/src/proxy/direct_relay_security_tests.rs b/src/proxy/direct_relay_security_tests.rs index d967da3..e47164f 100644 --- a/src/proxy/direct_relay_security_tests.rs +++ b/src/proxy/direct_relay_security_tests.rs @@ -7,6 +7,7 @@ use crate::stats::Stats; use crate::stream::{BufferPool, CryptoReader, CryptoWriter}; use crate::transport::UpstreamManager; use std::fs; +use std::io::Write; use std::path::Path; use std::sync::Arc; use std::sync::atomic::{AtomicUsize, Ordering}; @@ -182,7 +183,7 @@ fn unknown_dc_log_path_sanitizer_accepts_absolute_paths_with_existing_parent() { let sanitized = sanitize_unknown_dc_log_path(absolute_str) .expect("absolute paths with existing parent must be accepted"); - assert_eq!(sanitized, absolute); + assert_eq!(sanitized.resolved_path, absolute); } #[test] @@ -206,7 +207,7 @@ fn unknown_dc_log_path_sanitizer_accepts_safe_relative_path() { let sanitized = sanitize_unknown_dc_log_path(&candidate_relative) .expect("safe relative path with existing parent must be accepted"); - assert_eq!(sanitized, candidate); + assert_eq!(sanitized.resolved_path, candidate); } #[test] @@ -226,7 +227,7 @@ fn unknown_dc_log_path_sanitizer_accepts_directory_only_as_filename_projection() let sanitized = sanitize_unknown_dc_log_path("target/") .expect("directory-only input is interpreted as filename projection in current sanitizer"); assert!( - sanitized.ends_with("target"), + sanitized.resolved_path.ends_with("target"), "directory-only input should resolve to canonical parent plus filename projection" ); } @@ -243,7 +244,7 @@ fn unknown_dc_log_path_sanitizer_accepts_dot_prefixed_relative_path() { let expected = abs_dir.join("unknown-dc.log"); let sanitized = sanitize_unknown_dc_log_path(&rel_candidate) .expect("dot-prefixed safe path must be accepted"); - assert_eq!(sanitized, expected); + assert_eq!(sanitized.resolved_path, expected); } #[test] @@ -300,7 +301,7 @@ fn unknown_dc_log_path_sanitizer_accepts_symlinked_parent_inside_workspace() { let sanitized = sanitize_unknown_dc_log_path(&rel_candidate) .expect("symlinked parent that resolves inside workspace must be accepted"); assert!( - sanitized.starts_with(&real_parent), + sanitized.resolved_path.starts_with(&real_parent), "sanitized path must resolve to canonical internal parent" ); } @@ -328,11 +329,304 @@ fn unknown_dc_log_path_sanitizer_accepts_symlink_parent_escape_as_canonical_path let sanitized = sanitize_unknown_dc_log_path(&rel_candidate) .expect("symlinked parent must canonicalize to target path"); assert!( - sanitized.starts_with(Path::new("/tmp")), + sanitized.resolved_path.starts_with(Path::new("/tmp")), "sanitized path must resolve to canonical symlink target" ); } +#[cfg(unix)] +#[test] +fn unknown_dc_log_path_revalidation_rejects_symlinked_target_escape() { + use std::os::unix::fs::symlink; + + let base = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!("telemt-unknown-dc-target-link-{}", std::process::id())); + fs::create_dir_all(&base).expect("target-link base must be creatable"); + + let outside = std::env::temp_dir().join(format!("telemt-outside-{}", std::process::id())); + let _ = fs::remove_file(&outside); + fs::write(&outside, "outside").expect("outside file must be writable"); + + let linked_target = base.join("unknown-dc.log"); + let _ = fs::remove_file(&linked_target); + symlink(&outside, &linked_target).expect("target symlink must be creatable"); + + let rel_candidate = format!( + "target/telemt-unknown-dc-target-link-{}/unknown-dc.log", + std::process::id() + ); + let sanitized = sanitize_unknown_dc_log_path(&rel_candidate) + .expect("candidate should sanitize before final revalidation"); + + assert!( + !unknown_dc_log_path_is_still_safe(&sanitized), + "final revalidation must reject symlinked target escape" + ); +} + +#[cfg(unix)] +#[test] +fn unknown_dc_open_append_rejects_symlink_target_with_nofollow() { + use std::os::unix::fs::symlink; + + let base = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!("telemt-unknown-dc-nofollow-{}", std::process::id())); + fs::create_dir_all(&base).expect("nofollow base must be creatable"); + + let outside = std::env::temp_dir().join(format!( + "telemt-unknown-dc-nofollow-outside-{}.log", + std::process::id() + )); + let _ = fs::remove_file(&outside); + fs::write(&outside, "outside\n").expect("outside file must be writable"); + + let linked_target = base.join("unknown-dc.log"); + let _ = fs::remove_file(&linked_target); + symlink(&outside, &linked_target).expect("symlink target must be creatable"); + + let err = open_unknown_dc_log_append(&linked_target) + .expect_err("O_NOFOLLOW open must fail for symlink target"); + assert_eq!( + err.raw_os_error(), + Some(libc::ELOOP), + "symlink target must be rejected with ELOOP when O_NOFOLLOW is applied" + ); +} + +#[cfg(unix)] +#[test] +fn unknown_dc_open_append_rejects_broken_symlink_target_with_nofollow() { + use std::os::unix::fs::symlink; + + let base = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!("telemt-unknown-dc-broken-link-{}", std::process::id())); + fs::create_dir_all(&base).expect("broken-link base must be creatable"); + + let linked_target = base.join("unknown-dc.log"); + let _ = fs::remove_file(&linked_target); + symlink(base.join("missing-target.log"), &linked_target) + .expect("broken symlink target must be creatable"); + + let err = open_unknown_dc_log_append(&linked_target) + .expect_err("O_NOFOLLOW open must fail for broken symlink target"); + assert_eq!( + err.raw_os_error(), + Some(libc::ELOOP), + "broken symlink target must be rejected with ELOOP when O_NOFOLLOW is applied" + ); +} + +#[cfg(unix)] +#[test] +fn adversarial_unknown_dc_open_append_symlink_flip_never_writes_outside_file() { + use std::os::unix::fs::symlink; + + let base = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!("telemt-unknown-dc-symlink-flip-{}", std::process::id())); + fs::create_dir_all(&base).expect("symlink-flip base must be creatable"); + + let outside = std::env::temp_dir().join(format!( + "telemt-unknown-dc-symlink-flip-outside-{}.log", + std::process::id() + )); + fs::write(&outside, "outside-baseline\n").expect("outside baseline file must be writable"); + let outside_before = fs::read_to_string(&outside).expect("outside baseline must be readable"); + + let target = base.join("unknown-dc.log"); + let _ = fs::remove_file(&target); + + for step in 0..1024usize { + let _ = fs::remove_file(&target); + if step % 2 == 0 { + symlink(&outside, &target).expect("symlink creation in flip loop must succeed"); + } + if let Ok(mut file) = open_unknown_dc_log_append(&target) { + writeln!(file, "dc_idx={step}").expect("append on regular file must succeed"); + } + } + + let outside_after = fs::read_to_string(&outside).expect("outside file must remain readable"); + assert_eq!( + outside_after, outside_before, + "outside file must never be modified under symlink-flip adversarial churn" + ); +} + +#[test] +fn unknown_dc_open_append_creates_regular_file() { + let base = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!("telemt-unknown-dc-open-{}", std::process::id())); + fs::create_dir_all(&base).expect("open test base must be creatable"); + + let target = base.join("unknown-dc.log"); + let _ = fs::remove_file(&target); + + { + let mut file = open_unknown_dc_log_append(&target) + .expect("regular target must be creatable with append open"); + writeln!(file, "dc_idx=1234").expect("append write must succeed"); + } + + let meta = fs::symlink_metadata(&target).expect("created target metadata must be readable"); + assert!(meta.file_type().is_file(), "target must be a regular file"); + assert!( + !meta.file_type().is_symlink(), + "regular target open path must not produce symlink artifacts" + ); +} + +#[test] +fn stress_unknown_dc_open_append_regular_file_preserves_line_integrity() { + let base = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!("telemt-unknown-dc-open-stress-{}", std::process::id())); + fs::create_dir_all(&base).expect("stress open base must be creatable"); + + let target = base.join("unknown-dc.log"); + let _ = fs::remove_file(&target); + + let writes = 2048usize; + for idx in 0..writes { + let mut file = open_unknown_dc_log_append(&target) + .expect("stress append open on regular file must succeed"); + writeln!(file, "dc_idx={idx}").expect("stress append write must succeed"); + } + + let content = fs::read_to_string(&target).expect("stress output file must be readable"); + assert_eq!( + nonempty_line_count(&content), + writes, + "regular-file append stress must preserve one logical line per write" + ); +} + +#[test] +fn unknown_dc_log_path_revalidation_accepts_regular_existing_target() { + let base = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!("telemt-unknown-dc-safe-target-{}", std::process::id())); + fs::create_dir_all(&base).expect("safe target base must be creatable"); + + let target = base.join("unknown-dc.log"); + fs::write(&target, "seed\n").expect("safe target seed write must succeed"); + + let rel_candidate = format!( + "target/telemt-unknown-dc-safe-target-{}/unknown-dc.log", + std::process::id() + ); + let sanitized = sanitize_unknown_dc_log_path(&rel_candidate) + .expect("safe candidate must sanitize"); + assert!( + unknown_dc_log_path_is_still_safe(&sanitized), + "revalidation must allow safe existing regular files" + ); +} + +#[test] +fn unknown_dc_log_path_revalidation_rejects_deleted_parent_after_sanitize() { + let base = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!("telemt-unknown-dc-vanish-parent-{}", std::process::id())); + fs::create_dir_all(&base).expect("vanish-parent base must be creatable"); + + let rel_candidate = format!( + "target/telemt-unknown-dc-vanish-parent-{}/unknown-dc.log", + std::process::id() + ); + let sanitized = sanitize_unknown_dc_log_path(&rel_candidate) + .expect("candidate must sanitize before parent deletion"); + + fs::remove_dir_all(&base).expect("test parent directory must be removable"); + assert!( + !unknown_dc_log_path_is_still_safe(&sanitized), + "revalidation must fail when sanitized parent disappears before write" + ); +} + +#[cfg(unix)] +#[test] +fn unknown_dc_log_path_revalidation_rejects_parent_swapped_to_symlink() { + use std::os::unix::fs::symlink; + + let parent = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!("telemt-unknown-dc-parent-swap-{}", std::process::id())); + fs::create_dir_all(&parent).expect("parent-swap test parent must be creatable"); + + let rel_candidate = format!( + "target/telemt-unknown-dc-parent-swap-{}/unknown-dc.log", + std::process::id() + ); + let sanitized = sanitize_unknown_dc_log_path(&rel_candidate) + .expect("candidate must sanitize before parent swap"); + + let moved = parent.with_extension("bak"); + let _ = fs::remove_dir_all(&moved); + fs::rename(&parent, &moved).expect("parent must be movable for swap simulation"); + symlink("/tmp", &parent).expect("symlink replacement for parent must be creatable"); + + assert!( + !unknown_dc_log_path_is_still_safe(&sanitized), + "revalidation must fail when canonical parent is swapped to a symlinked target" + ); +} + +#[cfg(unix)] +#[test] +fn adversarial_check_then_symlink_flip_is_blocked_by_nofollow_open() { + use std::os::unix::fs::symlink; + + let parent = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!("telemt-unknown-dc-check-open-race-{}", std::process::id())); + fs::create_dir_all(&parent).expect("check-open-race parent must be creatable"); + + let target = parent.join("unknown-dc.log"); + fs::write(&target, "seed\n").expect("seed target file must be writable"); + let rel_candidate = format!( + "target/telemt-unknown-dc-check-open-race-{}/unknown-dc.log", + std::process::id() + ); + let sanitized = sanitize_unknown_dc_log_path(&rel_candidate) + .expect("candidate must sanitize"); + + assert!( + unknown_dc_log_path_is_still_safe(&sanitized), + "precondition: target should initially pass revalidation" + ); + + let outside = std::env::temp_dir().join(format!( + "telemt-unknown-dc-check-open-race-outside-{}.log", + std::process::id() + )); + fs::write(&outside, "outside\n").expect("outside file must be writable"); + fs::remove_file(&target).expect("target removal before flip must succeed"); + symlink(&outside, &target).expect("target symlink flip must be creatable"); + + let err = open_unknown_dc_log_append(&sanitized.resolved_path) + .expect_err("nofollow open must fail after symlink flip between check and open"); + assert_eq!( + err.raw_os_error(), + Some(libc::ELOOP), + "symlink flip in check/open window must be neutralized by O_NOFOLLOW" + ); +} + #[tokio::test] async fn unknown_dc_absolute_log_path_writes_one_entry() { let _guard = unknown_dc_test_lock() @@ -499,6 +793,53 @@ async fn unknown_dc_distinct_burst_is_hard_capped_on_file_writes() { ); } +#[cfg(unix)] +#[tokio::test] +async fn unknown_dc_symlinked_target_escape_is_not_written_integration() { + use std::os::unix::fs::symlink; + + let _guard = unknown_dc_test_lock() + .lock() + .expect("unknown dc test lock must be available"); + clear_unknown_dc_log_cache_for_testing(); + + let base = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!("telemt-unknown-dc-no-write-link-{}", std::process::id())); + fs::create_dir_all(&base).expect("integration symlink base must be creatable"); + + let outside = std::env::temp_dir().join(format!( + "telemt-unknown-dc-outside-{}.log", + std::process::id() + )); + fs::write(&outside, "baseline\n").expect("outside baseline file must be writable"); + + let linked_target = base.join("unknown-dc.log"); + let _ = fs::remove_file(&linked_target); + symlink(&outside, &linked_target).expect("symlink target must be creatable"); + + let rel_file = format!( + "target/telemt-unknown-dc-no-write-link-{}/unknown-dc.log", + std::process::id() + ); + let dc_idx: i16 = 31_050; + + let mut cfg = ProxyConfig::default(); + cfg.general.unknown_dc_file_log_enabled = true; + cfg.general.unknown_dc_log_path = Some(rel_file); + + let before = fs::read_to_string(&outside).expect("must read baseline outside file"); + let _ = get_dc_addr_static(dc_idx, &cfg).expect("fallback routing must still work"); + tokio::time::sleep(Duration::from_millis(80)).await; + let after = fs::read_to_string(&outside).expect("must read outside file after attempt"); + + assert_eq!( + after, before, + "symlink target escape must not be written by unknown-DC logging" + ); +} + #[test] fn fallback_dc_never_panics_with_single_dc_list() { let mut cfg = ProxyConfig::default(); diff --git a/src/proxy/handshake.rs b/src/proxy/handshake.rs index 03b5012..3659754 100644 --- a/src/proxy/handshake.rs +++ b/src/proxy/handshake.rs @@ -4,11 +4,11 @@ use std::net::SocketAddr; use std::collections::HashSet; +use std::collections::hash_map::RandomState; use std::net::{IpAddr, Ipv6Addr}; use std::sync::Arc; use std::sync::{Mutex, OnceLock}; -use std::collections::hash_map::DefaultHasher; -use std::hash::{Hash, Hasher}; +use std::hash::{BuildHasher, Hash, Hasher}; use std::time::{Duration, Instant}; use dashmap::DashMap; use dashmap::mapref::entry::Entry; @@ -64,6 +64,7 @@ struct AuthProbeSaturationState { static AUTH_PROBE_STATE: OnceLock> = OnceLock::new(); static AUTH_PROBE_SATURATION_STATE: OnceLock>> = OnceLock::new(); +static AUTH_PROBE_EVICTION_HASHER: OnceLock = OnceLock::new(); fn auth_probe_state_map() -> &'static DashMap { AUTH_PROBE_STATE.get_or_init(DashMap::new) @@ -101,7 +102,8 @@ fn auth_probe_state_expired(state: &AuthProbeState, now: Instant) -> bool { } fn auth_probe_eviction_offset(peer_ip: IpAddr, now: Instant) -> usize { - let mut hasher = DefaultHasher::new(); + let hasher_state = AUTH_PROBE_EVICTION_HASHER.get_or_init(RandomState::new); + let mut hasher = hasher_state.build_hasher(); peer_ip.hash(&mut hasher); now.hash(&mut hasher); hasher.finish() as usize @@ -234,32 +236,79 @@ fn auth_probe_record_failure_with_state( } if state.len() >= AUTH_PROBE_TRACK_MAX_ENTRIES { - let mut stale_keys = Vec::new(); - let mut oldest_candidate: Option<(IpAddr, Instant)> = None; - for entry in state.iter().take(AUTH_PROBE_PRUNE_SCAN_LIMIT) { - let key = *entry.key(); - let last_seen = entry.value().last_seen; - match oldest_candidate { - Some((_, oldest_seen)) if last_seen >= oldest_seen => {} - _ => oldest_candidate = Some((key, last_seen)), + let mut rounds = 0usize; + while state.len() >= AUTH_PROBE_TRACK_MAX_ENTRIES { + rounds += 1; + if rounds > 8 { + auth_probe_note_saturation(now); + return; } - if auth_probe_state_expired(entry.value(), now) { - stale_keys.push(key); + + let mut stale_keys = Vec::new(); + let mut eviction_candidate: Option<(IpAddr, u32, Instant)> = None; + let state_len = state.len(); + let scan_limit = state_len.min(AUTH_PROBE_PRUNE_SCAN_LIMIT); + let start_offset = if state_len == 0 { + 0 + } else { + auth_probe_eviction_offset(peer_ip, now) % state_len + }; + + let mut scanned = 0usize; + for entry in state.iter().skip(start_offset) { + let key = *entry.key(); + let fail_streak = entry.value().fail_streak; + let last_seen = entry.value().last_seen; + match eviction_candidate { + Some((_, current_fail, current_seen)) + if fail_streak > current_fail + || (fail_streak == current_fail && last_seen >= current_seen) => + { + } + _ => eviction_candidate = Some((key, fail_streak, last_seen)), + } + if auth_probe_state_expired(entry.value(), now) { + stale_keys.push(key); + } + scanned += 1; + if scanned >= scan_limit { + break; + } } - } - for stale_key in stale_keys { - state.remove(&stale_key); - } - if state.len() >= AUTH_PROBE_TRACK_MAX_ENTRIES { - let Some((evict_key, _)) = oldest_candidate else { + + if scanned < scan_limit { + for entry in state.iter().take(scan_limit - scanned) { + let key = *entry.key(); + let fail_streak = entry.value().fail_streak; + let last_seen = entry.value().last_seen; + match eviction_candidate { + Some((_, current_fail, current_seen)) + if fail_streak > current_fail + || (fail_streak == current_fail && last_seen >= current_seen) => + { + } + _ => eviction_candidate = Some((key, fail_streak, last_seen)), + } + if auth_probe_state_expired(entry.value(), now) { + stale_keys.push(key); + } + } + } + + for stale_key in stale_keys { + state.remove(&stale_key); + } + + if state.len() < AUTH_PROBE_TRACK_MAX_ENTRIES { + break; + } + + let Some((evict_key, _, _)) = eviction_candidate else { auth_probe_note_saturation(now); return; }; state.remove(&evict_key); auth_probe_note_saturation(now); - if state.len() >= AUTH_PROBE_TRACK_MAX_ENTRIES { - return; - } } } diff --git a/src/proxy/handshake_security_tests.rs b/src/proxy/handshake_security_tests.rs index 1823167..2132fbe 100644 --- a/src/proxy/handshake_security_tests.rs +++ b/src/proxy/handshake_security_tests.rs @@ -1539,7 +1539,7 @@ fn auth_probe_capacity_fresh_full_map_still_tracks_newcomer_with_bounded_evictio fn stress_auth_probe_full_map_churn_keeps_bound_and_tracks_newcomers() { let _guard = auth_probe_test_lock() .lock() - .expect("auth probe test lock must be available"); + .unwrap_or_else(|poisoned| poisoned.into_inner()); clear_auth_probe_state_for_testing(); let state = DashMap::new(); @@ -1584,6 +1584,197 @@ fn stress_auth_probe_full_map_churn_keeps_bound_and_tracks_newcomers() { } } +#[test] +fn auth_probe_capacity_prefers_evicting_low_fail_streak_entries_first() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let state = DashMap::new(); + let now = Instant::now(); + + // Fill map at capacity with mostly high fail streak entries. + for idx in 0..AUTH_PROBE_TRACK_MAX_ENTRIES { + let ip = IpAddr::V4(Ipv4Addr::new( + 172, + 20, + ((idx >> 8) & 0xff) as u8, + (idx & 0xff) as u8, + )); + state.insert( + ip, + AuthProbeState { + fail_streak: 9, + blocked_until: now, + last_seen: now + Duration::from_millis(idx as u64 + 1), + }, + ); + } + + let low_fail = IpAddr::V4(Ipv4Addr::new(172, 21, 0, 1)); + state.insert( + low_fail, + AuthProbeState { + fail_streak: 1, + blocked_until: now, + last_seen: now + Duration::from_secs(30), + }, + ); + + let high_fail_old = IpAddr::V4(Ipv4Addr::new(172, 21, 0, 2)); + state.insert( + high_fail_old, + AuthProbeState { + fail_streak: 12, + blocked_until: now, + last_seen: now - Duration::from_secs(10), + }, + ); + + let newcomer = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 201)); + auth_probe_record_failure_with_state(&state, newcomer, now); + + assert!(state.get(&newcomer).is_some(), "new source must be tracked"); + assert!( + state.get(&low_fail).is_none(), + "least-penalized entry should be evicted before high-penalty entries" + ); + assert!( + state.get(&high_fail_old).is_some(), + "high fail-streak entry should be preserved under mixed-priority eviction" + ); +} + +#[test] +fn auth_probe_capacity_tie_breaker_evicts_oldest_with_equal_fail_streak() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let state = DashMap::new(); + let now = Instant::now(); + + for idx in 0..(AUTH_PROBE_TRACK_MAX_ENTRIES - 2) { + let ip = IpAddr::V4(Ipv4Addr::new( + 172, + 30, + ((idx >> 8) & 0xff) as u8, + (idx & 0xff) as u8, + )); + state.insert( + ip, + AuthProbeState { + fail_streak: 5, + blocked_until: now, + last_seen: now + Duration::from_millis(idx as u64 + 1), + }, + ); + } + + let oldest = IpAddr::V4(Ipv4Addr::new(172, 31, 0, 1)); + let newer = IpAddr::V4(Ipv4Addr::new(172, 31, 0, 2)); + state.insert( + oldest, + AuthProbeState { + fail_streak: 1, + blocked_until: now, + last_seen: now - Duration::from_secs(20), + }, + ); + state.insert( + newer, + AuthProbeState { + fail_streak: 1, + blocked_until: now, + last_seen: now - Duration::from_secs(5), + }, + ); + + let newcomer = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 202)); + auth_probe_record_failure_with_state(&state, newcomer, now); + + assert!(state.get(&newcomer).is_some(), "new source must be tracked"); + assert!( + state.get(&oldest).is_none(), + "among equal fail streak candidates, oldest entry must be evicted" + ); + assert!( + state.get(&newer).is_some(), + "newer equal-priority entry should be retained" + ); +} + +#[test] +fn stress_auth_probe_capacity_churn_preserves_high_fail_sentinels() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let state = DashMap::new(); + let base_now = Instant::now(); + + let sentinel_a = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 250)); + let sentinel_b = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 251)); + + state.insert( + sentinel_a, + AuthProbeState { + fail_streak: 20, + blocked_until: base_now, + last_seen: base_now - Duration::from_secs(30), + }, + ); + state.insert( + sentinel_b, + AuthProbeState { + fail_streak: 21, + blocked_until: base_now, + last_seen: base_now - Duration::from_secs(31), + }, + ); + + for idx in 0..(AUTH_PROBE_TRACK_MAX_ENTRIES - 2) { + let ip = IpAddr::V4(Ipv4Addr::new( + 10, + 4, + ((idx >> 8) & 0xff) as u8, + (idx & 0xff) as u8, + )); + state.insert( + ip, + AuthProbeState { + fail_streak: 1, + blocked_until: base_now, + last_seen: base_now + Duration::from_millis((idx % 1024) as u64), + }, + ); + } + + for step in 0..1024usize { + let newcomer = IpAddr::V4(Ipv4Addr::new( + 203, + 1, + ((step >> 8) & 0xff) as u8, + (step & 0xff) as u8, + )); + let now = base_now + Duration::from_millis(10_000 + step as u64); + auth_probe_record_failure_with_state(&state, newcomer, now); + + assert_eq!( + state.len(), + AUTH_PROBE_TRACK_MAX_ENTRIES, + "auth probe map must remain hard-bounded at capacity" + ); + assert!( + state.get(&sentinel_a).is_some() && state.get(&sentinel_b).is_some(), + "high fail-streak sentinels should survive low-streak newcomer churn" + ); + } +} + #[test] fn auth_probe_ipv6_is_bucketed_by_prefix_64() { let state = DashMap::new(); @@ -1674,6 +1865,97 @@ fn auth_probe_eviction_offset_varies_with_input() { assert_ne!(a, c, "different peer IPs should not collapse to one offset"); } +#[test] +fn auth_probe_eviction_offset_changes_with_time_component() { + let ip = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 77)); + let now = Instant::now(); + let later = now + Duration::from_millis(1); + + let a = auth_probe_eviction_offset(ip, now); + let b = auth_probe_eviction_offset(ip, later); + + assert_ne!( + a, b, + "eviction offset must incorporate timestamp entropy and not only peer IP" + ); +} + +#[test] +fn light_fuzz_auth_probe_eviction_offset_is_deterministic_per_input_pair() { + let mut rng = StdRng::seed_from_u64(0xA11CE5EED); + let base = Instant::now(); + + for _ in 0..4096usize { + let ip = IpAddr::V4(Ipv4Addr::new(rng.random(), rng.random(), rng.random(), rng.random())); + let offset_ns = rng.random_range(0_u64..2_000_000); + let when = base + Duration::from_nanos(offset_ns); + + let first = auth_probe_eviction_offset(ip, when); + let second = auth_probe_eviction_offset(ip, when); + assert_eq!( + first, second, + "eviction offset must be stable for identical (ip, now) pairs" + ); + } +} + +#[test] +fn adversarial_eviction_offset_spread_avoids_single_bucket_collapse() { + let modulus = AUTH_PROBE_TRACK_MAX_ENTRIES; + let mut bucket_hits = vec![0usize; modulus]; + let now = Instant::now(); + + for idx in 0..8192usize { + let ip = IpAddr::V4(Ipv4Addr::new( + 100, + ((idx >> 8) & 0xff) as u8, + (idx & 0xff) as u8, + ((idx.wrapping_mul(37)) & 0xff) as u8, + )); + let bucket = auth_probe_eviction_offset(ip, now) % modulus; + bucket_hits[bucket] += 1; + } + + let non_empty_buckets = bucket_hits.iter().filter(|&&hits| hits > 0).count(); + assert!( + non_empty_buckets >= modulus / 2, + "adversarial sequential input should cover a broad bucket set (covered {non_empty_buckets}/{modulus})" + ); + + let max_hits = bucket_hits.iter().copied().max().unwrap_or(0); + let min_non_zero_hits = bucket_hits + .iter() + .copied() + .filter(|&hits| hits > 0) + .min() + .unwrap_or(0); + assert!( + max_hits <= min_non_zero_hits.saturating_mul(32).max(1), + "bucket skew is unexpectedly extreme for keyed hasher spread (max={max_hits}, min_non_zero={min_non_zero_hits})" + ); +} + +#[test] +fn stress_auth_probe_eviction_offset_high_volume_uniqueness_sanity() { + let now = Instant::now(); + let mut seen = std::collections::HashSet::new(); + + for idx in 0..50_000usize { + let ip = IpAddr::V4(Ipv4Addr::new( + 198, + ((idx >> 16) & 0xff) as u8, + ((idx >> 8) & 0xff) as u8, + (idx & 0xff) as u8, + )); + seen.insert(auth_probe_eviction_offset(ip, now)); + } + + assert!( + seen.len() >= 40_000, + "high-volume eviction offsets should not collapse excessively under keyed hashing" + ); +} + #[tokio::test(flavor = "multi_thread", worker_threads = 4)] async fn auth_probe_concurrent_failures_do_not_lose_fail_streak_updates() { let _guard = auth_probe_test_lock() diff --git a/src/proxy/route_mode.rs b/src/proxy/route_mode.rs index 2b109d1..114babe 100644 --- a/src/proxy/route_mode.rs +++ b/src/proxy/route_mode.rs @@ -1,5 +1,5 @@ use std::sync::Arc; -use std::sync::atomic::{AtomicU8, AtomicU64, Ordering}; +use std::sync::atomic::{AtomicU64, Ordering}; use std::time::{Duration, SystemTime, UNIX_EPOCH}; use tokio::sync::watch; @@ -14,17 +14,6 @@ pub(crate) enum RelayRouteMode { } impl RelayRouteMode { - pub(crate) fn as_u8(self) -> u8 { - self as u8 - } - - pub(crate) fn from_u8(value: u8) -> Self { - match value { - 1 => Self::Middle, - _ => Self::Direct, - } - } - pub(crate) fn as_str(self) -> &'static str { match self { Self::Direct => "direct", @@ -41,8 +30,6 @@ pub(crate) struct RouteCutoverState { #[derive(Clone)] pub(crate) struct RouteRuntimeController { - mode: Arc, - generation: Arc, direct_since_epoch_secs: Arc, tx: watch::Sender, } @@ -60,18 +47,13 @@ impl RouteRuntimeController { 0 }; Self { - mode: Arc::new(AtomicU8::new(initial_mode.as_u8())), - generation: Arc::new(AtomicU64::new(0)), direct_since_epoch_secs: Arc::new(AtomicU64::new(direct_since_epoch_secs)), tx, } } pub(crate) fn snapshot(&self) -> RouteCutoverState { - RouteCutoverState { - mode: RelayRouteMode::from_u8(self.mode.load(Ordering::Relaxed)), - generation: self.generation.load(Ordering::Relaxed), - } + *self.tx.borrow() } pub(crate) fn subscribe(&self) -> watch::Receiver { @@ -84,20 +66,29 @@ impl RouteRuntimeController { } pub(crate) fn set_mode(&self, mode: RelayRouteMode) -> Option { - let previous = self.mode.swap(mode.as_u8(), Ordering::Relaxed); - if previous == mode.as_u8() { + let mut next = None; + let changed = self.tx.send_if_modified(|state| { + if state.mode == mode { + return false; + } + state.mode = mode; + state.generation = state.generation.saturating_add(1); + next = Some(*state); + true + }); + + if !changed { return None; } + if matches!(mode, RelayRouteMode::Direct) { self.direct_since_epoch_secs .store(now_epoch_secs(), Ordering::Relaxed); } else { self.direct_since_epoch_secs.store(0, Ordering::Relaxed); } - let generation = self.generation.fetch_add(1, Ordering::Relaxed) + 1; - let next = RouteCutoverState { mode, generation }; - self.tx.send_replace(next); - Some(next) + + next } } @@ -110,10 +101,10 @@ fn now_epoch_secs() -> u64 { pub(crate) fn is_session_affected_by_cutover( current: RouteCutoverState, - _session_mode: RelayRouteMode, + session_mode: RelayRouteMode, session_generation: u64, ) -> bool { - current.generation > session_generation + current.generation > session_generation && current.mode != session_mode } pub(crate) fn affected_cutover_state( diff --git a/src/proxy/route_mode_security_tests.rs b/src/proxy/route_mode_security_tests.rs index 36ab5c3..e86d574 100644 --- a/src/proxy/route_mode_security_tests.rs +++ b/src/proxy/route_mode_security_tests.rs @@ -1,4 +1,8 @@ use super::*; +use rand::{Rng, SeedableRng}; +use rand::rngs::StdRng; +use std::sync::Arc; +use std::sync::atomic::{AtomicU64, Ordering}; #[test] fn cutover_stagger_delay_is_deterministic_for_same_inputs() { @@ -81,6 +85,236 @@ fn affected_cutover_state_triggers_only_for_newer_generation() { assert_eq!(seen.mode, RelayRouteMode::Middle); } +#[test] +fn integration_watch_and_snapshot_follow_same_transition_sequence() { + let runtime = RouteRuntimeController::new(RelayRouteMode::Direct); + let rx = runtime.subscribe(); + + let sequence = [ + RelayRouteMode::Middle, + RelayRouteMode::Middle, + RelayRouteMode::Direct, + RelayRouteMode::Direct, + RelayRouteMode::Middle, + ]; + + let mut expected_generation = 0u64; + let mut expected_mode = RelayRouteMode::Direct; + + for target in sequence { + let changed = runtime.set_mode(target); + if target == expected_mode { + assert!(changed.is_none(), "idempotent transition must return none"); + } else { + expected_mode = target; + expected_generation = expected_generation.saturating_add(1); + let emitted = changed.expect("real transition must emit cutover state"); + assert_eq!(emitted.mode, expected_mode); + assert_eq!(emitted.generation, expected_generation); + } + + let snap = runtime.snapshot(); + let watched = *rx.borrow(); + assert_eq!(snap, watched, "snapshot and watch state must stay aligned"); + assert_eq!(snap.mode, expected_mode); + assert_eq!(snap.generation, expected_generation); + } +} + +#[test] +fn session_is_not_affected_when_mode_matches_even_if_generation_advanced() { + let session_mode = RelayRouteMode::Direct; + let current = RouteCutoverState { + mode: RelayRouteMode::Direct, + generation: 2, + }; + let session_generation = 0; + + assert!( + !is_session_affected_by_cutover(current, session_mode, session_generation), + "session on matching final route mode should not be force-cut over on intermediate generation bumps" + ); +} + +#[test] +fn cutover_predicate_rejects_equal_generation_even_if_mode_differs() { + let current = RouteCutoverState { + mode: RelayRouteMode::Middle, + generation: 77, + }; + assert!( + !is_session_affected_by_cutover(current, RelayRouteMode::Direct, 77), + "equal generation must never trigger cutover regardless of mode mismatch" + ); +} + +#[test] +fn adversarial_route_oscillation_only_cuts_over_sessions_with_different_final_mode() { + let runtime = RouteRuntimeController::new(RelayRouteMode::Direct); + let rx = runtime.subscribe(); + let session_generation = runtime.snapshot().generation; + + runtime + .set_mode(RelayRouteMode::Middle) + .expect("direct->middle must transition"); + runtime + .set_mode(RelayRouteMode::Direct) + .expect("middle->direct must transition"); + + assert!( + affected_cutover_state(&rx, RelayRouteMode::Direct, session_generation).is_none(), + "direct session should survive when final mode returns to direct" + ); + assert!( + affected_cutover_state(&rx, RelayRouteMode::Middle, session_generation).is_some(), + "middle session should be cut over when final mode is direct" + ); +} + +#[test] +fn light_fuzz_cutover_predicate_matches_reference_oracle() { + let mut rng = StdRng::seed_from_u64(0xC0DEC0DE5EED); + for _ in 0..20_000 { + let current = RouteCutoverState { + mode: if rng.random::() { + RelayRouteMode::Direct + } else { + RelayRouteMode::Middle + }, + generation: rng.random_range(0u64..1_000_000), + }; + let session_mode = if rng.random::() { + RelayRouteMode::Direct + } else { + RelayRouteMode::Middle + }; + let session_generation = rng.random_range(0u64..1_000_000); + + let expected = current.generation > session_generation && current.mode != session_mode; + let actual = is_session_affected_by_cutover(current, session_mode, session_generation); + assert_eq!( + actual, expected, + "cutover predicate must match mode-aware generation oracle" + ); + } +} + +#[test] +fn light_fuzz_set_mode_generation_tracks_only_real_transitions() { + let runtime = RouteRuntimeController::new(RelayRouteMode::Direct); + let mut rng = StdRng::seed_from_u64(0x0DDC0FFE); + + let mut expected_mode = RelayRouteMode::Direct; + let mut expected_generation = 0u64; + + for _ in 0..10_000 { + let candidate = if rng.random::() { + RelayRouteMode::Direct + } else { + RelayRouteMode::Middle + }; + let changed = runtime.set_mode(candidate); + + if candidate == expected_mode { + assert!(changed.is_none(), "idempotent set_mode must not emit cutover state"); + } else { + expected_mode = candidate; + expected_generation = expected_generation.saturating_add(1); + let next = changed.expect("mode transition must emit cutover state"); + assert_eq!(next.mode, expected_mode); + assert_eq!(next.generation, expected_generation); + } + } + + let final_state = runtime.snapshot(); + assert_eq!(final_state.mode, expected_mode); + assert_eq!(final_state.generation, expected_generation); +} + +#[test] +fn stress_snapshot_and_watch_state_remain_consistent_under_concurrent_switch_storm() { + let runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + + std::thread::scope(|scope| { + let mut writers = Vec::new(); + for worker in 0..4usize { + let runtime = Arc::clone(&runtime); + writers.push(scope.spawn(move || { + for step in 0..20_000usize { + let mode = if (worker + step) % 2 == 0 { + RelayRouteMode::Direct + } else { + RelayRouteMode::Middle + }; + let _ = runtime.set_mode(mode); + } + })); + } + + for writer in writers { + writer + .join() + .expect("route mode writer thread must not panic"); + } + + let rx = runtime.subscribe(); + for _ in 0..128 { + assert_eq!( + runtime.snapshot(), + *rx.borrow(), + "snapshot and watch state must converge after concurrent set_mode churn" + ); + std::thread::yield_now(); + } + }); +} + +#[test] +fn stress_concurrent_transition_count_matches_final_generation() { + let runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let successful_transitions = Arc::new(AtomicU64::new(0)); + + std::thread::scope(|scope| { + let mut workers = Vec::new(); + for worker in 0..6usize { + let runtime = Arc::clone(&runtime); + let successful_transitions = Arc::clone(&successful_transitions); + workers.push(scope.spawn(move || { + let mut state = (worker as u64 + 1).wrapping_mul(0x9E37_79B9_7F4A_7C15); + for _ in 0..25_000usize { + state ^= state << 7; + state ^= state >> 9; + state ^= state << 8; + let mode = if (state & 1) == 0 { + RelayRouteMode::Direct + } else { + RelayRouteMode::Middle + }; + if runtime.set_mode(mode).is_some() { + successful_transitions.fetch_add(1, Ordering::Relaxed); + } + } + })); + } + + for worker in workers { + worker.join().expect("route mode transition worker must not panic"); + } + }); + + let final_state = runtime.snapshot(); + assert_eq!( + final_state.generation, + successful_transitions.load(Ordering::Relaxed), + "final generation must equal number of accepted mode transitions" + ); + assert_eq!( + final_state, + *runtime.subscribe().borrow(), + "watch and snapshot state must match after concurrent transition accounting" + ); +} + #[test] fn light_fuzz_cutover_stagger_delay_distribution_stays_in_fixed_window() { // Deterministic xorshift fuzzing keeps this test stable across runs. diff --git a/src/tls_front/emulator.rs b/src/tls_front/emulator.rs index 3278f63..7e329c5 100644 --- a/src/tls_front/emulator.rs +++ b/src/tls_front/emulator.rs @@ -103,7 +103,7 @@ pub fn build_emulated_server_hello( cached: &CachedTlsData, use_full_cert_payload: bool, rng: &SecureRandom, - alpn: Option>, + _alpn: Option>, new_session_tickets: u8, ) -> Vec { // --- ServerHello --- @@ -117,15 +117,6 @@ pub fn build_emulated_server_hello( extensions.extend_from_slice(&0x002bu16.to_be_bytes()); extensions.extend_from_slice(&(2u16).to_be_bytes()); extensions.extend_from_slice(&0x0304u16.to_be_bytes()); - if let Some(alpn_proto) = &alpn { - extensions.extend_from_slice(&0x0010u16.to_be_bytes()); - let list_len: u16 = 1 + alpn_proto.len() as u16; - let ext_len: u16 = 2 + list_len; - extensions.extend_from_slice(&ext_len.to_be_bytes()); - extensions.extend_from_slice(&list_len.to_be_bytes()); - extensions.push(alpn_proto.len() as u8); - extensions.extend_from_slice(alpn_proto); - } let extensions_len = extensions.len() as u16; let body_len = 2 + // version