mirror of https://github.com/telemt/telemt.git
feat(proxy): implement auth probe eviction logic and corresponding tests
This commit is contained in:
parent
b2e15327fe
commit
0c6bb3a641
|
|
@ -7,6 +7,8 @@ use std::collections::HashSet;
|
||||||
use std::net::{IpAddr, Ipv6Addr};
|
use std::net::{IpAddr, Ipv6Addr};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::sync::{Mutex, OnceLock};
|
use std::sync::{Mutex, OnceLock};
|
||||||
|
use std::collections::hash_map::DefaultHasher;
|
||||||
|
use std::hash::{Hash, Hasher};
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
use dashmap::DashMap;
|
use dashmap::DashMap;
|
||||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||||
|
|
@ -84,6 +86,13 @@ fn auth_probe_state_expired(state: &AuthProbeState, now: Instant) -> bool {
|
||||||
now.duration_since(state.last_seen) > retention
|
now.duration_since(state.last_seen) > retention
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn auth_probe_eviction_offset(peer_ip: IpAddr, now: Instant) -> usize {
|
||||||
|
let mut hasher = DefaultHasher::new();
|
||||||
|
peer_ip.hash(&mut hasher);
|
||||||
|
now.hash(&mut hasher);
|
||||||
|
hasher.finish() as usize
|
||||||
|
}
|
||||||
|
|
||||||
fn auth_probe_is_throttled(peer_ip: IpAddr, now: Instant) -> bool {
|
fn auth_probe_is_throttled(peer_ip: IpAddr, now: Instant) -> bool {
|
||||||
let peer_ip = normalize_auth_probe_ip(peer_ip);
|
let peer_ip = normalize_auth_probe_ip(peer_ip);
|
||||||
let state = auth_probe_state_map();
|
let state = auth_probe_state_map();
|
||||||
|
|
@ -126,11 +135,9 @@ fn auth_probe_record_failure_with_state(
|
||||||
|
|
||||||
if state.len() >= AUTH_PROBE_TRACK_MAX_ENTRIES {
|
if state.len() >= AUTH_PROBE_TRACK_MAX_ENTRIES {
|
||||||
let mut stale_keys = Vec::new();
|
let mut stale_keys = Vec::new();
|
||||||
let mut eviction_candidate = None;
|
let mut eviction_candidates = Vec::new();
|
||||||
for entry in state.iter().take(AUTH_PROBE_PRUNE_SCAN_LIMIT) {
|
for entry in state.iter().take(AUTH_PROBE_PRUNE_SCAN_LIMIT) {
|
||||||
if eviction_candidate.is_none() {
|
eviction_candidates.push(*entry.key());
|
||||||
eviction_candidate = Some(*entry.key());
|
|
||||||
}
|
|
||||||
if auth_probe_state_expired(entry.value(), now) {
|
if auth_probe_state_expired(entry.value(), now) {
|
||||||
stale_keys.push(*entry.key());
|
stale_keys.push(*entry.key());
|
||||||
}
|
}
|
||||||
|
|
@ -139,9 +146,11 @@ fn auth_probe_record_failure_with_state(
|
||||||
state.remove(&stale_key);
|
state.remove(&stale_key);
|
||||||
}
|
}
|
||||||
if state.len() >= AUTH_PROBE_TRACK_MAX_ENTRIES {
|
if state.len() >= AUTH_PROBE_TRACK_MAX_ENTRIES {
|
||||||
let Some(evict_key) = eviction_candidate else {
|
if eviction_candidates.is_empty() {
|
||||||
return;
|
return;
|
||||||
};
|
}
|
||||||
|
let idx = auth_probe_eviction_offset(peer_ip, now) % eviction_candidates.len();
|
||||||
|
let evict_key = eviction_candidates[idx];
|
||||||
state.remove(&evict_key);
|
state.remove(&evict_key);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -980,3 +980,17 @@ fn auth_probe_success_clears_whole_ipv6_prefix_bucket() {
|
||||||
"success from the same /64 must clear the shared bucket"
|
"success from the same /64 must clear the shared bucket"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn auth_probe_eviction_offset_varies_with_input() {
|
||||||
|
let now = Instant::now();
|
||||||
|
let ip1 = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 10));
|
||||||
|
let ip2 = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 11));
|
||||||
|
|
||||||
|
let a = auth_probe_eviction_offset(ip1, now);
|
||||||
|
let b = auth_probe_eviction_offset(ip1, now);
|
||||||
|
let c = auth_probe_eviction_offset(ip2, now);
|
||||||
|
|
||||||
|
assert_eq!(a, b, "same input must yield deterministic offset");
|
||||||
|
assert_ne!(a, c, "different peer IPs should not collapse to one offset");
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -236,17 +236,16 @@ where
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
let c2m = tokio::spawn(async move {
|
let _ = tokio::join!(
|
||||||
let _ = tokio::io::copy(&mut reader, &mut mask_write).await;
|
async {
|
||||||
let _ = mask_write.shutdown().await;
|
let _ = tokio::io::copy(&mut reader, &mut mask_write).await;
|
||||||
});
|
let _ = mask_write.shutdown().await;
|
||||||
|
},
|
||||||
let m2c = tokio::spawn(async move {
|
async {
|
||||||
let _ = tokio::io::copy(&mut mask_read, &mut writer).await;
|
let _ = tokio::io::copy(&mut mask_read, &mut writer).await;
|
||||||
let _ = writer.shutdown().await;
|
let _ = writer.shutdown().await;
|
||||||
});
|
}
|
||||||
|
);
|
||||||
let _ = tokio::join!(c2m, m2c);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Just consume all data from client without responding
|
/// Just consume all data from client without responding
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,7 @@
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::config::ProxyConfig;
|
use crate::config::ProxyConfig;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
use std::pin::Pin;
|
use std::pin::Pin;
|
||||||
use std::task::{Context, Poll};
|
use std::task::{Context, Poll};
|
||||||
use tokio::io::{duplex, AsyncBufReadExt, BufReader};
|
use tokio::io::{duplex, AsyncBufReadExt, BufReader};
|
||||||
|
|
@ -542,6 +544,54 @@ impl tokio::io::AsyncWrite for PendingWriter {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct DropTrackedPendingReader {
|
||||||
|
dropped: Arc<AtomicBool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl tokio::io::AsyncRead for DropTrackedPendingReader {
|
||||||
|
fn poll_read(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
_cx: &mut Context<'_>,
|
||||||
|
_buf: &mut tokio::io::ReadBuf<'_>,
|
||||||
|
) -> Poll<std::io::Result<()>> {
|
||||||
|
Poll::Pending
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for DropTrackedPendingReader {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
self.dropped.store(true, Ordering::SeqCst);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct DropTrackedPendingWriter {
|
||||||
|
dropped: Arc<AtomicBool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl tokio::io::AsyncWrite for DropTrackedPendingWriter {
|
||||||
|
fn poll_write(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
_cx: &mut Context<'_>,
|
||||||
|
_buf: &[u8],
|
||||||
|
) -> Poll<std::io::Result<usize>> {
|
||||||
|
Poll::Pending
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for DropTrackedPendingWriter {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
self.dropped.store(true, Ordering::SeqCst);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn proxy_header_write_timeout_returns_false() {
|
async fn proxy_header_write_timeout_returns_false() {
|
||||||
let mut writer = PendingWriter;
|
let mut writer = PendingWriter;
|
||||||
|
|
@ -645,3 +695,37 @@ async fn relay_to_mask_preserves_backend_response_after_client_half_close() {
|
||||||
timeout(Duration::from_secs(1), fallback_task).await.unwrap().unwrap();
|
timeout(Duration::from_secs(1), fallback_task).await.unwrap().unwrap();
|
||||||
timeout(Duration::from_secs(1), backend_task).await.unwrap().unwrap();
|
timeout(Duration::from_secs(1), backend_task).await.unwrap().unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn relay_to_mask_timeout_cancels_and_drops_all_io_endpoints() {
|
||||||
|
let reader_dropped = Arc::new(AtomicBool::new(false));
|
||||||
|
let writer_dropped = Arc::new(AtomicBool::new(false));
|
||||||
|
let mask_reader_dropped = Arc::new(AtomicBool::new(false));
|
||||||
|
let mask_writer_dropped = Arc::new(AtomicBool::new(false));
|
||||||
|
|
||||||
|
let reader = DropTrackedPendingReader {
|
||||||
|
dropped: reader_dropped.clone(),
|
||||||
|
};
|
||||||
|
let writer = DropTrackedPendingWriter {
|
||||||
|
dropped: writer_dropped.clone(),
|
||||||
|
};
|
||||||
|
let mask_read = DropTrackedPendingReader {
|
||||||
|
dropped: mask_reader_dropped.clone(),
|
||||||
|
};
|
||||||
|
let mask_write = DropTrackedPendingWriter {
|
||||||
|
dropped: mask_writer_dropped.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let timed = timeout(
|
||||||
|
Duration::from_millis(40),
|
||||||
|
relay_to_mask(reader, writer, mask_read, mask_write, b""),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert!(timed.is_err(), "stalled relay must be bounded by timeout");
|
||||||
|
|
||||||
|
assert!(reader_dropped.load(Ordering::SeqCst));
|
||||||
|
assert!(writer_dropped.load(Ordering::SeqCst));
|
||||||
|
assert!(mask_reader_dropped.load(Ordering::SeqCst));
|
||||||
|
assert!(mask_writer_dropped.load(Ordering::SeqCst));
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -807,8 +807,8 @@ where
|
||||||
payload.truncate(secure_payload_len);
|
payload.truncate(secure_payload_len);
|
||||||
}
|
}
|
||||||
*frame_counter += 1;
|
*frame_counter += 1;
|
||||||
let payload = payload.take().freeze();
|
let payload_bytes = Bytes::copy_from_slice(&payload[..]);
|
||||||
return Ok(Some((payload, quickack)));
|
return Ok(Some((payload_bytes, quickack)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -429,3 +429,50 @@ async fn read_client_payload_abridged_extended_len_sets_quickack() {
|
||||||
assert_eq!(frame.len(), payload_len);
|
assert_eq!(frame.len(), payload_len);
|
||||||
assert_eq!(frame_counter, 1, "one abridged frame must be counted");
|
assert_eq!(frame_counter, 1, "one abridged frame must be counted");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn read_client_payload_returns_buffer_to_pool_after_emit() {
|
||||||
|
let _guard = desync_dedup_test_lock()
|
||||||
|
.lock()
|
||||||
|
.expect("middle relay test lock must be available");
|
||||||
|
|
||||||
|
let pool = Arc::new(BufferPool::with_config(64, 8));
|
||||||
|
pool.preallocate(1);
|
||||||
|
assert_eq!(pool.stats().pooled, 1, "precondition: one pooled buffer");
|
||||||
|
|
||||||
|
let (reader, mut writer) = duplex(4096);
|
||||||
|
let mut crypto_reader = make_crypto_reader(reader);
|
||||||
|
let stats = Stats::new();
|
||||||
|
let forensics = make_forensics_state();
|
||||||
|
let mut frame_counter = 0;
|
||||||
|
|
||||||
|
// Force growth beyond default pool buffer size to catch ownership-take regressions.
|
||||||
|
let payload_len = 257usize;
|
||||||
|
let mut plaintext = Vec::with_capacity(4 + payload_len);
|
||||||
|
plaintext.extend_from_slice(&(payload_len as u32).to_le_bytes());
|
||||||
|
plaintext.extend((0..payload_len).map(|idx| (idx as u8).wrapping_mul(13)));
|
||||||
|
|
||||||
|
let encrypted = encrypt_for_reader(&plaintext);
|
||||||
|
writer.write_all(&encrypted).await.unwrap();
|
||||||
|
|
||||||
|
let _ = read_client_payload(
|
||||||
|
&mut crypto_reader,
|
||||||
|
ProtoTag::Intermediate,
|
||||||
|
payload_len + 8,
|
||||||
|
TokioDuration::from_secs(1),
|
||||||
|
&pool,
|
||||||
|
&forensics,
|
||||||
|
&mut frame_counter,
|
||||||
|
&stats,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("payload read must succeed")
|
||||||
|
.expect("frame must be present");
|
||||||
|
|
||||||
|
assert_eq!(frame_counter, 1);
|
||||||
|
let pool_stats = pool.stats();
|
||||||
|
assert!(
|
||||||
|
pool_stats.pooled >= 1,
|
||||||
|
"emitted payload buffer must be returned to pool to avoid pool drain"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue