diff --git a/src/transport/middle_proxy/mod.rs b/src/transport/middle_proxy/mod.rs index 974d31c..abd54c4 100644 --- a/src/transport/middle_proxy/mod.rs +++ b/src/transport/middle_proxy/mod.rs @@ -29,6 +29,10 @@ mod health_integration_tests; mod health_adversarial_tests; #[cfg(test)] mod send_adversarial_tests; +#[cfg(test)] +mod pool_writer_security_tests; +#[cfg(test)] +mod pool_refill_security_tests; use bytes::Bytes; diff --git a/src/transport/middle_proxy/pool_refill.rs b/src/transport/middle_proxy/pool_refill.rs index fc916f4..7808d3e 100644 --- a/src/transport/middle_proxy/pool_refill.rs +++ b/src/transport/middle_proxy/pool_refill.rs @@ -71,17 +71,31 @@ impl MePool { } if let Some((addr, expiry)) = earliest_quarantine { + let remaining = expiry.saturating_duration_since(now); + if remaining.is_zero() { + return vec![addr]; + } + drop(guard); debug!( %addr, - wait_ms = expiry.saturating_duration_since(now).as_millis(), - "All ME endpoints are quarantined for the DC group; retrying earliest one" + wait_ms = remaining.as_millis(), + "All ME endpoints quarantined; waiting for earliest to expire" ); + tokio::time::sleep(remaining).await; return vec![addr]; } Vec::new() } + #[cfg(test)] + pub(super) async fn connectable_endpoints_for_test( + &self, + endpoints: &[SocketAddr], + ) -> Vec { + self.connectable_endpoints(endpoints).await + } + pub(super) async fn has_refill_inflight_for_dc_key(&self, key: RefillDcKey) -> bool { let guard = self.refill_inflight_dc.lock().await; guard.contains(&key) diff --git a/src/transport/middle_proxy/pool_refill_security_tests.rs b/src/transport/middle_proxy/pool_refill_security_tests.rs new file mode 100644 index 0000000..cd49270 --- /dev/null +++ b/src/transport/middle_proxy/pool_refill_security_tests.rs @@ -0,0 +1,150 @@ +use std::collections::HashMap; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +use std::sync::Arc; +use std::time::{Duration, Instant}; + +use crate::config::{GeneralConfig, MeRouteNoWriterMode, MeSocksKdfPolicy, MeWriterPickMode}; +use crate::crypto::SecureRandom; +use crate::network::probe::NetworkDecision; +use crate::stats::Stats; + +use super::pool::MePool; + +async fn make_pool() -> Arc { + let general = GeneralConfig::default(); + + MePool::new( + None, + vec![1u8; 32], + None, + false, + None, + Vec::new(), + 1, + None, + 12, + 1200, + HashMap::new(), + HashMap::new(), + None, + NetworkDecision::default(), + None, + Arc::new(SecureRandom::new()), + Arc::new(Stats::default()), + general.me_keepalive_enabled, + general.me_keepalive_interval_secs, + general.me_keepalive_jitter_secs, + general.me_keepalive_payload_random, + general.rpc_proxy_req_every, + general.me_warmup_stagger_enabled, + general.me_warmup_step_delay_ms, + general.me_warmup_step_jitter_ms, + general.me_reconnect_max_concurrent_per_dc, + general.me_reconnect_backoff_base_ms, + general.me_reconnect_backoff_cap_ms, + general.me_reconnect_fast_retry_count, + general.me_single_endpoint_shadow_writers, + general.me_single_endpoint_outage_mode_enabled, + general.me_single_endpoint_outage_disable_quarantine, + general.me_single_endpoint_outage_backoff_min_ms, + general.me_single_endpoint_outage_backoff_max_ms, + general.me_single_endpoint_shadow_rotate_every_secs, + general.me_floor_mode, + general.me_adaptive_floor_idle_secs, + general.me_adaptive_floor_min_writers_single_endpoint, + general.me_adaptive_floor_min_writers_multi_endpoint, + general.me_adaptive_floor_recover_grace_secs, + general.me_adaptive_floor_writers_per_core_total, + general.me_adaptive_floor_cpu_cores_override, + general.me_adaptive_floor_max_extra_writers_single_per_core, + general.me_adaptive_floor_max_extra_writers_multi_per_core, + general.me_adaptive_floor_max_active_writers_per_core, + general.me_adaptive_floor_max_warm_writers_per_core, + general.me_adaptive_floor_max_active_writers_global, + general.me_adaptive_floor_max_warm_writers_global, + general.hardswap, + general.me_pool_drain_ttl_secs, + general.me_instadrain, + general.me_pool_drain_threshold, + general.effective_me_pool_force_close_secs(), + general.me_pool_min_fresh_ratio, + general.me_hardswap_warmup_delay_min_ms, + general.me_hardswap_warmup_delay_max_ms, + general.me_hardswap_warmup_extra_passes, + general.me_hardswap_warmup_pass_backoff_base_ms, + general.me_bind_stale_mode, + general.me_bind_stale_ttl_secs, + general.me_secret_atomic_snapshot, + general.me_deterministic_writer_sort, + MeWriterPickMode::default(), + general.me_writer_pick_sample_size, + MeSocksKdfPolicy::default(), + general.me_writer_cmd_channel_capacity, + general.me_route_channel_capacity, + general.me_route_backpressure_base_timeout_ms, + general.me_route_backpressure_high_timeout_ms, + general.me_route_backpressure_high_watermark_pct, + general.me_reader_route_data_wait_ms, + general.me_health_interval_ms_unhealthy, + general.me_health_interval_ms_healthy, + general.me_warn_rate_limit_ms, + MeRouteNoWriterMode::default(), + general.me_route_no_writer_wait_ms, + general.me_route_inline_recovery_attempts, + general.me_route_inline_recovery_wait_ms, + ) +} + +#[tokio::test] +async fn connectable_endpoints_waits_until_quarantine_expires() { + let pool = make_pool().await; + let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 31, 0, 11)), 443); + + { + let mut guard = pool.endpoint_quarantine.lock().await; + guard.insert(addr, Instant::now() + Duration::from_millis(80)); + } + + let started = Instant::now(); + let endpoints = pool.connectable_endpoints_for_test(&[addr]).await; + let elapsed = started.elapsed(); + + assert_eq!(endpoints, vec![addr]); + assert!( + elapsed >= Duration::from_millis(50), + "single-endpoint DC should honor quarantine before retry" + ); +} + +#[tokio::test] +async fn connectable_endpoints_releases_quarantine_lock_before_sleep() { + let pool = make_pool().await; + let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 31, 0, 12)), 443); + + { + let mut guard = pool.endpoint_quarantine.lock().await; + guard.insert(addr, Instant::now() + Duration::from_millis(120)); + } + + let pool_for_task = Arc::clone(&pool); + let task = tokio::spawn(async move { pool_for_task.connectable_endpoints_for_test(&[addr]).await }); + + tokio::time::sleep(Duration::from_millis(10)).await; + + let quarantine_check = tokio::time::timeout( + Duration::from_millis(40), + pool.is_endpoint_quarantined(addr), + ) + .await; + assert!( + quarantine_check.is_ok(), + "quarantine lock must not be held while waiting for expiry" + ); + assert!(quarantine_check.expect("timeout")); + + let endpoints = tokio::time::timeout(Duration::from_millis(300), task) + .await + .expect("connectable_endpoints task timed out") + .expect("task join failed"); + assert_eq!(endpoints, vec![addr]); +} diff --git a/src/transport/middle_proxy/pool_writer.rs b/src/transport/middle_proxy/pool_writer.rs index 5b23d7f..dbab191 100644 --- a/src/transport/middle_proxy/pool_writer.rs +++ b/src/transport/middle_proxy/pool_writer.rs @@ -240,21 +240,24 @@ impl MePool { stats_reader_close.increment_me_idle_close_by_peer_total(); info!(writer_id, "ME socket closed by peer on idle writer"); } - if let Some(pool) = pool.upgrade() - && cleanup_for_reader - .compare_exchange(false, true, Ordering::AcqRel, Ordering::Relaxed) - .is_ok() + if cleanup_for_reader + .compare_exchange(false, true, Ordering::AcqRel, Ordering::Relaxed) + .is_ok() { - pool.remove_writer_and_close_clients(writer_id).await; + if let Some(pool) = pool.upgrade() { + pool.remove_writer_and_close_clients(writer_id).await; + } else { + // Pool is already gone during shutdown; do a local writer list cleanup only. + let mut ws = writers_arc.write().await; + ws.retain(|w| w.id != writer_id); + debug!(writer_id, remaining = ws.len(), "Writer removed during pool shutdown"); + } } if let Err(e) = res { if !idle_close_by_peer { warn!(error = %e, "ME reader ended"); } } - let mut ws = writers_arc.write().await; - ws.retain(|w| w.id != writer_id); - info!(remaining = ws.len(), "Dead ME writer removed from pool"); }); let pool_ping = Arc::downgrade(self); @@ -357,12 +360,13 @@ impl MePool { stats_ping.increment_me_keepalive_failed(); debug!("ME ping failed, removing dead writer"); cancel_ping.cancel(); - if let Some(pool) = pool_ping.upgrade() - && cleanup_for_ping - .compare_exchange(false, true, Ordering::AcqRel, Ordering::Relaxed) - .is_ok() + if cleanup_for_ping + .compare_exchange(false, true, Ordering::AcqRel, Ordering::Relaxed) + .is_ok() { - pool.remove_writer_and_close_clients(writer_id).await; + if let Some(pool) = pool_ping.upgrade() { + pool.remove_writer_and_close_clients(writer_id).await; + } } break; } @@ -548,13 +552,16 @@ impl MePool { if let Some(tx) = close_tx { let _ = tx.send(WriterCommand::Close).await; } + if let Some(addr) = removed_addr + && let Some(uptime) = removed_uptime + { + // Quarantine flapping endpoints regardless of draining state. + self.maybe_quarantine_flapping_endpoint(addr, uptime).await; + } if trigger_refill && let Some(addr) = removed_addr && let Some(writer_dc) = removed_dc { - if let Some(uptime) = removed_uptime { - self.maybe_quarantine_flapping_endpoint(addr, uptime).await; - } self.trigger_immediate_refill_for_dc(addr, writer_dc); } conns diff --git a/src/transport/middle_proxy/pool_writer_security_tests.rs b/src/transport/middle_proxy/pool_writer_security_tests.rs new file mode 100644 index 0000000..61f291c --- /dev/null +++ b/src/transport/middle_proxy/pool_writer_security_tests.rs @@ -0,0 +1,171 @@ +use std::collections::HashMap; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, AtomicU8, AtomicU32, AtomicU64, Ordering}; +use std::time::{Duration, Instant}; + +use tokio::sync::mpsc; +use tokio_util::sync::CancellationToken; + +use super::codec::WriterCommand; +use super::pool::{MePool, MeWriter, WriterContour}; +use crate::config::{GeneralConfig, MeRouteNoWriterMode, MeSocksKdfPolicy, MeWriterPickMode}; +use crate::crypto::SecureRandom; +use crate::network::probe::NetworkDecision; +use crate::stats::Stats; + +async fn make_pool() -> Arc { + let general = GeneralConfig::default(); + + MePool::new( + None, + vec![1u8; 32], + None, + false, + None, + Vec::new(), + 1, + None, + 12, + 1200, + HashMap::new(), + HashMap::new(), + None, + NetworkDecision::default(), + None, + Arc::new(SecureRandom::new()), + Arc::new(Stats::default()), + general.me_keepalive_enabled, + general.me_keepalive_interval_secs, + general.me_keepalive_jitter_secs, + general.me_keepalive_payload_random, + general.rpc_proxy_req_every, + general.me_warmup_stagger_enabled, + general.me_warmup_step_delay_ms, + general.me_warmup_step_jitter_ms, + general.me_reconnect_max_concurrent_per_dc, + general.me_reconnect_backoff_base_ms, + general.me_reconnect_backoff_cap_ms, + general.me_reconnect_fast_retry_count, + general.me_single_endpoint_shadow_writers, + general.me_single_endpoint_outage_mode_enabled, + general.me_single_endpoint_outage_disable_quarantine, + general.me_single_endpoint_outage_backoff_min_ms, + general.me_single_endpoint_outage_backoff_max_ms, + general.me_single_endpoint_shadow_rotate_every_secs, + general.me_floor_mode, + general.me_adaptive_floor_idle_secs, + general.me_adaptive_floor_min_writers_single_endpoint, + general.me_adaptive_floor_min_writers_multi_endpoint, + general.me_adaptive_floor_recover_grace_secs, + general.me_adaptive_floor_writers_per_core_total, + general.me_adaptive_floor_cpu_cores_override, + general.me_adaptive_floor_max_extra_writers_single_per_core, + general.me_adaptive_floor_max_extra_writers_multi_per_core, + general.me_adaptive_floor_max_active_writers_per_core, + general.me_adaptive_floor_max_warm_writers_per_core, + general.me_adaptive_floor_max_active_writers_global, + general.me_adaptive_floor_max_warm_writers_global, + general.hardswap, + general.me_pool_drain_ttl_secs, + general.me_instadrain, + general.me_pool_drain_threshold, + general.effective_me_pool_force_close_secs(), + general.me_pool_min_fresh_ratio, + general.me_hardswap_warmup_delay_min_ms, + general.me_hardswap_warmup_delay_max_ms, + general.me_hardswap_warmup_extra_passes, + general.me_hardswap_warmup_pass_backoff_base_ms, + general.me_bind_stale_mode, + general.me_bind_stale_ttl_secs, + general.me_secret_atomic_snapshot, + general.me_deterministic_writer_sort, + MeWriterPickMode::default(), + general.me_writer_pick_sample_size, + MeSocksKdfPolicy::default(), + general.me_writer_cmd_channel_capacity, + general.me_route_channel_capacity, + general.me_route_backpressure_base_timeout_ms, + general.me_route_backpressure_high_timeout_ms, + general.me_route_backpressure_high_watermark_pct, + general.me_reader_route_data_wait_ms, + general.me_health_interval_ms_unhealthy, + general.me_health_interval_ms_healthy, + general.me_warn_rate_limit_ms, + MeRouteNoWriterMode::default(), + general.me_route_no_writer_wait_ms, + general.me_route_inline_recovery_attempts, + general.me_route_inline_recovery_wait_ms, + ) +} + +async fn insert_writer( + pool: &Arc, + writer_id: u64, + writer_dc: i32, + addr: SocketAddr, + draining: bool, + created_at: Instant, +) { + let (tx, _rx) = mpsc::channel::(8); + let contour = if draining { + WriterContour::Draining + } else { + WriterContour::Active + }; + let writer = MeWriter { + id: writer_id, + addr, + source_ip: addr.ip(), + writer_dc, + generation: pool.current_generation(), + contour: Arc::new(AtomicU8::new(contour.as_u8())), + created_at, + tx: tx.clone(), + cancel: CancellationToken::new(), + degraded: Arc::new(AtomicBool::new(false)), + rtt_ema_ms_x10: Arc::new(AtomicU32::new(0)), + draining: Arc::new(AtomicBool::new(draining)), + draining_started_at_epoch_secs: Arc::new(AtomicU64::new(0)), + drain_deadline_epoch_secs: Arc::new(AtomicU64::new(0)), + allow_drain_fallback: Arc::new(AtomicBool::new(false)), + }; + + pool.writers.write().await.push(writer); + pool.registry.register_writer(writer_id, tx).await; + pool.conn_count.fetch_add(1, Ordering::Relaxed); +} + +#[tokio::test] +async fn remove_draining_writer_still_quarantines_flapping_endpoint() { + let pool = make_pool().await; + let writer_id = 77; + let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 12, 0, 77)), 443); + insert_writer( + &pool, + writer_id, + 2, + addr, + true, + Instant::now() - Duration::from_secs(1), + ) + .await; + + pool.remove_writer_and_close_clients(writer_id).await; + + let writer_still_present = pool + .writers + .read() + .await + .iter() + .any(|writer| writer.id == writer_id); + assert!( + !writer_still_present, + "writer must be removed from pool after cleanup" + ); + assert!( + pool.is_endpoint_quarantined(addr).await, + "draining removals must still quarantine flapping endpoints" + ); + assert_eq!(pool.conn_count.load(Ordering::Relaxed), 0); +}