From 0c6bb3a6416a29b9d513b83069601e04dec0e9d2 Mon Sep 17 00:00:00 2001 From: David Osipov Date: Tue, 17 Mar 2026 15:43:07 +0400 Subject: [PATCH] feat(proxy): implement auth probe eviction logic and corresponding tests --- src/proxy/handshake.rs | 21 ++++-- src/proxy/handshake_security_tests.rs | 14 ++++ src/proxy/masking.rs | 21 +++--- src/proxy/masking_security_tests.rs | 84 ++++++++++++++++++++++++ src/proxy/middle_relay.rs | 4 +- src/proxy/middle_relay_security_tests.rs | 47 +++++++++++++ 6 files changed, 172 insertions(+), 19 deletions(-) diff --git a/src/proxy/handshake.rs b/src/proxy/handshake.rs index c837b5b..142495c 100644 --- a/src/proxy/handshake.rs +++ b/src/proxy/handshake.rs @@ -7,6 +7,8 @@ use std::collections::HashSet; use std::net::{IpAddr, Ipv6Addr}; use std::sync::Arc; use std::sync::{Mutex, OnceLock}; +use std::collections::hash_map::DefaultHasher; +use std::hash::{Hash, Hasher}; use std::time::{Duration, Instant}; use dashmap::DashMap; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; @@ -84,6 +86,13 @@ fn auth_probe_state_expired(state: &AuthProbeState, now: Instant) -> bool { now.duration_since(state.last_seen) > retention } +fn auth_probe_eviction_offset(peer_ip: IpAddr, now: Instant) -> usize { + let mut hasher = DefaultHasher::new(); + peer_ip.hash(&mut hasher); + now.hash(&mut hasher); + hasher.finish() as usize +} + fn auth_probe_is_throttled(peer_ip: IpAddr, now: Instant) -> bool { let peer_ip = normalize_auth_probe_ip(peer_ip); let state = auth_probe_state_map(); @@ -126,11 +135,9 @@ fn auth_probe_record_failure_with_state( if state.len() >= AUTH_PROBE_TRACK_MAX_ENTRIES { let mut stale_keys = Vec::new(); - let mut eviction_candidate = None; + let mut eviction_candidates = Vec::new(); for entry in state.iter().take(AUTH_PROBE_PRUNE_SCAN_LIMIT) { - if eviction_candidate.is_none() { - eviction_candidate = Some(*entry.key()); - } + eviction_candidates.push(*entry.key()); if auth_probe_state_expired(entry.value(), now) { stale_keys.push(*entry.key()); } @@ -139,9 +146,11 @@ fn auth_probe_record_failure_with_state( state.remove(&stale_key); } if state.len() >= AUTH_PROBE_TRACK_MAX_ENTRIES { - let Some(evict_key) = eviction_candidate else { + if eviction_candidates.is_empty() { return; - }; + } + let idx = auth_probe_eviction_offset(peer_ip, now) % eviction_candidates.len(); + let evict_key = eviction_candidates[idx]; state.remove(&evict_key); } } diff --git a/src/proxy/handshake_security_tests.rs b/src/proxy/handshake_security_tests.rs index c18e520..d8d8d3b 100644 --- a/src/proxy/handshake_security_tests.rs +++ b/src/proxy/handshake_security_tests.rs @@ -980,3 +980,17 @@ fn auth_probe_success_clears_whole_ipv6_prefix_bucket() { "success from the same /64 must clear the shared bucket" ); } + +#[test] +fn auth_probe_eviction_offset_varies_with_input() { + let now = Instant::now(); + let ip1 = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 10)); + let ip2 = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 11)); + + let a = auth_probe_eviction_offset(ip1, now); + let b = auth_probe_eviction_offset(ip1, now); + let c = auth_probe_eviction_offset(ip2, now); + + assert_eq!(a, b, "same input must yield deterministic offset"); + assert_ne!(a, c, "different peer IPs should not collapse to one offset"); +} diff --git a/src/proxy/masking.rs b/src/proxy/masking.rs index 4cffc37..9a23c5b 100644 --- a/src/proxy/masking.rs +++ b/src/proxy/masking.rs @@ -236,17 +236,16 @@ where return; } - let c2m = tokio::spawn(async move { - let _ = tokio::io::copy(&mut reader, &mut mask_write).await; - let _ = mask_write.shutdown().await; - }); - - let m2c = tokio::spawn(async move { - let _ = tokio::io::copy(&mut mask_read, &mut writer).await; - let _ = writer.shutdown().await; - }); - - let _ = tokio::join!(c2m, m2c); + let _ = tokio::join!( + async { + let _ = tokio::io::copy(&mut reader, &mut mask_write).await; + let _ = mask_write.shutdown().await; + }, + async { + let _ = tokio::io::copy(&mut mask_read, &mut writer).await; + let _ = writer.shutdown().await; + } + ); } /// Just consume all data from client without responding diff --git a/src/proxy/masking_security_tests.rs b/src/proxy/masking_security_tests.rs index ffbbd0e..25b6a76 100644 --- a/src/proxy/masking_security_tests.rs +++ b/src/proxy/masking_security_tests.rs @@ -1,5 +1,7 @@ use super::*; use crate::config::ProxyConfig; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, Ordering}; use std::pin::Pin; use std::task::{Context, Poll}; use tokio::io::{duplex, AsyncBufReadExt, BufReader}; @@ -542,6 +544,54 @@ impl tokio::io::AsyncWrite for PendingWriter { } } +struct DropTrackedPendingReader { + dropped: Arc, +} + +impl tokio::io::AsyncRead for DropTrackedPendingReader { + fn poll_read( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + Poll::Pending + } +} + +impl Drop for DropTrackedPendingReader { + fn drop(&mut self) { + self.dropped.store(true, Ordering::SeqCst); + } +} + +struct DropTrackedPendingWriter { + dropped: Arc, +} + +impl tokio::io::AsyncWrite for DropTrackedPendingWriter { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _buf: &[u8], + ) -> Poll> { + Poll::Pending + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} + +impl Drop for DropTrackedPendingWriter { + fn drop(&mut self) { + self.dropped.store(true, Ordering::SeqCst); + } +} + #[tokio::test] async fn proxy_header_write_timeout_returns_false() { let mut writer = PendingWriter; @@ -645,3 +695,37 @@ async fn relay_to_mask_preserves_backend_response_after_client_half_close() { timeout(Duration::from_secs(1), fallback_task).await.unwrap().unwrap(); timeout(Duration::from_secs(1), backend_task).await.unwrap().unwrap(); } + +#[tokio::test] +async fn relay_to_mask_timeout_cancels_and_drops_all_io_endpoints() { + let reader_dropped = Arc::new(AtomicBool::new(false)); + let writer_dropped = Arc::new(AtomicBool::new(false)); + let mask_reader_dropped = Arc::new(AtomicBool::new(false)); + let mask_writer_dropped = Arc::new(AtomicBool::new(false)); + + let reader = DropTrackedPendingReader { + dropped: reader_dropped.clone(), + }; + let writer = DropTrackedPendingWriter { + dropped: writer_dropped.clone(), + }; + let mask_read = DropTrackedPendingReader { + dropped: mask_reader_dropped.clone(), + }; + let mask_write = DropTrackedPendingWriter { + dropped: mask_writer_dropped.clone(), + }; + + let timed = timeout( + Duration::from_millis(40), + relay_to_mask(reader, writer, mask_read, mask_write, b""), + ) + .await; + + assert!(timed.is_err(), "stalled relay must be bounded by timeout"); + + assert!(reader_dropped.load(Ordering::SeqCst)); + assert!(writer_dropped.load(Ordering::SeqCst)); + assert!(mask_reader_dropped.load(Ordering::SeqCst)); + assert!(mask_writer_dropped.load(Ordering::SeqCst)); +} diff --git a/src/proxy/middle_relay.rs b/src/proxy/middle_relay.rs index 19007d8..091094d 100644 --- a/src/proxy/middle_relay.rs +++ b/src/proxy/middle_relay.rs @@ -807,8 +807,8 @@ where payload.truncate(secure_payload_len); } *frame_counter += 1; - let payload = payload.take().freeze(); - return Ok(Some((payload, quickack))); + let payload_bytes = Bytes::copy_from_slice(&payload[..]); + return Ok(Some((payload_bytes, quickack))); } } diff --git a/src/proxy/middle_relay_security_tests.rs b/src/proxy/middle_relay_security_tests.rs index 75f2fad..511a853 100644 --- a/src/proxy/middle_relay_security_tests.rs +++ b/src/proxy/middle_relay_security_tests.rs @@ -429,3 +429,50 @@ async fn read_client_payload_abridged_extended_len_sets_quickack() { assert_eq!(frame.len(), payload_len); assert_eq!(frame_counter, 1, "one abridged frame must be counted"); } + +#[tokio::test] +async fn read_client_payload_returns_buffer_to_pool_after_emit() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("middle relay test lock must be available"); + + let pool = Arc::new(BufferPool::with_config(64, 8)); + pool.preallocate(1); + assert_eq!(pool.stats().pooled, 1, "precondition: one pooled buffer"); + + let (reader, mut writer) = duplex(4096); + let mut crypto_reader = make_crypto_reader(reader); + let stats = Stats::new(); + let forensics = make_forensics_state(); + let mut frame_counter = 0; + + // Force growth beyond default pool buffer size to catch ownership-take regressions. + let payload_len = 257usize; + let mut plaintext = Vec::with_capacity(4 + payload_len); + plaintext.extend_from_slice(&(payload_len as u32).to_le_bytes()); + plaintext.extend((0..payload_len).map(|idx| (idx as u8).wrapping_mul(13))); + + let encrypted = encrypt_for_reader(&plaintext); + writer.write_all(&encrypted).await.unwrap(); + + let _ = read_client_payload( + &mut crypto_reader, + ProtoTag::Intermediate, + payload_len + 8, + TokioDuration::from_secs(1), + &pool, + &forensics, + &mut frame_counter, + &stats, + ) + .await + .expect("payload read must succeed") + .expect("frame must be present"); + + assert_eq!(frame_counter, 1); + let pool_stats = pool.stats(); + assert!( + pool_stats.pooled >= 1, + "emitted payload buffer must be returned to pool to avoid pool drain" + ); +}