diff --git a/src/proxy/client.rs b/src/proxy/client.rs index 849e409..5d32e34 100644 --- a/src/proxy/client.rs +++ b/src/proxy/client.rs @@ -878,7 +878,7 @@ impl RunningClientHandler { { Ok(reservation) => reservation, Err(e) => { - warn!(user = %user, error = %e, "User limit exceeded"); + warn!(user = %user, error = %e, "User admission check failed"); return Err(e); } }; @@ -998,8 +998,8 @@ impl RunningClientHandler { #[cfg(test)] async fn check_user_limits_static( - user: &str, - config: &ProxyConfig, + user: &str, + config: &ProxyConfig, stats: &Stats, peer_addr: SocketAddr, ip_tracker: &UserIpTracker, diff --git a/src/proxy/client_security_tests.rs b/src/proxy/client_security_tests.rs index 8bdb234..6ca2d4b 100644 --- a/src/proxy/client_security_tests.rs +++ b/src/proxy/client_security_tests.rs @@ -1420,6 +1420,105 @@ async fn tcp_limit_rejection_does_not_reserve_ip_or_trigger_rollback() { ); } +#[tokio::test] +async fn zero_tcp_limit_rejects_without_ip_or_counter_side_effects() { + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert("user".to_string(), 0); + + let stats = Stats::new(); + let ip_tracker = UserIpTracker::new(); + let peer_addr: SocketAddr = "198.51.100.211:50001".parse().unwrap(); + + let result = RunningClientHandler::check_user_limits_static( + "user", + &config, + &stats, + peer_addr, + &ip_tracker, + ) + .await; + + assert!(matches!( + result, + Err(ProxyError::ConnectionLimitExceeded { user }) if user == "user" + )); + assert_eq!(stats.get_user_curr_connects("user"), 0); + assert_eq!(ip_tracker.get_active_ip_count("user").await, 0); +} + +#[tokio::test] +async fn concurrent_distinct_ip_rejections_rollback_user_counter_without_leak() { + let user = "rollback-storm-user"; + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 128); + + let config = Arc::new(config); + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 1).await; + + let keeper_peer: SocketAddr = "198.51.100.212:50002".parse().unwrap(); + let keeper = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + keeper_peer, + ip_tracker.clone(), + ) + .await + .expect("keeper reservation must succeed"); + + let mut tasks = tokio::task::JoinSet::new(); + for i in 0..64u8 { + let config = config.clone(); + let stats = stats.clone(); + let ip_tracker = ip_tracker.clone(); + tasks.spawn(async move { + let peer = SocketAddr::new( + IpAddr::V4(std::net::Ipv4Addr::new(198, 51, 101, i.saturating_add(1))), + 41000 + i as u16, + ); + let result = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats, + peer, + ip_tracker, + ) + .await; + assert!(matches!( + result, + Err(ProxyError::ConnectionLimitExceeded { user }) if user == "rollback-storm-user" + )); + }); + } + + while let Some(joined) = tasks.join_next().await { + joined.unwrap(); + } + + assert_eq!( + stats.get_user_curr_connects(user), + 1, + "failed distinct-IP attempts must rollback acquired user slots" + ); + assert_eq!( + ip_tracker.get_active_ip_count(user).await, + 1, + "failed distinct-IP attempts must not leave extra active IPs" + ); + + keeper.release().await; + assert_eq!(stats.get_user_curr_connects(user), 0); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 0); +} + #[tokio::test] async fn explicit_reservation_release_cleans_user_and_ip_immediately() { let user = "release-user"; @@ -2990,3 +3089,478 @@ async fn tls_record_len_16384_is_accepted_in_client_handler_pipeline() { "Valid max-length ClientHello must not increment bad counter" ); } + +fn lcg_next(state: &mut u64) -> u64 { + *state = state.wrapping_mul(6364136223846793005).wrapping_add(1); + *state +} + +async fn wait_for_user_and_ip_zero( + stats: &Arc, + ip_tracker: &Arc, + user: &str, +) { + tokio::time::timeout(Duration::from_secs(2), async { + loop { + if stats.get_user_curr_connects(user) == 0 + && ip_tracker.get_active_ip_count(user).await == 0 + { + break; + } + tokio::time::sleep(Duration::from_millis(5)).await; + } + }) + .await + .expect("cleanup must converge to zero user and IP footprint"); +} + +async fn burst_acquire_distinct_ips( + user: &'static str, + config: Arc, + stats: Arc, + ip_tracker: Arc, + third_octet: u8, + attempts: u16, +) -> (Vec, usize) { + let mut tasks = tokio::task::JoinSet::new(); + for i in 0..attempts { + let config = config.clone(); + let stats = stats.clone(); + let ip_tracker = ip_tracker.clone(); + tasks.spawn(async move { + let host = (i as u8).saturating_add(1); + let peer = SocketAddr::new( + IpAddr::V4(std::net::Ipv4Addr::new(198, 51, third_octet, host)), + 55000 + i, + ); + RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats, + peer, + ip_tracker, + ) + .await + }); + } + + let mut successes = Vec::new(); + let mut failures = 0usize; + while let Some(joined) = tasks.join_next().await { + match joined.expect("burst acquire task must not panic") { + Ok(reservation) => successes.push(reservation), + Err(err) => { + assert!(matches!( + err, + ProxyError::ConnectionLimitExceeded { user: ref denied_user } + if denied_user == user + )); + failures = failures.saturating_add(1); + } + } + } + + (successes, failures) +} + +#[tokio::test] +async fn deterministic_mixed_reservation_churn_preserves_counter_and_eventual_cleanup() { + let user = "deterministic-churn-user"; + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 12); + + let config = Arc::new(config); + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 4).await; + + let mut seed = 0xD1F2_A4C8_991B_77E1u64; + let mut reservations: Vec> = Vec::new(); + + for step in 0..220u64 { + let op = (lcg_next(&mut seed) % 100) as u8; + let active = reservations.iter().filter(|entry| entry.is_some()).count(); + + if active == 0 || op < 55 { + let ip_octet = (lcg_next(&mut seed) % 16 + 1) as u8; + let peer = SocketAddr::new( + IpAddr::V4(std::net::Ipv4Addr::new(198, 51, 120, ip_octet)), + 52000 + (step % 4000) as u16, + ); + let result = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await; + + if let Ok(reservation) = result { + reservations.push(Some(reservation)); + } else { + assert!(matches!( + result, + Err(ProxyError::ConnectionLimitExceeded { user }) if user == "deterministic-churn-user" + )); + } + } else { + let selected = reservations + .iter() + .enumerate() + .filter(|(_, entry)| entry.is_some()) + .map(|(idx, _)| idx) + .nth((lcg_next(&mut seed) as usize) % active) + .unwrap(); + + let reservation = reservations[selected].take().unwrap(); + if op < 80 { + reservation.release().await; + } else { + std::thread::spawn(move || { + drop(reservation); + }) + .join() + .expect("cross-thread drop must not panic"); + } + } + + let live_slots = reservations.iter().filter(|entry| entry.is_some()).count() as u64; + assert_eq!( + stats.get_user_curr_connects(user), + live_slots, + "current-connects counter must match number of live reservations" + ); + assert!( + stats.get_user_curr_connects(user) <= 12, + "current-connects must stay within configured TCP limit" + ); + assert!( + ip_tracker.get_active_ip_count(user).await <= 4, + "active unique IPs must stay within configured per-user IP limit" + ); + } + + for reservation in reservations.into_iter().flatten() { + reservation.release().await; + } + wait_for_user_and_ip_zero(&stats, &ip_tracker, user).await; +} + +#[tokio::test] +async fn cross_thread_drop_storm_then_parallel_reacquire_wave_has_no_leak() { + let user = "drop-storm-reacquire-user"; + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 64); + + let config = Arc::new(config); + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 8).await; + + let mut initial = Vec::new(); + for i in 0..32u16 { + let ip_octet = (i % 8 + 1) as u8; + let peer = SocketAddr::new( + IpAddr::V4(std::net::Ipv4Addr::new(203, 0, 120, ip_octet)), + 53000 + i, + ); + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await + .expect("initial reservation must succeed"); + initial.push(reservation); + } + + let mut second_half = initial.split_off(16); + + let mut releases = Vec::new(); + for reservation in initial { + releases.push(tokio::spawn(async move { + reservation.release().await; + })); + } + for release_task in releases { + release_task.await.expect("release task must not panic"); + } + + let mut drop_threads = Vec::new(); + for reservation in second_half.drain(..) { + drop_threads.push(std::thread::spawn(move || { + drop(reservation); + })); + } + for drop_thread in drop_threads { + drop_thread + .join() + .expect("cross-thread drop worker must not panic"); + } + + wait_for_user_and_ip_zero(&stats, &ip_tracker, user).await; + + let mut reacquire_tasks = tokio::task::JoinSet::new(); + for i in 0..16u16 { + let config = config.clone(); + let stats = stats.clone(); + let ip_tracker = ip_tracker.clone(); + reacquire_tasks.spawn(async move { + let peer = SocketAddr::new( + IpAddr::V4(std::net::Ipv4Addr::new(198, 51, 121, (i + 1) as u8)), + 54000 + i, + ); + RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats, + peer, + ip_tracker, + ) + .await + }); + } + + let mut acquired = Vec::new(); + while let Some(joined) = reacquire_tasks.join_next().await { + match joined.expect("reacquire task must not panic") { + Ok(reservation) => acquired.push(reservation), + Err(err) => { + assert!(matches!( + err, + ProxyError::ConnectionLimitExceeded { user } + if user == "drop-storm-reacquire-user" + )); + } + } + } + + assert!( + acquired.len() <= 8, + "parallel distinct-IP reacquire wave must not exceed per-user unique IP limit" + ); + for reservation in acquired { + reservation.release().await; + } + wait_for_user_and_ip_zero(&stats, &ip_tracker, user).await; +} + +#[tokio::test] +async fn scheduled_near_limit_and_burst_windows_preserve_admission_invariants() { + let user: &'static str = "scheduled-attack-user"; + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 6); + + let config = Arc::new(config); + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 2).await; + + let mut base = Vec::new(); + for i in 0..5u16 { + let peer = SocketAddr::new(IpAddr::V4(std::net::Ipv4Addr::new(198, 51, 130, 1)), 56000 + i); + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await + .expect("near-limit warmup reservation must succeed"); + base.push(reservation); + } + assert_eq!(stats.get_user_curr_connects(user), 5); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 1); + + let (wave1_success, wave1_fail) = burst_acquire_distinct_ips( + user, + config.clone(), + stats.clone(), + ip_tracker.clone(), + 131, + 32, + ) + .await; + assert_eq!(wave1_success.len(), 1); + assert_eq!(wave1_fail, 31); + assert_eq!(stats.get_user_curr_connects(user), 6); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 2); + + let released = base.pop().expect("must have releasable reservation"); + released.release().await; + for reservation in wave1_success { + reservation.release().await; + } + + tokio::time::timeout(Duration::from_secs(1), async { + loop { + if stats.get_user_curr_connects(user) == 4 + && ip_tracker.get_active_ip_count(user).await == 1 + { + break; + } + tokio::time::sleep(Duration::from_millis(5)).await; + } + }) + .await + .expect("window cleanup must settle to expected occupancy"); + + let (wave2_success, wave2_fail) = burst_acquire_distinct_ips( + user, + config, + stats.clone(), + ip_tracker.clone(), + 132, + 32, + ) + .await; + assert_eq!(wave2_success.len(), 1); + assert_eq!(wave2_fail, 31); + assert_eq!(stats.get_user_curr_connects(user), 5); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 2); + + let tail = base.split_off(2); + + let mut drop_threads = Vec::new(); + for reservation in base { + drop_threads.push(std::thread::spawn(move || { + drop(reservation); + })); + } + for drop_thread in drop_threads { + drop_thread + .join() + .expect("cross-thread scheduled cleanup must not panic"); + } + + for reservation in tail { + reservation.release().await; + } + for reservation in wave2_success { + reservation.release().await; + } + + wait_for_user_and_ip_zero(&stats, &ip_tracker, user).await; +} + +#[tokio::test] +async fn scheduled_mode_switch_burst_churn_preserves_limits_and_cleanup() { + let user: &'static str = "scheduled-mode-switch-user"; + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 10); + + let config = Arc::new(config); + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 3).await; + + let base_peer = SocketAddr::new(IpAddr::V4(std::net::Ipv4Addr::new(198, 51, 140, 1)), 57000); + let mut base = Vec::new(); + for i in 0..7u16 { + let peer = SocketAddr::new(base_peer.ip(), base_peer.port().saturating_add(i)); + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await + .expect("base occupancy reservation must succeed"); + base.push(reservation); + } + + assert_eq!(stats.get_user_curr_connects(user), 7); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 1); + + for round in 0..8u8 { + let (wave_success, wave_fail) = burst_acquire_distinct_ips( + user, + config.clone(), + stats.clone(), + ip_tracker.clone(), + 141u8.saturating_add(round), + 24, + ) + .await; + + assert!( + wave_success.len() <= 2, + "burst must not exceed available unique-IP headroom under limit=3" + ); + assert_eq!(wave_success.len() + wave_fail, 24); + assert_eq!( + stats.get_user_curr_connects(user), + 7 + wave_success.len() as u64, + "slot counter must reflect base occupancy plus successful burst leases" + ); + assert!(ip_tracker.get_active_ip_count(user).await <= 3); + + if round % 2 == 0 { + for reservation in wave_success { + reservation.release().await; + } + let rotated = base.pop().expect("base rotation reservation must exist"); + rotated.release().await; + } else { + for reservation in wave_success { + std::thread::spawn(move || { + drop(reservation); + }) + .join() + .expect("drop-heavy burst cleanup thread must not panic"); + } + let rotated = base.pop().expect("base rotation reservation must exist"); + std::thread::spawn(move || { + drop(rotated); + }) + .join() + .expect("drop-heavy base cleanup thread must not panic"); + } + + let replacement = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + base_peer, + ip_tracker.clone(), + ) + .await + .expect("base replacement reservation must succeed after each round"); + base.push(replacement); + + tokio::time::timeout(Duration::from_secs(1), async { + loop { + if stats.get_user_curr_connects(user) == 7 + && ip_tracker.get_active_ip_count(user).await <= 1 + { + break; + } + tokio::time::sleep(Duration::from_millis(5)).await; + } + }) + .await + .expect("round cleanup must converge to steady base occupancy"); + } + + for reservation in base { + reservation.release().await; + } + wait_for_user_and_ip_zero(&stats, &ip_tracker, user).await; +} diff --git a/src/transport/middle_proxy/health.rs b/src/transport/middle_proxy/health.rs index edc9598..1c2c648 100644 --- a/src/transport/middle_proxy/health.rs +++ b/src/transport/middle_proxy/health.rs @@ -186,9 +186,7 @@ pub(super) async fn reap_draining_writers( } } - let mut active_draining_writer_ids = HashSet::with_capacity(draining_writers.len()); for writer in draining_writers { - active_draining_writer_ids.insert(writer.id); if drain_ttl_secs > 0 && writer.draining_started_at_epoch_secs != 0 && now_epoch_secs.saturating_sub(writer.draining_started_at_epoch_secs) > drain_ttl_secs @@ -214,12 +212,9 @@ pub(super) async fn reap_draining_writers( { warn!(writer_id = writer.id, "Drain timeout, force-closing"); force_close_writer_ids.push(writer.id); - active_draining_writer_ids.remove(&writer.id); } } - warn_next_allowed.retain(|writer_id, _| active_draining_writer_ids.contains(writer_id)); - let close_budget = health_drain_close_budget(); let requested_force_close = force_close_writer_ids.len(); let requested_empty_close = empty_writer_ids.len(); @@ -257,6 +252,18 @@ pub(super) async fn reap_draining_writers( "ME draining close backlog deferred to next health cycle" ); } + + // Keep warn cooldown state for draining writers still present in the pool; + // drop state only once a writer is actually removed. + let active_draining_writer_ids = { + let writers = pool.writers.read().await; + writers + .iter() + .filter(|writer| writer.draining.load(std::sync::atomic::Ordering::Relaxed)) + .map(|writer| writer.id) + .collect::>() + }; + warn_next_allowed.retain(|writer_id, _| active_draining_writer_ids.contains(writer_id)); } pub(super) fn health_drain_close_budget() -> usize { diff --git a/src/transport/middle_proxy/health_adversarial_tests.rs b/src/transport/middle_proxy/health_adversarial_tests.rs index 675005a..cd06fdf 100644 --- a/src/transport/middle_proxy/health_adversarial_tests.rs +++ b/src/transport/middle_proxy/health_adversarial_tests.rs @@ -1,4 +1,5 @@ use std::collections::HashMap; +use std::collections::HashSet; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::sync::Arc; use std::sync::atomic::{AtomicBool, AtomicU8, AtomicU32, AtomicU64, Ordering}; @@ -181,6 +182,40 @@ async fn sorted_writer_ids(pool: &Arc) -> Vec { ids } +fn lcg_next(state: &mut u64) -> u64 { + *state = state.wrapping_mul(6364136223846793005).wrapping_add(1); + *state +} + +async fn draining_writer_ids(pool: &Arc) -> HashSet { + pool.writers + .read() + .await + .iter() + .filter(|writer| writer.draining.load(Ordering::Relaxed)) + .map(|writer| writer.id) + .collect::>() +} + +async fn set_writer_runtime_state( + pool: &Arc, + writer_id: u64, + draining: bool, + drain_started_at_epoch_secs: u64, + drain_deadline_epoch_secs: u64, +) { + let writers = pool.writers.read().await; + if let Some(writer) = writers.iter().find(|writer| writer.id == writer_id) { + writer.draining.store(draining, Ordering::Relaxed); + writer + .draining_started_at_epoch_secs + .store(drain_started_at_epoch_secs, Ordering::Relaxed); + writer + .drain_deadline_epoch_secs + .store(drain_deadline_epoch_secs, Ordering::Relaxed); + } +} + #[tokio::test] async fn reap_draining_writers_clears_warn_state_when_pool_empty() { let (pool, _rng) = make_pool(128, 1, 1).await; @@ -430,6 +465,149 @@ async fn me_health_monitor_eliminates_mixed_empty_and_deadline_backlog() { assert!(writer_count(&pool).await <= threshold as usize); } +#[tokio::test] +async fn reap_draining_writers_deterministic_mixed_state_churn_preserves_invariants() { + let threshold = 9u64; + let (pool, _rng) = make_pool(threshold, 1, 1).await; + let mut warn_next_allowed = HashMap::new(); + let mut seed = 0x9E37_79B9_7F4A_7C15u64; + let mut next_writer_id = 20_000u64; + let now_epoch_secs = MePool::now_epoch_secs(); + + for writer_id in 1..=72u64 { + let bound_clients = if writer_id % 4 == 0 { 0 } else { 1 }; + let deadline = if writer_id % 5 == 0 { + now_epoch_secs.saturating_sub(1) + } else { + 0 + }; + insert_draining_writer( + &pool, + writer_id, + now_epoch_secs.saturating_sub(500).saturating_add(writer_id), + bound_clients, + deadline, + ) + .await; + } + + for _round in 0..90 { + reap_draining_writers(&pool, &mut warn_next_allowed).await; + + let draining_ids = draining_writer_ids(&pool).await; + assert!( + warn_next_allowed.keys().all(|id| draining_ids.contains(id)), + "warn-state keys must always be a subset of live draining writers" + ); + + let writer_ids = sorted_writer_ids(&pool).await; + if writer_ids.is_empty() { + continue; + } + + let remove_n = (lcg_next(&mut seed) % 3) as usize; + for writer_id in writer_ids.iter().copied().take(remove_n) { + let _ = pool.remove_writer_and_close_clients(writer_id).await; + } + + let survivors = sorted_writer_ids(&pool).await; + if !survivors.is_empty() { + let idx = (lcg_next(&mut seed) as usize) % survivors.len(); + let target = survivors[idx]; + set_writer_runtime_state(&pool, target, false, 0, 0).await; + } + + let survivors = sorted_writer_ids(&pool).await; + if survivors.len() > 1 { + let idx = (lcg_next(&mut seed) as usize) % survivors.len(); + let target = survivors[idx]; + let expired_deadline = if lcg_next(&mut seed) & 1 == 0 { + now_epoch_secs.saturating_sub(1) + } else { + 0 + }; + set_writer_runtime_state( + &pool, + target, + true, + now_epoch_secs.saturating_sub(120), + expired_deadline, + ) + .await; + } + + let inject_n = (lcg_next(&mut seed) % 4) as usize; + for _ in 0..inject_n { + let bound_clients = if lcg_next(&mut seed) & 1 == 0 { 0 } else { 1 }; + let deadline = if lcg_next(&mut seed) & 1 == 0 { + now_epoch_secs.saturating_sub(1) + } else { + 0 + }; + insert_draining_writer( + &pool, + next_writer_id, + now_epoch_secs.saturating_sub(240), + bound_clients, + deadline, + ) + .await; + next_writer_id = next_writer_id.saturating_add(1); + } + } + + for _ in 0..64 { + reap_draining_writers(&pool, &mut warn_next_allowed).await; + if writer_count(&pool).await <= threshold as usize { + break; + } + } + + assert!(writer_count(&pool).await <= threshold as usize); + let draining_ids = draining_writer_ids(&pool).await; + assert!(warn_next_allowed.keys().all(|id| draining_ids.contains(id))); +} + +#[tokio::test] +async fn reap_draining_writers_repeated_draining_flips_never_leave_stale_warn_state() { + let (pool, _rng) = make_pool(64, 1, 1).await; + let now_epoch_secs = MePool::now_epoch_secs(); + + for writer_id in 1..=24u64 { + insert_draining_writer( + &pool, + writer_id, + now_epoch_secs.saturating_sub(240), + 1, + 0, + ) + .await; + } + + let mut warn_next_allowed = HashMap::new(); + for _round in 0..48u64 { + for writer_id in 1..=24u64 { + let draining = (writer_id + _round) % 3 != 0; + set_writer_runtime_state( + &pool, + writer_id, + draining, + now_epoch_secs.saturating_sub(120), + 0, + ) + .await; + } + + reap_draining_writers(&pool, &mut warn_next_allowed).await; + + let draining_ids = draining_writer_ids(&pool).await; + assert!( + warn_next_allowed.keys().all(|id| draining_ids.contains(id)), + "warn-state map must not retain entries for writers outside draining set" + ); + } +} + #[test] fn health_drain_close_budget_is_within_expected_bounds() { let budget = health_drain_close_budget(); diff --git a/src/transport/middle_proxy/health_regression_tests.rs b/src/transport/middle_proxy/health_regression_tests.rs index 05a8e6a..fe73670 100644 --- a/src/transport/middle_proxy/health_regression_tests.rs +++ b/src/transport/middle_proxy/health_regression_tests.rs @@ -168,6 +168,21 @@ async fn current_writer_ids(pool: &Arc) -> Vec { writer_ids } +async fn writer_exists(pool: &Arc, writer_id: u64) -> bool { + pool.writers + .read() + .await + .iter() + .any(|writer| writer.id == writer_id) +} + +async fn set_writer_draining(pool: &Arc, writer_id: u64, draining: bool) { + let writers = pool.writers.read().await; + if let Some(writer) = writers.iter().find(|writer| writer.id == writer_id) { + writer.draining.store(draining, Ordering::Relaxed); + } +} + #[tokio::test] async fn reap_draining_writers_drops_warn_state_for_removed_writer() { let pool = make_pool(128).await; @@ -257,6 +272,123 @@ async fn reap_draining_writers_limits_closes_per_health_tick() { assert_eq!(pool.writers.read().await.len(), writer_total - close_budget); } +#[tokio::test] +async fn reap_draining_writers_keeps_warn_state_for_deadline_backlog_writers() { + let pool = make_pool(0).await; + let now_epoch_secs = MePool::now_epoch_secs(); + let close_budget = health_drain_close_budget(); + let writer_total = close_budget.saturating_add(5); + for writer_id in 1..=writer_total as u64 { + insert_draining_writer( + &pool, + writer_id, + now_epoch_secs.saturating_sub(60), + 1, + now_epoch_secs.saturating_sub(1), + ) + .await; + } + let target_writer_id = writer_total as u64; + let mut warn_next_allowed = HashMap::new(); + warn_next_allowed.insert( + target_writer_id, + Instant::now() + Duration::from_secs(300), + ); + + reap_draining_writers(&pool, &mut warn_next_allowed).await; + + assert!(writer_exists(&pool, target_writer_id).await); + assert!(warn_next_allowed.contains_key(&target_writer_id)); +} + +#[tokio::test] +async fn reap_draining_writers_keeps_warn_state_for_overflow_backlog_writers() { + let pool = make_pool(1).await; + let now_epoch_secs = MePool::now_epoch_secs(); + let close_budget = health_drain_close_budget(); + let writer_total = close_budget.saturating_add(6); + for writer_id in 1..=writer_total as u64 { + insert_draining_writer( + &pool, + writer_id, + now_epoch_secs.saturating_sub(300).saturating_add(writer_id), + 1, + 0, + ) + .await; + } + let target_writer_id = writer_total.saturating_sub(1) as u64; + let mut warn_next_allowed = HashMap::new(); + warn_next_allowed.insert( + target_writer_id, + Instant::now() + Duration::from_secs(300), + ); + + reap_draining_writers(&pool, &mut warn_next_allowed).await; + + assert!(writer_exists(&pool, target_writer_id).await); + assert!(warn_next_allowed.contains_key(&target_writer_id)); +} + +#[tokio::test] +async fn reap_draining_writers_drops_warn_state_when_writer_exits_draining_state() { + let pool = make_pool(128).await; + let now_epoch_secs = MePool::now_epoch_secs(); + insert_draining_writer(&pool, 71, now_epoch_secs.saturating_sub(60), 1, 0).await; + + let mut warn_next_allowed = HashMap::new(); + warn_next_allowed.insert(71, Instant::now() + Duration::from_secs(300)); + + set_writer_draining(&pool, 71, false).await; + reap_draining_writers(&pool, &mut warn_next_allowed).await; + + assert!(writer_exists(&pool, 71).await); + assert!( + !warn_next_allowed.contains_key(&71), + "warn cooldown state must be dropped after writer leaves draining state" + ); +} + +#[tokio::test] +async fn reap_draining_writers_preserves_warn_state_across_multiple_budget_deferrals() { + let pool = make_pool(0).await; + let now_epoch_secs = MePool::now_epoch_secs(); + let close_budget = health_drain_close_budget(); + let writer_total = close_budget.saturating_mul(2).saturating_add(1); + for writer_id in 1..=writer_total as u64 { + insert_draining_writer( + &pool, + writer_id, + now_epoch_secs.saturating_sub(120), + 1, + now_epoch_secs.saturating_sub(1), + ) + .await; + } + + let tail_writer_id = writer_total as u64; + let mut warn_next_allowed = HashMap::new(); + warn_next_allowed.insert( + tail_writer_id, + Instant::now() + Duration::from_secs(300), + ); + + reap_draining_writers(&pool, &mut warn_next_allowed).await; + assert!(writer_exists(&pool, tail_writer_id).await); + assert!(warn_next_allowed.contains_key(&tail_writer_id)); + + reap_draining_writers(&pool, &mut warn_next_allowed).await; + assert!(writer_exists(&pool, tail_writer_id).await); + assert!(warn_next_allowed.contains_key(&tail_writer_id)); + + reap_draining_writers(&pool, &mut warn_next_allowed).await; + assert!(!writer_exists(&pool, tail_writer_id).await); + assert!( + !warn_next_allowed.contains_key(&tail_writer_id), + "warn cooldown state must clear once writer is actually removed" + ); +} + #[tokio::test] async fn reap_draining_writers_backlog_drains_across_ticks() { let pool = make_pool(128).await;