mirror of https://github.com/telemt/telemt.git
feat(proxy): implement auth probe eviction logic and corresponding tests
This commit is contained in:
parent
b2e15327fe
commit
0c6bb3a641
|
|
@ -7,6 +7,8 @@ use std::collections::HashSet;
|
|||
use std::net::{IpAddr, Ipv6Addr};
|
||||
use std::sync::Arc;
|
||||
use std::sync::{Mutex, OnceLock};
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
use std::hash::{Hash, Hasher};
|
||||
use std::time::{Duration, Instant};
|
||||
use dashmap::DashMap;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||
|
|
@ -84,6 +86,13 @@ fn auth_probe_state_expired(state: &AuthProbeState, now: Instant) -> bool {
|
|||
now.duration_since(state.last_seen) > retention
|
||||
}
|
||||
|
||||
fn auth_probe_eviction_offset(peer_ip: IpAddr, now: Instant) -> usize {
|
||||
let mut hasher = DefaultHasher::new();
|
||||
peer_ip.hash(&mut hasher);
|
||||
now.hash(&mut hasher);
|
||||
hasher.finish() as usize
|
||||
}
|
||||
|
||||
fn auth_probe_is_throttled(peer_ip: IpAddr, now: Instant) -> bool {
|
||||
let peer_ip = normalize_auth_probe_ip(peer_ip);
|
||||
let state = auth_probe_state_map();
|
||||
|
|
@ -126,11 +135,9 @@ fn auth_probe_record_failure_with_state(
|
|||
|
||||
if state.len() >= AUTH_PROBE_TRACK_MAX_ENTRIES {
|
||||
let mut stale_keys = Vec::new();
|
||||
let mut eviction_candidate = None;
|
||||
let mut eviction_candidates = Vec::new();
|
||||
for entry in state.iter().take(AUTH_PROBE_PRUNE_SCAN_LIMIT) {
|
||||
if eviction_candidate.is_none() {
|
||||
eviction_candidate = Some(*entry.key());
|
||||
}
|
||||
eviction_candidates.push(*entry.key());
|
||||
if auth_probe_state_expired(entry.value(), now) {
|
||||
stale_keys.push(*entry.key());
|
||||
}
|
||||
|
|
@ -139,9 +146,11 @@ fn auth_probe_record_failure_with_state(
|
|||
state.remove(&stale_key);
|
||||
}
|
||||
if state.len() >= AUTH_PROBE_TRACK_MAX_ENTRIES {
|
||||
let Some(evict_key) = eviction_candidate else {
|
||||
if eviction_candidates.is_empty() {
|
||||
return;
|
||||
};
|
||||
}
|
||||
let idx = auth_probe_eviction_offset(peer_ip, now) % eviction_candidates.len();
|
||||
let evict_key = eviction_candidates[idx];
|
||||
state.remove(&evict_key);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -980,3 +980,17 @@ fn auth_probe_success_clears_whole_ipv6_prefix_bucket() {
|
|||
"success from the same /64 must clear the shared bucket"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_probe_eviction_offset_varies_with_input() {
|
||||
let now = Instant::now();
|
||||
let ip1 = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 10));
|
||||
let ip2 = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 11));
|
||||
|
||||
let a = auth_probe_eviction_offset(ip1, now);
|
||||
let b = auth_probe_eviction_offset(ip1, now);
|
||||
let c = auth_probe_eviction_offset(ip2, now);
|
||||
|
||||
assert_eq!(a, b, "same input must yield deterministic offset");
|
||||
assert_ne!(a, c, "different peer IPs should not collapse to one offset");
|
||||
}
|
||||
|
|
|
|||
|
|
@ -236,17 +236,16 @@ where
|
|||
return;
|
||||
}
|
||||
|
||||
let c2m = tokio::spawn(async move {
|
||||
let _ = tokio::io::copy(&mut reader, &mut mask_write).await;
|
||||
let _ = mask_write.shutdown().await;
|
||||
});
|
||||
|
||||
let m2c = tokio::spawn(async move {
|
||||
let _ = tokio::io::copy(&mut mask_read, &mut writer).await;
|
||||
let _ = writer.shutdown().await;
|
||||
});
|
||||
|
||||
let _ = tokio::join!(c2m, m2c);
|
||||
let _ = tokio::join!(
|
||||
async {
|
||||
let _ = tokio::io::copy(&mut reader, &mut mask_write).await;
|
||||
let _ = mask_write.shutdown().await;
|
||||
},
|
||||
async {
|
||||
let _ = tokio::io::copy(&mut mask_read, &mut writer).await;
|
||||
let _ = writer.shutdown().await;
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
/// Just consume all data from client without responding
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
use super::*;
|
||||
use crate::config::ProxyConfig;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
use tokio::io::{duplex, AsyncBufReadExt, BufReader};
|
||||
|
|
@ -542,6 +544,54 @@ impl tokio::io::AsyncWrite for PendingWriter {
|
|||
}
|
||||
}
|
||||
|
||||
struct DropTrackedPendingReader {
|
||||
dropped: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
impl tokio::io::AsyncRead for DropTrackedPendingReader {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
_buf: &mut tokio::io::ReadBuf<'_>,
|
||||
) -> Poll<std::io::Result<()>> {
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for DropTrackedPendingReader {
|
||||
fn drop(&mut self) {
|
||||
self.dropped.store(true, Ordering::SeqCst);
|
||||
}
|
||||
}
|
||||
|
||||
struct DropTrackedPendingWriter {
|
||||
dropped: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
impl tokio::io::AsyncWrite for DropTrackedPendingWriter {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
_buf: &[u8],
|
||||
) -> Poll<std::io::Result<usize>> {
|
||||
Poll::Pending
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for DropTrackedPendingWriter {
|
||||
fn drop(&mut self) {
|
||||
self.dropped.store(true, Ordering::SeqCst);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn proxy_header_write_timeout_returns_false() {
|
||||
let mut writer = PendingWriter;
|
||||
|
|
@ -645,3 +695,37 @@ async fn relay_to_mask_preserves_backend_response_after_client_half_close() {
|
|||
timeout(Duration::from_secs(1), fallback_task).await.unwrap().unwrap();
|
||||
timeout(Duration::from_secs(1), backend_task).await.unwrap().unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn relay_to_mask_timeout_cancels_and_drops_all_io_endpoints() {
|
||||
let reader_dropped = Arc::new(AtomicBool::new(false));
|
||||
let writer_dropped = Arc::new(AtomicBool::new(false));
|
||||
let mask_reader_dropped = Arc::new(AtomicBool::new(false));
|
||||
let mask_writer_dropped = Arc::new(AtomicBool::new(false));
|
||||
|
||||
let reader = DropTrackedPendingReader {
|
||||
dropped: reader_dropped.clone(),
|
||||
};
|
||||
let writer = DropTrackedPendingWriter {
|
||||
dropped: writer_dropped.clone(),
|
||||
};
|
||||
let mask_read = DropTrackedPendingReader {
|
||||
dropped: mask_reader_dropped.clone(),
|
||||
};
|
||||
let mask_write = DropTrackedPendingWriter {
|
||||
dropped: mask_writer_dropped.clone(),
|
||||
};
|
||||
|
||||
let timed = timeout(
|
||||
Duration::from_millis(40),
|
||||
relay_to_mask(reader, writer, mask_read, mask_write, b""),
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(timed.is_err(), "stalled relay must be bounded by timeout");
|
||||
|
||||
assert!(reader_dropped.load(Ordering::SeqCst));
|
||||
assert!(writer_dropped.load(Ordering::SeqCst));
|
||||
assert!(mask_reader_dropped.load(Ordering::SeqCst));
|
||||
assert!(mask_writer_dropped.load(Ordering::SeqCst));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -807,8 +807,8 @@ where
|
|||
payload.truncate(secure_payload_len);
|
||||
}
|
||||
*frame_counter += 1;
|
||||
let payload = payload.take().freeze();
|
||||
return Ok(Some((payload, quickack)));
|
||||
let payload_bytes = Bytes::copy_from_slice(&payload[..]);
|
||||
return Ok(Some((payload_bytes, quickack)));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -429,3 +429,50 @@ async fn read_client_payload_abridged_extended_len_sets_quickack() {
|
|||
assert_eq!(frame.len(), payload_len);
|
||||
assert_eq!(frame_counter, 1, "one abridged frame must be counted");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn read_client_payload_returns_buffer_to_pool_after_emit() {
|
||||
let _guard = desync_dedup_test_lock()
|
||||
.lock()
|
||||
.expect("middle relay test lock must be available");
|
||||
|
||||
let pool = Arc::new(BufferPool::with_config(64, 8));
|
||||
pool.preallocate(1);
|
||||
assert_eq!(pool.stats().pooled, 1, "precondition: one pooled buffer");
|
||||
|
||||
let (reader, mut writer) = duplex(4096);
|
||||
let mut crypto_reader = make_crypto_reader(reader);
|
||||
let stats = Stats::new();
|
||||
let forensics = make_forensics_state();
|
||||
let mut frame_counter = 0;
|
||||
|
||||
// Force growth beyond default pool buffer size to catch ownership-take regressions.
|
||||
let payload_len = 257usize;
|
||||
let mut plaintext = Vec::with_capacity(4 + payload_len);
|
||||
plaintext.extend_from_slice(&(payload_len as u32).to_le_bytes());
|
||||
plaintext.extend((0..payload_len).map(|idx| (idx as u8).wrapping_mul(13)));
|
||||
|
||||
let encrypted = encrypt_for_reader(&plaintext);
|
||||
writer.write_all(&encrypted).await.unwrap();
|
||||
|
||||
let _ = read_client_payload(
|
||||
&mut crypto_reader,
|
||||
ProtoTag::Intermediate,
|
||||
payload_len + 8,
|
||||
TokioDuration::from_secs(1),
|
||||
&pool,
|
||||
&forensics,
|
||||
&mut frame_counter,
|
||||
&stats,
|
||||
)
|
||||
.await
|
||||
.expect("payload read must succeed")
|
||||
.expect("frame must be present");
|
||||
|
||||
assert_eq!(frame_counter, 1);
|
||||
let pool_stats = pool.stats();
|
||||
assert!(
|
||||
pool_stats.pooled >= 1,
|
||||
"emitted payload buffer must be returned to pool to avoid pool drain"
|
||||
);
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue