diff --git a/src/main.rs b/src/main.rs index 16a8bdf..dff8c8a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,6 +9,9 @@ mod ip_tracker; #[cfg(test)] #[path = "tests/ip_tracker_regression_tests.rs"] mod ip_tracker_regression_tests; +#[cfg(test)] +#[path = "tests/ip_tracker_hotpath_adversarial_tests.rs"] +mod ip_tracker_hotpath_adversarial_tests; mod maestro; mod metrics; mod network; diff --git a/src/protocol/tests/tls_adversarial_tests.rs b/src/protocol/tests/tls_adversarial_tests.rs index 4c8aa72..b8df41a 100644 --- a/src/protocol/tests/tls_adversarial_tests.rs +++ b/src/protocol/tests/tls_adversarial_tests.rs @@ -307,9 +307,8 @@ fn extract_sni_with_duplicate_extensions_rejected() { h.extend_from_slice(&(handshake.len() as u16).to_be_bytes()); h.extend_from_slice(&handshake); - // Parser might return first, see second, or fail. OWASP ASVS prefers rejection of unexpected dups. - // Telemt's `extract_sni` returns the first one found. - assert!(extract_sni_from_client_hello(&h).is_some()); + // Duplicate SNI extensions are ambiguous and must fail closed. + assert!(extract_sni_from_client_hello(&h).is_none()); } #[test] diff --git a/src/protocol/tls.rs b/src/protocol/tls.rs index dc15a1e..9cac85e 100644 --- a/src/protocol/tls.rs +++ b/src/protocol/tls.rs @@ -588,6 +588,9 @@ pub fn extract_sni_from_client_hello(handshake: &[u8]) -> Option { return None; } + let mut saw_sni_extension = false; + let mut extracted_sni = None; + while pos + 4 <= ext_end { let etype = u16::from_be_bytes([handshake[pos], handshake[pos + 1]]); let elen = u16::from_be_bytes([handshake[pos + 2], handshake[pos + 3]]) as usize; @@ -595,6 +598,12 @@ pub fn extract_sni_from_client_hello(handshake: &[u8]) -> Option { if pos + elen > ext_end { break; } + if etype == 0x0000 { + if saw_sni_extension { + return None; + } + saw_sni_extension = true; + } if etype == 0x0000 && elen >= 5 { // server_name extension let list_len = u16::from_be_bytes([handshake[pos], handshake[pos + 1]]) as usize; @@ -611,7 +620,8 @@ pub fn extract_sni_from_client_hello(handshake: &[u8]) -> Option { && let Ok(host) = std::str::from_utf8(&handshake[sn_pos..sn_pos + name_len]) { if is_valid_sni_hostname(host) { - return Some(host.to_string()); + extracted_sni = Some(host.to_string()); + break; } } sn_pos += name_len; @@ -620,7 +630,7 @@ pub fn extract_sni_from_client_hello(handshake: &[u8]) -> Option { pos += elen; } - None + extracted_sni } fn is_valid_sni_hostname(host: &str) -> bool { diff --git a/src/proxy/direct_relay.rs b/src/proxy/direct_relay.rs index 801206b..18cbda3 100644 --- a/src/proxy/direct_relay.rs +++ b/src/proxy/direct_relay.rs @@ -27,6 +27,10 @@ use crate::transport::UpstreamManager; #[cfg(unix)] use std::os::unix::fs::OpenOptionsExt; +#[cfg(unix)] +use std::os::unix::ffi::OsStrExt; +#[cfg(unix)] +use std::os::unix::io::{AsRawFd, FromRawFd}; const UNKNOWN_DC_LOG_DISTINCT_LIMIT: usize = 1024; static LOGGED_UNKNOWN_DCS: OnceLock>> = OnceLock::new(); @@ -136,6 +140,7 @@ fn unknown_dc_log_path_is_still_safe(path: &SanitizedUnknownDcLogPath) -> bool { true } +#[cfg(test)] fn open_unknown_dc_log_append(path: &Path) -> std::io::Result { #[cfg(unix)] { @@ -155,6 +160,64 @@ fn open_unknown_dc_log_append(path: &Path) -> std::io::Result { } } +fn open_unknown_dc_log_append_anchored(path: &SanitizedUnknownDcLogPath) -> std::io::Result { + #[cfg(unix)] + { + let parent = OpenOptions::new() + .read(true) + .custom_flags(libc::O_DIRECTORY | libc::O_NOFOLLOW | libc::O_CLOEXEC) + .open(&path.allowed_parent)?; + + let file_name = std::ffi::CString::new(path.file_name.as_os_str().as_bytes()) + .map_err(|_| std::io::Error::new(std::io::ErrorKind::InvalidInput, "unknown DC log file name contains NUL byte"))?; + + let fd = unsafe { + libc::openat( + parent.as_raw_fd(), + file_name.as_ptr(), + libc::O_CREAT | libc::O_APPEND | libc::O_WRONLY | libc::O_NOFOLLOW | libc::O_CLOEXEC, + 0o600, + ) + }; + + if fd < 0 { + return Err(std::io::Error::last_os_error()); + } + + let file = unsafe { std::fs::File::from_raw_fd(fd) }; + Ok(file) + } + #[cfg(not(unix))] + { + let _ = path; + Err(std::io::Error::new( + std::io::ErrorKind::PermissionDenied, + "unknown_dc_file_log_enabled requires unix O_NOFOLLOW support", + )) + } +} + +fn append_unknown_dc_line(file: &mut std::fs::File, dc_idx: i16) -> std::io::Result<()> { + #[cfg(unix)] + { + if unsafe { libc::flock(file.as_raw_fd(), libc::LOCK_EX) } != 0 { + return Err(std::io::Error::last_os_error()); + } + + let write_result = writeln!(file, "dc_idx={dc_idx}"); + + if unsafe { libc::flock(file.as_raw_fd(), libc::LOCK_UN) } != 0 { + return Err(std::io::Error::last_os_error()); + } + + write_result + } + #[cfg(not(unix))] + { + writeln!(file, "dc_idx={dc_idx}") + } +} + #[cfg(test)] fn clear_unknown_dc_log_cache_for_testing() { if let Some(set) = LOGGED_UNKNOWN_DCS.get() @@ -321,9 +384,9 @@ fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result { if should_log_unknown_dc(dc_idx) { handle.spawn_blocking(move || { if unknown_dc_log_path_is_still_safe(&path) - && let Ok(mut file) = open_unknown_dc_log_append(&path.resolved_path) + && let Ok(mut file) = open_unknown_dc_log_append_anchored(&path) { - let _ = writeln!(file, "dc_idx={dc_idx}"); + let _ = append_unknown_dc_line(&mut file, dc_idx); } }); } @@ -394,3 +457,15 @@ where #[cfg(test)] #[path = "tests/direct_relay_security_tests.rs"] mod security_tests; + +#[cfg(test)] +#[path = "tests/direct_relay_business_logic_tests.rs"] +mod business_logic_tests; + +#[cfg(test)] +#[path = "tests/direct_relay_common_mistakes_tests.rs"] +mod common_mistakes_tests; + +#[cfg(test)] +#[path = "tests/direct_relay_subtle_adversarial_tests.rs"] +mod subtle_adversarial_tests; diff --git a/src/proxy/relay.rs b/src/proxy/relay.rs index ed7e758..949f2c2 100644 --- a/src/proxy/relay.rs +++ b/src/proxy/relay.rs @@ -263,6 +263,33 @@ fn is_quota_io_error(err: &io::Error) -> bool { } static QUOTA_USER_LOCKS: OnceLock>>> = OnceLock::new(); +static QUOTA_USER_OVERFLOW_LOCKS: OnceLock>>> = OnceLock::new(); + +#[cfg(test)] +const QUOTA_USER_LOCKS_MAX: usize = 64; +#[cfg(not(test))] +const QUOTA_USER_LOCKS_MAX: usize = 4_096; +#[cfg(test)] +const QUOTA_OVERFLOW_LOCK_STRIPES: usize = 16; +#[cfg(not(test))] +const QUOTA_OVERFLOW_LOCK_STRIPES: usize = 256; + +#[cfg(test)] +fn quota_user_lock_test_guard() -> &'static Mutex<()> { + static TEST_LOCK: OnceLock> = OnceLock::new(); + TEST_LOCK.get_or_init(|| Mutex::new(())) +} + +fn quota_overflow_user_lock(user: &str) -> Arc> { + let stripes = QUOTA_USER_OVERFLOW_LOCKS.get_or_init(|| { + (0..QUOTA_OVERFLOW_LOCK_STRIPES) + .map(|_| Arc::new(Mutex::new(()))) + .collect() + }); + + let hash = crc32fast::hash(user.as_bytes()) as usize; + Arc::clone(&stripes[hash % stripes.len()]) +} fn quota_user_lock(user: &str) -> Arc> { let locks = QUOTA_USER_LOCKS.get_or_init(DashMap::new); @@ -270,6 +297,14 @@ fn quota_user_lock(user: &str) -> Arc> { return Arc::clone(existing.value()); } + if locks.len() >= QUOTA_USER_LOCKS_MAX { + locks.retain(|_, value| Arc::strong_count(value) > 1); + } + + if locks.len() >= QUOTA_USER_LOCKS_MAX { + return quota_overflow_user_lock(user); + } + let created = Arc::new(Mutex::new(())); match locks.entry(user.to_string()) { dashmap::mapref::entry::Entry::Occupied(entry) => Arc::clone(entry.get()), @@ -662,4 +697,8 @@ mod security_tests; #[cfg(test)] #[path = "tests/relay_adversarial_tests.rs"] -mod adversarial_tests; \ No newline at end of file +mod adversarial_tests; + +#[cfg(test)] +#[path = "tests/relay_quota_lock_pressure_adversarial_tests.rs"] +mod relay_quota_lock_pressure_adversarial_tests; \ No newline at end of file diff --git a/src/proxy/tests/client_tls_clienthello_truncation_adversarial_tests.rs b/src/proxy/tests/client_tls_clienthello_truncation_adversarial_tests.rs index dfd0c55..6ac02dd 100644 --- a/src/proxy/tests/client_tls_clienthello_truncation_adversarial_tests.rs +++ b/src/proxy/tests/client_tls_clienthello_truncation_adversarial_tests.rs @@ -249,7 +249,7 @@ async fn run_blackhat_client_handler_fragmented_probe_should_mask( } #[tokio::test] -async fn blackhat_truncated_in_range_clienthello_generic_stream_should_mask_but_leaks() { +async fn blackhat_truncated_in_range_clienthello_generic_stream_should_mask() { let mask_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let mask_addr = mask_listener.local_addr().unwrap(); let backend_reply = b"HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n".to_vec(); @@ -309,7 +309,7 @@ async fn blackhat_truncated_in_range_clienthello_generic_stream_should_mask_but_ client_side.shutdown().await.unwrap(); // Security expectation: even malformed in-range TLS should be masked. - // Current code leaks by returning EOF/timeout instead of masking. + // This invariant must hold to avoid probe-distinguishable EOF/timeout behavior. let mut observed = vec![0u8; backend_reply.len()]; tokio::time::timeout(Duration::from_secs(2), client_side.read_exact(&mut observed)) .await @@ -329,7 +329,7 @@ async fn blackhat_truncated_in_range_clienthello_generic_stream_should_mask_but_ } #[tokio::test] -async fn blackhat_truncated_in_range_clienthello_client_handler_should_mask_but_leaks() { +async fn blackhat_truncated_in_range_clienthello_client_handler_should_mask() { let mask_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let mask_addr = mask_listener.local_addr().unwrap(); @@ -429,7 +429,7 @@ async fn blackhat_truncated_in_range_clienthello_client_handler_should_mask_but_ } #[tokio::test] -async fn blackhat_generic_truncated_min_body_1_should_mask_but_leaks() { +async fn blackhat_generic_truncated_min_body_1_should_mask() { run_blackhat_generic_fragmented_probe_should_mask( truncated_in_range_record(1), &[6], @@ -440,7 +440,7 @@ async fn blackhat_generic_truncated_min_body_1_should_mask_but_leaks() { } #[tokio::test] -async fn blackhat_generic_truncated_min_body_8_should_mask_but_leaks() { +async fn blackhat_generic_truncated_min_body_8_should_mask() { run_blackhat_generic_fragmented_probe_should_mask( truncated_in_range_record(8), &[13], @@ -451,7 +451,7 @@ async fn blackhat_generic_truncated_min_body_8_should_mask_but_leaks() { } #[tokio::test] -async fn blackhat_generic_truncated_min_body_99_should_mask_but_leaks() { +async fn blackhat_generic_truncated_min_body_99_should_mask() { run_blackhat_generic_fragmented_probe_should_mask( truncated_in_range_record(MIN_TLS_CLIENT_HELLO_SIZE - 1), &[5, MIN_TLS_CLIENT_HELLO_SIZE - 1], @@ -462,7 +462,7 @@ async fn blackhat_generic_truncated_min_body_99_should_mask_but_leaks() { } #[tokio::test] -async fn blackhat_generic_fragmented_header_then_close_should_mask_but_leaks() { +async fn blackhat_generic_fragmented_header_then_close_should_mask() { run_blackhat_generic_fragmented_probe_should_mask( truncated_in_range_record(0), &[1, 1, 1, 1, 1], @@ -473,7 +473,7 @@ async fn blackhat_generic_fragmented_header_then_close_should_mask_but_leaks() { } #[tokio::test] -async fn blackhat_generic_fragmented_header_plus_partial_body_should_mask_but_leaks() { +async fn blackhat_generic_fragmented_header_plus_partial_body_should_mask() { run_blackhat_generic_fragmented_probe_should_mask( truncated_in_range_record(5), &[1, 1, 1, 1, 1, 1, 1, 1, 1, 1], @@ -495,7 +495,7 @@ async fn blackhat_generic_slowloris_fragmented_min_probe_should_mask_but_times_o } #[tokio::test] -async fn blackhat_client_handler_truncated_min_body_1_should_mask_but_leaks() { +async fn blackhat_client_handler_truncated_min_body_1_should_mask() { run_blackhat_client_handler_fragmented_probe_should_mask( truncated_in_range_record(1), &[6], @@ -506,7 +506,7 @@ async fn blackhat_client_handler_truncated_min_body_1_should_mask_but_leaks() { } #[tokio::test] -async fn blackhat_client_handler_truncated_min_body_8_should_mask_but_leaks() { +async fn blackhat_client_handler_truncated_min_body_8_should_mask() { run_blackhat_client_handler_fragmented_probe_should_mask( truncated_in_range_record(8), &[13], @@ -517,7 +517,7 @@ async fn blackhat_client_handler_truncated_min_body_8_should_mask_but_leaks() { } #[tokio::test] -async fn blackhat_client_handler_truncated_min_body_99_should_mask_but_leaks() { +async fn blackhat_client_handler_truncated_min_body_99_should_mask() { run_blackhat_client_handler_fragmented_probe_should_mask( truncated_in_range_record(MIN_TLS_CLIENT_HELLO_SIZE - 1), &[5, MIN_TLS_CLIENT_HELLO_SIZE - 1], @@ -528,7 +528,7 @@ async fn blackhat_client_handler_truncated_min_body_99_should_mask_but_leaks() { } #[tokio::test] -async fn blackhat_client_handler_fragmented_header_then_close_should_mask_but_leaks() { +async fn blackhat_client_handler_fragmented_header_then_close_should_mask() { run_blackhat_client_handler_fragmented_probe_should_mask( truncated_in_range_record(0), &[1, 1, 1, 1, 1], @@ -539,7 +539,7 @@ async fn blackhat_client_handler_fragmented_header_then_close_should_mask_but_le } #[tokio::test] -async fn blackhat_client_handler_fragmented_header_plus_partial_body_should_mask_but_leaks() { +async fn blackhat_client_handler_fragmented_header_plus_partial_body_should_mask() { run_blackhat_client_handler_fragmented_probe_should_mask( truncated_in_range_record(5), &[1, 1, 1, 1, 1, 1, 1, 1, 1, 1], diff --git a/src/proxy/tests/direct_relay_business_logic_tests.rs b/src/proxy/tests/direct_relay_business_logic_tests.rs new file mode 100644 index 0000000..166518e --- /dev/null +++ b/src/proxy/tests/direct_relay_business_logic_tests.rs @@ -0,0 +1,51 @@ +use super::*; +use crate::protocol::constants::{TG_DATACENTER_PORT, TG_DATACENTERS_V4, TG_DATACENTERS_V6}; +use std::net::SocketAddr; + +#[test] +fn business_scope_hint_accepts_exact_boundary_length() { + let value = format!("scope_{}", "a".repeat(MAX_SCOPE_HINT_LEN)); + assert_eq!(validated_scope_hint(&value), Some("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa")); +} + +#[test] +fn business_scope_hint_rejects_missing_prefix_even_when_charset_is_valid() { + assert_eq!(validated_scope_hint("alpha-01"), None); +} + +#[test] +fn business_known_dc_uses_ipv4_table_by_default() { + let cfg = ProxyConfig::default(); + let resolved = get_dc_addr_static(2, &cfg).expect("known dc must resolve"); + let expected = SocketAddr::new(TG_DATACENTERS_V4[1], TG_DATACENTER_PORT); + assert_eq!(resolved, expected); +} + +#[test] +fn business_negative_dc_maps_by_absolute_value() { + let cfg = ProxyConfig::default(); + let resolved = get_dc_addr_static(-3, &cfg).expect("negative dc index must map by absolute value"); + let expected = SocketAddr::new(TG_DATACENTERS_V4[2], TG_DATACENTER_PORT); + assert_eq!(resolved, expected); +} + +#[test] +fn business_known_dc_uses_ipv6_table_when_preferred_and_enabled() { + let mut cfg = ProxyConfig::default(); + cfg.network.prefer = 6; + cfg.network.ipv6 = Some(true); + + let resolved = get_dc_addr_static(1, &cfg).expect("known dc must resolve on ipv6 path"); + let expected = SocketAddr::new(TG_DATACENTERS_V6[0], TG_DATACENTER_PORT); + assert_eq!(resolved, expected); +} + +#[test] +fn business_unknown_dc_uses_configured_default_dc_when_in_range() { + let mut cfg = ProxyConfig::default(); + cfg.default_dc = Some(4); + + let resolved = get_dc_addr_static(29_999, &cfg).expect("unknown dc must resolve to configured default"); + let expected = SocketAddr::new(TG_DATACENTERS_V4[3], TG_DATACENTER_PORT); + assert_eq!(resolved, expected); +} diff --git a/src/proxy/tests/direct_relay_common_mistakes_tests.rs b/src/proxy/tests/direct_relay_common_mistakes_tests.rs new file mode 100644 index 0000000..ef40f37 --- /dev/null +++ b/src/proxy/tests/direct_relay_common_mistakes_tests.rs @@ -0,0 +1,98 @@ +use super::*; +use crate::protocol::constants::{TG_DATACENTER_PORT, TG_DATACENTERS_V4}; +use std::collections::HashSet; +use std::net::SocketAddr; +use std::sync::Mutex; + +#[test] +fn common_invalid_override_entries_fallback_to_static_table() { + let mut cfg = ProxyConfig::default(); + cfg.dc_overrides.insert( + "2".to_string(), + vec!["bad-address".to_string(), "still-bad".to_string()], + ); + + let resolved = get_dc_addr_static(2, &cfg).expect("fallback to static table must still resolve"); + let expected = SocketAddr::new(TG_DATACENTERS_V4[1], TG_DATACENTER_PORT); + assert_eq!(resolved, expected); +} + +#[test] +fn common_prefer_v6_with_only_ipv4_override_uses_override_instead_of_ignoring_it() { + let mut cfg = ProxyConfig::default(); + cfg.network.prefer = 6; + cfg.network.ipv6 = Some(true); + cfg.dc_overrides + .insert("3".to_string(), vec!["203.0.113.203:443".to_string()]); + + let resolved = get_dc_addr_static(3, &cfg).expect("ipv4 override must be used if no ipv6 override exists"); + assert_eq!(resolved, "203.0.113.203:443".parse::().unwrap()); +} + +#[test] +fn common_scope_hint_rejects_unicode_lookalike_characters() { + assert_eq!(validated_scope_hint("scope_аlpha"), None); + assert_eq!(validated_scope_hint("scope_Αlpha"), None); +} + +#[cfg(unix)] +#[test] +fn common_anchored_open_rejects_nul_filename() { + use std::os::unix::ffi::OsStringExt; + + let parent = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!("telemt-direct-relay-nul-{}", std::process::id())); + std::fs::create_dir_all(&parent).expect("parent directory must be creatable"); + + let path = SanitizedUnknownDcLogPath { + resolved_path: parent.join("placeholder.log"), + allowed_parent: parent, + file_name: std::ffi::OsString::from_vec(vec![b'a', 0, b'b']), + }; + + let err = open_unknown_dc_log_append_anchored(&path) + .expect_err("anchored open must fail on NUL in filename"); + assert_eq!(err.kind(), std::io::ErrorKind::InvalidInput); +} + +#[cfg(unix)] +#[test] +fn common_anchored_open_creates_owner_only_file_permissions() { + use std::os::unix::fs::PermissionsExt; + + let parent = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!("telemt-direct-relay-perm-{}", std::process::id())); + std::fs::create_dir_all(&parent).expect("parent directory must be creatable"); + + let sanitized = SanitizedUnknownDcLogPath { + resolved_path: parent.join("unknown-dc.log"), + allowed_parent: parent.clone(), + file_name: std::ffi::OsString::from("unknown-dc.log"), + }; + + let mut file = open_unknown_dc_log_append_anchored(&sanitized) + .expect("anchored open must create regular file"); + use std::io::Write; + writeln!(file, "dc_idx=1").expect("write must succeed"); + + let mode = std::fs::metadata(parent.join("unknown-dc.log")) + .expect("metadata must be readable") + .permissions() + .mode() + & 0o777; + assert_eq!(mode, 0o600); +} + +#[test] +fn common_duplicate_dc_attempts_do_not_consume_unique_slots() { + let set = Mutex::new(HashSet::new()); + + assert!(should_log_unknown_dc_with_set(&set, 100)); + assert!(!should_log_unknown_dc_with_set(&set, 100)); + assert!(should_log_unknown_dc_with_set(&set, 101)); + assert_eq!(set.lock().expect("set lock must be available").len(), 2); +} diff --git a/src/proxy/tests/direct_relay_security_tests.rs b/src/proxy/tests/direct_relay_security_tests.rs index e8016a5..7c3a51e 100644 --- a/src/proxy/tests/direct_relay_security_tests.rs +++ b/src/proxy/tests/direct_relay_security_tests.rs @@ -667,6 +667,56 @@ fn adversarial_check_then_symlink_flip_is_blocked_by_nofollow_open() { ); } +#[cfg(unix)] +#[test] +fn adversarial_parent_swap_after_check_is_blocked_by_anchored_open() { + use std::os::unix::fs::symlink; + + let base = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!("telemt-unknown-dc-parent-swap-openat-{}", std::process::id())); + fs::create_dir_all(&base).expect("parent-swap-openat base must be creatable"); + + let rel_candidate = format!( + "target/telemt-unknown-dc-parent-swap-openat-{}/unknown-dc.log", + std::process::id() + ); + let sanitized = sanitize_unknown_dc_log_path(&rel_candidate) + .expect("candidate must sanitize before parent swap"); + fs::write(&sanitized.resolved_path, "seed\n").expect("seed target file must be writable"); + + assert!( + unknown_dc_log_path_is_still_safe(&sanitized), + "precondition: target should initially pass revalidation" + ); + + let outside_parent = std::env::temp_dir().join(format!( + "telemt-unknown-dc-parent-swap-openat-outside-{}", + std::process::id() + )); + fs::create_dir_all(&outside_parent).expect("outside parent directory must be creatable"); + let outside_target = outside_parent.join("unknown-dc.log"); + let _ = fs::remove_file(&outside_target); + + let moved = base.with_extension("bak"); + let _ = fs::remove_dir_all(&moved); + fs::rename(&base, &moved).expect("base parent must be movable for swap simulation"); + symlink(&outside_parent, &base).expect("base parent symlink replacement must be creatable"); + + let err = open_unknown_dc_log_append_anchored(&sanitized) + .expect_err("anchored open must fail when parent is swapped to symlink"); + let raw = err.raw_os_error(); + assert!( + matches!(raw, Some(libc::ELOOP) | Some(libc::ENOTDIR) | Some(libc::ENOENT)), + "anchored open must fail closed on parent swap race, got raw_os_error={raw:?}" + ); + assert!( + !outside_target.exists(), + "anchored open must never create a log file in swapped outside parent" + ); +} + #[tokio::test] async fn unknown_dc_absolute_log_path_writes_one_entry() { let _guard = unknown_dc_test_lock() diff --git a/src/proxy/tests/direct_relay_subtle_adversarial_tests.rs b/src/proxy/tests/direct_relay_subtle_adversarial_tests.rs new file mode 100644 index 0000000..5cbbc68 --- /dev/null +++ b/src/proxy/tests/direct_relay_subtle_adversarial_tests.rs @@ -0,0 +1,197 @@ +use super::*; +use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; + +fn nonempty_line_count(text: &str) -> usize { + text.lines().filter(|line| !line.trim().is_empty()).count() +} + +#[test] +fn subtle_stress_single_unknown_dc_under_concurrency_logs_once() { + let _guard = unknown_dc_test_lock() + .lock() + .expect("unknown dc test lock must be available"); + clear_unknown_dc_log_cache_for_testing(); + + let winners = Arc::new(AtomicUsize::new(0)); + let mut workers = Vec::new(); + + for _ in 0..128 { + let winners = Arc::clone(&winners); + workers.push(std::thread::spawn(move || { + if should_log_unknown_dc(31_333) { + winners.fetch_add(1, Ordering::Relaxed); + } + })); + } + + for worker in workers { + worker.join().expect("worker must not panic"); + } + + assert_eq!(winners.load(Ordering::Relaxed), 1); +} + +#[test] +fn subtle_light_fuzz_scope_hint_matches_oracle() { + fn oracle(input: &str) -> bool { + let Some(rest) = input.strip_prefix("scope_") else { + return false; + }; + !rest.is_empty() + && rest.len() <= MAX_SCOPE_HINT_LEN + && rest + .bytes() + .all(|b| b.is_ascii_alphanumeric() || b == b'-') + } + + let mut state: u64 = 0xC0FF_EE11_D15C_AFE5; + for _ in 0..4_096 { + state ^= state << 7; + state ^= state >> 9; + state ^= state << 8; + + let len = (state as usize % 72) + 1; + let mut s = String::with_capacity(len + 6); + if (state & 1) == 0 { + s.push_str("scope_"); + } else { + s.push_str("user_"); + } + + for idx in 0..len { + let v = ((state >> ((idx % 8) * 8)) & 0xff) as u8; + let ch = match v % 6 { + 0 => (b'a' + (v % 26)) as char, + 1 => (b'A' + (v % 26)) as char, + 2 => (b'0' + (v % 10)) as char, + 3 => '-', + 4 => '_', + _ => '.', + }; + s.push(ch); + } + + let got = validated_scope_hint(&s).is_some(); + assert_eq!(got, oracle(&s), "mismatch for input: {s}"); + } +} + +#[test] +fn subtle_light_fuzz_dc_resolution_never_panics_and_preserves_port() { + let mut state: u64 = 0x1234_5678_9ABC_DEF0; + + for _ in 0..2_048 { + state ^= state << 13; + state ^= state >> 7; + state ^= state << 17; + + let mut cfg = ProxyConfig::default(); + cfg.network.prefer = if (state & 1) == 0 { 4 } else { 6 }; + cfg.network.ipv6 = Some((state & 2) != 0); + cfg.default_dc = Some(((state >> 8) as u8).max(1)); + + let dc_idx = (state as i16).wrapping_sub(16_384); + let resolved = get_dc_addr_static(dc_idx, &cfg).expect("dc resolution must never fail"); + + assert_eq!(resolved.port(), crate::protocol::constants::TG_DATACENTER_PORT); + let expect_v6 = cfg.network.prefer == 6 && cfg.network.ipv6.unwrap_or(true); + assert_eq!(resolved.is_ipv6(), expect_v6); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn subtle_integration_parallel_same_dc_logs_one_line() { + let _guard = unknown_dc_test_lock() + .lock() + .expect("unknown dc test lock must be available"); + clear_unknown_dc_log_cache_for_testing(); + + let rel_dir = format!("target/telemt-direct-relay-same-{}", std::process::id()); + let rel_file = format!("{rel_dir}/unknown-dc.log"); + let abs_dir = std::env::current_dir() + .expect("cwd must be available") + .join(&rel_dir); + std::fs::create_dir_all(&abs_dir).expect("log directory must be creatable"); + let abs_file = abs_dir.join("unknown-dc.log"); + let _ = std::fs::remove_file(&abs_file); + + let mut cfg = ProxyConfig::default(); + cfg.general.unknown_dc_file_log_enabled = true; + cfg.general.unknown_dc_log_path = Some(rel_file); + + let cfg = Arc::new(cfg); + let mut tasks = Vec::new(); + for _ in 0..32 { + let cfg = Arc::clone(&cfg); + tasks.push(tokio::spawn(async move { + let _ = get_dc_addr_static(31_777, cfg.as_ref()); + })); + } + for task in tasks { + task.await.expect("task must not panic"); + } + + for _ in 0..60 { + if let Ok(content) = std::fs::read_to_string(&abs_file) + && nonempty_line_count(&content) == 1 + { + return; + } + tokio::time::sleep(std::time::Duration::from_millis(20)).await; + } + + let content = std::fs::read_to_string(&abs_file).unwrap_or_default(); + assert_eq!(nonempty_line_count(&content), 1); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn subtle_integration_parallel_unique_dcs_log_unique_lines() { + let _guard = unknown_dc_test_lock() + .lock() + .expect("unknown dc test lock must be available"); + clear_unknown_dc_log_cache_for_testing(); + + let rel_dir = format!("target/telemt-direct-relay-unique-{}", std::process::id()); + let rel_file = format!("{rel_dir}/unknown-dc.log"); + let abs_dir = std::env::current_dir() + .expect("cwd must be available") + .join(&rel_dir); + std::fs::create_dir_all(&abs_dir).expect("log directory must be creatable"); + let abs_file = abs_dir.join("unknown-dc.log"); + let _ = std::fs::remove_file(&abs_file); + + let mut cfg = ProxyConfig::default(); + cfg.general.unknown_dc_file_log_enabled = true; + cfg.general.unknown_dc_log_path = Some(rel_file); + + let cfg = Arc::new(cfg); + let dcs = [31_901_i16, 31_902, 31_903, 31_904, 31_905, 31_906, 31_907, 31_908]; + let mut tasks = Vec::new(); + + for dc in dcs { + let cfg = Arc::clone(&cfg); + tasks.push(tokio::spawn(async move { + let _ = get_dc_addr_static(dc, cfg.as_ref()); + })); + } + + for task in tasks { + task.await.expect("task must not panic"); + } + + for _ in 0..80 { + if let Ok(content) = std::fs::read_to_string(&abs_file) + && nonempty_line_count(&content) >= 8 + { + return; + } + tokio::time::sleep(std::time::Duration::from_millis(20)).await; + } + + let content = std::fs::read_to_string(&abs_file).unwrap_or_default(); + assert!( + nonempty_line_count(&content) >= 8, + "expected at least one line per unique dc, content: {content}" + ); +} diff --git a/src/proxy/tests/relay_quota_lock_pressure_adversarial_tests.rs b/src/proxy/tests/relay_quota_lock_pressure_adversarial_tests.rs new file mode 100644 index 0000000..fd8fb2f --- /dev/null +++ b/src/proxy/tests/relay_quota_lock_pressure_adversarial_tests.rs @@ -0,0 +1,409 @@ +use super::*; +use crate::error::ProxyError; +use crate::stats::Stats; +use crate::stream::BufferPool; +use dashmap::DashMap; +use std::sync::Arc; +use std::sync::atomic::AtomicBool; +use std::time::Duration; +use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; +use tokio::sync::Barrier; +use tokio::time::Instant; + +#[test] +fn quota_lock_same_user_returns_same_arc_instance() { + let _guard = super::quota_user_lock_test_guard() + .lock() + .expect("quota lock test guard must be available"); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let a = quota_user_lock("quota-lock-same-user"); + let b = quota_user_lock("quota-lock-same-user"); + assert!(Arc::ptr_eq(&a, &b)); +} + +#[test] +fn quota_lock_parallel_same_user_reuses_single_lock() { + let _guard = super::quota_user_lock_test_guard() + .lock() + .expect("quota lock test guard must be available"); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let user = "quota-lock-parallel-same"; + let mut handles = Vec::new(); + + for _ in 0..64 { + handles.push(std::thread::spawn(move || quota_user_lock(user))); + } + + let first = handles + .remove(0) + .join() + .expect("thread must return lock handle"); + + for handle in handles { + let got = handle.join().expect("thread must return lock handle"); + assert!(Arc::ptr_eq(&first, &got)); + } +} + +#[test] +fn quota_lock_unique_users_materialize_distinct_entries() { + let _guard = super::quota_user_lock_test_guard() + .lock() + .expect("quota lock test guard must be available"); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + + map.clear(); + + let base = format!("quota-lock-distinct-{}", std::process::id()); + let users: Vec = (0..(QUOTA_USER_LOCKS_MAX / 2)) + .map(|idx| format!("{base}-{idx}")) + .collect(); + + for user in &users { + let _ = quota_user_lock(user); + } + + for user in &users { + assert!(map.get(user).is_some(), "lock cache must contain entry for {user}"); + } +} + +#[test] +fn quota_lock_unique_churn_stress_keeps_all_inserted_keys_addressable() { + let _guard = super::quota_user_lock_test_guard() + .lock() + .expect("quota lock test guard must be available"); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + + map.clear(); + + let base = format!("quota-lock-churn-{}", std::process::id()); + for idx in 0..(QUOTA_USER_LOCKS_MAX + 256) { + let _ = quota_user_lock(&format!("{base}-{idx}")); + } + + assert!( + map.len() <= QUOTA_USER_LOCKS_MAX, + "quota lock cache must stay bounded under unique-user churn" + ); +} + +#[test] +fn quota_lock_saturation_returns_stable_overflow_lock_without_cache_growth() { + let _guard = super::quota_user_lock_test_guard() + .lock() + .expect("quota lock test guard must be available"); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let prefix = format!("quota-held-{}", std::process::id()); + let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); + for idx in 0..QUOTA_USER_LOCKS_MAX { + retained.push(quota_user_lock(&format!("{prefix}-{idx}"))); + } + + assert_eq!( + map.len(), + QUOTA_USER_LOCKS_MAX, + "cache must be saturated for overflow check" + ); + + let overflow_user = format!("quota-overflow-{}", std::process::id()); + let overflow_a = quota_user_lock(&overflow_user); + let overflow_b = quota_user_lock(&overflow_user); + + assert_eq!( + map.len(), + QUOTA_USER_LOCKS_MAX, + "overflow path must not grow lock cache" + ); + assert!( + map.get(&overflow_user).is_none(), + "overflow user lock must stay outside bounded cache under saturation" + ); + assert!( + Arc::ptr_eq(&overflow_a, &overflow_b), + "overflow user must receive stable striped overflow lock while saturated" + ); + + drop(retained); +} + +#[test] +fn quota_lock_reclaims_unreferenced_entries_before_ephemeral_fallback() { + let _guard = super::quota_user_lock_test_guard() + .lock() + .expect("quota lock test guard must be available"); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + // Fill and immediately drop strong references, leaving only map-owned Arcs. + for idx in 0..QUOTA_USER_LOCKS_MAX { + let _ = quota_user_lock(&format!("quota-reclaim-drop-{}-{idx}", std::process::id())); + } + assert_eq!(map.len(), QUOTA_USER_LOCKS_MAX); + + let overflow_user = format!("quota-reclaim-overflow-{}", std::process::id()); + let overflow = quota_user_lock(&overflow_user); + + assert!( + map.get(&overflow_user).is_some(), + "after reclaiming stale entries, overflow user should become cacheable" + ); + assert!( + Arc::strong_count(&overflow) >= 2, + "cacheable overflow lock should be held by both map and caller" + ); +} + +#[test] +fn quota_lock_saturated_same_user_must_not_return_distinct_locks() { + let _guard = super::quota_user_lock_test_guard() + .lock() + .expect("quota lock test guard must be available"); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); + for idx in 0..QUOTA_USER_LOCKS_MAX { + retained.push(quota_user_lock(&format!("quota-saturated-held-{}-{idx}", std::process::id()))); + } + + let overflow_user = format!("quota-saturated-same-user-{}", std::process::id()); + let a = quota_user_lock(&overflow_user); + let b = quota_user_lock(&overflow_user); + + assert!( + Arc::ptr_eq(&a, &b), + "same user must not receive distinct locks under saturation because that enables quota race bypass" + ); + + drop(retained); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn quota_lock_saturation_concurrent_same_user_never_overshoots_quota() { + let _guard = super::quota_user_lock_test_guard() + .lock() + .expect("quota lock test guard must be available"); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); + for idx in 0..QUOTA_USER_LOCKS_MAX { + retained.push(quota_user_lock(&format!("quota-saturated-race-held-{}-{idx}", std::process::id()))); + } + + let stats = Arc::new(Stats::new()); + let user = format!("quota-saturated-race-user-{}", std::process::id()); + let gate = Arc::new(Barrier::new(2)); + + let worker = |label: u8, stats: Arc, user: String, gate: Arc| { + tokio::spawn(async move { + let counters = Arc::new(SharedCounters::new()); + let quota_exceeded = Arc::new(AtomicBool::new(false)); + let mut io = StatsIo::new( + tokio::io::sink(), + counters, + Arc::clone(&stats), + user, + Some(1), + quota_exceeded, + Instant::now(), + ); + gate.wait().await; + io.write_all(&[label]).await + }) + }; + + let one = worker(0x11, Arc::clone(&stats), user.clone(), Arc::clone(&gate)); + let two = worker(0x22, Arc::clone(&stats), user.clone(), Arc::clone(&gate)); + + let _ = tokio::time::timeout(Duration::from_secs(2), async { + let _ = one.await.expect("task one must not panic"); + let _ = two.await.expect("task two must not panic"); + }) + .await + .expect("quota race workers must complete"); + + assert!( + stats.get_user_total_octets(&user) <= 1, + "saturated lock path must never overshoot quota for same user" + ); + + drop(retained); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn quota_lock_saturation_stress_same_user_never_overshoots_quota() { + let _guard = super::quota_user_lock_test_guard() + .lock() + .expect("quota lock test guard must be available"); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); + for idx in 0..QUOTA_USER_LOCKS_MAX { + retained.push(quota_user_lock(&format!("quota-saturated-stress-held-{}-{idx}", std::process::id()))); + } + + for round in 0..128u32 { + let stats = Arc::new(Stats::new()); + let user = format!("quota-saturated-stress-user-{}-{round}", std::process::id()); + let gate = Arc::new(Barrier::new(2)); + + let one = { + let stats = Arc::clone(&stats); + let user = user.clone(); + let gate = Arc::clone(&gate); + tokio::spawn(async move { + let counters = Arc::new(SharedCounters::new()); + let quota_exceeded = Arc::new(AtomicBool::new(false)); + let mut io = StatsIo::new( + tokio::io::sink(), + counters, + Arc::clone(&stats), + user, + Some(1), + quota_exceeded, + Instant::now(), + ); + gate.wait().await; + io.write_all(&[0x31]).await + }) + }; + + let two = { + let stats = Arc::clone(&stats); + let user = user.clone(); + let gate = Arc::clone(&gate); + tokio::spawn(async move { + let counters = Arc::new(SharedCounters::new()); + let quota_exceeded = Arc::new(AtomicBool::new(false)); + let mut io = StatsIo::new( + tokio::io::sink(), + counters, + Arc::clone(&stats), + user, + Some(1), + quota_exceeded, + Instant::now(), + ); + gate.wait().await; + io.write_all(&[0x32]).await + }) + }; + + let _ = one.await.expect("stress task one must not panic"); + let _ = two.await.expect("stress task two must not panic"); + + assert!( + stats.get_user_total_octets(&user) <= 1, + "round {round}: saturated path must not overshoot quota" + ); + } + + drop(retained); +} + +#[test] +fn quota_error_classifier_accepts_internal_quota_sentinel_only() { + let err = quota_io_error(); + assert!(is_quota_io_error(&err)); +} + +#[test] +fn quota_error_classifier_rejects_plain_permission_denied() { + let err = std::io::Error::new(std::io::ErrorKind::PermissionDenied, "permission denied"); + assert!(!is_quota_io_error(&err)); +} + +#[tokio::test] +async fn quota_lock_integration_zero_quota_cuts_off_without_forwarding() { + let stats = Arc::new(Stats::new()); + let user = "quota-zero-user"; + + let (mut client_peer, relay_client) = duplex(2048); + let (relay_server, mut server_peer) = duplex(2048); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 512, + 512, + user, + Arc::clone(&stats), + Some(0), + Arc::new(BufferPool::new()), + )); + + client_peer + .write_all(b"x") + .await + .expect("client write must succeed"); + + let mut probe = [0u8; 1]; + let forwarded = tokio::time::timeout(Duration::from_millis(80), server_peer.read(&mut probe)).await; + if let Ok(Ok(n)) = forwarded { + assert_eq!(n, 0, "zero quota path must not forward payload bytes"); + } + + let result = tokio::time::timeout(Duration::from_secs(2), relay) + .await + .expect("relay must terminate under zero quota") + .expect("relay task must not panic"); + assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. }))); +} + +#[tokio::test] +async fn quota_lock_integration_no_quota_relays_both_directions_under_burst() { + let stats = Arc::new(Stats::new()); + + let (mut client_peer, relay_client) = duplex(8192); + let (relay_server, mut server_peer) = duplex(8192); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 1024, + 1024, + "quota-none-burst-user", + Arc::clone(&stats), + None, + Arc::new(BufferPool::new()), + )); + + let c2s = vec![0xA5; 2048]; + let s2c = vec![0x5A; 1536]; + + client_peer.write_all(&c2s).await.expect("client burst write must succeed"); + let mut got_c2s = vec![0u8; c2s.len()]; + server_peer.read_exact(&mut got_c2s).await.expect("server must receive c2s burst"); + assert_eq!(got_c2s, c2s); + + server_peer.write_all(&s2c).await.expect("server burst write must succeed"); + let mut got_s2c = vec![0u8; s2c.len()]; + client_peer.read_exact(&mut got_s2c).await.expect("client must receive s2c burst"); + assert_eq!(got_s2c, s2c); + + drop(client_peer); + drop(server_peer); + + let done = tokio::time::timeout(Duration::from_secs(2), relay) + .await + .expect("relay must terminate after peers close") + .expect("relay task must not panic"); + assert!(done.is_ok()); +} diff --git a/src/proxy/tests/relay_security_tests.rs b/src/proxy/tests/relay_security_tests.rs index 4b002a4..c7aa918 100644 --- a/src/proxy/tests/relay_security_tests.rs +++ b/src/proxy/tests/relay_security_tests.rs @@ -31,6 +31,12 @@ impl std::task::Wake for WakeCounter { #[tokio::test] async fn quota_lock_contention_does_not_self_wake_pending_writer() { + let _guard = super::quota_user_lock_test_guard() + .lock() + .expect("quota lock test guard must be available"); + let map = super::QUOTA_USER_LOCKS.get_or_init(dashmap::DashMap::new); + map.clear(); + let stats = Arc::new(Stats::new()); let user = "quota-lock-contention-user"; @@ -66,6 +72,12 @@ async fn quota_lock_contention_does_not_self_wake_pending_writer() { #[tokio::test] async fn quota_lock_contention_writer_schedules_single_deferred_wake_until_lock_acquired() { + let _guard = super::quota_user_lock_test_guard() + .lock() + .expect("quota lock test guard must be available"); + let map = super::QUOTA_USER_LOCKS.get_or_init(dashmap::DashMap::new); + map.clear(); + let stats = Arc::new(Stats::new()); let user = "quota-lock-writer-liveness-user"; @@ -133,6 +145,12 @@ async fn quota_lock_contention_writer_schedules_single_deferred_wake_until_lock_ #[tokio::test] async fn quota_lock_contention_read_path_schedules_deferred_wake_for_liveness() { + let _guard = super::quota_user_lock_test_guard() + .lock() + .expect("quota lock test guard must be available"); + let map = super::QUOTA_USER_LOCKS.get_or_init(dashmap::DashMap::new); + map.clear(); + let stats = Arc::new(Stats::new()); let user = "quota-lock-read-liveness-user"; diff --git a/src/tests/ip_tracker_hotpath_adversarial_tests.rs b/src/tests/ip_tracker_hotpath_adversarial_tests.rs new file mode 100644 index 0000000..53c4123 --- /dev/null +++ b/src/tests/ip_tracker_hotpath_adversarial_tests.rs @@ -0,0 +1,168 @@ +use std::net::{IpAddr, Ipv4Addr}; +use std::sync::Arc; +use std::time::Duration; + +use crate::config::UserMaxUniqueIpsMode; +use crate::ip_tracker::UserIpTracker; + +fn ip_from_idx(idx: u32) -> IpAddr { + IpAddr::V4(Ipv4Addr::new(10, ((idx >> 16) & 0xff) as u8, ((idx >> 8) & 0xff) as u8, (idx & 0xff) as u8)) +} + +#[tokio::test] +async fn hotpath_empty_drain_is_idempotent() { + let tracker = UserIpTracker::new(); + for _ in 0..128 { + tracker.drain_cleanup_queue().await; + } + assert_eq!(tracker.get_active_ip_count("none").await, 0); +} + +#[tokio::test] +async fn hotpath_batch_cleanup_drain_clears_all_active_entries() { + let tracker = UserIpTracker::new(); + tracker.set_user_limit("u", 100).await; + + for idx in 0..32 { + let ip = ip_from_idx(idx); + tracker.check_and_add("u", ip).await.unwrap(); + tracker.enqueue_cleanup("u".to_string(), ip); + } + + tracker.drain_cleanup_queue().await; + assert_eq!(tracker.get_active_ip_count("u").await, 0); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn hotpath_parallel_enqueue_and_drain_does_not_deadlock() { + let tracker = Arc::new(UserIpTracker::new()); + tracker.set_user_limit("p", 64).await; + + let mut tasks = Vec::new(); + for worker in 0..32u32 { + let t = tracker.clone(); + tasks.push(tokio::spawn(async move { + let ip = ip_from_idx(1_000 + worker); + for _ in 0..64 { + let _ = t.check_and_add("p", ip).await; + t.enqueue_cleanup("p".to_string(), ip); + t.drain_cleanup_queue().await; + } + })); + } + + for task in tasks { + tokio::time::timeout(Duration::from_secs(3), task) + .await + .expect("worker must not deadlock") + .expect("worker task must not panic"); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn hotpath_parallel_unique_ip_limit_never_exceeds_cap() { + let tracker = Arc::new(UserIpTracker::new()); + tracker.set_user_limit("limit", 5).await; + + let mut tasks = Vec::new(); + for idx in 0..64u32 { + let t = tracker.clone(); + tasks.push(tokio::spawn(async move { t.check_and_add("limit", ip_from_idx(idx)).await.is_ok() })); + } + + let mut admitted = 0usize; + for task in tasks { + if task.await.expect("task must not panic") { + admitted += 1; + } + } + + assert!(admitted <= 5, "admitted unique IPs must not exceed configured cap"); + assert!(tracker.get_active_ip_count("limit").await <= 5); +} + +#[tokio::test] +async fn hotpath_repeated_same_ip_counter_balances_to_zero() { + let tracker = UserIpTracker::new(); + tracker.set_user_limit("same", 1).await; + let ip = ip_from_idx(77); + + for _ in 0..512 { + tracker.check_and_add("same", ip).await.unwrap(); + } + for _ in 0..512 { + tracker.remove_ip("same", ip).await; + } + + assert_eq!(tracker.get_active_ip_count("same").await, 0); +} + +#[tokio::test] +async fn hotpath_light_fuzz_mixed_operations_preserve_limit_invariants() { + let tracker = UserIpTracker::new(); + tracker.set_user_limit("fuzz", 4).await; + + let mut state: u64 = 0xA55A_5AA5_D15C_B00B; + for _ in 0..4_000 { + state ^= state << 7; + state ^= state >> 9; + state ^= state << 8; + + let ip = ip_from_idx((state as u32) % 8); + match state & 0x3 { + 0 | 1 => { + let _ = tracker.check_and_add("fuzz", ip).await; + } + _ => { + tracker.remove_ip("fuzz", ip).await; + } + } + + assert!( + tracker.get_active_ip_count("fuzz").await <= 4, + "active count must stay within configured cap" + ); + } +} + +#[tokio::test] +async fn hotpath_multi_user_churn_keeps_isolation() { + let tracker = UserIpTracker::new(); + tracker.set_user_limit("u1", 2).await; + tracker.set_user_limit("u2", 3).await; + + for idx in 0..200u32 { + let ip1 = ip_from_idx(idx % 5); + let ip2 = ip_from_idx(100 + (idx % 7)); + let _ = tracker.check_and_add("u1", ip1).await; + let _ = tracker.check_and_add("u2", ip2).await; + if idx % 2 == 0 { + tracker.remove_ip("u1", ip1).await; + } + if idx % 3 == 0 { + tracker.remove_ip("u2", ip2).await; + } + } + + assert!(tracker.get_active_ip_count("u1").await <= 2); + assert!(tracker.get_active_ip_count("u2").await <= 3); +} + +#[tokio::test] +async fn hotpath_time_window_expiry_allows_new_ip_after_window() { + let tracker = UserIpTracker::new(); + tracker.set_user_limit("tw", 1).await; + tracker + .set_limit_policy(UserMaxUniqueIpsMode::TimeWindow, 1) + .await; + + let ip1 = ip_from_idx(901); + let ip2 = ip_from_idx(902); + + tracker.check_and_add("tw", ip1).await.unwrap(); + tracker.remove_ip("tw", ip1).await; + assert!(tracker.check_and_add("tw", ip2).await.is_err()); + + tokio::time::sleep(Duration::from_millis(1_100)).await; + assert!(tracker.check_and_add("tw", ip2).await.is_ok()); +}