mirror of https://github.com/telemt/telemt.git
Refactor and enhance tests for proxy and relay functionality
- Renamed test functions in `client_tls_clienthello_truncation_adversarial_tests.rs` to remove "but_leaks" suffix for clarity. - Added new tests in `direct_relay_business_logic_tests.rs` to validate business logic for data center resolution and scope hints. - Introduced tests in `direct_relay_common_mistakes_tests.rs` to cover common mistakes in direct relay configurations. - Added security tests in `direct_relay_security_tests.rs` to ensure proper handling of symlink and parent swap scenarios. - Created `direct_relay_subtle_adversarial_tests.rs` to stress test concurrent logging and validate scope hint behavior. - Implemented `relay_quota_lock_pressure_adversarial_tests.rs` to test quota lock behavior under high contention and stress. - Updated `relay_security_tests.rs` to include quota lock contention tests ensuring proper behavior under concurrent access. - Introduced `ip_tracker_hotpath_adversarial_tests.rs` to validate the performance and correctness of the IP tracking logic under various scenarios.
This commit is contained in:
parent
8188fedf6a
commit
5933b5e821
|
|
@ -9,6 +9,9 @@ mod ip_tracker;
|
|||
#[cfg(test)]
|
||||
#[path = "tests/ip_tracker_regression_tests.rs"]
|
||||
mod ip_tracker_regression_tests;
|
||||
#[cfg(test)]
|
||||
#[path = "tests/ip_tracker_hotpath_adversarial_tests.rs"]
|
||||
mod ip_tracker_hotpath_adversarial_tests;
|
||||
mod maestro;
|
||||
mod metrics;
|
||||
mod network;
|
||||
|
|
|
|||
|
|
@ -307,9 +307,8 @@ fn extract_sni_with_duplicate_extensions_rejected() {
|
|||
h.extend_from_slice(&(handshake.len() as u16).to_be_bytes());
|
||||
h.extend_from_slice(&handshake);
|
||||
|
||||
// Parser might return first, see second, or fail. OWASP ASVS prefers rejection of unexpected dups.
|
||||
// Telemt's `extract_sni` returns the first one found.
|
||||
assert!(extract_sni_from_client_hello(&h).is_some());
|
||||
// Duplicate SNI extensions are ambiguous and must fail closed.
|
||||
assert!(extract_sni_from_client_hello(&h).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
|||
|
|
@ -588,6 +588,9 @@ pub fn extract_sni_from_client_hello(handshake: &[u8]) -> Option<String> {
|
|||
return None;
|
||||
}
|
||||
|
||||
let mut saw_sni_extension = false;
|
||||
let mut extracted_sni = None;
|
||||
|
||||
while pos + 4 <= ext_end {
|
||||
let etype = u16::from_be_bytes([handshake[pos], handshake[pos + 1]]);
|
||||
let elen = u16::from_be_bytes([handshake[pos + 2], handshake[pos + 3]]) as usize;
|
||||
|
|
@ -595,6 +598,12 @@ pub fn extract_sni_from_client_hello(handshake: &[u8]) -> Option<String> {
|
|||
if pos + elen > ext_end {
|
||||
break;
|
||||
}
|
||||
if etype == 0x0000 {
|
||||
if saw_sni_extension {
|
||||
return None;
|
||||
}
|
||||
saw_sni_extension = true;
|
||||
}
|
||||
if etype == 0x0000 && elen >= 5 {
|
||||
// server_name extension
|
||||
let list_len = u16::from_be_bytes([handshake[pos], handshake[pos + 1]]) as usize;
|
||||
|
|
@ -611,7 +620,8 @@ pub fn extract_sni_from_client_hello(handshake: &[u8]) -> Option<String> {
|
|||
&& let Ok(host) = std::str::from_utf8(&handshake[sn_pos..sn_pos + name_len])
|
||||
{
|
||||
if is_valid_sni_hostname(host) {
|
||||
return Some(host.to_string());
|
||||
extracted_sni = Some(host.to_string());
|
||||
break;
|
||||
}
|
||||
}
|
||||
sn_pos += name_len;
|
||||
|
|
@ -620,7 +630,7 @@ pub fn extract_sni_from_client_hello(handshake: &[u8]) -> Option<String> {
|
|||
pos += elen;
|
||||
}
|
||||
|
||||
None
|
||||
extracted_sni
|
||||
}
|
||||
|
||||
fn is_valid_sni_hostname(host: &str) -> bool {
|
||||
|
|
|
|||
|
|
@ -27,6 +27,10 @@ use crate::transport::UpstreamManager;
|
|||
|
||||
#[cfg(unix)]
|
||||
use std::os::unix::fs::OpenOptionsExt;
|
||||
#[cfg(unix)]
|
||||
use std::os::unix::ffi::OsStrExt;
|
||||
#[cfg(unix)]
|
||||
use std::os::unix::io::{AsRawFd, FromRawFd};
|
||||
|
||||
const UNKNOWN_DC_LOG_DISTINCT_LIMIT: usize = 1024;
|
||||
static LOGGED_UNKNOWN_DCS: OnceLock<Mutex<HashSet<i16>>> = OnceLock::new();
|
||||
|
|
@ -136,6 +140,7 @@ fn unknown_dc_log_path_is_still_safe(path: &SanitizedUnknownDcLogPath) -> bool {
|
|||
true
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn open_unknown_dc_log_append(path: &Path) -> std::io::Result<std::fs::File> {
|
||||
#[cfg(unix)]
|
||||
{
|
||||
|
|
@ -155,6 +160,64 @@ fn open_unknown_dc_log_append(path: &Path) -> std::io::Result<std::fs::File> {
|
|||
}
|
||||
}
|
||||
|
||||
fn open_unknown_dc_log_append_anchored(path: &SanitizedUnknownDcLogPath) -> std::io::Result<std::fs::File> {
|
||||
#[cfg(unix)]
|
||||
{
|
||||
let parent = OpenOptions::new()
|
||||
.read(true)
|
||||
.custom_flags(libc::O_DIRECTORY | libc::O_NOFOLLOW | libc::O_CLOEXEC)
|
||||
.open(&path.allowed_parent)?;
|
||||
|
||||
let file_name = std::ffi::CString::new(path.file_name.as_os_str().as_bytes())
|
||||
.map_err(|_| std::io::Error::new(std::io::ErrorKind::InvalidInput, "unknown DC log file name contains NUL byte"))?;
|
||||
|
||||
let fd = unsafe {
|
||||
libc::openat(
|
||||
parent.as_raw_fd(),
|
||||
file_name.as_ptr(),
|
||||
libc::O_CREAT | libc::O_APPEND | libc::O_WRONLY | libc::O_NOFOLLOW | libc::O_CLOEXEC,
|
||||
0o600,
|
||||
)
|
||||
};
|
||||
|
||||
if fd < 0 {
|
||||
return Err(std::io::Error::last_os_error());
|
||||
}
|
||||
|
||||
let file = unsafe { std::fs::File::from_raw_fd(fd) };
|
||||
Ok(file)
|
||||
}
|
||||
#[cfg(not(unix))]
|
||||
{
|
||||
let _ = path;
|
||||
Err(std::io::Error::new(
|
||||
std::io::ErrorKind::PermissionDenied,
|
||||
"unknown_dc_file_log_enabled requires unix O_NOFOLLOW support",
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
fn append_unknown_dc_line(file: &mut std::fs::File, dc_idx: i16) -> std::io::Result<()> {
|
||||
#[cfg(unix)]
|
||||
{
|
||||
if unsafe { libc::flock(file.as_raw_fd(), libc::LOCK_EX) } != 0 {
|
||||
return Err(std::io::Error::last_os_error());
|
||||
}
|
||||
|
||||
let write_result = writeln!(file, "dc_idx={dc_idx}");
|
||||
|
||||
if unsafe { libc::flock(file.as_raw_fd(), libc::LOCK_UN) } != 0 {
|
||||
return Err(std::io::Error::last_os_error());
|
||||
}
|
||||
|
||||
write_result
|
||||
}
|
||||
#[cfg(not(unix))]
|
||||
{
|
||||
writeln!(file, "dc_idx={dc_idx}")
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn clear_unknown_dc_log_cache_for_testing() {
|
||||
if let Some(set) = LOGGED_UNKNOWN_DCS.get()
|
||||
|
|
@ -321,9 +384,9 @@ fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result<SocketAddr> {
|
|||
if should_log_unknown_dc(dc_idx) {
|
||||
handle.spawn_blocking(move || {
|
||||
if unknown_dc_log_path_is_still_safe(&path)
|
||||
&& let Ok(mut file) = open_unknown_dc_log_append(&path.resolved_path)
|
||||
&& let Ok(mut file) = open_unknown_dc_log_append_anchored(&path)
|
||||
{
|
||||
let _ = writeln!(file, "dc_idx={dc_idx}");
|
||||
let _ = append_unknown_dc_line(&mut file, dc_idx);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
|
@ -394,3 +457,15 @@ where
|
|||
#[cfg(test)]
|
||||
#[path = "tests/direct_relay_security_tests.rs"]
|
||||
mod security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/direct_relay_business_logic_tests.rs"]
|
||||
mod business_logic_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/direct_relay_common_mistakes_tests.rs"]
|
||||
mod common_mistakes_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/direct_relay_subtle_adversarial_tests.rs"]
|
||||
mod subtle_adversarial_tests;
|
||||
|
|
|
|||
|
|
@ -263,6 +263,33 @@ fn is_quota_io_error(err: &io::Error) -> bool {
|
|||
}
|
||||
|
||||
static QUOTA_USER_LOCKS: OnceLock<DashMap<String, Arc<Mutex<()>>>> = OnceLock::new();
|
||||
static QUOTA_USER_OVERFLOW_LOCKS: OnceLock<Vec<Arc<Mutex<()>>>> = OnceLock::new();
|
||||
|
||||
#[cfg(test)]
|
||||
const QUOTA_USER_LOCKS_MAX: usize = 64;
|
||||
#[cfg(not(test))]
|
||||
const QUOTA_USER_LOCKS_MAX: usize = 4_096;
|
||||
#[cfg(test)]
|
||||
const QUOTA_OVERFLOW_LOCK_STRIPES: usize = 16;
|
||||
#[cfg(not(test))]
|
||||
const QUOTA_OVERFLOW_LOCK_STRIPES: usize = 256;
|
||||
|
||||
#[cfg(test)]
|
||||
fn quota_user_lock_test_guard() -> &'static Mutex<()> {
|
||||
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
TEST_LOCK.get_or_init(|| Mutex::new(()))
|
||||
}
|
||||
|
||||
fn quota_overflow_user_lock(user: &str) -> Arc<Mutex<()>> {
|
||||
let stripes = QUOTA_USER_OVERFLOW_LOCKS.get_or_init(|| {
|
||||
(0..QUOTA_OVERFLOW_LOCK_STRIPES)
|
||||
.map(|_| Arc::new(Mutex::new(())))
|
||||
.collect()
|
||||
});
|
||||
|
||||
let hash = crc32fast::hash(user.as_bytes()) as usize;
|
||||
Arc::clone(&stripes[hash % stripes.len()])
|
||||
}
|
||||
|
||||
fn quota_user_lock(user: &str) -> Arc<Mutex<()>> {
|
||||
let locks = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
|
|
@ -270,6 +297,14 @@ fn quota_user_lock(user: &str) -> Arc<Mutex<()>> {
|
|||
return Arc::clone(existing.value());
|
||||
}
|
||||
|
||||
if locks.len() >= QUOTA_USER_LOCKS_MAX {
|
||||
locks.retain(|_, value| Arc::strong_count(value) > 1);
|
||||
}
|
||||
|
||||
if locks.len() >= QUOTA_USER_LOCKS_MAX {
|
||||
return quota_overflow_user_lock(user);
|
||||
}
|
||||
|
||||
let created = Arc::new(Mutex::new(()));
|
||||
match locks.entry(user.to_string()) {
|
||||
dashmap::mapref::entry::Entry::Occupied(entry) => Arc::clone(entry.get()),
|
||||
|
|
@ -662,4 +697,8 @@ mod security_tests;
|
|||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_adversarial_tests.rs"]
|
||||
mod adversarial_tests;
|
||||
mod adversarial_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_quota_lock_pressure_adversarial_tests.rs"]
|
||||
mod relay_quota_lock_pressure_adversarial_tests;
|
||||
|
|
@ -249,7 +249,7 @@ async fn run_blackhat_client_handler_fragmented_probe_should_mask(
|
|||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn blackhat_truncated_in_range_clienthello_generic_stream_should_mask_but_leaks() {
|
||||
async fn blackhat_truncated_in_range_clienthello_generic_stream_should_mask() {
|
||||
let mask_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let mask_addr = mask_listener.local_addr().unwrap();
|
||||
let backend_reply = b"HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n".to_vec();
|
||||
|
|
@ -309,7 +309,7 @@ async fn blackhat_truncated_in_range_clienthello_generic_stream_should_mask_but_
|
|||
client_side.shutdown().await.unwrap();
|
||||
|
||||
// Security expectation: even malformed in-range TLS should be masked.
|
||||
// Current code leaks by returning EOF/timeout instead of masking.
|
||||
// This invariant must hold to avoid probe-distinguishable EOF/timeout behavior.
|
||||
let mut observed = vec![0u8; backend_reply.len()];
|
||||
tokio::time::timeout(Duration::from_secs(2), client_side.read_exact(&mut observed))
|
||||
.await
|
||||
|
|
@ -329,7 +329,7 @@ async fn blackhat_truncated_in_range_clienthello_generic_stream_should_mask_but_
|
|||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn blackhat_truncated_in_range_clienthello_client_handler_should_mask_but_leaks() {
|
||||
async fn blackhat_truncated_in_range_clienthello_client_handler_should_mask() {
|
||||
let mask_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let mask_addr = mask_listener.local_addr().unwrap();
|
||||
|
||||
|
|
@ -429,7 +429,7 @@ async fn blackhat_truncated_in_range_clienthello_client_handler_should_mask_but_
|
|||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn blackhat_generic_truncated_min_body_1_should_mask_but_leaks() {
|
||||
async fn blackhat_generic_truncated_min_body_1_should_mask() {
|
||||
run_blackhat_generic_fragmented_probe_should_mask(
|
||||
truncated_in_range_record(1),
|
||||
&[6],
|
||||
|
|
@ -440,7 +440,7 @@ async fn blackhat_generic_truncated_min_body_1_should_mask_but_leaks() {
|
|||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn blackhat_generic_truncated_min_body_8_should_mask_but_leaks() {
|
||||
async fn blackhat_generic_truncated_min_body_8_should_mask() {
|
||||
run_blackhat_generic_fragmented_probe_should_mask(
|
||||
truncated_in_range_record(8),
|
||||
&[13],
|
||||
|
|
@ -451,7 +451,7 @@ async fn blackhat_generic_truncated_min_body_8_should_mask_but_leaks() {
|
|||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn blackhat_generic_truncated_min_body_99_should_mask_but_leaks() {
|
||||
async fn blackhat_generic_truncated_min_body_99_should_mask() {
|
||||
run_blackhat_generic_fragmented_probe_should_mask(
|
||||
truncated_in_range_record(MIN_TLS_CLIENT_HELLO_SIZE - 1),
|
||||
&[5, MIN_TLS_CLIENT_HELLO_SIZE - 1],
|
||||
|
|
@ -462,7 +462,7 @@ async fn blackhat_generic_truncated_min_body_99_should_mask_but_leaks() {
|
|||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn blackhat_generic_fragmented_header_then_close_should_mask_but_leaks() {
|
||||
async fn blackhat_generic_fragmented_header_then_close_should_mask() {
|
||||
run_blackhat_generic_fragmented_probe_should_mask(
|
||||
truncated_in_range_record(0),
|
||||
&[1, 1, 1, 1, 1],
|
||||
|
|
@ -473,7 +473,7 @@ async fn blackhat_generic_fragmented_header_then_close_should_mask_but_leaks() {
|
|||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn blackhat_generic_fragmented_header_plus_partial_body_should_mask_but_leaks() {
|
||||
async fn blackhat_generic_fragmented_header_plus_partial_body_should_mask() {
|
||||
run_blackhat_generic_fragmented_probe_should_mask(
|
||||
truncated_in_range_record(5),
|
||||
&[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
|
||||
|
|
@ -495,7 +495,7 @@ async fn blackhat_generic_slowloris_fragmented_min_probe_should_mask_but_times_o
|
|||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn blackhat_client_handler_truncated_min_body_1_should_mask_but_leaks() {
|
||||
async fn blackhat_client_handler_truncated_min_body_1_should_mask() {
|
||||
run_blackhat_client_handler_fragmented_probe_should_mask(
|
||||
truncated_in_range_record(1),
|
||||
&[6],
|
||||
|
|
@ -506,7 +506,7 @@ async fn blackhat_client_handler_truncated_min_body_1_should_mask_but_leaks() {
|
|||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn blackhat_client_handler_truncated_min_body_8_should_mask_but_leaks() {
|
||||
async fn blackhat_client_handler_truncated_min_body_8_should_mask() {
|
||||
run_blackhat_client_handler_fragmented_probe_should_mask(
|
||||
truncated_in_range_record(8),
|
||||
&[13],
|
||||
|
|
@ -517,7 +517,7 @@ async fn blackhat_client_handler_truncated_min_body_8_should_mask_but_leaks() {
|
|||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn blackhat_client_handler_truncated_min_body_99_should_mask_but_leaks() {
|
||||
async fn blackhat_client_handler_truncated_min_body_99_should_mask() {
|
||||
run_blackhat_client_handler_fragmented_probe_should_mask(
|
||||
truncated_in_range_record(MIN_TLS_CLIENT_HELLO_SIZE - 1),
|
||||
&[5, MIN_TLS_CLIENT_HELLO_SIZE - 1],
|
||||
|
|
@ -528,7 +528,7 @@ async fn blackhat_client_handler_truncated_min_body_99_should_mask_but_leaks() {
|
|||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn blackhat_client_handler_fragmented_header_then_close_should_mask_but_leaks() {
|
||||
async fn blackhat_client_handler_fragmented_header_then_close_should_mask() {
|
||||
run_blackhat_client_handler_fragmented_probe_should_mask(
|
||||
truncated_in_range_record(0),
|
||||
&[1, 1, 1, 1, 1],
|
||||
|
|
@ -539,7 +539,7 @@ async fn blackhat_client_handler_fragmented_header_then_close_should_mask_but_le
|
|||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn blackhat_client_handler_fragmented_header_plus_partial_body_should_mask_but_leaks() {
|
||||
async fn blackhat_client_handler_fragmented_header_plus_partial_body_should_mask() {
|
||||
run_blackhat_client_handler_fragmented_probe_should_mask(
|
||||
truncated_in_range_record(5),
|
||||
&[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
|
||||
|
|
|
|||
|
|
@ -0,0 +1,51 @@
|
|||
use super::*;
|
||||
use crate::protocol::constants::{TG_DATACENTER_PORT, TG_DATACENTERS_V4, TG_DATACENTERS_V6};
|
||||
use std::net::SocketAddr;
|
||||
|
||||
#[test]
|
||||
fn business_scope_hint_accepts_exact_boundary_length() {
|
||||
let value = format!("scope_{}", "a".repeat(MAX_SCOPE_HINT_LEN));
|
||||
assert_eq!(validated_scope_hint(&value), Some("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn business_scope_hint_rejects_missing_prefix_even_when_charset_is_valid() {
|
||||
assert_eq!(validated_scope_hint("alpha-01"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn business_known_dc_uses_ipv4_table_by_default() {
|
||||
let cfg = ProxyConfig::default();
|
||||
let resolved = get_dc_addr_static(2, &cfg).expect("known dc must resolve");
|
||||
let expected = SocketAddr::new(TG_DATACENTERS_V4[1], TG_DATACENTER_PORT);
|
||||
assert_eq!(resolved, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn business_negative_dc_maps_by_absolute_value() {
|
||||
let cfg = ProxyConfig::default();
|
||||
let resolved = get_dc_addr_static(-3, &cfg).expect("negative dc index must map by absolute value");
|
||||
let expected = SocketAddr::new(TG_DATACENTERS_V4[2], TG_DATACENTER_PORT);
|
||||
assert_eq!(resolved, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn business_known_dc_uses_ipv6_table_when_preferred_and_enabled() {
|
||||
let mut cfg = ProxyConfig::default();
|
||||
cfg.network.prefer = 6;
|
||||
cfg.network.ipv6 = Some(true);
|
||||
|
||||
let resolved = get_dc_addr_static(1, &cfg).expect("known dc must resolve on ipv6 path");
|
||||
let expected = SocketAddr::new(TG_DATACENTERS_V6[0], TG_DATACENTER_PORT);
|
||||
assert_eq!(resolved, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn business_unknown_dc_uses_configured_default_dc_when_in_range() {
|
||||
let mut cfg = ProxyConfig::default();
|
||||
cfg.default_dc = Some(4);
|
||||
|
||||
let resolved = get_dc_addr_static(29_999, &cfg).expect("unknown dc must resolve to configured default");
|
||||
let expected = SocketAddr::new(TG_DATACENTERS_V4[3], TG_DATACENTER_PORT);
|
||||
assert_eq!(resolved, expected);
|
||||
}
|
||||
|
|
@ -0,0 +1,98 @@
|
|||
use super::*;
|
||||
use crate::protocol::constants::{TG_DATACENTER_PORT, TG_DATACENTERS_V4};
|
||||
use std::collections::HashSet;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Mutex;
|
||||
|
||||
#[test]
|
||||
fn common_invalid_override_entries_fallback_to_static_table() {
|
||||
let mut cfg = ProxyConfig::default();
|
||||
cfg.dc_overrides.insert(
|
||||
"2".to_string(),
|
||||
vec!["bad-address".to_string(), "still-bad".to_string()],
|
||||
);
|
||||
|
||||
let resolved = get_dc_addr_static(2, &cfg).expect("fallback to static table must still resolve");
|
||||
let expected = SocketAddr::new(TG_DATACENTERS_V4[1], TG_DATACENTER_PORT);
|
||||
assert_eq!(resolved, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn common_prefer_v6_with_only_ipv4_override_uses_override_instead_of_ignoring_it() {
|
||||
let mut cfg = ProxyConfig::default();
|
||||
cfg.network.prefer = 6;
|
||||
cfg.network.ipv6 = Some(true);
|
||||
cfg.dc_overrides
|
||||
.insert("3".to_string(), vec!["203.0.113.203:443".to_string()]);
|
||||
|
||||
let resolved = get_dc_addr_static(3, &cfg).expect("ipv4 override must be used if no ipv6 override exists");
|
||||
assert_eq!(resolved, "203.0.113.203:443".parse::<SocketAddr>().unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn common_scope_hint_rejects_unicode_lookalike_characters() {
|
||||
assert_eq!(validated_scope_hint("scope_аlpha"), None);
|
||||
assert_eq!(validated_scope_hint("scope_Αlpha"), None);
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[test]
|
||||
fn common_anchored_open_rejects_nul_filename() {
|
||||
use std::os::unix::ffi::OsStringExt;
|
||||
|
||||
let parent = std::env::current_dir()
|
||||
.expect("cwd must be available")
|
||||
.join("target")
|
||||
.join(format!("telemt-direct-relay-nul-{}", std::process::id()));
|
||||
std::fs::create_dir_all(&parent).expect("parent directory must be creatable");
|
||||
|
||||
let path = SanitizedUnknownDcLogPath {
|
||||
resolved_path: parent.join("placeholder.log"),
|
||||
allowed_parent: parent,
|
||||
file_name: std::ffi::OsString::from_vec(vec![b'a', 0, b'b']),
|
||||
};
|
||||
|
||||
let err = open_unknown_dc_log_append_anchored(&path)
|
||||
.expect_err("anchored open must fail on NUL in filename");
|
||||
assert_eq!(err.kind(), std::io::ErrorKind::InvalidInput);
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[test]
|
||||
fn common_anchored_open_creates_owner_only_file_permissions() {
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
|
||||
let parent = std::env::current_dir()
|
||||
.expect("cwd must be available")
|
||||
.join("target")
|
||||
.join(format!("telemt-direct-relay-perm-{}", std::process::id()));
|
||||
std::fs::create_dir_all(&parent).expect("parent directory must be creatable");
|
||||
|
||||
let sanitized = SanitizedUnknownDcLogPath {
|
||||
resolved_path: parent.join("unknown-dc.log"),
|
||||
allowed_parent: parent.clone(),
|
||||
file_name: std::ffi::OsString::from("unknown-dc.log"),
|
||||
};
|
||||
|
||||
let mut file = open_unknown_dc_log_append_anchored(&sanitized)
|
||||
.expect("anchored open must create regular file");
|
||||
use std::io::Write;
|
||||
writeln!(file, "dc_idx=1").expect("write must succeed");
|
||||
|
||||
let mode = std::fs::metadata(parent.join("unknown-dc.log"))
|
||||
.expect("metadata must be readable")
|
||||
.permissions()
|
||||
.mode()
|
||||
& 0o777;
|
||||
assert_eq!(mode, 0o600);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn common_duplicate_dc_attempts_do_not_consume_unique_slots() {
|
||||
let set = Mutex::new(HashSet::new());
|
||||
|
||||
assert!(should_log_unknown_dc_with_set(&set, 100));
|
||||
assert!(!should_log_unknown_dc_with_set(&set, 100));
|
||||
assert!(should_log_unknown_dc_with_set(&set, 101));
|
||||
assert_eq!(set.lock().expect("set lock must be available").len(), 2);
|
||||
}
|
||||
|
|
@ -667,6 +667,56 @@ fn adversarial_check_then_symlink_flip_is_blocked_by_nofollow_open() {
|
|||
);
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[test]
|
||||
fn adversarial_parent_swap_after_check_is_blocked_by_anchored_open() {
|
||||
use std::os::unix::fs::symlink;
|
||||
|
||||
let base = std::env::current_dir()
|
||||
.expect("cwd must be available")
|
||||
.join("target")
|
||||
.join(format!("telemt-unknown-dc-parent-swap-openat-{}", std::process::id()));
|
||||
fs::create_dir_all(&base).expect("parent-swap-openat base must be creatable");
|
||||
|
||||
let rel_candidate = format!(
|
||||
"target/telemt-unknown-dc-parent-swap-openat-{}/unknown-dc.log",
|
||||
std::process::id()
|
||||
);
|
||||
let sanitized = sanitize_unknown_dc_log_path(&rel_candidate)
|
||||
.expect("candidate must sanitize before parent swap");
|
||||
fs::write(&sanitized.resolved_path, "seed\n").expect("seed target file must be writable");
|
||||
|
||||
assert!(
|
||||
unknown_dc_log_path_is_still_safe(&sanitized),
|
||||
"precondition: target should initially pass revalidation"
|
||||
);
|
||||
|
||||
let outside_parent = std::env::temp_dir().join(format!(
|
||||
"telemt-unknown-dc-parent-swap-openat-outside-{}",
|
||||
std::process::id()
|
||||
));
|
||||
fs::create_dir_all(&outside_parent).expect("outside parent directory must be creatable");
|
||||
let outside_target = outside_parent.join("unknown-dc.log");
|
||||
let _ = fs::remove_file(&outside_target);
|
||||
|
||||
let moved = base.with_extension("bak");
|
||||
let _ = fs::remove_dir_all(&moved);
|
||||
fs::rename(&base, &moved).expect("base parent must be movable for swap simulation");
|
||||
symlink(&outside_parent, &base).expect("base parent symlink replacement must be creatable");
|
||||
|
||||
let err = open_unknown_dc_log_append_anchored(&sanitized)
|
||||
.expect_err("anchored open must fail when parent is swapped to symlink");
|
||||
let raw = err.raw_os_error();
|
||||
assert!(
|
||||
matches!(raw, Some(libc::ELOOP) | Some(libc::ENOTDIR) | Some(libc::ENOENT)),
|
||||
"anchored open must fail closed on parent swap race, got raw_os_error={raw:?}"
|
||||
);
|
||||
assert!(
|
||||
!outside_target.exists(),
|
||||
"anchored open must never create a log file in swapped outside parent"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn unknown_dc_absolute_log_path_writes_one_entry() {
|
||||
let _guard = unknown_dc_test_lock()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,197 @@
|
|||
use super::*;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
|
||||
fn nonempty_line_count(text: &str) -> usize {
|
||||
text.lines().filter(|line| !line.trim().is_empty()).count()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn subtle_stress_single_unknown_dc_under_concurrency_logs_once() {
|
||||
let _guard = unknown_dc_test_lock()
|
||||
.lock()
|
||||
.expect("unknown dc test lock must be available");
|
||||
clear_unknown_dc_log_cache_for_testing();
|
||||
|
||||
let winners = Arc::new(AtomicUsize::new(0));
|
||||
let mut workers = Vec::new();
|
||||
|
||||
for _ in 0..128 {
|
||||
let winners = Arc::clone(&winners);
|
||||
workers.push(std::thread::spawn(move || {
|
||||
if should_log_unknown_dc(31_333) {
|
||||
winners.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
for worker in workers {
|
||||
worker.join().expect("worker must not panic");
|
||||
}
|
||||
|
||||
assert_eq!(winners.load(Ordering::Relaxed), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn subtle_light_fuzz_scope_hint_matches_oracle() {
|
||||
fn oracle(input: &str) -> bool {
|
||||
let Some(rest) = input.strip_prefix("scope_") else {
|
||||
return false;
|
||||
};
|
||||
!rest.is_empty()
|
||||
&& rest.len() <= MAX_SCOPE_HINT_LEN
|
||||
&& rest
|
||||
.bytes()
|
||||
.all(|b| b.is_ascii_alphanumeric() || b == b'-')
|
||||
}
|
||||
|
||||
let mut state: u64 = 0xC0FF_EE11_D15C_AFE5;
|
||||
for _ in 0..4_096 {
|
||||
state ^= state << 7;
|
||||
state ^= state >> 9;
|
||||
state ^= state << 8;
|
||||
|
||||
let len = (state as usize % 72) + 1;
|
||||
let mut s = String::with_capacity(len + 6);
|
||||
if (state & 1) == 0 {
|
||||
s.push_str("scope_");
|
||||
} else {
|
||||
s.push_str("user_");
|
||||
}
|
||||
|
||||
for idx in 0..len {
|
||||
let v = ((state >> ((idx % 8) * 8)) & 0xff) as u8;
|
||||
let ch = match v % 6 {
|
||||
0 => (b'a' + (v % 26)) as char,
|
||||
1 => (b'A' + (v % 26)) as char,
|
||||
2 => (b'0' + (v % 10)) as char,
|
||||
3 => '-',
|
||||
4 => '_',
|
||||
_ => '.',
|
||||
};
|
||||
s.push(ch);
|
||||
}
|
||||
|
||||
let got = validated_scope_hint(&s).is_some();
|
||||
assert_eq!(got, oracle(&s), "mismatch for input: {s}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn subtle_light_fuzz_dc_resolution_never_panics_and_preserves_port() {
|
||||
let mut state: u64 = 0x1234_5678_9ABC_DEF0;
|
||||
|
||||
for _ in 0..2_048 {
|
||||
state ^= state << 13;
|
||||
state ^= state >> 7;
|
||||
state ^= state << 17;
|
||||
|
||||
let mut cfg = ProxyConfig::default();
|
||||
cfg.network.prefer = if (state & 1) == 0 { 4 } else { 6 };
|
||||
cfg.network.ipv6 = Some((state & 2) != 0);
|
||||
cfg.default_dc = Some(((state >> 8) as u8).max(1));
|
||||
|
||||
let dc_idx = (state as i16).wrapping_sub(16_384);
|
||||
let resolved = get_dc_addr_static(dc_idx, &cfg).expect("dc resolution must never fail");
|
||||
|
||||
assert_eq!(resolved.port(), crate::protocol::constants::TG_DATACENTER_PORT);
|
||||
let expect_v6 = cfg.network.prefer == 6 && cfg.network.ipv6.unwrap_or(true);
|
||||
assert_eq!(resolved.is_ipv6(), expect_v6);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn subtle_integration_parallel_same_dc_logs_one_line() {
|
||||
let _guard = unknown_dc_test_lock()
|
||||
.lock()
|
||||
.expect("unknown dc test lock must be available");
|
||||
clear_unknown_dc_log_cache_for_testing();
|
||||
|
||||
let rel_dir = format!("target/telemt-direct-relay-same-{}", std::process::id());
|
||||
let rel_file = format!("{rel_dir}/unknown-dc.log");
|
||||
let abs_dir = std::env::current_dir()
|
||||
.expect("cwd must be available")
|
||||
.join(&rel_dir);
|
||||
std::fs::create_dir_all(&abs_dir).expect("log directory must be creatable");
|
||||
let abs_file = abs_dir.join("unknown-dc.log");
|
||||
let _ = std::fs::remove_file(&abs_file);
|
||||
|
||||
let mut cfg = ProxyConfig::default();
|
||||
cfg.general.unknown_dc_file_log_enabled = true;
|
||||
cfg.general.unknown_dc_log_path = Some(rel_file);
|
||||
|
||||
let cfg = Arc::new(cfg);
|
||||
let mut tasks = Vec::new();
|
||||
for _ in 0..32 {
|
||||
let cfg = Arc::clone(&cfg);
|
||||
tasks.push(tokio::spawn(async move {
|
||||
let _ = get_dc_addr_static(31_777, cfg.as_ref());
|
||||
}));
|
||||
}
|
||||
for task in tasks {
|
||||
task.await.expect("task must not panic");
|
||||
}
|
||||
|
||||
for _ in 0..60 {
|
||||
if let Ok(content) = std::fs::read_to_string(&abs_file)
|
||||
&& nonempty_line_count(&content) == 1
|
||||
{
|
||||
return;
|
||||
}
|
||||
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
|
||||
}
|
||||
|
||||
let content = std::fs::read_to_string(&abs_file).unwrap_or_default();
|
||||
assert_eq!(nonempty_line_count(&content), 1);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn subtle_integration_parallel_unique_dcs_log_unique_lines() {
|
||||
let _guard = unknown_dc_test_lock()
|
||||
.lock()
|
||||
.expect("unknown dc test lock must be available");
|
||||
clear_unknown_dc_log_cache_for_testing();
|
||||
|
||||
let rel_dir = format!("target/telemt-direct-relay-unique-{}", std::process::id());
|
||||
let rel_file = format!("{rel_dir}/unknown-dc.log");
|
||||
let abs_dir = std::env::current_dir()
|
||||
.expect("cwd must be available")
|
||||
.join(&rel_dir);
|
||||
std::fs::create_dir_all(&abs_dir).expect("log directory must be creatable");
|
||||
let abs_file = abs_dir.join("unknown-dc.log");
|
||||
let _ = std::fs::remove_file(&abs_file);
|
||||
|
||||
let mut cfg = ProxyConfig::default();
|
||||
cfg.general.unknown_dc_file_log_enabled = true;
|
||||
cfg.general.unknown_dc_log_path = Some(rel_file);
|
||||
|
||||
let cfg = Arc::new(cfg);
|
||||
let dcs = [31_901_i16, 31_902, 31_903, 31_904, 31_905, 31_906, 31_907, 31_908];
|
||||
let mut tasks = Vec::new();
|
||||
|
||||
for dc in dcs {
|
||||
let cfg = Arc::clone(&cfg);
|
||||
tasks.push(tokio::spawn(async move {
|
||||
let _ = get_dc_addr_static(dc, cfg.as_ref());
|
||||
}));
|
||||
}
|
||||
|
||||
for task in tasks {
|
||||
task.await.expect("task must not panic");
|
||||
}
|
||||
|
||||
for _ in 0..80 {
|
||||
if let Ok(content) = std::fs::read_to_string(&abs_file)
|
||||
&& nonempty_line_count(&content) >= 8
|
||||
{
|
||||
return;
|
||||
}
|
||||
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
|
||||
}
|
||||
|
||||
let content = std::fs::read_to_string(&abs_file).unwrap_or_default();
|
||||
assert!(
|
||||
nonempty_line_count(&content) >= 8,
|
||||
"expected at least one line per unique dc, content: {content}"
|
||||
);
|
||||
}
|
||||
|
|
@ -0,0 +1,409 @@
|
|||
use super::*;
|
||||
use crate::error::ProxyError;
|
||||
use crate::stats::Stats;
|
||||
use crate::stream::BufferPool;
|
||||
use dashmap::DashMap;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::time::Duration;
|
||||
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::sync::Barrier;
|
||||
use tokio::time::Instant;
|
||||
|
||||
#[test]
|
||||
fn quota_lock_same_user_returns_same_arc_instance() {
|
||||
let _guard = super::quota_user_lock_test_guard()
|
||||
.lock()
|
||||
.expect("quota lock test guard must be available");
|
||||
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
map.clear();
|
||||
|
||||
let a = quota_user_lock("quota-lock-same-user");
|
||||
let b = quota_user_lock("quota-lock-same-user");
|
||||
assert!(Arc::ptr_eq(&a, &b));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quota_lock_parallel_same_user_reuses_single_lock() {
|
||||
let _guard = super::quota_user_lock_test_guard()
|
||||
.lock()
|
||||
.expect("quota lock test guard must be available");
|
||||
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
map.clear();
|
||||
|
||||
let user = "quota-lock-parallel-same";
|
||||
let mut handles = Vec::new();
|
||||
|
||||
for _ in 0..64 {
|
||||
handles.push(std::thread::spawn(move || quota_user_lock(user)));
|
||||
}
|
||||
|
||||
let first = handles
|
||||
.remove(0)
|
||||
.join()
|
||||
.expect("thread must return lock handle");
|
||||
|
||||
for handle in handles {
|
||||
let got = handle.join().expect("thread must return lock handle");
|
||||
assert!(Arc::ptr_eq(&first, &got));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quota_lock_unique_users_materialize_distinct_entries() {
|
||||
let _guard = super::quota_user_lock_test_guard()
|
||||
.lock()
|
||||
.expect("quota lock test guard must be available");
|
||||
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
|
||||
map.clear();
|
||||
|
||||
let base = format!("quota-lock-distinct-{}", std::process::id());
|
||||
let users: Vec<String> = (0..(QUOTA_USER_LOCKS_MAX / 2))
|
||||
.map(|idx| format!("{base}-{idx}"))
|
||||
.collect();
|
||||
|
||||
for user in &users {
|
||||
let _ = quota_user_lock(user);
|
||||
}
|
||||
|
||||
for user in &users {
|
||||
assert!(map.get(user).is_some(), "lock cache must contain entry for {user}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quota_lock_unique_churn_stress_keeps_all_inserted_keys_addressable() {
|
||||
let _guard = super::quota_user_lock_test_guard()
|
||||
.lock()
|
||||
.expect("quota lock test guard must be available");
|
||||
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
|
||||
map.clear();
|
||||
|
||||
let base = format!("quota-lock-churn-{}", std::process::id());
|
||||
for idx in 0..(QUOTA_USER_LOCKS_MAX + 256) {
|
||||
let _ = quota_user_lock(&format!("{base}-{idx}"));
|
||||
}
|
||||
|
||||
assert!(
|
||||
map.len() <= QUOTA_USER_LOCKS_MAX,
|
||||
"quota lock cache must stay bounded under unique-user churn"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quota_lock_saturation_returns_stable_overflow_lock_without_cache_growth() {
|
||||
let _guard = super::quota_user_lock_test_guard()
|
||||
.lock()
|
||||
.expect("quota lock test guard must be available");
|
||||
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
map.clear();
|
||||
|
||||
let prefix = format!("quota-held-{}", std::process::id());
|
||||
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
|
||||
for idx in 0..QUOTA_USER_LOCKS_MAX {
|
||||
retained.push(quota_user_lock(&format!("{prefix}-{idx}")));
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
map.len(),
|
||||
QUOTA_USER_LOCKS_MAX,
|
||||
"cache must be saturated for overflow check"
|
||||
);
|
||||
|
||||
let overflow_user = format!("quota-overflow-{}", std::process::id());
|
||||
let overflow_a = quota_user_lock(&overflow_user);
|
||||
let overflow_b = quota_user_lock(&overflow_user);
|
||||
|
||||
assert_eq!(
|
||||
map.len(),
|
||||
QUOTA_USER_LOCKS_MAX,
|
||||
"overflow path must not grow lock cache"
|
||||
);
|
||||
assert!(
|
||||
map.get(&overflow_user).is_none(),
|
||||
"overflow user lock must stay outside bounded cache under saturation"
|
||||
);
|
||||
assert!(
|
||||
Arc::ptr_eq(&overflow_a, &overflow_b),
|
||||
"overflow user must receive stable striped overflow lock while saturated"
|
||||
);
|
||||
|
||||
drop(retained);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quota_lock_reclaims_unreferenced_entries_before_ephemeral_fallback() {
|
||||
let _guard = super::quota_user_lock_test_guard()
|
||||
.lock()
|
||||
.expect("quota lock test guard must be available");
|
||||
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
map.clear();
|
||||
|
||||
// Fill and immediately drop strong references, leaving only map-owned Arcs.
|
||||
for idx in 0..QUOTA_USER_LOCKS_MAX {
|
||||
let _ = quota_user_lock(&format!("quota-reclaim-drop-{}-{idx}", std::process::id()));
|
||||
}
|
||||
assert_eq!(map.len(), QUOTA_USER_LOCKS_MAX);
|
||||
|
||||
let overflow_user = format!("quota-reclaim-overflow-{}", std::process::id());
|
||||
let overflow = quota_user_lock(&overflow_user);
|
||||
|
||||
assert!(
|
||||
map.get(&overflow_user).is_some(),
|
||||
"after reclaiming stale entries, overflow user should become cacheable"
|
||||
);
|
||||
assert!(
|
||||
Arc::strong_count(&overflow) >= 2,
|
||||
"cacheable overflow lock should be held by both map and caller"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quota_lock_saturated_same_user_must_not_return_distinct_locks() {
|
||||
let _guard = super::quota_user_lock_test_guard()
|
||||
.lock()
|
||||
.expect("quota lock test guard must be available");
|
||||
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
map.clear();
|
||||
|
||||
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
|
||||
for idx in 0..QUOTA_USER_LOCKS_MAX {
|
||||
retained.push(quota_user_lock(&format!("quota-saturated-held-{}-{idx}", std::process::id())));
|
||||
}
|
||||
|
||||
let overflow_user = format!("quota-saturated-same-user-{}", std::process::id());
|
||||
let a = quota_user_lock(&overflow_user);
|
||||
let b = quota_user_lock(&overflow_user);
|
||||
|
||||
assert!(
|
||||
Arc::ptr_eq(&a, &b),
|
||||
"same user must not receive distinct locks under saturation because that enables quota race bypass"
|
||||
);
|
||||
|
||||
drop(retained);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn quota_lock_saturation_concurrent_same_user_never_overshoots_quota() {
|
||||
let _guard = super::quota_user_lock_test_guard()
|
||||
.lock()
|
||||
.expect("quota lock test guard must be available");
|
||||
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
map.clear();
|
||||
|
||||
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
|
||||
for idx in 0..QUOTA_USER_LOCKS_MAX {
|
||||
retained.push(quota_user_lock(&format!("quota-saturated-race-held-{}-{idx}", std::process::id())));
|
||||
}
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user = format!("quota-saturated-race-user-{}", std::process::id());
|
||||
let gate = Arc::new(Barrier::new(2));
|
||||
|
||||
let worker = |label: u8, stats: Arc<Stats>, user: String, gate: Arc<Barrier>| {
|
||||
tokio::spawn(async move {
|
||||
let counters = Arc::new(SharedCounters::new());
|
||||
let quota_exceeded = Arc::new(AtomicBool::new(false));
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
counters,
|
||||
Arc::clone(&stats),
|
||||
user,
|
||||
Some(1),
|
||||
quota_exceeded,
|
||||
Instant::now(),
|
||||
);
|
||||
gate.wait().await;
|
||||
io.write_all(&[label]).await
|
||||
})
|
||||
};
|
||||
|
||||
let one = worker(0x11, Arc::clone(&stats), user.clone(), Arc::clone(&gate));
|
||||
let two = worker(0x22, Arc::clone(&stats), user.clone(), Arc::clone(&gate));
|
||||
|
||||
let _ = tokio::time::timeout(Duration::from_secs(2), async {
|
||||
let _ = one.await.expect("task one must not panic");
|
||||
let _ = two.await.expect("task two must not panic");
|
||||
})
|
||||
.await
|
||||
.expect("quota race workers must complete");
|
||||
|
||||
assert!(
|
||||
stats.get_user_total_octets(&user) <= 1,
|
||||
"saturated lock path must never overshoot quota for same user"
|
||||
);
|
||||
|
||||
drop(retained);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn quota_lock_saturation_stress_same_user_never_overshoots_quota() {
|
||||
let _guard = super::quota_user_lock_test_guard()
|
||||
.lock()
|
||||
.expect("quota lock test guard must be available");
|
||||
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
map.clear();
|
||||
|
||||
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
|
||||
for idx in 0..QUOTA_USER_LOCKS_MAX {
|
||||
retained.push(quota_user_lock(&format!("quota-saturated-stress-held-{}-{idx}", std::process::id())));
|
||||
}
|
||||
|
||||
for round in 0..128u32 {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user = format!("quota-saturated-stress-user-{}-{round}", std::process::id());
|
||||
let gate = Arc::new(Barrier::new(2));
|
||||
|
||||
let one = {
|
||||
let stats = Arc::clone(&stats);
|
||||
let user = user.clone();
|
||||
let gate = Arc::clone(&gate);
|
||||
tokio::spawn(async move {
|
||||
let counters = Arc::new(SharedCounters::new());
|
||||
let quota_exceeded = Arc::new(AtomicBool::new(false));
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
counters,
|
||||
Arc::clone(&stats),
|
||||
user,
|
||||
Some(1),
|
||||
quota_exceeded,
|
||||
Instant::now(),
|
||||
);
|
||||
gate.wait().await;
|
||||
io.write_all(&[0x31]).await
|
||||
})
|
||||
};
|
||||
|
||||
let two = {
|
||||
let stats = Arc::clone(&stats);
|
||||
let user = user.clone();
|
||||
let gate = Arc::clone(&gate);
|
||||
tokio::spawn(async move {
|
||||
let counters = Arc::new(SharedCounters::new());
|
||||
let quota_exceeded = Arc::new(AtomicBool::new(false));
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
counters,
|
||||
Arc::clone(&stats),
|
||||
user,
|
||||
Some(1),
|
||||
quota_exceeded,
|
||||
Instant::now(),
|
||||
);
|
||||
gate.wait().await;
|
||||
io.write_all(&[0x32]).await
|
||||
})
|
||||
};
|
||||
|
||||
let _ = one.await.expect("stress task one must not panic");
|
||||
let _ = two.await.expect("stress task two must not panic");
|
||||
|
||||
assert!(
|
||||
stats.get_user_total_octets(&user) <= 1,
|
||||
"round {round}: saturated path must not overshoot quota"
|
||||
);
|
||||
}
|
||||
|
||||
drop(retained);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quota_error_classifier_accepts_internal_quota_sentinel_only() {
|
||||
let err = quota_io_error();
|
||||
assert!(is_quota_io_error(&err));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quota_error_classifier_rejects_plain_permission_denied() {
|
||||
let err = std::io::Error::new(std::io::ErrorKind::PermissionDenied, "permission denied");
|
||||
assert!(!is_quota_io_error(&err));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn quota_lock_integration_zero_quota_cuts_off_without_forwarding() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user = "quota-zero-user";
|
||||
|
||||
let (mut client_peer, relay_client) = duplex(2048);
|
||||
let (relay_server, mut server_peer) = duplex(2048);
|
||||
let (client_reader, client_writer) = tokio::io::split(relay_client);
|
||||
let (server_reader, server_writer) = tokio::io::split(relay_server);
|
||||
|
||||
let relay = tokio::spawn(relay_bidirectional(
|
||||
client_reader,
|
||||
client_writer,
|
||||
server_reader,
|
||||
server_writer,
|
||||
512,
|
||||
512,
|
||||
user,
|
||||
Arc::clone(&stats),
|
||||
Some(0),
|
||||
Arc::new(BufferPool::new()),
|
||||
));
|
||||
|
||||
client_peer
|
||||
.write_all(b"x")
|
||||
.await
|
||||
.expect("client write must succeed");
|
||||
|
||||
let mut probe = [0u8; 1];
|
||||
let forwarded = tokio::time::timeout(Duration::from_millis(80), server_peer.read(&mut probe)).await;
|
||||
if let Ok(Ok(n)) = forwarded {
|
||||
assert_eq!(n, 0, "zero quota path must not forward payload bytes");
|
||||
}
|
||||
|
||||
let result = tokio::time::timeout(Duration::from_secs(2), relay)
|
||||
.await
|
||||
.expect("relay must terminate under zero quota")
|
||||
.expect("relay task must not panic");
|
||||
assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. })));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn quota_lock_integration_no_quota_relays_both_directions_under_burst() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
|
||||
let (mut client_peer, relay_client) = duplex(8192);
|
||||
let (relay_server, mut server_peer) = duplex(8192);
|
||||
let (client_reader, client_writer) = tokio::io::split(relay_client);
|
||||
let (server_reader, server_writer) = tokio::io::split(relay_server);
|
||||
|
||||
let relay = tokio::spawn(relay_bidirectional(
|
||||
client_reader,
|
||||
client_writer,
|
||||
server_reader,
|
||||
server_writer,
|
||||
1024,
|
||||
1024,
|
||||
"quota-none-burst-user",
|
||||
Arc::clone(&stats),
|
||||
None,
|
||||
Arc::new(BufferPool::new()),
|
||||
));
|
||||
|
||||
let c2s = vec![0xA5; 2048];
|
||||
let s2c = vec![0x5A; 1536];
|
||||
|
||||
client_peer.write_all(&c2s).await.expect("client burst write must succeed");
|
||||
let mut got_c2s = vec![0u8; c2s.len()];
|
||||
server_peer.read_exact(&mut got_c2s).await.expect("server must receive c2s burst");
|
||||
assert_eq!(got_c2s, c2s);
|
||||
|
||||
server_peer.write_all(&s2c).await.expect("server burst write must succeed");
|
||||
let mut got_s2c = vec![0u8; s2c.len()];
|
||||
client_peer.read_exact(&mut got_s2c).await.expect("client must receive s2c burst");
|
||||
assert_eq!(got_s2c, s2c);
|
||||
|
||||
drop(client_peer);
|
||||
drop(server_peer);
|
||||
|
||||
let done = tokio::time::timeout(Duration::from_secs(2), relay)
|
||||
.await
|
||||
.expect("relay must terminate after peers close")
|
||||
.expect("relay task must not panic");
|
||||
assert!(done.is_ok());
|
||||
}
|
||||
|
|
@ -31,6 +31,12 @@ impl std::task::Wake for WakeCounter {
|
|||
|
||||
#[tokio::test]
|
||||
async fn quota_lock_contention_does_not_self_wake_pending_writer() {
|
||||
let _guard = super::quota_user_lock_test_guard()
|
||||
.lock()
|
||||
.expect("quota lock test guard must be available");
|
||||
let map = super::QUOTA_USER_LOCKS.get_or_init(dashmap::DashMap::new);
|
||||
map.clear();
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user = "quota-lock-contention-user";
|
||||
|
||||
|
|
@ -66,6 +72,12 @@ async fn quota_lock_contention_does_not_self_wake_pending_writer() {
|
|||
|
||||
#[tokio::test]
|
||||
async fn quota_lock_contention_writer_schedules_single_deferred_wake_until_lock_acquired() {
|
||||
let _guard = super::quota_user_lock_test_guard()
|
||||
.lock()
|
||||
.expect("quota lock test guard must be available");
|
||||
let map = super::QUOTA_USER_LOCKS.get_or_init(dashmap::DashMap::new);
|
||||
map.clear();
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user = "quota-lock-writer-liveness-user";
|
||||
|
||||
|
|
@ -133,6 +145,12 @@ async fn quota_lock_contention_writer_schedules_single_deferred_wake_until_lock_
|
|||
|
||||
#[tokio::test]
|
||||
async fn quota_lock_contention_read_path_schedules_deferred_wake_for_liveness() {
|
||||
let _guard = super::quota_user_lock_test_guard()
|
||||
.lock()
|
||||
.expect("quota lock test guard must be available");
|
||||
let map = super::QUOTA_USER_LOCKS.get_or_init(dashmap::DashMap::new);
|
||||
map.clear();
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user = "quota-lock-read-liveness-user";
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,168 @@
|
|||
use std::net::{IpAddr, Ipv4Addr};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::config::UserMaxUniqueIpsMode;
|
||||
use crate::ip_tracker::UserIpTracker;
|
||||
|
||||
fn ip_from_idx(idx: u32) -> IpAddr {
|
||||
IpAddr::V4(Ipv4Addr::new(10, ((idx >> 16) & 0xff) as u8, ((idx >> 8) & 0xff) as u8, (idx & 0xff) as u8))
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn hotpath_empty_drain_is_idempotent() {
|
||||
let tracker = UserIpTracker::new();
|
||||
for _ in 0..128 {
|
||||
tracker.drain_cleanup_queue().await;
|
||||
}
|
||||
assert_eq!(tracker.get_active_ip_count("none").await, 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn hotpath_batch_cleanup_drain_clears_all_active_entries() {
|
||||
let tracker = UserIpTracker::new();
|
||||
tracker.set_user_limit("u", 100).await;
|
||||
|
||||
for idx in 0..32 {
|
||||
let ip = ip_from_idx(idx);
|
||||
tracker.check_and_add("u", ip).await.unwrap();
|
||||
tracker.enqueue_cleanup("u".to_string(), ip);
|
||||
}
|
||||
|
||||
tracker.drain_cleanup_queue().await;
|
||||
assert_eq!(tracker.get_active_ip_count("u").await, 0);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn hotpath_parallel_enqueue_and_drain_does_not_deadlock() {
|
||||
let tracker = Arc::new(UserIpTracker::new());
|
||||
tracker.set_user_limit("p", 64).await;
|
||||
|
||||
let mut tasks = Vec::new();
|
||||
for worker in 0..32u32 {
|
||||
let t = tracker.clone();
|
||||
tasks.push(tokio::spawn(async move {
|
||||
let ip = ip_from_idx(1_000 + worker);
|
||||
for _ in 0..64 {
|
||||
let _ = t.check_and_add("p", ip).await;
|
||||
t.enqueue_cleanup("p".to_string(), ip);
|
||||
t.drain_cleanup_queue().await;
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
for task in tasks {
|
||||
tokio::time::timeout(Duration::from_secs(3), task)
|
||||
.await
|
||||
.expect("worker must not deadlock")
|
||||
.expect("worker task must not panic");
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn hotpath_parallel_unique_ip_limit_never_exceeds_cap() {
|
||||
let tracker = Arc::new(UserIpTracker::new());
|
||||
tracker.set_user_limit("limit", 5).await;
|
||||
|
||||
let mut tasks = Vec::new();
|
||||
for idx in 0..64u32 {
|
||||
let t = tracker.clone();
|
||||
tasks.push(tokio::spawn(async move { t.check_and_add("limit", ip_from_idx(idx)).await.is_ok() }));
|
||||
}
|
||||
|
||||
let mut admitted = 0usize;
|
||||
for task in tasks {
|
||||
if task.await.expect("task must not panic") {
|
||||
admitted += 1;
|
||||
}
|
||||
}
|
||||
|
||||
assert!(admitted <= 5, "admitted unique IPs must not exceed configured cap");
|
||||
assert!(tracker.get_active_ip_count("limit").await <= 5);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn hotpath_repeated_same_ip_counter_balances_to_zero() {
|
||||
let tracker = UserIpTracker::new();
|
||||
tracker.set_user_limit("same", 1).await;
|
||||
let ip = ip_from_idx(77);
|
||||
|
||||
for _ in 0..512 {
|
||||
tracker.check_and_add("same", ip).await.unwrap();
|
||||
}
|
||||
for _ in 0..512 {
|
||||
tracker.remove_ip("same", ip).await;
|
||||
}
|
||||
|
||||
assert_eq!(tracker.get_active_ip_count("same").await, 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn hotpath_light_fuzz_mixed_operations_preserve_limit_invariants() {
|
||||
let tracker = UserIpTracker::new();
|
||||
tracker.set_user_limit("fuzz", 4).await;
|
||||
|
||||
let mut state: u64 = 0xA55A_5AA5_D15C_B00B;
|
||||
for _ in 0..4_000 {
|
||||
state ^= state << 7;
|
||||
state ^= state >> 9;
|
||||
state ^= state << 8;
|
||||
|
||||
let ip = ip_from_idx((state as u32) % 8);
|
||||
match state & 0x3 {
|
||||
0 | 1 => {
|
||||
let _ = tracker.check_and_add("fuzz", ip).await;
|
||||
}
|
||||
_ => {
|
||||
tracker.remove_ip("fuzz", ip).await;
|
||||
}
|
||||
}
|
||||
|
||||
assert!(
|
||||
tracker.get_active_ip_count("fuzz").await <= 4,
|
||||
"active count must stay within configured cap"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn hotpath_multi_user_churn_keeps_isolation() {
|
||||
let tracker = UserIpTracker::new();
|
||||
tracker.set_user_limit("u1", 2).await;
|
||||
tracker.set_user_limit("u2", 3).await;
|
||||
|
||||
for idx in 0..200u32 {
|
||||
let ip1 = ip_from_idx(idx % 5);
|
||||
let ip2 = ip_from_idx(100 + (idx % 7));
|
||||
let _ = tracker.check_and_add("u1", ip1).await;
|
||||
let _ = tracker.check_and_add("u2", ip2).await;
|
||||
if idx % 2 == 0 {
|
||||
tracker.remove_ip("u1", ip1).await;
|
||||
}
|
||||
if idx % 3 == 0 {
|
||||
tracker.remove_ip("u2", ip2).await;
|
||||
}
|
||||
}
|
||||
|
||||
assert!(tracker.get_active_ip_count("u1").await <= 2);
|
||||
assert!(tracker.get_active_ip_count("u2").await <= 3);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn hotpath_time_window_expiry_allows_new_ip_after_window() {
|
||||
let tracker = UserIpTracker::new();
|
||||
tracker.set_user_limit("tw", 1).await;
|
||||
tracker
|
||||
.set_limit_policy(UserMaxUniqueIpsMode::TimeWindow, 1)
|
||||
.await;
|
||||
|
||||
let ip1 = ip_from_idx(901);
|
||||
let ip2 = ip_from_idx(902);
|
||||
|
||||
tracker.check_and_add("tw", ip1).await.unwrap();
|
||||
tracker.remove_ip("tw", ip1).await;
|
||||
assert!(tracker.check_and_add("tw", ip2).await.is_err());
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(1_100)).await;
|
||||
assert!(tracker.check_and_add("tw", ip2).await.is_ok());
|
||||
}
|
||||
Loading…
Reference in New Issue