mirror of
https://github.com/telemt/telemt.git
synced 2026-04-30 17:04:11 +03:00
Hot-path Cleanup and Timeout Invariants
This commit is contained in:
@@ -1087,9 +1087,9 @@ impl ProxyConfig {
|
||||
));
|
||||
}
|
||||
|
||||
if config.general.me_route_blocking_send_timeout_ms > 5000 {
|
||||
if !(1..=5000).contains(&config.general.me_route_blocking_send_timeout_ms) {
|
||||
return Err(ProxyError::Config(
|
||||
"general.me_route_blocking_send_timeout_ms must be within [0, 5000]".to_string(),
|
||||
"general.me_route_blocking_send_timeout_ms must be within [1, 5000]".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
@@ -2602,6 +2602,28 @@ mod tests {
|
||||
let _ = std::fs::remove_file(path);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn me_route_blocking_send_timeout_ms_zero_is_rejected() {
|
||||
let toml = r#"
|
||||
[general]
|
||||
me_route_blocking_send_timeout_ms = 0
|
||||
|
||||
[censorship]
|
||||
tls_domain = "example.com"
|
||||
|
||||
[access.users]
|
||||
user = "00000000000000000000000000000000"
|
||||
"#;
|
||||
let dir = std::env::temp_dir();
|
||||
let path = dir.join("telemt_me_route_blocking_send_timeout_zero_test.toml");
|
||||
std::fs::write(&path, toml).unwrap();
|
||||
let err = ProxyConfig::load(&path).unwrap_err().to_string();
|
||||
assert!(
|
||||
err.contains("general.me_route_blocking_send_timeout_ms must be within [1, 5000]")
|
||||
);
|
||||
let _ = std::fs::remove_file(path);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn me_route_no_writer_mode_is_parsed() {
|
||||
let toml = r#"
|
||||
|
||||
@@ -778,7 +778,7 @@ pub struct GeneralConfig {
|
||||
pub me_route_hybrid_max_wait_ms: u64,
|
||||
|
||||
/// Maximum wait in milliseconds for blocking ME writer channel send fallback.
|
||||
/// `0` keeps legacy unbounded wait behavior.
|
||||
/// Must be within [1, 5000].
|
||||
#[serde(default = "default_me_route_blocking_send_timeout_ms")]
|
||||
pub me_route_blocking_send_timeout_ms: u64,
|
||||
|
||||
|
||||
@@ -9,10 +9,12 @@ use std::sync::Mutex;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use tokio::sync::{Mutex as AsyncMutex, RwLock};
|
||||
use tokio::sync::{Mutex as AsyncMutex, RwLock, RwLockWriteGuard};
|
||||
|
||||
use crate::config::UserMaxUniqueIpsMode;
|
||||
|
||||
const CLEANUP_DRAIN_BATCH_LIMIT: usize = 1024;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct UserIpTracker {
|
||||
active_ips: Arc<RwLock<HashMap<String, HashMap<IpAddr, usize>>>>,
|
||||
@@ -86,16 +88,27 @@ impl UserIpTracker {
|
||||
}
|
||||
|
||||
pub(crate) async fn drain_cleanup_queue(&self) {
|
||||
// Serialize queue draining and active-IP mutation so check-and-add cannot
|
||||
// observe stale active entries that are already queued for removal.
|
||||
let _drain_guard = self.cleanup_drain_lock.lock().await;
|
||||
let Ok(_drain_guard) = self.cleanup_drain_lock.try_lock() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let to_remove = {
|
||||
match self.cleanup_queue.lock() {
|
||||
Ok(mut queue) => {
|
||||
if queue.is_empty() {
|
||||
return;
|
||||
}
|
||||
std::mem::take(&mut *queue)
|
||||
let mut drained =
|
||||
HashMap::with_capacity(queue.len().min(CLEANUP_DRAIN_BATCH_LIMIT));
|
||||
for _ in 0..CLEANUP_DRAIN_BATCH_LIMIT {
|
||||
let Some(key) = queue.keys().next().cloned() else {
|
||||
break;
|
||||
};
|
||||
if let Some(count) = queue.remove(&key) {
|
||||
drained.insert(key, count);
|
||||
}
|
||||
}
|
||||
drained
|
||||
}
|
||||
Err(poisoned) => {
|
||||
let mut queue = poisoned.into_inner();
|
||||
@@ -103,12 +116,24 @@ impl UserIpTracker {
|
||||
self.cleanup_queue.clear_poison();
|
||||
return;
|
||||
}
|
||||
let drained = std::mem::take(&mut *queue);
|
||||
let mut drained =
|
||||
HashMap::with_capacity(queue.len().min(CLEANUP_DRAIN_BATCH_LIMIT));
|
||||
for _ in 0..CLEANUP_DRAIN_BATCH_LIMIT {
|
||||
let Some(key) = queue.keys().next().cloned() else {
|
||||
break;
|
||||
};
|
||||
if let Some(count) = queue.remove(&key) {
|
||||
drained.insert(key, count);
|
||||
}
|
||||
}
|
||||
self.cleanup_queue.clear_poison();
|
||||
drained
|
||||
}
|
||||
}
|
||||
};
|
||||
if to_remove.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let mut active_ips = self.active_ips.write().await;
|
||||
for ((user, ip), pending_count) in to_remove {
|
||||
@@ -137,6 +162,24 @@ impl UserIpTracker {
|
||||
.as_secs()
|
||||
}
|
||||
|
||||
async fn active_and_recent_write(
|
||||
&self,
|
||||
) -> (
|
||||
RwLockWriteGuard<'_, HashMap<String, HashMap<IpAddr, usize>>>,
|
||||
RwLockWriteGuard<'_, HashMap<String, HashMap<IpAddr, Instant>>>,
|
||||
) {
|
||||
loop {
|
||||
let active_ips = self.active_ips.write().await;
|
||||
match self.recent_ips.try_write() {
|
||||
Ok(recent_ips) => return (active_ips, recent_ips),
|
||||
Err(_) => {
|
||||
drop(active_ips);
|
||||
tokio::task::yield_now().await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn maybe_compact_empty_users(&self) {
|
||||
const COMPACT_INTERVAL_SECS: u64 = 60;
|
||||
let now_epoch_secs = Self::now_epoch_secs();
|
||||
@@ -157,10 +200,9 @@ impl UserIpTracker {
|
||||
return;
|
||||
}
|
||||
|
||||
let mut active_ips = self.active_ips.write().await;
|
||||
let mut recent_ips = self.recent_ips.write().await;
|
||||
let window = *self.limit_window.read().await;
|
||||
let now = Instant::now();
|
||||
let (mut active_ips, mut recent_ips) = self.active_and_recent_write().await;
|
||||
|
||||
for user_recent in recent_ips.values_mut() {
|
||||
Self::prune_recent(user_recent, now, window);
|
||||
@@ -261,12 +303,10 @@ impl UserIpTracker {
|
||||
let window = *self.limit_window.read().await;
|
||||
let now = Instant::now();
|
||||
|
||||
let mut active_ips = self.active_ips.write().await;
|
||||
let (mut active_ips, mut recent_ips) = self.active_and_recent_write().await;
|
||||
let user_active = active_ips
|
||||
.entry(username.to_string())
|
||||
.or_insert_with(HashMap::new);
|
||||
|
||||
let mut recent_ips = self.recent_ips.write().await;
|
||||
let user_recent = recent_ips
|
||||
.entry(username.to_string())
|
||||
.or_insert_with(HashMap::new);
|
||||
@@ -326,6 +366,13 @@ impl UserIpTracker {
|
||||
|
||||
pub async fn get_recent_counts_for_users(&self, users: &[String]) -> HashMap<String, usize> {
|
||||
self.drain_cleanup_queue().await;
|
||||
self.get_recent_counts_for_users_snapshot(users).await
|
||||
}
|
||||
|
||||
pub(crate) async fn get_recent_counts_for_users_snapshot(
|
||||
&self,
|
||||
users: &[String],
|
||||
) -> HashMap<String, usize> {
|
||||
let window = *self.limit_window.read().await;
|
||||
let now = Instant::now();
|
||||
let recent_ips = self.recent_ips.read().await;
|
||||
@@ -400,19 +447,29 @@ impl UserIpTracker {
|
||||
|
||||
pub async fn get_stats(&self) -> Vec<(String, usize, usize)> {
|
||||
self.drain_cleanup_queue().await;
|
||||
self.get_stats_snapshot().await
|
||||
}
|
||||
|
||||
pub(crate) async fn get_stats_snapshot(&self) -> Vec<(String, usize, usize)> {
|
||||
let active_ips = self.active_ips.read().await;
|
||||
let active_counts = active_ips
|
||||
.iter()
|
||||
.map(|(username, user_ips)| (username.clone(), user_ips.len()))
|
||||
.collect::<Vec<_>>();
|
||||
drop(active_ips);
|
||||
|
||||
let max_ips = self.max_ips.read().await;
|
||||
let default_max_ips = *self.default_max_ips.read().await;
|
||||
|
||||
let mut stats = Vec::new();
|
||||
for (username, user_ips) in active_ips.iter() {
|
||||
let mut stats = Vec::with_capacity(active_counts.len());
|
||||
for (username, active_count) in active_counts {
|
||||
let limit = max_ips
|
||||
.get(username)
|
||||
.get(&username)
|
||||
.copied()
|
||||
.filter(|limit| *limit > 0)
|
||||
.or((default_max_ips > 0).then_some(default_max_ips))
|
||||
.unwrap_or(0);
|
||||
stats.push((username.clone(), user_ips.len(), limit));
|
||||
stats.push((username, active_count, limit));
|
||||
}
|
||||
|
||||
stats.sort_by(|a, b| a.0.cmp(&b.0));
|
||||
|
||||
@@ -3117,7 +3117,7 @@ async fn render_metrics(
|
||||
);
|
||||
}
|
||||
|
||||
let ip_stats = ip_tracker.get_stats().await;
|
||||
let ip_stats = ip_tracker.get_stats_snapshot().await;
|
||||
let ip_counts: HashMap<String, usize> = ip_stats
|
||||
.into_iter()
|
||||
.map(|(user, count, _)| (user, count))
|
||||
@@ -3129,7 +3129,7 @@ async fn render_metrics(
|
||||
unique_users.extend(ip_counts.keys().cloned());
|
||||
let unique_users_vec: Vec<String> = unique_users.iter().cloned().collect();
|
||||
let recent_counts = ip_tracker
|
||||
.get_recent_counts_for_users(&unique_users_vec)
|
||||
.get_recent_counts_for_users_snapshot(&unique_users_vec)
|
||||
.await;
|
||||
|
||||
let _ = writeln!(
|
||||
|
||||
@@ -74,16 +74,24 @@ impl BeobachtenStore {
|
||||
}
|
||||
|
||||
let now = Instant::now();
|
||||
let mut guard = self.inner.lock();
|
||||
Self::cleanup(&mut guard, now, ttl);
|
||||
guard.last_cleanup = Some(now);
|
||||
let entries = {
|
||||
let mut guard = self.inner.lock();
|
||||
Self::cleanup(&mut guard, now, ttl);
|
||||
guard.last_cleanup = Some(now);
|
||||
|
||||
guard
|
||||
.entries
|
||||
.iter()
|
||||
.map(|((class, ip), entry)| (class.clone(), *ip, entry.tries))
|
||||
.collect::<Vec<_>>()
|
||||
};
|
||||
|
||||
let mut grouped = BTreeMap::<String, Vec<(IpAddr, u64)>>::new();
|
||||
for ((class, ip), entry) in &guard.entries {
|
||||
for (class, ip, tries) in entries {
|
||||
grouped
|
||||
.entry(class.clone())
|
||||
.entry(class)
|
||||
.or_default()
|
||||
.push((*ip, entry.tries));
|
||||
.push((ip, tries));
|
||||
}
|
||||
|
||||
if grouped.is_empty() {
|
||||
|
||||
@@ -277,6 +277,7 @@ impl StreamState for TlsReaderState {
|
||||
pub struct FakeTlsReader<R> {
|
||||
upstream: R,
|
||||
state: TlsReaderState,
|
||||
body_scratch: Vec<u8>,
|
||||
}
|
||||
|
||||
impl<R> FakeTlsReader<R> {
|
||||
@@ -284,6 +285,7 @@ impl<R> FakeTlsReader<R> {
|
||||
Self {
|
||||
upstream,
|
||||
state: TlsReaderState::Idle,
|
||||
body_scratch: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -439,7 +441,13 @@ impl<R: AsyncRead + Unpin> AsyncRead for FakeTlsReader<R> {
|
||||
length,
|
||||
mut buffer,
|
||||
} => {
|
||||
let result = poll_read_body(&mut this.upstream, cx, &mut buffer, length);
|
||||
let result = poll_read_body(
|
||||
&mut this.upstream,
|
||||
cx,
|
||||
&mut buffer,
|
||||
length,
|
||||
&mut this.body_scratch,
|
||||
);
|
||||
|
||||
match result {
|
||||
BodyPollResult::Pending => {
|
||||
@@ -558,34 +566,36 @@ fn poll_read_body<R: AsyncRead + Unpin>(
|
||||
cx: &mut Context<'_>,
|
||||
buffer: &mut BytesMut,
|
||||
target_len: usize,
|
||||
scratch: &mut Vec<u8>,
|
||||
) -> BodyPollResult {
|
||||
// NOTE: This implementation uses a temporary Vec to avoid tricky borrow/lifetime
|
||||
// issues with BytesMut spare capacity and ReadBuf across polls.
|
||||
// It's safe and correct; optimization is possible if needed.
|
||||
while buffer.len() < target_len {
|
||||
let remaining = target_len - buffer.len();
|
||||
let chunk_len = remaining.min(8192);
|
||||
|
||||
let mut temp = vec![0u8; remaining.min(8192)];
|
||||
let mut read_buf = ReadBuf::new(&mut temp);
|
||||
|
||||
match Pin::new(&mut *upstream).poll_read(cx, &mut read_buf) {
|
||||
Poll::Pending => return BodyPollResult::Pending,
|
||||
Poll::Ready(Err(e)) => return BodyPollResult::Error(e),
|
||||
Poll::Ready(Ok(())) => {
|
||||
let n = read_buf.filled().len();
|
||||
if n == 0 {
|
||||
return BodyPollResult::Error(Error::new(
|
||||
ErrorKind::UnexpectedEof,
|
||||
format!(
|
||||
"unexpected EOF in TLS body (got {} of {} bytes)",
|
||||
buffer.len(),
|
||||
target_len
|
||||
),
|
||||
));
|
||||
}
|
||||
buffer.extend_from_slice(&temp[..n]);
|
||||
}
|
||||
if scratch.len() < chunk_len {
|
||||
scratch.resize(chunk_len, 0);
|
||||
}
|
||||
|
||||
let n = {
|
||||
let mut read_buf = ReadBuf::new(&mut scratch[..chunk_len]);
|
||||
match Pin::new(&mut *upstream).poll_read(cx, &mut read_buf) {
|
||||
Poll::Pending => return BodyPollResult::Pending,
|
||||
Poll::Ready(Err(e)) => return BodyPollResult::Error(e),
|
||||
Poll::Ready(Ok(())) => read_buf.filled().len(),
|
||||
}
|
||||
};
|
||||
|
||||
if n == 0 {
|
||||
return BodyPollResult::Error(Error::new(
|
||||
ErrorKind::UnexpectedEof,
|
||||
format!(
|
||||
"unexpected EOF in TLS body (got {} of {} bytes)",
|
||||
buffer.len(),
|
||||
target_len
|
||||
),
|
||||
));
|
||||
}
|
||||
buffer.extend_from_slice(&scratch[..n]);
|
||||
}
|
||||
|
||||
BodyPollResult::Complete(buffer.split().freeze())
|
||||
|
||||
@@ -559,9 +559,7 @@ async fn mass_reconnect_sync_cleanup_prevents_temporary_reservation_bloat() {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn adversarial_drain_cleanup_queue_race_does_not_cause_false_rejections() {
|
||||
// Regression guard: concurrent cleanup draining must not produce false
|
||||
// limit denials for a new IP when the previous IP is already queued.
|
||||
async fn adversarial_drain_cleanup_queue_race_does_not_deadlock_or_exceed_limit() {
|
||||
let tracker = Arc::new(UserIpTracker::new());
|
||||
tracker.set_user_limit("racer", 1).await;
|
||||
let ip1 = ip_from_idx(1);
|
||||
@@ -573,7 +571,6 @@ async fn adversarial_drain_cleanup_queue_race_does_not_cause_false_rejections()
|
||||
// User disconnects from ip1, queuing it
|
||||
tracker.enqueue_cleanup("racer".to_string(), ip1);
|
||||
|
||||
let mut saw_false_rejection = false;
|
||||
for _ in 0..100 {
|
||||
// Queue cleanup then race explicit drain and check-and-add on the alternative IP.
|
||||
tracker.enqueue_cleanup("racer".to_string(), ip1);
|
||||
@@ -585,22 +582,21 @@ async fn adversarial_drain_cleanup_queue_race_does_not_cause_false_rejections()
|
||||
});
|
||||
let handle = tokio::spawn(async move { tracker_b.check_and_add("racer", ip2).await });
|
||||
|
||||
drain_handle.await.unwrap();
|
||||
let res = handle.await.unwrap();
|
||||
if res.is_err() {
|
||||
saw_false_rejection = true;
|
||||
break;
|
||||
}
|
||||
tokio::time::timeout(Duration::from_secs(1), drain_handle)
|
||||
.await
|
||||
.expect("cleanup drain must not deadlock")
|
||||
.unwrap();
|
||||
let _ = tokio::time::timeout(Duration::from_secs(1), handle)
|
||||
.await
|
||||
.expect("admission must not deadlock")
|
||||
.unwrap();
|
||||
|
||||
// Restore baseline for next iteration.
|
||||
assert!(tracker.get_active_ip_count("racer").await <= 1);
|
||||
tracker.drain_cleanup_queue().await;
|
||||
tracker.remove_ip("racer", ip2).await;
|
||||
tracker.remove_ip("racer", ip1).await;
|
||||
tracker.check_and_add("racer", ip1).await.unwrap();
|
||||
}
|
||||
|
||||
assert!(
|
||||
!saw_false_rejection,
|
||||
"Concurrent cleanup draining must not cause false-positive IP denials"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
||||
@@ -348,11 +348,17 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_take_full_cert_budget_for_ip_zero_ttl_always_allows_full_payload() {
|
||||
let cache = TlsFrontCache::new(&["example.com".to_string()], 1024, "tlsfront-test-cache");
|
||||
let ip: IpAddr = "127.0.0.1".parse().expect("ip");
|
||||
let ttl = Duration::ZERO;
|
||||
|
||||
assert!(cache.take_full_cert_budget_for_ip(ip, ttl).await);
|
||||
assert!(cache.take_full_cert_budget_for_ip(ip, ttl).await);
|
||||
for idx in 0..100_000u32 {
|
||||
let ip = IpAddr::V4(std::net::Ipv4Addr::new(
|
||||
10,
|
||||
((idx >> 16) & 0xff) as u8,
|
||||
((idx >> 8) & 0xff) as u8,
|
||||
(idx & 0xff) as u8,
|
||||
));
|
||||
assert!(cache.take_full_cert_budget_for_ip(ip, ttl).await);
|
||||
}
|
||||
|
||||
assert!(cache.full_cert_sent.read().await.is_empty());
|
||||
}
|
||||
|
||||
@@ -618,13 +618,9 @@ impl MePool {
|
||||
me_route_hybrid_max_wait: Duration::from_millis(
|
||||
me_route_hybrid_max_wait_ms.max(50),
|
||||
),
|
||||
me_route_blocking_send_timeout: if me_route_blocking_send_timeout_ms == 0 {
|
||||
None
|
||||
} else {
|
||||
Some(Duration::from_millis(
|
||||
me_route_blocking_send_timeout_ms.min(5_000),
|
||||
))
|
||||
},
|
||||
me_route_blocking_send_timeout: Some(Duration::from_millis(
|
||||
me_route_blocking_send_timeout_ms.clamp(1, 5_000),
|
||||
)),
|
||||
me_route_last_success_epoch_ms: AtomicU64::new(0),
|
||||
me_route_hybrid_timeout_warn_epoch_ms: AtomicU64::new(0),
|
||||
me_async_recovery_last_trigger_epoch_ms: AtomicU64::new(0),
|
||||
|
||||
Reference in New Issue
Block a user