From 20e205189c73f5199add6c6b7b6f85fbdd243c2b Mon Sep 17 00:00:00 2001 From: David Osipov Date: Wed, 18 Mar 2026 17:04:50 +0400 Subject: [PATCH] Enhance TLS Emulator with ALPN Support and Add Adversarial Tests - Modified `build_emulated_server_hello` to accept ALPN (Application-Layer Protocol Negotiation) as an optional parameter, allowing for the embedding of ALPN markers in the application data payload. - Implemented logic to handle oversized ALPN values and ensure they do not interfere with the application data payload. - Added new security tests in `emulator_security_tests.rs` to validate the behavior of the ALPN embedding, including scenarios for oversized ALPN and preference for certificate payloads over ALPN markers. - Introduced `send_adversarial_tests.rs` to cover edge cases and potential issues in the middle proxy's send functionality, ensuring robustness against various failure modes. - Updated `middle_proxy` module to include new test modules and ensure proper handling of writer commands during data transmission. --- src/ip_tracker.rs | 74 +- src/ip_tracker_regression_tests.rs | 169 +++ src/protocol/tls.rs | 50 +- src/protocol/tls_security_tests.rs | 283 ++++- src/proxy/client.rs | 27 +- src/proxy/client_security_tests.rs | 29 + src/proxy/direct_relay.rs | 7 +- src/proxy/handshake.rs | 29 +- src/proxy/handshake_security_tests.rs | 41 + src/proxy/middle_relay.rs | 193 +++- src/proxy/middle_relay_security_tests.rs | 530 +++++++++- src/proxy/relay.rs | 175 +++- src/proxy/relay_security_tests.rs | 972 ++++++++++++++++++ src/tls_front/emulator.rs | 42 +- src/tls_front/emulator_security_tests.rs | 136 +++ src/transport/middle_proxy/mod.rs | 2 + src/transport/middle_proxy/pool.rs | 1 + src/transport/middle_proxy/registry.rs | 1 + src/transport/middle_proxy/send.rs | 24 +- .../middle_proxy/send_adversarial_tests.rs | 263 +++++ 20 files changed, 2935 insertions(+), 113 deletions(-) create mode 100644 src/proxy/relay_security_tests.rs create mode 100644 src/tls_front/emulator_security_tests.rs create mode 100644 src/transport/middle_proxy/send_adversarial_tests.rs diff --git a/src/ip_tracker.rs b/src/ip_tracker.rs index fce20b6..c35c587 100644 --- a/src/ip_tracker.rs +++ b/src/ip_tracker.rs @@ -7,8 +7,9 @@ use std::net::IpAddr; use std::sync::Arc; use std::sync::atomic::{AtomicU64, Ordering}; use std::time::{Duration, Instant}; +use std::sync::Mutex; -use tokio::sync::RwLock; +use tokio::sync::{Mutex as AsyncMutex, RwLock}; use crate::config::UserMaxUniqueIpsMode; @@ -21,6 +22,8 @@ pub struct UserIpTracker { limit_mode: Arc>, limit_window: Arc>, last_compact_epoch_secs: Arc, + pub(crate) cleanup_queue: Arc>>, + cleanup_drain_lock: Arc>, } impl UserIpTracker { @@ -33,6 +36,67 @@ impl UserIpTracker { limit_mode: Arc::new(RwLock::new(UserMaxUniqueIpsMode::ActiveWindow)), limit_window: Arc::new(RwLock::new(Duration::from_secs(30))), last_compact_epoch_secs: Arc::new(AtomicU64::new(0)), + cleanup_queue: Arc::new(Mutex::new(Vec::new())), + cleanup_drain_lock: Arc::new(AsyncMutex::new(())), + } + } + + + pub fn enqueue_cleanup(&self, user: String, ip: IpAddr) { + match self.cleanup_queue.lock() { + Ok(mut queue) => queue.push((user, ip)), + Err(poisoned) => { + let mut queue = poisoned.into_inner(); + queue.push((user.clone(), ip)); + self.cleanup_queue.clear_poison(); + tracing::warn!( + "UserIpTracker cleanup_queue lock poisoned; recovered and enqueued IP cleanup for {} ({})", + user, + ip + ); + } + } + } + + pub(crate) async fn drain_cleanup_queue(&self) { + // Serialize queue draining and active-IP mutation so check-and-add cannot + // observe stale active entries that are already queued for removal. + let _drain_guard = self.cleanup_drain_lock.lock().await; + let to_remove = { + match self.cleanup_queue.lock() { + Ok(mut queue) => { + if queue.is_empty() { + return; + } + std::mem::take(&mut *queue) + } + Err(poisoned) => { + let mut queue = poisoned.into_inner(); + if queue.is_empty() { + self.cleanup_queue.clear_poison(); + return; + } + let drained = std::mem::take(&mut *queue); + self.cleanup_queue.clear_poison(); + drained + } + } + }; + + let mut active_ips = self.active_ips.write().await; + for (user, ip) in to_remove { + if let Some(user_ips) = active_ips.get_mut(&user) { + if let Some(count) = user_ips.get_mut(&ip) { + if *count > 1 { + *count -= 1; + } else { + user_ips.remove(&ip); + } + } + if user_ips.is_empty() { + active_ips.remove(&user); + } + } } } @@ -118,6 +182,7 @@ impl UserIpTracker { } pub async fn check_and_add(&self, username: &str, ip: IpAddr) -> Result<(), String> { + self.drain_cleanup_queue().await; self.maybe_compact_empty_users().await; let default_max_ips = *self.default_max_ips.read().await; let limit = { @@ -194,6 +259,7 @@ impl UserIpTracker { } pub async fn get_recent_counts_for_users(&self, users: &[String]) -> HashMap { + self.drain_cleanup_queue().await; let window = *self.limit_window.read().await; let now = Instant::now(); let recent_ips = self.recent_ips.read().await; @@ -214,6 +280,7 @@ impl UserIpTracker { } pub async fn get_active_ips_for_users(&self, users: &[String]) -> HashMap> { + self.drain_cleanup_queue().await; let active_ips = self.active_ips.read().await; let mut out = HashMap::with_capacity(users.len()); for user in users { @@ -228,6 +295,7 @@ impl UserIpTracker { } pub async fn get_recent_ips_for_users(&self, users: &[String]) -> HashMap> { + self.drain_cleanup_queue().await; let window = *self.limit_window.read().await; let now = Instant::now(); let recent_ips = self.recent_ips.read().await; @@ -250,11 +318,13 @@ impl UserIpTracker { } pub async fn get_active_ip_count(&self, username: &str) -> usize { + self.drain_cleanup_queue().await; let active_ips = self.active_ips.read().await; active_ips.get(username).map(|ips| ips.len()).unwrap_or(0) } pub async fn get_active_ips(&self, username: &str) -> Vec { + self.drain_cleanup_queue().await; let active_ips = self.active_ips.read().await; active_ips .get(username) @@ -263,6 +333,7 @@ impl UserIpTracker { } pub async fn get_stats(&self) -> Vec<(String, usize, usize)> { + self.drain_cleanup_queue().await; let active_ips = self.active_ips.read().await; let max_ips = self.max_ips.read().await; let default_max_ips = *self.default_max_ips.read().await; @@ -301,6 +372,7 @@ impl UserIpTracker { } pub async fn is_ip_active(&self, username: &str, ip: IpAddr) -> bool { + self.drain_cleanup_queue().await; let active_ips = self.active_ips.read().await; active_ips .get(username) diff --git a/src/ip_tracker_regression_tests.rs b/src/ip_tracker_regression_tests.rs index 5d6b358..57e135d 100644 --- a/src/ip_tracker_regression_tests.rs +++ b/src/ip_tracker_regression_tests.rs @@ -448,3 +448,172 @@ async fn concurrent_reconnect_and_disconnect_preserves_non_negative_counts() { assert!(tracker.get_active_ip_count("cc").await <= 8); } + +#[tokio::test] +async fn enqueue_cleanup_recovers_from_poisoned_mutex() { + let tracker = UserIpTracker::new(); + let ip = ip_from_idx(99); + + // Poison the lock by panicking while holding it + let result = std::panic::catch_unwind(|| { + let _guard = tracker.cleanup_queue.lock().unwrap(); + panic!("Intentional poison panic"); + }); + assert!(result.is_err(), "Expected panic to poison mutex"); + + // Attempt to enqueue anyway; should hit the poison catch arm and still insert + tracker.enqueue_cleanup("poison-user".to_string(), ip); + + tracker.drain_cleanup_queue().await; + + assert_eq!(tracker.get_active_ip_count("poison-user").await, 0); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn mass_reconnect_sync_cleanup_prevents_temporary_reservation_bloat() { + // Tests that synchronous M-01 drop mechanism protects against starvation + let tracker = Arc::new(UserIpTracker::new()); + tracker.set_user_limit("mass", 5).await; + + let ip = ip_from_idx(42); + let mut join_handles = Vec::new(); + + // 10,000 rapid concurrent requests hitting the same IP limit + for _ in 0..10_000 { + let tracker_clone = tracker.clone(); + join_handles.push(tokio::spawn(async move { + if tracker_clone.check_and_add("mass", ip).await.is_ok() { + // Instantly enqueue cleanup, simulating synchronous reservation drop + tracker_clone.enqueue_cleanup("mass".to_string(), ip); + // The next caller will drain it before acquiring again + } + })); + } + + for handle in join_handles { + let _ = handle.await; + } + + // Force flush + tracker.drain_cleanup_queue().await; + assert_eq!(tracker.get_active_ip_count("mass").await, 0, "No leaked footprints"); +} + +#[tokio::test] +async fn adversarial_drain_cleanup_queue_race_does_not_cause_false_rejections() { + // Regression guard: concurrent cleanup draining must not produce false + // limit denials for a new IP when the previous IP is already queued. + let tracker = Arc::new(UserIpTracker::new()); + tracker.set_user_limit("racer", 1).await; + let ip1 = ip_from_idx(1); + let ip2 = ip_from_idx(2); + + // Initial state: add ip1 + tracker.check_and_add("racer", ip1).await.unwrap(); + + // User disconnects from ip1, queuing it + tracker.enqueue_cleanup("racer".to_string(), ip1); + + let mut saw_false_rejection = false; + for _ in 0..100 { + // Queue cleanup then race explicit drain and check-and-add on the alternative IP. + tracker.enqueue_cleanup("racer".to_string(), ip1); + let tracker_a = tracker.clone(); + let tracker_b = tracker.clone(); + + let drain_handle = tokio::spawn(async move { + tracker_a.drain_cleanup_queue().await; + }); + let handle = tokio::spawn(async move { + tracker_b.check_and_add("racer", ip2).await + }); + + drain_handle.await.unwrap(); + let res = handle.await.unwrap(); + if res.is_err() { + saw_false_rejection = true; + break; + } + + // Restore baseline for next iteration. + tracker.remove_ip("racer", ip2).await; + tracker.check_and_add("racer", ip1).await.unwrap(); + } + + assert!( + !saw_false_rejection, + "Concurrent cleanup draining must not cause false-positive IP denials" + ); +} + +#[tokio::test] +async fn poisoned_cleanup_queue_still_releases_slot_for_next_ip() { + let tracker = UserIpTracker::new(); + tracker.set_user_limit("poison-slot", 1).await; + let ip1 = ip_from_idx(7001); + let ip2 = ip_from_idx(7002); + + tracker.check_and_add("poison-slot", ip1).await.unwrap(); + + // Poison the queue lock as an adversarial condition. + let _ = std::panic::catch_unwind(|| { + let _guard = tracker.cleanup_queue.lock().unwrap(); + panic!("intentional queue poison"); + }); + + // Disconnect path must still queue cleanup so the next IP can be admitted. + tracker.enqueue_cleanup("poison-slot".to_string(), ip1); + let admitted = tracker.check_and_add("poison-slot", ip2).await; + assert!( + admitted.is_ok(), + "cleanup queue poison must not permanently block slot release for the next IP" + ); +} + +#[tokio::test] +async fn duplicate_cleanup_entries_do_not_break_future_admission() { + let tracker = UserIpTracker::new(); + tracker.set_user_limit("dup-cleanup", 1).await; + let ip1 = ip_from_idx(7101); + let ip2 = ip_from_idx(7102); + + tracker.check_and_add("dup-cleanup", ip1).await.unwrap(); + tracker.enqueue_cleanup("dup-cleanup".to_string(), ip1); + tracker.enqueue_cleanup("dup-cleanup".to_string(), ip1); + tracker.enqueue_cleanup("dup-cleanup".to_string(), ip1); + + tracker.drain_cleanup_queue().await; + + assert_eq!(tracker.get_active_ip_count("dup-cleanup").await, 0); + assert!( + tracker.check_and_add("dup-cleanup", ip2).await.is_ok(), + "extra queued cleanup entries must not leave user stuck in denied state" + ); +} + +#[tokio::test] +async fn stress_repeated_queue_poison_recovery_preserves_admission_progress() { + let tracker = UserIpTracker::new(); + tracker.set_user_limit("poison-stress", 1).await; + let ip_primary = ip_from_idx(7201); + let ip_alt = ip_from_idx(7202); + + tracker.check_and_add("poison-stress", ip_primary).await.unwrap(); + + for _ in 0..64 { + let _ = std::panic::catch_unwind(|| { + let _guard = tracker.cleanup_queue.lock().unwrap(); + panic!("intentional queue poison in stress loop"); + }); + + tracker.enqueue_cleanup("poison-stress".to_string(), ip_primary); + + assert!( + tracker.check_and_add("poison-stress", ip_alt).await.is_ok(), + "poison recovery must preserve admission progress under repeated queue poisoning" + ); + + tracker.remove_ip("poison-stress", ip_alt).await; + tracker.check_and_add("poison-stress", ip_primary).await.unwrap(); + } +} diff --git a/src/protocol/tls.rs b/src/protocol/tls.rs index 3f9f981..12a2158 100644 --- a/src/protocol/tls.rs +++ b/src/protocol/tls.rs @@ -34,6 +34,9 @@ pub const TIME_SKEW_MIN: i64 = -2 * 60; // 2 minutes before pub const TIME_SKEW_MAX: i64 = 2 * 60; // 2 minutes after /// Maximum accepted boot-time timestamp (seconds) before skew checks are enforced. pub const BOOT_TIME_MAX_SECS: u32 = 7 * 24 * 60 * 60; +/// Hard cap for boot-time compatibility bypass to avoid oversized acceptance +/// windows when replay TTL is configured very large. +pub const BOOT_TIME_COMPAT_MAX_SECS: u32 = 2 * 60; // ============= Private Constants ============= @@ -66,6 +69,7 @@ pub struct TlsValidation { /// Client digest for response generation pub digest: [u8; TLS_DIGEST_LEN], /// Timestamp extracted from digest + pub timestamp: u32, } @@ -121,6 +125,7 @@ impl TlsExtensionBuilder { } /// Build final extensions with length prefix + fn build(self) -> Vec { let mut result = Vec::with_capacity(2 + self.extensions.len()); @@ -135,7 +140,7 @@ impl TlsExtensionBuilder { } /// Get current extensions without length prefix (for calculation) - #[allow(dead_code)] + fn as_bytes(&self) -> &[u8] { &self.extensions } @@ -251,6 +256,7 @@ impl ServerHelloBuilder { /// Returns validation result if a matching user is found. /// The result **must** be used — ignoring it silently bypasses authentication. #[must_use] + pub fn validate_tls_handshake( handshake: &[u8], secrets: &[(String, Vec)], @@ -266,9 +272,9 @@ pub fn validate_tls_handshake( /// Validate TLS ClientHello and cap the boot-time bypass by replay-cache TTL. /// -/// A boot-time timestamp is only accepted when it falls below both -/// `BOOT_TIME_MAX_SECS` and the configured replay window, preventing timestamp -/// reuse outside replay cache coverage. +/// A boot-time timestamp is only accepted when it falls below all three +/// bounds: `BOOT_TIME_MAX_SECS`, configured replay window, and +/// `BOOT_TIME_COMPAT_MAX_SECS`, preventing oversized compatibility windows. #[must_use] pub fn validate_tls_handshake_with_replay_window( handshake: &[u8], @@ -292,7 +298,9 @@ pub fn validate_tls_handshake_with_replay_window( let boot_time_cap_secs = if ignore_time_skew { 0 } else { - BOOT_TIME_MAX_SECS.min(replay_window_u32) + BOOT_TIME_MAX_SECS + .min(replay_window_u32) + .min(BOOT_TIME_COMPAT_MAX_SECS) }; validate_tls_handshake_at_time_with_boot_cap( @@ -312,6 +320,7 @@ fn system_time_to_unix_secs(now: SystemTime) -> Option { i64::try_from(d.as_secs()).ok() } + fn validate_tls_handshake_at_time( handshake: &[u8], secrets: &[(String, Vec)], @@ -437,7 +446,7 @@ pub fn build_server_hello( session_id: &[u8], fake_cert_len: usize, rng: &SecureRandom, - _alpn: Option>, + alpn: Option>, new_session_tickets: u8, ) -> Vec { const MIN_APP_DATA: usize = 64; @@ -459,8 +468,27 @@ pub fn build_server_hello( 0x01, // CCS byte ]; - // Build fake certificate (Application Data record) - let fake_cert = rng.bytes(fake_cert_len); + // Build first encrypted flight mimic as opaque ApplicationData bytes. + // Embed a compact EncryptedExtensions-like ALPN block when selected. + let mut fake_cert = Vec::with_capacity(fake_cert_len); + if let Some(proto) = alpn.as_ref().filter(|p| !p.is_empty() && p.len() <= u8::MAX as usize) { + let proto_list_len = 1usize + proto.len(); + let ext_data_len = 2usize + proto_list_len; + let marker_len = 4usize + ext_data_len; + if marker_len <= fake_cert_len { + fake_cert.extend_from_slice(&0x0010u16.to_be_bytes()); + fake_cert.extend_from_slice(&(ext_data_len as u16).to_be_bytes()); + fake_cert.extend_from_slice(&(proto_list_len as u16).to_be_bytes()); + fake_cert.push(proto.len() as u8); + fake_cert.extend_from_slice(proto); + } + } + if fake_cert.len() < fake_cert_len { + fake_cert.extend_from_slice(&rng.bytes(fake_cert_len - fake_cert.len())); + } else if fake_cert.len() > fake_cert_len { + fake_cert.truncate(fake_cert_len); + } + let mut app_data_record = Vec::with_capacity(5 + fake_cert_len); app_data_record.push(TLS_RECORD_APPLICATION); app_data_record.extend_from_slice(&TLS_VERSION); @@ -472,8 +500,9 @@ pub fn build_server_hello( // Build optional NewSessionTicket records (TLS 1.3 handshake messages are encrypted; // here we mimic with opaque ApplicationData records of plausible size). let mut tickets = Vec::new(); - if new_session_tickets > 0 { - for _ in 0..new_session_tickets { + let ticket_count = new_session_tickets.min(4); + if ticket_count > 0 { + for _ in 0..ticket_count { let ticket_len: usize = rng.range(48) + 48; // 48-95 bytes let mut record = Vec::with_capacity(5 + ticket_len); record.push(TLS_RECORD_APPLICATION); @@ -678,6 +707,7 @@ pub fn is_tls_handshake(first_bytes: &[u8]) -> bool { } /// Parse TLS record header, returns (record_type, length) + pub fn parse_tls_record_header(header: &[u8; 5]) -> Option<(u8, u16)> { let record_type = header[0]; let version = [header[1], header[2]]; diff --git a/src/protocol/tls_security_tests.rs b/src/protocol/tls_security_tests.rs index 9f568b5..f8f2695 100644 --- a/src/protocol/tls_security_tests.rs +++ b/src/protocol/tls_security_tests.rs @@ -300,8 +300,8 @@ fn boot_time_timestamp_accepted_without_ignore_flag() { // Timestamps below the boot-time threshold are treated as client uptime, // not real wall-clock time. The proxy allows them regardless of skew. let secret = b"boot_time_test"; - // Keep this safely below BOOT_TIME_MAX_SECS to assert bypass behavior. - let boot_ts: u32 = BOOT_TIME_MAX_SECS / 2; + // Keep this safely below compatibility cap to assert bypass behavior. + let boot_ts: u32 = BOOT_TIME_COMPAT_MAX_SECS.saturating_sub(1); let handshake = make_valid_tls_handshake(secret, boot_ts); let secrets = vec![("u".to_string(), secret.to_vec())]; assert!( @@ -663,13 +663,14 @@ fn zero_length_session_id_accepted() { // Boot-time threshold — exact boundary precision // ------------------------------------------------------------------ -/// timestamp = BOOT_TIME_MAX_SECS - 1 is the last value inside the boot-time window. +/// timestamp = BOOT_TIME_COMPAT_MAX_SECS - 1 is the last value inside +/// the runtime boot-time compatibility window. /// is_boot_time = true → skew check is skipped entirely → accepted even /// when `now` is far from the timestamp. #[test] fn timestamp_one_below_boot_threshold_bypasses_skew_check() { let secret = b"boot_last_value_test"; - let ts: u32 = BOOT_TIME_MAX_SECS - 1; + let ts: u32 = BOOT_TIME_COMPAT_MAX_SECS - 1; let h = make_valid_tls_handshake(secret, ts); let secrets = vec![("u".to_string(), secret.to_vec())]; @@ -677,32 +678,48 @@ fn timestamp_one_below_boot_threshold_bypasses_skew_check() { // Boot-time bypass must prevent the skew check from running. assert!( validate_tls_handshake_at_time(&h, &secrets, false, 0).is_some(), - "ts=BOOT_TIME_MAX_SECS-1 must bypass skew check regardless of now" + "ts=BOOT_TIME_COMPAT_MAX_SECS-1 must bypass skew check regardless of now" ); } -/// timestamp = BOOT_TIME_MAX_SECS is the first value outside the boot-time window. +/// timestamp = BOOT_TIME_COMPAT_MAX_SECS is the first value outside the +/// runtime boot-time compatibility window. /// is_boot_time = false → skew check IS applied. Two sub-cases confirm this: /// once with now chosen so the skew passes (accepted) and once where it fails. #[test] fn timestamp_at_boot_threshold_triggers_skew_check() { let secret = b"boot_exact_value_test"; - let ts: u32 = BOOT_TIME_MAX_SECS; + let ts: u32 = BOOT_TIME_COMPAT_MAX_SECS; let h = make_valid_tls_handshake(secret, ts); let secrets = vec![("u".to_string(), secret.to_vec())]; // now = ts + 50 → time_diff = 50, within [-1200, 600] → accepted. let now_valid: i64 = ts as i64 + 50; assert!( - validate_tls_handshake_at_time(&h, &secrets, false, now_valid).is_some(), - "ts=BOOT_TIME_MAX_SECS within skew window must be accepted via skew check" + validate_tls_handshake_at_time_with_boot_cap( + &h, + &secrets, + false, + now_valid, + BOOT_TIME_COMPAT_MAX_SECS, + ) + .is_some(), + "ts=BOOT_TIME_COMPAT_MAX_SECS within skew window must be accepted via skew check" ); - // now = 0 → time_diff = -86_400_000, outside window → rejected. - // If the boot-time bypass were wrongly applied here this would pass. + // now = -1 → time_diff = -121 at the 120-second threshold, outside window + // for TIME_SKEW_MIN=-120. If boot-time bypass were wrongly applied this + // would pass. assert!( - validate_tls_handshake_at_time(&h, &secrets, false, 0).is_none(), - "ts=BOOT_TIME_MAX_SECS far from now must be rejected — no boot-time bypass" + validate_tls_handshake_at_time_with_boot_cap( + &h, + &secrets, + false, + -1, + BOOT_TIME_COMPAT_MAX_SECS, + ) + .is_none(), + "ts=BOOT_TIME_COMPAT_MAX_SECS far from now must be rejected — no boot-time bypass" ); } @@ -723,7 +740,7 @@ fn replay_window_cap_disables_boot_bypass_for_old_timestamps() { #[test] fn replay_window_cap_still_allows_small_boot_timestamp() { let secret = b"boot_cap_enabled_test"; - let ts: u32 = 120; + let ts: u32 = BOOT_TIME_COMPAT_MAX_SECS.saturating_sub(1); let h = make_valid_tls_handshake(secret, ts); let secrets = vec![("u".to_string(), secret.to_vec())]; @@ -734,6 +751,20 @@ fn replay_window_cap_still_allows_small_boot_timestamp() { ); } +#[test] +fn large_replay_window_is_hard_capped_for_boot_compatibility() { + let secret = b"boot_cap_hard_limit_test"; + let ts: u32 = BOOT_TIME_COMPAT_MAX_SECS + 1; + let h = make_valid_tls_handshake(secret, ts); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + let result = validate_tls_handshake_with_replay_window(&h, &secrets, false, u64::MAX); + assert!( + result.is_none(), + "very large replay window must not expand boot-time bypass beyond hard compatibility cap" + ); +} + #[test] fn ignore_time_skew_explicitly_decouples_from_boot_time_cap() { let secret = b"ignore_skew_boot_cap_decouple_test"; @@ -743,7 +774,7 @@ fn ignore_time_skew_explicitly_decouples_from_boot_time_cap() { let cap_zero = validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, true, 0, 0); let cap_nonzero = - validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, true, 0, BOOT_TIME_MAX_SECS); + validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, true, 0, BOOT_TIME_COMPAT_MAX_SECS); assert!(cap_zero.is_some(), "ignore_time_skew=true must accept valid HMAC"); assert!( @@ -1889,6 +1920,228 @@ fn server_hello_new_session_ticket_count_matches_configuration() { ); } +#[test] +fn server_hello_new_session_ticket_count_is_safely_capped() { + let secret = b"ticket_count_cap_test"; + let client_digest = [0x44u8; TLS_DIGEST_LEN]; + let session_id = vec![0x54; 32]; + let rng = crate::crypto::SecureRandom::new(); + + let response = build_server_hello(secret, &client_digest, &session_id, 1024, &rng, None, u8::MAX); + + let mut pos = 0usize; + let mut app_records = 0usize; + while pos + 5 <= response.len() { + let rtype = response[pos]; + let rlen = u16::from_be_bytes([response[pos + 3], response[pos + 4]]) as usize; + let next = pos + 5 + rlen; + assert!(next <= response.len(), "TLS record must stay inside response bounds"); + if rtype == TLS_RECORD_APPLICATION { + app_records += 1; + } + pos = next; + } + + assert_eq!( + app_records, + 5, + "response must cap ticket-like tail records to four plus one main application record" + ); +} + +#[test] +fn server_hello_application_data_contains_alpn_marker_when_selected() { + let secret = b"alpn_marker_test"; + let client_digest = [0x55u8; TLS_DIGEST_LEN]; + let session_id = vec![0xAB; 32]; + let rng = crate::crypto::SecureRandom::new(); + + let response = build_server_hello( + secret, + &client_digest, + &session_id, + 512, + &rng, + Some(b"h2".to_vec()), + 0, + ); + + let sh_len = u16::from_be_bytes([response[3], response[4]]) as usize; + let ccs_pos = 5 + sh_len; + let ccs_len = u16::from_be_bytes([response[ccs_pos + 3], response[ccs_pos + 4]]) as usize; + let app_pos = ccs_pos + 5 + ccs_len; + let app_len = u16::from_be_bytes([response[app_pos + 3], response[app_pos + 4]]) as usize; + let app_payload = &response[app_pos + 5..app_pos + 5 + app_len]; + + let expected = [0x00u8, 0x10, 0x00, 0x05, 0x00, 0x03, 0x02, b'h', b'2']; + assert!( + app_payload.windows(expected.len()).any(|window| window == expected), + "first application payload must carry ALPN marker for selected protocol" + ); +} + +#[test] +fn server_hello_ignores_oversized_alpn_and_still_caps_ticket_tail() { + let secret = b"alpn_oversize_ignore_test"; + let client_digest = [0x56u8; TLS_DIGEST_LEN]; + let session_id = vec![0xCD; 32]; + let rng = crate::crypto::SecureRandom::new(); + let oversized_alpn = vec![b'x'; u8::MAX as usize + 1]; + + let response = build_server_hello( + secret, + &client_digest, + &session_id, + 512, + &rng, + Some(oversized_alpn), + u8::MAX, + ); + + let mut pos = 0usize; + let mut app_records = 0usize; + let mut first_app_payload: Option<&[u8]> = None; + while pos + 5 <= response.len() { + let rtype = response[pos]; + let rlen = u16::from_be_bytes([response[pos + 3], response[pos + 4]]) as usize; + let next = pos + 5 + rlen; + assert!(next <= response.len(), "TLS record must stay inside response bounds"); + if rtype == TLS_RECORD_APPLICATION { + app_records += 1; + if first_app_payload.is_none() { + first_app_payload = Some(&response[pos + 5..next]); + } + } + pos = next; + } + let marker = [0x00u8, 0x10, 0x00, 0x06, 0x00, 0x04, 0x03, b'x', b'x', b'x', b'x']; + + assert_eq!( + app_records, 5, + "oversized ALPN must not change the four-ticket cap on tail records" + ); + assert!( + !first_app_payload + .expect("response must contain an application record") + .windows(marker.len()) + .any(|window| window == marker), + "oversized ALPN must be ignored rather than embedded into the first application payload" + ); +} + +#[test] +fn server_hello_ignores_oversized_alpn_when_marker_would_not_fit() { + let secret = b"alpn_too_large_to_fit_test"; + let client_digest = [0x57u8; TLS_DIGEST_LEN]; + let session_id = vec![0xEF; 32]; + let rng = crate::crypto::SecureRandom::new(); + let oversized_alpn = vec![0xAB; u8::MAX as usize]; + + let response = build_server_hello( + secret, + &client_digest, + &session_id, + 64, + &rng, + Some(oversized_alpn), + 0, + ); + + let sh_len = u16::from_be_bytes([response[3], response[4]]) as usize; + let ccs_pos = 5 + sh_len; + let ccs_len = u16::from_be_bytes([response[ccs_pos + 3], response[ccs_pos + 4]]) as usize; + let app_pos = ccs_pos + 5 + ccs_len; + let app_len = u16::from_be_bytes([response[app_pos + 3], response[app_pos + 4]]) as usize; + let app_payload = &response[app_pos + 5..app_pos + 5 + app_len]; + + let mut marker_prefix = Vec::new(); + marker_prefix.extend_from_slice(&0x0010u16.to_be_bytes()); + marker_prefix.extend_from_slice(&0x0102u16.to_be_bytes()); + marker_prefix.extend_from_slice(&0x0100u16.to_be_bytes()); + marker_prefix.push(0xff); + marker_prefix.extend_from_slice(&[0xab; 8]); + assert!( + !app_payload.starts_with(&marker_prefix), + "oversized ALPN must not be partially embedded into the ServerHello application record" + ); +} + +#[test] +fn server_hello_embeds_full_alpn_marker_when_it_exactly_fits_fake_cert_len() { + let secret = b"alpn_exact_fit_test"; + let client_digest = [0x58u8; TLS_DIGEST_LEN]; + let session_id = vec![0xA5; 32]; + let rng = crate::crypto::SecureRandom::new(); + let proto = vec![b'z'; 57]; + + // marker_len = 4 + (2 + (1 + proto_len)) = 7 + proto_len = 64 + let response = build_server_hello( + secret, + &client_digest, + &session_id, + 64, + &rng, + Some(proto.clone()), + 0, + ); + + let sh_len = u16::from_be_bytes([response[3], response[4]]) as usize; + let ccs_pos = 5 + sh_len; + let ccs_len = u16::from_be_bytes([response[ccs_pos + 3], response[ccs_pos + 4]]) as usize; + let app_pos = ccs_pos + 5 + ccs_len; + let app_len = u16::from_be_bytes([response[app_pos + 3], response[app_pos + 4]]) as usize; + let app_payload = &response[app_pos + 5..app_pos + 5 + app_len]; + + let mut expected_marker = Vec::new(); + expected_marker.extend_from_slice(&0x0010u16.to_be_bytes()); + expected_marker.extend_from_slice(&0x003Cu16.to_be_bytes()); + expected_marker.extend_from_slice(&0x003Au16.to_be_bytes()); + expected_marker.push(57u8); + expected_marker.extend_from_slice(&proto); + + assert_eq!(app_payload.len(), expected_marker.len()); + assert_eq!(app_payload, expected_marker.as_slice()); +} + +#[test] +fn server_hello_does_not_embed_partial_alpn_marker_when_one_byte_short() { + let secret = b"alpn_one_byte_short_test"; + let client_digest = [0x59u8; TLS_DIGEST_LEN]; + let session_id = vec![0xA6; 32]; + let rng = crate::crypto::SecureRandom::new(); + let proto = vec![0xAB; 58]; + + // marker_len = 65, fake_cert_len = 64 => marker must be fully skipped. + let response = build_server_hello( + secret, + &client_digest, + &session_id, + 64, + &rng, + Some(proto), + 0, + ); + + let sh_len = u16::from_be_bytes([response[3], response[4]]) as usize; + let ccs_pos = 5 + sh_len; + let ccs_len = u16::from_be_bytes([response[ccs_pos + 3], response[ccs_pos + 4]]) as usize; + let app_pos = ccs_pos + 5 + ccs_len; + let app_len = u16::from_be_bytes([response[app_pos + 3], response[app_pos + 4]]) as usize; + let app_payload = &response[app_pos + 5..app_pos + 5 + app_len]; + + let mut marker_prefix = Vec::new(); + marker_prefix.extend_from_slice(&0x0010u16.to_be_bytes()); + marker_prefix.extend_from_slice(&0x003Du16.to_be_bytes()); + marker_prefix.extend_from_slice(&0x003Bu16.to_be_bytes()); + marker_prefix.push(58u8); + marker_prefix.extend_from_slice(&[0xAB; 8]); + + assert!( + !app_payload.starts_with(&marker_prefix), + "one-byte-short ALPN marker must be skipped entirely, not partially embedded" + ); +} + #[test] fn exhaustive_tls_minor_version_classification_matches_policy() { for minor in 0u8..=u8::MAX { diff --git a/src/proxy/client.rs b/src/proxy/client.rs index 199f775..d7b3660 100644 --- a/src/proxy/client.rs +++ b/src/proxy/client.rs @@ -31,19 +31,16 @@ struct UserConnectionReservation { user: String, ip: IpAddr, active: bool, - runtime_handle: Option, } impl UserConnectionReservation { fn new(stats: Arc, ip_tracker: Arc, user: String, ip: IpAddr) -> Self { - let runtime_handle = tokio::runtime::Handle::try_current().ok(); Self { stats, ip_tracker, user, ip, active: true, - runtime_handle, } } @@ -64,29 +61,7 @@ impl Drop for UserConnectionReservation { } self.active = false; self.stats.decrement_user_curr_connects(&self.user); - - if let Some(handle) = &self.runtime_handle { - let ip_tracker = self.ip_tracker.clone(); - let user = self.user.clone(); - let ip = self.ip; - let handle = handle.clone(); - handle.spawn(async move { - ip_tracker.remove_ip(&user, ip).await; - }); - } else if let Ok(handle) = tokio::runtime::Handle::try_current() { - let ip_tracker = self.ip_tracker.clone(); - let user = self.user.clone(); - let ip = self.ip; - handle.spawn(async move { - ip_tracker.remove_ip(&user, ip).await; - }); - } else { - warn!( - user = %self.user, - ip = %self.ip, - "UserConnectionReservation dropped without Tokio runtime; IP reservation cleanup skipped" - ); - } + self.ip_tracker.enqueue_cleanup(self.user.clone(), self.ip); } } diff --git a/src/proxy/client_security_tests.rs b/src/proxy/client_security_tests.rs index 6b236aa..7e34f4b 100644 --- a/src/proxy/client_security_tests.rs +++ b/src/proxy/client_security_tests.rs @@ -42,6 +42,35 @@ where CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) } +#[tokio::test] +async fn user_connection_reservation_drop_enqueues_cleanup_synchronously() { + let ip_tracker = Arc::new(crate::ip_tracker::UserIpTracker::new()); + let stats = Arc::new(crate::stats::Stats::new()); + let user = "sync-drop-user".to_string(); + let ip: std::net::IpAddr = "192.168.1.1".parse().unwrap(); + + ip_tracker.set_user_limit(&user, 1).await; + ip_tracker.check_and_add(&user, ip).await.unwrap(); + stats.increment_user_curr_connects(&user); + + assert_eq!(ip_tracker.get_active_ip_count(&user).await, 1); + assert_eq!(stats.get_user_curr_connects(&user), 1); + + let reservation = UserConnectionReservation::new(stats.clone(), ip_tracker.clone(), user.clone(), ip); + + // Drop the reservation synchronously without any tokio::spawn/await yielding! + drop(reservation); + + // The IP is now inside the cleanup_queue, check that the queue has length 1 + let queue_len = ip_tracker.cleanup_queue.lock().unwrap().len(); + assert_eq!(queue_len, 1, "Reservation drop must push directly to synchronized IP queue"); + + assert_eq!(stats.get_user_curr_connects(&user), 0, "Stats must decrement immediately"); + + ip_tracker.drain_cleanup_queue().await; + assert_eq!(ip_tracker.get_active_ip_count(&user).await, 0); +} + #[tokio::test] async fn relay_task_abort_releases_user_gate_and_ip_reservation() { let tg_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); diff --git a/src/proxy/direct_relay.rs b/src/proxy/direct_relay.rs index 4a7b9a9..d36856d 100644 --- a/src/proxy/direct_relay.rs +++ b/src/proxy/direct_relay.rs @@ -132,7 +132,11 @@ fn open_unknown_dc_log_append(path: &Path) -> std::io::Result { } #[cfg(not(unix))] { - OpenOptions::new().create(true).append(true).open(path) + let _ = path; + Err(std::io::Error::new( + std::io::ErrorKind::PermissionDenied, + "unknown_dc_file_log_enabled requires unix O_NOFOLLOW support", + )) } } @@ -204,6 +208,7 @@ where config.general.direct_relay_copy_buf_s2c_bytes, user, Arc::clone(&stats), + config.access.user_data_quota.get(user).copied(), buffer_pool, ); tokio::pin!(relay_result); diff --git a/src/proxy/handshake.rs b/src/proxy/handshake.rs index 3659754..dc83ccc 100644 --- a/src/proxy/handshake.rs +++ b/src/proxy/handshake.rs @@ -241,7 +241,26 @@ fn auth_probe_record_failure_with_state( rounds += 1; if rounds > 8 { auth_probe_note_saturation(now); - return; + let mut eviction_candidate: Option<(IpAddr, u32, Instant)> = None; + for entry in state.iter().take(AUTH_PROBE_PRUNE_SCAN_LIMIT) { + let key = *entry.key(); + let fail_streak = entry.value().fail_streak; + let last_seen = entry.value().last_seen; + match eviction_candidate { + Some((_, current_fail, current_seen)) + if fail_streak > current_fail + || (fail_streak == current_fail && last_seen >= current_seen) => + { + } + _ => eviction_candidate = Some((key, fail_streak, last_seen)), + } + } + + let Some((evict_key, _, _)) = eviction_candidate else { + return; + }; + state.remove(&evict_key); + break; } let mut stale_keys = Vec::new(); @@ -518,6 +537,7 @@ pub struct HandshakeSuccess { /// Client address pub peer: SocketAddr, /// Whether TLS was used + pub is_tls: bool, } @@ -716,7 +736,11 @@ where R: AsyncRead + Unpin + Send, W: AsyncWrite + Unpin + Send, { - trace!(peer = %peer, handshake = ?hex::encode(handshake), "MTProto handshake bytes"); + trace!( + peer = %peer, + handshake_head = %hex::encode(&handshake[..8]), + "MTProto handshake prefix" + ); let throttle_now = Instant::now(); if auth_probe_should_apply_preauth_throttle(peer.ip(), throttle_now) { @@ -916,6 +940,7 @@ pub fn encrypt_tg_nonce_with_ciphers(nonce: &[u8; HANDSHAKE_LEN]) -> (Vec, A } /// Encrypt nonce for sending to Telegram (legacy function for compatibility) + pub fn encrypt_tg_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec { let (encrypted, _, _) = encrypt_tg_nonce_with_ciphers(nonce); encrypted diff --git a/src/proxy/handshake_security_tests.rs b/src/proxy/handshake_security_tests.rs index 2132fbe..7af7192 100644 --- a/src/proxy/handshake_security_tests.rs +++ b/src/proxy/handshake_security_tests.rs @@ -1584,6 +1584,47 @@ fn stress_auth_probe_full_map_churn_keeps_bound_and_tracks_newcomers() { } } +#[test] +fn auth_probe_over_cap_churn_still_tracks_newcomer_after_round_limit() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let state = DashMap::new(); + let now = Instant::now(); + let initial = AUTH_PROBE_TRACK_MAX_ENTRIES + 32; + + for idx in 0..initial { + let ip = IpAddr::V4(Ipv4Addr::new( + 10, + 6, + ((idx >> 8) & 0xff) as u8, + (idx & 0xff) as u8, + )); + state.insert( + ip, + AuthProbeState { + fail_streak: 1, + blocked_until: now, + last_seen: now + Duration::from_millis((idx % 1024) as u64), + }, + ); + } + + let newcomer = IpAddr::V4(Ipv4Addr::new(203, 0, 114, 77)); + auth_probe_record_failure_with_state(&state, newcomer, now + Duration::from_secs(1)); + + assert!( + state.get(&newcomer).is_some(), + "new probe source must still be tracked even when map starts above hard cap" + ); + assert!( + state.len() < initial + 1, + "round-limited eviction path must still reclaim capacity under over-cap churn" + ); +} + #[test] fn auth_probe_capacity_prefers_evicting_low_fail_streak_entries_first() { let _guard = auth_probe_test_lock() diff --git a/src/proxy/middle_relay.rs b/src/proxy/middle_relay.rs index bf23045..1dbbbfd 100644 --- a/src/proxy/middle_relay.rs +++ b/src/proxy/middle_relay.rs @@ -2,15 +2,13 @@ use std::collections::hash_map::RandomState; use std::hash::BuildHasher; use std::hash::{Hash, Hasher}; use std::net::{IpAddr, SocketAddr}; -use std::sync::atomic::{AtomicU64, Ordering}; -use std::sync::{Arc, OnceLock}; +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; +use std::sync::{Arc, Mutex, OnceLock}; use std::time::{Duration, Instant}; -#[cfg(test)] -use std::sync::Mutex; use dashmap::DashMap; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; -use tokio::sync::{mpsc, oneshot, watch}; +use tokio::sync::{mpsc, oneshot, watch, Mutex as AsyncMutex}; use tokio::time::timeout; use tracing::{debug, trace, warn}; @@ -35,14 +33,22 @@ enum C2MeCommand { const DESYNC_DEDUP_WINDOW: Duration = Duration::from_secs(60); const DESYNC_DEDUP_MAX_ENTRIES: usize = 65_536; const DESYNC_DEDUP_PRUNE_SCAN_LIMIT: usize = 1024; +const DESYNC_FULL_CACHE_EMIT_MIN_INTERVAL: Duration = Duration::from_millis(1000); const DESYNC_ERROR_CLASS: &str = "frame_too_large_crypto_desync"; const C2ME_CHANNEL_CAPACITY_FALLBACK: usize = 128; const C2ME_SOFT_PRESSURE_MIN_FREE_SLOTS: usize = 64; const C2ME_SENDER_FAIRNESS_BUDGET: usize = 32; +#[cfg(test)] +const C2ME_SEND_TIMEOUT: Duration = Duration::from_millis(50); +#[cfg(not(test))] +const C2ME_SEND_TIMEOUT: Duration = Duration::from_secs(5); const ME_D2C_FLUSH_BATCH_MAX_FRAMES_MIN: usize = 1; const ME_D2C_FLUSH_BATCH_MAX_BYTES_MIN: usize = 4096; static DESYNC_DEDUP: OnceLock> = OnceLock::new(); static DESYNC_HASHER: OnceLock = OnceLock::new(); +static DESYNC_FULL_CACHE_LAST_EMIT_AT: OnceLock>> = OnceLock::new(); +static DESYNC_DEDUP_EVER_SATURATED: OnceLock = OnceLock::new(); +static QUOTA_USER_LOCKS: OnceLock>>> = OnceLock::new(); struct RelayForensicsState { trace_id: u64, @@ -98,6 +104,11 @@ fn should_emit_full_desync(key: u64, all_full: bool, now: Instant) -> bool { } let dedup = DESYNC_DEDUP.get_or_init(DashMap::new); + let saturated_before = dedup.len() >= DESYNC_DEDUP_MAX_ENTRIES; + let ever_saturated = DESYNC_DEDUP_EVER_SATURATED.get_or_init(|| AtomicBool::new(false)); + if saturated_before { + ever_saturated.store(true, Ordering::Relaxed); + } if let Some(mut seen_at) = dedup.get_mut(&key) { if now.duration_since(*seen_at) >= DESYNC_DEDUP_WINDOW { @@ -132,12 +143,52 @@ fn should_emit_full_desync(key: u64, all_full: bool, now: Instant) -> bool { }; dedup.remove(&evict_key); dedup.insert(key, now); - return false; + return should_emit_full_desync_full_cache(now); } } dedup.insert(key, now); - true + let saturated_after = dedup.len() >= DESYNC_DEDUP_MAX_ENTRIES; + // Preserve the first sequential insert that reaches capacity as a normal + // emit, while still gating concurrent newcomer churn after the cache has + // ever been observed at saturation. + let was_ever_saturated = if saturated_after { + ever_saturated.swap(true, Ordering::Relaxed) + } else { + ever_saturated.load(Ordering::Relaxed) + }; + + if saturated_before || (saturated_after && was_ever_saturated) { + should_emit_full_desync_full_cache(now) + } else { + true + } +} + +fn should_emit_full_desync_full_cache(now: Instant) -> bool { + let gate = DESYNC_FULL_CACHE_LAST_EMIT_AT.get_or_init(|| Mutex::new(None)); + let Ok(mut last_emit_at) = gate.lock() else { + return false; + }; + + match *last_emit_at { + None => { + *last_emit_at = Some(now); + true + } + Some(last) => { + let Some(elapsed) = now.checked_duration_since(last) else { + *last_emit_at = Some(now); + return true; + }; + if elapsed >= DESYNC_FULL_CACHE_EMIT_MIN_INTERVAL { + *last_emit_at = Some(now); + true + } else { + false + } + } + } } #[cfg(test)] @@ -145,6 +196,21 @@ fn clear_desync_dedup_for_testing() { if let Some(dedup) = DESYNC_DEDUP.get() { dedup.clear(); } + if let Some(ever_saturated) = DESYNC_DEDUP_EVER_SATURATED.get() { + ever_saturated.store(false, Ordering::Relaxed); + } + if let Some(last_emit_at) = DESYNC_FULL_CACHE_LAST_EMIT_AT.get() { + match last_emit_at.lock() { + Ok(mut guard) => { + *guard = None; + } + Err(poisoned) => { + let mut guard = poisoned.into_inner(); + *guard = None; + last_emit_at.clear_poison(); + } + } + } } #[cfg(test)] @@ -248,6 +314,38 @@ fn should_yield_c2me_sender(sent_since_yield: usize, has_backlog: bool) -> bool has_backlog && sent_since_yield >= C2ME_SENDER_FAIRNESS_BUDGET } +fn quota_exceeded_for_user(stats: &Stats, user: &str, quota_limit: Option) -> bool { + quota_limit.is_some_and(|quota| stats.get_user_total_octets(user) >= quota) +} + +fn quota_would_be_exceeded_for_user( + stats: &Stats, + user: &str, + quota_limit: Option, + bytes: u64, +) -> bool { + quota_limit.is_some_and(|quota| { + let used = stats.get_user_total_octets(user); + used >= quota || bytes > quota.saturating_sub(used) + }) +} + +fn quota_user_lock(user: &str) -> Arc> { + let locks = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + if let Some(existing) = locks.get(user) { + return Arc::clone(existing.value()); + } + + let created = Arc::new(AsyncMutex::new(())); + match locks.entry(user.to_string()) { + dashmap::mapref::entry::Entry::Occupied(entry) => Arc::clone(entry.get()), + dashmap::mapref::entry::Entry::Vacant(entry) => { + entry.insert(Arc::clone(&created)); + created + } + } +} + async fn enqueue_c2me_command( tx: &mpsc::Sender, cmd: C2MeCommand, @@ -260,7 +358,14 @@ async fn enqueue_c2me_command( if tx.capacity() <= C2ME_SOFT_PRESSURE_MIN_FREE_SLOTS { tokio::task::yield_now().await; } - tx.send(cmd).await + match timeout(C2ME_SEND_TIMEOUT, tx.reserve()).await { + Ok(Ok(permit)) => { + permit.send(cmd); + Ok(()) + } + Ok(Err(_)) => Err(mpsc::error::SendError(cmd)), + Err(_) => Err(mpsc::error::SendError(cmd)), + } } } } @@ -284,6 +389,7 @@ where W: AsyncWrite + Unpin + Send + 'static, { let user = success.user.clone(); + let quota_limit = config.access.user_data_quota.get(&user).copied(); let peer = success.peer; let proto_tag = success.proto_tag; let pool_generation = me_pool.current_generation(); @@ -432,6 +538,7 @@ where &mut frame_buf, stats_clone.as_ref(), &user_clone, + quota_limit, bytes_me2c_clone.as_ref(), conn_id, d2c_flush_policy.ack_flush_immediate, @@ -464,6 +571,7 @@ where &mut frame_buf, stats_clone.as_ref(), &user_clone, + quota_limit, bytes_me2c_clone.as_ref(), conn_id, d2c_flush_policy.ack_flush_immediate, @@ -496,6 +604,7 @@ where &mut frame_buf, stats_clone.as_ref(), &user_clone, + quota_limit, bytes_me2c_clone.as_ref(), conn_id, d2c_flush_policy.ack_flush_immediate, @@ -528,6 +637,7 @@ where &mut frame_buf, stats_clone.as_ref(), &user_clone, + quota_limit, bytes_me2c_clone.as_ref(), conn_id, d2c_flush_policy.ack_flush_immediate, @@ -609,7 +719,19 @@ where forensics.bytes_c2me = forensics .bytes_c2me .saturating_add(payload.len() as u64); - stats.add_user_octets_from(&user, payload.len() as u64); + if let Some(limit) = quota_limit { + let quota_lock = quota_user_lock(&user); + let _quota_guard = quota_lock.lock().await; + stats.add_user_octets_from(&user, payload.len() as u64); + if quota_exceeded_for_user(stats.as_ref(), &user, Some(limit)) { + main_result = Err(ProxyError::DataQuotaExceeded { + user: user.clone(), + }); + break; + } + } else { + stats.add_user_octets_from(&user, payload.len() as u64); + } let mut flags = proto_flags; if quickack { flags |= RPC_FLAG_QUICKACK; @@ -833,6 +955,7 @@ async fn process_me_writer_response( frame_buf: &mut Vec, stats: &Stats, user: &str, + quota_limit: Option, bytes_me2c: &AtomicU64, conn_id: u64, ack_flush_immediate: bool, @@ -848,17 +971,47 @@ where } else { trace!(conn_id, bytes = data.len(), flags, "ME->C data"); } - bytes_me2c.fetch_add(data.len() as u64, Ordering::Relaxed); - stats.add_user_octets_to(user, data.len() as u64); - write_client_payload( - client_writer, - proto_tag, - flags, - &data, - rng, - frame_buf, - ) - .await?; + let data_len = data.len() as u64; + if let Some(limit) = quota_limit { + let quota_lock = quota_user_lock(user); + let _quota_guard = quota_lock.lock().await; + if quota_would_be_exceeded_for_user(stats, user, Some(limit), data_len) { + return Err(ProxyError::DataQuotaExceeded { + user: user.to_string(), + }); + } + write_client_payload( + client_writer, + proto_tag, + flags, + &data, + rng, + frame_buf, + ) + .await?; + + bytes_me2c.fetch_add(data.len() as u64, Ordering::Relaxed); + stats.add_user_octets_to(user, data.len() as u64); + + if quota_exceeded_for_user(stats, user, Some(limit)) { + return Err(ProxyError::DataQuotaExceeded { + user: user.to_string(), + }); + } + } else { + write_client_payload( + client_writer, + proto_tag, + flags, + &data, + rng, + frame_buf, + ) + .await?; + + bytes_me2c.fetch_add(data.len() as u64, Ordering::Relaxed); + stats.add_user_octets_to(user, data.len() as u64); + } Ok(MeWriterResponseOutcome::Continue { frames: 1, diff --git a/src/proxy/middle_relay_security_tests.rs b/src/proxy/middle_relay_security_tests.rs index 441595e..4dd1178 100644 --- a/src/proxy/middle_relay_security_tests.rs +++ b/src/proxy/middle_relay_security_tests.rs @@ -13,8 +13,9 @@ use rand::{Rng, SeedableRng}; use std::collections::{HashMap, HashSet}; use std::net::SocketAddr; use std::sync::Arc; -use std::sync::atomic::AtomicU64; -use tokio::io::AsyncWriteExt; +use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering}; +use std::thread; +use tokio::io::AsyncReadExt; use tokio::io::duplex; use tokio::time::{Duration as TokioDuration, timeout}; @@ -176,6 +177,36 @@ async fn enqueue_c2me_command_full_then_closed_recycles_waiting_payload() { ); } +#[tokio::test] +async fn enqueue_c2me_command_full_queue_times_out_without_receiver_progress() { + let (tx, _rx) = mpsc::channel::(1); + tx.send(C2MeCommand::Data { + payload: make_pooled_payload(&[1]), + flags: 0, + }) + .await + .unwrap(); + + let started = Instant::now(); + let result = enqueue_c2me_command( + &tx, + C2MeCommand::Data { + payload: make_pooled_payload(&[2, 2]), + flags: 1, + }, + ) + .await; + + assert!( + result.is_err(), + "enqueue must fail when queue stays full beyond bounded timeout" + ); + assert!( + started.elapsed() < TokioDuration::from_millis(400), + "full-queue timeout must resolve promptly" + ); +} + #[test] fn desync_dedup_cache_is_bounded() { let _guard = desync_dedup_test_lock() @@ -192,12 +223,12 @@ fn desync_dedup_cache_is_bounded() { } assert!( - !should_emit_full_desync(u64::MAX, false, now), - "new key above cap must remain suppressed to avoid log amplification" + should_emit_full_desync(u64::MAX, false, now), + "new key above cap must emit once after bounded eviction for forensic visibility" ); assert!( - !should_emit_full_desync(7, false, now), + !should_emit_full_desync(u64::MAX, false, now), "already tracked key inside dedup window must stay suppressed" ); } @@ -215,10 +246,18 @@ fn desync_dedup_full_cache_churn_stays_suppressed() { } for offset in 0..2048u64 { - assert!( - !should_emit_full_desync(u64::MAX - offset, false, now), - "fresh full-cache churn must remain suppressed under pressure" - ); + let emitted = should_emit_full_desync(u64::MAX - offset, false, now); + if offset == 0 { + assert!( + emitted, + "first full-cache newcomer should emit for forensic visibility" + ); + } else { + assert!( + !emitted, + "full-cache newcomer churn inside emit interval must stay suppressed" + ); + } } } @@ -296,18 +335,20 @@ fn stress_desync_dedup_churn_keeps_cache_hard_bounded() { let now = Instant::now(); let total = DESYNC_DEDUP_MAX_ENTRIES + 8192; + let mut emitted_count = 0usize; for key in 0..total as u64 { let emitted = should_emit_full_desync(key, false, now); - if key < DESYNC_DEDUP_MAX_ENTRIES as u64 { - assert!(emitted, "keys below cap must be admitted initially"); - } else { - assert!( - !emitted, - "new keys above cap must stay suppressed under sustained churn" - ); + if emitted { + emitted_count += 1; } } + assert_eq!( + emitted_count, + DESYNC_DEDUP_MAX_ENTRIES + 1, + "after capacity is reached, same-tick newcomer churn must be rate-limited" + ); + let len = DESYNC_DEDUP .get() .expect("dedup cache must be initialized by stress run") @@ -318,6 +359,282 @@ fn stress_desync_dedup_churn_keeps_cache_hard_bounded() { ); } +#[test] +fn full_cache_newcomer_emission_is_rate_limited_but_periodic() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("desync dedup test lock must be available"); + clear_desync_dedup_for_testing(); + + let dedup = DESYNC_DEDUP.get_or_init(DashMap::new); + let base_now = Instant::now(); + + for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { + dedup.insert(key, base_now - TokioDuration::from_millis(10)); + } + + // Same-tick newcomer storm: only the first should emit full forensic record. + let mut burst_emits = 0usize; + for i in 0..1024u64 { + if should_emit_full_desync(10_000_000 + i, false, base_now) { + burst_emits += 1; + } + } + assert_eq!( + burst_emits, 1, + "full-cache newcomer burst must be bounded to a single full emit per interval" + ); + + // After each interval elapses, one newcomer may emit again. + for step in 1..=6u64 { + let t = base_now + DESYNC_FULL_CACHE_EMIT_MIN_INTERVAL * step as u32; + assert!( + should_emit_full_desync(20_000_000 + step, false, t), + "full-cache newcomer should re-emit once interval has elapsed" + ); + assert!( + !should_emit_full_desync(30_000_000 + step, false, t), + "additional newcomers in the same interval tick must remain suppressed" + ); + } +} + +#[test] +fn full_cache_mode_override_emits_every_event() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("desync dedup test lock must be available"); + clear_desync_dedup_for_testing(); + + let now = Instant::now(); + for i in 0..10_000u64 { + assert!( + should_emit_full_desync(100_000_000 + i, true, now), + "desync_all_full override must bypass dedup and rate-limit suppression" + ); + } +} + +#[test] +fn report_desync_stats_follow_rate_limited_full_cache_policy() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("desync dedup test lock must be available"); + clear_desync_dedup_for_testing(); + + let dedup = DESYNC_DEDUP.get_or_init(DashMap::new); + let base_now = Instant::now(); + for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { + dedup.insert(key, base_now - TokioDuration::from_millis(10)); + } + + let stats = Stats::new(); + let mut state = make_forensics_state(); + state.started_at = base_now; + + for i in 0..128u64 { + state.peer_hash = 0xABC0_0000_0000_0000u64 ^ i; + let _ = report_desync_frame_too_large( + &state, + ProtoTag::Secure, + 3, + 1024, + 4096, + Some([0x16, 0x03, 0x03, 0x00]), + &stats, + ); + } + + assert_eq!( + stats.get_desync_total(), + 128, + "every detected desync must increment total counter" + ); + assert_eq!( + stats.get_desync_full_logged(), + 1, + "same-interval full-cache newcomer storm must allow only one full forensic emit" + ); + assert_eq!( + stats.get_desync_suppressed(), + 127, + "remaining same-interval full-cache newcomer events must be suppressed" + ); + + // After one full interval in real wall clock, a newcomer should emit again. + thread::sleep(DESYNC_FULL_CACHE_EMIT_MIN_INTERVAL + TokioDuration::from_millis(20)); + state.peer_hash = 0xDEAD_BEEF_DEAD_BEEFu64; + let _ = report_desync_frame_too_large( + &state, + ProtoTag::Secure, + 4, + 1024, + 4097, + Some([0x16, 0x03, 0x03, 0x01]), + &stats, + ); + + assert_eq!( + stats.get_desync_full_logged(), + 2, + "full forensic emission must recover after rate-limit interval" + ); +} + +#[test] +fn concurrent_full_cache_newcomer_storm_is_single_emit_per_interval() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("desync dedup test lock must be available"); + clear_desync_dedup_for_testing(); + + let dedup = DESYNC_DEDUP.get_or_init(DashMap::new); + let base_now = Instant::now(); + for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { + dedup.insert(key, base_now - TokioDuration::from_millis(10)); + } + + let emits = Arc::new(AtomicUsize::new(0)); + let mut workers = Vec::new(); + for worker_id in 0..32u64 { + let emits = Arc::clone(&emits); + workers.push(thread::spawn(move || { + for i in 0..512u64 { + let key = 0x7000_0000_0000_0000u64 ^ (worker_id << 20) ^ i; + if should_emit_full_desync(key, false, base_now) { + emits.fetch_add(1, Ordering::Relaxed); + } + } + })); + } + + for worker in workers { + worker.join().expect("worker thread must not panic"); + } + + assert_eq!( + emits.load(Ordering::Relaxed), + 1, + "concurrent same-interval full-cache storm must allow only one full forensic emit" + ); +} + +#[test] +fn light_fuzz_full_cache_rate_limit_oracle_matches_model() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("desync dedup test lock must be available"); + clear_desync_dedup_for_testing(); + + let dedup = DESYNC_DEDUP.get_or_init(DashMap::new); + let base_now = Instant::now(); + for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { + dedup.insert(key, base_now - TokioDuration::from_millis(10)); + } + + let mut rng = StdRng::seed_from_u64(0xD15EA5E5_F00DBAAD); + let mut model_last_emit: Option = None; + + for i in 0..4096u64 { + let jitter_ms: u64 = rng.random_range(0..=3000); + let t = base_now + TokioDuration::from_millis(jitter_ms); + let key = 0x55AA_0000_0000_0000u64 ^ i ^ rng.random::(); + let actual = should_emit_full_desync(key, false, t); + + let expected = match model_last_emit { + None => { + model_last_emit = Some(t); + true + } + Some(last) => { + match t.checked_duration_since(last) { + Some(elapsed) if elapsed >= DESYNC_FULL_CACHE_EMIT_MIN_INTERVAL => { + model_last_emit = Some(t); + true + } + Some(_) => false, + None => { + // Match production fail-open behavior for non-monotonic synthetic input. + model_last_emit = Some(t); + true + } + } + } + }; + + assert_eq!( + actual, expected, + "full-cache rate-limit gate diverged from reference model under light fuzz" + ); + } +} + +#[test] +fn full_cache_gate_lock_poison_is_fail_closed_without_panic() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("desync dedup test lock must be available"); + clear_desync_dedup_for_testing(); + + let dedup = DESYNC_DEDUP.get_or_init(DashMap::new); + let base_now = Instant::now(); + for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { + dedup.insert(key, base_now - TokioDuration::from_millis(10)); + } + + // Poison the full-cache gate lock intentionally. + let gate = DESYNC_FULL_CACHE_LAST_EMIT_AT.get_or_init(|| Mutex::new(None)); + let _ = std::panic::catch_unwind(|| { + let _lock = gate.lock().expect("gate lock must be lockable before poison"); + panic!("intentional gate poison for fail-closed regression"); + }); + + let emitted = should_emit_full_desync(0xFACE_0000_0000_0001, false, base_now); + assert!( + !emitted, + "poisoned full-cache gate must fail-closed (suppress) instead of panic or fail-open" + ); + assert!( + dedup.len() <= DESYNC_DEDUP_MAX_ENTRIES, + "dedup cache must remain bounded even when gate lock is poisoned" + ); +} + +#[test] +fn full_cache_non_monotonic_time_emits_and_resets_gate_safely() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("desync dedup test lock must be available"); + clear_desync_dedup_for_testing(); + + let dedup = DESYNC_DEDUP.get_or_init(DashMap::new); + let base_now = Instant::now(); + for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { + dedup.insert(key, base_now - TokioDuration::from_millis(10)); + } + + // First event seeds the gate. + assert!(should_emit_full_desync( + 0xABCD_0000_0000_0001, + false, + base_now + TokioDuration::from_millis(900) + )); + + // Synthetic earlier timestamp must not panic; it should fail-open and reset gate. + assert!(should_emit_full_desync( + 0xABCD_0000_0000_0002, + false, + base_now + TokioDuration::from_millis(100) + )); + + // Same instant again remains suppressed after reset. + assert!(!should_emit_full_desync( + 0xABCD_0000_0000_0003, + false, + base_now + TokioDuration::from_millis(100) + )); +} + #[test] fn desync_dedup_full_cache_inserts_new_key_with_bounded_single_key_churn() { let _guard = desync_dedup_test_lock() @@ -338,8 +655,8 @@ fn desync_dedup_full_cache_inserts_new_key_with_bounded_single_key_churn() { let newcomer_key = u64::MAX; let emitted = should_emit_full_desync(newcomer_key, false, base_now); assert!( - !emitted, - "new entry under full fresh cache must stay suppressed" + emitted, + "new entry under full fresh cache must emit after bounded eviction" ); assert!( dedup.get(&newcomer_key).is_some(), @@ -406,6 +723,24 @@ fn light_fuzz_desync_dedup_temporal_gate_behavior_is_stable() { panic!("expected at least one post-window sample to re-emit forensic record"); } +#[test] +#[ignore = "Tracking for M-04: Verify should_emit_full_desync returns true on first occurrence and false on duplicate within window"] +fn should_emit_full_desync_filters_duplicates() { + unimplemented!("Stub for M-04"); +} + +#[test] +#[ignore = "Tracking for M-04: Verify desync dedup eviction behaves correctly under map-full condition"] +fn desync_dedup_eviction_under_map_full_condition() { + unimplemented!("Stub for M-04"); +} + +#[tokio::test] +#[ignore = "Tracking for M-05: Verify C2ME channel full path yields then sends under backpressure"] +async fn c2me_channel_full_path_yields_then_sends() { + unimplemented!("Stub for M-05"); +} + fn make_forensics_state() -> RelayForensicsState { RelayForensicsState { trace_id: 1, @@ -974,6 +1309,7 @@ async fn process_me_writer_response_ack_obeys_flush_policy() { &mut frame_buf, &stats, "user", + None, &bytes_me2c, 77, true, @@ -999,6 +1335,7 @@ async fn process_me_writer_response_ack_obeys_flush_policy() { &mut frame_buf, &stats, "user", + None, &bytes_me2c, 77, false, @@ -1038,6 +1375,7 @@ async fn process_me_writer_response_data_updates_byte_accounting() { &mut frame_buf, &stats, "user", + None, &bytes_me2c, 88, false, @@ -1061,6 +1399,162 @@ async fn process_me_writer_response_data_updates_byte_accounting() { ); } +#[tokio::test] +async fn process_me_writer_response_data_enforces_live_user_quota() { + let (writer_side, mut reader_side) = duplex(1024); + let mut writer = make_crypto_writer(writer_side); + let rng = SecureRandom::new(); + let mut frame_buf = Vec::new(); + let stats = Stats::new(); + let bytes_me2c = AtomicU64::new(0); + + stats.add_user_octets_from("quota-user", 10); + + let result = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from(vec![1u8, 2, 3, 4]), + }, + &mut writer, + ProtoTag::Intermediate, + &rng, + &mut frame_buf, + &stats, + "quota-user", + Some(12), + &bytes_me2c, + 89, + false, + false, + ) + .await; + + assert!( + matches!(result, Err(ProxyError::DataQuotaExceeded { user }) if user == "quota-user"), + "ME->client runtime path must terminate when live user quota is crossed" + ); + + let mut raw = [0u8; 1]; + assert!( + timeout(TokioDuration::from_millis(100), reader_side.read(&mut raw)) + .await + .is_err(), + "quota exhaustion must not write any ciphertext to the client stream" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn process_me_writer_response_concurrent_same_user_quota_does_not_overshoot_limit() { + let stats = Stats::new(); + let bytes_me2c = AtomicU64::new(0); + let user = "quota-race-user"; + + let (writer_side_a, _reader_side_a) = duplex(1024); + let (writer_side_b, _reader_side_b) = duplex(1024); + let mut writer_a = make_crypto_writer(writer_side_a); + let mut writer_b = make_crypto_writer(writer_side_b); + let mut frame_buf_a = Vec::new(); + let mut frame_buf_b = Vec::new(); + let rng_a = SecureRandom::new(); + let rng_b = SecureRandom::new(); + + let fut_a = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0x11]), + }, + &mut writer_a, + ProtoTag::Intermediate, + &rng_a, + &mut frame_buf_a, + &stats, + user, + Some(1), + &bytes_me2c, + 91, + false, + false, + ); + let fut_b = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0x22]), + }, + &mut writer_b, + ProtoTag::Intermediate, + &rng_b, + &mut frame_buf_b, + &stats, + user, + Some(1), + &bytes_me2c, + 92, + false, + false, + ); + + let (result_a, result_b) = tokio::join!(fut_a, fut_b); + + assert!( + matches!(result_a, Err(ProxyError::DataQuotaExceeded { ref user }) if user == "quota-race-user") + || matches!(result_a, Ok(_)), + "concurrent quota test must complete without panicking" + ); + assert!( + matches!(result_b, Err(ProxyError::DataQuotaExceeded { ref user }) if user == "quota-race-user") + || matches!(result_b, Ok(_)), + "concurrent quota test must complete without panicking" + ); + assert!( + stats.get_user_total_octets(user) <= 1, + "same-user concurrent middle-relay responses must not overshoot the configured quota" + ); +} + +#[tokio::test] +async fn process_me_writer_response_data_does_not_forward_partial_payload_when_remaining_quota_is_smaller_than_message() { + let (writer_side, mut reader_side) = duplex(1024); + let mut writer = make_crypto_writer(writer_side); + let rng = SecureRandom::new(); + let mut frame_buf = Vec::new(); + let stats = Stats::new(); + let bytes_me2c = AtomicU64::new(0); + + stats.add_user_octets_to("partial-quota-user", 3); + + let result = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from(vec![1u8, 2, 3, 4]), + }, + &mut writer, + ProtoTag::Intermediate, + &rng, + &mut frame_buf, + &stats, + "partial-quota-user", + Some(4), + &bytes_me2c, + 90, + false, + false, + ) + .await; + + assert!( + matches!(result, Err(ProxyError::DataQuotaExceeded { user }) if user == "partial-quota-user"), + "ME->client runtime path must reject oversized payloads before writing" + ); + + let mut raw = [0u8; 1]; + assert!( + timeout(TokioDuration::from_millis(100), reader_side.read(&mut raw)) + .await + .is_err(), + "oversized payloads must not leak any partial ciphertext to the client stream" + ); +} + #[tokio::test] async fn middle_relay_abort_midflight_releases_route_gauge() { let stats = Arc::new(Stats::new()); diff --git a/src/proxy/relay.rs b/src/proxy/relay.rs index 06ce0d8..46a2b21 100644 --- a/src/proxy/relay.rs +++ b/src/proxy/relay.rs @@ -53,16 +53,17 @@ use std::io; use std::pin::Pin; -use std::sync::Arc; -use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::{Arc, Mutex, OnceLock}; +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; use std::task::{Context, Poll}; use std::time::Duration; +use dashmap::DashMap; use tokio::io::{ AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf, copy_bidirectional_with_sizes, }; use tokio::time::Instant; use tracing::{debug, trace, warn}; -use crate::error::Result; +use crate::error::{ProxyError, Result}; use crate::stats::Stats; use crate::stream::BufferPool; @@ -205,6 +206,8 @@ struct StatsIo { counters: Arc, stats: Arc, user: String, + quota_limit: Option, + quota_exceeded: Arc, epoch: Instant, } @@ -214,11 +217,62 @@ impl StatsIo { counters: Arc, stats: Arc, user: String, + quota_limit: Option, + quota_exceeded: Arc, epoch: Instant, ) -> Self { // Mark initial activity so the watchdog doesn't fire before data flows counters.touch(Instant::now(), epoch); - Self { inner, counters, stats, user, epoch } + Self { + inner, + counters, + stats, + user, + quota_limit, + quota_exceeded, + epoch, + } + } +} + +#[derive(Debug)] +struct QuotaIoSentinel; + +impl std::fmt::Display for QuotaIoSentinel { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("user data quota exceeded") + } +} + +impl std::error::Error for QuotaIoSentinel {} + +fn quota_io_error() -> io::Error { + io::Error::new(io::ErrorKind::PermissionDenied, QuotaIoSentinel) +} + +fn is_quota_io_error(err: &io::Error) -> bool { + err.kind() == io::ErrorKind::PermissionDenied + && err + .get_ref() + .and_then(|source| source.downcast_ref::()) + .is_some() +} + +static QUOTA_USER_LOCKS: OnceLock>>> = OnceLock::new(); + +fn quota_user_lock(user: &str) -> Arc> { + let locks = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + if let Some(existing) = locks.get(user) { + return Arc::clone(existing.value()); + } + + let created = Arc::new(Mutex::new(())); + match locks.entry(user.to_string()) { + dashmap::mapref::entry::Entry::Occupied(entry) => Arc::clone(entry.get()), + dashmap::mapref::entry::Entry::Vacant(entry) => { + entry.insert(Arc::clone(&created)); + created + } } } @@ -229,6 +283,32 @@ impl AsyncRead for StatsIo { buf: &mut ReadBuf<'_>, ) -> Poll> { let this = self.get_mut(); + if this.quota_exceeded.load(Ordering::Relaxed) { + return Poll::Ready(Err(quota_io_error())); + } + + let quota_lock = this + .quota_limit + .is_some() + .then(|| quota_user_lock(&this.user)); + let _quota_guard = if let Some(lock) = quota_lock.as_ref() { + match lock.try_lock() { + Ok(guard) => Some(guard), + Err(_) => { + cx.waker().wake_by_ref(); + return Poll::Pending; + } + } + } else { + None + }; + + if let Some(limit) = this.quota_limit + && this.stats.get_user_total_octets(&this.user) >= limit + { + this.quota_exceeded.store(true, Ordering::Relaxed); + return Poll::Ready(Err(quota_io_error())); + } let before = buf.filled().len(); match Pin::new(&mut this.inner).poll_read(cx, buf) { @@ -243,6 +323,13 @@ impl AsyncRead for StatsIo { this.stats.add_user_octets_from(&this.user, n as u64); this.stats.increment_user_msgs_from(&this.user); + if let Some(limit) = this.quota_limit + && this.stats.get_user_total_octets(&this.user) >= limit + { + this.quota_exceeded.store(true, Ordering::Relaxed); + return Poll::Ready(Err(quota_io_error())); + } + trace!(user = %this.user, bytes = n, "C->S"); } Poll::Ready(Ok(())) @@ -259,8 +346,46 @@ impl AsyncWrite for StatsIo { buf: &[u8], ) -> Poll> { let this = self.get_mut(); + if this.quota_exceeded.load(Ordering::Relaxed) { + return Poll::Ready(Err(quota_io_error())); + } - match Pin::new(&mut this.inner).poll_write(cx, buf) { + let quota_lock = this + .quota_limit + .is_some() + .then(|| quota_user_lock(&this.user)); + let _quota_guard = if let Some(lock) = quota_lock.as_ref() { + match lock.try_lock() { + Ok(guard) => Some(guard), + Err(_) => { + cx.waker().wake_by_ref(); + return Poll::Pending; + } + } + } else { + None + }; + + let write_buf = if let Some(limit) = this.quota_limit { + let used = this.stats.get_user_total_octets(&this.user); + if used >= limit { + this.quota_exceeded.store(true, Ordering::Relaxed); + return Poll::Ready(Err(quota_io_error())); + } + + let remaining = (limit - used) as usize; + if buf.len() > remaining { + // Fail closed: do not emit partial S->C payload when remaining + // quota cannot accommodate the pending write request. + this.quota_exceeded.store(true, Ordering::Relaxed); + return Poll::Ready(Err(quota_io_error())); + } + buf + } else { + buf + }; + + match Pin::new(&mut this.inner).poll_write(cx, write_buf) { Poll::Ready(Ok(n)) => { if n > 0 { // S→C: data written to client @@ -271,6 +396,13 @@ impl AsyncWrite for StatsIo { this.stats.add_user_octets_to(&this.user, n as u64); this.stats.increment_user_msgs_to(&this.user); + if let Some(limit) = this.quota_limit + && this.stats.get_user_total_octets(&this.user) >= limit + { + this.quota_exceeded.store(true, Ordering::Relaxed); + return Poll::Ready(Err(quota_io_error())); + } + trace!(user = %this.user, bytes = n, "S->C"); } Poll::Ready(Ok(n)) @@ -307,7 +439,8 @@ impl AsyncWrite for StatsIo { /// - Per-user stats: bytes and ops counted per direction /// - Periodic rate logging: every 10 seconds when active /// - Clean shutdown: both write sides are shut down on exit -/// - Error propagation: I/O errors are returned as `ProxyError::Io` +/// - Error propagation: quota exits return `ProxyError::DataQuotaExceeded`, +/// other I/O failures are returned as `ProxyError::Io` pub async fn relay_bidirectional( client_reader: CR, client_writer: CW, @@ -317,6 +450,7 @@ pub async fn relay_bidirectional( s2c_buf_size: usize, user: &str, stats: Arc, + quota_limit: Option, _buffer_pool: Arc, ) -> Result<()> where @@ -327,6 +461,7 @@ where { let epoch = Instant::now(); let counters = Arc::new(SharedCounters::new()); + let quota_exceeded = Arc::new(AtomicBool::new(false)); let user_owned = user.to_string(); // ── Combine split halves into bidirectional streams ────────────── @@ -339,12 +474,15 @@ where Arc::clone(&counters), Arc::clone(&stats), user_owned.clone(), + quota_limit, + Arc::clone("a_exceeded), epoch, ); // ── Watchdog: activity timeout + periodic rate logging ────────── let wd_counters = Arc::clone(&counters); let wd_user = user_owned.clone(); + let wd_quota_exceeded = Arc::clone("a_exceeded); let watchdog = async { let mut prev_c2s: u64 = 0; @@ -356,6 +494,11 @@ where let now = Instant::now(); let idle = wd_counters.idle_duration(now, epoch); + if wd_quota_exceeded.load(Ordering::Relaxed) { + warn!(user = %wd_user, "User data quota reached, closing relay"); + return; + } + // ── Activity timeout ──────────────────────────────────── if idle >= ACTIVITY_TIMEOUT { let c2s = wd_counters.c2s_bytes.load(Ordering::Relaxed); @@ -439,6 +582,22 @@ where ); Ok(()) } + Some(Err(e)) if is_quota_io_error(&e) => { + let c2s = counters.c2s_bytes.load(Ordering::Relaxed); + let s2c = counters.s2c_bytes.load(Ordering::Relaxed); + warn!( + user = %user_owned, + c2s_bytes = c2s, + s2c_bytes = s2c, + c2s_msgs = c2s_ops, + s2c_msgs = s2c_ops, + duration_secs = duration.as_secs(), + "Data quota reached, closing relay" + ); + Err(ProxyError::DataQuotaExceeded { + user: user_owned.clone(), + }) + } Some(Err(e)) => { // I/O error in one of the directions let c2s = counters.c2s_bytes.load(Ordering::Relaxed); @@ -472,3 +631,7 @@ where } } } + +#[cfg(test)] +#[path = "relay_security_tests.rs"] +mod security_tests; diff --git a/src/proxy/relay_security_tests.rs b/src/proxy/relay_security_tests.rs new file mode 100644 index 0000000..7b985cb --- /dev/null +++ b/src/proxy/relay_security_tests.rs @@ -0,0 +1,972 @@ +use super::relay_bidirectional; +use crate::error::ProxyError; +use crate::stats::Stats; +use crate::stream::BufferPool; +use std::future::poll_fn; +use std::io; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Mutex; +use std::task::{Context, Poll}; +use std::task::Waker; +use tokio::io::{AsyncRead, ReadBuf}; +use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt, duplex}; +use tokio::time::{Duration, timeout}; + +#[tokio::test] +async fn relay_bidirectional_enforces_live_user_quota() { + let stats = Arc::new(Stats::new()); + let user = "quota-user"; + stats.add_user_octets_from(user, 6); + + let (mut client_peer, relay_client) = duplex(4096); + let (relay_server, mut server_peer) = duplex(4096); + + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_task = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 1024, + 1024, + user, + Arc::clone(&stats), + Some(8), + Arc::new(BufferPool::new()), + )); + + client_peer + .write_all(&[0x10, 0x20, 0x30, 0x40]) + .await + .expect("client write must succeed"); + + let mut forwarded = [0u8; 4]; + let _ = timeout( + Duration::from_millis(200), + server_peer.read_exact(&mut forwarded), + ) + .await; + + let relay_result = timeout(Duration::from_secs(2), relay_task) + .await + .expect("relay task must finish under quota cutoff") + .expect("relay task must not panic"); + + assert!( + matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == "quota-user"), + "relay must surface a typed quota error once live quota is exceeded" + ); +} + +#[tokio::test] +async fn relay_bidirectional_does_not_forward_server_bytes_after_quota_is_exhausted() { + let stats = Arc::new(Stats::new()); + let quota_user = "quota-exhausted-user"; + stats.add_user_octets_from(quota_user, 1); + + let (mut client_peer, relay_client) = duplex(4096); + let (relay_server, mut server_peer) = duplex(4096); + + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_task = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 1024, + 1024, + quota_user, + Arc::clone(&stats), + Some(1), + Arc::new(BufferPool::new()), + )); + + server_peer + .write_all(&[0xde, 0xad, 0xbe, 0xef]) + .await + .expect("server write must succeed"); + + let mut observed = [0u8; 4]; + let forwarded = timeout( + Duration::from_millis(200), + client_peer.read_exact(&mut observed), + ) + .await; + + let relay_result = timeout(Duration::from_secs(2), relay_task) + .await + .expect("relay task must finish under quota cutoff") + .expect("relay task must not panic"); + + assert!( + !matches!(forwarded, Ok(Ok(n)) if n == observed.len()), + "no full server payload should be forwarded once quota is already exhausted" + ); + assert!( + matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user), + "relay must still terminate with a typed quota error" + ); +} + +#[tokio::test] +async fn relay_bidirectional_does_not_leak_partial_server_payload_when_remaining_quota_is_smaller_than_write() { + let stats = Arc::new(Stats::new()); + let quota_user = "partial-leak-user"; + stats.add_user_octets_from(quota_user, 3); + + let (mut client_peer, relay_client) = duplex(4096); + let (relay_server, mut server_peer) = duplex(4096); + + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_task = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 1024, + 1024, + quota_user, + Arc::clone(&stats), + Some(4), + Arc::new(BufferPool::new()), + )); + + server_peer + .write_all(&[0x11, 0x22, 0x33, 0x44]) + .await + .expect("server write must succeed"); + + let mut observed = [0u8; 8]; + let forwarded = timeout(Duration::from_millis(200), client_peer.read(&mut observed)).await; + + let relay_result = timeout(Duration::from_secs(2), relay_task) + .await + .expect("relay task must finish under quota cutoff") + .expect("relay task must not panic"); + + assert!( + !matches!(forwarded, Ok(Ok(n)) if n > 0), + "quota exhaustion must not leak any partial server payload when remaining quota is smaller than the write" + ); + assert!( + matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user), + "relay must still terminate with a typed quota error" + ); +} + +#[tokio::test] +async fn relay_bidirectional_zero_quota_remains_fail_closed_for_server_payloads_under_stress() { + let stats = Arc::new(Stats::new()); + let quota_user = "zero-quota-user"; + + for payload_len in [1usize, 16, 512, 4096] { + let (mut client_peer, relay_client) = duplex(4096); + let (relay_server, mut server_peer) = duplex(4096); + + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_task = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 1024, + 1024, + quota_user, + Arc::clone(&stats), + Some(0), + Arc::new(BufferPool::new()), + )); + + let payload = vec![0x7f; payload_len]; + let _ = server_peer.write_all(&payload).await; + + let mut observed = vec![0u8; payload_len]; + let forwarded = timeout(Duration::from_millis(200), client_peer.read(&mut observed)).await; + + let relay_result = timeout(Duration::from_secs(2), relay_task) + .await + .expect("relay task must finish under zero-quota cutoff") + .expect("relay task must not panic"); + + assert!( + !matches!(forwarded, Ok(Ok(n)) if n > 0), + "zero quota must not forward any server bytes for payload_len={payload_len}" + ); + assert!( + matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user), + "zero quota must terminate with the typed quota error for payload_len={payload_len}" + ); + } +} + +#[tokio::test] +async fn relay_bidirectional_allows_exact_server_payload_at_quota_boundary() { + let stats = Arc::new(Stats::new()); + let quota_user = "exact-boundary-user"; + + let (mut client_peer, relay_client) = duplex(4096); + let (relay_server, mut server_peer) = duplex(4096); + + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_task = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 1024, + 1024, + quota_user, + Arc::clone(&stats), + Some(4), + Arc::new(BufferPool::new()), + )); + + server_peer + .write_all(&[0x91, 0x92, 0x93, 0x94]) + .await + .expect("server write must succeed at exact quota boundary"); + + let mut observed = [0u8; 4]; + client_peer + .read_exact(&mut observed) + .await + .expect("client must receive the full payload at the exact quota boundary"); + assert_eq!(observed, [0x91, 0x92, 0x93, 0x94]); + + let relay_result = timeout(Duration::from_secs(2), relay_task) + .await + .expect("relay task must finish after exact boundary delivery") + .expect("relay task must not panic"); + + assert!( + matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user), + "relay must close with a typed quota error after reaching the exact boundary" + ); +} + +#[tokio::test] +async fn relay_bidirectional_does_not_forward_client_bytes_after_quota_is_exhausted() { + let stats = Arc::new(Stats::new()); + let quota_user = "client-exhausted-user"; + stats.add_user_octets_from(quota_user, 1); + + let (mut client_peer, relay_client) = duplex(4096); + let (relay_server, mut server_peer) = duplex(4096); + + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_task = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 1024, + 1024, + quota_user, + Arc::clone(&stats), + Some(1), + Arc::new(BufferPool::new()), + )); + + client_peer + .write_all(&[0x51, 0x52, 0x53, 0x54]) + .await + .expect("client write must succeed even when quota is already exhausted"); + + let mut observed = [0u8; 4]; + let forwarded = timeout( + Duration::from_millis(200), + server_peer.read_exact(&mut observed), + ) + .await; + + let relay_result = timeout(Duration::from_secs(2), relay_task) + .await + .expect("relay task must finish under quota cutoff") + .expect("relay task must not panic"); + + assert!( + !matches!(forwarded, Ok(Ok(n)) if n == observed.len()), + "client payload must not be fully forwarded once quota is already exhausted" + ); + assert!( + matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user), + "relay must still terminate with a typed quota error" + ); +} + +#[tokio::test] +async fn relay_bidirectional_server_bytes_remain_blocked_even_under_multiple_payload_sizes() { + let stats = Arc::new(Stats::new()); + let quota_user = "quota-fuzz-user"; + stats.add_user_octets_from(quota_user, 2); + + for payload_len in [1usize, 32, 1024, 8192] { + let (mut client_peer, relay_client) = duplex(4096); + let (relay_server, mut server_peer) = duplex(4096); + + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_task = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 1024, + 1024, + quota_user, + Arc::clone(&stats), + Some(2), + Arc::new(BufferPool::new()), + )); + + let payload = vec![0xaa; payload_len]; + let _ = server_peer.write_all(&payload).await; + + let mut observed = vec![0u8; payload_len]; + let forwarded = timeout( + Duration::from_millis(200), + client_peer.read_exact(&mut observed), + ) + .await; + + let relay_result = timeout(Duration::from_secs(2), relay_task) + .await + .expect("relay task must finish under quota cutoff") + .expect("relay task must not panic"); + + assert!( + !matches!(forwarded, Ok(Ok(n)) if n == payload_len), + "quota exhaustion must block full server-to-client forwarding for payload_len={payload_len}" + ); + assert!( + matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user), + "relay must keep returning the typed quota error for payload_len={payload_len}" + ); + } +} + +#[tokio::test] +async fn relay_bidirectional_terminates_on_activity_timeout() { + tokio::time::pause(); + let stats = Arc::new(Stats::new()); + let user = "timeout-user"; + + let (client_peer, relay_client) = duplex(4096); + let (relay_server, server_peer) = duplex(4096); + + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_task = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 1024, + 1024, + user, + Arc::clone(&stats), + None, // No quota + Arc::new(BufferPool::new()), + )); + + // Wait past the activity timeout threshold (1800 seconds) + buffer + tokio::time::sleep(Duration::from_secs(1805)).await; + + // Resume time to process timeouts + tokio::time::resume(); + + let relay_result = timeout(Duration::from_secs(1), relay_task) + .await + .expect("relay task must finish inside bounded timeout due to inactivity cutoff") + .expect("relay task must not panic"); + + assert!( + relay_result.is_ok(), + "relay should complete successfully on scheduled inactivity timeout" + ); + + // Verify client/server sockets are closed + drop(client_peer); + drop(server_peer); +} + +#[tokio::test] +async fn relay_bidirectional_watchdog_resists_premature_execution() { + tokio::time::pause(); + let stats = Arc::new(Stats::new()); + let user = "activity-user"; + + let (mut client_peer, relay_client) = duplex(4096); + let (relay_server, server_peer) = duplex(4096); + + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let mut relay_task = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 1024, + 1024, + user, + Arc::clone(&stats), + None, + Arc::new(BufferPool::new()), + )); + + // Advance by half the timeout + tokio::time::sleep(Duration::from_secs(900)).await; + + // Provide activity + client_peer + .write_all(&[0xaa, 0xbb]) + .await + .expect("client write must succeed"); + client_peer.flush().await.unwrap(); + + // Advance by another half (total time since start is 1800, but since last activity is 900) + tokio::time::sleep(Duration::from_secs(900)).await; + + tokio::time::resume(); + + // Re-evaluating the task, it should NOT have timed out and still be pending + let relay_result = timeout(Duration::from_millis(100), &mut relay_task).await; + assert!( + relay_result.is_err(), + "Relay must not exit prematurely as long as activity was received before timeout" + ); + + // Explicitly drop sockets to cleanly shut down relay loop + drop(client_peer); + drop(server_peer); + + let completion = timeout(Duration::from_secs(1), relay_task).await + .expect("relay task must complete securely after client disconnection") + .expect("relay task must not panic"); + assert!(completion.is_ok(), "relay exits clean"); +} + +#[tokio::test] +async fn relay_bidirectional_half_closure_terminates_cleanly() { + let stats = Arc::new(Stats::new()); + let (client_peer, relay_client) = duplex(4096); + let (relay_server, server_peer) = duplex(4096); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_task = tokio::spawn(relay_bidirectional( + client_reader, client_writer, server_reader, server_writer, 1024, 1024, "half-close", stats, None, Arc::new(BufferPool::new()), + )); + + // Half closure: drop the client completely but leave the server active. + drop(client_peer); + + // Check that we don't immediately crash. Bidirectional relay stays open for the server -> client flush. + // Eventually dropping the server cleanly closes the task. + drop(server_peer); + timeout(Duration::from_secs(1), relay_task).await.unwrap().unwrap().unwrap(); +} + +#[tokio::test] +async fn relay_bidirectional_zero_length_noise_fuzzing() { + let stats = Arc::new(Stats::new()); + let (mut client_peer, relay_client) = duplex(4096); + let (relay_server, mut server_peer) = duplex(4096); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_task = tokio::spawn(relay_bidirectional( + client_reader, client_writer, server_reader, server_writer, 1024, 1024, "fuzz", stats, None, Arc::new(BufferPool::new()), + )); + + // Flood with zero-length payloads (edge cases in stream framing logic sometimes loop) + for _ in 0..100 { + client_peer.write_all(&[]).await.unwrap(); + } + client_peer.write_all(&[1, 2, 3]).await.unwrap(); + client_peer.flush().await.unwrap(); + + let mut buf = [0u8; 3]; + server_peer.read_exact(&mut buf).await.unwrap(); + assert_eq!(&buf, &[1, 2, 3]); + + drop(client_peer); + drop(server_peer); + timeout(Duration::from_secs(1), relay_task).await.unwrap().unwrap().unwrap(); +} + +#[tokio::test] +async fn relay_bidirectional_asymmetric_backpressure() { + let stats = Arc::new(Stats::new()); + // Give the client stream an extremely narrow throughput limit explicitly + let (client_peer, relay_client) = duplex(1024); + let (relay_server, mut server_peer) = duplex(4096); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_task = tokio::spawn(relay_bidirectional( + client_reader, client_writer, server_reader, server_writer, 1024, 1024, "slowloris", stats, None, Arc::new(BufferPool::new()), + )); + + let payload = vec![0xba; 65536]; // 64k payload + + // Server attempts to shove 64KB into a relay whose client pipe only holds 1KB! + let write_res = tokio::time::timeout(Duration::from_millis(50), server_peer.write_all(&payload)).await; + + assert!( + write_res.is_err(), + "Relay backpressure MUST halt the server writer from unbounded buffering when client stream is full!" + ); + + drop(client_peer); + drop(server_peer); + + let completion = timeout(Duration::from_secs(1), relay_task).await.unwrap().unwrap(); + assert!( + completion.is_ok() || completion.is_err(), + "Task must unwind reliably (either Ok or BrokenPipe Err) when dropped despite active backpressure locks" + ); +} + +use rand::{Rng, SeedableRng, rngs::StdRng}; + +#[tokio::test] +async fn relay_bidirectional_light_fuzzing_temporal_jitter() { + tokio::time::pause(); + let stats = Arc::new(Stats::new()); + let (mut client_peer, relay_client) = duplex(4096); + let (relay_server, server_peer) = duplex(4096); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let mut relay_task = tokio::spawn(relay_bidirectional( + client_reader, client_writer, server_reader, server_writer, 1024, 1024, "fuzz-user", stats, None, Arc::new(BufferPool::new()), + )); + + let mut rng = StdRng::seed_from_u64(0xDEADBEEF); + + for _ in 0..10 { + // Vary timing significantly up to 1600 seconds (limit is 1800s) + let jitter = rng.random_range(100..1600); + tokio::time::sleep(Duration::from_secs(jitter)).await; + + client_peer.write_all(&[0x11]).await.unwrap(); + client_peer.flush().await.unwrap(); + + // Ensure task has not died + let res = timeout(Duration::from_millis(10), &mut relay_task).await; + assert!(res.is_err(), "Relay must remain open indefinitely under light temporal fuzzing with active jitter pulses"); + } + + drop(client_peer); + drop(server_peer); + timeout(Duration::from_secs(1), relay_task).await.unwrap().unwrap().unwrap(); +} + +struct FaultyReader { + error_once: Option, +} + +struct TwoPartyGate { + arrivals: AtomicUsize, + total_bytes: AtomicUsize, + wakers: Mutex>, +} + +impl TwoPartyGate { + fn new() -> Self { + Self { + arrivals: AtomicUsize::new(0), + total_bytes: AtomicUsize::new(0), + wakers: Mutex::new(Vec::new()), + } + } + + fn arrive_or_park(&self, cx: &mut Context<'_>) -> bool { + if self.arrivals.load(Ordering::Relaxed) >= 2 { + return true; + } + + let prev = self.arrivals.fetch_add(1, Ordering::AcqRel); + if prev + 1 >= 2 { + let mut wakers = self.wakers.lock().unwrap_or_else(|p| p.into_inner()); + for waker in wakers.drain(..) { + waker.wake(); + } + true + } else { + let mut wakers = self.wakers.lock().unwrap_or_else(|p| p.into_inner()); + wakers.push(cx.waker().clone()); + false + } + } + + fn total_bytes(&self) -> usize { + self.total_bytes.load(Ordering::Relaxed) + } +} + +struct GateWriter { + gate: Arc, + entered: bool, +} + +impl GateWriter { + fn new(gate: Arc) -> Self { + Self { + gate, + entered: false, + } + } +} + +impl AsyncWrite for GateWriter { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + if !self.entered { + self.entered = true; + } + + if !self.gate.arrive_or_park(cx) { + return Poll::Pending; + } + + self.gate + .total_bytes + .fetch_add(buf.len(), Ordering::Relaxed); + Poll::Ready(Ok(buf.len())) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} + +struct GateReader { + gate: Arc, + entered: bool, + emitted: bool, +} + +impl GateReader { + fn new(gate: Arc) -> Self { + Self { + gate, + entered: false, + emitted: false, + } + } +} + +impl AsyncRead for GateReader { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + if self.emitted { + return Poll::Ready(Ok(())); + } + + if !self.entered { + self.entered = true; + } + + if !self.gate.arrive_or_park(cx) { + return Poll::Pending; + } + + buf.put_slice(&[0x42]); + self.gate.total_bytes.fetch_add(1, Ordering::Relaxed); + self.emitted = true; + Poll::Ready(Ok(())) + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn adversarial_concurrent_quota_write_race_does_not_overshoot_limit() { + let stats = Arc::new(Stats::new()); + let gate = Arc::new(TwoPartyGate::new()); + let user = "concurrent-quota-write".to_string(); + + let writer_a = super::StatsIo::new( + GateWriter::new(Arc::clone(&gate)), + Arc::new(super::SharedCounters::new()), + Arc::clone(&stats), + user.clone(), + Some(1), + Arc::new(std::sync::atomic::AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + let writer_b = super::StatsIo::new( + GateWriter::new(Arc::clone(&gate)), + Arc::new(super::SharedCounters::new()), + Arc::clone(&stats), + user.clone(), + Some(1), + Arc::new(std::sync::atomic::AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + let task_a = tokio::spawn(async move { + let mut w = writer_a; + AsyncWriteExt::write_all(&mut w, &[0x01]).await + }); + let task_b = tokio::spawn(async move { + let mut w = writer_b; + AsyncWriteExt::write_all(&mut w, &[0x02]).await + }); + + let (res_a, res_b) = tokio::join!(task_a, task_b); + let _ = res_a.expect("task a must join"); + let _ = res_b.expect("task b must join"); + + assert!( + gate.total_bytes() <= 1, + "concurrent same-user writes must not forward more than one byte under quota=1" + ); + assert!( + stats.get_user_total_octets(&user) <= 1, + "concurrent same-user writes must not account over limit" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn adversarial_concurrent_quota_read_race_does_not_overshoot_limit() { + let stats = Arc::new(Stats::new()); + let gate = Arc::new(TwoPartyGate::new()); + let user = "concurrent-quota-read".to_string(); + + let reader_a = super::StatsIo::new( + GateReader::new(Arc::clone(&gate)), + Arc::new(super::SharedCounters::new()), + Arc::clone(&stats), + user.clone(), + Some(1), + Arc::new(std::sync::atomic::AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + let reader_b = super::StatsIo::new( + GateReader::new(Arc::clone(&gate)), + Arc::new(super::SharedCounters::new()), + Arc::clone(&stats), + user.clone(), + Some(1), + Arc::new(std::sync::atomic::AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + let task_a = tokio::spawn(async move { + let mut r = reader_a; + let mut one = [0u8; 1]; + AsyncReadExt::read_exact(&mut r, &mut one).await + }); + let task_b = tokio::spawn(async move { + let mut r = reader_b; + let mut one = [0u8; 1]; + AsyncReadExt::read_exact(&mut r, &mut one).await + }); + + let (res_a, res_b) = tokio::join!(task_a, task_b); + let _ = res_a.expect("task a must join"); + let _ = res_b.expect("task b must join"); + + assert!( + gate.total_bytes() <= 1, + "concurrent same-user reads must not consume more than one byte under quota=1" + ); + assert!( + stats.get_user_total_octets(&user) <= 1, + "concurrent same-user reads must not account over limit" + ); +} + +#[tokio::test] +async fn stress_same_user_quota_parallel_relays_never_exceed_limit() { + let stats = Arc::new(Stats::new()); + let user = "parallel-quota-user"; + + for _ in 0..128 { + let (mut client_peer_a, relay_client_a) = duplex(256); + let (relay_server_a, mut server_peer_a) = duplex(256); + let (mut client_peer_b, relay_client_b) = duplex(256); + let (relay_server_b, mut server_peer_b) = duplex(256); + + let (client_reader_a, client_writer_a) = tokio::io::split(relay_client_a); + let (server_reader_a, server_writer_a) = tokio::io::split(relay_server_a); + let (client_reader_b, client_writer_b) = tokio::io::split(relay_client_b); + let (server_reader_b, server_writer_b) = tokio::io::split(relay_server_b); + + let relay_a = tokio::spawn(relay_bidirectional( + client_reader_a, + client_writer_a, + server_reader_a, + server_writer_a, + 64, + 64, + user, + Arc::clone(&stats), + Some(1), + Arc::new(BufferPool::new()), + )); + + let relay_b = tokio::spawn(relay_bidirectional( + client_reader_b, + client_writer_b, + server_reader_b, + server_writer_b, + 64, + 64, + user, + Arc::clone(&stats), + Some(1), + Arc::new(BufferPool::new()), + )); + + let _ = tokio::join!( + client_peer_a.write_all(&[0x01]), + server_peer_a.write_all(&[0x02]), + client_peer_b.write_all(&[0x03]), + server_peer_b.write_all(&[0x04]), + ); + + let _ = timeout(Duration::from_millis(50), poll_fn(|cx| { + let mut one = [0u8; 1]; + let _ = Pin::new(&mut client_peer_a).poll_read(cx, &mut ReadBuf::new(&mut one)); + Poll::Ready(()) + })) + .await; + + drop(client_peer_a); + drop(server_peer_a); + drop(client_peer_b); + drop(server_peer_b); + + let _ = timeout(Duration::from_secs(1), relay_a).await; + let _ = timeout(Duration::from_secs(1), relay_b).await; + + assert!( + stats.get_user_total_octets(user) <= 1, + "parallel relays must not exceed configured quota" + ); + } +} + +impl FaultyReader { + fn permission_denied_with_message(message: impl Into) -> Self { + Self { + error_once: Some(io::Error::new(io::ErrorKind::PermissionDenied, message.into())), + } + } +} + +impl AsyncRead for FaultyReader { + fn poll_read( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _buf: &mut ReadBuf<'_>, + ) -> Poll> { + if let Some(err) = self.error_once.take() { + return Poll::Ready(Err(err)); + } + Poll::Ready(Ok(())) + } +} + +#[tokio::test] +async fn relay_bidirectional_does_not_misclassify_transport_permission_denied_as_quota() { + let stats = Arc::new(Stats::new()); + let (client_peer, relay_client) = duplex(4096); + let (client_reader, client_writer) = tokio::io::split(relay_client); + + let relay_result = relay_bidirectional( + client_reader, + client_writer, + FaultyReader::permission_denied_with_message("user data quota exceeded"), + tokio::io::sink(), + 1024, + 1024, + "non-quota-permission-denied", + Arc::clone(&stats), + None, + Arc::new(BufferPool::new()), + ) + .await; + + drop(client_peer); + + assert!( + matches!(relay_result, Err(ProxyError::Io(ref err)) if err.kind() == io::ErrorKind::PermissionDenied), + "non-quota transport PermissionDenied errors must remain IO errors" + ); +} + +#[tokio::test] +async fn relay_bidirectional_light_fuzz_permission_denied_messages_remain_io_errors() { + let mut rng = StdRng::seed_from_u64(0xA11CE0B5); + + for i in 0..128u64 { + let stats = Arc::new(Stats::new()); + let (client_peer, relay_client) = duplex(1024); + let (client_reader, client_writer) = tokio::io::split(relay_client); + + let random_len = rng.random_range(1..=48); + let mut msg = String::with_capacity(random_len); + for _ in 0..random_len { + let ch = (b'a' + (rng.random::() % 26)) as char; + msg.push(ch); + } + // Include the legacy quota string in a subset of fuzz cases to validate + // collision resistance against message-based classification. + if i % 7 == 0 { + msg = "user data quota exceeded".to_string(); + } + + let relay_result = relay_bidirectional( + client_reader, + client_writer, + FaultyReader::permission_denied_with_message(msg), + tokio::io::sink(), + 1024, + 1024, + "fuzz-perm-denied", + Arc::clone(&stats), + None, + Arc::new(BufferPool::new()), + ) + .await; + + drop(client_peer); + + assert!( + matches!(relay_result, Err(ProxyError::Io(ref err)) if err.kind() == io::ErrorKind::PermissionDenied), + "transport PermissionDenied case must stay typed as IO regardless of message content" + ); + } +} diff --git a/src/tls_front/emulator.rs b/src/tls_front/emulator.rs index 7e329c5..9140b39 100644 --- a/src/tls_front/emulator.rs +++ b/src/tls_front/emulator.rs @@ -103,7 +103,7 @@ pub fn build_emulated_server_hello( cached: &CachedTlsData, use_full_cert_payload: bool, rng: &SecureRandom, - _alpn: Option>, + alpn: Option>, new_session_tickets: u8, ) -> Vec { // --- ServerHello --- @@ -198,8 +198,22 @@ pub fn build_emulated_server_hello( } let mut app_data = Vec::new(); + let alpn_marker = alpn + .as_ref() + .filter(|p| !p.is_empty() && p.len() <= u8::MAX as usize) + .map(|proto| { + let proto_list_len = 1usize + proto.len(); + let ext_data_len = 2usize + proto_list_len; + let mut marker = Vec::with_capacity(4 + ext_data_len); + marker.extend_from_slice(&0x0010u16.to_be_bytes()); + marker.extend_from_slice(&(ext_data_len as u16).to_be_bytes()); + marker.extend_from_slice(&(proto_list_len as u16).to_be_bytes()); + marker.push(proto.len() as u8); + marker.extend_from_slice(proto); + marker + }); let mut payload_offset = 0usize; - for size in sizes { + for (idx, size) in sizes.into_iter().enumerate() { let mut rec = Vec::with_capacity(5 + size); rec.push(TLS_RECORD_APPLICATION); rec.extend_from_slice(&TLS_VERSION); @@ -224,7 +238,20 @@ pub fn build_emulated_server_hello( } } else if size > 17 { let body_len = size - 17; - rec.extend_from_slice(&rng.bytes(body_len)); + let mut body = Vec::with_capacity(body_len); + if idx == 0 && let Some(marker) = &alpn_marker { + if marker.len() <= body_len { + body.extend_from_slice(marker); + if body_len > marker.len() { + body.extend_from_slice(&rng.bytes(body_len - marker.len())); + } + } else { + body.extend_from_slice(&rng.bytes(body_len)); + } + } else { + body.extend_from_slice(&rng.bytes(body_len)); + } + rec.extend_from_slice(&body); rec.push(0x16); // inner content type marker (handshake) rec.extend_from_slice(&rng.bytes(16)); // AEAD-like tag } else { @@ -236,8 +263,9 @@ pub fn build_emulated_server_hello( // --- Combine --- // Optional NewSessionTicket mimic records (opaque ApplicationData for fingerprint). let mut tickets = Vec::new(); - if new_session_tickets > 0 { - for _ in 0..new_session_tickets { + let ticket_count = new_session_tickets.min(4); + if ticket_count > 0 { + for _ in 0..ticket_count { let ticket_len: usize = rng.range(48) + 48; let mut rec = Vec::with_capacity(5 + ticket_len); rec.push(TLS_RECORD_APPLICATION); @@ -264,6 +292,10 @@ pub fn build_emulated_server_hello( response } +#[cfg(test)] +#[path = "emulator_security_tests.rs"] +mod security_tests; + #[cfg(test)] mod tests { use std::time::SystemTime; diff --git a/src/tls_front/emulator_security_tests.rs b/src/tls_front/emulator_security_tests.rs new file mode 100644 index 0000000..c49d15a --- /dev/null +++ b/src/tls_front/emulator_security_tests.rs @@ -0,0 +1,136 @@ +use std::time::SystemTime; + +use crate::crypto::SecureRandom; +use crate::protocol::constants::{TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER, TLS_RECORD_HANDSHAKE}; +use crate::tls_front::emulator::build_emulated_server_hello; +use crate::tls_front::types::{ + CachedTlsData, ParsedServerHello, TlsBehaviorProfile, TlsCertPayload, TlsProfileSource, +}; + +fn make_cached(cert_payload: Option) -> CachedTlsData { + CachedTlsData { + server_hello_template: ParsedServerHello { + version: [0x03, 0x03], + random: [0u8; 32], + session_id: Vec::new(), + cipher_suite: [0x13, 0x01], + compression: 0, + extensions: Vec::new(), + }, + cert_info: None, + cert_payload, + app_data_records_sizes: vec![64], + total_app_data_len: 64, + behavior_profile: TlsBehaviorProfile { + change_cipher_spec_count: 1, + app_data_record_sizes: vec![64], + ticket_record_sizes: Vec::new(), + source: TlsProfileSource::Default, + }, + fetched_at: SystemTime::now(), + domain: "example.com".to_string(), + } +} + +fn first_app_data_payload(response: &[u8]) -> &[u8] { + let hello_len = u16::from_be_bytes([response[3], response[4]]) as usize; + let ccs_start = 5 + hello_len; + let ccs_len = u16::from_be_bytes([response[ccs_start + 3], response[ccs_start + 4]]) as usize; + let app_start = ccs_start + 5 + ccs_len; + let app_len = u16::from_be_bytes([response[app_start + 3], response[app_start + 4]]) as usize; + &response[app_start + 5..app_start + 5 + app_len] +} + +#[test] +fn emulated_server_hello_ignores_oversized_alpn_when_marker_would_not_fit() { + let cached = make_cached(None); + let rng = SecureRandom::new(); + let oversized_alpn = vec![0xAB; u8::MAX as usize + 1]; + + let response = build_emulated_server_hello( + b"secret", + &[0x11; 32], + &[0x22; 16], + &cached, + true, + &rng, + Some(oversized_alpn), + 0, + ); + + assert_eq!(response[0], TLS_RECORD_HANDSHAKE); + let hello_len = u16::from_be_bytes([response[3], response[4]]) as usize; + let ccs_start = 5 + hello_len; + assert_eq!(response[ccs_start], TLS_RECORD_CHANGE_CIPHER); + let app_start = ccs_start + 6; + assert_eq!(response[app_start], TLS_RECORD_APPLICATION); + + let payload = first_app_data_payload(&response); + let mut marker_prefix = Vec::new(); + marker_prefix.extend_from_slice(&0x0010u16.to_be_bytes()); + marker_prefix.extend_from_slice(&0x0102u16.to_be_bytes()); + marker_prefix.extend_from_slice(&0x0100u16.to_be_bytes()); + marker_prefix.push(0xff); + marker_prefix.extend_from_slice(&[0xab; 8]); + assert!( + !payload.starts_with(&marker_prefix), + "oversized ALPN must not be partially embedded into the emulated first application record" + ); +} + +#[test] +fn emulated_server_hello_embeds_full_alpn_marker_when_body_can_fit() { + let cached = make_cached(None); + let rng = SecureRandom::new(); + + let response = build_emulated_server_hello( + b"secret", + &[0x31; 32], + &[0x41; 16], + &cached, + true, + &rng, + Some(b"h2".to_vec()), + 0, + ); + + let payload = first_app_data_payload(&response); + let expected = [0x00u8, 0x10, 0x00, 0x05, 0x00, 0x03, 0x02, b'h', b'2']; + assert!( + payload.starts_with(&expected), + "when body has enough capacity, emulated first application record must include full ALPN marker" + ); +} + +#[test] +fn emulated_server_hello_prefers_cert_payload_over_alpn_marker() { + let cert_msg = vec![0x0b, 0x00, 0x00, 0x05, 0x00, 0xaa, 0xbb, 0xcc, 0xdd]; + let cached = make_cached(Some(TlsCertPayload { + cert_chain_der: vec![vec![0x30, 0x01, 0x00]], + certificate_message: cert_msg.clone(), + })); + let rng = SecureRandom::new(); + + let response = build_emulated_server_hello( + b"secret", + &[0x32; 32], + &[0x42; 16], + &cached, + true, + &rng, + Some(b"h2".to_vec()), + 0, + ); + + let payload = first_app_data_payload(&response); + let alpn_marker = [0x00u8, 0x10, 0x00, 0x05, 0x00, 0x03, 0x02, b'h', b'2']; + + assert!( + payload.starts_with(&cert_msg), + "when certificate payload is available, first record must start with cert payload bytes" + ); + assert!( + !payload.starts_with(&alpn_marker), + "ALPN marker must not displace selected certificate payload" + ); +} diff --git a/src/transport/middle_proxy/mod.rs b/src/transport/middle_proxy/mod.rs index 590c996..4e2a5c7 100644 --- a/src/transport/middle_proxy/mod.rs +++ b/src/transport/middle_proxy/mod.rs @@ -27,6 +27,8 @@ mod health_regression_tests; mod health_integration_tests; #[cfg(test)] mod health_adversarial_tests; +#[cfg(test)] +mod send_adversarial_tests; use bytes::Bytes; diff --git a/src/transport/middle_proxy/pool.rs b/src/transport/middle_proxy/pool.rs index 56f3fbf..84e4e11 100644 --- a/src/transport/middle_proxy/pool.rs +++ b/src/transport/middle_proxy/pool.rs @@ -692,6 +692,7 @@ impl MePool { } } + #[allow(dead_code)] pub(super) fn draining_active_runtime(&self) -> u64 { self.draining_active_runtime.load(Ordering::Relaxed) } diff --git a/src/transport/middle_proxy/registry.rs b/src/transport/middle_proxy/registry.rs index ea968b5..a22b98d 100644 --- a/src/transport/middle_proxy/registry.rs +++ b/src/transport/middle_proxy/registry.rs @@ -454,6 +454,7 @@ impl ConnRegistry { true } + #[allow(dead_code)] pub(super) async fn non_empty_writer_ids(&self, writer_ids: &[u64]) -> HashSet { let inner = self.inner.read().await; let mut out = HashSet::::with_capacity(writer_ids.len()); diff --git a/src/transport/middle_proxy/send.rs b/src/transport/middle_proxy/send.rs index 0f9fed6..5e0e562 100644 --- a/src/transport/middle_proxy/send.rs +++ b/src/transport/middle_proxy/send.rs @@ -372,17 +372,20 @@ impl MePool { } let effective_our_addr = SocketAddr::new(w.source_ip, our_addr.port()); let (payload, meta) = build_routed_payload(effective_our_addr); - match w.tx.try_send(WriterCommand::Data(payload.clone())) { - Ok(()) => { - self.stats.increment_me_writer_pick_success_try_total(pick_mode); + match w.tx.clone().try_reserve_owned() { + Ok(permit) => { if !self.registry.bind_writer(conn_id, w.id, meta).await { debug!( conn_id, writer_id = w.id, - "ME writer disappeared before bind commit, retrying" + "ME writer disappeared before bind commit, pruning stale writer" ); + drop(permit); + self.remove_writer_and_close_clients(w.id).await; continue; } + permit.send(WriterCommand::Data(payload.clone())); + self.stats.increment_me_writer_pick_success_try_total(pick_mode); if w.generation < self.current_generation() { self.stats.increment_pool_stale_pick_total(); debug!( @@ -422,18 +425,21 @@ impl MePool { self.stats.increment_me_writer_pick_blocking_fallback_total(); let effective_our_addr = SocketAddr::new(w.source_ip, our_addr.port()); let (payload, meta) = build_routed_payload(effective_our_addr); - match w.tx.send(WriterCommand::Data(payload.clone())).await { - Ok(()) => { - self.stats - .increment_me_writer_pick_success_fallback_total(pick_mode); + match w.tx.clone().reserve_owned().await { + Ok(permit) => { if !self.registry.bind_writer(conn_id, w.id, meta).await { debug!( conn_id, writer_id = w.id, - "ME writer disappeared before fallback bind commit, retrying" + "ME writer disappeared before fallback bind commit, pruning stale writer" ); + drop(permit); + self.remove_writer_and_close_clients(w.id).await; continue; } + permit.send(WriterCommand::Data(payload.clone())); + self.stats + .increment_me_writer_pick_success_fallback_total(pick_mode); if w.generation < self.current_generation() { self.stats.increment_pool_stale_pick_total(); } diff --git a/src/transport/middle_proxy/send_adversarial_tests.rs b/src/transport/middle_proxy/send_adversarial_tests.rs new file mode 100644 index 0000000..6c80672 --- /dev/null +++ b/src/transport/middle_proxy/send_adversarial_tests.rs @@ -0,0 +1,263 @@ +use std::collections::HashMap; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, AtomicU8, AtomicU32, AtomicU64, Ordering}; +use std::time::{Duration, Instant}; + +use tokio::sync::mpsc; +use tokio_util::sync::CancellationToken; + +use super::codec::WriterCommand; +use super::pool::{MePool, MeWriter, WriterContour}; +use crate::config::{GeneralConfig, MeRouteNoWriterMode, MeSocksKdfPolicy, MeWriterPickMode}; +use crate::crypto::SecureRandom; +use crate::network::probe::NetworkDecision; +use crate::stats::Stats; + +async fn make_pool() -> (Arc, Arc) { + let general = GeneralConfig { + me_route_no_writer_mode: MeRouteNoWriterMode::AsyncRecoveryFailfast, + me_route_no_writer_wait_ms: 50, + me_writer_pick_mode: MeWriterPickMode::SortedRr, + me_deterministic_writer_sort: true, + ..GeneralConfig::default() + }; + + let rng = Arc::new(SecureRandom::new()); + let pool = MePool::new( + None, + vec![1u8; 32], + None, + false, + None, + Vec::new(), + 1, + None, + 12, + 1200, + HashMap::new(), + HashMap::new(), + None, + NetworkDecision::default(), + None, + rng.clone(), + Arc::new(Stats::default()), + general.me_keepalive_enabled, + general.me_keepalive_interval_secs, + general.me_keepalive_jitter_secs, + general.me_keepalive_payload_random, + general.rpc_proxy_req_every, + general.me_warmup_stagger_enabled, + general.me_warmup_step_delay_ms, + general.me_warmup_step_jitter_ms, + general.me_reconnect_max_concurrent_per_dc, + general.me_reconnect_backoff_base_ms, + general.me_reconnect_backoff_cap_ms, + general.me_reconnect_fast_retry_count, + general.me_single_endpoint_shadow_writers, + general.me_single_endpoint_outage_mode_enabled, + general.me_single_endpoint_outage_disable_quarantine, + general.me_single_endpoint_outage_backoff_min_ms, + general.me_single_endpoint_outage_backoff_max_ms, + general.me_single_endpoint_shadow_rotate_every_secs, + general.me_floor_mode, + general.me_adaptive_floor_idle_secs, + general.me_adaptive_floor_min_writers_single_endpoint, + general.me_adaptive_floor_min_writers_multi_endpoint, + general.me_adaptive_floor_recover_grace_secs, + general.me_adaptive_floor_writers_per_core_total, + general.me_adaptive_floor_cpu_cores_override, + general.me_adaptive_floor_max_extra_writers_single_per_core, + general.me_adaptive_floor_max_extra_writers_multi_per_core, + general.me_adaptive_floor_max_active_writers_per_core, + general.me_adaptive_floor_max_warm_writers_per_core, + general.me_adaptive_floor_max_active_writers_global, + general.me_adaptive_floor_max_warm_writers_global, + general.hardswap, + general.me_pool_drain_ttl_secs, + general.me_pool_drain_threshold, + general.effective_me_pool_force_close_secs(), + general.me_pool_min_fresh_ratio, + general.me_hardswap_warmup_delay_min_ms, + general.me_hardswap_warmup_delay_max_ms, + general.me_hardswap_warmup_extra_passes, + general.me_hardswap_warmup_pass_backoff_base_ms, + general.me_bind_stale_mode, + general.me_bind_stale_ttl_secs, + general.me_secret_atomic_snapshot, + general.me_deterministic_writer_sort, + general.me_writer_pick_mode, + general.me_writer_pick_sample_size, + MeSocksKdfPolicy::default(), + general.me_writer_cmd_channel_capacity, + general.me_route_channel_capacity, + general.me_route_backpressure_base_timeout_ms, + general.me_route_backpressure_high_timeout_ms, + general.me_route_backpressure_high_watermark_pct, + general.me_reader_route_data_wait_ms, + general.me_health_interval_ms_unhealthy, + general.me_health_interval_ms_healthy, + general.me_warn_rate_limit_ms, + general.me_route_no_writer_mode, + general.me_route_no_writer_wait_ms, + general.me_route_inline_recovery_attempts, + general.me_route_inline_recovery_wait_ms, + ); + + (pool, rng) +} + +async fn insert_writer( + pool: &Arc, + writer_id: u64, + writer_dc: i32, + addr: SocketAddr, + register_in_registry: bool, +) -> mpsc::Receiver { + let (tx, rx) = mpsc::channel::(8); + let writer = MeWriter { + id: writer_id, + addr, + source_ip: addr.ip(), + writer_dc, + generation: pool.current_generation(), + contour: Arc::new(AtomicU8::new(WriterContour::Active.as_u8())), + created_at: Instant::now(), + tx: tx.clone(), + cancel: CancellationToken::new(), + degraded: Arc::new(AtomicBool::new(false)), + rtt_ema_ms_x10: Arc::new(AtomicU32::new(0)), + draining: Arc::new(AtomicBool::new(false)), + draining_started_at_epoch_secs: Arc::new(AtomicU64::new(0)), + drain_deadline_epoch_secs: Arc::new(AtomicU64::new(0)), + allow_drain_fallback: Arc::new(AtomicBool::new(false)), + }; + + pool.writers.write().await.push(writer); + { + let mut map = pool.proxy_map_v4.write().await; + map.entry(writer_dc) + .or_insert_with(Vec::new) + .push((addr.ip(), addr.port())); + } + pool.rebuild_endpoint_dc_map().await; + if register_in_registry { + pool.registry.register_writer(writer_id, tx).await; + } + rx +} + +async fn recv_data_count(rx: &mut mpsc::Receiver, budget: Duration) -> usize { + let start = Instant::now(); + let mut data_count = 0usize; + while Instant::now().duration_since(start) < budget { + let remaining = budget.saturating_sub(Instant::now().duration_since(start)); + match tokio::time::timeout(remaining.min(Duration::from_millis(10)), rx.recv()).await { + Ok(Some(WriterCommand::Data(_))) => data_count += 1, + Ok(Some(WriterCommand::DataAndFlush(_))) => data_count += 1, + Ok(Some(WriterCommand::Close)) => {} + Ok(None) => break, + Err(_) => break, + } + } + data_count +} + +#[tokio::test] +async fn send_proxy_req_does_not_replay_when_first_bind_commit_fails() { + let (pool, _rng) = make_pool().await; + pool.rr.store(0, Ordering::Relaxed); + + let (conn_id, _rx) = pool.registry.register().await; + let mut stale_rx = insert_writer( + &pool, + 10, + 2, + SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 10)), 443), + false, + ) + .await; + let mut live_rx = insert_writer( + &pool, + 11, + 2, + SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 11)), 443), + true, + ) + .await; + + let result = pool + .send_proxy_req( + conn_id, + 2, + SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 30000), + SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443), + b"hello", + 0, + None, + ) + .await; + + assert!(result.is_ok()); + assert_eq!(recv_data_count(&mut stale_rx, Duration::from_millis(50)).await, 0); + assert_eq!(recv_data_count(&mut live_rx, Duration::from_millis(50)).await, 1); + + let bound = pool.registry.get_writer(conn_id).await; + assert!(bound.is_some()); + assert_eq!(bound.expect("writer should be bound").writer_id, 11); +} + +#[tokio::test] +async fn send_proxy_req_prunes_iterative_stale_bind_failures_without_data_replay() { + let (pool, _rng) = make_pool().await; + pool.rr.store(0, Ordering::Relaxed); + + let (conn_id, _rx) = pool.registry.register().await; + + let mut stale_rx_1 = insert_writer( + &pool, + 21, + 2, + SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 1, 21)), 443), + false, + ) + .await; + let mut stale_rx_2 = insert_writer( + &pool, + 22, + 2, + SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 1, 22)), 443), + false, + ) + .await; + let mut live_rx = insert_writer( + &pool, + 23, + 2, + SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 1, 23)), 443), + true, + ) + .await; + + let result = pool + .send_proxy_req( + conn_id, + 2, + SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 30001), + SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443), + b"storm", + 0, + None, + ) + .await; + + assert!(result.is_ok()); + assert_eq!(recv_data_count(&mut stale_rx_1, Duration::from_millis(50)).await, 0); + assert_eq!(recv_data_count(&mut stale_rx_2, Duration::from_millis(50)).await, 0); + assert_eq!(recv_data_count(&mut live_rx, Duration::from_millis(50)).await, 1); + + let writers = pool.writers.read().await; + let writer_ids = writers.iter().map(|w| w.id).collect::>(); + drop(writers); + assert_eq!(writer_ids, vec![23]); +}