feat(proxy): implement auth probe eviction logic and corresponding tests

This commit is contained in:
David Osipov 2026-03-17 15:43:07 +04:00
parent b2e15327fe
commit 0c6bb3a641
No known key found for this signature in database
GPG Key ID: 0E55C4A47454E82E
6 changed files with 172 additions and 19 deletions

View File

@ -7,6 +7,8 @@ use std::collections::HashSet;
use std::net::{IpAddr, Ipv6Addr};
use std::sync::Arc;
use std::sync::{Mutex, OnceLock};
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::time::{Duration, Instant};
use dashmap::DashMap;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
@ -84,6 +86,13 @@ fn auth_probe_state_expired(state: &AuthProbeState, now: Instant) -> bool {
now.duration_since(state.last_seen) > retention
}
fn auth_probe_eviction_offset(peer_ip: IpAddr, now: Instant) -> usize {
let mut hasher = DefaultHasher::new();
peer_ip.hash(&mut hasher);
now.hash(&mut hasher);
hasher.finish() as usize
}
fn auth_probe_is_throttled(peer_ip: IpAddr, now: Instant) -> bool {
let peer_ip = normalize_auth_probe_ip(peer_ip);
let state = auth_probe_state_map();
@ -126,11 +135,9 @@ fn auth_probe_record_failure_with_state(
if state.len() >= AUTH_PROBE_TRACK_MAX_ENTRIES {
let mut stale_keys = Vec::new();
let mut eviction_candidate = None;
let mut eviction_candidates = Vec::new();
for entry in state.iter().take(AUTH_PROBE_PRUNE_SCAN_LIMIT) {
if eviction_candidate.is_none() {
eviction_candidate = Some(*entry.key());
}
eviction_candidates.push(*entry.key());
if auth_probe_state_expired(entry.value(), now) {
stale_keys.push(*entry.key());
}
@ -139,9 +146,11 @@ fn auth_probe_record_failure_with_state(
state.remove(&stale_key);
}
if state.len() >= AUTH_PROBE_TRACK_MAX_ENTRIES {
let Some(evict_key) = eviction_candidate else {
if eviction_candidates.is_empty() {
return;
};
}
let idx = auth_probe_eviction_offset(peer_ip, now) % eviction_candidates.len();
let evict_key = eviction_candidates[idx];
state.remove(&evict_key);
}
}

View File

@ -980,3 +980,17 @@ fn auth_probe_success_clears_whole_ipv6_prefix_bucket() {
"success from the same /64 must clear the shared bucket"
);
}
#[test]
fn auth_probe_eviction_offset_varies_with_input() {
let now = Instant::now();
let ip1 = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 10));
let ip2 = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 11));
let a = auth_probe_eviction_offset(ip1, now);
let b = auth_probe_eviction_offset(ip1, now);
let c = auth_probe_eviction_offset(ip2, now);
assert_eq!(a, b, "same input must yield deterministic offset");
assert_ne!(a, c, "different peer IPs should not collapse to one offset");
}

View File

@ -236,17 +236,16 @@ where
return;
}
let c2m = tokio::spawn(async move {
let _ = tokio::io::copy(&mut reader, &mut mask_write).await;
let _ = mask_write.shutdown().await;
});
let m2c = tokio::spawn(async move {
let _ = tokio::io::copy(&mut mask_read, &mut writer).await;
let _ = writer.shutdown().await;
});
let _ = tokio::join!(c2m, m2c);
let _ = tokio::join!(
async {
let _ = tokio::io::copy(&mut reader, &mut mask_write).await;
let _ = mask_write.shutdown().await;
},
async {
let _ = tokio::io::copy(&mut mask_read, &mut writer).await;
let _ = writer.shutdown().await;
}
);
}
/// Just consume all data from client without responding

View File

