Refactor and enhance tests for proxy and relay functionality

- Renamed test functions in `client_tls_clienthello_truncation_adversarial_tests.rs` to remove "but_leaks" suffix for clarity.
- Added new tests in `direct_relay_business_logic_tests.rs` to validate business logic for data center resolution and scope hints.
- Introduced tests in `direct_relay_common_mistakes_tests.rs` to cover common mistakes in direct relay configurations.
- Added security tests in `direct_relay_security_tests.rs` to ensure proper handling of symlink and parent swap scenarios.
- Created `direct_relay_subtle_adversarial_tests.rs` to stress test concurrent logging and validate scope hint behavior.
- Implemented `relay_quota_lock_pressure_adversarial_tests.rs` to test quota lock behavior under high contention and stress.
- Updated `relay_security_tests.rs` to include quota lock contention tests ensuring proper behavior under concurrent access.
- Introduced `ip_tracker_hotpath_adversarial_tests.rs` to validate the performance and correctness of the IP tracking logic under various scenarios.
This commit is contained in:
David Osipov 2026-03-21 13:38:17 +04:00
parent 8188fedf6a
commit 5933b5e821
No known key found for this signature in database
GPG Key ID: 0E55C4A47454E82E
13 changed files with 1138 additions and 21 deletions

View File

@ -9,6 +9,9 @@ mod ip_tracker;
#[cfg(test)]
#[path = "tests/ip_tracker_regression_tests.rs"]
mod ip_tracker_regression_tests;
#[cfg(test)]
#[path = "tests/ip_tracker_hotpath_adversarial_tests.rs"]
mod ip_tracker_hotpath_adversarial_tests;
mod maestro;
mod metrics;
mod network;

View File

@ -307,9 +307,8 @@ fn extract_sni_with_duplicate_extensions_rejected() {
h.extend_from_slice(&(handshake.len() as u16).to_be_bytes());
h.extend_from_slice(&handshake);
// Parser might return first, see second, or fail. OWASP ASVS prefers rejection of unexpected dups.
// Telemt's `extract_sni` returns the first one found.
assert!(extract_sni_from_client_hello(&h).is_some());
// Duplicate SNI extensions are ambiguous and must fail closed.
assert!(extract_sni_from_client_hello(&h).is_none());
}
#[test]

View File

@ -588,6 +588,9 @@ pub fn extract_sni_from_client_hello(handshake: &[u8]) -> Option<String> {
return None;
}
let mut saw_sni_extension = false;
let mut extracted_sni = None;
while pos + 4 <= ext_end {
let etype = u16::from_be_bytes([handshake[pos], handshake[pos + 1]]);
let elen = u16::from_be_bytes([handshake[pos + 2], handshake[pos + 3]]) as usize;
@ -595,6 +598,12 @@ pub fn extract_sni_from_client_hello(handshake: &[u8]) -> Option<String> {
if pos + elen > ext_end {
break;
}
if etype == 0x0000 {
if saw_sni_extension {
return None;
}
saw_sni_extension = true;
}
if etype == 0x0000 && elen >= 5 {
// server_name extension
let list_len = u16::from_be_bytes([handshake[pos], handshake[pos + 1]]) as usize;
@ -611,7 +620,8 @@ pub fn extract_sni_from_client_hello(handshake: &[u8]) -> Option<String> {
&& let Ok(host) = std::str::from_utf8(&handshake[sn_pos..sn_pos + name_len])
{
if is_valid_sni_hostname(host) {
return Some(host.to_string());
extracted_sni = Some(host.to_string());
break;
}
}
sn_pos += name_len;
@ -620,7 +630,7 @@ pub fn extract_sni_from_client_hello(handshake: &[u8]) -> Option<String> {
pos += elen;
}
None
extracted_sni
}
fn is_valid_sni_hostname(host: &str) -> bool {

View File

@ -27,6 +27,10 @@ use crate::transport::UpstreamManager;
#[cfg(unix)]
use std::os::unix::fs::OpenOptionsExt;
#[cfg(unix)]
use std::os::unix::ffi::OsStrExt;
#[cfg(unix)]
use std::os::unix::io::{AsRawFd, FromRawFd};
const UNKNOWN_DC_LOG_DISTINCT_LIMIT: usize = 1024;
static LOGGED_UNKNOWN_DCS: OnceLock<Mutex<HashSet<i16>>> = OnceLock::new();
@ -136,6 +140,7 @@ fn unknown_dc_log_path_is_still_safe(path: &SanitizedUnknownDcLogPath) -> bool {
true
}
#[cfg(test)]
fn open_unknown_dc_log_append(path: &Path) -> std::io::Result<std::fs::File> {
#[cfg(unix)]
{
@ -155,6 +160,64 @@ fn open_unknown_dc_log_append(path: &Path) -> std::io::Result<std::fs::File> {
}
}
fn open_unknown_dc_log_append_anchored(path: &SanitizedUnknownDcLogPath) -> std::io::Result<std::fs::File> {
#[cfg(unix)]
{
let parent = OpenOptions::new()
.read(true)
.custom_flags(libc::O_DIRECTORY | libc::O_NOFOLLOW | libc::O_CLOEXEC)
.open(&path.allowed_parent)?;
let file_name = std::ffi::CString::new(path.file_name.as_os_str().as_bytes())
.map_err(|_| std::io::Error::new(std::io::ErrorKind::InvalidInput, "unknown DC log file name contains NUL byte"))?;
let fd = unsafe {
libc::openat(
parent.as_raw_fd(),
file_name.as_ptr(),
libc::O_CREAT | libc::O_APPEND | libc::O_WRONLY | libc::O_NOFOLLOW | libc::O_CLOEXEC,
0o600,
)
};
if fd < 0 {
return Err(std::io::Error::last_os_error());
}
let file = unsafe { std::fs::File::from_raw_fd(fd) };
Ok(file)
}
#[cfg(not(unix))]
{
let _ = path;
Err(std::io::Error::new(
std::io::ErrorKind::PermissionDenied,
"unknown_dc_file_log_enabled requires unix O_NOFOLLOW support",
))
}
}
fn append_unknown_dc_line(file: &mut std::fs::File, dc_idx: i16) -> std::io::Result<()> {
#[cfg(unix)]
{
if unsafe { libc::flock(file.as_raw_fd(), libc::LOCK_EX) } != 0 {
return Err(std::io::Error::last_os_error());
}
let write_result = writeln!(file, "dc_idx={dc_idx}");
if unsafe { libc::flock(file.as_raw_fd(), libc::LOCK_UN) } != 0 {
return Err(std::io::Error::last_os_error());
}
write_result
}
#[cfg(not(unix))]
{
writeln!(file, "dc_idx={dc_idx}")
}
}
#[cfg(test)]
fn clear_unknown_dc_log_cache_for_testing() {
if let Some(set) = LOGGED_UNKNOWN_DCS.get()
@ -321,9 +384,9 @@ fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result<SocketAddr> {
if should_log_unknown_dc(dc_idx) {
handle.spawn_blocking(move || {
if unknown_dc_log_path_is_still_safe(&path)
&& let Ok(mut file) = open_unknown_dc_log_append(&path.resolved_path)
&& let Ok(mut file) = open_unknown_dc_log_append_anchored(&path)
{
let _ = writeln!(file, "dc_idx={dc_idx}");
let _ = append_unknown_dc_line(&mut file, dc_idx);
}
});
}
@ -394,3 +457,15 @@ where
#[cfg(test)]
#[path = "tests/direct_relay_security_tests.rs"]
mod security_tests;
#[cfg(test)]
#[path = "tests/direct_relay_business_logic_tests.rs"]
mod business_logic_tests;
#[cfg(test)]
#[path = "tests/direct_relay_common_mistakes_tests.rs"]
mod common_mistakes_tests;
#[cfg(test)]
#[path = "tests/direct_relay_subtle_adversarial_tests.rs"]
mod subtle_adversarial_tests;

View File

@ -263,6 +263,33 @@ fn is_quota_io_error(err: &io::Error) -> bool {
}
static QUOTA_USER_LOCKS: OnceLock<DashMap<String, Arc<Mutex<()>>>> = OnceLock::new();
static QUOTA_USER_OVERFLOW_LOCKS: OnceLock<Vec<Arc<Mutex<()>>>> = OnceLock::new();
#[cfg(test)]
const QUOTA_USER_LOCKS_MAX: usize = 64;
#[cfg(not(test))]
const QUOTA_USER_LOCKS_MAX: usize = 4_096;
#[cfg(test)]
const QUOTA_OVERFLOW_LOCK_STRIPES: usize = 16;
#[cfg(not(test))]
const QUOTA_OVERFLOW_LOCK_STRIPES: usize = 256;
#[cfg(test)]
fn quota_user_lock_test_guard() -> &'static Mutex<()> {
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
TEST_LOCK.get_or_init(|| Mutex::new(()))
}
fn quota_overflow_user_lock(user: &str) -> Arc<Mutex<()>> {
let stripes = QUOTA_USER_OVERFLOW_LOCKS.get_or_init(|| {
(0..QUOTA_OVERFLOW_LOCK_STRIPES)
.map(|_| Arc::new(Mutex::new(())))
.collect()
});
let hash = crc32fast::hash(user.as_bytes()) as usize;
Arc::clone(&stripes[hash % stripes.len()])
}
fn quota_user_lock(user: &str) -> Arc<Mutex<()>> {
let locks = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
@ -270,6 +297,14 @@ fn quota_user_lock(user: &str) -> Arc<Mutex<()>> {
return Arc::clone(existing.value());
}
if locks.len() >= QUOTA_USER_LOCKS_MAX {
locks.retain(|_, value| Arc::strong_count(value) > 1);
}
if locks.len() >= QUOTA_USER_LOCKS_MAX {
return quota_overflow_user_lock(user);
}
let created = Arc::new(Mutex::new(()));
match locks.entry(user.to_string()) {
dashmap::mapref::entry::Entry::Occupied(entry) => Arc::clone(entry.get()),
@ -662,4 +697,8 @@ mod security_tests;
#[cfg(test)]
#[path = "tests/relay_adversarial_tests.rs"]
mod adversarial_tests;
mod adversarial_tests;
#[cfg(test)]
#[path = "tests/relay_quota_lock_pressure_adversarial_tests.rs"]
mod relay_quota_lock_pressure_adversarial_tests;

View File

@ -249,7 +249,7 @@ async fn run_blackhat_client_handler_fragmented_probe_should_mask(
}
#[tokio::test]
async fn blackhat_truncated_in_range_clienthello_generic_stream_should_mask_but_leaks() {
async fn blackhat_truncated_in_range_clienthello_generic_stream_should_mask() {
let mask_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let mask_addr = mask_listener.local_addr().unwrap();
let backend_reply = b"HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n".to_vec();
@ -309,7 +309,7 @@ async fn blackhat_truncated_in_range_clienthello_generic_stream_should_mask_but_
client_side.shutdown().await.unwrap();
// Security expectation: even malformed in-range TLS should be masked.
// Current code leaks by returning EOF/timeout instead of masking.
// This invariant must hold to avoid probe-distinguishable EOF/timeout behavior.
let mut observed = vec![0u8; backend_reply.len()];
tokio::time::timeout(Duration::from_secs(2), client_side.read_exact(&mut observed))
.await
@ -329,7 +329,7 @@ async fn blackhat_truncated_in_range_clienthello_generic_stream_should_mask_but_
}
#[tokio::test]
async fn blackhat_truncated_in_range_clienthello_client_handler_should_mask_but_leaks() {
async fn blackhat_truncated_in_range_clienthello_client_handler_should_mask() {
let mask_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let mask_addr = mask_listener.local_addr().unwrap();
@ -429,7 +429,7 @@ async fn blackhat_truncated_in_range_clienthello_client_handler_should_mask_but_
}
#[tokio::test]
async fn blackhat_generic_truncated_min_body_1_should_mask_but_leaks() {
async fn blackhat_generic_truncated_min_body_1_should_mask() {
run_blackhat_generic_fragmented_probe_should_mask(
truncated_in_range_record(1),
&[6],
@ -440,7 +440,7 @@ async fn blackhat_generic_truncated_min_body_1_should_mask_but_leaks() {
}
#[tokio::test]
async fn blackhat_generic_truncated_min_body_8_should_mask_but_leaks() {
async fn blackhat_generic_truncated_min_body_8_should_mask() {
run_blackhat_generic_fragmented_probe_should_mask(
truncated_in_range_record(8),
&[13],
@ -451,7 +451,7 @@ async fn blackhat_generic_truncated_min_body_8_should_mask_but_leaks() {
}
#[tokio::test]
async fn blackhat_generic_truncated_min_body_99_should_mask_but_leaks() {
async fn blackhat_generic_truncated_min_body_99_should_mask() {
run_blackhat_generic_fragmented_probe_should_mask(
truncated_in_range_record(MIN_TLS_CLIENT_HELLO_SIZE - 1),
&[5, MIN_TLS_CLIENT_HELLO_SIZE - 1],
@ -462,7 +462,7 @@ async fn blackhat_generic_truncated_min_body_99_should_mask_but_leaks() {
}
#[tokio::test]
async fn blackhat_generic_fragmented_header_then_close_should_mask_but_leaks() {
async fn blackhat_generic_fragmented_header_then_close_should_mask() {
run_blackhat_generic_fragmented_probe_should_mask(
truncated_in_range_record(0),
&[1, 1, 1, 1, 1],
@ -473,7 +473,7 @@ async fn blackhat_generic_fragmented_header_then_close_should_mask_but_leaks() {
}
#[tokio::test]
async fn blackhat_generic_fragmented_header_plus_partial_body_should_mask_but_leaks() {
async fn blackhat_generic_fragmented_header_plus_partial_body_should_mask() {
run_blackhat_generic_fragmented_probe_should_mask(
truncated_in_range_record(5),
&[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
@ -495,7 +495,7 @@ async fn blackhat_generic_slowloris_fragmented_min_probe_should_mask_but_times_o
}
#[tokio::test]
async fn blackhat_client_handler_truncated_min_body_1_should_mask_but_leaks() {
async fn blackhat_client_handler_truncated_min_body_1_should_mask() {
run_blackhat_client_handler_fragmented_probe_should_mask(
truncated_in_range_record(1),
&[6],
@ -506,7 +506,7 @@ async fn blackhat_client_handler_truncated_min_body_1_should_mask_but_leaks() {
}
#[tokio::test]
async fn blackhat_client_handler_truncated_min_body_8_should_mask_but_leaks() {
async fn blackhat_client_handler_truncated_min_body_8_should_mask() {
run_blackhat_client_handler_fragmented_probe_should_mask(
truncated_in_range_record(8),
&[13],
@ -517,7 +517,7 @@ async fn blackhat_client_handler_truncated_min_body_8_should_mask_but_leaks() {
}
#[tokio::test]
async fn blackhat_client_handler_truncated_min_body_99_should_mask_but_leaks() {
async fn blackhat_client_handler_truncated_min_body_99_should_mask() {
run_blackhat_client_handler_fragmented_probe_should_mask(
truncated_in_range_record(MIN_TLS_CLIENT_HELLO_SIZE - 1),
&[5, MIN_TLS_CLIENT_HELLO_SIZE - 1],
@ -528,7 +528,7 @@ async fn blackhat_client_handler_truncated_min_body_99_should_mask_but_leaks() {
}
#[tokio::test]
async fn blackhat_client_handler_fragmented_header_then_close_should_mask_but_leaks() {
async fn blackhat_client_handler_fragmented_header_then_close_should_mask() {
run_blackhat_client_handler_fragmented_probe_should_mask(
truncated_in_range_record(0),
&[1, 1, 1, 1, 1],
@ -539,7 +539,7 @@ async fn blackhat_client_handler_fragmented_header_then_close_should_mask_but_le
}
#[tokio::test]
async fn blackhat_client_handler_fragmented_header_plus_partial_body_should_mask_but_leaks() {
async fn blackhat_client_handler_fragmented_header_plus_partial_body_should_mask() {
run_blackhat_client_handler_fragmented_probe_should_mask(
truncated_in_range_record(5),
&[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],

View File

@ -0,0 +1,51 @@
use super::*;
use crate::protocol::constants::{TG_DATACENTER_PORT, TG_DATACENTERS_V4, TG_DATACENTERS_V6};
use std::net::SocketAddr;
#[test]
fn business_scope_hint_accepts_exact_boundary_length() {
let value = format!("scope_{}", "a".repeat(MAX_SCOPE_HINT_LEN));
assert_eq!(validated_scope_hint(&value), Some("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"));
}
#[test]
fn business_scope_hint_rejects_missing_prefix_even_when_charset_is_valid() {
assert_eq!(validated_scope_hint("alpha-01"), None);
}
#[test]
fn business_known_dc_uses_ipv4_table_by_default() {
let cfg = ProxyConfig::default();
let resolved = get_dc_addr_static(2, &cfg).expect("known dc must resolve");
let expected = SocketAddr::new(TG_DATACENTERS_V4[1], TG_DATACENTER_PORT);
assert_eq!(resolved, expected);
}
#[test]
fn business_negative_dc_maps_by_absolute_value() {
let cfg = ProxyConfig::default();
let resolved = get_dc_addr_static(-3, &cfg).expect("negative dc index must map by absolute value");
let expected = SocketAddr::new(TG_DATACENTERS_V4[2], TG_DATACENTER_PORT);
assert_eq!(resolved, expected);
}
#[test]
fn business_known_dc_uses_ipv6_table_when_preferred_and_enabled() {
let mut cfg = ProxyConfig::default();
cfg.network.prefer = 6;
cfg.network.ipv6 = Some(true);
let resolved = get_dc_addr_static(1, &cfg).expect("known dc must resolve on ipv6 path");
let expected = SocketAddr::new(TG_DATACENTERS_V6[0], TG_DATACENTER_PORT);
assert_eq!(resolved, expected);
}
#[test]
fn business_unknown_dc_uses_configured_default_dc_when_in_range() {
let mut cfg = ProxyConfig::default();
cfg.default_dc = Some(4);
let resolved = get_dc_addr_static(29_999, &cfg).expect("unknown dc must resolve to configured default");
let expected = SocketAddr::new(TG_DATACENTERS_V4[3], TG_DATACENTER_PORT);
assert_eq!(resolved, expected);
}

View File

@ -0,0 +1,98 @@
use super::*;
use crate::protocol::constants::{TG_DATACENTER_PORT, TG_DATACENTERS_V4};
use std::collections::HashSet;
use std::net::SocketAddr;
use std::sync::Mutex;
#[test]
fn common_invalid_override_entries_fallback_to_static_table() {
let mut cfg = ProxyConfig::default();
cfg.dc_overrides.insert(
"2".to_string(),
vec!["bad-address".to_string(), "still-bad".to_string()],
);
let resolved = get_dc_addr_static(2, &cfg).expect("fallback to static table must still resolve");
let expected = SocketAddr::new(TG_DATACENTERS_V4[1], TG_DATACENTER_PORT);
assert_eq!(resolved, expected);
}
#[test]
fn common_prefer_v6_with_only_ipv4_override_uses_override_instead_of_ignoring_it() {
let mut cfg = ProxyConfig::default();
cfg.network.prefer = 6;
cfg.network.ipv6 = Some(true);
cfg.dc_overrides
.insert("3".to_string(), vec!["203.0.113.203:443".to_string()]);
let resolved = get_dc_addr_static(3, &cfg).expect("ipv4 override must be used if no ipv6 override exists");
assert_eq!(resolved, "203.0.113.203:443".parse::<SocketAddr>().unwrap());
}
#[test]
fn common_scope_hint_rejects_unicode_lookalike_characters() {
assert_eq!(validated_scope_hint("scope_аlpha"), None);
assert_eq!(validated_scope_hint("scope_Αlpha"), None);
}
#[cfg(unix)]
#[test]
fn common_anchored_open_rejects_nul_filename() {
use std::os::unix::ffi::OsStringExt;
let parent = std::env::current_dir()
.expect("cwd must be available")
.join("target")
.join(format!("telemt-direct-relay-nul-{}", std::process::id()));
std::fs::create_dir_all(&parent).expect("parent directory must be creatable");
let path = SanitizedUnknownDcLogPath {
resolved_path: parent.join("placeholder.log"),
allowed_parent: parent,
file_name: std::ffi::OsString::from_vec(vec![b'a', 0, b'b']),
};
let err = open_unknown_dc_log_append_anchored(&path)
.expect_err("anchored open must fail on NUL in filename");
assert_eq!(err.kind(), std::io::ErrorKind::InvalidInput);
}
#[cfg(unix)]
#[test]
fn common_anchored_open_creates_owner_only_file_permissions() {
use std::os::unix::fs::PermissionsExt;
let parent = std::env::current_dir()
.expect("cwd must be available")
.join("target")
.join(format!("telemt-direct-relay-perm-{}", std::process::id()));
std::fs::create_dir_all(&parent).expect("parent directory must be creatable");
let sanitized = SanitizedUnknownDcLogPath {
resolved_path: parent.join("unknown-dc.log"),
allowed_parent: parent.clone(),
file_name: std::ffi::OsString::from("unknown-dc.log"),
};
let mut file = open_unknown_dc_log_append_anchored(&sanitized)
.expect("anchored open must create regular file");
use std::io::Write;
writeln!(file, "dc_idx=1").expect("write must succeed");
let mode = std::fs::metadata(parent.join("unknown-dc.log"))
.expect("metadata must be readable")
.permissions()
.mode()
& 0o777;
assert_eq!(mode, 0o600);
}
#[test]
fn common_duplicate_dc_attempts_do_not_consume_unique_slots() {
let set = Mutex::new(HashSet::new());
assert!(should_log_unknown_dc_with_set(&set, 100));
assert!(!should_log_unknown_dc_with_set(&set, 100));
assert!(should_log_unknown_dc_with_set(&set, 101));
assert_eq!(set.lock().expect("set lock must be available").len(), 2);
}

View File

@ -667,6 +667,56 @@ fn adversarial_check_then_symlink_flip_is_blocked_by_nofollow_open() {
);
}
#[cfg(unix)]
#[test]
fn adversarial_parent_swap_after_check_is_blocked_by_anchored_open() {
use std::os::unix::fs::symlink;
let base = std::env::current_dir()
.expect("cwd must be available")
.join("target")
.join(format!("telemt-unknown-dc-parent-swap-openat-{}", std::process::id()));
fs::create_dir_all(&base).expect("parent-swap-openat base must be creatable");
let rel_candidate = format!(
"target/telemt-unknown-dc-parent-swap-openat-{}/unknown-dc.log",
std::process::id()
);
let sanitized = sanitize_unknown_dc_log_path(&rel_candidate)
.expect("candidate must sanitize before parent swap");
fs::write(&sanitized.resolved_path, "seed\n").expect("seed target file must be writable");
assert!(
unknown_dc_log_path_is_still_safe(&sanitized),
"precondition: target should initially pass revalidation"
);
let outside_parent = std::env::temp_dir().join(format!(
"telemt-unknown-dc-parent-swap-openat-outside-{}",
std::process::id()
));
fs::create_dir_all(&outside_parent).expect("outside parent directory must be creatable");
let outside_target = outside_parent.join("unknown-dc.log");
let _ = fs::remove_file(&outside_target);
let moved = base.with_extension("bak");
let _ = fs::remove_dir_all(&moved);
fs::rename(&base, &moved).expect("base parent must be movable for swap simulation");
symlink(&outside_parent, &base).expect("base parent symlink replacement must be creatable");
let err = open_unknown_dc_log_append_anchored(&sanitized)
.expect_err("anchored open must fail when parent is swapped to symlink");
let raw = err.raw_os_error();
assert!(
matches!(raw, Some(libc::ELOOP) | Some(libc::ENOTDIR) | Some(libc::ENOENT)),
"anchored open must fail closed on parent swap race, got raw_os_error={raw:?}"
);
assert!(
!outside_target.exists(),
"anchored open must never create a log file in swapped outside parent"
);
}
#[tokio::test]
async fn unknown_dc_absolute_log_path_writes_one_entry() {
let _guard = unknown_dc_test_lock()

View File

@ -0,0 +1,197 @@
use super::*;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
fn nonempty_line_count(text: &str) -> usize {
text.lines().filter(|line| !line.trim().is_empty()).count()
}
#[test]
fn subtle_stress_single_unknown_dc_under_concurrency_logs_once() {
let _guard = unknown_dc_test_lock()
.lock()
.expect("unknown dc test lock must be available");
clear_unknown_dc_log_cache_for_testing();
let winners = Arc::new(AtomicUsize::new(0));
let mut workers = Vec::new();
for _ in 0..128 {
let winners = Arc::clone(&winners);
workers.push(std::thread::spawn(move || {
if should_log_unknown_dc(31_333) {
winners.fetch_add(1, Ordering::Relaxed);
}
}));
}
for worker in workers {
worker.join().expect("worker must not panic");
}
assert_eq!(winners.load(Ordering::Relaxed), 1);
}
#[test]
fn subtle_light_fuzz_scope_hint_matches_oracle() {
fn oracle(input: &str) -> bool {
let Some(rest) = input.strip_prefix("scope_") else {
return false;
};
!rest.is_empty()
&& rest.len() <= MAX_SCOPE_HINT_LEN
&& rest
.bytes()
.all(|b| b.is_ascii_alphanumeric() || b == b'-')
}
let mut state: u64 = 0xC0FF_EE11_D15C_AFE5;
for _ in 0..4_096 {
state ^= state << 7;
state ^= state >> 9;
state ^= state << 8;
let len = (state as usize % 72) + 1;
let mut s = String::with_capacity(len + 6);
if (state & 1) == 0 {
s.push_str("scope_");
} else {
s.push_str("user_");
}
for idx in 0..len {
let v = ((state >> ((idx % 8) * 8)) & 0xff) as u8;
let ch = match v % 6 {
0 => (b'a' + (v % 26)) as char,
1 => (b'A' + (v % 26)) as char,
2 => (b'0' + (v % 10)) as char,
3 => '-',
4 => '_',
_ => '.',
};
s.push(ch);
}
let got = validated_scope_hint(&s).is_some();
assert_eq!(got, oracle(&s), "mismatch for input: {s}");
}
}
#[test]
fn subtle_light_fuzz_dc_resolution_never_panics_and_preserves_port() {
let mut state: u64 = 0x1234_5678_9ABC_DEF0;
for _ in 0..2_048 {
state ^= state << 13;
state ^= state >> 7;
state ^= state << 17;
let mut cfg = ProxyConfig::default();
cfg.network.prefer = if (state & 1) == 0 { 4 } else { 6 };
cfg.network.ipv6 = Some((state & 2) != 0);
cfg.default_dc = Some(((state >> 8) as u8).max(1));
let dc_idx = (state as i16).wrapping_sub(16_384);
let resolved = get_dc_addr_static(dc_idx, &cfg).expect("dc resolution must never fail");
assert_eq!(resolved.port(), crate::protocol::constants::TG_DATACENTER_PORT);
let expect_v6 = cfg.network.prefer == 6 && cfg.network.ipv6.unwrap_or(true);
assert_eq!(resolved.is_ipv6(), expect_v6);
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn subtle_integration_parallel_same_dc_logs_one_line() {
let _guard = unknown_dc_test_lock()
.lock()
.expect("unknown dc test lock must be available");
clear_unknown_dc_log_cache_for_testing();
let rel_dir = format!("target/telemt-direct-relay-same-{}", std::process::id());
let rel_file = format!("{rel_dir}/unknown-dc.log");
let abs_dir = std::env::current_dir()
.expect("cwd must be available")
.join(&rel_dir);
std::fs::create_dir_all(&abs_dir).expect("log directory must be creatable");
let abs_file = abs_dir.join("unknown-dc.log");
let _ = std::fs::remove_file(&abs_file);
let mut cfg = ProxyConfig::default();
cfg.general.unknown_dc_file_log_enabled = true;
cfg.general.unknown_dc_log_path = Some(rel_file);
let cfg = Arc::new(cfg);
let mut tasks = Vec::new();
for _ in 0..32 {
let cfg = Arc::clone(&cfg);
tasks.push(tokio::spawn(async move {
let _ = get_dc_addr_static(31_777, cfg.as_ref());
}));
}
for task in tasks {
task.await.expect("task must not panic");
}
for _ in 0..60 {
if let Ok(content) = std::fs::read_to_string(&abs_file)
&& nonempty_line_count(&content) == 1
{
return;
}
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
}
let content = std::fs::read_to_string(&abs_file).unwrap_or_default();
assert_eq!(nonempty_line_count(&content), 1);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn subtle_integration_parallel_unique_dcs_log_unique_lines() {
let _guard = unknown_dc_test_lock()
.lock()
.expect("unknown dc test lock must be available");
clear_unknown_dc_log_cache_for_testing();
let rel_dir = format!("target/telemt-direct-relay-unique-{}", std::process::id());
let rel_file = format!("{rel_dir}/unknown-dc.log");
let abs_dir = std::env::current_dir()
.expect("cwd must be available")
.join(&rel_dir);
std::fs::create_dir_all(&abs_dir).expect("log directory must be creatable");
let abs_file = abs_dir.join("unknown-dc.log");
let _ = std::fs::remove_file(&abs_file);
let mut cfg = ProxyConfig::default();
cfg.general.unknown_dc_file_log_enabled = true;
cfg.general.unknown_dc_log_path = Some(rel_file);
let cfg = Arc::new(cfg);
let dcs = [31_901_i16, 31_902, 31_903, 31_904, 31_905, 31_906, 31_907, 31_908];
let mut tasks = Vec::new();
for dc in dcs {
let cfg = Arc::clone(&cfg);
tasks.push(tokio::spawn(async move {
let _ = get_dc_addr_static(dc, cfg.as_ref());
}));
}
for task in tasks {
task.await.expect("task must not panic");
}
for _ in 0..80 {
if let Ok(content) = std::fs::read_to_string(&abs_file)
&& nonempty_line_count(&content) >= 8
{
return;
}
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
}
let content = std::fs::read_to_string(&abs_file).unwrap_or_default();
assert!(
nonempty_line_count(&content) >= 8,
"expected at least one line per unique dc, content: {content}"
);
}

View File

@ -0,0 +1,409 @@
use super::*;
use crate::error::ProxyError;
use crate::stats::Stats;
use crate::stream::BufferPool;
use dashmap::DashMap;
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use std::time::Duration;
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};
use tokio::sync::Barrier;
use tokio::time::Instant;
#[test]
fn quota_lock_same_user_returns_same_arc_instance() {
let _guard = super::quota_user_lock_test_guard()
.lock()
.expect("quota lock test guard must be available");
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let a = quota_user_lock("quota-lock-same-user");
let b = quota_user_lock("quota-lock-same-user");
assert!(Arc::ptr_eq(&a, &b));
}
#[test]
fn quota_lock_parallel_same_user_reuses_single_lock() {
let _guard = super::quota_user_lock_test_guard()
.lock()
.expect("quota lock test guard must be available");
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let user = "quota-lock-parallel-same";
let mut handles = Vec::new();
for _ in 0..64 {
handles.push(std::thread::spawn(move || quota_user_lock(user)));
}
let first = handles
.remove(0)
.join()
.expect("thread must return lock handle");
for handle in handles {
let got = handle.join().expect("thread must return lock handle");
assert!(Arc::ptr_eq(&first, &got));
}
}
#[test]
fn quota_lock_unique_users_materialize_distinct_entries() {
let _guard = super::quota_user_lock_test_guard()
.lock()
.expect("quota lock test guard must be available");
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let base = format!("quota-lock-distinct-{}", std::process::id());
let users: Vec<String> = (0..(QUOTA_USER_LOCKS_MAX / 2))
.map(|idx| format!("{base}-{idx}"))
.collect();
for user in &users {
let _ = quota_user_lock(user);
}
for user in &users {
assert!(map.get(user).is_some(), "lock cache must contain entry for {user}");
}
}
#[test]
fn quota_lock_unique_churn_stress_keeps_all_inserted_keys_addressable() {
let _guard = super::quota_user_lock_test_guard()
.lock()
.expect("quota lock test guard must be available");
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let base = format!("quota-lock-churn-{}", std::process::id());
for idx in 0..(QUOTA_USER_LOCKS_MAX + 256) {
let _ = quota_user_lock(&format!("{base}-{idx}"));
}
assert!(
map.len() <= QUOTA_USER_LOCKS_MAX,
"quota lock cache must stay bounded under unique-user churn"
);
}
#[test]
fn quota_lock_saturation_returns_stable_overflow_lock_without_cache_growth() {
let _guard = super::quota_user_lock_test_guard()
.lock()
.expect("quota lock test guard must be available");
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let prefix = format!("quota-held-{}", std::process::id());
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
for idx in 0..QUOTA_USER_LOCKS_MAX {
retained.push(quota_user_lock(&format!("{prefix}-{idx}")));
}
assert_eq!(
map.len(),
QUOTA_USER_LOCKS_MAX,
"cache must be saturated for overflow check"
);
let overflow_user = format!("quota-overflow-{}", std::process::id());
let overflow_a = quota_user_lock(&overflow_user);
let overflow_b = quota_user_lock(&overflow_user);
assert_eq!(
map.len(),
QUOTA_USER_LOCKS_MAX,
"overflow path must not grow lock cache"
);
assert!(
map.get(&overflow_user).is_none(),
"overflow user lock must stay outside bounded cache under saturation"
);
assert!(
Arc::ptr_eq(&overflow_a, &overflow_b),
"overflow user must receive stable striped overflow lock while saturated"
);
drop(retained);
}
#[test]
fn quota_lock_reclaims_unreferenced_entries_before_ephemeral_fallback() {
let _guard = super::quota_user_lock_test_guard()
.lock()
.expect("quota lock test guard must be available");
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
// Fill and immediately drop strong references, leaving only map-owned Arcs.
for idx in 0..QUOTA_USER_LOCKS_MAX {
let _ = quota_user_lock(&format!("quota-reclaim-drop-{}-{idx}", std::process::id()));
}
assert_eq!(map.len(), QUOTA_USER_LOCKS_MAX);
let overflow_user = format!("quota-reclaim-overflow-{}", std::process::id());
let overflow = quota_user_lock(&overflow_user);
assert!(
map.get(&overflow_user).is_some(),
"after reclaiming stale entries, overflow user should become cacheable"
);
assert!(
Arc::strong_count(&overflow) >= 2,
"cacheable overflow lock should be held by both map and caller"
);
}
#[test]
fn quota_lock_saturated_same_user_must_not_return_distinct_locks() {
let _guard = super::quota_user_lock_test_guard()
.lock()
.expect("quota lock test guard must be available");
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
for idx in 0..QUOTA_USER_LOCKS_MAX {
retained.push(quota_user_lock(&format!("quota-saturated-held-{}-{idx}", std::process::id())));
}
let overflow_user = format!("quota-saturated-same-user-{}", std::process::id());
let a = quota_user_lock(&overflow_user);
let b = quota_user_lock(&overflow_user);
assert!(
Arc::ptr_eq(&a, &b),
"same user must not receive distinct locks under saturation because that enables quota race bypass"
);
drop(retained);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn quota_lock_saturation_concurrent_same_user_never_overshoots_quota() {
let _guard = super::quota_user_lock_test_guard()
.lock()
.expect("quota lock test guard must be available");
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
for idx in 0..QUOTA_USER_LOCKS_MAX {
retained.push(quota_user_lock(&format!("quota-saturated-race-held-{}-{idx}", std::process::id())));
}
let stats = Arc::new(Stats::new());
let user = format!("quota-saturated-race-user-{}", std::process::id());
let gate = Arc::new(Barrier::new(2));
let worker = |label: u8, stats: Arc<Stats>, user: String, gate: Arc<Barrier>| {
tokio::spawn(async move {
let counters = Arc::new(SharedCounters::new());
let quota_exceeded = Arc::new(AtomicBool::new(false));
let mut io = StatsIo::new(
tokio::io::sink(),
counters,
Arc::clone(&stats),
user,
Some(1),
quota_exceeded,
Instant::now(),
);
gate.wait().await;
io.write_all(&[label]).await
})
};
let one = worker(0x11, Arc::clone(&stats), user.clone(), Arc::clone(&gate));
let two = worker(0x22, Arc::clone(&stats), user.clone(), Arc::clone(&gate));
let _ = tokio::time::timeout(Duration::from_secs(2), async {
let _ = one.await.expect("task one must not panic");
let _ = two.await.expect("task two must not panic");
})
.await
.expect("quota race workers must complete");
assert!(
stats.get_user_total_octets(&user) <= 1,
"saturated lock path must never overshoot quota for same user"
);
drop(retained);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn quota_lock_saturation_stress_same_user_never_overshoots_quota() {
let _guard = super::quota_user_lock_test_guard()
.lock()
.expect("quota lock test guard must be available");
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
for idx in 0..QUOTA_USER_LOCKS_MAX {
retained.push(quota_user_lock(&format!("quota-saturated-stress-held-{}-{idx}", std::process::id())));
}
for round in 0..128u32 {
let stats = Arc::new(Stats::new());
let user = format!("quota-saturated-stress-user-{}-{round}", std::process::id());
let gate = Arc::new(Barrier::new(2));
let one = {
let stats = Arc::clone(&stats);
let user = user.clone();
let gate = Arc::clone(&gate);
tokio::spawn(async move {
let counters = Arc::new(SharedCounters::new());
let quota_exceeded = Arc::new(AtomicBool::new(false));
let mut io = StatsIo::new(
tokio::io::sink(),
counters,
Arc::clone(&stats),
user,
Some(1),
quota_exceeded,
Instant::now(),
);
gate.wait().await;
io.write_all(&[0x31]).await
})
};
let two = {
let stats = Arc::clone(&stats);
let user = user.clone();
let gate = Arc::clone(&gate);
tokio::spawn(async move {
let counters = Arc::new(SharedCounters::new());
let quota_exceeded = Arc::new(AtomicBool::new(false));
let mut io = StatsIo::new(
tokio::io::sink(),
counters,
Arc::clone(&stats),
user,
Some(1),
quota_exceeded,
Instant::now(),
);
gate.wait().await;
io.write_all(&[0x32]).await
})
};
let _ = one.await.expect("stress task one must not panic");
let _ = two.await.expect("stress task two must not panic");
assert!(
stats.get_user_total_octets(&user) <= 1,
"round {round}: saturated path must not overshoot quota"
);
}
drop(retained);
}
#[test]
fn quota_error_classifier_accepts_internal_quota_sentinel_only() {
let err = quota_io_error();
assert!(is_quota_io_error(&err));
}
#[test]
fn quota_error_classifier_rejects_plain_permission_denied() {
let err = std::io::Error::new(std::io::ErrorKind::PermissionDenied, "permission denied");
assert!(!is_quota_io_error(&err));
}
#[tokio::test]
async fn quota_lock_integration_zero_quota_cuts_off_without_forwarding() {
let stats = Arc::new(Stats::new());
let user = "quota-zero-user";
let (mut client_peer, relay_client) = duplex(2048);
let (relay_server, mut server_peer) = duplex(2048);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let relay = tokio::spawn(relay_bidirectional(
client_reader,
client_writer,
server_reader,
server_writer,
512,
512,
user,
Arc::clone(&stats),
Some(0),
Arc::new(BufferPool::new()),
));
client_peer
.write_all(b"x")
.await
.expect("client write must succeed");
let mut probe = [0u8; 1];
let forwarded = tokio::time::timeout(Duration::from_millis(80), server_peer.read(&mut probe)).await;
if let Ok(Ok(n)) = forwarded {
assert_eq!(n, 0, "zero quota path must not forward payload bytes");
}
let result = tokio::time::timeout(Duration::from_secs(2), relay)
.await
.expect("relay must terminate under zero quota")
.expect("relay task must not panic");
assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. })));
}
#[tokio::test]
async fn quota_lock_integration_no_quota_relays_both_directions_under_burst() {
let stats = Arc::new(Stats::new());
let (mut client_peer, relay_client) = duplex(8192);
let (relay_server, mut server_peer) = duplex(8192);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let relay = tokio::spawn(relay_bidirectional(
client_reader,
client_writer,
server_reader,
server_writer,
1024,
1024,
"quota-none-burst-user",
Arc::clone(&stats),
None,
Arc::new(BufferPool::new()),
));
let c2s = vec![0xA5; 2048];
let s2c = vec![0x5A; 1536];
client_peer.write_all(&c2s).await.expect("client burst write must succeed");
let mut got_c2s = vec![0u8; c2s.len()];
server_peer.read_exact(&mut got_c2s).await.expect("server must receive c2s burst");
assert_eq!(got_c2s, c2s);
server_peer.write_all(&s2c).await.expect("server burst write must succeed");
let mut got_s2c = vec![0u8; s2c.len()];
client_peer.read_exact(&mut got_s2c).await.expect("client must receive s2c burst");
assert_eq!(got_s2c, s2c);
drop(client_peer);
drop(server_peer);
let done = tokio::time::timeout(Duration::from_secs(2), relay)
.await
.expect("relay must terminate after peers close")
.expect("relay task must not panic");
assert!(done.is_ok());
}

View File

@ -31,6 +31,12 @@ impl std::task::Wake for WakeCounter {
#[tokio::test]
async fn quota_lock_contention_does_not_self_wake_pending_writer() {
let _guard = super::quota_user_lock_test_guard()
.lock()
.expect("quota lock test guard must be available");
let map = super::QUOTA_USER_LOCKS.get_or_init(dashmap::DashMap::new);
map.clear();
let stats = Arc::new(Stats::new());
let user = "quota-lock-contention-user";
@ -66,6 +72,12 @@ async fn quota_lock_contention_does_not_self_wake_pending_writer() {
#[tokio::test]
async fn quota_lock_contention_writer_schedules_single_deferred_wake_until_lock_acquired() {
let _guard = super::quota_user_lock_test_guard()
.lock()
.expect("quota lock test guard must be available");
let map = super::QUOTA_USER_LOCKS.get_or_init(dashmap::DashMap::new);
map.clear();
let stats = Arc::new(Stats::new());
let user = "quota-lock-writer-liveness-user";
@ -133,6 +145,12 @@ async fn quota_lock_contention_writer_schedules_single_deferred_wake_until_lock_
#[tokio::test]
async fn quota_lock_contention_read_path_schedules_deferred_wake_for_liveness() {
let _guard = super::quota_user_lock_test_guard()
.lock()
.expect("quota lock test guard must be available");
let map = super::QUOTA_USER_LOCKS.get_or_init(dashmap::DashMap::new);
map.clear();
let stats = Arc::new(Stats::new());
let user = "quota-lock-read-liveness-user";

View File

@ -0,0 +1,168 @@
use std::net::{IpAddr, Ipv4Addr};
use std::sync::Arc;
use std::time::Duration;
use crate::config::UserMaxUniqueIpsMode;
use crate::ip_tracker::UserIpTracker;
fn ip_from_idx(idx: u32) -> IpAddr {
IpAddr::V4(Ipv4Addr::new(10, ((idx >> 16) & 0xff) as u8, ((idx >> 8) & 0xff) as u8, (idx & 0xff) as u8))
}
#[tokio::test]
async fn hotpath_empty_drain_is_idempotent() {
let tracker = UserIpTracker::new();
for _ in 0..128 {
tracker.drain_cleanup_queue().await;
}
assert_eq!(tracker.get_active_ip_count("none").await, 0);
}
#[tokio::test]
async fn hotpath_batch_cleanup_drain_clears_all_active_entries() {
let tracker = UserIpTracker::new();
tracker.set_user_limit("u", 100).await;
for idx in 0..32 {
let ip = ip_from_idx(idx);
tracker.check_and_add("u", ip).await.unwrap();
tracker.enqueue_cleanup("u".to_string(), ip);
}
tracker.drain_cleanup_queue().await;
assert_eq!(tracker.get_active_ip_count("u").await, 0);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn hotpath_parallel_enqueue_and_drain_does_not_deadlock() {
let tracker = Arc::new(UserIpTracker::new());
tracker.set_user_limit("p", 64).await;
let mut tasks = Vec::new();
for worker in 0..32u32 {
let t = tracker.clone();
tasks.push(tokio::spawn(async move {
let ip = ip_from_idx(1_000 + worker);
for _ in 0..64 {
let _ = t.check_and_add("p", ip).await;
t.enqueue_cleanup("p".to_string(), ip);
t.drain_cleanup_queue().await;
}
}));
}
for task in tasks {
tokio::time::timeout(Duration::from_secs(3), task)
.await
.expect("worker must not deadlock")
.expect("worker task must not panic");
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn hotpath_parallel_unique_ip_limit_never_exceeds_cap() {
let tracker = Arc::new(UserIpTracker::new());
tracker.set_user_limit("limit", 5).await;
let mut tasks = Vec::new();
for idx in 0..64u32 {
let t = tracker.clone();
tasks.push(tokio::spawn(async move { t.check_and_add("limit", ip_from_idx(idx)).await.is_ok() }));
}
let mut admitted = 0usize;
for task in tasks {
if task.await.expect("task must not panic") {
admitted += 1;
}
}
assert!(admitted <= 5, "admitted unique IPs must not exceed configured cap");
assert!(tracker.get_active_ip_count("limit").await <= 5);
}
#[tokio::test]
async fn hotpath_repeated_same_ip_counter_balances_to_zero() {
let tracker = UserIpTracker::new();
tracker.set_user_limit("same", 1).await;
let ip = ip_from_idx(77);
for _ in 0..512 {
tracker.check_and_add("same", ip).await.unwrap();
}
for _ in 0..512 {
tracker.remove_ip("same", ip).await;
}
assert_eq!(tracker.get_active_ip_count("same").await, 0);
}
#[tokio::test]
async fn hotpath_light_fuzz_mixed_operations_preserve_limit_invariants() {
let tracker = UserIpTracker::new();
tracker.set_user_limit("fuzz", 4).await;
let mut state: u64 = 0xA55A_5AA5_D15C_B00B;
for _ in 0..4_000 {
state ^= state << 7;
state ^= state >> 9;
state ^= state << 8;
let ip = ip_from_idx((state as u32) % 8);
match state & 0x3 {
0 | 1 => {
let _ = tracker.check_and_add("fuzz", ip).await;
}
_ => {
tracker.remove_ip("fuzz", ip).await;
}
}
assert!(
tracker.get_active_ip_count("fuzz").await <= 4,
"active count must stay within configured cap"
);
}
}
#[tokio::test]
async fn hotpath_multi_user_churn_keeps_isolation() {
let tracker = UserIpTracker::new();
tracker.set_user_limit("u1", 2).await;
tracker.set_user_limit("u2", 3).await;
for idx in 0..200u32 {
let ip1 = ip_from_idx(idx % 5);
let ip2 = ip_from_idx(100 + (idx % 7));
let _ = tracker.check_and_add("u1", ip1).await;
let _ = tracker.check_and_add("u2", ip2).await;
if idx % 2 == 0 {
tracker.remove_ip("u1", ip1).await;
}
if idx % 3 == 0 {
tracker.remove_ip("u2", ip2).await;
}
}
assert!(tracker.get_active_ip_count("u1").await <= 2);
assert!(tracker.get_active_ip_count("u2").await <= 3);
}
#[tokio::test]
async fn hotpath_time_window_expiry_allows_new_ip_after_window() {
let tracker = UserIpTracker::new();
tracker.set_user_limit("tw", 1).await;
tracker
.set_limit_policy(UserMaxUniqueIpsMode::TimeWindow, 1)
.await;
let ip1 = ip_from_idx(901);
let ip2 = ip_from_idx(902);
tracker.check_and_add("tw", ip1).await.unwrap();
tracker.remove_ip("tw", ip1).await;
assert!(tracker.check_and_add("tw", ip2).await.is_err());
tokio::time::sleep(Duration::from_millis(1_100)).await;
assert!(tracker.check_and_add("tw", ip2).await.is_ok());
}