From 512bee6a8d31678a140c4034327ecfbb495da064 Mon Sep 17 00:00:00 2001 From: David Osipov Date: Fri, 20 Mar 2026 16:43:50 +0400 Subject: [PATCH] Add security tests for middle relay idle policy and enhance stats tracking - Introduced a new test module for middle relay idle policy security tests, covering various scenarios including soft mark, hard close, and grace periods. - Implemented functions to create crypto readers and encrypt data for testing. - Enhanced the Stats struct to include counters for relay idle soft marks, hard closes, pressure evictions, and protocol desync closes. - Added corresponding increment and retrieval methods for the new stats fields. --- src/config/defaults.rs | 16 + src/config/load.rs | 40 + src/config/load_idle_policy_tests.rs | 78 ++ src/config/types.rs | 23 + src/metrics.rs | 75 ++ src/proxy/middle_relay.rs | 518 +++++++++++- ...middle_relay_idle_policy_security_tests.rs | 799 ++++++++++++++++++ src/stats/mod.rs | 40 + 8 files changed, 1571 insertions(+), 18 deletions(-) create mode 100644 src/config/load_idle_policy_tests.rs create mode 100644 src/proxy/middle_relay_idle_policy_security_tests.rs diff --git a/src/config/defaults.rs b/src/config/defaults.rs index 1495dee..9f5da5f 100644 --- a/src/config/defaults.rs +++ b/src/config/defaults.rs @@ -91,6 +91,22 @@ pub(crate) fn default_handshake_timeout() -> u64 { 30 } +pub(crate) fn default_relay_idle_policy_v2_enabled() -> bool { + true +} + +pub(crate) fn default_relay_client_idle_soft_secs() -> u64 { + 120 +} + +pub(crate) fn default_relay_client_idle_hard_secs() -> u64 { + 360 +} + +pub(crate) fn default_relay_idle_grace_after_downstream_activity_secs() -> u64 { + 30 +} + pub(crate) fn default_connect_timeout() -> u64 { 10 } diff --git a/src/config/load.rs b/src/config/load.rs index 14799ed..b461434 100644 --- a/src/config/load.rs +++ b/src/config/load.rs @@ -328,6 +328,42 @@ impl ProxyConfig { )); } + if config.timeouts.client_handshake == 0 { + return Err(ProxyError::Config( + "timeouts.client_handshake must be > 0".to_string(), + )); + } + + if config.timeouts.relay_client_idle_soft_secs == 0 { + return Err(ProxyError::Config( + "timeouts.relay_client_idle_soft_secs must be > 0".to_string(), + )); + } + + if config.timeouts.relay_client_idle_hard_secs == 0 { + return Err(ProxyError::Config( + "timeouts.relay_client_idle_hard_secs must be > 0".to_string(), + )); + } + + if config.timeouts.relay_client_idle_hard_secs + < config.timeouts.relay_client_idle_soft_secs + { + return Err(ProxyError::Config( + "timeouts.relay_client_idle_hard_secs must be >= timeouts.relay_client_idle_soft_secs" + .to_string(), + )); + } + + if config.timeouts.relay_idle_grace_after_downstream_activity_secs + > config.timeouts.relay_client_idle_hard_secs + { + return Err(ProxyError::Config( + "timeouts.relay_idle_grace_after_downstream_activity_secs must be <= timeouts.relay_client_idle_hard_secs" + .to_string(), + )); + } + if config.general.me_writer_cmd_channel_capacity == 0 { return Err(ProxyError::Config( "general.me_writer_cmd_channel_capacity must be > 0".to_string(), @@ -934,6 +970,10 @@ impl ProxyConfig { } } +#[cfg(test)] +#[path = "load_idle_policy_tests.rs"] +mod load_idle_policy_tests; + #[cfg(test)] mod tests { use super::*; diff --git a/src/config/load_idle_policy_tests.rs b/src/config/load_idle_policy_tests.rs new file mode 100644 index 0000000..087fd75 --- /dev/null +++ b/src/config/load_idle_policy_tests.rs @@ -0,0 +1,78 @@ +use super::*; +use std::fs; +use std::path::PathBuf; +use std::time::{SystemTime, UNIX_EPOCH}; + +fn write_temp_config(contents: &str) -> PathBuf { + let nonce = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("system time must be after unix epoch") + .as_nanos(); + let path = std::env::temp_dir().join(format!("telemt-idle-policy-{nonce}.toml")); + fs::write(&path, contents).expect("temp config write must succeed"); + path +} + +fn remove_temp_config(path: &PathBuf) { + let _ = fs::remove_file(path); +} + +#[test] +fn load_rejects_relay_hard_idle_smaller_than_soft_idle_with_clear_error() { + let path = write_temp_config( + r#" +[timeouts] +relay_client_idle_soft_secs = 120 +relay_client_idle_hard_secs = 60 +"#, + ); + + let err = ProxyConfig::load(&path).expect_err("config with hard= timeouts.relay_client_idle_soft_secs"), + "error must explain the violated hard>=soft invariant, got: {msg}" + ); + + remove_temp_config(&path); +} + +#[test] +fn load_rejects_relay_grace_larger_than_hard_idle_with_clear_error() { + let path = write_temp_config( + r#" +[timeouts] +relay_client_idle_soft_secs = 60 +relay_client_idle_hard_secs = 120 +relay_idle_grace_after_downstream_activity_secs = 121 +"#, + ); + + let err = ProxyConfig::load(&path).expect_err("config with grace>hard must fail"); + let msg = err.to_string(); + assert!( + msg.contains("timeouts.relay_idle_grace_after_downstream_activity_secs must be <= timeouts.relay_client_idle_hard_secs"), + "error must explain the violated grace<=hard invariant, got: {msg}" + ); + + remove_temp_config(&path); +} + +#[test] +fn load_rejects_zero_handshake_timeout_with_clear_error() { + let path = write_temp_config( + r#" +[timeouts] +client_handshake = 0 +"#, + ); + + let err = ProxyConfig::load(&path).expect_err("config with zero handshake timeout must fail"); + let msg = err.to_string(); + assert!( + msg.contains("timeouts.client_handshake must be > 0"), + "error must explain that handshake timeout must be positive, got: {msg}" + ); + + remove_temp_config(&path); +} diff --git a/src/config/types.rs b/src/config/types.rs index 965603e..3468e9a 100644 --- a/src/config/types.rs +++ b/src/config/types.rs @@ -1276,6 +1276,24 @@ pub struct TimeoutsConfig { #[serde(default = "default_handshake_timeout")] pub client_handshake: u64, + /// Enables soft/hard relay client idle policy for middle-relay sessions. + #[serde(default = "default_relay_idle_policy_v2_enabled")] + pub relay_idle_policy_v2_enabled: bool, + + /// Soft idle threshold for middle-relay client uplink activity in seconds. + /// Hitting this threshold marks the session as idle-candidate, but does not close it. + #[serde(default = "default_relay_client_idle_soft_secs")] + pub relay_client_idle_soft_secs: u64, + + /// Hard idle threshold for middle-relay client uplink activity in seconds. + /// Hitting this threshold closes the session. + #[serde(default = "default_relay_client_idle_hard_secs")] + pub relay_client_idle_hard_secs: u64, + + /// Additional grace in seconds added to hard idle window after recent downstream activity. + #[serde(default = "default_relay_idle_grace_after_downstream_activity_secs")] + pub relay_idle_grace_after_downstream_activity_secs: u64, + #[serde(default = "default_connect_timeout")] pub tg_connect: u64, @@ -1298,6 +1316,11 @@ impl Default for TimeoutsConfig { fn default() -> Self { Self { client_handshake: default_handshake_timeout(), + relay_idle_policy_v2_enabled: default_relay_idle_policy_v2_enabled(), + relay_client_idle_soft_secs: default_relay_client_idle_soft_secs(), + relay_client_idle_hard_secs: default_relay_client_idle_hard_secs(), + relay_idle_grace_after_downstream_activity_secs: + default_relay_idle_grace_after_downstream_activity_secs(), tg_connect: default_connect_timeout(), client_keepalive: default_keepalive(), client_ack: default_ack_timeout(), diff --git a/src/metrics.rs b/src/metrics.rs index f4f8a2e..b7a16f0 100644 --- a/src/metrics.rs +++ b/src/metrics.rs @@ -705,6 +705,69 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp } ); + let _ = writeln!( + out, + "# HELP telemt_relay_idle_soft_mark_total Middle-relay sessions marked as soft-idle candidates" + ); + let _ = writeln!(out, "# TYPE telemt_relay_idle_soft_mark_total counter"); + let _ = writeln!( + out, + "telemt_relay_idle_soft_mark_total {}", + if me_allows_normal { + stats.get_relay_idle_soft_mark_total() + } else { + 0 + } + ); + + let _ = writeln!( + out, + "# HELP telemt_relay_idle_hard_close_total Middle-relay sessions closed by hard-idle policy" + ); + let _ = writeln!(out, "# TYPE telemt_relay_idle_hard_close_total counter"); + let _ = writeln!( + out, + "telemt_relay_idle_hard_close_total {}", + if me_allows_normal { + stats.get_relay_idle_hard_close_total() + } else { + 0 + } + ); + + let _ = writeln!( + out, + "# HELP telemt_relay_pressure_evict_total Middle-relay sessions evicted under resource pressure" + ); + let _ = writeln!(out, "# TYPE telemt_relay_pressure_evict_total counter"); + let _ = writeln!( + out, + "telemt_relay_pressure_evict_total {}", + if me_allows_normal { + stats.get_relay_pressure_evict_total() + } else { + 0 + } + ); + + let _ = writeln!( + out, + "# HELP telemt_relay_protocol_desync_close_total Middle-relay sessions closed due to protocol desync" + ); + let _ = writeln!( + out, + "# TYPE telemt_relay_protocol_desync_close_total counter" + ); + let _ = writeln!( + out, + "telemt_relay_protocol_desync_close_total {}", + if me_allows_normal { + stats.get_relay_protocol_desync_close_total() + } else { + 0 + } + ); + let _ = writeln!(out, "# HELP telemt_me_crc_mismatch_total ME CRC mismatches"); let _ = writeln!(out, "# TYPE telemt_me_crc_mismatch_total counter"); let _ = writeln!( @@ -1879,6 +1942,10 @@ mod tests { stats.increment_me_rpc_proxy_req_signal_response_total(); stats.increment_me_rpc_proxy_req_signal_close_sent_total(); stats.increment_me_idle_close_by_peer_total(); + stats.increment_relay_idle_soft_mark_total(); + stats.increment_relay_idle_hard_close_total(); + stats.increment_relay_pressure_evict_total(); + stats.increment_relay_protocol_desync_close_total(); stats.increment_user_connects("alice"); stats.increment_user_curr_connects("alice"); stats.add_user_octets_from("alice", 1024); @@ -1917,6 +1984,10 @@ mod tests { assert!(output.contains("telemt_me_rpc_proxy_req_signal_response_total 1")); assert!(output.contains("telemt_me_rpc_proxy_req_signal_close_sent_total 1")); assert!(output.contains("telemt_me_idle_close_by_peer_total 1")); + assert!(output.contains("telemt_relay_idle_soft_mark_total 1")); + assert!(output.contains("telemt_relay_idle_hard_close_total 1")); + assert!(output.contains("telemt_relay_pressure_evict_total 1")); + assert!(output.contains("telemt_relay_protocol_desync_close_total 1")); assert!(output.contains("telemt_user_connections_total{user=\"alice\"} 1")); assert!(output.contains("telemt_user_connections_current{user=\"alice\"} 1")); assert!(output.contains("telemt_user_octets_from_client{user=\"alice\"} 1024")); @@ -1974,6 +2045,10 @@ mod tests { assert!(output.contains("# TYPE telemt_upstream_connect_attempt_total counter")); assert!(output.contains("# TYPE telemt_me_rpc_proxy_req_signal_sent_total counter")); assert!(output.contains("# TYPE telemt_me_idle_close_by_peer_total counter")); + assert!(output.contains("# TYPE telemt_relay_idle_soft_mark_total counter")); + assert!(output.contains("# TYPE telemt_relay_idle_hard_close_total counter")); + assert!(output.contains("# TYPE telemt_relay_pressure_evict_total counter")); + assert!(output.contains("# TYPE telemt_relay_protocol_desync_close_total counter")); assert!(output.contains("# TYPE telemt_me_writer_removed_total counter")); assert!(output.contains( "# TYPE telemt_me_writer_removed_unexpected_minus_restored_total gauge" diff --git a/src/proxy/middle_relay.rs b/src/proxy/middle_relay.rs index 7298cb4..c73944f 100644 --- a/src/proxy/middle_relay.rs +++ b/src/proxy/middle_relay.rs @@ -1,4 +1,5 @@ use std::collections::hash_map::RandomState; +use std::collections::{BTreeSet, HashMap}; use std::hash::BuildHasher; use std::hash::{Hash, Hasher}; use std::net::{IpAddr, SocketAddr}; @@ -10,7 +11,7 @@ use dashmap::DashMap; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::sync::{mpsc, oneshot, watch, Mutex as AsyncMutex}; use tokio::time::timeout; -use tracing::{debug, trace, warn}; +use tracing::{debug, info, trace, warn}; use crate::config::ProxyConfig; use crate::crypto::SecureRandom; @@ -38,6 +39,7 @@ const DESYNC_ERROR_CLASS: &str = "frame_too_large_crypto_desync"; const C2ME_CHANNEL_CAPACITY_FALLBACK: usize = 128; const C2ME_SOFT_PRESSURE_MIN_FREE_SLOTS: usize = 64; const C2ME_SENDER_FAIRNESS_BUDGET: usize = 32; +const RELAY_IDLE_IO_POLL_MAX: Duration = Duration::from_secs(1); #[cfg(test)] const C2ME_SEND_TIMEOUT: Duration = Duration::from_millis(50); #[cfg(not(test))] @@ -53,6 +55,8 @@ static DESYNC_HASHER: OnceLock = OnceLock::new(); static DESYNC_FULL_CACHE_LAST_EMIT_AT: OnceLock>> = OnceLock::new(); static DESYNC_DEDUP_EVER_SATURATED: OnceLock = OnceLock::new(); static QUOTA_USER_LOCKS: OnceLock>>> = OnceLock::new(); +static RELAY_IDLE_CANDIDATE_REGISTRY: OnceLock> = OnceLock::new(); +static RELAY_IDLE_MARK_SEQ: AtomicU64 = AtomicU64::new(0); struct RelayForensicsState { trace_id: u64, @@ -66,6 +70,140 @@ struct RelayForensicsState { desync_all_full: bool, } +#[derive(Default)] +struct RelayIdleCandidateRegistry { + by_conn_id: HashMap, + ordered: BTreeSet<(u64, u64)>, + pressure_event_seq: u64, + pressure_consumed_seq: u64, +} + +#[derive(Clone, Copy)] +struct RelayIdleCandidateMeta { + mark_order_seq: u64, + mark_pressure_seq: u64, +} + +fn relay_idle_candidate_registry() -> &'static Mutex { + RELAY_IDLE_CANDIDATE_REGISTRY.get_or_init(|| Mutex::new(RelayIdleCandidateRegistry::default())) +} + +fn mark_relay_idle_candidate(conn_id: u64) -> bool { + let Ok(mut guard) = relay_idle_candidate_registry().lock() else { + return false; + }; + + if guard.by_conn_id.contains_key(&conn_id) { + return false; + } + + let mark_order_seq = RELAY_IDLE_MARK_SEQ + .fetch_add(1, Ordering::Relaxed) + .saturating_add(1); + let meta = RelayIdleCandidateMeta { + mark_order_seq, + mark_pressure_seq: guard.pressure_event_seq, + }; + guard.by_conn_id.insert(conn_id, meta); + guard.ordered.insert((meta.mark_order_seq, conn_id)); + true +} + +fn clear_relay_idle_candidate(conn_id: u64) { + let Ok(mut guard) = relay_idle_candidate_registry().lock() else { + return; + }; + + if let Some(meta) = guard.by_conn_id.remove(&conn_id) { + guard.ordered.remove(&(meta.mark_order_seq, conn_id)); + } +} + +#[cfg(test)] +fn oldest_relay_idle_candidate() -> Option { + let Ok(guard) = relay_idle_candidate_registry().lock() else { + return None; + }; + guard.ordered.iter().next().map(|(_, conn_id)| *conn_id) +} + +fn note_relay_pressure_event() { + let Ok(mut guard) = relay_idle_candidate_registry().lock() else { + return; + }; + guard.pressure_event_seq = guard.pressure_event_seq.wrapping_add(1); +} + +fn relay_pressure_event_seq() -> u64 { + let Ok(guard) = relay_idle_candidate_registry().lock() else { + return 0; + }; + guard.pressure_event_seq +} + +fn maybe_evict_idle_candidate_on_pressure( + conn_id: u64, + seen_pressure_seq: &mut u64, + stats: &Stats, +) -> bool { + let Ok(mut guard) = relay_idle_candidate_registry().lock() else { + return false; + }; + + let latest_pressure_seq = guard.pressure_event_seq; + if latest_pressure_seq == *seen_pressure_seq { + return false; + } + *seen_pressure_seq = latest_pressure_seq; + + if latest_pressure_seq == guard.pressure_consumed_seq { + return false; + } + + if guard.ordered.is_empty() { + guard.pressure_consumed_seq = latest_pressure_seq; + return false; + } + + let oldest = guard + .ordered + .iter() + .next() + .map(|(_, candidate_conn_id)| *candidate_conn_id); + if oldest != Some(conn_id) { + return false; + } + + let Some(candidate_meta) = guard.by_conn_id.get(&conn_id).copied() else { + return false; + }; + + // Pressure events that happened before candidate soft-mark are stale for this candidate. + if latest_pressure_seq == candidate_meta.mark_pressure_seq { + return false; + } + + if let Some(meta) = guard.by_conn_id.remove(&conn_id) { + guard.ordered.remove(&(meta.mark_order_seq, conn_id)); + } + guard.pressure_consumed_seq = latest_pressure_seq; + stats.increment_relay_pressure_evict_total(); + true +} + +#[cfg(test)] +fn clear_relay_idle_pressure_state_for_testing() { + if let Some(registry) = RELAY_IDLE_CANDIDATE_REGISTRY.get() + && let Ok(mut guard) = registry.lock() + { + guard.by_conn_id.clear(); + guard.ordered.clear(); + guard.pressure_event_seq = 0; + guard.pressure_consumed_seq = 0; + } + RELAY_IDLE_MARK_SEQ.store(0, Ordering::Relaxed); +} + #[derive(Clone, Copy)] struct MeD2cFlushPolicy { max_frames: usize, @@ -74,6 +212,61 @@ struct MeD2cFlushPolicy { ack_flush_immediate: bool, } +#[derive(Clone, Copy)] +struct RelayClientIdlePolicy { + enabled: bool, + soft_idle: Duration, + hard_idle: Duration, + grace_after_downstream_activity: Duration, + legacy_frame_read_timeout: Duration, +} + +impl RelayClientIdlePolicy { + fn from_config(config: &ProxyConfig) -> Self { + Self { + enabled: config.timeouts.relay_idle_policy_v2_enabled, + soft_idle: Duration::from_secs(config.timeouts.relay_client_idle_soft_secs.max(1)), + hard_idle: Duration::from_secs(config.timeouts.relay_client_idle_hard_secs.max(1)), + grace_after_downstream_activity: Duration::from_secs( + config + .timeouts + .relay_idle_grace_after_downstream_activity_secs, + ), + legacy_frame_read_timeout: Duration::from_secs(config.timeouts.client_handshake.max(1)), + } + } + + #[cfg(test)] + fn disabled(frame_read_timeout: Duration) -> Self { + Self { + enabled: false, + soft_idle: Duration::from_secs(0), + hard_idle: Duration::from_secs(0), + grace_after_downstream_activity: Duration::from_secs(0), + legacy_frame_read_timeout: frame_read_timeout, + } + } +} + +struct RelayClientIdleState { + last_client_frame_at: Instant, + soft_idle_marked: bool, +} + +impl RelayClientIdleState { + fn new(now: Instant) -> Self { + Self { + last_client_frame_at: now, + soft_idle_marked: false, + } + } + + fn on_client_frame(&mut self, now: Instant) { + self.last_client_frame_at = now; + self.soft_idle_marked = false; + } +} + impl MeD2cFlushPolicy { fn from_config(config: &ProxyConfig) -> Self { Self { @@ -251,6 +444,7 @@ fn report_desync_frame_too_large( let bytes_me2c = state.bytes_me2c.load(Ordering::Relaxed); stats.increment_desync_total(); + stats.increment_relay_protocol_desync_close_total(); stats.observe_desync_frames_ok(frame_counter); if emit_full { stats.increment_desync_full_logged(); @@ -366,6 +560,7 @@ async fn enqueue_c2me_command( Ok(()) => Ok(()), Err(mpsc::error::TrySendError::Closed(cmd)) => Err(mpsc::error::SendError(cmd)), Err(mpsc::error::TrySendError::Full(cmd)) => { + note_relay_pressure_event(); // Cooperative yield reduces burst catch-up when the per-conn queue is near saturation. if tx.capacity() <= C2ME_SOFT_PRESSURE_MIN_FREE_SLOTS { tokio::task::yield_now().await; @@ -483,6 +678,10 @@ where let translated_local_addr = me_pool.translate_our_addr(local_addr); let frame_limit = config.general.max_client_frame; + let relay_idle_policy = RelayClientIdlePolicy::from_config(&config); + let session_started_at = forensics.started_at; + let mut relay_idle_state = RelayClientIdleState::new(session_started_at); + let last_downstream_activity_ms = Arc::new(AtomicU64::new(0)); let c2me_channel_capacity = config .general @@ -525,6 +724,7 @@ where let stats_clone = stats.clone(); let rng_clone = rng.clone(); let user_clone = user.clone(); + let last_downstream_activity_ms_clone = last_downstream_activity_ms.clone(); let bytes_me2c_clone = bytes_me2c.clone(); let d2c_flush_policy = MeD2cFlushPolicy::from_config(&config); let me_writer = tokio::spawn(async move { @@ -542,6 +742,8 @@ where let mut batch_bytes = 0usize; let mut flush_immediately; + let first_is_downstream_activity = + matches!(&first, MeResponse::Data { .. } | MeResponse::Ack(_)); match process_me_writer_response( first, &mut writer, @@ -557,6 +759,10 @@ where false, ).await? { MeWriterResponseOutcome::Continue { frames, bytes, flush_immediately: immediate } => { + if first_is_downstream_activity { + last_downstream_activity_ms_clone + .store(session_started_at.elapsed().as_millis() as u64, Ordering::Relaxed); + } batch_frames = batch_frames.saturating_add(frames); batch_bytes = batch_bytes.saturating_add(bytes); flush_immediately = immediate; @@ -575,6 +781,8 @@ where break; }; + let next_is_downstream_activity = + matches!(&next, MeResponse::Data { .. } | MeResponse::Ack(_)); match process_me_writer_response( next, &mut writer, @@ -590,6 +798,10 @@ where true, ).await? { MeWriterResponseOutcome::Continue { frames, bytes, flush_immediately: immediate } => { + if next_is_downstream_activity { + last_downstream_activity_ms_clone + .store(session_started_at.elapsed().as_millis() as u64, Ordering::Relaxed); + } batch_frames = batch_frames.saturating_add(frames); batch_bytes = batch_bytes.saturating_add(bytes); flush_immediately |= immediate; @@ -608,6 +820,8 @@ where { match tokio::time::timeout(d2c_flush_policy.max_delay, me_rx_task.recv()).await { Ok(Some(next)) => { + let next_is_downstream_activity = + matches!(&next, MeResponse::Data { .. } | MeResponse::Ack(_)); match process_me_writer_response( next, &mut writer, @@ -623,6 +837,10 @@ where true, ).await? { MeWriterResponseOutcome::Continue { frames, bytes, flush_immediately: immediate } => { + if next_is_downstream_activity { + last_downstream_activity_ms_clone + .store(session_started_at.elapsed().as_millis() as u64, Ordering::Relaxed); + } batch_frames = batch_frames.saturating_add(frames); batch_bytes = batch_bytes.saturating_add(bytes); flush_immediately |= immediate; @@ -641,6 +859,8 @@ where break; }; + let extra_is_downstream_activity = + matches!(&extra, MeResponse::Data { .. } | MeResponse::Ack(_)); match process_me_writer_response( extra, &mut writer, @@ -656,6 +876,10 @@ where true, ).await? { MeWriterResponseOutcome::Continue { frames, bytes, flush_immediately: immediate } => { + if extra_is_downstream_activity { + last_downstream_activity_ms_clone + .store(session_started_at.elapsed().as_millis() as u64, Ordering::Relaxed); + } batch_frames = batch_frames.saturating_add(frames); batch_bytes = batch_bytes.saturating_add(bytes); flush_immediately |= immediate; @@ -689,7 +913,24 @@ where let mut client_closed = false; let mut frame_counter: u64 = 0; let mut route_watch_open = true; + let mut seen_pressure_seq = relay_pressure_event_seq(); loop { + if relay_idle_policy.enabled + && maybe_evict_idle_candidate_on_pressure(conn_id, &mut seen_pressure_seq, stats.as_ref()) + { + info!( + conn_id, + trace_id = format_args!("0x{:016x}", trace_id), + user = %user, + "Middle-relay pressure eviction for idle-candidate session" + ); + let _ = enqueue_c2me_command(&c2me_tx, C2MeCommand::Close).await; + main_result = Err(ProxyError::Proxy( + "middle-relay session evicted under pressure (idle-candidate)".to_string(), + )); + break; + } + if let Some(cutover) = affected_cutover_state( &route_rx, RelayRouteMode::Middle, @@ -715,15 +956,18 @@ where route_watch_open = false; } } - payload_result = read_client_payload( + payload_result = read_client_payload_with_idle_policy( &mut crypto_reader, proto_tag, frame_limit, - Duration::from_secs(config.timeouts.client_handshake.max(1)), &buffer_pool, &forensics, &mut frame_counter, &stats, + &relay_idle_policy, + &mut relay_idle_state, + last_downstream_activity_ms.as_ref(), + session_started_at, ) => { match payload_result { Ok(Some((payload, quickack))) => { @@ -812,46 +1056,181 @@ where frames_ok = frame_counter, "ME relay cleanup" ); + clear_relay_idle_candidate(conn_id); me_pool.registry().unregister(conn_id).await; result } -async fn read_client_payload( +async fn read_client_payload_with_idle_policy( client_reader: &mut CryptoReader, proto_tag: ProtoTag, max_frame: usize, - frame_read_timeout: Duration, buffer_pool: &Arc, forensics: &RelayForensicsState, frame_counter: &mut u64, stats: &Stats, + idle_policy: &RelayClientIdlePolicy, + idle_state: &mut RelayClientIdleState, + last_downstream_activity_ms: &AtomicU64, + session_started_at: Instant, ) -> Result> where R: AsyncRead + Unpin + Send + 'static, { - async fn read_exact_with_timeout( + async fn read_exact_with_policy( client_reader: &mut CryptoReader, buf: &mut [u8], - frame_read_timeout: Duration, + idle_policy: &RelayClientIdlePolicy, + idle_state: &mut RelayClientIdleState, + last_downstream_activity_ms: &AtomicU64, + session_started_at: Instant, + forensics: &RelayForensicsState, + stats: &Stats, + read_label: &'static str, ) -> Result<()> where R: AsyncRead + Unpin + Send + 'static, { - match timeout(frame_read_timeout, client_reader.read_exact(buf)).await { - Ok(Ok(_)) => Ok(()), - Ok(Err(e)) => Err(ProxyError::Io(e)), - Err(_) => Err(ProxyError::Io(std::io::Error::new( - std::io::ErrorKind::TimedOut, - "middle-relay client frame read timeout", - ))), + fn hard_deadline( + idle_policy: &RelayClientIdlePolicy, + idle_state: &RelayClientIdleState, + session_started_at: Instant, + last_downstream_activity_ms: u64, + ) -> Instant { + let mut deadline = idle_state.last_client_frame_at + idle_policy.hard_idle; + if idle_policy.grace_after_downstream_activity.is_zero() { + return deadline; + } + + let downstream_at = session_started_at + Duration::from_millis(last_downstream_activity_ms); + if downstream_at > idle_state.last_client_frame_at { + let grace_deadline = downstream_at + idle_policy.grace_after_downstream_activity; + if grace_deadline > deadline { + deadline = grace_deadline; + } + } + deadline } + + let mut filled = 0usize; + while filled < buf.len() { + let timeout_window = if idle_policy.enabled { + let now = Instant::now(); + let downstream_ms = last_downstream_activity_ms.load(Ordering::Relaxed); + let hard_deadline = hard_deadline( + idle_policy, + idle_state, + session_started_at, + downstream_ms, + ); + if now >= hard_deadline { + clear_relay_idle_candidate(forensics.conn_id); + stats.increment_relay_idle_hard_close_total(); + let client_idle_secs = now + .saturating_duration_since(idle_state.last_client_frame_at) + .as_secs(); + let downstream_idle_secs = now + .saturating_duration_since(session_started_at + Duration::from_millis(downstream_ms)) + .as_secs(); + warn!( + trace_id = format_args!("0x{:016x}", forensics.trace_id), + conn_id = forensics.conn_id, + user = %forensics.user, + read_label, + client_idle_secs, + downstream_idle_secs, + soft_idle_secs = idle_policy.soft_idle.as_secs(), + hard_idle_secs = idle_policy.hard_idle.as_secs(), + grace_secs = idle_policy.grace_after_downstream_activity.as_secs(), + "Middle-relay hard idle close" + ); + return Err(ProxyError::Io(std::io::Error::new( + std::io::ErrorKind::TimedOut, + format!( + "middle-relay hard idle timeout while reading {read_label}: client_idle_secs={client_idle_secs}, downstream_idle_secs={downstream_idle_secs}, soft_idle_secs={}, hard_idle_secs={}, grace_secs={}", + idle_policy.soft_idle.as_secs(), + idle_policy.hard_idle.as_secs(), + idle_policy.grace_after_downstream_activity.as_secs(), + ), + ))); + } + + if !idle_state.soft_idle_marked + && now.saturating_duration_since(idle_state.last_client_frame_at) + >= idle_policy.soft_idle + { + idle_state.soft_idle_marked = true; + if mark_relay_idle_candidate(forensics.conn_id) { + stats.increment_relay_idle_soft_mark_total(); + } + info!( + trace_id = format_args!("0x{:016x}", forensics.trace_id), + conn_id = forensics.conn_id, + user = %forensics.user, + read_label, + soft_idle_secs = idle_policy.soft_idle.as_secs(), + hard_idle_secs = idle_policy.hard_idle.as_secs(), + grace_secs = idle_policy.grace_after_downstream_activity.as_secs(), + "Middle-relay soft idle mark" + ); + } + + let soft_deadline = idle_state.last_client_frame_at + idle_policy.soft_idle; + let next_deadline = if idle_state.soft_idle_marked { + hard_deadline + } else { + soft_deadline.min(hard_deadline) + }; + let mut remaining = next_deadline.saturating_duration_since(now); + if remaining.is_zero() { + remaining = Duration::from_millis(1); + } + remaining.min(RELAY_IDLE_IO_POLL_MAX) + } else { + idle_policy.legacy_frame_read_timeout + }; + + let read_result = timeout(timeout_window, client_reader.read(&mut buf[filled..])).await; + match read_result { + Ok(Ok(0)) => { + return Err(ProxyError::Io(std::io::Error::from( + std::io::ErrorKind::UnexpectedEof, + ))); + } + Ok(Ok(n)) => { + filled = filled.saturating_add(n); + } + Ok(Err(e)) => return Err(ProxyError::Io(e)), + Err(_) if !idle_policy.enabled => { + return Err(ProxyError::Io(std::io::Error::new( + std::io::ErrorKind::TimedOut, + format!("middle-relay client frame read timeout while reading {read_label}"), + ))); + } + Err(_) => {} + } + } + + Ok(()) } loop { let (len, quickack, raw_len_bytes) = match proto_tag { ProtoTag::Abridged => { let mut first = [0u8; 1]; - match read_exact_with_timeout(client_reader, &mut first, frame_read_timeout).await { + match read_exact_with_policy( + client_reader, + &mut first, + idle_policy, + idle_state, + last_downstream_activity_ms, + session_started_at, + forensics, + stats, + "abridged.first_len_byte", + ) + .await + { Ok(()) => {} Err(ProxyError::Io(e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => { return Ok(None); @@ -862,7 +1241,18 @@ where let quickack = (first[0] & 0x80) != 0; let len_words = if (first[0] & 0x7f) == 0x7f { let mut ext = [0u8; 3]; - read_exact_with_timeout(client_reader, &mut ext, frame_read_timeout).await?; + read_exact_with_policy( + client_reader, + &mut ext, + idle_policy, + idle_state, + last_downstream_activity_ms, + session_started_at, + forensics, + stats, + "abridged.extended_len", + ) + .await?; u32::from_le_bytes([ext[0], ext[1], ext[2], 0]) as usize } else { (first[0] & 0x7f) as usize @@ -875,7 +1265,19 @@ where } ProtoTag::Intermediate | ProtoTag::Secure => { let mut len_buf = [0u8; 4]; - match read_exact_with_timeout(client_reader, &mut len_buf, frame_read_timeout).await { + match read_exact_with_policy( + client_reader, + &mut len_buf, + idle_policy, + idle_state, + last_downstream_activity_ms, + session_started_at, + forensics, + stats, + "len_prefix", + ) + .await + { Ok(()) => {} Err(ProxyError::Io(e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => { return Ok(None); @@ -903,6 +1305,7 @@ where proto = ?proto_tag, "Frame too small — corrupt or probe" ); + stats.increment_relay_protocol_desync_close_total(); return Err(ProxyError::Proxy(format!("Frame too small: {len}"))); } @@ -923,6 +1326,7 @@ where Some(payload_len) => payload_len, None => { stats.increment_secure_padding_invalid(); + stats.increment_relay_protocol_desync_close_total(); return Err(ProxyError::Proxy(format!( "Invalid secure frame length: {len}" ))); @@ -939,17 +1343,91 @@ where payload.reserve(len - current_cap); } payload.resize(len, 0); - read_exact_with_timeout(client_reader, &mut payload[..len], frame_read_timeout).await?; + read_exact_with_policy( + client_reader, + &mut payload[..len], + idle_policy, + idle_state, + last_downstream_activity_ms, + session_started_at, + forensics, + stats, + "payload", + ) + .await?; // Secure Intermediate: strip validated trailing padding bytes. if proto_tag == ProtoTag::Secure { payload.truncate(secure_payload_len); } *frame_counter += 1; + idle_state.on_client_frame(Instant::now()); + clear_relay_idle_candidate(forensics.conn_id); return Ok(Some((payload, quickack))); } } +#[cfg(test)] +async fn read_client_payload_legacy( + client_reader: &mut CryptoReader, + proto_tag: ProtoTag, + max_frame: usize, + frame_read_timeout: Duration, + buffer_pool: &Arc, + forensics: &RelayForensicsState, + frame_counter: &mut u64, + stats: &Stats, +) -> Result> +where + R: AsyncRead + Unpin + Send + 'static, +{ + let now = Instant::now(); + let mut idle_state = RelayClientIdleState::new(now); + let last_downstream_activity_ms = AtomicU64::new(0); + let idle_policy = RelayClientIdlePolicy::disabled(frame_read_timeout); + read_client_payload_with_idle_policy( + client_reader, + proto_tag, + max_frame, + buffer_pool, + forensics, + frame_counter, + stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + now, + ) + .await +} + +#[cfg(test)] +async fn read_client_payload( + client_reader: &mut CryptoReader, + proto_tag: ProtoTag, + max_frame: usize, + frame_read_timeout: Duration, + buffer_pool: &Arc, + forensics: &RelayForensicsState, + frame_counter: &mut u64, + stats: &Stats, +) -> Result> +where + R: AsyncRead + Unpin + Send + 'static, +{ + read_client_payload_legacy( + client_reader, + proto_tag, + max_frame, + frame_read_timeout, + buffer_pool, + forensics, + frame_counter, + stats, + ) + .await +} + enum MeWriterResponseOutcome { Continue { frames: usize, @@ -1171,3 +1649,7 @@ where #[cfg(test)] #[path = "middle_relay_security_tests.rs"] mod security_tests; + +#[cfg(test)] +#[path = "middle_relay_idle_policy_security_tests.rs"] +mod idle_policy_security_tests; diff --git a/src/proxy/middle_relay_idle_policy_security_tests.rs b/src/proxy/middle_relay_idle_policy_security_tests.rs new file mode 100644 index 0000000..0efc904 --- /dev/null +++ b/src/proxy/middle_relay_idle_policy_security_tests.rs @@ -0,0 +1,799 @@ +use super::*; +use crate::crypto::AesCtr; +use crate::stats::Stats; +use crate::stream::{BufferPool, CryptoReader}; +use std::sync::{Arc, Mutex, OnceLock}; +use std::sync::atomic::AtomicU64; +use tokio::io::AsyncWriteExt; +use tokio::io::duplex; +use tokio::time::{Duration as TokioDuration, Instant as TokioInstant, timeout}; + +fn make_crypto_reader(reader: T) -> CryptoReader +where + T: AsyncRead + Unpin + Send + 'static, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoReader::new(reader, AesCtr::new(&key, iv)) +} + +fn encrypt_for_reader(plaintext: &[u8]) -> Vec { + let key = [0u8; 32]; + let iv = 0u128; + let mut cipher = AesCtr::new(&key, iv); + cipher.encrypt(plaintext) +} + +fn make_forensics(conn_id: u64, started_at: Instant) -> RelayForensicsState { + RelayForensicsState { + trace_id: 0xA000_0000 + conn_id, + conn_id, + user: format!("idle-test-user-{conn_id}"), + peer: "127.0.0.1:50000".parse().expect("peer parse must succeed"), + peer_hash: hash_ip("127.0.0.1".parse().expect("ip parse must succeed")), + started_at, + bytes_c2me: 0, + bytes_me2c: Arc::new(AtomicU64::new(0)), + desync_all_full: false, + } +} + +fn make_idle_policy(soft_ms: u64, hard_ms: u64, grace_ms: u64) -> RelayClientIdlePolicy { + RelayClientIdlePolicy { + enabled: true, + soft_idle: Duration::from_millis(soft_ms), + hard_idle: Duration::from_millis(hard_ms), + grace_after_downstream_activity: Duration::from_millis(grace_ms), + legacy_frame_read_timeout: Duration::from_millis(hard_ms), + } +} + +fn idle_pressure_test_lock() -> &'static Mutex<()> { + static TEST_LOCK: OnceLock> = OnceLock::new(); + TEST_LOCK.get_or_init(|| Mutex::new(())) +} + +fn acquire_idle_pressure_test_lock() -> std::sync::MutexGuard<'static, ()> { + match idle_pressure_test_lock().lock() { + Ok(guard) => guard, + Err(poisoned) => poisoned.into_inner(), + } +} + +#[tokio::test] +async fn idle_policy_soft_mark_then_hard_close_increments_reason_counters() { + let (reader, _writer) = duplex(1024); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(1, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_idle_policy(40, 120, 0); + let last_downstream_activity_ms = AtomicU64::new(0); + + let start = TokioInstant::now(); + let result = timeout( + TokioDuration::from_secs(2), + read_client_payload_with_idle_policy( + &mut crypto_reader, + ProtoTag::Intermediate, + 1024, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ), + ) + .await + .expect("idle test must complete"); + + assert!(matches!(result, Err(ProxyError::Io(ref e)) if e.kind() == std::io::ErrorKind::TimedOut)); + let err_text = match result { + Err(ProxyError::Io(ref e)) => e.to_string(), + _ => String::new(), + }; + assert!( + err_text.contains("middle-relay hard idle timeout"), + "hard close must expose a clear timeout reason" + ); + assert!( + start.elapsed() >= TokioDuration::from_millis(80), + "hard timeout must not trigger before idle deadline window" + ); + assert_eq!(stats.get_relay_idle_soft_mark_total(), 1); + assert_eq!(stats.get_relay_idle_hard_close_total(), 1); +} + +#[tokio::test] +async fn idle_policy_downstream_activity_grace_extends_hard_deadline() { + let (reader, _writer) = duplex(1024); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(2, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_idle_policy(30, 60, 100); + let last_downstream_activity_ms = AtomicU64::new(20); + + let start = TokioInstant::now(); + let result = timeout( + TokioDuration::from_secs(2), + read_client_payload_with_idle_policy( + &mut crypto_reader, + ProtoTag::Intermediate, + 1024, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ), + ) + .await + .expect("grace test must complete"); + + assert!(matches!(result, Err(ProxyError::Io(ref e)) if e.kind() == std::io::ErrorKind::TimedOut)); + assert!( + start.elapsed() >= TokioDuration::from_millis(100), + "recent downstream activity must extend hard idle deadline" + ); +} + +#[tokio::test] +async fn relay_idle_policy_disabled_keeps_legacy_timeout_behavior() { + let (reader, _writer) = duplex(1024); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let forensics = make_forensics(3, Instant::now()); + let mut frame_counter = 0u64; + + let result = read_client_payload( + &mut crypto_reader, + ProtoTag::Intermediate, + 1024, + Duration::from_millis(60), + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + ) + .await; + + assert!(matches!(result, Err(ProxyError::Io(ref e)) if e.kind() == std::io::ErrorKind::TimedOut)); + let err_text = match result { + Err(ProxyError::Io(ref e)) => e.to_string(), + _ => String::new(), + }; + assert!( + err_text.contains("middle-relay client frame read timeout"), + "legacy mode must keep expected timeout reason" + ); + assert_eq!(stats.get_relay_idle_soft_mark_total(), 0); + assert_eq!(stats.get_relay_idle_hard_close_total(), 0); +} + +#[tokio::test] +async fn adversarial_partial_frame_trickle_cannot_bypass_hard_idle_close() { + let (reader, mut writer) = duplex(1024); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(4, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_idle_policy(30, 90, 0); + let last_downstream_activity_ms = AtomicU64::new(0); + + let mut plaintext = Vec::with_capacity(12); + plaintext.extend_from_slice(&8u32.to_le_bytes()); + plaintext.extend_from_slice(&[1, 2, 3, 4, 5, 6, 7, 8]); + let encrypted = encrypt_for_reader(&plaintext); + writer + .write_all(&encrypted[..1]) + .await + .expect("must write a single trickle byte"); + + let result = timeout( + TokioDuration::from_secs(2), + read_client_payload_with_idle_policy( + &mut crypto_reader, + ProtoTag::Intermediate, + 1024, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ), + ) + .await + .expect("partial frame trickle test must complete"); + + assert!(matches!(result, Err(ProxyError::Io(ref e)) if e.kind() == std::io::ErrorKind::TimedOut)); + assert_eq!(frame_counter, 0, "partial trickle must not count as a valid frame"); +} + +#[tokio::test] +async fn successful_client_frame_resets_soft_idle_mark() { + let (reader, mut writer) = duplex(1024); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(5, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + idle_state.soft_idle_marked = true; + let idle_policy = make_idle_policy(200, 300, 0); + let last_downstream_activity_ms = AtomicU64::new(0); + + let payload = [9u8, 8, 7, 6, 5, 4, 3, 2]; + let mut plaintext = Vec::with_capacity(4 + payload.len()); + plaintext.extend_from_slice(&(payload.len() as u32).to_le_bytes()); + plaintext.extend_from_slice(&payload); + let encrypted = encrypt_for_reader(&plaintext); + writer + .write_all(&encrypted) + .await + .expect("must write full encrypted frame"); + + let read = read_client_payload_with_idle_policy( + &mut crypto_reader, + ProtoTag::Intermediate, + 1024, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ) + .await + .expect("frame read must succeed") + .expect("frame must be returned"); + + assert_eq!(read.0.as_ref(), &payload); + assert_eq!(frame_counter, 1); + assert!( + !idle_state.soft_idle_marked, + "a valid client frame must clear soft-idle mark" + ); +} + +#[tokio::test] +async fn protocol_desync_small_frame_updates_reason_counter() { + let (reader, mut writer) = duplex(1024); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let forensics = make_forensics(6, Instant::now()); + let mut frame_counter = 0u64; + + let mut plaintext = Vec::with_capacity(7); + plaintext.extend_from_slice(&3u32.to_le_bytes()); + plaintext.extend_from_slice(&[1u8, 2, 3]); + let encrypted = encrypt_for_reader(&plaintext); + writer.write_all(&encrypted).await.expect("must write frame"); + + let result = read_client_payload( + &mut crypto_reader, + ProtoTag::Secure, + 1024, + TokioDuration::from_secs(1), + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + ) + .await; + + assert!(matches!(result, Err(ProxyError::Proxy(ref msg)) if msg.contains("Frame too small"))); + assert_eq!(stats.get_relay_protocol_desync_close_total(), 1); +} + +#[tokio::test] +async fn stress_many_idle_sessions_fail_closed_without_hang() { + let mut tasks = Vec::with_capacity(24); + + for idx in 0..24u64 { + tasks.push(tokio::spawn(async move { + let (reader, _writer) = duplex(256); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(100 + idx, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_idle_policy(20, 50, 10); + let last_downstream_activity_ms = AtomicU64::new(0); + + let result = timeout( + TokioDuration::from_secs(2), + read_client_payload_with_idle_policy( + &mut crypto_reader, + ProtoTag::Intermediate, + 1024, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ), + ) + .await + .expect("stress task must complete"); + + assert!(matches!(result, Err(ProxyError::Io(ref e)) if e.kind() == std::io::ErrorKind::TimedOut)); + assert_eq!(stats.get_relay_idle_hard_close_total(), 1); + assert_eq!(frame_counter, 0); + })); + } + + for task in tasks { + task.await.expect("stress task must not panic"); + } +} + +#[test] +fn pressure_evicts_oldest_idle_candidate_with_deterministic_ordering() { + let _guard = acquire_idle_pressure_test_lock(); + clear_relay_idle_pressure_state_for_testing(); + let stats = Stats::new(); + + assert!(mark_relay_idle_candidate(10)); + assert!(mark_relay_idle_candidate(11)); + assert_eq!(oldest_relay_idle_candidate(), Some(10)); + + note_relay_pressure_event(); + + let mut seen_for_newer = 0u64; + assert!( + !maybe_evict_idle_candidate_on_pressure(11, &mut seen_for_newer, &stats), + "newer idle candidate must not be evicted while older candidate exists" + ); + assert_eq!(oldest_relay_idle_candidate(), Some(10)); + + let mut seen_for_oldest = 0u64; + assert!( + maybe_evict_idle_candidate_on_pressure(10, &mut seen_for_oldest, &stats), + "oldest idle candidate must be evicted first under pressure" + ); + assert_eq!(oldest_relay_idle_candidate(), Some(11)); + assert_eq!(stats.get_relay_pressure_evict_total(), 1); + + clear_relay_idle_pressure_state_for_testing(); +} + +#[test] +fn pressure_does_not_evict_without_new_pressure_signal() { + let _guard = acquire_idle_pressure_test_lock(); + clear_relay_idle_pressure_state_for_testing(); + let stats = Stats::new(); + + assert!(mark_relay_idle_candidate(21)); + let mut seen = relay_pressure_event_seq(); + + assert!( + !maybe_evict_idle_candidate_on_pressure(21, &mut seen, &stats), + "without new pressure signal, candidate must stay" + ); + assert_eq!(stats.get_relay_pressure_evict_total(), 0); + assert_eq!(oldest_relay_idle_candidate(), Some(21)); + + clear_relay_idle_pressure_state_for_testing(); +} + +#[test] +fn stress_pressure_eviction_preserves_fifo_across_many_candidates() { + let _guard = acquire_idle_pressure_test_lock(); + clear_relay_idle_pressure_state_for_testing(); + let stats = Stats::new(); + + let mut seen_per_conn = std::collections::HashMap::new(); + for conn_id in 1000u64..1064u64 { + assert!(mark_relay_idle_candidate(conn_id)); + seen_per_conn.insert(conn_id, 0u64); + } + + for expected in 1000u64..1064u64 { + note_relay_pressure_event(); + + let mut seen = *seen_per_conn + .get(&expected) + .expect("per-conn pressure cursor must exist"); + assert!( + maybe_evict_idle_candidate_on_pressure(expected, &mut seen, &stats), + "expected conn_id {expected} must be evicted next by deterministic FIFO ordering" + ); + seen_per_conn.insert(expected, seen); + + let next = if expected == 1063 { + None + } else { + Some(expected + 1) + }; + assert_eq!(oldest_relay_idle_candidate(), next); + } + + assert_eq!(stats.get_relay_pressure_evict_total(), 64); + clear_relay_idle_pressure_state_for_testing(); +} + +#[test] +fn blackhat_single_pressure_event_must_not_evict_more_than_one_candidate() { + let _guard = acquire_idle_pressure_test_lock(); + clear_relay_idle_pressure_state_for_testing(); + let stats = Stats::new(); + + assert!(mark_relay_idle_candidate(301)); + assert!(mark_relay_idle_candidate(302)); + assert!(mark_relay_idle_candidate(303)); + + let mut seen_301 = 0u64; + let mut seen_302 = 0u64; + let mut seen_303 = 0u64; + + // Single pressure event should authorize at most one eviction globally. + note_relay_pressure_event(); + + let evicted_301 = maybe_evict_idle_candidate_on_pressure(301, &mut seen_301, &stats); + let evicted_302 = maybe_evict_idle_candidate_on_pressure(302, &mut seen_302, &stats); + let evicted_303 = maybe_evict_idle_candidate_on_pressure(303, &mut seen_303, &stats); + + let evicted_total = [evicted_301, evicted_302, evicted_303] + .iter() + .filter(|value| **value) + .count(); + + assert_eq!( + evicted_total, 1, + "single pressure event must not cascade-evict multiple idle candidates" + ); + + clear_relay_idle_pressure_state_for_testing(); +} + +#[test] +fn blackhat_pressure_counter_must_track_global_budget_not_per_session_cursor() { + let _guard = acquire_idle_pressure_test_lock(); + clear_relay_idle_pressure_state_for_testing(); + let stats = Stats::new(); + + assert!(mark_relay_idle_candidate(401)); + assert!(mark_relay_idle_candidate(402)); + + let mut seen_oldest = 0u64; + let mut seen_next = 0u64; + + note_relay_pressure_event(); + + assert!( + maybe_evict_idle_candidate_on_pressure(401, &mut seen_oldest, &stats), + "oldest candidate must consume pressure budget first" + ); + + assert!( + !maybe_evict_idle_candidate_on_pressure(402, &mut seen_next, &stats), + "next candidate must not consume the same pressure budget" + ); + + assert_eq!( + stats.get_relay_pressure_evict_total(), + 1, + "single pressure budget must produce exactly one eviction" + ); + + clear_relay_idle_pressure_state_for_testing(); +} + +#[test] +fn blackhat_stale_pressure_before_idle_mark_must_not_trigger_eviction() { + let _guard = acquire_idle_pressure_test_lock(); + clear_relay_idle_pressure_state_for_testing(); + let stats = Stats::new(); + + // Pressure happened before any idle candidate existed. + note_relay_pressure_event(); + assert!(mark_relay_idle_candidate(501)); + + let mut seen = 0u64; + assert!( + !maybe_evict_idle_candidate_on_pressure(501, &mut seen, &stats), + "stale pressure (before soft-idle mark) must not evict newly marked candidate" + ); + + clear_relay_idle_pressure_state_for_testing(); +} + +#[test] +fn blackhat_stale_pressure_must_not_evict_any_of_newly_marked_batch() { + let _guard = acquire_idle_pressure_test_lock(); + clear_relay_idle_pressure_state_for_testing(); + let stats = Stats::new(); + + note_relay_pressure_event(); + assert!(mark_relay_idle_candidate(511)); + assert!(mark_relay_idle_candidate(512)); + assert!(mark_relay_idle_candidate(513)); + + let mut seen_511 = 0u64; + let mut seen_512 = 0u64; + let mut seen_513 = 0u64; + + let evicted = [ + maybe_evict_idle_candidate_on_pressure(511, &mut seen_511, &stats), + maybe_evict_idle_candidate_on_pressure(512, &mut seen_512, &stats), + maybe_evict_idle_candidate_on_pressure(513, &mut seen_513, &stats), + ] + .iter() + .filter(|value| **value) + .count(); + + assert_eq!( + evicted, 0, + "stale pressure event must not evict any candidate from a newly marked batch" + ); + + clear_relay_idle_pressure_state_for_testing(); +} + +#[test] +fn blackhat_stale_pressure_seen_without_candidates_must_be_globally_invalidated() { + let _guard = acquire_idle_pressure_test_lock(); + clear_relay_idle_pressure_state_for_testing(); + let stats = Stats::new(); + + note_relay_pressure_event(); + + // Session A observed pressure while there were no candidates. + let mut seen_a = 0u64; + assert!( + !maybe_evict_idle_candidate_on_pressure(999_001, &mut seen_a, &stats), + "no candidate existed, so no eviction is possible" + ); + + // Candidate appears later; Session B must not be able to consume stale pressure. + assert!(mark_relay_idle_candidate(521)); + let mut seen_b = 0u64; + assert!( + !maybe_evict_idle_candidate_on_pressure(521, &mut seen_b, &stats), + "once pressure is observed with empty candidate set, it must not be replayed later" + ); + + clear_relay_idle_pressure_state_for_testing(); +} + +#[test] +fn blackhat_stale_pressure_must_not_survive_candidate_churn() { + let _guard = acquire_idle_pressure_test_lock(); + clear_relay_idle_pressure_state_for_testing(); + let stats = Stats::new(); + + note_relay_pressure_event(); + assert!(mark_relay_idle_candidate(531)); + clear_relay_idle_candidate(531); + assert!(mark_relay_idle_candidate(532)); + + let mut seen = 0u64; + assert!( + !maybe_evict_idle_candidate_on_pressure(532, &mut seen, &stats), + "stale pressure must not survive clear+remark churn cycles" + ); + + clear_relay_idle_pressure_state_for_testing(); +} + +#[test] +fn blackhat_pressure_seq_saturation_must_not_disable_future_pressure_accounting() { + let _guard = acquire_idle_pressure_test_lock(); + clear_relay_idle_pressure_state_for_testing(); + + { + let mut guard = relay_idle_candidate_registry() + .lock() + .expect("registry lock must be available"); + guard.pressure_event_seq = u64::MAX; + guard.pressure_consumed_seq = u64::MAX - 1; + } + + // A new pressure event should still be representable; saturating at MAX creates a permanent lockout. + note_relay_pressure_event(); + let after = relay_pressure_event_seq(); + assert_ne!( + after, + u64::MAX, + "pressure sequence saturation must not permanently freeze event progression" + ); + + clear_relay_idle_pressure_state_for_testing(); +} + +#[test] +fn blackhat_pressure_seq_saturation_must_not_break_multiple_distinct_events() { + let _guard = acquire_idle_pressure_test_lock(); + clear_relay_idle_pressure_state_for_testing(); + + { + let mut guard = relay_idle_candidate_registry() + .lock() + .expect("registry lock must be available"); + guard.pressure_event_seq = u64::MAX; + guard.pressure_consumed_seq = u64::MAX; + } + + note_relay_pressure_event(); + let first = relay_pressure_event_seq(); + note_relay_pressure_event(); + let second = relay_pressure_event_seq(); + + assert!( + second > first, + "distinct pressure events must remain distinguishable even at sequence boundary" + ); + + clear_relay_idle_pressure_state_for_testing(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn integration_race_single_pressure_event_allows_at_most_one_eviction_under_parallel_claims() { + let _guard = acquire_idle_pressure_test_lock(); + clear_relay_idle_pressure_state_for_testing(); + + let stats = Arc::new(Stats::new()); + let sessions = 16usize; + let rounds = 200usize; + let conn_ids: Vec = (10_000u64..10_000u64 + sessions as u64).collect(); + let mut seen_per_session = vec![0u64; sessions]; + + for conn_id in &conn_ids { + assert!(mark_relay_idle_candidate(*conn_id)); + } + + for round in 0..rounds { + note_relay_pressure_event(); + + let mut joins = Vec::with_capacity(sessions); + for (idx, conn_id) in conn_ids.iter().enumerate() { + let mut seen = seen_per_session[idx]; + let conn_id = *conn_id; + let stats = stats.clone(); + joins.push(tokio::spawn(async move { + let evicted = maybe_evict_idle_candidate_on_pressure(conn_id, &mut seen, stats.as_ref()); + (idx, conn_id, seen, evicted) + })); + } + + let mut evicted_this_round = 0usize; + let mut evicted_conn = None; + for join in joins { + let (idx, conn_id, seen, evicted) = join.await.expect("race task must not panic"); + seen_per_session[idx] = seen; + if evicted { + evicted_this_round += 1; + evicted_conn = Some(conn_id); + } + } + + assert!( + evicted_this_round <= 1, + "round {round}: one pressure event must never produce more than one eviction" + ); + if let Some(conn) = evicted_conn { + assert!( + mark_relay_idle_candidate(conn), + "round {round}: evicted conn must be re-markable as idle candidate" + ); + } + } + + assert!( + stats.get_relay_pressure_evict_total() <= rounds as u64, + "eviction total must never exceed number of pressure events" + ); + assert!( + stats.get_relay_pressure_evict_total() > 0, + "parallel race must still observe at least one successful eviction" + ); + + clear_relay_idle_pressure_state_for_testing(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn integration_race_burst_pressure_with_churn_preserves_empty_set_invalidation_and_budget() { + let _guard = acquire_idle_pressure_test_lock(); + clear_relay_idle_pressure_state_for_testing(); + + let stats = Arc::new(Stats::new()); + let sessions = 12usize; + let rounds = 120usize; + let conn_ids: Vec = (20_000u64..20_000u64 + sessions as u64).collect(); + let mut seen_per_session = vec![0u64; sessions]; + + for conn_id in &conn_ids { + assert!(mark_relay_idle_candidate(*conn_id)); + } + + let mut expected_total_evictions = 0u64; + + for round in 0..rounds { + let empty_phase = round % 5 == 0; + if empty_phase { + for conn_id in &conn_ids { + clear_relay_idle_candidate(*conn_id); + } + } + + note_relay_pressure_event(); + + let mut joins = Vec::with_capacity(sessions); + for (idx, conn_id) in conn_ids.iter().enumerate() { + let mut seen = seen_per_session[idx]; + let conn_id = *conn_id; + let stats = stats.clone(); + joins.push(tokio::spawn(async move { + let evicted = maybe_evict_idle_candidate_on_pressure(conn_id, &mut seen, stats.as_ref()); + (idx, conn_id, seen, evicted) + })); + } + + let mut evicted_this_round = 0usize; + let mut evicted_conn = None; + for join in joins { + let (idx, conn_id, seen, evicted) = join.await.expect("burst race task must not panic"); + seen_per_session[idx] = seen; + if evicted { + evicted_this_round += 1; + evicted_conn = Some(conn_id); + } + } + + if empty_phase { + assert_eq!( + evicted_this_round, 0, + "round {round}: empty candidate phase must not allow stale-pressure eviction" + ); + for conn_id in &conn_ids { + assert!(mark_relay_idle_candidate(*conn_id)); + } + } else { + assert!( + evicted_this_round <= 1, + "round {round}: pressure budget must cap at one eviction" + ); + if let Some(conn_id) = evicted_conn { + expected_total_evictions = expected_total_evictions.saturating_add(1); + assert!(mark_relay_idle_candidate(conn_id)); + } + } + } + + assert_eq!( + stats.get_relay_pressure_evict_total(), + expected_total_evictions, + "global pressure eviction counter must match observed per-round successful consumes" + ); + + clear_relay_idle_pressure_state_for_testing(); +} diff --git a/src/stats/mod.rs b/src/stats/mod.rs index 3c79448..27461ef 100644 --- a/src/stats/mod.rs +++ b/src/stats/mod.rs @@ -99,6 +99,10 @@ pub struct Stats { me_handshake_reject_total: AtomicU64, me_reader_eof_total: AtomicU64, me_idle_close_by_peer_total: AtomicU64, + relay_idle_soft_mark_total: AtomicU64, + relay_idle_hard_close_total: AtomicU64, + relay_pressure_evict_total: AtomicU64, + relay_protocol_desync_close_total: AtomicU64, me_crc_mismatch: AtomicU64, me_seq_mismatch: AtomicU64, me_endpoint_quarantine_total: AtomicU64, @@ -525,6 +529,30 @@ impl Stats { .fetch_add(1, Ordering::Relaxed); } } + pub fn increment_relay_idle_soft_mark_total(&self) { + if self.telemetry_me_allows_normal() { + self.relay_idle_soft_mark_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_relay_idle_hard_close_total(&self) { + if self.telemetry_me_allows_normal() { + self.relay_idle_hard_close_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_relay_pressure_evict_total(&self) { + if self.telemetry_me_allows_normal() { + self.relay_pressure_evict_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_relay_protocol_desync_close_total(&self) { + if self.telemetry_me_allows_normal() { + self.relay_protocol_desync_close_total + .fetch_add(1, Ordering::Relaxed); + } + } pub fn increment_me_crc_mismatch(&self) { if self.telemetry_me_allows_normal() { self.me_crc_mismatch.fetch_add(1, Ordering::Relaxed); @@ -1019,6 +1047,18 @@ impl Stats { pub fn get_me_idle_close_by_peer_total(&self) -> u64 { self.me_idle_close_by_peer_total.load(Ordering::Relaxed) } + pub fn get_relay_idle_soft_mark_total(&self) -> u64 { + self.relay_idle_soft_mark_total.load(Ordering::Relaxed) + } + pub fn get_relay_idle_hard_close_total(&self) -> u64 { + self.relay_idle_hard_close_total.load(Ordering::Relaxed) + } + pub fn get_relay_pressure_evict_total(&self) -> u64 { + self.relay_pressure_evict_total.load(Ordering::Relaxed) + } + pub fn get_relay_protocol_desync_close_total(&self) -> u64 { + self.relay_protocol_desync_close_total.load(Ordering::Relaxed) + } pub fn get_me_crc_mismatch(&self) -> u64 { self.me_crc_mismatch.load(Ordering::Relaxed) } pub fn get_me_seq_mismatch(&self) -> u64 { self.me_seq_mismatch.load(Ordering::Relaxed) } pub fn get_me_endpoint_quarantine_total(&self) -> u64 {