@ -1,5 +1,7 @@
use super::*;
use crate::config::ProxyConfig;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{duplex, AsyncBufReadExt, BufReader};
@ -542,6 +544,54 @@ impl tokio::io::AsyncWrite for PendingWriter {
}
}
struct DropTrackedPendingReader {
dropped: Arc<AtomicBool>,
}
impl tokio::io::AsyncRead for DropTrackedPendingReader {
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
Poll::Pending
}
}
impl Drop for DropTrackedPendingReader {
fn drop(&mut self) {
self.dropped.store(true, Ordering::SeqCst);
}
}
struct DropTrackedPendingWriter {
dropped: Arc<AtomicBool>,
}
impl tokio::io::AsyncWrite for DropTrackedPendingWriter {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &[u8],
) -> Poll<std::io::Result<usize>> {
Poll::Pending
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Poll::Ready(Ok(()))
}
}
impl Drop for DropTrackedPendingWriter {
fn drop(&mut self) {
self.dropped.store(true, Ordering::SeqCst);
}
}
#[tokio::test]
async fn proxy_header_write_timeout_returns_false() {
let mut writer = PendingWriter;
@ -645,3 +695,37 @@ async fn relay_to_mask_preserves_backend_response_after_client_half_close() {
timeout(Duration::from_secs(1), fallback_task).await.unwrap().unwrap();
timeout(Duration::from_secs(1), backend_task).await.unwrap().unwrap();
}
#[tokio::test]
async fn relay_to_mask_timeout_cancels_and_drops_all_io_endpoints() {
let reader_dropped = Arc::new(AtomicBool::new(false));
let writer_dropped = Arc::new(AtomicBool::new(false));
let mask_reader_dropped = Arc::new(AtomicBool::new(false));
let mask_writer_dropped = Arc::new(AtomicBool::new(false));
let reader = DropTrackedPendingReader {
dropped: reader_dropped.clone(),
};
let writer = DropTrackedPendingWriter {
dropped: writer_dropped.clone(),
};
let mask_read = DropTrackedPendingReader {
dropped: mask_reader_dropped.clone(),
};
let mask_write = DropTrackedPendingWriter {
dropped: mask_writer_dropped.clone(),
};
let timed = timeout(
Duration::from_millis(40),
relay_to_mask(reader, writer, mask_read, mask_write, b""),
)
.await;
assert!(timed.is_err(), "stalled relay must be bounded by timeout");
assert!(reader_dropped.load(Ordering::SeqCst));
assert!(writer_dropped.load(Ordering::SeqCst));
assert!(mask_reader_dropped.load(Ordering::SeqCst));
assert!(mask_writer_dropped.load(Ordering::SeqCst));
}

View File

@ -807,8 +807,8 @@ where
payload.truncate(secure_payload_len);
}
*frame_counter += 1;
let payload = payload.take().freeze();
return Ok(Some((payload, quickack)));
let payload_bytes = Bytes::copy_from_slice(&payload[..]);
return Ok(Some((payload_bytes, quickack)));
}
}

View File

@ -429,3 +429,50 @@ async fn read_client_payload_abridged_extended_len_sets_quickack() {
assert_eq!(frame.len(), payload_len);
assert_eq!(frame_counter, 1, "one abridged frame must be counted");
}
#[tokio::test]
async fn read_client_payload_returns_buffer_to_pool_after_emit() {
let _guard = desync_dedup_test_lock()
.lock()
.expect("middle relay test lock must be available");
let pool = Arc::new(BufferPool::with_config(64, 8));
pool.preallocate(1);
assert_eq!(pool.stats().pooled, 1, "precondition: one pooled buffer");
let (reader, mut writer) = duplex(4096);
let mut crypto_reader = make_crypto_reader(reader);
let stats = Stats::new();
let forensics = make_forensics_state();
let mut frame_counter = 0;
// Force growth beyond default pool buffer size to catch ownership-take regressions.
let payload_len = 257usize;
let mut plaintext = Vec::with_capacity(4 + payload_len);
plaintext.extend_from_slice(&(payload_len as u32).to_le_bytes());
plaintext.extend((0..payload_len).map(|idx| (idx as u8).wrapping_mul(13)));
let encrypted = encrypt_for_reader(&plaintext);
writer.write_all(&encrypted).await.unwrap();
let _ = read_client_payload(
&mut crypto_reader,
ProtoTag::Intermediate,
payload_len + 8,
TokioDuration::from_secs(1),
&pool,
&forensics,
&mut frame_counter,
&stats,
)
.await
.expect("payload read must succeed")
.expect("frame must be present");
assert_eq!(frame_counter, 1);
let pool_stats = pool.stats();
assert!(
pool_stats.pooled >= 1,
"emitted payload buffer must be returned to pool to avoid pool drain"
);
}