feat(proxy): enhance auth probe handling with IPv6 normalization and eviction logic

This commit is contained in:
David Osipov 2026-03-17 15:15:12 +04:00
parent 8821e38013
commit b2e15327fe
No known key found for this signature in database
GPG Key ID: 0E55C4A47454E82E
8 changed files with 608 additions and 77 deletions

View File

@ -381,7 +381,7 @@ fn validate_tls_handshake_at_time_with_boot_cap(
let mut msg = handshake.to_vec();
msg[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].fill(0);
let mut first_match: Option<TlsValidation> = None;
let mut first_match: Option<(&String, u32)> = None;
for (user, secret) in secrets {
let computed = sha256_hmac(secret, &msg);
@ -421,16 +421,16 @@ fn validate_tls_handshake_at_time_with_boot_cap(
}
if first_match.is_none() {
first_match = Some(TlsValidation {
user: user.clone(),
session_id: session_id.clone(),
digest,
timestamp,
});
first_match = Some((user, timestamp));
}
}
first_match
first_match.map(|(user, timestamp)| TlsValidation {
user: user.clone(),
session_id,
digest,
timestamp,
})
}
fn curve25519_prime() -> BigUint {

View File

@ -9,12 +9,19 @@ use crate::crypto::sha256_hmac;
/// [TLS_DIGEST_POS..+32] : digest = HMAC XOR [0..0 || timestamp_le]
/// [TLS_DIGEST_POS+32] : session_id_len = 32
/// [TLS_DIGEST_POS+33..+65] : session_id filler (0x42)
fn make_valid_tls_handshake(secret: &[u8], timestamp: u32) -> Vec<u8> {
let session_id_len: usize = 32;
fn make_valid_tls_handshake_with_session_id(
secret: &[u8],
timestamp: u32,
session_id: &[u8],
) -> Vec<u8> {
let session_id_len = session_id.len();
assert!(session_id_len <= u8::MAX as usize);
let len = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1 + session_id_len;
let mut handshake = vec![0x42u8; len];
handshake[TLS_DIGEST_POS + TLS_DIGEST_LEN] = session_id_len as u8;
let sid_start = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1;
handshake[sid_start..sid_start + session_id_len].copy_from_slice(session_id);
// Zero the digest slot before computing HMAC (mirrors what validate does).
handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].fill(0);
@ -34,6 +41,10 @@ fn make_valid_tls_handshake(secret: &[u8], timestamp: u32) -> Vec<u8> {
handshake
}
fn make_valid_tls_handshake(secret: &[u8], timestamp: u32) -> Vec<u8> {
make_valid_tls_handshake_with_session_id(secret, timestamp, &[0x42; 32])
}
// ------------------------------------------------------------------
// Happy-path sanity
// ------------------------------------------------------------------
@ -311,6 +322,20 @@ fn too_short_handshake_rejected_without_panic() {
assert!(validate_tls_handshake(&[], &secrets, true).is_none());
}
#[test]
fn all_prefix_lengths_below_minimum_rejected_without_panic() {
let min_len = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1;
let secrets = vec![("u".to_string(), b"s".to_vec())];
for len in 0..min_len {
let h = vec![0u8; len];
assert!(
validate_tls_handshake(&h, &secrets, true).is_none(),
"prefix length {len} below minimum must be rejected"
);
}
}
#[test]
fn claimed_session_id_overflows_buffer_rejected() {
let session_id_len: usize = 32;
@ -332,6 +357,30 @@ fn max_session_id_len_255_does_not_panic() {
assert!(validate_tls_handshake(&h, &secrets, true).is_none());
}
#[test]
fn one_byte_session_id_validates_and_is_preserved() {
let secret = b"sid_len_1_test";
let handshake = make_valid_tls_handshake_with_session_id(secret, 0, &[0xAB]);
let secrets = vec![("u".to_string(), secret.to_vec())];
let result = validate_tls_handshake(&handshake, &secrets, true)
.expect("one-byte session_id handshake must validate");
assert_eq!(result.session_id, vec![0xAB]);
}
#[test]
fn max_session_id_len_255_with_valid_digest_is_accepted() {
let secret = b"sid_len_255_test";
let session_id = vec![0xCCu8; 255];
let handshake = make_valid_tls_handshake_with_session_id(secret, 0, &session_id);
let secrets = vec![("u".to_string(), secret.to_vec())];
let result = validate_tls_handshake(&handshake, &secrets, true)
.expect("session_id_len=255 with valid digest must validate");
assert_eq!(result.session_id.len(), 255);
assert_eq!(result.session_id, session_id);
}
// ------------------------------------------------------------------
// Adversarial digest values
// ------------------------------------------------------------------
@ -867,6 +916,23 @@ fn test_parse_tls_record_header() {
assert_eq!(result.1, 16384);
}
#[test]
fn parse_tls_record_header_rejects_invalid_versions() {
let invalid = [
[0x16, 0x03, 0x00, 0x00, 0x10],
[0x16, 0x02, 0x00, 0x00, 0x10],
[0x16, 0x03, 0x02, 0x00, 0x10],
[0x16, 0x04, 0x00, 0x00, 0x10],
];
for header in invalid {
assert!(
parse_tls_record_header(&header).is_none(),
"invalid TLS record version {:?} must be rejected",
[header[1], header[2]]
);
}
}
#[test]
fn test_gen_fake_x25519_key() {
let rng = crate::crypto::SecureRandom::new();
@ -1168,6 +1234,47 @@ fn extract_sni_rejects_when_extension_block_is_truncated() {
assert!(extract_sni_from_client_hello(&ch).is_none());
}
#[test]
fn extract_sni_rejects_session_id_len_overflow() {
let mut ch = build_client_hello_with_exts(Vec::new(), "example.com");
let sid_len_pos = 5 + 4 + 2 + 32;
ch[sid_len_pos] = 255;
assert!(extract_sni_from_client_hello(&ch).is_none());
}
#[test]
fn extract_sni_rejects_cipher_suites_len_overflow() {
let mut ch = build_client_hello_with_exts(Vec::new(), "example.com");
let sid_len_pos = 5 + 4 + 2 + 32;
let cipher_len_pos = sid_len_pos + 1 + ch[sid_len_pos] as usize;
ch[cipher_len_pos] = 0xFF;
ch[cipher_len_pos + 1] = 0xFF;
assert!(extract_sni_from_client_hello(&ch).is_none());
}
#[test]
fn extract_sni_rejects_compression_methods_len_overflow() {
let mut ch = build_client_hello_with_exts(Vec::new(), "example.com");
let sid_len_pos = 5 + 4 + 2 + 32;
let cipher_len_pos = sid_len_pos + 1 + ch[sid_len_pos] as usize;
let cipher_len = u16::from_be_bytes([ch[cipher_len_pos], ch[cipher_len_pos + 1]]) as usize;
let comp_len_pos = cipher_len_pos + 2 + cipher_len;
ch[comp_len_pos] = 0xFF;
assert!(extract_sni_from_client_hello(&ch).is_none());
}
#[test]
fn extract_alpn_returns_empty_on_session_id_len_overflow() {
let mut alpn_data = Vec::new();
alpn_data.extend_from_slice(&3u16.to_be_bytes());
alpn_data.push(2);
alpn_data.extend_from_slice(b"h2");
let mut ch = build_client_hello_with_exts(vec![(0x0010, alpn_data)], "alpn.test");
let sid_len_pos = 5 + 4 + 2 + 32;
ch[sid_len_pos] = 255;
assert!(extract_alpn_from_client_hello(&ch).is_empty());
}
#[test]
fn extract_alpn_rejects_when_extension_block_is_truncated() {
let mut ext_blob = Vec::new();

View File

@ -4,7 +4,7 @@
use std::net::SocketAddr;
use std::collections::HashSet;
use std::net::IpAddr;
use std::net::{IpAddr, Ipv6Addr};
use std::sync::Arc;
use std::sync::{Mutex, OnceLock};
use std::time::{Duration, Instant};
@ -57,6 +57,16 @@ fn auth_probe_state_map() -> &'static DashMap<IpAddr, AuthProbeState> {
AUTH_PROBE_STATE.get_or_init(DashMap::new)
}
fn normalize_auth_probe_ip(peer_ip: IpAddr) -> IpAddr {
match peer_ip {
IpAddr::V4(ip) => IpAddr::V4(ip),
IpAddr::V6(ip) => {
let [a, b, c, d, _, _, _, _] = ip.segments();
IpAddr::V6(Ipv6Addr::new(a, b, c, d, 0, 0, 0, 0))
}
}
}
fn auth_probe_backoff(fail_streak: u32) -> Duration {
if fail_streak < AUTH_PROBE_BACKOFF_START_FAILS {
return Duration::ZERO;
@ -75,6 +85,7 @@ fn auth_probe_state_expired(state: &AuthProbeState, now: Instant) -> bool {
}
fn auth_probe_is_throttled(peer_ip: IpAddr, now: Instant) -> bool {
let peer_ip = normalize_auth_probe_ip(peer_ip);
let state = auth_probe_state_map();
let Some(entry) = state.get(&peer_ip) else {
return false;
@ -88,6 +99,7 @@ fn auth_probe_is_throttled(peer_ip: IpAddr, now: Instant) -> bool {
}
fn auth_probe_record_failure(peer_ip: IpAddr, now: Instant) {
let peer_ip = normalize_auth_probe_ip(peer_ip);
let state = auth_probe_state_map();
auth_probe_record_failure_with_state(state, peer_ip, now);
}
@ -114,7 +126,11 @@ fn auth_probe_record_failure_with_state(
if state.len() >= AUTH_PROBE_TRACK_MAX_ENTRIES {
let mut stale_keys = Vec::new();
let mut eviction_candidate = None;
for entry in state.iter().take(AUTH_PROBE_PRUNE_SCAN_LIMIT) {
if eviction_candidate.is_none() {
eviction_candidate = Some(*entry.key());
}
if auth_probe_state_expired(entry.value(), now) {
stale_keys.push(*entry.key());
}
@ -123,23 +139,22 @@ fn auth_probe_record_failure_with_state(
state.remove(&stale_key);
}
if state.len() >= AUTH_PROBE_TRACK_MAX_ENTRIES {
let Some(evict_key) = eviction_candidate else {
return;
};
state.remove(&evict_key);
}
}
state.insert(peer_ip, AuthProbeState {
fail_streak: 0,
blocked_until: now,
fail_streak: 1,
blocked_until: now + auth_probe_backoff(1),
last_seen: now,
});
if let Some(mut entry) = state.get_mut(&peer_ip) {
entry.fail_streak = 1;
entry.blocked_until = now + auth_probe_backoff(1);
}
}
fn auth_probe_record_success(peer_ip: IpAddr) {
let peer_ip = normalize_auth_probe_ip(peer_ip);
let state = auth_probe_state_map();
state.remove(&peer_ip);
}
@ -153,6 +168,7 @@ fn clear_auth_probe_state_for_testing() {
#[cfg(test)]
fn auth_probe_fail_streak_for_testing(peer_ip: IpAddr) -> Option<u32> {
let peer_ip = normalize_auth_probe_ip(peer_ip);
let state = AUTH_PROBE_STATE.get()?;
state.get(&peer_ip).map(|entry| entry.fail_streak)
}
@ -177,6 +193,12 @@ fn clear_warned_secrets_for_testing() {
}
}
#[cfg(test)]
fn warned_secrets_test_lock() -> &'static Mutex<()> {
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
TEST_LOCK.get_or_init(|| Mutex::new(()))
}
fn warn_invalid_secret_once(name: &str, reason: &str, expected: usize, got: Option<usize>) {
let key = (name.to_string(), reason.to_string());
let warned = INVALID_SECRET_WARNED.get_or_init(|| Mutex::new(HashSet::new()));

View File

@ -84,7 +84,6 @@ fn make_valid_tls_client_hello_with_alpn(
}
fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig {
clear_auth_probe_state_for_testing();
let mut cfg = ProxyConfig::default();
cfg.access.users.clear();
cfg.access
@ -369,6 +368,9 @@ async fn invalid_tls_probe_does_not_pollute_replay_cache() {
#[tokio::test]
async fn empty_decoded_secret_is_rejected() {
let _guard = warned_secrets_test_lock()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
clear_warned_secrets_for_testing();
let config = test_config_with_secret_hex("");
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
@ -393,6 +395,9 @@ async fn empty_decoded_secret_is_rejected() {
#[tokio::test]
async fn wrong_length_decoded_secret_is_rejected() {
let _guard = warned_secrets_test_lock()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
clear_warned_secrets_for_testing();
let config = test_config_with_secret_hex("aa");
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
@ -443,6 +448,12 @@ async fn invalid_mtproto_probe_does_not_pollute_replay_cache() {
#[tokio::test]
async fn mixed_secret_lengths_keep_valid_user_authenticating() {
let _probe_guard = auth_probe_test_lock()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
let _guard = warned_secrets_test_lock()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
clear_warned_secrets_for_testing();
clear_auth_probe_state_for_testing();
let good_secret = [0x22u8; 16];
@ -708,6 +719,9 @@ fn mode_policy_matrix_is_stable_for_all_tag_transport_mode_combinations() {
#[test]
fn invalid_secret_warning_keys_do_not_collide_on_colon_boundaries() {
let _guard = warned_secrets_test_lock()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
clear_warned_secrets_for_testing();
warn_invalid_secret_once("a:b", "c", ACCESS_SECRET_BYTES, Some(1));
@ -755,8 +769,9 @@ async fn repeated_invalid_tls_probes_trigger_pre_auth_throttle() {
}
assert!(
auth_probe_is_throttled_for_testing(peer.ip()),
"invalid probe burst must activate per-IP pre-auth throttle"
auth_probe_fail_streak_for_testing(peer.ip())
.is_some_and(|streak| streak >= AUTH_PROBE_BACKOFF_START_FAILS),
"invalid probe burst must grow pre-auth failure streak to backoff threshold"
);
}
@ -855,7 +870,7 @@ fn auth_probe_capacity_prunes_stale_entries_for_new_ips() {
}
#[test]
fn auth_probe_capacity_stays_fail_closed_when_map_is_fresh_and_full() {
fn auth_probe_capacity_forces_bounded_eviction_when_map_is_fresh_and_full() {
let state = DashMap::new();
let now = Instant::now();
@ -880,12 +895,88 @@ fn auth_probe_capacity_stays_fail_closed_when_map_is_fresh_and_full() {
auth_probe_record_failure_with_state(&state, newcomer, now);
assert!(
state.get(&newcomer).is_none(),
"when all entries are fresh and full, new probes must not be admitted"
state.get(&newcomer).is_some(),
"when all entries are fresh and full, one bounded eviction must admit a new probe source"
);
assert_eq!(
state.len(),
AUTH_PROBE_TRACK_MAX_ENTRIES,
"auth probe map must stay at the configured cap"
"auth probe map must stay at the configured cap after forced eviction"
);
}
#[test]
fn auth_probe_ipv6_is_bucketed_by_prefix_64() {
let state = DashMap::new();
let now = Instant::now();
let ip_a = IpAddr::V6("2001:db8:abcd:1234:1:2:3:4".parse().unwrap());
let ip_b = IpAddr::V6("2001:db8:abcd:1234:ffff:eeee:dddd:cccc".parse().unwrap());
auth_probe_record_failure_with_state(&state, normalize_auth_probe_ip(ip_a), now);
auth_probe_record_failure_with_state(&state, normalize_auth_probe_ip(ip_b), now);
let normalized = normalize_auth_probe_ip(ip_a);
assert_eq!(
state.len(),
1,
"IPv6 sources in the same /64 must share one pre-auth throttle bucket"
);
assert_eq!(
state.get(&normalized).map(|entry| entry.fail_streak),
Some(2),
"failures from the same /64 must accumulate in one throttle state"
);
}
#[test]
fn auth_probe_ipv6_different_prefixes_use_distinct_buckets() {
let state = DashMap::new();
let now = Instant::now();
let ip_a = IpAddr::V6("2001:db8:1111:2222:1:2:3:4".parse().unwrap());
let ip_b = IpAddr::V6("2001:db8:1111:3333:1:2:3:4".parse().unwrap());
auth_probe_record_failure_with_state(&state, normalize_auth_probe_ip(ip_a), now);
auth_probe_record_failure_with_state(&state, normalize_auth_probe_ip(ip_b), now);
assert_eq!(
state.len(),
2,
"different IPv6 /64 prefixes must not share throttle buckets"
);
assert_eq!(
state.get(&normalize_auth_probe_ip(ip_a)).map(|entry| entry.fail_streak),
Some(1)
);
assert_eq!(
state.get(&normalize_auth_probe_ip(ip_b)).map(|entry| entry.fail_streak),
Some(1)
);
}
#[test]
fn auth_probe_success_clears_whole_ipv6_prefix_bucket() {
let _guard = auth_probe_test_lock()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
clear_auth_probe_state_for_testing();
let now = Instant::now();
let ip_fail = IpAddr::V6("2001:db8:aaaa:bbbb:1:2:3:4".parse().unwrap());
let ip_success = IpAddr::V6("2001:db8:aaaa:bbbb:ffff:eeee:dddd:cccc".parse().unwrap());
auth_probe_record_failure(ip_fail, now);
assert_eq!(
auth_probe_fail_streak_for_testing(ip_fail),
Some(1),
"precondition: normalized prefix bucket must exist"
);
auth_probe_record_success(ip_success);
assert_eq!(
auth_probe_fail_streak_for_testing(ip_fail),
None,
"success from the same /64 must clear the shared bucket"
);
}

View File

@ -223,10 +223,10 @@ async fn relay_to_mask<R, W, MR, MW>(
initial_data: &[u8],
)
where
R: AsyncRead + Unpin + Send,
W: AsyncWrite + Unpin + Send,
MR: AsyncRead + Unpin + Send,
MW: AsyncWrite + Unpin + Send,
R: AsyncRead + Unpin + Send + 'static,
W: AsyncWrite + Unpin + Send + 'static,
MR: AsyncRead + Unpin + Send + 'static,
MW: AsyncWrite + Unpin + Send + 'static,
{
// Send initial data to mask host
if mask_write.write_all(initial_data).await.is_err() {
@ -236,39 +236,17 @@ where
return;
}
let mut client_buf = vec![0u8; MASK_BUFFER_SIZE];
let mut mask_buf = vec![0u8; MASK_BUFFER_SIZE];
loop {
tokio::select! {
client_read = reader.read(&mut client_buf) => {
match client_read {
Ok(0) | Err(_) => {
let c2m = tokio::spawn(async move {
let _ = tokio::io::copy(&mut reader, &mut mask_write).await;
let _ = mask_write.shutdown().await;
break;
}
Ok(n) => {
if mask_write.write_all(&client_buf[..n]).await.is_err() {
break;
}
}
}
}
mask_read_res = mask_read.read(&mut mask_buf) => {
match mask_read_res {
Ok(0) | Err(_) => {
});
let m2c = tokio::spawn(async move {
let _ = tokio::io::copy(&mut mask_read, &mut writer).await;
let _ = writer.shutdown().await;
break;
}
Ok(n) => {
if writer.write_all(&mask_buf[..n]).await.is_err() {
break;
}
}
}
}
}
}
});
let _ = tokio::join!(c2m, m2c);
}
/// Just consume all data from client without responding

View File

@ -6,7 +6,7 @@ use tokio::io::{duplex, AsyncBufReadExt, BufReader};
use tokio::net::TcpListener;
#[cfg(unix)]
use tokio::net::UnixListener;
use tokio::time::{timeout, Duration};
use tokio::time::{sleep, timeout, Duration};
#[tokio::test]
async fn bad_client_probe_is_forwarded_verbatim_to_mask_backend() {
@ -548,3 +548,100 @@ async fn proxy_header_write_timeout_returns_false() {
let ok = write_proxy_header_with_timeout(&mut writer, b"PROXY UNKNOWN\r\n").await;
assert!(!ok, "Proxy header writes that never complete must time out");
}
#[tokio::test]
async fn relay_to_mask_keeps_backend_to_client_flow_when_client_to_backend_stalls() {
let (mut client_feed_writer, client_feed_reader) = duplex(64);
let (mut client_visible_reader, client_visible_writer) = duplex(64);
let (mut backend_feed_writer, backend_feed_reader) = duplex(64);
// Make client->mask direction immediately active so the c2m path blocks on PendingWriter.
client_feed_writer.write_all(b"X").await.unwrap();
let relay = tokio::spawn(async move {
relay_to_mask(
client_feed_reader,
client_visible_writer,
backend_feed_reader,
PendingWriter,
b"",
)
.await;
});
// Allow relay tasks to start, then emulate mask backend response.
sleep(Duration::from_millis(20)).await;
backend_feed_writer.write_all(b"HTTP/1.1 200 OK\r\n\r\n").await.unwrap();
backend_feed_writer.shutdown().await.unwrap();
let mut observed = vec![0u8; 19];
timeout(Duration::from_secs(1), client_visible_reader.read_exact(&mut observed))
.await
.unwrap()
.unwrap();
assert_eq!(observed, b"HTTP/1.1 200 OK\r\n\r\n");
relay.abort();
let _ = relay.await;
}
#[tokio::test]
async fn relay_to_mask_preserves_backend_response_after_client_half_close() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let request = b"GET / HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec();
let response = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK".to_vec();
let backend_task = tokio::spawn({
let request = request.clone();
let response = response.clone();
async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut observed_req = vec![0u8; request.len()];
stream.read_exact(&mut observed_req).await.unwrap();
assert_eq!(observed_req, request);
stream.write_all(&response).await.unwrap();
stream.shutdown().await.unwrap();
}
});
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = backend_addr.port();
config.censorship.mask_unix_sock = None;
config.censorship.mask_proxy_protocol = 0;
let peer: SocketAddr = "203.0.113.77:55001".parse().unwrap();
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let (mut client_write, client_read) = duplex(1024);
let (mut client_visible_reader, client_visible_writer) = duplex(2048);
let beobachten = BeobachtenStore::new();
let fallback_task = tokio::spawn(async move {
handle_bad_client(
client_read,
client_visible_writer,
&request,
peer,
local_addr,
&config,
&beobachten,
)
.await;
});
client_write.shutdown().await.unwrap();
let mut observed_resp = vec![0u8; response.len()];
timeout(Duration::from_secs(1), client_visible_reader.read_exact(&mut observed_resp))
.await
.unwrap()
.unwrap();
assert_eq!(observed_resp, response);
timeout(Duration::from_secs(1), fallback_task).await.unwrap().unwrap();
timeout(Duration::from_secs(1), backend_task).await.unwrap().unwrap();
}

View File

@ -7,7 +7,7 @@ use std::time::{Duration, Instant};
#[cfg(test)]
use std::sync::Mutex;
use bytes::{Bytes, BytesMut};
use bytes::Bytes;
use dashmap::DashMap;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::sync::{mpsc, oneshot, watch};
@ -107,7 +107,11 @@ fn should_emit_full_desync(key: u64, all_full: bool, now: Instant) -> bool {
if dedup.len() >= DESYNC_DEDUP_MAX_ENTRIES {
let mut stale_keys = Vec::new();
let mut eviction_candidate = None;
for entry in dedup.iter().take(DESYNC_DEDUP_PRUNE_SCAN_LIMIT) {
if eviction_candidate.is_none() {
eviction_candidate = Some(*entry.key());
}
if now.duration_since(*entry.value()) >= DESYNC_DEDUP_WINDOW {
stale_keys.push(*entry.key());
}
@ -116,6 +120,11 @@ fn should_emit_full_desync(key: u64, all_full: bool, now: Instant) -> bool {
dedup.remove(&stale_key);
}
if dedup.len() >= DESYNC_DEDUP_MAX_ENTRIES {
let Some(evict_key) = eviction_candidate else {
return false;
};
dedup.remove(&evict_key);
dedup.insert(key, now);
return false;
}
}
@ -784,25 +793,22 @@ where
len
};
let chunk_cap = buffer_pool.buffer_size().max(1024);
let mut payload = BytesMut::with_capacity(len.min(chunk_cap));
let mut remaining = len;
while remaining > 0 {
let chunk_len = remaining.min(chunk_cap);
let mut chunk = buffer_pool.get();
chunk.resize(chunk_len, 0);
read_exact_with_timeout(client_reader, &mut chunk[..chunk_len], frame_read_timeout)
.await?;
payload.extend_from_slice(&chunk[..chunk_len]);
remaining -= chunk_len;
let mut payload = buffer_pool.get();
payload.clear();
let current_cap = payload.capacity();
if current_cap < len {
payload.reserve(len - current_cap);
}
payload.resize(len, 0);
read_exact_with_timeout(client_reader, &mut payload[..len], frame_read_timeout).await?;
// Secure Intermediate: strip validated trailing padding bytes.
if proto_tag == ProtoTag::Secure {
payload.truncate(secure_payload_len);
}
*frame_counter += 1;
return Ok(Some((payload.freeze(), quickack)));
let payload = payload.take().freeze();
return Ok(Some((payload, quickack)));
}
}

View File

@ -101,7 +101,7 @@ fn desync_dedup_cache_is_bounded() {
assert!(
!should_emit_full_desync(u64::MAX, false, now),
"new key above cap must be suppressed to bound memory"
"new key above cap must remain suppressed to avoid log amplification"
);
assert!(
@ -110,6 +110,26 @@ fn desync_dedup_cache_is_bounded() {
);
}
#[test]
fn desync_dedup_full_cache_churn_stays_suppressed() {
let _guard = desync_dedup_test_lock()
.lock()
.expect("desync dedup test lock must be available");
clear_desync_dedup_for_testing();
let now = Instant::now();
for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 {
assert!(should_emit_full_desync(key, false, now));
}
for offset in 0..2048u64 {
assert!(
!should_emit_full_desync(u64::MAX - offset, false, now),
"fresh full-cache churn must remain suppressed under pressure"
);
}
}
fn make_forensics_state() -> RelayForensicsState {
RelayForensicsState {
trace_id: 1,
@ -199,3 +219,213 @@ async fn read_client_payload_times_out_on_payload_stall() {
"stalled payload body read must time out"
);
}
#[tokio::test]
async fn read_client_payload_large_intermediate_frame_is_exact() {
let _guard = desync_dedup_test_lock()
.lock()
.expect("middle relay test lock must be available");
let (reader, mut writer) = duplex(262_144);
let mut crypto_reader = make_crypto_reader(reader);
let buffer_pool = Arc::new(BufferPool::new());
let stats = Stats::new();
let forensics = make_forensics_state();
let mut frame_counter = 0;
let payload_len = buffer_pool.buffer_size().saturating_mul(3).max(65_537);
let mut plaintext = Vec::with_capacity(4 + payload_len);
plaintext.extend_from_slice(&(payload_len as u32).to_le_bytes());
plaintext.extend((0..payload_len).map(|idx| (idx as u8).wrapping_mul(31)));
let encrypted = encrypt_for_reader(&plaintext);
writer.write_all(&encrypted).await.unwrap();
let read = read_client_payload(
&mut crypto_reader,
ProtoTag::Intermediate,
payload_len + 16,
TokioDuration::from_secs(1),
&buffer_pool,
&forensics,
&mut frame_counter,
&stats,
)
.await
.expect("payload read must succeed")
.expect("frame must be present");
let (frame, quickack) = read;
assert!(!quickack, "quickack flag must be unset");
assert_eq!(frame.len(), payload_len, "payload size must match wire length");
for (idx, byte) in frame.iter().enumerate() {
assert_eq!(*byte, (idx as u8).wrapping_mul(31));
}
assert_eq!(frame_counter, 1, "exactly one frame must be counted");
}
#[tokio::test]
async fn read_client_payload_secure_strips_tail_padding_bytes() {
let _guard = desync_dedup_test_lock()
.lock()
.expect("middle relay test lock must be available");
let (reader, mut writer) = duplex(1024);
let mut crypto_reader = make_crypto_reader(reader);
let buffer_pool = Arc::new(BufferPool::new());
let stats = Stats::new();
let forensics = make_forensics_state();
let mut frame_counter = 0;
let payload = [0x11u8, 0x22, 0x33, 0x44, 0xaa, 0xbb, 0xcc, 0xdd];
let tail = [0xeeu8, 0xff, 0x99];
let wire_len = payload.len() + tail.len();
let mut plaintext = Vec::with_capacity(4 + wire_len);
plaintext.extend_from_slice(&(wire_len as u32).to_le_bytes());
plaintext.extend_from_slice(&payload);
plaintext.extend_from_slice(&tail);
let encrypted = encrypt_for_reader(&plaintext);
writer.write_all(&encrypted).await.unwrap();
let read = read_client_payload(
&mut crypto_reader,
ProtoTag::Secure,
1024,
TokioDuration::from_secs(1),
&buffer_pool,
&forensics,
&mut frame_counter,
&stats,
)
.await
.expect("secure payload read must succeed")
.expect("secure frame must be present");
let (frame, quickack) = read;
assert!(!quickack, "quickack flag must be unset");
assert_eq!(frame.as_ref(), &payload);
assert_eq!(frame_counter, 1, "one secure frame must be counted");
}
#[tokio::test]
async fn read_client_payload_secure_rejects_wire_len_below_4() {
let _guard = desync_dedup_test_lock()
.lock()
.expect("middle relay test lock must be available");
let (reader, mut writer) = duplex(1024);
let mut crypto_reader = make_crypto_reader(reader);
let buffer_pool = Arc::new(BufferPool::new());
let stats = Stats::new();
let forensics = make_forensics_state();
let mut frame_counter = 0;
let mut plaintext = Vec::with_capacity(7);
plaintext.extend_from_slice(&3u32.to_le_bytes());
plaintext.extend_from_slice(&[1u8, 2, 3]);
let encrypted = encrypt_for_reader(&plaintext);
writer.write_all(&encrypted).await.unwrap();
let result = read_client_payload(
&mut crypto_reader,
ProtoTag::Secure,
1024,
TokioDuration::from_secs(1),
&buffer_pool,
&forensics,
&mut frame_counter,
&stats,
)
.await;
assert!(
matches!(result, Err(ProxyError::Proxy(ref msg)) if msg.contains("Frame too small: 3")),
"secure wire length below 4 must be fail-closed by the frame-too-small guard"
);
}
#[tokio::test]
async fn read_client_payload_intermediate_skips_zero_len_frame() {
let _guard = desync_dedup_test_lock()
.lock()
.expect("middle relay test lock must be available");
let (reader, mut writer) = duplex(1024);
let mut crypto_reader = make_crypto_reader(reader);
let buffer_pool = Arc::new(BufferPool::new());
let stats = Stats::new();
let forensics = make_forensics_state();
let mut frame_counter = 0;
let payload = [7u8, 6, 5, 4, 3, 2, 1, 0];
let mut plaintext = Vec::with_capacity(4 + 4 + payload.len());
plaintext.extend_from_slice(&0u32.to_le_bytes());
plaintext.extend_from_slice(&(payload.len() as u32).to_le_bytes());
plaintext.extend_from_slice(&payload);
let encrypted = encrypt_for_reader(&plaintext);
writer.write_all(&encrypted).await.unwrap();
let read = read_client_payload(
&mut crypto_reader,
ProtoTag::Intermediate,
1024,
TokioDuration::from_secs(1),
&buffer_pool,
&forensics,
&mut frame_counter,
&stats,
)
.await
.expect("intermediate payload read must succeed")
.expect("frame must be present");
let (frame, quickack) = read;
assert!(!quickack, "quickack flag must be unset");
assert_eq!(frame.as_ref(), &payload);
assert_eq!(frame_counter, 1, "zero-length frame must be skipped");
}
#[tokio::test]
async fn read_client_payload_abridged_extended_len_sets_quickack() {
let _guard = desync_dedup_test_lock()
.lock()
.expect("middle relay test lock must be available");
let (reader, mut writer) = duplex(4096);
let mut crypto_reader = make_crypto_reader(reader);
let buffer_pool = Arc::new(BufferPool::new());
let stats = Stats::new();
let forensics = make_forensics_state();
let mut frame_counter = 0;
let payload_len = 4 * 130;
let len_words = (payload_len / 4) as u32;
let mut plaintext = Vec::with_capacity(1 + 3 + payload_len);
plaintext.push(0xff | 0x80);
let lw = len_words.to_le_bytes();
plaintext.extend_from_slice(&lw[..3]);
plaintext.extend((0..payload_len).map(|idx| (idx as u8).wrapping_add(17)));
let encrypted = encrypt_for_reader(&plaintext);
writer.write_all(&encrypted).await.unwrap();
let read = read_client_payload(
&mut crypto_reader,
ProtoTag::Abridged,
payload_len + 16,
TokioDuration::from_secs(1),
&buffer_pool,
&forensics,
&mut frame_counter,
&stats,
)
.await
.expect("abridged payload read must succeed")
.expect("frame must be present");
let (frame, quickack) = read;
assert!(quickack, "quickack bit must be propagated from abridged header");
assert_eq!(frame.len(), payload_len);
assert_eq!(frame_counter, 1, "one abridged frame must be counted");
}