diff --git a/src/error.rs b/src/error.rs index 889cbcd..9a839dd 100644 --- a/src/error.rs +++ b/src/error.rs @@ -225,6 +225,9 @@ pub enum ProxyError { #[error("ME connection lost")] MiddleConnectionLost, + #[error("Session terminated")] + RouteSwitched, + // ============= Config Errors ============= #[error("Config error: {0}")] Config(String), diff --git a/src/ip_tracker.rs b/src/ip_tracker.rs index e7b5185..5d7d9b6 100644 --- a/src/ip_tracker.rs +++ b/src/ip_tracker.rs @@ -32,6 +32,7 @@ pub struct UserIpTracker { limit_mode: Arc>, limit_window: Arc>, last_compact_epoch_secs: Arc, + cleanup_queue_len: Arc, cleanup_queue: Arc>>, cleanup_drain_lock: Arc>, } @@ -72,6 +73,7 @@ impl UserIpTracker { limit_mode: Arc::new(RwLock::new(UserMaxUniqueIpsMode::ActiveWindow)), limit_window: Arc::new(RwLock::new(Duration::from_secs(30))), last_compact_epoch_secs: Arc::new(AtomicU64::new(0)), + cleanup_queue_len: Arc::new(AtomicU64::new(0)), cleanup_queue: Arc::new(Mutex::new(HashMap::new())), cleanup_drain_lock: Arc::new(AsyncMutex::new(())), } @@ -120,6 +122,9 @@ impl UserIpTracker { match self.cleanup_queue.lock() { Ok(mut queue) => { let count = queue.entry((user, ip)).or_insert(0); + if *count == 0 { + self.cleanup_queue_len.fetch_add(1, Ordering::Relaxed); + } *count = count.saturating_add(1); self.cleanup_deferred_releases .fetch_add(1, Ordering::Relaxed); @@ -127,6 +132,9 @@ impl UserIpTracker { Err(poisoned) => { let mut queue = poisoned.into_inner(); let count = queue.entry((user.clone(), ip)).or_insert(0); + if *count == 0 { + self.cleanup_queue_len.fetch_add(1, Ordering::Relaxed); + } *count = count.saturating_add(1); self.cleanup_deferred_releases .fetch_add(1, Ordering::Relaxed); @@ -156,6 +164,9 @@ impl UserIpTracker { } pub(crate) async fn drain_cleanup_queue(&self) { + if self.cleanup_queue_len.load(Ordering::Relaxed) == 0 { + return; + } let Ok(_drain_guard) = self.cleanup_drain_lock.try_lock() else { return; }; @@ -173,6 +184,7 @@ impl UserIpTracker { break; }; if let Some(count) = queue.remove(&key) { + self.cleanup_queue_len.fetch_sub(1, Ordering::Relaxed); drained.insert(key, count); } } @@ -191,6 +203,7 @@ impl UserIpTracker { break; }; if let Some(count) = queue.remove(&key) { + self.cleanup_queue_len.fetch_sub(1, Ordering::Relaxed); drained.insert(key, count); } } @@ -294,12 +307,17 @@ impl UserIpTracker { } } + pub async fn run_periodic_maintenance(self: Arc) { + let mut interval = tokio::time::interval(Duration::from_secs(1)); + loop { + interval.tick().await; + self.drain_cleanup_queue().await; + self.maybe_compact_empty_users().await; + } + } + pub async fn memory_stats(&self) -> UserIpTrackerMemoryStats { - let cleanup_queue_len = self - .cleanup_queue - .lock() - .unwrap_or_else(|poisoned| poisoned.into_inner()) - .len(); + let cleanup_queue_len = self.cleanup_queue_len.load(Ordering::Relaxed) as usize; let active_ips = self.active_ips.read().await; let recent_ips = self.recent_ips.read().await; let active_entries = active_ips.values().map(HashMap::len).sum(); diff --git a/src/maestro/listeners.rs b/src/maestro/listeners.rs index 501a476..f24fa37 100644 --- a/src/maestro/listeners.rs +++ b/src/maestro/listeners.rs @@ -13,7 +13,7 @@ use crate::config::{ProxyConfig, RstOnCloseMode}; use crate::crypto::SecureRandom; use crate::ip_tracker::UserIpTracker; use crate::proxy::ClientHandler; -use crate::proxy::route_mode::{ROUTE_SWITCH_ERROR_MSG, RouteRuntimeController}; +use crate::proxy::route_mode::RouteRuntimeController; use crate::proxy::shared_state::ProxySharedState; use crate::startup::{COMPONENT_LISTENERS_BIND, StartupTracker}; use crate::stats::beobachten::BeobachtenStore; @@ -498,7 +498,7 @@ pub(crate) fn spawn_tcp_accept_loops( ); let route_switched = matches!( &e, - crate::error::ProxyError::Proxy(msg) if msg == ROUTE_SWITCH_ERROR_MSG + crate::error::ProxyError::RouteSwitched ); match (peer_close_reason, me_closed) { diff --git a/src/maestro/runtime_tasks.rs b/src/maestro/runtime_tasks.rs index d5b5f6b..b0c3e4e 100644 --- a/src/maestro/runtime_tasks.rs +++ b/src/maestro/runtime_tasks.rs @@ -78,6 +78,11 @@ pub(crate) async fn spawn_runtime_tasks( stats_maintenance.run_periodic_user_stats_maintenance().await; }); + let ip_tracker_maintenance = ip_tracker.clone(); + tokio::spawn(async move { + ip_tracker_maintenance.run_periodic_maintenance().await; + }); + let detected_ip_v4: Option = probe.detected_ipv4.map(IpAddr::V4); let detected_ip_v6: Option = probe.detected_ipv6.map(IpAddr::V6); debug!( diff --git a/src/proxy/direct_relay.rs b/src/proxy/direct_relay.rs index 6bd2101..efebcd9 100644 --- a/src/proxy/direct_relay.rs +++ b/src/proxy/direct_relay.rs @@ -18,8 +18,7 @@ use crate::error::{ProxyError, Result}; use crate::protocol::constants::*; use crate::proxy::handshake::{HandshakeSuccess, encrypt_tg_nonce_with_ciphers, generate_tg_nonce}; use crate::proxy::route_mode::{ - ROUTE_SWITCH_ERROR_MSG, RelayRouteMode, RouteCutoverState, affected_cutover_state, - cutover_stagger_delay, + RelayRouteMode, RouteCutoverState, affected_cutover_state, cutover_stagger_delay, }; use crate::proxy::shared_state::{ ConntrackCloseEvent, ConntrackClosePublishResult, ConntrackCloseReason, ProxySharedState, @@ -360,7 +359,7 @@ where "Cutover affected direct session, closing client connection" ); tokio::time::sleep(delay).await; - break Err(ProxyError::Proxy(ROUTE_SWITCH_ERROR_MSG.to_string())); + break Err(ProxyError::RouteSwitched); } tokio::select! { result = &mut relay_result => { diff --git a/src/proxy/middle_relay.rs b/src/proxy/middle_relay.rs index 9a89083..d67531a 100644 --- a/src/proxy/middle_relay.rs +++ b/src/proxy/middle_relay.rs @@ -23,8 +23,7 @@ use crate::error::{ProxyError, Result}; use crate::protocol::constants::{secure_padding_len, *}; use crate::proxy::handshake::HandshakeSuccess; use crate::proxy::route_mode::{ - ROUTE_SWITCH_ERROR_MSG, RelayRouteMode, RouteCutoverState, affected_cutover_state, - cutover_stagger_delay, + RelayRouteMode, RouteCutoverState, affected_cutover_state, cutover_stagger_delay, }; use crate::proxy::shared_state::{ ConntrackCloseEvent, ConntrackClosePublishResult, ConntrackCloseReason, ProxySharedState, @@ -1188,7 +1187,7 @@ where tokio::time::sleep(delay).await; let _ = me_pool.send_close(conn_id).await; me_pool.registry().unregister(conn_id).await; - return Err(ProxyError::Proxy(ROUTE_SWITCH_ERROR_MSG.to_string())); + return Err(ProxyError::RouteSwitched); } // Per-user ad_tag from access.user_ad_tags; fallback to general.ad_tag (hot-reloadable) @@ -1690,7 +1689,7 @@ where stats.as_ref(), ) .await; - main_result = Err(ProxyError::Proxy(ROUTE_SWITCH_ERROR_MSG.to_string())); + main_result = Err(ProxyError::RouteSwitched); break; } diff --git a/src/proxy/route_mode.rs b/src/proxy/route_mode.rs index 5aa7e91..a3a5d6c 100644 --- a/src/proxy/route_mode.rs +++ b/src/proxy/route_mode.rs @@ -4,8 +4,6 @@ use std::time::{Duration, SystemTime, UNIX_EPOCH}; use tokio::sync::watch; -pub(crate) const ROUTE_SWITCH_ERROR_MSG: &str = "Session terminated"; - #[derive(Clone, Copy, Debug, PartialEq, Eq)] #[repr(u8)] pub(crate) enum RelayRouteMode { diff --git a/src/proxy/tests/client_security_tests.rs b/src/proxy/tests/client_security_tests.rs index 2a60b57..e819e4f 100644 --- a/src/proxy/tests/client_security_tests.rs +++ b/src/proxy/tests/client_security_tests.rs @@ -661,7 +661,7 @@ async fn integration_route_cutover_and_quota_overlap_fails_closed_and_releases_s assert!( matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. })) - || matches!(relay_result, Err(ProxyError::Proxy(ref msg)) if msg == crate::proxy::route_mode::ROUTE_SWITCH_ERROR_MSG), + || matches!(relay_result, Err(ProxyError::RouteSwitched)), "overlap race must fail closed via quota enforcement or generic cutover termination" ); diff --git a/src/proxy/tests/direct_relay_security_tests.rs b/src/proxy/tests/direct_relay_security_tests.rs index 9554752..f7ffd0d 100644 --- a/src/proxy/tests/direct_relay_security_tests.rs +++ b/src/proxy/tests/direct_relay_security_tests.rs @@ -1491,7 +1491,7 @@ async fn direct_relay_cutover_midflight_releases_route_gauge() { assert!( matches!( relay_result, - Err(ProxyError::Proxy(ref msg)) if msg == ROUTE_SWITCH_ERROR_MSG + Err(ProxyError::RouteSwitched) ), "client-visible cutover error must stay generic and avoid route-internal metadata" ); @@ -1631,7 +1631,7 @@ async fn direct_relay_cutover_storm_multi_session_keeps_generic_errors_and_relea assert!( matches!( relay_result, - Err(ProxyError::Proxy(ref msg)) if msg == ROUTE_SWITCH_ERROR_MSG + Err(ProxyError::RouteSwitched) ), "storm-cutover termination must remain generic for all direct sessions" ); @@ -1937,7 +1937,7 @@ async fn adversarial_direct_relay_cutover_integrity() { assert!( matches!( result, - Err(ProxyError::Proxy(ref msg)) if msg == ROUTE_SWITCH_ERROR_MSG + Err(ProxyError::RouteSwitched) ), "Session must terminate with route switch error on cutover" );