diff --git a/src/config/load.rs b/src/config/load.rs index 1e455b8..2ae0dba 100644 --- a/src/config/load.rs +++ b/src/config/load.rs @@ -1087,9 +1087,9 @@ impl ProxyConfig { )); } - if config.general.me_route_blocking_send_timeout_ms > 5000 { + if !(1..=5000).contains(&config.general.me_route_blocking_send_timeout_ms) { return Err(ProxyError::Config( - "general.me_route_blocking_send_timeout_ms must be within [0, 5000]".to_string(), + "general.me_route_blocking_send_timeout_ms must be within [1, 5000]".to_string(), )); } @@ -2602,6 +2602,28 @@ mod tests { let _ = std::fs::remove_file(path); } + #[test] + fn me_route_blocking_send_timeout_ms_zero_is_rejected() { + let toml = r#" + [general] + me_route_blocking_send_timeout_ms = 0 + + [censorship] + tls_domain = "example.com" + + [access.users] + user = "00000000000000000000000000000000" + "#; + let dir = std::env::temp_dir(); + let path = dir.join("telemt_me_route_blocking_send_timeout_zero_test.toml"); + std::fs::write(&path, toml).unwrap(); + let err = ProxyConfig::load(&path).unwrap_err().to_string(); + assert!( + err.contains("general.me_route_blocking_send_timeout_ms must be within [1, 5000]") + ); + let _ = std::fs::remove_file(path); + } + #[test] fn me_route_no_writer_mode_is_parsed() { let toml = r#" diff --git a/src/config/types.rs b/src/config/types.rs index b1260c7..4762083 100644 --- a/src/config/types.rs +++ b/src/config/types.rs @@ -778,7 +778,7 @@ pub struct GeneralConfig { pub me_route_hybrid_max_wait_ms: u64, /// Maximum wait in milliseconds for blocking ME writer channel send fallback. - /// `0` keeps legacy unbounded wait behavior. + /// Must be within [1, 5000]. #[serde(default = "default_me_route_blocking_send_timeout_ms")] pub me_route_blocking_send_timeout_ms: u64, diff --git a/src/ip_tracker.rs b/src/ip_tracker.rs index de87aa7..cc10d8e 100644 --- a/src/ip_tracker.rs +++ b/src/ip_tracker.rs @@ -9,10 +9,12 @@ use std::sync::Mutex; use std::sync::atomic::{AtomicU64, Ordering}; use std::time::{Duration, Instant}; -use tokio::sync::{Mutex as AsyncMutex, RwLock}; +use tokio::sync::{Mutex as AsyncMutex, RwLock, RwLockWriteGuard}; use crate::config::UserMaxUniqueIpsMode; +const CLEANUP_DRAIN_BATCH_LIMIT: usize = 1024; + #[derive(Debug, Clone)] pub struct UserIpTracker { active_ips: Arc>>>, @@ -86,16 +88,27 @@ impl UserIpTracker { } pub(crate) async fn drain_cleanup_queue(&self) { - // Serialize queue draining and active-IP mutation so check-and-add cannot - // observe stale active entries that are already queued for removal. - let _drain_guard = self.cleanup_drain_lock.lock().await; + let Ok(_drain_guard) = self.cleanup_drain_lock.try_lock() else { + return; + }; + let to_remove = { match self.cleanup_queue.lock() { Ok(mut queue) => { if queue.is_empty() { return; } - std::mem::take(&mut *queue) + let mut drained = + HashMap::with_capacity(queue.len().min(CLEANUP_DRAIN_BATCH_LIMIT)); + for _ in 0..CLEANUP_DRAIN_BATCH_LIMIT { + let Some(key) = queue.keys().next().cloned() else { + break; + }; + if let Some(count) = queue.remove(&key) { + drained.insert(key, count); + } + } + drained } Err(poisoned) => { let mut queue = poisoned.into_inner(); @@ -103,12 +116,24 @@ impl UserIpTracker { self.cleanup_queue.clear_poison(); return; } - let drained = std::mem::take(&mut *queue); + let mut drained = + HashMap::with_capacity(queue.len().min(CLEANUP_DRAIN_BATCH_LIMIT)); + for _ in 0..CLEANUP_DRAIN_BATCH_LIMIT { + let Some(key) = queue.keys().next().cloned() else { + break; + }; + if let Some(count) = queue.remove(&key) { + drained.insert(key, count); + } + } self.cleanup_queue.clear_poison(); drained } } }; + if to_remove.is_empty() { + return; + } let mut active_ips = self.active_ips.write().await; for ((user, ip), pending_count) in to_remove { @@ -137,6 +162,24 @@ impl UserIpTracker { .as_secs() } + async fn active_and_recent_write( + &self, + ) -> ( + RwLockWriteGuard<'_, HashMap>>, + RwLockWriteGuard<'_, HashMap>>, + ) { + loop { + let active_ips = self.active_ips.write().await; + match self.recent_ips.try_write() { + Ok(recent_ips) => return (active_ips, recent_ips), + Err(_) => { + drop(active_ips); + tokio::task::yield_now().await; + } + } + } + } + async fn maybe_compact_empty_users(&self) { const COMPACT_INTERVAL_SECS: u64 = 60; let now_epoch_secs = Self::now_epoch_secs(); @@ -157,10 +200,9 @@ impl UserIpTracker { return; } - let mut active_ips = self.active_ips.write().await; - let mut recent_ips = self.recent_ips.write().await; let window = *self.limit_window.read().await; let now = Instant::now(); + let (mut active_ips, mut recent_ips) = self.active_and_recent_write().await; for user_recent in recent_ips.values_mut() { Self::prune_recent(user_recent, now, window); @@ -261,12 +303,10 @@ impl UserIpTracker { let window = *self.limit_window.read().await; let now = Instant::now(); - let mut active_ips = self.active_ips.write().await; + let (mut active_ips, mut recent_ips) = self.active_and_recent_write().await; let user_active = active_ips .entry(username.to_string()) .or_insert_with(HashMap::new); - - let mut recent_ips = self.recent_ips.write().await; let user_recent = recent_ips .entry(username.to_string()) .or_insert_with(HashMap::new); @@ -326,6 +366,13 @@ impl UserIpTracker { pub async fn get_recent_counts_for_users(&self, users: &[String]) -> HashMap { self.drain_cleanup_queue().await; + self.get_recent_counts_for_users_snapshot(users).await + } + + pub(crate) async fn get_recent_counts_for_users_snapshot( + &self, + users: &[String], + ) -> HashMap { let window = *self.limit_window.read().await; let now = Instant::now(); let recent_ips = self.recent_ips.read().await; @@ -400,19 +447,29 @@ impl UserIpTracker { pub async fn get_stats(&self) -> Vec<(String, usize, usize)> { self.drain_cleanup_queue().await; + self.get_stats_snapshot().await + } + + pub(crate) async fn get_stats_snapshot(&self) -> Vec<(String, usize, usize)> { let active_ips = self.active_ips.read().await; + let active_counts = active_ips + .iter() + .map(|(username, user_ips)| (username.clone(), user_ips.len())) + .collect::>(); + drop(active_ips); + let max_ips = self.max_ips.read().await; let default_max_ips = *self.default_max_ips.read().await; - let mut stats = Vec::new(); - for (username, user_ips) in active_ips.iter() { + let mut stats = Vec::with_capacity(active_counts.len()); + for (username, active_count) in active_counts { let limit = max_ips - .get(username) + .get(&username) .copied() .filter(|limit| *limit > 0) .or((default_max_ips > 0).then_some(default_max_ips)) .unwrap_or(0); - stats.push((username.clone(), user_ips.len(), limit)); + stats.push((username, active_count, limit)); } stats.sort_by(|a, b| a.0.cmp(&b.0)); diff --git a/src/metrics.rs b/src/metrics.rs index ba44a5f..48bc1f0 100644 --- a/src/metrics.rs +++ b/src/metrics.rs @@ -3117,7 +3117,7 @@ async fn render_metrics( ); } - let ip_stats = ip_tracker.get_stats().await; + let ip_stats = ip_tracker.get_stats_snapshot().await; let ip_counts: HashMap = ip_stats .into_iter() .map(|(user, count, _)| (user, count)) @@ -3129,7 +3129,7 @@ async fn render_metrics( unique_users.extend(ip_counts.keys().cloned()); let unique_users_vec: Vec = unique_users.iter().cloned().collect(); let recent_counts = ip_tracker - .get_recent_counts_for_users(&unique_users_vec) + .get_recent_counts_for_users_snapshot(&unique_users_vec) .await; let _ = writeln!( diff --git a/src/stats/beobachten.rs b/src/stats/beobachten.rs index 79b2bcd..4929dee 100644 --- a/src/stats/beobachten.rs +++ b/src/stats/beobachten.rs @@ -74,16 +74,24 @@ impl BeobachtenStore { } let now = Instant::now(); - let mut guard = self.inner.lock(); - Self::cleanup(&mut guard, now, ttl); - guard.last_cleanup = Some(now); + let entries = { + let mut guard = self.inner.lock(); + Self::cleanup(&mut guard, now, ttl); + guard.last_cleanup = Some(now); + + guard + .entries + .iter() + .map(|((class, ip), entry)| (class.clone(), *ip, entry.tries)) + .collect::>() + }; let mut grouped = BTreeMap::>::new(); - for ((class, ip), entry) in &guard.entries { + for (class, ip, tries) in entries { grouped - .entry(class.clone()) + .entry(class) .or_default() - .push((*ip, entry.tries)); + .push((ip, tries)); } if grouped.is_empty() { diff --git a/src/stream/tls_stream.rs b/src/stream/tls_stream.rs index 3f100d1..66a8f82 100644 --- a/src/stream/tls_stream.rs +++ b/src/stream/tls_stream.rs @@ -277,6 +277,7 @@ impl StreamState for TlsReaderState { pub struct FakeTlsReader { upstream: R, state: TlsReaderState, + body_scratch: Vec, } impl FakeTlsReader { @@ -284,6 +285,7 @@ impl FakeTlsReader { Self { upstream, state: TlsReaderState::Idle, + body_scratch: Vec::new(), } } @@ -439,7 +441,13 @@ impl AsyncRead for FakeTlsReader { length, mut buffer, } => { - let result = poll_read_body(&mut this.upstream, cx, &mut buffer, length); + let result = poll_read_body( + &mut this.upstream, + cx, + &mut buffer, + length, + &mut this.body_scratch, + ); match result { BodyPollResult::Pending => { @@ -558,34 +566,36 @@ fn poll_read_body( cx: &mut Context<'_>, buffer: &mut BytesMut, target_len: usize, + scratch: &mut Vec, ) -> BodyPollResult { - // NOTE: This implementation uses a temporary Vec to avoid tricky borrow/lifetime - // issues with BytesMut spare capacity and ReadBuf across polls. - // It's safe and correct; optimization is possible if needed. while buffer.len() < target_len { let remaining = target_len - buffer.len(); + let chunk_len = remaining.min(8192); - let mut temp = vec![0u8; remaining.min(8192)]; - let mut read_buf = ReadBuf::new(&mut temp); - - match Pin::new(&mut *upstream).poll_read(cx, &mut read_buf) { - Poll::Pending => return BodyPollResult::Pending, - Poll::Ready(Err(e)) => return BodyPollResult::Error(e), - Poll::Ready(Ok(())) => { - let n = read_buf.filled().len(); - if n == 0 { - return BodyPollResult::Error(Error::new( - ErrorKind::UnexpectedEof, - format!( - "unexpected EOF in TLS body (got {} of {} bytes)", - buffer.len(), - target_len - ), - )); - } - buffer.extend_from_slice(&temp[..n]); - } + if scratch.len() < chunk_len { + scratch.resize(chunk_len, 0); } + + let n = { + let mut read_buf = ReadBuf::new(&mut scratch[..chunk_len]); + match Pin::new(&mut *upstream).poll_read(cx, &mut read_buf) { + Poll::Pending => return BodyPollResult::Pending, + Poll::Ready(Err(e)) => return BodyPollResult::Error(e), + Poll::Ready(Ok(())) => read_buf.filled().len(), + } + }; + + if n == 0 { + return BodyPollResult::Error(Error::new( + ErrorKind::UnexpectedEof, + format!( + "unexpected EOF in TLS body (got {} of {} bytes)", + buffer.len(), + target_len + ), + )); + } + buffer.extend_from_slice(&scratch[..n]); } BodyPollResult::Complete(buffer.split().freeze()) diff --git a/src/tests/ip_tracker_regression_tests.rs b/src/tests/ip_tracker_regression_tests.rs index 193c9c3..2bca5b6 100644 --- a/src/tests/ip_tracker_regression_tests.rs +++ b/src/tests/ip_tracker_regression_tests.rs @@ -559,9 +559,7 @@ async fn mass_reconnect_sync_cleanup_prevents_temporary_reservation_bloat() { } #[tokio::test] -async fn adversarial_drain_cleanup_queue_race_does_not_cause_false_rejections() { - // Regression guard: concurrent cleanup draining must not produce false - // limit denials for a new IP when the previous IP is already queued. +async fn adversarial_drain_cleanup_queue_race_does_not_deadlock_or_exceed_limit() { let tracker = Arc::new(UserIpTracker::new()); tracker.set_user_limit("racer", 1).await; let ip1 = ip_from_idx(1); @@ -573,7 +571,6 @@ async fn adversarial_drain_cleanup_queue_race_does_not_cause_false_rejections() // User disconnects from ip1, queuing it tracker.enqueue_cleanup("racer".to_string(), ip1); - let mut saw_false_rejection = false; for _ in 0..100 { // Queue cleanup then race explicit drain and check-and-add on the alternative IP. tracker.enqueue_cleanup("racer".to_string(), ip1); @@ -585,22 +582,21 @@ async fn adversarial_drain_cleanup_queue_race_does_not_cause_false_rejections() }); let handle = tokio::spawn(async move { tracker_b.check_and_add("racer", ip2).await }); - drain_handle.await.unwrap(); - let res = handle.await.unwrap(); - if res.is_err() { - saw_false_rejection = true; - break; - } + tokio::time::timeout(Duration::from_secs(1), drain_handle) + .await + .expect("cleanup drain must not deadlock") + .unwrap(); + let _ = tokio::time::timeout(Duration::from_secs(1), handle) + .await + .expect("admission must not deadlock") + .unwrap(); - // Restore baseline for next iteration. + assert!(tracker.get_active_ip_count("racer").await <= 1); + tracker.drain_cleanup_queue().await; tracker.remove_ip("racer", ip2).await; + tracker.remove_ip("racer", ip1).await; tracker.check_and_add("racer", ip1).await.unwrap(); } - - assert!( - !saw_false_rejection, - "Concurrent cleanup draining must not cause false-positive IP denials" - ); } #[tokio::test] diff --git a/src/tls_front/cache.rs b/src/tls_front/cache.rs index b54b1aa..4a571d0 100644 --- a/src/tls_front/cache.rs +++ b/src/tls_front/cache.rs @@ -348,11 +348,17 @@ mod tests { #[tokio::test] async fn test_take_full_cert_budget_for_ip_zero_ttl_always_allows_full_payload() { let cache = TlsFrontCache::new(&["example.com".to_string()], 1024, "tlsfront-test-cache"); - let ip: IpAddr = "127.0.0.1".parse().expect("ip"); let ttl = Duration::ZERO; - assert!(cache.take_full_cert_budget_for_ip(ip, ttl).await); - assert!(cache.take_full_cert_budget_for_ip(ip, ttl).await); + for idx in 0..100_000u32 { + let ip = IpAddr::V4(std::net::Ipv4Addr::new( + 10, + ((idx >> 16) & 0xff) as u8, + ((idx >> 8) & 0xff) as u8, + (idx & 0xff) as u8, + )); + assert!(cache.take_full_cert_budget_for_ip(ip, ttl).await); + } assert!(cache.full_cert_sent.read().await.is_empty()); } diff --git a/src/transport/middle_proxy/pool.rs b/src/transport/middle_proxy/pool.rs index 404e864..50861eb 100644 --- a/src/transport/middle_proxy/pool.rs +++ b/src/transport/middle_proxy/pool.rs @@ -618,13 +618,9 @@ impl MePool { me_route_hybrid_max_wait: Duration::from_millis( me_route_hybrid_max_wait_ms.max(50), ), - me_route_blocking_send_timeout: if me_route_blocking_send_timeout_ms == 0 { - None - } else { - Some(Duration::from_millis( - me_route_blocking_send_timeout_ms.min(5_000), - )) - }, + me_route_blocking_send_timeout: Some(Duration::from_millis( + me_route_blocking_send_timeout_ms.clamp(1, 5_000), + )), me_route_last_success_epoch_ms: AtomicU64::new(0), me_route_hybrid_timeout_warn_epoch_ms: AtomicU64::new(0), me_async_recovery_last_trigger_epoch_ms: AtomicU64::new(0),